From 287005809cec5388dcb75a3d99bc0f0461b9bb69 Mon Sep 17 00:00:00 2001 From: "sergei.rubtcov" Date: Fri, 10 Mar 2017 15:03:27 +0200 Subject: [PATCH 0001/1765] [SPARK-19228][SQL] Introduce tryParseDate method to process csv date, add a type-widening rule in findTightestCommonType between DateType and TimestampType, add an end-to-end test case --- .../datasources/csv/CSVInferSchema.scala | 19 +++++++-- .../test-data/dates-and-timestamps.csv | 4 ++ .../datasources/csv/CSVInferSchemaSuite.scala | 10 ++++- .../execution/datasources/csv/CSVSuite.scala | 39 +++++++++++++++++++ 4 files changed, 68 insertions(+), 4 deletions(-) create mode 100644 sql/core/src/test/resources/test-data/dates-and-timestamps.csv diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchema.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchema.scala index b64d71bb4eef2..6249a235ad502 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchema.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchema.scala @@ -90,7 +90,9 @@ private[csv] object CSVInferSchema { // DecimalTypes have different precisions and scales, so we try to find the common type. findTightestCommonType(typeSoFar, tryParseDecimal(field, options)).getOrElse(StringType) case DoubleType => tryParseDouble(field, options) - case TimestampType => tryParseTimestamp(field, options) + case DateType => tryParseDate(field, options) + case TimestampType => + findTightestCommonType(typeSoFar, tryParseTimestamp(field, options)).getOrElse(StringType) case BooleanType => tryParseBoolean(field, options) case StringType => StringType case other: DataType => @@ -140,17 +142,26 @@ private[csv] object CSVInferSchema { private def tryParseDouble(field: String, options: CSVOptions): DataType = { if ((allCatch opt field.toDouble).isDefined || isInfOrNan(field, options)) { DoubleType + } else { + tryParseDate(field, options) + } + } + + private def tryParseDate(field: String, options: CSVOptions): DataType = { + // This case infers a custom `dateFormat` is set. + if ((allCatch opt options.dateFormat.parse(field)).isDefined) { + DateType } else { tryParseTimestamp(field, options) } } private def tryParseTimestamp(field: String, options: CSVOptions): DataType = { - // This case infers a custom `dataFormat` is set. + // This case infers a custom `timestampFormat` is set. if ((allCatch opt options.timestampFormat.parse(field)).isDefined) { TimestampType } else if ((allCatch opt DateTimeUtils.stringToTime(field)).isDefined) { - // We keep this for backwords competibility. + // We keep this for backwards compatibility. TimestampType } else { tryParseBoolean(field, options) @@ -216,6 +227,8 @@ private[csv] object CSVInferSchema { } else { Some(DecimalType(range + scale, scale)) } + // By design 'TimestampType' (8 bytes) is larger than 'DateType' (4 bytes). + case (t1: DateType, t2: TimestampType) => Some(TimestampType) case _ => None } diff --git a/sql/core/src/test/resources/test-data/dates-and-timestamps.csv b/sql/core/src/test/resources/test-data/dates-and-timestamps.csv new file mode 100644 index 0000000000000..0a9a4c2f8566c --- /dev/null +++ b/sql/core/src/test/resources/test-data/dates-and-timestamps.csv @@ -0,0 +1,4 @@ +timestamp,date +26/08/2015 22:31:46.913,27/09/2015 +27/10/2014 22:33:31.601,26/12/2016 +28/01/2016 22:33:52.888,28/01/2017 \ No newline at end of file diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchemaSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchemaSuite.scala index 661742087112f..d1a8822ca025e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchemaSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchemaSuite.scala @@ -59,13 +59,21 @@ class CSVInferSchemaSuite extends SparkFunSuite { assert(CSVInferSchema.inferField(IntegerType, textValueOne, options) == expectedTypeOne) } - test("Timestamp field types are inferred correctly via custom data format") { + test("Timestamp field types are inferred correctly via custom date format") { var options = new CSVOptions(Map("timestampFormat" -> "yyyy-mm"), "GMT") assert(CSVInferSchema.inferField(TimestampType, "2015-08", options) == TimestampType) options = new CSVOptions(Map("timestampFormat" -> "yyyy"), "GMT") assert(CSVInferSchema.inferField(TimestampType, "2015", options) == TimestampType) } + test("Date field types are inferred correctly via custom date and timestamp format") { + val options = new CSVOptions(Map("dateFormat" -> "dd/MM/yyyy", + "timestampFormat" -> "dd/MM/yyyy HH:mm:ss.SSS"), "GMT") + assert(CSVInferSchema.inferField(TimestampType, + "28/01/2017 22:31:46.913", options) == TimestampType) + assert(CSVInferSchema.inferField(DateType, "16/12/2012", options) == DateType) + } + test("Timestamp field types are inferred correctly from other types") { val options = new CSVOptions(Map.empty[String, String], "GMT") assert(CSVInferSchema.inferField(IntegerType, "2015-08-20 14", options) == StringType) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala index 4435e4df38ef6..f5da8451d46bc 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala @@ -53,6 +53,7 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils { private val simpleSparseFile = "test-data/simple_sparse.csv" private val numbersFile = "test-data/numbers.csv" private val datesFile = "test-data/dates.csv" + private val datesAndTimestampsFile = "test-data/dates-and-timestamps.csv" private val unescapedQuotesFile = "test-data/unescaped-quotes.csv" private val valueMalformedFile = "test-data/value-malformed.csv" @@ -531,6 +532,44 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils { assert(results.toSeq.map(_.toSeq) === expected) } + test("inferring timestamp types and date types via custom formats") { + val options = Map( + "header" -> "true", + "inferSchema" -> "true", + "timestampFormat" -> "dd/MM/yyyy HH:mm:ss.SSS", + "dateFormat" -> "dd/MM/yyyy") + val results = spark.read + .format("csv") + .options(options) + .load(testFile(datesAndTimestampsFile)) + assert(results.schema{0}.dataType===TimestampType) + assert(results.schema{1}.dataType===DateType) + val timestamps = spark.read + .format("csv") + .options(options) + .load(testFile(datesAndTimestampsFile)) + .select("timestamp") + .collect() + val timestampFormat = new SimpleDateFormat("dd/MM/yyyy HH:mm:ss.SSS", Locale.US) + val timestampExpected = + Seq(Seq(new Timestamp(timestampFormat.parse("26/08/2015 22:31:46.913").getTime)), + Seq(new Timestamp(timestampFormat.parse("27/10/2014 22:33:31.601").getTime)), + Seq(new Timestamp(timestampFormat.parse("28/01/2016 22:33:52.888").getTime))) + assert(timestamps.toSeq.map(_.toSeq) === timestampExpected) + val dates = spark.read + .format("csv") + .options(options) + .load(testFile(datesAndTimestampsFile)) + .select("date") + .collect() + val dateFormat = new SimpleDateFormat("dd/MM/yyyy", Locale.US) + val dateExpected = + Seq(Seq(new Date(dateFormat.parse("27/09/2015").getTime)), + Seq(new Date(dateFormat.parse("26/12/2016").getTime)), + Seq(new Date(dateFormat.parse("28/01/2017").getTime))) + assert(dates.toSeq.map(_.toSeq) === dateExpected) + } + test("load date types via custom date format") { val customSchema = new StructType(Array(StructField("date", DateType, true))) val options = Map( From 72c66dbbb4dacaf5fd77bca58c952f34eba7c147 Mon Sep 17 00:00:00 2001 From: "Joseph K. Bradley" Date: Mon, 13 Mar 2017 16:30:15 -0700 Subject: [PATCH 0002/1765] [MINOR][ML] Improve MLWriter overwrite error message ## What changes were proposed in this pull request? Give proper syntax for Java and Python in addition to Scala. ## How was this patch tested? Manually. Author: Joseph K. Bradley Closes #17215 from jkbradley/write-err-msg. --- .../src/main/scala/org/apache/spark/ml/util/ReadWrite.scala | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala b/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala index 09bddcdb810bb..a8b80031faf86 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala @@ -104,8 +104,9 @@ abstract class MLWriter extends BaseReadWrite with Logging { // TODO: Revert back to the original content if save is not successful. fs.delete(qualifiedOutputPath, true) } else { - throw new IOException( - s"Path $path already exists. Please use write.overwrite().save(path) to overwrite it.") + throw new IOException(s"Path $path already exists. To overwrite it, " + + s"please use write.overwrite().save(path) for Scala and use " + + s"write().overwrite().save(path) for Java and Python.") } } saveImpl(path) From 4dc3a8171c31e11aafa85200d3928b1745aa32bd Mon Sep 17 00:00:00 2001 From: Xiao Li Date: Tue, 14 Mar 2017 12:06:01 +0800 Subject: [PATCH 0003/1765] [SPARK-19924][SQL] Handle InvocationTargetException for all Hive Shim ### What changes were proposed in this pull request? Since we are using shim for most Hive metastore APIs, the exceptions thrown by the underlying method of Method.invoke() are wrapped by `InvocationTargetException`. Instead of doing it one by one, we should handle all of them in the `withClient`. If any of them is missing, the error message could looks unfriendly. For example, below is an example for dropping tables. ``` Expected exception org.apache.spark.sql.AnalysisException to be thrown, but java.lang.reflect.InvocationTargetException was thrown. ScalaTestFailureLocation: org.apache.spark.sql.catalyst.catalog.ExternalCatalogSuite$$anonfun$14 at (ExternalCatalogSuite.scala:193) org.scalatest.exceptions.TestFailedException: Expected exception org.apache.spark.sql.AnalysisException to be thrown, but java.lang.reflect.InvocationTargetException was thrown. at org.scalatest.Assertions$class.newAssertionFailedException(Assertions.scala:496) at org.scalatest.FunSuite.newAssertionFailedException(FunSuite.scala:1555) at org.scalatest.Assertions$class.intercept(Assertions.scala:1004) at org.scalatest.FunSuite.intercept(FunSuite.scala:1555) at org.apache.spark.sql.catalyst.catalog.ExternalCatalogSuite$$anonfun$14.apply$mcV$sp(ExternalCatalogSuite.scala:193) at org.apache.spark.sql.catalyst.catalog.ExternalCatalogSuite$$anonfun$14.apply(ExternalCatalogSuite.scala:183) at org.apache.spark.sql.catalyst.catalog.ExternalCatalogSuite$$anonfun$14.apply(ExternalCatalogSuite.scala:183) at org.scalatest.Transformer$$anonfun$apply$1.apply$mcV$sp(Transformer.scala:22) at org.scalatest.OutcomeOf$class.outcomeOf(OutcomeOf.scala:85) at org.scalatest.OutcomeOf$.outcomeOf(OutcomeOf.scala:104) at org.scalatest.Transformer.apply(Transformer.scala:22) at org.scalatest.Transformer.apply(Transformer.scala:20) at org.scalatest.FunSuiteLike$$anon$1.apply(FunSuiteLike.scala:166) at org.apache.spark.SparkFunSuite.withFixture(SparkFunSuite.scala:68) at org.scalatest.FunSuiteLike$class.invokeWithFixture$1(FunSuiteLike.scala:163) at org.scalatest.FunSuiteLike$$anonfun$runTest$1.apply(FunSuiteLike.scala:175) at org.scalatest.FunSuiteLike$$anonfun$runTest$1.apply(FunSuiteLike.scala:175) at org.scalatest.SuperEngine.runTestImpl(Engine.scala:306) at org.scalatest.FunSuiteLike$class.runTest(FunSuiteLike.scala:175) at org.apache.spark.sql.catalyst.catalog.ExternalCatalogSuite.org$scalatest$BeforeAndAfterEach$$super$runTest(ExternalCatalogSuite.scala:40) at org.scalatest.BeforeAndAfterEach$class.runTest(BeforeAndAfterEach.scala:255) at org.apache.spark.sql.catalyst.catalog.ExternalCatalogSuite.runTest(ExternalCatalogSuite.scala:40) at org.scalatest.FunSuiteLike$$anonfun$runTests$1.apply(FunSuiteLike.scala:208) at org.scalatest.FunSuiteLike$$anonfun$runTests$1.apply(FunSuiteLike.scala:208) at org.scalatest.SuperEngine$$anonfun$traverseSubNodes$1$1.apply(Engine.scala:413) at org.scalatest.SuperEngine$$anonfun$traverseSubNodes$1$1.apply(Engine.scala:401) at scala.collection.immutable.List.foreach(List.scala:381) at org.scalatest.SuperEngine.traverseSubNodes$1(Engine.scala:401) at org.scalatest.SuperEngine.org$scalatest$SuperEngine$$runTestsInBranch(Engine.scala:396) at org.scalatest.SuperEngine.runTestsImpl(Engine.scala:483) at org.scalatest.FunSuiteLike$class.runTests(FunSuiteLike.scala:208) at org.scalatest.FunSuite.runTests(FunSuite.scala:1555) at org.scalatest.Suite$class.run(Suite.scala:1424) at org.scalatest.FunSuite.org$scalatest$FunSuiteLike$$super$run(FunSuite.scala:1555) at org.scalatest.FunSuiteLike$$anonfun$run$1.apply(FunSuiteLike.scala:212) at org.scalatest.FunSuiteLike$$anonfun$run$1.apply(FunSuiteLike.scala:212) at org.scalatest.SuperEngine.runImpl(Engine.scala:545) at org.scalatest.FunSuiteLike$class.run(FunSuiteLike.scala:212) at org.apache.spark.SparkFunSuite.org$scalatest$BeforeAndAfterAll$$super$run(SparkFunSuite.scala:31) at org.scalatest.BeforeAndAfterAll$class.liftedTree1$1(BeforeAndAfterAll.scala:257) at org.scalatest.BeforeAndAfterAll$class.run(BeforeAndAfterAll.scala:256) at org.apache.spark.SparkFunSuite.run(SparkFunSuite.scala:31) at org.scalatest.tools.SuiteRunner.run(SuiteRunner.scala:55) at org.scalatest.tools.Runner$$anonfun$doRunRunRunDaDoRunRun$3.apply(Runner.scala:2563) at org.scalatest.tools.Runner$$anonfun$doRunRunRunDaDoRunRun$3.apply(Runner.scala:2557) at scala.collection.immutable.List.foreach(List.scala:381) at org.scalatest.tools.Runner$.doRunRunRunDaDoRunRun(Runner.scala:2557) at org.scalatest.tools.Runner$$anonfun$runOptionallyWithPassFailReporter$2.apply(Runner.scala:1044) at org.scalatest.tools.Runner$$anonfun$runOptionallyWithPassFailReporter$2.apply(Runner.scala:1043) at org.scalatest.tools.Runner$.withClassLoaderAndDispatchReporter(Runner.scala:2722) at org.scalatest.tools.Runner$.runOptionallyWithPassFailReporter(Runner.scala:1043) at org.scalatest.tools.Runner$.run(Runner.scala:883) at org.scalatest.tools.Runner.run(Runner.scala) at org.jetbrains.plugins.scala.testingSupport.scalaTest.ScalaTestRunner.runScalaTest2(ScalaTestRunner.java:138) at org.jetbrains.plugins.scala.testingSupport.scalaTest.ScalaTestRunner.main(ScalaTestRunner.java:28) at sun.reflect.NativeMethodAccessorImpl.invoke0(Native Method) at sun.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:62) at sun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43) at java.lang.reflect.Method.invoke(Method.java:498) at com.intellij.rt.execution.application.AppMain.main(AppMain.java:147) Caused by: java.lang.reflect.InvocationTargetException at sun.reflect.NativeMethodAccessorImpl.invoke0(Native Method) at sun.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:62) at sun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43) at java.lang.reflect.Method.invoke(Method.java:498) at org.apache.spark.sql.hive.client.Shim_v0_14.dropTable(HiveShim.scala:736) at org.apache.spark.sql.hive.client.HiveClientImpl$$anonfun$dropTable$1.apply$mcV$sp(HiveClientImpl.scala:451) at org.apache.spark.sql.hive.client.HiveClientImpl$$anonfun$dropTable$1.apply(HiveClientImpl.scala:451) at org.apache.spark.sql.hive.client.HiveClientImpl$$anonfun$dropTable$1.apply(HiveClientImpl.scala:451) at org.apache.spark.sql.hive.client.HiveClientImpl$$anonfun$withHiveState$1.apply(HiveClientImpl.scala:287) at org.apache.spark.sql.hive.client.HiveClientImpl.liftedTree1$1(HiveClientImpl.scala:228) at org.apache.spark.sql.hive.client.HiveClientImpl.retryLocked(HiveClientImpl.scala:227) at org.apache.spark.sql.hive.client.HiveClientImpl.withHiveState(HiveClientImpl.scala:270) at org.apache.spark.sql.hive.client.HiveClientImpl.dropTable(HiveClientImpl.scala:450) at org.apache.spark.sql.hive.HiveExternalCatalog$$anonfun$dropTable$1.apply$mcV$sp(HiveExternalCatalog.scala:456) at org.apache.spark.sql.hive.HiveExternalCatalog$$anonfun$dropTable$1.apply(HiveExternalCatalog.scala:454) at org.apache.spark.sql.hive.HiveExternalCatalog$$anonfun$dropTable$1.apply(HiveExternalCatalog.scala:454) at org.apache.spark.sql.hive.HiveExternalCatalog.withClient(HiveExternalCatalog.scala:94) at org.apache.spark.sql.hive.HiveExternalCatalog.dropTable(HiveExternalCatalog.scala:454) at org.apache.spark.sql.catalyst.catalog.ExternalCatalogSuite$$anonfun$14$$anonfun$apply$mcV$sp$8.apply$mcV$sp(ExternalCatalogSuite.scala:194) at org.apache.spark.sql.catalyst.catalog.ExternalCatalogSuite$$anonfun$14$$anonfun$apply$mcV$sp$8.apply(ExternalCatalogSuite.scala:194) at org.apache.spark.sql.catalyst.catalog.ExternalCatalogSuite$$anonfun$14$$anonfun$apply$mcV$sp$8.apply(ExternalCatalogSuite.scala:194) at org.scalatest.Assertions$class.intercept(Assertions.scala:997) ... 57 more Caused by: org.apache.hadoop.hive.ql.metadata.HiveException: NoSuchObjectException(message:db2.unknown_table table not found) at org.apache.hadoop.hive.ql.metadata.Hive.dropTable(Hive.java:1038) ... 79 more Caused by: NoSuchObjectException(message:db2.unknown_table table not found) at org.apache.hadoop.hive.metastore.HiveMetaStore$HMSHandler.get_table_core(HiveMetaStore.java:1808) at org.apache.hadoop.hive.metastore.HiveMetaStore$HMSHandler.get_table(HiveMetaStore.java:1778) at sun.reflect.NativeMethodAccessorImpl.invoke0(Native Method) at sun.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:62) at sun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43) at java.lang.reflect.Method.invoke(Method.java:498) at org.apache.hadoop.hive.metastore.RetryingHMSHandler.invoke(RetryingHMSHandler.java:107) at com.sun.proxy.$Proxy10.get_table(Unknown Source) at org.apache.hadoop.hive.metastore.HiveMetaStoreClient.getTable(HiveMetaStoreClient.java:1208) at org.apache.hadoop.hive.ql.metadata.SessionHiveMetaStoreClient.getTable(SessionHiveMetaStoreClient.java:131) at org.apache.hadoop.hive.metastore.HiveMetaStoreClient.dropTable(HiveMetaStoreClient.java:952) at org.apache.hadoop.hive.metastore.HiveMetaStoreClient.dropTable(HiveMetaStoreClient.java:904) at sun.reflect.NativeMethodAccessorImpl.invoke0(Native Method) at sun.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:62) at sun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43) at java.lang.reflect.Method.invoke(Method.java:498) at org.apache.hadoop.hive.metastore.RetryingMetaStoreClient.invoke(RetryingMetaStoreClient.java:156) at com.sun.proxy.$Proxy11.dropTable(Unknown Source) at org.apache.hadoop.hive.ql.metadata.Hive.dropTable(Hive.java:1035) ... 79 more ``` After unwrapping the exception, the message is like ``` org.apache.hadoop.hive.ql.metadata.HiveException: NoSuchObjectException(message:db2.unknown_table table not found); org.apache.spark.sql.AnalysisException: org.apache.hadoop.hive.ql.metadata.HiveException: NoSuchObjectException(message:db2.unknown_table table not found); at org.apache.spark.sql.hive.HiveExternalCatalog.withClient(HiveExternalCatalog.scala:100) at org.apache.spark.sql.hive.HiveExternalCatalog.dropTable(HiveExternalCatalog.scala:460) at org.apache.spark.sql.catalyst.catalog.ExternalCatalogSuite$$anonfun$14.apply$mcV$sp(ExternalCatalogSuite.scala:193) at org.apache.spark.sql.catalyst.catalog.ExternalCatalogSuite$$anonfun$14.apply(ExternalCatalogSuite.scala:183) at org.apache.spark.sql.catalyst.catalog.ExternalCatalogSuite$$anonfun$14.apply(ExternalCatalogSuite.scala:183) at org.scalatest.Transformer$$anonfun$apply$1.apply$mcV$sp(Transformer.scala:22) ... ``` ### How was this patch tested? Covered by the existing test case in `test("drop table when database/table does not exist")` in `ExternalCatalogSuite`. Author: Xiao Li Closes #17265 from gatorsmile/InvocationTargetException. --- .../spark/sql/hive/HiveExternalCatalog.scala | 12 ++++++++++-- .../apache/spark/sql/hive/client/HiveShim.scala | 14 +++----------- .../spark/sql/hive/execution/HiveDDLSuite.scala | 13 ++++++------- 3 files changed, 19 insertions(+), 20 deletions(-) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala index 78aa2bd2494f3..fd633869dde57 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.hive import java.io.IOException +import java.lang.reflect.InvocationTargetException import java.net.URI import java.util @@ -68,7 +69,8 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat // Exceptions thrown by the hive client that we would like to wrap private val clientExceptions = Set( classOf[HiveException].getCanonicalName, - classOf[TException].getCanonicalName) + classOf[TException].getCanonicalName, + classOf[InvocationTargetException].getCanonicalName) /** * Whether this is an exception thrown by the hive client that should be wrapped. @@ -94,7 +96,13 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat try { body } catch { - case NonFatal(e) if isClientException(e) => + case NonFatal(exception) if isClientException(exception) => + val e = exception match { + // Since we are using shim, the exceptions thrown by the underlying method of + // Method.invoke() are wrapped by InvocationTargetException + case i: InvocationTargetException => i.getCause + case o => o + } throw new AnalysisException( e.getClass.getCanonicalName + ": " + e.getMessage, cause = Some(e)) } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala index c6188fc683e77..153f1673c96f6 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala @@ -733,12 +733,8 @@ private[client] class Shim_v0_14 extends Shim_v0_13 { deleteData: Boolean, ignoreIfNotExists: Boolean, purge: Boolean): Unit = { - try { - dropTableMethod.invoke(hive, dbName, tableName, deleteData: JBoolean, - ignoreIfNotExists: JBoolean, purge: JBoolean) - } catch { - case e: InvocationTargetException => throw e.getCause() - } + dropTableMethod.invoke(hive, dbName, tableName, deleteData: JBoolean, + ignoreIfNotExists: JBoolean, purge: JBoolean) } override def getMetastoreClientConnectRetryDelayMillis(conf: HiveConf): Long = { @@ -824,11 +820,7 @@ private[client] class Shim_v1_2 extends Shim_v1_1 { val dropOptions = dropOptionsClass.newInstance().asInstanceOf[Object] dropOptionsDeleteData.setBoolean(dropOptions, deleteData) dropOptionsPurge.setBoolean(dropOptions, purge) - try { - dropPartitionMethod.invoke(hive, dbName, tableName, part, dropOptions) - } catch { - case e: InvocationTargetException => throw e.getCause() - } + dropPartitionMethod.invoke(hive, dbName, tableName, part, dropOptions) } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala index d29242bb47e36..d752c415c1ed8 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala @@ -18,7 +18,6 @@ package org.apache.spark.sql.hive.execution import java.io.File -import java.lang.reflect.InvocationTargetException import java.net.URI import org.apache.hadoop.fs.Path @@ -1799,9 +1798,9 @@ class HiveDDLSuite assert(loc.listFiles().length >= 1) checkAnswer(spark.table("t"), Row("1") :: Nil) } else { - val e = intercept[InvocationTargetException] { + val e = intercept[AnalysisException] { spark.sql("INSERT INTO TABLE t SELECT 1") - }.getTargetException.getMessage + }.getMessage assert(e.contains("java.net.URISyntaxException: Relative path in absolute URI: a:b")) } } @@ -1836,14 +1835,14 @@ class HiveDDLSuite checkAnswer(spark.table("t1"), Row("1", "2") :: Row("1", "2017-03-03 12:13%3A14") :: Nil) } else { - val e = intercept[InvocationTargetException] { + val e = intercept[AnalysisException] { spark.sql("INSERT INTO TABLE t1 PARTITION(b=2) SELECT 1") - }.getTargetException.getMessage + }.getMessage assert(e.contains("java.net.URISyntaxException: Relative path in absolute URI: a:b")) - val e1 = intercept[InvocationTargetException] { + val e1 = intercept[AnalysisException] { spark.sql("INSERT INTO TABLE t1 PARTITION(b='2017-03-03 12:13%3A14') SELECT 1") - }.getTargetException.getMessage + }.getMessage assert(e1.contains("java.net.URISyntaxException: Relative path in absolute URI: a:b")) } } From 415f9f3423aacc395097e40427364c921a2ed7f1 Mon Sep 17 00:00:00 2001 From: Xiao Li Date: Tue, 14 Mar 2017 14:19:02 +0800 Subject: [PATCH 0004/1765] [SPARK-19921][SQL][TEST] Enable end-to-end testing using different Hive metastore versions. ### What changes were proposed in this pull request? To improve the quality of our Spark SQL in different Hive metastore versions, this PR is to enable end-to-end testing using different versions. This PR allows the test cases in sql/hive to pass the existing Hive client to create a SparkSession. - Since Derby does not allow concurrent connections, the pre-built Hive clients use different database from the TestHive's built-in 1.2.1 client. - Since our test cases in sql/hive only can create a single Spark context in the same JVM, the newly created SparkSession share the same spark context with the existing TestHive's corresponding SparkSession. ### How was this patch tested? Fixed the existing test cases. Author: Xiao Li Closes #17260 from gatorsmile/versionSuite. --- .../spark/sql/internal/SharedState.scala | 2 +- .../spark/sql/hive/HiveExternalCatalog.scala | 2 +- .../apache/spark/sql/hive/test/TestHive.scala | 69 +++++++++++++-- .../spark/sql/hive/client/VersionsSuite.scala | 85 +++++++++++-------- 4 files changed, 112 insertions(+), 46 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/SharedState.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/SharedState.scala index 86129fa87feaa..1ef9d52713d92 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/SharedState.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/SharedState.scala @@ -87,7 +87,7 @@ private[sql] class SharedState(val sparkContext: SparkContext) extends Logging { /** * A catalog that interacts with external systems. */ - val externalCatalog: ExternalCatalog = + lazy val externalCatalog: ExternalCatalog = SharedState.reflect[ExternalCatalog, SparkConf, Configuration]( SharedState.externalCatalogClassName(sparkContext.conf), sparkContext.conf, diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala index fd633869dde57..33802ae62333e 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala @@ -62,7 +62,7 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat /** * A Hive client used to interact with the metastore. */ - val client: HiveClient = { + lazy val client: HiveClient = { HiveUtils.newClientForMetadata(conf, hadoopConf) } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala index 076c40d45932b..b63ed76967bd9 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala @@ -24,23 +24,24 @@ import scala.collection.JavaConverters._ import scala.collection.mutable import scala.language.implicitConversions +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.fs.Path import org.apache.hadoop.hive.conf.HiveConf.ConfVars import org.apache.hadoop.hive.ql.exec.FunctionRegistry import org.apache.hadoop.hive.serde2.`lazy`.LazySimpleSerDe import org.apache.spark.{SparkConf, SparkContext} import org.apache.spark.internal.Logging -import org.apache.spark.sql.{ExperimentalMethods, SparkSession, SQLContext} -import org.apache.spark.sql.catalyst.analysis.{Analyzer, UnresolvedRelation} -import org.apache.spark.sql.catalyst.parser.ParserInterface +import org.apache.spark.sql.{SparkSession, SQLContext} +import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation +import org.apache.spark.sql.catalyst.catalog.ExternalCatalog import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan -import org.apache.spark.sql.execution.{QueryExecution, SparkPlanner} +import org.apache.spark.sql.execution.QueryExecution import org.apache.spark.sql.execution.command.CacheTableCommand import org.apache.spark.sql.hive._ import org.apache.spark.sql.hive.client.HiveClient import org.apache.spark.sql.internal.{SessionState, SharedState, SQLConf} import org.apache.spark.sql.internal.StaticSQLConf.CATALOG_IMPLEMENTATION -import org.apache.spark.sql.streaming.StreamingQueryManager import org.apache.spark.util.{ShutdownHookManager, Utils} // SPARK-3729: Test key required to check for initialization errors with config. @@ -58,6 +59,37 @@ object TestHive .set("spark.ui.enabled", "false"))) +case class TestHiveVersion(hiveClient: HiveClient) + extends TestHiveContext(TestHive.sparkContext, hiveClient) + + +private[hive] class TestHiveExternalCatalog( + conf: SparkConf, + hadoopConf: Configuration, + hiveClient: Option[HiveClient] = None) + extends HiveExternalCatalog(conf, hadoopConf) with Logging { + + override lazy val client: HiveClient = + hiveClient.getOrElse { + HiveUtils.newClientForMetadata(conf, hadoopConf) + } +} + + +private[hive] class TestHiveSharedState( + sc: SparkContext, + hiveClient: Option[HiveClient] = None) + extends SharedState(sc) { + + override lazy val externalCatalog: ExternalCatalog = { + new TestHiveExternalCatalog( + sc.conf, + sc.hadoopConfiguration, + hiveClient) + } +} + + /** * A locally running test instance of Spark's Hive execution engine. * @@ -81,6 +113,12 @@ class TestHiveContext( this(new TestHiveSparkSession(HiveUtils.withHiveExternalCatalog(sc), loadTestTables)) } + def this(sc: SparkContext, hiveClient: HiveClient) { + this(new TestHiveSparkSession(HiveUtils.withHiveExternalCatalog(sc), + hiveClient, + loadTestTables = false)) + } + override def newSession(): TestHiveContext = { new TestHiveContext(sparkSession.newSession()) } @@ -115,7 +153,7 @@ class TestHiveContext( */ private[hive] class TestHiveSparkSession( @transient private val sc: SparkContext, - @transient private val existingSharedState: Option[SharedState], + @transient private val existingSharedState: Option[TestHiveSharedState], private val loadTestTables: Boolean) extends SparkSession(sc) with Logging { self => @@ -126,6 +164,13 @@ private[hive] class TestHiveSparkSession( loadTestTables) } + def this(sc: SparkContext, hiveClient: HiveClient, loadTestTables: Boolean) { + this( + sc, + existingSharedState = Some(new TestHiveSharedState(sc, Some(hiveClient))), + loadTestTables) + } + { // set the metastore temporary configuration val metastoreTempConf = HiveUtils.newTemporaryConfiguration(useInMemoryDerby = false) ++ Map( ConfVars.METASTORE_INTEGER_JDO_PUSHDOWN.varname -> "true", @@ -141,8 +186,8 @@ private[hive] class TestHiveSparkSession( assume(sc.conf.get(CATALOG_IMPLEMENTATION) == "hive") @transient - override lazy val sharedState: SharedState = { - existingSharedState.getOrElse(new SharedState(sc)) + override lazy val sharedState: TestHiveSharedState = { + existingSharedState.getOrElse(new TestHiveSharedState(sc)) } @transient @@ -463,6 +508,14 @@ private[hive] class TestHiveSparkSession( FunctionRegistry.getFunctionNames.asScala.filterNot(originalUDFs.contains(_)). foreach { udfName => FunctionRegistry.unregisterTemporaryUDF(udfName) } + // HDFS root scratch dir requires the write all (733) permission. For each connecting user, + // an HDFS scratch dir: ${hive.exec.scratchdir}/ is created, with + // ${hive.scratch.dir.permission}. To resolve the permission issue, the simplest way is to + // delete it. Later, it will be re-created with the right permission. + val location = new Path(sc.hadoopConfiguration.get(ConfVars.SCRATCHDIR.varname)) + val fs = location.getFileSystem(sc.hadoopConfiguration) + fs.delete(location, true) + // Some tests corrupt this value on purpose, which breaks the RESET call below. sessionState.conf.setConfString("fs.defaultFS", new File(".").toURI.toString) // It is important that we RESET first as broken hooks that might have been set could break diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala index 6025f8adbce28..cb1386111035a 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala @@ -21,21 +21,20 @@ import java.io.{ByteArrayOutputStream, File, PrintStream} import java.net.URI import org.apache.hadoop.conf.Configuration -import org.apache.hadoop.fs.Path import org.apache.hadoop.hive.ql.io.HiveIgnoreKeyTextOutputFormat import org.apache.hadoop.hive.serde2.`lazy`.LazySimpleSerDe import org.apache.hadoop.mapred.TextInputFormat +import org.apache.spark.SparkFunSuite import org.apache.spark.internal.Logging -import org.apache.spark.sql.{AnalysisException, QueryTest, Row} +import org.apache.spark.sql.{AnalysisException, Row} import org.apache.spark.sql.catalyst.{FunctionIdentifier, TableIdentifier} import org.apache.spark.sql.catalyst.analysis.{NoSuchDatabaseException, NoSuchPermanentFunctionException} import org.apache.spark.sql.catalyst.catalog._ import org.apache.spark.sql.catalyst.expressions.{AttributeReference, EqualTo, Literal} import org.apache.spark.sql.catalyst.util.quietly -import org.apache.spark.sql.hive.HiveUtils -import org.apache.spark.sql.hive.test.TestHiveSingleton -import org.apache.spark.sql.test.SQLTestUtils +import org.apache.spark.sql.hive.{HiveExternalCatalog, HiveUtils} +import org.apache.spark.sql.hive.test.TestHiveVersion import org.apache.spark.sql.types.IntegerType import org.apache.spark.sql.types.StructType import org.apache.spark.tags.ExtendedHiveTest @@ -48,11 +47,31 @@ import org.apache.spark.util.{MutableURLClassLoader, Utils} * is not fully tested. */ @ExtendedHiveTest -class VersionsSuite extends QueryTest with SQLTestUtils with TestHiveSingleton with Logging { +class VersionsSuite extends SparkFunSuite with Logging { private val clientBuilder = new HiveClientBuilder import clientBuilder.buildClient + /** + * Creates a temporary directory, which is then passed to `f` and will be deleted after `f` + * returns. + */ + protected def withTempDir(f: File => Unit): Unit = { + val dir = Utils.createTempDir().getCanonicalFile + try f(dir) finally Utils.deleteRecursively(dir) + } + + /** + * Drops table `tableName` after calling `f`. + */ + protected def withTable(tableNames: String*)(f: => Unit): Unit = { + try f finally { + tableNames.foreach { name => + versionSpark.sql(s"DROP TABLE IF EXISTS $name") + } + } + } + test("success sanity check") { val badClient = buildClient(HiveUtils.hiveExecutionVersion, new Configuration()) val db = new CatalogDatabase("default", "desc", new URI("loc"), Map()) @@ -93,6 +112,8 @@ class VersionsSuite extends QueryTest with SQLTestUtils with TestHiveSingleton w private var client: HiveClient = null + private var versionSpark: TestHiveVersion = null + versions.foreach { version => test(s"$version: create client") { client = null @@ -105,6 +126,10 @@ class VersionsSuite extends QueryTest with SQLTestUtils with TestHiveSingleton w hadoopConf.set("datanucleus.schema.autoCreateAll", "true") } client = buildClient(version, hadoopConf, HiveUtils.hiveClientConfigurations(hadoopConf)) + if (versionSpark != null) versionSpark.reset() + versionSpark = TestHiveVersion(client) + assert(versionSpark.sharedState.externalCatalog.asInstanceOf[HiveExternalCatalog].client + .version.fullVersion.startsWith(version)) } def table(database: String, tableName: String): CatalogTable = { @@ -545,22 +570,22 @@ class VersionsSuite extends QueryTest with SQLTestUtils with TestHiveSingleton w test(s"$version: CREATE TABLE AS SELECT") { withTable("tbl") { - spark.sql("CREATE TABLE tbl AS SELECT 1 AS a") - assert(spark.table("tbl").collect().toSeq == Seq(Row(1))) + versionSpark.sql("CREATE TABLE tbl AS SELECT 1 AS a") + assert(versionSpark.table("tbl").collect().toSeq == Seq(Row(1))) } } test(s"$version: Delete the temporary staging directory and files after each insert") { withTempDir { tmpDir => withTable("tab") { - spark.sql( + versionSpark.sql( s""" |CREATE TABLE tab(c1 string) |location '${tmpDir.toURI.toString}' """.stripMargin) (1 to 3).map { i => - spark.sql(s"INSERT OVERWRITE TABLE tab SELECT '$i'") + versionSpark.sql(s"INSERT OVERWRITE TABLE tab SELECT '$i'") } def listFiles(path: File): List[String] = { val dir = path.listFiles() @@ -569,7 +594,9 @@ class VersionsSuite extends QueryTest with SQLTestUtils with TestHiveSingleton w folders.flatMap(listFiles) ++: filePaths } // expect 2 files left: `.part-00000-random-uuid.crc` and `part-00000-random-uuid` - assert(listFiles(tmpDir).length == 2) + // 0.12, 0.13, 1.0 and 1.1 also has another two more files ._SUCCESS.crc and _SUCCESS + val metadataFiles = Seq("._SUCCESS.crc", "_SUCCESS") + assert(listFiles(tmpDir).filterNot(metadataFiles.contains).length == 2) } } } @@ -609,7 +636,7 @@ class VersionsSuite extends QueryTest with SQLTestUtils with TestHiveSingleton w withTable(tableName, tempTableName) { // Creates the external partitioned Avro table to be tested. - sql( + versionSpark.sql( s"""CREATE EXTERNAL TABLE $tableName |PARTITIONED BY (ds STRING) |ROW FORMAT SERDE 'org.apache.hadoop.hive.serde2.avro.AvroSerDe' @@ -622,7 +649,7 @@ class VersionsSuite extends QueryTest with SQLTestUtils with TestHiveSingleton w ) // Creates an temporary Avro table used to prepare testing Avro file. - sql( + versionSpark.sql( s"""CREATE EXTERNAL TABLE $tempTableName |ROW FORMAT SERDE 'org.apache.hadoop.hive.serde2.avro.AvroSerDe' |STORED AS @@ -634,43 +661,29 @@ class VersionsSuite extends QueryTest with SQLTestUtils with TestHiveSingleton w ) // Generates Avro data. - sql(s"INSERT OVERWRITE TABLE $tempTableName SELECT 1, STRUCT(2, 2.5)") + versionSpark.sql(s"INSERT OVERWRITE TABLE $tempTableName SELECT 1, STRUCT(2, 2.5)") // Adds generated Avro data as a new partition to the testing table. - sql(s"ALTER TABLE $tableName ADD PARTITION (ds = 'foo') LOCATION '$path/$tempTableName'") + versionSpark.sql( + s"ALTER TABLE $tableName ADD PARTITION (ds = 'foo') LOCATION '$path/$tempTableName'") // The following query fails before SPARK-13709 is fixed. This is because when reading // data from table partitions, Avro deserializer needs the Avro schema, which is defined // in table property "avro.schema.literal". However, we only initializes the deserializer // using partition properties, which doesn't include the wanted property entry. Merging // two sets of properties solves the problem. - checkAnswer( - sql(s"SELECT * FROM $tableName"), - Row(1, Row(2, 2.5D), "foo") - ) + assert(versionSpark.sql(s"SELECT * FROM $tableName").collect() === + Array(Row(1, Row(2, 2.5D), "foo"))) } } } test(s"$version: CTAS for managed data source tables") { withTable("t", "t1") { - import spark.implicits._ - - val tPath = new Path(spark.sessionState.conf.warehousePath, "t") - Seq("1").toDF("a").write.saveAsTable("t") - val table = spark.sessionState.catalog.getTableMetadata(TableIdentifier("t")) - - assert(table.location == makeQualifiedPath(tPath.toString)) - assert(tPath.getFileSystem(spark.sessionState.newHadoopConf()).exists(tPath)) - checkAnswer(spark.table("t"), Row("1") :: Nil) - - val t1Path = new Path(spark.sessionState.conf.warehousePath, "t1") - spark.sql("create table t1 using parquet as select 2 as a") - val table1 = spark.sessionState.catalog.getTableMetadata(TableIdentifier("t1")) - - assert(table1.location == makeQualifiedPath(t1Path.toString)) - assert(t1Path.getFileSystem(spark.sessionState.newHadoopConf()).exists(t1Path)) - checkAnswer(spark.table("t1"), Row(2) :: Nil) + versionSpark.range(1).write.saveAsTable("t") + assert(versionSpark.table("t").collect() === Array(Row(0))) + versionSpark.sql("create table t1 using parquet as select 2 as a") + assert(versionSpark.table("t1").collect() === Array(Row(2))) } } // TODO: add more tests. From f6314eab4b494bd5b5e9e41c6f582d4f22c0967a Mon Sep 17 00:00:00 2001 From: actuaryzhang Date: Tue, 14 Mar 2017 00:50:38 -0700 Subject: [PATCH 0005/1765] [SPARK-19391][SPARKR][ML] Tweedie GLM API for SparkR ## What changes were proposed in this pull request? Port Tweedie GLM #16344 to SparkR felixcheung yanboliang ## How was this patch tested? new test in SparkR Author: actuaryzhang Closes #16729 from actuaryzhang/sparkRTweedie. --- R/pkg/R/mllib_regression.R | 55 ++++++++++++++++--- .../tests/testthat/test_mllib_regression.R | 38 ++++++++++++- R/pkg/vignettes/sparkr-vignettes.Rmd | 19 ++++++- .../GeneralizedLinearRegressionWrapper.scala | 19 +++++-- 4 files changed, 117 insertions(+), 14 deletions(-) diff --git a/R/pkg/R/mllib_regression.R b/R/pkg/R/mllib_regression.R index 648d363f1a255..d59c890f3e5fd 100644 --- a/R/pkg/R/mllib_regression.R +++ b/R/pkg/R/mllib_regression.R @@ -53,12 +53,23 @@ setClass("IsotonicRegressionModel", representation(jobj = "jobj")) #' the result of a call to a family function. Refer R family at #' \url{https://stat.ethz.ch/R-manual/R-devel/library/stats/html/family.html}. #' Currently these families are supported: \code{binomial}, \code{gaussian}, -#' \code{Gamma}, and \code{poisson}. +#' \code{Gamma}, \code{poisson} and \code{tweedie}. +#' +#' Note that there are two ways to specify the tweedie family. +#' \itemize{ +#' \item Set \code{family = "tweedie"} and specify the var.power and link.power; +#' \item When package \code{statmod} is loaded, the tweedie family is specified using the +#' family definition therein, i.e., \code{tweedie(var.power, link.power)}. +#' } #' @param tol positive convergence tolerance of iterations. #' @param maxIter integer giving the maximal number of IRLS iterations. #' @param weightCol the weight column name. If this is not set or \code{NULL}, we treat all instance #' weights as 1.0. #' @param regParam regularization parameter for L2 regularization. +#' @param var.power the power in the variance function of the Tweedie distribution which provides +#' the relationship between the variance and mean of the distribution. Only +#' applicable to the Tweedie family. +#' @param link.power the index in the power link function. Only applicable to the Tweedie family. #' @param ... additional arguments passed to the method. #' @aliases spark.glm,SparkDataFrame,formula-method #' @return \code{spark.glm} returns a fitted generalized linear model. @@ -84,14 +95,30 @@ setClass("IsotonicRegressionModel", representation(jobj = "jobj")) #' # can also read back the saved model and print #' savedModel <- read.ml(path) #' summary(savedModel) +#' +#' # fit tweedie model +#' model <- spark.glm(df, Freq ~ Sex + Age, family = "tweedie", +#' var.power = 1.2, link.power = 0) +#' summary(model) +#' +#' # use the tweedie family from statmod +#' library(statmod) +#' model <- spark.glm(df, Freq ~ Sex + Age, family = tweedie(1.2, 0)) +#' summary(model) #' } #' @note spark.glm since 2.0.0 #' @seealso \link{glm}, \link{read.ml} setMethod("spark.glm", signature(data = "SparkDataFrame", formula = "formula"), function(data, formula, family = gaussian, tol = 1e-6, maxIter = 25, weightCol = NULL, - regParam = 0.0) { + regParam = 0.0, var.power = 0.0, link.power = 1.0 - var.power) { + if (is.character(family)) { - family <- get(family, mode = "function", envir = parent.frame()) + # Handle when family = "tweedie" + if (tolower(family) == "tweedie") { + family <- list(family = "tweedie", link = NULL) + } else { + family <- get(family, mode = "function", envir = parent.frame()) + } } if (is.function(family)) { family <- family() @@ -100,6 +127,12 @@ setMethod("spark.glm", signature(data = "SparkDataFrame", formula = "formula"), print(family) stop("'family' not recognized") } + # Handle when family = statmod::tweedie() + if (tolower(family$family) == "tweedie" && !is.null(family$variance)) { + var.power <- log(family$variance(exp(1))) + link.power <- log(family$linkfun(exp(1))) + family <- list(family = "tweedie", link = NULL) + } formula <- paste(deparse(formula), collapse = "") if (!is.null(weightCol) && weightCol == "") { @@ -111,7 +144,8 @@ setMethod("spark.glm", signature(data = "SparkDataFrame", formula = "formula"), # For known families, Gamma is upper-cased jobj <- callJStatic("org.apache.spark.ml.r.GeneralizedLinearRegressionWrapper", "fit", formula, data@sdf, tolower(family$family), family$link, - tol, as.integer(maxIter), weightCol, regParam) + tol, as.integer(maxIter), weightCol, regParam, + as.double(var.power), as.double(link.power)) new("GeneralizedLinearRegressionModel", jobj = jobj) }) @@ -126,11 +160,13 @@ setMethod("spark.glm", signature(data = "SparkDataFrame", formula = "formula"), #' the result of a call to a family function. Refer R family at #' \url{https://stat.ethz.ch/R-manual/R-devel/library/stats/html/family.html}. #' Currently these families are supported: \code{binomial}, \code{gaussian}, -#' \code{Gamma}, and \code{poisson}. +#' \code{poisson}, \code{Gamma}, and \code{tweedie}. #' @param weightCol the weight column name. If this is not set or \code{NULL}, we treat all instance #' weights as 1.0. #' @param epsilon positive convergence tolerance of iterations. #' @param maxit integer giving the maximal number of IRLS iterations. +#' @param var.power the index of the power variance function in the Tweedie family. +#' @param link.power the index of the power link function in the Tweedie family. #' @return \code{glm} returns a fitted generalized linear model. #' @rdname glm #' @export @@ -145,8 +181,10 @@ setMethod("spark.glm", signature(data = "SparkDataFrame", formula = "formula"), #' @note glm since 1.5.0 #' @seealso \link{spark.glm} setMethod("glm", signature(formula = "formula", family = "ANY", data = "SparkDataFrame"), - function(formula, family = gaussian, data, epsilon = 1e-6, maxit = 25, weightCol = NULL) { - spark.glm(data, formula, family, tol = epsilon, maxIter = maxit, weightCol = weightCol) + function(formula, family = gaussian, data, epsilon = 1e-6, maxit = 25, weightCol = NULL, + var.power = 0.0, link.power = 1.0 - var.power) { + spark.glm(data, formula, family, tol = epsilon, maxIter = maxit, weightCol = weightCol, + var.power = var.power, link.power = link.power) }) # Returns the summary of a model produced by glm() or spark.glm(), similarly to R's summary(). @@ -172,9 +210,10 @@ setMethod("summary", signature(object = "GeneralizedLinearRegressionModel"), deviance <- callJMethod(jobj, "rDeviance") df.null <- callJMethod(jobj, "rResidualDegreeOfFreedomNull") df.residual <- callJMethod(jobj, "rResidualDegreeOfFreedom") - aic <- callJMethod(jobj, "rAic") iter <- callJMethod(jobj, "rNumIterations") family <- callJMethod(jobj, "rFamily") + aic <- callJMethod(jobj, "rAic") + if (family == "tweedie" && aic == 0) aic <- NA deviance.resid <- if (is.loaded) { NULL } else { diff --git a/R/pkg/inst/tests/testthat/test_mllib_regression.R b/R/pkg/inst/tests/testthat/test_mllib_regression.R index 81a5bdc414927..3e9ad77198073 100644 --- a/R/pkg/inst/tests/testthat/test_mllib_regression.R +++ b/R/pkg/inst/tests/testthat/test_mllib_regression.R @@ -77,6 +77,24 @@ test_that("spark.glm and predict", { out <- capture.output(print(summary(model))) expect_true(any(grepl("Dispersion parameter for gamma family", out))) + # tweedie family + model <- spark.glm(training, Sepal_Width ~ Sepal_Length + Species, + family = "tweedie", var.power = 1.2, link.power = 0.0) + prediction <- predict(model, training) + expect_equal(typeof(take(select(prediction, "prediction"), 1)$prediction), "double") + vals <- collect(select(prediction, "prediction")) + + # manual calculation of the R predicted values to avoid dependence on statmod + #' library(statmod) + #' rModel <- glm(Sepal.Width ~ Sepal.Length + Species, data = iris, + #' family = tweedie(var.power = 1.2, link.power = 0.0)) + #' print(coef(rModel)) + + rCoef <- c(0.6455409, 0.1169143, -0.3224752, -0.3282174) + rVals <- exp(as.numeric(model.matrix(Sepal.Width ~ Sepal.Length + Species, + data = iris) %*% rCoef)) + expect_true(all(abs(rVals - vals) < 1e-5), rVals - vals) + # Test stats::predict is working x <- rnorm(15) y <- x + rnorm(15) @@ -233,7 +251,7 @@ test_that("glm and predict", { training <- suppressWarnings(createDataFrame(iris)) # gaussian family model <- glm(Sepal_Width ~ Sepal_Length + Species, data = training) - prediction <- predict(model, training) + prediction <- predict(model, training) expect_equal(typeof(take(select(prediction, "prediction"), 1)$prediction), "double") vals <- collect(select(prediction, "prediction")) rVals <- predict(glm(Sepal.Width ~ Sepal.Length + Species, data = iris), iris) @@ -249,6 +267,24 @@ test_that("glm and predict", { data = iris, family = poisson(link = identity)), iris)) expect_true(all(abs(rVals - vals) < 1e-6), rVals - vals) + # tweedie family + model <- glm(Sepal_Width ~ Sepal_Length + Species, data = training, + family = "tweedie", var.power = 1.2, link.power = 0.0) + prediction <- predict(model, training) + expect_equal(typeof(take(select(prediction, "prediction"), 1)$prediction), "double") + vals <- collect(select(prediction, "prediction")) + + # manual calculation of the R predicted values to avoid dependence on statmod + #' library(statmod) + #' rModel <- glm(Sepal.Width ~ Sepal.Length + Species, data = iris, + #' family = tweedie(var.power = 1.2, link.power = 0.0)) + #' print(coef(rModel)) + + rCoef <- c(0.6455409, 0.1169143, -0.3224752, -0.3282174) + rVals <- exp(as.numeric(model.matrix(Sepal.Width ~ Sepal.Length + Species, + data = iris) %*% rCoef)) + expect_true(all(abs(rVals - vals) < 1e-5), rVals - vals) + # Test stats::predict is working x <- rnorm(15) y <- x + rnorm(15) diff --git a/R/pkg/vignettes/sparkr-vignettes.Rmd b/R/pkg/vignettes/sparkr-vignettes.Rmd index 43c255cff3028..a6ff650c33fea 100644 --- a/R/pkg/vignettes/sparkr-vignettes.Rmd +++ b/R/pkg/vignettes/sparkr-vignettes.Rmd @@ -672,6 +672,7 @@ gaussian | identity, log, inverse binomial | logit, probit, cloglog (complementary log-log) poisson | log, identity, sqrt gamma | inverse, identity, log +tweedie | power link function There are three ways to specify the `family` argument. @@ -679,7 +680,11 @@ There are three ways to specify the `family` argument. * Family function, e.g. `family = binomial`. -* Result returned by a family function, e.g. `family = poisson(link = log)` +* Result returned by a family function, e.g. `family = poisson(link = log)`. + +* Note that there are two ways to specify the tweedie family: + a) Set `family = "tweedie"` and specify the `var.power` and `link.power` + b) When package `statmod` is loaded, the tweedie family is specified using the family definition therein, i.e., `tweedie()`. For more information regarding the families and their link functions, see the Wikipedia page [Generalized Linear Model](https://en.wikipedia.org/wiki/Generalized_linear_model). @@ -695,6 +700,18 @@ gaussianFitted <- predict(gaussianGLM, carsDF) head(select(gaussianFitted, "model", "prediction", "mpg", "wt", "hp")) ``` +The following is the same fit using the tweedie family: +```{r} +tweedieGLM1 <- spark.glm(carsDF, mpg ~ wt + hp, family = "tweedie", var.power = 0.0) +summary(tweedieGLM1) +``` +We can try other distributions in the tweedie family, for example, a compound Poisson distribution with a log link: +```{r} +tweedieGLM2 <- spark.glm(carsDF, mpg ~ wt + hp, family = "tweedie", + var.power = 1.2, link.power = 0.0) +summary(tweedieGLM2) +``` + #### Isotonic Regression `spark.isoreg` fits an [Isotonic Regression](https://en.wikipedia.org/wiki/Isotonic_regression) model against a `SparkDataFrame`. It solves a weighted univariate a regression problem under a complete order constraint. Specifically, given a set of real observed responses $y_1, \ldots, y_n$, corresponding real features $x_1, \ldots, x_n$, and optionally positive weights $w_1, \ldots, w_n$, we want to find a monotone (piecewise linear) function $f$ to minimize diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/GeneralizedLinearRegressionWrapper.scala b/mllib/src/main/scala/org/apache/spark/ml/r/GeneralizedLinearRegressionWrapper.scala index cbd6cd1c7933c..c49416b240181 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/r/GeneralizedLinearRegressionWrapper.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/r/GeneralizedLinearRegressionWrapper.scala @@ -71,7 +71,9 @@ private[r] object GeneralizedLinearRegressionWrapper tol: Double, maxIter: Int, weightCol: String, - regParam: Double): GeneralizedLinearRegressionWrapper = { + regParam: Double, + variancePower: Double, + linkPower: Double): GeneralizedLinearRegressionWrapper = { val rFormula = new RFormula().setFormula(formula) checkDataColumns(rFormula, data) val rFormulaModel = rFormula.fit(data) @@ -83,13 +85,17 @@ private[r] object GeneralizedLinearRegressionWrapper // assemble and fit the pipeline val glr = new GeneralizedLinearRegression() .setFamily(family) - .setLink(link) .setFitIntercept(rFormula.hasIntercept) .setTol(tol) .setMaxIter(maxIter) .setRegParam(regParam) .setFeaturesCol(rFormula.getFeaturesCol) - + // set variancePower and linkPower if family is tweedie; otherwise, set link function + if (family.toLowerCase == "tweedie") { + glr.setVariancePower(variancePower).setLinkPower(linkPower) + } else { + glr.setLink(link) + } if (weightCol != null) glr.setWeightCol(weightCol) val pipeline = new Pipeline() @@ -145,7 +151,12 @@ private[r] object GeneralizedLinearRegressionWrapper val rDeviance: Double = summary.deviance val rResidualDegreeOfFreedomNull: Long = summary.residualDegreeOfFreedomNull val rResidualDegreeOfFreedom: Long = summary.residualDegreeOfFreedom - val rAic: Double = summary.aic + val rAic: Double = if (family.toLowerCase == "tweedie" && + !Array(0.0, 1.0, 2.0).exists(x => math.abs(x - variancePower) < 1e-8)) { + 0.0 + } else { + summary.aic + } val rNumIterations: Int = summary.numIterations new GeneralizedLinearRegressionWrapper(pipeline, rFeatures, rCoefficients, rDispersion, From 4ce970d71488c7de6025ef925f75b8b92a5a6a79 Mon Sep 17 00:00:00 2001 From: Nattavut Sutyanyong Date: Tue, 14 Mar 2017 10:37:10 +0100 Subject: [PATCH 0006/1765] [SPARK-18874][SQL] First phase: Deferring the correlated predicate pull up to Optimizer phase ## What changes were proposed in this pull request? Currently Analyzer as part of ResolveSubquery, pulls up the correlated predicates to its originating SubqueryExpression. The subquery plan is then transformed to remove the correlated predicates after they are moved up to the outer plan. In this PR, the task of pulling up correlated predicates is deferred to Optimizer. This is the initial work that will allow us to support the form of correlated subqueries that we don't support today. The design document from nsyca can be found in the following link : [DesignDoc](https://docs.google.com/document/d/1QDZ8JwU63RwGFS6KVF54Rjj9ZJyK33d49ZWbjFBaIgU/edit#) The brief description of code changes (hopefully to aid with code review) can be be found in the following link: [CodeChanges](https://docs.google.com/document/d/18mqjhL9V1An-tNta7aVE13HkALRZ5GZ24AATA-Vqqf0/edit#) ## How was this patch tested? The test case PRs were submitted earlier using. [16337](https://github.com/apache/spark/pull/16337) [16759](https://github.com/apache/spark/pull/16759) [16841](https://github.com/apache/spark/pull/16841) [16915](https://github.com/apache/spark/pull/16915) [16798](https://github.com/apache/spark/pull/16798) [16712](https://github.com/apache/spark/pull/16712) [16710](https://github.com/apache/spark/pull/16710) [16760](https://github.com/apache/spark/pull/16760) [16802](https://github.com/apache/spark/pull/16802) Author: Dilip Biswal Closes #16954 from dilipbiswal/SPARK-18874. --- .../sql/catalyst/analysis/Analyzer.scala | 314 ++++++++++-------- .../sql/catalyst/analysis/CheckAnalysis.scala | 40 +-- .../sql/catalyst/analysis/TypeCoercion.scala | 130 ++++++-- .../sql/catalyst/expressions/predicates.scala | 43 ++- .../sql/catalyst/expressions/subquery.scala | 256 ++++++++++---- .../sql/catalyst/optimizer/Optimizer.scala | 4 +- .../sql/catalyst/optimizer/subquery.scala | 159 ++++++++- .../analysis/AnalysisErrorSuite.scala | 11 +- .../analysis/ResolveSubquerySuite.scala | 2 +- .../spark/sql/catalyst/plans/PlanTest.scala | 2 - .../apache/spark/sql/execution/subquery.scala | 3 - .../invalid-correlation.sql.out | 4 +- .../org/apache/spark/sql/SubquerySuite.scala | 7 +- 13 files changed, 675 insertions(+), 300 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 93666f14958e9..a3764d8c843dd 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -21,12 +21,13 @@ import scala.annotation.tailrec import scala.collection.mutable.ArrayBuffer import org.apache.spark.sql.AnalysisException -import org.apache.spark.sql.catalyst.{CatalystConf, ScalaReflection, SimpleCatalystConf, TableIdentifier} +import org.apache.spark.sql.catalyst._ import org.apache.spark.sql.catalyst.catalog._ import org.apache.spark.sql.catalyst.encoders.OuterScopes import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.expressions.objects.NewInstance +import org.apache.spark.sql.catalyst.expressions.SubExprUtils._ import org.apache.spark.sql.catalyst.optimizer.BooleanSimplification import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, _} @@ -162,6 +163,8 @@ class Analyzer( FixNullability), Batch("ResolveTimeZone", Once, ResolveTimeZone), + Batch("Subquery", Once, + UpdateOuterReferences), Batch("Cleanup", fixedPoint, CleanupAliases) ) @@ -710,13 +713,72 @@ class Analyzer( } transformUp { case other => other transformExpressions { case a: Attribute => - attributeRewrites.get(a).getOrElse(a).withQualifier(a.qualifier) + dedupAttr(a, attributeRewrites) + case s: SubqueryExpression => + s.withNewPlan(dedupOuterReferencesInSubquery(s.plan, attributeRewrites)) } } newRight } } + private def dedupAttr(attr: Attribute, attrMap: AttributeMap[Attribute]): Attribute = { + attrMap.get(attr).getOrElse(attr).withQualifier(attr.qualifier) + } + + /** + * The outer plan may have been de-duplicated and the function below updates the + * outer references to refer to the de-duplicated attributes. + * + * For example (SQL): + * {{{ + * SELECT * FROM t1 + * INTERSECT + * SELECT * FROM t1 + * WHERE EXISTS (SELECT 1 + * FROM t2 + * WHERE t1.c1 = t2.c1) + * }}} + * Plan before resolveReference rule. + * 'Intersect + * :- Project [c1#245, c2#246] + * : +- SubqueryAlias t1 + * : +- Relation[c1#245,c2#246] parquet + * +- 'Project [*] + * +- Filter exists#257 [c1#245] + * : +- Project [1 AS 1#258] + * : +- Filter (outer(c1#245) = c1#251) + * : +- SubqueryAlias t2 + * : +- Relation[c1#251,c2#252] parquet + * +- SubqueryAlias t1 + * +- Relation[c1#245,c2#246] parquet + * Plan after the resolveReference rule. + * Intersect + * :- Project [c1#245, c2#246] + * : +- SubqueryAlias t1 + * : +- Relation[c1#245,c2#246] parquet + * +- Project [c1#259, c2#260] + * +- Filter exists#257 [c1#259] + * : +- Project [1 AS 1#258] + * : +- Filter (outer(c1#259) = c1#251) => Updated + * : +- SubqueryAlias t2 + * : +- Relation[c1#251,c2#252] parquet + * +- SubqueryAlias t1 + * +- Relation[c1#259,c2#260] parquet => Outer plan's attributes are de-duplicated. + */ + private def dedupOuterReferencesInSubquery( + plan: LogicalPlan, + attrMap: AttributeMap[Attribute]): LogicalPlan = { + plan transformDown { case currentFragment => + currentFragment transformExpressions { + case OuterReference(a: Attribute) => + OuterReference(dedupAttr(a, attrMap)) + case s: SubqueryExpression => + s.withNewPlan(dedupOuterReferencesInSubquery(s.plan, attrMap)) + } + } + } + def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { case p: LogicalPlan if !p.childrenResolved => p @@ -1132,28 +1194,21 @@ class Analyzer( } /** - * Pull out all (outer) correlated predicates from a given subquery. This method removes the - * correlated predicates from subquery [[Filter]]s and adds the references of these predicates - * to all intermediate [[Project]] and [[Aggregate]] clauses (if they are missing) in order to - * be able to evaluate the predicates at the top level. - * - * This method returns the rewritten subquery and correlated predicates. + * Validates to make sure the outer references appearing inside the subquery + * are legal. This function also returns the list of expressions + * that contain outer references. These outer references would be kept as children + * of subquery expressions by the caller of this function. */ - private def pullOutCorrelatedPredicates(sub: LogicalPlan): (LogicalPlan, Seq[Expression]) = { - val predicateMap = scala.collection.mutable.Map.empty[LogicalPlan, Seq[Expression]] + private def checkAndGetOuterReferences(sub: LogicalPlan): Seq[Expression] = { + val outerReferences = ArrayBuffer.empty[Expression] // Make sure a plan's subtree does not contain outer references def failOnOuterReferenceInSubTree(p: LogicalPlan): Unit = { - if (p.collectFirst(predicateMap).nonEmpty) { + if (hasOuterReferences(p)) { failAnalysis(s"Accessing outer query column is not allowed in:\n$p") } } - // Helper function for locating outer references. - def containsOuter(e: Expression): Boolean = { - e.find(_.isInstanceOf[OuterReference]).isDefined - } - // Make sure a plan's expressions do not contain outer references def failOnOuterReference(p: LogicalPlan): Unit = { if (p.expressions.exists(containsOuter)) { @@ -1194,20 +1249,11 @@ class Analyzer( } } - /** Determine which correlated predicate references are missing from this plan. */ - def missingReferences(p: LogicalPlan): AttributeSet = { - val localPredicateReferences = p.collect(predicateMap) - .flatten - .map(_.references) - .reduceOption(_ ++ _) - .getOrElse(AttributeSet.empty) - localPredicateReferences -- p.outputSet - } - var foundNonEqualCorrelatedPred : Boolean = false - // Simplify the predicates before pulling them out. - val transformed = BooleanSimplification(sub) transformUp { + // Simplify the predicates before validating any unsupported correlation patterns + // in the plan. + BooleanSimplification(sub).foreachUp { // Whitelist operators allowed in a correlated subquery // There are 4 categories: @@ -1229,80 +1275,48 @@ class Analyzer( // Category 1: // BroadcastHint, Distinct, LeafNode, Repartition, and SubqueryAlias - case p: BroadcastHint => - p - case p: Distinct => - p - case p: LeafNode => - p - case p: Repartition => - p - case p: SubqueryAlias => - p + case _: BroadcastHint | _: Distinct | _: LeafNode | _: Repartition | _: SubqueryAlias => // Category 2: // These operators can be anywhere in a correlated subquery. // so long as they do not host outer references in the operators. - case p: Sort => - failOnOuterReference(p) - p - case p: RepartitionByExpression => - failOnOuterReference(p) - p + case s: Sort => + failOnOuterReference(s) + case r: RepartitionByExpression => + failOnOuterReference(r) // Category 3: // Filter is one of the two operators allowed to host correlated expressions. // The other operator is Join. Filter can be anywhere in a correlated subquery. - case f @ Filter(cond, child) => + case f: Filter => // Find all predicates with an outer reference. - val (correlated, local) = splitConjunctivePredicates(cond).partition(containsOuter) + val (correlated, _) = splitConjunctivePredicates(f.condition).partition(containsOuter) // Find any non-equality correlated predicates foundNonEqualCorrelatedPred = foundNonEqualCorrelatedPred || correlated.exists { case _: EqualTo | _: EqualNullSafe => false case _ => true } - - // Rewrite the filter without the correlated predicates if any. - correlated match { - case Nil => f - case xs if local.nonEmpty => - val newFilter = Filter(local.reduce(And), child) - predicateMap += newFilter -> xs - newFilter - case xs => - predicateMap += child -> xs - child - } + // The aggregate expressions are treated in a special way by getOuterReferences. If the + // aggregate expression contains only outer reference attributes then the entire aggregate + // expression is isolated as an OuterReference. + // i.e min(OuterReference(b)) => OuterReference(min(b)) + outerReferences ++= getOuterReferences(correlated) // Project cannot host any correlated expressions // but can be anywhere in a correlated subquery. - case p @ Project(expressions, child) => + case p: Project => failOnOuterReference(p) - val referencesToAdd = missingReferences(p) - if (referencesToAdd.nonEmpty) { - Project(expressions ++ referencesToAdd, child) - } else { - p - } - // Aggregate cannot host any correlated expressions // It can be on a correlation path if the correlation contains // only equality correlated predicates. // It cannot be on a correlation path if the correlation has // non-equality correlated predicates. - case a @ Aggregate(grouping, expressions, child) => + case a: Aggregate => failOnOuterReference(a) failOnNonEqualCorrelatedPredicate(foundNonEqualCorrelatedPred, a) - val referencesToAdd = missingReferences(a) - if (referencesToAdd.nonEmpty) { - Aggregate(grouping ++ referencesToAdd, expressions ++ referencesToAdd, child) - } else { - a - } - // Join can host correlated expressions. case j @ Join(left, right, joinType, _) => joinType match { @@ -1332,7 +1346,6 @@ class Analyzer( case _ => failOnOuterReferenceInSubTree(j) } - j // Generator with join=true, i.e., expressed with // LATERAL VIEW [OUTER], similar to inner join, @@ -1340,9 +1353,8 @@ class Analyzer( // but must not host any outer references. // Note: // Generator with join=false is treated as Category 4. - case p @ Generate(generator, true, _, _, _, _) => - failOnOuterReference(p) - p + case g: Generate if g.join => + failOnOuterReference(g) // Category 4: Any other operators not in the above 3 categories // cannot be on a correlation path, that is they are allowed only @@ -1350,54 +1362,17 @@ class Analyzer( // are not allowed to have any correlated expressions. case p => failOnOuterReferenceInSubTree(p) - p } - (transformed, predicateMap.values.flatten.toSeq) + outerReferences } /** - * Rewrite the subquery in a safe way by preventing that the subquery and the outer use the same - * attributes. - */ - private def rewriteSubQuery( - sub: LogicalPlan, - outer: Seq[LogicalPlan]): (LogicalPlan, Seq[Expression]) = { - // Pull out the tagged predicates and rewrite the subquery in the process. - val (basePlan, baseConditions) = pullOutCorrelatedPredicates(sub) - - // Make sure the inner and the outer query attributes do not collide. - val outputSet = outer.map(_.outputSet).reduce(_ ++ _) - val duplicates = basePlan.outputSet.intersect(outputSet) - val (plan, deDuplicatedConditions) = if (duplicates.nonEmpty) { - val aliasMap = AttributeMap(duplicates.map { dup => - dup -> Alias(dup, dup.toString)() - }.toSeq) - val aliasedExpressions = basePlan.output.map { ref => - aliasMap.getOrElse(ref, ref) - } - val aliasedProjection = Project(aliasedExpressions, basePlan) - val aliasedConditions = baseConditions.map(_.transform { - case ref: Attribute => aliasMap.getOrElse(ref, ref).toAttribute - }) - (aliasedProjection, aliasedConditions) - } else { - (basePlan, baseConditions) - } - // Remove outer references from the correlated predicates. We wait with extracting - // these until collisions between the inner and outer query attributes have been - // solved. - val conditions = deDuplicatedConditions.map(_.transform { - case OuterReference(ref) => ref - }) - (plan, conditions) - } - - /** - * Resolve and rewrite a subquery. The subquery is resolved using its outer plans. This method + * Resolves the subquery. The subquery is resolved using its outer plans. This method * will resolve the subquery by alternating between the regular analyzer and by applying the * resolveOuterReferences rule. * - * All correlated conditions are pulled out of the subquery as soon as the subquery is resolved. + * Outer references from the correlated predicates are updated as children of + * Subquery expression. */ private def resolveSubQuery( e: SubqueryExpression, @@ -1420,7 +1395,8 @@ class Analyzer( } } while (!current.resolved && !current.fastEquals(previous)) - // Step 2: Pull out the predicates if the plan is resolved. + // Step 2: If the subquery plan is fully resolved, pull the outer references and record + // them as children of SubqueryExpression. if (current.resolved) { // Make sure the resolved query has the required number of output columns. This is only // needed for Scalar and IN subqueries. @@ -1428,34 +1404,37 @@ class Analyzer( failAnalysis(s"The number of columns in the subquery (${current.output.size}) " + s"does not match the required number of columns ($requiredColumns)") } - // Pullout predicates and construct a new plan. - f.tupled(rewriteSubQuery(current, plans)) + // Validate the outer reference and record the outer references as children of + // subquery expression. + f(current, checkAndGetOuterReferences(current)) } else { e.withNewPlan(current) } } /** - * Resolve and rewrite all subqueries in a LogicalPlan. This method transforms IN and EXISTS - * expressions into PredicateSubquery expression once the are resolved. + * Resolves the subquery. Apart of resolving the subquery and outer references (if any) + * in the subquery plan, the children of subquery expression are updated to record the + * outer references. This is needed to make sure + * (1) The column(s) referred from the outer query are not pruned from the plan during + * optimization. + * (2) Any aggregate expression(s) that reference outer attributes are pushed down to + * outer plan to get evaluated. */ private def resolveSubQueries(plan: LogicalPlan, plans: Seq[LogicalPlan]): LogicalPlan = { plan transformExpressions { case s @ ScalarSubquery(sub, _, exprId) if !sub.resolved => resolveSubQuery(s, plans, 1)(ScalarSubquery(_, _, exprId)) - case e @ Exists(sub, exprId) => - resolveSubQuery(e, plans)(PredicateSubquery(_, _, nullAware = false, exprId)) - case In(e, Seq(l @ ListQuery(_, exprId))) if e.resolved => + case e @ Exists(sub, _, exprId) if !sub.resolved => + resolveSubQuery(e, plans)(Exists(_, _, exprId)) + case In(value, Seq(l @ ListQuery(sub, _, exprId))) if value.resolved && !sub.resolved => // Get the left hand side expressions. - val expressions = e match { + val expressions = value match { case cns : CreateNamedStruct => cns.valExprs case expr => Seq(expr) } - resolveSubQuery(l, plans, expressions.size) { (rewrite, conditions) => - // Construct the IN conditions. - val inConditions = expressions.zip(rewrite.output).map(EqualTo.tupled) - PredicateSubquery(rewrite, inConditions ++ conditions, nullAware = true, exprId) - } + val expr = resolveSubQuery(l, plans, expressions.size)(ListQuery(_, _, exprId)) + In(value, Seq(expr)) } } @@ -2353,6 +2332,11 @@ class Analyzer( override def apply(plan: LogicalPlan): LogicalPlan = plan.resolveExpressions { case e: TimeZoneAwareExpression if e.timeZoneId.isEmpty => e.withTimeZone(conf.sessionLocalTimeZone) + // Casts could be added in the subquery plan through the rule TypeCoercion while coercing + // the types between the value expression and list query expression of IN expression. + // We need to subject the subquery plan through ResolveTimeZone again to setup timezone + // information for time zone aware expressions. + case e: ListQuery => e.withNewPlan(apply(e.plan)) } } } @@ -2533,3 +2517,67 @@ object ResolveCreateNamedStruct extends Rule[LogicalPlan] { CreateNamedStruct(children.toList) } } + +/** + * The aggregate expressions from subquery referencing outer query block are pushed + * down to the outer query block for evaluation. This rule below updates such outer references + * as AttributeReference referring attributes from the parent/outer query block. + * + * For example (SQL): + * {{{ + * SELECT l.a FROM l GROUP BY 1 HAVING EXISTS (SELECT 1 FROM r WHERE r.d < min(l.b)) + * }}} + * Plan before the rule. + * Project [a#226] + * +- Filter exists#245 [min(b#227)#249] + * : +- Project [1 AS 1#247] + * : +- Filter (d#238 < min(outer(b#227))) <----- + * : +- SubqueryAlias r + * : +- Project [_1#234 AS c#237, _2#235 AS d#238] + * : +- LocalRelation [_1#234, _2#235] + * +- Aggregate [a#226], [a#226, min(b#227) AS min(b#227)#249] + * +- SubqueryAlias l + * +- Project [_1#223 AS a#226, _2#224 AS b#227] + * +- LocalRelation [_1#223, _2#224] + * Plan after the rule. + * Project [a#226] + * +- Filter exists#245 [min(b#227)#249] + * : +- Project [1 AS 1#247] + * : +- Filter (d#238 < outer(min(b#227)#249)) <----- + * : +- SubqueryAlias r + * : +- Project [_1#234 AS c#237, _2#235 AS d#238] + * : +- LocalRelation [_1#234, _2#235] + * +- Aggregate [a#226], [a#226, min(b#227) AS min(b#227)#249] + * +- SubqueryAlias l + * +- Project [_1#223 AS a#226, _2#224 AS b#227] + * +- LocalRelation [_1#223, _2#224] + */ +object UpdateOuterReferences extends Rule[LogicalPlan] { + private def stripAlias(expr: Expression): Expression = expr match { case a: Alias => a.child } + + private def updateOuterReferenceInSubquery( + plan: LogicalPlan, + refExprs: Seq[Expression]): LogicalPlan = { + plan transformAllExpressions { case e => + val outerAlias = + refExprs.find(stripAlias(_).semanticEquals(stripOuterReference(e))) + outerAlias match { + case Some(a: Alias) => OuterReference(a.toAttribute) + case _ => e + } + } + } + + def apply(plan: LogicalPlan): LogicalPlan = { + plan transform { + case f @ Filter(_, a: Aggregate) if f.resolved => + f transformExpressions { + case s: SubqueryExpression if s.children.nonEmpty => + // Collect the aliases from output of aggregate. + val outerAliases = a.aggregateExpressions collect { case a: Alias => a } + // Update the subquery plan to record the OuterReference to point to outer query plan. + s.withNewPlan(updateOuterReferenceInSubquery(s.plan, outerAliases)) + } + } + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala index d32fbeb4e91ef..da0c6b098f5ce 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.catalyst.analysis import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression +import org.apache.spark.sql.catalyst.expressions.SubExprUtils._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.types._ @@ -133,10 +134,8 @@ trait CheckAnalysis extends PredicateHelper { if (conditions.isEmpty && query.output.size != 1) { failAnalysis( s"Scalar subquery must return only one column, but got ${query.output.size}") - } else if (conditions.nonEmpty) { - // Collect the columns from the subquery for further checking. - var subqueryColumns = conditions.flatMap(_.references).filter(query.output.contains) - + } + else if (conditions.nonEmpty) { def checkAggregate(agg: Aggregate): Unit = { // Make sure correlated scalar subqueries contain one row for every outer row by // enforcing that they are aggregates containing exactly one aggregate expression. @@ -152,6 +151,9 @@ trait CheckAnalysis extends PredicateHelper { // SPARK-18504/SPARK-18814: Block cases where GROUP BY columns // are not part of the correlated columns. val groupByCols = AttributeSet(agg.groupingExpressions.flatMap(_.references)) + // Collect the local references from the correlated predicate in the subquery. + val subqueryColumns = getCorrelatedPredicates(query).flatMap(_.references) + .filterNot(conditions.flatMap(_.references).contains) val correlatedCols = AttributeSet(subqueryColumns) val invalidCols = groupByCols -- correlatedCols // GROUP BY columns must be a subset of columns in the predicates @@ -167,17 +169,7 @@ trait CheckAnalysis extends PredicateHelper { // For projects, do the necessary mapping and skip to its child. def cleanQuery(p: LogicalPlan): LogicalPlan = p match { case s: SubqueryAlias => cleanQuery(s.child) - case p: Project => - // SPARK-18814: Map any aliases to their AttributeReference children - // for the checking in the Aggregate operators below this Project. - subqueryColumns = subqueryColumns.map { - xs => p.projectList.collectFirst { - case e @ Alias(child : AttributeReference, _) if e.exprId == xs.exprId => - child - }.getOrElse(xs) - } - - cleanQuery(p.child) + case p: Project => cleanQuery(p.child) case child => child } @@ -211,14 +203,9 @@ trait CheckAnalysis extends PredicateHelper { s"filter expression '${f.condition.sql}' " + s"of type ${f.condition.dataType.simpleString} is not a boolean.") - case Filter(condition, _) => - splitConjunctivePredicates(condition).foreach { - case _: PredicateSubquery | Not(_: PredicateSubquery) => - case e if PredicateSubquery.hasNullAwarePredicateWithinNot(e) => - failAnalysis(s"Null-aware predicate sub-queries cannot be used in nested" + - s" conditions: $e") - case e => - } + case Filter(condition, _) if hasNullAwarePredicateWithinNot(condition) => + failAnalysis("Null-aware predicate sub-queries cannot be used in nested " + + s"conditions: $condition") case j @ Join(_, _, _, Some(condition)) if condition.dataType != BooleanType => failAnalysis( @@ -306,8 +293,11 @@ trait CheckAnalysis extends PredicateHelper { s"Correlated scalar sub-queries can only be used in a Filter/Aggregate/Project: $p") } - case p if p.expressions.exists(PredicateSubquery.hasPredicateSubquery) => - failAnalysis(s"Predicate sub-queries can only be used in a Filter: $p") + case p if p.expressions.exists(SubqueryExpression.hasInOrExistsSubquery) => + p match { + case _: Filter => // Ok + case _ => failAnalysis(s"Predicate sub-queries can only be used in a Filter: $p") + } case _: Union | _: SetOperation if operator.children.length > 1 => def dataTypes(plan: LogicalPlan): Seq[DataType] = plan.output.map(_.dataType) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala index 2c00957bd6afb..768897dc0713c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala @@ -108,6 +108,28 @@ object TypeCoercion { case _ => None } + /** + * This function determines the target type of a comparison operator when one operand + * is a String and the other is not. It also handles when one op is a Date and the + * other is a Timestamp by making the target type to be String. + */ + val findCommonTypeForBinaryComparison: (DataType, DataType) => Option[DataType] = { + // We should cast all relative timestamp/date/string comparison into string comparisons + // This behaves as a user would expect because timestamp strings sort lexicographically. + // i.e. TimeStamp(2013-01-01 00:00 ...) < "2014" = true + case (StringType, DateType) => Some(StringType) + case (DateType, StringType) => Some(StringType) + case (StringType, TimestampType) => Some(StringType) + case (TimestampType, StringType) => Some(StringType) + case (TimestampType, DateType) => Some(StringType) + case (DateType, TimestampType) => Some(StringType) + case (StringType, NullType) => Some(StringType) + case (NullType, StringType) => Some(StringType) + case (l: StringType, r: AtomicType) if r != StringType => Some(r) + case (l: AtomicType, r: StringType) if (l != StringType) => Some(l) + case (l, r) => None + } + /** * Case 2 type widening (see the classdoc comment above for TypeCoercion). * @@ -305,6 +327,14 @@ object TypeCoercion { * Promotes strings that appear in arithmetic expressions. */ object PromoteStrings extends Rule[LogicalPlan] { + private def castExpr(expr: Expression, targetType: DataType): Expression = { + (expr.dataType, targetType) match { + case (NullType, dt) => Literal.create(null, targetType) + case (l, dt) if (l != dt) => Cast(expr, targetType) + case _ => expr + } + } + def apply(plan: LogicalPlan): LogicalPlan = plan resolveExpressions { // Skip nodes who's children have not been resolved yet. case e if !e.childrenResolved => e @@ -321,37 +351,10 @@ object TypeCoercion { case p @ Equality(left @ TimestampType(), right @ StringType()) => p.makeCopy(Array(left, Cast(right, TimestampType))) - // We should cast all relative timestamp/date/string comparison into string comparisons - // This behaves as a user would expect because timestamp strings sort lexicographically. - // i.e. TimeStamp(2013-01-01 00:00 ...) < "2014" = true - case p @ BinaryComparison(left @ StringType(), right @ DateType()) => - p.makeCopy(Array(left, Cast(right, StringType))) - case p @ BinaryComparison(left @ DateType(), right @ StringType()) => - p.makeCopy(Array(Cast(left, StringType), right)) - case p @ BinaryComparison(left @ StringType(), right @ TimestampType()) => - p.makeCopy(Array(left, Cast(right, StringType))) - case p @ BinaryComparison(left @ TimestampType(), right @ StringType()) => - p.makeCopy(Array(Cast(left, StringType), right)) - - // Comparisons between dates and timestamps. - case p @ BinaryComparison(left @ TimestampType(), right @ DateType()) => - p.makeCopy(Array(Cast(left, StringType), Cast(right, StringType))) - case p @ BinaryComparison(left @ DateType(), right @ TimestampType()) => - p.makeCopy(Array(Cast(left, StringType), Cast(right, StringType))) - - // Checking NullType - case p @ BinaryComparison(left @ StringType(), right @ NullType()) => - p.makeCopy(Array(left, Literal.create(null, StringType))) - case p @ BinaryComparison(left @ NullType(), right @ StringType()) => - p.makeCopy(Array(Literal.create(null, StringType), right)) - - // When compare string with atomic type, case string to that type. - case p @ BinaryComparison(left @ StringType(), right @ AtomicType()) - if right.dataType != StringType => - p.makeCopy(Array(Cast(left, right.dataType), right)) - case p @ BinaryComparison(left @ AtomicType(), right @ StringType()) - if left.dataType != StringType => - p.makeCopy(Array(left, Cast(right, left.dataType))) + case p @ BinaryComparison(left, right) + if findCommonTypeForBinaryComparison(left.dataType, right.dataType).isDefined => + val commonType = findCommonTypeForBinaryComparison(left.dataType, right.dataType).get + p.makeCopy(Array(castExpr(left, commonType), castExpr(right, commonType))) case Sum(e @ StringType()) => Sum(Cast(e, DoubleType)) case Average(e @ StringType()) => Average(Cast(e, DoubleType)) @@ -365,17 +368,72 @@ object TypeCoercion { } /** - * Convert the value and in list expressions to the common operator type - * by looking at all the argument types and finding the closest one that - * all the arguments can be cast to. When no common operator type is found - * the original expression will be returned and an Analysis Exception will - * be raised at type checking phase. + * Handles type coercion for both IN expression with subquery and IN + * expressions without subquery. + * 1. In the first case, find the common type by comparing the left hand side (LHS) + * expression types against corresponding right hand side (RHS) expression derived + * from the subquery expression's plan output. Inject appropriate casts in the + * LHS and RHS side of IN expression. + * + * 2. In the second case, convert the value and in list expressions to the + * common operator type by looking at all the argument types and finding + * the closest one that all the arguments can be cast to. When no common + * operator type is found the original expression will be returned and an + * Analysis Exception will be raised at the type checking phase. */ object InConversion extends Rule[LogicalPlan] { + private def flattenExpr(expr: Expression): Seq[Expression] = { + expr match { + // Multi columns in IN clause is represented as a CreateNamedStruct. + // flatten the named struct to get the list of expressions. + case cns: CreateNamedStruct => cns.valExprs + case expr => Seq(expr) + } + } + def apply(plan: LogicalPlan): LogicalPlan = plan resolveExpressions { // Skip nodes who's children have not been resolved yet. case e if !e.childrenResolved => e + // Handle type casting required between value expression and subquery output + // in IN subquery. + case i @ In(a, Seq(ListQuery(sub, children, exprId))) + if !i.resolved && flattenExpr(a).length == sub.output.length => + // LHS is the value expression of IN subquery. + val lhs = flattenExpr(a) + + // RHS is the subquery output. + val rhs = sub.output + + val commonTypes = lhs.zip(rhs).flatMap { case (l, r) => + findCommonTypeForBinaryComparison(l.dataType, r.dataType) + .orElse(findTightestCommonType(l.dataType, r.dataType)) + } + + // The number of columns/expressions must match between LHS and RHS of an + // IN subquery expression. + if (commonTypes.length == lhs.length) { + val castedRhs = rhs.zip(commonTypes).map { + case (e, dt) if e.dataType != dt => Alias(Cast(e, dt), e.name)() + case (e, _) => e + } + val castedLhs = lhs.zip(commonTypes).map { + case (e, dt) if e.dataType != dt => Cast(e, dt) + case (e, _) => e + } + + // Before constructing the In expression, wrap the multi values in LHS + // in a CreatedNamedStruct. + val newLhs = castedLhs match { + case Seq(lhs) => lhs + case _ => CreateStruct(castedLhs) + } + + In(newLhs, Seq(ListQuery(Project(castedRhs, sub), children, exprId))) + } else { + i + } + case i @ In(a, b) if b.exists(_.dataType != a.dataType) => findWiderCommonType(i.children.map(_.dataType)) match { case Some(finalDataType) => i.withNewChildren(i.children.map(Cast(_, finalDataType))) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala index ac56ff13fa5bf..e5d1a1e2996c5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala @@ -123,19 +123,44 @@ case class Not(child: Expression) */ @ExpressionDescription( usage = "expr1 _FUNC_(expr2, expr3, ...) - Returns true if `expr` equals to any valN.") -case class In(value: Expression, list: Seq[Expression]) extends Predicate - with ImplicitCastInputTypes { +case class In(value: Expression, list: Seq[Expression]) extends Predicate { require(list != null, "list should not be null") + override def checkInputDataTypes(): TypeCheckResult = { + list match { + case ListQuery(sub, _, _) :: Nil => + val valExprs = value match { + case cns: CreateNamedStruct => cns.valExprs + case expr => Seq(expr) + } - override def inputTypes: Seq[AbstractDataType] = value.dataType +: list.map(_.dataType) + val mismatchedColumns = valExprs.zip(sub.output).flatMap { + case (l, r) if l.dataType != r.dataType => + s"(${l.sql}:${l.dataType.catalogString}, ${r.sql}:${r.dataType.catalogString})" + case _ => None + } - override def checkInputDataTypes(): TypeCheckResult = { - if (list.exists(l => l.dataType != value.dataType)) { - TypeCheckResult.TypeCheckFailure( - "Arguments must be same type") - } else { - TypeCheckResult.TypeCheckSuccess + if (mismatchedColumns.nonEmpty) { + TypeCheckResult.TypeCheckFailure( + s""" + |The data type of one or more elements in the left hand side of an IN subquery + |is not compatible with the data type of the output of the subquery + |Mismatched columns: + |[${mismatchedColumns.mkString(", ")}] + |Left side: + |[${valExprs.map(_.dataType.catalogString).mkString(", ")}]. + |Right side: + |[${sub.output.map(_.dataType.catalogString).mkString(", ")}]. + """.stripMargin) + } else { + TypeCheckResult.TypeCheckSuccess + } + case _ => + if (list.exists(l => l.dataType != value.dataType)) { + TypeCheckResult.TypeCheckFailure("Arguments must be same type") + } else { + TypeCheckResult.TypeCheckSuccess + } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala index e2e7d98e33459..ad11700fa28d2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala @@ -17,8 +17,11 @@ package org.apache.spark.sql.catalyst.expressions +import scala.collection.mutable.ArrayBuffer + +import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression import org.apache.spark.sql.catalyst.plans.QueryPlan -import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.catalyst.plans.logical.{Filter, LogicalPlan} import org.apache.spark.sql.types._ /** @@ -40,19 +43,184 @@ abstract class PlanExpression[T <: QueryPlan[_]] extends Expression { /** * A base interface for expressions that contain a [[LogicalPlan]]. */ -abstract class SubqueryExpression extends PlanExpression[LogicalPlan] { +abstract class SubqueryExpression( + plan: LogicalPlan, + children: Seq[Expression], + exprId: ExprId) extends PlanExpression[LogicalPlan] { + + override lazy val resolved: Boolean = childrenResolved && plan.resolved + override lazy val references: AttributeSet = + if (plan.resolved) super.references -- plan.outputSet else super.references override def withNewPlan(plan: LogicalPlan): SubqueryExpression + override def semanticEquals(o: Expression): Boolean = o match { + case p: SubqueryExpression => + this.getClass.getName.equals(p.getClass.getName) && plan.sameResult(p.plan) && + children.length == p.children.length && + children.zip(p.children).forall(p => p._1.semanticEquals(p._2)) + case _ => false + } } object SubqueryExpression { + /** + * Returns true when an expression contains an IN or EXISTS subquery and false otherwise. + */ + def hasInOrExistsSubquery(e: Expression): Boolean = { + e.find { + case _: ListQuery | _: Exists => true + case _ => false + }.isDefined + } + + /** + * Returns true when an expression contains a subquery that has outer reference(s). The outer + * reference attributes are kept as children of subquery expression by + * [[org.apache.spark.sql.catalyst.analysis.Analyzer.ResolveSubquery]] + */ def hasCorrelatedSubquery(e: Expression): Boolean = { e.find { - case e: SubqueryExpression if e.children.nonEmpty => true + case s: SubqueryExpression => s.children.nonEmpty case _ => false }.isDefined } } +object SubExprUtils extends PredicateHelper { + /** + * Returns true when an expression contains correlated predicates i.e outer references and + * returns false otherwise. + */ + def containsOuter(e: Expression): Boolean = { + e.find(_.isInstanceOf[OuterReference]).isDefined + } + + /** + * Returns whether there are any null-aware predicate subqueries inside Not. If not, we could + * turn the null-aware predicate into not-null-aware predicate. + */ + def hasNullAwarePredicateWithinNot(condition: Expression): Boolean = { + splitConjunctivePredicates(condition).exists { + case _: Exists | Not(_: Exists) | In(_, Seq(_: ListQuery)) | Not(In(_, Seq(_: ListQuery))) => + false + case e => e.find { x => + x.isInstanceOf[Not] && e.find { + case In(_, Seq(_: ListQuery)) => true + case _ => false + }.isDefined + }.isDefined + } + + } + + /** + * Returns an expression after removing the OuterReference shell. + */ + def stripOuterReference(e: Expression): Expression = e.transform { case OuterReference(r) => r } + + /** + * Returns the list of expressions after removing the OuterReference shell from each of + * the expression. + */ + def stripOuterReferences(e: Seq[Expression]): Seq[Expression] = e.map(stripOuterReference) + + /** + * Returns the logical plan after removing the OuterReference shell from all the expressions + * of the input logical plan. + */ + def stripOuterReferences(p: LogicalPlan): LogicalPlan = { + p.transformAllExpressions { + case OuterReference(a) => a + } + } + + /** + * Given a logical plan, returns TRUE if it has an outer reference and false otherwise. + */ + def hasOuterReferences(plan: LogicalPlan): Boolean = { + plan.find { + case f: Filter => containsOuter(f.condition) + case other => false + }.isDefined + } + + /** + * Given a list of expressions, returns the expressions which have outer references. Aggregate + * expressions are treated in a special way. If the children of aggregate expression contains an + * outer reference, then the entire aggregate expression is marked as an outer reference. + * Example (SQL): + * {{{ + * SELECT a FROM l GROUP by 1 HAVING EXISTS (SELECT 1 FROM r WHERE d < min(b)) + * }}} + * In the above case, we want to mark the entire min(b) as an outer reference + * OuterReference(min(b)) instead of min(OuterReference(b)). + * TODO: Currently we don't allow deep correlation. Also, we don't allow mixing of + * outer references and local references under an aggregate expression. + * For example (SQL): + * {{{ + * SELECT .. FROM p1 + * WHERE EXISTS (SELECT ... + * FROM p2 + * WHERE EXISTS (SELECT ... + * FROM sq + * WHERE min(p1.a + p2.b) = sq.c)) + * + * SELECT .. FROM p1 + * WHERE EXISTS (SELECT ... + * FROM p2 + * WHERE EXISTS (SELECT ... + * FROM sq + * WHERE min(p1.a) + max(p2.b) = sq.c)) + * + * SELECT .. FROM p1 + * WHERE EXISTS (SELECT ... + * FROM p2 + * WHERE EXISTS (SELECT ... + * FROM sq + * WHERE min(p1.a + sq.c) > 1)) + * }}} + * The code below needs to change when we support the above cases. + */ + def getOuterReferences(conditions: Seq[Expression]): Seq[Expression] = { + val outerExpressions = ArrayBuffer.empty[Expression] + conditions foreach { expr => + expr transformDown { + case a: AggregateExpression if a.collectLeaves.forall(_.isInstanceOf[OuterReference]) => + val newExpr = stripOuterReference(a) + outerExpressions += newExpr + newExpr + case OuterReference(e) => + outerExpressions += e + e + } + } + outerExpressions + } + + /** + * Returns all the expressions that have outer references from a logical plan. Currently only + * Filter operator can host outer references. + */ + def getOuterReferences(plan: LogicalPlan): Seq[Expression] = { + val conditions = plan.collect { case Filter(cond, _) => cond } + getOuterReferences(conditions) + } + + /** + * Returns the correlated predicates from a logical plan. The OuterReference wrapper + * is removed before returning the predicate to the caller. + */ + def getCorrelatedPredicates(plan: LogicalPlan): Seq[Expression] = { + val conditions = plan.collect { case Filter(cond, _) => cond } + conditions.flatMap { e => + val (correlated, _) = splitConjunctivePredicates(e).partition(containsOuter) + stripOuterReferences(correlated) match { + case Nil => None + case xs => xs + } + } + } +} + /** * A subquery that will return only one row and one column. This will be converted into a physical * scalar subquery during planning. @@ -63,14 +231,8 @@ case class ScalarSubquery( plan: LogicalPlan, children: Seq[Expression] = Seq.empty, exprId: ExprId = NamedExpression.newExprId) - extends SubqueryExpression with Unevaluable { - override lazy val resolved: Boolean = childrenResolved && plan.resolved - override lazy val references: AttributeSet = { - if (plan.resolved) super.references -- plan.outputSet - else super.references - } + extends SubqueryExpression(plan, children, exprId) with Unevaluable { override def dataType: DataType = plan.schema.fields.head.dataType - override def foldable: Boolean = false override def nullable: Boolean = true override def withNewPlan(plan: LogicalPlan): ScalarSubquery = copy(plan = plan) override def toString: String = s"scalar-subquery#${exprId.id} $conditionString" @@ -79,59 +241,12 @@ case class ScalarSubquery( object ScalarSubquery { def hasCorrelatedScalarSubquery(e: Expression): Boolean = { e.find { - case e: ScalarSubquery if e.children.nonEmpty => true + case s: ScalarSubquery => s.children.nonEmpty case _ => false }.isDefined } } -/** - * A predicate subquery checks the existence of a value in a sub-query. We currently only allow - * [[PredicateSubquery]] expressions within a Filter plan (i.e. WHERE or a HAVING clause). This will - * be rewritten into a left semi/anti join during analysis. - */ -case class PredicateSubquery( - plan: LogicalPlan, - children: Seq[Expression] = Seq.empty, - nullAware: Boolean = false, - exprId: ExprId = NamedExpression.newExprId) - extends SubqueryExpression with Predicate with Unevaluable { - override lazy val resolved = childrenResolved && plan.resolved - override lazy val references: AttributeSet = super.references -- plan.outputSet - override def nullable: Boolean = nullAware - override def withNewPlan(plan: LogicalPlan): PredicateSubquery = copy(plan = plan) - override def semanticEquals(o: Expression): Boolean = o match { - case p: PredicateSubquery => - plan.sameResult(p.plan) && nullAware == p.nullAware && - children.length == p.children.length && - children.zip(p.children).forall(p => p._1.semanticEquals(p._2)) - case _ => false - } - override def toString: String = s"predicate-subquery#${exprId.id} $conditionString" -} - -object PredicateSubquery { - def hasPredicateSubquery(e: Expression): Boolean = { - e.find { - case _: PredicateSubquery | _: ListQuery | _: Exists => true - case _ => false - }.isDefined - } - - /** - * Returns whether there are any null-aware predicate subqueries inside Not. If not, we could - * turn the null-aware predicate into not-null-aware predicate. - */ - def hasNullAwarePredicateWithinNot(e: Expression): Boolean = { - e.find{ x => - x.isInstanceOf[Not] && e.find { - case p: PredicateSubquery => p.nullAware - case _ => false - }.isDefined - }.isDefined - } -} - /** * A [[ListQuery]] expression defines the query which we want to search in an IN subquery * expression. It should and can only be used in conjunction with an IN expression. @@ -144,18 +259,20 @@ object PredicateSubquery { * FROM b) * }}} */ -case class ListQuery(plan: LogicalPlan, exprId: ExprId = NamedExpression.newExprId) - extends SubqueryExpression with Unevaluable { - override lazy val resolved = false - override def children: Seq[Expression] = Seq.empty - override def dataType: DataType = ArrayType(NullType) +case class ListQuery( + plan: LogicalPlan, + children: Seq[Expression] = Seq.empty, + exprId: ExprId = NamedExpression.newExprId) + extends SubqueryExpression(plan, children, exprId) with Unevaluable { + override def dataType: DataType = plan.schema.fields.head.dataType override def nullable: Boolean = false override def withNewPlan(plan: LogicalPlan): ListQuery = copy(plan = plan) - override def toString: String = s"list#${exprId.id}" + override def toString: String = s"list#${exprId.id} $conditionString" } /** * The [[Exists]] expression checks if a row exists in a subquery given some correlated condition. + * * For example (SQL): * {{{ * SELECT * @@ -165,11 +282,12 @@ case class ListQuery(plan: LogicalPlan, exprId: ExprId = NamedExpression.newExpr * WHERE b.id = a.id) * }}} */ -case class Exists(plan: LogicalPlan, exprId: ExprId = NamedExpression.newExprId) - extends SubqueryExpression with Predicate with Unevaluable { - override lazy val resolved = false - override def children: Seq[Expression] = Seq.empty +case class Exists( + plan: LogicalPlan, + children: Seq[Expression] = Seq.empty, + exprId: ExprId = NamedExpression.newExprId) + extends SubqueryExpression(plan, children, exprId) with Predicate with Unevaluable { override def nullable: Boolean = false override def withNewPlan(plan: LogicalPlan): Exists = copy(plan = plan) - override def toString: String = s"exists#${exprId.id}" + override def toString: String = s"exists#${exprId.id} $conditionString" } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index caafa1c134cd4..e9dbded3d4d02 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -68,6 +68,8 @@ abstract class Optimizer(sessionCatalog: SessionCatalog, conf: CatalystConf) // since the other rules might make two separate Unions operators adjacent. Batch("Union", Once, CombineUnions) :: + Batch("Pullup Correlated Expressions", Once, + PullupCorrelatedPredicates) :: Batch("Subquery", Once, OptimizeSubqueries) :: Batch("Replace Operators", fixedPoint, @@ -885,7 +887,7 @@ object PushDownPredicate extends Rule[LogicalPlan] with PredicateHelper { private def canPushThroughCondition(plan: LogicalPlan, condition: Expression): Boolean = { val attributes = plan.outputSet val matched = condition.find { - case PredicateSubquery(p, _, _, _) => p.outputSet.intersect(attributes).nonEmpty + case s: SubqueryExpression => s.plan.outputSet.intersect(attributes).nonEmpty case _ => false } matched.isEmpty diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala index fb7ce6aecea53..ba3fd1d5f802f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala @@ -21,6 +21,7 @@ import scala.collection.mutable.ArrayBuffer import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ +import org.apache.spark.sql.catalyst.expressions.SubExprUtils._ import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules._ @@ -41,10 +42,17 @@ import org.apache.spark.sql.types._ * condition. */ object RewritePredicateSubquery extends Rule[LogicalPlan] with PredicateHelper { + private def getValueExpression(e: Expression): Seq[Expression] = { + e match { + case cns : CreateNamedStruct => cns.valExprs + case expr => Seq(expr) + } + } + def apply(plan: LogicalPlan): LogicalPlan = plan transform { case Filter(condition, child) => val (withSubquery, withoutSubquery) = - splitConjunctivePredicates(condition).partition(PredicateSubquery.hasPredicateSubquery) + splitConjunctivePredicates(condition).partition(SubqueryExpression.hasInOrExistsSubquery) // Construct the pruned filter condition. val newFilter: LogicalPlan = withoutSubquery match { @@ -54,20 +62,25 @@ object RewritePredicateSubquery extends Rule[LogicalPlan] with PredicateHelper { // Filter the plan by applying left semi and left anti joins. withSubquery.foldLeft(newFilter) { - case (p, PredicateSubquery(sub, conditions, _, _)) => + case (p, Exists(sub, conditions, _)) => val (joinCond, outerPlan) = rewriteExistentialExpr(conditions, p) Join(outerPlan, sub, LeftSemi, joinCond) - case (p, Not(PredicateSubquery(sub, conditions, false, _))) => + case (p, Not(Exists(sub, conditions, _))) => val (joinCond, outerPlan) = rewriteExistentialExpr(conditions, p) Join(outerPlan, sub, LeftAnti, joinCond) - case (p, Not(PredicateSubquery(sub, conditions, true, _))) => + case (p, In(value, Seq(ListQuery(sub, conditions, _)))) => + val inConditions = getValueExpression(value).zip(sub.output).map(EqualTo.tupled) + val (joinCond, outerPlan) = rewriteExistentialExpr(inConditions ++ conditions, p) + Join(outerPlan, sub, LeftSemi, joinCond) + case (p, Not(In(value, Seq(ListQuery(sub, conditions, _))))) => // This is a NULL-aware (left) anti join (NAAJ) e.g. col NOT IN expr // Construct the condition. A NULL in one of the conditions is regarded as a positive // result; such a row will be filtered out by the Anti-Join operator. // Note that will almost certainly be planned as a Broadcast Nested Loop join. // Use EXISTS if performance matters to you. - val (joinCond, outerPlan) = rewriteExistentialExpr(conditions, p) + val inConditions = getValueExpression(value).zip(sub.output).map(EqualTo.tupled) + val (joinCond, outerPlan) = rewriteExistentialExpr(inConditions ++ conditions, p) // Expand the NOT IN expression with the NULL-aware semantic // to its full form. That is from: // (a1,b1,...) = (a2,b2,...) @@ -83,11 +96,10 @@ object RewritePredicateSubquery extends Rule[LogicalPlan] with PredicateHelper { } /** - * Given a predicate expression and an input plan, it rewrites - * any embedded existential sub-query into an existential join. - * It returns the rewritten expression together with the updated plan. - * Currently, it does not support null-aware joins. Embedded NOT IN predicates - * are blocked in the Analyzer. + * Given a predicate expression and an input plan, it rewrites any embedded existential sub-query + * into an existential join. It returns the rewritten expression together with the updated plan. + * Currently, it does not support NOT IN nested inside a NOT expression. This case is blocked in + * the Analyzer. */ private def rewriteExistentialExpr( exprs: Seq[Expression], @@ -95,17 +107,138 @@ object RewritePredicateSubquery extends Rule[LogicalPlan] with PredicateHelper { var newPlan = plan val newExprs = exprs.map { e => e transformUp { - case PredicateSubquery(sub, conditions, nullAware, _) => - // TODO: support null-aware join + case Exists(sub, conditions, _) => val exists = AttributeReference("exists", BooleanType, nullable = false)() newPlan = Join(newPlan, sub, ExistenceJoin(exists), conditions.reduceLeftOption(And)) exists - } + case In(value, Seq(ListQuery(sub, conditions, _))) => + val exists = AttributeReference("exists", BooleanType, nullable = false)() + val inConditions = getValueExpression(value).zip(sub.output).map(EqualTo.tupled) + val newConditions = (inConditions ++ conditions).reduceLeftOption(And) + newPlan = Join(newPlan, sub, ExistenceJoin(exists), newConditions) + exists + } } (newExprs.reduceOption(And), newPlan) } } + /** + * Pull out all (outer) correlated predicates from a given subquery. This method removes the + * correlated predicates from subquery [[Filter]]s and adds the references of these predicates + * to all intermediate [[Project]] and [[Aggregate]] clauses (if they are missing) in order to + * be able to evaluate the predicates at the top level. + * + * TODO: Look to merge this rule with RewritePredicateSubquery. + */ +object PullupCorrelatedPredicates extends Rule[LogicalPlan] with PredicateHelper { + /** + * Returns the correlated predicates and a updated plan that removes the outer references. + */ + private def pullOutCorrelatedPredicates( + sub: LogicalPlan, + outer: Seq[LogicalPlan]): (LogicalPlan, Seq[Expression]) = { + val predicateMap = scala.collection.mutable.Map.empty[LogicalPlan, Seq[Expression]] + + /** Determine which correlated predicate references are missing from this plan. */ + def missingReferences(p: LogicalPlan): AttributeSet = { + val localPredicateReferences = p.collect(predicateMap) + .flatten + .map(_.references) + .reduceOption(_ ++ _) + .getOrElse(AttributeSet.empty) + localPredicateReferences -- p.outputSet + } + + // Simplify the predicates before pulling them out. + val transformed = BooleanSimplification(sub) transformUp { + case f @ Filter(cond, child) => + val (correlated, local) = + splitConjunctivePredicates(cond).partition(containsOuter) + + // Rewrite the filter without the correlated predicates if any. + correlated match { + case Nil => f + case xs if local.nonEmpty => + val newFilter = Filter(local.reduce(And), child) + predicateMap += newFilter -> xs + newFilter + case xs => + predicateMap += child -> xs + child + } + case p @ Project(expressions, child) => + val referencesToAdd = missingReferences(p) + if (referencesToAdd.nonEmpty) { + Project(expressions ++ referencesToAdd, child) + } else { + p + } + case a @ Aggregate(grouping, expressions, child) => + val referencesToAdd = missingReferences(a) + if (referencesToAdd.nonEmpty) { + Aggregate(grouping ++ referencesToAdd, expressions ++ referencesToAdd, child) + } else { + a + } + case p => + p + } + + // Make sure the inner and the outer query attributes do not collide. + // In case of a collision, change the subquery plan's output to use + // different attribute by creating alias(s). + val baseConditions = predicateMap.values.flatten.toSeq + val (newPlan, newCond) = if (outer.nonEmpty) { + val outputSet = outer.map(_.outputSet).reduce(_ ++ _) + val duplicates = transformed.outputSet.intersect(outputSet) + val (plan, deDuplicatedConditions) = if (duplicates.nonEmpty) { + val aliasMap = AttributeMap(duplicates.map { dup => + dup -> Alias(dup, dup.toString)() + }.toSeq) + val aliasedExpressions = transformed.output.map { ref => + aliasMap.getOrElse(ref, ref) + } + val aliasedProjection = Project(aliasedExpressions, transformed) + val aliasedConditions = baseConditions.map(_.transform { + case ref: Attribute => aliasMap.getOrElse(ref, ref).toAttribute + }) + (aliasedProjection, aliasedConditions) + } else { + (transformed, baseConditions) + } + (plan, stripOuterReferences(deDuplicatedConditions)) + } else { + (transformed, stripOuterReferences(baseConditions)) + } + (newPlan, newCond) + } + + private def rewriteSubQueries(plan: LogicalPlan, outerPlans: Seq[LogicalPlan]): LogicalPlan = { + plan transformExpressions { + case ScalarSubquery(sub, children, exprId) if children.nonEmpty => + val (newPlan, newCond) = pullOutCorrelatedPredicates(sub, outerPlans) + ScalarSubquery(newPlan, newCond, exprId) + case Exists(sub, children, exprId) if children.nonEmpty => + val (newPlan, newCond) = pullOutCorrelatedPredicates(sub, outerPlans) + Exists(newPlan, newCond, exprId) + case ListQuery(sub, _, exprId) => + val (newPlan, newCond) = pullOutCorrelatedPredicates(sub, outerPlans) + ListQuery(newPlan, newCond, exprId) + } + } + + /** + * Pull up the correlated predicates and rewrite all subqueries in an operator tree.. + */ + def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { + case f @ Filter(_, a: Aggregate) => + rewriteSubQueries(f, Seq(a, a.child)) + // Only a few unary nodes (Project/Filter/Aggregate) can contain subqueries. + case q: UnaryNode => + rewriteSubQueries(q, q.children) + } +} /** * This rule rewrites correlated [[ScalarSubquery]] expressions into LEFT OUTER joins. diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala index c5e877d12811c..d2ebca5a83dd3 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala @@ -530,7 +530,7 @@ class AnalysisErrorSuite extends AnalysisTest { Exists( Join( LocalRelation(b), - Filter(EqualTo(OuterReference(a), c), LocalRelation(c)), + Filter(EqualTo(UnresolvedAttribute("a"), c), LocalRelation(c)), LeftOuter, Option(EqualTo(b, c)))), LocalRelation(a)) @@ -539,7 +539,7 @@ class AnalysisErrorSuite extends AnalysisTest { val plan2 = Filter( Exists( Join( - Filter(EqualTo(OuterReference(a), c), LocalRelation(c)), + Filter(EqualTo(UnresolvedAttribute("a"), c), LocalRelation(c)), LocalRelation(b), RightOuter, Option(EqualTo(b, c)))), @@ -547,14 +547,15 @@ class AnalysisErrorSuite extends AnalysisTest { assertAnalysisError(plan2, "Accessing outer query column is not allowed in" :: Nil) val plan3 = Filter( - Exists(Union(LocalRelation(b), Filter(EqualTo(OuterReference(a), c), LocalRelation(c)))), + Exists(Union(LocalRelation(b), + Filter(EqualTo(UnresolvedAttribute("a"), c), LocalRelation(c)))), LocalRelation(a)) assertAnalysisError(plan3, "Accessing outer query column is not allowed in" :: Nil) val plan4 = Filter( Exists( Limit(1, - Filter(EqualTo(OuterReference(a), b), LocalRelation(b))) + Filter(EqualTo(UnresolvedAttribute("a"), b), LocalRelation(b))) ), LocalRelation(a)) assertAnalysisError(plan4, "Accessing outer query column is not allowed in" :: Nil) @@ -562,7 +563,7 @@ class AnalysisErrorSuite extends AnalysisTest { val plan5 = Filter( Exists( Sample(0.0, 0.5, false, 1L, - Filter(EqualTo(OuterReference(a), b), LocalRelation(b)))().select('b) + Filter(EqualTo(UnresolvedAttribute("a"), b), LocalRelation(b)))().select('b) ), LocalRelation(a)) assertAnalysisError(plan5, diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveSubquerySuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveSubquerySuite.scala index 4aafb2b83fb69..55693121431a2 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveSubquerySuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveSubquerySuite.scala @@ -33,7 +33,7 @@ class ResolveSubquerySuite extends AnalysisTest { val t2 = LocalRelation(b) test("SPARK-17251 Improve `OuterReference` to be `NamedExpression`") { - val expr = Filter(In(a, Seq(ListQuery(Project(Seq(OuterReference(a)), t2)))), t1) + val expr = Filter(In(a, Seq(ListQuery(Project(Seq(UnresolvedAttribute("a")), t2)))), t1) val m = intercept[AnalysisException] { SimpleAnalyzer.ResolveSubquery(expr) }.getMessage diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala index e9b7a0c6ad671..5eb31413ad70f 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala @@ -43,8 +43,6 @@ abstract class PlanTest extends SparkFunSuite with PredicateHelper { e.copy(exprId = ExprId(0)) case l: ListQuery => l.copy(exprId = ExprId(0)) - case p: PredicateSubquery => - p.copy(exprId = ExprId(0)) case a: AttributeReference => AttributeReference(a.name, a.dataType, a.nullable)(exprId = ExprId(0)) case a: Alias => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/subquery.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/subquery.scala index 730ca27f82bac..58be2d1da2816 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/subquery.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/subquery.scala @@ -144,9 +144,6 @@ case class PlanSubqueries(sparkSession: SparkSession) extends Rule[SparkPlan] { ScalarSubquery( SubqueryExec(s"subquery${subquery.exprId.id}", executedPlan), subquery.exprId) - case expressions.PredicateSubquery(query, Seq(e: Expression), _, exprId) => - val executedPlan = new QueryExecution(sparkSession, query).executedPlan - InSubquery(e, SubqueryExec(s"subquery${exprId.id}", executedPlan), exprId) } } } diff --git a/sql/core/src/test/resources/sql-tests/results/subquery/negative-cases/invalid-correlation.sql.out b/sql/core/src/test/resources/sql-tests/results/subquery/negative-cases/invalid-correlation.sql.out index 50ae01e181bcf..f7bbb35aad6ce 100644 --- a/sql/core/src/test/resources/sql-tests/results/subquery/negative-cases/invalid-correlation.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/subquery/negative-cases/invalid-correlation.sql.out @@ -46,7 +46,7 @@ and t2b = (select max(avg) struct<> -- !query 3 output org.apache.spark.sql.AnalysisException -expression 't2.`t2b`' is neither present in the group by, nor is it an aggregate function. Add to group by or wrap in first() (or first_value) if you don't care which value you get.; +grouping expressions sequence is empty, and 't2.`t2b`' is not an aggregate function. Wrap '(avg(CAST(t2.`t2b` AS BIGINT)) AS `avg`)' in windowing function(s) or wrap 't2.`t2b`' in first() (or first_value) if you don't care which value you get.; -- !query 4 @@ -63,4 +63,4 @@ where t1a in (select min(t2a) struct<> -- !query 4 output org.apache.spark.sql.AnalysisException -resolved attribute(s) t2b#x missing from min(t2a)#x,t2c#x in operator !Filter predicate-subquery#x [(t2c#x = max(t3c)#x) && (t3b#x > t2b#x)]; +resolved attribute(s) t2b#x missing from min(t2a)#x,t2c#x in operator !Filter t2c#x IN (list#x [t2b#x]); diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala index 25dbecb5894e4..6f1cd49c08ee1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala @@ -622,7 +622,12 @@ class SubquerySuite extends QueryTest with SharedSQLContext { test("SPARK-15370: COUNT bug with attribute ref in subquery input and output ") { checkAnswer( - sql("select l.b, (select (r.c + count(*)) is null from r where l.a = r.c) from l"), + sql( + """ + |select l.b, (select (r.c + count(*)) is null + |from r + |where l.a = r.c group by r.c) from l + """.stripMargin), Row(1.0, false) :: Row(1.0, false) :: Row(2.0, true) :: Row(2.0, true) :: Row(3.0, false) :: Row(5.0, true) :: Row(null, false) :: Row(null, true) :: Nil) } From 0ee38a39e43dd7ad9d50457e446ae36f64621a1b Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Tue, 14 Mar 2017 19:02:30 +0800 Subject: [PATCH 0007/1765] [SPARK-19944][SQL] Move SQLConf from sql/core to sql/catalyst ## What changes were proposed in this pull request? This patch moves SQLConf from sql/core to sql/catalyst. To minimize the changes, the patch used type alias to still keep CatalystConf (as a type alias) and SimpleCatalystConf (as a concrete class that extends SQLConf). Motivation for the change is that it is pretty weird to have SQLConf only in sql/core and then we have to duplicate config options that impact optimizer/analyzer in sql/catalyst using CatalystConf. ## How was this patch tested? N/A Author: Reynold Xin Closes #17285 from rxin/SPARK-19944. --- .../spark/sql/catalyst/CatalystConf.scala | 93 --------------- .../sql/catalyst/SimpleCatalystConf.scala | 48 ++++++++ .../apache/spark/sql/catalyst/package.scala | 7 ++ .../apache/spark/sql/internal/SQLConf.scala | 106 +++++------------- .../spark/sql/internal/StaticSQLConf.scala | 84 ++++++++++++++ 5 files changed, 165 insertions(+), 173 deletions(-) delete mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystConf.scala create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SimpleCatalystConf.scala rename sql/{core => catalyst}/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala (91%) create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/internal/StaticSQLConf.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystConf.scala deleted file mode 100644 index cff0efa979932..0000000000000 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystConf.scala +++ /dev/null @@ -1,93 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.catalyst - -import java.util.TimeZone - -import org.apache.spark.sql.catalyst.analysis._ - -/** - * Interface for configuration options used in the catalyst module. - */ -trait CatalystConf { - def caseSensitiveAnalysis: Boolean - - def orderByOrdinal: Boolean - def groupByOrdinal: Boolean - - def optimizerMaxIterations: Int - def optimizerInSetConversionThreshold: Int - def maxCaseBranchesForCodegen: Int - - def tableRelationCacheSize: Int - - def runSQLonFile: Boolean - - def warehousePath: String - - def sessionLocalTimeZone: String - - /** If true, cartesian products between relations will be allowed for all - * join types(inner, (left|right|full) outer). - * If false, cartesian products will require explicit CROSS JOIN syntax. - */ - def crossJoinEnabled: Boolean - - /** - * Returns the [[Resolver]] for the current configuration, which can be used to determine if two - * identifiers are equal. - */ - def resolver: Resolver = { - if (caseSensitiveAnalysis) caseSensitiveResolution else caseInsensitiveResolution - } - - /** - * Enables CBO for estimation of plan statistics when set true. - */ - def cboEnabled: Boolean - - /** Enables join reorder in CBO. */ - def joinReorderEnabled: Boolean - - /** The maximum number of joined nodes allowed in the dynamic programming algorithm. */ - def joinReorderDPThreshold: Int - - override def clone(): CatalystConf = throw new CloneNotSupportedException() -} - - -/** A CatalystConf that can be used for local testing. */ -case class SimpleCatalystConf( - caseSensitiveAnalysis: Boolean, - orderByOrdinal: Boolean = true, - groupByOrdinal: Boolean = true, - optimizerMaxIterations: Int = 100, - optimizerInSetConversionThreshold: Int = 10, - maxCaseBranchesForCodegen: Int = 20, - tableRelationCacheSize: Int = 1000, - runSQLonFile: Boolean = true, - crossJoinEnabled: Boolean = false, - cboEnabled: Boolean = false, - joinReorderEnabled: Boolean = false, - joinReorderDPThreshold: Int = 12, - warehousePath: String = "/user/hive/warehouse", - sessionLocalTimeZone: String = TimeZone.getDefault().getID) - extends CatalystConf { - - override def clone(): SimpleCatalystConf = this.copy() -} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SimpleCatalystConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SimpleCatalystConf.scala new file mode 100644 index 0000000000000..746f84459de26 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SimpleCatalystConf.scala @@ -0,0 +1,48 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst + +import java.util.TimeZone + +import org.apache.spark.sql.internal.SQLConf + + +/** + * A SQLConf that can be used for local testing. This class is only here to minimize the change + * for ticket SPARK-19944 (moves SQLConf from sql/core to sql/catalyst). This class should + * eventually be removed (test cases should just create SQLConf and set values appropriately). + */ +case class SimpleCatalystConf( + override val caseSensitiveAnalysis: Boolean, + override val orderByOrdinal: Boolean = true, + override val groupByOrdinal: Boolean = true, + override val optimizerMaxIterations: Int = 100, + override val optimizerInSetConversionThreshold: Int = 10, + override val maxCaseBranchesForCodegen: Int = 20, + override val tableRelationCacheSize: Int = 1000, + override val runSQLonFile: Boolean = true, + override val crossJoinEnabled: Boolean = false, + override val cboEnabled: Boolean = false, + override val joinReorderEnabled: Boolean = false, + override val joinReorderDPThreshold: Int = 12, + override val warehousePath: String = "/user/hive/warehouse", + override val sessionLocalTimeZone: String = TimeZone.getDefault().getID) + extends SQLConf { + + override def clone(): SimpleCatalystConf = this.copy() +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/package.scala index 105cdf52500c6..4af56afebb762 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/package.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql +import org.apache.spark.sql.internal.SQLConf + /** * Catalyst is a library for manipulating relational query plans. All classes in catalyst are * considered an internal API to Spark SQL and are subject to change between minor releases. @@ -29,4 +31,9 @@ package object catalyst { */ protected[sql] object ScalaReflectionLock + /** + * This class is only here to minimize the change for ticket SPARK-19944 + * (moves SQLConf from sql/core to sql/catalyst). This class should eventually be removed. + */ + type CatalystConf = SQLConf } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala similarity index 91% rename from sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala rename to sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 8e3f567b7dd90..315bedb12e716 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -24,15 +24,11 @@ import scala.collection.JavaConverters._ import scala.collection.immutable import org.apache.hadoop.fs.Path -import org.apache.parquet.hadoop.ParquetOutputCommitter import org.apache.spark.internal.Logging import org.apache.spark.internal.config._ import org.apache.spark.network.util.ByteUnit -import org.apache.spark.sql.catalyst.CatalystConf -import org.apache.spark.sql.execution.datasources.SQLHadoopMapReduceCommitProtocol -import org.apache.spark.sql.execution.streaming.ManifestFileCommitProtocol -import org.apache.spark.util.Utils +import org.apache.spark.sql.catalyst.analysis.Resolver //////////////////////////////////////////////////////////////////////////////////////////////////// // This file defines the configuration options for Spark SQL. @@ -251,7 +247,7 @@ object SQLConf { "of org.apache.parquet.hadoop.ParquetOutputCommitter.") .internal() .stringConf - .createWithDefault(classOf[ParquetOutputCommitter].getName) + .createWithDefault("org.apache.parquet.hadoop.ParquetOutputCommitter") val PARQUET_VECTORIZED_READER_ENABLED = buildConf("spark.sql.parquet.enableVectorizedReader") @@ -417,7 +413,8 @@ object SQLConf { buildConf("spark.sql.sources.commitProtocolClass") .internal() .stringConf - .createWithDefault(classOf[SQLHadoopMapReduceCommitProtocol].getName) + .createWithDefault( + "org.apache.spark.sql.execution.datasources.SQLHadoopMapReduceCommitProtocol") val PARALLEL_PARTITION_DISCOVERY_THRESHOLD = buildConf("spark.sql.sources.parallelPartitionDiscovery.threshold") @@ -578,7 +575,7 @@ object SQLConf { buildConf("spark.sql.streaming.commitProtocolClass") .internal() .stringConf - .createWithDefault(classOf[ManifestFileCommitProtocol].getName) + .createWithDefault("org.apache.spark.sql.execution.streaming.ManifestFileCommitProtocol") val OBJECT_AGG_SORT_BASED_FALLBACK_THRESHOLD = buildConf("spark.sql.objectHashAggregate.sortBased.fallbackThreshold") @@ -723,7 +720,7 @@ object SQLConf { * * SQLConf is thread-safe (internally synchronized, so safe to be used in multiple threads). */ -private[sql] class SQLConf extends Serializable with CatalystConf with Logging { +class SQLConf extends Serializable with Logging { import SQLConf._ /** Only low degree of contention is expected for conf, thus NOT using ConcurrentHashMap. */ @@ -833,6 +830,18 @@ private[sql] class SQLConf extends Serializable with CatalystConf with Logging { def caseSensitiveAnalysis: Boolean = getConf(SQLConf.CASE_SENSITIVE) + /** + * Returns the [[Resolver]] for the current configuration, which can be used to determine if two + * identifiers are equal. + */ + def resolver: Resolver = { + if (caseSensitiveAnalysis) { + org.apache.spark.sql.catalyst.analysis.caseSensitiveResolution + } else { + org.apache.spark.sql.catalyst.analysis.caseInsensitiveResolution + } + } + def subexpressionEliminationEnabled: Boolean = getConf(SUBEXPRESSION_ELIMINATION_ENABLED) @@ -890,7 +899,7 @@ private[sql] class SQLConf extends Serializable with CatalystConf with Logging { def dataFramePivotMaxValues: Int = getConf(DATAFRAME_PIVOT_MAX_VALUES) - override def runSQLonFile: Boolean = getConf(RUN_SQL_ON_FILES) + def runSQLonFile: Boolean = getConf(RUN_SQL_ON_FILES) def enableTwoLevelAggMap: Boolean = getConf(ENABLE_TWOLEVEL_AGG_MAP) @@ -907,21 +916,21 @@ private[sql] class SQLConf extends Serializable with CatalystConf with Logging { def hiveThriftServerSingleSession: Boolean = getConf(StaticSQLConf.HIVE_THRIFT_SERVER_SINGLESESSION) - override def orderByOrdinal: Boolean = getConf(ORDER_BY_ORDINAL) + def orderByOrdinal: Boolean = getConf(ORDER_BY_ORDINAL) - override def groupByOrdinal: Boolean = getConf(GROUP_BY_ORDINAL) + def groupByOrdinal: Boolean = getConf(GROUP_BY_ORDINAL) - override def crossJoinEnabled: Boolean = getConf(SQLConf.CROSS_JOINS_ENABLED) + def crossJoinEnabled: Boolean = getConf(SQLConf.CROSS_JOINS_ENABLED) - override def sessionLocalTimeZone: String = getConf(SQLConf.SESSION_LOCAL_TIMEZONE) + def sessionLocalTimeZone: String = getConf(SQLConf.SESSION_LOCAL_TIMEZONE) def ndvMaxError: Double = getConf(NDV_MAX_ERROR) - override def cboEnabled: Boolean = getConf(SQLConf.CBO_ENABLED) + def cboEnabled: Boolean = getConf(SQLConf.CBO_ENABLED) - override def joinReorderEnabled: Boolean = getConf(SQLConf.JOIN_REORDER_ENABLED) + def joinReorderEnabled: Boolean = getConf(SQLConf.JOIN_REORDER_ENABLED) - override def joinReorderDPThreshold: Int = getConf(SQLConf.JOIN_REORDER_DP_THRESHOLD) + def joinReorderDPThreshold: Int = getConf(SQLConf.JOIN_REORDER_DP_THRESHOLD) /** ********************** SQLConf functionality methods ************ */ @@ -1050,66 +1059,3 @@ private[sql] class SQLConf extends Serializable with CatalystConf with Logging { result } } - -/** - * Static SQL configuration is a cross-session, immutable Spark configuration. External users can - * see the static sql configs via `SparkSession.conf`, but can NOT set/unset them. - */ -object StaticSQLConf { - - import SQLConf.buildStaticConf - - val WAREHOUSE_PATH = buildStaticConf("spark.sql.warehouse.dir") - .doc("The default location for managed databases and tables.") - .stringConf - .createWithDefault(Utils.resolveURI("spark-warehouse").toString) - - val CATALOG_IMPLEMENTATION = buildStaticConf("spark.sql.catalogImplementation") - .internal() - .stringConf - .checkValues(Set("hive", "in-memory")) - .createWithDefault("in-memory") - - val GLOBAL_TEMP_DATABASE = buildStaticConf("spark.sql.globalTempDatabase") - .internal() - .stringConf - .createWithDefault("global_temp") - - // This is used to control when we will split a schema's JSON string to multiple pieces - // in order to fit the JSON string in metastore's table property (by default, the value has - // a length restriction of 4000 characters, so do not use a value larger than 4000 as the default - // value of this property). We will split the JSON string of a schema to its length exceeds the - // threshold. Note that, this conf is only read in HiveExternalCatalog which is cross-session, - // that's why this conf has to be a static SQL conf. - val SCHEMA_STRING_LENGTH_THRESHOLD = - buildStaticConf("spark.sql.sources.schemaStringLengthThreshold") - .doc("The maximum length allowed in a single cell when " + - "storing additional schema information in Hive's metastore.") - .internal() - .intConf - .createWithDefault(4000) - - val FILESOURCE_TABLE_RELATION_CACHE_SIZE = - buildStaticConf("spark.sql.filesourceTableRelationCacheSize") - .internal() - .doc("The maximum size of the cache that maps qualified table names to table relation plans.") - .intConf - .checkValue(cacheSize => cacheSize >= 0, "The maximum size of the cache must not be negative") - .createWithDefault(1000) - - // When enabling the debug, Spark SQL internal table properties are not filtered out; however, - // some related DDL commands (e.g., ANALYZE TABLE and CREATE TABLE LIKE) might not work properly. - val DEBUG_MODE = buildStaticConf("spark.sql.debug") - .internal() - .doc("Only used for internal debugging. Not all functions are supported when it is enabled.") - .booleanConf - .createWithDefault(false) - - val HIVE_THRIFT_SERVER_SINGLESESSION = - buildStaticConf("spark.sql.hive.thriftServer.singleSession") - .doc("When set to true, Hive Thrift server is running in a single session mode. " + - "All the JDBC/ODBC connections share the temporary views, function registries, " + - "SQL configuration and the current database.") - .booleanConf - .createWithDefault(false) -} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/StaticSQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/StaticSQLConf.scala new file mode 100644 index 0000000000000..af1a9cee2962a --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/StaticSQLConf.scala @@ -0,0 +1,84 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.internal + +import org.apache.spark.util.Utils + + +/** + * Static SQL configuration is a cross-session, immutable Spark configuration. External users can + * see the static sql configs via `SparkSession.conf`, but can NOT set/unset them. + */ +object StaticSQLConf { + + import SQLConf.buildStaticConf + + val WAREHOUSE_PATH = buildStaticConf("spark.sql.warehouse.dir") + .doc("The default location for managed databases and tables.") + .stringConf + .createWithDefault(Utils.resolveURI("spark-warehouse").toString) + + val CATALOG_IMPLEMENTATION = buildStaticConf("spark.sql.catalogImplementation") + .internal() + .stringConf + .checkValues(Set("hive", "in-memory")) + .createWithDefault("in-memory") + + val GLOBAL_TEMP_DATABASE = buildStaticConf("spark.sql.globalTempDatabase") + .internal() + .stringConf + .createWithDefault("global_temp") + + // This is used to control when we will split a schema's JSON string to multiple pieces + // in order to fit the JSON string in metastore's table property (by default, the value has + // a length restriction of 4000 characters, so do not use a value larger than 4000 as the default + // value of this property). We will split the JSON string of a schema to its length exceeds the + // threshold. Note that, this conf is only read in HiveExternalCatalog which is cross-session, + // that's why this conf has to be a static SQL conf. + val SCHEMA_STRING_LENGTH_THRESHOLD = + buildStaticConf("spark.sql.sources.schemaStringLengthThreshold") + .doc("The maximum length allowed in a single cell when " + + "storing additional schema information in Hive's metastore.") + .internal() + .intConf + .createWithDefault(4000) + + val FILESOURCE_TABLE_RELATION_CACHE_SIZE = + buildStaticConf("spark.sql.filesourceTableRelationCacheSize") + .internal() + .doc("The maximum size of the cache that maps qualified table names to table relation plans.") + .intConf + .checkValue(cacheSize => cacheSize >= 0, "The maximum size of the cache must not be negative") + .createWithDefault(1000) + + // When enabling the debug, Spark SQL internal table properties are not filtered out; however, + // some related DDL commands (e.g., ANALYZE TABLE and CREATE TABLE LIKE) might not work properly. + val DEBUG_MODE = buildStaticConf("spark.sql.debug") + .internal() + .doc("Only used for internal debugging. Not all functions are supported when it is enabled.") + .booleanConf + .createWithDefault(false) + + val HIVE_THRIFT_SERVER_SINGLESESSION = + buildStaticConf("spark.sql.hive.thriftServer.singleSession") + .doc("When set to true, Hive Thrift server is running in a single session mode. " + + "All the JDBC/ODBC connections share the temporary views, function registries, " + + "SQL configuration and the current database.") + .booleanConf + .createWithDefault(false) +} From a0b92f73fed9b91883f08cced1c09724e09e1883 Mon Sep 17 00:00:00 2001 From: Herman van Hovell Date: Tue, 14 Mar 2017 12:49:30 +0100 Subject: [PATCH 0008/1765] [SPARK-19850][SQL] Allow the use of aliases in SQL function calls ## What changes were proposed in this pull request? We currently cannot use aliases in SQL function calls. This is inconvenient when you try to create a struct. This SQL query for example `select struct(1, 2) st`, will create a struct with column names `col1` and `col2`. This is even more problematic when we want to append a field to an existing struct. For example if we want to a field to struct `st` we would issue the following SQL query `select struct(st.*, 1) as st from src`, the result will be struct `st` with an a column with a non descriptive name `col3` (if `st` itself has 2 fields). This PR proposes to change this by allowing the use of aliased expression in function parameters. For example `select struct(1 as a, 2 as b) st`, will create a struct with columns `a` & `b`. ## How was this patch tested? Added a test to `ExpressionParserSuite` and added a test file for `SQLQueryTestSuite`. Author: Herman van Hovell Closes #17245 from hvanhovell/SPARK-19850. --- .../spark/sql/catalyst/parser/SqlBase.g4 | 7 ++- .../sql/catalyst/parser/AstBuilder.scala | 4 +- .../parser/ExpressionParserSuite.scala | 2 + .../resources/sql-tests/inputs/struct.sql | 20 +++++++ .../sql-tests/results/struct.sql.out | 60 +++++++++++++++++++ 5 files changed, 88 insertions(+), 5 deletions(-) create mode 100644 sql/core/src/test/resources/sql-tests/inputs/struct.sql create mode 100644 sql/core/src/test/resources/sql-tests/results/struct.sql.out diff --git a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 index 59f93b3c469d5..cc3b8fd3b4689 100644 --- a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 +++ b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 @@ -506,10 +506,10 @@ expression booleanExpression : NOT booleanExpression #logicalNot + | EXISTS '(' query ')' #exists | predicated #booleanDefault | left=booleanExpression operator=AND right=booleanExpression #logicalBinary | left=booleanExpression operator=OR right=booleanExpression #logicalBinary - | EXISTS '(' query ')' #exists ; // workaround for: @@ -546,9 +546,10 @@ primaryExpression | constant #constantDefault | ASTERISK #star | qualifiedName '.' ASTERISK #star - | '(' expression (',' expression)+ ')' #rowConstructor + | '(' namedExpression (',' namedExpression)+ ')' #rowConstructor | '(' query ')' #subqueryExpression - | qualifiedName '(' (setQuantifier? expression (',' expression)*)? ')' (OVER windowSpec)? #functionCall + | qualifiedName '(' (setQuantifier? namedExpression (',' namedExpression)*)? ')' + (OVER windowSpec)? #functionCall | value=primaryExpression '[' index=valueExpression ']' #subscript | identifier #columnReference | base=primaryExpression '.' fieldName=identifier #dereference diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index 3cf11adc1953b..4c9fb2ec2774a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -1016,7 +1016,7 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with Logging { // Create the function call. val name = ctx.qualifiedName.getText val isDistinct = Option(ctx.setQuantifier()).exists(_.DISTINCT != null) - val arguments = ctx.expression().asScala.map(expression) match { + val arguments = ctx.namedExpression().asScala.map(expression) match { case Seq(UnresolvedStar(None)) if name.toLowerCase == "count" && !isDistinct => // Transform COUNT(*) into COUNT(1). Seq(Literal(1)) @@ -1127,7 +1127,7 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with Logging { * Create a [[CreateStruct]] expression. */ override def visitRowConstructor(ctx: RowConstructorContext): Expression = withOrigin(ctx) { - CreateStruct(ctx.expression.asScala.map(expression)) + CreateStruct(ctx.namedExpression().asScala.map(expression)) } /** diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala index 2fecb8dc4a60e..c2e62e739776f 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala @@ -209,6 +209,7 @@ class ExpressionParserSuite extends PlanTest { assertEqual("foo(distinct a, b)", 'foo.distinctFunction('a, 'b)) assertEqual("grouping(distinct a, b)", 'grouping.distinctFunction('a, 'b)) assertEqual("`select`(all a, b)", 'select.function('a, 'b)) + assertEqual("foo(a as x, b as e)", 'foo.function('a as 'x, 'b as 'e)) } test("window function expressions") { @@ -278,6 +279,7 @@ class ExpressionParserSuite extends PlanTest { // Note that '(a)' will be interpreted as a nested expression. assertEqual("(a, b)", CreateStruct(Seq('a, 'b))) assertEqual("(a, b, c)", CreateStruct(Seq('a, 'b, 'c))) + assertEqual("(a as b, b as c)", CreateStruct(Seq('a as 'b, 'b as 'c))) } test("scalar sub-query") { diff --git a/sql/core/src/test/resources/sql-tests/inputs/struct.sql b/sql/core/src/test/resources/sql-tests/inputs/struct.sql new file mode 100644 index 0000000000000..e56344dc4de80 --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/inputs/struct.sql @@ -0,0 +1,20 @@ +CREATE TEMPORARY VIEW tbl_x AS VALUES + (1, NAMED_STRUCT('C', 'gamma', 'D', 'delta')), + (2, NAMED_STRUCT('C', 'epsilon', 'D', 'eta')), + (3, NAMED_STRUCT('C', 'theta', 'D', 'iota')) + AS T(ID, ST); + +-- Create a struct +SELECT STRUCT('alpha', 'beta') ST; + +-- Create a struct with aliases +SELECT STRUCT('alpha' AS A, 'beta' AS B) ST; + +-- Star expansion in a struct. +SELECT ID, STRUCT(ST.*) NST FROM tbl_x; + +-- Append a column to a struct +SELECT ID, STRUCT(ST.*,CAST(ID AS STRING) AS E) NST FROM tbl_x; + +-- Prepend a column to a struct +SELECT ID, STRUCT(CAST(ID AS STRING) AS AA, ST.*) NST FROM tbl_x; diff --git a/sql/core/src/test/resources/sql-tests/results/struct.sql.out b/sql/core/src/test/resources/sql-tests/results/struct.sql.out new file mode 100644 index 0000000000000..3e32f46195464 --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/results/struct.sql.out @@ -0,0 +1,60 @@ +-- Automatically generated by SQLQueryTestSuite +-- Number of queries: 6 + + +-- !query 0 +CREATE TEMPORARY VIEW tbl_x AS VALUES + (1, NAMED_STRUCT('C', 'gamma', 'D', 'delta')), + (2, NAMED_STRUCT('C', 'epsilon', 'D', 'eta')), + (3, NAMED_STRUCT('C', 'theta', 'D', 'iota')) + AS T(ID, ST) +-- !query 0 schema +struct<> +-- !query 0 output + + + +-- !query 1 +SELECT STRUCT('alpha', 'beta') ST +-- !query 1 schema +struct> +-- !query 1 output +{"col1":"alpha","col2":"beta"} + + +-- !query 2 +SELECT STRUCT('alpha' AS A, 'beta' AS B) ST +-- !query 2 schema +struct> +-- !query 2 output +{"A":"alpha","B":"beta"} + + +-- !query 3 +SELECT ID, STRUCT(ST.*) NST FROM tbl_x +-- !query 3 schema +struct> +-- !query 3 output +1 {"C":"gamma","D":"delta"} +2 {"C":"epsilon","D":"eta"} +3 {"C":"theta","D":"iota"} + + +-- !query 4 +SELECT ID, STRUCT(ST.*,CAST(ID AS STRING) AS E) NST FROM tbl_x +-- !query 4 schema +struct> +-- !query 4 output +1 {"C":"gamma","D":"delta","E":"1"} +2 {"C":"epsilon","D":"eta","E":"2"} +3 {"C":"theta","D":"iota","E":"3"} + + +-- !query 5 +SELECT ID, STRUCT(CAST(ID AS STRING) AS AA, ST.*) NST FROM tbl_x +-- !query 5 schema +struct> +-- !query 5 output +1 {"AA":"1","C":"gamma","D":"delta"} +2 {"AA":"2","C":"epsilon","D":"eta"} +3 {"AA":"3","C":"theta","D":"iota"} From 1c7275efa7bfaaa92719750e93a7b35cbcb48e45 Mon Sep 17 00:00:00 2001 From: Herman van Hovell Date: Tue, 14 Mar 2017 14:02:48 +0100 Subject: [PATCH 0009/1765] [SPARK-18874][SQL] Fix 2.10 build after moving the subquery rules to optimization ## What changes were proposed in this pull request? Commit https://github.com/apache/spark/commit/4ce970d71488c7de6025ef925f75b8b92a5a6a79 in accidentally broke the 2.10 build for Spark. This PR fixes this by simplifying the offending pattern match. ## How was this patch tested? Existing tests. Author: Herman van Hovell Closes #17288 from hvanhovell/SPARK-18874. --- .../org/apache/spark/sql/catalyst/expressions/subquery.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala index ad11700fa28d2..59db28d58afce 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala @@ -100,8 +100,8 @@ object SubExprUtils extends PredicateHelper { */ def hasNullAwarePredicateWithinNot(condition: Expression): Boolean = { splitConjunctivePredicates(condition).exists { - case _: Exists | Not(_: Exists) | In(_, Seq(_: ListQuery)) | Not(In(_, Seq(_: ListQuery))) => - false + case _: Exists | Not(_: Exists) => false + case In(_, Seq(_: ListQuery)) | Not(In(_, Seq(_: ListQuery))) => false case e => e.find { x => x.isInstanceOf[Not] && e.find { case In(_, Seq(_: ListQuery)) => true From 5e96a57b2f383d4b33735681b41cd3ec06570671 Mon Sep 17 00:00:00 2001 From: Asher Krim Date: Tue, 14 Mar 2017 13:08:11 +0000 Subject: [PATCH 0010/1765] [SPARK-19922][ML] small speedups to findSynonyms Currently generating synonyms using a large model (I've tested with 3m words) is very slow. These efficiencies have sped things up for us by ~17% I wasn't sure if such small changes were worthy of a jira, but the guidelines seemed to suggest that that is the preferred approach ## What changes were proposed in this pull request? Address a few small issues in the findSynonyms logic: 1) remove usage of ``Array.fill`` to zero out the ``cosineVec`` array. The default float value in Scala and Java is 0.0f, so explicitly setting the values to zero is not needed 2) use Floats throughout. The conversion to Doubles before doing the ``priorityQueue`` is totally superfluous, since all the similarity computations are done using Floats anyway. Creating a second large array just serves to put extra strain on the GC 3) convert the slow ``for(i <- cosVec.indices)`` to an ugly, but faster, ``while`` loop These efficiencies are really only apparent when working with a large model ## How was this patch tested? Existing unit tests + some in-house tests to time the difference cc jkbradley MLNick srowen Author: Asher Krim Author: Asher Krim Closes #17263 from Krimit/fasterFindSynonyms. --- .../apache/spark/mllib/feature/Word2Vec.scala | 34 +++++++++++-------- 1 file changed, 19 insertions(+), 15 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala index 531c8b07910fc..6f96813497b62 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala @@ -491,8 +491,8 @@ class Word2VecModel private[spark] ( // wordVecNorms: Array of length numWords, each value being the Euclidean norm // of the wordVector. - private val wordVecNorms: Array[Double] = { - val wordVecNorms = new Array[Double](numWords) + private val wordVecNorms: Array[Float] = { + val wordVecNorms = new Array[Float](numWords) var i = 0 while (i < numWords) { val vec = wordVectors.slice(i * vectorSize, i * vectorSize + vectorSize) @@ -570,7 +570,7 @@ class Word2VecModel private[spark] ( require(num > 0, "Number of similar words should > 0") val fVector = vector.toArray.map(_.toFloat) - val cosineVec = Array.fill[Float](numWords)(0) + val cosineVec = new Array[Float](numWords) val alpha: Float = 1 val beta: Float = 0 // Normalize input vector before blas.sgemv to avoid Inf value @@ -581,22 +581,23 @@ class Word2VecModel private[spark] ( blas.sgemv( "T", vectorSize, numWords, alpha, wordVectors, vectorSize, fVector, 1, beta, cosineVec, 1) - val cosVec = cosineVec.map(_.toDouble) - var ind = 0 - while (ind < numWords) { - val norm = wordVecNorms(ind) - if (norm == 0.0) { - cosVec(ind) = 0.0 + var i = 0 + while (i < numWords) { + val norm = wordVecNorms(i) + if (norm == 0.0f) { + cosineVec(i) = 0.0f } else { - cosVec(ind) /= norm + cosineVec(i) /= norm } - ind += 1 + i += 1 } - val pq = new BoundedPriorityQueue[(String, Double)](num + 1)(Ordering.by(_._2)) + val pq = new BoundedPriorityQueue[(String, Float)](num + 1)(Ordering.by(_._2)) - for(i <- cosVec.indices) { - pq += Tuple2(wordList(i), cosVec(i)) + var j = 0 + while (j < numWords) { + pq += Tuple2(wordList(j), cosineVec(j)) + j += 1 } val scored = pq.toSeq.sortBy(-_._2) @@ -606,7 +607,10 @@ class Word2VecModel private[spark] ( case None => scored } - filtered.take(num).toArray + filtered + .take(num) + .map { case (word, score) => (word, score.toDouble) } + .toArray } /** From d4a637cd46b6dd5cc71ea17a55c4a26186e592c7 Mon Sep 17 00:00:00 2001 From: zero323 Date: Tue, 14 Mar 2017 07:34:44 -0700 Subject: [PATCH 0011/1765] [SPARK-19940][ML][MINOR] FPGrowthModel.transform should skip duplicated items ## What changes were proposed in this pull request? This commit moved `distinct` in its intended place to avoid duplicated predictions and adds unit test covering the issue. ## How was this patch tested? Unit tests. Author: zero323 Closes #17283 from zero323/SPARK-19940. --- .../scala/org/apache/spark/ml/fpm/FPGrowth.scala | 4 ++-- .../org/apache/spark/ml/fpm/FPGrowthSuite.scala | 14 ++++++++++++++ 2 files changed, 16 insertions(+), 2 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/fpm/FPGrowth.scala b/mllib/src/main/scala/org/apache/spark/ml/fpm/FPGrowth.scala index 417968d9b817d..fa39dd954af57 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/fpm/FPGrowth.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/fpm/FPGrowth.scala @@ -245,10 +245,10 @@ class FPGrowthModel private[ml] ( rule._2.filter(item => !itemset.contains(item)) } else { Seq.empty - }) + }).distinct } else { Seq.empty - }.distinct }, dt) + }}, dt) dataset.withColumn($(predictionCol), predictUDF(col($(featuresCol)))) } diff --git a/mllib/src/test/scala/org/apache/spark/ml/fpm/FPGrowthSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/fpm/FPGrowthSuite.scala index 076d55c180548..910d4b07d1302 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/fpm/FPGrowthSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/fpm/FPGrowthSuite.scala @@ -103,6 +103,20 @@ class FPGrowthSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul FPGrowthSuite.allParamSettings, checkModelData) } + test("FPGrowth prediction should not contain duplicates") { + // This should generate rule 1 -> 3, 2 -> 3 + val dataset = spark.createDataFrame(Seq( + Array("1", "3"), + Array("2", "3") + ).map(Tuple1(_))).toDF("features") + val model = new FPGrowth().fit(dataset) + + val prediction = model.transform( + spark.createDataFrame(Seq(Tuple1(Array("1", "2")))).toDF("features") + ).first().getAs[Seq[String]]("prediction") + + assert(prediction === Seq("3")) + } } object FPGrowthSuite { From 85941ecf28362f35718ebcd3a22dbb17adb49154 Mon Sep 17 00:00:00 2001 From: Menglong TAN Date: Tue, 14 Mar 2017 07:45:42 -0700 Subject: [PATCH 0012/1765] [SPARK-11569][ML] Fix StringIndexer to handle null value properly ## What changes were proposed in this pull request? This PR is to enhance StringIndexer with NULL values handling. Before the PR, StringIndexer will throw an exception when encounters NULL values. With this PR: - handleInvalid=error: Throw an exception as before - handleInvalid=skip: Skip null values as well as unseen labels - handleInvalid=keep: Give null values an additional index as well as unseen labels BTW, I noticed someone was trying to solve the same problem ( #9920 ) but seems getting no progress or response for a long time. Would you mind to give me a chance to solve it ? I'm eager to help. :-) ## How was this patch tested? new unit tests Author: Menglong TAN Author: Menglong TAN Closes #17233 from crackcell/11569_StringIndexer_NULL. --- .../spark/ml/feature/StringIndexer.scala | 54 +++++++++++-------- .../spark/ml/feature/StringIndexerSuite.scala | 45 ++++++++++++++++ 2 files changed, 77 insertions(+), 22 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala index 810b02febbe77..99321bcc7cf98 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala @@ -39,20 +39,21 @@ import org.apache.spark.util.collection.OpenHashMap private[feature] trait StringIndexerBase extends Params with HasInputCol with HasOutputCol { /** - * Param for how to handle unseen labels. Options are 'skip' (filter out rows with - * unseen labels), 'error' (throw an error), or 'keep' (put unseen labels in a special additional - * bucket, at index numLabels. + * Param for how to handle invalid data (unseen labels or NULL values). + * Options are 'skip' (filter out rows with invalid data), + * 'error' (throw an error), or 'keep' (put invalid data in a special additional + * bucket, at index numLabels). * Default: "error" * @group param */ @Since("1.6.0") val handleInvalid: Param[String] = new Param[String](this, "handleInvalid", "how to handle " + - "unseen labels. Options are 'skip' (filter out rows with unseen labels), " + - "error (throw an error), or 'keep' (put unseen labels in a special additional bucket, " + - "at index numLabels).", + "invalid data (unseen labels or NULL values). " + + "Options are 'skip' (filter out rows with invalid data), error (throw an error), " + + "or 'keep' (put invalid data in a special additional bucket, at index numLabels).", ParamValidators.inArray(StringIndexer.supportedHandleInvalids)) - setDefault(handleInvalid, StringIndexer.ERROR_UNSEEN_LABEL) + setDefault(handleInvalid, StringIndexer.ERROR_INVALID) /** @group getParam */ @Since("1.6.0") @@ -106,7 +107,7 @@ class StringIndexer @Since("1.4.0") ( @Since("2.0.0") override def fit(dataset: Dataset[_]): StringIndexerModel = { transformSchema(dataset.schema, logging = true) - val counts = dataset.select(col($(inputCol)).cast(StringType)) + val counts = dataset.na.drop(Array($(inputCol))).select(col($(inputCol)).cast(StringType)) .rdd .map(_.getString(0)) .countByValue() @@ -125,11 +126,11 @@ class StringIndexer @Since("1.4.0") ( @Since("1.6.0") object StringIndexer extends DefaultParamsReadable[StringIndexer] { - private[feature] val SKIP_UNSEEN_LABEL: String = "skip" - private[feature] val ERROR_UNSEEN_LABEL: String = "error" - private[feature] val KEEP_UNSEEN_LABEL: String = "keep" + private[feature] val SKIP_INVALID: String = "skip" + private[feature] val ERROR_INVALID: String = "error" + private[feature] val KEEP_INVALID: String = "keep" private[feature] val supportedHandleInvalids: Array[String] = - Array(SKIP_UNSEEN_LABEL, ERROR_UNSEEN_LABEL, KEEP_UNSEEN_LABEL) + Array(SKIP_INVALID, ERROR_INVALID, KEEP_INVALID) @Since("1.6.0") override def load(path: String): StringIndexer = super.load(path) @@ -188,7 +189,7 @@ class StringIndexerModel ( transformSchema(dataset.schema, logging = true) val filteredLabels = getHandleInvalid match { - case StringIndexer.KEEP_UNSEEN_LABEL => labels :+ "__unknown" + case StringIndexer.KEEP_INVALID => labels :+ "__unknown" case _ => labels } @@ -196,22 +197,31 @@ class StringIndexerModel ( .withName($(outputCol)).withValues(filteredLabels).toMetadata() // If we are skipping invalid records, filter them out. val (filteredDataset, keepInvalid) = getHandleInvalid match { - case StringIndexer.SKIP_UNSEEN_LABEL => + case StringIndexer.SKIP_INVALID => val filterer = udf { label: String => labelToIndex.contains(label) } - (dataset.where(filterer(dataset($(inputCol)))), false) - case _ => (dataset, getHandleInvalid == StringIndexer.KEEP_UNSEEN_LABEL) + (dataset.na.drop(Array($(inputCol))).where(filterer(dataset($(inputCol)))), false) + case _ => (dataset, getHandleInvalid == StringIndexer.KEEP_INVALID) } val indexer = udf { label: String => - if (labelToIndex.contains(label)) { - labelToIndex(label) - } else if (keepInvalid) { - labels.length + if (label == null) { + if (keepInvalid) { + labels.length + } else { + throw new SparkException("StringIndexer encountered NULL value. To handle or skip " + + "NULLS, try setting StringIndexer.handleInvalid.") + } } else { - throw new SparkException(s"Unseen label: $label. To handle unseen labels, " + - s"set Param handleInvalid to ${StringIndexer.KEEP_UNSEEN_LABEL}.") + if (labelToIndex.contains(label)) { + labelToIndex(label) + } else if (keepInvalid) { + labels.length + } else { + throw new SparkException(s"Unseen label: $label. To handle unseen labels, " + + s"set Param handleInvalid to ${StringIndexer.KEEP_INVALID}.") + } } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala index 188dffb3dd55f..8d9042b31e033 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala @@ -122,6 +122,51 @@ class StringIndexerSuite assert(output === expected) } + test("StringIndexer with NULLs") { + val data: Seq[(Int, String)] = Seq((0, "a"), (1, "b"), (2, "b"), (3, null)) + val data2: Seq[(Int, String)] = Seq((0, "a"), (1, "b"), (3, null)) + val df = data.toDF("id", "label") + val df2 = data2.toDF("id", "label") + + val indexer = new StringIndexer() + .setInputCol("label") + .setOutputCol("labelIndex") + + withClue("StringIndexer should throw error when setHandleInvalid=error " + + "when given NULL values") { + intercept[SparkException] { + indexer.setHandleInvalid("error") + indexer.fit(df).transform(df2).collect() + } + } + + indexer.setHandleInvalid("skip") + val transformedSkip = indexer.fit(df).transform(df2) + val attrSkip = Attribute + .fromStructField(transformedSkip.schema("labelIndex")) + .asInstanceOf[NominalAttribute] + assert(attrSkip.values.get === Array("b", "a")) + val outputSkip = transformedSkip.select("id", "labelIndex").rdd.map { r => + (r.getInt(0), r.getDouble(1)) + }.collect().toSet + // a -> 1, b -> 0 + val expectedSkip = Set((0, 1.0), (1, 0.0)) + assert(outputSkip === expectedSkip) + + indexer.setHandleInvalid("keep") + val transformedKeep = indexer.fit(df).transform(df2) + val attrKeep = Attribute + .fromStructField(transformedKeep.schema("labelIndex")) + .asInstanceOf[NominalAttribute] + assert(attrKeep.values.get === Array("b", "a", "__unknown")) + val outputKeep = transformedKeep.select("id", "labelIndex").rdd.map { r => + (r.getInt(0), r.getDouble(1)) + }.collect().toSet + // a -> 1, b -> 0, null -> 2 + val expectedKeep = Set((0, 1.0), (1, 0.0), (3, 2.0)) + assert(outputKeep === expectedKeep) + } + test("StringIndexerModel should keep silent if the input column does not exist.") { val indexerModel = new StringIndexerModel("indexer", Array("a", "b", "c")) .setInputCol("label") From a02a0b1703dafab541c9b57939e3ed37e412d0f8 Mon Sep 17 00:00:00 2001 From: jiangxingbo Date: Tue, 14 Mar 2017 10:13:50 -0700 Subject: [PATCH 0013/1765] [SPARK-18961][SQL] Support `SHOW TABLE EXTENDED ... PARTITION` statement ## What changes were proposed in this pull request? We should support the statement `SHOW TABLE EXTENDED LIKE 'table_identifier' PARTITION(partition_spec)`, just like that HIVE does. When partition is specified, the `SHOW TABLE EXTENDED` command should output the information of the partitions instead of the tables. Note that in this statement, we require exact matched partition spec. For example: ``` CREATE TABLE show_t1(a String, b Int) PARTITIONED BY (c String, d String); ALTER TABLE show_t1 ADD PARTITION (c='Us', d=1) PARTITION (c='Us', d=22); -- Output the extended information of Partition(c='Us', d=1) SHOW TABLE EXTENDED LIKE 'show_t1' PARTITION(c='Us', d=1); -- Throw an AnalysisException SHOW TABLE EXTENDED LIKE 'show_t1' PARTITION(c='Us'); ``` ## How was this patch tested? Add new test sqls in file `show-tables.sql`. Add new test case in `DDLSuite`. Author: jiangxingbo Closes #16373 from jiangxb1987/show-partition-extended. --- .../spark/sql/execution/QueryExecution.scala | 4 +- .../spark/sql/execution/SparkSqlParser.scala | 11 +- .../spark/sql/execution/command/tables.scala | 44 ++++-- .../sql-tests/inputs/show-tables.sql | 15 +- .../sql-tests/results/show-tables.sql.out | 133 ++++++++++++++---- .../apache/spark/sql/SQLQueryTestSuite.scala | 5 +- .../sql/execution/command/DDLSuite.scala | 34 ----- 7 files changed, 163 insertions(+), 83 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala index 9a3656ddc79f4..8e8210e334a1d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala @@ -127,8 +127,8 @@ class QueryExecution(val sparkSession: SparkSession, val logical: LogicalPlan) { .map(s => String.format(s"%-20s", s)) .mkString("\t") } - // SHOW TABLES in Hive only output table names, while ours outputs database, table name, isTemp. - case command: ExecutedCommandExec if command.cmd.isInstanceOf[ShowTablesCommand] => + // SHOW TABLES in Hive only output table names, while ours output database, table name, isTemp. + case command @ ExecutedCommandExec(s: ShowTablesCommand) if !s.isExtended => command.executeCollect().map(_.getString(1)) case other => val result: Seq[Seq[Any]] = other.executeCollectPublic().map(_.toSeq).toSeq diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala index 00d1d6d2701f2..abea7a3bcf146 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala @@ -134,7 +134,8 @@ class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder { ShowTablesCommand( Option(ctx.db).map(_.getText), Option(ctx.pattern).map(string), - isExtended = false) + isExtended = false, + partitionSpec = None) } /** @@ -146,14 +147,12 @@ class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder { * }}} */ override def visitShowTable(ctx: ShowTableContext): LogicalPlan = withOrigin(ctx) { - if (ctx.partitionSpec != null) { - operationNotAllowed("SHOW TABLE EXTENDED ... PARTITION", ctx) - } - + val partitionSpec = Option(ctx.partitionSpec).map(visitNonOptionalPartitionSpec) ShowTablesCommand( Option(ctx.db).map(_.getText), Option(ctx.pattern).map(string), - isExtended = true) + isExtended = true, + partitionSpec = partitionSpec) } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala index 86394ff23e379..beb3dcafd64f9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala @@ -616,13 +616,15 @@ case class DescribeTableCommand( * The syntax of using this command in SQL is: * {{{ * SHOW TABLES [(IN|FROM) database_name] [[LIKE] 'identifier_with_wildcards']; - * SHOW TABLE EXTENDED [(IN|FROM) database_name] LIKE 'identifier_with_wildcards'; + * SHOW TABLE EXTENDED [(IN|FROM) database_name] LIKE 'identifier_with_wildcards' + * [PARTITION(partition_spec)]; * }}} */ case class ShowTablesCommand( databaseName: Option[String], tableIdentifierPattern: Option[String], - isExtended: Boolean = false) extends RunnableCommand { + isExtended: Boolean = false, + partitionSpec: Option[TablePartitionSpec] = None) extends RunnableCommand { // The result of SHOW TABLES/SHOW TABLE has three basic columns: database, tableName and // isTemporary. If `isExtended` is true, append column `information` to the output columns. @@ -642,18 +644,34 @@ case class ShowTablesCommand( // instead of calling tables in sparkSession. val catalog = sparkSession.sessionState.catalog val db = databaseName.getOrElse(catalog.getCurrentDatabase) - val tables = - tableIdentifierPattern.map(catalog.listTables(db, _)).getOrElse(catalog.listTables(db)) - tables.map { tableIdent => - val database = tableIdent.database.getOrElse("") - val tableName = tableIdent.table - val isTemp = catalog.isTemporaryTable(tableIdent) - if (isExtended) { - val information = catalog.getTempViewOrPermanentTableMetadata(tableIdent).toString - Row(database, tableName, isTemp, s"${information}\n") - } else { - Row(database, tableName, isTemp) + if (partitionSpec.isEmpty) { + // Show the information of tables. + val tables = + tableIdentifierPattern.map(catalog.listTables(db, _)).getOrElse(catalog.listTables(db)) + tables.map { tableIdent => + val database = tableIdent.database.getOrElse("") + val tableName = tableIdent.table + val isTemp = catalog.isTemporaryTable(tableIdent) + if (isExtended) { + val information = catalog.getTempViewOrPermanentTableMetadata(tableIdent).toString + Row(database, tableName, isTemp, s"$information\n") + } else { + Row(database, tableName, isTemp) + } } + } else { + // Show the information of partitions. + // + // Note: tableIdentifierPattern should be non-empty, otherwise a [[ParseException]] + // should have been thrown by the sql parser. + val tableIdent = TableIdentifier(tableIdentifierPattern.get, Some(db)) + val table = catalog.getTableMetadata(tableIdent).identifier + val partition = catalog.getPartition(tableIdent, partitionSpec.get) + val database = table.database.getOrElse("") + val tableName = table.table + val isTemp = catalog.isTemporaryTable(table) + val information = partition.toString + Seq(Row(database, tableName, isTemp, s"$information\n")) } } } diff --git a/sql/core/src/test/resources/sql-tests/inputs/show-tables.sql b/sql/core/src/test/resources/sql-tests/inputs/show-tables.sql index 10c379dfa014e..3c77c9977d80f 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/show-tables.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/show-tables.sql @@ -17,10 +17,21 @@ SHOW TABLES LIKE 'show_t1*|show_t2*'; SHOW TABLES IN showdb 'show_t*'; -- SHOW TABLE EXTENDED --- Ignore these because there exist timestamp results, e.g. `Created`. --- SHOW TABLE EXTENDED LIKE 'show_t*'; +SHOW TABLE EXTENDED LIKE 'show_t*'; SHOW TABLE EXTENDED; + +-- SHOW TABLE EXTENDED ... PARTITION +SHOW TABLE EXTENDED LIKE 'show_t1' PARTITION(c='Us', d=1); +-- Throw a ParseException if table name is not specified. +SHOW TABLE EXTENDED PARTITION(c='Us', d=1); +-- Don't support regular expression for table name if a partition specification is present. +SHOW TABLE EXTENDED LIKE 'show_t*' PARTITION(c='Us', d=1); +-- Partition specification is not complete. SHOW TABLE EXTENDED LIKE 'show_t1' PARTITION(c='Us'); +-- Partition specification is invalid. +SHOW TABLE EXTENDED LIKE 'show_t1' PARTITION(a='Us', d=1); +-- Partition specification doesn't exist. +SHOW TABLE EXTENDED LIKE 'show_t1' PARTITION(c='Ch', d=1); -- Clean Up DROP TABLE show_t1; diff --git a/sql/core/src/test/resources/sql-tests/results/show-tables.sql.out b/sql/core/src/test/resources/sql-tests/results/show-tables.sql.out index 3d287f43accc9..6d62e6092147b 100644 --- a/sql/core/src/test/resources/sql-tests/results/show-tables.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/show-tables.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 20 +-- Number of queries: 26 -- !query 0 @@ -114,76 +114,159 @@ show_t3 -- !query 12 -SHOW TABLE EXTENDED +SHOW TABLE EXTENDED LIKE 'show_t*' -- !query 12 schema -struct<> +struct -- !query 12 output -org.apache.spark.sql.catalyst.parser.ParseException - -mismatched input '' expecting 'LIKE'(line 1, pos 19) - -== SQL == -SHOW TABLE EXTENDED --------------------^^^ +show_t3 true CatalogTable( + Table: `show_t3` + Created: + Last Access: + Type: VIEW + Schema: [StructField(e,IntegerType,true)] + Storage()) + +showdb show_t1 false CatalogTable( + Table: `showdb`.`show_t1` + Created: + Last Access: + Type: MANAGED + Schema: [StructField(a,StringType,true), StructField(b,IntegerType,true), StructField(c,StringType,true), StructField(d,StringType,true)] + Provider: parquet + Partition Columns: [`c`, `d`] + Storage(Location: sql/core/spark-warehouse/showdb.db/show_t1) + Partition Provider: Catalog) + +showdb show_t2 false CatalogTable( + Table: `showdb`.`show_t2` + Created: + Last Access: + Type: MANAGED + Schema: [StructField(b,StringType,true), StructField(d,IntegerType,true)] + Provider: parquet + Storage(Location: sql/core/spark-warehouse/showdb.db/show_t2)) -- !query 13 -SHOW TABLE EXTENDED LIKE 'show_t1' PARTITION(c='Us') +SHOW TABLE EXTENDED -- !query 13 schema struct<> -- !query 13 output org.apache.spark.sql.catalyst.parser.ParseException -Operation not allowed: SHOW TABLE EXTENDED ... PARTITION(line 1, pos 0) +mismatched input '' expecting 'LIKE'(line 1, pos 19) == SQL == -SHOW TABLE EXTENDED LIKE 'show_t1' PARTITION(c='Us') -^^^ +SHOW TABLE EXTENDED +-------------------^^^ -- !query 14 -DROP TABLE show_t1 +SHOW TABLE EXTENDED LIKE 'show_t1' PARTITION(c='Us', d=1) -- !query 14 schema -struct<> +struct -- !query 14 output - +showdb show_t1 false CatalogPartition( + Partition Values: [c=Us, d=1] + Storage(Location: sql/core/spark-warehouse/showdb.db/show_t1/c=Us/d=1) + Partition Parameters:{}) -- !query 15 -DROP TABLE show_t2 +SHOW TABLE EXTENDED PARTITION(c='Us', d=1) -- !query 15 schema struct<> -- !query 15 output +org.apache.spark.sql.catalyst.parser.ParseException + +mismatched input 'PARTITION' expecting 'LIKE'(line 1, pos 20) +== SQL == +SHOW TABLE EXTENDED PARTITION(c='Us', d=1) +--------------------^^^ -- !query 16 -DROP VIEW show_t3 +SHOW TABLE EXTENDED LIKE 'show_t*' PARTITION(c='Us', d=1) -- !query 16 schema struct<> -- !query 16 output - +org.apache.spark.sql.catalyst.analysis.NoSuchTableException +Table or view 'show_t*' not found in database 'showdb'; -- !query 17 -DROP VIEW global_temp.show_t4 +SHOW TABLE EXTENDED LIKE 'show_t1' PARTITION(c='Us') -- !query 17 schema struct<> -- !query 17 output - +org.apache.spark.sql.AnalysisException +Partition spec is invalid. The spec (c) must match the partition spec (c, d) defined in table '`showdb`.`show_t1`'; -- !query 18 -USE default +SHOW TABLE EXTENDED LIKE 'show_t1' PARTITION(a='Us', d=1) -- !query 18 schema struct<> -- !query 18 output - +org.apache.spark.sql.AnalysisException +Partition spec is invalid. The spec (a, d) must match the partition spec (c, d) defined in table '`showdb`.`show_t1`'; -- !query 19 -DROP DATABASE showdb +SHOW TABLE EXTENDED LIKE 'show_t1' PARTITION(c='Ch', d=1) -- !query 19 schema struct<> -- !query 19 output +org.apache.spark.sql.catalyst.analysis.NoSuchPartitionException +Partition not found in table 'show_t1' database 'showdb': +c -> Ch +d -> 1; + + +-- !query 20 +DROP TABLE show_t1 +-- !query 20 schema +struct<> +-- !query 20 output + + + +-- !query 21 +DROP TABLE show_t2 +-- !query 21 schema +struct<> +-- !query 21 output + + + +-- !query 22 +DROP VIEW show_t3 +-- !query 22 schema +struct<> +-- !query 22 output + + + +-- !query 23 +DROP VIEW global_temp.show_t4 +-- !query 23 schema +struct<> +-- !query 23 output + + + +-- !query 24 +USE default +-- !query 24 schema +struct<> +-- !query 24 output + + + +-- !query 25 +DROP DATABASE showdb +-- !query 25 schema +struct<> +-- !query 25 output diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQueryTestSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQueryTestSuite.scala index 68ababcd11027..c285995514c85 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQueryTestSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQueryTestSuite.scala @@ -222,7 +222,10 @@ class SQLQueryTestSuite extends QueryTest with SharedSQLContext { val df = session.sql(sql) val schema = df.schema // Get answer, but also get rid of the #1234 expression ids that show up in explain plans - val answer = df.queryExecution.hiveResultString().map(_.replaceAll("#\\d+", "#x")) + val answer = df.queryExecution.hiveResultString().map(_.replaceAll("#\\d+", "#x") + .replaceAll("Location: .*/sql/core/", "Location: sql/core/") + .replaceAll("Created: .*\n", "Created: \n") + .replaceAll("Last Access: .*\n", "Last Access: \n")) // If the output is not pre-sorted, sort it. if (isSorted(df.queryExecution.analyzed)) (schema, answer) else (schema, answer.sorted) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala index 0666f446f3b52..6eed10ec51464 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala @@ -977,40 +977,6 @@ abstract class DDLSuite extends QueryTest with SQLTestUtils { testRenamePartitions(isDatasourceTable = false) } - test("show table extended") { - withTempView("show1a", "show2b") { - sql( - """ - |CREATE TEMPORARY VIEW show1a - |USING org.apache.spark.sql.sources.DDLScanSource - |OPTIONS ( - | From '1', - | To '10', - | Table 'test1' - | - |) - """.stripMargin) - sql( - """ - |CREATE TEMPORARY VIEW show2b - |USING org.apache.spark.sql.sources.DDLScanSource - |OPTIONS ( - | From '1', - | To '10', - | Table 'test1' - |) - """.stripMargin) - assert( - sql("SHOW TABLE EXTENDED LIKE 'show*'").count() >= 2) - assert( - sql("SHOW TABLE EXTENDED LIKE 'show*'").schema == - StructType(StructField("database", StringType, false) :: - StructField("tableName", StringType, false) :: - StructField("isTemporary", BooleanType, false) :: - StructField("information", StringType, false) :: Nil)) - } - } - test("show databases") { sql("CREATE DATABASE showdb2B") sql("CREATE DATABASE showdb1A") From 6325a2f82a95a63bee020122620bc4f5fd25d059 Mon Sep 17 00:00:00 2001 From: Takeshi Yamamuro Date: Tue, 14 Mar 2017 18:51:05 +0100 Subject: [PATCH 0014/1765] [SPARK-19923][SQL] Remove unnecessary type conversions per call in Hive ## What changes were proposed in this pull request? This pr removed unnecessary type conversions per call in Hive: https://github.com/apache/spark/blob/master/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala#L116 ## How was this patch tested? Existing tests Author: Takeshi Yamamuro Closes #17264 from maropu/SPARK-19923. --- .../scala/org/apache/spark/sql/hive/hiveUDFs.scala | 3 ++- .../apache/spark/sql/hive/orc/OrcFileFormat.scala | 13 +++++++++---- 2 files changed, 11 insertions(+), 5 deletions(-) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala index 506949cb682b0..51c814cf32a81 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala @@ -108,12 +108,13 @@ private[hive] case class HiveSimpleUDF( private[hive] class DeferredObjectAdapter(oi: ObjectInspector, dataType: DataType) extends DeferredObject with HiveInspectors { + private val wrapper = wrapperFor(oi, dataType) private var func: () => Any = _ def set(func: () => Any): Unit = { this.func = func } override def prepare(i: Int): Unit = {} - override def get(): AnyRef = wrap(func(), oi, dataType) + override def get(): AnyRef = wrapper(func()).asInstanceOf[AnyRef] } private[hive] case class HiveGenericUDF( diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileFormat.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileFormat.scala index f496c01ce9ff7..3a34ec55c8b07 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileFormat.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileFormat.scala @@ -20,6 +20,8 @@ package org.apache.spark.sql.hive.orc import java.net.URI import java.util.Properties +import scala.collection.JavaConverters._ + import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.{FileStatus, Path} import org.apache.hadoop.hive.conf.HiveConf.ConfVars @@ -196,6 +198,11 @@ private[orc] class OrcSerializer(dataSchema: StructType, conf: Configuration) private[this] val cachedOrcStruct = structOI.create().asInstanceOf[OrcStruct] + // Wrapper functions used to wrap Spark SQL input arguments into Hive specific format + private[this] val wrappers = dataSchema.zip(structOI.getAllStructFieldRefs().asScala.toSeq).map { + case (f, i) => wrapperFor(i.getFieldObjectInspector, f.dataType) + } + private[this] def wrapOrcStruct( struct: OrcStruct, oi: SettableStructObjectInspector, @@ -208,10 +215,8 @@ private[orc] class OrcSerializer(dataSchema: StructType, conf: Configuration) oi.setStructFieldData( struct, fieldRefs.get(i), - wrap( - row.get(i, dataSchema(i).dataType), - fieldRefs.get(i).getFieldObjectInspector, - dataSchema(i).dataType)) + wrappers(i)(row.get(i, dataSchema(i).dataType)) + ) i += 1 } } From e04c05cf41a125b0526f59f9b9e7fdf0b78b8b21 Mon Sep 17 00:00:00 2001 From: Herman van Hovell Date: Tue, 14 Mar 2017 18:52:16 +0100 Subject: [PATCH 0015/1765] [SPARK-19933][SQL] Do not change output of a subquery ## What changes were proposed in this pull request? The `RemoveRedundantAlias` rule can change the output attributes (the expression id's to be precise) of a query by eliminating the redundant alias producing them. This is no problem for a regular query, but can cause problems for correlated subqueries: The attributes produced by the subquery are used in the parent plan; changing them will break the parent plan. This PR fixes this by wrapping a subquery in a `Subquery` top level node when it gets optimized. The `RemoveRedundantAlias` rule now recognizes `Subquery` and makes sure that the output attributes of the `Subquery` node are retained. ## How was this patch tested? Added a test case to `RemoveRedundantAliasAndProjectSuite` and added a regression test to `SubquerySuite`. Author: Herman van Hovell Closes #17278 from hvanhovell/SPARK-19933. --- .../spark/sql/catalyst/optimizer/Optimizer.scala | 15 ++++++++++++--- .../plans/logical/basicLogicalOperators.scala | 8 ++++++++ .../RemoveRedundantAliasAndProjectSuite.scala | 8 ++++++++ .../org/apache/spark/sql/SubquerySuite.scala | 14 ++++++++++++++ 4 files changed, 42 insertions(+), 3 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index e9dbded3d4d02..c8ed4190a13ad 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -142,7 +142,8 @@ abstract class Optimizer(sessionCatalog: SessionCatalog, conf: CatalystConf) object OptimizeSubqueries extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { case s: SubqueryExpression => - s.withNewPlan(Optimizer.this.execute(s.plan)) + val Subquery(newPlan) = Optimizer.this.execute(Subquery(s.plan)) + s.withNewPlan(newPlan) } } } @@ -187,7 +188,10 @@ object RemoveRedundantAliases extends Rule[LogicalPlan] { // If the alias name is different from attribute name, we can't strip it either, or we // may accidentally change the output schema name of the root plan. case a @ Alias(attr: Attribute, name) - if a.metadata == Metadata.empty && name == attr.name && !blacklist.contains(attr) => + if a.metadata == Metadata.empty && + name == attr.name && + !blacklist.contains(attr) && + !blacklist.contains(a) => attr case a => a } @@ -195,10 +199,15 @@ object RemoveRedundantAliases extends Rule[LogicalPlan] { /** * Remove redundant alias expression from a LogicalPlan and its subtree. A blacklist is used to * prevent the removal of seemingly redundant aliases used to deduplicate the input for a (self) - * join. + * join or to prevent the removal of top-level subquery attributes. */ private def removeRedundantAliases(plan: LogicalPlan, blacklist: AttributeSet): LogicalPlan = { plan match { + // We want to keep the same output attributes for subqueries. This means we cannot remove + // the aliases that produce these attributes + case Subquery(child) => + Subquery(removeRedundantAliases(child, blacklist ++ child.outputSet)) + // A join has to be treated differently, because the left and the right side of the join are // not allowed to use the same attributes. We use a blacklist to prevent us from creating a // situation in which this happens; the rule will only remove an alias if its child diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala index 31b6ed48a2230..5cbf263d1ce42 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala @@ -38,6 +38,14 @@ case class ReturnAnswer(child: LogicalPlan) extends UnaryNode { override def output: Seq[Attribute] = child.output } +/** + * This node is inserted at the top of a subquery when it is optimized. This makes sure we can + * recognize a subquery as such, and it allows us to write subquery aware transformations. + */ +case class Subquery(child: LogicalPlan) extends UnaryNode { + override def output: Seq[Attribute] = child.output +} + case class Project(projectList: Seq[NamedExpression], child: LogicalPlan) extends UnaryNode { override def output: Seq[Attribute] = projectList.map(_.toAttribute) override def maxRows: Option[Long] = child.maxRows diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RemoveRedundantAliasAndProjectSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RemoveRedundantAliasAndProjectSuite.scala index c01ea01ec6808..1973b5abb462d 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RemoveRedundantAliasAndProjectSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RemoveRedundantAliasAndProjectSuite.scala @@ -116,4 +116,12 @@ class RemoveRedundantAliasAndProjectSuite extends PlanTest with PredicateHelper val expected = relation.window(Seq('b), Seq('a), Seq()).analyze comparePlans(optimized, expected) } + + test("do not remove output attributes from a subquery") { + val relation = LocalRelation('a.int, 'b.int) + val query = Subquery(relation.select('a as "a", 'b as "b").where('b < 10).select('a).analyze) + val optimized = Optimize.execute(query) + val expected = Subquery(relation.select('a as "a", 'b).where('b < 10).select('a).analyze) + comparePlans(optimized, expected) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala index 6f1cd49c08ee1..5fe6667ceca18 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala @@ -830,4 +830,18 @@ class SubquerySuite extends QueryTest with SharedSQLContext { Row(1) :: Row(0) :: Nil) } } + + test("SPARK-19933 Do not eliminate top-level aliases in sub-queries") { + withTempView("t1", "t2") { + spark.range(4).createOrReplaceTempView("t1") + checkAnswer( + sql("select * from t1 where id in (select id as id from t1)"), + Row(0) :: Row(1) :: Row(2) :: Row(3) :: Nil) + + spark.range(2).createOrReplaceTempView("t2") + checkAnswer( + sql("select * from t1 where id in (select id as id from t2)"), + Row(0) :: Row(1) :: Nil) + } + } } From 6eac96823c7b244773bd810812b369e336a65837 Mon Sep 17 00:00:00 2001 From: Nattavut Sutyanyong Date: Tue, 14 Mar 2017 20:34:59 +0100 Subject: [PATCH 0016/1765] [SPARK-18966][SQL] NOT IN subquery with correlated expressions may return incorrect result ## What changes were proposed in this pull request? This PR fixes the following problem: ```` Seq((1, 2)).toDF("a1", "a2").createOrReplaceTempView("a") Seq[(java.lang.Integer, java.lang.Integer)]((1, null)).toDF("b1", "b2").createOrReplaceTempView("b") // The expected result is 1 row of (1,2) as shown in the next statement. sql("select * from a where a1 not in (select b1 from b where b2 = a2)").show +---+---+ | a1| a2| +---+---+ +---+---+ sql("select * from a where a1 not in (select b1 from b where b2 = 2)").show +---+---+ | a1| a2| +---+---+ | 1| 2| +---+---+ ```` There are a number of scenarios to consider: 1. When the correlated predicate yields a match (i.e., B.B2 = A.A2) 1.1. When the NOT IN expression yields a match (i.e., A.A1 = B.B1) 1.2. When the NOT IN expression yields no match (i.e., A.A1 = B.B1 returns false) 1.3. When A.A1 is null 1.4. When B.B1 is null 1.4.1. When A.A1 is not null 1.4.2. When A.A1 is null 2. When the correlated predicate yields no match (i.e.,B.B2 = A.A2 is false or unknown) 2.1. When B.B2 is null and A.A2 is null 2.2. When B.B2 is null and A.A2 is not null 2.3. When the value of A.A2 does not match any of B.B2 ```` A.A1 A.A2 B.B1 B.B2 ----- ----- ----- ----- 1 1 1 1 (1.1) 2 1 (1.2) null 1 (1.3) 1 3 null 3 (1.4.1) null 3 (1.4.2) 1 null 1 null (2.1) null 2 (2.2 & 2.3) ```` We can divide the evaluation of the above correlated NOT IN subquery into 2 groups:- Group 1: The rows in A when there is a match from the correlated predicate (A.A1 = B.B1) In this case, the result of the subquery is not empty and the semantics of the NOT IN depends solely on the evaluation of the equality comparison of the columns of NOT IN, i.e., A1 = B1, which says - If A.A1 is null, the row is filtered (1.3 and 1.4.2) - If A.A1 = B.B1, the row is filtered (1.1) - If B.B1 is null, any rows of A in the same group (A.A2 = B.B2) is filtered (1.4.1 & 1.4.2) - Otherwise, the row is qualified. Hence, in this group, the result is the row from (1.2). Group 2: The rows in A when there is no match from the correlated predicate (A.A2 = B.B2) In this case, all the rows in A, including the rows where A.A1, are qualified because the subquery returns an empty set and by the semantics of the NOT IN, all rows from the parent side qualifies as the result set, that is, the rows from (2.1, 2.2 and 2.3). In conclusion, the correct result set of the above query is ```` A.A1 A.A2 ----- ----- 2 1 (1.2) 1 null (2.1) null 2 (2.2 & 2.3) ```` ## How was this patch tested? unit tests, regression tests, and new test cases focusing on the problem being fixed. Author: Nattavut Sutyanyong Closes #17294 from nsyca/18966. --- .../sql/catalyst/optimizer/subquery.scala | 13 +++-- .../inputs/subquery/in-subquery/simple-in.sql | 24 +++++++++ .../subquery/in-subquery/simple-in.sql.out | 50 ++++++++++++++++++- 3 files changed, 82 insertions(+), 5 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala index ba3fd1d5f802f..2a3e07aebe709 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala @@ -80,14 +80,19 @@ object RewritePredicateSubquery extends Rule[LogicalPlan] with PredicateHelper { // Note that will almost certainly be planned as a Broadcast Nested Loop join. // Use EXISTS if performance matters to you. val inConditions = getValueExpression(value).zip(sub.output).map(EqualTo.tupled) - val (joinCond, outerPlan) = rewriteExistentialExpr(inConditions ++ conditions, p) + val (joinCond, outerPlan) = rewriteExistentialExpr(inConditions, p) // Expand the NOT IN expression with the NULL-aware semantic // to its full form. That is from: - // (a1,b1,...) = (a2,b2,...) + // (a1,a2,...) = (b1,b2,...) // to - // (a1=a2 OR isnull(a1=a2)) AND (b1=b2 OR isnull(b1=b2)) AND ... + // (a1=b1 OR isnull(a1=b1)) AND (a2=b2 OR isnull(a2=b2)) AND ... val joinConds = splitConjunctivePredicates(joinCond.get) - val pairs = joinConds.map(c => Or(c, IsNull(c))).reduceLeft(And) + // After that, add back the correlated join predicate(s) in the subquery + // Example: + // SELECT ... FROM A WHERE A.A1 NOT IN (SELECT B.B1 FROM B WHERE B.B2 = A.A2 AND B.B3 > 1) + // will have the final conditions in the LEFT ANTI as + // (A.A1 = B.B1 OR ISNULL(A.A1 = B.B1)) AND (B.B2 = A.A2) + val pairs = (joinConds.map(c => Or(c, IsNull(c))) ++ conditions).reduceLeft(And) Join(outerPlan, sub, LeftAnti, Option(pairs)) case (p, predicate) => val (newCond, inputPlan) = rewriteExistentialExpr(Seq(predicate), p) diff --git a/sql/core/src/test/resources/sql-tests/inputs/subquery/in-subquery/simple-in.sql b/sql/core/src/test/resources/sql-tests/inputs/subquery/in-subquery/simple-in.sql index 20370b045e803..f19567d2fac20 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/subquery/in-subquery/simple-in.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/subquery/in-subquery/simple-in.sql @@ -109,4 +109,28 @@ FROM t1 WHERE t1a NOT IN (SELECT t2a FROM t2); +-- DDLs +create temporary view a as select * from values + (1, 1), (2, 1), (null, 1), (1, 3), (null, 3), (1, null), (null, 2) + as a(a1, a2); +create temporary view b as select * from values + (1, 1, 2), (null, 3, 2), (1, null, 2), (1, 2, null) + as b(b1, b2, b3); + +-- TC 02.01 +SELECT a1, a2 +FROM a +WHERE a1 NOT IN (SELECT b.b1 + FROM b + WHERE a.a2 = b.b2) +; + +-- TC 02.02 +SELECT a1, a2 +FROM a +WHERE a1 NOT IN (SELECT b.b1 + FROM b + WHERE a.a2 = b.b2 + AND b.b3 > 1) +; diff --git a/sql/core/src/test/resources/sql-tests/results/subquery/in-subquery/simple-in.sql.out b/sql/core/src/test/resources/sql-tests/results/subquery/in-subquery/simple-in.sql.out index 66493d7fcc92d..d69b4bcf185c3 100644 --- a/sql/core/src/test/resources/sql-tests/results/subquery/in-subquery/simple-in.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/subquery/in-subquery/simple-in.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 10 +-- Number of queries: 14 -- !query 0 @@ -174,3 +174,51 @@ t1a 6 2014-04-04 01:02:00.001 t1d 10 2015-05-04 01:01:00 t1d NULL 2014-06-04 01:01:00 t1d NULL 2014-07-04 01:02:00.001 + + +-- !query 10 +create temporary view a as select * from values + (1, 1), (2, 1), (null, 1), (1, 3), (null, 3), (1, null), (null, 2) + as a(a1, a2) +-- !query 10 schema +struct<> +-- !query 10 output + + + +-- !query 11 +create temporary view b as select * from values + (1, 1, 2), (null, 3, 2), (1, null, 2), (1, 2, null) + as b(b1, b2, b3) +-- !query 11 schema +struct<> +-- !query 11 output + + + +-- !query 12 +SELECT a1, a2 +FROM a +WHERE a1 NOT IN (SELECT b.b1 + FROM b + WHERE a.a2 = b.b2) +-- !query 12 schema +struct +-- !query 12 output +1 NULL +2 1 + + +-- !query 13 +SELECT a1, a2 +FROM a +WHERE a1 NOT IN (SELECT b.b1 + FROM b + WHERE a.a2 = b.b2 + AND b.b3 > 1) +-- !query 13 schema +struct +-- !query 13 output +1 NULL +2 1 +NULL 2 From 7ded39c223429265b23940ca8244660dbee8320c Mon Sep 17 00:00:00 2001 From: Takuya UESHIN Date: Tue, 14 Mar 2017 13:57:23 -0700 Subject: [PATCH 0017/1765] [SPARK-19817][SQL] Make it clear that `timeZone` option is a general option in DataFrameReader/Writer. ## What changes were proposed in this pull request? As timezone setting can also affect partition values, it works for all formats, we should make it clear. ## How was this patch tested? Existing tests. Author: Takuya UESHIN Closes #17281 from ueshin/issues/SPARK-19817. --- python/pyspark/sql/readwriter.py | 46 +++++++++++-------- .../sql/catalyst/catalog/interface.scala | 3 +- .../spark/sql/catalyst/json/JSONOptions.scala | 5 +- .../sql/catalyst/util/DateTimeUtils.scala | 2 + .../expressions/JsonExpressionsSuite.scala | 9 ++-- .../apache/spark/sql/DataFrameReader.scala | 22 +++++++-- .../apache/spark/sql/DataFrameWriter.scala | 22 +++++++-- .../execution/OptimizeMetadataOnlyQuery.scala | 2 +- .../datasources/FileFormatWriter.scala | 2 +- .../PartitioningAwareFileIndex.scala | 2 +- .../datasources/csv/CSVOptions.scala | 5 +- .../execution/datasources/csv/CSVSuite.scala | 5 +- .../datasources/json/JsonSuite.scala | 4 +- .../ParquetPartitionDiscoverySuite.scala | 11 +++-- .../sql/sources/PartitionedWriteSuite.scala | 4 +- .../sql/sources/ResolvedDataSourceSuite.scala | 3 +- .../spark/sql/hive/HiveExternalCatalog.scala | 2 +- 17 files changed, 101 insertions(+), 48 deletions(-) diff --git a/python/pyspark/sql/readwriter.py b/python/pyspark/sql/readwriter.py index 4354345ebc550..705803791d894 100644 --- a/python/pyspark/sql/readwriter.py +++ b/python/pyspark/sql/readwriter.py @@ -109,6 +109,11 @@ def schema(self, schema): @since(1.5) def option(self, key, value): """Adds an input option for the underlying data source. + + You can set the following option(s) for reading files: + * ``timeZone``: sets the string that indicates a timezone to be used to parse timestamps + in the JSON/CSV datasources or parttion values. + If it isn't set, it uses the default value, session local timezone. """ self._jreader = self._jreader.option(key, to_str(value)) return self @@ -116,6 +121,11 @@ def option(self, key, value): @since(1.4) def options(self, **options): """Adds input options for the underlying data source. + + You can set the following option(s) for reading files: + * ``timeZone``: sets the string that indicates a timezone to be used to parse timestamps + in the JSON/CSV datasources or parttion values. + If it isn't set, it uses the default value, session local timezone. """ for k in options: self._jreader = self._jreader.option(k, to_str(options[k])) @@ -159,7 +169,7 @@ def json(self, path, schema=None, primitivesAsString=None, prefersDecimal=None, allowComments=None, allowUnquotedFieldNames=None, allowSingleQuotes=None, allowNumericLeadingZero=None, allowBackslashEscapingAnyCharacter=None, mode=None, columnNameOfCorruptRecord=None, dateFormat=None, timestampFormat=None, - timeZone=None, wholeFile=None): + wholeFile=None): """ Loads JSON files and returns the results as a :class:`DataFrame`. @@ -214,8 +224,6 @@ def json(self, path, schema=None, primitivesAsString=None, prefersDecimal=None, formats follow the formats at ``java.text.SimpleDateFormat``. This applies to timestamp type. If None is set, it uses the default value, ``yyyy-MM-dd'T'HH:mm:ss.SSSZZ``. - :param timeZone: sets the string that indicates a timezone to be used to parse timestamps. - If None is set, it uses the default value, session local timezone. :param wholeFile: parse one record, which may span multiple lines, per file. If None is set, it uses the default value, ``false``. @@ -234,7 +242,7 @@ def json(self, path, schema=None, primitivesAsString=None, prefersDecimal=None, allowSingleQuotes=allowSingleQuotes, allowNumericLeadingZero=allowNumericLeadingZero, allowBackslashEscapingAnyCharacter=allowBackslashEscapingAnyCharacter, mode=mode, columnNameOfCorruptRecord=columnNameOfCorruptRecord, dateFormat=dateFormat, - timestampFormat=timestampFormat, timeZone=timeZone, wholeFile=wholeFile) + timestampFormat=timestampFormat, wholeFile=wholeFile) if isinstance(path, basestring): path = [path] if type(path) == list: @@ -307,7 +315,7 @@ def csv(self, path, schema=None, sep=None, encoding=None, quote=None, escape=Non comment=None, header=None, inferSchema=None, ignoreLeadingWhiteSpace=None, ignoreTrailingWhiteSpace=None, nullValue=None, nanValue=None, positiveInf=None, negativeInf=None, dateFormat=None, timestampFormat=None, maxColumns=None, - maxCharsPerColumn=None, maxMalformedLogPerPartition=None, mode=None, timeZone=None, + maxCharsPerColumn=None, maxMalformedLogPerPartition=None, mode=None, columnNameOfCorruptRecord=None, wholeFile=None): """Loads a CSV file and returns the result as a :class:`DataFrame`. @@ -367,8 +375,6 @@ def csv(self, path, schema=None, sep=None, encoding=None, quote=None, escape=Non uses the default value, ``10``. :param mode: allows a mode for dealing with corrupt records during parsing. If None is set, it uses the default value, ``PERMISSIVE``. - :param timeZone: sets the string that indicates a timezone to be used to parse timestamps. - If None is set, it uses the default value, session local timezone. * ``PERMISSIVE`` : sets other fields to ``null`` when it meets a corrupted \ record, and puts the malformed string into a field configured by \ @@ -399,7 +405,7 @@ def csv(self, path, schema=None, sep=None, encoding=None, quote=None, escape=Non nanValue=nanValue, positiveInf=positiveInf, negativeInf=negativeInf, dateFormat=dateFormat, timestampFormat=timestampFormat, maxColumns=maxColumns, maxCharsPerColumn=maxCharsPerColumn, - maxMalformedLogPerPartition=maxMalformedLogPerPartition, mode=mode, timeZone=timeZone, + maxMalformedLogPerPartition=maxMalformedLogPerPartition, mode=mode, columnNameOfCorruptRecord=columnNameOfCorruptRecord, wholeFile=wholeFile) if isinstance(path, basestring): path = [path] @@ -521,6 +527,11 @@ def format(self, source): @since(1.5) def option(self, key, value): """Adds an output option for the underlying data source. + + You can set the following option(s) for writing files: + * ``timeZone``: sets the string that indicates a timezone to be used to format + timestamps in the JSON/CSV datasources or parttion values. + If it isn't set, it uses the default value, session local timezone. """ self._jwrite = self._jwrite.option(key, to_str(value)) return self @@ -528,6 +539,11 @@ def option(self, key, value): @since(1.4) def options(self, **options): """Adds output options for the underlying data source. + + You can set the following option(s) for writing files: + * ``timeZone``: sets the string that indicates a timezone to be used to format + timestamps in the JSON/CSV datasources or parttion values. + If it isn't set, it uses the default value, session local timezone. """ for k in options: self._jwrite = self._jwrite.option(k, to_str(options[k])) @@ -619,8 +635,7 @@ def saveAsTable(self, name, format=None, mode=None, partitionBy=None, **options) self._jwrite.saveAsTable(name) @since(1.4) - def json(self, path, mode=None, compression=None, dateFormat=None, timestampFormat=None, - timeZone=None): + def json(self, path, mode=None, compression=None, dateFormat=None, timestampFormat=None): """Saves the content of the :class:`DataFrame` in JSON format at the specified path. :param path: the path in any Hadoop supported file system @@ -641,15 +656,12 @@ def json(self, path, mode=None, compression=None, dateFormat=None, timestampForm formats follow the formats at ``java.text.SimpleDateFormat``. This applies to timestamp type. If None is set, it uses the default value, ``yyyy-MM-dd'T'HH:mm:ss.SSSZZ``. - :param timeZone: sets the string that indicates a timezone to be used to format timestamps. - If None is set, it uses the default value, session local timezone. >>> df.write.json(os.path.join(tempfile.mkdtemp(), 'data')) """ self.mode(mode) self._set_opts( - compression=compression, dateFormat=dateFormat, timestampFormat=timestampFormat, - timeZone=timeZone) + compression=compression, dateFormat=dateFormat, timestampFormat=timestampFormat) self._jwrite.json(path) @since(1.4) @@ -696,7 +708,7 @@ def text(self, path, compression=None): @since(2.0) def csv(self, path, mode=None, compression=None, sep=None, quote=None, escape=None, header=None, nullValue=None, escapeQuotes=None, quoteAll=None, dateFormat=None, - timestampFormat=None, timeZone=None): + timestampFormat=None): """Saves the content of the :class:`DataFrame` in CSV format at the specified path. :param path: the path in any Hadoop supported file system @@ -736,15 +748,13 @@ def csv(self, path, mode=None, compression=None, sep=None, quote=None, escape=No formats follow the formats at ``java.text.SimpleDateFormat``. This applies to timestamp type. If None is set, it uses the default value, ``yyyy-MM-dd'T'HH:mm:ss.SSSZZ``. - :param timeZone: sets the string that indicates a timezone to be used to parse timestamps. - If None is set, it uses the default value, session local timezone. >>> df.write.csv(os.path.join(tempfile.mkdtemp(), 'data')) """ self.mode(mode) self._set_opts(compression=compression, sep=sep, quote=quote, escape=escape, header=header, nullValue=nullValue, escapeQuotes=escapeQuotes, quoteAll=quoteAll, - dateFormat=dateFormat, timestampFormat=timestampFormat, timeZone=timeZone) + dateFormat=dateFormat, timestampFormat=timestampFormat) self._jwrite.csv(path) @since(1.5) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala index e3631b0c07737..b862deaf36369 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala @@ -113,7 +113,8 @@ case class CatalogTablePartition( */ def toRow(partitionSchema: StructType, defaultTimeZondId: String): InternalRow = { val caseInsensitiveProperties = CaseInsensitiveMap(storage.properties) - val timeZoneId = caseInsensitiveProperties.getOrElse("timeZone", defaultTimeZondId) + val timeZoneId = caseInsensitiveProperties.getOrElse( + DateTimeUtils.TIMEZONE_OPTION, defaultTimeZondId) InternalRow.fromSeq(partitionSchema.map { field => Cast(Literal(spec(field.name)), field.dataType, Option(timeZoneId)).eval() }) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JSONOptions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JSONOptions.scala index 5a91f9c1939aa..5f222ec602c99 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JSONOptions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JSONOptions.scala @@ -23,7 +23,7 @@ import com.fasterxml.jackson.core.{JsonFactory, JsonParser} import org.apache.commons.lang3.time.FastDateFormat import org.apache.spark.internal.Logging -import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, CompressionCodecs, ParseModes} +import org.apache.spark.sql.catalyst.util._ /** * Options for parsing JSON data into Spark SQL rows. @@ -69,7 +69,8 @@ private[sql] class JSONOptions( val columnNameOfCorruptRecord = parameters.getOrElse("columnNameOfCorruptRecord", defaultColumnNameOfCorruptRecord) - val timeZone: TimeZone = TimeZone.getTimeZone(parameters.getOrElse("timeZone", defaultTimeZoneId)) + val timeZone: TimeZone = TimeZone.getTimeZone( + parameters.getOrElse(DateTimeUtils.TIMEZONE_OPTION, defaultTimeZoneId)) // Uses `FastDateFormat` which can be direct replacement for `SimpleDateFormat` and thread-safe. val dateFormat: FastDateFormat = diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala index 9e1de0fd2f3d6..9b94c1e2b40bb 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala @@ -60,6 +60,8 @@ object DateTimeUtils { final val TimeZoneGMT = TimeZone.getTimeZone("GMT") final val MonthOf31Days = Set(1, 3, 5, 7, 8, 10, 12) + val TIMEZONE_OPTION = "timeZone" + def defaultTimeZone(): TimeZone = TimeZone.getDefault() // Reuse the Calendar object in each thread as it is expensive to create in each method call. diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/JsonExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/JsonExpressionsSuite.scala index e3584909ddc4a..19d0c8eb92f1a 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/JsonExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/JsonExpressionsSuite.scala @@ -471,7 +471,8 @@ class JsonExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation( JsonToStruct( schema, - Map("timestampFormat" -> "yyyy-MM-dd'T'HH:mm:ss", "timeZone" -> tz.getID), + Map("timestampFormat" -> "yyyy-MM-dd'T'HH:mm:ss", + DateTimeUtils.TIMEZONE_OPTION -> tz.getID), Literal(jsonData2), gmtId), InternalRow(c.getTimeInMillis * 1000L) @@ -523,14 +524,16 @@ class JsonExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation( StructToJson( - Map("timestampFormat" -> "yyyy-MM-dd'T'HH:mm:ss", "timeZone" -> gmtId.get), + Map("timestampFormat" -> "yyyy-MM-dd'T'HH:mm:ss", + DateTimeUtils.TIMEZONE_OPTION -> gmtId.get), struct, gmtId), """{"t":"2016-01-01T00:00:00"}""" ) checkEvaluation( StructToJson( - Map("timestampFormat" -> "yyyy-MM-dd'T'HH:mm:ss", "timeZone" -> "PST"), + Map("timestampFormat" -> "yyyy-MM-dd'T'HH:mm:ss", + DateTimeUtils.TIMEZONE_OPTION -> "PST"), struct, gmtId), """{"t":"2015-12-31T16:00:00"}""" diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala index 4f4cc93117494..f1bce1aa41029 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala @@ -70,6 +70,12 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { /** * Adds an input option for the underlying data source. * + * You can set the following option(s): + *
    + *
  • `timeZone` (default session local timezone): sets the string that indicates a timezone + * to be used to parse timestamps in the JSON/CSV datasources or parttion values.
  • + *
+ * * @since 1.4.0 */ def option(key: String, value: String): DataFrameReader = { @@ -101,6 +107,12 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { /** * (Scala-specific) Adds input options for the underlying data source. * + * You can set the following option(s): + *
    + *
  • `timeZone` (default session local timezone): sets the string that indicates a timezone + * to be used to parse timestamps in the JSON/CSV datasources or parttion values.
  • + *
+ * * @since 1.4.0 */ def options(options: scala.collection.Map[String, String]): DataFrameReader = { @@ -111,6 +123,12 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { /** * Adds input options for the underlying data source. * + * You can set the following option(s): + *
    + *
  • `timeZone` (default session local timezone): sets the string that indicates a timezone + * to be used to parse timestamps in the JSON/CSV datasources or parttion values.
  • + *
+ * * @since 1.4.0 */ def options(options: java.util.Map[String, String]): DataFrameReader = { @@ -305,8 +323,6 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { *
  • `timestampFormat` (default `yyyy-MM-dd'T'HH:mm:ss.SSSZZ`): sets the string that * indicates a timestamp format. Custom date formats follow the formats at * `java.text.SimpleDateFormat`. This applies to timestamp type.
  • - *
  • `timeZone` (default session local timezone): sets the string that indicates a timezone - * to be used to parse timestamps.
  • *
  • `wholeFile` (default `false`): parse one record, which may span multiple lines, * per file
  • * @@ -478,8 +494,6 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { *
  • `timestampFormat` (default `yyyy-MM-dd'T'HH:mm:ss.SSSZZ`): sets the string that * indicates a timestamp format. Custom date formats follow the formats at * `java.text.SimpleDateFormat`. This applies to timestamp type.
  • - *
  • `timeZone` (default session local timezone): sets the string that indicates a timezone - * to be used to parse timestamps.
  • *
  • `maxColumns` (default `20480`): defines a hard limit of how many columns * a record can have.
  • *
  • `maxCharsPerColumn` (default `-1`): defines the maximum number of characters allowed diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala index 49e85dc7b13f6..608160a214fba 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala @@ -90,6 +90,12 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { /** * Adds an output option for the underlying data source. * + * You can set the following option(s): + *
      + *
    • `timeZone` (default session local timezone): sets the string that indicates a timezone + * to be used to format timestamps in the JSON/CSV datasources or parttion values.
    • + *
    + * * @since 1.4.0 */ def option(key: String, value: String): DataFrameWriter[T] = { @@ -121,6 +127,12 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { /** * (Scala-specific) Adds output options for the underlying data source. * + * You can set the following option(s): + *
      + *
    • `timeZone` (default session local timezone): sets the string that indicates a timezone + * to be used to format timestamps in the JSON/CSV datasources or parttion values.
    • + *
    + * * @since 1.4.0 */ def options(options: scala.collection.Map[String, String]): DataFrameWriter[T] = { @@ -131,6 +143,12 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { /** * Adds output options for the underlying data source. * + * You can set the following option(s): + *
      + *
    • `timeZone` (default session local timezone): sets the string that indicates a timezone + * to be used to format timestamps in the JSON/CSV datasources or parttion values.
    • + *
    + * * @since 1.4.0 */ def options(options: java.util.Map[String, String]): DataFrameWriter[T] = { @@ -457,8 +475,6 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { *
  • `timestampFormat` (default `yyyy-MM-dd'T'HH:mm:ss.SSSZZ`): sets the string that * indicates a timestamp format. Custom date formats follow the formats at * `java.text.SimpleDateFormat`. This applies to timestamp type.
  • - *
  • `timeZone` (default session local timezone): sets the string that indicates a timezone - * to be used to format timestamps.
  • * * * @since 1.4.0 @@ -565,8 +581,6 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { *
  • `timestampFormat` (default `yyyy-MM-dd'T'HH:mm:ss.SSSZZ`): sets the string that * indicates a timestamp format. Custom date formats follow the formats at * `java.text.SimpleDateFormat`. This applies to timestamp type.
  • - *
  • `timeZone` (default session local timezone): sets the string that indicates a timezone - * to be used to format timestamps.
  • * * * @since 2.0.0 diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/OptimizeMetadataOnlyQuery.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/OptimizeMetadataOnlyQuery.scala index aa578f4d23133..769deb1890b6d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/OptimizeMetadataOnlyQuery.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/OptimizeMetadataOnlyQuery.scala @@ -105,7 +105,7 @@ case class OptimizeMetadataOnlyQuery( val partAttrs = getPartitionAttrs(relation.tableMeta.partitionColumnNames, relation) val caseInsensitiveProperties = CaseInsensitiveMap(relation.tableMeta.storage.properties) - val timeZoneId = caseInsensitiveProperties.get("timeZone") + val timeZoneId = caseInsensitiveProperties.get(DateTimeUtils.TIMEZONE_OPTION) .getOrElse(conf.sessionLocalTimeZone) val partitionData = catalog.listPartitions(relation.tableMeta.identifier).map { p => InternalRow.fromSeq(partAttrs.map { attr => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala index 30a09a9ad3370..ce33298aeb1da 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala @@ -141,7 +141,7 @@ object FileFormatWriter extends Logging { customPartitionLocations = outputSpec.customPartitionLocations, maxRecordsPerFile = caseInsensitiveOptions.get("maxRecordsPerFile").map(_.toLong) .getOrElse(sparkSession.sessionState.conf.maxRecordsPerFile), - timeZoneId = caseInsensitiveOptions.get("timeZone") + timeZoneId = caseInsensitiveOptions.get(DateTimeUtils.TIMEZONE_OPTION) .getOrElse(sparkSession.sessionState.conf.sessionLocalTimeZone) ) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningAwareFileIndex.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningAwareFileIndex.scala index c8097a7fabc2e..a5fa8b3f9385e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningAwareFileIndex.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningAwareFileIndex.scala @@ -127,7 +127,7 @@ abstract class PartitioningAwareFileIndex( }.keys.toSeq val caseInsensitiveOptions = CaseInsensitiveMap(parameters) - val timeZoneId = caseInsensitiveOptions.get("timeZone") + val timeZoneId = caseInsensitiveOptions.get(DateTimeUtils.TIMEZONE_OPTION) .getOrElse(sparkSession.sessionState.conf.sessionLocalTimeZone) userPartitionSchema match { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala index 0b1e5dac2da66..2632e87971d68 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala @@ -24,7 +24,7 @@ import com.univocity.parsers.csv.{CsvParserSettings, CsvWriterSettings, Unescape import org.apache.commons.lang3.time.FastDateFormat import org.apache.spark.internal.Logging -import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, CompressionCodecs, ParseModes} +import org.apache.spark.sql.catalyst.util._ class CSVOptions( @transient private val parameters: CaseInsensitiveMap[String], @@ -120,7 +120,8 @@ class CSVOptions( name.map(CompressionCodecs.getCodecClassName) } - val timeZone: TimeZone = TimeZone.getTimeZone(parameters.getOrElse("timeZone", defaultTimeZoneId)) + val timeZone: TimeZone = TimeZone.getTimeZone( + parameters.getOrElse(DateTimeUtils.TIMEZONE_OPTION, defaultTimeZoneId)) // Uses `FastDateFormat` which can be direct replacement for `SimpleDateFormat` and thread-safe. val dateFormat: FastDateFormat = diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala index 4435e4df38ef6..95dfdf5b298e6 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala @@ -29,6 +29,7 @@ import org.apache.hadoop.io.SequenceFile.CompressionType import org.apache.spark.SparkException import org.apache.spark.sql.{AnalysisException, DataFrame, QueryTest, Row, UDT} +import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.functions.{col, regexp_replace} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.{SharedSQLContext, SQLTestUtils} @@ -912,7 +913,7 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils { .format("csv") .option("header", "true") .option("timestampFormat", "yyyy/MM/dd HH:mm") - .option("timeZone", "GMT") + .option(DateTimeUtils.TIMEZONE_OPTION, "GMT") .save(timestampsWithFormatPath) // This will load back the timestamps as string. @@ -934,7 +935,7 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils { .option("header", "true") .option("inferSchema", "true") .option("timestampFormat", "yyyy/MM/dd HH:mm") - .option("timeZone", "GMT") + .option(DateTimeUtils.TIMEZONE_OPTION, "GMT") .load(timestampsWithFormatPath) checkAnswer(readBack, timestampsWithFormat) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala index 0aaf148dac258..9b0efcbdaf5c3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala @@ -1767,7 +1767,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { timestampsWithFormat.write .format("json") .option("timestampFormat", "yyyy/MM/dd HH:mm") - .option("timeZone", "GMT") + .option(DateTimeUtils.TIMEZONE_OPTION, "GMT") .save(timestampsWithFormatPath) // This will load back the timestamps as string. @@ -1785,7 +1785,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { val readBack = spark.read .schema(customSchema) .option("timestampFormat", "yyyy/MM/dd HH:mm") - .option("timeZone", "GMT") + .option(DateTimeUtils.TIMEZONE_OPTION, "GMT") .json(timestampsWithFormatPath) checkAnswer(readBack, timestampsWithFormat) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetPartitionDiscoverySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetPartitionDiscoverySuite.scala index 88cb8a0bad21e..2b20b9716bf80 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetPartitionDiscoverySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetPartitionDiscoverySuite.scala @@ -32,6 +32,7 @@ import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.catalog.ExternalCatalogUtils import org.apache.spark.sql.catalyst.expressions.Literal +import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.execution.datasources._ import org.apache.spark.sql.execution.datasources.{PartitionPath => Partition} import org.apache.spark.sql.functions._ @@ -708,10 +709,11 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest with Sha } withTempPath { dir => - df.write.option("timeZone", "GMT") + df.write.option(DateTimeUtils.TIMEZONE_OPTION, "GMT") .format("parquet").partitionBy(partitionColumns.map(_.name): _*).save(dir.toString) val fields = schema.map(f => Column(f.name).cast(f.dataType)) - checkAnswer(spark.read.option("timeZone", "GMT").load(dir.toString).select(fields: _*), row) + checkAnswer(spark.read.option(DateTimeUtils.TIMEZONE_OPTION, "GMT") + .load(dir.toString).select(fields: _*), row) } } @@ -749,10 +751,11 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest with Sha } withTempPath { dir => - df.write.option("timeZone", "GMT") + df.write.option(DateTimeUtils.TIMEZONE_OPTION, "GMT") .format("parquet").partitionBy(partitionColumns.map(_.name): _*).save(dir.toString) val fields = schema.map(f => Column(f.name)) - checkAnswer(spark.read.option("timeZone", "GMT").load(dir.toString).select(fields: _*), row) + checkAnswer(spark.read.option(DateTimeUtils.TIMEZONE_OPTION, "GMT") + .load(dir.toString).select(fields: _*), row) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/PartitionedWriteSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/PartitionedWriteSuite.scala index f251290583c5e..a2f3afe3ce236 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/PartitionedWriteSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/PartitionedWriteSuite.scala @@ -25,6 +25,7 @@ import org.apache.hadoop.mapreduce.TaskAttemptContext import org.apache.spark.internal.Logging import org.apache.spark.sql.{QueryTest, Row} import org.apache.spark.sql.catalyst.catalog.ExternalCatalogUtils +import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.execution.datasources.SQLHadoopMapReduceCommitProtocol import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf @@ -142,7 +143,8 @@ class PartitionedWriteSuite extends QueryTest with SharedSQLContext { checkPartitionValues(files.head, "2016-12-01 00:00:00") } withTempPath { f => - df.write.option("timeZone", "GMT").partitionBy("ts").parquet(f.getAbsolutePath) + df.write.option(DateTimeUtils.TIMEZONE_OPTION, "GMT") + .partitionBy("ts").parquet(f.getAbsolutePath) val files = recursiveList(f).filter(_.getAbsolutePath.endsWith("parquet")) assert(files.length == 1) // use timeZone option "GMT" to format partition value. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/ResolvedDataSourceSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/ResolvedDataSourceSuite.scala index 9b5e364e512a2..0f97fd78d2ffb 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/ResolvedDataSourceSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/ResolvedDataSourceSuite.scala @@ -27,7 +27,8 @@ class ResolvedDataSourceSuite extends SparkFunSuite { DataSource( sparkSession = null, className = name, - options = Map("timeZone" -> DateTimeUtils.defaultTimeZone().getID)).providingClass + options = Map(DateTimeUtils.TIMEZONE_OPTION -> DateTimeUtils.defaultTimeZone().getID) + ).providingClass test("jdbc") { assert( diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala index 33802ae62333e..8860b7dc079cb 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala @@ -39,7 +39,7 @@ import org.apache.spark.sql.catalyst.catalog._ import org.apache.spark.sql.catalyst.catalog.ExternalCatalogUtils.escapePathName import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical.ColumnStat -import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, DateTimeUtils} +import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap import org.apache.spark.sql.execution.command.DDLUtils import org.apache.spark.sql.execution.datasources.PartitioningUtils import org.apache.spark.sql.hive.client.HiveClient From dacc382f0c918f1ca808228484305ce0e21c705e Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Wed, 15 Mar 2017 08:24:41 +0800 Subject: [PATCH 0018/1765] [SPARK-19887][SQL] dynamic partition keys can be null or empty string ## What changes were proposed in this pull request? When dynamic partition value is null or empty string, we should write the data to a directory like `a=__HIVE_DEFAULT_PARTITION__`, when we read the data back, we should respect this special directory name and treat it as null. This is the same behavior of impala, see https://issues.apache.org/jira/browse/IMPALA-252 ## How was this patch tested? new regression test Author: Wenchen Fan Closes #17277 from cloud-fan/partition. --- .../catalog/ExternalCatalogUtils.scala | 2 +- .../sql/catalyst/catalog/interface.scala | 9 +++++-- .../sql/execution/DataSourceScanExec.scala | 2 +- .../datasources/FileFormatWriter.scala | 11 ++++----- .../datasources/PartitioningUtils.scala | 3 +-- .../spark/sql/hive/HiveExternalCatalog.scala | 4 ++-- .../PartitionProviderCompatibilitySuite.scala | 24 ++++++++++++++++++- 7 files changed, 39 insertions(+), 16 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalogUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalogUtils.scala index a418edc302d9c..a8693dcca539d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalogUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalogUtils.scala @@ -118,7 +118,7 @@ object ExternalCatalogUtils { } def getPartitionPathString(col: String, value: String): String = { - val partitionString = if (value == null) { + val partitionString = if (value == null || value.isEmpty) { DEFAULT_PARTITION_NAME } else { escapePathName(value) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala index b862deaf36369..70ed44e025f51 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala @@ -116,7 +116,12 @@ case class CatalogTablePartition( val timeZoneId = caseInsensitiveProperties.getOrElse( DateTimeUtils.TIMEZONE_OPTION, defaultTimeZondId) InternalRow.fromSeq(partitionSchema.map { field => - Cast(Literal(spec(field.name)), field.dataType, Option(timeZoneId)).eval() + val partValue = if (spec(field.name) == ExternalCatalogUtils.DEFAULT_PARTITION_NAME) { + null + } else { + spec(field.name) + } + Cast(Literal(partValue), field.dataType, Option(timeZoneId)).eval() }) } } @@ -164,7 +169,7 @@ case class BucketSpec( * @param tracksPartitionsInCatalog whether this table's partition metadata is stored in the * catalog. If false, it is inferred automatically based on file * structure. - * @param schemaPresevesCase Whether or not the schema resolved for this table is case-sensitive. + * @param schemaPreservesCase Whether or not the schema resolved for this table is case-sensitive. * When using a Hive Metastore, this flag is set to false if a case- * sensitive schema was unable to be read from the table properties. * Used to trigger case-sensitive schema inference at query time, when diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala index 39b010efec7b0..8ebad676ca310 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala @@ -319,7 +319,7 @@ case class FileSourceScanExec( val input = ctx.freshName("input") ctx.addMutableState("scala.collection.Iterator", input, s"$input = inputs[0];") val exprRows = output.zipWithIndex.map{ case (a, i) => - new BoundReference(i, a.dataType, a.nullable) + BoundReference(i, a.dataType, a.nullable) } val row = ctx.freshName("row") ctx.INPUT_ROW = row diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala index ce33298aeb1da..7957224ce48b5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala @@ -335,14 +335,11 @@ object FileFormatWriter extends Logging { /** Expressions that given partition columns build a path string like: col1=val/col2=val/... */ private def partitionPathExpression: Seq[Expression] = { desc.partitionColumns.zipWithIndex.flatMap { case (c, i) => - val escaped = ScalaUDF( - ExternalCatalogUtils.escapePathName _, + val partitionName = ScalaUDF( + ExternalCatalogUtils.getPartitionPathString _, StringType, - Seq(Cast(c, StringType, Option(desc.timeZoneId))), - Seq(StringType)) - val str = If(IsNull(c), Literal(ExternalCatalogUtils.DEFAULT_PARTITION_NAME), escaped) - val partitionName = Literal(ExternalCatalogUtils.escapePathName(c.name) + "=") :: str :: Nil - if (i == 0) partitionName else Literal(Path.SEPARATOR) :: partitionName + Seq(Literal(c.name), Cast(c, StringType, Option(desc.timeZoneId)))) + if (i == 0) Seq(partitionName) else Seq(Literal(Path.SEPARATOR), partitionName) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala index 09876bbc2f85d..03980922ab38f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala @@ -33,7 +33,6 @@ import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec import org.apache.spark.sql.catalyst.expressions.{Cast, Literal} import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.types._ -import org.apache.spark.unsafe.types.UTF8String // TODO: We should tighten up visibility of the classes here once we clean up Hive coupling. @@ -129,7 +128,7 @@ object PartitioningUtils { // "hdfs://host:9000/invalidPath" // "hdfs://host:9000/path" // TODO: Selective case sensitivity. - val discoveredBasePaths = optDiscoveredBasePaths.flatMap(x => x).map(_.toString.toLowerCase()) + val discoveredBasePaths = optDiscoveredBasePaths.flatten.map(_.toString.toLowerCase()) assert( discoveredBasePaths.distinct.size == 1, "Conflicting directory structures detected. Suspicious paths:\b" + diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala index 8860b7dc079cb..8a3c81ac8b0fc 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala @@ -1012,8 +1012,8 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat val partColNameMap = buildLowerCasePartColNameMap(catalogTable).mapValues(escapePathName) val clientPartitionNames = client.getPartitionNames(catalogTable, partialSpec.map(lowerCasePartitionSpec)) - clientPartitionNames.map { partName => - val partSpec = PartitioningUtils.parsePathFragmentAsSeq(partName) + clientPartitionNames.map { partitionPath => + val partSpec = PartitioningUtils.parsePathFragmentAsSeq(partitionPath) partSpec.map { case (partName, partValue) => partColNameMap(partName.toLowerCase) + "=" + escapePathName(partValue) }.mkString("/") diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/PartitionProviderCompatibilitySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/PartitionProviderCompatibilitySuite.scala index 96385961c9a52..9440a17677ebf 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/PartitionProviderCompatibilitySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/PartitionProviderCompatibilitySuite.scala @@ -22,7 +22,7 @@ import java.io.File import org.apache.hadoop.fs.Path import org.apache.spark.metrics.source.HiveCatalogMetrics -import org.apache.spark.sql.{AnalysisException, QueryTest} +import org.apache.spark.sql.{AnalysisException, QueryTest, Row} import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.hive.test.TestHiveSingleton import org.apache.spark.sql.internal.SQLConf @@ -316,6 +316,28 @@ class PartitionProviderCompatibilitySuite } } } + + test(s"SPARK-19887 partition value is null - partition management $enabled") { + withTable("test") { + Seq((1, "p", 1), (2, null, 2)).toDF("a", "b", "c") + .write.partitionBy("b", "c").saveAsTable("test") + checkAnswer(spark.table("test"), + Row(1, "p", 1) :: Row(2, null, 2) :: Nil) + + Seq((3, null: String, 3)).toDF("a", "b", "c") + .write.mode("append").partitionBy("b", "c").saveAsTable("test") + checkAnswer(spark.table("test"), + Row(1, "p", 1) :: Row(2, null, 2) :: Row(3, null, 3) :: Nil) + // make sure partition pruning also works. + checkAnswer(spark.table("test").filter($"b".isNotNull), Row(1, "p", 1)) + + // empty string is an invalid partition value and we treat it as null when read back. + Seq((4, "", 4)).toDF("a", "b", "c") + .write.mode("append").partitionBy("b", "c").saveAsTable("test") + checkAnswer(spark.table("test"), + Row(1, "p", 1) :: Row(2, null, 2) :: Row(3, null, 3) :: Row(4, null, 4) :: Nil) + } + } } /** From 8fb2a02e2ce6832e3d9338a7d0148dfac9fa24c2 Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Wed, 15 Mar 2017 10:19:19 +0800 Subject: [PATCH 0019/1765] [SPARK-19918][SQL] Use TextFileFormat in implementation of TextInputJsonDataSource ## What changes were proposed in this pull request? This PR proposes to use text datasource when Json schema inference. This basically proposes the similar approach in https://github.com/apache/spark/pull/15813 If we use Dataset for initial loading when inferring the schema, there are advantages. Please refer SPARK-18362 It seems JSON one was supposed to be fixed together but taken out according to https://github.com/apache/spark/pull/15813 > A similar problem also affects the JSON file format and this patch originally fixed that as well, but I've decided to split that change into a separate patch so as not to conflict with changes in another JSON PR. Also, this seems affecting some functionalities because it does not use `FileScanRDD`. This problem is described in SPARK-19885 (but it was CSV's case). ## How was this patch tested? Existing tests should cover this and manual test by `spark.read.json(path)` and check the UI. Author: hyukjinkwon Closes #17255 from HyukjinKwon/json-filescanrdd. --- .../apache/spark/sql/DataFrameReader.scala | 9 +- .../datasources/json/JsonDataSource.scala | 145 ++++++++---------- .../datasources/json/JsonFileFormat.scala | 2 +- .../datasources/json/JsonInferSchema.scala | 9 +- .../datasources/json/JsonUtils.scala | 51 ++++++ 5 files changed, 122 insertions(+), 94 deletions(-) create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonUtils.scala diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala index f1bce1aa41029..309654c804148 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala @@ -32,7 +32,7 @@ import org.apache.spark.sql.execution.command.DDLUtils import org.apache.spark.sql.execution.datasources.csv._ import org.apache.spark.sql.execution.datasources.DataSource import org.apache.spark.sql.execution.datasources.jdbc._ -import org.apache.spark.sql.execution.datasources.json.JsonInferSchema +import org.apache.spark.sql.execution.datasources.json.TextInputJsonDataSource import org.apache.spark.sql.types.{StringType, StructType} import org.apache.spark.unsafe.types.UTF8String @@ -376,17 +376,14 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { extraOptions.toMap, sparkSession.sessionState.conf.sessionLocalTimeZone, sparkSession.sessionState.conf.columnNameOfCorruptRecord) - val createParser = CreateJacksonParser.string _ val schema = userSpecifiedSchema.getOrElse { - JsonInferSchema.infer( - jsonDataset.rdd, - parsedOptions, - createParser) + TextInputJsonDataSource.inferFromDataset(jsonDataset, parsedOptions) } verifyColumnNameOfCorruptRecord(schema, parsedOptions.columnNameOfCorruptRecord) + val createParser = CreateJacksonParser.string _ val parsed = jsonDataset.rdd.mapPartitions { iter => val parser = new JacksonParser(schema, parsedOptions) iter.flatMap(parser.parse(_, createParser, UTF8String.fromString)) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonDataSource.scala index 18843bfc307b3..84f026620d907 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonDataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonDataSource.scala @@ -17,32 +17,30 @@ package org.apache.spark.sql.execution.datasources.json -import scala.reflect.ClassTag - import com.fasterxml.jackson.core.{JsonFactory, JsonParser} import com.google.common.io.ByteStreams import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.FileStatus -import org.apache.hadoop.io.{LongWritable, Text} +import org.apache.hadoop.io.Text import org.apache.hadoop.mapreduce.Job -import org.apache.hadoop.mapreduce.lib.input.{FileInputFormat, TextInputFormat} +import org.apache.hadoop.mapreduce.lib.input.FileInputFormat import org.apache.spark.TaskContext import org.apache.spark.input.{PortableDataStream, StreamInputFormat} import org.apache.spark.rdd.{BinaryFileRDD, RDD} -import org.apache.spark.sql.{AnalysisException, SparkSession} +import org.apache.spark.sql.{AnalysisException, Dataset, Encoders, SparkSession} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.json.{CreateJacksonParser, JacksonParser, JSONOptions} -import org.apache.spark.sql.execution.datasources.{CodecStreams, HadoopFileLinesReader, PartitionedFile} +import org.apache.spark.sql.execution.datasources.{CodecStreams, DataSource, HadoopFileLinesReader, PartitionedFile} +import org.apache.spark.sql.execution.datasources.text.TextFileFormat import org.apache.spark.sql.types.StructType import org.apache.spark.unsafe.types.UTF8String import org.apache.spark.util.Utils /** * Common functions for parsing JSON files - * @tparam T A datatype containing the unparsed JSON, such as [[Text]] or [[String]] */ -abstract class JsonDataSource[T] extends Serializable { +abstract class JsonDataSource extends Serializable { def isSplitable: Boolean /** @@ -53,28 +51,12 @@ abstract class JsonDataSource[T] extends Serializable { file: PartitionedFile, parser: JacksonParser): Iterator[InternalRow] - /** - * Create an [[RDD]] that handles the preliminary parsing of [[T]] records - */ - protected def createBaseRdd( - sparkSession: SparkSession, - inputPaths: Seq[FileStatus]): RDD[T] - - /** - * A generic wrapper to invoke the correct [[JsonFactory]] method to allocate a [[JsonParser]] - * for an instance of [[T]] - */ - def createParser(jsonFactory: JsonFactory, value: T): JsonParser - - final def infer( + final def inferSchema( sparkSession: SparkSession, inputPaths: Seq[FileStatus], parsedOptions: JSONOptions): Option[StructType] = { if (inputPaths.nonEmpty) { - val jsonSchema = JsonInferSchema.infer( - createBaseRdd(sparkSession, inputPaths), - parsedOptions, - createParser) + val jsonSchema = infer(sparkSession, inputPaths, parsedOptions) checkConstraints(jsonSchema) Some(jsonSchema) } else { @@ -82,6 +64,11 @@ abstract class JsonDataSource[T] extends Serializable { } } + protected def infer( + sparkSession: SparkSession, + inputPaths: Seq[FileStatus], + parsedOptions: JSONOptions): StructType + /** Constraints to be imposed on schema to be stored. */ private def checkConstraints(schema: StructType): Unit = { if (schema.fieldNames.length != schema.fieldNames.distinct.length) { @@ -95,53 +82,46 @@ abstract class JsonDataSource[T] extends Serializable { } object JsonDataSource { - def apply(options: JSONOptions): JsonDataSource[_] = { + def apply(options: JSONOptions): JsonDataSource = { if (options.wholeFile) { WholeFileJsonDataSource } else { TextInputJsonDataSource } } - - /** - * Create a new [[RDD]] via the supplied callback if there is at least one file to process, - * otherwise an [[org.apache.spark.rdd.EmptyRDD]] will be returned. - */ - def createBaseRdd[T : ClassTag]( - sparkSession: SparkSession, - inputPaths: Seq[FileStatus])( - fn: (Configuration, String) => RDD[T]): RDD[T] = { - val paths = inputPaths.map(_.getPath) - - if (paths.nonEmpty) { - val job = Job.getInstance(sparkSession.sessionState.newHadoopConf()) - FileInputFormat.setInputPaths(job, paths: _*) - fn(job.getConfiguration, paths.mkString(",")) - } else { - sparkSession.sparkContext.emptyRDD[T] - } - } } -object TextInputJsonDataSource extends JsonDataSource[Text] { +object TextInputJsonDataSource extends JsonDataSource { override val isSplitable: Boolean = { // splittable if the underlying source is true } - override protected def createBaseRdd( + override def infer( sparkSession: SparkSession, - inputPaths: Seq[FileStatus]): RDD[Text] = { - JsonDataSource.createBaseRdd(sparkSession, inputPaths) { - case (conf, name) => - sparkSession.sparkContext.newAPIHadoopRDD( - conf, - classOf[TextInputFormat], - classOf[LongWritable], - classOf[Text]) - .setName(s"JsonLines: $name") - .values // get the text column - } + inputPaths: Seq[FileStatus], + parsedOptions: JSONOptions): StructType = { + val json: Dataset[String] = createBaseDataset(sparkSession, inputPaths) + inferFromDataset(json, parsedOptions) + } + + def inferFromDataset(json: Dataset[String], parsedOptions: JSONOptions): StructType = { + val sampled: Dataset[String] = JsonUtils.sample(json, parsedOptions) + val rdd: RDD[UTF8String] = sampled.queryExecution.toRdd.map(_.getUTF8String(0)) + JsonInferSchema.infer(rdd, parsedOptions, CreateJacksonParser.utf8String) + } + + private def createBaseDataset( + sparkSession: SparkSession, + inputPaths: Seq[FileStatus]): Dataset[String] = { + val paths = inputPaths.map(_.getPath.toString) + sparkSession.baseRelationToDataFrame( + DataSource.apply( + sparkSession, + paths = paths, + className = classOf[TextFileFormat].getName + ).resolveRelation(checkFilesExist = false)) + .select("value").as(Encoders.STRING) } override def readFile( @@ -150,41 +130,48 @@ object TextInputJsonDataSource extends JsonDataSource[Text] { parser: JacksonParser): Iterator[InternalRow] = { val linesReader = new HadoopFileLinesReader(file, conf) Option(TaskContext.get()).foreach(_.addTaskCompletionListener(_ => linesReader.close())) - linesReader.flatMap(parser.parse(_, createParser, textToUTF8String)) + linesReader.flatMap(parser.parse(_, CreateJacksonParser.text, textToUTF8String)) } private def textToUTF8String(value: Text): UTF8String = { UTF8String.fromBytes(value.getBytes, 0, value.getLength) } - - override def createParser(jsonFactory: JsonFactory, value: Text): JsonParser = { - CreateJacksonParser.text(jsonFactory, value) - } } -object WholeFileJsonDataSource extends JsonDataSource[PortableDataStream] { +object WholeFileJsonDataSource extends JsonDataSource { override val isSplitable: Boolean = { false } - override protected def createBaseRdd( + override def infer( + sparkSession: SparkSession, + inputPaths: Seq[FileStatus], + parsedOptions: JSONOptions): StructType = { + val json: RDD[PortableDataStream] = createBaseRdd(sparkSession, inputPaths) + val sampled: RDD[PortableDataStream] = JsonUtils.sample(json, parsedOptions) + JsonInferSchema.infer(sampled, parsedOptions, createParser) + } + + private def createBaseRdd( sparkSession: SparkSession, inputPaths: Seq[FileStatus]): RDD[PortableDataStream] = { - JsonDataSource.createBaseRdd(sparkSession, inputPaths) { - case (conf, name) => - new BinaryFileRDD( - sparkSession.sparkContext, - classOf[StreamInputFormat], - classOf[String], - classOf[PortableDataStream], - conf, - sparkSession.sparkContext.defaultMinPartitions) - .setName(s"JsonFile: $name") - .values - } + val paths = inputPaths.map(_.getPath) + val job = Job.getInstance(sparkSession.sessionState.newHadoopConf()) + val conf = job.getConfiguration + val name = paths.mkString(",") + FileInputFormat.setInputPaths(job, paths: _*) + new BinaryFileRDD( + sparkSession.sparkContext, + classOf[StreamInputFormat], + classOf[String], + classOf[PortableDataStream], + conf, + sparkSession.sparkContext.defaultMinPartitions) + .setName(s"JsonFile: $name") + .values } - override def createParser(jsonFactory: JsonFactory, record: PortableDataStream): JsonParser = { + private def createParser(jsonFactory: JsonFactory, record: PortableDataStream): JsonParser = { CreateJacksonParser.inputStream( jsonFactory, CodecStreams.createInputStreamWithCloseResource(record.getConfiguration, record.getPath())) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonFileFormat.scala index 902fee5a7e3f7..a9dd91eba6f72 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonFileFormat.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonFileFormat.scala @@ -54,7 +54,7 @@ class JsonFileFormat extends TextBasedFileFormat with DataSourceRegister { options, sparkSession.sessionState.conf.sessionLocalTimeZone, sparkSession.sessionState.conf.columnNameOfCorruptRecord) - JsonDataSource(parsedOptions).infer( + JsonDataSource(parsedOptions).inferSchema( sparkSession, files, parsedOptions) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonInferSchema.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonInferSchema.scala index ab09358115c0a..7475f8ec79331 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonInferSchema.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonInferSchema.scala @@ -40,18 +40,11 @@ private[sql] object JsonInferSchema { json: RDD[T], configOptions: JSONOptions, createParser: (JsonFactory, T) => JsonParser): StructType = { - require(configOptions.samplingRatio > 0, - s"samplingRatio (${configOptions.samplingRatio}) should be greater than 0") val shouldHandleCorruptRecord = configOptions.permissive val columnNameOfCorruptRecord = configOptions.columnNameOfCorruptRecord - val schemaData = if (configOptions.samplingRatio > 0.99) { - json - } else { - json.sample(withReplacement = false, configOptions.samplingRatio, 1) - } // perform schema inference on each row and merge afterwards - val rootType = schemaData.mapPartitions { iter => + val rootType = json.mapPartitions { iter => val factory = new JsonFactory() configOptions.setJacksonOptions(factory) iter.flatMap { row => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonUtils.scala new file mode 100644 index 0000000000000..d511594c5de1c --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonUtils.scala @@ -0,0 +1,51 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources.json + +import org.apache.spark.input.PortableDataStream +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.Dataset +import org.apache.spark.sql.catalyst.json.JSONOptions + +object JsonUtils { + /** + * Sample JSON dataset as configured by `samplingRatio`. + */ + def sample(json: Dataset[String], options: JSONOptions): Dataset[String] = { + require(options.samplingRatio > 0, + s"samplingRatio (${options.samplingRatio}) should be greater than 0") + if (options.samplingRatio > 0.99) { + json + } else { + json.sample(withReplacement = false, options.samplingRatio, 1) + } + } + + /** + * Sample JSON RDD as configured by `samplingRatio`. + */ + def sample(json: RDD[PortableDataStream], options: JSONOptions): RDD[PortableDataStream] = { + require(options.samplingRatio > 0, + s"samplingRatio (${options.samplingRatio}) should be greater than 0") + if (options.samplingRatio > 0.99) { + json + } else { + json.sample(withReplacement = false, options.samplingRatio, 1) + } + } +} From d1f6c64c4b763c05d6d79ae5497f298dc3835f3e Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Tue, 14 Mar 2017 19:51:25 -0700 Subject: [PATCH 0020/1765] [SPARK-19828][R] Support array type in from_json in R ## What changes were proposed in this pull request? Since we could not directly define the array type in R, this PR proposes to support array types in R as string types that are used in `structField` as below: ```R jsonArr <- "[{\"name\":\"Bob\"}, {\"name\":\"Alice\"}]" df <- as.DataFrame(list(list("people" = jsonArr))) collect(select(df, alias(from_json(df$people, "array>"), "arrcol"))) ``` prints ```R arrcol 1 Bob, Alice ``` ## How was this patch tested? Unit tests in `test_sparkSQL.R`. Author: hyukjinkwon Closes #17178 from HyukjinKwon/SPARK-19828. --- R/pkg/R/functions.R | 12 ++++++++++-- R/pkg/inst/tests/testthat/test_sparkSQL.R | 12 ++++++++++++ .../scala/org/apache/spark/sql/api/r/SQLUtils.scala | 2 +- 3 files changed, 23 insertions(+), 3 deletions(-) diff --git a/R/pkg/R/functions.R b/R/pkg/R/functions.R index edf2bcf8fdb3c..9867f2d5b7c51 100644 --- a/R/pkg/R/functions.R +++ b/R/pkg/R/functions.R @@ -2437,6 +2437,7 @@ setMethod("date_format", signature(y = "Column", x = "character"), #' #' @param x Column containing the JSON string. #' @param schema a structType object to use as the schema to use when parsing the JSON string. +#' @param asJsonArray indicating if input string is JSON array of objects or a single object. #' @param ... additional named properties to control how the json is parsed, accepts the same #' options as the JSON data source. #' @@ -2452,11 +2453,18 @@ setMethod("date_format", signature(y = "Column", x = "character"), #'} #' @note from_json since 2.2.0 setMethod("from_json", signature(x = "Column", schema = "structType"), - function(x, schema, ...) { + function(x, schema, asJsonArray = FALSE, ...) { + if (asJsonArray) { + jschema <- callJStatic("org.apache.spark.sql.types.DataTypes", + "createArrayType", + schema$jobj) + } else { + jschema <- schema$jobj + } options <- varargsToStrEnv(...) jc <- callJStatic("org.apache.spark.sql.functions", "from_json", - x@jc, schema$jobj, options) + x@jc, jschema, options) column(jc) }) diff --git a/R/pkg/inst/tests/testthat/test_sparkSQL.R b/R/pkg/inst/tests/testthat/test_sparkSQL.R index 9735fe3201553..f7081cb1d4e50 100644 --- a/R/pkg/inst/tests/testthat/test_sparkSQL.R +++ b/R/pkg/inst/tests/testthat/test_sparkSQL.R @@ -1364,6 +1364,18 @@ test_that("column functions", { # check for unparseable df <- as.DataFrame(list(list("a" = ""))) expect_equal(collect(select(df, from_json(df$a, schema)))[[1]][[1]], NA) + + # check if array type in string is correctly supported. + jsonArr <- "[{\"name\":\"Bob\"}, {\"name\":\"Alice\"}]" + df <- as.DataFrame(list(list("people" = jsonArr))) + schema <- structType(structField("name", "string")) + arr <- collect(select(df, alias(from_json(df$people, schema, asJsonArray = TRUE), "arrcol"))) + expect_equal(ncol(arr), 1) + expect_equal(nrow(arr), 1) + expect_is(arr[[1]][[1]], "list") + expect_equal(length(arr$arrcol[[1]]), 2) + expect_equal(arr$arrcol[[1]][[1]]$name, "Bob") + expect_equal(arr$arrcol[[1]][[2]]$name, "Alice") }) test_that("column binary mathfunctions", { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala index a4c5bf756cd5a..c77328690daec 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala @@ -81,7 +81,7 @@ private[sql] object SQLUtils extends Logging { new JavaSparkContext(spark.sparkContext) } - def createStructType(fields : Seq[StructField]): StructType = { + def createStructType(fields: Seq[StructField]): StructType = { StructType(fields) } From f9a93b1b4a20e7c72d900362b269edab66e73dd8 Mon Sep 17 00:00:00 2001 From: Xiao Li Date: Wed, 15 Mar 2017 10:53:58 +0800 Subject: [PATCH 0021/1765] [SPARK-18112][SQL] Support reading data from Hive 2.1 metastore ### What changes were proposed in this pull request? This PR is to support reading data from Hive 2.1 metastore. Need to update shim class because of the Hive API changes caused by the following three Hive JIRAs: - [HIVE-12730 MetadataUpdater: provide a mechanism to edit the basic statistics of a table (or a partition)](https://issues.apache.org/jira/browse/HIVE-12730) - [Hive-13341 Stats state is not captured correctly: differentiate load table and create table](https://issues.apache.org/jira/browse/HIVE-13341) - [HIVE-13622 WriteSet tracking optimizations](https://issues.apache.org/jira/browse/HIVE-13622) There are three new fields added to Hive APIs. - `boolean hasFollowingStatsTask`. We always set it to `false`. This is to keep the existing behavior unchanged (starting from 0.13), no matter which Hive metastore client version users choose. If we set it to `true`, the basic table statistics is not collected by Hive. For example, ```SQL CREATE TABLE tbl AS SELECT 1 AS a ``` When setting `hasFollowingStatsTask ` to `false`, the table properties is like ``` Properties: [numFiles=1, transient_lastDdlTime=1489513927, totalSize=2] ``` When setting `hasFollowingStatsTask ` to `true`, the table properties is like ``` Properties: [transient_lastDdlTime=1489513563] ``` - `AcidUtils.Operation operation`. Obviously, we do not support ACID. Thus, we set it to `AcidUtils.Operation.NOT_ACID`. - `EnvironmentContext environmentContext`. So far, this is always set to `null`. This was introduced for supporting DDL `alter table s update statistics set ('numRows'='NaN')`. Using this DDL, users can specify the statistics. So far, our Spark SQL does not need it, because we use different table properties to store our generated statistics values. However, when Spark SQL issues ALTER TABLE DDL statements, Hive metastore always automatically invalidate the Hive-generated statistics. In the follow-up PR, we can fix it by explicitly adding a property to `environmentContext`. ```JAVA putToProperties(StatsSetupConst.STATS_GENERATED, StatsSetupConst.USER) ``` Another alternative is to set `DO_NOT_UPDATE_STATS`to `TRUE`. See the Hive JIRA: https://issues.apache.org/jira/browse/HIVE-15653. We will not address it in this PR. ### How was this patch tested? Added test cases to VersionsSuite.scala Author: Xiao Li Closes #17232 from gatorsmile/Hive21. --- .../spark/sql/hive/HiveExternalCatalog.scala | 1 - .../sql/hive/client/HiveClientImpl.scala | 5 +- .../spark/sql/hive/client/HiveShim.scala | 181 ++++++++++++++++-- .../hive/client/IsolatedClientLoader.scala | 1 + .../spark/sql/hive/client/package.scala | 6 +- .../hive/execution/InsertIntoHiveTable.scala | 2 +- .../spark/sql/hive/client/VersionsSuite.scala | 19 +- 7 files changed, 190 insertions(+), 25 deletions(-) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala index 8a3c81ac8b0fc..33b21be37203b 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala @@ -19,7 +19,6 @@ package org.apache.spark.sql.hive import java.io.IOException import java.lang.reflect.InvocationTargetException -import java.net.URI import java.util import scala.collection.mutable diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala index 6e1f429286cfa..989fdc5564d39 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala @@ -97,6 +97,7 @@ private[hive] class HiveClientImpl( case hive.v1_1 => new Shim_v1_1() case hive.v1_2 => new Shim_v1_2() case hive.v2_0 => new Shim_v2_0() + case hive.v2_1 => new Shim_v2_1() } // Create an internal session state for this HiveClientImpl. @@ -455,7 +456,7 @@ private[hive] class HiveClientImpl( val hiveTable = toHiveTable(table, Some(conf)) // Do not use `table.qualifiedName` here because this may be a rename val qualifiedTableName = s"${table.database}.$tableName" - client.alterTable(qualifiedTableName, hiveTable) + shim.alterTable(client, qualifiedTableName, hiveTable) } override def createPartitions( @@ -535,7 +536,7 @@ private[hive] class HiveClientImpl( table: String, newParts: Seq[CatalogTablePartition]): Unit = withHiveState { val hiveTable = toHiveTable(getTable(db, table), Some(conf)) - client.alterPartitions(table, newParts.map { p => toHivePartition(p, hiveTable) }.asJava) + shim.alterPartitions(client, table, newParts.map { p => toHivePartition(p, hiveTable) }.asJava) } /** diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala index 153f1673c96f6..76568f599078d 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala @@ -28,8 +28,10 @@ import scala.util.control.NonFatal import org.apache.hadoop.fs.Path import org.apache.hadoop.hive.conf.HiveConf -import org.apache.hadoop.hive.metastore.api.{Function => HiveFunction, FunctionType, MetaException, PrincipalType, ResourceType, ResourceUri} +import org.apache.hadoop.hive.metastore.api.{EnvironmentContext, Function => HiveFunction, FunctionType} +import org.apache.hadoop.hive.metastore.api.{MetaException, PrincipalType, ResourceType, ResourceUri} import org.apache.hadoop.hive.ql.Driver +import org.apache.hadoop.hive.ql.io.AcidUtils import org.apache.hadoop.hive.ql.metadata.{Hive, HiveException, Partition, Table} import org.apache.hadoop.hive.ql.plan.AddPartitionDesc import org.apache.hadoop.hive.ql.processors.{CommandProcessor, CommandProcessorFactory} @@ -82,6 +84,10 @@ private[client] sealed abstract class Shim { def getMetastoreClientConnectRetryDelayMillis(conf: HiveConf): Long + def alterTable(hive: Hive, tableName: String, table: Table): Unit + + def alterPartitions(hive: Hive, tableName: String, newParts: JList[Partition]): Unit + def createPartitions( hive: Hive, db: String, @@ -158,6 +164,10 @@ private[client] sealed abstract class Shim { } private[client] class Shim_v0_12 extends Shim with Logging { + // See HIVE-12224, HOLD_DDLTIME was broken as soon as it landed + protected lazy val holdDDLTime = JBoolean.FALSE + // deletes the underlying data along with metadata + protected lazy val deleteDataInDropIndex = JBoolean.TRUE private lazy val startMethod = findStaticMethod( @@ -240,6 +250,18 @@ private[client] class Shim_v0_12 extends Shim with Logging { classOf[String], classOf[String], JBoolean.TYPE) + private lazy val alterTableMethod = + findMethod( + classOf[Hive], + "alterTable", + classOf[String], + classOf[Table]) + private lazy val alterPartitionsMethod = + findMethod( + classOf[Hive], + "alterPartitions", + classOf[String], + classOf[JList[Partition]]) override def setCurrentSessionState(state: SessionState): Unit = { // Starting from Hive 0.13, setCurrentSessionState will internally override @@ -341,7 +363,7 @@ private[client] class Shim_v0_12 extends Shim with Logging { tableName: String, replace: Boolean, isSrcLocal: Boolean): Unit = { - loadTableMethod.invoke(hive, loadPath, tableName, replace: JBoolean, JBoolean.FALSE) + loadTableMethod.invoke(hive, loadPath, tableName, replace: JBoolean, holdDDLTime) } override def loadDynamicPartitions( @@ -353,11 +375,11 @@ private[client] class Shim_v0_12 extends Shim with Logging { numDP: Int, listBucketingEnabled: Boolean): Unit = { loadDynamicPartitionsMethod.invoke(hive, loadPath, tableName, partSpec, replace: JBoolean, - numDP: JInteger, JBoolean.FALSE, listBucketingEnabled: JBoolean) + numDP: JInteger, holdDDLTime, listBucketingEnabled: JBoolean) } override def dropIndex(hive: Hive, dbName: String, tableName: String, indexName: String): Unit = { - dropIndexMethod.invoke(hive, dbName, tableName, indexName, true: JBoolean) + dropIndexMethod.invoke(hive, dbName, tableName, indexName, deleteDataInDropIndex) } override def dropTable( @@ -373,6 +395,14 @@ private[client] class Shim_v0_12 extends Shim with Logging { hive.dropTable(dbName, tableName, deleteData, ignoreIfNotExists) } + override def alterTable(hive: Hive, tableName: String, table: Table): Unit = { + alterTableMethod.invoke(hive, tableName, table) + } + + override def alterPartitions(hive: Hive, tableName: String, newParts: JList[Partition]): Unit = { + alterPartitionsMethod.invoke(hive, tableName, newParts) + } + override def dropPartition( hive: Hive, dbName: String, @@ -520,7 +550,7 @@ private[client] class Shim_v0_13 extends Shim_v0_12 { } FunctionResource(FunctionResourceType.fromString(resourceType), uri.getUri()) } - new CatalogFunction(name, hf.getClassName, resources) + CatalogFunction(name, hf.getClassName, resources) } override def getFunctionOption(hive: Hive, db: String, name: String): Option[CatalogFunction] = { @@ -638,6 +668,11 @@ private[client] class Shim_v0_13 extends Shim_v0_12 { private[client] class Shim_v0_14 extends Shim_v0_13 { + // true if this is an ACID operation + protected lazy val isAcid = JBoolean.FALSE + // true if list bucketing enabled + protected lazy val isSkewedStoreAsSubdir = JBoolean.FALSE + private lazy val loadPartitionMethod = findMethod( classOf[Hive], @@ -700,8 +735,8 @@ private[client] class Shim_v0_14 extends Shim_v0_13 { isSkewedStoreAsSubdir: Boolean, isSrcLocal: Boolean): Unit = { loadPartitionMethod.invoke(hive, loadPath, tableName, partSpec, replace: JBoolean, - JBoolean.FALSE, inheritTableSpecs: JBoolean, isSkewedStoreAsSubdir: JBoolean, - isSrcLocal: JBoolean, JBoolean.FALSE) + holdDDLTime, inheritTableSpecs: JBoolean, isSkewedStoreAsSubdir: JBoolean, + isSrcLocal: JBoolean, isAcid) } override def loadTable( @@ -710,8 +745,8 @@ private[client] class Shim_v0_14 extends Shim_v0_13 { tableName: String, replace: Boolean, isSrcLocal: Boolean): Unit = { - loadTableMethod.invoke(hive, loadPath, tableName, replace: JBoolean, JBoolean.FALSE, - isSrcLocal: JBoolean, JBoolean.FALSE, JBoolean.FALSE) + loadTableMethod.invoke(hive, loadPath, tableName, replace: JBoolean, holdDDLTime, + isSrcLocal: JBoolean, isSkewedStoreAsSubdir, isAcid) } override def loadDynamicPartitions( @@ -723,7 +758,7 @@ private[client] class Shim_v0_14 extends Shim_v0_13 { numDP: Int, listBucketingEnabled: Boolean): Unit = { loadDynamicPartitionsMethod.invoke(hive, loadPath, tableName, partSpec, replace: JBoolean, - numDP: JInteger, JBoolean.FALSE, listBucketingEnabled: JBoolean, JBoolean.FALSE) + numDP: JInteger, holdDDLTime, listBucketingEnabled: JBoolean, isAcid) } override def dropTable( @@ -752,6 +787,9 @@ private[client] class Shim_v1_0 extends Shim_v0_14 { private[client] class Shim_v1_1 extends Shim_v1_0 { + // throws an exception if the index does not exist + protected lazy val throwExceptionInDropIndex = JBoolean.TRUE + private lazy val dropIndexMethod = findMethod( classOf[Hive], @@ -763,13 +801,17 @@ private[client] class Shim_v1_1 extends Shim_v1_0 { JBoolean.TYPE) override def dropIndex(hive: Hive, dbName: String, tableName: String, indexName: String): Unit = { - dropIndexMethod.invoke(hive, dbName, tableName, indexName, true: JBoolean, true: JBoolean) + dropIndexMethod.invoke(hive, dbName, tableName, indexName, throwExceptionInDropIndex, + deleteDataInDropIndex) } } private[client] class Shim_v1_2 extends Shim_v1_1 { + // txnId can be 0 unless isAcid == true + protected lazy val txnIdInLoadDynamicPartitions: JLong = 0L + private lazy val loadDynamicPartitionsMethod = findMethod( classOf[Hive], @@ -806,8 +848,8 @@ private[client] class Shim_v1_2 extends Shim_v1_1 { numDP: Int, listBucketingEnabled: Boolean): Unit = { loadDynamicPartitionsMethod.invoke(hive, loadPath, tableName, partSpec, replace: JBoolean, - numDP: JInteger, JBoolean.FALSE, listBucketingEnabled: JBoolean, JBoolean.FALSE, - 0L: JLong) + numDP: JInteger, holdDDLTime, listBucketingEnabled: JBoolean, isAcid, + txnIdInLoadDynamicPartitions) } override def dropPartition( @@ -872,7 +914,106 @@ private[client] class Shim_v2_0 extends Shim_v1_2 { isSrcLocal: Boolean): Unit = { loadPartitionMethod.invoke(hive, loadPath, tableName, partSpec, replace: JBoolean, inheritTableSpecs: JBoolean, isSkewedStoreAsSubdir: JBoolean, - isSrcLocal: JBoolean, JBoolean.FALSE) + isSrcLocal: JBoolean, isAcid) + } + + override def loadTable( + hive: Hive, + loadPath: Path, + tableName: String, + replace: Boolean, + isSrcLocal: Boolean): Unit = { + loadTableMethod.invoke(hive, loadPath, tableName, replace: JBoolean, isSrcLocal: JBoolean, + isSkewedStoreAsSubdir, isAcid) + } + + override def loadDynamicPartitions( + hive: Hive, + loadPath: Path, + tableName: String, + partSpec: JMap[String, String], + replace: Boolean, + numDP: Int, + listBucketingEnabled: Boolean): Unit = { + loadDynamicPartitionsMethod.invoke(hive, loadPath, tableName, partSpec, replace: JBoolean, + numDP: JInteger, listBucketingEnabled: JBoolean, isAcid, txnIdInLoadDynamicPartitions) + } + +} + +private[client] class Shim_v2_1 extends Shim_v2_0 { + + // true if there is any following stats task + protected lazy val hasFollowingStatsTask = JBoolean.FALSE + // TODO: Now, always set environmentContext to null. In the future, we should avoid setting + // hive-generated stats to -1 when altering tables by using environmentContext. See Hive-12730 + protected lazy val environmentContextInAlterTable = null + + private lazy val loadPartitionMethod = + findMethod( + classOf[Hive], + "loadPartition", + classOf[Path], + classOf[String], + classOf[JMap[String, String]], + JBoolean.TYPE, + JBoolean.TYPE, + JBoolean.TYPE, + JBoolean.TYPE, + JBoolean.TYPE, + JBoolean.TYPE) + private lazy val loadTableMethod = + findMethod( + classOf[Hive], + "loadTable", + classOf[Path], + classOf[String], + JBoolean.TYPE, + JBoolean.TYPE, + JBoolean.TYPE, + JBoolean.TYPE, + JBoolean.TYPE) + private lazy val loadDynamicPartitionsMethod = + findMethod( + classOf[Hive], + "loadDynamicPartitions", + classOf[Path], + classOf[String], + classOf[JMap[String, String]], + JBoolean.TYPE, + JInteger.TYPE, + JBoolean.TYPE, + JBoolean.TYPE, + JLong.TYPE, + JBoolean.TYPE, + classOf[AcidUtils.Operation]) + private lazy val alterTableMethod = + findMethod( + classOf[Hive], + "alterTable", + classOf[String], + classOf[Table], + classOf[EnvironmentContext]) + private lazy val alterPartitionsMethod = + findMethod( + classOf[Hive], + "alterPartitions", + classOf[String], + classOf[JList[Partition]], + classOf[EnvironmentContext]) + + override def loadPartition( + hive: Hive, + loadPath: Path, + tableName: String, + partSpec: JMap[String, String], + replace: Boolean, + inheritTableSpecs: Boolean, + isSkewedStoreAsSubdir: Boolean, + isSrcLocal: Boolean): Unit = { + loadPartitionMethod.invoke(hive, loadPath, tableName, partSpec, replace: JBoolean, + inheritTableSpecs: JBoolean, isSkewedStoreAsSubdir: JBoolean, + isSrcLocal: JBoolean, isAcid, hasFollowingStatsTask) } override def loadTable( @@ -882,7 +1023,7 @@ private[client] class Shim_v2_0 extends Shim_v1_2 { replace: Boolean, isSrcLocal: Boolean): Unit = { loadTableMethod.invoke(hive, loadPath, tableName, replace: JBoolean, isSrcLocal: JBoolean, - JBoolean.FALSE, JBoolean.FALSE) + isSkewedStoreAsSubdir, isAcid, hasFollowingStatsTask) } override def loadDynamicPartitions( @@ -894,7 +1035,15 @@ private[client] class Shim_v2_0 extends Shim_v1_2 { numDP: Int, listBucketingEnabled: Boolean): Unit = { loadDynamicPartitionsMethod.invoke(hive, loadPath, tableName, partSpec, replace: JBoolean, - numDP: JInteger, listBucketingEnabled: JBoolean, JBoolean.FALSE, 0L: JLong) + numDP: JInteger, listBucketingEnabled: JBoolean, isAcid, txnIdInLoadDynamicPartitions, + hasFollowingStatsTask, AcidUtils.Operation.NOT_ACID) } + override def alterTable(hive: Hive, tableName: String, table: Table): Unit = { + alterTableMethod.invoke(hive, tableName, table, environmentContextInAlterTable) + } + + override def alterPartitions(hive: Hive, tableName: String, newParts: JList[Partition]): Unit = { + alterPartitionsMethod.invoke(hive, tableName, newParts, environmentContextInAlterTable) + } } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/IsolatedClientLoader.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/IsolatedClientLoader.scala index 6f69a4adf29d5..e95f9ea480431 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/IsolatedClientLoader.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/IsolatedClientLoader.scala @@ -95,6 +95,7 @@ private[hive] object IsolatedClientLoader extends Logging { case "1.1" | "1.1.0" => hive.v1_1 case "1.2" | "1.2.0" | "1.2.1" => hive.v1_2 case "2.0" | "2.0.0" | "2.0.1" => hive.v2_0 + case "2.1" | "2.1.0" | "2.1.1" => hive.v2_1 } private def downloadVersion( diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/package.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/package.scala index 790ad74e6639e..f9635e36549e8 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/package.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/package.scala @@ -67,7 +67,11 @@ package object client { exclusions = Seq("org.apache.curator:*", "org.pentaho:pentaho-aggdesigner-algorithm")) - val allSupportedHiveVersions = Set(v12, v13, v14, v1_0, v1_1, v1_2, v2_0) + case object v2_1 extends HiveVersion("2.1.1", + exclusions = Seq("org.apache.curator:*", + "org.pentaho:pentaho-aggdesigner-algorithm")) + + val allSupportedHiveVersions = Set(v12, v13, v14, v1_0, v1_1, v1_2, v2_0, v2_1) } // scalastyle:on diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala index b8536d0c1bd58..3682dc850790e 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala @@ -149,7 +149,7 @@ case class InsertIntoHiveTable( // staging directory under the table director for Hive prior to 1.1, the staging directory will // be removed by Hive when Hive is trying to empty the table directory. val hiveVersionsUsingOldExternalTempPath: Set[HiveVersion] = Set(v12, v13, v14, v1_0) - val hiveVersionsUsingNewExternalTempPath: Set[HiveVersion] = Set(v1_1, v1_2, v2_0) + val hiveVersionsUsingNewExternalTempPath: Set[HiveVersion] = Set(v1_1, v1_2, v2_0, v2_1) // Ensure all the supported versions are considered here. assert(hiveVersionsUsingNewExternalTempPath ++ hiveVersionsUsingOldExternalTempPath == diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala index cb1386111035a..7aff49c0fc3b1 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala @@ -21,6 +21,7 @@ import java.io.{ByteArrayOutputStream, File, PrintStream} import java.net.URI import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.hive.common.StatsSetupConst import org.apache.hadoop.hive.ql.io.HiveIgnoreKeyTextOutputFormat import org.apache.hadoop.hive.serde2.`lazy`.LazySimpleSerDe import org.apache.hadoop.mapred.TextInputFormat @@ -108,7 +109,7 @@ class VersionsSuite extends SparkFunSuite with Logging { assert(getNestedMessages(e) contains "Unknown column 'A0.OWNER_NAME' in 'field list'") } - private val versions = Seq("0.12", "0.13", "0.14", "1.0", "1.1", "1.2", "2.0") + private val versions = Seq("0.12", "0.13", "0.14", "1.0", "1.1", "1.2", "2.0", "2.1") private var client: HiveClient = null @@ -120,10 +121,12 @@ class VersionsSuite extends SparkFunSuite with Logging { System.gc() // Hack to avoid SEGV on some JVM versions. val hadoopConf = new Configuration() hadoopConf.set("test", "success") - // Hive changed the default of datanucleus.schema.autoCreateAll from true to false since 2.0 - // For details, see the JIRA HIVE-6113 - if (version == "2.0") { + // Hive changed the default of datanucleus.schema.autoCreateAll from true to false and + // hive.metastore.schema.verification from false to true since 2.0 + // For details, see the JIRA HIVE-6113 and HIVE-12463 + if (version == "2.0" || version == "2.1") { hadoopConf.set("datanucleus.schema.autoCreateAll", "true") + hadoopConf.set("hive.metastore.schema.verification", "false") } client = buildClient(version, hadoopConf, HiveUtils.hiveClientConfigurations(hadoopConf)) if (versionSpark != null) versionSpark.reset() @@ -572,6 +575,14 @@ class VersionsSuite extends SparkFunSuite with Logging { withTable("tbl") { versionSpark.sql("CREATE TABLE tbl AS SELECT 1 AS a") assert(versionSpark.table("tbl").collect().toSeq == Seq(Row(1))) + val tableMeta = versionSpark.sessionState.catalog.getTableMetadata(TableIdentifier("tbl")) + val totalSize = tableMeta.properties.get(StatsSetupConst.TOTAL_SIZE).map(_.toLong) + // Except 0.12, all the following versions will fill the Hive-generated statistics + if (version == "0.12") { + assert(totalSize.isEmpty) + } else { + assert(totalSize.nonEmpty && totalSize.get > 0) + } } } From e1ac553402ab82bbc72fd64e5943b71c16b4b37d Mon Sep 17 00:00:00 2001 From: Liwei Lin Date: Tue, 14 Mar 2017 22:30:16 -0700 Subject: [PATCH 0022/1765] [SPARK-19817][SS] Make it clear that `timeZone` is a general option in DataStreamReader/Writer ## What changes were proposed in this pull request? As timezone setting can also affect partition values, it works for all formats, we should make it clear. ## How was this patch tested? N/A Author: Liwei Lin Closes #17299 from lw-lin/timezone. --- python/pyspark/sql/readwriter.py | 8 ++--- python/pyspark/sql/streaming.py | 32 ++++++++++++++----- .../apache/spark/sql/DataFrameReader.scala | 6 ++-- .../apache/spark/sql/DataFrameWriter.scala | 6 ++-- .../sql/streaming/DataStreamReader.scala | 22 ++++++++++--- .../sql/streaming/DataStreamWriter.scala | 18 +++++++++++ 6 files changed, 70 insertions(+), 22 deletions(-) diff --git a/python/pyspark/sql/readwriter.py b/python/pyspark/sql/readwriter.py index 705803791d894..122e17f2020f4 100644 --- a/python/pyspark/sql/readwriter.py +++ b/python/pyspark/sql/readwriter.py @@ -112,7 +112,7 @@ def option(self, key, value): You can set the following option(s) for reading files: * ``timeZone``: sets the string that indicates a timezone to be used to parse timestamps - in the JSON/CSV datasources or parttion values. + in the JSON/CSV datasources or partition values. If it isn't set, it uses the default value, session local timezone. """ self._jreader = self._jreader.option(key, to_str(value)) @@ -124,7 +124,7 @@ def options(self, **options): You can set the following option(s) for reading files: * ``timeZone``: sets the string that indicates a timezone to be used to parse timestamps - in the JSON/CSV datasources or parttion values. + in the JSON/CSV datasources or partition values. If it isn't set, it uses the default value, session local timezone. """ for k in options: @@ -530,7 +530,7 @@ def option(self, key, value): You can set the following option(s) for writing files: * ``timeZone``: sets the string that indicates a timezone to be used to format - timestamps in the JSON/CSV datasources or parttion values. + timestamps in the JSON/CSV datasources or partition values. If it isn't set, it uses the default value, session local timezone. """ self._jwrite = self._jwrite.option(key, to_str(value)) @@ -542,7 +542,7 @@ def options(self, **options): You can set the following option(s) for writing files: * ``timeZone``: sets the string that indicates a timezone to be used to format - timestamps in the JSON/CSV datasources or parttion values. + timestamps in the JSON/CSV datasources or partition values. If it isn't set, it uses the default value, session local timezone. """ for k in options: diff --git a/python/pyspark/sql/streaming.py b/python/pyspark/sql/streaming.py index 625fb9ba385af..288cc1e4f64dc 100644 --- a/python/pyspark/sql/streaming.py +++ b/python/pyspark/sql/streaming.py @@ -373,6 +373,11 @@ def schema(self, schema): def option(self, key, value): """Adds an input option for the underlying data source. + You can set the following option(s) for reading files: + * ``timeZone``: sets the string that indicates a timezone to be used to parse timestamps + in the JSON/CSV datasources or partition values. + If it isn't set, it uses the default value, session local timezone. + .. note:: Experimental. >>> s = spark.readStream.option("x", 1) @@ -384,6 +389,11 @@ def option(self, key, value): def options(self, **options): """Adds input options for the underlying data source. + You can set the following option(s) for reading files: + * ``timeZone``: sets the string that indicates a timezone to be used to parse timestamps + in the JSON/CSV datasources or partition values. + If it isn't set, it uses the default value, session local timezone. + .. note:: Experimental. >>> s = spark.readStream.options(x="1", y=2) @@ -429,7 +439,7 @@ def json(self, path, schema=None, primitivesAsString=None, prefersDecimal=None, allowComments=None, allowUnquotedFieldNames=None, allowSingleQuotes=None, allowNumericLeadingZero=None, allowBackslashEscapingAnyCharacter=None, mode=None, columnNameOfCorruptRecord=None, dateFormat=None, timestampFormat=None, - timeZone=None, wholeFile=None): + wholeFile=None): """ Loads a JSON file stream and returns the results as a :class:`DataFrame`. @@ -486,8 +496,6 @@ def json(self, path, schema=None, primitivesAsString=None, prefersDecimal=None, formats follow the formats at ``java.text.SimpleDateFormat``. This applies to timestamp type. If None is set, it uses the default value, ``yyyy-MM-dd'T'HH:mm:ss.SSSZZ``. - :param timeZone: sets the string that indicates a timezone to be used to parse timestamps. - If None is set, it uses the default value, session local timezone. :param wholeFile: parse one record, which may span multiple lines, per file. If None is set, it uses the default value, ``false``. @@ -503,7 +511,7 @@ def json(self, path, schema=None, primitivesAsString=None, prefersDecimal=None, allowSingleQuotes=allowSingleQuotes, allowNumericLeadingZero=allowNumericLeadingZero, allowBackslashEscapingAnyCharacter=allowBackslashEscapingAnyCharacter, mode=mode, columnNameOfCorruptRecord=columnNameOfCorruptRecord, dateFormat=dateFormat, - timestampFormat=timestampFormat, timeZone=timeZone, wholeFile=wholeFile) + timestampFormat=timestampFormat, wholeFile=wholeFile) if isinstance(path, basestring): return self._df(self._jreader.json(path)) else: @@ -561,7 +569,7 @@ def csv(self, path, schema=None, sep=None, encoding=None, quote=None, escape=Non comment=None, header=None, inferSchema=None, ignoreLeadingWhiteSpace=None, ignoreTrailingWhiteSpace=None, nullValue=None, nanValue=None, positiveInf=None, negativeInf=None, dateFormat=None, timestampFormat=None, maxColumns=None, - maxCharsPerColumn=None, maxMalformedLogPerPartition=None, mode=None, timeZone=None, + maxCharsPerColumn=None, maxMalformedLogPerPartition=None, mode=None, columnNameOfCorruptRecord=None, wholeFile=None): """Loads a CSV file stream and returns the result as a :class:`DataFrame`. @@ -619,8 +627,6 @@ def csv(self, path, schema=None, sep=None, encoding=None, quote=None, escape=Non ``-1`` meaning unlimited length. :param mode: allows a mode for dealing with corrupt records during parsing. If None is set, it uses the default value, ``PERMISSIVE``. - :param timeZone: sets the string that indicates a timezone to be used to parse timestamps. - If None is set, it uses the default value, session local timezone. * ``PERMISSIVE`` : sets other fields to ``null`` when it meets a corrupted \ record, and puts the malformed string into a field configured by \ @@ -653,7 +659,7 @@ def csv(self, path, schema=None, sep=None, encoding=None, quote=None, escape=Non nanValue=nanValue, positiveInf=positiveInf, negativeInf=negativeInf, dateFormat=dateFormat, timestampFormat=timestampFormat, maxColumns=maxColumns, maxCharsPerColumn=maxCharsPerColumn, - maxMalformedLogPerPartition=maxMalformedLogPerPartition, mode=mode, timeZone=timeZone, + maxMalformedLogPerPartition=maxMalformedLogPerPartition, mode=mode, columnNameOfCorruptRecord=columnNameOfCorruptRecord, wholeFile=wholeFile) if isinstance(path, basestring): return self._df(self._jreader.csv(path)) @@ -721,6 +727,11 @@ def format(self, source): def option(self, key, value): """Adds an output option for the underlying data source. + You can set the following option(s) for writing files: + * ``timeZone``: sets the string that indicates a timezone to be used to format + timestamps in the JSON/CSV datasources or partition values. + If it isn't set, it uses the default value, session local timezone. + .. note:: Experimental. """ self._jwrite = self._jwrite.option(key, to_str(value)) @@ -730,6 +741,11 @@ def option(self, key, value): def options(self, **options): """Adds output options for the underlying data source. + You can set the following option(s) for writing files: + * ``timeZone``: sets the string that indicates a timezone to be used to format + timestamps in the JSON/CSV datasources or partition values. + If it isn't set, it uses the default value, session local timezone. + .. note:: Experimental. """ for k in options: diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala index 309654c804148..88fbfb4c92a00 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala @@ -73,7 +73,7 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { * You can set the following option(s): *
      *
    • `timeZone` (default session local timezone): sets the string that indicates a timezone - * to be used to parse timestamps in the JSON/CSV datasources or parttion values.
    • + * to be used to parse timestamps in the JSON/CSV datasources or partition values. *
    * * @since 1.4.0 @@ -110,7 +110,7 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { * You can set the following option(s): *
      *
    • `timeZone` (default session local timezone): sets the string that indicates a timezone - * to be used to parse timestamps in the JSON/CSV datasources or parttion values.
    • + * to be used to parse timestamps in the JSON/CSV datasources or partition values. *
    * * @since 1.4.0 @@ -126,7 +126,7 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { * You can set the following option(s): *
      *
    • `timeZone` (default session local timezone): sets the string that indicates a timezone - * to be used to parse timestamps in the JSON/CSV datasources or parttion values.
    • + * to be used to parse timestamps in the JSON/CSV datasources or partition values. *
    * * @since 1.4.0 diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala index 608160a214fba..deaa8006945c1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala @@ -93,7 +93,7 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { * You can set the following option(s): *
      *
    • `timeZone` (default session local timezone): sets the string that indicates a timezone - * to be used to format timestamps in the JSON/CSV datasources or parttion values.
    • + * to be used to format timestamps in the JSON/CSV datasources or partition values. *
    * * @since 1.4.0 @@ -130,7 +130,7 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { * You can set the following option(s): *
      *
    • `timeZone` (default session local timezone): sets the string that indicates a timezone - * to be used to format timestamps in the JSON/CSV datasources or parttion values.
    • + * to be used to format timestamps in the JSON/CSV datasources or partition values. *
    * * @since 1.4.0 @@ -146,7 +146,7 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { * You can set the following option(s): *
      *
    • `timeZone` (default session local timezone): sets the string that indicates a timezone - * to be used to format timestamps in the JSON/CSV datasources or parttion values.
    • + * to be used to format timestamps in the JSON/CSV datasources or partition values. *
    * * @since 1.4.0 diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala index aed8074a64d5b..388ef182ce3a6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala @@ -61,6 +61,12 @@ final class DataStreamReader private[sql](sparkSession: SparkSession) extends Lo /** * Adds an input option for the underlying data source. * + * You can set the following option(s): + *
      + *
    • `timeZone` (default session local timezone): sets the string that indicates a timezone + * to be used to parse timestamps in the JSON/CSV datasources or partition values.
    • + *
    + * * @since 2.0.0 */ def option(key: String, value: String): DataStreamReader = { @@ -92,6 +98,12 @@ final class DataStreamReader private[sql](sparkSession: SparkSession) extends Lo /** * (Scala-specific) Adds input options for the underlying data source. * + * You can set the following option(s): + *
      + *
    • `timeZone` (default session local timezone): sets the string that indicates a timezone + * to be used to parse timestamps in the JSON/CSV datasources or partition values.
    • + *
    + * * @since 2.0.0 */ def options(options: scala.collection.Map[String, String]): DataStreamReader = { @@ -102,6 +114,12 @@ final class DataStreamReader private[sql](sparkSession: SparkSession) extends Lo /** * Adds input options for the underlying data source. * + * You can set the following option(s): + *
      + *
    • `timeZone` (default session local timezone): sets the string that indicates a timezone + * to be used to parse timestamps in the JSON/CSV datasources or partition values.
    • + *
    + * * @since 2.0.0 */ def options(options: java.util.Map[String, String]): DataStreamReader = { @@ -186,8 +204,6 @@ final class DataStreamReader private[sql](sparkSession: SparkSession) extends Lo *
  • `timestampFormat` (default `yyyy-MM-dd'T'HH:mm:ss.SSSZZ`): sets the string that * indicates a timestamp format. Custom date formats follow the formats at * `java.text.SimpleDateFormat`. This applies to timestamp type.
  • - *
  • `timeZone` (default session local timezone): sets the string that indicates a timezone - * to be used to parse timestamps.
  • *
  • `wholeFile` (default `false`): parse one record, which may span multiple lines, * per file
  • * @@ -239,8 +255,6 @@ final class DataStreamReader private[sql](sparkSession: SparkSession) extends Lo *
  • `timestampFormat` (default `yyyy-MM-dd'T'HH:mm:ss.SSSZZ`): sets the string that * indicates a timestamp format. Custom date formats follow the formats at * `java.text.SimpleDateFormat`. This applies to timestamp type.
  • - *
  • `timeZone` (default session local timezone): sets the string that indicates a timezone - * to be used to parse timestamps.
  • *
  • `maxColumns` (default `20480`): defines a hard limit of how many columns * a record can have.
  • *
  • `maxCharsPerColumn` (default `-1`): defines the maximum number of characters allowed diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala index c8fda8cd83598..fe52013badb65 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala @@ -145,6 +145,12 @@ final class DataStreamWriter[T] private[sql](ds: Dataset[T]) { /** * Adds an output option for the underlying data source. * + * You can set the following option(s): + *
      + *
    • `timeZone` (default session local timezone): sets the string that indicates a timezone + * to be used to format timestamps in the JSON/CSV datasources or partition values.
    • + *
    + * * @since 2.0.0 */ def option(key: String, value: String): DataStreamWriter[T] = { @@ -176,6 +182,12 @@ final class DataStreamWriter[T] private[sql](ds: Dataset[T]) { /** * (Scala-specific) Adds output options for the underlying data source. * + * You can set the following option(s): + *
      + *
    • `timeZone` (default session local timezone): sets the string that indicates a timezone + * to be used to format timestamps in the JSON/CSV datasources or partition values.
    • + *
    + * * @since 2.0.0 */ def options(options: scala.collection.Map[String, String]): DataStreamWriter[T] = { @@ -186,6 +198,12 @@ final class DataStreamWriter[T] private[sql](ds: Dataset[T]) { /** * Adds output options for the underlying data source. * + * You can set the following option(s): + *
      + *
    • `timeZone` (default session local timezone): sets the string that indicates a timezone + * to be used to format timestamps in the JSON/CSV datasources or partition values.
    • + *
    + * * @since 2.0.0 */ def options(options: java.util.Map[String, String]): DataStreamWriter[T] = { From ee36bc1c9043ead3c3ba4fba7e68c6c47ad7ae7a Mon Sep 17 00:00:00 2001 From: jiangxingbo Date: Tue, 14 Mar 2017 23:57:54 -0700 Subject: [PATCH 0023/1765] [SPARK-19877][SQL] Restrict the nested level of a view ## What changes were proposed in this pull request? We should restrict the nested level of a view, to avoid stack overflow exception during the view resolution. ## How was this patch tested? Add new test case in `SQLViewSuite`. Author: jiangxingbo Closes #17241 from jiangxb1987/view-depth. --- .../sql/catalyst/SimpleCatalystConf.scala | 3 ++- .../sql/catalyst/analysis/Analyzer.scala | 13 +++++++--- .../apache/spark/sql/internal/SQLConf.scala | 15 +++++++++++ .../spark/sql/execution/SQLViewSuite.scala | 25 +++++++++++++++++++ 4 files changed, 51 insertions(+), 5 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SimpleCatalystConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SimpleCatalystConf.scala index 746f84459de26..0d4903e03bf5a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SimpleCatalystConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SimpleCatalystConf.scala @@ -41,7 +41,8 @@ case class SimpleCatalystConf( override val joinReorderEnabled: Boolean = false, override val joinReorderDPThreshold: Int = 12, override val warehousePath: String = "/user/hive/warehouse", - override val sessionLocalTimeZone: String = TimeZone.getDefault().getID) + override val sessionLocalTimeZone: String = TimeZone.getDefault().getID, + override val maxNestedViewDepth: Int = 100) extends SQLConf { override def clone(): SimpleCatalystConf = this.copy() diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index a3764d8c843dd..68a4746a54d96 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -58,13 +58,12 @@ object SimpleAnalyzer extends Analyzer( * * @param defaultDatabase The default database used in the view resolution, this overrules the * current catalog database. - * @param nestedViewLevel The nested level in the view resolution, this enables us to limit the + * @param nestedViewDepth The nested depth in the view resolution, this enables us to limit the * depth of nested views. - * TODO Limit the depth of nested views. */ case class AnalysisContext( defaultDatabase: Option[String] = None, - nestedViewLevel: Int = 0) + nestedViewDepth: Int = 0) object AnalysisContext { private val value = new ThreadLocal[AnalysisContext]() { @@ -77,7 +76,7 @@ object AnalysisContext { def withAnalysisContext[A](database: Option[String])(f: => A): A = { val originContext = value.get() val context = AnalysisContext(defaultDatabase = database, - nestedViewLevel = originContext.nestedViewLevel + 1) + nestedViewDepth = originContext.nestedViewDepth + 1) set(context) try f finally { set(originContext) } } @@ -598,6 +597,12 @@ class Analyzer( case view @ View(desc, _, child) if !child.resolved => // Resolve all the UnresolvedRelations and Views in the child. val newChild = AnalysisContext.withAnalysisContext(desc.viewDefaultDatabase) { + if (AnalysisContext.get.nestedViewDepth > conf.maxNestedViewDepth) { + view.failAnalysis(s"The depth of view ${view.desc.identifier} exceeds the maximum " + + s"view resolution depth (${conf.maxNestedViewDepth}). Analysis is aborted to " + + "avoid errors. Increase the value of spark.sql.view.maxNestedViewDepth to work " + + "aroud this.") + } execute(child) } view.copy(child = newChild) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 315bedb12e716..8f65672d5a839 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -571,6 +571,19 @@ object SQLConf { .booleanConf .createWithDefault(true) + val MAX_NESTED_VIEW_DEPTH = + buildConf("spark.sql.view.maxNestedViewDepth") + .internal() + .doc("The maximum depth of a view reference in a nested view. A nested view may reference " + + "other nested views, the dependencies are organized in a directed acyclic graph (DAG). " + + "However the DAG depth may become too large and cause unexpected behavior. This " + + "configuration puts a limit on this: when the depth of a view exceeds this value during " + + "analysis, we terminate the resolution to avoid potential errors.") + .intConf + .checkValue(depth => depth > 0, "The maximum depth of a view reference in a nested view " + + "must be positive.") + .createWithDefault(100) + val STREAMING_FILE_COMMIT_PROTOCOL_CLASS = buildConf("spark.sql.streaming.commitProtocolClass") .internal() @@ -932,6 +945,8 @@ class SQLConf extends Serializable with Logging { def joinReorderDPThreshold: Int = getConf(SQLConf.JOIN_REORDER_DP_THRESHOLD) + def maxNestedViewDepth: Int = getConf(SQLConf.MAX_NESTED_VIEW_DEPTH) + /** ********************** SQLConf functionality methods ************ */ /** Set Spark SQL configuration properties. */ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLViewSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLViewSuite.scala index 2ca2206bb9d44..d32716c18ddfb 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLViewSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLViewSuite.scala @@ -644,4 +644,29 @@ abstract class SQLViewSuite extends QueryTest with SQLTestUtils { "-> `default`.`view2` -> `default`.`view1`)")) } } + + test("restrict the nested level of a view") { + val viewNames = Array.range(0, 11).map(idx => s"view$idx") + withView(viewNames: _*) { + sql("CREATE VIEW view0 AS SELECT * FROM jt") + Array.range(0, 10).foreach { idx => + sql(s"CREATE VIEW view${idx + 1} AS SELECT * FROM view$idx") + } + + withSQLConf("spark.sql.view.maxNestedViewDepth" -> "10") { + val e = intercept[AnalysisException] { + sql("SELECT * FROM view10") + }.getMessage + assert(e.contains("The depth of view `default`.`view0` exceeds the maximum view " + + "resolution depth (10). Analysis is aborted to avoid errors. Increase the value " + + "of spark.sql.view.maxNestedViewDepth to work aroud this.")) + } + + val e = intercept[IllegalArgumentException] { + withSQLConf("spark.sql.view.maxNestedViewDepth" -> "0") {} + }.getMessage + assert(e.contains("The maximum depth of a view reference in a nested view must be " + + "positive.")) + } + } } From 9ff85be3bd6bf3a782c0e52fa9c2598d79f310bb Mon Sep 17 00:00:00 2001 From: Herman van Hovell Date: Wed, 15 Mar 2017 10:46:05 +0100 Subject: [PATCH 0024/1765] [SPARK-19889][SQL] Make TaskContext callbacks thread safe ## What changes were proposed in this pull request? It is sometimes useful to use multiple threads in a task to parallelize tasks. These threads might register some completion/failure listeners to clean up when the task completes or fails. We currently cannot register such a callback and be sure that it will get called, because the context might be in the process of invoking its callbacks, when the the callback gets registered. This PR improves this by making sure that you cannot add a completion/failure listener from a different thread when the context is being marked as completed/failed in another thread. This is done by synchronizing these methods on the task context itself. Failure listeners were called only once. Completion listeners now follow the same pattern; this lifts the idempotency requirement for completion listeners and makes it easier to implement them. In some cases we can (accidentally) add a completion/failure listener after the fact, these listeners will be called immediately in order make sure we can safely clean-up after a task. As a result of this change we could make the `failure` and `completed` flags non-volatile. The `isCompleted()` method now uses synchronization to ensure that updates are visible across threads. ## How was this patch tested? Adding tests to `TaskContestSuite` to test adding listeners to a completed/failed context. Author: Herman van Hovell Closes #17244 from hvanhovell/SPARK-19889. --- .../scala/org/apache/spark/TaskContext.scala | 16 ++-- .../org/apache/spark/TaskContextImpl.scala | 85 +++++++++++++------ .../spark/scheduler/TaskContextSuite.scala | 26 ++++++ 3 files changed, 93 insertions(+), 34 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/TaskContext.scala b/core/src/main/scala/org/apache/spark/TaskContext.scala index f0867ecb16ea3..5acfce17593b3 100644 --- a/core/src/main/scala/org/apache/spark/TaskContext.scala +++ b/core/src/main/scala/org/apache/spark/TaskContext.scala @@ -105,7 +105,9 @@ abstract class TaskContext extends Serializable { /** * Adds a (Java friendly) listener to be executed on task completion. - * This will be called in all situation - success, failure, or cancellation. + * This will be called in all situations - success, failure, or cancellation. Adding a listener + * to an already completed task will result in that listener being called immediately. + * * An example use is for HadoopRDD to register a callback to close the input stream. * * Exceptions thrown by the listener will result in failure of the task. @@ -114,7 +116,9 @@ abstract class TaskContext extends Serializable { /** * Adds a listener in the form of a Scala closure to be executed on task completion. - * This will be called in all situations - success, failure, or cancellation. + * This will be called in all situations - success, failure, or cancellation. Adding a listener + * to an already completed task will result in that listener being called immediately. + * * An example use is for HadoopRDD to register a callback to close the input stream. * * Exceptions thrown by the listener will result in failure of the task. @@ -126,14 +130,14 @@ abstract class TaskContext extends Serializable { } /** - * Adds a listener to be executed on task failure. - * Operations defined here must be idempotent, as `onTaskFailure` can be called multiple times. + * Adds a listener to be executed on task failure. Adding a listener to an already failed task + * will result in that listener being called immediately. */ def addTaskFailureListener(listener: TaskFailureListener): TaskContext /** - * Adds a listener to be executed on task failure. - * Operations defined here must be idempotent, as `onTaskFailure` can be called multiple times. + * Adds a listener to be executed on task failure. Adding a listener to an already failed task + * will result in that listener being called immediately. */ def addTaskFailureListener(f: (TaskContext, Throwable) => Unit): TaskContext = { addTaskFailureListener(new TaskFailureListener { diff --git a/core/src/main/scala/org/apache/spark/TaskContextImpl.scala b/core/src/main/scala/org/apache/spark/TaskContextImpl.scala index dc0d12878550a..ea8dcdfd5d7d9 100644 --- a/core/src/main/scala/org/apache/spark/TaskContextImpl.scala +++ b/core/src/main/scala/org/apache/spark/TaskContextImpl.scala @@ -18,6 +18,7 @@ package org.apache.spark import java.util.Properties +import javax.annotation.concurrent.GuardedBy import scala.collection.mutable.ArrayBuffer @@ -29,6 +30,16 @@ import org.apache.spark.metrics.source.Source import org.apache.spark.shuffle.FetchFailedException import org.apache.spark.util._ +/** + * A [[TaskContext]] implementation. + * + * A small note on thread safety. The interrupted & fetchFailed fields are volatile, this makes + * sure that updates are always visible across threads. The complete & failed flags and their + * callbacks are protected by locking on the context instance. For instance, this ensures + * that you cannot add a completion listener in one thread while we are completing (and calling + * the completion listeners) in another thread. Other state is immutable, however the exposed + * [[TaskMetrics]] & [[MetricsSystem]] objects are not thread safe. + */ private[spark] class TaskContextImpl( val stageId: Int, val partitionId: Int, @@ -52,62 +63,79 @@ private[spark] class TaskContextImpl( @volatile private var interrupted: Boolean = false // Whether the task has completed. - @volatile private var completed: Boolean = false + private var completed: Boolean = false // Whether the task has failed. - @volatile private var failed: Boolean = false + private var failed: Boolean = false + + // Throwable that caused the task to fail + private var failure: Throwable = _ // If there was a fetch failure in the task, we store it here, to make sure user-code doesn't // hide the exception. See SPARK-19276 @volatile private var _fetchFailedException: Option[FetchFailedException] = None - override def addTaskCompletionListener(listener: TaskCompletionListener): this.type = { - onCompleteCallbacks += listener + @GuardedBy("this") + override def addTaskCompletionListener(listener: TaskCompletionListener) + : this.type = synchronized { + if (completed) { + listener.onTaskCompletion(this) + } else { + onCompleteCallbacks += listener + } this } - override def addTaskFailureListener(listener: TaskFailureListener): this.type = { - onFailureCallbacks += listener + @GuardedBy("this") + override def addTaskFailureListener(listener: TaskFailureListener) + : this.type = synchronized { + if (failed) { + listener.onTaskFailure(this, failure) + } else { + onFailureCallbacks += listener + } this } /** Marks the task as failed and triggers the failure listeners. */ - private[spark] def markTaskFailed(error: Throwable): Unit = { - // failure callbacks should only be called once + @GuardedBy("this") + private[spark] def markTaskFailed(error: Throwable): Unit = synchronized { if (failed) return failed = true - val errorMsgs = new ArrayBuffer[String](2) - // Process failure callbacks in the reverse order of registration - onFailureCallbacks.reverse.foreach { listener => - try { - listener.onTaskFailure(this, error) - } catch { - case e: Throwable => - errorMsgs += e.getMessage - logError("Error in TaskFailureListener", e) - } - } - if (errorMsgs.nonEmpty) { - throw new TaskCompletionListenerException(errorMsgs, Option(error)) + failure = error + invokeListeners(onFailureCallbacks, "TaskFailureListener", Option(error)) { + _.onTaskFailure(this, error) } } /** Marks the task as completed and triggers the completion listeners. */ - private[spark] def markTaskCompleted(): Unit = { + @GuardedBy("this") + private[spark] def markTaskCompleted(): Unit = synchronized { + if (completed) return completed = true + invokeListeners(onCompleteCallbacks, "TaskCompletionListener", None) { + _.onTaskCompletion(this) + } + } + + private def invokeListeners[T]( + listeners: Seq[T], + name: String, + error: Option[Throwable])( + callback: T => Unit): Unit = { val errorMsgs = new ArrayBuffer[String](2) - // Process complete callbacks in the reverse order of registration - onCompleteCallbacks.reverse.foreach { listener => + // Process callbacks in the reverse order of registration + listeners.reverse.foreach { listener => try { - listener.onTaskCompletion(this) + callback(listener) } catch { case e: Throwable => errorMsgs += e.getMessage - logError("Error in TaskCompletionListener", e) + logError(s"Error in $name", e) } } if (errorMsgs.nonEmpty) { - throw new TaskCompletionListenerException(errorMsgs) + throw new TaskCompletionListenerException(errorMsgs, error) } } @@ -116,7 +144,8 @@ private[spark] class TaskContextImpl( interrupted = true } - override def isCompleted(): Boolean = completed + @GuardedBy("this") + override def isCompleted(): Boolean = synchronized(completed) override def isRunningLocally(): Boolean = false diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala index 7004128308af9..8f576daa77d15 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala @@ -228,6 +228,32 @@ class TaskContextSuite extends SparkFunSuite with BeforeAndAfter with LocalSpark assert(res === Array("testPropValue,testPropValue")) } + test("immediately call a completion listener if the context is completed") { + var invocations = 0 + val context = TaskContext.empty() + context.markTaskCompleted() + context.addTaskCompletionListener(_ => invocations += 1) + assert(invocations == 1) + context.markTaskCompleted() + assert(invocations == 1) + } + + test("immediately call a failure listener if the context has failed") { + var invocations = 0 + var lastError: Throwable = null + val error = new RuntimeException + val context = TaskContext.empty() + context.markTaskFailed(error) + context.addTaskFailureListener { (_, e) => + lastError = e + invocations += 1 + } + assert(lastError == error) + assert(invocations == 1) + context.markTaskFailed(error) + assert(lastError == error) + assert(invocations == 1) + } } private object TaskContextSuite { From 7387126f83dc0489eb1df734bfeba705709b7861 Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Wed, 15 Mar 2017 10:17:18 -0700 Subject: [PATCH 0025/1765] [SPARK-19872] [PYTHON] Use the correct deserializer for RDD construction for coalesce/repartition ## What changes were proposed in this pull request? This PR proposes to use the correct deserializer, `BatchedSerializer` for RDD construction for coalesce/repartition when the shuffle is enabled. Currently, it is passing `UTF8Deserializer` as is not `BatchedSerializer` from the copied one. with the file, `text.txt` below: ``` a b d e f g h i j k l ``` - Before ```python >>> sc.textFile('text.txt').repartition(1).collect() ``` ``` UTF8Deserializer(True) Traceback (most recent call last): File "", line 1, in File ".../spark/python/pyspark/rdd.py", line 811, in collect return list(_load_from_socket(port, self._jrdd_deserializer)) File ".../spark/python/pyspark/serializers.py", line 549, in load_stream yield self.loads(stream) File ".../spark/python/pyspark/serializers.py", line 544, in loads return s.decode("utf-8") if self.use_unicode else s File "/System/Library/Frameworks/Python.framework/Versions/2.7/lib/python2.7/encodings/utf_8.py", line 16, in decode return codecs.utf_8_decode(input, errors, True) UnicodeDecodeError: 'utf8' codec can't decode byte 0x80 in position 0: invalid start byte ``` - After ```python >>> sc.textFile('text.txt').repartition(1).collect() ``` ``` [u'a', u'b', u'', u'd', u'e', u'f', u'g', u'h', u'i', u'j', u'k', u'l', u''] ``` ## How was this patch tested? Unit test in `python/pyspark/tests.py`. Author: hyukjinkwon Closes #17282 from HyukjinKwon/SPARK-19872. --- python/pyspark/rdd.py | 4 +++- python/pyspark/tests.py | 6 ++++++ 2 files changed, 9 insertions(+), 1 deletion(-) diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index a5e6e2b054963..291c1caaaed57 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -2072,10 +2072,12 @@ def coalesce(self, numPartitions, shuffle=False): batchSize = min(10, self.ctx._batchSize or 1024) ser = BatchedSerializer(PickleSerializer(), batchSize) selfCopy = self._reserialize(ser) + jrdd_deserializer = selfCopy._jrdd_deserializer jrdd = selfCopy._jrdd.coalesce(numPartitions, shuffle) else: + jrdd_deserializer = self._jrdd_deserializer jrdd = self._jrdd.coalesce(numPartitions, shuffle) - return RDD(jrdd, self.ctx, self._jrdd_deserializer) + return RDD(jrdd, self.ctx, jrdd_deserializer) def zip(self, other): """ diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py index c6c87a9ea5555..bb13de563cdd4 100644 --- a/python/pyspark/tests.py +++ b/python/pyspark/tests.py @@ -1037,6 +1037,12 @@ def test_repartition_no_skewed(self): zeros = len([x for x in l if x == 0]) self.assertTrue(zeros == 0) + def test_repartition_on_textfile(self): + path = os.path.join(SPARK_HOME, "python/test_support/hello/hello.txt") + rdd = self.sc.textFile(path) + result = rdd.repartition(1).collect() + self.assertEqual(u"Hello World!", result[0]) + def test_distinct(self): rdd = self.sc.parallelize((1, 2, 3)*10, 10) self.assertEqual(rdd.getNumPartitions(), 10) From 02c274eaba0a8e7611226e0d4e93d3c36253f4ce Mon Sep 17 00:00:00 2001 From: Tejas Patil Date: Wed, 15 Mar 2017 20:18:39 +0100 Subject: [PATCH 0026/1765] [SPARK-13450] Introduce ExternalAppendOnlyUnsafeRowArray. Change CartesianProductExec, SortMergeJoin, WindowExec to use it ## What issue does this PR address ? Jira: https://issues.apache.org/jira/browse/SPARK-13450 In `SortMergeJoinExec`, rows of the right relation having the same value for a join key are buffered in-memory. In case of skew, this causes OOMs (see comments in SPARK-13450 for more details). Heap dump from a failed job confirms this : https://issues.apache.org/jira/secure/attachment/12846382/heap-dump-analysis.png . While its possible to increase the heap size to workaround, Spark should be resilient to such issues as skews can happen arbitrarily. ## Change proposed in this pull request - Introduces `ExternalAppendOnlyUnsafeRowArray` - It holds `UnsafeRow`s in-memory upto a certain threshold. - After the threshold is hit, it switches to `UnsafeExternalSorter` which enables spilling of the rows to disk. It does NOT sort the data. - Allows iterating the array multiple times. However, any alteration to the array (using `add` or `clear`) will invalidate the existing iterator(s) - `WindowExec` was already using `UnsafeExternalSorter` to support spilling. Changed it to use the new array - Changed `SortMergeJoinExec` to use the new array implementation - NOTE: I have not changed FULL OUTER JOIN to use this new array implementation. Changing that will need more surgery and I will rather put up a separate PR for that once this gets in. - Changed `CartesianProductExec` to use the new array implementation #### Note for reviewers The diff can be divided into 3 parts. My motive behind having all the changes in a single PR was to demonstrate that the API is sane and supports 2 use cases. If reviewing as 3 separate PRs would help, I am happy to make the split. ## How was this patch tested ? #### Unit testing - Added unit tests `ExternalAppendOnlyUnsafeRowArray` to validate all its APIs and access patterns - Added unit test for `SortMergeExec` - with and without spill for inner join, left outer join, right outer join to confirm that the spill threshold config behaves as expected and output is as expected. - This PR touches the scanning logic in `SortMergeExec` for _all_ joins (except FULL OUTER JOIN). However, I expect existing test cases to cover that there is no regression in correctness. - Added unit test for `WindowExec` to check behavior of spilling and correctness of results. #### Stress testing - Confirmed that OOM is gone by running against a production job which used to OOM - Since I cannot share details about prod workload externally, created synthetic data to mimic the issue. Ran before and after the fix to demonstrate the issue and query success with this PR Generating the synthetic data ``` ./bin/spark-shell --driver-memory=6G import org.apache.spark.sql._ val hc = SparkSession.builder.master("local").getOrCreate() hc.sql("DROP TABLE IF EXISTS spark_13450_large_table").collect hc.sql("DROP TABLE IF EXISTS spark_13450_one_row_table").collect val df1 = (0 until 1).map(i => ("10", "100", i.toString, (i * 2).toString)).toDF("i", "j", "str1", "str2") df1.write.format("org.apache.spark.sql.hive.orc.OrcFileFormat").bucketBy(100, "i", "j").sortBy("i", "j").saveAsTable("spark_13450_one_row_table") val df2 = (0 until 3000000).map(i => ("10", "100", i.toString, (i * 2).toString)).toDF("i", "j", "str1", "str2") df2.write.format("org.apache.spark.sql.hive.orc.OrcFileFormat").bucketBy(100, "i", "j").sortBy("i", "j").saveAsTable("spark_13450_large_table") ``` Ran this against trunk VS local build with this PR. OOM repros with trunk and with the fix this query runs fine. ``` ./bin/spark-shell --driver-java-options="-XX:+HeapDumpOnOutOfMemoryError -XX:HeapDumpPath=/tmp/spark.driver.heapdump.hprof" import org.apache.spark.sql._ val hc = SparkSession.builder.master("local").getOrCreate() hc.sql("SET spark.sql.autoBroadcastJoinThreshold=1") hc.sql("SET spark.sql.sortMergeJoinExec.buffer.spill.threshold=10000") hc.sql("DROP TABLE IF EXISTS spark_13450_result").collect hc.sql(""" CREATE TABLE spark_13450_result AS SELECT a.i AS a_i, a.j AS a_j, a.str1 AS a_str1, a.str2 AS a_str2, b.i AS b_i, b.j AS b_j, b.str1 AS b_str1, b.str2 AS b_str2 FROM spark_13450_one_row_table a JOIN spark_13450_large_table b ON a.i=b.i AND a.j=b.j """) ``` ## Performance comparison ### Macro-benchmark I ran a SMB join query over two real world tables (2 trillion rows (40 TB) and 6 million rows (120 GB)). Note that this dataset does not have skew so no spill happened. I saw improvement in CPU time by 2-4% over version without this PR. This did not add up as I was expected some regression. I think allocating array of capacity of 128 at the start (instead of starting with default size 16) is the sole reason for the perf. gain : https://github.com/tejasapatil/spark/blob/SPARK-13450_smb_buffer_oom/sql/core/src/main/scala/org/apache/spark/sql/execution/ExternalAppendOnlyUnsafeRowArray.scala#L43 . I could remove that and rerun, but effectively the change will be deployed in this form and I wanted to see the effect of it over large workload. ### Micro-benchmark Two types of benchmarking can be found in `ExternalAppendOnlyUnsafeRowArrayBenchmark`: [A] Comparing `ExternalAppendOnlyUnsafeRowArray` against raw `ArrayBuffer` when all rows fit in-memory and there is no spill ``` Array with 1000 rows: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ ArrayBuffer 7821 / 7941 33.5 29.8 1.0X ExternalAppendOnlyUnsafeRowArray 8798 / 8819 29.8 33.6 0.9X Array with 30000 rows: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ ArrayBuffer 19200 / 19206 25.6 39.1 1.0X ExternalAppendOnlyUnsafeRowArray 19558 / 19562 25.1 39.8 1.0X Array with 100000 rows: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ ArrayBuffer 5949 / 6028 17.2 58.1 1.0X ExternalAppendOnlyUnsafeRowArray 6078 / 6138 16.8 59.4 1.0X ``` [B] Comparing `ExternalAppendOnlyUnsafeRowArray` against raw `UnsafeExternalSorter` when there is spilling of data ``` Spilling with 1000 rows: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ UnsafeExternalSorter 9239 / 9470 28.4 35.2 1.0X ExternalAppendOnlyUnsafeRowArray 8857 / 8909 29.6 33.8 1.0X Spilling with 10000 rows: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ UnsafeExternalSorter 4 / 5 39.3 25.5 1.0X ExternalAppendOnlyUnsafeRowArray 5 / 6 29.8 33.5 0.8X ``` Author: Tejas Patil Closes #16909 from tejasapatil/SPARK-13450_smb_buffer_oom. --- .../apache/spark/sql/internal/SQLConf.scala | 30 ++ .../ExternalAppendOnlyUnsafeRowArray.scala | 243 ++++++++++++ .../joins/CartesianProductExec.scala | 52 +-- .../execution/joins/SortMergeJoinExec.scala | 117 +++--- .../sql/execution/window/RowBuffer.scala | 115 ------ .../sql/execution/window/WindowExec.scala | 72 +--- .../window/WindowFunctionFrame.scala | 97 +++-- .../org/apache/spark/sql/JoinSuite.scala | 136 ++++++- ...nalAppendOnlyUnsafeRowArrayBenchmark.scala | 233 ++++++++++++ ...xternalAppendOnlyUnsafeRowArraySuite.scala | 351 ++++++++++++++++++ .../execution/SQLWindowFunctionSuite.scala | 33 ++ 11 files changed, 1187 insertions(+), 292 deletions(-) create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/ExternalAppendOnlyUnsafeRowArray.scala delete mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/window/RowBuffer.scala create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/execution/ExternalAppendOnlyUnsafeRowArrayBenchmark.scala create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/execution/ExternalAppendOnlyUnsafeRowArraySuite.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 8f65672d5a839..a85f87aece45b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -29,6 +29,7 @@ import org.apache.spark.internal.Logging import org.apache.spark.internal.config._ import org.apache.spark.network.util.ByteUnit import org.apache.spark.sql.catalyst.analysis.Resolver +import org.apache.spark.util.collection.unsafe.sort.UnsafeExternalSorter //////////////////////////////////////////////////////////////////////////////////////////////////// // This file defines the configuration options for Spark SQL. @@ -715,6 +716,27 @@ object SQLConf { .stringConf .createWithDefault(TimeZone.getDefault().getID()) + val WINDOW_EXEC_BUFFER_SPILL_THRESHOLD = + buildConf("spark.sql.windowExec.buffer.spill.threshold") + .internal() + .doc("Threshold for number of rows buffered in window operator") + .intConf + .createWithDefault(4096) + + val SORT_MERGE_JOIN_EXEC_BUFFER_SPILL_THRESHOLD = + buildConf("spark.sql.sortMergeJoinExec.buffer.spill.threshold") + .internal() + .doc("Threshold for number of rows buffered in sort merge join operator") + .intConf + .createWithDefault(Int.MaxValue) + + val CARTESIAN_PRODUCT_EXEC_BUFFER_SPILL_THRESHOLD = + buildConf("spark.sql.cartesianProductExec.buffer.spill.threshold") + .internal() + .doc("Threshold for number of rows buffered in cartesian product operator") + .intConf + .createWithDefault(UnsafeExternalSorter.DEFAULT_NUM_ELEMENTS_FOR_SPILL_THRESHOLD.toInt) + object Deprecated { val MAPRED_REDUCE_TASKS = "mapred.reduce.tasks" } @@ -945,6 +967,14 @@ class SQLConf extends Serializable with Logging { def joinReorderDPThreshold: Int = getConf(SQLConf.JOIN_REORDER_DP_THRESHOLD) + def windowExecBufferSpillThreshold: Int = getConf(WINDOW_EXEC_BUFFER_SPILL_THRESHOLD) + + def sortMergeJoinExecBufferSpillThreshold: Int = + getConf(SORT_MERGE_JOIN_EXEC_BUFFER_SPILL_THRESHOLD) + + def cartesianProductExecBufferSpillThreshold: Int = + getConf(CARTESIAN_PRODUCT_EXEC_BUFFER_SPILL_THRESHOLD) + def maxNestedViewDepth: Int = getConf(SQLConf.MAX_NESTED_VIEW_DEPTH) /** ********************** SQLConf functionality methods ************ */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExternalAppendOnlyUnsafeRowArray.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExternalAppendOnlyUnsafeRowArray.scala new file mode 100644 index 0000000000000..458ac4ba3637c --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExternalAppendOnlyUnsafeRowArray.scala @@ -0,0 +1,243 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution + +import java.util.ConcurrentModificationException + +import scala.collection.mutable.ArrayBuffer + +import org.apache.spark.{SparkEnv, TaskContext} +import org.apache.spark.internal.Logging +import org.apache.spark.memory.TaskMemoryManager +import org.apache.spark.serializer.SerializerManager +import org.apache.spark.sql.catalyst.expressions.UnsafeRow +import org.apache.spark.sql.execution.ExternalAppendOnlyUnsafeRowArray.DefaultInitialSizeOfInMemoryBuffer +import org.apache.spark.storage.BlockManager +import org.apache.spark.util.collection.unsafe.sort.{UnsafeExternalSorter, UnsafeSorterIterator} + +/** + * An append-only array for [[UnsafeRow]]s that spills content to disk when there a predefined + * threshold of rows is reached. + * + * Setting spill threshold faces following trade-off: + * + * - If the spill threshold is too high, the in-memory array may occupy more memory than is + * available, resulting in OOM. + * - If the spill threshold is too low, we spill frequently and incur unnecessary disk writes. + * This may lead to a performance regression compared to the normal case of using an + * [[ArrayBuffer]] or [[Array]]. + */ +private[sql] class ExternalAppendOnlyUnsafeRowArray( + taskMemoryManager: TaskMemoryManager, + blockManager: BlockManager, + serializerManager: SerializerManager, + taskContext: TaskContext, + initialSize: Int, + pageSizeBytes: Long, + numRowsSpillThreshold: Int) extends Logging { + + def this(numRowsSpillThreshold: Int) { + this( + TaskContext.get().taskMemoryManager(), + SparkEnv.get.blockManager, + SparkEnv.get.serializerManager, + TaskContext.get(), + 1024, + SparkEnv.get.memoryManager.pageSizeBytes, + numRowsSpillThreshold) + } + + private val initialSizeOfInMemoryBuffer = + Math.min(DefaultInitialSizeOfInMemoryBuffer, numRowsSpillThreshold) + + private val inMemoryBuffer = if (initialSizeOfInMemoryBuffer > 0) { + new ArrayBuffer[UnsafeRow](initialSizeOfInMemoryBuffer) + } else { + null + } + + private var spillableArray: UnsafeExternalSorter = _ + private var numRows = 0 + + // A counter to keep track of total modifications done to this array since its creation. + // This helps to invalidate iterators when there are changes done to the backing array. + private var modificationsCount: Long = 0 + + private var numFieldsPerRow = 0 + + def length: Int = numRows + + def isEmpty: Boolean = numRows == 0 + + /** + * Clears up resources (eg. memory) held by the backing storage + */ + def clear(): Unit = { + if (spillableArray != null) { + // The last `spillableArray` of this task will be cleaned up via task completion listener + // inside `UnsafeExternalSorter` + spillableArray.cleanupResources() + spillableArray = null + } else if (inMemoryBuffer != null) { + inMemoryBuffer.clear() + } + numFieldsPerRow = 0 + numRows = 0 + modificationsCount += 1 + } + + def add(unsafeRow: UnsafeRow): Unit = { + if (numRows < numRowsSpillThreshold) { + inMemoryBuffer += unsafeRow.copy() + } else { + if (spillableArray == null) { + logInfo(s"Reached spill threshold of $numRowsSpillThreshold rows, switching to " + + s"${classOf[UnsafeExternalSorter].getName}") + + // We will not sort the rows, so prefixComparator and recordComparator are null + spillableArray = UnsafeExternalSorter.create( + taskMemoryManager, + blockManager, + serializerManager, + taskContext, + null, + null, + initialSize, + pageSizeBytes, + numRowsSpillThreshold, + false) + + // populate with existing in-memory buffered rows + if (inMemoryBuffer != null) { + inMemoryBuffer.foreach(existingUnsafeRow => + spillableArray.insertRecord( + existingUnsafeRow.getBaseObject, + existingUnsafeRow.getBaseOffset, + existingUnsafeRow.getSizeInBytes, + 0, + false) + ) + inMemoryBuffer.clear() + } + numFieldsPerRow = unsafeRow.numFields() + } + + spillableArray.insertRecord( + unsafeRow.getBaseObject, + unsafeRow.getBaseOffset, + unsafeRow.getSizeInBytes, + 0, + false) + } + + numRows += 1 + modificationsCount += 1 + } + + /** + * Creates an [[Iterator]] for the current rows in the array starting from a user provided index + * + * If there are subsequent [[add()]] or [[clear()]] calls made on this array after creation of + * the iterator, then the iterator is invalidated thus saving clients from thinking that they + * have read all the data while there were new rows added to this array. + */ + def generateIterator(startIndex: Int): Iterator[UnsafeRow] = { + if (startIndex < 0 || (numRows > 0 && startIndex > numRows)) { + throw new ArrayIndexOutOfBoundsException( + "Invalid `startIndex` provided for generating iterator over the array. " + + s"Total elements: $numRows, requested `startIndex`: $startIndex") + } + + if (spillableArray == null) { + new InMemoryBufferIterator(startIndex) + } else { + new SpillableArrayIterator(spillableArray.getIterator, numFieldsPerRow, startIndex) + } + } + + def generateIterator(): Iterator[UnsafeRow] = generateIterator(startIndex = 0) + + private[this] + abstract class ExternalAppendOnlyUnsafeRowArrayIterator extends Iterator[UnsafeRow] { + private val expectedModificationsCount = modificationsCount + + protected def isModified(): Boolean = expectedModificationsCount != modificationsCount + + protected def throwExceptionIfModified(): Unit = { + if (expectedModificationsCount != modificationsCount) { + throw new ConcurrentModificationException( + s"The backing ${classOf[ExternalAppendOnlyUnsafeRowArray].getName} has been modified " + + s"since the creation of this Iterator") + } + } + } + + private[this] class InMemoryBufferIterator(startIndex: Int) + extends ExternalAppendOnlyUnsafeRowArrayIterator { + + private var currentIndex = startIndex + + override def hasNext(): Boolean = !isModified() && currentIndex < numRows + + override def next(): UnsafeRow = { + throwExceptionIfModified() + val result = inMemoryBuffer(currentIndex) + currentIndex += 1 + result + } + } + + private[this] class SpillableArrayIterator( + iterator: UnsafeSorterIterator, + numFieldPerRow: Int, + startIndex: Int) + extends ExternalAppendOnlyUnsafeRowArrayIterator { + + private val currentRow = new UnsafeRow(numFieldPerRow) + + def init(): Unit = { + var i = 0 + while (i < startIndex) { + if (iterator.hasNext) { + iterator.loadNext() + } else { + throw new ArrayIndexOutOfBoundsException( + "Invalid `startIndex` provided for generating iterator over the array. " + + s"Total elements: $numRows, requested `startIndex`: $startIndex") + } + i += 1 + } + } + + // Traverse upto the given [[startIndex]] + init() + + override def hasNext(): Boolean = !isModified() && iterator.hasNext + + override def next(): UnsafeRow = { + throwExceptionIfModified() + iterator.loadNext() + currentRow.pointTo(iterator.getBaseObject, iterator.getBaseOffset, iterator.getRecordLength) + currentRow + } + } +} + +private[sql] object ExternalAppendOnlyUnsafeRowArray { + val DefaultInitialSizeOfInMemoryBuffer = 128 +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProductExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProductExec.scala index 8341fe2ffd078..f380986951317 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProductExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProductExec.scala @@ -19,65 +19,39 @@ package org.apache.spark.sql.execution.joins import org.apache.spark._ import org.apache.spark.rdd.{CartesianPartition, CartesianRDD, RDD} -import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression, JoinedRow, UnsafeRow} import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeRowJoiner -import org.apache.spark.sql.execution.{BinaryExecNode, SparkPlan} +import org.apache.spark.sql.execution.{BinaryExecNode, ExternalAppendOnlyUnsafeRowArray, SparkPlan} import org.apache.spark.sql.execution.metric.SQLMetrics -import org.apache.spark.sql.internal.SQLConf import org.apache.spark.util.CompletionIterator -import org.apache.spark.util.collection.unsafe.sort.UnsafeExternalSorter /** * An optimized CartesianRDD for UnsafeRow, which will cache the rows from second child RDD, * will be much faster than building the right partition for every row in left RDD, it also * materialize the right RDD (in case of the right RDD is nondeterministic). */ -class UnsafeCartesianRDD(left : RDD[UnsafeRow], right : RDD[UnsafeRow], numFieldsOfRight: Int) +class UnsafeCartesianRDD( + left : RDD[UnsafeRow], + right : RDD[UnsafeRow], + numFieldsOfRight: Int, + spillThreshold: Int) extends CartesianRDD[UnsafeRow, UnsafeRow](left.sparkContext, left, right) { override def compute(split: Partition, context: TaskContext): Iterator[(UnsafeRow, UnsafeRow)] = { - // We will not sort the rows, so prefixComparator and recordComparator are null. - val sorter = UnsafeExternalSorter.create( - context.taskMemoryManager(), - SparkEnv.get.blockManager, - SparkEnv.get.serializerManager, - context, - null, - null, - 1024, - SparkEnv.get.memoryManager.pageSizeBytes, - SparkEnv.get.conf.getLong("spark.shuffle.spill.numElementsForceSpillThreshold", - UnsafeExternalSorter.DEFAULT_NUM_ELEMENTS_FOR_SPILL_THRESHOLD), - false) + val rowArray = new ExternalAppendOnlyUnsafeRowArray(spillThreshold) val partition = split.asInstanceOf[CartesianPartition] - for (y <- rdd2.iterator(partition.s2, context)) { - sorter.insertRecord(y.getBaseObject, y.getBaseOffset, y.getSizeInBytes, 0, false) - } + rdd2.iterator(partition.s2, context).foreach(rowArray.add) - // Create an iterator from sorter and wrapper it as Iterator[UnsafeRow] - def createIter(): Iterator[UnsafeRow] = { - val iter = sorter.getIterator - val unsafeRow = new UnsafeRow(numFieldsOfRight) - new Iterator[UnsafeRow] { - override def hasNext: Boolean = { - iter.hasNext - } - override def next(): UnsafeRow = { - iter.loadNext() - unsafeRow.pointTo(iter.getBaseObject, iter.getBaseOffset, iter.getRecordLength) - unsafeRow - } - } - } + // Create an iterator from rowArray + def createIter(): Iterator[UnsafeRow] = rowArray.generateIterator() val resultIter = for (x <- rdd1.iterator(partition.s1, context); y <- createIter()) yield (x, y) CompletionIterator[(UnsafeRow, UnsafeRow), Iterator[(UnsafeRow, UnsafeRow)]]( - resultIter, sorter.cleanupResources()) + resultIter, rowArray.clear()) } } @@ -97,7 +71,9 @@ case class CartesianProductExec( val leftResults = left.execute().asInstanceOf[RDD[UnsafeRow]] val rightResults = right.execute().asInstanceOf[RDD[UnsafeRow]] - val pair = new UnsafeCartesianRDD(leftResults, rightResults, right.output.size) + val spillThreshold = sqlContext.conf.cartesianProductExecBufferSpillThreshold + + val pair = new UnsafeCartesianRDD(leftResults, rightResults, right.output.size, spillThreshold) pair.mapPartitionsWithIndexInternal { (index, iter) => val joiner = GenerateUnsafeRowJoiner.create(left.schema, right.schema) val filtered = if (condition.isDefined) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala index ca9c0ed8cec32..bcdc4dcdf7d99 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala @@ -25,7 +25,8 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.physical._ -import org.apache.spark.sql.execution.{BinaryExecNode, CodegenSupport, RowIterator, SparkPlan} +import org.apache.spark.sql.execution.{BinaryExecNode, CodegenSupport, +ExternalAppendOnlyUnsafeRowArray, RowIterator, SparkPlan} import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics} import org.apache.spark.util.collection.BitSet @@ -95,9 +96,13 @@ case class SortMergeJoinExec( private def createRightKeyGenerator(): Projection = UnsafeProjection.create(rightKeys, right.output) + private def getSpillThreshold: Int = { + sqlContext.conf.sortMergeJoinExecBufferSpillThreshold + } + protected override def doExecute(): RDD[InternalRow] = { val numOutputRows = longMetric("numOutputRows") - + val spillThreshold = getSpillThreshold left.execute().zipPartitions(right.execute()) { (leftIter, rightIter) => val boundCondition: (InternalRow) => Boolean = { condition.map { cond => @@ -115,39 +120,39 @@ case class SortMergeJoinExec( case _: InnerLike => new RowIterator { private[this] var currentLeftRow: InternalRow = _ - private[this] var currentRightMatches: ArrayBuffer[InternalRow] = _ - private[this] var currentMatchIdx: Int = -1 + private[this] var currentRightMatches: ExternalAppendOnlyUnsafeRowArray = _ + private[this] var rightMatchesIterator: Iterator[UnsafeRow] = null private[this] val smjScanner = new SortMergeJoinScanner( createLeftKeyGenerator(), createRightKeyGenerator(), keyOrdering, RowIterator.fromScala(leftIter), - RowIterator.fromScala(rightIter) + RowIterator.fromScala(rightIter), + spillThreshold ) private[this] val joinRow = new JoinedRow if (smjScanner.findNextInnerJoinRows()) { currentRightMatches = smjScanner.getBufferedMatches currentLeftRow = smjScanner.getStreamedRow - currentMatchIdx = 0 + rightMatchesIterator = currentRightMatches.generateIterator() } override def advanceNext(): Boolean = { - while (currentMatchIdx >= 0) { - if (currentMatchIdx == currentRightMatches.length) { + while (rightMatchesIterator != null) { + if (!rightMatchesIterator.hasNext) { if (smjScanner.findNextInnerJoinRows()) { currentRightMatches = smjScanner.getBufferedMatches currentLeftRow = smjScanner.getStreamedRow - currentMatchIdx = 0 + rightMatchesIterator = currentRightMatches.generateIterator() } else { currentRightMatches = null currentLeftRow = null - currentMatchIdx = -1 + rightMatchesIterator = null return false } } - joinRow(currentLeftRow, currentRightMatches(currentMatchIdx)) - currentMatchIdx += 1 + joinRow(currentLeftRow, rightMatchesIterator.next()) if (boundCondition(joinRow)) { numOutputRows += 1 return true @@ -165,7 +170,8 @@ case class SortMergeJoinExec( bufferedKeyGenerator = createRightKeyGenerator(), keyOrdering, streamedIter = RowIterator.fromScala(leftIter), - bufferedIter = RowIterator.fromScala(rightIter) + bufferedIter = RowIterator.fromScala(rightIter), + spillThreshold ) val rightNullRow = new GenericInternalRow(right.output.length) new LeftOuterIterator( @@ -177,7 +183,8 @@ case class SortMergeJoinExec( bufferedKeyGenerator = createLeftKeyGenerator(), keyOrdering, streamedIter = RowIterator.fromScala(rightIter), - bufferedIter = RowIterator.fromScala(leftIter) + bufferedIter = RowIterator.fromScala(leftIter), + spillThreshold ) val leftNullRow = new GenericInternalRow(left.output.length) new RightOuterIterator( @@ -209,7 +216,8 @@ case class SortMergeJoinExec( createRightKeyGenerator(), keyOrdering, RowIterator.fromScala(leftIter), - RowIterator.fromScala(rightIter) + RowIterator.fromScala(rightIter), + spillThreshold ) private[this] val joinRow = new JoinedRow @@ -217,14 +225,15 @@ case class SortMergeJoinExec( while (smjScanner.findNextInnerJoinRows()) { val currentRightMatches = smjScanner.getBufferedMatches currentLeftRow = smjScanner.getStreamedRow - var i = 0 - while (i < currentRightMatches.length) { - joinRow(currentLeftRow, currentRightMatches(i)) - if (boundCondition(joinRow)) { - numOutputRows += 1 - return true + if (currentRightMatches != null && currentRightMatches.length > 0) { + val rightMatchesIterator = currentRightMatches.generateIterator() + while (rightMatchesIterator.hasNext) { + joinRow(currentLeftRow, rightMatchesIterator.next()) + if (boundCondition(joinRow)) { + numOutputRows += 1 + return true + } } - i += 1 } } false @@ -241,7 +250,8 @@ case class SortMergeJoinExec( createRightKeyGenerator(), keyOrdering, RowIterator.fromScala(leftIter), - RowIterator.fromScala(rightIter) + RowIterator.fromScala(rightIter), + spillThreshold ) private[this] val joinRow = new JoinedRow @@ -249,17 +259,16 @@ case class SortMergeJoinExec( while (smjScanner.findNextOuterJoinRows()) { currentLeftRow = smjScanner.getStreamedRow val currentRightMatches = smjScanner.getBufferedMatches - if (currentRightMatches == null) { + if (currentRightMatches == null || currentRightMatches.length == 0) { return true } - var i = 0 var found = false - while (!found && i < currentRightMatches.length) { - joinRow(currentLeftRow, currentRightMatches(i)) + val rightMatchesIterator = currentRightMatches.generateIterator() + while (!found && rightMatchesIterator.hasNext) { + joinRow(currentLeftRow, rightMatchesIterator.next()) if (boundCondition(joinRow)) { found = true } - i += 1 } if (!found) { numOutputRows += 1 @@ -281,7 +290,8 @@ case class SortMergeJoinExec( createRightKeyGenerator(), keyOrdering, RowIterator.fromScala(leftIter), - RowIterator.fromScala(rightIter) + RowIterator.fromScala(rightIter), + spillThreshold ) private[this] val joinRow = new JoinedRow @@ -290,14 +300,13 @@ case class SortMergeJoinExec( currentLeftRow = smjScanner.getStreamedRow val currentRightMatches = smjScanner.getBufferedMatches var found = false - if (currentRightMatches != null) { - var i = 0 - while (!found && i < currentRightMatches.length) { - joinRow(currentLeftRow, currentRightMatches(i)) + if (currentRightMatches != null && currentRightMatches.length > 0) { + val rightMatchesIterator = currentRightMatches.generateIterator() + while (!found && rightMatchesIterator.hasNext) { + joinRow(currentLeftRow, rightMatchesIterator.next()) if (boundCondition(joinRow)) { found = true } - i += 1 } } result.setBoolean(0, found) @@ -376,8 +385,11 @@ case class SortMergeJoinExec( // A list to hold all matched rows from right side. val matches = ctx.freshName("matches") - val clsName = classOf[java.util.ArrayList[InternalRow]].getName - ctx.addMutableState(clsName, matches, s"$matches = new $clsName();") + val clsName = classOf[ExternalAppendOnlyUnsafeRowArray].getName + + val spillThreshold = getSpillThreshold + + ctx.addMutableState(clsName, matches, s"$matches = new $clsName($spillThreshold);") // Copy the left keys as class members so they could be used in next function call. val matchedKeyVars = copyKeys(ctx, leftKeyVars) @@ -428,7 +440,7 @@ case class SortMergeJoinExec( | } | $leftRow = null; | } else { - | $matches.add($rightRow.copy()); + | $matches.add((UnsafeRow) $rightRow); | $rightRow = null;; | } | } while ($leftRow != null); @@ -517,8 +529,7 @@ case class SortMergeJoinExec( val rightRow = ctx.freshName("rightRow") val rightVars = createRightVar(ctx, rightRow) - val size = ctx.freshName("size") - val i = ctx.freshName("i") + val iterator = ctx.freshName("iterator") val numOutput = metricTerm(ctx, "numOutputRows") val (beforeLoop, condCheck) = if (condition.isDefined) { // Split the code of creating variables based on whether it's used by condition or not. @@ -551,10 +562,10 @@ case class SortMergeJoinExec( s""" |while (findNextInnerJoinRows($leftInput, $rightInput)) { - | int $size = $matches.size(); | ${beforeLoop.trim} - | for (int $i = 0; $i < $size; $i ++) { - | InternalRow $rightRow = (InternalRow) $matches.get($i); + | scala.collection.Iterator $iterator = $matches.generateIterator(); + | while ($iterator.hasNext()) { + | InternalRow $rightRow = (InternalRow) $iterator.next(); | ${condCheck.trim} | $numOutput.add(1); | ${consume(ctx, leftVars ++ rightVars)} @@ -589,7 +600,8 @@ private[joins] class SortMergeJoinScanner( bufferedKeyGenerator: Projection, keyOrdering: Ordering[InternalRow], streamedIter: RowIterator, - bufferedIter: RowIterator) { + bufferedIter: RowIterator, + bufferThreshold: Int) { private[this] var streamedRow: InternalRow = _ private[this] var streamedRowKey: InternalRow = _ private[this] var bufferedRow: InternalRow = _ @@ -600,7 +612,7 @@ private[joins] class SortMergeJoinScanner( */ private[this] var matchJoinKey: InternalRow = _ /** Buffered rows from the buffered side of the join. This is empty if there are no matches. */ - private[this] val bufferedMatches: ArrayBuffer[InternalRow] = new ArrayBuffer[InternalRow] + private[this] val bufferedMatches = new ExternalAppendOnlyUnsafeRowArray(bufferThreshold) // Initialization (note: do _not_ want to advance streamed here). advancedBufferedToRowWithNullFreeJoinKey() @@ -609,7 +621,7 @@ private[joins] class SortMergeJoinScanner( def getStreamedRow: InternalRow = streamedRow - def getBufferedMatches: ArrayBuffer[InternalRow] = bufferedMatches + def getBufferedMatches: ExternalAppendOnlyUnsafeRowArray = bufferedMatches /** * Advances both input iterators, stopping when we have found rows with matching join keys. @@ -755,7 +767,7 @@ private[joins] class SortMergeJoinScanner( matchJoinKey = streamedRowKey.copy() bufferedMatches.clear() do { - bufferedMatches += bufferedRow.copy() // need to copy mutable rows before buffering them + bufferedMatches.add(bufferedRow.asInstanceOf[UnsafeRow]) advancedBufferedToRowWithNullFreeJoinKey() } while (bufferedRow != null && keyOrdering.compare(streamedRowKey, bufferedRowKey) == 0) } @@ -819,7 +831,7 @@ private abstract class OneSideOuterIterator( protected[this] val joinedRow: JoinedRow = new JoinedRow() // Index of the buffered rows, reset to 0 whenever we advance to a new streamed row - private[this] var bufferIndex: Int = 0 + private[this] var rightMatchesIterator: Iterator[UnsafeRow] = null // This iterator is initialized lazily so there should be no matches initially assert(smjScanner.getBufferedMatches.length == 0) @@ -833,7 +845,7 @@ private abstract class OneSideOuterIterator( * @return whether there are more rows in the stream to consume. */ private def advanceStream(): Boolean = { - bufferIndex = 0 + rightMatchesIterator = null if (smjScanner.findNextOuterJoinRows()) { setStreamSideOutput(smjScanner.getStreamedRow) if (smjScanner.getBufferedMatches.isEmpty) { @@ -858,10 +870,13 @@ private abstract class OneSideOuterIterator( */ private def advanceBufferUntilBoundConditionSatisfied(): Boolean = { var foundMatch: Boolean = false - while (!foundMatch && bufferIndex < smjScanner.getBufferedMatches.length) { - setBufferedSideOutput(smjScanner.getBufferedMatches(bufferIndex)) + if (rightMatchesIterator == null) { + rightMatchesIterator = smjScanner.getBufferedMatches.generateIterator() + } + + while (!foundMatch && rightMatchesIterator.hasNext) { + setBufferedSideOutput(rightMatchesIterator.next()) foundMatch = boundCondition(joinedRow) - bufferIndex += 1 } foundMatch } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/window/RowBuffer.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/window/RowBuffer.scala deleted file mode 100644 index ee36c84251519..0000000000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/window/RowBuffer.scala +++ /dev/null @@ -1,115 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.execution.window - -import scala.collection.mutable.ArrayBuffer - -import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.UnsafeRow -import org.apache.spark.util.collection.unsafe.sort.{UnsafeExternalSorter, UnsafeSorterIterator} - - -/** - * The interface of row buffer for a partition. In absence of a buffer pool (with locking), the - * row buffer is used to materialize a partition of rows since we need to repeatedly scan these - * rows in window function processing. - */ -private[window] abstract class RowBuffer { - - /** Number of rows. */ - def size: Int - - /** Return next row in the buffer, null if no more left. */ - def next(): InternalRow - - /** Skip the next `n` rows. */ - def skip(n: Int): Unit - - /** Return a new RowBuffer that has the same rows. */ - def copy(): RowBuffer -} - -/** - * A row buffer based on ArrayBuffer (the number of rows is limited). - */ -private[window] class ArrayRowBuffer(buffer: ArrayBuffer[UnsafeRow]) extends RowBuffer { - - private[this] var cursor: Int = -1 - - /** Number of rows. */ - override def size: Int = buffer.length - - /** Return next row in the buffer, null if no more left. */ - override def next(): InternalRow = { - cursor += 1 - if (cursor < buffer.length) { - buffer(cursor) - } else { - null - } - } - - /** Skip the next `n` rows. */ - override def skip(n: Int): Unit = { - cursor += n - } - - /** Return a new RowBuffer that has the same rows. */ - override def copy(): RowBuffer = { - new ArrayRowBuffer(buffer) - } -} - -/** - * An external buffer of rows based on UnsafeExternalSorter. - */ -private[window] class ExternalRowBuffer(sorter: UnsafeExternalSorter, numFields: Int) - extends RowBuffer { - - private[this] val iter: UnsafeSorterIterator = sorter.getIterator - - private[this] val currentRow = new UnsafeRow(numFields) - - /** Number of rows. */ - override def size: Int = iter.getNumRecords() - - /** Return next row in the buffer, null if no more left. */ - override def next(): InternalRow = { - if (iter.hasNext) { - iter.loadNext() - currentRow.pointTo(iter.getBaseObject, iter.getBaseOffset, iter.getRecordLength) - currentRow - } else { - null - } - } - - /** Skip the next `n` rows. */ - override def skip(n: Int): Unit = { - var i = 0 - while (i < n && iter.hasNext) { - iter.loadNext() - i += 1 - } - } - - /** Return a new RowBuffer that has the same rows. */ - override def copy(): RowBuffer = { - new ExternalRowBuffer(sorter, numFields) - } -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowExec.scala index 80b87d5ffa797..950a6794a74a3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowExec.scala @@ -20,15 +20,13 @@ package org.apache.spark.sql.execution.window import scala.collection.mutable import scala.collection.mutable.ArrayBuffer -import org.apache.spark.{SparkEnv, TaskContext} import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.plans.physical._ -import org.apache.spark.sql.execution.{SparkPlan, UnaryExecNode} +import org.apache.spark.sql.execution.{ExternalAppendOnlyUnsafeRowArray, SparkPlan, UnaryExecNode} import org.apache.spark.sql.types.IntegerType -import org.apache.spark.util.collection.unsafe.sort.UnsafeExternalSorter /** * This class calculates and outputs (windowed) aggregates over the rows in a single (sorted) @@ -284,6 +282,7 @@ case class WindowExec( // Unwrap the expressions and factories from the map. val expressions = windowFrameExpressionFactoryPairs.flatMap(_._1) val factories = windowFrameExpressionFactoryPairs.map(_._2).toArray + val spillThreshold = sqlContext.conf.windowExecBufferSpillThreshold // Start processing. child.execute().mapPartitions { stream => @@ -310,10 +309,12 @@ case class WindowExec( fetchNextRow() // Manage the current partition. - val rows = ArrayBuffer.empty[UnsafeRow] val inputFields = child.output.length - var sorter: UnsafeExternalSorter = null - var rowBuffer: RowBuffer = null + + val buffer: ExternalAppendOnlyUnsafeRowArray = + new ExternalAppendOnlyUnsafeRowArray(spillThreshold) + var bufferIterator: Iterator[UnsafeRow] = _ + val windowFunctionResult = new SpecificInternalRow(expressions.map(_.dataType)) val frames = factories.map(_(windowFunctionResult)) val numFrames = frames.length @@ -323,78 +324,43 @@ case class WindowExec( val currentGroup = nextGroup.copy() // clear last partition - if (sorter != null) { - // the last sorter of this task will be cleaned up via task completion listener - sorter.cleanupResources() - sorter = null - } else { - rows.clear() - } + buffer.clear() while (nextRowAvailable && nextGroup == currentGroup) { - if (sorter == null) { - rows += nextRow.copy() - - if (rows.length >= 4096) { - // We will not sort the rows, so prefixComparator and recordComparator are null. - sorter = UnsafeExternalSorter.create( - TaskContext.get().taskMemoryManager(), - SparkEnv.get.blockManager, - SparkEnv.get.serializerManager, - TaskContext.get(), - null, - null, - 1024, - SparkEnv.get.memoryManager.pageSizeBytes, - SparkEnv.get.conf.getLong("spark.shuffle.spill.numElementsForceSpillThreshold", - UnsafeExternalSorter.DEFAULT_NUM_ELEMENTS_FOR_SPILL_THRESHOLD), - false) - rows.foreach { r => - sorter.insertRecord(r.getBaseObject, r.getBaseOffset, r.getSizeInBytes, 0, false) - } - rows.clear() - } - } else { - sorter.insertRecord(nextRow.getBaseObject, nextRow.getBaseOffset, - nextRow.getSizeInBytes, 0, false) - } + buffer.add(nextRow) fetchNextRow() } - if (sorter != null) { - rowBuffer = new ExternalRowBuffer(sorter, inputFields) - } else { - rowBuffer = new ArrayRowBuffer(rows) - } // Setup the frames. var i = 0 while (i < numFrames) { - frames(i).prepare(rowBuffer.copy()) + frames(i).prepare(buffer) i += 1 } // Setup iteration rowIndex = 0 - rowsSize = rowBuffer.size + bufferIterator = buffer.generateIterator() } // Iteration var rowIndex = 0 - var rowsSize = 0L - override final def hasNext: Boolean = rowIndex < rowsSize || nextRowAvailable + override final def hasNext: Boolean = + (bufferIterator != null && bufferIterator.hasNext) || nextRowAvailable val join = new JoinedRow override final def next(): InternalRow = { // Load the next partition if we need to. - if (rowIndex >= rowsSize && nextRowAvailable) { + if ((bufferIterator == null || !bufferIterator.hasNext) && nextRowAvailable) { fetchNextPartition() } - if (rowIndex < rowsSize) { + if (bufferIterator.hasNext) { + val current = bufferIterator.next() + // Get the results for the window frames. var i = 0 - val current = rowBuffer.next() while (i < numFrames) { frames(i).write(rowIndex, current) i += 1 @@ -406,7 +372,9 @@ case class WindowExec( // Return the projection. result(join) - } else throw new NoSuchElementException + } else { + throw new NoSuchElementException + } } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowFunctionFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowFunctionFrame.scala index 70efc0f78ddb0..af2b4fb92062b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowFunctionFrame.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowFunctionFrame.scala @@ -22,6 +22,7 @@ import java.util import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate.NoOp +import org.apache.spark.sql.execution.ExternalAppendOnlyUnsafeRowArray /** @@ -35,7 +36,7 @@ private[window] abstract class WindowFunctionFrame { * * @param rows to calculate the frame results for. */ - def prepare(rows: RowBuffer): Unit + def prepare(rows: ExternalAppendOnlyUnsafeRowArray): Unit /** * Write the current results to the target row. @@ -43,6 +44,12 @@ private[window] abstract class WindowFunctionFrame { def write(index: Int, current: InternalRow): Unit } +object WindowFunctionFrame { + def getNextOrNull(iterator: Iterator[UnsafeRow]): UnsafeRow = { + if (iterator.hasNext) iterator.next() else null + } +} + /** * The offset window frame calculates frames containing LEAD/LAG statements. * @@ -65,7 +72,12 @@ private[window] final class OffsetWindowFunctionFrame( extends WindowFunctionFrame { /** Rows of the partition currently being processed. */ - private[this] var input: RowBuffer = null + private[this] var input: ExternalAppendOnlyUnsafeRowArray = null + + /** + * An iterator over the [[input]] + */ + private[this] var inputIterator: Iterator[UnsafeRow] = _ /** Index of the input row currently used for output. */ private[this] var inputIndex = 0 @@ -103,20 +115,21 @@ private[window] final class OffsetWindowFunctionFrame( newMutableProjection(boundExpressions, Nil).target(target) } - override def prepare(rows: RowBuffer): Unit = { + override def prepare(rows: ExternalAppendOnlyUnsafeRowArray): Unit = { input = rows + inputIterator = input.generateIterator() // drain the first few rows if offset is larger than zero inputIndex = 0 while (inputIndex < offset) { - input.next() + if (inputIterator.hasNext) inputIterator.next() inputIndex += 1 } inputIndex = offset } override def write(index: Int, current: InternalRow): Unit = { - if (inputIndex >= 0 && inputIndex < input.size) { - val r = input.next() + if (inputIndex >= 0 && inputIndex < input.length) { + val r = WindowFunctionFrame.getNextOrNull(inputIterator) projection(r) } else { // Use default values since the offset row does not exist. @@ -143,7 +156,12 @@ private[window] final class SlidingWindowFunctionFrame( extends WindowFunctionFrame { /** Rows of the partition currently being processed. */ - private[this] var input: RowBuffer = null + private[this] var input: ExternalAppendOnlyUnsafeRowArray = null + + /** + * An iterator over the [[input]] + */ + private[this] var inputIterator: Iterator[UnsafeRow] = _ /** The next row from `input`. */ private[this] var nextRow: InternalRow = null @@ -164,9 +182,10 @@ private[window] final class SlidingWindowFunctionFrame( private[this] var inputLowIndex = 0 /** Prepare the frame for calculating a new partition. Reset all variables. */ - override def prepare(rows: RowBuffer): Unit = { + override def prepare(rows: ExternalAppendOnlyUnsafeRowArray): Unit = { input = rows - nextRow = rows.next() + inputIterator = input.generateIterator() + nextRow = WindowFunctionFrame.getNextOrNull(inputIterator) inputHighIndex = 0 inputLowIndex = 0 buffer.clear() @@ -180,7 +199,7 @@ private[window] final class SlidingWindowFunctionFrame( // the output row upper bound. while (nextRow != null && ubound.compare(nextRow, inputHighIndex, current, index) <= 0) { buffer.add(nextRow.copy()) - nextRow = input.next() + nextRow = WindowFunctionFrame.getNextOrNull(inputIterator) inputHighIndex += 1 bufferUpdated = true } @@ -195,7 +214,7 @@ private[window] final class SlidingWindowFunctionFrame( // Only recalculate and update when the buffer changes. if (bufferUpdated) { - processor.initialize(input.size) + processor.initialize(input.length) val iter = buffer.iterator() while (iter.hasNext) { processor.update(iter.next()) @@ -222,13 +241,12 @@ private[window] final class UnboundedWindowFunctionFrame( extends WindowFunctionFrame { /** Prepare the frame for calculating a new partition. Process all rows eagerly. */ - override def prepare(rows: RowBuffer): Unit = { - val size = rows.size - processor.initialize(size) - var i = 0 - while (i < size) { - processor.update(rows.next()) - i += 1 + override def prepare(rows: ExternalAppendOnlyUnsafeRowArray): Unit = { + processor.initialize(rows.length) + + val iterator = rows.generateIterator() + while (iterator.hasNext) { + processor.update(iterator.next()) } } @@ -261,7 +279,12 @@ private[window] final class UnboundedPrecedingWindowFunctionFrame( extends WindowFunctionFrame { /** Rows of the partition currently being processed. */ - private[this] var input: RowBuffer = null + private[this] var input: ExternalAppendOnlyUnsafeRowArray = null + + /** + * An iterator over the [[input]] + */ + private[this] var inputIterator: Iterator[UnsafeRow] = _ /** The next row from `input`. */ private[this] var nextRow: InternalRow = null @@ -273,11 +296,15 @@ private[window] final class UnboundedPrecedingWindowFunctionFrame( private[this] var inputIndex = 0 /** Prepare the frame for calculating a new partition. */ - override def prepare(rows: RowBuffer): Unit = { + override def prepare(rows: ExternalAppendOnlyUnsafeRowArray): Unit = { input = rows - nextRow = rows.next() inputIndex = 0 - processor.initialize(input.size) + inputIterator = input.generateIterator() + if (inputIterator.hasNext) { + nextRow = inputIterator.next() + } + + processor.initialize(input.length) } /** Write the frame columns for the current row to the given target row. */ @@ -288,7 +315,7 @@ private[window] final class UnboundedPrecedingWindowFunctionFrame( // the output row upper bound. while (nextRow != null && ubound.compare(nextRow, inputIndex, current, index) <= 0) { processor.update(nextRow) - nextRow = input.next() + nextRow = WindowFunctionFrame.getNextOrNull(inputIterator) inputIndex += 1 bufferUpdated = true } @@ -323,7 +350,7 @@ private[window] final class UnboundedFollowingWindowFunctionFrame( extends WindowFunctionFrame { /** Rows of the partition currently being processed. */ - private[this] var input: RowBuffer = null + private[this] var input: ExternalAppendOnlyUnsafeRowArray = null /** * Index of the first input row with a value equal to or greater than the lower bound of the @@ -332,7 +359,7 @@ private[window] final class UnboundedFollowingWindowFunctionFrame( private[this] var inputIndex = 0 /** Prepare the frame for calculating a new partition. */ - override def prepare(rows: RowBuffer): Unit = { + override def prepare(rows: ExternalAppendOnlyUnsafeRowArray): Unit = { input = rows inputIndex = 0 } @@ -341,25 +368,25 @@ private[window] final class UnboundedFollowingWindowFunctionFrame( override def write(index: Int, current: InternalRow): Unit = { var bufferUpdated = index == 0 - // Duplicate the input to have a new iterator - val tmp = input.copy() - - // Drop all rows from the buffer for which the input row value is smaller than + // Ignore all the rows from the buffer for which the input row value is smaller than // the output row lower bound. - tmp.skip(inputIndex) - var nextRow = tmp.next() + val iterator = input.generateIterator(startIndex = inputIndex) + + var nextRow = WindowFunctionFrame.getNextOrNull(iterator) while (nextRow != null && lbound.compare(nextRow, inputIndex, current, index) < 0) { - nextRow = tmp.next() inputIndex += 1 bufferUpdated = true + nextRow = WindowFunctionFrame.getNextOrNull(iterator) } // Only recalculate and update when the buffer changes. if (bufferUpdated) { - processor.initialize(input.size) - while (nextRow != null) { + processor.initialize(input.length) + if (nextRow != null) { processor.update(nextRow) - nextRow = tmp.next() + } + while (iterator.hasNext) { + processor.update(iterator.next()) } processor.evaluate(target) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala index 2e006735d123e..1a66aa85f5a02 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql +import scala.collection.mutable.ListBuffer import scala.language.existentials import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation @@ -24,7 +25,7 @@ import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.execution.joins._ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSQLContext - +import org.apache.spark.TestUtils.{assertNotSpilled, assertSpilled} class JoinSuite extends QueryTest with SharedSQLContext { import testImplicits._ @@ -604,4 +605,137 @@ class JoinSuite extends QueryTest with SharedSQLContext { cartesianQueries.foreach(checkCartesianDetection) } + + test("test SortMergeJoin (without spill)") { + withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "1", + "spark.sql.sortMergeJoinExec.buffer.spill.threshold" -> Int.MaxValue.toString) { + + assertNotSpilled(sparkContext, "inner join") { + checkAnswer( + sql("SELECT * FROM testData JOIN testData2 ON key = a where key = 2"), + Row(2, "2", 2, 1) :: Row(2, "2", 2, 2) :: Nil + ) + } + + val expected = new ListBuffer[Row]() + expected.append( + Row(1, "1", 1, 1), Row(1, "1", 1, 2), + Row(2, "2", 2, 1), Row(2, "2", 2, 2), + Row(3, "3", 3, 1), Row(3, "3", 3, 2) + ) + for (i <- 4 to 100) { + expected.append(Row(i, i.toString, null, null)) + } + + assertNotSpilled(sparkContext, "left outer join") { + checkAnswer( + sql( + """ + |SELECT + | big.key, big.value, small.a, small.b + |FROM + | testData big + |LEFT OUTER JOIN + | testData2 small + |ON + | big.key = small.a + """.stripMargin), + expected + ) + } + + assertNotSpilled(sparkContext, "right outer join") { + checkAnswer( + sql( + """ + |SELECT + | big.key, big.value, small.a, small.b + |FROM + | testData2 small + |RIGHT OUTER JOIN + | testData big + |ON + | big.key = small.a + """.stripMargin), + expected + ) + } + } + } + + test("test SortMergeJoin (with spill)") { + withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "1", + "spark.sql.sortMergeJoinExec.buffer.spill.threshold" -> "0") { + + assertSpilled(sparkContext, "inner join") { + checkAnswer( + sql("SELECT * FROM testData JOIN testData2 ON key = a where key = 2"), + Row(2, "2", 2, 1) :: Row(2, "2", 2, 2) :: Nil + ) + } + + val expected = new ListBuffer[Row]() + expected.append( + Row(1, "1", 1, 1), Row(1, "1", 1, 2), + Row(2, "2", 2, 1), Row(2, "2", 2, 2), + Row(3, "3", 3, 1), Row(3, "3", 3, 2) + ) + for (i <- 4 to 100) { + expected.append(Row(i, i.toString, null, null)) + } + + assertSpilled(sparkContext, "left outer join") { + checkAnswer( + sql( + """ + |SELECT + | big.key, big.value, small.a, small.b + |FROM + | testData big + |LEFT OUTER JOIN + | testData2 small + |ON + | big.key = small.a + """.stripMargin), + expected + ) + } + + assertSpilled(sparkContext, "right outer join") { + checkAnswer( + sql( + """ + |SELECT + | big.key, big.value, small.a, small.b + |FROM + | testData2 small + |RIGHT OUTER JOIN + | testData big + |ON + | big.key = small.a + """.stripMargin), + expected + ) + } + + // FULL OUTER JOIN still does not use [[ExternalAppendOnlyUnsafeRowArray]] + // so should not cause any spill + assertNotSpilled(sparkContext, "full outer join") { + checkAnswer( + sql( + """ + |SELECT + | big.key, big.value, small.a, small.b + |FROM + | testData2 small + |FULL OUTER JOIN + | testData big + |ON + | big.key = small.a + """.stripMargin), + expected + ) + } + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/ExternalAppendOnlyUnsafeRowArrayBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/ExternalAppendOnlyUnsafeRowArrayBenchmark.scala new file mode 100644 index 0000000000000..00c5f2550cbb1 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/ExternalAppendOnlyUnsafeRowArrayBenchmark.scala @@ -0,0 +1,233 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution + +import scala.collection.mutable.ArrayBuffer + +import org.apache.spark.{SparkConf, SparkContext, SparkEnv, TaskContext} +import org.apache.spark.memory.MemoryTestingUtils +import org.apache.spark.sql.catalyst.expressions.UnsafeRow +import org.apache.spark.util.Benchmark +import org.apache.spark.util.collection.unsafe.sort.UnsafeExternalSorter + +object ExternalAppendOnlyUnsafeRowArrayBenchmark { + + def testAgainstRawArrayBuffer(numSpillThreshold: Int, numRows: Int, iterations: Int): Unit = { + val random = new java.util.Random() + val rows = (1 to numRows).map(_ => { + val row = new UnsafeRow(1) + row.pointTo(new Array[Byte](64), 16) + row.setLong(0, random.nextLong()) + row + }) + + val benchmark = new Benchmark(s"Array with $numRows rows", iterations * numRows) + + // Internally, `ExternalAppendOnlyUnsafeRowArray` will create an + // in-memory buffer of size `numSpillThreshold`. This will mimic that + val initialSize = + Math.min( + ExternalAppendOnlyUnsafeRowArray.DefaultInitialSizeOfInMemoryBuffer, + numSpillThreshold) + + benchmark.addCase("ArrayBuffer") { _: Int => + var sum = 0L + for (_ <- 0L until iterations) { + val array = new ArrayBuffer[UnsafeRow](initialSize) + + // Internally, `ExternalAppendOnlyUnsafeRowArray` will create a + // copy of the row. This will mimic that + rows.foreach(x => array += x.copy()) + + var i = 0 + val n = array.length + while (i < n) { + sum = sum + array(i).getLong(0) + i += 1 + } + array.clear() + } + } + + benchmark.addCase("ExternalAppendOnlyUnsafeRowArray") { _: Int => + var sum = 0L + for (_ <- 0L until iterations) { + val array = new ExternalAppendOnlyUnsafeRowArray(numSpillThreshold) + rows.foreach(x => array.add(x)) + + val iterator = array.generateIterator() + while (iterator.hasNext) { + sum = sum + iterator.next().getLong(0) + } + array.clear() + } + } + + val conf = new SparkConf(false) + // Make the Java serializer write a reset instruction (TC_RESET) after each object to test + // for a bug we had with bytes written past the last object in a batch (SPARK-2792) + conf.set("spark.serializer.objectStreamReset", "1") + conf.set("spark.serializer", "org.apache.spark.serializer.JavaSerializer") + + val sc = new SparkContext("local", "test", conf) + val taskContext = MemoryTestingUtils.fakeTaskContext(SparkEnv.get) + TaskContext.setTaskContext(taskContext) + benchmark.run() + sc.stop() + } + + def testAgainstRawUnsafeExternalSorter( + numSpillThreshold: Int, + numRows: Int, + iterations: Int): Unit = { + + val random = new java.util.Random() + val rows = (1 to numRows).map(_ => { + val row = new UnsafeRow(1) + row.pointTo(new Array[Byte](64), 16) + row.setLong(0, random.nextLong()) + row + }) + + val benchmark = new Benchmark(s"Spilling with $numRows rows", iterations * numRows) + + benchmark.addCase("UnsafeExternalSorter") { _: Int => + var sum = 0L + for (_ <- 0L until iterations) { + val array = UnsafeExternalSorter.create( + TaskContext.get().taskMemoryManager(), + SparkEnv.get.blockManager, + SparkEnv.get.serializerManager, + TaskContext.get(), + null, + null, + 1024, + SparkEnv.get.memoryManager.pageSizeBytes, + numSpillThreshold, + false) + + rows.foreach(x => + array.insertRecord( + x.getBaseObject, + x.getBaseOffset, + x.getSizeInBytes, + 0, + false)) + + val unsafeRow = new UnsafeRow(1) + val iter = array.getIterator + while (iter.hasNext) { + iter.loadNext() + unsafeRow.pointTo(iter.getBaseObject, iter.getBaseOffset, iter.getRecordLength) + sum = sum + unsafeRow.getLong(0) + } + array.cleanupResources() + } + } + + benchmark.addCase("ExternalAppendOnlyUnsafeRowArray") { _: Int => + var sum = 0L + for (_ <- 0L until iterations) { + val array = new ExternalAppendOnlyUnsafeRowArray(numSpillThreshold) + rows.foreach(x => array.add(x)) + + val iterator = array.generateIterator() + while (iterator.hasNext) { + sum = sum + iterator.next().getLong(0) + } + array.clear() + } + } + + val conf = new SparkConf(false) + // Make the Java serializer write a reset instruction (TC_RESET) after each object to test + // for a bug we had with bytes written past the last object in a batch (SPARK-2792) + conf.set("spark.serializer.objectStreamReset", "1") + conf.set("spark.serializer", "org.apache.spark.serializer.JavaSerializer") + + val sc = new SparkContext("local", "test", conf) + val taskContext = MemoryTestingUtils.fakeTaskContext(SparkEnv.get) + TaskContext.setTaskContext(taskContext) + benchmark.run() + sc.stop() + } + + def main(args: Array[String]): Unit = { + + // ========================================================================================= // + // WITHOUT SPILL + // ========================================================================================= // + + val spillThreshold = 100 * 1000 + + /* + Intel(R) Core(TM) i7-6920HQ CPU @ 2.90GHz + + Array with 1000 rows: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ------------------------------------------------------------------------------------------------ + ArrayBuffer 7821 / 7941 33.5 29.8 1.0X + ExternalAppendOnlyUnsafeRowArray 8798 / 8819 29.8 33.6 0.9X + */ + testAgainstRawArrayBuffer(spillThreshold, 1000, 1 << 18) + + /* + Intel(R) Core(TM) i7-6920HQ CPU @ 2.90GHz + + Array with 30000 rows: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ------------------------------------------------------------------------------------------------ + ArrayBuffer 19200 / 19206 25.6 39.1 1.0X + ExternalAppendOnlyUnsafeRowArray 19558 / 19562 25.1 39.8 1.0X + */ + testAgainstRawArrayBuffer(spillThreshold, 30 * 1000, 1 << 14) + + /* + Intel(R) Core(TM) i7-6920HQ CPU @ 2.90GHz + + Array with 100000 rows: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ------------------------------------------------------------------------------------------------ + ArrayBuffer 5949 / 6028 17.2 58.1 1.0X + ExternalAppendOnlyUnsafeRowArray 6078 / 6138 16.8 59.4 1.0X + */ + testAgainstRawArrayBuffer(spillThreshold, 100 * 1000, 1 << 10) + + // ========================================================================================= // + // WITH SPILL + // ========================================================================================= // + + /* + Intel(R) Core(TM) i7-6920HQ CPU @ 2.90GHz + + Spilling with 1000 rows: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ------------------------------------------------------------------------------------------------ + UnsafeExternalSorter 9239 / 9470 28.4 35.2 1.0X + ExternalAppendOnlyUnsafeRowArray 8857 / 8909 29.6 33.8 1.0X + */ + testAgainstRawUnsafeExternalSorter(100 * 1000, 1000, 1 << 18) + + /* + Intel(R) Core(TM) i7-6920HQ CPU @ 2.90GHz + + Spilling with 10000 rows: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ------------------------------------------------------------------------------------------------ + UnsafeExternalSorter 4 / 5 39.3 25.5 1.0X + ExternalAppendOnlyUnsafeRowArray 5 / 6 29.8 33.5 0.8X + */ + testAgainstRawUnsafeExternalSorter( + UnsafeExternalSorter.DEFAULT_NUM_ELEMENTS_FOR_SPILL_THRESHOLD.toInt, 10 * 1000, 1 << 4) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/ExternalAppendOnlyUnsafeRowArraySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/ExternalAppendOnlyUnsafeRowArraySuite.scala new file mode 100644 index 0000000000000..53c41639942b4 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/ExternalAppendOnlyUnsafeRowArraySuite.scala @@ -0,0 +1,351 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution + +import java.util.ConcurrentModificationException + +import scala.collection.mutable.ArrayBuffer + +import org.apache.spark._ +import org.apache.spark.memory.MemoryTestingUtils +import org.apache.spark.sql.catalyst.expressions.UnsafeRow + +class ExternalAppendOnlyUnsafeRowArraySuite extends SparkFunSuite with LocalSparkContext { + private val random = new java.util.Random() + private var taskContext: TaskContext = _ + + override def afterAll(): Unit = TaskContext.unset() + + private def withExternalArray(spillThreshold: Int) + (f: ExternalAppendOnlyUnsafeRowArray => Unit): Unit = { + sc = new SparkContext("local", "test", new SparkConf(false)) + + taskContext = MemoryTestingUtils.fakeTaskContext(SparkEnv.get) + TaskContext.setTaskContext(taskContext) + + val array = new ExternalAppendOnlyUnsafeRowArray( + taskContext.taskMemoryManager(), + SparkEnv.get.blockManager, + SparkEnv.get.serializerManager, + taskContext, + 1024, + SparkEnv.get.memoryManager.pageSizeBytes, + spillThreshold) + try f(array) finally { + array.clear() + } + } + + private def insertRow(array: ExternalAppendOnlyUnsafeRowArray): Long = { + val valueInserted = random.nextLong() + + val row = new UnsafeRow(1) + row.pointTo(new Array[Byte](64), 16) + row.setLong(0, valueInserted) + array.add(row) + valueInserted + } + + private def checkIfValueExists(iterator: Iterator[UnsafeRow], expectedValue: Long): Unit = { + assert(iterator.hasNext) + val actualRow = iterator.next() + assert(actualRow.getLong(0) == expectedValue) + assert(actualRow.getSizeInBytes == 16) + } + + private def validateData( + array: ExternalAppendOnlyUnsafeRowArray, + expectedValues: ArrayBuffer[Long]): Iterator[UnsafeRow] = { + val iterator = array.generateIterator() + for (value <- expectedValues) { + checkIfValueExists(iterator, value) + } + + assert(!iterator.hasNext) + iterator + } + + private def populateRows( + array: ExternalAppendOnlyUnsafeRowArray, + numRowsToBePopulated: Int): ArrayBuffer[Long] = { + val populatedValues = new ArrayBuffer[Long] + populateRows(array, numRowsToBePopulated, populatedValues) + } + + private def populateRows( + array: ExternalAppendOnlyUnsafeRowArray, + numRowsToBePopulated: Int, + populatedValues: ArrayBuffer[Long]): ArrayBuffer[Long] = { + for (_ <- 0 until numRowsToBePopulated) { + populatedValues.append(insertRow(array)) + } + populatedValues + } + + private def getNumBytesSpilled: Long = { + TaskContext.get().taskMetrics().memoryBytesSpilled + } + + private def assertNoSpill(): Unit = { + assert(getNumBytesSpilled == 0) + } + + private def assertSpill(): Unit = { + assert(getNumBytesSpilled > 0) + } + + test("insert rows less than the spillThreshold") { + val spillThreshold = 100 + withExternalArray(spillThreshold) { array => + assert(array.isEmpty) + + val expectedValues = populateRows(array, 1) + assert(!array.isEmpty) + assert(array.length == 1) + + val iterator1 = validateData(array, expectedValues) + + // Add more rows (but not too many to trigger switch to [[UnsafeExternalSorter]]) + // Verify that NO spill has happened + populateRows(array, spillThreshold - 1, expectedValues) + assert(array.length == spillThreshold) + assertNoSpill() + + val iterator2 = validateData(array, expectedValues) + + assert(!iterator1.hasNext) + assert(!iterator2.hasNext) + } + } + + test("insert rows more than the spillThreshold to force spill") { + val spillThreshold = 100 + withExternalArray(spillThreshold) { array => + val numValuesInserted = 20 * spillThreshold + + assert(array.isEmpty) + val expectedValues = populateRows(array, 1) + assert(array.length == 1) + + val iterator1 = validateData(array, expectedValues) + + // Populate more rows to trigger spill. Verify that spill has happened + populateRows(array, numValuesInserted - 1, expectedValues) + assert(array.length == numValuesInserted) + assertSpill() + + val iterator2 = validateData(array, expectedValues) + assert(!iterator2.hasNext) + + assert(!iterator1.hasNext) + intercept[ConcurrentModificationException](iterator1.next()) + } + } + + test("iterator on an empty array should be empty") { + withExternalArray(spillThreshold = 10) { array => + val iterator = array.generateIterator() + assert(array.isEmpty) + assert(array.length == 0) + assert(!iterator.hasNext) + } + } + + test("generate iterator with negative start index") { + withExternalArray(spillThreshold = 2) { array => + val exception = + intercept[ArrayIndexOutOfBoundsException](array.generateIterator(startIndex = -10)) + + assert(exception.getMessage.contains( + "Invalid `startIndex` provided for generating iterator over the array") + ) + } + } + + test("generate iterator with start index exceeding array's size (without spill)") { + val spillThreshold = 2 + withExternalArray(spillThreshold) { array => + populateRows(array, spillThreshold / 2) + + val exception = + intercept[ArrayIndexOutOfBoundsException]( + array.generateIterator(startIndex = spillThreshold * 10)) + assert(exception.getMessage.contains( + "Invalid `startIndex` provided for generating iterator over the array")) + } + } + + test("generate iterator with start index exceeding array's size (with spill)") { + val spillThreshold = 2 + withExternalArray(spillThreshold) { array => + populateRows(array, spillThreshold * 2) + + val exception = + intercept[ArrayIndexOutOfBoundsException]( + array.generateIterator(startIndex = spillThreshold * 10)) + + assert(exception.getMessage.contains( + "Invalid `startIndex` provided for generating iterator over the array")) + } + } + + test("generate iterator with custom start index (without spill)") { + val spillThreshold = 10 + withExternalArray(spillThreshold) { array => + val expectedValues = populateRows(array, spillThreshold) + val startIndex = spillThreshold / 2 + val iterator = array.generateIterator(startIndex = startIndex) + for (i <- startIndex until expectedValues.length) { + checkIfValueExists(iterator, expectedValues(i)) + } + } + } + + test("generate iterator with custom start index (with spill)") { + val spillThreshold = 10 + withExternalArray(spillThreshold) { array => + val expectedValues = populateRows(array, spillThreshold * 10) + val startIndex = spillThreshold * 2 + val iterator = array.generateIterator(startIndex = startIndex) + for (i <- startIndex until expectedValues.length) { + checkIfValueExists(iterator, expectedValues(i)) + } + } + } + + test("test iterator invalidation (without spill)") { + withExternalArray(spillThreshold = 10) { array => + // insert 2 rows, iterate until the first row + populateRows(array, 2) + + var iterator = array.generateIterator() + assert(iterator.hasNext) + iterator.next() + + // Adding more row(s) should invalidate any old iterators + populateRows(array, 1) + assert(!iterator.hasNext) + intercept[ConcurrentModificationException](iterator.next()) + + // Clearing the array should also invalidate any old iterators + iterator = array.generateIterator() + assert(iterator.hasNext) + iterator.next() + + array.clear() + assert(!iterator.hasNext) + intercept[ConcurrentModificationException](iterator.next()) + } + } + + test("test iterator invalidation (with spill)") { + val spillThreshold = 10 + withExternalArray(spillThreshold) { array => + // Populate enough rows so that spill has happens + populateRows(array, spillThreshold * 2) + assertSpill() + + var iterator = array.generateIterator() + assert(iterator.hasNext) + iterator.next() + + // Adding more row(s) should invalidate any old iterators + populateRows(array, 1) + assert(!iterator.hasNext) + intercept[ConcurrentModificationException](iterator.next()) + + // Clearing the array should also invalidate any old iterators + iterator = array.generateIterator() + assert(iterator.hasNext) + iterator.next() + + array.clear() + assert(!iterator.hasNext) + intercept[ConcurrentModificationException](iterator.next()) + } + } + + test("clear on an empty the array") { + withExternalArray(spillThreshold = 2) { array => + val iterator = array.generateIterator() + assert(!iterator.hasNext) + + // multiple clear'ing should not have an side-effect + array.clear() + array.clear() + array.clear() + assert(array.isEmpty) + assert(array.length == 0) + + // Clearing an empty array should also invalidate any old iterators + assert(!iterator.hasNext) + intercept[ConcurrentModificationException](iterator.next()) + } + } + + test("clear array (without spill)") { + val spillThreshold = 10 + withExternalArray(spillThreshold) { array => + // Populate rows ... but not enough to trigger spill + populateRows(array, spillThreshold / 2) + assertNoSpill() + + // Clear the array + array.clear() + assert(array.isEmpty) + + // Re-populate few rows so that there is no spill + // Verify the data. Verify that there was no spill + val expectedValues = populateRows(array, spillThreshold / 3) + validateData(array, expectedValues) + assertNoSpill() + + // Populate more rows .. enough to not trigger a spill. + // Verify the data. Verify that there was no spill + populateRows(array, spillThreshold / 3, expectedValues) + validateData(array, expectedValues) + assertNoSpill() + } + } + + test("clear array (with spill)") { + val spillThreshold = 10 + withExternalArray(spillThreshold) { array => + // Populate enough rows to trigger spill + populateRows(array, spillThreshold * 2) + val bytesSpilled = getNumBytesSpilled + assert(bytesSpilled > 0) + + // Clear the array + array.clear() + assert(array.isEmpty) + + // Re-populate the array ... but NOT upto the point that there is spill. + // Verify data. Verify that there was NO "extra" spill + val expectedValues = populateRows(array, spillThreshold / 2) + validateData(array, expectedValues) + assert(getNumBytesSpilled == bytesSpilled) + + // Populate more rows to trigger spill + // Verify the data. Verify that there was "extra" spill + populateRows(array, spillThreshold * 2, expectedValues) + validateData(array, expectedValues) + assert(getNumBytesSpilled > bytesSpilled) + } + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLWindowFunctionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLWindowFunctionSuite.scala index afd47897ed4b2..52e4f047225de 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLWindowFunctionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLWindowFunctionSuite.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.execution import org.apache.spark.sql.{AnalysisException, QueryTest, Row} import org.apache.spark.sql.test.SharedSQLContext +import org.apache.spark.TestUtils.assertSpilled case class WindowData(month: Int, area: String, product: Int) @@ -412,4 +413,36 @@ class SQLWindowFunctionSuite extends QueryTest with SharedSQLContext { """.stripMargin), Row(1, 3, null) :: Row(2, null, 4) :: Nil) } + + test("test with low buffer spill threshold") { + val nums = sparkContext.parallelize(1 to 10).map(x => (x, x % 2)).toDF("x", "y") + nums.createOrReplaceTempView("nums") + + val expected = + Row(1, 1, 1) :: + Row(0, 2, 3) :: + Row(1, 3, 6) :: + Row(0, 4, 10) :: + Row(1, 5, 15) :: + Row(0, 6, 21) :: + Row(1, 7, 28) :: + Row(0, 8, 36) :: + Row(1, 9, 45) :: + Row(0, 10, 55) :: Nil + + val actual = sql( + """ + |SELECT y, x, sum(x) OVER w1 AS running_sum + |FROM nums + |WINDOW w1 AS (ORDER BY x ROWS BETWEEN UNBOUNDED PRECEDiNG AND CURRENT RoW) + """.stripMargin) + + withSQLConf("spark.sql.windowExec.buffer.spill.threshold" -> "1") { + assertSpilled(sparkContext, "test with low buffer spill threshold") { + checkAnswer(actual, expected) + } + } + + spark.catalog.dropTempView("nums") + } } From 97cc5e5a5555519d221d0ca78645dde9bb8ea40b Mon Sep 17 00:00:00 2001 From: jiangxingbo Date: Wed, 15 Mar 2017 14:58:19 -0700 Subject: [PATCH 0027/1765] [SPARK-19960][CORE] Move `SparkHadoopWriter` to `internal/io/` ## What changes were proposed in this pull request? This PR introduces the following changes: 1. Move `SparkHadoopWriter` to `core/internal/io/`, so that it's in the same directory with `SparkHadoopMapReduceWriter`; 2. Move `SparkHadoopWriterUtils` to a separated file. After this PR is merged, we may consolidate `SparkHadoopWriter` and `SparkHadoopMapReduceWriter`, and make the new commit protocol support the old `mapred` package's committer; ## How was this patch tested? Tested by existing test cases. Author: jiangxingbo Closes #17304 from jiangxb1987/writer. --- .../io/SparkHadoopMapReduceWriter.scala | 59 ------------ .../{ => internal/io}/SparkHadoopWriter.scala | 7 +- .../internal/io/SparkHadoopWriterUtils.scala | 93 +++++++++++++++++++ .../apache/spark/rdd/PairRDDFunctions.scala | 3 +- .../OutputCommitCoordinatorSuite.scala | 1 + 5 files changed, 99 insertions(+), 64 deletions(-) rename core/src/main/scala/org/apache/spark/{ => internal/io}/SparkHadoopWriter.scala (97%) create mode 100644 core/src/main/scala/org/apache/spark/internal/io/SparkHadoopWriterUtils.scala diff --git a/core/src/main/scala/org/apache/spark/internal/io/SparkHadoopMapReduceWriter.scala b/core/src/main/scala/org/apache/spark/internal/io/SparkHadoopMapReduceWriter.scala index 659ad5d0bad8c..376ff9bb19f74 100644 --- a/core/src/main/scala/org/apache/spark/internal/io/SparkHadoopMapReduceWriter.scala +++ b/core/src/main/scala/org/apache/spark/internal/io/SparkHadoopMapReduceWriter.scala @@ -179,62 +179,3 @@ object SparkHadoopMapReduceWriter extends Logging { } } } - -private[spark] -object SparkHadoopWriterUtils { - - private val RECORDS_BETWEEN_BYTES_WRITTEN_METRIC_UPDATES = 256 - - def createJobID(time: Date, id: Int): JobID = { - val jobtrackerID = createJobTrackerID(time) - new JobID(jobtrackerID, id) - } - - def createJobTrackerID(time: Date): String = { - new SimpleDateFormat("yyyyMMddHHmmss", Locale.US).format(time) - } - - def createPathFromString(path: String, conf: JobConf): Path = { - if (path == null) { - throw new IllegalArgumentException("Output path is null") - } - val outputPath = new Path(path) - val fs = outputPath.getFileSystem(conf) - if (fs == null) { - throw new IllegalArgumentException("Incorrectly formatted output path") - } - outputPath.makeQualified(fs.getUri, fs.getWorkingDirectory) - } - - // Note: this needs to be a function instead of a 'val' so that the disableOutputSpecValidation - // setting can take effect: - def isOutputSpecValidationEnabled(conf: SparkConf): Boolean = { - val validationDisabled = disableOutputSpecValidation.value - val enabledInConf = conf.getBoolean("spark.hadoop.validateOutputSpecs", true) - enabledInConf && !validationDisabled - } - - // TODO: these don't seem like the right abstractions. - // We should abstract the duplicate code in a less awkward way. - - def initHadoopOutputMetrics(context: TaskContext): (OutputMetrics, () => Long) = { - val bytesWrittenCallback = SparkHadoopUtil.get.getFSBytesWrittenOnThreadCallback() - (context.taskMetrics().outputMetrics, bytesWrittenCallback) - } - - def maybeUpdateOutputMetrics( - outputMetrics: OutputMetrics, - callback: () => Long, - recordsWritten: Long): Unit = { - if (recordsWritten % RECORDS_BETWEEN_BYTES_WRITTEN_METRIC_UPDATES == 0) { - outputMetrics.setBytesWritten(callback()) - outputMetrics.setRecordsWritten(recordsWritten) - } - } - - /** - * Allows for the `spark.hadoop.validateOutputSpecs` checks to be disabled on a case-by-case - * basis; see SPARK-4835 for more details. - */ - val disableOutputSpecValidation: DynamicVariable[Boolean] = new DynamicVariable[Boolean](false) -} diff --git a/core/src/main/scala/org/apache/spark/SparkHadoopWriter.scala b/core/src/main/scala/org/apache/spark/internal/io/SparkHadoopWriter.scala similarity index 97% rename from core/src/main/scala/org/apache/spark/SparkHadoopWriter.scala rename to core/src/main/scala/org/apache/spark/internal/io/SparkHadoopWriter.scala index 46e22b215b8ee..acc9c38571007 100644 --- a/core/src/main/scala/org/apache/spark/SparkHadoopWriter.scala +++ b/core/src/main/scala/org/apache/spark/internal/io/SparkHadoopWriter.scala @@ -15,19 +15,18 @@ * limitations under the License. */ -package org.apache.spark +package org.apache.spark.internal.io import java.io.IOException -import java.text.NumberFormat -import java.text.SimpleDateFormat +import java.text.{NumberFormat, SimpleDateFormat} import java.util.{Date, Locale} import org.apache.hadoop.fs.FileSystem import org.apache.hadoop.mapred._ import org.apache.hadoop.mapreduce.TaskType +import org.apache.spark.SerializableWritable import org.apache.spark.internal.Logging -import org.apache.spark.internal.io.SparkHadoopWriterUtils import org.apache.spark.mapred.SparkHadoopMapRedUtil import org.apache.spark.rdd.HadoopRDD import org.apache.spark.util.SerializableJobConf diff --git a/core/src/main/scala/org/apache/spark/internal/io/SparkHadoopWriterUtils.scala b/core/src/main/scala/org/apache/spark/internal/io/SparkHadoopWriterUtils.scala new file mode 100644 index 0000000000000..de828a6d6156e --- /dev/null +++ b/core/src/main/scala/org/apache/spark/internal/io/SparkHadoopWriterUtils.scala @@ -0,0 +1,93 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.internal.io + +import java.text.SimpleDateFormat +import java.util.{Date, Locale} + +import scala.util.DynamicVariable + +import org.apache.hadoop.fs.Path +import org.apache.hadoop.mapred.{JobConf, JobID} + +import org.apache.spark.{SparkConf, TaskContext} +import org.apache.spark.deploy.SparkHadoopUtil +import org.apache.spark.executor.OutputMetrics + +/** + * A helper object that provide common utils used during saving an RDD using a Hadoop OutputFormat + * (both from the old mapred API and the new mapreduce API) + */ +private[spark] +object SparkHadoopWriterUtils { + + private val RECORDS_BETWEEN_BYTES_WRITTEN_METRIC_UPDATES = 256 + + def createJobID(time: Date, id: Int): JobID = { + val jobtrackerID = createJobTrackerID(time) + new JobID(jobtrackerID, id) + } + + def createJobTrackerID(time: Date): String = { + new SimpleDateFormat("yyyyMMddHHmmss", Locale.US).format(time) + } + + def createPathFromString(path: String, conf: JobConf): Path = { + if (path == null) { + throw new IllegalArgumentException("Output path is null") + } + val outputPath = new Path(path) + val fs = outputPath.getFileSystem(conf) + if (fs == null) { + throw new IllegalArgumentException("Incorrectly formatted output path") + } + outputPath.makeQualified(fs.getUri, fs.getWorkingDirectory) + } + + // Note: this needs to be a function instead of a 'val' so that the disableOutputSpecValidation + // setting can take effect: + def isOutputSpecValidationEnabled(conf: SparkConf): Boolean = { + val validationDisabled = disableOutputSpecValidation.value + val enabledInConf = conf.getBoolean("spark.hadoop.validateOutputSpecs", true) + enabledInConf && !validationDisabled + } + + // TODO: these don't seem like the right abstractions. + // We should abstract the duplicate code in a less awkward way. + + def initHadoopOutputMetrics(context: TaskContext): (OutputMetrics, () => Long) = { + val bytesWrittenCallback = SparkHadoopUtil.get.getFSBytesWrittenOnThreadCallback() + (context.taskMetrics().outputMetrics, bytesWrittenCallback) + } + + def maybeUpdateOutputMetrics( + outputMetrics: OutputMetrics, + callback: () => Long, + recordsWritten: Long): Unit = { + if (recordsWritten % RECORDS_BETWEEN_BYTES_WRITTEN_METRIC_UPDATES == 0) { + outputMetrics.setBytesWritten(callback()) + outputMetrics.setRecordsWritten(recordsWritten) + } + } + + /** + * Allows for the `spark.hadoop.validateOutputSpecs` checks to be disabled on a case-by-case + * basis; see SPARK-4835 for more details. + */ + val disableOutputSpecValidation: DynamicVariable[Boolean] = new DynamicVariable[Boolean](false) +} diff --git a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala index 52ce03ff8cde9..58762cc0838cd 100644 --- a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala +++ b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala @@ -37,7 +37,8 @@ import org.apache.spark._ import org.apache.spark.Partitioner.defaultPartitioner import org.apache.spark.annotation.Experimental import org.apache.spark.deploy.SparkHadoopUtil -import org.apache.spark.internal.io.{SparkHadoopMapReduceWriter, SparkHadoopWriterUtils} +import org.apache.spark.internal.io.{SparkHadoopMapReduceWriter, SparkHadoopWriter, + SparkHadoopWriterUtils} import org.apache.spark.internal.Logging import org.apache.spark.partial.{BoundedDouble, PartialResult} import org.apache.spark.serializer.Serializer diff --git a/core/src/test/scala/org/apache/spark/scheduler/OutputCommitCoordinatorSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/OutputCommitCoordinatorSuite.scala index 83ed12752074d..38b9d40329d48 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/OutputCommitCoordinatorSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/OutputCommitCoordinatorSuite.scala @@ -31,6 +31,7 @@ import org.mockito.stubbing.Answer import org.scalatest.BeforeAndAfter import org.apache.spark._ +import org.apache.spark.internal.io.SparkHadoopWriter import org.apache.spark.rdd.{FakeOutputCommitter, RDD} import org.apache.spark.util.{ThreadUtils, Utils} From 54a3697f1fb562ef9ed8fed9caffc62b84763049 Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Wed, 15 Mar 2017 15:01:16 -0700 Subject: [PATCH 0028/1765] [MINOR][CORE] Fix a info message of `prunePartitions` ## What changes were proposed in this pull request? `PrunedInMemoryFileIndex.prunePartitions` shows `pruned NaN% partitions` for the following case. ```scala scala> Seq.empty[(String, String)].toDF("a", "p").write.partitionBy("p").saveAsTable("t1") scala> sc.setLogLevel("INFO") scala> spark.table("t1").filter($"p" === "1").select($"a").show ... 17/03/13 00:33:04 INFO PrunedInMemoryFileIndex: Selected 0 partitions out of 0, pruned NaN% partitions. ``` After this PR, the message looks like this. ```scala 17/03/15 10:39:48 INFO PrunedInMemoryFileIndex: Selected 0 partitions out of 0, pruned 0 partitions. ``` ## How was this patch tested? Pass the Jenkins with the existing tests. Author: Dongjoon Hyun Closes #17273 from dongjoon-hyun/SPARK-EMPTY-PARTITION. --- .../sql/execution/datasources/PartitioningAwareFileIndex.scala | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningAwareFileIndex.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningAwareFileIndex.scala index a5fa8b3f9385e..db8bbc52aaf4d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningAwareFileIndex.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningAwareFileIndex.scala @@ -186,7 +186,8 @@ abstract class PartitioningAwareFileIndex( val total = partitions.length val selectedSize = selected.length val percentPruned = (1 - selectedSize.toDouble / total.toDouble) * 100 - s"Selected $selectedSize partitions out of $total, pruned $percentPruned% partitions." + s"Selected $selectedSize partitions out of $total, " + + s"pruned ${if (total == 0) "0" else s"$percentPruned%"} partitions." } selected From 046b8d4aef00b0701cf7e4b99aeaf450cacb42fe Mon Sep 17 00:00:00 2001 From: erenavsarogullari Date: Wed, 15 Mar 2017 15:57:51 -0700 Subject: [PATCH 0029/1765] [SPARK-18066][CORE][TESTS] Add Pool usage policies test coverage for FIFO & FAIR Schedulers ## What changes were proposed in this pull request? The following FIFO & FAIR Schedulers Pool usage cases need to have unit test coverage : - FIFO Scheduler just uses **root pool** so even if `spark.scheduler.pool` property is set, related pool is not created and `TaskSetManagers` are added to **root pool**. - FAIR Scheduler uses `default pool` when `spark.scheduler.pool` property is not set. This can be happened when - `Properties` object is **null**, - `Properties` object is **empty**(`new Properties()`), - **default pool** is set(`spark.scheduler.pool=default`). - FAIR Scheduler creates a **new pool** with **default values** when `spark.scheduler.pool` property points a **non-existent** pool. This can be happened when **scheduler allocation file** is not set or it does not contain related pool. ## How was this patch tested? New Unit tests are added. Author: erenavsarogullari Closes #15604 from erenavsarogullari/SPARK-18066. --- .../spark/scheduler/SchedulableBuilder.scala | 7 +- .../apache/spark/scheduler/PoolSuite.scala | 97 +++++++++++++++++-- 2 files changed, 96 insertions(+), 8 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/scheduler/SchedulableBuilder.scala b/core/src/main/scala/org/apache/spark/scheduler/SchedulableBuilder.scala index e53c4fb5b4778..20cedaf060420 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/SchedulableBuilder.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/SchedulableBuilder.scala @@ -191,8 +191,11 @@ private[spark] class FairSchedulableBuilder(val rootPool: Pool, conf: SparkConf) parentPool = new Pool(poolName, DEFAULT_SCHEDULING_MODE, DEFAULT_MINIMUM_SHARE, DEFAULT_WEIGHT) rootPool.addSchedulable(parentPool) - logInfo("Created pool: %s, schedulingMode: %s, minShare: %d, weight: %d".format( - poolName, DEFAULT_SCHEDULING_MODE, DEFAULT_MINIMUM_SHARE, DEFAULT_WEIGHT)) + logWarning(s"A job was submitted with scheduler pool $poolName, which has not been " + + "configured. This can happen when the file that pools are read from isn't set, or " + + s"when that file doesn't contain $poolName. Created $poolName with default " + + s"configuration (schedulingMode: $DEFAULT_SCHEDULING_MODE, " + + s"minShare: $DEFAULT_MINIMUM_SHARE, weight: $DEFAULT_WEIGHT)") } } parentPool.addSchedulable(manager) diff --git a/core/src/test/scala/org/apache/spark/scheduler/PoolSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/PoolSuite.scala index 520736ab64270..cddff3dd35861 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/PoolSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/PoolSuite.scala @@ -31,6 +31,7 @@ class PoolSuite extends SparkFunSuite with LocalSparkContext { val LOCAL = "local" val APP_NAME = "PoolSuite" val SCHEDULER_ALLOCATION_FILE_PROPERTY = "spark.scheduler.allocation.file" + val TEST_POOL = "testPool" def createTaskSetManager(stageId: Int, numTasks: Int, taskScheduler: TaskSchedulerImpl) : TaskSetManager = { @@ -40,7 +41,7 @@ class PoolSuite extends SparkFunSuite with LocalSparkContext { new TaskSetManager(taskScheduler, new TaskSet(tasks, stageId, 0, 0, null), 0) } - def scheduleTaskAndVerifyId(taskId: Int, rootPool: Pool, expectedStageId: Int) { + def scheduleTaskAndVerifyId(taskId: Int, rootPool: Pool, expectedStageId: Int): Unit = { val taskSetQueue = rootPool.getSortedTaskSetQueue val nextTaskSetToSchedule = taskSetQueue.find(t => (t.runningTasks + t.tasksSuccessful) < t.numTasks) @@ -201,12 +202,96 @@ class PoolSuite extends SparkFunSuite with LocalSparkContext { verifyPool(rootPool, "pool_with_surrounded_whitespace", 3, 2, FAIR) } + /** + * spark.scheduler.pool property should be ignored for the FIFO scheduler, + * because pools are only needed for fair scheduling. + */ + test("FIFO scheduler uses root pool and not spark.scheduler.pool property") { + sc = new SparkContext("local", "PoolSuite") + val taskScheduler = new TaskSchedulerImpl(sc) + + val rootPool = new Pool("", SchedulingMode.FIFO, initMinShare = 0, initWeight = 0) + val schedulableBuilder = new FIFOSchedulableBuilder(rootPool) + + val taskSetManager0 = createTaskSetManager(stageId = 0, numTasks = 1, taskScheduler) + val taskSetManager1 = createTaskSetManager(stageId = 1, numTasks = 1, taskScheduler) + + val properties = new Properties() + properties.setProperty("spark.scheduler.pool", TEST_POOL) + + // When FIFO Scheduler is used and task sets are submitted, they should be added to + // the root pool, and no additional pools should be created + // (even though there's a configured default pool). + schedulableBuilder.addTaskSetManager(taskSetManager0, properties) + schedulableBuilder.addTaskSetManager(taskSetManager1, properties) + + assert(rootPool.getSchedulableByName(TEST_POOL) === null) + assert(rootPool.schedulableQueue.size === 2) + assert(rootPool.getSchedulableByName(taskSetManager0.name) === taskSetManager0) + assert(rootPool.getSchedulableByName(taskSetManager1.name) === taskSetManager1) + } + + test("FAIR Scheduler uses default pool when spark.scheduler.pool property is not set") { + sc = new SparkContext("local", "PoolSuite") + val taskScheduler = new TaskSchedulerImpl(sc) + + val rootPool = new Pool("", SchedulingMode.FAIR, initMinShare = 0, initWeight = 0) + val schedulableBuilder = new FairSchedulableBuilder(rootPool, sc.conf) + schedulableBuilder.buildPools() + + // Submit a new task set manager with pool properties set to null. This should result + // in the task set manager getting added to the default pool. + val taskSetManager0 = createTaskSetManager(stageId = 0, numTasks = 1, taskScheduler) + schedulableBuilder.addTaskSetManager(taskSetManager0, null) + + val defaultPool = rootPool.getSchedulableByName(schedulableBuilder.DEFAULT_POOL_NAME) + assert(defaultPool !== null) + assert(defaultPool.schedulableQueue.size === 1) + assert(defaultPool.getSchedulableByName(taskSetManager0.name) === taskSetManager0) + + // When a task set manager is submitted with spark.scheduler.pool unset, it should be added to + // the default pool (as above). + val taskSetManager1 = createTaskSetManager(stageId = 1, numTasks = 1, taskScheduler) + schedulableBuilder.addTaskSetManager(taskSetManager1, new Properties()) + + assert(defaultPool.schedulableQueue.size === 2) + assert(defaultPool.getSchedulableByName(taskSetManager1.name) === taskSetManager1) + } + + test("FAIR Scheduler creates a new pool when spark.scheduler.pool property points to " + + "a non-existent pool") { + sc = new SparkContext("local", "PoolSuite") + val taskScheduler = new TaskSchedulerImpl(sc) + + val rootPool = new Pool("", SchedulingMode.FAIR, initMinShare = 0, initWeight = 0) + val schedulableBuilder = new FairSchedulableBuilder(rootPool, sc.conf) + schedulableBuilder.buildPools() + + assert(rootPool.getSchedulableByName(TEST_POOL) === null) + + val taskSetManager = createTaskSetManager(stageId = 0, numTasks = 1, taskScheduler) + + val properties = new Properties() + properties.setProperty(schedulableBuilder.FAIR_SCHEDULER_PROPERTIES, TEST_POOL) + + // The fair scheduler should create a new pool with default values when spark.scheduler.pool + // points to a pool that doesn't exist yet (this can happen when the file that pools are read + // from isn't set, or when that file doesn't contain the pool name specified + // by spark.scheduler.pool). + schedulableBuilder.addTaskSetManager(taskSetManager, properties) + + verifyPool(rootPool, TEST_POOL, schedulableBuilder.DEFAULT_MINIMUM_SHARE, + schedulableBuilder.DEFAULT_WEIGHT, schedulableBuilder.DEFAULT_SCHEDULING_MODE) + val testPool = rootPool.getSchedulableByName(TEST_POOL) + assert(testPool.getSchedulableByName(taskSetManager.name) === taskSetManager) + } + private def verifyPool(rootPool: Pool, poolName: String, expectedInitMinShare: Int, expectedInitWeight: Int, expectedSchedulingMode: SchedulingMode): Unit = { - assert(rootPool.getSchedulableByName(poolName) != null) - assert(rootPool.getSchedulableByName(poolName).minShare === expectedInitMinShare) - assert(rootPool.getSchedulableByName(poolName).weight === expectedInitWeight) - assert(rootPool.getSchedulableByName(poolName).schedulingMode === expectedSchedulingMode) + val selectedPool = rootPool.getSchedulableByName(poolName) + assert(selectedPool !== null) + assert(selectedPool.minShare === expectedInitMinShare) + assert(selectedPool.weight === expectedInitWeight) + assert(selectedPool.schedulingMode === expectedSchedulingMode) } - } From 7d734a658349e8691d8b4294454c9cd98d555014 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Thu, 16 Mar 2017 08:18:36 +0800 Subject: [PATCH 0030/1765] [SPARK-19931][SQL] InMemoryTableScanExec should rewrite output partitioning and ordering when aliasing output attributes ## What changes were proposed in this pull request? Now `InMemoryTableScanExec` simply takes the `outputPartitioning` and `outputOrdering` from the associated `InMemoryRelation`'s `child.outputPartitioning` and `outputOrdering`. However, `InMemoryTableScanExec` can alias the output attributes. In this case, its `outputPartitioning` and `outputOrdering` are not correct and its parent operators can't correctly determine its data distribution. ## How was this patch tested? Jenkins tests. Please review http://spark.apache.org/contributing.html before opening a pull request. Author: Liang-Chi Hsieh Closes #17175 from viirya/ensure-no-unnecessary-shuffle. --- .../columnar/InMemoryTableScanExec.scala | 21 ++++++++++++--- .../columnar/InMemoryColumnarQuerySuite.scala | 26 +++++++++++++++++++ 2 files changed, 44 insertions(+), 3 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala index 9028caa446e8c..214e8d309de11 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala @@ -22,7 +22,7 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.QueryPlan -import org.apache.spark.sql.catalyst.plans.physical.Partitioning +import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, Partitioning} import org.apache.spark.sql.execution.LeafExecNode import org.apache.spark.sql.execution.metric.SQLMetrics import org.apache.spark.sql.types.UserDefinedType @@ -41,11 +41,26 @@ case class InMemoryTableScanExec( override def output: Seq[Attribute] = attributes + private def updateAttribute(expr: Expression): Expression = { + val attrMap = AttributeMap(relation.child.output.zip(output)) + expr.transform { + case attr: Attribute => attrMap.getOrElse(attr, attr) + } + } + // The cached version does not change the outputPartitioning of the original SparkPlan. - override def outputPartitioning: Partitioning = relation.child.outputPartitioning + // But the cached version could alias output, so we need to replace output. + override def outputPartitioning: Partitioning = { + relation.child.outputPartitioning match { + case h: HashPartitioning => updateAttribute(h).asInstanceOf[HashPartitioning] + case _ => relation.child.outputPartitioning + } + } // The cached version does not change the outputOrdering of the original SparkPlan. - override def outputOrdering: Seq[SortOrder] = relation.child.outputOrdering + // But the cached version could alias output, so we need to replace output. + override def outputOrdering: Seq[SortOrder] = + relation.child.outputOrdering.map(updateAttribute(_).asInstanceOf[SortOrder]) private def statsFor(a: Attribute) = relation.partitionStatistics.forAttribute(a) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala index 0250a53fe2324..1e6a6a8ba3362 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala @@ -21,6 +21,9 @@ import java.nio.charset.StandardCharsets import java.sql.{Date, Timestamp} import org.apache.spark.sql.{DataFrame, QueryTest, Row} +import org.apache.spark.sql.catalyst.expressions.AttributeSet +import org.apache.spark.sql.catalyst.plans.physical.HashPartitioning +import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.test.SQLTestData._ @@ -388,4 +391,27 @@ class InMemoryColumnarQuerySuite extends QueryTest with SharedSQLContext { } } + test("InMemoryTableScanExec should return correct output ordering and partitioning") { + val df1 = Seq((0, 0), (1, 1)).toDF + .repartition(col("_1")).sortWithinPartitions(col("_1")).persist + val df2 = Seq((0, 0), (1, 1)).toDF + .repartition(col("_1")).sortWithinPartitions(col("_1")).persist + + // Because two cached dataframes have the same logical plan, this is a self-join actually. + // So we force one of in-memory relation to alias its output. Then we can test if original and + // aliased in-memory relations have correct ordering and partitioning. + val joined = df1.joinWith(df2, df1("_1") === df2("_1")) + + val inMemoryScans = joined.queryExecution.executedPlan.collect { + case m: InMemoryTableScanExec => m + } + inMemoryScans.foreach { inMemoryScan => + val sortedAttrs = AttributeSet(inMemoryScan.outputOrdering.flatMap(_.references)) + assert(sortedAttrs.subsetOf(inMemoryScan.outputSet)) + + val partitionedAttrs = + inMemoryScan.outputPartitioning.asInstanceOf[HashPartitioning].references + assert(partitionedAttrs.subsetOf(inMemoryScan.outputSet)) + } + } } From 339b237dc18d4367b0735236b4b8be2901fcad79 Mon Sep 17 00:00:00 2001 From: Juliusz Sompolski Date: Thu, 16 Mar 2017 08:20:47 +0800 Subject: [PATCH 0031/1765] [SPARK-19948] Document that saveAsTable uses catalog as source of truth for table existence. It is quirky behaviour that saveAsTable to e.g. a JDBC source with SaveMode other than Overwrite will nevertheless overwrite the table in the external source, if that table was not a catalog table. Author: Juliusz Sompolski Closes #17289 from juliuszsompolski/saveAsTableDoc. --- .../main/scala/org/apache/spark/sql/DataFrameWriter.scala | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala index deaa8006945c1..3e975ef6a3c24 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala @@ -337,6 +337,11 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { * +---+---+ * }}} * + * In this method, save mode is used to determine the behavior if the data source table exists in + * Spark catalog. We will always overwrite the underlying data of data source (e.g. a table in + * JDBC data source) if the table doesn't exist in Spark catalog, and will always append to the + * underlying data of data source if the table already exists. + * * When the DataFrame is created from a non-partitioned `HadoopFsRelation` with a single input * path, and the data source provider can be mapped to an existing Hive builtin SerDe (i.e. ORC * and Parquet), the table is persisted in a Hive compatible format, which means other systems From fc9314671c8a082ae339fd6df177a2b684c65d40 Mon Sep 17 00:00:00 2001 From: windpiger Date: Thu, 16 Mar 2017 08:44:57 +0800 Subject: [PATCH 0032/1765] [SPARK-19961][SQL][MINOR] unify a erro msg when drop databse for HiveExternalCatalog and InMemoryCatalog ## What changes were proposed in this pull request? unify a exception erro msg for dropdatabase when the database still have some tables for HiveExternalCatalog and InMemoryCatalog ## How was this patch tested? N/A Author: windpiger Closes #17305 from windpiger/unifyErromsg. --- .../spark/sql/catalyst/catalog/InMemoryCatalog.scala | 2 +- .../org/apache/spark/sql/execution/command/DDLSuite.scala | 8 ++------ 2 files changed, 3 insertions(+), 7 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/InMemoryCatalog.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/InMemoryCatalog.scala index 5cc6b0abc6fde..cdf618aef97c3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/InMemoryCatalog.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/InMemoryCatalog.scala @@ -127,7 +127,7 @@ class InMemoryCatalog( if (!cascade) { // If cascade is false, make sure the database is empty. if (catalog(db).tables.nonEmpty) { - throw new AnalysisException(s"Database '$db' is not empty. One or more tables exist.") + throw new AnalysisException(s"Database $db is not empty. One or more tables exist.") } if (catalog(db).functions.nonEmpty) { throw new AnalysisException(s"Database '$db' is not empty. One or more functions exist.") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala index 6eed10ec51464..dd76fdde06cdc 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala @@ -617,12 +617,8 @@ abstract class DDLSuite extends QueryTest with SQLTestUtils { val message = intercept[AnalysisException] { sql(s"DROP DATABASE $dbName RESTRICT") }.getMessage - // TODO: Unify the exception. - if (isUsingHiveMetastore) { - assert(message.contains(s"Database $dbName is not empty. One or more tables exist")) - } else { - assert(message.contains(s"Database '$dbName' is not empty. One or more tables exist")) - } + assert(message.contains(s"Database $dbName is not empty. One or more tables exist")) + catalog.dropTable(tableIdent1, ignoreIfNotExists = false, purge = false) From 21f333c635465069b7657d788052d510ffb0779a Mon Sep 17 00:00:00 2001 From: Takeshi Yamamuro Date: Thu, 16 Mar 2017 08:50:01 +0800 Subject: [PATCH 0033/1765] [SPARK-19751][SQL] Throw an exception if bean class has one's own class in fields ## What changes were proposed in this pull request? The current master throws `StackOverflowError` in `createDataFrame`/`createDataset` if bean has one's own class in fields; ``` public class SelfClassInFieldBean implements Serializable { private SelfClassInFieldBean child; ... } ``` This pr added code to throw `UnsupportedOperationException` in that case as soon as possible. ## How was this patch tested? Added tests in `JavaDataFrameSuite` and `JavaDatasetSuite`. Author: Takeshi Yamamuro Closes #17188 from maropu/SPARK-19751. --- .../sql/catalyst/JavaTypeInference.scala | 19 ++-- .../apache/spark/sql/JavaDataFrameSuite.java | 32 +++++++ .../apache/spark/sql/JavaDatasetSuite.java | 87 +++++++++++++++++++ 3 files changed, 132 insertions(+), 6 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala index e9d9508e5adfe..4ff87edde139a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala @@ -69,7 +69,8 @@ object JavaTypeInference { * @param typeToken Java type * @return (SQL data type, nullable) */ - private def inferDataType(typeToken: TypeToken[_]): (DataType, Boolean) = { + private def inferDataType(typeToken: TypeToken[_], seenTypeSet: Set[Class[_]] = Set.empty) + : (DataType, Boolean) = { typeToken.getRawType match { case c: Class[_] if c.isAnnotationPresent(classOf[SQLUserDefinedType]) => (c.getAnnotation(classOf[SQLUserDefinedType]).udt().newInstance(), true) @@ -104,26 +105,32 @@ object JavaTypeInference { case c: Class[_] if c == classOf[java.sql.Timestamp] => (TimestampType, true) case _ if typeToken.isArray => - val (dataType, nullable) = inferDataType(typeToken.getComponentType) + val (dataType, nullable) = inferDataType(typeToken.getComponentType, seenTypeSet) (ArrayType(dataType, nullable), true) case _ if iterableType.isAssignableFrom(typeToken) => - val (dataType, nullable) = inferDataType(elementType(typeToken)) + val (dataType, nullable) = inferDataType(elementType(typeToken), seenTypeSet) (ArrayType(dataType, nullable), true) case _ if mapType.isAssignableFrom(typeToken) => val (keyType, valueType) = mapKeyValueType(typeToken) - val (keyDataType, _) = inferDataType(keyType) - val (valueDataType, nullable) = inferDataType(valueType) + val (keyDataType, _) = inferDataType(keyType, seenTypeSet) + val (valueDataType, nullable) = inferDataType(valueType, seenTypeSet) (MapType(keyDataType, valueDataType, nullable), true) case other => + if (seenTypeSet.contains(other)) { + throw new UnsupportedOperationException( + "Cannot have circular references in bean class, but got the circular reference " + + s"of class $other") + } + // TODO: we should only collect properties that have getter and setter. However, some tests // pass in scala case class as java bean class which doesn't have getter and setter. val properties = getJavaBeanReadableProperties(other) val fields = properties.map { property => val returnType = typeToken.method(property.getReadMethod).getReturnType - val (dataType, nullable) = inferDataType(returnType) + val (dataType, nullable) = inferDataType(returnType, seenTypeSet + other) new StructField(property.getName, dataType, nullable) } (new StructType(fields), true) diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java index be8d95d0d9124..b007093dad84b 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java @@ -423,4 +423,36 @@ public void testJsonRDDToDataFrame() { Assert.assertEquals(1L, df.count()); Assert.assertEquals(2L, df.collectAsList().get(0).getLong(0)); } + + public class CircularReference1Bean implements Serializable { + private CircularReference2Bean child; + + public CircularReference2Bean getChild() { + return child; + } + + public void setChild(CircularReference2Bean child) { + this.child = child; + } + } + + public class CircularReference2Bean implements Serializable { + private CircularReference1Bean child; + + public CircularReference1Bean getChild() { + return child; + } + + public void setChild(CircularReference1Bean child) { + this.child = child; + } + } + + // Checks a simple case for DataFrame here and put exhaustive tests for the issue + // of circular references in `JavaDatasetSuite`. + @Test(expected = UnsupportedOperationException.class) + public void testCircularReferenceBean() { + CircularReference1Bean bean = new CircularReference1Bean(); + spark.createDataFrame(Arrays.asList(bean), CircularReference1Bean.class); + } } diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java index d06e35bb44d08..439cac3dfbcb7 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java @@ -1291,4 +1291,91 @@ public void testEmptyBean() { Assert.assertEquals(df.schema().length(), 0); Assert.assertEquals(df.collectAsList().size(), 1); } + + public class CircularReference1Bean implements Serializable { + private CircularReference2Bean child; + + public CircularReference2Bean getChild() { + return child; + } + + public void setChild(CircularReference2Bean child) { + this.child = child; + } + } + + public class CircularReference2Bean implements Serializable { + private CircularReference1Bean child; + + public CircularReference1Bean getChild() { + return child; + } + + public void setChild(CircularReference1Bean child) { + this.child = child; + } + } + + public class CircularReference3Bean implements Serializable { + private CircularReference3Bean[] child; + + public CircularReference3Bean[] getChild() { + return child; + } + + public void setChild(CircularReference3Bean[] child) { + this.child = child; + } + } + + public class CircularReference4Bean implements Serializable { + private Map child; + + public Map getChild() { + return child; + } + + public void setChild(Map child) { + this.child = child; + } + } + + public class CircularReference5Bean implements Serializable { + private String id; + private List child; + + public String getId() { + return id; + } + + public List getChild() { + return child; + } + + public void setId(String id) { + this.id = id; + } + + public void setChild(List child) { + this.child = child; + } + } + + @Test(expected = UnsupportedOperationException.class) + public void testCircularReferenceBean1() { + CircularReference1Bean bean = new CircularReference1Bean(); + spark.createDataset(Arrays.asList(bean), Encoders.bean(CircularReference1Bean.class)); + } + + @Test(expected = UnsupportedOperationException.class) + public void testCircularReferenceBean2() { + CircularReference3Bean bean = new CircularReference3Bean(); + spark.createDataset(Arrays.asList(bean), Encoders.bean(CircularReference3Bean.class)); + } + + @Test(expected = UnsupportedOperationException.class) + public void testCircularReferenceBean3() { + CircularReference4Bean bean = new CircularReference4Bean(); + spark.createDataset(Arrays.asList(bean), Encoders.bean(CircularReference4Bean.class)); + } } From 1472cac4bb31c1886f82830778d34c4dd9030d7a Mon Sep 17 00:00:00 2001 From: Xiao Li Date: Thu, 16 Mar 2017 12:06:20 +0800 Subject: [PATCH 0034/1765] [SPARK-19830][SQL] Add parseTableSchema API to ParserInterface ### What changes were proposed in this pull request? Specifying the table schema in DDL formats is needed for different scenarios. For example, - [specifying the schema in SQL function `from_json` using DDL formats](https://issues.apache.org/jira/browse/SPARK-19637), which is suggested by marmbrus , - [specifying the customized JDBC data types](https://github.com/apache/spark/pull/16209). These two PRs need users to use the JSON format to specify the table schema. This is not user friendly. This PR is to provide a `parseTableSchema` API in `ParserInterface`. ### How was this patch tested? Added a test suite `TableSchemaParserSuite` Author: Xiao Li Closes #17171 from gatorsmile/parseDDLStmt. --- .../sql/catalyst/parser/ParseDriver.scala | 10 ++- .../sql/catalyst/parser/ParserInterface.scala | 7 ++ .../parser/TableSchemaParserSuite.scala | 88 +++++++++++++++++++ 3 files changed, 104 insertions(+), 1 deletion(-) create mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/TableSchemaParserSuite.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParseDriver.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParseDriver.scala index d687a85c18b63..f704b0998cada 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParseDriver.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParseDriver.scala @@ -26,7 +26,7 @@ import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.expressions.Expression import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.trees.Origin -import org.apache.spark.sql.types.DataType +import org.apache.spark.sql.types.{DataType, StructType} /** * Base SQL parsing infrastructure. @@ -49,6 +49,14 @@ abstract class AbstractSqlParser extends ParserInterface with Logging { astBuilder.visitSingleTableIdentifier(parser.singleTableIdentifier()) } + /** + * Creates StructType for a given SQL string, which is a comma separated list of field + * definitions which will preserve the correct Hive metadata. + */ + override def parseTableSchema(sqlText: String): StructType = parse(sqlText) { parser => + StructType(astBuilder.visitColTypeList(parser.colTypeList())) + } + /** Creates LogicalPlan for a given SQL string. */ override def parsePlan(sqlText: String): LogicalPlan = parse(sqlText) { parser => astBuilder.visitSingleStatement(parser.singleStatement()) match { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParserInterface.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParserInterface.scala index 7f35d650b9571..6edbe253970e9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParserInterface.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParserInterface.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.catalyst.parser import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.expressions.Expression import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.types.StructType /** * Interface for a parser. @@ -33,4 +34,10 @@ trait ParserInterface { /** Creates TableIdentifier for a given SQL string. */ def parseTableIdentifier(sqlText: String): TableIdentifier + + /** + * Creates StructType for a given SQL string, which is a comma separated list of field + * definitions which will preserve the correct Hive metadata. + */ + def parseTableSchema(sqlText: String): StructType } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/TableSchemaParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/TableSchemaParserSuite.scala new file mode 100644 index 0000000000000..da1041d617086 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/TableSchemaParserSuite.scala @@ -0,0 +1,88 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +package org.apache.spark.sql.catalyst.parser + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.types._ + +class TableSchemaParserSuite extends SparkFunSuite { + + def parse(sql: String): StructType = CatalystSqlParser.parseTableSchema(sql) + + def checkTableSchema(tableSchemaString: String, expectedDataType: DataType): Unit = { + test(s"parse $tableSchemaString") { + assert(parse(tableSchemaString) === expectedDataType) + } + } + + def assertError(sql: String): Unit = + intercept[ParseException](CatalystSqlParser.parseTableSchema(sql)) + + checkTableSchema("a int", new StructType().add("a", "int")) + checkTableSchema("A int", new StructType().add("A", "int")) + checkTableSchema("a INT", new StructType().add("a", "int")) + checkTableSchema("`!@#$%.^&*()` string", new StructType().add("!@#$%.^&*()", "string")) + checkTableSchema("a int, b long", new StructType().add("a", "int").add("b", "long")) + checkTableSchema("a STRUCT", + StructType( + StructField("a", StructType( + StructField("intType", IntegerType) :: + StructField("ts", TimestampType) :: Nil)) :: Nil)) + checkTableSchema( + "a int comment 'test'", + new StructType().add("a", "int", nullable = true, "test")) + + test("complex hive type") { + val tableSchemaString = + """ + |complexStructCol struct< + |struct:struct, + |MAP:Map, + |arrAy:Array, + |anotherArray:Array> + """.stripMargin.replace("\n", "") + + val builder = new MetadataBuilder + builder.putString(HIVE_TYPE_STRING, + "struct," + + "MAP:map,arrAy:array,anotherArray:array>") + + val expectedDataType = + StructType( + StructField("complexStructCol", StructType( + StructField("struct", + StructType( + StructField("deciMal", DecimalType.USER_DEFAULT) :: + StructField("anotherDecimal", DecimalType(5, 2)) :: Nil)) :: + StructField("MAP", MapType(TimestampType, StringType)) :: + StructField("arrAy", ArrayType(DoubleType)) :: + StructField("anotherArray", ArrayType(StringType)) :: Nil), + nullable = true, + builder.build()) :: Nil) + + assert(parse(tableSchemaString) === expectedDataType) + } + + // Negative cases + assertError("") + assertError("a") + assertError("a INT b long") + assertError("a INT,, b long") + assertError("a INT, b long,,") + assertError("a INT, b long, c int,") +} From d647aae278ef31a07fc64715eb07e48294d94bb8 Mon Sep 17 00:00:00 2001 From: Yuhao Yang Date: Thu, 16 Mar 2017 12:49:59 +0200 Subject: [PATCH 0035/1765] [SPARK-13568][ML] Create feature transformer to impute missing values ## What changes were proposed in this pull request? jira: https://issues.apache.org/jira/browse/SPARK-13568 It is quite common to encounter missing values in data sets. It would be useful to implement a Transformer that can impute missing data points, similar to e.g. Imputer in scikit-learn. Initially, options for imputation could include mean, median and most frequent, but we could add various other approaches, where possible existing DataFrame code can be used (e.g. for approximate quantiles etc). Currently this PR supports imputation for Double and Vector (null and NaN in Vector). ## How was this patch tested? new unit tests and manual test Author: Yuhao Yang Author: Yuhao Yang Author: Yuhao Closes #11601 from hhbyyh/imputer. --- .../org/apache/spark/ml/feature/Imputer.scala | 259 ++++++++++++++++++ .../spark/ml/feature/ImputerSuite.scala | 185 +++++++++++++ 2 files changed, 444 insertions(+) create mode 100644 mllib/src/main/scala/org/apache/spark/ml/feature/Imputer.scala create mode 100644 mllib/src/test/scala/org/apache/spark/ml/feature/ImputerSuite.scala diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Imputer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Imputer.scala new file mode 100644 index 0000000000000..b1a802ee13fc4 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Imputer.scala @@ -0,0 +1,259 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.feature + +import org.apache.hadoop.fs.Path + +import org.apache.spark.SparkException +import org.apache.spark.annotation.{Experimental, Since} +import org.apache.spark.ml.{Estimator, Model} +import org.apache.spark.ml.param._ +import org.apache.spark.ml.param.shared.HasInputCols +import org.apache.spark.ml.util._ +import org.apache.spark.sql.{DataFrame, Dataset, Row} +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.types._ + +/** + * Params for [[Imputer]] and [[ImputerModel]]. + */ +private[feature] trait ImputerParams extends Params with HasInputCols { + + /** + * The imputation strategy. + * If "mean", then replace missing values using the mean value of the feature. + * If "median", then replace missing values using the approximate median value of the feature. + * Default: mean + * + * @group param + */ + final val strategy: Param[String] = new Param(this, "strategy", s"strategy for imputation. " + + s"If ${Imputer.mean}, then replace missing values using the mean value of the feature. " + + s"If ${Imputer.median}, then replace missing values using the median value of the feature.", + ParamValidators.inArray[String](Array(Imputer.mean, Imputer.median))) + + /** @group getParam */ + def getStrategy: String = $(strategy) + + /** + * The placeholder for the missing values. All occurrences of missingValue will be imputed. + * Note that null values are always treated as missing. + * Default: Double.NaN + * + * @group param + */ + final val missingValue: DoubleParam = new DoubleParam(this, "missingValue", + "The placeholder for the missing values. All occurrences of missingValue will be imputed") + + /** @group getParam */ + def getMissingValue: Double = $(missingValue) + + /** + * Param for output column names. + * @group param + */ + final val outputCols: StringArrayParam = new StringArrayParam(this, "outputCols", + "output column names") + + /** @group getParam */ + final def getOutputCols: Array[String] = $(outputCols) + + /** Validates and transforms the input schema. */ + protected def validateAndTransformSchema(schema: StructType): StructType = { + require($(inputCols).length == $(inputCols).distinct.length, s"inputCols contains" + + s" duplicates: (${$(inputCols).mkString(", ")})") + require($(outputCols).length == $(outputCols).distinct.length, s"outputCols contains" + + s" duplicates: (${$(outputCols).mkString(", ")})") + require($(inputCols).length == $(outputCols).length, s"inputCols(${$(inputCols).length})" + + s" and outputCols(${$(outputCols).length}) should have the same length") + val outputFields = $(inputCols).zip($(outputCols)).map { case (inputCol, outputCol) => + val inputField = schema(inputCol) + SchemaUtils.checkColumnTypes(schema, inputCol, Seq(DoubleType, FloatType)) + StructField(outputCol, inputField.dataType, inputField.nullable) + } + StructType(schema ++ outputFields) + } +} + +/** + * :: Experimental :: + * Imputation estimator for completing missing values, either using the mean or the median + * of the column in which the missing values are located. The input column should be of + * DoubleType or FloatType. Currently Imputer does not support categorical features yet + * (SPARK-15041) and possibly creates incorrect values for a categorical feature. + * + * Note that the mean/median value is computed after filtering out missing values. + * All Null values in the input column are treated as missing, and so are also imputed. For + * computing median, DataFrameStatFunctions.approxQuantile is used with a relative error of 0.001. + */ +@Experimental +class Imputer @Since("2.2.0")(override val uid: String) + extends Estimator[ImputerModel] with ImputerParams with DefaultParamsWritable { + + @Since("2.2.0") + def this() = this(Identifiable.randomUID("imputer")) + + /** @group setParam */ + @Since("2.2.0") + def setInputCols(value: Array[String]): this.type = set(inputCols, value) + + /** @group setParam */ + @Since("2.2.0") + def setOutputCols(value: Array[String]): this.type = set(outputCols, value) + + /** + * Imputation strategy. Available options are ["mean", "median"]. + * @group setParam + */ + @Since("2.2.0") + def setStrategy(value: String): this.type = set(strategy, value) + + /** @group setParam */ + @Since("2.2.0") + def setMissingValue(value: Double): this.type = set(missingValue, value) + + setDefault(strategy -> Imputer.mean, missingValue -> Double.NaN) + + override def fit(dataset: Dataset[_]): ImputerModel = { + transformSchema(dataset.schema, logging = true) + val spark = dataset.sparkSession + import spark.implicits._ + val surrogates = $(inputCols).map { inputCol => + val ic = col(inputCol) + val filtered = dataset.select(ic.cast(DoubleType)) + .filter(ic.isNotNull && ic =!= $(missingValue) && !ic.isNaN) + if(filtered.take(1).length == 0) { + throw new SparkException(s"surrogate cannot be computed. " + + s"All the values in $inputCol are Null, Nan or missingValue(${$(missingValue)})") + } + val surrogate = $(strategy) match { + case Imputer.mean => filtered.select(avg(inputCol)).as[Double].first() + case Imputer.median => filtered.stat.approxQuantile(inputCol, Array(0.5), 0.001).head + } + surrogate + } + + val rows = spark.sparkContext.parallelize(Seq(Row.fromSeq(surrogates))) + val schema = StructType($(inputCols).map(col => StructField(col, DoubleType, nullable = false))) + val surrogateDF = spark.createDataFrame(rows, schema) + copyValues(new ImputerModel(uid, surrogateDF).setParent(this)) + } + + override def transformSchema(schema: StructType): StructType = { + validateAndTransformSchema(schema) + } + + override def copy(extra: ParamMap): Imputer = defaultCopy(extra) +} + +@Since("2.2.0") +object Imputer extends DefaultParamsReadable[Imputer] { + + /** strategy names that Imputer currently supports. */ + private[ml] val mean = "mean" + private[ml] val median = "median" + + @Since("2.2.0") + override def load(path: String): Imputer = super.load(path) +} + +/** + * :: Experimental :: + * Model fitted by [[Imputer]]. + * + * @param surrogateDF a DataFrame contains inputCols and their corresponding surrogates, which are + * used to replace the missing values in the input DataFrame. + */ +@Experimental +class ImputerModel private[ml]( + override val uid: String, + val surrogateDF: DataFrame) + extends Model[ImputerModel] with ImputerParams with MLWritable { + + import ImputerModel._ + + /** @group setParam */ + def setInputCols(value: Array[String]): this.type = set(inputCols, value) + + /** @group setParam */ + def setOutputCols(value: Array[String]): this.type = set(outputCols, value) + + override def transform(dataset: Dataset[_]): DataFrame = { + transformSchema(dataset.schema, logging = true) + var outputDF = dataset + val surrogates = surrogateDF.select($(inputCols).map(col): _*).head().toSeq + + $(inputCols).zip($(outputCols)).zip(surrogates).foreach { + case ((inputCol, outputCol), surrogate) => + val inputType = dataset.schema(inputCol).dataType + val ic = col(inputCol) + outputDF = outputDF.withColumn(outputCol, + when(ic.isNull, surrogate) + .when(ic === $(missingValue), surrogate) + .otherwise(ic) + .cast(inputType)) + } + outputDF.toDF() + } + + override def transformSchema(schema: StructType): StructType = { + validateAndTransformSchema(schema) + } + + override def copy(extra: ParamMap): ImputerModel = { + val copied = new ImputerModel(uid, surrogateDF) + copyValues(copied, extra).setParent(parent) + } + + @Since("2.2.0") + override def write: MLWriter = new ImputerModelWriter(this) +} + + +@Since("2.2.0") +object ImputerModel extends MLReadable[ImputerModel] { + + private[ImputerModel] class ImputerModelWriter(instance: ImputerModel) extends MLWriter { + + override protected def saveImpl(path: String): Unit = { + DefaultParamsWriter.saveMetadata(instance, path, sc) + val dataPath = new Path(path, "data").toString + instance.surrogateDF.repartition(1).write.parquet(dataPath) + } + } + + private class ImputerReader extends MLReader[ImputerModel] { + + private val className = classOf[ImputerModel].getName + + override def load(path: String): ImputerModel = { + val metadata = DefaultParamsReader.loadMetadata(path, sc, className) + val dataPath = new Path(path, "data").toString + val surrogateDF = sqlContext.read.parquet(dataPath) + val model = new ImputerModel(metadata.uid, surrogateDF) + DefaultParamsReader.getAndSetParams(model, metadata) + model + } + } + + @Since("2.2.0") + override def read: MLReader[ImputerModel] = new ImputerReader + + @Since("2.2.0") + override def load(path: String): ImputerModel = super.load(path) +} diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/ImputerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/ImputerSuite.scala new file mode 100644 index 0000000000000..ee2ba73fa96d5 --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/ImputerSuite.scala @@ -0,0 +1,185 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.ml.feature + +import org.apache.spark.{SparkException, SparkFunSuite} +import org.apache.spark.ml.util.DefaultReadWriteTest +import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.mllib.util.TestingUtils._ +import org.apache.spark.sql.{DataFrame, Row} + +class ImputerSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { + + test("Imputer for Double with default missing Value NaN") { + val df = spark.createDataFrame( Seq( + (0, 1.0, 4.0, 1.0, 1.0, 4.0, 4.0), + (1, 11.0, 12.0, 11.0, 11.0, 12.0, 12.0), + (2, 3.0, Double.NaN, 3.0, 3.0, 10.0, 12.0), + (3, Double.NaN, 14.0, 5.0, 3.0, 14.0, 14.0) + )).toDF("id", "value1", "value2", "expected_mean_value1", "expected_median_value1", + "expected_mean_value2", "expected_median_value2") + val imputer = new Imputer() + .setInputCols(Array("value1", "value2")) + .setOutputCols(Array("out1", "out2")) + ImputerSuite.iterateStrategyTest(imputer, df) + } + + test("Imputer should handle NaNs when computing surrogate value, if missingValue is not NaN") { + val df = spark.createDataFrame( Seq( + (0, 1.0, 1.0, 1.0), + (1, 3.0, 3.0, 3.0), + (2, Double.NaN, Double.NaN, Double.NaN), + (3, -1.0, 2.0, 3.0) + )).toDF("id", "value", "expected_mean_value", "expected_median_value") + val imputer = new Imputer().setInputCols(Array("value")).setOutputCols(Array("out")) + .setMissingValue(-1.0) + ImputerSuite.iterateStrategyTest(imputer, df) + } + + test("Imputer for Float with missing Value -1.0") { + val df = spark.createDataFrame( Seq( + (0, 1.0F, 1.0F, 1.0F), + (1, 3.0F, 3.0F, 3.0F), + (2, 10.0F, 10.0F, 10.0F), + (3, 10.0F, 10.0F, 10.0F), + (4, -1.0F, 6.0F, 3.0F) + )).toDF("id", "value", "expected_mean_value", "expected_median_value") + val imputer = new Imputer().setInputCols(Array("value")).setOutputCols(Array("out")) + .setMissingValue(-1) + ImputerSuite.iterateStrategyTest(imputer, df) + } + + test("Imputer should impute null as well as 'missingValue'") { + val rawDf = spark.createDataFrame( Seq( + (0, 4.0, 4.0, 4.0), + (1, 10.0, 10.0, 10.0), + (2, 10.0, 10.0, 10.0), + (3, Double.NaN, 8.0, 10.0), + (4, -1.0, 8.0, 10.0) + )).toDF("id", "rawValue", "expected_mean_value", "expected_median_value") + val df = rawDf.selectExpr("*", "IF(rawValue=-1.0, null, rawValue) as value") + val imputer = new Imputer().setInputCols(Array("value")).setOutputCols(Array("out")) + ImputerSuite.iterateStrategyTest(imputer, df) + } + + test("Imputer throws exception when surrogate cannot be computed") { + val df = spark.createDataFrame( Seq( + (0, Double.NaN, 1.0, 1.0), + (1, Double.NaN, 3.0, 3.0), + (2, Double.NaN, Double.NaN, Double.NaN) + )).toDF("id", "value", "expected_mean_value", "expected_median_value") + Seq("mean", "median").foreach { strategy => + val imputer = new Imputer().setInputCols(Array("value")).setOutputCols(Array("out")) + .setStrategy(strategy) + withClue("Imputer should fail all the values are invalid") { + val e: SparkException = intercept[SparkException] { + val model = imputer.fit(df) + } + assert(e.getMessage.contains("surrogate cannot be computed")) + } + } + } + + test("Imputer input & output column validation") { + val df = spark.createDataFrame( Seq( + (0, 1.0, 1.0, 1.0), + (1, Double.NaN, 3.0, 3.0), + (2, Double.NaN, Double.NaN, Double.NaN) + )).toDF("id", "value1", "value2", "value3") + Seq("mean", "median").foreach { strategy => + withClue("Imputer should fail if inputCols and outputCols are different length") { + val e: IllegalArgumentException = intercept[IllegalArgumentException] { + val imputer = new Imputer().setStrategy(strategy) + .setInputCols(Array("value1", "value2")) + .setOutputCols(Array("out1")) + val model = imputer.fit(df) + } + assert(e.getMessage.contains("should have the same length")) + } + + withClue("Imputer should fail if inputCols contains duplicates") { + val e: IllegalArgumentException = intercept[IllegalArgumentException] { + val imputer = new Imputer().setStrategy(strategy) + .setInputCols(Array("value1", "value1")) + .setOutputCols(Array("out1", "out2")) + val model = imputer.fit(df) + } + assert(e.getMessage.contains("inputCols contains duplicates")) + } + + withClue("Imputer should fail if outputCols contains duplicates") { + val e: IllegalArgumentException = intercept[IllegalArgumentException] { + val imputer = new Imputer().setStrategy(strategy) + .setInputCols(Array("value1", "value2")) + .setOutputCols(Array("out1", "out1")) + val model = imputer.fit(df) + } + assert(e.getMessage.contains("outputCols contains duplicates")) + } + } + } + + test("Imputer read/write") { + val t = new Imputer() + .setInputCols(Array("myInputCol")) + .setOutputCols(Array("myOutputCol")) + .setMissingValue(-1.0) + testDefaultReadWrite(t) + } + + test("ImputerModel read/write") { + val spark = this.spark + import spark.implicits._ + val surrogateDF = Seq(1.234).toDF("myInputCol") + + val instance = new ImputerModel( + "myImputer", surrogateDF) + .setInputCols(Array("myInputCol")) + .setOutputCols(Array("myOutputCol")) + val newInstance = testDefaultReadWrite(instance) + assert(newInstance.surrogateDF.columns === instance.surrogateDF.columns) + assert(newInstance.surrogateDF.collect() === instance.surrogateDF.collect()) + } + +} + +object ImputerSuite { + + /** + * Imputation strategy. Available options are ["mean", "median"]. + * @param df DataFrame with columns "id", "value", "expected_mean", "expected_median" + */ + def iterateStrategyTest(imputer: Imputer, df: DataFrame): Unit = { + val inputCols = imputer.getInputCols + + Seq("mean", "median").foreach { strategy => + imputer.setStrategy(strategy) + val model = imputer.fit(df) + val resultDF = model.transform(df) + imputer.getInputCols.zip(imputer.getOutputCols).foreach { case (inputCol, outputCol) => + resultDF.select(s"expected_${strategy}_$inputCol", outputCol).collect().foreach { + case Row(exp: Float, out: Float) => + assert((exp.isNaN && out.isNaN) || (exp == out), + s"Imputed values differ. Expected: $exp, actual: $out") + case Row(exp: Double, out: Double) => + assert((exp.isNaN && out.isNaN) || (exp ~== out absTol 1e-5), + s"Imputed values differ. Expected: $exp, actual: $out") + } + } + } + } +} From ee91a0decc389572099ea7c038149cc50375a2ef Mon Sep 17 00:00:00 2001 From: Bogdan Raducanu Date: Thu, 16 Mar 2017 15:25:45 +0100 Subject: [PATCH 0036/1765] [SPARK-19946][TESTING] DebugFilesystem.assertNoOpenStreams should report the open streams to help debugging ## What changes were proposed in this pull request? DebugFilesystem.assertNoOpenStreams throws an exception with a cause exception that actually shows the code line which leaked the stream. ## How was this patch tested? New test in SparkContextSuite to check there is a cause exception. Author: Bogdan Raducanu Closes #17292 from bogdanrdc/SPARK-19946. --- .../org/apache/spark/DebugFilesystem.scala | 3 ++- .../org/apache/spark/SparkContextSuite.scala | 20 ++++++++++++++++++- 2 files changed, 21 insertions(+), 2 deletions(-) diff --git a/core/src/test/scala/org/apache/spark/DebugFilesystem.scala b/core/src/test/scala/org/apache/spark/DebugFilesystem.scala index fb8d701ebda8a..72aea841117cc 100644 --- a/core/src/test/scala/org/apache/spark/DebugFilesystem.scala +++ b/core/src/test/scala/org/apache/spark/DebugFilesystem.scala @@ -44,7 +44,8 @@ object DebugFilesystem extends Logging { logWarning("Leaked filesystem connection created at:") exc.printStackTrace() } - throw new RuntimeException(s"There are $numOpen possibly leaked file streams.") + throw new IllegalStateException(s"There are $numOpen possibly leaked file streams.", + openStreams.values().asScala.head) } } } diff --git a/core/src/test/scala/org/apache/spark/SparkContextSuite.scala b/core/src/test/scala/org/apache/spark/SparkContextSuite.scala index f97a112ec1276..d08a162feda03 100644 --- a/core/src/test/scala/org/apache/spark/SparkContextSuite.scala +++ b/core/src/test/scala/org/apache/spark/SparkContextSuite.scala @@ -18,7 +18,7 @@ package org.apache.spark import java.io.File -import java.net.MalformedURLException +import java.net.{MalformedURLException, URI} import java.nio.charset.StandardCharsets import java.util.concurrent.TimeUnit @@ -26,6 +26,8 @@ import scala.concurrent.duration._ import scala.concurrent.Await import com.google.common.io.Files +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.fs.{FileSystem, Path} import org.apache.hadoop.io.{BytesWritable, LongWritable, Text} import org.apache.hadoop.mapred.TextInputFormat import org.apache.hadoop.mapreduce.lib.input.{TextInputFormat => NewTextInputFormat} @@ -538,6 +540,22 @@ class SparkContextSuite extends SparkFunSuite with LocalSparkContext with Eventu } } + test("SPARK-19446: DebugFilesystem.assertNoOpenStreams should report " + + "open streams to help debugging") { + val fs = new DebugFilesystem() + fs.initialize(new URI("file:///"), new Configuration()) + val file = File.createTempFile("SPARK19446", "temp") + Files.write(Array.ofDim[Byte](1000), file) + val path = new Path("file:///" + file.getCanonicalPath) + val stream = fs.open(path) + val exc = intercept[RuntimeException] { + DebugFilesystem.assertNoOpenStreams() + } + assert(exc != null) + assert(exc.getCause() != null) + stream.close() + } + } object SparkContextSuite { From 8e8f898335f5019c0d4f3944c4aefa12a185db70 Mon Sep 17 00:00:00 2001 From: windpiger Date: Thu, 16 Mar 2017 11:34:13 -0700 Subject: [PATCH 0037/1765] [SPARK-19945][SQL] add test suite for SessionCatalog with HiveExternalCatalog ## What changes were proposed in this pull request? Currently `SessionCatalogSuite` is only for `InMemoryCatalog`, there is no suite for `HiveExternalCatalog`. And there are some ddl function is not proper to test in `ExternalCatalogSuite`, because some logic are not full implement in `ExternalCatalog`, these ddl functions are full implement in `SessionCatalog`(e.g. merge the same logic from `ExternalCatalog` up to `SessionCatalog` ). It is better to test it in `SessionCatalogSuite` for this situation. So we should add a test suite for `SessionCatalog` with `HiveExternalCatalog` The main change is that in `SessionCatalogSuite` add two functions: `withBasicCatalog` and `withEmptyCatalog` And replace the code like `val catalog = new SessionCatalog(newBasicCatalog)` with above two functions ## How was this patch tested? add `HiveExternalSessionCatalogSuite` Author: windpiger Closes #17287 from windpiger/sessioncatalogsuit. --- .../sql/catalyst/catalog/SessionCatalog.scala | 2 +- .../catalog/SessionCatalogSuite.scala | 1907 +++++++++-------- .../HiveExternalSessionCatalogSuite.scala | 40 + 3 files changed, 1049 insertions(+), 900 deletions(-) create mode 100644 sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalSessionCatalogSuite.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala index bfcdb70fe47c1..25aa8d3ba921f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala @@ -48,7 +48,7 @@ object SessionCatalog { * This class must be thread-safe. */ class SessionCatalog( - externalCatalog: ExternalCatalog, + val externalCatalog: ExternalCatalog, globalTempViewManager: GlobalTempViewManager, functionRegistry: FunctionRegistry, conf: CatalystConf, diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalogSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalogSuite.scala index 7e74dcdef0e27..bb87763e0bbb0 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalogSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalogSuite.scala @@ -27,41 +27,67 @@ import org.apache.spark.sql.catalyst.parser.CatalystSqlParser import org.apache.spark.sql.catalyst.plans.PlanTest import org.apache.spark.sql.catalyst.plans.logical.{Range, SubqueryAlias, View} +class InMemorySessionCatalogSuite extends SessionCatalogSuite { + protected val utils = new CatalogTestUtils { + override val tableInputFormat: String = "com.fruit.eyephone.CameraInputFormat" + override val tableOutputFormat: String = "com.fruit.eyephone.CameraOutputFormat" + override val defaultProvider: String = "parquet" + override def newEmptyCatalog(): ExternalCatalog = new InMemoryCatalog + } +} + /** - * Tests for [[SessionCatalog]] that assume that [[InMemoryCatalog]] is correctly implemented. + * Tests for [[SessionCatalog]] * * Note: many of the methods here are very similar to the ones in [[ExternalCatalogSuite]]. * This is because [[SessionCatalog]] and [[ExternalCatalog]] share many similar method * signatures but do not extend a common parent. This is largely by design but * unfortunately leads to very similar test code in two places. */ -class SessionCatalogSuite extends PlanTest { - private val utils = new CatalogTestUtils { - override val tableInputFormat: String = "com.fruit.eyephone.CameraInputFormat" - override val tableOutputFormat: String = "com.fruit.eyephone.CameraOutputFormat" - override val defaultProvider: String = "parquet" - override def newEmptyCatalog(): ExternalCatalog = new InMemoryCatalog - } +abstract class SessionCatalogSuite extends PlanTest { + protected val utils: CatalogTestUtils + + protected val isHiveExternalCatalog = false import utils._ + private def withBasicCatalog(f: SessionCatalog => Unit): Unit = { + val catalog = new SessionCatalog(newBasicCatalog()) + catalog.createDatabase(newDb("default"), ignoreIfExists = true) + try { + f(catalog) + } finally { + catalog.reset() + } + } + + private def withEmptyCatalog(f: SessionCatalog => Unit): Unit = { + val catalog = new SessionCatalog(newEmptyCatalog()) + catalog.createDatabase(newDb("default"), ignoreIfExists = true) + try { + f(catalog) + } finally { + catalog.reset() + } + } // -------------------------------------------------------------------------- // Databases // -------------------------------------------------------------------------- test("basic create and list databases") { - val catalog = new SessionCatalog(newEmptyCatalog()) - catalog.createDatabase(newDb("default"), ignoreIfExists = true) - assert(catalog.databaseExists("default")) - assert(!catalog.databaseExists("testing")) - assert(!catalog.databaseExists("testing2")) - catalog.createDatabase(newDb("testing"), ignoreIfExists = false) - assert(catalog.databaseExists("testing")) - assert(catalog.listDatabases().toSet == Set("default", "testing")) - catalog.createDatabase(newDb("testing2"), ignoreIfExists = false) - assert(catalog.listDatabases().toSet == Set("default", "testing", "testing2")) - assert(catalog.databaseExists("testing2")) - assert(!catalog.databaseExists("does_not_exist")) + withEmptyCatalog { catalog => + catalog.createDatabase(newDb("default"), ignoreIfExists = true) + assert(catalog.databaseExists("default")) + assert(!catalog.databaseExists("testing")) + assert(!catalog.databaseExists("testing2")) + catalog.createDatabase(newDb("testing"), ignoreIfExists = false) + assert(catalog.databaseExists("testing")) + assert(catalog.listDatabases().toSet == Set("default", "testing")) + catalog.createDatabase(newDb("testing2"), ignoreIfExists = false) + assert(catalog.listDatabases().toSet == Set("default", "testing", "testing2")) + assert(catalog.databaseExists("testing2")) + assert(!catalog.databaseExists("does_not_exist")) + } } def testInvalidName(func: (String) => Unit) { @@ -76,121 +102,141 @@ class SessionCatalogSuite extends PlanTest { } test("create databases using invalid names") { - val catalog = new SessionCatalog(newEmptyCatalog()) - testInvalidName(name => catalog.createDatabase(newDb(name), ignoreIfExists = true)) + withEmptyCatalog { catalog => + testInvalidName( + name => catalog.createDatabase(newDb(name), ignoreIfExists = true)) + } } test("get database when a database exists") { - val catalog = new SessionCatalog(newBasicCatalog()) - val db1 = catalog.getDatabaseMetadata("db1") - assert(db1.name == "db1") - assert(db1.description.contains("db1")) + withBasicCatalog { catalog => + val db1 = catalog.getDatabaseMetadata("db1") + assert(db1.name == "db1") + assert(db1.description.contains("db1")) + } } test("get database should throw exception when the database does not exist") { - val catalog = new SessionCatalog(newBasicCatalog()) - intercept[NoSuchDatabaseException] { - catalog.getDatabaseMetadata("db_that_does_not_exist") + withBasicCatalog { catalog => + intercept[NoSuchDatabaseException] { + catalog.getDatabaseMetadata("db_that_does_not_exist") + } } } test("list databases without pattern") { - val catalog = new SessionCatalog(newBasicCatalog()) - assert(catalog.listDatabases().toSet == Set("default", "db1", "db2", "db3")) + withBasicCatalog { catalog => + assert(catalog.listDatabases().toSet == Set("default", "db1", "db2", "db3")) + } } test("list databases with pattern") { - val catalog = new SessionCatalog(newBasicCatalog()) - assert(catalog.listDatabases("db").toSet == Set.empty) - assert(catalog.listDatabases("db*").toSet == Set("db1", "db2", "db3")) - assert(catalog.listDatabases("*1").toSet == Set("db1")) - assert(catalog.listDatabases("db2").toSet == Set("db2")) + withBasicCatalog { catalog => + assert(catalog.listDatabases("db").toSet == Set.empty) + assert(catalog.listDatabases("db*").toSet == Set("db1", "db2", "db3")) + assert(catalog.listDatabases("*1").toSet == Set("db1")) + assert(catalog.listDatabases("db2").toSet == Set("db2")) + } } test("drop database") { - val catalog = new SessionCatalog(newBasicCatalog()) - catalog.dropDatabase("db1", ignoreIfNotExists = false, cascade = false) - assert(catalog.listDatabases().toSet == Set("default", "db2", "db3")) + withBasicCatalog { catalog => + catalog.dropDatabase("db1", ignoreIfNotExists = false, cascade = false) + assert(catalog.listDatabases().toSet == Set("default", "db2", "db3")) + } } test("drop database when the database is not empty") { // Throw exception if there are functions left - val externalCatalog1 = newBasicCatalog() - val sessionCatalog1 = new SessionCatalog(externalCatalog1) - externalCatalog1.dropTable("db2", "tbl1", ignoreIfNotExists = false, purge = false) - externalCatalog1.dropTable("db2", "tbl2", ignoreIfNotExists = false, purge = false) - intercept[AnalysisException] { - sessionCatalog1.dropDatabase("db2", ignoreIfNotExists = false, cascade = false) + withBasicCatalog { catalog => + catalog.externalCatalog.dropTable("db2", "tbl1", ignoreIfNotExists = false, purge = false) + catalog.externalCatalog.dropTable("db2", "tbl2", ignoreIfNotExists = false, purge = false) + intercept[AnalysisException] { + catalog.dropDatabase("db2", ignoreIfNotExists = false, cascade = false) + } } - - // Throw exception if there are tables left - val externalCatalog2 = newBasicCatalog() - val sessionCatalog2 = new SessionCatalog(externalCatalog2) - externalCatalog2.dropFunction("db2", "func1") - intercept[AnalysisException] { - sessionCatalog2.dropDatabase("db2", ignoreIfNotExists = false, cascade = false) + withBasicCatalog { catalog => + // Throw exception if there are tables left + catalog.externalCatalog.dropFunction("db2", "func1") + intercept[AnalysisException] { + catalog.dropDatabase("db2", ignoreIfNotExists = false, cascade = false) + } } - // When cascade is true, it should drop them - val externalCatalog3 = newBasicCatalog() - val sessionCatalog3 = new SessionCatalog(externalCatalog3) - externalCatalog3.dropDatabase("db2", ignoreIfNotExists = false, cascade = true) - assert(sessionCatalog3.listDatabases().toSet == Set("default", "db1", "db3")) + withBasicCatalog { catalog => + // When cascade is true, it should drop them + catalog.externalCatalog.dropDatabase("db2", ignoreIfNotExists = false, cascade = true) + assert(catalog.listDatabases().toSet == Set("default", "db1", "db3")) + } } test("drop database when the database does not exist") { - val catalog = new SessionCatalog(newBasicCatalog()) - intercept[NoSuchDatabaseException] { - catalog.dropDatabase("db_that_does_not_exist", ignoreIfNotExists = false, cascade = false) + withBasicCatalog { catalog => + // TODO: fix this inconsistent between HiveExternalCatalog and InMemoryCatalog + if (isHiveExternalCatalog) { + val e = intercept[AnalysisException] { + catalog.dropDatabase("db_that_does_not_exist", ignoreIfNotExists = false, cascade = false) + }.getMessage + assert(e.contains( + "org.apache.hadoop.hive.metastore.api.NoSuchObjectException: db_that_does_not_exist")) + } else { + intercept[NoSuchDatabaseException] { + catalog.dropDatabase("db_that_does_not_exist", ignoreIfNotExists = false, cascade = false) + } + } + catalog.dropDatabase("db_that_does_not_exist", ignoreIfNotExists = true, cascade = false) } - catalog.dropDatabase("db_that_does_not_exist", ignoreIfNotExists = true, cascade = false) } test("drop current database and drop default database") { - val catalog = new SessionCatalog(newBasicCatalog()) - catalog.setCurrentDatabase("db1") - assert(catalog.getCurrentDatabase == "db1") - catalog.dropDatabase("db1", ignoreIfNotExists = false, cascade = true) - intercept[NoSuchDatabaseException] { - catalog.createTable(newTable("tbl1", "db1"), ignoreIfExists = false) - } - catalog.setCurrentDatabase("default") - assert(catalog.getCurrentDatabase == "default") - intercept[AnalysisException] { - catalog.dropDatabase("default", ignoreIfNotExists = false, cascade = true) + withBasicCatalog { catalog => + catalog.setCurrentDatabase("db1") + assert(catalog.getCurrentDatabase == "db1") + catalog.dropDatabase("db1", ignoreIfNotExists = false, cascade = true) + intercept[NoSuchDatabaseException] { + catalog.createTable(newTable("tbl1", "db1"), ignoreIfExists = false) + } + catalog.setCurrentDatabase("default") + assert(catalog.getCurrentDatabase == "default") + intercept[AnalysisException] { + catalog.dropDatabase("default", ignoreIfNotExists = false, cascade = true) + } } } test("alter database") { - val catalog = new SessionCatalog(newBasicCatalog()) - val db1 = catalog.getDatabaseMetadata("db1") - // Note: alter properties here because Hive does not support altering other fields - catalog.alterDatabase(db1.copy(properties = Map("k" -> "v3", "good" -> "true"))) - val newDb1 = catalog.getDatabaseMetadata("db1") - assert(db1.properties.isEmpty) - assert(newDb1.properties.size == 2) - assert(newDb1.properties.get("k") == Some("v3")) - assert(newDb1.properties.get("good") == Some("true")) + withBasicCatalog { catalog => + val db1 = catalog.getDatabaseMetadata("db1") + // Note: alter properties here because Hive does not support altering other fields + catalog.alterDatabase(db1.copy(properties = Map("k" -> "v3", "good" -> "true"))) + val newDb1 = catalog.getDatabaseMetadata("db1") + assert(db1.properties.isEmpty) + assert(newDb1.properties.size == 2) + assert(newDb1.properties.get("k") == Some("v3")) + assert(newDb1.properties.get("good") == Some("true")) + } } test("alter database should throw exception when the database does not exist") { - val catalog = new SessionCatalog(newBasicCatalog()) - intercept[NoSuchDatabaseException] { - catalog.alterDatabase(newDb("unknown_db")) + withBasicCatalog { catalog => + intercept[NoSuchDatabaseException] { + catalog.alterDatabase(newDb("unknown_db")) + } } } test("get/set current database") { - val catalog = new SessionCatalog(newBasicCatalog()) - assert(catalog.getCurrentDatabase == "default") - catalog.setCurrentDatabase("db2") - assert(catalog.getCurrentDatabase == "db2") - intercept[NoSuchDatabaseException] { + withBasicCatalog { catalog => + assert(catalog.getCurrentDatabase == "default") + catalog.setCurrentDatabase("db2") + assert(catalog.getCurrentDatabase == "db2") + intercept[NoSuchDatabaseException] { + catalog.setCurrentDatabase("deebo") + } + catalog.createDatabase(newDb("deebo"), ignoreIfExists = false) catalog.setCurrentDatabase("deebo") + assert(catalog.getCurrentDatabase == "deebo") } - catalog.createDatabase(newDb("deebo"), ignoreIfExists = false) - catalog.setCurrentDatabase("deebo") - assert(catalog.getCurrentDatabase == "deebo") } // -------------------------------------------------------------------------- @@ -198,346 +244,360 @@ class SessionCatalogSuite extends PlanTest { // -------------------------------------------------------------------------- test("create table") { - val externalCatalog = newBasicCatalog() - val sessionCatalog = new SessionCatalog(externalCatalog) - assert(externalCatalog.listTables("db1").isEmpty) - assert(externalCatalog.listTables("db2").toSet == Set("tbl1", "tbl2")) - sessionCatalog.createTable(newTable("tbl3", "db1"), ignoreIfExists = false) - sessionCatalog.createTable(newTable("tbl3", "db2"), ignoreIfExists = false) - assert(externalCatalog.listTables("db1").toSet == Set("tbl3")) - assert(externalCatalog.listTables("db2").toSet == Set("tbl1", "tbl2", "tbl3")) - // Create table without explicitly specifying database - sessionCatalog.setCurrentDatabase("db1") - sessionCatalog.createTable(newTable("tbl4"), ignoreIfExists = false) - assert(externalCatalog.listTables("db1").toSet == Set("tbl3", "tbl4")) - assert(externalCatalog.listTables("db2").toSet == Set("tbl1", "tbl2", "tbl3")) + withBasicCatalog { catalog => + assert(catalog.externalCatalog.listTables("db1").isEmpty) + assert(catalog.externalCatalog.listTables("db2").toSet == Set("tbl1", "tbl2")) + catalog.createTable(newTable("tbl3", "db1"), ignoreIfExists = false) + catalog.createTable(newTable("tbl3", "db2"), ignoreIfExists = false) + assert(catalog.externalCatalog.listTables("db1").toSet == Set("tbl3")) + assert(catalog.externalCatalog.listTables("db2").toSet == Set("tbl1", "tbl2", "tbl3")) + // Create table without explicitly specifying database + catalog.setCurrentDatabase("db1") + catalog.createTable(newTable("tbl4"), ignoreIfExists = false) + assert(catalog.externalCatalog.listTables("db1").toSet == Set("tbl3", "tbl4")) + assert(catalog.externalCatalog.listTables("db2").toSet == Set("tbl1", "tbl2", "tbl3")) + } } test("create tables using invalid names") { - val catalog = new SessionCatalog(newEmptyCatalog()) - testInvalidName(name => catalog.createTable(newTable(name, "db1"), ignoreIfExists = false)) + withEmptyCatalog { catalog => + testInvalidName(name => catalog.createTable(newTable(name, "db1"), ignoreIfExists = false)) + } } test("create table when database does not exist") { - val catalog = new SessionCatalog(newBasicCatalog()) - // Creating table in non-existent database should always fail - intercept[NoSuchDatabaseException] { - catalog.createTable(newTable("tbl1", "does_not_exist"), ignoreIfExists = false) - } - intercept[NoSuchDatabaseException] { - catalog.createTable(newTable("tbl1", "does_not_exist"), ignoreIfExists = true) - } - // Table already exists - intercept[TableAlreadyExistsException] { - catalog.createTable(newTable("tbl1", "db2"), ignoreIfExists = false) + withBasicCatalog { catalog => + // Creating table in non-existent database should always fail + intercept[NoSuchDatabaseException] { + catalog.createTable(newTable("tbl1", "does_not_exist"), ignoreIfExists = false) + } + intercept[NoSuchDatabaseException] { + catalog.createTable(newTable("tbl1", "does_not_exist"), ignoreIfExists = true) + } + // Table already exists + intercept[TableAlreadyExistsException] { + catalog.createTable(newTable("tbl1", "db2"), ignoreIfExists = false) + } + catalog.createTable(newTable("tbl1", "db2"), ignoreIfExists = true) } - catalog.createTable(newTable("tbl1", "db2"), ignoreIfExists = true) } test("create temp table") { - val catalog = new SessionCatalog(newBasicCatalog()) - val tempTable1 = Range(1, 10, 1, 10) - val tempTable2 = Range(1, 20, 2, 10) - catalog.createTempView("tbl1", tempTable1, overrideIfExists = false) - catalog.createTempView("tbl2", tempTable2, overrideIfExists = false) - assert(catalog.getTempView("tbl1") == Option(tempTable1)) - assert(catalog.getTempView("tbl2") == Option(tempTable2)) - assert(catalog.getTempView("tbl3").isEmpty) - // Temporary table already exists - intercept[TempTableAlreadyExistsException] { + withBasicCatalog { catalog => + val tempTable1 = Range(1, 10, 1, 10) + val tempTable2 = Range(1, 20, 2, 10) catalog.createTempView("tbl1", tempTable1, overrideIfExists = false) + catalog.createTempView("tbl2", tempTable2, overrideIfExists = false) + assert(catalog.getTempView("tbl1") == Option(tempTable1)) + assert(catalog.getTempView("tbl2") == Option(tempTable2)) + assert(catalog.getTempView("tbl3").isEmpty) + // Temporary table already exists + intercept[TempTableAlreadyExistsException] { + catalog.createTempView("tbl1", tempTable1, overrideIfExists = false) + } + // Temporary table already exists but we override it + catalog.createTempView("tbl1", tempTable2, overrideIfExists = true) + assert(catalog.getTempView("tbl1") == Option(tempTable2)) } - // Temporary table already exists but we override it - catalog.createTempView("tbl1", tempTable2, overrideIfExists = true) - assert(catalog.getTempView("tbl1") == Option(tempTable2)) } test("drop table") { - val externalCatalog = newBasicCatalog() - val sessionCatalog = new SessionCatalog(externalCatalog) - assert(externalCatalog.listTables("db2").toSet == Set("tbl1", "tbl2")) - sessionCatalog.dropTable(TableIdentifier("tbl1", Some("db2")), ignoreIfNotExists = false, - purge = false) - assert(externalCatalog.listTables("db2").toSet == Set("tbl2")) - // Drop table without explicitly specifying database - sessionCatalog.setCurrentDatabase("db2") - sessionCatalog.dropTable(TableIdentifier("tbl2"), ignoreIfNotExists = false, purge = false) - assert(externalCatalog.listTables("db2").isEmpty) + withBasicCatalog { catalog => + assert(catalog.externalCatalog.listTables("db2").toSet == Set("tbl1", "tbl2")) + catalog.dropTable(TableIdentifier("tbl1", Some("db2")), ignoreIfNotExists = false, + purge = false) + assert(catalog.externalCatalog.listTables("db2").toSet == Set("tbl2")) + // Drop table without explicitly specifying database + catalog.setCurrentDatabase("db2") + catalog.dropTable(TableIdentifier("tbl2"), ignoreIfNotExists = false, purge = false) + assert(catalog.externalCatalog.listTables("db2").isEmpty) + } } test("drop table when database/table does not exist") { - val catalog = new SessionCatalog(newBasicCatalog()) - // Should always throw exception when the database does not exist - intercept[NoSuchDatabaseException] { - catalog.dropTable(TableIdentifier("tbl1", Some("unknown_db")), ignoreIfNotExists = false, - purge = false) - } - intercept[NoSuchDatabaseException] { - catalog.dropTable(TableIdentifier("tbl1", Some("unknown_db")), ignoreIfNotExists = true, - purge = false) - } - intercept[NoSuchTableException] { - catalog.dropTable(TableIdentifier("unknown_table", Some("db2")), ignoreIfNotExists = false, + withBasicCatalog { catalog => + // Should always throw exception when the database does not exist + intercept[NoSuchDatabaseException] { + catalog.dropTable(TableIdentifier("tbl1", Some("unknown_db")), ignoreIfNotExists = false, + purge = false) + } + intercept[NoSuchDatabaseException] { + catalog.dropTable(TableIdentifier("tbl1", Some("unknown_db")), ignoreIfNotExists = true, + purge = false) + } + intercept[NoSuchTableException] { + catalog.dropTable(TableIdentifier("unknown_table", Some("db2")), ignoreIfNotExists = false, + purge = false) + } + catalog.dropTable(TableIdentifier("unknown_table", Some("db2")), ignoreIfNotExists = true, purge = false) } - catalog.dropTable(TableIdentifier("unknown_table", Some("db2")), ignoreIfNotExists = true, - purge = false) } test("drop temp table") { - val externalCatalog = newBasicCatalog() - val sessionCatalog = new SessionCatalog(externalCatalog) - val tempTable = Range(1, 10, 2, 10) - sessionCatalog.createTempView("tbl1", tempTable, overrideIfExists = false) - sessionCatalog.setCurrentDatabase("db2") - assert(sessionCatalog.getTempView("tbl1") == Some(tempTable)) - assert(externalCatalog.listTables("db2").toSet == Set("tbl1", "tbl2")) - // If database is not specified, temp table should be dropped first - sessionCatalog.dropTable(TableIdentifier("tbl1"), ignoreIfNotExists = false, purge = false) - assert(sessionCatalog.getTempView("tbl1") == None) - assert(externalCatalog.listTables("db2").toSet == Set("tbl1", "tbl2")) - // If temp table does not exist, the table in the current database should be dropped - sessionCatalog.dropTable(TableIdentifier("tbl1"), ignoreIfNotExists = false, purge = false) - assert(externalCatalog.listTables("db2").toSet == Set("tbl2")) - // If database is specified, temp tables are never dropped - sessionCatalog.createTempView("tbl1", tempTable, overrideIfExists = false) - sessionCatalog.createTable(newTable("tbl1", "db2"), ignoreIfExists = false) - sessionCatalog.dropTable(TableIdentifier("tbl1", Some("db2")), ignoreIfNotExists = false, - purge = false) - assert(sessionCatalog.getTempView("tbl1") == Some(tempTable)) - assert(externalCatalog.listTables("db2").toSet == Set("tbl2")) + withBasicCatalog { catalog => + val tempTable = Range(1, 10, 2, 10) + catalog.createTempView("tbl1", tempTable, overrideIfExists = false) + catalog.setCurrentDatabase("db2") + assert(catalog.getTempView("tbl1") == Some(tempTable)) + assert(catalog.externalCatalog.listTables("db2").toSet == Set("tbl1", "tbl2")) + // If database is not specified, temp table should be dropped first + catalog.dropTable(TableIdentifier("tbl1"), ignoreIfNotExists = false, purge = false) + assert(catalog.getTempView("tbl1") == None) + assert(catalog.externalCatalog.listTables("db2").toSet == Set("tbl1", "tbl2")) + // If temp table does not exist, the table in the current database should be dropped + catalog.dropTable(TableIdentifier("tbl1"), ignoreIfNotExists = false, purge = false) + assert(catalog.externalCatalog.listTables("db2").toSet == Set("tbl2")) + // If database is specified, temp tables are never dropped + catalog.createTempView("tbl1", tempTable, overrideIfExists = false) + catalog.createTable(newTable("tbl1", "db2"), ignoreIfExists = false) + catalog.dropTable(TableIdentifier("tbl1", Some("db2")), ignoreIfNotExists = false, + purge = false) + assert(catalog.getTempView("tbl1") == Some(tempTable)) + assert(catalog.externalCatalog.listTables("db2").toSet == Set("tbl2")) + } } test("rename table") { - val externalCatalog = newBasicCatalog() - val sessionCatalog = new SessionCatalog(externalCatalog) - assert(externalCatalog.listTables("db2").toSet == Set("tbl1", "tbl2")) - sessionCatalog.renameTable(TableIdentifier("tbl1", Some("db2")), TableIdentifier("tblone")) - assert(externalCatalog.listTables("db2").toSet == Set("tblone", "tbl2")) - sessionCatalog.renameTable(TableIdentifier("tbl2", Some("db2")), TableIdentifier("tbltwo")) - assert(externalCatalog.listTables("db2").toSet == Set("tblone", "tbltwo")) - // Rename table without explicitly specifying database - sessionCatalog.setCurrentDatabase("db2") - sessionCatalog.renameTable(TableIdentifier("tbltwo"), TableIdentifier("table_two")) - assert(externalCatalog.listTables("db2").toSet == Set("tblone", "table_two")) - // Renaming "db2.tblone" to "db1.tblones" should fail because databases don't match - intercept[AnalysisException] { - sessionCatalog.renameTable( - TableIdentifier("tblone", Some("db2")), TableIdentifier("tblones", Some("db1"))) - } - // The new table already exists - intercept[TableAlreadyExistsException] { - sessionCatalog.renameTable( - TableIdentifier("tblone", Some("db2")), - TableIdentifier("table_two")) + withBasicCatalog { catalog => + assert(catalog.externalCatalog.listTables("db2").toSet == Set("tbl1", "tbl2")) + catalog.renameTable(TableIdentifier("tbl1", Some("db2")), TableIdentifier("tblone")) + assert(catalog.externalCatalog.listTables("db2").toSet == Set("tblone", "tbl2")) + catalog.renameTable(TableIdentifier("tbl2", Some("db2")), TableIdentifier("tbltwo")) + assert(catalog.externalCatalog.listTables("db2").toSet == Set("tblone", "tbltwo")) + // Rename table without explicitly specifying database + catalog.setCurrentDatabase("db2") + catalog.renameTable(TableIdentifier("tbltwo"), TableIdentifier("table_two")) + assert(catalog.externalCatalog.listTables("db2").toSet == Set("tblone", "table_two")) + // Renaming "db2.tblone" to "db1.tblones" should fail because databases don't match + intercept[AnalysisException] { + catalog.renameTable( + TableIdentifier("tblone", Some("db2")), TableIdentifier("tblones", Some("db1"))) + } + // The new table already exists + intercept[TableAlreadyExistsException] { + catalog.renameTable( + TableIdentifier("tblone", Some("db2")), + TableIdentifier("table_two")) + } } } test("rename tables to an invalid name") { - val catalog = new SessionCatalog(newBasicCatalog()) - testInvalidName( - name => catalog.renameTable(TableIdentifier("tbl1", Some("db2")), TableIdentifier(name))) + withBasicCatalog { catalog => + testInvalidName( + name => catalog.renameTable(TableIdentifier("tbl1", Some("db2")), TableIdentifier(name))) + } } test("rename table when database/table does not exist") { - val catalog = new SessionCatalog(newBasicCatalog()) - intercept[NoSuchDatabaseException] { - catalog.renameTable(TableIdentifier("tbl1", Some("unknown_db")), TableIdentifier("tbl2")) - } - intercept[NoSuchTableException] { - catalog.renameTable(TableIdentifier("unknown_table", Some("db2")), TableIdentifier("tbl2")) + withBasicCatalog { catalog => + intercept[NoSuchDatabaseException] { + catalog.renameTable(TableIdentifier("tbl1", Some("unknown_db")), TableIdentifier("tbl2")) + } + intercept[NoSuchTableException] { + catalog.renameTable(TableIdentifier("unknown_table", Some("db2")), TableIdentifier("tbl2")) + } } } test("rename temp table") { - val externalCatalog = newBasicCatalog() - val sessionCatalog = new SessionCatalog(externalCatalog) - val tempTable = Range(1, 10, 2, 10) - sessionCatalog.createTempView("tbl1", tempTable, overrideIfExists = false) - sessionCatalog.setCurrentDatabase("db2") - assert(sessionCatalog.getTempView("tbl1") == Option(tempTable)) - assert(externalCatalog.listTables("db2").toSet == Set("tbl1", "tbl2")) - // If database is not specified, temp table should be renamed first - sessionCatalog.renameTable(TableIdentifier("tbl1"), TableIdentifier("tbl3")) - assert(sessionCatalog.getTempView("tbl1").isEmpty) - assert(sessionCatalog.getTempView("tbl3") == Option(tempTable)) - assert(externalCatalog.listTables("db2").toSet == Set("tbl1", "tbl2")) - // If database is specified, temp tables are never renamed - sessionCatalog.renameTable(TableIdentifier("tbl2", Some("db2")), TableIdentifier("tbl4")) - assert(sessionCatalog.getTempView("tbl3") == Option(tempTable)) - assert(sessionCatalog.getTempView("tbl4").isEmpty) - assert(externalCatalog.listTables("db2").toSet == Set("tbl1", "tbl4")) + withBasicCatalog { catalog => + val tempTable = Range(1, 10, 2, 10) + catalog.createTempView("tbl1", tempTable, overrideIfExists = false) + catalog.setCurrentDatabase("db2") + assert(catalog.getTempView("tbl1") == Option(tempTable)) + assert(catalog.externalCatalog.listTables("db2").toSet == Set("tbl1", "tbl2")) + // If database is not specified, temp table should be renamed first + catalog.renameTable(TableIdentifier("tbl1"), TableIdentifier("tbl3")) + assert(catalog.getTempView("tbl1").isEmpty) + assert(catalog.getTempView("tbl3") == Option(tempTable)) + assert(catalog.externalCatalog.listTables("db2").toSet == Set("tbl1", "tbl2")) + // If database is specified, temp tables are never renamed + catalog.renameTable(TableIdentifier("tbl2", Some("db2")), TableIdentifier("tbl4")) + assert(catalog.getTempView("tbl3") == Option(tempTable)) + assert(catalog.getTempView("tbl4").isEmpty) + assert(catalog.externalCatalog.listTables("db2").toSet == Set("tbl1", "tbl4")) + } } test("alter table") { - val externalCatalog = newBasicCatalog() - val sessionCatalog = new SessionCatalog(externalCatalog) - val tbl1 = externalCatalog.getTable("db2", "tbl1") - sessionCatalog.alterTable(tbl1.copy(properties = Map("toh" -> "frem"))) - val newTbl1 = externalCatalog.getTable("db2", "tbl1") - assert(!tbl1.properties.contains("toh")) - assert(newTbl1.properties.size == tbl1.properties.size + 1) - assert(newTbl1.properties.get("toh") == Some("frem")) - // Alter table without explicitly specifying database - sessionCatalog.setCurrentDatabase("db2") - sessionCatalog.alterTable(tbl1.copy(identifier = TableIdentifier("tbl1"))) - val newestTbl1 = externalCatalog.getTable("db2", "tbl1") - assert(newestTbl1 == tbl1) + withBasicCatalog { catalog => + val tbl1 = catalog.externalCatalog.getTable("db2", "tbl1") + catalog.alterTable(tbl1.copy(properties = Map("toh" -> "frem"))) + val newTbl1 = catalog.externalCatalog.getTable("db2", "tbl1") + assert(!tbl1.properties.contains("toh")) + assert(newTbl1.properties.size == tbl1.properties.size + 1) + assert(newTbl1.properties.get("toh") == Some("frem")) + // Alter table without explicitly specifying database + catalog.setCurrentDatabase("db2") + catalog.alterTable(tbl1.copy(identifier = TableIdentifier("tbl1"))) + val newestTbl1 = catalog.externalCatalog.getTable("db2", "tbl1") + // For hive serde table, hive metastore will set transient_lastDdlTime in table's properties, + // and its value will be modified, here we ignore it when comparing the two tables. + assert(newestTbl1.copy(properties = Map.empty) == tbl1.copy(properties = Map.empty)) + } } test("alter table when database/table does not exist") { - val catalog = new SessionCatalog(newBasicCatalog()) - intercept[NoSuchDatabaseException] { - catalog.alterTable(newTable("tbl1", "unknown_db")) - } - intercept[NoSuchTableException] { - catalog.alterTable(newTable("unknown_table", "db2")) + withBasicCatalog { catalog => + intercept[NoSuchDatabaseException] { + catalog.alterTable(newTable("tbl1", "unknown_db")) + } + intercept[NoSuchTableException] { + catalog.alterTable(newTable("unknown_table", "db2")) + } } } test("get table") { - val externalCatalog = newBasicCatalog() - val sessionCatalog = new SessionCatalog(externalCatalog) - assert(sessionCatalog.getTableMetadata(TableIdentifier("tbl1", Some("db2"))) - == externalCatalog.getTable("db2", "tbl1")) - // Get table without explicitly specifying database - sessionCatalog.setCurrentDatabase("db2") - assert(sessionCatalog.getTableMetadata(TableIdentifier("tbl1")) - == externalCatalog.getTable("db2", "tbl1")) + withBasicCatalog { catalog => + assert(catalog.getTableMetadata(TableIdentifier("tbl1", Some("db2"))) + == catalog.externalCatalog.getTable("db2", "tbl1")) + // Get table without explicitly specifying database + catalog.setCurrentDatabase("db2") + assert(catalog.getTableMetadata(TableIdentifier("tbl1")) + == catalog.externalCatalog.getTable("db2", "tbl1")) + } } test("get table when database/table does not exist") { - val catalog = new SessionCatalog(newBasicCatalog()) - intercept[NoSuchDatabaseException] { - catalog.getTableMetadata(TableIdentifier("tbl1", Some("unknown_db"))) - } - intercept[NoSuchTableException] { - catalog.getTableMetadata(TableIdentifier("unknown_table", Some("db2"))) + withBasicCatalog { catalog => + intercept[NoSuchDatabaseException] { + catalog.getTableMetadata(TableIdentifier("tbl1", Some("unknown_db"))) + } + intercept[NoSuchTableException] { + catalog.getTableMetadata(TableIdentifier("unknown_table", Some("db2"))) + } } } test("get option of table metadata") { - val externalCatalog = newBasicCatalog() - val catalog = new SessionCatalog(externalCatalog) - assert(catalog.getTableMetadataOption(TableIdentifier("tbl1", Some("db2"))) - == Option(externalCatalog.getTable("db2", "tbl1"))) - assert(catalog.getTableMetadataOption(TableIdentifier("unknown_table", Some("db2"))).isEmpty) - intercept[NoSuchDatabaseException] { - catalog.getTableMetadataOption(TableIdentifier("tbl1", Some("unknown_db"))) + withBasicCatalog { catalog => + assert(catalog.getTableMetadataOption(TableIdentifier("tbl1", Some("db2"))) + == Option(catalog.externalCatalog.getTable("db2", "tbl1"))) + assert(catalog.getTableMetadataOption(TableIdentifier("unknown_table", Some("db2"))).isEmpty) + intercept[NoSuchDatabaseException] { + catalog.getTableMetadataOption(TableIdentifier("tbl1", Some("unknown_db"))) + } } } test("lookup table relation") { - val externalCatalog = newBasicCatalog() - val sessionCatalog = new SessionCatalog(externalCatalog) - val tempTable1 = Range(1, 10, 1, 10) - val metastoreTable1 = externalCatalog.getTable("db2", "tbl1") - sessionCatalog.createTempView("tbl1", tempTable1, overrideIfExists = false) - sessionCatalog.setCurrentDatabase("db2") - // If we explicitly specify the database, we'll look up the relation in that database - assert(sessionCatalog.lookupRelation(TableIdentifier("tbl1", Some("db2"))).children.head - .asInstanceOf[CatalogRelation].tableMeta == metastoreTable1) - // Otherwise, we'll first look up a temporary table with the same name - assert(sessionCatalog.lookupRelation(TableIdentifier("tbl1")) - == SubqueryAlias("tbl1", tempTable1)) - // Then, if that does not exist, look up the relation in the current database - sessionCatalog.dropTable(TableIdentifier("tbl1"), ignoreIfNotExists = false, purge = false) - assert(sessionCatalog.lookupRelation(TableIdentifier("tbl1")).children.head - .asInstanceOf[CatalogRelation].tableMeta == metastoreTable1) + withBasicCatalog { catalog => + val tempTable1 = Range(1, 10, 1, 10) + val metastoreTable1 = catalog.externalCatalog.getTable("db2", "tbl1") + catalog.createTempView("tbl1", tempTable1, overrideIfExists = false) + catalog.setCurrentDatabase("db2") + // If we explicitly specify the database, we'll look up the relation in that database + assert(catalog.lookupRelation(TableIdentifier("tbl1", Some("db2"))).children.head + .asInstanceOf[CatalogRelation].tableMeta == metastoreTable1) + // Otherwise, we'll first look up a temporary table with the same name + assert(catalog.lookupRelation(TableIdentifier("tbl1")) + == SubqueryAlias("tbl1", tempTable1)) + // Then, if that does not exist, look up the relation in the current database + catalog.dropTable(TableIdentifier("tbl1"), ignoreIfNotExists = false, purge = false) + assert(catalog.lookupRelation(TableIdentifier("tbl1")).children.head + .asInstanceOf[CatalogRelation].tableMeta == metastoreTable1) + } } test("look up view relation") { - val externalCatalog = newBasicCatalog() - val sessionCatalog = new SessionCatalog(externalCatalog) - val metadata = externalCatalog.getTable("db3", "view1") - sessionCatalog.setCurrentDatabase("default") - // Look up a view. - assert(metadata.viewText.isDefined) - val view = View(desc = metadata, output = metadata.schema.toAttributes, - child = CatalystSqlParser.parsePlan(metadata.viewText.get)) - comparePlans(sessionCatalog.lookupRelation(TableIdentifier("view1", Some("db3"))), - SubqueryAlias("view1", view)) - // Look up a view using current database of the session catalog. - sessionCatalog.setCurrentDatabase("db3") - comparePlans(sessionCatalog.lookupRelation(TableIdentifier("view1")), - SubqueryAlias("view1", view)) + withBasicCatalog { catalog => + val metadata = catalog.externalCatalog.getTable("db3", "view1") + catalog.setCurrentDatabase("default") + // Look up a view. + assert(metadata.viewText.isDefined) + val view = View(desc = metadata, output = metadata.schema.toAttributes, + child = CatalystSqlParser.parsePlan(metadata.viewText.get)) + comparePlans(catalog.lookupRelation(TableIdentifier("view1", Some("db3"))), + SubqueryAlias("view1", view)) + // Look up a view using current database of the session catalog. + catalog.setCurrentDatabase("db3") + comparePlans(catalog.lookupRelation(TableIdentifier("view1")), + SubqueryAlias("view1", view)) + } } test("table exists") { - val catalog = new SessionCatalog(newBasicCatalog()) - assert(catalog.tableExists(TableIdentifier("tbl1", Some("db2")))) - assert(catalog.tableExists(TableIdentifier("tbl2", Some("db2")))) - assert(!catalog.tableExists(TableIdentifier("tbl3", Some("db2")))) - assert(!catalog.tableExists(TableIdentifier("tbl1", Some("db1")))) - assert(!catalog.tableExists(TableIdentifier("tbl2", Some("db1")))) - // If database is explicitly specified, do not check temporary tables - val tempTable = Range(1, 10, 1, 10) - assert(!catalog.tableExists(TableIdentifier("tbl3", Some("db2")))) - // If database is not explicitly specified, check the current database - catalog.setCurrentDatabase("db2") - assert(catalog.tableExists(TableIdentifier("tbl1"))) - assert(catalog.tableExists(TableIdentifier("tbl2"))) - - catalog.createTempView("tbl3", tempTable, overrideIfExists = false) - // tableExists should not check temp view. - assert(!catalog.tableExists(TableIdentifier("tbl3"))) + withBasicCatalog { catalog => + assert(catalog.tableExists(TableIdentifier("tbl1", Some("db2")))) + assert(catalog.tableExists(TableIdentifier("tbl2", Some("db2")))) + assert(!catalog.tableExists(TableIdentifier("tbl3", Some("db2")))) + assert(!catalog.tableExists(TableIdentifier("tbl1", Some("db1")))) + assert(!catalog.tableExists(TableIdentifier("tbl2", Some("db1")))) + // If database is explicitly specified, do not check temporary tables + val tempTable = Range(1, 10, 1, 10) + assert(!catalog.tableExists(TableIdentifier("tbl3", Some("db2")))) + // If database is not explicitly specified, check the current database + catalog.setCurrentDatabase("db2") + assert(catalog.tableExists(TableIdentifier("tbl1"))) + assert(catalog.tableExists(TableIdentifier("tbl2"))) + + catalog.createTempView("tbl3", tempTable, overrideIfExists = false) + // tableExists should not check temp view. + assert(!catalog.tableExists(TableIdentifier("tbl3"))) + } } test("getTempViewOrPermanentTableMetadata on temporary views") { - val catalog = new SessionCatalog(newBasicCatalog()) - val tempTable = Range(1, 10, 2, 10) - intercept[NoSuchTableException] { - catalog.getTempViewOrPermanentTableMetadata(TableIdentifier("view1")) - }.getMessage + withBasicCatalog { catalog => + val tempTable = Range(1, 10, 2, 10) + intercept[NoSuchTableException] { + catalog.getTempViewOrPermanentTableMetadata(TableIdentifier("view1")) + }.getMessage - intercept[NoSuchTableException] { - catalog.getTempViewOrPermanentTableMetadata(TableIdentifier("view1", Some("default"))) - }.getMessage + intercept[NoSuchTableException] { + catalog.getTempViewOrPermanentTableMetadata(TableIdentifier("view1", Some("default"))) + }.getMessage - catalog.createTempView("view1", tempTable, overrideIfExists = false) - assert(catalog.getTempViewOrPermanentTableMetadata( - TableIdentifier("view1")).identifier.table == "view1") - assert(catalog.getTempViewOrPermanentTableMetadata( - TableIdentifier("view1")).schema(0).name == "id") + catalog.createTempView("view1", tempTable, overrideIfExists = false) + assert(catalog.getTempViewOrPermanentTableMetadata( + TableIdentifier("view1")).identifier.table == "view1") + assert(catalog.getTempViewOrPermanentTableMetadata( + TableIdentifier("view1")).schema(0).name == "id") - intercept[NoSuchTableException] { - catalog.getTempViewOrPermanentTableMetadata(TableIdentifier("view1", Some("default"))) - }.getMessage + intercept[NoSuchTableException] { + catalog.getTempViewOrPermanentTableMetadata(TableIdentifier("view1", Some("default"))) + }.getMessage + } } test("list tables without pattern") { - val catalog = new SessionCatalog(newBasicCatalog()) - val tempTable = Range(1, 10, 2, 10) - catalog.createTempView("tbl1", tempTable, overrideIfExists = false) - catalog.createTempView("tbl4", tempTable, overrideIfExists = false) - assert(catalog.listTables("db1").toSet == - Set(TableIdentifier("tbl1"), TableIdentifier("tbl4"))) - assert(catalog.listTables("db2").toSet == - Set(TableIdentifier("tbl1"), - TableIdentifier("tbl4"), - TableIdentifier("tbl1", Some("db2")), - TableIdentifier("tbl2", Some("db2")))) - intercept[NoSuchDatabaseException] { - catalog.listTables("unknown_db") + withBasicCatalog { catalog => + val tempTable = Range(1, 10, 2, 10) + catalog.createTempView("tbl1", tempTable, overrideIfExists = false) + catalog.createTempView("tbl4", tempTable, overrideIfExists = false) + assert(catalog.listTables("db1").toSet == + Set(TableIdentifier("tbl1"), TableIdentifier("tbl4"))) + assert(catalog.listTables("db2").toSet == + Set(TableIdentifier("tbl1"), + TableIdentifier("tbl4"), + TableIdentifier("tbl1", Some("db2")), + TableIdentifier("tbl2", Some("db2")))) + intercept[NoSuchDatabaseException] { + catalog.listTables("unknown_db") + } } } test("list tables with pattern") { - val catalog = new SessionCatalog(newBasicCatalog()) - val tempTable = Range(1, 10, 2, 10) - catalog.createTempView("tbl1", tempTable, overrideIfExists = false) - catalog.createTempView("tbl4", tempTable, overrideIfExists = false) - assert(catalog.listTables("db1", "*").toSet == catalog.listTables("db1").toSet) - assert(catalog.listTables("db2", "*").toSet == catalog.listTables("db2").toSet) - assert(catalog.listTables("db2", "tbl*").toSet == - Set(TableIdentifier("tbl1"), - TableIdentifier("tbl4"), - TableIdentifier("tbl1", Some("db2")), - TableIdentifier("tbl2", Some("db2")))) - assert(catalog.listTables("db2", "*1").toSet == - Set(TableIdentifier("tbl1"), TableIdentifier("tbl1", Some("db2")))) - intercept[NoSuchDatabaseException] { - catalog.listTables("unknown_db", "*") + withBasicCatalog { catalog => + val tempTable = Range(1, 10, 2, 10) + catalog.createTempView("tbl1", tempTable, overrideIfExists = false) + catalog.createTempView("tbl4", tempTable, overrideIfExists = false) + assert(catalog.listTables("db1", "*").toSet == catalog.listTables("db1").toSet) + assert(catalog.listTables("db2", "*").toSet == catalog.listTables("db2").toSet) + assert(catalog.listTables("db2", "tbl*").toSet == + Set(TableIdentifier("tbl1"), + TableIdentifier("tbl4"), + TableIdentifier("tbl1", Some("db2")), + TableIdentifier("tbl2", Some("db2")))) + assert(catalog.listTables("db2", "*1").toSet == + Set(TableIdentifier("tbl1"), TableIdentifier("tbl1", Some("db2")))) + intercept[NoSuchDatabaseException] { + catalog.listTables("unknown_db", "*") + } } } @@ -546,451 +606,477 @@ class SessionCatalogSuite extends PlanTest { // -------------------------------------------------------------------------- test("basic create and list partitions") { - val externalCatalog = newEmptyCatalog() - val sessionCatalog = new SessionCatalog(externalCatalog) - sessionCatalog.createDatabase(newDb("mydb"), ignoreIfExists = false) - sessionCatalog.createTable(newTable("tbl", "mydb"), ignoreIfExists = false) - sessionCatalog.createPartitions( - TableIdentifier("tbl", Some("mydb")), Seq(part1, part2), ignoreIfExists = false) - assert(catalogPartitionsEqual(externalCatalog.listPartitions("mydb", "tbl"), part1, part2)) - // Create partitions without explicitly specifying database - sessionCatalog.setCurrentDatabase("mydb") - sessionCatalog.createPartitions( - TableIdentifier("tbl"), Seq(partWithMixedOrder), ignoreIfExists = false) - assert(catalogPartitionsEqual( - externalCatalog.listPartitions("mydb", "tbl"), part1, part2, partWithMixedOrder)) + withEmptyCatalog { catalog => + catalog.createDatabase(newDb("mydb"), ignoreIfExists = false) + catalog.createTable(newTable("tbl", "mydb"), ignoreIfExists = false) + catalog.createPartitions( + TableIdentifier("tbl", Some("mydb")), Seq(part1, part2), ignoreIfExists = false) + assert(catalogPartitionsEqual( + catalog.externalCatalog.listPartitions("mydb", "tbl"), part1, part2)) + // Create partitions without explicitly specifying database + catalog.setCurrentDatabase("mydb") + catalog.createPartitions( + TableIdentifier("tbl"), Seq(partWithMixedOrder), ignoreIfExists = false) + assert(catalogPartitionsEqual( + catalog.externalCatalog.listPartitions("mydb", "tbl"), part1, part2, partWithMixedOrder)) + } } test("create partitions when database/table does not exist") { - val catalog = new SessionCatalog(newBasicCatalog()) - intercept[NoSuchDatabaseException] { - catalog.createPartitions( - TableIdentifier("tbl1", Some("unknown_db")), Seq(), ignoreIfExists = false) - } - intercept[NoSuchTableException] { - catalog.createPartitions( - TableIdentifier("does_not_exist", Some("db2")), Seq(), ignoreIfExists = false) + withBasicCatalog { catalog => + intercept[NoSuchDatabaseException] { + catalog.createPartitions( + TableIdentifier("tbl1", Some("unknown_db")), Seq(), ignoreIfExists = false) + } + intercept[NoSuchTableException] { + catalog.createPartitions( + TableIdentifier("does_not_exist", Some("db2")), Seq(), ignoreIfExists = false) + } } } test("create partitions that already exist") { - val catalog = new SessionCatalog(newBasicCatalog()) - intercept[AnalysisException] { + withBasicCatalog { catalog => + intercept[AnalysisException] { + catalog.createPartitions( + TableIdentifier("tbl2", Some("db2")), Seq(part1), ignoreIfExists = false) + } catalog.createPartitions( - TableIdentifier("tbl2", Some("db2")), Seq(part1), ignoreIfExists = false) + TableIdentifier("tbl2", Some("db2")), Seq(part1), ignoreIfExists = true) } - catalog.createPartitions( - TableIdentifier("tbl2", Some("db2")), Seq(part1), ignoreIfExists = true) } test("create partitions with invalid part spec") { - val catalog = new SessionCatalog(newBasicCatalog()) - var e = intercept[AnalysisException] { - catalog.createPartitions( - TableIdentifier("tbl2", Some("db2")), - Seq(part1, partWithLessColumns), ignoreIfExists = false) - } - assert(e.getMessage.contains("Partition spec is invalid. The spec (a) must match " + - "the partition spec (a, b) defined in table '`db2`.`tbl2`'")) - e = intercept[AnalysisException] { - catalog.createPartitions( - TableIdentifier("tbl2", Some("db2")), - Seq(part1, partWithMoreColumns), ignoreIfExists = true) - } - assert(e.getMessage.contains("Partition spec is invalid. The spec (a, b, c) must match " + - "the partition spec (a, b) defined in table '`db2`.`tbl2`'")) - e = intercept[AnalysisException] { - catalog.createPartitions( - TableIdentifier("tbl2", Some("db2")), - Seq(partWithUnknownColumns, part1), ignoreIfExists = true) - } - assert(e.getMessage.contains("Partition spec is invalid. The spec (a, unknown) must match " + - "the partition spec (a, b) defined in table '`db2`.`tbl2`'")) - e = intercept[AnalysisException] { - catalog.createPartitions( - TableIdentifier("tbl2", Some("db2")), - Seq(partWithEmptyValue, part1), ignoreIfExists = true) + withBasicCatalog { catalog => + var e = intercept[AnalysisException] { + catalog.createPartitions( + TableIdentifier("tbl2", Some("db2")), + Seq(part1, partWithLessColumns), ignoreIfExists = false) + } + assert(e.getMessage.contains("Partition spec is invalid. The spec (a) must match " + + "the partition spec (a, b) defined in table '`db2`.`tbl2`'")) + e = intercept[AnalysisException] { + catalog.createPartitions( + TableIdentifier("tbl2", Some("db2")), + Seq(part1, partWithMoreColumns), ignoreIfExists = true) + } + assert(e.getMessage.contains("Partition spec is invalid. The spec (a, b, c) must match " + + "the partition spec (a, b) defined in table '`db2`.`tbl2`'")) + e = intercept[AnalysisException] { + catalog.createPartitions( + TableIdentifier("tbl2", Some("db2")), + Seq(partWithUnknownColumns, part1), ignoreIfExists = true) + } + assert(e.getMessage.contains("Partition spec is invalid. The spec (a, unknown) must match " + + "the partition spec (a, b) defined in table '`db2`.`tbl2`'")) + e = intercept[AnalysisException] { + catalog.createPartitions( + TableIdentifier("tbl2", Some("db2")), + Seq(partWithEmptyValue, part1), ignoreIfExists = true) + } + assert(e.getMessage.contains("Partition spec is invalid. The spec ([a=3, b=]) contains an " + + "empty partition column value")) } - assert(e.getMessage.contains("Partition spec is invalid. The spec ([a=3, b=]) contains an " + - "empty partition column value")) } test("drop partitions") { - val externalCatalog = newBasicCatalog() - val sessionCatalog = new SessionCatalog(externalCatalog) - assert(catalogPartitionsEqual(externalCatalog.listPartitions("db2", "tbl2"), part1, part2)) - sessionCatalog.dropPartitions( - TableIdentifier("tbl2", Some("db2")), - Seq(part1.spec), - ignoreIfNotExists = false, - purge = false, - retainData = false) - assert(catalogPartitionsEqual(externalCatalog.listPartitions("db2", "tbl2"), part2)) - // Drop partitions without explicitly specifying database - sessionCatalog.setCurrentDatabase("db2") - sessionCatalog.dropPartitions( - TableIdentifier("tbl2"), - Seq(part2.spec), - ignoreIfNotExists = false, - purge = false, - retainData = false) - assert(externalCatalog.listPartitions("db2", "tbl2").isEmpty) - // Drop multiple partitions at once - sessionCatalog.createPartitions( - TableIdentifier("tbl2", Some("db2")), Seq(part1, part2), ignoreIfExists = false) - assert(catalogPartitionsEqual(externalCatalog.listPartitions("db2", "tbl2"), part1, part2)) - sessionCatalog.dropPartitions( - TableIdentifier("tbl2", Some("db2")), - Seq(part1.spec, part2.spec), - ignoreIfNotExists = false, - purge = false, - retainData = false) - assert(externalCatalog.listPartitions("db2", "tbl2").isEmpty) - } - - test("drop partitions when database/table does not exist") { - val catalog = new SessionCatalog(newBasicCatalog()) - intercept[NoSuchDatabaseException] { + withBasicCatalog { catalog => + assert(catalogPartitionsEqual( + catalog.externalCatalog.listPartitions("db2", "tbl2"), part1, part2)) catalog.dropPartitions( - TableIdentifier("tbl1", Some("unknown_db")), - Seq(), + TableIdentifier("tbl2", Some("db2")), + Seq(part1.spec), ignoreIfNotExists = false, purge = false, retainData = false) - } - intercept[NoSuchTableException] { + assert(catalogPartitionsEqual( + catalog.externalCatalog.listPartitions("db2", "tbl2"), part2)) + // Drop partitions without explicitly specifying database + catalog.setCurrentDatabase("db2") catalog.dropPartitions( - TableIdentifier("does_not_exist", Some("db2")), - Seq(), + TableIdentifier("tbl2"), + Seq(part2.spec), ignoreIfNotExists = false, purge = false, retainData = false) - } - } - - test("drop partitions that do not exist") { - val catalog = new SessionCatalog(newBasicCatalog()) - intercept[AnalysisException] { + assert(catalog.externalCatalog.listPartitions("db2", "tbl2").isEmpty) + // Drop multiple partitions at once + catalog.createPartitions( + TableIdentifier("tbl2", Some("db2")), Seq(part1, part2), ignoreIfExists = false) + assert(catalogPartitionsEqual( + catalog.externalCatalog.listPartitions("db2", "tbl2"), part1, part2)) catalog.dropPartitions( TableIdentifier("tbl2", Some("db2")), - Seq(part3.spec), + Seq(part1.spec, part2.spec), ignoreIfNotExists = false, purge = false, retainData = false) + assert(catalog.externalCatalog.listPartitions("db2", "tbl2").isEmpty) } - catalog.dropPartitions( - TableIdentifier("tbl2", Some("db2")), - Seq(part3.spec), - ignoreIfNotExists = true, - purge = false, - retainData = false) } - test("drop partitions with invalid partition spec") { - val catalog = new SessionCatalog(newBasicCatalog()) - var e = intercept[AnalysisException] { - catalog.dropPartitions( - TableIdentifier("tbl2", Some("db2")), - Seq(partWithMoreColumns.spec), - ignoreIfNotExists = false, - purge = false, - retainData = false) + test("drop partitions when database/table does not exist") { + withBasicCatalog { catalog => + intercept[NoSuchDatabaseException] { + catalog.dropPartitions( + TableIdentifier("tbl1", Some("unknown_db")), + Seq(), + ignoreIfNotExists = false, + purge = false, + retainData = false) + } + intercept[NoSuchTableException] { + catalog.dropPartitions( + TableIdentifier("does_not_exist", Some("db2")), + Seq(), + ignoreIfNotExists = false, + purge = false, + retainData = false) + } } - assert(e.getMessage.contains( - "Partition spec is invalid. The spec (a, b, c) must be contained within " + - "the partition spec (a, b) defined in table '`db2`.`tbl2`'")) - e = intercept[AnalysisException] { + } + + test("drop partitions that do not exist") { + withBasicCatalog { catalog => + intercept[AnalysisException] { + catalog.dropPartitions( + TableIdentifier("tbl2", Some("db2")), + Seq(part3.spec), + ignoreIfNotExists = false, + purge = false, + retainData = false) + } catalog.dropPartitions( TableIdentifier("tbl2", Some("db2")), - Seq(partWithUnknownColumns.spec), - ignoreIfNotExists = false, + Seq(part3.spec), + ignoreIfNotExists = true, purge = false, retainData = false) } - assert(e.getMessage.contains( - "Partition spec is invalid. The spec (a, unknown) must be contained within " + - "the partition spec (a, b) defined in table '`db2`.`tbl2`'")) - e = intercept[AnalysisException] { - catalog.dropPartitions( - TableIdentifier("tbl2", Some("db2")), - Seq(partWithEmptyValue.spec, part1.spec), - ignoreIfNotExists = false, - purge = false, - retainData = false) + } + + test("drop partitions with invalid partition spec") { + withBasicCatalog { catalog => + var e = intercept[AnalysisException] { + catalog.dropPartitions( + TableIdentifier("tbl2", Some("db2")), + Seq(partWithMoreColumns.spec), + ignoreIfNotExists = false, + purge = false, + retainData = false) + } + assert(e.getMessage.contains( + "Partition spec is invalid. The spec (a, b, c) must be contained within " + + "the partition spec (a, b) defined in table '`db2`.`tbl2`'")) + e = intercept[AnalysisException] { + catalog.dropPartitions( + TableIdentifier("tbl2", Some("db2")), + Seq(partWithUnknownColumns.spec), + ignoreIfNotExists = false, + purge = false, + retainData = false) + } + assert(e.getMessage.contains( + "Partition spec is invalid. The spec (a, unknown) must be contained within " + + "the partition spec (a, b) defined in table '`db2`.`tbl2`'")) + e = intercept[AnalysisException] { + catalog.dropPartitions( + TableIdentifier("tbl2", Some("db2")), + Seq(partWithEmptyValue.spec, part1.spec), + ignoreIfNotExists = false, + purge = false, + retainData = false) + } + assert(e.getMessage.contains("Partition spec is invalid. The spec ([a=3, b=]) contains an " + + "empty partition column value")) } - assert(e.getMessage.contains("Partition spec is invalid. The spec ([a=3, b=]) contains an " + - "empty partition column value")) } test("get partition") { - val catalog = new SessionCatalog(newBasicCatalog()) - assert(catalog.getPartition( - TableIdentifier("tbl2", Some("db2")), part1.spec).spec == part1.spec) - assert(catalog.getPartition( - TableIdentifier("tbl2", Some("db2")), part2.spec).spec == part2.spec) - // Get partition without explicitly specifying database - catalog.setCurrentDatabase("db2") - assert(catalog.getPartition(TableIdentifier("tbl2"), part1.spec).spec == part1.spec) - assert(catalog.getPartition(TableIdentifier("tbl2"), part2.spec).spec == part2.spec) - // Get non-existent partition - intercept[AnalysisException] { - catalog.getPartition(TableIdentifier("tbl2"), part3.spec) + withBasicCatalog { catalog => + assert(catalog.getPartition( + TableIdentifier("tbl2", Some("db2")), part1.spec).spec == part1.spec) + assert(catalog.getPartition( + TableIdentifier("tbl2", Some("db2")), part2.spec).spec == part2.spec) + // Get partition without explicitly specifying database + catalog.setCurrentDatabase("db2") + assert(catalog.getPartition(TableIdentifier("tbl2"), part1.spec).spec == part1.spec) + assert(catalog.getPartition(TableIdentifier("tbl2"), part2.spec).spec == part2.spec) + // Get non-existent partition + intercept[AnalysisException] { + catalog.getPartition(TableIdentifier("tbl2"), part3.spec) + } } } test("get partition when database/table does not exist") { - val catalog = new SessionCatalog(newBasicCatalog()) - intercept[NoSuchDatabaseException] { - catalog.getPartition(TableIdentifier("tbl1", Some("unknown_db")), part1.spec) - } - intercept[NoSuchTableException] { - catalog.getPartition(TableIdentifier("does_not_exist", Some("db2")), part1.spec) + withBasicCatalog { catalog => + intercept[NoSuchDatabaseException] { + catalog.getPartition(TableIdentifier("tbl1", Some("unknown_db")), part1.spec) + } + intercept[NoSuchTableException] { + catalog.getPartition(TableIdentifier("does_not_exist", Some("db2")), part1.spec) + } } } test("get partition with invalid partition spec") { - val catalog = new SessionCatalog(newBasicCatalog()) - var e = intercept[AnalysisException] { - catalog.getPartition(TableIdentifier("tbl1", Some("db2")), partWithLessColumns.spec) - } - assert(e.getMessage.contains("Partition spec is invalid. The spec (a) must match " + - "the partition spec (a, b) defined in table '`db2`.`tbl1`'")) - e = intercept[AnalysisException] { - catalog.getPartition(TableIdentifier("tbl1", Some("db2")), partWithMoreColumns.spec) - } - assert(e.getMessage.contains("Partition spec is invalid. The spec (a, b, c) must match " + - "the partition spec (a, b) defined in table '`db2`.`tbl1`'")) - e = intercept[AnalysisException] { - catalog.getPartition(TableIdentifier("tbl1", Some("db2")), partWithUnknownColumns.spec) - } - assert(e.getMessage.contains("Partition spec is invalid. The spec (a, unknown) must match " + - "the partition spec (a, b) defined in table '`db2`.`tbl1`'")) - e = intercept[AnalysisException] { - catalog.getPartition(TableIdentifier("tbl1", Some("db2")), partWithEmptyValue.spec) + withBasicCatalog { catalog => + var e = intercept[AnalysisException] { + catalog.getPartition(TableIdentifier("tbl1", Some("db2")), partWithLessColumns.spec) + } + assert(e.getMessage.contains("Partition spec is invalid. The spec (a) must match " + + "the partition spec (a, b) defined in table '`db2`.`tbl1`'")) + e = intercept[AnalysisException] { + catalog.getPartition(TableIdentifier("tbl1", Some("db2")), partWithMoreColumns.spec) + } + assert(e.getMessage.contains("Partition spec is invalid. The spec (a, b, c) must match " + + "the partition spec (a, b) defined in table '`db2`.`tbl1`'")) + e = intercept[AnalysisException] { + catalog.getPartition(TableIdentifier("tbl1", Some("db2")), partWithUnknownColumns.spec) + } + assert(e.getMessage.contains("Partition spec is invalid. The spec (a, unknown) must match " + + "the partition spec (a, b) defined in table '`db2`.`tbl1`'")) + e = intercept[AnalysisException] { + catalog.getPartition(TableIdentifier("tbl1", Some("db2")), partWithEmptyValue.spec) + } + assert(e.getMessage.contains("Partition spec is invalid. The spec ([a=3, b=]) contains an " + + "empty partition column value")) } - assert(e.getMessage.contains("Partition spec is invalid. The spec ([a=3, b=]) contains an " + - "empty partition column value")) } test("rename partitions") { - val catalog = new SessionCatalog(newBasicCatalog()) - val newPart1 = part1.copy(spec = Map("a" -> "100", "b" -> "101")) - val newPart2 = part2.copy(spec = Map("a" -> "200", "b" -> "201")) - val newSpecs = Seq(newPart1.spec, newPart2.spec) - catalog.renamePartitions( - TableIdentifier("tbl2", Some("db2")), Seq(part1.spec, part2.spec), newSpecs) - assert(catalog.getPartition( - TableIdentifier("tbl2", Some("db2")), newPart1.spec).spec === newPart1.spec) - assert(catalog.getPartition( - TableIdentifier("tbl2", Some("db2")), newPart2.spec).spec === newPart2.spec) - intercept[AnalysisException] { - catalog.getPartition(TableIdentifier("tbl2", Some("db2")), part1.spec) - } - intercept[AnalysisException] { - catalog.getPartition(TableIdentifier("tbl2", Some("db2")), part2.spec) - } - // Rename partitions without explicitly specifying database - catalog.setCurrentDatabase("db2") - catalog.renamePartitions(TableIdentifier("tbl2"), newSpecs, Seq(part1.spec, part2.spec)) - assert(catalog.getPartition(TableIdentifier("tbl2"), part1.spec).spec === part1.spec) - assert(catalog.getPartition(TableIdentifier("tbl2"), part2.spec).spec === part2.spec) - intercept[AnalysisException] { - catalog.getPartition(TableIdentifier("tbl2"), newPart1.spec) - } - intercept[AnalysisException] { - catalog.getPartition(TableIdentifier("tbl2"), newPart2.spec) + withBasicCatalog { catalog => + val newPart1 = part1.copy(spec = Map("a" -> "100", "b" -> "101")) + val newPart2 = part2.copy(spec = Map("a" -> "200", "b" -> "201")) + val newSpecs = Seq(newPart1.spec, newPart2.spec) + catalog.renamePartitions( + TableIdentifier("tbl2", Some("db2")), Seq(part1.spec, part2.spec), newSpecs) + assert(catalog.getPartition( + TableIdentifier("tbl2", Some("db2")), newPart1.spec).spec === newPart1.spec) + assert(catalog.getPartition( + TableIdentifier("tbl2", Some("db2")), newPart2.spec).spec === newPart2.spec) + intercept[AnalysisException] { + catalog.getPartition(TableIdentifier("tbl2", Some("db2")), part1.spec) + } + intercept[AnalysisException] { + catalog.getPartition(TableIdentifier("tbl2", Some("db2")), part2.spec) + } + // Rename partitions without explicitly specifying database + catalog.setCurrentDatabase("db2") + catalog.renamePartitions(TableIdentifier("tbl2"), newSpecs, Seq(part1.spec, part2.spec)) + assert(catalog.getPartition(TableIdentifier("tbl2"), part1.spec).spec === part1.spec) + assert(catalog.getPartition(TableIdentifier("tbl2"), part2.spec).spec === part2.spec) + intercept[AnalysisException] { + catalog.getPartition(TableIdentifier("tbl2"), newPart1.spec) + } + intercept[AnalysisException] { + catalog.getPartition(TableIdentifier("tbl2"), newPart2.spec) + } } } test("rename partitions when database/table does not exist") { - val catalog = new SessionCatalog(newBasicCatalog()) - intercept[NoSuchDatabaseException] { - catalog.renamePartitions( - TableIdentifier("tbl1", Some("unknown_db")), Seq(part1.spec), Seq(part2.spec)) - } - intercept[NoSuchTableException] { - catalog.renamePartitions( - TableIdentifier("does_not_exist", Some("db2")), Seq(part1.spec), Seq(part2.spec)) + withBasicCatalog { catalog => + intercept[NoSuchDatabaseException] { + catalog.renamePartitions( + TableIdentifier("tbl1", Some("unknown_db")), Seq(part1.spec), Seq(part2.spec)) + } + intercept[NoSuchTableException] { + catalog.renamePartitions( + TableIdentifier("does_not_exist", Some("db2")), Seq(part1.spec), Seq(part2.spec)) + } } } test("rename partition with invalid partition spec") { - val catalog = new SessionCatalog(newBasicCatalog()) - var e = intercept[AnalysisException] { - catalog.renamePartitions( - TableIdentifier("tbl1", Some("db2")), - Seq(part1.spec), Seq(partWithLessColumns.spec)) - } - assert(e.getMessage.contains("Partition spec is invalid. The spec (a) must match " + - "the partition spec (a, b) defined in table '`db2`.`tbl1`'")) - e = intercept[AnalysisException] { - catalog.renamePartitions( - TableIdentifier("tbl1", Some("db2")), - Seq(part1.spec), Seq(partWithMoreColumns.spec)) - } - assert(e.getMessage.contains("Partition spec is invalid. The spec (a, b, c) must match " + - "the partition spec (a, b) defined in table '`db2`.`tbl1`'")) - e = intercept[AnalysisException] { - catalog.renamePartitions( - TableIdentifier("tbl1", Some("db2")), - Seq(part1.spec), Seq(partWithUnknownColumns.spec)) - } - assert(e.getMessage.contains("Partition spec is invalid. The spec (a, unknown) must match " + - "the partition spec (a, b) defined in table '`db2`.`tbl1`'")) - e = intercept[AnalysisException] { - catalog.renamePartitions( - TableIdentifier("tbl1", Some("db2")), - Seq(part1.spec), Seq(partWithEmptyValue.spec)) + withBasicCatalog { catalog => + var e = intercept[AnalysisException] { + catalog.renamePartitions( + TableIdentifier("tbl1", Some("db2")), + Seq(part1.spec), Seq(partWithLessColumns.spec)) + } + assert(e.getMessage.contains("Partition spec is invalid. The spec (a) must match " + + "the partition spec (a, b) defined in table '`db2`.`tbl1`'")) + e = intercept[AnalysisException] { + catalog.renamePartitions( + TableIdentifier("tbl1", Some("db2")), + Seq(part1.spec), Seq(partWithMoreColumns.spec)) + } + assert(e.getMessage.contains("Partition spec is invalid. The spec (a, b, c) must match " + + "the partition spec (a, b) defined in table '`db2`.`tbl1`'")) + e = intercept[AnalysisException] { + catalog.renamePartitions( + TableIdentifier("tbl1", Some("db2")), + Seq(part1.spec), Seq(partWithUnknownColumns.spec)) + } + assert(e.getMessage.contains("Partition spec is invalid. The spec (a, unknown) must match " + + "the partition spec (a, b) defined in table '`db2`.`tbl1`'")) + e = intercept[AnalysisException] { + catalog.renamePartitions( + TableIdentifier("tbl1", Some("db2")), + Seq(part1.spec), Seq(partWithEmptyValue.spec)) + } + assert(e.getMessage.contains("Partition spec is invalid. The spec ([a=3, b=]) contains an " + + "empty partition column value")) } - assert(e.getMessage.contains("Partition spec is invalid. The spec ([a=3, b=]) contains an " + - "empty partition column value")) } test("alter partitions") { - val catalog = new SessionCatalog(newBasicCatalog()) - val newLocation = newUriForDatabase() - // Alter but keep spec the same - val oldPart1 = catalog.getPartition(TableIdentifier("tbl2", Some("db2")), part1.spec) - val oldPart2 = catalog.getPartition(TableIdentifier("tbl2", Some("db2")), part2.spec) - catalog.alterPartitions(TableIdentifier("tbl2", Some("db2")), Seq( - oldPart1.copy(storage = storageFormat.copy(locationUri = Some(newLocation))), - oldPart2.copy(storage = storageFormat.copy(locationUri = Some(newLocation))))) - val newPart1 = catalog.getPartition(TableIdentifier("tbl2", Some("db2")), part1.spec) - val newPart2 = catalog.getPartition(TableIdentifier("tbl2", Some("db2")), part2.spec) - assert(newPart1.storage.locationUri == Some(newLocation)) - assert(newPart2.storage.locationUri == Some(newLocation)) - assert(oldPart1.storage.locationUri != Some(newLocation)) - assert(oldPart2.storage.locationUri != Some(newLocation)) - // Alter partitions without explicitly specifying database - catalog.setCurrentDatabase("db2") - catalog.alterPartitions(TableIdentifier("tbl2"), Seq(oldPart1, oldPart2)) - val newerPart1 = catalog.getPartition(TableIdentifier("tbl2"), part1.spec) - val newerPart2 = catalog.getPartition(TableIdentifier("tbl2"), part2.spec) - assert(oldPart1.storage.locationUri == newerPart1.storage.locationUri) - assert(oldPart2.storage.locationUri == newerPart2.storage.locationUri) - // Alter but change spec, should fail because new partition specs do not exist yet - val badPart1 = part1.copy(spec = Map("a" -> "v1", "b" -> "v2")) - val badPart2 = part2.copy(spec = Map("a" -> "v3", "b" -> "v4")) - intercept[AnalysisException] { - catalog.alterPartitions(TableIdentifier("tbl2", Some("db2")), Seq(badPart1, badPart2)) + withBasicCatalog { catalog => + val newLocation = newUriForDatabase() + // Alter but keep spec the same + val oldPart1 = catalog.getPartition(TableIdentifier("tbl2", Some("db2")), part1.spec) + val oldPart2 = catalog.getPartition(TableIdentifier("tbl2", Some("db2")), part2.spec) + catalog.alterPartitions(TableIdentifier("tbl2", Some("db2")), Seq( + oldPart1.copy(storage = storageFormat.copy(locationUri = Some(newLocation))), + oldPart2.copy(storage = storageFormat.copy(locationUri = Some(newLocation))))) + val newPart1 = catalog.getPartition(TableIdentifier("tbl2", Some("db2")), part1.spec) + val newPart2 = catalog.getPartition(TableIdentifier("tbl2", Some("db2")), part2.spec) + assert(newPart1.storage.locationUri == Some(newLocation)) + assert(newPart2.storage.locationUri == Some(newLocation)) + assert(oldPart1.storage.locationUri != Some(newLocation)) + assert(oldPart2.storage.locationUri != Some(newLocation)) + // Alter partitions without explicitly specifying database + catalog.setCurrentDatabase("db2") + catalog.alterPartitions(TableIdentifier("tbl2"), Seq(oldPart1, oldPart2)) + val newerPart1 = catalog.getPartition(TableIdentifier("tbl2"), part1.spec) + val newerPart2 = catalog.getPartition(TableIdentifier("tbl2"), part2.spec) + assert(oldPart1.storage.locationUri == newerPart1.storage.locationUri) + assert(oldPart2.storage.locationUri == newerPart2.storage.locationUri) + // Alter but change spec, should fail because new partition specs do not exist yet + val badPart1 = part1.copy(spec = Map("a" -> "v1", "b" -> "v2")) + val badPart2 = part2.copy(spec = Map("a" -> "v3", "b" -> "v4")) + intercept[AnalysisException] { + catalog.alterPartitions(TableIdentifier("tbl2", Some("db2")), Seq(badPart1, badPart2)) + } } } test("alter partitions when database/table does not exist") { - val catalog = new SessionCatalog(newBasicCatalog()) - intercept[NoSuchDatabaseException] { - catalog.alterPartitions(TableIdentifier("tbl1", Some("unknown_db")), Seq(part1)) - } - intercept[NoSuchTableException] { - catalog.alterPartitions(TableIdentifier("does_not_exist", Some("db2")), Seq(part1)) + withBasicCatalog { catalog => + intercept[NoSuchDatabaseException] { + catalog.alterPartitions(TableIdentifier("tbl1", Some("unknown_db")), Seq(part1)) + } + intercept[NoSuchTableException] { + catalog.alterPartitions(TableIdentifier("does_not_exist", Some("db2")), Seq(part1)) + } } } test("alter partition with invalid partition spec") { - val catalog = new SessionCatalog(newBasicCatalog()) - var e = intercept[AnalysisException] { - catalog.alterPartitions(TableIdentifier("tbl1", Some("db2")), Seq(partWithLessColumns)) - } - assert(e.getMessage.contains("Partition spec is invalid. The spec (a) must match " + - "the partition spec (a, b) defined in table '`db2`.`tbl1`'")) - e = intercept[AnalysisException] { - catalog.alterPartitions(TableIdentifier("tbl1", Some("db2")), Seq(partWithMoreColumns)) - } - assert(e.getMessage.contains("Partition spec is invalid. The spec (a, b, c) must match " + - "the partition spec (a, b) defined in table '`db2`.`tbl1`'")) - e = intercept[AnalysisException] { - catalog.alterPartitions(TableIdentifier("tbl1", Some("db2")), Seq(partWithUnknownColumns)) - } - assert(e.getMessage.contains("Partition spec is invalid. The spec (a, unknown) must match " + - "the partition spec (a, b) defined in table '`db2`.`tbl1`'")) - e = intercept[AnalysisException] { - catalog.alterPartitions(TableIdentifier("tbl1", Some("db2")), Seq(partWithEmptyValue)) + withBasicCatalog { catalog => + var e = intercept[AnalysisException] { + catalog.alterPartitions(TableIdentifier("tbl1", Some("db2")), Seq(partWithLessColumns)) + } + assert(e.getMessage.contains("Partition spec is invalid. The spec (a) must match " + + "the partition spec (a, b) defined in table '`db2`.`tbl1`'")) + e = intercept[AnalysisException] { + catalog.alterPartitions(TableIdentifier("tbl1", Some("db2")), Seq(partWithMoreColumns)) + } + assert(e.getMessage.contains("Partition spec is invalid. The spec (a, b, c) must match " + + "the partition spec (a, b) defined in table '`db2`.`tbl1`'")) + e = intercept[AnalysisException] { + catalog.alterPartitions(TableIdentifier("tbl1", Some("db2")), Seq(partWithUnknownColumns)) + } + assert(e.getMessage.contains("Partition spec is invalid. The spec (a, unknown) must match " + + "the partition spec (a, b) defined in table '`db2`.`tbl1`'")) + e = intercept[AnalysisException] { + catalog.alterPartitions(TableIdentifier("tbl1", Some("db2")), Seq(partWithEmptyValue)) + } + assert(e.getMessage.contains("Partition spec is invalid. The spec ([a=3, b=]) contains an " + + "empty partition column value")) } - assert(e.getMessage.contains("Partition spec is invalid. The spec ([a=3, b=]) contains an " + - "empty partition column value")) } test("list partition names") { - val catalog = new SessionCatalog(newBasicCatalog()) - val expectedPartitionNames = Seq("a=1/b=2", "a=3/b=4") - assert(catalog.listPartitionNames(TableIdentifier("tbl2", Some("db2"))) == - expectedPartitionNames) - // List partition names without explicitly specifying database - catalog.setCurrentDatabase("db2") - assert(catalog.listPartitionNames(TableIdentifier("tbl2")) == expectedPartitionNames) + withBasicCatalog { catalog => + val expectedPartitionNames = Seq("a=1/b=2", "a=3/b=4") + assert(catalog.listPartitionNames(TableIdentifier("tbl2", Some("db2"))) == + expectedPartitionNames) + // List partition names without explicitly specifying database + catalog.setCurrentDatabase("db2") + assert(catalog.listPartitionNames(TableIdentifier("tbl2")) == expectedPartitionNames) + } } test("list partition names with partial partition spec") { - val catalog = new SessionCatalog(newBasicCatalog()) - assert( - catalog.listPartitionNames(TableIdentifier("tbl2", Some("db2")), Some(Map("a" -> "1"))) == - Seq("a=1/b=2")) + withBasicCatalog { catalog => + assert( + catalog.listPartitionNames(TableIdentifier("tbl2", Some("db2")), Some(Map("a" -> "1"))) == + Seq("a=1/b=2")) + } } test("list partition names with invalid partial partition spec") { - val catalog = new SessionCatalog(newBasicCatalog()) - var e = intercept[AnalysisException] { - catalog.listPartitionNames(TableIdentifier("tbl2", Some("db2")), - Some(partWithMoreColumns.spec)) - } - assert(e.getMessage.contains("Partition spec is invalid. The spec (a, b, c) must be " + - "contained within the partition spec (a, b) defined in table '`db2`.`tbl2`'")) - e = intercept[AnalysisException] { - catalog.listPartitionNames(TableIdentifier("tbl2", Some("db2")), - Some(partWithUnknownColumns.spec)) - } - assert(e.getMessage.contains("Partition spec is invalid. The spec (a, unknown) must be " + - "contained within the partition spec (a, b) defined in table '`db2`.`tbl2`'")) - e = intercept[AnalysisException] { - catalog.listPartitionNames(TableIdentifier("tbl2", Some("db2")), - Some(partWithEmptyValue.spec)) + withBasicCatalog { catalog => + var e = intercept[AnalysisException] { + catalog.listPartitionNames(TableIdentifier("tbl2", Some("db2")), + Some(partWithMoreColumns.spec)) + } + assert(e.getMessage.contains("Partition spec is invalid. The spec (a, b, c) must be " + + "contained within the partition spec (a, b) defined in table '`db2`.`tbl2`'")) + e = intercept[AnalysisException] { + catalog.listPartitionNames(TableIdentifier("tbl2", Some("db2")), + Some(partWithUnknownColumns.spec)) + } + assert(e.getMessage.contains("Partition spec is invalid. The spec (a, unknown) must be " + + "contained within the partition spec (a, b) defined in table '`db2`.`tbl2`'")) + e = intercept[AnalysisException] { + catalog.listPartitionNames(TableIdentifier("tbl2", Some("db2")), + Some(partWithEmptyValue.spec)) + } + assert(e.getMessage.contains("Partition spec is invalid. The spec ([a=3, b=]) contains an " + + "empty partition column value")) } - assert(e.getMessage.contains("Partition spec is invalid. The spec ([a=3, b=]) contains an " + - "empty partition column value")) } test("list partitions") { - val catalog = new SessionCatalog(newBasicCatalog()) - assert(catalogPartitionsEqual( - catalog.listPartitions(TableIdentifier("tbl2", Some("db2"))), part1, part2)) - // List partitions without explicitly specifying database - catalog.setCurrentDatabase("db2") - assert(catalogPartitionsEqual(catalog.listPartitions(TableIdentifier("tbl2")), part1, part2)) + withBasicCatalog { catalog => + assert(catalogPartitionsEqual( + catalog.listPartitions(TableIdentifier("tbl2", Some("db2"))), part1, part2)) + // List partitions without explicitly specifying database + catalog.setCurrentDatabase("db2") + assert(catalogPartitionsEqual(catalog.listPartitions(TableIdentifier("tbl2")), part1, part2)) + } } test("list partitions with partial partition spec") { - val catalog = new SessionCatalog(newBasicCatalog()) - assert(catalogPartitionsEqual( - catalog.listPartitions(TableIdentifier("tbl2", Some("db2")), Some(Map("a" -> "1"))), part1)) + withBasicCatalog { catalog => + assert(catalogPartitionsEqual( + catalog.listPartitions(TableIdentifier("tbl2", Some("db2")), Some(Map("a" -> "1"))), part1)) + } } test("list partitions with invalid partial partition spec") { - val catalog = new SessionCatalog(newBasicCatalog()) - var e = intercept[AnalysisException] { - catalog.listPartitions(TableIdentifier("tbl2", Some("db2")), Some(partWithMoreColumns.spec)) - } - assert(e.getMessage.contains("Partition spec is invalid. The spec (a, b, c) must be " + - "contained within the partition spec (a, b) defined in table '`db2`.`tbl2`'")) - e = intercept[AnalysisException] { - catalog.listPartitions(TableIdentifier("tbl2", Some("db2")), - Some(partWithUnknownColumns.spec)) - } - assert(e.getMessage.contains("Partition spec is invalid. The spec (a, unknown) must be " + - "contained within the partition spec (a, b) defined in table '`db2`.`tbl2`'")) - e = intercept[AnalysisException] { - catalog.listPartitions(TableIdentifier("tbl2", Some("db2")), Some(partWithEmptyValue.spec)) + withBasicCatalog { catalog => + var e = intercept[AnalysisException] { + catalog.listPartitions(TableIdentifier("tbl2", Some("db2")), Some(partWithMoreColumns.spec)) + } + assert(e.getMessage.contains("Partition spec is invalid. The spec (a, b, c) must be " + + "contained within the partition spec (a, b) defined in table '`db2`.`tbl2`'")) + e = intercept[AnalysisException] { + catalog.listPartitions(TableIdentifier("tbl2", Some("db2")), + Some(partWithUnknownColumns.spec)) + } + assert(e.getMessage.contains("Partition spec is invalid. The spec (a, unknown) must be " + + "contained within the partition spec (a, b) defined in table '`db2`.`tbl2`'")) + e = intercept[AnalysisException] { + catalog.listPartitions(TableIdentifier("tbl2", Some("db2")), Some(partWithEmptyValue.spec)) + } + assert(e.getMessage.contains("Partition spec is invalid. The spec ([a=3, b=]) contains an " + + "empty partition column value")) } - assert(e.getMessage.contains("Partition spec is invalid. The spec ([a=3, b=]) contains an " + - "empty partition column value")) } test("list partitions when database/table does not exist") { - val catalog = new SessionCatalog(newBasicCatalog()) - intercept[NoSuchDatabaseException] { - catalog.listPartitions(TableIdentifier("tbl1", Some("unknown_db"))) - } - intercept[NoSuchTableException] { - catalog.listPartitions(TableIdentifier("does_not_exist", Some("db2"))) + withBasicCatalog { catalog => + intercept[NoSuchDatabaseException] { + catalog.listPartitions(TableIdentifier("tbl1", Some("unknown_db"))) + } + intercept[NoSuchTableException] { + catalog.listPartitions(TableIdentifier("does_not_exist", Some("db2"))) + } } } @@ -999,8 +1085,17 @@ class SessionCatalogSuite extends PlanTest { expectedParts: CatalogTablePartition*): Boolean = { // ExternalCatalog may set a default location for partitions, here we ignore the partition // location when comparing them. - actualParts.map(p => p.copy(storage = p.storage.copy(locationUri = None))).toSet == - expectedParts.map(p => p.copy(storage = p.storage.copy(locationUri = None))).toSet + // And for hive serde table, hive metastore will set some values(e.g.transient_lastDdlTime) + // in table's parameters and storage's properties, here we also ignore them. + val actualPartsNormalize = actualParts.map(p => + p.copy(parameters = Map.empty, storage = p.storage.copy( + properties = Map.empty, locationUri = None, serde = None))).toSet + + val expectedPartsNormalize = expectedParts.map(p => + p.copy(parameters = Map.empty, storage = p.storage.copy( + properties = Map.empty, locationUri = None, serde = None))).toSet + + actualPartsNormalize == expectedPartsNormalize } // -------------------------------------------------------------------------- @@ -1008,248 +1103,258 @@ class SessionCatalogSuite extends PlanTest { // -------------------------------------------------------------------------- test("basic create and list functions") { - val externalCatalog = newEmptyCatalog() - val sessionCatalog = new SessionCatalog(externalCatalog) - sessionCatalog.createDatabase(newDb("mydb"), ignoreIfExists = false) - sessionCatalog.createFunction(newFunc("myfunc", Some("mydb")), ignoreIfExists = false) - assert(externalCatalog.listFunctions("mydb", "*").toSet == Set("myfunc")) - // Create function without explicitly specifying database - sessionCatalog.setCurrentDatabase("mydb") - sessionCatalog.createFunction(newFunc("myfunc2"), ignoreIfExists = false) - assert(externalCatalog.listFunctions("mydb", "*").toSet == Set("myfunc", "myfunc2")) + withEmptyCatalog { catalog => + catalog.createDatabase(newDb("mydb"), ignoreIfExists = false) + catalog.createFunction(newFunc("myfunc", Some("mydb")), ignoreIfExists = false) + assert(catalog.externalCatalog.listFunctions("mydb", "*").toSet == Set("myfunc")) + // Create function without explicitly specifying database + catalog.setCurrentDatabase("mydb") + catalog.createFunction(newFunc("myfunc2"), ignoreIfExists = false) + assert(catalog.externalCatalog.listFunctions("mydb", "*").toSet == Set("myfunc", "myfunc2")) + } } test("create function when database does not exist") { - val catalog = new SessionCatalog(newBasicCatalog()) - intercept[NoSuchDatabaseException] { - catalog.createFunction( - newFunc("func5", Some("does_not_exist")), ignoreIfExists = false) + withBasicCatalog { catalog => + intercept[NoSuchDatabaseException] { + catalog.createFunction( + newFunc("func5", Some("does_not_exist")), ignoreIfExists = false) + } } } test("create function that already exists") { - val catalog = new SessionCatalog(newBasicCatalog()) - intercept[FunctionAlreadyExistsException] { - catalog.createFunction(newFunc("func1", Some("db2")), ignoreIfExists = false) + withBasicCatalog { catalog => + intercept[FunctionAlreadyExistsException] { + catalog.createFunction(newFunc("func1", Some("db2")), ignoreIfExists = false) + } + catalog.createFunction(newFunc("func1", Some("db2")), ignoreIfExists = true) } - catalog.createFunction(newFunc("func1", Some("db2")), ignoreIfExists = true) } test("create temp function") { - val catalog = new SessionCatalog(newBasicCatalog()) - val tempFunc1 = (e: Seq[Expression]) => e.head - val tempFunc2 = (e: Seq[Expression]) => e.last - val info1 = new ExpressionInfo("tempFunc1", "temp1") - val info2 = new ExpressionInfo("tempFunc2", "temp2") - catalog.createTempFunction("temp1", info1, tempFunc1, ignoreIfExists = false) - catalog.createTempFunction("temp2", info2, tempFunc2, ignoreIfExists = false) - val arguments = Seq(Literal(1), Literal(2), Literal(3)) - assert(catalog.lookupFunction(FunctionIdentifier("temp1"), arguments) === Literal(1)) - assert(catalog.lookupFunction(FunctionIdentifier("temp2"), arguments) === Literal(3)) - // Temporary function does not exist. - intercept[NoSuchFunctionException] { - catalog.lookupFunction(FunctionIdentifier("temp3"), arguments) - } - val tempFunc3 = (e: Seq[Expression]) => Literal(e.size) - val info3 = new ExpressionInfo("tempFunc3", "temp1") - // Temporary function already exists - intercept[TempFunctionAlreadyExistsException] { - catalog.createTempFunction("temp1", info3, tempFunc3, ignoreIfExists = false) - } - // Temporary function is overridden - catalog.createTempFunction("temp1", info3, tempFunc3, ignoreIfExists = true) - assert( - catalog.lookupFunction(FunctionIdentifier("temp1"), arguments) === Literal(arguments.length)) + withBasicCatalog { catalog => + val tempFunc1 = (e: Seq[Expression]) => e.head + val tempFunc2 = (e: Seq[Expression]) => e.last + val info1 = new ExpressionInfo("tempFunc1", "temp1") + val info2 = new ExpressionInfo("tempFunc2", "temp2") + catalog.createTempFunction("temp1", info1, tempFunc1, ignoreIfExists = false) + catalog.createTempFunction("temp2", info2, tempFunc2, ignoreIfExists = false) + val arguments = Seq(Literal(1), Literal(2), Literal(3)) + assert(catalog.lookupFunction(FunctionIdentifier("temp1"), arguments) === Literal(1)) + assert(catalog.lookupFunction(FunctionIdentifier("temp2"), arguments) === Literal(3)) + // Temporary function does not exist. + intercept[NoSuchFunctionException] { + catalog.lookupFunction(FunctionIdentifier("temp3"), arguments) + } + val tempFunc3 = (e: Seq[Expression]) => Literal(e.size) + val info3 = new ExpressionInfo("tempFunc3", "temp1") + // Temporary function already exists + intercept[TempFunctionAlreadyExistsException] { + catalog.createTempFunction("temp1", info3, tempFunc3, ignoreIfExists = false) + } + // Temporary function is overridden + catalog.createTempFunction("temp1", info3, tempFunc3, ignoreIfExists = true) + assert( + catalog.lookupFunction( + FunctionIdentifier("temp1"), arguments) === Literal(arguments.length)) + } } test("isTemporaryFunction") { - val externalCatalog = newBasicCatalog() - val sessionCatalog = new SessionCatalog(externalCatalog) - - // Returns false when the function does not exist - assert(!sessionCatalog.isTemporaryFunction(FunctionIdentifier("temp1"))) + withBasicCatalog { catalog => + // Returns false when the function does not exist + assert(!catalog.isTemporaryFunction(FunctionIdentifier("temp1"))) - val tempFunc1 = (e: Seq[Expression]) => e.head - val info1 = new ExpressionInfo("tempFunc1", "temp1") - sessionCatalog.createTempFunction("temp1", info1, tempFunc1, ignoreIfExists = false) + val tempFunc1 = (e: Seq[Expression]) => e.head + val info1 = new ExpressionInfo("tempFunc1", "temp1") + catalog.createTempFunction("temp1", info1, tempFunc1, ignoreIfExists = false) - // Returns true when the function is temporary - assert(sessionCatalog.isTemporaryFunction(FunctionIdentifier("temp1"))) + // Returns true when the function is temporary + assert(catalog.isTemporaryFunction(FunctionIdentifier("temp1"))) - // Returns false when the function is permanent - assert(externalCatalog.listFunctions("db2", "*").toSet == Set("func1")) - assert(!sessionCatalog.isTemporaryFunction(FunctionIdentifier("func1", Some("db2")))) - assert(!sessionCatalog.isTemporaryFunction(FunctionIdentifier("db2.func1"))) - sessionCatalog.setCurrentDatabase("db2") - assert(!sessionCatalog.isTemporaryFunction(FunctionIdentifier("func1"))) + // Returns false when the function is permanent + assert(catalog.externalCatalog.listFunctions("db2", "*").toSet == Set("func1")) + assert(!catalog.isTemporaryFunction(FunctionIdentifier("func1", Some("db2")))) + assert(!catalog.isTemporaryFunction(FunctionIdentifier("db2.func1"))) + catalog.setCurrentDatabase("db2") + assert(!catalog.isTemporaryFunction(FunctionIdentifier("func1"))) - // Returns false when the function is built-in or hive - assert(FunctionRegistry.builtin.functionExists("sum")) - assert(!sessionCatalog.isTemporaryFunction(FunctionIdentifier("sum"))) - assert(!sessionCatalog.isTemporaryFunction(FunctionIdentifier("histogram_numeric"))) + // Returns false when the function is built-in or hive + assert(FunctionRegistry.builtin.functionExists("sum")) + assert(!catalog.isTemporaryFunction(FunctionIdentifier("sum"))) + assert(!catalog.isTemporaryFunction(FunctionIdentifier("histogram_numeric"))) + } } test("drop function") { - val externalCatalog = newBasicCatalog() - val sessionCatalog = new SessionCatalog(externalCatalog) - assert(externalCatalog.listFunctions("db2", "*").toSet == Set("func1")) - sessionCatalog.dropFunction( - FunctionIdentifier("func1", Some("db2")), ignoreIfNotExists = false) - assert(externalCatalog.listFunctions("db2", "*").isEmpty) - // Drop function without explicitly specifying database - sessionCatalog.setCurrentDatabase("db2") - sessionCatalog.createFunction(newFunc("func2", Some("db2")), ignoreIfExists = false) - assert(externalCatalog.listFunctions("db2", "*").toSet == Set("func2")) - sessionCatalog.dropFunction(FunctionIdentifier("func2"), ignoreIfNotExists = false) - assert(externalCatalog.listFunctions("db2", "*").isEmpty) + withBasicCatalog { catalog => + assert(catalog.externalCatalog.listFunctions("db2", "*").toSet == Set("func1")) + catalog.dropFunction( + FunctionIdentifier("func1", Some("db2")), ignoreIfNotExists = false) + assert(catalog.externalCatalog.listFunctions("db2", "*").isEmpty) + // Drop function without explicitly specifying database + catalog.setCurrentDatabase("db2") + catalog.createFunction(newFunc("func2", Some("db2")), ignoreIfExists = false) + assert(catalog.externalCatalog.listFunctions("db2", "*").toSet == Set("func2")) + catalog.dropFunction(FunctionIdentifier("func2"), ignoreIfNotExists = false) + assert(catalog.externalCatalog.listFunctions("db2", "*").isEmpty) + } } test("drop function when database/function does not exist") { - val catalog = new SessionCatalog(newBasicCatalog()) - intercept[NoSuchDatabaseException] { - catalog.dropFunction( - FunctionIdentifier("something", Some("unknown_db")), ignoreIfNotExists = false) - } - intercept[NoSuchFunctionException] { - catalog.dropFunction(FunctionIdentifier("does_not_exist"), ignoreIfNotExists = false) + withBasicCatalog { catalog => + intercept[NoSuchDatabaseException] { + catalog.dropFunction( + FunctionIdentifier("something", Some("unknown_db")), ignoreIfNotExists = false) + } + intercept[NoSuchFunctionException] { + catalog.dropFunction(FunctionIdentifier("does_not_exist"), ignoreIfNotExists = false) + } + catalog.dropFunction(FunctionIdentifier("does_not_exist"), ignoreIfNotExists = true) } - catalog.dropFunction(FunctionIdentifier("does_not_exist"), ignoreIfNotExists = true) } test("drop temp function") { - val catalog = new SessionCatalog(newBasicCatalog()) - val info = new ExpressionInfo("tempFunc", "func1") - val tempFunc = (e: Seq[Expression]) => e.head - catalog.createTempFunction("func1", info, tempFunc, ignoreIfExists = false) - val arguments = Seq(Literal(1), Literal(2), Literal(3)) - assert(catalog.lookupFunction(FunctionIdentifier("func1"), arguments) === Literal(1)) - catalog.dropTempFunction("func1", ignoreIfNotExists = false) - intercept[NoSuchFunctionException] { - catalog.lookupFunction(FunctionIdentifier("func1"), arguments) - } - intercept[NoSuchTempFunctionException] { + withBasicCatalog { catalog => + val info = new ExpressionInfo("tempFunc", "func1") + val tempFunc = (e: Seq[Expression]) => e.head + catalog.createTempFunction("func1", info, tempFunc, ignoreIfExists = false) + val arguments = Seq(Literal(1), Literal(2), Literal(3)) + assert(catalog.lookupFunction(FunctionIdentifier("func1"), arguments) === Literal(1)) catalog.dropTempFunction("func1", ignoreIfNotExists = false) + intercept[NoSuchFunctionException] { + catalog.lookupFunction(FunctionIdentifier("func1"), arguments) + } + intercept[NoSuchTempFunctionException] { + catalog.dropTempFunction("func1", ignoreIfNotExists = false) + } + catalog.dropTempFunction("func1", ignoreIfNotExists = true) } - catalog.dropTempFunction("func1", ignoreIfNotExists = true) } test("get function") { - val catalog = new SessionCatalog(newBasicCatalog()) - val expected = - CatalogFunction(FunctionIdentifier("func1", Some("db2")), funcClass, - Seq.empty[FunctionResource]) - assert(catalog.getFunctionMetadata(FunctionIdentifier("func1", Some("db2"))) == expected) - // Get function without explicitly specifying database - catalog.setCurrentDatabase("db2") - assert(catalog.getFunctionMetadata(FunctionIdentifier("func1")) == expected) + withBasicCatalog { catalog => + val expected = + CatalogFunction(FunctionIdentifier("func1", Some("db2")), funcClass, + Seq.empty[FunctionResource]) + assert(catalog.getFunctionMetadata(FunctionIdentifier("func1", Some("db2"))) == expected) + // Get function without explicitly specifying database + catalog.setCurrentDatabase("db2") + assert(catalog.getFunctionMetadata(FunctionIdentifier("func1")) == expected) + } } test("get function when database/function does not exist") { - val catalog = new SessionCatalog(newBasicCatalog()) - intercept[NoSuchDatabaseException] { - catalog.getFunctionMetadata(FunctionIdentifier("func1", Some("unknown_db"))) - } - intercept[NoSuchFunctionException] { - catalog.getFunctionMetadata(FunctionIdentifier("does_not_exist", Some("db2"))) + withBasicCatalog { catalog => + intercept[NoSuchDatabaseException] { + catalog.getFunctionMetadata(FunctionIdentifier("func1", Some("unknown_db"))) + } + intercept[NoSuchFunctionException] { + catalog.getFunctionMetadata(FunctionIdentifier("does_not_exist", Some("db2"))) + } } } test("lookup temp function") { - val catalog = new SessionCatalog(newBasicCatalog()) - val info1 = new ExpressionInfo("tempFunc1", "func1") - val tempFunc1 = (e: Seq[Expression]) => e.head - catalog.createTempFunction("func1", info1, tempFunc1, ignoreIfExists = false) - assert(catalog.lookupFunction( - FunctionIdentifier("func1"), Seq(Literal(1), Literal(2), Literal(3))) == Literal(1)) - catalog.dropTempFunction("func1", ignoreIfNotExists = false) - intercept[NoSuchFunctionException] { - catalog.lookupFunction(FunctionIdentifier("func1"), Seq(Literal(1), Literal(2), Literal(3))) + withBasicCatalog { catalog => + val info1 = new ExpressionInfo("tempFunc1", "func1") + val tempFunc1 = (e: Seq[Expression]) => e.head + catalog.createTempFunction("func1", info1, tempFunc1, ignoreIfExists = false) + assert(catalog.lookupFunction( + FunctionIdentifier("func1"), Seq(Literal(1), Literal(2), Literal(3))) == Literal(1)) + catalog.dropTempFunction("func1", ignoreIfNotExists = false) + intercept[NoSuchFunctionException] { + catalog.lookupFunction(FunctionIdentifier("func1"), Seq(Literal(1), Literal(2), Literal(3))) + } } } test("list functions") { - val catalog = new SessionCatalog(newBasicCatalog()) - val info1 = new ExpressionInfo("tempFunc1", "func1") - val info2 = new ExpressionInfo("tempFunc2", "yes_me") - val tempFunc1 = (e: Seq[Expression]) => e.head - val tempFunc2 = (e: Seq[Expression]) => e.last - catalog.createFunction(newFunc("func2", Some("db2")), ignoreIfExists = false) - catalog.createFunction(newFunc("not_me", Some("db2")), ignoreIfExists = false) - catalog.createTempFunction("func1", info1, tempFunc1, ignoreIfExists = false) - catalog.createTempFunction("yes_me", info2, tempFunc2, ignoreIfExists = false) - assert(catalog.listFunctions("db1", "*").map(_._1).toSet == - Set(FunctionIdentifier("func1"), - FunctionIdentifier("yes_me"))) - assert(catalog.listFunctions("db2", "*").map(_._1).toSet == - Set(FunctionIdentifier("func1"), - FunctionIdentifier("yes_me"), - FunctionIdentifier("func1", Some("db2")), - FunctionIdentifier("func2", Some("db2")), - FunctionIdentifier("not_me", Some("db2")))) - assert(catalog.listFunctions("db2", "func*").map(_._1).toSet == - Set(FunctionIdentifier("func1"), - FunctionIdentifier("func1", Some("db2")), - FunctionIdentifier("func2", Some("db2")))) + withBasicCatalog { catalog => + val info1 = new ExpressionInfo("tempFunc1", "func1") + val info2 = new ExpressionInfo("tempFunc2", "yes_me") + val tempFunc1 = (e: Seq[Expression]) => e.head + val tempFunc2 = (e: Seq[Expression]) => e.last + catalog.createFunction(newFunc("func2", Some("db2")), ignoreIfExists = false) + catalog.createFunction(newFunc("not_me", Some("db2")), ignoreIfExists = false) + catalog.createTempFunction("func1", info1, tempFunc1, ignoreIfExists = false) + catalog.createTempFunction("yes_me", info2, tempFunc2, ignoreIfExists = false) + assert(catalog.listFunctions("db1", "*").map(_._1).toSet == + Set(FunctionIdentifier("func1"), + FunctionIdentifier("yes_me"))) + assert(catalog.listFunctions("db2", "*").map(_._1).toSet == + Set(FunctionIdentifier("func1"), + FunctionIdentifier("yes_me"), + FunctionIdentifier("func1", Some("db2")), + FunctionIdentifier("func2", Some("db2")), + FunctionIdentifier("not_me", Some("db2")))) + assert(catalog.listFunctions("db2", "func*").map(_._1).toSet == + Set(FunctionIdentifier("func1"), + FunctionIdentifier("func1", Some("db2")), + FunctionIdentifier("func2", Some("db2")))) + } } test("list functions when database does not exist") { - val catalog = new SessionCatalog(newBasicCatalog()) - intercept[NoSuchDatabaseException] { - catalog.listFunctions("unknown_db", "func*") + withBasicCatalog { catalog => + intercept[NoSuchDatabaseException] { + catalog.listFunctions("unknown_db", "func*") + } } } test("clone SessionCatalog - temp views") { - val externalCatalog = newEmptyCatalog() - val original = new SessionCatalog(externalCatalog) - val tempTable1 = Range(1, 10, 1, 10) - original.createTempView("copytest1", tempTable1, overrideIfExists = false) + withEmptyCatalog { original => + val tempTable1 = Range(1, 10, 1, 10) + original.createTempView("copytest1", tempTable1, overrideIfExists = false) - // check if tables copied over - val clone = original.newSessionCatalogWith( - SimpleCatalystConf(caseSensitiveAnalysis = true), - new Configuration(), - new SimpleFunctionRegistry, - CatalystSqlParser) - assert(original ne clone) - assert(clone.getTempView("copytest1") == Some(tempTable1)) + // check if tables copied over + val clone = original.newSessionCatalogWith( + SimpleCatalystConf(caseSensitiveAnalysis = true), + new Configuration(), + new SimpleFunctionRegistry, + CatalystSqlParser) + assert(original ne clone) + assert(clone.getTempView("copytest1") == Some(tempTable1)) - // check if clone and original independent - clone.dropTable(TableIdentifier("copytest1"), ignoreIfNotExists = false, purge = false) - assert(original.getTempView("copytest1") == Some(tempTable1)) + // check if clone and original independent + clone.dropTable(TableIdentifier("copytest1"), ignoreIfNotExists = false, purge = false) + assert(original.getTempView("copytest1") == Some(tempTable1)) - val tempTable2 = Range(1, 20, 2, 10) - original.createTempView("copytest2", tempTable2, overrideIfExists = false) - assert(clone.getTempView("copytest2").isEmpty) + val tempTable2 = Range(1, 20, 2, 10) + original.createTempView("copytest2", tempTable2, overrideIfExists = false) + assert(clone.getTempView("copytest2").isEmpty) + } } test("clone SessionCatalog - current db") { - val externalCatalog = newEmptyCatalog() - val db1 = "db1" - val db2 = "db2" - val db3 = "db3" - - externalCatalog.createDatabase(newDb(db1), ignoreIfExists = true) - externalCatalog.createDatabase(newDb(db2), ignoreIfExists = true) - externalCatalog.createDatabase(newDb(db3), ignoreIfExists = true) - - val original = new SessionCatalog(externalCatalog) - original.setCurrentDatabase(db1) - - // check if current db copied over - val clone = original.newSessionCatalogWith( - SimpleCatalystConf(caseSensitiveAnalysis = true), - new Configuration(), - new SimpleFunctionRegistry, - CatalystSqlParser) - assert(original ne clone) - assert(clone.getCurrentDatabase == db1) - - // check if clone and original independent - clone.setCurrentDatabase(db2) - assert(original.getCurrentDatabase == db1) - original.setCurrentDatabase(db3) - assert(clone.getCurrentDatabase == db2) + withEmptyCatalog { original => + val db1 = "db1" + val db2 = "db2" + val db3 = "db3" + + original.externalCatalog.createDatabase(newDb(db1), ignoreIfExists = true) + original.externalCatalog.createDatabase(newDb(db2), ignoreIfExists = true) + original.externalCatalog.createDatabase(newDb(db3), ignoreIfExists = true) + + original.setCurrentDatabase(db1) + + // check if current db copied over + val clone = original.newSessionCatalogWith( + SimpleCatalystConf(caseSensitiveAnalysis = true), + new Configuration(), + new SimpleFunctionRegistry, + CatalystSqlParser) + assert(original ne clone) + assert(clone.getCurrentDatabase == db1) + + // check if clone and original independent + clone.setCurrentDatabase(db2) + assert(original.getCurrentDatabase == db1) + original.setCurrentDatabase(db3) + assert(clone.getCurrentDatabase == db2) + } } test("SPARK-19737: detect undefined functions without triggering relation resolution") { @@ -1258,18 +1363,22 @@ class SessionCatalogSuite extends PlanTest { Seq(true, false) foreach { caseSensitive => val conf = SimpleCatalystConf(caseSensitive) val catalog = new SessionCatalog(newBasicCatalog(), new SimpleFunctionRegistry, conf) - val analyzer = new Analyzer(catalog, conf) - - // The analyzer should report the undefined function rather than the undefined table first. - val cause = intercept[AnalysisException] { - analyzer.execute( - UnresolvedRelation(TableIdentifier("undefined_table")).select( - UnresolvedFunction("undefined_fn", Nil, isDistinct = false) + try { + val analyzer = new Analyzer(catalog, conf) + + // The analyzer should report the undefined function rather than the undefined table first. + val cause = intercept[AnalysisException] { + analyzer.execute( + UnresolvedRelation(TableIdentifier("undefined_table")).select( + UnresolvedFunction("undefined_fn", Nil, isDistinct = false) + ) ) - ) - } + } - assert(cause.getMessage.contains("Undefined function: 'undefined_fn'")) + assert(cause.getMessage.contains("Undefined function: 'undefined_fn'")) + } finally { + catalog.reset() + } } } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalSessionCatalogSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalSessionCatalogSuite.scala new file mode 100644 index 0000000000000..285f35b0b0eac --- /dev/null +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalSessionCatalogSuite.scala @@ -0,0 +1,40 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive + +import org.apache.spark.sql.catalyst.catalog.{CatalogTestUtils, ExternalCatalog, SessionCatalogSuite} +import org.apache.spark.sql.hive.test.TestHiveSingleton + +class HiveExternalSessionCatalogSuite extends SessionCatalogSuite with TestHiveSingleton { + + protected override val isHiveExternalCatalog = true + + private val externalCatalog = { + val catalog = spark.sharedState.externalCatalog + catalog.asInstanceOf[HiveExternalCatalog].client.reset() + catalog + } + + protected val utils = new CatalogTestUtils { + override val tableInputFormat: String = "org.apache.hadoop.mapred.SequenceFileInputFormat" + override val tableOutputFormat: String = + "org.apache.hadoop.hive.ql.io.HiveSequenceFileOutputFormat" + override val defaultProvider: String = "hive" + override def newEmptyCatalog(): ExternalCatalog = externalCatalog + } +} From 2ea214dd05da929840c15891e908384cfa695ca8 Mon Sep 17 00:00:00 2001 From: Liwei Lin Date: Thu, 16 Mar 2017 13:05:36 -0700 Subject: [PATCH 0038/1765] [SPARK-19721][SS] Good error message for version mismatch in log files ## Problem There are several places where we write out version identifiers in various logs for structured streaming (usually `v1`). However, in the places where we check for this, we throw a confusing error message. ## What changes were proposed in this pull request? This patch made two major changes: 1. added a `parseVersion(...)` method, and based on this method, fixed the following places the way they did version checking (no other place needed to do this checking): ``` HDFSMetadataLog - CompactibleFileStreamLog ------------> fixed with this patch - FileStreamSourceLog ---------------> inherited the fix of `CompactibleFileStreamLog` - FileStreamSinkLog -----------------> inherited the fix of `CompactibleFileStreamLog` - OffsetSeqLog ------------------------> fixed with this patch - anonymous subclass in KafkaSource ---> fixed with this patch ``` 2. changed the type of `FileStreamSinkLog.VERSION`, `FileStreamSourceLog.VERSION` etc. from `String` to `Int`, so that we can identify newer versions via `version > 1` instead of `version != "v1"` - note this didn't break any backwards compatibility -- we are still writing out `"v1"` and reading back `"v1"` ## Exception message with this patch ``` java.lang.IllegalStateException: Failed to read log file /private/var/folders/nn/82rmvkk568sd8p3p8tb33trw0000gn/T/spark-86867b65-0069-4ef1-b0eb-d8bd258ff5b8/0. UnsupportedLogVersion: maximum supported log version is v1, but encountered v99. The log file was produced by a newer version of Spark and cannot be read by this version. Please upgrade. at org.apache.spark.sql.execution.streaming.HDFSMetadataLog.get(HDFSMetadataLog.scala:202) at org.apache.spark.sql.execution.streaming.OffsetSeqLogSuite$$anonfun$3$$anonfun$apply$mcV$sp$2.apply(OffsetSeqLogSuite.scala:78) at org.apache.spark.sql.execution.streaming.OffsetSeqLogSuite$$anonfun$3$$anonfun$apply$mcV$sp$2.apply(OffsetSeqLogSuite.scala:75) at org.apache.spark.sql.test.SQLTestUtils$class.withTempDir(SQLTestUtils.scala:133) at org.apache.spark.sql.execution.streaming.OffsetSeqLogSuite.withTempDir(OffsetSeqLogSuite.scala:26) at org.apache.spark.sql.execution.streaming.OffsetSeqLogSuite$$anonfun$3.apply$mcV$sp(OffsetSeqLogSuite.scala:75) at org.apache.spark.sql.execution.streaming.OffsetSeqLogSuite$$anonfun$3.apply(OffsetSeqLogSuite.scala:75) at org.apache.spark.sql.execution.streaming.OffsetSeqLogSuite$$anonfun$3.apply(OffsetSeqLogSuite.scala:75) at org.scalatest.Transformer$$anonfun$apply$1.apply$mcV$sp(Transformer.scala:22) at org.scalatest.OutcomeOf$class.outcomeOf(OutcomeOf.scala:85) ``` ## How was this patch tested? unit tests Author: Liwei Lin Closes #17070 from lw-lin/better-msg. --- .../spark/sql/kafka010/KafkaSource.scala | 14 +++---- .../spark/sql/kafka010/KafkaSourceSuite.scala | 9 ++++- .../streaming/CompactibleFileStreamLog.scala | 9 ++--- .../streaming/FileStreamSinkLog.scala | 4 +- .../streaming/FileStreamSourceLog.scala | 4 +- .../execution/streaming/HDFSMetadataLog.scala | 36 +++++++++++++++++ .../execution/streaming/OffsetSeqLog.scala | 10 ++--- .../CompactibleFileStreamLogSuite.scala | 40 ++++++++++++++++--- .../streaming/FileStreamSinkLogSuite.scala | 8 ++-- .../streaming/HDFSMetadataLogSuite.scala | 27 +++++++++++++ .../streaming/OffsetSeqLogSuite.scala | 17 ++++++++ 11 files changed, 143 insertions(+), 35 deletions(-) diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSource.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSource.scala index 92b5d91ba435e..1fb0a338299b7 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSource.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSource.scala @@ -100,7 +100,7 @@ private[kafka010] class KafkaSource( override def serialize(metadata: KafkaSourceOffset, out: OutputStream): Unit = { out.write(0) // A zero byte is written to support Spark 2.1.0 (SPARK-19517) val writer = new BufferedWriter(new OutputStreamWriter(out, StandardCharsets.UTF_8)) - writer.write(VERSION) + writer.write("v" + VERSION + "\n") writer.write(metadata.json) writer.flush } @@ -111,13 +111,13 @@ private[kafka010] class KafkaSource( // HDFSMetadataLog guarantees that it never creates a partial file. assert(content.length != 0) if (content(0) == 'v') { - if (content.startsWith(VERSION)) { - KafkaSourceOffset(SerializedOffset(content.substring(VERSION.length))) + val indexOfNewLine = content.indexOf("\n") + if (indexOfNewLine > 0) { + val version = parseVersion(content.substring(0, indexOfNewLine), VERSION) + KafkaSourceOffset(SerializedOffset(content.substring(indexOfNewLine + 1))) } else { - val versionInFile = content.substring(0, content.indexOf("\n")) throw new IllegalStateException( - s"Unsupported format. Expected version is ${VERSION.stripLineEnd} " + - s"but was $versionInFile. Please upgrade your Spark.") + s"Log file was malformed: failed to detect the log file version line.") } } else { // The log was generated by Spark 2.1.0 @@ -351,7 +351,7 @@ private[kafka010] object KafkaSource { | source option "failOnDataLoss" to "false". """.stripMargin - private val VERSION = "v1\n" + private[kafka010] val VERSION = 1 def getSortedExecutorList(sc: SparkContext): Array[String] = { val bm = sc.env.blockManager diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSourceSuite.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSourceSuite.scala index bf6aad671a18e..7b6396e0291c9 100644 --- a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSourceSuite.scala +++ b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSourceSuite.scala @@ -205,7 +205,7 @@ class KafkaSourceSuite extends KafkaSourceTest { override def serialize(metadata: KafkaSourceOffset, out: OutputStream): Unit = { out.write(0) val writer = new BufferedWriter(new OutputStreamWriter(out, UTF_8)) - writer.write(s"v0\n${metadata.json}") + writer.write(s"v99999\n${metadata.json}") writer.flush } } @@ -227,7 +227,12 @@ class KafkaSourceSuite extends KafkaSourceTest { source.getOffset.get // Read initial offset } - assert(e.getMessage.contains("Please upgrade your Spark")) + Seq( + s"maximum supported log version is v${KafkaSource.VERSION}, but encountered v99999", + "produced by a newer version of Spark and cannot be read by this version" + ).foreach { message => + assert(e.getMessage.contains(message)) + } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/CompactibleFileStreamLog.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/CompactibleFileStreamLog.scala index 5a6f9e87f6eaa..408c8f81f17ba 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/CompactibleFileStreamLog.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/CompactibleFileStreamLog.scala @@ -40,7 +40,7 @@ import org.apache.spark.sql.SparkSession * doing a compaction, it will read all old log files and merge them with the new batch. */ abstract class CompactibleFileStreamLog[T <: AnyRef : ClassTag]( - metadataLogVersion: String, + metadataLogVersion: Int, sparkSession: SparkSession, path: String) extends HDFSMetadataLog[Array[T]](sparkSession, path) { @@ -134,7 +134,7 @@ abstract class CompactibleFileStreamLog[T <: AnyRef : ClassTag]( override def serialize(logData: Array[T], out: OutputStream): Unit = { // called inside a try-finally where the underlying stream is closed in the caller - out.write(metadataLogVersion.getBytes(UTF_8)) + out.write(("v" + metadataLogVersion).getBytes(UTF_8)) logData.foreach { data => out.write('\n') out.write(Serialization.write(data).getBytes(UTF_8)) @@ -146,10 +146,7 @@ abstract class CompactibleFileStreamLog[T <: AnyRef : ClassTag]( if (!lines.hasNext) { throw new IllegalStateException("Incomplete log file") } - val version = lines.next() - if (version != metadataLogVersion) { - throw new IllegalStateException(s"Unknown log version: ${version}") - } + val version = parseVersion(lines.next(), metadataLogVersion) lines.map(Serialization.read[T]).toArray } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSinkLog.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSinkLog.scala index eb6eed87eca7b..8d718b2164d22 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSinkLog.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSinkLog.scala @@ -77,7 +77,7 @@ object SinkFileStatus { * (drops the deleted files). */ class FileStreamSinkLog( - metadataLogVersion: String, + metadataLogVersion: Int, sparkSession: SparkSession, path: String) extends CompactibleFileStreamLog[SinkFileStatus](metadataLogVersion, sparkSession, path) { @@ -106,7 +106,7 @@ class FileStreamSinkLog( } object FileStreamSinkLog { - val VERSION = "v1" + val VERSION = 1 val DELETE_ACTION = "delete" val ADD_ACTION = "add" } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSourceLog.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSourceLog.scala index 81908c0cefdfa..33e6a1d5d6e18 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSourceLog.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSourceLog.scala @@ -30,7 +30,7 @@ import org.apache.spark.sql.execution.streaming.FileStreamSource.FileEntry import org.apache.spark.sql.internal.SQLConf class FileStreamSourceLog( - metadataLogVersion: String, + metadataLogVersion: Int, sparkSession: SparkSession, path: String) extends CompactibleFileStreamLog[FileEntry](metadataLogVersion, sparkSession, path) { @@ -120,5 +120,5 @@ class FileStreamSourceLog( } object FileStreamSourceLog { - val VERSION = "v1" + val VERSION = 1 } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/HDFSMetadataLog.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/HDFSMetadataLog.scala index f9e1f7de9ec08..60ce64261c4a4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/HDFSMetadataLog.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/HDFSMetadataLog.scala @@ -195,6 +195,11 @@ class HDFSMetadataLog[T <: AnyRef : ClassTag](sparkSession: SparkSession, path: val input = fileManager.open(batchMetadataFile) try { Some(deserialize(input)) + } catch { + case ise: IllegalStateException => + // re-throw the exception with the log file path added + throw new IllegalStateException( + s"Failed to read log file $batchMetadataFile. ${ise.getMessage}", ise) } finally { IOUtils.closeQuietly(input) } @@ -268,6 +273,37 @@ class HDFSMetadataLog[T <: AnyRef : ClassTag](sparkSession: SparkSession, path: new FileSystemManager(metadataPath, hadoopConf) } } + + /** + * Parse the log version from the given `text` -- will throw exception when the parsed version + * exceeds `maxSupportedVersion`, or when `text` is malformed (such as "xyz", "v", "v-1", + * "v123xyz" etc.) + */ + private[sql] def parseVersion(text: String, maxSupportedVersion: Int): Int = { + if (text.length > 0 && text(0) == 'v') { + val version = + try { + text.substring(1, text.length).toInt + } catch { + case _: NumberFormatException => + throw new IllegalStateException(s"Log file was malformed: failed to read correct log " + + s"version from $text.") + } + if (version > 0) { + if (version > maxSupportedVersion) { + throw new IllegalStateException(s"UnsupportedLogVersion: maximum supported log version " + + s"is v${maxSupportedVersion}, but encountered v$version. The log file was produced " + + s"by a newer version of Spark and cannot be read by this version. Please upgrade.") + } else { + return version + } + } + } + + // reaching here means we failed to read the correct log version + throw new IllegalStateException(s"Log file was malformed: failed to read correct log " + + s"version from $text.") + } } object HDFSMetadataLog { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/OffsetSeqLog.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/OffsetSeqLog.scala index 3210d8ad64e22..4f8cd116f610e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/OffsetSeqLog.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/OffsetSeqLog.scala @@ -55,10 +55,8 @@ class OffsetSeqLog(sparkSession: SparkSession, path: String) if (!lines.hasNext) { throw new IllegalStateException("Incomplete log file") } - val version = lines.next() - if (version != OffsetSeqLog.VERSION) { - throw new IllegalStateException(s"Unknown log version: ${version}") - } + + val version = parseVersion(lines.next(), OffsetSeqLog.VERSION) // read metadata val metadata = lines.next().trim match { @@ -70,7 +68,7 @@ class OffsetSeqLog(sparkSession: SparkSession, path: String) override protected def serialize(offsetSeq: OffsetSeq, out: OutputStream): Unit = { // called inside a try-finally where the underlying stream is closed in the caller - out.write(OffsetSeqLog.VERSION.getBytes(UTF_8)) + out.write(("v" + OffsetSeqLog.VERSION).getBytes(UTF_8)) // write metadata out.write('\n') @@ -88,6 +86,6 @@ class OffsetSeqLog(sparkSession: SparkSession, path: String) } object OffsetSeqLog { - private val VERSION = "v1" + private[streaming] val VERSION = 1 private val SERIALIZED_VOID_OFFSET = "-" } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/CompactibleFileStreamLogSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/CompactibleFileStreamLogSuite.scala index 24d92a96237e3..20ac06f048c6f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/CompactibleFileStreamLogSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/CompactibleFileStreamLogSuite.scala @@ -122,7 +122,7 @@ class CompactibleFileStreamLogSuite extends SparkFunSuite with SharedSQLContext defaultMinBatchesToRetain = 1, compactibleLog => { val logs = Array("entry_1", "entry_2", "entry_3") - val expected = s"""${FakeCompactibleFileStreamLog.VERSION} + val expected = s"""v${FakeCompactibleFileStreamLog.VERSION} |"entry_1" |"entry_2" |"entry_3"""".stripMargin @@ -132,7 +132,7 @@ class CompactibleFileStreamLogSuite extends SparkFunSuite with SharedSQLContext baos.reset() compactibleLog.serialize(Array(), baos) - assert(FakeCompactibleFileStreamLog.VERSION === baos.toString(UTF_8.name())) + assert(s"v${FakeCompactibleFileStreamLog.VERSION}" === baos.toString(UTF_8.name())) }) } @@ -142,7 +142,7 @@ class CompactibleFileStreamLogSuite extends SparkFunSuite with SharedSQLContext defaultCompactInterval = 3, defaultMinBatchesToRetain = 1, compactibleLog => { - val logs = s"""${FakeCompactibleFileStreamLog.VERSION} + val logs = s"""v${FakeCompactibleFileStreamLog.VERSION} |"entry_1" |"entry_2" |"entry_3"""".stripMargin @@ -152,10 +152,36 @@ class CompactibleFileStreamLogSuite extends SparkFunSuite with SharedSQLContext assert(Nil === compactibleLog.deserialize( - new ByteArrayInputStream(FakeCompactibleFileStreamLog.VERSION.getBytes(UTF_8)))) + new ByteArrayInputStream(s"v${FakeCompactibleFileStreamLog.VERSION}".getBytes(UTF_8)))) }) } + test("deserialization log written by future version") { + withTempDir { dir => + def newFakeCompactibleFileStreamLog(version: Int): FakeCompactibleFileStreamLog = + new FakeCompactibleFileStreamLog( + version, + _fileCleanupDelayMs = Long.MaxValue, // this param does not matter here in this test case + _defaultCompactInterval = 3, // this param does not matter here in this test case + _defaultMinBatchesToRetain = 1, // this param does not matter here in this test case + spark, + dir.getCanonicalPath) + + val writer = newFakeCompactibleFileStreamLog(version = 2) + val reader = newFakeCompactibleFileStreamLog(version = 1) + writer.add(0, Array("entry")) + val e = intercept[IllegalStateException] { + reader.get(0) + } + Seq( + "maximum supported log version is v1, but encountered v2", + "produced by a newer version of Spark and cannot be read by this version" + ).foreach { message => + assert(e.getMessage.contains(message)) + } + } + } + test("compact") { withFakeCompactibleFileStreamLog( fileCleanupDelayMs = Long.MaxValue, @@ -219,6 +245,7 @@ class CompactibleFileStreamLogSuite extends SparkFunSuite with SharedSQLContext ): Unit = { withTempDir { file => val compactibleLog = new FakeCompactibleFileStreamLog( + FakeCompactibleFileStreamLog.VERSION, fileCleanupDelayMs, defaultCompactInterval, defaultMinBatchesToRetain, @@ -230,17 +257,18 @@ class CompactibleFileStreamLogSuite extends SparkFunSuite with SharedSQLContext } object FakeCompactibleFileStreamLog { - val VERSION = "test_version" + val VERSION = 1 } class FakeCompactibleFileStreamLog( + metadataLogVersion: Int, _fileCleanupDelayMs: Long, _defaultCompactInterval: Int, _defaultMinBatchesToRetain: Int, sparkSession: SparkSession, path: String) extends CompactibleFileStreamLog[String]( - FakeCompactibleFileStreamLog.VERSION, + metadataLogVersion, sparkSession, path ) { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/FileStreamSinkLogSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/FileStreamSinkLogSuite.scala index 340d2945acd4a..dd3a414659c23 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/FileStreamSinkLogSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/FileStreamSinkLogSuite.scala @@ -74,7 +74,7 @@ class FileStreamSinkLogSuite extends SparkFunSuite with SharedSQLContext { action = FileStreamSinkLog.ADD_ACTION)) // scalastyle:off - val expected = s"""$VERSION + val expected = s"""v$VERSION |{"path":"/a/b/x","size":100,"isDir":false,"modificationTime":1000,"blockReplication":1,"blockSize":10000,"action":"add"} |{"path":"/a/b/y","size":200,"isDir":false,"modificationTime":2000,"blockReplication":2,"blockSize":20000,"action":"delete"} |{"path":"/a/b/z","size":300,"isDir":false,"modificationTime":3000,"blockReplication":3,"blockSize":30000,"action":"add"}""".stripMargin @@ -84,14 +84,14 @@ class FileStreamSinkLogSuite extends SparkFunSuite with SharedSQLContext { assert(expected === baos.toString(UTF_8.name())) baos.reset() sinkLog.serialize(Array(), baos) - assert(VERSION === baos.toString(UTF_8.name())) + assert(s"v$VERSION" === baos.toString(UTF_8.name())) } } test("deserialize") { withFileStreamSinkLog { sinkLog => // scalastyle:off - val logs = s"""$VERSION + val logs = s"""v$VERSION |{"path":"/a/b/x","size":100,"isDir":false,"modificationTime":1000,"blockReplication":1,"blockSize":10000,"action":"add"} |{"path":"/a/b/y","size":200,"isDir":false,"modificationTime":2000,"blockReplication":2,"blockSize":20000,"action":"delete"} |{"path":"/a/b/z","size":300,"isDir":false,"modificationTime":3000,"blockReplication":3,"blockSize":30000,"action":"add"}""".stripMargin @@ -125,7 +125,7 @@ class FileStreamSinkLogSuite extends SparkFunSuite with SharedSQLContext { assert(expected === sinkLog.deserialize(new ByteArrayInputStream(logs.getBytes(UTF_8)))) - assert(Nil === sinkLog.deserialize(new ByteArrayInputStream(VERSION.getBytes(UTF_8)))) + assert(Nil === sinkLog.deserialize(new ByteArrayInputStream(s"v$VERSION".getBytes(UTF_8)))) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/HDFSMetadataLogSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/HDFSMetadataLogSuite.scala index 55750b9202982..662c4466b21b2 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/HDFSMetadataLogSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/HDFSMetadataLogSuite.scala @@ -127,6 +127,33 @@ class HDFSMetadataLogSuite extends SparkFunSuite with SharedSQLContext { } } + test("HDFSMetadataLog: parseVersion") { + withTempDir { dir => + val metadataLog = new HDFSMetadataLog[String](spark, dir.getAbsolutePath) + def assertLogFileMalformed(func: => Int): Unit = { + val e = intercept[IllegalStateException] { func } + assert(e.getMessage.contains(s"Log file was malformed: failed to read correct log version")) + } + assertLogFileMalformed { metadataLog.parseVersion("", 100) } + assertLogFileMalformed { metadataLog.parseVersion("xyz", 100) } + assertLogFileMalformed { metadataLog.parseVersion("v10.x", 100) } + assertLogFileMalformed { metadataLog.parseVersion("10", 100) } + assertLogFileMalformed { metadataLog.parseVersion("v0", 100) } + assertLogFileMalformed { metadataLog.parseVersion("v-10", 100) } + + assert(metadataLog.parseVersion("v10", 10) === 10) + assert(metadataLog.parseVersion("v10", 100) === 10) + + val e = intercept[IllegalStateException] { metadataLog.parseVersion("v200", 100) } + Seq( + "maximum supported log version is v100, but encountered v200", + "produced by a newer version of Spark and cannot be read by this version" + ).foreach { message => + assert(e.getMessage.contains(message)) + } + } + } + test("HDFSMetadataLog: restart") { withTempDir { temp => val metadataLog = new HDFSMetadataLog[String](spark, temp.getAbsolutePath) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/OffsetSeqLogSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/OffsetSeqLogSuite.scala index 5ae8b2484d2ef..f7f0dade8717e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/OffsetSeqLogSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/OffsetSeqLogSuite.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.execution.streaming import java.io.File import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.util.stringToFile import org.apache.spark.sql.test.SharedSQLContext class OffsetSeqLogSuite extends SparkFunSuite with SharedSQLContext { @@ -70,6 +71,22 @@ class OffsetSeqLogSuite extends SparkFunSuite with SharedSQLContext { } } + test("deserialization log written by future version") { + withTempDir { dir => + stringToFile(new File(dir, "0"), "v99999") + val log = new OffsetSeqLog(spark, dir.getCanonicalPath) + val e = intercept[IllegalStateException] { + log.get(0) + } + Seq( + s"maximum supported log version is v${OffsetSeqLog.VERSION}, but encountered v99999", + "produced by a newer version of Spark and cannot be read by this version" + ).foreach { message => + assert(e.getMessage.contains(message)) + } + } + } + test("read Spark 2.1.0 log format") { val (batchId, offsetSeq) = readFromResource("offset-log-version-2.1.0") assert(batchId === 0) From 4c3200546c5c55e671988a957011417ba76a0600 Mon Sep 17 00:00:00 2001 From: "Joseph K. Bradley" Date: Thu, 16 Mar 2017 17:10:15 -0700 Subject: [PATCH 0039/1765] [SPARK-19635][ML] DataFrame-based API for chi square test ## What changes were proposed in this pull request? Wrapper taking and return a DataFrame ## How was this patch tested? Copied unit tests from RDD-based API Author: Joseph K. Bradley Closes #17110 from jkbradley/df-hypotests. --- .../org/apache/spark/ml/stat/ChiSquare.scala | 81 +++++++++++++++ .../spark/mllib/stat/test/ChiSqTest.scala | 8 +- .../apache/spark/ml/stat/ChiSquareSuite.scala | 98 +++++++++++++++++++ .../mllib/stat/HypothesisTestSuite.scala | 11 ++- 4 files changed, 192 insertions(+), 6 deletions(-) create mode 100644 mllib/src/main/scala/org/apache/spark/ml/stat/ChiSquare.scala create mode 100644 mllib/src/test/scala/org/apache/spark/ml/stat/ChiSquareSuite.scala diff --git a/mllib/src/main/scala/org/apache/spark/ml/stat/ChiSquare.scala b/mllib/src/main/scala/org/apache/spark/ml/stat/ChiSquare.scala new file mode 100644 index 0000000000000..c3865ce6a9e2a --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/stat/ChiSquare.scala @@ -0,0 +1,81 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.stat + +import org.apache.spark.annotation.{Experimental, Since} +import org.apache.spark.ml.linalg.{Vector, Vectors, VectorUDT} +import org.apache.spark.ml.util.SchemaUtils +import org.apache.spark.mllib.linalg.{Vectors => OldVectors} +import org.apache.spark.mllib.regression.{LabeledPoint => OldLabeledPoint} +import org.apache.spark.mllib.stat.{Statistics => OldStatistics} +import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.functions.col + + +/** + * :: Experimental :: + * + * Chi-square hypothesis testing for categorical data. + * + * See Wikipedia for more information + * on the Chi-squared test. + */ +@Experimental +@Since("2.2.0") +object ChiSquare { + + /** Used to construct output schema of tests */ + private case class ChiSquareResult( + pValues: Vector, + degreesOfFreedom: Array[Int], + statistics: Vector) + + /** + * Conduct Pearson's independence test for every feature against the label across the input RDD. + * For each feature, the (feature, label) pairs are converted into a contingency matrix for which + * the Chi-squared statistic is computed. All label and feature values must be categorical. + * + * The null hypothesis is that the occurrence of the outcomes is statistically independent. + * + * @param dataset DataFrame of categorical labels and categorical features. + * Real-valued features will be treated as categorical for each distinct value. + * @param featuresCol Name of features column in dataset, of type `Vector` (`VectorUDT`) + * @param labelCol Name of label column in dataset, of any numerical type + * @return DataFrame containing the test result for every feature against the label. + * This DataFrame will contain a single Row with the following fields: + * - `pValues: Vector` + * - `degreesOfFreedom: Array[Int]` + * - `statistics: Vector` + * Each of these fields has one value per feature. + */ + @Since("2.2.0") + def test(dataset: DataFrame, featuresCol: String, labelCol: String): DataFrame = { + val spark = dataset.sparkSession + import spark.implicits._ + + SchemaUtils.checkColumnType(dataset.schema, featuresCol, new VectorUDT) + SchemaUtils.checkNumericType(dataset.schema, labelCol) + val rdd = dataset.select(col(labelCol).cast("double"), col(featuresCol)).as[(Double, Vector)] + .rdd.map { case (label, features) => OldLabeledPoint(label, OldVectors.fromML(features)) } + val testResults = OldStatistics.chiSqTest(rdd) + val pValues: Vector = Vectors.dense(testResults.map(_.pValue)) + val degreesOfFreedom: Array[Int] = testResults.map(_.degreesOfFreedom) + val statistics: Vector = Vectors.dense(testResults.map(_.statistic)) + spark.createDataFrame(Seq(ChiSquareResult(pValues, degreesOfFreedom, statistics))) + } +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/stat/test/ChiSqTest.scala b/mllib/src/main/scala/org/apache/spark/mllib/stat/test/ChiSqTest.scala index 9a63b8a5d63db..ee51248e53556 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/stat/test/ChiSqTest.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/stat/test/ChiSqTest.scala @@ -41,7 +41,7 @@ import org.apache.spark.rdd.RDD * * More information on Chi-squared test: http://en.wikipedia.org/wiki/Chi-squared_test */ -private[stat] object ChiSqTest extends Logging { +private[spark] object ChiSqTest extends Logging { /** * @param name String name for the method. @@ -70,6 +70,11 @@ private[stat] object ChiSqTest extends Logging { } } + /** + * Max number of categories when indexing labels and features + */ + private[spark] val maxCategories: Int = 10000 + /** * Conduct Pearson's independence test for each feature against the label across the input RDD. * The contingency table is constructed from the raw (feature, label) pairs and used to conduct @@ -78,7 +83,6 @@ private[stat] object ChiSqTest extends Logging { */ def chiSquaredFeatures(data: RDD[LabeledPoint], methodName: String = PEARSON.name): Array[ChiSqTestResult] = { - val maxCategories = 10000 val numCols = data.first().features.size val results = new Array[ChiSqTestResult](numCols) var labels: Map[Double, Int] = null diff --git a/mllib/src/test/scala/org/apache/spark/ml/stat/ChiSquareSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/stat/ChiSquareSuite.scala new file mode 100644 index 0000000000000..b4bed82e4d00f --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/stat/ChiSquareSuite.scala @@ -0,0 +1,98 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.stat + +import java.util.Random + +import org.apache.spark.{SparkException, SparkFunSuite} +import org.apache.spark.ml.feature.LabeledPoint +import org.apache.spark.ml.linalg.{Vector, Vectors} +import org.apache.spark.ml.util.DefaultReadWriteTest +import org.apache.spark.ml.util.TestingUtils._ +import org.apache.spark.mllib.stat.test.ChiSqTest +import org.apache.spark.mllib.util.MLlibTestSparkContext + +class ChiSquareSuite + extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { + + import testImplicits._ + + test("test DataFrame of labeled points") { + // labels: 1.0 (2 / 6), 0.0 (4 / 6) + // feature1: 0.5 (1 / 6), 1.5 (2 / 6), 3.5 (3 / 6) + // feature2: 10.0 (1 / 6), 20.0 (1 / 6), 30.0 (2 / 6), 40.0 (2 / 6) + val data = Seq( + LabeledPoint(0.0, Vectors.dense(0.5, 10.0)), + LabeledPoint(0.0, Vectors.dense(1.5, 20.0)), + LabeledPoint(1.0, Vectors.dense(1.5, 30.0)), + LabeledPoint(0.0, Vectors.dense(3.5, 30.0)), + LabeledPoint(0.0, Vectors.dense(3.5, 40.0)), + LabeledPoint(1.0, Vectors.dense(3.5, 40.0))) + for (numParts <- List(2, 4, 6, 8)) { + val df = spark.createDataFrame(sc.parallelize(data, numParts)) + val chi = ChiSquare.test(df, "features", "label") + val (pValues: Vector, degreesOfFreedom: Array[Int], statistics: Vector) = + chi.select("pValues", "degreesOfFreedom", "statistics") + .as[(Vector, Array[Int], Vector)].head() + assert(pValues ~== Vectors.dense(0.6873, 0.6823) relTol 1e-4) + assert(degreesOfFreedom === Array(2, 3)) + assert(statistics ~== Vectors.dense(0.75, 1.5) relTol 1e-4) + } + } + + test("large number of features (SPARK-3087)") { + // Test that the right number of results is returned + val numCols = 1001 + val sparseData = Array( + LabeledPoint(0.0, Vectors.sparse(numCols, Seq((100, 2.0)))), + LabeledPoint(0.1, Vectors.sparse(numCols, Seq((200, 1.0))))) + val df = spark.createDataFrame(sparseData) + val chi = ChiSquare.test(df, "features", "label") + val (pValues: Vector, degreesOfFreedom: Array[Int], statistics: Vector) = + chi.select("pValues", "degreesOfFreedom", "statistics") + .as[(Vector, Array[Int], Vector)].head() + assert(pValues.size === numCols) + assert(degreesOfFreedom.length === numCols) + assert(statistics.size === numCols) + assert(pValues(1000) !== null) // SPARK-3087 + } + + test("fail on continuous features or labels") { + val tooManyCategories: Int = 100000 + assert(tooManyCategories > ChiSqTest.maxCategories, "This unit test requires that " + + "tooManyCategories be large enough to cause ChiSqTest to throw an exception.") + + val random = new Random(11L) + val continuousLabel = Seq.fill(tooManyCategories)( + LabeledPoint(random.nextDouble(), Vectors.dense(random.nextInt(2)))) + withClue("ChiSquare should throw an exception when given a continuous-valued label") { + intercept[SparkException] { + val df = spark.createDataFrame(continuousLabel) + ChiSquare.test(df, "features", "label") + } + } + val continuousFeature = Seq.fill(tooManyCategories)( + LabeledPoint(random.nextInt(2), Vectors.dense(random.nextDouble()))) + withClue("ChiSquare should throw an exception when given continuous-valued features") { + intercept[SparkException] { + val df = spark.createDataFrame(continuousFeature) + ChiSquare.test(df, "features", "label") + } + } + } +} diff --git a/mllib/src/test/scala/org/apache/spark/mllib/stat/HypothesisTestSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/stat/HypothesisTestSuite.scala index 46fcebe132749..992b876561896 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/stat/HypothesisTestSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/stat/HypothesisTestSuite.scala @@ -145,14 +145,17 @@ class HypothesisTestSuite extends SparkFunSuite with MLlibTestSparkContext { assert(chi(1000) != null) // SPARK-3087 // Detect continuous features or labels + val tooManyCategories: Int = 100000 + assert(tooManyCategories > ChiSqTest.maxCategories, "This unit test requires that " + + "tooManyCategories be large enough to cause ChiSqTest to throw an exception.") val random = new Random(11L) - val continuousLabel = - Seq.fill(100000)(LabeledPoint(random.nextDouble(), Vectors.dense(random.nextInt(2)))) + val continuousLabel = Seq.fill(tooManyCategories)( + LabeledPoint(random.nextDouble(), Vectors.dense(random.nextInt(2)))) intercept[SparkException] { Statistics.chiSqTest(sc.parallelize(continuousLabel, 2)) } - val continuousFeature = - Seq.fill(100000)(LabeledPoint(random.nextInt(2), Vectors.dense(random.nextDouble()))) + val continuousFeature = Seq.fill(tooManyCategories)( + LabeledPoint(random.nextInt(2), Vectors.dense(random.nextDouble()))) intercept[SparkException] { Statistics.chiSqTest(sc.parallelize(continuousFeature, 2)) } From 8537c00e0a17eff2a8c6745fbdd1d08873c0434d Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Thu, 16 Mar 2017 18:31:57 -0700 Subject: [PATCH 0040/1765] [SPARK-19987][SQL] Pass all filters into FileIndex ## What changes were proposed in this pull request? This is a tiny teeny refactoring to pass data filters also to the FileIndex, so FileIndex can have a more global view on predicates. ## How was this patch tested? Change should be covered by existing test cases. Author: Reynold Xin Closes #17322 from rxin/SPARK-19987. --- .../sql/execution/DataSourceScanExec.scala | 23 +++++++++++-------- .../execution/OptimizeMetadataOnlyQuery.scala | 2 +- .../datasources/CatalogFileIndex.scala | 5 ++-- .../sql/execution/datasources/FileIndex.scala | 15 ++++++++---- .../datasources/FileSourceStrategy.scala | 5 +--- .../PartitioningAwareFileIndex.scala | 8 ++++--- .../spark/sql/hive/HiveMetastoreCatalog.scala | 4 +--- 7 files changed, 35 insertions(+), 27 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala index 8ebad676ca310..bfe9c8e351abc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala @@ -23,18 +23,18 @@ import org.apache.commons.lang3.StringUtils import org.apache.hadoop.fs.{BlockLocation, FileStatus, LocatedFileStatus, Path} import org.apache.spark.rdd.RDD -import org.apache.spark.sql.{AnalysisException, SparkSession} +import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.{InternalRow, TableIdentifier} import org.apache.spark.sql.catalyst.catalog.BucketSpec import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} +import org.apache.spark.sql.catalyst.expressions.codegen.CodegenContext import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, Partitioning, UnknownPartitioning} import org.apache.spark.sql.execution.datasources._ import org.apache.spark.sql.execution.datasources.parquet.{ParquetFileFormat => ParquetSource} import org.apache.spark.sql.execution.metric.SQLMetrics import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.sources.{BaseRelation, Filter} -import org.apache.spark.sql.types.{DataType, StructType} +import org.apache.spark.sql.sources.BaseRelation +import org.apache.spark.sql.types.StructType import org.apache.spark.util.Utils trait DataSourceScanExec extends LeafExecNode with CodegenSupport { @@ -135,7 +135,7 @@ case class RowDataSourceScanExec( * @param output Output attributes of the scan. * @param outputSchema Output schema of the scan. * @param partitionFilters Predicates to use for partition pruning. - * @param dataFilters Data source filters to use for filtering data within partitions. + * @param dataFilters Filters on non-partition columns. * @param metastoreTableIdentifier identifier for the table in the metastore. */ case class FileSourceScanExec( @@ -143,7 +143,7 @@ case class FileSourceScanExec( output: Seq[Attribute], outputSchema: StructType, partitionFilters: Seq[Expression], - dataFilters: Seq[Filter], + dataFilters: Seq[Expression], override val metastoreTableIdentifier: Option[TableIdentifier]) extends DataSourceScanExec with ColumnarBatchScan { @@ -156,7 +156,8 @@ case class FileSourceScanExec( false } - @transient private lazy val selectedPartitions = relation.location.listFiles(partitionFilters) + @transient private lazy val selectedPartitions = + relation.location.listFiles(partitionFilters, dataFilters) override val (outputPartitioning, outputOrdering): (Partitioning, Seq[SortOrder]) = { val bucketSpec = if (relation.sparkSession.sessionState.conf.bucketingEnabled) { @@ -225,6 +226,10 @@ case class FileSourceScanExec( } } + @transient + private val pushedDownFilters = dataFilters.flatMap(DataSourceStrategy.translateFilter) + logInfo(s"Pushed Filters: ${pushedDownFilters.mkString(",")}") + // These metadata values make scan plans uniquely identifiable for equality checking. override val metadata: Map[String, String] = { def seqToString(seq: Seq[Any]) = seq.mkString("[", ", ", "]") @@ -237,7 +242,7 @@ case class FileSourceScanExec( "ReadSchema" -> outputSchema.catalogString, "Batched" -> supportsBatch.toString, "PartitionFilters" -> seqToString(partitionFilters), - "PushedFilters" -> seqToString(dataFilters), + "PushedFilters" -> seqToString(pushedDownFilters), "Location" -> locationDesc) val withOptPartitionCount = relation.partitionSchemaOption.map { _ => @@ -255,7 +260,7 @@ case class FileSourceScanExec( dataSchema = relation.dataSchema, partitionSchema = relation.partitionSchema, requiredSchema = outputSchema, - filters = dataFilters, + filters = pushedDownFilters, options = relation.options, hadoopConf = relation.sparkSession.sessionState.newHadoopConfWithOptions(relation.options)) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/OptimizeMetadataOnlyQuery.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/OptimizeMetadataOnlyQuery.scala index 769deb1890b6d..3c046ce494285 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/OptimizeMetadataOnlyQuery.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/OptimizeMetadataOnlyQuery.scala @@ -98,7 +98,7 @@ case class OptimizeMetadataOnlyQuery( relation match { case l @ LogicalRelation(fsRelation: HadoopFsRelation, _, _) => val partAttrs = getPartitionAttrs(fsRelation.partitionSchema.map(_.name), l) - val partitionData = fsRelation.location.listFiles(filters = Nil) + val partitionData = fsRelation.location.listFiles(Nil, Nil) LocalRelation(partAttrs, partitionData.map(_.values)) case relation: CatalogRelation => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/CatalogFileIndex.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/CatalogFileIndex.scala index d6c4b97ebd080..db0254f8d5581 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/CatalogFileIndex.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/CatalogFileIndex.scala @@ -54,8 +54,9 @@ class CatalogFileIndex( override def rootPaths: Seq[Path] = baseLocation.map(new Path(_)).toSeq - override def listFiles(filters: Seq[Expression]): Seq[PartitionDirectory] = { - filterPartitions(filters).listFiles(Nil) + override def listFiles( + partitionFilters: Seq[Expression], dataFilters: Seq[Expression]): Seq[PartitionDirectory] = { + filterPartitions(partitionFilters).listFiles(Nil, dataFilters) } override def refresh(): Unit = fileStatusCache.invalidateAll() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileIndex.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileIndex.scala index 277223d52ec52..6b99d38fe5729 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileIndex.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileIndex.scala @@ -46,12 +46,17 @@ trait FileIndex { * Returns all valid files grouped into partitions when the data is partitioned. If the data is * unpartitioned, this will return a single partition with no partition values. * - * @param filters The filters used to prune which partitions are returned. These filters must - * only refer to partition columns and this method will only return files - * where these predicates are guaranteed to evaluate to `true`. Thus, these - * filters will not need to be evaluated again on the returned data. + * @param partitionFilters The filters used to prune which partitions are returned. These filters + * must only refer to partition columns and this method will only return + * files where these predicates are guaranteed to evaluate to `true`. + * Thus, these filters will not need to be evaluated again on the + * returned data. + * @param dataFilters Filters that can be applied on non-partitioned columns. The implementation + * does not need to guarantee these filters are applied, i.e. the execution + * engine will ensure these filters are still applied on the returned files. */ - def listFiles(filters: Seq[Expression]): Seq[PartitionDirectory] + def listFiles( + partitionFilters: Seq[Expression], dataFilters: Seq[Expression]): Seq[PartitionDirectory] /** * Returns the list of files that will be read when scanning this relation. This call may be diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala index 26e1380eca499..17f7e0e601c0c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala @@ -100,9 +100,6 @@ object FileSourceStrategy extends Strategy with Logging { val outputSchema = readDataColumns.toStructType logInfo(s"Output Data Schema: ${outputSchema.simpleString(5)}") - val pushedDownFilters = dataFilters.flatMap(DataSourceStrategy.translateFilter) - logInfo(s"Pushed Filters: ${pushedDownFilters.mkString(",")}") - val outputAttributes = readDataColumns ++ partitionColumns val scan = @@ -111,7 +108,7 @@ object FileSourceStrategy extends Strategy with Logging { outputAttributes, outputSchema, partitionKeyFilters.toSeq, - pushedDownFilters, + dataFilters, table.map(_.identifier)) val afterScanFilter = afterScanFilters.toSeq.reduceOption(expressions.And) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningAwareFileIndex.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningAwareFileIndex.scala index db8bbc52aaf4d..71500a010581e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningAwareFileIndex.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningAwareFileIndex.scala @@ -54,17 +54,19 @@ abstract class PartitioningAwareFileIndex( override def partitionSchema: StructType = partitionSpec().partitionColumns - protected val hadoopConf = sparkSession.sessionState.newHadoopConfWithOptions(parameters) + protected val hadoopConf: Configuration = + sparkSession.sessionState.newHadoopConfWithOptions(parameters) protected def leafFiles: mutable.LinkedHashMap[Path, FileStatus] protected def leafDirToChildrenFiles: Map[Path, Array[FileStatus]] - override def listFiles(filters: Seq[Expression]): Seq[PartitionDirectory] = { + override def listFiles( + partitionFilters: Seq[Expression], dataFilters: Seq[Expression]): Seq[PartitionDirectory] = { val selectedPartitions = if (partitionSpec().partitionColumns.isEmpty) { PartitionDirectory(InternalRow.empty, allFiles().filter(f => isDataPath(f.getPath))) :: Nil } else { - prunePartitions(filters, partitionSpec()).map { + prunePartitions(partitionFilters, partitionSpec()).map { case PartitionPath(values, path) => val files: Seq[FileStatus] = leafDirToChildrenFiles.get(path) match { case Some(existingDir) => diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala index 9f0d1ceb28fca..2e060ab9f6801 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala @@ -17,8 +17,6 @@ package org.apache.spark.sql.hive -import java.net.URI - import scala.util.control.NonFatal import com.google.common.util.concurrent.Striped @@ -248,7 +246,7 @@ private[hive] class HiveMetastoreCatalog(sparkSession: SparkSession) extends Log .inferSchema( sparkSession, options, - fileIndex.listFiles(Nil).flatMap(_.files)) + fileIndex.listFiles(Nil, Nil).flatMap(_.files)) .map(mergeWithMetastoreSchema(relation.tableMeta.schema, _)) inferredSchema match { From 13538cf3dd089222c7e12a3cd6e72ac836fa51ac Mon Sep 17 00:00:00 2001 From: Andrew Ray Date: Fri, 17 Mar 2017 16:43:42 +0800 Subject: [PATCH 0041/1765] [SPARK-19882][SQL] Pivot with null as a distinct pivot value throws NPE ## What changes were proposed in this pull request? Allows null values of the pivot column to be included in the pivot values list without throwing NPE Note this PR was made as an alternative to #17224 but preserves the two phase aggregate operation that is needed for good performance. ## How was this patch tested? Additional unit test Author: Andrew Ray Closes #17226 from aray/pivot-null. --- .../spark/sql/catalyst/analysis/Analyzer.scala | 2 +- .../expressions/aggregate/PivotFirst.scala | 18 +++++++++--------- .../apache/spark/sql/DataFramePivotSuite.scala | 14 ++++++++++++++ 3 files changed, 24 insertions(+), 10 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 68a4746a54d96..8cf4073826192 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -524,7 +524,7 @@ class Analyzer( } else { val pivotAggregates: Seq[NamedExpression] = pivotValues.flatMap { value => def ifExpr(expr: Expression) = { - If(EqualTo(pivotColumn, value), expr, Literal(null)) + If(EqualNullSafe(pivotColumn, value), expr, Literal(null)) } aggregates.map { aggregate => val filteredAggregate = aggregate.transformDown { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/PivotFirst.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/PivotFirst.scala index 9ad31243e4122..523714869242d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/PivotFirst.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/PivotFirst.scala @@ -91,14 +91,12 @@ case class PivotFirst( override def update(mutableAggBuffer: InternalRow, inputRow: InternalRow): Unit = { val pivotColValue = pivotColumn.eval(inputRow) - if (pivotColValue != null) { - // We ignore rows whose pivot column value is not in the list of pivot column values. - val index = pivotIndex.getOrElse(pivotColValue, -1) - if (index >= 0) { - val value = valueColumn.eval(inputRow) - if (value != null) { - updateRow(mutableAggBuffer, mutableAggBufferOffset + index, value) - } + // We ignore rows whose pivot column value is not in the list of pivot column values. + val index = pivotIndex.getOrElse(pivotColValue, -1) + if (index >= 0) { + val value = valueColumn.eval(inputRow) + if (value != null) { + updateRow(mutableAggBuffer, mutableAggBufferOffset + index, value) } } } @@ -140,7 +138,9 @@ case class PivotFirst( override val aggBufferAttributes: Seq[AttributeReference] = - pivotIndex.toList.sortBy(_._2).map(kv => AttributeReference(kv._1.toString, valueDataType)()) + pivotIndex.toList.sortBy(_._2).map { kv => + AttributeReference(Option(kv._1).getOrElse("null").toString, valueDataType)() + } override val aggBufferSchema: StructType = StructType.fromAttributes(aggBufferAttributes) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFramePivotSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFramePivotSuite.scala index 51ffe34172714..ca3cb5676742e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFramePivotSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFramePivotSuite.scala @@ -216,4 +216,18 @@ class DataFramePivotSuite extends QueryTest with SharedSQLContext{ Row("d", 15000.0, 48000.0) :: Row("J", 20000.0, 30000.0) :: Nil ) } + + test("pivot with null should not throw NPE") { + checkAnswer( + Seq(Tuple1(None), Tuple1(Some(1))).toDF("a").groupBy($"a").pivot("a").count(), + Row(null, 1, null) :: Row(1, null, 1) :: Nil) + } + + test("pivot with null and aggregate type not supported by PivotFirst returns correct result") { + checkAnswer( + Seq(Tuple1(None), Tuple1(Some(1))).toDF("a") + .withColumn("b", expr("array(a, 7)")) + .groupBy($"a").pivot("a").agg(min($"b")), + Row(null, Seq(null, 7), null) :: Row(1, null, Seq(1, 7)) :: Nil) + } } From 7b5d873aef672aa0aee41e338bab7428101e1ad3 Mon Sep 17 00:00:00 2001 From: Sital Kedia Date: Fri, 17 Mar 2017 09:33:45 -0500 Subject: [PATCH 0042/1765] [SPARK-13369] Add config for number of consecutive fetch failures The previously hardcoded max 4 retries per stage is not suitable for all cluster configurations. Since spark retries a stage at the sign of the first fetch failure, you can easily end up with many stage retries to discover all the failures. In particular, two scenarios this value should change are (1) if there are more than 4 executors per node; in that case, it may take 4 retries to discover the problem with each executor on the node and (2) during cluster maintenance on large clusters, where multiple machines are serviced at once, but you also cannot afford total cluster downtime. By making this value configurable, cluster managers can tune this value to something more appropriate to their cluster configuration. Unit tests Author: Sital Kedia Closes #17307 from sitalkedia/SPARK-13369. --- .../apache/spark/scheduler/DAGScheduler.scala | 15 +++++++++++++-- .../org/apache/spark/scheduler/Stage.scala | 18 +----------------- .../spark/scheduler/DAGSchedulerSuite.scala | 16 ++++++++-------- docs/configuration.md | 5 +++++ 4 files changed, 27 insertions(+), 27 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index 692ed8083475c..d944f268755de 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -187,6 +187,13 @@ class DAGScheduler( /** If enabled, FetchFailed will not cause stage retry, in order to surface the problem. */ private val disallowStageRetryForTest = sc.getConf.getBoolean("spark.test.noStageRetry", false) + /** + * Number of consecutive stage attempts allowed before a stage is aborted. + */ + private[scheduler] val maxConsecutiveStageAttempts = + sc.getConf.getInt("spark.stage.maxConsecutiveAttempts", + DAGScheduler.DEFAULT_MAX_CONSECUTIVE_STAGE_ATTEMPTS) + private val messageScheduler = ThreadUtils.newDaemonSingleThreadScheduledExecutor("dag-scheduler-message") @@ -1282,8 +1289,9 @@ class DAGScheduler( s"longer running") } + failedStage.fetchFailedAttemptIds.add(task.stageAttemptId) val shouldAbortStage = - failedStage.failedOnFetchAndShouldAbort(task.stageAttemptId) || + failedStage.fetchFailedAttemptIds.size >= maxConsecutiveStageAttempts || disallowStageRetryForTest if (shouldAbortStage) { @@ -1292,7 +1300,7 @@ class DAGScheduler( } else { s"""$failedStage (${failedStage.name}) |has failed the maximum allowable number of - |times: ${Stage.MAX_CONSECUTIVE_FETCH_FAILURES}. + |times: $maxConsecutiveStageAttempts. |Most recent failure reason: $failureMessage""".stripMargin.replaceAll("\n", " ") } abortStage(failedStage, abortMessage, None) @@ -1726,4 +1734,7 @@ private[spark] object DAGScheduler { // this is a simplistic way to avoid resubmitting tasks in the non-fetchable map stage one by one // as more failure events come in val RESUBMIT_TIMEOUT = 200 + + // Number of consecutive stage attempts allowed before a stage is aborted + val DEFAULT_MAX_CONSECUTIVE_STAGE_ATTEMPTS = 4 } diff --git a/core/src/main/scala/org/apache/spark/scheduler/Stage.scala b/core/src/main/scala/org/apache/spark/scheduler/Stage.scala index 32e5df6d75f4f..290fd073caf27 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/Stage.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/Stage.scala @@ -87,23 +87,12 @@ private[scheduler] abstract class Stage( * We keep track of each attempt ID that has failed to avoid recording duplicate failures if * multiple tasks from the same stage attempt fail (SPARK-5945). */ - private val fetchFailedAttemptIds = new HashSet[Int] + val fetchFailedAttemptIds = new HashSet[Int] private[scheduler] def clearFailures() : Unit = { fetchFailedAttemptIds.clear() } - /** - * Check whether we should abort the failedStage due to multiple consecutive fetch failures. - * - * This method updates the running set of failed stage attempts and returns - * true if the number of failures exceeds the allowable number of failures. - */ - private[scheduler] def failedOnFetchAndShouldAbort(stageAttemptId: Int): Boolean = { - fetchFailedAttemptIds.add(stageAttemptId) - fetchFailedAttemptIds.size >= Stage.MAX_CONSECUTIVE_FETCH_FAILURES - } - /** Creates a new attempt for this stage by creating a new StageInfo with a new attempt ID. */ def makeNewStageAttempt( numPartitionsToCompute: Int, @@ -128,8 +117,3 @@ private[scheduler] abstract class Stage( /** Returns the sequence of partition ids that are missing (i.e. needs to be computed). */ def findMissingPartitions(): Seq[Int] } - -private[scheduler] object Stage { - // The number of consecutive failures allowed before a stage is aborted - val MAX_CONSECUTIVE_FETCH_FAILURES = 4 -} diff --git a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala index 8eaf9dfcf49b1..dfad5db68a914 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala @@ -801,7 +801,7 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with Timeou val reduceRdd = new MyRDD(sc, 2, List(shuffleDep), tracker = mapOutputTracker) submit(reduceRdd, Array(0, 1)) - for (attempt <- 0 until Stage.MAX_CONSECUTIVE_FETCH_FAILURES) { + for (attempt <- 0 until scheduler.maxConsecutiveStageAttempts) { // Complete all the tasks for the current attempt of stage 0 successfully completeShuffleMapStageSuccessfully(0, attempt, numShufflePartitions = 2) @@ -813,7 +813,7 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with Timeou // map output, for the next iteration through the loop scheduler.resubmitFailedStages() - if (attempt < Stage.MAX_CONSECUTIVE_FETCH_FAILURES - 1) { + if (attempt < scheduler.maxConsecutiveStageAttempts - 1) { assert(scheduler.runningStages.nonEmpty) assert(!ended) } else { @@ -847,11 +847,11 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with Timeou // In the first two iterations, Stage 0 succeeds and stage 1 fails. In the next two iterations, // stage 2 fails. - for (attempt <- 0 until Stage.MAX_CONSECUTIVE_FETCH_FAILURES) { + for (attempt <- 0 until scheduler.maxConsecutiveStageAttempts) { // Complete all the tasks for the current attempt of stage 0 successfully completeShuffleMapStageSuccessfully(0, attempt, numShufflePartitions = 2) - if (attempt < Stage.MAX_CONSECUTIVE_FETCH_FAILURES / 2) { + if (attempt < scheduler.maxConsecutiveStageAttempts / 2) { // Now we should have a new taskSet, for a new attempt of stage 1. // Fail all these tasks with FetchFailure completeNextStageWithFetchFailure(1, attempt, shuffleDepOne) @@ -859,8 +859,8 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with Timeou completeShuffleMapStageSuccessfully(1, attempt, numShufflePartitions = 1) // Fail stage 2 - completeNextStageWithFetchFailure(2, attempt - Stage.MAX_CONSECUTIVE_FETCH_FAILURES / 2, - shuffleDepTwo) + completeNextStageWithFetchFailure(2, + attempt - scheduler.maxConsecutiveStageAttempts / 2, shuffleDepTwo) } // this will trigger a resubmission of stage 0, since we've lost some of its @@ -872,7 +872,7 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with Timeou completeShuffleMapStageSuccessfully(1, 4, numShufflePartitions = 1) // Succeed stage2 with a "42" - completeNextResultStageWithSuccess(2, Stage.MAX_CONSECUTIVE_FETCH_FAILURES/2) + completeNextResultStageWithSuccess(2, scheduler.maxConsecutiveStageAttempts / 2) assert(results === Map(0 -> 42)) assertDataStructuresEmpty() @@ -895,7 +895,7 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with Timeou submit(finalRdd, Array(0)) // First, execute stages 0 and 1, failing stage 1 up to MAX-1 times. - for (attempt <- 0 until Stage.MAX_CONSECUTIVE_FETCH_FAILURES - 1) { + for (attempt <- 0 until scheduler.maxConsecutiveStageAttempts - 1) { // Make each task in stage 0 success completeShuffleMapStageSuccessfully(0, attempt, numShufflePartitions = 2) diff --git a/docs/configuration.md b/docs/configuration.md index 63392a741a1f0..4729f1b0404c1 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -1506,6 +1506,11 @@ Apart from these, the following properties are also available, and may be useful of this setting is to act as a safety-net to prevent runaway uncancellable tasks from rendering an executor unusable. + spark.stage.maxConsecutiveAttempts + 4 + + Number of consecutive stage attempts allowed before a stage is aborted. + From 376d782164437573880f0ad58cecae1cb5f212f2 Mon Sep 17 00:00:00 2001 From: Shixiong Zhu Date: Fri, 17 Mar 2017 11:12:23 -0700 Subject: [PATCH 0043/1765] [SPARK-19986][TESTS] Make pyspark.streaming.tests.CheckpointTests more stable ## What changes were proposed in this pull request? Sometimes, CheckpointTests will hang on a busy machine because the streaming jobs are too slow and cannot catch up. I observed the scheduled delay was keeping increasing for dozens of seconds locally. This PR increases the batch interval from 0.5 seconds to 2 seconds to generate less Spark jobs. It should make `pyspark.streaming.tests.CheckpointTests` more stable. I also replaced `sleep` with `awaitTerminationOrTimeout` so that if the streaming job fails, it will also fail the test. ## How was this patch tested? Jenkins Author: Shixiong Zhu Closes #17323 from zsxwing/SPARK-19986. --- python/pyspark/streaming/tests.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/python/pyspark/streaming/tests.py b/python/pyspark/streaming/tests.py index 2e8ed698278d0..1bec33509580c 100644 --- a/python/pyspark/streaming/tests.py +++ b/python/pyspark/streaming/tests.py @@ -903,11 +903,11 @@ def updater(vs, s): def setup(): conf = SparkConf().set("spark.default.parallelism", 1) sc = SparkContext(conf=conf) - ssc = StreamingContext(sc, 0.5) + ssc = StreamingContext(sc, 2) dstream = ssc.textFileStream(inputd).map(lambda x: (x, 1)) wc = dstream.updateStateByKey(updater) wc.map(lambda x: "%s,%d" % x).saveAsTextFiles(outputd + "test") - wc.checkpoint(.5) + wc.checkpoint(2) self.setupCalled = True return ssc @@ -921,21 +921,22 @@ def setup(): def check_output(n): while not os.listdir(outputd): - time.sleep(0.01) + if self.ssc.awaitTerminationOrTimeout(0.5): + raise Exception("ssc stopped") time.sleep(1) # make sure mtime is larger than the previous one with open(os.path.join(inputd, str(n)), 'w') as f: f.writelines(["%d\n" % i for i in range(10)]) while True: + if self.ssc.awaitTerminationOrTimeout(0.5): + raise Exception("ssc stopped") p = os.path.join(outputd, max(os.listdir(outputd))) if '_SUCCESS' not in os.listdir(p): # not finished - time.sleep(0.01) continue ordd = self.ssc.sparkContext.textFile(p).map(lambda line: line.split(",")) d = ordd.values().map(int).collect() if not d: - time.sleep(0.01) continue self.assertEqual(10, len(d)) s = set(d) From bfdeea5c68f963ce60d48d0aa4a4c8c582169950 Mon Sep 17 00:00:00 2001 From: Andrew Ray Date: Fri, 17 Mar 2017 14:23:07 -0700 Subject: [PATCH 0044/1765] [SPARK-18847][GRAPHX] PageRank gives incorrect results for graphs with sinks ## What changes were proposed in this pull request? Graphs with sinks (vertices with no outgoing edges) don't have the expected rank sum of n (or 1 for personalized). We fix this by normalizing to the expected sum at the end of each implementation. Additionally this fixes the dynamic version of personal pagerank which gave incorrect answers that were not detected by existing unit tests. ## How was this patch tested? Revamped existing and additional unit tests with reference values (and reproduction code) from igraph and NetworkX. Note that for comparison on personal pagerank we use the arpack algorithm in igraph as prpack (the current default) redistributes rank to all vertices uniformly instead of just to the personalization source. We could take the alternate convention (redistribute rank to all vertices uniformly) but that would involve more extensive changes to the algorithms (the dynamic version would no longer be able to use Pregel). Author: Andrew Ray Closes #16483 from aray/pagerank-sink2. --- .../apache/spark/graphx/lib/PageRank.scala | 45 +++-- .../spark/graphx/lib/PageRankSuite.scala | 158 +++++++++++++----- 2 files changed, 144 insertions(+), 59 deletions(-) diff --git a/graphx/src/main/scala/org/apache/spark/graphx/lib/PageRank.scala b/graphx/src/main/scala/org/apache/spark/graphx/lib/PageRank.scala index 37b6e453592e5..13b2b57719188 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/lib/PageRank.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/lib/PageRank.scala @@ -162,7 +162,8 @@ object PageRank extends Logging { iteration += 1 } - rankGraph + // SPARK-18847 If the graph has sinks (vertices with no outgoing edges) correct the sum of ranks + normalizeRankSum(rankGraph, personalized) } /** @@ -179,7 +180,8 @@ object PageRank extends Logging { * @param resetProb The random reset probability * @param sources The list of sources to compute personalized pagerank from * @return the graph with vertex attributes - * containing the pagerank relative to all starting nodes (as a sparse vector) and + * containing the pagerank relative to all starting nodes (as a sparse vector + * indexed by the position of nodes in the sources list) and * edge attributes the normalized edge weight */ def runParallelPersonalizedPageRank[VD: ClassTag, ED: ClassTag](graph: Graph[VD, ED], @@ -194,6 +196,8 @@ object PageRank extends Logging { // TODO if one sources vertex id is outside of the int range // we won't be able to store its activations in a sparse vector + require(sources.max <= Int.MaxValue.toLong, + s"This implementation currently only works for source vertex ids at most ${Int.MaxValue}") val zero = Vectors.sparse(sources.size, List()).asBreeze val sourcesInitMap = sources.zipWithIndex.map { case (vid, i) => val v = Vectors.sparse(sources.size, Array(i), Array(1.0)).asBreeze @@ -245,8 +249,10 @@ object PageRank extends Logging { i += 1 } + // SPARK-18847 If the graph has sinks (vertices with no outgoing edges) correct the sum of ranks + val rankSums = rankGraph.vertices.values.fold(zero)(_ :+ _) rankGraph.mapVertices { (vid, attr) => - Vectors.fromBreeze(attr) + Vectors.fromBreeze(attr :/ rankSums) } } @@ -307,7 +313,7 @@ object PageRank extends Logging { .mapTriplets( e => 1.0 / e.srcAttr ) // Set the vertex attributes to (initialPR, delta = 0) .mapVertices { (id, attr) => - if (id == src) (1.0, Double.NegativeInfinity) else (0.0, 0.0) + if (id == src) (0.0, Double.NegativeInfinity) else (0.0, 0.0) } .cache() @@ -322,13 +328,12 @@ object PageRank extends Logging { def personalizedVertexProgram(id: VertexId, attr: (Double, Double), msgSum: Double): (Double, Double) = { val (oldPR, lastDelta) = attr - var teleport = oldPR - val delta = if (src==id) resetProb else 0.0 - teleport = oldPR*delta - - val newPR = teleport + (1.0 - resetProb) * msgSum - val newDelta = if (lastDelta == Double.NegativeInfinity) newPR else newPR - oldPR - (newPR, newDelta) + val newPR = if (lastDelta == Double.NegativeInfinity) { + 1.0 + } else { + oldPR + (1.0 - resetProb) * msgSum + } + (newPR, newPR - oldPR) } def sendMessage(edge: EdgeTriplet[(Double, Double), Double]) = { @@ -353,9 +358,23 @@ object PageRank extends Logging { vertexProgram(id, attr, msgSum) } - Pregel(pagerankGraph, initialMessage, activeDirection = EdgeDirection.Out)( + val rankGraph = Pregel(pagerankGraph, initialMessage, activeDirection = EdgeDirection.Out)( vp, sendMessage, messageCombiner) .mapVertices((vid, attr) => attr._1) - } // end of deltaPageRank + // SPARK-18847 If the graph has sinks (vertices with no outgoing edges) correct the sum of ranks + normalizeRankSum(rankGraph, personalized) + } + + // Normalizes the sum of ranks to n (or 1 if personalized) + private def normalizeRankSum(rankGraph: Graph[Double, Double], personalized: Boolean) = { + val rankSum = rankGraph.vertices.values.sum() + if (personalized) { + rankGraph.mapVertices((id, rank) => rank / rankSum) + } else { + val numVertices = rankGraph.numVertices + val correctionFactor = numVertices.toDouble / rankSum + rankGraph.mapVertices((id, rank) => rank * correctionFactor) + } + } } diff --git a/graphx/src/test/scala/org/apache/spark/graphx/lib/PageRankSuite.scala b/graphx/src/test/scala/org/apache/spark/graphx/lib/PageRankSuite.scala index 6afbb5a959894..9779553ce85d1 100644 --- a/graphx/src/test/scala/org/apache/spark/graphx/lib/PageRankSuite.scala +++ b/graphx/src/test/scala/org/apache/spark/graphx/lib/PageRankSuite.scala @@ -50,7 +50,8 @@ object GridPageRank { inNbrs(ind).map( nbr => oldPr(nbr) / outDegree(nbr)).sum } } - (0L until (nRows * nCols)).zip(pr) + val prSum = pr.sum + (0L until (nRows * nCols)).zip(pr.map(_ * pr.length / prSum)) } } @@ -68,26 +69,34 @@ class PageRankSuite extends SparkFunSuite with LocalSparkContext { val nVertices = 100 val starGraph = GraphGenerators.starGraph(sc, nVertices).cache() val resetProb = 0.15 + val tol = 0.0001 + val numIter = 2 val errorTol = 1.0e-5 - val staticRanks1 = starGraph.staticPageRank(numIter = 2, resetProb).vertices - val staticRanks2 = starGraph.staticPageRank(numIter = 3, resetProb).vertices.cache() + val staticRanks = starGraph.staticPageRank(numIter, resetProb).vertices.cache() + val staticRanks2 = starGraph.staticPageRank(numIter + 1, resetProb).vertices - // Static PageRank should only take 3 iterations to converge - val notMatching = staticRanks1.innerZipJoin(staticRanks2) { (vid, pr1, pr2) => + // Static PageRank should only take 2 iterations to converge + val notMatching = staticRanks.innerZipJoin(staticRanks2) { (vid, pr1, pr2) => if (pr1 != pr2) 1 else 0 }.map { case (vid, test) => test }.sum() assert(notMatching === 0) - val staticErrors = staticRanks2.map { case (vid, pr) => - val p = math.abs(pr - (resetProb + (1.0 - resetProb) * (resetProb * (nVertices - 1)) )) - val correct = (vid > 0 && pr == resetProb) || (vid == 0L && p < 1.0E-5) - if (!correct) 1 else 0 - } - assert(staticErrors.sum === 0) + val dynamicRanks = starGraph.pageRank(tol, resetProb).vertices.cache() + assert(compareRanks(staticRanks, dynamicRanks) < errorTol) + + // Computed in igraph 1.0 w/ R bindings: + // > page_rank(make_star(100, mode = "in")) + // Alternatively in NetworkX 1.11: + // > nx.pagerank(nx.DiGraph([(x, 0) for x in range(1,100)])) + // We multiply by the number of vertices to account for difference in normalization + val centerRank = 0.462394787 * nVertices + val othersRank = 0.005430356 * nVertices + val igraphPR = centerRank +: Seq.fill(nVertices - 1)(othersRank) + val ranks = VertexRDD(sc.parallelize(0L until nVertices zip igraphPR)) + assert(compareRanks(staticRanks, ranks) < errorTol) + assert(compareRanks(dynamicRanks, ranks) < errorTol) - val dynamicRanks = starGraph.pageRank(0, resetProb).vertices.cache() - assert(compareRanks(staticRanks2, dynamicRanks) < errorTol) } } // end of test Star PageRank @@ -96,51 +105,62 @@ class PageRankSuite extends SparkFunSuite with LocalSparkContext { val nVertices = 100 val starGraph = GraphGenerators.starGraph(sc, nVertices).cache() val resetProb = 0.15 + val tol = 0.0001 + val numIter = 2 val errorTol = 1.0e-5 - val staticRanks1 = starGraph.staticPersonalizedPageRank(0, numIter = 1, resetProb).vertices - val staticRanks2 = starGraph.staticPersonalizedPageRank(0, numIter = 2, resetProb) - .vertices.cache() - - // Static PageRank should only take 2 iterations to converge - val notMatching = staticRanks1.innerZipJoin(staticRanks2) { (vid, pr1, pr2) => - if (pr1 != pr2) 1 else 0 - }.map { case (vid, test) => test }.sum - assert(notMatching === 0) + val staticRanks = starGraph.staticPersonalizedPageRank(0, numIter, resetProb).vertices.cache() - val staticErrors = staticRanks2.map { case (vid, pr) => - val correct = (vid > 0 && pr == 0.0) || - (vid == 0 && pr == resetProb) - if (!correct) 1 else 0 - } - assert(staticErrors.sum === 0) - - val dynamicRanks = starGraph.personalizedPageRank(0, 0, resetProb).vertices.cache() - assert(compareRanks(staticRanks2, dynamicRanks) < errorTol) + val dynamicRanks = starGraph.personalizedPageRank(0, tol, resetProb).vertices.cache() + assert(compareRanks(staticRanks, dynamicRanks) < errorTol) - val parallelStaticRanks1 = starGraph - .staticParallelPersonalizedPageRank(Array(0), 1, resetProb).mapVertices { + val parallelStaticRanks = starGraph + .staticParallelPersonalizedPageRank(Array(0), numIter, resetProb).mapVertices { case (vertexId, vector) => vector(0) }.vertices.cache() - assert(compareRanks(staticRanks1, parallelStaticRanks1) < errorTol) + assert(compareRanks(staticRanks, parallelStaticRanks) < errorTol) + + // Computed in igraph 1.0 w/ R bindings: + // > page_rank(make_star(100, mode = "in"), personalized = c(1, rep(0, 99)), algo = "arpack") + // NOTE: We use the arpack algorithm as prpack (the default) redistributes rank to all + // vertices uniformly instead of just to the personalization source. + // Alternatively in NetworkX 1.11: + // > nx.pagerank(nx.DiGraph([(x, 0) for x in range(1,100)]), + // personalization=dict([(x, 1 if x == 0 else 0) for x in range(0,100)])) + // We multiply by the number of vertices to account for difference in normalization + val igraphPR0 = 1.0 +: Seq.fill(nVertices - 1)(0.0) + val ranks0 = VertexRDD(sc.parallelize(0L until nVertices zip igraphPR0)) + assert(compareRanks(staticRanks, ranks0) < errorTol) + assert(compareRanks(dynamicRanks, ranks0) < errorTol) - val parallelStaticRanks2 = starGraph - .staticParallelPersonalizedPageRank(Array(0, 1), 2, resetProb).mapVertices { - case (vertexId, vector) => vector(0) - }.vertices.cache() - assert(compareRanks(staticRanks2, parallelStaticRanks2) < errorTol) // We have one outbound edge from 1 to 0 - val otherStaticRanks2 = starGraph.staticPersonalizedPageRank(1, numIter = 2, resetProb) + val otherStaticRanks = starGraph.staticPersonalizedPageRank(1, numIter, resetProb) .vertices.cache() - val otherDynamicRanks = starGraph.personalizedPageRank(1, 0, resetProb).vertices.cache() - val otherParallelStaticRanks2 = starGraph - .staticParallelPersonalizedPageRank(Array(0, 1), 2, resetProb).mapVertices { + val otherDynamicRanks = starGraph.personalizedPageRank(1, tol, resetProb).vertices.cache() + val otherParallelStaticRanks = starGraph + .staticParallelPersonalizedPageRank(Array(0, 1), numIter, resetProb).mapVertices { case (vertexId, vector) => vector(1) }.vertices.cache() - assert(compareRanks(otherDynamicRanks, otherStaticRanks2) < errorTol) - assert(compareRanks(otherStaticRanks2, otherParallelStaticRanks2) < errorTol) - assert(compareRanks(otherDynamicRanks, otherParallelStaticRanks2) < errorTol) + assert(compareRanks(otherDynamicRanks, otherStaticRanks) < errorTol) + assert(compareRanks(otherStaticRanks, otherParallelStaticRanks) < errorTol) + assert(compareRanks(otherDynamicRanks, otherParallelStaticRanks) < errorTol) + + // Computed in igraph 1.0 w/ R bindings: + // > page_rank(make_star(100, mode = "in"), + // personalized = c(0, 1, rep(0, 98)), algo = "arpack") + // NOTE: We use the arpack algorithm as prpack (the default) redistributes rank to all + // vertices uniformly instead of just to the personalization source. + // Alternatively in NetworkX 1.11: + // > nx.pagerank(nx.DiGraph([(x, 0) for x in range(1,100)]), + // personalization=dict([(x, 1 if x == 1 else 0) for x in range(0,100)])) + val centerRank = 0.4594595 + val sourceRank = 0.5405405 + val igraphPR1 = centerRank +: sourceRank +: Seq.fill(nVertices - 2)(0.0) + val ranks1 = VertexRDD(sc.parallelize(0L until nVertices zip igraphPR1)) + assert(compareRanks(otherStaticRanks, ranks1) < errorTol) + assert(compareRanks(otherDynamicRanks, ranks1) < errorTol) + assert(compareRanks(otherParallelStaticRanks, ranks1) < errorTol) } } // end of test Star PersonalPageRank @@ -229,4 +249,50 @@ class PageRankSuite extends SparkFunSuite with LocalSparkContext { } } + + test("Loop with sink PageRank") { + withSpark { sc => + val edges = sc.parallelize((1L, 2L) :: (2L, 3L) :: (3L, 1L) :: (1L, 4L) :: Nil) + val g = Graph.fromEdgeTuples(edges, 1) + val resetProb = 0.15 + val tol = 0.0001 + val numIter = 20 + val errorTol = 1.0e-5 + + val staticRanks = g.staticPageRank(numIter, resetProb).vertices.cache() + val dynamicRanks = g.pageRank(tol, resetProb).vertices.cache() + + assert(compareRanks(staticRanks, dynamicRanks) < errorTol) + + // Computed in igraph 1.0 w/ R bindings: + // > page_rank(graph_from_literal( A -+ B -+ C -+ A -+ D)) + // Alternatively in NetworkX 1.11: + // > nx.pagerank(nx.DiGraph([(1,2),(2,3),(3,1),(1,4)])) + // We multiply by the number of vertices to account for difference in normalization + val igraphPR = Seq(0.3078534, 0.2137622, 0.2646223, 0.2137622).map(_ * 4) + val ranks = VertexRDD(sc.parallelize(1L to 4L zip igraphPR)) + assert(compareRanks(staticRanks, ranks) < errorTol) + assert(compareRanks(dynamicRanks, ranks) < errorTol) + + val p1staticRanks = g.staticPersonalizedPageRank(1, numIter, resetProb).vertices.cache() + val p1dynamicRanks = g.personalizedPageRank(1, tol, resetProb).vertices.cache() + val p1parallelDynamicRanks = + g.staticParallelPersonalizedPageRank(Array(1, 2, 3, 4), numIter, resetProb) + .vertices.mapValues(v => v(0)).cache() + + // Computed in igraph 1.0 w/ R bindings: + // > page_rank(graph_from_literal( A -+ B -+ C -+ A -+ D), personalized = c(1, 0, 0, 0), + // algo = "arpack") + // NOTE: We use the arpack algorithm as prpack (the default) redistributes rank to all + // vertices uniformly instead of just to the personalization source. + // Alternatively in NetworkX 1.11: + // > nx.pagerank(nx.DiGraph([(1,2),(2,3),(3,1),(1,4)]), personalization={1:1, 2:0, 3:0, 4:0}) + val igraphPR2 = Seq(0.4522329, 0.1921990, 0.1633691, 0.1921990) + val ranks2 = VertexRDD(sc.parallelize(1L to 4L zip igraphPR2)) + assert(compareRanks(p1staticRanks, ranks2) < errorTol) + assert(compareRanks(p1dynamicRanks, ranks2) < errorTol) + assert(compareRanks(p1parallelDynamicRanks, ranks2) < errorTol) + + } + } } From 7de66bae58733595cb88ec899640f7acf734d5c4 Mon Sep 17 00:00:00 2001 From: Takeshi Yamamuro Date: Fri, 17 Mar 2017 14:51:59 -0700 Subject: [PATCH 0045/1765] [SPARK-19967][SQL] Add from_json in FunctionRegistry ## What changes were proposed in this pull request? This pr added entries in `FunctionRegistry` and supported `from_json` in SQL. ## How was this patch tested? Added tests in `JsonFunctionsSuite` and `SQLQueryTestSuite`. Author: Takeshi Yamamuro Closes #17320 from maropu/SPARK-19967. --- .../catalyst/analysis/FunctionRegistry.scala | 1 + .../expressions/jsonExpressions.scala | 36 +++++- .../sql-tests/inputs/json-functions.sql | 13 +++ .../sql-tests/results/json-functions.sql.out | 107 +++++++++++++++++- .../apache/spark/sql/JsonFunctionsSuite.scala | 36 ++++++ 5 files changed, 189 insertions(+), 4 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index 0dcb44081f608..0486e67dbdf86 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -426,6 +426,7 @@ object FunctionRegistry { // json expression[StructToJson]("to_json"), + expression[JsonToStruct]("from_json"), // Cast aliases (SPARK-16730) castAlias("boolean", BooleanType), diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala index 18b5f2f7ed2e8..37e4bb5060436 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala @@ -26,6 +26,7 @@ import com.fasterxml.jackson.core._ import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback +import org.apache.spark.sql.catalyst.parser.CatalystSqlParser import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.json._ import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, GenericArrayData, ParseModes} @@ -483,6 +484,17 @@ case class JsonTuple(children: Seq[Expression]) /** * Converts an json input string to a [[StructType]] or [[ArrayType]] with the specified schema. */ +// scalastyle:off line.size.limit +@ExpressionDescription( + usage = "_FUNC_(jsonStr, schema[, options]) - Returns a struct value with the given `jsonStr` and `schema`.", + extended = """ + Examples: + > SELECT _FUNC_('{"a":1, "b":0.8}', 'a INT, b DOUBLE'); + {"a":1, "b":0.8} + > SELECT _FUNC_('{"time":"26/08/2015"}', 'time Timestamp', map('timestampFormat', 'dd/MM/yyyy')); + {"time":"2015-08-26 00:00:00.0"} + """) +// scalastyle:on line.size.limit case class JsonToStruct( schema: DataType, options: Map[String, String], @@ -494,6 +506,21 @@ case class JsonToStruct( def this(schema: DataType, options: Map[String, String], child: Expression) = this(schema, options, child, None) + // Used in `FunctionRegistry` + def this(child: Expression, schema: Expression) = + this( + schema = JsonExprUtils.validateSchemaLiteral(schema), + options = Map.empty[String, String], + child = child, + timeZoneId = None) + + def this(child: Expression, schema: Expression, options: Expression) = + this( + schema = JsonExprUtils.validateSchemaLiteral(schema), + options = JsonExprUtils.convertToMapData(options), + child = child, + timeZoneId = None) + override def checkInputDataTypes(): TypeCheckResult = schema match { case _: StructType | ArrayType(_: StructType, _) => super.checkInputDataTypes() @@ -589,7 +616,7 @@ case class StructToJson( def this(child: Expression) = this(Map.empty, child, None) def this(child: Expression, options: Expression) = this( - options = StructToJson.convertToMapData(options), + options = JsonExprUtils.convertToMapData(options), child = child, timeZoneId = None) @@ -634,7 +661,12 @@ case class StructToJson( override def inputTypes: Seq[AbstractDataType] = StructType :: Nil } -object StructToJson { +object JsonExprUtils { + + def validateSchemaLiteral(exp: Expression): StructType = exp match { + case Literal(s, StringType) => CatalystSqlParser.parseTableSchema(s.toString) + case e => throw new AnalysisException(s"Expected a string literal instead of $e") + } def convertToMapData(exp: Expression): Map[String, String] = exp match { case m: CreateMap diff --git a/sql/core/src/test/resources/sql-tests/inputs/json-functions.sql b/sql/core/src/test/resources/sql-tests/inputs/json-functions.sql index 9308560451bf5..83243c5e5a12f 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/json-functions.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/json-functions.sql @@ -5,4 +5,17 @@ select to_json(named_struct('a', 1, 'b', 2)); select to_json(named_struct('time', to_timestamp('2015-08-26', 'yyyy-MM-dd')), map('timestampFormat', 'dd/MM/yyyy')); -- Check if errors handled select to_json(named_struct('a', 1, 'b', 2), named_struct('mode', 'PERMISSIVE')); +select to_json(named_struct('a', 1, 'b', 2), map('mode', 1)); select to_json(); + +-- from_json +describe function from_json; +describe function extended from_json; +select from_json('{"a":1}', 'a INT'); +select from_json('{"time":"26/08/2015"}', 'time Timestamp', map('timestampFormat', 'dd/MM/yyyy')); +-- Check if errors handled +select from_json('{"a":1}', 1); +select from_json('{"a":1}', 'a InvalidType'); +select from_json('{"a":1}', 'a INT', named_struct('mode', 'PERMISSIVE')); +select from_json('{"a":1}', 'a INT', map('mode', 1)); +select from_json(); diff --git a/sql/core/src/test/resources/sql-tests/results/json-functions.sql.out b/sql/core/src/test/resources/sql-tests/results/json-functions.sql.out index d8aa4fb9fa788..b57cbbc1d843b 100644 --- a/sql/core/src/test/resources/sql-tests/results/json-functions.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/json-functions.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 6 +-- Number of queries: 16 -- !query 0 @@ -55,9 +55,112 @@ Must use a map() function for options;; line 1 pos 7 -- !query 5 -select to_json() +select to_json(named_struct('a', 1, 'b', 2), map('mode', 1)) -- !query 5 schema struct<> -- !query 5 output org.apache.spark.sql.AnalysisException +A type of keys and values in map() must be string, but got MapType(StringType,IntegerType,false);; line 1 pos 7 + + +-- !query 6 +select to_json() +-- !query 6 schema +struct<> +-- !query 6 output +org.apache.spark.sql.AnalysisException Invalid number of arguments for function to_json; line 1 pos 7 + + +-- !query 7 +describe function from_json +-- !query 7 schema +struct +-- !query 7 output +Class: org.apache.spark.sql.catalyst.expressions.JsonToStruct +Function: from_json +Usage: from_json(jsonStr, schema[, options]) - Returns a struct value with the given `jsonStr` and `schema`. + + +-- !query 8 +describe function extended from_json +-- !query 8 schema +struct +-- !query 8 output +Class: org.apache.spark.sql.catalyst.expressions.JsonToStruct +Extended Usage: + Examples: + > SELECT from_json('{"a":1, "b":0.8}', 'a INT, b DOUBLE'); + {"a":1, "b":0.8} + > SELECT from_json('{"time":"26/08/2015"}', 'time Timestamp', map('timestampFormat', 'dd/MM/yyyy')); + {"time":"2015-08-26 00:00:00.0"} + +Function: from_json +Usage: from_json(jsonStr, schema[, options]) - Returns a struct value with the given `jsonStr` and `schema`. + + +-- !query 9 +select from_json('{"a":1}', 'a INT') +-- !query 9 schema +struct> +-- !query 9 output +{"a":1} + + +-- !query 10 +select from_json('{"time":"26/08/2015"}', 'time Timestamp', map('timestampFormat', 'dd/MM/yyyy')) +-- !query 10 schema +struct> +-- !query 10 output +{"time":2015-08-26 00:00:00.0} + + +-- !query 11 +select from_json('{"a":1}', 1) +-- !query 11 schema +struct<> +-- !query 11 output +org.apache.spark.sql.AnalysisException +Expected a string literal instead of 1;; line 1 pos 7 + + +-- !query 12 +select from_json('{"a":1}', 'a InvalidType') +-- !query 12 schema +struct<> +-- !query 12 output +org.apache.spark.sql.AnalysisException + +DataType invalidtype() is not supported.(line 1, pos 2) + +== SQL == +a InvalidType +--^^^ +; line 1 pos 7 + + +-- !query 13 +select from_json('{"a":1}', 'a INT', named_struct('mode', 'PERMISSIVE')) +-- !query 13 schema +struct<> +-- !query 13 output +org.apache.spark.sql.AnalysisException +Must use a map() function for options;; line 1 pos 7 + + +-- !query 14 +select from_json('{"a":1}', 'a INT', map('mode', 1)) +-- !query 14 schema +struct<> +-- !query 14 output +org.apache.spark.sql.AnalysisException +A type of keys and values in map() must be string, but got MapType(StringType,IntegerType,false);; line 1 pos 7 + + +-- !query 15 +select from_json() +-- !query 15 schema +struct<> +-- !query 15 output +org.apache.spark.sql.AnalysisException +Invalid number of arguments for function from_json; line 1 pos 7 diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JsonFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JsonFunctionsSuite.scala index cdea3b9a0f79f..2345b82081161 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/JsonFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/JsonFunctionsSuite.scala @@ -220,4 +220,40 @@ class JsonFunctionsSuite extends QueryTest with SharedSQLContext { assert(errMsg2.getMessage.startsWith( "A type of keys and values in map() must be string, but got")) } + + test("SPARK-19967 Support from_json in SQL") { + val df1 = Seq("""{"a": 1}""").toDS() + checkAnswer( + df1.selectExpr("from_json(value, 'a INT')"), + Row(Row(1)) :: Nil) + + val df2 = Seq("""{"c0": "a", "c1": 1, "c2": {"c20": 3.8, "c21": 8}}""").toDS() + checkAnswer( + df2.selectExpr("from_json(value, 'c0 STRING, c1 INT, c2 STRUCT')"), + Row(Row("a", 1, Row(3.8, 8))) :: Nil) + + val df3 = Seq("""{"time": "26/08/2015 18:00"}""").toDS() + checkAnswer( + df3.selectExpr( + "from_json(value, 'time Timestamp', map('timestampFormat', 'dd/MM/yyyy HH:mm'))"), + Row(Row(java.sql.Timestamp.valueOf("2015-08-26 18:00:00.0")))) + + val errMsg1 = intercept[AnalysisException] { + df3.selectExpr("from_json(value, 1)") + } + assert(errMsg1.getMessage.startsWith("Expected a string literal instead of")) + val errMsg2 = intercept[AnalysisException] { + df3.selectExpr("""from_json(value, 'time InvalidType')""") + } + assert(errMsg2.getMessage.contains("DataType invalidtype() is not supported")) + val errMsg3 = intercept[AnalysisException] { + df3.selectExpr("from_json(value, 'time Timestamp', named_struct('a', 1))") + } + assert(errMsg3.getMessage.startsWith("Must use a map() function for options")) + val errMsg4 = intercept[AnalysisException] { + df3.selectExpr("from_json(value, 'time Timestamp', map('a', 1))") + } + assert(errMsg4.getMessage.startsWith( + "A type of keys and values in map() must be string, but got")) + } } From 3783539d7ab83a2a632a9f35ca66ae39d01c28b6 Mon Sep 17 00:00:00 2001 From: Kunal Khamar Date: Fri, 17 Mar 2017 16:16:22 -0700 Subject: [PATCH 0046/1765] [SPARK-19873][SS] Record num shuffle partitions in offset log and enforce in next batch. ## What changes were proposed in this pull request? If the user changes the shuffle partition number between batches, Streaming aggregation will fail. Here are some possible cases: - Change "spark.sql.shuffle.partitions" - Use "repartition" and change the partition number in codes - RangePartitioner doesn't generate deterministic partitions. Right now it's safe as we disallow sort before aggregation. Not sure if we will add some operators using RangePartitioner in future. ## How was this patch tested? - Unit tests - Manual tests - forward compatibility tested by using the new `OffsetSeqMetadata` json with Spark v2.1.0 Author: Kunal Khamar Closes #17216 from kunalkhamar/num-partitions. --- .../sql/execution/streaming/OffsetSeq.scala | 8 +- .../execution/streaming/StreamExecution.scala | 60 ++++++++--- .../sql/streaming/StreamingQueryManager.scala | 8 +- .../checkpoint-version-2.1.0/metadata | 1 + .../checkpoint-version-2.1.0/offsets/0 | 3 + .../checkpoint-version-2.1.0/offsets/1 | 3 + .../state/0/0/1.delta | Bin 0 -> 46 bytes .../state/0/0/2.delta | Bin 0 -> 46 bytes .../state/0/1/1.delta | Bin 0 -> 79 bytes .../state/0/1/2.delta | Bin 0 -> 79 bytes .../state/0/2/1.delta | Bin 0 -> 79 bytes .../state/0/2/2.delta | Bin 0 -> 79 bytes .../state/0/3/1.delta | Bin 0 -> 73 bytes .../state/0/3/2.delta | Bin 0 -> 79 bytes .../state/0/4/1.delta | Bin 0 -> 79 bytes .../state/0/4/2.delta | Bin 0 -> 46 bytes .../state/0/5/1.delta | Bin 0 -> 46 bytes .../state/0/5/2.delta | Bin 0 -> 46 bytes .../state/0/6/1.delta | Bin 0 -> 46 bytes .../state/0/6/2.delta | Bin 0 -> 79 bytes .../state/0/7/1.delta | Bin 0 -> 46 bytes .../state/0/7/2.delta | Bin 0 -> 79 bytes .../state/0/8/1.delta | Bin 0 -> 46 bytes .../state/0/8/2.delta | Bin 0 -> 46 bytes .../state/0/9/1.delta | Bin 0 -> 46 bytes .../state/0/9/2.delta | Bin 0 -> 79 bytes .../streaming/OffsetSeqLogSuite.scala | 38 +++++-- .../spark/sql/streaming/StreamSuite.scala | 101 +++++++++++++++++- .../StreamingQueryManagerSuite.scala | 10 -- .../test/DataStreamReaderWriterSuite.scala | 22 ++-- 30 files changed, 207 insertions(+), 47 deletions(-) create mode 100644 sql/core/src/test/resources/structured-streaming/checkpoint-version-2.1.0/metadata create mode 100644 sql/core/src/test/resources/structured-streaming/checkpoint-version-2.1.0/offsets/0 create mode 100644 sql/core/src/test/resources/structured-streaming/checkpoint-version-2.1.0/offsets/1 create mode 100644 sql/core/src/test/resources/structured-streaming/checkpoint-version-2.1.0/state/0/0/1.delta create mode 100644 sql/core/src/test/resources/structured-streaming/checkpoint-version-2.1.0/state/0/0/2.delta create mode 100644 sql/core/src/test/resources/structured-streaming/checkpoint-version-2.1.0/state/0/1/1.delta create mode 100644 sql/core/src/test/resources/structured-streaming/checkpoint-version-2.1.0/state/0/1/2.delta create mode 100644 sql/core/src/test/resources/structured-streaming/checkpoint-version-2.1.0/state/0/2/1.delta create mode 100644 sql/core/src/test/resources/structured-streaming/checkpoint-version-2.1.0/state/0/2/2.delta create mode 100644 sql/core/src/test/resources/structured-streaming/checkpoint-version-2.1.0/state/0/3/1.delta create mode 100644 sql/core/src/test/resources/structured-streaming/checkpoint-version-2.1.0/state/0/3/2.delta create mode 100644 sql/core/src/test/resources/structured-streaming/checkpoint-version-2.1.0/state/0/4/1.delta create mode 100644 sql/core/src/test/resources/structured-streaming/checkpoint-version-2.1.0/state/0/4/2.delta create mode 100644 sql/core/src/test/resources/structured-streaming/checkpoint-version-2.1.0/state/0/5/1.delta create mode 100644 sql/core/src/test/resources/structured-streaming/checkpoint-version-2.1.0/state/0/5/2.delta create mode 100644 sql/core/src/test/resources/structured-streaming/checkpoint-version-2.1.0/state/0/6/1.delta create mode 100644 sql/core/src/test/resources/structured-streaming/checkpoint-version-2.1.0/state/0/6/2.delta create mode 100644 sql/core/src/test/resources/structured-streaming/checkpoint-version-2.1.0/state/0/7/1.delta create mode 100644 sql/core/src/test/resources/structured-streaming/checkpoint-version-2.1.0/state/0/7/2.delta create mode 100644 sql/core/src/test/resources/structured-streaming/checkpoint-version-2.1.0/state/0/8/1.delta create mode 100644 sql/core/src/test/resources/structured-streaming/checkpoint-version-2.1.0/state/0/8/2.delta create mode 100644 sql/core/src/test/resources/structured-streaming/checkpoint-version-2.1.0/state/0/9/1.delta create mode 100644 sql/core/src/test/resources/structured-streaming/checkpoint-version-2.1.0/state/0/9/2.delta diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/OffsetSeq.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/OffsetSeq.scala index e5a1997d6b808..8249adab4bba8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/OffsetSeq.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/OffsetSeq.scala @@ -20,7 +20,6 @@ package org.apache.spark.sql.execution.streaming import org.json4s.NoTypeHints import org.json4s.jackson.Serialization - /** * An ordered collection of offsets, used to track the progress of processing data from one or more * [[Source]]s that are present in a streaming query. This is similar to simplified, single-instance @@ -70,8 +69,12 @@ object OffsetSeq { * bound the lateness of data that will processed. Time unit: milliseconds * @param batchTimestampMs: The current batch processing timestamp. * Time unit: milliseconds + * @param conf: Additional conf_s to be persisted across batches, e.g. number of shuffle partitions. */ -case class OffsetSeqMetadata(var batchWatermarkMs: Long = 0, var batchTimestampMs: Long = 0) { +case class OffsetSeqMetadata( + batchWatermarkMs: Long = 0, + batchTimestampMs: Long = 0, + conf: Map[String, String] = Map.empty) { def json: String = Serialization.write(this)(OffsetSeqMetadata.format) } @@ -79,4 +82,3 @@ object OffsetSeqMetadata { private implicit val format = Serialization.formats(NoTypeHints) def apply(json: String): OffsetSeqMetadata = Serialization.read[OffsetSeqMetadata](json) } - diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala index 529263805c0aa..40faddccc2423 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala @@ -35,6 +35,7 @@ import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap, Curre import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan} import org.apache.spark.sql.execution.QueryExecution import org.apache.spark.sql.execution.command.StreamingExplainCommand +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.streaming._ import org.apache.spark.util.{Clock, UninterruptibleThread, Utils} @@ -117,7 +118,9 @@ class StreamExecution( } /** Metadata associated with the offset seq of a batch in the query. */ - protected var offsetSeqMetadata = OffsetSeqMetadata() + protected var offsetSeqMetadata = OffsetSeqMetadata(batchWatermarkMs = 0, batchTimestampMs = 0, + conf = Map(SQLConf.SHUFFLE_PARTITIONS.key -> + sparkSession.conf.get(SQLConf.SHUFFLE_PARTITIONS).toString)) override val id: UUID = UUID.fromString(streamMetadata.id) @@ -256,6 +259,15 @@ class StreamExecution( updateStatusMessage("Initializing sources") // force initialization of the logical plan so that the sources can be created logicalPlan + + // Isolated spark session to run the batches with. + val sparkSessionToRunBatches = sparkSession.cloneSession() + // Adaptive execution can change num shuffle partitions, disallow + sparkSessionToRunBatches.conf.set(SQLConf.ADAPTIVE_EXECUTION_ENABLED.key, "false") + offsetSeqMetadata = OffsetSeqMetadata(batchWatermarkMs = 0, batchTimestampMs = 0, + conf = Map(SQLConf.SHUFFLE_PARTITIONS.key -> + sparkSessionToRunBatches.conf.get(SQLConf.SHUFFLE_PARTITIONS.key))) + if (state.compareAndSet(INITIALIZING, ACTIVE)) { // Unblock `awaitInitialization` initializationLatch.countDown() @@ -268,7 +280,7 @@ class StreamExecution( reportTimeTaken("triggerExecution") { if (currentBatchId < 0) { // We'll do this initialization only once - populateStartOffsets() + populateStartOffsets(sparkSessionToRunBatches) logDebug(s"Stream running from $committedOffsets to $availableOffsets") } else { constructNextBatch() @@ -276,7 +288,7 @@ class StreamExecution( if (dataAvailable) { currentStatus = currentStatus.copy(isDataAvailable = true) updateStatusMessage("Processing new data") - runBatch() + runBatch(sparkSessionToRunBatches) } } @@ -381,13 +393,32 @@ class StreamExecution( * - committedOffsets * - availableOffsets */ - private def populateStartOffsets(): Unit = { + private def populateStartOffsets(sparkSessionToRunBatches: SparkSession): Unit = { offsetLog.getLatest() match { case Some((batchId, nextOffsets)) => logInfo(s"Resuming streaming query, starting with batch $batchId") currentBatchId = batchId availableOffsets = nextOffsets.toStreamProgress(sources) - offsetSeqMetadata = nextOffsets.metadata.getOrElse(OffsetSeqMetadata()) + + // update offset metadata + nextOffsets.metadata.foreach { metadata => + val shufflePartitionsSparkSession: Int = + sparkSessionToRunBatches.conf.get(SQLConf.SHUFFLE_PARTITIONS) + val shufflePartitionsToUse = metadata.conf.getOrElse(SQLConf.SHUFFLE_PARTITIONS.key, { + // For backward compatibility, if # partitions was not recorded in the offset log, + // then ensure it is not missing. The new value is picked up from the conf. + logWarning("Number of shuffle partitions from previous run not found in checkpoint. " + + s"Using the value from the conf, $shufflePartitionsSparkSession partitions.") + shufflePartitionsSparkSession + }) + offsetSeqMetadata = OffsetSeqMetadata( + metadata.batchWatermarkMs, metadata.batchTimestampMs, + metadata.conf + (SQLConf.SHUFFLE_PARTITIONS.key -> shufflePartitionsToUse.toString)) + // Update conf with correct number of shuffle partitions + sparkSessionToRunBatches.conf.set( + SQLConf.SHUFFLE_PARTITIONS.key, shufflePartitionsToUse.toString) + } + logDebug(s"Found possibly unprocessed offsets $availableOffsets " + s"at batch timestamp ${offsetSeqMetadata.batchTimestampMs}") @@ -444,8 +475,7 @@ class StreamExecution( } } if (hasNewData) { - // Current batch timestamp in milliseconds - offsetSeqMetadata.batchTimestampMs = triggerClock.getTimeMillis() + var batchWatermarkMs = offsetSeqMetadata.batchWatermarkMs // Update the eventTime watermark if we find one in the plan. if (lastExecution != null) { lastExecution.executedPlan.collect { @@ -453,16 +483,19 @@ class StreamExecution( logDebug(s"Observed event time stats: ${e.eventTimeStats.value}") e.eventTimeStats.value.max - e.delayMs }.headOption.foreach { newWatermarkMs => - if (newWatermarkMs > offsetSeqMetadata.batchWatermarkMs) { + if (newWatermarkMs > batchWatermarkMs) { logInfo(s"Updating eventTime watermark to: $newWatermarkMs ms") - offsetSeqMetadata.batchWatermarkMs = newWatermarkMs + batchWatermarkMs = newWatermarkMs } else { logDebug( s"Event time didn't move: $newWatermarkMs < " + - s"${offsetSeqMetadata.batchWatermarkMs}") + s"$batchWatermarkMs") } } } + offsetSeqMetadata = offsetSeqMetadata.copy( + batchWatermarkMs = batchWatermarkMs, + batchTimestampMs = triggerClock.getTimeMillis()) // Current batch timestamp in milliseconds updateStatusMessage("Writing offsets to log") reportTimeTaken("walCommit") { @@ -505,8 +538,9 @@ class StreamExecution( /** * Processes any data available between `availableOffsets` and `committedOffsets`. + * @param sparkSessionToRunBatch Isolated [[SparkSession]] to run this batch with. */ - private def runBatch(): Unit = { + private def runBatch(sparkSessionToRunBatch: SparkSession): Unit = { // Request unprocessed data from all sources. newData = reportTimeTaken("getBatch") { availableOffsets.flatMap { @@ -551,7 +585,7 @@ class StreamExecution( reportTimeTaken("queryPlanning") { lastExecution = new IncrementalExecution( - sparkSession, + sparkSessionToRunBatch, triggerLogicalPlan, outputMode, checkpointFile("state"), @@ -561,7 +595,7 @@ class StreamExecution( } val nextBatch = - new Dataset(sparkSession, lastExecution, RowEncoder(lastExecution.analyzed.schema)) + new Dataset(sparkSessionToRunBatch, lastExecution, RowEncoder(lastExecution.analyzed.schema)) reportTimeTaken("addBatch") { sink.addBatch(currentBatchId, nextBatch) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryManager.scala index 38edb40dfb781..7810d9f6e9642 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryManager.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryManager.scala @@ -25,6 +25,7 @@ import scala.collection.mutable import org.apache.hadoop.fs.Path import org.apache.spark.annotation.{Experimental, InterfaceStability} +import org.apache.spark.internal.Logging import org.apache.spark.sql.{AnalysisException, DataFrame, SparkSession} import org.apache.spark.sql.catalyst.analysis.UnsupportedOperationChecker import org.apache.spark.sql.execution.streaming._ @@ -40,7 +41,7 @@ import org.apache.spark.util.{Clock, SystemClock, Utils} */ @Experimental @InterfaceStability.Evolving -class StreamingQueryManager private[sql] (sparkSession: SparkSession) { +class StreamingQueryManager private[sql] (sparkSession: SparkSession) extends Logging { private[sql] val stateStoreCoordinator = StateStoreCoordinatorRef.forDriver(sparkSession.sparkContext.env) @@ -234,9 +235,8 @@ class StreamingQueryManager private[sql] (sparkSession: SparkSession) { } if (sparkSession.sessionState.conf.adaptiveExecutionEnabled) { - throw new AnalysisException( - s"${SQLConf.ADAPTIVE_EXECUTION_ENABLED.key} " + - "is not supported in streaming DataFrames/Datasets") + logWarning(s"${SQLConf.ADAPTIVE_EXECUTION_ENABLED.key} " + + "is not supported in streaming DataFrames/Datasets and will be disabled.") } new StreamingQueryWrapper(new StreamExecution( diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.1.0/metadata b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.1.0/metadata new file mode 100644 index 0000000000000..3492220e36b8d --- /dev/null +++ b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.1.0/metadata @@ -0,0 +1 @@ +{"id":"dddc5e7f-1e71-454c-8362-de184444fb5a"} \ No newline at end of file diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.1.0/offsets/0 b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.1.0/offsets/0 new file mode 100644 index 0000000000000..cbde042e79af1 --- /dev/null +++ b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.1.0/offsets/0 @@ -0,0 +1,3 @@ +v1 +{"batchWatermarkMs":0,"batchTimestampMs":1489180207737} +0 \ No newline at end of file diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.1.0/offsets/1 b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.1.0/offsets/1 new file mode 100644 index 0000000000000..10b5774746de9 --- /dev/null +++ b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.1.0/offsets/1 @@ -0,0 +1,3 @@ +v1 +{"batchWatermarkMs":0,"batchTimestampMs":1489180209261} +2 \ No newline at end of file diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.1.0/state/0/0/1.delta b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.1.0/state/0/0/1.delta new file mode 100644 index 0000000000000000000000000000000000000000..6352978051846970ca41a0ca97fd79952105726d GIT binary patch literal 46 icmeZ?GI7euPtF!)VPIeY;oA+q9RGp92POd&g989JFAHe^ literal 0 HcmV?d00001 diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.1.0/state/0/0/2.delta b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.1.0/state/0/0/2.delta new file mode 100644 index 0000000000000000000000000000000000000000..6352978051846970ca41a0ca97fd79952105726d GIT binary patch literal 46 icmeZ?GI7euPtF!)VPIeY;oA+q9RGp92POd&g989JFAHe^ literal 0 HcmV?d00001 diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.1.0/state/0/1/1.delta b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.1.0/state/0/1/1.delta new file mode 100644 index 0000000000000000000000000000000000000000..7dc49cb3e47fd7a4001ff7ddc96094e754117c44 GIT binary patch literal 79 zcmeZ?GI7euPtI0VWnf^i0AjiN4z6GzEx^FYAk56c;0R>PuraWUFbFd8F)RS`fZ#t6 M_&{}vLWCeB023(<4FCWD literal 0 HcmV?d00001 diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.1.0/state/0/1/2.delta b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.1.0/state/0/1/2.delta new file mode 100644 index 0000000000000000000000000000000000000000..8b566e81f48663efa0ebda2dbf694d65e28def72 GIT binary patch literal 79 zcmeZ?GI7euPtI0VWnf^i0OEK1WO;*uv;YGmgD^7(gCmeF!^Xfa!XU`R$FKm%1A_lR M-~-hu3K4>k06a_$wEzGB literal 0 HcmV?d00001 diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.1.0/state/0/2/1.delta b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.1.0/state/0/2/1.delta new file mode 100644 index 0000000000000000000000000000000000000000..ca2a7ed033f3baf749f5a93522e951c15729e4c6 GIT binary patch literal 79 zcmeZ?GI7euPtI0VWnf^i0AjUmuFSzeT7ZF(L70Vu!4b%oVPjwyVGv~GV^{#>0l|MD M@PX0l|MD M@PXPuraWUFbFd8F)RS`fZ#t6 M_&{}vLWCeB00o>3)Bpeg literal 0 HcmV?d00001 diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.1.0/state/0/4/1.delta b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.1.0/state/0/4/1.delta new file mode 100644 index 0000000000000000000000000000000000000000..fe521b8c07504adc9e174c9116497607365cca7e GIT binary patch literal 79 zcmeZ?GI7euPtI0VWnf^i0OE5h9UQ?xT7ZF(L70hy!4b%oVPjwyVGv~GV^{#>0l|MD M@PX0l|MD M@PXgCmeF!^Xfa!XU`V$FKm%1A_lR M-~-hu3K4>k06R$yw*UYD literal 0 HcmV?d00001 diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.1.0/state/0/8/1.delta b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.1.0/state/0/8/1.delta new file mode 100644 index 0000000000000000000000000000000000000000..6352978051846970ca41a0ca97fd79952105726d GIT binary patch literal 46 icmeZ?GI7euPtF!)VPIeY;oA+q9RGp92POd&g989JFAHe^ literal 0 HcmV?d00001 diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.1.0/state/0/8/2.delta b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.1.0/state/0/8/2.delta new file mode 100644 index 0000000000000000000000000000000000000000..6352978051846970ca41a0ca97fd79952105726d GIT binary patch literal 46 icmeZ?GI7euPtF!)VPIeY;oA+q9RGp92POd&g989JFAHe^ literal 0 HcmV?d00001 diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.1.0/state/0/9/1.delta b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.1.0/state/0/9/1.delta new file mode 100644 index 0000000000000000000000000000000000000000..6352978051846970ca41a0ca97fd79952105726d GIT binary patch literal 46 icmeZ?GI7euPtF!)VPIeY;oA+q9RGp92POd&g989JFAHe^ literal 0 HcmV?d00001 diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.1.0/state/0/9/2.delta b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.1.0/state/0/9/2.delta new file mode 100644 index 0000000000000000000000000000000000000000..0c9b6ac5c863d06d63c46c8a1fc51da716a5fdbd GIT binary patch literal 79 zcmeZ?GI7euPtI0VWnf^i0OIOpUzvh|v;YGmgD@KhgCmeF!^Xfa!XU`R$FKm%1A_lR M-~-hu3K4>k083;I`Tzg` literal 0 HcmV?d00001 diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/OffsetSeqLogSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/OffsetSeqLogSuite.scala index f7f0dade8717e..dc556322beddb 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/OffsetSeqLogSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/OffsetSeqLogSuite.scala @@ -21,6 +21,7 @@ import java.io.File import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.util.stringToFile +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSQLContext class OffsetSeqLogSuite extends SparkFunSuite with SharedSQLContext { @@ -29,12 +30,37 @@ class OffsetSeqLogSuite extends SparkFunSuite with SharedSQLContext { case class StringOffset(override val json: String) extends Offset test("OffsetSeqMetadata - deserialization") { - assert(OffsetSeqMetadata(0, 0) === OffsetSeqMetadata("""{}""")) - assert(OffsetSeqMetadata(1, 0) === OffsetSeqMetadata("""{"batchWatermarkMs":1}""")) - assert(OffsetSeqMetadata(0, 2) === OffsetSeqMetadata("""{"batchTimestampMs":2}""")) - assert( - OffsetSeqMetadata(1, 2) === - OffsetSeqMetadata("""{"batchWatermarkMs":1,"batchTimestampMs":2}""")) + val key = SQLConf.SHUFFLE_PARTITIONS.key + + def getConfWith(shufflePartitions: Int): Map[String, String] = { + Map(key -> shufflePartitions.toString) + } + + // None set + assert(OffsetSeqMetadata(0, 0, Map.empty) === OffsetSeqMetadata("""{}""")) + + // One set + assert(OffsetSeqMetadata(1, 0, Map.empty) === OffsetSeqMetadata("""{"batchWatermarkMs":1}""")) + assert(OffsetSeqMetadata(0, 2, Map.empty) === OffsetSeqMetadata("""{"batchTimestampMs":2}""")) + assert(OffsetSeqMetadata(0, 0, getConfWith(shufflePartitions = 2)) === + OffsetSeqMetadata(s"""{"conf": {"$key":2}}""")) + + // Two set + assert(OffsetSeqMetadata(1, 2, Map.empty) === + OffsetSeqMetadata("""{"batchWatermarkMs":1,"batchTimestampMs":2}""")) + assert(OffsetSeqMetadata(1, 0, getConfWith(shufflePartitions = 3)) === + OffsetSeqMetadata(s"""{"batchWatermarkMs":1,"conf": {"$key":3}}""")) + assert(OffsetSeqMetadata(0, 2, getConfWith(shufflePartitions = 3)) === + OffsetSeqMetadata(s"""{"batchTimestampMs":2,"conf": {"$key":3}}""")) + + // All set + assert(OffsetSeqMetadata(1, 2, getConfWith(shufflePartitions = 3)) === + OffsetSeqMetadata(s"""{"batchWatermarkMs":1,"batchTimestampMs":2,"conf": {"$key":3}}""")) + + // Drop unknown fields + assert(OffsetSeqMetadata(1, 2, getConfWith(shufflePartitions = 3)) === + OffsetSeqMetadata( + s"""{"batchWatermarkMs":1,"batchTimestampMs":2,"conf": {"$key":3}},"unknown":1""")) } test("OffsetSeqLog - serialization - deserialization") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala index 6dfcd8baba20e..e867fc40f7f1a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala @@ -17,17 +17,20 @@ package org.apache.spark.sql.streaming -import java.io.{InterruptedIOException, IOException} +import java.io.{File, InterruptedIOException, IOException} import java.util.concurrent.{CountDownLatch, TimeoutException, TimeUnit} import scala.reflect.ClassTag import scala.util.control.ControlThrowable +import org.apache.commons.io.FileUtils + import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.streaming.InternalOutputModes import org.apache.spark.sql.execution.command.ExplainCommand import org.apache.spark.sql.execution.streaming._ import org.apache.spark.sql.functions._ +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.sources.StreamSourceProvider import org.apache.spark.sql.types.{IntegerType, StructField, StructType} @@ -389,6 +392,102 @@ class StreamSuite extends StreamTest { query.stop() assert(query.exception.isEmpty) } + + test("SPARK-19873: streaming aggregation with change in number of partitions") { + val inputData = MemoryStream[(Int, Int)] + val agg = inputData.toDS().groupBy("_1").count() + + testStream(agg, OutputMode.Complete())( + AddData(inputData, (1, 0), (2, 0)), + StartStream(additionalConfs = Map(SQLConf.SHUFFLE_PARTITIONS.key -> "2")), + CheckAnswer((1, 1), (2, 1)), + StopStream, + AddData(inputData, (3, 0), (2, 0)), + StartStream(additionalConfs = Map(SQLConf.SHUFFLE_PARTITIONS.key -> "5")), + CheckAnswer((1, 1), (2, 2), (3, 1)), + StopStream, + AddData(inputData, (3, 0), (1, 0)), + StartStream(additionalConfs = Map(SQLConf.SHUFFLE_PARTITIONS.key -> "1")), + CheckAnswer((1, 2), (2, 2), (3, 2))) + } + + test("recover from a Spark v2.1 checkpoint") { + var inputData: MemoryStream[Int] = null + var query: DataStreamWriter[Row] = null + + def prepareMemoryStream(): Unit = { + inputData = MemoryStream[Int] + inputData.addData(1, 2, 3, 4) + inputData.addData(3, 4, 5, 6) + inputData.addData(5, 6, 7, 8) + + query = inputData + .toDF() + .groupBy($"value") + .agg(count("*")) + .writeStream + .outputMode("complete") + .format("memory") + } + + // Get an existing checkpoint generated by Spark v2.1. + // v2.1 does not record # shuffle partitions in the offset metadata. + val resourceUri = + this.getClass.getResource("/structured-streaming/checkpoint-version-2.1.0").toURI + val checkpointDir = new File(resourceUri) + + // 1 - Test if recovery from the checkpoint is successful. + prepareMemoryStream() + withTempDir { dir => + // Copy the checkpoint to a temp dir to prevent changes to the original. + // Not doing this will lead to the test passing on the first run, but fail subsequent runs. + FileUtils.copyDirectory(checkpointDir, dir) + + // Checkpoint data was generated by a query with 10 shuffle partitions. + // In order to test reading from the checkpoint, the checkpoint must have two or more batches, + // since the last batch may be rerun. + withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "10") { + var streamingQuery: StreamingQuery = null + try { + streamingQuery = + query.queryName("counts").option("checkpointLocation", dir.getCanonicalPath).start() + streamingQuery.processAllAvailable() + inputData.addData(9) + streamingQuery.processAllAvailable() + + QueryTest.checkAnswer(spark.table("counts").toDF(), + Row("1", 1) :: Row("2", 1) :: Row("3", 2) :: Row("4", 2) :: + Row("5", 2) :: Row("6", 2) :: Row("7", 1) :: Row("8", 1) :: Row("9", 1) :: Nil) + } finally { + if (streamingQuery ne null) { + streamingQuery.stop() + } + } + } + } + + // 2 - Check recovery with wrong num shuffle partitions + prepareMemoryStream() + withTempDir { dir => + FileUtils.copyDirectory(checkpointDir, dir) + + // Since the number of partitions is greater than 10, should throw exception. + withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "15") { + var streamingQuery: StreamingQuery = null + try { + intercept[StreamingQueryException] { + streamingQuery = + query.queryName("badQuery").option("checkpointLocation", dir.getCanonicalPath).start() + streamingQuery.processAllAvailable() + } + } finally { + if (streamingQuery ne null) { + streamingQuery.stop() + } + } + } + } + } } abstract class FakeSource extends StreamSourceProvider { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryManagerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryManagerSuite.scala index f05e9d1fda73f..b49efa6890236 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryManagerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryManagerSuite.scala @@ -239,16 +239,6 @@ class StreamingQueryManagerSuite extends StreamTest with BeforeAndAfter { } } - test("SPARK-19268: Adaptive query execution should be disallowed") { - withSQLConf(SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true") { - val e = intercept[AnalysisException] { - MemoryStream[Int].toDS.writeStream.queryName("test-query").format("memory").start() - } - assert(e.getMessage.contains(SQLConf.ADAPTIVE_EXECUTION_ENABLED.key) && - e.getMessage.contains("not supported")) - } - } - /** Run a body of code by defining a query on each dataset */ private def withQueriesOn(datasets: Dataset[_]*)(body: Seq[StreamingQuery] => Unit): Unit = { failAfter(streamingTimeout) { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/test/DataStreamReaderWriterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/test/DataStreamReaderWriterSuite.scala index f61dcdcbcf718..341ab0eb923da 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/test/DataStreamReaderWriterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/test/DataStreamReaderWriterSuite.scala @@ -23,6 +23,7 @@ import java.util.concurrent.TimeUnit import scala.concurrent.duration._ import org.apache.hadoop.fs.Path +import org.mockito.Matchers.{any, eq => meq} import org.mockito.Mockito._ import org.scalatest.BeforeAndAfter @@ -370,21 +371,22 @@ class DataStreamReaderWriterSuite extends StreamTest with BeforeAndAfter { .option("checkpointLocation", checkpointLocationURI.toString) .trigger(ProcessingTime(10.seconds)) .start() + q.processAllAvailable() q.stop() verify(LastOptions.mockStreamSourceProvider).createSource( - spark.sqlContext, - s"$checkpointLocationURI/sources/0", - None, - "org.apache.spark.sql.streaming.test", - Map.empty) + any(), + meq(s"$checkpointLocationURI/sources/0"), + meq(None), + meq("org.apache.spark.sql.streaming.test"), + meq(Map.empty)) verify(LastOptions.mockStreamSourceProvider).createSource( - spark.sqlContext, - s"$checkpointLocationURI/sources/1", - None, - "org.apache.spark.sql.streaming.test", - Map.empty) + any(), + meq(s"$checkpointLocationURI/sources/1"), + meq(None), + meq("org.apache.spark.sql.streaming.test"), + meq(Map.empty)) } private def newTextInput = Utils.createTempDir(namePrefix = "text").getCanonicalPath From 6326d406b98a34e9cc8afa6743b23ee1cced8611 Mon Sep 17 00:00:00 2001 From: Jacek Laskowski Date: Fri, 17 Mar 2017 21:55:10 -0700 Subject: [PATCH 0047/1765] [SQL][MINOR] Fix scaladoc for UDFRegistration ## What changes were proposed in this pull request? Fix scaladoc for UDFRegistration ## How was this patch tested? local build Author: Jacek Laskowski Closes #17337 from jaceklaskowski/udfregistration-scaladoc. --- .../main/scala/org/apache/spark/sql/UDFRegistration.scala | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala b/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala index 7abfa4ea37a74..a57673334c10b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala @@ -36,7 +36,11 @@ import org.apache.spark.sql.types.{DataType, DataTypes} import org.apache.spark.util.Utils /** - * Functions for registering user-defined functions. Use `SQLContext.udf` to access this. + * Functions for registering user-defined functions. Use `SparkSession.udf` to access this: + * + * {{{ + * spark.udf + * }}} * * @note The user-defined functions must be deterministic. * From c083b6b7dec337d680b54dabeaa40e7a0f69ae69 Mon Sep 17 00:00:00 2001 From: wangzhenhua Date: Sat, 18 Mar 2017 14:07:25 +0800 Subject: [PATCH 0048/1765] [SPARK-19915][SQL] Exclude cartesian product candidates to reduce the search space ## What changes were proposed in this pull request? We have some concerns about removing size in the cost model [in the previous pr](https://github.com/apache/spark/pull/17240). It's a tradeoff between code structure and algorithm completeness. I tend to keep the size and thus create this new pr without changing cost model. What this pr does: 1. We only consider consecutive inner joinable items, thus excluding cartesian products in reordering procedure. This significantly reduces the search space and memory overhead of memo. Otherwise every combination of items will exist in the memo. 2. This pr also includes a bug fix: if a leaf item is a project(_, child), current solution will miss the project. ## How was this patch tested? Added test cases. Author: wangzhenhua Closes #17286 from wzhfy/joinReorder3. --- .../optimizer/CostBasedJoinReorder.scala | 191 +++++++++--------- .../apache/spark/sql/internal/SQLConf.scala | 11 + .../catalyst/optimizer/JoinReorderSuite.scala | 41 +++- 3 files changed, 143 insertions(+), 100 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/CostBasedJoinReorder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/CostBasedJoinReorder.scala index b694561e5372d..1b32bda72bc9f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/CostBasedJoinReorder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/CostBasedJoinReorder.scala @@ -19,11 +19,11 @@ package org.apache.spark.sql.catalyst.optimizer import scala.collection.mutable -import org.apache.spark.sql.catalyst.CatalystConf import org.apache.spark.sql.catalyst.expressions.{And, Attribute, AttributeSet, Expression, PredicateHelper} import org.apache.spark.sql.catalyst.plans.{Inner, InnerLike} import org.apache.spark.sql.catalyst.plans.logical.{BinaryNode, Join, LogicalPlan, Project} import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.internal.SQLConf /** @@ -31,19 +31,21 @@ import org.apache.spark.sql.catalyst.rules.Rule * We may have several join reorder algorithms in the future. This class is the entry of these * algorithms, and chooses which one to use. */ -case class CostBasedJoinReorder(conf: CatalystConf) extends Rule[LogicalPlan] with PredicateHelper { +case class CostBasedJoinReorder(conf: SQLConf) extends Rule[LogicalPlan] with PredicateHelper { def apply(plan: LogicalPlan): LogicalPlan = { if (!conf.cboEnabled || !conf.joinReorderEnabled) { plan } else { - val result = plan transform { - case p @ Project(projectList, j @ Join(_, _, _: InnerLike, _)) => - reorder(p, p.outputSet) - case j @ Join(_, _, _: InnerLike, _) => + val result = plan transformDown { + // Start reordering with a joinable item, which is an InnerLike join with conditions. + case j @ Join(_, _, _: InnerLike, Some(cond)) => reorder(j, j.outputSet) + case p @ Project(projectList, Join(_, _, _: InnerLike, Some(cond))) + if projectList.forall(_.isInstanceOf[Attribute]) => + reorder(p, p.outputSet) } // After reordering is finished, convert OrderedJoin back to Join - result transform { + result transformDown { case oj: OrderedJoin => oj.join } } @@ -56,7 +58,7 @@ case class CostBasedJoinReorder(conf: CatalystConf) extends Rule[LogicalPlan] wi // We also need to check if costs of all items can be evaluated. if (items.size > 2 && items.size <= conf.joinReorderDPThreshold && conditions.nonEmpty && items.forall(_.stats(conf).rowCount.isDefined)) { - JoinReorderDP.search(conf, items, conditions, output).getOrElse(plan) + JoinReorderDP.search(conf, items, conditions, output) } else { plan } @@ -70,25 +72,26 @@ case class CostBasedJoinReorder(conf: CatalystConf) extends Rule[LogicalPlan] wi */ private def extractInnerJoins(plan: LogicalPlan): (Seq[LogicalPlan], Set[Expression]) = { plan match { - case Join(left, right, _: InnerLike, cond) => + case Join(left, right, _: InnerLike, Some(cond)) => val (leftPlans, leftConditions) = extractInnerJoins(left) val (rightPlans, rightConditions) = extractInnerJoins(right) - (leftPlans ++ rightPlans, cond.toSet.flatMap(splitConjunctivePredicates) ++ + (leftPlans ++ rightPlans, splitConjunctivePredicates(cond).toSet ++ leftConditions ++ rightConditions) - case Project(projectList, join) if projectList.forall(_.isInstanceOf[Attribute]) => - extractInnerJoins(join) + case Project(projectList, j @ Join(_, _, _: InnerLike, Some(cond))) + if projectList.forall(_.isInstanceOf[Attribute]) => + extractInnerJoins(j) case _ => (Seq(plan), Set()) } } private def replaceWithOrderedJoin(plan: LogicalPlan): LogicalPlan = plan match { - case j @ Join(left, right, _: InnerLike, cond) => + case j @ Join(left, right, _: InnerLike, Some(cond)) => val replacedLeft = replaceWithOrderedJoin(left) val replacedRight = replaceWithOrderedJoin(right) OrderedJoin(j.copy(left = replacedLeft, right = replacedRight)) - case p @ Project(_, join) => - p.copy(child = replaceWithOrderedJoin(join)) + case p @ Project(projectList, j @ Join(_, _, _: InnerLike, Some(cond))) => + p.copy(child = replaceWithOrderedJoin(j)) case _ => plan } @@ -128,10 +131,10 @@ case class CostBasedJoinReorder(conf: CatalystConf) extends Rule[LogicalPlan] wi object JoinReorderDP extends PredicateHelper { def search( - conf: CatalystConf, + conf: SQLConf, items: Seq[LogicalPlan], conditions: Set[Expression], - topOutput: AttributeSet): Option[LogicalPlan] = { + topOutput: AttributeSet): LogicalPlan = { // Level i maintains all found plans for i + 1 items. // Create the initial plans: each plan is a single item with zero cost. @@ -140,26 +143,22 @@ object JoinReorderDP extends PredicateHelper { case (item, id) => Set(id) -> JoinPlan(Set(id), item, Set(), Cost(0, 0)) }.toMap) - for (lev <- 1 until items.length) { + // Build plans for next levels until the last level has only one plan. This plan contains + // all items that can be joined, so there's no need to continue. + while (foundPlans.size < items.length && foundPlans.last.size > 1) { // Build plans for the next level. foundPlans += searchLevel(foundPlans, conf, conditions, topOutput) } - val plansLastLevel = foundPlans(items.length - 1) - if (plansLastLevel.isEmpty) { - // Failed to find a plan, fall back to the original plan - None - } else { - // There must be only one plan at the last level, which contains all items. - assert(plansLastLevel.size == 1 && plansLastLevel.head._1.size == items.length) - Some(plansLastLevel.head._2.plan) - } + // The last level must have one and only one plan, because all items are joinable. + assert(foundPlans.size == items.length && foundPlans.last.size == 1) + foundPlans.last.head._2.plan } /** Find all possible plans at the next level, based on existing levels. */ private def searchLevel( existingLevels: Seq[JoinPlanMap], - conf: CatalystConf, + conf: SQLConf, conditions: Set[Expression], topOutput: AttributeSet): JoinPlanMap = { @@ -185,11 +184,14 @@ object JoinReorderDP extends PredicateHelper { // Should not join two overlapping item sets. if (oneSidePlan.itemIds.intersect(otherSidePlan.itemIds).isEmpty) { val joinPlan = buildJoin(oneSidePlan, otherSidePlan, conf, conditions, topOutput) - // Check if it's the first plan for the item set, or it's a better plan than - // the existing one due to lower cost. - val existingPlan = nextLevel.get(joinPlan.itemIds) - if (existingPlan.isEmpty || joinPlan.cost.lessThan(existingPlan.get.cost)) { - nextLevel.update(joinPlan.itemIds, joinPlan) + if (joinPlan.isDefined) { + val newJoinPlan = joinPlan.get + // Check if it's the first plan for the item set, or it's a better plan than + // the existing one due to lower cost. + val existingPlan = nextLevel.get(newJoinPlan.itemIds) + if (existingPlan.isEmpty || newJoinPlan.betterThan(existingPlan.get, conf)) { + nextLevel.update(newJoinPlan.itemIds, newJoinPlan) + } } } } @@ -203,64 +205,46 @@ object JoinReorderDP extends PredicateHelper { private def buildJoin( oneJoinPlan: JoinPlan, otherJoinPlan: JoinPlan, - conf: CatalystConf, + conf: SQLConf, conditions: Set[Expression], - topOutput: AttributeSet): JoinPlan = { + topOutput: AttributeSet): Option[JoinPlan] = { val onePlan = oneJoinPlan.plan val otherPlan = otherJoinPlan.plan - // Now both onePlan and otherPlan become intermediate joins, so the cost of the - // new join should also include their own cardinalities and sizes. - val newCost = if (isCartesianProduct(onePlan) || isCartesianProduct(otherPlan)) { - // We consider cartesian product very expensive, thus set a very large cost for it. - // This enables to plan all the cartesian products at the end, because having a cartesian - // product as an intermediate join will significantly increase a plan's cost, making it - // impossible to be selected as the best plan for the items, unless there's no other choice. - Cost( - rows = BigInt(Long.MaxValue) * BigInt(Long.MaxValue), - size = BigInt(Long.MaxValue) * BigInt(Long.MaxValue)) - } else { - val onePlanStats = onePlan.stats(conf) - val otherPlanStats = otherPlan.stats(conf) - Cost( - rows = oneJoinPlan.cost.rows + onePlanStats.rowCount.get + - otherJoinPlan.cost.rows + otherPlanStats.rowCount.get, - size = oneJoinPlan.cost.size + onePlanStats.sizeInBytes + - otherJoinPlan.cost.size + otherPlanStats.sizeInBytes) - } - - // Put the deeper side on the left, tend to build a left-deep tree. - val (left, right) = if (oneJoinPlan.itemIds.size >= otherJoinPlan.itemIds.size) { - (onePlan, otherPlan) - } else { - (otherPlan, onePlan) - } val joinConds = conditions .filterNot(l => canEvaluate(l, onePlan)) .filterNot(r => canEvaluate(r, otherPlan)) .filter(e => e.references.subsetOf(onePlan.outputSet ++ otherPlan.outputSet)) - // We use inner join whether join condition is empty or not. Since cross join is - // equivalent to inner join without condition. - val newJoin = Join(left, right, Inner, joinConds.reduceOption(And)) - val collectedJoinConds = joinConds ++ oneJoinPlan.joinConds ++ otherJoinPlan.joinConds - val remainingConds = conditions -- collectedJoinConds - val neededAttr = AttributeSet(remainingConds.flatMap(_.references)) ++ topOutput - val neededFromNewJoin = newJoin.outputSet.filter(neededAttr.contains) - val newPlan = - if ((newJoin.outputSet -- neededFromNewJoin).nonEmpty) { - Project(neededFromNewJoin.toSeq, newJoin) + if (joinConds.isEmpty) { + // Cartesian product is very expensive, so we exclude them from candidate plans. + // This also significantly reduces the search space. + None + } else { + // Put the deeper side on the left, tend to build a left-deep tree. + val (left, right) = if (oneJoinPlan.itemIds.size >= otherJoinPlan.itemIds.size) { + (onePlan, otherPlan) } else { - newJoin + (otherPlan, onePlan) } + val newJoin = Join(left, right, Inner, joinConds.reduceOption(And)) + val collectedJoinConds = joinConds ++ oneJoinPlan.joinConds ++ otherJoinPlan.joinConds + val remainingConds = conditions -- collectedJoinConds + val neededAttr = AttributeSet(remainingConds.flatMap(_.references)) ++ topOutput + val neededFromNewJoin = newJoin.outputSet.filter(neededAttr.contains) + val newPlan = + if ((newJoin.outputSet -- neededFromNewJoin).nonEmpty) { + Project(neededFromNewJoin.toSeq, newJoin) + } else { + newJoin + } - val itemIds = oneJoinPlan.itemIds.union(otherJoinPlan.itemIds) - JoinPlan(itemIds, newPlan, collectedJoinConds, newCost) - } - - private def isCartesianProduct(plan: LogicalPlan): Boolean = plan match { - case Join(_, _, _, None) => true - case Project(_, Join(_, _, _, None)) => true - case _ => false + val itemIds = oneJoinPlan.itemIds.union(otherJoinPlan.itemIds) + // Now the root node of onePlan/otherPlan becomes an intermediate join (if it's a non-leaf + // item), so the cost of the new join should also include its own cost. + val newPlanCost = oneJoinPlan.planCost + oneJoinPlan.rootCost(conf) + + otherJoinPlan.planCost + otherJoinPlan.rootCost(conf) + Some(JoinPlan(itemIds, newPlan, collectedJoinConds, newPlanCost)) + } } /** Map[set of item ids, join plan for these items] */ @@ -272,26 +256,39 @@ object JoinReorderDP extends PredicateHelper { * @param itemIds Set of item ids participating in this partial plan. * @param plan The plan tree with the lowest cost for these items found so far. * @param joinConds Join conditions included in the plan. - * @param cost The cost of this plan is the sum of costs of all intermediate joins. + * @param planCost The cost of this plan tree is the sum of costs of all intermediate joins. */ - case class JoinPlan(itemIds: Set[Int], plan: LogicalPlan, joinConds: Set[Expression], cost: Cost) -} + case class JoinPlan( + itemIds: Set[Int], + plan: LogicalPlan, + joinConds: Set[Expression], + planCost: Cost) { -/** This class defines the cost model. */ -case class Cost(rows: BigInt, size: BigInt) { - /** - * An empirical value for the weights of cardinality (number of rows) in the cost formula: - * cost = rows * weight + size * (1 - weight), usually cardinality is more important than size. - */ - val weight = 0.7 + /** Get the cost of the root node of this plan tree. */ + def rootCost(conf: SQLConf): Cost = { + if (itemIds.size > 1) { + val rootStats = plan.stats(conf) + Cost(rootStats.rowCount.get, rootStats.sizeInBytes) + } else { + // If the plan is a leaf item, it has zero cost. + Cost(0, 0) + } + } - def lessThan(other: Cost): Boolean = { - if (other.rows == 0 || other.size == 0) { - false - } else { - val relativeRows = BigDecimal(rows) / BigDecimal(other.rows) - val relativeSize = BigDecimal(size) / BigDecimal(other.size) - relativeRows * weight + relativeSize * (1 - weight) < 1 + def betterThan(other: JoinPlan, conf: SQLConf): Boolean = { + if (other.planCost.rows == 0 || other.planCost.size == 0) { + false + } else { + val relativeRows = BigDecimal(this.planCost.rows) / BigDecimal(other.planCost.rows) + val relativeSize = BigDecimal(this.planCost.size) / BigDecimal(other.planCost.size) + relativeRows * conf.joinReorderCardWeight + + relativeSize * (1 - conf.joinReorderCardWeight) < 1 + } } } } + +/** This class defines the cost model. */ +case class Cost(rows: BigInt, size: BigInt) { + def +(other: Cost): Cost = Cost(this.rows + other.rows, this.size + other.size) +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index a85f87aece45b..d2ac4b88ee8fd 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -710,6 +710,15 @@ object SQLConf { .intConf .createWithDefault(12) + val JOIN_REORDER_CARD_WEIGHT = + buildConf("spark.sql.cbo.joinReorder.card.weight") + .internal() + .doc("The weight of cardinality (number of rows) for plan cost comparison in join reorder: " + + "rows * weight + size * (1 - weight).") + .doubleConf + .checkValue(weight => weight >= 0 && weight <= 1, "The weight value must be in [0, 1].") + .createWithDefault(0.7) + val SESSION_LOCAL_TIMEZONE = buildConf("spark.sql.session.timeZone") .doc("""The ID of session local timezone, e.g. "GMT", "America/Los_Angeles", etc.""") @@ -967,6 +976,8 @@ class SQLConf extends Serializable with Logging { def joinReorderDPThreshold: Int = getConf(SQLConf.JOIN_REORDER_DP_THRESHOLD) + def joinReorderCardWeight: Double = getConf(SQLConf.JOIN_REORDER_CARD_WEIGHT) + def windowExecBufferSpillThreshold: Int = getConf(WINDOW_EXEC_BUFFER_SPILL_THRESHOLD) def sortMergeJoinExecBufferSpillThreshold: Int = diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinReorderSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinReorderSuite.scala index 1b2f7a66b6a0b..5607bcd16f3ff 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinReorderSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinReorderSuite.scala @@ -38,6 +38,7 @@ class JoinReorderSuite extends PlanTest with StatsEstimationTestBase { Batch("Operator Optimizations", FixedPoint(100), CombineFilters, PushDownPredicate, + ReorderJoin, PushPredicateThroughJoin, ColumnPruning, CollapseProject) :: @@ -58,6 +59,10 @@ class JoinReorderSuite extends PlanTest with StatsEstimationTestBase { attr("t4.k-1-2") -> ColumnStat(distinctCount = 2, min = Some(1), max = Some(2), nullCount = 0, avgLen = 4, maxLen = 4), attr("t4.v-1-10") -> ColumnStat(distinctCount = 10, min = Some(1), max = Some(10), + nullCount = 0, avgLen = 4, maxLen = 4), + attr("t5.k-1-5") -> ColumnStat(distinctCount = 5, min = Some(1), max = Some(5), + nullCount = 0, avgLen = 4, maxLen = 4), + attr("t5.v-1-5") -> ColumnStat(distinctCount = 5, min = Some(1), max = Some(5), nullCount = 0, avgLen = 4, maxLen = 4) )) @@ -92,6 +97,13 @@ class JoinReorderSuite extends PlanTest with StatsEstimationTestBase { size = Some(100 * (8 + 4)), attributeStats = AttributeMap(Seq("t3.v-1-100").map(nameToColInfo))) + // Table t5: small table with two columns + private val t5 = StatsTestPlan( + outputList = Seq("t5.k-1-5", "t5.v-1-5").map(nameToAttr), + rowCount = 20, + size = Some(20 * (8 + 4)), + attributeStats = AttributeMap(Seq("t5.k-1-5", "t5.v-1-5").map(nameToColInfo))) + test("reorder 3 tables") { val originalPlan = t1.join(t2).join(t3).where((nameToAttr("t1.k-1-2") === nameToAttr("t2.k-1-5")) && @@ -110,13 +122,17 @@ class JoinReorderSuite extends PlanTest with StatsEstimationTestBase { assertEqualPlans(originalPlan, bestPlan) } - test("reorder 3 tables - put cross join at the end") { + test("put unjoinable item at the end and reorder 3 joinable tables") { + // The ReorderJoin rule puts the unjoinable item at the end, and then CostBasedJoinReorder + // reorders other joinable items. val originalPlan = - t1.join(t2).join(t3).where(nameToAttr("t1.v-1-10") === nameToAttr("t3.v-1-100")) + t1.join(t2).join(t4).join(t3).where((nameToAttr("t1.k-1-2") === nameToAttr("t2.k-1-5")) && + (nameToAttr("t1.v-1-10") === nameToAttr("t3.v-1-100"))) val bestPlan = t1.join(t3, Inner, Some(nameToAttr("t1.v-1-10") === nameToAttr("t3.v-1-100"))) - .join(t2, Inner, None) + .join(t2, Inner, Some(nameToAttr("t1.k-1-2") === nameToAttr("t2.k-1-5"))) + .join(t4) assertEqualPlans(originalPlan, bestPlan) } @@ -136,6 +152,23 @@ class JoinReorderSuite extends PlanTest with StatsEstimationTestBase { assertEqualPlans(originalPlan, bestPlan) } + test("reorder 3 tables - one of the leaf items is a project") { + val originalPlan = + t1.join(t5).join(t3).where((nameToAttr("t1.k-1-2") === nameToAttr("t5.k-1-5")) && + (nameToAttr("t1.v-1-10") === nameToAttr("t3.v-1-100"))) + .select(nameToAttr("t1.v-1-10")) + + // Items: t1, t3, project(t5.k-1-5, t5) + val bestPlan = + t1.join(t3, Inner, Some(nameToAttr("t1.v-1-10") === nameToAttr("t3.v-1-100"))) + .select(nameToAttr("t1.k-1-2"), nameToAttr("t1.v-1-10")) + .join(t5.select(nameToAttr("t5.k-1-5")), Inner, + Some(nameToAttr("t1.k-1-2") === nameToAttr("t5.k-1-5"))) + .select(nameToAttr("t1.v-1-10")) + + assertEqualPlans(originalPlan, bestPlan) + } + test("don't reorder if project contains non-attribute") { val originalPlan = t1.join(t2, Inner, Some(nameToAttr("t1.k-1-2") === nameToAttr("t2.k-1-5"))) @@ -187,6 +220,8 @@ class JoinReorderSuite extends PlanTest with StatsEstimationTestBase { case (j1: Join, j2: Join) => (sameJoinPlan(j1.left, j2.left) && sameJoinPlan(j1.right, j2.right)) || (sameJoinPlan(j1.left, j2.right) && sameJoinPlan(j1.right, j2.left)) + case _ if plan1.children.nonEmpty && plan2.children.nonEmpty => + (plan1.children, plan2.children).zipped.forall { case (c1, c2) => sameJoinPlan(c1, c2) } case _ => plan1 == plan2 } From ccba622e35741d8344ec8d74b6750529b2c7219b Mon Sep 17 00:00:00 2001 From: Takeshi Yamamuro Date: Sat, 18 Mar 2017 14:40:16 +0800 Subject: [PATCH 0049/1765] [SPARK-19896][SQL] Throw an exception if case classes have circular references in toDS ## What changes were proposed in this pull request? If case classes have circular references below, it throws StackOverflowError; ``` scala> :pasge case class classA(i: Int, cls: classB) case class classB(cls: classA) scala> Seq(classA(0, null)).toDS() java.lang.StackOverflowError at scala.reflect.internal.Symbols$Symbol.info(Symbols.scala:1494) at scala.reflect.runtime.JavaMirrors$JavaMirror$$anon$1.scala$reflect$runtime$SynchronizedSymbols$SynchronizedSymbol$$super$info(JavaMirrors.scala:66) at scala.reflect.runtime.SynchronizedSymbols$SynchronizedSymbol$$anonfun$info$1.apply(SynchronizedSymbols.scala:127) at scala.reflect.runtime.SynchronizedSymbols$SynchronizedSymbol$$anonfun$info$1.apply(SynchronizedSymbols.scala:127) at scala.reflect.runtime.Gil$class.gilSynchronized(Gil.scala:19) at scala.reflect.runtime.JavaUniverse.gilSynchronized(JavaUniverse.scala:16) at scala.reflect.runtime.SynchronizedSymbols$SynchronizedSymbol$class.gilSynchronizedIfNotThreadsafe(SynchronizedSymbols.scala:123) at scala.reflect.runtime.JavaMirrors$JavaMirror$$anon$1.gilSynchronizedIfNotThreadsafe(JavaMirrors.scala:66) at scala.reflect.runtime.SynchronizedSymbols$SynchronizedSymbol$class.info(SynchronizedSymbols.scala:127) at scala.reflect.runtime.JavaMirrors$JavaMirror$$anon$1.info(JavaMirrors.scala:66) at scala.reflect.internal.Mirrors$RootsBase.getModuleOrClass(Mirrors.scala:48) at scala.reflect.internal.Mirrors$RootsBase.getModuleOrClass(Mirrors.scala:45) at scala.reflect.internal.Mirrors$RootsBase.getModuleOrClass(Mirrors.scala:45) at scala.reflect.internal.Mirrors$RootsBase.getModuleOrClass(Mirrors.scala:45) at scala.reflect.internal.Mirrors$RootsBase.getModuleOrClass(Mirrors.scala:45) ``` This pr added code to throw UnsupportedOperationException in that case as follows; ``` scala> :paste case class A(cls: B) case class B(cls: A) scala> Seq(A(null)).toDS() java.lang.UnsupportedOperationException: cannot have circular references in class, but got the circular reference of class B at org.apache.spark.sql.catalyst.ScalaReflection$.org$apache$spark$sql$catalyst$ScalaReflection$$serializerFor(ScalaReflection.scala:627) at org.apache.spark.sql.catalyst.ScalaReflection$$anonfun$9.apply(ScalaReflection.scala:644) at org.apache.spark.sql.catalyst.ScalaReflection$$anonfun$9.apply(ScalaReflection.scala:632) at scala.collection.TraversableLike$$anonfun$flatMap$1.apply(TraversableLike.scala:241) at scala.collection.TraversableLike$$anonfun$flatMap$1.apply(TraversableLike.scala:241) at scala.collection.immutable.List.foreach(List.scala:381) at scala.collection.TraversableLike$class.flatMap(TraversableLike.scala:241) ``` ## How was this patch tested? Added tests in `DatasetSuite`. Author: Takeshi Yamamuro Closes #17318 from maropu/SPARK-19896. --- .../spark/sql/catalyst/ScalaReflection.scala | 20 ++++++++++------ .../org/apache/spark/sql/DatasetSuite.scala | 24 +++++++++++++++++++ 2 files changed, 37 insertions(+), 7 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala index 7f7dd51aa2650..c4af284f73d16 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala @@ -470,14 +470,15 @@ object ScalaReflection extends ScalaReflection { private def serializerFor( inputObject: Expression, tpe: `Type`, - walkedTypePath: Seq[String]): Expression = ScalaReflectionLock.synchronized { + walkedTypePath: Seq[String], + seenTypeSet: Set[`Type`] = Set.empty): Expression = ScalaReflectionLock.synchronized { def toCatalystArray(input: Expression, elementType: `Type`): Expression = { dataTypeFor(elementType) match { case dt: ObjectType => val clsName = getClassNameFromType(elementType) val newPath = s"""- array element class: "$clsName"""" +: walkedTypePath - MapObjects(serializerFor(_, elementType, newPath), input, dt) + MapObjects(serializerFor(_, elementType, newPath, seenTypeSet), input, dt) case dt @ (BooleanType | ByteType | ShortType | IntegerType | LongType | FloatType | DoubleType) => @@ -511,7 +512,7 @@ object ScalaReflection extends ScalaReflection { val className = getClassNameFromType(optType) val newPath = s"""- option value class: "$className"""" +: walkedTypePath val unwrapped = UnwrapOption(dataTypeFor(optType), inputObject) - serializerFor(unwrapped, optType, newPath) + serializerFor(unwrapped, optType, newPath, seenTypeSet) // Since List[_] also belongs to localTypeOf[Product], we put this case before // "case t if definedByConstructorParams(t)" to make sure it will match to the @@ -534,9 +535,9 @@ object ScalaReflection extends ScalaReflection { ExternalMapToCatalyst( inputObject, dataTypeFor(keyType), - serializerFor(_, keyType, keyPath), + serializerFor(_, keyType, keyPath, seenTypeSet), dataTypeFor(valueType), - serializerFor(_, valueType, valuePath), + serializerFor(_, valueType, valuePath, seenTypeSet), valueNullable = !valueType.typeSymbol.asClass.isPrimitive) case t if t <:< localTypeOf[String] => @@ -622,6 +623,11 @@ object ScalaReflection extends ScalaReflection { Invoke(obj, "serialize", udt, inputObject :: Nil) case t if definedByConstructorParams(t) => + if (seenTypeSet.contains(t)) { + throw new UnsupportedOperationException( + s"cannot have circular references in class, but got the circular reference of class $t") + } + val params = getConstructorParameters(t) val nonNullOutput = CreateNamedStruct(params.flatMap { case (fieldName, fieldType) => if (javaKeywords.contains(fieldName)) { @@ -634,7 +640,8 @@ object ScalaReflection extends ScalaReflection { returnNullable = !fieldType.typeSymbol.asClass.isPrimitive) val clsName = getClassNameFromType(fieldType) val newPath = s"""- field (class: "$clsName", name: "$fieldName")""" +: walkedTypePath - expressions.Literal(fieldName) :: serializerFor(fieldValue, fieldType, newPath) :: Nil + expressions.Literal(fieldName) :: + serializerFor(fieldValue, fieldType, newPath, seenTypeSet + t) :: Nil }) val nullOutput = expressions.Literal.create(null, nonNullOutput.dataType) expressions.If(IsNull(inputObject), nullOutput, nonNullOutput) @@ -643,7 +650,6 @@ object ScalaReflection extends ScalaReflection { throw new UnsupportedOperationException( s"No Encoder found for $tpe\n" + walkedTypePath.mkString("\n")) } - } /** diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala index b37bf131e8dce..6417e7a8b6038 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala @@ -1136,6 +1136,24 @@ class DatasetSuite extends QueryTest with SharedSQLContext { assert(spark.range(1).map { x => new java.sql.Timestamp(100000) }.head == new java.sql.Timestamp(100000)) } + + test("SPARK-19896: cannot have circular references in in case class") { + val errMsg1 = intercept[UnsupportedOperationException] { + Seq(CircularReferenceClassA(null)).toDS + } + assert(errMsg1.getMessage.startsWith("cannot have circular references in class, but got the " + + "circular reference of class")) + val errMsg2 = intercept[UnsupportedOperationException] { + Seq(CircularReferenceClassC(null)).toDS + } + assert(errMsg2.getMessage.startsWith("cannot have circular references in class, but got the " + + "circular reference of class")) + val errMsg3 = intercept[UnsupportedOperationException] { + Seq(CircularReferenceClassD(null)).toDS + } + assert(errMsg3.getMessage.startsWith("cannot have circular references in class, but got the " + + "circular reference of class")) + } } case class WithImmutableMap(id: String, map_test: scala.collection.immutable.Map[Long, String]) @@ -1214,3 +1232,9 @@ object DatasetTransform { case class Route(src: String, dest: String, cost: Int) case class GroupedRoutes(src: String, dest: String, routes: Seq[Route]) + +case class CircularReferenceClassA(cls: CircularReferenceClassB) +case class CircularReferenceClassB(cls: CircularReferenceClassA) +case class CircularReferenceClassC(ar: Array[CircularReferenceClassC]) +case class CircularReferenceClassD(map: Map[String, CircularReferenceClassE]) +case class CircularReferenceClassE(id: String, list: List[CircularReferenceClassD]) From 54e61df2634163382c7d01a2ad40ffb5e7270abc Mon Sep 17 00:00:00 2001 From: Sean Owen Date: Sat, 18 Mar 2017 18:01:24 +0100 Subject: [PATCH 0050/1765] [SPARK-16599][CORE] java.util.NoSuchElementException: None.get at at org.apache.spark.storage.BlockInfoManager.releaseAllLocksForTask ## What changes were proposed in this pull request? Avoid None.get exception in (rare?) case that no readLocks exist Note that while this would resolve the immediate cause of the exception, it's not clear it is the root problem. ## How was this patch tested? Existing tests Author: Sean Owen Closes #17290 from srowen/SPARK-16599. --- .../scala/org/apache/spark/storage/BlockInfoManager.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/storage/BlockInfoManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockInfoManager.scala index dd8f5bacb9f6e..490d45d12b8e3 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockInfoManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockInfoManager.scala @@ -23,7 +23,7 @@ import scala.collection.JavaConverters._ import scala.collection.mutable import scala.reflect.ClassTag -import com.google.common.collect.ConcurrentHashMultiset +import com.google.common.collect.{ConcurrentHashMultiset, ImmutableMultiset} import org.apache.spark.{SparkException, TaskContext} import org.apache.spark.internal.Logging @@ -340,7 +340,7 @@ private[storage] class BlockInfoManager extends Logging { val blocksWithReleasedLocks = mutable.ArrayBuffer[BlockId]() val readLocks = synchronized { - readLocksByTask.remove(taskAttemptId).get + readLocksByTask.remove(taskAttemptId).getOrElse(ImmutableMultiset.of[BlockId]()) } val writeLocks = synchronized { writeLocksByTask.remove(taskAttemptId).getOrElse(Seq.empty) From 5c165596dac136b9b3a88cfb3578b2423d227eb7 Mon Sep 17 00:00:00 2001 From: Felix Cheung Date: Sat, 18 Mar 2017 16:26:48 -0700 Subject: [PATCH 0051/1765] [SPARK-19654][SPARKR][SS] Structured Streaming API for R ## What changes were proposed in this pull request? Add "experimental" API for SS in R ## How was this patch tested? manual, unit tests Author: Felix Cheung Closes #16982 from felixcheung/rss. --- R/pkg/DESCRIPTION | 1 + R/pkg/NAMESPACE | 13 ++ R/pkg/R/DataFrame.R | 104 ++++++++++- R/pkg/R/SQLContext.R | 50 +++++ R/pkg/R/generics.R | 41 +++- R/pkg/R/streaming.R | 208 +++++++++++++++++++++ R/pkg/R/utils.R | 11 +- R/pkg/inst/tests/testthat/test_streaming.R | 150 +++++++++++++++ 8 files changed, 573 insertions(+), 5 deletions(-) create mode 100644 R/pkg/R/streaming.R create mode 100644 R/pkg/inst/tests/testthat/test_streaming.R diff --git a/R/pkg/DESCRIPTION b/R/pkg/DESCRIPTION index cc471edc376b3..1635f71489aa3 100644 --- a/R/pkg/DESCRIPTION +++ b/R/pkg/DESCRIPTION @@ -54,5 +54,6 @@ Collate: 'types.R' 'utils.R' 'window.R' + 'streaming.R' RoxygenNote: 5.0.1 VignetteBuilder: knitr diff --git a/R/pkg/NAMESPACE b/R/pkg/NAMESPACE index 871f8e41a0f23..78344ce9ff08b 100644 --- a/R/pkg/NAMESPACE +++ b/R/pkg/NAMESPACE @@ -121,6 +121,7 @@ exportMethods("arrange", "insertInto", "intersect", "isLocal", + "isStreaming", "join", "limit", "merge", @@ -169,6 +170,7 @@ exportMethods("arrange", "write.json", "write.orc", "write.parquet", + "write.stream", "write.text", "write.ml") @@ -365,6 +367,7 @@ export("as.DataFrame", "read.json", "read.orc", "read.parquet", + "read.stream", "read.text", "spark.lapply", "spark.addFile", @@ -402,6 +405,16 @@ export("partitionBy", export("windowPartitionBy", "windowOrderBy") +exportClasses("StreamingQuery") + +export("awaitTermination", + "isActive", + "lastProgress", + "queryName", + "status", + "stopQuery") + + S3method(print, jobj) S3method(print, structField) S3method(print, structType) diff --git a/R/pkg/R/DataFrame.R b/R/pkg/R/DataFrame.R index 97e0c9edeab48..bc81633815c65 100644 --- a/R/pkg/R/DataFrame.R +++ b/R/pkg/R/DataFrame.R @@ -133,9 +133,6 @@ setMethod("schema", #' #' Print the logical and physical Catalyst plans to the console for debugging. #' -#' @param x a SparkDataFrame. -#' @param extended Logical. If extended is FALSE, explain() only prints the physical plan. -#' @param ... further arguments to be passed to or from other methods. #' @family SparkDataFrame functions #' @aliases explain,SparkDataFrame-method #' @rdname explain @@ -3515,3 +3512,104 @@ setMethod("getNumPartitions", function(x) { callJMethod(callJMethod(x@sdf, "rdd"), "getNumPartitions") }) + +#' isStreaming +#' +#' Returns TRUE if this SparkDataFrame contains one or more sources that continuously return data +#' as it arrives. +#' +#' @param x A SparkDataFrame +#' @return TRUE if this SparkDataFrame is from a streaming source +#' @family SparkDataFrame functions +#' @aliases isStreaming,SparkDataFrame-method +#' @rdname isStreaming +#' @name isStreaming +#' @seealso \link{read.stream} \link{write.stream} +#' @export +#' @examples +#'\dontrun{ +#' sparkR.session() +#' df <- read.stream("socket", host = "localhost", port = 9999) +#' isStreaming(df) +#' } +#' @note isStreaming since 2.2.0 +#' @note experimental +setMethod("isStreaming", + signature(x = "SparkDataFrame"), + function(x) { + callJMethod(x@sdf, "isStreaming") + }) + +#' Write the streaming SparkDataFrame to a data source. +#' +#' The data source is specified by the \code{source} and a set of options (...). +#' If \code{source} is not specified, the default data source configured by +#' spark.sql.sources.default will be used. +#' +#' Additionally, \code{outputMode} specifies how data of a streaming SparkDataFrame is written to a +#' output data source. There are three modes: +#' \itemize{ +#' \item append: Only the new rows in the streaming SparkDataFrame will be written out. This +#' output mode can be only be used in queries that do not contain any aggregation. +#' \item complete: All the rows in the streaming SparkDataFrame will be written out every time +#' there are some updates. This output mode can only be used in queries that +#' contain aggregations. +#' \item update: Only the rows that were updated in the streaming SparkDataFrame will be written +#' out every time there are some updates. If the query doesn't contain aggregations, +#' it will be equivalent to \code{append} mode. +#' } +#' +#' @param df a streaming SparkDataFrame. +#' @param source a name for external data source. +#' @param outputMode one of 'append', 'complete', 'update'. +#' @param ... additional argument(s) passed to the method. +#' +#' @family SparkDataFrame functions +#' @seealso \link{read.stream} +#' @aliases write.stream,SparkDataFrame-method +#' @rdname write.stream +#' @name write.stream +#' @export +#' @examples +#'\dontrun{ +#' sparkR.session() +#' df <- read.stream("socket", host = "localhost", port = 9999) +#' isStreaming(df) +#' wordCounts <- count(group_by(df, "value")) +#' +#' # console +#' q <- write.stream(wordCounts, "console", outputMode = "complete") +#' # text stream +#' q <- write.stream(df, "text", path = "/home/user/out", checkpointLocation = "/home/user/cp") +#' # memory stream +#' q <- write.stream(wordCounts, "memory", queryName = "outs", outputMode = "complete") +#' head(sql("SELECT * from outs")) +#' queryName(q) +#' +#' stopQuery(q) +#' } +#' @note write.stream since 2.2.0 +#' @note experimental +setMethod("write.stream", + signature(df = "SparkDataFrame"), + function(df, source = NULL, outputMode = NULL, ...) { + if (!is.null(source) && !is.character(source)) { + stop("source should be character, NULL or omitted. It is the data source specified ", + "in 'spark.sql.sources.default' configuration by default.") + } + if (!is.null(outputMode) && !is.character(outputMode)) { + stop("outputMode should be charactor or omitted.") + } + if (is.null(source)) { + source <- getDefaultSqlSource() + } + options <- varargsToStrEnv(...) + write <- handledCallJMethod(df@sdf, "writeStream") + write <- callJMethod(write, "format", source) + if (!is.null(outputMode)) { + write <- callJMethod(write, "outputMode", outputMode) + } + write <- callJMethod(write, "options", options) + ssq <- handledCallJMethod(write, "start") + streamingQuery(ssq) + }) diff --git a/R/pkg/R/SQLContext.R b/R/pkg/R/SQLContext.R index 8354f705f6dea..b75fb0159d503 100644 --- a/R/pkg/R/SQLContext.R +++ b/R/pkg/R/SQLContext.R @@ -937,3 +937,53 @@ read.jdbc <- function(url, tableName, } dataFrame(sdf) } + +#' Load a streaming SparkDataFrame +#' +#' Returns the dataset in a data source as a SparkDataFrame +#' +#' The data source is specified by the \code{source} and a set of options(...). +#' If \code{source} is not specified, the default data source configured by +#' "spark.sql.sources.default" will be used. +#' +#' @param source The name of external data source +#' @param schema The data schema defined in structType, this is required for file-based streaming +#' data source +#' @param ... additional external data source specific named options, for instance \code{path} for +#' file-based streaming data source +#' @return SparkDataFrame +#' @rdname read.stream +#' @name read.stream +#' @seealso \link{write.stream} +#' @export +#' @examples +#'\dontrun{ +#' sparkR.session() +#' df <- read.stream("socket", host = "localhost", port = 9999) +#' q <- write.stream(df, "text", path = "/home/user/out", checkpointLocation = "/home/user/cp") +#' +#' df <- read.stream("json", path = jsonDir, schema = schema, maxFilesPerTrigger = 1) +#' } +#' @name read.stream +#' @note read.stream since 2.2.0 +#' @note experimental +read.stream <- function(source = NULL, schema = NULL, ...) { + sparkSession <- getSparkSession() + if (!is.null(source) && !is.character(source)) { + stop("source should be character, NULL or omitted. It is the data source specified ", + "in 'spark.sql.sources.default' configuration by default.") + } + if (is.null(source)) { + source <- getDefaultSqlSource() + } + options <- varargsToStrEnv(...) + read <- callJMethod(sparkSession, "readStream") + read <- callJMethod(read, "format", source) + if (!is.null(schema)) { + stopifnot(class(schema) == "structType") + read <- callJMethod(read, "schema", schema$jobj) + } + read <- callJMethod(read, "options", options) + sdf <- handledCallJMethod(read, "load") + dataFrame(callJMethod(sdf, "toDF")) +} diff --git a/R/pkg/R/generics.R b/R/pkg/R/generics.R index 45bc12746511c..029771289fd53 100644 --- a/R/pkg/R/generics.R +++ b/R/pkg/R/generics.R @@ -539,6 +539,9 @@ setGeneric("dtypes", function(x) { standardGeneric("dtypes") }) #' @rdname explain #' @export +#' @param x a SparkDataFrame or a StreamingQuery. +#' @param extended Logical. If extended is FALSE, prints only the physical plan. +#' @param ... further arguments to be passed to or from other methods. setGeneric("explain", function(x, ...) { standardGeneric("explain") }) #' @rdname except @@ -577,6 +580,10 @@ setGeneric("intersect", function(x, y) { standardGeneric("intersect") }) #' @export setGeneric("isLocal", function(x) { standardGeneric("isLocal") }) +#' @rdname isStreaming +#' @export +setGeneric("isStreaming", function(x) { standardGeneric("isStreaming") }) + #' @rdname limit #' @export setGeneric("limit", function(x, num) {standardGeneric("limit") }) @@ -682,6 +689,12 @@ setGeneric("write.parquet", function(x, path, ...) { #' @export setGeneric("saveAsParquetFile", function(x, path) { standardGeneric("saveAsParquetFile") }) +#' @rdname write.stream +#' @export +setGeneric("write.stream", function(df, source = NULL, outputMode = NULL, ...) { + standardGeneric("write.stream") +}) + #' @rdname write.text #' @export setGeneric("write.text", function(x, path, ...) { standardGeneric("write.text") }) @@ -1428,10 +1441,36 @@ setGeneric("spark.posterior", function(object, newData) { standardGeneric("spark #' @export setGeneric("spark.perplexity", function(object, data) { standardGeneric("spark.perplexity") }) - #' @param object a fitted ML model object. #' @param path the directory where the model is saved. #' @param ... additional argument(s) passed to the method. #' @rdname write.ml #' @export setGeneric("write.ml", function(object, path, ...) { standardGeneric("write.ml") }) + + +###################### Streaming Methods ########################## + +#' @rdname awaitTermination +#' @export +setGeneric("awaitTermination", function(x, timeout) { standardGeneric("awaitTermination") }) + +#' @rdname isActive +#' @export +setGeneric("isActive", function(x) { standardGeneric("isActive") }) + +#' @rdname lastProgress +#' @export +setGeneric("lastProgress", function(x) { standardGeneric("lastProgress") }) + +#' @rdname queryName +#' @export +setGeneric("queryName", function(x) { standardGeneric("queryName") }) + +#' @rdname status +#' @export +setGeneric("status", function(x) { standardGeneric("status") }) + +#' @rdname stopQuery +#' @export +setGeneric("stopQuery", function(x) { standardGeneric("stopQuery") }) diff --git a/R/pkg/R/streaming.R b/R/pkg/R/streaming.R new file mode 100644 index 0000000000000..e353d2dd07c3d --- /dev/null +++ b/R/pkg/R/streaming.R @@ -0,0 +1,208 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# streaming.R - Structured Streaming / StreamingQuery class and methods implemented in S4 OO classes + +#' @include generics.R jobj.R +NULL + +#' S4 class that represents a StreamingQuery +#' +#' StreamingQuery can be created by using read.stream() and write.stream() +#' +#' @rdname StreamingQuery +#' @seealso \link{read.stream} +#' +#' @param ssq A Java object reference to the backing Scala StreamingQuery +#' @export +#' @note StreamingQuery since 2.2.0 +#' @note experimental +setClass("StreamingQuery", + slots = list(ssq = "jobj")) + +setMethod("initialize", "StreamingQuery", function(.Object, ssq) { + .Object@ssq <- ssq + .Object +}) + +streamingQuery <- function(ssq) { + stopifnot(class(ssq) == "jobj") + new("StreamingQuery", ssq) +} + +#' @rdname show +#' @export +#' @note show(StreamingQuery) since 2.2.0 +setMethod("show", "StreamingQuery", + function(object) { + name <- callJMethod(object@ssq, "name") + if (!is.null(name)) { + cat(paste0("StreamingQuery '", name, "'\n")) + } else { + cat("StreamingQuery", "\n") + } + }) + +#' queryName +#' +#' Returns the user-specified name of the query. This is specified in +#' \code{write.stream(df, queryName = "query")}. This name, if set, must be unique across all active +#' queries. +#' +#' @param x a StreamingQuery. +#' @return The name of the query, or NULL if not specified. +#' @rdname queryName +#' @name queryName +#' @aliases queryName,StreamingQuery-method +#' @family StreamingQuery methods +#' @seealso \link{write.stream} +#' @export +#' @examples +#' \dontrun{ queryName(sq) } +#' @note queryName(StreamingQuery) since 2.2.0 +#' @note experimental +setMethod("queryName", + signature(x = "StreamingQuery"), + function(x) { + callJMethod(x@ssq, "name") + }) + +#' @rdname explain +#' @name explain +#' @aliases explain,StreamingQuery-method +#' @family StreamingQuery methods +#' @export +#' @examples +#' \dontrun{ explain(sq) } +#' @note explain(StreamingQuery) since 2.2.0 +setMethod("explain", + signature(x = "StreamingQuery"), + function(x, extended = FALSE) { + cat(callJMethod(x@ssq, "explainInternal", extended), "\n") + }) + +#' lastProgress +#' +#' Prints the most recent progess update of this streaming query in JSON format. +#' +#' @param x a StreamingQuery. +#' @rdname lastProgress +#' @name lastProgress +#' @aliases lastProgress,StreamingQuery-method +#' @family StreamingQuery methods +#' @export +#' @examples +#' \dontrun{ lastProgress(sq) } +#' @note lastProgress(StreamingQuery) since 2.2.0 +#' @note experimental +setMethod("lastProgress", + signature(x = "StreamingQuery"), + function(x) { + p <- callJMethod(x@ssq, "lastProgress") + if (is.null(p)) { + cat("Streaming query has no progress") + } else { + cat(callJMethod(p, "toString"), "\n") + } + }) + +#' status +#' +#' Prints the current status of the query in JSON format. +#' +#' @param x a StreamingQuery. +#' @rdname status +#' @name status +#' @aliases status,StreamingQuery-method +#' @family StreamingQuery methods +#' @export +#' @examples +#' \dontrun{ status(sq) } +#' @note status(StreamingQuery) since 2.2.0 +#' @note experimental +setMethod("status", + signature(x = "StreamingQuery"), + function(x) { + cat(callJMethod(callJMethod(x@ssq, "status"), "toString"), "\n") + }) + +#' isActive +#' +#' Returns TRUE if this query is actively running. +#' +#' @param x a StreamingQuery. +#' @return TRUE if query is actively running, FALSE if stopped. +#' @rdname isActive +#' @name isActive +#' @aliases isActive,StreamingQuery-method +#' @family StreamingQuery methods +#' @export +#' @examples +#' \dontrun{ isActive(sq) } +#' @note isActive(StreamingQuery) since 2.2.0 +#' @note experimental +setMethod("isActive", + signature(x = "StreamingQuery"), + function(x) { + callJMethod(x@ssq, "isActive") + }) + +#' awaitTermination +#' +#' Waits for the termination of the query, either by \code{stopQuery} or by an error. +#' +#' If the query has terminated, then all subsequent calls to this method will return TRUE +#' immediately. +#' +#' @param x a StreamingQuery. +#' @param timeout time to wait in milliseconds +#' @return TRUE if query has terminated within the timeout period. +#' @rdname awaitTermination +#' @name awaitTermination +#' @aliases awaitTermination,StreamingQuery-method +#' @family StreamingQuery methods +#' @export +#' @examples +#' \dontrun{ awaitTermination(sq, 10000) } +#' @note awaitTermination(StreamingQuery) since 2.2.0 +#' @note experimental +setMethod("awaitTermination", + signature(x = "StreamingQuery"), + function(x, timeout) { + handledCallJMethod(x@ssq, "awaitTermination", as.integer(timeout)) + }) + +#' stopQuery +#' +#' Stops the execution of this query if it is running. This method blocks until the execution is +#' stopped. +#' +#' @param x a StreamingQuery. +#' @rdname stopQuery +#' @name stopQuery +#' @aliases stopQuery,StreamingQuery-method +#' @family StreamingQuery methods +#' @export +#' @examples +#' \dontrun{ stopQuery(sq) } +#' @note stopQuery(StreamingQuery) since 2.2.0 +#' @note experimental +setMethod("stopQuery", + signature(x = "StreamingQuery"), + function(x) { + invisible(callJMethod(x@ssq, "stop")) + }) diff --git a/R/pkg/R/utils.R b/R/pkg/R/utils.R index 1f7848f2b413f..810de9917e0ba 100644 --- a/R/pkg/R/utils.R +++ b/R/pkg/R/utils.R @@ -823,7 +823,16 @@ captureJVMException <- function(e, method) { stacktrace <- rawmsg } - if (any(grep("java.lang.IllegalArgumentException: ", stacktrace))) { + # StreamingQueryException could wrap an IllegalArgumentException, so look for that first + if (any(grep("org.apache.spark.sql.streaming.StreamingQueryException: ", stacktrace))) { + msg <- strsplit(stacktrace, "org.apache.spark.sql.streaming.StreamingQueryException: ", + fixed = TRUE)[[1]] + # Extract "Error in ..." message. + rmsg <- msg[1] + # Extract the first message of JVM exception. + first <- strsplit(msg[2], "\r?\n\tat")[[1]][1] + stop(paste0(rmsg, "streaming query error - ", first), call. = FALSE) + } else if (any(grep("java.lang.IllegalArgumentException: ", stacktrace))) { msg <- strsplit(stacktrace, "java.lang.IllegalArgumentException: ", fixed = TRUE)[[1]] # Extract "Error in ..." message. rmsg <- msg[1] diff --git a/R/pkg/inst/tests/testthat/test_streaming.R b/R/pkg/inst/tests/testthat/test_streaming.R new file mode 100644 index 0000000000000..03b1bd3dc1f44 --- /dev/null +++ b/R/pkg/inst/tests/testthat/test_streaming.R @@ -0,0 +1,150 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +library(testthat) + +context("Structured Streaming") + +# Tests for Structured Streaming functions in SparkR + +sparkSession <- sparkR.session(enableHiveSupport = FALSE) + +jsonSubDir <- file.path("sparkr-test", "json", "") +if (.Platform$OS.type == "windows") { + # file.path removes the empty separator on Windows, adds it back + jsonSubDir <- paste0(jsonSubDir, .Platform$file.sep) +} +jsonDir <- file.path(tempdir(), jsonSubDir) +dir.create(jsonDir, recursive = TRUE) + +mockLines <- c("{\"name\":\"Michael\"}", + "{\"name\":\"Andy\", \"age\":30}", + "{\"name\":\"Justin\", \"age\":19}") +jsonPath <- tempfile(pattern = jsonSubDir, fileext = ".tmp") +writeLines(mockLines, jsonPath) + +mockLinesNa <- c("{\"name\":\"Bob\",\"age\":16,\"height\":176.5}", + "{\"name\":\"Alice\",\"age\":null,\"height\":164.3}", + "{\"name\":\"David\",\"age\":60,\"height\":null}") +jsonPathNa <- tempfile(pattern = jsonSubDir, fileext = ".tmp") + +schema <- structType(structField("name", "string"), + structField("age", "integer"), + structField("count", "double")) + +test_that("read.stream, write.stream, awaitTermination, stopQuery", { + df <- read.stream("json", path = jsonDir, schema = schema, maxFilesPerTrigger = 1) + expect_true(isStreaming(df)) + counts <- count(group_by(df, "name")) + q <- write.stream(counts, "memory", queryName = "people", outputMode = "complete") + + expect_false(awaitTermination(q, 5 * 1000)) + expect_equal(head(sql("SELECT count(*) FROM people"))[[1]], 3) + + writeLines(mockLinesNa, jsonPathNa) + awaitTermination(q, 5 * 1000) + expect_equal(head(sql("SELECT count(*) FROM people"))[[1]], 6) + + stopQuery(q) + expect_true(awaitTermination(q, 1)) +}) + +test_that("print from explain, lastProgress, status, isActive", { + df <- read.stream("json", path = jsonDir, schema = schema) + expect_true(isStreaming(df)) + counts <- count(group_by(df, "name")) + q <- write.stream(counts, "memory", queryName = "people2", outputMode = "complete") + + awaitTermination(q, 5 * 1000) + + expect_equal(capture.output(explain(q))[[1]], "== Physical Plan ==") + expect_true(any(grepl("\"description\" : \"MemorySink\"", capture.output(lastProgress(q))))) + expect_true(any(grepl("\"isTriggerActive\" : ", capture.output(status(q))))) + + expect_equal(queryName(q), "people2") + expect_true(isActive(q)) + + stopQuery(q) +}) + +test_that("Stream other format", { + parquetPath <- tempfile(pattern = "sparkr-test", fileext = ".parquet") + df <- read.df(jsonPath, "json", schema) + write.df(df, parquetPath, "parquet", "overwrite") + + df <- read.stream(path = parquetPath, schema = schema) + expect_true(isStreaming(df)) + counts <- count(group_by(df, "name")) + q <- write.stream(counts, "memory", queryName = "people3", outputMode = "complete") + + expect_false(awaitTermination(q, 5 * 1000)) + expect_equal(head(sql("SELECT count(*) FROM people3"))[[1]], 3) + + expect_equal(queryName(q), "people3") + expect_true(any(grepl("\"description\" : \"FileStreamSource[[:print:]]+parquet", + capture.output(lastProgress(q))))) + expect_true(isActive(q)) + + stopQuery(q) + expect_true(awaitTermination(q, 1)) + expect_false(isActive(q)) + + unlink(parquetPath) +}) + +test_that("Non-streaming DataFrame", { + c <- as.DataFrame(cars) + expect_false(isStreaming(c)) + + expect_error(write.stream(c, "memory", queryName = "people", outputMode = "complete"), + paste0(".*(writeStream : analysis error - 'writeStream' can be called only on ", + "streaming Dataset/DataFrame).*")) +}) + +test_that("Unsupported operation", { + # memory sink without aggregation + df <- read.stream("json", path = jsonDir, schema = schema, maxFilesPerTrigger = 1) + expect_error(write.stream(df, "memory", queryName = "people", outputMode = "complete"), + paste0(".*(start : analysis error - Complete output mode not supported when there ", + "are no streaming aggregations on streaming DataFrames/Datasets).*")) +}) + +test_that("Terminated by error", { + df <- read.stream("json", path = jsonDir, schema = schema, maxFilesPerTrigger = -1) + counts <- count(group_by(df, "name")) + # This would not fail before returning with a StreamingQuery, + # but could dump error log at just about the same time + expect_error(q <- write.stream(counts, "memory", queryName = "people4", outputMode = "complete"), + NA) + + expect_error(awaitTermination(q, 1), + paste0(".*(awaitTermination : streaming query error - Invalid value '-1' for option", + " 'maxFilesPerTrigger', must be a positive integer).*")) + + expect_true(any(grepl("\"message\" : \"Terminated with exception: Invalid value", + capture.output(status(q))))) + expect_true(any(grepl("Streaming query has no progress", capture.output(lastProgress(q))))) + expect_equal(queryName(q), "people4") + expect_false(isActive(q)) + + stopQuery(q) +}) + +unlink(jsonPath) +unlink(jsonPathNa) + +sparkR.session.stop() From 60262bc951864a7a3874ab3570b723198e99d613 Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Sun, 19 Mar 2017 10:30:34 -0700 Subject: [PATCH 0052/1765] [MINOR][R] Reorder `Collate` fields in DESCRIPTION file ## What changes were proposed in this pull request? It seems cran check scripts corrects `R/pkg/DESCRIPTION` and follows the order in `Collate` fields. This PR proposes to fix this so that running this script does not show up a diff in this file. ## How was this patch tested? Manually via `./R/check-cran.sh`. Author: hyukjinkwon Closes #17349 from HyukjinKwon/minor-cran. --- R/pkg/DESCRIPTION | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/R/pkg/DESCRIPTION b/R/pkg/DESCRIPTION index 1635f71489aa3..2ea90f7d3666e 100644 --- a/R/pkg/DESCRIPTION +++ b/R/pkg/DESCRIPTION @@ -51,9 +51,9 @@ Collate: 'serialize.R' 'sparkR.R' 'stats.R' + 'streaming.R' 'types.R' 'utils.R' 'window.R' - 'streaming.R' RoxygenNote: 5.0.1 VignetteBuilder: knitr From 422aa67d1bb84f913b06e6d94615adb6557e2870 Mon Sep 17 00:00:00 2001 From: Felix Cheung Date: Sun, 19 Mar 2017 10:37:15 -0700 Subject: [PATCH 0053/1765] [SPARK-18817][SPARKR][SQL] change derby log output to temp dir ## What changes were proposed in this pull request? Passes R `tempdir()` (this is the R session temp dir, shared with other temp files/dirs) to JVM, set System.Property for derby home dir to move derby.log ## How was this patch tested? Manually, unit tests With this, these are relocated to under /tmp ``` # ls /tmp/RtmpG2M0cB/ derby.log ``` And they are removed automatically when the R session is ended. Author: Felix Cheung Closes #16330 from felixcheung/rderby. --- R/pkg/R/sparkR.R | 15 +++++++- R/pkg/inst/tests/testthat/test_sparkSQL.R | 34 +++++++++++++++++++ R/pkg/tests/run-all.R | 6 ++++ .../scala/org/apache/spark/api/r/RRDD.scala | 9 +++++ 4 files changed, 63 insertions(+), 1 deletion(-) diff --git a/R/pkg/R/sparkR.R b/R/pkg/R/sparkR.R index 61773ed3ee8c0..d0a12b7ecec65 100644 --- a/R/pkg/R/sparkR.R +++ b/R/pkg/R/sparkR.R @@ -322,10 +322,19 @@ sparkRHive.init <- function(jsc = NULL) { #' SparkSession or initializes a new SparkSession. #' Additional Spark properties can be set in \code{...}, and these named parameters take priority #' over values in \code{master}, \code{appName}, named lists of \code{sparkConfig}. -#' When called in an interactive session, this checks for the Spark installation, and, if not +#' +#' When called in an interactive session, this method checks for the Spark installation, and, if not #' found, it will be downloaded and cached automatically. Alternatively, \code{install.spark} can #' be called manually. #' +#' A default warehouse is created automatically in the current directory when a managed table is +#' created via \code{sql} statement \code{CREATE TABLE}, for example. To change the location of the +#' warehouse, set the named parameter \code{spark.sql.warehouse.dir} to the SparkSession. Along with +#' the warehouse, an accompanied metastore may also be automatically created in the current +#' directory when a new SparkSession is initialized with \code{enableHiveSupport} set to +#' \code{TRUE}, which is the default. For more details, refer to Hive configuration at +#' \url{http://spark.apache.org/docs/latest/sql-programming-guide.html#hive-tables}. +#' #' For details on how to initialize and use SparkR, refer to SparkR programming guide at #' \url{http://spark.apache.org/docs/latest/sparkr.html#starting-up-sparksession}. #' @@ -381,6 +390,10 @@ sparkR.session <- function( deployMode <- sparkConfigMap[["spark.submit.deployMode"]] } + if (!exists("spark.r.sql.derby.temp.dir", envir = sparkConfigMap)) { + sparkConfigMap[["spark.r.sql.derby.temp.dir"]] <- tempdir() + } + if (!exists(".sparkRjsc", envir = .sparkREnv)) { retHome <- sparkCheckInstall(sparkHome, master, deployMode) if (!is.null(retHome)) sparkHome <- retHome diff --git a/R/pkg/inst/tests/testthat/test_sparkSQL.R b/R/pkg/inst/tests/testthat/test_sparkSQL.R index f7081cb1d4e50..32856b399cdd1 100644 --- a/R/pkg/inst/tests/testthat/test_sparkSQL.R +++ b/R/pkg/inst/tests/testthat/test_sparkSQL.R @@ -60,6 +60,7 @@ unsetHiveContext <- function() { # Tests for SparkSQL functions in SparkR +filesBefore <- list.files(path = sparkRDir, all.files = TRUE) sparkSession <- sparkR.session() sc <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", "getJavaSparkContext", sparkSession) @@ -2909,6 +2910,39 @@ test_that("Collect on DataFrame when NAs exists at the top of a timestamp column expect_equal(class(ldf3$col3), c("POSIXct", "POSIXt")) }) +compare_list <- function(list1, list2) { + # get testthat to show the diff by first making the 2 lists equal in length + expect_equal(length(list1), length(list2)) + l <- max(length(list1), length(list2)) + length(list1) <- l + length(list2) <- l + expect_equal(sort(list1, na.last = TRUE), sort(list2, na.last = TRUE)) +} + +# This should always be the **very last test** in this test file. +test_that("No extra files are created in SPARK_HOME by starting session and making calls", { + # Check that it is not creating any extra file. + # Does not check the tempdir which would be cleaned up after. + filesAfter <- list.files(path = sparkRDir, all.files = TRUE) + + expect_true(length(sparkRFilesBefore) > 0) + # first, ensure derby.log is not there + expect_false("derby.log" %in% filesAfter) + # second, ensure only spark-warehouse is created when calling SparkSession, enableHiveSupport = F + # note: currently all other test files have enableHiveSupport = F, so we capture the list of files + # before creating a SparkSession with enableHiveSupport = T at the top of this test file + # (filesBefore). The test here is to compare that (filesBefore) against the list of files before + # any test is run in run-all.R (sparkRFilesBefore). + # sparkRWhitelistSQLDirs is also defined in run-all.R, and should contain only 2 whitelisted dirs, + # here allow the first value, spark-warehouse, in the diff, everything else should be exactly the + # same as before any test is run. + compare_list(sparkRFilesBefore, setdiff(filesBefore, sparkRWhitelistSQLDirs[[1]])) + # third, ensure only spark-warehouse and metastore_db are created when enableHiveSupport = T + # note: as the note above, after running all tests in this file while enableHiveSupport = T, we + # check the list of files again. This time we allow both whitelisted dirs to be in the diff. + compare_list(sparkRFilesBefore, setdiff(filesAfter, sparkRWhitelistSQLDirs)) +}) + unlink(parquetPath) unlink(orcPath) unlink(jsonPath) diff --git a/R/pkg/tests/run-all.R b/R/pkg/tests/run-all.R index ab8d1ca019941..cefaadda6e215 100644 --- a/R/pkg/tests/run-all.R +++ b/R/pkg/tests/run-all.R @@ -22,6 +22,12 @@ library(SparkR) options("warn" = 2) # Setup global test environment +sparkRDir <- file.path(Sys.getenv("SPARK_HOME"), "R") +sparkRFilesBefore <- list.files(path = sparkRDir, all.files = TRUE) +sparkRWhitelistSQLDirs <- c("spark-warehouse", "metastore_db") +invisible(lapply(sparkRWhitelistSQLDirs, + function(x) { unlink(file.path(sparkRDir, x), recursive = TRUE, force = TRUE)})) + install.spark() test_package("SparkR") diff --git a/core/src/main/scala/org/apache/spark/api/r/RRDD.scala b/core/src/main/scala/org/apache/spark/api/r/RRDD.scala index a1a5eb8cf55e8..72ae0340aa3d1 100644 --- a/core/src/main/scala/org/apache/spark/api/r/RRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/r/RRDD.scala @@ -17,6 +17,7 @@ package org.apache.spark.api.r +import java.io.File import java.util.{Map => JMap} import scala.collection.JavaConverters._ @@ -127,6 +128,14 @@ private[r] object RRDD { sparkConf.setExecutorEnv(name.toString, value.toString) } + if (sparkEnvirMap.containsKey("spark.r.sql.derby.temp.dir") && + System.getProperty("derby.stream.error.file") == null) { + // This must be set before SparkContext is instantiated. + System.setProperty("derby.stream.error.file", + Seq(sparkEnvirMap.get("spark.r.sql.derby.temp.dir").toString, "derby.log") + .mkString(File.separator)) + } + val jsc = new JavaSparkContext(sparkConf) jars.foreach { jar => jsc.addJar(jar) From 0ee9fbf51ac863e015d57ae7824a39bd3b36141a Mon Sep 17 00:00:00 2001 From: Xiao Li Date: Sun, 19 Mar 2017 13:52:22 -0700 Subject: [PATCH 0054/1765] [SPARK-19990][TEST] Use the database after Hive's current Database is dropped ### What changes were proposed in this pull request? This PR is to fix the following test failure in maven and the PR https://github.com/apache/spark/pull/15363. > org.apache.spark.sql.hive.orc.OrcSourceSuite SPARK-19459/SPARK-18220: read char/varchar column written by Hive The[ test history](https://spark-tests.appspot.com/test-details?suite_name=org.apache.spark.sql.hive.orc.OrcSourceSuite&test_name=SPARK-19459%2FSPARK-18220%3A+read+char%2Fvarchar+column+written+by+Hive) shows all the maven builds failed this test case with the same error message. ``` FAILED: SemanticException [Error 10072]: Database does not exist: db2 org.apache.spark.sql.execution.QueryExecutionException: FAILED: SemanticException [Error 10072]: Database does not exist: db2 at org.apache.spark.sql.hive.client.HiveClientImpl$$anonfun$runHive$1.apply(HiveClientImpl.scala:637) at org.apache.spark.sql.hive.client.HiveClientImpl$$anonfun$runHive$1.apply(HiveClientImpl.scala:621) at org.apache.spark.sql.hive.client.HiveClientImpl$$anonfun$withHiveState$1.apply(HiveClientImpl.scala:288) at org.apache.spark.sql.hive.client.HiveClientImpl.liftedTree1$1(HiveClientImpl.scala:229) at org.apache.spark.sql.hive.client.HiveClientImpl.retryLocked(HiveClientImpl.scala:228) at org.apache.spark.sql.hive.client.HiveClientImpl.withHiveState(HiveClientImpl.scala:271) at org.apache.spark.sql.hive.client.HiveClientImpl.runHive(HiveClientImpl.scala:621) at org.apache.spark.sql.hive.client.HiveClientImpl.runSqlHive(HiveClientImpl.scala:611) at org.apache.spark.sql.hive.orc.OrcSuite$$anonfun$7.apply$mcV$sp(OrcSourceSuite.scala:160) at org.apache.spark.sql.hive.orc.OrcSuite$$anonfun$7.apply(OrcSourceSuite.scala:155) at org.apache.spark.sql.hive.orc.OrcSuite$$anonfun$7.apply(OrcSourceSuite.scala:155) at org.scalatest.Transformer$$anonfun$apply$1.apply$mcV$sp(Transformer.scala:22) at org.scalatest.OutcomeOf$class.outcomeOf(OutcomeOf.scala:85) at org.scalatest.OutcomeOf$.outcomeOf(OutcomeOf.scala:104) at org.scalatest.Transformer.apply(Transformer.scala:22) at org.scalatest.Transformer.apply(Transformer.scala:20) at org.scalatest.FunSuiteLike$$anon$1.apply(FunSuiteLike.scala:166) at org.apache.spark.SparkFunSuite.withFixture(SparkFunSuite.scala:68) at org.scalatest.FunSuiteLike$class.invokeWithFixture$1(FunSuiteLike.scala:163) at org.scalatest.FunSuiteLike$$anonfun$runTest$1.apply(FunSuiteLike.scala:175) ``` ### How was this patch tested? N/A Author: Xiao Li Closes #17344 from gatorsmile/testtest. --- .../spark/sql/hive/orc/OrcSourceSuite.scala | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcSourceSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcSourceSuite.scala index 11dda5425cf94..6bfb88c0c1af5 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcSourceSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcSourceSuite.scala @@ -157,19 +157,21 @@ abstract class OrcSuite extends QueryTest with TestHiveSingleton with BeforeAndA val location = Utils.createTempDir() val uri = location.toURI try { + hiveClient.runSqlHive("USE default") hiveClient.runSqlHive( """ - |CREATE EXTERNAL TABLE hive_orc( - | a STRING, - | b CHAR(10), - | c VARCHAR(10), - | d ARRAY) - |STORED AS orc""".stripMargin) + |CREATE EXTERNAL TABLE hive_orc( + | a STRING, + | b CHAR(10), + | c VARCHAR(10), + | d ARRAY) + |STORED AS orc""".stripMargin) // Hive throws an exception if I assign the location in the create table statement. hiveClient.runSqlHive( s"ALTER TABLE hive_orc SET LOCATION '$uri'") hiveClient.runSqlHive( - """INSERT INTO TABLE hive_orc + """ + |INSERT INTO TABLE hive_orc |SELECT 'a', 'b', 'c', ARRAY(CAST('d' AS CHAR(3))) |FROM (SELECT 1) t""".stripMargin) From 990af630d0d569880edd9c7ce9932e10037a28ab Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Sun, 19 Mar 2017 14:07:49 -0700 Subject: [PATCH 0055/1765] [SPARK-19067][SS] Processing-time-based timeout in MapGroupsWithState ## What changes were proposed in this pull request? When a key does not get any new data in `mapGroupsWithState`, the mapping function is never called on it. So we need a timeout feature that calls the function again in such cases, so that the user can decide whether to continue waiting or clean up (remove state, save stuff externally, etc.). Timeouts can be either based on processing time or event time. This JIRA is for processing time, but defines the high level API design for both. The usage would look like this. ``` def stateFunction(key: K, value: Iterator[V], state: KeyedState[S]): U = { ... state.setTimeoutDuration(10000) ... } dataset // type is Dataset[T] .groupByKey[K](keyingFunc) // generates KeyValueGroupedDataset[K, T] .mapGroupsWithState[S, U]( func = stateFunction, timeout = KeyedStateTimeout.withProcessingTime) // returns Dataset[U] ``` Note the following design aspects. - The timeout type is provided as a param in mapGroupsWithState as a parameter global to all the keys. This is so that the planner knows this at planning time, and accordingly optimize the execution based on whether to saves extra info in state or not (e.g. timeout durations or timestamps). - The exact timeout duration is provided inside the function call so that it can be customized on a per key basis. - When the timeout occurs for a key, the function is called with no values, and KeyedState.isTimingOut() set to true. - The timeout is reset for key every time the function is called on the key, that is, when the key has new data, or the key has timed out. So the user has to set the timeout duration everytime the function is called, otherwise there will not be any timeout set. Guarantees provided on timeout of key, when timeout duration is D ms: - Timeout will never be called before real clock time has advanced by D ms - Timeout will be called eventually when there is a trigger with any data in it (i.e. after D ms). So there is a no strict upper bound on when the timeout would occur. For example, if there is no data in the stream (for any key) for a while, then the timeout will not be hit. Implementation details: - Added new param to `mapGroupsWithState` for timeout - Added new method to `StateStore` to filter data based on timeout timestamp - Changed the internal map type of `HDFSBackedStateStore` from Java's `HashMap` to `ConcurrentHashMap` as the latter allows weakly-consistent fail-safe iterators on the map data. See comments in code for more details. - Refactored logic of `MapGroupsWithStateExec` to - Save timeout info to state store for each key that has data. - Then, filter states that should be timed out based on the current batch processing timestamp. - Moved KeyedState for `o.a.s.sql` to `o.a.s.sql.streaming`. I remember that this was a feedback in the MapGroupsWithState PR that I had forgotten to address. ## How was this patch tested? New unit tests in - MapGroupsWithStateSuite for timeouts. - StateStoreSuite for new APIs in StateStore. Author: Tathagata Das Closes #17179 from tdas/mapgroupwithstate-timeout. --- .../sql/streaming/KeyedStateTimeout.java | 42 ++ .../expressions/objects/objects.scala | 2 +- .../sql/catalyst/plans/logical/object.scala | 30 +- .../streaming/JavaKeyedStateTimeoutSuite.java | 29 + .../analysis/UnsupportedOperationsSuite.scala | 80 +-- .../FlatMapGroupsWithStateFunction.java | 2 +- .../function/MapGroupsWithStateFunction.java | 2 +- .../spark/sql/KeyValueGroupedDataset.scala | 137 +++-- .../org/apache/spark/sql/KeyedState.scala | 140 ----- .../spark/sql/execution/SparkStrategies.scala | 20 +- .../sql/execution/command/commands.scala | 5 +- .../FlatMapGroupsWithStateExec.scala | 258 +++++++++ .../streaming/IncrementalExecution.scala | 16 +- .../execution/streaming/KeyedStateImpl.scala | 104 +++- .../execution/streaming/StreamExecution.scala | 2 +- .../state/HDFSBackedStateStoreProvider.scala | 19 +- .../streaming/state/StateStore.scala | 9 + .../streaming/statefulOperators.scala | 97 +--- .../spark/sql/streaming/KeyedState.scala | 214 +++++++ .../apache/spark/sql/JavaDatasetSuite.java | 4 +- .../streaming/state/StateStoreSuite.scala | 24 + .../FlatMapGroupsWithStateSuite.scala | 546 ++++++++++++++++-- 22 files changed, 1353 insertions(+), 429 deletions(-) create mode 100644 sql/catalyst/src/main/java/org/apache/spark/sql/streaming/KeyedStateTimeout.java create mode 100644 sql/catalyst/src/test/java/org/apache/spark/sql/streaming/JavaKeyedStateTimeoutSuite.java delete mode 100644 sql/core/src/main/scala/org/apache/spark/sql/KeyedState.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/streaming/KeyedState.scala diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/streaming/KeyedStateTimeout.java b/sql/catalyst/src/main/java/org/apache/spark/sql/streaming/KeyedStateTimeout.java new file mode 100644 index 0000000000000..cf112f2e02a95 --- /dev/null +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/streaming/KeyedStateTimeout.java @@ -0,0 +1,42 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.streaming; + +import org.apache.spark.annotation.Experimental; +import org.apache.spark.annotation.InterfaceStability; +import org.apache.spark.sql.catalyst.plans.logical.NoTimeout$; +import org.apache.spark.sql.catalyst.plans.logical.ProcessingTimeTimeout; +import org.apache.spark.sql.catalyst.plans.logical.ProcessingTimeTimeout$; + +/** + * Represents the type of timeouts possible for the Dataset operations + * `mapGroupsWithState` and `flatMapGroupsWithState`. See documentation on + * `KeyedState` for more details. + * + * @since 2.2.0 + */ +@Experimental +@InterfaceStability.Evolving +public class KeyedStateTimeout { + + /** Timeout based on processing time. */ + public static KeyedStateTimeout ProcessingTimeTimeout() { return ProcessingTimeTimeout$.MODULE$; } + + /** No timeout */ + public static KeyedStateTimeout NoTimeout() { return NoTimeout$.MODULE$; } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala index 36bf3017d4cdb..771ac28e5107a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala @@ -951,7 +951,7 @@ case class AssertNotNull(child: Expression, walkedTypePath: Seq[String]) override def eval(input: InternalRow): Any = { val result = child.eval(input) if (result == null) { - throw new RuntimeException(errMsg); + throw new RuntimeException(errMsg) } result } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala index 7f4462e583607..d1f95faf2db0c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala @@ -26,7 +26,7 @@ import org.apache.spark.sql.catalyst.analysis.UnresolvedDeserializer import org.apache.spark.sql.catalyst.encoders._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.objects.Invoke -import org.apache.spark.sql.streaming.OutputMode +import org.apache.spark.sql.streaming.{KeyedStateTimeout, OutputMode } import org.apache.spark.sql.types._ import org.apache.spark.util.Utils @@ -353,6 +353,10 @@ case class MapGroups( /** Internal class representing State */ trait LogicalKeyedState[S] +/** Possible types of timeouts used in FlatMapGroupsWithState */ +case object NoTimeout extends KeyedStateTimeout +case object ProcessingTimeTimeout extends KeyedStateTimeout + /** Factory for constructing new `MapGroupsWithState` nodes. */ object FlatMapGroupsWithState { def apply[K: Encoder, V: Encoder, S: Encoder, U: Encoder]( @@ -361,7 +365,10 @@ object FlatMapGroupsWithState { dataAttributes: Seq[Attribute], outputMode: OutputMode, isMapGroupsWithState: Boolean, + timeout: KeyedStateTimeout, child: LogicalPlan): LogicalPlan = { + val encoder = encoderFor[S] + val mapped = new FlatMapGroupsWithState( func, UnresolvedDeserializer(encoderFor[K].deserializer, groupingAttributes), @@ -369,11 +376,11 @@ object FlatMapGroupsWithState { groupingAttributes, dataAttributes, CatalystSerde.generateObjAttr[U], - encoderFor[S].resolveAndBind().deserializer, - encoderFor[S].namedExpressions, + encoder.asInstanceOf[ExpressionEncoder[Any]], outputMode, - child, - isMapGroupsWithState) + isMapGroupsWithState, + timeout, + child) CatalystSerde.serialize[U](mapped) } } @@ -384,15 +391,16 @@ object FlatMapGroupsWithState { * Func is invoked with an object representation of the grouping key an iterator containing the * object representation of all the rows with that key. * + * @param func function called on each group * @param keyDeserializer used to extract the key object for each group. * @param valueDeserializer used to extract the items in the iterator from an input row. * @param groupingAttributes used to group the data * @param dataAttributes used to read the data * @param outputObjAttr used to define the output object - * @param stateDeserializer used to deserialize state before calling `func` - * @param stateSerializer used to serialize updated state after calling `func` + * @param stateEncoder used to serialize/deserialize state before calling `func` * @param outputMode the output mode of `func` * @param isMapGroupsWithState whether it is created by the `mapGroupsWithState` method + * @param timeout used to timeout groups that have not received data in a while */ case class FlatMapGroupsWithState( func: (Any, Iterator[Any], LogicalKeyedState[Any]) => Iterator[Any], @@ -401,11 +409,11 @@ case class FlatMapGroupsWithState( groupingAttributes: Seq[Attribute], dataAttributes: Seq[Attribute], outputObjAttr: Attribute, - stateDeserializer: Expression, - stateSerializer: Seq[NamedExpression], + stateEncoder: ExpressionEncoder[Any], outputMode: OutputMode, - child: LogicalPlan, - isMapGroupsWithState: Boolean = false) extends UnaryNode with ObjectProducer { + isMapGroupsWithState: Boolean = false, + timeout: KeyedStateTimeout, + child: LogicalPlan) extends UnaryNode with ObjectProducer { if (isMapGroupsWithState) { assert(outputMode == OutputMode.Update) diff --git a/sql/catalyst/src/test/java/org/apache/spark/sql/streaming/JavaKeyedStateTimeoutSuite.java b/sql/catalyst/src/test/java/org/apache/spark/sql/streaming/JavaKeyedStateTimeoutSuite.java new file mode 100644 index 0000000000000..02c94b0b32449 --- /dev/null +++ b/sql/catalyst/src/test/java/org/apache/spark/sql/streaming/JavaKeyedStateTimeoutSuite.java @@ -0,0 +1,29 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.streaming; + +import org.apache.spark.sql.catalyst.plans.logical.ProcessingTimeTimeout$; +import org.junit.Test; + +public class JavaKeyedStateTimeoutSuite { + + @Test + public void testTimeouts() { + assert(KeyedStateTimeout.ProcessingTimeTimeout() == ProcessingTimeTimeout$.MODULE$); + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationsSuite.scala index 200c39f43a6b4..08216e2660400 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationsSuite.scala @@ -144,14 +144,16 @@ class UnsupportedOperationsSuite extends SparkFunSuite { assertSupportedInBatchPlan( s"flatMapGroupsWithState - flatMapGroupsWithState($funcMode) on batch relation", FlatMapGroupsWithState( - null, att, att, Seq(att), Seq(att), att, att, Seq(att), funcMode, batchRelation)) + null, att, att, Seq(att), Seq(att), att, null, funcMode, isMapGroupsWithState = false, null, + batchRelation)) assertSupportedInBatchPlan( s"flatMapGroupsWithState - multiple flatMapGroupsWithState($funcMode)s on batch relation", FlatMapGroupsWithState( - null, att, att, Seq(att), Seq(att), att, att, Seq(att), funcMode, + null, att, att, Seq(att), Seq(att), att, null, funcMode, isMapGroupsWithState = false, null, FlatMapGroupsWithState( - null, att, att, Seq(att), Seq(att), att, att, Seq(att), funcMode, batchRelation))) + null, att, att, Seq(att), Seq(att), att, null, funcMode, isMapGroupsWithState = false, + null, batchRelation))) } // FlatMapGroupsWithState(Update) in streaming without aggregation @@ -159,14 +161,16 @@ class UnsupportedOperationsSuite extends SparkFunSuite { "flatMapGroupsWithState - flatMapGroupsWithState(Update) " + "on streaming relation without aggregation in update mode", FlatMapGroupsWithState( - null, att, att, Seq(att), Seq(att), att, att, Seq(att), Update, streamRelation), + null, att, att, Seq(att), Seq(att), att, null, Update, isMapGroupsWithState = false, null, + streamRelation), outputMode = Update) assertNotSupportedInStreamingPlan( "flatMapGroupsWithState - flatMapGroupsWithState(Update) " + "on streaming relation without aggregation in append mode", FlatMapGroupsWithState( - null, att, att, Seq(att), Seq(att), att, att, Seq(att), Update, streamRelation), + null, att, att, Seq(att), Seq(att), att, null, Update, isMapGroupsWithState = false, null, + streamRelation), outputMode = Append, expectedMsgs = Seq("flatMapGroupsWithState in update mode", "Append")) @@ -174,7 +178,8 @@ class UnsupportedOperationsSuite extends SparkFunSuite { "flatMapGroupsWithState - flatMapGroupsWithState(Update) " + "on streaming relation without aggregation in complete mode", FlatMapGroupsWithState( - null, att, att, Seq(att), Seq(att), att, att, Seq(att), Update, streamRelation), + null, att, att, Seq(att), Seq(att), att, null, Update, isMapGroupsWithState = false, null, + streamRelation), outputMode = Complete, // Disallowed by the aggregation check but let's still keep this test in case it's broken in // future. @@ -186,7 +191,7 @@ class UnsupportedOperationsSuite extends SparkFunSuite { "flatMapGroupsWithState - flatMapGroupsWithState(Update) on streaming relation " + s"with aggregation in $outputMode mode", FlatMapGroupsWithState( - null, att, att, Seq(att), Seq(att), att, att, Seq(att), Update, + null, att, att, Seq(att), Seq(att), att, null, Update, isMapGroupsWithState = false, null, Aggregate(Seq(attributeWithWatermark), aggExprs("c"), streamRelation)), outputMode = outputMode, expectedMsgs = Seq("flatMapGroupsWithState in update mode", "with aggregation")) @@ -197,14 +202,16 @@ class UnsupportedOperationsSuite extends SparkFunSuite { "flatMapGroupsWithState - flatMapGroupsWithState(Append) " + "on streaming relation without aggregation in append mode", FlatMapGroupsWithState( - null, att, att, Seq(att), Seq(att), att, att, Seq(att), Append, streamRelation), + null, att, att, Seq(att), Seq(att), att, null, Append, isMapGroupsWithState = false, null, + streamRelation), outputMode = Append) assertNotSupportedInStreamingPlan( "flatMapGroupsWithState - flatMapGroupsWithState(Append) " + "on streaming relation without aggregation in update mode", FlatMapGroupsWithState( - null, att, att, Seq(att), Seq(att), att, att, Seq(att), Append, streamRelation), + null, att, att, Seq(att), Seq(att), att, null, Append, isMapGroupsWithState = false, null, + streamRelation), outputMode = Update, expectedMsgs = Seq("flatMapGroupsWithState in append mode", "update")) @@ -217,7 +224,8 @@ class UnsupportedOperationsSuite extends SparkFunSuite { Seq(attributeWithWatermark), aggExprs("c"), FlatMapGroupsWithState( - null, att, att, Seq(att), Seq(att), att, att, Seq(att), Append, streamRelation)), + null, att, att, Seq(att), Seq(att), att, null, Append, isMapGroupsWithState = false, null, + streamRelation)), outputMode = outputMode) } @@ -225,7 +233,8 @@ class UnsupportedOperationsSuite extends SparkFunSuite { assertNotSupportedInStreamingPlan( "flatMapGroupsWithState - flatMapGroupsWithState(Append) " + s"on streaming relation after aggregation in $outputMode mode", - FlatMapGroupsWithState(null, att, att, Seq(att), Seq(att), att, att, Seq(att), Append, + FlatMapGroupsWithState(null, att, att, Seq(att), Seq(att), att, null, Append, + isMapGroupsWithState = false, null, Aggregate(Seq(attributeWithWatermark), aggExprs("c"), streamRelation)), outputMode = outputMode, expectedMsgs = Seq("flatMapGroupsWithState", "after aggregation")) @@ -235,7 +244,8 @@ class UnsupportedOperationsSuite extends SparkFunSuite { "flatMapGroupsWithState - " + "flatMapGroupsWithState(Update) on streaming relation in complete mode", FlatMapGroupsWithState( - null, att, att, Seq(att), Seq(att), att, att, Seq(att), Append, streamRelation), + null, att, att, Seq(att), Seq(att), att, null, Append, isMapGroupsWithState = false, null, + streamRelation), outputMode = Complete, // Disallowed by the aggregation check but let's still keep this test in case it's broken in // future. @@ -248,7 +258,8 @@ class UnsupportedOperationsSuite extends SparkFunSuite { s"flatMapGroupsWithState - flatMapGroupsWithState($funcMode) on batch relation inside " + s"streaming relation in $outputMode output mode", FlatMapGroupsWithState( - null, att, att, Seq(att), Seq(att), att, att, Seq(att), funcMode, batchRelation), + null, att, att, Seq(att), Seq(att), att, null, funcMode, isMapGroupsWithState = false, + null, batchRelation), outputMode = outputMode ) } @@ -258,19 +269,20 @@ class UnsupportedOperationsSuite extends SparkFunSuite { assertSupportedInStreamingPlan( "flatMapGroupsWithState - multiple flatMapGroupsWithStates on streaming relation and all are " + "in append mode", - FlatMapGroupsWithState( - null, att, att, Seq(att), Seq(att), att, att, Seq(att), Append, - FlatMapGroupsWithState( - null, att, att, Seq(att), Seq(att), att, att, Seq(att), Append, streamRelation)), + FlatMapGroupsWithState(null, att, att, Seq(att), Seq(att), att, null, Append, + isMapGroupsWithState = false, null, + FlatMapGroupsWithState(null, att, att, Seq(att), Seq(att), att, null, Append, + isMapGroupsWithState = false, null, streamRelation)), outputMode = Append) assertNotSupportedInStreamingPlan( "flatMapGroupsWithState - multiple flatMapGroupsWithStates on s streaming relation but some" + " are not in append mode", FlatMapGroupsWithState( - null, att, att, Seq(att), Seq(att), att, att, Seq(att), Update, + null, att, att, Seq(att), Seq(att), att, null, Update, isMapGroupsWithState = false, null, FlatMapGroupsWithState( - null, att, att, Seq(att), Seq(att), att, att, Seq(att), Append, streamRelation)), + null, att, att, Seq(att), Seq(att), att, null, Append, isMapGroupsWithState = false, null, + streamRelation)), outputMode = Append, expectedMsgs = Seq("multiple flatMapGroupsWithState", "append")) @@ -279,8 +291,8 @@ class UnsupportedOperationsSuite extends SparkFunSuite { "mapGroupsWithState - mapGroupsWithState " + "on streaming relation without aggregation in append mode", FlatMapGroupsWithState( - null, att, att, Seq(att), Seq(att), att, att, Seq(att), Update, streamRelation, - isMapGroupsWithState = true), + null, att, att, Seq(att), Seq(att), att, null, Update, isMapGroupsWithState = true, null, + streamRelation), outputMode = Append, // Disallowed by the aggregation check but let's still keep this test in case it's broken in // future. @@ -290,8 +302,8 @@ class UnsupportedOperationsSuite extends SparkFunSuite { "mapGroupsWithState - mapGroupsWithState " + "on streaming relation without aggregation in complete mode", FlatMapGroupsWithState( - null, att, att, Seq(att), Seq(att), att, att, Seq(att), Update, streamRelation, - isMapGroupsWithState = true), + null, att, att, Seq(att), Seq(att), att, null, Update, isMapGroupsWithState = true, null, + streamRelation), outputMode = Complete, // Disallowed by the aggregation check but let's still keep this test in case it's broken in // future. @@ -301,10 +313,9 @@ class UnsupportedOperationsSuite extends SparkFunSuite { assertNotSupportedInStreamingPlan( "mapGroupsWithState - mapGroupsWithState on streaming relation " + s"with aggregation in $outputMode mode", - FlatMapGroupsWithState( - null, att, att, Seq(att), Seq(att), att, att, Seq(att), Update, - Aggregate(Seq(attributeWithWatermark), aggExprs("c"), streamRelation), - isMapGroupsWithState = true), + FlatMapGroupsWithState(null, att, att, Seq(att), Seq(att), att, null, Update, + isMapGroupsWithState = true, null, + Aggregate(Seq(attributeWithWatermark), aggExprs("c"), streamRelation)), outputMode = outputMode, expectedMsgs = Seq("mapGroupsWithState", "with aggregation")) } @@ -314,11 +325,10 @@ class UnsupportedOperationsSuite extends SparkFunSuite { "mapGroupsWithState - multiple mapGroupsWithStates on streaming relation and all are " + "in append mode", FlatMapGroupsWithState( - null, att, att, Seq(att), Seq(att), att, att, Seq(att), Update, + null, att, att, Seq(att), Seq(att), att, null, Update, isMapGroupsWithState = true, null, FlatMapGroupsWithState( - null, att, att, Seq(att), Seq(att), att, att, Seq(att), Update, streamRelation, - isMapGroupsWithState = true), - isMapGroupsWithState = true), + null, att, att, Seq(att), Seq(att), att, null, Update, isMapGroupsWithState = true, null, + streamRelation)), outputMode = Append, expectedMsgs = Seq("multiple mapGroupsWithStates")) @@ -327,11 +337,11 @@ class UnsupportedOperationsSuite extends SparkFunSuite { "mapGroupsWithState - " + "mixing mapGroupsWithStates and flatMapGroupsWithStates on streaming relation", FlatMapGroupsWithState( - null, att, att, Seq(att), Seq(att), att, att, Seq(att), Update, + null, att, att, Seq(att), Seq(att), att, null, Update, isMapGroupsWithState = true, null, FlatMapGroupsWithState( - null, att, att, Seq(att), Seq(att), att, att, Seq(att), Update, streamRelation, - isMapGroupsWithState = false), - isMapGroupsWithState = true), + null, att, att, Seq(att), Seq(att), att, null, Update, isMapGroupsWithState = false, null, + streamRelation) + ), outputMode = Append, expectedMsgs = Seq("Mixing mapGroupsWithStates and flatMapGroupsWithStates")) diff --git a/sql/core/src/main/java/org/apache/spark/api/java/function/FlatMapGroupsWithStateFunction.java b/sql/core/src/main/java/org/apache/spark/api/java/function/FlatMapGroupsWithStateFunction.java index d44af7ef48157..29af78c4f6a85 100644 --- a/sql/core/src/main/java/org/apache/spark/api/java/function/FlatMapGroupsWithStateFunction.java +++ b/sql/core/src/main/java/org/apache/spark/api/java/function/FlatMapGroupsWithStateFunction.java @@ -22,7 +22,7 @@ import org.apache.spark.annotation.Experimental; import org.apache.spark.annotation.InterfaceStability; -import org.apache.spark.sql.KeyedState; +import org.apache.spark.sql.streaming.KeyedState; /** * ::Experimental:: diff --git a/sql/core/src/main/java/org/apache/spark/api/java/function/MapGroupsWithStateFunction.java b/sql/core/src/main/java/org/apache/spark/api/java/function/MapGroupsWithStateFunction.java index 75986d1706209..70f3f01a8e9da 100644 --- a/sql/core/src/main/java/org/apache/spark/api/java/function/MapGroupsWithStateFunction.java +++ b/sql/core/src/main/java/org/apache/spark/api/java/function/MapGroupsWithStateFunction.java @@ -22,7 +22,7 @@ import org.apache.spark.annotation.Experimental; import org.apache.spark.annotation.InterfaceStability; -import org.apache.spark.sql.KeyedState; +import org.apache.spark.sql.streaming.KeyedState; /** * ::Experimental:: diff --git a/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala index ab956ffd642e7..96437f868a6e0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala @@ -27,7 +27,7 @@ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.streaming.InternalOutputModes import org.apache.spark.sql.execution.QueryExecution import org.apache.spark.sql.expressions.ReduceAggregator -import org.apache.spark.sql.streaming.OutputMode +import org.apache.spark.sql.streaming.{KeyedState, KeyedStateTimeout, OutputMode} /** * :: Experimental :: @@ -228,13 +228,14 @@ class KeyValueGroupedDataset[K, V] private[sql]( * For a static batch Dataset, the function will be invoked once per group. For a streaming * Dataset, the function will be invoked for each group repeatedly in every trigger, and * updates to each group's state will be saved across invocations. - * See [[KeyedState]] for more details. + * See [[org.apache.spark.sql.streaming.KeyedState]] for more details. * * @tparam S The type of the user-defined state. Must be encodable to Spark SQL types. * @tparam U The type of the output objects. Must be encodable to Spark SQL types. + * @param func Function to be called on every group. * * See [[Encoder]] for more details on what types are encodable to Spark SQL. - * @since 2.1.1 + * @since 2.2.0 */ @Experimental @InterfaceStability.Evolving @@ -249,42 +250,49 @@ class KeyValueGroupedDataset[K, V] private[sql]( dataAttributes, OutputMode.Update, isMapGroupsWithState = true, + KeyedStateTimeout.NoTimeout, child = logicalPlan)) } /** * ::Experimental:: - * (Java-specific) + * (Scala-specific) * Applies the given function to each group of data, while maintaining a user-defined per-group * state. The result Dataset will represent the objects returned by the function. * For a static batch Dataset, the function will be invoked once per group. For a streaming * Dataset, the function will be invoked for each group repeatedly in every trigger, and * updates to each group's state will be saved across invocations. - * See [[KeyedState]] for more details. + * See [[org.apache.spark.sql.streaming.KeyedState]] for more details. * * @tparam S The type of the user-defined state. Must be encodable to Spark SQL types. * @tparam U The type of the output objects. Must be encodable to Spark SQL types. - * @param func Function to be called on every group. - * @param stateEncoder Encoder for the state type. - * @param outputEncoder Encoder for the output type. + * @param func Function to be called on every group. + * @param timeoutConf Timeout configuration for groups that do not receive data for a while. * * See [[Encoder]] for more details on what types are encodable to Spark SQL. - * @since 2.1.1 + * @since 2.2.0 */ @Experimental @InterfaceStability.Evolving - def mapGroupsWithState[S, U]( - func: MapGroupsWithStateFunction[K, V, S, U], - stateEncoder: Encoder[S], - outputEncoder: Encoder[U]): Dataset[U] = { - mapGroupsWithState[S, U]( - (key: K, it: Iterator[V], s: KeyedState[S]) => func.call(key, it.asJava, s) - )(stateEncoder, outputEncoder) + def mapGroupsWithState[S: Encoder, U: Encoder]( + timeoutConf: KeyedStateTimeout)( + func: (K, Iterator[V], KeyedState[S]) => U): Dataset[U] = { + val flatMapFunc = (key: K, it: Iterator[V], s: KeyedState[S]) => Iterator(func(key, it, s)) + Dataset[U]( + sparkSession, + FlatMapGroupsWithState[K, V, S, U]( + flatMapFunc.asInstanceOf[(Any, Iterator[Any], LogicalKeyedState[Any]) => Iterator[Any]], + groupingAttributes, + dataAttributes, + OutputMode.Update, + isMapGroupsWithState = true, + timeoutConf, + child = logicalPlan)) } /** * ::Experimental:: - * (Scala-specific) + * (Java-specific) * Applies the given function to each group of data, while maintaining a user-defined per-group * state. The result Dataset will represent the objects returned by the function. * For a static batch Dataset, the function will be invoked once per group. For a streaming @@ -294,33 +302,27 @@ class KeyValueGroupedDataset[K, V] private[sql]( * * @tparam S The type of the user-defined state. Must be encodable to Spark SQL types. * @tparam U The type of the output objects. Must be encodable to Spark SQL types. - * @param func Function to be called on every group. - * @param outputMode The output mode of the function. + * @param func Function to be called on every group. + * @param stateEncoder Encoder for the state type. + * @param outputEncoder Encoder for the output type. * * See [[Encoder]] for more details on what types are encodable to Spark SQL. - * @since 2.1.1 + * @since 2.2.0 */ @Experimental @InterfaceStability.Evolving - def flatMapGroupsWithState[S: Encoder, U: Encoder]( - func: (K, Iterator[V], KeyedState[S]) => Iterator[U], outputMode: OutputMode): Dataset[U] = { - if (outputMode != OutputMode.Append && outputMode != OutputMode.Update) { - throw new IllegalArgumentException("The output mode of function should be append or update") - } - Dataset[U]( - sparkSession, - FlatMapGroupsWithState[K, V, S, U]( - func.asInstanceOf[(Any, Iterator[Any], LogicalKeyedState[Any]) => Iterator[Any]], - groupingAttributes, - dataAttributes, - outputMode, - isMapGroupsWithState = false, - child = logicalPlan)) + def mapGroupsWithState[S, U]( + func: MapGroupsWithStateFunction[K, V, S, U], + stateEncoder: Encoder[S], + outputEncoder: Encoder[U]): Dataset[U] = { + mapGroupsWithState[S, U]( + (key: K, it: Iterator[V], s: KeyedState[S]) => func.call(key, it.asJava, s) + )(stateEncoder, outputEncoder) } /** * ::Experimental:: - * (Scala-specific) + * (Java-specific) * Applies the given function to each group of data, while maintaining a user-defined per-group * state. The result Dataset will represent the objects returned by the function. * For a static batch Dataset, the function will be invoked once per group. For a streaming @@ -330,22 +332,29 @@ class KeyValueGroupedDataset[K, V] private[sql]( * * @tparam S The type of the user-defined state. Must be encodable to Spark SQL types. * @tparam U The type of the output objects. Must be encodable to Spark SQL types. - * @param func Function to be called on every group. - * @param outputMode The output mode of the function. + * @param func Function to be called on every group. + * @param stateEncoder Encoder for the state type. + * @param outputEncoder Encoder for the output type. + * @param timeoutConf Timeout configuration for groups that do not receive data for a while. * * See [[Encoder]] for more details on what types are encodable to Spark SQL. - * @since 2.1.1 + * @since 2.2.0 */ @Experimental @InterfaceStability.Evolving - def flatMapGroupsWithState[S: Encoder, U: Encoder]( - func: (K, Iterator[V], KeyedState[S]) => Iterator[U], outputMode: String): Dataset[U] = { - flatMapGroupsWithState(func, InternalOutputModes(outputMode)) + def mapGroupsWithState[S, U]( + func: MapGroupsWithStateFunction[K, V, S, U], + stateEncoder: Encoder[S], + outputEncoder: Encoder[U], + timeoutConf: KeyedStateTimeout): Dataset[U] = { + mapGroupsWithState[S, U]( + (key: K, it: Iterator[V], s: KeyedState[S]) => func.call(key, it.asJava, s) + )(stateEncoder, outputEncoder) } /** * ::Experimental:: - * (Java-specific) + * (Scala-specific) * Applies the given function to each group of data, while maintaining a user-defined per-group * state. The result Dataset will represent the objects returned by the function. * For a static batch Dataset, the function will be invoked once per group. For a streaming @@ -355,25 +364,32 @@ class KeyValueGroupedDataset[K, V] private[sql]( * * @tparam S The type of the user-defined state. Must be encodable to Spark SQL types. * @tparam U The type of the output objects. Must be encodable to Spark SQL types. - * @param func Function to be called on every group. - * @param outputMode The output mode of the function. - * @param stateEncoder Encoder for the state type. - * @param outputEncoder Encoder for the output type. + * @param func Function to be called on every group. + * @param outputMode The output mode of the function. + * @param timeoutConf Timeout configuration for groups that do not receive data for a while. * * See [[Encoder]] for more details on what types are encodable to Spark SQL. - * @since 2.1.1 + * @since 2.2.0 */ @Experimental @InterfaceStability.Evolving - def flatMapGroupsWithState[S, U]( - func: FlatMapGroupsWithStateFunction[K, V, S, U], + def flatMapGroupsWithState[S: Encoder, U: Encoder]( outputMode: OutputMode, - stateEncoder: Encoder[S], - outputEncoder: Encoder[U]): Dataset[U] = { - flatMapGroupsWithState[S, U]( - (key: K, it: Iterator[V], s: KeyedState[S]) => func.call(key, it.asJava, s).asScala, - outputMode - )(stateEncoder, outputEncoder) + timeoutConf: KeyedStateTimeout)( + func: (K, Iterator[V], KeyedState[S]) => Iterator[U]): Dataset[U] = { + if (outputMode != OutputMode.Append && outputMode != OutputMode.Update) { + throw new IllegalArgumentException("The output mode of function should be append or update") + } + Dataset[U]( + sparkSession, + FlatMapGroupsWithState[K, V, S, U]( + func.asInstanceOf[(Any, Iterator[Any], LogicalKeyedState[Any]) => Iterator[Any]], + groupingAttributes, + dataAttributes, + outputMode, + isMapGroupsWithState = false, + timeoutConf, + child = logicalPlan)) } /** @@ -392,18 +408,21 @@ class KeyValueGroupedDataset[K, V] private[sql]( * @param outputMode The output mode of the function. * @param stateEncoder Encoder for the state type. * @param outputEncoder Encoder for the output type. + * @param timeoutConf Timeout configuration for groups that do not receive data for a while. * * See [[Encoder]] for more details on what types are encodable to Spark SQL. - * @since 2.1.1 + * @since 2.2.0 */ @Experimental @InterfaceStability.Evolving def flatMapGroupsWithState[S, U]( func: FlatMapGroupsWithStateFunction[K, V, S, U], - outputMode: String, + outputMode: OutputMode, stateEncoder: Encoder[S], - outputEncoder: Encoder[U]): Dataset[U] = { - flatMapGroupsWithState(func, InternalOutputModes(outputMode), stateEncoder, outputEncoder) + outputEncoder: Encoder[U], + timeoutConf: KeyedStateTimeout): Dataset[U] = { + val f = (key: K, it: Iterator[V], s: KeyedState[S]) => func.call(key, it.asJava, s).asScala + flatMapGroupsWithState[S, U](outputMode, timeoutConf)(f)(stateEncoder, outputEncoder) } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/KeyedState.scala b/sql/core/src/main/scala/org/apache/spark/sql/KeyedState.scala deleted file mode 100644 index 71efa4384211f..0000000000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/KeyedState.scala +++ /dev/null @@ -1,140 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql - -import org.apache.spark.annotation.{Experimental, InterfaceStability} -import org.apache.spark.sql.catalyst.plans.logical.LogicalKeyedState - -/** - * :: Experimental :: - * - * Wrapper class for interacting with keyed state data in `mapGroupsWithState` and - * `flatMapGroupsWithState` operations on - * [[KeyValueGroupedDataset]]. - * - * Detail description on `[map/flatMap]GroupsWithState` operation - * ------------------------------------------------------------ - * Both, `mapGroupsWithState` and `flatMapGroupsWithState` in [[KeyValueGroupedDataset]] - * will invoke the user-given function on each group (defined by the grouping function in - * `Dataset.groupByKey()`) while maintaining user-defined per-group state between invocations. - * For a static batch Dataset, the function will be invoked once per group. For a streaming - * Dataset, the function will be invoked for each group repeatedly in every trigger. - * That is, in every batch of the `streaming.StreamingQuery`, - * the function will be invoked once for each group that has data in the batch. - * - * The function is invoked with following parameters. - * - The key of the group. - * - An iterator containing all the values for this key. - * - A user-defined state object set by previous invocations of the given function. - * In case of a batch Dataset, there is only one invocation and state object will be empty as - * there is no prior state. Essentially, for batch Datasets, `[map/flatMap]GroupsWithState` - * is equivalent to `[map/flatMap]Groups`. - * - * Important points to note about the function. - * - In a trigger, the function will be called only the groups present in the batch. So do not - * assume that the function will be called in every trigger for every group that has state. - * - There is no guaranteed ordering of values in the iterator in the function, neither with - * batch, nor with streaming Datasets. - * - All the data will be shuffled before applying the function. - * - * Important points to note about using KeyedState. - * - The value of the state cannot be null. So updating state with null will throw - * `IllegalArgumentException`. - * - Operations on `KeyedState` are not thread-safe. This is to avoid memory barriers. - * - If `remove()` is called, then `exists()` will return `false`, - * `get()` will throw `NoSuchElementException` and `getOption()` will return `None` - * - After that, if `update(newState)` is called, then `exists()` will again return `true`, - * `get()` and `getOption()`will return the updated value. - * - * Scala example of using KeyedState in `mapGroupsWithState`: - * {{{ - * // A mapping function that maintains an integer state for string keys and returns a string. - * def mappingFunction(key: String, value: Iterator[Int], state: KeyedState[Int]): String = { - * // Check if state exists - * if (state.exists) { - * val existingState = state.get // Get the existing state - * val shouldRemove = ... // Decide whether to remove the state - * if (shouldRemove) { - * state.remove() // Remove the state - * } else { - * val newState = ... - * state.update(newState) // Set the new state - * } - * } else { - * val initialState = ... - * state.update(initialState) // Set the initial state - * } - * ... // return something - * } - * - * }}} - * - * Java example of using `KeyedState`: - * {{{ - * // A mapping function that maintains an integer state for string keys and returns a string. - * MapGroupsWithStateFunction mappingFunction = - * new MapGroupsWithStateFunction() { - * - * @Override - * public String call(String key, Iterator value, KeyedState state) { - * if (state.exists()) { - * int existingState = state.get(); // Get the existing state - * boolean shouldRemove = ...; // Decide whether to remove the state - * if (shouldRemove) { - * state.remove(); // Remove the state - * } else { - * int newState = ...; - * state.update(newState); // Set the new state - * } - * } else { - * int initialState = ...; // Set the initial state - * state.update(initialState); - * } - * ... // return something - * } - * }; - * }}} - * - * @tparam S User-defined type of the state to be stored for each key. Must be encodable into - * Spark SQL types (see [[Encoder]] for more details). - * @since 2.1.1 - */ -@Experimental -@InterfaceStability.Evolving -trait KeyedState[S] extends LogicalKeyedState[S] { - - /** Whether state exists or not. */ - def exists: Boolean - - /** Get the state value if it exists, or throw NoSuchElementException. */ - @throws[NoSuchElementException]("when state does not exist") - def get: S - - /** Get the state value as a scala Option. */ - def getOption: Option[S] - - /** - * Update the value of the state. Note that `null` is not a valid value, and it throws - * IllegalArgumentException. - */ - @throws[IllegalArgumentException]("when updating with null") - def update(newState: S): Unit - - /** Remove this keyed state. */ - def remove(): Unit -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 0f7aa3709c1cf..9e58e8ce3d5f8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -329,22 +329,14 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { * Strategy to convert [[FlatMapGroupsWithState]] logical operator to physical operator * in streaming plans. Conversion for batch plans is handled by [[BasicOperators]]. */ - object MapGroupsWithStateStrategy extends Strategy { + object FlatMapGroupsWithStateStrategy extends Strategy { override def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { case FlatMapGroupsWithState( - f, - keyDeser, - valueDeser, - groupAttr, - dataAttr, - outputAttr, - stateDeser, - stateSer, - outputMode, - child, - _) => + func, keyDeser, valueDeser, groupAttr, dataAttr, outputAttr, stateEnc, outputMode, _, + timeout, child) => val execPlan = FlatMapGroupsWithStateExec( - f, keyDeser, valueDeser, groupAttr, dataAttr, outputAttr, None, stateDeser, stateSer, + func, keyDeser, valueDeser, groupAttr, dataAttr, outputAttr, None, stateEnc, outputMode, + timeout, batchTimestampMs = KeyedStateImpl.NO_BATCH_PROCESSING_TIMESTAMP, planLater(child)) execPlan :: Nil case _ => @@ -392,7 +384,7 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { case logical.MapGroups(f, key, value, grouping, data, objAttr, child) => execution.MapGroupsExec(f, key, value, grouping, data, objAttr, planLater(child)) :: Nil case logical.FlatMapGroupsWithState( - f, key, value, grouping, data, output, _, _, _, child, _) => + f, key, value, grouping, data, output, _, _, _, _, child) => execution.MapGroupsExec(f, key, value, grouping, data, output, planLater(child)) :: Nil case logical.CoGroup(f, key, lObj, rObj, lGroup, rGroup, lAttr, rAttr, oAttr, left, right) => execution.CoGroupExec( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/commands.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/commands.scala index 5de45b159684c..41d91d877d4c2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/commands.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/commands.scala @@ -27,7 +27,7 @@ import org.apache.spark.sql.catalyst.plans.logical import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.execution.SparkPlan import org.apache.spark.sql.execution.debug._ -import org.apache.spark.sql.execution.streaming.IncrementalExecution +import org.apache.spark.sql.execution.streaming.{IncrementalExecution, OffsetSeqMetadata} import org.apache.spark.sql.streaming.OutputMode import org.apache.spark.sql.types._ @@ -106,7 +106,8 @@ case class ExplainCommand( if (logicalPlan.isStreaming) { // This is used only by explaining `Dataset/DataFrame` created by `spark.readStream`, so the // output mode does not matter since there is no `Sink`. - new IncrementalExecution(sparkSession, logicalPlan, OutputMode.Append(), "", 0, 0) + new IncrementalExecution( + sparkSession, logicalPlan, OutputMode.Append(), "", 0, OffsetSeqMetadata(0, 0)) } else { sparkSession.sessionState.executePlan(logicalPlan) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala new file mode 100644 index 0000000000000..991d8ef707567 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala @@ -0,0 +1,258 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.execution.streaming + +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder +import org.apache.spark.sql.catalyst.expressions.{Ascending, Attribute, AttributeReference, Expression, Literal, SortOrder, SpecificInternalRow, UnsafeProjection, UnsafeRow} +import org.apache.spark.sql.catalyst.plans.logical.{LogicalKeyedState, ProcessingTimeTimeout} +import org.apache.spark.sql.catalyst.plans.physical.{ClusteredDistribution, Distribution, Partitioning} +import org.apache.spark.sql.execution._ +import org.apache.spark.sql.execution.streaming.state._ +import org.apache.spark.sql.streaming.{KeyedStateTimeout, OutputMode} +import org.apache.spark.sql.types.{BooleanType, IntegerType} +import org.apache.spark.util.CompletionIterator + +/** + * Physical operator for executing `FlatMapGroupsWithState.` + * + * @param func function called on each group + * @param keyDeserializer used to extract the key object for each group. + * @param valueDeserializer used to extract the items in the iterator from an input row. + * @param groupingAttributes used to group the data + * @param dataAttributes used to read the data + * @param outputObjAttr used to define the output object + * @param stateEncoder used to serialize/deserialize state before calling `func` + * @param outputMode the output mode of `func` + * @param timeout used to timeout groups that have not received data in a while + * @param batchTimestampMs processing timestamp of the current batch. + */ +case class FlatMapGroupsWithStateExec( + func: (Any, Iterator[Any], LogicalKeyedState[Any]) => Iterator[Any], + keyDeserializer: Expression, + valueDeserializer: Expression, + groupingAttributes: Seq[Attribute], + dataAttributes: Seq[Attribute], + outputObjAttr: Attribute, + stateId: Option[OperatorStateId], + stateEncoder: ExpressionEncoder[Any], + outputMode: OutputMode, + timeout: KeyedStateTimeout, + batchTimestampMs: Long, + child: SparkPlan) extends UnaryExecNode with ObjectProducerExec with StateStoreWriter { + + private val isTimeoutEnabled = timeout == ProcessingTimeTimeout + private val timestampTimeoutAttribute = + AttributeReference("timeoutTimestamp", dataType = IntegerType, nullable = false)() + private val stateAttributes: Seq[Attribute] = { + val encSchemaAttribs = stateEncoder.schema.toAttributes + if (isTimeoutEnabled) encSchemaAttribs :+ timestampTimeoutAttribute else encSchemaAttribs + } + + import KeyedStateImpl._ + + /** Distribute by grouping attributes */ + override def requiredChildDistribution: Seq[Distribution] = + ClusteredDistribution(groupingAttributes) :: Nil + + /** Ordering needed for using GroupingIterator */ + override def requiredChildOrdering: Seq[Seq[SortOrder]] = + Seq(groupingAttributes.map(SortOrder(_, Ascending))) + + override protected def doExecute(): RDD[InternalRow] = { + metrics // force lazy init at driver + + child.execute().mapPartitionsWithStateStore[InternalRow]( + getStateId.checkpointLocation, + getStateId.operatorId, + getStateId.batchId, + groupingAttributes.toStructType, + stateAttributes.toStructType, + sqlContext.sessionState, + Some(sqlContext.streams.stateStoreCoordinator)) { case (store, iterator) => + val updater = new StateStoreUpdater(store) + + // Generate a iterator that returns the rows grouped by the grouping function + // Note that this code ensures that the filtering for timeout occurs only after + // all the data has been processed. This is to ensure that the timeout information of all + // the keys with data is updated before they are processed for timeouts. + val outputIterator = + updater.updateStateForKeysWithData(iterator) ++ updater.updateStateForTimedOutKeys() + + // Return an iterator of all the rows generated by all the keys, such that when fully + // consumed, all the state updates will be committed by the state store + CompletionIterator[InternalRow, Iterator[InternalRow]]( + outputIterator, + { + store.commit() + longMetric("numTotalStateRows") += store.numKeys() + } + ) + } + } + + /** Helper class to update the state store */ + class StateStoreUpdater(store: StateStore) { + + // Converters for translating input keys, values, output data between rows and Java objects + private val getKeyObj = + ObjectOperator.deserializeRowToObject(keyDeserializer, groupingAttributes) + private val getValueObj = + ObjectOperator.deserializeRowToObject(valueDeserializer, dataAttributes) + private val getOutputRow = ObjectOperator.wrapObjectToRow(outputObjAttr.dataType) + + // Converter for translating state rows to Java objects + private val getStateObjFromRow = ObjectOperator.deserializeRowToObject( + stateEncoder.resolveAndBind().deserializer, stateAttributes) + + // Converter for translating state Java objects to rows + private val stateSerializer = { + val encoderSerializer = stateEncoder.namedExpressions + if (isTimeoutEnabled) { + encoderSerializer :+ Literal(KeyedStateImpl.TIMEOUT_TIMESTAMP_NOT_SET) + } else { + encoderSerializer + } + } + private val getStateRowFromObj = ObjectOperator.serializeObjectToRow(stateSerializer) + + // Index of the additional metadata fields in the state row + private val timeoutTimestampIndex = stateAttributes.indexOf(timestampTimeoutAttribute) + + // Metrics + private val numUpdatedStateRows = longMetric("numUpdatedStateRows") + private val numOutputRows = longMetric("numOutputRows") + + /** + * For every group, get the key, values and corresponding state and call the function, + * and return an iterator of rows + */ + def updateStateForKeysWithData(dataIter: Iterator[InternalRow]): Iterator[InternalRow] = { + val groupedIter = GroupedIterator(dataIter, groupingAttributes, child.output) + groupedIter.flatMap { case (keyRow, valueRowIter) => + val keyUnsafeRow = keyRow.asInstanceOf[UnsafeRow] + callFunctionAndUpdateState( + keyUnsafeRow, + valueRowIter, + store.get(keyUnsafeRow), + hasTimedOut = false) + } + } + + /** Find the groups that have timeout set and are timing out right now, and call the function */ + def updateStateForTimedOutKeys(): Iterator[InternalRow] = { + if (isTimeoutEnabled) { + val timingOutKeys = store.filter { case (_, stateRow) => + val timeoutTimestamp = getTimeoutTimestamp(stateRow) + timeoutTimestamp != TIMEOUT_TIMESTAMP_NOT_SET && timeoutTimestamp < batchTimestampMs + } + timingOutKeys.flatMap { case (keyRow, stateRow) => + callFunctionAndUpdateState( + keyRow, + Iterator.empty, + Some(stateRow), + hasTimedOut = true) + } + } else Iterator.empty + } + + /** + * Call the user function on a key's data, update the state store, and return the return data + * iterator. Note that the store updating is lazy, that is, the store will be updated only + * after the returned iterator is fully consumed. + */ + private def callFunctionAndUpdateState( + keyRow: UnsafeRow, + valueRowIter: Iterator[InternalRow], + prevStateRowOption: Option[UnsafeRow], + hasTimedOut: Boolean): Iterator[InternalRow] = { + + val keyObj = getKeyObj(keyRow) // convert key to objects + val valueObjIter = valueRowIter.map(getValueObj.apply) // convert value rows to objects + val stateObjOption = getStateObj(prevStateRowOption) + val keyedState = new KeyedStateImpl( + stateObjOption, batchTimestampMs, isTimeoutEnabled, hasTimedOut) + + // Call function, get the returned objects and convert them to rows + val mappedIterator = func(keyObj, valueObjIter, keyedState).map { obj => + numOutputRows += 1 + getOutputRow(obj) + } + + // When the iterator is consumed, then write changes to state + def onIteratorCompletion: Unit = { + // Has the timeout information changed + + if (keyedState.hasRemoved) { + store.remove(keyRow) + numUpdatedStateRows += 1 + + } else { + val previousTimeoutTimestamp = prevStateRowOption match { + case Some(row) => getTimeoutTimestamp(row) + case None => TIMEOUT_TIMESTAMP_NOT_SET + } + + val stateRowToWrite = if (keyedState.hasUpdated) { + getStateRow(keyedState.get) + } else { + prevStateRowOption.orNull + } + + val hasTimeoutChanged = keyedState.getTimeoutTimestamp != previousTimeoutTimestamp + val shouldWriteState = keyedState.hasUpdated || hasTimeoutChanged + + if (shouldWriteState) { + if (stateRowToWrite == null) { + // This should never happen because checks in KeyedStateImpl should avoid cases + // where empty state would need to be written + throw new IllegalStateException( + "Attempting to write empty state") + } + setTimeoutTimestamp(stateRowToWrite, keyedState.getTimeoutTimestamp) + store.put(keyRow.copy(), stateRowToWrite.copy()) + numUpdatedStateRows += 1 + } + } + } + + // Return an iterator of rows such that fully consumed, the updated state value will be saved + CompletionIterator[InternalRow, Iterator[InternalRow]](mappedIterator, onIteratorCompletion) + } + + /** Returns the state as Java object if defined */ + def getStateObj(stateRowOption: Option[UnsafeRow]): Option[Any] = { + stateRowOption.map(getStateObjFromRow) + } + + /** Returns the row for an updated state */ + def getStateRow(obj: Any): UnsafeRow = { + getStateRowFromObj(obj) + } + + /** Returns the timeout timestamp of a state row is set */ + def getTimeoutTimestamp(stateRow: UnsafeRow): Long = { + if (isTimeoutEnabled) stateRow.getLong(timeoutTimestampIndex) else TIMEOUT_TIMESTAMP_NOT_SET + } + + /** Set the timestamp in a state row */ + def setTimeoutTimestamp(stateRow: UnsafeRow, timeoutTimestamps: Long): Unit = { + if (isTimeoutEnabled) stateRow.setLong(timeoutTimestampIndex, timeoutTimestamps) + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala index 610ce5e1ebf5d..a934c75a02457 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala @@ -37,13 +37,13 @@ class IncrementalExecution( val outputMode: OutputMode, val checkpointLocation: String, val currentBatchId: Long, - val currentEventTimeWatermark: Long) + offsetSeqMetadata: OffsetSeqMetadata) extends QueryExecution(sparkSession, logicalPlan) with Logging { // TODO: make this always part of planning. val streamingExtraStrategies = sparkSession.sessionState.planner.StatefulAggregationStrategy +: - sparkSession.sessionState.planner.MapGroupsWithStateStrategy +: + sparkSession.sessionState.planner.FlatMapGroupsWithStateStrategy +: sparkSession.sessionState.planner.StreamingRelationStrategy +: sparkSession.sessionState.planner.StreamingDeduplicationStrategy +: sparkSession.sessionState.experimentalMethods.extraStrategies @@ -88,12 +88,13 @@ class IncrementalExecution( keys, Some(stateId), Some(outputMode), - Some(currentEventTimeWatermark), + Some(offsetSeqMetadata.batchWatermarkMs), agg.withNewChildren( StateStoreRestoreExec( keys, Some(stateId), child) :: Nil)) + case StreamingDeduplicateExec(keys, child, None, None) => val stateId = OperatorStateId(checkpointLocation, operatorId.getAndIncrement(), currentBatchId) @@ -102,13 +103,12 @@ class IncrementalExecution( keys, child, Some(stateId), - Some(currentEventTimeWatermark)) - case FlatMapGroupsWithStateExec( - f, kDeser, vDeser, group, data, output, None, stateDeser, stateSer, child) => + Some(offsetSeqMetadata.batchWatermarkMs)) + + case m: FlatMapGroupsWithStateExec => val stateId = OperatorStateId(checkpointLocation, operatorId.getAndIncrement(), currentBatchId) - FlatMapGroupsWithStateExec( - f, kDeser, vDeser, group, data, output, Some(stateId), stateDeser, stateSer, child) + m.copy(stateId = Some(stateId), batchTimestampMs = offsetSeqMetadata.batchTimestampMs) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/KeyedStateImpl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/KeyedStateImpl.scala index eee7ec45dd77b..ac421d395beb4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/KeyedStateImpl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/KeyedStateImpl.scala @@ -17,15 +17,37 @@ package org.apache.spark.sql.execution.streaming -import org.apache.spark.sql.KeyedState +import org.apache.commons.lang3.StringUtils -/** Internal implementation of the [[KeyedState]] interface. Methods are not thread-safe. */ -private[sql] class KeyedStateImpl[S](optionalValue: Option[S]) extends KeyedState[S] { +import org.apache.spark.sql.streaming.KeyedState +import org.apache.spark.unsafe.types.CalendarInterval + +/** + * Internal implementation of the [[KeyedState]] interface. Methods are not thread-safe. + * @param optionalValue Optional value of the state + * @param batchProcessingTimeMs Processing time of current batch, used to calculate timestamp + * for processing time timeouts + * @param isTimeoutEnabled Whether timeout is enabled. This will be used to check whether the user + * is allowed to configure timeouts. + * @param hasTimedOut Whether the key for which this state wrapped is being created is + * getting timed out or not. + */ +private[sql] class KeyedStateImpl[S]( + optionalValue: Option[S], + batchProcessingTimeMs: Long, + isTimeoutEnabled: Boolean, + override val hasTimedOut: Boolean) extends KeyedState[S] { + + import KeyedStateImpl._ + + // Constructor to create dummy state when using mapGroupsWithState in a batch query + def this(optionalValue: Option[S]) = this( + optionalValue, -1, isTimeoutEnabled = false, hasTimedOut = false) private var value: S = optionalValue.getOrElse(null.asInstanceOf[S]) private var defined: Boolean = optionalValue.isDefined - private var updated: Boolean = false - // whether value has been updated (but not removed) + private var updated: Boolean = false // whether value has been updated (but not removed) private var removed: Boolean = false // whether value has been removed + private var timeoutTimestamp: Long = TIMEOUT_TIMESTAMP_NOT_SET // ========= Public API ========= override def exists: Boolean = defined @@ -60,6 +82,55 @@ private[sql] class KeyedStateImpl[S](optionalValue: Option[S]) extends KeyedStat defined = false updated = false removed = true + timeoutTimestamp = TIMEOUT_TIMESTAMP_NOT_SET + } + + override def setTimeoutDuration(durationMs: Long): Unit = { + if (!isTimeoutEnabled) { + throw new UnsupportedOperationException( + "Cannot set timeout information without enabling timeout in map/flatMapGroupsWithState") + } + if (!defined) { + throw new IllegalStateException( + "Cannot set timeout information without any state value, " + + "state has either not been initialized, or has already been removed") + } + + if (durationMs <= 0) { + throw new IllegalArgumentException("Timeout duration must be positive") + } + if (!removed && batchProcessingTimeMs != NO_BATCH_PROCESSING_TIMESTAMP) { + timeoutTimestamp = durationMs + batchProcessingTimeMs + } else { + // This is being called in a batch query, hence no processing timestamp. + // Just ignore any attempts to set timeout. + } + } + + override def setTimeoutDuration(duration: String): Unit = { + if (StringUtils.isBlank(duration)) { + throw new IllegalArgumentException( + "The window duration, slide duration and start time cannot be null or blank.") + } + val intervalString = if (duration.startsWith("interval")) { + duration + } else { + "interval " + duration + } + val cal = CalendarInterval.fromString(intervalString) + if (cal == null) { + throw new IllegalArgumentException( + s"The provided duration ($duration) is not valid.") + } + if (cal.milliseconds < 0 || cal.months < 0) { + throw new IllegalArgumentException("Timeout duration must be positive") + } + + val delayMs = { + val millisPerMonth = CalendarInterval.MICROS_PER_DAY / 1000 * 31 + cal.milliseconds + cal.months * millisPerMonth + } + setTimeoutDuration(delayMs) } override def toString: String = { @@ -69,12 +140,21 @@ private[sql] class KeyedStateImpl[S](optionalValue: Option[S]) extends KeyedStat // ========= Internal API ========= /** Whether the state has been marked for removing */ - def isRemoved: Boolean = { - removed - } + def hasRemoved: Boolean = removed - /** Whether the state has been been updated */ - def isUpdated: Boolean = { - updated - } + /** Whether the state has been updated */ + def hasUpdated: Boolean = updated + + /** Return timeout timestamp or `TIMEOUT_TIMESTAMP_NOT_SET` if not set */ + def getTimeoutTimestamp: Long = timeoutTimestamp +} + + +private[sql] object KeyedStateImpl { + // Value used in the state row to represent the lack of any timeout timestamp + val TIMEOUT_TIMESTAMP_NOT_SET = -1L + + // Value to represent that no batch processing timestamp is passed to KeyedStateImpl. This is + // used in batch queries where there are no streaming batches and timeouts. + val NO_BATCH_PROCESSING_TIMESTAMP = -1L } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala index 40faddccc2423..60d5283e6b211 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala @@ -590,7 +590,7 @@ class StreamExecution( outputMode, checkpointFile("state"), currentBatchId, - offsetSeqMetadata.batchWatermarkMs) + offsetSeqMetadata) lastExecution.executedPlan // Force the lazy generation of execution plan } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala index ab1204a750fac..f9dd80230e488 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala @@ -73,7 +73,12 @@ private[state] class HDFSBackedStateStoreProvider( hadoopConf: Configuration ) extends StateStoreProvider with Logging { - type MapType = java.util.HashMap[UnsafeRow, UnsafeRow] + // ConcurrentHashMap is used because it generates fail-safe iterators on filtering + // - The iterator is weakly consistent with the map, i.e., iterator's data reflect the values in + // the map when the iterator was created + // - Any updates to the map while iterating through the filtered iterator does not throw + // java.util.ConcurrentModificationException + type MapType = java.util.concurrent.ConcurrentHashMap[UnsafeRow, UnsafeRow] /** Implementation of [[StateStore]] API which is backed by a HDFS-compatible file system */ class HDFSBackedStateStore(val version: Long, mapToUpdate: MapType) @@ -99,6 +104,16 @@ private[state] class HDFSBackedStateStoreProvider( Option(mapToUpdate.get(key)) } + override def filter( + condition: (UnsafeRow, UnsafeRow) => Boolean): Iterator[(UnsafeRow, UnsafeRow)] = { + mapToUpdate + .entrySet + .asScala + .iterator + .filter { entry => condition(entry.getKey, entry.getValue) } + .map { entry => (entry.getKey, entry.getValue) } + } + override def put(key: UnsafeRow, value: UnsafeRow): Unit = { verify(state == UPDATING, "Cannot put after already committed or aborted") @@ -227,7 +242,7 @@ private[state] class HDFSBackedStateStoreProvider( } override def toString(): String = { - s"HDFSStateStore[id = (op=${id.operatorId}, part=${id.partitionId}), dir = $baseDir]" + s"HDFSStateStore[id=(op=${id.operatorId},part=${id.partitionId}),dir=$baseDir]" } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala index dcb24b26f78f3..eaa558eb6d0ed 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala @@ -50,6 +50,15 @@ trait StateStore { /** Get the current value of a key. */ def get(key: UnsafeRow): Option[UnsafeRow] + /** + * Return an iterator of key-value pairs that satisfy a certain condition. + * Note that the iterator must be fail-safe towards modification to the store, that is, + * it must be based on the snapshot of store the time of this call, and any change made to the + * store while iterating through iterator should not cause the iterator to fail or have + * any affect on the values in the iterator. + */ + def filter(condition: (UnsafeRow, UnsafeRow) => Boolean): Iterator[(UnsafeRow, UnsafeRow)] + /** Put a new value for a key. */ def put(key: UnsafeRow, value: UnsafeRow): Unit diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala index c3075a3eacaac..6d2de441eb44c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala @@ -19,17 +19,18 @@ package org.apache.spark.sql.execution.streaming import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder import org.apache.spark.sql.catalyst.errors._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen.{GenerateUnsafeProjection, Predicate} -import org.apache.spark.sql.catalyst.plans.logical.{EventTimeWatermark, LogicalKeyedState} +import org.apache.spark.sql.catalyst.plans.logical.{EventTimeWatermark, LogicalKeyedState, ProcessingTimeTimeout} import org.apache.spark.sql.catalyst.plans.physical.{ClusteredDistribution, Distribution, Partitioning} import org.apache.spark.sql.catalyst.streaming.InternalOutputModes._ import org.apache.spark.sql.execution._ -import org.apache.spark.sql.execution.metric.SQLMetrics +import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics} import org.apache.spark.sql.execution.streaming.state._ -import org.apache.spark.sql.streaming.OutputMode -import org.apache.spark.sql.types.{DataType, NullType, StructType} +import org.apache.spark.sql.streaming.{KeyedStateTimeout, OutputMode} +import org.apache.spark.sql.types._ import org.apache.spark.util.CompletionIterator @@ -256,94 +257,6 @@ case class StateStoreSaveExec( override def outputPartitioning: Partitioning = child.outputPartitioning } - -/** Physical operator for executing streaming flatMapGroupsWithState. */ -case class FlatMapGroupsWithStateExec( - func: (Any, Iterator[Any], LogicalKeyedState[Any]) => Iterator[Any], - keyDeserializer: Expression, - valueDeserializer: Expression, - groupingAttributes: Seq[Attribute], - dataAttributes: Seq[Attribute], - outputObjAttr: Attribute, - stateId: Option[OperatorStateId], - stateDeserializer: Expression, - stateSerializer: Seq[NamedExpression], - child: SparkPlan) extends UnaryExecNode with ObjectProducerExec with StateStoreWriter { - - override def outputPartitioning: Partitioning = child.outputPartitioning - - /** Distribute by grouping attributes */ - override def requiredChildDistribution: Seq[Distribution] = - ClusteredDistribution(groupingAttributes) :: Nil - - /** Ordering needed for using GroupingIterator */ - override def requiredChildOrdering: Seq[Seq[SortOrder]] = - Seq(groupingAttributes.map(SortOrder(_, Ascending))) - - override protected def doExecute(): RDD[InternalRow] = { - child.execute().mapPartitionsWithStateStore[InternalRow]( - getStateId.checkpointLocation, - getStateId.operatorId, - getStateId.batchId, - groupingAttributes.toStructType, - child.output.toStructType, - sqlContext.sessionState, - Some(sqlContext.streams.stateStoreCoordinator)) { (store, iter) => - val numTotalStateRows = longMetric("numTotalStateRows") - val numUpdatedStateRows = longMetric("numUpdatedStateRows") - val numOutputRows = longMetric("numOutputRows") - - // Generate a iterator that returns the rows grouped by the grouping function - val groupedIter = GroupedIterator(iter, groupingAttributes, child.output) - - // Converters to and from object and rows - val getKeyObj = ObjectOperator.deserializeRowToObject(keyDeserializer, groupingAttributes) - val getValueObj = ObjectOperator.deserializeRowToObject(valueDeserializer, dataAttributes) - val getOutputRow = ObjectOperator.wrapObjectToRow(outputObjAttr.dataType) - val getStateObj = - ObjectOperator.deserializeRowToObject(stateDeserializer) - val outputStateObj = ObjectOperator.serializeObjectToRow(stateSerializer) - - // For every group, get the key, values and corresponding state and call the function, - // and return an iterator of rows - val allRowsIterator = groupedIter.flatMap { case (keyRow, valueRowIter) => - - val key = keyRow.asInstanceOf[UnsafeRow] - val keyObj = getKeyObj(keyRow) // convert key to objects - val valueObjIter = valueRowIter.map(getValueObj.apply) // convert value rows to objects - val stateObjOption = store.get(key).map(getStateObj) // get existing state if any - val wrappedState = new KeyedStateImpl(stateObjOption) - val mappedIterator = func(keyObj, valueObjIter, wrappedState).map { obj => - numOutputRows += 1 - getOutputRow(obj) // convert back to rows - } - - // Return an iterator of rows generated this key, - // such that fully consumed, the updated state value will be saved - CompletionIterator[InternalRow, Iterator[InternalRow]]( - mappedIterator, { - // When the iterator is consumed, then write changes to state - if (wrappedState.isRemoved) { - store.remove(key) - numUpdatedStateRows += 1 - } else if (wrappedState.isUpdated) { - store.put(key, outputStateObj(wrappedState.get)) - numUpdatedStateRows += 1 - } - }) - } - - // Return an iterator of all the rows generated by all the keys, such that when fully - // consumer, all the state updates will be committed by the state store - CompletionIterator[InternalRow, Iterator[InternalRow]](allRowsIterator, { - store.commit() - numTotalStateRows += store.numKeys() - }) - } - } -} - - /** Physical operator for executing streaming Deduplicate. */ case class StreamingDeduplicateExec( keyExpressions: Seq[Attribute], diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/KeyedState.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/KeyedState.scala new file mode 100644 index 0000000000000..6b4b1ced98a34 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/KeyedState.scala @@ -0,0 +1,214 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.streaming + +import org.apache.spark.annotation.{Experimental, InterfaceStability} +import org.apache.spark.sql.{Encoder, KeyValueGroupedDataset} +import org.apache.spark.sql.catalyst.plans.logical.LogicalKeyedState + +/** + * :: Experimental :: + * + * Wrapper class for interacting with keyed state data in `mapGroupsWithState` and + * `flatMapGroupsWithState` operations on + * [[KeyValueGroupedDataset]]. + * + * Detail description on `[map/flatMap]GroupsWithState` operation + * -------------------------------------------------------------- + * Both, `mapGroupsWithState` and `flatMapGroupsWithState` in [[KeyValueGroupedDataset]] + * will invoke the user-given function on each group (defined by the grouping function in + * `Dataset.groupByKey()`) while maintaining user-defined per-group state between invocations. + * For a static batch Dataset, the function will be invoked once per group. For a streaming + * Dataset, the function will be invoked for each group repeatedly in every trigger. + * That is, in every batch of the `streaming.StreamingQuery`, + * the function will be invoked once for each group that has data in the trigger. Furthermore, + * if timeout is set, then the function will invoked on timed out keys (more detail below). + * + * The function is invoked with following parameters. + * - The key of the group. + * - An iterator containing all the values for this key. + * - A user-defined state object set by previous invocations of the given function. + * In case of a batch Dataset, there is only one invocation and state object will be empty as + * there is no prior state. Essentially, for batch Datasets, `[map/flatMap]GroupsWithState` + * is equivalent to `[map/flatMap]Groups` and any updates to the state and/or timeouts have + * no effect. + * + * Important points to note about the function. + * - In a trigger, the function will be called only the groups present in the batch. So do not + * assume that the function will be called in every trigger for every group that has state. + * - There is no guaranteed ordering of values in the iterator in the function, neither with + * batch, nor with streaming Datasets. + * - All the data will be shuffled before applying the function. + * - If timeout is set, then the function will also be called with no values. + * See more details on KeyedStateTimeout` below. + * + * Important points to note about using `KeyedState`. + * - The value of the state cannot be null. So updating state with null will throw + * `IllegalArgumentException`. + * - Operations on `KeyedState` are not thread-safe. This is to avoid memory barriers. + * - If `remove()` is called, then `exists()` will return `false`, + * `get()` will throw `NoSuchElementException` and `getOption()` will return `None` + * - After that, if `update(newState)` is called, then `exists()` will again return `true`, + * `get()` and `getOption()`will return the updated value. + * + * Important points to note about using `KeyedStateTimeout`. + * - The timeout type is a global param across all the keys (set as `timeout` param in + * `[map|flatMap]GroupsWithState`, but the exact timeout duration is configurable per key + * (by calling `setTimeout...()` in `KeyedState`). + * - When the timeout occurs for a key, the function is called with no values, and + * `KeyedState.hasTimedOut()` set to true. + * - The timeout is reset for key every time the function is called on the key, that is, + * when the key has new data, or the key has timed out. So the user has to set the timeout + * duration every time the function is called, otherwise there will not be any timeout set. + * - Guarantees provided on processing-time-based timeout of key, when timeout duration is D ms: + * - Timeout will never be called before real clock time has advanced by D ms + * - Timeout will be called eventually when there is a trigger in the query + * (i.e. after D ms). So there is a no strict upper bound on when the timeout would occur. + * For example, the trigger interval of the query will affect when the timeout is actually hit. + * If there is no data in the stream (for any key) for a while, then their will not be + * any trigger and timeout will not be hit until there is data. + * + * Scala example of using KeyedState in `mapGroupsWithState`: + * {{{ + * // A mapping function that maintains an integer state for string keys and returns a string. + * // Additionally, it sets a timeout to remove the state if it has not received data for an hour. + * def mappingFunction(key: String, value: Iterator[Int], state: KeyedState[Int]): String = { + * + * if (state.hasTimedOut) { // If called when timing out, remove the state + * state.remove() + * + * } else if (state.exists) { // If state exists, use it for processing + * val existingState = state.get // Get the existing state + * val shouldRemove = ... // Decide whether to remove the state + * if (shouldRemove) { + * state.remove() // Remove the state + * + * } else { + * val newState = ... + * state.update(newState) // Set the new state + * state.setTimeoutDuration("1 hour") // Set the timeout + * } + * + * } else { + * val initialState = ... + * state.update(initialState) // Set the initial state + * state.setTimeoutDuration("1 hour") // Set the timeout + * } + * ... + * // return something + * } + * + * dataset + * .groupByKey(...) + * .mapGroupsWithState(KeyedStateTimeout.ProcessingTimeTimeout)(mappingFunction) + * }}} + * + * Java example of using `KeyedState`: + * {{{ + * // A mapping function that maintains an integer state for string keys and returns a string. + * // Additionally, it sets a timeout to remove the state if it has not received data for an hour. + * MapGroupsWithStateFunction mappingFunction = + * new MapGroupsWithStateFunction() { + * + * @Override + * public String call(String key, Iterator value, KeyedState state) { + * if (state.hasTimedOut()) { // If called when timing out, remove the state + * state.remove(); + * + * } else if (state.exists()) { // If state exists, use it for processing + * int existingState = state.get(); // Get the existing state + * boolean shouldRemove = ...; // Decide whether to remove the state + * if (shouldRemove) { + * state.remove(); // Remove the state + * + * } else { + * int newState = ...; + * state.update(newState); // Set the new state + * state.setTimeoutDuration("1 hour"); // Set the timeout + * } + * + * } else { + * int initialState = ...; // Set the initial state + * state.update(initialState); + * state.setTimeoutDuration("1 hour"); // Set the timeout + * } + * ... +* // return something + * } + * }; + * + * dataset + * .groupByKey(...) + * .mapGroupsWithState( + * mappingFunction, Encoders.INT, Encoders.STRING, KeyedStateTimeout.ProcessingTimeTimeout); + * }}} + * + * @tparam S User-defined type of the state to be stored for each key. Must be encodable into + * Spark SQL types (see [[Encoder]] for more details). + * @since 2.2.0 + */ +@Experimental +@InterfaceStability.Evolving +trait KeyedState[S] extends LogicalKeyedState[S] { + + /** Whether state exists or not. */ + def exists: Boolean + + /** Get the state value if it exists, or throw NoSuchElementException. */ + @throws[NoSuchElementException]("when state does not exist") + def get: S + + /** Get the state value as a scala Option. */ + def getOption: Option[S] + + /** + * Update the value of the state. Note that `null` is not a valid value, and it throws + * IllegalArgumentException. + */ + @throws[IllegalArgumentException]("when updating with null") + def update(newState: S): Unit + + /** Remove this keyed state. Note that this resets any timeout configuration as well. */ + def remove(): Unit + + /** + * Whether the function has been called because the key has timed out. + * @note This can return true only when timeouts are enabled in `[map/flatmap]GroupsWithStates`. + */ + def hasTimedOut: Boolean + + /** + * Set the timeout duration in ms for this key. + * @note Timeouts must be enabled in `[map/flatmap]GroupsWithStates`. + */ + @throws[IllegalArgumentException]("if 'durationMs' is not positive") + @throws[IllegalStateException]("when state is either not initialized, or already removed") + @throws[UnsupportedOperationException]( + "if 'timeout' has not been enabled in [map|flatMap]GroupsWithState in a streaming query") + def setTimeoutDuration(durationMs: Long): Unit + + /** + * Set the timeout duration for this key as a string. For example, "1 hour", "2 days", etc. + * @note, Timeouts must be enabled in `[map/flatmap]GroupsWithStates`. + */ + @throws[IllegalArgumentException]("if 'duration' is not a valid duration") + @throws[IllegalStateException]("when state is either not initialized, or already removed") + @throws[UnsupportedOperationException]( + "if 'timeout' has not been enabled in [map|flatMap]GroupsWithState in a streaming query") + def setTimeoutDuration(duration: String): Unit +} diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java index 439cac3dfbcb7..ca9e5ad2ea86b 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java @@ -23,6 +23,7 @@ import java.sql.Timestamp; import java.util.*; +import org.apache.spark.sql.streaming.KeyedStateTimeout; import org.apache.spark.sql.streaming.OutputMode; import scala.Tuple2; import scala.Tuple3; @@ -208,7 +209,8 @@ public void testGroupBy() { }, OutputMode.Append(), Encoders.LONG(), - Encoders.STRING()); + Encoders.STRING(), + KeyedStateTimeout.NoTimeout()); Assert.assertEquals(asSet("1a", "3foobar"), toSet(flatMapped2.collectAsList())); diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala index e848f74e3159f..ebb7422765ebb 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala @@ -123,6 +123,30 @@ class StateStoreSuite extends SparkFunSuite with BeforeAndAfter with PrivateMeth assert(getDataFromFiles(provider, version = 2) === Set("b" -> 2, "c" -> 4)) } + test("filter and concurrent updates") { + val provider = newStoreProvider() + + // Verify state before starting a new set of updates + assert(provider.latestIterator.isEmpty) + val store = provider.getStore(0) + put(store, "a", 1) + put(store, "b", 2) + + // Updates should work while iterating of filtered entries + val filtered = store.filter { case (keyRow, _) => rowToString(keyRow) == "a" } + filtered.foreach { case (keyRow, valueRow) => + store.put(keyRow, intToRow(rowToInt(valueRow) + 1)) + } + assert(get(store, "a") === Some(2)) + + // Removes should work while iterating of filtered entries + val filtered2 = store.filter { case (keyRow, _) => rowToString(keyRow) == "b" } + filtered2.foreach { case (keyRow, _) => + store.remove(keyRow) + } + assert(get(store, "b") === None) + } + test("updates iterator with all combos of updates and removes") { val provider = newStoreProvider() var currentVersion: Int = 0 diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala index 902b842e97aa9..7daa5e6a0f61f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala @@ -17,20 +17,33 @@ package org.apache.spark.sql.streaming +import java.util +import java.util.concurrent.ConcurrentHashMap + import org.scalatest.BeforeAndAfterAll import org.apache.spark.SparkException -import org.apache.spark.sql.KeyedState +import org.apache.spark.api.java.function.FlatMapGroupsWithStateFunction +import org.apache.spark.sql.Encoder +import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, UnsafeProjection, UnsafeRow} +import org.apache.spark.sql.catalyst.plans.logical.FlatMapGroupsWithState +import org.apache.spark.sql.catalyst.plans.physical.UnknownPartitioning import org.apache.spark.sql.catalyst.streaming.InternalOutputModes._ -import org.apache.spark.sql.execution.streaming.{KeyedStateImpl, MemoryStream} -import org.apache.spark.sql.execution.streaming.state.StateStore +import org.apache.spark.sql.execution.RDDScanExec +import org.apache.spark.sql.execution.streaming.{FlatMapGroupsWithStateExec, KeyedStateImpl, MemoryStream} +import org.apache.spark.sql.execution.streaming.state.{StateStore, StateStoreId, StoreUpdate} +import org.apache.spark.sql.streaming.FlatMapGroupsWithStateSuite.MemoryStateStore +import org.apache.spark.sql.types.{DataType, IntegerType} /** Class to check custom state types */ case class RunningCount(count: Long) +case class Result(key: Long, count: Int) + class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest with BeforeAndAfterAll { import testImplicits._ + import KeyedStateImpl._ override def afterAll(): Unit = { super.afterAll() @@ -54,8 +67,8 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest with BeforeAndAf } } assert(state.getOption === expectedData) - assert(state.isUpdated === shouldBeUpdated) - assert(state.isRemoved === shouldBeRemoved) + assert(state.hasUpdated === shouldBeUpdated) + assert(state.hasRemoved === shouldBeRemoved) } // Updating empty state @@ -83,6 +96,79 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest with BeforeAndAf } } + test("KeyedState - setTimeoutDuration, hasTimedOut") { + import KeyedStateImpl._ + var state: KeyedStateImpl[Int] = null + + // When isTimeoutEnabled = false, then setTimeoutDuration() is not allowed + for (initState <- Seq(None, Some(5))) { + // for different initial state + state = new KeyedStateImpl(initState, 1000, isTimeoutEnabled = false, hasTimedOut = false) + assert(state.hasTimedOut === false) + assert(state.getTimeoutTimestamp === TIMEOUT_TIMESTAMP_NOT_SET) + intercept[UnsupportedOperationException] { + state.setTimeoutDuration(1000) + } + intercept[UnsupportedOperationException] { + state.setTimeoutDuration("1 day") + } + assert(state.getTimeoutTimestamp === TIMEOUT_TIMESTAMP_NOT_SET) + } + + def testTimeoutNotAllowed(): Unit = { + intercept[IllegalStateException] { + state.setTimeoutDuration(1000) + } + assert(state.getTimeoutTimestamp === TIMEOUT_TIMESTAMP_NOT_SET) + intercept[IllegalStateException] { + state.setTimeoutDuration("2 second") + } + assert(state.getTimeoutTimestamp === TIMEOUT_TIMESTAMP_NOT_SET) + } + + // When isTimeoutEnabled = true, then setTimeoutDuration() is not allowed until the + // state is be defined + state = new KeyedStateImpl(None, 1000, isTimeoutEnabled = true, hasTimedOut = false) + assert(state.hasTimedOut === false) + assert(state.getTimeoutTimestamp === TIMEOUT_TIMESTAMP_NOT_SET) + testTimeoutNotAllowed() + + // After state has been set, setTimeoutDuration() is allowed, and + // getTimeoutTimestamp returned correct timestamp + state.update(5) + assert(state.hasTimedOut === false) + assert(state.getTimeoutTimestamp === TIMEOUT_TIMESTAMP_NOT_SET) + state.setTimeoutDuration(1000) + assert(state.getTimeoutTimestamp === 2000) + state.setTimeoutDuration("2 second") + assert(state.getTimeoutTimestamp === 3000) + assert(state.hasTimedOut === false) + + // setTimeoutDuration() with negative values or 0 is not allowed + def testIllegalTimeout(body: => Unit): Unit = { + intercept[IllegalArgumentException] { body } + assert(state.getTimeoutTimestamp === TIMEOUT_TIMESTAMP_NOT_SET) + } + state = new KeyedStateImpl(Some(5), 1000, isTimeoutEnabled = true, hasTimedOut = false) + testIllegalTimeout { state.setTimeoutDuration(-1000) } + testIllegalTimeout { state.setTimeoutDuration(0) } + testIllegalTimeout { state.setTimeoutDuration("-2 second") } + testIllegalTimeout { state.setTimeoutDuration("-1 month") } + testIllegalTimeout { state.setTimeoutDuration("1 month -1 day") } + + // Test remove() clear timeout timestamp, and setTimeoutDuration() is not allowed after that + state = new KeyedStateImpl(Some(5), 1000, isTimeoutEnabled = true, hasTimedOut = false) + state.remove() + assert(state.hasTimedOut === false) + assert(state.getTimeoutTimestamp === TIMEOUT_TIMESTAMP_NOT_SET) + testTimeoutNotAllowed() + + // Test hasTimedOut = true + state = new KeyedStateImpl(Some(5), 1000, isTimeoutEnabled = true, hasTimedOut = true) + assert(state.hasTimedOut === true) + assert(state.getTimeoutTimestamp === TIMEOUT_TIMESTAMP_NOT_SET) + } + test("KeyedState - primitive type") { var intState = new KeyedStateImpl[Int](None) intercept[NoSuchElementException] { @@ -100,6 +186,151 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest with BeforeAndAf } } + // Values used for testing StateStoreUpdater + val currentTimestamp = 1000 + val beforeCurrentTimestamp = 999 + val afterCurrentTimestamp = 1001 + + // Tests for StateStoreUpdater.updateStateForKeysWithData() when timeout is disabled + for (priorState <- Seq(None, Some(0))) { + val priorStateStr = if (priorState.nonEmpty) "prior state set" else "no prior state" + val testName = s"timeout disabled - $priorStateStr - " + + testStateUpdateWithData( + testName + "no update", + stateUpdates = state => { /* do nothing */ }, + timeoutType = KeyedStateTimeout.NoTimeout, + priorState = priorState, + expectedState = priorState) // should not change + + testStateUpdateWithData( + testName + "state updated", + stateUpdates = state => { state.update(5) }, + timeoutType = KeyedStateTimeout.NoTimeout, + priorState = priorState, + expectedState = Some(5)) // should change + + testStateUpdateWithData( + testName + "state removed", + stateUpdates = state => { state.remove() }, + timeoutType = KeyedStateTimeout.NoTimeout, + priorState = priorState, + expectedState = None) // should be removed + } + + // Tests for StateStoreUpdater.updateStateForKeysWithData() when timeout is enabled + for (priorState <- Seq(None, Some(0))) { + for (priorTimeoutTimestamp <- Seq(TIMEOUT_TIMESTAMP_NOT_SET, 1000)) { + var testName = s"timeout enabled - " + if (priorState.nonEmpty) { + testName += "prior state set, " + if (priorTimeoutTimestamp == 1000) { + testName += "prior timeout set - " + } else { + testName += "no prior timeout - " + } + } else { + testName += "no prior state - " + } + + testStateUpdateWithData( + testName + "no update", + stateUpdates = state => { /* do nothing */ }, + timeoutType = KeyedStateTimeout.ProcessingTimeTimeout, + priorState = priorState, + priorTimeoutTimestamp = priorTimeoutTimestamp, + expectedState = priorState, // state should not change + expectedTimeoutTimestamp = TIMEOUT_TIMESTAMP_NOT_SET) // timestamp should be reset + + testStateUpdateWithData( + testName + "state updated", + stateUpdates = state => { state.update(5) }, + timeoutType = KeyedStateTimeout.ProcessingTimeTimeout, + priorState = priorState, + priorTimeoutTimestamp = priorTimeoutTimestamp, + expectedState = Some(5), // state should change + expectedTimeoutTimestamp = TIMEOUT_TIMESTAMP_NOT_SET) // timestamp should be reset + + testStateUpdateWithData( + testName + "state removed", + stateUpdates = state => { state.remove() }, + timeoutType = KeyedStateTimeout.ProcessingTimeTimeout, + priorState = priorState, + priorTimeoutTimestamp = priorTimeoutTimestamp, + expectedState = None) // state should be removed + + testStateUpdateWithData( + testName + "timeout and state updated", + stateUpdates = state => { state.update(5); state.setTimeoutDuration(5000) }, + timeoutType = KeyedStateTimeout.ProcessingTimeTimeout, + priorState = priorState, + priorTimeoutTimestamp = priorTimeoutTimestamp, + expectedState = Some(5), // state should change + expectedTimeoutTimestamp = currentTimestamp + 5000) // timestamp should change + } + } + + // Tests for StateStoreUpdater.updateStateForTimedOutKeys() + val preTimeoutState = Some(5) + + testStateUpdateWithTimeout( + "should not timeout", + stateUpdates = state => { assert(false, "function called without timeout") }, + priorTimeoutTimestamp = afterCurrentTimestamp, + expectedState = preTimeoutState, // state should not change + expectedTimeoutTimestamp = afterCurrentTimestamp) // timestamp should not change + + testStateUpdateWithTimeout( + "should timeout - no update/remove", + stateUpdates = state => { /* do nothing */ }, + priorTimeoutTimestamp = beforeCurrentTimestamp, + expectedState = preTimeoutState, // state should not change + expectedTimeoutTimestamp = TIMEOUT_TIMESTAMP_NOT_SET) // timestamp should be reset + + testStateUpdateWithTimeout( + "should timeout - update state", + stateUpdates = state => { state.update(5) }, + priorTimeoutTimestamp = beforeCurrentTimestamp, + expectedState = Some(5), // state should change + expectedTimeoutTimestamp = TIMEOUT_TIMESTAMP_NOT_SET) // timestamp should be reset + + testStateUpdateWithTimeout( + "should timeout - remove state", + stateUpdates = state => { state.remove() }, + priorTimeoutTimestamp = beforeCurrentTimestamp, + expectedState = None, // state should be removed + expectedTimeoutTimestamp = TIMEOUT_TIMESTAMP_NOT_SET) + + testStateUpdateWithTimeout( + "should timeout - timeout updated", + stateUpdates = state => { state.setTimeoutDuration(2000) }, + priorTimeoutTimestamp = beforeCurrentTimestamp, + expectedState = preTimeoutState, // state should not change + expectedTimeoutTimestamp = currentTimestamp + 2000) // timestamp should change + + testStateUpdateWithTimeout( + "should timeout - timeout and state updated", + stateUpdates = state => { state.update(5); state.setTimeoutDuration(2000) }, + priorTimeoutTimestamp = beforeCurrentTimestamp, + expectedState = Some(5), // state should change + expectedTimeoutTimestamp = currentTimestamp + 2000) // timestamp should change + + test("StateStoreUpdater - rows are cloned before writing to StateStore") { + // function for running count + val func = (key: Int, values: Iterator[Int], state: KeyedState[Int]) => { + state.update(state.getOption.getOrElse(0) + values.size) + Iterator.empty + } + val store = newStateStore() + val plan = newFlatMapGroupsWithStateExec(func) + val updater = new plan.StateStoreUpdater(store) + val data = Seq(1, 1, 2) + val returnIter = updater.updateStateForKeysWithData(data.iterator.map(intToRow)) + returnIter.size // consume the iterator to force store updates + val storeData = store.iterator.map { case (k, v) => (rowToInt(k), rowToInt(v)) }.toSet + assert(storeData === Set((1, 2), (2, 1))) + } + test("flatMapGroupsWithState - streaming") { // Function to maintain running count up to 2, and then remove the count // Returns the data and the count if state is defined, otherwise does not return anything @@ -119,7 +350,7 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest with BeforeAndAf val result = inputData.toDS() .groupByKey(x => x) - .flatMapGroupsWithState(stateFunc, Update) // State: Int, Out: (Str, Str) + .flatMapGroupsWithState(Update, KeyedStateTimeout.NoTimeout)(stateFunc) testStream(result, Update)( AddData(inputData, "a"), @@ -162,8 +393,7 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest with BeforeAndAf val result = inputData.toDS() .groupByKey(x => x) - .flatMapGroupsWithState(stateFunc, Update) // State: Int, Out: (Str, Str) - + .flatMapGroupsWithState(Update, KeyedStateTimeout.NoTimeout)(stateFunc) testStream(result, Update)( AddData(inputData, "a", "a", "b"), CheckLastBatch(("a", "1"), ("a", "2"), ("b", "1")), @@ -178,59 +408,118 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest with BeforeAndAf ) } + test("flatMapGroupsWithState - streaming + aggregation") { + // Function to maintain running count up to 2, and then remove the count + // Returns the data and the count (-1 if count reached beyond 2 and state was just removed) + val stateFunc = (key: String, values: Iterator[String], state: KeyedState[RunningCount]) => { + + val count = state.getOption.map(_.count).getOrElse(0L) + values.size + if (count == 3) { + state.remove() + Iterator(key -> "-1") + } else { + state.update(RunningCount(count)) + Iterator(key -> count.toString) + } + } + + val inputData = MemoryStream[String] + val result = + inputData.toDS() + .groupByKey(x => x) + .flatMapGroupsWithState(Append, KeyedStateTimeout.NoTimeout)(stateFunc) + .groupByKey(_._1) + .count() + + testStream(result, Complete)( + AddData(inputData, "a"), + CheckLastBatch(("a", 1)), + AddData(inputData, "a", "b"), + // mapGroups generates ("a", "2"), ("b", "1"); so increases counts of a and b by 1 + CheckLastBatch(("a", 2), ("b", 1)), + StopStream, + StartStream(), + AddData(inputData, "a", "b"), + // mapGroups should remove state for "a" and generate ("a", "-1"), ("b", "2") ; + // so increment a and b by 1 + CheckLastBatch(("a", 3), ("b", 2)), + StopStream, + StartStream(), + AddData(inputData, "a", "c"), + // mapGroups should recreate state for "a" and generate ("a", "1"), ("c", "1") ; + // so increment a and c by 1 + CheckLastBatch(("a", 4), ("b", 2), ("c", 1)) + ) + } + test("flatMapGroupsWithState - batch") { // Function that returns running count only if its even, otherwise does not return val stateFunc = (key: String, values: Iterator[String], state: KeyedState[RunningCount]) => { if (state.exists) throw new IllegalArgumentException("state.exists should be false") Iterator((key, values.size)) } - checkAnswer( - Seq("a", "a", "b").toDS.groupByKey(x => x).flatMapGroupsWithState(stateFunc, Update).toDF, - Seq(("a", 2), ("b", 1)).toDF) + val df = Seq("a", "a", "b").toDS + .groupByKey(x => x) + .flatMapGroupsWithState(Update, KeyedStateTimeout.NoTimeout)(stateFunc).toDF + checkAnswer(df, Seq(("a", 2), ("b", 1)).toDF) } - test("mapGroupsWithState - streaming") { + test("flatMapGroupsWithState - streaming with processing time timeout") { // Function to maintain running count up to 2, and then remove the count // Returns the data and the count (-1 if count reached beyond 2 and state was just removed) val stateFunc = (key: String, values: Iterator[String], state: KeyedState[RunningCount]) => { - - val count = state.getOption.map(_.count).getOrElse(0L) + values.size - if (count == 3) { + if (state.hasTimedOut) { state.remove() - (key, "-1") + Iterator((key, "-1")) } else { + val count = state.getOption.map(_.count).getOrElse(0L) + values.size state.update(RunningCount(count)) - (key, count.toString) + state.setTimeoutDuration("10 seconds") + Iterator((key, count.toString)) } } + val clock = new StreamManualClock val inputData = MemoryStream[String] + val timeout = KeyedStateTimeout.ProcessingTimeTimeout val result = inputData.toDS() .groupByKey(x => x) - .mapGroupsWithState(stateFunc) // Types = State: MyState, Out: (Str, Str) + .flatMapGroupsWithState(Update, timeout)(stateFunc) testStream(result, Update)( + StartStream(ProcessingTime("1 second"), triggerClock = clock), AddData(inputData, "a"), + AdvanceManualClock(1 * 1000), CheckLastBatch(("a", "1")), assertNumStateRows(total = 1, updated = 1), - AddData(inputData, "a", "b"), - CheckLastBatch(("a", "2"), ("b", "1")), - assertNumStateRows(total = 2, updated = 2), - StopStream, - StartStream(), - AddData(inputData, "a", "b"), // should remove state for "a" and return count as -1 + + AddData(inputData, "b"), + AdvanceManualClock(1 * 1000), + CheckLastBatch(("b", "1")), + assertNumStateRows(total = 2, updated = 1), + + AddData(inputData, "b"), + AdvanceManualClock(10 * 1000), CheckLastBatch(("a", "-1"), ("b", "2")), assertNumStateRows(total = 1, updated = 2), + StopStream, - StartStream(), - AddData(inputData, "a", "c"), // should recreate state for "a" and return count as 1 - CheckLastBatch(("a", "1"), ("c", "1")), - assertNumStateRows(total = 3, updated = 2) + StartStream(ProcessingTime("1 second"), triggerClock = clock), + + AddData(inputData, "c"), + AdvanceManualClock(20 * 1000), + CheckLastBatch(("b", "-1"), ("c", "1")), + assertNumStateRows(total = 1, updated = 2), + + AddData(inputData, "c"), + AdvanceManualClock(20 * 1000), + CheckLastBatch(("c", "2")), + assertNumStateRows(total = 1, updated = 1) ) } - test("flatMapGroupsWithState - streaming + aggregation") { + test("mapGroupsWithState - streaming") { // Function to maintain running count up to 2, and then remove the count // Returns the data and the count (-1 if count reached beyond 2 and state was just removed) val stateFunc = (key: String, values: Iterator[String], state: KeyedState[RunningCount]) => { @@ -238,10 +527,10 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest with BeforeAndAf val count = state.getOption.map(_.count).getOrElse(0L) + values.size if (count == 3) { state.remove() - Iterator(key -> "-1") + (key, "-1") } else { state.update(RunningCount(count)) - Iterator(key -> count.toString) + (key, count.toString) } } @@ -249,28 +538,25 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest with BeforeAndAf val result = inputData.toDS() .groupByKey(x => x) - .flatMapGroupsWithState(stateFunc, Append) // Types = State: MyState, Out: (Str, Str) - .groupByKey(_._1) - .count() + .mapGroupsWithState(stateFunc) // Types = State: MyState, Out: (Str, Str) - testStream(result, Complete)( + testStream(result, Update)( AddData(inputData, "a"), - CheckLastBatch(("a", 1)), + CheckLastBatch(("a", "1")), + assertNumStateRows(total = 1, updated = 1), AddData(inputData, "a", "b"), - // mapGroups generates ("a", "2"), ("b", "1"); so increases counts of a and b by 1 - CheckLastBatch(("a", 2), ("b", 1)), + CheckLastBatch(("a", "2"), ("b", "1")), + assertNumStateRows(total = 2, updated = 2), StopStream, StartStream(), - AddData(inputData, "a", "b"), - // mapGroups should remove state for "a" and generate ("a", "-1"), ("b", "2") ; - // so increment a and b by 1 - CheckLastBatch(("a", 3), ("b", 2)), + AddData(inputData, "a", "b"), // should remove state for "a" and return count as -1 + CheckLastBatch(("a", "-1"), ("b", "2")), + assertNumStateRows(total = 1, updated = 2), StopStream, StartStream(), - AddData(inputData, "a", "c"), - // mapGroups should recreate state for "a" and generate ("a", "1"), ("c", "1") ; - // so increment a and c by 1 - CheckLastBatch(("a", 4), ("b", 2), ("c", 1)) + AddData(inputData, "a", "c"), // should recreate state for "a" and return count as 1 + CheckLastBatch(("a", "1"), ("c", "1")), + assertNumStateRows(total = 3, updated = 2) ) } @@ -322,23 +608,185 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest with BeforeAndAf ) } + test("output partitioning is unknown") { + val stateFunc = (key: String, values: Iterator[String], state: KeyedState[RunningCount]) => key + val inputData = MemoryStream[String] + val result = inputData.toDS.groupByKey(x => x).mapGroupsWithState(stateFunc) + result + testStream(result, Update)( + AddData(inputData, "a"), + CheckLastBatch("a"), + AssertOnQuery(_.lastExecution.executedPlan.outputPartitioning === UnknownPartitioning(0)) + ) + } + test("disallow complete mode") { - val stateFunc = (key: String, values: Iterator[String], state: KeyedState[RunningCount]) => { + val stateFunc = (key: String, values: Iterator[String], state: KeyedState[Int]) => { Iterator[String]() } var e = intercept[IllegalArgumentException] { - MemoryStream[String].toDS().groupByKey(x => x).flatMapGroupsWithState(stateFunc, Complete) + MemoryStream[String].toDS().groupByKey(x => x).flatMapGroupsWithState( + OutputMode.Complete, KeyedStateTimeout.NoTimeout)(stateFunc) } assert(e.getMessage === "The output mode of function should be append or update") + val javaStateFunc = new FlatMapGroupsWithStateFunction[String, String, Int, String] { + import java.util.{Iterator => JIterator} + override def call( + key: String, + values: JIterator[String], + state: KeyedState[Int]): JIterator[String] = { null } + } e = intercept[IllegalArgumentException] { - MemoryStream[String].toDS().groupByKey(x => x).flatMapGroupsWithState(stateFunc, "complete") + MemoryStream[String].toDS().groupByKey(x => x).flatMapGroupsWithState( + javaStateFunc, OutputMode.Complete, + implicitly[Encoder[Int]], implicitly[Encoder[String]], KeyedStateTimeout.NoTimeout) } assert(e.getMessage === "The output mode of function should be append or update") } + + def testStateUpdateWithData( + testName: String, + stateUpdates: KeyedState[Int] => Unit, + timeoutType: KeyedStateTimeout = KeyedStateTimeout.NoTimeout, + priorState: Option[Int], + priorTimeoutTimestamp: Long = TIMEOUT_TIMESTAMP_NOT_SET, + expectedState: Option[Int] = None, + expectedTimeoutTimestamp: Long = TIMEOUT_TIMESTAMP_NOT_SET): Unit = { + + if (priorState.isEmpty && priorTimeoutTimestamp != TIMEOUT_TIMESTAMP_NOT_SET) { + return // there can be no prior timestamp, when there is no prior state + } + test(s"StateStoreUpdater - updates with data - $testName") { + val mapGroupsFunc = (key: Int, values: Iterator[Int], state: KeyedState[Int]) => { + assert(state.hasTimedOut === false, "hasTimedOut not false") + assert(values.nonEmpty, "Some value is expected") + stateUpdates(state) + Iterator.empty + } + testStateUpdate( + testTimeoutUpdates = false, mapGroupsFunc, timeoutType, + priorState, priorTimeoutTimestamp, expectedState, expectedTimeoutTimestamp) + } + } + + def testStateUpdateWithTimeout( + testName: String, + stateUpdates: KeyedState[Int] => Unit, + priorTimeoutTimestamp: Long, + expectedState: Option[Int], + expectedTimeoutTimestamp: Long = TIMEOUT_TIMESTAMP_NOT_SET): Unit = { + + test(s"StateStoreUpdater - updates for timeout - $testName") { + val mapGroupsFunc = (key: Int, values: Iterator[Int], state: KeyedState[Int]) => { + assert(state.hasTimedOut === true, "hasTimedOut not true") + assert(values.isEmpty, "values not empty") + stateUpdates(state) + Iterator.empty + } + testStateUpdate( + testTimeoutUpdates = true, mapGroupsFunc, KeyedStateTimeout.ProcessingTimeTimeout, + preTimeoutState, priorTimeoutTimestamp, + expectedState, expectedTimeoutTimestamp) + } + } + + def testStateUpdate( + testTimeoutUpdates: Boolean, + mapGroupsFunc: (Int, Iterator[Int], KeyedState[Int]) => Iterator[Int], + timeoutType: KeyedStateTimeout, + priorState: Option[Int], + priorTimeoutTimestamp: Long, + expectedState: Option[Int], + expectedTimeoutTimestamp: Long): Unit = { + + val store = newStateStore() + val mapGroupsSparkPlan = newFlatMapGroupsWithStateExec( + mapGroupsFunc, timeoutType, currentTimestamp) + val updater = new mapGroupsSparkPlan.StateStoreUpdater(store) + val key = intToRow(0) + // Prepare store with prior state configs + if (priorState.nonEmpty) { + val row = updater.getStateRow(priorState.get) + updater.setTimeoutTimestamp(row, priorTimeoutTimestamp) + store.put(key.copy(), row.copy()) + } + + // Call updating function to update state store + val returnedIter = if (testTimeoutUpdates) { + updater.updateStateForTimedOutKeys() + } else { + updater.updateStateForKeysWithData(Iterator(key)) + } + returnedIter.size // consumer the iterator to force state updates + + // Verify updated state in store + val updatedStateRow = store.get(key) + assert( + updater.getStateObj(updatedStateRow).map(_.toString.toInt) === expectedState, + "final state not as expected") + if (updatedStateRow.nonEmpty) { + assert( + updater.getTimeoutTimestamp(updatedStateRow.get) === expectedTimeoutTimestamp, + "final timeout timestamp not as expected") + } + } + + def newFlatMapGroupsWithStateExec( + func: (Int, Iterator[Int], KeyedState[Int]) => Iterator[Int], + timeoutType: KeyedStateTimeout = KeyedStateTimeout.NoTimeout, + batchTimestampMs: Long = NO_BATCH_PROCESSING_TIMESTAMP): FlatMapGroupsWithStateExec = { + MemoryStream[Int] + .toDS + .groupByKey(x => x) + .flatMapGroupsWithState[Int, Int](Append, timeoutConf = timeoutType)(func) + .logicalPlan.collectFirst { + case FlatMapGroupsWithState(f, k, v, g, d, o, s, m, _, t, _) => + FlatMapGroupsWithStateExec( + f, k, v, g, d, o, None, s, m, t, currentTimestamp, + RDDScanExec(g, null, "rdd")) + }.get + } + + def newStateStore(): StateStore = new MemoryStateStore() + + val intProj = UnsafeProjection.create(Array[DataType](IntegerType)) + def intToRow(i: Int): UnsafeRow = { + intProj.apply(new GenericInternalRow(Array[Any](i))).copy() + } + + def rowToInt(row: UnsafeRow): Int = row.getInt(0) } object FlatMapGroupsWithStateSuite { + var failInTask = true + + class MemoryStateStore extends StateStore() { + import scala.collection.JavaConverters._ + private val map = new ConcurrentHashMap[UnsafeRow, UnsafeRow] + + override def iterator(): Iterator[(UnsafeRow, UnsafeRow)] = { + map.entrySet.iterator.asScala.map { case e => (e.getKey, e.getValue) } + } + + override def filter(c: (UnsafeRow, UnsafeRow) => Boolean): Iterator[(UnsafeRow, UnsafeRow)] = { + iterator.filter { case (k, v) => c(k, v) } + } + + override def get(key: UnsafeRow): Option[UnsafeRow] = Option(map.get(key)) + override def put(key: UnsafeRow, newValue: UnsafeRow): Unit = map.put(key, newValue) + override def remove(key: UnsafeRow): Unit = { map.remove(key) } + override def remove(condition: (UnsafeRow) => Boolean): Unit = { + iterator.map(_._1).filter(condition).foreach(map.remove) + } + override def commit(): Long = version + 1 + override def abort(): Unit = { } + override def id: StateStoreId = null + override def version: Long = 0 + override def updates(): Iterator[StoreUpdate] = { throw new UnsupportedOperationException } + override def numKeys(): Long = map.size + override def hasCommitted: Boolean = true + } } From 0cdcf9114527a2c359c25e46fd6556b3855bfb28 Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Sun, 19 Mar 2017 22:33:01 -0700 Subject: [PATCH 0056/1765] [SPARK-19849][SQL] Support ArrayType in to_json to produce JSON array ## What changes were proposed in this pull request? This PR proposes to support an array of struct type in `to_json` as below: ```scala import org.apache.spark.sql.functions._ val df = Seq(Tuple1(Tuple1(1) :: Nil)).toDF("a") df.select(to_json($"a").as("json")).show() ``` ``` +----------+ | json| +----------+ |[{"_1":1}]| +----------+ ``` Currently, it throws an exception as below (a newline manually inserted for readability): ``` org.apache.spark.sql.AnalysisException: cannot resolve 'structtojson(`array`)' due to data type mismatch: structtojson requires that the expression is a struct expression.;; ``` This allows the roundtrip with `from_json` as below: ```scala import org.apache.spark.sql.functions._ import org.apache.spark.sql.types._ val schema = ArrayType(StructType(StructField("a", IntegerType) :: Nil)) val df = Seq("""[{"a":1}, {"a":2}]""").toDF("json").select(from_json($"json", schema).as("array")) df.show() // Read back. df.select(to_json($"array").as("json")).show() ``` ``` +----------+ | array| +----------+ |[[1], [2]]| +----------+ +-----------------+ | json| +-----------------+ |[{"a":1},{"a":2}]| +-----------------+ ``` Also, this PR proposes to rename from `StructToJson` to `StructsToJson ` and `JsonToStruct` to `JsonToStructs`. ## How was this patch tested? Unit tests in `JsonFunctionsSuite` and `JsonExpressionsSuite` for Scala, doctest for Python and test in `test_sparkSQL.R` for R. Author: hyukjinkwon Closes #17192 from HyukjinKwon/SPARK-19849. --- R/pkg/R/functions.R | 18 ++-- R/pkg/inst/tests/testthat/test_sparkSQL.R | 4 + python/pyspark/sql/functions.py | 15 ++- .../catalyst/analysis/FunctionRegistry.scala | 4 +- .../expressions/jsonExpressions.scala | 70 +++++++++----- .../sql/catalyst/json/JacksonGenerator.scala | 23 +++-- .../expressions/JsonExpressionsSuite.scala | 77 ++++++++++----- .../org/apache/spark/sql/functions.scala | 34 ++++--- .../sql-tests/inputs/json-functions.sql | 1 + .../sql-tests/results/json-functions.sql.out | 96 ++++++++++--------- .../apache/spark/sql/JsonFunctionsSuite.scala | 26 ++++- 11 files changed, 236 insertions(+), 132 deletions(-) diff --git a/R/pkg/R/functions.R b/R/pkg/R/functions.R index 9867f2d5b7c51..2cff3ac08c3ae 100644 --- a/R/pkg/R/functions.R +++ b/R/pkg/R/functions.R @@ -1795,10 +1795,10 @@ setMethod("to_date", #' to_json #' -#' Converts a column containing a \code{structType} into a Column of JSON string. -#' Resolving the Column can fail if an unsupported type is encountered. +#' Converts a column containing a \code{structType} or array of \code{structType} into a Column +#' of JSON string. Resolving the Column can fail if an unsupported type is encountered. #' -#' @param x Column containing the struct +#' @param x Column containing the struct or array of the structs #' @param ... additional named properties to control how it is converted, accepts the same options #' as the JSON data source. #' @@ -1809,8 +1809,13 @@ setMethod("to_date", #' @export #' @examples #' \dontrun{ -#' to_json(df$t, dateFormat = 'dd/MM/yyyy') -#' select(df, to_json(df$t)) +#' # Converts a struct into a JSON object +#' df <- sql("SELECT named_struct('date', cast('2000-01-01' as date)) as d") +#' select(df, to_json(df$d, dateFormat = 'dd/MM/yyyy')) +#' +#' # Converts an array of structs into a JSON array +#' df <- sql("SELECT array(named_struct('name', 'Bob'), named_struct('name', 'Alice')) as people") +#' select(df, to_json(df$people)) #'} #' @note to_json since 2.2.0 setMethod("to_json", signature(x = "Column"), @@ -2433,7 +2438,8 @@ setMethod("date_format", signature(y = "Column", x = "character"), #' from_json #' #' Parses a column containing a JSON string into a Column of \code{structType} with the specified -#' \code{schema}. If the string is unparseable, the Column will contains the value NA. +#' \code{schema} or array of \code{structType} if \code{asJsonArray} is set to \code{TRUE}. +#' If the string is unparseable, the Column will contains the value NA. #' #' @param x Column containing the JSON string. #' @param schema a structType object to use as the schema to use when parsing the JSON string. diff --git a/R/pkg/inst/tests/testthat/test_sparkSQL.R b/R/pkg/inst/tests/testthat/test_sparkSQL.R index 32856b399cdd1..9c38e0d866aa3 100644 --- a/R/pkg/inst/tests/testthat/test_sparkSQL.R +++ b/R/pkg/inst/tests/testthat/test_sparkSQL.R @@ -1340,6 +1340,10 @@ test_that("column functions", { expect_equal(collect(select(df, bround(df$x, 0)))[[1]][2], 4) # Test to_json(), from_json() + df <- sql("SELECT array(named_struct('name', 'Bob'), named_struct('name', 'Alice')) as people") + j <- collect(select(df, alias(to_json(df$people), "json"))) + expect_equal(j[order(j$json), ][1], "[{\"name\":\"Bob\"},{\"name\":\"Alice\"}]") + df <- read.json(mapTypeJsonPath) j <- collect(select(df, alias(to_json(df$info), "json"))) expect_equal(j[order(j$json), ][1], "{\"age\":16,\"height\":176.5}") diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index 376b86ea69bd4..f9121e60f35b8 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -1774,10 +1774,11 @@ def json_tuple(col, *fields): def from_json(col, schema, options={}): """ Parses a column containing a JSON string into a [[StructType]] or [[ArrayType]] - with the specified schema. Returns `null`, in the case of an unparseable string. + of [[StructType]]s with the specified schema. Returns `null`, in the case of an unparseable + string. :param col: string column in json format - :param schema: a StructType or ArrayType to use when parsing the json column + :param schema: a StructType or ArrayType of StructType to use when parsing the json column :param options: options to control parsing. accepts the same options as the json datasource >>> from pyspark.sql.types import * @@ -1802,10 +1803,10 @@ def from_json(col, schema, options={}): @since(2.1) def to_json(col, options={}): """ - Converts a column containing a [[StructType]] into a JSON string. Throws an exception, - in the case of an unsupported type. + Converts a column containing a [[StructType]] or [[ArrayType]] of [[StructType]]s into a + JSON string. Throws an exception, in the case of an unsupported type. - :param col: name of column containing the struct + :param col: name of column containing the struct or array of the structs :param options: options to control converting. accepts the same options as the json datasource >>> from pyspark.sql import Row @@ -1814,6 +1815,10 @@ def to_json(col, options={}): >>> df = spark.createDataFrame(data, ("key", "value")) >>> df.select(to_json(df.value).alias("json")).collect() [Row(json=u'{"age":2,"name":"Alice"}')] + >>> data = [(1, [Row(name='Alice', age=2), Row(name='Bob', age=3)])] + >>> df = spark.createDataFrame(data, ("key", "value")) + >>> df.select(to_json(df.value).alias("json")).collect() + [Row(json=u'[{"age":2,"name":"Alice"},{"age":3,"name":"Bob"}]')] """ sc = SparkContext._active_spark_context diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index 0486e67dbdf86..e1d83a86f99dc 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -425,8 +425,8 @@ object FunctionRegistry { expression[BitwiseXor]("^"), // json - expression[StructToJson]("to_json"), - expression[JsonToStruct]("from_json"), + expression[StructsToJson]("to_json"), + expression[JsonToStructs]("from_json"), // Cast aliases (SPARK-16730) castAlias("boolean", BooleanType), diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala index 37e4bb5060436..e4e08a8665a5a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala @@ -29,7 +29,7 @@ import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback import org.apache.spark.sql.catalyst.parser.CatalystSqlParser import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.json._ -import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, GenericArrayData, ParseModes} +import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, ArrayData, GenericArrayData, ParseModes} import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String import org.apache.spark.util.Utils @@ -482,7 +482,8 @@ case class JsonTuple(children: Seq[Expression]) } /** - * Converts an json input string to a [[StructType]] or [[ArrayType]] with the specified schema. + * Converts an json input string to a [[StructType]] or [[ArrayType]] of [[StructType]]s + * with the specified schema. */ // scalastyle:off line.size.limit @ExpressionDescription( @@ -495,7 +496,7 @@ case class JsonTuple(children: Seq[Expression]) {"time":"2015-08-26 00:00:00.0"} """) // scalastyle:on line.size.limit -case class JsonToStruct( +case class JsonToStructs( schema: DataType, options: Map[String, String], child: Expression, @@ -590,7 +591,7 @@ case class JsonToStruct( } /** - * Converts a [[StructType]] to a json output string. + * Converts a [[StructType]] or [[ArrayType]] of [[StructType]]s to a json output string. */ // scalastyle:off line.size.limit @ExpressionDescription( @@ -601,9 +602,11 @@ case class JsonToStruct( {"a":1,"b":2} > SELECT _FUNC_(named_struct('time', to_timestamp('2015-08-26', 'yyyy-MM-dd')), map('timestampFormat', 'dd/MM/yyyy')); {"time":"26/08/2015"} + > SELECT _FUNC_(array(named_struct('a', 1, 'b', 2)); + [{"a":1,"b":2}] """) // scalastyle:on line.size.limit -case class StructToJson( +case class StructsToJson( options: Map[String, String], child: Expression, timeZoneId: Option[String] = None) @@ -624,41 +627,58 @@ case class StructToJson( lazy val writer = new CharArrayWriter() @transient - lazy val gen = - new JacksonGenerator( - child.dataType.asInstanceOf[StructType], - writer, - new JSONOptions(options, timeZoneId.get)) + lazy val gen = new JacksonGenerator( + rowSchema, writer, new JSONOptions(options, timeZoneId.get)) + + @transient + lazy val rowSchema = child.dataType match { + case st: StructType => st + case ArrayType(st: StructType, _) => st + } + + // This converts rows to the JSON output according to the given schema. + @transient + lazy val converter: Any => UTF8String = { + def getAndReset(): UTF8String = { + gen.flush() + val json = writer.toString + writer.reset() + UTF8String.fromString(json) + } + + child.dataType match { + case _: StructType => + (row: Any) => + gen.write(row.asInstanceOf[InternalRow]) + getAndReset() + case ArrayType(_: StructType, _) => + (arr: Any) => + gen.write(arr.asInstanceOf[ArrayData]) + getAndReset() + } + } override def dataType: DataType = StringType - override def checkInputDataTypes(): TypeCheckResult = { - if (StructType.acceptsType(child.dataType)) { + override def checkInputDataTypes(): TypeCheckResult = child.dataType match { + case _: StructType | ArrayType(_: StructType, _) => try { - JacksonUtils.verifySchema(child.dataType.asInstanceOf[StructType]) + JacksonUtils.verifySchema(rowSchema) TypeCheckResult.TypeCheckSuccess } catch { case e: UnsupportedOperationException => TypeCheckResult.TypeCheckFailure(e.getMessage) } - } else { - TypeCheckResult.TypeCheckFailure( - s"$prettyName requires that the expression is a struct expression.") - } + case _ => TypeCheckResult.TypeCheckFailure( + s"Input type ${child.dataType.simpleString} must be a struct or array of structs.") } override def withTimeZone(timeZoneId: String): TimeZoneAwareExpression = copy(timeZoneId = Option(timeZoneId)) - override def nullSafeEval(row: Any): Any = { - gen.write(row.asInstanceOf[InternalRow]) - gen.flush() - val json = writer.toString - writer.reset() - UTF8String.fromString(json) - } + override def nullSafeEval(value: Any): Any = converter(value) - override def inputTypes: Seq[AbstractDataType] = StructType :: Nil + override def inputTypes: Seq[AbstractDataType] = TypeCollection(ArrayType, StructType) :: Nil } object JsonExprUtils { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonGenerator.scala index dec55279c9fc5..1d302aea6fd16 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonGenerator.scala @@ -37,6 +37,10 @@ private[sql] class JacksonGenerator( // `ValueWriter`s for all fields of the schema private val rootFieldWriters: Array[ValueWriter] = schema.map(_.dataType).map(makeWriter).toArray + // `ValueWriter` for array data storing rows of the schema. + private val arrElementWriter: ValueWriter = (arr: SpecializedGetters, i: Int) => { + writeObject(writeFields(arr.getStruct(i, schema.length), schema, rootFieldWriters)) + } private val gen = new JsonFactory().createGenerator(writer).setRootValueSeparator(null) @@ -185,17 +189,18 @@ private[sql] class JacksonGenerator( def flush(): Unit = gen.flush() /** - * Transforms a single InternalRow to JSON using Jackson + * Transforms a single `InternalRow` to JSON object using Jackson * * @param row The row to convert */ - def write(row: InternalRow): Unit = { - writeObject { - writeFields(row, schema, rootFieldWriters) - } - } + def write(row: InternalRow): Unit = writeObject(writeFields(row, schema, rootFieldWriters)) - def writeLineEnding(): Unit = { - gen.writeRaw('\n') - } + /** + * Transforms multiple `InternalRow`s to JSON array using Jackson + * + * @param array The array of rows to convert + */ + def write(array: ArrayData): Unit = writeArray(writeArrayData(array, arrElementWriter)) + + def writeLineEnding(): Unit = gen.writeRaw('\n') } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/JsonExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/JsonExpressionsSuite.scala index 19d0c8eb92f1a..e4698d44636b6 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/JsonExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/JsonExpressionsSuite.scala @@ -21,7 +21,7 @@ import java.util.Calendar import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.util.{DateTimeTestUtils, DateTimeUtils, ParseModes} +import org.apache.spark.sql.catalyst.util.{DateTimeTestUtils, DateTimeUtils, GenericArrayData, ParseModes} import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String @@ -352,7 +352,7 @@ class JsonExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { val jsonData = """{"a": 1}""" val schema = StructType(StructField("a", IntegerType) :: Nil) checkEvaluation( - JsonToStruct(schema, Map.empty, Literal(jsonData), gmtId), + JsonToStructs(schema, Map.empty, Literal(jsonData), gmtId), InternalRow(1) ) } @@ -361,13 +361,13 @@ class JsonExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { val jsonData = """{"a" 1}""" val schema = StructType(StructField("a", IntegerType) :: Nil) checkEvaluation( - JsonToStruct(schema, Map.empty, Literal(jsonData), gmtId), + JsonToStructs(schema, Map.empty, Literal(jsonData), gmtId), null ) // Other modes should still return `null`. checkEvaluation( - JsonToStruct(schema, Map("mode" -> ParseModes.PERMISSIVE_MODE), Literal(jsonData), gmtId), + JsonToStructs(schema, Map("mode" -> ParseModes.PERMISSIVE_MODE), Literal(jsonData), gmtId), null ) } @@ -376,62 +376,62 @@ class JsonExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { val input = """[{"a": 1}, {"a": 2}]""" val schema = ArrayType(StructType(StructField("a", IntegerType) :: Nil)) val output = InternalRow(1) :: InternalRow(2) :: Nil - checkEvaluation(JsonToStruct(schema, Map.empty, Literal(input), gmtId), output) + checkEvaluation(JsonToStructs(schema, Map.empty, Literal(input), gmtId), output) } test("from_json - input=object, schema=array, output=array of single row") { val input = """{"a": 1}""" val schema = ArrayType(StructType(StructField("a", IntegerType) :: Nil)) val output = InternalRow(1) :: Nil - checkEvaluation(JsonToStruct(schema, Map.empty, Literal(input), gmtId), output) + checkEvaluation(JsonToStructs(schema, Map.empty, Literal(input), gmtId), output) } test("from_json - input=empty array, schema=array, output=empty array") { val input = "[ ]" val schema = ArrayType(StructType(StructField("a", IntegerType) :: Nil)) val output = Nil - checkEvaluation(JsonToStruct(schema, Map.empty, Literal(input), gmtId), output) + checkEvaluation(JsonToStructs(schema, Map.empty, Literal(input), gmtId), output) } test("from_json - input=empty object, schema=array, output=array of single row with null") { val input = "{ }" val schema = ArrayType(StructType(StructField("a", IntegerType) :: Nil)) val output = InternalRow(null) :: Nil - checkEvaluation(JsonToStruct(schema, Map.empty, Literal(input), gmtId), output) + checkEvaluation(JsonToStructs(schema, Map.empty, Literal(input), gmtId), output) } test("from_json - input=array of single object, schema=struct, output=single row") { val input = """[{"a": 1}]""" val schema = StructType(StructField("a", IntegerType) :: Nil) val output = InternalRow(1) - checkEvaluation(JsonToStruct(schema, Map.empty, Literal(input), gmtId), output) + checkEvaluation(JsonToStructs(schema, Map.empty, Literal(input), gmtId), output) } test("from_json - input=array, schema=struct, output=null") { val input = """[{"a": 1}, {"a": 2}]""" val schema = StructType(StructField("a", IntegerType) :: Nil) val output = null - checkEvaluation(JsonToStruct(schema, Map.empty, Literal(input), gmtId), output) + checkEvaluation(JsonToStructs(schema, Map.empty, Literal(input), gmtId), output) } test("from_json - input=empty array, schema=struct, output=null") { val input = """[]""" val schema = StructType(StructField("a", IntegerType) :: Nil) val output = null - checkEvaluation(JsonToStruct(schema, Map.empty, Literal(input), gmtId), output) + checkEvaluation(JsonToStructs(schema, Map.empty, Literal(input), gmtId), output) } test("from_json - input=empty object, schema=struct, output=single row with null") { val input = """{ }""" val schema = StructType(StructField("a", IntegerType) :: Nil) val output = InternalRow(null) - checkEvaluation(JsonToStruct(schema, Map.empty, Literal(input), gmtId), output) + checkEvaluation(JsonToStructs(schema, Map.empty, Literal(input), gmtId), output) } test("from_json null input column") { val schema = StructType(StructField("a", IntegerType) :: Nil) checkEvaluation( - JsonToStruct(schema, Map.empty, Literal.create(null, StringType), gmtId), + JsonToStructs(schema, Map.empty, Literal.create(null, StringType), gmtId), null ) } @@ -444,14 +444,14 @@ class JsonExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { c.set(2016, 0, 1, 0, 0, 0) c.set(Calendar.MILLISECOND, 123) checkEvaluation( - JsonToStruct(schema, Map.empty, Literal(jsonData1), gmtId), + JsonToStructs(schema, Map.empty, Literal(jsonData1), gmtId), InternalRow(c.getTimeInMillis * 1000L) ) // The result doesn't change because the json string includes timezone string ("Z" here), // which means the string represents the timestamp string in the timezone regardless of // the timeZoneId parameter. checkEvaluation( - JsonToStruct(schema, Map.empty, Literal(jsonData1), Option("PST")), + JsonToStructs(schema, Map.empty, Literal(jsonData1), Option("PST")), InternalRow(c.getTimeInMillis * 1000L) ) @@ -461,7 +461,7 @@ class JsonExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { c.set(2016, 0, 1, 0, 0, 0) c.set(Calendar.MILLISECOND, 0) checkEvaluation( - JsonToStruct( + JsonToStructs( schema, Map("timestampFormat" -> "yyyy-MM-dd'T'HH:mm:ss"), Literal(jsonData2), @@ -469,7 +469,7 @@ class JsonExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { InternalRow(c.getTimeInMillis * 1000L) ) checkEvaluation( - JsonToStruct( + JsonToStructs( schema, Map("timestampFormat" -> "yyyy-MM-dd'T'HH:mm:ss", DateTimeUtils.TIMEZONE_OPTION -> tz.getID), @@ -483,25 +483,52 @@ class JsonExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { test("SPARK-19543: from_json empty input column") { val schema = StructType(StructField("a", IntegerType) :: Nil) checkEvaluation( - JsonToStruct(schema, Map.empty, Literal.create(" ", StringType), gmtId), + JsonToStructs(schema, Map.empty, Literal.create(" ", StringType), gmtId), null ) } - test("to_json") { + test("to_json - struct") { val schema = StructType(StructField("a", IntegerType) :: Nil) val struct = Literal.create(create_row(1), schema) checkEvaluation( - StructToJson(Map.empty, struct, gmtId), + StructsToJson(Map.empty, struct, gmtId), """{"a":1}""" ) } + test("to_json - array") { + val inputSchema = ArrayType(StructType(StructField("a", IntegerType) :: Nil)) + val input = new GenericArrayData(InternalRow(1) :: InternalRow(2) :: Nil) + val output = """[{"a":1},{"a":2}]""" + checkEvaluation( + StructsToJson(Map.empty, Literal.create(input, inputSchema), gmtId), + output) + } + + test("to_json - array with single empty row") { + val inputSchema = ArrayType(StructType(StructField("a", IntegerType) :: Nil)) + val input = new GenericArrayData(InternalRow(null) :: Nil) + val output = """[{}]""" + checkEvaluation( + StructsToJson(Map.empty, Literal.create(input, inputSchema), gmtId), + output) + } + + test("to_json - empty array") { + val inputSchema = ArrayType(StructType(StructField("a", IntegerType) :: Nil)) + val input = new GenericArrayData(Nil) + val output = """[]""" + checkEvaluation( + StructsToJson(Map.empty, Literal.create(input, inputSchema), gmtId), + output) + } + test("to_json null input column") { val schema = StructType(StructField("a", IntegerType) :: Nil) val struct = Literal.create(null, schema) checkEvaluation( - StructToJson(Map.empty, struct, gmtId), + StructsToJson(Map.empty, struct, gmtId), null ) } @@ -514,16 +541,16 @@ class JsonExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { val struct = Literal.create(create_row(c.getTimeInMillis * 1000L), schema) checkEvaluation( - StructToJson(Map.empty, struct, gmtId), + StructsToJson(Map.empty, struct, gmtId), """{"t":"2016-01-01T00:00:00.000Z"}""" ) checkEvaluation( - StructToJson(Map.empty, struct, Option("PST")), + StructsToJson(Map.empty, struct, Option("PST")), """{"t":"2015-12-31T16:00:00.000-08:00"}""" ) checkEvaluation( - StructToJson( + StructsToJson( Map("timestampFormat" -> "yyyy-MM-dd'T'HH:mm:ss", DateTimeUtils.TIMEZONE_OPTION -> gmtId.get), struct, @@ -531,7 +558,7 @@ class JsonExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { """{"t":"2016-01-01T00:00:00"}""" ) checkEvaluation( - StructToJson( + StructsToJson( Map("timestampFormat" -> "yyyy-MM-dd'T'HH:mm:ss", DateTimeUtils.TIMEZONE_OPTION -> "PST"), struct, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 201f726db3fad..a9f089c850d42 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -2978,7 +2978,8 @@ object functions { /** * (Scala-specific) Parses a column containing a JSON string into a `StructType` or `ArrayType` - * with the specified schema. Returns `null`, in the case of an unparseable string. + * of `StructType`s with the specified schema. Returns `null`, in the case of an unparseable + * string. * * @param e a string column containing JSON data. * @param schema the schema to use when parsing the json string @@ -2989,7 +2990,7 @@ object functions { * @since 2.2.0 */ def from_json(e: Column, schema: DataType, options: Map[String, String]): Column = withExpr { - JsonToStruct(schema, options, e.expr) + JsonToStructs(schema, options, e.expr) } /** @@ -3009,7 +3010,8 @@ object functions { /** * (Java-specific) Parses a column containing a JSON string into a `StructType` or `ArrayType` - * with the specified schema. Returns `null`, in the case of an unparseable string. + * of `StructType`s with the specified schema. Returns `null`, in the case of an unparseable + * string. * * @param e a string column containing JSON data. * @param schema the schema to use when parsing the json string @@ -3036,7 +3038,7 @@ object functions { from_json(e, schema, Map.empty[String, String]) /** - * Parses a column containing a JSON string into a `StructType` or `ArrayType` + * Parses a column containing a JSON string into a `StructType` or `ArrayType` of `StructType`s * with the specified schema. Returns `null`, in the case of an unparseable string. * * @param e a string column containing JSON data. @@ -3049,7 +3051,7 @@ object functions { from_json(e, schema, Map.empty[String, String]) /** - * Parses a column containing a JSON string into a `StructType` or `ArrayType` + * Parses a column containing a JSON string into a `StructType` or `ArrayType` of `StructType`s * with the specified schema. Returns `null`, in the case of an unparseable string. * * @param e a string column containing JSON data. @@ -3062,10 +3064,11 @@ object functions { from_json(e, DataType.fromJson(schema), options) /** - * (Scala-specific) Converts a column containing a `StructType` into a JSON string with the - * specified schema. Throws an exception, in the case of an unsupported type. + * (Scala-specific) Converts a column containing a `StructType` or `ArrayType` of `StructType`s + * into a JSON string with the specified schema. Throws an exception, in the case of an + * unsupported type. * - * @param e a struct column. + * @param e a column containing a struct or array of the structs. * @param options options to control how the struct column is converted into a json string. * accepts the same options and the json data source. * @@ -3073,14 +3076,15 @@ object functions { * @since 2.1.0 */ def to_json(e: Column, options: Map[String, String]): Column = withExpr { - StructToJson(options, e.expr) + StructsToJson(options, e.expr) } /** - * (Java-specific) Converts a column containing a `StructType` into a JSON string with the - * specified schema. Throws an exception, in the case of an unsupported type. + * (Java-specific) Converts a column containing a `StructType` or `ArrayType` of `StructType`s + * into a JSON string with the specified schema. Throws an exception, in the case of an + * unsupported type. * - * @param e a struct column. + * @param e a column containing a struct or array of the structs. * @param options options to control how the struct column is converted into a json string. * accepts the same options and the json data source. * @@ -3091,10 +3095,10 @@ object functions { to_json(e, options.asScala.toMap) /** - * Converts a column containing a `StructType` into a JSON string with the - * specified schema. Throws an exception, in the case of an unsupported type. + * Converts a column containing a `StructType` or `ArrayType` of `StructType`s into a JSON string + * with the specified schema. Throws an exception, in the case of an unsupported type. * - * @param e a struct column. + * @param e a column containing a struct or array of the structs. * * @group collection_funcs * @since 2.1.0 diff --git a/sql/core/src/test/resources/sql-tests/inputs/json-functions.sql b/sql/core/src/test/resources/sql-tests/inputs/json-functions.sql index 83243c5e5a12f..b3cc2cea51d43 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/json-functions.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/json-functions.sql @@ -3,6 +3,7 @@ describe function to_json; describe function extended to_json; select to_json(named_struct('a', 1, 'b', 2)); select to_json(named_struct('time', to_timestamp('2015-08-26', 'yyyy-MM-dd')), map('timestampFormat', 'dd/MM/yyyy')); +select to_json(array(named_struct('a', 1, 'b', 2))); -- Check if errors handled select to_json(named_struct('a', 1, 'b', 2), named_struct('mode', 'PERMISSIVE')); select to_json(named_struct('a', 1, 'b', 2), map('mode', 1)); diff --git a/sql/core/src/test/resources/sql-tests/results/json-functions.sql.out b/sql/core/src/test/resources/sql-tests/results/json-functions.sql.out index b57cbbc1d843b..315e1730ce7df 100644 --- a/sql/core/src/test/resources/sql-tests/results/json-functions.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/json-functions.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 16 +-- Number of queries: 17 -- !query 0 @@ -7,7 +7,7 @@ describe function to_json -- !query 0 schema struct -- !query 0 output -Class: org.apache.spark.sql.catalyst.expressions.StructToJson +Class: org.apache.spark.sql.catalyst.expressions.StructsToJson Function: to_json Usage: to_json(expr[, options]) - Returns a json string with a given struct value @@ -17,13 +17,15 @@ describe function extended to_json -- !query 1 schema struct -- !query 1 output -Class: org.apache.spark.sql.catalyst.expressions.StructToJson +Class: org.apache.spark.sql.catalyst.expressions.StructsToJson Extended Usage: Examples: > SELECT to_json(named_struct('a', 1, 'b', 2)); {"a":1,"b":2} > SELECT to_json(named_struct('time', to_timestamp('2015-08-26', 'yyyy-MM-dd')), map('timestampFormat', 'dd/MM/yyyy')); {"time":"26/08/2015"} + > SELECT to_json(array(named_struct('a', 1, 'b', 2)); + [{"a":1,"b":2}] Function: to_json Usage: to_json(expr[, options]) - Returns a json string with a given struct value @@ -32,7 +34,7 @@ Usage: to_json(expr[, options]) - Returns a json string with a given struct valu -- !query 2 select to_json(named_struct('a', 1, 'b', 2)) -- !query 2 schema -struct +struct -- !query 2 output {"a":1,"b":2} @@ -40,54 +42,62 @@ struct -- !query 3 select to_json(named_struct('time', to_timestamp('2015-08-26', 'yyyy-MM-dd')), map('timestampFormat', 'dd/MM/yyyy')) -- !query 3 schema -struct +struct -- !query 3 output {"time":"26/08/2015"} -- !query 4 -select to_json(named_struct('a', 1, 'b', 2), named_struct('mode', 'PERMISSIVE')) +select to_json(array(named_struct('a', 1, 'b', 2))) -- !query 4 schema -struct<> +struct -- !query 4 output -org.apache.spark.sql.AnalysisException -Must use a map() function for options;; line 1 pos 7 +[{"a":1,"b":2}] -- !query 5 -select to_json(named_struct('a', 1, 'b', 2), map('mode', 1)) +select to_json(named_struct('a', 1, 'b', 2), named_struct('mode', 'PERMISSIVE')) -- !query 5 schema struct<> -- !query 5 output org.apache.spark.sql.AnalysisException -A type of keys and values in map() must be string, but got MapType(StringType,IntegerType,false);; line 1 pos 7 +Must use a map() function for options;; line 1 pos 7 -- !query 6 -select to_json() +select to_json(named_struct('a', 1, 'b', 2), map('mode', 1)) -- !query 6 schema struct<> -- !query 6 output org.apache.spark.sql.AnalysisException -Invalid number of arguments for function to_json; line 1 pos 7 +A type of keys and values in map() must be string, but got MapType(StringType,IntegerType,false);; line 1 pos 7 -- !query 7 -describe function from_json +select to_json() -- !query 7 schema -struct +struct<> -- !query 7 output -Class: org.apache.spark.sql.catalyst.expressions.JsonToStruct -Function: from_json -Usage: from_json(jsonStr, schema[, options]) - Returns a struct value with the given `jsonStr` and `schema`. +org.apache.spark.sql.AnalysisException +Invalid number of arguments for function to_json; line 1 pos 7 -- !query 8 -describe function extended from_json +describe function from_json -- !query 8 schema struct -- !query 8 output -Class: org.apache.spark.sql.catalyst.expressions.JsonToStruct +Class: org.apache.spark.sql.catalyst.expressions.JsonToStructs +Function: from_json +Usage: from_json(jsonStr, schema[, options]) - Returns a struct value with the given `jsonStr` and `schema`. + + +-- !query 9 +describe function extended from_json +-- !query 9 schema +struct +-- !query 9 output +Class: org.apache.spark.sql.catalyst.expressions.JsonToStructs Extended Usage: Examples: > SELECT from_json('{"a":1, "b":0.8}', 'a INT, b DOUBLE'); @@ -99,36 +109,36 @@ Function: from_json Usage: from_json(jsonStr, schema[, options]) - Returns a struct value with the given `jsonStr` and `schema`. --- !query 9 +-- !query 10 select from_json('{"a":1}', 'a INT') --- !query 9 schema -struct> --- !query 9 output +-- !query 10 schema +struct> +-- !query 10 output {"a":1} --- !query 10 +-- !query 11 select from_json('{"time":"26/08/2015"}', 'time Timestamp', map('timestampFormat', 'dd/MM/yyyy')) --- !query 10 schema -struct> --- !query 10 output +-- !query 11 schema +struct> +-- !query 11 output {"time":2015-08-26 00:00:00.0} --- !query 11 +-- !query 12 select from_json('{"a":1}', 1) --- !query 11 schema +-- !query 12 schema struct<> --- !query 11 output +-- !query 12 output org.apache.spark.sql.AnalysisException Expected a string literal instead of 1;; line 1 pos 7 --- !query 12 +-- !query 13 select from_json('{"a":1}', 'a InvalidType') --- !query 12 schema +-- !query 13 schema struct<> --- !query 12 output +-- !query 13 output org.apache.spark.sql.AnalysisException DataType invalidtype() is not supported.(line 1, pos 2) @@ -139,28 +149,28 @@ a InvalidType ; line 1 pos 7 --- !query 13 +-- !query 14 select from_json('{"a":1}', 'a INT', named_struct('mode', 'PERMISSIVE')) --- !query 13 schema +-- !query 14 schema struct<> --- !query 13 output +-- !query 14 output org.apache.spark.sql.AnalysisException Must use a map() function for options;; line 1 pos 7 --- !query 14 +-- !query 15 select from_json('{"a":1}', 'a INT', map('mode', 1)) --- !query 14 schema +-- !query 15 schema struct<> --- !query 14 output +-- !query 15 output org.apache.spark.sql.AnalysisException A type of keys and values in map() must be string, but got MapType(StringType,IntegerType,false);; line 1 pos 7 --- !query 15 +-- !query 16 select from_json() --- !query 15 schema +-- !query 16 schema struct<> --- !query 15 output +-- !query 16 output org.apache.spark.sql.AnalysisException Invalid number of arguments for function from_json; line 1 pos 7 diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JsonFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JsonFunctionsSuite.scala index 2345b82081161..170c238c53438 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/JsonFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/JsonFunctionsSuite.scala @@ -156,7 +156,7 @@ class JsonFunctionsSuite extends QueryTest with SharedSQLContext { Row(Seq(Row(1, "a"), Row(2, null), Row(null, null)))) } - test("to_json") { + test("to_json - struct") { val df = Seq(Tuple1(Tuple1(1))).toDF("a") checkAnswer( @@ -164,6 +164,14 @@ class JsonFunctionsSuite extends QueryTest with SharedSQLContext { Row("""{"_1":1}""") :: Nil) } + test("to_json - array") { + val df = Seq(Tuple1(Tuple1(1) :: Nil)).toDF("a") + + checkAnswer( + df.select(to_json($"a")), + Row("""[{"_1":1}]""") :: Nil) + } + test("to_json with option") { val df = Seq(Tuple1(Tuple1(java.sql.Timestamp.valueOf("2015-08-26 18:00:00.0")))).toDF("a") val options = Map("timestampFormat" -> "dd/MM/yyyy HH:mm") @@ -184,7 +192,7 @@ class JsonFunctionsSuite extends QueryTest with SharedSQLContext { "Unable to convert column a of type calendarinterval to JSON.")) } - test("roundtrip in to_json and from_json") { + test("roundtrip in to_json and from_json - struct") { val dfOne = Seq(Tuple1(Tuple1(1)), Tuple1(null)).toDF("struct") val schemaOne = dfOne.schema(0).dataType.asInstanceOf[StructType] val readBackOne = dfOne.select(to_json($"struct").as("json")) @@ -198,6 +206,20 @@ class JsonFunctionsSuite extends QueryTest with SharedSQLContext { checkAnswer(dfTwo, readBackTwo) } + test("roundtrip in to_json and from_json - array") { + val dfOne = Seq(Tuple1(Tuple1(1) :: Nil), Tuple1(null :: Nil)).toDF("array") + val schemaOne = dfOne.schema(0).dataType + val readBackOne = dfOne.select(to_json($"array").as("json")) + .select(from_json($"json", schemaOne).as("array")) + checkAnswer(dfOne, readBackOne) + + val dfTwo = Seq(Some("""[{"a":1}]"""), None).toDF("json") + val schemaTwo = ArrayType(StructType(StructField("a", IntegerType) :: Nil)) + val readBackTwo = dfTwo.select(from_json($"json", schemaTwo).as("array")) + .select(to_json($"array").as("json")) + checkAnswer(dfTwo, readBackTwo) + } + test("SPARK-19637 Support to_json in SQL") { val df1 = Seq(Tuple1(Tuple1(1))).toDF("a") checkAnswer( From c40597720e8e66a6b11ca241b1ad387154a8fe72 Mon Sep 17 00:00:00 2001 From: Felix Cheung Date: Sun, 19 Mar 2017 22:34:18 -0700 Subject: [PATCH 0057/1765] [SPARK-20020][SPARKR] DataFrame checkpoint API ## What changes were proposed in this pull request? Add checkpoint, setCheckpointDir API to R ## How was this patch tested? unit tests, manual tests Author: Felix Cheung Closes #17351 from felixcheung/rdfcheckpoint. --- R/pkg/NAMESPACE | 2 ++ R/pkg/R/DataFrame.R | 29 +++++++++++++++++++++++ R/pkg/R/RDD.R | 2 +- R/pkg/R/context.R | 21 +++++++++++++++- R/pkg/R/generics.R | 6 ++++- R/pkg/inst/tests/testthat/test_rdd.R | 4 ++-- R/pkg/inst/tests/testthat/test_sparkSQL.R | 11 +++++++++ 7 files changed, 70 insertions(+), 5 deletions(-) diff --git a/R/pkg/NAMESPACE b/R/pkg/NAMESPACE index 78344ce9ff08b..8be7875ad2d5f 100644 --- a/R/pkg/NAMESPACE +++ b/R/pkg/NAMESPACE @@ -82,6 +82,7 @@ exportMethods("arrange", "as.data.frame", "attach", "cache", + "checkpoint", "coalesce", "collect", "colnames", @@ -369,6 +370,7 @@ export("as.DataFrame", "read.parquet", "read.stream", "read.text", + "setCheckpointDir", "spark.lapply", "spark.addFile", "spark.getSparkFilesRootDirectory", diff --git a/R/pkg/R/DataFrame.R b/R/pkg/R/DataFrame.R index bc81633815c65..97786df4ae6a1 100644 --- a/R/pkg/R/DataFrame.R +++ b/R/pkg/R/DataFrame.R @@ -3613,3 +3613,32 @@ setMethod("write.stream", ssq <- handledCallJMethod(write, "start") streamingQuery(ssq) }) + +#' checkpoint +#' +#' Returns a checkpointed version of this SparkDataFrame. Checkpointing can be used to truncate the +#' logical plan, which is especially useful in iterative algorithms where the plan may grow +#' exponentially. It will be saved to files inside the checkpoint directory set with +#' \code{setCheckpointDir} +#' +#' @param x A SparkDataFrame +#' @param eager whether to checkpoint this SparkDataFrame immediately +#' @return a new checkpointed SparkDataFrame +#' @family SparkDataFrame functions +#' @aliases checkpoint,SparkDataFrame-method +#' @rdname checkpoint +#' @name checkpoint +#' @seealso \link{setCheckpointDir} +#' @export +#' @examples +#'\dontrun{ +#' setCheckpointDir("/checkpoint") +#' df <- checkpoint(df) +#' } +#' @note checkpoint since 2.2.0 +setMethod("checkpoint", + signature(x = "SparkDataFrame"), + function(x, eager = TRUE) { + df <- callJMethod(x@sdf, "checkpoint", as.logical(eager)) + dataFrame(df) + }) diff --git a/R/pkg/R/RDD.R b/R/pkg/R/RDD.R index 5667b9d788821..7ad3993e9ecbc 100644 --- a/R/pkg/R/RDD.R +++ b/R/pkg/R/RDD.R @@ -291,7 +291,7 @@ setMethod("unpersistRDD", #' @rdname checkpoint-methods #' @aliases checkpoint,RDD-method #' @noRd -setMethod("checkpoint", +setMethod("checkpointRDD", signature(x = "RDD"), function(x) { jrdd <- getJRDD(x) diff --git a/R/pkg/R/context.R b/R/pkg/R/context.R index 1a0dd65f450b9..cb0f83b2fa227 100644 --- a/R/pkg/R/context.R +++ b/R/pkg/R/context.R @@ -291,7 +291,7 @@ broadcast <- function(sc, object) { #' rdd <- parallelize(sc, 1:2, 2L) #' checkpoint(rdd) #'} -setCheckpointDir <- function(sc, dirName) { +setCheckpointDirSC <- function(sc, dirName) { invisible(callJMethod(sc, "setCheckpointDir", suppressWarnings(normalizePath(dirName)))) } @@ -410,3 +410,22 @@ setLogLevel <- function(level) { sc <- getSparkContext() invisible(callJMethod(sc, "setLogLevel", level)) } + +#' Set checkpoint directory +#' +#' Set the directory under which SparkDataFrame are going to be checkpointed. The directory must be +#' a HDFS path if running on a cluster. +#' +#' @rdname setCheckpointDir +#' @param directory Directory path to checkpoint to +#' @seealso \link{checkpoint} +#' @export +#' @examples +#'\dontrun{ +#' setCheckpointDir("/checkpoint") +#'} +#' @note setCheckpointDir since 2.0.0 +setCheckpointDir <- function(directory) { + sc <- getSparkContext() + invisible(callJMethod(sc, "setCheckpointDir", suppressWarnings(normalizePath(directory)))) +} diff --git a/R/pkg/R/generics.R b/R/pkg/R/generics.R index 029771289fd53..80283e48ced7b 100644 --- a/R/pkg/R/generics.R +++ b/R/pkg/R/generics.R @@ -32,7 +32,7 @@ setGeneric("coalesceRDD", function(x, numPartitions, ...) { standardGeneric("coa # @rdname checkpoint-methods # @export -setGeneric("checkpoint", function(x) { standardGeneric("checkpoint") }) +setGeneric("checkpointRDD", function(x) { standardGeneric("checkpointRDD") }) setGeneric("collectRDD", function(x, ...) { standardGeneric("collectRDD") }) @@ -406,6 +406,10 @@ setGeneric("attach") #' @export setGeneric("cache", function(x) { standardGeneric("cache") }) +#' @rdname checkpoint +#' @export +setGeneric("checkpoint", function(x, eager = TRUE) { standardGeneric("checkpoint") }) + #' @rdname coalesce #' @param x a Column or a SparkDataFrame. #' @param ... additional argument(s). If \code{x} is a Column, additional Columns can be optionally diff --git a/R/pkg/inst/tests/testthat/test_rdd.R b/R/pkg/inst/tests/testthat/test_rdd.R index 787ef51c501c0..b72c801dd958d 100644 --- a/R/pkg/inst/tests/testthat/test_rdd.R +++ b/R/pkg/inst/tests/testthat/test_rdd.R @@ -143,8 +143,8 @@ test_that("PipelinedRDD support actions: cache(), persist(), unpersist(), checkp expect_false(rdd2@env$isCached) tempDir <- tempfile(pattern = "checkpoint") - setCheckpointDir(sc, tempDir) - checkpoint(rdd2) + setCheckpointDirSC(sc, tempDir) + checkpointRDD(rdd2) expect_true(rdd2@env$isCheckpointed) rdd2 <- lapply(rdd2, function(x) x) diff --git a/R/pkg/inst/tests/testthat/test_sparkSQL.R b/R/pkg/inst/tests/testthat/test_sparkSQL.R index 9c38e0d866aa3..cbc3569795d97 100644 --- a/R/pkg/inst/tests/testthat/test_sparkSQL.R +++ b/R/pkg/inst/tests/testthat/test_sparkSQL.R @@ -841,6 +841,17 @@ test_that("cache(), storageLevel(), persist(), and unpersist() on a DataFrame", expect_true(is.data.frame(collect(df))) }) +test_that("setCheckpointDir(), checkpoint() on a DataFrame", { + checkpointDir <- file.path(tempdir(), "cproot") + expect_true(length(list.files(path = checkpointDir, all.files = TRUE)) == 0) + + setCheckpointDir(checkpointDir) + df <- read.json(jsonPath) + df <- checkpoint(df) + expect_is(df, "SparkDataFrame") + expect_false(length(list.files(path = checkpointDir, all.files = TRUE)) == 0) +}) + test_that("schema(), dtypes(), columns(), names() return the correct values/format", { df <- read.json(jsonPath) testSchema <- schema(df) From 965a5abcff3adccc10a53b0d97d06c43934df1a2 Mon Sep 17 00:00:00 2001 From: wangzhenhua Date: Mon, 20 Mar 2017 14:37:23 +0800 Subject: [PATCH 0058/1765] [SPARK-19994][SQL] Wrong outputOrdering for right/full outer smj ## What changes were proposed in this pull request? For right outer join, values of the left key will be filled with nulls if it can't match the value of the right key, so `nullOrdering` of the left key can't be guaranteed. We should output right key order instead of left key order. For full outer join, neither left key nor right key guarantees `nullOrdering`. We should not output any ordering. In tests, besides adding three test cases for left/right/full outer sort merge join, this patch also reorganizes code in `PlannerSuite` by putting together tests for `Sort`, and also extracts common logic in Sort tests into a method. ## How was this patch tested? Corresponding test cases are added. Author: wangzhenhua Author: Zhenhua Wang Closes #17331 from wzhfy/wrongOrdering. --- .../execution/joins/SortMergeJoinExec.scala | 12 +- .../spark/sql/execution/PlannerSuite.scala | 233 ++++++++++-------- 2 files changed, 146 insertions(+), 99 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala index bcdc4dcdf7d99..02f4f55c7999a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala @@ -80,7 +80,17 @@ case class SortMergeJoinExec( override def requiredChildDistribution: Seq[Distribution] = ClusteredDistribution(leftKeys) :: ClusteredDistribution(rightKeys) :: Nil - override def outputOrdering: Seq[SortOrder] = requiredOrders(leftKeys) + override def outputOrdering: Seq[SortOrder] = joinType match { + // For left and right outer joins, the output is ordered by the streamed input's join keys. + case LeftOuter => requiredOrders(leftKeys) + case RightOuter => requiredOrders(rightKeys) + // There are null rows in both streams, so there is no order. + case FullOuter => Nil + case _: InnerLike | LeftExistence(_) => requiredOrders(leftKeys) + case x => + throw new IllegalArgumentException( + s"${getClass.getSimpleName} should not take $x as the JoinType") + } override def requiredChildOrdering: Seq[Seq[SortOrder]] = requiredOrders(leftKeys) :: requiredOrders(rightKeys) :: Nil diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala index 02ccebd22bdf9..f2232fc489b78 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala @@ -21,7 +21,7 @@ import org.apache.spark.rdd.RDD import org.apache.spark.sql.{execution, Row} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.plans.Inner +import org.apache.spark.sql.catalyst.plans.{FullOuter, Inner, LeftOuter, RightOuter} import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Repartition} import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.execution.columnar.InMemoryRelation @@ -251,7 +251,9 @@ class PlannerSuite extends SharedSQLContext { } } - // --- Unit tests of EnsureRequirements --------------------------------------------------------- + /////////////////////////////////////////////////////////////////////////// + // Unit tests of EnsureRequirements for Exchange + /////////////////////////////////////////////////////////////////////////// // When it comes to testing whether EnsureRequirements properly ensures distribution requirements, // there two dimensions that need to be considered: are the child partitionings compatible and @@ -384,93 +386,6 @@ class PlannerSuite extends SharedSQLContext { } } - test("EnsureRequirements adds sort when there is no existing ordering") { - val orderingA = SortOrder(Literal(1), Ascending) - val orderingB = SortOrder(Literal(2), Ascending) - assert(orderingA != orderingB) - val inputPlan = DummySparkPlan( - children = DummySparkPlan(outputOrdering = Seq.empty) :: Nil, - requiredChildOrdering = Seq(Seq(orderingB)), - requiredChildDistribution = Seq(UnspecifiedDistribution) - ) - val outputPlan = EnsureRequirements(spark.sessionState.conf).apply(inputPlan) - assertDistributionRequirementsAreSatisfied(outputPlan) - if (outputPlan.collect { case s: SortExec => true }.isEmpty) { - fail(s"Sort should have been added:\n$outputPlan") - } - } - - test("EnsureRequirements skips sort when required ordering is prefix of existing ordering") { - val orderingA = SortOrder(Literal(1), Ascending) - val orderingB = SortOrder(Literal(2), Ascending) - assert(orderingA != orderingB) - val inputPlan = DummySparkPlan( - children = DummySparkPlan(outputOrdering = Seq(orderingA, orderingB)) :: Nil, - requiredChildOrdering = Seq(Seq(orderingA)), - requiredChildDistribution = Seq(UnspecifiedDistribution) - ) - val outputPlan = EnsureRequirements(spark.sessionState.conf).apply(inputPlan) - assertDistributionRequirementsAreSatisfied(outputPlan) - if (outputPlan.collect { case s: SortExec => true }.nonEmpty) { - fail(s"No sorts should have been added:\n$outputPlan") - } - } - - test("EnsureRequirements skips sort when required ordering is semantically equal to " + - "existing ordering") { - val exprId: ExprId = NamedExpression.newExprId - val attribute1 = - AttributeReference( - name = "col1", - dataType = LongType, - nullable = false - ) (exprId = exprId, - qualifier = Some("col1_qualifier") - ) - - val attribute2 = - AttributeReference( - name = "col1", - dataType = LongType, - nullable = false - ) (exprId = exprId) - - val orderingA1 = SortOrder(attribute1, Ascending) - val orderingA2 = SortOrder(attribute2, Ascending) - - assert(orderingA1 != orderingA2, s"$orderingA1 should NOT equal to $orderingA2") - assert(orderingA1.semanticEquals(orderingA2), - s"$orderingA1 should be semantically equal to $orderingA2") - - val inputPlan = DummySparkPlan( - children = DummySparkPlan(outputOrdering = Seq(orderingA1)) :: Nil, - requiredChildOrdering = Seq(Seq(orderingA2)), - requiredChildDistribution = Seq(UnspecifiedDistribution) - ) - val outputPlan = EnsureRequirements(spark.sessionState.conf).apply(inputPlan) - assertDistributionRequirementsAreSatisfied(outputPlan) - if (outputPlan.collect { case s: SortExec => true }.nonEmpty) { - fail(s"No sorts should have been added:\n$outputPlan") - } - } - - // This is a regression test for SPARK-11135 - test("EnsureRequirements adds sort when required ordering isn't a prefix of existing ordering") { - val orderingA = SortOrder(Literal(1), Ascending) - val orderingB = SortOrder(Literal(2), Ascending) - assert(orderingA != orderingB) - val inputPlan = DummySparkPlan( - children = DummySparkPlan(outputOrdering = Seq(orderingA)) :: Nil, - requiredChildOrdering = Seq(Seq(orderingA, orderingB)), - requiredChildDistribution = Seq(UnspecifiedDistribution) - ) - val outputPlan = EnsureRequirements(spark.sessionState.conf).apply(inputPlan) - assertDistributionRequirementsAreSatisfied(outputPlan) - if (outputPlan.collect { case s: SortExec => true }.isEmpty) { - fail(s"Sort should have been added:\n$outputPlan") - } - } - test("EnsureRequirements eliminates Exchange if child has Exchange with same partitioning") { val distribution = ClusteredDistribution(Literal(1) :: Nil) val finalPartitioning = HashPartitioning(Literal(1) :: Nil, 5) @@ -481,7 +396,7 @@ class PlannerSuite extends SharedSQLContext { children = DummySparkPlan(outputPartitioning = childPartitioning) :: Nil, requiredChildDistribution = Seq(distribution), requiredChildOrdering = Seq(Seq.empty)), - None) + None) val outputPlan = EnsureRequirements(spark.sessionState.conf).apply(inputPlan) assertDistributionRequirementsAreSatisfied(outputPlan) @@ -510,8 +425,6 @@ class PlannerSuite extends SharedSQLContext { } } - // --------------------------------------------------------------------------------------------- - test("Reuse exchanges") { val distribution = ClusteredDistribution(Literal(1) :: Nil) val finalPartitioning = HashPartitioning(Literal(1) :: Nil, 5) @@ -525,12 +438,12 @@ class PlannerSuite extends SharedSQLContext { None) val inputPlan = SortMergeJoinExec( - Literal(1) :: Nil, - Literal(1) :: Nil, - Inner, - None, - shuffle, - shuffle) + Literal(1) :: Nil, + Literal(1) :: Nil, + Inner, + None, + shuffle, + shuffle) val outputPlan = ReuseExchange(spark.sessionState.conf).apply(inputPlan) if (outputPlan.collect { case e: ReusedExchangeExec => true }.size != 1) { @@ -557,6 +470,130 @@ class PlannerSuite extends SharedSQLContext { fail(s"Should have only two shuffles:\n$outputPlan") } } + + /////////////////////////////////////////////////////////////////////////// + // Unit tests of EnsureRequirements for Sort + /////////////////////////////////////////////////////////////////////////// + + private val exprA = Literal(1) + private val exprB = Literal(2) + private val orderingA = SortOrder(exprA, Ascending) + private val orderingB = SortOrder(exprB, Ascending) + private val planA = DummySparkPlan(outputOrdering = Seq(orderingA), + outputPartitioning = HashPartitioning(exprA :: Nil, 5)) + private val planB = DummySparkPlan(outputOrdering = Seq(orderingB), + outputPartitioning = HashPartitioning(exprB :: Nil, 5)) + + assert(orderingA != orderingB) + + private def assertSortRequirementsAreSatisfied( + childPlan: SparkPlan, + requiredOrdering: Seq[SortOrder], + shouldHaveSort: Boolean): Unit = { + val inputPlan = DummySparkPlan( + children = childPlan :: Nil, + requiredChildOrdering = Seq(requiredOrdering), + requiredChildDistribution = Seq(UnspecifiedDistribution) + ) + val outputPlan = EnsureRequirements(spark.sessionState.conf).apply(inputPlan) + assertDistributionRequirementsAreSatisfied(outputPlan) + if (shouldHaveSort) { + if (outputPlan.collect { case s: SortExec => true }.isEmpty) { + fail(s"Sort should have been added:\n$outputPlan") + } + } else { + if (outputPlan.collect { case s: SortExec => true }.nonEmpty) { + fail(s"No sorts should have been added:\n$outputPlan") + } + } + } + + test("EnsureRequirements for sort operator after left outer sort merge join") { + // Only left key is sorted after left outer SMJ (thus doesn't need a sort). + val leftSmj = SortMergeJoinExec(exprA :: Nil, exprB :: Nil, LeftOuter, None, planA, planB) + Seq((orderingA, false), (orderingB, true)).foreach { case (ordering, needSort) => + assertSortRequirementsAreSatisfied( + childPlan = leftSmj, + requiredOrdering = Seq(ordering), + shouldHaveSort = needSort) + } + } + + test("EnsureRequirements for sort operator after right outer sort merge join") { + // Only right key is sorted after right outer SMJ (thus doesn't need a sort). + val rightSmj = SortMergeJoinExec(exprA :: Nil, exprB :: Nil, RightOuter, None, planA, planB) + Seq((orderingA, true), (orderingB, false)).foreach { case (ordering, needSort) => + assertSortRequirementsAreSatisfied( + childPlan = rightSmj, + requiredOrdering = Seq(ordering), + shouldHaveSort = needSort) + } + } + + test("EnsureRequirements adds sort after full outer sort merge join") { + // Neither keys is sorted after full outer SMJ, so they both need sorts. + val fullSmj = SortMergeJoinExec(exprA :: Nil, exprB :: Nil, FullOuter, None, planA, planB) + Seq(orderingA, orderingB).foreach { ordering => + assertSortRequirementsAreSatisfied( + childPlan = fullSmj, + requiredOrdering = Seq(ordering), + shouldHaveSort = true) + } + } + + test("EnsureRequirements adds sort when there is no existing ordering") { + assertSortRequirementsAreSatisfied( + childPlan = DummySparkPlan(outputOrdering = Seq.empty), + requiredOrdering = Seq(orderingB), + shouldHaveSort = true) + } + + test("EnsureRequirements skips sort when required ordering is prefix of existing ordering") { + assertSortRequirementsAreSatisfied( + childPlan = DummySparkPlan(outputOrdering = Seq(orderingA, orderingB)), + requiredOrdering = Seq(orderingA), + shouldHaveSort = false) + } + + test("EnsureRequirements skips sort when required ordering is semantically equal to " + + "existing ordering") { + val exprId: ExprId = NamedExpression.newExprId + val attribute1 = + AttributeReference( + name = "col1", + dataType = LongType, + nullable = false + ) (exprId = exprId, + qualifier = Some("col1_qualifier") + ) + + val attribute2 = + AttributeReference( + name = "col1", + dataType = LongType, + nullable = false + ) (exprId = exprId) + + val orderingA1 = SortOrder(attribute1, Ascending) + val orderingA2 = SortOrder(attribute2, Ascending) + + assert(orderingA1 != orderingA2, s"$orderingA1 should NOT equal to $orderingA2") + assert(orderingA1.semanticEquals(orderingA2), + s"$orderingA1 should be semantically equal to $orderingA2") + + assertSortRequirementsAreSatisfied( + childPlan = DummySparkPlan(outputOrdering = Seq(orderingA1)), + requiredOrdering = Seq(orderingA2), + shouldHaveSort = false) + } + + // This is a regression test for SPARK-11135 + test("EnsureRequirements adds sort when required ordering isn't a prefix of existing ordering") { + assertSortRequirementsAreSatisfied( + childPlan = DummySparkPlan(outputOrdering = Seq(orderingA)), + requiredOrdering = Seq(orderingA, orderingB), + shouldHaveSort = true) + } } // Used for unit-testing EnsureRequirements From f14f81e900e2e6c216055799584148a2c944268d Mon Sep 17 00:00:00 2001 From: Felix Cheung Date: Sun, 19 Mar 2017 23:49:26 -0700 Subject: [PATCH 0059/1765] [SPARK-20020][SPARKR][FOLLOWUP] DataFrame checkpoint API fix version tag ## What changes were proposed in this pull request? doc only change ## How was this patch tested? manual Author: Felix Cheung Closes #17356 from felixcheung/rdfcheckpoint2. --- R/pkg/R/context.R | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/R/pkg/R/context.R b/R/pkg/R/context.R index cb0f83b2fa227..1ca573e5bd614 100644 --- a/R/pkg/R/context.R +++ b/R/pkg/R/context.R @@ -424,7 +424,7 @@ setLogLevel <- function(level) { #'\dontrun{ #' setCheckpointDir("/checkpoint") #'} -#' @note setCheckpointDir since 2.0.0 +#' @note setCheckpointDir since 2.2.0 setCheckpointDir <- function(directory) { sc <- getSparkContext() invisible(callJMethod(sc, "setCheckpointDir", suppressWarnings(normalizePath(directory)))) From 81639115947a13017d1637549a8f66ba599b27b8 Mon Sep 17 00:00:00 2001 From: Ioana Delaney Date: Mon, 20 Mar 2017 16:04:58 +0800 Subject: [PATCH 0060/1765] [SPARK-17791][SQL] Join reordering using star schema detection ## What changes were proposed in this pull request? Star schema consists of one or more fact tables referencing a number of dimension tables. In general, queries against star schema are expected to run fast because of the established RI constraints among the tables. This design proposes a join reordering based on natural, generally accepted heuristics for star schema queries: - Finds the star join with the largest fact table and places it on the driving arm of the left-deep join. This plan avoids large tables on the inner, and thus favors hash joins. - Applies the most selective dimensions early in the plan to reduce the amount of data flow. The design document was included in SPARK-17791. Link to the google doc: [StarSchemaDetection](https://docs.google.com/document/d/1UAfwbm_A6wo7goHlVZfYK99pqDMEZUumi7pubJXETEA/edit?usp=sharing) ## How was this patch tested? A new test suite StarJoinSuite.scala was implemented. Author: Ioana Delaney Closes #15363 from ioana-delaney/starJoinReord2. --- .../sql/catalyst/SimpleCatalystConf.scala | 1 + .../optimizer/CostBasedJoinReorder.scala | 2 + .../sql/catalyst/optimizer/Optimizer.scala | 2 +- .../spark/sql/catalyst/optimizer/joins.scala | 350 ++++++++++- .../sql/catalyst/planning/patterns.scala | 4 +- .../apache/spark/sql/internal/SQLConf.scala | 16 + .../optimizer/JoinOptimizationSuite.scala | 4 +- .../catalyst/optimizer/JoinReorderSuite.scala | 29 +- .../optimizer/StarJoinReorderSuite.scala | 580 ++++++++++++++++++ .../spark/sql/catalyst/plans/PlanTest.scala | 26 + 10 files changed, 978 insertions(+), 36 deletions(-) create mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/StarJoinReorderSuite.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SimpleCatalystConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SimpleCatalystConf.scala index 0d4903e03bf5a..ac97987c55e08 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SimpleCatalystConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SimpleCatalystConf.scala @@ -40,6 +40,7 @@ case class SimpleCatalystConf( override val cboEnabled: Boolean = false, override val joinReorderEnabled: Boolean = false, override val joinReorderDPThreshold: Int = 12, + override val starSchemaDetection: Boolean = false, override val warehousePath: String = "/user/hive/warehouse", override val sessionLocalTimeZone: String = TimeZone.getDefault().getID, override val maxNestedViewDepth: Int = 100) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/CostBasedJoinReorder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/CostBasedJoinReorder.scala index 1b32bda72bc9f..521c468fe18af 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/CostBasedJoinReorder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/CostBasedJoinReorder.scala @@ -53,6 +53,8 @@ case class CostBasedJoinReorder(conf: SQLConf) extends Rule[LogicalPlan] with Pr def reorder(plan: LogicalPlan, output: AttributeSet): LogicalPlan = { val (items, conditions) = extractInnerJoins(plan) + // TODO: Compute the set of star-joins and use them in the join enumeration + // algorithm to prune un-optimal plan choices. val result = // Do reordering if the number of items is appropriate and join conditions exist. // We also need to check if costs of all items can be evaluated. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index c8ed4190a13ad..d7524a57adbc7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -82,7 +82,7 @@ abstract class Optimizer(sessionCatalog: SessionCatalog, conf: CatalystConf) Batch("Operator Optimizations", fixedPoint, // Operator push down PushProjectionThroughUnion, - ReorderJoin, + ReorderJoin(conf), EliminateOuterJoin, PushPredicateThroughJoin, PushDownPredicate, diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/joins.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/joins.scala index bfe529e21e9ad..58e4a230f4ef0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/joins.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/joins.scala @@ -20,19 +20,347 @@ package org.apache.spark.sql.catalyst.optimizer import scala.annotation.tailrec import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.planning.ExtractFiltersAndInnerJoins +import org.apache.spark.sql.catalyst.planning.{ExtractFiltersAndInnerJoins, PhysicalOperation} import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules._ +import org.apache.spark.sql.internal.SQLConf + +/** + * Encapsulates star-schema join detection. + */ +case class StarSchemaDetection(conf: SQLConf) extends PredicateHelper { + + /** + * Star schema consists of one or more fact tables referencing a number of dimension + * tables. In general, star-schema joins are detected using the following conditions: + * 1. Informational RI constraints (reliable detection) + * + Dimension contains a primary key that is being joined to the fact table. + * + Fact table contains foreign keys referencing multiple dimension tables. + * 2. Cardinality based heuristics + * + Usually, the table with the highest cardinality is the fact table. + * + Table being joined with the most number of tables is the fact table. + * + * To detect star joins, the algorithm uses a combination of the above two conditions. + * The fact table is chosen based on the cardinality heuristics, and the dimension + * tables are chosen based on the RI constraints. A star join will consist of the largest + * fact table joined with the dimension tables on their primary keys. To detect that a + * column is a primary key, the algorithm uses table and column statistics. + * + * Since Catalyst only supports left-deep tree plans, the algorithm currently returns only + * the star join with the largest fact table. Choosing the largest fact table on the + * driving arm to avoid large inners is in general a good heuristic. This restriction can + * be lifted with support for bushy tree plans. + * + * The highlights of the algorithm are the following: + * + * Given a set of joined tables/plans, the algorithm first verifies if they are eligible + * for star join detection. An eligible plan is a base table access with valid statistics. + * A base table access represents Project or Filter operators above a LeafNode. Conservatively, + * the algorithm only considers base table access as part of a star join since they provide + * reliable statistics. + * + * If some of the plans are not base table access, or statistics are not available, the algorithm + * returns an empty star join plan since, in the absence of statistics, it cannot make + * good planning decisions. Otherwise, the algorithm finds the table with the largest cardinality + * (number of rows), which is assumed to be a fact table. + * + * Next, it computes the set of dimension tables for the current fact table. A dimension table + * is assumed to be in a RI relationship with a fact table. To infer column uniqueness, + * the algorithm compares the number of distinct values with the total number of rows in the + * table. If their relative difference is within certain limits (i.e. ndvMaxError * 2, adjusted + * based on 1TB TPC-DS data), the column is assumed to be unique. + */ + def findStarJoins( + input: Seq[LogicalPlan], + conditions: Seq[Expression]): Seq[Seq[LogicalPlan]] = { + + val emptyStarJoinPlan = Seq.empty[Seq[LogicalPlan]] + + if (!conf.starSchemaDetection || input.size < 2) { + emptyStarJoinPlan + } else { + // Find if the input plans are eligible for star join detection. + // An eligible plan is a base table access with valid statistics. + val foundEligibleJoin = input.forall { + case PhysicalOperation(_, _, t: LeafNode) if t.stats(conf).rowCount.isDefined => true + case _ => false + } + + if (!foundEligibleJoin) { + // Some plans don't have stats or are complex plans. Conservatively, + // return an empty star join. This restriction can be lifted + // once statistics are propagated in the plan. + emptyStarJoinPlan + } else { + // Find the fact table using cardinality based heuristics i.e. + // the table with the largest number of rows. + val sortedFactTables = input.map { plan => + TableAccessCardinality(plan, getTableAccessCardinality(plan)) + }.collect { case t @ TableAccessCardinality(_, Some(_)) => + t + }.sortBy(_.size)(implicitly[Ordering[Option[BigInt]]].reverse) + + sortedFactTables match { + case Nil => + emptyStarJoinPlan + case table1 :: table2 :: _ + if table2.size.get.toDouble > conf.starSchemaFTRatio * table1.size.get.toDouble => + // If the top largest tables have comparable number of rows, return an empty star plan. + // This restriction will be lifted when the algorithm is generalized + // to return multiple star plans. + emptyStarJoinPlan + case TableAccessCardinality(factTable, _) :: rest => + // Find the fact table joins. + val allFactJoins = rest.collect { case TableAccessCardinality(plan, _) + if findJoinConditions(factTable, plan, conditions).nonEmpty => + plan + } + + // Find the corresponding join conditions. + val allFactJoinCond = allFactJoins.flatMap { plan => + val joinCond = findJoinConditions(factTable, plan, conditions) + joinCond + } + + // Verify if the join columns have valid statistics. + // Allow any relational comparison between the tables. Later + // we will heuristically choose a subset of equi-join + // tables. + val areStatsAvailable = allFactJoins.forall { dimTable => + allFactJoinCond.exists { + case BinaryComparison(lhs: AttributeReference, rhs: AttributeReference) => + val dimCol = if (dimTable.outputSet.contains(lhs)) lhs else rhs + val factCol = if (factTable.outputSet.contains(lhs)) lhs else rhs + hasStatistics(dimCol, dimTable) && hasStatistics(factCol, factTable) + case _ => false + } + } + + if (!areStatsAvailable) { + emptyStarJoinPlan + } else { + // Find the subset of dimension tables. A dimension table is assumed to be in a + // RI relationship with the fact table. Only consider equi-joins + // between a fact and a dimension table to avoid expanding joins. + val eligibleDimPlans = allFactJoins.filter { dimTable => + allFactJoinCond.exists { + case cond @ Equality(lhs: AttributeReference, rhs: AttributeReference) => + val dimCol = if (dimTable.outputSet.contains(lhs)) lhs else rhs + isUnique(dimCol, dimTable) + case _ => false + } + } + + if (eligibleDimPlans.isEmpty) { + // An eligible star join was not found because the join is not + // an RI join, or the star join is an expanding join. + emptyStarJoinPlan + } else { + Seq(factTable +: eligibleDimPlans) + } + } + } + } + } + } + + /** + * Reorders a star join based on heuristics: + * 1) Finds the star join with the largest fact table and places it on the driving + * arm of the left-deep tree. This plan avoids large table access on the inner, and + * thus favor hash joins. + * 2) Applies the most selective dimensions early in the plan to reduce the amount of + * data flow. + */ + def reorderStarJoins( + input: Seq[(LogicalPlan, InnerLike)], + conditions: Seq[Expression]): Seq[(LogicalPlan, InnerLike)] = { + assert(input.size >= 2) + + val emptyStarJoinPlan = Seq.empty[(LogicalPlan, InnerLike)] + + // Find the eligible star plans. Currently, it only returns + // the star join with the largest fact table. + val eligibleJoins = input.collect{ case (plan, Inner) => plan } + val starPlans = findStarJoins(eligibleJoins, conditions) + + if (starPlans.isEmpty) { + emptyStarJoinPlan + } else { + val starPlan = starPlans.head + val (factTable, dimTables) = (starPlan.head, starPlan.tail) + + // Only consider selective joins. This case is detected by observing local predicates + // on the dimension tables. In a star schema relationship, the join between the fact and the + // dimension table is a FK-PK join. Heuristically, a selective dimension may reduce + // the result of a join. + // Also, conservatively assume that a fact table is joined with more than one dimension. + if (dimTables.size >= 2 && isSelectiveStarJoin(dimTables, conditions)) { + val reorderDimTables = dimTables.map { plan => + TableAccessCardinality(plan, getTableAccessCardinality(plan)) + }.sortBy(_.size).map { + case TableAccessCardinality(p1, _) => p1 + } + + val reorderStarPlan = factTable +: reorderDimTables + reorderStarPlan.map(plan => (plan, Inner)) + } else { + emptyStarJoinPlan + } + } + } + + /** + * Determines if a column referenced by a base table access is a primary key. + * A column is a PK if it is not nullable and has unique values. + * To determine if a column has unique values in the absence of informational + * RI constraints, the number of distinct values is compared to the total + * number of rows in the table. If their relative difference + * is within the expected limits (i.e. 2 * spark.sql.statistics.ndv.maxError based + * on TPCDS data results), the column is assumed to have unique values. + */ + private def isUnique( + column: Attribute, + plan: LogicalPlan): Boolean = plan match { + case PhysicalOperation(_, _, t: LeafNode) => + val leafCol = findLeafNodeCol(column, plan) + leafCol match { + case Some(col) if t.outputSet.contains(col) => + val stats = t.stats(conf) + stats.rowCount match { + case Some(rowCount) if rowCount >= 0 => + if (stats.attributeStats.nonEmpty && stats.attributeStats.contains(col)) { + val colStats = stats.attributeStats.get(col) + if (colStats.get.nullCount > 0) { + false + } else { + val distinctCount = colStats.get.distinctCount + val relDiff = math.abs((distinctCount.toDouble / rowCount.toDouble) - 1.0d) + // ndvMaxErr adjusted based on TPCDS 1TB data results + relDiff <= conf.ndvMaxError * 2 + } + } else { + false + } + case None => false + } + case None => false + } + case _ => false + } + + /** + * Given a column over a base table access, it returns + * the leaf node column from which the input column is derived. + */ + @tailrec + private def findLeafNodeCol( + column: Attribute, + plan: LogicalPlan): Option[Attribute] = plan match { + case pl @ PhysicalOperation(_, _, _: LeafNode) => + pl match { + case t: LeafNode if t.outputSet.contains(column) => + Option(column) + case p: Project if p.outputSet.exists(_.semanticEquals(column)) => + val col = p.outputSet.find(_.semanticEquals(column)).get + findLeafNodeCol(col, p.child) + case f: Filter => + findLeafNodeCol(column, f.child) + case _ => None + } + case _ => None + } + + /** + * Checks if a column has statistics. + * The column is assumed to be over a base table access. + */ + private def hasStatistics( + column: Attribute, + plan: LogicalPlan): Boolean = plan match { + case PhysicalOperation(_, _, t: LeafNode) => + val leafCol = findLeafNodeCol(column, plan) + leafCol match { + case Some(col) if t.outputSet.contains(col) => + val stats = t.stats(conf) + stats.attributeStats.nonEmpty && stats.attributeStats.contains(col) + case None => false + } + case _ => false + } + + /** + * Returns the join predicates between two input plans. It only + * considers basic comparison operators. + */ + @inline + private def findJoinConditions( + plan1: LogicalPlan, + plan2: LogicalPlan, + conditions: Seq[Expression]): Seq[Expression] = { + val refs = plan1.outputSet ++ plan2.outputSet + conditions.filter { + case BinaryComparison(_, _) => true + case _ => false + }.filterNot(canEvaluate(_, plan1)) + .filterNot(canEvaluate(_, plan2)) + .filter(_.references.subsetOf(refs)) + } + + /** + * Checks if a star join is a selective join. A star join is assumed + * to be selective if there are local predicates on the dimension + * tables. + */ + private def isSelectiveStarJoin( + dimTables: Seq[LogicalPlan], + conditions: Seq[Expression]): Boolean = dimTables.exists { + case plan @ PhysicalOperation(_, p, _: LeafNode) => + // Checks if any condition applies to the dimension tables. + // Exclude the IsNotNull predicates until predicate selectivity is available. + // In most cases, this predicate is artificially introduced by the Optimizer + // to enforce nullability constraints. + val localPredicates = conditions.filterNot(_.isInstanceOf[IsNotNull]) + .exists(canEvaluate(_, plan)) + + // Checks if there are any predicates pushed down to the base table access. + val pushedDownPredicates = p.nonEmpty && !p.forall(_.isInstanceOf[IsNotNull]) + + localPredicates || pushedDownPredicates + case _ => false + } + + /** + * Helper case class to hold (plan, rowCount) pairs. + */ + private case class TableAccessCardinality(plan: LogicalPlan, size: Option[BigInt]) + + /** + * Returns the cardinality of a base table access. A base table access represents + * a LeafNode, or Project or Filter operators above a LeafNode. + */ + private def getTableAccessCardinality( + input: LogicalPlan): Option[BigInt] = input match { + case PhysicalOperation(_, cond, t: LeafNode) if t.stats(conf).rowCount.isDefined => + if (conf.cboEnabled && input.stats(conf).rowCount.isDefined) { + Option(input.stats(conf).rowCount.get) + } else { + Option(t.stats(conf).rowCount.get) + } + case _ => None + } +} /** * Reorder the joins and push all the conditions into join, so that the bottom ones have at least * one condition. * * The order of joins will not be changed if all of them already have at least one condition. + * + * If star schema detection is enabled, reorder the star join plans based on heuristics. */ -object ReorderJoin extends Rule[LogicalPlan] with PredicateHelper { - +case class ReorderJoin(conf: SQLConf) extends Rule[LogicalPlan] with PredicateHelper { /** * Join a list of plans together and push down the conditions into them. * @@ -42,7 +370,7 @@ object ReorderJoin extends Rule[LogicalPlan] with PredicateHelper { * @param conditions a list of condition for join. */ @tailrec - def createOrderedJoin(input: Seq[(LogicalPlan, InnerLike)], conditions: Seq[Expression]) + final def createOrderedJoin(input: Seq[(LogicalPlan, InnerLike)], conditions: Seq[Expression]) : LogicalPlan = { assert(input.size >= 2) if (input.size == 2) { @@ -83,9 +411,19 @@ object ReorderJoin extends Rule[LogicalPlan] with PredicateHelper { } def apply(plan: LogicalPlan): LogicalPlan = plan transform { - case j @ ExtractFiltersAndInnerJoins(input, conditions) + case ExtractFiltersAndInnerJoins(input, conditions) if input.size > 2 && conditions.nonEmpty => - createOrderedJoin(input, conditions) + if (conf.starSchemaDetection && !conf.cboEnabled) { + val starJoinPlan = StarSchemaDetection(conf).reorderStarJoins(input, conditions) + if (starJoinPlan.nonEmpty) { + val rest = input.filterNot(starJoinPlan.contains(_)) + createOrderedJoin(starJoinPlan ++ rest, conditions) + } else { + createOrderedJoin(input, conditions) + } + } else { + createOrderedJoin(input, conditions) + } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala index 0893af26738bf..d39b0ef7e1d8a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala @@ -167,8 +167,8 @@ object ExtractFiltersAndInnerJoins extends PredicateHelper { : (Seq[(LogicalPlan, InnerLike)], Seq[Expression]) = plan match { case Join(left, right, joinType: InnerLike, cond) => val (plans, conditions) = flattenJoin(left, joinType) - (plans ++ Seq((right, joinType)), conditions ++ cond.toSeq) - + (plans ++ Seq((right, joinType)), conditions ++ + cond.toSeq.flatMap(splitConjunctivePredicates)) case Filter(filterCondition, j @ Join(left, right, _: InnerLike, joinCondition)) => val (plans, conditions) = flattenJoin(j) (plans, conditions ++ splitConjunctivePredicates(filterCondition)) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index d2ac4b88ee8fd..b6e0b8ccbeed6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -719,6 +719,18 @@ object SQLConf { .checkValue(weight => weight >= 0 && weight <= 1, "The weight value must be in [0, 1].") .createWithDefault(0.7) + val STARSCHEMA_DETECTION = buildConf("spark.sql.cbo.starSchemaDetection") + .doc("When true, it enables join reordering based on star schema detection. ") + .booleanConf + .createWithDefault(false) + + val STARSCHEMA_FACT_TABLE_RATIO = buildConf("spark.sql.cbo.starJoinFTRatio") + .internal() + .doc("Specifies the upper limit of the ratio between the largest fact tables" + + " for a star join to be considered. ") + .doubleConf + .createWithDefault(0.9) + val SESSION_LOCAL_TIMEZONE = buildConf("spark.sql.session.timeZone") .doc("""The ID of session local timezone, e.g. "GMT", "America/Los_Angeles", etc.""") @@ -988,6 +1000,10 @@ class SQLConf extends Serializable with Logging { def maxNestedViewDepth: Int = getConf(SQLConf.MAX_NESTED_VIEW_DEPTH) + def starSchemaDetection: Boolean = getConf(STARSCHEMA_DETECTION) + + def starSchemaFTRatio: Double = getConf(STARSCHEMA_FACT_TABLE_RATIO) + /** ********************** SQLConf functionality methods ************ */ /** Set Spark SQL configuration properties. */ diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinOptimizationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinOptimizationSuite.scala index 985e49069da90..61e81808147c7 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinOptimizationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinOptimizationSuite.scala @@ -26,7 +26,7 @@ import org.apache.spark.sql.catalyst.planning.ExtractFiltersAndInnerJoins import org.apache.spark.sql.catalyst.plans.{Cross, Inner, InnerLike, PlanTest} import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules.RuleExecutor - +import org.apache.spark.sql.catalyst.SimpleCatalystConf class JoinOptimizationSuite extends PlanTest { @@ -38,7 +38,7 @@ class JoinOptimizationSuite extends PlanTest { CombineFilters, PushDownPredicate, BooleanSimplification, - ReorderJoin, + ReorderJoin(SimpleCatalystConf(true)), PushPredicateThroughJoin, ColumnPruning, CollapseProject) :: Nil diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinReorderSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinReorderSuite.scala index 5607bcd16f3ff..05b839b0119f4 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinReorderSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinReorderSuite.scala @@ -22,10 +22,9 @@ import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap} import org.apache.spark.sql.catalyst.plans.{Inner, PlanTest} -import org.apache.spark.sql.catalyst.plans.logical.{ColumnStat, Join, LogicalPlan} +import org.apache.spark.sql.catalyst.plans.logical.{ColumnStat, LogicalPlan} import org.apache.spark.sql.catalyst.rules.RuleExecutor import org.apache.spark.sql.catalyst.statsEstimation.{StatsEstimationTestBase, StatsTestPlan} -import org.apache.spark.sql.catalyst.util._ class JoinReorderSuite extends PlanTest with StatsEstimationTestBase { @@ -38,7 +37,7 @@ class JoinReorderSuite extends PlanTest with StatsEstimationTestBase { Batch("Operator Optimizations", FixedPoint(100), CombineFilters, PushDownPredicate, - ReorderJoin, + ReorderJoin(conf), PushPredicateThroughJoin, ColumnPruning, CollapseProject) :: @@ -203,27 +202,7 @@ class JoinReorderSuite extends PlanTest with StatsEstimationTestBase { originalPlan: LogicalPlan, groundTruthBestPlan: LogicalPlan): Unit = { val optimized = Optimize.execute(originalPlan.analyze) - val normalized1 = normalizePlan(normalizeExprIds(optimized)) - val normalized2 = normalizePlan(normalizeExprIds(groundTruthBestPlan.analyze)) - if (!sameJoinPlan(normalized1, normalized2)) { - fail( - s""" - |== FAIL: Plans do not match === - |${sideBySide(normalized1.treeString, normalized2.treeString).mkString("\n")} - """.stripMargin) - } - } - - /** Consider symmetry for joins when comparing plans. */ - private def sameJoinPlan(plan1: LogicalPlan, plan2: LogicalPlan): Boolean = { - (plan1, plan2) match { - case (j1: Join, j2: Join) => - (sameJoinPlan(j1.left, j2.left) && sameJoinPlan(j1.right, j2.right)) || - (sameJoinPlan(j1.left, j2.right) && sameJoinPlan(j1.right, j2.left)) - case _ if plan1.children.nonEmpty && plan2.children.nonEmpty => - (plan1.children, plan2.children).zipped.forall { case (c1, c2) => sameJoinPlan(c1, c2) } - case _ => - plan1 == plan2 - } + val expected = groundTruthBestPlan.analyze + compareJoinOrder(optimized, expected) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/StarJoinReorderSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/StarJoinReorderSuite.scala new file mode 100644 index 0000000000000..93fdd98d1ac93 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/StarJoinReorderSuite.scala @@ -0,0 +1,580 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.optimizer + +import org.apache.spark.sql.catalyst.SimpleCatalystConf +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.dsl.plans._ +import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap} +import org.apache.spark.sql.catalyst.plans.{Inner, PlanTest} +import org.apache.spark.sql.catalyst.plans.logical.{ColumnStat, LocalRelation, LogicalPlan} +import org.apache.spark.sql.catalyst.rules.RuleExecutor +import org.apache.spark.sql.catalyst.statsEstimation.{StatsEstimationTestBase, StatsTestPlan} + + +class StarJoinReorderSuite extends PlanTest with StatsEstimationTestBase { + + override val conf = SimpleCatalystConf( + caseSensitiveAnalysis = true, starSchemaDetection = true) + + object Optimize extends RuleExecutor[LogicalPlan] { + val batches = + Batch("Operator Optimizations", FixedPoint(100), + CombineFilters, + PushDownPredicate, + ReorderJoin(conf), + PushPredicateThroughJoin, + ColumnPruning, + CollapseProject) :: Nil + } + + // Table setup using star schema relationships: + // + // d1 - f1 - d2 + // | + // d3 - s3 + // + // Table f1 is the fact table. Tables d1, d2, and d3 are the dimension tables. + // Dimension d3 is further joined/normalized into table s3. + // Tables' cardinality: f1 > d3 > d1 > d2 > s3 + private val columnInfo: AttributeMap[ColumnStat] = AttributeMap(Seq( + // F1 + attr("f1_fk1") -> ColumnStat(distinctCount = 3, min = Some(1), max = Some(3), + nullCount = 0, avgLen = 4, maxLen = 4), + attr("f1_fk2") -> ColumnStat(distinctCount = 3, min = Some(1), max = Some(3), + nullCount = 0, avgLen = 4, maxLen = 4), + attr("f1_fk3") -> ColumnStat(distinctCount = 4, min = Some(1), max = Some(4), + nullCount = 0, avgLen = 4, maxLen = 4), + attr("f1_c4") -> ColumnStat(distinctCount = 4, min = Some(1), max = Some(4), + nullCount = 0, avgLen = 4, maxLen = 4), + // D1 + attr("d1_pk1") -> ColumnStat(distinctCount = 4, min = Some(1), max = Some(4), + nullCount = 0, avgLen = 4, maxLen = 4), + attr("d1_c2") -> ColumnStat(distinctCount = 3, min = Some(1), max = Some(3), + nullCount = 0, avgLen = 4, maxLen = 4), + attr("d1_c3") -> ColumnStat(distinctCount = 4, min = Some(1), max = Some(4), + nullCount = 0, avgLen = 4, maxLen = 4), + attr("d1_c4") -> ColumnStat(distinctCount = 2, min = Some(2), max = Some(3), + nullCount = 0, avgLen = 4, maxLen = 4), + // D2 + attr("d2_c2") -> ColumnStat(distinctCount = 3, min = Some(1), max = Some(3), + nullCount = 1, avgLen = 4, maxLen = 4), + attr("d2_pk1") -> ColumnStat(distinctCount = 3, min = Some(1), max = Some(3), + nullCount = 0, avgLen = 4, maxLen = 4), + attr("d2_c3") -> ColumnStat(distinctCount = 3, min = Some(1), max = Some(3), + nullCount = 0, avgLen = 4, maxLen = 4), + attr("d2_c4") -> ColumnStat(distinctCount = 2, min = Some(3), max = Some(4), + nullCount = 0, avgLen = 4, maxLen = 4), + // D3 + attr("d3_fk1") -> ColumnStat(distinctCount = 3, min = Some(1), max = Some(3), + nullCount = 0, avgLen = 4, maxLen = 4), + attr("d3_c2") -> ColumnStat(distinctCount = 3, min = Some(1), max = Some(3), + nullCount = 0, avgLen = 4, maxLen = 4), + attr("d3_pk1") -> ColumnStat(distinctCount = 5, min = Some(1), max = Some(5), + nullCount = 0, avgLen = 4, maxLen = 4), + attr("d3_c4") -> ColumnStat(distinctCount = 2, min = Some(2), max = Some(3), + nullCount = 0, avgLen = 4, maxLen = 4), + // S3 + attr("s3_pk1") -> ColumnStat(distinctCount = 2, min = Some(1), max = Some(2), + nullCount = 0, avgLen = 4, maxLen = 4), + attr("s3_c2") -> ColumnStat(distinctCount = 1, min = Some(3), max = Some(3), + nullCount = 0, avgLen = 4, maxLen = 4), + attr("s3_c3") -> ColumnStat(distinctCount = 1, min = Some(3), max = Some(3), + nullCount = 0, avgLen = 4, maxLen = 4), + attr("s3_c4") -> ColumnStat(distinctCount = 2, min = Some(3), max = Some(4), + nullCount = 0, avgLen = 4, maxLen = 4), + // F11 + attr("f11_fk1") -> ColumnStat(distinctCount = 3, min = Some(1), max = Some(3), + nullCount = 0, avgLen = 4, maxLen = 4), + attr("f11_fk2") -> ColumnStat(distinctCount = 3, min = Some(1), max = Some(3), + nullCount = 0, avgLen = 4, maxLen = 4), + attr("f11_fk3") -> ColumnStat(distinctCount = 4, min = Some(1), max = Some(4), + nullCount = 0, avgLen = 4, maxLen = 4), + attr("f11_c4") -> ColumnStat(distinctCount = 4, min = Some(1), max = Some(4), + nullCount = 0, avgLen = 4, maxLen = 4) + )) + + private val nameToAttr: Map[String, Attribute] = columnInfo.map(kv => kv._1.name -> kv._1) + private val nameToColInfo: Map[String, (Attribute, ColumnStat)] = + columnInfo.map(kv => kv._1.name -> kv) + + private val f1 = StatsTestPlan( + outputList = Seq("f1_fk1", "f1_fk2", "f1_fk3", "f1_c4").map(nameToAttr), + rowCount = 6, + size = Some(48), + attributeStats = AttributeMap(Seq("f1_fk1", "f1_fk2", "f1_fk3", "f1_c4").map(nameToColInfo))) + + private val d1 = StatsTestPlan( + outputList = Seq("d1_pk1", "d1_c2", "d1_c3", "d1_c4").map(nameToAttr), + rowCount = 4, + size = Some(32), + attributeStats = AttributeMap(Seq("d1_pk1", "d1_c2", "d1_c3", "d1_c4").map(nameToColInfo))) + + private val d2 = StatsTestPlan( + outputList = Seq("d2_c2", "d2_pk1", "d2_c3", "d2_c4").map(nameToAttr), + rowCount = 3, + size = Some(24), + attributeStats = AttributeMap(Seq("d2_c2", "d2_pk1", "d2_c3", "d2_c4").map(nameToColInfo))) + + private val d3 = StatsTestPlan( + outputList = Seq("d3_fk1", "d3_c2", "d3_pk1", "d3_c4").map(nameToAttr), + rowCount = 5, + size = Some(40), + attributeStats = AttributeMap(Seq("d3_fk1", "d3_c2", "d3_pk1", "d3_c4").map(nameToColInfo))) + + private val s3 = StatsTestPlan( + outputList = Seq("s3_pk1", "s3_c2", "s3_c3", "s3_c4").map(nameToAttr), + rowCount = 2, + size = Some(17), + attributeStats = AttributeMap(Seq("s3_pk1", "s3_c2", "s3_c3", "s3_c4").map(nameToColInfo))) + + private val d3_ns = LocalRelation('d3_fk1.int, 'd3_c2.int, 'd3_pk1.int, 'd3_c4.int) + + private val f11 = StatsTestPlan( + outputList = Seq("f11_fk1", "f11_fk2", "f11_fk3", "f11_c4").map(nameToAttr), + rowCount = 6, + size = Some(48), + attributeStats = AttributeMap(Seq("f11_fk1", "f11_fk2", "f11_fk3", "f11_c4") + .map(nameToColInfo))) + + private val subq = d3.select(sum('d3_fk1).as('col)) + + test("Test 1: Selective star-join on all dimensions") { + // Star join: + // (=) (=) + // d1 - f1 - d2 + // | (=) + // s3 - d3 + // + // Query: + // select f1_fk1, f1_fk3 + // from d1, d2, f1, d3, s3 + // where f1_fk2 = d2_pk1 and d2_c2 < 2 + // and f1_fk1 = d1_pk1 + // and f1_fk3 = d3_pk1 + // and d3_fk1 = s3_pk1 + // + // Positional join reordering: d1, f1, d2, d3, s3 + // Star join reordering: f1, d2, d1, d3, s3 + val query = + d1.join(d2).join(f1).join(d3).join(s3) + .where((nameToAttr("f1_fk2") === nameToAttr("d2_pk1")) && + (nameToAttr("d2_c2") === 2) && + (nameToAttr("f1_fk1") === nameToAttr("d1_pk1")) && + (nameToAttr("f1_fk3") === nameToAttr("d3_pk1")) && + (nameToAttr("d3_fk1") === nameToAttr("s3_pk1"))) + + val expected = + f1.join(d2.where(nameToAttr("d2_c2") === 2), Inner, + Some(nameToAttr("f1_fk2") === nameToAttr("d2_pk1"))) + .join(d1, Inner, Some(nameToAttr("f1_fk1") === nameToAttr("d1_pk1"))) + .join(d3, Inner, Some(nameToAttr("f1_fk3") === nameToAttr("d3_pk1"))) + .join(s3, Inner, Some(nameToAttr("d3_fk1") === nameToAttr("s3_pk1"))) + + assertEqualPlans(query, expected) + } + + test("Test 2: Star join on a subset of dimensions due to inequality joins") { + // Star join: + // (=) (<) + // d1 - f1 - d2 + // | + // | (=) + // d3 - s3 + // (=) + // + // Query: + // select f1_fk1, f1_fk3 + // from d1, f1, d2, s3, d3 + // where f1_fk2 < d2_pk1 + // and f1_fk1 = d1_pk1 and d1_c2 = 2 + // and f1_fk3 = d3_pk1 + // and d3_fk1 = s3_pk1 + // + // Default join reordering: d1, f1, d2, d3, s3 + // Star join reordering: f1, d1, d3, d2,, d3 + + val query = + d1.join(f1).join(d2).join(s3).join(d3) + .where((nameToAttr("f1_fk2") < nameToAttr("d2_pk1")) && + (nameToAttr("f1_fk1") === nameToAttr("d1_pk1")) && + (nameToAttr("d1_c2") === 2) && + (nameToAttr("f1_fk3") === nameToAttr("d3_pk1")) && + (nameToAttr("d3_fk1") === nameToAttr("s3_pk1"))) + + val expected = + f1.join(d1.where(nameToAttr("d1_c2") === 2), Inner, + Some(nameToAttr("f1_fk1") === nameToAttr("d1_pk1"))) + .join(d3, Inner, Some(nameToAttr("f1_fk3") === nameToAttr("d3_pk1"))) + .join(d2, Inner, Some(nameToAttr("f1_fk2") < nameToAttr("d2_pk1"))) + .join(s3, Inner, Some(nameToAttr("d3_fk1") === nameToAttr("s3_pk1"))) + + assertEqualPlans(query, expected) + } + + test("Test 3: Star join on a subset of dimensions since join column is not unique") { + // Star join: + // (=) (=) + // d1 - f1 - d2 + // | (=) + // d3 - s3 + // + // Query: + // select f1_fk1, f1_fk3 + // from d1, f1, d2, s3, d3 + // where f1_fk2 = d2_c4 + // and f1_fk1 = d1_pk1 and d1_c2 = 2 + // and f1_fk3 = d3_pk1 + // and d3_fk1 = s3_pk1 + // + // Default join reordering: d1, f1, d2, d3, s3 + // Star join reordering: f1, d1, d3, d2, d3 + val query = + d1.join(f1).join(d2).join(s3).join(d3) + .where((nameToAttr("f1_fk1") === nameToAttr("d1_pk1")) && + (nameToAttr("d1_c2") === 2) && + (nameToAttr("f1_fk2") === nameToAttr("d2_c4")) && + (nameToAttr("f1_fk3") === nameToAttr("d3_pk1")) && + (nameToAttr("d3_fk1") === nameToAttr("s3_pk1"))) + + val expected = + f1.join(d1.where(nameToAttr("d1_c2") === 2), Inner, + Some(nameToAttr("f1_fk1") === nameToAttr("d1_pk1"))) + .join(d3, Inner, Some(nameToAttr("d3_fk1") === nameToAttr("s3_pk1"))) + .join(d2, Inner, Some(nameToAttr("f1_fk2") === nameToAttr("d2_pk1"))) + .join(s3, Inner, Some(nameToAttr("f1_fk3") === nameToAttr("s3_c2"))) + + + assertEqualPlans(query, expected) + } + + test("Test 4: Star join on a subset of dimensions since join column is nullable") { + // Star join: + // (=) (=) + // d1 - f1 - d2 + // | (=) + // s3 - d3 + // + // Query: + // select f1_fk1, f1_fk3 + // from d1, f1, d2, s3, d3 + // where f1_fk2 = d2_c2 + // and f1_fk1 = d1_pk1 and d1_c2 = 2 + // and f1_fk3 = d3_pk1 + // and d3_fk1 = s3_pk1 + // + // Default join reordering: d1, f1, d2, d3, s3 + // Star join reordering: f1, d1, d3, d2, s3 + + val query = + d1.join(f1).join(d2).join(s3).join(d3) + .where((nameToAttr("f1_fk1") === nameToAttr("d1_pk1")) && + (nameToAttr("d1_c2") === 2) && + (nameToAttr("f1_fk2") === nameToAttr("d2_c2")) && + (nameToAttr("f1_fk3") === nameToAttr("d3_pk1")) && + (nameToAttr("d3_fk1") === nameToAttr("s3_pk1"))) + + val expected = + f1.join(d1.where(nameToAttr("d1_c2") === 2), Inner, + Some(nameToAttr("f1_fk1") === nameToAttr("d1_pk1"))) + .join(d3, Inner, Some(nameToAttr("f1_fk3") === nameToAttr("d3_pk1"))) + .join(d2, Inner, Some(nameToAttr("f1_fk2") === nameToAttr("d2_c2"))) + .join(s3, Inner, Some(nameToAttr("d3_fk1") < nameToAttr("s3_pk1"))) + + assertEqualPlans(query, expected) + } + + test("Test 5: Table stats not available for some of the joined tables") { + // Star join: + // (=) (=) + // d1 - f1 - d2 + // | (=) + // d3_ns - s3 + // + // select f1_fk1, f1_fk3 + // from d3_ns, f1, d1, d2, s3 + // where f1_fk2 = d2_pk1 and d2_c2 = 2 + // and f1_fk1 = d1_pk1 + // and f1_fk3 = d3_pk1 + // and d3_fk1 = s3_pk1 + // + // Positional join reordering: d3_ns, f1, d1, d2, s3 + // Star join reordering: empty + + val query = + d3_ns.join(f1).join(d1).join(d2).join(s3) + .where((nameToAttr("f1_fk2") === nameToAttr("d2_pk1")) && + (nameToAttr("d2_c2") === 2) && + (nameToAttr("f1_fk1") === nameToAttr("d1_pk1")) && + (nameToAttr("f1_fk3") === nameToAttr("d3_pk1")) && + (nameToAttr("d3_fk1") === nameToAttr("s3_pk1"))) + + val equivQuery = + d3_ns.join(f1, Inner, Some(nameToAttr("f1_fk3") === nameToAttr("d3_pk1"))) + .join(d1, Inner, Some(nameToAttr("f1_fk1") === nameToAttr("d1_pk1"))) + .join(d2.where(nameToAttr("d2_c2") === 2), Inner, + Some(nameToAttr("f1_fk2") === nameToAttr("d2_pk1"))) + .join(s3, Inner, Some(nameToAttr("d3_fk1") === nameToAttr("s3_pk1"))) + + assertEqualPlans(query, equivQuery) + } + + test("Test 6: Join with complex plans") { + // Star join: + // (=) (=) + // d1 - f1 - d2 + // | (=) + // (sub-query) + // + // select f1_fk1, f1_fk3 + // from (select sum(d3_fk1) as col from d3) subq, f1, d1, d2 + // where f1_fk2 = d2_pk1 and d2_c2 < 2 + // and f1_fk1 = d1_pk1 + // and f1_fk3 = sq.col + // + // Positional join reordering: d3, f1, d1, d2 + // Star join reordering: empty + + val query = + subq.join(f1).join(d1).join(d2) + .where((nameToAttr("f1_fk2") === nameToAttr("d2_pk1")) && + (nameToAttr("d2_c2") === 2) && + (nameToAttr("f1_fk1") === nameToAttr("d1_pk1")) && + (nameToAttr("f1_fk3") === "col".attr)) + + val expected = + d3.select('d3_fk1).select(sum('d3_fk1).as('col)) + .join(f1, Inner, Some(nameToAttr("f1_fk3") === "col".attr)) + .join(d1, Inner, Some(nameToAttr("f1_fk1") === nameToAttr("d1_pk1"))) + .join(d2.where(nameToAttr("d2_c2") === 2), Inner, + Some(nameToAttr("f1_fk2") === nameToAttr("d2_pk1"))) + + assertEqualPlans(query, expected) + } + + test("Test 7: Comparable fact table sizes") { + // Star join: + // (=) (=) + // d1 - f1 - d2 + // | (=) + // f11 - s3 + // + // select f1.f1_fk1, f1.f1_fk3 + // from d1, f11, f1, d2, s3 + // where f1.f1_fk2 = d2_pk1 and d2_c2 = 2 + // and f1.f1_fk1 = d1_pk1 + // and f1.f1_fk3 = f11.f1_fk3 + // and f11.f1_fk1 = s3_pk1 + // + // Positional join reordering: d1, f1, f11, d2, s3 + // Star join reordering: empty + + val query = + d1.join(f11).join(f1).join(d2).join(s3) + .where((nameToAttr("f1_fk2") === nameToAttr("d2_pk1")) && + (nameToAttr("d2_c2") === 2) && + (nameToAttr("f1_fk1") === nameToAttr("d1_pk1")) && + (nameToAttr("f1_fk3") === nameToAttr("f11_fk3")) && + (nameToAttr("f11_fk1") === nameToAttr("s3_pk1"))) + + val equivQuery = + d1.join(f1, Inner, Some(nameToAttr("f1_fk1") === nameToAttr("d1_pk1"))) + .join(f11, Inner, Some(nameToAttr("f1_fk3") === nameToAttr("f11_fk3"))) + .join(d2.where(nameToAttr("d2_c2") === 2), Inner, + Some(nameToAttr("f1_fk2") === nameToAttr("d2_pk1"))) + .join(s3, Inner, Some(nameToAttr("f11_fk1") === nameToAttr("s3_pk1"))) + + assertEqualPlans(query, equivQuery) + } + + test("Test 8: No RI joins") { + // Star join: + // (=) (=) + // d1 - f1 - d2 + // | (=) + // d3 - s3 + // + // select f1_fk1, f1_fk3 + // from d1, d3, f1, d2, s3 + // where f1_fk2 = d2_c4 and d2_c2 = 2 + // and f1_fk1 = d1_c4 + // and f1_fk3 = d3_c4 + // and d3_fk1 = s3_pk1 + // + // Positional/default join reordering: d1, f1, d3, d2, s3 + // Star join reordering: empty + + val query = + d1.join(d3).join(f1).join(d2).join(s3) + .where((nameToAttr("f1_fk2") === nameToAttr("d2_c4")) && + (nameToAttr("d2_c2") === 2) && + (nameToAttr("f1_fk1") === nameToAttr("d1_c4")) && + (nameToAttr("f1_fk3") === nameToAttr("d3_c4")) && + (nameToAttr("d3_fk1") === nameToAttr("s3_pk1"))) + + val expected = + d1.join(f1, Inner, Some(nameToAttr("f1_fk1") === nameToAttr("d1_c4"))) + .join(d3, Inner, Some(nameToAttr("f1_fk3") === nameToAttr("d3_c4"))) + .join(d2.where(nameToAttr("d2_c2") === 2), Inner, + Some(nameToAttr("f1_fk2") === nameToAttr("d2_c4"))) + .join(s3, Inner, Some(nameToAttr("d3_fk1") === nameToAttr("s3_pk1"))) + + assertEqualPlans(query, expected) + } + + test("Test 9: Complex join predicates") { + // Star join: + // (=) (=) + // d1 - f1 - d2 + // | (=) + // d3 - s3 + // + // select f1_fk1, f1_fk3 + // from d1, d3, f1, d2, s3 + // where f1_fk2 = d2_pk1 and d2_c2 = 2 + // and abs(f1_fk1) = d1_pk1 + // and f1_fk3 = d3_pk1 + // and d3_fk1 = s3_pk1 + // + // Positional/default join reordering: d1, f1, d3, d2, s3 + // Star join reordering: empty + + val query = + d1.join(d3).join(f1).join(d2).join(s3) + .where((nameToAttr("f1_fk2") === nameToAttr("d2_pk1")) && + (nameToAttr("d2_c2") === 2) && + (abs(nameToAttr("f1_fk1")) === nameToAttr("d1_pk1")) && + (nameToAttr("f1_fk3") === nameToAttr("d3_pk1")) && + (nameToAttr("d3_fk1") === nameToAttr("s3_pk1"))) + + val expected = + d1.join(f1, Inner, Some(abs(nameToAttr("f1_fk1")) === nameToAttr("d1_pk1"))) + .join(d3, Inner, Some(nameToAttr("f1_fk3") === nameToAttr("d3_pk1"))) + .join(d2.where(nameToAttr("d2_c2") === 2), Inner, + Some(nameToAttr("f1_fk2") === nameToAttr("d2_pk1"))) + .join(s3, Inner, Some(nameToAttr("d3_fk1") === nameToAttr("s3_pk1"))) + + assertEqualPlans(query, expected) + } + + test("Test 10: Less than two dimensions") { + // Star join: + // (<) (=) + // d1 - f1 - d2 + // |(<) + // d3 - s3 + // + // select f1_fk1, f1_fk3 + // from d1, d3, f1, d2, s3 + // where f1_fk2 = d2_pk1 and d2_c2 = 2 + // and f1_fk1 < d1_pk1 + // and f1_fk3 < d3_pk1 + // + // Positional join reordering: d1, f1, d3, d2, s3 + // Star join reordering: empty + + val query = + d1.join(d3).join(f1).join(d2).join(s3) + .where((nameToAttr("f1_fk2") === nameToAttr("d2_pk1")) && + (nameToAttr("d2_c2") === 2) && + (nameToAttr("f1_fk1") < nameToAttr("d1_pk1")) && + (nameToAttr("f1_fk3") < nameToAttr("d3_pk1")) && + (nameToAttr("d3_fk1") === nameToAttr("s3_pk1"))) + + val expected = + d1.join(f1, Inner, Some(nameToAttr("f1_fk1") < nameToAttr("d1_pk1"))) + .join(d3, Inner, Some(nameToAttr("f1_fk3") < nameToAttr("d3_pk1"))) + .join(d2.where(nameToAttr("d2_c2") === 2), + Inner, Some(nameToAttr("f1_fk2") === nameToAttr("d2_pk1"))) + .join(s3, Inner, Some(nameToAttr("d3_fk1") === nameToAttr("s3_pk1"))) + + assertEqualPlans(query, expected) + } + + test("Test 11: Expanding star join") { + // Star join: + // (<) (<) + // d1 - f1 - d2 + // | (<) + // d3 - s3 + // + // select f1_fk1, f1_fk3 + // from d1, d3, f1, d2, s3 + // where f1_fk2 < d2_pk1 + // and f1_fk1 < d1_pk1 + // and f1_fk3 < d3_pk1 + // and d3_fk1 < s3_pk1 + // + // Positional join reordering: d1, f1, d3, d2, s3 + // Star join reordering: empty + + val query = + d1.join(d3).join(f1).join(d2).join(s3) + .where((nameToAttr("f1_fk2") < nameToAttr("d2_pk1")) && + (nameToAttr("f1_fk1") < nameToAttr("d1_pk1")) && + (nameToAttr("f1_fk3") < nameToAttr("d3_pk1")) && + (nameToAttr("d3_fk1") < nameToAttr("s3_pk1"))) + + val expected = + d1.join(f1, Inner, Some(nameToAttr("f1_fk1") < nameToAttr("d1_pk1"))) + .join(d3, Inner, Some(nameToAttr("f1_fk3") < nameToAttr("d3_pk1"))) + .join(d2, Inner, Some(nameToAttr("f1_fk2") < nameToAttr("d2_pk1"))) + .join(s3, Inner, Some(nameToAttr("d3_fk1") < nameToAttr("s3_pk1"))) + + assertEqualPlans(query, expected) + } + + test("Test 12: Non selective star join") { + // Star join: + // (=) (=) + // d1 - f1 - d2 + // | (=) + // d3 - s3 + // + // select f1_fk1, f1_fk3 + // from d1, d3, f1, d2, s3 + // where f1_fk2 = d2_pk1 + // and f1_fk1 = d1_pk1 + // and f1_fk3 = d3_pk1 + // and d3_fk1 = s3_pk1 + // + // Positional join reordering: d1, f1, d3, d2, s3 + // Star join reordering: empty + + val query = + d1.join(d3).join(f1).join(d2).join(s3) + .where((nameToAttr("f1_fk2") === nameToAttr("d2_pk1")) && + (nameToAttr("f1_fk1") === nameToAttr("d1_pk1")) && + (nameToAttr("f1_fk3") === nameToAttr("d3_pk1")) && + (nameToAttr("d3_fk1") === nameToAttr("s3_pk1"))) + + val expected = + d1.join(f1, Inner, Some(nameToAttr("f1_fk1") === nameToAttr("d1_pk1"))) + .join(d3, Inner, Some(nameToAttr("f1_fk3") === nameToAttr("d3_pk1"))) + .join(d2, Inner, Some(nameToAttr("f1_fk2") === nameToAttr("d2_pk1"))) + .join(s3, Inner, Some(nameToAttr("d3_fk1") === nameToAttr("s3_pk1"))) + + assertEqualPlans(query, expected) + } + + private def assertEqualPlans( plan1: LogicalPlan, plan2: LogicalPlan): Unit = { + val optimized = Optimize.execute(plan1.analyze) + val expected = plan2.analyze + compareJoinOrder(optimized, expected) + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala index 5eb31413ad70f..2a9d0570148ad 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala @@ -106,4 +106,30 @@ abstract class PlanTest extends SparkFunSuite with PredicateHelper { protected def compareExpressions(e1: Expression, e2: Expression): Unit = { comparePlans(Filter(e1, OneRowRelation), Filter(e2, OneRowRelation)) } + + /** Fails the test if the join order in the two plans do not match */ + protected def compareJoinOrder(plan1: LogicalPlan, plan2: LogicalPlan) { + val normalized1 = normalizePlan(normalizeExprIds(plan1)) + val normalized2 = normalizePlan(normalizeExprIds(plan2)) + if (!sameJoinPlan(normalized1, normalized2)) { + fail( + s""" + |== FAIL: Plans do not match === + |${sideBySide(normalized1.treeString, normalized2.treeString).mkString("\n")} + """.stripMargin) + } + } + + /** Consider symmetry for joins when comparing plans. */ + private def sameJoinPlan(plan1: LogicalPlan, plan2: LogicalPlan): Boolean = { + (plan1, plan2) match { + case (j1: Join, j2: Join) => + (sameJoinPlan(j1.left, j2.left) && sameJoinPlan(j1.right, j2.right)) || + (sameJoinPlan(j1.left, j2.right) && sameJoinPlan(j1.right, j2.left)) + case _ if plan1.children.nonEmpty && plan2.children.nonEmpty => + (plan1.children, plan2.children).zipped.forall { case (c1, c2) => sameJoinPlan(c1, c2) } + case _ => + plan1 == plan2 + } + } } From 7ce30e00b236e77b5175f797f9c6fc6cf4ca7e93 Mon Sep 17 00:00:00 2001 From: windpiger Date: Mon, 20 Mar 2017 21:36:00 +0800 Subject: [PATCH 0061/1765] [SPARK-19990][SQL][TEST-MAVEN] create a temp file for file in test.jar's resource when run mvn test accross different modules ## What changes were proposed in this pull request? After we have merged the `HiveDDLSuite` and `DDLSuite` in [SPARK-19235](https://issues.apache.org/jira/browse/SPARK-19235), we have two subclasses of `DDLSuite`, that is `HiveCatalogedDDLSuite` and `InMemoryCatalogDDLSuite`. While `DDLSuite` is in `sql/core module`, and `HiveCatalogedDDLSuite` is in `sql/hive module`, if we mvn test `HiveCatalogedDDLSuite`, it will run the test in its parent class `DDLSuite`, this will cause some test case failed which will get and use the test file path in `sql/core module` 's `resource`. Because the test file path getted will start with 'jar:' like "jar:file:/home/jenkins/workspace/spark-master-test-maven-hadoop-2.6/sql/core/target/spark-sql_2.11-2.2.0-SNAPSHOT-tests.jar!/test-data/cars.csv", which will failed when new Path() in datasource.scala This PR fix this by copy file from resource to a temp dir. ## How was this patch tested? N/A Author: windpiger Closes #17338 from windpiger/fixtestfailemvn. --- .../sql/execution/command/DDLSuite.scala | 33 +++++++++++-------- .../apache/spark/sql/test/SQLTestUtils.scala | 17 +++++++++- 2 files changed, 36 insertions(+), 14 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala index dd76fdde06cdc..235c6bf6ad592 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala @@ -25,7 +25,7 @@ import org.scalatest.BeforeAndAfterEach import org.apache.spark.sql.{AnalysisException, QueryTest, Row, SaveMode} import org.apache.spark.sql.catalyst.TableIdentifier -import org.apache.spark.sql.catalyst.analysis.{DatabaseAlreadyExistsException, FunctionRegistry, NoSuchPartitionException, NoSuchTableException, TempTableAlreadyExistsException} +import org.apache.spark.sql.catalyst.analysis.{FunctionRegistry, NoSuchPartitionException, NoSuchTableException, TempTableAlreadyExistsException} import org.apache.spark.sql.catalyst.catalog._ import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec import org.apache.spark.sql.internal.SQLConf @@ -699,21 +699,28 @@ abstract class DDLSuite extends QueryTest with SQLTestUtils { } test("create temporary view using") { - val csvFile = - Thread.currentThread().getContextClassLoader.getResource("test-data/cars.csv").toString - withView("testview") { - sql(s"CREATE OR REPLACE TEMPORARY VIEW testview (c1 String, c2 String) USING " + - "org.apache.spark.sql.execution.datasources.csv.CSVFileFormat " + - s"OPTIONS (PATH '$csvFile')") + // when we test the HiveCatalogedDDLSuite, it will failed because the csvFile path above + // starts with 'jar:', and it is an illegal parameter for Path, so here we copy it + // to a temp file by withResourceTempPath + withResourceTempPath("test-data/cars.csv") { tmpFile => + withView("testview") { + sql(s"CREATE OR REPLACE TEMPORARY VIEW testview (c1 String, c2 String) USING " + + "org.apache.spark.sql.execution.datasources.csv.CSVFileFormat " + + s"OPTIONS (PATH '$tmpFile')") - checkAnswer( - sql("select c1, c2 from testview order by c1 limit 1"), + checkAnswer( + sql("select c1, c2 from testview order by c1 limit 1"), Row("1997", "Ford") :: Nil) - // Fails if creating a new view with the same name - intercept[TempTableAlreadyExistsException] { - sql(s"CREATE TEMPORARY VIEW testview USING " + - s"org.apache.spark.sql.execution.datasources.csv.CSVFileFormat OPTIONS (PATH '$csvFile')") + // Fails if creating a new view with the same name + intercept[TempTableAlreadyExistsException] { + sql( + s""" + |CREATE TEMPORARY VIEW testview + |USING org.apache.spark.sql.execution.datasources.csv.CSVFileFormat + |OPTIONS (PATH '$tmpFile') + """.stripMargin) + } } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala index 9201954b66d10..cab219216d1ca 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala @@ -19,10 +19,10 @@ package org.apache.spark.sql.test import java.io.File import java.net.URI +import java.nio.file.Files import java.util.UUID import scala.language.implicitConversions -import scala.util.Try import scala.util.control.NonFatal import org.apache.hadoop.fs.Path @@ -123,6 +123,21 @@ private[sql] trait SQLTestUtils try f(path) finally Utils.deleteRecursively(path) } + /** + * Copy file in jar's resource to a temp file, then pass it to `f`. + * This function is used to make `f` can use the path of temp file(e.g. file:/), instead of + * path of jar's resource which starts with 'jar:file:/' + */ + protected def withResourceTempPath(resourcePath: String)(f: File => Unit): Unit = { + val inputStream = + Thread.currentThread().getContextClassLoader.getResourceAsStream(resourcePath) + withTempDir { dir => + val tmpFile = new File(dir, "tmp") + Files.copy(inputStream, tmpFile.toPath) + f(tmpFile) + } + } + /** * Creates a temporary directory, which is then passed to `f` and will be deleted after `f` * returns. From fc7554599a4b6e5c22aa35e7296b424a653a420b Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Mon, 20 Mar 2017 10:07:31 -0700 Subject: [PATCH 0062/1765] [SPARK-19970][SQL] Table owner should be USER instead of PRINCIPAL in kerberized clusters ## What changes were proposed in this pull request? In the kerberized hadoop cluster, when Spark creates tables, the owner of tables are filled with PRINCIPAL strings instead of USER names. This is inconsistent with Hive and causes problems when using [ROLE](https://cwiki.apache.org/confluence/display/Hive/SQL+Standard+Based+Hive+Authorization) in Hive. We had better to fix this. **BEFORE** ```scala scala> sql("create table t(a int)").show scala> sql("desc formatted t").show(false) ... |Owner: |sparkEXAMPLE.COM | | ``` **AFTER** ```scala scala> sql("create table t(a int)").show scala> sql("desc formatted t").show(false) ... |Owner: |spark | | ``` ## How was this patch tested? Manually do `create table` and `desc formatted` because this happens in Kerberized clusters. Author: Dongjoon Hyun Closes #17311 from dongjoon-hyun/SPARK-19970. --- .../scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala index 989fdc5564d39..13edcd051768c 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala @@ -851,7 +851,7 @@ private[hive] object HiveClientImpl { hiveTable.setFields(schema.asJava) } hiveTable.setPartCols(partCols.asJava) - conf.foreach(c => hiveTable.setOwner(c.getUser)) + conf.foreach { _ => hiveTable.setOwner(SessionState.get().getAuthenticator().getUserName()) } hiveTable.setCreateTime((table.createTime / 1000).toInt) hiveTable.setLastAccessTime((table.lastAccessTime / 1000).toInt) table.storage.locationUri.map(CatalogUtils.URIToString(_)).foreach { loc => From bec6b16c1900fe93def89cc5eb51cbef498196cb Mon Sep 17 00:00:00 2001 From: zero323 Date: Mon, 20 Mar 2017 10:58:30 -0700 Subject: [PATCH 0063/1765] [SPARK-19899][ML] Replace featuresCol with itemsCol in ml.fpm.FPGrowth ## What changes were proposed in this pull request? Replaces `featuresCol` `Param` with `itemsCol`. See [SPARK-19899](https://issues.apache.org/jira/browse/SPARK-19899). ## How was this patch tested? Manual tests. Existing unit tests. Author: zero323 Closes #17321 from zero323/SPARK-19899. --- .../org/apache/spark/ml/fpm/FPGrowth.scala | 35 +++++++++++++------ .../apache/spark/ml/fpm/FPGrowthSuite.scala | 14 ++++---- 2 files changed, 31 insertions(+), 18 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/fpm/FPGrowth.scala b/mllib/src/main/scala/org/apache/spark/ml/fpm/FPGrowth.scala index fa39dd954af57..e2bc270b38da7 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/fpm/FPGrowth.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/fpm/FPGrowth.scala @@ -25,7 +25,7 @@ import org.apache.hadoop.fs.Path import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.ml.{Estimator, Model} import org.apache.spark.ml.param._ -import org.apache.spark.ml.param.shared.{HasFeaturesCol, HasPredictionCol} +import org.apache.spark.ml.param.shared.HasPredictionCol import org.apache.spark.ml.util._ import org.apache.spark.mllib.fpm.{AssociationRules => MLlibAssociationRules, FPGrowth => MLlibFPGrowth} @@ -37,7 +37,20 @@ import org.apache.spark.sql.types._ /** * Common params for FPGrowth and FPGrowthModel */ -private[fpm] trait FPGrowthParams extends Params with HasFeaturesCol with HasPredictionCol { +private[fpm] trait FPGrowthParams extends Params with HasPredictionCol { + + /** + * Items column name. + * Default: "items" + * @group param + */ + @Since("2.2.0") + val itemsCol: Param[String] = new Param[String](this, "itemsCol", "items column name") + setDefault(itemsCol -> "items") + + /** @group getParam */ + @Since("2.2.0") + def getItemsCol: String = $(itemsCol) /** * Minimal support level of the frequent pattern. [0.0, 1.0]. Any pattern that appears @@ -91,10 +104,10 @@ private[fpm] trait FPGrowthParams extends Params with HasFeaturesCol with HasPre */ @Since("2.2.0") protected def validateAndTransformSchema(schema: StructType): StructType = { - val inputType = schema($(featuresCol)).dataType + val inputType = schema($(itemsCol)).dataType require(inputType.isInstanceOf[ArrayType], s"The input column must be ArrayType, but got $inputType.") - SchemaUtils.appendColumn(schema, $(predictionCol), schema($(featuresCol)).dataType) + SchemaUtils.appendColumn(schema, $(predictionCol), schema($(itemsCol)).dataType) } } @@ -133,7 +146,7 @@ class FPGrowth @Since("2.2.0") ( /** @group setParam */ @Since("2.2.0") - def setFeaturesCol(value: String): this.type = set(featuresCol, value) + def setItemsCol(value: String): this.type = set(itemsCol, value) /** @group setParam */ @Since("2.2.0") @@ -146,8 +159,8 @@ class FPGrowth @Since("2.2.0") ( } private def genericFit[T: ClassTag](dataset: Dataset[_]): FPGrowthModel = { - val data = dataset.select($(featuresCol)) - val items = data.where(col($(featuresCol)).isNotNull).rdd.map(r => r.getSeq[T](0).toArray) + val data = dataset.select($(itemsCol)) + val items = data.where(col($(itemsCol)).isNotNull).rdd.map(r => r.getSeq[T](0).toArray) val mllibFP = new MLlibFPGrowth().setMinSupport($(minSupport)) if (isSet(numPartitions)) { mllibFP.setNumPartitions($(numPartitions)) @@ -156,7 +169,7 @@ class FPGrowth @Since("2.2.0") ( val rows = parentModel.freqItemsets.map(f => Row(f.items, f.freq)) val schema = StructType(Seq( - StructField("items", dataset.schema($(featuresCol)).dataType, nullable = false), + StructField("items", dataset.schema($(itemsCol)).dataType, nullable = false), StructField("freq", LongType, nullable = false))) val frequentItems = dataset.sparkSession.createDataFrame(rows, schema) copyValues(new FPGrowthModel(uid, frequentItems)).setParent(this) @@ -198,7 +211,7 @@ class FPGrowthModel private[ml] ( /** @group setParam */ @Since("2.2.0") - def setFeaturesCol(value: String): this.type = set(featuresCol, value) + def setItemsCol(value: String): this.type = set(itemsCol, value) /** @group setParam */ @Since("2.2.0") @@ -235,7 +248,7 @@ class FPGrowthModel private[ml] ( .collect().asInstanceOf[Array[(Seq[Any], Seq[Any])]] val brRules = dataset.sparkSession.sparkContext.broadcast(rules) - val dt = dataset.schema($(featuresCol)).dataType + val dt = dataset.schema($(itemsCol)).dataType // For each rule, examine the input items and summarize the consequents val predictUDF = udf((items: Seq[_]) => { if (items != null) { @@ -249,7 +262,7 @@ class FPGrowthModel private[ml] ( } else { Seq.empty }}, dt) - dataset.withColumn($(predictionCol), predictUDF(col($(featuresCol)))) + dataset.withColumn($(predictionCol), predictUDF(col($(itemsCol)))) } @Since("2.2.0") diff --git a/mllib/src/test/scala/org/apache/spark/ml/fpm/FPGrowthSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/fpm/FPGrowthSuite.scala index 910d4b07d1302..4603a618d2f93 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/fpm/FPGrowthSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/fpm/FPGrowthSuite.scala @@ -34,7 +34,7 @@ class FPGrowthSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul test("FPGrowth fit and transform with different data types") { Array(IntegerType, StringType, ShortType, LongType, ByteType).foreach { dt => - val data = dataset.withColumn("features", col("features").cast(ArrayType(dt))) + val data = dataset.withColumn("items", col("items").cast(ArrayType(dt))) val model = new FPGrowth().setMinSupport(0.5).fit(data) val generatedRules = model.setMinConfidence(0.5).associationRules val expectedRules = spark.createDataFrame(Seq( @@ -52,8 +52,8 @@ class FPGrowthSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul (0, Array("1", "2"), Array.emptyIntArray), (0, Array("1", "2"), Array.emptyIntArray), (0, Array("1", "3"), Array(2)) - )).toDF("id", "features", "prediction") - .withColumn("features", col("features").cast(ArrayType(dt))) + )).toDF("id", "items", "prediction") + .withColumn("items", col("items").cast(ArrayType(dt))) .withColumn("prediction", col("prediction").cast(ArrayType(dt))) assert(expectedTransformed.collect().toSet.equals( transformed.collect().toSet)) @@ -79,7 +79,7 @@ class FPGrowthSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul (1, Array("1", "2", "3", "5")), (2, Array("1", "2", "3", "4")), (3, null.asInstanceOf[Array[String]]) - )).toDF("id", "features") + )).toDF("id", "items") val model = new FPGrowth().setMinSupport(0.7).fit(dataset) val prediction = model.transform(df) assert(prediction.select("prediction").where("id=3").first().getSeq[String](0).isEmpty) @@ -108,11 +108,11 @@ class FPGrowthSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul val dataset = spark.createDataFrame(Seq( Array("1", "3"), Array("2", "3") - ).map(Tuple1(_))).toDF("features") + ).map(Tuple1(_))).toDF("items") val model = new FPGrowth().fit(dataset) val prediction = model.transform( - spark.createDataFrame(Seq(Tuple1(Array("1", "2")))).toDF("features") + spark.createDataFrame(Seq(Tuple1(Array("1", "2")))).toDF("items") ).first().getAs[Seq[String]]("prediction") assert(prediction === Seq("3")) @@ -127,7 +127,7 @@ object FPGrowthSuite { (0, Array("1", "2")), (0, Array("1", "2")), (0, Array("1", "3")) - )).toDF("id", "features") + )).toDF("id", "items") } /** From c2d1761a57f5d175913284533b3d0417e8718688 Mon Sep 17 00:00:00 2001 From: Tyson Condie Date: Mon, 20 Mar 2017 17:18:59 -0700 Subject: [PATCH 0064/1765] [SPARK-19906][SS][DOCS] Documentation describing how to write queries to Kafka ## What changes were proposed in this pull request? Add documentation that describes how to write streaming and batch queries to Kafka. zsxwing tdas Please review http://spark.apache.org/contributing.html before opening a pull request. Author: Tyson Condie Closes #17246 from tcondie/kafka-write-docs. --- .../structured-streaming-kafka-integration.md | 321 ++++++++++++++---- 1 file changed, 264 insertions(+), 57 deletions(-) diff --git a/docs/structured-streaming-kafka-integration.md b/docs/structured-streaming-kafka-integration.md index 522e669568678..217c1a91a16f3 100644 --- a/docs/structured-streaming-kafka-integration.md +++ b/docs/structured-streaming-kafka-integration.md @@ -3,9 +3,9 @@ layout: global title: Structured Streaming + Kafka Integration Guide (Kafka broker version 0.10.0 or higher) --- -Structured Streaming integration for Kafka 0.10 to poll data from Kafka. +Structured Streaming integration for Kafka 0.10 to read data from and write data to Kafka. -### Linking +## Linking For Scala/Java applications using SBT/Maven project definitions, link your application with the following artifact: groupId = org.apache.spark @@ -15,40 +15,42 @@ For Scala/Java applications using SBT/Maven project definitions, link your appli For Python applications, you need to add this above library and its dependencies when deploying your application. See the [Deploying](#deploying) subsection below. -### Creating a Kafka Source Stream +## Reading Data from Kafka + +### Creating a Kafka Source for Streaming Queries
    {% highlight scala %} // Subscribe to 1 topic -val ds1 = spark +val df = spark .readStream .format("kafka") .option("kafka.bootstrap.servers", "host1:port1,host2:port2") .option("subscribe", "topic1") .load() -ds1.selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") +df.selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") .as[(String, String)] // Subscribe to multiple topics -val ds2 = spark +val df = spark .readStream .format("kafka") .option("kafka.bootstrap.servers", "host1:port1,host2:port2") .option("subscribe", "topic1,topic2") .load() -ds2.selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") +df.selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") .as[(String, String)] // Subscribe to a pattern -val ds3 = spark +val df = spark .readStream .format("kafka") .option("kafka.bootstrap.servers", "host1:port1,host2:port2") .option("subscribePattern", "topic.*") .load() -ds3.selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") +df.selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") .as[(String, String)] {% endhighlight %} @@ -57,31 +59,31 @@ ds3.selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") {% highlight java %} // Subscribe to 1 topic -Dataset ds1 = spark +DataFrame df = spark .readStream() .format("kafka") .option("kafka.bootstrap.servers", "host1:port1,host2:port2") .option("subscribe", "topic1") .load() -ds1.selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") +df.selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") // Subscribe to multiple topics -Dataset ds2 = spark +DataFrame df = spark .readStream() .format("kafka") .option("kafka.bootstrap.servers", "host1:port1,host2:port2") .option("subscribe", "topic1,topic2") .load() -ds2.selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") +df.selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") // Subscribe to a pattern -Dataset ds3 = spark +DataFrame df = spark .readStream() .format("kafka") .option("kafka.bootstrap.servers", "host1:port1,host2:port2") .option("subscribePattern", "topic.*") .load() -ds3.selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") +df.selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") {% endhighlight %}
    @@ -89,37 +91,37 @@ ds3.selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") {% highlight python %} # Subscribe to 1 topic -ds1 = spark - .readStream - .format("kafka") - .option("kafka.bootstrap.servers", "host1:port1,host2:port2") - .option("subscribe", "topic1") +df = spark \ + .readStream \ + .format("kafka") \ + .option("kafka.bootstrap.servers", "host1:port1,host2:port2") \ + .option("subscribe", "topic1") \ .load() -ds1.selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") +df.selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") # Subscribe to multiple topics -ds2 = spark - .readStream - .format("kafka") - .option("kafka.bootstrap.servers", "host1:port1,host2:port2") - .option("subscribe", "topic1,topic2") +df = spark \ + .readStream \ + .format("kafka") \ + .option("kafka.bootstrap.servers", "host1:port1,host2:port2") \ + .option("subscribe", "topic1,topic2") \ .load() -ds2.selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") +df.selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") # Subscribe to a pattern -ds3 = spark - .readStream - .format("kafka") - .option("kafka.bootstrap.servers", "host1:port1,host2:port2") - .option("subscribePattern", "topic.*") +df = spark \ + .readStream \ + .format("kafka") \ + .option("kafka.bootstrap.servers", "host1:port1,host2:port2") \ + .option("subscribePattern", "topic.*") \ .load() -ds3.selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") +df.selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") {% endhighlight %}
    -### Creating a Kafka Source Batch +### Creating a Kafka Source for Batch Queries If you have a use case that is better suited to batch processing, you can create an Dataset/DataFrame for a defined range of offsets. @@ -128,17 +130,17 @@ you can create an Dataset/DataFrame for a defined range of offsets. {% highlight scala %} // Subscribe to 1 topic defaults to the earliest and latest offsets -val ds1 = spark +val df = spark .read .format("kafka") .option("kafka.bootstrap.servers", "host1:port1,host2:port2") .option("subscribe", "topic1") .load() -ds1.selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") +df.selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") .as[(String, String)] // Subscribe to multiple topics, specifying explicit Kafka offsets -val ds2 = spark +val df = spark .read .format("kafka") .option("kafka.bootstrap.servers", "host1:port1,host2:port2") @@ -146,11 +148,11 @@ val ds2 = spark .option("startingOffsets", """{"topic1":{"0":23,"1":-2},"topic2":{"0":-2}}""") .option("endingOffsets", """{"topic1":{"0":50,"1":-1},"topic2":{"0":-1}}""") .load() -ds2.selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") +df.selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") .as[(String, String)] // Subscribe to a pattern, at the earliest and latest offsets -val ds3 = spark +val df = spark .read .format("kafka") .option("kafka.bootstrap.servers", "host1:port1,host2:port2") @@ -158,7 +160,7 @@ val ds3 = spark .option("startingOffsets", "earliest") .option("endingOffsets", "latest") .load() -ds3.selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") +df.selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") .as[(String, String)] {% endhighlight %} @@ -167,16 +169,16 @@ ds3.selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") {% highlight java %} // Subscribe to 1 topic defaults to the earliest and latest offsets -Dataset ds1 = spark +DataFrame df = spark .read() .format("kafka") .option("kafka.bootstrap.servers", "host1:port1,host2:port2") .option("subscribe", "topic1") .load(); -ds1.selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)"); +df.selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)"); // Subscribe to multiple topics, specifying explicit Kafka offsets -Dataset ds2 = spark +DataFrame df = spark .read() .format("kafka") .option("kafka.bootstrap.servers", "host1:port1,host2:port2") @@ -184,10 +186,10 @@ Dataset ds2 = spark .option("startingOffsets", "{\"topic1\":{\"0\":23,\"1\":-2},\"topic2\":{\"0\":-2}}") .option("endingOffsets", "{\"topic1\":{\"0\":50,\"1\":-1},\"topic2\":{\"0\":-1}}") .load(); -ds2.selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)"); +df.selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)"); // Subscribe to a pattern, at the earliest and latest offsets -Dataset ds3 = spark +DataFrame df = spark .read() .format("kafka") .option("kafka.bootstrap.servers", "host1:port1,host2:port2") @@ -195,7 +197,7 @@ Dataset ds3 = spark .option("startingOffsets", "earliest") .option("endingOffsets", "latest") .load(); -ds3.selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)"); +df.selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)"); {% endhighlight %} @@ -203,16 +205,16 @@ ds3.selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)"); {% highlight python %} # Subscribe to 1 topic defaults to the earliest and latest offsets -ds1 = spark \ +df = spark \ .read \ .format("kafka") \ .option("kafka.bootstrap.servers", "host1:port1,host2:port2") \ .option("subscribe", "topic1") \ .load() -ds1.selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") +df.selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") # Subscribe to multiple topics, specifying explicit Kafka offsets -ds2 = spark \ +df = spark \ .read \ .format("kafka") \ .option("kafka.bootstrap.servers", "host1:port1,host2:port2") \ @@ -220,10 +222,10 @@ ds2 = spark \ .option("startingOffsets", """{"topic1":{"0":23,"1":-2},"topic2":{"0":-2}}""") \ .option("endingOffsets", """{"topic1":{"0":50,"1":-1},"topic2":{"0":-1}}""") \ .load() -ds2.selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") +df.selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") # Subscribe to a pattern, at the earliest and latest offsets -ds3 = spark \ +df = spark \ .read \ .format("kafka") \ .option("kafka.bootstrap.servers", "host1:port1,host2:port2") \ @@ -231,8 +233,7 @@ ds3 = spark \ .option("startingOffsets", "earliest") \ .option("endingOffsets", "latest") \ .load() -ds3.selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") - +df.selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") {% endhighlight %} @@ -373,11 +374,213 @@ The following configurations are optional: +## Writing Data to Kafka + +Here, we describe the support for writing Streaming Queries and Batch Queries to Apache Kafka. Take note that +Apache Kafka only supports at least once write semantics. Consequently, when writing---either Streaming Queries +or Batch Queries---to Kafka, some records may be duplicated; this can happen, for example, if Kafka needs +to retry a message that was not acknowledged by a Broker, even though that Broker received and wrote the message record. +Structured Streaming cannot prevent such duplicates from occurring due to these Kafka write semantics. However, +if writing the query is successful, then you can assume that the query output was written at least once. A possible +solution to remove duplicates when reading the written data could be to introduce a primary (unique) key +that can be used to perform de-duplication when reading. + +The Dataframe being written to Kafka should have the following columns in schema: + + + + + + + + + + + + + + +
    ColumnType
    key (optional)string or binary
    value (required)string or binary
    topic (*optional)string
    +\* The topic column is required if the "topic" configuration option is not specified.
    + +The value column is the only required option. If a key column is not specified then +a ```null``` valued key column will be automatically added (see Kafka semantics on +how ```null``` valued key values are handled). If a topic column exists then its value +is used as the topic when writing the given row to Kafka, unless the "topic" configuration +option is set i.e., the "topic" configuration option overrides the topic column. + +The following options must be set for the Kafka sink +for both batch and streaming queries. + + + + + + + + +
    Optionvaluemeaning
    kafka.bootstrap.serversA comma-separated list of host:portThe Kafka "bootstrap.servers" configuration.
    + +The following configurations are optional: + + + + + + + + + + +
    Optionvaluedefaultquery typemeaning
    topicstringnonestreaming and batchSets the topic that all rows will be written to in Kafka. This option overrides any + topic column that may exist in the data.
    + +### Creating a Kafka Sink for Streaming Queries + +
    +
    +{% highlight scala %} + +// Write key-value data from a DataFrame to a specific Kafka topic specified in an option +val ds = df + .selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") + .writeStream + .format("kafka") + .option("kafka.bootstrap.servers", "host1:port1,host2:port2") + .option("topic", "topic1") + .start() + +// Write key-value data from a DataFrame to Kafka using a topic specified in the data +val ds = df + .selectExpr("topic", "CAST(key AS STRING)", "CAST(value AS STRING)") + .writeStream + .format("kafka") + .option("kafka.bootstrap.servers", "host1:port1,host2:port2") + .start() + +{% endhighlight %} +
    +
    +{% highlight java %} + +// Write key-value data from a DataFrame to a specific Kafka topic specified in an option +StreamingQuery ds = df + .selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") + .writeStream() + .format("kafka") + .option("kafka.bootstrap.servers", "host1:port1,host2:port2") + .option("topic", "topic1") + .start() + +// Write key-value data from a DataFrame to Kafka using a topic specified in the data +StreamingQuery ds = df + .selectExpr("topic", "CAST(key AS STRING)", "CAST(value AS STRING)") + .writeStream() + .format("kafka") + .option("kafka.bootstrap.servers", "host1:port1,host2:port2") + .start() + +{% endhighlight %} +
    +
    +{% highlight python %} + +# Write key-value data from a DataFrame to a specific Kafka topic specified in an option +ds = df \ + .selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") \ + .writeStream \ + .format("kafka") \ + .option("kafka.bootstrap.servers", "host1:port1,host2:port2") \ + .option("topic", "topic1") \ + .start() + +# Write key-value data from a DataFrame to Kafka using a topic specified in the data +ds = df \ + .selectExpr("topic", "CAST(key AS STRING)", "CAST(value AS STRING)") \ + .writeStream \ + .format("kafka") \ + .option("kafka.bootstrap.servers", "host1:port1,host2:port2") \ + .start() + +{% endhighlight %} +
    +
    + +### Writing the output of Batch Queries to Kafka + +
    +
    +{% highlight scala %} + +// Write key-value data from a DataFrame to a specific Kafka topic specified in an option +df.selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") + .write + .format("kafka") + .option("kafka.bootstrap.servers", "host1:port1,host2:port2") + .option("topic", "topic1") + .save() + +// Write key-value data from a DataFrame to Kafka using a topic specified in the data +df.selectExpr("topic", "CAST(key AS STRING)", "CAST(value AS STRING)") + .write + .format("kafka") + .option("kafka.bootstrap.servers", "host1:port1,host2:port2") + .save() + +{% endhighlight %} +
    +
    +{% highlight java %} + +// Write key-value data from a DataFrame to a specific Kafka topic specified in an option +df.selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") + .write() + .format("kafka") + .option("kafka.bootstrap.servers", "host1:port1,host2:port2") + .option("topic", "topic1") + .save() + +// Write key-value data from a DataFrame to Kafka using a topic specified in the data +df.selectExpr("topic", "CAST(key AS STRING)", "CAST(value AS STRING)") + .write() + .format("kafka") + .option("kafka.bootstrap.servers", "host1:port1,host2:port2") + .save() + +{% endhighlight %} +
    +
    +{% highlight python %} + +# Write key-value data from a DataFrame to a specific Kafka topic specified in an option +df.selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") \ + .write \ + .format("kafka") \ + .option("kafka.bootstrap.servers", "host1:port1,host2:port2") \ + .option("topic", "topic1") \ + .save() + +# Write key-value data from a DataFrame to Kafka using a topic specified in the data +df.selectExpr("topic", "CAST(key AS STRING)", "CAST(value AS STRING)") \ + .write \ + .format("kafka") \ + .option("kafka.bootstrap.servers", "host1:port1,host2:port2") \ + .save() + +{% endhighlight %} +
    +
    + + +## Kafka Specific Configurations + Kafka's own configurations can be set via `DataStreamReader.option` with `kafka.` prefix, e.g, -`stream.option("kafka.bootstrap.servers", "host:port")`. For possible kafkaParams, see -[Kafka consumer config docs](http://kafka.apache.org/documentation.html#newconsumerconfigs). +`stream.option("kafka.bootstrap.servers", "host:port")`. For possible kafka parameters, see +[Kafka consumer config docs](http://kafka.apache.org/documentation.html#newconsumerconfigs) for +parameters related to reading data, and [Kafka producer config docs](http://kafka.apache.org/documentation/#producerconfigs) +for parameters related to writing data. -Note that the following Kafka params cannot be set and the Kafka source will throw an exception: +Note that the following Kafka params cannot be set and the Kafka source or sink will throw an exception: - **group.id**: Kafka source will create a unique group id for each query automatically. - **auto.offset.reset**: Set the source option `startingOffsets` to specify @@ -389,11 +592,15 @@ Note that the following Kafka params cannot be set and the Kafka source will thr DataFrame operations to explicitly deserialize the keys. - **value.deserializer**: Values are always deserialized as byte arrays with ByteArrayDeserializer. Use DataFrame operations to explicitly deserialize the values. +- **key.serializer**: Keys are always serialized with ByteArraySerializer or StringSerializer. Use +DataFrame operations to explicitly serialize the keys into either strings or byte arrays. +- **value.serializer**: values are always serialized with ByteArraySerializer or StringSerializer. Use +DataFrame oeprations to explicitly serialize the values into either strings or byte arrays. - **enable.auto.commit**: Kafka source doesn't commit any offset. - **interceptor.classes**: Kafka source always read keys and values as byte arrays. It's not safe to use ConsumerInterceptor as it may break the query. -### Deploying +## Deploying As with any Spark applications, `spark-submit` is used to launch your application. `spark-sql-kafka-0-10_{{site.SCALA_BINARY_VERSION}}` and its dependencies can be directly added to `spark-submit` using `--packages`, such as, From 10691d36de902e3771af20aed40336b4f99de719 Mon Sep 17 00:00:00 2001 From: Zheng RuiFeng Date: Mon, 20 Mar 2017 18:25:59 -0700 Subject: [PATCH 0065/1765] [SPARK-19573][SQL] Make NaN/null handling consistent in approxQuantile ## What changes were proposed in this pull request? update `StatFunctions.multipleApproxQuantiles` to handle NaN/null ## How was this patch tested? existing tests and added tests Author: Zheng RuiFeng Closes #16971 from zhengruifeng/quantiles_nan. --- .../aggregate/ApproximatePercentile.scala | 3 +- .../sql/catalyst/util/QuantileSummaries.scala | 12 ++-- .../util/QuantileSummariesSuite.scala | 46 ++++++++++----- .../spark/sql/DataFrameStatFunctions.scala | 21 +++---- .../sql/execution/stat/StatFunctions.scala | 10 +++- .../apache/spark/sql/DataFrameStatSuite.scala | 57 ++++++++++++------- 6 files changed, 95 insertions(+), 54 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproximatePercentile.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproximatePercentile.scala index db062f1a543fe..1ec2e4a9e9319 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproximatePercentile.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproximatePercentile.scala @@ -245,7 +245,8 @@ object ApproximatePercentile { val result = new Array[Double](percentages.length) var i = 0 while (i < percentages.length) { - result(i) = summaries.query(percentages(i)) + // Since summaries.count != 0, the query here never return None. + result(i) = summaries.query(percentages(i)).get i += 1 } result diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/QuantileSummaries.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/QuantileSummaries.scala index 04f4ff2a92247..af543b04ba780 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/QuantileSummaries.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/QuantileSummaries.scala @@ -176,17 +176,19 @@ class QuantileSummaries( * @param quantile the target quantile * @return */ - def query(quantile: Double): Double = { + def query(quantile: Double): Option[Double] = { require(quantile >= 0 && quantile <= 1.0, "quantile should be in the range [0.0, 1.0]") require(headSampled.isEmpty, "Cannot operate on an uncompressed summary, call compress() first") + if (sampled.isEmpty) return None + if (quantile <= relativeError) { - return sampled.head.value + return Some(sampled.head.value) } if (quantile >= 1 - relativeError) { - return sampled.last.value + return Some(sampled.last.value) } // Target rank @@ -200,11 +202,11 @@ class QuantileSummaries( minRank += curSample.g val maxRank = minRank + curSample.delta if (maxRank - targetError <= rank && rank <= minRank + targetError) { - return curSample.value + return Some(curSample.value) } i += 1 } - sampled.last.value + Some(sampled.last.value) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/QuantileSummariesSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/QuantileSummariesSuite.scala index 5e90970b1bb2e..df579d5ec1ddf 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/QuantileSummariesSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/QuantileSummariesSuite.scala @@ -55,15 +55,19 @@ class QuantileSummariesSuite extends SparkFunSuite { } private def checkQuantile(quant: Double, data: Seq[Double], summary: QuantileSummaries): Unit = { - val approx = summary.query(quant) - // The rank of the approximation. - val rank = data.count(_ < approx) // has to be <, not <= to be exact - val lower = math.floor((quant - summary.relativeError) * data.size) - val upper = math.ceil((quant + summary.relativeError) * data.size) - val msg = - s"$rank not in [$lower $upper], requested quantile: $quant, approx returned: $approx" - assert(rank >= lower, msg) - assert(rank <= upper, msg) + if (data.nonEmpty) { + val approx = summary.query(quant).get + // The rank of the approximation. + val rank = data.count(_ < approx) // has to be <, not <= to be exact + val lower = math.floor((quant - summary.relativeError) * data.size) + val upper = math.ceil((quant + summary.relativeError) * data.size) + val msg = + s"$rank not in [$lower $upper], requested quantile: $quant, approx returned: $approx" + assert(rank >= lower, msg) + assert(rank <= upper, msg) + } else { + assert(summary.query(quant).isEmpty) + } } for { @@ -74,9 +78,9 @@ class QuantileSummariesSuite extends SparkFunSuite { test(s"Extremas with epsi=$epsi and seq=$seq_name, compression=$compression") { val s = buildSummary(data, epsi, compression) - val min_approx = s.query(0.0) + val min_approx = s.query(0.0).get assert(min_approx == data.min, s"Did not return the min: min=${data.min}, got $min_approx") - val max_approx = s.query(1.0) + val max_approx = s.query(1.0).get assert(max_approx == data.max, s"Did not return the max: max=${data.max}, got $max_approx") } @@ -100,6 +104,18 @@ class QuantileSummariesSuite extends SparkFunSuite { checkQuantile(0.1, data, s) checkQuantile(0.001, data, s) } + + test(s"Tests on empty data with epsi=$epsi and seq=$seq_name, compression=$compression") { + val emptyData = Seq.empty[Double] + val s = buildSummary(emptyData, epsi, compression) + assert(s.count == 0, s"Found count=${s.count} but data size=0") + assert(s.sampled.isEmpty, s"if QuantileSummaries is empty, sampled should be empty") + checkQuantile(0.9999, emptyData, s) + checkQuantile(0.9, emptyData, s) + checkQuantile(0.5, emptyData, s) + checkQuantile(0.1, emptyData, s) + checkQuantile(0.001, emptyData, s) + } } // Tests for merging procedure @@ -118,9 +134,9 @@ class QuantileSummariesSuite extends SparkFunSuite { val s1 = buildSummary(data1, epsi, compression) val s2 = buildSummary(data2, epsi, compression) val s = s1.merge(s2) - val min_approx = s.query(0.0) + val min_approx = s.query(0.0).get assert(min_approx == data.min, s"Did not return the min: min=${data.min}, got $min_approx") - val max_approx = s.query(1.0) + val max_approx = s.query(1.0).get assert(max_approx == data.max, s"Did not return the max: max=${data.max}, got $max_approx") checkQuantile(0.9999, data, s) checkQuantile(0.9, data, s) @@ -137,9 +153,9 @@ class QuantileSummariesSuite extends SparkFunSuite { val s1 = buildSummary(data11, epsi, compression) val s2 = buildSummary(data12, epsi, compression) val s = s1.merge(s2) - val min_approx = s.query(0.0) + val min_approx = s.query(0.0).get assert(min_approx == data.min, s"Did not return the min: min=${data.min}, got $min_approx") - val max_approx = s.query(1.0) + val max_approx = s.query(1.0).get assert(max_approx == data.max, s"Did not return the max: max=${data.max}, got $max_approx") checkQuantile(0.9999, data, s) checkQuantile(0.9, data, s) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala index bdcdf0c61ff36..c856d3099f6ee 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala @@ -64,7 +64,7 @@ final class DataFrameStatFunctions private[sql](df: DataFrame) { * @return the approximate quantiles at the given probabilities * * @note null and NaN values will be removed from the numerical column before calculation. If - * the dataframe is empty or all rows contain null or NaN, null is returned. + * the dataframe is empty or the column only contains null or NaN, an empty array is returned. * * @since 2.0.0 */ @@ -72,8 +72,7 @@ final class DataFrameStatFunctions private[sql](df: DataFrame) { col: String, probabilities: Array[Double], relativeError: Double): Array[Double] = { - val res = approxQuantile(Array(col), probabilities, relativeError) - Option(res).map(_.head).orNull + approxQuantile(Array(col), probabilities, relativeError).head } /** @@ -89,8 +88,8 @@ final class DataFrameStatFunctions private[sql](df: DataFrame) { * Note that values greater than 1 are accepted but give the same result as 1. * @return the approximate quantiles at the given probabilities of each column * - * @note Rows containing any null or NaN values will be removed before calculation. If - * the dataframe is empty or all rows contain null or NaN, null is returned. + * @note null and NaN values will be ignored in numerical columns before calculation. For + * columns only containing null or NaN values, an empty array is returned. * * @since 2.2.0 */ @@ -98,13 +97,11 @@ final class DataFrameStatFunctions private[sql](df: DataFrame) { cols: Array[String], probabilities: Array[Double], relativeError: Double): Array[Array[Double]] = { - // TODO: Update NaN/null handling to keep consistent with the single-column version - try { - StatFunctions.multipleApproxQuantiles(df.select(cols.map(col): _*).na.drop(), cols, - probabilities, relativeError).map(_.toArray).toArray - } catch { - case e: NoSuchElementException => null - } + StatFunctions.multipleApproxQuantiles( + df.select(cols.map(col): _*), + cols, + probabilities, + relativeError).map(_.toArray).toArray } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala index c3d8859cb7a92..1debad03c93fa 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala @@ -54,6 +54,9 @@ object StatFunctions extends Logging { * Note that values greater than 1 are accepted but give the same result as 1. * * @return for each column, returns the requested approximations + * + * @note null and NaN values will be ignored in numerical columns before calculation. For + * a column only containing null or NaN values, an empty array is returned. */ def multipleApproxQuantiles( df: DataFrame, @@ -78,7 +81,10 @@ object StatFunctions extends Logging { def apply(summaries: Array[QuantileSummaries], row: Row): Array[QuantileSummaries] = { var i = 0 while (i < summaries.length) { - summaries(i) = summaries(i).insert(row.getDouble(i)) + if (!row.isNullAt(i)) { + val v = row.getDouble(i) + if (!v.isNaN) summaries(i) = summaries(i).insert(v) + } i += 1 } summaries @@ -91,7 +97,7 @@ object StatFunctions extends Logging { } val summaries = df.select(columns: _*).rdd.aggregate(emptySummaries)(apply, merge) - summaries.map { summary => probabilities.map(summary.query) } + summaries.map { summary => probabilities.flatMap(summary.query) } } /** Calculate the Pearson Correlation Coefficient for the given columns */ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala index d0910e618a040..97890a035a62f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala @@ -171,15 +171,6 @@ class DataFrameStatSuite extends QueryTest with SharedSQLContext { df.stat.approxQuantile(Array("singles", "doubles"), Array(q1, q2), -1.0) } assert(e2.getMessage.contains("Relative Error must be non-negative")) - - // return null if the dataset is empty - val res1 = df.selectExpr("*").limit(0) - .stat.approxQuantile("singles", Array(q1, q2), epsilons.head) - assert(res1 === null) - - val res2 = df.selectExpr("*").limit(0) - .stat.approxQuantile(Array("singles", "doubles"), Array(q1, q2), epsilons.head) - assert(res2 === null) } test("approximate quantile 2: test relativeError greater than 1 return the same result as 1") { @@ -214,20 +205,48 @@ class DataFrameStatSuite extends QueryTest with SharedSQLContext { val q1 = 0.5 val q2 = 0.8 val epsilon = 0.1 - val rows = spark.sparkContext.parallelize(Seq(Row(Double.NaN, 1.0), Row(1.0, 1.0), - Row(-1.0, Double.NaN), Row(Double.NaN, Double.NaN), Row(null, null), Row(null, 1.0), - Row(-1.0, null), Row(Double.NaN, null))) + val rows = spark.sparkContext.parallelize(Seq(Row(Double.NaN, 1.0, Double.NaN), + Row(1.0, -1.0, null), Row(-1.0, Double.NaN, null), Row(Double.NaN, Double.NaN, null), + Row(null, null, Double.NaN), Row(null, 1.0, null), Row(-1.0, null, Double.NaN), + Row(Double.NaN, null, null))) val schema = StructType(Seq(StructField("input1", DoubleType, nullable = true), - StructField("input2", DoubleType, nullable = true))) + StructField("input2", DoubleType, nullable = true), + StructField("input3", DoubleType, nullable = true))) val dfNaN = spark.createDataFrame(rows, schema) - val resNaN = dfNaN.stat.approxQuantile("input1", Array(q1, q2), epsilon) - assert(resNaN.count(_.isNaN) === 0) - assert(resNaN.count(_ == null) === 0) - val resNaN2 = dfNaN.stat.approxQuantile(Array("input1", "input2"), + val resNaN1 = dfNaN.stat.approxQuantile("input1", Array(q1, q2), epsilon) + assert(resNaN1.count(_.isNaN) === 0) + assert(resNaN1.count(_ == null) === 0) + + val resNaN2 = dfNaN.stat.approxQuantile("input2", Array(q1, q2), epsilon) + assert(resNaN2.count(_.isNaN) === 0) + assert(resNaN2.count(_ == null) === 0) + + val resNaN3 = dfNaN.stat.approxQuantile("input3", Array(q1, q2), epsilon) + assert(resNaN3.isEmpty) + + val resNaNAll = dfNaN.stat.approxQuantile(Array("input1", "input2", "input3"), Array(q1, q2), epsilon) - assert(resNaN2.flatten.count(_.isNaN) === 0) - assert(resNaN2.flatten.count(_ == null) === 0) + assert(resNaNAll.flatten.count(_.isNaN) === 0) + assert(resNaNAll.flatten.count(_ == null) === 0) + + assert(resNaN1(0) === resNaNAll(0)(0)) + assert(resNaN1(1) === resNaNAll(0)(1)) + assert(resNaN2(0) === resNaNAll(1)(0)) + assert(resNaN2(1) === resNaNAll(1)(1)) + + // return empty array for columns only containing null or NaN values + assert(resNaNAll(2).isEmpty) + + // return empty array if the dataset is empty + val res1 = dfNaN.selectExpr("*").limit(0) + .stat.approxQuantile("input1", Array(q1, q2), epsilon) + assert(res1.isEmpty) + + val res2 = dfNaN.selectExpr("*").limit(0) + .stat.approxQuantile(Array("input1", "input2"), Array(q1, q2), epsilon) + assert(res2(0).isEmpty) + assert(res2(1).isEmpty) } test("crosstab") { From e9c91badce64731ffd3e53cbcd9f044a7593e6b8 Mon Sep 17 00:00:00 2001 From: wangzhenhua Date: Tue, 21 Mar 2017 10:43:17 +0800 Subject: [PATCH 0066/1765] [SPARK-20010][SQL] Sort information is lost after sort merge join ## What changes were proposed in this pull request? After sort merge join for inner join, now we only keep left key ordering. However, after inner join, right key has the same value and order as left key. So if we need another smj on right key, we will unnecessarily add a sort which causes additional cost. As a more complicated example, A join B on A.key = B.key join C on B.key = C.key join D on A.key = D.key. We will unnecessarily add a sort on B.key when join {A, B} and C, and add a sort on A.key when join {A, B, C} and D. To fix this, we need to propagate all sorted information (equivalent expressions) from bottom up through `outputOrdering` and `SortOrder`. ## How was this patch tested? Test cases are added. Author: wangzhenhua Closes #17339 from wzhfy/sortEnhance. --- .../sql/catalyst/analysis/Analyzer.scala | 4 +-- .../SubstituteUnresolvedOrdinals.scala | 2 +- .../spark/sql/catalyst/dsl/package.scala | 4 +-- .../sql/catalyst/expressions/SortOrder.scala | 21 +++++++++++-- .../sql/catalyst/parser/AstBuilder.scala | 2 +- .../scala/org/apache/spark/sql/Column.scala | 8 ++--- .../exchange/EnsureRequirements.scala | 2 +- .../execution/joins/SortMergeJoinExec.scala | 26 ++++++++++++++-- .../spark/sql/execution/PlannerSuite.scala | 30 ++++++++++++++++++- 9 files changed, 81 insertions(+), 18 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 8cf4073826192..574f91b09912b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -966,9 +966,9 @@ class Analyzer( case s @ Sort(orders, global, child) if orders.exists(_.child.isInstanceOf[UnresolvedOrdinal]) => val newOrders = orders map { - case s @ SortOrder(UnresolvedOrdinal(index), direction, nullOrdering) => + case s @ SortOrder(UnresolvedOrdinal(index), direction, nullOrdering, _) => if (index > 0 && index <= child.output.size) { - SortOrder(child.output(index - 1), direction, nullOrdering) + SortOrder(child.output(index - 1), direction, nullOrdering, Set.empty) } else { s.failAnalysis( s"ORDER BY position $index is not in select list " + diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/SubstituteUnresolvedOrdinals.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/SubstituteUnresolvedOrdinals.scala index af0a565f73ae9..38a3d3de1288e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/SubstituteUnresolvedOrdinals.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/SubstituteUnresolvedOrdinals.scala @@ -36,7 +36,7 @@ class SubstituteUnresolvedOrdinals(conf: CatalystConf) extends Rule[LogicalPlan] def apply(plan: LogicalPlan): LogicalPlan = plan transform { case s: Sort if conf.orderByOrdinal && s.order.exists(o => isIntLiteral(o.child)) => val newOrders = s.order.map { - case order @ SortOrder(ordinal @ Literal(index: Int, IntegerType), _, _) => + case order @ SortOrder(ordinal @ Literal(index: Int, IntegerType), _, _, _) => val newOrdinal = withOrigin(ordinal.origin)(UnresolvedOrdinal(index)) withOrigin(order.origin)(order.copy(child = newOrdinal)) case other => other diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala index 35ca2a0aa53a2..75bf780d41424 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala @@ -109,9 +109,9 @@ package object dsl { def cast(to: DataType): Expression = Cast(expr, to) def asc: SortOrder = SortOrder(expr, Ascending) - def asc_nullsLast: SortOrder = SortOrder(expr, Ascending, NullsLast) + def asc_nullsLast: SortOrder = SortOrder(expr, Ascending, NullsLast, Set.empty) def desc: SortOrder = SortOrder(expr, Descending) - def desc_nullsFirst: SortOrder = SortOrder(expr, Descending, NullsFirst) + def desc_nullsFirst: SortOrder = SortOrder(expr, Descending, NullsFirst, Set.empty) def as(alias: String): NamedExpression = Alias(expr, alias)() def as(alias: Symbol): NamedExpression = Alias(expr, alias.name)() } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala index 3bebd552ef51a..abcb9a2b939b4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala @@ -53,8 +53,15 @@ case object NullsLast extends NullOrdering{ /** * An expression that can be used to sort a tuple. This class extends expression primarily so that * transformations over expression will descend into its child. + * `sameOrderExpressions` is a set of expressions with the same sort order as the child. It is + * derived from equivalence relation in an operator, e.g. left/right keys of an inner sort merge + * join. */ -case class SortOrder(child: Expression, direction: SortDirection, nullOrdering: NullOrdering) +case class SortOrder( + child: Expression, + direction: SortDirection, + nullOrdering: NullOrdering, + sameOrderExpressions: Set[Expression]) extends UnaryExpression with Unevaluable { /** Sort order is not foldable because we don't have an eval for it. */ @@ -75,11 +82,19 @@ case class SortOrder(child: Expression, direction: SortDirection, nullOrdering: override def sql: String = child.sql + " " + direction.sql + " " + nullOrdering.sql def isAscending: Boolean = direction == Ascending + + def satisfies(required: SortOrder): Boolean = { + (sameOrderExpressions + child).exists(required.child.semanticEquals) && + direction == required.direction && nullOrdering == required.nullOrdering + } } object SortOrder { - def apply(child: Expression, direction: SortDirection): SortOrder = { - new SortOrder(child, direction, direction.defaultNullOrdering) + def apply( + child: Expression, + direction: SortDirection, + sameOrderExpressions: Set[Expression] = Set.empty): SortOrder = { + new SortOrder(child, direction, direction.defaultNullOrdering, sameOrderExpressions) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index 4c9fb2ec2774a..cd238e05d4102 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -1229,7 +1229,7 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with Logging { } else { direction.defaultNullOrdering } - SortOrder(expression(ctx.expression), direction, nullOrdering) + SortOrder(expression(ctx.expression), direction, nullOrdering, Set.empty) } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala index 38029552d13bd..ae0703513cf42 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala @@ -1037,7 +1037,7 @@ class Column(val expr: Expression) extends Logging { * @group expr_ops * @since 2.1.0 */ - def desc_nulls_first: Column = withExpr { SortOrder(expr, Descending, NullsFirst) } + def desc_nulls_first: Column = withExpr { SortOrder(expr, Descending, NullsFirst, Set.empty) } /** * Returns a descending ordering used in sorting, where null values appear after non-null values. @@ -1052,7 +1052,7 @@ class Column(val expr: Expression) extends Logging { * @group expr_ops * @since 2.1.0 */ - def desc_nulls_last: Column = withExpr { SortOrder(expr, Descending, NullsLast) } + def desc_nulls_last: Column = withExpr { SortOrder(expr, Descending, NullsLast, Set.empty) } /** * Returns an ascending ordering used in sorting. @@ -1082,7 +1082,7 @@ class Column(val expr: Expression) extends Logging { * @group expr_ops * @since 2.1.0 */ - def asc_nulls_first: Column = withExpr { SortOrder(expr, Ascending, NullsFirst) } + def asc_nulls_first: Column = withExpr { SortOrder(expr, Ascending, NullsFirst, Set.empty) } /** * Returns an ordering used in sorting, where null values appear after non-null values. @@ -1097,7 +1097,7 @@ class Column(val expr: Expression) extends Logging { * @group expr_ops * @since 2.1.0 */ - def asc_nulls_last: Column = withExpr { SortOrder(expr, Ascending, NullsLast) } + def asc_nulls_last: Column = withExpr { SortOrder(expr, Ascending, NullsLast, Set.empty) } /** * Prints the expression to the console for debugging purpose. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala index f17049949aa47..b91d077442557 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala @@ -241,7 +241,7 @@ case class EnsureRequirements(conf: SQLConf) extends Rule[SparkPlan] { } else { requiredOrdering.zip(child.outputOrdering).forall { case (requiredOrder, childOutputOrder) => - requiredOrder.semanticEquals(childOutputOrder) + childOutputOrder.satisfies(requiredOrder) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala index 02f4f55c7999a..c6aae1a4db2e4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala @@ -81,17 +81,37 @@ case class SortMergeJoinExec( ClusteredDistribution(leftKeys) :: ClusteredDistribution(rightKeys) :: Nil override def outputOrdering: Seq[SortOrder] = joinType match { + // For inner join, orders of both sides keys should be kept. + case Inner => + val leftKeyOrdering = getKeyOrdering(leftKeys, left.outputOrdering) + val rightKeyOrdering = getKeyOrdering(rightKeys, right.outputOrdering) + leftKeyOrdering.zip(rightKeyOrdering).map { case (lKey, rKey) => + // Also add the right key and its `sameOrderExpressions` + SortOrder(lKey.child, Ascending, lKey.sameOrderExpressions + rKey.child ++ rKey + .sameOrderExpressions) + } // For left and right outer joins, the output is ordered by the streamed input's join keys. - case LeftOuter => requiredOrders(leftKeys) - case RightOuter => requiredOrders(rightKeys) + case LeftOuter => getKeyOrdering(leftKeys, left.outputOrdering) + case RightOuter => getKeyOrdering(rightKeys, right.outputOrdering) // There are null rows in both streams, so there is no order. case FullOuter => Nil - case _: InnerLike | LeftExistence(_) => requiredOrders(leftKeys) + case LeftExistence(_) => getKeyOrdering(leftKeys, left.outputOrdering) case x => throw new IllegalArgumentException( s"${getClass.getSimpleName} should not take $x as the JoinType") } + /** + * For SMJ, child's output must have been sorted on key or expressions with the same order as + * key, so we can get ordering for key from child's output ordering. + */ + private def getKeyOrdering(keys: Seq[Expression], childOutputOrdering: Seq[SortOrder]) + : Seq[SortOrder] = { + keys.zip(childOutputOrdering).map { case (key, childOrder) => + SortOrder(key, Ascending, childOrder.sameOrderExpressions + childOrder.child - key) + } + } + override def requiredChildOrdering: Seq[Seq[SortOrder]] = requiredOrders(leftKeys) :: requiredOrders(rightKeys) :: Nil diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala index f2232fc489b78..4d155d538d637 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala @@ -477,14 +477,18 @@ class PlannerSuite extends SharedSQLContext { private val exprA = Literal(1) private val exprB = Literal(2) + private val exprC = Literal(3) private val orderingA = SortOrder(exprA, Ascending) private val orderingB = SortOrder(exprB, Ascending) + private val orderingC = SortOrder(exprC, Ascending) private val planA = DummySparkPlan(outputOrdering = Seq(orderingA), outputPartitioning = HashPartitioning(exprA :: Nil, 5)) private val planB = DummySparkPlan(outputOrdering = Seq(orderingB), outputPartitioning = HashPartitioning(exprB :: Nil, 5)) + private val planC = DummySparkPlan(outputOrdering = Seq(orderingC), + outputPartitioning = HashPartitioning(exprC :: Nil, 5)) - assert(orderingA != orderingB) + assert(orderingA != orderingB && orderingA != orderingC && orderingB != orderingC) private def assertSortRequirementsAreSatisfied( childPlan: SparkPlan, @@ -508,6 +512,30 @@ class PlannerSuite extends SharedSQLContext { } } + test("EnsureRequirements skips sort when either side of join keys is required after inner SMJ") { + val innerSmj = SortMergeJoinExec(exprA :: Nil, exprB :: Nil, Inner, None, planA, planB) + // Both left and right keys should be sorted after the SMJ. + Seq(orderingA, orderingB).foreach { ordering => + assertSortRequirementsAreSatisfied( + childPlan = innerSmj, + requiredOrdering = Seq(ordering), + shouldHaveSort = false) + } + } + + test("EnsureRequirements skips sort when key order of a parent SMJ is propagated from its " + + "child SMJ") { + val childSmj = SortMergeJoinExec(exprA :: Nil, exprB :: Nil, Inner, None, planA, planB) + val parentSmj = SortMergeJoinExec(exprB :: Nil, exprC :: Nil, Inner, None, childSmj, planC) + // After the second SMJ, exprA, exprB and exprC should all be sorted. + Seq(orderingA, orderingB, orderingC).foreach { ordering => + assertSortRequirementsAreSatisfied( + childPlan = parentSmj, + requiredOrdering = Seq(ordering), + shouldHaveSort = false) + } + } + test("EnsureRequirements for sort operator after left outer sort merge join") { // Only left key is sorted after left outer SMJ (thus doesn't need a sort). val leftSmj = SortMergeJoinExec(exprA :: Nil, exprB :: Nil, LeftOuter, None, planA, planB) From 0ec1db5475f1a7839bdbf0d9cffe93ce6970a7fe Mon Sep 17 00:00:00 2001 From: Takeshi Yamamuro Date: Tue, 21 Mar 2017 11:17:34 +0800 Subject: [PATCH 0067/1765] [SPARK-19980][SQL] Add NULL checks in Bean serializer ## What changes were proposed in this pull request? A Bean serializer in `ExpressionEncoder` could change values when Beans having NULL. A concrete example is as follows; ``` scala> :paste class Outer extends Serializable { private var cls: Inner = _ def setCls(c: Inner): Unit = cls = c def getCls(): Inner = cls } class Inner extends Serializable { private var str: String = _ def setStr(s: String): Unit = str = str def getStr(): String = str } scala> Seq("""{"cls":null}""", """{"cls": {"str":null}}""").toDF().write.text("data") scala> val encoder = Encoders.bean(classOf[Outer]) scala> val schema = encoder.schema scala> val df = spark.read.schema(schema).json("data").as[Outer](encoder) scala> df.show +------+ | cls| +------+ |[null]| | null| +------+ scala> df.map(x => x)(encoder).show() +------+ | cls| +------+ |[null]| |[null]| // <-- Value changed +------+ ``` This is because the Bean serializer does not have the NULL-check expressions that the serializer of Scala's product types has. Actually, this value change does not happen in Scala's product types; ``` scala> :paste case class Outer(cls: Inner) case class Inner(str: String) scala> val encoder = Encoders.product[Outer] scala> val schema = encoder.schema scala> val df = spark.read.schema(schema).json("data").as[Outer](encoder) scala> df.show +------+ | cls| +------+ |[null]| | null| +------+ scala> df.map(x => x)(encoder).show() +------+ | cls| +------+ |[null]| | null| +------+ ``` This pr added the NULL-check expressions in Bean serializer along with the serializer of Scala's product types. ## How was this patch tested? Added tests in `JavaDatasetSuite`. Author: Takeshi Yamamuro Closes #17347 from maropu/SPARK-19980. --- .../sql/catalyst/JavaTypeInference.scala | 11 +++++++++-- .../apache/spark/sql/JavaDatasetSuite.java | 19 +++++++++++++++++++ 2 files changed, 28 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala index 4ff87edde139a..9d4617dda555f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala @@ -343,7 +343,11 @@ object JavaTypeInference { */ def serializerFor(beanClass: Class[_]): CreateNamedStruct = { val inputObject = BoundReference(0, ObjectType(beanClass), nullable = true) - serializerFor(inputObject, TypeToken.of(beanClass)).asInstanceOf[CreateNamedStruct] + val nullSafeInput = AssertNotNull(inputObject, Seq("top level input bean")) + serializerFor(nullSafeInput, TypeToken.of(beanClass)) match { + case expressions.If(_, _, s: CreateNamedStruct) => s + case other => CreateNamedStruct(expressions.Literal("value") :: other :: Nil) + } } private def serializerFor(inputObject: Expression, typeToken: TypeToken[_]): Expression = { @@ -427,7 +431,7 @@ object JavaTypeInference { case other => val properties = getJavaBeanReadableAndWritableProperties(other) - CreateNamedStruct(properties.flatMap { p => + val nonNullOutput = CreateNamedStruct(properties.flatMap { p => val fieldName = p.getName val fieldType = typeToken.method(p.getReadMethod).getReturnType val fieldValue = Invoke( @@ -436,6 +440,9 @@ object JavaTypeInference { inferExternalType(fieldType.getRawType)) expressions.Literal(fieldName) :: serializerFor(fieldValue, fieldType) :: Nil }) + + val nullOutput = expressions.Literal.create(null, nonNullOutput.dataType) + expressions.If(IsNull(inputObject), nullOutput, nonNullOutput) } } } diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java index ca9e5ad2ea86b..ffb4c6273ff85 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java @@ -1380,4 +1380,23 @@ public void testCircularReferenceBean3() { CircularReference4Bean bean = new CircularReference4Bean(); spark.createDataset(Arrays.asList(bean), Encoders.bean(CircularReference4Bean.class)); } + + @Test(expected = RuntimeException.class) + public void testNullInTopLevelBean() { + NestedSmallBean bean = new NestedSmallBean(); + // We cannot set null in top-level bean + spark.createDataset(Arrays.asList(bean, null), Encoders.bean(NestedSmallBean.class)); + } + + @Test + public void testSerializeNull() { + NestedSmallBean bean = new NestedSmallBean(); + Encoder encoder = Encoders.bean(NestedSmallBean.class); + List beans = Arrays.asList(bean); + Dataset ds1 = spark.createDataset(beans, encoder); + Assert.assertEquals(beans, ds1.collectAsList()); + Dataset ds2 = + ds1.map((MapFunction) b -> b, encoder); + Assert.assertEquals(beans, ds2.collectAsList()); + } } From 7fa116f8fc77906202217c0cd2f9718a4e62632b Mon Sep 17 00:00:00 2001 From: Michael Allman Date: Tue, 21 Mar 2017 11:51:22 +0800 Subject: [PATCH 0068/1765] [SPARK-17204][CORE] Fix replicated off heap storage (Jira: https://issues.apache.org/jira/browse/SPARK-17204) ## What changes were proposed in this pull request? There are a couple of bugs in the `BlockManager` with respect to support for replicated off-heap storage. First, the locally-stored off-heap byte buffer is disposed of when it is replicated. It should not be. Second, the replica byte buffers are stored as heap byte buffers instead of direct byte buffers even when the storage level memory mode is off-heap. This PR addresses both of these problems. ## How was this patch tested? `BlockManagerReplicationSuite` was enhanced to fill in the coverage gaps. It now fails if either of the bugs in this PR exist. Author: Michael Allman Closes #16499 from mallman/spark-17204-replicated_off_heap_storage. --- .../apache/spark/storage/BlockManager.scala | 23 ++++++-- .../apache/spark/storage/StorageUtils.scala | 52 ++++++++++++++++--- .../spark/util/ByteBufferInputStream.scala | 8 +-- .../spark/util/io/ChunkedByteBuffer.scala | 27 ++++++++-- .../BlockManagerReplicationSuite.scala | 20 +++++-- 5 files changed, 105 insertions(+), 25 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala index 45b73380806dd..245d94ac4f8b1 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala @@ -317,6 +317,9 @@ private[spark] class BlockManager( /** * Put the block locally, using the given storage level. + * + * '''Important!''' Callers must not mutate or release the data buffer underlying `bytes`. Doing + * so may corrupt or change the data stored by the `BlockManager`. */ override def putBlockData( blockId: BlockId, @@ -755,6 +758,9 @@ private[spark] class BlockManager( /** * Put a new block of serialized bytes to the block manager. * + * '''Important!''' Callers must not mutate or release the data buffer underlying `bytes`. Doing + * so may corrupt or change the data stored by the `BlockManager`. + * * @param encrypt If true, asks the block manager to encrypt the data block before storing, * when I/O encryption is enabled. This is required for blocks that have been * read from unencrypted sources, since all the BlockManager read APIs @@ -773,7 +779,7 @@ private[spark] class BlockManager( if (encrypt && securityManager.ioEncryptionKey.isDefined) { try { val data = bytes.toByteBuffer - val in = new ByteBufferInputStream(data, true) + val in = new ByteBufferInputStream(data) val byteBufOut = new ByteBufferOutputStream(data.remaining()) val out = CryptoStreamUtils.createCryptoOutputStream(byteBufOut, conf, securityManager.ioEncryptionKey.get) @@ -800,6 +806,9 @@ private[spark] class BlockManager( * * If the block already exists, this method will not overwrite it. * + * '''Important!''' Callers must not mutate or release the data buffer underlying `bytes`. Doing + * so may corrupt or change the data stored by the `BlockManager`. + * * @param keepReadLock if true, this method will hold the read lock when it returns (even if the * block already exists). If false, this method will hold no locks when it * returns. @@ -843,7 +852,15 @@ private[spark] class BlockManager( false } } else { - memoryStore.putBytes(blockId, size, level.memoryMode, () => bytes) + val memoryMode = level.memoryMode + memoryStore.putBytes(blockId, size, memoryMode, () => { + if (memoryMode == MemoryMode.OFF_HEAP && + bytes.chunks.exists(buffer => !buffer.isDirect)) { + bytes.copy(Platform.allocateDirectBuffer) + } else { + bytes + } + }) } if (!putSucceeded && level.useDisk) { logWarning(s"Persisting block $blockId to disk instead.") @@ -1048,7 +1065,7 @@ private[spark] class BlockManager( try { replicate(blockId, bytesToReplicate, level, remoteClassTag) } finally { - bytesToReplicate.dispose() + bytesToReplicate.unmap() } logDebug("Put block %s remotely took %s" .format(blockId, Utils.getUsedTimeMs(remoteStartTime))) diff --git a/core/src/main/scala/org/apache/spark/storage/StorageUtils.scala b/core/src/main/scala/org/apache/spark/storage/StorageUtils.scala index e12f2e6095d5a..5efdd23f79a21 100644 --- a/core/src/main/scala/org/apache/spark/storage/StorageUtils.scala +++ b/core/src/main/scala/org/apache/spark/storage/StorageUtils.scala @@ -236,22 +236,60 @@ class StorageStatus(val blockManagerId: BlockManagerId, val maxMem: Long) { /** Helper methods for storage-related objects. */ private[spark] object StorageUtils extends Logging { + // Ewwww... Reflection!!! See the unmap method for justification + private val memoryMappedBufferFileDescriptorField = { + val mappedBufferClass = classOf[java.nio.MappedByteBuffer] + val fdField = mappedBufferClass.getDeclaredField("fd") + fdField.setAccessible(true) + fdField + } /** - * Attempt to clean up a ByteBuffer if it is memory-mapped. This uses an *unsafe* Sun API that - * might cause errors if one attempts to read from the unmapped buffer, but it's better than - * waiting for the GC to find it because that could lead to huge numbers of open files. There's - * unfortunately no standard API to do this. + * Attempt to clean up a ByteBuffer if it is direct or memory-mapped. This uses an *unsafe* Sun + * API that will cause errors if one attempts to read from the disposed buffer. However, neither + * the bytes allocated to direct buffers nor file descriptors opened for memory-mapped buffers put + * pressure on the garbage collector. Waiting for garbage collection may lead to the depletion of + * off-heap memory or huge numbers of open files. There's unfortunately no standard API to + * manually dispose of these kinds of buffers. + * + * See also [[unmap]] */ def dispose(buffer: ByteBuffer): Unit = { if (buffer != null && buffer.isInstanceOf[MappedByteBuffer]) { - logTrace(s"Unmapping $buffer") - if (buffer.asInstanceOf[DirectBuffer].cleaner() != null) { - buffer.asInstanceOf[DirectBuffer].cleaner().clean() + logTrace(s"Disposing of $buffer") + cleanDirectBuffer(buffer.asInstanceOf[DirectBuffer]) + } + } + + /** + * Attempt to unmap a ByteBuffer if it is memory-mapped. This uses an *unsafe* Sun API that will + * cause errors if one attempts to read from the unmapped buffer. However, the file descriptors of + * memory-mapped buffers do not put pressure on the garbage collector. Waiting for garbage + * collection may lead to huge numbers of open files. There's unfortunately no standard API to + * manually unmap memory-mapped buffers. + * + * See also [[dispose]] + */ + def unmap(buffer: ByteBuffer): Unit = { + if (buffer != null && buffer.isInstanceOf[MappedByteBuffer]) { + // Note that direct buffers are instances of MappedByteBuffer. As things stand in Java 8, the + // JDK does not provide a public API to distinguish between direct buffers and memory-mapped + // buffers. As an alternative, we peek beneath the curtains and look for a non-null file + // descriptor in mappedByteBuffer + if (memoryMappedBufferFileDescriptorField.get(buffer) != null) { + logTrace(s"Unmapping $buffer") + cleanDirectBuffer(buffer.asInstanceOf[DirectBuffer]) } } } + private def cleanDirectBuffer(buffer: DirectBuffer) = { + val cleaner = buffer.cleaner() + if (cleaner != null) { + cleaner.clean() + } + } + /** * Update the given list of RDDInfo with the given list of storage statuses. * This method overwrites the old values stored in the RDDInfo's. diff --git a/core/src/main/scala/org/apache/spark/util/ByteBufferInputStream.scala b/core/src/main/scala/org/apache/spark/util/ByteBufferInputStream.scala index dce2ac63a664c..50dc948e6c410 100644 --- a/core/src/main/scala/org/apache/spark/util/ByteBufferInputStream.scala +++ b/core/src/main/scala/org/apache/spark/util/ByteBufferInputStream.scala @@ -23,11 +23,10 @@ import java.nio.ByteBuffer import org.apache.spark.storage.StorageUtils /** - * Reads data from a ByteBuffer, and optionally cleans it up using StorageUtils.dispose() - * at the end of the stream (e.g. to close a memory-mapped file). + * Reads data from a ByteBuffer. */ private[spark] -class ByteBufferInputStream(private var buffer: ByteBuffer, dispose: Boolean = false) +class ByteBufferInputStream(private var buffer: ByteBuffer) extends InputStream { override def read(): Int = { @@ -72,9 +71,6 @@ class ByteBufferInputStream(private var buffer: ByteBuffer, dispose: Boolean = f */ private def cleanUp() { if (buffer != null) { - if (dispose) { - StorageUtils.dispose(buffer) - } buffer = null } } diff --git a/core/src/main/scala/org/apache/spark/util/io/ChunkedByteBuffer.scala b/core/src/main/scala/org/apache/spark/util/io/ChunkedByteBuffer.scala index 7572cac39317c..1667516663b35 100644 --- a/core/src/main/scala/org/apache/spark/util/io/ChunkedByteBuffer.scala +++ b/core/src/main/scala/org/apache/spark/util/io/ChunkedByteBuffer.scala @@ -86,7 +86,11 @@ private[spark] class ChunkedByteBuffer(var chunks: Array[ByteBuffer]) { } /** - * Copy this buffer into a new ByteBuffer. + * Convert this buffer to a ByteBuffer. If this buffer is backed by a single chunk, its underlying + * data will not be copied. Instead, it will be duplicated. If this buffer is backed by multiple + * chunks, the data underlying this buffer will be copied into a new byte buffer. As a result, it + * is suggested to use this method only if the caller does not need to manage the memory + * underlying this buffer. * * @throws UnsupportedOperationException if this buffer's size exceeds the max ByteBuffer size. */ @@ -132,10 +136,10 @@ private[spark] class ChunkedByteBuffer(var chunks: Array[ByteBuffer]) { } /** - * Attempt to clean up a ByteBuffer if it is memory-mapped. This uses an *unsafe* Sun API that - * might cause errors if one attempts to read from the unmapped buffer, but it's better than - * waiting for the GC to find it because that could lead to huge numbers of open files. There's - * unfortunately no standard API to do this. + * Attempt to clean up any ByteBuffer in this ChunkedByteBuffer which is direct or memory-mapped. + * See [[StorageUtils.dispose]] for more information. + * + * See also [[unmap]] */ def dispose(): Unit = { if (!disposed) { @@ -143,6 +147,19 @@ private[spark] class ChunkedByteBuffer(var chunks: Array[ByteBuffer]) { disposed = true } } + + /** + * Attempt to unmap any ByteBuffer in this ChunkedByteBuffer if it is memory-mapped. See + * [[StorageUtils.unmap]] for more information. + * + * See also [[dispose]] + */ + def unmap(): Unit = { + if (!disposed) { + chunks.foreach(StorageUtils.unmap) + disposed = true + } + } } /** diff --git a/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala index 75dc04038debc..d907add920c8a 100644 --- a/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala @@ -374,7 +374,8 @@ trait BlockManagerReplicationBehavior extends SparkFunSuite // Put the block into one of the stores val blockId = new TestBlockId( "block-with-" + storageLevel.description.replace(" ", "-").toLowerCase) - stores(0).putSingle(blockId, new Array[Byte](blockSize), storageLevel) + val testValue = Array.fill[Byte](blockSize)(1) + stores(0).putSingle(blockId, testValue, storageLevel) // Assert that master know two locations for the block val blockLocations = master.getLocations(blockId).map(_.executorId).toSet @@ -386,12 +387,23 @@ trait BlockManagerReplicationBehavior extends SparkFunSuite testStore => blockLocations.contains(testStore.blockManagerId.executorId) }.foreach { testStore => val testStoreName = testStore.blockManagerId.executorId - assert( - testStore.getLocalValues(blockId).isDefined, s"$blockId was not found in $testStoreName") - testStore.releaseLock(blockId) + val blockResultOpt = testStore.getLocalValues(blockId) + assert(blockResultOpt.isDefined, s"$blockId was not found in $testStoreName") + val localValues = blockResultOpt.get.data.toSeq + assert(localValues.size == 1) + assert(localValues.head === testValue) assert(master.getLocations(blockId).map(_.executorId).toSet.contains(testStoreName), s"master does not have status for ${blockId.name} in $testStoreName") + val memoryStore = testStore.memoryStore + if (memoryStore.contains(blockId) && !storageLevel.deserialized) { + memoryStore.getBytes(blockId).get.chunks.foreach { byteBuffer => + assert(storageLevel.useOffHeap == byteBuffer.isDirect, + s"memory mode ${storageLevel.memoryMode} is not compatible with " + + byteBuffer.getClass.getSimpleName) + } + } + val blockStatus = master.getBlockStatus(blockId)(testStore.blockManagerId) // Assert that block status in the master for this store has expected storage level From 21e366aea5a7f49e42e78dce06ff6b3ee1e36f06 Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Tue, 21 Mar 2017 12:17:26 +0800 Subject: [PATCH 0069/1765] [SPARK-19912][SQL] String literals should be escaped for Hive metastore partition pruning ## What changes were proposed in this pull request? Since current `HiveShim`'s `convertFilters` does not escape the string literals. There exists the following correctness issues. This PR aims to return the correct result and also shows the more clear exception message. **BEFORE** ```scala scala> Seq((1, "p1", "q1"), (2, "p1\" and q=\"q1", "q2")).toDF("a", "p", "q").write.partitionBy("p", "q").saveAsTable("t1") scala> spark.table("t1").filter($"p" === "p1\" and q=\"q1").select($"a").show +---+ | a| +---+ +---+ scala> spark.table("t1").filter($"p" === "'\"").select($"a").show java.lang.RuntimeException: Caught Hive MetaException attempting to get partition metadata by filter from ... ``` **AFTER** ```scala scala> spark.table("t1").filter($"p" === "p1\" and q=\"q1").select($"a").show +---+ | a| +---+ | 2| +---+ scala> spark.table("t1").filter($"p" === "'\"").select($"a").show java.lang.UnsupportedOperationException: Partition filter cannot have both `"` and `'` characters ``` ## How was this patch tested? Pass the Jenkins test with new test cases. Author: Dongjoon Hyun Closes #17266 from dongjoon-hyun/SPARK-19912. --- .../apache/spark/sql/hive/client/HiveShim.scala | 16 ++++++++++++++-- .../spark/sql/hive/client/FiltersSuite.scala | 5 +++++ .../spark/sql/hive/execution/SQLQuerySuite.scala | 16 ++++++++++++++++ 3 files changed, 35 insertions(+), 2 deletions(-) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala index 76568f599078d..d55c41e5c9f29 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala @@ -596,13 +596,24 @@ private[client] class Shim_v0_13 extends Shim_v0_12 { s"$v ${op.symbol} ${a.name}" case op @ BinaryComparison(a: Attribute, Literal(v, _: StringType)) if !varcharKeys.contains(a.name) => - s"""${a.name} ${op.symbol} "$v"""" + s"""${a.name} ${op.symbol} ${quoteStringLiteral(v.toString)}""" case op @ BinaryComparison(Literal(v, _: StringType), a: Attribute) if !varcharKeys.contains(a.name) => - s""""$v" ${op.symbol} ${a.name}""" + s"""${quoteStringLiteral(v.toString)} ${op.symbol} ${a.name}""" }.mkString(" and ") } + private def quoteStringLiteral(str: String): String = { + if (!str.contains("\"")) { + s""""$str"""" + } else if (!str.contains("'")) { + s"""'$str'""" + } else { + throw new UnsupportedOperationException( + """Partition filter cannot have both `"` and `'` characters""") + } + } + override def getPartitionsByFilter( hive: Hive, table: Table, @@ -611,6 +622,7 @@ private[client] class Shim_v0_13 extends Shim_v0_12 { // Hive getPartitionsByFilter() takes a string that represents partition // predicates like "str_key=\"value\" and int_key=1 ..." val filter = convertFilters(table, predicates) + val partitions = if (filter.isEmpty) { getAllPartitionsMethod.invoke(hive, table).asInstanceOf[JSet[Partition]] diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/FiltersSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/FiltersSuite.scala index cd96c85f3e209..031c1a5ec0ec3 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/FiltersSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/FiltersSuite.scala @@ -65,6 +65,11 @@ class FiltersSuite extends SparkFunSuite with Logging { (Literal("") === a("varchar", StringType)) :: Nil, "") + filterTest("SPARK-19912 String literals should be escaped for Hive metastore partition pruning", + (a("stringcol", StringType) === Literal("p1\" and q=\"q1")) :: + (Literal("p2\" and q=\"q2") === a("stringcol", StringType)) :: Nil, + """stringcol = 'p1" and q="q1' and 'p2" and q="q2' = stringcol""") + private def filterTest(name: String, filters: Seq[Expression], result: String) = { test(name) { val converted = shim.convertFilters(testTable, filters) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala index 236135dcff523..55ff4bb115e59 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala @@ -2057,4 +2057,20 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton { } } } + + test("SPARK-19912 String literals should be escaped for Hive metastore partition pruning") { + withTable("spark_19912") { + Seq( + (1, "p1", "q1"), + (2, "'", "q2"), + (3, "\"", "q3"), + (4, "p1\" and q=\"q1", "q4") + ).toDF("a", "p", "q").write.partitionBy("p", "q").saveAsTable("spark_19912") + + val table = spark.table("spark_19912") + checkAnswer(table.filter($"p" === "'").select($"a"), Row(2)) + checkAnswer(table.filter($"p" === "\"").select($"a"), Row(3)) + checkAnswer(table.filter($"p" === "p1\" and q=\"q1").select($"a"), Row(4)) + } + } } From 68d65fae71e475ad811a9716098aca03a2af9532 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Mon, 20 Mar 2017 21:43:14 -0700 Subject: [PATCH 0070/1765] [SPARK-19949][SQL] unify bad record handling in CSV and JSON ## What changes were proposed in this pull request? Currently JSON and CSV have exactly the same logic about handling bad records, this PR tries to abstract it and put it in a upper level to reduce code duplication. The overall idea is, we make the JSON and CSV parser to throw a BadRecordException, then the upper level, FailureSafeParser, handles bad records according to the parse mode. Behavior changes: 1. with PERMISSIVE mode, if the number of tokens doesn't match the schema, previously CSV parser will treat it as a legal record and parse as many tokens as possible. After this PR, we treat it as an illegal record, and put the raw record string in a special column, but we still parse as many tokens as possible. 2. all logging is removed as they are not very useful in practice. ## How was this patch tested? existing tests Author: Wenchen Fan Author: hyukjinkwon Author: Wenchen Fan Closes #17315 from cloud-fan/bad-record2. --- R/pkg/inst/tests/testthat/test_sparkSQL.R | 5 +- .../expressions/jsonExpressions.scala | 4 +- .../spark/sql/catalyst/json/JSONOptions.scala | 2 +- .../sql/catalyst/json/JacksonParser.scala | 122 +---------- .../sql/catalyst/util/FailureSafeParser.scala | 80 +++++++ .../apache/spark/sql/DataFrameReader.scala | 23 +- .../datasources/csv/CSVDataSource.scala | 17 +- .../datasources/csv/CSVFileFormat.scala | 7 +- .../datasources/csv/CSVOptions.scala | 2 +- .../datasources/csv/UnivocityParser.scala | 197 ++++++------------ .../datasources/json/JsonDataSource.scala | 31 ++- .../datasources/json/JsonFileFormat.scala | 7 +- .../execution/datasources/csv/CSVSuite.scala | 2 +- .../datasources/json/JsonSuite.scala | 8 +- 14 files changed, 222 insertions(+), 285 deletions(-) create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/FailureSafeParser.scala diff --git a/R/pkg/inst/tests/testthat/test_sparkSQL.R b/R/pkg/inst/tests/testthat/test_sparkSQL.R index cbc3569795d97..394d1a04e09c3 100644 --- a/R/pkg/inst/tests/testthat/test_sparkSQL.R +++ b/R/pkg/inst/tests/testthat/test_sparkSQL.R @@ -1370,9 +1370,8 @@ test_that("column functions", { # passing option df <- as.DataFrame(list(list("col" = "{\"date\":\"21/10/2014\"}"))) schema2 <- structType(structField("date", "date")) - expect_error(tryCatch(collect(select(df, from_json(df$col, schema2))), - error = function(e) { stop(e) }), - paste0(".*(java.lang.NumberFormatException: For input string:).*")) + s <- collect(select(df, from_json(df$col, schema2))) + expect_equal(s[[1]][[1]], NA) s <- collect(select(df, from_json(df$col, schema2, dateFormat = "dd/MM/yyyy"))) expect_is(s[[1]][[1]]$date, "Date") expect_equal(as.character(s[[1]][[1]]$date), "2014-10-21") diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala index e4e08a8665a5a..08af5522d822d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala @@ -29,7 +29,7 @@ import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback import org.apache.spark.sql.catalyst.parser.CatalystSqlParser import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.json._ -import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, ArrayData, GenericArrayData, ParseModes} +import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, ArrayData, BadRecordException, GenericArrayData, ParseModes} import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String import org.apache.spark.util.Utils @@ -583,7 +583,7 @@ case class JsonToStructs( CreateJacksonParser.utf8String, identity[UTF8String])) } catch { - case _: SparkSQLJsonProcessingException => null + case _: BadRecordException => null } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JSONOptions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JSONOptions.scala index 5f222ec602c99..355c26afa6f0d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JSONOptions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JSONOptions.scala @@ -65,7 +65,7 @@ private[sql] class JSONOptions( val allowBackslashEscapingAnyCharacter = parameters.get("allowBackslashEscapingAnyCharacter").map(_.toBoolean).getOrElse(false) val compressionCodec = parameters.get("compression").map(CompressionCodecs.getCodecClassName) - private val parseMode = parameters.getOrElse("mode", "PERMISSIVE") + val parseMode = parameters.getOrElse("mode", "PERMISSIVE") val columnNameOfCorruptRecord = parameters.getOrElse("columnNameOfCorruptRecord", defaultColumnNameOfCorruptRecord) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonParser.scala index 9b80c0fc87c93..fdb7d88d5bd7f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonParser.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonParser.scala @@ -32,17 +32,14 @@ import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String import org.apache.spark.util.Utils -private[sql] class SparkSQLJsonProcessingException(msg: String) extends RuntimeException(msg) - /** * Constructs a parser for a given schema that translates a json string to an [[InternalRow]]. */ class JacksonParser( schema: StructType, - options: JSONOptions) extends Logging { + val options: JSONOptions) extends Logging { import JacksonUtils._ - import ParseModes._ import com.fasterxml.jackson.core.JsonToken._ // A `ValueConverter` is responsible for converting a value from `JsonParser` @@ -55,108 +52,6 @@ class JacksonParser( private val factory = new JsonFactory() options.setJacksonOptions(factory) - private val emptyRow: Seq[InternalRow] = Seq(new GenericInternalRow(schema.length)) - - private val corruptFieldIndex = schema.getFieldIndex(options.columnNameOfCorruptRecord) - corruptFieldIndex.foreach { corrFieldIndex => - require(schema(corrFieldIndex).dataType == StringType) - require(schema(corrFieldIndex).nullable) - } - - @transient - private[this] var isWarningPrinted: Boolean = false - - @transient - private def printWarningForMalformedRecord(record: () => UTF8String): Unit = { - def sampleRecord: String = { - if (options.wholeFile) { - "" - } else { - s"Sample record: ${record()}\n" - } - } - - def footer: String = { - s"""Code example to print all malformed records (scala): - |=================================================== - |// The corrupted record exists in column ${options.columnNameOfCorruptRecord}. - |val parsedJson = spark.read.json("/path/to/json/file/test.json") - | - """.stripMargin - } - - if (options.permissive) { - logWarning( - s"""Found at least one malformed record. The JSON reader will replace - |all malformed records with placeholder null in current $PERMISSIVE_MODE parser mode. - |To find out which corrupted records have been replaced with null, please use the - |default inferred schema instead of providing a custom schema. - | - |${sampleRecord ++ footer} - | - """.stripMargin) - } else if (options.dropMalformed) { - logWarning( - s"""Found at least one malformed record. The JSON reader will drop - |all malformed records in current $DROP_MALFORMED_MODE parser mode. To find out which - |corrupted records have been dropped, please switch the parser mode to $PERMISSIVE_MODE - |mode and use the default inferred schema. - | - |${sampleRecord ++ footer} - | - """.stripMargin) - } - } - - @transient - private def printWarningIfWholeFile(): Unit = { - if (options.wholeFile && corruptFieldIndex.isDefined) { - logWarning( - s"""Enabling wholeFile mode and defining columnNameOfCorruptRecord may result - |in very large allocations or OutOfMemoryExceptions being raised. - | - """.stripMargin) - } - } - - /** - * This function deals with the cases it fails to parse. This function will be called - * when exceptions are caught during converting. This functions also deals with `mode` option. - */ - private def failedRecord(record: () => UTF8String): Seq[InternalRow] = { - corruptFieldIndex match { - case _ if options.failFast => - if (options.wholeFile) { - throw new SparkSQLJsonProcessingException("Malformed line in FAILFAST mode") - } else { - throw new SparkSQLJsonProcessingException(s"Malformed line in FAILFAST mode: ${record()}") - } - - case _ if options.dropMalformed => - if (!isWarningPrinted) { - printWarningForMalformedRecord(record) - isWarningPrinted = true - } - Nil - - case None => - if (!isWarningPrinted) { - printWarningForMalformedRecord(record) - isWarningPrinted = true - } - emptyRow - - case Some(corruptIndex) => - if (!isWarningPrinted) { - printWarningIfWholeFile() - isWarningPrinted = true - } - val row = new GenericInternalRow(schema.length) - row.update(corruptIndex, record()) - Seq(row) - } - } - /** * Create a converter which converts the JSON documents held by the `JsonParser` * to a value according to a desired schema. This is a wrapper for the method @@ -239,7 +134,7 @@ class JacksonParser( lowerCaseValue.equals("-inf")) { value.toFloat } else { - throw new SparkSQLJsonProcessingException(s"Cannot parse $value as FloatType.") + throw new RuntimeException(s"Cannot parse $value as FloatType.") } } @@ -259,7 +154,7 @@ class JacksonParser( lowerCaseValue.equals("-inf")) { value.toDouble } else { - throw new SparkSQLJsonProcessingException(s"Cannot parse $value as DoubleType.") + throw new RuntimeException(s"Cannot parse $value as DoubleType.") } } @@ -391,9 +286,8 @@ class JacksonParser( case token => // We cannot parse this token based on the given data type. So, we throw a - // SparkSQLJsonProcessingException and this exception will be caught by - // `parse` method. - throw new SparkSQLJsonProcessingException( + // RuntimeException and this exception will be caught by `parse` method. + throw new RuntimeException( s"Failed to parse a value for data type $dataType (current token: $token).") } @@ -466,14 +360,14 @@ class JacksonParser( parser.nextToken() match { case null => Nil case _ => rootConverter.apply(parser) match { - case null => throw new SparkSQLJsonProcessingException("Root converter returned null") + case null => throw new RuntimeException("Root converter returned null") case rows => rows } } } } catch { - case _: JsonProcessingException | _: SparkSQLJsonProcessingException => - failedRecord(() => recordLiteral(record)) + case e @ (_: RuntimeException | _: JsonProcessingException) => + throw BadRecordException(() => recordLiteral(record), () => None, e) } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/FailureSafeParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/FailureSafeParser.scala new file mode 100644 index 0000000000000..e8da10d65ecb9 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/FailureSafeParser.scala @@ -0,0 +1,80 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.util + +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.GenericInternalRow +import org.apache.spark.sql.types.StructType +import org.apache.spark.unsafe.types.UTF8String + +class FailureSafeParser[IN]( + rawParser: IN => Seq[InternalRow], + mode: String, + schema: StructType, + columnNameOfCorruptRecord: String) { + + private val corruptFieldIndex = schema.getFieldIndex(columnNameOfCorruptRecord) + private val actualSchema = StructType(schema.filterNot(_.name == columnNameOfCorruptRecord)) + private val resultRow = new GenericInternalRow(schema.length) + private val nullResult = new GenericInternalRow(schema.length) + + // This function takes 2 parameters: an optional partial result, and the bad record. If the given + // schema doesn't contain a field for corrupted record, we just return the partial result or a + // row with all fields null. If the given schema contains a field for corrupted record, we will + // set the bad record to this field, and set other fields according to the partial result or null. + private val toResultRow: (Option[InternalRow], () => UTF8String) => InternalRow = { + if (corruptFieldIndex.isDefined) { + (row, badRecord) => { + var i = 0 + while (i < actualSchema.length) { + val from = actualSchema(i) + resultRow(schema.fieldIndex(from.name)) = row.map(_.get(i, from.dataType)).orNull + i += 1 + } + resultRow(corruptFieldIndex.get) = badRecord() + resultRow + } + } else { + (row, _) => row.getOrElse(nullResult) + } + } + + def parse(input: IN): Iterator[InternalRow] = { + try { + rawParser.apply(input).toIterator.map(row => toResultRow(Some(row), () => null)) + } catch { + case e: BadRecordException if ParseModes.isPermissiveMode(mode) => + Iterator(toResultRow(e.partialResult(), e.record)) + case _: BadRecordException if ParseModes.isDropMalformedMode(mode) => + Iterator.empty + case e: BadRecordException => throw e.cause + } + } +} + +/** + * Exception thrown when the underlying parser meet a bad record and can't parse it. + * @param record a function to return the record that cause the parser to fail + * @param partialResult a function that returns an optional row, which is the partial result of + * parsing this bad record. + * @param cause the actual exception about why the record is bad and can't be parsed. + */ +case class BadRecordException( + record: () => UTF8String, + partialResult: () => Option[InternalRow], + cause: Throwable) extends Exception(cause) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala index 88fbfb4c92a00..767a636d70731 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala @@ -27,6 +27,7 @@ import org.apache.spark.Partition import org.apache.spark.annotation.InterfaceStability import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.json.{CreateJacksonParser, JacksonParser, JSONOptions} +import org.apache.spark.sql.catalyst.util.FailureSafeParser import org.apache.spark.sql.execution.LogicalRDD import org.apache.spark.sql.execution.command.DDLUtils import org.apache.spark.sql.execution.datasources.csv._ @@ -382,11 +383,18 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { } verifyColumnNameOfCorruptRecord(schema, parsedOptions.columnNameOfCorruptRecord) + val actualSchema = + StructType(schema.filterNot(_.name == parsedOptions.columnNameOfCorruptRecord)) val createParser = CreateJacksonParser.string _ val parsed = jsonDataset.rdd.mapPartitions { iter => - val parser = new JacksonParser(schema, parsedOptions) - iter.flatMap(parser.parse(_, createParser, UTF8String.fromString)) + val rawParser = new JacksonParser(actualSchema, parsedOptions) + val parser = new FailureSafeParser[String]( + input => rawParser.parse(input, createParser, UTF8String.fromString), + parsedOptions.parseMode, + schema, + parsedOptions.columnNameOfCorruptRecord) + iter.flatMap(parser.parse) } Dataset.ofRows( @@ -435,14 +443,21 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { } verifyColumnNameOfCorruptRecord(schema, parsedOptions.columnNameOfCorruptRecord) + val actualSchema = + StructType(schema.filterNot(_.name == parsedOptions.columnNameOfCorruptRecord)) val linesWithoutHeader: RDD[String] = maybeFirstLine.map { firstLine => filteredLines.rdd.mapPartitions(CSVUtils.filterHeaderLine(_, firstLine, parsedOptions)) }.getOrElse(filteredLines.rdd) val parsed = linesWithoutHeader.mapPartitions { iter => - val parser = new UnivocityParser(schema, parsedOptions) - iter.flatMap(line => parser.parse(line)) + val rawParser = new UnivocityParser(actualSchema, parsedOptions) + val parser = new FailureSafeParser[String]( + input => Seq(rawParser.parse(input)), + parsedOptions.parseMode, + schema, + parsedOptions.columnNameOfCorruptRecord) + iter.flatMap(parser.parse) } Dataset.ofRows( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala index 35ff924f27ce5..63af18ec5b8eb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala @@ -49,7 +49,7 @@ abstract class CSVDataSource extends Serializable { conf: Configuration, file: PartitionedFile, parser: UnivocityParser, - parsedOptions: CSVOptions): Iterator[InternalRow] + schema: StructType): Iterator[InternalRow] /** * Infers the schema from `inputPaths` files. @@ -115,17 +115,17 @@ object TextInputCSVDataSource extends CSVDataSource { conf: Configuration, file: PartitionedFile, parser: UnivocityParser, - parsedOptions: CSVOptions): Iterator[InternalRow] = { + schema: StructType): Iterator[InternalRow] = { val lines = { val linesReader = new HadoopFileLinesReader(file, conf) Option(TaskContext.get()).foreach(_.addTaskCompletionListener(_ => linesReader.close())) linesReader.map { line => - new String(line.getBytes, 0, line.getLength, parsedOptions.charset) + new String(line.getBytes, 0, line.getLength, parser.options.charset) } } - val shouldDropHeader = parsedOptions.headerFlag && file.start == 0 - UnivocityParser.parseIterator(lines, shouldDropHeader, parser) + val shouldDropHeader = parser.options.headerFlag && file.start == 0 + UnivocityParser.parseIterator(lines, shouldDropHeader, parser, schema) } override def infer( @@ -192,11 +192,12 @@ object WholeFileCSVDataSource extends CSVDataSource { conf: Configuration, file: PartitionedFile, parser: UnivocityParser, - parsedOptions: CSVOptions): Iterator[InternalRow] = { + schema: StructType): Iterator[InternalRow] = { UnivocityParser.parseStream( CodecStreams.createInputStreamWithCloseResource(conf, file.filePath), - parsedOptions.headerFlag, - parser) + parser.options.headerFlag, + parser, + schema) } override def infer( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVFileFormat.scala index 29c41455279e6..eef43c7629c12 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVFileFormat.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVFileFormat.scala @@ -113,8 +113,11 @@ class CSVFileFormat extends TextBasedFileFormat with DataSourceRegister { (file: PartitionedFile) => { val conf = broadcastedHadoopConf.value.value - val parser = new UnivocityParser(dataSchema, requiredSchema, parsedOptions) - CSVDataSource(parsedOptions).readFile(conf, file, parser, parsedOptions) + val parser = new UnivocityParser( + StructType(dataSchema.filterNot(_.name == parsedOptions.columnNameOfCorruptRecord)), + StructType(requiredSchema.filterNot(_.name == parsedOptions.columnNameOfCorruptRecord)), + parsedOptions) + CSVDataSource(parsedOptions).readFile(conf, file, parser, requiredSchema) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala index 2632e87971d68..f6c6b6f56cd9d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala @@ -82,7 +82,7 @@ class CSVOptions( val delimiter = CSVUtils.toChar( parameters.getOrElse("sep", parameters.getOrElse("delimiter", ","))) - private val parseMode = parameters.getOrElse("mode", "PERMISSIVE") + val parseMode = parameters.getOrElse("mode", "PERMISSIVE") val charset = parameters.getOrElse("encoding", parameters.getOrElse("charset", StandardCharsets.UTF_8.name())) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/UnivocityParser.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/UnivocityParser.scala index e42ea3fa391f5..263f77e11c4da 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/UnivocityParser.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/UnivocityParser.scala @@ -30,14 +30,14 @@ import com.univocity.parsers.csv.CsvParser import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.GenericInternalRow -import org.apache.spark.sql.catalyst.util.DateTimeUtils +import org.apache.spark.sql.catalyst.util.{BadRecordException, DateTimeUtils, FailureSafeParser} import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String class UnivocityParser( schema: StructType, requiredSchema: StructType, - private val options: CSVOptions) extends Logging { + val options: CSVOptions) extends Logging { require(requiredSchema.toSet.subsetOf(schema.toSet), "requiredSchema should be the subset of schema.") @@ -46,39 +46,26 @@ class UnivocityParser( // A `ValueConverter` is responsible for converting the given value to a desired type. private type ValueConverter = String => Any - private val corruptFieldIndex = schema.getFieldIndex(options.columnNameOfCorruptRecord) - corruptFieldIndex.foreach { corrFieldIndex => - require(schema(corrFieldIndex).dataType == StringType) - require(schema(corrFieldIndex).nullable) - } - - private val dataSchema = StructType(schema.filter(_.name != options.columnNameOfCorruptRecord)) - private val tokenizer = new CsvParser(options.asParserSettings) - private var numMalformedRecords = 0 - private val row = new GenericInternalRow(requiredSchema.length) - // In `PERMISSIVE` parse mode, we should be able to put the raw malformed row into the field - // specified in `columnNameOfCorruptRecord`. The raw input is retrieved by this method. - private def getCurrentInput(): String = tokenizer.getContext.currentParsedContent().stripLineEnd + // Retrieve the raw record string. + private def getCurrentInput: UTF8String = { + UTF8String.fromString(tokenizer.getContext.currentParsedContent().stripLineEnd) + } - // This parser loads an `tokenIndexArr`-th position value in input tokens, - // then put the value in `row(rowIndexArr)`. + // This parser first picks some tokens from the input tokens, according to the required schema, + // then parse these tokens and put the values in a row, with the order specified by the required + // schema. // // For example, let's say there is CSV data as below: // // a,b,c // 1,2,A // - // Also, let's say `columnNameOfCorruptRecord` is set to "_unparsed", `header` is `true` - // by user and the user selects "c", "b", "_unparsed" and "a" fields. In this case, we need - // to map those values below: - // - // required schema - ["c", "b", "_unparsed", "a"] - // CSV data schema - ["a", "b", "c"] - // required CSV data schema - ["c", "b", "a"] + // So the CSV data schema is: ["a", "b", "c"] + // And let's say the required schema is: ["c", "b"] // // with the input tokens, // @@ -86,45 +73,12 @@ class UnivocityParser( // // Each input token is placed in each output row's position by mapping these. In this case, // - // output row - ["A", 2, null, 1] - // - // In more details, - // - `valueConverters`, input tokens - CSV data schema - // `valueConverters` keeps the positions of input token indices (by its index) to each - // value's converter (by its value) in an order of CSV data schema. In this case, - // [string->int, string->int, string->string]. - // - // - `tokenIndexArr`, input tokens - required CSV data schema - // `tokenIndexArr` keeps the positions of input token indices (by its index) to reordered - // fields given the required CSV data schema (by its value). In this case, [2, 1, 0]. - // - // - `rowIndexArr`, input tokens - required schema - // `rowIndexArr` keeps the positions of input token indices (by its index) to reordered - // field indices given the required schema (by its value). In this case, [0, 1, 3]. + // output row - ["A", 2] private val valueConverters: Array[ValueConverter] = - dataSchema.map(f => makeConverter(f.name, f.dataType, f.nullable, options)).toArray - - // Only used to create both `tokenIndexArr` and `rowIndexArr`. This variable means - // the fields that we should try to convert. - private val reorderedFields = if (options.dropMalformed) { - // If `dropMalformed` is enabled, then it needs to parse all the values - // so that we can decide which row is malformed. - requiredSchema ++ schema.filterNot(requiredSchema.contains(_)) - } else { - requiredSchema - } + schema.map(f => makeConverter(f.name, f.dataType, f.nullable, options)).toArray private val tokenIndexArr: Array[Int] = { - reorderedFields - .filter(_.name != options.columnNameOfCorruptRecord) - .map(f => dataSchema.indexOf(f)).toArray - } - - private val rowIndexArr: Array[Int] = if (corruptFieldIndex.isDefined) { - val corrFieldIndex = corruptFieldIndex.get - reorderedFields.indices.filter(_ != corrFieldIndex).toArray - } else { - reorderedFields.indices.toArray + requiredSchema.map(f => schema.indexOf(f)).toArray } /** @@ -205,7 +159,7 @@ class UnivocityParser( } case _: StringType => (d: String) => - nullSafeDatum(d, name, nullable, options)(UTF8String.fromString(_)) + nullSafeDatum(d, name, nullable, options)(UTF8String.fromString) case udt: UserDefinedType[_] => (datum: String) => makeConverter(name, udt.sqlType, nullable, options) @@ -233,81 +187,41 @@ class UnivocityParser( * Parses a single CSV string and turns it into either one resulting row or no row (if the * the record is malformed). */ - def parse(input: String): Option[InternalRow] = convert(tokenizer.parseLine(input)) - - private def convert(tokens: Array[String]): Option[InternalRow] = { - convertWithParseMode(tokens) { tokens => - var i: Int = 0 - while (i < tokenIndexArr.length) { - // It anyway needs to try to parse since it decides if this row is malformed - // or not after trying to cast in `DROPMALFORMED` mode even if the casted - // value is not stored in the row. - val from = tokenIndexArr(i) - val to = rowIndexArr(i) - val value = valueConverters(from).apply(tokens(from)) - if (i < requiredSchema.length) { - row(to) = value - } - i += 1 - } - row - } - } - - private def convertWithParseMode( - tokens: Array[String])(convert: Array[String] => InternalRow): Option[InternalRow] = { - if (options.dropMalformed && dataSchema.length != tokens.length) { - if (numMalformedRecords < options.maxMalformedLogPerPartition) { - logWarning(s"Dropping malformed line: ${tokens.mkString(options.delimiter.toString)}") - } - if (numMalformedRecords == options.maxMalformedLogPerPartition - 1) { - logWarning( - s"More than ${options.maxMalformedLogPerPartition} malformed records have been " + - "found on this partition. Malformed records from now on will not be logged.") + def parse(input: String): InternalRow = convert(tokenizer.parseLine(input)) + + private def convert(tokens: Array[String]): InternalRow = { + if (tokens.length != schema.length) { + // If the number of tokens doesn't match the schema, we should treat it as a malformed record. + // However, we still have chance to parse some of the tokens, by adding extra null tokens in + // the tail if the number is smaller, or by dropping extra tokens if the number is larger. + val checkedTokens = if (schema.length > tokens.length) { + tokens ++ new Array[String](schema.length - tokens.length) + } else { + tokens.take(schema.length) } - numMalformedRecords += 1 - None - } else if (options.failFast && dataSchema.length != tokens.length) { - throw new RuntimeException(s"Malformed line in FAILFAST mode: " + - s"${tokens.mkString(options.delimiter.toString)}") - } else { - // If a length of parsed tokens is not equal to expected one, it makes the length the same - // with the expected. If the length is shorter, it adds extra tokens in the tail. - // If longer, it drops extra tokens. - // - // TODO: Revisit this; if a length of tokens does not match an expected length in the schema, - // we probably need to treat it as a malformed record. - // See an URL below for related discussions: - // https://github.com/apache/spark/pull/16928#discussion_r102657214 - val checkedTokens = if (options.permissive && dataSchema.length != tokens.length) { - if (dataSchema.length > tokens.length) { - tokens ++ new Array[String](dataSchema.length - tokens.length) - } else { - tokens.take(dataSchema.length) + def getPartialResult(): Option[InternalRow] = { + try { + Some(convert(checkedTokens)) + } catch { + case _: BadRecordException => None } - } else { - tokens } - + throw BadRecordException( + () => getCurrentInput, + getPartialResult, + new RuntimeException("Malformed CSV record")) + } else { try { - Some(convert(checkedTokens)) + var i = 0 + while (i < requiredSchema.length) { + val from = tokenIndexArr(i) + row(i) = valueConverters(from).apply(tokens(from)) + i += 1 + } + row } catch { - case NonFatal(e) if options.permissive => - val row = new GenericInternalRow(requiredSchema.length) - corruptFieldIndex.foreach(row(_) = UTF8String.fromString(getCurrentInput())) - Some(row) - case NonFatal(e) if options.dropMalformed => - if (numMalformedRecords < options.maxMalformedLogPerPartition) { - logWarning("Parse exception. " + - s"Dropping malformed line: ${tokens.mkString(options.delimiter.toString)}") - } - if (numMalformedRecords == options.maxMalformedLogPerPartition - 1) { - logWarning( - s"More than ${options.maxMalformedLogPerPartition} malformed records have been " + - "found on this partition. Malformed records from now on will not be logged.") - } - numMalformedRecords += 1 - None + case NonFatal(e) => + throw BadRecordException(() => getCurrentInput, () => None, e) } } } @@ -331,10 +245,16 @@ private[csv] object UnivocityParser { def parseStream( inputStream: InputStream, shouldDropHeader: Boolean, - parser: UnivocityParser): Iterator[InternalRow] = { + parser: UnivocityParser, + schema: StructType): Iterator[InternalRow] = { val tokenizer = parser.tokenizer + val safeParser = new FailureSafeParser[Array[String]]( + input => Seq(parser.convert(input)), + parser.options.parseMode, + schema, + parser.options.columnNameOfCorruptRecord) convertStream(inputStream, shouldDropHeader, tokenizer) { tokens => - parser.convert(tokens) + safeParser.parse(tokens) }.flatten } @@ -368,7 +288,8 @@ private[csv] object UnivocityParser { def parseIterator( lines: Iterator[String], shouldDropHeader: Boolean, - parser: UnivocityParser): Iterator[InternalRow] = { + parser: UnivocityParser, + schema: StructType): Iterator[InternalRow] = { val options = parser.options val linesWithoutHeader = if (shouldDropHeader) { @@ -381,6 +302,12 @@ private[csv] object UnivocityParser { val filteredLines: Iterator[String] = CSVUtils.filterCommentAndEmpty(linesWithoutHeader, options) - filteredLines.flatMap(line => parser.parse(line)) + + val safeParser = new FailureSafeParser[String]( + input => Seq(parser.parse(input)), + parser.options.parseMode, + schema, + parser.options.columnNameOfCorruptRecord) + filteredLines.flatMap(safeParser.parse) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonDataSource.scala index 84f026620d907..51e952c12202e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonDataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonDataSource.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.execution.datasources.json +import java.io.InputStream + import com.fasterxml.jackson.core.{JsonFactory, JsonParser} import com.google.common.io.ByteStreams import org.apache.hadoop.conf.Configuration @@ -31,6 +33,7 @@ import org.apache.spark.rdd.{BinaryFileRDD, RDD} import org.apache.spark.sql.{AnalysisException, Dataset, Encoders, SparkSession} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.json.{CreateJacksonParser, JacksonParser, JSONOptions} +import org.apache.spark.sql.catalyst.util.FailureSafeParser import org.apache.spark.sql.execution.datasources.{CodecStreams, DataSource, HadoopFileLinesReader, PartitionedFile} import org.apache.spark.sql.execution.datasources.text.TextFileFormat import org.apache.spark.sql.types.StructType @@ -49,7 +52,8 @@ abstract class JsonDataSource extends Serializable { def readFile( conf: Configuration, file: PartitionedFile, - parser: JacksonParser): Iterator[InternalRow] + parser: JacksonParser, + schema: StructType): Iterator[InternalRow] final def inferSchema( sparkSession: SparkSession, @@ -127,10 +131,16 @@ object TextInputJsonDataSource extends JsonDataSource { override def readFile( conf: Configuration, file: PartitionedFile, - parser: JacksonParser): Iterator[InternalRow] = { + parser: JacksonParser, + schema: StructType): Iterator[InternalRow] = { val linesReader = new HadoopFileLinesReader(file, conf) Option(TaskContext.get()).foreach(_.addTaskCompletionListener(_ => linesReader.close())) - linesReader.flatMap(parser.parse(_, CreateJacksonParser.text, textToUTF8String)) + val safeParser = new FailureSafeParser[Text]( + input => parser.parse(input, CreateJacksonParser.text, textToUTF8String), + parser.options.parseMode, + schema, + parser.options.columnNameOfCorruptRecord) + linesReader.flatMap(safeParser.parse) } private def textToUTF8String(value: Text): UTF8String = { @@ -180,7 +190,8 @@ object WholeFileJsonDataSource extends JsonDataSource { override def readFile( conf: Configuration, file: PartitionedFile, - parser: JacksonParser): Iterator[InternalRow] = { + parser: JacksonParser, + schema: StructType): Iterator[InternalRow] = { def partitionedFileString(ignored: Any): UTF8String = { Utils.tryWithResource { CodecStreams.createInputStreamWithCloseResource(conf, file.filePath) @@ -189,9 +200,13 @@ object WholeFileJsonDataSource extends JsonDataSource { } } - parser.parse( - CodecStreams.createInputStreamWithCloseResource(conf, file.filePath), - CreateJacksonParser.inputStream, - partitionedFileString).toIterator + val safeParser = new FailureSafeParser[InputStream]( + input => parser.parse(input, CreateJacksonParser.inputStream, partitionedFileString), + parser.options.parseMode, + schema, + parser.options.columnNameOfCorruptRecord) + + safeParser.parse( + CodecStreams.createInputStreamWithCloseResource(conf, file.filePath)) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonFileFormat.scala index a9dd91eba6f72..53d62d88b04c6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonFileFormat.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonFileFormat.scala @@ -102,6 +102,8 @@ class JsonFileFormat extends TextBasedFileFormat with DataSourceRegister { sparkSession.sessionState.conf.sessionLocalTimeZone, sparkSession.sessionState.conf.columnNameOfCorruptRecord) + val actualSchema = + StructType(requiredSchema.filterNot(_.name == parsedOptions.columnNameOfCorruptRecord)) // Check a field requirement for corrupt records here to throw an exception in a driver side dataSchema.getFieldIndex(parsedOptions.columnNameOfCorruptRecord).foreach { corruptFieldIndex => val f = dataSchema(corruptFieldIndex) @@ -112,11 +114,12 @@ class JsonFileFormat extends TextBasedFileFormat with DataSourceRegister { } (file: PartitionedFile) => { - val parser = new JacksonParser(requiredSchema, parsedOptions) + val parser = new JacksonParser(actualSchema, parsedOptions) JsonDataSource(parsedOptions).readFile( broadcastedHadoopConf.value.value, file, - parser) + parser, + requiredSchema) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala index 95dfdf5b298e6..598babfe0e7ad 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala @@ -293,7 +293,7 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils { .load(testFile(carsFile)).collect() } - assert(exception.getMessage.contains("Malformed line in FAILFAST mode: 2015,Chevy,Volt")) + assert(exception.getMessage.contains("Malformed CSV record")) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala index 9b0efcbdaf5c3..56fcf773f7dd9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala @@ -1043,7 +1043,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { .json(corruptRecords) .collect() } - assert(exceptionOne.getMessage.contains("Malformed line in FAILFAST mode: {")) + assert(exceptionOne.getMessage.contains("JsonParseException")) val exceptionTwo = intercept[SparkException] { spark.read @@ -1052,7 +1052,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { .json(corruptRecords) .collect() } - assert(exceptionTwo.getMessage.contains("Malformed line in FAILFAST mode: {")) + assert(exceptionTwo.getMessage.contains("JsonParseException")) } test("Corrupt records: DROPMALFORMED mode") { @@ -1929,7 +1929,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { .json(path) .collect() } - assert(exceptionOne.getMessage.contains("Malformed line in FAILFAST mode")) + assert(exceptionOne.getMessage.contains("Failed to parse a value")) val exceptionTwo = intercept[SparkException] { spark.read @@ -1939,7 +1939,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { .json(path) .collect() } - assert(exceptionTwo.getMessage.contains("Malformed line in FAILFAST mode")) + assert(exceptionTwo.getMessage.contains("Failed to parse a value")) } } From d2dcd6792f4cea39e12945ad8c4cda5d8d034de4 Mon Sep 17 00:00:00 2001 From: Xiao Li Date: Mon, 20 Mar 2017 22:52:45 -0700 Subject: [PATCH 0071/1765] [SPARK-20024][SQL][TEST-MAVEN] SessionCatalog reset need to set the current database of ExternalCatalog ### What changes were proposed in this pull request? SessionCatalog API setCurrentDatabase does not set the current database of the underlying ExternalCatalog. Thus, weird errors could come in the test suites after we call reset. We need to fix it. So far, have not found the direct impact in the other code paths because we expect all the SessionCatalog APIs should always use the current database value we managed, unless some of code paths skip it. Thus, we fix it in the test-only function reset(). ### How was this patch tested? Multiple test case failures are observed in mvn and add a test case in SessionCatalogSuite. Author: Xiao Li Closes #17354 from gatorsmile/useDB. --- .../org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala | 1 + .../apache/spark/sql/catalyst/catalog/SessionCatalogSuite.scala | 2 -- 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala index 25aa8d3ba921f..b134fd44a311f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala @@ -1175,6 +1175,7 @@ class SessionCatalog( */ def reset(): Unit = synchronized { setCurrentDatabase(DEFAULT_DATABASE) + externalCatalog.setCurrentDatabase(DEFAULT_DATABASE) listDatabases().filter(_ != DEFAULT_DATABASE).foreach { db => dropDatabase(db, ignoreIfNotExists = false, cascade = true) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalogSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalogSuite.scala index bb87763e0bbb0..fd9e5d6bb13ed 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalogSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalogSuite.scala @@ -53,7 +53,6 @@ abstract class SessionCatalogSuite extends PlanTest { private def withBasicCatalog(f: SessionCatalog => Unit): Unit = { val catalog = new SessionCatalog(newBasicCatalog()) - catalog.createDatabase(newDb("default"), ignoreIfExists = true) try { f(catalog) } finally { @@ -76,7 +75,6 @@ abstract class SessionCatalogSuite extends PlanTest { test("basic create and list databases") { withEmptyCatalog { catalog => - catalog.createDatabase(newDb("default"), ignoreIfExists = true) assert(catalog.databaseExists("default")) assert(!catalog.databaseExists("testing")) assert(!catalog.databaseExists("testing2")) From 7620aed828d8baefc425b54684a83c81f1507b02 Mon Sep 17 00:00:00 2001 From: christopher snow Date: Tue, 21 Mar 2017 13:23:59 +0000 Subject: [PATCH 0072/1765] [SPARK-20011][ML][DOCS] Clarify documentation for ALS 'rank' parameter ## What changes were proposed in this pull request? API documentation and collaborative filtering documentation page changes to clarify inconsistent description of ALS rank parameter. - [DOCS] was previously: "rank is the number of latent factors in the model." - [API] was previously: "rank - number of features to use" This change describes rank in both places consistently as: - "Number of features to use (also referred to as the number of latent factors)" Author: Chris Snow Author: christopher snow Closes #17345 from snowch/SPARK-20011. --- docs/mllib-collaborative-filtering.md | 2 +- .../apache/spark/mllib/recommendation/ALS.scala | 16 ++++++++-------- python/pyspark/mllib/recommendation.py | 4 ++-- 3 files changed, 11 insertions(+), 11 deletions(-) diff --git a/docs/mllib-collaborative-filtering.md b/docs/mllib-collaborative-filtering.md index 0f891a09a6e61..d1bb6d69f1256 100644 --- a/docs/mllib-collaborative-filtering.md +++ b/docs/mllib-collaborative-filtering.md @@ -20,7 +20,7 @@ algorithm to learn these latent factors. The implementation in `spark.mllib` has following parameters: * *numBlocks* is the number of blocks used to parallelize computation (set to -1 to auto-configure). -* *rank* is the number of latent factors in the model. +* *rank* is the number of features to use (also referred to as the number of latent factors). * *iterations* is the number of iterations of ALS to run. ALS typically converges to a reasonable solution in 20 iterations or less. * *lambda* specifies the regularization parameter in ALS. diff --git a/mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala b/mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala index 76b1bc13b4b05..14288221b6945 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala @@ -301,7 +301,7 @@ object ALS { * level of parallelism. * * @param ratings RDD of [[Rating]] objects with userID, productID, and rating - * @param rank number of features to use + * @param rank number of features to use (also referred to as the number of latent factors) * @param iterations number of iterations of ALS * @param lambda regularization parameter * @param blocks level of parallelism to split computation into @@ -326,7 +326,7 @@ object ALS { * level of parallelism. * * @param ratings RDD of [[Rating]] objects with userID, productID, and rating - * @param rank number of features to use + * @param rank number of features to use (also referred to as the number of latent factors) * @param iterations number of iterations of ALS * @param lambda regularization parameter * @param blocks level of parallelism to split computation into @@ -349,7 +349,7 @@ object ALS { * parallelism automatically based on the number of partitions in `ratings`. * * @param ratings RDD of [[Rating]] objects with userID, productID, and rating - * @param rank number of features to use + * @param rank number of features to use (also referred to as the number of latent factors) * @param iterations number of iterations of ALS * @param lambda regularization parameter */ @@ -366,7 +366,7 @@ object ALS { * parallelism automatically based on the number of partitions in `ratings`. * * @param ratings RDD of [[Rating]] objects with userID, productID, and rating - * @param rank number of features to use + * @param rank number of features to use (also referred to as the number of latent factors) * @param iterations number of iterations of ALS */ @Since("0.8.0") @@ -383,7 +383,7 @@ object ALS { * a level of parallelism given by `blocks`. * * @param ratings RDD of (userID, productID, rating) pairs - * @param rank number of features to use + * @param rank number of features to use (also referred to as the number of latent factors) * @param iterations number of iterations of ALS * @param lambda regularization parameter * @param blocks level of parallelism to split computation into @@ -410,7 +410,7 @@ object ALS { * iteratively with a configurable level of parallelism. * * @param ratings RDD of [[Rating]] objects with userID, productID, and rating - * @param rank number of features to use + * @param rank number of features to use (also referred to as the number of latent factors) * @param iterations number of iterations of ALS * @param lambda regularization parameter * @param blocks level of parallelism to split computation into @@ -436,7 +436,7 @@ object ALS { * partitions in `ratings`. * * @param ratings RDD of [[Rating]] objects with userID, productID, and rating - * @param rank number of features to use + * @param rank number of features to use (also referred to as the number of latent factors) * @param iterations number of iterations of ALS * @param lambda regularization parameter * @param alpha confidence parameter @@ -455,7 +455,7 @@ object ALS { * partitions in `ratings`. * * @param ratings RDD of [[Rating]] objects with userID, productID, and rating - * @param rank number of features to use + * @param rank number of features to use (also referred to as the number of latent factors) * @param iterations number of iterations of ALS */ @Since("0.8.1") diff --git a/python/pyspark/mllib/recommendation.py b/python/pyspark/mllib/recommendation.py index 732300ee9c2c9..81182881352bb 100644 --- a/python/pyspark/mllib/recommendation.py +++ b/python/pyspark/mllib/recommendation.py @@ -249,7 +249,7 @@ def train(cls, ratings, rank, iterations=5, lambda_=0.01, blocks=-1, nonnegative :param ratings: RDD of `Rating` or (userID, productID, rating) tuple. :param rank: - Rank of the feature matrices computed (number of features). + Number of features to use (also referred to as the number of latent factors). :param iterations: Number of iterations of ALS. (default: 5) @@ -287,7 +287,7 @@ def trainImplicit(cls, ratings, rank, iterations=5, lambda_=0.01, blocks=-1, alp :param ratings: RDD of `Rating` or (userID, productID, rating) tuple. :param rank: - Rank of the feature matrices computed (number of features). + Number of features to use (also referred to as the number of latent factors). :param iterations: Number of iterations of ALS. (default: 5) From 650d03cfc9a609a2c603f9ced452d03ec8429b0d Mon Sep 17 00:00:00 2001 From: "jianran.tfh" Date: Tue, 21 Mar 2017 15:15:19 +0000 Subject: [PATCH 0073/1765] [SPARK-19998][BLOCK MANAGER] Change the exception log to add RDD id of the related the block ## What changes were proposed in this pull request? "java.lang.Exception: Could not compute split, block $blockId not found" doesn't have the rdd id info, the "BlockManager: Removing RDD $id" has only the RDD id, so it couldn't find that the Exception's reason is the Removing; so it's better block not found Exception add RDD id info ## How was this patch tested? Existing tests Author: jianran.tfh Author: jianran Closes #17334 from jianran/SPARK-19998. --- core/src/main/scala/org/apache/spark/rdd/BlockRDD.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/rdd/BlockRDD.scala b/core/src/main/scala/org/apache/spark/rdd/BlockRDD.scala index d47b75544fdba..4e036c2ed49b5 100644 --- a/core/src/main/scala/org/apache/spark/rdd/BlockRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/BlockRDD.scala @@ -47,7 +47,7 @@ class BlockRDD[T: ClassTag](sc: SparkContext, @transient val blockIds: Array[Blo blockManager.get[T](blockId) match { case Some(block) => block.data.asInstanceOf[Iterator[T]] case None => - throw new Exception("Could not compute split, block " + blockId + " not found") + throw new Exception(s"Could not compute split, block $blockId of RDD $id not found") } } From 14865d7ff78db5cf9a3e8626204c8e7ed059c353 Mon Sep 17 00:00:00 2001 From: wangzhenhua Date: Tue, 21 Mar 2017 08:44:09 -0700 Subject: [PATCH 0074/1765] [SPARK-17080][SQL][FOLLOWUP] Improve documentation, change buildJoin method structure and add a debug log ## What changes were proposed in this pull request? 1. Improve documentation for class `Cost` and `JoinReorderDP` and method `buildJoin()`. 2. Change code structure of `buildJoin()` to make the logic clearer. 3. Add a debug-level log to record information for join reordering, including time cost, the number of items and the number of plans in memo. ## How was this patch tested? Not related. Author: wangzhenhua Closes #17353 from wzhfy/reorderFollow. --- .../optimizer/CostBasedJoinReorder.scala | 109 +++++++++++------- .../apache/spark/sql/internal/SQLConf.scala | 1 + 2 files changed, 68 insertions(+), 42 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/CostBasedJoinReorder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/CostBasedJoinReorder.scala index 521c468fe18af..fc37720809ba2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/CostBasedJoinReorder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/CostBasedJoinReorder.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst.optimizer import scala.collection.mutable +import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.expressions.{And, Attribute, AttributeSet, Expression, PredicateHelper} import org.apache.spark.sql.catalyst.plans.{Inner, InnerLike} import org.apache.spark.sql.catalyst.plans.logical.{BinaryNode, Join, LogicalPlan, Project} @@ -51,7 +52,7 @@ case class CostBasedJoinReorder(conf: SQLConf) extends Rule[LogicalPlan] with Pr } } - def reorder(plan: LogicalPlan, output: AttributeSet): LogicalPlan = { + private def reorder(plan: LogicalPlan, output: AttributeSet): LogicalPlan = { val (items, conditions) = extractInnerJoins(plan) // TODO: Compute the set of star-joins and use them in the join enumeration // algorithm to prune un-optimal plan choices. @@ -69,7 +70,7 @@ case class CostBasedJoinReorder(conf: SQLConf) extends Rule[LogicalPlan] with Pr } /** - * Extract consecutive inner joinable items and join conditions. + * Extracts items of consecutive inner joins and join conditions. * This method works for bushy trees and left/right deep trees. */ private def extractInnerJoins(plan: LogicalPlan): (Seq[LogicalPlan], Set[Expression]) = { @@ -119,18 +120,21 @@ case class CostBasedJoinReorder(conf: SQLConf) extends Rule[LogicalPlan] with Pr * When building m-way joins, we only keep the best plan (with the lowest cost) for the same set * of m items. E.g., for 3-way joins, we keep only the best plan for items {A, B, C} among * plans (A J B) J C, (A J C) J B and (B J C) J A. - * - * Thus the plans maintained for each level when reordering four items A, B, C, D are as follows: + * We also prune cartesian product candidates when building a new plan if there exists no join + * condition involving references from both left and right. This pruning strategy significantly + * reduces the search space. + * E.g., given A J B J C J D with join conditions A.k1 = B.k1 and B.k2 = C.k2 and C.k3 = D.k3, + * plans maintained for each level are as follows: * level 0: p({A}), p({B}), p({C}), p({D}) - * level 1: p({A, B}), p({A, C}), p({A, D}), p({B, C}), p({B, D}), p({C, D}) - * level 2: p({A, B, C}), p({A, B, D}), p({A, C, D}), p({B, C, D}) + * level 1: p({A, B}), p({B, C}), p({C, D}) + * level 2: p({A, B, C}), p({B, C, D}) * level 3: p({A, B, C, D}) * where p({A, B, C, D}) is the final output plan. * * For cost evaluation, since physical costs for operators are not available currently, we use * cardinalities and sizes to compute costs. */ -object JoinReorderDP extends PredicateHelper { +object JoinReorderDP extends PredicateHelper with Logging { def search( conf: SQLConf, @@ -138,6 +142,7 @@ object JoinReorderDP extends PredicateHelper { conditions: Set[Expression], topOutput: AttributeSet): LogicalPlan = { + val startTime = System.nanoTime() // Level i maintains all found plans for i + 1 items. // Create the initial plans: each plan is a single item with zero cost. val itemIndex = items.zipWithIndex @@ -152,6 +157,10 @@ object JoinReorderDP extends PredicateHelper { foundPlans += searchLevel(foundPlans, conf, conditions, topOutput) } + val durationInMs = (System.nanoTime() - startTime) / (1000 * 1000) + logDebug(s"Join reordering finished. Duration: $durationInMs ms, number of items: " + + s"${items.length}, number of plans in memo: ${foundPlans.map(_.size).sum}") + // The last level must have one and only one plan, because all items are joinable. assert(foundPlans.size == items.length && foundPlans.last.size == 1) foundPlans.last.head._2.plan @@ -183,18 +192,15 @@ object JoinReorderDP extends PredicateHelper { } otherSideCandidates.foreach { otherSidePlan => - // Should not join two overlapping item sets. - if (oneSidePlan.itemIds.intersect(otherSidePlan.itemIds).isEmpty) { - val joinPlan = buildJoin(oneSidePlan, otherSidePlan, conf, conditions, topOutput) - if (joinPlan.isDefined) { - val newJoinPlan = joinPlan.get + buildJoin(oneSidePlan, otherSidePlan, conf, conditions, topOutput) match { + case Some(newJoinPlan) => // Check if it's the first plan for the item set, or it's a better plan than // the existing one due to lower cost. val existingPlan = nextLevel.get(newJoinPlan.itemIds) if (existingPlan.isEmpty || newJoinPlan.betterThan(existingPlan.get, conf)) { nextLevel.update(newJoinPlan.itemIds, newJoinPlan) } - } + case None => } } } @@ -203,7 +209,17 @@ object JoinReorderDP extends PredicateHelper { nextLevel.toMap } - /** Build a new join node. */ + /** + * Builds a new JoinPlan when both conditions hold: + * - the sets of items contained in left and right sides do not overlap. + * - there exists at least one join condition involving references from both sides. + * @param oneJoinPlan One side JoinPlan for building a new JoinPlan. + * @param otherJoinPlan The other side JoinPlan for building a new join node. + * @param conf SQLConf for statistics computation. + * @param conditions The overall set of join conditions. + * @param topOutput The output attributes of the final plan. + * @return Builds and returns a new JoinPlan if both conditions hold. Otherwise, returns None. + */ private def buildJoin( oneJoinPlan: JoinPlan, otherJoinPlan: JoinPlan, @@ -211,6 +227,11 @@ object JoinReorderDP extends PredicateHelper { conditions: Set[Expression], topOutput: AttributeSet): Option[JoinPlan] = { + if (oneJoinPlan.itemIds.intersect(otherJoinPlan.itemIds).nonEmpty) { + // Should not join two overlapping item sets. + return None + } + val onePlan = oneJoinPlan.plan val otherPlan = otherJoinPlan.plan val joinConds = conditions @@ -220,33 +241,33 @@ object JoinReorderDP extends PredicateHelper { if (joinConds.isEmpty) { // Cartesian product is very expensive, so we exclude them from candidate plans. // This also significantly reduces the search space. - None + return None + } + + // Put the deeper side on the left, tend to build a left-deep tree. + val (left, right) = if (oneJoinPlan.itemIds.size >= otherJoinPlan.itemIds.size) { + (onePlan, otherPlan) } else { - // Put the deeper side on the left, tend to build a left-deep tree. - val (left, right) = if (oneJoinPlan.itemIds.size >= otherJoinPlan.itemIds.size) { - (onePlan, otherPlan) + (otherPlan, onePlan) + } + val newJoin = Join(left, right, Inner, joinConds.reduceOption(And)) + val collectedJoinConds = joinConds ++ oneJoinPlan.joinConds ++ otherJoinPlan.joinConds + val remainingConds = conditions -- collectedJoinConds + val neededAttr = AttributeSet(remainingConds.flatMap(_.references)) ++ topOutput + val neededFromNewJoin = newJoin.outputSet.filter(neededAttr.contains) + val newPlan = + if ((newJoin.outputSet -- neededFromNewJoin).nonEmpty) { + Project(neededFromNewJoin.toSeq, newJoin) } else { - (otherPlan, onePlan) + newJoin } - val newJoin = Join(left, right, Inner, joinConds.reduceOption(And)) - val collectedJoinConds = joinConds ++ oneJoinPlan.joinConds ++ otherJoinPlan.joinConds - val remainingConds = conditions -- collectedJoinConds - val neededAttr = AttributeSet(remainingConds.flatMap(_.references)) ++ topOutput - val neededFromNewJoin = newJoin.outputSet.filter(neededAttr.contains) - val newPlan = - if ((newJoin.outputSet -- neededFromNewJoin).nonEmpty) { - Project(neededFromNewJoin.toSeq, newJoin) - } else { - newJoin - } - val itemIds = oneJoinPlan.itemIds.union(otherJoinPlan.itemIds) - // Now the root node of onePlan/otherPlan becomes an intermediate join (if it's a non-leaf - // item), so the cost of the new join should also include its own cost. - val newPlanCost = oneJoinPlan.planCost + oneJoinPlan.rootCost(conf) + - otherJoinPlan.planCost + otherJoinPlan.rootCost(conf) - Some(JoinPlan(itemIds, newPlan, collectedJoinConds, newPlanCost)) - } + val itemIds = oneJoinPlan.itemIds.union(otherJoinPlan.itemIds) + // Now the root node of onePlan/otherPlan becomes an intermediate join (if it's a non-leaf + // item), so the cost of the new join should also include its own cost. + val newPlanCost = oneJoinPlan.planCost + oneJoinPlan.rootCost(conf) + + otherJoinPlan.planCost + otherJoinPlan.rootCost(conf) + Some(JoinPlan(itemIds, newPlan, collectedJoinConds, newPlanCost)) } /** Map[set of item ids, join plan for these items] */ @@ -278,10 +299,10 @@ object JoinReorderDP extends PredicateHelper { } def betterThan(other: JoinPlan, conf: SQLConf): Boolean = { - if (other.planCost.rows == 0 || other.planCost.size == 0) { + if (other.planCost.card == 0 || other.planCost.size == 0) { false } else { - val relativeRows = BigDecimal(this.planCost.rows) / BigDecimal(other.planCost.rows) + val relativeRows = BigDecimal(this.planCost.card) / BigDecimal(other.planCost.card) val relativeSize = BigDecimal(this.planCost.size) / BigDecimal(other.planCost.size) relativeRows * conf.joinReorderCardWeight + relativeSize * (1 - conf.joinReorderCardWeight) < 1 @@ -290,7 +311,11 @@ object JoinReorderDP extends PredicateHelper { } } -/** This class defines the cost model. */ -case class Cost(rows: BigInt, size: BigInt) { - def +(other: Cost): Cost = Cost(this.rows + other.rows, this.size + other.size) +/** + * This class defines the cost model for a plan. + * @param card Cardinality (number of rows). + * @param size Size in bytes. + */ +case class Cost(card: BigInt, size: BigInt) { + def +(other: Cost): Cost = Cost(this.card + other.card, this.size + other.size) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index b6e0b8ccbeed6..d5006c16469bc 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -708,6 +708,7 @@ object SQLConf { buildConf("spark.sql.cbo.joinReorder.dp.threshold") .doc("The maximum number of joined nodes allowed in the dynamic programming algorithm.") .intConf + .checkValue(number => number > 0, "The maximum number must be a positive integer.") .createWithDefault(12) val JOIN_REORDER_CARD_WEIGHT = From 63f077fbe50b4094340e9915db41d7dbdba52975 Mon Sep 17 00:00:00 2001 From: Zheng RuiFeng Date: Tue, 21 Mar 2017 08:45:59 -0700 Subject: [PATCH 0075/1765] [SPARK-20041][DOC] Update docs for NaN handling in approxQuantile ## What changes were proposed in this pull request? Update docs for NaN handling in approxQuantile. ## How was this patch tested? existing tests. Author: Zheng RuiFeng Closes #17369 from zhengruifeng/doc_quantiles_nan. --- R/pkg/R/stats.R | 3 ++- .../org/apache/spark/ml/feature/QuantileDiscretizer.scala | 4 ++-- python/pyspark/sql/dataframe.py | 3 ++- 3 files changed, 6 insertions(+), 4 deletions(-) diff --git a/R/pkg/R/stats.R b/R/pkg/R/stats.R index 8d1d165052f7f..d78a10893f92e 100644 --- a/R/pkg/R/stats.R +++ b/R/pkg/R/stats.R @@ -149,7 +149,8 @@ setMethod("freqItems", signature(x = "SparkDataFrame", cols = "character"), #' This method implements a variation of the Greenwald-Khanna algorithm (with some speed #' optimizations). The algorithm was first present in [[http://dx.doi.org/10.1145/375663.375670 #' Space-efficient Online Computation of Quantile Summaries]] by Greenwald and Khanna. -#' Note that rows containing any NA values will be removed before calculation. +#' Note that NA values will be ignored in numerical columns before calculation. For +#' columns only containing NA values, an empty list is returned. #' #' @param x A SparkDataFrame. #' @param cols A single column name, or a list of names for multiple columns. diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala index 80c7f55e26b84..feceeba866dfa 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala @@ -93,8 +93,8 @@ private[feature] trait QuantileDiscretizerBase extends Params * are too few distinct values of the input to create enough distinct quantiles. * * NaN handling: - * NaN values will be removed from the column during `QuantileDiscretizer` fitting. This will - * produce a `Bucketizer` model for making predictions. During the transformation, + * null and NaN values will be ignored from the column during `QuantileDiscretizer` fitting. This + * will produce a `Bucketizer` model for making predictions. During the transformation, * `Bucketizer` will raise an error when it finds NaN values in the dataset, but the user can * also choose to either keep or remove NaN values within the dataset by setting `handleInvalid`. * If the user chooses to keep NaN values, they will be handled specially and placed into their own diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index bb6df22682095..a24512f53c525 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -1384,7 +1384,8 @@ def approxQuantile(self, col, probabilities, relativeError): Space-efficient Online Computation of Quantile Summaries]] by Greenwald and Khanna. - Note that rows containing any null values will be removed before calculation. + Note that null values will be ignored in numerical columns before calculation. + For columns only containing null values, an empty list is returned. :param col: str, list. Can be a single column name, or a list of names for multiple columns. From 4c0ff5f58565f811b65f1a11b6121da007bcbd5f Mon Sep 17 00:00:00 2001 From: Xin Wu Date: Tue, 21 Mar 2017 08:49:54 -0700 Subject: [PATCH 0076/1765] [SPARK-19261][SQL] Alter add columns for Hive serde and some datasource tables ## What changes were proposed in this pull request? Support` ALTER TABLE ADD COLUMNS (...) `syntax for Hive serde and some datasource tables. In this PR, we consider a few aspects: 1. View is not supported for `ALTER ADD COLUMNS` 2. Since tables created in SparkSQL with Hive DDL syntax will populate table properties with schema information, we need make sure the consistency of the schema before and after ALTER operation in order for future use. 3. For embedded-schema type of format, such as `parquet`, we need to make sure that the predicate on the newly-added columns can be evaluated properly, or pushed down properly. In case of the data file does not have the columns for the newly-added columns, such predicates should return as if the column values are NULLs. 4. For datasource table, this feature does not support the following: 4.1 TEXT format, since there is only one default column `value` is inferred for text format data. 4.2 ORC format, since SparkSQL native ORC reader does not support the difference between user-specified-schema and inferred schema from ORC files. 4.3 Third party datasource types that implements RelationProvider, including the built-in JDBC format, since different implementations by the vendors may have different ways to dealing with schema. 4.4 Other datasource types, such as `parquet`, `json`, `csv`, `hive` are supported. 5. Column names being added can not be duplicate of any existing data column or partition column names. Case sensitivity is taken into consideration according to the sql configuration. 6. This feature also supports In-Memory catalog, while Hive support is turned off. ## How was this patch tested? Add new test cases Author: Xin Wu Closes #16626 from xwu0226/alter_add_columns. --- .../spark/sql/catalyst/parser/SqlBase.g4 | 3 +- .../sql/catalyst/catalog/SessionCatalog.scala | 56 ++++++++ .../catalog/SessionCatalogSuite.scala | 29 +++++ .../spark/sql/execution/SparkSqlParser.scala | 16 +++ .../spark/sql/execution/command/tables.scala | 76 ++++++++++- .../execution/command/DDLCommandSuite.scala | 8 +- .../sql/execution/command/DDLSuite.scala | 122 ++++++++++++++++++ .../sql/hive/execution/HiveDDLSuite.scala | 100 +++++++++++++- 8 files changed, 400 insertions(+), 10 deletions(-) diff --git a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 index cc3b8fd3b4689..c4a590ec6916b 100644 --- a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 +++ b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 @@ -85,6 +85,8 @@ statement LIKE source=tableIdentifier locationSpec? #createTableLike | ANALYZE TABLE tableIdentifier partitionSpec? COMPUTE STATISTICS (identifier | FOR COLUMNS identifierSeq)? #analyze + | ALTER TABLE tableIdentifier + ADD COLUMNS '(' columns=colTypeList ')' #addTableColumns | ALTER (TABLE | VIEW) from=tableIdentifier RENAME TO to=tableIdentifier #renameTable | ALTER (TABLE | VIEW) tableIdentifier @@ -198,7 +200,6 @@ unsupportedHiveNativeCommands | kw1=ALTER kw2=TABLE tableIdentifier partitionSpec? kw3=COMPACT | kw1=ALTER kw2=TABLE tableIdentifier partitionSpec? kw3=CONCATENATE | kw1=ALTER kw2=TABLE tableIdentifier partitionSpec? kw3=SET kw4=FILEFORMAT - | kw1=ALTER kw2=TABLE tableIdentifier partitionSpec? kw3=ADD kw4=COLUMNS | kw1=ALTER kw2=TABLE tableIdentifier partitionSpec? kw3=REPLACE kw4=COLUMNS | kw1=START kw2=TRANSACTION | kw1=COMMIT diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala index b134fd44a311f..a469d12451643 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala @@ -35,6 +35,7 @@ import org.apache.spark.sql.catalyst.expressions.{Expression, ExpressionInfo} import org.apache.spark.sql.catalyst.parser.{CatalystSqlParser, ParserInterface} import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, SubqueryAlias, View} import org.apache.spark.sql.catalyst.util.StringUtils +import org.apache.spark.sql.types.{StructField, StructType} object SessionCatalog { val DEFAULT_DATABASE = "default" @@ -161,6 +162,20 @@ class SessionCatalog( throw new TableAlreadyExistsException(db = db, table = name.table) } } + + private def checkDuplication(fields: Seq[StructField]): Unit = { + val columnNames = if (conf.caseSensitiveAnalysis) { + fields.map(_.name) + } else { + fields.map(_.name.toLowerCase) + } + if (columnNames.distinct.length != columnNames.length) { + val duplicateColumns = columnNames.groupBy(identity).collect { + case (x, ys) if ys.length > 1 => x + } + throw new AnalysisException(s"Found duplicate column(s): ${duplicateColumns.mkString(", ")}") + } + } // ---------------------------------------------------------------------------- // Databases // ---------------------------------------------------------------------------- @@ -295,6 +310,47 @@ class SessionCatalog( externalCatalog.alterTable(newTableDefinition) } + /** + * Alter the schema of a table identified by the provided table identifier. The new schema + * should still contain the existing bucket columns and partition columns used by the table. This + * method will also update any Spark SQL-related parameters stored as Hive table properties (such + * as the schema itself). + * + * @param identifier TableIdentifier + * @param newSchema Updated schema to be used for the table (must contain existing partition and + * bucket columns, and partition columns need to be at the end) + */ + def alterTableSchema( + identifier: TableIdentifier, + newSchema: StructType): Unit = { + val db = formatDatabaseName(identifier.database.getOrElse(getCurrentDatabase)) + val table = formatTableName(identifier.table) + val tableIdentifier = TableIdentifier(table, Some(db)) + requireDbExists(db) + requireTableExists(tableIdentifier) + checkDuplication(newSchema) + + val catalogTable = externalCatalog.getTable(db, table) + val oldSchema = catalogTable.schema + + // not supporting dropping columns yet + val nonExistentColumnNames = oldSchema.map(_.name).filterNot(columnNameResolved(newSchema, _)) + if (nonExistentColumnNames.nonEmpty) { + throw new AnalysisException( + s""" + |Some existing schema fields (${nonExistentColumnNames.mkString("[", ",", "]")}) are + |not present in the new schema. We don't support dropping columns yet. + """.stripMargin) + } + + // assuming the newSchema has all partition columns at the end as required + externalCatalog.alterTableSchema(db, table, newSchema) + } + + private def columnNameResolved(schema: StructType, colName: String): Boolean = { + schema.fields.map(_.name).exists(conf.resolver(_, colName)) + } + /** * Return whether a table/view with the specified name exists. If no database is specified, check * with current database. diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalogSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalogSuite.scala index fd9e5d6bb13ed..ca4ce1c11707a 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalogSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalogSuite.scala @@ -26,6 +26,7 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.parser.CatalystSqlParser import org.apache.spark.sql.catalyst.plans.PlanTest import org.apache.spark.sql.catalyst.plans.logical.{Range, SubqueryAlias, View} +import org.apache.spark.sql.types._ class InMemorySessionCatalogSuite extends SessionCatalogSuite { protected val utils = new CatalogTestUtils { @@ -448,6 +449,34 @@ abstract class SessionCatalogSuite extends PlanTest { } } + test("alter table add columns") { + withBasicCatalog { sessionCatalog => + sessionCatalog.createTable(newTable("t1", "default"), ignoreIfExists = false) + val oldTab = sessionCatalog.externalCatalog.getTable("default", "t1") + sessionCatalog.alterTableSchema( + TableIdentifier("t1", Some("default")), + StructType(oldTab.dataSchema.add("c3", IntegerType) ++ oldTab.partitionSchema)) + + val newTab = sessionCatalog.externalCatalog.getTable("default", "t1") + // construct the expected table schema + val expectedTableSchema = StructType(oldTab.dataSchema.fields ++ + Seq(StructField("c3", IntegerType)) ++ oldTab.partitionSchema) + assert(newTab.schema == expectedTableSchema) + } + } + + test("alter table drop columns") { + withBasicCatalog { sessionCatalog => + sessionCatalog.createTable(newTable("t1", "default"), ignoreIfExists = false) + val oldTab = sessionCatalog.externalCatalog.getTable("default", "t1") + val e = intercept[AnalysisException] { + sessionCatalog.alterTableSchema( + TableIdentifier("t1", Some("default")), StructType(oldTab.schema.drop(1))) + }.getMessage + assert(e.contains("We don't support dropping columns yet.")) + } + } + test("get table") { withBasicCatalog { catalog => assert(catalog.getTableMetadata(TableIdentifier("tbl1", Some("db2"))) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala index abea7a3bcf146..d4f23f9dd5185 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala @@ -741,6 +741,22 @@ class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder { ctx.VIEW != null) } + /** + * Create a [[AlterTableAddColumnsCommand]] command. + * + * For example: + * {{{ + * ALTER TABLE table1 + * ADD COLUMNS (col_name data_type [COMMENT col_comment], ...); + * }}} + */ + override def visitAddTableColumns(ctx: AddTableColumnsContext): LogicalPlan = withOrigin(ctx) { + AlterTableAddColumnsCommand( + visitTableIdentifier(ctx.tableIdentifier), + visitColTypeList(ctx.columns) + ) + } + /** * Create an [[AlterTableSetPropertiesCommand]] command. * diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala index beb3dcafd64f9..93307fc883565 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala @@ -37,7 +37,10 @@ import org.apache.spark.sql.catalyst.catalog.CatalogTableType._ import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference} import org.apache.spark.sql.catalyst.util.quoteIdentifier -import org.apache.spark.sql.execution.datasources.PartitioningUtils +import org.apache.spark.sql.execution.datasources.{DataSource, FileFormat, PartitioningUtils} +import org.apache.spark.sql.execution.datasources.csv.CSVFileFormat +import org.apache.spark.sql.execution.datasources.json.JsonFileFormat +import org.apache.spark.sql.execution.datasources.parquet.ParquetFileFormat import org.apache.spark.sql.types._ import org.apache.spark.util.Utils @@ -174,6 +177,77 @@ case class AlterTableRenameCommand( } +/** + * A command that add columns to a table + * The syntax of using this command in SQL is: + * {{{ + * ALTER TABLE table_identifier + * ADD COLUMNS (col_name data_type [COMMENT col_comment], ...); + * }}} +*/ +case class AlterTableAddColumnsCommand( + table: TableIdentifier, + columns: Seq[StructField]) extends RunnableCommand { + override def run(sparkSession: SparkSession): Seq[Row] = { + val catalog = sparkSession.sessionState.catalog + val catalogTable = verifyAlterTableAddColumn(catalog, table) + + try { + sparkSession.catalog.uncacheTable(table.quotedString) + } catch { + case NonFatal(e) => + log.warn(s"Exception when attempting to uncache table ${table.quotedString}", e) + } + catalog.refreshTable(table) + + // make sure any partition columns are at the end of the fields + val reorderedSchema = catalogTable.dataSchema ++ columns ++ catalogTable.partitionSchema + catalog.alterTableSchema( + table, catalogTable.schema.copy(fields = reorderedSchema.toArray)) + + Seq.empty[Row] + } + + /** + * ALTER TABLE ADD COLUMNS command does not support temporary view/table, + * view, or datasource table with text, orc formats or external provider. + * For datasource table, it currently only supports parquet, json, csv. + */ + private def verifyAlterTableAddColumn( + catalog: SessionCatalog, + table: TableIdentifier): CatalogTable = { + val catalogTable = catalog.getTempViewOrPermanentTableMetadata(table) + + if (catalogTable.tableType == CatalogTableType.VIEW) { + throw new AnalysisException( + s""" + |ALTER ADD COLUMNS does not support views. + |You must drop and re-create the views for adding the new columns. Views: $table + """.stripMargin) + } + + if (DDLUtils.isDatasourceTable(catalogTable)) { + DataSource.lookupDataSource(catalogTable.provider.get).newInstance() match { + // For datasource table, this command can only support the following File format. + // TextFileFormat only default to one column "value" + // OrcFileFormat can not handle difference between user-specified schema and + // inferred schema yet. TODO, once this issue is resolved , we can add Orc back. + // Hive type is already considered as hive serde table, so the logic will not + // come in here. + case _: JsonFileFormat | _: CSVFileFormat | _: ParquetFileFormat => + case s => + throw new AnalysisException( + s""" + |ALTER ADD COLUMNS does not support datasource table with type $s. + |You must drop and re-create the table for adding the new columns. Tables: $table + """.stripMargin) + } + } + catalogTable + } +} + + /** * A command that loads data into a Hive table. * diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLCommandSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLCommandSuite.scala index 4b73b078da38e..13202a57851e1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLCommandSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLCommandSuite.scala @@ -780,13 +780,7 @@ class DDLCommandSuite extends PlanTest { assertUnsupported("ALTER TABLE table_name SKEWED BY (key) ON (1,5,6) STORED AS DIRECTORIES") } - test("alter table: add/replace columns (not allowed)") { - assertUnsupported( - """ - |ALTER TABLE table_name PARTITION (dt='2008-08-08', country='us') - |ADD COLUMNS (new_col1 INT COMMENT 'test_comment', new_col2 LONG - |COMMENT 'test_comment2') CASCADE - """.stripMargin) + test("alter table: replace columns (not allowed)") { assertUnsupported( """ |ALTER TABLE table_name REPLACE COLUMNS (new_col1 INT diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala index 235c6bf6ad592..648b1798c66e0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala @@ -2185,4 +2185,126 @@ abstract class DDLSuite extends QueryTest with SQLTestUtils { } } } + + val supportedNativeFileFormatsForAlterTableAddColumns = Seq("parquet", "json", "csv") + + supportedNativeFileFormatsForAlterTableAddColumns.foreach { provider => + test(s"alter datasource table add columns - $provider") { + withTable("t1") { + sql(s"CREATE TABLE t1 (c1 int) USING $provider") + sql("INSERT INTO t1 VALUES (1)") + sql("ALTER TABLE t1 ADD COLUMNS (c2 int)") + checkAnswer( + spark.table("t1"), + Seq(Row(1, null)) + ) + checkAnswer( + sql("SELECT * FROM t1 WHERE c2 is null"), + Seq(Row(1, null)) + ) + + sql("INSERT INTO t1 VALUES (3, 2)") + checkAnswer( + sql("SELECT * FROM t1 WHERE c2 = 2"), + Seq(Row(3, 2)) + ) + } + } + } + + supportedNativeFileFormatsForAlterTableAddColumns.foreach { provider => + test(s"alter datasource table add columns - partitioned - $provider") { + withTable("t1") { + sql(s"CREATE TABLE t1 (c1 int, c2 int) USING $provider PARTITIONED BY (c2)") + sql("INSERT INTO t1 PARTITION(c2 = 2) VALUES (1)") + sql("ALTER TABLE t1 ADD COLUMNS (c3 int)") + checkAnswer( + spark.table("t1"), + Seq(Row(1, null, 2)) + ) + checkAnswer( + sql("SELECT * FROM t1 WHERE c3 is null"), + Seq(Row(1, null, 2)) + ) + sql("INSERT INTO t1 PARTITION(c2 =1) VALUES (2, 3)") + checkAnswer( + sql("SELECT * FROM t1 WHERE c3 = 3"), + Seq(Row(2, 3, 1)) + ) + checkAnswer( + sql("SELECT * FROM t1 WHERE c2 = 1"), + Seq(Row(2, 3, 1)) + ) + } + } + } + + test("alter datasource table add columns - text format not supported") { + withTable("t1") { + sql("CREATE TABLE t1 (c1 int) USING text") + val e = intercept[AnalysisException] { + sql("ALTER TABLE t1 ADD COLUMNS (c2 int)") + }.getMessage + assert(e.contains("ALTER ADD COLUMNS does not support datasource table with type")) + } + } + + test("alter table add columns -- not support temp view") { + withTempView("tmp_v") { + sql("CREATE TEMPORARY VIEW tmp_v AS SELECT 1 AS c1, 2 AS c2") + val e = intercept[AnalysisException] { + sql("ALTER TABLE tmp_v ADD COLUMNS (c3 INT)") + } + assert(e.message.contains("ALTER ADD COLUMNS does not support views")) + } + } + + test("alter table add columns -- not support view") { + withView("v1") { + sql("CREATE VIEW v1 AS SELECT 1 AS c1, 2 AS c2") + val e = intercept[AnalysisException] { + sql("ALTER TABLE v1 ADD COLUMNS (c3 INT)") + } + assert(e.message.contains("ALTER ADD COLUMNS does not support views")) + } + } + + test("alter table add columns with existing column name") { + withTable("t1") { + sql("CREATE TABLE t1 (c1 int) USING PARQUET") + val e = intercept[AnalysisException] { + sql("ALTER TABLE t1 ADD COLUMNS (c1 string)") + }.getMessage + assert(e.contains("Found duplicate column(s)")) + } + } + + Seq(true, false).foreach { caseSensitive => + test(s"alter table add columns with existing column name - caseSensitive $caseSensitive") { + withSQLConf(SQLConf.CASE_SENSITIVE.key -> s"$caseSensitive") { + withTable("t1") { + sql("CREATE TABLE t1 (c1 int) USING PARQUET") + if (!caseSensitive) { + val e = intercept[AnalysisException] { + sql("ALTER TABLE t1 ADD COLUMNS (C1 string)") + }.getMessage + assert(e.contains("Found duplicate column(s)")) + } else { + if (isUsingHiveMetastore) { + // hive catalog will still complains that c1 is duplicate column name because hive + // identifiers are case insensitive. + val e = intercept[AnalysisException] { + sql("ALTER TABLE t1 ADD COLUMNS (C1 string)") + }.getMessage + assert(e.contains("HiveException")) + } else { + sql("ALTER TABLE t1 ADD COLUMNS (C1 string)") + assert(spark.table("t1").schema + .equals(new StructType().add("c1", IntegerType).add("C1", StringType))) + } + } + } + } + } + } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala index d752c415c1ed8..04bc79d430324 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala @@ -35,7 +35,7 @@ import org.apache.spark.sql.hive.test.TestHiveSingleton import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.internal.StaticSQLConf.CATALOG_IMPLEMENTATION import org.apache.spark.sql.test.SQLTestUtils -import org.apache.spark.sql.types.{MetadataBuilder, StructType} +import org.apache.spark.sql.types._ // TODO(gatorsmile): combine HiveCatalogedDDLSuite and HiveDDLSuite class HiveCatalogedDDLSuite extends DDLSuite with TestHiveSingleton with BeforeAndAfterEach { @@ -112,6 +112,7 @@ class HiveCatalogedDDLSuite extends DDLSuite with TestHiveSingleton with BeforeA class HiveDDLSuite extends QueryTest with SQLTestUtils with TestHiveSingleton with BeforeAndAfterEach { import testImplicits._ + val hiveFormats = Seq("PARQUET", "ORC", "TEXTFILE", "SEQUENCEFILE", "RCFILE", "AVRO") override def afterEach(): Unit = { try { @@ -1860,4 +1861,101 @@ class HiveDDLSuite } } } + + hiveFormats.foreach { tableType => + test(s"alter hive serde table add columns -- partitioned - $tableType") { + withTable("tab") { + sql( + s""" + |CREATE TABLE tab (c1 int, c2 int) + |PARTITIONED BY (c3 int) STORED AS $tableType + """.stripMargin) + + sql("INSERT INTO tab PARTITION (c3=1) VALUES (1, 2)") + sql("ALTER TABLE tab ADD COLUMNS (c4 int)") + + checkAnswer( + sql("SELECT * FROM tab WHERE c3 = 1"), + Seq(Row(1, 2, null, 1)) + ) + assert(spark.table("tab").schema + .contains(StructField("c4", IntegerType))) + sql("INSERT INTO tab PARTITION (c3=2) VALUES (2, 3, 4)") + checkAnswer( + spark.table("tab"), + Seq(Row(1, 2, null, 1), Row(2, 3, 4, 2)) + ) + checkAnswer( + sql("SELECT * FROM tab WHERE c3 = 2 AND c4 IS NOT NULL"), + Seq(Row(2, 3, 4, 2)) + ) + + sql("ALTER TABLE tab ADD COLUMNS (c5 char(10))") + assert(spark.table("tab").schema.find(_.name == "c5") + .get.metadata.getString("HIVE_TYPE_STRING") == "char(10)") + } + } + } + + hiveFormats.foreach { tableType => + test(s"alter hive serde table add columns -- with predicate - $tableType ") { + withTable("tab") { + sql(s"CREATE TABLE tab (c1 int, c2 int) STORED AS $tableType") + sql("INSERT INTO tab VALUES (1, 2)") + sql("ALTER TABLE tab ADD COLUMNS (c4 int)") + checkAnswer( + sql("SELECT * FROM tab WHERE c4 IS NULL"), + Seq(Row(1, 2, null)) + ) + assert(spark.table("tab").schema + .contains(StructField("c4", IntegerType))) + sql("INSERT INTO tab VALUES (2, 3, 4)") + checkAnswer( + sql("SELECT * FROM tab WHERE c4 = 4 "), + Seq(Row(2, 3, 4)) + ) + checkAnswer( + spark.table("tab"), + Seq(Row(1, 2, null), Row(2, 3, 4)) + ) + } + } + } + + Seq(true, false).foreach { caseSensitive => + test(s"alter add columns with existing column name - caseSensitive $caseSensitive") { + withSQLConf(SQLConf.CASE_SENSITIVE.key -> s"$caseSensitive") { + withTable("tab") { + sql("CREATE TABLE tab (c1 int) PARTITIONED BY (c2 int) STORED AS PARQUET") + if (!caseSensitive) { + // duplicating partitioning column name + val e1 = intercept[AnalysisException] { + sql("ALTER TABLE tab ADD COLUMNS (C2 string)") + }.getMessage + assert(e1.contains("Found duplicate column(s)")) + + // duplicating data column name + val e2 = intercept[AnalysisException] { + sql("ALTER TABLE tab ADD COLUMNS (C1 string)") + }.getMessage + assert(e2.contains("Found duplicate column(s)")) + } else { + // hive catalog will still complains that c1 is duplicate column name because hive + // identifiers are case insensitive. + val e1 = intercept[AnalysisException] { + sql("ALTER TABLE tab ADD COLUMNS (C2 string)") + }.getMessage + assert(e1.contains("HiveException")) + + // hive catalog will still complains that c1 is duplicate column name because hive + // identifiers are case insensitive. + val e2 = intercept[AnalysisException] { + sql("ALTER TABLE tab ADD COLUMNS (C1 string)") + }.getMessage + assert(e2.contains("HiveException")) + } + } + } + } + } } From ae4b91d1f5734b9d66f3b851b71b3c179f3cdd76 Mon Sep 17 00:00:00 2001 From: "Joseph K. Bradley" Date: Tue, 21 Mar 2017 11:01:25 -0700 Subject: [PATCH 0077/1765] [SPARK-20039][ML] rename ChiSquare to ChiSquareTest ## What changes were proposed in this pull request? I realized that since ChiSquare is in the package stat, it's pretty unclear if it's the hypothesis test, distribution, or what. This PR renames it to ChiSquareTest to clarify this. ## How was this patch tested? Existing unit tests Author: Joseph K. Bradley Closes #17368 from jkbradley/SPARK-20039. --- .../ml/stat/{ChiSquare.scala => ChiSquareTest.scala} | 2 +- .../{ChiSquareSuite.scala => ChiSquareTestSuite.scala} | 10 +++++----- 2 files changed, 6 insertions(+), 6 deletions(-) rename mllib/src/main/scala/org/apache/spark/ml/stat/{ChiSquare.scala => ChiSquareTest.scala} (99%) rename mllib/src/test/scala/org/apache/spark/ml/stat/{ChiSquareSuite.scala => ChiSquareTestSuite.scala} (94%) diff --git a/mllib/src/main/scala/org/apache/spark/ml/stat/ChiSquare.scala b/mllib/src/main/scala/org/apache/spark/ml/stat/ChiSquareTest.scala similarity index 99% rename from mllib/src/main/scala/org/apache/spark/ml/stat/ChiSquare.scala rename to mllib/src/main/scala/org/apache/spark/ml/stat/ChiSquareTest.scala index c3865ce6a9e2a..21eba9a49809f 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/stat/ChiSquare.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/stat/ChiSquareTest.scala @@ -37,7 +37,7 @@ import org.apache.spark.sql.functions.col */ @Experimental @Since("2.2.0") -object ChiSquare { +object ChiSquareTest { /** Used to construct output schema of tests */ private case class ChiSquareResult( diff --git a/mllib/src/test/scala/org/apache/spark/ml/stat/ChiSquareSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/stat/ChiSquareTestSuite.scala similarity index 94% rename from mllib/src/test/scala/org/apache/spark/ml/stat/ChiSquareSuite.scala rename to mllib/src/test/scala/org/apache/spark/ml/stat/ChiSquareTestSuite.scala index b4bed82e4d00f..2d6aad0808bc6 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/stat/ChiSquareSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/stat/ChiSquareTestSuite.scala @@ -27,7 +27,7 @@ import org.apache.spark.ml.util.TestingUtils._ import org.apache.spark.mllib.stat.test.ChiSqTest import org.apache.spark.mllib.util.MLlibTestSparkContext -class ChiSquareSuite +class ChiSquareTestSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { import testImplicits._ @@ -45,7 +45,7 @@ class ChiSquareSuite LabeledPoint(1.0, Vectors.dense(3.5, 40.0))) for (numParts <- List(2, 4, 6, 8)) { val df = spark.createDataFrame(sc.parallelize(data, numParts)) - val chi = ChiSquare.test(df, "features", "label") + val chi = ChiSquareTest.test(df, "features", "label") val (pValues: Vector, degreesOfFreedom: Array[Int], statistics: Vector) = chi.select("pValues", "degreesOfFreedom", "statistics") .as[(Vector, Array[Int], Vector)].head() @@ -62,7 +62,7 @@ class ChiSquareSuite LabeledPoint(0.0, Vectors.sparse(numCols, Seq((100, 2.0)))), LabeledPoint(0.1, Vectors.sparse(numCols, Seq((200, 1.0))))) val df = spark.createDataFrame(sparseData) - val chi = ChiSquare.test(df, "features", "label") + val chi = ChiSquareTest.test(df, "features", "label") val (pValues: Vector, degreesOfFreedom: Array[Int], statistics: Vector) = chi.select("pValues", "degreesOfFreedom", "statistics") .as[(Vector, Array[Int], Vector)].head() @@ -83,7 +83,7 @@ class ChiSquareSuite withClue("ChiSquare should throw an exception when given a continuous-valued label") { intercept[SparkException] { val df = spark.createDataFrame(continuousLabel) - ChiSquare.test(df, "features", "label") + ChiSquareTest.test(df, "features", "label") } } val continuousFeature = Seq.fill(tooManyCategories)( @@ -91,7 +91,7 @@ class ChiSquareSuite withClue("ChiSquare should throw an exception when given continuous-valued features") { intercept[SparkException] { val df = spark.createDataFrame(continuousFeature) - ChiSquare.test(df, "features", "label") + ChiSquareTest.test(df, "features", "label") } } } From 7dbc162f12cc1a447c85a1a2c20d32ebb5cbeacf Mon Sep 17 00:00:00 2001 From: zhaorongsheng <334362872@qq.com> Date: Tue, 21 Mar 2017 11:30:55 -0700 Subject: [PATCH 0078/1765] [SPARK-20017][SQL] change the nullability of function 'StringToMap' from 'false' to 'true' ## What changes were proposed in this pull request? Change the nullability of function `StringToMap` from `false` to `true`. Author: zhaorongsheng <334362872@qq.com> Closes #17350 from zhaorongsheng/bug-fix_strToMap_NPE. --- .../sql/catalyst/expressions/complexTypeCreator.scala | 4 +++- .../spark/sql/catalyst/expressions/ComplexTypeSuite.scala | 7 +++++++ 2 files changed, 10 insertions(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala index 22277ad8d56ee..b6675a84ece48 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala @@ -390,6 +390,8 @@ case class CreateNamedStructUnsafe(children: Seq[Expression]) extends CreateName Examples: > SELECT _FUNC_('a:1,b:2,c:3', ',', ':'); map("a":"1","b":"2","c":"3") + > SELECT _FUNC_('a'); + map("a":null) """) // scalastyle:on line.size.limit case class StringToMap(text: Expression, pairDelim: Expression, keyValueDelim: Expression) @@ -407,7 +409,7 @@ case class StringToMap(text: Expression, pairDelim: Expression, keyValueDelim: E override def inputTypes: Seq[AbstractDataType] = Seq(StringType, StringType, StringType) - override def dataType: DataType = MapType(StringType, StringType, valueContainsNull = false) + override def dataType: DataType = MapType(StringType, StringType) override def checkInputDataTypes(): TypeCheckResult = { if (Seq(pairDelim, keyValueDelim).exists(! _.foldable)) { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala index abe1d2b2c99e1..5f8a8f44d48e6 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala @@ -251,6 +251,9 @@ class ComplexTypeSuite extends SparkFunSuite with ExpressionEvalHelper { } test("StringToMap") { + val expectedDataType = MapType(StringType, StringType, valueContainsNull = true) + assert(new StringToMap("").dataType === expectedDataType) + val s0 = Literal("a:1,b:2,c:3") val m0 = Map("a" -> "1", "b" -> "2", "c" -> "3") checkEvaluation(new StringToMap(s0), m0) @@ -271,6 +274,10 @@ class ComplexTypeSuite extends SparkFunSuite with ExpressionEvalHelper { val m4 = Map("a" -> "1", "b" -> "2", "c" -> "3") checkEvaluation(new StringToMap(s4, Literal("_")), m4) + val s5 = Literal("a") + val m5 = Map("a" -> null) + checkEvaluation(new StringToMap(s5), m5) + // arguments checking assert(new StringToMap(Literal("a:1,b:2,c:3")).checkInputDataTypes().isSuccess) assert(new StringToMap(Literal(null)).checkInputDataTypes().isFailure) From a8877bdbba6df105740f909bc87a13cdd4440757 Mon Sep 17 00:00:00 2001 From: Felix Cheung Date: Tue, 21 Mar 2017 14:24:41 -0700 Subject: [PATCH 0079/1765] [SPARK-19237][SPARKR][CORE] On Windows spark-submit should handle when java is not installed ## What changes were proposed in this pull request? When SparkR is installed as a R package there might not be any java runtime. If it is not there SparkR's `sparkR.session()` will block waiting for the connection timeout, hanging the R IDE/shell, without any notification or message. ## How was this patch tested? manually - [x] need to test on Windows Author: Felix Cheung Closes #16596 from felixcheung/rcheckjava. --- R/pkg/inst/tests/testthat/test_Windows.R | 1 + bin/spark-class2.cmd | 11 ++++++++++- 2 files changed, 11 insertions(+), 1 deletion(-) diff --git a/R/pkg/inst/tests/testthat/test_Windows.R b/R/pkg/inst/tests/testthat/test_Windows.R index e8d983426a676..1d777ddb286df 100644 --- a/R/pkg/inst/tests/testthat/test_Windows.R +++ b/R/pkg/inst/tests/testthat/test_Windows.R @@ -20,6 +20,7 @@ test_that("sparkJars tag in SparkContext", { if (.Platform$OS.type != "windows") { skip("This test is only for Windows, skipped") } + testOutput <- launchScript("ECHO", "a/b/c", wait = TRUE) abcPath <- testOutput[1] expect_equal(abcPath, "a\\b\\c") diff --git a/bin/spark-class2.cmd b/bin/spark-class2.cmd index 869c0b202f7f3..9faa7d65f83e4 100644 --- a/bin/spark-class2.cmd +++ b/bin/spark-class2.cmd @@ -50,7 +50,16 @@ if not "x%SPARK_PREPEND_CLASSES%"=="x" ( rem Figure out where java is. set RUNNER=java -if not "x%JAVA_HOME%"=="x" set RUNNER=%JAVA_HOME%\bin\java +if not "x%JAVA_HOME%"=="x" ( + set RUNNER="%JAVA_HOME%\bin\java" +) else ( + where /q "%RUNNER%" + if ERRORLEVEL 1 ( + echo Java not found and JAVA_HOME environment variable is not set. + echo Install Java and set JAVA_HOME to point to the Java installation directory. + exit /b 1 + ) +) rem The launcher library prints the command to be executed in a single line suitable for being rem executed by the batch interpreter. So read all the output of the launcher into a variable. From a04dcde8cb191e591a5f5d7a67a5371e31e7343c Mon Sep 17 00:00:00 2001 From: Will Manning Date: Wed, 22 Mar 2017 00:40:48 +0100 Subject: [PATCH 0080/1765] clarify array_contains function description ## What changes were proposed in this pull request? The description in the comment for array_contains is vague/incomplete (i.e., doesn't mention that it returns `null` if the array is `null`); this PR fixes that. ## How was this patch tested? No testing, since it merely changes a comment. Please review http://spark.apache.org/contributing.html before opening a pull request. Author: Will Manning Closes #17380 from lwwmanning/patch-1. --- sql/core/src/main/scala/org/apache/spark/sql/functions.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index a9f089c850d42..66bb8816a6701 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -2896,7 +2896,7 @@ object functions { ////////////////////////////////////////////////////////////////////////////////////////////// /** - * Returns true if the array contains `value` + * Returns null if the array is null, true if the array contains `value`, and false otherwise. * @group collection_funcs * @since 1.5.0 */ From 9281a3d504d526440c1d445075e38a6d9142ac93 Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Wed, 22 Mar 2017 08:41:46 +0800 Subject: [PATCH 0081/1765] [SPARK-19919][SQL] Defer throwing the exception for empty paths in CSV datasource into `DataSource` ## What changes were proposed in this pull request? This PR proposes to defer throwing the exception within `DataSource`. Currently, if other datasources fail to infer the schema, it returns `None` and then this is being validated in `DataSource` as below: ``` scala> spark.read.json("emptydir") org.apache.spark.sql.AnalysisException: Unable to infer schema for JSON. It must be specified manually.; ``` ``` scala> spark.read.orc("emptydir") org.apache.spark.sql.AnalysisException: Unable to infer schema for ORC. It must be specified manually.; ``` ``` scala> spark.read.parquet("emptydir") org.apache.spark.sql.AnalysisException: Unable to infer schema for Parquet. It must be specified manually.; ``` However, CSV it checks it within the datasource implementation and throws another exception message as below: ``` scala> spark.read.csv("emptydir") java.lang.IllegalArgumentException: requirement failed: Cannot infer schema from an empty set of files ``` We could remove this duplicated check and validate this in one place in the same way with the same message. ## How was this patch tested? Unit test in `CSVSuite` and manual test. Author: hyukjinkwon Closes #17256 from HyukjinKwon/SPARK-19919. --- .../datasources/csv/CSVDataSource.scala | 25 +++++++++++++------ .../datasources/csv/CSVFileFormat.scala | 4 +-- .../sql/test/DataFrameReaderWriterSuite.scala | 6 +++-- 3 files changed, 23 insertions(+), 12 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala index 63af18ec5b8eb..83bdf6fe224be 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala @@ -54,10 +54,21 @@ abstract class CSVDataSource extends Serializable { /** * Infers the schema from `inputPaths` files. */ - def infer( + final def inferSchema( sparkSession: SparkSession, inputPaths: Seq[FileStatus], - parsedOptions: CSVOptions): Option[StructType] + parsedOptions: CSVOptions): Option[StructType] = { + if (inputPaths.nonEmpty) { + Some(infer(sparkSession, inputPaths, parsedOptions)) + } else { + None + } + } + + protected def infer( + sparkSession: SparkSession, + inputPaths: Seq[FileStatus], + parsedOptions: CSVOptions): StructType /** * Generates a header from the given row which is null-safe and duplicate-safe. @@ -131,10 +142,10 @@ object TextInputCSVDataSource extends CSVDataSource { override def infer( sparkSession: SparkSession, inputPaths: Seq[FileStatus], - parsedOptions: CSVOptions): Option[StructType] = { + parsedOptions: CSVOptions): StructType = { val csv = createBaseDataset(sparkSession, inputPaths, parsedOptions) val maybeFirstLine = CSVUtils.filterCommentAndEmpty(csv, parsedOptions).take(1).headOption - Some(inferFromDataset(sparkSession, csv, maybeFirstLine, parsedOptions)) + inferFromDataset(sparkSession, csv, maybeFirstLine, parsedOptions) } /** @@ -203,7 +214,7 @@ object WholeFileCSVDataSource extends CSVDataSource { override def infer( sparkSession: SparkSession, inputPaths: Seq[FileStatus], - parsedOptions: CSVOptions): Option[StructType] = { + parsedOptions: CSVOptions): StructType = { val csv = createBaseRdd(sparkSession, inputPaths, parsedOptions) csv.flatMap { lines => UnivocityParser.tokenizeStream( @@ -222,10 +233,10 @@ object WholeFileCSVDataSource extends CSVDataSource { parsedOptions.headerFlag, new CsvParser(parsedOptions.asParserSettings)) } - Some(CSVInferSchema.infer(tokenRDD, header, parsedOptions)) + CSVInferSchema.infer(tokenRDD, header, parsedOptions) case None => // If the first row could not be read, just return the empty schema. - Some(StructType(Nil)) + StructType(Nil) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVFileFormat.scala index eef43c7629c12..a99bdfee5d6e6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVFileFormat.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVFileFormat.scala @@ -51,12 +51,10 @@ class CSVFileFormat extends TextBasedFileFormat with DataSourceRegister { sparkSession: SparkSession, options: Map[String, String], files: Seq[FileStatus]): Option[StructType] = { - require(files.nonEmpty, "Cannot infer schema from an empty set of files") - val parsedOptions = new CSVOptions(options, sparkSession.sessionState.conf.sessionLocalTimeZone) - CSVDataSource(parsedOptions).infer(sparkSession, files, parsedOptions) + CSVDataSource(parsedOptions).inferSchema(sparkSession, files, parsedOptions) } override def prepareWrite( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/DataFrameReaderWriterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/DataFrameReaderWriterSuite.scala index 8a8ba05534529..8287776f8f558 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/test/DataFrameReaderWriterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/DataFrameReaderWriterSuite.scala @@ -370,9 +370,11 @@ class DataFrameReaderWriterSuite extends QueryTest with SharedSQLContext with Be val schema = df.schema // Reader, without user specified schema - intercept[IllegalArgumentException] { + val message = intercept[AnalysisException] { testRead(spark.read.csv(), Seq.empty, schema) - } + }.getMessage + assert(message.contains("Unable to infer schema for CSV. It must be specified manually.")) + testRead(spark.read.csv(dir), data, schema) testRead(spark.read.csv(dir, dir), data ++ data, schema) testRead(spark.read.csv(Seq(dir, dir): _*), data ++ data, schema) From 2d73fcced0492c606feab8fe84f62e8318ebcaa1 Mon Sep 17 00:00:00 2001 From: Kunal Khamar Date: Tue, 21 Mar 2017 18:56:14 -0700 Subject: [PATCH 0082/1765] [SPARK-20051][SS] Fix StreamSuite flaky test - recover from v2.1 checkpoint ## What changes were proposed in this pull request? There is a race condition between calling stop on a streaming query and deleting directories in `withTempDir` that causes test to fail, fixing to do lazy deletion using delete on shutdown JVM hook. ## How was this patch tested? - Unit test - repeated 300 runs with no failure Author: Kunal Khamar Closes #17382 from kunalkhamar/partition-bugfix. --- .../spark/sql/streaming/StreamSuite.scala | 77 +++++++++---------- 1 file changed, 37 insertions(+), 40 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala index e867fc40f7f1a..f01211e20cbfc 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala @@ -33,6 +33,7 @@ import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.sources.StreamSourceProvider import org.apache.spark.sql.types.{IntegerType, StructField, StructType} +import org.apache.spark.util.Utils class StreamSuite extends StreamTest { @@ -438,52 +439,48 @@ class StreamSuite extends StreamTest { // 1 - Test if recovery from the checkpoint is successful. prepareMemoryStream() - withTempDir { dir => - // Copy the checkpoint to a temp dir to prevent changes to the original. - // Not doing this will lead to the test passing on the first run, but fail subsequent runs. - FileUtils.copyDirectory(checkpointDir, dir) - - // Checkpoint data was generated by a query with 10 shuffle partitions. - // In order to test reading from the checkpoint, the checkpoint must have two or more batches, - // since the last batch may be rerun. - withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "10") { - var streamingQuery: StreamingQuery = null - try { - streamingQuery = - query.queryName("counts").option("checkpointLocation", dir.getCanonicalPath).start() - streamingQuery.processAllAvailable() - inputData.addData(9) - streamingQuery.processAllAvailable() - - QueryTest.checkAnswer(spark.table("counts").toDF(), - Row("1", 1) :: Row("2", 1) :: Row("3", 2) :: Row("4", 2) :: - Row("5", 2) :: Row("6", 2) :: Row("7", 1) :: Row("8", 1) :: Row("9", 1) :: Nil) - } finally { - if (streamingQuery ne null) { - streamingQuery.stop() - } + val dir1 = Utils.createTempDir().getCanonicalFile // not using withTempDir {}, makes test flaky + // Copy the checkpoint to a temp dir to prevent changes to the original. + // Not doing this will lead to the test passing on the first run, but fail subsequent runs. + FileUtils.copyDirectory(checkpointDir, dir1) + // Checkpoint data was generated by a query with 10 shuffle partitions. + // In order to test reading from the checkpoint, the checkpoint must have two or more batches, + // since the last batch may be rerun. + withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "10") { + var streamingQuery: StreamingQuery = null + try { + streamingQuery = + query.queryName("counts").option("checkpointLocation", dir1.getCanonicalPath).start() + streamingQuery.processAllAvailable() + inputData.addData(9) + streamingQuery.processAllAvailable() + + QueryTest.checkAnswer(spark.table("counts").toDF(), + Row("1", 1) :: Row("2", 1) :: Row("3", 2) :: Row("4", 2) :: + Row("5", 2) :: Row("6", 2) :: Row("7", 1) :: Row("8", 1) :: Row("9", 1) :: Nil) + } finally { + if (streamingQuery ne null) { + streamingQuery.stop() } } } // 2 - Check recovery with wrong num shuffle partitions prepareMemoryStream() - withTempDir { dir => - FileUtils.copyDirectory(checkpointDir, dir) - - // Since the number of partitions is greater than 10, should throw exception. - withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "15") { - var streamingQuery: StreamingQuery = null - try { - intercept[StreamingQueryException] { - streamingQuery = - query.queryName("badQuery").option("checkpointLocation", dir.getCanonicalPath).start() - streamingQuery.processAllAvailable() - } - } finally { - if (streamingQuery ne null) { - streamingQuery.stop() - } + val dir2 = Utils.createTempDir().getCanonicalFile + FileUtils.copyDirectory(checkpointDir, dir2) + // Since the number of partitions is greater than 10, should throw exception. + withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "15") { + var streamingQuery: StreamingQuery = null + try { + intercept[StreamingQueryException] { + streamingQuery = + query.queryName("badQuery").option("checkpointLocation", dir2.getCanonicalPath).start() + streamingQuery.processAllAvailable() + } + } finally { + if (streamingQuery ne null) { + streamingQuery.stop() } } } From c1e87e384d1878308b42da80bb3d65be512aab55 Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Tue, 21 Mar 2017 21:27:08 -0700 Subject: [PATCH 0083/1765] [SPARK-20030][SS] Event-time-based timeout for MapGroupsWithState ## What changes were proposed in this pull request? Adding event time based timeout. The user sets the timeout timestamp directly using `KeyedState.setTimeoutTimestamp`. The keys times out when the watermark crosses the timeout timestamp. ## How was this patch tested? Unit tests Author: Tathagata Das Closes #17361 from tdas/SPARK-20030. --- .../sql/streaming/KeyedStateTimeout.java | 22 +- .../UnsupportedOperationChecker.scala | 96 +++-- .../sql/catalyst/plans/logical/object.scala | 3 +- .../analysis/UnsupportedOperationsSuite.scala | 16 + .../spark/sql/execution/SparkStrategies.scala | 3 +- .../FlatMapGroupsWithStateExec.scala | 87 ++-- .../streaming/IncrementalExecution.scala | 5 +- .../execution/streaming/KeyedStateImpl.scala | 139 ++++-- .../streaming/statefulOperators.scala | 14 +- .../spark/sql/streaming/KeyedState.scala | 97 ++++- .../FlatMapGroupsWithStateSuite.scala | 402 ++++++++++++------ 11 files changed, 616 insertions(+), 268 deletions(-) diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/streaming/KeyedStateTimeout.java b/sql/catalyst/src/main/java/org/apache/spark/sql/streaming/KeyedStateTimeout.java index cf112f2e02a95..e2e7ab1d2609f 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/streaming/KeyedStateTimeout.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/streaming/KeyedStateTimeout.java @@ -19,9 +19,7 @@ import org.apache.spark.annotation.Experimental; import org.apache.spark.annotation.InterfaceStability; -import org.apache.spark.sql.catalyst.plans.logical.NoTimeout$; -import org.apache.spark.sql.catalyst.plans.logical.ProcessingTimeTimeout; -import org.apache.spark.sql.catalyst.plans.logical.ProcessingTimeTimeout$; +import org.apache.spark.sql.catalyst.plans.logical.*; /** * Represents the type of timeouts possible for the Dataset operations @@ -34,9 +32,23 @@ @InterfaceStability.Evolving public class KeyedStateTimeout { - /** Timeout based on processing time. */ + /** + * Timeout based on processing time. The duration of timeout can be set for each group in + * `map/flatMapGroupsWithState` by calling `KeyedState.setTimeoutDuration()`. See documentation + * on `KeyedState` for more details. + */ public static KeyedStateTimeout ProcessingTimeTimeout() { return ProcessingTimeTimeout$.MODULE$; } - /** No timeout */ + /** + * Timeout based on event-time. The event-time timestamp for timeout can be set for each + * group in `map/flatMapGroupsWithState` by calling `KeyedState.setTimeoutTimestamp()`. + * In addition, you have to define the watermark in the query using `Dataset.withWatermark`. + * When the watermark advances beyond the set timestamp of a group and the group has not + * received any data, then the group times out. See documentation on + * `KeyedState` for more details. + */ + public static KeyedStateTimeout EventTimeTimeout() { return EventTimeTimeout$.MODULE$; } + + /** No timeout. */ public static KeyedStateTimeout NoTimeout() { return NoTimeout$.MODULE$; } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala index a9ff61e0e8802..7da7f55aa5d7f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala @@ -147,49 +147,69 @@ object UnsupportedOperationChecker { throwError("Commands like CreateTable*, AlterTable*, Show* are not supported with " + "streaming DataFrames/Datasets") - // mapGroupsWithState: Allowed only when no aggregation + Update output mode - case m: FlatMapGroupsWithState if m.isStreaming && m.isMapGroupsWithState => - if (collectStreamingAggregates(plan).isEmpty) { - if (outputMode != InternalOutputModes.Update) { - throwError("mapGroupsWithState is not supported with " + - s"$outputMode output mode on a streaming DataFrame/Dataset") - } else { - // Allowed when no aggregation + Update output mode - } - } else { - throwError("mapGroupsWithState is not supported with aggregation " + - "on a streaming DataFrame/Dataset") - } - - // flatMapGroupsWithState without aggregation - case m: FlatMapGroupsWithState - if m.isStreaming && collectStreamingAggregates(plan).isEmpty => - m.outputMode match { - case InternalOutputModes.Update => - if (outputMode != InternalOutputModes.Update) { - throwError("flatMapGroupsWithState in update mode is not supported with " + + // mapGroupsWithState and flatMapGroupsWithState + case m: FlatMapGroupsWithState if m.isStreaming => + + // Check compatibility with output modes and aggregations in query + val aggsAfterFlatMapGroups = collectStreamingAggregates(plan) + + if (m.isMapGroupsWithState) { // check mapGroupsWithState + // allowed only in update query output mode and without aggregation + if (aggsAfterFlatMapGroups.nonEmpty) { + throwError( + "mapGroupsWithState is not supported with aggregation " + + "on a streaming DataFrame/Dataset") + } else if (outputMode != InternalOutputModes.Update) { + throwError( + "mapGroupsWithState is not supported with " + s"$outputMode output mode on a streaming DataFrame/Dataset") + } + } else { // check latMapGroupsWithState + if (aggsAfterFlatMapGroups.isEmpty) { + // flatMapGroupsWithState without aggregation: operation's output mode must + // match query output mode + m.outputMode match { + case InternalOutputModes.Update if outputMode != InternalOutputModes.Update => + throwError( + "flatMapGroupsWithState in update mode is not supported with " + + s"$outputMode output mode on a streaming DataFrame/Dataset") + + case InternalOutputModes.Append if outputMode != InternalOutputModes.Append => + throwError( + "flatMapGroupsWithState in append mode is not supported with " + + s"$outputMode output mode on a streaming DataFrame/Dataset") + + case _ => } - case InternalOutputModes.Append => - if (outputMode != InternalOutputModes.Append) { - throwError("flatMapGroupsWithState in append mode is not supported with " + - s"$outputMode output mode on a streaming DataFrame/Dataset") + } else { + // flatMapGroupsWithState with aggregation: update operation mode not allowed, and + // *groupsWithState after aggregation not allowed + if (m.outputMode == InternalOutputModes.Update) { + throwError( + "flatMapGroupsWithState in update mode is not supported with " + + "aggregation on a streaming DataFrame/Dataset") + } else if (collectStreamingAggregates(m).nonEmpty) { + throwError( + "flatMapGroupsWithState in append mode is not supported after " + + s"aggregation on a streaming DataFrame/Dataset") } + } } - // flatMapGroupsWithState(Update) with aggregation - case m: FlatMapGroupsWithState - if m.isStreaming && m.outputMode == InternalOutputModes.Update - && collectStreamingAggregates(plan).nonEmpty => - throwError("flatMapGroupsWithState in update mode is not supported with " + - "aggregation on a streaming DataFrame/Dataset") - - // flatMapGroupsWithState(Append) with aggregation - case m: FlatMapGroupsWithState - if m.isStreaming && m.outputMode == InternalOutputModes.Append - && collectStreamingAggregates(m).nonEmpty => - throwError("flatMapGroupsWithState in append mode is not supported after " + - s"aggregation on a streaming DataFrame/Dataset") + // Check compatibility with timeout configs + if (m.timeout == EventTimeTimeout) { + // With event time timeout, watermark must be defined. + val watermarkAttributes = m.child.output.collect { + case a: Attribute if a.metadata.contains(EventTimeWatermark.delayKey) => a + } + if (watermarkAttributes.isEmpty) { + throwError( + "Watermark must be specified in the query using " + + "'[Dataset/DataFrame].withWatermark()' for using event-time timeout in a " + + "[map|flatMap]GroupsWithState. Event-time timeout not supported without " + + "watermark.")(plan) + } + } case d: Deduplicate if collectStreamingAggregates(d).nonEmpty => throwError("dropDuplicates is not supported after aggregation on a " + diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala index d1f95faf2db0c..e0ecf8c5f2643 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala @@ -353,9 +353,10 @@ case class MapGroups( /** Internal class representing State */ trait LogicalKeyedState[S] -/** Possible types of timeouts used in FlatMapGroupsWithState */ +/** Types of timeouts used in FlatMapGroupsWithState */ case object NoTimeout extends KeyedStateTimeout case object ProcessingTimeTimeout extends KeyedStateTimeout +case object EventTimeTimeout extends KeyedStateTimeout /** Factory for constructing new `MapGroupsWithState` nodes. */ object FlatMapGroupsWithState { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationsSuite.scala index 08216e2660400..8f0a0c0d99d15 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationsSuite.scala @@ -345,6 +345,22 @@ class UnsupportedOperationsSuite extends SparkFunSuite { outputMode = Append, expectedMsgs = Seq("Mixing mapGroupsWithStates and flatMapGroupsWithStates")) + // mapGroupsWithState with event time timeout + watermark + assertNotSupportedInStreamingPlan( + "mapGroupsWithState - mapGroupsWithState with event time timeout without watermark", + FlatMapGroupsWithState( + null, att, att, Seq(att), Seq(att), att, null, Update, isMapGroupsWithState = true, + EventTimeTimeout, streamRelation), + outputMode = Update, + expectedMsgs = Seq("watermark")) + + assertSupportedInStreamingPlan( + "mapGroupsWithState - mapGroupsWithState with event time timeout with watermark", + FlatMapGroupsWithState( + null, att, att, Seq(att), Seq(att), att, null, Update, isMapGroupsWithState = true, + EventTimeTimeout, new TestStreamingRelation(attributeWithWatermark)), + outputMode = Update) + // Deduplicate assertSupportedInStreamingPlan( "Deduplicate - Deduplicate on streaming relation before aggregation", diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 9e58e8ce3d5f8..ca2f6dd7a84b2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -336,8 +336,7 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { timeout, child) => val execPlan = FlatMapGroupsWithStateExec( func, keyDeser, valueDeser, groupAttr, dataAttr, outputAttr, None, stateEnc, outputMode, - timeout, batchTimestampMs = KeyedStateImpl.NO_BATCH_PROCESSING_TIMESTAMP, - planLater(child)) + timeout, batchTimestampMs = None, eventTimeWatermark = None, planLater(child)) execPlan :: Nil case _ => Nil diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala index 991d8ef707567..52ad70c7dc886 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala @@ -19,13 +19,14 @@ package org.apache.spark.sql.execution.streaming import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder -import org.apache.spark.sql.catalyst.expressions.{Ascending, Attribute, AttributeReference, Expression, Literal, SortOrder, SpecificInternalRow, UnsafeProjection, UnsafeRow} -import org.apache.spark.sql.catalyst.plans.logical.{LogicalKeyedState, ProcessingTimeTimeout} -import org.apache.spark.sql.catalyst.plans.physical.{ClusteredDistribution, Distribution, Partitioning} +import org.apache.spark.sql.catalyst.expressions.{Ascending, Attribute, AttributeReference, Expression, Literal, SortOrder, UnsafeRow} +import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.catalyst.plans.physical.{ClusteredDistribution, Distribution} import org.apache.spark.sql.execution._ +import org.apache.spark.sql.execution.streaming.KeyedStateImpl.NO_TIMESTAMP import org.apache.spark.sql.execution.streaming.state._ import org.apache.spark.sql.streaming.{KeyedStateTimeout, OutputMode} -import org.apache.spark.sql.types.{BooleanType, IntegerType} +import org.apache.spark.sql.types.IntegerType import org.apache.spark.util.CompletionIterator /** @@ -39,7 +40,7 @@ import org.apache.spark.util.CompletionIterator * @param outputObjAttr used to define the output object * @param stateEncoder used to serialize/deserialize state before calling `func` * @param outputMode the output mode of `func` - * @param timeout used to timeout groups that have not received data in a while + * @param timeoutConf used to timeout groups that have not received data in a while * @param batchTimestampMs processing timestamp of the current batch. */ case class FlatMapGroupsWithStateExec( @@ -52,11 +53,15 @@ case class FlatMapGroupsWithStateExec( stateId: Option[OperatorStateId], stateEncoder: ExpressionEncoder[Any], outputMode: OutputMode, - timeout: KeyedStateTimeout, - batchTimestampMs: Long, - child: SparkPlan) extends UnaryExecNode with ObjectProducerExec with StateStoreWriter { + timeoutConf: KeyedStateTimeout, + batchTimestampMs: Option[Long], + override val eventTimeWatermark: Option[Long], + child: SparkPlan + ) extends UnaryExecNode with ObjectProducerExec with StateStoreWriter with WatermarkSupport { - private val isTimeoutEnabled = timeout == ProcessingTimeTimeout + import KeyedStateImpl._ + + private val isTimeoutEnabled = timeoutConf != NoTimeout private val timestampTimeoutAttribute = AttributeReference("timeoutTimestamp", dataType = IntegerType, nullable = false)() private val stateAttributes: Seq[Attribute] = { @@ -64,8 +69,6 @@ case class FlatMapGroupsWithStateExec( if (isTimeoutEnabled) encSchemaAttribs :+ timestampTimeoutAttribute else encSchemaAttribs } - import KeyedStateImpl._ - /** Distribute by grouping attributes */ override def requiredChildDistribution: Seq[Distribution] = ClusteredDistribution(groupingAttributes) :: Nil @@ -74,9 +77,21 @@ case class FlatMapGroupsWithStateExec( override def requiredChildOrdering: Seq[Seq[SortOrder]] = Seq(groupingAttributes.map(SortOrder(_, Ascending))) + override def keyExpressions: Seq[Attribute] = groupingAttributes + override protected def doExecute(): RDD[InternalRow] = { metrics // force lazy init at driver + // Throw errors early if parameters are not as expected + timeoutConf match { + case ProcessingTimeTimeout => + require(batchTimestampMs.nonEmpty) + case EventTimeTimeout => + require(eventTimeWatermark.nonEmpty) // watermark value has been populated + require(watermarkExpression.nonEmpty) // input schema has watermark attribute + case _ => + } + child.execute().mapPartitionsWithStateStore[InternalRow]( getStateId.checkpointLocation, getStateId.operatorId, @@ -84,15 +99,23 @@ case class FlatMapGroupsWithStateExec( groupingAttributes.toStructType, stateAttributes.toStructType, sqlContext.sessionState, - Some(sqlContext.streams.stateStoreCoordinator)) { case (store, iterator) => + Some(sqlContext.streams.stateStoreCoordinator)) { case (store, iter) => val updater = new StateStoreUpdater(store) + // If timeout is based on event time, then filter late data based on watermark + val filteredIter = watermarkPredicateForData match { + case Some(predicate) if timeoutConf == EventTimeTimeout => + iter.filter(row => !predicate.eval(row)) + case None => + iter + } + // Generate a iterator that returns the rows grouped by the grouping function // Note that this code ensures that the filtering for timeout occurs only after // all the data has been processed. This is to ensure that the timeout information of all // the keys with data is updated before they are processed for timeouts. val outputIterator = - updater.updateStateForKeysWithData(iterator) ++ updater.updateStateForTimedOutKeys() + updater.updateStateForKeysWithData(filteredIter) ++ updater.updateStateForTimedOutKeys() // Return an iterator of all the rows generated by all the keys, such that when fully // consumed, all the state updates will be committed by the state store @@ -124,7 +147,7 @@ case class FlatMapGroupsWithStateExec( private val stateSerializer = { val encoderSerializer = stateEncoder.namedExpressions if (isTimeoutEnabled) { - encoderSerializer :+ Literal(KeyedStateImpl.TIMEOUT_TIMESTAMP_NOT_SET) + encoderSerializer :+ Literal(KeyedStateImpl.NO_TIMESTAMP) } else { encoderSerializer } @@ -157,16 +180,19 @@ case class FlatMapGroupsWithStateExec( /** Find the groups that have timeout set and are timing out right now, and call the function */ def updateStateForTimedOutKeys(): Iterator[InternalRow] = { if (isTimeoutEnabled) { + val timeoutThreshold = timeoutConf match { + case ProcessingTimeTimeout => batchTimestampMs.get + case EventTimeTimeout => eventTimeWatermark.get + case _ => + throw new IllegalStateException( + s"Cannot filter timed out keys for $timeoutConf") + } val timingOutKeys = store.filter { case (_, stateRow) => val timeoutTimestamp = getTimeoutTimestamp(stateRow) - timeoutTimestamp != TIMEOUT_TIMESTAMP_NOT_SET && timeoutTimestamp < batchTimestampMs + timeoutTimestamp != NO_TIMESTAMP && timeoutTimestamp < timeoutThreshold } timingOutKeys.flatMap { case (keyRow, stateRow) => - callFunctionAndUpdateState( - keyRow, - Iterator.empty, - Some(stateRow), - hasTimedOut = true) + callFunctionAndUpdateState(keyRow, Iterator.empty, Some(stateRow), hasTimedOut = true) } } else Iterator.empty } @@ -186,7 +212,11 @@ case class FlatMapGroupsWithStateExec( val valueObjIter = valueRowIter.map(getValueObj.apply) // convert value rows to objects val stateObjOption = getStateObj(prevStateRowOption) val keyedState = new KeyedStateImpl( - stateObjOption, batchTimestampMs, isTimeoutEnabled, hasTimedOut) + stateObjOption, + batchTimestampMs.getOrElse(NO_TIMESTAMP), + eventTimeWatermark.getOrElse(NO_TIMESTAMP), + timeoutConf, + hasTimedOut) // Call function, get the returned objects and convert them to rows val mappedIterator = func(keyObj, valueObjIter, keyedState).map { obj => @@ -196,8 +226,6 @@ case class FlatMapGroupsWithStateExec( // When the iterator is consumed, then write changes to state def onIteratorCompletion: Unit = { - // Has the timeout information changed - if (keyedState.hasRemoved) { store.remove(keyRow) numUpdatedStateRows += 1 @@ -205,26 +233,25 @@ case class FlatMapGroupsWithStateExec( } else { val previousTimeoutTimestamp = prevStateRowOption match { case Some(row) => getTimeoutTimestamp(row) - case None => TIMEOUT_TIMESTAMP_NOT_SET + case None => NO_TIMESTAMP } - + val currentTimeoutTimestamp = keyedState.getTimeoutTimestamp val stateRowToWrite = if (keyedState.hasUpdated) { getStateRow(keyedState.get) } else { prevStateRowOption.orNull } - val hasTimeoutChanged = keyedState.getTimeoutTimestamp != previousTimeoutTimestamp + val hasTimeoutChanged = currentTimeoutTimestamp != previousTimeoutTimestamp val shouldWriteState = keyedState.hasUpdated || hasTimeoutChanged if (shouldWriteState) { if (stateRowToWrite == null) { // This should never happen because checks in KeyedStateImpl should avoid cases // where empty state would need to be written - throw new IllegalStateException( - "Attempting to write empty state") + throw new IllegalStateException("Attempting to write empty state") } - setTimeoutTimestamp(stateRowToWrite, keyedState.getTimeoutTimestamp) + setTimeoutTimestamp(stateRowToWrite, currentTimeoutTimestamp) store.put(keyRow.copy(), stateRowToWrite.copy()) numUpdatedStateRows += 1 } @@ -247,7 +274,7 @@ case class FlatMapGroupsWithStateExec( /** Returns the timeout timestamp of a state row is set */ def getTimeoutTimestamp(stateRow: UnsafeRow): Long = { - if (isTimeoutEnabled) stateRow.getLong(timeoutTimestampIndex) else TIMEOUT_TIMESTAMP_NOT_SET + if (isTimeoutEnabled) stateRow.getLong(timeoutTimestampIndex) else NO_TIMESTAMP } /** Set the timestamp in a state row */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala index a934c75a02457..0f0e4a91f8cc7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala @@ -108,7 +108,10 @@ class IncrementalExecution( case m: FlatMapGroupsWithStateExec => val stateId = OperatorStateId(checkpointLocation, operatorId.getAndIncrement(), currentBatchId) - m.copy(stateId = Some(stateId), batchTimestampMs = offsetSeqMetadata.batchTimestampMs) + m.copy( + stateId = Some(stateId), + batchTimestampMs = Some(offsetSeqMetadata.batchTimestampMs), + eventTimeWatermark = Some(offsetSeqMetadata.batchWatermarkMs)) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/KeyedStateImpl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/KeyedStateImpl.scala index ac421d395beb4..edfd35bd5dd75 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/KeyedStateImpl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/KeyedStateImpl.scala @@ -17,37 +17,45 @@ package org.apache.spark.sql.execution.streaming +import java.sql.Date + import org.apache.commons.lang3.StringUtils -import org.apache.spark.sql.streaming.KeyedState +import org.apache.spark.sql.catalyst.plans.logical.{EventTimeTimeout, ProcessingTimeTimeout} +import org.apache.spark.sql.execution.streaming.KeyedStateImpl._ +import org.apache.spark.sql.streaming.{KeyedState, KeyedStateTimeout} import org.apache.spark.unsafe.types.CalendarInterval + /** * Internal implementation of the [[KeyedState]] interface. Methods are not thread-safe. * @param optionalValue Optional value of the state * @param batchProcessingTimeMs Processing time of current batch, used to calculate timestamp * for processing time timeouts - * @param isTimeoutEnabled Whether timeout is enabled. This will be used to check whether the user - * is allowed to configure timeouts. + * @param timeoutConf Type of timeout configured. Based on this, different operations will + * be supported. * @param hasTimedOut Whether the key for which this state wrapped is being created is * getting timed out or not. */ private[sql] class KeyedStateImpl[S]( optionalValue: Option[S], batchProcessingTimeMs: Long, - isTimeoutEnabled: Boolean, + eventTimeWatermarkMs: Long, + timeoutConf: KeyedStateTimeout, override val hasTimedOut: Boolean) extends KeyedState[S] { - import KeyedStateImpl._ - // Constructor to create dummy state when using mapGroupsWithState in a batch query def this(optionalValue: Option[S]) = this( - optionalValue, -1, isTimeoutEnabled = false, hasTimedOut = false) + optionalValue, + batchProcessingTimeMs = NO_TIMESTAMP, + eventTimeWatermarkMs = NO_TIMESTAMP, + timeoutConf = KeyedStateTimeout.NoTimeout, + hasTimedOut = false) private var value: S = optionalValue.getOrElse(null.asInstanceOf[S]) private var defined: Boolean = optionalValue.isDefined private var updated: Boolean = false // whether value has been updated (but not removed) private var removed: Boolean = false // whether value has been removed - private var timeoutTimestamp: Long = TIMEOUT_TIMESTAMP_NOT_SET + private var timeoutTimestamp: Long = NO_TIMESTAMP // ========= Public API ========= override def exists: Boolean = defined @@ -82,13 +90,14 @@ private[sql] class KeyedStateImpl[S]( defined = false updated = false removed = true - timeoutTimestamp = TIMEOUT_TIMESTAMP_NOT_SET + timeoutTimestamp = NO_TIMESTAMP } override def setTimeoutDuration(durationMs: Long): Unit = { - if (!isTimeoutEnabled) { + if (timeoutConf != ProcessingTimeTimeout) { throw new UnsupportedOperationException( - "Cannot set timeout information without enabling timeout in map/flatMapGroupsWithState") + "Cannot set timeout duration without enabling processing time timeout in " + + "map/flatMapGroupsWithState") } if (!defined) { throw new IllegalStateException( @@ -99,7 +108,7 @@ private[sql] class KeyedStateImpl[S]( if (durationMs <= 0) { throw new IllegalArgumentException("Timeout duration must be positive") } - if (!removed && batchProcessingTimeMs != NO_BATCH_PROCESSING_TIMESTAMP) { + if (!removed && batchProcessingTimeMs != NO_TIMESTAMP) { timeoutTimestamp = durationMs + batchProcessingTimeMs } else { // This is being called in a batch query, hence no processing timestamp. @@ -108,29 +117,55 @@ private[sql] class KeyedStateImpl[S]( } override def setTimeoutDuration(duration: String): Unit = { - if (StringUtils.isBlank(duration)) { - throw new IllegalArgumentException( - "The window duration, slide duration and start time cannot be null or blank.") - } - val intervalString = if (duration.startsWith("interval")) { - duration - } else { - "interval " + duration + setTimeoutDuration(parseDuration(duration)) + } + + @throws[IllegalArgumentException]("if 'timestampMs' is not positive") + @throws[IllegalStateException]("when state is either not initialized, or already removed") + @throws[UnsupportedOperationException]( + "if 'timeout' has not been enabled in [map|flatMap]GroupsWithState in a streaming query") + override def setTimeoutTimestamp(timestampMs: Long): Unit = { + checkTimeoutTimestampAllowed() + if (timestampMs <= 0) { + throw new IllegalArgumentException("Timeout timestamp must be positive") } - val cal = CalendarInterval.fromString(intervalString) - if (cal == null) { + if (eventTimeWatermarkMs != NO_TIMESTAMP && timestampMs < eventTimeWatermarkMs) { throw new IllegalArgumentException( - s"The provided duration ($duration) is not valid.") + s"Timeout timestamp ($timestampMs) cannot be earlier than the " + + s"current watermark ($eventTimeWatermarkMs)") } - if (cal.milliseconds < 0 || cal.months < 0) { - throw new IllegalArgumentException("Timeout duration must be positive") + if (!removed && batchProcessingTimeMs != NO_TIMESTAMP) { + timeoutTimestamp = timestampMs + } else { + // This is being called in a batch query, hence no processing timestamp. + // Just ignore any attempts to set timeout. } + } - val delayMs = { - val millisPerMonth = CalendarInterval.MICROS_PER_DAY / 1000 * 31 - cal.milliseconds + cal.months * millisPerMonth - } - setTimeoutDuration(delayMs) + @throws[IllegalArgumentException]("if 'additionalDuration' is invalid") + @throws[IllegalStateException]("when state is either not initialized, or already removed") + @throws[UnsupportedOperationException]( + "if 'timeout' has not been enabled in [map|flatMap]GroupsWithState in a streaming query") + override def setTimeoutTimestamp(timestampMs: Long, additionalDuration: String): Unit = { + checkTimeoutTimestampAllowed() + setTimeoutTimestamp(parseDuration(additionalDuration) + timestampMs) + } + + @throws[IllegalStateException]("when state is either not initialized, or already removed") + @throws[UnsupportedOperationException]( + "if 'timeout' has not been enabled in [map|flatMap]GroupsWithState in a streaming query") + override def setTimeoutTimestamp(timestamp: Date): Unit = { + checkTimeoutTimestampAllowed() + setTimeoutTimestamp(timestamp.getTime) + } + + @throws[IllegalArgumentException]("if 'additionalDuration' is invalid") + @throws[IllegalStateException]("when state is either not initialized, or already removed") + @throws[UnsupportedOperationException]( + "if 'timeout' has not been enabled in [map|flatMap]GroupsWithState in a streaming query") + override def setTimeoutTimestamp(timestamp: Date, additionalDuration: String): Unit = { + checkTimeoutTimestampAllowed() + setTimeoutTimestamp(timestamp.getTime + parseDuration(additionalDuration)) } override def toString: String = { @@ -147,14 +182,46 @@ private[sql] class KeyedStateImpl[S]( /** Return timeout timestamp or `TIMEOUT_TIMESTAMP_NOT_SET` if not set */ def getTimeoutTimestamp: Long = timeoutTimestamp + + private def parseDuration(duration: String): Long = { + if (StringUtils.isBlank(duration)) { + throw new IllegalArgumentException( + "Provided duration is null or blank.") + } + val intervalString = if (duration.startsWith("interval")) { + duration + } else { + "interval " + duration + } + val cal = CalendarInterval.fromString(intervalString) + if (cal == null) { + throw new IllegalArgumentException( + s"Provided duration ($duration) is not valid.") + } + if (cal.milliseconds < 0 || cal.months < 0) { + throw new IllegalArgumentException(s"Provided duration ($duration) is not positive") + } + + val millisPerMonth = CalendarInterval.MICROS_PER_DAY / 1000 * 31 + cal.milliseconds + cal.months * millisPerMonth + } + + private def checkTimeoutTimestampAllowed(): Unit = { + if (timeoutConf != EventTimeTimeout) { + throw new UnsupportedOperationException( + "Cannot set timeout timestamp without enabling event time timeout in " + + "map/flatMapGroupsWithState") + } + if (!defined) { + throw new IllegalStateException( + "Cannot set timeout timestamp without any state value, " + + "state has either not been initialized, or has already been removed") + } + } } private[sql] object KeyedStateImpl { - // Value used in the state row to represent the lack of any timeout timestamp - val TIMEOUT_TIMESTAMP_NOT_SET = -1L - - // Value to represent that no batch processing timestamp is passed to KeyedStateImpl. This is - // used in batch queries where there are no streaming batches and timeouts. - val NO_BATCH_PROCESSING_TIMESTAMP = -1L + // Value used represent the lack of valid timestamp as a long + val NO_TIMESTAMP = -1L } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala index 6d2de441eb44c..f72144a25d5cc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala @@ -80,7 +80,7 @@ trait WatermarkSupport extends UnaryExecNode { /** Generate an expression that matches data older than the watermark */ lazy val watermarkExpression: Option[Expression] = { val optionalWatermarkAttribute = - keyExpressions.find(_.metadata.contains(EventTimeWatermark.delayKey)) + child.output.find(_.metadata.contains(EventTimeWatermark.delayKey)) optionalWatermarkAttribute.map { watermarkAttribute => // If we are evicting based on a window, use the end of the window. Otherwise just @@ -101,14 +101,12 @@ trait WatermarkSupport extends UnaryExecNode { } } - /** Generate a predicate based on keys that matches data older than the watermark */ + /** Predicate based on keys that matches data older than the watermark */ lazy val watermarkPredicateForKeys: Option[Predicate] = watermarkExpression.map(newPredicate(_, keyExpressions)) - /** - * Generate a predicate based on the child output that matches data older than the watermark. - */ - lazy val watermarkPredicate: Option[Predicate] = + /** Predicate based on the child output that matches data older than the watermark. */ + lazy val watermarkPredicateForData: Option[Predicate] = watermarkExpression.map(newPredicate(_, child.output)) } @@ -218,7 +216,7 @@ case class StateStoreSaveExec( new Iterator[InternalRow] { // Filter late date using watermark if specified - private[this] val baseIterator = watermarkPredicate match { + private[this] val baseIterator = watermarkPredicateForData match { case Some(predicate) => iter.filter((row: InternalRow) => !predicate.eval(row)) case None => iter } @@ -285,7 +283,7 @@ case class StreamingDeduplicateExec( val numTotalStateRows = longMetric("numTotalStateRows") val numUpdatedStateRows = longMetric("numUpdatedStateRows") - val baseIterator = watermarkPredicate match { + val baseIterator = watermarkPredicateForData match { case Some(predicate) => iter.filter(row => !predicate.eval(row)) case None => iter } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/KeyedState.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/KeyedState.scala index 6b4b1ced98a34..461de04f6bbe2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/KeyedState.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/KeyedState.scala @@ -55,7 +55,7 @@ import org.apache.spark.sql.catalyst.plans.logical.LogicalKeyedState * batch, nor with streaming Datasets. * - All the data will be shuffled before applying the function. * - If timeout is set, then the function will also be called with no values. - * See more details on KeyedStateTimeout` below. + * See more details on `KeyedStateTimeout` below. * * Important points to note about using `KeyedState`. * - The value of the state cannot be null. So updating state with null will throw @@ -68,20 +68,38 @@ import org.apache.spark.sql.catalyst.plans.logical.LogicalKeyedState * * Important points to note about using `KeyedStateTimeout`. * - The timeout type is a global param across all the keys (set as `timeout` param in - * `[map|flatMap]GroupsWithState`, but the exact timeout duration is configurable per key - * (by calling `setTimeout...()` in `KeyedState`). - * - When the timeout occurs for a key, the function is called with no values, and + * `[map|flatMap]GroupsWithState`, but the exact timeout duration/timestamp is configurable per + * key by calling `setTimeout...()` in `KeyedState`. + * - Timeouts can be either based on processing time (i.e. + * [[KeyedStateTimeout.ProcessingTimeTimeout]]) or event time (i.e. + * [[KeyedStateTimeout.EventTimeTimeout]]). + * - With `ProcessingTimeTimeout`, the timeout duration can be set by calling + * `KeyedState.setTimeoutDuration`. The timeout will occur when the clock has advanced by the set + * duration. Guarantees provided by this timeout with a duration of D ms are as follows: + * - Timeout will never be occur before the clock time has advanced by D ms + * - Timeout will occur eventually when there is a trigger in the query + * (i.e. after D ms). So there is a no strict upper bound on when the timeout would occur. + * For example, the trigger interval of the query will affect when the timeout actually occurs. + * If there is no data in the stream (for any key) for a while, then their will not be + * any trigger and timeout function call will not occur until there is data. + * - Since the processing time timeout is based on the clock time, it is affected by the + * variations in the system clock (i.e. time zone changes, clock skew, etc.). + * - With `EventTimeTimeout`, the user also has to specify the the the event time watermark in + * the query using `Dataset.withWatermark()`. With this setting, data that is older than the + * watermark are filtered out. The timeout can be enabled for a key by setting a timestamp using + * `KeyedState.setTimeoutTimestamp()`, and the timeout would occur when the watermark advances + * beyond the set timestamp. You can control the timeout delay by two parameters - (i) watermark + * delay and an additional duration beyond the timestamp in the event (which is guaranteed to + * > watermark due to the filtering). Guarantees provided by this timeout are as follows: + * - Timeout will never be occur before watermark has exceeded the set timeout. + * - Similar to processing time timeouts, there is a no strict upper bound on the delay when + * the timeout actually occurs. The watermark can advance only when there is data in the + * stream, and the event time of the data has actually advanced. + * - When the timeout occurs for a key, the function is called for that key with no values, and * `KeyedState.hasTimedOut()` set to true. * - The timeout is reset for key every time the function is called on the key, that is, * when the key has new data, or the key has timed out. So the user has to set the timeout * duration every time the function is called, otherwise there will not be any timeout set. - * - Guarantees provided on processing-time-based timeout of key, when timeout duration is D ms: - * - Timeout will never be called before real clock time has advanced by D ms - * - Timeout will be called eventually when there is a trigger in the query - * (i.e. after D ms). So there is a no strict upper bound on when the timeout would occur. - * For example, the trigger interval of the query will affect when the timeout is actually hit. - * If there is no data in the stream (for any key) for a while, then their will not be - * any trigger and timeout will not be hit until there is data. * * Scala example of using KeyedState in `mapGroupsWithState`: * {{{ @@ -194,7 +212,8 @@ trait KeyedState[S] extends LogicalKeyedState[S] { /** * Set the timeout duration in ms for this key. - * @note Timeouts must be enabled in `[map/flatmap]GroupsWithStates`. + * + * @note ProcessingTimeTimeout must be enabled in `[map/flatmap]GroupsWithStates`. */ @throws[IllegalArgumentException]("if 'durationMs' is not positive") @throws[IllegalStateException]("when state is either not initialized, or already removed") @@ -204,11 +223,63 @@ trait KeyedState[S] extends LogicalKeyedState[S] { /** * Set the timeout duration for this key as a string. For example, "1 hour", "2 days", etc. - * @note, Timeouts must be enabled in `[map/flatmap]GroupsWithStates`. + * + * @note, ProcessingTimeTimeout must be enabled in `[map/flatmap]GroupsWithStates`. */ @throws[IllegalArgumentException]("if 'duration' is not a valid duration") @throws[IllegalStateException]("when state is either not initialized, or already removed") @throws[UnsupportedOperationException]( "if 'timeout' has not been enabled in [map|flatMap]GroupsWithState in a streaming query") def setTimeoutDuration(duration: String): Unit + + @throws[IllegalArgumentException]("if 'timestampMs' is not positive") + @throws[IllegalStateException]("when state is either not initialized, or already removed") + @throws[UnsupportedOperationException]( + "if 'timeout' has not been enabled in [map|flatMap]GroupsWithState in a streaming query") + /** + * Set the timeout timestamp for this key as milliseconds in epoch time. + * This timestamp cannot be older than the current watermark. + * + * @note, EventTimeTimeout must be enabled in `[map/flatmap]GroupsWithStates`. + */ + def setTimeoutTimestamp(timestampMs: Long): Unit + + @throws[IllegalArgumentException]("if 'additionalDuration' is invalid") + @throws[IllegalStateException]("when state is either not initialized, or already removed") + @throws[UnsupportedOperationException]( + "if 'timeout' has not been enabled in [map|flatMap]GroupsWithState in a streaming query") + /** + * Set the timeout timestamp for this key as milliseconds in epoch time and an additional + * duration as a string (e.g. "1 hour", "2 days", etc.). + * The final timestamp (including the additional duration) cannot be older than the + * current watermark. + * + * @note, EventTimeTimeout must be enabled in `[map/flatmap]GroupsWithStates`. + */ + def setTimeoutTimestamp(timestampMs: Long, additionalDuration: String): Unit + + @throws[IllegalStateException]("when state is either not initialized, or already removed") + @throws[UnsupportedOperationException]( + "if 'timeout' has not been enabled in [map|flatMap]GroupsWithState in a streaming query") + /** + * Set the timeout timestamp for this key as a java.sql.Date. + * This timestamp cannot be older than the current watermark. + * + * @note, EventTimeTimeout must be enabled in `[map/flatmap]GroupsWithStates`. + */ + def setTimeoutTimestamp(timestamp: java.sql.Date): Unit + + @throws[IllegalArgumentException]("if 'additionalDuration' is invalid") + @throws[IllegalStateException]("when state is either not initialized, or already removed") + @throws[UnsupportedOperationException]( + "if 'timeout' has not been enabled in [map|flatMap]GroupsWithState in a streaming query") + /** + * Set the timeout timestamp for this key as a java.sql.Date and an additional + * duration as a string (e.g. "1 hour", "2 days", etc.). + * The final timestamp (including the additional duration) cannot be older than the + * current watermark. + * + * @note, EventTimeTimeout must be enabled in `[map/flatmap]GroupsWithStates`. + */ + def setTimeoutTimestamp(timestamp: java.sql.Date, additionalDuration: String): Unit } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala index 7daa5e6a0f61f..fe72283bb608f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.streaming -import java.util +import java.sql.Date import java.util.concurrent.ConcurrentHashMap import org.scalatest.BeforeAndAfterAll @@ -44,6 +44,7 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest with BeforeAndAf import testImplicits._ import KeyedStateImpl._ + import KeyedStateTimeout._ override def afterAll(): Unit = { super.afterAll() @@ -96,77 +97,93 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest with BeforeAndAf } } - test("KeyedState - setTimeoutDuration, hasTimedOut") { - import KeyedStateImpl._ - var state: KeyedStateImpl[Int] = null - - // When isTimeoutEnabled = false, then setTimeoutDuration() is not allowed + test("KeyedState - setTimeout**** with NoTimeout") { for (initState <- Seq(None, Some(5))) { // for different initial state - state = new KeyedStateImpl(initState, 1000, isTimeoutEnabled = false, hasTimedOut = false) - assert(state.hasTimedOut === false) - assert(state.getTimeoutTimestamp === TIMEOUT_TIMESTAMP_NOT_SET) - intercept[UnsupportedOperationException] { - state.setTimeoutDuration(1000) - } - intercept[UnsupportedOperationException] { - state.setTimeoutDuration("1 day") - } - assert(state.getTimeoutTimestamp === TIMEOUT_TIMESTAMP_NOT_SET) + implicit val state = new KeyedStateImpl(initState, 1000, 1000, NoTimeout, hasTimedOut = false) + testTimeoutDurationNotAllowed[UnsupportedOperationException](state) + testTimeoutTimestampNotAllowed[UnsupportedOperationException](state) } + } - def testTimeoutNotAllowed(): Unit = { - intercept[IllegalStateException] { - state.setTimeoutDuration(1000) - } - assert(state.getTimeoutTimestamp === TIMEOUT_TIMESTAMP_NOT_SET) - intercept[IllegalStateException] { - state.setTimeoutDuration("2 second") - } - assert(state.getTimeoutTimestamp === TIMEOUT_TIMESTAMP_NOT_SET) - } + test("KeyedState - setTimeout**** with ProcessingTimeTimeout") { + implicit var state: KeyedStateImpl[Int] = null - // When isTimeoutEnabled = true, then setTimeoutDuration() is not allowed until the - // state is be defined - state = new KeyedStateImpl(None, 1000, isTimeoutEnabled = true, hasTimedOut = false) - assert(state.hasTimedOut === false) - assert(state.getTimeoutTimestamp === TIMEOUT_TIMESTAMP_NOT_SET) - testTimeoutNotAllowed() + state = new KeyedStateImpl[Int](None, 1000, 1000, ProcessingTimeTimeout, hasTimedOut = false) + assert(state.getTimeoutTimestamp === NO_TIMESTAMP) + testTimeoutDurationNotAllowed[IllegalStateException](state) + testTimeoutTimestampNotAllowed[UnsupportedOperationException](state) - // After state has been set, setTimeoutDuration() is allowed, and - // getTimeoutTimestamp returned correct timestamp state.update(5) - assert(state.hasTimedOut === false) - assert(state.getTimeoutTimestamp === TIMEOUT_TIMESTAMP_NOT_SET) + assert(state.getTimeoutTimestamp === NO_TIMESTAMP) state.setTimeoutDuration(1000) assert(state.getTimeoutTimestamp === 2000) state.setTimeoutDuration("2 second") assert(state.getTimeoutTimestamp === 3000) - assert(state.hasTimedOut === false) + testTimeoutTimestampNotAllowed[UnsupportedOperationException](state) + + state.remove() + assert(state.getTimeoutTimestamp === NO_TIMESTAMP) + testTimeoutDurationNotAllowed[IllegalStateException](state) + testTimeoutTimestampNotAllowed[UnsupportedOperationException](state) + } + + test("KeyedState - setTimeout**** with EventTimeTimeout") { + implicit val state = new KeyedStateImpl[Int]( + None, 1000, 1000, EventTimeTimeout, hasTimedOut = false) + assert(state.getTimeoutTimestamp === NO_TIMESTAMP) + testTimeoutDurationNotAllowed[UnsupportedOperationException](state) + testTimeoutTimestampNotAllowed[IllegalStateException](state) + + state.update(5) + state.setTimeoutTimestamp(10000) + assert(state.getTimeoutTimestamp === 10000) + state.setTimeoutTimestamp(new Date(20000)) + assert(state.getTimeoutTimestamp === 20000) + testTimeoutDurationNotAllowed[UnsupportedOperationException](state) + + state.remove() + assert(state.getTimeoutTimestamp === NO_TIMESTAMP) + testTimeoutDurationNotAllowed[UnsupportedOperationException](state) + testTimeoutTimestampNotAllowed[IllegalStateException](state) + } + + test("KeyedState - illegal params to setTimeout****") { + var state: KeyedStateImpl[Int] = null - // setTimeoutDuration() with negative values or 0 is not allowed + // Test setTimeout****() with illegal values def testIllegalTimeout(body: => Unit): Unit = { intercept[IllegalArgumentException] { body } - assert(state.getTimeoutTimestamp === TIMEOUT_TIMESTAMP_NOT_SET) + assert(state.getTimeoutTimestamp === NO_TIMESTAMP) } - state = new KeyedStateImpl(Some(5), 1000, isTimeoutEnabled = true, hasTimedOut = false) + + state = new KeyedStateImpl(Some(5), 1000, 1000, ProcessingTimeTimeout, hasTimedOut = false) testIllegalTimeout { state.setTimeoutDuration(-1000) } testIllegalTimeout { state.setTimeoutDuration(0) } testIllegalTimeout { state.setTimeoutDuration("-2 second") } testIllegalTimeout { state.setTimeoutDuration("-1 month") } testIllegalTimeout { state.setTimeoutDuration("1 month -1 day") } - // Test remove() clear timeout timestamp, and setTimeoutDuration() is not allowed after that - state = new KeyedStateImpl(Some(5), 1000, isTimeoutEnabled = true, hasTimedOut = false) - state.remove() - assert(state.hasTimedOut === false) - assert(state.getTimeoutTimestamp === TIMEOUT_TIMESTAMP_NOT_SET) - testTimeoutNotAllowed() - - // Test hasTimedOut = true - state = new KeyedStateImpl(Some(5), 1000, isTimeoutEnabled = true, hasTimedOut = true) - assert(state.hasTimedOut === true) - assert(state.getTimeoutTimestamp === TIMEOUT_TIMESTAMP_NOT_SET) + state = new KeyedStateImpl(Some(5), 1000, 1000, EventTimeTimeout, hasTimedOut = false) + testIllegalTimeout { state.setTimeoutTimestamp(-10000) } + testIllegalTimeout { state.setTimeoutTimestamp(10000, "-3 second") } + testIllegalTimeout { state.setTimeoutTimestamp(10000, "-1 month") } + testIllegalTimeout { state.setTimeoutTimestamp(10000, "1 month -1 day") } + testIllegalTimeout { state.setTimeoutTimestamp(new Date(-10000)) } + testIllegalTimeout { state.setTimeoutTimestamp(new Date(-10000), "-3 second") } + testIllegalTimeout { state.setTimeoutTimestamp(new Date(-10000), "-1 month") } + testIllegalTimeout { state.setTimeoutTimestamp(new Date(-10000), "1 month -1 day") } + } + + test("KeyedState - hasTimedOut") { + for (timeoutConf <- Seq(NoTimeout, ProcessingTimeTimeout, EventTimeTimeout)) { + for (initState <- Seq(None, Some(5))) { + val state1 = new KeyedStateImpl(initState, 1000, 1000, timeoutConf, hasTimedOut = false) + assert(state1.hasTimedOut === false) + val state2 = new KeyedStateImpl(initState, 1000, 1000, timeoutConf, hasTimedOut = true) + assert(state2.hasTimedOut === true) + } + } } test("KeyedState - primitive type") { @@ -187,133 +204,186 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest with BeforeAndAf } // Values used for testing StateStoreUpdater - val currentTimestamp = 1000 - val beforeCurrentTimestamp = 999 - val afterCurrentTimestamp = 1001 + val currentBatchTimestamp = 1000 + val currentBatchWatermark = 1000 + val beforeTimeoutThreshold = 999 + val afterTimeoutThreshold = 1001 + - // Tests for StateStoreUpdater.updateStateForKeysWithData() when timeout is disabled + // Tests for StateStoreUpdater.updateStateForKeysWithData() when timeout = NoTimeout for (priorState <- Seq(None, Some(0))) { val priorStateStr = if (priorState.nonEmpty) "prior state set" else "no prior state" - val testName = s"timeout disabled - $priorStateStr - " + val testName = s"NoTimeout - $priorStateStr - " testStateUpdateWithData( testName + "no update", stateUpdates = state => { /* do nothing */ }, - timeoutType = KeyedStateTimeout.NoTimeout, + timeoutConf = KeyedStateTimeout.NoTimeout, priorState = priorState, expectedState = priorState) // should not change testStateUpdateWithData( testName + "state updated", stateUpdates = state => { state.update(5) }, - timeoutType = KeyedStateTimeout.NoTimeout, + timeoutConf = KeyedStateTimeout.NoTimeout, priorState = priorState, expectedState = Some(5)) // should change testStateUpdateWithData( testName + "state removed", stateUpdates = state => { state.remove() }, - timeoutType = KeyedStateTimeout.NoTimeout, + timeoutConf = KeyedStateTimeout.NoTimeout, priorState = priorState, expectedState = None) // should be removed } - // Tests for StateStoreUpdater.updateStateForKeysWithData() when timeout is enabled + // Tests for StateStoreUpdater.updateStateForKeysWithData() when timeout != NoTimeout for (priorState <- Seq(None, Some(0))) { - for (priorTimeoutTimestamp <- Seq(TIMEOUT_TIMESTAMP_NOT_SET, 1000)) { - var testName = s"timeout enabled - " + for (priorTimeoutTimestamp <- Seq(NO_TIMESTAMP, 1000)) { + var testName = s"" if (priorState.nonEmpty) { testName += "prior state set, " if (priorTimeoutTimestamp == 1000) { - testName += "prior timeout set - " + testName += "prior timeout set" } else { - testName += "no prior timeout - " + testName += "no prior timeout" } } else { - testName += "no prior state - " + testName += "no prior state" + } + for (timeoutConf <- Seq(ProcessingTimeTimeout, EventTimeTimeout)) { + + testStateUpdateWithData( + s"$timeoutConf - $testName - no update", + stateUpdates = state => { /* do nothing */ }, + timeoutConf = timeoutConf, + priorState = priorState, + priorTimeoutTimestamp = priorTimeoutTimestamp, + expectedState = priorState, // state should not change + expectedTimeoutTimestamp = NO_TIMESTAMP) // timestamp should be reset + + testStateUpdateWithData( + s"$timeoutConf - $testName - state updated", + stateUpdates = state => { state.update(5) }, + timeoutConf = timeoutConf, + priorState = priorState, + priorTimeoutTimestamp = priorTimeoutTimestamp, + expectedState = Some(5), // state should change + expectedTimeoutTimestamp = NO_TIMESTAMP) // timestamp should be reset + + testStateUpdateWithData( + s"$timeoutConf - $testName - state removed", + stateUpdates = state => { state.remove() }, + timeoutConf = timeoutConf, + priorState = priorState, + priorTimeoutTimestamp = priorTimeoutTimestamp, + expectedState = None) // state should be removed } testStateUpdateWithData( - testName + "no update", - stateUpdates = state => { /* do nothing */ }, - timeoutType = KeyedStateTimeout.ProcessingTimeTimeout, - priorState = priorState, - priorTimeoutTimestamp = priorTimeoutTimestamp, - expectedState = priorState, // state should not change - expectedTimeoutTimestamp = TIMEOUT_TIMESTAMP_NOT_SET) // timestamp should be reset - - testStateUpdateWithData( - testName + "state updated", - stateUpdates = state => { state.update(5) }, - timeoutType = KeyedStateTimeout.ProcessingTimeTimeout, + s"ProcessingTimeTimeout - $testName - state and timeout duration updated", + stateUpdates = + (state: KeyedState[Int]) => { state.update(5); state.setTimeoutDuration(5000) }, + timeoutConf = ProcessingTimeTimeout, priorState = priorState, priorTimeoutTimestamp = priorTimeoutTimestamp, - expectedState = Some(5), // state should change - expectedTimeoutTimestamp = TIMEOUT_TIMESTAMP_NOT_SET) // timestamp should be reset + expectedState = Some(5), // state should change + expectedTimeoutTimestamp = currentBatchTimestamp + 5000) // timestamp should change testStateUpdateWithData( - testName + "state removed", - stateUpdates = state => { state.remove() }, - timeoutType = KeyedStateTimeout.ProcessingTimeTimeout, + s"EventTimeTimeout - $testName - state and timeout timestamp updated", + stateUpdates = + (state: KeyedState[Int]) => { state.update(5); state.setTimeoutTimestamp(5000) }, + timeoutConf = EventTimeTimeout, priorState = priorState, priorTimeoutTimestamp = priorTimeoutTimestamp, - expectedState = None) // state should be removed + expectedState = Some(5), // state should change + expectedTimeoutTimestamp = 5000) // timestamp should change testStateUpdateWithData( - testName + "timeout and state updated", - stateUpdates = state => { state.update(5); state.setTimeoutDuration(5000) }, - timeoutType = KeyedStateTimeout.ProcessingTimeTimeout, + s"EventTimeTimeout - $testName - timeout timestamp updated to before watermark", + stateUpdates = + (state: KeyedState[Int]) => { + state.update(5) + intercept[IllegalArgumentException] { + state.setTimeoutTimestamp(currentBatchWatermark - 1) // try to set to < watermark + } + }, + timeoutConf = EventTimeTimeout, priorState = priorState, priorTimeoutTimestamp = priorTimeoutTimestamp, - expectedState = Some(5), // state should change - expectedTimeoutTimestamp = currentTimestamp + 5000) // timestamp should change + expectedState = Some(5), // state should change + expectedTimeoutTimestamp = NO_TIMESTAMP) // timestamp should not update } } // Tests for StateStoreUpdater.updateStateForTimedOutKeys() val preTimeoutState = Some(5) + for (timeoutConf <- Seq(ProcessingTimeTimeout, EventTimeTimeout)) { + testStateUpdateWithTimeout( + s"$timeoutConf - should not timeout", + stateUpdates = state => { assert(false, "function called without timeout") }, + timeoutConf = timeoutConf, + priorTimeoutTimestamp = afterTimeoutThreshold, + expectedState = preTimeoutState, // state should not change + expectedTimeoutTimestamp = afterTimeoutThreshold) // timestamp should not change + + testStateUpdateWithTimeout( + s"$timeoutConf - should timeout - no update/remove", + stateUpdates = state => { /* do nothing */ }, + timeoutConf = timeoutConf, + priorTimeoutTimestamp = beforeTimeoutThreshold, + expectedState = preTimeoutState, // state should not change + expectedTimeoutTimestamp = NO_TIMESTAMP) // timestamp should be reset - testStateUpdateWithTimeout( - "should not timeout", - stateUpdates = state => { assert(false, "function called without timeout") }, - priorTimeoutTimestamp = afterCurrentTimestamp, - expectedState = preTimeoutState, // state should not change - expectedTimeoutTimestamp = afterCurrentTimestamp) // timestamp should not change + testStateUpdateWithTimeout( + s"$timeoutConf - should timeout - update state", + stateUpdates = state => { state.update(5) }, + timeoutConf = timeoutConf, + priorTimeoutTimestamp = beforeTimeoutThreshold, + expectedState = Some(5), // state should change + expectedTimeoutTimestamp = NO_TIMESTAMP) // timestamp should be reset + + testStateUpdateWithTimeout( + s"$timeoutConf - should timeout - remove state", + stateUpdates = state => { state.remove() }, + timeoutConf = timeoutConf, + priorTimeoutTimestamp = beforeTimeoutThreshold, + expectedState = None, // state should be removed + expectedTimeoutTimestamp = NO_TIMESTAMP) + } testStateUpdateWithTimeout( - "should timeout - no update/remove", - stateUpdates = state => { /* do nothing */ }, - priorTimeoutTimestamp = beforeCurrentTimestamp, + "ProcessingTimeTimeout - should timeout - timeout duration updated", + stateUpdates = state => { state.setTimeoutDuration(2000) }, + timeoutConf = ProcessingTimeTimeout, + priorTimeoutTimestamp = beforeTimeoutThreshold, expectedState = preTimeoutState, // state should not change - expectedTimeoutTimestamp = TIMEOUT_TIMESTAMP_NOT_SET) // timestamp should be reset + expectedTimeoutTimestamp = currentBatchTimestamp + 2000) // timestamp should change testStateUpdateWithTimeout( - "should timeout - update state", - stateUpdates = state => { state.update(5) }, - priorTimeoutTimestamp = beforeCurrentTimestamp, + "ProcessingTimeTimeout - should timeout - timeout duration and state updated", + stateUpdates = state => { state.update(5); state.setTimeoutDuration(2000) }, + timeoutConf = ProcessingTimeTimeout, + priorTimeoutTimestamp = beforeTimeoutThreshold, expectedState = Some(5), // state should change - expectedTimeoutTimestamp = TIMEOUT_TIMESTAMP_NOT_SET) // timestamp should be reset + expectedTimeoutTimestamp = currentBatchTimestamp + 2000) // timestamp should change testStateUpdateWithTimeout( - "should timeout - remove state", - stateUpdates = state => { state.remove() }, - priorTimeoutTimestamp = beforeCurrentTimestamp, - expectedState = None, // state should be removed - expectedTimeoutTimestamp = TIMEOUT_TIMESTAMP_NOT_SET) - - testStateUpdateWithTimeout( - "should timeout - timeout updated", - stateUpdates = state => { state.setTimeoutDuration(2000) }, - priorTimeoutTimestamp = beforeCurrentTimestamp, + "EventTimeTimeout - should timeout - timeout timestamp updated", + stateUpdates = state => { state.setTimeoutTimestamp(5000) }, + timeoutConf = EventTimeTimeout, + priorTimeoutTimestamp = beforeTimeoutThreshold, expectedState = preTimeoutState, // state should not change - expectedTimeoutTimestamp = currentTimestamp + 2000) // timestamp should change + expectedTimeoutTimestamp = 5000) // timestamp should change testStateUpdateWithTimeout( - "should timeout - timeout and state updated", - stateUpdates = state => { state.update(5); state.setTimeoutDuration(2000) }, - priorTimeoutTimestamp = beforeCurrentTimestamp, + "EventTimeTimeout - should timeout - timeout and state updated", + stateUpdates = state => { state.update(5); state.setTimeoutTimestamp(5000) }, + timeoutConf = EventTimeTimeout, + priorTimeoutTimestamp = beforeTimeoutThreshold, expectedState = Some(5), // state should change - expectedTimeoutTimestamp = currentTimestamp + 2000) // timestamp should change + expectedTimeoutTimestamp = 5000) // timestamp should change test("StateStoreUpdater - rows are cloned before writing to StateStore") { // function for running count @@ -481,11 +551,10 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest with BeforeAndAf val clock = new StreamManualClock val inputData = MemoryStream[String] - val timeout = KeyedStateTimeout.ProcessingTimeTimeout val result = inputData.toDS() .groupByKey(x => x) - .flatMapGroupsWithState(Update, timeout)(stateFunc) + .flatMapGroupsWithState(Update, ProcessingTimeTimeout)(stateFunc) testStream(result, Update)( StartStream(ProcessingTime("1 second"), triggerClock = clock), @@ -519,6 +588,52 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest with BeforeAndAf ) } + test("flatMapGroupsWithState - streaming with event time timeout") { + // Function to maintain the max event time + // Returns the max event time in the state, or -1 if the state was removed by timeout + val stateFunc = ( + key: String, + values: Iterator[(String, Long)], + state: KeyedState[Long]) => { + val timeoutDelay = 5 + if (key != "a") { + Iterator.empty + } else { + if (state.hasTimedOut) { + state.remove() + Iterator((key, -1)) + } else { + val valuesSeq = values.toSeq + val maxEventTime = math.max(valuesSeq.map(_._2).max, state.getOption.getOrElse(0L)) + val timeoutTimestampMs = maxEventTime + timeoutDelay + state.update(maxEventTime) + state.setTimeoutTimestamp(timeoutTimestampMs * 1000) + Iterator((key, maxEventTime.toInt)) + } + } + } + val inputData = MemoryStream[(String, Int)] + val result = + inputData.toDS + .select($"_1".as("key"), $"_2".cast("timestamp").as("eventTime")) + .withWatermark("eventTime", "10 seconds") + .as[(String, Long)] + .groupByKey(_._1) + .flatMapGroupsWithState(Update, EventTimeTimeout)(stateFunc) + + testStream(result, Update)( + StartStream(ProcessingTime("1 second")), + AddData(inputData, ("a", 11), ("a", 13), ("a", 15)), // Set timeout timestamp of ... + CheckLastBatch(("a", 15)), // "a" to 15 + 5 = 20s, watermark to 5s + AddData(inputData, ("a", 4)), // Add data older than watermark for "a" + CheckLastBatch(), // No output as data should get filtered by watermark + AddData(inputData, ("dummy", 35)), // Set watermark = 35 - 10 = 25s + CheckLastBatch(), // No output as no data for "a" + AddData(inputData, ("a", 24)), // Add data older than watermark, should be ignored + CheckLastBatch(("a", -1)) // State for "a" should timeout and emit -1 + ) + } + test("mapGroupsWithState - streaming") { // Function to maintain running count up to 2, and then remove the count // Returns the data and the count (-1 if count reached beyond 2 and state was just removed) @@ -612,7 +727,6 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest with BeforeAndAf val stateFunc = (key: String, values: Iterator[String], state: KeyedState[RunningCount]) => key val inputData = MemoryStream[String] val result = inputData.toDS.groupByKey(x => x).mapGroupsWithState(stateFunc) - result testStream(result, Update)( AddData(inputData, "a"), CheckLastBatch("a"), @@ -649,13 +763,13 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest with BeforeAndAf def testStateUpdateWithData( testName: String, stateUpdates: KeyedState[Int] => Unit, - timeoutType: KeyedStateTimeout = KeyedStateTimeout.NoTimeout, + timeoutConf: KeyedStateTimeout, priorState: Option[Int], - priorTimeoutTimestamp: Long = TIMEOUT_TIMESTAMP_NOT_SET, + priorTimeoutTimestamp: Long = NO_TIMESTAMP, expectedState: Option[Int] = None, - expectedTimeoutTimestamp: Long = TIMEOUT_TIMESTAMP_NOT_SET): Unit = { + expectedTimeoutTimestamp: Long = NO_TIMESTAMP): Unit = { - if (priorState.isEmpty && priorTimeoutTimestamp != TIMEOUT_TIMESTAMP_NOT_SET) { + if (priorState.isEmpty && priorTimeoutTimestamp != NO_TIMESTAMP) { return // there can be no prior timestamp, when there is no prior state } test(s"StateStoreUpdater - updates with data - $testName") { @@ -666,7 +780,7 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest with BeforeAndAf Iterator.empty } testStateUpdate( - testTimeoutUpdates = false, mapGroupsFunc, timeoutType, + testTimeoutUpdates = false, mapGroupsFunc, timeoutConf, priorState, priorTimeoutTimestamp, expectedState, expectedTimeoutTimestamp) } } @@ -674,9 +788,10 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest with BeforeAndAf def testStateUpdateWithTimeout( testName: String, stateUpdates: KeyedState[Int] => Unit, + timeoutConf: KeyedStateTimeout, priorTimeoutTimestamp: Long, expectedState: Option[Int], - expectedTimeoutTimestamp: Long = TIMEOUT_TIMESTAMP_NOT_SET): Unit = { + expectedTimeoutTimestamp: Long = NO_TIMESTAMP): Unit = { test(s"StateStoreUpdater - updates for timeout - $testName") { val mapGroupsFunc = (key: Int, values: Iterator[Int], state: KeyedState[Int]) => { @@ -686,16 +801,15 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest with BeforeAndAf Iterator.empty } testStateUpdate( - testTimeoutUpdates = true, mapGroupsFunc, KeyedStateTimeout.ProcessingTimeTimeout, - preTimeoutState, priorTimeoutTimestamp, - expectedState, expectedTimeoutTimestamp) + testTimeoutUpdates = true, mapGroupsFunc, timeoutConf = timeoutConf, + preTimeoutState, priorTimeoutTimestamp, expectedState, expectedTimeoutTimestamp) } } def testStateUpdate( testTimeoutUpdates: Boolean, mapGroupsFunc: (Int, Iterator[Int], KeyedState[Int]) => Iterator[Int], - timeoutType: KeyedStateTimeout, + timeoutConf: KeyedStateTimeout, priorState: Option[Int], priorTimeoutTimestamp: Long, expectedState: Option[Int], @@ -703,7 +817,7 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest with BeforeAndAf val store = newStateStore() val mapGroupsSparkPlan = newFlatMapGroupsWithStateExec( - mapGroupsFunc, timeoutType, currentTimestamp) + mapGroupsFunc, timeoutConf, currentBatchTimestamp) val updater = new mapGroupsSparkPlan.StateStoreUpdater(store) val key = intToRow(0) // Prepare store with prior state configs @@ -736,7 +850,7 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest with BeforeAndAf def newFlatMapGroupsWithStateExec( func: (Int, Iterator[Int], KeyedState[Int]) => Iterator[Int], timeoutType: KeyedStateTimeout = KeyedStateTimeout.NoTimeout, - batchTimestampMs: Long = NO_BATCH_PROCESSING_TIMESTAMP): FlatMapGroupsWithStateExec = { + batchTimestampMs: Long = NO_TIMESTAMP): FlatMapGroupsWithStateExec = { MemoryStream[Int] .toDS .groupByKey(x => x) @@ -744,11 +858,31 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest with BeforeAndAf .logicalPlan.collectFirst { case FlatMapGroupsWithState(f, k, v, g, d, o, s, m, _, t, _) => FlatMapGroupsWithStateExec( - f, k, v, g, d, o, None, s, m, t, currentTimestamp, - RDDScanExec(g, null, "rdd")) + f, k, v, g, d, o, None, s, m, t, + Some(currentBatchTimestamp), Some(currentBatchWatermark), RDDScanExec(g, null, "rdd")) }.get } + def testTimeoutDurationNotAllowed[T <: Exception: Manifest](state: KeyedStateImpl[_]): Unit = { + val prevTimestamp = state.getTimeoutTimestamp + intercept[T] { state.setTimeoutDuration(1000) } + assert(state.getTimeoutTimestamp === prevTimestamp) + intercept[T] { state.setTimeoutDuration("2 second") } + assert(state.getTimeoutTimestamp === prevTimestamp) + } + + def testTimeoutTimestampNotAllowed[T <: Exception: Manifest](state: KeyedStateImpl[_]): Unit = { + val prevTimestamp = state.getTimeoutTimestamp + intercept[T] { state.setTimeoutTimestamp(2000) } + assert(state.getTimeoutTimestamp === prevTimestamp) + intercept[T] { state.setTimeoutTimestamp(2000, "1 second") } + assert(state.getTimeoutTimestamp === prevTimestamp) + intercept[T] { state.setTimeoutTimestamp(new Date(2000)) } + assert(state.getTimeoutTimestamp === prevTimestamp) + intercept[T] { state.setTimeoutTimestamp(new Date(2000), "1 second") } + assert(state.getTimeoutTimestamp === prevTimestamp) + } + def newStateStore(): StateStore = new MemoryStateStore() val intProj = UnsafeProjection.create(Array[DataType](IntegerType)) From 478fbc866fbfdb4439788583281863ecea14e8af Mon Sep 17 00:00:00 2001 From: Yanbo Liang Date: Tue, 21 Mar 2017 21:50:54 -0700 Subject: [PATCH 0084/1765] [SPARK-19925][SPARKR] Fix SparkR spark.getSparkFiles fails when it was called on executors. ## What changes were proposed in this pull request? SparkR ```spark.getSparkFiles``` fails when it was called on executors, see details at [SPARK-19925](https://issues.apache.org/jira/browse/SPARK-19925). ## How was this patch tested? Add unit tests, and verify this fix at standalone and yarn cluster. Author: Yanbo Liang Closes #17274 from yanboliang/spark-19925. --- R/pkg/R/context.R | 16 ++++++++++++++-- R/pkg/inst/tests/testthat/test_context.R | 7 +++++++ .../scala/org/apache/spark/api/r/RRunner.scala | 2 ++ 3 files changed, 23 insertions(+), 2 deletions(-) diff --git a/R/pkg/R/context.R b/R/pkg/R/context.R index 1ca573e5bd614..50856e3d9856c 100644 --- a/R/pkg/R/context.R +++ b/R/pkg/R/context.R @@ -330,7 +330,13 @@ spark.addFile <- function(path, recursive = FALSE) { #'} #' @note spark.getSparkFilesRootDirectory since 2.1.0 spark.getSparkFilesRootDirectory <- function() { - callJStatic("org.apache.spark.SparkFiles", "getRootDirectory") + if (Sys.getenv("SPARKR_IS_RUNNING_ON_WORKER") == "") { + # Running on driver. + callJStatic("org.apache.spark.SparkFiles", "getRootDirectory") + } else { + # Running on worker. + Sys.getenv("SPARKR_SPARKFILES_ROOT_DIR") + } } #' Get the absolute path of a file added through spark.addFile. @@ -345,7 +351,13 @@ spark.getSparkFilesRootDirectory <- function() { #'} #' @note spark.getSparkFiles since 2.1.0 spark.getSparkFiles <- function(fileName) { - callJStatic("org.apache.spark.SparkFiles", "get", as.character(fileName)) + if (Sys.getenv("SPARKR_IS_RUNNING_ON_WORKER") == "") { + # Running on driver. + callJStatic("org.apache.spark.SparkFiles", "get", as.character(fileName)) + } else { + # Running on worker. + file.path(spark.getSparkFilesRootDirectory(), as.character(fileName)) + } } #' Run a function over a list of elements, distributing the computations with Spark diff --git a/R/pkg/inst/tests/testthat/test_context.R b/R/pkg/inst/tests/testthat/test_context.R index caca06933952b..c847113491113 100644 --- a/R/pkg/inst/tests/testthat/test_context.R +++ b/R/pkg/inst/tests/testthat/test_context.R @@ -177,6 +177,13 @@ test_that("add and get file to be downloaded with Spark job on every node", { spark.addFile(path) download_path <- spark.getSparkFiles(filename) expect_equal(readLines(download_path), words) + + # Test spark.getSparkFiles works well on executors. + seq <- seq(from = 1, to = 10, length.out = 5) + f <- function(seq) { spark.getSparkFiles(filename) } + results <- spark.lapply(seq, f) + for (i in 1:5) { expect_equal(basename(results[[i]]), filename) } + unlink(path) # Test add directory recursively. diff --git a/core/src/main/scala/org/apache/spark/api/r/RRunner.scala b/core/src/main/scala/org/apache/spark/api/r/RRunner.scala index 29e21b3b1aa8a..88118392003e8 100644 --- a/core/src/main/scala/org/apache/spark/api/r/RRunner.scala +++ b/core/src/main/scala/org/apache/spark/api/r/RRunner.scala @@ -347,6 +347,8 @@ private[r] object RRunner { pb.environment().put("SPARKR_RLIBDIR", rLibDir.mkString(",")) pb.environment().put("SPARKR_WORKER_PORT", port.toString) pb.environment().put("SPARKR_BACKEND_CONNECTION_TIMEOUT", rConnectionTimeout.toString) + pb.environment().put("SPARKR_SPARKFILES_ROOT_DIR", SparkFiles.getRootDirectory()) + pb.environment().put("SPARKR_IS_RUNNING_ON_WORKER", "TRUE") pb.redirectErrorStream(true) // redirect stderr into stdout val proc = pb.start() val errThread = startStdoutThread(proc) From 7343a09401e7d6636634968b1cd8bc403a1f77b6 Mon Sep 17 00:00:00 2001 From: Xiao Li Date: Wed, 22 Mar 2017 19:08:28 +0800 Subject: [PATCH 0085/1765] [SPARK-20023][SQL] Output table comment for DESC FORMATTED ### What changes were proposed in this pull request? Currently, `DESC FORMATTED` did not output the table comment, unlike what `DESC EXTENDED` does. This PR is to fix it. Also correct the following displayed names in `DESC FORMATTED`, for being consistent with `DESC EXTENDED` - `"Create Time:"` -> `"Created:"` - `"Last Access Time:"` -> `"Last Access:"` ### How was this patch tested? Added test cases in `describe.sql` Author: Xiao Li Closes #17381 from gatorsmile/descFormattedTableComment. --- .../spark/sql/execution/command/tables.scala | 5 +- .../resources/sql-tests/inputs/describe.sql | 14 +- .../sql-tests/results/describe.sql.out | 125 ++++++++++++++++-- .../apache/spark/sql/SQLQueryTestSuite.scala | 6 +- 4 files changed, 124 insertions(+), 26 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala index 93307fc883565..c7aeef06a0bf0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala @@ -568,11 +568,12 @@ case class DescribeTableCommand( append(buffer, "# Detailed Table Information", "", "") append(buffer, "Database:", table.database, "") append(buffer, "Owner:", table.owner, "") - append(buffer, "Create Time:", new Date(table.createTime).toString, "") - append(buffer, "Last Access Time:", new Date(table.lastAccessTime).toString, "") + append(buffer, "Created:", new Date(table.createTime).toString, "") + append(buffer, "Last Access:", new Date(table.lastAccessTime).toString, "") append(buffer, "Location:", table.storage.locationUri.map(CatalogUtils.URIToString(_)) .getOrElse(""), "") append(buffer, "Table Type:", table.tableType.name, "") + append(buffer, "Comment:", table.comment.getOrElse(""), "") table.stats.foreach(s => append(buffer, "Statistics:", s.simpleString, "")) append(buffer, "Table Parameters:", "", "") diff --git a/sql/core/src/test/resources/sql-tests/inputs/describe.sql b/sql/core/src/test/resources/sql-tests/inputs/describe.sql index ff327f5e82b13..56f3281440d29 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/describe.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/describe.sql @@ -1,4 +1,4 @@ -CREATE TABLE t (a STRING, b INT, c STRING, d STRING) USING parquet PARTITIONED BY (c, d); +CREATE TABLE t (a STRING, b INT, c STRING, d STRING) USING parquet PARTITIONED BY (c, d) COMMENT 'table_comment'; ALTER TABLE t ADD PARTITION (c='Us', d=1); @@ -8,15 +8,15 @@ DESC t; DESC TABLE t; --- Ignore these because there exist timestamp results, e.g., `Create Table`. --- DESC EXTENDED t; --- DESC FORMATTED t; +DESC FORMATTED t; + +DESC EXTENDED t; DESC t PARTITION (c='Us', d=1); --- Ignore these because there exist timestamp results, e.g., transient_lastDdlTime. --- DESC EXTENDED t PARTITION (c='Us', d=1); --- DESC FORMATTED t PARTITION (c='Us', d=1); +DESC EXTENDED t PARTITION (c='Us', d=1); + +DESC FORMATTED t PARTITION (c='Us', d=1); -- NoSuchPartitionException: Partition not found in table DESC t PARTITION (c='Us', d=2); diff --git a/sql/core/src/test/resources/sql-tests/results/describe.sql.out b/sql/core/src/test/resources/sql-tests/results/describe.sql.out index 0a11c1cde2b45..422d548ea8de8 100644 --- a/sql/core/src/test/resources/sql-tests/results/describe.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/describe.sql.out @@ -1,9 +1,9 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 10 +-- Number of queries: 14 -- !query 0 -CREATE TABLE t (a STRING, b INT, c STRING, d STRING) USING parquet PARTITIONED BY (c, d) +CREATE TABLE t (a STRING, b INT, c STRING, d STRING) USING parquet PARTITIONED BY (c, d) COMMENT 'table_comment' -- !query 0 schema struct<> -- !query 0 output @@ -64,12 +64,25 @@ d string -- !query 5 -DESC t PARTITION (c='Us', d=1) +DESC FORMATTED t -- !query 5 schema struct -- !query 5 output +# Detailed Table Information # Partition Information +# Storage Information # col_name data_type comment +Comment: table_comment +Compressed: No +Created: +Database: default +Last Access: +Location: sql/core/spark-warehouse/t +Owner: +Partition Provider: Catalog +Storage Desc Parameters: +Table Parameters: +Table Type: MANAGED a string b int c string @@ -79,30 +92,114 @@ d string -- !query 6 -DESC t PARTITION (c='Us', d=2) +DESC EXTENDED t -- !query 6 schema -struct<> +struct -- !query 6 output +# Detailed Table Information CatalogTable( + Table: `default`.`t` + Created: + Last Access: + Type: MANAGED + Schema: [StructField(a,StringType,true), StructField(b,IntegerType,true), StructField(c,StringType,true), StructField(d,StringType,true)] + Provider: parquet + Partition Columns: [`c`, `d`] + Comment: table_comment + Storage(Location: sql/core/spark-warehouse/t) + Partition Provider: Catalog) +# Partition Information +# col_name data_type comment +a string +b int +c string +c string +d string +d string + + +-- !query 7 +DESC t PARTITION (c='Us', d=1) +-- !query 7 schema +struct +-- !query 7 output +# Partition Information +# col_name data_type comment +a string +b int +c string +c string +d string +d string + + +-- !query 8 +DESC EXTENDED t PARTITION (c='Us', d=1) +-- !query 8 schema +struct +-- !query 8 output +# Partition Information +# col_name data_type comment +Detailed Partition Information CatalogPartition( + Partition Values: [c=Us, d=1] + Storage(Location: sql/core/spark-warehouse/t/c=Us/d=1) + Partition Parameters:{}) +a string +b int +c string +c string +d string +d string + + +-- !query 9 +DESC FORMATTED t PARTITION (c='Us', d=1) +-- !query 9 schema +struct +-- !query 9 output +# Detailed Partition Information +# Partition Information +# Storage Information +# col_name data_type comment +Compressed: No +Database: default +Location: sql/core/spark-warehouse/t/c=Us/d=1 +Partition Parameters: +Partition Value: [Us, 1] +Storage Desc Parameters: +Table: t +a string +b int +c string +c string +d string +d string + + +-- !query 10 +DESC t PARTITION (c='Us', d=2) +-- !query 10 schema +struct<> +-- !query 10 output org.apache.spark.sql.catalyst.analysis.NoSuchPartitionException Partition not found in table 't' database 'default': c -> Us d -> 2; --- !query 7 +-- !query 11 DESC t PARTITION (c='Us') --- !query 7 schema +-- !query 11 schema struct<> --- !query 7 output +-- !query 11 output org.apache.spark.sql.AnalysisException Partition spec is invalid. The spec (c) must match the partition spec (c, d) defined in table '`default`.`t`'; --- !query 8 +-- !query 12 DESC t PARTITION (c='Us', d) --- !query 8 schema +-- !query 12 schema struct<> --- !query 8 output +-- !query 12 output org.apache.spark.sql.catalyst.parser.ParseException PARTITION specification is incomplete: `d`(line 1, pos 0) @@ -112,9 +209,9 @@ DESC t PARTITION (c='Us', d) ^^^ --- !query 9 +-- !query 13 DROP TABLE t --- !query 9 schema +-- !query 13 schema struct<> --- !query 9 output +-- !query 13 output diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQueryTestSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQueryTestSuite.scala index c285995514c85..4092862c430b1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQueryTestSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQueryTestSuite.scala @@ -223,9 +223,9 @@ class SQLQueryTestSuite extends QueryTest with SharedSQLContext { val schema = df.schema // Get answer, but also get rid of the #1234 expression ids that show up in explain plans val answer = df.queryExecution.hiveResultString().map(_.replaceAll("#\\d+", "#x") - .replaceAll("Location: .*/sql/core/", "Location: sql/core/") - .replaceAll("Created: .*\n", "Created: \n") - .replaceAll("Last Access: .*\n", "Last Access: \n")) + .replaceAll("Location:.*/sql/core/", "Location: sql/core/") + .replaceAll("Created: .*", "Created: ") + .replaceAll("Last Access: .*", "Last Access: ")) // If the output is not pre-sorted, sort it. if (isSorted(df.queryExecution.analyzed)) (schema, answer) else (schema, answer.sorted) From facfd608865c385c0dabfe09cffe5874532a9cdf Mon Sep 17 00:00:00 2001 From: uncleGen Date: Wed, 22 Mar 2017 11:10:08 +0000 Subject: [PATCH 0086/1765] [SPARK-20021][PYSPARK] Miss backslash in python code ## What changes were proposed in this pull request? Add backslash for line continuation in python code. ## How was this patch tested? Jenkins. Author: uncleGen Author: dylon Closes #17352 from uncleGen/python-example-doc. --- docs/structured-streaming-programming-guide.md | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/docs/structured-streaming-programming-guide.md b/docs/structured-streaming-programming-guide.md index 798847237866b..ff07ad11943bd 100644 --- a/docs/structured-streaming-programming-guide.md +++ b/docs/structured-streaming-programming-guide.md @@ -764,11 +764,11 @@ Dataset windowedCounts = words words = ... # streaming DataFrame of schema { timestamp: Timestamp, word: String } # Group the data by window and word and compute the count of each group -windowedCounts = words - .withWatermark("timestamp", "10 minutes") +windowedCounts = words \ + .withWatermark("timestamp", "10 minutes") \ .groupBy( window(words.timestamp, "10 minutes", "5 minutes"), - words.word) + words.word) \ .count() {% endhighlight %} From 0caade634076034182e22318eb09a6df1c560576 Mon Sep 17 00:00:00 2001 From: Prashant Sharma Date: Wed, 22 Mar 2017 13:52:03 +0000 Subject: [PATCH 0087/1765] [SPARK-20027][DOCS] Compilation fix in java docs. ## What changes were proposed in this pull request? During build/sbt publish-local, build breaks due to javadocs errors. This patch fixes those errors. ## How was this patch tested? Tested by running the sbt build. Author: Prashant Sharma Closes #17358 from ScrapCodes/docs-fix. --- .../java/org/apache/spark/network/crypto/ClientChallenge.java | 2 +- .../java/org/apache/spark/network/crypto/ServerResponse.java | 2 +- .../main/java/org/apache/spark/unsafe/types/UTF8String.java | 2 +- .../api/java/function/FlatMapGroupsWithStateFunction.java | 3 ++- 4 files changed, 5 insertions(+), 4 deletions(-) diff --git a/common/network-common/src/main/java/org/apache/spark/network/crypto/ClientChallenge.java b/common/network-common/src/main/java/org/apache/spark/network/crypto/ClientChallenge.java index 3312a5bd81a66..819b8a7efbdba 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/crypto/ClientChallenge.java +++ b/common/network-common/src/main/java/org/apache/spark/network/crypto/ClientChallenge.java @@ -28,7 +28,7 @@ /** * The client challenge message, used to initiate authentication. * - * @see README.md + * Please see crypto/README.md for more details of implementation. */ public class ClientChallenge implements Encodable { /** Serialization tag used to catch incorrect payloads. */ diff --git a/common/network-common/src/main/java/org/apache/spark/network/crypto/ServerResponse.java b/common/network-common/src/main/java/org/apache/spark/network/crypto/ServerResponse.java index affdbf450b1d0..caf3a0f3b38cc 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/crypto/ServerResponse.java +++ b/common/network-common/src/main/java/org/apache/spark/network/crypto/ServerResponse.java @@ -28,7 +28,7 @@ /** * Server's response to client's challenge. * - * @see README.md + * Please see crypto/README.md for more details. */ public class ServerResponse implements Encodable { /** Serialization tag used to catch incorrect payloads. */ diff --git a/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java index 4c28075bd9386..5437e998c085f 100644 --- a/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java +++ b/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java @@ -863,7 +863,7 @@ public static class LongWrapper { * This is done solely for better performance and is not expected to be used by end users. * * {@link LongWrapper} could have been used here but using `int` directly save the extra cost of - * conversion from `long` -> `int` + * conversion from `long` to `int` */ public static class IntWrapper { public int value = 0; diff --git a/sql/core/src/main/java/org/apache/spark/api/java/function/FlatMapGroupsWithStateFunction.java b/sql/core/src/main/java/org/apache/spark/api/java/function/FlatMapGroupsWithStateFunction.java index 29af78c4f6a85..bdda8aaf734dd 100644 --- a/sql/core/src/main/java/org/apache/spark/api/java/function/FlatMapGroupsWithStateFunction.java +++ b/sql/core/src/main/java/org/apache/spark/api/java/function/FlatMapGroupsWithStateFunction.java @@ -28,7 +28,8 @@ * ::Experimental:: * Base interface for a map function used in * {@link org.apache.spark.sql.KeyValueGroupedDataset#flatMapGroupsWithState( - * FlatMapGroupsWithStateFunction, org.apache.spark.sql.Encoder, org.apache.spark.sql.Encoder)}. + * FlatMapGroupsWithStateFunction, org.apache.spark.sql.streaming.OutputMode, + * org.apache.spark.sql.Encoder, org.apache.spark.sql.Encoder)} * @since 2.1.1 */ @Experimental From 465818389aab1217c9de5c685cfaee3ffaec91bb Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Wed, 22 Mar 2017 09:52:37 -0700 Subject: [PATCH 0088/1765] [SPARK-19949][SQL][FOLLOW-UP] Clean up parse modes and update related comments ## What changes were proposed in this pull request? This PR proposes to make `mode` options in both CSV and JSON to use `cass object` and fix some related comments related previous fix. Also, this PR modifies some tests related parse modes. ## How was this patch tested? Modified unit tests in both `CSVSuite.scala` and `JsonSuite.scala`. Author: hyukjinkwon Closes #17377 from HyukjinKwon/SPARK-19949. --- python/pyspark/sql/readwriter.py | 6 +- python/pyspark/sql/streaming.py | 2 + .../expressions/jsonExpressions.scala | 4 +- .../spark/sql/catalyst/json/JSONOptions.scala | 12 +- .../sql/catalyst/util/FailureSafeParser.scala | 15 ++- .../spark/sql/catalyst/util/ParseMode.scala | 56 +++++++++ .../spark/sql/catalyst/util/ParseModes.scala | 41 ------- .../expressions/JsonExpressionsSuite.scala | 4 +- .../apache/spark/sql/DataFrameReader.scala | 4 +- .../datasources/csv/CSVOptions.scala | 14 +-- .../datasources/json/JsonInferSchema.scala | 3 +- .../sql/streaming/DataStreamReader.scala | 2 +- .../execution/datasources/csv/CSVSuite.scala | 7 +- .../datasources/json/JsonSuite.scala | 113 +++++++----------- 14 files changed, 130 insertions(+), 153 deletions(-) create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ParseMode.scala delete mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ParseModes.scala diff --git a/python/pyspark/sql/readwriter.py b/python/pyspark/sql/readwriter.py index 122e17f2020f4..759c27507c397 100644 --- a/python/pyspark/sql/readwriter.py +++ b/python/pyspark/sql/readwriter.py @@ -369,10 +369,8 @@ def csv(self, path, schema=None, sep=None, encoding=None, quote=None, escape=Non :param maxCharsPerColumn: defines the maximum number of characters allowed for any given value being read. If None is set, it uses the default value, ``-1`` meaning unlimited length. - :param maxMalformedLogPerPartition: sets the maximum number of malformed rows Spark will - log for each partition. Malformed records beyond this - number will be ignored. If None is set, it - uses the default value, ``10``. + :param maxMalformedLogPerPartition: this parameter is no longer used since Spark 2.2.0. + If specified, it is ignored. :param mode: allows a mode for dealing with corrupt records during parsing. If None is set, it uses the default value, ``PERMISSIVE``. diff --git a/python/pyspark/sql/streaming.py b/python/pyspark/sql/streaming.py index 288cc1e4f64dc..e227f9ceb5769 100644 --- a/python/pyspark/sql/streaming.py +++ b/python/pyspark/sql/streaming.py @@ -625,6 +625,8 @@ def csv(self, path, schema=None, sep=None, encoding=None, quote=None, escape=Non :param maxCharsPerColumn: defines the maximum number of characters allowed for any given value being read. If None is set, it uses the default value, ``-1`` meaning unlimited length. + :param maxMalformedLogPerPartition: this parameter is no longer used since Spark 2.2.0. + If specified, it is ignored. :param mode: allows a mode for dealing with corrupt records during parsing. If None is set, it uses the default value, ``PERMISSIVE``. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala index 08af5522d822d..df4d406b84d60 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala @@ -29,7 +29,7 @@ import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback import org.apache.spark.sql.catalyst.parser.CatalystSqlParser import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.json._ -import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, ArrayData, BadRecordException, GenericArrayData, ParseModes} +import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, ArrayData, BadRecordException, FailFastMode, GenericArrayData} import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String import org.apache.spark.util.Utils @@ -548,7 +548,7 @@ case class JsonToStructs( lazy val parser = new JacksonParser( rowSchema, - new JSONOptions(options + ("mode" -> ParseModes.FAIL_FAST_MODE), timeZoneId.get)) + new JSONOptions(options + ("mode" -> FailFastMode.name), timeZoneId.get)) override def dataType: DataType = schema diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JSONOptions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JSONOptions.scala index 355c26afa6f0d..c22b1ade4e64b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JSONOptions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JSONOptions.scala @@ -65,7 +65,8 @@ private[sql] class JSONOptions( val allowBackslashEscapingAnyCharacter = parameters.get("allowBackslashEscapingAnyCharacter").map(_.toBoolean).getOrElse(false) val compressionCodec = parameters.get("compression").map(CompressionCodecs.getCodecClassName) - val parseMode = parameters.getOrElse("mode", "PERMISSIVE") + val parseMode: ParseMode = + parameters.get("mode").map(ParseMode.fromString).getOrElse(PermissiveMode) val columnNameOfCorruptRecord = parameters.getOrElse("columnNameOfCorruptRecord", defaultColumnNameOfCorruptRecord) @@ -82,15 +83,6 @@ private[sql] class JSONOptions( val wholeFile = parameters.get("wholeFile").map(_.toBoolean).getOrElse(false) - // Parse mode flags - if (!ParseModes.isValidMode(parseMode)) { - logWarning(s"$parseMode is not a valid parse mode. Using ${ParseModes.DEFAULT}.") - } - - val failFast = ParseModes.isFailFastMode(parseMode) - val dropMalformed = ParseModes.isDropMalformedMode(parseMode) - val permissive = ParseModes.isPermissiveMode(parseMode) - /** Sets config options on a Jackson [[JsonFactory]]. */ def setJacksonOptions(factory: JsonFactory): Unit = { factory.configure(JsonParser.Feature.ALLOW_COMMENTS, allowComments) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/FailureSafeParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/FailureSafeParser.scala index e8da10d65ecb9..725e3015b3416 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/FailureSafeParser.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/FailureSafeParser.scala @@ -24,7 +24,7 @@ import org.apache.spark.unsafe.types.UTF8String class FailureSafeParser[IN]( rawParser: IN => Seq[InternalRow], - mode: String, + mode: ParseMode, schema: StructType, columnNameOfCorruptRecord: String) { @@ -58,11 +58,14 @@ class FailureSafeParser[IN]( try { rawParser.apply(input).toIterator.map(row => toResultRow(Some(row), () => null)) } catch { - case e: BadRecordException if ParseModes.isPermissiveMode(mode) => - Iterator(toResultRow(e.partialResult(), e.record)) - case _: BadRecordException if ParseModes.isDropMalformedMode(mode) => - Iterator.empty - case e: BadRecordException => throw e.cause + case e: BadRecordException => mode match { + case PermissiveMode => + Iterator(toResultRow(e.partialResult(), e.record)) + case DropMalformedMode => + Iterator.empty + case FailFastMode => + throw e.cause + } } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ParseMode.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ParseMode.scala new file mode 100644 index 0000000000000..4565dbde88c88 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ParseMode.scala @@ -0,0 +1,56 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.util + +import org.apache.spark.internal.Logging + +sealed trait ParseMode { + /** + * String name of the parse mode. + */ + def name: String +} + +/** + * This mode permissively parses the records. + */ +case object PermissiveMode extends ParseMode { val name = "PERMISSIVE" } + +/** + * This mode ignores the whole corrupted records. + */ +case object DropMalformedMode extends ParseMode { val name = "DROPMALFORMED" } + +/** + * This mode throws an exception when it meets corrupted records. + */ +case object FailFastMode extends ParseMode { val name = "FAILFAST" } + +object ParseMode extends Logging { + /** + * Returns the parse mode from the given string. + */ + def fromString(mode: String): ParseMode = mode.toUpperCase match { + case PermissiveMode.name => PermissiveMode + case DropMalformedMode.name => DropMalformedMode + case FailFastMode.name => FailFastMode + case _ => + logWarning(s"$mode is not a valid parse mode. Using ${PermissiveMode.name}.") + PermissiveMode + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ParseModes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ParseModes.scala deleted file mode 100644 index 0e466962b4678..0000000000000 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ParseModes.scala +++ /dev/null @@ -1,41 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.catalyst.util - -object ParseModes { - val PERMISSIVE_MODE = "PERMISSIVE" - val DROP_MALFORMED_MODE = "DROPMALFORMED" - val FAIL_FAST_MODE = "FAILFAST" - - val DEFAULT = PERMISSIVE_MODE - - def isValidMode(mode: String): Boolean = { - mode.toUpperCase match { - case PERMISSIVE_MODE | DROP_MALFORMED_MODE | FAIL_FAST_MODE => true - case _ => false - } - } - - def isDropMalformedMode(mode: String): Boolean = mode.toUpperCase == DROP_MALFORMED_MODE - def isFailFastMode(mode: String): Boolean = mode.toUpperCase == FAIL_FAST_MODE - def isPermissiveMode(mode: String): Boolean = if (isValidMode(mode)) { - mode.toUpperCase == PERMISSIVE_MODE - } else { - true // We default to permissive is the mode string is not valid - } -} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/JsonExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/JsonExpressionsSuite.scala index e4698d44636b6..c5b72235e5db0 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/JsonExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/JsonExpressionsSuite.scala @@ -21,7 +21,7 @@ import java.util.Calendar import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.util.{DateTimeTestUtils, DateTimeUtils, GenericArrayData, ParseModes} +import org.apache.spark.sql.catalyst.util.{DateTimeTestUtils, DateTimeUtils, GenericArrayData, PermissiveMode} import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String @@ -367,7 +367,7 @@ class JsonExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { // Other modes should still return `null`. checkEvaluation( - JsonToStructs(schema, Map("mode" -> ParseModes.PERMISSIVE_MODE), Literal(jsonData), gmtId), + JsonToStructs(schema, Map("mode" -> PermissiveMode.name), Literal(jsonData), gmtId), null ) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala index 767a636d70731..e39b4d91f1f6a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala @@ -510,10 +510,8 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { * a record can have.
  • *
  • `maxCharsPerColumn` (default `-1`): defines the maximum number of characters allowed * for any given value being read. By default, it is -1 meaning unlimited length
  • - *
  • `maxMalformedLogPerPartition` (default `10`): sets the maximum number of malformed rows - * Spark will log for each partition. Malformed records beyond this number will be ignored.
  • *
  • `mode` (default `PERMISSIVE`): allows a mode for dealing with corrupt records - * during parsing. + * during parsing. It supports the following case-insensitive modes. *
      *
    • `PERMISSIVE` : sets other fields to `null` when it meets a corrupted record, and puts * the malformed string into a field configured by `columnNameOfCorruptRecord`. To keep diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala index f6c6b6f56cd9d..5d2c23ed9618c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala @@ -82,7 +82,8 @@ class CSVOptions( val delimiter = CSVUtils.toChar( parameters.getOrElse("sep", parameters.getOrElse("delimiter", ","))) - val parseMode = parameters.getOrElse("mode", "PERMISSIVE") + val parseMode: ParseMode = + parameters.get("mode").map(ParseMode.fromString).getOrElse(PermissiveMode) val charset = parameters.getOrElse("encoding", parameters.getOrElse("charset", StandardCharsets.UTF_8.name())) @@ -95,15 +96,6 @@ class CSVOptions( val ignoreLeadingWhiteSpaceFlag = getBool("ignoreLeadingWhiteSpace") val ignoreTrailingWhiteSpaceFlag = getBool("ignoreTrailingWhiteSpace") - // Parse mode flags - if (!ParseModes.isValidMode(parseMode)) { - logWarning(s"$parseMode is not a valid parse mode. Using ${ParseModes.DEFAULT}.") - } - - val failFast = ParseModes.isFailFastMode(parseMode) - val dropMalformed = ParseModes.isDropMalformedMode(parseMode) - val permissive = ParseModes.isPermissiveMode(parseMode) - val columnNameOfCorruptRecord = parameters.getOrElse("columnNameOfCorruptRecord", defaultColumnNameOfCorruptRecord) @@ -139,8 +131,6 @@ class CSVOptions( val escapeQuotes = getBool("escapeQuotes", true) - val maxMalformedLogPerPartition = getInt("maxMalformedLogPerPartition", 10) - val quoteAll = getBool("quoteAll", false) val inputBufferSize = 128 diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonInferSchema.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonInferSchema.scala index 7475f8ec79331..e15c30b4374bb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonInferSchema.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonInferSchema.scala @@ -25,6 +25,7 @@ import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.analysis.TypeCoercion import org.apache.spark.sql.catalyst.json.JacksonUtils.nextUntil import org.apache.spark.sql.catalyst.json.JSONOptions +import org.apache.spark.sql.catalyst.util.PermissiveMode import org.apache.spark.sql.types._ import org.apache.spark.util.Utils @@ -40,7 +41,7 @@ private[sql] object JsonInferSchema { json: RDD[T], configOptions: JSONOptions, createParser: (JsonFactory, T) => JsonParser): StructType = { - val shouldHandleCorruptRecord = configOptions.permissive + val shouldHandleCorruptRecord = configOptions.parseMode == PermissiveMode val columnNameOfCorruptRecord = configOptions.columnNameOfCorruptRecord // perform schema inference on each row and merge afterwards diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala index 388ef182ce3a6..f6e2fef74b8db 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala @@ -260,7 +260,7 @@ final class DataStreamReader private[sql](sparkSession: SparkSession) extends Lo *
    • `maxCharsPerColumn` (default `-1`): defines the maximum number of characters allowed * for any given value being read. By default, it is -1 meaning unlimited length
    • *
    • `mode` (default `PERMISSIVE`): allows a mode for dealing with corrupt records - * during parsing. + * during parsing. It supports the following case-insensitive modes. *
        *
      • `PERMISSIVE` : sets other fields to `null` when it meets a corrupted record, and puts * the malformed string into a field configured by `columnNameOfCorruptRecord`. To keep diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala index 598babfe0e7ad..2600894ca303c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala @@ -992,9 +992,10 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils { test("SPARK-18699 put malformed records in a `columnNameOfCorruptRecord` field") { Seq(false, true).foreach { wholeFile => val schema = new StructType().add("a", IntegerType).add("b", TimestampType) + // We use `PERMISSIVE` mode by default if invalid string is given. val df1 = spark .read - .option("mode", "PERMISSIVE") + .option("mode", "abcd") .option("wholeFile", wholeFile) .schema(schema) .csv(testFile(valueMalformedFile)) @@ -1008,7 +1009,7 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils { val schemaWithCorrField1 = schema.add(columnNameOfCorruptRecord, StringType) val df2 = spark .read - .option("mode", "PERMISSIVE") + .option("mode", "Permissive") .option("columnNameOfCorruptRecord", columnNameOfCorruptRecord) .option("wholeFile", wholeFile) .schema(schemaWithCorrField1) @@ -1025,7 +1026,7 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils { .add("b", TimestampType) val df3 = spark .read - .option("mode", "PERMISSIVE") + .option("mode", "permissive") .option("columnNameOfCorruptRecord", columnNameOfCorruptRecord) .option("wholeFile", wholeFile) .schema(schemaWithCorrField2) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala index 56fcf773f7dd9..b09cef76d2be7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala @@ -1083,83 +1083,59 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { } test("Corrupt records: PERMISSIVE mode, without designated column for malformed records") { - withTempView("jsonTable") { - val schema = StructType( - StructField("a", StringType, true) :: - StructField("b", StringType, true) :: - StructField("c", StringType, true) :: Nil) + val schema = StructType( + StructField("a", StringType, true) :: + StructField("b", StringType, true) :: + StructField("c", StringType, true) :: Nil) - val jsonDF = spark.read.schema(schema).json(corruptRecords) - jsonDF.createOrReplaceTempView("jsonTable") + val jsonDF = spark.read.schema(schema).json(corruptRecords) - checkAnswer( - sql( - """ - |SELECT a, b, c - |FROM jsonTable - """.stripMargin), - Seq( - // Corrupted records are replaced with null - Row(null, null, null), - Row(null, null, null), - Row(null, null, null), - Row("str_a_4", "str_b_4", "str_c_4"), - Row(null, null, null)) - ) - } + checkAnswer( + jsonDF.select($"a", $"b", $"c"), + Seq( + // Corrupted records are replaced with null + Row(null, null, null), + Row(null, null, null), + Row(null, null, null), + Row("str_a_4", "str_b_4", "str_c_4"), + Row(null, null, null)) + ) } test("Corrupt records: PERMISSIVE mode, with designated column for malformed records") { // Test if we can query corrupt records. withSQLConf(SQLConf.COLUMN_NAME_OF_CORRUPT_RECORD.key -> "_unparsed") { - withTempView("jsonTable") { - val jsonDF = spark.read.json(corruptRecords) - jsonDF.createOrReplaceTempView("jsonTable") - val schema = StructType( - StructField("_unparsed", StringType, true) :: + val jsonDF = spark.read.json(corruptRecords) + val schema = StructType( + StructField("_unparsed", StringType, true) :: StructField("a", StringType, true) :: StructField("b", StringType, true) :: StructField("c", StringType, true) :: Nil) - assert(schema === jsonDF.schema) - - // In HiveContext, backticks should be used to access columns starting with a underscore. - checkAnswer( - sql( - """ - |SELECT a, b, c, _unparsed - |FROM jsonTable - """.stripMargin), - Row(null, null, null, "{") :: - Row(null, null, null, """{"a":1, b:2}""") :: - Row(null, null, null, """{"a":{, b:3}""") :: - Row("str_a_4", "str_b_4", "str_c_4", null) :: - Row(null, null, null, "]") :: Nil - ) - - checkAnswer( - sql( - """ - |SELECT a, b, c - |FROM jsonTable - |WHERE _unparsed IS NULL - """.stripMargin), - Row("str_a_4", "str_b_4", "str_c_4") - ) - - checkAnswer( - sql( - """ - |SELECT _unparsed - |FROM jsonTable - |WHERE _unparsed IS NOT NULL - """.stripMargin), - Row("{") :: - Row("""{"a":1, b:2}""") :: - Row("""{"a":{, b:3}""") :: - Row("]") :: Nil - ) - } + assert(schema === jsonDF.schema) + + // In HiveContext, backticks should be used to access columns starting with a underscore. + checkAnswer( + jsonDF.select($"a", $"b", $"c", $"_unparsed"), + Row(null, null, null, "{") :: + Row(null, null, null, """{"a":1, b:2}""") :: + Row(null, null, null, """{"a":{, b:3}""") :: + Row("str_a_4", "str_b_4", "str_c_4", null) :: + Row(null, null, null, "]") :: Nil + ) + + checkAnswer( + jsonDF.filter($"_unparsed".isNull).select($"a", $"b", $"c"), + Row("str_a_4", "str_b_4", "str_c_4") + ) + + checkAnswer( + jsonDF.filter($"_unparsed".isNotNull).select($"_unparsed"), + Row("{") :: + Row("""{"a":1, b:2}""") :: + Row("""{"a":{, b:3}""") :: + Row("]") :: Nil + ) } } @@ -1952,19 +1928,20 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { StructField("c", StringType, true) :: Nil) val errMsg = intercept[AnalysisException] { spark.read - .option("mode", "PERMISSIVE") + .option("mode", "Permissive") .option("columnNameOfCorruptRecord", columnNameOfCorruptRecord) .schema(schema) .json(corruptRecords) }.getMessage assert(errMsg.startsWith("The field for corrupt records must be string type and nullable")) + // We use `PERMISSIVE` mode by default if invalid string is given. withTempPath { dir => val path = dir.getCanonicalPath corruptRecords.toDF("value").write.text(path) val errMsg = intercept[AnalysisException] { spark.read - .option("mode", "PERMISSIVE") + .option("mode", "permm") .option("columnNameOfCorruptRecord", columnNameOfCorruptRecord) .schema(schema) .json(path) From 80fd070389a9c8ffa342d7b11f1ab2ea92e0f562 Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Wed, 22 Mar 2017 09:58:46 -0700 Subject: [PATCH 0089/1765] [SPARK-20018][SQL] Pivot with timestamp and count should not print internal representation ## What changes were proposed in this pull request? Currently, when we perform count with timestamp types, it prints the internal representation as the column name as below: ```scala Seq(new java.sql.Timestamp(1)).toDF("a").groupBy("a").pivot("a").count().show() ``` ``` +--------------------+----+ | a|1000| +--------------------+----+ |1969-12-31 16:00:...| 1| +--------------------+----+ ``` This PR proposes to use external Scala value instead of the internal representation in the column names as below: ``` +--------------------+-----------------------+ | a|1969-12-31 16:00:00.001| +--------------------+-----------------------+ |1969-12-31 16:00:...| 1| +--------------------+-----------------------+ ``` ## How was this patch tested? Unit test in `DataFramePivotSuite` and manual tests. Author: hyukjinkwon Closes #17348 from HyukjinKwon/SPARK-20018. --- .../spark/sql/catalyst/analysis/Analyzer.scala | 6 ++++-- .../apache/spark/sql/DataFramePivotSuite.scala | 18 +++++++++++++++++- 2 files changed, 21 insertions(+), 3 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 574f91b09912b..036ed060d9efe 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -486,14 +486,16 @@ class Analyzer( case Pivot(groupByExprs, pivotColumn, pivotValues, aggregates, child) => val singleAgg = aggregates.size == 1 def outputName(value: Literal, aggregate: Expression): String = { + val utf8Value = Cast(value, StringType, Some(conf.sessionLocalTimeZone)).eval(EmptyRow) + val stringValue: String = Option(utf8Value).map(_.toString).getOrElse("null") if (singleAgg) { - value.toString + stringValue } else { val suffix = aggregate match { case n: NamedExpression => n.name case _ => toPrettySQL(aggregate) } - value + "_" + suffix + stringValue + "_" + suffix } } if (aggregates.forall(a => PivotFirst.supportsDataType(a.dataType))) { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFramePivotSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFramePivotSuite.scala index ca3cb5676742e..6ca9ee57e8f49 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFramePivotSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFramePivotSuite.scala @@ -23,7 +23,7 @@ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types._ -class DataFramePivotSuite extends QueryTest with SharedSQLContext{ +class DataFramePivotSuite extends QueryTest with SharedSQLContext { import testImplicits._ test("pivot courses") { @@ -230,4 +230,20 @@ class DataFramePivotSuite extends QueryTest with SharedSQLContext{ .groupBy($"a").pivot("a").agg(min($"b")), Row(null, Seq(null, 7), null) :: Row(1, null, Seq(1, 7)) :: Nil) } + + test("pivot with timestamp and count should not print internal representation") { + val ts = "2012-12-31 16:00:10.011" + val tsWithZone = "2013-01-01 00:00:10.011" + + withSQLConf(SQLConf.SESSION_LOCAL_TIMEZONE.key -> "GMT") { + val df = Seq(java.sql.Timestamp.valueOf(ts)).toDF("a").groupBy("a").pivot("a").count() + val expected = StructType( + StructField("a", TimestampType) :: + StructField(tsWithZone, LongType) :: Nil) + assert(df.schema == expected) + // String representation of timestamp with timezone should take the time difference + // into account. + checkAnswer(df.select($"a".cast(StringType)), Row(tsWithZone)) + } + } } From 82b598b963a21ae9d6a2a9638e86b4165c2a78c9 Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Wed, 22 Mar 2017 12:30:36 -0700 Subject: [PATCH 0090/1765] [SPARK-20057][SS] Renamed KeyedState to GroupState in mapGroupsWithState ## What changes were proposed in this pull request? Since the state is tied a "group" in the "mapGroupsWithState" operations, its better to call the state "GroupState" instead of a key. This would make it more general if you extends this operation to RelationGroupedDataset and python APIs. ## How was this patch tested? Existing unit tests. Author: Tathagata Das Closes #17385 from tdas/SPARK-20057. --- ...ateTimeout.java => GroupStateTimeout.java} | 18 +-- .../sql/catalyst/plans/logical/object.scala | 18 +-- ...e.java => JavaGroupStateTimeoutSuite.java} | 8 +- .../FlatMapGroupsWithStateFunction.java | 4 +- .../function/MapGroupsWithStateFunction.java | 4 +- .../spark/sql/KeyValueGroupedDataset.scala | 46 +++---- .../apache/spark/sql/execution/objects.scala | 8 +- .../FlatMapGroupsWithStateExec.scala | 16 +-- ...edStateImpl.scala => GroupStateImpl.scala} | 19 +-- .../streaming/statefulOperators.scala | 4 +- .../{KeyedState.scala => GroupState.scala} | 68 +++++----- .../apache/spark/sql/JavaDatasetSuite.java | 4 +- .../FlatMapGroupsWithStateSuite.scala | 122 +++++++++--------- 13 files changed, 172 insertions(+), 167 deletions(-) rename sql/catalyst/src/main/java/org/apache/spark/sql/streaming/{KeyedStateTimeout.java => GroupStateTimeout.java} (79%) rename sql/catalyst/src/test/java/org/apache/spark/sql/streaming/{JavaKeyedStateTimeoutSuite.java => JavaGroupStateTimeoutSuite.java} (70%) rename sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/{KeyedStateImpl.scala => GroupStateImpl.scala} (94%) rename sql/core/src/main/scala/org/apache/spark/sql/streaming/{KeyedState.scala => GroupState.scala} (84%) diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/streaming/KeyedStateTimeout.java b/sql/catalyst/src/main/java/org/apache/spark/sql/streaming/GroupStateTimeout.java similarity index 79% rename from sql/catalyst/src/main/java/org/apache/spark/sql/streaming/KeyedStateTimeout.java rename to sql/catalyst/src/main/java/org/apache/spark/sql/streaming/GroupStateTimeout.java index e2e7ab1d2609f..bd5e2d7ecca9b 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/streaming/KeyedStateTimeout.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/streaming/GroupStateTimeout.java @@ -24,31 +24,31 @@ /** * Represents the type of timeouts possible for the Dataset operations * `mapGroupsWithState` and `flatMapGroupsWithState`. See documentation on - * `KeyedState` for more details. + * `GroupState` for more details. * * @since 2.2.0 */ @Experimental @InterfaceStability.Evolving -public class KeyedStateTimeout { +public class GroupStateTimeout { /** * Timeout based on processing time. The duration of timeout can be set for each group in - * `map/flatMapGroupsWithState` by calling `KeyedState.setTimeoutDuration()`. See documentation - * on `KeyedState` for more details. + * `map/flatMapGroupsWithState` by calling `GroupState.setTimeoutDuration()`. See documentation + * on `GroupState` for more details. */ - public static KeyedStateTimeout ProcessingTimeTimeout() { return ProcessingTimeTimeout$.MODULE$; } + public static GroupStateTimeout ProcessingTimeTimeout() { return ProcessingTimeTimeout$.MODULE$; } /** * Timeout based on event-time. The event-time timestamp for timeout can be set for each - * group in `map/flatMapGroupsWithState` by calling `KeyedState.setTimeoutTimestamp()`. + * group in `map/flatMapGroupsWithState` by calling `GroupState.setTimeoutTimestamp()`. * In addition, you have to define the watermark in the query using `Dataset.withWatermark`. * When the watermark advances beyond the set timestamp of a group and the group has not * received any data, then the group times out. See documentation on - * `KeyedState` for more details. + * `GroupState` for more details. */ - public static KeyedStateTimeout EventTimeTimeout() { return EventTimeTimeout$.MODULE$; } + public static GroupStateTimeout EventTimeTimeout() { return EventTimeTimeout$.MODULE$; } /** No timeout. */ - public static KeyedStateTimeout NoTimeout() { return NoTimeout$.MODULE$; } + public static GroupStateTimeout NoTimeout() { return NoTimeout$.MODULE$; } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala index e0ecf8c5f2643..6225b3fa42990 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala @@ -26,7 +26,7 @@ import org.apache.spark.sql.catalyst.analysis.UnresolvedDeserializer import org.apache.spark.sql.catalyst.encoders._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.objects.Invoke -import org.apache.spark.sql.streaming.{KeyedStateTimeout, OutputMode } +import org.apache.spark.sql.streaming.{GroupStateTimeout, OutputMode } import org.apache.spark.sql.types._ import org.apache.spark.util.Utils @@ -351,22 +351,22 @@ case class MapGroups( child: LogicalPlan) extends UnaryNode with ObjectProducer /** Internal class representing State */ -trait LogicalKeyedState[S] +trait LogicalGroupState[S] /** Types of timeouts used in FlatMapGroupsWithState */ -case object NoTimeout extends KeyedStateTimeout -case object ProcessingTimeTimeout extends KeyedStateTimeout -case object EventTimeTimeout extends KeyedStateTimeout +case object NoTimeout extends GroupStateTimeout +case object ProcessingTimeTimeout extends GroupStateTimeout +case object EventTimeTimeout extends GroupStateTimeout /** Factory for constructing new `MapGroupsWithState` nodes. */ object FlatMapGroupsWithState { def apply[K: Encoder, V: Encoder, S: Encoder, U: Encoder]( - func: (Any, Iterator[Any], LogicalKeyedState[Any]) => Iterator[Any], + func: (Any, Iterator[Any], LogicalGroupState[Any]) => Iterator[Any], groupingAttributes: Seq[Attribute], dataAttributes: Seq[Attribute], outputMode: OutputMode, isMapGroupsWithState: Boolean, - timeout: KeyedStateTimeout, + timeout: GroupStateTimeout, child: LogicalPlan): LogicalPlan = { val encoder = encoderFor[S] @@ -404,7 +404,7 @@ object FlatMapGroupsWithState { * @param timeout used to timeout groups that have not received data in a while */ case class FlatMapGroupsWithState( - func: (Any, Iterator[Any], LogicalKeyedState[Any]) => Iterator[Any], + func: (Any, Iterator[Any], LogicalGroupState[Any]) => Iterator[Any], keyDeserializer: Expression, valueDeserializer: Expression, groupingAttributes: Seq[Attribute], @@ -413,7 +413,7 @@ case class FlatMapGroupsWithState( stateEncoder: ExpressionEncoder[Any], outputMode: OutputMode, isMapGroupsWithState: Boolean = false, - timeout: KeyedStateTimeout, + timeout: GroupStateTimeout, child: LogicalPlan) extends UnaryNode with ObjectProducer { if (isMapGroupsWithState) { diff --git a/sql/catalyst/src/test/java/org/apache/spark/sql/streaming/JavaKeyedStateTimeoutSuite.java b/sql/catalyst/src/test/java/org/apache/spark/sql/streaming/JavaGroupStateTimeoutSuite.java similarity index 70% rename from sql/catalyst/src/test/java/org/apache/spark/sql/streaming/JavaKeyedStateTimeoutSuite.java rename to sql/catalyst/src/test/java/org/apache/spark/sql/streaming/JavaGroupStateTimeoutSuite.java index 02c94b0b32449..2e8f2e3fd9f47 100644 --- a/sql/catalyst/src/test/java/org/apache/spark/sql/streaming/JavaKeyedStateTimeoutSuite.java +++ b/sql/catalyst/src/test/java/org/apache/spark/sql/streaming/JavaGroupStateTimeoutSuite.java @@ -17,13 +17,17 @@ package org.apache.spark.sql.streaming; +import org.apache.spark.sql.catalyst.plans.logical.EventTimeTimeout$; +import org.apache.spark.sql.catalyst.plans.logical.NoTimeout$; import org.apache.spark.sql.catalyst.plans.logical.ProcessingTimeTimeout$; import org.junit.Test; -public class JavaKeyedStateTimeoutSuite { +public class JavaGroupStateTimeoutSuite { @Test public void testTimeouts() { - assert(KeyedStateTimeout.ProcessingTimeTimeout() == ProcessingTimeTimeout$.MODULE$); + assert (GroupStateTimeout.ProcessingTimeTimeout() == ProcessingTimeTimeout$.MODULE$); + assert (GroupStateTimeout.EventTimeTimeout() == EventTimeTimeout$.MODULE$); + assert (GroupStateTimeout.NoTimeout() == NoTimeout$.MODULE$); } } diff --git a/sql/core/src/main/java/org/apache/spark/api/java/function/FlatMapGroupsWithStateFunction.java b/sql/core/src/main/java/org/apache/spark/api/java/function/FlatMapGroupsWithStateFunction.java index bdda8aaf734dd..026b37cabbf1c 100644 --- a/sql/core/src/main/java/org/apache/spark/api/java/function/FlatMapGroupsWithStateFunction.java +++ b/sql/core/src/main/java/org/apache/spark/api/java/function/FlatMapGroupsWithStateFunction.java @@ -22,7 +22,7 @@ import org.apache.spark.annotation.Experimental; import org.apache.spark.annotation.InterfaceStability; -import org.apache.spark.sql.streaming.KeyedState; +import org.apache.spark.sql.streaming.GroupState; /** * ::Experimental:: @@ -35,5 +35,5 @@ @Experimental @InterfaceStability.Evolving public interface FlatMapGroupsWithStateFunction extends Serializable { - Iterator call(K key, Iterator values, KeyedState state) throws Exception; + Iterator call(K key, Iterator values, GroupState state) throws Exception; } diff --git a/sql/core/src/main/java/org/apache/spark/api/java/function/MapGroupsWithStateFunction.java b/sql/core/src/main/java/org/apache/spark/api/java/function/MapGroupsWithStateFunction.java index 70f3f01a8e9da..353e9886a8a57 100644 --- a/sql/core/src/main/java/org/apache/spark/api/java/function/MapGroupsWithStateFunction.java +++ b/sql/core/src/main/java/org/apache/spark/api/java/function/MapGroupsWithStateFunction.java @@ -22,7 +22,7 @@ import org.apache.spark.annotation.Experimental; import org.apache.spark.annotation.InterfaceStability; -import org.apache.spark.sql.streaming.KeyedState; +import org.apache.spark.sql.streaming.GroupState; /** * ::Experimental:: @@ -34,5 +34,5 @@ @Experimental @InterfaceStability.Evolving public interface MapGroupsWithStateFunction extends Serializable { - R call(K key, Iterator values, KeyedState state) throws Exception; + R call(K key, Iterator values, GroupState state) throws Exception; } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala index 96437f868a6e0..87c5621768872 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala @@ -27,7 +27,7 @@ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.streaming.InternalOutputModes import org.apache.spark.sql.execution.QueryExecution import org.apache.spark.sql.expressions.ReduceAggregator -import org.apache.spark.sql.streaming.{KeyedState, KeyedStateTimeout, OutputMode} +import org.apache.spark.sql.streaming.{GroupState, GroupStateTimeout, OutputMode} /** * :: Experimental :: @@ -228,7 +228,7 @@ class KeyValueGroupedDataset[K, V] private[sql]( * For a static batch Dataset, the function will be invoked once per group. For a streaming * Dataset, the function will be invoked for each group repeatedly in every trigger, and * updates to each group's state will be saved across invocations. - * See [[org.apache.spark.sql.streaming.KeyedState]] for more details. + * See [[org.apache.spark.sql.streaming.GroupState]] for more details. * * @tparam S The type of the user-defined state. Must be encodable to Spark SQL types. * @tparam U The type of the output objects. Must be encodable to Spark SQL types. @@ -240,17 +240,17 @@ class KeyValueGroupedDataset[K, V] private[sql]( @Experimental @InterfaceStability.Evolving def mapGroupsWithState[S: Encoder, U: Encoder]( - func: (K, Iterator[V], KeyedState[S]) => U): Dataset[U] = { - val flatMapFunc = (key: K, it: Iterator[V], s: KeyedState[S]) => Iterator(func(key, it, s)) + func: (K, Iterator[V], GroupState[S]) => U): Dataset[U] = { + val flatMapFunc = (key: K, it: Iterator[V], s: GroupState[S]) => Iterator(func(key, it, s)) Dataset[U]( sparkSession, FlatMapGroupsWithState[K, V, S, U]( - flatMapFunc.asInstanceOf[(Any, Iterator[Any], LogicalKeyedState[Any]) => Iterator[Any]], + flatMapFunc.asInstanceOf[(Any, Iterator[Any], LogicalGroupState[Any]) => Iterator[Any]], groupingAttributes, dataAttributes, OutputMode.Update, isMapGroupsWithState = true, - KeyedStateTimeout.NoTimeout, + GroupStateTimeout.NoTimeout, child = logicalPlan)) } @@ -262,7 +262,7 @@ class KeyValueGroupedDataset[K, V] private[sql]( * For a static batch Dataset, the function will be invoked once per group. For a streaming * Dataset, the function will be invoked for each group repeatedly in every trigger, and * updates to each group's state will be saved across invocations. - * See [[org.apache.spark.sql.streaming.KeyedState]] for more details. + * See [[org.apache.spark.sql.streaming.GroupState]] for more details. * * @tparam S The type of the user-defined state. Must be encodable to Spark SQL types. * @tparam U The type of the output objects. Must be encodable to Spark SQL types. @@ -275,13 +275,13 @@ class KeyValueGroupedDataset[K, V] private[sql]( @Experimental @InterfaceStability.Evolving def mapGroupsWithState[S: Encoder, U: Encoder]( - timeoutConf: KeyedStateTimeout)( - func: (K, Iterator[V], KeyedState[S]) => U): Dataset[U] = { - val flatMapFunc = (key: K, it: Iterator[V], s: KeyedState[S]) => Iterator(func(key, it, s)) + timeoutConf: GroupStateTimeout)( + func: (K, Iterator[V], GroupState[S]) => U): Dataset[U] = { + val flatMapFunc = (key: K, it: Iterator[V], s: GroupState[S]) => Iterator(func(key, it, s)) Dataset[U]( sparkSession, FlatMapGroupsWithState[K, V, S, U]( - flatMapFunc.asInstanceOf[(Any, Iterator[Any], LogicalKeyedState[Any]) => Iterator[Any]], + flatMapFunc.asInstanceOf[(Any, Iterator[Any], LogicalGroupState[Any]) => Iterator[Any]], groupingAttributes, dataAttributes, OutputMode.Update, @@ -298,7 +298,7 @@ class KeyValueGroupedDataset[K, V] private[sql]( * For a static batch Dataset, the function will be invoked once per group. For a streaming * Dataset, the function will be invoked for each group repeatedly in every trigger, and * updates to each group's state will be saved across invocations. - * See [[KeyedState]] for more details. + * See [[GroupState]] for more details. * * @tparam S The type of the user-defined state. Must be encodable to Spark SQL types. * @tparam U The type of the output objects. Must be encodable to Spark SQL types. @@ -316,7 +316,7 @@ class KeyValueGroupedDataset[K, V] private[sql]( stateEncoder: Encoder[S], outputEncoder: Encoder[U]): Dataset[U] = { mapGroupsWithState[S, U]( - (key: K, it: Iterator[V], s: KeyedState[S]) => func.call(key, it.asJava, s) + (key: K, it: Iterator[V], s: GroupState[S]) => func.call(key, it.asJava, s) )(stateEncoder, outputEncoder) } @@ -328,7 +328,7 @@ class KeyValueGroupedDataset[K, V] private[sql]( * For a static batch Dataset, the function will be invoked once per group. For a streaming * Dataset, the function will be invoked for each group repeatedly in every trigger, and * updates to each group's state will be saved across invocations. - * See [[KeyedState]] for more details. + * See [[GroupState]] for more details. * * @tparam S The type of the user-defined state. Must be encodable to Spark SQL types. * @tparam U The type of the output objects. Must be encodable to Spark SQL types. @@ -346,9 +346,9 @@ class KeyValueGroupedDataset[K, V] private[sql]( func: MapGroupsWithStateFunction[K, V, S, U], stateEncoder: Encoder[S], outputEncoder: Encoder[U], - timeoutConf: KeyedStateTimeout): Dataset[U] = { + timeoutConf: GroupStateTimeout): Dataset[U] = { mapGroupsWithState[S, U]( - (key: K, it: Iterator[V], s: KeyedState[S]) => func.call(key, it.asJava, s) + (key: K, it: Iterator[V], s: GroupState[S]) => func.call(key, it.asJava, s) )(stateEncoder, outputEncoder) } @@ -360,7 +360,7 @@ class KeyValueGroupedDataset[K, V] private[sql]( * For a static batch Dataset, the function will be invoked once per group. For a streaming * Dataset, the function will be invoked for each group repeatedly in every trigger, and * updates to each group's state will be saved across invocations. - * See [[KeyedState]] for more details. + * See [[GroupState]] for more details. * * @tparam S The type of the user-defined state. Must be encodable to Spark SQL types. * @tparam U The type of the output objects. Must be encodable to Spark SQL types. @@ -375,15 +375,15 @@ class KeyValueGroupedDataset[K, V] private[sql]( @InterfaceStability.Evolving def flatMapGroupsWithState[S: Encoder, U: Encoder]( outputMode: OutputMode, - timeoutConf: KeyedStateTimeout)( - func: (K, Iterator[V], KeyedState[S]) => Iterator[U]): Dataset[U] = { + timeoutConf: GroupStateTimeout)( + func: (K, Iterator[V], GroupState[S]) => Iterator[U]): Dataset[U] = { if (outputMode != OutputMode.Append && outputMode != OutputMode.Update) { throw new IllegalArgumentException("The output mode of function should be append or update") } Dataset[U]( sparkSession, FlatMapGroupsWithState[K, V, S, U]( - func.asInstanceOf[(Any, Iterator[Any], LogicalKeyedState[Any]) => Iterator[Any]], + func.asInstanceOf[(Any, Iterator[Any], LogicalGroupState[Any]) => Iterator[Any]], groupingAttributes, dataAttributes, outputMode, @@ -400,7 +400,7 @@ class KeyValueGroupedDataset[K, V] private[sql]( * For a static batch Dataset, the function will be invoked once per group. For a streaming * Dataset, the function will be invoked for each group repeatedly in every trigger, and * updates to each group's state will be saved across invocations. - * See [[KeyedState]] for more details. + * See [[GroupState]] for more details. * * @tparam S The type of the user-defined state. Must be encodable to Spark SQL types. * @tparam U The type of the output objects. Must be encodable to Spark SQL types. @@ -420,8 +420,8 @@ class KeyValueGroupedDataset[K, V] private[sql]( outputMode: OutputMode, stateEncoder: Encoder[S], outputEncoder: Encoder[U], - timeoutConf: KeyedStateTimeout): Dataset[U] = { - val f = (key: K, it: Iterator[V], s: KeyedState[S]) => func.call(key, it.asJava, s).asScala + timeoutConf: GroupStateTimeout): Dataset[U] = { + val f = (key: K, it: Iterator[V], s: GroupState[S]) => func.call(key, it.asJava, s).asScala flatMapGroupsWithState[S, U](outputMode, timeoutConf)(f)(stateEncoder, outputEncoder) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala index fdd1bcc94be25..48c7b80bffe03 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala @@ -31,8 +31,8 @@ import org.apache.spark.sql.catalyst.expressions.objects.Invoke import org.apache.spark.sql.catalyst.plans.logical.FunctionUtils import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.Row -import org.apache.spark.sql.catalyst.plans.logical.LogicalKeyedState -import org.apache.spark.sql.execution.streaming.KeyedStateImpl +import org.apache.spark.sql.catalyst.plans.logical.LogicalGroupState +import org.apache.spark.sql.execution.streaming.GroupStateImpl import org.apache.spark.sql.types._ import org.apache.spark.util.Utils @@ -355,14 +355,14 @@ case class MapGroupsExec( object MapGroupsExec { def apply( - func: (Any, Iterator[Any], LogicalKeyedState[Any]) => TraversableOnce[Any], + func: (Any, Iterator[Any], LogicalGroupState[Any]) => TraversableOnce[Any], keyDeserializer: Expression, valueDeserializer: Expression, groupingAttributes: Seq[Attribute], dataAttributes: Seq[Attribute], outputObjAttr: Attribute, child: SparkPlan): MapGroupsExec = { - val f = (key: Any, values: Iterator[Any]) => func(key, values, new KeyedStateImpl[Any](None)) + val f = (key: Any, values: Iterator[Any]) => func(key, values, new GroupStateImpl[Any](None)) new MapGroupsExec(f, keyDeserializer, valueDeserializer, groupingAttributes, dataAttributes, outputObjAttr, child) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala index 52ad70c7dc886..c7262ea97200f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala @@ -23,9 +23,9 @@ import org.apache.spark.sql.catalyst.expressions.{Ascending, Attribute, Attribut import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.plans.physical.{ClusteredDistribution, Distribution} import org.apache.spark.sql.execution._ -import org.apache.spark.sql.execution.streaming.KeyedStateImpl.NO_TIMESTAMP +import org.apache.spark.sql.execution.streaming.GroupStateImpl.NO_TIMESTAMP import org.apache.spark.sql.execution.streaming.state._ -import org.apache.spark.sql.streaming.{KeyedStateTimeout, OutputMode} +import org.apache.spark.sql.streaming.{GroupStateTimeout, OutputMode} import org.apache.spark.sql.types.IntegerType import org.apache.spark.util.CompletionIterator @@ -44,7 +44,7 @@ import org.apache.spark.util.CompletionIterator * @param batchTimestampMs processing timestamp of the current batch. */ case class FlatMapGroupsWithStateExec( - func: (Any, Iterator[Any], LogicalKeyedState[Any]) => Iterator[Any], + func: (Any, Iterator[Any], LogicalGroupState[Any]) => Iterator[Any], keyDeserializer: Expression, valueDeserializer: Expression, groupingAttributes: Seq[Attribute], @@ -53,13 +53,13 @@ case class FlatMapGroupsWithStateExec( stateId: Option[OperatorStateId], stateEncoder: ExpressionEncoder[Any], outputMode: OutputMode, - timeoutConf: KeyedStateTimeout, + timeoutConf: GroupStateTimeout, batchTimestampMs: Option[Long], override val eventTimeWatermark: Option[Long], child: SparkPlan ) extends UnaryExecNode with ObjectProducerExec with StateStoreWriter with WatermarkSupport { - import KeyedStateImpl._ + import GroupStateImpl._ private val isTimeoutEnabled = timeoutConf != NoTimeout private val timestampTimeoutAttribute = @@ -147,7 +147,7 @@ case class FlatMapGroupsWithStateExec( private val stateSerializer = { val encoderSerializer = stateEncoder.namedExpressions if (isTimeoutEnabled) { - encoderSerializer :+ Literal(KeyedStateImpl.NO_TIMESTAMP) + encoderSerializer :+ Literal(GroupStateImpl.NO_TIMESTAMP) } else { encoderSerializer } @@ -211,7 +211,7 @@ case class FlatMapGroupsWithStateExec( val keyObj = getKeyObj(keyRow) // convert key to objects val valueObjIter = valueRowIter.map(getValueObj.apply) // convert value rows to objects val stateObjOption = getStateObj(prevStateRowOption) - val keyedState = new KeyedStateImpl( + val keyedState = new GroupStateImpl( stateObjOption, batchTimestampMs.getOrElse(NO_TIMESTAMP), eventTimeWatermark.getOrElse(NO_TIMESTAMP), @@ -247,7 +247,7 @@ case class FlatMapGroupsWithStateExec( if (shouldWriteState) { if (stateRowToWrite == null) { - // This should never happen because checks in KeyedStateImpl should avoid cases + // This should never happen because checks in GroupStateImpl should avoid cases // where empty state would need to be written throw new IllegalStateException("Attempting to write empty state") } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/KeyedStateImpl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/GroupStateImpl.scala similarity index 94% rename from sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/KeyedStateImpl.scala rename to sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/GroupStateImpl.scala index edfd35bd5dd75..148d92247d6f0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/KeyedStateImpl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/GroupStateImpl.scala @@ -22,13 +22,14 @@ import java.sql.Date import org.apache.commons.lang3.StringUtils import org.apache.spark.sql.catalyst.plans.logical.{EventTimeTimeout, ProcessingTimeTimeout} -import org.apache.spark.sql.execution.streaming.KeyedStateImpl._ -import org.apache.spark.sql.streaming.{KeyedState, KeyedStateTimeout} +import org.apache.spark.sql.execution.streaming.GroupStateImpl._ +import org.apache.spark.sql.streaming.{GroupState, GroupStateTimeout} import org.apache.spark.unsafe.types.CalendarInterval /** - * Internal implementation of the [[KeyedState]] interface. Methods are not thread-safe. + * Internal implementation of the [[GroupState]] interface. Methods are not thread-safe. + * * @param optionalValue Optional value of the state * @param batchProcessingTimeMs Processing time of current batch, used to calculate timestamp * for processing time timeouts @@ -37,19 +38,19 @@ import org.apache.spark.unsafe.types.CalendarInterval * @param hasTimedOut Whether the key for which this state wrapped is being created is * getting timed out or not. */ -private[sql] class KeyedStateImpl[S]( +private[sql] class GroupStateImpl[S]( optionalValue: Option[S], batchProcessingTimeMs: Long, eventTimeWatermarkMs: Long, - timeoutConf: KeyedStateTimeout, - override val hasTimedOut: Boolean) extends KeyedState[S] { + timeoutConf: GroupStateTimeout, + override val hasTimedOut: Boolean) extends GroupState[S] { // Constructor to create dummy state when using mapGroupsWithState in a batch query def this(optionalValue: Option[S]) = this( optionalValue, batchProcessingTimeMs = NO_TIMESTAMP, eventTimeWatermarkMs = NO_TIMESTAMP, - timeoutConf = KeyedStateTimeout.NoTimeout, + timeoutConf = GroupStateTimeout.NoTimeout, hasTimedOut = false) private var value: S = optionalValue.getOrElse(null.asInstanceOf[S]) private var defined: Boolean = optionalValue.isDefined @@ -169,7 +170,7 @@ private[sql] class KeyedStateImpl[S]( } override def toString: String = { - s"KeyedState(${getOption.map(_.toString).getOrElse("")})" + s"GroupState(${getOption.map(_.toString).getOrElse("")})" } // ========= Internal API ========= @@ -221,7 +222,7 @@ private[sql] class KeyedStateImpl[S]( } -private[sql] object KeyedStateImpl { +private[sql] object GroupStateImpl { // Value used represent the lack of valid timestamp as a long val NO_TIMESTAMP = -1L } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala index f72144a25d5cc..8dbda298c87bc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala @@ -23,13 +23,13 @@ import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder import org.apache.spark.sql.catalyst.errors._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen.{GenerateUnsafeProjection, Predicate} -import org.apache.spark.sql.catalyst.plans.logical.{EventTimeWatermark, LogicalKeyedState, ProcessingTimeTimeout} +import org.apache.spark.sql.catalyst.plans.logical.{EventTimeWatermark, LogicalGroupState, ProcessingTimeTimeout} import org.apache.spark.sql.catalyst.plans.physical.{ClusteredDistribution, Distribution, Partitioning} import org.apache.spark.sql.catalyst.streaming.InternalOutputModes._ import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics} import org.apache.spark.sql.execution.streaming.state._ -import org.apache.spark.sql.streaming.{KeyedStateTimeout, OutputMode} +import org.apache.spark.sql.streaming.{GroupStateTimeout, OutputMode} import org.apache.spark.sql.types._ import org.apache.spark.util.CompletionIterator diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/KeyedState.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/GroupState.scala similarity index 84% rename from sql/core/src/main/scala/org/apache/spark/sql/streaming/KeyedState.scala rename to sql/core/src/main/scala/org/apache/spark/sql/streaming/GroupState.scala index 461de04f6bbe2..60a4d0d8f98a1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/KeyedState.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/GroupState.scala @@ -19,14 +19,13 @@ package org.apache.spark.sql.streaming import org.apache.spark.annotation.{Experimental, InterfaceStability} import org.apache.spark.sql.{Encoder, KeyValueGroupedDataset} -import org.apache.spark.sql.catalyst.plans.logical.LogicalKeyedState +import org.apache.spark.sql.catalyst.plans.logical.LogicalGroupState /** * :: Experimental :: * - * Wrapper class for interacting with keyed state data in `mapGroupsWithState` and - * `flatMapGroupsWithState` operations on - * [[KeyValueGroupedDataset]]. + * Wrapper class for interacting with per-group state data in `mapGroupsWithState` and + * `flatMapGroupsWithState` operations on [[KeyValueGroupedDataset]]. * * Detail description on `[map/flatMap]GroupsWithState` operation * -------------------------------------------------------------- @@ -37,11 +36,11 @@ import org.apache.spark.sql.catalyst.plans.logical.LogicalKeyedState * Dataset, the function will be invoked for each group repeatedly in every trigger. * That is, in every batch of the `streaming.StreamingQuery`, * the function will be invoked once for each group that has data in the trigger. Furthermore, - * if timeout is set, then the function will invoked on timed out keys (more detail below). + * if timeout is set, then the function will invoked on timed out groups (more detail below). * * The function is invoked with following parameters. * - The key of the group. - * - An iterator containing all the values for this key. + * - An iterator containing all the values for this group. * - A user-defined state object set by previous invocations of the given function. * In case of a batch Dataset, there is only one invocation and state object will be empty as * there is no prior state. Essentially, for batch Datasets, `[map/flatMap]GroupsWithState` @@ -55,57 +54,58 @@ import org.apache.spark.sql.catalyst.plans.logical.LogicalKeyedState * batch, nor with streaming Datasets. * - All the data will be shuffled before applying the function. * - If timeout is set, then the function will also be called with no values. - * See more details on `KeyedStateTimeout` below. + * See more details on `GroupStateTimeout` below. * - * Important points to note about using `KeyedState`. + * Important points to note about using `GroupState`. * - The value of the state cannot be null. So updating state with null will throw * `IllegalArgumentException`. - * - Operations on `KeyedState` are not thread-safe. This is to avoid memory barriers. + * - Operations on `GroupState` are not thread-safe. This is to avoid memory barriers. * - If `remove()` is called, then `exists()` will return `false`, * `get()` will throw `NoSuchElementException` and `getOption()` will return `None` * - After that, if `update(newState)` is called, then `exists()` will again return `true`, * `get()` and `getOption()`will return the updated value. * - * Important points to note about using `KeyedStateTimeout`. - * - The timeout type is a global param across all the keys (set as `timeout` param in + * Important points to note about using `GroupStateTimeout`. + * - The timeout type is a global param across all the groups (set as `timeout` param in * `[map|flatMap]GroupsWithState`, but the exact timeout duration/timestamp is configurable per - * key by calling `setTimeout...()` in `KeyedState`. + * group by calling `setTimeout...()` in `GroupState`. * - Timeouts can be either based on processing time (i.e. - * [[KeyedStateTimeout.ProcessingTimeTimeout]]) or event time (i.e. - * [[KeyedStateTimeout.EventTimeTimeout]]). + * [[GroupStateTimeout.ProcessingTimeTimeout]]) or event time (i.e. + * [[GroupStateTimeout.EventTimeTimeout]]). * - With `ProcessingTimeTimeout`, the timeout duration can be set by calling - * `KeyedState.setTimeoutDuration`. The timeout will occur when the clock has advanced by the set + * `GroupState.setTimeoutDuration`. The timeout will occur when the clock has advanced by the set * duration. Guarantees provided by this timeout with a duration of D ms are as follows: * - Timeout will never be occur before the clock time has advanced by D ms * - Timeout will occur eventually when there is a trigger in the query * (i.e. after D ms). So there is a no strict upper bound on when the timeout would occur. * For example, the trigger interval of the query will affect when the timeout actually occurs. - * If there is no data in the stream (for any key) for a while, then their will not be + * If there is no data in the stream (for any group) for a while, then their will not be * any trigger and timeout function call will not occur until there is data. * - Since the processing time timeout is based on the clock time, it is affected by the * variations in the system clock (i.e. time zone changes, clock skew, etc.). * - With `EventTimeTimeout`, the user also has to specify the the the event time watermark in * the query using `Dataset.withWatermark()`. With this setting, data that is older than the - * watermark are filtered out. The timeout can be enabled for a key by setting a timestamp using - * `KeyedState.setTimeoutTimestamp()`, and the timeout would occur when the watermark advances - * beyond the set timestamp. You can control the timeout delay by two parameters - (i) watermark - * delay and an additional duration beyond the timestamp in the event (which is guaranteed to - * > watermark due to the filtering). Guarantees provided by this timeout are as follows: + * watermark are filtered out. The timeout can be set for a group by setting a timeout timestamp + * using`GroupState.setTimeoutTimestamp()`, and the timeout would occur when the watermark + * advances beyond the set timestamp. You can control the timeout delay by two parameters - + * (i) watermark delay and an additional duration beyond the timestamp in the event (which + * is guaranteed to be newer than watermark due to the filtering). Guarantees provided by this + * timeout are as follows: * - Timeout will never be occur before watermark has exceeded the set timeout. * - Similar to processing time timeouts, there is a no strict upper bound on the delay when * the timeout actually occurs. The watermark can advance only when there is data in the * stream, and the event time of the data has actually advanced. - * - When the timeout occurs for a key, the function is called for that key with no values, and - * `KeyedState.hasTimedOut()` set to true. - * - The timeout is reset for key every time the function is called on the key, that is, - * when the key has new data, or the key has timed out. So the user has to set the timeout + * - When the timeout occurs for a group, the function is called for that group with no values, and + * `GroupState.hasTimedOut()` set to true. + * - The timeout is reset every time the function is called on a group, that is, + * when the group has new data, or the group has timed out. So the user has to set the timeout * duration every time the function is called, otherwise there will not be any timeout set. * - * Scala example of using KeyedState in `mapGroupsWithState`: + * Scala example of using GroupState in `mapGroupsWithState`: * {{{ * // A mapping function that maintains an integer state for string keys and returns a string. * // Additionally, it sets a timeout to remove the state if it has not received data for an hour. - * def mappingFunction(key: String, value: Iterator[Int], state: KeyedState[Int]): String = { + * def mappingFunction(key: String, value: Iterator[Int], state: GroupState[Int]): String = { * * if (state.hasTimedOut) { // If called when timing out, remove the state * state.remove() @@ -133,10 +133,10 @@ import org.apache.spark.sql.catalyst.plans.logical.LogicalKeyedState * * dataset * .groupByKey(...) - * .mapGroupsWithState(KeyedStateTimeout.ProcessingTimeTimeout)(mappingFunction) + * .mapGroupsWithState(GroupStateTimeout.ProcessingTimeTimeout)(mappingFunction) * }}} * - * Java example of using `KeyedState`: + * Java example of using `GroupState`: * {{{ * // A mapping function that maintains an integer state for string keys and returns a string. * // Additionally, it sets a timeout to remove the state if it has not received data for an hour. @@ -144,7 +144,7 @@ import org.apache.spark.sql.catalyst.plans.logical.LogicalKeyedState * new MapGroupsWithStateFunction() { * * @Override - * public String call(String key, Iterator value, KeyedState state) { + * public String call(String key, Iterator value, GroupState state) { * if (state.hasTimedOut()) { // If called when timing out, remove the state * state.remove(); * @@ -173,16 +173,16 @@ import org.apache.spark.sql.catalyst.plans.logical.LogicalKeyedState * dataset * .groupByKey(...) * .mapGroupsWithState( - * mappingFunction, Encoders.INT, Encoders.STRING, KeyedStateTimeout.ProcessingTimeTimeout); + * mappingFunction, Encoders.INT, Encoders.STRING, GroupStateTimeout.ProcessingTimeTimeout); * }}} * - * @tparam S User-defined type of the state to be stored for each key. Must be encodable into + * @tparam S User-defined type of the state to be stored for each group. Must be encodable into * Spark SQL types (see [[Encoder]] for more details). * @since 2.2.0 */ @Experimental @InterfaceStability.Evolving -trait KeyedState[S] extends LogicalKeyedState[S] { +trait GroupState[S] extends LogicalGroupState[S] { /** Whether state exists or not. */ def exists: Boolean @@ -201,7 +201,7 @@ trait KeyedState[S] extends LogicalKeyedState[S] { @throws[IllegalArgumentException]("when updating with null") def update(newState: S): Unit - /** Remove this keyed state. Note that this resets any timeout configuration as well. */ + /** Remove this state. Note that this resets any timeout configuration as well. */ def remove(): Unit /** diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java index ffb4c6273ff85..78cf033dd81d7 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java @@ -23,7 +23,7 @@ import java.sql.Timestamp; import java.util.*; -import org.apache.spark.sql.streaming.KeyedStateTimeout; +import org.apache.spark.sql.streaming.GroupStateTimeout; import org.apache.spark.sql.streaming.OutputMode; import scala.Tuple2; import scala.Tuple3; @@ -210,7 +210,7 @@ public void testGroupBy() { OutputMode.Append(), Encoders.LONG(), Encoders.STRING(), - KeyedStateTimeout.NoTimeout()); + GroupStateTimeout.NoTimeout()); Assert.assertEquals(asSet("1a", "3foobar"), toSet(flatMapped2.collectAsList())); diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala index fe72283bb608f..3dabef6a9a35f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala @@ -30,7 +30,7 @@ import org.apache.spark.sql.catalyst.plans.logical.FlatMapGroupsWithState import org.apache.spark.sql.catalyst.plans.physical.UnknownPartitioning import org.apache.spark.sql.catalyst.streaming.InternalOutputModes._ import org.apache.spark.sql.execution.RDDScanExec -import org.apache.spark.sql.execution.streaming.{FlatMapGroupsWithStateExec, KeyedStateImpl, MemoryStream} +import org.apache.spark.sql.execution.streaming.{FlatMapGroupsWithStateExec, GroupStateImpl, MemoryStream} import org.apache.spark.sql.execution.streaming.state.{StateStore, StateStoreId, StoreUpdate} import org.apache.spark.sql.streaming.FlatMapGroupsWithStateSuite.MemoryStateStore import org.apache.spark.sql.types.{DataType, IntegerType} @@ -43,16 +43,16 @@ case class Result(key: Long, count: Int) class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest with BeforeAndAfterAll { import testImplicits._ - import KeyedStateImpl._ - import KeyedStateTimeout._ + import GroupStateImpl._ + import GroupStateTimeout._ override def afterAll(): Unit = { super.afterAll() StateStore.stop() } - test("KeyedState - get, exists, update, remove") { - var state: KeyedStateImpl[String] = null + test("GroupState - get, exists, update, remove") { + var state: GroupStateImpl[String] = null def testState( expectedData: Option[String], @@ -73,13 +73,13 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest with BeforeAndAf } // Updating empty state - state = new KeyedStateImpl[String](None) + state = new GroupStateImpl[String](None) testState(None) state.update("") testState(Some(""), shouldBeUpdated = true) // Updating exiting state - state = new KeyedStateImpl[String](Some("2")) + state = new GroupStateImpl[String](Some("2")) testState(Some("2")) state.update("3") testState(Some("3"), shouldBeUpdated = true) @@ -97,19 +97,19 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest with BeforeAndAf } } - test("KeyedState - setTimeout**** with NoTimeout") { + test("GroupState - setTimeout**** with NoTimeout") { for (initState <- Seq(None, Some(5))) { // for different initial state - implicit val state = new KeyedStateImpl(initState, 1000, 1000, NoTimeout, hasTimedOut = false) + implicit val state = new GroupStateImpl(initState, 1000, 1000, NoTimeout, hasTimedOut = false) testTimeoutDurationNotAllowed[UnsupportedOperationException](state) testTimeoutTimestampNotAllowed[UnsupportedOperationException](state) } } - test("KeyedState - setTimeout**** with ProcessingTimeTimeout") { - implicit var state: KeyedStateImpl[Int] = null + test("GroupState - setTimeout**** with ProcessingTimeTimeout") { + implicit var state: GroupStateImpl[Int] = null - state = new KeyedStateImpl[Int](None, 1000, 1000, ProcessingTimeTimeout, hasTimedOut = false) + state = new GroupStateImpl[Int](None, 1000, 1000, ProcessingTimeTimeout, hasTimedOut = false) assert(state.getTimeoutTimestamp === NO_TIMESTAMP) testTimeoutDurationNotAllowed[IllegalStateException](state) testTimeoutTimestampNotAllowed[UnsupportedOperationException](state) @@ -128,8 +128,8 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest with BeforeAndAf testTimeoutTimestampNotAllowed[UnsupportedOperationException](state) } - test("KeyedState - setTimeout**** with EventTimeTimeout") { - implicit val state = new KeyedStateImpl[Int]( + test("GroupState - setTimeout**** with EventTimeTimeout") { + implicit val state = new GroupStateImpl[Int]( None, 1000, 1000, EventTimeTimeout, hasTimedOut = false) assert(state.getTimeoutTimestamp === NO_TIMESTAMP) testTimeoutDurationNotAllowed[UnsupportedOperationException](state) @@ -148,8 +148,8 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest with BeforeAndAf testTimeoutTimestampNotAllowed[IllegalStateException](state) } - test("KeyedState - illegal params to setTimeout****") { - var state: KeyedStateImpl[Int] = null + test("GroupState - illegal params to setTimeout****") { + var state: GroupStateImpl[Int] = null // Test setTimeout****() with illegal values def testIllegalTimeout(body: => Unit): Unit = { @@ -157,14 +157,14 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest with BeforeAndAf assert(state.getTimeoutTimestamp === NO_TIMESTAMP) } - state = new KeyedStateImpl(Some(5), 1000, 1000, ProcessingTimeTimeout, hasTimedOut = false) + state = new GroupStateImpl(Some(5), 1000, 1000, ProcessingTimeTimeout, hasTimedOut = false) testIllegalTimeout { state.setTimeoutDuration(-1000) } testIllegalTimeout { state.setTimeoutDuration(0) } testIllegalTimeout { state.setTimeoutDuration("-2 second") } testIllegalTimeout { state.setTimeoutDuration("-1 month") } testIllegalTimeout { state.setTimeoutDuration("1 month -1 day") } - state = new KeyedStateImpl(Some(5), 1000, 1000, EventTimeTimeout, hasTimedOut = false) + state = new GroupStateImpl(Some(5), 1000, 1000, EventTimeTimeout, hasTimedOut = false) testIllegalTimeout { state.setTimeoutTimestamp(-10000) } testIllegalTimeout { state.setTimeoutTimestamp(10000, "-3 second") } testIllegalTimeout { state.setTimeoutTimestamp(10000, "-1 month") } @@ -175,25 +175,25 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest with BeforeAndAf testIllegalTimeout { state.setTimeoutTimestamp(new Date(-10000), "1 month -1 day") } } - test("KeyedState - hasTimedOut") { + test("GroupState - hasTimedOut") { for (timeoutConf <- Seq(NoTimeout, ProcessingTimeTimeout, EventTimeTimeout)) { for (initState <- Seq(None, Some(5))) { - val state1 = new KeyedStateImpl(initState, 1000, 1000, timeoutConf, hasTimedOut = false) + val state1 = new GroupStateImpl(initState, 1000, 1000, timeoutConf, hasTimedOut = false) assert(state1.hasTimedOut === false) - val state2 = new KeyedStateImpl(initState, 1000, 1000, timeoutConf, hasTimedOut = true) + val state2 = new GroupStateImpl(initState, 1000, 1000, timeoutConf, hasTimedOut = true) assert(state2.hasTimedOut === true) } } } - test("KeyedState - primitive type") { - var intState = new KeyedStateImpl[Int](None) + test("GroupState - primitive type") { + var intState = new GroupStateImpl[Int](None) intercept[NoSuchElementException] { intState.get } assert(intState.getOption === None) - intState = new KeyedStateImpl[Int](Some(10)) + intState = new GroupStateImpl[Int](Some(10)) assert(intState.get == 10) intState.update(0) assert(intState.get == 0) @@ -218,21 +218,21 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest with BeforeAndAf testStateUpdateWithData( testName + "no update", stateUpdates = state => { /* do nothing */ }, - timeoutConf = KeyedStateTimeout.NoTimeout, + timeoutConf = GroupStateTimeout.NoTimeout, priorState = priorState, expectedState = priorState) // should not change testStateUpdateWithData( testName + "state updated", stateUpdates = state => { state.update(5) }, - timeoutConf = KeyedStateTimeout.NoTimeout, + timeoutConf = GroupStateTimeout.NoTimeout, priorState = priorState, expectedState = Some(5)) // should change testStateUpdateWithData( testName + "state removed", stateUpdates = state => { state.remove() }, - timeoutConf = KeyedStateTimeout.NoTimeout, + timeoutConf = GroupStateTimeout.NoTimeout, priorState = priorState, expectedState = None) // should be removed } @@ -283,7 +283,7 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest with BeforeAndAf testStateUpdateWithData( s"ProcessingTimeTimeout - $testName - state and timeout duration updated", stateUpdates = - (state: KeyedState[Int]) => { state.update(5); state.setTimeoutDuration(5000) }, + (state: GroupState[Int]) => { state.update(5); state.setTimeoutDuration(5000) }, timeoutConf = ProcessingTimeTimeout, priorState = priorState, priorTimeoutTimestamp = priorTimeoutTimestamp, @@ -293,7 +293,7 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest with BeforeAndAf testStateUpdateWithData( s"EventTimeTimeout - $testName - state and timeout timestamp updated", stateUpdates = - (state: KeyedState[Int]) => { state.update(5); state.setTimeoutTimestamp(5000) }, + (state: GroupState[Int]) => { state.update(5); state.setTimeoutTimestamp(5000) }, timeoutConf = EventTimeTimeout, priorState = priorState, priorTimeoutTimestamp = priorTimeoutTimestamp, @@ -303,7 +303,7 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest with BeforeAndAf testStateUpdateWithData( s"EventTimeTimeout - $testName - timeout timestamp updated to before watermark", stateUpdates = - (state: KeyedState[Int]) => { + (state: GroupState[Int]) => { state.update(5) intercept[IllegalArgumentException] { state.setTimeoutTimestamp(currentBatchWatermark - 1) // try to set to < watermark @@ -387,7 +387,7 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest with BeforeAndAf test("StateStoreUpdater - rows are cloned before writing to StateStore") { // function for running count - val func = (key: Int, values: Iterator[Int], state: KeyedState[Int]) => { + val func = (key: Int, values: Iterator[Int], state: GroupState[Int]) => { state.update(state.getOption.getOrElse(0) + values.size) Iterator.empty } @@ -404,7 +404,7 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest with BeforeAndAf test("flatMapGroupsWithState - streaming") { // Function to maintain running count up to 2, and then remove the count // Returns the data and the count if state is defined, otherwise does not return anything - val stateFunc = (key: String, values: Iterator[String], state: KeyedState[RunningCount]) => { + val stateFunc = (key: String, values: Iterator[String], state: GroupState[RunningCount]) => { val count = state.getOption.map(_.count).getOrElse(0L) + values.size if (count == 3) { @@ -420,7 +420,7 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest with BeforeAndAf val result = inputData.toDS() .groupByKey(x => x) - .flatMapGroupsWithState(Update, KeyedStateTimeout.NoTimeout)(stateFunc) + .flatMapGroupsWithState(Update, GroupStateTimeout.NoTimeout)(stateFunc) testStream(result, Update)( AddData(inputData, "a"), @@ -446,7 +446,7 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest with BeforeAndAf // Function to maintain running count up to 2, and then remove the count // Returns the data and the count if state is defined, otherwise does not return anything // Additionally, it updates state lazily as the returned iterator get consumed - val stateFunc = (key: String, values: Iterator[String], state: KeyedState[RunningCount]) => { + val stateFunc = (key: String, values: Iterator[String], state: GroupState[RunningCount]) => { values.flatMap { _ => val count = state.getOption.map(_.count).getOrElse(0L) + 1 if (count == 3) { @@ -463,7 +463,7 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest with BeforeAndAf val result = inputData.toDS() .groupByKey(x => x) - .flatMapGroupsWithState(Update, KeyedStateTimeout.NoTimeout)(stateFunc) + .flatMapGroupsWithState(Update, GroupStateTimeout.NoTimeout)(stateFunc) testStream(result, Update)( AddData(inputData, "a", "a", "b"), CheckLastBatch(("a", "1"), ("a", "2"), ("b", "1")), @@ -481,7 +481,7 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest with BeforeAndAf test("flatMapGroupsWithState - streaming + aggregation") { // Function to maintain running count up to 2, and then remove the count // Returns the data and the count (-1 if count reached beyond 2 and state was just removed) - val stateFunc = (key: String, values: Iterator[String], state: KeyedState[RunningCount]) => { + val stateFunc = (key: String, values: Iterator[String], state: GroupState[RunningCount]) => { val count = state.getOption.map(_.count).getOrElse(0L) + values.size if (count == 3) { @@ -497,7 +497,7 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest with BeforeAndAf val result = inputData.toDS() .groupByKey(x => x) - .flatMapGroupsWithState(Append, KeyedStateTimeout.NoTimeout)(stateFunc) + .flatMapGroupsWithState(Append, GroupStateTimeout.NoTimeout)(stateFunc) .groupByKey(_._1) .count() @@ -524,20 +524,20 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest with BeforeAndAf test("flatMapGroupsWithState - batch") { // Function that returns running count only if its even, otherwise does not return - val stateFunc = (key: String, values: Iterator[String], state: KeyedState[RunningCount]) => { + val stateFunc = (key: String, values: Iterator[String], state: GroupState[RunningCount]) => { if (state.exists) throw new IllegalArgumentException("state.exists should be false") Iterator((key, values.size)) } val df = Seq("a", "a", "b").toDS .groupByKey(x => x) - .flatMapGroupsWithState(Update, KeyedStateTimeout.NoTimeout)(stateFunc).toDF + .flatMapGroupsWithState(Update, GroupStateTimeout.NoTimeout)(stateFunc).toDF checkAnswer(df, Seq(("a", 2), ("b", 1)).toDF) } test("flatMapGroupsWithState - streaming with processing time timeout") { // Function to maintain running count up to 2, and then remove the count // Returns the data and the count (-1 if count reached beyond 2 and state was just removed) - val stateFunc = (key: String, values: Iterator[String], state: KeyedState[RunningCount]) => { + val stateFunc = (key: String, values: Iterator[String], state: GroupState[RunningCount]) => { if (state.hasTimedOut) { state.remove() Iterator((key, "-1")) @@ -594,7 +594,7 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest with BeforeAndAf val stateFunc = ( key: String, values: Iterator[(String, Long)], - state: KeyedState[Long]) => { + state: GroupState[Long]) => { val timeoutDelay = 5 if (key != "a") { Iterator.empty @@ -637,7 +637,7 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest with BeforeAndAf test("mapGroupsWithState - streaming") { // Function to maintain running count up to 2, and then remove the count // Returns the data and the count (-1 if count reached beyond 2 and state was just removed) - val stateFunc = (key: String, values: Iterator[String], state: KeyedState[RunningCount]) => { + val stateFunc = (key: String, values: Iterator[String], state: GroupState[RunningCount]) => { val count = state.getOption.map(_.count).getOrElse(0L) + values.size if (count == 3) { @@ -676,7 +676,7 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest with BeforeAndAf } test("mapGroupsWithState - batch") { - val stateFunc = (key: String, values: Iterator[String], state: KeyedState[RunningCount]) => { + val stateFunc = (key: String, values: Iterator[String], state: GroupState[RunningCount]) => { if (state.exists) throw new IllegalArgumentException("state.exists should be false") (key, values.size) } @@ -690,7 +690,7 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest with BeforeAndAf } testQuietly("StateStore.abort on task failure handling") { - val stateFunc = (key: String, values: Iterator[String], state: KeyedState[RunningCount]) => { + val stateFunc = (key: String, values: Iterator[String], state: GroupState[RunningCount]) => { if (FlatMapGroupsWithStateSuite.failInTask) throw new Exception("expected failure") val count = state.getOption.map(_.count).getOrElse(0L) + values.size state.update(RunningCount(count)) @@ -724,7 +724,7 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest with BeforeAndAf } test("output partitioning is unknown") { - val stateFunc = (key: String, values: Iterator[String], state: KeyedState[RunningCount]) => key + val stateFunc = (key: String, values: Iterator[String], state: GroupState[RunningCount]) => key val inputData = MemoryStream[String] val result = inputData.toDS.groupByKey(x => x).mapGroupsWithState(stateFunc) testStream(result, Update)( @@ -735,13 +735,13 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest with BeforeAndAf } test("disallow complete mode") { - val stateFunc = (key: String, values: Iterator[String], state: KeyedState[Int]) => { + val stateFunc = (key: String, values: Iterator[String], state: GroupState[Int]) => { Iterator[String]() } var e = intercept[IllegalArgumentException] { MemoryStream[String].toDS().groupByKey(x => x).flatMapGroupsWithState( - OutputMode.Complete, KeyedStateTimeout.NoTimeout)(stateFunc) + OutputMode.Complete, GroupStateTimeout.NoTimeout)(stateFunc) } assert(e.getMessage === "The output mode of function should be append or update") @@ -750,20 +750,20 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest with BeforeAndAf override def call( key: String, values: JIterator[String], - state: KeyedState[Int]): JIterator[String] = { null } + state: GroupState[Int]): JIterator[String] = { null } } e = intercept[IllegalArgumentException] { MemoryStream[String].toDS().groupByKey(x => x).flatMapGroupsWithState( javaStateFunc, OutputMode.Complete, - implicitly[Encoder[Int]], implicitly[Encoder[String]], KeyedStateTimeout.NoTimeout) + implicitly[Encoder[Int]], implicitly[Encoder[String]], GroupStateTimeout.NoTimeout) } assert(e.getMessage === "The output mode of function should be append or update") } def testStateUpdateWithData( testName: String, - stateUpdates: KeyedState[Int] => Unit, - timeoutConf: KeyedStateTimeout, + stateUpdates: GroupState[Int] => Unit, + timeoutConf: GroupStateTimeout, priorState: Option[Int], priorTimeoutTimestamp: Long = NO_TIMESTAMP, expectedState: Option[Int] = None, @@ -773,7 +773,7 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest with BeforeAndAf return // there can be no prior timestamp, when there is no prior state } test(s"StateStoreUpdater - updates with data - $testName") { - val mapGroupsFunc = (key: Int, values: Iterator[Int], state: KeyedState[Int]) => { + val mapGroupsFunc = (key: Int, values: Iterator[Int], state: GroupState[Int]) => { assert(state.hasTimedOut === false, "hasTimedOut not false") assert(values.nonEmpty, "Some value is expected") stateUpdates(state) @@ -787,14 +787,14 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest with BeforeAndAf def testStateUpdateWithTimeout( testName: String, - stateUpdates: KeyedState[Int] => Unit, - timeoutConf: KeyedStateTimeout, + stateUpdates: GroupState[Int] => Unit, + timeoutConf: GroupStateTimeout, priorTimeoutTimestamp: Long, expectedState: Option[Int], expectedTimeoutTimestamp: Long = NO_TIMESTAMP): Unit = { test(s"StateStoreUpdater - updates for timeout - $testName") { - val mapGroupsFunc = (key: Int, values: Iterator[Int], state: KeyedState[Int]) => { + val mapGroupsFunc = (key: Int, values: Iterator[Int], state: GroupState[Int]) => { assert(state.hasTimedOut === true, "hasTimedOut not true") assert(values.isEmpty, "values not empty") stateUpdates(state) @@ -808,8 +808,8 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest with BeforeAndAf def testStateUpdate( testTimeoutUpdates: Boolean, - mapGroupsFunc: (Int, Iterator[Int], KeyedState[Int]) => Iterator[Int], - timeoutConf: KeyedStateTimeout, + mapGroupsFunc: (Int, Iterator[Int], GroupState[Int]) => Iterator[Int], + timeoutConf: GroupStateTimeout, priorState: Option[Int], priorTimeoutTimestamp: Long, expectedState: Option[Int], @@ -848,8 +848,8 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest with BeforeAndAf } def newFlatMapGroupsWithStateExec( - func: (Int, Iterator[Int], KeyedState[Int]) => Iterator[Int], - timeoutType: KeyedStateTimeout = KeyedStateTimeout.NoTimeout, + func: (Int, Iterator[Int], GroupState[Int]) => Iterator[Int], + timeoutType: GroupStateTimeout = GroupStateTimeout.NoTimeout, batchTimestampMs: Long = NO_TIMESTAMP): FlatMapGroupsWithStateExec = { MemoryStream[Int] .toDS @@ -863,7 +863,7 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest with BeforeAndAf }.get } - def testTimeoutDurationNotAllowed[T <: Exception: Manifest](state: KeyedStateImpl[_]): Unit = { + def testTimeoutDurationNotAllowed[T <: Exception: Manifest](state: GroupStateImpl[_]): Unit = { val prevTimestamp = state.getTimeoutTimestamp intercept[T] { state.setTimeoutDuration(1000) } assert(state.getTimeoutTimestamp === prevTimestamp) @@ -871,7 +871,7 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest with BeforeAndAf assert(state.getTimeoutTimestamp === prevTimestamp) } - def testTimeoutTimestampNotAllowed[T <: Exception: Manifest](state: KeyedStateImpl[_]): Unit = { + def testTimeoutTimestampNotAllowed[T <: Exception: Manifest](state: GroupStateImpl[_]): Unit = { val prevTimestamp = state.getTimeoutTimestamp intercept[T] { state.setTimeoutTimestamp(2000) } assert(state.getTimeoutTimestamp === prevTimestamp) From 12cd00706cbfff4c8ac681fcae65b4c4c8751877 Mon Sep 17 00:00:00 2001 From: Sameer Agarwal Date: Wed, 22 Mar 2017 15:58:42 -0700 Subject: [PATCH 0091/1765] [BUILD][MINOR] Fix 2.10 build ## What changes were proposed in this pull request? https://github.com/apache/spark/pull/17385 breaks the 2.10 sbt/maven builds by hitting an empty-string interpolation bug (https://issues.scala-lang.org/browse/SI-7919). https://amplab.cs.berkeley.edu/jenkins/view/Spark%20QA%20Compile/job/spark-master-compile-sbt-scala-2.10/4072/ https://amplab.cs.berkeley.edu/jenkins/view/Spark%20QA%20Compile/job/spark-master-compile-maven-scala-2.10/3987/ ## How was this patch tested? Compiles Author: Sameer Agarwal Closes #17391 from sameeragarwal/build-fix. --- .../spark/sql/streaming/FlatMapGroupsWithStateSuite.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala index 3dabef6a9a35f..89a25973afdd1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala @@ -240,7 +240,7 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest with BeforeAndAf // Tests for StateStoreUpdater.updateStateForKeysWithData() when timeout != NoTimeout for (priorState <- Seq(None, Some(0))) { for (priorTimeoutTimestamp <- Seq(NO_TIMESTAMP, 1000)) { - var testName = s"" + var testName = "" if (priorState.nonEmpty) { testName += "prior state set, " if (priorTimeoutTimestamp == 1000) { From 07c12c09a75645f6b56b30654455b3838b7b6637 Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Thu, 23 Mar 2017 00:25:01 -0700 Subject: [PATCH 0092/1765] [SPARK-18579][SQL] Use ignoreLeadingWhiteSpace and ignoreTrailingWhiteSpace options in CSV writing ## What changes were proposed in this pull request? This PR proposes to support _not_ trimming the white spaces when writing out. These are `false` by default in CSV reading path but these are `true` by default in CSV writing in univocity parser. Both `ignoreLeadingWhiteSpace` and `ignoreTrailingWhiteSpace` options are not being used for writing and therefore, we are always trimming the white spaces. It seems we should provide a way to keep this white spaces easily. WIth the data below: ```scala val df = spark.read.csv(Seq("a , b , c").toDS) df.show() ``` ``` +---+----+---+ |_c0| _c1|_c2| +---+----+---+ | a | b | c| +---+----+---+ ``` **Before** ```scala df.write.csv("/tmp/text.csv") spark.read.text("/tmp/text.csv").show() ``` ``` +-----+ |value| +-----+ |a,b,c| +-----+ ``` It seems this can't be worked around via `quoteAll` too. ```scala df.write.option("quoteAll", true).csv("/tmp/text.csv") spark.read.text("/tmp/text.csv").show() ``` ``` +-----------+ | value| +-----------+ |"a","b","c"| +-----------+ ``` **After** ```scala df.write.option("ignoreLeadingWhiteSpace", false).option("ignoreTrailingWhiteSpace", false).csv("/tmp/text.csv") spark.read.text("/tmp/text.csv").show() ``` ``` +----------+ | value| +----------+ |a , b , c| +----------+ ``` Note that this case is possible in R ```r > system("cat text.csv") f1,f2,f3 a , b , c > df <- read.csv(file="text.csv") > df f1 f2 f3 1 a b c > write.csv(df, file="text1.csv", quote=F, row.names=F) > system("cat text1.csv") f1,f2,f3 a , b , c ``` ## How was this patch tested? Unit tests in `CSVSuite` and manual tests for Python. Author: hyukjinkwon Closes #17310 from HyukjinKwon/SPARK-18579. --- python/pyspark/sql/readwriter.py | 28 +++++---- python/pyspark/sql/streaming.py | 12 ++-- python/pyspark/sql/tests.py | 13 +++++ .../apache/spark/sql/DataFrameReader.scala | 6 +- .../apache/spark/sql/DataFrameWriter.scala | 6 +- .../datasources/csv/CSVOptions.scala | 15 +++-- .../sql/streaming/DataStreamReader.scala | 6 +- .../execution/datasources/csv/CSVSuite.scala | 57 +++++++++++++++++++ 8 files changed, 116 insertions(+), 27 deletions(-) diff --git a/python/pyspark/sql/readwriter.py b/python/pyspark/sql/readwriter.py index 759c27507c397..5e732b4bec8fd 100644 --- a/python/pyspark/sql/readwriter.py +++ b/python/pyspark/sql/readwriter.py @@ -341,12 +341,12 @@ def csv(self, path, schema=None, sep=None, encoding=None, quote=None, escape=Non default value, ``false``. :param inferSchema: infers the input schema automatically from data. It requires one extra pass over the data. If None is set, it uses the default value, ``false``. - :param ignoreLeadingWhiteSpace: defines whether or not leading whitespaces from values - being read should be skipped. If None is set, it uses - the default value, ``false``. - :param ignoreTrailingWhiteSpace: defines whether or not trailing whitespaces from values - being read should be skipped. If None is set, it uses - the default value, ``false``. + :param ignoreLeadingWhiteSpace: A flag indicating whether or not leading whitespaces from + values being read should be skipped. If None is set, it + uses the default value, ``false``. + :param ignoreTrailingWhiteSpace: A flag indicating whether or not trailing whitespaces from + values being read should be skipped. If None is set, it + uses the default value, ``false``. :param nullValue: sets the string representation of a null value. If None is set, it uses the default value, empty string. Since 2.0.1, this ``nullValue`` param applies to all supported types including the string type. @@ -706,7 +706,7 @@ def text(self, path, compression=None): @since(2.0) def csv(self, path, mode=None, compression=None, sep=None, quote=None, escape=None, header=None, nullValue=None, escapeQuotes=None, quoteAll=None, dateFormat=None, - timestampFormat=None): + timestampFormat=None, ignoreLeadingWhiteSpace=None, ignoreTrailingWhiteSpace=None): """Saves the content of the :class:`DataFrame` in CSV format at the specified path. :param path: the path in any Hadoop supported file system @@ -728,10 +728,10 @@ def csv(self, path, mode=None, compression=None, sep=None, quote=None, escape=No empty string. :param escape: sets the single character used for escaping quotes inside an already quoted value. If None is set, it uses the default value, ``\`` - :param escapeQuotes: A flag indicating whether values containing quotes should always + :param escapeQuotes: a flag indicating whether values containing quotes should always be enclosed in quotes. If None is set, it uses the default value ``true``, escaping all values containing a quote character. - :param quoteAll: A flag indicating whether all values should always be enclosed in + :param quoteAll: a flag indicating whether all values should always be enclosed in quotes. If None is set, it uses the default value ``false``, only escaping values containing a quote character. :param header: writes the names of columns as the first line. If None is set, it uses @@ -746,13 +746,21 @@ def csv(self, path, mode=None, compression=None, sep=None, quote=None, escape=No formats follow the formats at ``java.text.SimpleDateFormat``. This applies to timestamp type. If None is set, it uses the default value, ``yyyy-MM-dd'T'HH:mm:ss.SSSZZ``. + :param ignoreLeadingWhiteSpace: a flag indicating whether or not leading whitespaces from + values being written should be skipped. If None is set, it + uses the default value, ``true``. + :param ignoreTrailingWhiteSpace: a flag indicating whether or not trailing whitespaces from + values being written should be skipped. If None is set, it + uses the default value, ``true``. >>> df.write.csv(os.path.join(tempfile.mkdtemp(), 'data')) """ self.mode(mode) self._set_opts(compression=compression, sep=sep, quote=quote, escape=escape, header=header, nullValue=nullValue, escapeQuotes=escapeQuotes, quoteAll=quoteAll, - dateFormat=dateFormat, timestampFormat=timestampFormat) + dateFormat=dateFormat, timestampFormat=timestampFormat, + ignoreLeadingWhiteSpace=ignoreLeadingWhiteSpace, + ignoreTrailingWhiteSpace=ignoreTrailingWhiteSpace) self._jwrite.csv(path) @since(1.5) diff --git a/python/pyspark/sql/streaming.py b/python/pyspark/sql/streaming.py index e227f9ceb5769..80f4340cdf134 100644 --- a/python/pyspark/sql/streaming.py +++ b/python/pyspark/sql/streaming.py @@ -597,12 +597,12 @@ def csv(self, path, schema=None, sep=None, encoding=None, quote=None, escape=Non default value, ``false``. :param inferSchema: infers the input schema automatically from data. It requires one extra pass over the data. If None is set, it uses the default value, ``false``. - :param ignoreLeadingWhiteSpace: defines whether or not leading whitespaces from values - being read should be skipped. If None is set, it uses - the default value, ``false``. - :param ignoreTrailingWhiteSpace: defines whether or not trailing whitespaces from values - being read should be skipped. If None is set, it uses - the default value, ``false``. + :param ignoreLeadingWhiteSpace: a flag indicating whether or not leading whitespaces from + values being read should be skipped. If None is set, it + uses the default value, ``false``. + :param ignoreTrailingWhiteSpace: a flag indicating whether or not trailing whitespaces from + values being read should be skipped. If None is set, it + uses the default value, ``false``. :param nullValue: sets the string representation of a null value. If None is set, it uses the default value, empty string. Since 2.0.1, this ``nullValue`` param applies to all supported types including the string type. diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index f0a9a0400e392..29d613bc5fe3c 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -450,6 +450,19 @@ def test_wholefile_csv(self): Row(_c0=u'Hyukjin', _c1=u'25', _c2=u'I am Hyukjin\n\nI love Spark!')] self.assertEqual(ages_newlines.collect(), expected) + def test_ignorewhitespace_csv(self): + tmpPath = tempfile.mkdtemp() + shutil.rmtree(tmpPath) + self.spark.createDataFrame([[" a", "b ", " c "]]).write.csv( + tmpPath, + ignoreLeadingWhiteSpace=False, + ignoreTrailingWhiteSpace=False) + + expected = [Row(value=u' a,b , c ')] + readback = self.spark.read.text(tmpPath) + self.assertEqual(readback.collect(), expected) + shutil.rmtree(tmpPath) + def test_read_multiple_orc_file(self): df = self.spark.read.orc(["python/test_support/sql/orc_partitioned/b=0/c=0", "python/test_support/sql/orc_partitioned/b=1/c=1"]) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala index e39b4d91f1f6a..e6d2b1bc28d95 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala @@ -489,9 +489,9 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { *
      • `header` (default `false`): uses the first line as names of columns.
      • *
      • `inferSchema` (default `false`): infers the input schema automatically from data. It * requires one extra pass over the data.
      • - *
      • `ignoreLeadingWhiteSpace` (default `false`): defines whether or not leading whitespaces - * from values being read should be skipped.
      • - *
      • `ignoreTrailingWhiteSpace` (default `false`): defines whether or not trailing + *
      • `ignoreLeadingWhiteSpace` (default `false`): a flag indicating whether or not leading + * whitespaces from values being read should be skipped.
      • + *
      • `ignoreTrailingWhiteSpace` (default `false`): a flag indicating whether or not trailing * whitespaces from values being read should be skipped.
      • *
      • `nullValue` (default empty string): sets the string representation of a null value. Since * 2.0.1, this applies to all supported types including the string type.
      • diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala index 3e975ef6a3c24..e973d0bc6d09b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala @@ -573,7 +573,7 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { *
      • `escapeQuotes` (default `true`): a flag indicating whether values containing * quotes should always be enclosed in quotes. Default is to escape all values containing * a quote character.
      • - *
      • `quoteAll` (default `false`): A flag indicating whether all values should always be + *
      • `quoteAll` (default `false`): a flag indicating whether all values should always be * enclosed in quotes. Default is to only escape values containing a quote character.
      • *
      • `header` (default `false`): writes the names of columns as the first line.
      • *
      • `nullValue` (default empty string): sets the string representation of a null value.
      • @@ -586,6 +586,10 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { *
      • `timestampFormat` (default `yyyy-MM-dd'T'HH:mm:ss.SSSZZ`): sets the string that * indicates a timestamp format. Custom date formats follow the formats at * `java.text.SimpleDateFormat`. This applies to timestamp type.
      • + *
      • `ignoreLeadingWhiteSpace` (default `true`): a flag indicating whether or not leading + * whitespaces from values being written should be skipped.
      • + *
      • `ignoreTrailingWhiteSpace` (default `true`): a flag indicating defines whether or not + * trailing whitespaces from values being written should be skipped.
      • *
      * * @since 2.0.0 diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala index 5d2c23ed9618c..e7b79e0cbfd17 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala @@ -93,8 +93,13 @@ class CSVOptions( val headerFlag = getBool("header") val inferSchemaFlag = getBool("inferSchema") - val ignoreLeadingWhiteSpaceFlag = getBool("ignoreLeadingWhiteSpace") - val ignoreTrailingWhiteSpaceFlag = getBool("ignoreTrailingWhiteSpace") + val ignoreLeadingWhiteSpaceInRead = getBool("ignoreLeadingWhiteSpace", default = false) + val ignoreTrailingWhiteSpaceInRead = getBool("ignoreTrailingWhiteSpace", default = false) + + // For write, both options were `true` by default. We leave it as `true` for + // backwards compatibility. + val ignoreLeadingWhiteSpaceFlagInWrite = getBool("ignoreLeadingWhiteSpace", default = true) + val ignoreTrailingWhiteSpaceFlagInWrite = getBool("ignoreTrailingWhiteSpace", default = true) val columnNameOfCorruptRecord = parameters.getOrElse("columnNameOfCorruptRecord", defaultColumnNameOfCorruptRecord) @@ -144,6 +149,8 @@ class CSVOptions( format.setQuote(quote) format.setQuoteEscape(escape) format.setComment(comment) + writerSettings.setIgnoreLeadingWhitespaces(ignoreLeadingWhiteSpaceFlagInWrite) + writerSettings.setIgnoreTrailingWhitespaces(ignoreTrailingWhiteSpaceFlagInWrite) writerSettings.setNullValue(nullValue) writerSettings.setEmptyValue(nullValue) writerSettings.setSkipEmptyLines(true) @@ -159,8 +166,8 @@ class CSVOptions( format.setQuote(quote) format.setQuoteEscape(escape) format.setComment(comment) - settings.setIgnoreLeadingWhitespaces(ignoreLeadingWhiteSpaceFlag) - settings.setIgnoreTrailingWhitespaces(ignoreTrailingWhiteSpaceFlag) + settings.setIgnoreLeadingWhitespaces(ignoreLeadingWhiteSpaceInRead) + settings.setIgnoreTrailingWhitespaces(ignoreTrailingWhiteSpaceInRead) settings.setReadInputOnSeparateThread(false) settings.setInputBufferSize(inputBufferSize) settings.setMaxColumns(maxColumns) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala index f6e2fef74b8db..997ca286597da 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala @@ -238,9 +238,9 @@ final class DataStreamReader private[sql](sparkSession: SparkSession) extends Lo *
    • `header` (default `false`): uses the first line as names of columns.
    • *
    • `inferSchema` (default `false`): infers the input schema automatically from data. It * requires one extra pass over the data.
    • - *
    • `ignoreLeadingWhiteSpace` (default `false`): defines whether or not leading whitespaces - * from values being read should be skipped.
    • - *
    • `ignoreTrailingWhiteSpace` (default `false`): defines whether or not trailing + *
    • `ignoreLeadingWhiteSpace` (default `false`): a flag indicating whether or not leading + * whitespaces from values being read should be skipped.
    • + *
    • `ignoreTrailingWhiteSpace` (default `false`): a flag indicating whether or not trailing * whitespaces from values being read should be skipped.
    • *
    • `nullValue` (default empty string): sets the string representation of a null value. Since * 2.0.1, this applies to all supported types including the string type.
    • diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala index 2600894ca303c..d70c47f4e2379 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala @@ -1117,4 +1117,61 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils { assert(df2.schema === schema) } + test("ignoreLeadingWhiteSpace and ignoreTrailingWhiteSpace options - read") { + val input = " a,b , c " + + // For reading, default of both `ignoreLeadingWhiteSpace` and`ignoreTrailingWhiteSpace` + // are `false`. So, these are excluded. + val combinations = Seq( + (true, true), + (false, true), + (true, false)) + + // Check if read rows ignore whitespaces as configured. + val expectedRows = Seq( + Row("a", "b", "c"), + Row(" a", "b", " c"), + Row("a", "b ", "c ")) + + combinations.zip(expectedRows) + .foreach { case ((ignoreLeadingWhiteSpace, ignoreTrailingWhiteSpace), expected) => + val df = spark.read + .option("ignoreLeadingWhiteSpace", ignoreLeadingWhiteSpace) + .option("ignoreTrailingWhiteSpace", ignoreTrailingWhiteSpace) + .csv(Seq(input).toDS()) + + checkAnswer(df, expected) + } + } + + test("SPARK-18579: ignoreLeadingWhiteSpace and ignoreTrailingWhiteSpace options - write") { + val df = Seq((" a", "b ", " c ")).toDF() + + // For writing, default of both `ignoreLeadingWhiteSpace` and `ignoreTrailingWhiteSpace` + // are `true`. So, these are excluded. + val combinations = Seq( + (false, false), + (false, true), + (true, false)) + + // Check if written lines ignore each whitespaces as configured. + val expectedLines = Seq( + " a,b , c ", + " a,b, c", + "a,b ,c ") + + combinations.zip(expectedLines) + .foreach { case ((ignoreLeadingWhiteSpace, ignoreTrailingWhiteSpace), expected) => + withTempPath { path => + df.write + .option("ignoreLeadingWhiteSpace", ignoreLeadingWhiteSpace) + .option("ignoreTrailingWhiteSpace", ignoreTrailingWhiteSpace) + .csv(path.getAbsolutePath) + + // Read back the written lines. + val readBack = spark.read.text(path.getAbsolutePath) + checkAnswer(readBack, Row(expected)) + } + } + } } From aefe79890541bc0829f184e03eb3961739ca8ef2 Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Thu, 23 Mar 2017 08:41:30 +0000 Subject: [PATCH 0093/1765] [MINOR][BUILD] Fix javadoc8 break ## What changes were proposed in this pull request? Several javadoc8 breaks have been introduced. This PR proposes fix those instances so that we can build Scala/Java API docs. ``` [error] .../spark/sql/core/target/java/org/apache/spark/sql/streaming/GroupState.java:6: error: reference not found [error] * flatMapGroupsWithState operations on {link KeyValueGroupedDataset}. [error] ^ [error] .../spark/sql/core/target/java/org/apache/spark/sql/streaming/GroupState.java:10: error: reference not found [error] * Both, mapGroupsWithState and flatMapGroupsWithState in {link KeyValueGroupedDataset} [error] ^ [error] .../spark/sql/core/target/java/org/apache/spark/sql/streaming/GroupState.java:51: error: reference not found [error] * {link GroupStateTimeout.ProcessingTimeTimeout}) or event time (i.e. [error] ^ [error] .../spark/sql/core/target/java/org/apache/spark/sql/streaming/GroupState.java:52: error: reference not found [error] * {link GroupStateTimeout.EventTimeTimeout}). [error] ^ [error] .../spark/sql/core/target/java/org/apache/spark/sql/streaming/GroupState.java:158: error: reference not found [error] * Spark SQL types (see {link Encoder} for more details). [error] ^ [error] .../spark/mllib/target/java/org/apache/spark/ml/fpm/FPGrowthParams.java:26: error: bad use of '>' [error] * Number of partitions (>=1) used by parallel FP-growth. By default the param is not set, and [error] ^ [error] .../spark/sql/core/src/main/java/org/apache/spark/api/java/function/FlatMapGroupsWithStateFunction.java:30: error: reference not found [error] * {link org.apache.spark.sql.KeyValueGroupedDataset#flatMapGroupsWithState( [error] ^ [error] .../spark/sql/core/target/java/org/apache/spark/sql/KeyValueGroupedDataset.java:211: error: reference not found [error] * See {link GroupState} for more details. [error] ^ [error] .../spark/sql/core/target/java/org/apache/spark/sql/KeyValueGroupedDataset.java:232: error: reference not found [error] * See {link GroupState} for more details. [error] ^ [error] .../spark/sql/core/target/java/org/apache/spark/sql/KeyValueGroupedDataset.java:254: error: reference not found [error] * See {link GroupState} for more details. [error] ^ [error] .../spark/sql/core/target/java/org/apache/spark/sql/KeyValueGroupedDataset.java:277: error: reference not found [error] * See {link GroupState} for more details. [error] ^ [error] .../spark/core/target/java/org/apache/spark/TaskContextImpl.java:10: error: reference not found [error] * {link TaskMetrics} & {link MetricsSystem} objects are not thread safe. [error] ^ [error] .../spark/core/target/java/org/apache/spark/TaskContextImpl.java:10: error: reference not found [error] * {link TaskMetrics} & {link MetricsSystem} objects are not thread safe. [error] ^ [info] 13 errors ``` ``` jekyll 3.3.1 | Error: Unidoc generation failed ``` ## How was this patch tested? Manually via `jekyll build` Author: hyukjinkwon Closes #17389 from HyukjinKwon/minor-javadoc8-fix. --- .../org/apache/spark/TaskContextImpl.scala | 2 +- .../org/apache/spark/ml/fpm/FPGrowth.scala | 4 ++-- .../FlatMapGroupsWithStateFunction.java | 2 +- .../spark/sql/KeyValueGroupedDataset.scala | 8 +++---- .../spark/sql/streaming/GroupState.scala | 22 +++++++++---------- 5 files changed, 19 insertions(+), 19 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/TaskContextImpl.scala b/core/src/main/scala/org/apache/spark/TaskContextImpl.scala index ea8dcdfd5d7d9..f346cf8d65806 100644 --- a/core/src/main/scala/org/apache/spark/TaskContextImpl.scala +++ b/core/src/main/scala/org/apache/spark/TaskContextImpl.scala @@ -38,7 +38,7 @@ import org.apache.spark.util._ * callbacks are protected by locking on the context instance. For instance, this ensures * that you cannot add a completion listener in one thread while we are completing (and calling * the completion listeners) in another thread. Other state is immutable, however the exposed - * [[TaskMetrics]] & [[MetricsSystem]] objects are not thread safe. + * `TaskMetrics` & `MetricsSystem` objects are not thread safe. */ private[spark] class TaskContextImpl( val stageId: Int, diff --git a/mllib/src/main/scala/org/apache/spark/ml/fpm/FPGrowth.scala b/mllib/src/main/scala/org/apache/spark/ml/fpm/FPGrowth.scala index e2bc270b38da7..65cc80619569e 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/fpm/FPGrowth.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/fpm/FPGrowth.scala @@ -69,8 +69,8 @@ private[fpm] trait FPGrowthParams extends Params with HasPredictionCol { def getMinSupport: Double = $(minSupport) /** - * Number of partitions (>=1) used by parallel FP-growth. By default the param is not set, and - * partition number of the input dataset is used. + * Number of partitions (at least 1) used by parallel FP-growth. By default the param is not + * set, and partition number of the input dataset is used. * @group expertParam */ @Since("2.2.0") diff --git a/sql/core/src/main/java/org/apache/spark/api/java/function/FlatMapGroupsWithStateFunction.java b/sql/core/src/main/java/org/apache/spark/api/java/function/FlatMapGroupsWithStateFunction.java index 026b37cabbf1c..802949c0ddb60 100644 --- a/sql/core/src/main/java/org/apache/spark/api/java/function/FlatMapGroupsWithStateFunction.java +++ b/sql/core/src/main/java/org/apache/spark/api/java/function/FlatMapGroupsWithStateFunction.java @@ -27,7 +27,7 @@ /** * ::Experimental:: * Base interface for a map function used in - * {@link org.apache.spark.sql.KeyValueGroupedDataset#flatMapGroupsWithState( + * {@code org.apache.spark.sql.KeyValueGroupedDataset.flatMapGroupsWithState( * FlatMapGroupsWithStateFunction, org.apache.spark.sql.streaming.OutputMode, * org.apache.spark.sql.Encoder, org.apache.spark.sql.Encoder)} * @since 2.1.1 diff --git a/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala index 87c5621768872..022c2f5629e86 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala @@ -298,7 +298,7 @@ class KeyValueGroupedDataset[K, V] private[sql]( * For a static batch Dataset, the function will be invoked once per group. For a streaming * Dataset, the function will be invoked for each group repeatedly in every trigger, and * updates to each group's state will be saved across invocations. - * See [[GroupState]] for more details. + * See `GroupState` for more details. * * @tparam S The type of the user-defined state. Must be encodable to Spark SQL types. * @tparam U The type of the output objects. Must be encodable to Spark SQL types. @@ -328,7 +328,7 @@ class KeyValueGroupedDataset[K, V] private[sql]( * For a static batch Dataset, the function will be invoked once per group. For a streaming * Dataset, the function will be invoked for each group repeatedly in every trigger, and * updates to each group's state will be saved across invocations. - * See [[GroupState]] for more details. + * See `GroupState` for more details. * * @tparam S The type of the user-defined state. Must be encodable to Spark SQL types. * @tparam U The type of the output objects. Must be encodable to Spark SQL types. @@ -360,7 +360,7 @@ class KeyValueGroupedDataset[K, V] private[sql]( * For a static batch Dataset, the function will be invoked once per group. For a streaming * Dataset, the function will be invoked for each group repeatedly in every trigger, and * updates to each group's state will be saved across invocations. - * See [[GroupState]] for more details. + * See `GroupState` for more details. * * @tparam S The type of the user-defined state. Must be encodable to Spark SQL types. * @tparam U The type of the output objects. Must be encodable to Spark SQL types. @@ -400,7 +400,7 @@ class KeyValueGroupedDataset[K, V] private[sql]( * For a static batch Dataset, the function will be invoked once per group. For a streaming * Dataset, the function will be invoked for each group repeatedly in every trigger, and * updates to each group's state will be saved across invocations. - * See [[GroupState]] for more details. + * See `GroupState` for more details. * * @tparam S The type of the user-defined state. Must be encodable to Spark SQL types. * @tparam U The type of the output objects. Must be encodable to Spark SQL types. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/GroupState.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/GroupState.scala index 60a4d0d8f98a1..15df906ca7b13 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/GroupState.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/GroupState.scala @@ -18,18 +18,18 @@ package org.apache.spark.sql.streaming import org.apache.spark.annotation.{Experimental, InterfaceStability} -import org.apache.spark.sql.{Encoder, KeyValueGroupedDataset} +import org.apache.spark.sql.KeyValueGroupedDataset import org.apache.spark.sql.catalyst.plans.logical.LogicalGroupState /** * :: Experimental :: * * Wrapper class for interacting with per-group state data in `mapGroupsWithState` and - * `flatMapGroupsWithState` operations on [[KeyValueGroupedDataset]]. + * `flatMapGroupsWithState` operations on `KeyValueGroupedDataset`. * * Detail description on `[map/flatMap]GroupsWithState` operation * -------------------------------------------------------------- - * Both, `mapGroupsWithState` and `flatMapGroupsWithState` in [[KeyValueGroupedDataset]] + * Both, `mapGroupsWithState` and `flatMapGroupsWithState` in `KeyValueGroupedDataset` * will invoke the user-given function on each group (defined by the grouping function in * `Dataset.groupByKey()`) while maintaining user-defined per-group state between invocations. * For a static batch Dataset, the function will be invoked once per group. For a streaming @@ -70,8 +70,8 @@ import org.apache.spark.sql.catalyst.plans.logical.LogicalGroupState * `[map|flatMap]GroupsWithState`, but the exact timeout duration/timestamp is configurable per * group by calling `setTimeout...()` in `GroupState`. * - Timeouts can be either based on processing time (i.e. - * [[GroupStateTimeout.ProcessingTimeTimeout]]) or event time (i.e. - * [[GroupStateTimeout.EventTimeTimeout]]). + * `GroupStateTimeout.ProcessingTimeTimeout`) or event time (i.e. + * `GroupStateTimeout.EventTimeTimeout`). * - With `ProcessingTimeTimeout`, the timeout duration can be set by calling * `GroupState.setTimeoutDuration`. The timeout will occur when the clock has advanced by the set * duration. Guarantees provided by this timeout with a duration of D ms are as follows: @@ -177,7 +177,7 @@ import org.apache.spark.sql.catalyst.plans.logical.LogicalGroupState * }}} * * @tparam S User-defined type of the state to be stored for each group. Must be encodable into - * Spark SQL types (see [[Encoder]] for more details). + * Spark SQL types (see `Encoder` for more details). * @since 2.2.0 */ @Experimental @@ -224,7 +224,7 @@ trait GroupState[S] extends LogicalGroupState[S] { /** * Set the timeout duration for this key as a string. For example, "1 hour", "2 days", etc. * - * @note, ProcessingTimeTimeout must be enabled in `[map/flatmap]GroupsWithStates`. + * @note ProcessingTimeTimeout must be enabled in `[map/flatmap]GroupsWithStates`. */ @throws[IllegalArgumentException]("if 'duration' is not a valid duration") @throws[IllegalStateException]("when state is either not initialized, or already removed") @@ -240,7 +240,7 @@ trait GroupState[S] extends LogicalGroupState[S] { * Set the timeout timestamp for this key as milliseconds in epoch time. * This timestamp cannot be older than the current watermark. * - * @note, EventTimeTimeout must be enabled in `[map/flatmap]GroupsWithStates`. + * @note EventTimeTimeout must be enabled in `[map/flatmap]GroupsWithStates`. */ def setTimeoutTimestamp(timestampMs: Long): Unit @@ -254,7 +254,7 @@ trait GroupState[S] extends LogicalGroupState[S] { * The final timestamp (including the additional duration) cannot be older than the * current watermark. * - * @note, EventTimeTimeout must be enabled in `[map/flatmap]GroupsWithStates`. + * @note EventTimeTimeout must be enabled in `[map/flatmap]GroupsWithStates`. */ def setTimeoutTimestamp(timestampMs: Long, additionalDuration: String): Unit @@ -265,7 +265,7 @@ trait GroupState[S] extends LogicalGroupState[S] { * Set the timeout timestamp for this key as a java.sql.Date. * This timestamp cannot be older than the current watermark. * - * @note, EventTimeTimeout must be enabled in `[map/flatmap]GroupsWithStates`. + * @note EventTimeTimeout must be enabled in `[map/flatmap]GroupsWithStates`. */ def setTimeoutTimestamp(timestamp: java.sql.Date): Unit @@ -279,7 +279,7 @@ trait GroupState[S] extends LogicalGroupState[S] { * The final timestamp (including the additional duration) cannot be older than the * current watermark. * - * @note, EventTimeTimeout must be enabled in `[map/flatmap]GroupsWithStates`. + * @note EventTimeTimeout must be enabled in `[map/flatmap]GroupsWithStates`. */ def setTimeoutTimestamp(timestamp: java.sql.Date, additionalDuration: String): Unit } From b70c03a42002e924e979acbc98a8b464830be532 Mon Sep 17 00:00:00 2001 From: Sean Owen Date: Thu, 23 Mar 2017 08:42:42 +0000 Subject: [PATCH 0094/1765] [INFRA] Close stale PRs Closes #16819 Closes #13467 Closes #16083 Closes #17135 Closes #8785 Closes #16278 Closes #16997 Closes #17073 Closes #17220 Added: Closes #12059 Closes #12524 Closes #12888 Closes #16061 Author: Sean Owen Closes #17386 from srowen/StalePRs. From b0ae6a38a3ef65e4e853781c5127ba38997a8546 Mon Sep 17 00:00:00 2001 From: Ye Yin Date: Thu, 23 Mar 2017 13:30:50 +0100 Subject: [PATCH 0095/1765] Typo fixup in comment ## What changes were proposed in this pull request? Fixup typo in comment. ## How was this patch tested? Don't need. Author: Ye Yin Closes #17396 from hustcat/fix. --- .../spark/scheduler/cluster/mesos/MesosClusterManager.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterManager.scala b/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterManager.scala index ed29b346ba263..911a0857917ef 100644 --- a/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterManager.scala +++ b/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterManager.scala @@ -22,7 +22,7 @@ import org.apache.spark.internal.config._ import org.apache.spark.scheduler.{ExternalClusterManager, SchedulerBackend, TaskScheduler, TaskSchedulerImpl} /** - * Cluster Manager for creation of Yarn scheduler and backend + * Cluster Manager for creation of Mesos scheduler and backend */ private[spark] class MesosClusterManager extends ExternalClusterManager { private val MESOS_REGEX = """mesos://(.*)""".r From 746a558de2136f91f8fe77c6e51256017aa50913 Mon Sep 17 00:00:00 2001 From: Tyson Condie Date: Thu, 23 Mar 2017 14:32:05 -0700 Subject: [PATCH 0096/1765] [SPARK-19876][SS][WIP] OneTime Trigger Executor ## What changes were proposed in this pull request? An additional trigger and trigger executor that will execute a single trigger only. One can use this OneTime trigger to have more control over the scheduling of triggers. In addition, this patch requires an optimization to StreamExecution that logs a commit record at the end of successfully processing a batch. This new commit log will be used to determine the next batch (offsets) to process after a restart, instead of using the offset log itself to determine what batch to process next after restart; using the offset log to determine this would process the previously logged batch, always, thus not permitting a OneTime trigger feature. ## How was this patch tested? A number of existing tests have been revised. These tests all assumed that when restarting a stream, the last batch in the offset log is to be re-processed. Given that we now have a commit log that will tell us if that last batch was processed successfully, the results/assumptions of those tests needed to be revised accordingly. In addition, a OneTime trigger test was added to StreamingQuerySuite, which tests: - The semantics of OneTime trigger (i.e., on start, execute a single batch, then stop). - The case when the commit log was not able to successfully log the completion of a batch before restart, which would mean that we should fall back to what's in the offset log. - A OneTime trigger execution that results in an exception being thrown. marmbrus tdas zsxwing Please review http://spark.apache.org/contributing.html before opening a pull request. Author: Tyson Condie Author: Tathagata Das Closes #17219 from tcondie/stream-commit. --- .../spark/sql/kafka010/KafkaSourceSuite.scala | 2 - project/MimaExcludes.scala | 6 +- python/pyspark/sql/streaming.py | 63 +++-------- python/pyspark/sql/tests.py | 17 ++- .../execution/streaming/BatchCommitLog.scala | 77 +++++++++++++ .../execution/streaming/StreamExecution.scala | 81 +++++++++++--- .../execution/streaming/TriggerExecutor.scala | 11 ++ .../sql/execution/streaming/Triggers.scala | 29 +++++ .../sql/streaming/DataStreamWriter.scala | 2 +- .../{Trigger.scala => ProcessingTime.scala} | 36 +++--- .../apache/spark/sql/streaming/Trigger.java | 105 ++++++++++++++++++ .../streaming/EventTimeWatermarkSuite.scala | 4 +- .../FlatMapGroupsWithStateSuite.scala | 3 +- .../spark/sql/streaming/StreamSuite.scala | 20 +++- .../spark/sql/streaming/StreamTest.scala | 2 +- .../streaming/StreamingAggregationSuite.scala | 4 + .../StreamingQueryListenerSuite.scala | 18 ++- .../sql/streaming/StreamingQuerySuite.scala | 48 +++++++- .../test/DataStreamReaderWriterSuite.scala | 5 +- 19 files changed, 439 insertions(+), 94 deletions(-) create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/BatchCommitLog.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/Triggers.scala rename sql/core/src/main/scala/org/apache/spark/sql/streaming/{Trigger.scala => ProcessingTime.scala} (74%) create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/streaming/Trigger.java diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSourceSuite.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSourceSuite.scala index 7b6396e0291c9..6391d6269c5ab 100644 --- a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSourceSuite.scala +++ b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSourceSuite.scala @@ -301,8 +301,6 @@ class KafkaSourceSuite extends KafkaSourceTest { StopStream, StartStream(ProcessingTime(100), clock), waitUntilBatchProcessed, - AdvanceManualClock(100), - waitUntilBatchProcessed, // smallest now empty, 1 more from middle, 9 more from biggest CheckAnswer(1, 10, 100, 101, 102, 103, 104, 105, 106, 107, 11, 108, 109, 110, 111, 112, 113, 114, 115, 116, diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index bd4528bd21264..9925a8ba72662 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -64,7 +64,11 @@ object MimaExcludes { ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.status.api.v1.TaskData.$default$11"), // [SPARK-17161] Removing Python-friendly constructors not needed - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.classification.OneVsRestModel.this") + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.classification.OneVsRestModel.this"), + + // [SPARK-19876] Add one time trigger, and improve Trigger APIs + ProblemFilters.exclude[IncompatibleTemplateDefProblem]("org.apache.spark.sql.streaming.Trigger"), + ProblemFilters.exclude[MissingTypesProblem]("org.apache.spark.sql.streaming.ProcessingTime") ) // Exclude rules for 2.1.x diff --git a/python/pyspark/sql/streaming.py b/python/pyspark/sql/streaming.py index 80f4340cdf134..27d6725615a4c 100644 --- a/python/pyspark/sql/streaming.py +++ b/python/pyspark/sql/streaming.py @@ -277,44 +277,6 @@ def resetTerminated(self): self._jsqm.resetTerminated() -class Trigger(object): - """Used to indicate how often results should be produced by a :class:`StreamingQuery`. - - .. note:: Experimental - - .. versionadded:: 2.0 - """ - - __metaclass__ = ABCMeta - - @abstractmethod - def _to_java_trigger(self, sqlContext): - """Internal method to construct the trigger on the jvm. - """ - pass - - -class ProcessingTime(Trigger): - """A trigger that runs a query periodically based on the processing time. If `interval` is 0, - the query will run as fast as possible. - - The interval should be given as a string, e.g. '2 seconds', '5 minutes', ... - - .. note:: Experimental - - .. versionadded:: 2.0 - """ - - def __init__(self, interval): - if type(interval) != str or len(interval.strip()) == 0: - raise ValueError("interval should be a non empty interval string, e.g. '2 seconds'.") - self.interval = interval - - def _to_java_trigger(self, sqlContext): - return sqlContext._sc._jvm.org.apache.spark.sql.streaming.ProcessingTime.create( - self.interval) - - class DataStreamReader(OptionUtils): """ Interface used to load a streaming :class:`DataFrame` from external storage systems @@ -790,7 +752,7 @@ def queryName(self, queryName): @keyword_only @since(2.0) - def trigger(self, processingTime=None): + def trigger(self, processingTime=None, once=None): """Set the trigger for the stream query. If this is not set it will run the query as fast as possible, which is equivalent to setting the trigger to ``processingTime='0 seconds'``. @@ -800,17 +762,26 @@ def trigger(self, processingTime=None): >>> # trigger the query for execution every 5 seconds >>> writer = sdf.writeStream.trigger(processingTime='5 seconds') + >>> # trigger the query for just once batch of data + >>> writer = sdf.writeStream.trigger(once=True) """ - from pyspark.sql.streaming import ProcessingTime - trigger = None + jTrigger = None if processingTime is not None: + if once is not None: + raise ValueError('Multiple triggers not allowed.') if type(processingTime) != str or len(processingTime.strip()) == 0: - raise ValueError('The processing time must be a non empty string. Got: %s' % + raise ValueError('Value for processingTime must be a non empty string. Got: %s' % processingTime) - trigger = ProcessingTime(processingTime) - if trigger is None: - raise ValueError('A trigger was not provided. Supported triggers: processingTime.') - self._jwrite = self._jwrite.trigger(trigger._to_java_trigger(self._spark)) + interval = processingTime.strip() + jTrigger = self._spark._sc._jvm.org.apache.spark.sql.streaming.Trigger.ProcessingTime( + interval) + elif once is not None: + if once is not True: + raise ValueError('Value for once must be True. Got: %s' % once) + jTrigger = self._spark._sc._jvm.org.apache.spark.sql.streaming.Trigger.Once() + else: + raise ValueError('No trigger provided') + self._jwrite = self._jwrite.trigger(jTrigger) return self @ignore_unicode_prefix diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 29d613bc5fe3c..b93b7ed192104 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -1255,13 +1255,26 @@ def test_save_and_load_builder(self): shutil.rmtree(tmpPath) - def test_stream_trigger_takes_keyword_args(self): + def test_stream_trigger(self): df = self.spark.readStream.format('text').load('python/test_support/sql/streaming') + + # Should take at least one arg + try: + df.writeStream.trigger() + except ValueError: + pass + + # Should not take multiple args + try: + df.writeStream.trigger(once=True, processingTime='5 seconds') + except ValueError: + pass + + # Should take only keyword args try: df.writeStream.trigger('5 seconds') self.fail("Should have thrown an exception") except TypeError: - # should throw error pass def test_stream_read_options(self): diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/BatchCommitLog.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/BatchCommitLog.scala new file mode 100644 index 0000000000000..fb1a4fb9b12f5 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/BatchCommitLog.scala @@ -0,0 +1,77 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming + +import java.io.{InputStream, OutputStream} +import java.nio.charset.StandardCharsets._ + +import scala.io.{Source => IOSource} + +import org.apache.spark.sql.SparkSession + +/** + * Used to write log files that represent batch commit points in structured streaming. + * A commit log file will be written immediately after the successful completion of a + * batch, and before processing the next batch. Here is an execution summary: + * - trigger batch 1 + * - obtain batch 1 offsets and write to offset log + * - process batch 1 + * - write batch 1 to completion log + * - trigger batch 2 + * - obtain bactch 2 offsets and write to offset log + * - process batch 2 + * - write batch 2 to completion log + * .... + * + * The current format of the batch completion log is: + * line 1: version + * line 2: metadata (optional json string) + */ +class BatchCommitLog(sparkSession: SparkSession, path: String) + extends HDFSMetadataLog[String](sparkSession, path) { + + override protected def deserialize(in: InputStream): String = { + // called inside a try-finally where the underlying stream is closed in the caller + val lines = IOSource.fromInputStream(in, UTF_8.name()).getLines() + if (!lines.hasNext) { + throw new IllegalStateException("Incomplete log file in the offset commit log") + } + parseVersion(lines.next().trim, BatchCommitLog.VERSION) + // read metadata + lines.next().trim match { + case BatchCommitLog.SERIALIZED_VOID => null + case metadata => metadata + } + } + + override protected def serialize(metadata: String, out: OutputStream): Unit = { + // called inside a try-finally where the underlying stream is closed in the caller + out.write(s"v${BatchCommitLog.VERSION}".getBytes(UTF_8)) + out.write('\n') + + // write metadata or void + out.write((if (metadata == null) BatchCommitLog.SERIALIZED_VOID else metadata) + .getBytes(UTF_8)) + } +} + +object BatchCommitLog { + private val VERSION = 1 + private val SERIALIZED_VOID = "{}" +} + diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala index 60d5283e6b211..34e9262af7cb2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala @@ -165,6 +165,8 @@ class StreamExecution( private val triggerExecutor = trigger match { case t: ProcessingTime => ProcessingTimeExecutor(t, triggerClock) + case OneTimeTrigger => OneTimeExecutor() + case _ => throw new IllegalStateException(s"Unknown type of trigger: $trigger") } /** Defines the internal state of execution */ @@ -209,6 +211,13 @@ class StreamExecution( */ val offsetLog = new OffsetSeqLog(sparkSession, checkpointFile("offsets")) + /** + * A log that records the batch ids that have completed. This is used to check if a batch was + * fully processed, and its output was committed to the sink, hence no need to process it again. + * This is used (for instance) during restart, to help identify which batch to run next. + */ + val batchCommitLog = new BatchCommitLog(sparkSession, checkpointFile("commits")) + /** Whether all fields of the query have been initialized */ private def isInitialized: Boolean = state.get != INITIALIZING @@ -291,10 +300,13 @@ class StreamExecution( runBatch(sparkSessionToRunBatches) } } - // Report trigger as finished and construct progress object. finishTrigger(dataAvailable) if (dataAvailable) { + // Update committed offsets. + committedOffsets ++= availableOffsets + batchCommitLog.add(currentBatchId, null) + logDebug(s"batch ${currentBatchId} committed") // We'll increase currentBatchId after we complete processing current batch's data currentBatchId += 1 } else { @@ -306,9 +318,6 @@ class StreamExecution( } else { false } - - // Update committed offsets. - committedOffsets ++= availableOffsets updateStatusMessage("Waiting for next trigger") continueToRun }) @@ -392,13 +401,33 @@ class StreamExecution( * - currentBatchId * - committedOffsets * - availableOffsets + * The basic structure of this method is as follows: + * + * Identify (from the offset log) the offsets used to run the last batch + * IF last batch exists THEN + * Set the next batch to be executed as the last recovered batch + * Check the commit log to see which batch was committed last + * IF the last batch was committed THEN + * Call getBatch using the last batch start and end offsets + * // ^^^^ above line is needed since some sources assume last batch always re-executes + * Setup for a new batch i.e., start = last batch end, and identify new end + * DONE + * ELSE + * Identify a brand new batch + * DONE */ private def populateStartOffsets(sparkSessionToRunBatches: SparkSession): Unit = { offsetLog.getLatest() match { - case Some((batchId, nextOffsets)) => - logInfo(s"Resuming streaming query, starting with batch $batchId") - currentBatchId = batchId + case Some((latestBatchId, nextOffsets)) => + /* First assume that we are re-executing the latest known batch + * in the offset log */ + currentBatchId = latestBatchId availableOffsets = nextOffsets.toStreamProgress(sources) + /* Initialize committed offsets to a committed batch, which at this + * is the second latest batch id in the offset log. */ + offsetLog.get(latestBatchId - 1).foreach { secondLatestBatchId => + committedOffsets = secondLatestBatchId.toStreamProgress(sources) + } // update offset metadata nextOffsets.metadata.foreach { metadata => @@ -419,14 +448,37 @@ class StreamExecution( SQLConf.SHUFFLE_PARTITIONS.key, shufflePartitionsToUse.toString) } - logDebug(s"Found possibly unprocessed offsets $availableOffsets " + - s"at batch timestamp ${offsetSeqMetadata.batchTimestampMs}") - - offsetLog.get(batchId - 1).foreach { - case lastOffsets => - committedOffsets = lastOffsets.toStreamProgress(sources) - logDebug(s"Resuming with committed offsets: $committedOffsets") + /* identify the current batch id: if commit log indicates we successfully processed the + * latest batch id in the offset log, then we can safely move to the next batch + * i.e., committedBatchId + 1 */ + batchCommitLog.getLatest() match { + case Some((latestCommittedBatchId, _)) => + if (latestBatchId == latestCommittedBatchId) { + /* The last batch was successfully committed, so we can safely process a + * new next batch but first: + * Make a call to getBatch using the offsets from previous batch. + * because certain sources (e.g., KafkaSource) assume on restart the last + * batch will be executed before getOffset is called again. */ + availableOffsets.foreach { ao: (Source, Offset) => + val (source, end) = ao + if (committedOffsets.get(source).map(_ != end).getOrElse(true)) { + val start = committedOffsets.get(source) + source.getBatch(start, end) + } + } + currentBatchId = latestCommittedBatchId + 1 + committedOffsets ++= availableOffsets + // Construct a new batch be recomputing availableOffsets + constructNextBatch() + } else if (latestCommittedBatchId < latestBatchId - 1) { + logWarning(s"Batch completion log latest batch id is " + + s"${latestCommittedBatchId}, which is not trailing " + + s"batchid $latestBatchId by one") + } + case None => logInfo("no commit log present") } + logDebug(s"Resuming at batch $currentBatchId with committed offsets " + + s"$committedOffsets and available offsets $availableOffsets") case None => // We are starting this stream for the first time. logInfo(s"Starting new streaming query.") currentBatchId = 0 @@ -523,6 +575,7 @@ class StreamExecution( // Note that purge is exclusive, i.e. it purges everything before the target ID. if (minBatchesToRetain < currentBatchId) { offsetLog.purge(currentBatchId - minBatchesToRetain) + batchCommitLog.purge(currentBatchId - minBatchesToRetain) } } } else { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TriggerExecutor.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TriggerExecutor.scala index ac510df209f0a..02996ac854f69 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TriggerExecutor.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TriggerExecutor.scala @@ -29,6 +29,17 @@ trait TriggerExecutor { def execute(batchRunner: () => Boolean): Unit } +/** + * A trigger executor that runs a single batch only, then terminates. + */ +case class OneTimeExecutor() extends TriggerExecutor { + + /** + * Execute a single batch using `batchRunner`. + */ + override def execute(batchRunner: () => Boolean): Unit = batchRunner() +} + /** * A trigger executor that runs a batch every `intervalMs` milliseconds. */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/Triggers.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/Triggers.scala new file mode 100644 index 0000000000000..271bc4da99c08 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/Triggers.scala @@ -0,0 +1,29 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming + +import org.apache.spark.annotation.{Experimental, InterfaceStability} +import org.apache.spark.sql.streaming.Trigger + +/** + * A [[Trigger]] that process only one batch of data in a streaming query then terminates + * the query. + */ +@Experimental +@InterfaceStability.Evolving +case object OneTimeTrigger extends Trigger diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala index fe52013badb65..f2f700590ca8e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala @@ -377,7 +377,7 @@ final class DataStreamWriter[T] private[sql](ds: Dataset[T]) { private var outputMode: OutputMode = OutputMode.Append - private var trigger: Trigger = ProcessingTime(0L) + private var trigger: Trigger = Trigger.ProcessingTime(0L) private var extraOptions = new scala.collection.mutable.HashMap[String, String] diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/Trigger.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/ProcessingTime.scala similarity index 74% rename from sql/core/src/main/scala/org/apache/spark/sql/streaming/Trigger.scala rename to sql/core/src/main/scala/org/apache/spark/sql/streaming/ProcessingTime.scala index 68f2eab9d45fc..bdad8e4717be4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/Trigger.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/ProcessingTime.scala @@ -26,16 +26,6 @@ import org.apache.commons.lang3.StringUtils import org.apache.spark.annotation.{Experimental, InterfaceStability} import org.apache.spark.unsafe.types.CalendarInterval -/** - * :: Experimental :: - * Used to indicate how often results should be produced by a [[StreamingQuery]]. - * - * @since 2.0.0 - */ -@Experimental -@InterfaceStability.Evolving -sealed trait Trigger - /** * :: Experimental :: * A trigger that runs a query periodically based on the processing time. If `interval` is 0, @@ -43,24 +33,25 @@ sealed trait Trigger * * Scala Example: * {{{ - * df.write.trigger(ProcessingTime("10 seconds")) + * df.writeStream.trigger(ProcessingTime("10 seconds")) * * import scala.concurrent.duration._ - * df.write.trigger(ProcessingTime(10.seconds)) + * df.writeStream.trigger(ProcessingTime(10.seconds)) * }}} * * Java Example: * {{{ - * df.write.trigger(ProcessingTime.create("10 seconds")) + * df.writeStream.trigger(ProcessingTime.create("10 seconds")) * * import java.util.concurrent.TimeUnit - * df.write.trigger(ProcessingTime.create(10, TimeUnit.SECONDS)) + * df.writeStream.trigger(ProcessingTime.create(10, TimeUnit.SECONDS)) * }}} * * @since 2.0.0 */ @Experimental @InterfaceStability.Evolving +@deprecated("use Trigger.ProcessingTimeTrigger(intervalMs)", "2.2.0") case class ProcessingTime(intervalMs: Long) extends Trigger { require(intervalMs >= 0, "the interval of trigger should not be negative") } @@ -73,6 +64,7 @@ case class ProcessingTime(intervalMs: Long) extends Trigger { */ @Experimental @InterfaceStability.Evolving +@deprecated("use Trigger.ProcessingTimeTrigger(intervalMs)", "2.2.0") object ProcessingTime { /** @@ -80,11 +72,13 @@ object ProcessingTime { * * Example: * {{{ - * df.write.trigger(ProcessingTime("10 seconds")) + * df.writeStream.trigger(ProcessingTime("10 seconds")) * }}} * * @since 2.0.0 + * @deprecated use Trigger.ProcessingTimeTrigger(interval) */ + @deprecated("use Trigger.ProcessingTimeTrigger(interval)", "2.2.0") def apply(interval: String): ProcessingTime = { if (StringUtils.isBlank(interval)) { throw new IllegalArgumentException( @@ -110,11 +104,13 @@ object ProcessingTime { * Example: * {{{ * import scala.concurrent.duration._ - * df.write.trigger(ProcessingTime(10.seconds)) + * df.writeStream.trigger(ProcessingTime(10.seconds)) * }}} * * @since 2.0.0 + * @deprecated use Trigger.ProcessingTimeTrigger(interval) */ + @deprecated("use Trigger.ProcessingTimeTrigger(interval)", "2.2.0") def apply(interval: Duration): ProcessingTime = { new ProcessingTime(interval.toMillis) } @@ -124,11 +120,13 @@ object ProcessingTime { * * Example: * {{{ - * df.write.trigger(ProcessingTime.create("10 seconds")) + * df.writeStream.trigger(ProcessingTime.create("10 seconds")) * }}} * * @since 2.0.0 + * @deprecated use Trigger.ProcessingTimeTrigger(interval) */ + @deprecated("use Trigger.ProcessingTimeTrigger(interval)", "2.2.0") def create(interval: String): ProcessingTime = { apply(interval) } @@ -139,11 +137,13 @@ object ProcessingTime { * Example: * {{{ * import java.util.concurrent.TimeUnit - * df.write.trigger(ProcessingTime.create(10, TimeUnit.SECONDS)) + * df.writeStream.trigger(ProcessingTime.create(10, TimeUnit.SECONDS)) * }}} * * @since 2.0.0 + * @deprecated use Trigger.ProcessingTimeTrigger(interval) */ + @deprecated("use Trigger.ProcessingTimeTrigger(interval, unit)", "2.2.0") def create(interval: Long, unit: TimeUnit): ProcessingTime = { new ProcessingTime(unit.toMillis(interval)) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/Trigger.java b/sql/core/src/main/scala/org/apache/spark/sql/streaming/Trigger.java new file mode 100644 index 0000000000000..a03a851f245fc --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/Trigger.java @@ -0,0 +1,105 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.streaming; + +import java.util.concurrent.TimeUnit; + +import scala.concurrent.duration.Duration; + +import org.apache.spark.annotation.Experimental; +import org.apache.spark.annotation.InterfaceStability; +import org.apache.spark.sql.execution.streaming.OneTimeTrigger$; + +/** + * :: Experimental :: + * Policy used to indicate how often results should be produced by a [[StreamingQuery]]. + * + * @since 2.0.0 + */ +@Experimental +@InterfaceStability.Evolving +public class Trigger { + + /** + * :: Experimental :: + * A trigger policy that runs a query periodically based on an interval in processing time. + * If `interval` is 0, the query will run as fast as possible. + * + * @since 2.2.0 + */ + public static Trigger ProcessingTime(long intervalMs) { + return ProcessingTime.apply(intervalMs); + } + + /** + * :: Experimental :: + * (Java-friendly) + * A trigger policy that runs a query periodically based on an interval in processing time. + * If `interval` is 0, the query will run as fast as possible. + * + * {{{ + * import java.util.concurrent.TimeUnit + * df.writeStream.trigger(ProcessingTime.create(10, TimeUnit.SECONDS)) + * }}} + * + * @since 2.2.0 + */ + public static Trigger ProcessingTime(long interval, TimeUnit timeUnit) { + return ProcessingTime.create(interval, timeUnit); + } + + /** + * :: Experimental :: + * (Scala-friendly) + * A trigger policy that runs a query periodically based on an interval in processing time. + * If `duration` is 0, the query will run as fast as possible. + * + * {{{ + * import scala.concurrent.duration._ + * df.writeStream.trigger(ProcessingTime(10.seconds)) + * }}} + * @since 2.2.0 + */ + public static Trigger ProcessingTime(Duration interval) { + return ProcessingTime.apply(interval); + } + + /** + * :: Experimental :: + * A trigger policy that runs a query periodically based on an interval in processing time. + * If `interval` is effectively 0, the query will run as fast as possible. + * + * {{{ + * df.writeStream.trigger(Trigger.ProcessingTime("10 seconds")) + * }}} + * @since 2.2.0 + */ + public static Trigger ProcessingTime(String interval) { + return ProcessingTime.apply(interval); + } + + /** + * A trigger that process only one batch of data in a streaming query then terminates + * the query. + * + * @since 2.2.0 + */ + public static Trigger Once() { + return OneTimeTrigger$.MODULE$; + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/EventTimeWatermarkSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/EventTimeWatermarkSuite.scala index 7614ea5eb3c01..fd850a7365e20 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/EventTimeWatermarkSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/EventTimeWatermarkSuite.scala @@ -218,7 +218,9 @@ class EventTimeWatermarkSuite extends StreamTest with BeforeAndAfter with Loggin AddData(inputData, 25), // Evict items less than previous watermark. CheckLastBatch((10, 5)), StopStream, - AssertOnQuery { q => // clear the sink + AssertOnQuery { q => // purge commit and clear the sink + val commit = q.batchCommitLog.getLatest().map(_._1).getOrElse(-1L) + 1L + q.batchCommitLog.purge(commit) q.sink.asInstanceOf[MemorySink].clear() true }, diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala index 89a25973afdd1..a00a1a582a971 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala @@ -575,9 +575,10 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest with BeforeAndAf StopStream, StartStream(ProcessingTime("1 second"), triggerClock = clock), + AdvanceManualClock(10 * 1000), AddData(inputData, "c"), - AdvanceManualClock(20 * 1000), + AdvanceManualClock(1 * 1000), CheckLastBatch(("b", "-1"), ("c", "1")), assertNumStateRows(total = 1, updated = 2), diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala index f01211e20cbfc..32920f6dfa223 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala @@ -156,6 +156,15 @@ class StreamSuite extends StreamTest { AssertOnQuery(_.offsetLog.getLatest().get._1 == expectedId, s"offsetLog's latest should be $expectedId") + // Check the latest batchid in the commit log + def CheckCommitLogLatestBatchId(expectedId: Int): AssertOnQuery = + AssertOnQuery(_.batchCommitLog.getLatest().get._1 == expectedId, + s"commitLog's latest should be $expectedId") + + // Ensure that there has not been an incremental execution after restart + def CheckNoIncrementalExecutionCurrentBatchId(): AssertOnQuery = + AssertOnQuery(_.lastExecution == null, s"lastExecution not expected to run") + // For each batch, we would log the state change during the execution // This checks whether the key of the state change log is the expected batch id def CheckIncrementalExecutionCurrentBatchId(expectedId: Int): AssertOnQuery = @@ -181,6 +190,7 @@ class StreamSuite extends StreamTest { // Check the results of batch 0 CheckAnswer(1, 2, 3), CheckIncrementalExecutionCurrentBatchId(0), + CheckCommitLogLatestBatchId(0), CheckOffsetLogLatestBatchId(0), CheckSinkLatestBatchId(0), // Add some data in batch 1 @@ -191,6 +201,7 @@ class StreamSuite extends StreamTest { // Check the results of batch 1 CheckAnswer(1, 2, 3, 4, 5, 6), CheckIncrementalExecutionCurrentBatchId(1), + CheckCommitLogLatestBatchId(1), CheckOffsetLogLatestBatchId(1), CheckSinkLatestBatchId(1), @@ -203,6 +214,7 @@ class StreamSuite extends StreamTest { // the currentId does not get logged (e.g. as 2) even if the clock has advanced many times CheckAnswer(1, 2, 3, 4, 5, 6), CheckIncrementalExecutionCurrentBatchId(1), + CheckCommitLogLatestBatchId(1), CheckOffsetLogLatestBatchId(1), CheckSinkLatestBatchId(1), @@ -210,14 +222,15 @@ class StreamSuite extends StreamTest { StopStream, StartStream(ProcessingTime("10 seconds"), new StreamManualClock(60 * 1000)), - /* -- batch 1 rerun ----------------- */ - // this batch 1 would re-run because the latest batch id logged in offset log is 1 + /* -- batch 1 no rerun ----------------- */ + // batch 1 would not re-run because the latest batch id logged in commit log is 1 AdvanceManualClock(10 * 1000), + CheckNoIncrementalExecutionCurrentBatchId(), /* -- batch 2 ----------------------- */ // Check the results of batch 1 CheckAnswer(1, 2, 3, 4, 5, 6), - CheckIncrementalExecutionCurrentBatchId(1), + CheckCommitLogLatestBatchId(1), CheckOffsetLogLatestBatchId(1), CheckSinkLatestBatchId(1), // Add some data in batch 2 @@ -228,6 +241,7 @@ class StreamSuite extends StreamTest { // Check the results of batch 2 CheckAnswer(1, 2, 3, 4, 5, 6, 7, 8, 9), CheckIncrementalExecutionCurrentBatchId(2), + CheckCommitLogLatestBatchId(2), CheckOffsetLogLatestBatchId(2), CheckSinkLatestBatchId(2)) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala index 60e2375a9817d..8cf1791336814 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala @@ -159,7 +159,7 @@ trait StreamTest extends QueryTest with SharedSQLContext with Timeouts { /** Starts the stream, resuming if data has already been processed. It must not be running. */ case class StartStream( - trigger: Trigger = ProcessingTime(0), + trigger: Trigger = Trigger.ProcessingTime(0), triggerClock: Clock = new SystemClock, additionalConfs: Map[String, String] = Map.empty) extends StreamAction diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala index 0c8015672bab4..600c039cd0b9d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala @@ -272,11 +272,13 @@ class StreamingAggregationSuite extends StateStoreMetricsTest with BeforeAndAfte StopStream, AssertOnQuery { q => // clear the sink q.sink.asInstanceOf[MemorySink].clear() + q.batchCommitLog.purge(3) // advance by a minute i.e., 90 seconds total clock.advance(60 * 1000L) true }, StartStream(ProcessingTime("10 seconds"), triggerClock = clock), + // The commit log blown, causing the last batch to re-run CheckLastBatch((20L, 1), (85L, 1)), AssertOnQuery { q => clock.getTimeMillis() == 90000L @@ -322,11 +324,13 @@ class StreamingAggregationSuite extends StateStoreMetricsTest with BeforeAndAfte StopStream, AssertOnQuery { q => // clear the sink q.sink.asInstanceOf[MemorySink].clear() + q.batchCommitLog.purge(3) // advance by 60 days i.e., 90 days total clock.advance(DateTimeUtils.MILLIS_PER_DAY * 60) true }, StartStream(ProcessingTime("10 day"), triggerClock = clock), + // Commit log blown, causing a re-run of the last batch CheckLastBatch((20L, 1), (85L, 1)), // advance clock to 100 days, should retain keys >= 90 diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryListenerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryListenerSuite.scala index eb09b9ffcfc5d..03dad8a6ddbc7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryListenerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryListenerSuite.scala @@ -57,6 +57,20 @@ class StreamingQueryListenerSuite extends StreamTest with BeforeAndAfter { val inputData = new MemoryStream[Int](0, sqlContext) val df = inputData.toDS().as[Long].map { 10 / _ } val listener = new EventCollector + + case class AssertStreamExecThreadToWaitForClock() + extends AssertOnQuery(q => { + eventually(Timeout(streamingTimeout)) { + if (q.exception.isEmpty) { + assert(clock.asInstanceOf[StreamManualClock].isStreamWaitingAt(clock.getTimeMillis)) + } + } + if (q.exception.isDefined) { + throw q.exception.get + } + true + }, "") + try { // No events until started spark.streams.addListener(listener) @@ -81,6 +95,7 @@ class StreamingQueryListenerSuite extends StreamTest with BeforeAndAfter { // Progress event generated when data processed AddData(inputData, 1, 2), AdvanceManualClock(100), + AssertStreamExecThreadToWaitForClock(), CheckAnswer(10, 5), AssertOnQuery { query => assert(listener.progressEvents.nonEmpty) @@ -109,8 +124,9 @@ class StreamingQueryListenerSuite extends StreamTest with BeforeAndAfter { // Termination event generated with exception message when stopped with error StartStream(ProcessingTime(100), triggerClock = clock), + AssertStreamExecThreadToWaitForClock(), AddData(inputData, 0), - AdvanceManualClock(100), + AdvanceManualClock(100), // process bad data ExpectFailure[SparkException](), AssertOnQuery { query => eventually(Timeout(streamingTimeout)) { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala index a0a2b2b4c9b3b..3f41ecdb7ff68 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala @@ -158,6 +158,49 @@ class StreamingQuerySuite extends StreamTest with BeforeAndAfter with Logging wi ) } + testQuietly("OneTime trigger, commit log, and exception") { + import Trigger.Once + val inputData = MemoryStream[Int] + val mapped = inputData.toDS().map { 6 / _} + + testStream(mapped)( + AssertOnQuery(_.isActive === true), + StopStream, + AddData(inputData, 1, 2), + StartStream(trigger = Once), + CheckAnswer(6, 3), + StopStream, // clears out StreamTest state + AssertOnQuery { q => + // both commit log and offset log contain the same (latest) batch id + q.batchCommitLog.getLatest().map(_._1).getOrElse(-1L) == + q.offsetLog.getLatest().map(_._1).getOrElse(-2L) + }, + AssertOnQuery { q => + // blow away commit log and sink result + q.batchCommitLog.purge(1) + q.sink.asInstanceOf[MemorySink].clear() + true + }, + StartStream(trigger = Once), + CheckAnswer(6, 3), // ensure we fall back to offset log and reprocess batch + StopStream, + AddData(inputData, 3), + StartStream(trigger = Once), + CheckLastBatch(2), // commit log should be back in place + StopStream, + AddData(inputData, 0), + StartStream(trigger = Once), + ExpectFailure[SparkException](), + AssertOnQuery(_.isActive === false), + AssertOnQuery(q => { + q.exception.get.startOffset === + q.committedOffsets.toOffsetSeq(Seq(inputData), OffsetSeqMetadata()).toString && + q.exception.get.endOffset === + q.availableOffsets.toOffsetSeq(Seq(inputData), OffsetSeqMetadata()).toString + }, "incorrect start offset or end offset on exception") + ) + } + testQuietly("status, lastProgress, and recentProgress") { import StreamingQuerySuite._ clock = new StreamManualClock @@ -237,6 +280,7 @@ class StreamingQuerySuite extends StreamTest with BeforeAndAfter with Logging wi AdvanceManualClock(500), // time = 1100 to unblock job AssertOnQuery { _ => clock.getTimeMillis() === 1100 }, CheckAnswer(2), + AssertStreamExecThreadToWaitForClock(), AssertOnQuery(_.status.isDataAvailable === true), AssertOnQuery(_.status.isTriggerActive === false), AssertOnQuery(_.status.message === "Waiting for next trigger"), @@ -275,6 +319,7 @@ class StreamingQuerySuite extends StreamTest with BeforeAndAfter with Logging wi AddData(inputData, 1, 2), AdvanceManualClock(100), // allow another trigger + AssertStreamExecThreadToWaitForClock(), CheckAnswer(4), AssertOnQuery(_.status.isDataAvailable === true), AssertOnQuery(_.status.isTriggerActive === false), @@ -306,8 +351,9 @@ class StreamingQuerySuite extends StreamTest with BeforeAndAfter with Logging wi // Test status and progress after query terminated with error StartStream(ProcessingTime(100), triggerClock = clock), + AdvanceManualClock(100), // ensure initial trigger completes before AddData AddData(inputData, 0), - AdvanceManualClock(100), + AdvanceManualClock(100), // allow another trigger ExpectFailure[SparkException](), AssertOnQuery(_.status.isDataAvailable === false), AssertOnQuery(_.status.isTriggerActive === false), diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/test/DataStreamReaderWriterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/test/DataStreamReaderWriterSuite.scala index 341ab0eb923da..05cd3d9f7c2fa 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/test/DataStreamReaderWriterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/test/DataStreamReaderWriterSuite.scala @@ -31,7 +31,8 @@ import org.apache.spark.sql._ import org.apache.spark.sql.execution.streaming._ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.sources.{StreamSinkProvider, StreamSourceProvider} -import org.apache.spark.sql.streaming._ +import org.apache.spark.sql.streaming.{ProcessingTime => DeprecatedProcessingTime, _} +import org.apache.spark.sql.streaming.Trigger._ import org.apache.spark.sql.types._ import org.apache.spark.util.Utils @@ -346,7 +347,7 @@ class DataStreamReaderWriterSuite extends StreamTest with BeforeAndAfter { q = df.writeStream .format("org.apache.spark.sql.streaming.test") .option("checkpointLocation", newMetadataDir) - .trigger(ProcessingTime.create(100, TimeUnit.SECONDS)) + .trigger(ProcessingTime(100, TimeUnit.SECONDS)) .start() q.stop() From b7be05a203b3e2a307147ea0c6cb0dec03da82a2 Mon Sep 17 00:00:00 2001 From: erenavsarogullari Date: Thu, 23 Mar 2017 17:20:52 -0700 Subject: [PATCH 0097/1765] [SPARK-19567][CORE][SCHEDULER] Support some Schedulable variables immutability and access ## What changes were proposed in this pull request? Some `Schedulable` Entities(`Pool` and `TaskSetManager`) variables need refactoring for _immutability_ and _access modifiers_ levels as follows: - From `var` to `val` (if there is no requirement): This is important to support immutability as much as possible. - Sample => `Pool`: `weight`, `minShare`, `priority`, `name` and `taskSetSchedulingAlgorithm`. - Access modifiers: Specially, `var`s access needs to be restricted from other parts of codebase to prevent potential side effects. - `TaskSetManager`: `tasksSuccessful`, `totalResultSize`, `calculatedTasks` etc... This PR is related with #15604 and has been created seperatedly to keep patch content as isolated and to help the reviewers. ## How was this patch tested? Added new UTs and existing UT coverage. Author: erenavsarogullari Closes #16905 from erenavsarogullari/SPARK-19567. --- .../org/apache/spark/scheduler/Pool.scala | 12 +++---- .../spark/scheduler/TaskSchedulerImpl.scala | 19 ++++++---- .../spark/scheduler/TaskSetManager.scala | 36 ++++++++++--------- .../spark/scheduler/DAGSchedulerSuite.scala | 8 ++--- .../ExternalClusterManagerSuite.scala | 4 +-- .../apache/spark/scheduler/PoolSuite.scala | 6 ++++ .../scheduler/TaskSchedulerImplSuite.scala | 12 +++++-- 7 files changed, 58 insertions(+), 39 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/scheduler/Pool.scala b/core/src/main/scala/org/apache/spark/scheduler/Pool.scala index 2a69a6c5e8790..1181371ab425a 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/Pool.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/Pool.scala @@ -37,24 +37,24 @@ private[spark] class Pool( val schedulableQueue = new ConcurrentLinkedQueue[Schedulable] val schedulableNameToSchedulable = new ConcurrentHashMap[String, Schedulable] - var weight = initWeight - var minShare = initMinShare + val weight = initWeight + val minShare = initMinShare var runningTasks = 0 - var priority = 0 + val priority = 0 // A pool's stage id is used to break the tie in scheduling. var stageId = -1 - var name = poolName + val name = poolName var parent: Pool = null - var taskSetSchedulingAlgorithm: SchedulingAlgorithm = { + private val taskSetSchedulingAlgorithm: SchedulingAlgorithm = { schedulingMode match { case SchedulingMode.FAIR => new FairSchedulingAlgorithm() case SchedulingMode.FIFO => new FIFOSchedulingAlgorithm() case _ => - val msg = "Unsupported scheduling mode: $schedulingMode. Use FAIR or FIFO instead." + val msg = s"Unsupported scheduling mode: $schedulingMode. Use FAIR or FIFO instead." throw new IllegalArgumentException(msg) } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala index bfbcfa1aa386f..8257c70d672a0 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala @@ -59,6 +59,8 @@ private[spark] class TaskSchedulerImpl private[scheduler]( extends TaskScheduler with Logging { + import TaskSchedulerImpl._ + def this(sc: SparkContext) = { this( sc, @@ -130,17 +132,18 @@ private[spark] class TaskSchedulerImpl private[scheduler]( val mapOutputTracker = SparkEnv.get.mapOutputTracker - var schedulableBuilder: SchedulableBuilder = null - var rootPool: Pool = null + private var schedulableBuilder: SchedulableBuilder = null // default scheduler is FIFO - private val schedulingModeConf = conf.get("spark.scheduler.mode", "FIFO") + private val schedulingModeConf = conf.get(SCHEDULER_MODE_PROPERTY, SchedulingMode.FIFO.toString) val schedulingMode: SchedulingMode = try { SchedulingMode.withName(schedulingModeConf.toUpperCase) } catch { case e: java.util.NoSuchElementException => - throw new SparkException(s"Unrecognized spark.scheduler.mode: $schedulingModeConf") + throw new SparkException(s"Unrecognized $SCHEDULER_MODE_PROPERTY: $schedulingModeConf") } + val rootPool: Pool = new Pool("", schedulingMode, 0, 0) + // This is a var so that we can reset it for testing purposes. private[spark] var taskResultGetter = new TaskResultGetter(sc.env, this) @@ -150,8 +153,6 @@ private[spark] class TaskSchedulerImpl private[scheduler]( def initialize(backend: SchedulerBackend) { this.backend = backend - // temporarily set rootPool name to empty - rootPool = new Pool("", schedulingMode, 0, 0) schedulableBuilder = { schedulingMode match { case SchedulingMode.FIFO => @@ -159,7 +160,8 @@ private[spark] class TaskSchedulerImpl private[scheduler]( case SchedulingMode.FAIR => new FairSchedulableBuilder(rootPool, conf) case _ => - throw new IllegalArgumentException(s"Unsupported spark.scheduler.mode: $schedulingMode") + throw new IllegalArgumentException(s"Unsupported $SCHEDULER_MODE_PROPERTY: " + + s"$schedulingMode") } } schedulableBuilder.buildPools() @@ -683,6 +685,9 @@ private[spark] class TaskSchedulerImpl private[scheduler]( private[spark] object TaskSchedulerImpl { + + val SCHEDULER_MODE_PROPERTY = "spark.scheduler.mode" + /** * Used to balance containers across hosts. * diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala index 11633bef3cfc7..fd93a1f5c5d2a 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala @@ -78,16 +78,16 @@ private[spark] class TaskSetManager( private val numFailures = new Array[Int](numTasks) val taskAttempts = Array.fill[List[TaskInfo]](numTasks)(Nil) - var tasksSuccessful = 0 + private[scheduler] var tasksSuccessful = 0 - var weight = 1 - var minShare = 0 + val weight = 1 + val minShare = 0 var priority = taskSet.priority var stageId = taskSet.stageId val name = "TaskSet_" + taskSet.id var parent: Pool = null - var totalResultSize = 0L - var calculatedTasks = 0 + private var totalResultSize = 0L + private var calculatedTasks = 0 private[scheduler] val taskSetBlacklistHelperOpt: Option[TaskSetBlacklist] = { blacklistTracker.map { _ => @@ -95,7 +95,7 @@ private[spark] class TaskSetManager( } } - val runningTasksSet = new HashSet[Long] + private[scheduler] val runningTasksSet = new HashSet[Long] override def runningTasks: Int = runningTasksSet.size @@ -105,7 +105,7 @@ private[spark] class TaskSetManager( // state until all tasks have finished running; we keep TaskSetManagers that are in the zombie // state in order to continue to track and account for the running tasks. // TODO: We should kill any running task attempts when the task set manager becomes a zombie. - var isZombie = false + private[scheduler] var isZombie = false // Set of pending tasks for each executor. These collections are actually // treated as stacks, in which new tasks are added to the end of the @@ -129,17 +129,17 @@ private[spark] class TaskSetManager( private val pendingTasksForRack = new HashMap[String, ArrayBuffer[Int]] // Set containing pending tasks with no locality preferences. - var pendingTasksWithNoPrefs = new ArrayBuffer[Int] + private[scheduler] var pendingTasksWithNoPrefs = new ArrayBuffer[Int] // Set containing all pending tasks (also used as a stack, as above). - val allPendingTasks = new ArrayBuffer[Int] + private val allPendingTasks = new ArrayBuffer[Int] // Tasks that can be speculated. Since these will be a small fraction of total // tasks, we'll just hold them in a HashSet. - val speculatableTasks = new HashSet[Int] + private[scheduler] val speculatableTasks = new HashSet[Int] // Task index, start and finish time for each task attempt (indexed by task ID) - val taskInfos = new HashMap[Long, TaskInfo] + private val taskInfos = new HashMap[Long, TaskInfo] // How frequently to reprint duplicate exceptions in full, in milliseconds val EXCEPTION_PRINT_INTERVAL = @@ -148,7 +148,7 @@ private[spark] class TaskSetManager( // Map of recent exceptions (identified by string representation and top stack frame) to // duplicate count (how many times the same exception has appeared) and time the full exception // was printed. This should ideally be an LRU map that can drop old exceptions automatically. - val recentExceptions = HashMap[String, (Int, Long)]() + private val recentExceptions = HashMap[String, (Int, Long)]() // Figure out the current map output tracker epoch and set it on all tasks val epoch = sched.mapOutputTracker.getEpoch @@ -169,20 +169,22 @@ private[spark] class TaskSetManager( * This allows a performance optimization, of skipping levels that aren't relevant (eg., skip * PROCESS_LOCAL if no tasks could be run PROCESS_LOCAL for the current set of executors). */ - var myLocalityLevels = computeValidLocalityLevels() - var localityWaits = myLocalityLevels.map(getLocalityWait) // Time to wait at each level + private[scheduler] var myLocalityLevels = computeValidLocalityLevels() + + // Time to wait at each level + private[scheduler] var localityWaits = myLocalityLevels.map(getLocalityWait) // Delay scheduling variables: we keep track of our current locality level and the time we // last launched a task at that level, and move up a level when localityWaits[curLevel] expires. // We then move down if we manage to launch a "more local" task. - var currentLocalityIndex = 0 // Index of our current locality level in validLocalityLevels - var lastLaunchTime = clock.getTimeMillis() // Time we last launched a task at this level + private var currentLocalityIndex = 0 // Index of our current locality level in validLocalityLevels + private var lastLaunchTime = clock.getTimeMillis() // Time we last launched a task at this level override def schedulableQueue: ConcurrentLinkedQueue[Schedulable] = null override def schedulingMode: SchedulingMode = SchedulingMode.NONE - var emittedTaskSizeWarning = false + private[scheduler] var emittedTaskSizeWarning = false /** Add a task to all the pending-task lists that it should be on. */ private def addPendingTask(index: Int) { diff --git a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala index dfad5db68a914..a9389003d5db8 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala @@ -110,8 +110,8 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with Timeou val cancelledStages = new HashSet[Int]() val taskScheduler = new TaskScheduler() { - override def rootPool: Pool = null - override def schedulingMode: SchedulingMode = SchedulingMode.NONE + override def schedulingMode: SchedulingMode = SchedulingMode.FIFO + override def rootPool: Pool = new Pool("", schedulingMode, 0, 0) override def start() = {} override def stop() = {} override def executorHeartbeatReceived( @@ -542,8 +542,8 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with Timeou // make sure that the DAGScheduler doesn't crash when the TaskScheduler // doesn't implement killTask() val noKillTaskScheduler = new TaskScheduler() { - override def rootPool: Pool = null - override def schedulingMode: SchedulingMode = SchedulingMode.NONE + override def schedulingMode: SchedulingMode = SchedulingMode.FIFO + override def rootPool: Pool = new Pool("", schedulingMode, 0, 0) override def start(): Unit = {} override def stop(): Unit = {} override def submitTasks(taskSet: TaskSet): Unit = { diff --git a/core/src/test/scala/org/apache/spark/scheduler/ExternalClusterManagerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/ExternalClusterManagerSuite.scala index e87cebf0cf358..37c124a726be2 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/ExternalClusterManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/ExternalClusterManagerSuite.scala @@ -73,8 +73,8 @@ private class DummySchedulerBackend extends SchedulerBackend { private class DummyTaskScheduler extends TaskScheduler { var initialized = false - override def rootPool: Pool = null - override def schedulingMode: SchedulingMode = SchedulingMode.NONE + override def schedulingMode: SchedulingMode = SchedulingMode.FIFO + override def rootPool: Pool = new Pool("", schedulingMode, 0, 0) override def start(): Unit = {} override def stop(): Unit = {} override def submitTasks(taskSet: TaskSet): Unit = {} diff --git a/core/src/test/scala/org/apache/spark/scheduler/PoolSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/PoolSuite.scala index cddff3dd35861..4901062a78553 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/PoolSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/PoolSuite.scala @@ -286,6 +286,12 @@ class PoolSuite extends SparkFunSuite with LocalSparkContext { assert(testPool.getSchedulableByName(taskSetManager.name) === taskSetManager) } + test("Pool should throw IllegalArgumentException when schedulingMode is not supported") { + intercept[IllegalArgumentException] { + new Pool("TestPool", SchedulingMode.NONE, 0, 1) + } + } + private def verifyPool(rootPool: Pool, poolName: String, expectedInitMinShare: Int, expectedInitWeight: Int, expectedSchedulingMode: SchedulingMode): Unit = { val selectedPool = rootPool.getSchedulableByName(poolName) diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala index 9ae0bcd9b8860..8b9d45f734cda 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala @@ -75,9 +75,7 @@ class TaskSchedulerImplSuite extends SparkFunSuite with LocalSparkContext with B def setupScheduler(confs: (String, String)*): TaskSchedulerImpl = { val conf = new SparkConf().setMaster("local").setAppName("TaskSchedulerImplSuite") - confs.foreach { case (k, v) => - conf.set(k, v) - } + confs.foreach { case (k, v) => conf.set(k, v) } sc = new SparkContext(conf) taskScheduler = new TaskSchedulerImpl(sc) setupHelper() @@ -904,4 +902,12 @@ class TaskSchedulerImplSuite extends SparkFunSuite with LocalSparkContext with B assert(taskDescs.size === 1) assert(taskDescs.head.executorId === "exec2") } + + test("TaskScheduler should throw IllegalArgumentException when schedulingMode is not supported") { + intercept[IllegalArgumentException] { + val taskScheduler = setupScheduler( + TaskSchedulerImpl.SCHEDULER_MODE_PROPERTY -> SchedulingMode.NONE.toString) + taskScheduler.initialize(new FakeSchedulerBackend) + } + } } From c7911807050227fcd13161ce090330d9d8daa533 Mon Sep 17 00:00:00 2001 From: sureshthalamati Date: Thu, 23 Mar 2017 17:39:33 -0700 Subject: [PATCH 0098/1765] [SPARK-10849][SQL] Adds option to the JDBC data source write for user to specify database column type for the create table MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## What changes were proposed in this pull request? Currently JDBC data source creates tables in the target database using the default type mapping, and the JDBC dialect mechanism.  If users want to specify different database data type for only some of columns, there is no option available. In scenarios where default mapping does not work, users are forced to create tables on the target database before writing. This workaround is probably not acceptable from a usability point of view. This PR is to provide a user-defined type mapping for specific columns. The solution is to allow users to specify database column data type for the create table as JDBC datasource option(createTableColumnTypes) on write. Data type information can be specified in the same format as table schema DDL format (e.g: `name CHAR(64), comments VARCHAR(1024)`). All supported target database types can not be specified , the data types has to be valid spark sql data types also. For example user can not specify target database CLOB data type. This will be supported in the follow-up PR. Example: ```Scala df.write .option("createTableColumnTypes", "name CHAR(64), comments VARCHAR(1024)") .jdbc(url, "TEST.DBCOLTYPETEST", properties) ``` ## How was this patch tested? Added new test cases to the JDBCWriteSuite Author: sureshthalamati Closes #16209 from sureshthalamati/jdbc_custom_dbtype_option_json-spark-10849. --- docs/sql-programming-guide.md | 7 + .../sql/JavaSQLDataSourceExample.java | 5 + examples/src/main/python/sql/datasource.py | 6 + .../examples/sql/SQLDataSourceExample.scala | 5 + .../datasources/jdbc/JDBCOptions.scala | 2 + .../jdbc/JdbcRelationProvider.scala | 4 +- .../datasources/jdbc/JdbcUtils.scala | 66 +++++++- .../org/apache/spark/sql/jdbc/JDBCSuite.scala | 2 +- .../spark/sql/jdbc/JDBCWriteSuite.scala | 150 +++++++++++++++++- 9 files changed, 235 insertions(+), 12 deletions(-) diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md index b077575155eb0..7ae9847983d4d 100644 --- a/docs/sql-programming-guide.md +++ b/docs/sql-programming-guide.md @@ -1223,6 +1223,13 @@ the following case-insensitive options: This is a JDBC writer related option. If specified, this option allows setting of database-specific table and partition options when creating a table (e.g., CREATE TABLE t (name string) ENGINE=InnoDB.). This option applies only to writing. + + + createTableColumnTypes + + The database column data types to use instead of the defaults, when creating the table. Data type information should be specified in the same format as CREATE TABLE columns syntax (e.g: "name CHAR(64), comments VARCHAR(1024)"). The specified types should be valid spark sql data types. This option applies only to writing. + +
      diff --git a/examples/src/main/java/org/apache/spark/examples/sql/JavaSQLDataSourceExample.java b/examples/src/main/java/org/apache/spark/examples/sql/JavaSQLDataSourceExample.java index 82bb284ea3e58..1a7054614b348 100644 --- a/examples/src/main/java/org/apache/spark/examples/sql/JavaSQLDataSourceExample.java +++ b/examples/src/main/java/org/apache/spark/examples/sql/JavaSQLDataSourceExample.java @@ -258,6 +258,11 @@ private static void runJdbcDatasetExample(SparkSession spark) { jdbcDF2.write() .jdbc("jdbc:postgresql:dbserver", "schema.tablename", connectionProperties); + + // Specifying create table column data types on write + jdbcDF.write() + .option("createTableColumnTypes", "name CHAR(64), comments VARCHAR(1024)") + .jdbc("jdbc:postgresql:dbserver", "schema.tablename", connectionProperties); // $example off:jdbc_dataset$ } } diff --git a/examples/src/main/python/sql/datasource.py b/examples/src/main/python/sql/datasource.py index e9aa9d9ac2583..e4abb0933345d 100644 --- a/examples/src/main/python/sql/datasource.py +++ b/examples/src/main/python/sql/datasource.py @@ -169,6 +169,12 @@ def jdbc_dataset_example(spark): jdbcDF2.write \ .jdbc("jdbc:postgresql:dbserver", "schema.tablename", properties={"user": "username", "password": "password"}) + + # Specifying create table column data types on write + jdbcDF.write \ + .option("createTableColumnTypes", "name CHAR(64), comments VARCHAR(1024)") \ + .jdbc("jdbc:postgresql:dbserver", "schema.tablename", + properties={"user": "username", "password": "password"}) # $example off:jdbc_dataset$ diff --git a/examples/src/main/scala/org/apache/spark/examples/sql/SQLDataSourceExample.scala b/examples/src/main/scala/org/apache/spark/examples/sql/SQLDataSourceExample.scala index 381e69cda841c..82fd56de39847 100644 --- a/examples/src/main/scala/org/apache/spark/examples/sql/SQLDataSourceExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/sql/SQLDataSourceExample.scala @@ -181,6 +181,11 @@ object SQLDataSourceExample { jdbcDF2.write .jdbc("jdbc:postgresql:dbserver", "schema.tablename", connectionProperties) + + // Specifying create table column data types on write + jdbcDF.write + .option("createTableColumnTypes", "name CHAR(64), comments VARCHAR(1024)") + .jdbc("jdbc:postgresql:dbserver", "schema.tablename", connectionProperties) // $example off:jdbc_dataset$ } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCOptions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCOptions.scala index d4d34646545ba..89fe86c038b16 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCOptions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCOptions.scala @@ -119,6 +119,7 @@ class JDBCOptions( // E.g., "CREATE TABLE t (name string) ENGINE=InnoDB DEFAULT CHARSET=utf8" // TODO: to reuse the existing partition parameters for those partition specific options val createTableOptions = parameters.getOrElse(JDBC_CREATE_TABLE_OPTIONS, "") + val createTableColumnTypes = parameters.get(JDBC_CREATE_TABLE_COLUMN_TYPES) val batchSize = { val size = parameters.getOrElse(JDBC_BATCH_INSERT_SIZE, "1000").toInt require(size >= 1, @@ -154,6 +155,7 @@ object JDBCOptions { val JDBC_BATCH_FETCH_SIZE = newOption("fetchsize") val JDBC_TRUNCATE = newOption("truncate") val JDBC_CREATE_TABLE_OPTIONS = newOption("createTableOptions") + val JDBC_CREATE_TABLE_COLUMN_TYPES = newOption("createTableColumnTypes") val JDBC_BATCH_INSERT_SIZE = newOption("batchsize") val JDBC_TXN_ISOLATION_LEVEL = newOption("isolationLevel") } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcRelationProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcRelationProvider.scala index 88f6cb0021305..74dcfb06f5c2b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcRelationProvider.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcRelationProvider.scala @@ -69,7 +69,7 @@ class JdbcRelationProvider extends CreatableRelationProvider } else { // Otherwise, do not truncate the table, instead drop and recreate it dropTable(conn, options.table) - createTable(conn, df.schema, options) + createTable(conn, df, options) saveTable(df, Some(df.schema), isCaseSensitive, options) } @@ -87,7 +87,7 @@ class JdbcRelationProvider extends CreatableRelationProvider // Therefore, it is okay to do nothing here and then just return the relation below. } } else { - createTable(conn, df.schema, options) + createTable(conn, df, options) saveTable(df, Some(df.schema), isCaseSensitive, options) } } finally { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala index d89f600874177..774d1ba194321 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala @@ -30,7 +30,8 @@ import org.apache.spark.sql.{AnalysisException, DataFrame, Row} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.encoders.RowEncoder import org.apache.spark.sql.catalyst.expressions.SpecificInternalRow -import org.apache.spark.sql.catalyst.util.{DateTimeUtils, GenericArrayData} +import org.apache.spark.sql.catalyst.parser.CatalystSqlParser +import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, DateTimeUtils, GenericArrayData} import org.apache.spark.sql.jdbc.{JdbcDialect, JdbcDialects, JdbcType} import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String @@ -680,18 +681,70 @@ object JdbcUtils extends Logging { /** * Compute the schema string for this RDD. */ - def schemaString(schema: StructType, url: String): String = { + def schemaString( + df: DataFrame, + url: String, + createTableColumnTypes: Option[String] = None): String = { val sb = new StringBuilder() val dialect = JdbcDialects.get(url) - schema.fields foreach { field => + val userSpecifiedColTypesMap = createTableColumnTypes + .map(parseUserSpecifiedCreateTableColumnTypes(df, _)) + .getOrElse(Map.empty[String, String]) + df.schema.fields.foreach { field => val name = dialect.quoteIdentifier(field.name) - val typ: String = getJdbcType(field.dataType, dialect).databaseTypeDefinition + val typ = userSpecifiedColTypesMap + .getOrElse(field.name, getJdbcType(field.dataType, dialect).databaseTypeDefinition) val nullable = if (field.nullable) "" else "NOT NULL" sb.append(s", $name $typ $nullable") } if (sb.length < 2) "" else sb.substring(2) } + /** + * Parses the user specified createTableColumnTypes option value string specified in the same + * format as create table ddl column types, and returns Map of field name and the data type to + * use in-place of the default data type. + */ + private def parseUserSpecifiedCreateTableColumnTypes( + df: DataFrame, + createTableColumnTypes: String): Map[String, String] = { + def typeName(f: StructField): String = { + // char/varchar gets translated to string type. Real data type specified by the user + // is available in the field metadata as HIVE_TYPE_STRING + if (f.metadata.contains(HIVE_TYPE_STRING)) { + f.metadata.getString(HIVE_TYPE_STRING) + } else { + f.dataType.catalogString + } + } + + val userSchema = CatalystSqlParser.parseTableSchema(createTableColumnTypes) + val nameEquality = df.sparkSession.sessionState.conf.resolver + + // checks duplicate columns in the user specified column types. + userSchema.fieldNames.foreach { col => + val duplicatesCols = userSchema.fieldNames.filter(nameEquality(_, col)) + if (duplicatesCols.size >= 2) { + throw new AnalysisException( + "Found duplicate column(s) in createTableColumnTypes option value: " + + duplicatesCols.mkString(", ")) + } + } + + // checks if user specified column names exist in the DataFrame schema + userSchema.fieldNames.foreach { col => + df.schema.find(f => nameEquality(f.name, col)).getOrElse { + throw new AnalysisException( + s"createTableColumnTypes option column $col not found in schema " + + df.schema.catalogString) + } + } + + val userSchemaMap = userSchema.fields.map(f => f.name -> typeName(f)).toMap + val isCaseSensitive = df.sparkSession.sessionState.conf.caseSensitiveAnalysis + if (isCaseSensitive) userSchemaMap else CaseInsensitiveMap(userSchemaMap) + } + /** * Saves the RDD to the database in a single transaction. */ @@ -726,9 +779,10 @@ object JdbcUtils extends Logging { */ def createTable( conn: Connection, - schema: StructType, + df: DataFrame, options: JDBCOptions): Unit = { - val strSchema = schemaString(schema, options.url) + val strSchema = schemaString( + df, options.url, options.createTableColumnTypes) val table = options.table val createTableOptions = options.createTableOptions // Create the table if the table does not exist. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala index 5463728ca0c1d..4a02277631f14 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala @@ -869,7 +869,7 @@ class JDBCSuite extends SparkFunSuite test("SPARK-16387: Reserved SQL words are not escaped by JDBC writer") { val df = spark.createDataset(Seq("a", "b", "c")).toDF("order") - val schema = JdbcUtils.schemaString(df.schema, "jdbc:mysql://localhost:3306/temp") + val schema = JdbcUtils.schemaString(df, "jdbc:mysql://localhost:3306/temp") assert(schema.contains("`order` TEXT")) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala index ec7b19e666ec0..bf1fd160704fa 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala @@ -17,15 +17,16 @@ package org.apache.spark.sql.jdbc -import java.sql.DriverManager +import java.sql.{Date, DriverManager, Timestamp} import java.util.Properties import scala.collection.JavaConverters.propertiesAsScalaMapConverter import org.scalatest.BeforeAndAfter -import org.apache.spark.sql.{AnalysisException, Row, SaveMode} -import org.apache.spark.sql.execution.datasources.jdbc.JDBCOptions +import org.apache.spark.sql.{AnalysisException, DataFrame, Row, SaveMode} +import org.apache.spark.sql.catalyst.parser.ParseException +import org.apache.spark.sql.execution.datasources.jdbc.{JDBCOptions, JdbcUtils} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types._ @@ -362,4 +363,147 @@ class JDBCWriteSuite extends SharedSQLContext with BeforeAndAfter { assert(sql("select * from people_view").count() == 2) } } + + test("SPARK-10849: test schemaString - from createTableColumnTypes option values") { + def testCreateTableColDataTypes(types: Seq[String]): Unit = { + val colTypes = types.zipWithIndex.map { case (t, i) => (s"col$i", t) } + val schema = colTypes + .foldLeft(new StructType())((schema, colType) => schema.add(colType._1, colType._2)) + val createTableColTypes = + colTypes.map { case (col, dataType) => s"$col $dataType" }.mkString(", ") + val df = spark.createDataFrame(sparkContext.parallelize(Seq(Row.empty)), schema) + + val expectedSchemaStr = + colTypes.map { case (col, dataType) => s""""$col" $dataType """ }.mkString(", ") + + assert(JdbcUtils.schemaString(df, url1, Option(createTableColTypes)) == expectedSchemaStr) + } + + testCreateTableColDataTypes(Seq("boolean")) + testCreateTableColDataTypes(Seq("tinyint", "smallint", "int", "bigint")) + testCreateTableColDataTypes(Seq("float", "double")) + testCreateTableColDataTypes(Seq("string", "char(10)", "varchar(20)")) + testCreateTableColDataTypes(Seq("decimal(10,0)", "decimal(10,5)")) + testCreateTableColDataTypes(Seq("date", "timestamp")) + testCreateTableColDataTypes(Seq("binary")) + } + + test("SPARK-10849: create table using user specified column type and verify on target table") { + def testUserSpecifiedColTypes( + df: DataFrame, + createTableColTypes: String, + expectedTypes: Map[String, String]): Unit = { + df.write + .mode(SaveMode.Overwrite) + .option("createTableColumnTypes", createTableColTypes) + .jdbc(url1, "TEST.DBCOLTYPETEST", properties) + + // verify the data types of the created table by reading the database catalog of H2 + val query = + """ + |(SELECT column_name, type_name, character_maximum_length + | FROM information_schema.columns WHERE table_name = 'DBCOLTYPETEST') + """.stripMargin + val rows = spark.read.jdbc(url1, query, properties).collect() + + rows.foreach { row => + val typeName = row.getString(1) + // For CHAR and VARCHAR, we also compare the max length + if (typeName.contains("CHAR")) { + val charMaxLength = row.getInt(2) + assert(expectedTypes(row.getString(0)) == s"$typeName($charMaxLength)") + } else { + assert(expectedTypes(row.getString(0)) == typeName) + } + } + } + + val data = Seq[Row](Row(1, "dave", "Boston")) + val schema = StructType( + StructField("id", IntegerType) :: + StructField("first#name", StringType) :: + StructField("city", StringType) :: Nil) + val df = spark.createDataFrame(sparkContext.parallelize(data), schema) + + // out-of-order + val expected1 = Map("id" -> "BIGINT", "first#name" -> "VARCHAR(123)", "city" -> "CHAR(20)") + testUserSpecifiedColTypes(df, "`first#name` VARCHAR(123), id BIGINT, city CHAR(20)", expected1) + // partial schema + val expected2 = Map("id" -> "INTEGER", "first#name" -> "VARCHAR(123)", "city" -> "CHAR(20)") + testUserSpecifiedColTypes(df, "`first#name` VARCHAR(123), city CHAR(20)", expected2) + + withSQLConf(SQLConf.CASE_SENSITIVE.key -> "false") { + // should still respect the original column names + val expected = Map("id" -> "INTEGER", "first#name" -> "VARCHAR(123)", "city" -> "CLOB") + testUserSpecifiedColTypes(df, "`FiRsT#NaMe` VARCHAR(123)", expected) + } + + withSQLConf(SQLConf.CASE_SENSITIVE.key -> "true") { + val schema = StructType( + StructField("id", IntegerType) :: + StructField("First#Name", StringType) :: + StructField("city", StringType) :: Nil) + val df = spark.createDataFrame(sparkContext.parallelize(data), schema) + val expected = Map("id" -> "INTEGER", "First#Name" -> "VARCHAR(123)", "city" -> "CLOB") + testUserSpecifiedColTypes(df, "`First#Name` VARCHAR(123)", expected) + } + } + + test("SPARK-10849: jdbc CreateTableColumnTypes option with invalid data type") { + val df = spark.createDataFrame(sparkContext.parallelize(arr2x2), schema2) + val msg = intercept[ParseException] { + df.write.mode(SaveMode.Overwrite) + .option("createTableColumnTypes", "name CLOB(2000)") + .jdbc(url1, "TEST.USERDBTYPETEST", properties) + }.getMessage() + assert(msg.contains("DataType clob(2000) is not supported.")) + } + + test("SPARK-10849: jdbc CreateTableColumnTypes option with invalid syntax") { + val df = spark.createDataFrame(sparkContext.parallelize(arr2x2), schema2) + val msg = intercept[ParseException] { + df.write.mode(SaveMode.Overwrite) + .option("createTableColumnTypes", "`name char(20)") // incorrectly quoted column + .jdbc(url1, "TEST.USERDBTYPETEST", properties) + }.getMessage() + assert(msg.contains("no viable alternative at input")) + } + + test("SPARK-10849: jdbc CreateTableColumnTypes duplicate columns") { + withSQLConf(SQLConf.CASE_SENSITIVE.key -> "false") { + val df = spark.createDataFrame(sparkContext.parallelize(arr2x2), schema2) + val msg = intercept[AnalysisException] { + df.write.mode(SaveMode.Overwrite) + .option("createTableColumnTypes", "name CHAR(20), id int, NaMe VARCHAR(100)") + .jdbc(url1, "TEST.USERDBTYPETEST", properties) + }.getMessage() + assert(msg.contains( + "Found duplicate column(s) in createTableColumnTypes option value: name, NaMe")) + } + } + + test("SPARK-10849: jdbc CreateTableColumnTypes invalid columns") { + // schema2 has the column "id" and "name" + val df = spark.createDataFrame(sparkContext.parallelize(arr2x2), schema2) + + withSQLConf(SQLConf.CASE_SENSITIVE.key -> "false") { + val msg = intercept[AnalysisException] { + df.write.mode(SaveMode.Overwrite) + .option("createTableColumnTypes", "firstName CHAR(20), id int") + .jdbc(url1, "TEST.USERDBTYPETEST", properties) + }.getMessage() + assert(msg.contains("createTableColumnTypes option column firstName not found in " + + "schema struct")) + } + + withSQLConf(SQLConf.CASE_SENSITIVE.key -> "true") { + val msg = intercept[AnalysisException] { + df.write.mode(SaveMode.Overwrite) + .option("createTableColumnTypes", "id int, Name VARCHAR(100)") + .jdbc(url1, "TEST.USERDBTYPETEST", properties) + }.getMessage() + assert(msg.contains("createTableColumnTypes option column Name not found in " + + "schema struct")) + } + } } From 93581fbc18c01595918c565f6737aaa666116114 Mon Sep 17 00:00:00 2001 From: Burak Yavuz Date: Thu, 23 Mar 2017 17:57:31 -0700 Subject: [PATCH 0099/1765] Fix compilation of the Scala 2.10 master branch ## What changes were proposed in this pull request? Fixes break caused by: https://github.com/apache/spark/commit/746a558de2136f91f8fe77c6e51256017aa50913 ## How was this patch tested? Compiled with `build/sbt -Dscala2.10 sql/compile` locally Author: Burak Yavuz Closes #17403 from brkyvz/onceTrigger2.10. --- .../spark/sql/streaming/ProcessingTime.scala | 20 +++++++++---------- .../apache/spark/sql/streaming/Trigger.java | 2 +- 2 files changed, 11 insertions(+), 11 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/ProcessingTime.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/ProcessingTime.scala index bdad8e4717be4..9ba1fc01cbd30 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/ProcessingTime.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/ProcessingTime.scala @@ -51,7 +51,7 @@ import org.apache.spark.unsafe.types.CalendarInterval */ @Experimental @InterfaceStability.Evolving -@deprecated("use Trigger.ProcessingTimeTrigger(intervalMs)", "2.2.0") +@deprecated("use Trigger.ProcessingTime(intervalMs)", "2.2.0") case class ProcessingTime(intervalMs: Long) extends Trigger { require(intervalMs >= 0, "the interval of trigger should not be negative") } @@ -64,7 +64,7 @@ case class ProcessingTime(intervalMs: Long) extends Trigger { */ @Experimental @InterfaceStability.Evolving -@deprecated("use Trigger.ProcessingTimeTrigger(intervalMs)", "2.2.0") +@deprecated("use Trigger.ProcessingTime(intervalMs)", "2.2.0") object ProcessingTime { /** @@ -76,9 +76,9 @@ object ProcessingTime { * }}} * * @since 2.0.0 - * @deprecated use Trigger.ProcessingTimeTrigger(interval) + * @deprecated use Trigger.ProcessingTime(interval) */ - @deprecated("use Trigger.ProcessingTimeTrigger(interval)", "2.2.0") + @deprecated("use Trigger.ProcessingTime(interval)", "2.2.0") def apply(interval: String): ProcessingTime = { if (StringUtils.isBlank(interval)) { throw new IllegalArgumentException( @@ -108,9 +108,9 @@ object ProcessingTime { * }}} * * @since 2.0.0 - * @deprecated use Trigger.ProcessingTimeTrigger(interval) + * @deprecated use Trigger.ProcessingTime(interval) */ - @deprecated("use Trigger.ProcessingTimeTrigger(interval)", "2.2.0") + @deprecated("use Trigger.ProcessingTime(interval)", "2.2.0") def apply(interval: Duration): ProcessingTime = { new ProcessingTime(interval.toMillis) } @@ -124,9 +124,9 @@ object ProcessingTime { * }}} * * @since 2.0.0 - * @deprecated use Trigger.ProcessingTimeTrigger(interval) + * @deprecated use Trigger.ProcessingTime(interval) */ - @deprecated("use Trigger.ProcessingTimeTrigger(interval)", "2.2.0") + @deprecated("use Trigger.ProcessingTime(interval)", "2.2.0") def create(interval: String): ProcessingTime = { apply(interval) } @@ -141,9 +141,9 @@ object ProcessingTime { * }}} * * @since 2.0.0 - * @deprecated use Trigger.ProcessingTimeTrigger(interval) + * @deprecated use Trigger.ProcessingTime(interval, unit) */ - @deprecated("use Trigger.ProcessingTimeTrigger(interval, unit)", "2.2.0") + @deprecated("use Trigger.ProcessingTime(interval, unit)", "2.2.0") def create(interval: Long, unit: TimeUnit): ProcessingTime = { new ProcessingTime(unit.toMillis(interval)) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/Trigger.java b/sql/core/src/main/scala/org/apache/spark/sql/streaming/Trigger.java index a03a851f245fc..3e3997fa9bfec 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/Trigger.java +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/Trigger.java @@ -43,7 +43,7 @@ public class Trigger { * @since 2.2.0 */ public static Trigger ProcessingTime(long intervalMs) { - return ProcessingTime.apply(intervalMs); + return ProcessingTime.create(intervalMs, TimeUnit.MILLISECONDS); } /** From d27daa54bd341b29737a6352d9a1055151248ae7 Mon Sep 17 00:00:00 2001 From: Timothy Hunter Date: Thu, 23 Mar 2017 18:42:13 -0700 Subject: [PATCH 0100/1765] [SPARK-19636][ML] Feature parity for correlation statistics in MLlib ## What changes were proposed in this pull request? This patch adds the Dataframes-based support for the correlation statistics found in the `org.apache.spark.mllib.stat.correlation.Statistics`, following the design doc discussed in the JIRA ticket. The current implementation is a simple wrapper around the `spark.mllib` implementation. Future optimizations can be implemented at a later stage. ## How was this patch tested? ``` build/sbt "testOnly org.apache.spark.ml.stat.StatisticsSuite" ``` Author: Timothy Hunter Closes #17108 from thunterdb/19636. --- .../apache/spark/ml/util/TestingUtils.scala | 8 ++ .../apache/spark/ml/stat/Correlation.scala | 86 +++++++++++++++++++ .../spark/ml/stat/CorrelationSuite.scala | 77 +++++++++++++++++ 3 files changed, 171 insertions(+) create mode 100644 mllib/src/main/scala/org/apache/spark/ml/stat/Correlation.scala create mode 100644 mllib/src/test/scala/org/apache/spark/ml/stat/CorrelationSuite.scala diff --git a/mllib-local/src/test/scala/org/apache/spark/ml/util/TestingUtils.scala b/mllib-local/src/test/scala/org/apache/spark/ml/util/TestingUtils.scala index 2327917e2cad7..30edd00fb53e1 100644 --- a/mllib-local/src/test/scala/org/apache/spark/ml/util/TestingUtils.scala +++ b/mllib-local/src/test/scala/org/apache/spark/ml/util/TestingUtils.scala @@ -32,6 +32,10 @@ object TestingUtils { * the relative tolerance is meaningless, so the exception will be raised to warn users. */ private def RelativeErrorComparison(x: Double, y: Double, eps: Double): Boolean = { + // Special case for NaNs + if (x.isNaN && y.isNaN) { + return true + } val absX = math.abs(x) val absY = math.abs(y) val diff = math.abs(x - y) @@ -49,6 +53,10 @@ object TestingUtils { * Private helper function for comparing two values using absolute tolerance. */ private def AbsoluteErrorComparison(x: Double, y: Double, eps: Double): Boolean = { + // Special case for NaNs + if (x.isNaN && y.isNaN) { + return true + } math.abs(x - y) < eps } diff --git a/mllib/src/main/scala/org/apache/spark/ml/stat/Correlation.scala b/mllib/src/main/scala/org/apache/spark/ml/stat/Correlation.scala new file mode 100644 index 0000000000000..a7243ccbf28cc --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/stat/Correlation.scala @@ -0,0 +1,86 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.stat + +import scala.collection.JavaConverters._ + +import org.apache.spark.annotation.{Experimental, Since} +import org.apache.spark.ml.linalg.{SQLDataTypes, Vector} +import org.apache.spark.mllib.linalg.{Vectors => OldVectors} +import org.apache.spark.mllib.stat.{Statistics => OldStatistics} +import org.apache.spark.sql.{DataFrame, Dataset, Row} +import org.apache.spark.sql.types.{StructField, StructType} + +/** + * API for correlation functions in MLlib, compatible with Dataframes and Datasets. + * + * The functions in this package generalize the functions in [[org.apache.spark.sql.Dataset.stat]] + * to spark.ml's Vector types. + */ +@Since("2.2.0") +@Experimental +object Correlation { + + /** + * :: Experimental :: + * Compute the correlation matrix for the input RDD of Vectors using the specified method. + * Methods currently supported: `pearson` (default), `spearman`. + * + * @param dataset A dataset or a dataframe + * @param column The name of the column of vectors for which the correlation coefficient needs + * to be computed. This must be a column of the dataset, and it must contain + * Vector objects. + * @param method String specifying the method to use for computing correlation. + * Supported: `pearson` (default), `spearman` + * @return A dataframe that contains the correlation matrix of the column of vectors. This + * dataframe contains a single row and a single column of name + * '$METHODNAME($COLUMN)'. + * @throws IllegalArgumentException if the column is not a valid column in the dataset, or if + * the content of this column is not of type Vector. + * + * Here is how to access the correlation coefficient: + * {{{ + * val data: Dataset[Vector] = ... + * val Row(coeff: Matrix) = Statistics.corr(data, "value").head + * // coeff now contains the Pearson correlation matrix. + * }}} + * + * @note For Spearman, a rank correlation, we need to create an RDD[Double] for each column + * and sort it in order to retrieve the ranks and then join the columns back into an RDD[Vector], + * which is fairly costly. Cache the input RDD before calling corr with `method = "spearman"` to + * avoid recomputing the common lineage. + */ + @Since("2.2.0") + def corr(dataset: Dataset[_], column: String, method: String): DataFrame = { + val rdd = dataset.select(column).rdd.map { + case Row(v: Vector) => OldVectors.fromML(v) + } + val oldM = OldStatistics.corr(rdd, method) + val name = s"$method($column)" + val schema = StructType(Array(StructField(name, SQLDataTypes.MatrixType, nullable = false))) + dataset.sparkSession.createDataFrame(Seq(Row(oldM.asML)).asJava, schema) + } + + /** + * Compute the Pearson correlation matrix for the input Dataset of Vectors. + */ + @Since("2.2.0") + def corr(dataset: Dataset[_], column: String): DataFrame = { + corr(dataset, column, "pearson") + } +} diff --git a/mllib/src/test/scala/org/apache/spark/ml/stat/CorrelationSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/stat/CorrelationSuite.scala new file mode 100644 index 0000000000000..7d935e651f220 --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/stat/CorrelationSuite.scala @@ -0,0 +1,77 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.stat + +import breeze.linalg.{DenseMatrix => BDM} + +import org.apache.spark.SparkFunSuite +import org.apache.spark.internal.Logging +import org.apache.spark.ml.linalg.{Matrices, Matrix, Vectors} +import org.apache.spark.ml.util.TestingUtils._ +import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.sql.{DataFrame, Row} + + +class CorrelationSuite extends SparkFunSuite with MLlibTestSparkContext with Logging { + + val xData = Array(1.0, 0.0, -2.0) + val yData = Array(4.0, 5.0, 3.0) + val zeros = new Array[Double](3) + val data = Seq( + Vectors.dense(1.0, 0.0, 0.0, -2.0), + Vectors.dense(4.0, 5.0, 0.0, 3.0), + Vectors.dense(6.0, 7.0, 0.0, 8.0), + Vectors.dense(9.0, 0.0, 0.0, 1.0) + ) + + private def X = spark.createDataFrame(data.map(Tuple1.apply)).toDF("features") + + private def extract(df: DataFrame): BDM[Double] = { + val Array(Row(mat: Matrix)) = df.collect() + mat.asBreeze.toDenseMatrix + } + + + test("corr(X) default, pearson") { + val defaultMat = Correlation.corr(X, "features") + val pearsonMat = Correlation.corr(X, "features", "pearson") + // scalastyle:off + val expected = Matrices.fromBreeze(BDM( + (1.00000000, 0.05564149, Double.NaN, 0.4004714), + (0.05564149, 1.00000000, Double.NaN, 0.9135959), + (Double.NaN, Double.NaN, 1.00000000, Double.NaN), + (0.40047142, 0.91359586, Double.NaN, 1.0000000))) + // scalastyle:on + + assert(Matrices.fromBreeze(extract(defaultMat)) ~== expected absTol 1e-4) + assert(Matrices.fromBreeze(extract(pearsonMat)) ~== expected absTol 1e-4) + } + + test("corr(X) spearman") { + val spearmanMat = Correlation.corr(X, "features", "spearman") + // scalastyle:off + val expected = Matrices.fromBreeze(BDM( + (1.0000000, 0.1054093, Double.NaN, 0.4000000), + (0.1054093, 1.0000000, Double.NaN, 0.9486833), + (Double.NaN, Double.NaN, 1.00000000, Double.NaN), + (0.4000000, 0.9486833, Double.NaN, 1.0000000))) + // scalastyle:on + assert(Matrices.fromBreeze(extract(spearmanMat)) ~== expected absTol 1e-4) + } + +} From bb823ca4b479a00030c4919c2d857d254b2a44d8 Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Fri, 24 Mar 2017 12:57:56 +0800 Subject: [PATCH 0101/1765] [SPARK-19959][SQL] Fix to throw NullPointerException in df[java.lang.Long].collect ## What changes were proposed in this pull request? This PR fixes `NullPointerException` in the generated code by Catalyst. When we run the following code, we get the following `NullPointerException`. This is because there is no null checks for `inputadapter_value` while `java.lang.Long inputadapter_value` at Line 30 may have `null`. This happen when a type of DataFrame is nullable primitive type such as `java.lang.Long` and the wholestage codegen is used. While the physical plan keeps `nullable=true` in `input[0, java.lang.Long, true].longValue`, `BoundReference.doGenCode` ignores `nullable=true`. Thus, nullcheck code will not be generated and `NullPointerException` will occur. This PR checks the nullability and correctly generates nullcheck if needed. ```java sparkContext.parallelize(Seq[java.lang.Long](0L, null, 2L), 1).toDF.collect ``` ```java Caused by: java.lang.NullPointerException at org.apache.spark.sql.catalyst.expressions.GeneratedClass$GeneratedIterator.processNext(generated.java:37) at org.apache.spark.sql.execution.BufferedRowIterator.hasNext(BufferedRowIterator.java:43) at org.apache.spark.sql.execution.WholeStageCodegenExec$$anonfun$8$$anon$1.hasNext(WholeStageCodegenExec.scala:393) ... ``` Generated code without this PR ```java /* 005 */ final class GeneratedIterator extends org.apache.spark.sql.execution.BufferedRowIterator { /* 006 */ private Object[] references; /* 007 */ private scala.collection.Iterator[] inputs; /* 008 */ private scala.collection.Iterator inputadapter_input; /* 009 */ private UnsafeRow serializefromobject_result; /* 010 */ private org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder serializefromobject_holder; /* 011 */ private org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter serializefromobject_rowWriter; /* 012 */ /* 013 */ public GeneratedIterator(Object[] references) { /* 014 */ this.references = references; /* 015 */ } /* 016 */ /* 017 */ public void init(int index, scala.collection.Iterator[] inputs) { /* 018 */ partitionIndex = index; /* 019 */ this.inputs = inputs; /* 020 */ inputadapter_input = inputs[0]; /* 021 */ serializefromobject_result = new UnsafeRow(1); /* 022 */ this.serializefromobject_holder = new org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder(serializefromobject_result, 0); /* 023 */ this.serializefromobject_rowWriter = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter(serializefromobject_holder, 1); /* 024 */ /* 025 */ } /* 026 */ /* 027 */ protected void processNext() throws java.io.IOException { /* 028 */ while (inputadapter_input.hasNext() && !stopEarly()) { /* 029 */ InternalRow inputadapter_row = (InternalRow) inputadapter_input.next(); /* 030 */ java.lang.Long inputadapter_value = (java.lang.Long)inputadapter_row.get(0, null); /* 031 */ /* 032 */ boolean serializefromobject_isNull = true; /* 033 */ long serializefromobject_value = -1L; /* 034 */ if (!false) { /* 035 */ serializefromobject_isNull = false; /* 036 */ if (!serializefromobject_isNull) { /* 037 */ serializefromobject_value = inputadapter_value.longValue(); /* 038 */ } /* 039 */ /* 040 */ } /* 041 */ serializefromobject_rowWriter.zeroOutNullBytes(); /* 042 */ /* 043 */ if (serializefromobject_isNull) { /* 044 */ serializefromobject_rowWriter.setNullAt(0); /* 045 */ } else { /* 046 */ serializefromobject_rowWriter.write(0, serializefromobject_value); /* 047 */ } /* 048 */ append(serializefromobject_result); /* 049 */ if (shouldStop()) return; /* 050 */ } /* 051 */ } /* 052 */ } ``` Generated code with this PR ```java /* 005 */ final class GeneratedIterator extends org.apache.spark.sql.execution.BufferedRowIterator { /* 006 */ private Object[] references; /* 007 */ private scala.collection.Iterator[] inputs; /* 008 */ private scala.collection.Iterator inputadapter_input; /* 009 */ private UnsafeRow serializefromobject_result; /* 010 */ private org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder serializefromobject_holder; /* 011 */ private org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter serializefromobject_rowWriter; /* 012 */ /* 013 */ public GeneratedIterator(Object[] references) { /* 014 */ this.references = references; /* 015 */ } /* 016 */ /* 017 */ public void init(int index, scala.collection.Iterator[] inputs) { /* 018 */ partitionIndex = index; /* 019 */ this.inputs = inputs; /* 020 */ inputadapter_input = inputs[0]; /* 021 */ serializefromobject_result = new UnsafeRow(1); /* 022 */ this.serializefromobject_holder = new org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder(serializefromobject_result, 0); /* 023 */ this.serializefromobject_rowWriter = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter(serializefromobject_holder, 1); /* 024 */ /* 025 */ } /* 026 */ /* 027 */ protected void processNext() throws java.io.IOException { /* 028 */ while (inputadapter_input.hasNext() && !stopEarly()) { /* 029 */ InternalRow inputadapter_row = (InternalRow) inputadapter_input.next(); /* 030 */ boolean inputadapter_isNull = inputadapter_row.isNullAt(0); /* 031 */ java.lang.Long inputadapter_value = inputadapter_isNull ? null : ((java.lang.Long)inputadapter_row.get(0, null)); /* 032 */ /* 033 */ boolean serializefromobject_isNull = true; /* 034 */ long serializefromobject_value = -1L; /* 035 */ if (!inputadapter_isNull) { /* 036 */ serializefromobject_isNull = false; /* 037 */ if (!serializefromobject_isNull) { /* 038 */ serializefromobject_value = inputadapter_value.longValue(); /* 039 */ } /* 040 */ /* 041 */ } /* 042 */ serializefromobject_rowWriter.zeroOutNullBytes(); /* 043 */ /* 044 */ if (serializefromobject_isNull) { /* 045 */ serializefromobject_rowWriter.setNullAt(0); /* 046 */ } else { /* 047 */ serializefromobject_rowWriter.write(0, serializefromobject_value); /* 048 */ } /* 049 */ append(serializefromobject_result); /* 050 */ if (shouldStop()) return; /* 051 */ } /* 052 */ } /* 053 */ } ``` ## How was this patch tested? Added new test suites in `DataFrameSuites` Author: Kazuaki Ishizaki Closes #17302 from kiszk/SPARK-19959. --- .../spark/sql/catalyst/plans/logical/object.scala | 5 ++++- .../apache/spark/sql/DataFrameImplicitsSuite.scala | 11 +++++++++++ 2 files changed, 15 insertions(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala index 6225b3fa42990..bfb70c2ef4c89 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala @@ -41,7 +41,10 @@ object CatalystSerde { } def generateObjAttr[T : Encoder]: Attribute = { - AttributeReference("obj", encoderFor[T].deserializer.dataType, nullable = false)() + val enc = encoderFor[T] + val dataType = enc.deserializer.dataType + val nullable = !enc.clsTag.runtimeClass.isPrimitive + AttributeReference("obj", dataType, nullable)() } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameImplicitsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameImplicitsSuite.scala index 094efbaeadcd5..63094d1b6122b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameImplicitsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameImplicitsSuite.scala @@ -51,4 +51,15 @@ class DataFrameImplicitsSuite extends QueryTest with SharedSQLContext { sparkContext.parallelize(1 to 10).map(_.toString).toDF("stringCol"), (1 to 10).map(i => Row(i.toString))) } + + test("SPARK-19959: df[java.lang.Long].collect includes null throws NullPointerException") { + checkAnswer(sparkContext.parallelize(Seq[java.lang.Integer](0, null, 2), 1).toDF, + Seq(Row(0), Row(null), Row(2))) + checkAnswer(sparkContext.parallelize(Seq[java.lang.Long](0L, null, 2L), 1).toDF, + Seq(Row(0L), Row(null), Row(2L))) + checkAnswer(sparkContext.parallelize(Seq[java.lang.Float](0.0F, null, 2.0F), 1).toDF, + Seq(Row(0.0F), Row(null), Row(2.0F))) + checkAnswer(sparkContext.parallelize(Seq[java.lang.Double](0.0D, null, 2.0D), 1).toDF, + Seq(Row(0.0D), Row(null), Row(2.0D))) + } } From 19596c28b6ef6e7abe0cfccfd2269c2fddf1fdee Mon Sep 17 00:00:00 2001 From: jinxing Date: Thu, 23 Mar 2017 23:25:56 -0700 Subject: [PATCH 0102/1765] [SPARK-16929] Improve performance when check speculatable tasks. ## What changes were proposed in this pull request? 1. Use a MedianHeap to record durations of successful tasks. When check speculatable tasks, we can get the median duration with O(1) time complexity. 2. `checkSpeculatableTasks` will synchronize `TaskSchedulerImpl`. If `checkSpeculatableTasks` doesn't finish with 100ms, then the possibility exists for that thread to release and then immediately re-acquire the lock. Change `scheduleAtFixedRate` to be `scheduleWithFixedDelay` when call method of `checkSpeculatableTasks`. ## How was this patch tested? Added MedianHeapSuite. Author: jinxing Closes #16867 from jinxing64/SPARK-16929. --- .../spark/scheduler/TaskSchedulerImpl.scala | 2 +- .../spark/scheduler/TaskSetManager.scala | 19 +++- .../spark/util/collection/MedianHeap.scala | 93 +++++++++++++++++++ .../spark/scheduler/TaskSetManagerSuite.scala | 2 + .../util/collection/MedianHeapSuite.scala | 66 +++++++++++++ 5 files changed, 176 insertions(+), 6 deletions(-) create mode 100644 core/src/main/scala/org/apache/spark/util/collection/MedianHeap.scala create mode 100644 core/src/test/scala/org/apache/spark/util/collection/MedianHeapSuite.scala diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala index 8257c70d672a0..d6225a08739dd 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala @@ -174,7 +174,7 @@ private[spark] class TaskSchedulerImpl private[scheduler]( if (!isLocal && conf.getBoolean("spark.speculation", false)) { logInfo("Starting speculative execution thread") - speculationScheduler.scheduleAtFixedRate(new Runnable { + speculationScheduler.scheduleWithFixedDelay(new Runnable { override def run(): Unit = Utils.tryOrStopSparkContext(sc) { checkSpeculatableTasks() } diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala index fd93a1f5c5d2a..f4a21bca79aa4 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala @@ -19,11 +19,10 @@ package org.apache.spark.scheduler import java.io.NotSerializableException import java.nio.ByteBuffer -import java.util.Arrays import java.util.concurrent.ConcurrentLinkedQueue import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet} -import scala.math.{max, min} +import scala.math.max import scala.util.control.NonFatal import org.apache.spark._ @@ -31,6 +30,7 @@ import org.apache.spark.internal.Logging import org.apache.spark.scheduler.SchedulingMode._ import org.apache.spark.TaskState.TaskState import org.apache.spark.util.{AccumulatorV2, Clock, SystemClock, Utils} +import org.apache.spark.util.collection.MedianHeap /** * Schedules the tasks within a single TaskSet in the TaskSchedulerImpl. This class keeps track of @@ -63,6 +63,8 @@ private[spark] class TaskSetManager( // Limit of bytes for total size of results (default is 1GB) val maxResultSize = Utils.getMaxResultSize(conf) + val speculationEnabled = conf.getBoolean("spark.speculation", false) + // Serializer for closures and tasks. val env = SparkEnv.get val ser = env.closureSerializer.newInstance() @@ -141,6 +143,11 @@ private[spark] class TaskSetManager( // Task index, start and finish time for each task attempt (indexed by task ID) private val taskInfos = new HashMap[Long, TaskInfo] + // Use a MedianHeap to record durations of successful tasks so we know when to launch + // speculative tasks. This is only used when speculation is enabled, to avoid the overhead + // of inserting into the heap when the heap won't be used. + val successfulTaskDurations = new MedianHeap() + // How frequently to reprint duplicate exceptions in full, in milliseconds val EXCEPTION_PRINT_INTERVAL = conf.getLong("spark.logging.exceptionPrintInterval", 10000) @@ -698,6 +705,9 @@ private[spark] class TaskSetManager( val info = taskInfos(tid) val index = info.index info.markFinished(TaskState.FINISHED, clock.getTimeMillis()) + if (speculationEnabled) { + successfulTaskDurations.insert(info.duration) + } removeRunningTask(tid) // This method is called by "TaskSchedulerImpl.handleSuccessfulTask" which holds the // "TaskSchedulerImpl" lock until exiting. To avoid the SPARK-7655 issue, we should not @@ -919,11 +929,10 @@ private[spark] class TaskSetManager( var foundTasks = false val minFinishedForSpeculation = (SPECULATION_QUANTILE * numTasks).floor.toInt logDebug("Checking for speculative tasks: minFinished = " + minFinishedForSpeculation) + if (tasksSuccessful >= minFinishedForSpeculation && tasksSuccessful > 0) { val time = clock.getTimeMillis() - val durations = taskInfos.values.filter(_.successful).map(_.duration).toArray - Arrays.sort(durations) - val medianDuration = durations(min((0.5 * tasksSuccessful).round.toInt, durations.length - 1)) + var medianDuration = successfulTaskDurations.median val threshold = max(SPECULATION_MULTIPLIER * medianDuration, minTimeToSpeculation) // TODO: Threshold should also look at standard deviation of task durations and have a lower // bound based on that. diff --git a/core/src/main/scala/org/apache/spark/util/collection/MedianHeap.scala b/core/src/main/scala/org/apache/spark/util/collection/MedianHeap.scala new file mode 100644 index 0000000000000..6e57c3c5bee8c --- /dev/null +++ b/core/src/main/scala/org/apache/spark/util/collection/MedianHeap.scala @@ -0,0 +1,93 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util.collection + +import scala.collection.mutable.PriorityQueue + +/** + * MedianHeap is designed to be used to quickly track the median of a group of numbers + * that may contain duplicates. Inserting a new number has O(log n) time complexity and + * determining the median has O(1) time complexity. + * The basic idea is to maintain two heaps: a smallerHalf and a largerHalf. The smallerHalf + * stores the smaller half of all numbers while the largerHalf stores the larger half. + * The sizes of two heaps need to be balanced each time when a new number is inserted so + * that their sizes will not be different by more than 1. Therefore each time when + * findMedian() is called we check if two heaps have the same size. If they do, we should + * return the average of the two top values of heaps. Otherwise we return the top of the + * heap which has one more element. + */ +private[spark] class MedianHeap(implicit val ord: Ordering[Double]) { + + /** + * Stores all the numbers less than the current median in a smallerHalf, + * i.e median is the maximum, at the root. + */ + private[this] var smallerHalf = PriorityQueue.empty[Double](ord) + + /** + * Stores all the numbers greater than the current median in a largerHalf, + * i.e median is the minimum, at the root. + */ + private[this] var largerHalf = PriorityQueue.empty[Double](ord.reverse) + + def isEmpty(): Boolean = { + smallerHalf.isEmpty && largerHalf.isEmpty + } + + def size(): Int = { + smallerHalf.size + largerHalf.size + } + + def insert(x: Double): Unit = { + // If both heaps are empty, we arbitrarily insert it into a heap, let's say, the largerHalf. + if (isEmpty) { + largerHalf.enqueue(x) + } else { + // If the number is larger than current median, it should be inserted into largerHalf, + // otherwise smallerHalf. + if (x > median) { + largerHalf.enqueue(x) + } else { + smallerHalf.enqueue(x) + } + } + rebalance() + } + + private[this] def rebalance(): Unit = { + if (largerHalf.size - smallerHalf.size > 1) { + smallerHalf.enqueue(largerHalf.dequeue()) + } + if (smallerHalf.size - largerHalf.size > 1) { + largerHalf.enqueue(smallerHalf.dequeue) + } + } + + def median: Double = { + if (isEmpty) { + throw new NoSuchElementException("MedianHeap is empty.") + } + if (largerHalf.size == smallerHalf.size) { + (largerHalf.head + smallerHalf.head) / 2.0 + } else if (largerHalf.size > smallerHalf.size) { + largerHalf.head + } else { + smallerHalf.head + } + } +} diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala index f36bcd8504b05..064af381a76d2 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala @@ -893,6 +893,7 @@ class TaskSetManagerSuite extends SparkFunSuite with LocalSparkContext with Logg val taskSet = FakeTask.createTaskSet(4) // Set the speculation multiplier to be 0 so speculative tasks are launched immediately sc.conf.set("spark.speculation.multiplier", "0.0") + sc.conf.set("spark.speculation", "true") val clock = new ManualClock() val manager = new TaskSetManager(sched, taskSet, MAX_TASK_FAILURES, clock = clock) val accumUpdatesByTask: Array[Seq[AccumulatorV2[_, _]]] = taskSet.tasks.map { task => @@ -948,6 +949,7 @@ class TaskSetManagerSuite extends SparkFunSuite with LocalSparkContext with Logg // Set the speculation multiplier to be 0 so speculative tasks are launched immediately sc.conf.set("spark.speculation.multiplier", "0.0") sc.conf.set("spark.speculation.quantile", "0.6") + sc.conf.set("spark.speculation", "true") val clock = new ManualClock() val manager = new TaskSetManager(sched, taskSet, MAX_TASK_FAILURES, clock = clock) val accumUpdatesByTask: Array[Seq[AccumulatorV2[_, _]]] = taskSet.tasks.map { task => diff --git a/core/src/test/scala/org/apache/spark/util/collection/MedianHeapSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/MedianHeapSuite.scala new file mode 100644 index 0000000000000..c2a3ee95f1c55 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/util/collection/MedianHeapSuite.scala @@ -0,0 +1,66 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util.collection + +import java.util.NoSuchElementException + +import org.apache.spark.SparkFunSuite + +class MedianHeapSuite extends SparkFunSuite { + + test("If no numbers in MedianHeap, NoSuchElementException is thrown.") { + val medianHeap = new MedianHeap() + intercept[NoSuchElementException] { + medianHeap.median + } + } + + test("Median should be correct when size of MedianHeap is even") { + val array = Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9) + val medianHeap = new MedianHeap() + array.foreach(medianHeap.insert(_)) + assert(medianHeap.size() === 10) + assert(medianHeap.median === 4.5) + } + + test("Median should be correct when size of MedianHeap is odd") { + val array = Array(0, 1, 2, 3, 4, 5, 6, 7, 8) + val medianHeap = new MedianHeap() + array.foreach(medianHeap.insert(_)) + assert(medianHeap.size() === 9) + assert(medianHeap.median === 4) + } + + test("Median should be correct though there are duplicated numbers inside.") { + val array = Array(0, 0, 1, 1, 2, 3, 4) + val medianHeap = new MedianHeap() + array.foreach(medianHeap.insert(_)) + assert(medianHeap.size === 7) + assert(medianHeap.median === 1) + } + + test("Median should be correct when input data is skewed.") { + val medianHeap = new MedianHeap() + (0 until 10).foreach(_ => medianHeap.insert(5)) + assert(medianHeap.median === 5) + (0 until 100).foreach(_ => medianHeap.insert(10)) + assert(medianHeap.median === 10) + (0 until 1000).foreach(_ => medianHeap.insert(0)) + assert(medianHeap.median === 0) + } +} From 8e558041aa0c41ba9fb2ce242daaf6d6ed4d85b7 Mon Sep 17 00:00:00 2001 From: Eric Liang Date: Thu, 23 Mar 2017 23:30:40 -0700 Subject: [PATCH 0103/1765] [SPARK-19820][CORE] Add interface to kill tasks w/ a reason This commit adds a killTaskAttempt method to SparkContext, to allow users to kill tasks so that they can be re-scheduled elsewhere. This also refactors the task kill path to allow specifying a reason for the task kill. The reason is propagated opaquely through events, and will show up in the UI automatically as `(N killed: $reason)` and `TaskKilled: $reason`. Without this change, there is no way to provide the user feedback through the UI. Currently used reasons are "stage cancelled", "another attempt succeeded", and "killed via SparkContext.killTask". The user can also specify a custom reason through `SparkContext.killTask`. cc rxin In the stage overview UI the reasons are summarized: ![1](https://cloud.githubusercontent.com/assets/14922/23929209/a83b2862-08e1-11e7-8b3e-ae1967bbe2e5.png) Within the stage UI you can see individual task kill reasons: ![2](https://cloud.githubusercontent.com/assets/14922/23929200/9a798692-08e1-11e7-8697-72b27ad8a287.png) Existing tests, tried killing some stages in the UI and verified the messages are as expected. Author: Eric Liang Author: Eric Liang Closes #17166 from ericl/kill-reason. --- .../unsafe/sort/UnsafeInMemorySorter.java | 5 +- .../unsafe/sort/UnsafeSorterSpillReader.java | 5 +- .../apache/spark/InterruptibleIterator.scala | 7 +-- .../scala/org/apache/spark/SparkContext.scala | 18 +++++++ .../scala/org/apache/spark/TaskContext.scala | 10 ++++ .../org/apache/spark/TaskContextImpl.scala | 21 ++++++-- .../org/apache/spark/TaskEndReason.scala | 4 +- .../apache/spark/TaskKilledException.scala | 4 +- .../apache/spark/api/python/PythonRDD.scala | 2 +- .../CoarseGrainedExecutorBackend.scala | 4 +- .../org/apache/spark/executor/Executor.scala | 52 ++++++++++--------- .../apache/spark/scheduler/DAGScheduler.scala | 11 +++- .../spark/scheduler/SchedulerBackend.scala | 15 +++++- .../org/apache/spark/scheduler/Task.scala | 21 ++++---- .../spark/scheduler/TaskScheduler.scala | 7 +++ .../spark/scheduler/TaskSchedulerImpl.scala | 16 +++++- .../spark/scheduler/TaskSetManager.scala | 10 +++- .../cluster/CoarseGrainedClusterMessage.scala | 2 +- .../CoarseGrainedSchedulerBackend.scala | 10 ++-- .../local/LocalSchedulerBackend.scala | 11 ++-- .../scala/org/apache/spark/ui/UIUtils.scala | 7 ++- .../apache/spark/ui/jobs/AllJobsPage.scala | 4 +- .../apache/spark/ui/jobs/ExecutorTable.scala | 4 +- .../spark/ui/jobs/JobProgressListener.scala | 17 +++--- .../org/apache/spark/ui/jobs/StageTable.scala | 2 +- .../org/apache/spark/ui/jobs/UIData.scala | 6 +-- .../org/apache/spark/util/JsonProtocol.scala | 7 ++- .../org/apache/spark/SparkContextSuite.scala | 47 ++++++++++++++++- .../apache/spark/executor/ExecutorSuite.scala | 4 +- .../spark/scheduler/DAGSchedulerSuite.scala | 6 +++ .../ExternalClusterManagerSuite.scala | 2 + .../OutputCommitCoordinatorSuite.scala | 4 +- .../scheduler/SchedulerIntegrationSuite.scala | 3 +- .../spark/scheduler/TaskSetManagerSuite.scala | 17 +++--- .../org/apache/spark/ui/UIUtilsSuite.scala | 2 +- .../ui/jobs/JobProgressListenerSuite.scala | 5 +- .../apache/spark/util/JsonProtocolSuite.scala | 5 +- project/MimaExcludes.scala | 13 +++++ .../spark/executor/MesosExecutorBackend.scala | 3 +- .../MesosFineGrainedSchedulerBackend.scala | 3 +- .../execution/datasources/FileScanRDD.scala | 4 +- .../spark/streaming/ui/AllBatchesTable.scala | 2 +- .../apache/spark/streaming/ui/BatchPage.scala | 2 +- 43 files changed, 289 insertions(+), 115 deletions(-) diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java index f219c5605b643..c14c12664f5ab 100644 --- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java @@ -23,7 +23,6 @@ import org.apache.avro.reflect.Nullable; import org.apache.spark.TaskContext; -import org.apache.spark.TaskKilledException; import org.apache.spark.memory.MemoryConsumer; import org.apache.spark.memory.TaskMemoryManager; import org.apache.spark.unsafe.Platform; @@ -291,8 +290,8 @@ public void loadNext() { // to avoid performance overhead. This check is added here in `loadNext()` instead of in // `hasNext()` because it's technically possible for the caller to be relying on // `getNumRecords()` instead of `hasNext()` to know when to stop. - if (taskContext != null && taskContext.isInterrupted()) { - throw new TaskKilledException(); + if (taskContext != null) { + taskContext.killTaskIfInterrupted(); } // This pointer points to a 4-byte record length, followed by the record's bytes final long recordPointer = array.get(offset + position); diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillReader.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillReader.java index b6323c624b7b9..9521ab86a12d5 100644 --- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillReader.java +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillReader.java @@ -24,7 +24,6 @@ import org.apache.spark.SparkEnv; import org.apache.spark.TaskContext; -import org.apache.spark.TaskKilledException; import org.apache.spark.io.NioBufferedFileInputStream; import org.apache.spark.serializer.SerializerManager; import org.apache.spark.storage.BlockId; @@ -102,8 +101,8 @@ public void loadNext() throws IOException { // to avoid performance overhead. This check is added here in `loadNext()` instead of in // `hasNext()` because it's technically possible for the caller to be relying on // `getNumRecords()` instead of `hasNext()` to know when to stop. - if (taskContext != null && taskContext.isInterrupted()) { - throw new TaskKilledException(); + if (taskContext != null) { + taskContext.killTaskIfInterrupted(); } recordLength = din.readInt(); keyPrefix = din.readLong(); diff --git a/core/src/main/scala/org/apache/spark/InterruptibleIterator.scala b/core/src/main/scala/org/apache/spark/InterruptibleIterator.scala index 5c262bcbddf76..7f2c0068174b5 100644 --- a/core/src/main/scala/org/apache/spark/InterruptibleIterator.scala +++ b/core/src/main/scala/org/apache/spark/InterruptibleIterator.scala @@ -33,11 +33,8 @@ class InterruptibleIterator[+T](val context: TaskContext, val delegate: Iterator // is allowed. The assumption is that Thread.interrupted does not have a memory fence in read // (just a volatile field in C), while context.interrupted is a volatile in the JVM, which // introduces an expensive read fence. - if (context.isInterrupted) { - throw new TaskKilledException - } else { - delegate.hasNext - } + context.killTaskIfInterrupted() + delegate.hasNext } def next(): T = delegate.next() diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index 0e36a30c933d0..0225fd6056074 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -2249,6 +2249,24 @@ class SparkContext(config: SparkConf) extends Logging { dagScheduler.cancelStage(stageId, None) } + /** + * Kill and reschedule the given task attempt. Task ids can be obtained from the Spark UI + * or through SparkListener.onTaskStart. + * + * @param taskId the task ID to kill. This id uniquely identifies the task attempt. + * @param interruptThread whether to interrupt the thread running the task. + * @param reason the reason for killing the task, which should be a short string. If a task + * is killed multiple times with different reasons, only one reason will be reported. + * + * @return Whether the task was successfully killed. + */ + def killTaskAttempt( + taskId: Long, + interruptThread: Boolean = true, + reason: String = "killed via SparkContext.killTaskAttempt"): Boolean = { + dagScheduler.killTaskAttempt(taskId, interruptThread, reason) + } + /** * Clean a closure to make it ready to serialized and send to tasks * (removes unreferenced variables in $outer's, updates REPL variables) diff --git a/core/src/main/scala/org/apache/spark/TaskContext.scala b/core/src/main/scala/org/apache/spark/TaskContext.scala index 5acfce17593b3..0b87cd503d4fa 100644 --- a/core/src/main/scala/org/apache/spark/TaskContext.scala +++ b/core/src/main/scala/org/apache/spark/TaskContext.scala @@ -184,6 +184,16 @@ abstract class TaskContext extends Serializable { @DeveloperApi def getMetricsSources(sourceName: String): Seq[Source] + /** + * If the task is interrupted, throws TaskKilledException with the reason for the interrupt. + */ + private[spark] def killTaskIfInterrupted(): Unit + + /** + * If the task is interrupted, the reason this task was killed, otherwise None. + */ + private[spark] def getKillReason(): Option[String] + /** * Returns the manager for this task's managed memory. */ diff --git a/core/src/main/scala/org/apache/spark/TaskContextImpl.scala b/core/src/main/scala/org/apache/spark/TaskContextImpl.scala index f346cf8d65806..8cd1d1c96aa0a 100644 --- a/core/src/main/scala/org/apache/spark/TaskContextImpl.scala +++ b/core/src/main/scala/org/apache/spark/TaskContextImpl.scala @@ -59,8 +59,8 @@ private[spark] class TaskContextImpl( /** List of callback functions to execute when the task fails. */ @transient private val onFailureCallbacks = new ArrayBuffer[TaskFailureListener] - // Whether the corresponding task has been killed. - @volatile private var interrupted: Boolean = false + // If defined, the corresponding task has been killed and this option contains the reason. + @volatile private var reasonIfKilled: Option[String] = None // Whether the task has completed. private var completed: Boolean = false @@ -140,8 +140,19 @@ private[spark] class TaskContextImpl( } /** Marks the task for interruption, i.e. cancellation. */ - private[spark] def markInterrupted(): Unit = { - interrupted = true + private[spark] def markInterrupted(reason: String): Unit = { + reasonIfKilled = Some(reason) + } + + private[spark] override def killTaskIfInterrupted(): Unit = { + val reason = reasonIfKilled + if (reason.isDefined) { + throw new TaskKilledException(reason.get) + } + } + + private[spark] override def getKillReason(): Option[String] = { + reasonIfKilled } @GuardedBy("this") @@ -149,7 +160,7 @@ private[spark] class TaskContextImpl( override def isRunningLocally(): Boolean = false - override def isInterrupted(): Boolean = interrupted + override def isInterrupted(): Boolean = reasonIfKilled.isDefined override def getLocalProperty(key: String): String = localProperties.getProperty(key) diff --git a/core/src/main/scala/org/apache/spark/TaskEndReason.scala b/core/src/main/scala/org/apache/spark/TaskEndReason.scala index 8c1b5f7bf0d9b..a76283e33fa65 100644 --- a/core/src/main/scala/org/apache/spark/TaskEndReason.scala +++ b/core/src/main/scala/org/apache/spark/TaskEndReason.scala @@ -212,8 +212,8 @@ case object TaskResultLost extends TaskFailedReason { * Task was killed intentionally and needs to be rescheduled. */ @DeveloperApi -case object TaskKilled extends TaskFailedReason { - override def toErrorString: String = "TaskKilled (killed intentionally)" +case class TaskKilled(reason: String) extends TaskFailedReason { + override def toErrorString: String = s"TaskKilled ($reason)" override def countTowardsTaskFailures: Boolean = false } diff --git a/core/src/main/scala/org/apache/spark/TaskKilledException.scala b/core/src/main/scala/org/apache/spark/TaskKilledException.scala index ad487c4efb87a..9dbf0d493be11 100644 --- a/core/src/main/scala/org/apache/spark/TaskKilledException.scala +++ b/core/src/main/scala/org/apache/spark/TaskKilledException.scala @@ -24,4 +24,6 @@ import org.apache.spark.annotation.DeveloperApi * Exception thrown when a task is explicitly killed (i.e., task failure is expected). */ @DeveloperApi -class TaskKilledException extends RuntimeException +class TaskKilledException(val reason: String) extends RuntimeException { + def this() = this("unknown reason") +} diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala index 04ae97ed3ccbe..b0dd2fc187baf 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala @@ -215,7 +215,7 @@ private[spark] class PythonRunner( case e: Exception if context.isInterrupted => logDebug("Exception thrown after task interruption", e) - throw new TaskKilledException + throw new TaskKilledException(context.getKillReason().getOrElse("unknown reason")) case e: Exception if env.isStopped => logDebug("Exception thrown after context is stopped", e) diff --git a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala index b376ecd301eab..ba0096d874567 100644 --- a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala +++ b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala @@ -97,11 +97,11 @@ private[spark] class CoarseGrainedExecutorBackend( executor.launchTask(this, taskDesc) } - case KillTask(taskId, _, interruptThread) => + case KillTask(taskId, _, interruptThread, reason) => if (executor == null) { exitExecutor(1, "Received KillTask command but executor was null") } else { - executor.killTask(taskId, interruptThread) + executor.killTask(taskId, interruptThread, reason) } case StopExecutor => diff --git a/core/src/main/scala/org/apache/spark/executor/Executor.scala b/core/src/main/scala/org/apache/spark/executor/Executor.scala index 790c1ae942474..99b1608010ddb 100644 --- a/core/src/main/scala/org/apache/spark/executor/Executor.scala +++ b/core/src/main/scala/org/apache/spark/executor/Executor.scala @@ -158,7 +158,7 @@ private[spark] class Executor( threadPool.execute(tr) } - def killTask(taskId: Long, interruptThread: Boolean): Unit = { + def killTask(taskId: Long, interruptThread: Boolean, reason: String): Unit = { val taskRunner = runningTasks.get(taskId) if (taskRunner != null) { if (taskReaperEnabled) { @@ -168,7 +168,8 @@ private[spark] class Executor( case Some(existingReaper) => interruptThread && !existingReaper.interruptThread } if (shouldCreateReaper) { - val taskReaper = new TaskReaper(taskRunner, interruptThread = interruptThread) + val taskReaper = new TaskReaper( + taskRunner, interruptThread = interruptThread, reason = reason) taskReaperForTask(taskId) = taskReaper Some(taskReaper) } else { @@ -178,7 +179,7 @@ private[spark] class Executor( // Execute the TaskReaper from outside of the synchronized block. maybeNewTaskReaper.foreach(taskReaperPool.execute) } else { - taskRunner.kill(interruptThread = interruptThread) + taskRunner.kill(interruptThread = interruptThread, reason = reason) } } } @@ -189,8 +190,9 @@ private[spark] class Executor( * tasks instead of taking the JVM down. * @param interruptThread whether to interrupt the task thread */ - def killAllTasks(interruptThread: Boolean) : Unit = { - runningTasks.keys().asScala.foreach(t => killTask(t, interruptThread = interruptThread)) + def killAllTasks(interruptThread: Boolean, reason: String) : Unit = { + runningTasks.keys().asScala.foreach(t => + killTask(t, interruptThread = interruptThread, reason = reason)) } def stop(): Unit = { @@ -217,8 +219,8 @@ private[spark] class Executor( val threadName = s"Executor task launch worker for task $taskId" private val taskName = taskDescription.name - /** Whether this task has been killed. */ - @volatile private var killed = false + /** If specified, this task has been killed and this option contains the reason. */ + @volatile private var reasonIfKilled: Option[String] = None @volatile private var threadId: Long = -1 @@ -239,13 +241,13 @@ private[spark] class Executor( */ @volatile var task: Task[Any] = _ - def kill(interruptThread: Boolean): Unit = { - logInfo(s"Executor is trying to kill $taskName (TID $taskId)") - killed = true + def kill(interruptThread: Boolean, reason: String): Unit = { + logInfo(s"Executor is trying to kill $taskName (TID $taskId), reason: $reason") + reasonIfKilled = Some(reason) if (task != null) { synchronized { if (!finished) { - task.kill(interruptThread) + task.kill(interruptThread, reason) } } } @@ -296,12 +298,13 @@ private[spark] class Executor( // If this task has been killed before we deserialized it, let's quit now. Otherwise, // continue executing the task. - if (killed) { + val killReason = reasonIfKilled + if (killReason.isDefined) { // Throw an exception rather than returning, because returning within a try{} block // causes a NonLocalReturnControl exception to be thrown. The NonLocalReturnControl // exception will be caught by the catch block, leading to an incorrect ExceptionFailure // for the task. - throw new TaskKilledException + throw new TaskKilledException(killReason.get) } logDebug("Task " + taskId + "'s epoch is " + task.epoch) @@ -358,9 +361,7 @@ private[spark] class Executor( } else 0L // If the task has been killed, let's fail it. - if (task.killed) { - throw new TaskKilledException - } + task.context.killTaskIfInterrupted() val resultSer = env.serializer.newInstance() val beforeSerialization = System.currentTimeMillis() @@ -426,15 +427,17 @@ private[spark] class Executor( setTaskFinishedAndClearInterruptStatus() execBackend.statusUpdate(taskId, TaskState.FAILED, ser.serialize(reason)) - case _: TaskKilledException => - logInfo(s"Executor killed $taskName (TID $taskId)") + case t: TaskKilledException => + logInfo(s"Executor killed $taskName (TID $taskId), reason: ${t.reason}") setTaskFinishedAndClearInterruptStatus() - execBackend.statusUpdate(taskId, TaskState.KILLED, ser.serialize(TaskKilled)) + execBackend.statusUpdate(taskId, TaskState.KILLED, ser.serialize(TaskKilled(t.reason))) - case _: InterruptedException if task.killed => - logInfo(s"Executor interrupted and killed $taskName (TID $taskId)") + case _: InterruptedException if task.reasonIfKilled.isDefined => + val killReason = task.reasonIfKilled.getOrElse("unknown reason") + logInfo(s"Executor interrupted and killed $taskName (TID $taskId), reason: $killReason") setTaskFinishedAndClearInterruptStatus() - execBackend.statusUpdate(taskId, TaskState.KILLED, ser.serialize(TaskKilled)) + execBackend.statusUpdate( + taskId, TaskState.KILLED, ser.serialize(TaskKilled(killReason))) case CausedBy(cDE: CommitDeniedException) => val reason = cDE.toTaskFailedReason @@ -512,7 +515,8 @@ private[spark] class Executor( */ private class TaskReaper( taskRunner: TaskRunner, - val interruptThread: Boolean) + val interruptThread: Boolean, + val reason: String) extends Runnable { private[this] val taskId: Long = taskRunner.taskId @@ -533,7 +537,7 @@ private[spark] class Executor( // Only attempt to kill the task once. If interruptThread = false then a second kill // attempt would be a no-op and if interruptThread = true then it may not be safe or // effective to interrupt multiple times: - taskRunner.kill(interruptThread = interruptThread) + taskRunner.kill(interruptThread = interruptThread, reason = reason) // Monitor the killed task until it exits. The synchronization logic here is complicated // because we don't want to synchronize on the taskRunner while possibly taking a thread // dump, but we also need to be careful to avoid races between checking whether the task diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index d944f268755de..09717316833a7 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -738,6 +738,15 @@ class DAGScheduler( eventProcessLoop.post(StageCancelled(stageId, reason)) } + /** + * Kill a given task. It will be retried. + * + * @return Whether the task was successfully killed. + */ + def killTaskAttempt(taskId: Long, interruptThread: Boolean, reason: String): Boolean = { + taskScheduler.killTaskAttempt(taskId, interruptThread, reason) + } + /** * Resubmit any failed stages. Ordinarily called after a small amount of time has passed since * the last fetch failure. @@ -1353,7 +1362,7 @@ class DAGScheduler( case TaskResultLost => // Do nothing here; the TaskScheduler handles these failures and resubmits the task. - case _: ExecutorLostFailure | TaskKilled | UnknownReason => + case _: ExecutorLostFailure | _: TaskKilled | UnknownReason => // Unrecognized failure - also do nothing. If the task fails repeatedly, the TaskScheduler // will abort the job. } diff --git a/core/src/main/scala/org/apache/spark/scheduler/SchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/SchedulerBackend.scala index 8801a761afae3..22db3350abfa7 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/SchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/SchedulerBackend.scala @@ -30,8 +30,21 @@ private[spark] trait SchedulerBackend { def reviveOffers(): Unit def defaultParallelism(): Int - def killTask(taskId: Long, executorId: String, interruptThread: Boolean): Unit = + /** + * Requests that an executor kills a running task. + * + * @param taskId Id of the task. + * @param executorId Id of the executor the task is running on. + * @param interruptThread Whether the executor should interrupt the task thread. + * @param reason The reason for the task kill. + */ + def killTask( + taskId: Long, + executorId: String, + interruptThread: Boolean, + reason: String): Unit = throw new UnsupportedOperationException + def isReady(): Boolean = true /** diff --git a/core/src/main/scala/org/apache/spark/scheduler/Task.scala b/core/src/main/scala/org/apache/spark/scheduler/Task.scala index 70213722aae4f..46ef23f316a61 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/Task.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/Task.scala @@ -89,8 +89,8 @@ private[spark] abstract class Task[T]( TaskContext.setTaskContext(context) taskThread = Thread.currentThread() - if (_killed) { - kill(interruptThread = false) + if (_reasonIfKilled != null) { + kill(interruptThread = false, _reasonIfKilled) } new CallerContext( @@ -158,17 +158,17 @@ private[spark] abstract class Task[T]( // The actual Thread on which the task is running, if any. Initialized in run(). @volatile @transient private var taskThread: Thread = _ - // A flag to indicate whether the task is killed. This is used in case context is not yet - // initialized when kill() is invoked. - @volatile @transient private var _killed = false + // If non-null, this task has been killed and the reason is as specified. This is used in case + // context is not yet initialized when kill() is invoked. + @volatile @transient private var _reasonIfKilled: String = null protected var _executorDeserializeTime: Long = 0 protected var _executorDeserializeCpuTime: Long = 0 /** - * Whether the task has been killed. + * If defined, this task has been killed and this option contains the reason. */ - def killed: Boolean = _killed + def reasonIfKilled: Option[String] = Option(_reasonIfKilled) /** * Returns the amount of time spent deserializing the RDD and function to be run. @@ -201,10 +201,11 @@ private[spark] abstract class Task[T]( * be called multiple times. * If interruptThread is true, we will also call Thread.interrupt() on the Task's executor thread. */ - def kill(interruptThread: Boolean) { - _killed = true + def kill(interruptThread: Boolean, reason: String) { + require(reason != null) + _reasonIfKilled = reason if (context != null) { - context.markInterrupted() + context.markInterrupted(reason) } if (interruptThread && taskThread != null) { taskThread.interrupt() diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskScheduler.scala index cd13eebe74a99..3de7d1f7de22b 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskScheduler.scala @@ -54,6 +54,13 @@ private[spark] trait TaskScheduler { // Cancel a stage. def cancelTasks(stageId: Int, interruptThread: Boolean): Unit + /** + * Kills a task attempt. + * + * @return Whether the task was successfully killed. + */ + def killTaskAttempt(taskId: Long, interruptThread: Boolean, reason: String): Boolean + // Set the DAG scheduler for upcalls. This is guaranteed to be set before submitTasks is called. def setDAGScheduler(dagScheduler: DAGScheduler): Unit diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala index d6225a08739dd..07aea773fa632 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala @@ -241,7 +241,7 @@ private[spark] class TaskSchedulerImpl private[scheduler]( // simply abort the stage. tsm.runningTasksSet.foreach { tid => val execId = taskIdToExecutorId(tid) - backend.killTask(tid, execId, interruptThread) + backend.killTask(tid, execId, interruptThread, reason = "stage cancelled") } tsm.abort("Stage %s cancelled".format(stageId)) logInfo("Stage %d was cancelled".format(stageId)) @@ -249,6 +249,18 @@ private[spark] class TaskSchedulerImpl private[scheduler]( } } + override def killTaskAttempt(taskId: Long, interruptThread: Boolean, reason: String): Boolean = { + logInfo(s"Killing task $taskId: $reason") + val execId = taskIdToExecutorId.get(taskId) + if (execId.isDefined) { + backend.killTask(taskId, execId.get, interruptThread, reason) + true + } else { + logWarning(s"Could not kill task $taskId because no task with that ID was found.") + false + } + } + /** * Called to indicate that all task attempts (including speculated tasks) associated with the * given TaskSetManager have completed, so state associated with the TaskSetManager should be @@ -469,7 +481,7 @@ private[spark] class TaskSchedulerImpl private[scheduler]( taskState: TaskState, reason: TaskFailedReason): Unit = synchronized { taskSetManager.handleFailedTask(tid, taskState, reason) - if (!taskSetManager.isZombie && taskState != TaskState.KILLED) { + if (!taskSetManager.isZombie && !taskSetManager.someAttemptSucceeded(tid)) { // Need to revive offers again now that the task set manager state has been updated to // reflect failed tasks that need to be re-run. backend.reviveOffers() diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala index f4a21bca79aa4..a177aab5f95de 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala @@ -101,6 +101,10 @@ private[spark] class TaskSetManager( override def runningTasks: Int = runningTasksSet.size + def someAttemptSucceeded(tid: Long): Boolean = { + successful(taskInfos(tid).index) + } + // True once no more tasks should be launched for this task set manager. TaskSetManagers enter // the zombie state once at least one attempt of each task has completed successfully, or if the // task set is aborted (for example, because it was killed). TaskSetManagers remain in the zombie @@ -722,7 +726,11 @@ private[spark] class TaskSetManager( logInfo(s"Killing attempt ${attemptInfo.attemptNumber} for task ${attemptInfo.id} " + s"in stage ${taskSet.id} (TID ${attemptInfo.taskId}) on ${attemptInfo.host} " + s"as the attempt ${info.attemptNumber} succeeded on ${info.host}") - sched.backend.killTask(attemptInfo.taskId, attemptInfo.executorId, true) + sched.backend.killTask( + attemptInfo.taskId, + attemptInfo.executorId, + interruptThread = true, + reason = "another attempt succeeded") } if (!successful(index)) { tasksSuccessful += 1 diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedClusterMessage.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedClusterMessage.scala index 2898cd7d17ca0..6b49bd699a13a 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedClusterMessage.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedClusterMessage.scala @@ -40,7 +40,7 @@ private[spark] object CoarseGrainedClusterMessages { // Driver to executors case class LaunchTask(data: SerializableBuffer) extends CoarseGrainedClusterMessage - case class KillTask(taskId: Long, executor: String, interruptThread: Boolean) + case class KillTask(taskId: Long, executor: String, interruptThread: Boolean, reason: String) extends CoarseGrainedClusterMessage case class KillExecutorsOnHost(host: String) diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala index 7e2cfaccfc7ba..4eedaaea61195 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala @@ -132,10 +132,11 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp case ReviveOffers => makeOffers() - case KillTask(taskId, executorId, interruptThread) => + case KillTask(taskId, executorId, interruptThread, reason) => executorDataMap.get(executorId) match { case Some(executorInfo) => - executorInfo.executorEndpoint.send(KillTask(taskId, executorId, interruptThread)) + executorInfo.executorEndpoint.send( + KillTask(taskId, executorId, interruptThread, reason)) case None => // Ignoring the task kill since the executor is not registered. logWarning(s"Attempted to kill task $taskId for unknown executor $executorId.") @@ -428,8 +429,9 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp driverEndpoint.send(ReviveOffers) } - override def killTask(taskId: Long, executorId: String, interruptThread: Boolean) { - driverEndpoint.send(KillTask(taskId, executorId, interruptThread)) + override def killTask( + taskId: Long, executorId: String, interruptThread: Boolean, reason: String) { + driverEndpoint.send(KillTask(taskId, executorId, interruptThread, reason)) } override def defaultParallelism(): Int = { diff --git a/core/src/main/scala/org/apache/spark/scheduler/local/LocalSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/local/LocalSchedulerBackend.scala index 625f998cd4608..35509bc2f85b9 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/local/LocalSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/local/LocalSchedulerBackend.scala @@ -34,7 +34,7 @@ private case class ReviveOffers() private case class StatusUpdate(taskId: Long, state: TaskState, serializedData: ByteBuffer) -private case class KillTask(taskId: Long, interruptThread: Boolean) +private case class KillTask(taskId: Long, interruptThread: Boolean, reason: String) private case class StopExecutor() @@ -70,8 +70,8 @@ private[spark] class LocalEndpoint( reviveOffers() } - case KillTask(taskId, interruptThread) => - executor.killTask(taskId, interruptThread) + case KillTask(taskId, interruptThread, reason) => + executor.killTask(taskId, interruptThread, reason) } override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { @@ -143,8 +143,9 @@ private[spark] class LocalSchedulerBackend( override def defaultParallelism(): Int = scheduler.conf.getInt("spark.default.parallelism", totalCores) - override def killTask(taskId: Long, executorId: String, interruptThread: Boolean) { - localEndpoint.send(KillTask(taskId, interruptThread)) + override def killTask( + taskId: Long, executorId: String, interruptThread: Boolean, reason: String) { + localEndpoint.send(KillTask(taskId, interruptThread, reason)) } override def statusUpdate(taskId: Long, state: TaskState, serializedData: ByteBuffer) { diff --git a/core/src/main/scala/org/apache/spark/ui/UIUtils.scala b/core/src/main/scala/org/apache/spark/ui/UIUtils.scala index d161843dd2230..e53d6907bc404 100644 --- a/core/src/main/scala/org/apache/spark/ui/UIUtils.scala +++ b/core/src/main/scala/org/apache/spark/ui/UIUtils.scala @@ -342,7 +342,7 @@ private[spark] object UIUtils extends Logging { completed: Int, failed: Int, skipped: Int, - killed: Int, + reasonToNumKilled: Map[String, Int], total: Int): Seq[Node] = { val completeWidth = "width: %s%%".format((completed.toDouble/total)*100) // started + completed can be > total when there are speculative tasks @@ -354,7 +354,10 @@ private[spark] object UIUtils extends Logging { {completed}/{total} { if (failed > 0) s"($failed failed)" } { if (skipped > 0) s"($skipped skipped)" } - { if (killed > 0) s"($killed killed)" } + { reasonToNumKilled.toSeq.sortBy(-_._2).map { + case (reason, count) => s"($count killed: $reason)" + } + }
      diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/AllJobsPage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/AllJobsPage.scala index d217f558045f2..18be0870746e9 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/AllJobsPage.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/AllJobsPage.scala @@ -630,8 +630,8 @@ private[ui] class JobPagedTable( {UIUtils.makeProgressBar(started = job.numActiveTasks, completed = job.numCompletedTasks, - failed = job.numFailedTasks, skipped = job.numSkippedTasks, killed = job.numKilledTasks, - total = job.numTasks - job.numSkippedTasks)} + failed = job.numFailedTasks, skipped = job.numSkippedTasks, + reasonToNumKilled = job.reasonToNumKilled, total = job.numTasks - job.numSkippedTasks)} } diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/ExecutorTable.scala b/core/src/main/scala/org/apache/spark/ui/jobs/ExecutorTable.scala index cd1b02addc789..52f41298a1729 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/ExecutorTable.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/ExecutorTable.scala @@ -133,9 +133,9 @@ private[ui] class ExecutorTable(stageId: Int, stageAttemptId: Int, parent: Stage {executorIdToAddress.getOrElse(k, "CANNOT FIND ADDRESS")} {UIUtils.formatDuration(v.taskTime)} - {v.failedTasks + v.succeededTasks + v.killedTasks} + {v.failedTasks + v.succeededTasks + v.reasonToNumKilled.map(_._2).sum} {v.failedTasks} - {v.killedTasks} + {v.reasonToNumKilled.map(_._2).sum} {v.succeededTasks} {if (stageData.hasInput) { diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala b/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala index e87caff426436..1cf03e1541d14 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala @@ -371,8 +371,9 @@ class JobProgressListener(conf: SparkConf) extends SparkListener with Logging { taskEnd.reason match { case Success => execSummary.succeededTasks += 1 - case TaskKilled => - execSummary.killedTasks += 1 + case kill: TaskKilled => + execSummary.reasonToNumKilled = execSummary.reasonToNumKilled.updated( + kill.reason, execSummary.reasonToNumKilled.getOrElse(kill.reason, 0) + 1) case _ => execSummary.failedTasks += 1 } @@ -385,9 +386,10 @@ class JobProgressListener(conf: SparkConf) extends SparkListener with Logging { stageData.completedIndices.add(info.index) stageData.numCompleteTasks += 1 None - case TaskKilled => - stageData.numKilledTasks += 1 - Some(TaskKilled.toErrorString) + case kill: TaskKilled => + stageData.reasonToNumKilled = stageData.reasonToNumKilled.updated( + kill.reason, stageData.reasonToNumKilled.getOrElse(kill.reason, 0) + 1) + Some(kill.toErrorString) case e: ExceptionFailure => // Handle ExceptionFailure because we might have accumUpdates stageData.numFailedTasks += 1 Some(e.toErrorString) @@ -422,8 +424,9 @@ class JobProgressListener(conf: SparkConf) extends SparkListener with Logging { taskEnd.reason match { case Success => jobData.numCompletedTasks += 1 - case TaskKilled => - jobData.numKilledTasks += 1 + case kill: TaskKilled => + jobData.reasonToNumKilled = jobData.reasonToNumKilled.updated( + kill.reason, jobData.reasonToNumKilled.getOrElse(kill.reason, 0) + 1) case _ => jobData.numFailedTasks += 1 } diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala b/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala index e1fa9043b6a15..f4caad0f58715 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala @@ -300,7 +300,7 @@ private[ui] class StagePagedTable( {UIUtils.makeProgressBar(started = stageData.numActiveTasks, completed = stageData.completedIndices.size, failed = stageData.numFailedTasks, - skipped = 0, killed = stageData.numKilledTasks, total = info.numTasks)} + skipped = 0, reasonToNumKilled = stageData.reasonToNumKilled, total = info.numTasks)} {data.inputReadWithUnit} {data.outputWriteWithUnit} diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/UIData.scala b/core/src/main/scala/org/apache/spark/ui/jobs/UIData.scala index 073f7edfc2fe9..ac1a74ad8029d 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/UIData.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/UIData.scala @@ -32,7 +32,7 @@ private[spark] object UIData { var taskTime : Long = 0 var failedTasks : Int = 0 var succeededTasks : Int = 0 - var killedTasks : Int = 0 + var reasonToNumKilled : Map[String, Int] = Map.empty var inputBytes : Long = 0 var inputRecords : Long = 0 var outputBytes : Long = 0 @@ -64,7 +64,7 @@ private[spark] object UIData { var numCompletedTasks: Int = 0, var numSkippedTasks: Int = 0, var numFailedTasks: Int = 0, - var numKilledTasks: Int = 0, + var reasonToNumKilled: Map[String, Int] = Map.empty, /* Stages */ var numActiveStages: Int = 0, // This needs to be a set instead of a simple count to prevent double-counting of rerun stages: @@ -78,7 +78,7 @@ private[spark] object UIData { var numCompleteTasks: Int = _ var completedIndices = new OpenHashSet[Int]() var numFailedTasks: Int = _ - var numKilledTasks: Int = _ + var reasonToNumKilled: Map[String, Int] = Map.empty var executorRunTime: Long = _ var executorCpuTime: Long = _ diff --git a/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala b/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala index 4b4d2d10cbf8d..2cb88919c8c83 100644 --- a/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala +++ b/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala @@ -390,6 +390,8 @@ private[spark] object JsonProtocol { ("Executor ID" -> executorId) ~ ("Exit Caused By App" -> exitCausedByApp) ~ ("Loss Reason" -> reason.map(_.toString)) + case taskKilled: TaskKilled => + ("Kill Reason" -> taskKilled.reason) case _ => Utils.emptyJson } ("Reason" -> reason) ~ json @@ -877,7 +879,10 @@ private[spark] object JsonProtocol { })) ExceptionFailure(className, description, stackTrace, fullStackTrace, None, accumUpdates) case `taskResultLost` => TaskResultLost - case `taskKilled` => TaskKilled + case `taskKilled` => + val killReason = Utils.jsonOption(json \ "Kill Reason") + .map(_.extract[String]).getOrElse("unknown reason") + TaskKilled(killReason) case `taskCommitDenied` => // Unfortunately, the `TaskCommitDenied` message was introduced in 1.3.0 but the JSON // de/serialization logic was not added until 1.5.1. To provide backward compatibility diff --git a/core/src/test/scala/org/apache/spark/SparkContextSuite.scala b/core/src/test/scala/org/apache/spark/SparkContextSuite.scala index d08a162feda03..2c947556dfd30 100644 --- a/core/src/test/scala/org/apache/spark/SparkContextSuite.scala +++ b/core/src/test/scala/org/apache/spark/SparkContextSuite.scala @@ -34,7 +34,7 @@ import org.apache.hadoop.mapreduce.lib.input.{TextInputFormat => NewTextInputFor import org.scalatest.concurrent.Eventually import org.scalatest.Matchers._ -import org.apache.spark.scheduler.{SparkListener, SparkListenerJobStart, SparkListenerTaskStart} +import org.apache.spark.scheduler.{SparkListener, SparkListenerJobStart, SparkListenerTaskEnd, SparkListenerTaskStart} import org.apache.spark.util.Utils @@ -540,6 +540,48 @@ class SparkContextSuite extends SparkFunSuite with LocalSparkContext with Eventu } } + // Launches one task that will run forever. Once the SparkListener detects the task has + // started, kill and re-schedule it. The second run of the task will complete immediately. + // If this test times out, then the first version of the task wasn't killed successfully. + test("Killing tasks") { + sc = new SparkContext(new SparkConf().setAppName("test").setMaster("local")) + + SparkContextSuite.isTaskStarted = false + SparkContextSuite.taskKilled = false + SparkContextSuite.taskSucceeded = false + + val listener = new SparkListener { + override def onTaskStart(taskStart: SparkListenerTaskStart): Unit = { + eventually(timeout(10.seconds)) { + assert(SparkContextSuite.isTaskStarted) + } + if (!SparkContextSuite.taskKilled) { + SparkContextSuite.taskKilled = true + sc.killTaskAttempt(taskStart.taskInfo.taskId, true, "first attempt will hang") + } + } + override def onTaskEnd(taskEnd: SparkListenerTaskEnd): Unit = { + if (taskEnd.taskInfo.attemptNumber == 1 && taskEnd.reason == Success) { + SparkContextSuite.taskSucceeded = true + } + } + } + sc.addSparkListener(listener) + eventually(timeout(20.seconds)) { + sc.parallelize(1 to 1).foreach { x => + // first attempt will hang + if (!SparkContextSuite.isTaskStarted) { + SparkContextSuite.isTaskStarted = true + Thread.sleep(9999999) + } + // second attempt succeeds immediately + } + } + eventually(timeout(10.seconds)) { + assert(SparkContextSuite.taskSucceeded) + } + } + test("SPARK-19446: DebugFilesystem.assertNoOpenStreams should report " + "open streams to help debugging") { val fs = new DebugFilesystem() @@ -555,11 +597,12 @@ class SparkContextSuite extends SparkFunSuite with LocalSparkContext with Eventu assert(exc.getCause() != null) stream.close() } - } object SparkContextSuite { @volatile var cancelJob = false @volatile var cancelStage = false @volatile var isTaskStarted = false + @volatile var taskKilled = false + @volatile var taskSucceeded = false } diff --git a/core/src/test/scala/org/apache/spark/executor/ExecutorSuite.scala b/core/src/test/scala/org/apache/spark/executor/ExecutorSuite.scala index 8150fff2d018d..f47e574b4fc4b 100644 --- a/core/src/test/scala/org/apache/spark/executor/ExecutorSuite.scala +++ b/core/src/test/scala/org/apache/spark/executor/ExecutorSuite.scala @@ -110,14 +110,14 @@ class ExecutorSuite extends SparkFunSuite with LocalSparkContext with MockitoSug } // we know the task will be started, but not yet deserialized, because of the latches we // use in mockExecutorBackend. - executor.killAllTasks(true) + executor.killAllTasks(true, "test") executorSuiteHelper.latch2.countDown() if (!executorSuiteHelper.latch3.await(5, TimeUnit.SECONDS)) { fail("executor did not send second status update in time") } // `testFailedReason` should be `TaskKilled`; `taskState` should be `KILLED` - assert(executorSuiteHelper.testFailedReason === TaskKilled) + assert(executorSuiteHelper.testFailedReason === TaskKilled("test")) assert(executorSuiteHelper.taskState === TaskState.KILLED) } finally { diff --git a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala index a9389003d5db8..a10941b579fe2 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala @@ -126,6 +126,8 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with Timeou override def cancelTasks(stageId: Int, interruptThread: Boolean) { cancelledStages += stageId } + override def killTaskAttempt( + taskId: Long, interruptThread: Boolean, reason: String): Boolean = false override def setDAGScheduler(dagScheduler: DAGScheduler) = {} override def defaultParallelism() = 2 override def executorLost(executorId: String, reason: ExecutorLossReason): Unit = {} @@ -552,6 +554,10 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with Timeou override def cancelTasks(stageId: Int, interruptThread: Boolean) { throw new UnsupportedOperationException } + override def killTaskAttempt( + taskId: Long, interruptThread: Boolean, reason: String): Boolean = { + throw new UnsupportedOperationException + } override def setDAGScheduler(dagScheduler: DAGScheduler): Unit = {} override def defaultParallelism(): Int = 2 override def executorHeartbeatReceived( diff --git a/core/src/test/scala/org/apache/spark/scheduler/ExternalClusterManagerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/ExternalClusterManagerSuite.scala index 37c124a726be2..ba56af8215cd7 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/ExternalClusterManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/ExternalClusterManagerSuite.scala @@ -79,6 +79,8 @@ private class DummyTaskScheduler extends TaskScheduler { override def stop(): Unit = {} override def submitTasks(taskSet: TaskSet): Unit = {} override def cancelTasks(stageId: Int, interruptThread: Boolean): Unit = {} + override def killTaskAttempt( + taskId: Long, interruptThread: Boolean, reason: String): Boolean = false override def setDAGScheduler(dagScheduler: DAGScheduler): Unit = {} override def defaultParallelism(): Int = 2 override def executorLost(executorId: String, reason: ExecutorLossReason): Unit = {} diff --git a/core/src/test/scala/org/apache/spark/scheduler/OutputCommitCoordinatorSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/OutputCommitCoordinatorSuite.scala index 38b9d40329d48..e51e6a0d3ff6b 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/OutputCommitCoordinatorSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/OutputCommitCoordinatorSuite.scala @@ -176,13 +176,13 @@ class OutputCommitCoordinatorSuite extends SparkFunSuite with BeforeAndAfter { assert(!outputCommitCoordinator.canCommit(stage, partition, nonAuthorizedCommitter)) // The non-authorized committer fails outputCommitCoordinator.taskCompleted( - stage, partition, attemptNumber = nonAuthorizedCommitter, reason = TaskKilled) + stage, partition, attemptNumber = nonAuthorizedCommitter, reason = TaskKilled("test")) // New tasks should still not be able to commit because the authorized committer has not failed assert( !outputCommitCoordinator.canCommit(stage, partition, nonAuthorizedCommitter + 1)) // The authorized committer now fails, clearing the lock outputCommitCoordinator.taskCompleted( - stage, partition, attemptNumber = authorizedCommitter, reason = TaskKilled) + stage, partition, attemptNumber = authorizedCommitter, reason = TaskKilled("test")) // A new task should now be allowed to become the authorized committer assert( outputCommitCoordinator.canCommit(stage, partition, nonAuthorizedCommitter + 2)) diff --git a/core/src/test/scala/org/apache/spark/scheduler/SchedulerIntegrationSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/SchedulerIntegrationSuite.scala index 398ac3d6202db..8103983c4392a 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/SchedulerIntegrationSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/SchedulerIntegrationSuite.scala @@ -410,7 +410,8 @@ private[spark] abstract class MockBackend( } } - override def killTask(taskId: Long, executorId: String, interruptThread: Boolean): Unit = { + override def killTask( + taskId: Long, executorId: String, interruptThread: Boolean, reason: String): Unit = { // We have to implement this b/c of SPARK-15385. // Its OK for this to be a no-op, because even if a backend does implement killTask, // it really can only be "best-effort" in any case, and the scheduler should be robust to that. diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala index 064af381a76d2..132caef0978fb 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala @@ -677,7 +677,11 @@ class TaskSetManagerSuite extends SparkFunSuite with LocalSparkContext with Logg val sched = new FakeTaskScheduler(sc, ("execA", "host1"), ("execB", "host2")) sched.initialize(new FakeSchedulerBackend() { - override def killTask(taskId: Long, executorId: String, interruptThread: Boolean): Unit = {} + override def killTask( + taskId: Long, + executorId: String, + interruptThread: Boolean, + reason: String): Unit = {} }) // Keep track of the number of tasks that are resubmitted, @@ -935,7 +939,7 @@ class TaskSetManagerSuite extends SparkFunSuite with LocalSparkContext with Logg // Complete the speculative attempt for the running task manager.handleSuccessfulTask(4, createTaskResult(3, accumUpdatesByTask(3))) // Verify that it kills other running attempt - verify(sched.backend).killTask(3, "exec2", true) + verify(sched.backend).killTask(3, "exec2", true, "another attempt succeeded") // Because the SchedulerBackend was a mock, the 2nd copy of the task won't actually be // killed, so the FakeTaskScheduler is only told about the successful completion // of the speculated task. @@ -1023,14 +1027,14 @@ class TaskSetManagerSuite extends SparkFunSuite with LocalSparkContext with Logg manager.handleSuccessfulTask(speculativeTask.taskId, createTaskResult(3, accumUpdatesByTask(3))) // Verify that it kills other running attempt val origTask = originalTasks(speculativeTask.index) - verify(sched.backend).killTask(origTask.taskId, "exec2", true) + verify(sched.backend).killTask(origTask.taskId, "exec2", true, "another attempt succeeded") // Because the SchedulerBackend was a mock, the 2nd copy of the task won't actually be // killed, so the FakeTaskScheduler is only told about the successful completion // of the speculated task. assert(sched.endedTasks(3) === Success) // also because the scheduler is a mock, our manager isn't notified about the task killed event, // so we do that manually - manager.handleFailedTask(origTask.taskId, TaskState.KILLED, TaskKilled) + manager.handleFailedTask(origTask.taskId, TaskState.KILLED, TaskKilled("test")) // this task has "failed" 4 times, but one of them doesn't count, so keep running the stage assert(manager.tasksSuccessful === 4) assert(!manager.isZombie) @@ -1047,7 +1051,7 @@ class TaskSetManagerSuite extends SparkFunSuite with LocalSparkContext with Logg createTaskResult(3, accumUpdatesByTask(3))) // Verify that it kills other running attempt val origTask2 = originalTasks(speculativeTask2.index) - verify(sched.backend).killTask(origTask2.taskId, "exec2", true) + verify(sched.backend).killTask(origTask2.taskId, "exec2", true, "another attempt succeeded") assert(manager.tasksSuccessful === 5) assert(manager.isZombie) } @@ -1102,8 +1106,7 @@ class TaskSetManagerSuite extends SparkFunSuite with LocalSparkContext with Logg ExecutorLostFailure(taskDescs(1).executorId, exitCausedByApp = false, reason = None)) tsmSpy.handleFailedTask(taskDescs(2).taskId, TaskState.FAILED, TaskCommitDenied(0, 2, 0)) - tsmSpy.handleFailedTask(taskDescs(3).taskId, TaskState.KILLED, - TaskKilled) + tsmSpy.handleFailedTask(taskDescs(3).taskId, TaskState.KILLED, TaskKilled("test")) // Make sure that the blacklist ignored all of the task failures above, since they aren't // the fault of the executor where the task was running. diff --git a/core/src/test/scala/org/apache/spark/ui/UIUtilsSuite.scala b/core/src/test/scala/org/apache/spark/ui/UIUtilsSuite.scala index 6335d905c0fbf..c770fd5da76f7 100644 --- a/core/src/test/scala/org/apache/spark/ui/UIUtilsSuite.scala +++ b/core/src/test/scala/org/apache/spark/ui/UIUtilsSuite.scala @@ -110,7 +110,7 @@ class UIUtilsSuite extends SparkFunSuite { } test("SPARK-11906: Progress bar should not overflow because of speculative tasks") { - val generated = makeProgressBar(2, 3, 0, 0, 0, 4).head.child.filter(_.label == "div") + val generated = makeProgressBar(2, 3, 0, 0, Map.empty, 4).head.child.filter(_.label == "div") val expected = Seq(
      ,
      diff --git a/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala b/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala index e3127da9a6b24..93964a2d56743 100644 --- a/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala +++ b/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala @@ -274,8 +274,9 @@ class JobProgressListenerSuite extends SparkFunSuite with LocalSparkContext with // Make sure killed tasks are accounted for correctly. listener.onTaskEnd( - SparkListenerTaskEnd(task.stageId, 0, taskType, TaskKilled, taskInfo, metrics)) - assert(listener.stageIdToData((task.stageId, 0)).numKilledTasks === 1) + SparkListenerTaskEnd( + task.stageId, 0, taskType, TaskKilled("test"), taskInfo, metrics)) + assert(listener.stageIdToData((task.stageId, 0)).reasonToNumKilled === Map("test" -> 1)) // Make sure we count success as success. listener.onTaskEnd( diff --git a/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala b/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala index 9f76c74bce89e..a64dbeae47294 100644 --- a/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala @@ -164,7 +164,7 @@ class JsonProtocolSuite extends SparkFunSuite { testTaskEndReason(fetchMetadataFailed) testTaskEndReason(exceptionFailure) testTaskEndReason(TaskResultLost) - testTaskEndReason(TaskKilled) + testTaskEndReason(TaskKilled("test")) testTaskEndReason(TaskCommitDenied(2, 3, 4)) testTaskEndReason(ExecutorLostFailure("100", true, Some("Induced failure"))) testTaskEndReason(UnknownReason) @@ -676,7 +676,8 @@ private[spark] object JsonProtocolSuite extends Assertions { assert(r1.fullStackTrace === r2.fullStackTrace) assertSeqEquals[AccumulableInfo](r1.accumUpdates, r2.accumUpdates, (a, b) => a.equals(b)) case (TaskResultLost, TaskResultLost) => - case (TaskKilled, TaskKilled) => + case (r1: TaskKilled, r2: TaskKilled) => + assert(r1.reason == r2.reason) case (TaskCommitDenied(jobId1, partitionId1, attemptNumber1), TaskCommitDenied(jobId2, partitionId2, attemptNumber2)) => assert(jobId1 === jobId2) diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index 9925a8ba72662..8ce9367c9b446 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -66,6 +66,19 @@ object MimaExcludes { // [SPARK-17161] Removing Python-friendly constructors not needed ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.classification.OneVsRestModel.this"), + // [SPARK-19820] Allow reason to be specified to task kill + ProblemFilters.exclude[MissingTypesProblem]("org.apache.spark.TaskKilled$"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.TaskKilled.productElement"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.TaskKilled.productArity"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.TaskKilled.canEqual"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.TaskKilled.productIterator"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.TaskKilled.countTowardsTaskFailures"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.TaskKilled.productPrefix"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.TaskKilled.toErrorString"), + ProblemFilters.exclude[FinalMethodProblem]("org.apache.spark.TaskKilled.toString"), + ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.TaskContext.killTaskIfInterrupted"), + ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.TaskContext.getKillReason"), + // [SPARK-19876] Add one time trigger, and improve Trigger APIs ProblemFilters.exclude[IncompatibleTemplateDefProblem]("org.apache.spark.sql.streaming.Trigger"), ProblemFilters.exclude[MissingTypesProblem]("org.apache.spark.sql.streaming.ProcessingTime") diff --git a/resource-managers/mesos/src/main/scala/org/apache/spark/executor/MesosExecutorBackend.scala b/resource-managers/mesos/src/main/scala/org/apache/spark/executor/MesosExecutorBackend.scala index b252539782580..a086ec7ea2da6 100644 --- a/resource-managers/mesos/src/main/scala/org/apache/spark/executor/MesosExecutorBackend.scala +++ b/resource-managers/mesos/src/main/scala/org/apache/spark/executor/MesosExecutorBackend.scala @@ -104,7 +104,8 @@ private[spark] class MesosExecutorBackend logError("Received KillTask but executor was null") } else { // TODO: Determine the 'interruptOnCancel' property set for the given job. - executor.killTask(t.getValue.toLong, interruptThread = false) + executor.killTask( + t.getValue.toLong, interruptThread = false, reason = "killed by mesos") } } diff --git a/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosFineGrainedSchedulerBackend.scala b/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosFineGrainedSchedulerBackend.scala index f198f8893b3db..735c879c63c55 100644 --- a/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosFineGrainedSchedulerBackend.scala +++ b/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosFineGrainedSchedulerBackend.scala @@ -428,7 +428,8 @@ private[spark] class MesosFineGrainedSchedulerBackend( recordSlaveLost(d, slaveId, ExecutorExited(status, exitCausedByApp = true)) } - override def killTask(taskId: Long, executorId: String, interruptThread: Boolean): Unit = { + override def killTask( + taskId: Long, executorId: String, interruptThread: Boolean, reason: String): Unit = { schedulerDriver.killTask( TaskID.newBuilder() .setValue(taskId.toString).build() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileScanRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileScanRDD.scala index a89d172a911ab..9df20731c71d5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileScanRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileScanRDD.scala @@ -101,9 +101,7 @@ class FileScanRDD( // Kill the task in case it has been marked as killed. This logic is from // InterruptibleIterator, but we inline it here instead of wrapping the iterator in order // to avoid performance overhead. - if (context.isInterrupted()) { - throw new TaskKilledException - } + context.killTaskIfInterrupted() (currentIterator != null && currentIterator.hasNext) || nextIterator() } def next(): Object = { diff --git a/streaming/src/main/scala/org/apache/spark/streaming/ui/AllBatchesTable.scala b/streaming/src/main/scala/org/apache/spark/streaming/ui/AllBatchesTable.scala index 1352ca1c4c95f..70b4bb466c46b 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/ui/AllBatchesTable.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/ui/AllBatchesTable.scala @@ -97,7 +97,7 @@ private[ui] abstract class BatchTableBase(tableId: String, batchInterval: Long) completed = batch.numCompletedOutputOp, failed = batch.numFailedOutputOp, skipped = 0, - killed = 0, + reasonToNumKilled = Map.empty, total = batch.outputOperations.size) } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/ui/BatchPage.scala b/streaming/src/main/scala/org/apache/spark/streaming/ui/BatchPage.scala index 1a87fc790f91b..f55af6a5cc358 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/ui/BatchPage.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/ui/BatchPage.scala @@ -146,7 +146,7 @@ private[ui] class BatchPage(parent: StreamingTab) extends WebUIPage("batch") { completed = sparkJob.numCompletedTasks, failed = sparkJob.numFailedTasks, skipped = sparkJob.numSkippedTasks, - killed = sparkJob.numKilledTasks, + reasonToNumKilled = sparkJob.reasonToNumKilled, total = sparkJob.numTasks - sparkJob.numSkippedTasks) } From 344f38b04b271b5f3ec2748b34db4e52d54da1bc Mon Sep 17 00:00:00 2001 From: Xiao Li Date: Fri, 24 Mar 2017 14:42:33 +0800 Subject: [PATCH 0104/1765] [SPARK-19970][SQL][FOLLOW-UP] Table owner should be USER instead of PRINCIPAL in kerberized clusters #17311 ### What changes were proposed in this pull request? This is a follow-up for the PR: https://github.com/apache/spark/pull/17311 - For safety, use `sessionState` to get the user name, instead of calling `SessionState.get()` in the function `toHiveTable`. - Passing `user names` instead of `conf` when calling `toHiveTable`. ### How was this patch tested? N/A Author: Xiao Li Closes #17405 from gatorsmile/user. --- .../sql/hive/client/HiveClientImpl.scala | 26 +++++++++---------- 1 file changed, 13 insertions(+), 13 deletions(-) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala index 13edcd051768c..56ccac32a8d88 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala @@ -207,6 +207,8 @@ private[hive] class HiveClientImpl( /** Returns the configuration for the current session. */ def conf: HiveConf = state.getConf + private val userName = state.getAuthenticator.getUserName + override def getConf(key: String, defaultValue: String): String = { conf.get(key, defaultValue) } @@ -413,7 +415,7 @@ private[hive] class HiveClientImpl( createTime = h.getTTable.getCreateTime.toLong * 1000, lastAccessTime = h.getLastAccessTime.toLong * 1000, storage = CatalogStorageFormat( - locationUri = shim.getDataLocation(h).map(CatalogUtils.stringToURI(_)), + locationUri = shim.getDataLocation(h).map(CatalogUtils.stringToURI), // To avoid ClassNotFound exception, we try our best to not get the format class, but get // the class name directly. However, for non-native tables, there is no interface to get // the format class name, so we may still throw ClassNotFound in this case. @@ -441,7 +443,7 @@ private[hive] class HiveClientImpl( } override def createTable(table: CatalogTable, ignoreIfExists: Boolean): Unit = withHiveState { - client.createTable(toHiveTable(table, Some(conf)), ignoreIfExists) + client.createTable(toHiveTable(table, Some(userName)), ignoreIfExists) } override def dropTable( @@ -453,7 +455,7 @@ private[hive] class HiveClientImpl( } override def alterTable(tableName: String, table: CatalogTable): Unit = withHiveState { - val hiveTable = toHiveTable(table, Some(conf)) + val hiveTable = toHiveTable(table, Some(userName)) // Do not use `table.qualifiedName` here because this may be a rename val qualifiedTableName = s"${table.database}.$tableName" shim.alterTable(client, qualifiedTableName, hiveTable) @@ -522,7 +524,7 @@ private[hive] class HiveClientImpl( newSpecs: Seq[TablePartitionSpec]): Unit = withHiveState { require(specs.size == newSpecs.size, "number of old and new partition specs differ") val catalogTable = getTable(db, table) - val hiveTable = toHiveTable(catalogTable, Some(conf)) + val hiveTable = toHiveTable(catalogTable, Some(userName)) specs.zip(newSpecs).foreach { case (oldSpec, newSpec) => val hivePart = getPartitionOption(catalogTable, oldSpec) .map { p => toHivePartition(p.copy(spec = newSpec), hiveTable) } @@ -535,7 +537,7 @@ private[hive] class HiveClientImpl( db: String, table: String, newParts: Seq[CatalogTablePartition]): Unit = withHiveState { - val hiveTable = toHiveTable(getTable(db, table), Some(conf)) + val hiveTable = toHiveTable(getTable(db, table), Some(userName)) shim.alterPartitions(client, table, newParts.map { p => toHivePartition(p, hiveTable) }.asJava) } @@ -563,7 +565,7 @@ private[hive] class HiveClientImpl( override def getPartitionOption( table: CatalogTable, spec: TablePartitionSpec): Option[CatalogTablePartition] = withHiveState { - val hiveTable = toHiveTable(table, Some(conf)) + val hiveTable = toHiveTable(table, Some(userName)) val hivePartition = client.getPartition(hiveTable, spec.asJava, false) Option(hivePartition).map(fromHivePartition) } @@ -575,7 +577,7 @@ private[hive] class HiveClientImpl( override def getPartitions( table: CatalogTable, spec: Option[TablePartitionSpec]): Seq[CatalogTablePartition] = withHiveState { - val hiveTable = toHiveTable(table, Some(conf)) + val hiveTable = toHiveTable(table, Some(userName)) val parts = spec match { case None => shim.getAllPartitions(client, hiveTable).map(fromHivePartition) case Some(s) => @@ -589,7 +591,7 @@ private[hive] class HiveClientImpl( override def getPartitionsByFilter( table: CatalogTable, predicates: Seq[Expression]): Seq[CatalogTablePartition] = withHiveState { - val hiveTable = toHiveTable(table, Some(conf)) + val hiveTable = toHiveTable(table, Some(userName)) val parts = shim.getPartitionsByFilter(client, hiveTable, predicates).map(fromHivePartition) HiveCatalogMetrics.incrementFetchedPartitions(parts.length) parts @@ -817,9 +819,7 @@ private[hive] object HiveClientImpl { /** * Converts the native table metadata representation format CatalogTable to Hive's Table. */ - def toHiveTable( - table: CatalogTable, - conf: Option[HiveConf] = None): HiveTable = { + def toHiveTable(table: CatalogTable, userName: Option[String] = None): HiveTable = { val hiveTable = new HiveTable(table.database, table.identifier.table) // For EXTERNAL_TABLE, we also need to set EXTERNAL field in the table properties. // Otherwise, Hive metastore will change the table to a MANAGED_TABLE. @@ -851,10 +851,10 @@ private[hive] object HiveClientImpl { hiveTable.setFields(schema.asJava) } hiveTable.setPartCols(partCols.asJava) - conf.foreach { _ => hiveTable.setOwner(SessionState.get().getAuthenticator().getUserName()) } + userName.foreach(hiveTable.setOwner) hiveTable.setCreateTime((table.createTime / 1000).toInt) hiveTable.setLastAccessTime((table.lastAccessTime / 1000).toInt) - table.storage.locationUri.map(CatalogUtils.URIToString(_)).foreach { loc => + table.storage.locationUri.map(CatalogUtils.URIToString).foreach { loc => hiveTable.getTTable.getSd.setLocation(loc)} table.storage.inputFormat.map(toInputFormat).foreach(hiveTable.setInputFormatClass) table.storage.outputFormat.map(toOutputFormat).foreach(hiveTable.setOutputFormatClass) From d9f4ce6943c16a7e29f98e57c33acbfc0379b54d Mon Sep 17 00:00:00 2001 From: Nick Pentreath Date: Fri, 24 Mar 2017 08:01:15 -0700 Subject: [PATCH 0105/1765] [SPARK-15040][ML][PYSPARK] Add Imputer to PySpark Add Python wrapper for `Imputer` feature transformer. ## How was this patch tested? New doc tests and tweak to PySpark ML `tests.py` Author: Nick Pentreath Closes #17316 from MLnick/SPARK-15040-pyspark-imputer. --- .../org/apache/spark/ml/feature/Imputer.scala | 10 +- python/pyspark/ml/feature.py | 160 ++++++++++++++++++ python/pyspark/ml/tests.py | 10 ++ 3 files changed, 175 insertions(+), 5 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Imputer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Imputer.scala index b1a802ee13fc4..ec4c6ad75ee23 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/Imputer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Imputer.scala @@ -93,12 +93,12 @@ private[feature] trait ImputerParams extends Params with HasInputCols { /** * :: Experimental :: * Imputation estimator for completing missing values, either using the mean or the median - * of the column in which the missing values are located. The input column should be of - * DoubleType or FloatType. Currently Imputer does not support categorical features yet + * of the columns in which the missing values are located. The input columns should be of + * DoubleType or FloatType. Currently Imputer does not support categorical features * (SPARK-15041) and possibly creates incorrect values for a categorical feature. * * Note that the mean/median value is computed after filtering out missing values. - * All Null values in the input column are treated as missing, and so are also imputed. For + * All Null values in the input columns are treated as missing, and so are also imputed. For * computing median, DataFrameStatFunctions.approxQuantile is used with a relative error of 0.001. */ @Experimental @@ -176,8 +176,8 @@ object Imputer extends DefaultParamsReadable[Imputer] { * :: Experimental :: * Model fitted by [[Imputer]]. * - * @param surrogateDF a DataFrame contains inputCols and their corresponding surrogates, which are - * used to replace the missing values in the input DataFrame. + * @param surrogateDF a DataFrame containing inputCols and their corresponding surrogates, + * which are used to replace the missing values in the input DataFrame. */ @Experimental class ImputerModel private[ml]( diff --git a/python/pyspark/ml/feature.py b/python/pyspark/ml/feature.py index 92f8549e9cb9e..8d25f5b3a771a 100755 --- a/python/pyspark/ml/feature.py +++ b/python/pyspark/ml/feature.py @@ -36,6 +36,7 @@ 'ElementwiseProduct', 'HashingTF', 'IDF', 'IDFModel', + 'Imputer', 'ImputerModel', 'IndexToString', 'MaxAbsScaler', 'MaxAbsScalerModel', 'MinHashLSH', 'MinHashLSHModel', @@ -870,6 +871,165 @@ def idf(self): return self._call_java("idf") +@inherit_doc +class Imputer(JavaEstimator, HasInputCols, JavaMLReadable, JavaMLWritable): + """ + .. note:: Experimental + + Imputation estimator for completing missing values, either using the mean or the median + of the columns in which the missing values are located. The input columns should be of + DoubleType or FloatType. Currently Imputer does not support categorical features and + possibly creates incorrect values for a categorical feature. + + Note that the mean/median value is computed after filtering out missing values. + All Null values in the input columns are treated as missing, and so are also imputed. For + computing median, :py:meth:`pyspark.sql.DataFrame.approxQuantile` is used with a + relative error of `0.001`. + + >>> df = spark.createDataFrame([(1.0, float("nan")), (2.0, float("nan")), (float("nan"), 3.0), + ... (4.0, 4.0), (5.0, 5.0)], ["a", "b"]) + >>> imputer = Imputer(inputCols=["a", "b"], outputCols=["out_a", "out_b"]) + >>> model = imputer.fit(df) + >>> model.surrogateDF.show() + +---+---+ + | a| b| + +---+---+ + |3.0|4.0| + +---+---+ + ... + >>> model.transform(df).show() + +---+---+-----+-----+ + | a| b|out_a|out_b| + +---+---+-----+-----+ + |1.0|NaN| 1.0| 4.0| + |2.0|NaN| 2.0| 4.0| + |NaN|3.0| 3.0| 3.0| + ... + >>> imputer.setStrategy("median").setMissingValue(1.0).fit(df).transform(df).show() + +---+---+-----+-----+ + | a| b|out_a|out_b| + +---+---+-----+-----+ + |1.0|NaN| 4.0| NaN| + ... + >>> imputerPath = temp_path + "/imputer" + >>> imputer.save(imputerPath) + >>> loadedImputer = Imputer.load(imputerPath) + >>> loadedImputer.getStrategy() == imputer.getStrategy() + True + >>> loadedImputer.getMissingValue() + 1.0 + >>> modelPath = temp_path + "/imputer-model" + >>> model.save(modelPath) + >>> loadedModel = ImputerModel.load(modelPath) + >>> loadedModel.transform(df).head().out_a == model.transform(df).head().out_a + True + + .. versionadded:: 2.2.0 + """ + + outputCols = Param(Params._dummy(), "outputCols", + "output column names.", typeConverter=TypeConverters.toListString) + + strategy = Param(Params._dummy(), "strategy", + "strategy for imputation. If mean, then replace missing values using the mean " + "value of the feature. If median, then replace missing values using the " + "median value of the feature.", + typeConverter=TypeConverters.toString) + + missingValue = Param(Params._dummy(), "missingValue", + "The placeholder for the missing values. All occurrences of missingValue " + "will be imputed.", typeConverter=TypeConverters.toFloat) + + @keyword_only + def __init__(self, strategy="mean", missingValue=float("nan"), inputCols=None, + outputCols=None): + """ + __init__(self, strategy="mean", missingValue=float("nan"), inputCols=None, \ + outputCols=None): + """ + super(Imputer, self).__init__() + self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.Imputer", self.uid) + self._setDefault(strategy="mean", missingValue=float("nan")) + kwargs = self._input_kwargs + self.setParams(**kwargs) + + @keyword_only + @since("2.2.0") + def setParams(self, strategy="mean", missingValue=float("nan"), inputCols=None, + outputCols=None): + """ + setParams(self, strategy="mean", missingValue=float("nan"), inputCols=None, \ + outputCols=None) + Sets params for this Imputer. + """ + kwargs = self._input_kwargs + return self._set(**kwargs) + + @since("2.2.0") + def setOutputCols(self, value): + """ + Sets the value of :py:attr:`outputCols`. + """ + return self._set(outputCols=value) + + @since("2.2.0") + def getOutputCols(self): + """ + Gets the value of :py:attr:`outputCols` or its default value. + """ + return self.getOrDefault(self.outputCols) + + @since("2.2.0") + def setStrategy(self, value): + """ + Sets the value of :py:attr:`strategy`. + """ + return self._set(strategy=value) + + @since("2.2.0") + def getStrategy(self): + """ + Gets the value of :py:attr:`strategy` or its default value. + """ + return self.getOrDefault(self.strategy) + + @since("2.2.0") + def setMissingValue(self, value): + """ + Sets the value of :py:attr:`missingValue`. + """ + return self._set(missingValue=value) + + @since("2.2.0") + def getMissingValue(self): + """ + Gets the value of :py:attr:`missingValue` or its default value. + """ + return self.getOrDefault(self.missingValue) + + def _create_model(self, java_model): + return ImputerModel(java_model) + + +class ImputerModel(JavaModel, JavaMLReadable, JavaMLWritable): + """ + .. note:: Experimental + + Model fitted by :py:class:`Imputer`. + + .. versionadded:: 2.2.0 + """ + + @property + @since("2.2.0") + def surrogateDF(self): + """ + Returns a DataFrame containing inputCols and their corresponding surrogates, + which are used to replace the missing values in the input DataFrame. + """ + return self._call_java("surrogateDF") + + @inherit_doc class MaxAbsScaler(JavaEstimator, HasInputCol, HasOutputCol, JavaMLReadable, JavaMLWritable): """ diff --git a/python/pyspark/ml/tests.py b/python/pyspark/ml/tests.py index f052f5bb770c6..cc559db58720f 100755 --- a/python/pyspark/ml/tests.py +++ b/python/pyspark/ml/tests.py @@ -1273,6 +1273,7 @@ class DefaultValuesTests(PySparkTestCase): """ def check_params(self, py_stage): + import pyspark.ml.feature if not hasattr(py_stage, "_to_java"): return java_stage = py_stage._to_java() @@ -1292,6 +1293,15 @@ def check_params(self, py_stage): _java2py(self.sc, java_stage.clear(java_param).getOrDefault(java_param)) py_stage._clear(p) py_default = py_stage.getOrDefault(p) + if isinstance(py_stage, pyspark.ml.feature.Imputer) and p.name == "missingValue": + # SPARK-15040 - default value for Imputer param 'missingValue' is NaN, + # and NaN != NaN, so handle it specially here + import math + self.assertTrue(math.isnan(java_default) and math.isnan(py_default), + "Java default %s and python default %s are not both NaN for " + "param %s for Params %s" + % (str(java_default), str(py_default), p.name, str(py_stage))) + return self.assertEqual(java_default, py_default, "Java default %s != python default %s of param %s for Params %s" % (str(java_default), str(py_default), p.name, str(py_stage))) From 9299d071f95798e33b18c08d3c75bb26f88b266b Mon Sep 17 00:00:00 2001 From: Jacek Laskowski Date: Fri, 24 Mar 2017 09:56:05 -0700 Subject: [PATCH 0106/1765] [SQL][MINOR] Fix for typo in Analyzer ## What changes were proposed in this pull request? Fix for typo in Analyzer ## How was this patch tested? local build Author: Jacek Laskowski Closes #17409 from jaceklaskowski/analyzer-typo. --- .../scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 036ed060d9efe..1b3a53c6359e6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -2502,7 +2502,7 @@ object TimeWindowing extends Rule[LogicalPlan] { substitutedPlan.withNewChildren(expandedPlan :: Nil) } else if (windowExpressions.size > 1) { p.failAnalysis("Multiple time window expressions would result in a cartesian product " + - "of rows, therefore they are not currently not supported.") + "of rows, therefore they are currently not supported.") } else { p // Return unchanged. Analyzer will throw exception later } From 707e501832fa7adde0a884c528a7352983d83520 Mon Sep 17 00:00:00 2001 From: Adam Budde Date: Fri, 24 Mar 2017 12:40:29 -0700 Subject: [PATCH 0107/1765] [SPARK-19911][STREAMING] Add builder interface for Kinesis DStreams ## What changes were proposed in this pull request? - Add new KinesisDStream.scala containing KinesisDStream.Builder class - Add KinesisDStreamBuilderSuite test suite - Make KinesisInputDStream ctor args package private for testing - Add JavaKinesisDStreamBuilderSuite test suite - Add args to KinesisInputDStream and KinesisReceiver for optional service-specific auth (Kinesis, DynamoDB and CloudWatch) ## How was this patch tested? Added ```KinesisDStreamBuilderSuite``` to verify builder class works as expected Author: Adam Budde Closes #17250 from budde/KinesisStreamBuilder. --- .../kinesis/KinesisBackedBlockRDD.scala | 6 +- .../kinesis/KinesisInputDStream.scala | 259 +++++++++++++++++- .../streaming/kinesis/KinesisReceiver.scala | 20 +- .../streaming/kinesis/KinesisUtils.scala | 43 +-- .../SerializableCredentialsProvider.scala | 85 ------ .../kinesis/SparkAWSCredentials.scala | 182 ++++++++++++ .../JavaKinesisInputDStreamBuilderSuite.java | 63 +++++ .../KinesisInputDStreamBuilderSuite.scala | 115 ++++++++ .../kinesis/KinesisReceiverSuite.scala | 23 -- .../kinesis/KinesisStreamSuite.scala | 2 +- .../SparkAWSCredentialsBuilderSuite.scala | 100 +++++++ 11 files changed, 749 insertions(+), 149 deletions(-) delete mode 100644 external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/SerializableCredentialsProvider.scala create mode 100644 external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/SparkAWSCredentials.scala create mode 100644 external/kinesis-asl/src/test/java/org/apache/spark/streaming/kinesis/JavaKinesisInputDStreamBuilderSuite.java create mode 100644 external/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisInputDStreamBuilderSuite.scala create mode 100644 external/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/SparkAWSCredentialsBuilderSuite.scala diff --git a/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisBackedBlockRDD.scala b/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisBackedBlockRDD.scala index 0f1790bddcc3d..f31ebf1ec8da0 100644 --- a/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisBackedBlockRDD.scala +++ b/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisBackedBlockRDD.scala @@ -82,8 +82,8 @@ class KinesisBackedBlockRDD[T: ClassTag]( @transient val arrayOfseqNumberRanges: Array[SequenceNumberRanges], @transient private val isBlockIdValid: Array[Boolean] = Array.empty, val retryTimeoutMs: Int = 10000, - val messageHandler: Record => T = KinesisUtils.defaultMessageHandler _, - val kinesisCredsProvider: SerializableCredentialsProvider = DefaultCredentialsProvider + val messageHandler: Record => T = KinesisInputDStream.defaultMessageHandler _, + val kinesisCreds: SparkAWSCredentials = DefaultCredentials ) extends BlockRDD[T](sc, _blockIds) { require(_blockIds.length == arrayOfseqNumberRanges.length, @@ -109,7 +109,7 @@ class KinesisBackedBlockRDD[T: ClassTag]( } def getBlockFromKinesis(): Iterator[T] = { - val credentials = kinesisCredsProvider.provider.getCredentials + val credentials = kinesisCreds.provider.getCredentials partition.seqNumberRanges.ranges.iterator.flatMap { range => new KinesisSequenceRangeIterator(credentials, endpointUrl, regionName, range, retryTimeoutMs).map(messageHandler) diff --git a/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisInputDStream.scala b/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisInputDStream.scala index fbc6b99443ed7..8970ad2bafda0 100644 --- a/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisInputDStream.scala +++ b/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisInputDStream.scala @@ -22,24 +22,28 @@ import scala.reflect.ClassTag import com.amazonaws.services.kinesis.clientlibrary.lib.worker.InitialPositionInStream import com.amazonaws.services.kinesis.model.Record +import org.apache.spark.annotation.InterfaceStability import org.apache.spark.rdd.RDD import org.apache.spark.storage.{BlockId, StorageLevel} import org.apache.spark.streaming.{Duration, StreamingContext, Time} +import org.apache.spark.streaming.api.java.JavaStreamingContext import org.apache.spark.streaming.dstream.ReceiverInputDStream import org.apache.spark.streaming.receiver.Receiver import org.apache.spark.streaming.scheduler.ReceivedBlockInfo private[kinesis] class KinesisInputDStream[T: ClassTag]( _ssc: StreamingContext, - streamName: String, - endpointUrl: String, - regionName: String, - initialPositionInStream: InitialPositionInStream, - checkpointAppName: String, - checkpointInterval: Duration, - storageLevel: StorageLevel, - messageHandler: Record => T, - kinesisCredsProvider: SerializableCredentialsProvider + val streamName: String, + val endpointUrl: String, + val regionName: String, + val initialPositionInStream: InitialPositionInStream, + val checkpointAppName: String, + val checkpointInterval: Duration, + val _storageLevel: StorageLevel, + val messageHandler: Record => T, + val kinesisCreds: SparkAWSCredentials, + val dynamoDBCreds: Option[SparkAWSCredentials], + val cloudWatchCreds: Option[SparkAWSCredentials] ) extends ReceiverInputDStream[T](_ssc) { private[streaming] @@ -61,7 +65,7 @@ private[kinesis] class KinesisInputDStream[T: ClassTag]( isBlockIdValid = isBlockIdValid, retryTimeoutMs = ssc.graph.batchDuration.milliseconds.toInt, messageHandler = messageHandler, - kinesisCredsProvider = kinesisCredsProvider) + kinesisCreds = kinesisCreds) } else { logWarning("Kinesis sequence number information was not present with some block metadata," + " it may not be possible to recover from failures") @@ -71,7 +75,238 @@ private[kinesis] class KinesisInputDStream[T: ClassTag]( override def getReceiver(): Receiver[T] = { new KinesisReceiver(streamName, endpointUrl, regionName, initialPositionInStream, - checkpointAppName, checkpointInterval, storageLevel, messageHandler, - kinesisCredsProvider) + checkpointAppName, checkpointInterval, _storageLevel, messageHandler, + kinesisCreds, dynamoDBCreds, cloudWatchCreds) } } + +@InterfaceStability.Evolving +object KinesisInputDStream { + /** + * Builder for [[KinesisInputDStream]] instances. + * + * @since 2.2.0 + */ + @InterfaceStability.Evolving + class Builder { + // Required params + private var streamingContext: Option[StreamingContext] = None + private var streamName: Option[String] = None + private var checkpointAppName: Option[String] = None + + // Params with defaults + private var endpointUrl: Option[String] = None + private var regionName: Option[String] = None + private var initialPositionInStream: Option[InitialPositionInStream] = None + private var checkpointInterval: Option[Duration] = None + private var storageLevel: Option[StorageLevel] = None + private var kinesisCredsProvider: Option[SparkAWSCredentials] = None + private var dynamoDBCredsProvider: Option[SparkAWSCredentials] = None + private var cloudWatchCredsProvider: Option[SparkAWSCredentials] = None + + /** + * Sets the StreamingContext that will be used to construct the Kinesis DStream. This is a + * required parameter. + * + * @param ssc [[StreamingContext]] used to construct Kinesis DStreams + * @return Reference to this [[KinesisInputDStream.Builder]] + */ + def streamingContext(ssc: StreamingContext): Builder = { + streamingContext = Option(ssc) + this + } + + /** + * Sets the StreamingContext that will be used to construct the Kinesis DStream. This is a + * required parameter. + * + * @param jssc [[JavaStreamingContext]] used to construct Kinesis DStreams + * @return Reference to this [[KinesisInputDStream.Builder]] + */ + def streamingContext(jssc: JavaStreamingContext): Builder = { + streamingContext = Option(jssc.ssc) + this + } + + /** + * Sets the name of the Kinesis stream that the DStream will read from. This is a required + * parameter. + * + * @param streamName Name of Kinesis stream that the DStream will read from + * @return Reference to this [[KinesisInputDStream.Builder]] + */ + def streamName(streamName: String): Builder = { + this.streamName = Option(streamName) + this + } + + /** + * Sets the KCL application name to use when checkpointing state to DynamoDB. This is a + * required parameter. + * + * @param appName Value to use for the KCL app name (used when creating the DynamoDB checkpoint + * table and when writing metrics to CloudWatch) + * @return Reference to this [[KinesisInputDStream.Builder]] + */ + def checkpointAppName(appName: String): Builder = { + checkpointAppName = Option(appName) + this + } + + /** + * Sets the AWS Kinesis endpoint URL. Defaults to "https://kinesis.us-east-1.amazonaws.com" if + * no custom value is specified + * + * @param url Kinesis endpoint URL to use + * @return Reference to this [[KinesisInputDStream.Builder]] + */ + def endpointUrl(url: String): Builder = { + endpointUrl = Option(url) + this + } + + /** + * Sets the AWS region to construct clients for. Defaults to "us-east-1" if no custom value + * is specified. + * + * @param regionName Name of AWS region to use (e.g. "us-west-2") + * @return Reference to this [[KinesisInputDStream.Builder]] + */ + def regionName(regionName: String): Builder = { + this.regionName = Option(regionName) + this + } + + /** + * Sets the initial position data is read from in the Kinesis stream. Defaults to + * [[InitialPositionInStream.LATEST]] if no custom value is specified. + * + * @param initialPosition InitialPositionInStream value specifying where Spark Streaming + * will start reading records in the Kinesis stream from + * @return Reference to this [[KinesisInputDStream.Builder]] + */ + def initialPositionInStream(initialPosition: InitialPositionInStream): Builder = { + initialPositionInStream = Option(initialPosition) + this + } + + /** + * Sets how often the KCL application state is checkpointed to DynamoDB. Defaults to the Spark + * Streaming batch interval if no custom value is specified. + * + * @param interval [[Duration]] specifying how often the KCL state should be checkpointed to + * DynamoDB. + * @return Reference to this [[KinesisInputDStream.Builder]] + */ + def checkpointInterval(interval: Duration): Builder = { + checkpointInterval = Option(interval) + this + } + + /** + * Sets the storage level of the blocks for the DStream created. Defaults to + * [[StorageLevel.MEMORY_AND_DISK_2]] if no custom value is specified. + * + * @param storageLevel [[StorageLevel]] to use for the DStream data blocks + * @return Reference to this [[KinesisInputDStream.Builder]] + */ + def storageLevel(storageLevel: StorageLevel): Builder = { + this.storageLevel = Option(storageLevel) + this + } + + /** + * Sets the [[SparkAWSCredentials]] to use for authenticating to the AWS Kinesis + * endpoint. Defaults to [[DefaultCredentialsProvider]] if no custom value is specified. + * + * @param credentials [[SparkAWSCredentials]] to use for Kinesis authentication + */ + def kinesisCredentials(credentials: SparkAWSCredentials): Builder = { + kinesisCredsProvider = Option(credentials) + this + } + + /** + * Sets the [[SparkAWSCredentials]] to use for authenticating to the AWS DynamoDB + * endpoint. Will use the same credentials used for AWS Kinesis if no custom value is set. + * + * @param credentials [[SparkAWSCredentials]] to use for DynamoDB authentication + */ + def dynamoDBCredentials(credentials: SparkAWSCredentials): Builder = { + dynamoDBCredsProvider = Option(credentials) + this + } + + /** + * Sets the [[SparkAWSCredentials]] to use for authenticating to the AWS CloudWatch + * endpoint. Will use the same credentials used for AWS Kinesis if no custom value is set. + * + * @param credentials [[SparkAWSCredentials]] to use for CloudWatch authentication + */ + def cloudWatchCredentials(credentials: SparkAWSCredentials): Builder = { + cloudWatchCredsProvider = Option(credentials) + this + } + + /** + * Create a new instance of [[KinesisInputDStream]] with configured parameters and the provided + * message handler. + * + * @param handler Function converting [[Record]] instances read by the KCL to DStream type [[T]] + * @return Instance of [[KinesisInputDStream]] constructed with configured parameters + */ + def buildWithMessageHandler[T: ClassTag]( + handler: Record => T): KinesisInputDStream[T] = { + val ssc = getRequiredParam(streamingContext, "streamingContext") + new KinesisInputDStream( + ssc, + getRequiredParam(streamName, "streamName"), + endpointUrl.getOrElse(DEFAULT_KINESIS_ENDPOINT_URL), + regionName.getOrElse(DEFAULT_KINESIS_REGION_NAME), + initialPositionInStream.getOrElse(DEFAULT_INITIAL_POSITION_IN_STREAM), + getRequiredParam(checkpointAppName, "checkpointAppName"), + checkpointInterval.getOrElse(ssc.graph.batchDuration), + storageLevel.getOrElse(DEFAULT_STORAGE_LEVEL), + handler, + kinesisCredsProvider.getOrElse(DefaultCredentials), + dynamoDBCredsProvider, + cloudWatchCredsProvider) + } + + /** + * Create a new instance of [[KinesisInputDStream]] with configured parameters and using the + * default message handler, which returns [[Array[Byte]]]. + * + * @return Instance of [[KinesisInputDStream]] constructed with configured parameters + */ + def build(): KinesisInputDStream[Array[Byte]] = buildWithMessageHandler(defaultMessageHandler) + + private def getRequiredParam[T](param: Option[T], paramName: String): T = param.getOrElse { + throw new IllegalArgumentException(s"No value provided for required parameter $paramName") + } + } + + /** + * Creates a [[KinesisInputDStream.Builder]] for constructing [[KinesisInputDStream]] instances. + * + * @since 2.2.0 + * + * @return [[KinesisInputDStream.Builder]] instance + */ + def builder: Builder = new Builder + + private[kinesis] def defaultMessageHandler(record: Record): Array[Byte] = { + if (record == null) return null + val byteBuffer = record.getData() + val byteArray = new Array[Byte](byteBuffer.remaining()) + byteBuffer.get(byteArray) + byteArray + } + + private[kinesis] val DEFAULT_KINESIS_ENDPOINT_URL: String = + "https://kinesis.us-east-1.amazonaws.com" + private[kinesis] val DEFAULT_KINESIS_REGION_NAME: String = "us-east-1" + private[kinesis] val DEFAULT_INITIAL_POSITION_IN_STREAM: InitialPositionInStream = + InitialPositionInStream.LATEST + private[kinesis] val DEFAULT_STORAGE_LEVEL: StorageLevel = StorageLevel.MEMORY_AND_DISK_2 +} diff --git a/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisReceiver.scala b/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisReceiver.scala index 320728f4bb221..1026d0fcb59bd 100644 --- a/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisReceiver.scala +++ b/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisReceiver.scala @@ -70,9 +70,14 @@ import org.apache.spark.util.Utils * See the Kinesis Spark Streaming documentation for more * details on the different types of checkpoints. * @param storageLevel Storage level to use for storing the received objects - * @param kinesisCredsProvider SerializableCredentialsProvider instance that will be used to - * generate the AWSCredentialsProvider instance used for KCL - * authorization. + * @param kinesisCreds SparkAWSCredentials instance that will be used to generate the + * AWSCredentialsProvider passed to the KCL to authorize Kinesis API calls. + * @param cloudWatchCreds Optional SparkAWSCredentials instance that will be used to generate the + * AWSCredentialsProvider passed to the KCL to authorize CloudWatch API + * calls. Will use kinesisCreds if value is None. + * @param dynamoDBCreds Optional SparkAWSCredentials instance that will be used to generate the + * AWSCredentialsProvider passed to the KCL to authorize DynamoDB API calls. + * Will use kinesisCreds if value is None. */ private[kinesis] class KinesisReceiver[T]( val streamName: String, @@ -83,7 +88,9 @@ private[kinesis] class KinesisReceiver[T]( checkpointInterval: Duration, storageLevel: StorageLevel, messageHandler: Record => T, - kinesisCredsProvider: SerializableCredentialsProvider) + kinesisCreds: SparkAWSCredentials, + dynamoDBCreds: Option[SparkAWSCredentials], + cloudWatchCreds: Option[SparkAWSCredentials]) extends Receiver[T](storageLevel) with Logging { receiver => /* @@ -140,10 +147,13 @@ private[kinesis] class KinesisReceiver[T]( workerId = Utils.localHostName() + ":" + UUID.randomUUID() kinesisCheckpointer = new KinesisCheckpointer(receiver, checkpointInterval, workerId) + val kinesisProvider = kinesisCreds.provider val kinesisClientLibConfiguration = new KinesisClientLibConfiguration( checkpointAppName, streamName, - kinesisCredsProvider.provider, + kinesisProvider, + dynamoDBCreds.map(_.provider).getOrElse(kinesisProvider), + cloudWatchCreds.map(_.provider).getOrElse(kinesisProvider), workerId) .withKinesisEndpoint(endpointUrl) .withInitialPositionInStream(initialPositionInStream) diff --git a/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisUtils.scala b/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisUtils.scala index 2d777982e760c..1298463bfba1e 100644 --- a/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisUtils.scala +++ b/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisUtils.scala @@ -58,6 +58,7 @@ object KinesisUtils { * on the workers. See AWS documentation to understand how DefaultAWSCredentialsProviderChain * gets the AWS credentials. */ + @deprecated("Use KinesisInputDStream.builder instead", "2.2.0") def createStream[T: ClassTag]( ssc: StreamingContext, kinesisAppName: String, @@ -73,7 +74,7 @@ object KinesisUtils { ssc.withNamedScope("kinesis stream") { new KinesisInputDStream[T](ssc, streamName, endpointUrl, validateRegion(regionName), initialPositionInStream, kinesisAppName, checkpointInterval, storageLevel, - cleanedHandler, DefaultCredentialsProvider) + cleanedHandler, DefaultCredentials, None, None) } } @@ -108,6 +109,7 @@ object KinesisUtils { * is enabled. Make sure that your checkpoint directory is secure. */ // scalastyle:off + @deprecated("Use KinesisInputDStream.builder instead", "2.2.0") def createStream[T: ClassTag]( ssc: StreamingContext, kinesisAppName: String, @@ -123,12 +125,12 @@ object KinesisUtils { // scalastyle:on val cleanedHandler = ssc.sc.clean(messageHandler) ssc.withNamedScope("kinesis stream") { - val kinesisCredsProvider = BasicCredentialsProvider( + val kinesisCredsProvider = BasicCredentials( awsAccessKeyId = awsAccessKeyId, awsSecretKey = awsSecretKey) new KinesisInputDStream[T](ssc, streamName, endpointUrl, validateRegion(regionName), initialPositionInStream, kinesisAppName, checkpointInterval, storageLevel, - cleanedHandler, kinesisCredsProvider) + cleanedHandler, kinesisCredsProvider, None, None) } } @@ -169,6 +171,7 @@ object KinesisUtils { * is enabled. Make sure that your checkpoint directory is secure. */ // scalastyle:off + @deprecated("Use KinesisInputDStream.builder instead", "2.2.0") def createStream[T: ClassTag]( ssc: StreamingContext, kinesisAppName: String, @@ -187,16 +190,16 @@ object KinesisUtils { // scalastyle:on val cleanedHandler = ssc.sc.clean(messageHandler) ssc.withNamedScope("kinesis stream") { - val kinesisCredsProvider = STSCredentialsProvider( + val kinesisCredsProvider = STSCredentials( stsRoleArn = stsAssumeRoleArn, stsSessionName = stsSessionName, stsExternalId = Option(stsExternalId), - longLivedCredsProvider = BasicCredentialsProvider( + longLivedCreds = BasicCredentials( awsAccessKeyId = awsAccessKeyId, awsSecretKey = awsSecretKey)) new KinesisInputDStream[T](ssc, streamName, endpointUrl, validateRegion(regionName), initialPositionInStream, kinesisAppName, checkpointInterval, storageLevel, - cleanedHandler, kinesisCredsProvider) + cleanedHandler, kinesisCredsProvider, None, None) } } @@ -227,6 +230,7 @@ object KinesisUtils { * on the workers. See AWS documentation to understand how DefaultAWSCredentialsProviderChain * gets the AWS credentials. */ + @deprecated("Use KinesisInputDStream.builder instead", "2.2.0") def createStream( ssc: StreamingContext, kinesisAppName: String, @@ -240,7 +244,7 @@ object KinesisUtils { ssc.withNamedScope("kinesis stream") { new KinesisInputDStream[Array[Byte]](ssc, streamName, endpointUrl, validateRegion(regionName), initialPositionInStream, kinesisAppName, checkpointInterval, storageLevel, - defaultMessageHandler, DefaultCredentialsProvider) + KinesisInputDStream.defaultMessageHandler, DefaultCredentials, None, None) } } @@ -272,6 +276,7 @@ object KinesisUtils { * @note The given AWS credentials will get saved in DStream checkpoints if checkpointing * is enabled. Make sure that your checkpoint directory is secure. */ + @deprecated("Use KinesisInputDStream.builder instead", "2.2.0") def createStream( ssc: StreamingContext, kinesisAppName: String, @@ -284,12 +289,12 @@ object KinesisUtils { awsAccessKeyId: String, awsSecretKey: String): ReceiverInputDStream[Array[Byte]] = { ssc.withNamedScope("kinesis stream") { - val kinesisCredsProvider = BasicCredentialsProvider( + val kinesisCredsProvider = BasicCredentials( awsAccessKeyId = awsAccessKeyId, awsSecretKey = awsSecretKey) new KinesisInputDStream[Array[Byte]](ssc, streamName, endpointUrl, validateRegion(regionName), initialPositionInStream, kinesisAppName, checkpointInterval, storageLevel, - defaultMessageHandler, kinesisCredsProvider) + KinesisInputDStream.defaultMessageHandler, kinesisCredsProvider, None, None) } } @@ -323,6 +328,7 @@ object KinesisUtils { * on the workers. See AWS documentation to understand how DefaultAWSCredentialsProviderChain * gets the AWS credentials. */ + @deprecated("Use KinesisInputDStream.builder instead", "2.2.0") def createStream[T]( jssc: JavaStreamingContext, kinesisAppName: String, @@ -372,6 +378,7 @@ object KinesisUtils { * is enabled. Make sure that your checkpoint directory is secure. */ // scalastyle:off + @deprecated("Use KinesisInputDStream.builder instead", "2.2.0") def createStream[T]( jssc: JavaStreamingContext, kinesisAppName: String, @@ -431,6 +438,7 @@ object KinesisUtils { * is enabled. Make sure that your checkpoint directory is secure. */ // scalastyle:off + @deprecated("Use KinesisInputDStream.builder instead", "2.2.0") def createStream[T]( jssc: JavaStreamingContext, kinesisAppName: String, @@ -482,6 +490,7 @@ object KinesisUtils { * on the workers. See AWS documentation to understand how DefaultAWSCredentialsProviderChain * gets the AWS credentials. */ + @deprecated("Use KinesisInputDStream.builder instead", "2.2.0") def createStream( jssc: JavaStreamingContext, kinesisAppName: String, @@ -493,7 +502,8 @@ object KinesisUtils { storageLevel: StorageLevel ): JavaReceiverInputDStream[Array[Byte]] = { createStream[Array[Byte]](jssc.ssc, kinesisAppName, streamName, endpointUrl, regionName, - initialPositionInStream, checkpointInterval, storageLevel, defaultMessageHandler(_)) + initialPositionInStream, checkpointInterval, storageLevel, + KinesisInputDStream.defaultMessageHandler(_)) } /** @@ -524,6 +534,7 @@ object KinesisUtils { * @note The given AWS credentials will get saved in DStream checkpoints if checkpointing * is enabled. Make sure that your checkpoint directory is secure. */ + @deprecated("Use KinesisInputDStream.builder instead", "2.2.0") def createStream( jssc: JavaStreamingContext, kinesisAppName: String, @@ -537,7 +548,7 @@ object KinesisUtils { awsSecretKey: String): JavaReceiverInputDStream[Array[Byte]] = { createStream[Array[Byte]](jssc.ssc, kinesisAppName, streamName, endpointUrl, regionName, initialPositionInStream, checkpointInterval, storageLevel, - defaultMessageHandler(_), awsAccessKeyId, awsSecretKey) + KinesisInputDStream.defaultMessageHandler(_), awsAccessKeyId, awsSecretKey) } private def validateRegion(regionName: String): String = { @@ -545,14 +556,6 @@ object KinesisUtils { throw new IllegalArgumentException(s"Region name '$regionName' is not valid") } } - - private[kinesis] def defaultMessageHandler(record: Record): Array[Byte] = { - if (record == null) return null - val byteBuffer = record.getData() - val byteArray = new Array[Byte](byteBuffer.remaining()) - byteBuffer.get(byteArray) - byteArray - } } /** @@ -597,7 +600,7 @@ private class KinesisUtilsPythonHelper { validateAwsCreds(awsAccessKeyId, awsSecretKey) KinesisUtils.createStream(jssc.ssc, kinesisAppName, streamName, endpointUrl, regionName, getInitialPositionInStream(initialPositionInStream), checkpointInterval, storageLevel, - KinesisUtils.defaultMessageHandler(_), awsAccessKeyId, awsSecretKey, + KinesisInputDStream.defaultMessageHandler(_), awsAccessKeyId, awsSecretKey, stsAssumeRoleArn, stsSessionName, stsExternalId) } else { validateAwsCreds(awsAccessKeyId, awsSecretKey) diff --git a/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/SerializableCredentialsProvider.scala b/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/SerializableCredentialsProvider.scala deleted file mode 100644 index aa6fe12edf74e..0000000000000 --- a/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/SerializableCredentialsProvider.scala +++ /dev/null @@ -1,85 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.spark.streaming.kinesis - -import scala.collection.JavaConverters._ - -import com.amazonaws.auth._ - -import org.apache.spark.internal.Logging - -/** - * Serializable interface providing a method executors can call to obtain an - * AWSCredentialsProvider instance for authenticating to AWS services. - */ -private[kinesis] sealed trait SerializableCredentialsProvider extends Serializable { - /** - * Return an AWSCredentialProvider instance that can be used by the Kinesis Client - * Library to authenticate to AWS services (Kinesis, CloudWatch and DynamoDB). - */ - def provider: AWSCredentialsProvider -} - -/** Returns DefaultAWSCredentialsProviderChain for authentication. */ -private[kinesis] final case object DefaultCredentialsProvider - extends SerializableCredentialsProvider { - - def provider: AWSCredentialsProvider = new DefaultAWSCredentialsProviderChain -} - -/** - * Returns AWSStaticCredentialsProvider constructed using basic AWS keypair. Falls back to using - * DefaultAWSCredentialsProviderChain if unable to construct a AWSCredentialsProviderChain - * instance with the provided arguments (e.g. if they are null). - */ -private[kinesis] final case class BasicCredentialsProvider( - awsAccessKeyId: String, - awsSecretKey: String) extends SerializableCredentialsProvider with Logging { - - def provider: AWSCredentialsProvider = try { - new AWSStaticCredentialsProvider(new BasicAWSCredentials(awsAccessKeyId, awsSecretKey)) - } catch { - case e: IllegalArgumentException => - logWarning("Unable to construct AWSStaticCredentialsProvider with provided keypair; " + - "falling back to DefaultAWSCredentialsProviderChain.", e) - new DefaultAWSCredentialsProviderChain - } -} - -/** - * Returns an STSAssumeRoleSessionCredentialsProvider instance which assumes an IAM - * role in order to authenticate against resources in an external account. - */ -private[kinesis] final case class STSCredentialsProvider( - stsRoleArn: String, - stsSessionName: String, - stsExternalId: Option[String] = None, - longLivedCredsProvider: SerializableCredentialsProvider = DefaultCredentialsProvider) - extends SerializableCredentialsProvider { - - def provider: AWSCredentialsProvider = { - val builder = new STSAssumeRoleSessionCredentialsProvider.Builder(stsRoleArn, stsSessionName) - .withLongLivedCredentialsProvider(longLivedCredsProvider.provider) - stsExternalId match { - case Some(stsExternalId) => - builder.withExternalId(stsExternalId) - .build() - case None => - builder.build() - } - } -} diff --git a/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/SparkAWSCredentials.scala b/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/SparkAWSCredentials.scala new file mode 100644 index 0000000000000..9facfe8ff2b0f --- /dev/null +++ b/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/SparkAWSCredentials.scala @@ -0,0 +1,182 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.streaming.kinesis + +import scala.collection.JavaConverters._ + +import com.amazonaws.auth._ + +import org.apache.spark.annotation.InterfaceStability +import org.apache.spark.internal.Logging + +/** + * Serializable interface providing a method executors can call to obtain an + * AWSCredentialsProvider instance for authenticating to AWS services. + */ +private[kinesis] sealed trait SparkAWSCredentials extends Serializable { + /** + * Return an AWSCredentialProvider instance that can be used by the Kinesis Client + * Library to authenticate to AWS services (Kinesis, CloudWatch and DynamoDB). + */ + def provider: AWSCredentialsProvider +} + +/** Returns DefaultAWSCredentialsProviderChain for authentication. */ +private[kinesis] final case object DefaultCredentials extends SparkAWSCredentials { + + def provider: AWSCredentialsProvider = new DefaultAWSCredentialsProviderChain +} + +/** + * Returns AWSStaticCredentialsProvider constructed using basic AWS keypair. Falls back to using + * DefaultCredentialsProviderChain if unable to construct a AWSCredentialsProviderChain + * instance with the provided arguments (e.g. if they are null). + */ +private[kinesis] final case class BasicCredentials( + awsAccessKeyId: String, + awsSecretKey: String) extends SparkAWSCredentials with Logging { + + def provider: AWSCredentialsProvider = try { + new AWSStaticCredentialsProvider(new BasicAWSCredentials(awsAccessKeyId, awsSecretKey)) + } catch { + case e: IllegalArgumentException => + logWarning("Unable to construct AWSStaticCredentialsProvider with provided keypair; " + + "falling back to DefaultCredentialsProviderChain.", e) + new DefaultAWSCredentialsProviderChain + } +} + +/** + * Returns an STSAssumeRoleSessionCredentialsProvider instance which assumes an IAM + * role in order to authenticate against resources in an external account. + */ +private[kinesis] final case class STSCredentials( + stsRoleArn: String, + stsSessionName: String, + stsExternalId: Option[String] = None, + longLivedCreds: SparkAWSCredentials = DefaultCredentials) + extends SparkAWSCredentials { + + def provider: AWSCredentialsProvider = { + val builder = new STSAssumeRoleSessionCredentialsProvider.Builder(stsRoleArn, stsSessionName) + .withLongLivedCredentialsProvider(longLivedCreds.provider) + stsExternalId match { + case Some(stsExternalId) => + builder.withExternalId(stsExternalId) + .build() + case None => + builder.build() + } + } +} + +@InterfaceStability.Evolving +object SparkAWSCredentials { + /** + * Builder for [[SparkAWSCredentials]] instances. + * + * @since 2.2.0 + */ + @InterfaceStability.Evolving + class Builder { + private var basicCreds: Option[BasicCredentials] = None + private var stsCreds: Option[STSCredentials] = None + + // scalastyle:off + /** + * Use a basic AWS keypair for long-lived authorization. + * + * @note The given AWS keypair will be saved in DStream checkpoints if checkpointing is + * enabled. Make sure that your checkpoint directory is secure. Prefer using the + * [[http://docs.aws.amazon.com/sdk-for-java/v1/developer-guide/credentials.html#credentials-default default provider chain]] + * instead if possible. + * + * @param accessKeyId AWS access key ID + * @param secretKey AWS secret key + * @return Reference to this [[SparkAWSCredentials.Builder]] + */ + // scalastyle:on + def basicCredentials(accessKeyId: String, secretKey: String): Builder = { + basicCreds = Option(BasicCredentials( + awsAccessKeyId = accessKeyId, + awsSecretKey = secretKey)) + this + } + + /** + * Use STS to assume an IAM role for temporary session-based authentication. Will use configured + * long-lived credentials for authorizing to STS itself (either the default provider chain + * or a configured keypair). + * + * @param roleArn ARN of IAM role to assume via STS + * @param sessionName Name to use for the STS session + * @return Reference to this [[SparkAWSCredentials.Builder]] + */ + def stsCredentials(roleArn: String, sessionName: String): Builder = { + stsCreds = Option(STSCredentials(stsRoleArn = roleArn, stsSessionName = sessionName)) + this + } + + /** + * Use STS to assume an IAM role for temporary session-based authentication. Will use configured + * long-lived credentials for authorizing to STS itself (either the default provider chain + * or a configured keypair). STS will validate the provided external ID with the one defined + * in the trust policy of the IAM role to be assumed (if one is present). + * + * @param roleArn ARN of IAM role to assume via STS + * @param sessionName Name to use for the STS session + * @param externalId External ID to validate against assumed IAM role's trust policy + * @return Reference to this [[SparkAWSCredentials.Builder]] + */ + def stsCredentials(roleArn: String, sessionName: String, externalId: String): Builder = { + stsCreds = Option(STSCredentials( + stsRoleArn = roleArn, + stsSessionName = sessionName, + stsExternalId = Option(externalId))) + this + } + + /** + * Returns the appropriate instance of [[SparkAWSCredentials]] given the configured + * parameters. + * + * - The long-lived credentials will either be [[DefaultCredentials]] or [[BasicCredentials]] + * if they were provided. + * + * - If STS credentials were provided, the configured long-lived credentials will be added to + * them and the result will be returned. + * + * - The long-lived credentials will be returned otherwise. + * + * @return [[SparkAWSCredentials]] to use for configured parameters + */ + def build(): SparkAWSCredentials = + stsCreds.map(_.copy(longLivedCreds = longLivedCreds)).getOrElse(longLivedCreds) + + private def longLivedCreds: SparkAWSCredentials = basicCreds.getOrElse(DefaultCredentials) + } + + /** + * Creates a [[SparkAWSCredentials.Builder]] for constructing + * [[SparkAWSCredentials]] instances. + * + * @since 2.2.0 + * + * @return [[SparkAWSCredentials.Builder]] instance + */ + def builder: Builder = new Builder +} diff --git a/external/kinesis-asl/src/test/java/org/apache/spark/streaming/kinesis/JavaKinesisInputDStreamBuilderSuite.java b/external/kinesis-asl/src/test/java/org/apache/spark/streaming/kinesis/JavaKinesisInputDStreamBuilderSuite.java new file mode 100644 index 0000000000000..7205f6e27266c --- /dev/null +++ b/external/kinesis-asl/src/test/java/org/apache/spark/streaming/kinesis/JavaKinesisInputDStreamBuilderSuite.java @@ -0,0 +1,63 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming.kinesis; + +import org.junit.Test; + +import com.amazonaws.services.kinesis.clientlibrary.lib.worker.InitialPositionInStream; + +import org.apache.spark.storage.StorageLevel; +import org.apache.spark.streaming.Duration; +import org.apache.spark.streaming.Seconds; +import org.apache.spark.streaming.LocalJavaStreamingContext; +import org.apache.spark.streaming.api.java.JavaDStream; + +public class JavaKinesisInputDStreamBuilderSuite extends LocalJavaStreamingContext { + /** + * Basic test to ensure that the KinesisDStream.Builder interface is accessible from Java. + */ + @Test + public void testJavaKinesisDStreamBuilder() { + String streamName = "a-very-nice-stream-name"; + String endpointUrl = "https://kinesis.us-west-2.amazonaws.com"; + String region = "us-west-2"; + InitialPositionInStream initialPosition = InitialPositionInStream.TRIM_HORIZON; + String appName = "a-very-nice-kinesis-app"; + Duration checkpointInterval = Seconds.apply(30); + StorageLevel storageLevel = StorageLevel.MEMORY_ONLY(); + + KinesisInputDStream kinesisDStream = KinesisInputDStream.builder() + .streamingContext(ssc) + .streamName(streamName) + .endpointUrl(endpointUrl) + .regionName(region) + .initialPositionInStream(initialPosition) + .checkpointAppName(appName) + .checkpointInterval(checkpointInterval) + .storageLevel(storageLevel) + .build(); + assert(kinesisDStream.streamName() == streamName); + assert(kinesisDStream.endpointUrl() == endpointUrl); + assert(kinesisDStream.regionName() == region); + assert(kinesisDStream.initialPositionInStream() == initialPosition); + assert(kinesisDStream.checkpointAppName() == appName); + assert(kinesisDStream.checkpointInterval() == checkpointInterval); + assert(kinesisDStream._storageLevel() == storageLevel); + ssc.stop(); + } +} diff --git a/external/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisInputDStreamBuilderSuite.scala b/external/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisInputDStreamBuilderSuite.scala new file mode 100644 index 0000000000000..1c130654f3f95 --- /dev/null +++ b/external/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisInputDStreamBuilderSuite.scala @@ -0,0 +1,115 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming.kinesis + +import java.lang.IllegalArgumentException + +import com.amazonaws.services.kinesis.clientlibrary.lib.worker.InitialPositionInStream +import org.scalatest.BeforeAndAfterEach +import org.scalatest.mock.MockitoSugar + +import org.apache.spark.SparkFunSuite +import org.apache.spark.storage.StorageLevel +import org.apache.spark.streaming.{Seconds, StreamingContext, TestSuiteBase} + +class KinesisInputDStreamBuilderSuite extends TestSuiteBase with BeforeAndAfterEach + with MockitoSugar { + import KinesisInputDStream._ + + private val ssc = new StreamingContext(conf, batchDuration) + private val streamName = "a-very-nice-kinesis-stream-name" + private val checkpointAppName = "a-very-nice-kcl-app-name" + private def baseBuilder = KinesisInputDStream.builder + private def builder = baseBuilder.streamingContext(ssc) + .streamName(streamName) + .checkpointAppName(checkpointAppName) + + override def afterAll(): Unit = { + ssc.stop() + } + + test("should raise an exception if the StreamingContext is missing") { + intercept[IllegalArgumentException] { + baseBuilder.streamName(streamName).checkpointAppName(checkpointAppName).build() + } + } + + test("should raise an exception if the stream name is missing") { + intercept[IllegalArgumentException] { + baseBuilder.streamingContext(ssc).checkpointAppName(checkpointAppName).build() + } + } + + test("should raise an exception if the checkpoint app name is missing") { + intercept[IllegalArgumentException] { + baseBuilder.streamingContext(ssc).streamName(streamName).build() + } + } + + test("should propagate required values to KinesisInputDStream") { + val dstream = builder.build() + assert(dstream.context == ssc) + assert(dstream.streamName == streamName) + assert(dstream.checkpointAppName == checkpointAppName) + } + + test("should propagate default values to KinesisInputDStream") { + val dstream = builder.build() + assert(dstream.endpointUrl == DEFAULT_KINESIS_ENDPOINT_URL) + assert(dstream.regionName == DEFAULT_KINESIS_REGION_NAME) + assert(dstream.initialPositionInStream == DEFAULT_INITIAL_POSITION_IN_STREAM) + assert(dstream.checkpointInterval == batchDuration) + assert(dstream._storageLevel == DEFAULT_STORAGE_LEVEL) + assert(dstream.kinesisCreds == DefaultCredentials) + assert(dstream.dynamoDBCreds == None) + assert(dstream.cloudWatchCreds == None) + } + + test("should propagate custom non-auth values to KinesisInputDStream") { + val customEndpointUrl = "https://kinesis.us-west-2.amazonaws.com" + val customRegion = "us-west-2" + val customInitialPosition = InitialPositionInStream.TRIM_HORIZON + val customAppName = "a-very-nice-kinesis-app" + val customCheckpointInterval = Seconds(30) + val customStorageLevel = StorageLevel.MEMORY_ONLY + val customKinesisCreds = mock[SparkAWSCredentials] + val customDynamoDBCreds = mock[SparkAWSCredentials] + val customCloudWatchCreds = mock[SparkAWSCredentials] + + val dstream = builder + .endpointUrl(customEndpointUrl) + .regionName(customRegion) + .initialPositionInStream(customInitialPosition) + .checkpointAppName(customAppName) + .checkpointInterval(customCheckpointInterval) + .storageLevel(customStorageLevel) + .kinesisCredentials(customKinesisCreds) + .dynamoDBCredentials(customDynamoDBCreds) + .cloudWatchCredentials(customCloudWatchCreds) + .build() + assert(dstream.endpointUrl == customEndpointUrl) + assert(dstream.regionName == customRegion) + assert(dstream.initialPositionInStream == customInitialPosition) + assert(dstream.checkpointAppName == customAppName) + assert(dstream.checkpointInterval == customCheckpointInterval) + assert(dstream._storageLevel == customStorageLevel) + assert(dstream.kinesisCreds == customKinesisCreds) + assert(dstream.dynamoDBCreds == Option(customDynamoDBCreds)) + assert(dstream.cloudWatchCreds == Option(customCloudWatchCreds)) + } +} diff --git a/external/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisReceiverSuite.scala b/external/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisReceiverSuite.scala index deb411d73e588..3b14c8471e205 100644 --- a/external/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisReceiverSuite.scala +++ b/external/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisReceiverSuite.scala @@ -31,7 +31,6 @@ import org.scalatest.{BeforeAndAfter, Matchers} import org.scalatest.mock.MockitoSugar import org.apache.spark.streaming.{Duration, TestSuiteBase} -import org.apache.spark.util.Utils /** * Suite of Kinesis streaming receiver tests focusing mostly on the KinesisRecordProcessor @@ -62,28 +61,6 @@ class KinesisReceiverSuite extends TestSuiteBase with Matchers with BeforeAndAft checkpointerMock = mock[IRecordProcessorCheckpointer] } - test("check serializability of credential provider classes") { - Utils.deserialize[BasicCredentialsProvider]( - Utils.serialize(BasicCredentialsProvider( - awsAccessKeyId = "x", - awsSecretKey = "y"))) - - Utils.deserialize[STSCredentialsProvider]( - Utils.serialize(STSCredentialsProvider( - stsRoleArn = "fakeArn", - stsSessionName = "fakeSessionName", - stsExternalId = Some("fakeExternalId")))) - - Utils.deserialize[STSCredentialsProvider]( - Utils.serialize(STSCredentialsProvider( - stsRoleArn = "fakeArn", - stsSessionName = "fakeSessionName", - stsExternalId = Some("fakeExternalId"), - longLivedCredsProvider = BasicCredentialsProvider( - awsAccessKeyId = "x", - awsSecretKey = "y")))) - } - test("process records including store and set checkpointer") { when(receiverMock.isStopped()).thenReturn(false) when(receiverMock.getCurrentLimit).thenReturn(Int.MaxValue) diff --git a/external/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisStreamSuite.scala b/external/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisStreamSuite.scala index afb55c84f81fe..ed7e35805026e 100644 --- a/external/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisStreamSuite.scala +++ b/external/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisStreamSuite.scala @@ -138,7 +138,7 @@ abstract class KinesisStreamTests(aggregateTestData: Boolean) extends KinesisFun assert(kinesisRDD.regionName === dummyRegionName) assert(kinesisRDD.endpointUrl === dummyEndpointUrl) assert(kinesisRDD.retryTimeoutMs === batchDuration.milliseconds) - assert(kinesisRDD.kinesisCredsProvider === BasicCredentialsProvider( + assert(kinesisRDD.kinesisCreds === BasicCredentials( awsAccessKeyId = dummyAWSAccessKey, awsSecretKey = dummyAWSSecretKey)) assert(nonEmptyRDD.partitions.size === blockInfos.size) diff --git a/external/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/SparkAWSCredentialsBuilderSuite.scala b/external/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/SparkAWSCredentialsBuilderSuite.scala new file mode 100644 index 0000000000000..f579c2c3a6799 --- /dev/null +++ b/external/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/SparkAWSCredentialsBuilderSuite.scala @@ -0,0 +1,100 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming.kinesis + +import org.apache.spark.streaming.TestSuiteBase +import org.apache.spark.util.Utils + +class SparkAWSCredentialsBuilderSuite extends TestSuiteBase { + private def builder = SparkAWSCredentials.builder + + private val basicCreds = BasicCredentials( + awsAccessKeyId = "a-very-nice-access-key", + awsSecretKey = "a-very-nice-secret-key") + + private val stsCreds = STSCredentials( + stsRoleArn = "a-very-nice-role-arn", + stsSessionName = "a-very-nice-secret-key", + stsExternalId = Option("a-very-nice-external-id"), + longLivedCreds = basicCreds) + + test("should build DefaultCredentials when given no params") { + assert(builder.build() == DefaultCredentials) + } + + test("should build BasicCredentials") { + assertResult(basicCreds) { + builder.basicCredentials(basicCreds.awsAccessKeyId, basicCreds.awsSecretKey) + .build() + } + } + + test("should build STSCredentials") { + // No external ID, default long-lived creds + assertResult(stsCreds.copy(stsExternalId = None, longLivedCreds = DefaultCredentials)) { + builder.stsCredentials(stsCreds.stsRoleArn, stsCreds.stsSessionName) + .build() + } + // Default long-lived creds + assertResult(stsCreds.copy(longLivedCreds = DefaultCredentials)) { + builder.stsCredentials( + stsCreds.stsRoleArn, + stsCreds.stsSessionName, + stsCreds.stsExternalId.get) + .build() + } + // No external ID, basic keypair for long-lived creds + assertResult(stsCreds.copy(stsExternalId = None)) { + builder.stsCredentials(stsCreds.stsRoleArn, stsCreds.stsSessionName) + .basicCredentials(basicCreds.awsAccessKeyId, basicCreds.awsSecretKey) + .build() + } + // Basic keypair for long-lived creds + assertResult(stsCreds) { + builder.stsCredentials( + stsCreds.stsRoleArn, + stsCreds.stsSessionName, + stsCreds.stsExternalId.get) + .basicCredentials(basicCreds.awsAccessKeyId, basicCreds.awsSecretKey) + .build() + } + // Order shouldn't matter + assertResult(stsCreds) { + builder.basicCredentials(basicCreds.awsAccessKeyId, basicCreds.awsSecretKey) + .stsCredentials( + stsCreds.stsRoleArn, + stsCreds.stsSessionName, + stsCreds.stsExternalId.get) + .build() + } + } + + test("SparkAWSCredentials classes should be serializable") { + assertResult(basicCreds) { + Utils.deserialize[BasicCredentials](Utils.serialize(basicCreds)) + } + assertResult(stsCreds) { + Utils.deserialize[STSCredentials](Utils.serialize(stsCreds)) + } + // Will also test if DefaultCredentials can be serialized + val stsDefaultCreds = stsCreds.copy(longLivedCreds = DefaultCredentials) + assertResult(stsDefaultCreds) { + Utils.deserialize[STSCredentials](Utils.serialize(stsDefaultCreds)) + } + } +} From e8810b73c495b6d437dd3b9bb334762126b3c063 Mon Sep 17 00:00:00 2001 From: sethah Date: Fri, 24 Mar 2017 20:32:42 +0000 Subject: [PATCH 0108/1765] [SPARK-17471][ML] Add compressed method to ML matrices ## What changes were proposed in this pull request? This patch adds a `compressed` method to ML `Matrix` class, which returns the minimal storage representation of the matrix - either sparse or dense. Because the space occupied by a sparse matrix is dependent upon its layout (i.e. column major or row major), this method must consider both cases. It may also be useful to force the layout to be column or row major beforehand, so an overload is added which takes in a `columnMajor: Boolean` parameter. The compressed implementation relies upon two new abstract methods `toDense(columnMajor: Boolean)` and `toSparse(columnMajor: Boolean)`, similar to the compressed method implemented in the `Vector` class. These methods also allow the layout of the resulting matrix to be specified via the `columnMajor` parameter. More detail on the new methods is given below. ## How was this patch tested? Added many new unit tests ## New methods (summary, not exhaustive list) **Matrix trait** - `private[ml] def toDenseMatrix(columnMajor: Boolean): DenseMatrix` (abstract) - converts the matrix (either sparse or dense) to dense format - `private[ml] def toSparseMatrix(columnMajor: Boolean): SparseMatrix` (abstract) - converts the matrix (either sparse or dense) to sparse format - `def toDense: DenseMatrix = toDense(true)` - converts the matrix (either sparse or dense) to dense format in column major layout - `def toSparse: SparseMatrix = toSparse(true)` - converts the matrix (either sparse or dense) to sparse format in column major layout - `def compressed: Matrix` - finds the minimum space representation of this matrix, considering both column and row major layouts, and converts it - `def compressed(columnMajor: Boolean): Matrix` - finds the minimum space representation of this matrix considering only column OR row major, and converts it **DenseMatrix class** - `private[ml] def toDenseMatrix(columnMajor: Boolean): DenseMatrix` - converts the dense matrix to a dense matrix, optionally changing the layout (data is NOT duplicated if the layouts are the same) - `private[ml] def toSparseMatrix(columnMajor: Boolean): SparseMatrix` - converts the dense matrix to sparse matrix, using the specified layout **SparseMatrix class** - `private[ml] def toDenseMatrix(columnMajor: Boolean): DenseMatrix` - converts the sparse matrix to a dense matrix, using the specified layout - `private[ml] def toSparseMatrix(columnMajors: Boolean): SparseMatrix` - converts the sparse matrix to sparse matrix. If the sparse matrix contains any explicit zeros, they are removed. If the layout requested does not match the current layout, data is copied to a new representation. If the layouts match and no explicit zeros exist, the current matrix is returned. Author: sethah Closes #15628 from sethah/matrix_compress. --- .../org/apache/spark/ml/linalg/Matrices.scala | 274 ++++++++++-- .../spark/ml/linalg/MatricesSuite.scala | 420 +++++++++++++++++- .../apache/spark/ml/linalg/VectorsSuite.scala | 5 + project/MimaExcludes.scala | 20 +- 4 files changed, 673 insertions(+), 46 deletions(-) diff --git a/mllib-local/src/main/scala/org/apache/spark/ml/linalg/Matrices.scala b/mllib-local/src/main/scala/org/apache/spark/ml/linalg/Matrices.scala index d9ffdeb797fb8..07f3bc27280bd 100644 --- a/mllib-local/src/main/scala/org/apache/spark/ml/linalg/Matrices.scala +++ b/mllib-local/src/main/scala/org/apache/spark/ml/linalg/Matrices.scala @@ -44,6 +44,12 @@ sealed trait Matrix extends Serializable { @Since("2.0.0") val isTransposed: Boolean = false + /** Indicates whether the values backing this matrix are arranged in column major order. */ + private[ml] def isColMajor: Boolean = !isTransposed + + /** Indicates whether the values backing this matrix are arranged in row major order. */ + private[ml] def isRowMajor: Boolean = isTransposed + /** Converts to a dense array in column major. */ @Since("2.0.0") def toArray: Array[Double] = { @@ -148,7 +154,8 @@ sealed trait Matrix extends Serializable { * and column indices respectively with the type `Int`, and the final parameter is the * corresponding value in the matrix with type `Double`. */ - private[spark] def foreachActive(f: (Int, Int, Double) => Unit) + @Since("2.2.0") + def foreachActive(f: (Int, Int, Double) => Unit): Unit /** * Find the number of non-zero active values. @@ -161,6 +168,116 @@ sealed trait Matrix extends Serializable { */ @Since("2.0.0") def numActives: Int + + /** + * Converts this matrix to a sparse matrix. + * + * @param colMajor Whether the values of the resulting sparse matrix should be in column major + * or row major order. If `false`, resulting matrix will be row major. + */ + private[ml] def toSparseMatrix(colMajor: Boolean): SparseMatrix + + /** + * Converts this matrix to a sparse matrix in column major order. + */ + @Since("2.2.0") + def toSparseColMajor: SparseMatrix = toSparseMatrix(colMajor = true) + + /** + * Converts this matrix to a sparse matrix in row major order. + */ + @Since("2.2.0") + def toSparseRowMajor: SparseMatrix = toSparseMatrix(colMajor = false) + + /** + * Converts this matrix to a sparse matrix while maintaining the layout of the current matrix. + */ + @Since("2.2.0") + def toSparse: SparseMatrix = toSparseMatrix(colMajor = isColMajor) + + /** + * Converts this matrix to a dense matrix. + * + * @param colMajor Whether the values of the resulting dense matrix should be in column major + * or row major order. If `false`, resulting matrix will be row major. + */ + private[ml] def toDenseMatrix(colMajor: Boolean): DenseMatrix + + /** + * Converts this matrix to a dense matrix while maintaining the layout of the current matrix. + */ + @Since("2.2.0") + def toDense: DenseMatrix = toDenseMatrix(colMajor = isColMajor) + + /** + * Converts this matrix to a dense matrix in row major order. + */ + @Since("2.2.0") + def toDenseRowMajor: DenseMatrix = toDenseMatrix(colMajor = false) + + /** + * Converts this matrix to a dense matrix in column major order. + */ + @Since("2.2.0") + def toDenseColMajor: DenseMatrix = toDenseMatrix(colMajor = true) + + /** + * Returns a matrix in dense or sparse column major format, whichever uses less storage. + */ + @Since("2.2.0") + def compressedColMajor: Matrix = { + if (getDenseSizeInBytes <= getSparseSizeInBytes(colMajor = true)) { + this.toDenseColMajor + } else { + this.toSparseColMajor + } + } + + /** + * Returns a matrix in dense or sparse row major format, whichever uses less storage. + */ + @Since("2.2.0") + def compressedRowMajor: Matrix = { + if (getDenseSizeInBytes <= getSparseSizeInBytes(colMajor = false)) { + this.toDenseRowMajor + } else { + this.toSparseRowMajor + } + } + + /** + * Returns a matrix in dense column major, dense row major, sparse row major, or sparse column + * major format, whichever uses less storage. When dense representation is optimal, it maintains + * the current layout order. + */ + @Since("2.2.0") + def compressed: Matrix = { + val cscSize = getSparseSizeInBytes(colMajor = true) + val csrSize = getSparseSizeInBytes(colMajor = false) + if (getDenseSizeInBytes <= math.min(cscSize, csrSize)) { + // dense matrix size is the same for column major and row major, so maintain current layout + this.toDense + } else if (cscSize <= csrSize) { + this.toSparseColMajor + } else { + this.toSparseRowMajor + } + } + + /** Gets the size of the dense representation of this `Matrix`. */ + private[ml] def getDenseSizeInBytes: Long = { + Matrices.getDenseSize(numCols, numRows) + } + + /** Gets the size of the minimal sparse representation of this `Matrix`. */ + private[ml] def getSparseSizeInBytes(colMajor: Boolean): Long = { + val nnz = numNonzeros + val numPtrs = if (colMajor) numCols + 1L else numRows + 1L + Matrices.getSparseSize(nnz, numPtrs) + } + + /** Gets the current size in bytes of this `Matrix`. Useful for testing */ + private[ml] def getSizeInBytes: Long } /** @@ -258,7 +375,7 @@ class DenseMatrix @Since("2.0.0") ( override def transpose: DenseMatrix = new DenseMatrix(numCols, numRows, values, !isTransposed) - private[spark] override def foreachActive(f: (Int, Int, Double) => Unit): Unit = { + override def foreachActive(f: (Int, Int, Double) => Unit): Unit = { if (!isTransposed) { // outer loop over columns var j = 0 @@ -291,31 +408,49 @@ class DenseMatrix @Since("2.0.0") ( override def numActives: Int = values.length /** - * Generate a `SparseMatrix` from the given `DenseMatrix`. The new matrix will have isTransposed - * set to false. + * Generate a `SparseMatrix` from the given `DenseMatrix`. + * + * @param colMajor Whether the resulting `SparseMatrix` values will be in column major order. */ - @Since("2.0.0") - def toSparse: SparseMatrix = { - val spVals: MArrayBuilder[Double] = new MArrayBuilder.ofDouble - val colPtrs: Array[Int] = new Array[Int](numCols + 1) - val rowIndices: MArrayBuilder[Int] = new MArrayBuilder.ofInt - var nnz = 0 - var j = 0 - while (j < numCols) { - var i = 0 - while (i < numRows) { - val v = values(index(i, j)) - if (v != 0.0) { - rowIndices += i - spVals += v - nnz += 1 + private[ml] override def toSparseMatrix(colMajor: Boolean): SparseMatrix = { + if (!colMajor) this.transpose.toSparseColMajor.transpose + else { + val spVals: MArrayBuilder[Double] = new MArrayBuilder.ofDouble + val colPtrs: Array[Int] = new Array[Int](numCols + 1) + val rowIndices: MArrayBuilder[Int] = new MArrayBuilder.ofInt + var nnz = 0 + var j = 0 + while (j < numCols) { + var i = 0 + while (i < numRows) { + val v = values(index(i, j)) + if (v != 0.0) { + rowIndices += i + spVals += v + nnz += 1 + } + i += 1 } - i += 1 + j += 1 + colPtrs(j) = nnz } - j += 1 - colPtrs(j) = nnz + new SparseMatrix(numRows, numCols, colPtrs, rowIndices.result(), spVals.result()) + } + } + + /** + * Generate a `DenseMatrix` from this `DenseMatrix`. + * + * @param colMajor Whether the resulting `DenseMatrix` values will be in column major order. + */ + private[ml] override def toDenseMatrix(colMajor: Boolean): DenseMatrix = { + if (isRowMajor && colMajor) { + new DenseMatrix(numRows, numCols, this.toArray, isTransposed = false) + } else if (isColMajor && !colMajor) { + new DenseMatrix(numRows, numCols, this.transpose.toArray, isTransposed = true) + } else { + this } - new SparseMatrix(numRows, numCols, colPtrs, rowIndices.result(), spVals.result()) } override def colIter: Iterator[Vector] = { @@ -331,6 +466,8 @@ class DenseMatrix @Since("2.0.0") ( } } } + + private[ml] def getSizeInBytes: Long = Matrices.getDenseSize(numCols, numRows) } /** @@ -560,7 +697,7 @@ class SparseMatrix @Since("2.0.0") ( override def transpose: SparseMatrix = new SparseMatrix(numCols, numRows, colPtrs, rowIndices, values, !isTransposed) - private[spark] override def foreachActive(f: (Int, Int, Double) => Unit): Unit = { + override def foreachActive(f: (Int, Int, Double) => Unit): Unit = { if (!isTransposed) { var j = 0 while (j < numCols) { @@ -587,18 +724,67 @@ class SparseMatrix @Since("2.0.0") ( } } + override def numNonzeros: Int = values.count(_ != 0) + + override def numActives: Int = values.length + /** - * Generate a `DenseMatrix` from the given `SparseMatrix`. The new matrix will have isTransposed - * set to false. + * Generate a `SparseMatrix` from this `SparseMatrix`, removing explicit zero values if they + * exist. + * + * @param colMajor Whether or not the resulting `SparseMatrix` values are in column major + * order. */ - @Since("2.0.0") - def toDense: DenseMatrix = { - new DenseMatrix(numRows, numCols, toArray) + private[ml] override def toSparseMatrix(colMajor: Boolean): SparseMatrix = { + if (isColMajor && !colMajor) { + // it is col major and we want row major, use breeze to remove explicit zeros + val breezeTransposed = asBreeze.asInstanceOf[BSM[Double]].t + Matrices.fromBreeze(breezeTransposed).transpose.asInstanceOf[SparseMatrix] + } else if (isRowMajor && colMajor) { + // it is row major and we want col major, use breeze to remove explicit zeros + val breezeTransposed = asBreeze.asInstanceOf[BSM[Double]] + Matrices.fromBreeze(breezeTransposed).asInstanceOf[SparseMatrix] + } else { + val nnz = numNonzeros + if (nnz != numActives) { + // remove explicit zeros + val rr = new Array[Int](nnz) + val vv = new Array[Double](nnz) + val numPtrs = if (isRowMajor) numRows else numCols + val cc = new Array[Int](numPtrs + 1) + var nzIdx = 0 + var j = 0 + while (j < numPtrs) { + var idx = colPtrs(j) + val idxEnd = colPtrs(j + 1) + cc(j) = nzIdx + while (idx < idxEnd) { + if (values(idx) != 0.0) { + vv(nzIdx) = values(idx) + rr(nzIdx) = rowIndices(idx) + nzIdx += 1 + } + idx += 1 + } + j += 1 + } + cc(j) = nnz + new SparseMatrix(numRows, numCols, cc, rr, vv, isTransposed = isTransposed) + } else { + this + } + } } - override def numNonzeros: Int = values.count(_ != 0) - - override def numActives: Int = values.length + /** + * Generate a `DenseMatrix` from the given `SparseMatrix`. + * + * @param colMajor Whether the resulting `DenseMatrix` values are in column major order. + */ + private[ml] override def toDenseMatrix(colMajor: Boolean): DenseMatrix = { + if (colMajor) new DenseMatrix(numRows, numCols, this.toArray) + else new DenseMatrix(numRows, numCols, this.transpose.toArray, isTransposed = true) + } override def colIter: Iterator[Vector] = { if (isTransposed) { @@ -631,6 +817,8 @@ class SparseMatrix @Since("2.0.0") ( } } } + + private[ml] def getSizeInBytes: Long = Matrices.getSparseSize(numActives, colPtrs.length) } /** @@ -1079,4 +1267,26 @@ object Matrices { SparseMatrix.fromCOO(numRows, numCols, entries) } } + + private[ml] def getSparseSize(numActives: Long, numPtrs: Long): Long = { + /* + Sparse matrices store two int arrays, one double array, two ints, and one boolean: + 8 * values.length + 4 * rowIndices.length + 4 * colPtrs.length + arrayHeader * 3 + 2 * 4 + 1 + */ + val doubleBytes = java.lang.Double.BYTES + val intBytes = java.lang.Integer.BYTES + val arrayHeader = 12L + doubleBytes * numActives + intBytes * numActives + intBytes * numPtrs + arrayHeader * 3L + 9L + } + + private[ml] def getDenseSize(numCols: Long, numRows: Long): Long = { + /* + Dense matrices store one double array, two ints, and one boolean: + 8 * values.length + arrayHeader + 2 * 4 + 1 + */ + val doubleBytes = java.lang.Double.BYTES + val arrayHeader = 12L + doubleBytes * numCols * numRows + arrayHeader + 9L + } + } diff --git a/mllib-local/src/test/scala/org/apache/spark/ml/linalg/MatricesSuite.scala b/mllib-local/src/test/scala/org/apache/spark/ml/linalg/MatricesSuite.scala index 9c0aa73938478..9f8202086817d 100644 --- a/mllib-local/src/test/scala/org/apache/spark/ml/linalg/MatricesSuite.scala +++ b/mllib-local/src/test/scala/org/apache/spark/ml/linalg/MatricesSuite.scala @@ -160,22 +160,416 @@ class MatricesSuite extends SparkMLFunSuite { assert(sparseMat.values(2) === 10.0) } - test("toSparse, toDense") { - val m = 3 - val n = 2 - val values = Array(1.0, 2.0, 4.0, 5.0) - val allValues = Array(1.0, 2.0, 0.0, 0.0, 4.0, 5.0) - val colPtrs = Array(0, 2, 4) - val rowIndices = Array(0, 1, 1, 2) + test("dense to dense") { + /* + dm1 = 4.0 2.0 -8.0 + -1.0 7.0 4.0 + + dm2 = 5.0 -9.0 4.0 + 1.0 -3.0 -8.0 + */ + val dm1 = new DenseMatrix(2, 3, Array(4.0, -1.0, 2.0, 7.0, -8.0, 4.0)) + val dm2 = new DenseMatrix(2, 3, Array(5.0, -9.0, 4.0, 1.0, -3.0, -8.0), isTransposed = true) + + val dm8 = dm1.toDenseColMajor + assert(dm8 === dm1) + assert(dm8.isColMajor) + assert(dm8.values.equals(dm1.values)) + + val dm5 = dm2.toDenseColMajor + assert(dm5 === dm2) + assert(dm5.isColMajor) + assert(dm5.values === Array(5.0, 1.0, -9.0, -3.0, 4.0, -8.0)) + + val dm4 = dm1.toDenseRowMajor + assert(dm4 === dm1) + assert(dm4.isRowMajor) + assert(dm4.values === Array(4.0, 2.0, -8.0, -1.0, 7.0, 4.0)) + + val dm6 = dm2.toDenseRowMajor + assert(dm6 === dm2) + assert(dm6.isRowMajor) + assert(dm6.values.equals(dm2.values)) + + val dm3 = dm1.toDense + assert(dm3 === dm1) + assert(dm3.isColMajor) + assert(dm3.values.equals(dm1.values)) + + val dm9 = dm2.toDense + assert(dm9 === dm2) + assert(dm9.isRowMajor) + assert(dm9.values.equals(dm2.values)) + } - val spMat1 = new SparseMatrix(m, n, colPtrs, rowIndices, values) - val deMat1 = new DenseMatrix(m, n, allValues) + test("dense to sparse") { + /* + dm1 = 0.0 4.0 5.0 + 0.0 2.0 0.0 + + dm2 = 0.0 4.0 5.0 + 0.0 2.0 0.0 - val spMat2 = deMat1.toSparse - val deMat2 = spMat1.toDense + dm3 = 0.0 0.0 0.0 + 0.0 0.0 0.0 + */ + val dm1 = new DenseMatrix(2, 3, Array(0.0, 0.0, 4.0, 2.0, 5.0, 0.0)) + val dm2 = new DenseMatrix(2, 3, Array(0.0, 4.0, 5.0, 0.0, 2.0, 0.0), isTransposed = true) + val dm3 = new DenseMatrix(2, 3, Array(0.0, 0.0, 0.0, 0.0, 0.0, 0.0)) + + val sm1 = dm1.toSparseColMajor + assert(sm1 === dm1) + assert(sm1.isColMajor) + assert(sm1.values === Array(4.0, 2.0, 5.0)) + + val sm3 = dm2.toSparseColMajor + assert(sm3 === dm2) + assert(sm3.isColMajor) + assert(sm3.values === Array(4.0, 2.0, 5.0)) + + val sm5 = dm3.toSparseColMajor + assert(sm5 === dm3) + assert(sm5.values === Array.empty[Double]) + assert(sm5.isColMajor) + + val sm2 = dm1.toSparseRowMajor + assert(sm2 === dm1) + assert(sm2.isRowMajor) + assert(sm2.values === Array(4.0, 5.0, 2.0)) + + val sm4 = dm2.toSparseRowMajor + assert(sm4 === dm2) + assert(sm4.isRowMajor) + assert(sm4.values === Array(4.0, 5.0, 2.0)) + + val sm6 = dm3.toSparseRowMajor + assert(sm6 === dm3) + assert(sm6.values === Array.empty[Double]) + assert(sm6.isRowMajor) + + val sm7 = dm1.toSparse + assert(sm7 === dm1) + assert(sm7.values === Array(4.0, 2.0, 5.0)) + assert(sm7.isColMajor) + + val sm10 = dm2.toSparse + assert(sm10 === dm2) + assert(sm10.values === Array(4.0, 5.0, 2.0)) + assert(sm10.isRowMajor) + } + + test("sparse to sparse") { + /* + sm1 = sm2 = sm3 = sm4 = 0.0 4.0 5.0 + 0.0 2.0 0.0 + smZeros = 0.0 0.0 0.0 + 0.0 0.0 0.0 + */ + val sm1 = new SparseMatrix(2, 3, Array(0, 0, 2, 3), Array(0, 1, 0), Array(4.0, 2.0, 5.0)) + val sm2 = new SparseMatrix(2, 3, Array(0, 2, 3), Array(1, 2, 1), Array(4.0, 5.0, 2.0), + isTransposed = true) + val sm3 = new SparseMatrix(2, 3, Array(0, 0, 2, 4), Array(0, 1, 0, 1), + Array(4.0, 2.0, 5.0, 0.0)) + val sm4 = new SparseMatrix(2, 3, Array(0, 2, 4), Array(1, 2, 1, 2), + Array(4.0, 5.0, 2.0, 0.0), isTransposed = true) + val smZeros = new SparseMatrix(2, 3, Array(0, 2, 4, 6), Array(0, 1, 0, 1, 0, 1), + Array(0.0, 0.0, 0.0, 0.0, 0.0, 0.0)) + + val sm6 = sm1.toSparseColMajor + assert(sm6 === sm1) + assert(sm6.isColMajor) + assert(sm6.values.equals(sm1.values)) + + val sm7 = sm2.toSparseColMajor + assert(sm7 === sm2) + assert(sm7.isColMajor) + assert(sm7.values === Array(4.0, 2.0, 5.0)) + + val sm16 = sm3.toSparseColMajor + assert(sm16 === sm3) + assert(sm16.isColMajor) + assert(sm16.values === Array(4.0, 2.0, 5.0)) + + val sm14 = sm4.toSparseColMajor + assert(sm14 === sm4) + assert(sm14.values === Array(4.0, 2.0, 5.0)) + assert(sm14.isColMajor) + + val sm15 = smZeros.toSparseColMajor + assert(sm15 === smZeros) + assert(sm15.values === Array.empty[Double]) + assert(sm15.isColMajor) + + val sm5 = sm1.toSparseRowMajor + assert(sm5 === sm1) + assert(sm5.isRowMajor) + assert(sm5.values === Array(4.0, 5.0, 2.0)) + + val sm8 = sm2.toSparseRowMajor + assert(sm8 === sm2) + assert(sm8.isRowMajor) + assert(sm8.values.equals(sm2.values)) + + val sm10 = sm3.toSparseRowMajor + assert(sm10 === sm3) + assert(sm10.values === Array(4.0, 5.0, 2.0)) + assert(sm10.isRowMajor) + + val sm11 = sm4.toSparseRowMajor + assert(sm11 === sm4) + assert(sm11.values === Array(4.0, 5.0, 2.0)) + assert(sm11.isRowMajor) + + val sm17 = smZeros.toSparseRowMajor + assert(sm17 === smZeros) + assert(sm17.values === Array.empty[Double]) + assert(sm17.isRowMajor) + + val sm9 = sm3.toSparse + assert(sm9 === sm3) + assert(sm9.values === Array(4.0, 2.0, 5.0)) + assert(sm9.isColMajor) + + val sm12 = sm4.toSparse + assert(sm12 === sm4) + assert(sm12.values === Array(4.0, 5.0, 2.0)) + assert(sm12.isRowMajor) + + val sm13 = smZeros.toSparse + assert(sm13 === smZeros) + assert(sm13.values === Array.empty[Double]) + assert(sm13.isColMajor) + } + + test("sparse to dense") { + /* + sm1 = sm2 = 0.0 4.0 5.0 + 0.0 2.0 0.0 + + sm3 = 0.0 0.0 0.0 + 0.0 0.0 0.0 + */ + val sm1 = new SparseMatrix(2, 3, Array(0, 0, 2, 3), Array(0, 1, 0), Array(4.0, 2.0, 5.0)) + val sm2 = new SparseMatrix(2, 3, Array(0, 2, 3), Array(1, 2, 1), Array(4.0, 5.0, 2.0), + isTransposed = true) + val sm3 = new SparseMatrix(2, 3, Array(0, 0, 0, 0), Array.empty[Int], Array.empty[Double]) + + val dm6 = sm1.toDenseColMajor + assert(dm6 === sm1) + assert(dm6.isColMajor) + assert(dm6.values === Array(0.0, 0.0, 4.0, 2.0, 5.0, 0.0)) + + val dm7 = sm2.toDenseColMajor + assert(dm7 === sm2) + assert(dm7.isColMajor) + assert(dm7.values === Array(0.0, 0.0, 4.0, 2.0, 5.0, 0.0)) + + val dm2 = sm1.toDenseRowMajor + assert(dm2 === sm1) + assert(dm2.isRowMajor) + assert(dm2.values === Array(0.0, 4.0, 5.0, 0.0, 2.0, 0.0)) + + val dm4 = sm2.toDenseRowMajor + assert(dm4 === sm2) + assert(dm4.isRowMajor) + assert(dm4.values === Array(0.0, 4.0, 5.0, 0.0, 2.0, 0.0)) + + val dm1 = sm1.toDense + assert(dm1 === sm1) + assert(dm1.isColMajor) + assert(dm1.values === Array(0.0, 0.0, 4.0, 2.0, 5.0, 0.0)) + + val dm3 = sm2.toDense + assert(dm3 === sm2) + assert(dm3.isRowMajor) + assert(dm3.values === Array(0.0, 4.0, 5.0, 0.0, 2.0, 0.0)) + + val dm5 = sm3.toDense + assert(dm5 === sm3) + assert(dm5.isColMajor) + assert(dm5.values === Array.fill(6)(0.0)) + } + + test("compressed dense") { + /* + dm1 = 1.0 0.0 0.0 0.0 + 1.0 0.0 0.0 0.0 + 0.0 0.0 0.0 0.0 + + dm2 = 1.0 1.0 0.0 0.0 + 0.0 0.0 0.0 0.0 + 0.0 0.0 0.0 0.0 + */ + // this should compress to a sparse matrix + val dm1 = new DenseMatrix(3, 4, Array.fill(2)(1.0) ++ Array.fill(10)(0.0)) + + // optimal compression layout is row major since numRows < numCols + val cm1 = dm1.compressed.asInstanceOf[SparseMatrix] + assert(cm1 === dm1) + assert(cm1.isRowMajor) + assert(cm1.getSizeInBytes < dm1.getSizeInBytes) + + // force compressed column major + val cm2 = dm1.compressedColMajor.asInstanceOf[SparseMatrix] + assert(cm2 === dm1) + assert(cm2.isColMajor) + assert(cm2.getSizeInBytes < dm1.getSizeInBytes) + + // optimal compression layout for transpose is column major + val dm2 = dm1.transpose + val cm3 = dm2.compressed.asInstanceOf[SparseMatrix] + assert(cm3 === dm2) + assert(cm3.isColMajor) + assert(cm3.getSizeInBytes < dm2.getSizeInBytes) + + /* + dm3 = 1.0 1.0 1.0 0.0 + 1.0 1.0 0.0 0.0 + 1.0 1.0 0.0 0.0 + + dm4 = 1.0 1.0 1.0 1.0 + 1.0 1.0 1.0 0.0 + 0.0 0.0 0.0 0.0 + */ + // this should compress to a dense matrix + val dm3 = new DenseMatrix(3, 4, Array.fill(7)(1.0) ++ Array.fill(5)(0.0)) + val dm4 = new DenseMatrix(3, 4, Array.fill(7)(1.0) ++ Array.fill(5)(0.0), isTransposed = true) + + val cm4 = dm3.compressed.asInstanceOf[DenseMatrix] + assert(cm4 === dm3) + assert(cm4.isColMajor) + assert(cm4.values.equals(dm3.values)) + assert(cm4.getSizeInBytes === dm3.getSizeInBytes) + + // force compressed row major + val cm5 = dm3.compressedRowMajor.asInstanceOf[DenseMatrix] + assert(cm5 === dm3) + assert(cm5.isRowMajor) + assert(cm5.getSizeInBytes === dm3.getSizeInBytes) + + val cm6 = dm4.compressed.asInstanceOf[DenseMatrix] + assert(cm6 === dm4) + assert(cm6.isRowMajor) + assert(cm6.values.equals(dm4.values)) + assert(cm6.getSizeInBytes === dm4.getSizeInBytes) + + val cm7 = dm4.compressedColMajor.asInstanceOf[DenseMatrix] + assert(cm7 === dm4) + assert(cm7.isColMajor) + assert(cm7.getSizeInBytes === dm4.getSizeInBytes) + + // this has the same size sparse or dense + val dm5 = new DenseMatrix(4, 4, Array.fill(7)(1.0) ++ Array.fill(9)(0.0)) + // should choose dense to break ties + val cm8 = dm5.compressed.asInstanceOf[DenseMatrix] + assert(cm8.getSizeInBytes === dm5.toSparseColMajor.getSizeInBytes) + } - assert(spMat1.asBreeze === spMat2.asBreeze) - assert(deMat1.asBreeze === deMat2.asBreeze) + test("compressed sparse") { + /* + sm1 = 0.0 -1.0 + 0.0 0.0 + 0.0 0.0 + 0.0 0.0 + + sm2 = 0.0 0.0 0.0 0.0 + -1.0 0.0 0.0 0.0 + */ + // these should compress to sparse matrices + val sm1 = new SparseMatrix(4, 2, Array(0, 0, 1), Array(0), Array(-1.0)) + val sm2 = sm1.transpose + + val cm1 = sm1.compressed.asInstanceOf[SparseMatrix] + // optimal is column major + assert(cm1 === sm1) + assert(cm1.isColMajor) + assert(cm1.values.equals(sm1.values)) + assert(cm1.getSizeInBytes === sm1.getSizeInBytes) + + val cm2 = sm1.compressedRowMajor.asInstanceOf[SparseMatrix] + assert(cm2 === sm1) + assert(cm2.isRowMajor) + // forced to be row major, so we have increased the size + assert(cm2.getSizeInBytes > sm1.getSizeInBytes) + assert(cm2.getSizeInBytes < sm1.toDense.getSizeInBytes) + + val cm9 = sm1.compressedColMajor.asInstanceOf[SparseMatrix] + assert(cm9 === sm1) + assert(cm9.values.equals(sm1.values)) + assert(cm9.getSizeInBytes === sm1.getSizeInBytes) + + val cm3 = sm2.compressed.asInstanceOf[SparseMatrix] + assert(cm3 === sm2) + assert(cm3.isRowMajor) + assert(cm3.values.equals(sm2.values)) + assert(cm3.getSizeInBytes === sm2.getSizeInBytes) + + val cm8 = sm2.compressedColMajor.asInstanceOf[SparseMatrix] + assert(cm8 === sm2) + assert(cm8.isColMajor) + // forced to be col major, so we have increased the size + assert(cm8.getSizeInBytes > sm2.getSizeInBytes) + assert(cm8.getSizeInBytes < sm2.toDense.getSizeInBytes) + + val cm10 = sm2.compressedRowMajor.asInstanceOf[SparseMatrix] + assert(cm10 === sm2) + assert(cm10.isRowMajor) + assert(cm10.values.equals(sm2.values)) + assert(cm10.getSizeInBytes === sm2.getSizeInBytes) + + + /* + sm3 = 0.0 -1.0 + 2.0 3.0 + -4.0 9.0 + */ + // this should compress to a dense matrix + val sm3 = new SparseMatrix(3, 2, Array(0, 2, 5), Array(1, 2, 0, 1, 2), + Array(2.0, -4.0, -1.0, 3.0, 9.0)) + + // dense is optimal, and maintains column major + val cm4 = sm3.compressed.asInstanceOf[DenseMatrix] + assert(cm4 === sm3) + assert(cm4.isColMajor) + assert(cm4.getSizeInBytes < sm3.getSizeInBytes) + + val cm5 = sm3.compressedRowMajor.asInstanceOf[DenseMatrix] + assert(cm5 === sm3) + assert(cm5.isRowMajor) + assert(cm5.getSizeInBytes < sm3.getSizeInBytes) + + val cm11 = sm3.compressedColMajor.asInstanceOf[DenseMatrix] + assert(cm11 === sm3) + assert(cm11.isColMajor) + assert(cm11.getSizeInBytes < sm3.getSizeInBytes) + + /* + sm4 = 1.0 0.0 0.0 ... + + sm5 = 1.0 + 0.0 + 0.0 + ... + */ + val sm4 = new SparseMatrix(Int.MaxValue, 1, Array(0, 1), Array(0), Array(1.0)) + val cm6 = sm4.compressed.asInstanceOf[SparseMatrix] + assert(cm6 === sm4) + assert(cm6.isColMajor) + assert(cm6.getSizeInBytes <= sm4.getSizeInBytes) + + val sm5 = new SparseMatrix(1, Int.MaxValue, Array(0, 1), Array(0), Array(1.0), + isTransposed = true) + val cm7 = sm5.compressed.asInstanceOf[SparseMatrix] + assert(cm7 === sm5) + assert(cm7.isRowMajor) + assert(cm7.getSizeInBytes <= sm5.getSizeInBytes) + + // this has the same size sparse or dense + val sm6 = new SparseMatrix(4, 4, Array(0, 4, 7, 7, 7), Array(0, 1, 2, 3, 0, 1, 2), + Array.fill(7)(1.0)) + // should choose dense to break ties + val cm12 = sm6.compressed.asInstanceOf[DenseMatrix] + assert(cm12.getSizeInBytes === sm6.getSizeInBytes) } test("map, update") { diff --git a/mllib-local/src/test/scala/org/apache/spark/ml/linalg/VectorsSuite.scala b/mllib-local/src/test/scala/org/apache/spark/ml/linalg/VectorsSuite.scala index ea22c2787fb3c..dfbdaf19d374b 100644 --- a/mllib-local/src/test/scala/org/apache/spark/ml/linalg/VectorsSuite.scala +++ b/mllib-local/src/test/scala/org/apache/spark/ml/linalg/VectorsSuite.scala @@ -336,6 +336,11 @@ class VectorsSuite extends SparkMLFunSuite { val sv1 = Vectors.sparse(4, Array(0, 1, 2), Array(1.0, 2.0, 3.0)) val sv1c = sv1.compressed.asInstanceOf[DenseVector] assert(sv1 === sv1c) + + val sv2 = Vectors.sparse(Int.MaxValue, Array(0), Array(3.4)) + val sv2c = sv2.compressed.asInstanceOf[SparseVector] + assert(sv2c === sv2) + assert(sv2c.numActives === 1) } test("SparseVector.slice") { diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index 8ce9367c9b446..2e3f9f2d0f3ac 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -81,7 +81,25 @@ object MimaExcludes { // [SPARK-19876] Add one time trigger, and improve Trigger APIs ProblemFilters.exclude[IncompatibleTemplateDefProblem]("org.apache.spark.sql.streaming.Trigger"), - ProblemFilters.exclude[MissingTypesProblem]("org.apache.spark.sql.streaming.ProcessingTime") + ProblemFilters.exclude[MissingTypesProblem]("org.apache.spark.sql.streaming.ProcessingTime"), + + // [SPARK-17471][ML] Add compressed method to ML matrices + ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.ml.linalg.Matrix.compressed"), + ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.ml.linalg.Matrix.compressedColMajor"), + ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.ml.linalg.Matrix.compressedRowMajor"), + ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.ml.linalg.Matrix.isRowMajor"), + ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.ml.linalg.Matrix.isColMajor"), + ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.ml.linalg.Matrix.getSparseSizeInBytes"), + ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.ml.linalg.Matrix.toDense"), + ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.ml.linalg.Matrix.toSparse"), + ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.ml.linalg.Matrix.toDenseRowMajor"), + ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.ml.linalg.Matrix.toSparseRowMajor"), + ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.ml.linalg.Matrix.toSparseColMajor"), + ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.ml.linalg.Matrix.getDenseSizeInBytes"), + ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.ml.linalg.Matrix.toDenseColMajor"), + ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.ml.linalg.Matrix.toDenseMatrix"), + ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.ml.linalg.Matrix.toSparseMatrix"), + ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.ml.linalg.Matrix.getSizeInBytes") ) // Exclude rules for 2.1.x From 91fa80fe8a2480d64c430bd10f97b3d44c007bcc Mon Sep 17 00:00:00 2001 From: Herman van Hovell Date: Fri, 24 Mar 2017 15:52:48 -0700 Subject: [PATCH 0109/1765] [SPARK-20070][SQL] Redact DataSourceScanExec treeString ## What changes were proposed in this pull request? The explain output of `DataSourceScanExec` can contain sensitive information (like Amazon keys). Such information should not end up in logs, or be exposed to non privileged users. This PR addresses this by adding a redaction facility for the `DataSourceScanExec.treeString`. A user can enable this by setting a regex in the `spark.redaction.string.regex` configuration. ## How was this patch tested? Added a unit test to check the output of DataSourceScanExec. Author: Herman van Hovell Closes #17397 from hvanhovell/SPARK-20070. --- .../spark/internal/config/ConfigBuilder.scala | 13 ++++ .../spark/internal/config/package.scala | 12 +++- .../scala/org/apache/spark/util/Utils.scala | 17 +++++- .../internal/config/ConfigEntrySuite.scala | 19 ++++-- .../sql/execution/DataSourceScanExec.scala | 41 ++++++++----- .../DataSourceScanExecRedactionSuite.scala | 60 +++++++++++++++++++ 6 files changed, 138 insertions(+), 24 deletions(-) create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/execution/DataSourceScanExecRedactionSuite.scala diff --git a/core/src/main/scala/org/apache/spark/internal/config/ConfigBuilder.scala b/core/src/main/scala/org/apache/spark/internal/config/ConfigBuilder.scala index a177e66645c7d..d87619afd3b2f 100644 --- a/core/src/main/scala/org/apache/spark/internal/config/ConfigBuilder.scala +++ b/core/src/main/scala/org/apache/spark/internal/config/ConfigBuilder.scala @@ -18,6 +18,9 @@ package org.apache.spark.internal.config import java.util.concurrent.TimeUnit +import java.util.regex.PatternSyntaxException + +import scala.util.matching.Regex import org.apache.spark.network.util.{ByteUnit, JavaUtils} @@ -65,6 +68,13 @@ private object ConfigHelpers { def byteToString(v: Long, unit: ByteUnit): String = unit.convertTo(v, ByteUnit.BYTE) + "b" + def regexFromString(str: String, key: String): Regex = { + try str.r catch { + case e: PatternSyntaxException => + throw new IllegalArgumentException(s"$key should be a regex, but was $str", e) + } + } + } /** @@ -214,4 +224,7 @@ private[spark] case class ConfigBuilder(key: String) { new FallbackConfigEntry(key, _doc, _public, fallback) } + def regexConf: TypedConfigBuilder[Regex] = { + new TypedConfigBuilder(this, regexFromString(_, this.key), _.regex) + } } diff --git a/core/src/main/scala/org/apache/spark/internal/config/package.scala b/core/src/main/scala/org/apache/spark/internal/config/package.scala index 223c921810378..89aeea4939086 100644 --- a/core/src/main/scala/org/apache/spark/internal/config/package.scala +++ b/core/src/main/scala/org/apache/spark/internal/config/package.scala @@ -246,8 +246,16 @@ package object config { "driver and executor environments contain sensitive information. When this regex matches " + "a property, its value is redacted from the environment UI and various logs like YARN " + "and event logs.") - .stringConf - .createWithDefault("(?i)secret|password") + .regexConf + .createWithDefault("(?i)secret|password".r) + + private[spark] val STRING_REDACTION_PATTERN = + ConfigBuilder("spark.redaction.string.regex") + .doc("Regex to decide which parts of strings produced by Spark contain sensitive " + + "information. When this regex matches a string part, that string part is replaced by a " + + "dummy value. This is currently used to redact the output of SQL explain commands.") + .regexConf + .createOptional private[spark] val NETWORK_AUTH_ENABLED = ConfigBuilder("spark.authenticate") diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index 1af34e3da231f..943dde0723271 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -2585,13 +2585,26 @@ private[spark] object Utils extends Logging { } } - private[util] val REDACTION_REPLACEMENT_TEXT = "*********(redacted)" + private[spark] val REDACTION_REPLACEMENT_TEXT = "*********(redacted)" + /** + * Redact the sensitive values in the given map. If a map key matches the redaction pattern then + * its value is replaced with a dummy text. + */ def redact(conf: SparkConf, kvs: Seq[(String, String)]): Seq[(String, String)] = { - val redactionPattern = conf.get(SECRET_REDACTION_PATTERN).r + val redactionPattern = conf.get(SECRET_REDACTION_PATTERN) redact(redactionPattern, kvs) } + /** + * Redact the sensitive information in the given string. + */ + def redact(conf: SparkConf, text: String): String = { + if (text == null || text.isEmpty || !conf.contains(STRING_REDACTION_PATTERN)) return text + val regex = conf.get(STRING_REDACTION_PATTERN).get + regex.replaceAllIn(text, REDACTION_REPLACEMENT_TEXT) + } + private def redact(redactionPattern: Regex, kvs: Seq[(String, String)]): Seq[(String, String)] = { kvs.map { kv => redactionPattern.findFirstIn(kv._1) diff --git a/core/src/test/scala/org/apache/spark/internal/config/ConfigEntrySuite.scala b/core/src/test/scala/org/apache/spark/internal/config/ConfigEntrySuite.scala index 71eed464880b5..f3756b21080b2 100644 --- a/core/src/test/scala/org/apache/spark/internal/config/ConfigEntrySuite.scala +++ b/core/src/test/scala/org/apache/spark/internal/config/ConfigEntrySuite.scala @@ -19,9 +19,6 @@ package org.apache.spark.internal.config import java.util.concurrent.TimeUnit -import scala.collection.JavaConverters._ -import scala.collection.mutable.HashMap - import org.apache.spark.{SparkConf, SparkFunSuite} import org.apache.spark.network.util.ByteUnit import org.apache.spark.util.SparkConfWithEnv @@ -98,6 +95,21 @@ class ConfigEntrySuite extends SparkFunSuite { assert(conf.get(bytes) === 1L) } + test("conf entry: regex") { + val conf = new SparkConf() + val rConf = ConfigBuilder(testKey("regex")).regexConf.createWithDefault(".*".r) + + conf.set(rConf, "[0-9a-f]{8}".r) + assert(conf.get(rConf).regex === "[0-9a-f]{8}") + + conf.set(rConf.key, "[0-9a-f]{4}") + assert(conf.get(rConf).regex === "[0-9a-f]{4}") + + conf.set(rConf.key, "[.") + val e = intercept[IllegalArgumentException](conf.get(rConf)) + assert(e.getMessage.contains("regex should be a regex, but was")) + } + test("conf entry: string seq") { val conf = new SparkConf() val seq = ConfigBuilder(testKey("seq")).stringConf.toSequence.createWithDefault(Seq()) @@ -239,5 +251,4 @@ class ConfigEntrySuite extends SparkFunSuite { .createWithDefault(null) testEntryRef(nullConf, ref(nullConf)) } - } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala index bfe9c8e351abc..28156b277f597 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala @@ -41,9 +41,33 @@ trait DataSourceScanExec extends LeafExecNode with CodegenSupport { val relation: BaseRelation val metastoreTableIdentifier: Option[TableIdentifier] + protected val nodeNamePrefix: String = "" + override val nodeName: String = { s"Scan $relation ${metastoreTableIdentifier.map(_.unquotedString).getOrElse("")}" } + + override def simpleString: String = { + val metadataEntries = metadata.toSeq.sorted.map { + case (key, value) => + key + ": " + StringUtils.abbreviate(redact(value), 100) + } + val metadataStr = Utils.truncatedString(metadataEntries, " ", ", ", "") + s"$nodeNamePrefix$nodeName${Utils.truncatedString(output, "[", ",", "]")}$metadataStr" + } + + override def verboseString: String = redact(super.verboseString) + + override def treeString(verbose: Boolean, addSuffix: Boolean): String = { + redact(super.treeString(verbose, addSuffix)) + } + + /** + * Shorthand for calling redactString() without specifying redacting rules + */ + private def redact(text: String): String = { + Utils.redact(SparkSession.getActiveSession.get.sparkContext.conf, text) + } } /** Physical plan node for scanning data from a relation. */ @@ -85,15 +109,6 @@ case class RowDataSourceScanExec( } } - override def simpleString: String = { - val metadataEntries = for ((key, value) <- metadata.toSeq.sorted) yield { - key + ": " + StringUtils.abbreviate(value, 100) - } - - s"$nodeName${Utils.truncatedString(output, "[", ",", "]")}" + - s"${Utils.truncatedString(metadataEntries, " ", ", ", "")}" - } - override def inputRDDs(): Seq[RDD[InternalRow]] = { rdd :: Nil } @@ -307,13 +322,7 @@ case class FileSourceScanExec( } } - override def simpleString: String = { - val metadataEntries = for ((key, value) <- metadata.toSeq.sorted) yield { - key + ": " + StringUtils.abbreviate(value, 100) - } - val metadataStr = Utils.truncatedString(metadataEntries, " ", ", ", "") - s"File$nodeName${Utils.truncatedString(output, "[", ",", "]")}$metadataStr" - } + override val nodeNamePrefix: String = "File" override protected def doProduce(ctx: CodegenContext): String = { if (supportsBatch) { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/DataSourceScanExecRedactionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/DataSourceScanExecRedactionSuite.scala new file mode 100644 index 0000000000000..986fa878ee29b --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/DataSourceScanExecRedactionSuite.scala @@ -0,0 +1,60 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.execution + +import org.apache.hadoop.fs.Path + +import org.apache.spark.sql.QueryTest +import org.apache.spark.sql.test.SharedSQLContext +import org.apache.spark.util.Utils + +/** + * Suite that tests the redaction of DataSourceScanExec + */ +class DataSourceScanExecRedactionSuite extends QueryTest with SharedSQLContext { + + import Utils._ + + override def beforeAll(): Unit = { + sparkConf.set("spark.redaction.string.regex", + "spark-[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}") + super.beforeAll() + } + + test("treeString is redacted") { + withTempDir { dir => + val basePath = dir.getCanonicalPath + spark.range(0, 10).toDF("a").write.parquet(new Path(basePath, "foo=1").toString) + val df = spark.read.parquet(basePath) + + val rootPath = df.queryExecution.sparkPlan.find(_.isInstanceOf[FileSourceScanExec]).get + .asInstanceOf[FileSourceScanExec].relation.location.rootPaths.head + assert(rootPath.toString.contains(basePath.toString)) + + assert(!df.queryExecution.sparkPlan.treeString(verbose = true).contains(rootPath.getName)) + assert(!df.queryExecution.executedPlan.treeString(verbose = true).contains(rootPath.getName)) + assert(!df.queryExecution.toString.contains(rootPath.getName)) + assert(!df.queryExecution.simpleString.contains(rootPath.getName)) + + val replacement = "*********" + assert(df.queryExecution.sparkPlan.treeString(verbose = true).contains(replacement)) + assert(df.queryExecution.executedPlan.treeString(verbose = true).contains(replacement)) + assert(df.queryExecution.toString.contains(replacement)) + assert(df.queryExecution.simpleString.contains(replacement)) + } + } +} From b5c5bd98ea5e8dbfebcf86c5459bdf765f5ceb53 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Fri, 24 Mar 2017 23:57:29 +0100 Subject: [PATCH 0110/1765] Disable generate codegen since it fails my workload. --- .../spark/sql/execution/GenerateExec.scala | 2 +- .../execution/WholeStageCodegenSuite.scala | 28 ------------------- 2 files changed, 1 insertion(+), 29 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/GenerateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/GenerateExec.scala index 69be7094d2c39..f87d05884b276 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/GenerateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/GenerateExec.scala @@ -119,7 +119,7 @@ case class GenerateExec( } } - override def supportCodegen: Boolean = generator.supportCodegen + override def supportCodegen: Boolean = false override def inputRDDs(): Seq[RDD[InternalRow]] = { child.asInstanceOf[CodegenSupport].inputRDDs() diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala index 4d9203556d49e..a4b30a2f8cec1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala @@ -116,34 +116,6 @@ class WholeStageCodegenSuite extends SparkPlanTest with SharedSQLContext { assert(ds.collect() === Array(("a", 10.0), ("b", 3.0), ("c", 1.0))) } - test("generate should be included in WholeStageCodegen") { - import org.apache.spark.sql.functions._ - val ds = spark.range(2).select( - col("id"), - explode(array(col("id") + 1, col("id") + 2)).as("value")) - val plan = ds.queryExecution.executedPlan - assert(plan.find(p => - p.isInstanceOf[WholeStageCodegenExec] && - p.asInstanceOf[WholeStageCodegenExec].child.isInstanceOf[GenerateExec]).isDefined) - assert(ds.collect() === Array(Row(0, 1), Row(0, 2), Row(1, 2), Row(1, 3))) - } - - test("large stack generator should not use WholeStageCodegen") { - def createStackGenerator(rows: Int): SparkPlan = { - val id = UnresolvedAttribute("id") - val stack = Stack(Literal(rows) +: Seq.tabulate(rows)(i => Add(id, Literal(i)))) - spark.range(500).select(Column(stack)).queryExecution.executedPlan - } - val isCodeGenerated: SparkPlan => Boolean = { - case WholeStageCodegenExec(_: GenerateExec) => true - case _ => false - } - - // Only 'stack' generators that produce 50 rows or less are code generated. - assert(createStackGenerator(50).find(isCodeGenerated).isDefined) - assert(createStackGenerator(100).find(isCodeGenerated).isEmpty) - } - test("SPARK-19512 codegen for comparing structs is incorrect") { // this would raise CompileException before the fix spark.range(10) From e011004bedca47be998a0c14fe22a6f9bb5090cd Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Sat, 25 Mar 2017 00:04:51 +0100 Subject: [PATCH 0111/1765] [SPARK-19846][SQL] Add a flag to disable constraint propagation ## What changes were proposed in this pull request? Constraint propagation can be computation expensive and block the driver execution for long time. For example, the below benchmark needs 30mins. Compared with previous PRs #16998, #16785, this is a much simpler option: add a flag to disable constraint propagation. ### Benchmark Run the following codes locally. import org.apache.spark.ml.{Pipeline, PipelineStage} import org.apache.spark.ml.feature.{OneHotEncoder, StringIndexer, VectorAssembler} import org.apache.spark.sql.internal.SQLConf spark.conf.set(SQLConf.CONSTRAINT_PROPAGATION_ENABLED.key, false) val df = (1 to 40).foldLeft(Seq((1, "foo"), (2, "bar"), (3, "baz")).toDF("id", "x0"))((df, i) => df.withColumn(s"x$i", $"x0")) val indexers = df.columns.tail.map(c => new StringIndexer() .setInputCol(c) .setOutputCol(s"${c}_indexed") .setHandleInvalid("skip")) val encoders = indexers.map(indexer => new OneHotEncoder() .setInputCol(indexer.getOutputCol) .setOutputCol(s"${indexer.getOutputCol}_encoded") .setDropLast(true)) val stages: Array[PipelineStage] = indexers ++ encoders val pipeline = new Pipeline().setStages(stages) val startTime = System.nanoTime pipeline.fit(df).transform(df).show val runningTime = System.nanoTime - startTime Before this patch: 1786001 ms ~= 30 mins After this patch: 26392 ms = less than half of a minute Related PRs: #16998, #16785. ## How was this patch tested? Jenkins tests. Please review http://spark.apache.org/contributing.html before opening a pull request. Author: Liang-Chi Hsieh Closes #17186 from viirya/add-flag-disable-constraint-propagation. --- .../sql/catalyst/SimpleCatalystConf.scala | 3 +- .../sql/catalyst/optimizer/Optimizer.scala | 22 ++++++---- .../spark/sql/catalyst/optimizer/joins.scala | 6 ++- .../spark/sql/catalyst/plans/QueryPlan.scala | 11 +++++ .../apache/spark/sql/internal/SQLConf.scala | 11 +++++ .../BinaryComparisonSimplificationSuite.scala | 5 ++- .../BooleanSimplificationSuite.scala | 5 ++- .../InferFiltersFromConstraintsSuite.scala | 19 ++++++++- .../optimizer/OuterJoinEliminationSuite.scala | 30 +++++++++++++- .../PropagateEmptyRelationSuite.scala | 5 ++- .../optimizer/PruneFiltersSuite.scala | 40 ++++++++++++++++++- .../optimizer/SetOperationSuite.scala | 3 +- .../plans/ConstraintPropagationSuite.scala | 18 +++++++++ 13 files changed, 158 insertions(+), 20 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SimpleCatalystConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SimpleCatalystConf.scala index ac97987c55e08..8498cf1c9be79 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SimpleCatalystConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SimpleCatalystConf.scala @@ -43,7 +43,8 @@ case class SimpleCatalystConf( override val starSchemaDetection: Boolean = false, override val warehousePath: String = "/user/hive/warehouse", override val sessionLocalTimeZone: String = TimeZone.getDefault().getID, - override val maxNestedViewDepth: Int = 100) + override val maxNestedViewDepth: Int = 100, + override val constraintPropagationEnabled: Boolean = true) extends SQLConf { override def clone(): SimpleCatalystConf = this.copy() diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index d7524a57adbc7..ee7de86921496 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -83,12 +83,12 @@ abstract class Optimizer(sessionCatalog: SessionCatalog, conf: CatalystConf) // Operator push down PushProjectionThroughUnion, ReorderJoin(conf), - EliminateOuterJoin, + EliminateOuterJoin(conf), PushPredicateThroughJoin, PushDownPredicate, LimitPushDown(conf), ColumnPruning, - InferFiltersFromConstraints, + InferFiltersFromConstraints(conf), // Operator combine CollapseRepartition, CollapseProject, @@ -107,7 +107,7 @@ abstract class Optimizer(sessionCatalog: SessionCatalog, conf: CatalystConf) SimplifyConditionals, RemoveDispensableExpressions, SimplifyBinaryComparison, - PruneFilters, + PruneFilters(conf), EliminateSorts, SimplifyCasts, SimplifyCaseConversionExpressions, @@ -615,8 +615,16 @@ object CollapseWindow extends Rule[LogicalPlan] { * Note: While this optimization is applicable to all types of join, it primarily benefits Inner and * LeftSemi joins. */ -object InferFiltersFromConstraints extends Rule[LogicalPlan] with PredicateHelper { - def apply(plan: LogicalPlan): LogicalPlan = plan transform { +case class InferFiltersFromConstraints(conf: CatalystConf) + extends Rule[LogicalPlan] with PredicateHelper { + def apply(plan: LogicalPlan): LogicalPlan = if (conf.constraintPropagationEnabled) { + inferFilters(plan) + } else { + plan + } + + + private def inferFilters(plan: LogicalPlan): LogicalPlan = plan transform { case filter @ Filter(condition, child) => val newFilters = filter.constraints -- (child.constraints ++ splitConjunctivePredicates(condition)) @@ -705,7 +713,7 @@ object EliminateSorts extends Rule[LogicalPlan] { * 2) by substituting a dummy empty relation when the filter will always evaluate to `false`. * 3) by eliminating the always-true conditions given the constraints on the child's output. */ -object PruneFilters extends Rule[LogicalPlan] with PredicateHelper { +case class PruneFilters(conf: CatalystConf) extends Rule[LogicalPlan] with PredicateHelper { def apply(plan: LogicalPlan): LogicalPlan = plan transform { // If the filter condition always evaluate to true, remove the filter. case Filter(Literal(true, BooleanType), child) => child @@ -718,7 +726,7 @@ object PruneFilters extends Rule[LogicalPlan] with PredicateHelper { case f @ Filter(fc, p: LogicalPlan) => val (prunedPredicates, remainingPredicates) = splitConjunctivePredicates(fc).partition { cond => - cond.deterministic && p.constraints.contains(cond) + cond.deterministic && p.getConstraints(conf.constraintPropagationEnabled).contains(cond) } if (prunedPredicates.isEmpty) { f diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/joins.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/joins.scala index 58e4a230f4ef0..5f7316566b3ba 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/joins.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/joins.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst.optimizer import scala.annotation.tailrec +import org.apache.spark.sql.catalyst.CatalystConf import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.planning.{ExtractFiltersAndInnerJoins, PhysicalOperation} import org.apache.spark.sql.catalyst.plans._ @@ -439,7 +440,7 @@ case class ReorderJoin(conf: SQLConf) extends Rule[LogicalPlan] with PredicateHe * * This rule should be executed before pushing down the Filter */ -object EliminateOuterJoin extends Rule[LogicalPlan] with PredicateHelper { +case class EliminateOuterJoin(conf: CatalystConf) extends Rule[LogicalPlan] with PredicateHelper { /** * Returns whether the expression returns null or false when all inputs are nulls. @@ -455,7 +456,8 @@ object EliminateOuterJoin extends Rule[LogicalPlan] with PredicateHelper { } private def buildNewJoinType(filter: Filter, join: Join): JoinType = { - val conditions = splitConjunctivePredicates(filter.condition) ++ filter.constraints + val conditions = splitConjunctivePredicates(filter.condition) ++ + filter.getConstraints(conf.constraintPropagationEnabled) val leftConditions = conditions.filter(_.references.subsetOf(join.left.outputSet)) val rightConditions = conditions.filter(_.references.subsetOf(join.right.outputSet)) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala index a5761703fd655..9fd95a4b368ce 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala @@ -186,6 +186,17 @@ abstract class QueryPlan[PlanType <: QueryPlan[PlanType]] extends TreeNode[PlanT */ lazy val constraints: ExpressionSet = ExpressionSet(getRelevantConstraints(validConstraints)) + /** + * Returns [[constraints]] depending on the config of enabling constraint propagation. If the + * flag is disabled, simply returning an empty constraints. + */ + private[spark] def getConstraints(constraintPropagationEnabled: Boolean): ExpressionSet = + if (constraintPropagationEnabled) { + constraints + } else { + ExpressionSet(Set.empty) + } + /** * This method can be overridden by any child class of QueryPlan to specify a set of constraints * based on the given operator's constraint propagation logic. These constraints are then diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index d5006c16469bc..5566b06aa3553 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -187,6 +187,15 @@ object SQLConf { .booleanConf .createWithDefault(false) + val CONSTRAINT_PROPAGATION_ENABLED = buildConf("spark.sql.constraintPropagation.enabled") + .internal() + .doc("When true, the query optimizer will infer and propagate data constraints in the query " + + "plan to optimize them. Constraint propagation can sometimes be computationally expensive" + + "for certain kinds of query plans (such as those with a large number of predicates and " + + "aliases) which might negatively impact overall runtime.") + .booleanConf + .createWithDefault(true) + val PARQUET_SCHEMA_MERGING_ENABLED = buildConf("spark.sql.parquet.mergeSchema") .doc("When true, the Parquet data source merges schemas collected from all data files, " + "otherwise the schema is picked from the summary file or a random data file " + @@ -887,6 +896,8 @@ class SQLConf extends Serializable with Logging { def caseSensitiveAnalysis: Boolean = getConf(SQLConf.CASE_SENSITIVE) + def constraintPropagationEnabled: Boolean = getConf(CONSTRAINT_PROPAGATION_ENABLED) + /** * Returns the [[Resolver]] for the current configuration, which can be used to determine if two * identifiers are equal. diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/BinaryComparisonSimplificationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/BinaryComparisonSimplificationSuite.scala index a0d489681fd9f..2bfddb7bc2f35 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/BinaryComparisonSimplificationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/BinaryComparisonSimplificationSuite.scala @@ -30,15 +30,16 @@ import org.apache.spark.sql.catalyst.rules._ class BinaryComparisonSimplificationSuite extends PlanTest with PredicateHelper { object Optimize extends RuleExecutor[LogicalPlan] { + val conf = SimpleCatalystConf(caseSensitiveAnalysis = true) val batches = Batch("AnalysisNodes", Once, EliminateSubqueryAliases) :: Batch("Constant Folding", FixedPoint(50), - NullPropagation(SimpleCatalystConf(caseSensitiveAnalysis = true)), + NullPropagation(conf), ConstantFolding, BooleanSimplification, SimplifyBinaryComparison, - PruneFilters) :: Nil + PruneFilters(conf)) :: Nil } val nullableRelation = LocalRelation('a.int.withNullability(true)) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/BooleanSimplificationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/BooleanSimplificationSuite.scala index 1b9db06014921..4d404f55aa570 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/BooleanSimplificationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/BooleanSimplificationSuite.scala @@ -30,14 +30,15 @@ import org.apache.spark.sql.catalyst.rules._ class BooleanSimplificationSuite extends PlanTest with PredicateHelper { object Optimize extends RuleExecutor[LogicalPlan] { + val conf = SimpleCatalystConf(caseSensitiveAnalysis = true) val batches = Batch("AnalysisNodes", Once, EliminateSubqueryAliases) :: Batch("Constant Folding", FixedPoint(50), - NullPropagation(SimpleCatalystConf(caseSensitiveAnalysis = true)), + NullPropagation(conf), ConstantFolding, BooleanSimplification, - PruneFilters) :: Nil + PruneFilters(conf)) :: Nil } val testRelation = LocalRelation('a.int, 'b.int, 'c.int, 'd.string) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/InferFiltersFromConstraintsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/InferFiltersFromConstraintsSuite.scala index 9f57f66a2ea20..98d8b897a9165 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/InferFiltersFromConstraintsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/InferFiltersFromConstraintsSuite.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.catalyst.optimizer +import org.apache.spark.sql.catalyst.SimpleCatalystConf import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ import org.apache.spark.sql.catalyst.expressions._ @@ -31,7 +32,17 @@ class InferFiltersFromConstraintsSuite extends PlanTest { Batch("InferAndPushDownFilters", FixedPoint(100), PushPredicateThroughJoin, PushDownPredicate, - InferFiltersFromConstraints, + InferFiltersFromConstraints(SimpleCatalystConf(caseSensitiveAnalysis = true)), + CombineFilters) :: Nil + } + + object OptimizeWithConstraintPropagationDisabled extends RuleExecutor[LogicalPlan] { + val batches = + Batch("InferAndPushDownFilters", FixedPoint(100), + PushPredicateThroughJoin, + PushDownPredicate, + InferFiltersFromConstraints(SimpleCatalystConf(caseSensitiveAnalysis = true, + constraintPropagationEnabled = false)), CombineFilters) :: Nil } @@ -201,4 +212,10 @@ class InferFiltersFromConstraintsSuite extends PlanTest { val optimized = Optimize.execute(originalQuery) comparePlans(optimized, correctAnswer) } + + test("No inferred filter when constraint propagation is disabled") { + val originalQuery = testRelation.where('a === 1 && 'a === 'b).analyze + val optimized = OptimizeWithConstraintPropagationDisabled.execute(originalQuery) + comparePlans(optimized, originalQuery) + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OuterJoinEliminationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OuterJoinEliminationSuite.scala index c168a55e40c54..cbabc1fa6d929 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OuterJoinEliminationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OuterJoinEliminationSuite.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.catalyst.optimizer +import org.apache.spark.sql.catalyst.SimpleCatalystConf import org.apache.spark.sql.catalyst.analysis.EliminateSubqueryAliases import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ @@ -31,7 +32,17 @@ class OuterJoinEliminationSuite extends PlanTest { Batch("Subqueries", Once, EliminateSubqueryAliases) :: Batch("Outer Join Elimination", Once, - EliminateOuterJoin, + EliminateOuterJoin(SimpleCatalystConf(caseSensitiveAnalysis = true)), + PushPredicateThroughJoin) :: Nil + } + + object OptimizeWithConstraintPropagationDisabled extends RuleExecutor[LogicalPlan] { + val batches = + Batch("Subqueries", Once, + EliminateSubqueryAliases) :: + Batch("Outer Join Elimination", Once, + EliminateOuterJoin(SimpleCatalystConf(caseSensitiveAnalysis = true, + constraintPropagationEnabled = false)), PushPredicateThroughJoin) :: Nil } @@ -231,4 +242,21 @@ class OuterJoinEliminationSuite extends PlanTest { comparePlans(optimized, correctAnswer) } + + test("no outer join elimination if constraint propagation is disabled") { + val x = testRelation.subquery('x) + val y = testRelation1.subquery('y) + + // The predicate "x.b + y.d >= 3" will be inferred constraints like: + // "x.b != null" and "y.d != null", if constraint propagation is enabled. + // When we disable it, the predicate can't be evaluated on left or right plan and used to + // filter out nulls. So the Outer Join will not be eliminated. + val originalQuery = + x.join(y, FullOuter, Option("x.a".attr === "y.d".attr)) + .where("x.b".attr + "y.d".attr >= 3) + + val optimized = OptimizeWithConstraintPropagationDisabled.execute(originalQuery.analyze) + + comparePlans(optimized, originalQuery.analyze) + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PropagateEmptyRelationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PropagateEmptyRelationSuite.scala index 908dde7a66988..f771e3e9eba65 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PropagateEmptyRelationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PropagateEmptyRelationSuite.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.catalyst.optimizer import org.apache.spark.sql.Row +import org.apache.spark.sql.catalyst.SimpleCatalystConf import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ import org.apache.spark.sql.catalyst.plans._ @@ -33,7 +34,7 @@ class PropagateEmptyRelationSuite extends PlanTest { ReplaceExceptWithAntiJoin, ReplaceIntersectWithSemiJoin, PushDownPredicate, - PruneFilters, + PruneFilters(SimpleCatalystConf(caseSensitiveAnalysis = true)), PropagateEmptyRelation) :: Nil } @@ -45,7 +46,7 @@ class PropagateEmptyRelationSuite extends PlanTest { ReplaceExceptWithAntiJoin, ReplaceIntersectWithSemiJoin, PushDownPredicate, - PruneFilters) :: Nil + PruneFilters(SimpleCatalystConf(caseSensitiveAnalysis = true))) :: Nil } val testRelation1 = LocalRelation.fromExternalRows(Seq('a.int), data = Seq(Row(1))) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PruneFiltersSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PruneFiltersSuite.scala index d8cfec5391497..20f7f69e86c05 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PruneFiltersSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PruneFiltersSuite.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.catalyst.optimizer +import org.apache.spark.sql.catalyst.SimpleCatalystConf import org.apache.spark.sql.catalyst.analysis.EliminateSubqueryAliases import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ @@ -33,7 +34,19 @@ class PruneFiltersSuite extends PlanTest { EliminateSubqueryAliases) :: Batch("Filter Pushdown and Pruning", Once, CombineFilters, - PruneFilters, + PruneFilters(SimpleCatalystConf(caseSensitiveAnalysis = true)), + PushDownPredicate, + PushPredicateThroughJoin) :: Nil + } + + object OptimizeWithConstraintPropagationDisabled extends RuleExecutor[LogicalPlan] { + val batches = + Batch("Subqueries", Once, + EliminateSubqueryAliases) :: + Batch("Filter Pushdown and Pruning", Once, + CombineFilters, + PruneFilters(SimpleCatalystConf(caseSensitiveAnalysis = true, + constraintPropagationEnabled = false)), PushDownPredicate, PushPredicateThroughJoin) :: Nil } @@ -133,4 +146,29 @@ class PruneFiltersSuite extends PlanTest { val correctAnswer = testRelation.where(Rand(10) > 5).where(Rand(10) > 5).select('a).analyze comparePlans(optimized, correctAnswer) } + + test("No pruning when constraint propagation is disabled") { + val tr1 = LocalRelation('a.int, 'b.int, 'c.int).subquery('tr1) + val tr2 = LocalRelation('a.int, 'd.int, 'e.int).subquery('tr2) + + val query = tr1 + .where("tr1.a".attr > 10 || "tr1.c".attr < 10) + .join(tr2.where('d.attr < 100), Inner, Some("tr1.a".attr === "tr2.a".attr)) + + val queryWithUselessFilter = + query.where( + ("tr1.a".attr > 10 || "tr1.c".attr < 10) && + 'd.attr < 100) + + val optimized = + OptimizeWithConstraintPropagationDisabled.execute(queryWithUselessFilter.analyze) + // When constraint propagation is disabled, the useless filter won't be pruned. + // It gets pushed down. Because the rule `CombineFilters` runs only once, there are redundant + // and duplicate filters. + val correctAnswer = tr1 + .where("tr1.a".attr > 10 || "tr1.c".attr < 10).where("tr1.a".attr > 10 || "tr1.c".attr < 10) + .join(tr2.where('d.attr < 100).where('d.attr < 100), + Inner, Some("tr1.a".attr === "tr2.a".attr)).analyze + comparePlans(optimized, correctAnswer) + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SetOperationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SetOperationSuite.scala index 21b7f49e14bd5..ca4976f0d6db0 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SetOperationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SetOperationSuite.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.catalyst.optimizer +import org.apache.spark.sql.catalyst.SimpleCatalystConf import org.apache.spark.sql.catalyst.analysis.EliminateSubqueryAliases import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ @@ -34,7 +35,7 @@ class SetOperationSuite extends PlanTest { CombineUnions, PushProjectionThroughUnion, PushDownPredicate, - PruneFilters) :: Nil + PruneFilters(SimpleCatalystConf(caseSensitiveAnalysis = true))) :: Nil } val testRelation = LocalRelation('a.int, 'b.int, 'c.int) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/ConstraintPropagationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/ConstraintPropagationSuite.scala index 908b370408280..4061394b862a6 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/ConstraintPropagationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/ConstraintPropagationSuite.scala @@ -397,4 +397,22 @@ class ConstraintPropagationSuite extends SparkFunSuite { IsNotNull(resolveColumn(tr, "a")), IsNotNull(resolveColumn(tr, "c"))))) } + + test("enable/disable constraint propagation") { + val tr = LocalRelation('a.int, 'b.string, 'c.int) + val filterRelation = tr.where('a.attr > 10) + + verifyConstraints( + filterRelation.analyze.getConstraints(constraintPropagationEnabled = true), + filterRelation.analyze.constraints) + + assert(filterRelation.analyze.getConstraints(constraintPropagationEnabled = false).isEmpty) + + val aliasedRelation = tr.where('c.attr > 10 && 'a.attr < 5) + .groupBy('a, 'c, 'b)('a, 'c.as("c1"), count('a).as("a3")).select('c1, 'a, 'a3) + + verifyConstraints(aliasedRelation.analyze.getConstraints(constraintPropagationEnabled = true), + aliasedRelation.analyze.constraints) + assert(aliasedRelation.analyze.getConstraints(constraintPropagationEnabled = false).isEmpty) + } } From f88f56b835b3a61ff2d59236e7fa05eda5aefcaa Mon Sep 17 00:00:00 2001 From: Roxanne Moslehi Date: Sat, 25 Mar 2017 00:10:30 +0100 Subject: [PATCH 0112/1765] [DOCS] Clarify round mode for format_number & round functions ## What changes were proposed in this pull request? Updated the description for the `format_number` description to indicate that it uses `HALF_EVEN` rounding. Updated the description for the `round` description to indicate that it uses `HALF_UP` rounding. ## How was this patch tested? Just changing the two function comments so no testing involved. Please review http://spark.apache.org/contributing.html before opening a pull request. Author: Roxanne Moslehi Author: roxannemoslehi Closes #17399 from roxannemoslehi/patch-1. --- .../main/scala/org/apache/spark/sql/functions.scala | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 66bb8816a6701..acdb8e2d3edc8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -1861,7 +1861,7 @@ object functions { def rint(columnName: String): Column = rint(Column(columnName)) /** - * Returns the value of the column `e` rounded to 0 decimal places. + * Returns the value of the column `e` rounded to 0 decimal places with HALF_UP round mode. * * @group math_funcs * @since 1.5.0 @@ -1869,8 +1869,8 @@ object functions { def round(e: Column): Column = round(e, 0) /** - * Round the value of `e` to `scale` decimal places if `scale` is greater than or equal to 0 - * or at integral part when `scale` is less than 0. + * Round the value of `e` to `scale` decimal places with HALF_UP round mode + * if `scale` is greater than or equal to 0 or at integral part when `scale` is less than 0. * * @group math_funcs * @since 1.5.0 @@ -2191,8 +2191,8 @@ object functions { } /** - * Formats numeric column x to a format like '#,###,###.##', rounded to d decimal places, - * and returns the result as a string column. + * Formats numeric column x to a format like '#,###,###.##', rounded to d decimal places + * with HALF_EVEN round mode, and returns the result as a string column. * * If d is 0, the result has no decimal point or fractional part. * If d is less than 0, the result will be null. From 0a6c50711b871dce1a04f5dc7652a0b936369fa0 Mon Sep 17 00:00:00 2001 From: Herman van Hovell Date: Sat, 25 Mar 2017 01:07:50 +0100 Subject: [PATCH 0113/1765] [SPARK-20070][SQL] Fix 2.10 build ## What changes were proposed in this pull request? Commit https://github.com/apache/spark/commit/91fa80fe8a2480d64c430bd10f97b3d44c007bcc broke the build for scala 2.10. The commit uses `Regex.regex` field which is not available in Scala 2.10. This PR fixes this. ## How was this patch tested? Existing tests. Author: Herman van Hovell Closes #17420 from hvanhovell/SPARK-20070-2.0. --- .../org/apache/spark/internal/config/ConfigBuilder.scala | 2 +- .../org/apache/spark/internal/config/ConfigEntrySuite.scala | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/internal/config/ConfigBuilder.scala b/core/src/main/scala/org/apache/spark/internal/config/ConfigBuilder.scala index d87619afd3b2f..b9921138cc6c7 100644 --- a/core/src/main/scala/org/apache/spark/internal/config/ConfigBuilder.scala +++ b/core/src/main/scala/org/apache/spark/internal/config/ConfigBuilder.scala @@ -225,6 +225,6 @@ private[spark] case class ConfigBuilder(key: String) { } def regexConf: TypedConfigBuilder[Regex] = { - new TypedConfigBuilder(this, regexFromString(_, this.key), _.regex) + new TypedConfigBuilder(this, regexFromString(_, this.key), _.toString) } } diff --git a/core/src/test/scala/org/apache/spark/internal/config/ConfigEntrySuite.scala b/core/src/test/scala/org/apache/spark/internal/config/ConfigEntrySuite.scala index f3756b21080b2..3ff7e84d73bd4 100644 --- a/core/src/test/scala/org/apache/spark/internal/config/ConfigEntrySuite.scala +++ b/core/src/test/scala/org/apache/spark/internal/config/ConfigEntrySuite.scala @@ -100,10 +100,10 @@ class ConfigEntrySuite extends SparkFunSuite { val rConf = ConfigBuilder(testKey("regex")).regexConf.createWithDefault(".*".r) conf.set(rConf, "[0-9a-f]{8}".r) - assert(conf.get(rConf).regex === "[0-9a-f]{8}") + assert(conf.get(rConf).toString === "[0-9a-f]{8}") conf.set(rConf.key, "[0-9a-f]{4}") - assert(conf.get(rConf).regex === "[0-9a-f]{4}") + assert(conf.get(rConf).toString === "[0-9a-f]{4}") conf.set(rConf.key, "[.") val e = intercept[IllegalArgumentException](conf.get(rConf)) From a2ce0a2e309e70d74ae5d2ed203f7919a0f79397 Mon Sep 17 00:00:00 2001 From: Xiao Li Date: Fri, 24 Mar 2017 23:27:42 -0700 Subject: [PATCH 0114/1765] [HOTFIX][SQL] Fix the failed test cases in GeneratorFunctionSuite ### What changes were proposed in this pull request? Multiple tests failed. Revert the changes on `supportCodegen` of `GenerateExec`. For example, - https://amplab.cs.berkeley.edu/jenkins/job/SparkPullRequestBuilder/75194/testReport/ ### How was this patch tested? N/A Author: Xiao Li Closes #17425 from gatorsmile/turnOnCodeGenGenerateExec. --- .../apache/spark/sql/GeneratorFunctionSuite.scala | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/GeneratorFunctionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/GeneratorFunctionSuite.scala index b9871afd59e4f..cef5bbf0e85a7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/GeneratorFunctionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/GeneratorFunctionSuite.scala @@ -91,7 +91,7 @@ class GeneratorFunctionSuite extends QueryTest with SharedSQLContext { val df = Seq((1, Seq(1, 2, 3)), (2, Seq())).toDF("a", "intList") checkAnswer( df.select(explode_outer('intList)), - Row(1) :: Row(2) :: Row(3) :: Row(null) :: Nil) + Row(1) :: Row(2) :: Row(3) :: Nil) } test("single posexplode") { @@ -105,7 +105,7 @@ class GeneratorFunctionSuite extends QueryTest with SharedSQLContext { val df = Seq((1, Seq(1, 2, 3)), (2, Seq())).toDF("a", "intList") checkAnswer( df.select(posexplode_outer('intList)), - Row(0, 1) :: Row(1, 2) :: Row(2, 3) :: Row(null, null) :: Nil) + Row(0, 1) :: Row(1, 2) :: Row(2, 3) :: Nil) } test("explode and other columns") { @@ -161,7 +161,7 @@ class GeneratorFunctionSuite extends QueryTest with SharedSQLContext { checkAnswer( df.select(explode_outer('intList).as('int)).select('int), - Row(1) :: Row(2) :: Row(3) :: Row(null) :: Nil) + Row(1) :: Row(2) :: Row(3) :: Nil) checkAnswer( df.select(explode('intList).as('int)).select(sum('int)), @@ -182,7 +182,7 @@ class GeneratorFunctionSuite extends QueryTest with SharedSQLContext { checkAnswer( df.select(explode_outer('map)), - Row("a", "b") :: Row(null, null) :: Row("c", "d") :: Nil) + Row("a", "b") :: Row("c", "d") :: Nil) } test("explode on map with aliases") { @@ -198,7 +198,7 @@ class GeneratorFunctionSuite extends QueryTest with SharedSQLContext { checkAnswer( df.select(explode_outer('map).as("key1" :: "value1" :: Nil)).select("key1", "value1"), - Row("a", "b") :: Row(null, null) :: Nil) + Row("a", "b") :: Nil) } test("self join explode") { @@ -279,7 +279,7 @@ class GeneratorFunctionSuite extends QueryTest with SharedSQLContext { ) checkAnswer( df2.selectExpr("inline_outer(col1)"), - Row(null, null) :: Row(3, "4") :: Row(5, "6") :: Nil + Row(3, "4") :: Row(5, "6") :: Nil ) } From e8ddb91c7ea5a0b4576cf47aaf969bcc82860b7c Mon Sep 17 00:00:00 2001 From: Kalvin Chau Date: Sat, 25 Mar 2017 10:42:15 +0000 Subject: [PATCH 0115/1765] [SPARK-20078][MESOS] Mesos executor configurability for task name and labels ## What changes were proposed in this pull request? Adding configurable mesos executor names and labels using `spark.mesos.task.name` and `spark.mesos.task.labels`. Labels were defined as `k1:v1,k2:v2`. mgummelt ## How was this patch tested? Added unit tests to verify labels were added correctly, with incorrect labels being ignored and added a test to test the name of the executor. Tested with: `./build/sbt -Pmesos mesos/test` Please review http://spark.apache.org/contributing.html before opening a pull request. Author: Kalvin Chau Closes #17404 from kalvinnchau/mesos-config. --- .../mesos/MesosCoarseGrainedSchedulerBackend.scala | 3 ++- .../MesosCoarseGrainedSchedulerBackendSuite.scala | 11 +++++++++++ 2 files changed, 13 insertions(+), 1 deletion(-) diff --git a/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackend.scala b/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackend.scala index c049a32eabf90..5bdc2a2b840e3 100644 --- a/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackend.scala +++ b/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackend.scala @@ -403,7 +403,8 @@ private[spark] class MesosCoarseGrainedSchedulerBackend( .setTaskId(TaskID.newBuilder().setValue(taskId.toString).build()) .setSlaveId(offer.getSlaveId) .setCommand(createCommand(offer, taskCPUs + extraCoresPerExecutor, taskId)) - .setName("Task " + taskId) + .setName(s"${sc.appName} $taskId") + taskBuilder.addAllResources(resourcesToUse.asJava) taskBuilder.setContainer(MesosSchedulerBackendUtil.containerInfo(sc.conf)) diff --git a/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackendSuite.scala b/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackendSuite.scala index 98033bec6dd68..eb83926ae4102 100644 --- a/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackendSuite.scala +++ b/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackendSuite.scala @@ -464,6 +464,17 @@ class MesosCoarseGrainedSchedulerBackendSuite extends SparkFunSuite assert(!uris.asScala.head.getCache) } + test("mesos sets task name to spark.app.name") { + setBackend() + + val offers = List(Resources(backend.executorMemory(sc), 1)) + offerResources(offers) + val launchedTasks = verifyTaskLaunched(driver, "o1") + + // Add " 0" to the taskName to match the executor number that is appended + assert(launchedTasks.head.getName == "test-mesos-dynamic-alloc 0") + } + test("mesos supports spark.mesos.network.name") { setBackend(Map( "spark.mesos.network.name" -> "test-network-name" From be85245a98d58f636ff54956cdfde15ea5cd6122 Mon Sep 17 00:00:00 2001 From: sethah Date: Sat, 25 Mar 2017 17:41:59 +0000 Subject: [PATCH 0116/1765] [SPARK-17137][ML][WIP] Compress logistic regression coefficients ## What changes were proposed in this pull request? Use the new `compressed` method on matrices to store the logistic regression coefficients as sparse or dense - whichever is requires less memory. Marked as WIP so we can add some performance test results. Basically, we should see if prediction is slower because of using a sparse matrix over a dense one. This can happen since sparse matrices do not use native BLAS operations when computing the margins. ## How was this patch tested? Unit tests added. Author: sethah Closes #17426 from sethah/SPARK-17137. --- .../classification/LogisticRegression.scala | 28 ++------- .../LogisticRegressionSuite.scala | 58 ++++++++++++++----- 2 files changed, 49 insertions(+), 37 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala index 1a78187d4f8e3..7b56bce41c326 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala @@ -399,14 +399,9 @@ class LogisticRegression @Since("1.2.0") ( logWarning(s"All labels are the same value and fitIntercept=true, so the coefficients " + s"will be zeros. Training is not needed.") val constantLabelIndex = Vectors.dense(histogram).argmax - // TODO: use `compressed` after SPARK-17471 - val coefMatrix = if (numFeatures < numCoefficientSets) { - new SparseMatrix(numCoefficientSets, numFeatures, - Array.fill(numFeatures + 1)(0), Array.empty[Int], Array.empty[Double]) - } else { - new SparseMatrix(numCoefficientSets, numFeatures, Array.fill(numCoefficientSets + 1)(0), - Array.empty[Int], Array.empty[Double], isTransposed = true) - } + val coefMatrix = new SparseMatrix(numCoefficientSets, numFeatures, + new Array[Int](numCoefficientSets + 1), Array.empty[Int], Array.empty[Double], + isTransposed = true).compressed val interceptVec = if (isMultinomial) { Vectors.sparse(numClasses, Seq((constantLabelIndex, Double.PositiveInfinity))) } else { @@ -617,26 +612,13 @@ class LogisticRegression @Since("1.2.0") ( denseCoefficientMatrix.update(_ - coefficientMean) } - // TODO: use `denseCoefficientMatrix.compressed` after SPARK-17471 - val compressedCoefficientMatrix = if (isMultinomial) { - denseCoefficientMatrix - } else { - val compressedVector = Vectors.dense(denseCoefficientMatrix.values).compressed - compressedVector match { - case dv: DenseVector => denseCoefficientMatrix - case sv: SparseVector => - new SparseMatrix(1, numFeatures, Array(0, sv.indices.length), sv.indices, sv.values, - isTransposed = true) - } - } - // center the intercepts when using multinomial algorithm if ($(fitIntercept) && isMultinomial) { val interceptArray = interceptVec.toArray val interceptMean = interceptArray.sum / interceptArray.length (0 until interceptVec.size).foreach { i => interceptArray(i) -= interceptMean } } - (compressedCoefficientMatrix, interceptVec.compressed, arrayBuilder.result()) + (denseCoefficientMatrix.compressed, interceptVec.compressed, arrayBuilder.result()) } } @@ -713,7 +695,7 @@ class LogisticRegressionModel private[spark] ( // convert to appropriate vector representation without replicating data private lazy val _coefficients: Vector = { require(coefficientMatrix.isTransposed, - "LogisticRegressionModel coefficients should be row major.") + "LogisticRegressionModel coefficients should be row major for binomial model.") coefficientMatrix match { case dm: DenseMatrix => Vectors.dense(dm.values) case sm: SparseMatrix => Vectors.sparse(coefficientMatrix.numCols, sm.rowIndices, sm.values) diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala index affaa573749e8..1b64480373492 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala @@ -26,7 +26,7 @@ import org.apache.spark.{SparkException, SparkFunSuite} import org.apache.spark.ml.attribute.NominalAttribute import org.apache.spark.ml.classification.LogisticRegressionSuite._ import org.apache.spark.ml.feature.{Instance, LabeledPoint} -import org.apache.spark.ml.linalg.{DenseMatrix, Matrices, SparseMatrix, SparseVector, Vector, Vectors} +import org.apache.spark.ml.linalg.{DenseMatrix, Matrices, SparseMatrix, Vector, Vectors} import org.apache.spark.ml.param.{ParamMap, ParamsSuite} import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} import org.apache.spark.ml.util.TestingUtils._ @@ -713,8 +713,6 @@ class LogisticRegressionSuite assert(model2.intercept ~== interceptR relTol 1E-2) assert(model2.coefficients ~== coefficientsR absTol 1E-3) - // TODO: move this to a standalone test of compression after SPARK-17471 - assert(model2.coefficients.isInstanceOf[SparseVector]) } test("binary logistic regression without intercept with L1 regularization") { @@ -2031,29 +2029,61 @@ class LogisticRegressionSuite // TODO: check num iters is zero when it become available in the model } - test("compressed storage") { + test("compressed storage for constant label") { + /* + When the label is constant and fit intercept is true, all the coefficients will be + zeros, and so the model coefficients should be stored as sparse data structures, except + when the matrix dimensions are very small. + */ val moreClassesThanFeatures = Seq( - LabeledPoint(4.0, Vectors.dense(0.0, 0.0, 0.0)), - LabeledPoint(4.0, Vectors.dense(1.0, 1.0, 1.0)), - LabeledPoint(4.0, Vectors.dense(2.0, 2.0, 2.0))).toDF() - val mlr = new LogisticRegression().setFamily("multinomial") + LabeledPoint(4.0, Vectors.dense(Array.fill(5)(0.0))), + LabeledPoint(4.0, Vectors.dense(Array.fill(5)(1.0))), + LabeledPoint(4.0, Vectors.dense(Array.fill(5)(2.0)))).toDF() + val mlr = new LogisticRegression().setFamily("multinomial").setFitIntercept(true) val model = mlr.fit(moreClassesThanFeatures) assert(model.coefficientMatrix.isInstanceOf[SparseMatrix]) - assert(model.coefficientMatrix.asInstanceOf[SparseMatrix].colPtrs.length === 4) + assert(model.coefficientMatrix.isColMajor) + + // in this case, it should be stored as row major val moreFeaturesThanClasses = Seq( - LabeledPoint(1.0, Vectors.dense(0.0, 0.0, 0.0)), - LabeledPoint(1.0, Vectors.dense(1.0, 1.0, 1.0)), - LabeledPoint(1.0, Vectors.dense(2.0, 2.0, 2.0))).toDF() + LabeledPoint(1.0, Vectors.dense(Array.fill(5)(0.0))), + LabeledPoint(1.0, Vectors.dense(Array.fill(5)(1.0))), + LabeledPoint(1.0, Vectors.dense(Array.fill(5)(2.0)))).toDF() val model2 = mlr.fit(moreFeaturesThanClasses) assert(model2.coefficientMatrix.isInstanceOf[SparseMatrix]) - assert(model2.coefficientMatrix.asInstanceOf[SparseMatrix].colPtrs.length === 3) + assert(model2.coefficientMatrix.isRowMajor) - val blr = new LogisticRegression().setFamily("binomial") + val blr = new LogisticRegression().setFamily("binomial").setFitIntercept(true) val blrModel = blr.fit(moreFeaturesThanClasses) assert(blrModel.coefficientMatrix.isInstanceOf[SparseMatrix]) assert(blrModel.coefficientMatrix.asInstanceOf[SparseMatrix].colPtrs.length === 2) } + test("compressed coefficients") { + + val trainer1 = new LogisticRegression() + .setRegParam(0.1) + .setElasticNetParam(1.0) + + // compressed row major is optimal + val model1 = trainer1.fit(multinomialDataset.limit(100)) + assert(model1.coefficientMatrix.isInstanceOf[SparseMatrix]) + assert(model1.coefficientMatrix.isRowMajor) + + // compressed column major is optimal since there are more classes than features + val labelMeta = NominalAttribute.defaultAttr.withName("label").withNumValues(6).toMetadata() + val model2 = trainer1.fit(multinomialDataset + .withColumn("label", col("label").as("label", labelMeta)).limit(100)) + assert(model2.coefficientMatrix.isInstanceOf[SparseMatrix]) + assert(model2.coefficientMatrix.isColMajor) + + // coefficients are dense without L1 regularization + val trainer2 = new LogisticRegression() + .setElasticNetParam(0.0) + val model3 = trainer2.fit(multinomialDataset.limit(100)) + assert(model3.coefficientMatrix.isInstanceOf[DenseMatrix]) + } + test("numClasses specified in metadata/inferred") { val lr = new LogisticRegression().setMaxIter(1).setFamily("multinomial") From 0b903caef3183c5113feb09995874f6a07aa6698 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Sat, 25 Mar 2017 11:46:54 -0700 Subject: [PATCH 0117/1765] [SPARK-19949][SQL][FOLLOW-UP] move FailureSafeParser from catalyst to sql core ## What changes were proposed in this pull request? The `FailureSafeParser` is only used in sql core, it doesn't make sense to put it in catalyst module. ## How was this patch tested? N/A Author: Wenchen Fan Closes #17408 from cloud-fan/minor. --- .../catalyst/util/BadRecordException.scala | 33 +++++++++++++++++++ .../apache/spark/sql/DataFrameReader.scala | 3 +- .../datasources}/FailureSafeParser.scala | 15 ++------- .../datasources/csv/UnivocityParser.scala | 3 +- .../datasources/json/JsonDataSource.scala | 3 +- 5 files changed, 39 insertions(+), 18 deletions(-) create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/BadRecordException.scala rename sql/{catalyst/src/main/scala/org/apache/spark/sql/catalyst/util => core/src/main/scala/org/apache/spark/sql/execution/datasources}/FailureSafeParser.scala (82%) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/BadRecordException.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/BadRecordException.scala new file mode 100644 index 0000000000000..985f0dc1cd60e --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/BadRecordException.scala @@ -0,0 +1,33 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.util + +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.unsafe.types.UTF8String + +/** + * Exception thrown when the underlying parser meet a bad record and can't parse it. + * @param record a function to return the record that cause the parser to fail + * @param partialResult a function that returns an optional row, which is the partial result of + * parsing this bad record. + * @param cause the actual exception about why the record is bad and can't be parsed. + */ +case class BadRecordException( + record: () => UTF8String, + partialResult: () => Option[InternalRow], + cause: Throwable) extends Exception(cause) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala index e6d2b1bc28d95..6c238618f2af7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala @@ -27,11 +27,10 @@ import org.apache.spark.Partition import org.apache.spark.annotation.InterfaceStability import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.json.{CreateJacksonParser, JacksonParser, JSONOptions} -import org.apache.spark.sql.catalyst.util.FailureSafeParser import org.apache.spark.sql.execution.LogicalRDD import org.apache.spark.sql.execution.command.DDLUtils +import org.apache.spark.sql.execution.datasources.{DataSource, FailureSafeParser} import org.apache.spark.sql.execution.datasources.csv._ -import org.apache.spark.sql.execution.datasources.DataSource import org.apache.spark.sql.execution.datasources.jdbc._ import org.apache.spark.sql.execution.datasources.json.TextInputJsonDataSource import org.apache.spark.sql.types.{StringType, StructType} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/FailureSafeParser.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FailureSafeParser.scala similarity index 82% rename from sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/FailureSafeParser.scala rename to sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FailureSafeParser.scala index 725e3015b3416..159aef220be15 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/FailureSafeParser.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FailureSafeParser.scala @@ -15,10 +15,11 @@ * limitations under the License. */ -package org.apache.spark.sql.catalyst.util +package org.apache.spark.sql.execution.datasources import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.GenericInternalRow +import org.apache.spark.sql.catalyst.util._ import org.apache.spark.sql.types.StructType import org.apache.spark.unsafe.types.UTF8String @@ -69,15 +70,3 @@ class FailureSafeParser[IN]( } } } - -/** - * Exception thrown when the underlying parser meet a bad record and can't parse it. - * @param record a function to return the record that cause the parser to fail - * @param partialResult a function that returns an optional row, which is the partial result of - * parsing this bad record. - * @param cause the actual exception about why the record is bad and can't be parsed. - */ -case class BadRecordException( - record: () => UTF8String, - partialResult: () => Option[InternalRow], - cause: Throwable) extends Exception(cause) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/UnivocityParser.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/UnivocityParser.scala index 263f77e11c4da..c3657acb7d867 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/UnivocityParser.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/UnivocityParser.scala @@ -30,7 +30,8 @@ import com.univocity.parsers.csv.CsvParser import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.GenericInternalRow -import org.apache.spark.sql.catalyst.util.{BadRecordException, DateTimeUtils, FailureSafeParser} +import org.apache.spark.sql.catalyst.util.{BadRecordException, DateTimeUtils} +import org.apache.spark.sql.execution.datasources.FailureSafeParser import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonDataSource.scala index 51e952c12202e..4f2963da9ace9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonDataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonDataSource.scala @@ -33,8 +33,7 @@ import org.apache.spark.rdd.{BinaryFileRDD, RDD} import org.apache.spark.sql.{AnalysisException, Dataset, Encoders, SparkSession} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.json.{CreateJacksonParser, JacksonParser, JSONOptions} -import org.apache.spark.sql.catalyst.util.FailureSafeParser -import org.apache.spark.sql.execution.datasources.{CodecStreams, DataSource, HadoopFileLinesReader, PartitionedFile} +import org.apache.spark.sql.execution.datasources._ import org.apache.spark.sql.execution.datasources.text.TextFileFormat import org.apache.spark.sql.types.StructType import org.apache.spark.unsafe.types.UTF8String From 2422c86f2ce2dd649b1d63062ec5c5fc1716c519 Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Sat, 25 Mar 2017 23:29:02 -0700 Subject: [PATCH 0118/1765] [SPARK-20092][R][PROJECT INFRA] Add the detection for Scala codes dedicated for R in AppVeyor tests ## What changes were proposed in this pull request? We are currently detecting the changes in `R/` directory only and then trigger AppVeyor tests. It seems we need to tests when there are Scala codes dedicated for R in `core/src/main/scala/org/apache/spark/api/r/`, `sql/core/src/main/scala/org/apache/spark/sql/api/r/` and `mllib/src/main/scala/org/apache/spark/ml/r/` too. This will enables the tests, for example, for SPARK-20088. ## How was this patch tested? Tests with manually created PRs. - Changes in `sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala` https://github.com/spark-test/spark/pull/13 - Changes in `core/src/main/scala/org/apache/spark/api/r/SerDe.scala` https://github.com/spark-test/spark/pull/12 - Changes in `README.md` https://github.com/spark-test/spark/pull/14 Author: hyukjinkwon Closes #17427 from HyukjinKwon/SPARK-20092. --- appveyor.yml | 3 +++ 1 file changed, 3 insertions(+) diff --git a/appveyor.yml b/appveyor.yml index 5adf1b4bedb44..bbb27589cad09 100644 --- a/appveyor.yml +++ b/appveyor.yml @@ -27,6 +27,9 @@ branches: only_commits: files: - R/ + - sql/core/src/main/scala/org/apache/spark/sql/api/r/ + - core/src/main/scala/org/apache/spark/api/r/ + - mllib/src/main/scala/org/apache/spark/ml/r/ cache: - C:\Users\appveyor\.m2 From 93bb0b911b6c790fa369b39da51a83d8f62da909 Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Sun, 26 Mar 2017 09:20:22 +0200 Subject: [PATCH 0119/1765] [SPARK-20046][SQL] Facilitate loop optimizations in a JIT compiler regarding sqlContext.read.parquet() ## What changes were proposed in this pull request? This PR improves performance of operations with `sqlContext.read.parquet()` by changing Java code generated by Catalyst. This PR is inspired by [the blog article](https://databricks.com/blog/2017/02/16/processing-trillion-rows-per-second-single-machine-can-nested-loop-joins-fast.html) and [this stackoverflow entry](http://stackoverflow.com/questions/40629435/fast-parquet-row-count-in-spark). This PR changes generated code in the following two points. 1. Replace a while-loop with long instance variables a for-loop with int local variables 2. Suppress generation of `shouldStop()` method if this method is unnecessary (e.g. `append()` is not generated). These points facilitates compiler optimizations in a JIT compiler by feeding the simplified Java code into the JIT compiler. The performance of `sqlContext.read.parquet().count` is improved by 1.09x. Benchmark program: ```java val dir = "/dev/shm/parquet" val N = 1000 * 1000 * 40 val iters = 20 val benchmark = new Benchmark("Parquet", N * iters, minNumIters = 5, warmupTime = 30.seconds) sparkSession.range(n).write.mode("overwrite").parquet(dir) benchmark.addCase("count") { i: Int => var n = 0 var len = 0L while (n < iters) { len += sparkSession.read.parquet(dir).count n += 1 } } benchmark.run ``` Performance result without this PR ``` OpenJDK 64-Bit Server VM 1.8.0_121-8u121-b13-0ubuntu1.16.04.2-b13 on Linux 4.4.0-47-generic Intel(R) Xeon(R) CPU E5-2667 v3 3.20GHz Parquet: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ w/o this PR 1152 / 1211 694.7 1.4 1.0X ``` Performance result with this PR ``` OpenJDK 64-Bit Server VM 1.8.0_121-8u121-b13-0ubuntu1.16.04.2-b13 on Linux 4.4.0-47-generic Intel(R) Xeon(R) CPU E5-2667 v3 3.20GHz Parquet: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ with this PR 1053 / 1121 760.0 1.3 1.0X ``` Here is a comparison between generated code w/o and with this PR. Only the method ```agg_doAggregateWithoutKey``` is changed. Generated code without this PR ```java /* 005 */ final class GeneratedIterator extends org.apache.spark.sql.execution.BufferedRowIterator { /* 006 */ private Object[] references; /* 007 */ private scala.collection.Iterator[] inputs; /* 008 */ private boolean agg_initAgg; /* 009 */ private boolean agg_bufIsNull; /* 010 */ private long agg_bufValue; /* 011 */ private scala.collection.Iterator scan_input; /* 012 */ private org.apache.spark.sql.execution.metric.SQLMetric scan_numOutputRows; /* 013 */ private org.apache.spark.sql.execution.metric.SQLMetric scan_scanTime; /* 014 */ private long scan_scanTime1; /* 015 */ private org.apache.spark.sql.execution.vectorized.ColumnarBatch scan_batch; /* 016 */ private int scan_batchIdx; /* 017 */ private org.apache.spark.sql.execution.metric.SQLMetric agg_numOutputRows; /* 018 */ private org.apache.spark.sql.execution.metric.SQLMetric agg_aggTime; /* 019 */ private UnsafeRow agg_result; /* 020 */ private org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder agg_holder; /* 021 */ private org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter agg_rowWriter; /* 022 */ /* 023 */ public GeneratedIterator(Object[] references) { /* 024 */ this.references = references; /* 025 */ } /* 026 */ /* 027 */ public void init(int index, scala.collection.Iterator[] inputs) { /* 028 */ partitionIndex = index; /* 029 */ this.inputs = inputs; /* 030 */ agg_initAgg = false; /* 031 */ /* 032 */ scan_input = inputs[0]; /* 033 */ this.scan_numOutputRows = (org.apache.spark.sql.execution.metric.SQLMetric) references[0]; /* 034 */ this.scan_scanTime = (org.apache.spark.sql.execution.metric.SQLMetric) references[1]; /* 035 */ scan_scanTime1 = 0; /* 036 */ scan_batch = null; /* 037 */ scan_batchIdx = 0; /* 038 */ this.agg_numOutputRows = (org.apache.spark.sql.execution.metric.SQLMetric) references[2]; /* 039 */ this.agg_aggTime = (org.apache.spark.sql.execution.metric.SQLMetric) references[3]; /* 040 */ agg_result = new UnsafeRow(1); /* 041 */ this.agg_holder = new org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder(agg_result, 0); /* 042 */ this.agg_rowWriter = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter(agg_holder, 1); /* 043 */ /* 044 */ } /* 045 */ /* 046 */ private void agg_doAggregateWithoutKey() throws java.io.IOException { /* 047 */ // initialize aggregation buffer /* 048 */ agg_bufIsNull = false; /* 049 */ agg_bufValue = 0L; /* 050 */ /* 051 */ if (scan_batch == null) { /* 052 */ scan_nextBatch(); /* 053 */ } /* 054 */ while (scan_batch != null) { /* 055 */ int numRows = scan_batch.numRows(); /* 056 */ while (scan_batchIdx < numRows) { /* 057 */ int scan_rowIdx = scan_batchIdx++; /* 058 */ // do aggregate /* 059 */ // common sub-expressions /* 060 */ /* 061 */ // evaluate aggregate function /* 062 */ boolean agg_isNull1 = false; /* 063 */ /* 064 */ long agg_value1 = -1L; /* 065 */ agg_value1 = agg_bufValue + 1L; /* 066 */ // update aggregation buffer /* 067 */ agg_bufIsNull = false; /* 068 */ agg_bufValue = agg_value1; /* 069 */ if (shouldStop()) return; /* 070 */ } /* 071 */ scan_batch = null; /* 072 */ scan_nextBatch(); /* 073 */ } /* 074 */ scan_scanTime.add(scan_scanTime1 / (1000 * 1000)); /* 075 */ scan_scanTime1 = 0; /* 076 */ /* 077 */ } /* 078 */ /* 079 */ private void scan_nextBatch() throws java.io.IOException { /* 080 */ long getBatchStart = System.nanoTime(); /* 081 */ if (scan_input.hasNext()) { /* 082 */ scan_batch = (org.apache.spark.sql.execution.vectorized.ColumnarBatch)scan_input.next(); /* 083 */ scan_numOutputRows.add(scan_batch.numRows()); /* 084 */ scan_batchIdx = 0; /* 085 */ /* 086 */ } /* 087 */ scan_scanTime1 += System.nanoTime() - getBatchStart; /* 088 */ } /* 089 */ /* 090 */ protected void processNext() throws java.io.IOException { /* 091 */ while (!agg_initAgg) { /* 092 */ agg_initAgg = true; /* 093 */ long agg_beforeAgg = System.nanoTime(); /* 094 */ agg_doAggregateWithoutKey(); /* 095 */ agg_aggTime.add((System.nanoTime() - agg_beforeAgg) / 1000000); /* 096 */ /* 097 */ // output the result /* 098 */ /* 099 */ agg_numOutputRows.add(1); /* 100 */ agg_rowWriter.zeroOutNullBytes(); /* 101 */ /* 102 */ if (agg_bufIsNull) { /* 103 */ agg_rowWriter.setNullAt(0); /* 104 */ } else { /* 105 */ agg_rowWriter.write(0, agg_bufValue); /* 106 */ } /* 107 */ append(agg_result); /* 108 */ } /* 109 */ } /* 110 */ } ``` Generated code with this PR ```java /* 005 */ final class GeneratedIterator extends org.apache.spark.sql.execution.BufferedRowIterator { /* 006 */ private Object[] references; /* 007 */ private scala.collection.Iterator[] inputs; /* 008 */ private boolean agg_initAgg; /* 009 */ private boolean agg_bufIsNull; /* 010 */ private long agg_bufValue; /* 011 */ private scala.collection.Iterator scan_input; /* 012 */ private org.apache.spark.sql.execution.metric.SQLMetric scan_numOutputRows; /* 013 */ private org.apache.spark.sql.execution.metric.SQLMetric scan_scanTime; /* 014 */ private long scan_scanTime1; /* 015 */ private org.apache.spark.sql.execution.vectorized.ColumnarBatch scan_batch; /* 016 */ private int scan_batchIdx; /* 017 */ private org.apache.spark.sql.execution.metric.SQLMetric agg_numOutputRows; /* 018 */ private org.apache.spark.sql.execution.metric.SQLMetric agg_aggTime; /* 019 */ private UnsafeRow agg_result; /* 020 */ private org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder agg_holder; /* 021 */ private org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter agg_rowWriter; /* 022 */ /* 023 */ public GeneratedIterator(Object[] references) { /* 024 */ this.references = references; /* 025 */ } /* 026 */ /* 027 */ public void init(int index, scala.collection.Iterator[] inputs) { /* 028 */ partitionIndex = index; /* 029 */ this.inputs = inputs; /* 030 */ agg_initAgg = false; /* 031 */ /* 032 */ scan_input = inputs[0]; /* 033 */ this.scan_numOutputRows = (org.apache.spark.sql.execution.metric.SQLMetric) references[0]; /* 034 */ this.scan_scanTime = (org.apache.spark.sql.execution.metric.SQLMetric) references[1]; /* 035 */ scan_scanTime1 = 0; /* 036 */ scan_batch = null; /* 037 */ scan_batchIdx = 0; /* 038 */ this.agg_numOutputRows = (org.apache.spark.sql.execution.metric.SQLMetric) references[2]; /* 039 */ this.agg_aggTime = (org.apache.spark.sql.execution.metric.SQLMetric) references[3]; /* 040 */ agg_result = new UnsafeRow(1); /* 041 */ this.agg_holder = new org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder(agg_result, 0); /* 042 */ this.agg_rowWriter = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter(agg_holder, 1); /* 043 */ /* 044 */ } /* 045 */ /* 046 */ private void agg_doAggregateWithoutKey() throws java.io.IOException { /* 047 */ // initialize aggregation buffer /* 048 */ agg_bufIsNull = false; /* 049 */ agg_bufValue = 0L; /* 050 */ /* 051 */ if (scan_batch == null) { /* 052 */ scan_nextBatch(); /* 053 */ } /* 054 */ while (scan_batch != null) { /* 055 */ int numRows = scan_batch.numRows(); /* 056 */ int scan_localEnd = numRows - scan_batchIdx; /* 057 */ for (int scan_localIdx = 0; scan_localIdx < scan_localEnd; scan_localIdx++) { /* 058 */ int scan_rowIdx = scan_batchIdx + scan_localIdx; /* 059 */ // do aggregate /* 060 */ // common sub-expressions /* 061 */ /* 062 */ // evaluate aggregate function /* 063 */ boolean agg_isNull1 = false; /* 064 */ /* 065 */ long agg_value1 = -1L; /* 066 */ agg_value1 = agg_bufValue + 1L; /* 067 */ // update aggregation buffer /* 068 */ agg_bufIsNull = false; /* 069 */ agg_bufValue = agg_value1; /* 070 */ // shouldStop check is eliminated /* 071 */ } /* 072 */ scan_batchIdx = numRows; /* 073 */ scan_batch = null; /* 074 */ scan_nextBatch(); /* 075 */ } /* 079 */ } /* 080 */ /* 081 */ private void scan_nextBatch() throws java.io.IOException { /* 082 */ long getBatchStart = System.nanoTime(); /* 083 */ if (scan_input.hasNext()) { /* 084 */ scan_batch = (org.apache.spark.sql.execution.vectorized.ColumnarBatch)scan_input.next(); /* 085 */ scan_numOutputRows.add(scan_batch.numRows()); /* 086 */ scan_batchIdx = 0; /* 087 */ /* 088 */ } /* 089 */ scan_scanTime1 += System.nanoTime() - getBatchStart; /* 090 */ } /* 091 */ /* 092 */ protected void processNext() throws java.io.IOException { /* 093 */ while (!agg_initAgg) { /* 094 */ agg_initAgg = true; /* 095 */ long agg_beforeAgg = System.nanoTime(); /* 096 */ agg_doAggregateWithoutKey(); /* 097 */ agg_aggTime.add((System.nanoTime() - agg_beforeAgg) / 1000000); /* 098 */ /* 099 */ // output the result /* 100 */ /* 101 */ agg_numOutputRows.add(1); /* 102 */ agg_rowWriter.zeroOutNullBytes(); /* 103 */ /* 104 */ if (agg_bufIsNull) { /* 105 */ agg_rowWriter.setNullAt(0); /* 106 */ } else { /* 107 */ agg_rowWriter.write(0, agg_bufValue); /* 108 */ } /* 109 */ append(agg_result); /* 110 */ } /* 111 */ } /* 112 */ } ``` ## How was this patch tested? Tested existing test suites Author: Kazuaki Ishizaki Closes #17378 from kiszk/SPARK-20046. --- .../sql/execution/ColumnarBatchScan.scala | 18 ++++++++++++++---- 1 file changed, 14 insertions(+), 4 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ColumnarBatchScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ColumnarBatchScan.scala index 04fba17be4bfa..e86116680a57a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ColumnarBatchScan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ColumnarBatchScan.scala @@ -111,17 +111,27 @@ private[sql] trait ColumnarBatchScan extends CodegenSupport { val columnsBatchInput = (output zip colVars).map { case (attr, colVar) => genCodeColumnVector(ctx, colVar, rowidx, attr.dataType, attr.nullable) } + val localIdx = ctx.freshName("localIdx") + val localEnd = ctx.freshName("localEnd") + val numRows = ctx.freshName("numRows") + val shouldStop = if (isShouldStopRequired) { + s"if (shouldStop()) { $idx = $rowidx + 1; return; }" + } else { + "// shouldStop check is eliminated" + } s""" |if ($batch == null) { | $nextBatch(); |} |while ($batch != null) { - | int numRows = $batch.numRows(); - | while ($idx < numRows) { - | int $rowidx = $idx++; + | int $numRows = $batch.numRows(); + | int $localEnd = $numRows - $idx; + | for (int $localIdx = 0; $localIdx < $localEnd; $localIdx++) { + | int $rowidx = $idx + $localIdx; | ${consume(ctx, columnsBatchInput).trim} - | if (shouldStop()) return; + | $shouldStop | } + | $idx = $numRows; | $batch = null; | $nextBatch(); |} From 362ee93296a0de6342b4339e941e6a11f445c5b2 Mon Sep 17 00:00:00 2001 From: Juan Rodriguez Hortala Date: Sun, 26 Mar 2017 10:39:05 +0100 Subject: [PATCH 0120/1765] logging improvements ## What changes were proposed in this pull request? Adding additional information to existing logging messages: - YarnAllocator: log the executor ID together with the container id when a container for an executor is launched. - NettyRpcEnv: log the receiver address when there is a timeout waiting for an answer to a remote call. - ExecutorAllocationManager: fix a typo in the logging message for the list of executors to be removed. ## How was this patch tested? Build spark and submit the word count example to a YARN cluster using cluster mode Author: Juan Rodriguez Hortala Closes #17411 from juanrh/logging-improvements. --- .../scala/org/apache/spark/ExecutorAllocationManager.scala | 2 +- .../main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala | 3 ++- .../scala/org/apache/spark/deploy/yarn/YarnAllocator.scala | 3 ++- 3 files changed, 5 insertions(+), 3 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala b/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala index 1366251d0618f..261b3329a7b9c 100644 --- a/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala +++ b/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala @@ -439,7 +439,7 @@ private[spark] class ExecutorAllocationManager( executorsRemoved } else { logWarning(s"Unable to reach the cluster manager to kill executor/s " + - "executorIdsToBeRemoved.mkString(\",\") or no executor eligible to kill!") + s"${executorIdsToBeRemoved.mkString(",")} or no executor eligible to kill!") Seq.empty[String] } } diff --git a/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala b/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala index ff5e39a8dcbc8..b316e5443f639 100644 --- a/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala +++ b/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala @@ -236,7 +236,8 @@ private[netty] class NettyRpcEnv( val timeoutCancelable = timeoutScheduler.schedule(new Runnable { override def run(): Unit = { - onFailure(new TimeoutException(s"Cannot receive any reply in ${timeout.duration}")) + onFailure(new TimeoutException(s"Cannot receive any reply from ${remoteAddr} " + + s"in ${timeout.duration}")) } }, timeout.duration.toNanos, TimeUnit.NANOSECONDS) promise.future.onComplete { v => diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala index abd2de75c6450..25556763da904 100644 --- a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala +++ b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala @@ -494,7 +494,8 @@ private[yarn] class YarnAllocator( val containerId = container.getId val executorId = executorIdCounter.toString assert(container.getResource.getMemory >= resource.getMemory) - logInfo(s"Launching container $containerId on host $executorHostname") + logInfo(s"Launching container $containerId on host $executorHostname " + + s"for executor with ID $executorId") def updateInternalState(): Unit = synchronized { numExecutorsRunning += 1 From 617ab6445ea33d8297f0691723fd19bae19228dc Mon Sep 17 00:00:00 2001 From: Herman van Hovell Date: Sun, 26 Mar 2017 22:47:31 +0200 Subject: [PATCH 0121/1765] [SPARK-20086][SQL] CollapseWindow should not collapse dependent adjacent windows ## What changes were proposed in this pull request? The `CollapseWindow` is currently to aggressive when collapsing adjacent windows. It also collapses windows in the which the parent produces a column that is consumed by the child; this creates an invalid window which will fail at runtime. This PR fixes this by adding a check for dependent adjacent windows to the `CollapseWindow` rule. ## How was this patch tested? Added a new test case to `CollapseWindowSuite` Author: Herman van Hovell Closes #17432 from hvanhovell/SPARK-20086. --- .../spark/sql/catalyst/optimizer/Optimizer.scala | 8 +++++--- .../sql/catalyst/optimizer/CollapseWindowSuite.scala | 11 +++++++++++ 2 files changed, 16 insertions(+), 3 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index ee7de86921496..dbe3ded4bbf15 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -597,12 +597,14 @@ object CollapseRepartition extends Rule[LogicalPlan] { /** * Collapse Adjacent Window Expression. - * - If the partition specs and order specs are the same, collapse into the parent. + * - If the partition specs and order specs are the same and the window expression are + * independent, collapse into the parent. */ object CollapseWindow extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { - case w @ Window(we1, ps1, os1, Window(we2, ps2, os2, grandChild)) if ps1 == ps2 && os1 == os2 => - w.copy(windowExpressions = we2 ++ we1, child = grandChild) + case w1 @ Window(we1, ps1, os1, w2 @ Window(we2, ps2, os2, grandChild)) + if ps1 == ps2 && os1 == os2 && w1.references.intersect(w2.windowOutputSet).isEmpty => + w1.copy(windowExpressions = we2 ++ we1, child = grandChild) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CollapseWindowSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CollapseWindowSuite.scala index 3f7d1d9fd99af..52054c2f8bd8d 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CollapseWindowSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CollapseWindowSuite.scala @@ -78,4 +78,15 @@ class CollapseWindowSuite extends PlanTest { comparePlans(optimized2, correctAnswer2) } + + test("Don't collapse adjacent windows with dependent columns") { + val query = testRelation + .window(Seq(sum(a).as('sum_a)), partitionSpec1, orderSpec1) + .window(Seq(max('sum_a).as('max_sum_a)), partitionSpec1, orderSpec1) + .analyze + + val expected = query.analyze + val optimized = Optimize.execute(query.analyze) + comparePlans(optimized, expected) + } } From 0bc8847aa216497549c78ad49ec7ac066a059b15 Mon Sep 17 00:00:00 2001 From: zero323 Date: Sun, 26 Mar 2017 16:49:27 -0700 Subject: [PATCH 0122/1765] [SPARK-19281][PYTHON][ML] spark.ml Python API for FPGrowth ## What changes were proposed in this pull request? - Add `HasSupport` and `HasConfidence` `Params`. - Add new module `pyspark.ml.fpm`. - Add `FPGrowth` / `FPGrowthModel` wrappers. - Provide tests for new features. ## How was this patch tested? Unit tests. Author: zero323 Closes #17218 from zero323/SPARK-19281. --- dev/sparktestsupport/modules.py | 5 +- python/docs/pyspark.ml.rst | 8 ++ python/pyspark/ml/fpm.py | 216 ++++++++++++++++++++++++++++++++ python/pyspark/ml/tests.py | 53 ++++++-- 4 files changed, 273 insertions(+), 9 deletions(-) create mode 100644 python/pyspark/ml/fpm.py diff --git a/dev/sparktestsupport/modules.py b/dev/sparktestsupport/modules.py index 10ad1fe3aa2c6..eaf1f3a1db2ff 100644 --- a/dev/sparktestsupport/modules.py +++ b/dev/sparktestsupport/modules.py @@ -423,15 +423,16 @@ def __hash__(self): "python/pyspark/ml/" ], python_test_goals=[ - "pyspark.ml.feature", "pyspark.ml.classification", "pyspark.ml.clustering", + "pyspark.ml.evaluation", + "pyspark.ml.feature", + "pyspark.ml.fpm", "pyspark.ml.linalg.__init__", "pyspark.ml.recommendation", "pyspark.ml.regression", "pyspark.ml.tuning", "pyspark.ml.tests", - "pyspark.ml.evaluation", ], blacklisted_python_implementations=[ "PyPy" # Skip these tests under PyPy since they require numpy and it isn't available there diff --git a/python/docs/pyspark.ml.rst b/python/docs/pyspark.ml.rst index 26f7415e1a423..a68183445d78b 100644 --- a/python/docs/pyspark.ml.rst +++ b/python/docs/pyspark.ml.rst @@ -80,3 +80,11 @@ pyspark.ml.evaluation module :members: :undoc-members: :inherited-members: + +pyspark.ml.fpm module +---------------------------- + +.. automodule:: pyspark.ml.fpm + :members: + :undoc-members: + :inherited-members: diff --git a/python/pyspark/ml/fpm.py b/python/pyspark/ml/fpm.py new file mode 100644 index 0000000000000..b30d4edb19908 --- /dev/null +++ b/python/pyspark/ml/fpm.py @@ -0,0 +1,216 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from pyspark import keyword_only, since +from pyspark.ml.util import * +from pyspark.ml.wrapper import JavaEstimator, JavaModel +from pyspark.ml.param.shared import * + +__all__ = ["FPGrowth", "FPGrowthModel"] + + +class HasSupport(Params): + """ + Mixin for param support. + """ + + minSupport = Param( + Params._dummy(), + "minSupport", + """Minimal support level of the frequent pattern. [0.0, 1.0]. + Any pattern that appears more than (minSupport * size-of-the-dataset) + times will be output""", + typeConverter=TypeConverters.toFloat) + + def setMinSupport(self, value): + """ + Sets the value of :py:attr:`minSupport`. + """ + return self._set(minSupport=value) + + def getMinSupport(self): + """ + Gets the value of minSupport or its default value. + """ + return self.getOrDefault(self.minSupport) + + +class HasConfidence(Params): + """ + Mixin for param confidence. + """ + + minConfidence = Param( + Params._dummy(), + "minConfidence", + """Minimal confidence for generating Association Rule. [0.0, 1.0] + Note that minConfidence has no effect during fitting.""", + typeConverter=TypeConverters.toFloat) + + def setMinConfidence(self, value): + """ + Sets the value of :py:attr:`minConfidence`. + """ + return self._set(minConfidence=value) + + def getMinConfidence(self): + """ + Gets the value of minConfidence or its default value. + """ + return self.getOrDefault(self.minConfidence) + + +class HasItemsCol(Params): + """ + Mixin for param itemsCol: items column name. + """ + + itemsCol = Param(Params._dummy(), "itemsCol", + "items column name", typeConverter=TypeConverters.toString) + + def setItemsCol(self, value): + """ + Sets the value of :py:attr:`itemsCol`. + """ + return self._set(itemsCol=value) + + def getItemsCol(self): + """ + Gets the value of itemsCol or its default value. + """ + return self.getOrDefault(self.itemsCol) + + +class FPGrowthModel(JavaModel, JavaMLWritable, JavaMLReadable): + """ + .. note:: Experimental + + Model fitted by FPGrowth. + + .. versionadded:: 2.2.0 + """ + @property + @since("2.2.0") + def freqItemsets(self): + """ + DataFrame with two columns: + * `items` - Itemset of the same type as the input column. + * `freq` - Frequency of the itemset (`LongType`). + """ + return self._call_java("freqItemsets") + + @property + @since("2.2.0") + def associationRules(self): + """ + Data with three columns: + * `antecedent` - Array of the same type as the input column. + * `consequent` - Array of the same type as the input column. + * `confidence` - Confidence for the rule (`DoubleType`). + """ + return self._call_java("associationRules") + + +class FPGrowth(JavaEstimator, HasItemsCol, HasPredictionCol, + HasSupport, HasConfidence, JavaMLWritable, JavaMLReadable): + """ + .. note:: Experimental + + A parallel FP-growth algorithm to mine frequent itemsets. The algorithm is described in + Li et al., PFP: Parallel FP-Growth for Query Recommendation [LI2008]_. + PFP distributes computation in such a way that each worker executes an + independent group of mining tasks. The FP-Growth algorithm is described in + Han et al., Mining frequent patterns without candidate generation [HAN2000]_ + + .. [LI2008] http://dx.doi.org/10.1145/1454008.1454027 + .. [HAN2000] http://dx.doi.org/10.1145/335191.335372 + + .. note:: null values in the feature column are ignored during fit(). + .. note:: Internally `transform` `collects` and `broadcasts` association rules. + + >>> from pyspark.sql.functions import split + >>> data = (spark.read + ... .text("data/mllib/sample_fpgrowth.txt") + ... .select(split("value", "\s+").alias("items"))) + >>> data.show(truncate=False) + +------------------------+ + |items | + +------------------------+ + |[r, z, h, k, p] | + |[z, y, x, w, v, u, t, s]| + |[s, x, o, n, r] | + |[x, z, y, m, t, s, q, e]| + |[z] | + |[x, z, y, r, q, t, p] | + +------------------------+ + >>> fp = FPGrowth(minSupport=0.2, minConfidence=0.7) + >>> fpm = fp.fit(data) + >>> fpm.freqItemsets.show(5) + +---------+----+ + | items|freq| + +---------+----+ + | [s]| 3| + | [s, x]| 3| + |[s, x, z]| 2| + | [s, z]| 2| + | [r]| 3| + +---------+----+ + only showing top 5 rows + >>> fpm.associationRules.show(5) + +----------+----------+----------+ + |antecedent|consequent|confidence| + +----------+----------+----------+ + | [t, s]| [y]| 1.0| + | [t, s]| [x]| 1.0| + | [t, s]| [z]| 1.0| + | [p]| [r]| 1.0| + | [p]| [z]| 1.0| + +----------+----------+----------+ + only showing top 5 rows + >>> new_data = spark.createDataFrame([(["t", "s"], )], ["items"]) + >>> sorted(fpm.transform(new_data).first().prediction) + ['x', 'y', 'z'] + + .. versionadded:: 2.2.0 + """ + @keyword_only + def __init__(self, minSupport=0.3, minConfidence=0.8, itemsCol="items", + predictionCol="prediction", numPartitions=None): + """ + __init__(self, minSupport=0.3, minConfidence=0.8, itemsCol="items", \ + predictionCol="prediction", numPartitions=None) + """ + super(FPGrowth, self).__init__() + self._java_obj = self._new_java_obj("org.apache.spark.ml.fpm.FPGrowth", self.uid) + self._setDefault(minSupport=0.3, minConfidence=0.8, + itemsCol="items", predictionCol="prediction") + kwargs = self._input_kwargs + self.setParams(**kwargs) + + @keyword_only + @since("2.2.0") + def setParams(self, minSupport=0.3, minConfidence=0.8, itemsCol="items", + predictionCol="prediction", numPartitions=None): + """ + setParams(self, minSupport=0.3, minConfidence=0.8, itemsCol="items", \ + predictionCol="prediction", numPartitions=None) + """ + kwargs = self._input_kwargs + return self._set(**kwargs) + + def _create_model(self, java_model): + return FPGrowthModel(java_model) diff --git a/python/pyspark/ml/tests.py b/python/pyspark/ml/tests.py index cc559db58720f..527db9b66793a 100755 --- a/python/pyspark/ml/tests.py +++ b/python/pyspark/ml/tests.py @@ -42,7 +42,7 @@ import array as pyarray import numpy as np from numpy import ( - array, array_equal, zeros, inf, random, exp, dot, all, mean, abs, arange, tile, ones) + abs, all, arange, array, array_equal, dot, exp, inf, mean, ones, random, tile, zeros) from numpy import sum as array_sum import inspect @@ -50,18 +50,20 @@ from pyspark.ml import Estimator, Model, Pipeline, PipelineModel, Transformer from pyspark.ml.classification import * from pyspark.ml.clustering import * +from pyspark.ml.common import _java2py, _py2java from pyspark.ml.evaluation import BinaryClassificationEvaluator, RegressionEvaluator from pyspark.ml.feature import * -from pyspark.ml.linalg import Vector, SparseVector, DenseVector, VectorUDT,\ - DenseMatrix, SparseMatrix, Vectors, Matrices, MatrixUDT, _convert_to_vector +from pyspark.ml.fpm import FPGrowth, FPGrowthModel +from pyspark.ml.linalg import ( + DenseMatrix, DenseMatrix, DenseVector, Matrices, MatrixUDT, + SparseMatrix, SparseVector, Vector, VectorUDT, Vectors, _convert_to_vector) from pyspark.ml.param import Param, Params, TypeConverters -from pyspark.ml.param.shared import HasMaxIter, HasInputCol, HasSeed +from pyspark.ml.param.shared import HasInputCol, HasMaxIter, HasSeed from pyspark.ml.recommendation import ALS -from pyspark.ml.regression import LinearRegression, DecisionTreeRegressor, \ - GeneralizedLinearRegression +from pyspark.ml.regression import ( + DecisionTreeRegressor, GeneralizedLinearRegression, LinearRegression) from pyspark.ml.tuning import * from pyspark.ml.wrapper import JavaParams, JavaWrapper -from pyspark.ml.common import _java2py, _py2java from pyspark.serializers import PickleSerializer from pyspark.sql import DataFrame, Row, SparkSession from pyspark.sql.functions import rand @@ -1243,6 +1245,43 @@ def test_tweedie_distribution(self): self.assertTrue(np.isclose(model2.intercept, 0.6667, atol=1E-4)) +class FPGrowthTests(SparkSessionTestCase): + def setUp(self): + super(FPGrowthTests, self).setUp() + self.data = self.spark.createDataFrame( + [([1, 2], ), ([1, 2], ), ([1, 2, 3], ), ([1, 3], )], + ["items"]) + + def test_association_rules(self): + fp = FPGrowth() + fpm = fp.fit(self.data) + + expected_association_rules = self.spark.createDataFrame( + [([3], [1], 1.0), ([2], [1], 1.0)], + ["antecedent", "consequent", "confidence"] + ) + actual_association_rules = fpm.associationRules + + self.assertEqual(actual_association_rules.subtract(expected_association_rules).count(), 0) + self.assertEqual(expected_association_rules.subtract(actual_association_rules).count(), 0) + + def test_freq_itemsets(self): + fp = FPGrowth() + fpm = fp.fit(self.data) + + expected_freq_itemsets = self.spark.createDataFrame( + [([1], 4), ([2], 3), ([2, 1], 3), ([3], 2), ([3, 1], 2)], + ["items", "freq"] + ) + actual_freq_itemsets = fpm.freqItemsets + + self.assertEqual(actual_freq_itemsets.subtract(expected_freq_itemsets).count(), 0) + self.assertEqual(expected_freq_itemsets.subtract(actual_freq_itemsets).count(), 0) + + def tearDown(self): + del self.data + + class ALSTest(SparkSessionTestCase): def test_storage_levels(self): From 3fbf0a5f9297f438bc92db11f106d4a0ae568613 Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Sun, 26 Mar 2017 18:40:00 -0700 Subject: [PATCH 0123/1765] [MINOR][DOCS] Match several documentation changes in Scala to R/Python ## What changes were proposed in this pull request? This PR proposes to match minor documentations changes in https://github.com/apache/spark/pull/17399 and https://github.com/apache/spark/pull/17380 to R/Python. ## How was this patch tested? Manual tests in Python , Python tests via `./python/run-tests.py --module=pyspark-sql` and lint-checks for Python/R. Author: hyukjinkwon Closes #17429 from HyukjinKwon/minor-match-doc. --- R/pkg/R/functions.R | 6 +++--- python/pyspark/sql/functions.py | 8 ++++---- python/pyspark/sql/tests.py | 8 ++++++++ 3 files changed, 15 insertions(+), 7 deletions(-) diff --git a/R/pkg/R/functions.R b/R/pkg/R/functions.R index 2cff3ac08c3ae..449476dec5339 100644 --- a/R/pkg/R/functions.R +++ b/R/pkg/R/functions.R @@ -2632,8 +2632,8 @@ setMethod("date_sub", signature(y = "Column", x = "numeric"), #' format_number #' -#' Formats numeric column y to a format like '#,###,###.##', rounded to x decimal places, -#' and returns the result as a string column. +#' Formats numeric column y to a format like '#,###,###.##', rounded to x decimal places +#' with HALF_EVEN round mode, and returns the result as a string column. #' #' If x is 0, the result has no decimal point or fractional part. #' If x < 0, the result will be null. @@ -3548,7 +3548,7 @@ setMethod("row_number", #' array_contains #' -#' Returns true if the array contain the value. +#' Returns null if the array is null, true if the array contains the value, and false otherwise. #' #' @param x A Column #' @param value A value to be checked if contained in the column diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index f9121e60f35b8..843ae3816f061 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -1327,8 +1327,8 @@ def encode(col, charset): @since(1.5) def format_number(col, d): """ - Formats the number X to a format like '#,--#,--#.--', rounded to d decimal places, - and returns the result as a string. + Formats the number X to a format like '#,--#,--#.--', rounded to d decimal places + with HALF_EVEN round mode, and returns the result as a string. :param col: the column name of the numeric value to be formatted :param d: the N decimal places @@ -1675,8 +1675,8 @@ def array(*cols): @since(1.5) def array_contains(col, value): """ - Collection function: returns True if the array contains the given value. The collection - elements and value must be of the same type. + Collection function: returns null if the array is null, true if the array contains the + given value, and false otherwise. :param col: name of column containing array :param value: value to check for in array diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index b93b7ed192104..db41b4edb6dde 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -1129,6 +1129,14 @@ def test_rand_functions(self): rndn2 = df.select('key', functions.randn(0)).collect() self.assertEqual(sorted(rndn1), sorted(rndn2)) + def test_array_contains_function(self): + from pyspark.sql.functions import array_contains + + df = self.spark.createDataFrame([(["1", "2", "3"],), ([],)], ['data']) + actual = df.select(array_contains(df.data, 1).alias('b')).collect() + # The value argument can be implicitly castable to the element's type of the array. + self.assertEqual([Row(b=True), Row(b=False)], actual) + def test_between_function(self): df = self.sc.parallelize([ Row(a=1, b=2, c=3), From 890493458de396cfcffdd71233cfdd39e834944b Mon Sep 17 00:00:00 2001 From: wangzhenhua Date: Mon, 27 Mar 2017 23:41:27 +0800 Subject: [PATCH 0124/1765] [SPARK-20104][SQL] Don't estimate IsNull or IsNotNull predicates for non-leaf node ## What changes were proposed in this pull request? In current stage, we don't have advanced statistics such as sketches or histograms. As a result, some operator can't estimate `nullCount` accurately. E.g. left outer join estimation does not accurately update `nullCount` currently. So for `IsNull` and `IsNotNull` predicates, we only estimate them when the child is a leaf node, whose `nullCount` is accurate. ## How was this patch tested? A new test case is added in `FilterEstimationSuite`. Author: wangzhenhua Closes #17438 from wzhfy/nullEstimation. --- .../statsEstimation/FilterEstimation.scala | 12 ++++++--- .../FilterEstimationSuite.scala | 25 ++++++++++++++++++- 2 files changed, 33 insertions(+), 4 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/FilterEstimation.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/FilterEstimation.scala index b10785b05d6c7..f14df93160b75 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/FilterEstimation.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/FilterEstimation.scala @@ -24,7 +24,7 @@ import scala.math.BigDecimal.RoundingMode import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.CatalystConf import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.plans.logical.{ColumnStat, Filter, Statistics} +import org.apache.spark.sql.catalyst.plans.logical.{ColumnStat, Filter, LeafNode, Statistics} import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.types._ @@ -174,10 +174,16 @@ case class FilterEstimation(plan: Filter, catalystConf: CatalystConf) extends Lo case InSet(ar: Attribute, set) => evaluateInSet(ar, set, update) - case IsNull(ar: Attribute) => + // In current stage, we don't have advanced statistics such as sketches or histograms. + // As a result, some operator can't estimate `nullCount` accurately. E.g. left outer join + // estimation does not accurately update `nullCount` currently. + // So for IsNull and IsNotNull predicates, we only estimate them when the child is a leaf + // node, whose `nullCount` is accurate. + // This is a limitation due to lack of advanced stats. We should remove it in the future. + case IsNull(ar: Attribute) if plan.child.isInstanceOf[LeafNode] => evaluateNullCheck(ar, isNull = true, update) - case IsNotNull(ar: Attribute) => + case IsNotNull(ar: Attribute) if plan.child.isInstanceOf[LeafNode] => evaluateNullCheck(ar, isNull = false, update) case _ => diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/FilterEstimationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/FilterEstimationSuite.scala index 4691913c8c986..07abe1ed28533 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/FilterEstimationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/FilterEstimationSuite.scala @@ -20,7 +20,8 @@ package org.apache.spark.sql.catalyst.statsEstimation import java.sql.Date import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.plans.logical.{ColumnStat, Filter, Statistics} +import org.apache.spark.sql.catalyst.plans.LeftOuter +import org.apache.spark.sql.catalyst.plans.logical.{ColumnStat, Filter, Join, Statistics} import org.apache.spark.sql.catalyst.plans.logical.statsEstimation.EstimationUtils._ import org.apache.spark.sql.types._ @@ -340,6 +341,28 @@ class FilterEstimationSuite extends StatsEstimationTestBase { expectedRowCount = 2) } + // This is a limitation test. We should remove it after the limitation is removed. + test("don't estimate IsNull or IsNotNull if the child is a non-leaf node") { + val attrIntLargerRange = AttributeReference("c1", IntegerType)() + val colStatIntLargerRange = ColumnStat(distinctCount = 20, min = Some(1), max = Some(20), + nullCount = 10, avgLen = 4, maxLen = 4) + val smallerTable = childStatsTestPlan(Seq(attrInt), 10L) + val largerTable = StatsTestPlan( + outputList = Seq(attrIntLargerRange), + rowCount = 30, + attributeStats = AttributeMap(Seq(attrIntLargerRange -> colStatIntLargerRange))) + val nonLeafChild = Join(largerTable, smallerTable, LeftOuter, + Some(EqualTo(attrIntLargerRange, attrInt))) + + Seq(IsNull(attrIntLargerRange), IsNotNull(attrIntLargerRange)).foreach { predicate => + validateEstimatedStats( + Filter(predicate, nonLeafChild), + // column stats don't change + Seq(attrInt -> colStatInt, attrIntLargerRange -> colStatIntLargerRange), + expectedRowCount = 30) + } + } + private def childStatsTestPlan(outList: Seq[Attribute], tableRowCount: BigInt): StatsTestPlan = { StatsTestPlan( outputList = outList, From 0588dc7c0a9f3180dddae0dc202a6d41eb43464f Mon Sep 17 00:00:00 2001 From: Hossein Date: Mon, 27 Mar 2017 08:53:45 -0700 Subject: [PATCH 0125/1765] [SPARK-20088] Do not create new SparkContext in SparkR createSparkContext ## What changes were proposed in this pull request? Instead of creating new `JavaSparkContext` we use `SparkContext.getOrCreate`. ## How was this patch tested? Existing tests Author: Hossein Closes #17423 from falaki/SPARK-20088. --- core/src/main/scala/org/apache/spark/api/r/RRDD.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/api/r/RRDD.scala b/core/src/main/scala/org/apache/spark/api/r/RRDD.scala index 72ae0340aa3d1..295355c7bf018 100644 --- a/core/src/main/scala/org/apache/spark/api/r/RRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/r/RRDD.scala @@ -136,7 +136,7 @@ private[r] object RRDD { .mkString(File.separator)) } - val jsc = new JavaSparkContext(sparkConf) + val jsc = new JavaSparkContext(SparkContext.getOrCreate(sparkConf)) jars.foreach { jar => jsc.addJar(jar) } From 314cf51ded52834cfbaacf58d3d05a220965ca2a Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Mon, 27 Mar 2017 10:23:28 -0700 Subject: [PATCH 0126/1765] [SPARK-20102] Fix nightly packaging and RC packaging scripts w/ two minor build fixes ## What changes were proposed in this pull request? The master snapshot publisher builds are currently broken due to two minor build issues: 1. For unknown reasons, the LFTP `mkdir -p` command began throwing errors when the remote directory already exists. This change of behavior might have been caused by configuration changes in the ASF's SFTP server, but I'm not entirely sure of that. To work around this problem, this patch updates the script to ignore errors from the `lftp mkdir -p` commands. 2. The PySpark `setup.py` file references a non-existent `pyspark.ml.stat` module, causing Python packaging to fail by complaining about a missing directory. The fix is to simply drop that line from the setup script. ## How was this patch tested? The LFTP fix was tested by manually running the failing commands on AMPLab Jenkins against the ASF SFTP server. The PySpark fix was tested locally. Author: Josh Rosen Closes #17437 from JoshRosen/spark-20102. --- dev/create-release/release-build.sh | 8 ++++---- python/setup.py | 1 - 2 files changed, 4 insertions(+), 5 deletions(-) diff --git a/dev/create-release/release-build.sh b/dev/create-release/release-build.sh index e1db997a7d410..7976d8a039544 100755 --- a/dev/create-release/release-build.sh +++ b/dev/create-release/release-build.sh @@ -246,7 +246,7 @@ if [[ "$1" == "package" ]]; then dest_dir="$REMOTE_PARENT_DIR/${DEST_DIR_NAME}-bin" echo "Copying release tarballs to $dest_dir" # Put to new directory: - LFTP mkdir -p $dest_dir + LFTP mkdir -p $dest_dir || true LFTP mput -O $dest_dir 'spark-*' LFTP mput -O $dest_dir 'pyspark-*' LFTP mput -O $dest_dir 'SparkR_*' @@ -254,7 +254,7 @@ if [[ "$1" == "package" ]]; then LFTP "rm -r -f $REMOTE_PARENT_DIR/latest || exit 0" LFTP mv $dest_dir "$REMOTE_PARENT_DIR/latest" # Re-upload a second time and leave the files in the timestamped upload directory: - LFTP mkdir -p $dest_dir + LFTP mkdir -p $dest_dir || true LFTP mput -O $dest_dir 'spark-*' LFTP mput -O $dest_dir 'pyspark-*' LFTP mput -O $dest_dir 'SparkR_*' @@ -271,13 +271,13 @@ if [[ "$1" == "docs" ]]; then PRODUCTION=1 RELEASE_VERSION="$SPARK_VERSION" jekyll build echo "Copying release documentation to $dest_dir" # Put to new directory: - LFTP mkdir -p $dest_dir + LFTP mkdir -p $dest_dir || true LFTP mirror -R _site $dest_dir # Delete /latest directory and rename new upload to /latest LFTP "rm -r -f $REMOTE_PARENT_DIR/latest || exit 0" LFTP mv $dest_dir "$REMOTE_PARENT_DIR/latest" # Re-upload a second time and leave the files in the timestamped upload directory: - LFTP mkdir -p $dest_dir + LFTP mkdir -p $dest_dir || true LFTP mirror -R _site $dest_dir cd .. exit 0 diff --git a/python/setup.py b/python/setup.py index 47eab98e0f7b3..f50035435e26b 100644 --- a/python/setup.py +++ b/python/setup.py @@ -167,7 +167,6 @@ def _supports_symlinks(): 'pyspark.ml', 'pyspark.ml.linalg', 'pyspark.ml.param', - 'pyspark.ml.stat', 'pyspark.sql', 'pyspark.streaming', 'pyspark.bin', From 3fada2f502107bd5572fb895471943de7b2c38e4 Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Mon, 27 Mar 2017 10:43:00 -0700 Subject: [PATCH 0127/1765] [SPARK-20105][TESTS][R] Add tests for checkType and type string in structField in R ## What changes were proposed in this pull request? It seems `checkType` and the type string in `structField` are not being tested closely. This string format currently seems SparkR-specific (see https://github.com/apache/spark/blob/d1f6c64c4b763c05d6d79ae5497f298dc3835f3e/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala#L93-L131) but resembles SQL type definition. Therefore, it seems nicer if we test positive/negative cases in R side. ## How was this patch tested? Unit tests in `test_sparkSQL.R`. Author: hyukjinkwon Closes #17439 from HyukjinKwon/r-typestring-tests. --- R/pkg/inst/tests/testthat/test_sparkSQL.R | 53 +++++++++++++++++++++++ 1 file changed, 53 insertions(+) diff --git a/R/pkg/inst/tests/testthat/test_sparkSQL.R b/R/pkg/inst/tests/testthat/test_sparkSQL.R index 394d1a04e09c3..5acf8719d1201 100644 --- a/R/pkg/inst/tests/testthat/test_sparkSQL.R +++ b/R/pkg/inst/tests/testthat/test_sparkSQL.R @@ -140,6 +140,59 @@ test_that("structType and structField", { expect_equal(testSchema$fields()[[1]]$dataType.toString(), "StringType") }) +test_that("structField type strings", { + # positive cases + primitiveTypes <- list(byte = "ByteType", + integer = "IntegerType", + float = "FloatType", + double = "DoubleType", + string = "StringType", + binary = "BinaryType", + boolean = "BooleanType", + timestamp = "TimestampType", + date = "DateType") + + complexTypes <- list("map" = "MapType(StringType,IntegerType,true)", + "array" = "ArrayType(StringType,true)", + "struct" = "StructType(StructField(a,StringType,true))") + + typeList <- c(primitiveTypes, complexTypes) + typeStrings <- names(typeList) + + for (i in seq_along(typeStrings)){ + typeString <- typeStrings[i] + expected <- typeList[[i]] + testField <- structField("_col", typeString) + expect_is(testField, "structField") + expect_true(testField$nullable()) + expect_equal(testField$dataType.toString(), expected) + } + + # negative cases + primitiveErrors <- list(Byte = "Byte", + INTEGER = "INTEGER", + numeric = "numeric", + character = "character", + raw = "raw", + logical = "logical") + + complexErrors <- list("map" = " integer", + "array" = "String", + "struct" = "string ", + "map " = "map ", + "array< string>" = " string", + "struct" = " string") + + errorList <- c(primitiveErrors, complexErrors) + typeStrings <- names(errorList) + + for (i in seq_along(typeStrings)){ + typeString <- typeStrings[i] + expected <- paste0("Unsupported type for SparkDataframe: ", errorList[[i]]) + expect_error(structField("_col", typeString), expected) + } +}) + test_that("create DataFrame from RDD", { rdd <- lapply(parallelize(sc, 1:10), function(x) { list(x, as.character(x)) }) df <- createDataFrame(rdd, list("a", "b")) From 1d00761b9176a1f42976057ca78638c5b0763abc Mon Sep 17 00:00:00 2001 From: Yanbo Liang Date: Mon, 27 Mar 2017 17:37:24 -0700 Subject: [PATCH 0128/1765] [MINOR][SPARKR] Move 'Data type mapping between R and Spark' to right place in SparkR doc. Section ```Data type mapping between R and Spark``` was put in the wrong place in SparkR doc currently, we should move it to a separate section. ## What changes were proposed in this pull request? Before this PR: ![image](https://cloud.githubusercontent.com/assets/1962026/24340911/bc01a532-126a-11e7-9a08-0d60d13a547c.png) After this PR: ![image](https://cloud.githubusercontent.com/assets/1962026/24340938/d9d32a9a-126a-11e7-8891-d2f5b46e0c71.png) Author: Yanbo Liang Closes #17440 from yanboliang/sparkr-doc. --- docs/sparkr.md | 138 ++++++++++++++++++++++++------------------------- 1 file changed, 69 insertions(+), 69 deletions(-) diff --git a/docs/sparkr.md b/docs/sparkr.md index d7ffd9b3f1229..a1a35a7757e57 100644 --- a/docs/sparkr.md +++ b/docs/sparkr.md @@ -394,75 +394,6 @@ head(result[order(result$max_eruption, decreasing = TRUE), ]) {% endhighlight %}
      -#### Data type mapping between R and Spark - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
      RSpark
      bytebyte
      integerinteger
      floatfloat
      doubledouble
      numericdouble
      characterstring
      stringstring
      binarybinary
      rawbinary
      logicalboolean
      POSIXcttimestamp
      POSIXlttimestamp
      Datedate
      arrayarray
      listarray
      envmap
      - #### Run local R functions distributed using `spark.lapply` ##### spark.lapply @@ -557,6 +488,75 @@ SparkR supports a subset of the available R formula operators for model fitting, The following example shows how to save/load a MLlib model by SparkR. {% include_example read_write r/ml/ml.R %} +# Data type mapping between R and Spark + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
      RSpark
      bytebyte
      integerinteger
      floatfloat
      doubledouble
      numericdouble
      characterstring
      stringstring
      binarybinary
      rawbinary
      logicalboolean
      POSIXcttimestamp
      POSIXlttimestamp
      Datedate
      arrayarray
      listarray
      envmap
      + # R Function Name Conflicts When loading and attaching a new package in R, it is possible to have a name [conflict](https://stat.ethz.ch/R-manual/R-devel/library/base/html/library.html), where a From a250933c625ed720d15a0e479e9c51113605b102 Mon Sep 17 00:00:00 2001 From: Shubham Chopra Date: Tue, 28 Mar 2017 09:47:29 +0800 Subject: [PATCH 0129/1765] [SPARK-19803][CORE][TEST] Proactive replication test failures ## What changes were proposed in this pull request? Executors cache a list of their peers that is refreshed by default every minute. The cached stale references were randomly being used for replication. Since those executors were removed from the master, they did not occur in the block locations as reported by the master. This was fixed by 1. Refreshing peer cache in the block manager before trying to pro-actively replicate. This way the probability of replicating to a failed executor is eliminated. 2. Explicitly stopping the block manager in the tests. This shuts down the RPC endpoint use by the block manager. This way, even if a block manager tries to replicate using a stale reference, the replication logic should take care of refreshing the list of peers after failure. ## How was this patch tested? Tested manually Author: Shubham Chopra Author: Kay Ousterhout Author: Shubham Chopra Closes #17325 from shubhamchopra/SPARK-19803. --- .../spark/storage/BlockInfoManager.scala | 6 ++++ .../apache/spark/storage/BlockManager.scala | 6 +++- .../BlockManagerReplicationSuite.scala | 29 ++++++++++++------- 3 files changed, 29 insertions(+), 12 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/storage/BlockInfoManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockInfoManager.scala index 490d45d12b8e3..3db59837fbebd 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockInfoManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockInfoManager.scala @@ -371,6 +371,12 @@ private[storage] class BlockInfoManager extends Logging { blocksWithReleasedLocks } + /** Returns the number of locks held by the given task. Used only for testing. */ + private[storage] def getTaskLockCount(taskAttemptId: TaskAttemptId): Int = { + readLocksByTask.get(taskAttemptId).map(_.size()).getOrElse(0) + + writeLocksByTask.get(taskAttemptId).map(_.size).getOrElse(0) + } + /** * Returns the number of blocks tracked. */ diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala index 245d94ac4f8b1..991346a40af4e 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala @@ -1187,7 +1187,7 @@ private[spark] class BlockManager( blockId: BlockId, existingReplicas: Set[BlockManagerId], maxReplicas: Int): Unit = { - logInfo(s"Pro-actively replicating $blockId") + logInfo(s"Using $blockManagerId to pro-actively replicate $blockId") blockInfoManager.lockForReading(blockId).foreach { info => val data = doGetLocalBytes(blockId, info) val storageLevel = StorageLevel( @@ -1196,9 +1196,13 @@ private[spark] class BlockManager( useOffHeap = info.level.useOffHeap, deserialized = info.level.deserialized, replication = maxReplicas) + // we know we are called as a result of an executor removal, so we refresh peer cache + // this way, we won't try to replicate to a missing executor with a stale reference + getPeers(forceFetch = true) try { replicate(blockId, data, storageLevel, info.classTag, existingReplicas) } finally { + logDebug(s"Releasing lock for $blockId") releaseLock(blockId) } } diff --git a/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala index d907add920c8a..d5715f8469f71 100644 --- a/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala @@ -493,27 +493,34 @@ class BlockManagerProactiveReplicationSuite extends BlockManagerReplicationBehav assert(blockLocations.size === replicationFactor) // remove a random blockManager - val executorsToRemove = blockLocations.take(replicationFactor - 1) + val executorsToRemove = blockLocations.take(replicationFactor - 1).toSet logInfo(s"Removing $executorsToRemove") - executorsToRemove.foreach{exec => - master.removeExecutor(exec.executorId) + initialStores.filter(bm => executorsToRemove.contains(bm.blockManagerId)).foreach { bm => + master.removeExecutor(bm.blockManagerId.executorId) + bm.stop() // giving enough time for replication to happen and new block be reported to master - Thread.sleep(200) + eventually(timeout(5 seconds), interval(100 millis)) { + val newLocations = master.getLocations(blockId).toSet + assert(newLocations.size === replicationFactor) + } } - val newLocations = eventually(timeout(5 seconds), interval(10 millis)) { + val newLocations = eventually(timeout(5 seconds), interval(100 millis)) { val _newLocations = master.getLocations(blockId).toSet assert(_newLocations.size === replicationFactor) _newLocations } logInfo(s"New locations : $newLocations") - // there should only be one common block manager between initial and new locations - assert(newLocations.intersect(blockLocations.toSet).size === 1) - // check if all the read locks have been released - initialStores.filter(bm => newLocations.contains(bm.blockManagerId)).foreach { bm => - val locks = bm.releaseAllLocksForTask(BlockInfo.NON_TASK_WRITER) - assert(locks.size === 0, "Read locks unreleased!") + // new locations should not contain stopped block managers + assert(newLocations.forall(bmId => !executorsToRemove.contains(bmId)), + "New locations contain stopped block managers.") + + // Make sure all locks have been released. + eventually(timeout(1000 milliseconds), interval(10 milliseconds)) { + initialStores.filter(bm => newLocations.contains(bm.blockManagerId)).foreach { bm => + assert(bm.blockInfoManager.getTaskLockCount(BlockInfo.NON_TASK_WRITER) === 0) + } } } } From 8a6f33f0483dcee81467e6374a796b5dbd53ea30 Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Mon, 27 Mar 2017 19:04:16 -0700 Subject: [PATCH 0130/1765] [SPARK-19876][SS] Follow up: Refactored BatchCommitLog to simplify logic ## What changes were proposed in this pull request? Existing logic seemingly writes null to the BatchCommitLog, even though it does additional checks to write '{}' (valid json) to the log. This PR simplifies the logic by disallowing use of `log.add(batchId, metadata)` and instead using `log.add(batchId)`. No question of specifying metadata, so no confusion related to null. ## How was this patch tested? Existing tests pass. Author: Tathagata Das Closes #17444 from tdas/SPARK-19876-1. --- .../execution/streaming/BatchCommitLog.scala | 28 +++++++++++-------- .../execution/streaming/HDFSMetadataLog.scala | 1 + .../execution/streaming/StreamExecution.scala | 2 +- 3 files changed, 19 insertions(+), 12 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/BatchCommitLog.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/BatchCommitLog.scala index fb1a4fb9b12f5..a34938f911f76 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/BatchCommitLog.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/BatchCommitLog.scala @@ -45,33 +45,39 @@ import org.apache.spark.sql.SparkSession class BatchCommitLog(sparkSession: SparkSession, path: String) extends HDFSMetadataLog[String](sparkSession, path) { + import BatchCommitLog._ + + def add(batchId: Long): Unit = { + super.add(batchId, EMPTY_JSON) + } + + override def add(batchId: Long, metadata: String): Boolean = { + throw new UnsupportedOperationException( + "BatchCommitLog does not take any metadata, use 'add(batchId)' instead") + } + override protected def deserialize(in: InputStream): String = { // called inside a try-finally where the underlying stream is closed in the caller val lines = IOSource.fromInputStream(in, UTF_8.name()).getLines() if (!lines.hasNext) { throw new IllegalStateException("Incomplete log file in the offset commit log") } - parseVersion(lines.next().trim, BatchCommitLog.VERSION) - // read metadata - lines.next().trim match { - case BatchCommitLog.SERIALIZED_VOID => null - case metadata => metadata - } + parseVersion(lines.next.trim, VERSION) + EMPTY_JSON } override protected def serialize(metadata: String, out: OutputStream): Unit = { // called inside a try-finally where the underlying stream is closed in the caller - out.write(s"v${BatchCommitLog.VERSION}".getBytes(UTF_8)) + out.write(s"v${VERSION}".getBytes(UTF_8)) out.write('\n') - // write metadata or void - out.write((if (metadata == null) BatchCommitLog.SERIALIZED_VOID else metadata) - .getBytes(UTF_8)) + // write metadata + out.write(EMPTY_JSON.getBytes(UTF_8)) } } object BatchCommitLog { private val VERSION = 1 - private val SERIALIZED_VOID = "{}" + private val EMPTY_JSON = "{}" } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/HDFSMetadataLog.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/HDFSMetadataLog.scala index 60ce64261c4a4..46bfc297931fb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/HDFSMetadataLog.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/HDFSMetadataLog.scala @@ -106,6 +106,7 @@ class HDFSMetadataLog[T <: AnyRef : ClassTag](sparkSession: SparkSession, path: * metadata has already been stored, this method will return `false`. */ override def add(batchId: Long, metadata: T): Boolean = { + require(metadata != null, "'null' metadata cannot written to a metadata log") get(batchId).map(_ => false).getOrElse { // Only write metadata when the batch has not yet been written writeBatch(batchId, metadata) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala index 34e9262af7cb2..5f548172f5ced 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala @@ -305,7 +305,7 @@ class StreamExecution( if (dataAvailable) { // Update committed offsets. committedOffsets ++= availableOffsets - batchCommitLog.add(currentBatchId, null) + batchCommitLog.add(currentBatchId) logDebug(s"batch ${currentBatchId} committed") // We'll increase currentBatchId after we complete processing current batch's data currentBatchId += 1 From ea361165e1ddce4d8aa0242ae3e878d7b39f1de2 Mon Sep 17 00:00:00 2001 From: Herman van Hovell Date: Tue, 28 Mar 2017 10:07:24 +0800 Subject: [PATCH 0131/1765] [SPARK-20100][SQL] Refactor SessionState initialization ## What changes were proposed in this pull request? The current SessionState initialization code path is quite complex. A part of the creation is done in the SessionState companion objects, a part of the creation is one inside the SessionState class, and a part is done by passing functions. This PR refactors this code path, and consolidates SessionState initialization into a builder class. This SessionState will not do any initialization and just becomes a place holder for the various Spark SQL internals. This also lays the ground work for two future improvements: 1. This provides us with a start for removing the `HiveSessionState`. Removing the `HiveSessionState` would also require us to move resource loading into a separate class, and to (re)move metadata hive. 2. This makes it easier to customize the Spark Session. Currently you will need to create a custom version of the builder. I have added hooks to facilitate this. A future step will be to create a semi stable API on top of this. ## How was this patch tested? Existing tests. Author: Herman van Hovell Closes #17433 from hvanhovell/SPARK-20100. --- .../sql/catalyst/catalog/SessionCatalog.scala | 46 +-- .../sql/catalyst/optimizer/Optimizer.scala | 16 +- .../catalog/SessionCatalogSuite.scala | 22 +- .../spark/sql/execution/SparkOptimizer.scala | 12 +- .../spark/sql/execution/SparkPlanner.scala | 11 +- .../streaming/IncrementalExecution.scala | 23 +- .../spark/sql/internal/SessionState.scala | 180 +++-------- .../sql/internal/sessionStateBuilders.scala | 279 ++++++++++++++++++ .../spark/sql/test/TestSQLContext.scala | 23 +- .../spark/sql/hive/HiveSessionCatalog.scala | 76 +---- .../spark/sql/hive/HiveSessionState.scala | 259 +++++++--------- .../apache/spark/sql/hive/test/TestHive.scala | 60 ++-- .../sql/hive/HiveSessionCatalogSuite.scala | 112 ------- 13 files changed, 547 insertions(+), 572 deletions(-) create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/internal/sessionStateBuilders.scala delete mode 100644 sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSessionCatalogSuite.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala index a469d12451643..72ab075408899 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala @@ -54,7 +54,8 @@ class SessionCatalog( functionRegistry: FunctionRegistry, conf: CatalystConf, hadoopConf: Configuration, - parser: ParserInterface) extends Logging { + parser: ParserInterface, + functionResourceLoader: FunctionResourceLoader) extends Logging { import SessionCatalog._ import CatalogTypes.TablePartitionSpec @@ -69,8 +70,8 @@ class SessionCatalog( functionRegistry, conf, new Configuration(), - CatalystSqlParser) - functionResourceLoader = DummyFunctionResourceLoader + CatalystSqlParser, + DummyFunctionResourceLoader) } // For testing only. @@ -90,9 +91,7 @@ class SessionCatalog( // check whether the temporary table or function exists, then, if not, operate on // the corresponding item in the current database. @GuardedBy("this") - protected var currentDb = formatDatabaseName(DEFAULT_DATABASE) - - @volatile var functionResourceLoader: FunctionResourceLoader = _ + protected var currentDb: String = formatDatabaseName(DEFAULT_DATABASE) /** * Checks if the given name conforms the Hive standard ("[a-zA-z_0-9]+"), @@ -1059,9 +1058,6 @@ class SessionCatalog( * by a tuple (resource type, resource uri). */ def loadFunctionResources(resources: Seq[FunctionResource]): Unit = { - if (functionResourceLoader == null) { - throw new IllegalStateException("functionResourceLoader has not yet been initialized") - } resources.foreach(functionResourceLoader.loadResource) } @@ -1259,28 +1255,16 @@ class SessionCatalog( } /** - * Create a new [[SessionCatalog]] with the provided parameters. `externalCatalog` and - * `globalTempViewManager` are `inherited`, while `currentDb` and `tempTables` are copied. + * Copy the current state of the catalog to another catalog. + * + * This function is synchronized on this [[SessionCatalog]] (the source) to make sure the copied + * state is consistent. The target [[SessionCatalog]] is not synchronized, and should not be + * because the target [[SessionCatalog]] should not be published at this point. The caller must + * synchronize on the target if this assumption does not hold. */ - def newSessionCatalogWith( - conf: CatalystConf, - hadoopConf: Configuration, - functionRegistry: FunctionRegistry, - parser: ParserInterface): SessionCatalog = { - val catalog = new SessionCatalog( - externalCatalog, - globalTempViewManager, - functionRegistry, - conf, - hadoopConf, - parser) - - synchronized { - catalog.currentDb = currentDb - // copy over temporary tables - tempTables.foreach(kv => catalog.tempTables.put(kv._1, kv._2)) - } - - catalog + private[sql] def copyStateTo(target: SessionCatalog): Unit = synchronized { + target.currentDb = currentDb + // copy over temporary tables + tempTables.foreach(kv => target.tempTables.put(kv._1, kv._2)) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index dbe3ded4bbf15..dbf479d215134 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -17,20 +17,14 @@ package org.apache.spark.sql.catalyst.optimizer -import scala.annotation.tailrec -import scala.collection.immutable.HashSet import scala.collection.mutable -import scala.collection.mutable.ArrayBuffer -import org.apache.spark.api.java.function.FilterFunction import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.{CatalystConf, SimpleCatalystConf} import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.catalog.{InMemoryCatalog, SessionCatalog} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ -import org.apache.spark.sql.catalyst.expressions.Literal.{FalseLiteral, TrueLiteral} -import org.apache.spark.sql.catalyst.planning.ExtractFiltersAndInnerJoins import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules._ @@ -79,7 +73,7 @@ abstract class Optimizer(sessionCatalog: SessionCatalog, conf: CatalystConf) Batch("Aggregate", fixedPoint, RemoveLiteralFromGroupExpressions, RemoveRepetitionFromGroupExpressions) :: - Batch("Operator Optimizations", fixedPoint, + Batch("Operator Optimizations", fixedPoint, Seq( // Operator push down PushProjectionThroughUnion, ReorderJoin(conf), @@ -117,7 +111,8 @@ abstract class Optimizer(sessionCatalog: SessionCatalog, conf: CatalystConf) RemoveRedundantProject, SimplifyCreateStructOps, SimplifyCreateArrayOps, - SimplifyCreateMapOps) :: + SimplifyCreateMapOps) ++ + extendedOperatorOptimizationRules: _*) :: Batch("Check Cartesian Products", Once, CheckCartesianProducts(conf)) :: Batch("Join Reorder", Once, @@ -146,6 +141,11 @@ abstract class Optimizer(sessionCatalog: SessionCatalog, conf: CatalystConf) s.withNewPlan(newPlan) } } + + /** + * Override to provide additional rules for the operator optimization batch. + */ + def extendedOperatorOptimizationRules: Seq[Rule[LogicalPlan]] = Nil } /** diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalogSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalogSuite.scala index ca4ce1c11707a..56bca73a8857a 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalogSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalogSuite.scala @@ -17,8 +17,6 @@ package org.apache.spark.sql.catalyst.catalog -import org.apache.hadoop.conf.Configuration - import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.{FunctionIdentifier, SimpleCatalystConf, TableIdentifier} import org.apache.spark.sql.catalyst.analysis._ @@ -1331,17 +1329,15 @@ abstract class SessionCatalogSuite extends PlanTest { } } - test("clone SessionCatalog - temp views") { + test("copy SessionCatalog state - temp views") { withEmptyCatalog { original => val tempTable1 = Range(1, 10, 1, 10) original.createTempView("copytest1", tempTable1, overrideIfExists = false) // check if tables copied over - val clone = original.newSessionCatalogWith( - SimpleCatalystConf(caseSensitiveAnalysis = true), - new Configuration(), - new SimpleFunctionRegistry, - CatalystSqlParser) + val clone = new SessionCatalog(original.externalCatalog) + original.copyStateTo(clone) + assert(original ne clone) assert(clone.getTempView("copytest1") == Some(tempTable1)) @@ -1355,7 +1351,7 @@ abstract class SessionCatalogSuite extends PlanTest { } } - test("clone SessionCatalog - current db") { + test("copy SessionCatalog state - current db") { withEmptyCatalog { original => val db1 = "db1" val db2 = "db2" @@ -1368,11 +1364,9 @@ abstract class SessionCatalogSuite extends PlanTest { original.setCurrentDatabase(db1) // check if current db copied over - val clone = original.newSessionCatalogWith( - SimpleCatalystConf(caseSensitiveAnalysis = true), - new Configuration(), - new SimpleFunctionRegistry, - CatalystSqlParser) + val clone = new SessionCatalog(original.externalCatalog) + original.copyStateTo(clone) + assert(original ne clone) assert(clone.getCurrentDatabase == db1) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala index 981728331d361..2cdfb7a7828c9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala @@ -30,9 +30,17 @@ class SparkOptimizer( experimentalMethods: ExperimentalMethods) extends Optimizer(catalog, conf) { - override def batches: Seq[Batch] = super.batches :+ + override def batches: Seq[Batch] = (super.batches :+ Batch("Optimize Metadata Only Query", Once, OptimizeMetadataOnlyQuery(catalog, conf)) :+ Batch("Extract Python UDF from Aggregate", Once, ExtractPythonUDFFromAggregate) :+ - Batch("Prune File Source Table Partitions", Once, PruneFileSourcePartitions) :+ + Batch("Prune File Source Table Partitions", Once, PruneFileSourcePartitions)) ++ + postHocOptimizationBatches :+ Batch("User Provided Optimizers", fixedPoint, experimentalMethods.extraOptimizations: _*) + + /** + * Optimization batches that are executed after the regular optimization batches, but before the + * batch executing the [[ExperimentalMethods]] optimizer rules. This hook can be used to add + * custom optimizer batches to the Spark optimizer. + */ + def postHocOptimizationBatches: Seq[Batch] = Nil } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanner.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanner.scala index 678241656c011..6566502bd8a8a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanner.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanner.scala @@ -27,13 +27,14 @@ import org.apache.spark.sql.internal.SQLConf class SparkPlanner( val sparkContext: SparkContext, val conf: SQLConf, - val extraStrategies: Seq[Strategy]) + val experimentalMethods: ExperimentalMethods) extends SparkStrategies { def numPartitions: Int = conf.numShufflePartitions def strategies: Seq[Strategy] = - extraStrategies ++ ( + experimentalMethods.extraStrategies ++ + extraPlanningStrategies ++ ( FileSourceStrategy :: DataSourceStrategy :: SpecialLimits :: @@ -42,6 +43,12 @@ class SparkPlanner( InMemoryScans :: BasicOperators :: Nil) + /** + * Override to add extra planning strategies to the planner. These strategies are tried after + * the strategies defined in [[ExperimentalMethods]], and before the regular strategies. + */ + def extraPlanningStrategies: Seq[Strategy] = Nil + override protected def collectPlaceholders(plan: SparkPlan): Seq[(SparkPlan, LogicalPlan)] = { plan.collect { case placeholder @ PlanLater(logicalPlan) => placeholder -> logicalPlan diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala index 0f0e4a91f8cc7..622e049630db2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala @@ -20,8 +20,8 @@ package org.apache.spark.sql.execution.streaming import java.util.concurrent.atomic.AtomicInteger import org.apache.spark.internal.Logging -import org.apache.spark.sql.catalyst.expressions.{CurrentBatchTimestamp, Literal} -import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.{SparkSession, Strategy} +import org.apache.spark.sql.catalyst.expressions.CurrentBatchTimestamp import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.execution.{QueryExecution, SparkPlan, SparkPlanner, UnaryExecNode} @@ -40,20 +40,17 @@ class IncrementalExecution( offsetSeqMetadata: OffsetSeqMetadata) extends QueryExecution(sparkSession, logicalPlan) with Logging { - // TODO: make this always part of planning. - val streamingExtraStrategies = - sparkSession.sessionState.planner.StatefulAggregationStrategy +: - sparkSession.sessionState.planner.FlatMapGroupsWithStateStrategy +: - sparkSession.sessionState.planner.StreamingRelationStrategy +: - sparkSession.sessionState.planner.StreamingDeduplicationStrategy +: - sparkSession.sessionState.experimentalMethods.extraStrategies - // Modified planner with stateful operations. - override def planner: SparkPlanner = - new SparkPlanner( + override val planner: SparkPlanner = new SparkPlanner( sparkSession.sparkContext, sparkSession.sessionState.conf, - streamingExtraStrategies) + sparkSession.sessionState.experimentalMethods) { + override def extraPlanningStrategies: Seq[Strategy] = + StatefulAggregationStrategy :: + FlatMapGroupsWithStateStrategy :: + StreamingRelationStrategy :: + StreamingDeduplicationStrategy :: Nil + } /** * See [SPARK-18339] diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala index ce80604bd3657..b5b0bb0bfc401 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala @@ -22,22 +22,21 @@ import java.io.File import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.Path -import org.apache.spark.{SparkConf, SparkContext} +import org.apache.spark.SparkContext +import org.apache.spark.annotation.{Experimental, InterfaceStability} import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.analysis.{Analyzer, FunctionRegistry} import org.apache.spark.sql.catalyst.catalog._ import org.apache.spark.sql.catalyst.optimizer.Optimizer import org.apache.spark.sql.catalyst.parser.ParserInterface import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan -import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.execution._ -import org.apache.spark.sql.execution.datasources._ import org.apache.spark.sql.streaming.StreamingQueryManager import org.apache.spark.sql.util.ExecutionListenerManager - /** * A class that holds all session-specific state in a given [[SparkSession]]. + * * @param sparkContext The [[SparkContext]]. * @param sharedState The shared state. * @param conf SQL-specific key-value configurations. @@ -46,9 +45,11 @@ import org.apache.spark.sql.util.ExecutionListenerManager * @param catalog Internal catalog for managing table and database states. * @param sqlParser Parser that extracts expressions, plans, table identifiers etc. from SQL texts. * @param analyzer Logical query plan analyzer for resolving unresolved attributes and relations. - * @param streamingQueryManager Interface to start and stop - * [[org.apache.spark.sql.streaming.StreamingQuery]]s. - * @param queryExecutionCreator Lambda to create a [[QueryExecution]] from a [[LogicalPlan]] + * @param optimizer Logical query plan optimizer. + * @param planner Planner that converts optimized logical plans to physical plans + * @param streamingQueryManager Interface to start and stop streaming queries. + * @param createQueryExecution Function used to create QueryExecution objects. + * @param createClone Function used to create clones of the session state. */ private[sql] class SessionState( sparkContext: SparkContext, @@ -59,8 +60,11 @@ private[sql] class SessionState( val catalog: SessionCatalog, val sqlParser: ParserInterface, val analyzer: Analyzer, + val optimizer: Optimizer, + val planner: SparkPlanner, val streamingQueryManager: StreamingQueryManager, - val queryExecutionCreator: LogicalPlan => QueryExecution) { + createQueryExecution: LogicalPlan => QueryExecution, + createClone: (SparkSession, SessionState) => SessionState) { def newHadoopConf(): Configuration = SessionState.newHadoopConf( sparkContext.hadoopConfiguration, @@ -76,41 +80,12 @@ private[sql] class SessionState( hadoopConf } - /** - * A class for loading resources specified by a function. - */ - val functionResourceLoader: FunctionResourceLoader = { - new FunctionResourceLoader { - override def loadResource(resource: FunctionResource): Unit = { - resource.resourceType match { - case JarResource => addJar(resource.uri) - case FileResource => sparkContext.addFile(resource.uri) - case ArchiveResource => - throw new AnalysisException( - "Archive is not allowed to be loaded. If YARN mode is used, " + - "please use --archives options while calling spark-submit.") - } - } - } - } - /** * Interface exposed to the user for registering user-defined functions. * Note that the user-defined functions must be deterministic. */ val udf: UDFRegistration = new UDFRegistration(functionRegistry) - /** - * Logical query plan optimizer. - */ - val optimizer: Optimizer = new SparkOptimizer(catalog, conf, experimentalMethods) - - /** - * Planner that converts optimized logical plans to physical plans. - */ - def planner: SparkPlanner = - new SparkPlanner(sparkContext, conf, experimentalMethods.extraStrategies) - /** * An interface to register custom [[org.apache.spark.sql.util.QueryExecutionListener]]s * that listen for execution metrics. @@ -120,38 +95,13 @@ private[sql] class SessionState( /** * Get an identical copy of the `SessionState` and associate it with the given `SparkSession` */ - def clone(newSparkSession: SparkSession): SessionState = { - val sparkContext = newSparkSession.sparkContext - val confCopy = conf.clone() - val functionRegistryCopy = functionRegistry.clone() - val sqlParser: ParserInterface = new SparkSqlParser(confCopy) - val catalogCopy = catalog.newSessionCatalogWith( - confCopy, - SessionState.newHadoopConf(sparkContext.hadoopConfiguration, confCopy), - functionRegistryCopy, - sqlParser) - val queryExecutionCreator = (plan: LogicalPlan) => new QueryExecution(newSparkSession, plan) - - SessionState.mergeSparkConf(confCopy, sparkContext.getConf) - - new SessionState( - sparkContext, - newSparkSession.sharedState, - confCopy, - experimentalMethods.clone(), - functionRegistryCopy, - catalogCopy, - sqlParser, - SessionState.createAnalyzer(newSparkSession, catalogCopy, confCopy), - new StreamingQueryManager(newSparkSession), - queryExecutionCreator) - } + def clone(newSparkSession: SparkSession): SessionState = createClone(newSparkSession, this) // ------------------------------------------------------ // Helper methods, partially leftover from pre-2.0 days // ------------------------------------------------------ - def executePlan(plan: LogicalPlan): QueryExecution = queryExecutionCreator(plan) + def executePlan(plan: LogicalPlan): QueryExecution = createQueryExecution(plan) def refreshTable(tableName: String): Unit = { catalog.refreshTable(sqlParser.parseTableIdentifier(tableName)) @@ -179,53 +129,12 @@ private[sql] class SessionState( } } - private[sql] object SessionState { - - def apply(sparkSession: SparkSession): SessionState = { - apply(sparkSession, new SQLConf) - } - - def apply(sparkSession: SparkSession, sqlConf: SQLConf): SessionState = { - val sparkContext = sparkSession.sparkContext - - // Automatically extract all entries and put them in our SQLConf - mergeSparkConf(sqlConf, sparkContext.getConf) - - val functionRegistry = FunctionRegistry.builtin.clone() - - val sqlParser: ParserInterface = new SparkSqlParser(sqlConf) - - val catalog = new SessionCatalog( - sparkSession.sharedState.externalCatalog, - sparkSession.sharedState.globalTempViewManager, - functionRegistry, - sqlConf, - newHadoopConf(sparkContext.hadoopConfiguration, sqlConf), - sqlParser) - - val analyzer: Analyzer = createAnalyzer(sparkSession, catalog, sqlConf) - - val streamingQueryManager: StreamingQueryManager = new StreamingQueryManager(sparkSession) - - val queryExecutionCreator = (plan: LogicalPlan) => new QueryExecution(sparkSession, plan) - - val sessionState = new SessionState( - sparkContext, - sparkSession.sharedState, - sqlConf, - new ExperimentalMethods, - functionRegistry, - catalog, - sqlParser, - analyzer, - streamingQueryManager, - queryExecutionCreator) - // functionResourceLoader needs to access SessionState.addJar, so it cannot be created before - // creating SessionState. Setting `catalog.functionResourceLoader` here is safe since the caller - // cannot use SessionCatalog before we return SessionState. - catalog.functionResourceLoader = sessionState.functionResourceLoader - sessionState + /** + * Create a new [[SessionState]] for the given session. + */ + def apply(session: SparkSession): SessionState = { + new SessionStateBuilder(session).build() } def newHadoopConf(hadoopConf: Configuration, sqlConf: SQLConf): Configuration = { @@ -233,34 +142,33 @@ private[sql] object SessionState { sqlConf.getAllConfs.foreach { case (k, v) => if (v ne null) newHadoopConf.set(k, v) } newHadoopConf } +} - /** - * Create an logical query plan `Analyzer` with rules specific to a non-Hive `SessionState`. - */ - private def createAnalyzer( - sparkSession: SparkSession, - catalog: SessionCatalog, - sqlConf: SQLConf): Analyzer = { - new Analyzer(catalog, sqlConf) { - override val extendedResolutionRules: Seq[Rule[LogicalPlan]] = - new FindDataSourceTable(sparkSession) :: - new ResolveSQLOnFile(sparkSession) :: Nil - - override val postHocResolutionRules: Seq[Rule[LogicalPlan]] = - PreprocessTableCreation(sparkSession) :: - PreprocessTableInsertion(sqlConf) :: - DataSourceAnalysis(sqlConf) :: Nil - - override val extendedCheckRules = Seq(PreWriteCheck, HiveOnlyCheck) - } - } +/** + * Concrete implementation of a [[SessionStateBuilder]]. + */ +@Experimental +@InterfaceStability.Unstable +class SessionStateBuilder( + session: SparkSession, + parentState: Option[SessionState] = None) + extends BaseSessionStateBuilder(session, parentState) { + override protected def newBuilder: NewBuilder = new SessionStateBuilder(_, _) +} - /** - * Extract entries from `SparkConf` and put them in the `SQLConf` - */ - def mergeSparkConf(sqlConf: SQLConf, sparkConf: SparkConf): Unit = { - sparkConf.getAll.foreach { case (k, v) => - sqlConf.setConfString(k, v) +/** + * Session shared [[FunctionResourceLoader]]. + */ +@InterfaceStability.Unstable +class SessionFunctionResourceLoader(session: SparkSession) extends FunctionResourceLoader { + override def loadResource(resource: FunctionResource): Unit = { + resource.resourceType match { + case JarResource => session.sessionState.addJar(resource.uri) + case FileResource => session.sparkContext.addFile(resource.uri) + case ArchiveResource => + throw new AnalysisException( + "Archive is not allowed to be loaded. If YARN mode is used, " + + "please use --archives options while calling spark-submit.") } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/sessionStateBuilders.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/sessionStateBuilders.scala new file mode 100644 index 0000000000000..6b5559adb1db4 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/sessionStateBuilders.scala @@ -0,0 +1,279 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.internal + +import org.apache.spark.SparkConf +import org.apache.spark.annotation.{Experimental, InterfaceStability} +import org.apache.spark.sql.{ExperimentalMethods, SparkSession, Strategy} +import org.apache.spark.sql.catalyst.analysis.{Analyzer, FunctionRegistry} +import org.apache.spark.sql.catalyst.catalog.SessionCatalog +import org.apache.spark.sql.catalyst.optimizer.Optimizer +import org.apache.spark.sql.catalyst.parser.ParserInterface +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.execution.{QueryExecution, SparkOptimizer, SparkPlanner, SparkSqlParser} +import org.apache.spark.sql.execution.datasources._ +import org.apache.spark.sql.streaming.StreamingQueryManager + +/** + * Builder class that coordinates construction of a new [[SessionState]]. + * + * The builder explicitly defines all components needed by the session state, and creates a session + * state when `build` is called. Components should only be initialized once. This is not a problem + * for most components as they are only used in the `build` function. However some components + * (`conf`, `catalog`, `functionRegistry`, `experimentalMethods` & `sqlParser`) are as dependencies + * for other components and are shared as a result. These components are defined as lazy vals to + * make sure the component is created only once. + * + * A developer can modify the builder by providing custom versions of components, or by using the + * hooks provided for the analyzer, optimizer & planner. There are some dependencies between the + * components (they are documented per dependency), a developer should respect these when making + * modifications in order to prevent initialization problems. + * + * A parent [[SessionState]] can be used to initialize the new [[SessionState]]. The new session + * state will clone the parent sessions state's `conf`, `functionRegistry`, `experimentalMethods` + * and `catalog` fields. Note that the state is cloned when `build` is called, and not before. + */ +@Experimental +@InterfaceStability.Unstable +abstract class BaseSessionStateBuilder( + val session: SparkSession, + val parentState: Option[SessionState] = None) { + type NewBuilder = (SparkSession, Option[SessionState]) => BaseSessionStateBuilder + + /** + * Function that produces a new instance of the SessionStateBuilder. This is used by the + * [[SessionState]]'s clone functionality. Make sure to override this when implementing your own + * [[SessionStateBuilder]]. + */ + protected def newBuilder: NewBuilder + + /** + * Extract entries from `SparkConf` and put them in the `SQLConf` + */ + protected def mergeSparkConf(sqlConf: SQLConf, sparkConf: SparkConf): Unit = { + sparkConf.getAll.foreach { case (k, v) => + sqlConf.setConfString(k, v) + } + } + + /** + * SQL-specific key-value configurations. + * + * These either get cloned from a pre-existing instance or newly created. The conf is always + * merged with its [[SparkConf]]. + */ + protected lazy val conf: SQLConf = { + val conf = parentState.map(_.conf.clone()).getOrElse(new SQLConf) + mergeSparkConf(conf, session.sparkContext.conf) + conf + } + + /** + * Internal catalog managing functions registered by the user. + * + * This either gets cloned from a pre-existing version or cloned from the built-in registry. + */ + protected lazy val functionRegistry: FunctionRegistry = { + parentState.map(_.functionRegistry).getOrElse(FunctionRegistry.builtin).clone() + } + + /** + * Experimental methods that can be used to define custom optimization rules and custom planning + * strategies. + * + * This either gets cloned from a pre-existing version or newly created. + */ + protected lazy val experimentalMethods: ExperimentalMethods = { + parentState.map(_.experimentalMethods.clone()).getOrElse(new ExperimentalMethods) + } + + /** + * Parser that extracts expressions, plans, table identifiers etc. from SQL texts. + * + * Note: this depends on the `conf` field. + */ + protected lazy val sqlParser: ParserInterface = new SparkSqlParser(conf) + + /** + * Catalog for managing table and database states. If there is a pre-existing catalog, the state + * of that catalog (temp tables & current database) will be copied into the new catalog. + * + * Note: this depends on the `conf`, `functionRegistry` and `sqlParser` fields. + */ + protected lazy val catalog: SessionCatalog = { + val catalog = new SessionCatalog( + session.sharedState.externalCatalog, + session.sharedState.globalTempViewManager, + functionRegistry, + conf, + SessionState.newHadoopConf(session.sparkContext.hadoopConfiguration, conf), + sqlParser, + new SessionFunctionResourceLoader(session)) + parentState.foreach(_.catalog.copyStateTo(catalog)) + catalog + } + + /** + * Logical query plan analyzer for resolving unresolved attributes and relations. + * + * Note: this depends on the `conf` and `catalog` fields. + */ + protected def analyzer: Analyzer = new Analyzer(catalog, conf) { + override val extendedResolutionRules: Seq[Rule[LogicalPlan]] = + new FindDataSourceTable(session) +: + new ResolveSQLOnFile(session) +: + customResolutionRules + + override val postHocResolutionRules: Seq[Rule[LogicalPlan]] = + PreprocessTableCreation(session) +: + PreprocessTableInsertion(conf) +: + DataSourceAnalysis(conf) +: + customPostHocResolutionRules + + override val extendedCheckRules: Seq[LogicalPlan => Unit] = + PreWriteCheck +: + HiveOnlyCheck +: + customCheckRules + } + + /** + * Custom resolution rules to add to the Analyzer. Prefer overriding this instead of creating + * your own Analyzer. + * + * Note that this may NOT depend on the `analyzer` function. + */ + protected def customResolutionRules: Seq[Rule[LogicalPlan]] = Nil + + /** + * Custom post resolution rules to add to the Analyzer. Prefer overriding this instead of + * creating your own Analyzer. + * + * Note that this may NOT depend on the `analyzer` function. + */ + protected def customPostHocResolutionRules: Seq[Rule[LogicalPlan]] = Nil + + /** + * Custom check rules to add to the Analyzer. Prefer overriding this instead of creating + * your own Analyzer. + * + * Note that this may NOT depend on the `analyzer` function. + */ + protected def customCheckRules: Seq[LogicalPlan => Unit] = Nil + + /** + * Logical query plan optimizer. + * + * Note: this depends on the `conf`, `catalog` and `experimentalMethods` fields. + */ + protected def optimizer: Optimizer = { + new SparkOptimizer(catalog, conf, experimentalMethods) { + override def extendedOperatorOptimizationRules: Seq[Rule[LogicalPlan]] = + super.extendedOperatorOptimizationRules ++ customOperatorOptimizationRules + } + } + + /** + * Custom operator optimization rules to add to the Optimizer. Prefer overriding this instead + * of creating your own Optimizer. + * + * Note that this may NOT depend on the `optimizer` function. + */ + protected def customOperatorOptimizationRules: Seq[Rule[LogicalPlan]] = Nil + + /** + * Planner that converts optimized logical plans to physical plans. + * + * Note: this depends on the `conf` and `experimentalMethods` fields. + */ + protected def planner: SparkPlanner = { + new SparkPlanner(session.sparkContext, conf, experimentalMethods) { + override def extraPlanningStrategies: Seq[Strategy] = + super.extraPlanningStrategies ++ customPlanningStrategies + } + } + + /** + * Custom strategies to add to the planner. Prefer overriding this instead of creating + * your own Planner. + * + * Note that this may NOT depend on the `planner` function. + */ + protected def customPlanningStrategies: Seq[Strategy] = Nil + + /** + * Create a query execution object. + */ + protected def createQueryExecution: LogicalPlan => QueryExecution = { plan => + new QueryExecution(session, plan) + } + + /** + * Interface to start and stop streaming queries. + */ + protected def streamingQueryManager: StreamingQueryManager = new StreamingQueryManager(session) + + /** + * Function used to make clones of the session state. + */ + protected def createClone: (SparkSession, SessionState) => SessionState = { + val createBuilder = newBuilder + (session, state) => createBuilder(session, Option(state)).build() + } + + /** + * Build the [[SessionState]]. + */ + def build(): SessionState = { + new SessionState( + session.sparkContext, + session.sharedState, + conf, + experimentalMethods, + functionRegistry, + catalog, + sqlParser, + analyzer, + optimizer, + planner, + streamingQueryManager, + createQueryExecution, + createClone) + } +} + +/** + * Helper class for using SessionStateBuilders during tests. + */ +private[sql] trait WithTestConf { self: BaseSessionStateBuilder => + def overrideConfs: Map[String, String] + + override protected lazy val conf: SQLConf = { + val conf = parentState.map(_.conf.clone()).getOrElse { + new SQLConf { + clear() + override def clear(): Unit = { + super.clear() + // Make sure we start with the default test configs even after clear + overrideConfs.foreach { case (key, value) => setConfString(key, value) } + } + } + } + mergeSparkConf(conf, session.sparkContext.conf) + conf + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/TestSQLContext.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/TestSQLContext.scala index 898a2fb4f329b..b01977a23890f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/test/TestSQLContext.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/TestSQLContext.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.test import org.apache.spark.{SparkConf, SparkContext} import org.apache.spark.sql.SparkSession -import org.apache.spark.sql.internal.{SessionState, SQLConf} +import org.apache.spark.sql.internal.{SessionState, SessionStateBuilder, SQLConf, WithTestConf} /** * A special [[SparkSession]] prepared for testing. @@ -35,16 +35,9 @@ private[sql] class TestSparkSession(sc: SparkContext) extends SparkSession(sc) { } @transient - override lazy val sessionState: SessionState = SessionState( - this, - new SQLConf { - clear() - override def clear(): Unit = { - super.clear() - // Make sure we start with the default test configs even after clear - TestSQLContext.overrideConfs.foreach { case (key, value) => setConfString(key, value) } - } - }) + override lazy val sessionState: SessionState = { + new TestSQLSessionStateBuilder(this, None).build() + } // Needed for Java tests def loadTestData(): Unit = { @@ -67,3 +60,11 @@ private[sql] object TestSQLContext { // Fewer shuffle partitions to speed up testing. SQLConf.SHUFFLE_PARTITIONS.key -> "5") } + +private[sql] class TestSQLSessionStateBuilder( + session: SparkSession, + state: Option[SessionState]) + extends SessionStateBuilder(session, state) with WithTestConf { + override def overrideConfs: Map[String, String] = TestSQLContext.overrideConfs + override def newBuilder: NewBuilder = new TestSQLSessionStateBuilder(_, _) +} diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala index 6b7599e3d3401..2cc20a791d80c 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala @@ -25,8 +25,8 @@ import org.apache.hadoop.hive.ql.exec.{UDAF, UDF} import org.apache.hadoop.hive.ql.exec.{FunctionRegistry => HiveFunctionRegistry} import org.apache.hadoop.hive.ql.udf.generic.{AbstractGenericUDAFResolver, GenericUDF, GenericUDTF} -import org.apache.spark.sql.{AnalysisException, SparkSession} -import org.apache.spark.sql.catalyst.{CatalystConf, FunctionIdentifier, TableIdentifier} +import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.catalyst.{FunctionIdentifier, TableIdentifier} import org.apache.spark.sql.catalyst.analysis.FunctionRegistry import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FunctionBuilder import org.apache.spark.sql.catalyst.catalog.{FunctionResourceLoader, GlobalTempViewManager, SessionCatalog} @@ -47,14 +47,16 @@ private[sql] class HiveSessionCatalog( functionRegistry: FunctionRegistry, conf: SQLConf, hadoopConf: Configuration, - parser: ParserInterface) + parser: ParserInterface, + functionResourceLoader: FunctionResourceLoader) extends SessionCatalog( externalCatalog, globalTempViewManager, functionRegistry, conf, hadoopConf, - parser) { + parser, + functionResourceLoader) { // ---------------------------------------------------------------- // | Methods and fields for interacting with HiveMetastoreCatalog | @@ -69,47 +71,6 @@ private[sql] class HiveSessionCatalog( metastoreCatalog.hiveDefaultTableFilePath(name) } - /** - * Create a new [[HiveSessionCatalog]] with the provided parameters. `externalCatalog` and - * `globalTempViewManager` are `inherited`, while `currentDb` and `tempTables` are copied. - */ - def newSessionCatalogWith( - newSparkSession: SparkSession, - conf: SQLConf, - hadoopConf: Configuration, - functionRegistry: FunctionRegistry, - parser: ParserInterface): HiveSessionCatalog = { - val catalog = HiveSessionCatalog( - newSparkSession, - functionRegistry, - conf, - hadoopConf, - parser) - - synchronized { - catalog.currentDb = currentDb - // copy over temporary tables - tempTables.foreach(kv => catalog.tempTables.put(kv._1, kv._2)) - } - - catalog - } - - /** - * The parent class [[SessionCatalog]] cannot access the [[SparkSession]] class, so we cannot add - * a [[SparkSession]] parameter to [[SessionCatalog.newSessionCatalogWith]]. However, - * [[HiveSessionCatalog]] requires a [[SparkSession]] parameter, so we can a new version of - * `newSessionCatalogWith` and disable this one. - * - * TODO Refactor HiveSessionCatalog to not use [[SparkSession]] directly. - */ - override def newSessionCatalogWith( - conf: CatalystConf, - hadoopConf: Configuration, - functionRegistry: FunctionRegistry, - parser: ParserInterface): HiveSessionCatalog = throw new UnsupportedOperationException( - "to clone HiveSessionCatalog, use the other clone method that also accepts a SparkSession") - // For testing only private[hive] def getCachedDataSourceTable(table: TableIdentifier): LogicalPlan = { val key = metastoreCatalog.getQualifiedTableName(table) @@ -250,28 +211,3 @@ private[sql] class HiveSessionCatalog( "histogram_numeric" ) } - -private[sql] object HiveSessionCatalog { - - def apply( - sparkSession: SparkSession, - functionRegistry: FunctionRegistry, - conf: SQLConf, - hadoopConf: Configuration, - parser: ParserInterface): HiveSessionCatalog = { - // Catalog for handling data source tables. TODO: This really doesn't belong here since it is - // essentially a cache for metastore tables. However, it relies on a lot of session-specific - // things so it would be a lot of work to split its functionality between HiveSessionCatalog - // and HiveCatalog. We should still do it at some point... - val metastoreCatalog = new HiveMetastoreCatalog(sparkSession) - - new HiveSessionCatalog( - sparkSession.sharedState.externalCatalog.asInstanceOf[HiveExternalCatalog], - sparkSession.sharedState.globalTempViewManager, - metastoreCatalog, - functionRegistry, - conf, - hadoopConf, - parser) - } -} diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionState.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionState.scala index cb8bcb8591bd6..49ff8478f1ae2 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionState.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionState.scala @@ -18,20 +18,23 @@ package org.apache.spark.sql.hive import org.apache.spark.SparkContext +import org.apache.spark.annotation.{Experimental, InterfaceStability} import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.analysis.{Analyzer, FunctionRegistry} +import org.apache.spark.sql.catalyst.optimizer.Optimizer import org.apache.spark.sql.catalyst.parser.ParserInterface import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.rules.Rule -import org.apache.spark.sql.execution.{QueryExecution, SparkPlanner, SparkSqlParser} +import org.apache.spark.sql.execution.{QueryExecution, SparkPlanner} import org.apache.spark.sql.execution.datasources._ import org.apache.spark.sql.hive.client.HiveClient -import org.apache.spark.sql.internal.{SessionState, SharedState, SQLConf} +import org.apache.spark.sql.internal.{BaseSessionStateBuilder, SessionFunctionResourceLoader, SessionState, SharedState, SQLConf} import org.apache.spark.sql.streaming.StreamingQueryManager /** * A class that holds all session-specific state in a given [[SparkSession]] backed by Hive. + * * @param sparkContext The [[SparkContext]]. * @param sharedState The shared state. * @param conf SQL-specific key-value configurations. @@ -40,12 +43,14 @@ import org.apache.spark.sql.streaming.StreamingQueryManager * @param catalog Internal catalog for managing table and database states that uses Hive client for * interacting with the metastore. * @param sqlParser Parser that extracts expressions, plans, table identifiers etc. from SQL texts. - * @param metadataHive The Hive metadata client. * @param analyzer Logical query plan analyzer for resolving unresolved attributes and relations. - * @param streamingQueryManager Interface to start and stop - * [[org.apache.spark.sql.streaming.StreamingQuery]]s. - * @param queryExecutionCreator Lambda to create a [[QueryExecution]] from a [[LogicalPlan]] - * @param plannerCreator Lambda to create a planner that takes into account Hive-specific strategies + * @param optimizer Logical query plan optimizer. + * @param planner Planner that converts optimized logical plans to physical plans and that takes + * Hive-specific strategies into account. + * @param streamingQueryManager Interface to start and stop streaming queries. + * @param createQueryExecution Function used to create QueryExecution objects. + * @param createClone Function used to create clones of the session state. + * @param metadataHive The Hive metadata client. */ private[hive] class HiveSessionState( sparkContext: SparkContext, @@ -55,11 +60,13 @@ private[hive] class HiveSessionState( functionRegistry: FunctionRegistry, override val catalog: HiveSessionCatalog, sqlParser: ParserInterface, - val metadataHive: HiveClient, analyzer: Analyzer, + optimizer: Optimizer, + planner: SparkPlanner, streamingQueryManager: StreamingQueryManager, - queryExecutionCreator: LogicalPlan => QueryExecution, - val plannerCreator: () => SparkPlanner) + createQueryExecution: LogicalPlan => QueryExecution, + createClone: (SparkSession, SessionState) => SessionState, + val metadataHive: HiveClient) extends SessionState( sparkContext, sharedState, @@ -69,14 +76,11 @@ private[hive] class HiveSessionState( catalog, sqlParser, analyzer, + optimizer, + planner, streamingQueryManager, - queryExecutionCreator) { self => - - /** - * Planner that takes into account Hive-specific strategies. - */ - override def planner: SparkPlanner = plannerCreator() - + createQueryExecution, + createClone) { // ------------------------------------------------------ // Helper methods, partially leftover from pre-2.0 days @@ -121,150 +125,115 @@ private[hive] class HiveSessionState( def hiveThriftServerAsync: Boolean = { conf.getConf(HiveUtils.HIVE_THRIFT_SERVER_ASYNC) } +} +private[hive] object HiveSessionState { /** - * Get an identical copy of the `HiveSessionState`. - * This should ideally reuse the `SessionState.clone` but cannot do so. - * Doing that will throw an exception when trying to clone the catalog. + * Create a new [[HiveSessionState]] for the given session. */ - override def clone(newSparkSession: SparkSession): HiveSessionState = { - val sparkContext = newSparkSession.sparkContext - val confCopy = conf.clone() - val functionRegistryCopy = functionRegistry.clone() - val experimentalMethodsCopy = experimentalMethods.clone() - val sqlParser: ParserInterface = new SparkSqlParser(confCopy) - val catalogCopy = catalog.newSessionCatalogWith( - newSparkSession, - confCopy, - SessionState.newHadoopConf(sparkContext.hadoopConfiguration, confCopy), - functionRegistryCopy, - sqlParser) - val queryExecutionCreator = (plan: LogicalPlan) => new QueryExecution(newSparkSession, plan) - - val hiveClient = - newSparkSession.sharedState.externalCatalog.asInstanceOf[HiveExternalCatalog].client - .newSession() - - SessionState.mergeSparkConf(confCopy, sparkContext.getConf) - - new HiveSessionState( - sparkContext, - newSparkSession.sharedState, - confCopy, - experimentalMethodsCopy, - functionRegistryCopy, - catalogCopy, - sqlParser, - hiveClient, - HiveSessionState.createAnalyzer(newSparkSession, catalogCopy, confCopy), - new StreamingQueryManager(newSparkSession), - queryExecutionCreator, - HiveSessionState.createPlannerCreator( - newSparkSession, - confCopy, - experimentalMethodsCopy)) + def apply(session: SparkSession): HiveSessionState = { + new HiveSessionStateBuilder(session).build() } - } -private[hive] object HiveSessionState { - - def apply(sparkSession: SparkSession): HiveSessionState = { - apply(sparkSession, new SQLConf) - } - - def apply(sparkSession: SparkSession, conf: SQLConf): HiveSessionState = { - val initHelper = SessionState(sparkSession, conf) - - val sparkContext = sparkSession.sparkContext - - val catalog = HiveSessionCatalog( - sparkSession, - initHelper.functionRegistry, - initHelper.conf, - SessionState.newHadoopConf(sparkContext.hadoopConfiguration, initHelper.conf), - initHelper.sqlParser) - - val metadataHive: HiveClient = - sparkSession.sharedState.externalCatalog.asInstanceOf[HiveExternalCatalog].client - .newSession() - - val analyzer: Analyzer = createAnalyzer(sparkSession, catalog, initHelper.conf) +/** + * Builder that produces a [[HiveSessionState]]. + */ +@Experimental +@InterfaceStability.Unstable +class HiveSessionStateBuilder(session: SparkSession, parentState: Option[SessionState] = None) + extends BaseSessionStateBuilder(session, parentState) { - val plannerCreator = createPlannerCreator( - sparkSession, - initHelper.conf, - initHelper.experimentalMethods) + private def externalCatalog: HiveExternalCatalog = + session.sharedState.externalCatalog.asInstanceOf[HiveExternalCatalog] - val hiveSessionState = new HiveSessionState( - sparkContext, - sparkSession.sharedState, - initHelper.conf, - initHelper.experimentalMethods, - initHelper.functionRegistry, - catalog, - initHelper.sqlParser, - metadataHive, - analyzer, - initHelper.streamingQueryManager, - initHelper.queryExecutionCreator, - plannerCreator) - catalog.functionResourceLoader = hiveSessionState.functionResourceLoader - hiveSessionState + /** + * Create a [[HiveSessionCatalog]]. + */ + override protected lazy val catalog: HiveSessionCatalog = { + val catalog = new HiveSessionCatalog( + externalCatalog, + session.sharedState.globalTempViewManager, + new HiveMetastoreCatalog(session), + functionRegistry, + conf, + SessionState.newHadoopConf(session.sparkContext.hadoopConfiguration, conf), + sqlParser, + new SessionFunctionResourceLoader(session)) + parentState.foreach(_.catalog.copyStateTo(catalog)) + catalog } /** - * Create an logical query plan `Analyzer` with rules specific to a `HiveSessionState`. + * A logical query plan `Analyzer` with rules specific to Hive. */ - private def createAnalyzer( - sparkSession: SparkSession, - catalog: HiveSessionCatalog, - sqlConf: SQLConf): Analyzer = { - new Analyzer(catalog, sqlConf) { - override val extendedResolutionRules: Seq[Rule[LogicalPlan]] = - new ResolveHiveSerdeTable(sparkSession) :: - new FindDataSourceTable(sparkSession) :: - new ResolveSQLOnFile(sparkSession) :: Nil - - override val postHocResolutionRules: Seq[Rule[LogicalPlan]] = - new DetermineTableStats(sparkSession) :: - catalog.ParquetConversions :: - catalog.OrcConversions :: - PreprocessTableCreation(sparkSession) :: - PreprocessTableInsertion(sqlConf) :: - DataSourceAnalysis(sqlConf) :: - HiveAnalysis :: Nil + override protected def analyzer: Analyzer = new Analyzer(catalog, conf) { + override val extendedResolutionRules: Seq[Rule[LogicalPlan]] = + new ResolveHiveSerdeTable(session) +: + new FindDataSourceTable(session) +: + new ResolveSQLOnFile(session) +: + customResolutionRules + + override val postHocResolutionRules: Seq[Rule[LogicalPlan]] = + new DetermineTableStats(session) +: + catalog.ParquetConversions +: + catalog.OrcConversions +: + PreprocessTableCreation(session) +: + PreprocessTableInsertion(conf) +: + DataSourceAnalysis(conf) +: + HiveAnalysis +: + customPostHocResolutionRules + + override val extendedCheckRules: Seq[LogicalPlan => Unit] = + PreWriteCheck +: + customCheckRules + } - override val extendedCheckRules = Seq(PreWriteCheck) + /** + * Planner that takes into account Hive-specific strategies. + */ + override protected def planner: SparkPlanner = { + new SparkPlanner(session.sparkContext, conf, experimentalMethods) with HiveStrategies { + override val sparkSession: SparkSession = session + + override def extraPlanningStrategies: Seq[Strategy] = + super.extraPlanningStrategies ++ customPlanningStrategies + + override def strategies: Seq[Strategy] = { + experimentalMethods.extraStrategies ++ + extraPlanningStrategies ++ Seq( + FileSourceStrategy, + DataSourceStrategy, + SpecialLimits, + InMemoryScans, + HiveTableScans, + Scripts, + Aggregation, + JoinSelection, + BasicOperators + ) + } } } - private def createPlannerCreator( - associatedSparkSession: SparkSession, - sqlConf: SQLConf, - experimentalMethods: ExperimentalMethods): () => SparkPlanner = { - () => - new SparkPlanner( - associatedSparkSession.sparkContext, - sqlConf, - experimentalMethods.extraStrategies) - with HiveStrategies { - - override val sparkSession: SparkSession = associatedSparkSession + override protected def newBuilder: NewBuilder = new HiveSessionStateBuilder(_, _) - override def strategies: Seq[Strategy] = { - experimentalMethods.extraStrategies ++ Seq( - FileSourceStrategy, - DataSourceStrategy, - SpecialLimits, - InMemoryScans, - HiveTableScans, - Scripts, - Aggregation, - JoinSelection, - BasicOperators - ) - } - } + override def build(): HiveSessionState = { + val metadataHive: HiveClient = externalCatalog.client.newSession() + new HiveSessionState( + session.sparkContext, + session.sharedState, + conf, + experimentalMethods, + functionRegistry, + catalog, + sqlParser, + analyzer, + optimizer, + planner, + streamingQueryManager, + createQueryExecution, + createClone, + metadataHive) } } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala index b63ed76967bd9..32ca69605ef4d 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala @@ -40,7 +40,7 @@ import org.apache.spark.sql.execution.QueryExecution import org.apache.spark.sql.execution.command.CacheTableCommand import org.apache.spark.sql.hive._ import org.apache.spark.sql.hive.client.HiveClient -import org.apache.spark.sql.internal.{SessionState, SharedState, SQLConf} +import org.apache.spark.sql.internal._ import org.apache.spark.sql.internal.StaticSQLConf.CATALOG_IMPLEMENTATION import org.apache.spark.util.{ShutdownHookManager, Utils} @@ -148,12 +148,14 @@ class TestHiveContext( * * @param sc SparkContext * @param existingSharedState optional [[SharedState]] + * @param parentSessionState optional parent [[SessionState]] * @param loadTestTables if true, load the test tables. They can only be loaded when running * in the JVM, i.e when calling from Python this flag has to be false. */ private[hive] class TestHiveSparkSession( @transient private val sc: SparkContext, @transient private val existingSharedState: Option[TestHiveSharedState], + @transient private val parentSessionState: Option[HiveSessionState], private val loadTestTables: Boolean) extends SparkSession(sc) with Logging { self => @@ -161,6 +163,7 @@ private[hive] class TestHiveSparkSession( this( sc, existingSharedState = None, + parentSessionState = None, loadTestTables) } @@ -168,6 +171,7 @@ private[hive] class TestHiveSparkSession( this( sc, existingSharedState = Some(new TestHiveSharedState(sc, Some(hiveClient))), + parentSessionState = None, loadTestTables) } @@ -192,36 +196,21 @@ private[hive] class TestHiveSparkSession( @transient override lazy val sessionState: HiveSessionState = { - val testConf = - new SQLConf { - clear() - override def caseSensitiveAnalysis: Boolean = getConf(SQLConf.CASE_SENSITIVE, false) - override def clear(): Unit = { - super.clear() - TestHiveContext.overrideConfs.foreach { case (k, v) => setConfString(k, v) } - } - } - val queryExecutionCreator = (plan: LogicalPlan) => new TestHiveQueryExecution(this, plan) - val initHelper = HiveSessionState(this, testConf) - SessionState.mergeSparkConf(testConf, sparkContext.getConf) - - new HiveSessionState( - sparkContext, - sharedState, - testConf, - initHelper.experimentalMethods, - initHelper.functionRegistry, - initHelper.catalog, - initHelper.sqlParser, - initHelper.metadataHive, - initHelper.analyzer, - initHelper.streamingQueryManager, - queryExecutionCreator, - initHelper.plannerCreator) + new TestHiveSessionStateBuilder(this, parentSessionState).build() } override def newSession(): TestHiveSparkSession = { - new TestHiveSparkSession(sc, Some(sharedState), loadTestTables) + new TestHiveSparkSession(sc, Some(sharedState), None, loadTestTables) + } + + override def cloneSession(): SparkSession = { + val result = new TestHiveSparkSession( + sparkContext, + Some(sharedState), + Some(sessionState), + loadTestTables) + result.sessionState // force copy of SessionState + result } private var cacheTables: Boolean = false @@ -595,3 +584,18 @@ private[hive] object TestHiveContext { } } + +private[sql] class TestHiveSessionStateBuilder( + session: SparkSession, + state: Option[SessionState]) + extends HiveSessionStateBuilder(session, state) + with WithTestConf { + + override def overrideConfs: Map[String, String] = TestHiveContext.overrideConfs + + override def createQueryExecution: (LogicalPlan) => QueryExecution = { plan => + new TestHiveQueryExecution(session.asInstanceOf[TestHiveSparkSession], plan) + } + + override protected def newBuilder: NewBuilder = new TestHiveSessionStateBuilder(_, _) +} diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSessionCatalogSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSessionCatalogSuite.scala deleted file mode 100644 index 3b0f59b15916c..0000000000000 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSessionCatalogSuite.scala +++ /dev/null @@ -1,112 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.hive - -import java.net.URI - -import org.apache.hadoop.conf.Configuration - -import org.apache.spark.sql.catalyst.TableIdentifier -import org.apache.spark.sql.catalyst.analysis.SimpleFunctionRegistry -import org.apache.spark.sql.catalyst.catalog.CatalogDatabase -import org.apache.spark.sql.catalyst.parser.CatalystSqlParser -import org.apache.spark.sql.catalyst.plans.logical.Range -import org.apache.spark.sql.hive.test.TestHiveSingleton -import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.util.Utils - -class HiveSessionCatalogSuite extends TestHiveSingleton { - - test("clone HiveSessionCatalog") { - val original = spark.sessionState.catalog.asInstanceOf[HiveSessionCatalog] - - val tempTableName1 = "copytest1" - val tempTableName2 = "copytest2" - try { - val tempTable1 = Range(1, 10, 1, 10) - original.createTempView(tempTableName1, tempTable1, overrideIfExists = false) - - // check if tables copied over - val clone = original.newSessionCatalogWith( - spark, - new SQLConf, - new Configuration(), - new SimpleFunctionRegistry, - CatalystSqlParser) - assert(original ne clone) - assert(clone.getTempView(tempTableName1) == Some(tempTable1)) - - // check if clone and original independent - clone.dropTable(TableIdentifier(tempTableName1), ignoreIfNotExists = false, purge = false) - assert(original.getTempView(tempTableName1) == Some(tempTable1)) - - val tempTable2 = Range(1, 20, 2, 10) - original.createTempView(tempTableName2, tempTable2, overrideIfExists = false) - assert(clone.getTempView(tempTableName2).isEmpty) - } finally { - // Drop the created temp views from the global singleton HiveSession. - original.dropTable(TableIdentifier(tempTableName1), ignoreIfNotExists = true, purge = true) - original.dropTable(TableIdentifier(tempTableName2), ignoreIfNotExists = true, purge = true) - } - } - - test("clone SessionCatalog - current db") { - val original = spark.sessionState.catalog.asInstanceOf[HiveSessionCatalog] - val originalCurrentDatabase = original.getCurrentDatabase - val db1 = "db1" - val db2 = "db2" - val db3 = "db3" - try { - original.createDatabase(newDb(db1), ignoreIfExists = true) - original.createDatabase(newDb(db2), ignoreIfExists = true) - original.createDatabase(newDb(db3), ignoreIfExists = true) - - original.setCurrentDatabase(db1) - - // check if tables copied over - val clone = original.newSessionCatalogWith( - spark, - new SQLConf, - new Configuration(), - new SimpleFunctionRegistry, - CatalystSqlParser) - - // check if current db copied over - assert(original ne clone) - assert(clone.getCurrentDatabase == db1) - - // check if clone and original independent - clone.setCurrentDatabase(db2) - assert(original.getCurrentDatabase == db1) - original.setCurrentDatabase(db3) - assert(clone.getCurrentDatabase == db2) - } finally { - // Drop the created databases from the global singleton HiveSession. - original.dropDatabase(db1, ignoreIfNotExists = true, cascade = true) - original.dropDatabase(db2, ignoreIfNotExists = true, cascade = true) - original.dropDatabase(db3, ignoreIfNotExists = true, cascade = true) - original.setCurrentDatabase(originalCurrentDatabase) - } - } - - def newUriForDatabase(): URI = new URI(Utils.createTempDir().toURI.toString.stripSuffix("/")) - - def newDb(name: String): CatalogDatabase = { - CatalogDatabase(name, name + " description", newUriForDatabase(), Map.empty) - } -} From 6c70a38c2e60e1b69a310aee1a92ee0b3815c02d Mon Sep 17 00:00:00 2001 From: Michal Senkyr Date: Tue, 28 Mar 2017 10:09:49 +0800 Subject: [PATCH 0132/1765] [SPARK-19088][SQL] Optimize sequence type deserialization codegen ## What changes were proposed in this pull request? Optimization of arbitrary Scala sequence deserialization introduced by #16240. The previous implementation constructed an array which was then converted by `to`. This required two passes in most cases. This implementation attempts to remedy that by using `Builder`s provided by the `newBuilder` method on every Scala collection's companion object to build the resulting collection directly. Example codegen for simple `List` (obtained using `Seq(List(1)).toDS().map(identity).queryExecution.debug.codegen`): Before: ``` /* 001 */ public Object generate(Object[] references) { /* 002 */ return new GeneratedIterator(references); /* 003 */ } /* 004 */ /* 005 */ final class GeneratedIterator extends org.apache.spark.sql.execution.BufferedRowIterator { /* 006 */ private Object[] references; /* 007 */ private scala.collection.Iterator[] inputs; /* 008 */ private scala.collection.Iterator inputadapter_input; /* 009 */ private boolean deserializetoobject_resultIsNull; /* 010 */ private java.lang.Object[] deserializetoobject_argValue; /* 011 */ private boolean MapObjects_loopIsNull1; /* 012 */ private int MapObjects_loopValue0; /* 013 */ private boolean deserializetoobject_resultIsNull1; /* 014 */ private scala.collection.generic.CanBuildFrom deserializetoobject_argValue1; /* 015 */ private UnsafeRow deserializetoobject_result; /* 016 */ private org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder deserializetoobject_holder; /* 017 */ private org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter deserializetoobject_rowWriter; /* 018 */ private scala.collection.immutable.List mapelements_argValue; /* 019 */ private UnsafeRow mapelements_result; /* 020 */ private org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder mapelements_holder; /* 021 */ private org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter mapelements_rowWriter; /* 022 */ private scala.collection.immutable.List serializefromobject_argValue; /* 023 */ private UnsafeRow serializefromobject_result; /* 024 */ private org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder serializefromobject_holder; /* 025 */ private org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter serializefromobject_rowWriter; /* 026 */ private org.apache.spark.sql.catalyst.expressions.codegen.UnsafeArrayWriter serializefromobject_arrayWriter; /* 027 */ /* 028 */ public GeneratedIterator(Object[] references) { /* 029 */ this.references = references; /* 030 */ } /* 031 */ /* 032 */ public void init(int index, scala.collection.Iterator[] inputs) { /* 033 */ partitionIndex = index; /* 034 */ this.inputs = inputs; /* 035 */ inputadapter_input = inputs[0]; /* 036 */ /* 037 */ deserializetoobject_result = new UnsafeRow(1); /* 038 */ this.deserializetoobject_holder = new org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder(deserializetoobject_result, 32); /* 039 */ this.deserializetoobject_rowWriter = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter(deserializetoobject_holder, 1); /* 040 */ /* 041 */ mapelements_result = new UnsafeRow(1); /* 042 */ this.mapelements_holder = new org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder(mapelements_result, 32); /* 043 */ this.mapelements_rowWriter = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter(mapelements_holder, 1); /* 044 */ /* 045 */ serializefromobject_result = new UnsafeRow(1); /* 046 */ this.serializefromobject_holder = new org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder(serializefromobject_result, 32); /* 047 */ this.serializefromobject_rowWriter = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter(serializefromobject_holder, 1); /* 048 */ this.serializefromobject_arrayWriter = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeArrayWriter(); /* 049 */ /* 050 */ } /* 051 */ /* 052 */ protected void processNext() throws java.io.IOException { /* 053 */ while (inputadapter_input.hasNext() && !stopEarly()) { /* 054 */ InternalRow inputadapter_row = (InternalRow) inputadapter_input.next(); /* 055 */ ArrayData inputadapter_value = inputadapter_row.getArray(0); /* 056 */ /* 057 */ deserializetoobject_resultIsNull = false; /* 058 */ /* 059 */ if (!deserializetoobject_resultIsNull) { /* 060 */ ArrayData deserializetoobject_value3 = null; /* 061 */ /* 062 */ if (!false) { /* 063 */ Integer[] deserializetoobject_convertedArray = null; /* 064 */ int deserializetoobject_dataLength = inputadapter_value.numElements(); /* 065 */ deserializetoobject_convertedArray = new Integer[deserializetoobject_dataLength]; /* 066 */ /* 067 */ int deserializetoobject_loopIndex = 0; /* 068 */ while (deserializetoobject_loopIndex < deserializetoobject_dataLength) { /* 069 */ MapObjects_loopValue0 = (int) (inputadapter_value.getInt(deserializetoobject_loopIndex)); /* 070 */ MapObjects_loopIsNull1 = inputadapter_value.isNullAt(deserializetoobject_loopIndex); /* 071 */ /* 072 */ if (MapObjects_loopIsNull1) { /* 073 */ throw new RuntimeException(((java.lang.String) references[0])); /* 074 */ } /* 075 */ if (false) { /* 076 */ deserializetoobject_convertedArray[deserializetoobject_loopIndex] = null; /* 077 */ } else { /* 078 */ deserializetoobject_convertedArray[deserializetoobject_loopIndex] = MapObjects_loopValue0; /* 079 */ } /* 080 */ /* 081 */ deserializetoobject_loopIndex += 1; /* 082 */ } /* 083 */ /* 084 */ deserializetoobject_value3 = new org.apache.spark.sql.catalyst.util.GenericArrayData(deserializetoobject_convertedArray); /* 085 */ } /* 086 */ boolean deserializetoobject_isNull2 = true; /* 087 */ java.lang.Object[] deserializetoobject_value2 = null; /* 088 */ if (!false) { /* 089 */ deserializetoobject_isNull2 = false; /* 090 */ if (!deserializetoobject_isNull2) { /* 091 */ Object deserializetoobject_funcResult = null; /* 092 */ deserializetoobject_funcResult = deserializetoobject_value3.array(); /* 093 */ if (deserializetoobject_funcResult == null) { /* 094 */ deserializetoobject_isNull2 = true; /* 095 */ } else { /* 096 */ deserializetoobject_value2 = (java.lang.Object[]) deserializetoobject_funcResult; /* 097 */ } /* 098 */ /* 099 */ } /* 100 */ deserializetoobject_isNull2 = deserializetoobject_value2 == null; /* 101 */ } /* 102 */ deserializetoobject_resultIsNull = deserializetoobject_isNull2; /* 103 */ deserializetoobject_argValue = deserializetoobject_value2; /* 104 */ } /* 105 */ /* 106 */ boolean deserializetoobject_isNull1 = deserializetoobject_resultIsNull; /* 107 */ final scala.collection.Seq deserializetoobject_value1 = deserializetoobject_resultIsNull ? null : scala.collection.mutable.WrappedArray.make(deserializetoobject_argValue); /* 108 */ deserializetoobject_isNull1 = deserializetoobject_value1 == null; /* 109 */ boolean deserializetoobject_isNull = true; /* 110 */ scala.collection.immutable.List deserializetoobject_value = null; /* 111 */ if (!deserializetoobject_isNull1) { /* 112 */ deserializetoobject_resultIsNull1 = false; /* 113 */ /* 114 */ if (!deserializetoobject_resultIsNull1) { /* 115 */ boolean deserializetoobject_isNull6 = false; /* 116 */ final scala.collection.generic.CanBuildFrom deserializetoobject_value6 = false ? null : scala.collection.immutable.List.canBuildFrom(); /* 117 */ deserializetoobject_isNull6 = deserializetoobject_value6 == null; /* 118 */ deserializetoobject_resultIsNull1 = deserializetoobject_isNull6; /* 119 */ deserializetoobject_argValue1 = deserializetoobject_value6; /* 120 */ } /* 121 */ /* 122 */ deserializetoobject_isNull = deserializetoobject_resultIsNull1; /* 123 */ if (!deserializetoobject_isNull) { /* 124 */ Object deserializetoobject_funcResult1 = null; /* 125 */ deserializetoobject_funcResult1 = deserializetoobject_value1.to(deserializetoobject_argValue1); /* 126 */ if (deserializetoobject_funcResult1 == null) { /* 127 */ deserializetoobject_isNull = true; /* 128 */ } else { /* 129 */ deserializetoobject_value = (scala.collection.immutable.List) deserializetoobject_funcResult1; /* 130 */ } /* 131 */ /* 132 */ } /* 133 */ deserializetoobject_isNull = deserializetoobject_value == null; /* 134 */ } /* 135 */ /* 136 */ boolean mapelements_isNull = true; /* 137 */ scala.collection.immutable.List mapelements_value = null; /* 138 */ if (!false) { /* 139 */ mapelements_argValue = deserializetoobject_value; /* 140 */ /* 141 */ mapelements_isNull = false; /* 142 */ if (!mapelements_isNull) { /* 143 */ Object mapelements_funcResult = null; /* 144 */ mapelements_funcResult = ((scala.Function1) references[1]).apply(mapelements_argValue); /* 145 */ if (mapelements_funcResult == null) { /* 146 */ mapelements_isNull = true; /* 147 */ } else { /* 148 */ mapelements_value = (scala.collection.immutable.List) mapelements_funcResult; /* 149 */ } /* 150 */ /* 151 */ } /* 152 */ mapelements_isNull = mapelements_value == null; /* 153 */ } /* 154 */ /* 155 */ if (mapelements_isNull) { /* 156 */ throw new RuntimeException(((java.lang.String) references[2])); /* 157 */ } /* 158 */ serializefromobject_argValue = mapelements_value; /* 159 */ /* 160 */ final ArrayData serializefromobject_value = false ? null : new org.apache.spark.sql.catalyst.util.GenericArrayData(serializefromobject_argValue); /* 161 */ serializefromobject_holder.reset(); /* 162 */ /* 163 */ // Remember the current cursor so that we can calculate how many bytes are /* 164 */ // written later. /* 165 */ final int serializefromobject_tmpCursor = serializefromobject_holder.cursor; /* 166 */ /* 167 */ if (serializefromobject_value instanceof UnsafeArrayData) { /* 168 */ final int serializefromobject_sizeInBytes = ((UnsafeArrayData) serializefromobject_value).getSizeInBytes(); /* 169 */ // grow the global buffer before writing data. /* 170 */ serializefromobject_holder.grow(serializefromobject_sizeInBytes); /* 171 */ ((UnsafeArrayData) serializefromobject_value).writeToMemory(serializefromobject_holder.buffer, serializefromobject_holder.cursor); /* 172 */ serializefromobject_holder.cursor += serializefromobject_sizeInBytes; /* 173 */ /* 174 */ } else { /* 175 */ final int serializefromobject_numElements = serializefromobject_value.numElements(); /* 176 */ serializefromobject_arrayWriter.initialize(serializefromobject_holder, serializefromobject_numElements, 4); /* 177 */ /* 178 */ for (int serializefromobject_index = 0; serializefromobject_index < serializefromobject_numElements; serializefromobject_index++) { /* 179 */ if (serializefromobject_value.isNullAt(serializefromobject_index)) { /* 180 */ serializefromobject_arrayWriter.setNullInt(serializefromobject_index); /* 181 */ } else { /* 182 */ final int serializefromobject_element = serializefromobject_value.getInt(serializefromobject_index); /* 183 */ serializefromobject_arrayWriter.write(serializefromobject_index, serializefromobject_element); /* 184 */ } /* 185 */ } /* 186 */ } /* 187 */ /* 188 */ serializefromobject_rowWriter.setOffsetAndSize(0, serializefromobject_tmpCursor, serializefromobject_holder.cursor - serializefromobject_tmpCursor); /* 189 */ serializefromobject_result.setTotalSize(serializefromobject_holder.totalSize()); /* 190 */ append(serializefromobject_result); /* 191 */ if (shouldStop()) return; /* 192 */ } /* 193 */ } /* 194 */ } ``` After: ``` /* 001 */ public Object generate(Object[] references) { /* 002 */ return new GeneratedIterator(references); /* 003 */ } /* 004 */ /* 005 */ final class GeneratedIterator extends org.apache.spark.sql.execution.BufferedRowIterator { /* 006 */ private Object[] references; /* 007 */ private scala.collection.Iterator[] inputs; /* 008 */ private scala.collection.Iterator inputadapter_input; /* 009 */ private boolean CollectObjects_loopIsNull1; /* 010 */ private int CollectObjects_loopValue0; /* 011 */ private UnsafeRow deserializetoobject_result; /* 012 */ private org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder deserializetoobject_holder; /* 013 */ private org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter deserializetoobject_rowWriter; /* 014 */ private scala.collection.immutable.List mapelements_argValue; /* 015 */ private UnsafeRow mapelements_result; /* 016 */ private org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder mapelements_holder; /* 017 */ private org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter mapelements_rowWriter; /* 018 */ private scala.collection.immutable.List serializefromobject_argValue; /* 019 */ private UnsafeRow serializefromobject_result; /* 020 */ private org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder serializefromobject_holder; /* 021 */ private org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter serializefromobject_rowWriter; /* 022 */ private org.apache.spark.sql.catalyst.expressions.codegen.UnsafeArrayWriter serializefromobject_arrayWriter; /* 023 */ /* 024 */ public GeneratedIterator(Object[] references) { /* 025 */ this.references = references; /* 026 */ } /* 027 */ /* 028 */ public void init(int index, scala.collection.Iterator[] inputs) { /* 029 */ partitionIndex = index; /* 030 */ this.inputs = inputs; /* 031 */ inputadapter_input = inputs[0]; /* 032 */ /* 033 */ deserializetoobject_result = new UnsafeRow(1); /* 034 */ this.deserializetoobject_holder = new org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder(deserializetoobject_result, 32); /* 035 */ this.deserializetoobject_rowWriter = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter(deserializetoobject_holder, 1); /* 036 */ /* 037 */ mapelements_result = new UnsafeRow(1); /* 038 */ this.mapelements_holder = new org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder(mapelements_result, 32); /* 039 */ this.mapelements_rowWriter = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter(mapelements_holder, 1); /* 040 */ /* 041 */ serializefromobject_result = new UnsafeRow(1); /* 042 */ this.serializefromobject_holder = new org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder(serializefromobject_result, 32); /* 043 */ this.serializefromobject_rowWriter = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter(serializefromobject_holder, 1); /* 044 */ this.serializefromobject_arrayWriter = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeArrayWriter(); /* 045 */ /* 046 */ } /* 047 */ /* 048 */ protected void processNext() throws java.io.IOException { /* 049 */ while (inputadapter_input.hasNext() && !stopEarly()) { /* 050 */ InternalRow inputadapter_row = (InternalRow) inputadapter_input.next(); /* 051 */ ArrayData inputadapter_value = inputadapter_row.getArray(0); /* 052 */ /* 053 */ scala.collection.immutable.List deserializetoobject_value = null; /* 054 */ /* 055 */ if (!false) { /* 056 */ int deserializetoobject_dataLength = inputadapter_value.numElements(); /* 057 */ scala.collection.mutable.Builder CollectObjects_builderValue2 = scala.collection.immutable.List$.MODULE$.newBuilder(); /* 058 */ CollectObjects_builderValue2.sizeHint(deserializetoobject_dataLength); /* 059 */ /* 060 */ int deserializetoobject_loopIndex = 0; /* 061 */ while (deserializetoobject_loopIndex < deserializetoobject_dataLength) { /* 062 */ CollectObjects_loopValue0 = (int) (inputadapter_value.getInt(deserializetoobject_loopIndex)); /* 063 */ CollectObjects_loopIsNull1 = inputadapter_value.isNullAt(deserializetoobject_loopIndex); /* 064 */ /* 065 */ if (CollectObjects_loopIsNull1) { /* 066 */ throw new RuntimeException(((java.lang.String) references[0])); /* 067 */ } /* 068 */ if (false) { /* 069 */ CollectObjects_builderValue2.$plus$eq(null); /* 070 */ } else { /* 071 */ CollectObjects_builderValue2.$plus$eq(CollectObjects_loopValue0); /* 072 */ } /* 073 */ /* 074 */ deserializetoobject_loopIndex += 1; /* 075 */ } /* 076 */ /* 077 */ deserializetoobject_value = (scala.collection.immutable.List) CollectObjects_builderValue2.result(); /* 078 */ } /* 079 */ /* 080 */ boolean mapelements_isNull = true; /* 081 */ scala.collection.immutable.List mapelements_value = null; /* 082 */ if (!false) { /* 083 */ mapelements_argValue = deserializetoobject_value; /* 084 */ /* 085 */ mapelements_isNull = false; /* 086 */ if (!mapelements_isNull) { /* 087 */ Object mapelements_funcResult = null; /* 088 */ mapelements_funcResult = ((scala.Function1) references[1]).apply(mapelements_argValue); /* 089 */ if (mapelements_funcResult == null) { /* 090 */ mapelements_isNull = true; /* 091 */ } else { /* 092 */ mapelements_value = (scala.collection.immutable.List) mapelements_funcResult; /* 093 */ } /* 094 */ /* 095 */ } /* 096 */ mapelements_isNull = mapelements_value == null; /* 097 */ } /* 098 */ /* 099 */ if (mapelements_isNull) { /* 100 */ throw new RuntimeException(((java.lang.String) references[2])); /* 101 */ } /* 102 */ serializefromobject_argValue = mapelements_value; /* 103 */ /* 104 */ final ArrayData serializefromobject_value = false ? null : new org.apache.spark.sql.catalyst.util.GenericArrayData(serializefromobject_argValue); /* 105 */ serializefromobject_holder.reset(); /* 106 */ /* 107 */ // Remember the current cursor so that we can calculate how many bytes are /* 108 */ // written later. /* 109 */ final int serializefromobject_tmpCursor = serializefromobject_holder.cursor; /* 110 */ /* 111 */ if (serializefromobject_value instanceof UnsafeArrayData) { /* 112 */ final int serializefromobject_sizeInBytes = ((UnsafeArrayData) serializefromobject_value).getSizeInBytes(); /* 113 */ // grow the global buffer before writing data. /* 114 */ serializefromobject_holder.grow(serializefromobject_sizeInBytes); /* 115 */ ((UnsafeArrayData) serializefromobject_value).writeToMemory(serializefromobject_holder.buffer, serializefromobject_holder.cursor); /* 116 */ serializefromobject_holder.cursor += serializefromobject_sizeInBytes; /* 117 */ /* 118 */ } else { /* 119 */ final int serializefromobject_numElements = serializefromobject_value.numElements(); /* 120 */ serializefromobject_arrayWriter.initialize(serializefromobject_holder, serializefromobject_numElements, 4); /* 121 */ /* 122 */ for (int serializefromobject_index = 0; serializefromobject_index < serializefromobject_numElements; serializefromobject_index++) { /* 123 */ if (serializefromobject_value.isNullAt(serializefromobject_index)) { /* 124 */ serializefromobject_arrayWriter.setNullInt(serializefromobject_index); /* 125 */ } else { /* 126 */ final int serializefromobject_element = serializefromobject_value.getInt(serializefromobject_index); /* 127 */ serializefromobject_arrayWriter.write(serializefromobject_index, serializefromobject_element); /* 128 */ } /* 129 */ } /* 130 */ } /* 131 */ /* 132 */ serializefromobject_rowWriter.setOffsetAndSize(0, serializefromobject_tmpCursor, serializefromobject_holder.cursor - serializefromobject_tmpCursor); /* 133 */ serializefromobject_result.setTotalSize(serializefromobject_holder.totalSize()); /* 134 */ append(serializefromobject_result); /* 135 */ if (shouldStop()) return; /* 136 */ } /* 137 */ } /* 138 */ } ``` Benchmark results before: ``` OpenJDK 64-Bit Server VM 1.8.0_112-b15 on Linux 4.8.13-1-ARCH AMD A10-4600M APU with Radeon(tm) HD Graphics collect: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ Seq 269 / 370 0.0 269125.8 1.0X List 154 / 176 0.0 154453.5 1.7X mutable.Queue 210 / 233 0.0 209691.6 1.3X ``` Benchmark results after: ``` OpenJDK 64-Bit Server VM 1.8.0_112-b15 on Linux 4.8.13-1-ARCH AMD A10-4600M APU with Radeon(tm) HD Graphics collect: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ Seq 255 / 316 0.0 254697.3 1.0X List 152 / 177 0.0 152410.0 1.7X mutable.Queue 213 / 235 0.0 213470.0 1.2X ``` ## How was this patch tested? ```bash ./build/mvn -DskipTests clean package && ./dev/run-tests ``` Additionally in Spark Shell: ```scala case class QueueClass(q: scala.collection.immutable.Queue[Int]) spark.createDataset(Seq(List(1,2,3))).map(x => QueueClass(scala.collection.immutable.Queue(x: _*))).map(_.q.dequeue).collect ``` Author: Michal Senkyr Closes #16541 from michalsenkyr/dataset-seq-builder. --- .../spark/sql/catalyst/ScalaReflection.scala | 51 ++------------- .../expressions/objects/objects.scala | 64 +++++++++++++++---- .../sql/catalyst/ScalaReflectionSuite.scala | 8 --- 3 files changed, 54 insertions(+), 69 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala index c4af284f73d16..1c7720afe1ca3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala @@ -307,54 +307,11 @@ object ScalaReflection extends ScalaReflection { } } - val array = Invoke( - MapObjects(mapFunction, getPath, dataType), - "array", - ObjectType(classOf[Array[Any]])) - - val wrappedArray = StaticInvoke( - scala.collection.mutable.WrappedArray.getClass, - ObjectType(classOf[Seq[_]]), - "make", - array :: Nil) - - if (localTypeOf[scala.collection.mutable.WrappedArray[_]] <:< t.erasure) { - wrappedArray - } else { - // Convert to another type using `to` - val cls = mirror.runtimeClass(t.typeSymbol.asClass) - import scala.collection.generic.CanBuildFrom - import scala.reflect.ClassTag - - // Some canBuildFrom methods take an implicit ClassTag parameter - val cbfParams = try { - cls.getDeclaredMethod("canBuildFrom", classOf[ClassTag[_]]) - StaticInvoke( - ClassTag.getClass, - ObjectType(classOf[ClassTag[_]]), - "apply", - StaticInvoke( - cls, - ObjectType(classOf[Class[_]]), - "getClass" - ) :: Nil - ) :: Nil - } catch { - case _: NoSuchMethodException => Nil - } - - Invoke( - wrappedArray, - "to", - ObjectType(cls), - StaticInvoke( - cls, - ObjectType(classOf[CanBuildFrom[_, _, _]]), - "canBuildFrom", - cbfParams - ) :: Nil - ) + val cls = t.dealias.companion.decl(TermName("newBuilder")) match { + case NoSymbol => classOf[Seq[_]] + case _ => mirror.runtimeClass(t.typeSymbol.asClass) } + MapObjects(mapFunction, getPath, dataType, Some(cls)) case t if t <:< localTypeOf[Map[_, _]] => // TODO: add walked type path for map diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala index 771ac28e5107a..bb584f7d087e8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst.expressions.objects import java.lang.reflect.Modifier +import scala.collection.mutable.Builder import scala.language.existentials import scala.reflect.ClassTag @@ -429,24 +430,34 @@ object MapObjects { * @param function The function applied on the collection elements. * @param inputData An expression that when evaluated returns a collection object. * @param elementType The data type of elements in the collection. + * @param customCollectionCls Class of the resulting collection (returning ObjectType) + * or None (returning ArrayType) */ def apply( function: Expression => Expression, inputData: Expression, - elementType: DataType): MapObjects = { - val loopValue = "MapObjects_loopValue" + curId.getAndIncrement() - val loopIsNull = "MapObjects_loopIsNull" + curId.getAndIncrement() + elementType: DataType, + customCollectionCls: Option[Class[_]] = None): MapObjects = { + val id = curId.getAndIncrement() + val loopValue = s"MapObjects_loopValue$id" + val loopIsNull = s"MapObjects_loopIsNull$id" val loopVar = LambdaVariable(loopValue, loopIsNull, elementType) - MapObjects(loopValue, loopIsNull, elementType, function(loopVar), inputData) + val builderValue = s"MapObjects_builderValue$id" + MapObjects(loopValue, loopIsNull, elementType, function(loopVar), inputData, + customCollectionCls, builderValue) } } /** * Applies the given expression to every element of a collection of items, returning the result - * as an ArrayType. This is similar to a typical map operation, but where the lambda function - * is expressed using catalyst expressions. + * as an ArrayType or ObjectType. This is similar to a typical map operation, but where the lambda + * function is expressed using catalyst expressions. + * + * The type of the result is determined as follows: + * - ArrayType - when customCollectionCls is None + * - ObjectType(collection) - when customCollectionCls contains a collection class * - * The following collection ObjectTypes are currently supported: + * The following collection ObjectTypes are currently supported on input: * Seq, Array, ArrayData, java.util.List * * @param loopValue the name of the loop variable that used when iterate the collection, and used @@ -458,13 +469,19 @@ object MapObjects { * @param lambdaFunction A function that take the `loopVar` as input, and used as lambda function * to handle collection elements. * @param inputData An expression that when evaluated returns a collection object. + * @param customCollectionCls Class of the resulting collection (returning ObjectType) + * or None (returning ArrayType) + * @param builderValue The name of the builder variable used to construct the resulting collection + * (used only when returning ObjectType) */ case class MapObjects private( loopValue: String, loopIsNull: String, loopVarDataType: DataType, lambdaFunction: Expression, - inputData: Expression) extends Expression with NonSQLExpression { + inputData: Expression, + customCollectionCls: Option[Class[_]], + builderValue: String) extends Expression with NonSQLExpression { override def nullable: Boolean = inputData.nullable @@ -474,7 +491,8 @@ case class MapObjects private( throw new UnsupportedOperationException("Only code-generated evaluation is supported") override def dataType: DataType = - ArrayType(lambdaFunction.dataType, containsNull = lambdaFunction.nullable) + customCollectionCls.map(ObjectType.apply).getOrElse( + ArrayType(lambdaFunction.dataType, containsNull = lambdaFunction.nullable)) override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val elementJavaType = ctx.javaType(loopVarDataType) @@ -557,15 +575,33 @@ case class MapObjects private( case _ => s"$loopIsNull = $loopValue == null;" } + val (initCollection, addElement, getResult): (String, String => String, String) = + customCollectionCls match { + case Some(cls) => + // collection + val collObjectName = s"${cls.getName}$$.MODULE$$" + val getBuilderVar = s"$collObjectName.newBuilder()" + + (s"""${classOf[Builder[_, _]].getName} $builderValue = $getBuilderVar; + $builderValue.sizeHint($dataLength);""", + genValue => s"$builderValue.$$plus$$eq($genValue);", + s"(${cls.getName}) $builderValue.result();") + case None => + // array + (s"""$convertedType[] $convertedArray = null; + $convertedArray = $arrayConstructor;""", + genValue => s"$convertedArray[$loopIndex] = $genValue;", + s"new ${classOf[GenericArrayData].getName}($convertedArray);") + } + val code = s""" ${genInputData.code} ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; if (!${genInputData.isNull}) { $determineCollectionType - $convertedType[] $convertedArray = null; int $dataLength = $getLength; - $convertedArray = $arrayConstructor; + $initCollection int $loopIndex = 0; while ($loopIndex < $dataLength) { @@ -574,15 +610,15 @@ case class MapObjects private( ${genFunction.code} if (${genFunction.isNull}) { - $convertedArray[$loopIndex] = null; + ${addElement("null")} } else { - $convertedArray[$loopIndex] = $genFunctionValue; + ${addElement(genFunctionValue)} } $loopIndex += 1; } - ${ev.value} = new ${classOf[GenericArrayData].getName}($convertedArray); + ${ev.value} = $getResult } """ ev.copy(code = code, isNull = genInputData.isNull) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala index 650a35398f3e8..70ad064f93ebc 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala @@ -312,14 +312,6 @@ class ScalaReflectionSuite extends SparkFunSuite { ArrayType(IntegerType, containsNull = false)) val arrayBufferDeserializer = deserializerFor[ArrayBuffer[Int]] assert(arrayBufferDeserializer.dataType == ObjectType(classOf[ArrayBuffer[_]])) - - // Check whether conversion is skipped when using WrappedArray[_] supertype - // (would otherwise needlessly add overhead) - import org.apache.spark.sql.catalyst.expressions.objects.StaticInvoke - val seqDeserializer = deserializerFor[Seq[Int]] - assert(seqDeserializer.asInstanceOf[StaticInvoke].staticObject == - scala.collection.mutable.WrappedArray.getClass) - assert(seqDeserializer.asInstanceOf[StaticInvoke].functionName == "make") } private val dataTypeForComplexData = dataTypeFor[ComplexData] From a9abff281bcb15fdc91111121c8bcb983a9d91cb Mon Sep 17 00:00:00 2001 From: Xiao Li Date: Tue, 28 Mar 2017 09:37:28 +0200 Subject: [PATCH 0133/1765] [SPARK-20119][TEST-MAVEN] Fix the test case fail in DataSourceScanExecRedactionSuite ### What changes were proposed in this pull request? Changed the pattern to match the first n characters in the location field so that the string truncation does not affect it. ### How was this patch tested? N/A Author: Xiao Li Closes #17448 from gatorsmile/fixTestCAse. --- .../spark/sql/execution/DataSourceScanExecRedactionSuite.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/DataSourceScanExecRedactionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/DataSourceScanExecRedactionSuite.scala index 986fa878ee29b..05a2b2c862c73 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/DataSourceScanExecRedactionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/DataSourceScanExecRedactionSuite.scala @@ -31,7 +31,7 @@ class DataSourceScanExecRedactionSuite extends QueryTest with SharedSQLContext { override def beforeAll(): Unit = { sparkConf.set("spark.redaction.string.regex", - "spark-[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}") + "file:/[\\w_]+") super.beforeAll() } From 91559d277f42ee83b79f5d8eb7ba037cf5c108da Mon Sep 17 00:00:00 2001 From: wangzhenhua Date: Tue, 28 Mar 2017 13:43:23 +0200 Subject: [PATCH 0134/1765] [SPARK-20094][SQL] Preventing push down of IN subquery to Join operator ## What changes were proposed in this pull request? TPCDS q45 fails becuase: `ReorderJoin` collects all predicates and try to put them into join condition when creating ordered join. If a predicate with an IN subquery (`ListQuery`) is in a join condition instead of a filter condition, `RewritePredicateSubquery.rewriteExistentialExpr` would fail to convert the subquery to an `ExistenceJoin`, and thus result in error. We should prevent push down of IN subquery to Join operator. ## How was this patch tested? Add a new test case in `FilterPushdownSuite`. Author: wangzhenhua Closes #17428 from wzhfy/noSubqueryInJoinCond. --- .../sql/catalyst/expressions/predicates.scala | 6 ++++++ .../optimizer/FilterPushdownSuite.scala | 20 +++++++++++++++++++ 2 files changed, 26 insertions(+) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala index e5d1a1e2996c5..1235204591bbd 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala @@ -90,6 +90,12 @@ trait PredicateHelper { * Returns true iff `expr` could be evaluated as a condition within join. */ protected def canEvaluateWithinJoin(expr: Expression): Boolean = expr match { + case l: ListQuery => + // A ListQuery defines the query which we want to search in an IN subquery expression. + // Currently the only way to evaluate an IN subquery is to convert it to a + // LeftSemi/LeftAnti/ExistenceJoin by `RewritePredicateSubquery` rule. + // It cannot be evaluated as part of a Join operator. + false case e: SubqueryExpression => // non-correlated subquery will be replaced as literal e.children.isEmpty diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala index 6feea4060f46a..d846786473eb0 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala @@ -836,6 +836,26 @@ class FilterPushdownSuite extends PlanTest { comparePlans(optimized, answer) } + test("SPARK-20094: don't push predicate with IN subquery into join condition") { + val x = testRelation.subquery('x) + val z = testRelation.subquery('z) + val w = testRelation1.subquery('w) + + val queryPlan = x + .join(z) + .where(("x.b".attr === "z.b".attr) && + ("x.a".attr > 1 || "z.c".attr.in(ListQuery(w.select("w.d".attr))))) + .analyze + + val expectedPlan = x + .join(z, Inner, Some("x.b".attr === "z.b".attr)) + .where("x.a".attr > 1 || "z.c".attr.in(ListQuery(w.select("w.d".attr)))) + .analyze + + val optimized = Optimize.execute(queryPlan) + comparePlans(optimized, expectedPlan) + } + test("Window: predicate push down -- basic") { val winExpr = windowExpr(count('b), windowSpec('a :: Nil, 'b.asc :: Nil, UnspecifiedFrame)) From 4fcc214d9eb5e98b2eed3e28cc23b0c511cd9007 Mon Sep 17 00:00:00 2001 From: wangzhenhua Date: Tue, 28 Mar 2017 22:22:38 +0800 Subject: [PATCH 0135/1765] [SPARK-20124][SQL] Join reorder should keep the same order of final project attributes ## What changes were proposed in this pull request? Join reorder algorithm should keep exactly the same order of output attributes in the top project. For example, if user want to select a, b, c, after reordering, we should output a, b, c in the same order as specified by user, instead of b, a, c or other orders. ## How was this patch tested? A new test case is added in `JoinReorderSuite`. Author: wangzhenhua Closes #17453 from wzhfy/keepOrderInProject. --- .../optimizer/CostBasedJoinReorder.scala | 24 ++++++++++++------- .../catalyst/optimizer/JoinReorderSuite.scala | 13 ++++++++++ .../spark/sql/catalyst/plans/PlanTest.scala | 4 ++-- 3 files changed, 31 insertions(+), 10 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/CostBasedJoinReorder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/CostBasedJoinReorder.scala index fc37720809ba2..cbd506465ae6a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/CostBasedJoinReorder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/CostBasedJoinReorder.scala @@ -40,10 +40,10 @@ case class CostBasedJoinReorder(conf: SQLConf) extends Rule[LogicalPlan] with Pr val result = plan transformDown { // Start reordering with a joinable item, which is an InnerLike join with conditions. case j @ Join(_, _, _: InnerLike, Some(cond)) => - reorder(j, j.outputSet) + reorder(j, j.output) case p @ Project(projectList, Join(_, _, _: InnerLike, Some(cond))) if projectList.forall(_.isInstanceOf[Attribute]) => - reorder(p, p.outputSet) + reorder(p, p.output) } // After reordering is finished, convert OrderedJoin back to Join result transformDown { @@ -52,7 +52,7 @@ case class CostBasedJoinReorder(conf: SQLConf) extends Rule[LogicalPlan] with Pr } } - private def reorder(plan: LogicalPlan, output: AttributeSet): LogicalPlan = { + private def reorder(plan: LogicalPlan, output: Seq[Attribute]): LogicalPlan = { val (items, conditions) = extractInnerJoins(plan) // TODO: Compute the set of star-joins and use them in the join enumeration // algorithm to prune un-optimal plan choices. @@ -140,7 +140,7 @@ object JoinReorderDP extends PredicateHelper with Logging { conf: SQLConf, items: Seq[LogicalPlan], conditions: Set[Expression], - topOutput: AttributeSet): LogicalPlan = { + output: Seq[Attribute]): LogicalPlan = { val startTime = System.nanoTime() // Level i maintains all found plans for i + 1 items. @@ -152,9 +152,10 @@ object JoinReorderDP extends PredicateHelper with Logging { // Build plans for next levels until the last level has only one plan. This plan contains // all items that can be joined, so there's no need to continue. + val topOutputSet = AttributeSet(output) while (foundPlans.size < items.length && foundPlans.last.size > 1) { // Build plans for the next level. - foundPlans += searchLevel(foundPlans, conf, conditions, topOutput) + foundPlans += searchLevel(foundPlans, conf, conditions, topOutputSet) } val durationInMs = (System.nanoTime() - startTime) / (1000 * 1000) @@ -163,7 +164,14 @@ object JoinReorderDP extends PredicateHelper with Logging { // The last level must have one and only one plan, because all items are joinable. assert(foundPlans.size == items.length && foundPlans.last.size == 1) - foundPlans.last.head._2.plan + foundPlans.last.head._2.plan match { + case p @ Project(projectList, j: Join) if projectList != output => + assert(topOutputSet == p.outputSet) + // Keep the same order of final output attributes. + p.copy(projectList = output) + case finalPlan => + finalPlan + } } /** Find all possible plans at the next level, based on existing levels. */ @@ -254,10 +262,10 @@ object JoinReorderDP extends PredicateHelper with Logging { val collectedJoinConds = joinConds ++ oneJoinPlan.joinConds ++ otherJoinPlan.joinConds val remainingConds = conditions -- collectedJoinConds val neededAttr = AttributeSet(remainingConds.flatMap(_.references)) ++ topOutput - val neededFromNewJoin = newJoin.outputSet.filter(neededAttr.contains) + val neededFromNewJoin = newJoin.output.filter(neededAttr.contains) val newPlan = if ((newJoin.outputSet -- neededFromNewJoin).nonEmpty) { - Project(neededFromNewJoin.toSeq, newJoin) + Project(neededFromNewJoin, newJoin) } else { newJoin } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinReorderSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinReorderSuite.scala index 05b839b0119f4..d74008c1b3027 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinReorderSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinReorderSuite.scala @@ -198,6 +198,19 @@ class JoinReorderSuite extends PlanTest with StatsEstimationTestBase { assertEqualPlans(originalPlan, bestPlan) } + test("keep the order of attributes in the final output") { + val outputLists = Seq("t1.k-1-2", "t1.v-1-10", "t3.v-1-100").permutations + while (outputLists.hasNext) { + val expectedOrder = outputLists.next().map(nameToAttr) + val expectedPlan = + t1.join(t3, Inner, Some(nameToAttr("t1.v-1-10") === nameToAttr("t3.v-1-100"))) + .join(t2, Inner, Some(nameToAttr("t1.k-1-2") === nameToAttr("t2.k-1-5"))) + .select(expectedOrder: _*) + // The plan should not change after optimization + assertEqualPlans(expectedPlan, expectedPlan) + } + } + private def assertEqualPlans( originalPlan: LogicalPlan, groundTruthBestPlan: LogicalPlan): Unit = { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala index 2a9d0570148ad..c73dfaf3f8fe3 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala @@ -126,8 +126,8 @@ abstract class PlanTest extends SparkFunSuite with PredicateHelper { case (j1: Join, j2: Join) => (sameJoinPlan(j1.left, j2.left) && sameJoinPlan(j1.right, j2.right)) || (sameJoinPlan(j1.left, j2.right) && sameJoinPlan(j1.right, j2.left)) - case _ if plan1.children.nonEmpty && plan2.children.nonEmpty => - (plan1.children, plan2.children).zipped.forall { case (c1, c2) => sameJoinPlan(c1, c2) } + case (p1: Project, p2: Project) => + p1.projectList == p2.projectList && sameJoinPlan(p1.child, p2.child) case _ => plan1 == plan2 } From f82461fc1197f6055d9cf972d82260b178e10a7c Mon Sep 17 00:00:00 2001 From: Herman van Hovell Date: Tue, 28 Mar 2017 23:14:31 +0800 Subject: [PATCH 0136/1765] [SPARK-20126][SQL] Remove HiveSessionState ## What changes were proposed in this pull request? Commit https://github.com/apache/spark/commit/ea361165e1ddce4d8aa0242ae3e878d7b39f1de2 moved most of the logic from the SessionState classes into an accompanying builder. This makes the existence of the `HiveSessionState` redundant. This PR removes the `HiveSessionState`. ## How was this patch tested? Existing tests. Author: Herman van Hovell Closes #17457 from hvanhovell/SPARK-20126. --- .../sql/execution/command/resources.scala | 2 +- .../spark/sql/internal/SessionState.scala | 47 +++--- .../sql/internal/sessionStateBuilders.scala | 8 +- .../sql/hive/thriftserver/SparkSQLEnv.scala | 12 +- .../server/SparkSQLOperationManager.scala | 6 +- .../execution/HiveCompatibilitySuite.scala | 2 +- .../apache/spark/sql/hive/HiveContext.scala | 4 - .../spark/sql/hive/HiveMetastoreCatalog.scala | 9 +- .../spark/sql/hive/HiveSessionState.scala | 144 +++--------------- .../apache/spark/sql/hive/test/TestHive.scala | 23 ++- .../sql/hive/HiveMetastoreCatalogSuite.scala | 6 +- .../sql/hive/MetastoreDataSourcesSuite.scala | 7 +- .../sql/hive/execution/HiveDDLSuite.scala | 6 +- .../apache/spark/sql/hive/parquetSuites.scala | 21 ++- 14 files changed, 104 insertions(+), 193 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/resources.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/resources.scala index 20b08946675d0..2e859cf1ef253 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/resources.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/resources.scala @@ -37,7 +37,7 @@ case class AddJarCommand(path: String) extends RunnableCommand { } override def run(sparkSession: SparkSession): Seq[Row] = { - sparkSession.sessionState.addJar(path) + sparkSession.sessionState.resourceLoader.addJar(path) Seq(Row(0)) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala index b5b0bb0bfc401..c6241d923d7b3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala @@ -63,6 +63,7 @@ private[sql] class SessionState( val optimizer: Optimizer, val planner: SparkPlanner, val streamingQueryManager: StreamingQueryManager, + val resourceLoader: SessionResourceLoader, createQueryExecution: LogicalPlan => QueryExecution, createClone: (SparkSession, SessionState) => SessionState) { @@ -106,27 +107,6 @@ private[sql] class SessionState( def refreshTable(tableName: String): Unit = { catalog.refreshTable(sqlParser.parseTableIdentifier(tableName)) } - - /** - * Add a jar path to [[SparkContext]] and the classloader. - * - * Note: this method seems not access any session state, but the subclass `HiveSessionState` needs - * to add the jar to its hive client for the current session. Hence, it still needs to be in - * [[SessionState]]. - */ - def addJar(path: String): Unit = { - sparkContext.addJar(path) - val uri = new Path(path).toUri - val jarURL = if (uri.getScheme == null) { - // `path` is a local file path without a URL scheme - new File(path).toURI.toURL - } else { - // `path` is a URL with a scheme - uri.toURL - } - sharedState.jarClassLoader.addURL(jarURL) - Thread.currentThread().setContextClassLoader(sharedState.jarClassLoader) - } } private[sql] object SessionState { @@ -160,10 +140,10 @@ class SessionStateBuilder( * Session shared [[FunctionResourceLoader]]. */ @InterfaceStability.Unstable -class SessionFunctionResourceLoader(session: SparkSession) extends FunctionResourceLoader { +class SessionResourceLoader(session: SparkSession) extends FunctionResourceLoader { override def loadResource(resource: FunctionResource): Unit = { resource.resourceType match { - case JarResource => session.sessionState.addJar(resource.uri) + case JarResource => addJar(resource.uri) case FileResource => session.sparkContext.addFile(resource.uri) case ArchiveResource => throw new AnalysisException( @@ -171,4 +151,25 @@ class SessionFunctionResourceLoader(session: SparkSession) extends FunctionResou "please use --archives options while calling spark-submit.") } } + + /** + * Add a jar path to [[SparkContext]] and the classloader. + * + * Note: this method seems not access any session state, but the subclass `HiveSessionState` needs + * to add the jar to its hive client for the current session. Hence, it still needs to be in + * [[SessionState]]. + */ + def addJar(path: String): Unit = { + session.sparkContext.addJar(path) + val uri = new Path(path).toUri + val jarURL = if (uri.getScheme == null) { + // `path` is a local file path without a URL scheme + new File(path).toURI.toURL + } else { + // `path` is a URL with a scheme + uri.toURL + } + session.sharedState.jarClassLoader.addURL(jarURL) + Thread.currentThread().setContextClassLoader(session.sharedState.jarClassLoader) + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/sessionStateBuilders.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/sessionStateBuilders.scala index 6b5559adb1db4..b8f645fdee85a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/sessionStateBuilders.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/sessionStateBuilders.scala @@ -109,6 +109,11 @@ abstract class BaseSessionStateBuilder( */ protected lazy val sqlParser: ParserInterface = new SparkSqlParser(conf) + /** + * ResourceLoader that is used to load function resources and jars. + */ + protected lazy val resourceLoader: SessionResourceLoader = new SessionResourceLoader(session) + /** * Catalog for managing table and database states. If there is a pre-existing catalog, the state * of that catalog (temp tables & current database) will be copied into the new catalog. @@ -123,7 +128,7 @@ abstract class BaseSessionStateBuilder( conf, SessionState.newHadoopConf(session.sparkContext.hadoopConfiguration, conf), sqlParser, - new SessionFunctionResourceLoader(session)) + resourceLoader) parentState.foreach(_.catalog.copyStateTo(catalog)) catalog } @@ -251,6 +256,7 @@ abstract class BaseSessionStateBuilder( optimizer, planner, streamingQueryManager, + resourceLoader, createQueryExecution, createClone) } diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLEnv.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLEnv.scala index c0b299411e94a..01c4eb131a564 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLEnv.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLEnv.scala @@ -22,7 +22,7 @@ import java.io.PrintStream import org.apache.spark.{SparkConf, SparkContext} import org.apache.spark.internal.Logging import org.apache.spark.sql.{SparkSession, SQLContext} -import org.apache.spark.sql.hive.{HiveSessionState, HiveUtils} +import org.apache.spark.sql.hive.{HiveExternalCatalog, HiveUtils} import org.apache.spark.util.Utils /** A singleton object for the master program. The slaves should not access this. */ @@ -49,10 +49,12 @@ private[hive] object SparkSQLEnv extends Logging { sparkContext = sparkSession.sparkContext sqlContext = sparkSession.sqlContext - val sessionState = sparkSession.sessionState.asInstanceOf[HiveSessionState] - sessionState.metadataHive.setOut(new PrintStream(System.out, true, "UTF-8")) - sessionState.metadataHive.setInfo(new PrintStream(System.err, true, "UTF-8")) - sessionState.metadataHive.setError(new PrintStream(System.err, true, "UTF-8")) + val metadataHive = sparkSession + .sharedState.externalCatalog.asInstanceOf[HiveExternalCatalog] + .client.newSession() + metadataHive.setOut(new PrintStream(System.out, true, "UTF-8")) + metadataHive.setInfo(new PrintStream(System.err, true, "UTF-8")) + metadataHive.setError(new PrintStream(System.err, true, "UTF-8")) sparkSession.conf.set("spark.sql.hive.version", HiveUtils.hiveExecutionVersion) } } diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/server/SparkSQLOperationManager.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/server/SparkSQLOperationManager.scala index 49ab664009341..a0e5012633f5e 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/server/SparkSQLOperationManager.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/server/SparkSQLOperationManager.scala @@ -26,7 +26,7 @@ import org.apache.hive.service.cli.session.HiveSession import org.apache.spark.internal.Logging import org.apache.spark.sql.SQLContext -import org.apache.spark.sql.hive.HiveSessionState +import org.apache.spark.sql.hive.HiveUtils import org.apache.spark.sql.hive.thriftserver.{ReflectionUtils, SparkExecuteStatementOperation} /** @@ -49,8 +49,8 @@ private[thriftserver] class SparkSQLOperationManager() val sqlContext = sessionToContexts.get(parentSession.getSessionHandle) require(sqlContext != null, s"Session handle: ${parentSession.getSessionHandle} has not been" + s" initialized or had already closed.") - val sessionState = sqlContext.sessionState.asInstanceOf[HiveSessionState] - val runInBackground = async && sessionState.hiveThriftServerAsync + val conf = sqlContext.sessionState.conf + val runInBackground = async && conf.getConf(HiveUtils.HIVE_THRIFT_SERVER_ASYNC) val operation = new SparkExecuteStatementOperation(parentSession, statement, confOverlay, runInBackground)(sqlContext, sessionToActivePool) handleToOperation.put(operation.getHandle, operation) diff --git a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala index f78660f7c14b6..0a53aaca404e6 100644 --- a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala +++ b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala @@ -39,7 +39,7 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { private val originalLocale = Locale.getDefault private val originalColumnBatchSize = TestHive.conf.columnBatchSize private val originalInMemoryPartitionPruning = TestHive.conf.inMemoryPartitionPruning - private val originalConvertMetastoreOrc = TestHive.sessionState.convertMetastoreOrc + private val originalConvertMetastoreOrc = TestHive.conf.getConf(HiveUtils.CONVERT_METASTORE_ORC) private val originalCrossJoinEnabled = TestHive.conf.crossJoinEnabled private val originalSessionLocalTimeZone = TestHive.conf.sessionLocalTimeZone diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala index 5393c57c9a28f..02a5117f005e8 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala @@ -48,10 +48,6 @@ class HiveContext private[hive](_sparkSession: SparkSession) new HiveContext(sparkSession.newSession()) } - protected[sql] override def sessionState: HiveSessionState = { - sparkSession.sessionState.asInstanceOf[HiveSessionState] - } - /** * Invalidate and refresh all the cached the metadata of the given table. For performance reasons, * Spark SQL or the external data source library it uses might cache certain metadata about a diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala index 2e060ab9f6801..305bd007c93f7 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala @@ -44,7 +44,7 @@ import org.apache.spark.sql.types._ */ private[hive] class HiveMetastoreCatalog(sparkSession: SparkSession) extends Logging { // these are def_s and not val/lazy val since the latter would introduce circular references - private def sessionState = sparkSession.sessionState.asInstanceOf[HiveSessionState] + private def sessionState = sparkSession.sessionState private def tableRelationCache = sparkSession.sessionState.catalog.tableRelationCache import HiveMetastoreCatalog._ @@ -281,12 +281,13 @@ private[hive] class HiveMetastoreCatalog(sparkSession: SparkSession) extends Log object ParquetConversions extends Rule[LogicalPlan] { private def shouldConvertMetastoreParquet(relation: CatalogRelation): Boolean = { relation.tableMeta.storage.serde.getOrElse("").toLowerCase.contains("parquet") && - sessionState.convertMetastoreParquet + sessionState.conf.getConf(HiveUtils.CONVERT_METASTORE_PARQUET) } private def convertToParquetRelation(relation: CatalogRelation): LogicalRelation = { val fileFormatClass = classOf[ParquetFileFormat] - val mergeSchema = sessionState.convertMetastoreParquetWithSchemaMerging + val mergeSchema = sessionState.conf.getConf( + HiveUtils.CONVERT_METASTORE_PARQUET_WITH_SCHEMA_MERGING) val options = Map(ParquetOptions.MERGE_SCHEMA -> mergeSchema.toString) convertToLogicalRelation(relation, options, fileFormatClass, "parquet") @@ -316,7 +317,7 @@ private[hive] class HiveMetastoreCatalog(sparkSession: SparkSession) extends Log object OrcConversions extends Rule[LogicalPlan] { private def shouldConvertMetastoreOrc(relation: CatalogRelation): Boolean = { relation.tableMeta.storage.serde.getOrElse("").toLowerCase.contains("orc") && - sessionState.convertMetastoreOrc + sessionState.conf.getConf(HiveUtils.CONVERT_METASTORE_ORC) } private def convertToOrcRelation(relation: CatalogRelation): LogicalRelation = { diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionState.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionState.scala index 49ff8478f1ae2..f49e6bb418644 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionState.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionState.scala @@ -17,121 +17,24 @@ package org.apache.spark.sql.hive -import org.apache.spark.SparkContext import org.apache.spark.annotation.{Experimental, InterfaceStability} import org.apache.spark.sql._ -import org.apache.spark.sql.catalyst.analysis.{Analyzer, FunctionRegistry} -import org.apache.spark.sql.catalyst.optimizer.Optimizer -import org.apache.spark.sql.catalyst.parser.ParserInterface +import org.apache.spark.sql.catalyst.analysis.Analyzer import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.rules.Rule -import org.apache.spark.sql.execution.{QueryExecution, SparkPlanner} +import org.apache.spark.sql.execution.SparkPlanner import org.apache.spark.sql.execution.datasources._ import org.apache.spark.sql.hive.client.HiveClient -import org.apache.spark.sql.internal.{BaseSessionStateBuilder, SessionFunctionResourceLoader, SessionState, SharedState, SQLConf} -import org.apache.spark.sql.streaming.StreamingQueryManager - +import org.apache.spark.sql.internal.{BaseSessionStateBuilder, SessionResourceLoader, SessionState} /** - * A class that holds all session-specific state in a given [[SparkSession]] backed by Hive. - * - * @param sparkContext The [[SparkContext]]. - * @param sharedState The shared state. - * @param conf SQL-specific key-value configurations. - * @param experimentalMethods The experimental methods. - * @param functionRegistry Internal catalog for managing functions registered by the user. - * @param catalog Internal catalog for managing table and database states that uses Hive client for - * interacting with the metastore. - * @param sqlParser Parser that extracts expressions, plans, table identifiers etc. from SQL texts. - * @param analyzer Logical query plan analyzer for resolving unresolved attributes and relations. - * @param optimizer Logical query plan optimizer. - * @param planner Planner that converts optimized logical plans to physical plans and that takes - * Hive-specific strategies into account. - * @param streamingQueryManager Interface to start and stop streaming queries. - * @param createQueryExecution Function used to create QueryExecution objects. - * @param createClone Function used to create clones of the session state. - * @param metadataHive The Hive metadata client. + * Entry object for creating a Hive aware [[SessionState]]. */ -private[hive] class HiveSessionState( - sparkContext: SparkContext, - sharedState: SharedState, - conf: SQLConf, - experimentalMethods: ExperimentalMethods, - functionRegistry: FunctionRegistry, - override val catalog: HiveSessionCatalog, - sqlParser: ParserInterface, - analyzer: Analyzer, - optimizer: Optimizer, - planner: SparkPlanner, - streamingQueryManager: StreamingQueryManager, - createQueryExecution: LogicalPlan => QueryExecution, - createClone: (SparkSession, SessionState) => SessionState, - val metadataHive: HiveClient) - extends SessionState( - sparkContext, - sharedState, - conf, - experimentalMethods, - functionRegistry, - catalog, - sqlParser, - analyzer, - optimizer, - planner, - streamingQueryManager, - createQueryExecution, - createClone) { - - // ------------------------------------------------------ - // Helper methods, partially leftover from pre-2.0 days - // ------------------------------------------------------ - - override def addJar(path: String): Unit = { - metadataHive.addJar(path) - super.addJar(path) - } - - /** - * When true, enables an experimental feature where metastore tables that use the parquet SerDe - * are automatically converted to use the Spark SQL parquet table scan, instead of the Hive - * SerDe. - */ - def convertMetastoreParquet: Boolean = { - conf.getConf(HiveUtils.CONVERT_METASTORE_PARQUET) - } - - /** - * When true, also tries to merge possibly different but compatible Parquet schemas in different - * Parquet data files. - * - * This configuration is only effective when "spark.sql.hive.convertMetastoreParquet" is true. - */ - def convertMetastoreParquetWithSchemaMerging: Boolean = { - conf.getConf(HiveUtils.CONVERT_METASTORE_PARQUET_WITH_SCHEMA_MERGING) - } - - /** - * When true, enables an experimental feature where metastore tables that use the Orc SerDe - * are automatically converted to use the Spark SQL ORC table scan, instead of the Hive - * SerDe. - */ - def convertMetastoreOrc: Boolean = { - conf.getConf(HiveUtils.CONVERT_METASTORE_ORC) - } - - /** - * When true, Hive Thrift server will execute SQL queries asynchronously using a thread pool." - */ - def hiveThriftServerAsync: Boolean = { - conf.getConf(HiveUtils.HIVE_THRIFT_SERVER_ASYNC) - } -} - private[hive] object HiveSessionState { /** - * Create a new [[HiveSessionState]] for the given session. + * Create a new Hive aware [[SessionState]]. for the given session. */ - def apply(session: SparkSession): HiveSessionState = { + def apply(session: SparkSession): SessionState = { new HiveSessionStateBuilder(session).build() } } @@ -147,6 +50,14 @@ class HiveSessionStateBuilder(session: SparkSession, parentState: Option[Session private def externalCatalog: HiveExternalCatalog = session.sharedState.externalCatalog.asInstanceOf[HiveExternalCatalog] + /** + * Create a Hive aware resource loader. + */ + override protected lazy val resourceLoader: HiveSessionResourceLoader = { + val client: HiveClient = externalCatalog.client.newSession() + new HiveSessionResourceLoader(session, client) + } + /** * Create a [[HiveSessionCatalog]]. */ @@ -159,7 +70,7 @@ class HiveSessionStateBuilder(session: SparkSession, parentState: Option[Session conf, SessionState.newHadoopConf(session.sparkContext.hadoopConfiguration, conf), sqlParser, - new SessionFunctionResourceLoader(session)) + resourceLoader) parentState.foreach(_.catalog.copyStateTo(catalog)) catalog } @@ -217,23 +128,14 @@ class HiveSessionStateBuilder(session: SparkSession, parentState: Option[Session } override protected def newBuilder: NewBuilder = new HiveSessionStateBuilder(_, _) +} - override def build(): HiveSessionState = { - val metadataHive: HiveClient = externalCatalog.client.newSession() - new HiveSessionState( - session.sparkContext, - session.sharedState, - conf, - experimentalMethods, - functionRegistry, - catalog, - sqlParser, - analyzer, - optimizer, - planner, - streamingQueryManager, - createQueryExecution, - createClone, - metadataHive) +class HiveSessionResourceLoader( + session: SparkSession, + client: HiveClient) + extends SessionResourceLoader(session) { + override def addJar(path: String): Unit = { + client.addJar(path) + super.addJar(path) } } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala index 32ca69605ef4d..0bcf219922764 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala @@ -34,7 +34,6 @@ import org.apache.spark.{SparkConf, SparkContext} import org.apache.spark.internal.Logging import org.apache.spark.sql.{SparkSession, SQLContext} import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation -import org.apache.spark.sql.catalyst.catalog.ExternalCatalog import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.execution.QueryExecution import org.apache.spark.sql.execution.command.CacheTableCommand @@ -81,7 +80,7 @@ private[hive] class TestHiveSharedState( hiveClient: Option[HiveClient] = None) extends SharedState(sc) { - override lazy val externalCatalog: ExternalCatalog = { + override lazy val externalCatalog: TestHiveExternalCatalog = { new TestHiveExternalCatalog( sc.conf, sc.hadoopConfiguration, @@ -123,8 +122,6 @@ class TestHiveContext( new TestHiveContext(sparkSession.newSession()) } - override def sessionState: HiveSessionState = sparkSession.sessionState - def setCacheTables(c: Boolean): Unit = { sparkSession.setCacheTables(c) } @@ -155,7 +152,7 @@ class TestHiveContext( private[hive] class TestHiveSparkSession( @transient private val sc: SparkContext, @transient private val existingSharedState: Option[TestHiveSharedState], - @transient private val parentSessionState: Option[HiveSessionState], + @transient private val parentSessionState: Option[SessionState], private val loadTestTables: Boolean) extends SparkSession(sc) with Logging { self => @@ -195,10 +192,12 @@ private[hive] class TestHiveSparkSession( } @transient - override lazy val sessionState: HiveSessionState = { + override lazy val sessionState: SessionState = { new TestHiveSessionStateBuilder(this, parentSessionState).build() } + lazy val metadataHive: HiveClient = sharedState.externalCatalog.client.newSession() + override def newSession(): TestHiveSparkSession = { new TestHiveSparkSession(sc, Some(sharedState), None, loadTestTables) } @@ -492,7 +491,7 @@ private[hive] class TestHiveSparkSession( sessionState.catalog.clearTempTables() sessionState.catalog.tableRelationCache.invalidateAll() - sessionState.metadataHive.reset() + metadataHive.reset() FunctionRegistry.getFunctionNames.asScala.filterNot(originalUDFs.contains(_)). foreach { udfName => FunctionRegistry.unregisterTemporaryUDF(udfName) } @@ -509,14 +508,14 @@ private[hive] class TestHiveSparkSession( sessionState.conf.setConfString("fs.defaultFS", new File(".").toURI.toString) // It is important that we RESET first as broken hooks that might have been set could break // other sql exec here. - sessionState.metadataHive.runSqlHive("RESET") + metadataHive.runSqlHive("RESET") // For some reason, RESET does not reset the following variables... // https://issues.apache.org/jira/browse/HIVE-9004 - sessionState.metadataHive.runSqlHive("set hive.table.parameters.default=") - sessionState.metadataHive.runSqlHive("set datanucleus.cache.collections=true") - sessionState.metadataHive.runSqlHive("set datanucleus.cache.collections.lazy=true") + metadataHive.runSqlHive("set hive.table.parameters.default=") + metadataHive.runSqlHive("set datanucleus.cache.collections=true") + metadataHive.runSqlHive("set datanucleus.cache.collections.lazy=true") // Lots of tests fail if we do not change the partition whitelist from the default. - sessionState.metadataHive.runSqlHive("set hive.metastore.partition.name.whitelist.pattern=.*") + metadataHive.runSqlHive("set hive.metastore.partition.name.whitelist.pattern=.*") sessionState.catalog.setCurrentDatabase("default") } catch { diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetastoreCatalogSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetastoreCatalogSuite.scala index 079358b29a191..d8fd68b63d1eb 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetastoreCatalogSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetastoreCatalogSuite.scala @@ -115,7 +115,7 @@ class DataSourceWithHiveMetastoreCatalogSuite assert(columns.map(_.dataType) === Seq(DecimalType(10, 3), StringType)) checkAnswer(table("t"), testDF) - assert(sessionState.metadataHive.runSqlHive("SELECT * FROM t") === Seq("1.1\t1", "2.1\t2")) + assert(sparkSession.metadataHive.runSqlHive("SELECT * FROM t") === Seq("1.1\t1", "2.1\t2")) } } @@ -147,7 +147,7 @@ class DataSourceWithHiveMetastoreCatalogSuite assert(columns.map(_.dataType) === Seq(DecimalType(10, 3), StringType)) checkAnswer(table("t"), testDF) - assert(sessionState.metadataHive.runSqlHive("SELECT * FROM t") === + assert(sparkSession.metadataHive.runSqlHive("SELECT * FROM t") === Seq("1.1\t1", "2.1\t2")) } } @@ -176,7 +176,7 @@ class DataSourceWithHiveMetastoreCatalogSuite assert(columns.map(_.dataType) === Seq(IntegerType, StringType)) checkAnswer(table("t"), Row(1, "val_1")) - assert(sessionState.metadataHive.runSqlHive("SELECT * FROM t") === Seq("1\tval_1")) + assert(sparkSession.metadataHive.runSqlHive("SELECT * FROM t") === Seq("1\tval_1")) } } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala index f02b7218d6eee..55e02acfa4ce3 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala @@ -379,8 +379,7 @@ class MetastoreDataSourcesSuite extends QueryTest with SQLTestUtils with TestHiv |) """.stripMargin) - val expectedPath = - sessionState.catalog.hiveDefaultTableFilePath(TableIdentifier("ctasJsonTable")) + val expectedPath = sessionState.catalog.defaultTablePath(TableIdentifier("ctasJsonTable")) val filesystemPath = new Path(expectedPath) val fs = filesystemPath.getFileSystem(spark.sessionState.newHadoopConf()) fs.delete(filesystemPath, true) @@ -486,7 +485,7 @@ class MetastoreDataSourcesSuite extends QueryTest with SQLTestUtils with TestHiv sql("DROP TABLE savedJsonTable") intercept[AnalysisException] { read.json( - sessionState.catalog.hiveDefaultTableFilePath(TableIdentifier("savedJsonTable"))) + sessionState.catalog.defaultTablePath(TableIdentifier("savedJsonTable")).toString) } } @@ -756,7 +755,7 @@ class MetastoreDataSourcesSuite extends QueryTest with SQLTestUtils with TestHiv serde = None, compressed = false, properties = Map( - "path" -> sessionState.catalog.hiveDefaultTableFilePath(TableIdentifier(tableName))) + "path" -> sessionState.catalog.defaultTablePath(TableIdentifier(tableName)).toString) ), properties = Map( DATASOURCE_PROVIDER -> "json", diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala index 04bc79d430324..f0a995c274b64 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala @@ -128,11 +128,11 @@ class HiveDDLSuite dbPath: Option[String] = None): Boolean = { val expectedTablePath = if (dbPath.isEmpty) { - hiveContext.sessionState.catalog.hiveDefaultTableFilePath(tableIdentifier) + hiveContext.sessionState.catalog.defaultTablePath(tableIdentifier) } else { - new Path(new Path(dbPath.get), tableIdentifier.table).toString + new Path(new Path(dbPath.get), tableIdentifier.table) } - val filesystemPath = new Path(expectedTablePath) + val filesystemPath = new Path(expectedTablePath.toString) val fs = filesystemPath.getFileSystem(spark.sessionState.newHadoopConf()) fs.exists(filesystemPath) } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala index 81af24979d822..9fc2923bb6fd8 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala @@ -22,6 +22,7 @@ import java.io.File import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.catalog.CatalogRelation +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.execution.DataSourceScanExec import org.apache.spark.sql.execution.datasources._ import org.apache.spark.sql.hive.execution.HiveTableScanExec @@ -448,10 +449,14 @@ class ParquetMetastoreSuite extends ParquetPartitioningTest { } } + private def getCachedDataSourceTable(id: TableIdentifier): LogicalPlan = { + sessionState.catalog.asInstanceOf[HiveSessionCatalog].getCachedDataSourceTable(id) + } + test("Caching converted data source Parquet Relations") { def checkCached(tableIdentifier: TableIdentifier): Unit = { // Converted test_parquet should be cached. - sessionState.catalog.getCachedDataSourceTable(tableIdentifier) match { + getCachedDataSourceTable(tableIdentifier) match { case null => fail("Converted test_parquet should be cached in the cache.") case LogicalRelation(_: HadoopFsRelation, _, _) => // OK case other => @@ -479,14 +484,14 @@ class ParquetMetastoreSuite extends ParquetPartitioningTest { var tableIdentifier = TableIdentifier("test_insert_parquet", Some("default")) // First, make sure the converted test_parquet is not cached. - assert(sessionState.catalog.getCachedDataSourceTable(tableIdentifier) === null) + assert(getCachedDataSourceTable(tableIdentifier) === null) // Table lookup will make the table cached. table("test_insert_parquet") checkCached(tableIdentifier) // For insert into non-partitioned table, we will do the conversion, // so the converted test_insert_parquet should be cached. sessionState.refreshTable("test_insert_parquet") - assert(sessionState.catalog.getCachedDataSourceTable(tableIdentifier) === null) + assert(getCachedDataSourceTable(tableIdentifier) === null) sql( """ |INSERT INTO TABLE test_insert_parquet @@ -499,7 +504,7 @@ class ParquetMetastoreSuite extends ParquetPartitioningTest { sql("select a, b from jt").collect()) // Invalidate the cache. sessionState.refreshTable("test_insert_parquet") - assert(sessionState.catalog.getCachedDataSourceTable(tableIdentifier) === null) + assert(getCachedDataSourceTable(tableIdentifier) === null) // Create a partitioned table. sql( @@ -517,7 +522,7 @@ class ParquetMetastoreSuite extends ParquetPartitioningTest { """.stripMargin) tableIdentifier = TableIdentifier("test_parquet_partitioned_cache_test", Some("default")) - assert(sessionState.catalog.getCachedDataSourceTable(tableIdentifier) === null) + assert(getCachedDataSourceTable(tableIdentifier) === null) sql( """ |INSERT INTO TABLE test_parquet_partitioned_cache_test @@ -526,14 +531,14 @@ class ParquetMetastoreSuite extends ParquetPartitioningTest { """.stripMargin) // Right now, insert into a partitioned Parquet is not supported in data source Parquet. // So, we expect it is not cached. - assert(sessionState.catalog.getCachedDataSourceTable(tableIdentifier) === null) + assert(getCachedDataSourceTable(tableIdentifier) === null) sql( """ |INSERT INTO TABLE test_parquet_partitioned_cache_test |PARTITION (`date`='2015-04-02') |select a, b from jt """.stripMargin) - assert(sessionState.catalog.getCachedDataSourceTable(tableIdentifier) === null) + assert(getCachedDataSourceTable(tableIdentifier) === null) // Make sure we can cache the partitioned table. table("test_parquet_partitioned_cache_test") @@ -549,7 +554,7 @@ class ParquetMetastoreSuite extends ParquetPartitioningTest { """.stripMargin).collect()) sessionState.refreshTable("test_parquet_partitioned_cache_test") - assert(sessionState.catalog.getCachedDataSourceTable(tableIdentifier) === null) + assert(getCachedDataSourceTable(tableIdentifier) === null) dropTables("test_insert_parquet", "test_parquet_partitioned_cache_test") } From 17eddb35a280e77da7520343e0bf2a86b329ed62 Mon Sep 17 00:00:00 2001 From: jerryshao Date: Tue, 28 Mar 2017 10:41:11 -0700 Subject: [PATCH 0137/1765] [SPARK-19995][YARN] Register tokens to current UGI to avoid re-issuing of tokens in yarn client mode ## What changes were proposed in this pull request? In the current Spark on YARN code, we will obtain tokens from provided services, but we're not going to add these tokens to the current user's credentials. This will make all the following operations to these services still require TGT rather than delegation tokens. This is unnecessary since we already got the tokens, also this will lead to failure in user impersonation scenario, because the TGT is granted by real user, not proxy user. So here changing to put all the tokens to the current UGI, so that following operations to these services will honor tokens rather than TGT, and this will further handle the proxy user issue mentioned above. ## How was this patch tested? Local verified in secure cluster. vanzin tgravescs mridulm dongjoon-hyun please help to review, thanks a lot. Author: jerryshao Closes #17335 from jerryshao/SPARK-19995. --- .../src/main/scala/org/apache/spark/deploy/yarn/Client.scala | 3 +++ 1 file changed, 3 insertions(+) diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala index ccb0f8fdbbc21..3218d221143e5 100644 --- a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala +++ b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala @@ -371,6 +371,9 @@ private[spark] class Client( val nearestTimeOfNextRenewal = credentialManager.obtainCredentials(hadoopConf, credentials) if (credentials != null) { + // Add credentials to current user's UGI, so that following operations don't need to use the + // Kerberos tgt to get delegations again in the client side. + UserGroupInformation.getCurrentUser.addCredentials(credentials) logDebug(YarnSparkHadoopUtil.get.dumpTokens(credentials).mkString("\n")) } From d4fac410e0554b7ccd44be44b7ce2fe07ed7f206 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Tue, 28 Mar 2017 11:47:43 -0700 Subject: [PATCH 0138/1765] [SPARK-20125][SQL] Dataset of type option of map does not work ## What changes were proposed in this pull request? When we build the deserializer expression for map type, we will use `StaticInvoke` to call `ArrayBasedMapData.toScalaMap`, and declare the return type as `scala.collection.immutable.Map`. If the map is inside an Option, we will wrap this `StaticInvoke` with `WrapOption`, which requires the input to be `scala.collect.Map`. Ideally this should be fine, as `scala.collection.immutable.Map` extends `scala.collect.Map`, but our `ObjectType` is too strict about this, this PR fixes it. ## How was this patch tested? new regression test Author: Wenchen Fan Closes #17454 from cloud-fan/map. --- .../main/scala/org/apache/spark/sql/types/ObjectType.scala | 5 +++++ .../src/test/scala/org/apache/spark/sql/DatasetSuite.scala | 6 ++++++ 2 files changed, 11 insertions(+) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ObjectType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ObjectType.scala index b18fba29af0f9..2d49fe076786a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ObjectType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ObjectType.scala @@ -44,4 +44,9 @@ case class ObjectType(cls: Class[_]) extends DataType { def asNullable: DataType = this override def simpleString: String = cls.getName + + override def acceptsType(other: DataType): Boolean = other match { + case ObjectType(otherCls) => cls.isAssignableFrom(otherCls) + case _ => false + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala index 6417e7a8b6038..68e071a1a694f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala @@ -1154,10 +1154,16 @@ class DatasetSuite extends QueryTest with SharedSQLContext { assert(errMsg3.getMessage.startsWith("cannot have circular references in class, but got the " + "circular reference of class")) } + + test("SPARK-20125: option of map") { + val ds = Seq(WithMapInOption(Some(Map(1 -> 1)))).toDS() + checkDataset(ds, WithMapInOption(Some(Map(1 -> 1)))) + } } case class WithImmutableMap(id: String, map_test: scala.collection.immutable.Map[Long, String]) case class WithMap(id: String, map_test: scala.collection.Map[Long, String]) +case class WithMapInOption(m: Option[scala.collection.Map[Int, Int]]) case class Generic[T](id: T, value: Double) From 92e385e0b55d70a48411e90aa0f2ed141c4d07c8 Mon Sep 17 00:00:00 2001 From: liujianhui Date: Tue, 28 Mar 2017 12:13:45 -0700 Subject: [PATCH 0139/1765] [SPARK-19868] conflict TasksetManager lead to spark stopped ## What changes were proposed in this pull request? We must set the taskset to zombie before the DAGScheduler handles the taskEnded event. It's possible the taskEnded event will cause the DAGScheduler to launch a new stage attempt (this happens when map output data was lost), and if this happens before the taskSet has been set to zombie, it will appear that we have conflicting task sets. Author: liujianhui Closes #17208 from liujianhuiouc/spark-19868. --- .../spark/scheduler/TaskSetManager.scala | 15 ++++++----- .../spark/scheduler/TaskSetManagerSuite.scala | 27 ++++++++++++++++++- 2 files changed, 34 insertions(+), 8 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala index a177aab5f95de..a41b059fa7dec 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala @@ -713,13 +713,7 @@ private[spark] class TaskSetManager( successfulTaskDurations.insert(info.duration) } removeRunningTask(tid) - // This method is called by "TaskSchedulerImpl.handleSuccessfulTask" which holds the - // "TaskSchedulerImpl" lock until exiting. To avoid the SPARK-7655 issue, we should not - // "deserialize" the value when holding a lock to avoid blocking other threads. So we call - // "result.value()" in "TaskResultGetter.enqueueSuccessfulTask" before reaching here. - // Note: "result.value()" only deserializes the value when it's called at the first time, so - // here "result.value()" just returns the value and won't block other threads. - sched.dagScheduler.taskEnded(tasks(index), Success, result.value(), result.accumUpdates, info) + // Kill any other attempts for the same task (since those are unnecessary now that one // attempt completed successfully). for (attemptInfo <- taskAttempts(index) if attemptInfo.running) { @@ -746,6 +740,13 @@ private[spark] class TaskSetManager( logInfo("Ignoring task-finished event for " + info.id + " in stage " + taskSet.id + " because task " + index + " has already completed successfully") } + // This method is called by "TaskSchedulerImpl.handleSuccessfulTask" which holds the + // "TaskSchedulerImpl" lock until exiting. To avoid the SPARK-7655 issue, we should not + // "deserialize" the value when holding a lock to avoid blocking other threads. So we call + // "result.value()" in "TaskResultGetter.enqueueSuccessfulTask" before reaching here. + // Note: "result.value()" only deserializes the value when it's called at the first time, so + // here "result.value()" just returns the value and won't block other threads. + sched.dagScheduler.taskEnded(tasks(index), Success, result.value(), result.accumUpdates, info) maybeFinishTaskSet() } diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala index 132caef0978fb..9ca6b8b0fe635 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala @@ -22,8 +22,10 @@ import java.util.{Properties, Random} import scala.collection.mutable import scala.collection.mutable.ArrayBuffer -import org.mockito.Matchers.{anyInt, anyString} +import org.mockito.Matchers.{any, anyInt, anyString} import org.mockito.Mockito.{mock, never, spy, verify, when} +import org.mockito.invocation.InvocationOnMock +import org.mockito.stubbing.Answer import org.apache.spark._ import org.apache.spark.internal.config @@ -1056,6 +1058,29 @@ class TaskSetManagerSuite extends SparkFunSuite with LocalSparkContext with Logg assert(manager.isZombie) } + + test("SPARK-19868: DagScheduler only notified of taskEnd when state is ready") { + // dagScheduler.taskEnded() is async, so it may *seem* ok to call it before we've set all + // appropriate state, eg. isZombie. However, this sets up a race that could go the wrong way. + // This is a super-focused regression test which checks the zombie state as soon as + // dagScheduler.taskEnded() is called, to ensure we haven't introduced a race. + sc = new SparkContext("local", "test") + sched = new FakeTaskScheduler(sc, ("exec1", "host1")) + val mockDAGScheduler = mock(classOf[DAGScheduler]) + sched.dagScheduler = mockDAGScheduler + val taskSet = FakeTask.createTaskSet(numTasks = 1, stageId = 0, stageAttemptId = 0) + val manager = new TaskSetManager(sched, taskSet, MAX_TASK_FAILURES, clock = new ManualClock(1)) + when(mockDAGScheduler.taskEnded(any(), any(), any(), any(), any())).then(new Answer[Unit] { + override def answer(invocationOnMock: InvocationOnMock): Unit = { + assert(manager.isZombie === true) + } + }) + val taskOption = manager.resourceOffer("exec1", "host1", NO_PREF) + assert(taskOption.isDefined) + // this would fail, inside our mock dag scheduler, if it calls dagScheduler.taskEnded() too soon + manager.handleSuccessfulTask(0, createTaskResult(0)) + } + test("SPARK-17894: Verify TaskSetManagers for different stage attempts have unique names") { sc = new SparkContext("local", "test") sched = new FakeTaskScheduler(sc, ("exec1", "host1")) From 7d432af8f3c47973550ea253dae0c23cd2961bde Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=A2=9C=E5=8F=91=E6=89=8D=EF=BC=88Yan=20Facai=EF=BC=89?= Date: Tue, 28 Mar 2017 16:14:01 -0700 Subject: [PATCH 0140/1765] [SPARK-20043][ML] DecisionTreeModel: ImpurityCalculator builder fails for uppercase impurity type Gini MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Fix bug: DecisionTreeModel can't recongnize Impurity "Gini" when loading TODO: + [x] add unit test + [x] fix the bug Author: 颜发才(Yan Facai) Closes #17407 from facaiy/BUG/decision_tree_loader_failer_with_Gini_impurity. --- .../spark/mllib/tree/impurity/Impurity.scala | 2 +- .../DecisionTreeClassifierSuite.scala | 14 ++++++++++++++ 2 files changed, 15 insertions(+), 1 deletion(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala index a5bdc2c6d2c94..98a3021461eb8 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala @@ -184,7 +184,7 @@ private[spark] object ImpurityCalculator { * the given stats. */ def getCalculator(impurity: String, stats: Array[Double]): ImpurityCalculator = { - impurity match { + impurity.toLowerCase match { case "gini" => new GiniCalculator(stats) case "entropy" => new EntropyCalculator(stats) case "variance" => new VarianceCalculator(stats) diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala index 10de50306a5ce..964fcfbdd87a2 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala @@ -385,6 +385,20 @@ class DecisionTreeClassifierSuite testEstimatorAndModelReadWrite(dt, continuousData, allParamSettings ++ Map("maxDepth" -> 0), allParamSettings ++ Map("maxDepth" -> 0), checkModelData) } + + test("SPARK-20043: " + + "ImpurityCalculator builder fails for uppercase impurity type Gini in model read/write") { + val rdd = TreeTests.getTreeReadWriteData(sc) + val data: DataFrame = + TreeTests.setMetadata(rdd, Map.empty[Int, Int], numClasses = 2) + + val dt = new DecisionTreeClassifier() + .setImpurity("Gini") + .setMaxDepth(2) + val model = dt.fit(data) + + testDefaultReadWrite(model) + } } private[ml] object DecisionTreeClassifierSuite extends SparkFunSuite { From a5c87707eaec5cacdfb703eb396dfc264bc54cda Mon Sep 17 00:00:00 2001 From: Bago Amirbekian Date: Tue, 28 Mar 2017 19:19:16 -0700 Subject: [PATCH 0141/1765] [SPARK-20040][ML][PYTHON] pyspark wrapper for ChiSquareTest ## What changes were proposed in this pull request? A pyspark wrapper for spark.ml.stat.ChiSquareTest ## How was this patch tested? unit tests doctests Author: Bago Amirbekian Closes #17421 from MrBago/chiSquareTestWrapper. --- dev/sparktestsupport/modules.py | 1 + .../apache/spark/ml/stat/ChiSquareTest.scala | 6 +- python/docs/pyspark.ml.rst | 8 ++ python/pyspark/ml/stat.py | 93 +++++++++++++++++++ python/pyspark/ml/tests.py | 31 +++++-- 5 files changed, 127 insertions(+), 12 deletions(-) create mode 100644 python/pyspark/ml/stat.py diff --git a/dev/sparktestsupport/modules.py b/dev/sparktestsupport/modules.py index eaf1f3a1db2ff..246f5188a518d 100644 --- a/dev/sparktestsupport/modules.py +++ b/dev/sparktestsupport/modules.py @@ -431,6 +431,7 @@ def __hash__(self): "pyspark.ml.linalg.__init__", "pyspark.ml.recommendation", "pyspark.ml.regression", + "pyspark.ml.stat", "pyspark.ml.tuning", "pyspark.ml.tests", ], diff --git a/mllib/src/main/scala/org/apache/spark/ml/stat/ChiSquareTest.scala b/mllib/src/main/scala/org/apache/spark/ml/stat/ChiSquareTest.scala index 21eba9a49809f..5b38ca73e8014 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/stat/ChiSquareTest.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/stat/ChiSquareTest.scala @@ -46,9 +46,9 @@ object ChiSquareTest { statistics: Vector) /** - * Conduct Pearson's independence test for every feature against the label across the input RDD. - * For each feature, the (feature, label) pairs are converted into a contingency matrix for which - * the Chi-squared statistic is computed. All label and feature values must be categorical. + * Conduct Pearson's independence test for every feature against the label. For each feature, the + * (feature, label) pairs are converted into a contingency matrix for which the Chi-squared + * statistic is computed. All label and feature values must be categorical. * * The null hypothesis is that the occurrence of the outcomes is statistically independent. * diff --git a/python/docs/pyspark.ml.rst b/python/docs/pyspark.ml.rst index a68183445d78b..930646de9cd86 100644 --- a/python/docs/pyspark.ml.rst +++ b/python/docs/pyspark.ml.rst @@ -65,6 +65,14 @@ pyspark.ml.regression module :undoc-members: :inherited-members: +pyspark.ml.stat module +---------------------- + +.. automodule:: pyspark.ml.stat + :members: + :undoc-members: + :inherited-members: + pyspark.ml.tuning module ------------------------ diff --git a/python/pyspark/ml/stat.py b/python/pyspark/ml/stat.py new file mode 100644 index 0000000000000..db043ff68feca --- /dev/null +++ b/python/pyspark/ml/stat.py @@ -0,0 +1,93 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from pyspark import since, SparkContext +from pyspark.ml.common import _java2py, _py2java +from pyspark.ml.wrapper import _jvm + + +class ChiSquareTest(object): + """ + .. note:: Experimental + + Conduct Pearson's independence test for every feature against the label. For each feature, + the (feature, label) pairs are converted into a contingency matrix for which the Chi-squared + statistic is computed. All label and feature values must be categorical. + + The null hypothesis is that the occurrence of the outcomes is statistically independent. + + :param dataset: + DataFrame of categorical labels and categorical features. + Real-valued features will be treated as categorical for each distinct value. + :param featuresCol: + Name of features column in dataset, of type `Vector` (`VectorUDT`). + :param labelCol: + Name of label column in dataset, of any numerical type. + :return: + DataFrame containing the test result for every feature against the label. + This DataFrame will contain a single Row with the following fields: + - `pValues: Vector` + - `degreesOfFreedom: Array[Int]` + - `statistics: Vector` + Each of these fields has one value per feature. + + >>> from pyspark.ml.linalg import Vectors + >>> from pyspark.ml.stat import ChiSquareTest + >>> dataset = [[0, Vectors.dense([0, 0, 1])], + ... [0, Vectors.dense([1, 0, 1])], + ... [1, Vectors.dense([2, 1, 1])], + ... [1, Vectors.dense([3, 1, 1])]] + >>> dataset = spark.createDataFrame(dataset, ["label", "features"]) + >>> chiSqResult = ChiSquareTest.test(dataset, 'features', 'label') + >>> chiSqResult.select("degreesOfFreedom").collect()[0] + Row(degreesOfFreedom=[3, 1, 0]) + + .. versionadded:: 2.2.0 + + """ + @staticmethod + @since("2.2.0") + def test(dataset, featuresCol, labelCol): + """ + Perform a Pearson's independence test using dataset. + """ + sc = SparkContext._active_spark_context + javaTestObj = _jvm().org.apache.spark.ml.stat.ChiSquareTest + args = [_py2java(sc, arg) for arg in (dataset, featuresCol, labelCol)] + return _java2py(sc, javaTestObj.test(*args)) + + +if __name__ == "__main__": + import doctest + import pyspark.ml.stat + from pyspark.sql import SparkSession + + globs = pyspark.ml.stat.__dict__.copy() + # The small batch size here ensures that we see multiple batches, + # even in these small test examples: + spark = SparkSession.builder \ + .master("local[2]") \ + .appName("ml.stat tests") \ + .getOrCreate() + sc = spark.sparkContext + globs['sc'] = sc + globs['spark'] = spark + + failure_count, test_count = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS) + spark.stop() + if failure_count: + exit(-1) diff --git a/python/pyspark/ml/tests.py b/python/pyspark/ml/tests.py index 527db9b66793a..571ac4bc1c366 100755 --- a/python/pyspark/ml/tests.py +++ b/python/pyspark/ml/tests.py @@ -41,9 +41,7 @@ import tempfile import array as pyarray import numpy as np -from numpy import ( - abs, all, arange, array, array_equal, dot, exp, inf, mean, ones, random, tile, zeros) -from numpy import sum as array_sum +from numpy import abs, all, arange, array, array_equal, inf, ones, tile, zeros import inspect from pyspark import keyword_only, SparkContext @@ -54,20 +52,19 @@ from pyspark.ml.evaluation import BinaryClassificationEvaluator, RegressionEvaluator from pyspark.ml.feature import * from pyspark.ml.fpm import FPGrowth, FPGrowthModel -from pyspark.ml.linalg import ( - DenseMatrix, DenseMatrix, DenseVector, Matrices, MatrixUDT, - SparseMatrix, SparseVector, Vector, VectorUDT, Vectors, _convert_to_vector) +from pyspark.ml.linalg import DenseMatrix, DenseMatrix, DenseVector, Matrices, MatrixUDT, \ + SparseMatrix, SparseVector, Vector, VectorUDT, Vectors from pyspark.ml.param import Param, Params, TypeConverters from pyspark.ml.param.shared import HasInputCol, HasMaxIter, HasSeed from pyspark.ml.recommendation import ALS -from pyspark.ml.regression import ( - DecisionTreeRegressor, GeneralizedLinearRegression, LinearRegression) +from pyspark.ml.regression import DecisionTreeRegressor, GeneralizedLinearRegression, \ + LinearRegression +from pyspark.ml.stat import ChiSquareTest from pyspark.ml.tuning import * from pyspark.ml.wrapper import JavaParams, JavaWrapper from pyspark.serializers import PickleSerializer from pyspark.sql import DataFrame, Row, SparkSession from pyspark.sql.functions import rand -from pyspark.sql.utils import IllegalArgumentException from pyspark.storagelevel import * from pyspark.tests import ReusedPySparkTestCase as PySparkTestCase @@ -1741,6 +1738,22 @@ def test_new_java_array(self): self.assertEqual(_java2py(self.sc, java_array), []) +class ChiSquareTestTests(SparkSessionTestCase): + + def test_chisquaretest(self): + data = [[0, Vectors.dense([0, 1, 2])], + [1, Vectors.dense([1, 1, 1])], + [2, Vectors.dense([2, 1, 0])]] + df = self.spark.createDataFrame(data, ['label', 'feat']) + res = ChiSquareTest.test(df, 'feat', 'label') + # This line is hitting the collect bug described in #17218, commented for now. + # pValues = res.select("degreesOfFreedom").collect()) + self.assertIsInstance(res, DataFrame) + fieldNames = set(field.name for field in res.schema.fields) + expectedFields = ["pValues", "degreesOfFreedom", "statistics"] + self.assertTrue(all(field in fieldNames for field in expectedFields)) + + if __name__ == "__main__": from pyspark.ml.tests import * if xmlrunner: From 9712bd3954c029de5c828f27b57d46e4a6325a38 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Wed, 29 Mar 2017 00:02:15 -0700 Subject: [PATCH 0142/1765] [SPARK-20134][SQL] SQLMetrics.postDriverMetricUpdates to simplify driver side metric updates ## What changes were proposed in this pull request? It is not super intuitive how to update SQLMetric on the driver side. This patch introduces a new SQLMetrics.postDriverMetricUpdates function to do that, and adds documentation to make it more obvious. ## How was this patch tested? Updated a test case to use this method. Author: Reynold Xin Closes #17464 from rxin/SPARK-20134. --- .../execution/basicPhysicalOperators.scala | 8 +------- .../exchange/BroadcastExchangeExec.scala | 8 +------- .../sql/execution/metric/SQLMetrics.scala | 20 +++++++++++++++++++ .../spark/sql/execution/ui/SQLListener.scala | 7 +++++++ .../sql/execution/ui/SQLListenerSuite.scala | 8 +++++--- 5 files changed, 34 insertions(+), 17 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala index d876688a8aabd..66a8e044ab879 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala @@ -628,13 +628,7 @@ case class SubqueryExec(name: String, child: SparkPlan) extends UnaryExecNode { val dataSize = rows.map(_.asInstanceOf[UnsafeRow].getSizeInBytes.toLong).sum longMetric("dataSize") += dataSize - // There are some cases we don't care about the metrics and call `SparkPlan.doExecute` - // directly without setting an execution id. We should be tolerant to it. - if (executionId != null) { - sparkContext.listenerBus.post(SparkListenerDriverAccumUpdates( - executionId.toLong, metrics.values.map(m => m.id -> m.value).toSeq)) - } - + SQLMetrics.postDriverMetricUpdates(sparkContext, executionId, metrics.values.toSeq) rows } }(SubqueryExec.executionContext) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/BroadcastExchangeExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/BroadcastExchangeExec.scala index 7be5d31d4a765..efcaca9338ad6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/BroadcastExchangeExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/BroadcastExchangeExec.scala @@ -97,13 +97,7 @@ case class BroadcastExchangeExec( val broadcasted = sparkContext.broadcast(relation) longMetric("broadcastTime") += (System.nanoTime() - beforeBroadcast) / 1000000 - // There are some cases we don't care about the metrics and call `SparkPlan.doExecute` - // directly without setting an execution id. We should be tolerant to it. - if (executionId != null) { - sparkContext.listenerBus.post(SparkListenerDriverAccumUpdates( - executionId.toLong, metrics.values.map(m => m.id -> m.value).toSeq)) - } - + SQLMetrics.postDriverMetricUpdates(sparkContext, executionId, metrics.values.toSeq) broadcasted } catch { case oe: OutOfMemoryError => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetrics.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetrics.scala index dbc27d8b237f3..ef982a4ebd10d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetrics.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetrics.scala @@ -22,9 +22,15 @@ import java.util.Locale import org.apache.spark.SparkContext import org.apache.spark.scheduler.AccumulableInfo +import org.apache.spark.sql.execution.ui.SparkListenerDriverAccumUpdates import org.apache.spark.util.{AccumulatorContext, AccumulatorV2, Utils} +/** + * A metric used in a SQL query plan. This is implemented as an [[AccumulatorV2]]. Updates on + * the executor side are automatically propagated and shown in the SQL UI through metrics. Updates + * on the driver side must be explicitly posted using [[SQLMetrics.postDriverMetricUpdates()]]. + */ class SQLMetric(val metricType: String, initValue: Long = 0L) extends AccumulatorV2[Long, Long] { // This is a workaround for SPARK-11013. // We may use -1 as initial value of the accumulator, if the accumulator is valid, we will @@ -126,4 +132,18 @@ object SQLMetrics { s"\n$sum ($min, $med, $max)" } } + + /** + * Updates metrics based on the driver side value. This is useful for certain metrics that + * are only updated on the driver, e.g. subquery execution time, or number of files. + */ + def postDriverMetricUpdates( + sc: SparkContext, executionId: String, metrics: Seq[SQLMetric]): Unit = { + // There are some cases we don't care about the metrics and call `SparkPlan.doExecute` + // directly without setting an execution id. We should be tolerant to it. + if (executionId != null) { + sc.listenerBus.post( + SparkListenerDriverAccumUpdates(executionId.toLong, metrics.map(m => m.id -> m.value))) + } + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLListener.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLListener.scala index 12d3bc9281f35..b4a91230a0012 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLListener.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLListener.scala @@ -47,6 +47,13 @@ case class SparkListenerSQLExecutionStart( case class SparkListenerSQLExecutionEnd(executionId: Long, time: Long) extends SparkListenerEvent +/** + * A message used to update SQL metric value for driver-side updates (which doesn't get reflected + * automatically). + * + * @param executionId The execution id for a query, so we can find the query plan. + * @param accumUpdates Map from accumulator id to the metric value (metrics are always 64-bit ints). + */ @DeveloperApi case class SparkListenerDriverAccumUpdates( executionId: Long, diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLListenerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLListenerSuite.scala index e41c00ecec271..e6cd41e4facf1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLListenerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLListenerSuite.scala @@ -477,9 +477,11 @@ private case class MyPlan(sc: SparkContext, expectedValue: Long) extends LeafExe override def doExecute(): RDD[InternalRow] = { longMetric("dummy") += expectedValue - sc.listenerBus.post(SparkListenerDriverAccumUpdates( - sc.getLocalProperty(SQLExecution.EXECUTION_ID_KEY).toLong, - metrics.values.map(m => m.id -> m.value).toSeq)) + + SQLMetrics.postDriverMetricUpdates( + sc, + sc.getLocalProperty(SQLExecution.EXECUTION_ID_KEY), + metrics.values.toSeq) sc.emptyRDD } } From b56ad2b1ec19fd60fa9d4926d12244fd3f56aca4 Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Wed, 29 Mar 2017 20:27:41 +0800 Subject: [PATCH 0143/1765] [SPARK-19556][CORE] Do not encrypt block manager data in memory. This change modifies the way block data is encrypted to make the more common cases faster, while penalizing an edge case. As a side effect of the change, all data that goes through the block manager is now encrypted only when needed, including the previous path (broadcast variables) where that did not happen. The way the change works is by not encrypting data that is stored in memory; so if a serialized block is in memory, it will only be encrypted once it is evicted to disk. The penalty comes when transferring that encrypted data from disk. If the data ends up in memory again, it is as efficient as before; but if the evicted block needs to be transferred directly to a remote executor, then there's now a performance penalty, since the code now uses a custom FileRegion implementation to decrypt the data before transferring. This also means that block data transferred between executors now is not encrypted (and thus relies on the network library encryption support for secrecy). Shuffle blocks are still transferred in encrypted form, since they're handled in a slightly different way by the code. This also keeps compatibility with existing external shuffle services, which transfer encrypted shuffle blocks, and avoids having to make the external service aware of encryption at all. The serialization and deserialization APIs in the SerializerManager now do not do encryption automatically; callers need to explicitly wrap their streams with an appropriate crypto stream before using those. As a result of these changes, some of the workarounds added in SPARK-19520 are removed here. Testing: a new trait ("EncryptionFunSuite") was added that provides an easy way to run a test twice, with encryption on and off; broadcast, block manager and caching tests were modified to use this new trait so that the existing tests exercise both encrypted and non-encrypted paths. I also ran some applications with encryption turned on to verify that they still work, including streaming tests that failed without the fix for SPARK-19520. Author: Marcelo Vanzin Closes #17295 from vanzin/SPARK-19556. --- .../apache/spark/network/util/JavaUtils.java | 15 ++ .../spark/broadcast/TorrentBroadcast.scala | 35 +-- .../spark/security/CryptoStreamUtils.scala | 87 ++++++- .../spark/serializer/SerializerManager.scala | 24 +- .../apache/spark/storage/BlockManager.scala | 172 ++++++++----- .../storage/BlockManagerManagedBuffer.scala | 33 ++- .../org/apache/spark/storage/DiskStore.scala | 236 +++++++++++++++--- .../apache/spark/storage/StorageUtils.scala | 32 --- .../spark/storage/memory/MemoryStore.scala | 2 +- .../spark/util/io/ChunkedByteBuffer.scala | 14 -- .../org/apache/spark/DistributedSuite.scala | 12 +- .../spark/broadcast/BroadcastSuite.scala | 16 +- .../security/CryptoStreamUtilsSuite.scala | 46 +++- .../spark/security/EncryptionFunSuite.scala | 39 +++ .../spark/storage/BlockManagerSuite.scala | 77 +++--- .../apache/spark/storage/DiskStoreSuite.scala | 115 ++++++++- .../rdd/WriteAheadLogBackedBlockRDD.scala | 6 +- .../receiver/ReceivedBlockHandler.scala | 11 +- .../streaming/ReceivedBlockHandlerSuite.scala | 5 +- .../WriteAheadLogBackedBlockRDDSuite.scala | 3 +- 20 files changed, 710 insertions(+), 270 deletions(-) create mode 100644 core/src/test/scala/org/apache/spark/security/EncryptionFunSuite.scala diff --git a/common/network-common/src/main/java/org/apache/spark/network/util/JavaUtils.java b/common/network-common/src/main/java/org/apache/spark/network/util/JavaUtils.java index f3eaf22c0166e..51d7fda0cb260 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/util/JavaUtils.java +++ b/common/network-common/src/main/java/org/apache/spark/network/util/JavaUtils.java @@ -18,9 +18,11 @@ package org.apache.spark.network.util; import java.io.Closeable; +import java.io.EOFException; import java.io.File; import java.io.IOException; import java.nio.ByteBuffer; +import java.nio.channels.ReadableByteChannel; import java.nio.charset.StandardCharsets; import java.util.concurrent.TimeUnit; import java.util.regex.Matcher; @@ -344,4 +346,17 @@ public static byte[] bufferToArray(ByteBuffer buffer) { } } + /** + * Fills a buffer with data read from the channel. + */ + public static void readFully(ReadableByteChannel channel, ByteBuffer dst) throws IOException { + int expected = dst.remaining(); + while (dst.hasRemaining()) { + if (channel.read(dst) < 0) { + throw new EOFException(String.format("Not enough bytes in channel (expected %d).", + expected)); + } + } + } + } diff --git a/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala b/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala index 22d01c47e645d..039df75ce74fd 100644 --- a/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala +++ b/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala @@ -29,7 +29,7 @@ import org.apache.spark._ import org.apache.spark.internal.Logging import org.apache.spark.io.CompressionCodec import org.apache.spark.serializer.Serializer -import org.apache.spark.storage.{BlockId, BroadcastBlockId, StorageLevel} +import org.apache.spark.storage._ import org.apache.spark.util.{ByteBufferInputStream, Utils} import org.apache.spark.util.io.{ChunkedByteBuffer, ChunkedByteBufferOutputStream} @@ -141,10 +141,10 @@ private[spark] class TorrentBroadcast[T: ClassTag](obj: T, id: Long) } /** Fetch torrent blocks from the driver and/or other executors. */ - private def readBlocks(): Array[ChunkedByteBuffer] = { + private def readBlocks(): Array[BlockData] = { // Fetch chunks of data. Note that all these chunks are stored in the BlockManager and reported // to the driver, so other executors can pull these chunks from this executor as well. - val blocks = new Array[ChunkedByteBuffer](numBlocks) + val blocks = new Array[BlockData](numBlocks) val bm = SparkEnv.get.blockManager for (pid <- Random.shuffle(Seq.range(0, numBlocks))) { @@ -173,7 +173,7 @@ private[spark] class TorrentBroadcast[T: ClassTag](obj: T, id: Long) throw new SparkException( s"Failed to store $pieceId of $broadcastId in local BlockManager") } - blocks(pid) = b + blocks(pid) = new ByteBufferBlockData(b, true) case None => throw new SparkException(s"Failed to get $pieceId of $broadcastId") } @@ -219,18 +219,22 @@ private[spark] class TorrentBroadcast[T: ClassTag](obj: T, id: Long) case None => logInfo("Started reading broadcast variable " + id) val startTimeMs = System.currentTimeMillis() - val blocks = readBlocks().flatMap(_.getChunks()) + val blocks = readBlocks() logInfo("Reading broadcast variable " + id + " took" + Utils.getUsedTimeMs(startTimeMs)) - val obj = TorrentBroadcast.unBlockifyObject[T]( - blocks, SparkEnv.get.serializer, compressionCodec) - // Store the merged copy in BlockManager so other tasks on this executor don't - // need to re-fetch it. - val storageLevel = StorageLevel.MEMORY_AND_DISK - if (!blockManager.putSingle(broadcastId, obj, storageLevel, tellMaster = false)) { - throw new SparkException(s"Failed to store $broadcastId in BlockManager") + try { + val obj = TorrentBroadcast.unBlockifyObject[T]( + blocks.map(_.toInputStream()), SparkEnv.get.serializer, compressionCodec) + // Store the merged copy in BlockManager so other tasks on this executor don't + // need to re-fetch it. + val storageLevel = StorageLevel.MEMORY_AND_DISK + if (!blockManager.putSingle(broadcastId, obj, storageLevel, tellMaster = false)) { + throw new SparkException(s"Failed to store $broadcastId in BlockManager") + } + obj + } finally { + blocks.foreach(_.dispose()) } - obj } } } @@ -277,12 +281,11 @@ private object TorrentBroadcast extends Logging { } def unBlockifyObject[T: ClassTag]( - blocks: Array[ByteBuffer], + blocks: Array[InputStream], serializer: Serializer, compressionCodec: Option[CompressionCodec]): T = { require(blocks.nonEmpty, "Cannot unblockify an empty array of blocks") - val is = new SequenceInputStream( - blocks.iterator.map(new ByteBufferInputStream(_)).asJavaEnumeration) + val is = new SequenceInputStream(blocks.iterator.asJavaEnumeration) val in: InputStream = compressionCodec.map(c => c.compressedInputStream(is)).getOrElse(is) val ser = serializer.newInstance() val serIn = ser.deserializeStream(in) diff --git a/core/src/main/scala/org/apache/spark/security/CryptoStreamUtils.scala b/core/src/main/scala/org/apache/spark/security/CryptoStreamUtils.scala index cdd3b8d8512b1..78dabb42ac9d2 100644 --- a/core/src/main/scala/org/apache/spark/security/CryptoStreamUtils.scala +++ b/core/src/main/scala/org/apache/spark/security/CryptoStreamUtils.scala @@ -16,20 +16,23 @@ */ package org.apache.spark.security -import java.io.{InputStream, OutputStream} +import java.io.{EOFException, InputStream, OutputStream} +import java.nio.ByteBuffer +import java.nio.channels.{ReadableByteChannel, WritableByteChannel} import java.util.Properties import javax.crypto.KeyGenerator import javax.crypto.spec.{IvParameterSpec, SecretKeySpec} import scala.collection.JavaConverters._ +import com.google.common.io.ByteStreams import org.apache.commons.crypto.random._ import org.apache.commons.crypto.stream._ import org.apache.spark.SparkConf import org.apache.spark.internal.Logging import org.apache.spark.internal.config._ -import org.apache.spark.network.util.CryptoUtils +import org.apache.spark.network.util.{CryptoUtils, JavaUtils} /** * A util class for manipulating IO encryption and decryption streams. @@ -48,12 +51,27 @@ private[spark] object CryptoStreamUtils extends Logging { os: OutputStream, sparkConf: SparkConf, key: Array[Byte]): OutputStream = { - val properties = toCryptoConf(sparkConf) - val iv = createInitializationVector(properties) + val params = new CryptoParams(key, sparkConf) + val iv = createInitializationVector(params.conf) os.write(iv) - val transformationStr = sparkConf.get(IO_CRYPTO_CIPHER_TRANSFORMATION) - new CryptoOutputStream(transformationStr, properties, os, - new SecretKeySpec(key, "AES"), new IvParameterSpec(iv)) + new CryptoOutputStream(params.transformation, params.conf, os, params.keySpec, + new IvParameterSpec(iv)) + } + + /** + * Wrap a `WritableByteChannel` for encryption. + */ + def createWritableChannel( + channel: WritableByteChannel, + sparkConf: SparkConf, + key: Array[Byte]): WritableByteChannel = { + val params = new CryptoParams(key, sparkConf) + val iv = createInitializationVector(params.conf) + val helper = new CryptoHelperChannel(channel) + + helper.write(ByteBuffer.wrap(iv)) + new CryptoOutputStream(params.transformation, params.conf, helper, params.keySpec, + new IvParameterSpec(iv)) } /** @@ -63,12 +81,27 @@ private[spark] object CryptoStreamUtils extends Logging { is: InputStream, sparkConf: SparkConf, key: Array[Byte]): InputStream = { - val properties = toCryptoConf(sparkConf) val iv = new Array[Byte](IV_LENGTH_IN_BYTES) - is.read(iv, 0, iv.length) - val transformationStr = sparkConf.get(IO_CRYPTO_CIPHER_TRANSFORMATION) - new CryptoInputStream(transformationStr, properties, is, - new SecretKeySpec(key, "AES"), new IvParameterSpec(iv)) + ByteStreams.readFully(is, iv) + val params = new CryptoParams(key, sparkConf) + new CryptoInputStream(params.transformation, params.conf, is, params.keySpec, + new IvParameterSpec(iv)) + } + + /** + * Wrap a `ReadableByteChannel` for decryption. + */ + def createReadableChannel( + channel: ReadableByteChannel, + sparkConf: SparkConf, + key: Array[Byte]): ReadableByteChannel = { + val iv = new Array[Byte](IV_LENGTH_IN_BYTES) + val buf = ByteBuffer.wrap(iv) + JavaUtils.readFully(channel, buf) + + val params = new CryptoParams(key, sparkConf) + new CryptoInputStream(params.transformation, params.conf, channel, params.keySpec, + new IvParameterSpec(iv)) } def toCryptoConf(conf: SparkConf): Properties = { @@ -102,4 +135,34 @@ private[spark] object CryptoStreamUtils extends Logging { } iv } + + /** + * This class is a workaround for CRYPTO-125, that forces all bytes to be written to the + * underlying channel. Since the callers of this API are using blocking I/O, there are no + * concerns with regards to CPU usage here. + */ + private class CryptoHelperChannel(sink: WritableByteChannel) extends WritableByteChannel { + + override def write(src: ByteBuffer): Int = { + val count = src.remaining() + while (src.hasRemaining()) { + sink.write(src) + } + count + } + + override def isOpen(): Boolean = sink.isOpen() + + override def close(): Unit = sink.close() + + } + + private class CryptoParams(key: Array[Byte], sparkConf: SparkConf) { + + val keySpec = new SecretKeySpec(key, "AES") + val transformation = sparkConf.get(IO_CRYPTO_CIPHER_TRANSFORMATION) + val conf = toCryptoConf(sparkConf) + + } + } diff --git a/core/src/main/scala/org/apache/spark/serializer/SerializerManager.scala b/core/src/main/scala/org/apache/spark/serializer/SerializerManager.scala index 96b288b9cfb81..bb7ed8709ba8a 100644 --- a/core/src/main/scala/org/apache/spark/serializer/SerializerManager.scala +++ b/core/src/main/scala/org/apache/spark/serializer/SerializerManager.scala @@ -148,14 +148,14 @@ private[spark] class SerializerManager( /** * Wrap an output stream for compression if block compression is enabled for its block type */ - private[this] def wrapForCompression(blockId: BlockId, s: OutputStream): OutputStream = { + def wrapForCompression(blockId: BlockId, s: OutputStream): OutputStream = { if (shouldCompress(blockId)) compressionCodec.compressedOutputStream(s) else s } /** * Wrap an input stream for compression if block compression is enabled for its block type */ - private[this] def wrapForCompression(blockId: BlockId, s: InputStream): InputStream = { + def wrapForCompression(blockId: BlockId, s: InputStream): InputStream = { if (shouldCompress(blockId)) compressionCodec.compressedInputStream(s) else s } @@ -167,30 +167,26 @@ private[spark] class SerializerManager( val byteStream = new BufferedOutputStream(outputStream) val autoPick = !blockId.isInstanceOf[StreamBlockId] val ser = getSerializer(implicitly[ClassTag[T]], autoPick).newInstance() - ser.serializeStream(wrapStream(blockId, byteStream)).writeAll(values).close() + ser.serializeStream(wrapForCompression(blockId, byteStream)).writeAll(values).close() } /** Serializes into a chunked byte buffer. */ def dataSerialize[T: ClassTag]( blockId: BlockId, - values: Iterator[T], - allowEncryption: Boolean = true): ChunkedByteBuffer = { - dataSerializeWithExplicitClassTag(blockId, values, implicitly[ClassTag[T]], - allowEncryption = allowEncryption) + values: Iterator[T]): ChunkedByteBuffer = { + dataSerializeWithExplicitClassTag(blockId, values, implicitly[ClassTag[T]]) } /** Serializes into a chunked byte buffer. */ def dataSerializeWithExplicitClassTag( blockId: BlockId, values: Iterator[_], - classTag: ClassTag[_], - allowEncryption: Boolean = true): ChunkedByteBuffer = { + classTag: ClassTag[_]): ChunkedByteBuffer = { val bbos = new ChunkedByteBufferOutputStream(1024 * 1024 * 4, ByteBuffer.allocate) val byteStream = new BufferedOutputStream(bbos) val autoPick = !blockId.isInstanceOf[StreamBlockId] val ser = getSerializer(classTag, autoPick).newInstance() - val encrypted = if (allowEncryption) wrapForEncryption(byteStream) else byteStream - ser.serializeStream(wrapForCompression(blockId, encrypted)).writeAll(values).close() + ser.serializeStream(wrapForCompression(blockId, byteStream)).writeAll(values).close() bbos.toChunkedByteBuffer } @@ -200,15 +196,13 @@ private[spark] class SerializerManager( */ def dataDeserializeStream[T]( blockId: BlockId, - inputStream: InputStream, - maybeEncrypted: Boolean = true) + inputStream: InputStream) (classTag: ClassTag[T]): Iterator[T] = { val stream = new BufferedInputStream(inputStream) val autoPick = !blockId.isInstanceOf[StreamBlockId] - val decrypted = if (maybeEncrypted) wrapForEncryption(inputStream) else inputStream getSerializer(classTag, autoPick) .newInstance() - .deserializeStream(wrapForCompression(blockId, decrypted)) + .deserializeStream(wrapForCompression(blockId, inputStream)) .asIterator.asInstanceOf[Iterator[T]] } } diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala index 991346a40af4e..fcda9fa65303a 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala @@ -19,6 +19,7 @@ package org.apache.spark.storage import java.io._ import java.nio.ByteBuffer +import java.nio.channels.Channels import scala.collection.mutable import scala.collection.mutable.HashMap @@ -35,7 +36,7 @@ import org.apache.spark.executor.{DataReadMethod, ShuffleWriteMetrics} import org.apache.spark.internal.Logging import org.apache.spark.memory.{MemoryManager, MemoryMode} import org.apache.spark.network._ -import org.apache.spark.network.buffer.{ManagedBuffer, NettyManagedBuffer} +import org.apache.spark.network.buffer.ManagedBuffer import org.apache.spark.network.netty.SparkTransportConf import org.apache.spark.network.shuffle.ExternalShuffleClient import org.apache.spark.network.shuffle.protocol.ExecutorShuffleInfo @@ -55,6 +56,55 @@ private[spark] class BlockResult( val readMethod: DataReadMethod.Value, val bytes: Long) +/** + * Abstracts away how blocks are stored and provides different ways to read the underlying block + * data. Callers should call [[dispose()]] when they're done with the block. + */ +private[spark] trait BlockData { + + def toInputStream(): InputStream + + /** + * Returns a Netty-friendly wrapper for the block's data. + * + * @see [[ManagedBuffer#convertToNetty()]] + */ + def toNetty(): Object + + def toChunkedByteBuffer(allocator: Int => ByteBuffer): ChunkedByteBuffer + + def toByteBuffer(): ByteBuffer + + def size: Long + + def dispose(): Unit + +} + +private[spark] class ByteBufferBlockData( + val buffer: ChunkedByteBuffer, + val shouldDispose: Boolean) extends BlockData { + + override def toInputStream(): InputStream = buffer.toInputStream(dispose = false) + + override def toNetty(): Object = buffer.toNetty + + override def toChunkedByteBuffer(allocator: Int => ByteBuffer): ChunkedByteBuffer = { + buffer.copy(allocator) + } + + override def toByteBuffer(): ByteBuffer = buffer.toByteBuffer + + override def size: Long = buffer.size + + override def dispose(): Unit = { + if (shouldDispose) { + buffer.dispose() + } + } + +} + /** * Manager running on every node (driver and executors) which provides interfaces for putting and * retrieving blocks both locally and remotely into various stores (memory, disk, and off-heap). @@ -94,7 +144,7 @@ private[spark] class BlockManager( // Actual storage of where blocks are kept private[spark] val memoryStore = new MemoryStore(conf, blockInfoManager, serializerManager, memoryManager, this) - private[spark] val diskStore = new DiskStore(conf, diskBlockManager) + private[spark] val diskStore = new DiskStore(conf, diskBlockManager, securityManager) memoryManager.setMemoryStore(memoryStore) // Note: depending on the memory manager, `maxMemory` may actually vary over time. @@ -304,7 +354,8 @@ private[spark] class BlockManager( shuffleManager.shuffleBlockResolver.getBlockData(blockId.asInstanceOf[ShuffleBlockId]) } else { getLocalBytes(blockId) match { - case Some(buffer) => new BlockManagerManagedBuffer(blockInfoManager, blockId, buffer) + case Some(blockData) => + new BlockManagerManagedBuffer(blockInfoManager, blockId, blockData, true) case None => // If this block manager receives a request for a block that it doesn't have then it's // likely that the master has outdated block statuses for this block. Therefore, we send @@ -463,21 +514,22 @@ private[spark] class BlockManager( val ci = CompletionIterator[Any, Iterator[Any]](iter, releaseLock(blockId)) Some(new BlockResult(ci, DataReadMethod.Memory, info.size)) } else if (level.useDisk && diskStore.contains(blockId)) { + val diskData = diskStore.getBytes(blockId) val iterToReturn: Iterator[Any] = { - val diskBytes = diskStore.getBytes(blockId) if (level.deserialized) { val diskValues = serializerManager.dataDeserializeStream( blockId, - diskBytes.toInputStream(dispose = true))(info.classTag) + diskData.toInputStream())(info.classTag) maybeCacheDiskValuesInMemory(info, blockId, level, diskValues) } else { - val stream = maybeCacheDiskBytesInMemory(info, blockId, level, diskBytes) - .map {_.toInputStream(dispose = false)} - .getOrElse { diskBytes.toInputStream(dispose = true) } + val stream = maybeCacheDiskBytesInMemory(info, blockId, level, diskData) + .map { _.toInputStream(dispose = false) } + .getOrElse { diskData.toInputStream() } serializerManager.dataDeserializeStream(blockId, stream)(info.classTag) } } - val ci = CompletionIterator[Any, Iterator[Any]](iterToReturn, releaseLock(blockId)) + val ci = CompletionIterator[Any, Iterator[Any]](iterToReturn, + releaseLockAndDispose(blockId, diskData)) Some(new BlockResult(ci, DataReadMethod.Disk, info.size)) } else { handleLocalReadFailure(blockId) @@ -488,7 +540,7 @@ private[spark] class BlockManager( /** * Get block from the local block manager as serialized bytes. */ - def getLocalBytes(blockId: BlockId): Option[ChunkedByteBuffer] = { + def getLocalBytes(blockId: BlockId): Option[BlockData] = { logDebug(s"Getting local block $blockId as bytes") // As an optimization for map output fetches, if the block is for a shuffle, return it // without acquiring a lock; the disk store never deletes (recent) items so this should work @@ -496,9 +548,9 @@ private[spark] class BlockManager( val shuffleBlockResolver = shuffleManager.shuffleBlockResolver // TODO: This should gracefully handle case where local block is not available. Currently // downstream code will throw an exception. - Option( - new ChunkedByteBuffer( - shuffleBlockResolver.getBlockData(blockId.asInstanceOf[ShuffleBlockId]).nioByteBuffer())) + val buf = new ChunkedByteBuffer( + shuffleBlockResolver.getBlockData(blockId.asInstanceOf[ShuffleBlockId]).nioByteBuffer()) + Some(new ByteBufferBlockData(buf, true)) } else { blockInfoManager.lockForReading(blockId).map { info => doGetLocalBytes(blockId, info) } } @@ -510,7 +562,7 @@ private[spark] class BlockManager( * Must be called while holding a read lock on the block. * Releases the read lock upon exception; keeps the read lock upon successful return. */ - private def doGetLocalBytes(blockId: BlockId, info: BlockInfo): ChunkedByteBuffer = { + private def doGetLocalBytes(blockId: BlockId, info: BlockInfo): BlockData = { val level = info.level logDebug(s"Level for block $blockId is $level") // In order, try to read the serialized bytes from memory, then from disk, then fall back to @@ -525,17 +577,19 @@ private[spark] class BlockManager( diskStore.getBytes(blockId) } else if (level.useMemory && memoryStore.contains(blockId)) { // The block was not found on disk, so serialize an in-memory copy: - serializerManager.dataSerializeWithExplicitClassTag( - blockId, memoryStore.getValues(blockId).get, info.classTag) + new ByteBufferBlockData(serializerManager.dataSerializeWithExplicitClassTag( + blockId, memoryStore.getValues(blockId).get, info.classTag), true) } else { handleLocalReadFailure(blockId) } } else { // storage level is serialized if (level.useMemory && memoryStore.contains(blockId)) { - memoryStore.getBytes(blockId).get + new ByteBufferBlockData(memoryStore.getBytes(blockId).get, false) } else if (level.useDisk && diskStore.contains(blockId)) { - val diskBytes = diskStore.getBytes(blockId) - maybeCacheDiskBytesInMemory(info, blockId, level, diskBytes).getOrElse(diskBytes) + val diskData = diskStore.getBytes(blockId) + maybeCacheDiskBytesInMemory(info, blockId, level, diskData) + .map(new ByteBufferBlockData(_, false)) + .getOrElse(diskData) } else { handleLocalReadFailure(blockId) } @@ -761,43 +815,15 @@ private[spark] class BlockManager( * '''Important!''' Callers must not mutate or release the data buffer underlying `bytes`. Doing * so may corrupt or change the data stored by the `BlockManager`. * - * @param encrypt If true, asks the block manager to encrypt the data block before storing, - * when I/O encryption is enabled. This is required for blocks that have been - * read from unencrypted sources, since all the BlockManager read APIs - * automatically do decryption. * @return true if the block was stored or false if an error occurred. */ def putBytes[T: ClassTag]( blockId: BlockId, bytes: ChunkedByteBuffer, level: StorageLevel, - tellMaster: Boolean = true, - encrypt: Boolean = false): Boolean = { + tellMaster: Boolean = true): Boolean = { require(bytes != null, "Bytes is null") - - val bytesToStore = - if (encrypt && securityManager.ioEncryptionKey.isDefined) { - try { - val data = bytes.toByteBuffer - val in = new ByteBufferInputStream(data) - val byteBufOut = new ByteBufferOutputStream(data.remaining()) - val out = CryptoStreamUtils.createCryptoOutputStream(byteBufOut, conf, - securityManager.ioEncryptionKey.get) - try { - ByteStreams.copy(in, out) - } finally { - in.close() - out.close() - } - new ChunkedByteBuffer(byteBufOut.toByteBuffer) - } finally { - bytes.dispose() - } - } else { - bytes - } - - doPutBytes(blockId, bytesToStore, level, implicitly[ClassTag[T]], tellMaster) + doPutBytes(blockId, bytes, level, implicitly[ClassTag[T]], tellMaster) } /** @@ -828,8 +854,9 @@ private[spark] class BlockManager( val replicationFuture = if (level.replication > 1) { Future { // This is a blocking action and should run in futureExecutionContext which is a cached - // thread pool - replicate(blockId, bytes, level, classTag) + // thread pool. The ByteBufferBlockData wrapper is not disposed of to avoid releasing + // buffers that are owned by the caller. + replicate(blockId, new ByteBufferBlockData(bytes, false), level, classTag) }(futureExecutionContext) } else { null @@ -1008,8 +1035,9 @@ private[spark] class BlockManager( // Not enough space to unroll this block; drop to disk if applicable if (level.useDisk) { logWarning(s"Persisting block $blockId to disk instead.") - diskStore.put(blockId) { fileOutputStream => - serializerManager.dataSerializeStream(blockId, fileOutputStream, iter)(classTag) + diskStore.put(blockId) { channel => + val out = Channels.newOutputStream(channel) + serializerManager.dataSerializeStream(blockId, out, iter)(classTag) } size = diskStore.getSize(blockId) } else { @@ -1024,8 +1052,9 @@ private[spark] class BlockManager( // Not enough space to unroll this block; drop to disk if applicable if (level.useDisk) { logWarning(s"Persisting block $blockId to disk instead.") - diskStore.put(blockId) { fileOutputStream => - partiallySerializedValues.finishWritingToStream(fileOutputStream) + diskStore.put(blockId) { channel => + val out = Channels.newOutputStream(channel) + partiallySerializedValues.finishWritingToStream(out) } size = diskStore.getSize(blockId) } else { @@ -1035,8 +1064,9 @@ private[spark] class BlockManager( } } else if (level.useDisk) { - diskStore.put(blockId) { fileOutputStream => - serializerManager.dataSerializeStream(blockId, fileOutputStream, iterator())(classTag) + diskStore.put(blockId) { channel => + val out = Channels.newOutputStream(channel) + serializerManager.dataSerializeStream(blockId, out, iterator())(classTag) } size = diskStore.getSize(blockId) } @@ -1065,7 +1095,7 @@ private[spark] class BlockManager( try { replicate(blockId, bytesToReplicate, level, remoteClassTag) } finally { - bytesToReplicate.unmap() + bytesToReplicate.dispose() } logDebug("Put block %s remotely took %s" .format(blockId, Utils.getUsedTimeMs(remoteStartTime))) @@ -1089,29 +1119,29 @@ private[spark] class BlockManager( blockInfo: BlockInfo, blockId: BlockId, level: StorageLevel, - diskBytes: ChunkedByteBuffer): Option[ChunkedByteBuffer] = { + diskData: BlockData): Option[ChunkedByteBuffer] = { require(!level.deserialized) if (level.useMemory) { // Synchronize on blockInfo to guard against a race condition where two readers both try to // put values read from disk into the MemoryStore. blockInfo.synchronized { if (memoryStore.contains(blockId)) { - diskBytes.dispose() + diskData.dispose() Some(memoryStore.getBytes(blockId).get) } else { val allocator = level.memoryMode match { case MemoryMode.ON_HEAP => ByteBuffer.allocate _ case MemoryMode.OFF_HEAP => Platform.allocateDirectBuffer _ } - val putSucceeded = memoryStore.putBytes(blockId, diskBytes.size, level.memoryMode, () => { + val putSucceeded = memoryStore.putBytes(blockId, diskData.size, level.memoryMode, () => { // https://issues.apache.org/jira/browse/SPARK-6076 // If the file size is bigger than the free memory, OOM will happen. So if we // cannot put it into MemoryStore, copyForMemory should not be created. That's why // this action is put into a `() => ChunkedByteBuffer` and created lazily. - diskBytes.copy(allocator) + diskData.toChunkedByteBuffer(allocator) }) if (putSucceeded) { - diskBytes.dispose() + diskData.dispose() Some(memoryStore.getBytes(blockId).get) } else { None @@ -1203,7 +1233,7 @@ private[spark] class BlockManager( replicate(blockId, data, storageLevel, info.classTag, existingReplicas) } finally { logDebug(s"Releasing lock for $blockId") - releaseLock(blockId) + releaseLockAndDispose(blockId, data) } } } @@ -1214,7 +1244,7 @@ private[spark] class BlockManager( */ private def replicate( blockId: BlockId, - data: ChunkedByteBuffer, + data: BlockData, level: StorageLevel, classTag: ClassTag[_], existingReplicas: Set[BlockManagerId] = Set.empty): Unit = { @@ -1256,7 +1286,7 @@ private[spark] class BlockManager( peer.port, peer.executorId, blockId, - new NettyManagedBuffer(data.toNetty), + new BlockManagerManagedBuffer(blockInfoManager, blockId, data, false), tLevel, classTag) logTrace(s"Replicated $blockId of ${data.size} bytes to $peer" + @@ -1339,10 +1369,11 @@ private[spark] class BlockManager( logInfo(s"Writing block $blockId to disk") data() match { case Left(elements) => - diskStore.put(blockId) { fileOutputStream => + diskStore.put(blockId) { channel => + val out = Channels.newOutputStream(channel) serializerManager.dataSerializeStream( blockId, - fileOutputStream, + out, elements.toIterator)(info.classTag.asInstanceOf[ClassTag[T]]) } case Right(bytes) => @@ -1434,6 +1465,11 @@ private[spark] class BlockManager( } } + def releaseLockAndDispose(blockId: BlockId, data: BlockData): Unit = { + blockInfoManager.unlock(blockId) + data.dispose() + } + def stop(): Unit = { blockTransferService.close() if (shuffleClient ne blockTransferService) { diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerManagedBuffer.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerManagedBuffer.scala index f66f942798550..1ea0d378cbe87 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerManagedBuffer.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerManagedBuffer.scala @@ -17,31 +17,52 @@ package org.apache.spark.storage -import org.apache.spark.network.buffer.{ManagedBuffer, NettyManagedBuffer} +import java.io.InputStream +import java.nio.ByteBuffer +import java.util.concurrent.atomic.AtomicInteger + +import org.apache.spark.network.buffer.ManagedBuffer import org.apache.spark.util.io.ChunkedByteBuffer /** - * This [[ManagedBuffer]] wraps a [[ChunkedByteBuffer]] retrieved from the [[BlockManager]] + * This [[ManagedBuffer]] wraps a [[BlockData]] instance retrieved from the [[BlockManager]] * so that the corresponding block's read lock can be released once this buffer's references * are released. * + * If `dispose` is set to true, the [[BlockData]]will be disposed when the buffer's reference + * count drops to zero. + * * This is effectively a wrapper / bridge to connect the BlockManager's notion of read locks * to the network layer's notion of retain / release counts. */ private[storage] class BlockManagerManagedBuffer( blockInfoManager: BlockInfoManager, blockId: BlockId, - chunkedBuffer: ChunkedByteBuffer) extends NettyManagedBuffer(chunkedBuffer.toNetty) { + data: BlockData, + dispose: Boolean) extends ManagedBuffer { + + private val refCount = new AtomicInteger(1) + + override def size(): Long = data.size + + override def nioByteBuffer(): ByteBuffer = data.toByteBuffer() + + override def createInputStream(): InputStream = data.toInputStream() + + override def convertToNetty(): Object = data.toNetty() override def retain(): ManagedBuffer = { - super.retain() + refCount.incrementAndGet() val locked = blockInfoManager.lockForReading(blockId, blocking = false) assert(locked.isDefined) this - } + } override def release(): ManagedBuffer = { blockInfoManager.unlock(blockId) - super.release() + if (refCount.decrementAndGet() == 0 && dispose) { + data.dispose() + } + this } } diff --git a/core/src/main/scala/org/apache/spark/storage/DiskStore.scala b/core/src/main/scala/org/apache/spark/storage/DiskStore.scala index ca23e2391ed02..c6656341fcd15 100644 --- a/core/src/main/scala/org/apache/spark/storage/DiskStore.scala +++ b/core/src/main/scala/org/apache/spark/storage/DiskStore.scala @@ -17,48 +17,67 @@ package org.apache.spark.storage -import java.io.{FileOutputStream, IOException, RandomAccessFile} +import java.io._ import java.nio.ByteBuffer +import java.nio.channels.{Channels, ReadableByteChannel, WritableByteChannel} import java.nio.channels.FileChannel.MapMode +import java.nio.charset.StandardCharsets.UTF_8 +import java.util.concurrent.ConcurrentHashMap -import com.google.common.io.Closeables +import scala.collection.mutable.ListBuffer -import org.apache.spark.SparkConf +import com.google.common.io.{ByteStreams, Closeables, Files} +import io.netty.channel.FileRegion +import io.netty.util.AbstractReferenceCounted + +import org.apache.spark.{SecurityManager, SparkConf} import org.apache.spark.internal.Logging -import org.apache.spark.util.Utils +import org.apache.spark.network.buffer.ManagedBuffer +import org.apache.spark.network.util.JavaUtils +import org.apache.spark.security.CryptoStreamUtils +import org.apache.spark.util.{ByteBufferInputStream, Utils} import org.apache.spark.util.io.ChunkedByteBuffer /** * Stores BlockManager blocks on disk. */ -private[spark] class DiskStore(conf: SparkConf, diskManager: DiskBlockManager) extends Logging { +private[spark] class DiskStore( + conf: SparkConf, + diskManager: DiskBlockManager, + securityManager: SecurityManager) extends Logging { private val minMemoryMapBytes = conf.getSizeAsBytes("spark.storage.memoryMapThreshold", "2m") + private val blockSizes = new ConcurrentHashMap[String, Long]() - def getSize(blockId: BlockId): Long = { - diskManager.getFile(blockId.name).length - } + def getSize(blockId: BlockId): Long = blockSizes.get(blockId.name) /** * Invokes the provided callback function to write the specific block. * * @throws IllegalStateException if the block already exists in the disk store. */ - def put(blockId: BlockId)(writeFunc: FileOutputStream => Unit): Unit = { + def put(blockId: BlockId)(writeFunc: WritableByteChannel => Unit): Unit = { if (contains(blockId)) { throw new IllegalStateException(s"Block $blockId is already present in the disk store") } logDebug(s"Attempting to put block $blockId") val startTime = System.currentTimeMillis val file = diskManager.getFile(blockId) - val fileOutputStream = new FileOutputStream(file) + val out = new CountingWritableChannel(openForWrite(file)) var threwException: Boolean = true try { - writeFunc(fileOutputStream) + writeFunc(out) + blockSizes.put(blockId.name, out.getCount) threwException = false } finally { try { - Closeables.close(fileOutputStream, threwException) + out.close() + } catch { + case ioe: IOException => + if (!threwException) { + threwException = true + throw ioe + } } finally { if (threwException) { remove(blockId) @@ -73,41 +92,46 @@ private[spark] class DiskStore(conf: SparkConf, diskManager: DiskBlockManager) e } def putBytes(blockId: BlockId, bytes: ChunkedByteBuffer): Unit = { - put(blockId) { fileOutputStream => - val channel = fileOutputStream.getChannel - Utils.tryWithSafeFinally { - bytes.writeFully(channel) - } { - channel.close() - } + put(blockId) { channel => + bytes.writeFully(channel) } } - def getBytes(blockId: BlockId): ChunkedByteBuffer = { + def getBytes(blockId: BlockId): BlockData = { val file = diskManager.getFile(blockId.name) - val channel = new RandomAccessFile(file, "r").getChannel - Utils.tryWithSafeFinally { - // For small files, directly read rather than memory map - if (file.length < minMemoryMapBytes) { - val buf = ByteBuffer.allocate(file.length.toInt) - channel.position(0) - while (buf.remaining() != 0) { - if (channel.read(buf) == -1) { - throw new IOException("Reached EOF before filling buffer\n" + - s"offset=0\nfile=${file.getAbsolutePath}\nbuf.remaining=${buf.remaining}") + val blockSize = getSize(blockId) + + securityManager.getIOEncryptionKey() match { + case Some(key) => + // Encrypted blocks cannot be memory mapped; return a special object that does decryption + // and provides InputStream / FileRegion implementations for reading the data. + new EncryptedBlockData(file, blockSize, conf, key) + + case _ => + val channel = new FileInputStream(file).getChannel() + if (blockSize < minMemoryMapBytes) { + // For small files, directly read rather than memory map. + Utils.tryWithSafeFinally { + val buf = ByteBuffer.allocate(blockSize.toInt) + JavaUtils.readFully(channel, buf) + buf.flip() + new ByteBufferBlockData(new ChunkedByteBuffer(buf), true) + } { + channel.close() + } + } else { + Utils.tryWithSafeFinally { + new ByteBufferBlockData( + new ChunkedByteBuffer(channel.map(MapMode.READ_ONLY, 0, file.length)), true) + } { + channel.close() } } - buf.flip() - new ChunkedByteBuffer(buf) - } else { - new ChunkedByteBuffer(channel.map(MapMode.READ_ONLY, 0, file.length)) - } - } { - channel.close() } } def remove(blockId: BlockId): Boolean = { + blockSizes.remove(blockId.name) val file = diskManager.getFile(blockId.name) if (file.exists()) { val ret = file.delete() @@ -124,4 +148,142 @@ private[spark] class DiskStore(conf: SparkConf, diskManager: DiskBlockManager) e val file = diskManager.getFile(blockId.name) file.exists() } + + private def openForWrite(file: File): WritableByteChannel = { + val out = new FileOutputStream(file).getChannel() + try { + securityManager.getIOEncryptionKey().map { key => + CryptoStreamUtils.createWritableChannel(out, conf, key) + }.getOrElse(out) + } catch { + case e: Exception => + Closeables.close(out, true) + file.delete() + throw e + } + } + +} + +private class EncryptedBlockData( + file: File, + blockSize: Long, + conf: SparkConf, + key: Array[Byte]) extends BlockData { + + override def toInputStream(): InputStream = Channels.newInputStream(open()) + + override def toNetty(): Object = new ReadableChannelFileRegion(open(), blockSize) + + override def toChunkedByteBuffer(allocator: Int => ByteBuffer): ChunkedByteBuffer = { + val source = open() + try { + var remaining = blockSize + val chunks = new ListBuffer[ByteBuffer]() + while (remaining > 0) { + val chunkSize = math.min(remaining, Int.MaxValue) + val chunk = allocator(chunkSize.toInt) + remaining -= chunkSize + JavaUtils.readFully(source, chunk) + chunk.flip() + chunks += chunk + } + + new ChunkedByteBuffer(chunks.toArray) + } finally { + source.close() + } + } + + override def toByteBuffer(): ByteBuffer = { + // This is used by the block transfer service to replicate blocks. The upload code reads + // all bytes into memory to send the block to the remote executor, so it's ok to do this + // as long as the block fits in a Java array. + assert(blockSize <= Int.MaxValue, "Block is too large to be wrapped in a byte buffer.") + val dst = ByteBuffer.allocate(blockSize.toInt) + val in = open() + try { + JavaUtils.readFully(in, dst) + dst.flip() + dst + } finally { + Closeables.close(in, true) + } + } + + override def size: Long = blockSize + + override def dispose(): Unit = { } + + private def open(): ReadableByteChannel = { + val channel = new FileInputStream(file).getChannel() + try { + CryptoStreamUtils.createReadableChannel(channel, conf, key) + } catch { + case e: Exception => + Closeables.close(channel, true) + throw e + } + } + +} + +private class ReadableChannelFileRegion(source: ReadableByteChannel, blockSize: Long) + extends AbstractReferenceCounted with FileRegion { + + private var _transferred = 0L + + private val buffer = ByteBuffer.allocateDirect(64 * 1024) + buffer.flip() + + override def count(): Long = blockSize + + override def position(): Long = 0 + + override def transfered(): Long = _transferred + + override def transferTo(target: WritableByteChannel, pos: Long): Long = { + assert(pos == transfered(), "Invalid position.") + + var written = 0L + var lastWrite = -1L + while (lastWrite != 0) { + if (!buffer.hasRemaining()) { + buffer.clear() + source.read(buffer) + buffer.flip() + } + if (buffer.hasRemaining()) { + lastWrite = target.write(buffer) + written += lastWrite + } else { + lastWrite = 0 + } + } + + _transferred += written + written + } + + override def deallocate(): Unit = source.close() +} + +private class CountingWritableChannel(sink: WritableByteChannel) extends WritableByteChannel { + + private var count = 0L + + def getCount: Long = count + + override def write(src: ByteBuffer): Int = { + val written = sink.write(src) + if (written > 0) { + count += written + } + written + } + + override def isOpen(): Boolean = sink.isOpen() + + override def close(): Unit = sink.close() + } diff --git a/core/src/main/scala/org/apache/spark/storage/StorageUtils.scala b/core/src/main/scala/org/apache/spark/storage/StorageUtils.scala index 5efdd23f79a21..241aacd74b586 100644 --- a/core/src/main/scala/org/apache/spark/storage/StorageUtils.scala +++ b/core/src/main/scala/org/apache/spark/storage/StorageUtils.scala @@ -236,14 +236,6 @@ class StorageStatus(val blockManagerId: BlockManagerId, val maxMem: Long) { /** Helper methods for storage-related objects. */ private[spark] object StorageUtils extends Logging { - // Ewwww... Reflection!!! See the unmap method for justification - private val memoryMappedBufferFileDescriptorField = { - val mappedBufferClass = classOf[java.nio.MappedByteBuffer] - val fdField = mappedBufferClass.getDeclaredField("fd") - fdField.setAccessible(true) - fdField - } - /** * Attempt to clean up a ByteBuffer if it is direct or memory-mapped. This uses an *unsafe* Sun * API that will cause errors if one attempts to read from the disposed buffer. However, neither @@ -251,8 +243,6 @@ private[spark] object StorageUtils extends Logging { * pressure on the garbage collector. Waiting for garbage collection may lead to the depletion of * off-heap memory or huge numbers of open files. There's unfortunately no standard API to * manually dispose of these kinds of buffers. - * - * See also [[unmap]] */ def dispose(buffer: ByteBuffer): Unit = { if (buffer != null && buffer.isInstanceOf[MappedByteBuffer]) { @@ -261,28 +251,6 @@ private[spark] object StorageUtils extends Logging { } } - /** - * Attempt to unmap a ByteBuffer if it is memory-mapped. This uses an *unsafe* Sun API that will - * cause errors if one attempts to read from the unmapped buffer. However, the file descriptors of - * memory-mapped buffers do not put pressure on the garbage collector. Waiting for garbage - * collection may lead to huge numbers of open files. There's unfortunately no standard API to - * manually unmap memory-mapped buffers. - * - * See also [[dispose]] - */ - def unmap(buffer: ByteBuffer): Unit = { - if (buffer != null && buffer.isInstanceOf[MappedByteBuffer]) { - // Note that direct buffers are instances of MappedByteBuffer. As things stand in Java 8, the - // JDK does not provide a public API to distinguish between direct buffers and memory-mapped - // buffers. As an alternative, we peek beneath the curtains and look for a non-null file - // descriptor in mappedByteBuffer - if (memoryMappedBufferFileDescriptorField.get(buffer) != null) { - logTrace(s"Unmapping $buffer") - cleanDirectBuffer(buffer.asInstanceOf[DirectBuffer]) - } - } - } - private def cleanDirectBuffer(buffer: DirectBuffer) = { val cleaner = buffer.cleaner() if (cleaner != null) { diff --git a/core/src/main/scala/org/apache/spark/storage/memory/MemoryStore.scala b/core/src/main/scala/org/apache/spark/storage/memory/MemoryStore.scala index fb54dd66a39a9..90e3af2d0ec74 100644 --- a/core/src/main/scala/org/apache/spark/storage/memory/MemoryStore.scala +++ b/core/src/main/scala/org/apache/spark/storage/memory/MemoryStore.scala @@ -344,7 +344,7 @@ private[spark] class MemoryStore( val serializationStream: SerializationStream = { val autoPick = !blockId.isInstanceOf[StreamBlockId] val ser = serializerManager.getSerializer(classTag, autoPick).newInstance() - ser.serializeStream(serializerManager.wrapStream(blockId, redirectableStream)) + ser.serializeStream(serializerManager.wrapForCompression(blockId, redirectableStream)) } // Request enough memory to begin unrolling diff --git a/core/src/main/scala/org/apache/spark/util/io/ChunkedByteBuffer.scala b/core/src/main/scala/org/apache/spark/util/io/ChunkedByteBuffer.scala index 1667516663b35..2f905c8af0f63 100644 --- a/core/src/main/scala/org/apache/spark/util/io/ChunkedByteBuffer.scala +++ b/core/src/main/scala/org/apache/spark/util/io/ChunkedByteBuffer.scala @@ -138,8 +138,6 @@ private[spark] class ChunkedByteBuffer(var chunks: Array[ByteBuffer]) { /** * Attempt to clean up any ByteBuffer in this ChunkedByteBuffer which is direct or memory-mapped. * See [[StorageUtils.dispose]] for more information. - * - * See also [[unmap]] */ def dispose(): Unit = { if (!disposed) { @@ -148,18 +146,6 @@ private[spark] class ChunkedByteBuffer(var chunks: Array[ByteBuffer]) { } } - /** - * Attempt to unmap any ByteBuffer in this ChunkedByteBuffer if it is memory-mapped. See - * [[StorageUtils.unmap]] for more information. - * - * See also [[dispose]] - */ - def unmap(): Unit = { - if (!disposed) { - chunks.foreach(StorageUtils.unmap) - disposed = true - } - } } /** diff --git a/core/src/test/scala/org/apache/spark/DistributedSuite.scala b/core/src/test/scala/org/apache/spark/DistributedSuite.scala index 4e36adc8baf3f..84f7f1fc8eb09 100644 --- a/core/src/test/scala/org/apache/spark/DistributedSuite.scala +++ b/core/src/test/scala/org/apache/spark/DistributedSuite.scala @@ -21,6 +21,7 @@ import org.scalatest.concurrent.Timeouts._ import org.scalatest.Matchers import org.scalatest.time.{Millis, Span} +import org.apache.spark.security.EncryptionFunSuite import org.apache.spark.storage.{RDDBlockId, StorageLevel} import org.apache.spark.util.io.ChunkedByteBuffer @@ -28,7 +29,8 @@ class NotSerializableClass class NotSerializableExn(val notSer: NotSerializableClass) extends Throwable() {} -class DistributedSuite extends SparkFunSuite with Matchers with LocalSparkContext { +class DistributedSuite extends SparkFunSuite with Matchers with LocalSparkContext + with EncryptionFunSuite { val clusterUrl = "local-cluster[2,1,1024]" @@ -149,8 +151,8 @@ class DistributedSuite extends SparkFunSuite with Matchers with LocalSparkContex sc.parallelize(1 to 10).count() } - private def testCaching(storageLevel: StorageLevel): Unit = { - sc = new SparkContext(clusterUrl, "test") + private def testCaching(conf: SparkConf, storageLevel: StorageLevel): Unit = { + sc = new SparkContext(conf.setMaster(clusterUrl).setAppName("test")) sc.jobProgressListener.waitUntilExecutorsUp(2, 30000) val data = sc.parallelize(1 to 1000, 10) val cachedData = data.persist(storageLevel) @@ -187,8 +189,8 @@ class DistributedSuite extends SparkFunSuite with Matchers with LocalSparkContex "caching in memory and disk, replicated" -> StorageLevel.MEMORY_AND_DISK_2, "caching in memory and disk, serialized, replicated" -> StorageLevel.MEMORY_AND_DISK_SER_2 ).foreach { case (testName, storageLevel) => - test(testName) { - testCaching(storageLevel) + encryptionTest(testName) { conf => + testCaching(conf, storageLevel) } } diff --git a/core/src/test/scala/org/apache/spark/broadcast/BroadcastSuite.scala b/core/src/test/scala/org/apache/spark/broadcast/BroadcastSuite.scala index 6646068d5080b..82760fe92f76a 100644 --- a/core/src/test/scala/org/apache/spark/broadcast/BroadcastSuite.scala +++ b/core/src/test/scala/org/apache/spark/broadcast/BroadcastSuite.scala @@ -24,8 +24,10 @@ import org.scalatest.Assertions import org.apache.spark._ import org.apache.spark.io.SnappyCompressionCodec import org.apache.spark.rdd.RDD +import org.apache.spark.security.EncryptionFunSuite import org.apache.spark.serializer.JavaSerializer import org.apache.spark.storage._ +import org.apache.spark.util.io.ChunkedByteBuffer // Dummy class that creates a broadcast variable but doesn't use it class DummyBroadcastClass(rdd: RDD[Int]) extends Serializable { @@ -43,7 +45,7 @@ class DummyBroadcastClass(rdd: RDD[Int]) extends Serializable { } } -class BroadcastSuite extends SparkFunSuite with LocalSparkContext { +class BroadcastSuite extends SparkFunSuite with LocalSparkContext with EncryptionFunSuite { test("Using TorrentBroadcast locally") { sc = new SparkContext("local", "test") @@ -61,9 +63,8 @@ class BroadcastSuite extends SparkFunSuite with LocalSparkContext { assert(results.collect().toSet === (1 to 10).map(x => (x, 10)).toSet) } - test("Accessing TorrentBroadcast variables in a local cluster") { + encryptionTest("Accessing TorrentBroadcast variables in a local cluster") { conf => val numSlaves = 4 - val conf = new SparkConf conf.set("spark.serializer", "org.apache.spark.serializer.KryoSerializer") conf.set("spark.broadcast.compress", "true") sc = new SparkContext("local-cluster[%d, 1, 1024]".format(numSlaves), "test", conf) @@ -85,7 +86,9 @@ class BroadcastSuite extends SparkFunSuite with LocalSparkContext { val size = 1 + rand.nextInt(1024 * 10) val data: Array[Byte] = new Array[Byte](size) rand.nextBytes(data) - val blocks = blockifyObject(data, blockSize, serializer, compressionCodec) + val blocks = blockifyObject(data, blockSize, serializer, compressionCodec).map { b => + new ChunkedByteBuffer(b).toInputStream(dispose = true) + } val unblockified = unBlockifyObject[Array[Byte]](blocks, serializer, compressionCodec) assert(unblockified === data) } @@ -137,9 +140,8 @@ class BroadcastSuite extends SparkFunSuite with LocalSparkContext { sc.stop() } - test("Cache broadcast to disk") { - val conf = new SparkConf() - .setMaster("local") + encryptionTest("Cache broadcast to disk") { conf => + conf.setMaster("local") .setAppName("test") .set("spark.memory.useLegacyMode", "true") .set("spark.storage.memoryFraction", "0.0") diff --git a/core/src/test/scala/org/apache/spark/security/CryptoStreamUtilsSuite.scala b/core/src/test/scala/org/apache/spark/security/CryptoStreamUtilsSuite.scala index 0f3a4a03618ed..608052f5ed855 100644 --- a/core/src/test/scala/org/apache/spark/security/CryptoStreamUtilsSuite.scala +++ b/core/src/test/scala/org/apache/spark/security/CryptoStreamUtilsSuite.scala @@ -16,9 +16,11 @@ */ package org.apache.spark.security -import java.io.{ByteArrayInputStream, ByteArrayOutputStream} +import java.io.{ByteArrayInputStream, ByteArrayOutputStream, FileInputStream, FileOutputStream} +import java.nio.channels.Channels import java.nio.charset.StandardCharsets.UTF_8 -import java.util.UUID +import java.nio.file.Files +import java.util.{Arrays, Random, UUID} import com.google.common.io.ByteStreams @@ -121,6 +123,46 @@ class CryptoStreamUtilsSuite extends SparkFunSuite { } } + test("crypto stream wrappers") { + val testData = new Array[Byte](128 * 1024) + new Random().nextBytes(testData) + + val conf = createConf() + val key = createKey(conf) + val file = Files.createTempFile("crypto", ".test").toFile() + + val outStream = createCryptoOutputStream(new FileOutputStream(file), conf, key) + try { + ByteStreams.copy(new ByteArrayInputStream(testData), outStream) + } finally { + outStream.close() + } + + val inStream = createCryptoInputStream(new FileInputStream(file), conf, key) + try { + val inStreamData = ByteStreams.toByteArray(inStream) + assert(Arrays.equals(inStreamData, testData)) + } finally { + inStream.close() + } + + val outChannel = createWritableChannel(new FileOutputStream(file).getChannel(), conf, key) + try { + val inByteChannel = Channels.newChannel(new ByteArrayInputStream(testData)) + ByteStreams.copy(inByteChannel, outChannel) + } finally { + outChannel.close() + } + + val inChannel = createReadableChannel(new FileInputStream(file).getChannel(), conf, key) + try { + val inChannelData = ByteStreams.toByteArray(Channels.newInputStream(inChannel)) + assert(Arrays.equals(inChannelData, testData)) + } finally { + inChannel.close() + } + } + private def createConf(extra: (String, String)*): SparkConf = { val conf = new SparkConf() extra.foreach { case (k, v) => conf.set(k, v) } diff --git a/core/src/test/scala/org/apache/spark/security/EncryptionFunSuite.scala b/core/src/test/scala/org/apache/spark/security/EncryptionFunSuite.scala new file mode 100644 index 0000000000000..3f52dc41abf6d --- /dev/null +++ b/core/src/test/scala/org/apache/spark/security/EncryptionFunSuite.scala @@ -0,0 +1,39 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.security + +import org.apache.spark.{SparkConf, SparkFunSuite} +import org.apache.spark.internal.config._ + +trait EncryptionFunSuite { + + this: SparkFunSuite => + + /** + * Runs a test twice, initializing a SparkConf object with encryption off, then on. It's ok + * for the test to modify the provided SparkConf. + */ + final protected def encryptionTest(name: String)(fn: SparkConf => Unit) { + Seq(false, true).foreach { encrypt => + test(s"$name (encryption = ${ if (encrypt) "on" else "off" })") { + val conf = new SparkConf().set(IO_ENCRYPTION_ENABLED, encrypt) + fn(conf) + } + } + } + +} diff --git a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala index 64a67b4c4cbab..a8b9604899838 100644 --- a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala @@ -35,6 +35,7 @@ import org.scalatest.concurrent.Timeouts._ import org.apache.spark._ import org.apache.spark.broadcast.BroadcastManager import org.apache.spark.executor.DataReadMethod +import org.apache.spark.internal.config._ import org.apache.spark.memory.UnifiedMemoryManager import org.apache.spark.network.{BlockDataManager, BlockTransferService} import org.apache.spark.network.buffer.{ManagedBuffer, NioManagedBuffer} @@ -42,6 +43,7 @@ import org.apache.spark.network.netty.NettyBlockTransferService import org.apache.spark.network.shuffle.BlockFetchingListener import org.apache.spark.rpc.RpcEnv import org.apache.spark.scheduler.LiveListenerBus +import org.apache.spark.security.{CryptoStreamUtils, EncryptionFunSuite} import org.apache.spark.serializer.{JavaSerializer, KryoSerializer, SerializerManager} import org.apache.spark.shuffle.sort.SortShuffleManager import org.apache.spark.storage.BlockManagerMessages.BlockManagerHeartbeat @@ -49,7 +51,8 @@ import org.apache.spark.util._ import org.apache.spark.util.io.ChunkedByteBuffer class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterEach - with PrivateMethodTester with LocalSparkContext with ResetSystemProperties { + with PrivateMethodTester with LocalSparkContext with ResetSystemProperties + with EncryptionFunSuite { import BlockManagerSuite._ @@ -75,16 +78,24 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE maxMem: Long, name: String = SparkContext.DRIVER_IDENTIFIER, master: BlockManagerMaster = this.master, - transferService: Option[BlockTransferService] = Option.empty): BlockManager = { - conf.set("spark.testing.memory", maxMem.toString) - conf.set("spark.memory.offHeap.size", maxMem.toString) - val serializer = new KryoSerializer(conf) + transferService: Option[BlockTransferService] = Option.empty, + testConf: Option[SparkConf] = None): BlockManager = { + val bmConf = testConf.map(_.setAll(conf.getAll)).getOrElse(conf) + bmConf.set("spark.testing.memory", maxMem.toString) + bmConf.set("spark.memory.offHeap.size", maxMem.toString) + val serializer = new KryoSerializer(bmConf) + val encryptionKey = if (bmConf.get(IO_ENCRYPTION_ENABLED)) { + Some(CryptoStreamUtils.createKey(bmConf)) + } else { + None + } + val bmSecurityMgr = new SecurityManager(bmConf, encryptionKey) val transfer = transferService .getOrElse(new NettyBlockTransferService(conf, securityMgr, "localhost", "localhost", 0, 1)) - val memManager = UnifiedMemoryManager(conf, numCores = 1) - val serializerManager = new SerializerManager(serializer, conf) - val blockManager = new BlockManager(name, rpcEnv, master, serializerManager, conf, - memManager, mapOutputTracker, shuffleManager, transfer, securityMgr, 0) + val memManager = UnifiedMemoryManager(bmConf, numCores = 1) + val serializerManager = new SerializerManager(serializer, bmConf) + val blockManager = new BlockManager(name, rpcEnv, master, serializerManager, bmConf, + memManager, mapOutputTracker, shuffleManager, transfer, bmSecurityMgr, 0) memManager.setMemoryStore(blockManager.memoryStore) blockManager.initialize("app-id") blockManager @@ -610,8 +621,8 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE assert(store.memoryStore.contains(rdd(0, 3)), "rdd_0_3 was not in store") } - test("on-disk storage") { - store = makeBlockManager(1200) + encryptionTest("on-disk storage") { _conf => + store = makeBlockManager(1200, testConf = Some(_conf)) val a1 = new Array[Byte](400) val a2 = new Array[Byte](400) val a3 = new Array[Byte](400) @@ -623,34 +634,35 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE assert(store.getSingleAndReleaseLock("a1").isDefined, "a1 was in store") } - test("disk and memory storage") { - testDiskAndMemoryStorage(StorageLevel.MEMORY_AND_DISK, getAsBytes = false) + encryptionTest("disk and memory storage") { _conf => + testDiskAndMemoryStorage(StorageLevel.MEMORY_AND_DISK, getAsBytes = false, testConf = conf) } - test("disk and memory storage with getLocalBytes") { - testDiskAndMemoryStorage(StorageLevel.MEMORY_AND_DISK, getAsBytes = true) + encryptionTest("disk and memory storage with getLocalBytes") { _conf => + testDiskAndMemoryStorage(StorageLevel.MEMORY_AND_DISK, getAsBytes = true, testConf = conf) } - test("disk and memory storage with serialization") { - testDiskAndMemoryStorage(StorageLevel.MEMORY_AND_DISK_SER, getAsBytes = false) + encryptionTest("disk and memory storage with serialization") { _conf => + testDiskAndMemoryStorage(StorageLevel.MEMORY_AND_DISK_SER, getAsBytes = false, testConf = conf) } - test("disk and memory storage with serialization and getLocalBytes") { - testDiskAndMemoryStorage(StorageLevel.MEMORY_AND_DISK_SER, getAsBytes = true) + encryptionTest("disk and memory storage with serialization and getLocalBytes") { _conf => + testDiskAndMemoryStorage(StorageLevel.MEMORY_AND_DISK_SER, getAsBytes = true, testConf = conf) } - test("disk and off-heap memory storage") { - testDiskAndMemoryStorage(StorageLevel.OFF_HEAP, getAsBytes = false) + encryptionTest("disk and off-heap memory storage") { _conf => + testDiskAndMemoryStorage(StorageLevel.OFF_HEAP, getAsBytes = false, testConf = conf) } - test("disk and off-heap memory storage with getLocalBytes") { - testDiskAndMemoryStorage(StorageLevel.OFF_HEAP, getAsBytes = true) + encryptionTest("disk and off-heap memory storage with getLocalBytes") { _conf => + testDiskAndMemoryStorage(StorageLevel.OFF_HEAP, getAsBytes = true, testConf = conf) } def testDiskAndMemoryStorage( storageLevel: StorageLevel, - getAsBytes: Boolean): Unit = { - store = makeBlockManager(12000) + getAsBytes: Boolean, + testConf: SparkConf): Unit = { + store = makeBlockManager(12000, testConf = Some(testConf)) val accessMethod = if (getAsBytes) store.getLocalBytesAndReleaseLock else store.getSingleAndReleaseLock val a1 = new Array[Byte](4000) @@ -678,8 +690,8 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE } } - test("LRU with mixed storage levels") { - store = makeBlockManager(12000) + encryptionTest("LRU with mixed storage levels") { _conf => + store = makeBlockManager(12000, testConf = Some(_conf)) val a1 = new Array[Byte](4000) val a2 = new Array[Byte](4000) val a3 = new Array[Byte](4000) @@ -700,8 +712,8 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE assert(store.getSingleAndReleaseLock("a4").isDefined, "a4 was not in store") } - test("in-memory LRU with streams") { - store = makeBlockManager(12000) + encryptionTest("in-memory LRU with streams") { _conf => + store = makeBlockManager(12000, testConf = Some(_conf)) val list1 = List(new Array[Byte](2000), new Array[Byte](2000)) val list2 = List(new Array[Byte](2000), new Array[Byte](2000)) val list3 = List(new Array[Byte](2000), new Array[Byte](2000)) @@ -728,8 +740,8 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE assert(store.getAndReleaseLock("list3") === None, "list1 was in store") } - test("LRU with mixed storage levels and streams") { - store = makeBlockManager(12000) + encryptionTest("LRU with mixed storage levels and streams") { _conf => + store = makeBlockManager(12000, testConf = Some(_conf)) val list1 = List(new Array[Byte](2000), new Array[Byte](2000)) val list2 = List(new Array[Byte](2000), new Array[Byte](2000)) val list3 = List(new Array[Byte](2000), new Array[Byte](2000)) @@ -1325,7 +1337,8 @@ private object BlockManagerSuite { val getAndReleaseLock: (BlockId) => Option[BlockResult] = wrapGet(store.get) val getSingleAndReleaseLock: (BlockId) => Option[Any] = wrapGet(store.getSingle) val getLocalBytesAndReleaseLock: (BlockId) => Option[ChunkedByteBuffer] = { - wrapGet(store.getLocalBytes) + val allocator = ByteBuffer.allocate _ + wrapGet { bid => store.getLocalBytes(bid).map(_.toChunkedByteBuffer(allocator)) } } } diff --git a/core/src/test/scala/org/apache/spark/storage/DiskStoreSuite.scala b/core/src/test/scala/org/apache/spark/storage/DiskStoreSuite.scala index 9e6b02b9eac4d..67fc084e8a13d 100644 --- a/core/src/test/scala/org/apache/spark/storage/DiskStoreSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/DiskStoreSuite.scala @@ -18,15 +18,23 @@ package org.apache.spark.storage import java.nio.{ByteBuffer, MappedByteBuffer} -import java.util.Arrays +import java.util.{Arrays, Random} -import org.apache.spark.{SparkConf, SparkFunSuite} +import com.google.common.io.{ByteStreams, Files} +import io.netty.channel.FileRegion + +import org.apache.spark.{SecurityManager, SparkConf, SparkFunSuite} +import org.apache.spark.network.util.{ByteArrayWritableChannel, JavaUtils} +import org.apache.spark.security.CryptoStreamUtils import org.apache.spark.util.io.ChunkedByteBuffer import org.apache.spark.util.Utils class DiskStoreSuite extends SparkFunSuite { test("reads of memory-mapped and non memory-mapped files are equivalent") { + val conf = new SparkConf() + val securityManager = new SecurityManager(conf) + // It will cause error when we tried to re-open the filestore and the // memory-mapped byte buffer tot he file has not been GC on Windows. assume(!Utils.isWindows) @@ -37,16 +45,18 @@ class DiskStoreSuite extends SparkFunSuite { val byteBuffer = new ChunkedByteBuffer(ByteBuffer.wrap(bytes)) val blockId = BlockId("rdd_1_2") - val diskBlockManager = new DiskBlockManager(new SparkConf(), deleteFilesOnStop = true) + val diskBlockManager = new DiskBlockManager(conf, deleteFilesOnStop = true) - val diskStoreMapped = new DiskStore(new SparkConf().set(confKey, "0"), diskBlockManager) + val diskStoreMapped = new DiskStore(conf.clone().set(confKey, "0"), diskBlockManager, + securityManager) diskStoreMapped.putBytes(blockId, byteBuffer) - val mapped = diskStoreMapped.getBytes(blockId) + val mapped = diskStoreMapped.getBytes(blockId).asInstanceOf[ByteBufferBlockData].buffer assert(diskStoreMapped.remove(blockId)) - val diskStoreNotMapped = new DiskStore(new SparkConf().set(confKey, "1m"), diskBlockManager) + val diskStoreNotMapped = new DiskStore(conf.clone().set(confKey, "1m"), diskBlockManager, + securityManager) diskStoreNotMapped.putBytes(blockId, byteBuffer) - val notMapped = diskStoreNotMapped.getBytes(blockId) + val notMapped = diskStoreNotMapped.getBytes(blockId).asInstanceOf[ByteBufferBlockData].buffer // Not possible to do isInstanceOf due to visibility of HeapByteBuffer assert(notMapped.getChunks().forall(_.getClass.getName.endsWith("HeapByteBuffer")), @@ -63,4 +73,95 @@ class DiskStoreSuite extends SparkFunSuite { assert(Arrays.equals(mapped.toArray, bytes)) assert(Arrays.equals(notMapped.toArray, bytes)) } + + test("block size tracking") { + val conf = new SparkConf() + val diskBlockManager = new DiskBlockManager(conf, deleteFilesOnStop = true) + val diskStore = new DiskStore(conf, diskBlockManager, new SecurityManager(conf)) + + val blockId = BlockId("rdd_1_2") + diskStore.put(blockId) { chan => + val buf = ByteBuffer.wrap(new Array[Byte](32)) + while (buf.hasRemaining()) { + chan.write(buf) + } + } + + assert(diskStore.getSize(blockId) === 32L) + diskStore.remove(blockId) + assert(diskStore.getSize(blockId) === 0L) + } + + test("block data encryption") { + val testDir = Utils.createTempDir() + val testData = new Array[Byte](128 * 1024) + new Random().nextBytes(testData) + + val conf = new SparkConf() + val securityManager = new SecurityManager(conf, Some(CryptoStreamUtils.createKey(conf))) + val diskBlockManager = new DiskBlockManager(conf, deleteFilesOnStop = true) + val diskStore = new DiskStore(conf, diskBlockManager, securityManager) + + val blockId = BlockId("rdd_1_2") + diskStore.put(blockId) { chan => + val buf = ByteBuffer.wrap(testData) + while (buf.hasRemaining()) { + chan.write(buf) + } + } + + assert(diskStore.getSize(blockId) === testData.length) + + val diskData = Files.toByteArray(diskBlockManager.getFile(blockId.name)) + assert(!Arrays.equals(testData, diskData)) + + val blockData = diskStore.getBytes(blockId) + assert(blockData.isInstanceOf[EncryptedBlockData]) + assert(blockData.size === testData.length) + Map( + "input stream" -> readViaInputStream _, + "chunked byte buffer" -> readViaChunkedByteBuffer _, + "nio byte buffer" -> readViaNioBuffer _, + "managed buffer" -> readViaManagedBuffer _ + ).foreach { case (name, fn) => + val readData = fn(blockData) + assert(readData.length === blockData.size, s"Size of data read via $name did not match.") + assert(Arrays.equals(testData, readData), s"Data read via $name did not match.") + } + } + + private def readViaInputStream(data: BlockData): Array[Byte] = { + val is = data.toInputStream() + try { + ByteStreams.toByteArray(is) + } finally { + is.close() + } + } + + private def readViaChunkedByteBuffer(data: BlockData): Array[Byte] = { + val buf = data.toChunkedByteBuffer(ByteBuffer.allocate _) + try { + buf.toArray + } finally { + buf.dispose() + } + } + + private def readViaNioBuffer(data: BlockData): Array[Byte] = { + JavaUtils.bufferToArray(data.toByteBuffer()) + } + + private def readViaManagedBuffer(data: BlockData): Array[Byte] = { + val region = data.toNetty().asInstanceOf[FileRegion] + val byteChannel = new ByteArrayWritableChannel(data.size.toInt) + + while (region.transfered() < region.count()) { + region.transferTo(byteChannel, region.transfered()) + } + + byteChannel.close() + byteChannel.getData + } + } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/rdd/WriteAheadLogBackedBlockRDD.scala b/streaming/src/main/scala/org/apache/spark/streaming/rdd/WriteAheadLogBackedBlockRDD.scala index d0864fd3678b2..844760ab61d2e 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/rdd/WriteAheadLogBackedBlockRDD.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/rdd/WriteAheadLogBackedBlockRDD.scala @@ -158,16 +158,14 @@ class WriteAheadLogBackedBlockRDD[T: ClassTag]( logInfo(s"Read partition data of $this from write ahead log, record handle " + partition.walRecordHandle) if (storeInBlockManager) { - blockManager.putBytes(blockId, new ChunkedByteBuffer(dataRead.duplicate()), storageLevel, - encrypt = true) + blockManager.putBytes(blockId, new ChunkedByteBuffer(dataRead.duplicate()), storageLevel) logDebug(s"Stored partition data of $this into block manager with level $storageLevel") dataRead.rewind() } serializerManager .dataDeserializeStream( blockId, - new ChunkedByteBuffer(dataRead).toInputStream(), - maybeEncrypted = false)(elementClassTag) + new ChunkedByteBuffer(dataRead).toInputStream())(elementClassTag) .asInstanceOf[Iterator[T]] } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceivedBlockHandler.scala b/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceivedBlockHandler.scala index 2b488038f0620..80c07958b41f2 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceivedBlockHandler.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceivedBlockHandler.scala @@ -87,8 +87,7 @@ private[streaming] class BlockManagerBasedBlockHandler( putResult case ByteBufferBlock(byteBuffer) => blockManager.putBytes( - blockId, new ChunkedByteBuffer(byteBuffer.duplicate()), storageLevel, tellMaster = true, - encrypt = true) + blockId, new ChunkedByteBuffer(byteBuffer.duplicate()), storageLevel, tellMaster = true) case o => throw new SparkException( s"Could not store $blockId to block manager, unexpected block type ${o.getClass.getName}") @@ -176,11 +175,10 @@ private[streaming] class WriteAheadLogBasedBlockHandler( val serializedBlock = block match { case ArrayBufferBlock(arrayBuffer) => numRecords = Some(arrayBuffer.size.toLong) - serializerManager.dataSerialize(blockId, arrayBuffer.iterator, allowEncryption = false) + serializerManager.dataSerialize(blockId, arrayBuffer.iterator) case IteratorBlock(iterator) => val countIterator = new CountingIterator(iterator) - val serializedBlock = serializerManager.dataSerialize(blockId, countIterator, - allowEncryption = false) + val serializedBlock = serializerManager.dataSerialize(blockId, countIterator) numRecords = countIterator.count serializedBlock case ByteBufferBlock(byteBuffer) => @@ -195,8 +193,7 @@ private[streaming] class WriteAheadLogBasedBlockHandler( blockId, serializedBlock, effectiveStorageLevel, - tellMaster = true, - encrypt = true) + tellMaster = true) if (!putSucceeded) { throw new SparkException( s"Could not store $blockId to block manager with storage level $storageLevel") diff --git a/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockHandlerSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockHandlerSuite.scala index c2b0389b8c6f0..3c4a2716caf90 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockHandlerSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockHandlerSuite.scala @@ -175,8 +175,7 @@ abstract class BaseReceivedBlockHandlerSuite(enableEncryption: Boolean) reader.close() serializerManager.dataDeserializeStream( generateBlockId(), - new ChunkedByteBuffer(bytes).toInputStream(), - maybeEncrypted = false)(ClassTag.Any).toList + new ChunkedByteBuffer(bytes).toInputStream())(ClassTag.Any).toList } loggedData shouldEqual data } @@ -357,7 +356,7 @@ abstract class BaseReceivedBlockHandlerSuite(enableEncryption: Boolean) } def dataToByteBuffer(b: Seq[String]) = - serializerManager.dataSerialize(generateBlockId, b.iterator, allowEncryption = false) + serializerManager.dataSerialize(generateBlockId, b.iterator) val blocks = data.grouped(10).toSeq diff --git a/streaming/src/test/scala/org/apache/spark/streaming/rdd/WriteAheadLogBackedBlockRDDSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/rdd/WriteAheadLogBackedBlockRDDSuite.scala index 2ac0dc96916c5..aa69be7ca9939 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/rdd/WriteAheadLogBackedBlockRDDSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/rdd/WriteAheadLogBackedBlockRDDSuite.scala @@ -250,8 +250,7 @@ class WriteAheadLogBackedBlockRDDSuite require(blockData.size === blockIds.size) val writer = new FileBasedWriteAheadLogWriter(new File(dir, "logFile").toString, hadoopConf) val segments = blockData.zip(blockIds).map { case (data, id) => - writer.write(serializerManager.dataSerialize(id, data.iterator, allowEncryption = false) - .toByteBuffer) + writer.write(serializerManager.dataSerialize(id, data.iterator).toByteBuffer) } writer.close() segments From c622a87c44e0621e1b3024fdca9b2aa3c508615b Mon Sep 17 00:00:00 2001 From: jerryshao Date: Wed, 29 Mar 2017 10:09:58 -0700 Subject: [PATCH 0144/1765] [SPARK-20059][YARN] Use the correct classloader for HBaseCredentialProvider ## What changes were proposed in this pull request? Currently we use system classloader to find HBase jars, if it is specified by `--jars`, then it will be failed with ClassNotFound issue. So here changing to use child classloader. Also putting added jars and main jar into classpath of submitted application in yarn cluster mode, otherwise HBase jars specified with `--jars` will never be honored in cluster mode, and fetching tokens in client side will always be failed. ## How was this patch tested? Unit test and local verification. Author: jerryshao Closes #17388 from jerryshao/SPARK-20059. --- .../main/scala/org/apache/spark/deploy/SparkSubmit.scala | 7 ++++++- .../scala/org/apache/spark/deploy/SparkSubmitSuite.scala | 7 ++++++- .../deploy/yarn/security/HBaseCredentialProvider.scala | 5 +++-- 3 files changed, 15 insertions(+), 4 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala index 1e50eb6635651..77005aa9040b5 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala @@ -485,12 +485,17 @@ object SparkSubmit extends CommandLineUtils { // In client mode, launch the application main class directly // In addition, add the main application jar and any added jars (if any) to the classpath - if (deployMode == CLIENT) { + // Also add the main application jar and any added jars to classpath in case YARN client + // requires these jars. + if (deployMode == CLIENT || isYarnCluster) { childMainClass = args.mainClass if (isUserJar(args.primaryResource)) { childClasspath += args.primaryResource } if (args.jars != null) { childClasspath ++= args.jars.split(",") } + } + + if (deployMode == CLIENT) { if (args.childArgs != null) { childArgs ++= args.childArgs } } diff --git a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala index 9417930d02405..a591b98bca488 100644 --- a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala @@ -213,7 +213,12 @@ class SparkSubmitSuite childArgsStr should include ("--arg arg1 --arg arg2") childArgsStr should include regex ("--jar .*thejar.jar") mainClass should be ("org.apache.spark.deploy.yarn.Client") - classpath should have length (0) + + // In yarn cluster mode, also adding jars to classpath + classpath(0) should endWith ("thejar.jar") + classpath(1) should endWith ("one.jar") + classpath(2) should endWith ("two.jar") + classpath(3) should endWith ("three.jar") sysProps("spark.executor.memory") should be ("5g") sysProps("spark.driver.memory") should be ("4g") diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/security/HBaseCredentialProvider.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/security/HBaseCredentialProvider.scala index 5571df09a2ec9..5adeb8e605ff4 100644 --- a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/security/HBaseCredentialProvider.scala +++ b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/security/HBaseCredentialProvider.scala @@ -26,6 +26,7 @@ import org.apache.hadoop.security.token.{Token, TokenIdentifier} import org.apache.spark.SparkConf import org.apache.spark.internal.Logging +import org.apache.spark.util.Utils private[security] class HBaseCredentialProvider extends ServiceCredentialProvider with Logging { @@ -36,7 +37,7 @@ private[security] class HBaseCredentialProvider extends ServiceCredentialProvide sparkConf: SparkConf, creds: Credentials): Option[Long] = { try { - val mirror = universe.runtimeMirror(getClass.getClassLoader) + val mirror = universe.runtimeMirror(Utils.getContextOrSparkClassLoader) val obtainToken = mirror.classLoader. loadClass("org.apache.hadoop.hbase.security.token.TokenUtil"). getMethod("obtainToken", classOf[Configuration]) @@ -60,7 +61,7 @@ private[security] class HBaseCredentialProvider extends ServiceCredentialProvide private def hbaseConf(conf: Configuration): Configuration = { try { - val mirror = universe.runtimeMirror(getClass.getClassLoader) + val mirror = universe.runtimeMirror(Utils.getContextOrSparkClassLoader) val confCreate = mirror.classLoader. loadClass("org.apache.hadoop.hbase.HBaseConfiguration"). getMethod("create", classOf[Configuration]) From d6ddfdf60e77340256873b5acf08e85f95cf3bc2 Mon Sep 17 00:00:00 2001 From: Holden Karau Date: Wed, 29 Mar 2017 11:41:17 -0700 Subject: [PATCH 0145/1765] [SPARK-19955][PYSPARK] Jenkins Python Conda based test. ## What changes were proposed in this pull request? Allow Jenkins Python tests to use the installed conda to test Python 2.7 support & test pip installability. ## How was this patch tested? Updated shell scripts, ran tests locally with installed conda, ran tests in Jenkins. Author: Holden Karau Closes #17355 from holdenk/SPARK-19955-support-python-tests-with-conda. --- dev/run-pip-tests | 66 +++++++++++++++++++++++++++---------------- dev/run-tests-jenkins | 3 +- python/run-tests.py | 6 ++-- 3 files changed, 47 insertions(+), 28 deletions(-) diff --git a/dev/run-pip-tests b/dev/run-pip-tests index af1b1feb70cd1..d51dde12a03c5 100755 --- a/dev/run-pip-tests +++ b/dev/run-pip-tests @@ -35,9 +35,28 @@ function delete_virtualenv() { } trap delete_virtualenv EXIT +PYTHON_EXECS=() # Some systems don't have pip or virtualenv - in those cases our tests won't work. -if ! hash virtualenv 2>/dev/null; then - echo "Missing virtualenv skipping pip installability tests." +if hash virtualenv 2>/dev/null && [ ! -n "$USE_CONDA" ]; then + echo "virtualenv installed - using. Note if this is a conda virtual env you may wish to set USE_CONDA" + # Figure out which Python execs we should test pip installation with + if hash python2 2>/dev/null; then + # We do this since we are testing with virtualenv and the default virtual env python + # is in /usr/bin/python + PYTHON_EXECS+=('python2') + elif hash python 2>/dev/null; then + # If python2 isn't installed fallback to python if available + PYTHON_EXECS+=('python') + fi + if hash python3 2>/dev/null; then + PYTHON_EXECS+=('python3') + fi +elif hash conda 2>/dev/null; then + echo "Using conda virtual enviroments" + PYTHON_EXECS=('3.5') + USE_CONDA=1 +else + echo "Missing virtualenv & conda, skipping pip installability tests" exit 0 fi if ! hash pip 2>/dev/null; then @@ -45,22 +64,8 @@ if ! hash pip 2>/dev/null; then exit 0 fi -# Figure out which Python execs we should test pip installation with -PYTHON_EXECS=() -if hash python2 2>/dev/null; then - # We do this since we are testing with virtualenv and the default virtual env python - # is in /usr/bin/python - PYTHON_EXECS+=('python2') -elif hash python 2>/dev/null; then - # If python2 isn't installed fallback to python if available - PYTHON_EXECS+=('python') -fi -if hash python3 2>/dev/null; then - PYTHON_EXECS+=('python3') -fi - # Determine which version of PySpark we are building for archive name -PYSPARK_VERSION=$(python -c "exec(open('python/pyspark/version.py').read());print __version__") +PYSPARK_VERSION=$(python3 -c "exec(open('python/pyspark/version.py').read());print(__version__)") PYSPARK_DIST="$FWDIR/python/dist/pyspark-$PYSPARK_VERSION.tar.gz" # The pip install options we use for all the pip commands PIP_OPTIONS="--upgrade --no-cache-dir --force-reinstall " @@ -75,18 +80,24 @@ for python in "${PYTHON_EXECS[@]}"; do echo "Using $VIRTUALENV_BASE for virtualenv" VIRTUALENV_PATH="$VIRTUALENV_BASE"/$python rm -rf "$VIRTUALENV_PATH" - mkdir -p "$VIRTUALENV_PATH" - virtualenv --python=$python "$VIRTUALENV_PATH" - source "$VIRTUALENV_PATH"/bin/activate - # Upgrade pip & friends - pip install --upgrade pip pypandoc wheel - pip install numpy # Needed so we can verify mllib imports + if [ -n "$USE_CONDA" ]; then + conda create -y -p "$VIRTUALENV_PATH" python=$python numpy pandas pip setuptools + source activate "$VIRTUALENV_PATH" + else + mkdir -p "$VIRTUALENV_PATH" + virtualenv --python=$python "$VIRTUALENV_PATH" + source "$VIRTUALENV_PATH"/bin/activate + fi + # Upgrade pip & friends if using virutal env + if [ ! -n "USE_CONDA" ]; then + pip install --upgrade pip pypandoc wheel numpy + fi echo "Creating pip installable source dist" cd "$FWDIR"/python # Delete the egg info file if it exists, this can cache the setup file. rm -rf pyspark.egg-info || echo "No existing egg info file, skipping deletion" - $python setup.py sdist + python setup.py sdist echo "Installing dist into virtual env" @@ -112,6 +123,13 @@ for python in "${PYTHON_EXECS[@]}"; do cd "$FWDIR" + # conda / virtualenv enviroments need to be deactivated differently + if [ -n "$USE_CONDA" ]; then + source deactivate + else + deactivate + fi + done done diff --git a/dev/run-tests-jenkins b/dev/run-tests-jenkins index e79accf9e987a..f41f1ac79e381 100755 --- a/dev/run-tests-jenkins +++ b/dev/run-tests-jenkins @@ -22,7 +22,8 @@ # Environment variables are populated by the code here: #+ https://github.com/jenkinsci/ghprb-plugin/blob/master/src/main/java/org/jenkinsci/plugins/ghprb/GhprbTrigger.java#L139 -FWDIR="$(cd "`dirname $0`"/..; pwd)" +FWDIR="$( cd "$( dirname "$0" )/.." && pwd )" cd "$FWDIR" +export PATH=/home/anaconda/bin:$PATH exec python -u ./dev/run-tests-jenkins.py "$@" diff --git a/python/run-tests.py b/python/run-tests.py index 53a0aef229b08..b2e50435bb192 100755 --- a/python/run-tests.py +++ b/python/run-tests.py @@ -111,9 +111,9 @@ def run_individual_python_test(test_name, pyspark_python): def get_default_python_executables(): - python_execs = [x for x in ["python2.6", "python3.4", "pypy"] if which(x)] - if "python2.6" not in python_execs: - LOGGER.warning("Not testing against `python2.6` because it could not be found; falling" + python_execs = [x for x in ["python2.7", "python3.4", "pypy"] if which(x)] + if "python2.7" not in python_execs: + LOGGER.warning("Not testing against `python2.7` because it could not be found; falling" " back to `python` instead") python_execs.insert(0, "python") return python_execs From 142f6d14928c780cc9e8d6d7749c5d7c08a30972 Mon Sep 17 00:00:00 2001 From: Kunal Khamar Date: Wed, 29 Mar 2017 12:35:19 -0700 Subject: [PATCH 0146/1765] [SPARK-20048][SQL] Cloning SessionState does not clone query execution listeners ## What changes were proposed in this pull request? Bugfix from [SPARK-19540.](https://github.com/apache/spark/pull/16826) Cloning SessionState does not clone query execution listeners, so cloned session is unable to listen to events on queries. ## How was this patch tested? - Unit test Author: Kunal Khamar Closes #17379 from kunalkhamar/clone-bugfix. --- .../org/apache/spark/sql/SparkSession.scala | 22 ++++---- ...rs.scala => BaseSessionStateBuilder.scala} | 24 ++++++++- .../spark/sql/internal/SessionState.scala | 38 ++++--------- .../sql/util/QueryExecutionListener.scala | 10 ++++ .../apache/spark/sql/SessionStateSuite.scala | 53 +++++++++++++++++++ .../hive/thriftserver/SparkSQLCLIDriver.scala | 2 +- ...te.scala => HiveSessionStateBuilder.scala} | 14 +---- .../apache/spark/sql/hive/test/TestHive.scala | 2 +- .../sql/hive/HiveSessionStateSuite.scala | 2 +- 9 files changed, 111 insertions(+), 56 deletions(-) rename sql/core/src/main/scala/org/apache/spark/sql/internal/{sessionStateBuilders.scala => BaseSessionStateBuilder.scala} (92%) rename sql/hive/src/main/scala/org/apache/spark/sql/hive/{HiveSessionState.scala => HiveSessionStateBuilder.scala} (92%) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala index 49562578b23cd..a97297892b5e0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala @@ -38,7 +38,7 @@ import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, Range} import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.datasources.LogicalRelation import org.apache.spark.sql.execution.ui.SQLListener -import org.apache.spark.sql.internal.{CatalogImpl, SessionState, SharedState} +import org.apache.spark.sql.internal.{BaseSessionStateBuilder, CatalogImpl, SessionState, SessionStateBuilder, SharedState} import org.apache.spark.sql.internal.StaticSQLConf.CATALOG_IMPLEMENTATION import org.apache.spark.sql.sources.BaseRelation import org.apache.spark.sql.streaming._ @@ -194,7 +194,7 @@ class SparkSession private( * * @since 2.0.0 */ - def udf: UDFRegistration = sessionState.udf + def udf: UDFRegistration = sessionState.udfRegistration /** * :: Experimental :: @@ -990,28 +990,28 @@ object SparkSession { /** Reference to the root SparkSession. */ private val defaultSession = new AtomicReference[SparkSession] - private val HIVE_SESSION_STATE_CLASS_NAME = "org.apache.spark.sql.hive.HiveSessionState" + private val HIVE_SESSION_STATE_BUILDER_CLASS_NAME = + "org.apache.spark.sql.hive.HiveSessionStateBuilder" private def sessionStateClassName(conf: SparkConf): String = { conf.get(CATALOG_IMPLEMENTATION) match { - case "hive" => HIVE_SESSION_STATE_CLASS_NAME - case "in-memory" => classOf[SessionState].getCanonicalName + case "hive" => HIVE_SESSION_STATE_BUILDER_CLASS_NAME + case "in-memory" => classOf[SessionStateBuilder].getCanonicalName } } /** * Helper method to create an instance of `SessionState` based on `className` from conf. - * The result is either `SessionState` or `HiveSessionState`. + * The result is either `SessionState` or a Hive based `SessionState`. */ private def instantiateSessionState( className: String, sparkSession: SparkSession): SessionState = { - try { - // get `SessionState.apply(SparkSession)` + // invoke `new [Hive]SessionStateBuilder(SparkSession, Option[SessionState])` val clazz = Utils.classForName(className) - val method = clazz.getMethod("apply", sparkSession.getClass) - method.invoke(null, sparkSession).asInstanceOf[SessionState] + val ctor = clazz.getConstructors.head + ctor.newInstance(sparkSession, None).asInstanceOf[BaseSessionStateBuilder].build() } catch { case NonFatal(e) => throw new IllegalArgumentException(s"Error while instantiating '$className':", e) @@ -1023,7 +1023,7 @@ object SparkSession { */ private[spark] def hiveClassesArePresent: Boolean = { try { - Utils.classForName(HIVE_SESSION_STATE_CLASS_NAME) + Utils.classForName(HIVE_SESSION_STATE_BUILDER_CLASS_NAME) Utils.classForName("org.apache.hadoop.hive.conf.HiveConf") true } catch { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/sessionStateBuilders.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala similarity index 92% rename from sql/core/src/main/scala/org/apache/spark/sql/internal/sessionStateBuilders.scala rename to sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala index b8f645fdee85a..2b14eca919fa4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/sessionStateBuilders.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.internal import org.apache.spark.SparkConf import org.apache.spark.annotation.{Experimental, InterfaceStability} -import org.apache.spark.sql.{ExperimentalMethods, SparkSession, Strategy} +import org.apache.spark.sql.{ExperimentalMethods, SparkSession, Strategy, UDFRegistration} import org.apache.spark.sql.catalyst.analysis.{Analyzer, FunctionRegistry} import org.apache.spark.sql.catalyst.catalog.SessionCatalog import org.apache.spark.sql.catalyst.optimizer.Optimizer @@ -28,6 +28,7 @@ import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.execution.{QueryExecution, SparkOptimizer, SparkPlanner, SparkSqlParser} import org.apache.spark.sql.execution.datasources._ import org.apache.spark.sql.streaming.StreamingQueryManager +import org.apache.spark.sql.util.ExecutionListenerManager /** * Builder class that coordinates construction of a new [[SessionState]]. @@ -133,6 +134,14 @@ abstract class BaseSessionStateBuilder( catalog } + /** + * Interface exposed to the user for registering user-defined functions. + * + * Note 1: The user-defined functions must be deterministic. + * Note 2: This depends on the `functionRegistry` field. + */ + protected def udfRegistration: UDFRegistration = new UDFRegistration(functionRegistry) + /** * Logical query plan analyzer for resolving unresolved attributes and relations. * @@ -232,6 +241,16 @@ abstract class BaseSessionStateBuilder( */ protected def streamingQueryManager: StreamingQueryManager = new StreamingQueryManager(session) + /** + * An interface to register custom [[org.apache.spark.sql.util.QueryExecutionListener]]s + * that listen for execution metrics. + * + * This gets cloned from parent if available, otherwise is a new instance is created. + */ + protected def listenerManager: ExecutionListenerManager = { + parentState.map(_.listenerManager.clone()).getOrElse(new ExecutionListenerManager) + } + /** * Function used to make clones of the session state. */ @@ -245,17 +264,18 @@ abstract class BaseSessionStateBuilder( */ def build(): SessionState = { new SessionState( - session.sparkContext, session.sharedState, conf, experimentalMethods, functionRegistry, + udfRegistration, catalog, sqlParser, analyzer, optimizer, planner, streamingQueryManager, + listenerManager, resourceLoader, createQueryExecution, createClone) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala index c6241d923d7b3..1b341a12fc609 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala @@ -32,43 +32,46 @@ import org.apache.spark.sql.catalyst.parser.ParserInterface import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.execution._ import org.apache.spark.sql.streaming.StreamingQueryManager -import org.apache.spark.sql.util.ExecutionListenerManager +import org.apache.spark.sql.util.{ExecutionListenerManager, QueryExecutionListener} /** * A class that holds all session-specific state in a given [[SparkSession]]. * - * @param sparkContext The [[SparkContext]]. - * @param sharedState The shared state. + * @param sharedState The state shared across sessions, e.g. global view manager, external catalog. * @param conf SQL-specific key-value configurations. - * @param experimentalMethods The experimental methods. + * @param experimentalMethods Interface to add custom planning strategies and optimizers. * @param functionRegistry Internal catalog for managing functions registered by the user. + * @param udfRegistration Interface exposed to the user for registering user-defined functions. * @param catalog Internal catalog for managing table and database states. * @param sqlParser Parser that extracts expressions, plans, table identifiers etc. from SQL texts. * @param analyzer Logical query plan analyzer for resolving unresolved attributes and relations. * @param optimizer Logical query plan optimizer. - * @param planner Planner that converts optimized logical plans to physical plans + * @param planner Planner that converts optimized logical plans to physical plans. * @param streamingQueryManager Interface to start and stop streaming queries. + * @param listenerManager Interface to register custom [[QueryExecutionListener]]s. + * @param resourceLoader Session shared resource loader to load JARs, files, etc. * @param createQueryExecution Function used to create QueryExecution objects. * @param createClone Function used to create clones of the session state. */ private[sql] class SessionState( - sparkContext: SparkContext, sharedState: SharedState, val conf: SQLConf, val experimentalMethods: ExperimentalMethods, val functionRegistry: FunctionRegistry, + val udfRegistration: UDFRegistration, val catalog: SessionCatalog, val sqlParser: ParserInterface, val analyzer: Analyzer, val optimizer: Optimizer, val planner: SparkPlanner, val streamingQueryManager: StreamingQueryManager, + val listenerManager: ExecutionListenerManager, val resourceLoader: SessionResourceLoader, createQueryExecution: LogicalPlan => QueryExecution, createClone: (SparkSession, SessionState) => SessionState) { def newHadoopConf(): Configuration = SessionState.newHadoopConf( - sparkContext.hadoopConfiguration, + sharedState.sparkContext.hadoopConfiguration, conf) def newHadoopConfWithOptions(options: Map[String, String]): Configuration = { @@ -81,18 +84,6 @@ private[sql] class SessionState( hadoopConf } - /** - * Interface exposed to the user for registering user-defined functions. - * Note that the user-defined functions must be deterministic. - */ - val udf: UDFRegistration = new UDFRegistration(functionRegistry) - - /** - * An interface to register custom [[org.apache.spark.sql.util.QueryExecutionListener]]s - * that listen for execution metrics. - */ - val listenerManager: ExecutionListenerManager = new ExecutionListenerManager - /** * Get an identical copy of the `SessionState` and associate it with the given `SparkSession` */ @@ -110,13 +101,6 @@ private[sql] class SessionState( } private[sql] object SessionState { - /** - * Create a new [[SessionState]] for the given session. - */ - def apply(session: SparkSession): SessionState = { - new SessionStateBuilder(session).build() - } - def newHadoopConf(hadoopConf: Configuration, sqlConf: SQLConf): Configuration = { val newHadoopConf = new Configuration(hadoopConf) sqlConf.getAllConfs.foreach { case (k, v) => if (v ne null) newHadoopConf.set(k, v) } @@ -155,7 +139,7 @@ class SessionResourceLoader(session: SparkSession) extends FunctionResourceLoade /** * Add a jar path to [[SparkContext]] and the classloader. * - * Note: this method seems not access any session state, but the subclass `HiveSessionState` needs + * Note: this method seems not access any session state, but a Hive based `SessionState` needs * to add the jar to its hive client for the current session. Hence, it still needs to be in * [[SessionState]]. */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/util/QueryExecutionListener.scala b/sql/core/src/main/scala/org/apache/spark/sql/util/QueryExecutionListener.scala index 26ad0eadd9d4c..f6240d85fba6f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/util/QueryExecutionListener.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/util/QueryExecutionListener.scala @@ -98,6 +98,16 @@ class ExecutionListenerManager private[sql] () extends Logging { listeners.clear() } + /** + * Get an identical copy of this listener manager. + */ + @DeveloperApi + override def clone(): ExecutionListenerManager = writeLock { + val newListenerManager = new ExecutionListenerManager + listeners.foreach(newListenerManager.register) + newListenerManager + } + private[sql] def onSuccess(funcName: String, qe: QueryExecution, duration: Long): Unit = { readLock { withErrorHandling { listener => diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SessionStateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SessionStateSuite.scala index 2d5e37242a58b..5638c8eeda842 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SessionStateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SessionStateSuite.scala @@ -19,10 +19,13 @@ package org.apache.spark.sql import org.scalatest.BeforeAndAfterAll import org.scalatest.BeforeAndAfterEach +import scala.collection.mutable.ArrayBuffer import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.execution.QueryExecution +import org.apache.spark.sql.util.QueryExecutionListener class SessionStateSuite extends SparkFunSuite with BeforeAndAfterEach with BeforeAndAfterAll { @@ -122,6 +125,56 @@ class SessionStateSuite extends SparkFunSuite } } + test("fork new session and inherit listener manager") { + class CommandCollector extends QueryExecutionListener { + val commands: ArrayBuffer[String] = ArrayBuffer.empty[String] + override def onFailure(funcName: String, qe: QueryExecution, ex: Exception) : Unit = {} + override def onSuccess(funcName: String, qe: QueryExecution, duration: Long): Unit = { + commands += funcName + } + } + val collectorA = new CommandCollector + val collectorB = new CommandCollector + val collectorC = new CommandCollector + + try { + def runCollectQueryOn(sparkSession: SparkSession): Unit = { + val tupleEncoder = Encoders.tuple(Encoders.scalaInt, Encoders.STRING) + val df = sparkSession.createDataset(Seq(1 -> "a"))(tupleEncoder).toDF("i", "j") + df.select("i").collect() + } + + activeSession.listenerManager.register(collectorA) + val forkedSession = activeSession.cloneSession() + + // inheritance + assert(forkedSession ne activeSession) + assert(forkedSession.listenerManager ne activeSession.listenerManager) + runCollectQueryOn(forkedSession) + assert(collectorA.commands.length == 1) // forked should callback to A + assert(collectorA.commands(0) == "collect") + + // independence + // => changes to forked do not affect original + forkedSession.listenerManager.register(collectorB) + runCollectQueryOn(activeSession) + assert(collectorB.commands.isEmpty) // original should not callback to B + assert(collectorA.commands.length == 2) // original should still callback to A + assert(collectorA.commands(1) == "collect") + // <= changes to original do not affect forked + activeSession.listenerManager.register(collectorC) + runCollectQueryOn(forkedSession) + assert(collectorC.commands.isEmpty) // forked should not callback to C + assert(collectorA.commands.length == 3) // forked should still callback to A + assert(collectorB.commands.length == 1) // forked should still callback to B + assert(collectorA.commands(2) == "collect") + assert(collectorB.commands(0) == "collect") + } finally { + activeSession.listenerManager.unregister(collectorA) + activeSession.listenerManager.unregister(collectorC) + } + } + test("fork new sessions and run query on inherited table") { def checkTableExists(sparkSession: SparkSession): Unit = { QueryTest.checkAnswer(sparkSession.sql( diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala index 0c79b6f4211ff..390b9b6d68cab 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala @@ -38,7 +38,7 @@ import org.apache.thrift.transport.TSocket import org.apache.spark.internal.Logging import org.apache.spark.sql.AnalysisException -import org.apache.spark.sql.hive.{HiveSessionState, HiveUtils} +import org.apache.spark.sql.hive.HiveUtils import org.apache.spark.util.ShutdownHookManager /** diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionState.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionStateBuilder.scala similarity index 92% rename from sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionState.scala rename to sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionStateBuilder.scala index f49e6bb418644..8048c2ba2c2e4 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionState.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionStateBuilder.scala @@ -28,19 +28,7 @@ import org.apache.spark.sql.hive.client.HiveClient import org.apache.spark.sql.internal.{BaseSessionStateBuilder, SessionResourceLoader, SessionState} /** - * Entry object for creating a Hive aware [[SessionState]]. - */ -private[hive] object HiveSessionState { - /** - * Create a new Hive aware [[SessionState]]. for the given session. - */ - def apply(session: SparkSession): SessionState = { - new HiveSessionStateBuilder(session).build() - } -} - -/** - * Builder that produces a [[HiveSessionState]]. + * Builder that produces a Hive aware [[SessionState]]. */ @Experimental @InterfaceStability.Unstable diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala index 0bcf219922764..d9bb1f8c7edcc 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala @@ -39,7 +39,7 @@ import org.apache.spark.sql.execution.QueryExecution import org.apache.spark.sql.execution.command.CacheTableCommand import org.apache.spark.sql.hive._ import org.apache.spark.sql.hive.client.HiveClient -import org.apache.spark.sql.internal._ +import org.apache.spark.sql.internal.{SessionState, SharedState, SQLConf, WithTestConf} import org.apache.spark.sql.internal.StaticSQLConf.CATALOG_IMPLEMENTATION import org.apache.spark.util.{ShutdownHookManager, Utils} diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSessionStateSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSessionStateSuite.scala index 67c77fb62f4e1..958ad3e1c3ce8 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSessionStateSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSessionStateSuite.scala @@ -23,7 +23,7 @@ import org.apache.spark.sql._ import org.apache.spark.sql.hive.test.TestHiveSingleton /** - * Run all tests from `SessionStateSuite` with a `HiveSessionState`. + * Run all tests from `SessionStateSuite` with a Hive based `SessionState`. */ class HiveSessionStateSuite extends SessionStateSuite with TestHiveSingleton with BeforeAndAfterEach { From c4008480b781379ac0451b9220300d83c054c60d Mon Sep 17 00:00:00 2001 From: Takeshi Yamamuro Date: Wed, 29 Mar 2017 12:37:49 -0700 Subject: [PATCH 0147/1765] [SPARK-20009][SQL] Support DDL strings for defining schema in functions.from_json ## What changes were proposed in this pull request? This pr added `StructType.fromDDL` to convert a DDL format string into `StructType` for defining schemas in `functions.from_json`. ## How was this patch tested? Added tests in `JsonFunctionsSuite`. Author: Takeshi Yamamuro Closes #17406 from maropu/SPARK-20009. --- .../apache/spark/sql/types/StructType.scala | 6 ++ .../spark/sql/types/DataTypeSuite.scala | 85 ++++++++++++++----- .../org/apache/spark/sql/functions.scala | 15 +++- .../apache/spark/sql/JsonFunctionsSuite.scala | 7 ++ .../sql/sources/SimpleTextRelation.scala | 2 +- 5 files changed, 90 insertions(+), 25 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala index 8d8b5b86d5aa1..54006e20a3eb6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala @@ -417,6 +417,12 @@ object StructType extends AbstractDataType { } } + /** + * Creates StructType for a given DDL-formatted string, which is a comma separated list of field + * definitions, e.g., a INT, b STRING. + */ + def fromDDL(ddl: String): StructType = CatalystSqlParser.parseTableSchema(ddl) + def apply(fields: Seq[StructField]): StructType = StructType(fields.toArray) def apply(fields: java.util.List[StructField]): StructType = { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala index 61e1ec7c7ab35..05cb999af6a50 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala @@ -169,30 +169,72 @@ class DataTypeSuite extends SparkFunSuite { assert(!arrayType.existsRecursively(_.isInstanceOf[IntegerType])) } - def checkDataTypeJsonRepr(dataType: DataType): Unit = { - test(s"JSON - $dataType") { + def checkDataTypeFromJson(dataType: DataType): Unit = { + test(s"from Json - $dataType") { assert(DataType.fromJson(dataType.json) === dataType) } } - checkDataTypeJsonRepr(NullType) - checkDataTypeJsonRepr(BooleanType) - checkDataTypeJsonRepr(ByteType) - checkDataTypeJsonRepr(ShortType) - checkDataTypeJsonRepr(IntegerType) - checkDataTypeJsonRepr(LongType) - checkDataTypeJsonRepr(FloatType) - checkDataTypeJsonRepr(DoubleType) - checkDataTypeJsonRepr(DecimalType(10, 5)) - checkDataTypeJsonRepr(DecimalType.SYSTEM_DEFAULT) - checkDataTypeJsonRepr(DateType) - checkDataTypeJsonRepr(TimestampType) - checkDataTypeJsonRepr(StringType) - checkDataTypeJsonRepr(BinaryType) - checkDataTypeJsonRepr(ArrayType(DoubleType, true)) - checkDataTypeJsonRepr(ArrayType(StringType, false)) - checkDataTypeJsonRepr(MapType(IntegerType, StringType, true)) - checkDataTypeJsonRepr(MapType(IntegerType, ArrayType(DoubleType), false)) + def checkDataTypeFromDDL(dataType: DataType): Unit = { + test(s"from DDL - $dataType") { + val parsed = StructType.fromDDL(s"a ${dataType.sql}") + val expected = new StructType().add("a", dataType) + assert(parsed.sameType(expected)) + } + } + + checkDataTypeFromJson(NullType) + + checkDataTypeFromJson(BooleanType) + checkDataTypeFromDDL(BooleanType) + + checkDataTypeFromJson(ByteType) + checkDataTypeFromDDL(ByteType) + + checkDataTypeFromJson(ShortType) + checkDataTypeFromDDL(ShortType) + + checkDataTypeFromJson(IntegerType) + checkDataTypeFromDDL(IntegerType) + + checkDataTypeFromJson(LongType) + checkDataTypeFromDDL(LongType) + + checkDataTypeFromJson(FloatType) + checkDataTypeFromDDL(FloatType) + + checkDataTypeFromJson(DoubleType) + checkDataTypeFromDDL(DoubleType) + + checkDataTypeFromJson(DecimalType(10, 5)) + checkDataTypeFromDDL(DecimalType(10, 5)) + + checkDataTypeFromJson(DecimalType.SYSTEM_DEFAULT) + checkDataTypeFromDDL(DecimalType.SYSTEM_DEFAULT) + + checkDataTypeFromJson(DateType) + checkDataTypeFromDDL(DateType) + + checkDataTypeFromJson(TimestampType) + checkDataTypeFromDDL(TimestampType) + + checkDataTypeFromJson(StringType) + checkDataTypeFromDDL(StringType) + + checkDataTypeFromJson(BinaryType) + checkDataTypeFromDDL(BinaryType) + + checkDataTypeFromJson(ArrayType(DoubleType, true)) + checkDataTypeFromDDL(ArrayType(DoubleType, true)) + + checkDataTypeFromJson(ArrayType(StringType, false)) + checkDataTypeFromDDL(ArrayType(StringType, false)) + + checkDataTypeFromJson(MapType(IntegerType, StringType, true)) + checkDataTypeFromDDL(MapType(IntegerType, StringType, true)) + + checkDataTypeFromJson(MapType(IntegerType, ArrayType(DoubleType), false)) + checkDataTypeFromDDL(MapType(IntegerType, ArrayType(DoubleType), false)) val metadata = new MetadataBuilder() .putString("name", "age") @@ -201,7 +243,8 @@ class DataTypeSuite extends SparkFunSuite { StructField("a", IntegerType, nullable = true), StructField("b", ArrayType(DoubleType), nullable = false), StructField("c", DoubleType, nullable = false, metadata))) - checkDataTypeJsonRepr(structType) + checkDataTypeFromJson(structType) + checkDataTypeFromDDL(structType) def checkDefaultSize(dataType: DataType, expectedDefaultSize: Int): Unit = { test(s"Check the default size of $dataType") { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index acdb8e2d3edc8..0f9203065ef05 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -21,6 +21,7 @@ import scala.collection.JavaConverters._ import scala.language.implicitConversions import scala.reflect.runtime.universe.{typeTag, TypeTag} import scala.util.Try +import scala.util.control.NonFatal import org.apache.spark.annotation.{Experimental, InterfaceStability} import org.apache.spark.sql.catalyst.ScalaReflection @@ -3055,13 +3056,21 @@ object functions { * with the specified schema. Returns `null`, in the case of an unparseable string. * * @param e a string column containing JSON data. - * @param schema the schema to use when parsing the json string as a json string + * @param schema the schema to use when parsing the json string as a json string. In Spark 2.1, + * the user-provided schema has to be in JSON format. Since Spark 2.2, the DDL + * format is also supported for the schema. * * @group collection_funcs * @since 2.1.0 */ - def from_json(e: Column, schema: String, options: java.util.Map[String, String]): Column = - from_json(e, DataType.fromJson(schema), options) + def from_json(e: Column, schema: String, options: java.util.Map[String, String]): Column = { + val dataType = try { + DataType.fromJson(schema) + } catch { + case NonFatal(_) => StructType.fromDDL(schema) + } + from_json(e, dataType, options) + } /** * (Scala-specific) Converts a column containing a `StructType` or `ArrayType` of `StructType`s diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JsonFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JsonFunctionsSuite.scala index 170c238c53438..8465e8d036a6d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/JsonFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/JsonFunctionsSuite.scala @@ -156,6 +156,13 @@ class JsonFunctionsSuite extends QueryTest with SharedSQLContext { Row(Seq(Row(1, "a"), Row(2, null), Row(null, null)))) } + test("from_json uses DDL strings for defining a schema") { + val df = Seq("""{"a": 1, "b": "haa"}""").toDS() + checkAnswer( + df.select(from_json($"value", "a INT, b STRING", new java.util.HashMap[String, String]())), + Row(Row(1, "haa")) :: Nil) + } + test("to_json - struct") { val df = Seq(Tuple1(Tuple1(1))).toDF("a") diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextRelation.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextRelation.scala index 1607c97cd6acb..9f4009bfe402a 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextRelation.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextRelation.scala @@ -21,7 +21,7 @@ import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.{FileStatus, Path} import org.apache.hadoop.mapreduce.{Job, TaskAttemptContext} -import org.apache.spark.sql.{sources, Row, SparkSession} +import org.apache.spark.sql.{sources, SparkSession} import org.apache.spark.sql.catalyst.{expressions, InternalRow} import org.apache.spark.sql.catalyst.expressions.{Cast, Expression, GenericInternalRow, InterpretedPredicate, InterpretedProjection, JoinedRow, Literal} import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection From 5c8ef376e874497766ba0cc4d97429e33a3d9c61 Mon Sep 17 00:00:00 2001 From: Xiao Li Date: Wed, 29 Mar 2017 12:43:22 -0700 Subject: [PATCH 0148/1765] [SPARK-17075][SQL][FOLLOWUP] Add Estimation of Constant Literal ### What changes were proposed in this pull request? `FalseLiteral` and `TrueLiteral` should have been eliminated by optimizer rule `BooleanSimplification`, but null literals might be added by optimizer rule `NullPropagation`. For safety, our filter estimation should handle all the eligible literal cases. Our optimizer rule BooleanSimplification is unable to remove the null literal in many cases. For example, `a < 0 or null`. Thus, we need to handle null literal in filter estimation. `Not` can be pushed down below `And` and `Or`. Then, we could see two consecutive `Not`, which need to be collapsed into one. Because of the limited expression support for filter estimation, we just need to handle the case `Not(null)` for avoiding incorrect error due to the boolean operation on null. For details, see below matrix. ``` not NULL = NULL NULL or false = NULL NULL or true = true NULL or NULL = NULL NULL and false = false NULL and true = NULL NULL and NULL = NULL ``` ### How was this patch tested? Added the test cases. Author: Xiao Li Closes #17446 from gatorsmile/constantFilterEstimation. --- .../statsEstimation/FilterEstimation.scala | 39 ++++++++- .../FilterEstimationSuite.scala | 87 +++++++++++++++++++ 2 files changed, 124 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/FilterEstimation.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/FilterEstimation.scala index f14df93160b75..b32374c5742ef 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/FilterEstimation.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/FilterEstimation.scala @@ -24,6 +24,7 @@ import scala.math.BigDecimal.RoundingMode import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.CatalystConf import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.Literal.{FalseLiteral, TrueLiteral} import org.apache.spark.sql.catalyst.plans.logical.{ColumnStat, Filter, LeafNode, Statistics} import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.types._ @@ -104,12 +105,23 @@ case class FilterEstimation(plan: Filter, catalystConf: CatalystConf) extends Lo val percent2 = calculateFilterSelectivity(cond2, update = false).getOrElse(1.0) Some(percent1 + percent2 - (percent1 * percent2)) + // Not-operator pushdown case Not(And(cond1, cond2)) => calculateFilterSelectivity(Or(Not(cond1), Not(cond2)), update = false) + // Not-operator pushdown case Not(Or(cond1, cond2)) => calculateFilterSelectivity(And(Not(cond1), Not(cond2)), update = false) + // Collapse two consecutive Not operators which could be generated after Not-operator pushdown + case Not(Not(cond)) => + calculateFilterSelectivity(cond, update = false) + + // The foldable Not has been processed in the ConstantFolding rule + // This is a top-down traversal. The Not could be pushed down by the above two cases. + case Not(l @ Literal(null, _)) => + calculateSingleCondition(l, update = false) + case Not(cond) => calculateFilterSelectivity(cond, update = false) match { case Some(percent) => Some(1.0 - percent) @@ -134,13 +146,16 @@ case class FilterEstimation(plan: Filter, catalystConf: CatalystConf) extends Lo */ def calculateSingleCondition(condition: Expression, update: Boolean): Option[Double] = { condition match { + case l: Literal => + evaluateLiteral(l) + // For evaluateBinary method, we assume the literal on the right side of an operator. // So we will change the order if not. // EqualTo/EqualNullSafe does not care about the order - case op @ Equality(ar: Attribute, l: Literal) => + case Equality(ar: Attribute, l: Literal) => evaluateEquality(ar, l, update) - case op @ Equality(l: Literal, ar: Attribute) => + case Equality(l: Literal, ar: Attribute) => evaluateEquality(ar, l, update) case op @ LessThan(ar: Attribute, l: Literal) => @@ -342,6 +357,26 @@ case class FilterEstimation(plan: Filter, catalystConf: CatalystConf) extends Lo } + /** + * Returns a percentage of rows meeting a Literal expression. + * This method evaluates all the possible literal cases in Filter. + * + * FalseLiteral and TrueLiteral should be eliminated by optimizer, but null literal might be added + * by optimizer rule NullPropagation. For safety, we handle all the cases here. + * + * @param literal a literal value (or constant) + * @return an optional double value to show the percentage of rows meeting a given condition + */ + def evaluateLiteral(literal: Literal): Option[Double] = { + literal match { + case Literal(null, _) => Some(0.0) + case FalseLiteral => Some(0.0) + case TrueLiteral => Some(1.0) + // Ideally, we should not hit the following branch + case _ => None + } + } + /** * Returns a percentage of rows meeting "IN" operator expression. * This method evaluates the equality predicate for all data types. diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/FilterEstimationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/FilterEstimationSuite.scala index 07abe1ed28533..1966c96c05294 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/FilterEstimationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/FilterEstimationSuite.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.catalyst.statsEstimation import java.sql.Date import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.Literal.{FalseLiteral, TrueLiteral} import org.apache.spark.sql.catalyst.plans.LeftOuter import org.apache.spark.sql.catalyst.plans.logical.{ColumnStat, Filter, Join, Statistics} import org.apache.spark.sql.catalyst.plans.logical.statsEstimation.EstimationUtils._ @@ -76,6 +77,82 @@ class FilterEstimationSuite extends StatsEstimationTestBase { attrDouble -> colStatDouble, attrString -> colStatString)) + test("true") { + validateEstimatedStats( + Filter(TrueLiteral, childStatsTestPlan(Seq(attrInt), 10L)), + Seq(attrInt -> colStatInt), + expectedRowCount = 10) + } + + test("false") { + validateEstimatedStats( + Filter(FalseLiteral, childStatsTestPlan(Seq(attrInt), 10L)), + Nil, + expectedRowCount = 0) + } + + test("null") { + validateEstimatedStats( + Filter(Literal(null, IntegerType), childStatsTestPlan(Seq(attrInt), 10L)), + Nil, + expectedRowCount = 0) + } + + test("Not(null)") { + validateEstimatedStats( + Filter(Not(Literal(null, IntegerType)), childStatsTestPlan(Seq(attrInt), 10L)), + Nil, + expectedRowCount = 0) + } + + test("Not(Not(null))") { + validateEstimatedStats( + Filter(Not(Not(Literal(null, IntegerType))), childStatsTestPlan(Seq(attrInt), 10L)), + Nil, + expectedRowCount = 0) + } + + test("cint < 3 AND null") { + val condition = And(LessThan(attrInt, Literal(3)), Literal(null, IntegerType)) + validateEstimatedStats( + Filter(condition, childStatsTestPlan(Seq(attrInt), 10L)), + Nil, + expectedRowCount = 0) + } + + test("cint < 3 OR null") { + val condition = Or(LessThan(attrInt, Literal(3)), Literal(null, IntegerType)) + val m = Filter(condition, childStatsTestPlan(Seq(attrInt), 10L)).stats(conf) + validateEstimatedStats( + Filter(condition, childStatsTestPlan(Seq(attrInt), 10L)), + Seq(attrInt -> colStatInt), + expectedRowCount = 3) + } + + test("Not(cint < 3 AND null)") { + val condition = Not(And(LessThan(attrInt, Literal(3)), Literal(null, IntegerType))) + validateEstimatedStats( + Filter(condition, childStatsTestPlan(Seq(attrInt), 10L)), + Seq(attrInt -> colStatInt), + expectedRowCount = 8) + } + + test("Not(cint < 3 OR null)") { + val condition = Not(Or(LessThan(attrInt, Literal(3)), Literal(null, IntegerType))) + validateEstimatedStats( + Filter(condition, childStatsTestPlan(Seq(attrInt), 10L)), + Nil, + expectedRowCount = 0) + } + + test("Not(cint < 3 AND Not(null))") { + val condition = Not(And(LessThan(attrInt, Literal(3)), Not(Literal(null, IntegerType)))) + validateEstimatedStats( + Filter(condition, childStatsTestPlan(Seq(attrInt), 10L)), + Seq(attrInt -> colStatInt), + expectedRowCount = 8) + } + test("cint = 2") { validateEstimatedStats( Filter(EqualTo(attrInt, Literal(2)), childStatsTestPlan(Seq(attrInt), 10L)), @@ -163,6 +240,16 @@ class FilterEstimationSuite extends StatsEstimationTestBase { expectedRowCount = 10) } + test("cint IS NOT NULL && null") { + // 'cint < null' will be optimized to 'cint IS NOT NULL && null'. + // More similar cases can be found in the Optimizer NullPropagation. + val condition = And(IsNotNull(attrInt), Literal(null, IntegerType)) + validateEstimatedStats( + Filter(condition, childStatsTestPlan(Seq(attrInt), 10L)), + Nil, + expectedRowCount = 0) + } + test("cint > 3 AND cint <= 6") { val condition = And(GreaterThan(attrInt, Literal(3)), LessThanOrEqual(attrInt, Literal(6))) validateEstimatedStats( From fe1d6b05d47e384e3710ae428db499e89697267f Mon Sep 17 00:00:00 2001 From: Yuming Wang Date: Wed, 29 Mar 2017 15:23:24 -0700 Subject: [PATCH 0149/1765] [SPARK-20120][SQL] spark-sql support silent mode ## What changes were proposed in this pull request? It is similar to Hive silent mode, just show the query result. see: [Hive LanguageManual+Cli](https://cwiki.apache.org/confluence/display/Hive/LanguageManual+Cli) and [the implementation of Hive silent mode](https://github.com/apache/hive/blob/release-1.2.1/ql/src/java/org/apache/hadoop/hive/ql/session/SessionState.java#L948-L950). This PR set the Logger level to `WARN` to get similar result. ## How was this patch tested? manual tests ![manual test spark sql silent mode](https://cloud.githubusercontent.com/assets/5399861/24390165/989b7780-13b9-11e7-8496-6e68f55757e3.gif) Author: Yuming Wang Closes #17449 from wangyum/SPARK-20120. --- .../spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala index 390b9b6d68cab..1bc5c3c62f045 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala @@ -34,6 +34,7 @@ import org.apache.hadoop.hive.ql.Driver import org.apache.hadoop.hive.ql.exec.Utilities import org.apache.hadoop.hive.ql.processors._ import org.apache.hadoop.hive.ql.session.SessionState +import org.apache.log4j.{Level, Logger} import org.apache.thrift.transport.TSocket import org.apache.spark.internal.Logging @@ -275,6 +276,10 @@ private[hive] class SparkSQLCLIDriver extends CliDriver with Logging { private val console = new SessionState.LogHelper(LOG) + if (sessionState.getIsSilent) { + Logger.getRootLogger.setLevel(Level.WARN) + } + private val isRemoteMode = { SparkSQLCLIDriver.isRemoteMode(sessionState) } From dd2e7d528cb7468cdc077403f314c7ee0f214ac5 Mon Sep 17 00:00:00 2001 From: Takuya UESHIN Date: Wed, 29 Mar 2017 17:32:01 -0700 Subject: [PATCH 0150/1765] [SPARK-19088][SQL] Fix 2.10 build. ## What changes were proposed in this pull request? Commit 6c70a38 broke the build for scala 2.10. The commit uses some reflections which are not available in Scala 2.10. This PR fixes them. ## How was this patch tested? Existing tests. Author: Takuya UESHIN Closes #17473 from ueshin/issues/SPARK-19088. --- .../scala/org/apache/spark/sql/catalyst/ScalaReflection.scala | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala index 1c7720afe1ca3..da37eb00dcd97 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala @@ -307,7 +307,8 @@ object ScalaReflection extends ScalaReflection { } } - val cls = t.dealias.companion.decl(TermName("newBuilder")) match { + val companion = t.normalize.typeSymbol.companionSymbol.typeSignature + val cls = companion.declaration(newTermName("newBuilder")) match { case NoSymbol => classOf[Seq[_]] case _ => mirror.runtimeClass(t.typeSymbol.asClass) } From 22f07fefe11f0147f1e8d83d9b77707640d5dc97 Mon Sep 17 00:00:00 2001 From: bomeng Date: Wed, 29 Mar 2017 18:57:35 -0700 Subject: [PATCH 0151/1765] [SPARK-20146][SQL] fix comment missing issue for thrift server ## What changes were proposed in this pull request? The column comment was missing while constructing the Hive TableSchema. This fix will preserve the original comment. ## How was this patch tested? I have added a new test case to test the column with/without comment. Author: bomeng Closes #17470 from bomeng/SPARK-20146. --- .../SparkExecuteStatementOperation.scala | 2 +- .../SparkExecuteStatementOperationSuite.scala | 14 +++++++++++++- 2 files changed, 14 insertions(+), 2 deletions(-) diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkExecuteStatementOperation.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkExecuteStatementOperation.scala index 517b01f183926..ff3784cab9e26 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkExecuteStatementOperation.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkExecuteStatementOperation.scala @@ -292,7 +292,7 @@ object SparkExecuteStatementOperation { def getTableSchema(structType: StructType): TableSchema = { val schema = structType.map { field => val attrTypeString = if (field.dataType == NullType) "void" else field.dataType.catalogString - new FieldSchema(field.name, attrTypeString, "") + new FieldSchema(field.name, attrTypeString, field.getComment.getOrElse("")) } new TableSchema(schema.asJava) } diff --git a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/SparkExecuteStatementOperationSuite.scala b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/SparkExecuteStatementOperationSuite.scala index 32ded0d254ef8..06e3980662048 100644 --- a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/SparkExecuteStatementOperationSuite.scala +++ b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/SparkExecuteStatementOperationSuite.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.hive.thriftserver import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.types.{NullType, StructField, StructType} +import org.apache.spark.sql.types.{IntegerType, NullType, StringType, StructField, StructType} class SparkExecuteStatementOperationSuite extends SparkFunSuite { test("SPARK-17112 `select null` via JDBC triggers IllegalArgumentException in ThriftServer") { @@ -30,4 +30,16 @@ class SparkExecuteStatementOperationSuite extends SparkFunSuite { assert(columns.get(0).getType() == org.apache.hive.service.cli.Type.NULL_TYPE) assert(columns.get(1).getType() == org.apache.hive.service.cli.Type.NULL_TYPE) } + + test("SPARK-20146 Comment should be preserved") { + val field1 = StructField("column1", StringType).withComment("comment 1") + val field2 = StructField("column2", IntegerType) + val tableSchema = StructType(Seq(field1, field2)) + val columns = SparkExecuteStatementOperation.getTableSchema(tableSchema).getColumnDescriptors() + assert(columns.size() == 2) + assert(columns.get(0).getType() == org.apache.hive.service.cli.Type.STRING_TYPE) + assert(columns.get(0).getComment() == "comment 1") + assert(columns.get(1).getType() == org.apache.hive.service.cli.Type.INT_TYPE) + assert(columns.get(1).getComment() == "") + } } From 60977889eaecdf28adc6164310eaa5afed488fa1 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Wed, 29 Mar 2017 19:06:51 -0700 Subject: [PATCH 0152/1765] [SPARK-20136][SQL] Add num files and metadata operation timing to scan operator metrics ## What changes were proposed in this pull request? This patch adds explicit metadata operation timing and number of files in data source metrics. Those would be useful to include for performance profiling. Screenshot of a UI with this change (num files and metadata time are new metrics): screen shot 2017-03-29 at 12 29 28 am ## How was this patch tested? N/A Author: Reynold Xin Closes #17465 from rxin/SPARK-20136. --- .../sql/execution/DataSourceScanExec.scala | 18 ++++++++++++++++-- 1 file changed, 16 insertions(+), 2 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala index 28156b277f597..239151495f4bd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala @@ -171,8 +171,20 @@ case class FileSourceScanExec( false } - @transient private lazy val selectedPartitions = - relation.location.listFiles(partitionFilters, dataFilters) + @transient private lazy val selectedPartitions: Seq[PartitionDirectory] = { + val startTime = System.nanoTime() + val ret = relation.location.listFiles(partitionFilters, dataFilters) + val timeTaken = (System.nanoTime() - startTime) / 1000 / 1000 + + metrics("numFiles").add(ret.map(_.files.size.toLong).sum) + metrics("metadataTime").add(timeTaken) + + val executionId = sparkContext.getLocalProperty(SQLExecution.EXECUTION_ID_KEY) + SQLMetrics.postDriverMetricUpdates(sparkContext, executionId, + metrics("numFiles") :: metrics("metadataTime") :: Nil) + + ret + } override val (outputPartitioning, outputOrdering): (Partitioning, Seq[SortOrder]) = { val bucketSpec = if (relation.sparkSession.sessionState.conf.bucketingEnabled) { @@ -293,6 +305,8 @@ case class FileSourceScanExec( override lazy val metrics = Map("numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows"), + "numFiles" -> SQLMetrics.createMetric(sparkContext, "number of files"), + "metadataTime" -> SQLMetrics.createMetric(sparkContext, "metadata time (ms)"), "scanTime" -> SQLMetrics.createTimingMetric(sparkContext, "scan time")) protected override def doExecute(): RDD[InternalRow] = { From 79636054f60dd639e9d326e1328717e97df13304 Mon Sep 17 00:00:00 2001 From: Eric Liang Date: Wed, 29 Mar 2017 20:59:48 -0700 Subject: [PATCH 0153/1765] [SPARK-20148][SQL] Extend the file commit API to allow subscribing to task commit messages ## What changes were proposed in this pull request? The internal FileCommitProtocol interface returns all task commit messages in bulk to the implementation when a job finishes. However, it is sometimes useful to access those messages before the job completes, so that the driver gets incremental progress updates before the job finishes. This adds an `onTaskCommit` listener to the internal api. ## How was this patch tested? Unit tests. cc rxin Author: Eric Liang Closes #17475 from ericl/file-commit-api-ext. --- .../internal/io/FileCommitProtocol.scala | 7 +++++ .../datasources/FileFormatWriter.scala | 22 +++++++++---- .../sql/test/DataFrameReaderWriterSuite.scala | 31 ++++++++++++++++++- 3 files changed, 53 insertions(+), 7 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/internal/io/FileCommitProtocol.scala b/core/src/main/scala/org/apache/spark/internal/io/FileCommitProtocol.scala index 2394cf361c33a..7efa9416362a0 100644 --- a/core/src/main/scala/org/apache/spark/internal/io/FileCommitProtocol.scala +++ b/core/src/main/scala/org/apache/spark/internal/io/FileCommitProtocol.scala @@ -121,6 +121,13 @@ abstract class FileCommitProtocol { def deleteWithJob(fs: FileSystem, path: Path, recursive: Boolean): Boolean = { fs.delete(path, recursive) } + + /** + * Called on the driver after a task commits. This can be used to access task commit messages + * before the job has finished. These same task commit messages will be passed to commitJob() + * if the entire job succeeds. + */ + def onTaskCommit(taskCommit: TaskCommitMessage): Unit = {} } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala index 7957224ce48b5..bda64d4b91bbc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala @@ -80,6 +80,9 @@ object FileFormatWriter extends Logging { """.stripMargin) } + /** The result of a successful write task. */ + private case class WriteTaskResult(commitMsg: TaskCommitMessage, updatedPartitions: Set[String]) + /** * Basic work flow of this command is: * 1. Driver side setup, including output committer initialization and data source specific @@ -172,8 +175,9 @@ object FileFormatWriter extends Logging { global = false, child = queryExecution.executedPlan).execute() } - - val ret = sparkSession.sparkContext.runJob(rdd, + val ret = new Array[WriteTaskResult](rdd.partitions.length) + sparkSession.sparkContext.runJob( + rdd, (taskContext: TaskContext, iter: Iterator[InternalRow]) => { executeTask( description = description, @@ -182,10 +186,16 @@ object FileFormatWriter extends Logging { sparkAttemptNumber = taskContext.attemptNumber(), committer, iterator = iter) + }, + 0 until rdd.partitions.length, + (index, res: WriteTaskResult) => { + committer.onTaskCommit(res.commitMsg) + ret(index) = res }) - val commitMsgs = ret.map(_._1) - val updatedPartitions = ret.flatMap(_._2).distinct.map(PartitioningUtils.parsePathFragment) + val commitMsgs = ret.map(_.commitMsg) + val updatedPartitions = ret.flatMap(_.updatedPartitions) + .distinct.map(PartitioningUtils.parsePathFragment) committer.commitJob(job, commitMsgs) logInfo(s"Job ${job.getJobID} committed.") @@ -205,7 +215,7 @@ object FileFormatWriter extends Logging { sparkPartitionId: Int, sparkAttemptNumber: Int, committer: FileCommitProtocol, - iterator: Iterator[InternalRow]): (TaskCommitMessage, Set[String]) = { + iterator: Iterator[InternalRow]): WriteTaskResult = { val jobId = SparkHadoopWriterUtils.createJobID(new Date, sparkStageId) val taskId = new TaskID(jobId, TaskType.MAP, sparkPartitionId) @@ -238,7 +248,7 @@ object FileFormatWriter extends Logging { // Execute the task to write rows out and commit the task. val outputPartitions = writeTask.execute(iterator) writeTask.releaseResources() - (committer.commitTask(taskAttemptContext), outputPartitions) + WriteTaskResult(committer.commitTask(taskAttemptContext), outputPartitions) })(catchBlock = { // If there is an error, release resource and then abort the task try { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/DataFrameReaderWriterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/DataFrameReaderWriterSuite.scala index 8287776f8f558..7c71e7280c6d3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/test/DataFrameReaderWriterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/DataFrameReaderWriterSuite.scala @@ -18,9 +18,12 @@ package org.apache.spark.sql.test import java.io.File +import java.util.concurrent.ConcurrentLinkedQueue import org.scalatest.BeforeAndAfter +import org.apache.spark.internal.io.FileCommitProtocol.TaskCommitMessage +import org.apache.spark.internal.io.HadoopMapReduceCommitProtocol import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.sources._ @@ -41,7 +44,6 @@ object LastOptions { } } - /** Dummy provider. */ class DefaultSource extends RelationProvider @@ -107,6 +109,20 @@ class DefaultSourceWithoutUserSpecifiedSchema } } +object MessageCapturingCommitProtocol { + val commitMessages = new ConcurrentLinkedQueue[TaskCommitMessage]() +} + +class MessageCapturingCommitProtocol(jobId: String, path: String) + extends HadoopMapReduceCommitProtocol(jobId, path) { + + // captures commit messages for testing + override def onTaskCommit(msg: TaskCommitMessage): Unit = { + MessageCapturingCommitProtocol.commitMessages.offer(msg) + } +} + + class DataFrameReaderWriterSuite extends QueryTest with SharedSQLContext with BeforeAndAfter { import testImplicits._ @@ -291,6 +307,19 @@ class DataFrameReaderWriterSuite extends QueryTest with SharedSQLContext with Be Option(dir).map(spark.read.format("org.apache.spark.sql.test").load) } + test("write path implements onTaskCommit API correctly") { + withSQLConf( + "spark.sql.sources.commitProtocolClass" -> + classOf[MessageCapturingCommitProtocol].getCanonicalName) { + withTempDir { dir => + val path = dir.getCanonicalPath + MessageCapturingCommitProtocol.commitMessages.clear() + spark.range(10).repartition(10).write.mode("overwrite").parquet(path) + assert(MessageCapturingCommitProtocol.commitMessages.size() == 10) + } + } + } + test("read a data source that does not extend SchemaRelationProvider") { val dfReader = spark.read .option("from", "1") From 471de5db53ed77711523a3f016d6e9c530b651e5 Mon Sep 17 00:00:00 2001 From: "wm624@hotmail.com" Date: Wed, 29 Mar 2017 21:38:26 -0700 Subject: [PATCH 0154/1765] [MINOR][SPARKR] Add run command comment in examples ## What changes were proposed in this pull request? There are two examples in r folder missing the run commands. In this PR, I just add the missing comment, which is consistent with other examples. ## How was this patch tested? Manual test. Author: wm624@hotmail.com Closes #17474 from wangmiao1981/stat. --- examples/src/main/r/RSparkSQLExample.R | 3 +++ examples/src/main/r/dataframe.R | 3 +++ 2 files changed, 6 insertions(+) diff --git a/examples/src/main/r/RSparkSQLExample.R b/examples/src/main/r/RSparkSQLExample.R index e647f0e1e9f17..3734568d872d0 100644 --- a/examples/src/main/r/RSparkSQLExample.R +++ b/examples/src/main/r/RSparkSQLExample.R @@ -15,6 +15,9 @@ # limitations under the License. # +# To run this example use +# ./bin/spark-submit examples/src/main/r/RSparkSQLExample.R + library(SparkR) # $example on:init_session$ diff --git a/examples/src/main/r/dataframe.R b/examples/src/main/r/dataframe.R index 82b85f2f590f6..311350497f873 100644 --- a/examples/src/main/r/dataframe.R +++ b/examples/src/main/r/dataframe.R @@ -15,6 +15,9 @@ # limitations under the License. # +# To run this example use +# ./bin/spark-submit examples/src/main/r/dataframe.R + library(SparkR) # Initialize SparkSession From edc87d76efea7b4d19d9d0c4ddba274a3ccb8752 Mon Sep 17 00:00:00 2001 From: Yuming Wang Date: Thu, 30 Mar 2017 10:39:57 +0100 Subject: [PATCH 0155/1765] [SPARK-20107][DOC] Add spark.hadoop.mapreduce.fileoutputcommitter.algorithm.version option to configuration.md ## What changes were proposed in this pull request? Add `spark.hadoop.mapreduce.fileoutputcommitter.algorithm.version` option to `configuration.md`. Set `spark.hadoop.mapreduce.fileoutputcommitter.algorithm.version=2` can speed up [HadoopMapReduceCommitProtocol.commitJob](https://github.com/apache/spark/blob/v2.1.0/core/src/main/scala/org/apache/spark/internal/io/HadoopMapReduceCommitProtocol.scala#L121) for many output files. All cloudera's hadoop 2.6.0-cdh5.4.0 or higher versions(see: https://github.com/cloudera/hadoop-common/commit/1c1236182304d4075276c00c4592358f428bc433 and https://github.com/cloudera/hadoop-common/commit/16b2de27321db7ce2395c08baccfdec5562017f0) and apache's hadoop 2.7.0 or higher versions support this improvement. More see: 1. [MAPREDUCE-4815](https://issues.apache.org/jira/browse/MAPREDUCE-4815): Speed up FileOutputCommitter#commitJob for many output files. 2. [MAPREDUCE-6406](https://issues.apache.org/jira/browse/MAPREDUCE-6406): Update the default version for the property mapreduce.fileoutputcommitter.algorithm.version to 2. ## How was this patch tested? Manual test and exist tests. Author: Yuming Wang Closes #17442 from wangyum/SPARK-20107. --- docs/configuration.md | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/docs/configuration.md b/docs/configuration.md index 4729f1b0404c1..a9753925407d7 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -1137,6 +1137,15 @@ Apart from these, the following properties are also available, and may be useful mapping has high overhead for blocks close to or below the page size of the operating system. + + spark.hadoop.mapreduce.fileoutputcommitter.algorithm.version + 1 + + The file output committer algorithm version, valid algorithm version number: 1 or 2. + Version 2 may have better performance, but version 1 may handle failures better in certain situations, + as per MAPREDUCE-4815. + + ### Networking From b454d4402e5ee7d1a7385d1fe3737581f84d2c72 Mon Sep 17 00:00:00 2001 From: Shubham Chopra Date: Thu, 30 Mar 2017 22:21:57 +0800 Subject: [PATCH 0156/1765] [SPARK-15354][CORE] Topology aware block replication strategies ## What changes were proposed in this pull request? Implementations of strategies for resilient block replication for different resource managers that replicate the 3-replica strategy used by HDFS, where the first replica is on an executor, the second replica within the same rack as the executor and a third replica on a different rack. The implementation involves providing two pluggable classes, one running in the driver that provides topology information for every host at cluster start and the second prioritizing a list of peer BlockManagerIds. The prioritization itself can be thought of an optimization problem to find a minimal set of peers that satisfy certain objectives and replicating to these peers first. The objectives can be used to express richer constraints over and above HDFS like 3-replica strategy. ## How was this patch tested? This patch was tested with unit tests for storage, along with new unit tests to verify prioritization behaviour. Author: Shubham Chopra Closes #13932 from shubhamchopra/PrioritizerStrategy. --- .../apache/spark/storage/BlockManager.scala | 3 - .../storage/BlockReplicationPolicy.scala | 145 ++++++++++++++++-- .../BlockManagerReplicationSuite.scala | 33 +++- .../storage/BlockReplicationPolicySuite.scala | 73 +++++++-- 4 files changed, 222 insertions(+), 32 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala index fcda9fa65303a..46a078b2f9f93 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala @@ -49,7 +49,6 @@ import org.apache.spark.unsafe.Platform import org.apache.spark.util._ import org.apache.spark.util.io.ChunkedByteBuffer - /* Class for returning a fetched block and associated metrics. */ private[spark] class BlockResult( val data: Iterator[Any], @@ -1258,7 +1257,6 @@ private[spark] class BlockManager( replication = 1) val numPeersToReplicateTo = level.replication - 1 - val startTime = System.nanoTime var peersReplicatedTo = mutable.HashSet.empty ++ existingReplicas @@ -1313,7 +1311,6 @@ private[spark] class BlockManager( numPeersToReplicateTo - peersReplicatedTo.size) } } - logDebug(s"Replicating $blockId of ${data.size} bytes to " + s"${peersReplicatedTo.size} peer(s) took ${(System.nanoTime - startTime) / 1e6} ms") if (peersReplicatedTo.size < numPeersToReplicateTo) { diff --git a/core/src/main/scala/org/apache/spark/storage/BlockReplicationPolicy.scala b/core/src/main/scala/org/apache/spark/storage/BlockReplicationPolicy.scala index bb8a684b4c7a8..353eac60df171 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockReplicationPolicy.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockReplicationPolicy.scala @@ -53,6 +53,46 @@ trait BlockReplicationPolicy { numReplicas: Int): List[BlockManagerId] } +object BlockReplicationUtils { + // scalastyle:off line.size.limit + /** + * Uses sampling algorithm by Robert Floyd. Finds a random sample in O(n) while + * minimizing space usage. Please see + * here. + * + * @param n total number of indices + * @param m number of samples needed + * @param r random number generator + * @return list of m random unique indices + */ + // scalastyle:on line.size.limit + private def getSampleIds(n: Int, m: Int, r: Random): List[Int] = { + val indices = (n - m + 1 to n).foldLeft(mutable.LinkedHashSet.empty[Int]) {case (set, i) => + val t = r.nextInt(i) + 1 + if (set.contains(t)) set + i else set + t + } + indices.map(_ - 1).toList + } + + /** + * Get a random sample of size m from the elems + * + * @param elems + * @param m number of samples needed + * @param r random number generator + * @tparam T + * @return a random list of size m. If there are fewer than m elements in elems, we just + * randomly shuffle elems + */ + def getRandomSample[T](elems: Seq[T], m: Int, r: Random): List[T] = { + if (elems.size > m) { + getSampleIds(elems.size, m, r).map(elems(_)) + } else { + r.shuffle(elems).toList + } + } +} + @DeveloperApi class RandomBlockReplicationPolicy extends BlockReplicationPolicy @@ -67,6 +107,7 @@ class RandomBlockReplicationPolicy * @param peersReplicatedTo Set of peers already replicated to * @param blockId BlockId of the block being replicated. This can be used as a source of * randomness if needed. + * @param numReplicas Number of peers we need to replicate to * @return A prioritized list of peers. Lower the index of a peer, higher its priority */ override def prioritize( @@ -78,7 +119,7 @@ class RandomBlockReplicationPolicy val random = new Random(blockId.hashCode) logDebug(s"Input peers : ${peers.mkString(", ")}") val prioritizedPeers = if (peers.size > numReplicas) { - getSampleIds(peers.size, numReplicas, random).map(peers(_)) + BlockReplicationUtils.getRandomSample(peers, numReplicas, random) } else { if (peers.size < numReplicas) { logWarning(s"Expecting ${numReplicas} replicas with only ${peers.size} peer/s.") @@ -88,26 +129,96 @@ class RandomBlockReplicationPolicy logDebug(s"Prioritized peers : ${prioritizedPeers.mkString(", ")}") prioritizedPeers } +} + +@DeveloperApi +class BasicBlockReplicationPolicy + extends BlockReplicationPolicy + with Logging { - // scalastyle:off line.size.limit /** - * Uses sampling algorithm by Robert Floyd. Finds a random sample in O(n) while - * minimizing space usage. Please see - * here. + * Method to prioritize a bunch of candidate peers of a block manager. This implementation + * replicates the behavior of block replication in HDFS. For a given number of replicas needed, + * we choose a peer within the rack, one outside and remaining blockmanagers are chosen at + * random, in that order till we meet the number of replicas needed. + * This works best with a total replication factor of 3, like HDFS. * - * @param n total number of indices - * @param m number of samples needed - * @param r random number generator - * @return list of m random unique indices + * @param blockManagerId Id of the current BlockManager for self identification + * @param peers A list of peers of a BlockManager + * @param peersReplicatedTo Set of peers already replicated to + * @param blockId BlockId of the block being replicated. This can be used as a source of + * randomness if needed. + * @param numReplicas Number of peers we need to replicate to + * @return A prioritized list of peers. Lower the index of a peer, higher its priority */ - // scalastyle:on line.size.limit - private def getSampleIds(n: Int, m: Int, r: Random): List[Int] = { - val indices = (n - m + 1 to n).foldLeft(Set.empty[Int]) {case (set, i) => - val t = r.nextInt(i) + 1 - if (set.contains(t)) set + i else set + t + override def prioritize( + blockManagerId: BlockManagerId, + peers: Seq[BlockManagerId], + peersReplicatedTo: mutable.HashSet[BlockManagerId], + blockId: BlockId, + numReplicas: Int): List[BlockManagerId] = { + + logDebug(s"Input peers : $peers") + logDebug(s"BlockManagerId : $blockManagerId") + + val random = new Random(blockId.hashCode) + + // if block doesn't have topology info, we can't do much, so we randomly shuffle + // if there is, we see what's needed from peersReplicatedTo and based on numReplicas, + // we choose whats needed + if (blockManagerId.topologyInfo.isEmpty || numReplicas == 0) { + // no topology info for the block. The best we can do is randomly choose peers + BlockReplicationUtils.getRandomSample(peers, numReplicas, random) + } else { + // we have topology information, we see what is left to be done from peersReplicatedTo + val doneWithinRack = peersReplicatedTo.exists(_.topologyInfo == blockManagerId.topologyInfo) + val doneOutsideRack = peersReplicatedTo.exists { p => + p.topologyInfo.isDefined && p.topologyInfo != blockManagerId.topologyInfo + } + + if (doneOutsideRack && doneWithinRack) { + // we are done, we just return a random sample + BlockReplicationUtils.getRandomSample(peers, numReplicas, random) + } else { + // we separate peers within and outside rack + val (inRackPeers, outOfRackPeers) = peers + .filter(_.host != blockManagerId.host) + .partition(_.topologyInfo == blockManagerId.topologyInfo) + + val peerWithinRack = if (doneWithinRack) { + // we are done with in-rack replication, so don't need anymore peers + Seq.empty + } else { + if (inRackPeers.isEmpty) { + Seq.empty + } else { + Seq(inRackPeers(random.nextInt(inRackPeers.size))) + } + } + + val peerOutsideRack = if (doneOutsideRack || numReplicas - peerWithinRack.size <= 0) { + Seq.empty + } else { + if (outOfRackPeers.isEmpty) { + Seq.empty + } else { + Seq(outOfRackPeers(random.nextInt(outOfRackPeers.size))) + } + } + + val priorityPeers = peerWithinRack ++ peerOutsideRack + val numRemainingPeers = numReplicas - priorityPeers.size + val remainingPeers = if (numRemainingPeers > 0) { + val rPeers = peers.filter(p => !priorityPeers.contains(p)) + BlockReplicationUtils.getRandomSample(rPeers, numRemainingPeers, random) + } else { + Seq.empty + } + + (priorityPeers ++ remainingPeers).toList + } + } - // we shuffle the result to ensure a random arrangement within the sample - // to avoid any bias from set implementations - r.shuffle(indices.map(_ - 1).toList) } + } diff --git a/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala index d5715f8469f71..13020acdd3dbe 100644 --- a/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala @@ -28,6 +28,7 @@ import org.scalatest.concurrent.Eventually._ import org.apache.spark._ import org.apache.spark.broadcast.BroadcastManager +import org.apache.spark.internal.Logging import org.apache.spark.memory.UnifiedMemoryManager import org.apache.spark.network.BlockTransferService import org.apache.spark.network.netty.NettyBlockTransferService @@ -36,6 +37,7 @@ import org.apache.spark.scheduler.LiveListenerBus import org.apache.spark.serializer.{KryoSerializer, SerializerManager} import org.apache.spark.shuffle.sort.SortShuffleManager import org.apache.spark.storage.StorageLevel._ +import org.apache.spark.util.Utils trait BlockManagerReplicationBehavior extends SparkFunSuite with Matchers @@ -43,6 +45,7 @@ trait BlockManagerReplicationBehavior extends SparkFunSuite with LocalSparkContext { val conf: SparkConf + protected var rpcEnv: RpcEnv = null protected var master: BlockManagerMaster = null protected lazy val securityMgr = new SecurityManager(conf) @@ -55,7 +58,6 @@ trait BlockManagerReplicationBehavior extends SparkFunSuite protected val allStores = new ArrayBuffer[BlockManager] // Reuse a serializer across tests to avoid creating a new thread-local buffer on each test - protected lazy val serializer = new KryoSerializer(conf) // Implicitly convert strings to BlockIds for test clarity. @@ -471,7 +473,7 @@ class BlockManagerProactiveReplicationSuite extends BlockManagerReplicationBehav conf.set("spark.storage.replication.proactive", "true") conf.set("spark.storage.exceptionOnPinLeak", "true") - (2 to 5).foreach{ i => + (2 to 5).foreach { i => test(s"proactive block replication - $i replicas - ${i - 1} block manager deletions") { testProactiveReplication(i) } @@ -524,3 +526,30 @@ class BlockManagerProactiveReplicationSuite extends BlockManagerReplicationBehav } } } + +class DummyTopologyMapper(conf: SparkConf) extends TopologyMapper(conf) with Logging { + // number of racks to test with + val numRacks = 3 + + /** + * Gets the topology information given the host name + * + * @param hostname Hostname + * @return random topology + */ + override def getTopologyForHost(hostname: String): Option[String] = { + Some(s"/Rack-${Utils.random.nextInt(numRacks)}") + } +} + +class BlockManagerBasicStrategyReplicationSuite extends BlockManagerReplicationBehavior { + val conf: SparkConf = new SparkConf(false).set("spark.app.id", "test") + conf.set("spark.kryoserializer.buffer", "1m") + conf.set( + "spark.storage.replication.policy", + classOf[BasicBlockReplicationPolicy].getName) + conf.set( + "spark.storage.replication.topologyMapper", + classOf[DummyTopologyMapper].getName) +} + diff --git a/core/src/test/scala/org/apache/spark/storage/BlockReplicationPolicySuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockReplicationPolicySuite.scala index 800c3899f1a72..ecad0f5352e59 100644 --- a/core/src/test/scala/org/apache/spark/storage/BlockReplicationPolicySuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/BlockReplicationPolicySuite.scala @@ -18,34 +18,34 @@ package org.apache.spark.storage import scala.collection.mutable +import scala.util.Random import org.scalatest.{BeforeAndAfter, Matchers} import org.apache.spark.{LocalSparkContext, SparkFunSuite} -class BlockReplicationPolicySuite extends SparkFunSuite +class RandomBlockReplicationPolicyBehavior extends SparkFunSuite with Matchers with BeforeAndAfter with LocalSparkContext { // Implicitly convert strings to BlockIds for test clarity. - private implicit def StringToBlockId(value: String): BlockId = new TestBlockId(value) + protected implicit def StringToBlockId(value: String): BlockId = new TestBlockId(value) + val replicationPolicy: BlockReplicationPolicy = new RandomBlockReplicationPolicy + + val blockId = "test-block" /** * Test if we get the required number of peers when using random sampling from - * RandomBlockReplicationPolicy + * BlockReplicationPolicy */ - test(s"block replication - random block replication policy") { + test("block replication - random block replication policy") { val numBlockManagers = 10 val storeSize = 1000 - val blockManagers = (1 to numBlockManagers).map { i => - BlockManagerId(s"store-$i", "localhost", 1000 + i, None) - } + val blockManagers = generateBlockManagerIds(numBlockManagers, Seq("/Rack-1")) val candidateBlockManager = BlockManagerId("test-store", "localhost", 1000, None) - val replicationPolicy = new RandomBlockReplicationPolicy - val blockId = "test-block" - (1 to 10).foreach {numReplicas => + (1 to 10).foreach { numReplicas => logDebug(s"Num replicas : $numReplicas") val randomPeers = replicationPolicy.prioritize( candidateBlockManager, @@ -68,7 +68,60 @@ class BlockReplicationPolicySuite extends SparkFunSuite logDebug(s"Random peers : ${secondPass.mkString(", ")}") assert(secondPass.toSet.size === numReplicas) } + } + + protected def generateBlockManagerIds(count: Int, racks: Seq[String]): Seq[BlockManagerId] = { + (1 to count).map{i => + BlockManagerId(s"Exec-$i", s"Host-$i", 10000 + i, Some(racks(Random.nextInt(racks.size)))) + } + } +} + +class TopologyAwareBlockReplicationPolicyBehavior extends RandomBlockReplicationPolicyBehavior { + override val replicationPolicy = new BasicBlockReplicationPolicy + + test("All peers in the same rack") { + val racks = Seq("/default-rack") + val numBlockManager = 10 + (1 to 10).foreach {numReplicas => + val peers = generateBlockManagerIds(numBlockManager, racks) + val blockManager = BlockManagerId("Driver", "Host-driver", 10001, Some(racks.head)) + + val prioritizedPeers = replicationPolicy.prioritize( + blockManager, + peers, + mutable.HashSet.empty, + blockId, + numReplicas + ) + assert(prioritizedPeers.toSet.size == numReplicas) + assert(prioritizedPeers.forall(p => p.host != blockManager.host)) + } } + test("Peers in 2 racks") { + val racks = Seq("/Rack-1", "/Rack-2") + (1 to 10).foreach {numReplicas => + val peers = generateBlockManagerIds(10, racks) + val blockManager = BlockManagerId("Driver", "Host-driver", 9001, Some(racks.head)) + + val prioritizedPeers = replicationPolicy.prioritize( + blockManager, + peers, + mutable.HashSet.empty, + blockId, + numReplicas + ) + + assert(prioritizedPeers.toSet.size == numReplicas) + val priorityPeers = prioritizedPeers.take(2) + assert(priorityPeers.forall(p => p.host != blockManager.host)) + if(numReplicas > 1) { + // both these conditions should be satisfied when numReplicas > 1 + assert(priorityPeers.exists(p => p.topologyInfo == blockManager.topologyInfo)) + assert(priorityPeers.exists(p => p.topologyInfo != blockManager.topologyInfo)) + } + } + } } From 0197262a358fd174a188f8246ae777e53157610e Mon Sep 17 00:00:00 2001 From: Jacek Laskowski Date: Thu, 30 Mar 2017 16:07:27 +0100 Subject: [PATCH 0157/1765] [DOCS] Docs-only improvements MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit …adoc ## What changes were proposed in this pull request? Use recommended values for row boundaries in Window's scaladoc, i.e. `Window.unboundedPreceding`, `Window.unboundedFollowing`, and `Window.currentRow` (that were introduced in 2.1.0). ## How was this patch tested? Local build Author: Jacek Laskowski Closes #17417 from jaceklaskowski/window-expression-scaladoc. --- .../apache/spark/memory/MemoryConsumer.java | 2 -- .../sort/BypassMergeSortShuffleWriter.java | 5 ++-- .../spark/ExecutorAllocationClient.scala | 5 ++-- .../org/apache/spark/scheduler/Task.scala | 2 +- .../apache/spark/serializer/Serializer.scala | 2 +- .../shuffle/BlockStoreShuffleReader.scala | 3 +-- .../shuffle/IndexShuffleBlockResolver.scala | 4 ++-- .../shuffle/sort/SortShuffleManager.scala | 4 ++-- .../org/apache/spark/util/AccumulatorV2.scala | 2 +- .../spark/examples/ml/DataFrameExample.scala | 2 +- .../apache/spark/ml/stat/Correlation.scala | 2 +- .../sql/catalyst/analysis/ResolveHints.scala | 2 +- .../catalyst/encoders/ExpressionEncoder.scala | 6 ++--- .../sql/catalyst/expressions/Expression.scala | 2 +- .../expressions/windowExpressions.scala | 2 +- .../sql/catalyst/optimizer/objects.scala | 2 +- .../sql/catalyst/parser/AstBuilder.scala | 6 ++--- .../spark/sql/catalyst/plans/QueryPlan.scala | 5 ++-- .../catalyst/plans/logical/LogicalPlan.scala | 2 +- .../parser/ExpressionParserSuite.scala | 3 ++- .../scala/org/apache/spark/sql/Column.scala | 18 +++++++-------- .../org/apache/spark/sql/DatasetHolder.scala | 3 ++- .../org/apache/spark/sql/SparkSession.scala | 2 +- .../sql/execution/command/databases.scala | 2 +- .../sql/execution/streaming/Source.scala | 2 +- .../apache/spark/sql/expressions/Window.scala | 23 ++++++++++--------- .../spark/sql/expressions/WindowSpec.scala | 20 ++++++++-------- .../org/apache/spark/sql/functions.scala | 2 +- .../sql/hive/HiveSessionStateBuilder.scala | 2 +- .../scheduler/InputInfoTracker.scala | 2 +- 30 files changed, 68 insertions(+), 71 deletions(-) diff --git a/core/src/main/java/org/apache/spark/memory/MemoryConsumer.java b/core/src/main/java/org/apache/spark/memory/MemoryConsumer.java index fc1f3a80239ba..48cf4b9455e4d 100644 --- a/core/src/main/java/org/apache/spark/memory/MemoryConsumer.java +++ b/core/src/main/java/org/apache/spark/memory/MemoryConsumer.java @@ -60,8 +60,6 @@ protected long getUsed() { /** * Force spill during building. - * - * For testing. */ public void spill() throws IOException { spill(Long.MAX_VALUE, this); diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java b/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java index 4a15559e55cbd..323a5d3c52831 100644 --- a/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java +++ b/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java @@ -52,8 +52,7 @@ * This class implements sort-based shuffle's hash-style shuffle fallback path. This write path * writes incoming records to separate files, one file per reduce partition, then concatenates these * per-partition files to form a single output file, regions of which are served to reducers. - * Records are not buffered in memory. This is essentially identical to - * {@link org.apache.spark.shuffle.hash.HashShuffleWriter}, except that it writes output in a format + * Records are not buffered in memory. It writes output in a format * that can be served / consumed via {@link org.apache.spark.shuffle.IndexShuffleBlockResolver}. *

      * This write path is inefficient for shuffles with large numbers of reduce partitions because it @@ -61,7 +60,7 @@ * {@link SortShuffleManager} only selects this write path when *

        *
      • no Ordering is specified,
      • - *
      • no Aggregator is specific, and
      • + *
      • no Aggregator is specified, and
      • *
      • the number of partitions is less than * spark.shuffle.sort.bypassMergeThreshold.
      • *
      diff --git a/core/src/main/scala/org/apache/spark/ExecutorAllocationClient.scala b/core/src/main/scala/org/apache/spark/ExecutorAllocationClient.scala index e4b9f8111efca..9112d93a86b2a 100644 --- a/core/src/main/scala/org/apache/spark/ExecutorAllocationClient.scala +++ b/core/src/main/scala/org/apache/spark/ExecutorAllocationClient.scala @@ -71,13 +71,12 @@ private[spark] trait ExecutorAllocationClient { /** * Request that the cluster manager kill every executor on the specified host. - * Results in a call to killExecutors for each executor on the host, with the replace - * and force arguments set to true. + * * @return whether the request is acknowledged by the cluster manager. */ def killExecutorsOnHost(host: String): Boolean - /** + /** * Request that the cluster manager kill the specified executor. * @return whether the request is acknowledged by the cluster manager. */ diff --git a/core/src/main/scala/org/apache/spark/scheduler/Task.scala b/core/src/main/scala/org/apache/spark/scheduler/Task.scala index 46ef23f316a61..7fd2918960cd0 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/Task.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/Task.scala @@ -149,7 +149,7 @@ private[spark] abstract class Task[T]( def preferredLocations: Seq[TaskLocation] = Nil - // Map output tracker epoch. Will be set by TaskScheduler. + // Map output tracker epoch. Will be set by TaskSetManager. var epoch: Long = -1 // Task context, to be initialized in run(). diff --git a/core/src/main/scala/org/apache/spark/serializer/Serializer.scala b/core/src/main/scala/org/apache/spark/serializer/Serializer.scala index 008b0387899f6..01bbda0b5e6b3 100644 --- a/core/src/main/scala/org/apache/spark/serializer/Serializer.scala +++ b/core/src/main/scala/org/apache/spark/serializer/Serializer.scala @@ -77,7 +77,7 @@ abstract class Serializer { * position = 0 * serOut.write(obj1) * serOut.flush() - * position = # of bytes writen to stream so far + * position = # of bytes written to stream so far * obj1Bytes = output[0:position-1] * serOut.write(obj2) * serOut.flush() diff --git a/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala b/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala index 8b2e26cdd94fb..ba3e0e395e958 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala @@ -95,8 +95,7 @@ private[spark] class BlockStoreShuffleReader[K, C]( // Sort the output if there is a sort ordering defined. dep.keyOrdering match { case Some(keyOrd: Ordering[K]) => - // Create an ExternalSorter to sort the data. Note that if spark.shuffle.spill is disabled, - // the ExternalSorter won't spill to disk. + // Create an ExternalSorter to sort the data. val sorter = new ExternalSorter[K, C, C](context, ordering = Some(keyOrd), serializer = dep.serializer) sorter.insertAll(aggregatedIter) diff --git a/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockResolver.scala b/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockResolver.scala index 91858f0912b65..15540485170d0 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockResolver.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockResolver.scala @@ -61,7 +61,7 @@ private[spark] class IndexShuffleBlockResolver( /** * Remove data file and index file that contain the output data from one map. - * */ + */ def removeDataByMap(shuffleId: Int, mapId: Int): Unit = { var file = getDataFile(shuffleId, mapId) if (file.exists()) { @@ -132,7 +132,7 @@ private[spark] class IndexShuffleBlockResolver( * replace them with new ones. * * Note: the `lengths` will be updated to match the existing index file if use the existing ones. - * */ + */ def writeIndexFileAndCommit( shuffleId: Int, mapId: Int, diff --git a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala index 5e977a16febe1..bfb4dc698e325 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala @@ -82,13 +82,13 @@ private[spark] class SortShuffleManager(conf: SparkConf) extends ShuffleManager override val shuffleBlockResolver = new IndexShuffleBlockResolver(conf) /** - * Register a shuffle with the manager and obtain a handle for it to pass to tasks. + * Obtains a [[ShuffleHandle]] to pass to tasks. */ override def registerShuffle[K, V, C]( shuffleId: Int, numMaps: Int, dependency: ShuffleDependency[K, V, C]): ShuffleHandle = { - if (SortShuffleWriter.shouldBypassMergeSort(SparkEnv.get.conf, dependency)) { + if (SortShuffleWriter.shouldBypassMergeSort(conf, dependency)) { // If there are fewer than spark.shuffle.sort.bypassMergeThreshold partitions and we don't // need map-side aggregation, then write numPartitions files directly and just concatenate // them at the end. This avoids doing serialization and deserialization twice to merge diff --git a/core/src/main/scala/org/apache/spark/util/AccumulatorV2.scala b/core/src/main/scala/org/apache/spark/util/AccumulatorV2.scala index 00e0cf257cd4a..7479de55140ea 100644 --- a/core/src/main/scala/org/apache/spark/util/AccumulatorV2.scala +++ b/core/src/main/scala/org/apache/spark/util/AccumulatorV2.scala @@ -279,7 +279,7 @@ private[spark] object AccumulatorContext { /** - * An [[AccumulatorV2 accumulator]] for computing sum, count, and averages for 64-bit integers. + * An [[AccumulatorV2 accumulator]] for computing sum, count, and average of 64-bit integers. * * @since 2.0.0 */ diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/DataFrameExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/DataFrameExample.scala index e07c9a4717c3a..0658bddf16961 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/DataFrameExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/DataFrameExample.scala @@ -30,7 +30,7 @@ import org.apache.spark.sql.{DataFrame, Row, SparkSession} import org.apache.spark.util.Utils /** - * An example of how to use [[org.apache.spark.sql.DataFrame]] for ML. Run with + * An example of how to use [[DataFrame]] for ML. Run with * {{{ * ./bin/run-example ml.DataFrameExample [options] * }}} diff --git a/mllib/src/main/scala/org/apache/spark/ml/stat/Correlation.scala b/mllib/src/main/scala/org/apache/spark/ml/stat/Correlation.scala index a7243ccbf28cc..d3c84b77d26ac 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/stat/Correlation.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/stat/Correlation.scala @@ -29,7 +29,7 @@ import org.apache.spark.sql.types.{StructField, StructType} /** * API for correlation functions in MLlib, compatible with Dataframes and Datasets. * - * The functions in this package generalize the functions in [[org.apache.spark.sql.Dataset.stat]] + * The functions in this package generalize the functions in [[org.apache.spark.sql.Dataset#stat]] * to spark.ml's Vector types. */ @Since("2.2.0") diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveHints.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveHints.scala index 70438eb5912b8..920033a9a8480 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveHints.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveHints.scala @@ -26,7 +26,7 @@ import org.apache.spark.sql.catalyst.trees.CurrentOrigin /** * Collection of rules related to hints. The only hint currently available is broadcast join hint. * - * Note that this is separatedly into two rules because in the future we might introduce new hint + * Note that this is separately into two rules because in the future we might introduce new hint * rules that have different ordering requirements from broadcast. */ object ResolveHints { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala index 93fc565a53419..ec003cdc17b89 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala @@ -229,9 +229,9 @@ case class ExpressionEncoder[T]( // serializer expressions are used to encode an object to a row, while the object is usually an // intermediate value produced inside an operator, not from the output of the child operator. This // is quite different from normal expressions, and `AttributeReference` doesn't work here - // (intermediate value is not an attribute). We assume that all serializer expressions use a same - // `BoundReference` to refer to the object, and throw exception if they don't. - assert(serializer.forall(_.references.isEmpty), "serializer cannot reference to any attributes.") + // (intermediate value is not an attribute). We assume that all serializer expressions use the + // same `BoundReference` to refer to the object, and throw exception if they don't. + assert(serializer.forall(_.references.isEmpty), "serializer cannot reference any attributes.") assert(serializer.flatMap { ser => val boundRefs = ser.collect { case b: BoundReference => b } assert(boundRefs.nonEmpty, diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala index b93a5d0b7a0e5..1db26d9c415a7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala @@ -491,7 +491,7 @@ abstract class BinaryExpression extends Expression { * A [[BinaryExpression]] that is an operator, with two properties: * * 1. The string representation is "x symbol y", rather than "funcName(x, y)". - * 2. Two inputs are expected to the be same type. If the two inputs have different types, + * 2. Two inputs are expected to be of the same type. If the two inputs have different types, * the analyzer will find the tightest common type and do the proper type casting. */ abstract class BinaryOperator extends BinaryExpression with ExpectsInputTypes { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala index 07d294b108548..b2a3888ff7b08 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala @@ -695,7 +695,7 @@ case class DenseRank(children: Seq[Expression]) extends RankLike { * * This documentation has been based upon similar documentation for the Hive and Presto projects. * - * @param children to base the rank on; a change in the value of one the children will trigger a + * @param children to base the rank on; a change in the value of one of the children will trigger a * change in rank. This is an internal parameter and will be assigned by the * Analyser. */ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/objects.scala index 174d546e22809..257dbfac8c3e8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/objects.scala @@ -65,7 +65,7 @@ object EliminateSerialization extends Rule[LogicalPlan] { /** * Combines two adjacent [[TypedFilter]]s, which operate on same type object in condition, into one, - * mering the filter functions into one conjunctive function. + * merging the filter functions into one conjunctive function. */ object CombineTypedFilters extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan transform { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index cd238e05d4102..162051a8c0e4a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -492,7 +492,7 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with Logging { } /** - * Add an [[Aggregate]] to a logical plan. + * Add an [[Aggregate]] or [[GroupingSets]] to a logical plan. */ private def withAggregation( ctx: AggregationContext, @@ -519,7 +519,7 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with Logging { } /** - * Add a Hint to a logical plan. + * Add a [[Hint]] to a logical plan. */ private def withHints( ctx: HintContext, @@ -545,7 +545,7 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with Logging { } /** - * Create a single relation referenced in a FROM claused. This method is used when a part of the + * Create a single relation referenced in a FROM clause. This method is used when a part of the * join condition is nested, for example: * {{{ * select * from t1 join (t2 cross join t3) on col1 = col2 diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala index 9fd95a4b368ce..2d8ec2053a4cb 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala @@ -230,14 +230,15 @@ abstract class QueryPlan[PlanType <: QueryPlan[PlanType]] extends TreeNode[PlanT def producedAttributes: AttributeSet = AttributeSet.empty /** - * Attributes that are referenced by expressions but not provided by this nodes children. + * Attributes that are referenced by expressions but not provided by this node's children. * Subclasses should override this method if they produce attributes internally as it is used by * assertions designed to prevent the construction of invalid plans. */ def missingInput: AttributeSet = references -- inputSet -- producedAttributes /** - * Runs [[transform]] with `rule` on all expressions present in this query operator. + * Runs [[transformExpressionsDown]] with `rule` on all expressions present + * in this query operator. * Users should not expect a specific directionality. If a specific directionality is needed, * transformExpressionsDown or transformExpressionsUp should be used. * diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala index e22b429aec68b..f71a976bd7a24 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala @@ -32,7 +32,7 @@ abstract class LogicalPlan extends QueryPlan[LogicalPlan] with Logging { private var _analyzed: Boolean = false /** - * Marks this plan as already analyzed. This should only be called by CheckAnalysis. + * Marks this plan as already analyzed. This should only be called by [[CheckAnalysis]]. */ private[catalyst] def setAnalyzed(): Unit = { _analyzed = true } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala index c2e62e739776f..d1c6b50536cd2 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala @@ -26,7 +26,8 @@ import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.CalendarInterval /** - * Test basic expression parsing. If a type of expression is supported it should be tested here. + * Test basic expression parsing. + * If the type of an expression is supported it should be tested here. * * Please note that some of the expressions test don't have to be sound expressions, only their * structure needs to be valid. Unsound expressions should be caught by the Analyzer or diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala index ae0703513cf42..43de2de7e7094 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala @@ -84,8 +84,8 @@ class TypedColumn[-T, U]( } /** - * Gives the TypedColumn a name (alias). - * If the current TypedColumn has metadata associated with it, this metadata will be propagated + * Gives the [[TypedColumn]] a name (alias). + * If the current `TypedColumn` has metadata associated with it, this metadata will be propagated * to the new column. * * @group expr_ops @@ -99,16 +99,14 @@ class TypedColumn[-T, U]( /** * A column that will be computed based on the data in a `DataFrame`. * - * A new column is constructed based on the input columns present in a dataframe: + * A new column can be constructed based on the input columns present in a DataFrame: * * {{{ - * df("columnName") // On a specific DataFrame. + * df("columnName") // On a specific `df` DataFrame. * col("columnName") // A generic column no yet associated with a DataFrame. * col("columnName.field") // Extracting a struct field * col("`a.column.with.dots`") // Escape `.` in column names. * $"columnName" // Scala short hand for a named column. - * expr("a + 1") // A column that is constructed from a parsed SQL Expression. - * lit("abc") // A column that produces a literal (constant) value. * }}} * * [[Column]] objects can be composed to form complex expressions: @@ -118,7 +116,7 @@ class TypedColumn[-T, U]( * $"a" === $"b" * }}} * - * @note The internal Catalyst expression can be accessed via "expr", but this method is for + * @note The internal Catalyst expression can be accessed via [[expr]], but this method is for * debugging purposes only and can change in any future Spark releases. * * @groupname java_expr_ops Java-specific expression operators @@ -1100,7 +1098,7 @@ class Column(val expr: Expression) extends Logging { def asc_nulls_last: Column = withExpr { SortOrder(expr, Ascending, NullsLast, Set.empty) } /** - * Prints the expression to the console for debugging purpose. + * Prints the expression to the console for debugging purposes. * * @group df_ops * @since 1.3.0 @@ -1154,8 +1152,8 @@ class Column(val expr: Expression) extends Logging { * {{{ * val w = Window.partitionBy("name").orderBy("id") * df.select( - * sum("price").over(w.rangeBetween(Long.MinValue, 2)), - * avg("price").over(w.rowsBetween(0, 4)) + * sum("price").over(w.rangeBetween(Window.unboundedPreceding, 2)), + * avg("price").over(w.rowsBetween(Window.currentRow, 4)) * ) * }}} * diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DatasetHolder.scala b/sql/core/src/main/scala/org/apache/spark/sql/DatasetHolder.scala index 18bccee98f610..582d4a3670b8e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DatasetHolder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DatasetHolder.scala @@ -24,7 +24,8 @@ import org.apache.spark.annotation.InterfaceStability * * To use this, import implicit conversions in SQL: * {{{ - * import sqlContext.implicits._ + * val spark: SparkSession = ... + * import spark.implicits._ * }}} * * @since 1.6.0 diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala index a97297892b5e0..b60499253c42f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala @@ -60,7 +60,7 @@ import org.apache.spark.util.Utils * The builder can also be used to create a new session: * * {{{ - * SparkSession.builder() + * SparkSession.builder * .master("local") * .appName("Word Count") * .config("spark.some.config.option", "some-value") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/databases.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/databases.scala index e5a6a5f60b8a6..470c736da98b7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/databases.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/databases.scala @@ -24,7 +24,7 @@ import org.apache.spark.sql.types.StringType /** * A command for users to list the databases/schemas. - * If a databasePattern is supplied then the databases that only matches the + * If a databasePattern is supplied then the databases that only match the * pattern would be listed. * The syntax of using this command in SQL is: * {{{ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/Source.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/Source.scala index 75ffe90f2bb70..311942f6dbd84 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/Source.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/Source.scala @@ -25,7 +25,7 @@ import org.apache.spark.sql.types.StructType * monotonically increasing notion of progress that can be represented as an [[Offset]]. Spark * will regularly query each [[Source]] to see if any more data is available. */ -trait Source { +trait Source { /** Returns the schema of the data from this source */ def schema: StructType diff --git a/sql/core/src/main/scala/org/apache/spark/sql/expressions/Window.scala b/sql/core/src/main/scala/org/apache/spark/sql/expressions/Window.scala index f3cf3052ea3ea..00053485e614c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/expressions/Window.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/expressions/Window.scala @@ -113,7 +113,7 @@ object Window { * Creates a [[WindowSpec]] with the frame boundaries defined, * from `start` (inclusive) to `end` (inclusive). * - * Both `start` and `end` are relative positions from the current row. For example, "0" means + * Both `start` and `end` are positions relative to the current row. For example, "0" means * "current row", while "-1" means the row before the current row, and "5" means the fifth row * after the current row. * @@ -131,9 +131,9 @@ object Window { * import org.apache.spark.sql.expressions.Window * val df = Seq((1, "a"), (1, "a"), (2, "a"), (1, "b"), (2, "b"), (3, "b")) * .toDF("id", "category") - * df.withColumn("sum", - * sum('id) over Window.partitionBy('category).orderBy('id).rowsBetween(0,1)) - * .show() + * val byCategoryOrderedById = + * Window.partitionBy('category).orderBy('id).rowsBetween(Window.currentRow, 1) + * df.withColumn("sum", sum('id) over byCategoryOrderedById).show() * * +---+--------+---+ * | id|category|sum| @@ -150,7 +150,7 @@ object Window { * @param start boundary start, inclusive. The frame is unbounded if this is * the minimum long value (`Window.unboundedPreceding`). * @param end boundary end, inclusive. The frame is unbounded if this is the - * maximum long value (`Window.unboundedFollowing`). + * maximum long value (`Window.unboundedFollowing`). * @since 2.1.0 */ // Note: when updating the doc for this method, also update WindowSpec.rowsBetween. @@ -162,7 +162,7 @@ object Window { * Creates a [[WindowSpec]] with the frame boundaries defined, * from `start` (inclusive) to `end` (inclusive). * - * Both `start` and `end` are relative from the current row. For example, "0" means "current row", + * Both `start` and `end` are relative to the current row. For example, "0" means "current row", * while "-1" means one off before the current row, and "5" means the five off after the * current row. * @@ -183,9 +183,9 @@ object Window { * import org.apache.spark.sql.expressions.Window * val df = Seq((1, "a"), (1, "a"), (2, "a"), (1, "b"), (2, "b"), (3, "b")) * .toDF("id", "category") - * df.withColumn("sum", - * sum('id) over Window.partitionBy('category).orderBy('id).rangeBetween(0,1)) - * .show() + * val byCategoryOrderedById = + * Window.partitionBy('category).orderBy('id).rowsBetween(Window.currentRow, 1) + * df.withColumn("sum", sum('id) over byCategoryOrderedById).show() * * +---+--------+---+ * | id|category|sum| @@ -202,7 +202,7 @@ object Window { * @param start boundary start, inclusive. The frame is unbounded if this is * the minimum long value (`Window.unboundedPreceding`). * @param end boundary end, inclusive. The frame is unbounded if this is the - * maximum long value (`Window.unboundedFollowing`). + * maximum long value (`Window.unboundedFollowing`). * @since 2.1.0 */ // Note: when updating the doc for this method, also update WindowSpec.rangeBetween. @@ -221,7 +221,8 @@ object Window { * * {{{ * // PARTITION BY country ORDER BY date ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW - * Window.partitionBy("country").orderBy("date").rowsBetween(Long.MinValue, 0) + * Window.partitionBy("country").orderBy("date") + * .rowsBetween(Window.unboundedPreceding, Window.currentRow) * * // PARTITION BY country ORDER BY date ROWS BETWEEN 3 PRECEDING AND 3 FOLLOWING * Window.partitionBy("country").orderBy("date").rowsBetween(-3, 3) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/expressions/WindowSpec.scala b/sql/core/src/main/scala/org/apache/spark/sql/expressions/WindowSpec.scala index de7d7a1772753..6279d48c94de5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/expressions/WindowSpec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/expressions/WindowSpec.scala @@ -86,7 +86,7 @@ class WindowSpec private[sql]( * after the current row. * * We recommend users use `Window.unboundedPreceding`, `Window.unboundedFollowing`, - * and `[Window.currentRow` to specify special boundary values, rather than using integral + * and `Window.currentRow` to specify special boundary values, rather than using integral * values directly. * * A row based boundary is based on the position of the row within the partition. @@ -99,9 +99,9 @@ class WindowSpec private[sql]( * import org.apache.spark.sql.expressions.Window * val df = Seq((1, "a"), (1, "a"), (2, "a"), (1, "b"), (2, "b"), (3, "b")) * .toDF("id", "category") - * df.withColumn("sum", - * sum('id) over Window.partitionBy('category).orderBy('id).rowsBetween(0,1)) - * .show() + * val byCategoryOrderedById = + * Window.partitionBy('category).orderBy('id).rowsBetween(Window.currentRow, 1) + * df.withColumn("sum", sum('id) over byCategoryOrderedById).show() * * +---+--------+---+ * | id|category|sum| @@ -118,7 +118,7 @@ class WindowSpec private[sql]( * @param start boundary start, inclusive. The frame is unbounded if this is * the minimum long value (`Window.unboundedPreceding`). * @param end boundary end, inclusive. The frame is unbounded if this is the - * maximum long value (`Window.unboundedFollowing`). + * maximum long value (`Window.unboundedFollowing`). * @since 1.4.0 */ // Note: when updating the doc for this method, also update Window.rowsBetween. @@ -134,7 +134,7 @@ class WindowSpec private[sql]( * current row. * * We recommend users use `Window.unboundedPreceding`, `Window.unboundedFollowing`, - * and `[Window.currentRow` to specify special boundary values, rather than using integral + * and `Window.currentRow` to specify special boundary values, rather than using integral * values directly. * * A range based boundary is based on the actual value of the ORDER BY @@ -150,9 +150,9 @@ class WindowSpec private[sql]( * import org.apache.spark.sql.expressions.Window * val df = Seq((1, "a"), (1, "a"), (2, "a"), (1, "b"), (2, "b"), (3, "b")) * .toDF("id", "category") - * df.withColumn("sum", - * sum('id) over Window.partitionBy('category).orderBy('id).rangeBetween(0,1)) - * .show() + * val byCategoryOrderedById = + * Window.partitionBy('category).orderBy('id).rangeBetween(Window.currentRow, 1) + * df.withColumn("sum", sum('id) over byCategoryOrderedById).show() * * +---+--------+---+ * | id|category|sum| @@ -169,7 +169,7 @@ class WindowSpec private[sql]( * @param start boundary start, inclusive. The frame is unbounded if this is * the minimum long value (`Window.unboundedPreceding`). * @param end boundary end, inclusive. The frame is unbounded if this is the - * maximum long value (`Window.unboundedFollowing`). + * maximum long value (`Window.unboundedFollowing`). * @since 1.4.0 */ // Note: when updating the doc for this method, also update Window.rangeBetween. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 0f9203065ef05..f07e04368389f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -2968,7 +2968,7 @@ object functions { * * @param e a string column containing JSON data. * @param schema the schema to use when parsing the json string - * @param options options to control how the json is parsed. accepts the same options and the + * @param options options to control how the json is parsed. Accepts the same options as the * json data source. * * @group collection_funcs diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionStateBuilder.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionStateBuilder.scala index 8048c2ba2c2e4..2f3dfa05e9ef7 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionStateBuilder.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionStateBuilder.scala @@ -28,7 +28,7 @@ import org.apache.spark.sql.hive.client.HiveClient import org.apache.spark.sql.internal.{BaseSessionStateBuilder, SessionResourceLoader, SessionState} /** - * Builder that produces a Hive aware [[SessionState]]. + * Builder that produces a Hive-aware `SessionState`. */ @Experimental @InterfaceStability.Unstable diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/InputInfoTracker.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/InputInfoTracker.scala index 8e1a090618433..639ac6de4f5d3 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/InputInfoTracker.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/InputInfoTracker.scala @@ -66,7 +66,7 @@ private[streaming] class InputInfoTracker(ssc: StreamingContext) extends Logging new mutable.HashMap[Int, StreamInputInfo]()) if (inputInfos.contains(inputInfo.inputStreamId)) { - throw new IllegalStateException(s"Input stream ${inputInfo.inputStreamId} for batch" + + throw new IllegalStateException(s"Input stream ${inputInfo.inputStreamId} for batch " + s"$batchTime is already added into InputInfoTracker, this is an illegal state") } inputInfos += ((inputInfo.inputStreamId, inputInfo)) From 258bff2c3f54490ddca898e276029db9adf575d9 Mon Sep 17 00:00:00 2001 From: samelamin Date: Thu, 30 Mar 2017 16:08:26 +0100 Subject: [PATCH 0158/1765] [SPARK-19999] Workaround JDK-8165231 to identify PPC64 architectures as supporting unaligned access java.nio.Bits.unaligned() does not return true for the ppc64le arch. see https://bugs.openjdk.java.net/browse/JDK-8165231 ## What changes were proposed in this pull request? check architecture ## How was this patch tested? unit test Author: samelamin Author: samelamin Closes #17472 from samelamin/SPARK-19999. --- .../org/apache/spark/unsafe/Platform.java | 28 +++++++++++-------- 1 file changed, 16 insertions(+), 12 deletions(-) diff --git a/common/unsafe/src/main/java/org/apache/spark/unsafe/Platform.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/Platform.java index f13c24ae5e017..1321b83181150 100644 --- a/common/unsafe/src/main/java/org/apache/spark/unsafe/Platform.java +++ b/common/unsafe/src/main/java/org/apache/spark/unsafe/Platform.java @@ -46,18 +46,22 @@ public final class Platform { private static final boolean unaligned; static { boolean _unaligned; - // use reflection to access unaligned field - try { - Class bitsClass = - Class.forName("java.nio.Bits", false, ClassLoader.getSystemClassLoader()); - Method unalignedMethod = bitsClass.getDeclaredMethod("unaligned"); - unalignedMethod.setAccessible(true); - _unaligned = Boolean.TRUE.equals(unalignedMethod.invoke(null)); - } catch (Throwable t) { - // We at least know x86 and x64 support unaligned access. - String arch = System.getProperty("os.arch", ""); - //noinspection DynamicRegexReplaceableByCompiledPattern - _unaligned = arch.matches("^(i[3-6]86|x86(_64)?|x64|amd64|aarch64)$"); + String arch = System.getProperty("os.arch", ""); + if (arch.equals("ppc64le") || arch.equals("ppc64")) { + // Since java.nio.Bits.unaligned() doesn't return true on ppc (See JDK-8165231), but ppc64 and ppc64le support it + _unaligned = true; + } else { + try { + Class bitsClass = + Class.forName("java.nio.Bits", false, ClassLoader.getSystemClassLoader()); + Method unalignedMethod = bitsClass.getDeclaredMethod("unaligned"); + unalignedMethod.setAccessible(true); + _unaligned = Boolean.TRUE.equals(unalignedMethod.invoke(null)); + } catch (Throwable t) { + // We at least know x86 and x64 support unaligned access. + //noinspection DynamicRegexReplaceableByCompiledPattern + _unaligned = arch.matches("^(i[3-6]86|x86(_64)?|x64|amd64|aarch64)$"); + } } unaligned = _unaligned; } From e9d268f63e7308486739aa56ece02815bfb432d6 Mon Sep 17 00:00:00 2001 From: Kent Yao Date: Thu, 30 Mar 2017 16:11:03 +0100 Subject: [PATCH 0159/1765] [SPARK-20096][SPARK SUBMIT][MINOR] Expose the right queue name not null if set by --conf or configure file MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## What changes were proposed in this pull request? while submit apps with -v or --verbose, we can print the right queue name, but if we set a queue name with `spark.yarn.queue` by --conf or in the spark-default.conf, we just got `null` for the queue in Parsed arguments. ``` bin/spark-shell -v --conf spark.yarn.queue=thequeue Using properties file: /home/hadoop/spark-2.1.0-bin-apache-hdp2.7.3/conf/spark-defaults.conf .... Adding default property: spark.yarn.queue=default Parsed arguments: master yarn deployMode client ... queue null .... verbose true Spark properties used, including those specified through --conf and those from the properties file /home/hadoop/spark-2.1.0-bin-apache-hdp2.7.3/conf/spark-defaults.conf: spark.yarn.queue -> thequeue .... ``` ## How was this patch tested? ut and local verify Author: Kent Yao Closes #17430 from yaooqinn/SPARK-20096. --- .../apache/spark/deploy/SparkSubmitArguments.scala | 1 + .../org/apache/spark/deploy/SparkSubmitSuite.scala | 11 +++++++++++ 2 files changed, 12 insertions(+) diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala index 0614d80b60e1c..0144fd1056bac 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala @@ -190,6 +190,7 @@ private[deploy] class SparkSubmitArguments(args: Seq[String], env: Map[String, S .orNull numExecutors = Option(numExecutors) .getOrElse(sparkProperties.get("spark.executor.instances").orNull) + queue = Option(queue).orElse(sparkProperties.get("spark.yarn.queue")).orNull keytab = Option(keytab).orElse(sparkProperties.get("spark.yarn.keytab")).orNull principal = Option(principal).orElse(sparkProperties.get("spark.yarn.principal")).orNull diff --git a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala index a591b98bca488..7c2ec01a03d04 100644 --- a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala @@ -148,6 +148,17 @@ class SparkSubmitSuite appArgs.childArgs should be (Seq("--master", "local", "some", "--weird", "args")) } + test("print the right queue name") { + val clArgs = Seq( + "--name", "myApp", + "--class", "Foo", + "--conf", "spark.yarn.queue=thequeue", + "userjar.jar") + val appArgs = new SparkSubmitArguments(clArgs) + appArgs.queue should be ("thequeue") + appArgs.toString should include ("thequeue") + } + test("specify deploy mode through configuration") { val clArgs = Seq( "--master", "yarn", From 669a11b61bc217a13217f1ef48d781329c45575e Mon Sep 17 00:00:00 2001 From: "Seigneurin, Alexis (CONT)" Date: Thu, 30 Mar 2017 16:12:17 +0100 Subject: [PATCH 0160/1765] [DOCS][MINOR] Fixed a few typos in the Structured Streaming documentation Fixed a few typos. There is one more I'm not sure of: ``` Append mode uses watermark to drop old aggregation state. But the output of a windowed aggregation is delayed the late threshold specified in `withWatermark()` as by the modes semantics, rows can be added to the Result Table only once after they are ``` Not sure how to change `is delayed the late threshold`. Author: Seigneurin, Alexis (CONT) Closes #17443 from aseigneurin/typos. --- docs/structured-streaming-programming-guide.md | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/docs/structured-streaming-programming-guide.md b/docs/structured-streaming-programming-guide.md index ff07ad11943bd..b5cf9f1644986 100644 --- a/docs/structured-streaming-programming-guide.md +++ b/docs/structured-streaming-programming-guide.md @@ -717,11 +717,11 @@ However, to run this query for days, it's necessary for the system to bound the intermediate in-memory state it accumulates. This means the system needs to know when an old aggregate can be dropped from the in-memory state because the application is not going to receive late data for that aggregate any more. To enable this, in Spark 2.1, we have introduced -**watermarking**, which let's the engine automatically track the current event time in the data and +**watermarking**, which lets the engine automatically track the current event time in the data and attempt to clean up old state accordingly. You can define the watermark of a query by -specifying the event time column and the threshold on how late the data is expected be in terms of +specifying the event time column and the threshold on how late the data is expected to be in terms of event time. For a specific window starting at time `T`, the engine will maintain state and allow late -data to be update the state until `(max event time seen by the engine - late threshold > T)`. +data to update the state until `(max event time seen by the engine - late threshold > T)`. In other words, late data within the threshold will be aggregated, but data later than the threshold will be dropped. Let's understand this with an example. We can easily define watermarking on the previous example using `withWatermark()` as shown below. @@ -792,7 +792,7 @@ This watermark lets the engine maintain intermediate state for additional 10 min data to be counted. For example, the data `(12:09, cat)` is out of order and late, and it falls in windows `12:05 - 12:15` and `12:10 - 12:20`. Since, it is still ahead of the watermark `12:04` in the trigger, the engine still maintains the intermediate counts as state and correctly updates the -counts of the related windows. However, when the watermark is updated to 12:11, the intermediate +counts of the related windows. However, when the watermark is updated to `12:11`, the intermediate state for window `(12:00 - 12:10)` is cleared, and all subsequent data (e.g. `(12:04, donkey)`) is considered "too late" and therefore ignored. Note that after every trigger, the updated counts (i.e. purple rows) are written to sink as the trigger output, as dictated by @@ -825,7 +825,7 @@ section for detailed explanation of the semantics of each output mode. same column as the timestamp column used in the aggregate. For example, `df.withWatermark("time", "1 min").groupBy("time2").count()` is invalid in Append output mode, as watermark is defined on a different column -as the aggregation column. +from the aggregation column. - `withWatermark` must be called before the aggregation for the watermark details to be used. For example, `df.groupBy("time").count().withWatermark("time", "1 min")` is invalid in Append @@ -909,7 +909,7 @@ track of all the data received in the stream. This is therefore fundamentally ha efficiently. ## Starting Streaming Queries -Once you have defined the final result DataFrame/Dataset, all that is left is for you start the streaming computation. To do that, you have to use the `DataStreamWriter` +Once you have defined the final result DataFrame/Dataset, all that is left is for you to start the streaming computation. To do that, you have to use the `DataStreamWriter` ([Scala](api/scala/index.html#org.apache.spark.sql.streaming.DataStreamWriter)/[Java](api/java/org/apache/spark/sql/streaming/DataStreamWriter.html)/[Python](api/python/pyspark.sql.html#pyspark.sql.streaming.DataStreamWriter) docs) returned through `Dataset.writeStream()`. You will have to specify one or more of the following in this interface. @@ -1396,15 +1396,15 @@ You can directly get the current status and metrics of an active query using `lastProgress()` returns a `StreamingQueryProgress` object in [Scala](api/scala/index.html#org.apache.spark.sql.streaming.StreamingQueryProgress) and [Java](api/java/org/apache/spark/sql/streaming/StreamingQueryProgress.html) -and an dictionary with the same fields in Python. It has all the information about +and a dictionary with the same fields in Python. It has all the information about the progress made in the last trigger of the stream - what data was processed, what were the processing rates, latencies, etc. There is also `streamingQuery.recentProgress` which returns an array of last few progresses. -In addition, `streamingQuery.status()` returns `StreamingQueryStatus` object +In addition, `streamingQuery.status()` returns a `StreamingQueryStatus` object in [Scala](api/scala/index.html#org.apache.spark.sql.streaming.StreamingQueryStatus) and [Java](api/java/org/apache/spark/sql/streaming/StreamingQueryStatus.html) -and an dictionary with the same fields in Python. It gives information about +and a dictionary with the same fields in Python. It gives information about what the query is immediately doing - is a trigger active, is data being processed, etc. Here are a few examples. From 5e00a5de14ae2d80471c6f38c30cc6fe63e05163 Mon Sep 17 00:00:00 2001 From: Denis Bolshakov Date: Thu, 30 Mar 2017 16:15:40 +0100 Subject: [PATCH 0161/1765] [SPARK-20127][CORE] few warning have been fixed which Intellij IDEA reported Intellij IDEA ## What changes were proposed in this pull request? Few changes related to Intellij IDEA inspection. ## How was this patch tested? Changes were tested by existing unit tests Author: Denis Bolshakov Closes #17458 from dbolshak/SPARK-20127. --- .../org/apache/spark/memory/TaskMemoryManager.java | 6 +----- .../org/apache/spark/status/api/v1/TaskSorting.java | 5 ++--- .../scala/org/apache/spark/io/CompressionCodec.scala | 3 +-- core/src/main/scala/org/apache/spark/ui/WebUI.scala | 2 +- .../apache/spark/ui/exec/ExecutorThreadDumpPage.scala | 2 +- .../scala/org/apache/spark/ui/exec/ExecutorsPage.scala | 3 +-- .../scala/org/apache/spark/ui/exec/ExecutorsTab.scala | 4 ++-- .../scala/org/apache/spark/ui/jobs/AllStagesPage.scala | 4 ++-- .../scala/org/apache/spark/ui/jobs/ExecutorTable.scala | 4 ++-- .../org/apache/spark/ui/jobs/JobProgressListener.scala | 4 ++-- .../scala/org/apache/spark/ui/jobs/StagePage.scala | 10 +++++----- .../scala/org/apache/spark/ui/jobs/StageTable.scala | 2 +- .../org/apache/spark/ui/storage/StoragePage.scala | 2 +- 13 files changed, 22 insertions(+), 29 deletions(-) diff --git a/core/src/main/java/org/apache/spark/memory/TaskMemoryManager.java b/core/src/main/java/org/apache/spark/memory/TaskMemoryManager.java index 39fb3b249d731..aa0b373231327 100644 --- a/core/src/main/java/org/apache/spark/memory/TaskMemoryManager.java +++ b/core/src/main/java/org/apache/spark/memory/TaskMemoryManager.java @@ -155,11 +155,7 @@ public long acquireExecutionMemory(long required, MemoryConsumer consumer) { for (MemoryConsumer c: consumers) { if (c != consumer && c.getUsed() > 0 && c.getMode() == mode) { long key = c.getUsed(); - List list = sortedConsumers.get(key); - if (list == null) { - list = new ArrayList<>(1); - sortedConsumers.put(key, list); - } + List list = sortedConsumers.computeIfAbsent(key, k -> new ArrayList<>(1)); list.add(c); } } diff --git a/core/src/main/java/org/apache/spark/status/api/v1/TaskSorting.java b/core/src/main/java/org/apache/spark/status/api/v1/TaskSorting.java index 9307eb93a5b20..b38639e854815 100644 --- a/core/src/main/java/org/apache/spark/status/api/v1/TaskSorting.java +++ b/core/src/main/java/org/apache/spark/status/api/v1/TaskSorting.java @@ -19,6 +19,7 @@ import org.apache.spark.util.EnumUtil; +import java.util.Collections; import java.util.HashSet; import java.util.Set; @@ -30,9 +31,7 @@ public enum TaskSorting { private final Set alternateNames; TaskSorting(String... names) { alternateNames = new HashSet<>(); - for (String n: names) { - alternateNames.add(n); - } + Collections.addAll(alternateNames, names); } public static TaskSorting fromString(String str) { diff --git a/core/src/main/scala/org/apache/spark/io/CompressionCodec.scala b/core/src/main/scala/org/apache/spark/io/CompressionCodec.scala index 2e991ce394c42..c216fe477fd15 100644 --- a/core/src/main/scala/org/apache/spark/io/CompressionCodec.scala +++ b/core/src/main/scala/org/apache/spark/io/CompressionCodec.scala @@ -71,8 +71,7 @@ private[spark] object CompressionCodec { val ctor = Utils.classForName(codecClass).getConstructor(classOf[SparkConf]) Some(ctor.newInstance(conf).asInstanceOf[CompressionCodec]) } catch { - case e: ClassNotFoundException => None - case e: IllegalArgumentException => None + case _: ClassNotFoundException | _: IllegalArgumentException => None } codec.getOrElse(throw new IllegalArgumentException(s"Codec [$codecName] is not available. " + s"Consider setting $configKey=$FALLBACK_COMPRESSION_CODEC")) diff --git a/core/src/main/scala/org/apache/spark/ui/WebUI.scala b/core/src/main/scala/org/apache/spark/ui/WebUI.scala index a9480cc220c8d..8b75f5d8fe1a8 100644 --- a/core/src/main/scala/org/apache/spark/ui/WebUI.scala +++ b/core/src/main/scala/org/apache/spark/ui/WebUI.scala @@ -124,7 +124,7 @@ private[spark] abstract class WebUI( /** Bind to the HTTP server behind this web interface. */ def bind(): Unit = { - assert(!serverInfo.isDefined, s"Attempted to bind $className more than once!") + assert(serverInfo.isEmpty, s"Attempted to bind $className more than once!") try { val host = Option(conf.getenv("SPARK_LOCAL_IP")).getOrElse("0.0.0.0") serverInfo = Some(startJettyServer(host, port, sslOptions, handlers, conf, name)) diff --git a/core/src/main/scala/org/apache/spark/ui/exec/ExecutorThreadDumpPage.scala b/core/src/main/scala/org/apache/spark/ui/exec/ExecutorThreadDumpPage.scala index c6a07445f2a35..dbcc6402bc309 100644 --- a/core/src/main/scala/org/apache/spark/ui/exec/ExecutorThreadDumpPage.scala +++ b/core/src/main/scala/org/apache/spark/ui/exec/ExecutorThreadDumpPage.scala @@ -49,7 +49,7 @@ private[ui] class ExecutorThreadDumpPage(parent: ExecutorsTab) extends WebUIPage }.map { thread => val threadId = thread.threadId val blockedBy = thread.blockedByThreadId match { - case Some(blockedByThreadId) => + case Some(_) =>
      Blocked by Thread {thread.blockedByThreadId} {thread.blockedByLock} diff --git a/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsPage.scala b/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsPage.scala index 2d1691e55c428..d849ce76a9e3c 100644 --- a/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsPage.scala +++ b/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsPage.scala @@ -48,7 +48,6 @@ private[ui] class ExecutorsPage( parent: ExecutorsTab, threadDumpEnabled: Boolean) extends WebUIPage("") { - private val listener = parent.listener def render(request: HttpServletRequest): Seq[Node] = { val content = @@ -59,7 +58,7 @@ private[ui] class ExecutorsPage( ++ } -
      ; + UIUtils.headerSparkPage("Executors", content, parent, useDataTables = true) } diff --git a/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsTab.scala b/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsTab.scala index 8ae712f8ed323..03851293eb2f1 100644 --- a/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsTab.scala +++ b/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsTab.scala @@ -64,7 +64,7 @@ private[ui] case class ExecutorTaskSummary( @DeveloperApi class ExecutorsListener(storageStatusListener: StorageStatusListener, conf: SparkConf) extends SparkListener { - var executorToTaskSummary = LinkedHashMap[String, ExecutorTaskSummary]() + val executorToTaskSummary = LinkedHashMap[String, ExecutorTaskSummary]() var executorEvents = new ListBuffer[SparkListenerEvent]() private val maxTimelineExecutors = conf.getInt("spark.ui.timeline.executors.maximum", 1000) @@ -137,7 +137,7 @@ class ExecutorsListener(storageStatusListener: StorageStatusListener, conf: Spar // could have failed half-way through. The correct fix would be to keep track of the // metrics added by each attempt, but this is much more complicated. return - case e: ExceptionFailure => + case _: ExceptionFailure => taskSummary.tasksFailed += 1 case _ => taskSummary.tasksComplete += 1 diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/AllStagesPage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/AllStagesPage.scala index fe6ca1099e6b0..2b0816e35747d 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/AllStagesPage.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/AllStagesPage.scala @@ -34,9 +34,9 @@ private[ui] class AllStagesPage(parent: StagesTab) extends WebUIPage("") { listener.synchronized { val activeStages = listener.activeStages.values.toSeq val pendingStages = listener.pendingStages.values.toSeq - val completedStages = listener.completedStages.reverse.toSeq + val completedStages = listener.completedStages.reverse val numCompletedStages = listener.numCompletedStages - val failedStages = listener.failedStages.reverse.toSeq + val failedStages = listener.failedStages.reverse val numFailedStages = listener.numFailedStages val subPath = "stages" diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/ExecutorTable.scala b/core/src/main/scala/org/apache/spark/ui/jobs/ExecutorTable.scala index 52f41298a1729..382a6f979f2e6 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/ExecutorTable.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/ExecutorTable.scala @@ -133,9 +133,9 @@ private[ui] class ExecutorTable(stageId: Int, stageAttemptId: Int, parent: Stage {executorIdToAddress.getOrElse(k, "CANNOT FIND ADDRESS")} {UIUtils.formatDuration(v.taskTime)} - {v.failedTasks + v.succeededTasks + v.reasonToNumKilled.map(_._2).sum} + {v.failedTasks + v.succeededTasks + v.reasonToNumKilled.values.sum} {v.failedTasks} - {v.reasonToNumKilled.map(_._2).sum} + {v.reasonToNumKilled.values.sum} {v.succeededTasks} {if (stageData.hasInput) { diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala b/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala index 1cf03e1541d14..f78db5ab80d15 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala @@ -226,7 +226,7 @@ class JobProgressListener(conf: SparkConf) extends SparkListener with Logging { trimJobsIfNecessary(completedJobs) jobData.status = JobExecutionStatus.SUCCEEDED numCompletedJobs += 1 - case JobFailed(exception) => + case JobFailed(_) => failedJobs += jobData trimJobsIfNecessary(failedJobs) jobData.status = JobExecutionStatus.FAILED @@ -284,7 +284,7 @@ class JobProgressListener(conf: SparkConf) extends SparkListener with Logging { ) { jobData.numActiveStages -= 1 if (stage.failureReason.isEmpty) { - if (!stage.submissionTime.isEmpty) { + if (stage.submissionTime.isDefined) { jobData.completedStageIndices.add(stage.stageId) } } else { diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala index ff17775008acc..19325a2dc9169 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala @@ -142,7 +142,7 @@ private[ui] class StagePage(parent: StagesTab) extends WebUIPage("stage") { val allAccumulables = progressListener.stageIdToData((stageId, stageAttemptId)).accumulables val externalAccumulables = allAccumulables.values.filter { acc => !acc.internal } - val hasAccumulators = externalAccumulables.size > 0 + val hasAccumulators = externalAccumulables.nonEmpty val summary =
      @@ -339,7 +339,7 @@ private[ui] class StagePage(parent: StagesTab) extends WebUIPage("stage") { val validTasks = tasks.filter(t => t.taskInfo.status == "SUCCESS" && t.metrics.isDefined) val summaryTable: Option[Seq[Node]] = - if (validTasks.size == 0) { + if (validTasks.isEmpty) { None } else { @@ -786,8 +786,8 @@ private[ui] object StagePage { info: TaskInfo, metrics: TaskMetricsUIData, currentTime: Long): Long = { if (info.finished) { val totalExecutionTime = info.finishTime - info.launchTime - val executorOverhead = (metrics.executorDeserializeTime + - metrics.resultSerializationTime) + val executorOverhead = metrics.executorDeserializeTime + + metrics.resultSerializationTime math.max( 0, totalExecutionTime - metrics.executorRunTime - executorOverhead - @@ -872,7 +872,7 @@ private[ui] class TaskDataSource( // so that we can avoid creating duplicate contents during sorting the data private val data = tasks.map(taskRow).sorted(ordering(sortColumn, desc)) - private var _slicedTaskIds: Set[Long] = null + private var _slicedTaskIds: Set[Long] = _ override def dataSize: Int = data.size diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala b/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala index f4caad0f58715..256b726fa7eea 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala @@ -412,7 +412,7 @@ private[ui] class StageDataSource( // so that we can avoid creating duplicate contents during sorting the data private val data = stages.map(stageRow).sorted(ordering(sortColumn, desc)) - private var _slicedStageIds: Set[Int] = null + private var _slicedStageIds: Set[Int] = _ override def dataSize: Int = data.size diff --git a/core/src/main/scala/org/apache/spark/ui/storage/StoragePage.scala b/core/src/main/scala/org/apache/spark/ui/storage/StoragePage.scala index 76d7c6d414bcf..aa84788f1df88 100644 --- a/core/src/main/scala/org/apache/spark/ui/storage/StoragePage.scala +++ b/core/src/main/scala/org/apache/spark/ui/storage/StoragePage.scala @@ -151,7 +151,7 @@ private[ui] class StoragePage(parent: StorageTab) extends WebUIPage("") { /** Render a stream block */ private def streamBlockTableRow(block: (BlockId, Seq[BlockUIData])): Seq[Node] = { val replications = block._2 - assert(replications.size > 0) // This must be true because it's the result of "groupBy" + assert(replications.nonEmpty) // This must be true because it's the result of "groupBy" if (replications.size == 1) { streamBlockTableSubrow(block._1, replications.head, replications.size, true) } else { From c734fc504a3f6a3d3b0bd90ff54604b17df2b413 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Thu, 30 Mar 2017 13:36:36 -0700 Subject: [PATCH 0162/1765] [SPARK-20121][SQL] simplify NullPropagation with NullIntolerant ## What changes were proposed in this pull request? Instead of iterating all expressions that can return null for null inputs, we can just check `NullIntolerant`. ## How was this patch tested? existing tests Author: Wenchen Fan Closes #17450 from cloud-fan/null. --- .../sql/catalyst/expressions/arithmetic.scala | 18 +++--- .../expressions/complexTypeExtractors.scala | 8 +-- .../sql/catalyst/expressions/package.scala | 2 +- .../expressions/regexpExpressions.scala | 10 ++-- .../expressions/stringExpressions.scala | 15 ++--- .../sql/catalyst/optimizer/expressions.scala | 59 ++++++------------- 6 files changed, 39 insertions(+), 73 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala index 4870093e9250f..f2b252259b89d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala @@ -113,7 +113,7 @@ case class Abs(child: Expression) protected override def nullSafeEval(input: Any): Any = numeric.abs(input) } -abstract class BinaryArithmetic extends BinaryOperator { +abstract class BinaryArithmetic extends BinaryOperator with NullIntolerant { override def dataType: DataType = left.dataType @@ -146,7 +146,7 @@ object BinaryArithmetic { > SELECT 1 _FUNC_ 2; 3 """) -case class Add(left: Expression, right: Expression) extends BinaryArithmetic with NullIntolerant { +case class Add(left: Expression, right: Expression) extends BinaryArithmetic { override def inputType: AbstractDataType = TypeCollection.NumericAndInterval @@ -182,8 +182,7 @@ case class Add(left: Expression, right: Expression) extends BinaryArithmetic wit > SELECT 2 _FUNC_ 1; 1 """) -case class Subtract(left: Expression, right: Expression) - extends BinaryArithmetic with NullIntolerant { +case class Subtract(left: Expression, right: Expression) extends BinaryArithmetic { override def inputType: AbstractDataType = TypeCollection.NumericAndInterval @@ -219,8 +218,7 @@ case class Subtract(left: Expression, right: Expression) > SELECT 2 _FUNC_ 3; 6 """) -case class Multiply(left: Expression, right: Expression) - extends BinaryArithmetic with NullIntolerant { +case class Multiply(left: Expression, right: Expression) extends BinaryArithmetic { override def inputType: AbstractDataType = NumericType @@ -243,8 +241,7 @@ case class Multiply(left: Expression, right: Expression) 1.0 """) // scalastyle:on line.size.limit -case class Divide(left: Expression, right: Expression) - extends BinaryArithmetic with NullIntolerant { +case class Divide(left: Expression, right: Expression) extends BinaryArithmetic { override def inputType: AbstractDataType = TypeCollection(DoubleType, DecimalType) @@ -324,8 +321,7 @@ case class Divide(left: Expression, right: Expression) > SELECT 2 _FUNC_ 1.8; 0.2 """) -case class Remainder(left: Expression, right: Expression) - extends BinaryArithmetic with NullIntolerant { +case class Remainder(left: Expression, right: Expression) extends BinaryArithmetic { override def inputType: AbstractDataType = NumericType @@ -412,7 +408,7 @@ case class Remainder(left: Expression, right: Expression) > SELECT _FUNC_(-10, 3); 2 """) -case class Pmod(left: Expression, right: Expression) extends BinaryArithmetic with NullIntolerant { +case class Pmod(left: Expression, right: Expression) extends BinaryArithmetic { override def toString: String = s"pmod($left, $right)" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala index 0c256c3d890f1..de1594d119e17 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala @@ -104,7 +104,7 @@ trait ExtractValue extends Expression * For example, when get field `yEAr` from ``, we should pass in `yEAr`. */ case class GetStructField(child: Expression, ordinal: Int, name: Option[String] = None) - extends UnaryExpression with ExtractValue { + extends UnaryExpression with ExtractValue with NullIntolerant { lazy val childSchema = child.dataType.asInstanceOf[StructType] @@ -152,7 +152,7 @@ case class GetArrayStructFields( field: StructField, ordinal: Int, numFields: Int, - containsNull: Boolean) extends UnaryExpression with ExtractValue { + containsNull: Boolean) extends UnaryExpression with ExtractValue with NullIntolerant { override def dataType: DataType = ArrayType(field.dataType, containsNull) override def toString: String = s"$child.${field.name}" @@ -213,7 +213,7 @@ case class GetArrayStructFields( * We need to do type checking here as `ordinal` expression maybe unresolved. */ case class GetArrayItem(child: Expression, ordinal: Expression) - extends BinaryExpression with ExpectsInputTypes with ExtractValue { + extends BinaryExpression with ExpectsInputTypes with ExtractValue with NullIntolerant { // We have done type checking for child in `ExtractValue`, so only need to check the `ordinal`. override def inputTypes: Seq[AbstractDataType] = Seq(AnyDataType, IntegralType) @@ -260,7 +260,7 @@ case class GetArrayItem(child: Expression, ordinal: Expression) * We need to do type checking here as `key` expression maybe unresolved. */ case class GetMapValue(child: Expression, key: Expression) - extends BinaryExpression with ImplicitCastInputTypes with ExtractValue { + extends BinaryExpression with ImplicitCastInputTypes with ExtractValue with NullIntolerant { private def keyType = child.dataType.asInstanceOf[MapType].keyType diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala index 1b00c9e79da22..4c8b177237d23 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala @@ -138,5 +138,5 @@ package object expressions { * input will result in null output). We will use this information during constructing IsNotNull * constraints. */ - trait NullIntolerant + trait NullIntolerant extends Expression } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala index 4896a6225aa80..b23da537be721 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala @@ -27,8 +27,8 @@ import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String -trait StringRegexExpression extends ImplicitCastInputTypes { - self: BinaryExpression => +abstract class StringRegexExpression extends BinaryExpression + with ImplicitCastInputTypes with NullIntolerant { def escape(v: String): String def matches(regex: Pattern, str: String): Boolean @@ -69,8 +69,7 @@ trait StringRegexExpression extends ImplicitCastInputTypes { */ @ExpressionDescription( usage = "str _FUNC_ pattern - Returns true if `str` matches `pattern`, or false otherwise.") -case class Like(left: Expression, right: Expression) - extends BinaryExpression with StringRegexExpression { +case class Like(left: Expression, right: Expression) extends StringRegexExpression { override def escape(v: String): String = StringUtils.escapeLikeRegex(v) @@ -122,8 +121,7 @@ case class Like(left: Expression, right: Expression) @ExpressionDescription( usage = "str _FUNC_ regexp - Returns true if `str` matches `regexp`, or false otherwise.") -case class RLike(left: Expression, right: Expression) - extends BinaryExpression with StringRegexExpression { +case class RLike(left: Expression, right: Expression) extends StringRegexExpression { override def escape(v: String): String = v override def matches(regex: Pattern, str: String): Boolean = regex.matcher(str).find(0) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala index 908aa44f81c97..5598a146997ca 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala @@ -297,8 +297,8 @@ case class Lower(child: Expression) extends UnaryExpression with String2StringEx } /** A base trait for functions that compare two strings, returning a boolean. */ -trait StringPredicate extends Predicate with ImplicitCastInputTypes { - self: BinaryExpression => +abstract class StringPredicate extends BinaryExpression + with Predicate with ImplicitCastInputTypes with NullIntolerant { def compare(l: UTF8String, r: UTF8String): Boolean @@ -313,8 +313,7 @@ trait StringPredicate extends Predicate with ImplicitCastInputTypes { /** * A function that returns true if the string `left` contains the string `right`. */ -case class Contains(left: Expression, right: Expression) - extends BinaryExpression with StringPredicate { +case class Contains(left: Expression, right: Expression) extends StringPredicate { override def compare(l: UTF8String, r: UTF8String): Boolean = l.contains(r) override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { defineCodeGen(ctx, ev, (c1, c2) => s"($c1).contains($c2)") @@ -324,8 +323,7 @@ case class Contains(left: Expression, right: Expression) /** * A function that returns true if the string `left` starts with the string `right`. */ -case class StartsWith(left: Expression, right: Expression) - extends BinaryExpression with StringPredicate { +case class StartsWith(left: Expression, right: Expression) extends StringPredicate { override def compare(l: UTF8String, r: UTF8String): Boolean = l.startsWith(r) override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { defineCodeGen(ctx, ev, (c1, c2) => s"($c1).startsWith($c2)") @@ -335,8 +333,7 @@ case class StartsWith(left: Expression, right: Expression) /** * A function that returns true if the string `left` ends with the string `right`. */ -case class EndsWith(left: Expression, right: Expression) - extends BinaryExpression with StringPredicate { +case class EndsWith(left: Expression, right: Expression) extends StringPredicate { override def compare(l: UTF8String, r: UTF8String): Boolean = l.endsWith(r) override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { defineCodeGen(ctx, ev, (c1, c2) => s"($c1).endsWith($c2)") @@ -1122,7 +1119,7 @@ case class StringSpace(child: Expression) """) // scalastyle:on line.size.limit case class Substring(str: Expression, pos: Expression, len: Expression) - extends TernaryExpression with ImplicitCastInputTypes { + extends TernaryExpression with ImplicitCastInputTypes with NullIntolerant { def this(str: Expression, pos: Expression) = { this(str, pos, Literal(Integer.MAX_VALUE)) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala index 21d1cd5932620..33039127f16ce 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala @@ -347,35 +347,30 @@ object LikeSimplification extends Rule[LogicalPlan] { * Null value propagation from bottom to top of the expression tree. */ case class NullPropagation(conf: CatalystConf) extends Rule[LogicalPlan] { - private def nonNullLiteral(e: Expression): Boolean = e match { - case Literal(null, _) => false - case _ => true + private def isNullLiteral(e: Expression): Boolean = e match { + case Literal(null, _) => true + case _ => false } def apply(plan: LogicalPlan): LogicalPlan = plan transform { case q: LogicalPlan => q transformExpressionsUp { case e @ WindowExpression(Cast(Literal(0L, _), _, _), _) => Cast(Literal(0L), e.dataType, Option(conf.sessionLocalTimeZone)) - case e @ AggregateExpression(Count(exprs), _, _, _) if !exprs.exists(nonNullLiteral) => + case e @ AggregateExpression(Count(exprs), _, _, _) if exprs.forall(isNullLiteral) => Cast(Literal(0L), e.dataType, Option(conf.sessionLocalTimeZone)) - case e @ IsNull(c) if !c.nullable => Literal.create(false, BooleanType) - case e @ IsNotNull(c) if !c.nullable => Literal.create(true, BooleanType) - case e @ GetArrayItem(Literal(null, _), _) => Literal.create(null, e.dataType) - case e @ GetArrayItem(_, Literal(null, _)) => Literal.create(null, e.dataType) - case e @ GetMapValue(Literal(null, _), _) => Literal.create(null, e.dataType) - case e @ GetMapValue(_, Literal(null, _)) => Literal.create(null, e.dataType) - case e @ GetStructField(Literal(null, _), _, _) => Literal.create(null, e.dataType) - case e @ GetArrayStructFields(Literal(null, _), _, _, _, _) => - Literal.create(null, e.dataType) - case e @ EqualNullSafe(Literal(null, _), r) => IsNull(r) - case e @ EqualNullSafe(l, Literal(null, _)) => IsNull(l) case ae @ AggregateExpression(Count(exprs), _, false, _) if !exprs.exists(_.nullable) => // This rule should be only triggered when isDistinct field is false. ae.copy(aggregateFunction = Count(Literal(1))) + case IsNull(c) if !c.nullable => Literal.create(false, BooleanType) + case IsNotNull(c) if !c.nullable => Literal.create(true, BooleanType) + + case EqualNullSafe(Literal(null, _), r) => IsNull(r) + case EqualNullSafe(l, Literal(null, _)) => IsNull(l) + // For Coalesce, remove null literals. case e @ Coalesce(children) => - val newChildren = children.filter(nonNullLiteral) + val newChildren = children.filterNot(isNullLiteral) if (newChildren.isEmpty) { Literal.create(null, e.dataType) } else if (newChildren.length == 1) { @@ -384,33 +379,13 @@ case class NullPropagation(conf: CatalystConf) extends Rule[LogicalPlan] { Coalesce(newChildren) } - case e @ Substring(Literal(null, _), _, _) => Literal.create(null, e.dataType) - case e @ Substring(_, Literal(null, _), _) => Literal.create(null, e.dataType) - case e @ Substring(_, _, Literal(null, _)) => Literal.create(null, e.dataType) - - // Put exceptional cases above if any - case e @ BinaryArithmetic(Literal(null, _), _) => Literal.create(null, e.dataType) - case e @ BinaryArithmetic(_, Literal(null, _)) => Literal.create(null, e.dataType) - - case e @ BinaryComparison(Literal(null, _), _) => Literal.create(null, e.dataType) - case e @ BinaryComparison(_, Literal(null, _)) => Literal.create(null, e.dataType) - - case e: StringRegexExpression => e.children match { - case Literal(null, _) :: right :: Nil => Literal.create(null, e.dataType) - case left :: Literal(null, _) :: Nil => Literal.create(null, e.dataType) - case _ => e - } - - case e: StringPredicate => e.children match { - case Literal(null, _) :: right :: Nil => Literal.create(null, e.dataType) - case left :: Literal(null, _) :: Nil => Literal.create(null, e.dataType) - case _ => e - } - - // If the value expression is NULL then transform the In expression to - // Literal(null) - case In(Literal(null, _), list) => Literal.create(null, BooleanType) + // If the value expression is NULL then transform the In expression to null literal. + case In(Literal(null, _), _) => Literal.create(null, BooleanType) + // Non-leaf NullIntolerant expressions will return null, if at least one of its children is + // a null literal. + case e: NullIntolerant if e.children.exists(isNullLiteral) => + Literal.create(null, e.dataType) } } } From a8a765b3f302c078cb9519c4a17912cd38b9680c Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Thu, 30 Mar 2017 23:09:33 -0700 Subject: [PATCH 0163/1765] [SPARK-20151][SQL] Account for partition pruning in scan metadataTime metrics ## What changes were proposed in this pull request? After SPARK-20136, we report metadata timing metrics in scan operator. However, that timing metric doesn't include one of the most important part of metadata, which is partition pruning. This patch adds that time measurement to the scan metrics. ## How was this patch tested? N/A - I tried adding a test in SQLMetricsSuite but it was extremely convoluted to the point that I'm not sure if this is worth it. Author: Reynold Xin Closes #17476 from rxin/SPARK-20151. --- .../spark/sql/execution/DataSourceScanExec.scala | 5 +++-- .../sql/execution/datasources/CatalogFileIndex.scala | 7 +++++-- .../spark/sql/execution/datasources/FileIndex.scala | 10 ++++++++++ 3 files changed, 18 insertions(+), 4 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala index 239151495f4bd..2fa660c4d5e01 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala @@ -172,12 +172,13 @@ case class FileSourceScanExec( } @transient private lazy val selectedPartitions: Seq[PartitionDirectory] = { + val optimizerMetadataTimeNs = relation.location.metadataOpsTimeNs.getOrElse(0L) val startTime = System.nanoTime() val ret = relation.location.listFiles(partitionFilters, dataFilters) - val timeTaken = (System.nanoTime() - startTime) / 1000 / 1000 + val timeTakenMs = ((System.nanoTime() - startTime) + optimizerMetadataTimeNs) / 1000 / 1000 metrics("numFiles").add(ret.map(_.files.size.toLong).sum) - metrics("metadataTime").add(timeTaken) + metrics("metadataTime").add(timeTakenMs) val executionId = sparkContext.getLocalProperty(SQLExecution.EXECUTION_ID_KEY) SQLMetrics.postDriverMetricUpdates(sparkContext, executionId, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/CatalogFileIndex.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/CatalogFileIndex.scala index db0254f8d5581..4046396d0e614 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/CatalogFileIndex.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/CatalogFileIndex.scala @@ -69,6 +69,7 @@ class CatalogFileIndex( */ def filterPartitions(filters: Seq[Expression]): InMemoryFileIndex = { if (table.partitionColumnNames.nonEmpty) { + val startTime = System.nanoTime() val selectedPartitions = sparkSession.sessionState.catalog.listPartitionsByFilter( table.identifier, filters) val partitions = selectedPartitions.map { p => @@ -79,8 +80,9 @@ class CatalogFileIndex( path.makeQualified(fs.getUri, fs.getWorkingDirectory)) } val partitionSpec = PartitionSpec(partitionSchema, partitions) + val timeNs = System.nanoTime() - startTime new PrunedInMemoryFileIndex( - sparkSession, new Path(baseLocation.get), fileStatusCache, partitionSpec) + sparkSession, new Path(baseLocation.get), fileStatusCache, partitionSpec, Option(timeNs)) } else { new InMemoryFileIndex( sparkSession, rootPaths, table.storage.properties, partitionSchema = None) @@ -111,7 +113,8 @@ private class PrunedInMemoryFileIndex( sparkSession: SparkSession, tableBasePath: Path, fileStatusCache: FileStatusCache, - override val partitionSpec: PartitionSpec) + override val partitionSpec: PartitionSpec, + override val metadataOpsTimeNs: Option[Long]) extends InMemoryFileIndex( sparkSession, partitionSpec.partitions.map(_.path), diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileIndex.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileIndex.scala index 6b99d38fe5729..094a66a2820f3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileIndex.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileIndex.scala @@ -72,4 +72,14 @@ trait FileIndex { /** Schema of the partitioning columns, or the empty schema if the table is not partitioned. */ def partitionSchema: StructType + + /** + * Returns an optional metadata operation time, in nanoseconds, for listing files. + * + * We do file listing in query optimization (in order to get the proper statistics) and we want + * to account for file listing time in physical execution (as metrics). To do that, we save the + * file listing time in some implementations and physical execution calls it in this method + * to update the metrics. + */ + def metadataOpsTimeNs: Option[Long] = None } From 254877c2f04414c70d92fa0a00c0ecee1d73aba7 Mon Sep 17 00:00:00 2001 From: Kunal Khamar Date: Fri, 31 Mar 2017 09:17:22 -0700 Subject: [PATCH 0164/1765] [SPARK-20164][SQL] AnalysisException not tolerant of null query plan. ## What changes were proposed in this pull request? The query plan in an `AnalysisException` may be `null` when an `AnalysisException` object is serialized and then deserialized, since `plan` is marked `transient`. Or when someone throws an `AnalysisException` with a null query plan (which should not happen). `def getMessage` is not tolerant of this and throws a `NullPointerException`, leading to loss of information about the original exception. The fix is to add a `null` check in `getMessage`. ## How was this patch tested? - Unit test Author: Kunal Khamar Closes #17486 from kunalkhamar/spark-20164. --- .../scala/org/apache/spark/sql/AnalysisException.scala | 2 +- .../test/scala/org/apache/spark/sql/SQLQuerySuite.scala | 8 ++++++++ 2 files changed, 9 insertions(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/AnalysisException.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/AnalysisException.scala index ff8576157305b..50ee6cd4085ea 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/AnalysisException.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/AnalysisException.scala @@ -43,7 +43,7 @@ class AnalysisException protected[sql] ( } override def getMessage: String = { - val planAnnotation = plan.map(p => s";\n$p").getOrElse("") + val planAnnotation = Option(plan).flatten.map(p => s";\n$p").getOrElse("") getSimpleMessage + planAnnotation } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index d9e0196c57957..0dd9296a3f0ff 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -2598,4 +2598,12 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { } assert(!jobStarted.get(), "Command should not trigger a Spark job.") } + + test("SPARK-20164: AnalysisException should be tolerant to null query plan") { + try { + throw new AnalysisException("", None, None, plan = null) + } catch { + case ae: AnalysisException => assert(ae.plan == null && ae.getMessage == ae.getSimpleMessage) + } + } } From c4c03eed67c05a78dc8944f6119ea708d6b955be Mon Sep 17 00:00:00 2001 From: Ryan Blue Date: Fri, 31 Mar 2017 09:42:49 -0700 Subject: [PATCH 0165/1765] [SPARK-20084][CORE] Remove internal.metrics.updatedBlockStatuses from history files. ## What changes were proposed in this pull request? Remove accumulator updates for internal.metrics.updatedBlockStatuses from SparkListenerTaskEnd entries in the history file. These can cause history files to grow to hundreds of GB because the value of the accumulator contains all tracked blocks. ## How was this patch tested? Current History UI tests cover use of the history file. Author: Ryan Blue Closes #17412 from rdblue/SPARK-20084-remove-block-accumulator-info. --- .../org/apache/spark/util/JsonProtocol.scala | 15 +++++++++++---- 1 file changed, 11 insertions(+), 4 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala b/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala index 2cb88919c8c83..1d2cb7acefa33 100644 --- a/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala +++ b/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala @@ -264,8 +264,7 @@ private[spark] object JsonProtocol { ("Submission Time" -> submissionTime) ~ ("Completion Time" -> completionTime) ~ ("Failure Reason" -> failureReason) ~ - ("Accumulables" -> JArray( - stageInfo.accumulables.values.map(accumulableInfoToJson).toList)) + ("Accumulables" -> accumulablesToJson(stageInfo.accumulables.values)) } def taskInfoToJson(taskInfo: TaskInfo): JValue = { @@ -281,7 +280,15 @@ private[spark] object JsonProtocol { ("Finish Time" -> taskInfo.finishTime) ~ ("Failed" -> taskInfo.failed) ~ ("Killed" -> taskInfo.killed) ~ - ("Accumulables" -> JArray(taskInfo.accumulables.toList.map(accumulableInfoToJson))) + ("Accumulables" -> accumulablesToJson(taskInfo.accumulables)) + } + + private lazy val accumulableBlacklist = Set("internal.metrics.updatedBlockStatuses") + + def accumulablesToJson(accumulables: Traversable[AccumulableInfo]): JArray = { + JArray(accumulables + .filterNot(_.name.exists(accumulableBlacklist.contains)) + .toList.map(accumulableInfoToJson)) } def accumulableInfoToJson(accumulableInfo: AccumulableInfo): JValue = { @@ -376,7 +383,7 @@ private[spark] object JsonProtocol { ("Message" -> fetchFailed.message) case exceptionFailure: ExceptionFailure => val stackTrace = stackTraceToJson(exceptionFailure.stackTrace) - val accumUpdates = JArray(exceptionFailure.accumUpdates.map(accumulableInfoToJson).toList) + val accumUpdates = accumulablesToJson(exceptionFailure.accumUpdates) ("Class Name" -> exceptionFailure.className) ~ ("Description" -> exceptionFailure.description) ~ ("Stack Trace" -> stackTrace) ~ From b2349e6a00d569851f0ca91a60e9299306208e92 Mon Sep 17 00:00:00 2001 From: Xiao Li Date: Sat, 1 Apr 2017 00:56:18 +0800 Subject: [PATCH 0166/1765] [SPARK-20160][SQL] Move ParquetConversions and OrcConversions Out Of HiveSessionCatalog ### What changes were proposed in this pull request? `ParquetConversions` and `OrcConversions` should be treated as regular `Analyzer` rules. It is not reasonable to be part of `HiveSessionCatalog`. This PR also combines two rules `ParquetConversions` and `OrcConversions` to build a new rule `RelationConversions `. After moving these two rules out of HiveSessionCatalog, the next step is to clean up, rename and move `HiveMetastoreCatalog` because it is not related to the hive package any more. ### How was this patch tested? The existing test cases Author: Xiao Li Closes #17484 from gatorsmile/cleanup. --- .../spark/sql/hive/HiveMetastoreCatalog.scala | 96 ++----------------- .../spark/sql/hive/HiveSessionCatalog.scala | 25 +---- .../sql/hive/HiveSessionStateBuilder.scala | 3 +- .../spark/sql/hive/HiveStrategies.scala | 56 ++++++++++- .../hive/JavaMetastoreDataSourcesSuite.java | 3 +- .../spark/sql/hive/StatisticsSuite.scala | 4 +- .../apache/spark/sql/hive/parquetSuites.scala | 5 +- 7 files changed, 70 insertions(+), 122 deletions(-) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala index 305bd007c93f7..10f432570e94b 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala @@ -28,11 +28,7 @@ import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.{QualifiedTableName, TableIdentifier} import org.apache.spark.sql.catalyst.catalog._ import org.apache.spark.sql.catalyst.plans.logical._ -import org.apache.spark.sql.catalyst.rules._ -import org.apache.spark.sql.execution.command.DDLUtils import org.apache.spark.sql.execution.datasources._ -import org.apache.spark.sql.execution.datasources.parquet.{ParquetFileFormat, ParquetOptions} -import org.apache.spark.sql.hive.orc.OrcFileFormat import org.apache.spark.sql.internal.SQLConf.HiveCaseSensitiveInferenceMode._ import org.apache.spark.sql.types._ @@ -48,14 +44,6 @@ private[hive] class HiveMetastoreCatalog(sparkSession: SparkSession) extends Log private def tableRelationCache = sparkSession.sessionState.catalog.tableRelationCache import HiveMetastoreCatalog._ - private def getCurrentDatabase: String = sessionState.catalog.getCurrentDatabase - - def getQualifiedTableName(tableIdent: TableIdentifier): QualifiedTableName = { - QualifiedTableName( - tableIdent.database.getOrElse(getCurrentDatabase).toLowerCase, - tableIdent.table.toLowerCase) - } - /** These locks guard against multiple attempts to instantiate a table, which wastes memory. */ private val tableCreationLocks = Striped.lazyWeakLock(100) @@ -68,11 +56,12 @@ private[hive] class HiveMetastoreCatalog(sparkSession: SparkSession) extends Log } } - def hiveDefaultTableFilePath(tableIdent: TableIdentifier): String = { - // Code based on: hiveWarehouse.getTablePath(currentDatabase, tableName) - val QualifiedTableName(dbName, tblName) = getQualifiedTableName(tableIdent) - val dbLocation = sparkSession.sharedState.externalCatalog.getDatabase(dbName).locationUri - new Path(new Path(dbLocation), tblName).toString + // For testing only + private[hive] def getCachedDataSourceTable(table: TableIdentifier): LogicalPlan = { + val key = QualifiedTableName( + table.database.getOrElse(sessionState.catalog.getCurrentDatabase).toLowerCase, + table.table.toLowerCase) + tableRelationCache.getIfPresent(key) } private def getCached( @@ -122,7 +111,7 @@ private[hive] class HiveMetastoreCatalog(sparkSession: SparkSession) extends Log } } - private def convertToLogicalRelation( + def convertToLogicalRelation( relation: CatalogRelation, options: Map[String, String], fileFormatClass: Class[_ <: FileFormat], @@ -273,78 +262,9 @@ private[hive] class HiveMetastoreCatalog(sparkSession: SparkSession) extends Log case NonFatal(ex) => logWarning(s"Unable to save case-sensitive schema for table ${identifier.unquotedString}", ex) } - - /** - * When scanning or writing to non-partitioned Metastore Parquet tables, convert them to Parquet - * data source relations for better performance. - */ - object ParquetConversions extends Rule[LogicalPlan] { - private def shouldConvertMetastoreParquet(relation: CatalogRelation): Boolean = { - relation.tableMeta.storage.serde.getOrElse("").toLowerCase.contains("parquet") && - sessionState.conf.getConf(HiveUtils.CONVERT_METASTORE_PARQUET) - } - - private def convertToParquetRelation(relation: CatalogRelation): LogicalRelation = { - val fileFormatClass = classOf[ParquetFileFormat] - val mergeSchema = sessionState.conf.getConf( - HiveUtils.CONVERT_METASTORE_PARQUET_WITH_SCHEMA_MERGING) - val options = Map(ParquetOptions.MERGE_SCHEMA -> mergeSchema.toString) - - convertToLogicalRelation(relation, options, fileFormatClass, "parquet") - } - - override def apply(plan: LogicalPlan): LogicalPlan = { - plan transformUp { - // Write path - case InsertIntoTable(r: CatalogRelation, partition, query, overwrite, ifNotExists) - // Inserting into partitioned table is not supported in Parquet data source (yet). - if query.resolved && DDLUtils.isHiveTable(r.tableMeta) && - !r.isPartitioned && shouldConvertMetastoreParquet(r) => - InsertIntoTable(convertToParquetRelation(r), partition, query, overwrite, ifNotExists) - - // Read path - case relation: CatalogRelation if DDLUtils.isHiveTable(relation.tableMeta) && - shouldConvertMetastoreParquet(relation) => - convertToParquetRelation(relation) - } - } - } - - /** - * When scanning Metastore ORC tables, convert them to ORC data source relations - * for better performance. - */ - object OrcConversions extends Rule[LogicalPlan] { - private def shouldConvertMetastoreOrc(relation: CatalogRelation): Boolean = { - relation.tableMeta.storage.serde.getOrElse("").toLowerCase.contains("orc") && - sessionState.conf.getConf(HiveUtils.CONVERT_METASTORE_ORC) - } - - private def convertToOrcRelation(relation: CatalogRelation): LogicalRelation = { - val fileFormatClass = classOf[OrcFileFormat] - val options = Map[String, String]() - - convertToLogicalRelation(relation, options, fileFormatClass, "orc") - } - - override def apply(plan: LogicalPlan): LogicalPlan = { - plan transformUp { - // Write path - case InsertIntoTable(r: CatalogRelation, partition, query, overwrite, ifNotExists) - // Inserting into partitioned table is not supported in Orc data source (yet). - if query.resolved && DDLUtils.isHiveTable(r.tableMeta) && - !r.isPartitioned && shouldConvertMetastoreOrc(r) => - InsertIntoTable(convertToOrcRelation(r), partition, query, overwrite, ifNotExists) - - // Read path - case relation: CatalogRelation if DDLUtils.isHiveTable(relation.tableMeta) && - shouldConvertMetastoreOrc(relation) => - convertToOrcRelation(relation) - } - } - } } + private[hive] object HiveMetastoreCatalog { def mergeWithMetastoreSchema( metastoreSchema: StructType, diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala index 2cc20a791d80c..9e3eb2dd8234a 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala @@ -26,14 +26,12 @@ import org.apache.hadoop.hive.ql.exec.{FunctionRegistry => HiveFunctionRegistry} import org.apache.hadoop.hive.ql.udf.generic.{AbstractGenericUDAFResolver, GenericUDF, GenericUDTF} import org.apache.spark.sql.AnalysisException -import org.apache.spark.sql.catalyst.{FunctionIdentifier, TableIdentifier} +import org.apache.spark.sql.catalyst.FunctionIdentifier import org.apache.spark.sql.catalyst.analysis.FunctionRegistry import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FunctionBuilder import org.apache.spark.sql.catalyst.catalog.{FunctionResourceLoader, GlobalTempViewManager, SessionCatalog} import org.apache.spark.sql.catalyst.expressions.{Cast, Expression, ExpressionInfo} import org.apache.spark.sql.catalyst.parser.ParserInterface -import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan -import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.hive.HiveShim.HiveFunctionWrapper import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.{DecimalType, DoubleType} @@ -43,7 +41,7 @@ import org.apache.spark.util.Utils private[sql] class HiveSessionCatalog( externalCatalog: HiveExternalCatalog, globalTempViewManager: GlobalTempViewManager, - private val metastoreCatalog: HiveMetastoreCatalog, + val metastoreCatalog: HiveMetastoreCatalog, functionRegistry: FunctionRegistry, conf: SQLConf, hadoopConf: Configuration, @@ -58,25 +56,6 @@ private[sql] class HiveSessionCatalog( parser, functionResourceLoader) { - // ---------------------------------------------------------------- - // | Methods and fields for interacting with HiveMetastoreCatalog | - // ---------------------------------------------------------------- - - // These 2 rules must be run before all other DDL post-hoc resolution rules, i.e. - // `PreprocessTableCreation`, `PreprocessTableInsertion`, `DataSourceAnalysis` and `HiveAnalysis`. - val ParquetConversions: Rule[LogicalPlan] = metastoreCatalog.ParquetConversions - val OrcConversions: Rule[LogicalPlan] = metastoreCatalog.OrcConversions - - def hiveDefaultTableFilePath(name: TableIdentifier): String = { - metastoreCatalog.hiveDefaultTableFilePath(name) - } - - // For testing only - private[hive] def getCachedDataSourceTable(table: TableIdentifier): LogicalPlan = { - val key = metastoreCatalog.getQualifiedTableName(table) - tableRelationCache.getIfPresent(key) - } - override def makeFunctionBuilder(funcName: String, className: String): FunctionBuilder = { makeFunctionBuilder(funcName, Utils.classForName(className)) } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionStateBuilder.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionStateBuilder.scala index 2f3dfa05e9ef7..9d3b31f39c0f5 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionStateBuilder.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionStateBuilder.scala @@ -75,8 +75,7 @@ class HiveSessionStateBuilder(session: SparkSession, parentState: Option[Session override val postHocResolutionRules: Seq[Rule[LogicalPlan]] = new DetermineTableStats(session) +: - catalog.ParquetConversions +: - catalog.OrcConversions +: + RelationConversions(conf, catalog) +: PreprocessTableCreation(session) +: PreprocessTableInsertion(conf) +: DataSourceAnalysis(conf) +: diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala index b5ce027d51e73..0465e9c031e27 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala @@ -18,7 +18,6 @@ package org.apache.spark.sql.hive import java.io.IOException -import java.net.URI import org.apache.hadoop.fs.{FileSystem, Path} import org.apache.hadoop.hive.common.StatsSetupConst @@ -31,9 +30,11 @@ import org.apache.spark.sql.catalyst.plans.logical.{InsertIntoTable, LogicalPlan import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.command.{CreateTableCommand, DDLUtils} -import org.apache.spark.sql.execution.datasources.CreateTable +import org.apache.spark.sql.execution.datasources.{CreateTable, LogicalRelation} +import org.apache.spark.sql.execution.datasources.parquet.{ParquetFileFormat, ParquetOptions} import org.apache.spark.sql.hive.execution._ -import org.apache.spark.sql.internal.HiveSerDe +import org.apache.spark.sql.hive.orc.OrcFileFormat +import org.apache.spark.sql.internal.{HiveSerDe, SQLConf} /** @@ -170,6 +171,55 @@ object HiveAnalysis extends Rule[LogicalPlan] { } } +/** + * Relation conversion from metastore relations to data source relations for better performance + * + * - When writing to non-partitioned Hive-serde Parquet/Orc tables + * - When scanning Hive-serde Parquet/ORC tables + * + * This rule must be run before all other DDL post-hoc resolution rules, i.e. + * `PreprocessTableCreation`, `PreprocessTableInsertion`, `DataSourceAnalysis` and `HiveAnalysis`. + */ +case class RelationConversions( + conf: SQLConf, + sessionCatalog: HiveSessionCatalog) extends Rule[LogicalPlan] { + private def isConvertible(relation: CatalogRelation): Boolean = { + (relation.tableMeta.storage.serde.getOrElse("").toLowerCase.contains("parquet") && + conf.getConf(HiveUtils.CONVERT_METASTORE_PARQUET)) || + (relation.tableMeta.storage.serde.getOrElse("").toLowerCase.contains("orc") && + conf.getConf(HiveUtils.CONVERT_METASTORE_ORC)) + } + + private def convert(relation: CatalogRelation): LogicalRelation = { + if (relation.tableMeta.storage.serde.getOrElse("").toLowerCase.contains("parquet")) { + val options = Map(ParquetOptions.MERGE_SCHEMA -> + conf.getConf(HiveUtils.CONVERT_METASTORE_PARQUET_WITH_SCHEMA_MERGING).toString) + sessionCatalog.metastoreCatalog + .convertToLogicalRelation(relation, options, classOf[ParquetFileFormat], "parquet") + } else { + val options = Map[String, String]() + sessionCatalog.metastoreCatalog + .convertToLogicalRelation(relation, options, classOf[OrcFileFormat], "orc") + } + } + + override def apply(plan: LogicalPlan): LogicalPlan = { + plan transformUp { + // Write path + case InsertIntoTable(r: CatalogRelation, partition, query, overwrite, ifNotExists) + // Inserting into partitioned table is not supported in Parquet/Orc data source (yet). + if query.resolved && DDLUtils.isHiveTable(r.tableMeta) && + !r.isPartitioned && isConvertible(r) => + InsertIntoTable(convert(r), partition, query, overwrite, ifNotExists) + + // Read path + case relation: CatalogRelation + if DDLUtils.isHiveTable(relation.tableMeta) && isConvertible(relation) => + convert(relation) + } + } +} + private[hive] trait HiveStrategies { // Possibly being too clever with types here... or not clever enough. self: SparkPlanner => diff --git a/sql/hive/src/test/java/org/apache/spark/sql/hive/JavaMetastoreDataSourcesSuite.java b/sql/hive/src/test/java/org/apache/spark/sql/hive/JavaMetastoreDataSourcesSuite.java index 0b157a45e6e05..25bd4d0017bd8 100644 --- a/sql/hive/src/test/java/org/apache/spark/sql/hive/JavaMetastoreDataSourcesSuite.java +++ b/sql/hive/src/test/java/org/apache/spark/sql/hive/JavaMetastoreDataSourcesSuite.java @@ -72,8 +72,7 @@ public void setUp() throws IOException { path.delete(); } HiveSessionCatalog catalog = (HiveSessionCatalog) sqlContext.sessionState().catalog(); - hiveManagedPath = new Path( - catalog.hiveDefaultTableFilePath(new TableIdentifier("javaSavedTable"))); + hiveManagedPath = new Path(catalog.defaultTablePath(new TableIdentifier("javaSavedTable"))); fs = hiveManagedPath.getFileSystem(sc.hadoopConfiguration()); fs.delete(hiveManagedPath, true); diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala index 962998ea6fb68..3191b9975fbf9 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala @@ -413,7 +413,7 @@ class StatisticsSuite extends StatisticsCollectionTestBase with TestHiveSingleto } // Table lookup will make the table cached. spark.table(tableIndent) - statsBeforeUpdate = catalog.getCachedDataSourceTable(tableIndent) + statsBeforeUpdate = catalog.metastoreCatalog.getCachedDataSourceTable(tableIndent) .asInstanceOf[LogicalRelation].catalogTable.get.stats.get sql(s"INSERT INTO $tableName SELECT 2") @@ -423,7 +423,7 @@ class StatisticsSuite extends StatisticsCollectionTestBase with TestHiveSingleto sql(s"ANALYZE TABLE $tableName COMPUTE STATISTICS") } spark.table(tableIndent) - statsAfterUpdate = catalog.getCachedDataSourceTable(tableIndent) + statsAfterUpdate = catalog.metastoreCatalog.getCachedDataSourceTable(tableIndent) .asInstanceOf[LogicalRelation].catalogTable.get.stats.get } (statsBeforeUpdate, statsAfterUpdate) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala index 9fc2923bb6fd8..23f21e6b9931e 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala @@ -449,8 +449,9 @@ class ParquetMetastoreSuite extends ParquetPartitioningTest { } } - private def getCachedDataSourceTable(id: TableIdentifier): LogicalPlan = { - sessionState.catalog.asInstanceOf[HiveSessionCatalog].getCachedDataSourceTable(id) + private def getCachedDataSourceTable(table: TableIdentifier): LogicalPlan = { + sessionState.catalog.asInstanceOf[HiveSessionCatalog].metastoreCatalog + .getCachedDataSourceTable(table) } test("Caching converted data source Parquet Relations") { From 567a50acfb0ae26bd430c290348886d494963696 Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Fri, 31 Mar 2017 10:58:43 -0700 Subject: [PATCH 0167/1765] [SPARK-20165][SS] Resolve state encoder's deserializer in driver in FlatMapGroupsWithStateExec ## What changes were proposed in this pull request? - Encoder's deserializer must be resolved at the driver where the class is defined. Otherwise there are corner cases using nested classes where resolving at the executor can fail. - Fixed flaky test related to processing time timeout. The flakiness is caused because the test thread (that adds data to memory source) has a race condition with the streaming query thread. When testing the manual clock, the goal is to add data and increment clock together atomically, such that a trigger sees new data AND updated clock simultaneously (both or none). This fix adds additional synchronization in when adding data; it makes sure that the streaming query thread is waiting on the manual clock to be incremented (so no batch is currently running) before adding data. - Added`testQuietly` on some tests that generate a lot of error logs. ## How was this patch tested? Multiple runs on existing unit tests Author: Tathagata Das Closes #17488 from tdas/SPARK-20165. --- .../FlatMapGroupsWithStateExec.scala | 28 +++++++++++-------- .../sql/streaming/FileStreamSourceSuite.scala | 4 +-- .../FlatMapGroupsWithStateSuite.scala | 7 +++-- .../spark/sql/streaming/StreamSuite.scala | 2 +- .../spark/sql/streaming/StreamTest.scala | 23 +++++++++++++-- .../sql/streaming/StreamingQuerySuite.scala | 2 +- 6 files changed, 45 insertions(+), 21 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala index c7262ea97200f..e42df5dd61c70 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala @@ -68,6 +68,20 @@ case class FlatMapGroupsWithStateExec( val encSchemaAttribs = stateEncoder.schema.toAttributes if (isTimeoutEnabled) encSchemaAttribs :+ timestampTimeoutAttribute else encSchemaAttribs } + // Get the serializer for the state, taking into account whether we need to save timestamps + private val stateSerializer = { + val encoderSerializer = stateEncoder.namedExpressions + if (isTimeoutEnabled) { + encoderSerializer :+ Literal(GroupStateImpl.NO_TIMESTAMP) + } else { + encoderSerializer + } + } + // Get the deserializer for the state. Note that this must be done in the driver, as + // resolving and binding of deserializer expressions to the encoded type can be safely done + // only in the driver. + private val stateDeserializer = stateEncoder.resolveAndBind().deserializer + /** Distribute by grouping attributes */ override def requiredChildDistribution: Seq[Distribution] = @@ -139,19 +153,9 @@ case class FlatMapGroupsWithStateExec( ObjectOperator.deserializeRowToObject(valueDeserializer, dataAttributes) private val getOutputRow = ObjectOperator.wrapObjectToRow(outputObjAttr.dataType) - // Converter for translating state rows to Java objects + // Converters for translating state between rows and Java objects private val getStateObjFromRow = ObjectOperator.deserializeRowToObject( - stateEncoder.resolveAndBind().deserializer, stateAttributes) - - // Converter for translating state Java objects to rows - private val stateSerializer = { - val encoderSerializer = stateEncoder.namedExpressions - if (isTimeoutEnabled) { - encoderSerializer :+ Literal(GroupStateImpl.NO_TIMESTAMP) - } else { - encoderSerializer - } - } + stateDeserializer, stateAttributes) private val getStateRowFromObj = ObjectOperator.serializeObjectToRow(stateSerializer) // Index of the additional metadata fields in the state row diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSourceSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSourceSuite.scala index f705da3d6a709..171877abe6e92 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSourceSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSourceSuite.scala @@ -909,7 +909,7 @@ class FileStreamSourceSuite extends FileStreamSourceTest { } } - test("max files per trigger - incorrect values") { + testQuietly("max files per trigger - incorrect values") { val testTable = "maxFilesPerTrigger_test" withTable(testTable) { withTempDir { case src => @@ -1326,7 +1326,7 @@ class FileStreamSourceStressTestSuite extends FileStreamSourceTest { import testImplicits._ - test("file source stress test") { + testQuietly("file source stress test") { val src = Utils.createTempDir(namePrefix = "streaming.src") val tmp = Utils.createTempDir(namePrefix = "streaming.tmp") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala index a00a1a582a971..c8e31e3ca2e04 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala @@ -21,6 +21,8 @@ import java.sql.Date import java.util.concurrent.ConcurrentHashMap import org.scalatest.BeforeAndAfterAll +import org.scalatest.concurrent.Eventually.eventually +import org.scalatest.concurrent.PatienceConfiguration.Timeout import org.apache.spark.SparkException import org.apache.spark.api.java.function.FlatMapGroupsWithStateFunction @@ -574,11 +576,10 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest with BeforeAndAf assertNumStateRows(total = 1, updated = 2), StopStream, - StartStream(ProcessingTime("1 second"), triggerClock = clock), - AdvanceManualClock(10 * 1000), + StartStream(Trigger.ProcessingTime("1 second"), triggerClock = clock), AddData(inputData, "c"), - AdvanceManualClock(1 * 1000), + AdvanceManualClock(11 * 1000), CheckLastBatch(("b", "-1"), ("c", "1")), assertNumStateRows(total = 1, updated = 2), diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala index 32920f6dfa223..388f15405e70b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala @@ -426,7 +426,7 @@ class StreamSuite extends StreamTest { CheckAnswer((1, 2), (2, 2), (3, 2))) } - test("recover from a Spark v2.1 checkpoint") { + testQuietly("recover from a Spark v2.1 checkpoint") { var inputData: MemoryStream[Int] = null var query: DataStreamWriter[Row] = null diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala index 8cf1791336814..951ff2ca0d684 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala @@ -488,8 +488,27 @@ trait StreamTest extends QueryTest with SharedSQLContext with Timeouts { case a: AddData => try { - // Add data and get the source where it was added, and the expected offset of the - // added data. + + // If the query is running with manual clock, then wait for the stream execution + // thread to start waiting for the clock to increment. This is needed so that we + // are adding data when there is no trigger that is active. This would ensure that + // the data gets deterministically added to the next batch triggered after the manual + // clock is incremented in following AdvanceManualClock. This avoid race conditions + // between the test thread and the stream execution thread in tests using manual + // clock. + if (currentStream != null && + currentStream.triggerClock.isInstanceOf[StreamManualClock]) { + val clock = currentStream.triggerClock.asInstanceOf[StreamManualClock] + eventually("Error while synchronizing with manual clock before adding data") { + if (currentStream.isActive) { + assert(clock.isStreamWaitingAt(clock.getTimeMillis())) + } + } + if (!currentStream.isActive) { + failTest("Query terminated while synchronizing with manual clock") + } + } + // Add data val queryToUse = Option(currentStream).orElse(Option(lastStream)) val (source, offset) = a.addData(queryToUse) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala index 3f41ecdb7ff68..1172531fe9988 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala @@ -487,7 +487,7 @@ class StreamingQuerySuite extends StreamTest with BeforeAndAfter with Logging wi } } - test("StreamingQuery should be Serializable but cannot be used in executors") { + testQuietly("StreamingQuery should be Serializable but cannot be used in executors") { def startQuery(ds: Dataset[Int], queryName: String): StreamingQuery = { ds.writeStream .queryName(queryName) From cf5963c961e7eba37bdd58658ed4dfff66ce3c72 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=83=AD=E5=B0=8F=E9=BE=99=2010207633?= Date: Sat, 1 Apr 2017 11:48:58 +0100 Subject: [PATCH 0168/1765] =?UTF-8?q?[SPARK-20177]=20Document=20about=20co?= =?UTF-8?q?mpression=20way=20has=20some=20little=20detail=20ch=E2=80=A6?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit …anges. ## What changes were proposed in this pull request? Document compression way little detail changes. 1.spark.eventLog.compress add 'Compression will use spark.io.compression.codec.' 2.spark.broadcast.compress add 'Compression will use spark.io.compression.codec.' 3,spark.rdd.compress add 'Compression will use spark.io.compression.codec.' 4.spark.io.compression.codec add 'event log describe'. eg Through the documents, I don't know what is compression mode about 'event log'. ## How was this patch tested? manual tests Please review http://spark.apache.org/contributing.html before opening a pull request. Author: 郭小龙 10207633 Closes #17498 from guoxiaolongzte/SPARK-20177. --- docs/configuration.md | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/docs/configuration.md b/docs/configuration.md index a9753925407d7..2687f542b8bd3 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -639,6 +639,7 @@ Apart from these, the following properties are also available, and may be useful false Whether to compress logged events, if spark.eventLog.enabled is true. + Compression will use spark.io.compression.codec. @@ -773,14 +774,15 @@ Apart from these, the following properties are also available, and may be useful true Whether to compress broadcast variables before sending them. Generally a good idea. + Compression will use spark.io.compression.codec. spark.io.compression.codec lz4 - The codec used to compress internal data such as RDD partitions, broadcast variables and - shuffle outputs. By default, Spark provides three codecs: lz4, lzf, + The codec used to compress internal data such as RDD partitions, event log, broadcast variables + and shuffle outputs. By default, Spark provides three codecs: lz4, lzf, and snappy. You can also use fully qualified class names to specify the codec, e.g. org.apache.spark.io.LZ4CompressionCodec, @@ -881,6 +883,7 @@ Apart from these, the following properties are also available, and may be useful StorageLevel.MEMORY_ONLY_SER in Java and Scala or StorageLevel.MEMORY_ONLY in Python). Can save substantial space at the cost of some extra CPU time. + Compression will use spark.io.compression.codec. From 89d6822f722912d2b05571a95a539092091650b5 Mon Sep 17 00:00:00 2001 From: Xiao Li Date: Sat, 1 Apr 2017 20:43:13 +0800 Subject: [PATCH 0169/1765] [SPARK-19148][SQL][FOLLOW-UP] do not expose the external table concept in Catalog ### What changes were proposed in this pull request? After we renames `Catalog`.`createExternalTable` to `createTable` in the PR: https://github.com/apache/spark/pull/16528, we also need to deprecate the corresponding functions in `SQLContext`. ### How was this patch tested? N/A Author: Xiao Li Closes #17502 from gatorsmile/deprecateCreateExternalTable. --- .../org/apache/spark/sql/SQLContext.scala | 25 +++++++++++-------- 1 file changed, 15 insertions(+), 10 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index 234ef2dffc6bc..cc2983987eb90 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -17,7 +17,6 @@ package org.apache.spark.sql -import java.beans.BeanInfo import java.util.Properties import scala.collection.immutable @@ -527,8 +526,9 @@ class SQLContext private[sql](val sparkSession: SparkSession) * @group ddl_ops * @since 1.3.0 */ + @deprecated("use sparkSession.catalog.createTable instead.", "2.2.0") def createExternalTable(tableName: String, path: String): DataFrame = { - sparkSession.catalog.createExternalTable(tableName, path) + sparkSession.catalog.createTable(tableName, path) } /** @@ -538,11 +538,12 @@ class SQLContext private[sql](val sparkSession: SparkSession) * @group ddl_ops * @since 1.3.0 */ + @deprecated("use sparkSession.catalog.createTable instead.", "2.2.0") def createExternalTable( tableName: String, path: String, source: String): DataFrame = { - sparkSession.catalog.createExternalTable(tableName, path, source) + sparkSession.catalog.createTable(tableName, path, source) } /** @@ -552,11 +553,12 @@ class SQLContext private[sql](val sparkSession: SparkSession) * @group ddl_ops * @since 1.3.0 */ + @deprecated("use sparkSession.catalog.createTable instead.", "2.2.0") def createExternalTable( tableName: String, source: String, options: java.util.Map[String, String]): DataFrame = { - sparkSession.catalog.createExternalTable(tableName, source, options) + sparkSession.catalog.createTable(tableName, source, options) } /** @@ -567,11 +569,12 @@ class SQLContext private[sql](val sparkSession: SparkSession) * @group ddl_ops * @since 1.3.0 */ + @deprecated("use sparkSession.catalog.createTable instead.", "2.2.0") def createExternalTable( tableName: String, source: String, options: Map[String, String]): DataFrame = { - sparkSession.catalog.createExternalTable(tableName, source, options) + sparkSession.catalog.createTable(tableName, source, options) } /** @@ -581,12 +584,13 @@ class SQLContext private[sql](val sparkSession: SparkSession) * @group ddl_ops * @since 1.3.0 */ + @deprecated("use sparkSession.catalog.createTable instead.", "2.2.0") def createExternalTable( tableName: String, source: String, schema: StructType, options: java.util.Map[String, String]): DataFrame = { - sparkSession.catalog.createExternalTable(tableName, source, schema, options) + sparkSession.catalog.createTable(tableName, source, schema, options) } /** @@ -597,12 +601,13 @@ class SQLContext private[sql](val sparkSession: SparkSession) * @group ddl_ops * @since 1.3.0 */ + @deprecated("use sparkSession.catalog.createTable instead.", "2.2.0") def createExternalTable( tableName: String, source: String, schema: StructType, options: Map[String, String]): DataFrame = { - sparkSession.catalog.createExternalTable(tableName, source, schema, options) + sparkSession.catalog.createTable(tableName, source, schema, options) } /** @@ -1089,9 +1094,9 @@ object SQLContext { * method for internal use. */ private[sql] def beansToRows( - data: Iterator[_], - beanClass: Class[_], - attrs: Seq[AttributeReference]): Iterator[InternalRow] = { + data: Iterator[_], + beanClass: Class[_], + attrs: Seq[AttributeReference]): Iterator[InternalRow] = { val extractors = JavaTypeInference.getJavaBeanReadableProperties(beanClass).map(_.getReadMethod) val methodsToConverts = extractors.zip(attrs).map { case (e, attr) => From 2287f3d0b85730995bedc489a017de5700d6e1e4 Mon Sep 17 00:00:00 2001 From: wangzhenhua Date: Sat, 1 Apr 2017 22:19:08 +0800 Subject: [PATCH 0170/1765] [SPARK-20186][SQL] BroadcastHint should use child's stats ## What changes were proposed in this pull request? `BroadcastHint` should use child's statistics and set `isBroadcastable` to true. ## How was this patch tested? Added a new stats estimation test for `BroadcastHint`. Author: wangzhenhua Closes #17504 from wzhfy/broadcastHintEstimation. --- .../plans/logical/basicLogicalOperators.scala | 2 +- .../BasicStatsEstimationSuite.scala | 21 ++++++++++++++++++- 2 files changed, 21 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala index 5cbf263d1ce42..19db42c80895c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala @@ -383,7 +383,7 @@ case class BroadcastHint(child: LogicalPlan) extends UnaryNode { // set isBroadcastable to true so the child will be broadcasted override def computeStats(conf: CatalystConf): Statistics = - super.computeStats(conf).copy(isBroadcastable = true) + child.stats(conf).copy(isBroadcastable = true) } /** diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/BasicStatsEstimationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/BasicStatsEstimationSuite.scala index e5dc811c8b7db..0d92c1e35565a 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/BasicStatsEstimationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/BasicStatsEstimationSuite.scala @@ -35,6 +35,23 @@ class BasicStatsEstimationSuite extends StatsEstimationTestBase { // row count * (overhead + column size) size = Some(10 * (8 + 4))) + test("BroadcastHint estimation") { + val filter = Filter(Literal(true), plan) + val filterStatsCboOn = Statistics(sizeInBytes = 10 * (8 +4), isBroadcastable = false, + rowCount = Some(10), attributeStats = AttributeMap(Seq(attribute -> colStat))) + val filterStatsCboOff = Statistics(sizeInBytes = 10 * (8 +4), isBroadcastable = false) + checkStats( + filter, + expectedStatsCboOn = filterStatsCboOn, + expectedStatsCboOff = filterStatsCboOff) + + val broadcastHint = BroadcastHint(filter) + checkStats( + broadcastHint, + expectedStatsCboOn = filterStatsCboOn.copy(isBroadcastable = true), + expectedStatsCboOff = filterStatsCboOff.copy(isBroadcastable = true)) + } + test("limit estimation: limit < child's rowCount") { val localLimit = LocalLimit(Literal(2), plan) val globalLimit = GlobalLimit(Literal(2), plan) @@ -97,8 +114,10 @@ class BasicStatsEstimationSuite extends StatsEstimationTestBase { plan: LogicalPlan, expectedStatsCboOn: Statistics, expectedStatsCboOff: Statistics): Unit = { - assert(plan.stats(conf.copy(cboEnabled = true)) == expectedStatsCboOn) // Invalidate statistics + plan.invalidateStatsCache() + assert(plan.stats(conf.copy(cboEnabled = true)) == expectedStatsCboOn) + plan.invalidateStatsCache() assert(plan.stats(conf.copy(cboEnabled = false)) == expectedStatsCboOff) } From d40cbb861898de881621d5053a468af570d72127 Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Sun, 2 Apr 2017 07:26:49 -0700 Subject: [PATCH 0171/1765] [SPARK-20143][SQL] DataType.fromJson should throw an exception with better message ## What changes were proposed in this pull request? Currently, `DataType.fromJson` throws `scala.MatchError` or `java.util.NoSuchElementException` in some cases when the JSON input is invalid as below: ```scala DataType.fromJson(""""abcd"""") ``` ``` java.util.NoSuchElementException: key not found: abcd at ... ``` ```scala DataType.fromJson("""{"abcd":"a"}""") ``` ``` scala.MatchError: JObject(List((abcd,JString(a)))) (of class org.json4s.JsonAST$JObject) at ... ``` ```scala DataType.fromJson("""{"fields": [{"a":123}], "type": "struct"}""") ``` ``` scala.MatchError: JObject(List((a,JInt(123)))) (of class org.json4s.JsonAST$JObject) at ... ``` After this PR, ```scala DataType.fromJson(""""abcd"""") ``` ``` java.lang.IllegalArgumentException: Failed to convert the JSON string 'abcd' to a data type. at ... ``` ```scala DataType.fromJson("""{"abcd":"a"}""") ``` ``` java.lang.IllegalArgumentException: Failed to convert the JSON string '{"abcd":"a"}' to a data type. at ... ``` ```scala DataType.fromJson("""{"fields": [{"a":123}], "type": "struct"}""") at ... ``` ``` java.lang.IllegalArgumentException: Failed to convert the JSON string '{"a":123}' to a field. ``` ## How was this patch tested? Unit test added in `DataTypeSuite`. Author: hyukjinkwon Closes #17468 from HyukjinKwon/fromjson_exception. --- .../org/apache/spark/sql/types/DataType.scala | 12 +++++++- .../spark/sql/types/DataTypeSuite.scala | 28 +++++++++++++++++++ 2 files changed, 39 insertions(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala index 2642d9395ba88..26871259c6b6e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala @@ -115,7 +115,10 @@ object DataType { name match { case "decimal" => DecimalType.USER_DEFAULT case FIXED_DECIMAL(precision, scale) => DecimalType(precision.toInt, scale.toInt) - case other => nonDecimalNameToType(other) + case other => nonDecimalNameToType.getOrElse( + other, + throw new IllegalArgumentException( + s"Failed to convert the JSON string '$name' to a data type.")) } } @@ -164,6 +167,10 @@ object DataType { ("sqlType", v: JValue), ("type", JString("udt"))) => new PythonUserDefinedType(parseDataType(v), pyClass, serialized) + + case other => + throw new IllegalArgumentException( + s"Failed to convert the JSON string '${compact(render(other))}' to a data type.") } private def parseStructField(json: JValue): StructField = json match { @@ -179,6 +186,9 @@ object DataType { ("nullable", JBool(nullable)), ("type", dataType: JValue)) => StructField(name, parseDataType(dataType), nullable) + case other => + throw new IllegalArgumentException( + s"Failed to convert the JSON string '${compact(render(other))}' to a field.") } protected[types] def buildFormattedString( diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala index 05cb999af6a50..f078ef013387b 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.types +import com.fasterxml.jackson.core.JsonParseException + import org.apache.spark.{SparkException, SparkFunSuite} import org.apache.spark.sql.catalyst.parser.CatalystSqlParser @@ -246,6 +248,32 @@ class DataTypeSuite extends SparkFunSuite { checkDataTypeFromJson(structType) checkDataTypeFromDDL(structType) + test("fromJson throws an exception when given type string is invalid") { + var message = intercept[IllegalArgumentException] { + DataType.fromJson(""""abcd"""") + }.getMessage + assert(message.contains( + "Failed to convert the JSON string 'abcd' to a data type.")) + + message = intercept[IllegalArgumentException] { + DataType.fromJson("""{"abcd":"a"}""") + }.getMessage + assert(message.contains( + """Failed to convert the JSON string '{"abcd":"a"}' to a data type""")) + + message = intercept[IllegalArgumentException] { + DataType.fromJson("""{"fields": [{"a":123}], "type": "struct"}""") + }.getMessage + assert(message.contains( + """Failed to convert the JSON string '{"a":123}' to a field.""")) + + // Malformed JSON string + message = intercept[JsonParseException] { + DataType.fromJson("abcd") + }.getMessage + assert(message.contains("Unrecognized token 'abcd'")) + } + def checkDefaultSize(dataType: DataType, expectedDefaultSize: Int): Unit = { test(s"Check the default size of $dataType") { assert(dataType.defaultSize === expectedDefaultSize) From 76de2d115364aa6a1fdaacdfae05f0c695c953b8 Mon Sep 17 00:00:00 2001 From: zuotingbing Date: Sun, 2 Apr 2017 15:31:13 +0100 Subject: [PATCH 0172/1765] =?UTF-8?q?[SPARK-20123][BUILD]=20SPARK=5FHOME?= =?UTF-8?q?=20variable=20might=20have=20spaces=20in=20it(e.g.=20$SPARK?= =?UTF-8?q?=E2=80=A6?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit JIRA Issue: https://issues.apache.org/jira/browse/SPARK-20123 ## What changes were proposed in this pull request? If $SPARK_HOME or $FWDIR variable contains spaces, then use "./dev/make-distribution.sh --name custom-spark --tgz -Psparkr -Phadoop-2.7 -Phive -Phive-thriftserver -Pmesos -Pyarn" build spark will failed. ## How was this patch tested? manual tests Author: zuotingbing Closes #17452 from zuotingbing/spark-bulid. --- R/check-cran.sh | 20 ++++++++++---------- R/create-docs.sh | 10 +++++----- R/create-rd.sh | 8 ++++---- R/install-dev.sh | 14 +++++++------- R/install-source-package.sh | 20 ++++++++++---------- dev/make-distribution.sh | 32 ++++++++++++++++---------------- 6 files changed, 52 insertions(+), 52 deletions(-) diff --git a/R/check-cran.sh b/R/check-cran.sh index a188b1448a67b..22cc9c6b601fc 100755 --- a/R/check-cran.sh +++ b/R/check-cran.sh @@ -20,18 +20,18 @@ set -o pipefail set -e -FWDIR="$(cd `dirname "${BASH_SOURCE[0]}"`; pwd)" -pushd $FWDIR > /dev/null +FWDIR="$(cd "`dirname "${BASH_SOURCE[0]}"`"; pwd)" +pushd "$FWDIR" > /dev/null -. $FWDIR/find-r.sh +. "$FWDIR/find-r.sh" # Install the package (this is required for code in vignettes to run when building it later) # Build the latest docs, but not vignettes, which is built with the package next -. $FWDIR/install-dev.sh +. "$FWDIR/install-dev.sh" # Build source package with vignettes SPARK_HOME="$(cd "${FWDIR}"/..; pwd)" -. "${SPARK_HOME}"/bin/load-spark-env.sh +. "${SPARK_HOME}/bin/load-spark-env.sh" if [ -f "${SPARK_HOME}/RELEASE" ]; then SPARK_JARS_DIR="${SPARK_HOME}/jars" else @@ -40,16 +40,16 @@ fi if [ -d "$SPARK_JARS_DIR" ]; then # Build a zip file containing the source package with vignettes - SPARK_HOME="${SPARK_HOME}" "$R_SCRIPT_PATH/"R CMD build $FWDIR/pkg + SPARK_HOME="${SPARK_HOME}" "$R_SCRIPT_PATH/R" CMD build "$FWDIR/pkg" find pkg/vignettes/. -not -name '.' -not -name '*.Rmd' -not -name '*.md' -not -name '*.pdf' -not -name '*.html' -delete else - echo "Error Spark JARs not found in $SPARK_HOME" + echo "Error Spark JARs not found in '$SPARK_HOME'" exit 1 fi # Run check as-cran. -VERSION=`grep Version $FWDIR/pkg/DESCRIPTION | awk '{print $NF}'` +VERSION=`grep Version "$FWDIR/pkg/DESCRIPTION" | awk '{print $NF}'` CRAN_CHECK_OPTIONS="--as-cran" @@ -67,10 +67,10 @@ echo "Running CRAN check with $CRAN_CHECK_OPTIONS options" if [ -n "$NO_TESTS" ] && [ -n "$NO_MANUAL" ] then - "$R_SCRIPT_PATH/"R CMD check $CRAN_CHECK_OPTIONS SparkR_"$VERSION".tar.gz + "$R_SCRIPT_PATH/R" CMD check $CRAN_CHECK_OPTIONS "SparkR_$VERSION.tar.gz" else # This will run tests and/or build vignettes, and require SPARK_HOME - SPARK_HOME="${SPARK_HOME}" "$R_SCRIPT_PATH/"R CMD check $CRAN_CHECK_OPTIONS SparkR_"$VERSION".tar.gz + SPARK_HOME="${SPARK_HOME}" "$R_SCRIPT_PATH/R" CMD check $CRAN_CHECK_OPTIONS "SparkR_$VERSION.tar.gz" fi popd > /dev/null diff --git a/R/create-docs.sh b/R/create-docs.sh index 6bef7e75e3bd8..310dbc5fb50a3 100755 --- a/R/create-docs.sh +++ b/R/create-docs.sh @@ -33,15 +33,15 @@ export FWDIR="$(cd "`dirname "${BASH_SOURCE[0]}"`"; pwd)" export SPARK_HOME="$(cd "`dirname "${BASH_SOURCE[0]}"`"/..; pwd)" # Required for setting SPARK_SCALA_VERSION -. "${SPARK_HOME}"/bin/load-spark-env.sh +. "${SPARK_HOME}/bin/load-spark-env.sh" echo "Using Scala $SPARK_SCALA_VERSION" -pushd $FWDIR > /dev/null -. $FWDIR/find-r.sh +pushd "$FWDIR" > /dev/null +. "$FWDIR/find-r.sh" # Install the package (this will also generate the Rd files) -. $FWDIR/install-dev.sh +. "$FWDIR/install-dev.sh" # Now create HTML files @@ -49,7 +49,7 @@ pushd $FWDIR > /dev/null mkdir -p pkg/html pushd pkg/html -"$R_SCRIPT_PATH/"Rscript -e 'libDir <- "../../lib"; library(SparkR, lib.loc=libDir); library(knitr); knit_rd("SparkR", links = tools::findHTMLlinks(paste(libDir, "SparkR", sep="/")))' +"$R_SCRIPT_PATH/Rscript" -e 'libDir <- "../../lib"; library(SparkR, lib.loc=libDir); library(knitr); knit_rd("SparkR", links = tools::findHTMLlinks(paste(libDir, "SparkR", sep="/")))' popd diff --git a/R/create-rd.sh b/R/create-rd.sh index d17e1617397d1..ff622a41a46c0 100755 --- a/R/create-rd.sh +++ b/R/create-rd.sh @@ -29,9 +29,9 @@ set -o pipefail set -e -FWDIR="$(cd `dirname "${BASH_SOURCE[0]}"`; pwd)" -pushd $FWDIR > /dev/null -. $FWDIR/find-r.sh +FWDIR="$(cd "`dirname "${BASH_SOURCE[0]}"`"; pwd)" +pushd "$FWDIR" > /dev/null +. "$FWDIR/find-r.sh" # Generate Rd files if devtools is installed -"$R_SCRIPT_PATH/"Rscript -e ' if("devtools" %in% rownames(installed.packages())) { library(devtools); devtools::document(pkg="./pkg", roclets=c("rd")) }' +"$R_SCRIPT_PATH/Rscript" -e ' if("devtools" %in% rownames(installed.packages())) { library(devtools); devtools::document(pkg="./pkg", roclets=c("rd")) }' diff --git a/R/install-dev.sh b/R/install-dev.sh index 45e6411705814..d613552718307 100755 --- a/R/install-dev.sh +++ b/R/install-dev.sh @@ -29,21 +29,21 @@ set -o pipefail set -e -FWDIR="$(cd `dirname "${BASH_SOURCE[0]}"`; pwd)" +FWDIR="$(cd "`dirname "${BASH_SOURCE[0]}"`"; pwd)" LIB_DIR="$FWDIR/lib" -mkdir -p $LIB_DIR +mkdir -p "$LIB_DIR" -pushd $FWDIR > /dev/null -. $FWDIR/find-r.sh +pushd "$FWDIR" > /dev/null +. "$FWDIR/find-r.sh" -. $FWDIR/create-rd.sh +. "$FWDIR/create-rd.sh" # Install SparkR to $LIB_DIR -"$R_SCRIPT_PATH/"R CMD INSTALL --library=$LIB_DIR $FWDIR/pkg/ +"$R_SCRIPT_PATH/R" CMD INSTALL --library="$LIB_DIR" "$FWDIR/pkg/" # Zip the SparkR package so that it can be distributed to worker nodes on YARN -cd $LIB_DIR +cd "$LIB_DIR" jar cfM "$LIB_DIR/sparkr.zip" SparkR popd > /dev/null diff --git a/R/install-source-package.sh b/R/install-source-package.sh index c6e443c04e628..8de3569d1d482 100755 --- a/R/install-source-package.sh +++ b/R/install-source-package.sh @@ -29,28 +29,28 @@ set -o pipefail set -e -FWDIR="$(cd `dirname "${BASH_SOURCE[0]}"`; pwd)" -pushd $FWDIR > /dev/null -. $FWDIR/find-r.sh +FWDIR="$(cd "`dirname "${BASH_SOURCE[0]}"`"; pwd)" +pushd "$FWDIR" > /dev/null +. "$FWDIR/find-r.sh" if [ -z "$VERSION" ]; then - VERSION=`grep Version $FWDIR/pkg/DESCRIPTION | awk '{print $NF}'` + VERSION=`grep Version "$FWDIR/pkg/DESCRIPTION" | awk '{print $NF}'` fi -if [ ! -f "$FWDIR"/SparkR_"$VERSION".tar.gz ]; then - echo -e "R source package file $FWDIR/SparkR_$VERSION.tar.gz is not found." +if [ ! -f "$FWDIR/SparkR_$VERSION.tar.gz" ]; then + echo -e "R source package file '$FWDIR/SparkR_$VERSION.tar.gz' is not found." echo -e "Please build R source package with check-cran.sh" exit -1; fi echo "Removing lib path and installing from source package" LIB_DIR="$FWDIR/lib" -rm -rf $LIB_DIR -mkdir -p $LIB_DIR -"$R_SCRIPT_PATH/"R CMD INSTALL SparkR_"$VERSION".tar.gz --library=$LIB_DIR +rm -rf "$LIB_DIR" +mkdir -p "$LIB_DIR" +"$R_SCRIPT_PATH/R" CMD INSTALL "SparkR_$VERSION.tar.gz" --library="$LIB_DIR" # Zip the SparkR package so that it can be distributed to worker nodes on YARN -pushd $LIB_DIR > /dev/null +pushd "$LIB_DIR" > /dev/null jar cfM "$LIB_DIR/sparkr.zip" SparkR popd > /dev/null diff --git a/dev/make-distribution.sh b/dev/make-distribution.sh index 769cbda4fe347..48a824499acb9 100755 --- a/dev/make-distribution.sh +++ b/dev/make-distribution.sh @@ -140,7 +140,7 @@ echo "Spark version is $VERSION" if [ "$MAKE_TGZ" == "true" ]; then echo "Making spark-$VERSION-bin-$NAME.tgz" else - echo "Making distribution for Spark $VERSION in $DISTDIR..." + echo "Making distribution for Spark $VERSION in '$DISTDIR'..." fi # Build uber fat JAR @@ -170,7 +170,7 @@ cp "$SPARK_HOME"/assembly/target/scala*/jars/* "$DISTDIR/jars/" # Only create the yarn directory if the yarn artifacts were build. if [ -f "$SPARK_HOME"/common/network-yarn/target/scala*/spark-*-yarn-shuffle.jar ]; then - mkdir "$DISTDIR"/yarn + mkdir "$DISTDIR/yarn" cp "$SPARK_HOME"/common/network-yarn/target/scala*/spark-*-yarn-shuffle.jar "$DISTDIR/yarn" fi @@ -179,7 +179,7 @@ mkdir -p "$DISTDIR/examples/jars" cp "$SPARK_HOME"/examples/target/scala*/jars/* "$DISTDIR/examples/jars" # Deduplicate jars that have already been packaged as part of the main Spark dependencies. -for f in "$DISTDIR/examples/jars/"*; do +for f in "$DISTDIR"/examples/jars/*; do name=$(basename "$f") if [ -f "$DISTDIR/jars/$name" ]; then rm "$DISTDIR/examples/jars/$name" @@ -188,14 +188,14 @@ done # Copy example sources (needed for python and SQL) mkdir -p "$DISTDIR/examples/src/main" -cp -r "$SPARK_HOME"/examples/src/main "$DISTDIR/examples/src/" +cp -r "$SPARK_HOME/examples/src/main" "$DISTDIR/examples/src/" # Copy license and ASF files cp "$SPARK_HOME/LICENSE" "$DISTDIR" cp -r "$SPARK_HOME/licenses" "$DISTDIR" cp "$SPARK_HOME/NOTICE" "$DISTDIR" -if [ -e "$SPARK_HOME"/CHANGES.txt ]; then +if [ -e "$SPARK_HOME/CHANGES.txt" ]; then cp "$SPARK_HOME/CHANGES.txt" "$DISTDIR" fi @@ -217,43 +217,43 @@ fi # Make R package - this is used for both CRAN release and packing R layout into distribution if [ "$MAKE_R" == "true" ]; then echo "Building R source package" - R_PACKAGE_VERSION=`grep Version $SPARK_HOME/R/pkg/DESCRIPTION | awk '{print $NF}'` + R_PACKAGE_VERSION=`grep Version "$SPARK_HOME/R/pkg/DESCRIPTION" | awk '{print $NF}'` pushd "$SPARK_HOME/R" > /dev/null # Build source package and run full checks # Do not source the check-cran.sh - it should be run from where it is for it to set SPARK_HOME - NO_TESTS=1 "$SPARK_HOME/"R/check-cran.sh + NO_TESTS=1 "$SPARK_HOME/R/check-cran.sh" # Move R source package to match the Spark release version if the versions are not the same. # NOTE(shivaram): `mv` throws an error on Linux if source and destination are same file if [ "$R_PACKAGE_VERSION" != "$VERSION" ]; then - mv $SPARK_HOME/R/SparkR_"$R_PACKAGE_VERSION".tar.gz $SPARK_HOME/R/SparkR_"$VERSION".tar.gz + mv "$SPARK_HOME/R/SparkR_$R_PACKAGE_VERSION.tar.gz" "$SPARK_HOME/R/SparkR_$VERSION.tar.gz" fi # Install source package to get it to generate vignettes rds files, etc. - VERSION=$VERSION "$SPARK_HOME/"R/install-source-package.sh + VERSION=$VERSION "$SPARK_HOME/R/install-source-package.sh" popd > /dev/null else echo "Skipping building R source package" fi # Copy other things -mkdir "$DISTDIR"/conf -cp "$SPARK_HOME"/conf/*.template "$DISTDIR"/conf +mkdir "$DISTDIR/conf" +cp "$SPARK_HOME"/conf/*.template "$DISTDIR/conf" cp "$SPARK_HOME/README.md" "$DISTDIR" cp -r "$SPARK_HOME/bin" "$DISTDIR" cp -r "$SPARK_HOME/python" "$DISTDIR" # Remove the python distribution from dist/ if we built it if [ "$MAKE_PIP" == "true" ]; then - rm -f $DISTDIR/python/dist/pyspark-*.tar.gz + rm -f "$DISTDIR"/python/dist/pyspark-*.tar.gz fi cp -r "$SPARK_HOME/sbin" "$DISTDIR" # Copy SparkR if it exists -if [ -d "$SPARK_HOME"/R/lib/SparkR ]; then - mkdir -p "$DISTDIR"/R/lib - cp -r "$SPARK_HOME/R/lib/SparkR" "$DISTDIR"/R/lib - cp "$SPARK_HOME/R/lib/sparkr.zip" "$DISTDIR"/R/lib +if [ -d "$SPARK_HOME/R/lib/SparkR" ]; then + mkdir -p "$DISTDIR/R/lib" + cp -r "$SPARK_HOME/R/lib/SparkR" "$DISTDIR/R/lib" + cp "$SPARK_HOME/R/lib/sparkr.zip" "$DISTDIR/R/lib" fi if [ "$MAKE_TGZ" == "true" ]; then From 657cb9541db8508ce64d08cc3de14cd02adf16b5 Mon Sep 17 00:00:00 2001 From: zuotingbing Date: Sun, 2 Apr 2017 15:39:51 +0100 Subject: [PATCH 0173/1765] [SPARK-20173][SQL][HIVE-THRIFTSERVER] Throw NullPointerException when HiveThriftServer2 is shutdown ## What changes were proposed in this pull request? If the shutdown hook called before the variable `uiTab` is set , it will throw a NullPointerException. ## How was this patch tested? manual tests Author: zuotingbing Closes #17496 from zuotingbing/SPARK-HiveThriftServer2. --- .../apache/spark/sql/hive/thriftserver/HiveThriftServer2.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2.scala index 13c6f11f461c6..14553601b1d58 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2.scala @@ -46,7 +46,7 @@ import org.apache.spark.util.{ShutdownHookManager, Utils} */ object HiveThriftServer2 extends Logging { var LOG = LogFactory.getLog(classOf[HiveServer2]) - var uiTab: Option[ThriftServerTab] = _ + var uiTab: Option[ThriftServerTab] = None var listener: HiveThriftServer2Listener = _ /** From 93dbfe705f3e7410a7267e406332ffb3c3077829 Mon Sep 17 00:00:00 2001 From: Felix Cheung Date: Sun, 2 Apr 2017 11:59:27 -0700 Subject: [PATCH 0174/1765] [SPARK-20159][SPARKR][SQL] Support all catalog API in R ## What changes were proposed in this pull request? Add a set of catalog API in R ``` "currentDatabase", "listColumns", "listDatabases", "listFunctions", "listTables", "recoverPartitions", "refreshByPath", "refreshTable", "setCurrentDatabase", ``` https://github.com/apache/spark/pull/17483/files#diff-6929e6c5e59017ff954e110df20ed7ff ## How was this patch tested? manual tests, unit tests Author: Felix Cheung Closes #17483 from felixcheung/rcatalog. --- R/pkg/DESCRIPTION | 1 + R/pkg/NAMESPACE | 9 + R/pkg/R/SQLContext.R | 233 ----------- R/pkg/R/catalog.R | 479 ++++++++++++++++++++++ R/pkg/R/utils.R | 18 + R/pkg/inst/tests/testthat/test_sparkSQL.R | 66 ++- 6 files changed, 569 insertions(+), 237 deletions(-) create mode 100644 R/pkg/R/catalog.R diff --git a/R/pkg/DESCRIPTION b/R/pkg/DESCRIPTION index 2ea90f7d3666e..00dde64324ae7 100644 --- a/R/pkg/DESCRIPTION +++ b/R/pkg/DESCRIPTION @@ -32,6 +32,7 @@ Collate: 'pairRDD.R' 'DataFrame.R' 'SQLContext.R' + 'catalog.R' 'WindowSpec.R' 'backend.R' 'broadcast.R' diff --git a/R/pkg/NAMESPACE b/R/pkg/NAMESPACE index 8be7875ad2d5f..c02046c94bf4d 100644 --- a/R/pkg/NAMESPACE +++ b/R/pkg/NAMESPACE @@ -358,9 +358,14 @@ export("as.DataFrame", "clearCache", "createDataFrame", "createExternalTable", + "currentDatabase", "dropTempTable", "dropTempView", "jsonFile", + "listColumns", + "listDatabases", + "listFunctions", + "listTables", "loadDF", "parquetFile", "read.df", @@ -370,7 +375,11 @@ export("as.DataFrame", "read.parquet", "read.stream", "read.text", + "recoverPartitions", + "refreshByPath", + "refreshTable", "setCheckpointDir", + "setCurrentDatabase", "spark.lapply", "spark.addFile", "spark.getSparkFilesRootDirectory", diff --git a/R/pkg/R/SQLContext.R b/R/pkg/R/SQLContext.R index b75fb0159d503..a1edef7608fa1 100644 --- a/R/pkg/R/SQLContext.R +++ b/R/pkg/R/SQLContext.R @@ -569,200 +569,6 @@ tableToDF <- function(tableName) { dataFrame(sdf) } -#' Tables -#' -#' Returns a SparkDataFrame containing names of tables in the given database. -#' -#' @param databaseName name of the database -#' @return a SparkDataFrame -#' @rdname tables -#' @export -#' @examples -#'\dontrun{ -#' sparkR.session() -#' tables("hive") -#' } -#' @name tables -#' @method tables default -#' @note tables since 1.4.0 -tables.default <- function(databaseName = NULL) { - sparkSession <- getSparkSession() - jdf <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", "getTables", sparkSession, databaseName) - dataFrame(jdf) -} - -tables <- function(x, ...) { - dispatchFunc("tables(databaseName = NULL)", x, ...) -} - -#' Table Names -#' -#' Returns the names of tables in the given database as an array. -#' -#' @param databaseName name of the database -#' @return a list of table names -#' @rdname tableNames -#' @export -#' @examples -#'\dontrun{ -#' sparkR.session() -#' tableNames("hive") -#' } -#' @name tableNames -#' @method tableNames default -#' @note tableNames since 1.4.0 -tableNames.default <- function(databaseName = NULL) { - sparkSession <- getSparkSession() - callJStatic("org.apache.spark.sql.api.r.SQLUtils", - "getTableNames", - sparkSession, - databaseName) -} - -tableNames <- function(x, ...) { - dispatchFunc("tableNames(databaseName = NULL)", x, ...) -} - -#' Cache Table -#' -#' Caches the specified table in-memory. -#' -#' @param tableName The name of the table being cached -#' @return SparkDataFrame -#' @rdname cacheTable -#' @export -#' @examples -#'\dontrun{ -#' sparkR.session() -#' path <- "path/to/file.json" -#' df <- read.json(path) -#' createOrReplaceTempView(df, "table") -#' cacheTable("table") -#' } -#' @name cacheTable -#' @method cacheTable default -#' @note cacheTable since 1.4.0 -cacheTable.default <- function(tableName) { - sparkSession <- getSparkSession() - catalog <- callJMethod(sparkSession, "catalog") - invisible(callJMethod(catalog, "cacheTable", tableName)) -} - -cacheTable <- function(x, ...) { - dispatchFunc("cacheTable(tableName)", x, ...) -} - -#' Uncache Table -#' -#' Removes the specified table from the in-memory cache. -#' -#' @param tableName The name of the table being uncached -#' @return SparkDataFrame -#' @rdname uncacheTable -#' @export -#' @examples -#'\dontrun{ -#' sparkR.session() -#' path <- "path/to/file.json" -#' df <- read.json(path) -#' createOrReplaceTempView(df, "table") -#' uncacheTable("table") -#' } -#' @name uncacheTable -#' @method uncacheTable default -#' @note uncacheTable since 1.4.0 -uncacheTable.default <- function(tableName) { - sparkSession <- getSparkSession() - catalog <- callJMethod(sparkSession, "catalog") - invisible(callJMethod(catalog, "uncacheTable", tableName)) -} - -uncacheTable <- function(x, ...) { - dispatchFunc("uncacheTable(tableName)", x, ...) -} - -#' Clear Cache -#' -#' Removes all cached tables from the in-memory cache. -#' -#' @rdname clearCache -#' @export -#' @examples -#' \dontrun{ -#' clearCache() -#' } -#' @name clearCache -#' @method clearCache default -#' @note clearCache since 1.4.0 -clearCache.default <- function() { - sparkSession <- getSparkSession() - catalog <- callJMethod(sparkSession, "catalog") - invisible(callJMethod(catalog, "clearCache")) -} - -clearCache <- function() { - dispatchFunc("clearCache()") -} - -#' (Deprecated) Drop Temporary Table -#' -#' Drops the temporary table with the given table name in the catalog. -#' If the table has been cached/persisted before, it's also unpersisted. -#' -#' @param tableName The name of the SparkSQL table to be dropped. -#' @seealso \link{dropTempView} -#' @rdname dropTempTable-deprecated -#' @export -#' @examples -#' \dontrun{ -#' sparkR.session() -#' df <- read.df(path, "parquet") -#' createOrReplaceTempView(df, "table") -#' dropTempTable("table") -#' } -#' @name dropTempTable -#' @method dropTempTable default -#' @note dropTempTable since 1.4.0 -dropTempTable.default <- function(tableName) { - if (class(tableName) != "character") { - stop("tableName must be a string.") - } - dropTempView(tableName) -} - -dropTempTable <- function(x, ...) { - .Deprecated("dropTempView") - dispatchFunc("dropTempView(viewName)", x, ...) -} - -#' Drops the temporary view with the given view name in the catalog. -#' -#' Drops the temporary view with the given view name in the catalog. -#' If the view has been cached before, then it will also be uncached. -#' -#' @param viewName the name of the view to be dropped. -#' @return TRUE if the view is dropped successfully, FALSE otherwise. -#' @rdname dropTempView -#' @name dropTempView -#' @export -#' @examples -#' \dontrun{ -#' sparkR.session() -#' df <- read.df(path, "parquet") -#' createOrReplaceTempView(df, "table") -#' dropTempView("table") -#' } -#' @note since 2.0.0 - -dropTempView <- function(viewName) { - sparkSession <- getSparkSession() - if (class(viewName) != "character") { - stop("viewName must be a string.") - } - catalog <- callJMethod(sparkSession, "catalog") - callJMethod(catalog, "dropTempView", viewName) -} - #' Load a SparkDataFrame #' #' Returns the dataset in a data source as a SparkDataFrame @@ -841,45 +647,6 @@ loadDF <- function(x = NULL, ...) { dispatchFunc("loadDF(path = NULL, source = NULL, schema = NULL, ...)", x, ...) } -#' Create an external table -#' -#' Creates an external table based on the dataset in a data source, -#' Returns a SparkDataFrame associated with the external table. -#' -#' The data source is specified by the \code{source} and a set of options(...). -#' If \code{source} is not specified, the default data source configured by -#' "spark.sql.sources.default" will be used. -#' -#' @param tableName a name of the table. -#' @param path the path of files to load. -#' @param source the name of external data source. -#' @param ... additional argument(s) passed to the method. -#' @return A SparkDataFrame. -#' @rdname createExternalTable -#' @export -#' @examples -#'\dontrun{ -#' sparkR.session() -#' df <- createExternalTable("myjson", path="path/to/json", source="json") -#' } -#' @name createExternalTable -#' @method createExternalTable default -#' @note createExternalTable since 1.4.0 -createExternalTable.default <- function(tableName, path = NULL, source = NULL, ...) { - sparkSession <- getSparkSession() - options <- varargsToStrEnv(...) - if (!is.null(path)) { - options[["path"]] <- path - } - catalog <- callJMethod(sparkSession, "catalog") - sdf <- callJMethod(catalog, "createExternalTable", tableName, source, options) - dataFrame(sdf) -} - -createExternalTable <- function(x, ...) { - dispatchFunc("createExternalTable(tableName, path = NULL, source = NULL, ...)", x, ...) -} - #' Create a SparkDataFrame representing the database table accessible via JDBC URL #' #' Additional JDBC database connection properties can be set (...) diff --git a/R/pkg/R/catalog.R b/R/pkg/R/catalog.R new file mode 100644 index 0000000000000..07a89f763cde1 --- /dev/null +++ b/R/pkg/R/catalog.R @@ -0,0 +1,479 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# catalog.R: SparkSession catalog functions + +#' Create an external table +#' +#' Creates an external table based on the dataset in a data source, +#' Returns a SparkDataFrame associated with the external table. +#' +#' The data source is specified by the \code{source} and a set of options(...). +#' If \code{source} is not specified, the default data source configured by +#' "spark.sql.sources.default" will be used. +#' +#' @param tableName a name of the table. +#' @param path the path of files to load. +#' @param source the name of external data source. +#' @param schema the schema of the data for certain data source. +#' @param ... additional argument(s) passed to the method. +#' @return A SparkDataFrame. +#' @rdname createExternalTable +#' @export +#' @examples +#'\dontrun{ +#' sparkR.session() +#' df <- createExternalTable("myjson", path="path/to/json", source="json", schema) +#' } +#' @name createExternalTable +#' @method createExternalTable default +#' @note createExternalTable since 1.4.0 +createExternalTable.default <- function(tableName, path = NULL, source = NULL, schema = NULL, ...) { + sparkSession <- getSparkSession() + options <- varargsToStrEnv(...) + if (!is.null(path)) { + options[["path"]] <- path + } + catalog <- callJMethod(sparkSession, "catalog") + if (is.null(schema)) { + sdf <- callJMethod(catalog, "createExternalTable", tableName, source, options) + } else { + sdf <- callJMethod(catalog, "createExternalTable", tableName, source, schema$jobj, options) + } + dataFrame(sdf) +} + +createExternalTable <- function(x, ...) { + dispatchFunc("createExternalTable(tableName, path = NULL, source = NULL, ...)", x, ...) +} + +#' Cache Table +#' +#' Caches the specified table in-memory. +#' +#' @param tableName The name of the table being cached +#' @return SparkDataFrame +#' @rdname cacheTable +#' @export +#' @examples +#'\dontrun{ +#' sparkR.session() +#' path <- "path/to/file.json" +#' df <- read.json(path) +#' createOrReplaceTempView(df, "table") +#' cacheTable("table") +#' } +#' @name cacheTable +#' @method cacheTable default +#' @note cacheTable since 1.4.0 +cacheTable.default <- function(tableName) { + sparkSession <- getSparkSession() + catalog <- callJMethod(sparkSession, "catalog") + invisible(handledCallJMethod(catalog, "cacheTable", tableName)) +} + +cacheTable <- function(x, ...) { + dispatchFunc("cacheTable(tableName)", x, ...) +} + +#' Uncache Table +#' +#' Removes the specified table from the in-memory cache. +#' +#' @param tableName The name of the table being uncached +#' @return SparkDataFrame +#' @rdname uncacheTable +#' @export +#' @examples +#'\dontrun{ +#' sparkR.session() +#' path <- "path/to/file.json" +#' df <- read.json(path) +#' createOrReplaceTempView(df, "table") +#' uncacheTable("table") +#' } +#' @name uncacheTable +#' @method uncacheTable default +#' @note uncacheTable since 1.4.0 +uncacheTable.default <- function(tableName) { + sparkSession <- getSparkSession() + catalog <- callJMethod(sparkSession, "catalog") + invisible(handledCallJMethod(catalog, "uncacheTable", tableName)) +} + +uncacheTable <- function(x, ...) { + dispatchFunc("uncacheTable(tableName)", x, ...) +} + +#' Clear Cache +#' +#' Removes all cached tables from the in-memory cache. +#' +#' @rdname clearCache +#' @export +#' @examples +#' \dontrun{ +#' clearCache() +#' } +#' @name clearCache +#' @method clearCache default +#' @note clearCache since 1.4.0 +clearCache.default <- function() { + sparkSession <- getSparkSession() + catalog <- callJMethod(sparkSession, "catalog") + invisible(callJMethod(catalog, "clearCache")) +} + +clearCache <- function() { + dispatchFunc("clearCache()") +} + +#' (Deprecated) Drop Temporary Table +#' +#' Drops the temporary table with the given table name in the catalog. +#' If the table has been cached/persisted before, it's also unpersisted. +#' +#' @param tableName The name of the SparkSQL table to be dropped. +#' @seealso \link{dropTempView} +#' @rdname dropTempTable-deprecated +#' @export +#' @examples +#' \dontrun{ +#' sparkR.session() +#' df <- read.df(path, "parquet") +#' createOrReplaceTempView(df, "table") +#' dropTempTable("table") +#' } +#' @name dropTempTable +#' @method dropTempTable default +#' @note dropTempTable since 1.4.0 +dropTempTable.default <- function(tableName) { + if (class(tableName) != "character") { + stop("tableName must be a string.") + } + dropTempView(tableName) +} + +dropTempTable <- function(x, ...) { + .Deprecated("dropTempView") + dispatchFunc("dropTempView(viewName)", x, ...) +} + +#' Drops the temporary view with the given view name in the catalog. +#' +#' Drops the temporary view with the given view name in the catalog. +#' If the view has been cached before, then it will also be uncached. +#' +#' @param viewName the name of the view to be dropped. +#' @return TRUE if the view is dropped successfully, FALSE otherwise. +#' @rdname dropTempView +#' @name dropTempView +#' @export +#' @examples +#' \dontrun{ +#' sparkR.session() +#' df <- read.df(path, "parquet") +#' createOrReplaceTempView(df, "table") +#' dropTempView("table") +#' } +#' @note since 2.0.0 +dropTempView <- function(viewName) { + sparkSession <- getSparkSession() + if (class(viewName) != "character") { + stop("viewName must be a string.") + } + catalog <- callJMethod(sparkSession, "catalog") + callJMethod(catalog, "dropTempView", viewName) +} + +#' Tables +#' +#' Returns a SparkDataFrame containing names of tables in the given database. +#' +#' @param databaseName (optional) name of the database +#' @return a SparkDataFrame +#' @rdname tables +#' @seealso \link{listTables} +#' @export +#' @examples +#'\dontrun{ +#' sparkR.session() +#' tables("hive") +#' } +#' @name tables +#' @method tables default +#' @note tables since 1.4.0 +tables.default <- function(databaseName = NULL) { + # rename column to match previous output schema + withColumnRenamed(listTables(databaseName), "name", "tableName") +} + +tables <- function(x, ...) { + dispatchFunc("tables(databaseName = NULL)", x, ...) +} + +#' Table Names +#' +#' Returns the names of tables in the given database as an array. +#' +#' @param databaseName (optional) name of the database +#' @return a list of table names +#' @rdname tableNames +#' @export +#' @examples +#'\dontrun{ +#' sparkR.session() +#' tableNames("hive") +#' } +#' @name tableNames +#' @method tableNames default +#' @note tableNames since 1.4.0 +tableNames.default <- function(databaseName = NULL) { + sparkSession <- getSparkSession() + callJStatic("org.apache.spark.sql.api.r.SQLUtils", + "getTableNames", + sparkSession, + databaseName) +} + +tableNames <- function(x, ...) { + dispatchFunc("tableNames(databaseName = NULL)", x, ...) +} + +#' Returns the current default database +#' +#' Returns the current default database. +#' +#' @return name of the current default database. +#' @rdname currentDatabase +#' @name currentDatabase +#' @export +#' @examples +#' \dontrun{ +#' sparkR.session() +#' currentDatabase() +#' } +#' @note since 2.2.0 +currentDatabase <- function() { + sparkSession <- getSparkSession() + catalog <- callJMethod(sparkSession, "catalog") + callJMethod(catalog, "currentDatabase") +} + +#' Sets the current default database +#' +#' Sets the current default database. +#' +#' @param databaseName name of the database +#' @rdname setCurrentDatabase +#' @name setCurrentDatabase +#' @export +#' @examples +#' \dontrun{ +#' sparkR.session() +#' setCurrentDatabase("default") +#' } +#' @note since 2.2.0 +setCurrentDatabase <- function(databaseName) { + sparkSession <- getSparkSession() + if (class(databaseName) != "character") { + stop("databaseName must be a string.") + } + catalog <- callJMethod(sparkSession, "catalog") + invisible(handledCallJMethod(catalog, "setCurrentDatabase", databaseName)) +} + +#' Returns a list of databases available +#' +#' Returns a list of databases available. +#' +#' @return a SparkDataFrame of the list of databases. +#' @rdname listDatabases +#' @name listDatabases +#' @export +#' @examples +#' \dontrun{ +#' sparkR.session() +#' listDatabases() +#' } +#' @note since 2.2.0 +listDatabases <- function() { + sparkSession <- getSparkSession() + catalog <- callJMethod(sparkSession, "catalog") + dataFrame(callJMethod(callJMethod(catalog, "listDatabases"), "toDF")) +} + +#' Returns a list of tables in the specified database +#' +#' Returns a list of tables in the specified database. +#' This includes all temporary tables. +#' +#' @param databaseName (optional) name of the database +#' @return a SparkDataFrame of the list of tables. +#' @rdname listTables +#' @name listTables +#' @seealso \link{tables} +#' @export +#' @examples +#' \dontrun{ +#' sparkR.session() +#' listTables() +#' listTables("default") +#' } +#' @note since 2.2.0 +listTables <- function(databaseName = NULL) { + sparkSession <- getSparkSession() + if (!is.null(databaseName) && class(databaseName) != "character") { + stop("databaseName must be a string.") + } + catalog <- callJMethod(sparkSession, "catalog") + jdst <- if (is.null(databaseName)) { + callJMethod(catalog, "listTables") + } else { + handledCallJMethod(catalog, "listTables", databaseName) + } + dataFrame(callJMethod(jdst, "toDF")) +} + +#' Returns a list of columns for the given table in the specified database +#' +#' Returns a list of columns for the given table in the specified database. +#' +#' @param tableName a name of the table. +#' @param databaseName (optional) name of the database +#' @return a SparkDataFrame of the list of column descriptions. +#' @rdname listColumns +#' @name listColumns +#' @export +#' @examples +#' \dontrun{ +#' sparkR.session() +#' listColumns("mytable") +#' } +#' @note since 2.2.0 +listColumns <- function(tableName, databaseName = NULL) { + sparkSession <- getSparkSession() + if (!is.null(databaseName) && class(databaseName) != "character") { + stop("databaseName must be a string.") + } + catalog <- callJMethod(sparkSession, "catalog") + jdst <- if (is.null(databaseName)) { + handledCallJMethod(catalog, "listColumns", tableName) + } else { + handledCallJMethod(catalog, "listColumns", databaseName, tableName) + } + dataFrame(callJMethod(jdst, "toDF")) +} + +#' Returns a list of functions registered in the specified database +#' +#' Returns a list of functions registered in the specified database. +#' This includes all temporary functions. +#' +#' @param databaseName (optional) name of the database +#' @return a SparkDataFrame of the list of function descriptions. +#' @rdname listFunctions +#' @name listFunctions +#' @export +#' @examples +#' \dontrun{ +#' sparkR.session() +#' listFunctions() +#' } +#' @note since 2.2.0 +listFunctions <- function(databaseName = NULL) { + sparkSession <- getSparkSession() + if (!is.null(databaseName) && class(databaseName) != "character") { + stop("databaseName must be a string.") + } + catalog <- callJMethod(sparkSession, "catalog") + jdst <- if (is.null(databaseName)) { + callJMethod(catalog, "listFunctions") + } else { + handledCallJMethod(catalog, "listFunctions", databaseName) + } + dataFrame(callJMethod(jdst, "toDF")) +} + +#' Recover all the partitions in the directory of a table and update the catalog +#' +#' Recover all the partitions in the directory of a table and update the catalog. The name should +#' reference a partitioned table, and not a temporary view. +#' +#' @param tableName a name of the table. +#' @rdname recoverPartitions +#' @name recoverPartitions +#' @export +#' @examples +#' \dontrun{ +#' sparkR.session() +#' recoverPartitions("myTable") +#' } +#' @note since 2.2.0 +recoverPartitions <- function(tableName) { + sparkSession <- getSparkSession() + catalog <- callJMethod(sparkSession, "catalog") + invisible(handledCallJMethod(catalog, "recoverPartitions", tableName)) +} + +#' Invalidate and refresh all the cached metadata of the given table +#' +#' Invalidate and refresh all the cached metadata of the given table. For performance reasons, +#' Spark SQL or the external data source library it uses might cache certain metadata about a +#' table, such as the location of blocks. When those change outside of Spark SQL, users should +#' call this function to invalidate the cache. +#' +#' If this table is cached as an InMemoryRelation, drop the original cached version and make the +#' new version cached lazily. +#' +#' @param tableName a name of the table. +#' @rdname refreshTable +#' @name refreshTable +#' @export +#' @examples +#' \dontrun{ +#' sparkR.session() +#' refreshTable("myTable") +#' } +#' @note since 2.2.0 +refreshTable <- function(tableName) { + sparkSession <- getSparkSession() + catalog <- callJMethod(sparkSession, "catalog") + invisible(handledCallJMethod(catalog, "refreshTable", tableName)) +} + +#' Invalidate and refresh all the cached data and metadata for SparkDataFrame containing path +#' +#' Invalidate and refresh all the cached data (and the associated metadata) for any SparkDataFrame +#' that contains the given data source path. Path matching is by prefix, i.e. "/" would invalidate +#' everything that is cached. +#' +#' @param path the path of the data source. +#' @rdname refreshByPath +#' @name refreshByPath +#' @export +#' @examples +#' \dontrun{ +#' sparkR.session() +#' refreshByPath("/path") +#' } +#' @note since 2.2.0 +refreshByPath <- function(path) { + sparkSession <- getSparkSession() + catalog <- callJMethod(sparkSession, "catalog") + invisible(handledCallJMethod(catalog, "refreshByPath", path)) +} diff --git a/R/pkg/R/utils.R b/R/pkg/R/utils.R index 810de9917e0ba..fbc89e98847bf 100644 --- a/R/pkg/R/utils.R +++ b/R/pkg/R/utils.R @@ -846,6 +846,24 @@ captureJVMException <- function(e, method) { # Extract the first message of JVM exception. first <- strsplit(msg[2], "\r?\n\tat")[[1]][1] stop(paste0(rmsg, "analysis error - ", first), call. = FALSE) + } else + if (any(grep("org.apache.spark.sql.catalyst.analysis.NoSuchDatabaseException: ", stacktrace))) { + msg <- strsplit(stacktrace, "org.apache.spark.sql.catalyst.analysis.NoSuchDatabaseException: ", + fixed = TRUE)[[1]] + # Extract "Error in ..." message. + rmsg <- msg[1] + # Extract the first message of JVM exception. + first <- strsplit(msg[2], "\r?\n\tat")[[1]][1] + stop(paste0(rmsg, "no such database - ", first), call. = FALSE) + } else + if (any(grep("org.apache.spark.sql.catalyst.analysis.NoSuchTableException: ", stacktrace))) { + msg <- strsplit(stacktrace, "org.apache.spark.sql.catalyst.analysis.NoSuchTableException: ", + fixed = TRUE)[[1]] + # Extract "Error in ..." message. + rmsg <- msg[1] + # Extract the first message of JVM exception. + first <- strsplit(msg[2], "\r?\n\tat")[[1]][1] + stop(paste0(rmsg, "no such table - ", first), call. = FALSE) } else { stop(stacktrace, call. = FALSE) } diff --git a/R/pkg/inst/tests/testthat/test_sparkSQL.R b/R/pkg/inst/tests/testthat/test_sparkSQL.R index 5acf8719d1201..ad06711a79a78 100644 --- a/R/pkg/inst/tests/testthat/test_sparkSQL.R +++ b/R/pkg/inst/tests/testthat/test_sparkSQL.R @@ -645,16 +645,20 @@ test_that("test tableNames and tables", { df <- read.json(jsonPath) createOrReplaceTempView(df, "table1") expect_equal(length(tableNames()), 1) - tables <- tables() + expect_equal(length(tableNames("default")), 1) + tables <- listTables() expect_equal(count(tables), 1) + expect_equal(count(tables()), count(tables)) + expect_true("tableName" %in% colnames(tables())) + expect_true(all(c("tableName", "database", "isTemporary") %in% colnames(tables()))) suppressWarnings(registerTempTable(df, "table2")) - tables <- tables() + tables <- listTables() expect_equal(count(tables), 2) suppressWarnings(dropTempTable("table1")) expect_true(dropTempView("table2")) - tables <- tables() + tables <- listTables() expect_equal(count(tables), 0) }) @@ -686,6 +690,9 @@ test_that("test cache, uncache and clearCache", { uncacheTable("table1") clearCache() expect_true(dropTempView("table1")) + + expect_error(uncacheTable("foo"), + "Error in uncacheTable : no such table - Table or view 'foo' not found in database 'default'") }) test_that("insertInto() on a registered table", { @@ -2821,7 +2828,7 @@ test_that("createDataFrame sqlContext parameter backward compatibility", { # more tests for SPARK-16538 createOrReplaceTempView(df, "table") - SparkR::tables() + SparkR::listTables() SparkR::sql("SELECT 1") suppressWarnings(SparkR::sql(sqlContext, "SELECT * FROM table")) suppressWarnings(SparkR::dropTempTable(sqlContext, "table")) @@ -2977,6 +2984,57 @@ test_that("Collect on DataFrame when NAs exists at the top of a timestamp column expect_equal(class(ldf3$col3), c("POSIXct", "POSIXt")) }) +test_that("catalog APIs, currentDatabase, setCurrentDatabase, listDatabases", { + expect_equal(currentDatabase(), "default") + expect_error(setCurrentDatabase("default"), NA) + expect_error(setCurrentDatabase("foo"), + "Error in setCurrentDatabase : analysis error - Database 'foo' does not exist") + dbs <- collect(listDatabases()) + expect_equal(names(dbs), c("name", "description", "locationUri")) + expect_equal(dbs[[1]], "default") +}) + +test_that("catalog APIs, listTables, listColumns, listFunctions", { + tb <- listTables() + count <- count(tables()) + expect_equal(nrow(tb), count) + expect_equal(colnames(tb), c("name", "database", "description", "tableType", "isTemporary")) + + createOrReplaceTempView(as.DataFrame(cars), "cars") + + tb <- listTables() + expect_equal(nrow(tb), count + 1) + tbs <- collect(tb) + expect_true(nrow(tbs[tbs$name == "cars", ]) > 0) + expect_error(listTables("bar"), + "Error in listTables : no such database - Database 'bar' not found") + + c <- listColumns("cars") + expect_equal(nrow(c), 2) + expect_equal(colnames(c), + c("name", "description", "dataType", "nullable", "isPartition", "isBucket")) + expect_equal(collect(c)[[1]][[1]], "speed") + expect_error(listColumns("foo", "default"), + "Error in listColumns : analysis error - Table 'foo' does not exist in database 'default'") + + f <- listFunctions() + expect_true(nrow(f) >= 200) # 250 + expect_equal(colnames(f), + c("name", "database", "description", "className", "isTemporary")) + expect_equal(take(orderBy(f, "className"), 1)$className, + "org.apache.spark.sql.catalyst.expressions.Abs") + expect_error(listFunctions("foo_db"), + "Error in listFunctions : analysis error - Database 'foo_db' does not exist") + + # recoverPartitions does not work with tempory view + expect_error(recoverPartitions("cars"), + "no such table - Table or view 'cars' not found in database 'default'") + expect_error(refreshTable("cars"), NA) + expect_error(refreshByPath("/"), NA) + + dropTempView("cars") +}) + compare_list <- function(list1, list2) { # get testthat to show the diff by first making the 2 lists equal in length expect_equal(length(list1), length(list2)) From 2a903a1eec46e3bd58af0fcbc57e76752d9c18b3 Mon Sep 17 00:00:00 2001 From: Bryan Cutler Date: Mon, 3 Apr 2017 10:56:54 +0200 Subject: [PATCH 0175/1765] [SPARK-19985][ML] Fixed copy method for some ML Models ## What changes were proposed in this pull request? Some ML Models were using `defaultCopy` which expects a default constructor, and others were not setting the parent estimator. This change fixes these by creating a new instance of the model and explicitly setting values and parent. ## How was this patch tested? Added `MLTestingUtils.checkCopy` to the offending models to tests to verify the copy is made and parent is set. Author: Bryan Cutler Closes #17326 from BryanCutler/ml-model-copy-error-SPARK-19985. --- .../MultilayerPerceptronClassifier.scala | 3 ++- .../ml/feature/BucketedRandomProjectionLSH.scala | 5 ++++- .../org/apache/spark/ml/feature/MinHashLSH.scala | 5 ++++- .../scala/org/apache/spark/ml/feature/RFormula.scala | 6 ++++-- .../MultilayerPerceptronClassifierSuite.scala | 1 + .../ml/feature/BucketedRandomProjectionLSHSuite.scala | 6 ++++-- .../org/apache/spark/ml/feature/MinHashLSHSuite.scala | 11 ++++++++++- .../org/apache/spark/ml/feature/RFormulaSuite.scala | 1 + 8 files changed, 30 insertions(+), 8 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifier.scala index 95c1337ed5608..ec39f964e213a 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifier.scala @@ -329,7 +329,8 @@ class MultilayerPerceptronClassificationModel private[ml] ( @Since("1.5.0") override def copy(extra: ParamMap): MultilayerPerceptronClassificationModel = { - copyValues(new MultilayerPerceptronClassificationModel(uid, layers, weights), extra) + val copied = new MultilayerPerceptronClassificationModel(uid, layers, weights).setParent(parent) + copyValues(copied, extra) } @Since("2.0.0") diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/BucketedRandomProjectionLSH.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/BucketedRandomProjectionLSH.scala index cbac16345a292..36a46ca6ff4b7 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/BucketedRandomProjectionLSH.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/BucketedRandomProjectionLSH.scala @@ -96,7 +96,10 @@ class BucketedRandomProjectionLSHModel private[ml]( } @Since("2.1.0") - override def copy(extra: ParamMap): this.type = defaultCopy(extra) + override def copy(extra: ParamMap): BucketedRandomProjectionLSHModel = { + val copied = new BucketedRandomProjectionLSHModel(uid, randUnitVectors).setParent(parent) + copyValues(copied, extra) + } @Since("2.1.0") override def write: MLWriter = { diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/MinHashLSH.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/MinHashLSH.scala index 620e1fbb09ff7..145422a059196 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/MinHashLSH.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/MinHashLSH.scala @@ -86,7 +86,10 @@ class MinHashLSHModel private[ml]( } @Since("2.1.0") - override def copy(extra: ParamMap): this.type = defaultCopy(extra) + override def copy(extra: ParamMap): MinHashLSHModel = { + val copied = new MinHashLSHModel(uid, randCoefficients).setParent(parent) + copyValues(copied, extra) + } @Since("2.1.0") override def write: MLWriter = new MinHashLSHModel.MinHashLSHModelWriter(this) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala index 389898666eb8e..5a3e2929f5f52 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala @@ -268,8 +268,10 @@ class RFormulaModel private[feature]( } @Since("1.5.0") - override def copy(extra: ParamMap): RFormulaModel = copyValues( - new RFormulaModel(uid, resolvedFormula, pipelineModel)) + override def copy(extra: ParamMap): RFormulaModel = { + val copied = new RFormulaModel(uid, resolvedFormula, pipelineModel).setParent(parent) + copyValues(copied, extra) + } @Since("2.0.0") override def toString: String = s"RFormulaModel($resolvedFormula) (uid=$uid)" diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifierSuite.scala index 41684d92be33a..7700099caac37 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifierSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifierSuite.scala @@ -74,6 +74,7 @@ class MultilayerPerceptronClassifierSuite .setMaxIter(100) .setSolver("l-bfgs") val model = trainer.fit(dataset) + MLTestingUtils.checkCopy(model) val result = model.transform(dataset) val predictionAndLabels = result.select("prediction", "label").collect() predictionAndLabels.foreach { case Row(p: Double, l: Double) => diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/BucketedRandomProjectionLSHSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/BucketedRandomProjectionLSHSuite.scala index 91eac9e733312..cc81da5c66e6d 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/BucketedRandomProjectionLSHSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/BucketedRandomProjectionLSHSuite.scala @@ -23,7 +23,7 @@ import breeze.numerics.constants.Pi import org.apache.spark.SparkFunSuite import org.apache.spark.ml.linalg.{Vector, Vectors} import org.apache.spark.ml.param.ParamsSuite -import org.apache.spark.ml.util.DefaultReadWriteTest +import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} import org.apache.spark.ml.util.TestingUtils._ import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.sql.Dataset @@ -89,10 +89,12 @@ class BucketedRandomProjectionLSHSuite .setOutputCol("values") .setBucketLength(1.0) .setSeed(12345) - val unitVectors = brp.fit(dataset).randUnitVectors + val brpModel = brp.fit(dataset) + val unitVectors = brpModel.randUnitVectors unitVectors.foreach { v: Vector => assert(Vectors.norm(v, 2.0) ~== 1.0 absTol 1e-14) } + MLTestingUtils.checkCopy(brpModel) } test("BucketedRandomProjectionLSH: test of LSH property") { diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/MinHashLSHSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/MinHashLSHSuite.scala index a2f009310fd7a..0ddf097a6eb22 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/MinHashLSHSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/MinHashLSHSuite.scala @@ -20,7 +20,7 @@ package org.apache.spark.ml.feature import org.apache.spark.SparkFunSuite import org.apache.spark.ml.linalg.{Vector, Vectors} import org.apache.spark.ml.param.ParamsSuite -import org.apache.spark.ml.util.DefaultReadWriteTest +import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.sql.Dataset @@ -57,6 +57,15 @@ class MinHashLSHSuite extends SparkFunSuite with MLlibTestSparkContext with Defa testEstimatorAndModelReadWrite(mh, dataset, settings, settings, checkModelData) } + test("Model copy and uid checks") { + val mh = new MinHashLSH() + .setInputCol("keys") + .setOutputCol("values") + val model = mh.fit(dataset) + assert(mh.uid === model.uid) + MLTestingUtils.checkCopy(model) + } + test("hashFunction") { val model = new MinHashLSHModel("mh", randCoefficients = Array((0, 1), (1, 2), (3, 0))) val res = model.hashFunction(Vectors.sparse(10, Seq((2, 1.0), (3, 1.0), (5, 1.0), (7, 1.0)))) diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala index c664460d7d8bb..5cfd59e6b88a2 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala @@ -37,6 +37,7 @@ class RFormulaSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul val formula = new RFormula().setFormula("id ~ v1 + v2") val original = Seq((0, 1.0, 3.0), (2, 2.0, 5.0)).toDF("id", "v1", "v2") val model = formula.fit(original) + MLTestingUtils.checkCopy(model) val result = model.transform(original) val resultSchema = model.transformSchema(original.schema) val expected = Seq( From cff11fd20e869d14106d2d0f17df67161c44d476 Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Mon, 3 Apr 2017 10:07:41 +0100 Subject: [PATCH 0176/1765] [SPARK-20166][SQL] Use XXX for ISO 8601 timezone instead of ZZ (FastDateFormat specific) in CSV/JSON timeformat options ## What changes were proposed in this pull request? This PR proposes to use `XXX` format instead of `ZZ`. `ZZ` seems a `FastDateFormat` specific. `ZZ` supports "ISO 8601 extended format time zones" but it seems `FastDateFormat` specific option. I misunderstood this is compatible format with `SimpleDateFormat` when this change is introduced. Please see [SimpleDateFormat documentation]( https://docs.oracle.com/javase/7/docs/api/java/text/SimpleDateFormat.html#iso8601timezone) and [FastDateFormat documentation](https://commons.apache.org/proper/commons-lang/apidocs/org/apache/commons/lang3/time/FastDateFormat.html). It seems we better replace `ZZ` to `XXX` because they look using the same strategy - [FastDateParser.java#L930](https://github.com/apache/commons-lang/blob/8767cd4f1a6af07093c1e6c422dae8e574be7e5e/src/main/java/org/apache/commons/lang3/time/FastDateParser.java#L930), [FastDateParser.java#L932-L951 ](https://github.com/apache/commons-lang/blob/8767cd4f1a6af07093c1e6c422dae8e574be7e5e/src/main/java/org/apache/commons/lang3/time/FastDateParser.java#L932-L951) and [FastDateParser.java#L596-L601](https://github.com/apache/commons-lang/blob/8767cd4f1a6af07093c1e6c422dae8e574be7e5e/src/main/java/org/apache/commons/lang3/time/FastDateParser.java#L596-L601). I also checked the codes and manually debugged it for sure. It seems both cases use the same pattern `( Z|(?:[+-]\\d{2}(?::)\\d{2}))`. _Note that this should be rather a fix about documentation and not the behaviour change because `ZZ` seems invalid date format in `SimpleDateFormat` as documented in `DataFrameReader` and etc, and both `ZZ` and `XXX` look identically working with `FastDateFormat`_ Current documentation is as below: ``` *
    • `timestampFormat` (default `yyyy-MM-dd'T'HH:mm:ss.SSSZZ`): sets the string that * indicates a timestamp format. Custom date formats follow the formats at * `java.text.SimpleDateFormat`. This applies to timestamp type.
    • ``` ## How was this patch tested? Existing tests should cover this. Also, manually tested as below (BTW, I don't think these are worth being added as tests within Spark): **Parse** ```scala scala> new java.text.SimpleDateFormat("yyyy-MM-dd'T'HH:mm:ss.SSSXXX").parse("2017-03-21T00:00:00.000-11:00") res4: java.util.Date = Tue Mar 21 20:00:00 KST 2017 scala> new java.text.SimpleDateFormat("yyyy-MM-dd'T'HH:mm:ss.SSSXXX").parse("2017-03-21T00:00:00.000Z") res10: java.util.Date = Tue Mar 21 09:00:00 KST 2017 scala> new java.text.SimpleDateFormat("yyyy-MM-dd'T'HH:mm:ss.SSSZZ").parse("2017-03-21T00:00:00.000-11:00") java.text.ParseException: Unparseable date: "2017-03-21T00:00:00.000-11:00" at java.text.DateFormat.parse(DateFormat.java:366) ... 48 elided scala> new java.text.SimpleDateFormat("yyyy-MM-dd'T'HH:mm:ss.SSSZZ").parse("2017-03-21T00:00:00.000Z") java.text.ParseException: Unparseable date: "2017-03-21T00:00:00.000Z" at java.text.DateFormat.parse(DateFormat.java:366) ... 48 elided ``` ```scala scala> org.apache.commons.lang3.time.FastDateFormat.getInstance("yyyy-MM-dd'T'HH:mm:ss.SSSXXX").parse("2017-03-21T00:00:00.000-11:00") res7: java.util.Date = Tue Mar 21 20:00:00 KST 2017 scala> org.apache.commons.lang3.time.FastDateFormat.getInstance("yyyy-MM-dd'T'HH:mm:ss.SSSXXX").parse("2017-03-21T00:00:00.000Z") res1: java.util.Date = Tue Mar 21 09:00:00 KST 2017 scala> org.apache.commons.lang3.time.FastDateFormat.getInstance("yyyy-MM-dd'T'HH:mm:ss.SSSZZ").parse("2017-03-21T00:00:00.000-11:00") res8: java.util.Date = Tue Mar 21 20:00:00 KST 2017 scala> org.apache.commons.lang3.time.FastDateFormat.getInstance("yyyy-MM-dd'T'HH:mm:ss.SSSZZ").parse("2017-03-21T00:00:00.000Z") res2: java.util.Date = Tue Mar 21 09:00:00 KST 2017 ``` **Format** ```scala scala> new java.text.SimpleDateFormat("yyyy-MM-dd'T'HH:mm:ss.SSSXXX").format(new java.text.SimpleDateFormat("yyyy-MM-dd'T'HH:mm:ss.SSSXXX").parse("2017-03-21T00:00:00.000-11:00")) res6: String = 2017-03-21T20:00:00.000+09:00 ``` ```scala scala> val fd = org.apache.commons.lang3.time.FastDateFormat.getInstance("yyyy-MM-dd'T'HH:mm:ss.SSSZZ") fd: org.apache.commons.lang3.time.FastDateFormat = FastDateFormat[yyyy-MM-dd'T'HH:mm:ss.SSSZZ,ko_KR,Asia/Seoul] scala> fd.format(fd.parse("2017-03-21T00:00:00.000-11:00")) res1: String = 2017-03-21T20:00:00.000+09:00 scala> val fd = org.apache.commons.lang3.time.FastDateFormat.getInstance("yyyy-MM-dd'T'HH:mm:ss.SSSXXX") fd: org.apache.commons.lang3.time.FastDateFormat = FastDateFormat[yyyy-MM-dd'T'HH:mm:ss.SSSXXX,ko_KR,Asia/Seoul] scala> fd.format(fd.parse("2017-03-21T00:00:00.000-11:00")) res2: String = 2017-03-21T20:00:00.000+09:00 ``` Author: hyukjinkwon Closes #17489 from HyukjinKwon/SPARK-20166. --- python/pyspark/sql/readwriter.py | 8 ++++---- python/pyspark/sql/streaming.py | 4 ++-- .../org/apache/spark/sql/catalyst/json/JSONOptions.scala | 2 +- .../main/scala/org/apache/spark/sql/DataFrameReader.scala | 4 ++-- .../main/scala/org/apache/spark/sql/DataFrameWriter.scala | 4 ++-- .../spark/sql/execution/datasources/csv/CSVOptions.scala | 2 +- .../org/apache/spark/sql/streaming/DataStreamReader.scala | 4 ++-- .../spark/sql/execution/datasources/csv/CSVSuite.scala | 2 +- 8 files changed, 15 insertions(+), 15 deletions(-) diff --git a/python/pyspark/sql/readwriter.py b/python/pyspark/sql/readwriter.py index 5e732b4bec8fd..d912f395dafce 100644 --- a/python/pyspark/sql/readwriter.py +++ b/python/pyspark/sql/readwriter.py @@ -223,7 +223,7 @@ def json(self, path, schema=None, primitivesAsString=None, prefersDecimal=None, :param timestampFormat: sets the string that indicates a timestamp format. Custom date formats follow the formats at ``java.text.SimpleDateFormat``. This applies to timestamp type. If None is set, it uses the - default value, ``yyyy-MM-dd'T'HH:mm:ss.SSSZZ``. + default value, ``yyyy-MM-dd'T'HH:mm:ss.SSSXXX``. :param wholeFile: parse one record, which may span multiple lines, per file. If None is set, it uses the default value, ``false``. @@ -363,7 +363,7 @@ def csv(self, path, schema=None, sep=None, encoding=None, quote=None, escape=Non :param timestampFormat: sets the string that indicates a timestamp format. Custom date formats follow the formats at ``java.text.SimpleDateFormat``. This applies to timestamp type. If None is set, it uses the - default value, ``yyyy-MM-dd'T'HH:mm:ss.SSSZZ``. + default value, ``yyyy-MM-dd'T'HH:mm:ss.SSSXXX``. :param maxColumns: defines a hard limit of how many columns a record can have. If None is set, it uses the default value, ``20480``. :param maxCharsPerColumn: defines the maximum number of characters allowed for any given @@ -653,7 +653,7 @@ def json(self, path, mode=None, compression=None, dateFormat=None, timestampForm :param timestampFormat: sets the string that indicates a timestamp format. Custom date formats follow the formats at ``java.text.SimpleDateFormat``. This applies to timestamp type. If None is set, it uses the - default value, ``yyyy-MM-dd'T'HH:mm:ss.SSSZZ``. + default value, ``yyyy-MM-dd'T'HH:mm:ss.SSSXXX``. >>> df.write.json(os.path.join(tempfile.mkdtemp(), 'data')) """ @@ -745,7 +745,7 @@ def csv(self, path, mode=None, compression=None, sep=None, quote=None, escape=No :param timestampFormat: sets the string that indicates a timestamp format. Custom date formats follow the formats at ``java.text.SimpleDateFormat``. This applies to timestamp type. If None is set, it uses the - default value, ``yyyy-MM-dd'T'HH:mm:ss.SSSZZ``. + default value, ``yyyy-MM-dd'T'HH:mm:ss.SSSXXX``. :param ignoreLeadingWhiteSpace: a flag indicating whether or not leading whitespaces from values being written should be skipped. If None is set, it uses the default value, ``true``. diff --git a/python/pyspark/sql/streaming.py b/python/pyspark/sql/streaming.py index 27d6725615a4c..3b604963415f9 100644 --- a/python/pyspark/sql/streaming.py +++ b/python/pyspark/sql/streaming.py @@ -457,7 +457,7 @@ def json(self, path, schema=None, primitivesAsString=None, prefersDecimal=None, :param timestampFormat: sets the string that indicates a timestamp format. Custom date formats follow the formats at ``java.text.SimpleDateFormat``. This applies to timestamp type. If None is set, it uses the - default value, ``yyyy-MM-dd'T'HH:mm:ss.SSSZZ``. + default value, ``yyyy-MM-dd'T'HH:mm:ss.SSSXXX``. :param wholeFile: parse one record, which may span multiple lines, per file. If None is set, it uses the default value, ``false``. @@ -581,7 +581,7 @@ def csv(self, path, schema=None, sep=None, encoding=None, quote=None, escape=Non :param timestampFormat: sets the string that indicates a timestamp format. Custom date formats follow the formats at ``java.text.SimpleDateFormat``. This applies to timestamp type. If None is set, it uses the - default value, ``yyyy-MM-dd'T'HH:mm:ss.SSSZZ``. + default value, ``yyyy-MM-dd'T'HH:mm:ss.SSSXXX``. :param maxColumns: defines a hard limit of how many columns a record can have. If None is set, it uses the default value, ``20480``. :param maxCharsPerColumn: defines the maximum number of characters allowed for any given diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JSONOptions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JSONOptions.scala index c22b1ade4e64b..23ba5ed4d50dc 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JSONOptions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JSONOptions.scala @@ -79,7 +79,7 @@ private[sql] class JSONOptions( val timestampFormat: FastDateFormat = FastDateFormat.getInstance( - parameters.getOrElse("timestampFormat", "yyyy-MM-dd'T'HH:mm:ss.SSSZZ"), timeZone, Locale.US) + parameters.getOrElse("timestampFormat", "yyyy-MM-dd'T'HH:mm:ss.SSSXXX"), timeZone, Locale.US) val wholeFile = parameters.get("wholeFile").map(_.toBoolean).getOrElse(false) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala index 6c238618f2af7..2b8537c3d4a63 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala @@ -320,7 +320,7 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { *
    • `dateFormat` (default `yyyy-MM-dd`): sets the string that indicates a date format. * Custom date formats follow the formats at `java.text.SimpleDateFormat`. This applies to * date type.
    • - *
    • `timestampFormat` (default `yyyy-MM-dd'T'HH:mm:ss.SSSZZ`): sets the string that + *
    • `timestampFormat` (default `yyyy-MM-dd'T'HH:mm:ss.SSSXXX`): sets the string that * indicates a timestamp format. Custom date formats follow the formats at * `java.text.SimpleDateFormat`. This applies to timestamp type.
    • *
    • `wholeFile` (default `false`): parse one record, which may span multiple lines, @@ -502,7 +502,7 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { *
    • `dateFormat` (default `yyyy-MM-dd`): sets the string that indicates a date format. * Custom date formats follow the formats at `java.text.SimpleDateFormat`. This applies to * date type.
    • - *
    • `timestampFormat` (default `yyyy-MM-dd'T'HH:mm:ss.SSSZZ`): sets the string that + *
    • `timestampFormat` (default `yyyy-MM-dd'T'HH:mm:ss.SSSXXX`): sets the string that * indicates a timestamp format. Custom date formats follow the formats at * `java.text.SimpleDateFormat`. This applies to timestamp type.
    • *
    • `maxColumns` (default `20480`): defines a hard limit of how many columns diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala index e973d0bc6d09b..338a6e1314d90 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala @@ -477,7 +477,7 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { *
    • `dateFormat` (default `yyyy-MM-dd`): sets the string that indicates a date format. * Custom date formats follow the formats at `java.text.SimpleDateFormat`. This applies to * date type.
    • - *
    • `timestampFormat` (default `yyyy-MM-dd'T'HH:mm:ss.SSSZZ`): sets the string that + *
    • `timestampFormat` (default `yyyy-MM-dd'T'HH:mm:ss.SSSXXX`): sets the string that * indicates a timestamp format. Custom date formats follow the formats at * `java.text.SimpleDateFormat`. This applies to timestamp type.
    • *
    @@ -583,7 +583,7 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { *
  • `dateFormat` (default `yyyy-MM-dd`): sets the string that indicates a date format. * Custom date formats follow the formats at `java.text.SimpleDateFormat`. This applies to * date type.
  • - *
  • `timestampFormat` (default `yyyy-MM-dd'T'HH:mm:ss.SSSZZ`): sets the string that + *
  • `timestampFormat` (default `yyyy-MM-dd'T'HH:mm:ss.SSSXXX`): sets the string that * indicates a timestamp format. Custom date formats follow the formats at * `java.text.SimpleDateFormat`. This applies to timestamp type.
  • *
  • `ignoreLeadingWhiteSpace` (default `true`): a flag indicating whether or not leading diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala index e7b79e0cbfd17..4994b8dc80527 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala @@ -126,7 +126,7 @@ class CSVOptions( val timestampFormat: FastDateFormat = FastDateFormat.getInstance( - parameters.getOrElse("timestampFormat", "yyyy-MM-dd'T'HH:mm:ss.SSSZZ"), timeZone, Locale.US) + parameters.getOrElse("timestampFormat", "yyyy-MM-dd'T'HH:mm:ss.SSSXXX"), timeZone, Locale.US) val wholeFile = parameters.get("wholeFile").map(_.toBoolean).getOrElse(false) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala index 997ca286597da..c3a9cfc08517a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala @@ -201,7 +201,7 @@ final class DataStreamReader private[sql](sparkSession: SparkSession) extends Lo *
  • `dateFormat` (default `yyyy-MM-dd`): sets the string that indicates a date format. * Custom date formats follow the formats at `java.text.SimpleDateFormat`. This applies to * date type.
  • - *
  • `timestampFormat` (default `yyyy-MM-dd'T'HH:mm:ss.SSSZZ`): sets the string that + *
  • `timestampFormat` (default `yyyy-MM-dd'T'HH:mm:ss.SSSXXX`): sets the string that * indicates a timestamp format. Custom date formats follow the formats at * `java.text.SimpleDateFormat`. This applies to timestamp type.
  • *
  • `wholeFile` (default `false`): parse one record, which may span multiple lines, @@ -252,7 +252,7 @@ final class DataStreamReader private[sql](sparkSession: SparkSession) extends Lo *
  • `dateFormat` (default `yyyy-MM-dd`): sets the string that indicates a date format. * Custom date formats follow the formats at `java.text.SimpleDateFormat`. This applies to * date type.
  • - *
  • `timestampFormat` (default `yyyy-MM-dd'T'HH:mm:ss.SSSZZ`): sets the string that + *
  • `timestampFormat` (default `yyyy-MM-dd'T'HH:mm:ss.SSSXXX`): sets the string that * indicates a timestamp format. Custom date formats follow the formats at * `java.text.SimpleDateFormat`. This applies to timestamp type.
  • *
  • `maxColumns` (default `20480`): defines a hard limit of how many columns diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala index d70c47f4e2379..352dba79a4c08 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala @@ -766,7 +766,7 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils { .option("header", "true") .load(iso8601timestampsPath) - val iso8501 = FastDateFormat.getInstance("yyyy-MM-dd'T'HH:mm:ss.SSSZZ", Locale.US) + val iso8501 = FastDateFormat.getInstance("yyyy-MM-dd'T'HH:mm:ss.SSSXXX", Locale.US) val expectedTimestamps = timestamps.collect().map { r => // This should be ISO8601 formatted string. Row(iso8501.format(r.toSeq.head)) From 364b0db75308ddd346b4ab1e032680e8eb4c1753 Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Mon, 3 Apr 2017 10:09:11 +0100 Subject: [PATCH 0177/1765] [MINOR][DOCS] Replace non-breaking space to normal spaces that breaks rendering markdown MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit # What changes were proposed in this pull request? It seems there are several non-breaking spaces were inserted into several `.md`s and they look breaking rendering markdown files. These are different. For example, this can be checked via `python` as below: ```python >>> " " '\xc2\xa0' >>> " " ' ' ``` _Note that it seems this PR description automatically replaces non-breaking spaces into normal spaces. Please open a `vi` and copy and paste it into `python` to verify this (do not copy the characters here)._ I checked the output below in Sapari and Chrome on Mac OS and, Internal Explorer on Windows 10. **Before** ![2017-04-03 12 37 17](https://cloud.githubusercontent.com/assets/6477701/24594655/50aaba02-186a-11e7-80bb-d34b17a3398a.png) ![2017-04-03 12 36 57](https://cloud.githubusercontent.com/assets/6477701/24594654/50a855e6-186a-11e7-94e2-661e56544b0f.png) **After** ![2017-04-03 12 36 46](https://cloud.githubusercontent.com/assets/6477701/24594657/53c2545c-186a-11e7-9a73-00529afbfd75.png) ![2017-04-03 12 36 31](https://cloud.githubusercontent.com/assets/6477701/24594658/53c286c0-186a-11e7-99c9-e66b1f510fe7.png) ## How was this patch tested? Manually checking. These instances were found via ``` grep --include=*.scala --include=*.python --include=*.java --include=*.r --include=*.R --include=*.md --include=*.r -r -I " " . ``` in Mac OS. It seems there are several instances more as below: ``` ./docs/sql-programming-guide.md: │   ├── ... ./docs/sql-programming-guide.md: │   │ ./docs/sql-programming-guide.md: │   ├── country=US ./docs/sql-programming-guide.md: │   │   └── data.parquet ./docs/sql-programming-guide.md: │   ├── country=CN ./docs/sql-programming-guide.md: │   │   └── data.parquet ./docs/sql-programming-guide.md: │   └── ... ./docs/sql-programming-guide.md:    ├── ... ./docs/sql-programming-guide.md:    │ ./docs/sql-programming-guide.md:    ├── country=US ./docs/sql-programming-guide.md:    │   └── data.parquet ./docs/sql-programming-guide.md:    ├── country=CN ./docs/sql-programming-guide.md:    │   └── data.parquet ./docs/sql-programming-guide.md:    └── ... ./sql/core/src/test/README.md:│   ├── *.avdl # Testing Avro IDL(s) ./sql/core/src/test/README.md:│   └── *.avpr # !! NO TOUCH !! Protocol files generated from Avro IDL(s) ./sql/core/src/test/README.md:│   ├── gen-avro.sh # Script used to generate Java code for Avro ./sql/core/src/test/README.md:│   └── gen-thrift.sh # Script used to generate Java code for Thrift ``` These seems generated via `tree` command which inserts non-breaking spaces. They do not look causing any problem for rendering within code blocks and I did not fix it to reduce the overhead to manually replace it when it is overwritten via `tree` command in the future. Author: hyukjinkwon Closes #17517 from HyukjinKwon/non-breaking-space. --- README.md | 2 +- docs/building-spark.md | 2 +- docs/monitoring.md | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/README.md b/README.md index d0eca1ddea283..1e521a7e7b178 100644 --- a/README.md +++ b/README.md @@ -97,7 +97,7 @@ building for particular Hive and Hive Thriftserver distributions. Please refer to the [Configuration Guide](http://spark.apache.org/docs/latest/configuration.html) in the online documentation for an overview on how to configure Spark. -## Contributing +## Contributing Please review the [Contribution to Spark guide](http://spark.apache.org/contributing.html) for information on how to get started contributing to the project. diff --git a/docs/building-spark.md b/docs/building-spark.md index 8353b7a520b8e..e99b70f7a8b47 100644 --- a/docs/building-spark.md +++ b/docs/building-spark.md @@ -154,7 +154,7 @@ Developers who compile Spark frequently may want to speed up compilation; e.g., developers who build with SBT). For more information about how to do this, refer to the [Useful Developer Tools page](http://spark.apache.org/developer-tools.html#reducing-build-times). -## Encrypted Filesystems +## Encrypted Filesystems When building on an encrypted filesystem (if your home directory is encrypted, for example), then the Spark build might fail with a "Filename too long" error. As a workaround, add the following in the configuration args of the `scala-maven-plugin` in the project `pom.xml`: diff --git a/docs/monitoring.md b/docs/monitoring.md index 80519525af0c3..6cbc6660e816c 100644 --- a/docs/monitoring.md +++ b/docs/monitoring.md @@ -257,7 +257,7 @@ In the API, an application is referenced by its application ID, `[app-id]`. When running on YARN, each application may have multiple attempts, but there are attempt IDs only for applications in cluster mode, not applications in client mode. Applications in YARN cluster mode can be identified by their `[attempt-id]`. In the API listed below, when running in YARN cluster mode, -`[app-id]` will actually be `[base-app-id]/[attempt-id]`, where `[base-app-id]` is the YARN application ID. +`[app-id]` will actually be `[base-app-id]/[attempt-id]`, where `[base-app-id]` is the YARN application ID. From fb5869f2cf94217b3e254e2d0820507dc83a25cc Mon Sep 17 00:00:00 2001 From: Denis Bolshakov Date: Mon, 3 Apr 2017 10:16:07 +0100 Subject: [PATCH 0178/1765] [SPARK-9002][CORE] KryoSerializer initialization does not include 'Array[Int]' [SPARK-9002][CORE] KryoSerializer initialization does not include 'Array[Int]' ## What changes were proposed in this pull request? Array[Int] has been registered in KryoSerializer. The following file has been changed core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala ## How was this patch tested? First, the issue was reproduced by new unit test. Then, the issue was fixed to pass the failed test. Author: Denis Bolshakov Closes #17482 from dbolshak/SPARK-9002. --- .../org/apache/spark/serializer/KryoSerializer.scala | 7 +++++++ .../apache/spark/serializer/KryoSerializerSuite.scala | 10 ++++++++++ 2 files changed, 17 insertions(+) diff --git a/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala b/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala index 03815631a604c..6fc66e2374bd9 100644 --- a/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala +++ b/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala @@ -384,9 +384,16 @@ private[serializer] object KryoSerializer { classOf[HighlyCompressedMapStatus], classOf[CompactBuffer[_]], classOf[BlockManagerId], + classOf[Array[Boolean]], classOf[Array[Byte]], classOf[Array[Short]], + classOf[Array[Int]], classOf[Array[Long]], + classOf[Array[Float]], + classOf[Array[Double]], + classOf[Array[Char]], + classOf[Array[String]], + classOf[Array[Array[String]]], classOf[BoundedPriorityQueue[_]], classOf[SparkConf] ) diff --git a/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala b/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala index a30653bb36fa1..7c3922e47fbb9 100644 --- a/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala +++ b/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala @@ -76,6 +76,9 @@ class KryoSerializerSuite extends SparkFunSuite with SharedSparkContext { } test("basic types") { + val conf = new SparkConf(false) + conf.set("spark.kryo.registrationRequired", "true") + val ser = new KryoSerializer(conf).newInstance() def check[T: ClassTag](t: T) { assert(ser.deserialize[T](ser.serialize(t)) === t) @@ -106,6 +109,9 @@ class KryoSerializerSuite extends SparkFunSuite with SharedSparkContext { } test("pairs") { + val conf = new SparkConf(false) + conf.set("spark.kryo.registrationRequired", "true") + val ser = new KryoSerializer(conf).newInstance() def check[T: ClassTag](t: T) { assert(ser.deserialize[T](ser.serialize(t)) === t) @@ -130,12 +136,16 @@ class KryoSerializerSuite extends SparkFunSuite with SharedSparkContext { } test("Scala data structures") { + val conf = new SparkConf(false) + conf.set("spark.kryo.registrationRequired", "true") + val ser = new KryoSerializer(conf).newInstance() def check[T: ClassTag](t: T) { assert(ser.deserialize[T](ser.serialize(t)) === t) } check(List[Int]()) check(List[Int](1, 2, 3)) + check(Seq[Int](1, 2, 3)) check(List[String]()) check(List[String]("x", "y", "z")) check(None) From 4d28e8430d11323f08657ca8f3251ca787c45501 Mon Sep 17 00:00:00 2001 From: Yuhao Yang Date: Mon, 3 Apr 2017 11:42:33 +0200 Subject: [PATCH 0179/1765] [SPARK-19969][ML] Imputer doc and example ## What changes were proposed in this pull request? Add docs and examples for spark.ml.feature.Imputer. Currently scala and Java examples are included. Python example will be added after https://github.com/apache/spark/pull/17316 ## How was this patch tested? local doc generation and example execution Author: Yuhao Yang Closes #17324 from hhbyyh/imputerdoc. --- docs/ml-features.md | 66 +++++++++++++++++ .../spark/examples/ml/JavaImputerExample.java | 71 +++++++++++++++++++ .../src/main/python/ml/imputer_example.py | 50 +++++++++++++ .../spark/examples/ml/ImputerExample.scala | 56 +++++++++++++++ .../org/apache/spark/ml/feature/Imputer.scala | 2 +- 5 files changed, 244 insertions(+), 1 deletion(-) create mode 100644 examples/src/main/java/org/apache/spark/examples/ml/JavaImputerExample.java create mode 100644 examples/src/main/python/ml/imputer_example.py create mode 100644 examples/src/main/scala/org/apache/spark/examples/ml/ImputerExample.scala diff --git a/docs/ml-features.md b/docs/ml-features.md index dad1c6db18f8b..e19fba249fb2d 100644 --- a/docs/ml-features.md +++ b/docs/ml-features.md @@ -1284,6 +1284,72 @@ for more details on the API. + +## Imputer + +The `Imputer` transformer completes missing values in a dataset, either using the mean or the +median of the columns in which the missing values are located. The input columns should be of +`DoubleType` or `FloatType`. Currently `Imputer` does not support categorical features and possibly +creates incorrect values for columns containing categorical features. + +**Note** all `null` values in the input columns are treated as missing, and so are also imputed. + +**Examples** + +Suppose that we have a DataFrame with the columns `a` and `b`: + +~~~ + a | b +------------|----------- + 1.0 | Double.NaN + 2.0 | Double.NaN + Double.NaN | 3.0 + 4.0 | 4.0 + 5.0 | 5.0 +~~~ + +In this example, Imputer will replace all occurrences of `Double.NaN` (the default for the missing value) +with the mean (the default imputation strategy) computed from the other values in the corresponding columns. +In this example, the surrogate values for columns `a` and `b` are 3.0 and 4.0 respectively. After +transformation, the missing values in the output columns will be replaced by the surrogate value for +the relevant column. + +~~~ + a | b | out_a | out_b +------------|------------|-------|------- + 1.0 | Double.NaN | 1.0 | 4.0 + 2.0 | Double.NaN | 2.0 | 4.0 + Double.NaN | 3.0 | 3.0 | 3.0 + 4.0 | 4.0 | 4.0 | 4.0 + 5.0 | 5.0 | 5.0 | 5.0 +~~~ + +
    +
    + +Refer to the [Imputer Scala docs](api/scala/index.html#org.apache.spark.ml.feature.Imputer) +for more details on the API. + +{% include_example scala/org/apache/spark/examples/ml/ImputerExample.scala %} +
    + +
    + +Refer to the [Imputer Java docs](api/java/org/apache/spark/ml/feature/Imputer.html) +for more details on the API. + +{% include_example java/org/apache/spark/examples/ml/JavaImputerExample.java %} +
    + +
    + +Refer to the [Imputer Python docs](api/python/pyspark.ml.html#pyspark.ml.feature.Imputer) +for more details on the API. + +{% include_example python/ml/imputer_example.py %} +
    +
    + # Feature Selectors ## VectorSlicer diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaImputerExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaImputerExample.java new file mode 100644 index 0000000000000..ac40ccd9dbd75 --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaImputerExample.java @@ -0,0 +1,71 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.ml; + +// $example on$ +import java.util.Arrays; +import java.util.List; + +import org.apache.spark.ml.feature.Imputer; +import org.apache.spark.ml.feature.ImputerModel; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.RowFactory; +import org.apache.spark.sql.SparkSession; +import org.apache.spark.sql.types.*; +// $example off$ + +import static org.apache.spark.sql.types.DataTypes.*; + +/** + * An example demonstrating Imputer. + * Run with: + * bin/run-example ml.JavaImputerExample + */ +public class JavaImputerExample { + public static void main(String[] args) { + SparkSession spark = SparkSession + .builder() + .appName("JavaImputerExample") + .getOrCreate(); + + // $example on$ + List data = Arrays.asList( + RowFactory.create(1.0, Double.NaN), + RowFactory.create(2.0, Double.NaN), + RowFactory.create(Double.NaN, 3.0), + RowFactory.create(4.0, 4.0), + RowFactory.create(5.0, 5.0) + ); + StructType schema = new StructType(new StructField[]{ + createStructField("a", DoubleType, false), + createStructField("b", DoubleType, false) + }); + Dataset df = spark.createDataFrame(data, schema); + + Imputer imputer = new Imputer() + .setInputCols(new String[]{"a", "b"}) + .setOutputCols(new String[]{"out_a", "out_b"}); + + ImputerModel model = imputer.fit(df); + model.transform(df).show(); + // $example off$ + + spark.stop(); + } +} diff --git a/examples/src/main/python/ml/imputer_example.py b/examples/src/main/python/ml/imputer_example.py new file mode 100644 index 0000000000000..b8437f827e56d --- /dev/null +++ b/examples/src/main/python/ml/imputer_example.py @@ -0,0 +1,50 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# $example on$ +from pyspark.ml.feature import Imputer +# $example off$ +from pyspark.sql import SparkSession + +""" +An example demonstrating Imputer. +Run with: + bin/spark-submit examples/src/main/python/ml/imputer_example.py +""" + +if __name__ == "__main__": + spark = SparkSession\ + .builder\ + .appName("ImputerExample")\ + .getOrCreate() + + # $example on$ + df = spark.createDataFrame([ + (1.0, float("nan")), + (2.0, float("nan")), + (float("nan"), 3.0), + (4.0, 4.0), + (5.0, 5.0) + ], ["a", "b"]) + + imputer = Imputer(inputCols=["a", "b"], outputCols=["out_a", "out_b"]) + model = imputer.fit(df) + + model.transform(df).show() + # $example off$ + + spark.stop() diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/ImputerExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/ImputerExample.scala new file mode 100644 index 0000000000000..49e98d0c622ca --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/ml/ImputerExample.scala @@ -0,0 +1,56 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.ml + +// $example on$ +import org.apache.spark.ml.feature.Imputer +// $example off$ +import org.apache.spark.sql.SparkSession + +/** + * An example demonstrating Imputer. + * Run with: + * bin/run-example ml.ImputerExample + */ +object ImputerExample { + + def main(args: Array[String]): Unit = { + val spark = SparkSession.builder + .appName("ImputerExample") + .getOrCreate() + + // $example on$ + val df = spark.createDataFrame(Seq( + (1.0, Double.NaN), + (2.0, Double.NaN), + (Double.NaN, 3.0), + (4.0, 4.0), + (5.0, 5.0) + )).toDF("a", "b") + + val imputer = new Imputer() + .setInputCols(Array("a", "b")) + .setOutputCols(Array("out_a", "out_b")) + + val model = imputer.fit(df) + model.transform(df).show() + // $example off$ + + spark.stop() + } +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Imputer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Imputer.scala index ec4c6ad75ee23..a41bd8e689d56 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/Imputer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Imputer.scala @@ -35,7 +35,7 @@ import org.apache.spark.sql.types._ private[feature] trait ImputerParams extends Params with HasInputCols { /** - * The imputation strategy. + * The imputation strategy. Currently only "mean" and "median" are supported. * If "mean", then replace missing values using the mean value of the feature. * If "median", then replace missing values using the approximate median value of the feature. * Default: mean From 4fa1a43af6b5a6abaef7e04cacb2617a2e92d816 Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Mon, 3 Apr 2017 17:44:39 +0800 Subject: [PATCH 0180/1765] [SPARK-19641][SQL] JSON schema inference in DROPMALFORMED mode produces incorrect schema for non-array/object JSONs ## What changes were proposed in this pull request? Currently, when we infer the types for vaild JSON strings but object or array, we are producing empty schemas regardless of parse modes as below: ```scala scala> spark.read.option("mode", "DROPMALFORMED").json(Seq("""{"a": 1}""", """"a"""").toDS).printSchema() root ``` ```scala scala> spark.read.option("mode", "FAILFAST").json(Seq("""{"a": 1}""", """"a"""").toDS).printSchema() root ``` This PR proposes to handle parse modes in type inference. After this PR, ```scala scala> spark.read.option("mode", "DROPMALFORMED").json(Seq("""{"a": 1}""", """"a"""").toDS).printSchema() root |-- a: long (nullable = true) ``` ``` scala> spark.read.option("mode", "FAILFAST").json(Seq("""{"a": 1}""", """"a"""").toDS).printSchema() java.lang.RuntimeException: Failed to infer a common schema. Struct types are expected but string was found. ``` This PR is based on https://github.com/NathanHowell/spark/commit/e233fd03346a73b3b447fa4c24f3b12c8b2e53ae and I and NathanHowell talked about this in https://issues.apache.org/jira/browse/SPARK-19641 ## How was this patch tested? Unit tests in `JsonSuite` for both `DROPMALFORMED` and `FAILFAST` modes. Author: hyukjinkwon Closes #17492 from HyukjinKwon/SPARK-19641. --- .../datasources/json/JsonInferSchema.scala | 77 +++++++++++-------- .../datasources/json/JsonSuite.scala | 34 +++++++- 2 files changed, 78 insertions(+), 33 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonInferSchema.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonInferSchema.scala index e15c30b4374bb..fb632cf2bb70e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonInferSchema.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonInferSchema.scala @@ -25,7 +25,7 @@ import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.analysis.TypeCoercion import org.apache.spark.sql.catalyst.json.JacksonUtils.nextUntil import org.apache.spark.sql.catalyst.json.JSONOptions -import org.apache.spark.sql.catalyst.util.PermissiveMode +import org.apache.spark.sql.catalyst.util.{DropMalformedMode, FailFastMode, ParseMode, PermissiveMode} import org.apache.spark.sql.types._ import org.apache.spark.util.Utils @@ -41,7 +41,7 @@ private[sql] object JsonInferSchema { json: RDD[T], configOptions: JSONOptions, createParser: (JsonFactory, T) => JsonParser): StructType = { - val shouldHandleCorruptRecord = configOptions.parseMode == PermissiveMode + val parseMode = configOptions.parseMode val columnNameOfCorruptRecord = configOptions.columnNameOfCorruptRecord // perform schema inference on each row and merge afterwards @@ -55,20 +55,24 @@ private[sql] object JsonInferSchema { Some(inferField(parser, configOptions)) } } catch { - case _: JsonParseException if shouldHandleCorruptRecord => - Some(StructType(Seq(StructField(columnNameOfCorruptRecord, StringType)))) - case _: JsonParseException => - None + case e @ (_: RuntimeException | _: JsonProcessingException) => parseMode match { + case PermissiveMode => + Some(StructType(Seq(StructField(columnNameOfCorruptRecord, StringType)))) + case DropMalformedMode => + None + case FailFastMode => + throw e + } } } - }.fold(StructType(Seq()))( - compatibleRootType(columnNameOfCorruptRecord, shouldHandleCorruptRecord)) + }.fold(StructType(Nil))( + compatibleRootType(columnNameOfCorruptRecord, parseMode)) canonicalizeType(rootType) match { case Some(st: StructType) => st case _ => // canonicalizeType erases all empty structs, including the only one we want to keep - StructType(Seq()) + StructType(Nil) } } @@ -202,19 +206,33 @@ private[sql] object JsonInferSchema { private def withCorruptField( struct: StructType, - columnNameOfCorruptRecords: String): StructType = { - if (!struct.fieldNames.contains(columnNameOfCorruptRecords)) { - // If this given struct does not have a column used for corrupt records, - // add this field. - val newFields: Array[StructField] = - StructField(columnNameOfCorruptRecords, StringType, nullable = true) +: struct.fields - // Note: other code relies on this sorting for correctness, so don't remove it! - java.util.Arrays.sort(newFields, structFieldComparator) - StructType(newFields) - } else { - // Otherwise, just return this struct. + other: DataType, + columnNameOfCorruptRecords: String, + parseMode: ParseMode) = parseMode match { + case PermissiveMode => + // If we see any other data type at the root level, we get records that cannot be + // parsed. So, we use the struct as the data type and add the corrupt field to the schema. + if (!struct.fieldNames.contains(columnNameOfCorruptRecords)) { + // If this given struct does not have a column used for corrupt records, + // add this field. + val newFields: Array[StructField] = + StructField(columnNameOfCorruptRecords, StringType, nullable = true) +: struct.fields + // Note: other code relies on this sorting for correctness, so don't remove it! + java.util.Arrays.sort(newFields, structFieldComparator) + StructType(newFields) + } else { + // Otherwise, just return this struct. + struct + } + + case DropMalformedMode => + // If corrupt record handling is disabled we retain the valid schema and discard the other. struct - } + + case FailFastMode => + // If `other` is not struct type, consider it as malformed one and throws an exception. + throw new RuntimeException("Failed to infer a common schema. Struct types are expected" + + s" but ${other.catalogString} was found.") } /** @@ -222,21 +240,20 @@ private[sql] object JsonInferSchema { */ private def compatibleRootType( columnNameOfCorruptRecords: String, - shouldHandleCorruptRecord: Boolean): (DataType, DataType) => DataType = { + parseMode: ParseMode): (DataType, DataType) => DataType = { // Since we support array of json objects at the top level, // we need to check the element type and find the root level data type. case (ArrayType(ty1, _), ty2) => - compatibleRootType(columnNameOfCorruptRecords, shouldHandleCorruptRecord)(ty1, ty2) + compatibleRootType(columnNameOfCorruptRecords, parseMode)(ty1, ty2) case (ty1, ArrayType(ty2, _)) => - compatibleRootType(columnNameOfCorruptRecords, shouldHandleCorruptRecord)(ty1, ty2) - // If we see any other data type at the root level, we get records that cannot be - // parsed. So, we use the struct as the data type and add the corrupt field to the schema. + compatibleRootType(columnNameOfCorruptRecords, parseMode)(ty1, ty2) + // Discard null/empty documents case (struct: StructType, NullType) => struct case (NullType, struct: StructType) => struct - case (struct: StructType, o) if !o.isInstanceOf[StructType] && shouldHandleCorruptRecord => - withCorruptField(struct, columnNameOfCorruptRecords) - case (o, struct: StructType) if !o.isInstanceOf[StructType] && shouldHandleCorruptRecord => - withCorruptField(struct, columnNameOfCorruptRecords) + case (struct: StructType, o) if !o.isInstanceOf[StructType] => + withCorruptField(struct, o, columnNameOfCorruptRecords, parseMode) + case (o, struct: StructType) if !o.isInstanceOf[StructType] => + withCorruptField(struct, o, columnNameOfCorruptRecords, parseMode) // If we get anything else, we call compatibleType. // Usually, when we reach here, ty1 and ty2 are two StructTypes. case (ty1, ty2) => compatibleType(ty1, ty2) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala index b09cef76d2be7..2ab03819964be 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala @@ -1041,7 +1041,6 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { spark.read .option("mode", "FAILFAST") .json(corruptRecords) - .collect() } assert(exceptionOne.getMessage.contains("JsonParseException")) @@ -1082,6 +1081,18 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { assert(jsonDFTwo.schema === schemaTwo) } + test("SPARK-19641: Additional corrupt records: DROPMALFORMED mode") { + val schema = new StructType().add("dummy", StringType) + // `DROPMALFORMED` mode should skip corrupt records + val jsonDF = spark.read + .option("mode", "DROPMALFORMED") + .json(additionalCorruptRecords) + checkAnswer( + jsonDF, + Row("test")) + assert(jsonDF.schema === schema) + } + test("Corrupt records: PERMISSIVE mode, without designated column for malformed records") { val schema = StructType( StructField("a", StringType, true) :: @@ -1882,6 +1893,24 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { } } + test("SPARK-19641: Handle multi-line corrupt documents (DROPMALFORMED)") { + withTempPath { dir => + val path = dir.getCanonicalPath + val corruptRecordCount = additionalCorruptRecords.count().toInt + assert(corruptRecordCount === 5) + + additionalCorruptRecords + .toDF("value") + // this is the minimum partition count that avoids hash collisions + .repartition(corruptRecordCount * 4, F.hash($"value")) + .write + .text(path) + + val jsonDF = spark.read.option("wholeFile", true).option("mode", "DROPMALFORMED").json(path) + checkAnswer(jsonDF, Seq(Row("test"))) + } + } + test("SPARK-18352: Handle multi-line corrupt documents (FAILFAST)") { withTempPath { dir => val path = dir.getCanonicalPath @@ -1903,9 +1932,8 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { .option("wholeFile", true) .option("mode", "FAILFAST") .json(path) - .collect() } - assert(exceptionOne.getMessage.contains("Failed to parse a value")) + assert(exceptionOne.getMessage.contains("Failed to infer a common schema")) val exceptionTwo = intercept[SparkException] { spark.read From 703c42c398fefd3f7f60e1c503c4df50251f8dcf Mon Sep 17 00:00:00 2001 From: Adrian Ionescu Date: Mon, 3 Apr 2017 08:48:49 -0700 Subject: [PATCH 0181/1765] [SPARK-20194] Add support for partition pruning to in-memory catalog ## What changes were proposed in this pull request? This patch implements `listPartitionsByFilter()` for `InMemoryCatalog` and thus resolves an outstanding TODO causing the `PruneFileSourcePartitions` optimizer rule not to apply when "spark.sql.catalogImplementation" is set to "in-memory" (which is the default). The change is straightforward: it extracts the code for further filtering of the list of partitions returned by the metastore's `getPartitionsByFilter()` out from `HiveExternalCatalog` into `ExternalCatalogUtils` and calls this new function from `InMemoryCatalog` on the whole list of partitions. Now that this method is implemented we can always pass the `CatalogTable` to the `DataSource` in `FindDataSourceTable`, so that the latter is resolved to a relation with a `CatalogFileIndex`, which is what the `PruneFileSourcePartitions` rule matches for. ## How was this patch tested? Ran existing tests and added new test for `listPartitionsByFilter` in `ExternalCatalogSuite`, which is subclassed by both `InMemoryCatalogSuite` and `HiveExternalCatalogSuite`. Author: Adrian Ionescu Closes #17510 from adrian-ionescu/InMemoryCatalog. --- .../catalog/ExternalCatalogUtils.scala | 33 +++++++++++++++ .../catalyst/catalog/InMemoryCatalog.scala | 8 ++-- .../catalog/ExternalCatalogSuite.scala | 41 +++++++++++++++++++ .../datasources/DataSourceStrategy.scala | 5 +-- .../spark/sql/hive/HiveExternalCatalog.scala | 33 +++------------ .../spark/sql/hive/client/HiveShim.scala | 2 +- .../sql/hive/HiveExternalCatalogSuite.scala | 8 ---- 7 files changed, 85 insertions(+), 45 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalogUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalogUtils.scala index a8693dcca539d..254eedfe77517 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalogUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalogUtils.scala @@ -25,6 +25,7 @@ import org.apache.hadoop.util.Shell import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.analysis.Resolver import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec +import org.apache.spark.sql.catalyst.expressions.{And, AttributeReference, BoundReference, Expression, InterpretedPredicate} object ExternalCatalogUtils { // This duplicates default value of Hive `ConfVars.DEFAULTPARTITIONNAME`, since catalyst doesn't @@ -125,6 +126,38 @@ object ExternalCatalogUtils { } escapePathName(col) + "=" + partitionString } + + def prunePartitionsByFilter( + catalogTable: CatalogTable, + inputPartitions: Seq[CatalogTablePartition], + predicates: Seq[Expression], + defaultTimeZoneId: String): Seq[CatalogTablePartition] = { + if (predicates.isEmpty) { + inputPartitions + } else { + val partitionSchema = catalogTable.partitionSchema + val partitionColumnNames = catalogTable.partitionColumnNames.toSet + + val nonPartitionPruningPredicates = predicates.filterNot { + _.references.map(_.name).toSet.subsetOf(partitionColumnNames) + } + if (nonPartitionPruningPredicates.nonEmpty) { + throw new AnalysisException("Expected only partition pruning predicates: " + + nonPartitionPruningPredicates) + } + + val boundPredicate = + InterpretedPredicate.create(predicates.reduce(And).transform { + case att: AttributeReference => + val index = partitionSchema.indexWhere(_.name == att.name) + BoundReference(index, partitionSchema(index).dataType, nullable = true) + }) + + inputPartitions.filter { p => + boundPredicate(p.toRow(partitionSchema, defaultTimeZoneId)) + } + } + } } object CatalogUtils { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/InMemoryCatalog.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/InMemoryCatalog.scala index cdf618aef97c3..9ca1c71d1dcb1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/InMemoryCatalog.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/InMemoryCatalog.scala @@ -28,7 +28,7 @@ import org.apache.spark.{SparkConf, SparkException} import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.{FunctionIdentifier, TableIdentifier} import org.apache.spark.sql.catalyst.analysis._ -import org.apache.spark.sql.catalyst.catalog.ExternalCatalogUtils.escapePathName +import org.apache.spark.sql.catalyst.catalog.ExternalCatalogUtils._ import org.apache.spark.sql.catalyst.expressions.Expression import org.apache.spark.sql.catalyst.util.StringUtils import org.apache.spark.sql.types.StructType @@ -556,9 +556,9 @@ class InMemoryCatalog( table: String, predicates: Seq[Expression], defaultTimeZoneId: String): Seq[CatalogTablePartition] = { - // TODO: Provide an implementation - throw new UnsupportedOperationException( - "listPartitionsByFilter is not implemented for InMemoryCatalog") + val catalogTable = getTable(db, table) + val allPartitions = listPartitions(db, table) + prunePartitionsByFilter(catalogTable, allPartitions, predicates, defaultTimeZoneId) } // -------------------------------------------------------------------------- diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalogSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalogSuite.scala index 7820f39d96426..42db4398e5072 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalogSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalogSuite.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.catalyst.catalog import java.net.URI +import java.util.TimeZone import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.Path @@ -28,6 +29,8 @@ import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.{FunctionIdentifier, TableIdentifier} import org.apache.spark.sql.catalyst.analysis.{FunctionAlreadyExistsException, NoSuchDatabaseException, NoSuchFunctionException} import org.apache.spark.sql.catalyst.analysis.TableAlreadyExistsException +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.types._ import org.apache.spark.util.Utils @@ -436,6 +439,44 @@ abstract class ExternalCatalogSuite extends SparkFunSuite with BeforeAndAfterEac assert(catalog.listPartitions("db2", "tbl2", Some(Map("a" -> "unknown"))).isEmpty) } + test("list partitions by filter") { + val tz = TimeZone.getDefault.getID + val catalog = newBasicCatalog() + + def checkAnswer( + table: CatalogTable, filters: Seq[Expression], expected: Set[CatalogTablePartition]) + : Unit = { + + assertResult(expected.map(_.spec)) { + catalog.listPartitionsByFilter(table.database, table.identifier.identifier, filters, tz) + .map(_.spec).toSet + } + } + + val tbl2 = catalog.getTable("db2", "tbl2") + + checkAnswer(tbl2, Seq.empty, Set(part1, part2)) + checkAnswer(tbl2, Seq('a.int <= 1), Set(part1)) + checkAnswer(tbl2, Seq('a.int === 2), Set.empty) + checkAnswer(tbl2, Seq(In('a.int * 10, Seq(30))), Set(part2)) + checkAnswer(tbl2, Seq(Not(In('a.int, Seq(4)))), Set(part1, part2)) + checkAnswer(tbl2, Seq('a.int === 1, 'b.string === "2"), Set(part1)) + checkAnswer(tbl2, Seq('a.int === 1 && 'b.string === "2"), Set(part1)) + checkAnswer(tbl2, Seq('a.int === 1, 'b.string === "x"), Set.empty) + checkAnswer(tbl2, Seq('a.int === 1 || 'b.string === "x"), Set(part1)) + + intercept[AnalysisException] { + try { + checkAnswer(tbl2, Seq('a.int > 0 && 'col1.int > 0), Set.empty) + } catch { + // HiveExternalCatalog may be the first one to notice and throw an exception, which will + // then be caught and converted to a RuntimeException with a descriptive message. + case ex: RuntimeException if ex.getMessage.contains("MetaException") => + throw new AnalysisException(ex.getMessage) + } + } + } + test("drop partitions") { val catalog = newBasicCatalog() assert(catalogPartitionsEqual(catalog, "db2", "tbl2", Seq(part1, part2))) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala index bddf5af23e060..c350d8bcbae97 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala @@ -217,8 +217,6 @@ class FindDataSourceTable(sparkSession: SparkSession) extends Rule[LogicalPlan] val table = r.tableMeta val qualifiedTableName = QualifiedTableName(table.database, table.identifier.table) val cache = sparkSession.sessionState.catalog.tableRelationCache - val withHiveSupport = - sparkSession.sparkContext.conf.get(StaticSQLConf.CATALOG_IMPLEMENTATION) == "hive" val plan = cache.get(qualifiedTableName, new Callable[LogicalPlan]() { override def call(): LogicalPlan = { @@ -233,8 +231,7 @@ class FindDataSourceTable(sparkSession: SparkSession) extends Rule[LogicalPlan] bucketSpec = table.bucketSpec, className = table.provider.get, options = table.storage.properties ++ pathOption, - // TODO: improve `InMemoryCatalog` and remove this limitation. - catalogTable = if (withHiveSupport) Some(table) else None) + catalogTable = Some(table)) LogicalRelation( dataSource.resolveRelation(checkFilesExist = false), diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala index 33b21be37203b..f0e35dff57f7b 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala @@ -35,7 +35,7 @@ import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.analysis.TableAlreadyExistsException import org.apache.spark.sql.catalyst.catalog._ -import org.apache.spark.sql.catalyst.catalog.ExternalCatalogUtils.escapePathName +import org.apache.spark.sql.catalyst.catalog.ExternalCatalogUtils._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical.ColumnStat import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap @@ -1039,37 +1039,14 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat defaultTimeZoneId: String): Seq[CatalogTablePartition] = withClient { val rawTable = getRawTable(db, table) val catalogTable = restoreTableMetadata(rawTable) - val partitionColumnNames = catalogTable.partitionColumnNames.toSet - val nonPartitionPruningPredicates = predicates.filterNot { - _.references.map(_.name).toSet.subsetOf(partitionColumnNames) - } - if (nonPartitionPruningPredicates.nonEmpty) { - sys.error("Expected only partition pruning predicates: " + - predicates.reduceLeft(And)) - } + val partColNameMap = buildLowerCasePartColNameMap(catalogTable) - val partitionSchema = catalogTable.partitionSchema - val partColNameMap = buildLowerCasePartColNameMap(getTable(db, table)) - - if (predicates.nonEmpty) { - val clientPrunedPartitions = client.getPartitionsByFilter(rawTable, predicates).map { part => + val clientPrunedPartitions = + client.getPartitionsByFilter(rawTable, predicates).map { part => part.copy(spec = restorePartitionSpec(part.spec, partColNameMap)) } - val boundPredicate = - InterpretedPredicate.create(predicates.reduce(And).transform { - case att: AttributeReference => - val index = partitionSchema.indexWhere(_.name == att.name) - BoundReference(index, partitionSchema(index).dataType, nullable = true) - }) - clientPrunedPartitions.filter { p => - boundPredicate(p.toRow(partitionSchema, defaultTimeZoneId)) - } - } else { - client.getPartitions(catalogTable).map { part => - part.copy(spec = restorePartitionSpec(part.spec, partColNameMap)) - } - } + prunePartitionsByFilter(catalogTable, clientPrunedPartitions, predicates, defaultTimeZoneId) } // -------------------------------------------------------------------------- diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala index d55c41e5c9f29..2e35f39839488 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala @@ -584,7 +584,7 @@ private[client] class Shim_v0_13 extends Shim_v0_12 { */ def convertFilters(table: Table, filters: Seq[Expression]): String = { // hive varchar is treated as catalyst string, but hive varchar can't be pushed down. - val varcharKeys = table.getPartitionKeys.asScala + lazy val varcharKeys = table.getPartitionKeys.asScala .filter(col => col.getType.startsWith(serdeConstants.VARCHAR_TYPE_NAME) || col.getType.startsWith(serdeConstants.CHAR_TYPE_NAME)) .map(col => col.getName).toSet diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogSuite.scala index 4349f1aa23be0..bd54c043c6ec4 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogSuite.scala @@ -22,7 +22,6 @@ import org.apache.hadoop.conf.Configuration import org.apache.spark.SparkConf import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.catalog._ -import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.execution.command.DDLUtils import org.apache.spark.sql.types.StructType @@ -50,13 +49,6 @@ class HiveExternalCatalogSuite extends ExternalCatalogSuite { import utils._ - test("list partitions by filter") { - val catalog = newBasicCatalog() - val selectedPartitions = catalog.listPartitionsByFilter("db2", "tbl2", Seq('a.int === 1), "GMT") - assert(selectedPartitions.length == 1) - assert(selectedPartitions.head.spec == part1.spec) - } - test("SPARK-18647: do not put provider in table properties for Hive serde table") { val catalog = newBasicCatalog() val hiveTable = CatalogTable( From 58c9e6e77ae26345291dd9fce2c57aadcc36f66c Mon Sep 17 00:00:00 2001 From: samelamin Date: Mon, 3 Apr 2017 17:16:31 -0700 Subject: [PATCH 0182/1765] [SPARK-20145] Fix range case insensitive bug in SQL ## What changes were proposed in this pull request? Range in SQL should be case insensitive ## How was this patch tested? unit test Author: samelamin Author: samelamin Closes #17487 from samelamin/SPARK-20145. --- .../ResolveTableValuedFunctions.scala | 4 +--- .../inputs/table-valued-functions.sql | 6 ++++++ .../results/table-valued-functions.sql.out | 20 ++++++++++++++++++- 3 files changed, 26 insertions(+), 4 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveTableValuedFunctions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveTableValuedFunctions.scala index 6b3bb68538dd1..8841309939c24 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveTableValuedFunctions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveTableValuedFunctions.scala @@ -17,9 +17,7 @@ package org.apache.spark.sql.catalyst.analysis -import org.apache.spark.{SparkConf, SparkContext} import org.apache.spark.sql.catalyst.expressions.Expression -import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Range} import org.apache.spark.sql.catalyst.rules._ import org.apache.spark.sql.types.{DataType, IntegerType, LongType} @@ -105,7 +103,7 @@ object ResolveTableValuedFunctions extends Rule[LogicalPlan] { override def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { case u: UnresolvedTableValuedFunction if u.functionArgs.forall(_.resolved) => - builtinFunctions.get(u.functionName) match { + builtinFunctions.get(u.functionName.toLowerCase()) match { case Some(tvf) => val resolved = tvf.flatMap { case (argList, resolver) => argList.implicitCast(u.functionArgs) match { diff --git a/sql/core/src/test/resources/sql-tests/inputs/table-valued-functions.sql b/sql/core/src/test/resources/sql-tests/inputs/table-valued-functions.sql index 2e6dcd538b7ac..d0d2df7b243d5 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/table-valued-functions.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/table-valued-functions.sql @@ -18,3 +18,9 @@ select * from range(1, 1, 1, 1, 1); -- range call with null select * from range(1, null); + +-- range call with a mixed-case function name +select * from RaNgE(2); + +-- Explain +EXPLAIN select * from RaNgE(2); diff --git a/sql/core/src/test/resources/sql-tests/results/table-valued-functions.sql.out b/sql/core/src/test/resources/sql-tests/results/table-valued-functions.sql.out index d769bcef0aca7..acd4ecf14617e 100644 --- a/sql/core/src/test/resources/sql-tests/results/table-valued-functions.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/table-valued-functions.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 7 +-- Number of queries: 9 -- !query 0 @@ -85,3 +85,21 @@ struct<> -- !query 6 output java.lang.IllegalArgumentException Invalid arguments for resolved function: 1, null + + +-- !query 7 +select * from RaNgE(2) +-- !query 7 schema +struct +-- !query 7 output +0 +1 + + +-- !query 8 +EXPLAIN select * from RaNgE(2) +-- !query 8 schema +struct +-- !query 8 output +== Physical Plan == +*Range (0, 2, step=1, splits=None) From e7877fd4728ed41e440d7c4d8b6b02bd0d9e873e Mon Sep 17 00:00:00 2001 From: Ron Hu Date: Mon, 3 Apr 2017 17:27:12 -0700 Subject: [PATCH 0183/1765] [SPARK-19408][SQL] filter estimation on two columns of same table ## What changes were proposed in this pull request? In SQL queries, we also see predicate expressions involving two columns such as "column-1 (op) column-2" where column-1 and column-2 belong to same table. Note that, if column-1 and column-2 belong to different tables, then it is a join operator's work, NOT a filter operator's work. This PR estimates filter selectivity on two columns of same table. For example, multiple tpc-h queries have this predicate "WHERE l_commitdate < l_receiptdate" ## How was this patch tested? We added 6 new test cases to test various logical predicates involving two columns of same table. Please review http://spark.apache.org/contributing.html before opening a pull request. Author: Ron Hu Author: U-CHINA\r00754707 Closes #17415 from ron8hu/filterTwoColumns. --- .../statsEstimation/FilterEstimation.scala | 233 +++++++++++++++++- .../FilterEstimationSuite.scala | 140 ++++++++++- 2 files changed, 363 insertions(+), 10 deletions(-) mode change 100644 => 100755 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/FilterEstimation.scala mode change 100644 => 100755 sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/FilterEstimationSuite.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/FilterEstimation.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/FilterEstimation.scala old mode 100644 new mode 100755 index b32374c5742ef..03c76cd41d816 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/FilterEstimation.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/FilterEstimation.scala @@ -201,6 +201,21 @@ case class FilterEstimation(plan: Filter, catalystConf: CatalystConf) extends Lo case IsNotNull(ar: Attribute) if plan.child.isInstanceOf[LeafNode] => evaluateNullCheck(ar, isNull = false, update) + case op @ Equality(attrLeft: Attribute, attrRight: Attribute) => + evaluateBinaryForTwoColumns(op, attrLeft, attrRight, update) + + case op @ LessThan(attrLeft: Attribute, attrRight: Attribute) => + evaluateBinaryForTwoColumns(op, attrLeft, attrRight, update) + + case op @ LessThanOrEqual(attrLeft: Attribute, attrRight: Attribute) => + evaluateBinaryForTwoColumns(op, attrLeft, attrRight, update) + + case op @ GreaterThan(attrLeft: Attribute, attrRight: Attribute) => + evaluateBinaryForTwoColumns(op, attrLeft, attrRight, update) + + case op @ GreaterThanOrEqual(attrLeft: Attribute, attrRight: Attribute) => + evaluateBinaryForTwoColumns(op, attrLeft, attrRight, update) + case _ => // TODO: it's difficult to support string operators without advanced statistics. // Hence, these string operators Like(_, _) | Contains(_, _) | StartsWith(_, _) @@ -257,7 +272,7 @@ case class FilterEstimation(plan: Filter, catalystConf: CatalystConf) extends Lo /** * Returns a percentage of rows meeting a binary comparison expression. * - * @param op a binary comparison operator uch as =, <, <=, >, >= + * @param op a binary comparison operator such as =, <, <=, >, >= * @param attr an Attribute (or a column) * @param literal a literal value (or constant) * @param update a boolean flag to specify if we need to update ColumnStat of a given column @@ -448,7 +463,7 @@ case class FilterEstimation(plan: Filter, catalystConf: CatalystConf) extends Lo * Returns a percentage of rows meeting a binary comparison expression. * This method evaluate expression for Numeric/Date/Timestamp/Boolean columns. * - * @param op a binary comparison operator uch as =, <, <=, >, >= + * @param op a binary comparison operator such as =, <, <=, >, >= * @param attr an Attribute (or a column) * @param literal a literal value (or constant) * @param update a boolean flag to specify if we need to update ColumnStat of a given column @@ -550,6 +565,220 @@ case class FilterEstimation(plan: Filter, catalystConf: CatalystConf) extends Lo Some(percent.toDouble) } + /** + * Returns a percentage of rows meeting a binary comparison expression containing two columns. + * In SQL queries, we also see predicate expressions involving two columns + * such as "column-1 (op) column-2" where column-1 and column-2 belong to same table. + * Note that, if column-1 and column-2 belong to different tables, then it is a join + * operator's work, NOT a filter operator's work. + * + * @param op a binary comparison operator, including =, <=>, <, <=, >, >= + * @param attrLeft the left Attribute (or a column) + * @param attrRight the right Attribute (or a column) + * @param update a boolean flag to specify if we need to update ColumnStat of the given columns + * for subsequent conditions + * @return an optional double value to show the percentage of rows meeting a given condition + */ + def evaluateBinaryForTwoColumns( + op: BinaryComparison, + attrLeft: Attribute, + attrRight: Attribute, + update: Boolean): Option[Double] = { + + if (!colStatsMap.contains(attrLeft)) { + logDebug("[CBO] No statistics for " + attrLeft) + return None + } + if (!colStatsMap.contains(attrRight)) { + logDebug("[CBO] No statistics for " + attrRight) + return None + } + + attrLeft.dataType match { + case StringType | BinaryType => + // TODO: It is difficult to support other binary comparisons for String/Binary + // type without min/max and advanced statistics like histogram. + logDebug("[CBO] No range comparison statistics for String/Binary type " + attrLeft) + return None + case _ => + } + + val colStatLeft = colStatsMap(attrLeft) + val statsRangeLeft = Range(colStatLeft.min, colStatLeft.max, attrLeft.dataType) + .asInstanceOf[NumericRange] + val maxLeft = BigDecimal(statsRangeLeft.max) + val minLeft = BigDecimal(statsRangeLeft.min) + + val colStatRight = colStatsMap(attrRight) + val statsRangeRight = Range(colStatRight.min, colStatRight.max, attrRight.dataType) + .asInstanceOf[NumericRange] + val maxRight = BigDecimal(statsRangeRight.max) + val minRight = BigDecimal(statsRangeRight.min) + + // determine the overlapping degree between predicate range and column's range + val allNotNull = (colStatLeft.nullCount == 0) && (colStatRight.nullCount == 0) + val (noOverlap: Boolean, completeOverlap: Boolean) = op match { + // Left < Right or Left <= Right + // - no overlap: + // minRight maxRight minLeft maxLeft + // --------+------------------+------------+-------------+-------> + // - complete overlap: (If null values exists, we set it to partial overlap.) + // minLeft maxLeft minRight maxRight + // --------+------------------+------------+-------------+-------> + case _: LessThan => + (minLeft >= maxRight, (maxLeft < minRight) && allNotNull) + case _: LessThanOrEqual => + (minLeft > maxRight, (maxLeft <= minRight) && allNotNull) + + // Left > Right or Left >= Right + // - no overlap: + // minLeft maxLeft minRight maxRight + // --------+------------------+------------+-------------+-------> + // - complete overlap: (If null values exists, we set it to partial overlap.) + // minRight maxRight minLeft maxLeft + // --------+------------------+------------+-------------+-------> + case _: GreaterThan => + (maxLeft <= minRight, (minLeft > maxRight) && allNotNull) + case _: GreaterThanOrEqual => + (maxLeft < minRight, (minLeft >= maxRight) && allNotNull) + + // Left = Right or Left <=> Right + // - no overlap: + // minLeft maxLeft minRight maxRight + // --------+------------------+------------+-------------+-------> + // minRight maxRight minLeft maxLeft + // --------+------------------+------------+-------------+-------> + // - complete overlap: + // minLeft maxLeft + // minRight maxRight + // --------+------------------+-------> + case _: EqualTo => + ((maxLeft < minRight) || (maxRight < minLeft), + (minLeft == minRight) && (maxLeft == maxRight) && allNotNull + && (colStatLeft.distinctCount == colStatRight.distinctCount) + ) + case _: EqualNullSafe => + // For null-safe equality, we use a very restrictive condition to evaluate its overlap. + // If null values exists, we set it to partial overlap. + (((maxLeft < minRight) || (maxRight < minLeft)) && allNotNull, + (minLeft == minRight) && (maxLeft == maxRight) && allNotNull + && (colStatLeft.distinctCount == colStatRight.distinctCount) + ) + } + + var percent = BigDecimal(1.0) + if (noOverlap) { + percent = 0.0 + } else if (completeOverlap) { + percent = 1.0 + } else { + // For partial overlap, we use an empirical value 1/3 as suggested by the book + // "Database Systems, the complete book". + percent = 1.0 / 3.0 + + if (update) { + // Need to adjust new min/max after the filter condition is applied + + val ndvLeft = BigDecimal(colStatLeft.distinctCount) + var newNdvLeft = (ndvLeft * percent).setScale(0, RoundingMode.HALF_UP).toBigInt() + if (newNdvLeft < 1) newNdvLeft = 1 + val ndvRight = BigDecimal(colStatRight.distinctCount) + var newNdvRight = (ndvRight * percent).setScale(0, RoundingMode.HALF_UP).toBigInt() + if (newNdvRight < 1) newNdvRight = 1 + + var newMaxLeft = colStatLeft.max + var newMinLeft = colStatLeft.min + var newMaxRight = colStatRight.max + var newMinRight = colStatRight.min + + op match { + case _: LessThan | _: LessThanOrEqual => + // the left side should be less than the right side. + // If not, we need to adjust it to narrow the range. + // Left < Right or Left <= Right + // minRight < minLeft + // --------+******************+-------> + // filtered ^ + // | + // newMinRight + // + // maxRight < maxLeft + // --------+******************+-------> + // ^ filtered + // | + // newMaxLeft + if (minLeft > minRight) newMinRight = colStatLeft.min + if (maxLeft > maxRight) newMaxLeft = colStatRight.max + + case _: GreaterThan | _: GreaterThanOrEqual => + // the left side should be greater than the right side. + // If not, we need to adjust it to narrow the range. + // Left > Right or Left >= Right + // minLeft < minRight + // --------+******************+-------> + // filtered ^ + // | + // newMinLeft + // + // maxLeft < maxRight + // --------+******************+-------> + // ^ filtered + // | + // newMaxRight + if (minLeft < minRight) newMinLeft = colStatRight.min + if (maxLeft < maxRight) newMaxRight = colStatLeft.max + + case _: EqualTo | _: EqualNullSafe => + // need to set new min to the larger min value, and + // set the new max to the smaller max value. + // Left = Right or Left <=> Right + // minLeft < minRight + // --------+******************+-------> + // filtered ^ + // | + // newMinLeft + // + // minRight <= minLeft + // --------+******************+-------> + // filtered ^ + // | + // newMinRight + // + // maxLeft < maxRight + // --------+******************+-------> + // ^ filtered + // | + // newMaxRight + // + // maxRight <= maxLeft + // --------+******************+-------> + // ^ filtered + // | + // newMaxLeft + if (minLeft < minRight) { + newMinLeft = colStatRight.min + } else { + newMinRight = colStatLeft.min + } + if (maxLeft < maxRight) { + newMaxRight = colStatLeft.max + } else { + newMaxLeft = colStatRight.max + } + } + + val newStatsLeft = colStatLeft.copy(distinctCount = newNdvLeft, min = newMinLeft, + max = newMaxLeft) + colStatsMap(attrLeft) = newStatsLeft + val newStatsRight = colStatRight.copy(distinctCount = newNdvRight, min = newMinRight, + max = newMaxRight) + colStatsMap(attrRight) = newStatsRight + } + } + + Some(percent.toDouble) + } + } class ColumnStatsMap { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/FilterEstimationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/FilterEstimationSuite.scala old mode 100644 new mode 100755 index 1966c96c05294..cffb0d8739287 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/FilterEstimationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/FilterEstimationSuite.scala @@ -33,49 +33,74 @@ import org.apache.spark.sql.types._ class FilterEstimationSuite extends StatsEstimationTestBase { // Suppose our test table has 10 rows and 6 columns. - // First column cint has values: 1, 2, 3, 4, 5, 6, 7, 8, 9, 10 + // column cint has values: 1, 2, 3, 4, 5, 6, 7, 8, 9, 10 // Hence, distinctCount:10, min:1, max:10, nullCount:0, avgLen:4, maxLen:4 val attrInt = AttributeReference("cint", IntegerType)() val colStatInt = ColumnStat(distinctCount = 10, min = Some(1), max = Some(10), nullCount = 0, avgLen = 4, maxLen = 4) - // only 2 values + // column cbool has only 2 distinct values val attrBool = AttributeReference("cbool", BooleanType)() val colStatBool = ColumnStat(distinctCount = 2, min = Some(false), max = Some(true), nullCount = 0, avgLen = 1, maxLen = 1) - // Second column cdate has 10 values from 2017-01-01 through 2017-01-10. + // column cdate has 10 values from 2017-01-01 through 2017-01-10. val dMin = Date.valueOf("2017-01-01") val dMax = Date.valueOf("2017-01-10") val attrDate = AttributeReference("cdate", DateType)() val colStatDate = ColumnStat(distinctCount = 10, min = Some(dMin), max = Some(dMax), nullCount = 0, avgLen = 4, maxLen = 4) - // Fourth column cdecimal has 4 values from 0.20 through 0.80 at increment of 0.20. + // column cdecimal has 4 values from 0.20 through 0.80 at increment of 0.20. val decMin = new java.math.BigDecimal("0.200000000000000000") val decMax = new java.math.BigDecimal("0.800000000000000000") val attrDecimal = AttributeReference("cdecimal", DecimalType(18, 18))() val colStatDecimal = ColumnStat(distinctCount = 4, min = Some(decMin), max = Some(decMax), nullCount = 0, avgLen = 8, maxLen = 8) - // Fifth column cdouble has 10 double values: 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0 + // column cdouble has 10 double values: 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0 val attrDouble = AttributeReference("cdouble", DoubleType)() val colStatDouble = ColumnStat(distinctCount = 10, min = Some(1.0), max = Some(10.0), nullCount = 0, avgLen = 8, maxLen = 8) - // Sixth column cstring has 10 String values: + // column cstring has 10 String values: // "A0", "A1", "A2", "A3", "A4", "A5", "A6", "A7", "A8", "A9" val attrString = AttributeReference("cstring", StringType)() val colStatString = ColumnStat(distinctCount = 10, min = None, max = None, nullCount = 0, avgLen = 2, maxLen = 2) + // column cint2 has values: 7, 8, 9, 10, 11, 12, 13, 14, 15, 16 + // Hence, distinctCount:10, min:7, max:16, nullCount:0, avgLen:4, maxLen:4 + // This column is created to test "cint < cint2 + val attrInt2 = AttributeReference("cint2", IntegerType)() + val colStatInt2 = ColumnStat(distinctCount = 10, min = Some(7), max = Some(16), + nullCount = 0, avgLen = 4, maxLen = 4) + + // column cint3 has values: 30, 31, 32, 33, 34, 35, 36, 37, 38, 39 + // Hence, distinctCount:10, min:30, max:39, nullCount:0, avgLen:4, maxLen:4 + // This column is created to test "cint = cint3 without overlap at all. + val attrInt3 = AttributeReference("cint3", IntegerType)() + val colStatInt3 = ColumnStat(distinctCount = 10, min = Some(30), max = Some(39), + nullCount = 0, avgLen = 4, maxLen = 4) + + // column cint4 has values in the range from 1 to 10 + // distinctCount:10, min:1, max:10, nullCount:0, avgLen:4, maxLen:4 + // This column is created to test complete overlap + val attrInt4 = AttributeReference("cint4", IntegerType)() + val colStatInt4 = ColumnStat(distinctCount = 10, min = Some(1), max = Some(10), + nullCount = 0, avgLen = 4, maxLen = 4) + val attributeMap = AttributeMap(Seq( attrInt -> colStatInt, attrBool -> colStatBool, attrDate -> colStatDate, attrDecimal -> colStatDecimal, attrDouble -> colStatDouble, - attrString -> colStatString)) + attrString -> colStatString, + attrInt2 -> colStatInt2, + attrInt3 -> colStatInt3, + attrInt4 -> colStatInt4 + )) test("true") { validateEstimatedStats( @@ -450,6 +475,89 @@ class FilterEstimationSuite extends StatsEstimationTestBase { } } + test("cint = cint2") { + // partial overlap case + validateEstimatedStats( + Filter(EqualTo(attrInt, attrInt2), childStatsTestPlan(Seq(attrInt, attrInt2), 10L)), + Seq(attrInt -> ColumnStat(distinctCount = 3, min = Some(7), max = Some(10), + nullCount = 0, avgLen = 4, maxLen = 4), + attrInt2 -> ColumnStat(distinctCount = 3, min = Some(7), max = Some(10), + nullCount = 0, avgLen = 4, maxLen = 4)), + expectedRowCount = 4) + } + + test("cint > cint2") { + // partial overlap case + validateEstimatedStats( + Filter(GreaterThan(attrInt, attrInt2), childStatsTestPlan(Seq(attrInt, attrInt2), 10L)), + Seq(attrInt -> ColumnStat(distinctCount = 3, min = Some(7), max = Some(10), + nullCount = 0, avgLen = 4, maxLen = 4), + attrInt2 -> ColumnStat(distinctCount = 3, min = Some(7), max = Some(10), + nullCount = 0, avgLen = 4, maxLen = 4)), + expectedRowCount = 4) + } + + test("cint < cint2") { + // partial overlap case + validateEstimatedStats( + Filter(LessThan(attrInt, attrInt2), childStatsTestPlan(Seq(attrInt, attrInt2), 10L)), + Seq(attrInt -> ColumnStat(distinctCount = 3, min = Some(1), max = Some(10), + nullCount = 0, avgLen = 4, maxLen = 4), + attrInt2 -> ColumnStat(distinctCount = 3, min = Some(7), max = Some(16), + nullCount = 0, avgLen = 4, maxLen = 4)), + expectedRowCount = 4) + } + + test("cint = cint4") { + // complete overlap case + validateEstimatedStats( + Filter(EqualTo(attrInt, attrInt4), childStatsTestPlan(Seq(attrInt, attrInt4), 10L)), + Seq(attrInt -> ColumnStat(distinctCount = 10, min = Some(1), max = Some(10), + nullCount = 0, avgLen = 4, maxLen = 4), + attrInt4 -> ColumnStat(distinctCount = 10, min = Some(1), max = Some(10), + nullCount = 0, avgLen = 4, maxLen = 4)), + expectedRowCount = 10) + } + + test("cint < cint4") { + // partial overlap case + validateEstimatedStats( + Filter(LessThan(attrInt, attrInt4), childStatsTestPlan(Seq(attrInt, attrInt4), 10L)), + Seq(attrInt -> ColumnStat(distinctCount = 3, min = Some(1), max = Some(10), + nullCount = 0, avgLen = 4, maxLen = 4), + attrInt4 -> ColumnStat(distinctCount = 3, min = Some(1), max = Some(10), + nullCount = 0, avgLen = 4, maxLen = 4)), + expectedRowCount = 4) + } + + test("cint = cint3") { + // no records qualify due to no overlap + val emptyColStats = Seq[(Attribute, ColumnStat)]() + validateEstimatedStats( + Filter(EqualTo(attrInt, attrInt3), childStatsTestPlan(Seq(attrInt, attrInt3), 10L)), + Nil, // set to empty + expectedRowCount = 0) + } + + test("cint < cint3") { + // all table records qualify. + validateEstimatedStats( + Filter(LessThan(attrInt, attrInt3), childStatsTestPlan(Seq(attrInt, attrInt3), 10L)), + Seq(attrInt -> ColumnStat(distinctCount = 10, min = Some(1), max = Some(10), + nullCount = 0, avgLen = 4, maxLen = 4), + attrInt3 -> ColumnStat(distinctCount = 10, min = Some(30), max = Some(39), + nullCount = 0, avgLen = 4, maxLen = 4)), + expectedRowCount = 10) + } + + test("cint > cint3") { + // no records qualify due to no overlap + validateEstimatedStats( + Filter(GreaterThan(attrInt, attrInt3), childStatsTestPlan(Seq(attrInt, attrInt3), 10L)), + Nil, // set to empty + expectedRowCount = 0) + } + private def childStatsTestPlan(outList: Seq[Attribute], tableRowCount: BigInt): StatsTestPlan = { StatsTestPlan( outputList = outList, @@ -491,7 +599,23 @@ class FilterEstimationSuite extends StatsEstimationTestBase { sizeInBytes = getOutputSize(filter.output, expectedRowCount, expectedAttributeMap), rowCount = Some(expectedRowCount), attributeStats = expectedAttributeMap) - assert(filter.stats(conf) == expectedStats) + + val filterStats = filter.stats(conf) + assert(filterStats.sizeInBytes == expectedStats.sizeInBytes) + assert(filterStats.rowCount == expectedStats.rowCount) + val rowCountValue = filterStats.rowCount.getOrElse(0) + // check the output column stats if the row count is > 0. + // When row count is 0, the output is set to empty. + if (rowCountValue != 0) { + // Need to check attributeStats one by one because we may have multiple output columns. + // Due to update operation, the output columns may be in different order. + assert(expectedColStats.size == filterStats.attributeStats.size) + expectedColStats.foreach { kv => + val filterColumnStat = filterStats.attributeStats.get(kv._1).get + assert(filterColumnStat == kv._2) + } + } } } + } From 3bfb639cb7352aec572ef6686d3471bd78748ffa Mon Sep 17 00:00:00 2001 From: Dilip Biswal Date: Tue, 4 Apr 2017 09:53:05 +0900 Subject: [PATCH 0184/1765] [SPARK-10364][SQL] Support Parquet logical type TIMESTAMP_MILLIS ## What changes were proposed in this pull request? **Description** from JIRA The TimestampType in Spark SQL is of microsecond precision. Ideally, we should convert Spark SQL timestamp values into Parquet TIMESTAMP_MICROS. But unfortunately parquet-mr hasn't supported it yet. For the read path, we should be able to read TIMESTAMP_MILLIS Parquet values and pad a 0 microsecond part to read values. For the write path, currently we are writing timestamps as INT96, similar to Impala and Hive. One alternative is that, we can have a separate SQL option to let users be able to write Spark SQL timestamp values as TIMESTAMP_MILLIS. Of course, in this way the microsecond part will be truncated. ## How was this patch tested? Added new tests in ParquetQuerySuite and ParquetIOSuite Author: Dilip Biswal Closes #15332 from dilipbiswal/parquet-time-millis. --- .../sql/catalyst/util/DateTimeUtils.scala | 19 +++++ .../apache/spark/sql/internal/SQLConf.scala | 9 +++ .../SpecificParquetRecordReaderBase.java | 1 + .../parquet/VectorizedColumnReader.java | 27 ++++++- .../parquet/ParquetFileFormat.scala | 14 +++- .../parquet/ParquetRowConverter.scala | 9 ++- .../parquet/ParquetSchemaConverter.scala | 25 ++++-- .../parquet/ParquetWriteSupport.scala | 15 ++++ .../test-data/timemillis-in-i64.parquet | Bin 0 -> 517 bytes .../datasources/parquet/ParquetIOSuite.scala | 16 +++- .../parquet/ParquetQuerySuite.scala | 73 ++++++++++++++++++ .../parquet/ParquetSchemaSuite.scala | 33 ++++++-- 12 files changed, 221 insertions(+), 20 deletions(-) create mode 100644 sql/core/src/test/resources/test-data/timemillis-in-i64.parquet diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala index 9b94c1e2b40bb..f614965520f4a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala @@ -44,6 +44,7 @@ object DateTimeUtils { final val JULIAN_DAY_OF_EPOCH = 2440588 final val SECONDS_PER_DAY = 60 * 60 * 24L final val MICROS_PER_SECOND = 1000L * 1000L + final val MILLIS_PER_SECOND = 1000L final val NANOS_PER_SECOND = MICROS_PER_SECOND * 1000L final val MICROS_PER_DAY = MICROS_PER_SECOND * SECONDS_PER_DAY @@ -237,6 +238,24 @@ object DateTimeUtils { (day.toInt, micros * 1000L) } + /* + * Converts the timestamp to milliseconds since epoch. In spark timestamp values have microseconds + * precision, so this conversion is lossy. + */ + def toMillis(us: SQLTimestamp): Long = { + // When the timestamp is negative i.e before 1970, we need to adjust the millseconds portion. + // Example - 1965-01-01 10:11:12.123456 is represented as (-157700927876544) in micro precision. + // In millis precision the above needs to be represented as (-157700927877). + Math.floor(us.toDouble / MILLIS_PER_SECOND).toLong + } + + /* + * Converts millseconds since epoch to SQLTimestamp. + */ + def fromMillis(millis: Long): SQLTimestamp = { + millis * 1000L + } + /** * Parses a given UTF8 date string to the corresponding a corresponding [[Long]] value. * The return type is [[Option]] in order to distinguish between 0L and null. The following diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 5566b06aa3553..06dc0b41204fb 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -227,6 +227,13 @@ object SQLConf { .booleanConf .createWithDefault(true) + val PARQUET_INT64_AS_TIMESTAMP_MILLIS = buildConf("spark.sql.parquet.int64AsTimestampMillis") + .doc("When true, timestamp values will be stored as INT64 with TIMESTAMP_MILLIS as the " + + "extended type. In this mode, the microsecond portion of the timestamp value will be" + + "truncated.") + .booleanConf + .createWithDefault(false) + val PARQUET_CACHE_METADATA = buildConf("spark.sql.parquet.cacheMetadata") .doc("Turns on caching of Parquet schema metadata. Can speed up querying of static data.") .booleanConf @@ -935,6 +942,8 @@ class SQLConf extends Serializable with Logging { def isParquetINT96AsTimestamp: Boolean = getConf(PARQUET_INT96_AS_TIMESTAMP) + def isParquetINT64AsTimestampMillis: Boolean = getConf(PARQUET_INT64_AS_TIMESTAMP_MILLIS) + def writeLegacyParquetFormat: Boolean = getConf(PARQUET_WRITE_LEGACY_FORMAT) def inMemoryPartitionPruning: Boolean = getConf(IN_MEMORY_PARTITION_PRUNING) diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/SpecificParquetRecordReaderBase.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/SpecificParquetRecordReaderBase.java index bf8717483575f..eb97118872ea1 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/SpecificParquetRecordReaderBase.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/SpecificParquetRecordReaderBase.java @@ -197,6 +197,7 @@ protected void initialize(String path, List columns) throws IOException config.set("spark.sql.parquet.binaryAsString", "false"); config.set("spark.sql.parquet.int96AsTimestamp", "false"); config.set("spark.sql.parquet.writeLegacyFormat", "false"); + config.set("spark.sql.parquet.int64AsTimestampMillis", "false"); this.file = new Path(path); long length = this.file.getFileSystem(config).getFileStatus(this.file).getLen(); diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedColumnReader.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedColumnReader.java index cb51cb499eede..9d641b528723a 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedColumnReader.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedColumnReader.java @@ -28,6 +28,7 @@ import org.apache.parquet.io.api.Binary; import org.apache.parquet.schema.PrimitiveType; +import org.apache.spark.sql.catalyst.util.DateTimeUtils; import org.apache.spark.sql.execution.vectorized.ColumnVector; import org.apache.spark.sql.types.DataTypes; import org.apache.spark.sql.types.DecimalType; @@ -155,9 +156,13 @@ void readBatch(int total, ColumnVector column) throws IOException { // Read and decode dictionary ids. defColumn.readIntegers( num, dictionaryIds, column, rowId, maxDefLevel, (VectorizedValuesReader) dataColumn); + + // Timestamp values encoded as INT64 can't be lazily decoded as we need to post process + // the values to add microseconds precision. if (column.hasDictionary() || (rowId == 0 && (descriptor.getType() == PrimitiveType.PrimitiveTypeName.INT32 || - descriptor.getType() == PrimitiveType.PrimitiveTypeName.INT64 || + (descriptor.getType() == PrimitiveType.PrimitiveTypeName.INT64 && + column.dataType() != DataTypes.TimestampType) || descriptor.getType() == PrimitiveType.PrimitiveTypeName.FLOAT || descriptor.getType() == PrimitiveType.PrimitiveTypeName.DOUBLE || descriptor.getType() == PrimitiveType.PrimitiveTypeName.BINARY))) { @@ -250,7 +255,15 @@ private void decodeDictionaryIds(int rowId, int num, ColumnVector column, column.putLong(i, dictionary.decodeToLong(dictionaryIds.getDictId(i))); } } - } else { + } else if (column.dataType() == DataTypes.TimestampType) { + for (int i = rowId; i < rowId + num; ++i) { + if (!column.isNullAt(i)) { + column.putLong(i, + DateTimeUtils.fromMillis(dictionary.decodeToLong(dictionaryIds.getDictId(i)))); + } + } + } + else { throw new UnsupportedOperationException("Unimplemented type: " + column.dataType()); } break; @@ -362,7 +375,15 @@ private void readLongBatch(int rowId, int num, ColumnVector column) throws IOExc if (column.dataType() == DataTypes.LongType || DecimalType.is64BitDecimalType(column.dataType())) { defColumn.readLongs( - num, column, rowId, maxDefLevel, (VectorizedValuesReader) dataColumn); + num, column, rowId, maxDefLevel, (VectorizedValuesReader) dataColumn); + } else if (column.dataType() == DataTypes.TimestampType) { + for (int i = 0; i < num; i++) { + if (defColumn.readInteger() == maxDefLevel) { + column.putLong(rowId + i, DateTimeUtils.fromMillis(dataColumn.readLong())); + } else { + column.putNull(rowId + i); + } + } } else { throw new UnsupportedOperationException("Unsupported conversion to: " + column.dataType()); } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala index 062aa5c8ea624..2f3a2c62b912c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala @@ -125,6 +125,10 @@ class ParquetFileFormat SQLConf.PARQUET_WRITE_LEGACY_FORMAT.key, sparkSession.sessionState.conf.writeLegacyParquetFormat.toString) + conf.set( + SQLConf.PARQUET_INT64_AS_TIMESTAMP_MILLIS.key, + sparkSession.sessionState.conf.isParquetINT64AsTimestampMillis.toString) + // Sets compression scheme conf.set(ParquetOutputFormat.COMPRESSION, parquetOptions.compressionCodecClassName) @@ -300,6 +304,9 @@ class ParquetFileFormat hadoopConf.setBoolean( SQLConf.PARQUET_INT96_AS_TIMESTAMP.key, sparkSession.sessionState.conf.isParquetINT96AsTimestamp) + hadoopConf.setBoolean( + SQLConf.PARQUET_INT64_AS_TIMESTAMP_MILLIS.key, + sparkSession.sessionState.conf.isParquetINT64AsTimestampMillis) // Try to push down filters when filter push-down is enabled. val pushed = @@ -410,7 +417,8 @@ object ParquetFileFormat extends Logging { val converter = new ParquetSchemaConverter( sparkSession.sessionState.conf.isParquetBinaryAsString, sparkSession.sessionState.conf.isParquetBinaryAsString, - sparkSession.sessionState.conf.writeLegacyParquetFormat) + sparkSession.sessionState.conf.writeLegacyParquetFormat, + sparkSession.sessionState.conf.isParquetINT64AsTimestampMillis) converter.convert(schema) } @@ -510,6 +518,7 @@ object ParquetFileFormat extends Logging { sparkSession: SparkSession): Option[StructType] = { val assumeBinaryIsString = sparkSession.sessionState.conf.isParquetBinaryAsString val assumeInt96IsTimestamp = sparkSession.sessionState.conf.isParquetINT96AsTimestamp + val writeTimestampInMillis = sparkSession.sessionState.conf.isParquetINT64AsTimestampMillis val writeLegacyParquetFormat = sparkSession.sessionState.conf.writeLegacyParquetFormat val serializedConf = new SerializableConfiguration(sparkSession.sessionState.newHadoopConf()) @@ -554,7 +563,8 @@ object ParquetFileFormat extends Logging { new ParquetSchemaConverter( assumeBinaryIsString = assumeBinaryIsString, assumeInt96IsTimestamp = assumeInt96IsTimestamp, - writeLegacyParquetFormat = writeLegacyParquetFormat) + writeLegacyParquetFormat = writeLegacyParquetFormat, + writeTimestampInMillis = writeTimestampInMillis) if (footers.isEmpty) { Iterator.empty diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRowConverter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRowConverter.scala index 33dcf2f3fd167..32e6c60cd9766 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRowConverter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRowConverter.scala @@ -25,7 +25,7 @@ import scala.collection.mutable.ArrayBuffer import org.apache.parquet.column.Dictionary import org.apache.parquet.io.api.{Binary, Converter, GroupConverter, PrimitiveConverter} -import org.apache.parquet.schema.{GroupType, MessageType, Type} +import org.apache.parquet.schema.{GroupType, MessageType, OriginalType, Type} import org.apache.parquet.schema.OriginalType.{INT_32, LIST, UTF8} import org.apache.parquet.schema.PrimitiveType.PrimitiveTypeName.{BINARY, DOUBLE, FIXED_LEN_BYTE_ARRAY, INT32, INT64} @@ -252,6 +252,13 @@ private[parquet] class ParquetRowConverter( case StringType => new ParquetStringConverter(updater) + case TimestampType if parquetType.getOriginalType == OriginalType.TIMESTAMP_MILLIS => + new ParquetPrimitiveConverter(updater) { + override def addLong(value: Long): Unit = { + updater.setLong(DateTimeUtils.fromMillis(value)) + } + } + case TimestampType => // TODO Implements `TIMESTAMP_MICROS` once parquet-mr has that. new ParquetPrimitiveConverter(updater) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaConverter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaConverter.scala index 66d4027edf9f1..0b805e4362883 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaConverter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaConverter.scala @@ -51,22 +51,29 @@ import org.apache.spark.sql.types._ * and prior versions when converting a Catalyst [[StructType]] to a Parquet [[MessageType]]. * When set to false, use standard format defined in parquet-format spec. This argument only * affects Parquet write path. + * @param writeTimestampInMillis Whether to write timestamp values as INT64 annotated by logical + * type TIMESTAMP_MILLIS. + * */ private[parquet] class ParquetSchemaConverter( assumeBinaryIsString: Boolean = SQLConf.PARQUET_BINARY_AS_STRING.defaultValue.get, assumeInt96IsTimestamp: Boolean = SQLConf.PARQUET_INT96_AS_TIMESTAMP.defaultValue.get, - writeLegacyParquetFormat: Boolean = SQLConf.PARQUET_WRITE_LEGACY_FORMAT.defaultValue.get) { + writeLegacyParquetFormat: Boolean = SQLConf.PARQUET_WRITE_LEGACY_FORMAT.defaultValue.get, + writeTimestampInMillis: Boolean = SQLConf.PARQUET_INT64_AS_TIMESTAMP_MILLIS.defaultValue.get) { def this(conf: SQLConf) = this( assumeBinaryIsString = conf.isParquetBinaryAsString, assumeInt96IsTimestamp = conf.isParquetINT96AsTimestamp, - writeLegacyParquetFormat = conf.writeLegacyParquetFormat) + writeLegacyParquetFormat = conf.writeLegacyParquetFormat, + writeTimestampInMillis = conf.isParquetINT64AsTimestampMillis) def this(conf: Configuration) = this( assumeBinaryIsString = conf.get(SQLConf.PARQUET_BINARY_AS_STRING.key).toBoolean, assumeInt96IsTimestamp = conf.get(SQLConf.PARQUET_INT96_AS_TIMESTAMP.key).toBoolean, writeLegacyParquetFormat = conf.get(SQLConf.PARQUET_WRITE_LEGACY_FORMAT.key, - SQLConf.PARQUET_WRITE_LEGACY_FORMAT.defaultValue.get.toString).toBoolean) + SQLConf.PARQUET_WRITE_LEGACY_FORMAT.defaultValue.get.toString).toBoolean, + writeTimestampInMillis = conf.get(SQLConf.PARQUET_INT64_AS_TIMESTAMP_MILLIS.key).toBoolean) + /** * Converts Parquet [[MessageType]] `parquetSchema` to a Spark SQL [[StructType]]. @@ -158,7 +165,7 @@ private[parquet] class ParquetSchemaConverter( case INT_64 | null => LongType case DECIMAL => makeDecimalType(Decimal.MAX_LONG_DIGITS) case UINT_64 => typeNotSupported() - case TIMESTAMP_MILLIS => typeNotImplemented() + case TIMESTAMP_MILLIS => TimestampType case _ => illegalType() } @@ -370,10 +377,16 @@ private[parquet] class ParquetSchemaConverter( // we may resort to microsecond precision in the future. // // For Parquet, we plan to write all `TimestampType` value as `TIMESTAMP_MICROS`, but it's - // currently not implemented yet because parquet-mr 1.7.0 (the version we're currently using) - // hasn't implemented `TIMESTAMP_MICROS` yet. + // currently not implemented yet because parquet-mr 1.8.1 (the version we're currently using) + // hasn't implemented `TIMESTAMP_MICROS` yet, however it supports TIMESTAMP_MILLIS. We will + // encode timestamp values as TIMESTAMP_MILLIS annotating INT64 if + // 'spark.sql.parquet.int64AsTimestampMillis' is set. // // TODO Converts `TIMESTAMP_MICROS` once parquet-mr implements that. + + case TimestampType if writeTimestampInMillis => + Types.primitive(INT64, repetition).as(TIMESTAMP_MILLIS).named(field.name) + case TimestampType => Types.primitive(INT96, repetition).named(field.name) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetWriteSupport.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetWriteSupport.scala index a31d2b9c37e9d..38b0e33937f3c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetWriteSupport.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetWriteSupport.scala @@ -66,6 +66,9 @@ private[parquet] class ParquetWriteSupport extends WriteSupport[InternalRow] wit // Whether to write data in legacy Parquet format compatible with Spark 1.4 and prior versions private var writeLegacyParquetFormat: Boolean = _ + // Whether to write timestamp value with milliseconds precision. + private var writeTimestampInMillis: Boolean = _ + // Reusable byte array used to write timestamps as Parquet INT96 values private val timestampBuffer = new Array[Byte](12) @@ -80,6 +83,13 @@ private[parquet] class ParquetWriteSupport extends WriteSupport[InternalRow] wit assert(configuration.get(SQLConf.PARQUET_WRITE_LEGACY_FORMAT.key) != null) configuration.get(SQLConf.PARQUET_WRITE_LEGACY_FORMAT.key).toBoolean } + + this.writeTimestampInMillis = { + assert(configuration.get(SQLConf.PARQUET_INT64_AS_TIMESTAMP_MILLIS.key) != null) + configuration.get(SQLConf.PARQUET_INT64_AS_TIMESTAMP_MILLIS.key).toBoolean + } + + this.rootFieldWriters = schema.map(_.dataType).map(makeWriter) val messageType = new ParquetSchemaConverter(configuration).convert(schema) @@ -153,6 +163,11 @@ private[parquet] class ParquetWriteSupport extends WriteSupport[InternalRow] wit recordConsumer.addBinary( Binary.fromReusedByteArray(row.getUTF8String(ordinal).getBytes)) + case TimestampType if writeTimestampInMillis => + (row: SpecializedGetters, ordinal: Int) => + val millis = DateTimeUtils.toMillis(row.getLong(ordinal)) + recordConsumer.addLong(millis) + case TimestampType => (row: SpecializedGetters, ordinal: Int) => { // TODO Writes `TimestampType` values as `TIMESTAMP_MICROS` once parquet-mr implements it diff --git a/sql/core/src/test/resources/test-data/timemillis-in-i64.parquet b/sql/core/src/test/resources/test-data/timemillis-in-i64.parquet new file mode 100644 index 0000000000000000000000000000000000000000..d3c39e2c26eece8d20c154283c1a3fac40859efd GIT binary patch literal 517 zcmaKq&r8EF6vxvVq=*L*6I$q@1U4LW!R~j5m)&*{9h)~%N!wJ5ZP&G_B4Z$)_Gg>! z2h)o=Bros#KJR@4nT)0m0_Y4~*hrOuhBQ;xPQZ2@B3vajb1xuRAvY3%f6_n-nk_eZ z{MSi^;0URPJw7cmmcKn0{wq%yQYBvlIuudDYv%wT8@6HAH4{Oj1~g+UAQh|F!(m;! zKKMIC8>cwLDv)TpLE$eH;x7fSm3sOQyjC!jwBDHKFO+3Wnxh+^v{=Mc8eWuK(0u+u z6E0Z51k<0EM0{qP3`rsK(ig-gVZ`I0Aj5|xNm)`!)w86qE39sXU`ZxZX&J}Ni)B&B z;)2^`-{FdO^DrxK6L_yM4}mci^_Jf literal 0 HcmV?d00001 diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala index dbdcd230a4de9..57a0af1dda971 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala @@ -107,11 +107,13 @@ class ParquetIOSuite extends QueryTest with ParquetTest with SharedSQLContext { | required binary g(ENUM); | required binary h(DECIMAL(32,0)); | required fixed_len_byte_array(32) i(DECIMAL(32,0)); + | required int64 j(TIMESTAMP_MILLIS); |} """.stripMargin) val expectedSparkTypes = Seq(ByteType, ShortType, DateType, DecimalType(1, 0), - DecimalType(10, 0), StringType, StringType, DecimalType(32, 0), DecimalType(32, 0)) + DecimalType(10, 0), StringType, StringType, DecimalType(32, 0), DecimalType(32, 0), + TimestampType) withTempPath { location => val path = new Path(location.getCanonicalPath) @@ -607,6 +609,18 @@ class ParquetIOSuite extends QueryTest with ParquetTest with SharedSQLContext { } } + test("read dictionary and plain encoded timestamp_millis written as INT64") { + ("true" :: "false" :: Nil).foreach { vectorized => + withSQLConf(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key -> vectorized) { + checkAnswer( + // timestamp column in this file is encoded using combination of plain + // and dictionary encodings. + readResourceParquetFile("test-data/timemillis-in-i64.parquet"), + (1 to 3).map(i => Row(new java.sql.Timestamp(10)))) + } + } + } + test("SPARK-12589 copy() on rows returned from reader works for strings") { withTempPath { dir => val data = (1, "abc") ::(2, "helloabcde") :: Nil diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala index 200e356c72fd7..c36609586c807 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.execution.datasources.parquet import java.io.File +import java.sql.Timestamp import org.apache.hadoop.fs.{FileSystem, Path} import org.apache.parquet.hadoop.ParquetOutputFormat @@ -162,6 +163,78 @@ class ParquetQuerySuite extends QueryTest with ParquetTest with SharedSQLContext } } + test("SPARK-10634 timestamp written and read as INT64 - TIMESTAMP_MILLIS") { + val data = (1 to 10).map(i => Row(i, new java.sql.Timestamp(i))) + val schema = StructType(List(StructField("d", IntegerType, false), + StructField("time", TimestampType, false)).toArray) + withSQLConf(SQLConf.PARQUET_INT64_AS_TIMESTAMP_MILLIS.key -> "true") { + withTempPath { file => + val df = spark.createDataFrame(sparkContext.parallelize(data), schema) + df.write.parquet(file.getCanonicalPath) + ("true" :: "false" :: Nil).foreach { vectorized => + withSQLConf(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key -> vectorized) { + val df2 = spark.read.parquet(file.getCanonicalPath) + checkAnswer(df2, df.collect().toSeq) + } + } + } + } + } + + test("SPARK-10634 timestamp written and read as INT64 - truncation") { + withTable("ts") { + sql("create table ts (c1 int, c2 timestamp) using parquet") + sql("insert into ts values (1, '2016-01-01 10:11:12.123456')") + sql("insert into ts values (2, null)") + sql("insert into ts values (3, '1965-01-01 10:11:12.123456')") + checkAnswer( + sql("select * from ts"), + Seq( + Row(1, Timestamp.valueOf("2016-01-01 10:11:12.123456")), + Row(2, null), + Row(3, Timestamp.valueOf("1965-01-01 10:11:12.123456")))) + } + + // The microsecond portion is truncated when written as TIMESTAMP_MILLIS. + withTable("ts") { + withSQLConf(SQLConf.PARQUET_INT64_AS_TIMESTAMP_MILLIS.key -> "true") { + sql("create table ts (c1 int, c2 timestamp) using parquet") + sql("insert into ts values (1, '2016-01-01 10:11:12.123456')") + sql("insert into ts values (2, null)") + sql("insert into ts values (3, '1965-01-01 10:11:12.125456')") + sql("insert into ts values (4, '1965-01-01 10:11:12.125')") + sql("insert into ts values (5, '1965-01-01 10:11:12.1')") + sql("insert into ts values (6, '1965-01-01 10:11:12.123456789')") + sql("insert into ts values (7, '0001-01-01 00:00:00.000000')") + checkAnswer( + sql("select * from ts"), + Seq( + Row(1, Timestamp.valueOf("2016-01-01 10:11:12.123")), + Row(2, null), + Row(3, Timestamp.valueOf("1965-01-01 10:11:12.125")), + Row(4, Timestamp.valueOf("1965-01-01 10:11:12.125")), + Row(5, Timestamp.valueOf("1965-01-01 10:11:12.1")), + Row(6, Timestamp.valueOf("1965-01-01 10:11:12.123")), + Row(7, Timestamp.valueOf("0001-01-01 00:00:00.000")))) + + // Read timestamps that were encoded as TIMESTAMP_MILLIS annotated as INT64 + // with PARQUET_INT64_AS_TIMESTAMP_MILLIS set to false. + withSQLConf(SQLConf.PARQUET_INT64_AS_TIMESTAMP_MILLIS.key -> "false") { + checkAnswer( + sql("select * from ts"), + Seq( + Row(1, Timestamp.valueOf("2016-01-01 10:11:12.123")), + Row(2, null), + Row(3, Timestamp.valueOf("1965-01-01 10:11:12.125")), + Row(4, Timestamp.valueOf("1965-01-01 10:11:12.125")), + Row(5, Timestamp.valueOf("1965-01-01 10:11:12.1")), + Row(6, Timestamp.valueOf("1965-01-01 10:11:12.123")), + Row(7, Timestamp.valueOf("0001-01-01 00:00:00.000")))) + } + } + } + } + test("Enabling/disabling merging partfiles when merging parquet schema") { def testSchemaMerging(expectedColumnNumber: Int): Unit = { withTempDir { dir => diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaSuite.scala index 6aa940afbb2c4..ce992674d719f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaSuite.scala @@ -53,11 +53,13 @@ abstract class ParquetSchemaTest extends ParquetTest with SharedSQLContext { parquetSchema: String, binaryAsString: Boolean, int96AsTimestamp: Boolean, - writeLegacyParquetFormat: Boolean): Unit = { + writeLegacyParquetFormat: Boolean, + int64AsTimestampMillis: Boolean = false): Unit = { val converter = new ParquetSchemaConverter( assumeBinaryIsString = binaryAsString, assumeInt96IsTimestamp = int96AsTimestamp, - writeLegacyParquetFormat = writeLegacyParquetFormat) + writeLegacyParquetFormat = writeLegacyParquetFormat, + writeTimestampInMillis = int64AsTimestampMillis) test(s"sql <= parquet: $testName") { val actual = converter.convert(MessageTypeParser.parseMessageType(parquetSchema)) @@ -77,11 +79,13 @@ abstract class ParquetSchemaTest extends ParquetTest with SharedSQLContext { parquetSchema: String, binaryAsString: Boolean, int96AsTimestamp: Boolean, - writeLegacyParquetFormat: Boolean): Unit = { + writeLegacyParquetFormat: Boolean, + int64AsTimestampMillis: Boolean = false): Unit = { val converter = new ParquetSchemaConverter( assumeBinaryIsString = binaryAsString, assumeInt96IsTimestamp = int96AsTimestamp, - writeLegacyParquetFormat = writeLegacyParquetFormat) + writeLegacyParquetFormat = writeLegacyParquetFormat, + writeTimestampInMillis = int64AsTimestampMillis) test(s"sql => parquet: $testName") { val actual = converter.convert(sqlSchema) @@ -97,7 +101,8 @@ abstract class ParquetSchemaTest extends ParquetTest with SharedSQLContext { parquetSchema: String, binaryAsString: Boolean, int96AsTimestamp: Boolean, - writeLegacyParquetFormat: Boolean): Unit = { + writeLegacyParquetFormat: Boolean, + int64AsTimestampMillis: Boolean = false): Unit = { testCatalystToParquet( testName, @@ -105,7 +110,8 @@ abstract class ParquetSchemaTest extends ParquetTest with SharedSQLContext { parquetSchema, binaryAsString, int96AsTimestamp, - writeLegacyParquetFormat) + writeLegacyParquetFormat, + int64AsTimestampMillis) testParquetToCatalyst( testName, @@ -113,7 +119,8 @@ abstract class ParquetSchemaTest extends ParquetTest with SharedSQLContext { parquetSchema, binaryAsString, int96AsTimestamp, - writeLegacyParquetFormat) + writeLegacyParquetFormat, + int64AsTimestampMillis) } } @@ -965,6 +972,18 @@ class ParquetSchemaSuite extends ParquetSchemaTest { int96AsTimestamp = true, writeLegacyParquetFormat = true) + testSchema( + "Timestamp written and read as INT64 with TIMESTAMP_MILLIS", + StructType(Seq(StructField("f1", TimestampType))), + """message root { + | optional INT64 f1 (TIMESTAMP_MILLIS); + |} + """.stripMargin, + binaryAsString = true, + int96AsTimestamp = false, + writeLegacyParquetFormat = true, + int64AsTimestampMillis = true) + private def testSchemaClipping( testName: String, parquetSchema: String, From 51d3c854c54369aec1bfd55cefcd080dcd178d5f Mon Sep 17 00:00:00 2001 From: Xiao Li Date: Mon, 3 Apr 2017 23:30:12 -0700 Subject: [PATCH 0185/1765] [SPARK-20067][SQL] Unify and Clean Up Desc Commands Using Catalog Interface ### What changes were proposed in this pull request? This PR is to unify and clean up the outputs of `DESC EXTENDED/FORMATTED` and `SHOW TABLE EXTENDED` by moving the logics into the Catalog interface. The output formats are improved. We also add the missing attributes. It impacts the DDL commands like `SHOW TABLE EXTENDED`, `DESC EXTENDED` and `DESC FORMATTED`. In addition, by following what we did in Dataset API `printSchema`, we can use `treeString` to show the schema in the more readable way. Below is the current way: ``` Schema: STRUCT<`a`: STRING (nullable = true), `b`: INT (nullable = true), `c`: STRING (nullable = true), `d`: STRING (nullable = true)> ``` After the change, it should look like ``` Schema: root |-- a: string (nullable = true) |-- b: integer (nullable = true) |-- c: string (nullable = true) |-- d: string (nullable = true) ``` ### How was this patch tested? `describe.sql` and `show-tables.sql` Author: Xiao Li Closes #17394 from gatorsmile/descFollowUp. --- .../sql/catalyst/catalog/interface.scala | 136 ++++-- .../spark/sql/execution/SparkSqlParser.scala | 3 +- .../spark/sql/execution/command/tables.scala | 124 ++--- .../resources/sql-tests/inputs/describe.sql | 53 ++- .../sql-tests/results/change-column.sql.out | 9 + .../sql-tests/results/describe.sql.out | 422 ++++++++++++++---- .../sql-tests/results/show-tables.sql.out | 67 +-- .../apache/spark/sql/SQLQueryTestSuite.scala | 19 +- .../sql/execution/SparkSqlParserSuite.scala | 6 +- .../sql/execution/command/DDLSuite.scala | 12 - .../org/apache/spark/sql/jdbc/JDBCSuite.scala | 4 - .../spark/sql/sources/DDLTestSuite.scala | 123 ----- .../spark/sql/sources/DataSourceTest.scala | 56 +++ .../sql/hive/MetastoreDataSourcesSuite.scala | 8 +- .../hive/execution/HiveComparisonTest.scala | 4 +- .../sql/hive/execution/HiveDDLSuite.scala | 93 +--- .../HiveOperatorQueryableSuite.scala | 53 --- .../sql/hive/execution/HiveQuerySuite.scala | 56 --- .../sql/hive/execution/SQLQuerySuite.scala | 131 +----- 19 files changed, 642 insertions(+), 737 deletions(-) delete mode 100644 sql/core/src/test/scala/org/apache/spark/sql/sources/DDLTestSuite.scala delete mode 100644 sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveOperatorQueryableSuite.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala index 70ed44e025f51..3f25f9e7258f0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala @@ -20,6 +20,8 @@ package org.apache.spark.sql.catalyst.catalog import java.net.URI import java.util.Date +import scala.collection.mutable + import com.google.common.base.Objects import org.apache.spark.sql.AnalysisException @@ -57,20 +59,25 @@ case class CatalogStorageFormat( properties: Map[String, String]) { override def toString: String = { - val serdePropsToString = CatalogUtils.maskCredentials(properties) match { - case props if props.isEmpty => "" - case props => "Properties: " + props.map(p => p._1 + "=" + p._2).mkString("[", ", ", "]") - } - val output = - Seq(locationUri.map("Location: " + _).getOrElse(""), - inputFormat.map("InputFormat: " + _).getOrElse(""), - outputFormat.map("OutputFormat: " + _).getOrElse(""), - if (compressed) "Compressed" else "", - serde.map("Serde: " + _).getOrElse(""), - serdePropsToString) - output.filter(_.nonEmpty).mkString("Storage(", ", ", ")") + toLinkedHashMap.map { case ((key, value)) => + if (value.isEmpty) key else s"$key: $value" + }.mkString("Storage(", ", ", ")") } + def toLinkedHashMap: mutable.LinkedHashMap[String, String] = { + val map = new mutable.LinkedHashMap[String, String]() + locationUri.foreach(l => map.put("Location", l.toString)) + serde.foreach(map.put("Serde Library", _)) + inputFormat.foreach(map.put("InputFormat", _)) + outputFormat.foreach(map.put("OutputFormat", _)) + if (compressed) map.put("Compressed", "") + CatalogUtils.maskCredentials(properties) match { + case props if props.isEmpty => // No-op + case props => + map.put("Properties", props.map(p => p._1 + "=" + p._2).mkString("[", ", ", "]")) + } + map + } } object CatalogStorageFormat { @@ -91,15 +98,28 @@ case class CatalogTablePartition( storage: CatalogStorageFormat, parameters: Map[String, String] = Map.empty) { - override def toString: String = { + def toLinkedHashMap: mutable.LinkedHashMap[String, String] = { + val map = new mutable.LinkedHashMap[String, String]() val specString = spec.map { case (k, v) => s"$k=$v" }.mkString(", ") - val output = - Seq( - s"Partition Values: [$specString]", - s"$storage", - s"Partition Parameters:{${parameters.map(p => p._1 + "=" + p._2).mkString(", ")}}") + map.put("Partition Values", s"[$specString]") + map ++= storage.toLinkedHashMap + if (parameters.nonEmpty) { + map.put("Partition Parameters", s"{${parameters.map(p => p._1 + "=" + p._2).mkString(", ")}}") + } + map + } - output.filter(_.nonEmpty).mkString("CatalogPartition(\n\t", "\n\t", ")") + override def toString: String = { + toLinkedHashMap.map { case ((key, value)) => + if (value.isEmpty) key else s"$key: $value" + }.mkString("CatalogPartition(\n\t", "\n\t", ")") + } + + /** Readable string representation for the CatalogTablePartition. */ + def simpleString: String = { + toLinkedHashMap.map { case ((key, value)) => + if (value.isEmpty) key else s"$key: $value" + }.mkString("", "\n", "") } /** Return the partition location, assuming it is specified. */ @@ -154,6 +174,14 @@ case class BucketSpec( } s"$numBuckets buckets, $bucketString$sortString" } + + def toLinkedHashMap: mutable.LinkedHashMap[String, String] = { + mutable.LinkedHashMap[String, String]( + "Num Buckets" -> numBuckets.toString, + "Bucket Columns" -> bucketColumnNames.map(quoteIdentifier).mkString("[", ", ", "]"), + "Sort Columns" -> sortColumnNames.map(quoteIdentifier).mkString("[", ", ", "]") + ) + } } /** @@ -261,40 +289,50 @@ case class CatalogTable( locationUri, inputFormat, outputFormat, serde, compressed, properties)) } - override def toString: String = { + + def toLinkedHashMap: mutable.LinkedHashMap[String, String] = { + val map = new mutable.LinkedHashMap[String, String]() val tableProperties = properties.map(p => p._1 + "=" + p._2).mkString("[", ", ", "]") val partitionColumns = partitionColumnNames.map(quoteIdentifier).mkString("[", ", ", "]") - val bucketStrings = bucketSpec match { - case Some(BucketSpec(numBuckets, bucketColumnNames, sortColumnNames)) => - val bucketColumnsString = bucketColumnNames.map(quoteIdentifier).mkString("[", ", ", "]") - val sortColumnsString = sortColumnNames.map(quoteIdentifier).mkString("[", ", ", "]") - Seq( - s"Num Buckets: $numBuckets", - if (bucketColumnNames.nonEmpty) s"Bucket Columns: $bucketColumnsString" else "", - if (sortColumnNames.nonEmpty) s"Sort Columns: $sortColumnsString" else "" - ) - - case _ => Nil + + identifier.database.foreach(map.put("Database", _)) + map.put("Table", identifier.table) + if (owner.nonEmpty) map.put("Owner", owner) + map.put("Created", new Date(createTime).toString) + map.put("Last Access", new Date(lastAccessTime).toString) + map.put("Type", tableType.name) + provider.foreach(map.put("Provider", _)) + bucketSpec.foreach(map ++= _.toLinkedHashMap) + comment.foreach(map.put("Comment", _)) + if (tableType == CatalogTableType.VIEW) { + viewText.foreach(map.put("View Text", _)) + viewDefaultDatabase.foreach(map.put("View Default Database", _)) + if (viewQueryColumnNames.nonEmpty) { + map.put("View Query Output Columns", viewQueryColumnNames.mkString("[", ", ", "]")) + } } - val output = - Seq(s"Table: ${identifier.quotedString}", - if (owner.nonEmpty) s"Owner: $owner" else "", - s"Created: ${new Date(createTime).toString}", - s"Last Access: ${new Date(lastAccessTime).toString}", - s"Type: ${tableType.name}", - if (schema.nonEmpty) s"Schema: ${schema.mkString("[", ", ", "]")}" else "", - if (provider.isDefined) s"Provider: ${provider.get}" else "", - if (partitionColumnNames.nonEmpty) s"Partition Columns: $partitionColumns" else "" - ) ++ bucketStrings ++ Seq( - viewText.map("View: " + _).getOrElse(""), - comment.map("Comment: " + _).getOrElse(""), - if (properties.nonEmpty) s"Properties: $tableProperties" else "", - if (stats.isDefined) s"Statistics: ${stats.get.simpleString}" else "", - s"$storage", - if (tracksPartitionsInCatalog) "Partition Provider: Catalog" else "") - - output.filter(_.nonEmpty).mkString("CatalogTable(\n\t", "\n\t", ")") + if (properties.nonEmpty) map.put("Properties", tableProperties) + stats.foreach(s => map.put("Statistics", s.simpleString)) + map ++= storage.toLinkedHashMap + if (tracksPartitionsInCatalog) map.put("Partition Provider", "Catalog") + if (partitionColumnNames.nonEmpty) map.put("Partition Columns", partitionColumns) + if (schema.nonEmpty) map.put("Schema", schema.treeString) + + map + } + + override def toString: String = { + toLinkedHashMap.map { case ((key, value)) => + if (value.isEmpty) key else s"$key: $value" + }.mkString("CatalogTable(\n", "\n", ")") + } + + /** Readable string representation for the CatalogTable. */ + def simpleString: String = { + toLinkedHashMap.map { case ((key, value)) => + if (value.isEmpty) key else s"$key: $value" + }.mkString("", "\n", "") } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala index d4f23f9dd5185..80afb59b3e88e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala @@ -322,8 +322,7 @@ class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder { DescribeTableCommand( visitTableIdentifier(ctx.tableIdentifier), partitionSpec, - ctx.EXTENDED != null, - ctx.FORMATTED != null) + ctx.EXTENDED != null || ctx.FORMATTED != null) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala index c7aeef06a0bf0..ebf03e1bf8869 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala @@ -500,8 +500,7 @@ case class TruncateTableCommand( case class DescribeTableCommand( table: TableIdentifier, partitionSpec: TablePartitionSpec, - isExtended: Boolean, - isFormatted: Boolean) + isExtended: Boolean) extends RunnableCommand { override val output: Seq[Attribute] = Seq( @@ -536,14 +535,12 @@ case class DescribeTableCommand( describePartitionInfo(metadata, result) - if (partitionSpec.isEmpty) { - if (isExtended) { - describeExtendedTableInfo(metadata, result) - } else if (isFormatted) { - describeFormattedTableInfo(metadata, result) - } - } else { + if (partitionSpec.nonEmpty) { + // Outputs the partition-specific info for the DDL command: + // "DESCRIBE [EXTENDED|FORMATTED] table_name PARTITION (partitionVal*)" describeDetailedPartitionInfo(sparkSession, catalog, metadata, result) + } else if (isExtended) { + describeFormattedTableInfo(metadata, result) } } @@ -553,76 +550,20 @@ case class DescribeTableCommand( private def describePartitionInfo(table: CatalogTable, buffer: ArrayBuffer[Row]): Unit = { if (table.partitionColumnNames.nonEmpty) { append(buffer, "# Partition Information", "", "") - append(buffer, s"# ${output.head.name}", output(1).name, output(2).name) describeSchema(table.partitionSchema, buffer) } } - private def describeExtendedTableInfo(table: CatalogTable, buffer: ArrayBuffer[Row]): Unit = { - append(buffer, "", "", "") - append(buffer, "# Detailed Table Information", table.toString, "") - } - private def describeFormattedTableInfo(table: CatalogTable, buffer: ArrayBuffer[Row]): Unit = { + // The following information has been already shown in the previous outputs + val excludedTableInfo = Seq( + "Partition Columns", + "Schema" + ) append(buffer, "", "", "") append(buffer, "# Detailed Table Information", "", "") - append(buffer, "Database:", table.database, "") - append(buffer, "Owner:", table.owner, "") - append(buffer, "Created:", new Date(table.createTime).toString, "") - append(buffer, "Last Access:", new Date(table.lastAccessTime).toString, "") - append(buffer, "Location:", table.storage.locationUri.map(CatalogUtils.URIToString(_)) - .getOrElse(""), "") - append(buffer, "Table Type:", table.tableType.name, "") - append(buffer, "Comment:", table.comment.getOrElse(""), "") - table.stats.foreach(s => append(buffer, "Statistics:", s.simpleString, "")) - - append(buffer, "Table Parameters:", "", "") - table.properties.foreach { case (key, value) => - append(buffer, s" $key", value, "") - } - - describeStorageInfo(table, buffer) - - if (table.tableType == CatalogTableType.VIEW) describeViewInfo(table, buffer) - - if (DDLUtils.isDatasourceTable(table) && table.tracksPartitionsInCatalog) { - append(buffer, "Partition Provider:", "Catalog", "") - } - } - - private def describeStorageInfo(metadata: CatalogTable, buffer: ArrayBuffer[Row]): Unit = { - append(buffer, "", "", "") - append(buffer, "# Storage Information", "", "") - metadata.storage.serde.foreach(serdeLib => append(buffer, "SerDe Library:", serdeLib, "")) - metadata.storage.inputFormat.foreach(format => append(buffer, "InputFormat:", format, "")) - metadata.storage.outputFormat.foreach(format => append(buffer, "OutputFormat:", format, "")) - append(buffer, "Compressed:", if (metadata.storage.compressed) "Yes" else "No", "") - describeBucketingInfo(metadata, buffer) - - append(buffer, "Storage Desc Parameters:", "", "") - val maskedProperties = CatalogUtils.maskCredentials(metadata.storage.properties) - maskedProperties.foreach { case (key, value) => - append(buffer, s" $key", value, "") - } - } - - private def describeViewInfo(metadata: CatalogTable, buffer: ArrayBuffer[Row]): Unit = { - append(buffer, "", "", "") - append(buffer, "# View Information", "", "") - append(buffer, "View Text:", metadata.viewText.getOrElse(""), "") - append(buffer, "View Default Database:", metadata.viewDefaultDatabase.getOrElse(""), "") - append(buffer, "View Query Output Columns:", - metadata.viewQueryColumnNames.mkString("[", ", ", "]"), "") - } - - private def describeBucketingInfo(metadata: CatalogTable, buffer: ArrayBuffer[Row]): Unit = { - metadata.bucketSpec match { - case Some(BucketSpec(numBuckets, bucketColumnNames, sortColumnNames)) => - append(buffer, "Num Buckets:", numBuckets.toString, "") - append(buffer, "Bucket Columns:", bucketColumnNames.mkString("[", ", ", "]"), "") - append(buffer, "Sort Columns:", sortColumnNames.mkString("[", ", ", "]"), "") - - case _ => + table.toLinkedHashMap.filterKeys(!excludedTableInfo.contains(_)).foreach { + s => append(buffer, s._1, s._2, "") } } @@ -637,21 +578,7 @@ case class DescribeTableCommand( } DDLUtils.verifyPartitionProviderIsHive(spark, metadata, "DESC PARTITION") val partition = catalog.getPartition(table, partitionSpec) - if (isExtended) { - describeExtendedDetailedPartitionInfo(table, metadata, partition, result) - } else if (isFormatted) { - describeFormattedDetailedPartitionInfo(table, metadata, partition, result) - describeStorageInfo(metadata, result) - } - } - - private def describeExtendedDetailedPartitionInfo( - tableIdentifier: TableIdentifier, - table: CatalogTable, - partition: CatalogTablePartition, - buffer: ArrayBuffer[Row]): Unit = { - append(buffer, "", "", "") - append(buffer, "Detailed Partition Information " + partition.toString, "", "") + if (isExtended) describeFormattedDetailedPartitionInfo(table, metadata, partition, result) } private def describeFormattedDetailedPartitionInfo( @@ -661,18 +588,21 @@ case class DescribeTableCommand( buffer: ArrayBuffer[Row]): Unit = { append(buffer, "", "", "") append(buffer, "# Detailed Partition Information", "", "") - append(buffer, "Partition Value:", s"[${partition.spec.values.mkString(", ")}]", "") - append(buffer, "Database:", table.database, "") - append(buffer, "Table:", tableIdentifier.table, "") - append(buffer, "Location:", partition.storage.locationUri.map(CatalogUtils.URIToString(_)) - .getOrElse(""), "") - append(buffer, "Partition Parameters:", "", "") - partition.parameters.foreach { case (key, value) => - append(buffer, s" $key", value, "") + append(buffer, "Database", table.database, "") + append(buffer, "Table", tableIdentifier.table, "") + partition.toLinkedHashMap.foreach(s => append(buffer, s._1, s._2, "")) + append(buffer, "", "", "") + append(buffer, "# Storage Information", "", "") + table.bucketSpec match { + case Some(spec) => + spec.toLinkedHashMap.foreach(s => append(buffer, s._1, s._2, "")) + case _ => } + table.storage.toLinkedHashMap.foreach(s => append(buffer, s._1, s._2, "")) } private def describeSchema(schema: StructType, buffer: ArrayBuffer[Row]): Unit = { + append(buffer, s"# ${output.head.name}", output(1).name, output(2).name) schema.foreach { column => append(buffer, column.name, column.dataType.simpleString, column.getComment().orNull) } @@ -728,7 +658,7 @@ case class ShowTablesCommand( val tableName = tableIdent.table val isTemp = catalog.isTemporaryTable(tableIdent) if (isExtended) { - val information = catalog.getTempViewOrPermanentTableMetadata(tableIdent).toString + val information = catalog.getTempViewOrPermanentTableMetadata(tableIdent).simpleString Row(database, tableName, isTemp, s"$information\n") } else { Row(database, tableName, isTemp) @@ -745,7 +675,7 @@ case class ShowTablesCommand( val database = table.database.getOrElse("") val tableName = table.table val isTemp = catalog.isTemporaryTable(table) - val information = partition.toString + val information = partition.simpleString Seq(Row(database, tableName, isTemp, s"$information\n")) } } diff --git a/sql/core/src/test/resources/sql-tests/inputs/describe.sql b/sql/core/src/test/resources/sql-tests/inputs/describe.sql index 56f3281440d29..6de4cf0d5afa1 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/describe.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/describe.sql @@ -1,10 +1,23 @@ -CREATE TABLE t (a STRING, b INT, c STRING, d STRING) USING parquet PARTITIONED BY (c, d) COMMENT 'table_comment'; +CREATE TABLE t (a STRING, b INT, c STRING, d STRING) USING parquet + PARTITIONED BY (c, d) CLUSTERED BY (a) SORTED BY (b ASC) INTO 2 BUCKETS + COMMENT 'table_comment'; + +CREATE TEMPORARY VIEW temp_v AS SELECT * FROM t; + +CREATE TEMPORARY VIEW temp_Data_Source_View + USING org.apache.spark.sql.sources.DDLScanSource + OPTIONS ( + From '1', + To '10', + Table 'test1'); + +CREATE VIEW v AS SELECT * FROM t; ALTER TABLE t ADD PARTITION (c='Us', d=1); DESCRIBE t; -DESC t; +DESC default.t; DESC TABLE t; @@ -27,5 +40,39 @@ DESC t PARTITION (c='Us'); -- ParseException: PARTITION specification is incomplete DESC t PARTITION (c='Us', d); --- DROP TEST TABLE +-- DESC Temp View + +DESC temp_v; + +DESC TABLE temp_v; + +DESC FORMATTED temp_v; + +DESC EXTENDED temp_v; + +DESC temp_Data_Source_View; + +-- AnalysisException DESC PARTITION is not allowed on a temporary view +DESC temp_v PARTITION (c='Us', d=1); + +-- DESC Persistent View + +DESC v; + +DESC TABLE v; + +DESC FORMATTED v; + +DESC EXTENDED v; + +-- AnalysisException DESC PARTITION is not allowed on a view +DESC v PARTITION (c='Us', d=1); + +-- DROP TEST TABLES/VIEWS DROP TABLE t; + +DROP VIEW temp_v; + +DROP VIEW temp_Data_Source_View; + +DROP VIEW v; diff --git a/sql/core/src/test/resources/sql-tests/results/change-column.sql.out b/sql/core/src/test/resources/sql-tests/results/change-column.sql.out index ba8bc936f0c79..678a3f0f0a3c6 100644 --- a/sql/core/src/test/resources/sql-tests/results/change-column.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/change-column.sql.out @@ -15,6 +15,7 @@ DESC test_change -- !query 1 schema struct -- !query 1 output +# col_name data_type comment a int b string c int @@ -34,6 +35,7 @@ DESC test_change -- !query 3 schema struct -- !query 3 output +# col_name data_type comment a int b string c int @@ -53,6 +55,7 @@ DESC test_change -- !query 5 schema struct -- !query 5 output +# col_name data_type comment a int b string c int @@ -91,6 +94,7 @@ DESC test_change -- !query 8 schema struct -- !query 8 output +# col_name data_type comment a int b string c int @@ -125,6 +129,7 @@ DESC test_change -- !query 12 schema struct -- !query 12 output +# col_name data_type comment a int this is column a b string #*02?` c int @@ -143,6 +148,7 @@ DESC test_change -- !query 14 schema struct -- !query 14 output +# col_name data_type comment a int this is column a b string #*02?` c int @@ -162,6 +168,7 @@ DESC test_change -- !query 16 schema struct -- !query 16 output +# col_name data_type comment a int this is column a b string #*02?` c int @@ -186,6 +193,7 @@ DESC test_change -- !query 18 schema struct -- !query 18 output +# col_name data_type comment a int this is column a b string #*02?` c int @@ -229,6 +237,7 @@ DESC test_change -- !query 23 schema struct -- !query 23 output +# col_name data_type comment a int this is column A b string #*02?` c int diff --git a/sql/core/src/test/resources/sql-tests/results/describe.sql.out b/sql/core/src/test/resources/sql-tests/results/describe.sql.out index 422d548ea8de8..de10b29f3c65b 100644 --- a/sql/core/src/test/resources/sql-tests/results/describe.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/describe.sql.out @@ -1,9 +1,11 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 14 +-- Number of queries: 31 -- !query 0 -CREATE TABLE t (a STRING, b INT, c STRING, d STRING) USING parquet PARTITIONED BY (c, d) COMMENT 'table_comment' +CREATE TABLE t (a STRING, b INT, c STRING, d STRING) USING parquet + PARTITIONED BY (c, d) CLUSTERED BY (a) SORTED BY (b ASC) INTO 2 BUCKETS + COMMENT 'table_comment' -- !query 0 schema struct<> -- !query 0 output @@ -11,7 +13,7 @@ struct<> -- !query 1 -ALTER TABLE t ADD PARTITION (c='Us', d=1) +CREATE TEMPORARY VIEW temp_v AS SELECT * FROM t -- !query 1 schema struct<> -- !query 1 output @@ -19,187 +21,239 @@ struct<> -- !query 2 -DESCRIBE t +CREATE TEMPORARY VIEW temp_Data_Source_View + USING org.apache.spark.sql.sources.DDLScanSource + OPTIONS ( + From '1', + To '10', + Table 'test1') -- !query 2 schema -struct +struct<> -- !query 2 output -# Partition Information + + + +-- !query 3 +CREATE VIEW v AS SELECT * FROM t +-- !query 3 schema +struct<> +-- !query 3 output + + + +-- !query 4 +ALTER TABLE t ADD PARTITION (c='Us', d=1) +-- !query 4 schema +struct<> +-- !query 4 output + + + +-- !query 5 +DESCRIBE t +-- !query 5 schema +struct +-- !query 5 output # col_name data_type comment a string b int c string -c string d string +# Partition Information +# col_name data_type comment +c string d string --- !query 3 -DESC t --- !query 3 schema +-- !query 6 +DESC default.t +-- !query 6 schema struct --- !query 3 output -# Partition Information +-- !query 6 output # col_name data_type comment a string b int c string -c string d string +# Partition Information +# col_name data_type comment +c string d string --- !query 4 +-- !query 7 DESC TABLE t --- !query 4 schema +-- !query 7 schema struct --- !query 4 output -# Partition Information +-- !query 7 output # col_name data_type comment a string b int c string -c string d string +# Partition Information +# col_name data_type comment +c string d string --- !query 5 +-- !query 8 DESC FORMATTED t --- !query 5 schema +-- !query 8 schema struct --- !query 5 output -# Detailed Table Information -# Partition Information -# Storage Information +-- !query 8 output # col_name data_type comment -Comment: table_comment -Compressed: No -Created: -Database: default -Last Access: -Location: sql/core/spark-warehouse/t -Owner: -Partition Provider: Catalog -Storage Desc Parameters: -Table Parameters: -Table Type: MANAGED a string b int c string +d string +# Partition Information +# col_name data_type comment c string d string -d string + +# Detailed Table Information +Database default +Table t +Created [not included in comparison] +Last Access [not included in comparison] +Type MANAGED +Provider parquet +Num Buckets 2 +Bucket Columns [`a`] +Sort Columns [`b`] +Comment table_comment +Location [not included in comparison]sql/core/spark-warehouse/t +Partition Provider Catalog --- !query 6 +-- !query 9 DESC EXTENDED t --- !query 6 schema +-- !query 9 schema struct --- !query 6 output -# Detailed Table Information CatalogTable( - Table: `default`.`t` - Created: - Last Access: - Type: MANAGED - Schema: [StructField(a,StringType,true), StructField(b,IntegerType,true), StructField(c,StringType,true), StructField(d,StringType,true)] - Provider: parquet - Partition Columns: [`c`, `d`] - Comment: table_comment - Storage(Location: sql/core/spark-warehouse/t) - Partition Provider: Catalog) -# Partition Information +-- !query 9 output # col_name data_type comment a string b int c string +d string +# Partition Information +# col_name data_type comment c string d string -d string + +# Detailed Table Information +Database default +Table t +Created [not included in comparison] +Last Access [not included in comparison] +Type MANAGED +Provider parquet +Num Buckets 2 +Bucket Columns [`a`] +Sort Columns [`b`] +Comment table_comment +Location [not included in comparison]sql/core/spark-warehouse/t +Partition Provider Catalog --- !query 7 +-- !query 10 DESC t PARTITION (c='Us', d=1) --- !query 7 schema +-- !query 10 schema struct --- !query 7 output -# Partition Information +-- !query 10 output # col_name data_type comment a string b int c string -c string d string +# Partition Information +# col_name data_type comment +c string d string --- !query 8 +-- !query 11 DESC EXTENDED t PARTITION (c='Us', d=1) --- !query 8 schema +-- !query 11 schema struct --- !query 8 output -# Partition Information +-- !query 11 output # col_name data_type comment -Detailed Partition Information CatalogPartition( - Partition Values: [c=Us, d=1] - Storage(Location: sql/core/spark-warehouse/t/c=Us/d=1) - Partition Parameters:{}) a string b int c string +d string +# Partition Information +# col_name data_type comment c string d string -d string + +# Detailed Partition Information +Database default +Table t +Partition Values [c=Us, d=1] +Location [not included in comparison]sql/core/spark-warehouse/t/c=Us/d=1 + +# Storage Information +Num Buckets 2 +Bucket Columns [`a`] +Sort Columns [`b`] +Location [not included in comparison]sql/core/spark-warehouse/t --- !query 9 +-- !query 12 DESC FORMATTED t PARTITION (c='Us', d=1) --- !query 9 schema +-- !query 12 schema struct --- !query 9 output -# Detailed Partition Information -# Partition Information -# Storage Information +-- !query 12 output # col_name data_type comment -Compressed: No -Database: default -Location: sql/core/spark-warehouse/t/c=Us/d=1 -Partition Parameters: -Partition Value: [Us, 1] -Storage Desc Parameters: -Table: t a string b int c string +d string +# Partition Information +# col_name data_type comment c string d string -d string + +# Detailed Partition Information +Database default +Table t +Partition Values [c=Us, d=1] +Location [not included in comparison]sql/core/spark-warehouse/t/c=Us/d=1 + +# Storage Information +Num Buckets 2 +Bucket Columns [`a`] +Sort Columns [`b`] +Location [not included in comparison]sql/core/spark-warehouse/t --- !query 10 +-- !query 13 DESC t PARTITION (c='Us', d=2) --- !query 10 schema +-- !query 13 schema struct<> --- !query 10 output +-- !query 13 output org.apache.spark.sql.catalyst.analysis.NoSuchPartitionException Partition not found in table 't' database 'default': c -> Us d -> 2; --- !query 11 +-- !query 14 DESC t PARTITION (c='Us') --- !query 11 schema +-- !query 14 schema struct<> --- !query 11 output +-- !query 14 output org.apache.spark.sql.AnalysisException Partition spec is invalid. The spec (c) must match the partition spec (c, d) defined in table '`default`.`t`'; --- !query 12 +-- !query 15 DESC t PARTITION (c='Us', d) --- !query 12 schema +-- !query 15 schema struct<> --- !query 12 output +-- !query 15 output org.apache.spark.sql.catalyst.parser.ParseException PARTITION specification is incomplete: `d`(line 1, pos 0) @@ -209,9 +263,193 @@ DESC t PARTITION (c='Us', d) ^^^ --- !query 13 +-- !query 16 +DESC temp_v +-- !query 16 schema +struct +-- !query 16 output +# col_name data_type comment +a string +b int +c string +d string + + +-- !query 17 +DESC TABLE temp_v +-- !query 17 schema +struct +-- !query 17 output +# col_name data_type comment +a string +b int +c string +d string + + +-- !query 18 +DESC FORMATTED temp_v +-- !query 18 schema +struct +-- !query 18 output +# col_name data_type comment +a string +b int +c string +d string + + +-- !query 19 +DESC EXTENDED temp_v +-- !query 19 schema +struct +-- !query 19 output +# col_name data_type comment +a string +b int +c string +d string + + +-- !query 20 +DESC temp_Data_Source_View +-- !query 20 schema +struct +-- !query 20 output +# col_name data_type comment +intType int test comment test1 +stringType string +dateType date +timestampType timestamp +doubleType double +bigintType bigint +tinyintType tinyint +decimalType decimal(10,0) +fixedDecimalType decimal(5,1) +binaryType binary +booleanType boolean +smallIntType smallint +floatType float +mapType map +arrayType array +structType struct + + +-- !query 21 +DESC temp_v PARTITION (c='Us', d=1) +-- !query 21 schema +struct<> +-- !query 21 output +org.apache.spark.sql.AnalysisException +DESC PARTITION is not allowed on a temporary view: temp_v; + + +-- !query 22 +DESC v +-- !query 22 schema +struct +-- !query 22 output +# col_name data_type comment +a string +b int +c string +d string + + +-- !query 23 +DESC TABLE v +-- !query 23 schema +struct +-- !query 23 output +# col_name data_type comment +a string +b int +c string +d string + + +-- !query 24 +DESC FORMATTED v +-- !query 24 schema +struct +-- !query 24 output +# col_name data_type comment +a string +b int +c string +d string + +# Detailed Table Information +Database default +Table v +Created [not included in comparison] +Last Access [not included in comparison] +Type VIEW +View Text SELECT * FROM t +View Default Database default +View Query Output Columns [a, b, c, d] +Properties [view.query.out.col.3=d, view.query.out.col.0=a, view.query.out.numCols=4, view.default.database=default, view.query.out.col.1=b, view.query.out.col.2=c] + + +-- !query 25 +DESC EXTENDED v +-- !query 25 schema +struct +-- !query 25 output +# col_name data_type comment +a string +b int +c string +d string + +# Detailed Table Information +Database default +Table v +Created [not included in comparison] +Last Access [not included in comparison] +Type VIEW +View Text SELECT * FROM t +View Default Database default +View Query Output Columns [a, b, c, d] +Properties [view.query.out.col.3=d, view.query.out.col.0=a, view.query.out.numCols=4, view.default.database=default, view.query.out.col.1=b, view.query.out.col.2=c] + + +-- !query 26 +DESC v PARTITION (c='Us', d=1) +-- !query 26 schema +struct<> +-- !query 26 output +org.apache.spark.sql.AnalysisException +DESC PARTITION is not allowed on a view: v; + + +-- !query 27 DROP TABLE t --- !query 13 schema +-- !query 27 schema struct<> --- !query 13 output +-- !query 27 output + + + +-- !query 28 +DROP VIEW temp_v +-- !query 28 schema +struct<> +-- !query 28 output + + + +-- !query 29 +DROP VIEW temp_Data_Source_View +-- !query 29 schema +struct<> +-- !query 29 output + + + +-- !query 30 +DROP VIEW v +-- !query 30 schema +struct<> +-- !query 30 output diff --git a/sql/core/src/test/resources/sql-tests/results/show-tables.sql.out b/sql/core/src/test/resources/sql-tests/results/show-tables.sql.out index 6d62e6092147b..8f2a54f7c24e2 100644 --- a/sql/core/src/test/resources/sql-tests/results/show-tables.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/show-tables.sql.out @@ -118,33 +118,40 @@ SHOW TABLE EXTENDED LIKE 'show_t*' -- !query 12 schema struct -- !query 12 output -show_t3 true CatalogTable( - Table: `show_t3` - Created: - Last Access: - Type: VIEW - Schema: [StructField(e,IntegerType,true)] - Storage()) - -showdb show_t1 false CatalogTable( - Table: `showdb`.`show_t1` - Created: - Last Access: - Type: MANAGED - Schema: [StructField(a,StringType,true), StructField(b,IntegerType,true), StructField(c,StringType,true), StructField(d,StringType,true)] - Provider: parquet - Partition Columns: [`c`, `d`] - Storage(Location: sql/core/spark-warehouse/showdb.db/show_t1) - Partition Provider: Catalog) - -showdb show_t2 false CatalogTable( - Table: `showdb`.`show_t2` - Created: - Last Access: - Type: MANAGED - Schema: [StructField(b,StringType,true), StructField(d,IntegerType,true)] - Provider: parquet - Storage(Location: sql/core/spark-warehouse/showdb.db/show_t2)) +show_t3 true Table: show_t3 +Created [not included in comparison] +Last Access [not included in comparison] +Type: VIEW +Schema: root + |-- e: integer (nullable = true) + + +showdb show_t1 false Database: showdb +Table: show_t1 +Created [not included in comparison] +Last Access [not included in comparison] +Type: MANAGED +Provider: parquet +Location [not included in comparison]sql/core/spark-warehouse/showdb.db/show_t1 +Partition Provider: Catalog +Partition Columns: [`c`, `d`] +Schema: root + |-- a: string (nullable = true) + |-- b: integer (nullable = true) + |-- c: string (nullable = true) + |-- d: string (nullable = true) + + +showdb show_t2 false Database: showdb +Table: show_t2 +Created [not included in comparison] +Last Access [not included in comparison] +Type: MANAGED +Provider: parquet +Location [not included in comparison]sql/core/spark-warehouse/showdb.db/show_t2 +Schema: root + |-- b: string (nullable = true) + |-- d: integer (nullable = true) -- !query 13 @@ -166,10 +173,8 @@ SHOW TABLE EXTENDED LIKE 'show_t1' PARTITION(c='Us', d=1) -- !query 14 schema struct -- !query 14 output -showdb show_t1 false CatalogPartition( - Partition Values: [c=Us, d=1] - Storage(Location: sql/core/spark-warehouse/showdb.db/show_t1/c=Us/d=1) - Partition Parameters:{}) +showdb show_t1 false Partition Values: [c=Us, d=1] +Location [not included in comparison]sql/core/spark-warehouse/showdb.db/show_t1/c=Us/d=1 -- !query 15 diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQueryTestSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQueryTestSuite.scala index 4092862c430b1..4b69baffab620 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQueryTestSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQueryTestSuite.scala @@ -26,6 +26,7 @@ import org.apache.spark.sql.catalyst.planning.PhysicalOperation import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules.RuleExecutor import org.apache.spark.sql.catalyst.util.{fileToString, stringToFile} +import org.apache.spark.sql.execution.command.DescribeTableCommand import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types.StructType @@ -165,8 +166,8 @@ class SQLQueryTestSuite extends QueryTest with SharedSQLContext { s"-- Number of queries: ${outputs.size}\n\n\n" + outputs.zipWithIndex.map{case (qr, i) => qr.toString(i)}.mkString("\n\n\n") + "\n" } - val resultFile = new File(testCase.resultFile); - val parent = resultFile.getParentFile(); + val resultFile = new File(testCase.resultFile) + val parent = resultFile.getParentFile if (!parent.exists()) { assert(parent.mkdirs(), "Could not create directory: " + parent) } @@ -212,23 +213,25 @@ class SQLQueryTestSuite extends QueryTest with SharedSQLContext { /** Executes a query and returns the result as (schema of the output, normalized output). */ private def getNormalizedResult(session: SparkSession, sql: String): (StructType, Seq[String]) = { // Returns true if the plan is supposed to be sorted. - def isSorted(plan: LogicalPlan): Boolean = plan match { + def needSort(plan: LogicalPlan): Boolean = plan match { case _: Join | _: Aggregate | _: Generate | _: Sample | _: Distinct => false + case _: DescribeTableCommand => true case PhysicalOperation(_, _, Sort(_, true, _)) => true - case _ => plan.children.iterator.exists(isSorted) + case _ => plan.children.iterator.exists(needSort) } try { val df = session.sql(sql) val schema = df.schema + val notIncludedMsg = "[not included in comparison]" // Get answer, but also get rid of the #1234 expression ids that show up in explain plans val answer = df.queryExecution.hiveResultString().map(_.replaceAll("#\\d+", "#x") - .replaceAll("Location:.*/sql/core/", "Location: sql/core/") - .replaceAll("Created: .*", "Created: ") - .replaceAll("Last Access: .*", "Last Access: ")) + .replaceAll("Location.*/sql/core/", s"Location ${notIncludedMsg}sql/core/") + .replaceAll("Created.*", s"Created $notIncludedMsg") + .replaceAll("Last Access.*", s"Last Access $notIncludedMsg")) // If the output is not pre-sorted, sort it. - if (isSorted(df.queryExecution.analyzed)) (schema, answer) else (schema, answer.sorted) + if (needSort(df.queryExecution.analyzed)) (schema, answer) else (schema, answer.sorted) } catch { case a: AnalysisException => diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlParserSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlParserSuite.scala index a4d012cd76115..908b955abbf07 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlParserSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlParserSuite.scala @@ -224,13 +224,13 @@ class SparkSqlParserSuite extends PlanTest { test("SPARK-17328 Fix NPE with EXPLAIN DESCRIBE TABLE") { assertEqual("describe table t", DescribeTableCommand( - TableIdentifier("t"), Map.empty, isExtended = false, isFormatted = false)) + TableIdentifier("t"), Map.empty, isExtended = false)) assertEqual("describe table extended t", DescribeTableCommand( - TableIdentifier("t"), Map.empty, isExtended = true, isFormatted = false)) + TableIdentifier("t"), Map.empty, isExtended = true)) assertEqual("describe table formatted t", DescribeTableCommand( - TableIdentifier("t"), Map.empty, isExtended = false, isFormatted = true)) + TableIdentifier("t"), Map.empty, isExtended = true)) intercept("explain describe tables x", "Unsupported SQL statement") } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala index 648b1798c66e0..9ebf2dd839a79 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala @@ -69,18 +69,6 @@ class InMemoryCatalogedDDLSuite extends DDLSuite with SharedSQLContext with Befo tracksPartitionsInCatalog = true) } - test("desc table for parquet data source table using in-memory catalog") { - val tabName = "tab1" - withTable(tabName) { - sql(s"CREATE TABLE $tabName(a int comment 'test') USING parquet ") - - checkAnswer( - sql(s"DESC $tabName").select("col_name", "data_type", "comment"), - Row("a", "int", "test") - ) - } - } - test("alter table: set location (datasource table)") { testSetLocation(isDatasourceTable = true) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala index 4a02277631f14..5bd36ec25ccb0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala @@ -806,10 +806,6 @@ class JDBCSuite extends SparkFunSuite sql(s"DESC FORMATTED $tableName").collect().foreach { r => assert(!r.toString().contains(password)) } - - sql(s"DESC EXTENDED $tableName").collect().foreach { r => - assert(!r.toString().contains(password)) - } } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/DDLTestSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/DDLTestSuite.scala deleted file mode 100644 index 674463feca4db..0000000000000 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/DDLTestSuite.scala +++ /dev/null @@ -1,123 +0,0 @@ -/* -* Licensed to the Apache Software Foundation (ASF) under one or more -* contributor license agreements. See the NOTICE file distributed with -* this work for additional information regarding copyright ownership. -* The ASF licenses this file to You under the Apache License, Version 2.0 -* (the "License"); you may not use this file except in compliance with -* the License. You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*/ - -package org.apache.spark.sql.sources - -import org.apache.spark.rdd.RDD -import org.apache.spark.sql._ -import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.test.SharedSQLContext -import org.apache.spark.sql.types._ -import org.apache.spark.unsafe.types.UTF8String - -class DDLScanSource extends RelationProvider { - override def createRelation( - sqlContext: SQLContext, - parameters: Map[String, String]): BaseRelation = { - SimpleDDLScan( - parameters("from").toInt, - parameters("TO").toInt, - parameters("Table"))(sqlContext.sparkSession) - } -} - -case class SimpleDDLScan( - from: Int, - to: Int, - table: String)(@transient val sparkSession: SparkSession) - extends BaseRelation with TableScan { - - override def sqlContext: SQLContext = sparkSession.sqlContext - - override def schema: StructType = - StructType(Seq( - StructField("intType", IntegerType, nullable = false).withComment(s"test comment $table"), - StructField("stringType", StringType, nullable = false), - StructField("dateType", DateType, nullable = false), - StructField("timestampType", TimestampType, nullable = false), - StructField("doubleType", DoubleType, nullable = false), - StructField("bigintType", LongType, nullable = false), - StructField("tinyintType", ByteType, nullable = false), - StructField("decimalType", DecimalType.USER_DEFAULT, nullable = false), - StructField("fixedDecimalType", DecimalType(5, 1), nullable = false), - StructField("binaryType", BinaryType, nullable = false), - StructField("booleanType", BooleanType, nullable = false), - StructField("smallIntType", ShortType, nullable = false), - StructField("floatType", FloatType, nullable = false), - StructField("mapType", MapType(StringType, StringType)), - StructField("arrayType", ArrayType(StringType)), - StructField("structType", - StructType(StructField("f1", StringType) :: StructField("f2", IntegerType) :: Nil - ) - ) - )) - - override def needConversion: Boolean = false - - override def buildScan(): RDD[Row] = { - // Rely on a type erasure hack to pass RDD[InternalRow] back as RDD[Row] - sparkSession.sparkContext.parallelize(from to to).map { e => - InternalRow(UTF8String.fromString(s"people$e"), e * 2) - }.asInstanceOf[RDD[Row]] - } -} - -class DDLTestSuite extends DataSourceTest with SharedSQLContext { - protected override lazy val sql = spark.sql _ - - override def beforeAll(): Unit = { - super.beforeAll() - sql( - """ - |CREATE OR REPLACE TEMPORARY VIEW ddlPeople - |USING org.apache.spark.sql.sources.DDLScanSource - |OPTIONS ( - | From '1', - | To '10', - | Table 'test1' - |) - """.stripMargin) - } - - sqlTest( - "describe ddlPeople", - Seq( - Row("intType", "int", "test comment test1"), - Row("stringType", "string", null), - Row("dateType", "date", null), - Row("timestampType", "timestamp", null), - Row("doubleType", "double", null), - Row("bigintType", "bigint", null), - Row("tinyintType", "tinyint", null), - Row("decimalType", "decimal(10,0)", null), - Row("fixedDecimalType", "decimal(5,1)", null), - Row("binaryType", "binary", null), - Row("booleanType", "boolean", null), - Row("smallIntType", "smallint", null), - Row("floatType", "float", null), - Row("mapType", "map", null), - Row("arrayType", "array", null), - Row("structType", "struct", null) - )) - - test("SPARK-7686 DescribeCommand should have correct physical plan output attributes") { - val attributes = sql("describe ddlPeople") - .queryExecution.executedPlan.output - assert(attributes.map(_.name) === Seq("col_name", "data_type", "comment")) - assert(attributes.map(_.dataType).toSet === Set(StringType)) - } -} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/DataSourceTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/DataSourceTest.scala index cc77d3c4b91ac..80868fff897fd 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/DataSourceTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/DataSourceTest.scala @@ -17,7 +17,11 @@ package org.apache.spark.sql.sources +import org.apache.spark.rdd.RDD import org.apache.spark.sql._ +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.UTF8String private[sql] abstract class DataSourceTest extends QueryTest { @@ -28,3 +32,55 @@ private[sql] abstract class DataSourceTest extends QueryTest { } } + +class DDLScanSource extends RelationProvider { + override def createRelation( + sqlContext: SQLContext, + parameters: Map[String, String]): BaseRelation = { + SimpleDDLScan( + parameters("from").toInt, + parameters("TO").toInt, + parameters("Table"))(sqlContext.sparkSession) + } +} + +case class SimpleDDLScan( + from: Int, + to: Int, + table: String)(@transient val sparkSession: SparkSession) + extends BaseRelation with TableScan { + + override def sqlContext: SQLContext = sparkSession.sqlContext + + override def schema: StructType = + StructType(Seq( + StructField("intType", IntegerType, nullable = false).withComment(s"test comment $table"), + StructField("stringType", StringType, nullable = false), + StructField("dateType", DateType, nullable = false), + StructField("timestampType", TimestampType, nullable = false), + StructField("doubleType", DoubleType, nullable = false), + StructField("bigintType", LongType, nullable = false), + StructField("tinyintType", ByteType, nullable = false), + StructField("decimalType", DecimalType.USER_DEFAULT, nullable = false), + StructField("fixedDecimalType", DecimalType(5, 1), nullable = false), + StructField("binaryType", BinaryType, nullable = false), + StructField("booleanType", BooleanType, nullable = false), + StructField("smallIntType", ShortType, nullable = false), + StructField("floatType", FloatType, nullable = false), + StructField("mapType", MapType(StringType, StringType)), + StructField("arrayType", ArrayType(StringType)), + StructField("structType", + StructType(StructField("f1", StringType) :: StructField("f2", IntegerType) :: Nil + ) + ) + )) + + override def needConversion: Boolean = false + + override def buildScan(): RDD[Row] = { + // Rely on a type erasure hack to pass RDD[InternalRow] back as RDD[Row] + sparkSession.sparkContext.parallelize(from to to).map { e => + InternalRow(UTF8String.fromString(s"people$e"), e * 2) + }.asInstanceOf[RDD[Row]] + } +} diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala index 55e02acfa4ce3..b554694815571 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala @@ -767,9 +767,6 @@ class MetastoreDataSourcesSuite extends QueryTest with SQLTestUtils with TestHiv sessionState.refreshTable(tableName) val actualSchema = table(tableName).schema assert(schema === actualSchema) - - // Checks the DESCRIBE output. - checkAnswer(sql("DESCRIBE spark6655"), Row("int", "int", null) :: Nil) } } @@ -1381,7 +1378,10 @@ class MetastoreDataSourcesSuite extends QueryTest with SQLTestUtils with TestHiv checkAnswer(spark.table("old"), Row(1, "a")) - checkAnswer(sql("DESC old"), Row("i", "int", null) :: Row("j", "string", null) :: Nil) + val expectedSchema = StructType(Seq( + StructField("i", IntegerType, nullable = true), + StructField("j", StringType, nullable = true))) + assert(table("old").schema === expectedSchema) } } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala index 536ca8fd9d45d..e45cf977bfaa2 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala @@ -207,6 +207,7 @@ abstract class HiveComparisonTest // This list contains indicators for those lines which do not have actual results and we // want to ignore. lazy val ignoredLineIndicators = Seq( + "# Detailed Table Information", "# Partition Information", "# col_name" ) @@ -358,7 +359,7 @@ abstract class HiveComparisonTest stringToFile(new File(failedDirectory, testCaseName), errorMessage + consoleTestCase) fail(errorMessage) } - }.toSeq + } (queryList, hiveResults, catalystResults).zipped.foreach { case (query, hive, (hiveQuery, catalyst)) => @@ -369,6 +370,7 @@ abstract class HiveComparisonTest if ((!hiveQuery.logical.isInstanceOf[ExplainCommand]) && (!hiveQuery.logical.isInstanceOf[ShowFunctionsCommand]) && (!hiveQuery.logical.isInstanceOf[DescribeFunctionCommand]) && + (!hiveQuery.logical.isInstanceOf[DescribeTableCommand]) && preparedHive != catalyst) { val hivePrintOut = s"== HIVE - ${preparedHive.size} row(s) ==" +: preparedHive diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala index f0a995c274b64..3906968aaff10 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala @@ -708,23 +708,6 @@ class HiveDDLSuite } } - test("desc table for Hive table") { - withTable("tab1") { - val tabName = "tab1" - sql(s"CREATE TABLE $tabName(c1 int)") - - assert(sql(s"DESC $tabName").collect().length == 1) - - assert( - sql(s"DESC FORMATTED $tabName").collect() - .exists(_.getString(0) == "# Storage Information")) - - assert( - sql(s"DESC EXTENDED $tabName").collect() - .exists(_.getString(0) == "# Detailed Table Information")) - } - } - test("desc table for Hive table - partitioned table") { withTable("tbl") { sql("CREATE TABLE tbl(a int) PARTITIONED BY (b int)") @@ -741,23 +724,6 @@ class HiveDDLSuite } } - test("desc formatted table for permanent view") { - withTable("tbl") { - withView("view1") { - sql("CREATE TABLE tbl(a int)") - sql("CREATE VIEW view1 AS SELECT * FROM tbl") - assert(sql("DESC FORMATTED view1").collect().containsSlice( - Seq( - Row("# View Information", "", ""), - Row("View Text:", "SELECT * FROM tbl", ""), - Row("View Default Database:", "default", ""), - Row("View Query Output Columns:", "[a]", "") - ) - )) - } - } - } - test("desc table for data source table using Hive Metastore") { assume(spark.sparkContext.conf.get(CATALOG_IMPLEMENTATION) == "hive") val tabName = "tab1" @@ -766,7 +732,7 @@ class HiveDDLSuite checkAnswer( sql(s"DESC $tabName").select("col_name", "data_type", "comment"), - Row("a", "int", "test") + Row("# col_name", "data_type", "comment") :: Row("a", "int", "test") :: Nil ) } } @@ -1218,23 +1184,6 @@ class HiveDDLSuite sql(s"SELECT * FROM ${targetTable.identifier}")) } - test("desc table for data source table") { - withTable("tab1") { - val tabName = "tab1" - spark.range(1).write.format("json").saveAsTable(tabName) - - assert(sql(s"DESC $tabName").collect().length == 1) - - assert( - sql(s"DESC FORMATTED $tabName").collect() - .exists(_.getString(0) == "# Storage Information")) - - assert( - sql(s"DESC EXTENDED $tabName").collect() - .exists(_.getString(0) == "# Detailed Table Information")) - } - } - test("create table with the same name as an index table") { val tabName = "tab1" val indexName = tabName + "_index" @@ -1320,46 +1269,6 @@ class HiveDDLSuite } } - test("desc table for data source table - partitioned bucketed table") { - withTable("t1") { - spark - .range(1).select('id as 'a, 'id as 'b, 'id as 'c, 'id as 'd).write - .bucketBy(2, "b").sortBy("c").partitionBy("d") - .saveAsTable("t1") - - val formattedDesc = sql("DESC FORMATTED t1").collect() - - assert(formattedDesc.containsSlice( - Seq( - Row("a", "bigint", null), - Row("b", "bigint", null), - Row("c", "bigint", null), - Row("d", "bigint", null), - Row("# Partition Information", "", ""), - Row("# col_name", "data_type", "comment"), - Row("d", "bigint", null), - Row("", "", ""), - Row("# Detailed Table Information", "", ""), - Row("Database:", "default", "") - ) - )) - - assert(formattedDesc.containsSlice( - Seq( - Row("Table Type:", "MANAGED", "") - ) - )) - - assert(formattedDesc.containsSlice( - Seq( - Row("Num Buckets:", "2", ""), - Row("Bucket Columns:", "[b]", ""), - Row("Sort Columns:", "[c]", "") - ) - )) - } - } - test("datasource and statistics table property keys are not allowed") { import org.apache.spark.sql.hive.HiveExternalCatalog.DATASOURCE_PREFIX import org.apache.spark.sql.hive.HiveExternalCatalog.STATISTICS_PREFIX diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveOperatorQueryableSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveOperatorQueryableSuite.scala deleted file mode 100644 index 0e89e990e564e..0000000000000 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveOperatorQueryableSuite.scala +++ /dev/null @@ -1,53 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.hive.execution - -import org.apache.spark.sql.{QueryTest, Row} -import org.apache.spark.sql.hive.test.TestHiveSingleton - -/** - * A set of tests that validates commands can also be queried by like a table - */ -class HiveOperatorQueryableSuite extends QueryTest with TestHiveSingleton { - import spark._ - - test("SPARK-5324 query result of describe command") { - hiveContext.loadTestTable("src") - - // Creates a temporary view with the output of a describe command - sql("desc src").createOrReplaceTempView("mydesc") - checkAnswer( - sql("desc mydesc"), - Seq( - Row("col_name", "string", "name of the column"), - Row("data_type", "string", "data type of the column"), - Row("comment", "string", "comment of the column"))) - - checkAnswer( - sql("select * from mydesc"), - Seq( - Row("key", "int", null), - Row("value", "string", null))) - - checkAnswer( - sql("select col_name, data_type, comment from mydesc"), - Seq( - Row("key", "int", null), - Row("value", "string", null))) - } -} diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala index dd278f683a3cd..65a902fc5438e 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala @@ -789,62 +789,6 @@ class HiveQuerySuite extends HiveComparisonTest with SQLTestUtils with BeforeAnd assert(Try(q0.count()).isSuccess) } - test("DESCRIBE commands") { - sql(s"CREATE TABLE test_describe_commands1 (key INT, value STRING) PARTITIONED BY (dt STRING)") - - sql( - """FROM src INSERT OVERWRITE TABLE test_describe_commands1 PARTITION (dt='2008-06-08') - |SELECT key, value - """.stripMargin) - - // Describe a table - assertResult( - Array( - Row("key", "int", null), - Row("value", "string", null), - Row("dt", "string", null), - Row("# Partition Information", "", ""), - Row("# col_name", "data_type", "comment"), - Row("dt", "string", null)) - ) { - sql("DESCRIBE test_describe_commands1") - .select('col_name, 'data_type, 'comment) - .collect() - } - - // Describe a table with a fully qualified table name - assertResult( - Array( - Row("key", "int", null), - Row("value", "string", null), - Row("dt", "string", null), - Row("# Partition Information", "", ""), - Row("# col_name", "data_type", "comment"), - Row("dt", "string", null)) - ) { - sql("DESCRIBE default.test_describe_commands1") - .select('col_name, 'data_type, 'comment) - .collect() - } - - // Describe a temporary view. - val testData = - TestHive.sparkContext.parallelize( - TestData(1, "str1") :: - TestData(1, "str2") :: Nil) - testData.toDF().createOrReplaceTempView("test_describe_commands2") - - assertResult( - Array( - Row("a", "int", null), - Row("b", "string", null)) - ) { - sql("DESCRIBE test_describe_commands2") - .select('col_name, 'data_type, 'comment) - .collect() - } - } - test("SPARK-2263: Insert Map values") { sql("CREATE TABLE m(value MAP)") sql("INSERT OVERWRITE TABLE m SELECT MAP(key, value) FROM src LIMIT 10") diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala index 55ff4bb115e59..d012797e19926 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala @@ -363,79 +363,6 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton { } } - test("describe partition") { - withTable("partitioned_table") { - sql("CREATE TABLE partitioned_table (a STRING, b INT) PARTITIONED BY (c STRING, d STRING)") - sql("ALTER TABLE partitioned_table ADD PARTITION (c='Us', d=1)") - - checkKeywordsExist(sql("DESC partitioned_table PARTITION (c='Us', d=1)"), - "# Partition Information", - "# col_name") - - checkKeywordsExist(sql("DESC EXTENDED partitioned_table PARTITION (c='Us', d=1)"), - "# Partition Information", - "# col_name", - "Detailed Partition Information CatalogPartition(", - "Partition Values: [c=Us, d=1]", - "Storage(Location:", - "Partition Parameters") - - checkKeywordsExist(sql("DESC FORMATTED partitioned_table PARTITION (c='Us', d=1)"), - "# Partition Information", - "# col_name", - "# Detailed Partition Information", - "Partition Value:", - "Database:", - "Table:", - "Location:", - "Partition Parameters:", - "# Storage Information") - } - } - - test("describe partition - error handling") { - withTable("partitioned_table", "datasource_table") { - sql("CREATE TABLE partitioned_table (a STRING, b INT) PARTITIONED BY (c STRING, d STRING)") - sql("ALTER TABLE partitioned_table ADD PARTITION (c='Us', d=1)") - - val m = intercept[NoSuchPartitionException] { - sql("DESC partitioned_table PARTITION (c='Us', d=2)") - }.getMessage() - assert(m.contains("Partition not found in table")) - - val m2 = intercept[AnalysisException] { - sql("DESC partitioned_table PARTITION (c='Us')") - }.getMessage() - assert(m2.contains("Partition spec is invalid")) - - val m3 = intercept[ParseException] { - sql("DESC partitioned_table PARTITION (c='Us', d)") - }.getMessage() - assert(m3.contains("PARTITION specification is incomplete: `d`")) - - spark - .range(1).select('id as 'a, 'id as 'b, 'id as 'c, 'id as 'd).write - .partitionBy("d") - .saveAsTable("datasource_table") - - sql("DESC datasource_table PARTITION (d=0)") - - val m5 = intercept[AnalysisException] { - spark.range(10).select('id as 'a, 'id as 'b).createTempView("view1") - sql("DESC view1 PARTITION (c='Us', d=1)") - }.getMessage() - assert(m5.contains("DESC PARTITION is not allowed on a temporary view")) - - withView("permanent_view") { - val m = intercept[AnalysisException] { - sql("CREATE VIEW permanent_view AS SELECT * FROM partitioned_table") - sql("DESC permanent_view PARTITION (c='Us', d=1)") - }.getMessage() - assert(m.contains("DESC PARTITION is not allowed on a view")) - } - } - } - test("SPARK-5371: union with null and sum") { val df = Seq((1, 1)).toDF("c1", "c2") df.createOrReplaceTempView("table1") @@ -676,7 +603,7 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton { } test("CTAS with serde") { - sql("CREATE TABLE ctas1 AS SELECT key k, value FROM src ORDER BY k, value").collect() + sql("CREATE TABLE ctas1 AS SELECT key k, value FROM src ORDER BY k, value") sql( """CREATE TABLE ctas2 | ROW FORMAT SERDE "org.apache.hadoop.hive.serde2.columnar.ColumnarSerDe" @@ -686,86 +613,76 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton { | AS | SELECT key, value | FROM src - | ORDER BY key, value""".stripMargin).collect() + | ORDER BY key, value""".stripMargin) + + val storageCtas2 = spark.sessionState.catalog.getTableMetadata(TableIdentifier("ctas2")).storage + assert(storageCtas2.inputFormat == Some("org.apache.hadoop.hive.ql.io.RCFileInputFormat")) + assert(storageCtas2.outputFormat == Some("org.apache.hadoop.hive.ql.io.RCFileOutputFormat")) + assert(storageCtas2.serde == Some("org.apache.hadoop.hive.serde2.columnar.ColumnarSerDe")) + sql( """CREATE TABLE ctas3 | ROW FORMAT DELIMITED FIELDS TERMINATED BY ',' LINES TERMINATED BY '\012' | STORED AS textfile AS | SELECT key, value | FROM src - | ORDER BY key, value""".stripMargin).collect() + | ORDER BY key, value""".stripMargin) // the table schema may like (key: integer, value: string) sql( """CREATE TABLE IF NOT EXISTS ctas4 AS - | SELECT 1 AS key, value FROM src LIMIT 1""".stripMargin).collect() + | SELECT 1 AS key, value FROM src LIMIT 1""".stripMargin) // do nothing cause the table ctas4 already existed. sql( """CREATE TABLE IF NOT EXISTS ctas4 AS - | SELECT key, value FROM src ORDER BY key, value""".stripMargin).collect() + | SELECT key, value FROM src ORDER BY key, value""".stripMargin) checkAnswer( sql("SELECT k, value FROM ctas1 ORDER BY k, value"), - sql("SELECT key, value FROM src ORDER BY key, value").collect().toSeq) + sql("SELECT key, value FROM src ORDER BY key, value")) checkAnswer( sql("SELECT key, value FROM ctas2 ORDER BY key, value"), sql( """ SELECT key, value FROM src - ORDER BY key, value""").collect().toSeq) + ORDER BY key, value""")) checkAnswer( sql("SELECT key, value FROM ctas3 ORDER BY key, value"), sql( """ SELECT key, value FROM src - ORDER BY key, value""").collect().toSeq) + ORDER BY key, value""")) intercept[AnalysisException] { sql( """CREATE TABLE ctas4 AS - | SELECT key, value FROM src ORDER BY key, value""".stripMargin).collect() + | SELECT key, value FROM src ORDER BY key, value""".stripMargin) } checkAnswer( sql("SELECT key, value FROM ctas4 ORDER BY key, value"), sql("SELECT key, value FROM ctas4 LIMIT 1").collect().toSeq) - /* - Disabled because our describe table does not output the serde information right now. - checkKeywordsExist(sql("DESC EXTENDED ctas2"), - "name:key", "type:string", "name:value", "ctas2", - "org.apache.hadoop.hive.ql.io.RCFileInputFormat", - "org.apache.hadoop.hive.ql.io.RCFileOutputFormat", - "org.apache.hadoop.hive.serde2.columnar.ColumnarSerDe", - "serde_p1=p1", "serde_p2=p2", "tbl_p1=p11", "tbl_p2=p22", "MANAGED_TABLE" - ) - */ - sql( """CREATE TABLE ctas5 | STORED AS parquet AS | SELECT key, value | FROM src - | ORDER BY key, value""".stripMargin).collect() + | ORDER BY key, value""".stripMargin) + val storageCtas5 = spark.sessionState.catalog.getTableMetadata(TableIdentifier("ctas5")).storage + assert(storageCtas5.inputFormat == + Some("org.apache.hadoop.hive.ql.io.parquet.MapredParquetInputFormat")) + assert(storageCtas5.outputFormat == + Some("org.apache.hadoop.hive.ql.io.parquet.MapredParquetOutputFormat")) + assert(storageCtas5.serde == + Some("org.apache.hadoop.hive.ql.io.parquet.serde.ParquetHiveSerDe")) - /* - Disabled because our describe table does not output the serde information right now. - withSQLConf(HiveUtils.CONVERT_METASTORE_PARQUET.key -> "false") { - checkKeywordsExist(sql("DESC EXTENDED ctas5"), - "name:key", "type:string", "name:value", "ctas5", - "org.apache.hadoop.hive.ql.io.parquet.MapredParquetInputFormat", - "org.apache.hadoop.hive.ql.io.parquet.MapredParquetOutputFormat", - "org.apache.hadoop.hive.ql.io.parquet.serde.ParquetHiveSerDe", - "MANAGED_TABLE" - ) - } - */ // use the Hive SerDe for parquet tables withSQLConf(HiveUtils.CONVERT_METASTORE_PARQUET.key -> "false") { checkAnswer( sql("SELECT key, value FROM ctas5 ORDER BY key, value"), - sql("SELECT key, value FROM src ORDER BY key, value").collect().toSeq) + sql("SELECT key, value FROM src ORDER BY key, value")) } } From b34f7665ddb0a40044b4c2bc7d351599c125cb13 Mon Sep 17 00:00:00 2001 From: zero323 Date: Mon, 3 Apr 2017 23:42:04 -0700 Subject: [PATCH 0186/1765] [SPARK-19825][R][ML] spark.ml R API for FPGrowth ## What changes were proposed in this pull request? Adds SparkR API for FPGrowth: [SPARK-19825](https://issues.apache.org/jira/browse/SPARK-19825): - `spark.fpGrowth` -model training. - `freqItemsets` and `associationRules` methods with new corresponding generics. - Scala helper: `org.apache.spark.ml.r. FPGrowthWrapper` - unit tests. ## How was this patch tested? Feature specific unit tests. Author: zero323 Closes #17170 from zero323/SPARK-19825. --- R/pkg/DESCRIPTION | 1 + R/pkg/NAMESPACE | 5 +- R/pkg/R/generics.R | 12 ++ R/pkg/R/mllib_fpm.R | 158 ++++++++++++++++++ R/pkg/R/mllib_utils.R | 2 + R/pkg/inst/tests/testthat/test_mllib_fpm.R | 83 +++++++++ .../apache/spark/ml/r/FPGrowthWrapper.scala | 86 ++++++++++ .../org/apache/spark/ml/r/RWrappers.scala | 2 + 8 files changed, 348 insertions(+), 1 deletion(-) create mode 100644 R/pkg/R/mllib_fpm.R create mode 100644 R/pkg/inst/tests/testthat/test_mllib_fpm.R create mode 100644 mllib/src/main/scala/org/apache/spark/ml/r/FPGrowthWrapper.scala diff --git a/R/pkg/DESCRIPTION b/R/pkg/DESCRIPTION index 00dde64324ae7..f475ee87702e1 100644 --- a/R/pkg/DESCRIPTION +++ b/R/pkg/DESCRIPTION @@ -44,6 +44,7 @@ Collate: 'jvm.R' 'mllib_classification.R' 'mllib_clustering.R' + 'mllib_fpm.R' 'mllib_recommendation.R' 'mllib_regression.R' 'mllib_stat.R' diff --git a/R/pkg/NAMESPACE b/R/pkg/NAMESPACE index c02046c94bf4d..9b7e95ce30acb 100644 --- a/R/pkg/NAMESPACE +++ b/R/pkg/NAMESPACE @@ -66,7 +66,10 @@ exportMethods("glm", "spark.randomForest", "spark.gbt", "spark.bisectingKmeans", - "spark.svmLinear") + "spark.svmLinear", + "spark.fpGrowth", + "spark.freqItemsets", + "spark.associationRules") # Job group lifecycle management methods export("setJobGroup", diff --git a/R/pkg/R/generics.R b/R/pkg/R/generics.R index 80283e48ced7b..945676c7f10b3 100644 --- a/R/pkg/R/generics.R +++ b/R/pkg/R/generics.R @@ -1445,6 +1445,18 @@ setGeneric("spark.posterior", function(object, newData) { standardGeneric("spark #' @export setGeneric("spark.perplexity", function(object, data) { standardGeneric("spark.perplexity") }) +#' @rdname spark.fpGrowth +#' @export +setGeneric("spark.fpGrowth", function(data, ...) { standardGeneric("spark.fpGrowth") }) + +#' @rdname spark.fpGrowth +#' @export +setGeneric("spark.freqItemsets", function(object) { standardGeneric("spark.freqItemsets") }) + +#' @rdname spark.fpGrowth +#' @export +setGeneric("spark.associationRules", function(object) { standardGeneric("spark.associationRules") }) + #' @param object a fitted ML model object. #' @param path the directory where the model is saved. #' @param ... additional argument(s) passed to the method. diff --git a/R/pkg/R/mllib_fpm.R b/R/pkg/R/mllib_fpm.R new file mode 100644 index 0000000000000..96251b2c7c195 --- /dev/null +++ b/R/pkg/R/mllib_fpm.R @@ -0,0 +1,158 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# mllib_fpm.R: Provides methods for MLlib frequent pattern mining algorithms integration + +#' S4 class that represents a FPGrowthModel +#' +#' @param jobj a Java object reference to the backing Scala FPGrowthModel +#' @export +#' @note FPGrowthModel since 2.2.0 +setClass("FPGrowthModel", slots = list(jobj = "jobj")) + +#' FP-growth +#' +#' A parallel FP-growth algorithm to mine frequent itemsets. +#' For more details, see +#' \href{https://spark.apache.org/docs/latest/mllib-frequent-pattern-mining.html#fp-growth}{ +#' FP-growth}. +#' +#' @param data A SparkDataFrame for training. +#' @param minSupport Minimal support level. +#' @param minConfidence Minimal confidence level. +#' @param itemsCol Features column name. +#' @param numPartitions Number of partitions used for fitting. +#' @param ... additional argument(s) passed to the method. +#' @return \code{spark.fpGrowth} returns a fitted FPGrowth model. +#' @rdname spark.fpGrowth +#' @name spark.fpGrowth +#' @aliases spark.fpGrowth,SparkDataFrame-method +#' @export +#' @examples +#' \dontrun{ +#' raw_data <- read.df( +#' "data/mllib/sample_fpgrowth.txt", +#' source = "csv", +#' schema = structType(structField("raw_items", "string"))) +#' +#' data <- selectExpr(raw_data, "split(raw_items, ' ') as items") +#' model <- spark.fpGrowth(data) +#' +#' # Show frequent itemsets +#' frequent_itemsets <- spark.freqItemsets(model) +#' showDF(frequent_itemsets) +#' +#' # Show association rules +#' association_rules <- spark.associationRules(model) +#' showDF(association_rules) +#' +#' # Predict on new data +#' new_itemsets <- data.frame(items = c("t", "t,s")) +#' new_data <- selectExpr(createDataFrame(new_itemsets), "split(items, ',') as items") +#' predict(model, new_data) +#' +#' # Save and load model +#' path <- "/path/to/model" +#' write.ml(model, path) +#' read.ml(path) +#' +#' # Optional arguments +#' baskets_data <- selectExpr(createDataFrame(itemsets), "split(items, ',') as baskets") +#' another_model <- spark.fpGrowth(data, minSupport = 0.1, minConfidence = 0.5, +#' itemsCol = "baskets", numPartitions = 10) +#' } +#' @note spark.fpGrowth since 2.2.0 +setMethod("spark.fpGrowth", signature(data = "SparkDataFrame"), + function(data, minSupport = 0.3, minConfidence = 0.8, + itemsCol = "items", numPartitions = NULL) { + if (!is.numeric(minSupport) || minSupport < 0 || minSupport > 1) { + stop("minSupport should be a number [0, 1].") + } + if (!is.numeric(minConfidence) || minConfidence < 0 || minConfidence > 1) { + stop("minConfidence should be a number [0, 1].") + } + if (!is.null(numPartitions)) { + numPartitions <- as.integer(numPartitions) + stopifnot(numPartitions > 0) + } + + jobj <- callJStatic("org.apache.spark.ml.r.FPGrowthWrapper", "fit", + data@sdf, as.numeric(minSupport), as.numeric(minConfidence), + itemsCol, numPartitions) + new("FPGrowthModel", jobj = jobj) + }) + +# Get frequent itemsets. + +#' @param object a fitted FPGrowth model. +#' @return A \code{SparkDataFrame} with frequent itemsets. +#' The \code{SparkDataFrame} contains two columns: +#' \code{items} (an array of the same type as the input column) +#' and \code{freq} (frequency of the itemset). +#' @rdname spark.fpGrowth +#' @aliases freqItemsets,FPGrowthModel-method +#' @export +#' @note spark.freqItemsets(FPGrowthModel) since 2.2.0 +setMethod("spark.freqItemsets", signature(object = "FPGrowthModel"), + function(object) { + dataFrame(callJMethod(object@jobj, "freqItemsets")) + }) + +# Get association rules. + +#' @return A \code{SparkDataFrame} with association rules. +#' The \code{SparkDataFrame} contains three columns: +#' \code{antecedent} (an array of the same type as the input column), +#' \code{consequent} (an array of the same type as the input column), +#' and \code{condfidence} (confidence). +#' @rdname spark.fpGrowth +#' @aliases associationRules,FPGrowthModel-method +#' @export +#' @note spark.associationRules(FPGrowthModel) since 2.2.0 +setMethod("spark.associationRules", signature(object = "FPGrowthModel"), + function(object) { + dataFrame(callJMethod(object@jobj, "associationRules")) + }) + +# Makes predictions based on generated association rules + +#' @param newData a SparkDataFrame for testing. +#' @return \code{predict} returns a SparkDataFrame containing predicted values. +#' @rdname spark.fpGrowth +#' @aliases predict,FPGrowthModel-method +#' @export +#' @note predict(FPGrowthModel) since 2.2.0 +setMethod("predict", signature(object = "FPGrowthModel"), + function(object, newData) { + predict_internal(object, newData) + }) + +# Saves the FPGrowth model to the output path. + +#' @param path the directory where the model is saved. +#' @param overwrite logical value indicating whether to overwrite if the output path +#' already exists. Default is FALSE which means throw exception +#' if the output path exists. +#' @rdname spark.fpGrowth +#' @aliases write.ml,FPGrowthModel,character-method +#' @export +#' @seealso \link{read.ml} +#' @note write.ml(FPGrowthModel, character) since 2.2.0 +setMethod("write.ml", signature(object = "FPGrowthModel", path = "character"), + function(object, path, overwrite = FALSE) { + write_internal(object, path, overwrite) + }) diff --git a/R/pkg/R/mllib_utils.R b/R/pkg/R/mllib_utils.R index 04a0a6f944412..5dfef8625061b 100644 --- a/R/pkg/R/mllib_utils.R +++ b/R/pkg/R/mllib_utils.R @@ -118,6 +118,8 @@ read.ml <- function(path) { new("BisectingKMeansModel", jobj = jobj) } else if (isInstanceOf(jobj, "org.apache.spark.ml.r.LinearSVCWrapper")) { new("LinearSVCModel", jobj = jobj) + } else if (isInstanceOf(jobj, "org.apache.spark.ml.r.FPGrowthWrapper")) { + new("FPGrowthModel", jobj = jobj) } else { stop("Unsupported model: ", jobj) } diff --git a/R/pkg/inst/tests/testthat/test_mllib_fpm.R b/R/pkg/inst/tests/testthat/test_mllib_fpm.R new file mode 100644 index 0000000000000..c38f1133897dd --- /dev/null +++ b/R/pkg/inst/tests/testthat/test_mllib_fpm.R @@ -0,0 +1,83 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +library(testthat) + +context("MLlib frequent pattern mining") + +# Tests for MLlib frequent pattern mining algorithms in SparkR +sparkSession <- sparkR.session(enableHiveSupport = FALSE) + +test_that("spark.fpGrowth", { + data <- selectExpr(createDataFrame(data.frame(items = c( + "1,2", + "1,2", + "1,2,3", + "1,3" + ))), "split(items, ',') as items") + + model <- spark.fpGrowth(data, minSupport = 0.3, minConfidence = 0.8, numPartitions = 1) + + itemsets <- collect(spark.freqItemsets(model)) + + expected_itemsets <- data.frame( + items = I(list(list("3"), list("3", "1"), list("2"), list("2", "1"), list("1"))), + freq = c(2, 2, 3, 3, 4) + ) + + expect_equivalent(expected_itemsets, itemsets) + + expected_association_rules <- data.frame( + antecedent = I(list(list("2"), list("3"))), + consequent = I(list(list("1"), list("1"))), + confidence = c(1, 1) + ) + + expect_equivalent(expected_association_rules, collect(spark.associationRules(model))) + + new_data <- selectExpr(createDataFrame(data.frame(items = c( + "1,2", + "1,3", + "2,3" + ))), "split(items, ',') as items") + + expected_predictions <- data.frame( + items = I(list(list("1", "2"), list("1", "3"), list("2", "3"))), + prediction = I(list(list(), list(), list("1"))) + ) + + expect_equivalent(expected_predictions, collect(predict(model, new_data))) + + modelPath <- tempfile(pattern = "spark-fpm", fileext = ".tmp") + write.ml(model, modelPath, overwrite = TRUE) + loaded_model <- read.ml(modelPath) + + expect_equivalent( + itemsets, + collect(spark.freqItemsets(loaded_model))) + + unlink(modelPath) + + model_without_numpartitions <- spark.fpGrowth(data, minSupport = 0.3, minConfidence = 0.8) + expect_equal( + count(spark.freqItemsets(model_without_numpartitions)), + count(spark.freqItemsets(model)) + ) + +}) + +sparkR.session.stop() diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/FPGrowthWrapper.scala b/mllib/src/main/scala/org/apache/spark/ml/r/FPGrowthWrapper.scala new file mode 100644 index 0000000000000..b8151d8d90702 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/r/FPGrowthWrapper.scala @@ -0,0 +1,86 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.r + +import org.apache.hadoop.fs.Path +import org.json4s.JsonDSL._ +import org.json4s.jackson.JsonMethods._ + +import org.apache.spark.ml.fpm.{FPGrowth, FPGrowthModel} +import org.apache.spark.ml.util._ +import org.apache.spark.sql.{DataFrame, Dataset} + +private[r] class FPGrowthWrapper private (val fpGrowthModel: FPGrowthModel) extends MLWritable { + def freqItemsets: DataFrame = fpGrowthModel.freqItemsets + def associationRules: DataFrame = fpGrowthModel.associationRules + + def transform(dataset: Dataset[_]): DataFrame = { + fpGrowthModel.transform(dataset) + } + + override def write: MLWriter = new FPGrowthWrapper.FPGrowthWrapperWriter(this) +} + +private[r] object FPGrowthWrapper extends MLReadable[FPGrowthWrapper] { + + def fit( + data: DataFrame, + minSupport: Double, + minConfidence: Double, + itemsCol: String, + numPartitions: Integer): FPGrowthWrapper = { + val fpGrowth = new FPGrowth() + .setMinSupport(minSupport) + .setMinConfidence(minConfidence) + .setItemsCol(itemsCol) + + if (numPartitions != null && numPartitions > 0) { + fpGrowth.setNumPartitions(numPartitions) + } + + val fpGrowthModel = fpGrowth.fit(data) + + new FPGrowthWrapper(fpGrowthModel) + } + + override def read: MLReader[FPGrowthWrapper] = new FPGrowthWrapperReader + + class FPGrowthWrapperReader extends MLReader[FPGrowthWrapper] { + override def load(path: String): FPGrowthWrapper = { + val modelPath = new Path(path, "model").toString + val fPGrowthModel = FPGrowthModel.load(modelPath) + + new FPGrowthWrapper(fPGrowthModel) + } + } + + class FPGrowthWrapperWriter(instance: FPGrowthWrapper) extends MLWriter { + override protected def saveImpl(path: String): Unit = { + val modelPath = new Path(path, "model").toString + val rMetadataPath = new Path(path, "rMetadata").toString + + val rMetadataJson: String = compact(render( + "class" -> instance.getClass.getName + )) + + sc.parallelize(Seq(rMetadataJson), 1).saveAsTextFile(rMetadataPath) + + instance.fpGrowthModel.save(modelPath) + } + } +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/RWrappers.scala b/mllib/src/main/scala/org/apache/spark/ml/r/RWrappers.scala index 358e522dfe1c8..b30ce12bc6cc8 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/r/RWrappers.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/r/RWrappers.scala @@ -68,6 +68,8 @@ private[r] object RWrappers extends MLReader[Object] { BisectingKMeansWrapper.load(path) case "org.apache.spark.ml.r.LinearSVCWrapper" => LinearSVCWrapper.load(path) + case "org.apache.spark.ml.r.FPGrowthWrapper" => + FPGrowthWrapper.load(path) case _ => throw new SparkException(s"SparkR read.ml does not support load $className") } From c95fbea68e9dfb2c96a1d13dde17d80a37066ae6 Mon Sep 17 00:00:00 2001 From: guoxiaolongzte Date: Tue, 4 Apr 2017 09:56:17 +0100 Subject: [PATCH 0187/1765] =?UTF-8?q?[SPARK-20190][APP-ID]=20applications/?= =?UTF-8?q?/jobs'=20in=20rest=20api,status=20should=20be=20[running|s?= =?UTF-8?q?=E2=80=A6?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit …ucceeded|failed|unknown] ## What changes were proposed in this pull request? '/applications/[app-id]/jobs' in rest api.status should be'[running|succeeded|failed|unknown]'. now status is '[complete|succeeded|failed]'. but '/applications/[app-id]/jobs?status=complete' the server return 'HTTP ERROR 404'. Added '?status=running' and '?status=unknown'. code : public enum JobExecutionStatus { RUNNING, SUCCEEDED, FAILED, UNKNOWN; ## How was this patch tested? manual tests Please review http://spark.apache.org/contributing.html before opening a pull request. Author: guoxiaolongzte Closes #17507 from guoxiaolongzte/SPARK-20190. --- docs/monitoring.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/monitoring.md b/docs/monitoring.md index 6cbc6660e816c..4d0617d253b80 100644 --- a/docs/monitoring.md +++ b/docs/monitoring.md @@ -289,7 +289,7 @@ can be identified by their `[attempt-id]`. In the API listed below, when running
    From 26e7bca2295faeef22b2d9554f316c97bc240fd7 Mon Sep 17 00:00:00 2001 From: Xiao Li Date: Tue, 4 Apr 2017 18:57:46 +0800 Subject: [PATCH 0188/1765] [SPARK-20198][SQL] Remove the inconsistency in table/function name conventions in SparkSession.Catalog APIs ### What changes were proposed in this pull request? Observed by felixcheung , in `SparkSession`.`Catalog` APIs, we have different conventions/rules for table/function identifiers/names. Most APIs accept the qualified name (i.e., `databaseName`.`tableName` or `databaseName`.`functionName`). However, the following five APIs do not accept it. - def listColumns(tableName: String): Dataset[Column] - def getTable(tableName: String): Table - def getFunction(functionName: String): Function - def tableExists(tableName: String): Boolean - def functionExists(functionName: String): Boolean To make them consistent with the other Catalog APIs, this PR does the changes, updates the function/API comments and adds the `params` to clarify the inputs we allow. ### How was this patch tested? Added the test cases . Author: Xiao Li Closes #17518 from gatorsmile/tableIdentifier. --- .../spark/sql/catalyst/parser/SqlBase.g4 | 8 ++ .../sql/catalyst/parser/AstBuilder.scala | 13 +++ .../sql/catalyst/parser/ParseDriver.scala | 7 +- .../sql/catalyst/parser/ParserInterface.scala | 5 +- .../org/apache/spark/sql/SparkSession.scala | 7 +- .../apache/spark/sql/catalog/Catalog.scala | 109 +++++++++++++++--- .../spark/sql/internal/CatalogImpl.scala | 73 ++++++------ .../spark/sql/internal/CatalogSuite.scala | 21 ++++ 8 files changed, 186 insertions(+), 57 deletions(-) diff --git a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 index c4a590ec6916b..52b5b347fa9c7 100644 --- a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 +++ b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 @@ -56,6 +56,10 @@ singleTableIdentifier : tableIdentifier EOF ; +singleFunctionIdentifier + : functionIdentifier EOF + ; + singleDataType : dataType EOF ; @@ -493,6 +497,10 @@ tableIdentifier : (db=identifier '.')? table=identifier ; +functionIdentifier + : (db=identifier '.')? function=identifier + ; + namedExpression : expression (AS? (identifier | identifierList))? ; diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index 162051a8c0e4a..fab7e4c5b1285 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -75,6 +75,11 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with Logging { visitTableIdentifier(ctx.tableIdentifier) } + override def visitSingleFunctionIdentifier( + ctx: SingleFunctionIdentifierContext): FunctionIdentifier = withOrigin(ctx) { + visitFunctionIdentifier(ctx.functionIdentifier) + } + override def visitSingleDataType(ctx: SingleDataTypeContext): DataType = withOrigin(ctx) { visitSparkDataType(ctx.dataType) } @@ -759,6 +764,14 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with Logging { TableIdentifier(ctx.table.getText, Option(ctx.db).map(_.getText)) } + /** + * Create a [[FunctionIdentifier]] from a 'functionName' or 'databaseName'.'functionName' pattern. + */ + override def visitFunctionIdentifier( + ctx: FunctionIdentifierContext): FunctionIdentifier = withOrigin(ctx) { + FunctionIdentifier(ctx.function.getText, Option(ctx.db).map(_.getText)) + } + /* ******************************************************************************************** * Expression parsing * ******************************************************************************************** */ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParseDriver.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParseDriver.scala index f704b0998cada..80ab75cc17fab 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParseDriver.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParseDriver.scala @@ -22,7 +22,7 @@ import org.antlr.v4.runtime.misc.ParseCancellationException import org.apache.spark.internal.Logging import org.apache.spark.sql.AnalysisException -import org.apache.spark.sql.catalyst.TableIdentifier +import org.apache.spark.sql.catalyst.{FunctionIdentifier, TableIdentifier} import org.apache.spark.sql.catalyst.expressions.Expression import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.trees.Origin @@ -49,6 +49,11 @@ abstract class AbstractSqlParser extends ParserInterface with Logging { astBuilder.visitSingleTableIdentifier(parser.singleTableIdentifier()) } + /** Creates FunctionIdentifier for a given SQL string. */ + def parseFunctionIdentifier(sqlText: String): FunctionIdentifier = parse(sqlText) { parser => + astBuilder.visitSingleFunctionIdentifier(parser.singleFunctionIdentifier()) + } + /** * Creates StructType for a given SQL string, which is a comma separated list of field * definitions which will preserve the correct Hive metadata. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParserInterface.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParserInterface.scala index 6edbe253970e9..db3598bde04d3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParserInterface.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParserInterface.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.catalyst.parser -import org.apache.spark.sql.catalyst.TableIdentifier +import org.apache.spark.sql.catalyst.{FunctionIdentifier, TableIdentifier} import org.apache.spark.sql.catalyst.expressions.Expression import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.types.StructType @@ -35,6 +35,9 @@ trait ParserInterface { /** Creates TableIdentifier for a given SQL string. */ def parseTableIdentifier(sqlText: String): TableIdentifier + /** Creates FunctionIdentifier for a given SQL string. */ + def parseFunctionIdentifier(sqlText: String): FunctionIdentifier + /** * Creates StructType for a given SQL string, which is a comma separated list of field * definitions which will preserve the correct Hive metadata. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala index b60499253c42f..95f3463dfe62b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala @@ -591,8 +591,13 @@ class SparkSession private( @transient lazy val catalog: Catalog = new CatalogImpl(self) /** - * Returns the specified table as a `DataFrame`. + * Returns the specified table/view as a `DataFrame`. * + * @param tableName is either a qualified or unqualified name that designates a table or view. + * If a database is specified, it identifies the table/view from the database. + * Otherwise, it first attempts to find a temporary view with the given name + * and then match the table/view from the current database. + * Note that, the global temporary view database is also valid here. * @since 2.0.0 */ def table(tableName: String): DataFrame = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/catalog/Catalog.scala b/sql/core/src/main/scala/org/apache/spark/sql/catalog/Catalog.scala index 50252db789d46..137b0cbc84f8f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/catalog/Catalog.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/catalog/Catalog.scala @@ -54,16 +54,16 @@ abstract class Catalog { def listDatabases(): Dataset[Database] /** - * Returns a list of tables in the current database. - * This includes all temporary tables. + * Returns a list of tables/views in the current database. + * This includes all temporary views. * * @since 2.0.0 */ def listTables(): Dataset[Table] /** - * Returns a list of tables in the specified database. - * This includes all temporary tables. + * Returns a list of tables/views in the specified database. + * This includes all temporary views. * * @since 2.0.0 */ @@ -88,17 +88,21 @@ abstract class Catalog { def listFunctions(dbName: String): Dataset[Function] /** - * Returns a list of columns for the given table in the current database or - * the given temporary table. + * Returns a list of columns for the given table/view or temporary view. * + * @param tableName is either a qualified or unqualified name that designates a table/view. + * If no database identifier is provided, it refers to a temporary view or + * a table/view in the current database. * @since 2.0.0 */ @throws[AnalysisException]("table does not exist") def listColumns(tableName: String): Dataset[Column] /** - * Returns a list of columns for the given table in the specified database. + * Returns a list of columns for the given table/view in the specified database. * + * @param dbName is a name that designates a database. + * @param tableName is an unqualified name that designates a table/view. * @since 2.0.0 */ @throws[AnalysisException]("database or table does not exist") @@ -115,9 +119,11 @@ abstract class Catalog { /** * Get the table or view with the specified name. This table can be a temporary view or a - * table/view in the current database. This throws an AnalysisException when no Table - * can be found. + * table/view. This throws an AnalysisException when no Table can be found. * + * @param tableName is either a qualified or unqualified name that designates a table/view. + * If no database identifier is provided, it refers to a table/view in + * the current database. * @since 2.1.0 */ @throws[AnalysisException]("table does not exist") @@ -134,9 +140,11 @@ abstract class Catalog { /** * Get the function with the specified name. This function can be a temporary function or a - * function in the current database. This throws an AnalysisException when the function cannot - * be found. + * function. This throws an AnalysisException when the function cannot be found. * + * @param functionName is either a qualified or unqualified name that designates a function. + * If no database identifier is provided, it refers to a temporary function + * or a function in the current database. * @since 2.1.0 */ @throws[AnalysisException]("function does not exist") @@ -146,6 +154,8 @@ abstract class Catalog { * Get the function with the specified name. This throws an AnalysisException when the function * cannot be found. * + * @param dbName is a name that designates a database. + * @param functionName is an unqualified name that designates a function in the specified database * @since 2.1.0 */ @throws[AnalysisException]("database or function does not exist") @@ -160,8 +170,11 @@ abstract class Catalog { /** * Check if the table or view with the specified name exists. This can either be a temporary - * view or a table/view in the current database. + * view or a table/view. * + * @param tableName is either a qualified or unqualified name that designates a table/view. + * If no database identifier is provided, it refers to a table/view in + * the current database. * @since 2.1.0 */ def tableExists(tableName: String): Boolean @@ -169,14 +182,19 @@ abstract class Catalog { /** * Check if the table or view with the specified name exists in the specified database. * + * @param dbName is a name that designates a database. + * @param tableName is an unqualified name that designates a table. * @since 2.1.0 */ def tableExists(dbName: String, tableName: String): Boolean /** * Check if the function with the specified name exists. This can either be a temporary function - * or a function in the current database. + * or a function. * + * @param functionName is either a qualified or unqualified name that designates a function. + * If no database identifier is provided, it refers to a function in + * the current database. * @since 2.1.0 */ def functionExists(functionName: String): Boolean @@ -184,6 +202,8 @@ abstract class Catalog { /** * Check if the function with the specified name exists in the specified database. * + * @param dbName is a name that designates a database. + * @param functionName is an unqualified name that designates a function. * @since 2.1.0 */ def functionExists(dbName: String, functionName: String): Boolean @@ -192,6 +212,9 @@ abstract class Catalog { * Creates a table from the given path and returns the corresponding DataFrame. * It will use the default data source configured by spark.sql.sources.default. * + * @param tableName is either a qualified or unqualified name that designates a table. + * If no database identifier is provided, it refers to a table in + * the current database. * @since 2.0.0 */ @deprecated("use createTable instead.", "2.2.0") @@ -204,6 +227,9 @@ abstract class Catalog { * Creates a table from the given path and returns the corresponding DataFrame. * It will use the default data source configured by spark.sql.sources.default. * + * @param tableName is either a qualified or unqualified name that designates a table. + * If no database identifier is provided, it refers to a table in + * the current database. * @since 2.2.0 */ @Experimental @@ -214,6 +240,9 @@ abstract class Catalog { * Creates a table from the given path based on a data source and returns the corresponding * DataFrame. * + * @param tableName is either a qualified or unqualified name that designates a table. + * If no database identifier is provided, it refers to a table in + * the current database. * @since 2.0.0 */ @deprecated("use createTable instead.", "2.2.0") @@ -226,6 +255,9 @@ abstract class Catalog { * Creates a table from the given path based on a data source and returns the corresponding * DataFrame. * + * @param tableName is either a qualified or unqualified name that designates a table. + * If no database identifier is provided, it refers to a table in + * the current database. * @since 2.2.0 */ @Experimental @@ -236,6 +268,9 @@ abstract class Catalog { * Creates a table from the given path based on a data source and a set of options. * Then, returns the corresponding DataFrame. * + * @param tableName is either a qualified or unqualified name that designates a table. + * If no database identifier is provided, it refers to a table in + * the current database. * @since 2.0.0 */ @deprecated("use createTable instead.", "2.2.0") @@ -251,6 +286,9 @@ abstract class Catalog { * Creates a table from the given path based on a data source and a set of options. * Then, returns the corresponding DataFrame. * + * @param tableName is either a qualified or unqualified name that designates a table. + * If no database identifier is provided, it refers to a table in + * the current database. * @since 2.2.0 */ @Experimental @@ -267,6 +305,9 @@ abstract class Catalog { * Creates a table from the given path based on a data source and a set of options. * Then, returns the corresponding DataFrame. * + * @param tableName is either a qualified or unqualified name that designates a table. + * If no database identifier is provided, it refers to a table in + * the current database. * @since 2.0.0 */ @deprecated("use createTable instead.", "2.2.0") @@ -283,6 +324,9 @@ abstract class Catalog { * Creates a table from the given path based on a data source and a set of options. * Then, returns the corresponding DataFrame. * + * @param tableName is either a qualified or unqualified name that designates a table. + * If no database identifier is provided, it refers to a table in + * the current database. * @since 2.2.0 */ @Experimental @@ -297,6 +341,9 @@ abstract class Catalog { * Create a table from the given path based on a data source, a schema and a set of options. * Then, returns the corresponding DataFrame. * + * @param tableName is either a qualified or unqualified name that designates a table. + * If no database identifier is provided, it refers to a table in + * the current database. * @since 2.0.0 */ @deprecated("use createTable instead.", "2.2.0") @@ -313,6 +360,9 @@ abstract class Catalog { * Create a table from the given path based on a data source, a schema and a set of options. * Then, returns the corresponding DataFrame. * + * @param tableName is either a qualified or unqualified name that designates a table. + * If no database identifier is provided, it refers to a table in + * the current database. * @since 2.2.0 */ @Experimental @@ -330,6 +380,9 @@ abstract class Catalog { * Create a table from the given path based on a data source, a schema and a set of options. * Then, returns the corresponding DataFrame. * + * @param tableName is either a qualified or unqualified name that designates a table. + * If no database identifier is provided, it refers to a table in + * the current database. * @since 2.0.0 */ @deprecated("use createTable instead.", "2.2.0") @@ -347,6 +400,9 @@ abstract class Catalog { * Create a table from the given path based on a data source, a schema and a set of options. * Then, returns the corresponding DataFrame. * + * @param tableName is either a qualified or unqualified name that designates a table. + * If no database identifier is provided, it refers to a table in + * the current database. * @since 2.2.0 */ @Experimental @@ -368,7 +424,7 @@ abstract class Catalog { * Note that, the return type of this method was Unit in Spark 2.0, but changed to Boolean * in Spark 2.1. * - * @param viewName the name of the view to be dropped. + * @param viewName the name of the temporary view to be dropped. * @return true if the view is dropped successfully, false otherwise. * @since 2.0.0 */ @@ -383,15 +439,18 @@ abstract class Catalog { * preserved database `global_temp`, and we must use the qualified name to refer a global temp * view, e.g. `SELECT * FROM global_temp.view1`. * - * @param viewName the name of the view to be dropped. + * @param viewName the unqualified name of the temporary view to be dropped. * @return true if the view is dropped successfully, false otherwise. * @since 2.1.0 */ def dropGlobalTempView(viewName: String): Boolean /** - * Recover all the partitions in the directory of a table and update the catalog. + * Recovers all the partitions in the directory of a table and update the catalog. * + * @param tableName is either a qualified or unqualified name that designates a table. + * If no database identifier is provided, it refers to a table in the + * current database. * @since 2.1.1 */ def recoverPartitions(tableName: String): Unit @@ -399,6 +458,9 @@ abstract class Catalog { /** * Returns true if the table is currently cached in-memory. * + * @param tableName is either a qualified or unqualified name that designates a table/view. + * If no database identifier is provided, it refers to a temporary view or + * a table/view in the current database. * @since 2.0.0 */ def isCached(tableName: String): Boolean @@ -406,6 +468,9 @@ abstract class Catalog { /** * Caches the specified table in-memory. * + * @param tableName is either a qualified or unqualified name that designates a table/view. + * If no database identifier is provided, it refers to a temporary view or + * a table/view in the current database. * @since 2.0.0 */ def cacheTable(tableName: String): Unit @@ -413,6 +478,9 @@ abstract class Catalog { /** * Removes the specified table from the in-memory cache. * + * @param tableName is either a qualified or unqualified name that designates a table/view. + * If no database identifier is provided, it refers to a temporary view or + * a table/view in the current database. * @since 2.0.0 */ def uncacheTable(tableName: String): Unit @@ -425,7 +493,7 @@ abstract class Catalog { def clearCache(): Unit /** - * Invalidate and refresh all the cached metadata of the given table. For performance reasons, + * Invalidates and refreshes all the cached metadata of the given table. For performance reasons, * Spark SQL or the external data source library it uses might cache certain metadata about a * table, such as the location of blocks. When those change outside of Spark SQL, users should * call this function to invalidate the cache. @@ -433,13 +501,16 @@ abstract class Catalog { * If this table is cached as an InMemoryRelation, drop the original cached version and make the * new version cached lazily. * + * @param tableName is either a qualified or unqualified name that designates a table/view. + * If no database identifier is provided, it refers to a temporary view or + * a table/view in the current database. * @since 2.0.0 */ def refreshTable(tableName: String): Unit /** - * Invalidate and refresh all the cached data (and the associated metadata) for any dataframe that - * contains the given data source path. Path matching is by prefix, i.e. "/" would invalidate + * Invalidates and refreshes all the cached data (and the associated metadata) for any [[Dataset]] + * that contains the given data source path. Path matching is by prefix, i.e. "/" would invalidate * everything that is cached. * * @since 2.0.0 diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/CatalogImpl.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/CatalogImpl.scala index 53374859f13f4..5d1c35aba529a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/CatalogImpl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/CatalogImpl.scala @@ -19,8 +19,6 @@ package org.apache.spark.sql.internal import scala.reflect.runtime.universe.TypeTag -import org.apache.hadoop.fs.Path - import org.apache.spark.annotation.Experimental import org.apache.spark.sql._ import org.apache.spark.sql.catalog.{Catalog, Column, Database, Function, Table} @@ -143,11 +141,12 @@ class CatalogImpl(sparkSession: SparkSession) extends Catalog { } /** - * Returns a list of columns for the given table in the current database. + * Returns a list of columns for the given table temporary view. */ @throws[AnalysisException]("table does not exist") override def listColumns(tableName: String): Dataset[Column] = { - listColumns(TableIdentifier(tableName, None)) + val tableIdent = sparkSession.sessionState.sqlParser.parseTableIdentifier(tableName) + listColumns(tableIdent) } /** @@ -177,7 +176,7 @@ class CatalogImpl(sparkSession: SparkSession) extends Catalog { } /** - * Get the database with the specified name. This throws an `AnalysisException` when no + * Gets the database with the specified name. This throws an `AnalysisException` when no * `Database` can be found. */ override def getDatabase(dbName: String): Database = { @@ -185,16 +184,16 @@ class CatalogImpl(sparkSession: SparkSession) extends Catalog { } /** - * Get the table or view with the specified name. This table can be a temporary view or a - * table/view in the current database. This throws an `AnalysisException` when no `Table` - * can be found. + * Gets the table or view with the specified name. This table can be a temporary view or a + * table/view. This throws an `AnalysisException` when no `Table` can be found. */ override def getTable(tableName: String): Table = { - getTable(null, tableName) + val tableIdent = sparkSession.sessionState.sqlParser.parseTableIdentifier(tableName) + getTable(tableIdent.database.orNull, tableIdent.table) } /** - * Get the table or view with the specified name in the specified database. This throws an + * Gets the table or view with the specified name in the specified database. This throws an * `AnalysisException` when no `Table` can be found. */ override def getTable(dbName: String, tableName: String): Table = { @@ -202,16 +201,16 @@ class CatalogImpl(sparkSession: SparkSession) extends Catalog { } /** - * Get the function with the specified name. This function can be a temporary function or a - * function in the current database. This throws an `AnalysisException` when no `Function` - * can be found. + * Gets the function with the specified name. This function can be a temporary function or a + * function. This throws an `AnalysisException` when no `Function` can be found. */ override def getFunction(functionName: String): Function = { - getFunction(null, functionName) + val functionIdent = sparkSession.sessionState.sqlParser.parseFunctionIdentifier(functionName) + getFunction(functionIdent.database.orNull, functionIdent.funcName) } /** - * Get the function with the specified name. This returns `None` when no `Function` can be + * Gets the function with the specified name. This returns `None` when no `Function` can be * found. */ override def getFunction(dbName: String, functionName: String): Function = { @@ -219,22 +218,23 @@ class CatalogImpl(sparkSession: SparkSession) extends Catalog { } /** - * Check if the database with the specified name exists. + * Checks if the database with the specified name exists. */ override def databaseExists(dbName: String): Boolean = { sessionCatalog.databaseExists(dbName) } /** - * Check if the table or view with the specified name exists. This can either be a temporary - * view or a table/view in the current database. + * Checks if the table or view with the specified name exists. This can either be a temporary + * view or a table/view. */ override def tableExists(tableName: String): Boolean = { - tableExists(null, tableName) + val tableIdent = sparkSession.sessionState.sqlParser.parseTableIdentifier(tableName) + tableExists(tableIdent.database.orNull, tableIdent.table) } /** - * Check if the table or view with the specified name exists in the specified database. + * Checks if the table or view with the specified name exists in the specified database. */ override def tableExists(dbName: String, tableName: String): Boolean = { val tableIdent = TableIdentifier(tableName, Option(dbName)) @@ -242,15 +242,16 @@ class CatalogImpl(sparkSession: SparkSession) extends Catalog { } /** - * Check if the function with the specified name exists. This can either be a temporary function - * or a function in the current database. + * Checks if the function with the specified name exists. This can either be a temporary function + * or a function. */ override def functionExists(functionName: String): Boolean = { - functionExists(null, functionName) + val functionIdent = sparkSession.sessionState.sqlParser.parseFunctionIdentifier(functionName) + functionExists(functionIdent.database.orNull, functionIdent.funcName) } /** - * Check if the function with the specified name exists in the specified database. + * Checks if the function with the specified name exists in the specified database. */ override def functionExists(dbName: String, functionName: String): Boolean = { sessionCatalog.functionExists(FunctionIdentifier(functionName, Option(dbName))) @@ -303,7 +304,7 @@ class CatalogImpl(sparkSession: SparkSession) extends Catalog { /** * :: Experimental :: * (Scala-specific) - * Create a table from the given path based on a data source, a schema and a set of options. + * Creates a table from the given path based on a data source, a schema and a set of options. * Then, returns the corresponding DataFrame. * * @group ddl_ops @@ -338,7 +339,7 @@ class CatalogImpl(sparkSession: SparkSession) extends Catalog { * Drops the local temporary view with the given view name in the catalog. * If the view has been cached/persisted before, it's also unpersisted. * - * @param viewName the name of the view to be dropped. + * @param viewName the identifier of the temporary view to be dropped. * @group ddl_ops * @since 2.0.0 */ @@ -353,7 +354,7 @@ class CatalogImpl(sparkSession: SparkSession) extends Catalog { * Drops the global temporary view with the given view name in the catalog. * If the view has been cached/persisted before, it's also unpersisted. * - * @param viewName the name of the view to be dropped. + * @param viewName the identifier of the global temporary view to be dropped. * @group ddl_ops * @since 2.1.0 */ @@ -365,9 +366,11 @@ class CatalogImpl(sparkSession: SparkSession) extends Catalog { } /** - * Recover all the partitions in the directory of a table and update the catalog. + * Recovers all the partitions in the directory of a table and update the catalog. * - * @param tableName the name of the table to be repaired. + * @param tableName is either a qualified or unqualified name that designates a table. + * If no database identifier is provided, it refers to a table in the + * current database. * @group ddl_ops * @since 2.1.1 */ @@ -378,7 +381,7 @@ class CatalogImpl(sparkSession: SparkSession) extends Catalog { } /** - * Returns true if the table is currently cached in-memory. + * Returns true if the table or view is currently cached in-memory. * * @group cachemgmt * @since 2.0.0 @@ -388,7 +391,7 @@ class CatalogImpl(sparkSession: SparkSession) extends Catalog { } /** - * Caches the specified table in-memory. + * Caches the specified table or view in-memory. * * @group cachemgmt * @since 2.0.0 @@ -398,7 +401,7 @@ class CatalogImpl(sparkSession: SparkSession) extends Catalog { } /** - * Removes the specified table from the in-memory cache. + * Removes the specified table or view from the in-memory cache. * * @group cachemgmt * @since 2.0.0 @@ -408,7 +411,7 @@ class CatalogImpl(sparkSession: SparkSession) extends Catalog { } /** - * Removes all cached tables from the in-memory cache. + * Removes all cached tables or views from the in-memory cache. * * @group cachemgmt * @since 2.0.0 @@ -428,7 +431,7 @@ class CatalogImpl(sparkSession: SparkSession) extends Catalog { } /** - * Refresh the cache entry for a table, if any. For Hive metastore table, the metadata + * Refreshes the cache entry for a table or view, if any. For Hive metastore table, the metadata * is refreshed. For data source tables, the schema will not be inferred and refreshed. * * @group cachemgmt @@ -452,7 +455,7 @@ class CatalogImpl(sparkSession: SparkSession) extends Catalog { } /** - * Refresh the cache entry and the associated metadata for all dataframes (if any), that contain + * Refreshes the cache entry and the associated metadata for all Dataset (if any), that contain * the given data source path. * * @group cachemgmt diff --git a/sql/core/src/test/scala/org/apache/spark/sql/internal/CatalogSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/internal/CatalogSuite.scala index 9742b3b2d5c29..6469e501c1f68 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/internal/CatalogSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/internal/CatalogSuite.scala @@ -102,6 +102,11 @@ class CatalogSuite assert(col.isPartition == tableMetadata.partitionColumnNames.contains(col.name)) assert(col.isBucket == bucketColumnNames.contains(col.name)) } + + dbName.foreach { db => + val expected = columns.collect().map(_.name).toSet + assert(spark.catalog.listColumns(s"$db.$tableName").collect().map(_.name).toSet == expected) + } } override def afterEach(): Unit = { @@ -345,6 +350,7 @@ class CatalogSuite // Find a qualified table assert(spark.catalog.getTable(db, "tbl_y").name === "tbl_y") + assert(spark.catalog.getTable(s"$db.tbl_y").name === "tbl_y") // Find an unqualified table using the current database intercept[AnalysisException](spark.catalog.getTable("tbl_y")) @@ -378,6 +384,11 @@ class CatalogSuite assert(fn2.database === db) assert(!fn2.isTemporary) + val fn2WithQualifiedName = spark.catalog.getFunction(s"$db.fn2") + assert(fn2WithQualifiedName.name === "fn2") + assert(fn2WithQualifiedName.database === db) + assert(!fn2WithQualifiedName.isTemporary) + // Find an unqualified function using the current database intercept[AnalysisException](spark.catalog.getFunction("fn2")) spark.catalog.setCurrentDatabase(db) @@ -403,6 +414,7 @@ class CatalogSuite assert(!spark.catalog.tableExists("tbl_x")) assert(!spark.catalog.tableExists("tbl_y")) assert(!spark.catalog.tableExists(db, "tbl_y")) + assert(!spark.catalog.tableExists(s"$db.tbl_y")) // Create objects. createTempTable("tbl_x") @@ -413,11 +425,15 @@ class CatalogSuite // Find a qualified table assert(spark.catalog.tableExists(db, "tbl_y")) + assert(spark.catalog.tableExists(s"$db.tbl_y")) // Find an unqualified table using the current database assert(!spark.catalog.tableExists("tbl_y")) spark.catalog.setCurrentDatabase(db) assert(spark.catalog.tableExists("tbl_y")) + + // Unable to find the table, although the temp view with the given name exists + assert(!spark.catalog.tableExists(db, "tbl_x")) } } } @@ -429,6 +445,7 @@ class CatalogSuite assert(!spark.catalog.functionExists("fn1")) assert(!spark.catalog.functionExists("fn2")) assert(!spark.catalog.functionExists(db, "fn2")) + assert(!spark.catalog.functionExists(s"$db.fn2")) // Create objects. createTempFunction("fn1") @@ -439,11 +456,15 @@ class CatalogSuite // Find a qualified function assert(spark.catalog.functionExists(db, "fn2")) + assert(spark.catalog.functionExists(s"$db.fn2")) // Find an unqualified function using the current database assert(!spark.catalog.functionExists("fn2")) spark.catalog.setCurrentDatabase(db) assert(spark.catalog.functionExists("fn2")) + + // Unable to find the function, although the temp function with the given name exists + assert(!spark.catalog.functionExists(db, "fn1")) } } } From 11238d4c62961c03376d9b2899221ec74313363a Mon Sep 17 00:00:00 2001 From: Anirudh Ramanathan Date: Tue, 4 Apr 2017 10:46:44 -0700 Subject: [PATCH 0189/1765] [SPARK-18278][SCHEDULER] Documentation to point to Kubernetes cluster scheduler ## What changes were proposed in this pull request? Adding documentation to point to Kubernetes cluster scheduler being developed out-of-repo in https://github.com/apache-spark-on-k8s/spark cc rxin srowen tnachen ash211 mccheah erikerlandson ## How was this patch tested? Docs only change Author: Anirudh Ramanathan Author: foxish Closes #17522 from foxish/upstream-doc. --- docs/cluster-overview.md | 6 +++++- docs/index.md | 1 + 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/docs/cluster-overview.md b/docs/cluster-overview.md index 814e4406cf435..a2ad958959a50 100644 --- a/docs/cluster-overview.md +++ b/docs/cluster-overview.md @@ -52,7 +52,11 @@ The system currently supports three cluster managers: * [Apache Mesos](running-on-mesos.html) -- a general cluster manager that can also run Hadoop MapReduce and service applications. * [Hadoop YARN](running-on-yarn.html) -- the resource manager in Hadoop 2. - +* [Kubernetes (experimental)](https://github.com/apache-spark-on-k8s/spark) -- In addition to the above, +there is experimental support for Kubernetes. Kubernetes is an open-source platform +for providing container-centric infrastructure. Kubernetes support is being actively +developed in an [apache-spark-on-k8s](https://github.com/apache-spark-on-k8s/) Github organization. +For documentation, refer to that project's README. # Submitting Applications diff --git a/docs/index.md b/docs/index.md index 19a9d3bfc6017..ad4f24ff1a5d1 100644 --- a/docs/index.md +++ b/docs/index.md @@ -115,6 +115,7 @@ options for deployment: * [Mesos](running-on-mesos.html): deploy a private cluster using [Apache Mesos](http://mesos.apache.org) * [YARN](running-on-yarn.html): deploy Spark on top of Hadoop NextGen (YARN) + * [Kubernetes (experimental)](https://github.com/apache-spark-on-k8s/spark): deploy Spark on top of Kubernetes **Other Documents:** From 0736980f395f114faccbd58e78280ca63ed289c7 Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Tue, 4 Apr 2017 11:38:05 -0700 Subject: [PATCH 0190/1765] [SPARK-20191][YARN] Crate wrapper for RackResolver so tests can override it. Current test code tries to override the RackResolver used by setting configuration params, but because YARN libs statically initialize the resolver the first time it's used, that means that those configs don't really take effect during Spark tests. This change adds a wrapper class that easily allows tests to override the behavior of the resolver for the Spark code that uses it. Author: Marcelo Vanzin Closes #17508 from vanzin/SPARK-20191. --- ...yPreferredContainerPlacementStrategy.scala | 6 +-- .../spark/deploy/yarn/SparkRackResolver.scala | 40 +++++++++++++++++++ .../spark/deploy/yarn/YarnAllocator.scala | 13 ++---- .../spark/deploy/yarn/YarnRMClient.scala | 2 +- .../yarn/LocalityPlacementStrategySuite.scala | 8 +--- .../deploy/yarn/YarnAllocatorSuite.scala | 22 +++------- 6 files changed, 56 insertions(+), 35 deletions(-) create mode 100644 resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/SparkRackResolver.scala diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/LocalityPreferredContainerPlacementStrategy.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/LocalityPreferredContainerPlacementStrategy.scala index f2b6324db619a..257dc83621e98 100644 --- a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/LocalityPreferredContainerPlacementStrategy.scala +++ b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/LocalityPreferredContainerPlacementStrategy.scala @@ -23,7 +23,6 @@ import scala.collection.JavaConverters._ import org.apache.hadoop.conf.Configuration import org.apache.hadoop.yarn.api.records.{ContainerId, Resource} import org.apache.hadoop.yarn.client.api.AMRMClient.ContainerRequest -import org.apache.hadoop.yarn.util.RackResolver import org.apache.spark.SparkConf import org.apache.spark.internal.config._ @@ -83,7 +82,8 @@ private[yarn] case class ContainerLocalityPreferences(nodes: Array[String], rack private[yarn] class LocalityPreferredContainerPlacementStrategy( val sparkConf: SparkConf, val yarnConf: Configuration, - val resource: Resource) { + val resource: Resource, + resolver: SparkRackResolver) { /** * Calculate each container's node locality and rack locality @@ -139,7 +139,7 @@ private[yarn] class LocalityPreferredContainerPlacementStrategy( // still be allocated with new container request. val hosts = preferredLocalityRatio.filter(_._2 > 0).keys.toArray val racks = hosts.map { h => - RackResolver.resolve(yarnConf, h).getNetworkLocation + resolver.resolve(yarnConf, h) }.toSet containerLocalityPreferences += ContainerLocalityPreferences(hosts, racks.toArray) diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/SparkRackResolver.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/SparkRackResolver.scala new file mode 100644 index 0000000000000..c711d088f2116 --- /dev/null +++ b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/SparkRackResolver.scala @@ -0,0 +1,40 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.deploy.yarn + +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.yarn.util.RackResolver +import org.apache.log4j.{Level, Logger} + +/** + * Wrapper around YARN's [[RackResolver]]. This allows Spark tests to easily override the + * default behavior, since YARN's class self-initializes the first time it's called, and + * future calls all use the initial configuration. + */ +private[yarn] class SparkRackResolver { + + // RackResolver logs an INFO message whenever it resolves a rack, which is way too often. + if (Logger.getLogger(classOf[RackResolver]).getLevel == null) { + Logger.getLogger(classOf[RackResolver]).setLevel(Level.WARN) + } + + def resolve(conf: Configuration, hostName: String): String = { + RackResolver.resolve(conf, hostName).getNetworkLocation() + } + +} diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala index 25556763da904..ed77a6e4a1c7c 100644 --- a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala +++ b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala @@ -30,7 +30,6 @@ import org.apache.hadoop.yarn.api.records._ import org.apache.hadoop.yarn.client.api.AMRMClient import org.apache.hadoop.yarn.client.api.AMRMClient.ContainerRequest import org.apache.hadoop.yarn.conf.YarnConfiguration -import org.apache.hadoop.yarn.util.RackResolver import org.apache.log4j.{Level, Logger} import org.apache.spark.{SecurityManager, SparkConf, SparkException} @@ -65,16 +64,12 @@ private[yarn] class YarnAllocator( amClient: AMRMClient[ContainerRequest], appAttemptId: ApplicationAttemptId, securityMgr: SecurityManager, - localResources: Map[String, LocalResource]) + localResources: Map[String, LocalResource], + resolver: SparkRackResolver) extends Logging { import YarnAllocator._ - // RackResolver logs an INFO message whenever it resolves a rack, which is way too often. - if (Logger.getLogger(classOf[RackResolver]).getLevel == null) { - Logger.getLogger(classOf[RackResolver]).setLevel(Level.WARN) - } - // Visible for testing. val allocatedHostToContainersMap = new HashMap[String, collection.mutable.Set[ContainerId]] val allocatedContainerToHostMap = new HashMap[ContainerId, String] @@ -159,7 +154,7 @@ private[yarn] class YarnAllocator( // A container placement strategy based on pending tasks' locality preference private[yarn] val containerPlacementStrategy = - new LocalityPreferredContainerPlacementStrategy(sparkConf, conf, resource) + new LocalityPreferredContainerPlacementStrategy(sparkConf, conf, resource, resolver) /** * Use a different clock for YarnAllocator. This is mainly used for testing. @@ -424,7 +419,7 @@ private[yarn] class YarnAllocator( // Match remaining by rack val remainingAfterRackMatches = new ArrayBuffer[Container] for (allocatedContainer <- remainingAfterHostMatches) { - val rack = RackResolver.resolve(conf, allocatedContainer.getNodeId.getHost).getNetworkLocation + val rack = resolver.resolve(conf, allocatedContainer.getNodeId.getHost) matchContainerToRequest(allocatedContainer, rack, containersToUse, remainingAfterRackMatches) } diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnRMClient.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnRMClient.scala index 53fb467f6408d..72f4d273ab53b 100644 --- a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnRMClient.scala +++ b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnRMClient.scala @@ -75,7 +75,7 @@ private[spark] class YarnRMClient extends Logging { registered = true } new YarnAllocator(driverUrl, driverRef, conf, sparkConf, amClient, getAttemptId(), securityMgr, - localResources) + localResources, new SparkRackResolver()) } /** diff --git a/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/LocalityPlacementStrategySuite.scala b/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/LocalityPlacementStrategySuite.scala index fb80ff9f31322..b7f25656e49ac 100644 --- a/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/LocalityPlacementStrategySuite.scala +++ b/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/LocalityPlacementStrategySuite.scala @@ -17,10 +17,9 @@ package org.apache.spark.deploy.yarn +import scala.collection.JavaConverters._ import scala.collection.mutable.{HashMap, HashSet, Set} -import org.apache.hadoop.fs.CommonConfigurationKeysPublic -import org.apache.hadoop.net.DNSToSwitchMapping import org.apache.hadoop.yarn.api.records._ import org.apache.hadoop.yarn.conf.YarnConfiguration import org.mockito.Mockito._ @@ -51,9 +50,6 @@ class LocalityPlacementStrategySuite extends SparkFunSuite { private def runTest(): Unit = { val yarnConf = new YarnConfiguration() - yarnConf.setClass( - CommonConfigurationKeysPublic.NET_TOPOLOGY_NODE_SWITCH_MAPPING_IMPL_KEY, - classOf[MockResolver], classOf[DNSToSwitchMapping]) // The numbers below have been chosen to balance being large enough to replicate the // original issue while not taking too long to run when the issue is fixed. The main @@ -62,7 +58,7 @@ class LocalityPlacementStrategySuite extends SparkFunSuite { val resource = Resource.newInstance(8 * 1024, 4) val strategy = new LocalityPreferredContainerPlacementStrategy(new SparkConf(), - yarnConf, resource) + yarnConf, resource, new MockResolver()) val totalTasks = 32 * 1024 val totalContainers = totalTasks / 16 diff --git a/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnAllocatorSuite.scala b/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnAllocatorSuite.scala index fcc0594cf6d80..97b0e8aca3330 100644 --- a/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnAllocatorSuite.scala +++ b/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnAllocatorSuite.scala @@ -17,12 +17,9 @@ package org.apache.spark.deploy.yarn -import java.util.{Arrays, List => JList} - import scala.collection.JavaConverters._ -import org.apache.hadoop.fs.CommonConfigurationKeysPublic -import org.apache.hadoop.net.DNSToSwitchMapping +import org.apache.hadoop.conf.Configuration import org.apache.hadoop.yarn.api.records._ import org.apache.hadoop.yarn.client.api.AMRMClient import org.apache.hadoop.yarn.client.api.AMRMClient.ContainerRequest @@ -38,24 +35,16 @@ import org.apache.spark.rpc.RpcEndpointRef import org.apache.spark.scheduler.SplitInfo import org.apache.spark.util.ManualClock -class MockResolver extends DNSToSwitchMapping { +class MockResolver extends SparkRackResolver { - override def resolve(names: JList[String]): JList[String] = { - if (names.size > 0 && names.get(0) == "host3") Arrays.asList("/rack2") - else Arrays.asList("/rack1") + override def resolve(conf: Configuration, hostName: String): String = { + if (hostName == "host3") "/rack2" else "/rack1" } - override def reloadCachedMappings() {} - - def reloadCachedMappings(names: JList[String]) {} } class YarnAllocatorSuite extends SparkFunSuite with Matchers with BeforeAndAfterEach { val conf = new YarnConfiguration() - conf.setClass( - CommonConfigurationKeysPublic.NET_TOPOLOGY_NODE_SWITCH_MAPPING_IMPL_KEY, - classOf[MockResolver], classOf[DNSToSwitchMapping]) - val sparkConf = new SparkConf() sparkConf.set("spark.driver.host", "localhost") sparkConf.set("spark.driver.port", "4040") @@ -111,7 +100,8 @@ class YarnAllocatorSuite extends SparkFunSuite with Matchers with BeforeAndAfter rmClient, appAttemptId, new SecurityManager(sparkConf), - Map()) + Map(), + new MockResolver()) } def createContainer(host: String): Container = { From 0e2ee8204415d28613a60593f2b6e2b3d4ef794f Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Tue, 4 Apr 2017 11:42:14 -0700 Subject: [PATCH 0191/1765] [MINOR][R] Reorder `Collate` fields in DESCRIPTION file ## What changes were proposed in this pull request? It seems cran check scripts corrects `R/pkg/DESCRIPTION` and follows the order in `Collate` fields. This PR proposes to fix `catalog.R`'s order so that running this script does not show up a small diff in this file every time. ## How was this patch tested? Manually via `./R/check-cran.sh`. Author: hyukjinkwon Closes #17528 from HyukjinKwon/minor-reorder-description. --- R/pkg/DESCRIPTION | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/R/pkg/DESCRIPTION b/R/pkg/DESCRIPTION index f475ee87702e1..879c1f80f2c5d 100644 --- a/R/pkg/DESCRIPTION +++ b/R/pkg/DESCRIPTION @@ -32,10 +32,10 @@ Collate: 'pairRDD.R' 'DataFrame.R' 'SQLContext.R' - 'catalog.R' 'WindowSpec.R' 'backend.R' 'broadcast.R' + 'catalog.R' 'client.R' 'context.R' 'deserialize.R' From 402bf2a50ddd4039ff9f376b641bd18fffa54171 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Tue, 4 Apr 2017 11:56:21 -0700 Subject: [PATCH 0192/1765] [SPARK-20204][SQL] remove SimpleCatalystConf and CatalystConf type alias ## What changes were proposed in this pull request? This is a follow-up of https://github.com/apache/spark/pull/17285 . ## How was this patch tested? existing tests Author: Wenchen Fan Closes #17521 from cloud-fan/conf. --- .../sql/catalyst/SimpleCatalystConf.scala | 51 ------------------- .../sql/catalyst/analysis/Analyzer.scala | 21 ++++---- .../sql/catalyst/analysis/ResolveHints.scala | 4 +- .../analysis/ResolveInlineTables.scala | 5 +- .../SubstituteUnresolvedOrdinals.scala | 4 +- .../spark/sql/catalyst/analysis/view.scala | 4 +- .../sql/catalyst/catalog/SessionCatalog.scala | 7 +-- .../sql/catalyst/catalog/interface.scala | 5 +- .../sql/catalyst/optimizer/Optimizer.scala | 18 +++---- .../sql/catalyst/optimizer/expressions.scala | 8 +-- .../spark/sql/catalyst/optimizer/joins.scala | 3 +- .../apache/spark/sql/catalyst/package.scala | 8 --- .../plans/logical/LocalRelation.scala | 5 +- .../catalyst/plans/logical/LogicalPlan.scala | 8 +-- .../plans/logical/basicLogicalOperators.scala | 32 ++++++------ .../statsEstimation/AggregateEstimation.scala | 4 +- .../statsEstimation/EstimationUtils.scala | 4 +- .../statsEstimation/FilterEstimation.scala | 4 +- .../statsEstimation/JoinEstimation.scala | 8 +-- .../statsEstimation/ProjectEstimation.scala | 4 +- .../apache/spark/sql/internal/SQLConf.scala | 9 ++++ .../sql/catalyst/analysis/AnalysisTest.scala | 4 +- .../analysis/DecimalPrecisionSuite.scala | 1 - .../SubstituteUnresolvedOrdinalsSuite.scala | 6 +-- .../catalog/SessionCatalogSuite.scala | 5 +- .../optimizer/AggregateOptimizeSuite.scala | 5 +- .../BinaryComparisonSimplificationSuite.scala | 2 - .../BooleanSimplificationSuite.scala | 5 +- .../optimizer/CombiningLimitsSuite.scala | 3 +- .../optimizer/ConstantFoldingSuite.scala | 3 +- .../optimizer/DecimalAggregatesSuite.scala | 3 +- .../optimizer/EliminateSortsSuite.scala | 5 +- .../InferFiltersFromConstraintsSuite.scala | 7 ++- .../optimizer/JoinOptimizationSuite.scala | 3 +- .../catalyst/optimizer/JoinReorderSuite.scala | 7 +-- .../optimizer/LimitPushdownSuite.scala | 1 - .../optimizer/OptimizeCodegenSuite.scala | 3 +- .../catalyst/optimizer/OptimizeInSuite.scala | 11 ++-- .../optimizer/OuterJoinEliminationSuite.scala | 7 ++- .../PropagateEmptyRelationSuite.scala | 5 +- .../optimizer/PruneFiltersSuite.scala | 7 ++- .../RewriteDistinctAggregatesSuite.scala | 9 ++-- .../optimizer/SetOperationSuite.scala | 3 +- .../optimizer/StarJoinReorderSuite.scala | 7 ++- .../spark/sql/catalyst/plans/PlanTest.scala | 4 +- .../AggregateEstimationSuite.scala | 5 +- .../BasicStatsEstimationSuite.scala | 8 +-- .../StatsEstimationTestBase.scala | 7 +-- .../spark/sql/execution/ExistingRDD.scala | 7 +-- .../execution/columnar/InMemoryRelation.scala | 5 +- .../datasources/DataSourceStrategy.scala | 8 ++- .../datasources/LogicalRelation.scala | 4 +- .../sql/execution/streaming/memory.scala | 4 +- .../sql/sources/DataSourceAnalysisSuite.scala | 4 +- 54 files changed, 164 insertions(+), 220 deletions(-) delete mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SimpleCatalystConf.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SimpleCatalystConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SimpleCatalystConf.scala deleted file mode 100644 index 8498cf1c9be79..0000000000000 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SimpleCatalystConf.scala +++ /dev/null @@ -1,51 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.catalyst - -import java.util.TimeZone - -import org.apache.spark.sql.internal.SQLConf - - -/** - * A SQLConf that can be used for local testing. This class is only here to minimize the change - * for ticket SPARK-19944 (moves SQLConf from sql/core to sql/catalyst). This class should - * eventually be removed (test cases should just create SQLConf and set values appropriately). - */ -case class SimpleCatalystConf( - override val caseSensitiveAnalysis: Boolean, - override val orderByOrdinal: Boolean = true, - override val groupByOrdinal: Boolean = true, - override val optimizerMaxIterations: Int = 100, - override val optimizerInSetConversionThreshold: Int = 10, - override val maxCaseBranchesForCodegen: Int = 20, - override val tableRelationCacheSize: Int = 1000, - override val runSQLonFile: Boolean = true, - override val crossJoinEnabled: Boolean = false, - override val cboEnabled: Boolean = false, - override val joinReorderEnabled: Boolean = false, - override val joinReorderDPThreshold: Int = 12, - override val starSchemaDetection: Boolean = false, - override val warehousePath: String = "/user/hive/warehouse", - override val sessionLocalTimeZone: String = TimeZone.getDefault().getID, - override val maxNestedViewDepth: Int = 100, - override val constraintPropagationEnabled: Boolean = true) - extends SQLConf { - - override def clone(): SimpleCatalystConf = this.copy() -} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 1b3a53c6359e6..2d53d2424a34d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -34,6 +34,7 @@ import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, _} import org.apache.spark.sql.catalyst.rules._ import org.apache.spark.sql.catalyst.trees.TreeNodeRef import org.apache.spark.sql.catalyst.util.toPrettySQL +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ /** @@ -42,13 +43,13 @@ import org.apache.spark.sql.types._ * to resolve attribute references. */ object SimpleAnalyzer extends Analyzer( - new SessionCatalog( - new InMemoryCatalog, - EmptyFunctionRegistry, - new SimpleCatalystConf(caseSensitiveAnalysis = true)) { - override def createDatabase(dbDefinition: CatalogDatabase, ignoreIfExists: Boolean) {} - }, - new SimpleCatalystConf(caseSensitiveAnalysis = true)) + new SessionCatalog( + new InMemoryCatalog, + EmptyFunctionRegistry, + new SQLConf().copy(SQLConf.CASE_SENSITIVE -> true)) { + override def createDatabase(dbDefinition: CatalogDatabase, ignoreIfExists: Boolean) {} + }, + new SQLConf().copy(SQLConf.CASE_SENSITIVE -> true)) /** * Provides a way to keep state during the analysis, this enables us to decouple the concerns @@ -89,11 +90,11 @@ object AnalysisContext { */ class Analyzer( catalog: SessionCatalog, - conf: CatalystConf, + conf: SQLConf, maxIterations: Int) extends RuleExecutor[LogicalPlan] with CheckAnalysis { - def this(catalog: SessionCatalog, conf: CatalystConf) = { + def this(catalog: SessionCatalog, conf: SQLConf) = { this(catalog, conf, conf.optimizerMaxIterations) } @@ -2331,7 +2332,7 @@ class Analyzer( } /** - * Replace [[TimeZoneAwareExpression]] without [[TimeZone]] by its copy with session local + * Replace [[TimeZoneAwareExpression]] without timezone id by its copy with session local * time zone. */ object ResolveTimeZone extends Rule[LogicalPlan] { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveHints.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveHints.scala index 920033a9a8480..f8004ca300ac7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveHints.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveHints.scala @@ -17,10 +17,10 @@ package org.apache.spark.sql.catalyst.analysis -import org.apache.spark.sql.catalyst.CatalystConf import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.catalyst.trees.CurrentOrigin +import org.apache.spark.sql.internal.SQLConf /** @@ -43,7 +43,7 @@ object ResolveHints { * * This rule must happen before common table expressions. */ - class ResolveBroadcastHints(conf: CatalystConf) extends Rule[LogicalPlan] { + class ResolveBroadcastHints(conf: SQLConf) extends Rule[LogicalPlan] { private val BROADCAST_HINT_NAMES = Set("BROADCAST", "BROADCASTJOIN", "MAPJOIN") def resolver: Resolver = conf.resolver diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveInlineTables.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveInlineTables.scala index d5b3ea8c37c66..a991dd96e2828 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveInlineTables.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveInlineTables.scala @@ -19,16 +19,17 @@ package org.apache.spark.sql.catalyst.analysis import scala.util.control.NonFatal -import org.apache.spark.sql.catalyst.{CatalystConf, InternalRow} +import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{Cast, TimeZoneAwareExpression} import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan} import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.{StructField, StructType} /** * An analyzer rule that replaces [[UnresolvedInlineTable]] with [[LocalRelation]]. */ -case class ResolveInlineTables(conf: CatalystConf) extends Rule[LogicalPlan] { +case class ResolveInlineTables(conf: SQLConf) extends Rule[LogicalPlan] { override def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { case table: UnresolvedInlineTable if table.expressionsResolved => validateInputDimension(table) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/SubstituteUnresolvedOrdinals.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/SubstituteUnresolvedOrdinals.scala index 38a3d3de1288e..256b18771052a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/SubstituteUnresolvedOrdinals.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/SubstituteUnresolvedOrdinals.scala @@ -17,17 +17,17 @@ package org.apache.spark.sql.catalyst.analysis -import org.apache.spark.sql.catalyst.CatalystConf import org.apache.spark.sql.catalyst.expressions.{Expression, Literal, SortOrder} import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LogicalPlan, Sort} import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.catalyst.trees.CurrentOrigin.withOrigin +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.IntegerType /** * Replaces ordinal in 'order by' or 'group by' with UnresolvedOrdinal expression. */ -class SubstituteUnresolvedOrdinals(conf: CatalystConf) extends Rule[LogicalPlan] { +class SubstituteUnresolvedOrdinals(conf: SQLConf) extends Rule[LogicalPlan] { private def isIntLiteral(e: Expression) = e match { case Literal(_, IntegerType) => true case _ => false diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/view.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/view.scala index a5640a6c967a1..3bd54c257d98d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/view.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/view.scala @@ -18,10 +18,10 @@ package org.apache.spark.sql.catalyst.analysis import org.apache.spark.sql.AnalysisException -import org.apache.spark.sql.catalyst.CatalystConf import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, Cast} import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Project, View} import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.internal.SQLConf /** * This file defines analysis rules related to views. @@ -47,7 +47,7 @@ import org.apache.spark.sql.catalyst.rules.Rule * This should be only done after the batch of Resolution, because the view attributes are not * completely resolved during the batch of Resolution. */ -case class AliasViewChild(conf: CatalystConf) extends Rule[LogicalPlan] { +case class AliasViewChild(conf: SQLConf) extends Rule[LogicalPlan] { override def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { case v @ View(desc, output, child) if child.resolved && output != child.output => val resolver = conf.resolver diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala index 72ab075408899..6f8c6ee2f0f44 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala @@ -35,6 +35,7 @@ import org.apache.spark.sql.catalyst.expressions.{Expression, ExpressionInfo} import org.apache.spark.sql.catalyst.parser.{CatalystSqlParser, ParserInterface} import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, SubqueryAlias, View} import org.apache.spark.sql.catalyst.util.StringUtils +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.{StructField, StructType} object SessionCatalog { @@ -52,7 +53,7 @@ class SessionCatalog( val externalCatalog: ExternalCatalog, globalTempViewManager: GlobalTempViewManager, functionRegistry: FunctionRegistry, - conf: CatalystConf, + conf: SQLConf, hadoopConf: Configuration, parser: ParserInterface, functionResourceLoader: FunctionResourceLoader) extends Logging { @@ -63,7 +64,7 @@ class SessionCatalog( def this( externalCatalog: ExternalCatalog, functionRegistry: FunctionRegistry, - conf: CatalystConf) { + conf: SQLConf) { this( externalCatalog, new GlobalTempViewManager("global_temp"), @@ -79,7 +80,7 @@ class SessionCatalog( this( externalCatalog, new SimpleFunctionRegistry, - SimpleCatalystConf(caseSensitiveAnalysis = true)) + new SQLConf().copy(SQLConf.CASE_SENSITIVE -> true)) } /** List of temporary tables, mapping from table name to their logical plan. */ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala index 3f25f9e7258f0..dc2e40424fd5f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala @@ -25,12 +25,13 @@ import scala.collection.mutable import com.google.common.base.Objects import org.apache.spark.sql.AnalysisException -import org.apache.spark.sql.catalyst.{CatalystConf, FunctionIdentifier, InternalRow, TableIdentifier} +import org.apache.spark.sql.catalyst.{FunctionIdentifier, InternalRow, TableIdentifier} import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap, Cast, Literal} import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, DateTimeUtils} import org.apache.spark.sql.catalyst.util.quoteIdentifier +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.StructType @@ -425,7 +426,7 @@ case class CatalogRelation( /** Only compare table identifier. */ override lazy val cleanArgs: Seq[Any] = Seq(tableMeta.identifier) - override def computeStats(conf: CatalystConf): Statistics = { + override def computeStats(conf: SQLConf): Statistics = { // For data source tables, we will create a `LogicalRelation` and won't call this method, for // hive serde tables, we will always generate a statistics. // TODO: unify the table stats generation. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index dbf479d215134..577112779eea4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -20,7 +20,6 @@ package org.apache.spark.sql.catalyst.optimizer import scala.collection.mutable import org.apache.spark.sql.AnalysisException -import org.apache.spark.sql.catalyst.{CatalystConf, SimpleCatalystConf} import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.catalog.{InMemoryCatalog, SessionCatalog} import org.apache.spark.sql.catalyst.expressions._ @@ -28,13 +27,14 @@ import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules._ +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ /** * Abstract class all optimizers should inherit of, contains the standard batches (extending * Optimizers can override this. */ -abstract class Optimizer(sessionCatalog: SessionCatalog, conf: CatalystConf) +abstract class Optimizer(sessionCatalog: SessionCatalog, conf: SQLConf) extends RuleExecutor[LogicalPlan] { protected val fixedPoint = FixedPoint(conf.optimizerMaxIterations) @@ -160,8 +160,8 @@ class SimpleTestOptimizer extends Optimizer( new SessionCatalog( new InMemoryCatalog, EmptyFunctionRegistry, - new SimpleCatalystConf(caseSensitiveAnalysis = true)), - new SimpleCatalystConf(caseSensitiveAnalysis = true)) + new SQLConf().copy(SQLConf.CASE_SENSITIVE -> true)), + new SQLConf().copy(SQLConf.CASE_SENSITIVE -> true)) /** * Remove redundant aliases from a query plan. A redundant alias is an alias that does not change @@ -270,7 +270,7 @@ object RemoveRedundantProject extends Rule[LogicalPlan] { /** * Pushes down [[LocalLimit]] beneath UNION ALL and beneath the streamed inputs of outer joins. */ -case class LimitPushDown(conf: CatalystConf) extends Rule[LogicalPlan] { +case class LimitPushDown(conf: SQLConf) extends Rule[LogicalPlan] { private def stripGlobalLimitIfPresent(plan: LogicalPlan): LogicalPlan = { plan match { @@ -617,7 +617,7 @@ object CollapseWindow extends Rule[LogicalPlan] { * Note: While this optimization is applicable to all types of join, it primarily benefits Inner and * LeftSemi joins. */ -case class InferFiltersFromConstraints(conf: CatalystConf) +case class InferFiltersFromConstraints(conf: SQLConf) extends Rule[LogicalPlan] with PredicateHelper { def apply(plan: LogicalPlan): LogicalPlan = if (conf.constraintPropagationEnabled) { inferFilters(plan) @@ -715,7 +715,7 @@ object EliminateSorts extends Rule[LogicalPlan] { * 2) by substituting a dummy empty relation when the filter will always evaluate to `false`. * 3) by eliminating the always-true conditions given the constraints on the child's output. */ -case class PruneFilters(conf: CatalystConf) extends Rule[LogicalPlan] with PredicateHelper { +case class PruneFilters(conf: SQLConf) extends Rule[LogicalPlan] with PredicateHelper { def apply(plan: LogicalPlan): LogicalPlan = plan transform { // If the filter condition always evaluate to true, remove the filter. case Filter(Literal(true, BooleanType), child) => child @@ -1057,7 +1057,7 @@ object CombineLimits extends Rule[LogicalPlan] { * the join between R and S is not a cartesian product and therefore should be allowed. * The predicate R.r = S.s is not recognized as a join condition until the ReorderJoin rule. */ -case class CheckCartesianProducts(conf: CatalystConf) +case class CheckCartesianProducts(conf: SQLConf) extends Rule[LogicalPlan] with PredicateHelper { /** * Check if a join is a cartesian product. Returns true if @@ -1092,7 +1092,7 @@ case class CheckCartesianProducts(conf: CatalystConf) * This uses the same rules for increasing the precision and scale of the output as * [[org.apache.spark.sql.catalyst.analysis.DecimalPrecision]]. */ -case class DecimalAggregates(conf: CatalystConf) extends Rule[LogicalPlan] { +case class DecimalAggregates(conf: SQLConf) extends Rule[LogicalPlan] { import Decimal.MAX_LONG_DIGITS /** Maximum number of decimal digits representable precisely in a Double */ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala index 33039127f16ce..8445ee06bd89b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala @@ -19,7 +19,6 @@ package org.apache.spark.sql.catalyst.optimizer import scala.collection.immutable.HashSet -import org.apache.spark.sql.catalyst.CatalystConf import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ @@ -27,6 +26,7 @@ import org.apache.spark.sql.catalyst.expressions.Literal.{FalseLiteral, TrueLite import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules._ +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ /* @@ -115,7 +115,7 @@ object ReorderAssociativeOperator extends Rule[LogicalPlan] { * 2. Replaces [[In (value, seq[Literal])]] with optimized version * [[InSet (value, HashSet[Literal])]] which is much faster. */ -case class OptimizeIn(conf: CatalystConf) extends Rule[LogicalPlan] { +case class OptimizeIn(conf: SQLConf) extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan transform { case q: LogicalPlan => q transformExpressionsDown { case expr @ In(v, list) if expr.inSetConvertible => @@ -346,7 +346,7 @@ object LikeSimplification extends Rule[LogicalPlan] { * equivalent [[Literal]] values. This rule is more specific with * Null value propagation from bottom to top of the expression tree. */ -case class NullPropagation(conf: CatalystConf) extends Rule[LogicalPlan] { +case class NullPropagation(conf: SQLConf) extends Rule[LogicalPlan] { private def isNullLiteral(e: Expression): Boolean = e match { case Literal(null, _) => true case _ => false @@ -482,7 +482,7 @@ object FoldablePropagation extends Rule[LogicalPlan] { /** * Optimizes expressions by replacing according to CodeGen configuration. */ -case class OptimizeCodegen(conf: CatalystConf) extends Rule[LogicalPlan] { +case class OptimizeCodegen(conf: SQLConf) extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { case e: CaseWhen if canCodegen(e) => e.toCodegen() } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/joins.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/joins.scala index 5f7316566b3ba..250dd07a16eb4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/joins.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/joins.scala @@ -19,7 +19,6 @@ package org.apache.spark.sql.catalyst.optimizer import scala.annotation.tailrec -import org.apache.spark.sql.catalyst.CatalystConf import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.planning.{ExtractFiltersAndInnerJoins, PhysicalOperation} import org.apache.spark.sql.catalyst.plans._ @@ -440,7 +439,7 @@ case class ReorderJoin(conf: SQLConf) extends Rule[LogicalPlan] with PredicateHe * * This rule should be executed before pushing down the Filter */ -case class EliminateOuterJoin(conf: CatalystConf) extends Rule[LogicalPlan] with PredicateHelper { +case class EliminateOuterJoin(conf: SQLConf) extends Rule[LogicalPlan] with PredicateHelper { /** * Returns whether the expression returns null or false when all inputs are nulls. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/package.scala index 4af56afebb762..f9c88d496e899 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/package.scala @@ -17,8 +17,6 @@ package org.apache.spark.sql -import org.apache.spark.sql.internal.SQLConf - /** * Catalyst is a library for manipulating relational query plans. All classes in catalyst are * considered an internal API to Spark SQL and are subject to change between minor releases. @@ -30,10 +28,4 @@ package object catalyst { * 2.10.* builds. See SI-6240 for more details. */ protected[sql] object ScalaReflectionLock - - /** - * This class is only here to minimize the change for ticket SPARK-19944 - * (moves SQLConf from sql/core to sql/catalyst). This class should eventually be removed. - */ - type CatalystConf = SQLConf } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LocalRelation.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LocalRelation.scala index 1faabcfcb73b5..b7177c4a2c4e4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LocalRelation.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LocalRelation.scala @@ -18,9 +18,10 @@ package org.apache.spark.sql.catalyst.plans.logical import org.apache.spark.sql.Row -import org.apache.spark.sql.catalyst.{CatalystConf, CatalystTypeConverters, InternalRow} +import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} import org.apache.spark.sql.catalyst.analysis import org.apache.spark.sql.catalyst.expressions.{Attribute, Literal} +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.{StructField, StructType} object LocalRelation { @@ -74,7 +75,7 @@ case class LocalRelation(output: Seq[Attribute], data: Seq[InternalRow] = Nil) } } - override def computeStats(conf: CatalystConf): Statistics = + override def computeStats(conf: SQLConf): Statistics = Statistics(sizeInBytes = output.map(n => BigInt(n.dataType.defaultSize)).sum * data.length) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala index f71a976bd7a24..036b6256684cb 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala @@ -19,11 +19,11 @@ package org.apache.spark.sql.catalyst.plans.logical import org.apache.spark.internal.Logging import org.apache.spark.sql.AnalysisException -import org.apache.spark.sql.catalyst.CatalystConf import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.QueryPlan import org.apache.spark.sql.catalyst.trees.CurrentOrigin +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.StructType @@ -90,7 +90,7 @@ abstract class LogicalPlan extends QueryPlan[LogicalPlan] with Logging { * first time. If the configuration changes, the cache can be invalidated by calling * [[invalidateStatsCache()]]. */ - final def stats(conf: CatalystConf): Statistics = statsCache.getOrElse { + final def stats(conf: SQLConf): Statistics = statsCache.getOrElse { statsCache = Some(computeStats(conf)) statsCache.get } @@ -108,7 +108,7 @@ abstract class LogicalPlan extends QueryPlan[LogicalPlan] with Logging { * * [[LeafNode]]s must override this. */ - protected def computeStats(conf: CatalystConf): Statistics = { + protected def computeStats(conf: SQLConf): Statistics = { if (children.isEmpty) { throw new UnsupportedOperationException(s"LeafNode $nodeName must implement statistics.") } @@ -335,7 +335,7 @@ abstract class UnaryNode extends LogicalPlan { override protected def validConstraints: Set[Expression] = child.constraints - override def computeStats(conf: CatalystConf): Statistics = { + override def computeStats(conf: SQLConf): Statistics = { // There should be some overhead in Row object, the size should not be zero when there is // no columns, this help to prevent divide-by-zero error. val childRowSize = child.output.map(_.dataType.defaultSize).sum + 8 diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala index 19db42c80895c..c91de08ca5ef6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala @@ -17,13 +17,13 @@ package org.apache.spark.sql.catalyst.plans.logical -import org.apache.spark.sql.catalyst.{CatalystConf, TableIdentifier} import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation -import org.apache.spark.sql.catalyst.catalog.{CatalogTable, CatalogTypes} +import org.apache.spark.sql.catalyst.catalog.CatalogTable import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical.statsEstimation._ +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ import org.apache.spark.util.Utils @@ -64,7 +64,7 @@ case class Project(projectList: Seq[NamedExpression], child: LogicalPlan) extend override def validConstraints: Set[Expression] = child.constraints.union(getAliasedConstraints(projectList)) - override def computeStats(conf: CatalystConf): Statistics = { + override def computeStats(conf: SQLConf): Statistics = { if (conf.cboEnabled) { ProjectEstimation.estimate(conf, this).getOrElse(super.computeStats(conf)) } else { @@ -138,7 +138,7 @@ case class Filter(condition: Expression, child: LogicalPlan) child.constraints.union(predicates.toSet) } - override def computeStats(conf: CatalystConf): Statistics = { + override def computeStats(conf: SQLConf): Statistics = { if (conf.cboEnabled) { FilterEstimation(this, conf).estimate.getOrElse(super.computeStats(conf)) } else { @@ -191,7 +191,7 @@ case class Intersect(left: LogicalPlan, right: LogicalPlan) extends SetOperation } } - override def computeStats(conf: CatalystConf): Statistics = { + override def computeStats(conf: SQLConf): Statistics = { val leftSize = left.stats(conf).sizeInBytes val rightSize = right.stats(conf).sizeInBytes val sizeInBytes = if (leftSize < rightSize) leftSize else rightSize @@ -208,7 +208,7 @@ case class Except(left: LogicalPlan, right: LogicalPlan) extends SetOperation(le override protected def validConstraints: Set[Expression] = leftConstraints - override def computeStats(conf: CatalystConf): Statistics = { + override def computeStats(conf: SQLConf): Statistics = { left.stats(conf).copy() } } @@ -247,7 +247,7 @@ case class Union(children: Seq[LogicalPlan]) extends LogicalPlan { children.length > 1 && childrenResolved && allChildrenCompatible } - override def computeStats(conf: CatalystConf): Statistics = { + override def computeStats(conf: SQLConf): Statistics = { val sizeInBytes = children.map(_.stats(conf).sizeInBytes).sum Statistics(sizeInBytes = sizeInBytes) } @@ -356,7 +356,7 @@ case class Join( case _ => resolvedExceptNatural } - override def computeStats(conf: CatalystConf): Statistics = { + override def computeStats(conf: SQLConf): Statistics = { def simpleEstimation: Statistics = joinType match { case LeftAnti | LeftSemi => // LeftSemi and LeftAnti won't ever be bigger than left @@ -382,7 +382,7 @@ case class BroadcastHint(child: LogicalPlan) extends UnaryNode { override def output: Seq[Attribute] = child.output // set isBroadcastable to true so the child will be broadcasted - override def computeStats(conf: CatalystConf): Statistics = + override def computeStats(conf: SQLConf): Statistics = child.stats(conf).copy(isBroadcastable = true) } @@ -538,7 +538,7 @@ case class Range( override def newInstance(): Range = copy(output = output.map(_.newInstance())) - override def computeStats(conf: CatalystConf): Statistics = { + override def computeStats(conf: SQLConf): Statistics = { val sizeInBytes = LongType.defaultSize * numElements Statistics( sizeInBytes = sizeInBytes ) } @@ -571,7 +571,7 @@ case class Aggregate( child.constraints.union(getAliasedConstraints(nonAgg)) } - override def computeStats(conf: CatalystConf): Statistics = { + override def computeStats(conf: SQLConf): Statistics = { def simpleEstimation: Statistics = { if (groupingExpressions.isEmpty) { Statistics( @@ -687,7 +687,7 @@ case class Expand( override def references: AttributeSet = AttributeSet(projections.flatten.flatMap(_.references)) - override def computeStats(conf: CatalystConf): Statistics = { + override def computeStats(conf: SQLConf): Statistics = { val sizeInBytes = super.computeStats(conf).sizeInBytes * projections.length Statistics(sizeInBytes = sizeInBytes) } @@ -758,7 +758,7 @@ case class GlobalLimit(limitExpr: Expression, child: LogicalPlan) extends UnaryN case _ => None } } - override def computeStats(conf: CatalystConf): Statistics = { + override def computeStats(conf: SQLConf): Statistics = { val limit = limitExpr.eval().asInstanceOf[Int] val childStats = child.stats(conf) val rowCount: BigInt = childStats.rowCount.map(_.min(limit)).getOrElse(limit) @@ -778,7 +778,7 @@ case class LocalLimit(limitExpr: Expression, child: LogicalPlan) extends UnaryNo case _ => None } } - override def computeStats(conf: CatalystConf): Statistics = { + override def computeStats(conf: SQLConf): Statistics = { val limit = limitExpr.eval().asInstanceOf[Int] val childStats = child.stats(conf) if (limit == 0) { @@ -827,7 +827,7 @@ case class Sample( override def output: Seq[Attribute] = child.output - override def computeStats(conf: CatalystConf): Statistics = { + override def computeStats(conf: SQLConf): Statistics = { val ratio = upperBound - lowerBound val childStats = child.stats(conf) var sizeInBytes = EstimationUtils.ceil(BigDecimal(childStats.sizeInBytes) * ratio) @@ -893,7 +893,7 @@ case class RepartitionByExpression( case object OneRowRelation extends LeafNode { override def maxRows: Option[Long] = Some(1) override def output: Seq[Attribute] = Nil - override def computeStats(conf: CatalystConf): Statistics = Statistics(sizeInBytes = 1) + override def computeStats(conf: SQLConf): Statistics = Statistics(sizeInBytes = 1) } /** A logical plan for `dropDuplicates`. */ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/AggregateEstimation.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/AggregateEstimation.scala index ce74554c17010..48b5fbb03ef1e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/AggregateEstimation.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/AggregateEstimation.scala @@ -17,9 +17,9 @@ package org.apache.spark.sql.catalyst.plans.logical.statsEstimation -import org.apache.spark.sql.catalyst.CatalystConf import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Statistics} +import org.apache.spark.sql.internal.SQLConf object AggregateEstimation { @@ -29,7 +29,7 @@ object AggregateEstimation { * Estimate the number of output rows based on column stats of group-by columns, and propagate * column stats for aggregate expressions. */ - def estimate(conf: CatalystConf, agg: Aggregate): Option[Statistics] = { + def estimate(conf: SQLConf, agg: Aggregate): Option[Statistics] = { val childStats = agg.child.stats(conf) // Check if we have column stats for all group-by columns. val colStatsExist = agg.groupingExpressions.forall { e => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/EstimationUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/EstimationUtils.scala index 4d18b28be8663..5577233ffa6fe 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/EstimationUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/EstimationUtils.scala @@ -19,16 +19,16 @@ package org.apache.spark.sql.catalyst.plans.logical.statsEstimation import scala.math.BigDecimal.RoundingMode -import org.apache.spark.sql.catalyst.CatalystConf import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap} import org.apache.spark.sql.catalyst.plans.logical.{ColumnStat, LogicalPlan, Statistics} +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.{DataType, StringType} object EstimationUtils { /** Check if each plan has rowCount in its statistics. */ - def rowCountsExist(conf: CatalystConf, plans: LogicalPlan*): Boolean = + def rowCountsExist(conf: SQLConf, plans: LogicalPlan*): Boolean = plans.forall(_.stats(conf).rowCount.isDefined) /** Check if each attribute has column stat in the corresponding statistics. */ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/FilterEstimation.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/FilterEstimation.scala index 03c76cd41d816..7bd8e6511232f 100755 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/FilterEstimation.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/FilterEstimation.scala @@ -22,14 +22,14 @@ import scala.collection.mutable import scala.math.BigDecimal.RoundingMode import org.apache.spark.internal.Logging -import org.apache.spark.sql.catalyst.CatalystConf import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.Literal.{FalseLiteral, TrueLiteral} import org.apache.spark.sql.catalyst.plans.logical.{ColumnStat, Filter, LeafNode, Statistics} import org.apache.spark.sql.catalyst.util.DateTimeUtils +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ -case class FilterEstimation(plan: Filter, catalystConf: CatalystConf) extends Logging { +case class FilterEstimation(plan: Filter, catalystConf: SQLConf) extends Logging { private val childStats = plan.child.stats(catalystConf) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/JoinEstimation.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/JoinEstimation.scala index 9782c0bb0a939..3245a73c8a2eb 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/JoinEstimation.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/JoinEstimation.scala @@ -21,12 +21,12 @@ import scala.collection.mutable import scala.collection.mutable.ArrayBuffer import org.apache.spark.internal.Logging -import org.apache.spark.sql.catalyst.CatalystConf import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap, AttributeReference, Expression} import org.apache.spark.sql.catalyst.planning.ExtractEquiJoinKeys import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical.{ColumnStat, Join, Statistics} import org.apache.spark.sql.catalyst.plans.logical.statsEstimation.EstimationUtils._ +import org.apache.spark.sql.internal.SQLConf object JoinEstimation extends Logging { @@ -34,7 +34,7 @@ object JoinEstimation extends Logging { * Estimate statistics after join. Return `None` if the join type is not supported, or we don't * have enough statistics for estimation. */ - def estimate(conf: CatalystConf, join: Join): Option[Statistics] = { + def estimate(conf: SQLConf, join: Join): Option[Statistics] = { join.joinType match { case Inner | Cross | LeftOuter | RightOuter | FullOuter => InnerOuterEstimation(conf, join).doEstimate() @@ -47,7 +47,7 @@ object JoinEstimation extends Logging { } } -case class InnerOuterEstimation(conf: CatalystConf, join: Join) extends Logging { +case class InnerOuterEstimation(conf: SQLConf, join: Join) extends Logging { private val leftStats = join.left.stats(conf) private val rightStats = join.right.stats(conf) @@ -288,7 +288,7 @@ case class InnerOuterEstimation(conf: CatalystConf, join: Join) extends Logging } } -case class LeftSemiAntiEstimation(conf: CatalystConf, join: Join) { +case class LeftSemiAntiEstimation(conf: SQLConf, join: Join) { def doEstimate(): Option[Statistics] = { // TODO: It's error-prone to estimate cardinalities for LeftSemi and LeftAnti based on basic // column stats. Now we just propagate the statistics from left side. We should do more diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/ProjectEstimation.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/ProjectEstimation.scala index e9084ad8b859c..d700cd3b20f7d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/ProjectEstimation.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/ProjectEstimation.scala @@ -17,14 +17,14 @@ package org.apache.spark.sql.catalyst.plans.logical.statsEstimation -import org.apache.spark.sql.catalyst.CatalystConf import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, AttributeMap} import org.apache.spark.sql.catalyst.plans.logical.{Project, Statistics} +import org.apache.spark.sql.internal.SQLConf object ProjectEstimation { import EstimationUtils._ - def estimate(conf: CatalystConf, project: Project): Option[Statistics] = { + def estimate(conf: SQLConf, project: Project): Option[Statistics] = { if (rowCountsExist(conf, project.child)) { val childStats = project.child.stats(conf) val inputAttrStats = childStats.attributeStats diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 06dc0b41204fb..5b5d547f8fe54 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -1151,4 +1151,13 @@ class SQLConf extends Serializable with Logging { } result } + + // For test only + private[spark] def copy(entries: (ConfigEntry[_], Any)*): SQLConf = { + val cloned = clone() + entries.foreach { + case (entry, value) => cloned.setConfString(entry.key, value.toString) + } + cloned + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisTest.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisTest.scala index 0f059b9591460..1be25ec06c741 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisTest.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisTest.scala @@ -18,10 +18,10 @@ package org.apache.spark.sql.catalyst.analysis import org.apache.spark.sql.AnalysisException -import org.apache.spark.sql.catalyst.SimpleCatalystConf import org.apache.spark.sql.catalyst.catalog.{InMemoryCatalog, SessionCatalog} import org.apache.spark.sql.catalyst.plans.PlanTest import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.internal.SQLConf trait AnalysisTest extends PlanTest { @@ -29,7 +29,7 @@ trait AnalysisTest extends PlanTest { protected val caseInsensitiveAnalyzer = makeAnalyzer(caseSensitive = false) private def makeAnalyzer(caseSensitive: Boolean): Analyzer = { - val conf = new SimpleCatalystConf(caseSensitive) + val conf = new SQLConf().copy(SQLConf.CASE_SENSITIVE -> caseSensitive) val catalog = new SessionCatalog(new InMemoryCatalog, EmptyFunctionRegistry, conf) catalog.createTempView("TaBlE", TestRelations.testRelation, overrideIfExists = true) catalog.createTempView("TaBlE2", TestRelations.testRelation2, overrideIfExists = true) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecisionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecisionSuite.scala index 6995faebfa862..8f43171f309a9 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecisionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecisionSuite.scala @@ -19,7 +19,6 @@ package org.apache.spark.sql.catalyst.analysis import org.scalatest.BeforeAndAfter -import org.apache.spark.sql.catalyst.SimpleCatalystConf import org.apache.spark.sql.catalyst.catalog.{InMemoryCatalog, SessionCatalog} import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.expressions._ diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/SubstituteUnresolvedOrdinalsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/SubstituteUnresolvedOrdinalsSuite.scala index 88f68ebadc72a..2331346f325aa 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/SubstituteUnresolvedOrdinalsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/SubstituteUnresolvedOrdinalsSuite.scala @@ -21,7 +21,7 @@ import org.apache.spark.sql.catalyst.analysis.TestRelations.testRelation2 import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ import org.apache.spark.sql.catalyst.expressions.Literal -import org.apache.spark.sql.catalyst.SimpleCatalystConf +import org.apache.spark.sql.internal.SQLConf class SubstituteUnresolvedOrdinalsSuite extends AnalysisTest { private lazy val a = testRelation2.output(0) @@ -44,7 +44,7 @@ class SubstituteUnresolvedOrdinalsSuite extends AnalysisTest { // order by ordinal can be turned off by config comparePlans( - new SubstituteUnresolvedOrdinals(conf.copy(orderByOrdinal = false)).apply(plan), + new SubstituteUnresolvedOrdinals(conf.copy(SQLConf.ORDER_BY_ORDINAL -> false)).apply(plan), testRelation2.orderBy(Literal(1).asc, Literal(2).asc)) } @@ -60,7 +60,7 @@ class SubstituteUnresolvedOrdinalsSuite extends AnalysisTest { // group by ordinal can be turned off by config comparePlans( - new SubstituteUnresolvedOrdinals(conf.copy(groupByOrdinal = false)).apply(plan2), + new SubstituteUnresolvedOrdinals(conf.copy(SQLConf.GROUP_BY_ORDINAL -> false)).apply(plan2), testRelation2.groupBy(Literal(1), Literal(2))('a, 'b)) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalogSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalogSuite.scala index 56bca73a8857a..9ba846fb25279 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalogSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalogSuite.scala @@ -18,12 +18,13 @@ package org.apache.spark.sql.catalyst.catalog import org.apache.spark.sql.AnalysisException -import org.apache.spark.sql.catalyst.{FunctionIdentifier, SimpleCatalystConf, TableIdentifier} +import org.apache.spark.sql.catalyst.{FunctionIdentifier, TableIdentifier} import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.parser.CatalystSqlParser import org.apache.spark.sql.catalyst.plans.PlanTest import org.apache.spark.sql.catalyst.plans.logical.{Range, SubqueryAlias, View} +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ class InMemorySessionCatalogSuite extends SessionCatalogSuite { @@ -1382,7 +1383,7 @@ abstract class SessionCatalogSuite extends PlanTest { import org.apache.spark.sql.catalyst.dsl.plans._ Seq(true, false) foreach { caseSensitive => - val conf = SimpleCatalystConf(caseSensitive) + val conf = new SQLConf().copy(SQLConf.CASE_SENSITIVE -> caseSensitive) val catalog = new SessionCatalog(newBasicCatalog(), new SimpleFunctionRegistry, conf) try { val analyzer = new Analyzer(catalog, conf) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/AggregateOptimizeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/AggregateOptimizeSuite.scala index b45bd977cbba1..e6132ab2e4d17 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/AggregateOptimizeSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/AggregateOptimizeSuite.scala @@ -17,7 +17,6 @@ package org.apache.spark.sql.catalyst.optimizer -import org.apache.spark.sql.catalyst.SimpleCatalystConf import org.apache.spark.sql.catalyst.analysis.{Analyzer, EmptyFunctionRegistry} import org.apache.spark.sql.catalyst.catalog.{InMemoryCatalog, SessionCatalog} import org.apache.spark.sql.catalyst.dsl.expressions._ @@ -26,9 +25,11 @@ import org.apache.spark.sql.catalyst.expressions.Literal import org.apache.spark.sql.catalyst.plans.PlanTest import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan} import org.apache.spark.sql.catalyst.rules.RuleExecutor +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.internal.SQLConf.{CASE_SENSITIVE, GROUP_BY_ORDINAL} class AggregateOptimizeSuite extends PlanTest { - override val conf = SimpleCatalystConf(caseSensitiveAnalysis = false, groupByOrdinal = false) + override val conf = new SQLConf().copy(CASE_SENSITIVE -> false, GROUP_BY_ORDINAL -> false) val catalog = new SessionCatalog(new InMemoryCatalog, EmptyFunctionRegistry, conf) val analyzer = new Analyzer(catalog, conf) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/BinaryComparisonSimplificationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/BinaryComparisonSimplificationSuite.scala index 2bfddb7bc2f35..b29e1cbd14943 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/BinaryComparisonSimplificationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/BinaryComparisonSimplificationSuite.scala @@ -17,7 +17,6 @@ package org.apache.spark.sql.catalyst.optimizer -import org.apache.spark.sql.catalyst.SimpleCatalystConf import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ @@ -30,7 +29,6 @@ import org.apache.spark.sql.catalyst.rules._ class BinaryComparisonSimplificationSuite extends PlanTest with PredicateHelper { object Optimize extends RuleExecutor[LogicalPlan] { - val conf = SimpleCatalystConf(caseSensitiveAnalysis = true) val batches = Batch("AnalysisNodes", Once, EliminateSubqueryAliases) :: diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/BooleanSimplificationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/BooleanSimplificationSuite.scala index 4d404f55aa570..935bff7cef2e8 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/BooleanSimplificationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/BooleanSimplificationSuite.scala @@ -17,7 +17,6 @@ package org.apache.spark.sql.catalyst.optimizer -import org.apache.spark.sql.catalyst.SimpleCatalystConf import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.catalog.{InMemoryCatalog, SessionCatalog} import org.apache.spark.sql.catalyst.dsl.expressions._ @@ -26,11 +25,11 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.PlanTest import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules._ +import org.apache.spark.sql.internal.SQLConf class BooleanSimplificationSuite extends PlanTest with PredicateHelper { object Optimize extends RuleExecutor[LogicalPlan] { - val conf = SimpleCatalystConf(caseSensitiveAnalysis = true) val batches = Batch("AnalysisNodes", Once, EliminateSubqueryAliases) :: @@ -139,7 +138,7 @@ class BooleanSimplificationSuite extends PlanTest with PredicateHelper { checkCondition(!(('a || 'b) && ('c || 'd)), (!'a && !'b) || (!'c && !'d)) } - private val caseInsensitiveConf = new SimpleCatalystConf(false) + private val caseInsensitiveConf = new SQLConf().copy(SQLConf.CASE_SENSITIVE -> false) private val caseInsensitiveAnalyzer = new Analyzer( new SessionCatalog(new InMemoryCatalog, EmptyFunctionRegistry, caseInsensitiveConf), caseInsensitiveConf) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CombiningLimitsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CombiningLimitsSuite.scala index 276b8055b08d0..ac71887c16f96 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CombiningLimitsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CombiningLimitsSuite.scala @@ -17,7 +17,6 @@ package org.apache.spark.sql.catalyst.optimizer -import org.apache.spark.sql.catalyst.SimpleCatalystConf import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ import org.apache.spark.sql.catalyst.plans.logical._ @@ -33,7 +32,7 @@ class CombiningLimitsSuite extends PlanTest { Batch("Combine Limit", FixedPoint(10), CombineLimits) :: Batch("Constant Folding", FixedPoint(10), - NullPropagation(SimpleCatalystConf(caseSensitiveAnalysis = true)), + NullPropagation(conf), ConstantFolding, BooleanSimplification, SimplifyConditionals) :: Nil diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConstantFoldingSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConstantFoldingSuite.scala index d9655bbcc2ce1..25c592b9c1dde 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConstantFoldingSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConstantFoldingSuite.scala @@ -17,7 +17,6 @@ package org.apache.spark.sql.catalyst.optimizer -import org.apache.spark.sql.catalyst.SimpleCatalystConf import org.apache.spark.sql.catalyst.analysis.{EliminateSubqueryAliases, UnresolvedExtractValue} import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ @@ -34,7 +33,7 @@ class ConstantFoldingSuite extends PlanTest { Batch("AnalysisNodes", Once, EliminateSubqueryAliases) :: Batch("ConstantFolding", Once, - OptimizeIn(SimpleCatalystConf(true)), + OptimizeIn(conf), ConstantFolding, BooleanSimplification) :: Nil } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/DecimalAggregatesSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/DecimalAggregatesSuite.scala index a491f4433370d..cc4fb3a244a98 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/DecimalAggregatesSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/DecimalAggregatesSuite.scala @@ -17,7 +17,6 @@ package org.apache.spark.sql.catalyst.optimizer -import org.apache.spark.sql.catalyst.SimpleCatalystConf import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ import org.apache.spark.sql.catalyst.expressions._ @@ -30,7 +29,7 @@ class DecimalAggregatesSuite extends PlanTest { object Optimize extends RuleExecutor[LogicalPlan] { val batches = Batch("Decimal Optimizations", FixedPoint(100), - DecimalAggregates(SimpleCatalystConf(caseSensitiveAnalysis = true))) :: Nil + DecimalAggregates(conf)) :: Nil } val testRelation = LocalRelation('a.decimal(2, 1), 'b.decimal(12, 1)) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/EliminateSortsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/EliminateSortsSuite.scala index c5f9cc1852752..e318f36d78270 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/EliminateSortsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/EliminateSortsSuite.scala @@ -17,7 +17,6 @@ package org.apache.spark.sql.catalyst.optimizer -import org.apache.spark.sql.catalyst.SimpleCatalystConf import org.apache.spark.sql.catalyst.analysis.{Analyzer, EmptyFunctionRegistry} import org.apache.spark.sql.catalyst.catalog.{InMemoryCatalog, SessionCatalog} import org.apache.spark.sql.catalyst.dsl.expressions._ @@ -26,9 +25,11 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules._ +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.internal.SQLConf.{CASE_SENSITIVE, ORDER_BY_ORDINAL} class EliminateSortsSuite extends PlanTest { - override val conf = new SimpleCatalystConf(caseSensitiveAnalysis = true, orderByOrdinal = false) + override val conf = new SQLConf().copy(CASE_SENSITIVE -> true, ORDER_BY_ORDINAL -> false) val catalog = new SessionCatalog(new InMemoryCatalog, EmptyFunctionRegistry, conf) val analyzer = new Analyzer(catalog, conf) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/InferFiltersFromConstraintsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/InferFiltersFromConstraintsSuite.scala index 98d8b897a9165..c8fe37462726a 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/InferFiltersFromConstraintsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/InferFiltersFromConstraintsSuite.scala @@ -17,13 +17,13 @@ package org.apache.spark.sql.catalyst.optimizer -import org.apache.spark.sql.catalyst.SimpleCatalystConf import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules._ +import org.apache.spark.sql.internal.SQLConf.CONSTRAINT_PROPAGATION_ENABLED class InferFiltersFromConstraintsSuite extends PlanTest { @@ -32,7 +32,7 @@ class InferFiltersFromConstraintsSuite extends PlanTest { Batch("InferAndPushDownFilters", FixedPoint(100), PushPredicateThroughJoin, PushDownPredicate, - InferFiltersFromConstraints(SimpleCatalystConf(caseSensitiveAnalysis = true)), + InferFiltersFromConstraints(conf), CombineFilters) :: Nil } @@ -41,8 +41,7 @@ class InferFiltersFromConstraintsSuite extends PlanTest { Batch("InferAndPushDownFilters", FixedPoint(100), PushPredicateThroughJoin, PushDownPredicate, - InferFiltersFromConstraints(SimpleCatalystConf(caseSensitiveAnalysis = true, - constraintPropagationEnabled = false)), + InferFiltersFromConstraints(conf.copy(CONSTRAINT_PROPAGATION_ENABLED -> false)), CombineFilters) :: Nil } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinOptimizationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinOptimizationSuite.scala index 61e81808147c7..a43d78c7bd447 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinOptimizationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinOptimizationSuite.scala @@ -26,7 +26,6 @@ import org.apache.spark.sql.catalyst.planning.ExtractFiltersAndInnerJoins import org.apache.spark.sql.catalyst.plans.{Cross, Inner, InnerLike, PlanTest} import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules.RuleExecutor -import org.apache.spark.sql.catalyst.SimpleCatalystConf class JoinOptimizationSuite extends PlanTest { @@ -38,7 +37,7 @@ class JoinOptimizationSuite extends PlanTest { CombineFilters, PushDownPredicate, BooleanSimplification, - ReorderJoin(SimpleCatalystConf(true)), + ReorderJoin(conf), PushPredicateThroughJoin, ColumnPruning, CollapseProject) :: Nil diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinReorderSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinReorderSuite.scala index d74008c1b3027..1922eb30fdce4 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinReorderSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinReorderSuite.scala @@ -17,7 +17,6 @@ package org.apache.spark.sql.catalyst.optimizer -import org.apache.spark.sql.catalyst.SimpleCatalystConf import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap} @@ -25,12 +24,14 @@ import org.apache.spark.sql.catalyst.plans.{Inner, PlanTest} import org.apache.spark.sql.catalyst.plans.logical.{ColumnStat, LogicalPlan} import org.apache.spark.sql.catalyst.rules.RuleExecutor import org.apache.spark.sql.catalyst.statsEstimation.{StatsEstimationTestBase, StatsTestPlan} +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.internal.SQLConf.{CASE_SENSITIVE, CBO_ENABLED, JOIN_REORDER_ENABLED} class JoinReorderSuite extends PlanTest with StatsEstimationTestBase { - override val conf = SimpleCatalystConf( - caseSensitiveAnalysis = true, cboEnabled = true, joinReorderEnabled = true) + override val conf = new SQLConf().copy( + CASE_SENSITIVE -> true, CBO_ENABLED -> true, JOIN_REORDER_ENABLED -> true) object Optimize extends RuleExecutor[LogicalPlan] { val batches = diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/LimitPushdownSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/LimitPushdownSuite.scala index 0f3ba6c895566..2885fd6841e9d 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/LimitPushdownSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/LimitPushdownSuite.scala @@ -17,7 +17,6 @@ package org.apache.spark.sql.catalyst.optimizer -import org.apache.spark.sql.catalyst.SimpleCatalystConf import org.apache.spark.sql.catalyst.analysis.EliminateSubqueryAliases import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeCodegenSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeCodegenSuite.scala index 4385b0e019f25..f3b65cc797ec4 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeCodegenSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeCodegenSuite.scala @@ -18,7 +18,6 @@ package org.apache.spark.sql.catalyst.optimizer import org.apache.spark.sql.catalyst.dsl.plans._ -import org.apache.spark.sql.catalyst.SimpleCatalystConf import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.Literal._ import org.apache.spark.sql.catalyst.plans.PlanTest @@ -29,7 +28,7 @@ import org.apache.spark.sql.catalyst.rules._ class OptimizeCodegenSuite extends PlanTest { object Optimize extends RuleExecutor[LogicalPlan] { - val batches = Batch("OptimizeCodegen", Once, OptimizeCodegen(SimpleCatalystConf(true))) :: Nil + val batches = Batch("OptimizeCodegen", Once, OptimizeCodegen(conf)) :: Nil } protected def assertEquivalent(e1: Expression, e2: Expression): Unit = { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeInSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeInSuite.scala index 9daede1a5f957..d8937321ecb98 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeInSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeInSuite.scala @@ -17,7 +17,6 @@ package org.apache.spark.sql.catalyst.optimizer -import org.apache.spark.sql.catalyst.SimpleCatalystConf import org.apache.spark.sql.catalyst.analysis.{EliminateSubqueryAliases, UnresolvedAttribute} import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ @@ -25,6 +24,7 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical.{Filter, LocalRelation, LogicalPlan} import org.apache.spark.sql.catalyst.plans.PlanTest import org.apache.spark.sql.catalyst.rules.RuleExecutor +import org.apache.spark.sql.internal.SQLConf.OPTIMIZER_INSET_CONVERSION_THRESHOLD import org.apache.spark.sql.types._ class OptimizeInSuite extends PlanTest { @@ -34,10 +34,10 @@ class OptimizeInSuite extends PlanTest { Batch("AnalysisNodes", Once, EliminateSubqueryAliases) :: Batch("ConstantFolding", FixedPoint(10), - NullPropagation(SimpleCatalystConf(caseSensitiveAnalysis = true)), + NullPropagation(conf), ConstantFolding, BooleanSimplification, - OptimizeIn(SimpleCatalystConf(caseSensitiveAnalysis = true))) :: Nil + OptimizeIn(conf)) :: Nil } val testRelation = LocalRelation('a.int, 'b.int, 'c.int) @@ -159,12 +159,11 @@ class OptimizeInSuite extends PlanTest { .where(In(UnresolvedAttribute("a"), Seq(Literal(1), Literal(2), Literal(3)))) .analyze - val notOptimizedPlan = OptimizeIn(SimpleCatalystConf(caseSensitiveAnalysis = true))(plan) + val notOptimizedPlan = OptimizeIn(conf)(plan) comparePlans(notOptimizedPlan, plan) // Reduce the threshold to turning into InSet. - val optimizedPlan = OptimizeIn(SimpleCatalystConf(caseSensitiveAnalysis = true, - optimizerInSetConversionThreshold = 2))(plan) + val optimizedPlan = OptimizeIn(conf.copy(OPTIMIZER_INSET_CONVERSION_THRESHOLD -> 2))(plan) optimizedPlan match { case Filter(cond, _) if cond.isInstanceOf[InSet] && cond.asInstanceOf[InSet].getHSet().size == 3 => diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OuterJoinEliminationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OuterJoinEliminationSuite.scala index cbabc1fa6d929..b7136703b7541 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OuterJoinEliminationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OuterJoinEliminationSuite.scala @@ -17,7 +17,6 @@ package org.apache.spark.sql.catalyst.optimizer -import org.apache.spark.sql.catalyst.SimpleCatalystConf import org.apache.spark.sql.catalyst.analysis.EliminateSubqueryAliases import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ @@ -25,6 +24,7 @@ import org.apache.spark.sql.catalyst.expressions.{Coalesce, IsNotNull} import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules._ +import org.apache.spark.sql.internal.SQLConf.CONSTRAINT_PROPAGATION_ENABLED class OuterJoinEliminationSuite extends PlanTest { object Optimize extends RuleExecutor[LogicalPlan] { @@ -32,7 +32,7 @@ class OuterJoinEliminationSuite extends PlanTest { Batch("Subqueries", Once, EliminateSubqueryAliases) :: Batch("Outer Join Elimination", Once, - EliminateOuterJoin(SimpleCatalystConf(caseSensitiveAnalysis = true)), + EliminateOuterJoin(conf), PushPredicateThroughJoin) :: Nil } @@ -41,8 +41,7 @@ class OuterJoinEliminationSuite extends PlanTest { Batch("Subqueries", Once, EliminateSubqueryAliases) :: Batch("Outer Join Elimination", Once, - EliminateOuterJoin(SimpleCatalystConf(caseSensitiveAnalysis = true, - constraintPropagationEnabled = false)), + EliminateOuterJoin(conf.copy(CONSTRAINT_PROPAGATION_ENABLED -> false)), PushPredicateThroughJoin) :: Nil } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PropagateEmptyRelationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PropagateEmptyRelationSuite.scala index f771e3e9eba65..c261a6091d476 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PropagateEmptyRelationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PropagateEmptyRelationSuite.scala @@ -18,7 +18,6 @@ package org.apache.spark.sql.catalyst.optimizer import org.apache.spark.sql.Row -import org.apache.spark.sql.catalyst.SimpleCatalystConf import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ import org.apache.spark.sql.catalyst.plans._ @@ -34,7 +33,7 @@ class PropagateEmptyRelationSuite extends PlanTest { ReplaceExceptWithAntiJoin, ReplaceIntersectWithSemiJoin, PushDownPredicate, - PruneFilters(SimpleCatalystConf(caseSensitiveAnalysis = true)), + PruneFilters(conf), PropagateEmptyRelation) :: Nil } @@ -46,7 +45,7 @@ class PropagateEmptyRelationSuite extends PlanTest { ReplaceExceptWithAntiJoin, ReplaceIntersectWithSemiJoin, PushDownPredicate, - PruneFilters(SimpleCatalystConf(caseSensitiveAnalysis = true))) :: Nil + PruneFilters(conf)) :: Nil } val testRelation1 = LocalRelation.fromExternalRows(Seq('a.int), data = Seq(Row(1))) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PruneFiltersSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PruneFiltersSuite.scala index 20f7f69e86c05..741dd0cf428d0 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PruneFiltersSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PruneFiltersSuite.scala @@ -17,7 +17,6 @@ package org.apache.spark.sql.catalyst.optimizer -import org.apache.spark.sql.catalyst.SimpleCatalystConf import org.apache.spark.sql.catalyst.analysis.EliminateSubqueryAliases import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ @@ -25,6 +24,7 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules._ +import org.apache.spark.sql.internal.SQLConf.CONSTRAINT_PROPAGATION_ENABLED class PruneFiltersSuite extends PlanTest { @@ -34,7 +34,7 @@ class PruneFiltersSuite extends PlanTest { EliminateSubqueryAliases) :: Batch("Filter Pushdown and Pruning", Once, CombineFilters, - PruneFilters(SimpleCatalystConf(caseSensitiveAnalysis = true)), + PruneFilters(conf), PushDownPredicate, PushPredicateThroughJoin) :: Nil } @@ -45,8 +45,7 @@ class PruneFiltersSuite extends PlanTest { EliminateSubqueryAliases) :: Batch("Filter Pushdown and Pruning", Once, CombineFilters, - PruneFilters(SimpleCatalystConf(caseSensitiveAnalysis = true, - constraintPropagationEnabled = false)), + PruneFilters(conf.copy(CONSTRAINT_PROPAGATION_ENABLED -> false)), PushDownPredicate, PushPredicateThroughJoin) :: Nil } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregatesSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregatesSuite.scala index 350a1c26fd1ef..8cb939e010c68 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregatesSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregatesSuite.scala @@ -16,19 +16,20 @@ */ package org.apache.spark.sql.catalyst.optimizer -import org.apache.spark.sql.catalyst.SimpleCatalystConf import org.apache.spark.sql.catalyst.analysis.{Analyzer, EmptyFunctionRegistry} import org.apache.spark.sql.catalyst.catalog.{InMemoryCatalog, SessionCatalog} import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ -import org.apache.spark.sql.catalyst.expressions.{If, Literal} -import org.apache.spark.sql.catalyst.expressions.aggregate.{CollectSet, Count} +import org.apache.spark.sql.catalyst.expressions.Literal +import org.apache.spark.sql.catalyst.expressions.aggregate.CollectSet import org.apache.spark.sql.catalyst.plans.PlanTest import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Expand, LocalRelation, LogicalPlan} +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.internal.SQLConf.{CASE_SENSITIVE, GROUP_BY_ORDINAL} import org.apache.spark.sql.types.{IntegerType, StringType} class RewriteDistinctAggregatesSuite extends PlanTest { - override val conf = SimpleCatalystConf(caseSensitiveAnalysis = false, groupByOrdinal = false) + override val conf = new SQLConf().copy(CASE_SENSITIVE -> false, GROUP_BY_ORDINAL -> false) val catalog = new SessionCatalog(new InMemoryCatalog, EmptyFunctionRegistry, conf) val analyzer = new Analyzer(catalog, conf) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SetOperationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SetOperationSuite.scala index ca4976f0d6db0..756e0f35b2178 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SetOperationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SetOperationSuite.scala @@ -17,7 +17,6 @@ package org.apache.spark.sql.catalyst.optimizer -import org.apache.spark.sql.catalyst.SimpleCatalystConf import org.apache.spark.sql.catalyst.analysis.EliminateSubqueryAliases import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ @@ -35,7 +34,7 @@ class SetOperationSuite extends PlanTest { CombineUnions, PushProjectionThroughUnion, PushDownPredicate, - PruneFilters(SimpleCatalystConf(caseSensitiveAnalysis = true))) :: Nil + PruneFilters(conf)) :: Nil } val testRelation = LocalRelation('a.int, 'b.int, 'c.int) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/StarJoinReorderSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/StarJoinReorderSuite.scala index 93fdd98d1ac93..003ce49eaf8e6 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/StarJoinReorderSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/StarJoinReorderSuite.scala @@ -17,7 +17,6 @@ package org.apache.spark.sql.catalyst.optimizer -import org.apache.spark.sql.catalyst.SimpleCatalystConf import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap} @@ -25,12 +24,12 @@ import org.apache.spark.sql.catalyst.plans.{Inner, PlanTest} import org.apache.spark.sql.catalyst.plans.logical.{ColumnStat, LocalRelation, LogicalPlan} import org.apache.spark.sql.catalyst.rules.RuleExecutor import org.apache.spark.sql.catalyst.statsEstimation.{StatsEstimationTestBase, StatsTestPlan} - +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.internal.SQLConf.{CASE_SENSITIVE, STARSCHEMA_DETECTION} class StarJoinReorderSuite extends PlanTest with StatsEstimationTestBase { - override val conf = SimpleCatalystConf( - caseSensitiveAnalysis = true, starSchemaDetection = true) + override val conf = new SQLConf().copy(CASE_SENSITIVE -> true, STARSCHEMA_DETECTION -> true) object Optimize extends RuleExecutor[LogicalPlan] { val batches = diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala index c73dfaf3f8fe3..f44428c3512a9 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala @@ -18,18 +18,18 @@ package org.apache.spark.sql.catalyst.plans import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.catalyst.SimpleCatalystConf import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.util._ +import org.apache.spark.sql.internal.SQLConf /** * Provides helper methods for comparing plans. */ abstract class PlanTest extends SparkFunSuite with PredicateHelper { - protected val conf = SimpleCatalystConf(caseSensitiveAnalysis = true) + protected val conf = new SQLConf().copy(SQLConf.CASE_SENSITIVE -> true) /** * Since attribute references are given globally unique ids during analysis, diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/AggregateEstimationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/AggregateEstimationSuite.scala index c0b9515ca7cd0..38483a298cef0 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/AggregateEstimationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/AggregateEstimationSuite.scala @@ -21,6 +21,7 @@ import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, AttributeMap import org.apache.spark.sql.catalyst.expressions.aggregate.Count import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.plans.logical.statsEstimation.EstimationUtils._ +import org.apache.spark.sql.internal.SQLConf class AggregateEstimationSuite extends StatsEstimationTestBase { @@ -101,13 +102,13 @@ class AggregateEstimationSuite extends StatsEstimationTestBase { val noGroupAgg = Aggregate(groupingExpressions = Nil, aggregateExpressions = Seq(Alias(Count(Literal(1)), "cnt")()), child) - assert(noGroupAgg.stats(conf.copy(cboEnabled = false)) == + assert(noGroupAgg.stats(conf.copy(SQLConf.CBO_ENABLED -> false)) == // overhead + count result size Statistics(sizeInBytes = 8 + 8, rowCount = Some(1))) val hasGroupAgg = Aggregate(groupingExpressions = attributes, aggregateExpressions = attributes :+ Alias(Count(Literal(1)), "cnt")(), child) - assert(hasGroupAgg.stats(conf.copy(cboEnabled = false)) == + assert(hasGroupAgg.stats(conf.copy(SQLConf.CBO_ENABLED -> false)) == // From UnaryNode.computeStats, childSize * outputRowSize / childRowSize Statistics(sizeInBytes = 48 * (8 + 4 + 8) / (8 + 4))) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/BasicStatsEstimationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/BasicStatsEstimationSuite.scala index 0d92c1e35565a..b06871f96f0d8 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/BasicStatsEstimationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/BasicStatsEstimationSuite.scala @@ -17,9 +17,9 @@ package org.apache.spark.sql.catalyst.statsEstimation -import org.apache.spark.sql.catalyst.CatalystConf import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap, AttributeReference, Literal} import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.IntegerType @@ -116,10 +116,10 @@ class BasicStatsEstimationSuite extends StatsEstimationTestBase { expectedStatsCboOff: Statistics): Unit = { // Invalidate statistics plan.invalidateStatsCache() - assert(plan.stats(conf.copy(cboEnabled = true)) == expectedStatsCboOn) + assert(plan.stats(conf.copy(SQLConf.CBO_ENABLED -> true)) == expectedStatsCboOn) plan.invalidateStatsCache() - assert(plan.stats(conf.copy(cboEnabled = false)) == expectedStatsCboOff) + assert(plan.stats(conf.copy(SQLConf.CBO_ENABLED -> false)) == expectedStatsCboOff) } /** Check estimated stats when it's the same whether cbo is turned on or off. */ @@ -136,6 +136,6 @@ private case class DummyLogicalPlan( cboStats: Statistics) extends LogicalPlan { override def output: Seq[Attribute] = Nil override def children: Seq[LogicalPlan] = Nil - override def computeStats(conf: CatalystConf): Statistics = + override def computeStats(conf: SQLConf): Statistics = if (conf.cboEnabled) cboStats else defaultStats } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/StatsEstimationTestBase.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/StatsEstimationTestBase.scala index 9b2b8dbe1bf4a..263f4e18803d5 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/StatsEstimationTestBase.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/StatsEstimationTestBase.scala @@ -18,16 +18,17 @@ package org.apache.spark.sql.catalyst.statsEstimation import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.catalyst.{CatalystConf, SimpleCatalystConf} import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap, AttributeReference} import org.apache.spark.sql.catalyst.plans.logical.{ColumnStat, LeafNode, LogicalPlan, Statistics} +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.internal.SQLConf.{CASE_SENSITIVE, CBO_ENABLED} import org.apache.spark.sql.types.{IntegerType, StringType} trait StatsEstimationTestBase extends SparkFunSuite { /** Enable stats estimation based on CBO. */ - protected val conf = SimpleCatalystConf(caseSensitiveAnalysis = true, cboEnabled = true) + protected val conf = new SQLConf().copy(CASE_SENSITIVE -> true, CBO_ENABLED -> true) def getColSize(attribute: Attribute, colStat: ColumnStat): Long = attribute.dataType match { // For UTF8String: base + offset + numBytes @@ -54,7 +55,7 @@ case class StatsTestPlan( attributeStats: AttributeMap[ColumnStat], size: Option[BigInt] = None) extends LeafNode { override def output: Seq[Attribute] = outputList - override def computeStats(conf: CatalystConf): Statistics = Statistics( + override def computeStats(conf: SQLConf): Statistics = Statistics( // If sizeInBytes is useless in testing, we just use a fake value sizeInBytes = size.getOrElse(Int.MaxValue), rowCount = Some(rowCount), diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala index 49336f424822f..2827b8ac00331 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala @@ -19,12 +19,13 @@ package org.apache.spark.sql.execution import org.apache.spark.rdd.RDD import org.apache.spark.sql.{Encoder, Row, SparkSession} -import org.apache.spark.sql.catalyst.{CatalystConf, CatalystTypeConverters, InternalRow} +import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.plans.physical.{Partitioning, UnknownPartitioning} import org.apache.spark.sql.execution.metric.SQLMetrics +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.DataType import org.apache.spark.util.Utils @@ -95,7 +96,7 @@ case class ExternalRDD[T]( override protected def stringArgs: Iterator[Any] = Iterator(output) - @transient override def computeStats(conf: CatalystConf): Statistics = Statistics( + @transient override def computeStats(conf: SQLConf): Statistics = Statistics( // TODO: Instead of returning a default value here, find a way to return a meaningful size // estimate for RDDs. See PR 1238 for more discussions. sizeInBytes = BigInt(session.sessionState.conf.defaultSizeInBytes) @@ -170,7 +171,7 @@ case class LogicalRDD( override protected def stringArgs: Iterator[Any] = Iterator(output) - @transient override def computeStats(conf: CatalystConf): Statistics = Statistics( + @transient override def computeStats(conf: SQLConf): Statistics = Statistics( // TODO: Instead of returning a default value here, find a way to return a meaningful size // estimate for RDDs. See PR 1238 for more discussions. sizeInBytes = BigInt(session.sessionState.conf.defaultSizeInBytes) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala index 36037ac003728..0a9f3e799990f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala @@ -21,12 +21,13 @@ import org.apache.commons.lang3.StringUtils import org.apache.spark.network.util.JavaUtils import org.apache.spark.rdd.RDD -import org.apache.spark.sql.catalyst.{CatalystConf, InternalRow} +import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical import org.apache.spark.sql.catalyst.plans.logical.Statistics import org.apache.spark.sql.execution.SparkPlan +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.storage.StorageLevel import org.apache.spark.util.LongAccumulator @@ -69,7 +70,7 @@ case class InMemoryRelation( @transient val partitionStatistics = new PartitionStatistics(output) - override def computeStats(conf: CatalystConf): Statistics = { + override def computeStats(conf: SQLConf): Statistics = { if (batchStats.value == 0L) { // Underlying columnar RDD hasn't been materialized, no useful statistics information // available, return the default statistics. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala index c350d8bcbae97..e5c7c383d708c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala @@ -21,12 +21,10 @@ import java.util.concurrent.Callable import scala.collection.mutable.ArrayBuffer -import org.apache.hadoop.fs.Path - import org.apache.spark.internal.Logging import org.apache.spark.rdd.RDD import org.apache.spark.sql._ -import org.apache.spark.sql.catalyst.{CatalystConf, CatalystTypeConverters, InternalRow, QualifiedTableName, TableIdentifier} +import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow, QualifiedTableName, TableIdentifier} import org.apache.spark.sql.catalyst.CatalystTypeConverters.convertToScala import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.catalog.{CatalogRelation, CatalogUtils} @@ -38,7 +36,7 @@ import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, UnknownPa import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.execution.{RowDataSourceScanExec, SparkPlan} import org.apache.spark.sql.execution.command._ -import org.apache.spark.sql.internal.StaticSQLConf +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.sources._ import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String @@ -50,7 +48,7 @@ import org.apache.spark.unsafe.types.UTF8String * Note that, this rule must be run after `PreprocessTableCreation` and * `PreprocessTableInsertion`. */ -case class DataSourceAnalysis(conf: CatalystConf) extends Rule[LogicalPlan] { +case class DataSourceAnalysis(conf: SQLConf) extends Rule[LogicalPlan] { def resolver: Resolver = conf.resolver diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/LogicalRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/LogicalRelation.scala index 04a764bee2ef2..3b14b794fd08c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/LogicalRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/LogicalRelation.scala @@ -16,11 +16,11 @@ */ package org.apache.spark.sql.execution.datasources -import org.apache.spark.sql.catalyst.CatalystConf import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation import org.apache.spark.sql.catalyst.catalog.CatalogTable import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap, AttributeReference} import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, LogicalPlan, Statistics} +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.sources.BaseRelation import org.apache.spark.util.Utils @@ -73,7 +73,7 @@ case class LogicalRelation( // expId can be different but the relation is still the same. override lazy val cleanArgs: Seq[Any] = Seq(relation) - @transient override def computeStats(conf: CatalystConf): Statistics = { + @transient override def computeStats(conf: SQLConf): Statistics = { catalogTable.flatMap(_.stats.map(_.toPlanStats(output))).getOrElse( Statistics(sizeInBytes = relation.sizeInBytes)) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala index 6d34d51d31c1e..971ce5afb1778 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala @@ -25,11 +25,11 @@ import scala.util.control.NonFatal import org.apache.spark.internal.Logging import org.apache.spark.sql._ -import org.apache.spark.sql.catalyst.CatalystConf import org.apache.spark.sql.catalyst.encoders.encoderFor import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, Statistics} import org.apache.spark.sql.catalyst.streaming.InternalOutputModes._ +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.streaming.OutputMode import org.apache.spark.sql.types.StructType import org.apache.spark.util.Utils @@ -230,6 +230,6 @@ case class MemoryPlan(sink: MemorySink, output: Seq[Attribute]) extends LeafNode private val sizePerRow = sink.schema.toAttributes.map(_.dataType.defaultSize).sum - override def computeStats(conf: CatalystConf): Statistics = + override def computeStats(conf: SQLConf): Statistics = Statistics(sizePerRow * sink.allData.size) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/DataSourceAnalysisSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/DataSourceAnalysisSuite.scala index 448adcf11d656..b16c9f8fc96b2 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/DataSourceAnalysisSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/DataSourceAnalysisSuite.scala @@ -21,10 +21,10 @@ import org.scalatest.BeforeAndAfterAll import org.apache.spark.SparkFunSuite import org.apache.spark.sql.AnalysisException -import org.apache.spark.sql.catalyst.SimpleCatalystConf import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, Cast, Expression, Literal} import org.apache.spark.sql.execution.datasources.DataSourceAnalysis +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.{IntegerType, StructType} class DataSourceAnalysisSuite extends SparkFunSuite with BeforeAndAfterAll { @@ -49,7 +49,7 @@ class DataSourceAnalysisSuite extends SparkFunSuite with BeforeAndAfterAll { } Seq(true, false).foreach { caseSensitive => - val rule = DataSourceAnalysis(SimpleCatalystConf(caseSensitive)) + val rule = DataSourceAnalysis(new SQLConf().copy(SQLConf.CASE_SENSITIVE -> caseSensitive)) test( s"convertStaticPartitions only handle INSERT having at least static partitions " + s"(caseSensitive: $caseSensitive)") { From 295747e59739ee8a697ac3eba485d3439e4a04c3 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Tue, 4 Apr 2017 16:38:32 -0700 Subject: [PATCH 0193/1765] [SPARK-19716][SQL] support by-name resolution for struct type elements in array ## What changes were proposed in this pull request? Previously when we construct deserializer expression for array type, we will first cast the corresponding field to expected array type and then apply `MapObjects`. However, by doing that, we lose the opportunity to do by-name resolution for struct type inside array type. In this PR, I introduce a `UnresolvedMapObjects` to hold the lambda function and the input array expression. Then during analysis, after the input array expression is resolved, we get the actual array element type and apply by-name resolution. Then we don't need to add `Cast` for array type when constructing the deserializer expression, as the element type is determined later at analyzer. ## How was this patch tested? new regression test Author: Wenchen Fan Closes #17398 from cloud-fan/dataset. --- .../spark/sql/catalyst/ScalaReflection.scala | 66 +++++++++++-------- .../sql/catalyst/analysis/Analyzer.scala | 19 +++++- .../expressions/complexTypeExtractors.scala | 2 +- .../expressions/objects/objects.scala | 32 +++++++-- .../encoders/EncoderResolutionSuite.scala | 52 +++++++++++++++ .../sql/expressions/ReduceAggregator.scala | 2 +- .../org/apache/spark/sql/DatasetSuite.scala | 9 +++ 7 files changed, 141 insertions(+), 41 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala index da37eb00dcd97..206ae2f0e5eb1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala @@ -92,7 +92,7 @@ object ScalaReflection extends ScalaReflection { * Array[T]. Special handling is performed for primitive types to map them back to their raw * JVM form instead of the Scala Array that handles auto boxing. */ - private def arrayClassFor(tpe: `Type`): DataType = ScalaReflectionLock.synchronized { + private def arrayClassFor(tpe: `Type`): ObjectType = ScalaReflectionLock.synchronized { val cls = tpe match { case t if t <:< definitions.IntTpe => classOf[Array[Int]] case t if t <:< definitions.LongTpe => classOf[Array[Long]] @@ -178,15 +178,17 @@ object ScalaReflection extends ScalaReflection { * is [a: int, b: long], then we will hit runtime error and say that we can't construct class * `Data` with int and long, because we lost the information that `b` should be a string. * - * This method help us "remember" the required data type by adding a `UpCast`. Note that we - * don't need to cast struct type because there must be `UnresolvedExtractValue` or - * `GetStructField` wrapping it, thus we only need to handle leaf type. + * This method help us "remember" the required data type by adding a `UpCast`. Note that we + * only need to do this for leaf nodes. */ def upCastToExpectedType( expr: Expression, expected: DataType, walkedTypePath: Seq[String]): Expression = expected match { case _: StructType => expr + case _: ArrayType => expr + // TODO: ideally we should also skip MapType, but nested StructType inside MapType is rare and + // it's not trivial to support by-name resolution for StructType inside MapType. case _ => UpCast(expr, expected, walkedTypePath) } @@ -265,42 +267,48 @@ object ScalaReflection extends ScalaReflection { case t if t <:< localTypeOf[Array[_]] => val TypeRef(_, _, Seq(elementType)) = t + val Schema(_, elementNullable) = schemaFor(elementType) + val className = getClassNameFromType(elementType) + val newTypePath = s"""- array element class: "$className"""" +: walkedTypePath - // TODO: add runtime null check for primitive array - val primitiveMethod = elementType match { - case t if t <:< definitions.IntTpe => Some("toIntArray") - case t if t <:< definitions.LongTpe => Some("toLongArray") - case t if t <:< definitions.DoubleTpe => Some("toDoubleArray") - case t if t <:< definitions.FloatTpe => Some("toFloatArray") - case t if t <:< definitions.ShortTpe => Some("toShortArray") - case t if t <:< definitions.ByteTpe => Some("toByteArray") - case t if t <:< definitions.BooleanTpe => Some("toBooleanArray") - case _ => None + val mapFunction: Expression => Expression = p => { + val converter = deserializerFor(elementType, Some(p), newTypePath) + if (elementNullable) { + converter + } else { + AssertNotNull(converter, newTypePath) + } } - primitiveMethod.map { method => - Invoke(getPath, method, arrayClassFor(elementType)) - }.getOrElse { - val className = getClassNameFromType(elementType) - val newTypePath = s"""- array element class: "$className"""" +: walkedTypePath - Invoke( - MapObjects( - p => deserializerFor(elementType, Some(p), newTypePath), - getPath, - schemaFor(elementType).dataType), - "array", - arrayClassFor(elementType)) + val arrayData = UnresolvedMapObjects(mapFunction, getPath) + val arrayCls = arrayClassFor(elementType) + + if (elementNullable) { + Invoke(arrayData, "array", arrayCls) + } else { + val primitiveMethod = elementType match { + case t if t <:< definitions.IntTpe => "toIntArray" + case t if t <:< definitions.LongTpe => "toLongArray" + case t if t <:< definitions.DoubleTpe => "toDoubleArray" + case t if t <:< definitions.FloatTpe => "toFloatArray" + case t if t <:< definitions.ShortTpe => "toShortArray" + case t if t <:< definitions.ByteTpe => "toByteArray" + case t if t <:< definitions.BooleanTpe => "toBooleanArray" + case other => throw new IllegalStateException("expect primitive array element type " + + "but got " + other) + } + Invoke(arrayData, primitiveMethod, arrayCls) } case t if t <:< localTypeOf[Seq[_]] => val TypeRef(_, _, Seq(elementType)) = t - val Schema(dataType, nullable) = schemaFor(elementType) + val Schema(_, elementNullable) = schemaFor(elementType) val className = getClassNameFromType(elementType) val newTypePath = s"""- array element class: "$className"""" +: walkedTypePath val mapFunction: Expression => Expression = p => { val converter = deserializerFor(elementType, Some(p), newTypePath) - if (nullable) { + if (elementNullable) { converter } else { AssertNotNull(converter, newTypePath) @@ -312,7 +320,7 @@ object ScalaReflection extends ScalaReflection { case NoSymbol => classOf[Seq[_]] case _ => mirror.runtimeClass(t.typeSymbol.asClass) } - MapObjects(mapFunction, getPath, dataType, Some(cls)) + UnresolvedMapObjects(mapFunction, getPath, Some(cls)) case t if t <:< localTypeOf[Map[_, _]] => // TODO: add walked type path for map diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 2d53d2424a34d..c698ca6a8347c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -26,7 +26,7 @@ import org.apache.spark.sql.catalyst.catalog._ import org.apache.spark.sql.catalyst.encoders.OuterScopes import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ -import org.apache.spark.sql.catalyst.expressions.objects.NewInstance +import org.apache.spark.sql.catalyst.expressions.objects.{MapObjects, NewInstance, UnresolvedMapObjects} import org.apache.spark.sql.catalyst.expressions.SubExprUtils._ import org.apache.spark.sql.catalyst.optimizer.BooleanSimplification import org.apache.spark.sql.catalyst.plans._ @@ -2227,8 +2227,21 @@ class Analyzer( validateTopLevelTupleFields(deserializer, inputs) val resolved = resolveExpression( deserializer, LocalRelation(inputs), throws = true) - validateNestedTupleFields(resolved) - resolved + val result = resolved transformDown { + case UnresolvedMapObjects(func, inputData, cls) if inputData.resolved => + inputData.dataType match { + case ArrayType(et, _) => + val expr = MapObjects(func, inputData, et, cls) transformUp { + case UnresolvedExtractValue(child, fieldName) if child.resolved => + ExtractValue(child, fieldName, resolver) + } + expr + case other => + throw new AnalysisException("need an array field but got " + other.simpleString) + } + } + validateNestedTupleFields(result) + result } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala index de1594d119e17..ef88cfb543ebb 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala @@ -68,7 +68,7 @@ object ExtractValue { case StructType(_) => s"Field name should be String Literal, but it's $extraction" case other => - s"Can't extract value from $child" + s"Can't extract value from $child: need struct type but got ${other.simpleString}" } throw new AnalysisException(errorMsg) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala index bb584f7d087e8..00e2ac91e67ca 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala @@ -448,6 +448,17 @@ object MapObjects { } } +case class UnresolvedMapObjects( + function: Expression => Expression, + child: Expression, + customCollectionCls: Option[Class[_]] = None) extends UnaryExpression with Unevaluable { + override lazy val resolved = false + + override def dataType: DataType = customCollectionCls.map(ObjectType.apply).getOrElse { + throw new UnsupportedOperationException("not resolved") + } +} + /** * Applies the given expression to every element of a collection of items, returning the result * as an ArrayType or ObjectType. This is similar to a typical map operation, but where the lambda @@ -581,17 +592,24 @@ case class MapObjects private( // collection val collObjectName = s"${cls.getName}$$.MODULE$$" val getBuilderVar = s"$collObjectName.newBuilder()" - - (s"""${classOf[Builder[_, _]].getName} $builderValue = $getBuilderVar; - $builderValue.sizeHint($dataLength);""", + ( + s""" + ${classOf[Builder[_, _]].getName} $builderValue = $getBuilderVar; + $builderValue.sizeHint($dataLength); + """, genValue => s"$builderValue.$$plus$$eq($genValue);", - s"(${cls.getName}) $builderValue.result();") + s"(${cls.getName}) $builderValue.result();" + ) case None => // array - (s"""$convertedType[] $convertedArray = null; - $convertedArray = $arrayConstructor;""", + ( + s""" + $convertedType[] $convertedArray = null; + $convertedArray = $arrayConstructor; + """, genValue => s"$convertedArray[$loopIndex] = $genValue;", - s"new ${classOf[GenericArrayData].getName}($convertedArray);") + s"new ${classOf[GenericArrayData].getName}($convertedArray);" + ) } val code = s""" diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/EncoderResolutionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/EncoderResolutionSuite.scala index 802397d50e85c..e5a3e1fd374dc 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/EncoderResolutionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/EncoderResolutionSuite.scala @@ -33,6 +33,10 @@ case class StringIntClass(a: String, b: Int) case class ComplexClass(a: Long, b: StringLongClass) +case class ArrayClass(arr: Seq[StringIntClass]) + +case class NestedArrayClass(nestedArr: Array[ArrayClass]) + class EncoderResolutionSuite extends PlanTest { private val str = UTF8String.fromString("hello") @@ -62,6 +66,54 @@ class EncoderResolutionSuite extends PlanTest { encoder.resolveAndBind(attrs).fromRow(InternalRow(InternalRow(str, 1.toByte), 2)) } + test("real type doesn't match encoder schema but they are compatible: array") { + val encoder = ExpressionEncoder[ArrayClass] + val attrs = Seq('arr.array(new StructType().add("a", "int").add("b", "int").add("c", "int"))) + val array = new GenericArrayData(Array(InternalRow(1, 2, 3))) + encoder.resolveAndBind(attrs).fromRow(InternalRow(array)) + } + + test("real type doesn't match encoder schema but they are compatible: nested array") { + val encoder = ExpressionEncoder[NestedArrayClass] + val et = new StructType().add("arr", ArrayType( + new StructType().add("a", "int").add("b", "int").add("c", "int"))) + val attrs = Seq('nestedArr.array(et)) + val innerArr = new GenericArrayData(Array(InternalRow(1, 2, 3))) + val outerArr = new GenericArrayData(Array(InternalRow(innerArr))) + encoder.resolveAndBind(attrs).fromRow(InternalRow(outerArr)) + } + + test("the real type is not compatible with encoder schema: non-array field") { + val encoder = ExpressionEncoder[ArrayClass] + val attrs = Seq('arr.int) + assert(intercept[AnalysisException](encoder.resolveAndBind(attrs)).message == + "need an array field but got int") + } + + test("the real type is not compatible with encoder schema: array element type") { + val encoder = ExpressionEncoder[ArrayClass] + val attrs = Seq('arr.array(new StructType().add("c", "int"))) + assert(intercept[AnalysisException](encoder.resolveAndBind(attrs)).message == + "No such struct field a in c") + } + + test("the real type is not compatible with encoder schema: nested array element type") { + val encoder = ExpressionEncoder[NestedArrayClass] + + withClue("inner element is not array") { + val attrs = Seq('nestedArr.array(new StructType().add("arr", "int"))) + assert(intercept[AnalysisException](encoder.resolveAndBind(attrs)).message == + "need an array field but got int") + } + + withClue("nested array element type is not compatible") { + val attrs = Seq('nestedArr.array(new StructType() + .add("arr", ArrayType(new StructType().add("c", "int"))))) + assert(intercept[AnalysisException](encoder.resolveAndBind(attrs)).message == + "No such struct field a in c") + } + } + test("nullability of array type element should not fail analysis") { val encoder = ExpressionEncoder[Seq[Int]] val attrs = 'a.array(IntegerType) :: Nil diff --git a/sql/core/src/main/scala/org/apache/spark/sql/expressions/ReduceAggregator.scala b/sql/core/src/main/scala/org/apache/spark/sql/expressions/ReduceAggregator.scala index 174378304d4a5..e266ae55cc4d9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/expressions/ReduceAggregator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/expressions/ReduceAggregator.scala @@ -30,7 +30,7 @@ import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder private[sql] class ReduceAggregator[T: Encoder](func: (T, T) => T) extends Aggregator[T, (Boolean, T), T] { - private val encoder = implicitly[Encoder[T]] + @transient private val encoder = implicitly[Encoder[T]] override def zero: (Boolean, T) = (false, null.asInstanceOf[T]) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala index 68e071a1a694f..5b5cd28ad0c99 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala @@ -142,6 +142,15 @@ class DatasetSuite extends QueryTest with SharedSQLContext { assert(ds.take(2) === Array(ClassData("a", 1), ClassData("b", 2))) } + test("as seq of case class - reorder fields by name") { + val df = spark.range(3).select(array(struct($"id".cast("int").as("b"), lit("a").as("a")))) + val ds = df.as[Seq[ClassData]] + assert(ds.collect() === Array( + Seq(ClassData("a", 0)), + Seq(ClassData("a", 1)), + Seq(ClassData("a", 2)))) + } + test("map") { val ds = Seq(("a", 1), ("b", 2), ("c", 3)).toDS() checkDataset( From a59759e6c059617b2fc8102cbf41acc5d409b34a Mon Sep 17 00:00:00 2001 From: Seth Hendrickson Date: Tue, 4 Apr 2017 17:04:41 -0700 Subject: [PATCH 0194/1765] [SPARK-20183][ML] Added outlierRatio arg to MLTestingUtils.testOutliersWithSmallWeights ## What changes were proposed in this pull request? This is a small piece from https://github.com/apache/spark/pull/16722 which ultimately will add sample weights to decision trees. This is to allow more flexibility in testing outliers since linear models and trees behave differently. Note: The primary author when this is committed should be sethah since this is taken from his code. ## How was this patch tested? Existing tests Author: Joseph K. Bradley Closes #17501 from jkbradley/SPARK-20183. --- .../org/apache/spark/ml/classification/LinearSVCSuite.scala | 2 +- .../spark/ml/classification/LogisticRegressionSuite.scala | 2 +- .../org/apache/spark/ml/classification/NaiveBayesSuite.scala | 2 +- .../apache/spark/ml/regression/LinearRegressionSuite.scala | 3 ++- .../test/scala/org/apache/spark/ml/util/MLTestingUtils.scala | 5 +++-- 5 files changed, 8 insertions(+), 6 deletions(-) diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/LinearSVCSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/LinearSVCSuite.scala index 4c63a2a88c6c6..c763a4cef1afd 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/LinearSVCSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/LinearSVCSuite.scala @@ -164,7 +164,7 @@ class LinearSVCSuite extends SparkFunSuite with MLlibTestSparkContext with Defau MLTestingUtils.testArbitrarilyScaledWeights[LinearSVCModel, LinearSVC]( dataset.as[LabeledPoint], estimator, modelEquals) MLTestingUtils.testOutliersWithSmallWeights[LinearSVCModel, LinearSVC]( - dataset.as[LabeledPoint], estimator, 2, modelEquals) + dataset.as[LabeledPoint], estimator, 2, modelEquals, outlierRatio = 3) MLTestingUtils.testOversamplingVsWeighting[LinearSVCModel, LinearSVC]( dataset.as[LabeledPoint], estimator, modelEquals, 42L) } diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala index 1b64480373492..f0648d0936a12 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala @@ -1874,7 +1874,7 @@ class LogisticRegressionSuite MLTestingUtils.testArbitrarilyScaledWeights[LogisticRegressionModel, LogisticRegression]( dataset.as[LabeledPoint], estimator, modelEquals) MLTestingUtils.testOutliersWithSmallWeights[LogisticRegressionModel, LogisticRegression]( - dataset.as[LabeledPoint], estimator, numClasses, modelEquals) + dataset.as[LabeledPoint], estimator, numClasses, modelEquals, outlierRatio = 3) MLTestingUtils.testOversamplingVsWeighting[LogisticRegressionModel, LogisticRegression]( dataset.as[LabeledPoint], estimator, modelEquals, seed) } diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala index 4d5d299d1408f..d41c5b533dedf 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala @@ -178,7 +178,7 @@ class NaiveBayesSuite extends SparkFunSuite with MLlibTestSparkContext with Defa MLTestingUtils.testArbitrarilyScaledWeights[NaiveBayesModel, NaiveBayes]( dataset.as[LabeledPoint], estimatorNoSmoothing, modelEquals) MLTestingUtils.testOutliersWithSmallWeights[NaiveBayesModel, NaiveBayes]( - dataset.as[LabeledPoint], estimatorWithSmoothing, numClasses, modelEquals) + dataset.as[LabeledPoint], estimatorWithSmoothing, numClasses, modelEquals, outlierRatio = 3) MLTestingUtils.testOversamplingVsWeighting[NaiveBayesModel, NaiveBayes]( dataset.as[LabeledPoint], estimatorWithSmoothing, modelEquals, seed) } diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala index 6a51e75e12a36..c6a267b7283d8 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala @@ -842,7 +842,8 @@ class LinearRegressionSuite MLTestingUtils.testArbitrarilyScaledWeights[LinearRegressionModel, LinearRegression]( datasetWithStrongNoise.as[LabeledPoint], estimator, modelEquals) MLTestingUtils.testOutliersWithSmallWeights[LinearRegressionModel, LinearRegression]( - datasetWithStrongNoise.as[LabeledPoint], estimator, numClasses, modelEquals) + datasetWithStrongNoise.as[LabeledPoint], estimator, numClasses, modelEquals, + outlierRatio = 3) MLTestingUtils.testOversamplingVsWeighting[LinearRegressionModel, LinearRegression]( datasetWithStrongNoise.as[LabeledPoint], estimator, modelEquals, seed) } diff --git a/mllib/src/test/scala/org/apache/spark/ml/util/MLTestingUtils.scala b/mllib/src/test/scala/org/apache/spark/ml/util/MLTestingUtils.scala index f1ed568d5e60a..578f31c8e7dba 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/util/MLTestingUtils.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/util/MLTestingUtils.scala @@ -260,12 +260,13 @@ object MLTestingUtils extends SparkFunSuite { data: Dataset[LabeledPoint], estimator: E with HasWeightCol, numClasses: Int, - modelEquals: (M, M) => Unit): Unit = { + modelEquals: (M, M) => Unit, + outlierRatio: Int): Unit = { import data.sqlContext.implicits._ val outlierDS = data.withColumn("weight", lit(1.0)).as[Instance].flatMap { case Instance(l, w, f) => val outlierLabel = if (numClasses == 0) -l else numClasses - l - 1 - List.fill(3)(Instance(outlierLabel, 0.0001, f)) ++ List(Instance(l, w, f)) + List.fill(outlierRatio)(Instance(outlierLabel, 0.0001, f)) ++ List(Instance(l, w, f)) } val trueModel = estimator.set(estimator.weightCol, "").fit(data) val outlierModel = estimator.set(estimator.weightCol, "weight").fit(outlierDS) From b28bbffbadf7ebc4349666e8f17111f6fca18c9a Mon Sep 17 00:00:00 2001 From: Yuhao Yang Date: Tue, 4 Apr 2017 17:51:45 -0700 Subject: [PATCH 0195/1765] [SPARK-20003][ML] FPGrowthModel setMinConfidence should affect rules generation and transform ## What changes were proposed in this pull request? jira: https://issues.apache.org/jira/browse/SPARK-20003 I was doing some test and found the issue. ml.fpm.FPGrowthModel `setMinConfidence` should always affect rules generation and transform. Currently associationRules in FPGrowthModel is a lazy val and `setMinConfidence` in FPGrowthModel has no impact once associationRules got computed . I try to cache the associationRules to avoid re-computation if `minConfidence` is not changed, but this makes FPGrowthModel somehow stateful. Let me know if there's any concern. ## How was this patch tested? new unit test and I strength the unit test for model save/load to ensure the cache mechanism. Author: Yuhao Yang Closes #17336 from hhbyyh/fpmodelminconf. --- .../org/apache/spark/ml/fpm/FPGrowth.scala | 21 ++++++- .../apache/spark/ml/fpm/FPGrowthSuite.scala | 56 +++++++++++++------ 2 files changed, 56 insertions(+), 21 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/fpm/FPGrowth.scala b/mllib/src/main/scala/org/apache/spark/ml/fpm/FPGrowth.scala index 65cc80619569e..d604c1ac001a2 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/fpm/FPGrowth.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/fpm/FPGrowth.scala @@ -218,13 +218,28 @@ class FPGrowthModel private[ml] ( def setPredictionCol(value: String): this.type = set(predictionCol, value) /** - * Get association rules fitted by AssociationRules using the minConfidence. Returns a dataframe + * Cache minConfidence and associationRules to avoid redundant computation for association rules + * during transform. The associationRules will only be re-computed when minConfidence changed. + */ + @transient private var _cachedMinConf: Double = Double.NaN + + @transient private var _cachedRules: DataFrame = _ + + /** + * Get association rules fitted using the minConfidence. Returns a dataframe * with three fields, "antecedent", "consequent" and "confidence", where "antecedent" and * "consequent" are Array[T] and "confidence" is Double. */ @Since("2.2.0") - @transient lazy val associationRules: DataFrame = { - AssociationRules.getAssociationRulesFromFP(freqItemsets, "items", "freq", $(minConfidence)) + @transient def associationRules: DataFrame = { + if ($(minConfidence) == _cachedMinConf) { + _cachedRules + } else { + _cachedRules = AssociationRules + .getAssociationRulesFromFP(freqItemsets, "items", "freq", $(minConfidence)) + _cachedMinConf = $(minConfidence) + _cachedRules + } } /** diff --git a/mllib/src/test/scala/org/apache/spark/ml/fpm/FPGrowthSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/fpm/FPGrowthSuite.scala index 4603a618d2f93..6bec057511cd1 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/fpm/FPGrowthSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/fpm/FPGrowthSuite.scala @@ -17,7 +17,7 @@ package org.apache.spark.ml.fpm import org.apache.spark.SparkFunSuite -import org.apache.spark.ml.util.DefaultReadWriteTest +import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.sql.{DataFrame, Dataset, Row, SparkSession} import org.apache.spark.sql.functions._ @@ -85,38 +85,58 @@ class FPGrowthSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul assert(prediction.select("prediction").where("id=3").first().getSeq[String](0).isEmpty) } + test("FPGrowth prediction should not contain duplicates") { + // This should generate rule 1 -> 3, 2 -> 3 + val dataset = spark.createDataFrame(Seq( + Array("1", "3"), + Array("2", "3") + ).map(Tuple1(_))).toDF("items") + val model = new FPGrowth().fit(dataset) + + val prediction = model.transform( + spark.createDataFrame(Seq(Tuple1(Array("1", "2")))).toDF("items") + ).first().getAs[Seq[String]]("prediction") + + assert(prediction === Seq("3")) + } + + test("FPGrowthModel setMinConfidence should affect rules generation and transform") { + val model = new FPGrowth().setMinSupport(0.1).setMinConfidence(0.1).fit(dataset) + val oldRulesNum = model.associationRules.count() + val oldPredict = model.transform(dataset) + + model.setMinConfidence(0.8765) + assert(oldRulesNum > model.associationRules.count()) + assert(!model.transform(dataset).collect().toSet.equals(oldPredict.collect().toSet)) + + // association rules should stay the same for same minConfidence + model.setMinConfidence(0.1) + assert(oldRulesNum === model.associationRules.count()) + assert(model.transform(dataset).collect().toSet.equals(oldPredict.collect().toSet)) + } + test("FPGrowth parameter check") { val fpGrowth = new FPGrowth().setMinSupport(0.4567) val model = fpGrowth.fit(dataset) .setMinConfidence(0.5678) assert(fpGrowth.getMinSupport === 0.4567) assert(model.getMinConfidence === 0.5678) + MLTestingUtils.checkCopy(model) } test("read/write") { def checkModelData(model: FPGrowthModel, model2: FPGrowthModel): Unit = { - assert(model.freqItemsets.sort("items").collect() === - model2.freqItemsets.sort("items").collect()) + assert(model.freqItemsets.collect().toSet.equals( + model2.freqItemsets.collect().toSet)) + assert(model.associationRules.collect().toSet.equals( + model2.associationRules.collect().toSet)) + assert(model.setMinConfidence(0.9).associationRules.collect().toSet.equals( + model2.setMinConfidence(0.9).associationRules.collect().toSet)) } val fPGrowth = new FPGrowth() testEstimatorAndModelReadWrite(fPGrowth, dataset, FPGrowthSuite.allParamSettings, FPGrowthSuite.allParamSettings, checkModelData) } - - test("FPGrowth prediction should not contain duplicates") { - // This should generate rule 1 -> 3, 2 -> 3 - val dataset = spark.createDataFrame(Seq( - Array("1", "3"), - Array("2", "3") - ).map(Tuple1(_))).toDF("items") - val model = new FPGrowth().fit(dataset) - - val prediction = model.transform( - spark.createDataFrame(Seq(Tuple1(Array("1", "2")))).toDF("items") - ).first().getAs[Seq[String]]("prediction") - - assert(prediction === Seq("3")) - } } object FPGrowthSuite { From c1b8b667506ed95c6c2808e7d3db8463435e73f6 Mon Sep 17 00:00:00 2001 From: Felix Cheung Date: Tue, 4 Apr 2017 22:32:46 -0700 Subject: [PATCH 0196/1765] [SPARKR][DOC] update doc for fpgrowth ## What changes were proposed in this pull request? minor update zero323 Author: Felix Cheung Closes #17526 from felixcheung/rfpgrowthfollowup. --- R/pkg/R/mllib_clustering.R | 6 +----- R/pkg/R/mllib_fpm.R | 4 ++++ 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/R/pkg/R/mllib_clustering.R b/R/pkg/R/mllib_clustering.R index 0ebdb5a273088..97c9fa1b45840 100644 --- a/R/pkg/R/mllib_clustering.R +++ b/R/pkg/R/mllib_clustering.R @@ -498,11 +498,7 @@ setMethod("write.ml", signature(object = "KMeansModel", path = "character"), #' @export #' @examples #' \dontrun{ -#' # nolint start -#' # An example "path/to/file" can be -#' # paste0(Sys.getenv("SPARK_HOME"), "/data/mllib/sample_lda_libsvm_data.txt") -#' # nolint end -#' text <- read.df("path/to/file", source = "libsvm") +#' text <- read.df("data/mllib/sample_lda_libsvm_data.txt", source = "libsvm") #' model <- spark.lda(data = text, optimizer = "em") #' #' # get a summary of the model diff --git a/R/pkg/R/mllib_fpm.R b/R/pkg/R/mllib_fpm.R index 96251b2c7c195..dfcb45a1b66c9 100644 --- a/R/pkg/R/mllib_fpm.R +++ b/R/pkg/R/mllib_fpm.R @@ -27,6 +27,10 @@ setClass("FPGrowthModel", slots = list(jobj = "jobj")) #' FP-growth #' #' A parallel FP-growth algorithm to mine frequent itemsets. +#' \code{spark.fpGrowth} fits a FP-growth model on a SparkDataFrame. Users can +#' \code{spark.freqItemsets} to get frequent itemsets, \code{spark.associationRules} to get +#' association rules, \code{predict} to make predictions on new data based on generated association +#' rules, and \code{write.ml}/\code{read.ml} to save/load fitted models. #' For more details, see #' \href{https://spark.apache.org/docs/latest/mllib-frequent-pattern-mining.html#fp-growth}{ #' FP-growth}. From b6e71032d92a072b7c951e5ea641e9454b5e70ed Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Tue, 4 Apr 2017 22:46:42 -0700 Subject: [PATCH 0197/1765] Small doc fix for ReuseSubquery. --- .../main/scala/org/apache/spark/sql/execution/subquery.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/subquery.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/subquery.scala index 58be2d1da2816..d11045fb6ac8c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/subquery.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/subquery.scala @@ -150,7 +150,7 @@ case class PlanSubqueries(sparkSession: SparkSession) extends Rule[SparkPlan] { /** - * Find out duplicated exchanges in the spark plan, then use the same exchange for all the + * Find out duplicated subqueries in the spark plan, then use the same subquery result for all the * references. */ case class ReuseSubquery(conf: SQLConf) extends Rule[SparkPlan] { @@ -159,7 +159,7 @@ case class ReuseSubquery(conf: SQLConf) extends Rule[SparkPlan] { if (!conf.exchangeReuseEnabled) { return plan } - // Build a hash map using schema of exchanges to avoid O(N*N) sameResult calls. + // Build a hash map using schema of subqueries to avoid O(N*N) sameResult calls. val subqueries = mutable.HashMap[StructType, ArrayBuffer[SubqueryExec]]() plan transformAllExpressions { case sub: ExecSubqueryExpression => From dad499f324c6a93650aecfeb8cde10a405372930 Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Tue, 4 Apr 2017 23:20:17 -0700 Subject: [PATCH 0198/1765] [SPARK-20209][SS] Execute next trigger immediately if previous batch took longer than trigger interval ## What changes were proposed in this pull request? For large trigger intervals (e.g. 10 minutes), if a batch takes 11 minutes, then it will wait for 9 mins before starting the next batch. This does not make sense. The processing time based trigger policy should be to do process batches as fast as possible, but no faster than 1 in every trigger interval. If batches are taking longer than trigger interval anyways, then no point waiting extra trigger interval. In this PR, I modified the ProcessingTimeExecutor to do so. Another minor change I did was to extract our StreamManualClock into a separate class so that it can be used outside subclasses of StreamTest. For example, ProcessingTimeExecutorSuite does not need to create any context for testing, just needs the StreamManualClock. ## How was this patch tested? Added new unit tests to comprehensively test this behavior. Author: Tathagata Das Closes #17525 from tdas/SPARK-20209. --- .../spark/sql/kafka010/KafkaSourceSuite.scala | 1 + .../execution/streaming/TriggerExecutor.scala | 17 ++-- .../ProcessingTimeExecutorSuite.scala | 83 ++++++++++++++++-- .../sql/streaming/FileStreamSourceSuite.scala | 1 + .../FlatMapGroupsWithStateSuite.scala | 3 +- .../spark/sql/streaming/StreamSuite.scala | 1 + .../spark/sql/streaming/StreamTest.scala | 20 +---- .../streaming/StreamingAggregationSuite.scala | 1 + .../StreamingQueryListenerSuite.scala | 1 + .../sql/streaming/StreamingQuerySuite.scala | 87 +++++++++++-------- .../streaming/util/StreamManualClock.scala | 51 +++++++++++ 11 files changed, 194 insertions(+), 72 deletions(-) create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/streaming/util/StreamManualClock.scala diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSourceSuite.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSourceSuite.scala index 6391d6269c5ab..0046ba7e43d13 100644 --- a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSourceSuite.scala +++ b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSourceSuite.scala @@ -39,6 +39,7 @@ import org.apache.spark.sql.execution.streaming._ import org.apache.spark.sql.functions.{count, window} import org.apache.spark.sql.kafka010.KafkaSourceProvider._ import org.apache.spark.sql.streaming.{ProcessingTime, StreamTest} +import org.apache.spark.sql.streaming.util.StreamManualClock import org.apache.spark.sql.test.{SharedSQLContext, TestSparkSession} import org.apache.spark.util.Utils diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TriggerExecutor.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TriggerExecutor.scala index 02996ac854f69..d188566f822b4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TriggerExecutor.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TriggerExecutor.scala @@ -47,21 +47,22 @@ case class ProcessingTimeExecutor(processingTime: ProcessingTime, clock: Clock = extends TriggerExecutor with Logging { private val intervalMs = processingTime.intervalMs + require(intervalMs >= 0) - override def execute(batchRunner: () => Boolean): Unit = { + override def execute(triggerHandler: () => Boolean): Unit = { while (true) { - val batchStartTimeMs = clock.getTimeMillis() - val terminated = !batchRunner() + val triggerTimeMs = clock.getTimeMillis + val nextTriggerTimeMs = nextBatchTime(triggerTimeMs) + val terminated = !triggerHandler() if (intervalMs > 0) { - val batchEndTimeMs = clock.getTimeMillis() - val batchElapsedTimeMs = batchEndTimeMs - batchStartTimeMs + val batchElapsedTimeMs = clock.getTimeMillis - triggerTimeMs if (batchElapsedTimeMs > intervalMs) { notifyBatchFallingBehind(batchElapsedTimeMs) } if (terminated) { return } - clock.waitTillTime(nextBatchTime(batchEndTimeMs)) + clock.waitTillTime(nextTriggerTimeMs) } else { if (terminated) { return @@ -70,7 +71,7 @@ case class ProcessingTimeExecutor(processingTime: ProcessingTime, clock: Clock = } } - /** Called when a batch falls behind. Expose for test only */ + /** Called when a batch falls behind */ def notifyBatchFallingBehind(realElapsedTimeMs: Long): Unit = { logWarning("Current batch is falling behind. The trigger interval is " + s"${intervalMs} milliseconds, but spent ${realElapsedTimeMs} milliseconds") @@ -83,6 +84,6 @@ case class ProcessingTimeExecutor(processingTime: ProcessingTime, clock: Clock = * an interval of `100 ms`, `nextBatchTime(nextBatchTime(0)) = 200` rather than `0`). */ def nextBatchTime(now: Long): Long = { - now / intervalMs * intervalMs + intervalMs + if (intervalMs == 0) now else now / intervalMs * intervalMs + intervalMs } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/ProcessingTimeExecutorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/ProcessingTimeExecutorSuite.scala index 00d5e051de357..007554a83f548 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/ProcessingTimeExecutorSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/ProcessingTimeExecutorSuite.scala @@ -17,14 +17,24 @@ package org.apache.spark.sql.execution.streaming -import java.util.concurrent.{CountDownLatch, TimeUnit} +import java.util.concurrent.ConcurrentHashMap + +import scala.collection.mutable + +import org.eclipse.jetty.util.ConcurrentHashSet +import org.scalatest.concurrent.Eventually +import org.scalatest.concurrent.PatienceConfiguration.Timeout +import org.scalatest.concurrent.Timeouts._ +import org.scalatest.time.SpanSugar._ import org.apache.spark.SparkFunSuite import org.apache.spark.sql.streaming.ProcessingTime -import org.apache.spark.util.{Clock, ManualClock, SystemClock} +import org.apache.spark.sql.streaming.util.StreamManualClock class ProcessingTimeExecutorSuite extends SparkFunSuite { + val timeout = 10.seconds + test("nextBatchTime") { val processingTimeExecutor = ProcessingTimeExecutor(ProcessingTime(100)) assert(processingTimeExecutor.nextBatchTime(0) === 100) @@ -35,6 +45,57 @@ class ProcessingTimeExecutorSuite extends SparkFunSuite { assert(processingTimeExecutor.nextBatchTime(150) === 200) } + test("trigger timing") { + val triggerTimes = new ConcurrentHashSet[Int] + val clock = new StreamManualClock() + @volatile var continueExecuting = true + @volatile var clockIncrementInTrigger = 0L + val executor = ProcessingTimeExecutor(ProcessingTime("1000 milliseconds"), clock) + val executorThread = new Thread() { + override def run(): Unit = { + executor.execute(() => { + // Record the trigger time, increment clock if needed and + triggerTimes.add(clock.getTimeMillis.toInt) + clock.advance(clockIncrementInTrigger) + clockIncrementInTrigger = 0 // reset this so that there are no runaway triggers + continueExecuting + }) + } + } + executorThread.start() + // First batch should execute immediately, then executor should wait for next one + eventually { + assert(triggerTimes.contains(0)) + assert(clock.isStreamWaitingAt(0)) + assert(clock.isStreamWaitingFor(1000)) + } + + // Second batch should execute when clock reaches the next trigger time. + // If next trigger takes less than the trigger interval, executor should wait for next one + clockIncrementInTrigger = 500 + clock.setTime(1000) + eventually { + assert(triggerTimes.contains(1000)) + assert(clock.isStreamWaitingAt(1500)) + assert(clock.isStreamWaitingFor(2000)) + } + + // If next trigger takes less than the trigger interval, executor should immediately execute + // another one + clockIncrementInTrigger = 1500 + clock.setTime(2000) // allow another trigger by setting clock to 2000 + eventually { + // Since the next trigger will take 1500 (which is more than trigger interval of 1000) + // executor will immediately execute another trigger + assert(triggerTimes.contains(2000) && triggerTimes.contains(3500)) + assert(clock.isStreamWaitingAt(3500)) + assert(clock.isStreamWaitingFor(4000)) + } + continueExecuting = false + clock.advance(1000) + waitForThreadJoin(executorThread) + } + test("calling nextBatchTime with the result of a previous call should return the next interval") { val intervalMS = 100 val processingTimeExecutor = ProcessingTimeExecutor(ProcessingTime(intervalMS)) @@ -54,7 +115,7 @@ class ProcessingTimeExecutorSuite extends SparkFunSuite { val processingTimeExecutor = ProcessingTimeExecutor(ProcessingTime(intervalMs)) processingTimeExecutor.execute(() => { batchCounts += 1 - // If the batch termination works well, batchCounts should be 3 after `execute` + // If the batch termination works correctly, batchCounts should be 3 after `execute` batchCounts < 3 }) assert(batchCounts === 3) @@ -66,9 +127,8 @@ class ProcessingTimeExecutorSuite extends SparkFunSuite { } test("notifyBatchFallingBehind") { - val clock = new ManualClock() + val clock = new StreamManualClock() @volatile var batchFallingBehindCalled = false - val latch = new CountDownLatch(1) val t = new Thread() { override def run(): Unit = { val processingTimeExecutor = new ProcessingTimeExecutor(ProcessingTime(100), clock) { @@ -77,7 +137,6 @@ class ProcessingTimeExecutorSuite extends SparkFunSuite { } } processingTimeExecutor.execute(() => { - latch.countDown() clock.waitTillTime(200) false }) @@ -85,9 +144,17 @@ class ProcessingTimeExecutorSuite extends SparkFunSuite { } t.start() // Wait until the batch is running so that we don't call `advance` too early - assert(latch.await(10, TimeUnit.SECONDS), "the batch has not yet started in 10 seconds") + eventually { assert(clock.isStreamWaitingFor(200)) } clock.advance(200) - t.join() + waitForThreadJoin(t) assert(batchFallingBehindCalled === true) } + + private def eventually(body: => Unit): Unit = { + Eventually.eventually(Timeout(timeout)) { body } + } + + private def waitForThreadJoin(thread: Thread): Unit = { + failAfter(timeout) { thread.join() } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSourceSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSourceSuite.scala index 171877abe6e92..26967782f77c7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSourceSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSourceSuite.scala @@ -33,6 +33,7 @@ import org.apache.spark.sql.execution.streaming._ import org.apache.spark.sql.execution.streaming.FileStreamSource.{FileEntry, SeenFilesMap} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.streaming.ExistsThrowsExceptionFileSystem._ +import org.apache.spark.sql.streaming.util.StreamManualClock import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types._ import org.apache.spark.util.Utils diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala index c8e31e3ca2e04..85aa7dbe9ed86 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala @@ -21,8 +21,6 @@ import java.sql.Date import java.util.concurrent.ConcurrentHashMap import org.scalatest.BeforeAndAfterAll -import org.scalatest.concurrent.Eventually.eventually -import org.scalatest.concurrent.PatienceConfiguration.Timeout import org.apache.spark.SparkException import org.apache.spark.api.java.function.FlatMapGroupsWithStateFunction @@ -35,6 +33,7 @@ import org.apache.spark.sql.execution.RDDScanExec import org.apache.spark.sql.execution.streaming.{FlatMapGroupsWithStateExec, GroupStateImpl, MemoryStream} import org.apache.spark.sql.execution.streaming.state.{StateStore, StateStoreId, StoreUpdate} import org.apache.spark.sql.streaming.FlatMapGroupsWithStateSuite.MemoryStateStore +import org.apache.spark.sql.streaming.util.StreamManualClock import org.apache.spark.sql.types.{DataType, IntegerType} /** Class to check custom state types */ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala index 388f15405e70b..5ab9dc2bc7763 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala @@ -32,6 +32,7 @@ import org.apache.spark.sql.execution.streaming._ import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.sources.StreamSourceProvider +import org.apache.spark.sql.streaming.util.StreamManualClock import org.apache.spark.sql.types.{IntegerType, StructField, StructType} import org.apache.spark.util.Utils diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala index 951ff2ca0d684..03aa45b616880 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala @@ -214,24 +214,6 @@ trait StreamTest extends QueryTest with SharedSQLContext with Timeouts { AssertOnQuery(query => { func(query); true }) } - class StreamManualClock(time: Long = 0L) extends ManualClock(time) with Serializable { - private var waitStartTime: Option[Long] = None - - override def waitTillTime(targetTime: Long): Long = synchronized { - try { - waitStartTime = Some(getTimeMillis()) - super.waitTillTime(targetTime) - } finally { - waitStartTime = None - } - } - - def isStreamWaitingAt(time: Long): Boolean = synchronized { - waitStartTime == Some(time) - } - } - - /** * Executes the specified actions on the given streaming DataFrame and provides helpful * error messages in the case of failures or incorrect answers. @@ -242,6 +224,8 @@ trait StreamTest extends QueryTest with SharedSQLContext with Timeouts { def testStream( _stream: Dataset[_], outputMode: OutputMode = OutputMode.Append)(actions: StreamAction*): Unit = synchronized { + import org.apache.spark.sql.streaming.util.StreamManualClock + // `synchronized` is added to prevent the user from calling multiple `testStream`s concurrently // because this method assumes there is only one active query in its `StreamingQueryListener` // and it may not work correctly when multiple `testStream`s run concurrently. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala index 600c039cd0b9d..e5d5b4f328820 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala @@ -30,6 +30,7 @@ import org.apache.spark.sql.execution.streaming.state.StateStore import org.apache.spark.sql.expressions.scalalang.typed import org.apache.spark.sql.functions._ import org.apache.spark.sql.streaming.OutputMode._ +import org.apache.spark.sql.streaming.util.StreamManualClock object FailureSinglton { var firstTime = true diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryListenerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryListenerSuite.scala index 03dad8a6ddbc7..b8a694c177310 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryListenerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryListenerSuite.scala @@ -35,6 +35,7 @@ import org.apache.spark.sql.{Encoder, SparkSession} import org.apache.spark.sql.execution.streaming._ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.streaming.StreamingQueryListener._ +import org.apache.spark.sql.streaming.util.StreamManualClock import org.apache.spark.util.JsonProtocol class StreamingQueryListenerSuite extends StreamTest with BeforeAndAfter { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala index 1172531fe9988..2ebbfcd22b97c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala @@ -34,7 +34,7 @@ import org.apache.spark.SparkException import org.apache.spark.sql.execution.streaming._ import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.streaming.util.{BlockingSource, MockSourceProvider} +import org.apache.spark.sql.streaming.util.{BlockingSource, MockSourceProvider, StreamManualClock} import org.apache.spark.util.ManualClock @@ -207,46 +207,53 @@ class StreamingQuerySuite extends StreamTest with BeforeAndAfter with Logging wi /** Custom MemoryStream that waits for manual clock to reach a time */ val inputData = new MemoryStream[Int](0, sqlContext) { - // Wait for manual clock to be 100 first time there is data + // getOffset should take 50 ms the first time it is called override def getOffset: Option[Offset] = { val offset = super.getOffset if (offset.nonEmpty) { - clock.waitTillTime(300) + clock.waitTillTime(1050) } offset } - // Wait for manual clock to be 300 first time there is data + // getBatch should take 100 ms the first time it is called override def getBatch(start: Option[Offset], end: Offset): DataFrame = { - clock.waitTillTime(600) + if (start.isEmpty) clock.waitTillTime(1150) super.getBatch(start, end) } } - // This is to make sure thatquery waits for manual clock to be 600 first time there is data - val mapped = inputData.toDS().as[Long].map { x => - clock.waitTillTime(1100) + // query execution should take 350 ms the first time it is called + val mapped = inputData.toDS.coalesce(1).as[Long].map { x => + clock.waitTillTime(1500) // this will only wait the first time when clock < 1500 10 / x }.agg(count("*")).as[Long] - case class AssertStreamExecThreadToWaitForClock() + case class AssertStreamExecThreadIsWaitingForTime(targetTime: Long) extends AssertOnQuery(q => { eventually(Timeout(streamingTimeout)) { if (q.exception.isEmpty) { - assert(clock.asInstanceOf[StreamManualClock].isStreamWaitingAt(clock.getTimeMillis)) + assert(clock.isStreamWaitingFor(targetTime)) } } if (q.exception.isDefined) { throw q.exception.get } true - }, "") + }, "") { + override def toString: String = s"AssertStreamExecThreadIsWaitingForTime($targetTime)" + } + + case class AssertClockTime(time: Long) + extends AssertOnQuery(q => clock.getTimeMillis() === time, "") { + override def toString: String = s"AssertClockTime($time)" + } var lastProgressBeforeStop: StreamingQueryProgress = null testStream(mapped, OutputMode.Complete)( - StartStream(ProcessingTime(100), triggerClock = clock), - AssertStreamExecThreadToWaitForClock(), + StartStream(ProcessingTime(1000), triggerClock = clock), + AssertStreamExecThreadIsWaitingForTime(1000), AssertOnQuery(_.status.isDataAvailable === false), AssertOnQuery(_.status.isTriggerActive === false), AssertOnQuery(_.status.message === "Waiting for next trigger"), @@ -254,33 +261,37 @@ class StreamingQuerySuite extends StreamTest with BeforeAndAfter with Logging wi // Test status and progress while offset is being fetched AddData(inputData, 1, 2), - AdvanceManualClock(100), // time = 100 to start new trigger, will block on getOffset - AssertStreamExecThreadToWaitForClock(), + AdvanceManualClock(1000), // time = 1000 to start new trigger, will block on getOffset + AssertStreamExecThreadIsWaitingForTime(1050), AssertOnQuery(_.status.isDataAvailable === false), AssertOnQuery(_.status.isTriggerActive === true), AssertOnQuery(_.status.message.startsWith("Getting offsets from")), AssertOnQuery(_.recentProgress.count(_.numInputRows > 0) === 0), // Test status and progress while batch is being fetched - AdvanceManualClock(200), // time = 300 to unblock getOffset, will block on getBatch - AssertStreamExecThreadToWaitForClock(), + AdvanceManualClock(50), // time = 1050 to unblock getOffset + AssertClockTime(1050), + AssertStreamExecThreadIsWaitingForTime(1150), // will block on getBatch that needs 1150 AssertOnQuery(_.status.isDataAvailable === true), AssertOnQuery(_.status.isTriggerActive === true), AssertOnQuery(_.status.message === "Processing new data"), AssertOnQuery(_.recentProgress.count(_.numInputRows > 0) === 0), // Test status and progress while batch is being processed - AdvanceManualClock(300), // time = 600 to unblock getBatch, will block in Spark job + AdvanceManualClock(100), // time = 1150 to unblock getBatch + AssertClockTime(1150), + AssertStreamExecThreadIsWaitingForTime(1500), // will block in Spark job that needs 1500 AssertOnQuery(_.status.isDataAvailable === true), AssertOnQuery(_.status.isTriggerActive === true), AssertOnQuery(_.status.message === "Processing new data"), AssertOnQuery(_.recentProgress.count(_.numInputRows > 0) === 0), // Test status and progress while batch processing has completed - AdvanceManualClock(500), // time = 1100 to unblock job - AssertOnQuery { _ => clock.getTimeMillis() === 1100 }, + AssertOnQuery { _ => clock.getTimeMillis() === 1150 }, + AdvanceManualClock(350), // time = 1500 to unblock job + AssertClockTime(1500), CheckAnswer(2), - AssertStreamExecThreadToWaitForClock(), + AssertStreamExecThreadIsWaitingForTime(2000), AssertOnQuery(_.status.isDataAvailable === true), AssertOnQuery(_.status.isTriggerActive === false), AssertOnQuery(_.status.message === "Waiting for next trigger"), @@ -293,21 +304,21 @@ class StreamingQuerySuite extends StreamTest with BeforeAndAfter with Logging wi assert(progress.id === query.id) assert(progress.name === query.name) assert(progress.batchId === 0) - assert(progress.timestamp === "1970-01-01T00:00:00.100Z") // 100 ms in UTC + assert(progress.timestamp === "1970-01-01T00:00:01.000Z") // 100 ms in UTC assert(progress.numInputRows === 2) - assert(progress.processedRowsPerSecond === 2.0) + assert(progress.processedRowsPerSecond === 4.0) - assert(progress.durationMs.get("getOffset") === 200) - assert(progress.durationMs.get("getBatch") === 300) + assert(progress.durationMs.get("getOffset") === 50) + assert(progress.durationMs.get("getBatch") === 100) assert(progress.durationMs.get("queryPlanning") === 0) assert(progress.durationMs.get("walCommit") === 0) - assert(progress.durationMs.get("triggerExecution") === 1000) + assert(progress.durationMs.get("triggerExecution") === 500) assert(progress.sources.length === 1) assert(progress.sources(0).description contains "MemoryStream") assert(progress.sources(0).startOffset === null) assert(progress.sources(0).endOffset !== null) - assert(progress.sources(0).processedRowsPerSecond === 2.0) + assert(progress.sources(0).processedRowsPerSecond === 4.0) // 2 rows processed in 500 ms assert(progress.stateOperators.length === 1) assert(progress.stateOperators(0).numRowsUpdated === 1) @@ -317,9 +328,12 @@ class StreamingQuerySuite extends StreamTest with BeforeAndAfter with Logging wi true }, + // Test whether input rate is updated after two batches + AssertStreamExecThreadIsWaitingForTime(2000), // blocked waiting for next trigger time AddData(inputData, 1, 2), - AdvanceManualClock(100), // allow another trigger - AssertStreamExecThreadToWaitForClock(), + AdvanceManualClock(500), // allow another trigger + AssertClockTime(2000), + AssertStreamExecThreadIsWaitingForTime(3000), // will block waiting for next trigger time CheckAnswer(4), AssertOnQuery(_.status.isDataAvailable === true), AssertOnQuery(_.status.isTriggerActive === false), @@ -327,13 +341,14 @@ class StreamingQuerySuite extends StreamTest with BeforeAndAfter with Logging wi AssertOnQuery { query => assert(query.recentProgress.last.eq(query.lastProgress)) assert(query.lastProgress.batchId === 1) - assert(query.lastProgress.sources(0).inputRowsPerSecond === 1.818) + assert(query.lastProgress.inputRowsPerSecond === 2.0) + assert(query.lastProgress.sources(0).inputRowsPerSecond === 2.0) true }, // Test status and progress after data is not available for a trigger - AdvanceManualClock(100), // allow another trigger - AssertStreamExecThreadToWaitForClock(), + AdvanceManualClock(1000), // allow another trigger + AssertStreamExecThreadIsWaitingForTime(4000), AssertOnQuery(_.status.isDataAvailable === false), AssertOnQuery(_.status.isTriggerActive === false), AssertOnQuery(_.status.message === "Waiting for next trigger"), @@ -350,10 +365,10 @@ class StreamingQuerySuite extends StreamTest with BeforeAndAfter with Logging wi AssertOnQuery(_.status.message === "Stopped"), // Test status and progress after query terminated with error - StartStream(ProcessingTime(100), triggerClock = clock), - AdvanceManualClock(100), // ensure initial trigger completes before AddData + StartStream(ProcessingTime(1000), triggerClock = clock), + AdvanceManualClock(1000), // ensure initial trigger completes before AddData AddData(inputData, 0), - AdvanceManualClock(100), // allow another trigger + AdvanceManualClock(1000), // allow another trigger ExpectFailure[SparkException](), AssertOnQuery(_.status.isDataAvailable === false), AssertOnQuery(_.status.isTriggerActive === false), @@ -678,5 +693,5 @@ class StreamingQuerySuite extends StreamTest with BeforeAndAfter with Logging wi object StreamingQuerySuite { // Singleton reference to clock that does not get serialized in task closures - var clock: ManualClock = null + var clock: StreamManualClock = null } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/util/StreamManualClock.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/util/StreamManualClock.scala new file mode 100644 index 0000000000000..c769a790a4168 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/util/StreamManualClock.scala @@ -0,0 +1,51 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.streaming.util + +import org.apache.spark.util.ManualClock + +/** + * ManualClock used for streaming tests that allows checking whether the stream is waiting + * on the clock at expected times. + */ +class StreamManualClock(time: Long = 0L) extends ManualClock(time) with Serializable { + private var waitStartTime: Option[Long] = None + private var waitTargetTime: Option[Long] = None + + override def waitTillTime(targetTime: Long): Long = synchronized { + try { + waitStartTime = Some(getTimeMillis()) + waitTargetTime = Some(targetTime) + super.waitTillTime(targetTime) + } finally { + waitStartTime = None + waitTargetTime = None + } + } + + /** Is the streaming thread waiting for the clock to advance when it is at the given time */ + def isStreamWaitingAt(time: Long): Boolean = synchronized { + waitStartTime == Some(time) + } + + /** Is the streaming thread waiting for clock to advance to the given time */ + def isStreamWaitingFor(target: Long): Boolean = synchronized { + waitTargetTime == Some(target) + } +} + From 6f09dc70d9808cae004ceda9ad615aa9be50f43d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Oliver=20K=C3=B6th?= Date: Wed, 5 Apr 2017 08:09:42 +0100 Subject: [PATCH 0199/1765] [SPARK-20042][WEB UI] Fix log page buttons for reverse proxy mode MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit with spark.ui.reverseProxy=true, full path URLs like /log will point to the master web endpoint which is serving the worker UI as reverse proxy. To access a REST endpoint in the worker in reverse proxy mode , the leading /proxy/"target"/ part of the base URI must be retained. Added logic to log-view.js to handle this, similar to executorspage.js Patch was tested manually Author: Oliver Köth Closes #17370 from okoethibm/master. --- .../org/apache/spark/ui/static/log-view.js | 19 ++++++++++++++++--- 1 file changed, 16 insertions(+), 3 deletions(-) diff --git a/core/src/main/resources/org/apache/spark/ui/static/log-view.js b/core/src/main/resources/org/apache/spark/ui/static/log-view.js index 1782b4f209c09..b5c43e5788bc3 100644 --- a/core/src/main/resources/org/apache/spark/ui/static/log-view.js +++ b/core/src/main/resources/org/apache/spark/ui/static/log-view.js @@ -51,13 +51,26 @@ function noNewAlert() { window.setTimeout(function () {alert.css("display", "none");}, 4000); } + +function getRESTEndPoint() { + // If the worker is served from the master through a proxy (see doc on spark.ui.reverseProxy), + // we need to retain the leading ../proxy// part of the URL when making REST requests. + // Similar logic is contained in executorspage.js function createRESTEndPoint. + var words = document.baseURI.split('/'); + var ind = words.indexOf("proxy"); + if (ind > 0) { + return words.slice(0, ind + 2).join('/') + "/log"; + } + return "/log" +} + function loadMore() { var offset = Math.max(startByte - byteLength, 0); var moreByteLength = Math.min(byteLength, startByte); $.ajax({ type: "GET", - url: "/log" + baseParams + "&offset=" + offset + "&byteLength=" + moreByteLength, + url: getRESTEndPoint() + baseParams + "&offset=" + offset + "&byteLength=" + moreByteLength, success: function (data) { var oldHeight = $(".log-content")[0].scrollHeight; var newlineIndex = data.indexOf('\n'); @@ -83,14 +96,14 @@ function loadMore() { function loadNew() { $.ajax({ type: "GET", - url: "/log" + baseParams + "&byteLength=0", + url: getRESTEndPoint() + baseParams + "&byteLength=0", success: function (data) { var dataInfo = data.substring(0, data.indexOf('\n')).match(/\d+/g); var newDataLen = dataInfo[2] - totalLogLength; if (newDataLen != 0) { $.ajax({ type: "GET", - url: "/log" + baseParams + "&byteLength=" + newDataLen, + url: getRESTEndPoint() + baseParams + "&byteLength=" + newDataLen, success: function (data) { var newlineIndex = data.indexOf('\n'); var dataInfo = data.substring(0, newlineIndex).match(/\d+/g); From 71c3c48159fe7eb4a46fc2a1b78b72088ccfa824 Mon Sep 17 00:00:00 2001 From: shaolinliu Date: Wed, 5 Apr 2017 13:47:44 +0100 Subject: [PATCH 0200/1765] [SPARK-19807][WEB UI] Add reason for cancellation when a stage is killed using web UI ## What changes were proposed in this pull request? When a user kills a stage using web UI (in Stages page), StagesTab.handleKillRequest requests SparkContext to cancel the stage without giving a reason. SparkContext has cancelStage(stageId: Int, reason: String) that Spark could use to pass the information for monitoring/debugging purposes. ## How was this patch tested? manual tests Please review http://spark.apache.org/contributing.html before opening a pull request. Author: shaolinliu Author: lvdongr Closes #17258 from shaolinliu/SPARK-19807. --- core/src/main/scala/org/apache/spark/ui/jobs/StagesTab.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/StagesTab.scala b/core/src/main/scala/org/apache/spark/ui/jobs/StagesTab.scala index c1f25114371f1..181465bdf9609 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/StagesTab.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/StagesTab.scala @@ -42,7 +42,7 @@ private[ui] class StagesTab(parent: SparkUI) extends SparkUITab(parent, "stages" val stageId = Option(request.getParameter("id")).map(_.toInt) stageId.foreach { id => if (progressListener.activeStages.contains(id)) { - sc.foreach(_.cancelStage(id)) + sc.foreach(_.cancelStage(id, "killed via the Web UI")) // Do a quick pause here to give Spark time to kill the stage so it shows up as // killed after the refresh. Note that this will block the serving thread so the // time should be limited in duration. From a2d8d767d933321426a4eb9df1583e017722d7d6 Mon Sep 17 00:00:00 2001 From: wangzhenhua Date: Wed, 5 Apr 2017 10:21:43 -0700 Subject: [PATCH 0201/1765] [SPARK-20223][SQL] Fix typo in tpcds q77.sql ## What changes were proposed in this pull request? Fix typo in tpcds q77.sql ## How was this patch tested? N/A Author: wangzhenhua Closes #17538 from wzhfy/typoQ77. --- sql/core/src/test/resources/tpcds/q77.sql | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/core/src/test/resources/tpcds/q77.sql b/sql/core/src/test/resources/tpcds/q77.sql index 7830f96e76515..a69df9fbcd366 100755 --- a/sql/core/src/test/resources/tpcds/q77.sql +++ b/sql/core/src/test/resources/tpcds/q77.sql @@ -36,7 +36,7 @@ WITH ss AS sum(cr_net_loss) AS profit_loss FROM catalog_returns, date_dim WHERE cr_returned_date_sk = d_date_sk - AND d_date BETWEEN cast('2000-08-03]' AS DATE) AND + AND d_date BETWEEN cast('2000-08-03' AS DATE) AND (cast('2000-08-03' AS DATE) + INTERVAL 30 days)), ws AS (SELECT From e2773996b8d1c0214d9ffac634a059b4923caf7b Mon Sep 17 00:00:00 2001 From: zero323 Date: Wed, 5 Apr 2017 11:47:40 -0700 Subject: [PATCH 0202/1765] [SPARK-19454][PYTHON][SQL] DataFrame.replace improvements ## What changes were proposed in this pull request? - Allows skipping `value` argument if `to_replace` is a `dict`: ```python df = sc.parallelize([("Alice", 1, 3.0)]).toDF() df.replace({"Alice": "Bob"}).show() ```` - Adds validation step to ensure homogeneous values / replacements. - Simplifies internal control flow. - Improves unit tests coverage. ## How was this patch tested? Existing unit tests, additional unit tests, manual testing. Author: zero323 Closes #16793 from zero323/SPARK-19454. --- python/pyspark/sql/dataframe.py | 81 +++++++++++++++++++++++---------- python/pyspark/sql/tests.py | 72 +++++++++++++++++++++++++++++ 2 files changed, 128 insertions(+), 25 deletions(-) diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index a24512f53c525..774caf53f3a4b 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -25,6 +25,8 @@ else: from itertools import imap as map +import warnings + from pyspark import copy_func, since from pyspark.rdd import RDD, _load_from_socket, ignore_unicode_prefix from pyspark.serializers import BatchedSerializer, PickleSerializer, UTF8Deserializer @@ -1281,7 +1283,7 @@ def fillna(self, value, subset=None): return DataFrame(self._jdf.na().fill(value, self._jseq(subset)), self.sql_ctx) @since(1.4) - def replace(self, to_replace, value, subset=None): + def replace(self, to_replace, value=None, subset=None): """Returns a new :class:`DataFrame` replacing a value with another value. :func:`DataFrame.replace` and :func:`DataFrameNaFunctions.replace` are aliases of each other. @@ -1326,43 +1328,72 @@ def replace(self, to_replace, value, subset=None): |null| null|null| +----+------+----+ """ - if not isinstance(to_replace, (float, int, long, basestring, list, tuple, dict)): + # Helper functions + def all_of(types): + """Given a type or tuple of types and a sequence of xs + check if each x is instance of type(s) + + >>> all_of(bool)([True, False]) + True + >>> all_of(basestring)(["a", 1]) + False + """ + def all_of_(xs): + return all(isinstance(x, types) for x in xs) + return all_of_ + + all_of_bool = all_of(bool) + all_of_str = all_of(basestring) + all_of_numeric = all_of((float, int, long)) + + # Validate input types + valid_types = (bool, float, int, long, basestring, list, tuple) + if not isinstance(to_replace, valid_types + (dict, )): raise ValueError( - "to_replace should be a float, int, long, string, list, tuple, or dict") + "to_replace should be a float, int, long, string, list, tuple, or dict. " + "Got {0}".format(type(to_replace))) - if not isinstance(value, (float, int, long, basestring, list, tuple)): - raise ValueError("value should be a float, int, long, string, list, or tuple") + if not isinstance(value, valid_types) and not isinstance(to_replace, dict): + raise ValueError("If to_replace is not a dict, value should be " + "a float, int, long, string, list, or tuple. " + "Got {0}".format(type(value))) + + if isinstance(to_replace, (list, tuple)) and isinstance(value, (list, tuple)): + if len(to_replace) != len(value): + raise ValueError("to_replace and value lists should be of the same length. " + "Got {0} and {1}".format(len(to_replace), len(value))) - rep_dict = dict() + if not (subset is None or isinstance(subset, (list, tuple, basestring))): + raise ValueError("subset should be a list or tuple of column names, " + "column name or None. Got {0}".format(type(subset))) + # Reshape input arguments if necessary if isinstance(to_replace, (float, int, long, basestring)): to_replace = [to_replace] - if isinstance(to_replace, tuple): - to_replace = list(to_replace) + if isinstance(value, (float, int, long, basestring)): + value = [value for _ in range(len(to_replace))] - if isinstance(value, tuple): - value = list(value) - - if isinstance(to_replace, list) and isinstance(value, list): - if len(to_replace) != len(value): - raise ValueError("to_replace and value lists should be of the same length") - rep_dict = dict(zip(to_replace, value)) - elif isinstance(to_replace, list) and isinstance(value, (float, int, long, basestring)): - rep_dict = dict([(tr, value) for tr in to_replace]) - elif isinstance(to_replace, dict): + if isinstance(to_replace, dict): rep_dict = to_replace + if value is not None: + warnings.warn("to_replace is a dict and value is not None. value will be ignored.") + else: + rep_dict = dict(zip(to_replace, value)) - if subset is None: - return DataFrame(self._jdf.na().replace('*', rep_dict), self.sql_ctx) - elif isinstance(subset, basestring): + if isinstance(subset, basestring): subset = [subset] - if not isinstance(subset, (list, tuple)): - raise ValueError("subset should be a list or tuple of column names") + # Verify we were not passed in mixed type generics." + if not any(all_of_type(rep_dict.keys()) and all_of_type(rep_dict.values()) + for all_of_type in [all_of_bool, all_of_str, all_of_numeric]): + raise ValueError("Mixed type replacements are not supported") - return DataFrame( - self._jdf.na().replace(self._jseq(subset), self._jmap(rep_dict)), self.sql_ctx) + if subset is None: + return DataFrame(self._jdf.na().replace('*', rep_dict), self.sql_ctx) + else: + return DataFrame( + self._jdf.na().replace(self._jseq(subset), self._jmap(rep_dict)), self.sql_ctx) @since(2.0) def approxQuantile(self, col, probabilities, relativeError): diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index db41b4edb6dde..2b2444304e04a 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -1779,6 +1779,78 @@ def test_replace(self): self.assertEqual(row.age, 10) self.assertEqual(row.height, None) + # replace with lists + row = self.spark.createDataFrame( + [(u'Alice', 10, 80.1)], schema).replace([u'Alice'], [u'Ann']).first() + self.assertTupleEqual(row, (u'Ann', 10, 80.1)) + + # replace with dict + row = self.spark.createDataFrame( + [(u'Alice', 10, 80.1)], schema).replace({10: 11}).first() + self.assertTupleEqual(row, (u'Alice', 11, 80.1)) + + # test backward compatibility with dummy value + dummy_value = 1 + row = self.spark.createDataFrame( + [(u'Alice', 10, 80.1)], schema).replace({'Alice': 'Bob'}, dummy_value).first() + self.assertTupleEqual(row, (u'Bob', 10, 80.1)) + + # test dict with mixed numerics + row = self.spark.createDataFrame( + [(u'Alice', 10, 80.1)], schema).replace({10: -10, 80.1: 90.5}).first() + self.assertTupleEqual(row, (u'Alice', -10, 90.5)) + + # replace with tuples + row = self.spark.createDataFrame( + [(u'Alice', 10, 80.1)], schema).replace((u'Alice', ), (u'Bob', )).first() + self.assertTupleEqual(row, (u'Bob', 10, 80.1)) + + # replace multiple columns + row = self.spark.createDataFrame( + [(u'Alice', 10, 80.0)], schema).replace((10, 80.0), (20, 90)).first() + self.assertTupleEqual(row, (u'Alice', 20, 90.0)) + + # test for mixed numerics + row = self.spark.createDataFrame( + [(u'Alice', 10, 80.0)], schema).replace((10, 80), (20, 90.5)).first() + self.assertTupleEqual(row, (u'Alice', 20, 90.5)) + + row = self.spark.createDataFrame( + [(u'Alice', 10, 80.0)], schema).replace({10: 20, 80: 90.5}).first() + self.assertTupleEqual(row, (u'Alice', 20, 90.5)) + + # replace with boolean + row = (self + .spark.createDataFrame([(u'Alice', 10, 80.0)], schema) + .selectExpr("name = 'Bob'", 'age <= 15') + .replace(False, True).first()) + self.assertTupleEqual(row, (True, True)) + + # should fail if subset is not list, tuple or None + with self.assertRaises(ValueError): + self.spark.createDataFrame( + [(u'Alice', 10, 80.1)], schema).replace({10: 11}, subset=1).first() + + # should fail if to_replace and value have different length + with self.assertRaises(ValueError): + self.spark.createDataFrame( + [(u'Alice', 10, 80.1)], schema).replace(["Alice", "Bob"], ["Eve"]).first() + + # should fail if when received unexpected type + with self.assertRaises(ValueError): + from datetime import datetime + self.spark.createDataFrame( + [(u'Alice', 10, 80.1)], schema).replace(datetime.now(), datetime.now()).first() + + # should fail if provided mixed type replacements + with self.assertRaises(ValueError): + self.spark.createDataFrame( + [(u'Alice', 10, 80.1)], schema).replace(["Alice", 10], ["Eve", 20]).first() + + with self.assertRaises(ValueError): + self.spark.createDataFrame( + [(u'Alice', 10, 80.1)], schema).replace({u"Alice": u"Bob", 10: 20}).first() + def test_capture_analysis_exception(self): self.assertRaises(AnalysisException, lambda: self.spark.sql("select abc")) self.assertRaises(AnalysisException, lambda: self.df.selectExpr("a + b")) From 9543fc0e08a21680961689ea772441c49fcd52ee Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Wed, 5 Apr 2017 16:03:04 -0700 Subject: [PATCH 0203/1765] [SPARK-20224][SS] Updated docs for streaming dropDuplicates and mapGroupsWithState ## What changes were proposed in this pull request? - Fixed bug in Java API not passing timeout conf to scala API - Updated markdown docs - Updated scala docs - Added scala and Java example ## How was this patch tested? Manually ran examples. Author: Tathagata Das Closes #17539 from tdas/SPARK-20224. --- .../structured-streaming-programming-guide.md | 98 ++++++- .../JavaStructuredSessionization.java | 255 ++++++++++++++++++ .../streaming/StructuredSessionization.scala | 151 +++++++++++ .../spark/sql/KeyValueGroupedDataset.scala | 2 +- .../spark/sql/streaming/GroupState.scala | 15 +- 5 files changed, 509 insertions(+), 12 deletions(-) create mode 100644 examples/src/main/java/org/apache/spark/examples/sql/streaming/JavaStructuredSessionization.java create mode 100644 examples/src/main/scala/org/apache/spark/examples/sql/streaming/StructuredSessionization.scala diff --git a/docs/structured-streaming-programming-guide.md b/docs/structured-streaming-programming-guide.md index b5cf9f1644986..37a1d6189a42d 100644 --- a/docs/structured-streaming-programming-guide.md +++ b/docs/structured-streaming-programming-guide.md @@ -1,6 +1,6 @@ --- layout: global -displayTitle: Structured Streaming Programming Guide [Alpha] +displayTitle: Structured Streaming Programming Guide [Experimental] title: Structured Streaming Programming Guide --- @@ -871,6 +871,65 @@ streamingDf.join(staticDf, "type", "right_join") # right outer join with a stat +### Streaming Deduplication +You can deduplicate records in data streams using a unique identifier in the events. This is exactly same as deduplication on static using a unique identifier column. The query will store the necessary amount of data from previous records such that it can filter duplicate records. Similar to aggregations, you can use deduplication with or without watermarking. + +- *With watermark* - If there is a upper bound on how late a duplicate record may arrive, then you can define a watermark on a event time column and deduplicate using both the guid and the event time columns. The query will use the watermark to remove old state data from past records that are not expected to get any duplicates any more. This bounds the amount of the state the query has to maintain. + +- *Without watermark* - Since there are no bounds on when a duplicate record may arrive, the query stores the data from all the past records as state. + +
    +
    + +{% highlight scala %} +val streamingDf = spark.readStream. ... // columns: guid, eventTime, ... + +// Without watermark using guid column +streamingDf.dropDuplicates("guid") + +// With watermark using guid and eventTime columns +streamingDf + .withWatermark("eventTime", "10 seconds") + .dropDuplicates("guid", "eventTime") +{% endhighlight %} + +
    +
    + +{% highlight java %} +Dataset streamingDf = spark.readStream. ...; // columns: guid, eventTime, ... + +// Without watermark using guid column +streamingDf.dropDuplicates("guid"); + +// With watermark using guid and eventTime columns +streamingDf + .withWatermark("eventTime", "10 seconds") + .dropDuplicates("guid", "eventTime"); +{% endhighlight %} + + +
    +
    + +{% highlight python %} +streamingDf = spark.readStream. ... + +// Without watermark using guid column +streamingDf.dropDuplicates("guid") + +// With watermark using guid and eventTime columns +streamingDf \ + .withWatermark("eventTime", "10 seconds") \ + .dropDuplicates("guid", "eventTime") +{% endhighlight %} + +
    +
    + +### Arbitrary Stateful Operations +Many uscases require more advanced stateful operations than aggregations. For example, in many usecases, you have to track sessions from data streams of events. For doing such sessionization, you will have to save arbitrary types of data as state, and perform arbitrary operations on the state using the data stream events in every trigger. Since Spark 2.2, this can be done using the operation `mapGroupsWithState` and the more powerful operation `flatMapGroupsWithState`. Both operations allow you to apply user-defined code on grouped Datasets to update user-defined state. For more concrete details, take a look at the API documentation ([Scala](api/scala/index.html#org.apache.spark.sql.streaming.GroupState)/[Java](api/java/org/apache/spark/sql/streaming/GroupState.html)) and the examples ([Scala]({{site.SPARK_GITHUB_URL}}/blob/v{{site.SPARK_VERSION_SHORT}}/examples/src/main/scala/org/apache/spark/examples/sql/streaming/StructuredSessionization.scala)/[Java]({{site.SPARK_GITHUB_URL}}/blob/v{{site.SPARK_VERSION_SHORT}}/examples/src/main/java/org/apache/spark/examples/sql/streaming/JavaStructuredSessionization.java)). + ### Unsupported Operations There are a few DataFrame/Dataset operations that are not supported with streaming DataFrames/Datasets. Some of them are as follows. @@ -891,7 +950,7 @@ Some of them are as follows. + Right outer join with a streaming Dataset on the left is not supported -- Any kind of joins between two streaming Datasets are not yet supported. +- Any kind of joins between two streaming Datasets is not yet supported. In addition, there are some Dataset methods that will not work on streaming Datasets. They are actions that will immediately run queries and return results, which does not make sense on a streaming Dataset. Rather, those functionalities can be done by explicitly starting a streaming query (see the next section regarding that). @@ -951,13 +1010,6 @@ Here is the compatibility matrix.
    - - - - - @@ -986,6 +1038,33 @@ Here is the compatibility matrix. this mode. + + + + + + + + + + + + + + + + + + + + + @@ -994,6 +1073,7 @@ Here is the compatibility matrix.
    EndpointMeaning
    /applications/[app-id]/jobs A list of all jobs for a given application. -
    ?status=[complete|succeeded|failed] list only jobs in the specific state. +
    ?status=[running|succeeded|failed|unknown] list only jobs in the specific state.
    Supported Output Modes Notes
    Queries without aggregationAppend, Update - Complete mode not supported as it is infeasible to keep all data in the Result Table. -
    Queries with aggregation Aggregation on event-time with watermark
    Queries with mapGroupsWithStateUpdate
    Queries with flatMapGroupsWithStateAppend operation modeAppend + Aggregations are allowed after flatMapGroupsWithState. +
    Update operation modeUpdate + Aggregations not allowed after flatMapGroupsWithState. +
    Other queriesAppend, Update + Complete mode not supported as it is infeasible to keep all unaggregated data in the Result Table. +
    + #### Output Sinks There are a few types of built-in output sinks. diff --git a/examples/src/main/java/org/apache/spark/examples/sql/streaming/JavaStructuredSessionization.java b/examples/src/main/java/org/apache/spark/examples/sql/streaming/JavaStructuredSessionization.java new file mode 100644 index 0000000000000..da3a5dfe8628b --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/sql/streaming/JavaStructuredSessionization.java @@ -0,0 +1,255 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.examples.sql.streaming; + +import org.apache.spark.api.java.function.FlatMapFunction; +import org.apache.spark.api.java.function.MapFunction; +import org.apache.spark.api.java.function.MapGroupsWithStateFunction; +import org.apache.spark.sql.*; +import org.apache.spark.sql.streaming.GroupState; +import org.apache.spark.sql.streaming.GroupStateTimeout; +import org.apache.spark.sql.streaming.StreamingQuery; + +import java.io.Serializable; +import java.sql.Timestamp; +import java.util.*; + +import scala.Tuple2; + +/** + * Counts words in UTF8 encoded, '\n' delimited text received from the network. + *

    + * Usage: JavaStructuredNetworkWordCount + * and describe the TCP server that Structured Streaming + * would connect to receive data. + *

    + * To run this on your local machine, you need to first run a Netcat server + * `$ nc -lk 9999` + * and then run the example + * `$ bin/run-example sql.streaming.JavaStructuredSessionization + * localhost 9999` + */ +public final class JavaStructuredSessionization { + + public static void main(String[] args) throws Exception { + if (args.length < 2) { + System.err.println("Usage: JavaStructuredSessionization "); + System.exit(1); + } + + String host = args[0]; + int port = Integer.parseInt(args[1]); + + SparkSession spark = SparkSession + .builder() + .appName("JavaStructuredSessionization") + .getOrCreate(); + + // Create DataFrame representing the stream of input lines from connection to host:port + Dataset lines = spark + .readStream() + .format("socket") + .option("host", host) + .option("port", port) + .option("includeTimestamp", true) + .load(); + + FlatMapFunction linesToEvents = + new FlatMapFunction() { + @Override + public Iterator call(LineWithTimestamp lineWithTimestamp) throws Exception { + ArrayList eventList = new ArrayList(); + for (String word : lineWithTimestamp.getLine().split(" ")) { + eventList.add(new Event(word, lineWithTimestamp.getTimestamp())); + } + System.out.println( + "Number of events from " + lineWithTimestamp.getLine() + " = " + eventList.size()); + return eventList.iterator(); + } + }; + + // Split the lines into words, treat words as sessionId of events + Dataset events = lines + .withColumnRenamed("value", "line") + .as(Encoders.bean(LineWithTimestamp.class)) + .flatMap(linesToEvents, Encoders.bean(Event.class)); + + // Sessionize the events. Track number of events, start and end timestamps of session, and + // and report session updates. + // + // Step 1: Define the state update function + MapGroupsWithStateFunction stateUpdateFunc = + new MapGroupsWithStateFunction() { + @Override public SessionUpdate call( + String sessionId, Iterator events, GroupState state) + throws Exception { + // If timed out, then remove session and send final update + if (state.hasTimedOut()) { + SessionUpdate finalUpdate = new SessionUpdate( + sessionId, state.get().getDurationMs(), state.get().getNumEvents(), true); + state.remove(); + return finalUpdate; + + } else { + // Find max and min timestamps in events + long maxTimestampMs = Long.MIN_VALUE; + long minTimestampMs = Long.MAX_VALUE; + int numNewEvents = 0; + while (events.hasNext()) { + Event e = events.next(); + long timestampMs = e.getTimestamp().getTime(); + maxTimestampMs = Math.max(timestampMs, maxTimestampMs); + minTimestampMs = Math.min(timestampMs, minTimestampMs); + numNewEvents += 1; + } + SessionInfo updatedSession = new SessionInfo(); + + // Update start and end timestamps in session + if (state.exists()) { + SessionInfo oldSession = state.get(); + updatedSession.setNumEvents(oldSession.numEvents + numNewEvents); + updatedSession.setStartTimestampMs(oldSession.startTimestampMs); + updatedSession.setEndTimestampMs(Math.max(oldSession.endTimestampMs, maxTimestampMs)); + } else { + updatedSession.setNumEvents(numNewEvents); + updatedSession.setStartTimestampMs(minTimestampMs); + updatedSession.setEndTimestampMs(maxTimestampMs); + } + state.update(updatedSession); + // Set timeout such that the session will be expired if no data received for 10 seconds + state.setTimeoutDuration("10 seconds"); + return new SessionUpdate( + sessionId, state.get().getDurationMs(), state.get().getNumEvents(), false); + } + } + }; + + // Step 2: Apply the state update function to the events streaming Dataset grouped by sessionId + Dataset sessionUpdates = events + .groupByKey( + new MapFunction() { + @Override public String call(Event event) throws Exception { + return event.getSessionId(); + } + }, Encoders.STRING()) + .mapGroupsWithState( + stateUpdateFunc, + Encoders.bean(SessionInfo.class), + Encoders.bean(SessionUpdate.class), + GroupStateTimeout.ProcessingTimeTimeout()); + + // Start running the query that prints the session updates to the console + StreamingQuery query = sessionUpdates + .writeStream() + .outputMode("update") + .format("console") + .start(); + + query.awaitTermination(); + } + + /** + * User-defined data type representing the raw lines with timestamps. + */ + public static class LineWithTimestamp implements Serializable { + private String line; + private Timestamp timestamp; + + public Timestamp getTimestamp() { return timestamp; } + public void setTimestamp(Timestamp timestamp) { this.timestamp = timestamp; } + + public String getLine() { return line; } + public void setLine(String sessionId) { this.line = sessionId; } + } + + /** + * User-defined data type representing the input events + */ + public static class Event implements Serializable { + private String sessionId; + private Timestamp timestamp; + + public Event() { } + public Event(String sessionId, Timestamp timestamp) { + this.sessionId = sessionId; + this.timestamp = timestamp; + } + + public Timestamp getTimestamp() { return timestamp; } + public void setTimestamp(Timestamp timestamp) { this.timestamp = timestamp; } + + public String getSessionId() { return sessionId; } + public void setSessionId(String sessionId) { this.sessionId = sessionId; } + } + + /** + * User-defined data type for storing a session information as state in mapGroupsWithState. + */ + public static class SessionInfo implements Serializable { + private int numEvents = 0; + private long startTimestampMs = -1; + private long endTimestampMs = -1; + + public int getNumEvents() { return numEvents; } + public void setNumEvents(int numEvents) { this.numEvents = numEvents; } + + public long getStartTimestampMs() { return startTimestampMs; } + public void setStartTimestampMs(long startTimestampMs) { + this.startTimestampMs = startTimestampMs; + } + + public long getEndTimestampMs() { return endTimestampMs; } + public void setEndTimestampMs(long endTimestampMs) { this.endTimestampMs = endTimestampMs; } + + public long getDurationMs() { return endTimestampMs - startTimestampMs; } + @Override public String toString() { + return "SessionInfo(numEvents = " + numEvents + + ", timestamps = " + startTimestampMs + " to " + endTimestampMs + ")"; + } + } + + /** + * User-defined data type representing the update information returned by mapGroupsWithState. + */ + public static class SessionUpdate implements Serializable { + private String id; + private long durationMs; + private int numEvents; + private boolean expired; + + public SessionUpdate() { } + + public SessionUpdate(String id, long durationMs, int numEvents, boolean expired) { + this.id = id; + this.durationMs = durationMs; + this.numEvents = numEvents; + this.expired = expired; + } + + public String getId() { return id; } + public void setId(String id) { this.id = id; } + + public long getDurationMs() { return durationMs; } + public void setDurationMs(long durationMs) { this.durationMs = durationMs; } + + public int getNumEvents() { return numEvents; } + public void setNumEvents(int numEvents) { this.numEvents = numEvents; } + + public boolean isExpired() { return expired; } + public void setExpired(boolean expired) { this.expired = expired; } + } +} diff --git a/examples/src/main/scala/org/apache/spark/examples/sql/streaming/StructuredSessionization.scala b/examples/src/main/scala/org/apache/spark/examples/sql/streaming/StructuredSessionization.scala new file mode 100644 index 0000000000000..2ce792c00849c --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/sql/streaming/StructuredSessionization.scala @@ -0,0 +1,151 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// scalastyle:off println +package org.apache.spark.examples.sql.streaming + +import java.sql.Timestamp + +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.streaming._ + + +/** + * Counts words in UTF8 encoded, '\n' delimited text received from the network. + * + * Usage: MapGroupsWithState + * and describe the TCP server that Structured Streaming + * would connect to receive data. + * + * To run this on your local machine, you need to first run a Netcat server + * `$ nc -lk 9999` + * and then run the example + * `$ bin/run-example sql.streaming.StructuredNetworkWordCount + * localhost 9999` + */ +object StructuredSessionization { + + def main(args: Array[String]): Unit = { + if (args.length < 2) { + System.err.println("Usage: StructuredNetworkWordCount ") + System.exit(1) + } + + val host = args(0) + val port = args(1).toInt + + val spark = SparkSession + .builder + .appName("StructuredSessionization") + .getOrCreate() + + import spark.implicits._ + + // Create DataFrame representing the stream of input lines from connection to host:port + val lines = spark.readStream + .format("socket") + .option("host", host) + .option("port", port) + .option("includeTimestamp", true) + .load() + + // Split the lines into words, treat words as sessionId of events + val events = lines + .as[(String, Timestamp)] + .flatMap { case (line, timestamp) => + line.split(" ").map(word => Event(sessionId = word, timestamp)) + } + + // Sessionize the events. Track number of events, start and end timestamps of session, and + // and report session updates. + val sessionUpdates = events + .groupByKey(event => event.sessionId) + .mapGroupsWithState[SessionInfo, SessionUpdate](GroupStateTimeout.ProcessingTimeTimeout) { + + case (sessionId: String, events: Iterator[Event], state: GroupState[SessionInfo]) => + + // If timed out, then remove session and send final update + if (state.hasTimedOut) { + val finalUpdate = + SessionUpdate(sessionId, state.get.durationMs, state.get.numEvents, expired = true) + state.remove() + finalUpdate + } else { + // Update start and end timestamps in session + val timestamps = events.map(_.timestamp.getTime).toSeq + val updatedSession = if (state.exists) { + val oldSession = state.get + SessionInfo( + oldSession.numEvents + timestamps.size, + oldSession.startTimestampMs, + math.max(oldSession.endTimestampMs, timestamps.max)) + } else { + SessionInfo(timestamps.size, timestamps.min, timestamps.max) + } + state.update(updatedSession) + + // Set timeout such that the session will be expired if no data received for 10 seconds + state.setTimeoutDuration("10 seconds") + SessionUpdate(sessionId, state.get.durationMs, state.get.numEvents, expired = false) + } + } + + // Start running the query that prints the session updates to the console + val query = sessionUpdates + .writeStream + .outputMode("update") + .format("console") + .start() + + query.awaitTermination() + } +} +/** User-defined data type representing the input events */ +case class Event(sessionId: String, timestamp: Timestamp) + +/** + * User-defined data type for storing a session information as state in mapGroupsWithState. + * + * @param numEvents total number of events received in the session + * @param startTimestampMs timestamp of first event received in the session when it started + * @param endTimestampMs timestamp of last event received in the session before it expired + */ +case class SessionInfo( + numEvents: Int, + startTimestampMs: Long, + endTimestampMs: Long) { + + /** Duration of the session, between the first and last events */ + def durationMs: Long = endTimestampMs - startTimestampMs +} + +/** + * User-defined data type representing the update information returned by mapGroupsWithState. + * + * @param id Id of the session + * @param durationMs Duration the session was active, that is, from first event to its expiry + * @param numEvents Number of events received by the session while it was active + * @param expired Is the session active or expired + */ +case class SessionUpdate( + id: String, + durationMs: Long, + numEvents: Int, + expired: Boolean) + +// scalastyle:on println + diff --git a/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala index 022c2f5629e86..cb42e9e4560cf 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala @@ -347,7 +347,7 @@ class KeyValueGroupedDataset[K, V] private[sql]( stateEncoder: Encoder[S], outputEncoder: Encoder[U], timeoutConf: GroupStateTimeout): Dataset[U] = { - mapGroupsWithState[S, U]( + mapGroupsWithState[S, U](timeoutConf)( (key: K, it: Iterator[V], s: GroupState[S]) => func.call(key, it.asJava, s) )(stateEncoder, outputEncoder) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/GroupState.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/GroupState.scala index 15df906ca7b13..c659ac7fcf3d9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/GroupState.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/GroupState.scala @@ -34,7 +34,7 @@ import org.apache.spark.sql.catalyst.plans.logical.LogicalGroupState * `Dataset.groupByKey()`) while maintaining user-defined per-group state between invocations. * For a static batch Dataset, the function will be invoked once per group. For a streaming * Dataset, the function will be invoked for each group repeatedly in every trigger. - * That is, in every batch of the `streaming.StreamingQuery`, + * That is, in every batch of the `StreamingQuery`, * the function will be invoked once for each group that has data in the trigger. Furthermore, * if timeout is set, then the function will invoked on timed out groups (more detail below). * @@ -42,12 +42,23 @@ import org.apache.spark.sql.catalyst.plans.logical.LogicalGroupState * - The key of the group. * - An iterator containing all the values for this group. * - A user-defined state object set by previous invocations of the given function. + * * In case of a batch Dataset, there is only one invocation and state object will be empty as * there is no prior state. Essentially, for batch Datasets, `[map/flatMap]GroupsWithState` * is equivalent to `[map/flatMap]Groups` and any updates to the state and/or timeouts have * no effect. * - * Important points to note about the function. + * The major difference between `mapGroupsWithState` and `flatMapGroupsWithState` is that the + * former allows the function to return one and only one record, whereas the latter + * allows the function to return any number of records (including no records). Furthermore, the + * `flatMapGroupsWithState` is associated with an operation output mode, which can be either + * `Append` or `Update`. Semantically, this defines whether the output records of one trigger + * is effectively replacing the previously output records (from previous triggers) or is appending + * to the list of previously output records. Essentially, this defines how the Result Table (refer + * to the semantics in the programming guide) is updated, and allows us to reason about the + * semantics of later operations. + * + * Important points to note about the function (both mapGroupsWithState and flatMapGroupsWithState). * - In a trigger, the function will be called only the groups present in the batch. So do not * assume that the function will be called in every trigger for every group that has state. * - There is no guaranteed ordering of values in the iterator in the function, neither with From 9d68c67235481fa33983afb766916b791ca8212a Mon Sep 17 00:00:00 2001 From: Dilip Biswal Date: Thu, 6 Apr 2017 08:33:14 +0800 Subject: [PATCH 0204/1765] [SPARK-20204][SQL][FOLLOWUP] SQLConf should react to change in default timezone settings ## What changes were proposed in this pull request? Make sure SESSION_LOCAL_TIMEZONE reflects the change in JVM's default timezone setting. Currently several timezone related tests fail as the change to default timezone is not picked up by SQLConf. ## How was this patch tested? Added an unit test in ConfigEntrySuite Author: Dilip Biswal Closes #17537 from dilipbiswal/timezone_debug. --- .../spark/internal/config/ConfigBuilder.scala | 8 ++++++++ .../spark/internal/config/ConfigEntry.scala | 17 +++++++++++++++++ .../internal/config/ConfigEntrySuite.scala | 9 +++++++++ .../org/apache/spark/sql/internal/SQLConf.scala | 2 +- 4 files changed, 35 insertions(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/internal/config/ConfigBuilder.scala b/core/src/main/scala/org/apache/spark/internal/config/ConfigBuilder.scala index b9921138cc6c7..e5d60a7ef0984 100644 --- a/core/src/main/scala/org/apache/spark/internal/config/ConfigBuilder.scala +++ b/core/src/main/scala/org/apache/spark/internal/config/ConfigBuilder.scala @@ -147,6 +147,14 @@ private[spark] class TypedConfigBuilder[T]( } } + /** Creates a [[ConfigEntry]] with a function to determine the default value */ + def createWithDefaultFunction(defaultFunc: () => T): ConfigEntry[T] = { + val entry = new ConfigEntryWithDefaultFunction[T](parent.key, defaultFunc, converter, + stringConverter, parent._doc, parent._public) + parent._onCreate.foreach(_ (entry)) + entry + } + /** * Creates a [[ConfigEntry]] that has a default value. The default value is provided as a * [[String]] and must be a valid value for the entry. diff --git a/core/src/main/scala/org/apache/spark/internal/config/ConfigEntry.scala b/core/src/main/scala/org/apache/spark/internal/config/ConfigEntry.scala index 4f3e42bb3c94e..e86712e84d6ac 100644 --- a/core/src/main/scala/org/apache/spark/internal/config/ConfigEntry.scala +++ b/core/src/main/scala/org/apache/spark/internal/config/ConfigEntry.scala @@ -78,7 +78,24 @@ private class ConfigEntryWithDefault[T] ( def readFrom(reader: ConfigReader): T = { reader.get(key).map(valueConverter).getOrElse(_defaultValue) } +} + +private class ConfigEntryWithDefaultFunction[T] ( + key: String, + _defaultFunction: () => T, + valueConverter: String => T, + stringConverter: T => String, + doc: String, + isPublic: Boolean) + extends ConfigEntry(key, valueConverter, stringConverter, doc, isPublic) { + + override def defaultValue: Option[T] = Some(_defaultFunction()) + override def defaultValueString: String = stringConverter(_defaultFunction()) + + def readFrom(reader: ConfigReader): T = { + reader.get(key).map(valueConverter).getOrElse(_defaultFunction()) + } } private class ConfigEntryWithDefaultString[T] ( diff --git a/core/src/test/scala/org/apache/spark/internal/config/ConfigEntrySuite.scala b/core/src/test/scala/org/apache/spark/internal/config/ConfigEntrySuite.scala index 3ff7e84d73bd4..e2ba0d2a53d04 100644 --- a/core/src/test/scala/org/apache/spark/internal/config/ConfigEntrySuite.scala +++ b/core/src/test/scala/org/apache/spark/internal/config/ConfigEntrySuite.scala @@ -251,4 +251,13 @@ class ConfigEntrySuite extends SparkFunSuite { .createWithDefault(null) testEntryRef(nullConf, ref(nullConf)) } + + test("conf entry : default function") { + var data = 0 + val conf = new SparkConf() + val iConf = ConfigBuilder(testKey("intval")).intConf.createWithDefaultFunction(() => data) + assert(conf.get(iConf) === 0) + data = 2 + assert(conf.get(iConf) === 2) + } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 5b5d547f8fe54..e685c2bed50ae 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -752,7 +752,7 @@ object SQLConf { buildConf("spark.sql.session.timeZone") .doc("""The ID of session local timezone, e.g. "GMT", "America/Los_Angeles", etc.""") .stringConf - .createWithDefault(TimeZone.getDefault().getID()) + .createWithDefaultFunction(() => TimeZone.getDefault.getID) val WINDOW_EXEC_BUFFER_SPILL_THRESHOLD = buildConf("spark.sql.windowExec.buffer.spill.threshold") From 12206058e8780e202c208b92774df3773eff36ae Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Wed, 5 Apr 2017 17:46:44 -0700 Subject: [PATCH 0205/1765] [SPARK-20214][ML] Make sure converted csc matrix has sorted indices ## What changes were proposed in this pull request? `_convert_to_vector` converts a scipy sparse matrix to csc matrix for initializing `SparseVector`. However, it doesn't guarantee the converted csc matrix has sorted indices and so a failure happens when you do something like that: from scipy.sparse import lil_matrix lil = lil_matrix((4, 1)) lil[1, 0] = 1 lil[3, 0] = 2 _convert_to_vector(lil.todok()) File "/home/jenkins/workspace/python/pyspark/mllib/linalg/__init__.py", line 78, in _convert_to_vector return SparseVector(l.shape[0], csc.indices, csc.data) File "/home/jenkins/workspace/python/pyspark/mllib/linalg/__init__.py", line 556, in __init__ % (self.indices[i], self.indices[i + 1])) TypeError: Indices 3 and 1 are not strictly increasing A simple test can confirm that `dok_matrix.tocsc()` won't guarantee sorted indices: >>> from scipy.sparse import lil_matrix >>> lil = lil_matrix((4, 1)) >>> lil[1, 0] = 1 >>> lil[3, 0] = 2 >>> dok = lil.todok() >>> csc = dok.tocsc() >>> csc.has_sorted_indices 0 >>> csc.indices array([3, 1], dtype=int32) I checked the source codes of scipy. The only way to guarantee it is `csc_matrix.tocsr()` and `csr_matrix.tocsc()`. ## How was this patch tested? Existing tests. Please review http://spark.apache.org/contributing.html before opening a pull request. Author: Liang-Chi Hsieh Closes #17532 from viirya/make-sure-sorted-indices. --- python/pyspark/ml/linalg/__init__.py | 3 +++ python/pyspark/mllib/linalg/__init__.py | 3 +++ python/pyspark/mllib/tests.py | 11 +++++++++++ 3 files changed, 17 insertions(+) diff --git a/python/pyspark/ml/linalg/__init__.py b/python/pyspark/ml/linalg/__init__.py index b765343251965..ad1b487676fa7 100644 --- a/python/pyspark/ml/linalg/__init__.py +++ b/python/pyspark/ml/linalg/__init__.py @@ -72,7 +72,10 @@ def _convert_to_vector(l): return DenseVector(l) elif _have_scipy and scipy.sparse.issparse(l): assert l.shape[1] == 1, "Expected column vector" + # Make sure the converted csc_matrix has sorted indices. csc = l.tocsc() + if not csc.has_sorted_indices: + csc.sort_indices() return SparseVector(l.shape[0], csc.indices, csc.data) else: raise TypeError("Cannot convert type %s into Vector" % type(l)) diff --git a/python/pyspark/mllib/linalg/__init__.py b/python/pyspark/mllib/linalg/__init__.py index 031f22c02098e..7b24b3c74a9fa 100644 --- a/python/pyspark/mllib/linalg/__init__.py +++ b/python/pyspark/mllib/linalg/__init__.py @@ -74,7 +74,10 @@ def _convert_to_vector(l): return DenseVector(l) elif _have_scipy and scipy.sparse.issparse(l): assert l.shape[1] == 1, "Expected column vector" + # Make sure the converted csc_matrix has sorted indices. csc = l.tocsc() + if not csc.has_sorted_indices: + csc.sort_indices() return SparseVector(l.shape[0], csc.indices, csc.data) else: raise TypeError("Cannot convert type %s into Vector" % type(l)) diff --git a/python/pyspark/mllib/tests.py b/python/pyspark/mllib/tests.py index c519883cdd73b..523b3f1113317 100644 --- a/python/pyspark/mllib/tests.py +++ b/python/pyspark/mllib/tests.py @@ -853,6 +853,17 @@ def serialize(l): self.assertEqual(sv, serialize(lil.tocsr())) self.assertEqual(sv, serialize(lil.todok())) + def test_convert_to_vector(self): + from scipy.sparse import csc_matrix + # Create a CSC matrix with non-sorted indices + indptr = array([0, 2]) + indices = array([3, 1]) + data = array([2.0, 1.0]) + csc = csc_matrix((data, indices, indptr)) + self.assertFalse(csc.has_sorted_indices) + sv = SparseVector(4, {1: 1, 3: 2}) + self.assertEqual(sv, _convert_to_vector(csc)) + def test_dot(self): from scipy.sparse import lil_matrix lil = lil_matrix((4, 1)) From 4000f128b7101484ba618115504ca916c22fa84a Mon Sep 17 00:00:00 2001 From: Ioana Delaney Date: Wed, 5 Apr 2017 18:02:53 -0700 Subject: [PATCH 0206/1765] [SPARK-20231][SQL] Refactor star schema code for the subsequent star join detection in CBO ## What changes were proposed in this pull request? This commit moves star schema code from ```join.scala``` to ```StarSchemaDetection.scala```. It also applies some minor fixes in ```StarJoinReorderSuite.scala```. ## How was this patch tested? Run existing ```StarJoinReorderSuite.scala```. Author: Ioana Delaney Closes #17544 from ioana-delaney/starSchemaCBOv2. --- .../optimizer/StarSchemaDetection.scala | 351 ++++++++++++++++++ .../spark/sql/catalyst/optimizer/joins.scala | 328 +--------------- .../optimizer/StarJoinReorderSuite.scala | 4 +- 3 files changed, 354 insertions(+), 329 deletions(-) create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/StarSchemaDetection.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/StarSchemaDetection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/StarSchemaDetection.scala new file mode 100644 index 0000000000000..91cb004eaec46 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/StarSchemaDetection.scala @@ -0,0 +1,351 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.optimizer + +import scala.annotation.tailrec + +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.planning.PhysicalOperation +import org.apache.spark.sql.catalyst.plans._ +import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.internal.SQLConf + +/** + * Encapsulates star-schema detection logic. + */ +case class StarSchemaDetection(conf: SQLConf) extends PredicateHelper { + + /** + * Star schema consists of one or more fact tables referencing a number of dimension + * tables. In general, star-schema joins are detected using the following conditions: + * 1. Informational RI constraints (reliable detection) + * + Dimension contains a primary key that is being joined to the fact table. + * + Fact table contains foreign keys referencing multiple dimension tables. + * 2. Cardinality based heuristics + * + Usually, the table with the highest cardinality is the fact table. + * + Table being joined with the most number of tables is the fact table. + * + * To detect star joins, the algorithm uses a combination of the above two conditions. + * The fact table is chosen based on the cardinality heuristics, and the dimension + * tables are chosen based on the RI constraints. A star join will consist of the largest + * fact table joined with the dimension tables on their primary keys. To detect that a + * column is a primary key, the algorithm uses table and column statistics. + * + * The algorithm currently returns only the star join with the largest fact table. + * Choosing the largest fact table on the driving arm to avoid large inners is in + * general a good heuristic. This restriction will be lifted to observe multiple + * star joins. + * + * The highlights of the algorithm are the following: + * + * Given a set of joined tables/plans, the algorithm first verifies if they are eligible + * for star join detection. An eligible plan is a base table access with valid statistics. + * A base table access represents Project or Filter operators above a LeafNode. Conservatively, + * the algorithm only considers base table access as part of a star join since they provide + * reliable statistics. This restriction can be lifted with the CBO enablement by default. + * + * If some of the plans are not base table access, or statistics are not available, the algorithm + * returns an empty star join plan since, in the absence of statistics, it cannot make + * good planning decisions. Otherwise, the algorithm finds the table with the largest cardinality + * (number of rows), which is assumed to be a fact table. + * + * Next, it computes the set of dimension tables for the current fact table. A dimension table + * is assumed to be in a RI relationship with a fact table. To infer column uniqueness, + * the algorithm compares the number of distinct values with the total number of rows in the + * table. If their relative difference is within certain limits (i.e. ndvMaxError * 2, adjusted + * based on 1TB TPC-DS data), the column is assumed to be unique. + */ + def findStarJoins( + input: Seq[LogicalPlan], + conditions: Seq[Expression]): Seq[LogicalPlan] = { + + val emptyStarJoinPlan = Seq.empty[LogicalPlan] + + if (!conf.starSchemaDetection || input.size < 2) { + emptyStarJoinPlan + } else { + // Find if the input plans are eligible for star join detection. + // An eligible plan is a base table access with valid statistics. + val foundEligibleJoin = input.forall { + case PhysicalOperation(_, _, t: LeafNode) if t.stats(conf).rowCount.isDefined => true + case _ => false + } + + if (!foundEligibleJoin) { + // Some plans don't have stats or are complex plans. Conservatively, + // return an empty star join. This restriction can be lifted + // once statistics are propagated in the plan. + emptyStarJoinPlan + } else { + // Find the fact table using cardinality based heuristics i.e. + // the table with the largest number of rows. + val sortedFactTables = input.map { plan => + TableAccessCardinality(plan, getTableAccessCardinality(plan)) + }.collect { case t @ TableAccessCardinality(_, Some(_)) => + t + }.sortBy(_.size)(implicitly[Ordering[Option[BigInt]]].reverse) + + sortedFactTables match { + case Nil => + emptyStarJoinPlan + case table1 :: table2 :: _ + if table2.size.get.toDouble > conf.starSchemaFTRatio * table1.size.get.toDouble => + // If the top largest tables have comparable number of rows, return an empty star plan. + // This restriction will be lifted when the algorithm is generalized + // to return multiple star plans. + emptyStarJoinPlan + case TableAccessCardinality(factTable, _) :: rest => + // Find the fact table joins. + val allFactJoins = rest.collect { case TableAccessCardinality(plan, _) + if findJoinConditions(factTable, plan, conditions).nonEmpty => + plan + } + + // Find the corresponding join conditions. + val allFactJoinCond = allFactJoins.flatMap { plan => + val joinCond = findJoinConditions(factTable, plan, conditions) + joinCond + } + + // Verify if the join columns have valid statistics. + // Allow any relational comparison between the tables. Later + // we will heuristically choose a subset of equi-join + // tables. + val areStatsAvailable = allFactJoins.forall { dimTable => + allFactJoinCond.exists { + case BinaryComparison(lhs: AttributeReference, rhs: AttributeReference) => + val dimCol = if (dimTable.outputSet.contains(lhs)) lhs else rhs + val factCol = if (factTable.outputSet.contains(lhs)) lhs else rhs + hasStatistics(dimCol, dimTable) && hasStatistics(factCol, factTable) + case _ => false + } + } + + if (!areStatsAvailable) { + emptyStarJoinPlan + } else { + // Find the subset of dimension tables. A dimension table is assumed to be in a + // RI relationship with the fact table. Only consider equi-joins + // between a fact and a dimension table to avoid expanding joins. + val eligibleDimPlans = allFactJoins.filter { dimTable => + allFactJoinCond.exists { + case cond @ Equality(lhs: AttributeReference, rhs: AttributeReference) => + val dimCol = if (dimTable.outputSet.contains(lhs)) lhs else rhs + isUnique(dimCol, dimTable) + case _ => false + } + } + + if (eligibleDimPlans.isEmpty || eligibleDimPlans.size < 2) { + // An eligible star join was not found since the join is not + // an RI join, or the star join is an expanding join. + // Also, a star would involve more than one dimension table. + emptyStarJoinPlan + } else { + factTable +: eligibleDimPlans + } + } + } + } + } + } + + /** + * Determines if a column referenced by a base table access is a primary key. + * A column is a PK if it is not nullable and has unique values. + * To determine if a column has unique values in the absence of informational + * RI constraints, the number of distinct values is compared to the total + * number of rows in the table. If their relative difference + * is within the expected limits (i.e. 2 * spark.sql.statistics.ndv.maxError based + * on TPC-DS data results), the column is assumed to have unique values. + */ + private def isUnique( + column: Attribute, + plan: LogicalPlan): Boolean = plan match { + case PhysicalOperation(_, _, t: LeafNode) => + val leafCol = findLeafNodeCol(column, plan) + leafCol match { + case Some(col) if t.outputSet.contains(col) => + val stats = t.stats(conf) + stats.rowCount match { + case Some(rowCount) if rowCount >= 0 => + if (stats.attributeStats.nonEmpty && stats.attributeStats.contains(col)) { + val colStats = stats.attributeStats.get(col) + if (colStats.get.nullCount > 0) { + false + } else { + val distinctCount = colStats.get.distinctCount + val relDiff = math.abs((distinctCount.toDouble / rowCount.toDouble) - 1.0d) + // ndvMaxErr adjusted based on TPCDS 1TB data results + relDiff <= conf.ndvMaxError * 2 + } + } else { + false + } + case None => false + } + case None => false + } + case _ => false + } + + /** + * Given a column over a base table access, it returns + * the leaf node column from which the input column is derived. + */ + @tailrec + private def findLeafNodeCol( + column: Attribute, + plan: LogicalPlan): Option[Attribute] = plan match { + case pl @ PhysicalOperation(_, _, _: LeafNode) => + pl match { + case t: LeafNode if t.outputSet.contains(column) => + Option(column) + case p: Project if p.outputSet.exists(_.semanticEquals(column)) => + val col = p.outputSet.find(_.semanticEquals(column)).get + findLeafNodeCol(col, p.child) + case f: Filter => + findLeafNodeCol(column, f.child) + case _ => None + } + case _ => None + } + + /** + * Checks if a column has statistics. + * The column is assumed to be over a base table access. + */ + private def hasStatistics( + column: Attribute, + plan: LogicalPlan): Boolean = plan match { + case PhysicalOperation(_, _, t: LeafNode) => + val leafCol = findLeafNodeCol(column, plan) + leafCol match { + case Some(col) if t.outputSet.contains(col) => + val stats = t.stats(conf) + stats.attributeStats.nonEmpty && stats.attributeStats.contains(col) + case None => false + } + case _ => false + } + + /** + * Returns the join predicates between two input plans. It only + * considers basic comparison operators. + */ + @inline + private def findJoinConditions( + plan1: LogicalPlan, + plan2: LogicalPlan, + conditions: Seq[Expression]): Seq[Expression] = { + val refs = plan1.outputSet ++ plan2.outputSet + conditions.filter { + case BinaryComparison(_, _) => true + case _ => false + }.filterNot(canEvaluate(_, plan1)) + .filterNot(canEvaluate(_, plan2)) + .filter(_.references.subsetOf(refs)) + } + + /** + * Checks if a star join is a selective join. A star join is assumed + * to be selective if there are local predicates on the dimension + * tables. + */ + private def isSelectiveStarJoin( + dimTables: Seq[LogicalPlan], + conditions: Seq[Expression]): Boolean = dimTables.exists { + case plan @ PhysicalOperation(_, p, _: LeafNode) => + // Checks if any condition applies to the dimension tables. + // Exclude the IsNotNull predicates until predicate selectivity is available. + // In most cases, this predicate is artificially introduced by the Optimizer + // to enforce nullability constraints. + val localPredicates = conditions.filterNot(_.isInstanceOf[IsNotNull]) + .exists(canEvaluate(_, plan)) + + // Checks if there are any predicates pushed down to the base table access. + val pushedDownPredicates = p.nonEmpty && !p.forall(_.isInstanceOf[IsNotNull]) + + localPredicates || pushedDownPredicates + case _ => false + } + + /** + * Helper case class to hold (plan, rowCount) pairs. + */ + private case class TableAccessCardinality(plan: LogicalPlan, size: Option[BigInt]) + + /** + * Returns the cardinality of a base table access. A base table access represents + * a LeafNode, or Project or Filter operators above a LeafNode. + */ + private def getTableAccessCardinality( + input: LogicalPlan): Option[BigInt] = input match { + case PhysicalOperation(_, cond, t: LeafNode) if t.stats(conf).rowCount.isDefined => + if (conf.cboEnabled && input.stats(conf).rowCount.isDefined) { + Option(input.stats(conf).rowCount.get) + } else { + Option(t.stats(conf).rowCount.get) + } + case _ => None + } + + /** + * Reorders a star join based on heuristics. It is called from ReorderJoin if CBO is disabled. + * 1) Finds the star join with the largest fact table. + * 2) Places the fact table the driving arm of the left-deep tree. + * This plan avoids large table access on the inner, and thus favor hash joins. + * 3) Applies the most selective dimensions early in the plan to reduce the amount of + * data flow. + */ + def reorderStarJoins( + input: Seq[(LogicalPlan, InnerLike)], + conditions: Seq[Expression]): Seq[(LogicalPlan, InnerLike)] = { + assert(input.size >= 2) + + val emptyStarJoinPlan = Seq.empty[(LogicalPlan, InnerLike)] + + // Find the eligible star plans. Currently, it only returns + // the star join with the largest fact table. + val eligibleJoins = input.collect{ case (plan, Inner) => plan } + val starPlan = findStarJoins(eligibleJoins, conditions) + + if (starPlan.isEmpty) { + emptyStarJoinPlan + } else { + val (factTable, dimTables) = (starPlan.head, starPlan.tail) + + // Only consider selective joins. This case is detected by observing local predicates + // on the dimension tables. In a star schema relationship, the join between the fact and the + // dimension table is a FK-PK join. Heuristically, a selective dimension may reduce + // the result of a join. + if (isSelectiveStarJoin(dimTables, conditions)) { + val reorderDimTables = dimTables.map { plan => + TableAccessCardinality(plan, getTableAccessCardinality(plan)) + }.sortBy(_.size).map { + case TableAccessCardinality(p1, _) => p1 + } + + val reorderStarPlan = factTable +: reorderDimTables + reorderStarPlan.map(plan => (plan, Inner)) + } else { + emptyStarJoinPlan + } + } + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/joins.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/joins.scala index 250dd07a16eb4..c3ab58744953d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/joins.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/joins.scala @@ -20,338 +20,12 @@ package org.apache.spark.sql.catalyst.optimizer import scala.annotation.tailrec import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.planning.{ExtractFiltersAndInnerJoins, PhysicalOperation} +import org.apache.spark.sql.catalyst.planning.ExtractFiltersAndInnerJoins import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules._ import org.apache.spark.sql.internal.SQLConf -/** - * Encapsulates star-schema join detection. - */ -case class StarSchemaDetection(conf: SQLConf) extends PredicateHelper { - - /** - * Star schema consists of one or more fact tables referencing a number of dimension - * tables. In general, star-schema joins are detected using the following conditions: - * 1. Informational RI constraints (reliable detection) - * + Dimension contains a primary key that is being joined to the fact table. - * + Fact table contains foreign keys referencing multiple dimension tables. - * 2. Cardinality based heuristics - * + Usually, the table with the highest cardinality is the fact table. - * + Table being joined with the most number of tables is the fact table. - * - * To detect star joins, the algorithm uses a combination of the above two conditions. - * The fact table is chosen based on the cardinality heuristics, and the dimension - * tables are chosen based on the RI constraints. A star join will consist of the largest - * fact table joined with the dimension tables on their primary keys. To detect that a - * column is a primary key, the algorithm uses table and column statistics. - * - * Since Catalyst only supports left-deep tree plans, the algorithm currently returns only - * the star join with the largest fact table. Choosing the largest fact table on the - * driving arm to avoid large inners is in general a good heuristic. This restriction can - * be lifted with support for bushy tree plans. - * - * The highlights of the algorithm are the following: - * - * Given a set of joined tables/plans, the algorithm first verifies if they are eligible - * for star join detection. An eligible plan is a base table access with valid statistics. - * A base table access represents Project or Filter operators above a LeafNode. Conservatively, - * the algorithm only considers base table access as part of a star join since they provide - * reliable statistics. - * - * If some of the plans are not base table access, or statistics are not available, the algorithm - * returns an empty star join plan since, in the absence of statistics, it cannot make - * good planning decisions. Otherwise, the algorithm finds the table with the largest cardinality - * (number of rows), which is assumed to be a fact table. - * - * Next, it computes the set of dimension tables for the current fact table. A dimension table - * is assumed to be in a RI relationship with a fact table. To infer column uniqueness, - * the algorithm compares the number of distinct values with the total number of rows in the - * table. If their relative difference is within certain limits (i.e. ndvMaxError * 2, adjusted - * based on 1TB TPC-DS data), the column is assumed to be unique. - */ - def findStarJoins( - input: Seq[LogicalPlan], - conditions: Seq[Expression]): Seq[Seq[LogicalPlan]] = { - - val emptyStarJoinPlan = Seq.empty[Seq[LogicalPlan]] - - if (!conf.starSchemaDetection || input.size < 2) { - emptyStarJoinPlan - } else { - // Find if the input plans are eligible for star join detection. - // An eligible plan is a base table access with valid statistics. - val foundEligibleJoin = input.forall { - case PhysicalOperation(_, _, t: LeafNode) if t.stats(conf).rowCount.isDefined => true - case _ => false - } - - if (!foundEligibleJoin) { - // Some plans don't have stats or are complex plans. Conservatively, - // return an empty star join. This restriction can be lifted - // once statistics are propagated in the plan. - emptyStarJoinPlan - } else { - // Find the fact table using cardinality based heuristics i.e. - // the table with the largest number of rows. - val sortedFactTables = input.map { plan => - TableAccessCardinality(plan, getTableAccessCardinality(plan)) - }.collect { case t @ TableAccessCardinality(_, Some(_)) => - t - }.sortBy(_.size)(implicitly[Ordering[Option[BigInt]]].reverse) - - sortedFactTables match { - case Nil => - emptyStarJoinPlan - case table1 :: table2 :: _ - if table2.size.get.toDouble > conf.starSchemaFTRatio * table1.size.get.toDouble => - // If the top largest tables have comparable number of rows, return an empty star plan. - // This restriction will be lifted when the algorithm is generalized - // to return multiple star plans. - emptyStarJoinPlan - case TableAccessCardinality(factTable, _) :: rest => - // Find the fact table joins. - val allFactJoins = rest.collect { case TableAccessCardinality(plan, _) - if findJoinConditions(factTable, plan, conditions).nonEmpty => - plan - } - - // Find the corresponding join conditions. - val allFactJoinCond = allFactJoins.flatMap { plan => - val joinCond = findJoinConditions(factTable, plan, conditions) - joinCond - } - - // Verify if the join columns have valid statistics. - // Allow any relational comparison between the tables. Later - // we will heuristically choose a subset of equi-join - // tables. - val areStatsAvailable = allFactJoins.forall { dimTable => - allFactJoinCond.exists { - case BinaryComparison(lhs: AttributeReference, rhs: AttributeReference) => - val dimCol = if (dimTable.outputSet.contains(lhs)) lhs else rhs - val factCol = if (factTable.outputSet.contains(lhs)) lhs else rhs - hasStatistics(dimCol, dimTable) && hasStatistics(factCol, factTable) - case _ => false - } - } - - if (!areStatsAvailable) { - emptyStarJoinPlan - } else { - // Find the subset of dimension tables. A dimension table is assumed to be in a - // RI relationship with the fact table. Only consider equi-joins - // between a fact and a dimension table to avoid expanding joins. - val eligibleDimPlans = allFactJoins.filter { dimTable => - allFactJoinCond.exists { - case cond @ Equality(lhs: AttributeReference, rhs: AttributeReference) => - val dimCol = if (dimTable.outputSet.contains(lhs)) lhs else rhs - isUnique(dimCol, dimTable) - case _ => false - } - } - - if (eligibleDimPlans.isEmpty) { - // An eligible star join was not found because the join is not - // an RI join, or the star join is an expanding join. - emptyStarJoinPlan - } else { - Seq(factTable +: eligibleDimPlans) - } - } - } - } - } - } - - /** - * Reorders a star join based on heuristics: - * 1) Finds the star join with the largest fact table and places it on the driving - * arm of the left-deep tree. This plan avoids large table access on the inner, and - * thus favor hash joins. - * 2) Applies the most selective dimensions early in the plan to reduce the amount of - * data flow. - */ - def reorderStarJoins( - input: Seq[(LogicalPlan, InnerLike)], - conditions: Seq[Expression]): Seq[(LogicalPlan, InnerLike)] = { - assert(input.size >= 2) - - val emptyStarJoinPlan = Seq.empty[(LogicalPlan, InnerLike)] - - // Find the eligible star plans. Currently, it only returns - // the star join with the largest fact table. - val eligibleJoins = input.collect{ case (plan, Inner) => plan } - val starPlans = findStarJoins(eligibleJoins, conditions) - - if (starPlans.isEmpty) { - emptyStarJoinPlan - } else { - val starPlan = starPlans.head - val (factTable, dimTables) = (starPlan.head, starPlan.tail) - - // Only consider selective joins. This case is detected by observing local predicates - // on the dimension tables. In a star schema relationship, the join between the fact and the - // dimension table is a FK-PK join. Heuristically, a selective dimension may reduce - // the result of a join. - // Also, conservatively assume that a fact table is joined with more than one dimension. - if (dimTables.size >= 2 && isSelectiveStarJoin(dimTables, conditions)) { - val reorderDimTables = dimTables.map { plan => - TableAccessCardinality(plan, getTableAccessCardinality(plan)) - }.sortBy(_.size).map { - case TableAccessCardinality(p1, _) => p1 - } - - val reorderStarPlan = factTable +: reorderDimTables - reorderStarPlan.map(plan => (plan, Inner)) - } else { - emptyStarJoinPlan - } - } - } - - /** - * Determines if a column referenced by a base table access is a primary key. - * A column is a PK if it is not nullable and has unique values. - * To determine if a column has unique values in the absence of informational - * RI constraints, the number of distinct values is compared to the total - * number of rows in the table. If their relative difference - * is within the expected limits (i.e. 2 * spark.sql.statistics.ndv.maxError based - * on TPCDS data results), the column is assumed to have unique values. - */ - private def isUnique( - column: Attribute, - plan: LogicalPlan): Boolean = plan match { - case PhysicalOperation(_, _, t: LeafNode) => - val leafCol = findLeafNodeCol(column, plan) - leafCol match { - case Some(col) if t.outputSet.contains(col) => - val stats = t.stats(conf) - stats.rowCount match { - case Some(rowCount) if rowCount >= 0 => - if (stats.attributeStats.nonEmpty && stats.attributeStats.contains(col)) { - val colStats = stats.attributeStats.get(col) - if (colStats.get.nullCount > 0) { - false - } else { - val distinctCount = colStats.get.distinctCount - val relDiff = math.abs((distinctCount.toDouble / rowCount.toDouble) - 1.0d) - // ndvMaxErr adjusted based on TPCDS 1TB data results - relDiff <= conf.ndvMaxError * 2 - } - } else { - false - } - case None => false - } - case None => false - } - case _ => false - } - - /** - * Given a column over a base table access, it returns - * the leaf node column from which the input column is derived. - */ - @tailrec - private def findLeafNodeCol( - column: Attribute, - plan: LogicalPlan): Option[Attribute] = plan match { - case pl @ PhysicalOperation(_, _, _: LeafNode) => - pl match { - case t: LeafNode if t.outputSet.contains(column) => - Option(column) - case p: Project if p.outputSet.exists(_.semanticEquals(column)) => - val col = p.outputSet.find(_.semanticEquals(column)).get - findLeafNodeCol(col, p.child) - case f: Filter => - findLeafNodeCol(column, f.child) - case _ => None - } - case _ => None - } - - /** - * Checks if a column has statistics. - * The column is assumed to be over a base table access. - */ - private def hasStatistics( - column: Attribute, - plan: LogicalPlan): Boolean = plan match { - case PhysicalOperation(_, _, t: LeafNode) => - val leafCol = findLeafNodeCol(column, plan) - leafCol match { - case Some(col) if t.outputSet.contains(col) => - val stats = t.stats(conf) - stats.attributeStats.nonEmpty && stats.attributeStats.contains(col) - case None => false - } - case _ => false - } - - /** - * Returns the join predicates between two input plans. It only - * considers basic comparison operators. - */ - @inline - private def findJoinConditions( - plan1: LogicalPlan, - plan2: LogicalPlan, - conditions: Seq[Expression]): Seq[Expression] = { - val refs = plan1.outputSet ++ plan2.outputSet - conditions.filter { - case BinaryComparison(_, _) => true - case _ => false - }.filterNot(canEvaluate(_, plan1)) - .filterNot(canEvaluate(_, plan2)) - .filter(_.references.subsetOf(refs)) - } - - /** - * Checks if a star join is a selective join. A star join is assumed - * to be selective if there are local predicates on the dimension - * tables. - */ - private def isSelectiveStarJoin( - dimTables: Seq[LogicalPlan], - conditions: Seq[Expression]): Boolean = dimTables.exists { - case plan @ PhysicalOperation(_, p, _: LeafNode) => - // Checks if any condition applies to the dimension tables. - // Exclude the IsNotNull predicates until predicate selectivity is available. - // In most cases, this predicate is artificially introduced by the Optimizer - // to enforce nullability constraints. - val localPredicates = conditions.filterNot(_.isInstanceOf[IsNotNull]) - .exists(canEvaluate(_, plan)) - - // Checks if there are any predicates pushed down to the base table access. - val pushedDownPredicates = p.nonEmpty && !p.forall(_.isInstanceOf[IsNotNull]) - - localPredicates || pushedDownPredicates - case _ => false - } - - /** - * Helper case class to hold (plan, rowCount) pairs. - */ - private case class TableAccessCardinality(plan: LogicalPlan, size: Option[BigInt]) - - /** - * Returns the cardinality of a base table access. A base table access represents - * a LeafNode, or Project or Filter operators above a LeafNode. - */ - private def getTableAccessCardinality( - input: LogicalPlan): Option[BigInt] = input match { - case PhysicalOperation(_, cond, t: LeafNode) if t.stats(conf).rowCount.isDefined => - if (conf.cboEnabled && input.stats(conf).rowCount.isDefined) { - Option(input.stats(conf).rowCount.get) - } else { - Option(t.stats(conf).rowCount.get) - } - case _ => None - } -} - /** * Reorder the joins and push all the conditions into join, so that the bottom ones have at least * one condition. diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/StarJoinReorderSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/StarJoinReorderSuite.scala index 003ce49eaf8e6..605c01b7220d1 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/StarJoinReorderSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/StarJoinReorderSuite.scala @@ -206,7 +206,7 @@ class StarJoinReorderSuite extends PlanTest with StatsEstimationTestBase { // and d3_fk1 = s3_pk1 // // Default join reordering: d1, f1, d2, d3, s3 - // Star join reordering: f1, d1, d3, d2,, d3 + // Star join reordering: f1, d1, d3, d2, s3 val query = d1.join(f1).join(d2).join(s3).join(d3) @@ -242,7 +242,7 @@ class StarJoinReorderSuite extends PlanTest with StatsEstimationTestBase { // and d3_fk1 = s3_pk1 // // Default join reordering: d1, f1, d2, d3, s3 - // Star join reordering: f1, d1, d3, d2, d3 + // Star join reordering: f1, d1, d3, d2, s3 val query = d1.join(f1).join(d2).join(s3).join(d3) .where((nameToAttr("f1_fk1") === nameToAttr("d1_pk1")) && From 5142e5d4e09c7cb36cf1d792934a21c5305c6d42 Mon Sep 17 00:00:00 2001 From: Eric Liang Date: Wed, 5 Apr 2017 19:37:21 -0700 Subject: [PATCH 0207/1765] [SPARK-20217][CORE] Executor should not fail stage if killed task throws non-interrupted exception ## What changes were proposed in this pull request? If tasks throw non-interrupted exceptions on kill (e.g. java.nio.channels.ClosedByInterruptException), their death is reported back as TaskFailed instead of TaskKilled. This causes stage failure in some cases. This is reproducible as follows. Run the following, and then use SparkContext.killTaskAttempt to kill one of the tasks. The entire stage will fail since we threw a RuntimeException instead of InterruptedException. ``` spark.range(100).repartition(100).foreach { i => try { Thread.sleep(10000000) } catch { case t: InterruptedException => throw new RuntimeException(t) } } ``` Based on the code in TaskSetManager, I think this also affects kills of speculative tasks. However, since the number of speculated tasks is few, and usually you need to fail a task a few times before the stage is cancelled, it unlikely this would be noticed in production unless both speculation was enabled and the num allowed task failures was = 1. We should probably unconditionally return TaskKilled instead of TaskFailed if the task was killed by the driver, regardless of the actual exception thrown. ## How was this patch tested? Unit test. The test fails before the change in Executor.scala cc JoshRosen Author: Eric Liang Closes #17531 from ericl/fix-task-interrupt. --- .../main/scala/org/apache/spark/executor/Executor.scala | 2 +- .../test/scala/org/apache/spark/SparkContextSuite.scala | 8 +++++++- 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/executor/Executor.scala b/core/src/main/scala/org/apache/spark/executor/Executor.scala index 99b1608010ddb..83469c5ff0600 100644 --- a/core/src/main/scala/org/apache/spark/executor/Executor.scala +++ b/core/src/main/scala/org/apache/spark/executor/Executor.scala @@ -432,7 +432,7 @@ private[spark] class Executor( setTaskFinishedAndClearInterruptStatus() execBackend.statusUpdate(taskId, TaskState.KILLED, ser.serialize(TaskKilled(t.reason))) - case _: InterruptedException if task.reasonIfKilled.isDefined => + case NonFatal(_) if task != null && task.reasonIfKilled.isDefined => val killReason = task.reasonIfKilled.getOrElse("unknown reason") logInfo(s"Executor interrupted and killed $taskName (TID $taskId), reason: $killReason") setTaskFinishedAndClearInterruptStatus() diff --git a/core/src/test/scala/org/apache/spark/SparkContextSuite.scala b/core/src/test/scala/org/apache/spark/SparkContextSuite.scala index 2c947556dfd30..735f4454e299e 100644 --- a/core/src/test/scala/org/apache/spark/SparkContextSuite.scala +++ b/core/src/test/scala/org/apache/spark/SparkContextSuite.scala @@ -572,7 +572,13 @@ class SparkContextSuite extends SparkFunSuite with LocalSparkContext with Eventu // first attempt will hang if (!SparkContextSuite.isTaskStarted) { SparkContextSuite.isTaskStarted = true - Thread.sleep(9999999) + try { + Thread.sleep(9999999) + } catch { + case t: Throwable => + // SPARK-20217 should not fail stage if task throws non-interrupted exception + throw new RuntimeException("killed") + } } // second attempt succeeds immediately } From e156b5dd39dc1992077fe06e0f8be810c49c8255 Mon Sep 17 00:00:00 2001 From: Bryan Cutler Date: Thu, 6 Apr 2017 09:41:32 +0200 Subject: [PATCH 0208/1765] [SPARK-19953][ML] Random Forest Models use parent UID when being fit ## What changes were proposed in this pull request? The ML `RandomForestClassificationModel` and `RandomForestRegressionModel` were not using the estimator parent UID when being fit. This change fixes that so the models can be properly be identified with their parents. ## How was this patch tested?Existing tests. Added check to verify that model uid matches that of the parent, then renamed `checkCopy` to `checkCopyAndUids` and verified that it was called by one test for each ML algorithm. Author: Bryan Cutler Closes #17296 from BryanCutler/rfmodels-use-parent-uid-SPARK-19953. --- .../RandomForestClassifier.scala | 2 +- .../ml/regression/RandomForestRegressor.scala | 2 +- .../org/apache/spark/ml/PipelineSuite.scala | 2 +- .../DecisionTreeClassifierSuite.scala | 3 +- .../classification/GBTClassifierSuite.scala | 6 ++-- .../ml/classification/LinearSVCSuite.scala | 3 +- .../LogisticRegressionSuite.scala | 3 +- .../MultilayerPerceptronClassifierSuite.scala | 2 +- .../ml/classification/NaiveBayesSuite.scala | 1 + .../ml/classification/OneVsRestSuite.scala | 3 +- .../RandomForestClassifierSuite.scala | 3 +- .../ml/clustering/BisectingKMeansSuite.scala | 3 +- .../ml/clustering/GaussianMixtureSuite.scala | 3 +- .../spark/ml/clustering/KMeansSuite.scala | 3 +- .../apache/spark/ml/clustering/LDASuite.scala | 4 +-- .../BucketedRandomProjectionLSHSuite.scala | 3 +- .../spark/ml/feature/ChiSqSelectorSuite.scala | 9 +++-- .../ml/feature/CountVectorizerSuite.scala | 9 ++--- .../apache/spark/ml/feature/IDFSuite.scala | 8 +++-- .../org/apache/spark/ml/feature/LSHTest.scala | 4 ++- .../spark/ml/feature/MaxAbsScalerSuite.scala | 3 +- .../spark/ml/feature/MinHashLSHSuite.scala | 2 +- .../spark/ml/feature/MinMaxScalerSuite.scala | 3 +- .../apache/spark/ml/feature/PCASuite.scala | 8 ++--- .../spark/ml/feature/RFormulaSuite.scala | 2 +- .../ml/feature/StandardScalerSuite.scala | 7 ++-- .../spark/ml/feature/StringIndexerSuite.scala | 7 ++-- .../spark/ml/feature/VectorIndexerSuite.scala | 3 +- .../spark/ml/feature/Word2VecSuite.scala | 7 ++-- .../apache/spark/ml/fpm/FPGrowthSuite.scala | 7 ++-- .../spark/ml/recommendation/ALSSuite.scala | 3 +- .../AFTSurvivalRegressionSuite.scala | 3 +- .../DecisionTreeRegressorSuite.scala | 7 ++-- .../ml/regression/GBTRegressorSuite.scala | 3 +- .../GeneralizedLinearRegressionSuite.scala | 3 +- .../regression/IsotonicRegressionSuite.scala | 3 +- .../ml/regression/LinearRegressionSuite.scala | 3 +- .../RandomForestRegressorSuite.scala | 2 ++ .../spark/ml/tuning/CrossValidatorSuite.scala | 3 +- .../ml/tuning/TrainValidationSplitSuite.scala | 35 +++++++++---------- .../apache/spark/ml/util/MLTestingUtils.scala | 8 +++-- 41 files changed, 98 insertions(+), 100 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala index ce834f1d17e0d..ab4c235209289 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala @@ -140,7 +140,7 @@ class RandomForestClassifier @Since("1.4.0") ( .map(_.asInstanceOf[DecisionTreeClassificationModel]) val numFeatures = oldDataset.first().features.size - val m = new RandomForestClassificationModel(trees, numFeatures, numClasses) + val m = new RandomForestClassificationModel(uid, trees, numFeatures, numClasses) instr.logSuccess(m) m } diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala index 2f524a8c5784d..a58da50fad972 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala @@ -131,7 +131,7 @@ class RandomForestRegressor @Since("1.4.0") (@Since("1.4.0") override val uid: S .map(_.asInstanceOf[DecisionTreeRegressionModel]) val numFeatures = oldDataset.first().features.size - val m = new RandomForestRegressionModel(trees, numFeatures) + val m = new RandomForestRegressionModel(uid, trees, numFeatures) instr.logSuccess(m) m } diff --git a/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala index dafc6c200f95f..4cdbf845ae4f5 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala @@ -79,7 +79,7 @@ class PipelineSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul .setStages(Array(estimator0, transformer1, estimator2, transformer3)) val pipelineModel = pipeline.fit(dataset0) - MLTestingUtils.checkCopy(pipelineModel) + MLTestingUtils.checkCopyAndUids(pipeline, pipelineModel) assert(pipelineModel.stages.length === 4) assert(pipelineModel.stages(0).eq(model0)) diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala index 964fcfbdd87a2..918ab27e2730b 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala @@ -249,8 +249,7 @@ class DecisionTreeClassifierSuite val newData: DataFrame = TreeTests.setMetadata(rdd, categoricalFeatures, numClasses) val newTree = dt.fit(newData) - // copied model must have the same parent. - MLTestingUtils.checkCopy(newTree) + MLTestingUtils.checkCopyAndUids(dt, newTree) val predictions = newTree.transform(newData) .select(newTree.getPredictionCol, newTree.getRawPredictionCol, newTree.getProbabilityCol) diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala index 0cddb37281b39..1f79e0d4e6228 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala @@ -97,8 +97,7 @@ class GBTClassifierSuite extends SparkFunSuite with MLlibTestSparkContext assert(model.getProbabilityCol === "probability") assert(model.hasParent) - // copied model must have the same parent. - MLTestingUtils.checkCopy(model) + MLTestingUtils.checkCopyAndUids(gbt, model) } test("setThreshold, getThreshold") { @@ -261,8 +260,7 @@ class GBTClassifierSuite extends SparkFunSuite with MLlibTestSparkContext .setSeed(123) val model = gbt.fit(df) - // copied model must have the same parent. - MLTestingUtils.checkCopy(model) + MLTestingUtils.checkCopyAndUids(gbt, model) sc.checkpointDir = None Utils.deleteRecursively(tempDir) diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/LinearSVCSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/LinearSVCSuite.scala index c763a4cef1afd..2f87afc23fe7e 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/LinearSVCSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/LinearSVCSuite.scala @@ -124,8 +124,7 @@ class LinearSVCSuite extends SparkFunSuite with MLlibTestSparkContext with Defau assert(model.hasParent) assert(model.numFeatures === 2) - // copied model must have the same parent. - MLTestingUtils.checkCopy(model) + MLTestingUtils.checkCopyAndUids(lsvc, model) } test("linear svc doesn't fit intercept when fitIntercept is off") { diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala index f0648d0936a12..c858b9bbfc256 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala @@ -142,8 +142,7 @@ class LogisticRegressionSuite assert(model.intercept !== 0.0) assert(model.hasParent) - // copied model must have the same parent. - MLTestingUtils.checkCopy(model) + MLTestingUtils.checkCopyAndUids(lr, model) assert(model.hasSummary) val copiedModel = model.copy(ParamMap.empty) assert(copiedModel.hasSummary) diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifierSuite.scala index 7700099caac37..ce54c3df4f3f6 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifierSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifierSuite.scala @@ -74,8 +74,8 @@ class MultilayerPerceptronClassifierSuite .setMaxIter(100) .setSolver("l-bfgs") val model = trainer.fit(dataset) - MLTestingUtils.checkCopy(model) val result = model.transform(dataset) + MLTestingUtils.checkCopyAndUids(trainer, model) val predictionAndLabels = result.select("prediction", "label").collect() predictionAndLabels.foreach { case Row(p: Double, l: Double) => assert(p == l) diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala index d41c5b533dedf..b56f8e19ca53c 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala @@ -149,6 +149,7 @@ class NaiveBayesSuite extends SparkFunSuite with MLlibTestSparkContext with Defa validateModelFit(pi, theta, model) assert(model.hasParent) + MLTestingUtils.checkCopyAndUids(nb, model) val validationDataset = generateNaiveBayesInput(piArray, thetaArray, nPoints, 17, "multinomial").toDF() diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala index aacb7921b835f..c02e38ad64e3e 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala @@ -76,8 +76,7 @@ class OneVsRestSuite extends SparkFunSuite with MLlibTestSparkContext with Defau assert(ova.getPredictionCol === "prediction") val ovaModel = ova.fit(dataset) - // copied model must have the same parent. - MLTestingUtils.checkCopy(ovaModel) + MLTestingUtils.checkCopyAndUids(ova, ovaModel) assert(ovaModel.models.length === numClasses) val transformedDataset = ovaModel.transform(dataset) diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala index c3003cec73b41..ca2954d2f32c4 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala @@ -141,8 +141,7 @@ class RandomForestClassifierSuite val df: DataFrame = TreeTests.setMetadata(rdd, categoricalFeatures, numClasses) val model = rf.fit(df) - // copied model must have the same parent. - MLTestingUtils.checkCopy(model) + MLTestingUtils.checkCopyAndUids(rf, model) val predictions = model.transform(df) .select(rf.getPredictionCol, rf.getRawPredictionCol, rf.getProbabilityCol) diff --git a/mllib/src/test/scala/org/apache/spark/ml/clustering/BisectingKMeansSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/clustering/BisectingKMeansSuite.scala index 200a892f6c694..fa7471fa2d658 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/clustering/BisectingKMeansSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/clustering/BisectingKMeansSuite.scala @@ -47,8 +47,7 @@ class BisectingKMeansSuite assert(bkm.getMinDivisibleClusterSize === 1.0) val model = bkm.setMaxIter(1).fit(dataset) - // copied model must have the same parent - MLTestingUtils.checkCopy(model) + MLTestingUtils.checkCopyAndUids(bkm, model) assert(model.hasSummary) val copiedModel = model.copy(ParamMap.empty) assert(copiedModel.hasSummary) diff --git a/mllib/src/test/scala/org/apache/spark/ml/clustering/GaussianMixtureSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/clustering/GaussianMixtureSuite.scala index 61da897b666f4..08b800b7e4183 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/clustering/GaussianMixtureSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/clustering/GaussianMixtureSuite.scala @@ -77,8 +77,7 @@ class GaussianMixtureSuite extends SparkFunSuite with MLlibTestSparkContext assert(gm.getTol === 0.01) val model = gm.setMaxIter(1).fit(dataset) - // copied model must have the same parent - MLTestingUtils.checkCopy(model) + MLTestingUtils.checkCopyAndUids(gm, model) assert(model.hasSummary) val copiedModel = model.copy(ParamMap.empty) assert(copiedModel.hasSummary) diff --git a/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala index ca05b9c389f65..119fe1dead9a9 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala @@ -52,8 +52,7 @@ class KMeansSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultR assert(kmeans.getTol === 1e-4) val model = kmeans.setMaxIter(1).fit(dataset) - // copied model must have the same parent - MLTestingUtils.checkCopy(model) + MLTestingUtils.checkCopyAndUids(kmeans, model) assert(model.hasSummary) val copiedModel = model.copy(ParamMap.empty) assert(copiedModel.hasSummary) diff --git a/mllib/src/test/scala/org/apache/spark/ml/clustering/LDASuite.scala b/mllib/src/test/scala/org/apache/spark/ml/clustering/LDASuite.scala index 75aa0be61a3ed..b4fe63a89f871 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/clustering/LDASuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/clustering/LDASuite.scala @@ -176,7 +176,7 @@ class LDASuite extends SparkFunSuite with MLlibTestSparkContext with DefaultRead val lda = new LDA().setK(k).setSeed(1).setOptimizer("online").setMaxIter(2) val model = lda.fit(dataset) - MLTestingUtils.checkCopy(model) + MLTestingUtils.checkCopyAndUids(lda, model) assert(model.isInstanceOf[LocalLDAModel]) assert(model.vocabSize === vocabSize) @@ -221,7 +221,7 @@ class LDASuite extends SparkFunSuite with MLlibTestSparkContext with DefaultRead val lda = new LDA().setK(k).setSeed(1).setOptimizer("em").setMaxIter(2) val model_ = lda.fit(dataset) - MLTestingUtils.checkCopy(model_) + MLTestingUtils.checkCopyAndUids(lda, model_) assert(model_.isInstanceOf[DistributedLDAModel]) val model = model_.asInstanceOf[DistributedLDAModel] diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/BucketedRandomProjectionLSHSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/BucketedRandomProjectionLSHSuite.scala index cc81da5c66e6d..7175c721bff36 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/BucketedRandomProjectionLSHSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/BucketedRandomProjectionLSHSuite.scala @@ -94,7 +94,8 @@ class BucketedRandomProjectionLSHSuite unitVectors.foreach { v: Vector => assert(Vectors.norm(v, 2.0) ~== 1.0 absTol 1e-14) } - MLTestingUtils.checkCopy(brpModel) + + MLTestingUtils.checkCopyAndUids(brp, brpModel) } test("BucketedRandomProjectionLSH: test of LSH property") { diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/ChiSqSelectorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/ChiSqSelectorSuite.scala index d6925da97d57e..c83909c4498f2 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/ChiSqSelectorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/ChiSqSelectorSuite.scala @@ -119,7 +119,8 @@ class ChiSqSelectorSuite extends SparkFunSuite with MLlibTestSparkContext test("Test Chi-Square selector: numTopFeatures") { val selector = new ChiSqSelector() .setOutputCol("filtered").setSelectorType("numTopFeatures").setNumTopFeatures(1) - ChiSqSelectorSuite.testSelector(selector, dataset) + val model = ChiSqSelectorSuite.testSelector(selector, dataset) + MLTestingUtils.checkCopyAndUids(selector, model) } test("Test Chi-Square selector: percentile") { @@ -166,11 +167,13 @@ class ChiSqSelectorSuite extends SparkFunSuite with MLlibTestSparkContext object ChiSqSelectorSuite { - private def testSelector(selector: ChiSqSelector, dataset: Dataset[_]): Unit = { - selector.fit(dataset).transform(dataset).select("filtered", "topFeature").collect() + private def testSelector(selector: ChiSqSelector, dataset: Dataset[_]): ChiSqSelectorModel = { + val selectorModel = selector.fit(dataset) + selectorModel.transform(dataset).select("filtered", "topFeature").collect() .foreach { case Row(vec1: Vector, vec2: Vector) => assert(vec1 ~== vec2 absTol 1e-1) } + selectorModel } /** diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/CountVectorizerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/CountVectorizerSuite.scala index 69d3033bb2189..f213145f1ba0a 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/CountVectorizerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/CountVectorizerSuite.scala @@ -19,7 +19,7 @@ package org.apache.spark.ml.feature import org.apache.spark.SparkFunSuite import org.apache.spark.ml.linalg.{Vector, Vectors} import org.apache.spark.ml.param.ParamsSuite -import org.apache.spark.ml.util.DefaultReadWriteTest +import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} import org.apache.spark.ml.util.TestingUtils._ import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.sql.Row @@ -68,10 +68,11 @@ class CountVectorizerSuite extends SparkFunSuite with MLlibTestSparkContext val cv = new CountVectorizer() .setInputCol("words") .setOutputCol("features") - .fit(df) - assert(cv.vocabulary.toSet === Set("a", "b", "c", "d", "e")) + val cvm = cv.fit(df) + MLTestingUtils.checkCopyAndUids(cv, cvm) + assert(cvm.vocabulary.toSet === Set("a", "b", "c", "d", "e")) - cv.transform(df).select("features", "expected").collect().foreach { + cvm.transform(df).select("features", "expected").collect().foreach { case Row(features: Vector, expected: Vector) => assert(features ~== expected absTol 1e-14) } diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/IDFSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/IDFSuite.scala index 5325d95526a50..005edf73d29be 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/IDFSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/IDFSuite.scala @@ -20,7 +20,7 @@ package org.apache.spark.ml.feature import org.apache.spark.SparkFunSuite import org.apache.spark.ml.linalg.{DenseVector, SparseVector, Vector, Vectors} import org.apache.spark.ml.param.ParamsSuite -import org.apache.spark.ml.util.DefaultReadWriteTest +import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} import org.apache.spark.ml.util.TestingUtils._ import org.apache.spark.mllib.feature.{IDFModel => OldIDFModel} import org.apache.spark.mllib.linalg.VectorImplicits._ @@ -65,10 +65,12 @@ class IDFSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultRead val df = data.zip(expected).toSeq.toDF("features", "expected") - val idfModel = new IDF() + val idfEst = new IDF() .setInputCol("features") .setOutputCol("idfValue") - .fit(df) + val idfModel = idfEst.fit(df) + + MLTestingUtils.checkCopyAndUids(idfEst, idfModel) idfModel.transform(df).select("idfValue", "expected").collect().foreach { case Row(x: Vector, y: Vector) => diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/LSHTest.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/LSHTest.scala index a9b559f7ba648..dd4dd62b8cfe9 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/LSHTest.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/LSHTest.scala @@ -18,7 +18,7 @@ package org.apache.spark.ml.feature import org.apache.spark.ml.linalg.{Vector, VectorUDT} -import org.apache.spark.ml.util.SchemaUtils +import org.apache.spark.ml.util.{MLTestingUtils, SchemaUtils} import org.apache.spark.sql.Dataset import org.apache.spark.sql.functions._ import org.apache.spark.sql.types.DataTypes @@ -58,6 +58,8 @@ private[ml] object LSHTest { val outputCol = model.getOutputCol val transformedData = model.transform(dataset) + MLTestingUtils.checkCopyAndUids(lsh, model) + // Check output column type SchemaUtils.checkColumnType( transformedData.schema, model.getOutputCol, DataTypes.createArrayType(new VectorUDT)) diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/MaxAbsScalerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/MaxAbsScalerSuite.scala index a12174493b867..918da4f9388d4 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/MaxAbsScalerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/MaxAbsScalerSuite.scala @@ -50,8 +50,7 @@ class MaxAbsScalerSuite extends SparkFunSuite with MLlibTestSparkContext with De assert(vector1.equals(vector2), s"MaxAbsScaler ut error: $vector2 should be $vector1") } - // copied model must have the same parent. - MLTestingUtils.checkCopy(model) + MLTestingUtils.checkCopyAndUids(scaler, model) } test("MaxAbsScaler read/write") { diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/MinHashLSHSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/MinHashLSHSuite.scala index 0ddf097a6eb22..96df68dbdf053 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/MinHashLSHSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/MinHashLSHSuite.scala @@ -63,7 +63,7 @@ class MinHashLSHSuite extends SparkFunSuite with MLlibTestSparkContext with Defa .setOutputCol("values") val model = mh.fit(dataset) assert(mh.uid === model.uid) - MLTestingUtils.checkCopy(model) + MLTestingUtils.checkCopyAndUids(mh, model) } test("hashFunction") { diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/MinMaxScalerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/MinMaxScalerSuite.scala index b79eeb2d75ef0..51db74eb739ca 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/MinMaxScalerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/MinMaxScalerSuite.scala @@ -53,8 +53,7 @@ class MinMaxScalerSuite extends SparkFunSuite with MLlibTestSparkContext with De assert(vector1.equals(vector2), "Transformed vector is different with expected.") } - // copied model must have the same parent. - MLTestingUtils.checkCopy(model) + MLTestingUtils.checkCopyAndUids(scaler, model) } test("MinMaxScaler arguments max must be larger than min") { diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/PCASuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/PCASuite.scala index a60e87590f060..3067a52a4df76 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/PCASuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/PCASuite.scala @@ -58,12 +58,12 @@ class PCASuite extends SparkFunSuite with MLlibTestSparkContext with DefaultRead .setInputCol("features") .setOutputCol("pca_features") .setK(3) - .fit(df) - // copied model must have the same parent. - MLTestingUtils.checkCopy(pca) + val pcaModel = pca.fit(df) - pca.transform(df).select("pca_features", "expected").collect().foreach { + MLTestingUtils.checkCopyAndUids(pca, pcaModel) + + pcaModel.transform(df).select("pca_features", "expected").collect().foreach { case Row(x: Vector, y: Vector) => assert(x ~== y absTol 1e-5, "Transformed vector is different with expected vector.") } diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala index 5cfd59e6b88a2..fbebd75d70ac5 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala @@ -37,7 +37,7 @@ class RFormulaSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul val formula = new RFormula().setFormula("id ~ v1 + v2") val original = Seq((0, 1.0, 3.0), (2, 2.0, 5.0)).toDF("id", "v1", "v2") val model = formula.fit(original) - MLTestingUtils.checkCopy(model) + MLTestingUtils.checkCopyAndUids(formula, model) val result = model.transform(original) val resultSchema = model.transformSchema(original.schema) val expected = Seq( diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/StandardScalerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/StandardScalerSuite.scala index a928f93633011..350ba44baa1eb 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/StandardScalerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/StandardScalerSuite.scala @@ -20,7 +20,7 @@ package org.apache.spark.ml.feature import org.apache.spark.SparkFunSuite import org.apache.spark.ml.linalg.{Vector, Vectors} import org.apache.spark.ml.param.ParamsSuite -import org.apache.spark.ml.util.DefaultReadWriteTest +import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} import org.apache.spark.ml.util.TestingUtils._ import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.sql.{DataFrame, Row} @@ -77,10 +77,11 @@ class StandardScalerSuite extends SparkFunSuite with MLlibTestSparkContext test("Standardization with default parameter") { val df0 = data.zip(resWithStd).toSeq.toDF("features", "expected") - val standardScaler0 = new StandardScaler() + val standardScalerEst0 = new StandardScaler() .setInputCol("features") .setOutputCol("standardized_features") - .fit(df0) + val standardScaler0 = standardScalerEst0.fit(df0) + MLTestingUtils.checkCopyAndUids(standardScalerEst0, standardScaler0) assertResult(standardScaler0.transform(df0)) } diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala index 8d9042b31e033..5634d4210f478 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala @@ -45,12 +45,11 @@ class StringIndexerSuite val indexer = new StringIndexer() .setInputCol("label") .setOutputCol("labelIndex") - .fit(df) + val indexerModel = indexer.fit(df) - // copied model must have the same parent. - MLTestingUtils.checkCopy(indexer) + MLTestingUtils.checkCopyAndUids(indexer, indexerModel) - val transformed = indexer.transform(df) + val transformed = indexerModel.transform(df) val attr = Attribute.fromStructField(transformed.schema("labelIndex")) .asInstanceOf[NominalAttribute] assert(attr.values.get === Array("a", "c", "b")) diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/VectorIndexerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/VectorIndexerSuite.scala index b28ce2ab45b45..f2cca8aa82e85 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/VectorIndexerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/VectorIndexerSuite.scala @@ -114,8 +114,7 @@ class VectorIndexerSuite extends SparkFunSuite with MLlibTestSparkContext val vectorIndexer = getIndexer val model = vectorIndexer.fit(densePoints1) // vectors of length 3 - // copied model must have the same parent. - MLTestingUtils.checkCopy(model) + MLTestingUtils.checkCopyAndUids(vectorIndexer, model) model.transform(densePoints1) // should work model.transform(sparsePoints1) // should work diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala index 2043a16c15f1a..a6a1c2b4f32bd 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala @@ -57,15 +57,14 @@ class Word2VecSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul val docDF = doc.zip(expected).toDF("text", "expected") - val model = new Word2Vec() + val w2v = new Word2Vec() .setVectorSize(3) .setInputCol("text") .setOutputCol("result") .setSeed(42L) - .fit(docDF) + val model = w2v.fit(docDF) - // copied model must have the same parent. - MLTestingUtils.checkCopy(model) + MLTestingUtils.checkCopyAndUids(w2v, model) // These expectations are just magic values, characterizing the current // behavior. The test needs to be updated to be more general, see SPARK-11502 diff --git a/mllib/src/test/scala/org/apache/spark/ml/fpm/FPGrowthSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/fpm/FPGrowthSuite.scala index 6bec057511cd1..6806cb03bc42b 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/fpm/FPGrowthSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/fpm/FPGrowthSuite.scala @@ -17,9 +17,10 @@ package org.apache.spark.ml.fpm import org.apache.spark.SparkFunSuite +import org.apache.spark.ml.param.ParamsSuite import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} import org.apache.spark.mllib.util.MLlibTestSparkContext -import org.apache.spark.sql.{DataFrame, Dataset, Row, SparkSession} +import org.apache.spark.sql.{DataFrame, Dataset, SparkSession} import org.apache.spark.sql.functions._ import org.apache.spark.sql.types._ @@ -121,7 +122,9 @@ class FPGrowthSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul .setMinConfidence(0.5678) assert(fpGrowth.getMinSupport === 0.4567) assert(model.getMinConfidence === 0.5678) - MLTestingUtils.checkCopy(model) + MLTestingUtils.checkCopyAndUids(fpGrowth, model) + ParamsSuite.checkParams(fpGrowth) + ParamsSuite.checkParams(model) } test("read/write") { diff --git a/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala index a177ed13bf8ef..7574af3d77ea8 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala @@ -409,8 +409,7 @@ class ALSSuite logInfo(s"Test RMSE is $rmse.") assert(rmse < targetRMSE) - // copied model must have the same parent. - MLTestingUtils.checkCopy(model) + MLTestingUtils.checkCopyAndUids(als, model) } test("exact rank-1 matrix") { diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/AFTSurvivalRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/AFTSurvivalRegressionSuite.scala index 708185a0943df..fb39e50a83552 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/AFTSurvivalRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/AFTSurvivalRegressionSuite.scala @@ -83,8 +83,7 @@ class AFTSurvivalRegressionSuite .setQuantilesCol("quantiles") .fit(datasetUnivariate) - // copied model must have the same parent. - MLTestingUtils.checkCopy(model) + MLTestingUtils.checkCopyAndUids(aftr, model) model.transform(datasetUnivariate) .select("label", "prediction", "quantiles") diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala index 0e91284d03d98..642f266891b57 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala @@ -69,11 +69,12 @@ class DecisionTreeRegressorSuite test("copied model must have the same parent") { val categoricalFeatures = Map(0 -> 2, 1 -> 2) val df = TreeTests.setMetadata(categoricalDataPointsRDD, categoricalFeatures, numClasses = 0) - val model = new DecisionTreeRegressor() + val dtr = new DecisionTreeRegressor() .setImpurity("variance") .setMaxDepth(2) - .setMaxBins(8).fit(df) - MLTestingUtils.checkCopy(model) + .setMaxBins(8) + val model = dtr.fit(df) + MLTestingUtils.checkCopyAndUids(dtr, model) } test("predictVariance") { diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala index 03c2f97797bce..2da25f7e0100a 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala @@ -90,8 +90,7 @@ class GBTRegressorSuite extends SparkFunSuite with MLlibTestSparkContext .setMaxIter(2) val model = gbt.fit(df) - // copied model must have the same parent. - MLTestingUtils.checkCopy(model) + MLTestingUtils.checkCopyAndUids(gbt, model) val preds = model.transform(df) val predictions = preds.select("prediction").rdd.map(_.getDouble(0)) // Checks based on SPARK-8736 (to ensure it is not doing classification) diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala index 401911763fa3b..f7c7c001a36af 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala @@ -197,8 +197,7 @@ class GeneralizedLinearRegressionSuite val model = glr.setFamily("gaussian").setLink("identity") .fit(datasetGaussianIdentity) - // copied model must have the same parent. - MLTestingUtils.checkCopy(model) + MLTestingUtils.checkCopyAndUids(glr, model) assert(model.hasSummary) val copiedModel = model.copy(ParamMap.empty) assert(copiedModel.hasSummary) diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/IsotonicRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/IsotonicRegressionSuite.scala index f41a3601b1fa8..180f5f7ce5ab2 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/IsotonicRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/IsotonicRegressionSuite.scala @@ -93,8 +93,7 @@ class IsotonicRegressionSuite val model = ir.fit(dataset) - // copied model must have the same parent. - MLTestingUtils.checkCopy(model) + MLTestingUtils.checkCopyAndUids(ir, model) model.transform(dataset) .select("label", "features", "prediction", "weight") diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala index c6a267b7283d8..e7bd4eb9e0adf 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala @@ -148,8 +148,7 @@ class LinearRegressionSuite assert(lir.getSolver == "auto") val model = lir.fit(datasetWithDenseFeature) - // copied model must have the same parent. - MLTestingUtils.checkCopy(model) + MLTestingUtils.checkCopyAndUids(lir, model) assert(model.hasSummary) val copiedModel = model.copy(ParamMap.empty) assert(copiedModel.hasSummary) diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/RandomForestRegressorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/RandomForestRegressorSuite.scala index 3bf0445ebd3dd..8b8e8a655f47b 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/RandomForestRegressorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/RandomForestRegressorSuite.scala @@ -90,6 +90,8 @@ class RandomForestRegressorSuite extends SparkFunSuite with MLlibTestSparkContex val model = rf.fit(df) + MLTestingUtils.checkCopyAndUids(rf, model) + val importances = model.featureImportances val mostImportantFeature = importances.argmax assert(mostImportantFeature === 1) diff --git a/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala index 7116265474f22..2b4e6b53e4f81 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala @@ -58,8 +58,7 @@ class CrossValidatorSuite .setNumFolds(3) val cvModel = cv.fit(dataset) - // copied model must have the same paren. - MLTestingUtils.checkCopy(cvModel) + MLTestingUtils.checkCopyAndUids(cv, cvModel) val parent = cvModel.bestModel.parent.asInstanceOf[LogisticRegression] assert(parent.getRegParam === 0.001) diff --git a/mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala index 4463a9b6e543a..a34f930aa11c4 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala @@ -45,18 +45,18 @@ class TrainValidationSplitSuite .addGrid(lr.maxIter, Array(0, 10)) .build() val eval = new BinaryClassificationEvaluator - val cv = new TrainValidationSplit() + val tvs = new TrainValidationSplit() .setEstimator(lr) .setEstimatorParamMaps(lrParamMaps) .setEvaluator(eval) .setTrainRatio(0.5) .setSeed(42L) - val cvModel = cv.fit(dataset) - val parent = cvModel.bestModel.parent.asInstanceOf[LogisticRegression] - assert(cv.getTrainRatio === 0.5) + val tvsModel = tvs.fit(dataset) + val parent = tvsModel.bestModel.parent.asInstanceOf[LogisticRegression] + assert(tvs.getTrainRatio === 0.5) assert(parent.getRegParam === 0.001) assert(parent.getMaxIter === 10) - assert(cvModel.validationMetrics.length === lrParamMaps.length) + assert(tvsModel.validationMetrics.length === lrParamMaps.length) } test("train validation with linear regression") { @@ -71,28 +71,27 @@ class TrainValidationSplitSuite .addGrid(trainer.maxIter, Array(0, 10)) .build() val eval = new RegressionEvaluator() - val cv = new TrainValidationSplit() + val tvs = new TrainValidationSplit() .setEstimator(trainer) .setEstimatorParamMaps(lrParamMaps) .setEvaluator(eval) .setTrainRatio(0.5) .setSeed(42L) - val cvModel = cv.fit(dataset) + val tvsModel = tvs.fit(dataset) - // copied model must have the same paren. - MLTestingUtils.checkCopy(cvModel) + MLTestingUtils.checkCopyAndUids(tvs, tvsModel) - val parent = cvModel.bestModel.parent.asInstanceOf[LinearRegression] + val parent = tvsModel.bestModel.parent.asInstanceOf[LinearRegression] assert(parent.getRegParam === 0.001) assert(parent.getMaxIter === 10) - assert(cvModel.validationMetrics.length === lrParamMaps.length) + assert(tvsModel.validationMetrics.length === lrParamMaps.length) eval.setMetricName("r2") - val cvModel2 = cv.fit(dataset) - val parent2 = cvModel2.bestModel.parent.asInstanceOf[LinearRegression] + val tvsModel2 = tvs.fit(dataset) + val parent2 = tvsModel2.bestModel.parent.asInstanceOf[LinearRegression] assert(parent2.getRegParam === 0.001) assert(parent2.getMaxIter === 10) - assert(cvModel2.validationMetrics.length === lrParamMaps.length) + assert(tvsModel2.validationMetrics.length === lrParamMaps.length) } test("transformSchema should check estimatorParamMaps") { @@ -104,17 +103,17 @@ class TrainValidationSplitSuite .addGrid(est.inputCol, Array("input1", "input2")) .build() - val cv = new TrainValidationSplit() + val tvs = new TrainValidationSplit() .setEstimator(est) .setEstimatorParamMaps(paramMaps) .setEvaluator(eval) .setTrainRatio(0.5) - cv.transformSchema(new StructType()) // This should pass. + tvs.transformSchema(new StructType()) // This should pass. val invalidParamMaps = paramMaps :+ ParamMap(est.inputCol -> "") - cv.setEstimatorParamMaps(invalidParamMaps) + tvs.setEstimatorParamMaps(invalidParamMaps) intercept[IllegalArgumentException] { - cv.transformSchema(new StructType()) + tvs.transformSchema(new StructType()) } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/util/MLTestingUtils.scala b/mllib/src/test/scala/org/apache/spark/ml/util/MLTestingUtils.scala index 578f31c8e7dba..bef79e634f75f 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/util/MLTestingUtils.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/util/MLTestingUtils.scala @@ -31,11 +31,15 @@ import org.apache.spark.sql.functions._ import org.apache.spark.sql.types._ object MLTestingUtils extends SparkFunSuite { - def checkCopy(model: Model[_]): Unit = { + + def checkCopyAndUids[T <: Estimator[_]](estimator: T, model: Model[_]): Unit = { + assert(estimator.uid === model.uid, "Model uid does not match parent estimator") + + // copied model must have the same parent val copied = model.copy(ParamMap.empty) .asInstanceOf[Model[_]] - assert(copied.parent.uid == model.parent.uid) assert(copied.parent == model.parent) + assert(copied.parent.uid == model.parent.uid) } def checkNumericTypes[M <: Model[M], T <: Estimator[M]]( From c8fc1f3badf61bcfc4bd8eeeb61f73078ca068d1 Mon Sep 17 00:00:00 2001 From: Kalvin Chau Date: Thu, 6 Apr 2017 09:14:31 +0100 Subject: [PATCH 0209/1765] [SPARK-20085][MESOS] Configurable mesos labels for executors ## What changes were proposed in this pull request? Add spark.mesos.task.labels configuration option to add mesos key:value labels to the executor. "k1:v1,k2:v2" as the format, colons separating key-value and commas to list out more than one. Discussion of labels with mgummelt at #17404 ## How was this patch tested? Added unit tests to verify labels were added correctly, with incorrect labels being ignored and added a test to test the name of the executor. Tested with: `./build/sbt -Pmesos mesos/test` Please review http://spark.apache.org/contributing.html before opening a pull request. Author: Kalvin Chau Closes #17413 from kalvinnchau/mesos-labels. --- docs/running-on-mesos.md | 9 ++++ .../MesosCoarseGrainedSchedulerBackend.scala | 24 ++++++++++ ...osCoarseGrainedSchedulerBackendSuite.scala | 46 +++++++++++++++++++ 3 files changed, 79 insertions(+) diff --git a/docs/running-on-mesos.md b/docs/running-on-mesos.md index 8d5ad12cb85be..ef01cfe4b92cd 100644 --- a/docs/running-on-mesos.md +++ b/docs/running-on-mesos.md @@ -367,6 +367,15 @@ See the [configuration page](configuration.html) for information on Spark config

    [host_path:]container_path[:ro|:rw]
    + + spark.mesos.task.labels + (none) + + Set the Mesos labels to add to each task. Labels are free-form key-value pairs. + Key-value pairs should be separated by a colon, and commas used to list more than one. + Ex. key:value,key2:value2. + + spark.mesos.executor.home driver side SPARK_HOME diff --git a/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackend.scala b/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackend.scala index 5bdc2a2b840e3..2a36ec4fa8112 100644 --- a/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackend.scala +++ b/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackend.scala @@ -67,6 +67,8 @@ private[spark] class MesosCoarseGrainedSchedulerBackend( private val maxGpus = conf.getInt("spark.mesos.gpus.max", 0) + private val taskLabels = conf.get("spark.mesos.task.labels", "") + private[this] val shutdownTimeoutMS = conf.getTimeAsMs("spark.mesos.coarse.shutdownTimeout", "10s") .ensuring(_ >= 0, "spark.mesos.coarse.shutdownTimeout must be >= 0") @@ -408,6 +410,13 @@ private[spark] class MesosCoarseGrainedSchedulerBackend( taskBuilder.addAllResources(resourcesToUse.asJava) taskBuilder.setContainer(MesosSchedulerBackendUtil.containerInfo(sc.conf)) + val labelsBuilder = taskBuilder.getLabelsBuilder + val labels = buildMesosLabels().asJava + + labelsBuilder.addAllLabels(labels) + + taskBuilder.setLabels(labelsBuilder) + tasks(offer.getId) ::= taskBuilder.build() remainingResources(offerId) = resourcesLeft.asJava totalCoresAcquired += taskCPUs @@ -422,6 +431,21 @@ private[spark] class MesosCoarseGrainedSchedulerBackend( tasks.toMap } + private def buildMesosLabels(): List[Label] = { + taskLabels.split(",").flatMap(label => + label.split(":") match { + case Array(key, value) => + Some(Label.newBuilder() + .setKey(key) + .setValue(value) + .build()) + case _ => + logWarning(s"Unable to parse $label into a key:value label for the task.") + None + } + ).toList + } + /** Extracts task needed resources from a list of available resources. */ private def partitionTaskResources( resources: JList[Resource], diff --git a/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackendSuite.scala b/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackendSuite.scala index eb83926ae4102..c040f05d93b3a 100644 --- a/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackendSuite.scala +++ b/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackendSuite.scala @@ -475,6 +475,52 @@ class MesosCoarseGrainedSchedulerBackendSuite extends SparkFunSuite assert(launchedTasks.head.getName == "test-mesos-dynamic-alloc 0") } + test("mesos sets configurable labels on tasks") { + val taskLabelsString = "mesos:test,label:test" + setBackend(Map( + "spark.mesos.task.labels" -> taskLabelsString + )) + + // Build up the labels + val taskLabels = Protos.Labels.newBuilder() + .addLabels(Protos.Label.newBuilder() + .setKey("mesos").setValue("test").build()) + .addLabels(Protos.Label.newBuilder() + .setKey("label").setValue("test").build()) + .build() + + val offers = List(Resources(backend.executorMemory(sc), 1)) + offerResources(offers) + val launchedTasks = verifyTaskLaunched(driver, "o1") + + val labels = launchedTasks.head.getLabels + + assert(launchedTasks.head.getLabels.equals(taskLabels)) + } + + test("mesos ignored invalid labels and sets configurable labels on tasks") { + val taskLabelsString = "mesos:test,label:test,incorrect:label:here" + setBackend(Map( + "spark.mesos.task.labels" -> taskLabelsString + )) + + // Build up the labels + val taskLabels = Protos.Labels.newBuilder() + .addLabels(Protos.Label.newBuilder() + .setKey("mesos").setValue("test").build()) + .addLabels(Protos.Label.newBuilder() + .setKey("label").setValue("test").build()) + .build() + + val offers = List(Resources(backend.executorMemory(sc), 1)) + offerResources(offers) + val launchedTasks = verifyTaskLaunched(driver, "o1") + + val labels = launchedTasks.head.getLabels + + assert(launchedTasks.head.getLabels.equals(taskLabels)) + } + test("mesos supports spark.mesos.network.name") { setBackend(Map( "spark.mesos.network.name" -> "test-network-name" From d009fb369bbea0df81bbcf9c8028d14cfcaa683b Mon Sep 17 00:00:00 2001 From: setjet Date: Thu, 6 Apr 2017 09:43:07 +0100 Subject: [PATCH 0210/1765] [SPARK-20064][PYSPARK] Bump the PySpark verison number to 2.2 ## What changes were proposed in this pull request? PySpark version in version.py was lagging behind Versioning is in line with PEP 440: https://www.python.org/dev/peps/pep-0440/ ## How was this patch tested? Simply rebuild the project with existing tests Author: setjet Author: Ruben Janssen Closes #17523 from setjet/SPARK-20064. --- python/pyspark/version.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/pyspark/version.py b/python/pyspark/version.py index 08a301695fda7..41bf8c269b795 100644 --- a/python/pyspark/version.py +++ b/python/pyspark/version.py @@ -16,4 +16,4 @@ # See the License for the specific language governing permissions and # limitations under the License. -__version__ = "2.1.0.dev0" +__version__ = "2.2.0.dev0" From bccc330193217b2ec9660e06f1db6dd58f7af5d8 Mon Sep 17 00:00:00 2001 From: Felix Cheung Date: Thu, 6 Apr 2017 09:09:43 -0700 Subject: [PATCH 0211/1765] [SPARK-20196][PYTHON][SQL] update doc for catalog functions for all languages, add pyspark refreshByPath API ## What changes were proposed in this pull request? Update doc to remove external for createTable, add refreshByPath in python ## How was this patch tested? manual Author: Felix Cheung Closes #17512 from felixcheung/catalogdoc. --- R/pkg/R/SQLContext.R | 11 ++-- R/pkg/R/catalog.R | 52 +++++++++++-------- python/pyspark/sql/catalog.py | 27 +++++++--- python/pyspark/sql/context.py | 2 +- .../apache/spark/sql/catalog/Catalog.scala | 17 +++--- .../spark/sql/internal/CatalogImpl.scala | 22 +++++--- 6 files changed, 79 insertions(+), 52 deletions(-) diff --git a/R/pkg/R/SQLContext.R b/R/pkg/R/SQLContext.R index a1edef7608fa1..c2a1e240ad395 100644 --- a/R/pkg/R/SQLContext.R +++ b/R/pkg/R/SQLContext.R @@ -544,12 +544,15 @@ sql <- function(x, ...) { dispatchFunc("sql(sqlQuery)", x, ...) } -#' Create a SparkDataFrame from a SparkSQL Table +#' Create a SparkDataFrame from a SparkSQL table or view #' -#' Returns the specified Table as a SparkDataFrame. The Table must have already been registered -#' in the SparkSession. +#' Returns the specified table or view as a SparkDataFrame. The table or view must already exist or +#' have already been registered in the SparkSession. #' -#' @param tableName The SparkSQL Table to convert to a SparkDataFrame. +#' @param tableName the qualified or unqualified name that designates a table or view. If a database +#' is specified, it identifies the table/view from the database. +#' Otherwise, it first attempts to find a temporary view with the given name +#' and then match the table/view from the current database. #' @return SparkDataFrame #' @rdname tableToDF #' @name tableToDF diff --git a/R/pkg/R/catalog.R b/R/pkg/R/catalog.R index 07a89f763cde1..4b7f841b55dd0 100644 --- a/R/pkg/R/catalog.R +++ b/R/pkg/R/catalog.R @@ -65,7 +65,8 @@ createExternalTable <- function(x, ...) { #' #' Caches the specified table in-memory. #' -#' @param tableName The name of the table being cached +#' @param tableName the qualified or unqualified name that designates a table. If no database +#' identifier is provided, it refers to a table in the current database. #' @return SparkDataFrame #' @rdname cacheTable #' @export @@ -94,7 +95,8 @@ cacheTable <- function(x, ...) { #' #' Removes the specified table from the in-memory cache. #' -#' @param tableName The name of the table being uncached +#' @param tableName the qualified or unqualified name that designates a table. If no database +#' identifier is provided, it refers to a table in the current database. #' @return SparkDataFrame #' @rdname uncacheTable #' @export @@ -162,6 +164,7 @@ clearCache <- function() { #' @method dropTempTable default #' @note dropTempTable since 1.4.0 dropTempTable.default <- function(tableName) { + .Deprecated("dropTempView", old = "dropTempTable") if (class(tableName) != "character") { stop("tableName must be a string.") } @@ -169,7 +172,6 @@ dropTempTable.default <- function(tableName) { } dropTempTable <- function(x, ...) { - .Deprecated("dropTempView") dispatchFunc("dropTempView(viewName)", x, ...) } @@ -178,7 +180,7 @@ dropTempTable <- function(x, ...) { #' Drops the temporary view with the given view name in the catalog. #' If the view has been cached before, then it will also be uncached. #' -#' @param viewName the name of the view to be dropped. +#' @param viewName the name of the temporary view to be dropped. #' @return TRUE if the view is dropped successfully, FALSE otherwise. #' @rdname dropTempView #' @name dropTempView @@ -317,10 +319,10 @@ listDatabases <- function() { dataFrame(callJMethod(callJMethod(catalog, "listDatabases"), "toDF")) } -#' Returns a list of tables in the specified database +#' Returns a list of tables or views in the specified database #' -#' Returns a list of tables in the specified database. -#' This includes all temporary tables. +#' Returns a list of tables or views in the specified database. +#' This includes all temporary views. #' #' @param databaseName (optional) name of the database #' @return a SparkDataFrame of the list of tables. @@ -349,11 +351,13 @@ listTables <- function(databaseName = NULL) { dataFrame(callJMethod(jdst, "toDF")) } -#' Returns a list of columns for the given table in the specified database +#' Returns a list of columns for the given table/view in the specified database #' -#' Returns a list of columns for the given table in the specified database. +#' Returns a list of columns for the given table/view in the specified database. #' -#' @param tableName a name of the table. +#' @param tableName the qualified or unqualified name that designates a table/view. If no database +#' identifier is provided, it refers to a table/view in the current database. +#' If \code{databaseName} parameter is specified, this must be an unqualified name. #' @param databaseName (optional) name of the database #' @return a SparkDataFrame of the list of column descriptions. #' @rdname listColumns @@ -409,12 +413,13 @@ listFunctions <- function(databaseName = NULL) { dataFrame(callJMethod(jdst, "toDF")) } -#' Recover all the partitions in the directory of a table and update the catalog +#' Recovers all the partitions in the directory of a table and update the catalog #' -#' Recover all the partitions in the directory of a table and update the catalog. The name should -#' reference a partitioned table, and not a temporary view. +#' Recovers all the partitions in the directory of a table and update the catalog. The name should +#' reference a partitioned table, and not a view. #' -#' @param tableName a name of the table. +#' @param tableName the qualified or unqualified name that designates a table. If no database +#' identifier is provided, it refers to a table in the current database. #' @rdname recoverPartitions #' @name recoverPartitions #' @export @@ -430,17 +435,18 @@ recoverPartitions <- function(tableName) { invisible(handledCallJMethod(catalog, "recoverPartitions", tableName)) } -#' Invalidate and refresh all the cached metadata of the given table +#' Invalidates and refreshes all the cached data and metadata of the given table #' -#' Invalidate and refresh all the cached metadata of the given table. For performance reasons, -#' Spark SQL or the external data source library it uses might cache certain metadata about a -#' table, such as the location of blocks. When those change outside of Spark SQL, users should +#' Invalidates and refreshes all the cached data and metadata of the given table. For performance +#' reasons, Spark SQL or the external data source library it uses might cache certain metadata about +#' a table, such as the location of blocks. When those change outside of Spark SQL, users should #' call this function to invalidate the cache. #' #' If this table is cached as an InMemoryRelation, drop the original cached version and make the #' new version cached lazily. #' -#' @param tableName a name of the table. +#' @param tableName the qualified or unqualified name that designates a table. If no database +#' identifier is provided, it refers to a table in the current database. #' @rdname refreshTable #' @name refreshTable #' @export @@ -456,11 +462,11 @@ refreshTable <- function(tableName) { invisible(handledCallJMethod(catalog, "refreshTable", tableName)) } -#' Invalidate and refresh all the cached data and metadata for SparkDataFrame containing path +#' Invalidates and refreshes all the cached data and metadata for SparkDataFrame containing path #' -#' Invalidate and refresh all the cached data (and the associated metadata) for any SparkDataFrame -#' that contains the given data source path. Path matching is by prefix, i.e. "/" would invalidate -#' everything that is cached. +#' Invalidates and refreshes all the cached data (and the associated metadata) for any +#' SparkDataFrame that contains the given data source path. Path matching is by prefix, i.e. "/" +#' would invalidate everything that is cached. #' #' @param path the path of the data source. #' @rdname refreshByPath diff --git a/python/pyspark/sql/catalog.py b/python/pyspark/sql/catalog.py index 253a750629170..41e68a45a6159 100644 --- a/python/pyspark/sql/catalog.py +++ b/python/pyspark/sql/catalog.py @@ -72,10 +72,10 @@ def listDatabases(self): @ignore_unicode_prefix @since(2.0) def listTables(self, dbName=None): - """Returns a list of tables in the specified database. + """Returns a list of tables/views in the specified database. If no database is specified, the current database is used. - This includes all temporary tables. + This includes all temporary views. """ if dbName is None: dbName = self.currentDatabase() @@ -115,7 +115,7 @@ def listFunctions(self, dbName=None): @ignore_unicode_prefix @since(2.0) def listColumns(self, tableName, dbName=None): - """Returns a list of columns for the given table in the specified database. + """Returns a list of columns for the given table/view in the specified database. If no database is specified, the current database is used. @@ -161,14 +161,15 @@ def createExternalTable(self, tableName, path=None, source=None, schema=None, ** def createTable(self, tableName, path=None, source=None, schema=None, **options): """Creates a table based on the dataset in a data source. - It returns the DataFrame associated with the external table. + It returns the DataFrame associated with the table. The data source is specified by the ``source`` and a set of ``options``. If ``source`` is not specified, the default data source configured by - ``spark.sql.sources.default`` will be used. + ``spark.sql.sources.default`` will be used. When ``path`` is specified, an external table is + created from the data at the given path. Otherwise a managed table is created. Optionally, a schema can be provided as the schema of the returned :class:`DataFrame` and - created external table. + created table. :return: :class:`DataFrame` """ @@ -276,14 +277,24 @@ def clearCache(self): @since(2.0) def refreshTable(self, tableName): - """Invalidate and refresh all the cached metadata of the given table.""" + """Invalidates and refreshes all the cached data and metadata of the given table.""" self._jcatalog.refreshTable(tableName) @since('2.1.1') def recoverPartitions(self, tableName): - """Recover all the partitions of the given table and update the catalog.""" + """Recovers all the partitions of the given table and update the catalog. + + Only works with a partitioned table, and not a view. + """ self._jcatalog.recoverPartitions(tableName) + @since('2.2.0') + def refreshByPath(self, path): + """Invalidates and refreshes all the cached data (and the associated metadata) for any + DataFrame that contains the given data source path. + """ + self._jcatalog.refreshByPath(path) + def _reset(self): """(Internal use only) Drop all existing databases (except "default"), tables, partitions and functions, and set the current database to "default". diff --git a/python/pyspark/sql/context.py b/python/pyspark/sql/context.py index c22f4b87e1a78..fdb7abbad4e5f 100644 --- a/python/pyspark/sql/context.py +++ b/python/pyspark/sql/context.py @@ -385,7 +385,7 @@ def sql(self, sqlQuery): @since(1.0) def table(self, tableName): - """Returns the specified table as a :class:`DataFrame`. + """Returns the specified table or view as a :class:`DataFrame`. :return: :class:`DataFrame` diff --git a/sql/core/src/main/scala/org/apache/spark/sql/catalog/Catalog.scala b/sql/core/src/main/scala/org/apache/spark/sql/catalog/Catalog.scala index 137b0cbc84f8f..074952ff7900a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/catalog/Catalog.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/catalog/Catalog.scala @@ -283,7 +283,7 @@ abstract class Catalog { /** * :: Experimental :: - * Creates a table from the given path based on a data source and a set of options. + * Creates a table based on the dataset in a data source and a set of options. * Then, returns the corresponding DataFrame. * * @param tableName is either a qualified or unqualified name that designates a table. @@ -321,7 +321,7 @@ abstract class Catalog { /** * :: Experimental :: * (Scala-specific) - * Creates a table from the given path based on a data source and a set of options. + * Creates a table based on the dataset in a data source and a set of options. * Then, returns the corresponding DataFrame. * * @param tableName is either a qualified or unqualified name that designates a table. @@ -357,7 +357,7 @@ abstract class Catalog { /** * :: Experimental :: - * Create a table from the given path based on a data source, a schema and a set of options. + * Create a table based on the dataset in a data source, a schema and a set of options. * Then, returns the corresponding DataFrame. * * @param tableName is either a qualified or unqualified name that designates a table. @@ -397,7 +397,7 @@ abstract class Catalog { /** * :: Experimental :: * (Scala-specific) - * Create a table from the given path based on a data source, a schema and a set of options. + * Create a table based on the dataset in a data source, a schema and a set of options. * Then, returns the corresponding DataFrame. * * @param tableName is either a qualified or unqualified name that designates a table. @@ -447,6 +447,7 @@ abstract class Catalog { /** * Recovers all the partitions in the directory of a table and update the catalog. + * Only works with a partitioned table, and not a view. * * @param tableName is either a qualified or unqualified name that designates a table. * If no database identifier is provided, it refers to a table in the @@ -493,10 +494,10 @@ abstract class Catalog { def clearCache(): Unit /** - * Invalidates and refreshes all the cached metadata of the given table. For performance reasons, - * Spark SQL or the external data source library it uses might cache certain metadata about a - * table, such as the location of blocks. When those change outside of Spark SQL, users should - * call this function to invalidate the cache. + * Invalidates and refreshes all the cached data and metadata of the given table. For performance + * reasons, Spark SQL or the external data source library it uses might cache certain metadata + * about a table, such as the location of blocks. When those change outside of Spark SQL, users + * should call this function to invalidate the cache. * * If this table is cached as an InMemoryRelation, drop the original cached version and make the * new version cached lazily. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/CatalogImpl.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/CatalogImpl.scala index 5d1c35aba529a..aebb663df5c92 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/CatalogImpl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/CatalogImpl.scala @@ -141,7 +141,7 @@ class CatalogImpl(sparkSession: SparkSession) extends Catalog { } /** - * Returns a list of columns for the given table temporary view. + * Returns a list of columns for the given table/view or temporary view. */ @throws[AnalysisException]("table does not exist") override def listColumns(tableName: String): Dataset[Column] = { @@ -150,7 +150,7 @@ class CatalogImpl(sparkSession: SparkSession) extends Catalog { } /** - * Returns a list of columns for the given table in the specified database. + * Returns a list of columns for the given table/view or temporary view in the specified database. */ @throws[AnalysisException]("database or table does not exist") override def listColumns(dbName: String, tableName: String): Dataset[Column] = { @@ -273,7 +273,7 @@ class CatalogImpl(sparkSession: SparkSession) extends Catalog { /** * :: Experimental :: - * Creates a table from the given path based on a data source and returns the corresponding + * Creates a table from the given path and returns the corresponding * DataFrame. * * @group ddl_ops @@ -287,7 +287,7 @@ class CatalogImpl(sparkSession: SparkSession) extends Catalog { /** * :: Experimental :: * (Scala-specific) - * Creates a table from the given path based on a data source and a set of options. + * Creates a table based on the dataset in a data source and a set of options. * Then, returns the corresponding DataFrame. * * @group ddl_ops @@ -304,7 +304,7 @@ class CatalogImpl(sparkSession: SparkSession) extends Catalog { /** * :: Experimental :: * (Scala-specific) - * Creates a table from the given path based on a data source, a schema and a set of options. + * Creates a table based on the dataset in a data source, a schema and a set of options. * Then, returns the corresponding DataFrame. * * @group ddl_ops @@ -367,6 +367,7 @@ class CatalogImpl(sparkSession: SparkSession) extends Catalog { /** * Recovers all the partitions in the directory of a table and update the catalog. + * Only works with a partitioned table, and not a temporary view. * * @param tableName is either a qualified or unqualified name that designates a table. * If no database identifier is provided, it refers to a table in the @@ -431,8 +432,12 @@ class CatalogImpl(sparkSession: SparkSession) extends Catalog { } /** - * Refreshes the cache entry for a table or view, if any. For Hive metastore table, the metadata - * is refreshed. For data source tables, the schema will not be inferred and refreshed. + * Invalidates and refreshes all the cached data and metadata of the given table or view. + * For Hive metastore table, the metadata is refreshed. For data source tables, the schema will + * not be inferred and refreshed. + * + * If this table is cached as an InMemoryRelation, drop the original cached version and make the + * new version cached lazily. * * @group cachemgmt * @since 2.0.0 @@ -456,7 +461,8 @@ class CatalogImpl(sparkSession: SparkSession) extends Catalog { /** * Refreshes the cache entry and the associated metadata for all Dataset (if any), that contain - * the given data source path. + * the given data source path. Path matching is by prefix, i.e. "/" would invalidate + * everything that is cached. * * @group cachemgmt * @since 2.0.0 From 5a693b4138d4ce948e3bcdbe28d5c01d5deb8fa9 Mon Sep 17 00:00:00 2001 From: Felix Cheung Date: Thu, 6 Apr 2017 09:15:13 -0700 Subject: [PATCH 0212/1765] [SPARK-20195][SPARKR][SQL] add createTable catalog API and deprecate createExternalTable ## What changes were proposed in this pull request? Following up on #17483, add createTable (which is new in 2.2.0) and deprecate createExternalTable, plus a number of minor fixes ## How was this patch tested? manual, unit tests Author: Felix Cheung Closes #17511 from felixcheung/rceatetable. --- R/pkg/NAMESPACE | 1 + R/pkg/R/DataFrame.R | 4 +- R/pkg/R/catalog.R | 59 +++++++++++++++++++---- R/pkg/inst/tests/testthat/test_sparkSQL.R | 20 ++++++-- 4 files changed, 68 insertions(+), 16 deletions(-) diff --git a/R/pkg/NAMESPACE b/R/pkg/NAMESPACE index 9b7e95ce30acb..ca45c6f9b0a96 100644 --- a/R/pkg/NAMESPACE +++ b/R/pkg/NAMESPACE @@ -361,6 +361,7 @@ export("as.DataFrame", "clearCache", "createDataFrame", "createExternalTable", + "createTable", "currentDatabase", "dropTempTable", "dropTempView", diff --git a/R/pkg/R/DataFrame.R b/R/pkg/R/DataFrame.R index 97786df4ae6a1..ec85f723c08c6 100644 --- a/R/pkg/R/DataFrame.R +++ b/R/pkg/R/DataFrame.R @@ -557,7 +557,7 @@ setMethod("insertInto", jmode <- convertToJSaveMode(ifelse(overwrite, "overwrite", "append")) write <- callJMethod(x@sdf, "write") write <- callJMethod(write, "mode", jmode) - callJMethod(write, "insertInto", tableName) + invisible(callJMethod(write, "insertInto", tableName)) }) #' Cache @@ -2894,7 +2894,7 @@ setMethod("saveAsTable", write <- callJMethod(write, "format", source) write <- callJMethod(write, "mode", jmode) write <- callJMethod(write, "options", options) - callJMethod(write, "saveAsTable", tableName) + invisible(callJMethod(write, "saveAsTable", tableName)) }) #' summary diff --git a/R/pkg/R/catalog.R b/R/pkg/R/catalog.R index 4b7f841b55dd0..e59a7024333ac 100644 --- a/R/pkg/R/catalog.R +++ b/R/pkg/R/catalog.R @@ -17,7 +17,7 @@ # catalog.R: SparkSession catalog functions -#' Create an external table +#' (Deprecated) Create an external table #' #' Creates an external table based on the dataset in a data source, #' Returns a SparkDataFrame associated with the external table. @@ -29,10 +29,11 @@ #' @param tableName a name of the table. #' @param path the path of files to load. #' @param source the name of external data source. -#' @param schema the schema of the data for certain data source. +#' @param schema the schema of the data required for some data sources. #' @param ... additional argument(s) passed to the method. #' @return A SparkDataFrame. -#' @rdname createExternalTable +#' @rdname createExternalTable-deprecated +#' @seealso \link{createTable} #' @export #' @examples #'\dontrun{ @@ -43,24 +44,64 @@ #' @method createExternalTable default #' @note createExternalTable since 1.4.0 createExternalTable.default <- function(tableName, path = NULL, source = NULL, schema = NULL, ...) { + .Deprecated("createTable", old = "createExternalTable") + createTable(tableName, path, source, schema, ...) +} + +createExternalTable <- function(x, ...) { + dispatchFunc("createExternalTable(tableName, path = NULL, source = NULL, ...)", x, ...) +} + +#' Creates a table based on the dataset in a data source +#' +#' Creates a table based on the dataset in a data source. Returns a SparkDataFrame associated with +#' the table. +#' +#' The data source is specified by the \code{source} and a set of options(...). +#' If \code{source} is not specified, the default data source configured by +#' "spark.sql.sources.default" will be used. When a \code{path} is specified, an external table is +#' created from the data at the given path. Otherwise a managed table is created. +#' +#' @param tableName the qualified or unqualified name that designates a table. If no database +#' identifier is provided, it refers to a table in the current database. +#' @param path (optional) the path of files to load. +#' @param source (optional) the name of the data source. +#' @param schema (optional) the schema of the data required for some data sources. +#' @param ... additional named parameters as options for the data source. +#' @return A SparkDataFrame. +#' @rdname createTable +#' @seealso \link{createExternalTable} +#' @export +#' @examples +#'\dontrun{ +#' sparkR.session() +#' df <- createTable("myjson", path="path/to/json", source="json", schema) +#' +#' createTable("people", source = "json", schema = schema) +#' insertInto(df, "people") +#' } +#' @name createTable +#' @note createTable since 2.2.0 +createTable <- function(tableName, path = NULL, source = NULL, schema = NULL, ...) { sparkSession <- getSparkSession() options <- varargsToStrEnv(...) if (!is.null(path)) { options[["path"]] <- path } + if (is.null(source)) { + source <- getDefaultSqlSource() + } catalog <- callJMethod(sparkSession, "catalog") if (is.null(schema)) { - sdf <- callJMethod(catalog, "createExternalTable", tableName, source, options) + sdf <- callJMethod(catalog, "createTable", tableName, source, options) + } else if (class(schema) == "structType") { + sdf <- callJMethod(catalog, "createTable", tableName, source, schema$jobj, options) } else { - sdf <- callJMethod(catalog, "createExternalTable", tableName, source, schema$jobj, options) + stop("schema must be a structType.") } dataFrame(sdf) } -createExternalTable <- function(x, ...) { - dispatchFunc("createExternalTable(tableName, path = NULL, source = NULL, ...)", x, ...) -} - #' Cache Table #' #' Caches the specified table in-memory. diff --git a/R/pkg/inst/tests/testthat/test_sparkSQL.R b/R/pkg/inst/tests/testthat/test_sparkSQL.R index ad06711a79a78..58cf24256a94f 100644 --- a/R/pkg/inst/tests/testthat/test_sparkSQL.R +++ b/R/pkg/inst/tests/testthat/test_sparkSQL.R @@ -281,7 +281,7 @@ test_that("create DataFrame from RDD", { setHiveContext(sc) sql("CREATE TABLE people (name string, age double, height float)") df <- read.df(jsonPathNa, "json", schema) - invisible(insertInto(df, "people")) + insertInto(df, "people") expect_equal(collect(sql("SELECT age from people WHERE name = 'Bob'"))$age, c(16)) expect_equal(collect(sql("SELECT height from people WHERE name ='Bob'"))$height, @@ -1268,7 +1268,16 @@ test_that("column calculation", { test_that("test HiveContext", { setHiveContext(sc) - df <- createExternalTable("json", jsonPath, "json") + + schema <- structType(structField("name", "string"), structField("age", "integer"), + structField("height", "float")) + createTable("people", source = "json", schema = schema) + df <- read.df(jsonPathNa, "json", schema) + insertInto(df, "people") + expect_equal(collect(sql("SELECT age from people WHERE name = 'Bob'"))$age, c(16)) + sql("DROP TABLE people") + + df <- createTable("json", jsonPath, "json") expect_is(df, "SparkDataFrame") expect_equal(count(df), 3) df2 <- sql("select * from json") @@ -1276,25 +1285,26 @@ test_that("test HiveContext", { expect_equal(count(df2), 3) jsonPath2 <- tempfile(pattern = "sparkr-test", fileext = ".tmp") - invisible(saveAsTable(df, "json2", "json", "append", path = jsonPath2)) + saveAsTable(df, "json2", "json", "append", path = jsonPath2) df3 <- sql("select * from json2") expect_is(df3, "SparkDataFrame") expect_equal(count(df3), 3) unlink(jsonPath2) hivetestDataPath <- tempfile(pattern = "sparkr-test", fileext = ".tmp") - invisible(saveAsTable(df, "hivetestbl", path = hivetestDataPath)) + saveAsTable(df, "hivetestbl", path = hivetestDataPath) df4 <- sql("select * from hivetestbl") expect_is(df4, "SparkDataFrame") expect_equal(count(df4), 3) unlink(hivetestDataPath) parquetDataPath <- tempfile(pattern = "sparkr-test", fileext = ".tmp") - invisible(saveAsTable(df, "parquetest", "parquet", mode = "overwrite", path = parquetDataPath)) + saveAsTable(df, "parquetest", "parquet", mode = "overwrite", path = parquetDataPath) df5 <- sql("select * from parquetest") expect_is(df5, "SparkDataFrame") expect_equal(count(df5), 3) unlink(parquetDataPath) + unsetHiveContext() }) From a4491626ed8169f0162a0dfb78736c9b9e7fb434 Mon Sep 17 00:00:00 2001 From: jerryshao Date: Thu, 6 Apr 2017 13:23:54 -0500 Subject: [PATCH 0213/1765] [SPARK-17019][CORE] Expose on-heap and off-heap memory usage in various places ## What changes were proposed in this pull request? With [SPARK-13992](https://issues.apache.org/jira/browse/SPARK-13992), Spark supports persisting data into off-heap memory, but the usage of on-heap and off-heap memory is not exposed currently, it is not so convenient for user to monitor and profile, so here propose to expose off-heap memory as well as on-heap memory usage in various places: 1. Spark UI's executor page will display both on-heap and off-heap memory usage. 2. REST request returns both on-heap and off-heap memory. 3. Also this can be gotten from MetricsSystem. 4. Last this usage can be obtained programmatically from SparkListener. Attach the UI changes: ![screen shot 2016-08-12 at 11 20 44 am](https://cloud.githubusercontent.com/assets/850797/17612032/6c2f4480-607f-11e6-82e8-a27fb8cbb4ae.png) Backward compatibility is also considered for event-log and REST API. Old event log can still be replayed with off-heap usage displayed as 0. For REST API, only adds the new fields, so JSON backward compatibility can still be kept. ## How was this patch tested? Unit test added and manual verification. Author: jerryshao Closes #14617 from jerryshao/SPARK-17019. --- .../ui/static/executorspage-template.html | 18 ++- .../apache/spark/ui/static/executorspage.js | 103 ++++++++++++- .../org/apache/spark/ui/static/webui.css | 3 +- .../spark/scheduler/SparkListener.scala | 9 +- .../spark/status/api/v1/AllRDDResource.scala | 8 +- .../org/apache/spark/status/api/v1/api.scala | 12 +- .../apache/spark/storage/BlockManager.scala | 9 +- .../spark/storage/BlockManagerMaster.scala | 5 +- .../storage/BlockManagerMasterEndpoint.scala | 22 ++- .../spark/storage/BlockManagerMessages.scala | 3 +- .../spark/storage/BlockManagerSource.scala | 66 +++++---- .../spark/storage/StorageStatusListener.scala | 8 +- .../apache/spark/storage/StorageUtils.scala | 99 +++++++++---- .../apache/spark/ui/exec/ExecutorsPage.scala | 46 +++++- .../org/apache/spark/ui/storage/RDDPage.scala | 11 +- .../org/apache/spark/util/JsonProtocol.scala | 8 +- .../executor_memory_usage_expectation.json | 139 ++++++++++++++++++ ...xecutor_node_blacklisting_expectation.json | 41 ++++-- .../spark-events/app-20161116163331-0000 | 10 +- .../deploy/history/HistoryServerSuite.scala | 3 +- .../apache/spark/storage/StorageSuite.scala | 87 ++++++++++- .../org/apache/spark/ui/UISeleniumSuite.scala | 36 ++++- project/MimaExcludes.scala | 11 +- 23 files changed, 638 insertions(+), 119 deletions(-) create mode 100644 core/src/test/resources/HistoryServerExpectations/executor_memory_usage_expectation.json diff --git a/core/src/main/resources/org/apache/spark/ui/static/executorspage-template.html b/core/src/main/resources/org/apache/spark/ui/static/executorspage-template.html index 4e83d6d564986..5c91304e49fd7 100644 --- a/core/src/main/resources/org/apache/spark/ui/static/executorspage-template.html +++ b/core/src/main/resources/org/apache/spark/ui/static/executorspage-template.html @@ -24,7 +24,15 @@

    Summary

    RDD Blocks Storage Memory + title="Memory used / total available memory for storage of data like RDD partitions cached in memory.">Storage Memory + + + On Heap Storage Memory + + + Off Heap Storage Memory Disk Used Cores @@ -73,6 +81,14 @@

    Executors

    Storage Memory + + + On Heap Storage Memory + + + Off Heap Storage Memory Disk Used Cores Active Tasks diff --git a/core/src/main/resources/org/apache/spark/ui/static/executorspage.js b/core/src/main/resources/org/apache/spark/ui/static/executorspage.js index 7dbfe32de903a..930a0698928d1 100644 --- a/core/src/main/resources/org/apache/spark/ui/static/executorspage.js +++ b/core/src/main/resources/org/apache/spark/ui/static/executorspage.js @@ -190,6 +190,10 @@ $(document).ready(function () { var allRDDBlocks = 0; var allMemoryUsed = 0; var allMaxMemory = 0; + var allOnHeapMemoryUsed = 0; + var allOnHeapMaxMemory = 0; + var allOffHeapMemoryUsed = 0; + var allOffHeapMaxMemory = 0; var allDiskUsed = 0; var allTotalCores = 0; var allMaxTasks = 0; @@ -208,6 +212,10 @@ $(document).ready(function () { var activeRDDBlocks = 0; var activeMemoryUsed = 0; var activeMaxMemory = 0; + var activeOnHeapMemoryUsed = 0; + var activeOnHeapMaxMemory = 0; + var activeOffHeapMemoryUsed = 0; + var activeOffHeapMaxMemory = 0; var activeDiskUsed = 0; var activeTotalCores = 0; var activeMaxTasks = 0; @@ -226,6 +234,10 @@ $(document).ready(function () { var deadRDDBlocks = 0; var deadMemoryUsed = 0; var deadMaxMemory = 0; + var deadOnHeapMemoryUsed = 0; + var deadOnHeapMaxMemory = 0; + var deadOffHeapMemoryUsed = 0; + var deadOffHeapMaxMemory = 0; var deadDiskUsed = 0; var deadTotalCores = 0; var deadMaxTasks = 0; @@ -240,11 +252,22 @@ $(document).ready(function () { var deadTotalShuffleWrite = 0; var deadTotalBlacklisted = 0; + response.forEach(function (exec) { + exec.onHeapMemoryUsed = exec.hasOwnProperty('onHeapMemoryUsed') ? exec.onHeapMemoryUsed : 0; + exec.maxOnHeapMemory = exec.hasOwnProperty('maxOnHeapMemory') ? exec.maxOnHeapMemory : 0; + exec.offHeapMemoryUsed = exec.hasOwnProperty('offHeapMemoryUsed') ? exec.offHeapMemoryUsed : 0; + exec.maxOffHeapMemory = exec.hasOwnProperty('maxOffHeapMemory') ? exec.maxOffHeapMemory : 0; + }); + response.forEach(function (exec) { allExecCnt += 1; allRDDBlocks += exec.rddBlocks; allMemoryUsed += exec.memoryUsed; allMaxMemory += exec.maxMemory; + allOnHeapMemoryUsed += exec.onHeapMemoryUsed; + allOnHeapMaxMemory += exec.maxOnHeapMemory; + allOffHeapMemoryUsed += exec.offHeapMemoryUsed; + allOffHeapMaxMemory += exec.maxOffHeapMemory; allDiskUsed += exec.diskUsed; allTotalCores += exec.totalCores; allMaxTasks += exec.maxTasks; @@ -263,6 +286,10 @@ $(document).ready(function () { activeRDDBlocks += exec.rddBlocks; activeMemoryUsed += exec.memoryUsed; activeMaxMemory += exec.maxMemory; + activeOnHeapMemoryUsed += exec.onHeapMemoryUsed; + activeOnHeapMaxMemory += exec.maxOnHeapMemory; + activeOffHeapMemoryUsed += exec.offHeapMemoryUsed; + activeOffHeapMaxMemory += exec.maxOffHeapMemory; activeDiskUsed += exec.diskUsed; activeTotalCores += exec.totalCores; activeMaxTasks += exec.maxTasks; @@ -281,6 +308,10 @@ $(document).ready(function () { deadRDDBlocks += exec.rddBlocks; deadMemoryUsed += exec.memoryUsed; deadMaxMemory += exec.maxMemory; + deadOnHeapMemoryUsed += exec.onHeapMemoryUsed; + deadOnHeapMaxMemory += exec.maxOnHeapMemory; + deadOffHeapMemoryUsed += exec.offHeapMemoryUsed; + deadOffHeapMaxMemory += exec.maxOffHeapMemory; deadDiskUsed += exec.diskUsed; deadTotalCores += exec.totalCores; deadMaxTasks += exec.maxTasks; @@ -302,6 +333,10 @@ $(document).ready(function () { "allRDDBlocks": allRDDBlocks, "allMemoryUsed": allMemoryUsed, "allMaxMemory": allMaxMemory, + "allOnHeapMemoryUsed": allOnHeapMemoryUsed, + "allOnHeapMaxMemory": allOnHeapMaxMemory, + "allOffHeapMemoryUsed": allOffHeapMemoryUsed, + "allOffHeapMaxMemory": allOffHeapMaxMemory, "allDiskUsed": allDiskUsed, "allTotalCores": allTotalCores, "allMaxTasks": allMaxTasks, @@ -321,6 +356,10 @@ $(document).ready(function () { "allRDDBlocks": activeRDDBlocks, "allMemoryUsed": activeMemoryUsed, "allMaxMemory": activeMaxMemory, + "allOnHeapMemoryUsed": activeOnHeapMemoryUsed, + "allOnHeapMaxMemory": activeOnHeapMaxMemory, + "allOffHeapMemoryUsed": activeOffHeapMemoryUsed, + "allOffHeapMaxMemory": activeOffHeapMaxMemory, "allDiskUsed": activeDiskUsed, "allTotalCores": activeTotalCores, "allMaxTasks": activeMaxTasks, @@ -340,6 +379,10 @@ $(document).ready(function () { "allRDDBlocks": deadRDDBlocks, "allMemoryUsed": deadMemoryUsed, "allMaxMemory": deadMaxMemory, + "allOnHeapMemoryUsed": deadOnHeapMemoryUsed, + "allOnHeapMaxMemory": deadOnHeapMaxMemory, + "allOffHeapMemoryUsed": deadOffHeapMemoryUsed, + "allOffHeapMaxMemory": deadOffHeapMaxMemory, "allDiskUsed": deadDiskUsed, "allTotalCores": deadTotalCores, "allMaxTasks": deadMaxTasks, @@ -378,7 +421,35 @@ $(document).ready(function () { {data: 'rddBlocks'}, { data: function (row, type) { - return type === 'display' ? (formatBytes(row.memoryUsed, type) + ' / ' + formatBytes(row.maxMemory, type)) : row.memoryUsed; + if (type !== 'display') + return row.memoryUsed; + else + return (formatBytes(row.memoryUsed, type) + ' / ' + + formatBytes(row.maxMemory, type)); + } + }, + { + data: function (row, type) { + if (type !== 'display') + return row.onHeapMemoryUsed; + else + return (formatBytes(row.onHeapMemoryUsed, type) + ' / ' + + formatBytes(row.maxOnHeapMemory, type)); + }, + "fnCreatedCell": function (nTd, sData, oData, iRow, iCol) { + $(nTd).addClass('on_heap_memory') + } + }, + { + data: function (row, type) { + if (type !== 'display') + return row.offHeapMemoryUsed; + else + return (formatBytes(row.offHeapMemoryUsed, type) + ' / ' + + formatBytes(row.maxOffHeapMemory, type)); + }, + "fnCreatedCell": function (nTd, sData, oData, iRow, iCol) { + $(nTd).addClass('off_heap_memory') } }, {data: 'diskUsed', render: formatBytes}, @@ -450,7 +521,35 @@ $(document).ready(function () { {data: 'allRDDBlocks'}, { data: function (row, type) { - return type === 'display' ? (formatBytes(row.allMemoryUsed, type) + ' / ' + formatBytes(row.allMaxMemory, type)) : row.allMemoryUsed; + if (type !== 'display') + return row.allMemoryUsed + else + return (formatBytes(row.allMemoryUsed, type) + ' / ' + + formatBytes(row.allMaxMemory, type)); + } + }, + { + data: function (row, type) { + if (type !== 'display') + return row.allOnHeapMemoryUsed; + else + return (formatBytes(row.allOnHeapMemoryUsed, type) + ' / ' + + formatBytes(row.allOnHeapMaxMemory, type)); + }, + "fnCreatedCell": function (nTd, sData, oData, iRow, iCol) { + $(nTd).addClass('on_heap_memory') + } + }, + { + data: function (row, type) { + if (type !== 'display') + return row.allOffHeapMemoryUsed; + else + return (formatBytes(row.allOffHeapMemoryUsed, type) + ' / ' + + formatBytes(row.allOffHeapMaxMemory, type)); + }, + "fnCreatedCell": function (nTd, sData, oData, iRow, iCol) { + $(nTd).addClass('off_heap_memory') } }, {data: 'allDiskUsed', render: formatBytes}, diff --git a/core/src/main/resources/org/apache/spark/ui/static/webui.css b/core/src/main/resources/org/apache/spark/ui/static/webui.css index 319a719efaa79..935d9b1aec615 100644 --- a/core/src/main/resources/org/apache/spark/ui/static/webui.css +++ b/core/src/main/resources/org/apache/spark/ui/static/webui.css @@ -205,7 +205,8 @@ span.additional-metric-title { /* Hide all additional metrics by default. This is done here rather than using JavaScript to * avoid slow page loads for stage pages with large numbers (e.g., thousands) of tasks. */ .scheduler_delay, .deserialization_time, .fetch_wait_time, .shuffle_read_remote, -.serialization_time, .getting_result_time, .peak_execution_memory { +.serialization_time, .getting_result_time, .peak_execution_memory, +.on_heap_memory, .off_heap_memory { display: none; } diff --git a/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala b/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala index 4331addb44172..bc2e530716686 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala @@ -87,8 +87,13 @@ case class SparkListenerEnvironmentUpdate(environmentDetails: Map[String, Seq[(S extends SparkListenerEvent @DeveloperApi -case class SparkListenerBlockManagerAdded(time: Long, blockManagerId: BlockManagerId, maxMem: Long) - extends SparkListenerEvent +case class SparkListenerBlockManagerAdded( + time: Long, + blockManagerId: BlockManagerId, + maxMem: Long, + maxOnHeapMem: Option[Long] = None, + maxOffHeapMem: Option[Long] = None) extends SparkListenerEvent { +} @DeveloperApi case class SparkListenerBlockManagerRemoved(time: Long, blockManagerId: BlockManagerId) diff --git a/core/src/main/scala/org/apache/spark/status/api/v1/AllRDDResource.scala b/core/src/main/scala/org/apache/spark/status/api/v1/AllRDDResource.scala index 5c03609e5e5e5..1279b281ad8d8 100644 --- a/core/src/main/scala/org/apache/spark/status/api/v1/AllRDDResource.scala +++ b/core/src/main/scala/org/apache/spark/status/api/v1/AllRDDResource.scala @@ -70,7 +70,13 @@ private[spark] object AllRDDResource { address = status.blockManagerId.hostPort, memoryUsed = status.memUsedByRdd(rddId), memoryRemaining = status.memRemaining, - diskUsed = status.diskUsedByRdd(rddId) + diskUsed = status.diskUsedByRdd(rddId), + onHeapMemoryUsed = Some( + if (!rddInfo.storageLevel.useOffHeap) status.memUsedByRdd(rddId) else 0L), + offHeapMemoryUsed = Some( + if (rddInfo.storageLevel.useOffHeap) status.memUsedByRdd(rddId) else 0L), + onHeapMemoryRemaining = status.onHeapMemRemaining, + offHeapMemoryRemaining = status.offHeapMemRemaining ) } ) } else { None diff --git a/core/src/main/scala/org/apache/spark/status/api/v1/api.scala b/core/src/main/scala/org/apache/spark/status/api/v1/api.scala index 5b9227350edaa..d159b9450ef5c 100644 --- a/core/src/main/scala/org/apache/spark/status/api/v1/api.scala +++ b/core/src/main/scala/org/apache/spark/status/api/v1/api.scala @@ -75,7 +75,11 @@ class ExecutorSummary private[spark]( val totalShuffleWrite: Long, val isBlacklisted: Boolean, val maxMemory: Long, - val executorLogs: Map[String, String]) + val executorLogs: Map[String, String], + val onHeapMemoryUsed: Option[Long], + val offHeapMemoryUsed: Option[Long], + val maxOnHeapMemory: Option[Long], + val maxOffHeapMemory: Option[Long]) class JobData private[spark]( val jobId: Int, @@ -111,7 +115,11 @@ class RDDDataDistribution private[spark]( val address: String, val memoryUsed: Long, val memoryRemaining: Long, - val diskUsed: Long) + val diskUsed: Long, + val onHeapMemoryUsed: Option[Long], + val offHeapMemoryUsed: Option[Long], + val onHeapMemoryRemaining: Option[Long], + val offHeapMemoryRemaining: Option[Long]) class RDDPartitionInfo private[spark]( val blockName: String, diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala index 46a078b2f9f93..63acba65d3c5b 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala @@ -150,8 +150,8 @@ private[spark] class BlockManager( // However, since we use this only for reporting and logging, what we actually want here is // the absolute maximum value that `maxMemory` can ever possibly reach. We may need // to revisit whether reporting this value as the "max" is intuitive to the user. - private val maxMemory = - memoryManager.maxOnHeapStorageMemory + memoryManager.maxOffHeapStorageMemory + private val maxOnHeapMemory = memoryManager.maxOnHeapStorageMemory + private val maxOffHeapMemory = memoryManager.maxOffHeapStorageMemory // Port used by the external shuffle service. In Yarn mode, this may be already be // set through the Hadoop configuration as the server is launched in the Yarn NM. @@ -229,7 +229,8 @@ private[spark] class BlockManager( val idFromMaster = master.registerBlockManager( id, - maxMemory, + maxOnHeapMemory, + maxOffHeapMemory, slaveEndpoint) blockManagerId = if (idFromMaster != null) idFromMaster else id @@ -307,7 +308,7 @@ private[spark] class BlockManager( def reregister(): Unit = { // TODO: We might need to rate limit re-registering. logInfo(s"BlockManager $blockManagerId re-registering with master") - master.registerBlockManager(blockManagerId, maxMemory, slaveEndpoint) + master.registerBlockManager(blockManagerId, maxOnHeapMemory, maxOffHeapMemory, slaveEndpoint) reportAllBlocks() } diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala index 3ca690db9e79f..ea5d8423a588c 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala @@ -57,11 +57,12 @@ class BlockManagerMaster( */ def registerBlockManager( blockManagerId: BlockManagerId, - maxMemSize: Long, + maxOnHeapMemSize: Long, + maxOffHeapMemSize: Long, slaveEndpoint: RpcEndpointRef): BlockManagerId = { logInfo(s"Registering BlockManager $blockManagerId") val updatedId = driverEndpoint.askSync[BlockManagerId]( - RegisterBlockManager(blockManagerId, maxMemSize, slaveEndpoint)) + RegisterBlockManager(blockManagerId, maxOnHeapMemSize, maxOffHeapMemSize, slaveEndpoint)) logInfo(s"Registered BlockManager $updatedId") updatedId } diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala index 84c04d22600ad..467c3e0e6b51f 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala @@ -71,8 +71,8 @@ class BlockManagerMasterEndpoint( logInfo("BlockManagerMasterEndpoint up") override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { - case RegisterBlockManager(blockManagerId, maxMemSize, slaveEndpoint) => - context.reply(register(blockManagerId, maxMemSize, slaveEndpoint)) + case RegisterBlockManager(blockManagerId, maxOnHeapMemSize, maxOffHeapMemSize, slaveEndpoint) => + context.reply(register(blockManagerId, maxOnHeapMemSize, maxOffHeapMemSize, slaveEndpoint)) case _updateBlockInfo @ UpdateBlockInfo(blockManagerId, blockId, storageLevel, deserializedSize, size) => @@ -276,7 +276,8 @@ class BlockManagerMasterEndpoint( private def storageStatus: Array[StorageStatus] = { blockManagerInfo.map { case (blockManagerId, info) => - new StorageStatus(blockManagerId, info.maxMem, info.blocks.asScala) + new StorageStatus(blockManagerId, info.maxMem, Some(info.maxOnHeapMem), + Some(info.maxOffHeapMem), info.blocks.asScala) }.toArray } @@ -338,7 +339,8 @@ class BlockManagerMasterEndpoint( */ private def register( idWithoutTopologyInfo: BlockManagerId, - maxMemSize: Long, + maxOnHeapMemSize: Long, + maxOffHeapMemSize: Long, slaveEndpoint: RpcEndpointRef): BlockManagerId = { // the dummy id is not expected to contain the topology information. // we get that info here and respond back with a more fleshed out block manager id @@ -359,14 +361,15 @@ class BlockManagerMasterEndpoint( case None => } logInfo("Registering block manager %s with %s RAM, %s".format( - id.hostPort, Utils.bytesToString(maxMemSize), id)) + id.hostPort, Utils.bytesToString(maxOnHeapMemSize + maxOffHeapMemSize), id)) blockManagerIdByExecutor(id.executorId) = id blockManagerInfo(id) = new BlockManagerInfo( - id, System.currentTimeMillis(), maxMemSize, slaveEndpoint) + id, System.currentTimeMillis(), maxOnHeapMemSize, maxOffHeapMemSize, slaveEndpoint) } - listenerBus.post(SparkListenerBlockManagerAdded(time, id, maxMemSize)) + listenerBus.post(SparkListenerBlockManagerAdded(time, id, maxOnHeapMemSize + maxOffHeapMemSize, + Some(maxOnHeapMemSize), Some(maxOffHeapMemSize))) id } @@ -464,10 +467,13 @@ object BlockStatus { private[spark] class BlockManagerInfo( val blockManagerId: BlockManagerId, timeMs: Long, - val maxMem: Long, + val maxOnHeapMem: Long, + val maxOffHeapMem: Long, val slaveEndpoint: RpcEndpointRef) extends Logging { + val maxMem = maxOnHeapMem + maxOffHeapMem + private var _lastSeenMs: Long = timeMs private var _remainingMem: Long = maxMem diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala index 0aea438e7f473..0c0ff144596ac 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala @@ -58,7 +58,8 @@ private[spark] object BlockManagerMessages { case class RegisterBlockManager( blockManagerId: BlockManagerId, - maxMemSize: Long, + maxOnHeapMemSize: Long, + maxOffHeapMemSize: Long, sender: RpcEndpointRef) extends ToBlockManagerMaster diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerSource.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerSource.scala index c5ba9af3e2658..197a01762c0c5 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerSource.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerSource.scala @@ -26,35 +26,39 @@ private[spark] class BlockManagerSource(val blockManager: BlockManager) override val metricRegistry = new MetricRegistry() override val sourceName = "BlockManager" - metricRegistry.register(MetricRegistry.name("memory", "maxMem_MB"), new Gauge[Long] { - override def getValue: Long = { - val storageStatusList = blockManager.master.getStorageStatus - val maxMem = storageStatusList.map(_.maxMem).sum - maxMem / 1024 / 1024 - } - }) - - metricRegistry.register(MetricRegistry.name("memory", "remainingMem_MB"), new Gauge[Long] { - override def getValue: Long = { - val storageStatusList = blockManager.master.getStorageStatus - val remainingMem = storageStatusList.map(_.memRemaining).sum - remainingMem / 1024 / 1024 - } - }) - - metricRegistry.register(MetricRegistry.name("memory", "memUsed_MB"), new Gauge[Long] { - override def getValue: Long = { - val storageStatusList = blockManager.master.getStorageStatus - val memUsed = storageStatusList.map(_.memUsed).sum - memUsed / 1024 / 1024 - } - }) - - metricRegistry.register(MetricRegistry.name("disk", "diskSpaceUsed_MB"), new Gauge[Long] { - override def getValue: Long = { - val storageStatusList = blockManager.master.getStorageStatus - val diskSpaceUsed = storageStatusList.map(_.diskUsed).sum - diskSpaceUsed / 1024 / 1024 - } - }) + private def registerGauge(name: String, func: BlockManagerMaster => Long): Unit = { + metricRegistry.register(name, new Gauge[Long] { + override def getValue: Long = func(blockManager.master) / 1024 / 1024 + }) + } + + registerGauge(MetricRegistry.name("memory", "maxMem_MB"), + _.getStorageStatus.map(_.maxMem).sum) + + registerGauge(MetricRegistry.name("memory", "maxOnHeapMem_MB"), + _.getStorageStatus.map(_.maxOnHeapMem.getOrElse(0L)).sum) + + registerGauge(MetricRegistry.name("memory", "maxOffHeapMem_MB"), + _.getStorageStatus.map(_.maxOffHeapMem.getOrElse(0L)).sum) + + registerGauge(MetricRegistry.name("memory", "remainingMem_MB"), + _.getStorageStatus.map(_.memRemaining).sum) + + registerGauge(MetricRegistry.name("memory", "remainingOnHeapMem_MB"), + _.getStorageStatus.map(_.onHeapMemRemaining.getOrElse(0L)).sum) + + registerGauge(MetricRegistry.name("memory", "remainingOffHeapMem_MB"), + _.getStorageStatus.map(_.offHeapMemRemaining.getOrElse(0L)).sum) + + registerGauge(MetricRegistry.name("memory", "memUsed_MB"), + _.getStorageStatus.map(_.memUsed).sum) + + registerGauge(MetricRegistry.name("memory", "onHeapMemUsed_MB"), + _.getStorageStatus.map(_.onHeapMemUsed.getOrElse(0L)).sum) + + registerGauge(MetricRegistry.name("memory", "offHeapMemUsed_MB"), + _.getStorageStatus.map(_.offHeapMemUsed.getOrElse(0L)).sum) + + registerGauge(MetricRegistry.name("disk", "diskSpaceUsed_MB"), + _.getStorageStatus.map(_.diskUsed).sum) } diff --git a/core/src/main/scala/org/apache/spark/storage/StorageStatusListener.scala b/core/src/main/scala/org/apache/spark/storage/StorageStatusListener.scala index 798658a15b797..1b30d4fa93bc0 100644 --- a/core/src/main/scala/org/apache/spark/storage/StorageStatusListener.scala +++ b/core/src/main/scala/org/apache/spark/storage/StorageStatusListener.scala @@ -41,7 +41,7 @@ class StorageStatusListener(conf: SparkConf) extends SparkListener { } def deadStorageStatusList: Seq[StorageStatus] = synchronized { - deadExecutorStorageStatus.toSeq + deadExecutorStorageStatus } /** Update storage status list to reflect updated block statuses */ @@ -74,8 +74,10 @@ class StorageStatusListener(conf: SparkConf) extends SparkListener { synchronized { val blockManagerId = blockManagerAdded.blockManagerId val executorId = blockManagerId.executorId - val maxMem = blockManagerAdded.maxMem - val storageStatus = new StorageStatus(blockManagerId, maxMem) + // The onHeap and offHeap memory are always defined for new applications, + // but they can be missing if we are replaying old event logs. + val storageStatus = new StorageStatus(blockManagerId, blockManagerAdded.maxMem, + blockManagerAdded.maxOnHeapMem, blockManagerAdded.maxOffHeapMem) executorIdToStorageStatus(executorId) = storageStatus // Try to remove the dead storage status if same executor register the block manager twice. diff --git a/core/src/main/scala/org/apache/spark/storage/StorageUtils.scala b/core/src/main/scala/org/apache/spark/storage/StorageUtils.scala index 241aacd74b586..8f0d181fc8fe5 100644 --- a/core/src/main/scala/org/apache/spark/storage/StorageUtils.scala +++ b/core/src/main/scala/org/apache/spark/storage/StorageUtils.scala @@ -35,7 +35,11 @@ import org.apache.spark.internal.Logging * class cannot mutate the source of the information. Accesses are not thread-safe. */ @DeveloperApi -class StorageStatus(val blockManagerId: BlockManagerId, val maxMem: Long) { +class StorageStatus( + val blockManagerId: BlockManagerId, + val maxMemory: Long, + val maxOnHeapMem: Option[Long], + val maxOffHeapMem: Option[Long]) { /** * Internal representation of the blocks stored in this block manager. @@ -46,25 +50,21 @@ class StorageStatus(val blockManagerId: BlockManagerId, val maxMem: Long) { private val _rddBlocks = new mutable.HashMap[Int, mutable.Map[BlockId, BlockStatus]] private val _nonRddBlocks = new mutable.HashMap[BlockId, BlockStatus] - /** - * Storage information of the blocks that entails memory, disk, and off-heap memory usage. - * - * As with the block maps, we store the storage information separately for RDD blocks and - * non-RDD blocks for the same reason. In particular, RDD storage information is stored - * in a map indexed by the RDD ID to the following 4-tuple: - * - * (memory size, disk size, storage level) - * - * We assume that all the blocks that belong to the same RDD have the same storage level. - * This field is not relevant to non-RDD blocks, however, so the storage information for - * non-RDD blocks contains only the first 3 fields (in the same order). - */ - private val _rddStorageInfo = new mutable.HashMap[Int, (Long, Long, StorageLevel)] - private var _nonRddStorageInfo: (Long, Long) = (0L, 0L) + private case class RddStorageInfo(memoryUsage: Long, diskUsage: Long, level: StorageLevel) + private val _rddStorageInfo = new mutable.HashMap[Int, RddStorageInfo] + + private case class NonRddStorageInfo(var onHeapUsage: Long, var offHeapUsage: Long, + var diskUsage: Long) + private val _nonRddStorageInfo = NonRddStorageInfo(0L, 0L, 0L) /** Create a storage status with an initial set of blocks, leaving the source unmodified. */ - def this(bmid: BlockManagerId, maxMem: Long, initialBlocks: Map[BlockId, BlockStatus]) { - this(bmid, maxMem) + def this( + bmid: BlockManagerId, + maxMemory: Long, + maxOnHeapMem: Option[Long], + maxOffHeapMem: Option[Long], + initialBlocks: Map[BlockId, BlockStatus]) { + this(bmid, maxMemory, maxOnHeapMem, maxOffHeapMem) initialBlocks.foreach { case (bid, bstatus) => addBlock(bid, bstatus) } } @@ -176,26 +176,57 @@ class StorageStatus(val blockManagerId: BlockManagerId, val maxMem: Long) { */ def numRddBlocksById(rddId: Int): Int = _rddBlocks.get(rddId).map(_.size).getOrElse(0) + /** Return the max memory can be used by this block manager. */ + def maxMem: Long = maxMemory + /** Return the memory remaining in this block manager. */ def memRemaining: Long = maxMem - memUsed + /** Return the memory used by caching RDDs */ + def cacheSize: Long = onHeapCacheSize.getOrElse(0L) + offHeapCacheSize.getOrElse(0L) + /** Return the memory used by this block manager. */ - def memUsed: Long = _nonRddStorageInfo._1 + cacheSize + def memUsed: Long = onHeapMemUsed.getOrElse(0L) + offHeapMemUsed.getOrElse(0L) - /** Return the memory used by caching RDDs */ - def cacheSize: Long = _rddBlocks.keys.toSeq.map(memUsedByRdd).sum + /** Return the on-heap memory remaining in this block manager. */ + def onHeapMemRemaining: Option[Long] = + for (m <- maxOnHeapMem; o <- onHeapMemUsed) yield m - o + + /** Return the off-heap memory remaining in this block manager. */ + def offHeapMemRemaining: Option[Long] = + for (m <- maxOffHeapMem; o <- offHeapMemUsed) yield m - o + + /** Return the on-heap memory used by this block manager. */ + def onHeapMemUsed: Option[Long] = onHeapCacheSize.map(_ + _nonRddStorageInfo.onHeapUsage) + + /** Return the off-heap memory used by this block manager. */ + def offHeapMemUsed: Option[Long] = offHeapCacheSize.map(_ + _nonRddStorageInfo.offHeapUsage) + + /** Return the memory used by on-heap caching RDDs */ + def onHeapCacheSize: Option[Long] = maxOnHeapMem.map { _ => + _rddStorageInfo.collect { + case (_, storageInfo) if !storageInfo.level.useOffHeap => storageInfo.memoryUsage + }.sum + } + + /** Return the memory used by off-heap caching RDDs */ + def offHeapCacheSize: Option[Long] = maxOffHeapMem.map { _ => + _rddStorageInfo.collect { + case (_, storageInfo) if storageInfo.level.useOffHeap => storageInfo.memoryUsage + }.sum + } /** Return the disk space used by this block manager. */ - def diskUsed: Long = _nonRddStorageInfo._2 + _rddBlocks.keys.toSeq.map(diskUsedByRdd).sum + def diskUsed: Long = _nonRddStorageInfo.diskUsage + _rddBlocks.keys.toSeq.map(diskUsedByRdd).sum /** Return the memory used by the given RDD in this block manager in O(1) time. */ - def memUsedByRdd(rddId: Int): Long = _rddStorageInfo.get(rddId).map(_._1).getOrElse(0L) + def memUsedByRdd(rddId: Int): Long = _rddStorageInfo.get(rddId).map(_.memoryUsage).getOrElse(0L) /** Return the disk space used by the given RDD in this block manager in O(1) time. */ - def diskUsedByRdd(rddId: Int): Long = _rddStorageInfo.get(rddId).map(_._2).getOrElse(0L) + def diskUsedByRdd(rddId: Int): Long = _rddStorageInfo.get(rddId).map(_.diskUsage).getOrElse(0L) /** Return the storage level, if any, used by the given RDD in this block manager. */ - def rddStorageLevel(rddId: Int): Option[StorageLevel] = _rddStorageInfo.get(rddId).map(_._3) + def rddStorageLevel(rddId: Int): Option[StorageLevel] = _rddStorageInfo.get(rddId).map(_.level) /** * Update the relevant storage info, taking into account any existing status for this block. @@ -210,10 +241,12 @@ class StorageStatus(val blockManagerId: BlockManagerId, val maxMem: Long) { val (oldMem, oldDisk) = blockId match { case RDDBlockId(rddId, _) => _rddStorageInfo.get(rddId) - .map { case (mem, disk, _) => (mem, disk) } + .map { case RddStorageInfo(mem, disk, _) => (mem, disk) } .getOrElse((0L, 0L)) - case _ => - _nonRddStorageInfo + case _ if !level.useOffHeap => + (_nonRddStorageInfo.onHeapUsage, _nonRddStorageInfo.diskUsage) + case _ if level.useOffHeap => + (_nonRddStorageInfo.offHeapUsage, _nonRddStorageInfo.diskUsage) } val newMem = math.max(oldMem + changeInMem, 0L) val newDisk = math.max(oldDisk + changeInDisk, 0L) @@ -225,13 +258,17 @@ class StorageStatus(val blockManagerId: BlockManagerId, val maxMem: Long) { if (newMem + newDisk == 0) { _rddStorageInfo.remove(rddId) } else { - _rddStorageInfo(rddId) = (newMem, newDisk, level) + _rddStorageInfo(rddId) = RddStorageInfo(newMem, newDisk, level) } case _ => - _nonRddStorageInfo = (newMem, newDisk) + if (!level.useOffHeap) { + _nonRddStorageInfo.onHeapUsage = newMem + } else { + _nonRddStorageInfo.offHeapUsage = newMem + } + _nonRddStorageInfo.diskUsage = newDisk } } - } /** Helper methods for storage-related objects. */ diff --git a/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsPage.scala b/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsPage.scala index d849ce76a9e3c..0a3c63d14ca8a 100644 --- a/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsPage.scala +++ b/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsPage.scala @@ -40,7 +40,8 @@ private[ui] case class ExecutorSummaryInfo( totalShuffleRead: Long, totalShuffleWrite: Long, isBlacklisted: Int, - maxMemory: Long, + maxOnHeapMem: Long, + maxOffHeapMem: Long, executorLogs: Map[String, String]) @@ -53,6 +54,34 @@ private[ui] class ExecutorsPage( val content =
    { +
    + + + Show Additional Metrics + + +
    ++
    ++ ++ ++ @@ -65,6 +94,11 @@ private[ui] class ExecutorsPage( } private[spark] object ExecutorsPage { + private val ON_HEAP_MEMORY_TOOLTIP = "Memory used / total available memory for on heap " + + "storage of data like RDD partitions cached in memory." + private val OFF_HEAP_MEMORY_TOOLTIP = "Memory used / total available memory for off heap " + + "storage of data like RDD partitions cached in memory." + /** Represent an executor's info as a map given a storage status index */ def getExecInfo( listener: ExecutorsListener, @@ -80,6 +114,10 @@ private[spark] object ExecutorsPage { val rddBlocks = status.numBlocks val memUsed = status.memUsed val maxMem = status.maxMem + val onHeapMemUsed = status.onHeapMemUsed + val offHeapMemUsed = status.offHeapMemUsed + val maxOnHeapMem = status.maxOnHeapMem + val maxOffHeapMem = status.maxOffHeapMem val diskUsed = status.diskUsed val taskSummary = listener.executorToTaskSummary.getOrElse(execId, ExecutorTaskSummary(execId)) @@ -103,7 +141,11 @@ private[spark] object ExecutorsPage { taskSummary.shuffleWrite, taskSummary.isBlacklisted, maxMem, - taskSummary.executorLogs + taskSummary.executorLogs, + onHeapMemUsed, + offHeapMemUsed, + maxOnHeapMem, + maxOffHeapMem ) } } diff --git a/core/src/main/scala/org/apache/spark/ui/storage/RDDPage.scala b/core/src/main/scala/org/apache/spark/ui/storage/RDDPage.scala index 227e940c9c50c..a1a0c729b9240 100644 --- a/core/src/main/scala/org/apache/spark/ui/storage/RDDPage.scala +++ b/core/src/main/scala/org/apache/spark/ui/storage/RDDPage.scala @@ -147,7 +147,8 @@ private[ui] class RDDPage(parent: StorageTab) extends WebUIPage("rdd") { /** Header fields for the worker table */ private def workerHeader = Seq( "Host", - "Memory Usage", + "On Heap Memory Usage", + "Off Heap Memory Usage", "Disk Usage") /** Render an HTML row representing a worker */ @@ -155,8 +156,12 @@ private[ui] class RDDPage(parent: StorageTab) extends WebUIPage("rdd") { {worker.address} - {Utils.bytesToString(worker.memoryUsed)} - ({Utils.bytesToString(worker.memoryRemaining)} Remaining) + {Utils.bytesToString(worker.onHeapMemoryUsed.getOrElse(0L))} + ({Utils.bytesToString(worker.onHeapMemoryRemaining.getOrElse(0L))} Remaining) + + + {Utils.bytesToString(worker.offHeapMemoryUsed.getOrElse(0L))} + ({Utils.bytesToString(worker.offHeapMemoryRemaining.getOrElse(0L))} Remaining) {Utils.bytesToString(worker.diskUsed)} diff --git a/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala b/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala index 1d2cb7acefa33..8296c4294242c 100644 --- a/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala +++ b/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala @@ -182,7 +182,9 @@ private[spark] object JsonProtocol { ("Event" -> SPARK_LISTENER_EVENT_FORMATTED_CLASS_NAMES.blockManagerAdded) ~ ("Block Manager ID" -> blockManagerId) ~ ("Maximum Memory" -> blockManagerAdded.maxMem) ~ - ("Timestamp" -> blockManagerAdded.time) + ("Timestamp" -> blockManagerAdded.time) ~ + ("Maximum Onheap Memory" -> blockManagerAdded.maxOnHeapMem) ~ + ("Maximum Offheap Memory" -> blockManagerAdded.maxOffHeapMem) } def blockManagerRemovedToJson(blockManagerRemoved: SparkListenerBlockManagerRemoved): JValue = { @@ -612,7 +614,9 @@ private[spark] object JsonProtocol { val blockManagerId = blockManagerIdFromJson(json \ "Block Manager ID") val maxMem = (json \ "Maximum Memory").extract[Long] val time = Utils.jsonOption(json \ "Timestamp").map(_.extract[Long]).getOrElse(-1L) - SparkListenerBlockManagerAdded(time, blockManagerId, maxMem) + val maxOnHeapMem = Utils.jsonOption(json \ "Maximum Onheap Memory").map(_.extract[Long]) + val maxOffHeapMem = Utils.jsonOption(json \ "Maximum Offheap Memory").map(_.extract[Long]) + SparkListenerBlockManagerAdded(time, blockManagerId, maxMem, maxOnHeapMem, maxOffHeapMem) } def blockManagerRemovedFromJson(json: JValue): SparkListenerBlockManagerRemoved = { diff --git a/core/src/test/resources/HistoryServerExpectations/executor_memory_usage_expectation.json b/core/src/test/resources/HistoryServerExpectations/executor_memory_usage_expectation.json new file mode 100644 index 0000000000000..e732af2663503 --- /dev/null +++ b/core/src/test/resources/HistoryServerExpectations/executor_memory_usage_expectation.json @@ -0,0 +1,139 @@ +[ { + "id" : "2", + "hostPort" : "172.22.0.167:51487", + "isActive" : true, + "rddBlocks" : 0, + "memoryUsed" : 0, + "diskUsed" : 0, + "totalCores" : 4, + "maxTasks" : 4, + "activeTasks" : 0, + "failedTasks" : 4, + "completedTasks" : 0, + "totalTasks" : 4, + "totalDuration" : 2537, + "totalGCTime" : 88, + "totalInputBytes" : 0, + "totalShuffleRead" : 0, + "totalShuffleWrite" : 0, + "isBlacklisted" : true, + "maxMemory" : 908381388, + "executorLogs" : { + "stdout" : "http://172.22.0.167:51469/logPage/?appId=app-20161116163331-0000&executorId=2&logType=stdout", + "stderr" : "http://172.22.0.167:51469/logPage/?appId=app-20161116163331-0000&executorId=2&logType=stderr" + }, + "onHeapMemoryUsed" : 0, + "offHeapMemoryUsed" : 0, + "maxOnHeapMemory" : 384093388, + "maxOffHeapMemory" : 524288000 +}, { + "id" : "driver", + "hostPort" : "172.22.0.167:51475", + "isActive" : true, + "rddBlocks" : 0, + "memoryUsed" : 0, + "diskUsed" : 0, + "totalCores" : 0, + "maxTasks" : 0, + "activeTasks" : 0, + "failedTasks" : 0, + "completedTasks" : 0, + "totalTasks" : 0, + "totalDuration" : 0, + "totalGCTime" : 0, + "totalInputBytes" : 0, + "totalShuffleRead" : 0, + "totalShuffleWrite" : 0, + "isBlacklisted" : true, + "maxMemory" : 908381388, + "executorLogs" : { }, + "onHeapMemoryUsed" : 0, + "offHeapMemoryUsed" : 0, + "maxOnHeapMemory" : 384093388, + "maxOffHeapMemory" : 524288000 +}, { + "id" : "1", + "hostPort" : "172.22.0.167:51490", + "isActive" : true, + "rddBlocks" : 0, + "memoryUsed" : 0, + "diskUsed" : 0, + "totalCores" : 4, + "maxTasks" : 4, + "activeTasks" : 0, + "failedTasks" : 0, + "completedTasks" : 4, + "totalTasks" : 4, + "totalDuration" : 3152, + "totalGCTime" : 68, + "totalInputBytes" : 0, + "totalShuffleRead" : 0, + "totalShuffleWrite" : 0, + "isBlacklisted" : true, + "maxMemory" : 908381388, + "executorLogs" : { + "stdout" : "http://172.22.0.167:51467/logPage/?appId=app-20161116163331-0000&executorId=1&logType=stdout", + "stderr" : "http://172.22.0.167:51467/logPage/?appId=app-20161116163331-0000&executorId=1&logType=stderr" + }, + + "onHeapMemoryUsed" : 0, + "offHeapMemoryUsed" : 0, + "maxOnHeapMemory" : 384093388, + "maxOffHeapMemory" : 524288000 +}, { + "id" : "0", + "hostPort" : "172.22.0.167:51491", + "isActive" : true, + "rddBlocks" : 0, + "memoryUsed" : 0, + "diskUsed" : 0, + "totalCores" : 4, + "maxTasks" : 4, + "activeTasks" : 0, + "failedTasks" : 4, + "completedTasks" : 0, + "totalTasks" : 4, + "totalDuration" : 2551, + "totalGCTime" : 116, + "totalInputBytes" : 0, + "totalShuffleRead" : 0, + "totalShuffleWrite" : 0, + "isBlacklisted" : true, + "maxMemory" : 908381388, + "executorLogs" : { + "stdout" : "http://172.22.0.167:51465/logPage/?appId=app-20161116163331-0000&executorId=0&logType=stdout", + "stderr" : "http://172.22.0.167:51465/logPage/?appId=app-20161116163331-0000&executorId=0&logType=stderr" + }, + "onHeapMemoryUsed" : 0, + "offHeapMemoryUsed" : 0, + "maxOnHeapMemory" : 384093388, + "maxOffHeapMemory" : 524288000 +}, { + "id" : "3", + "hostPort" : "172.22.0.167:51485", + "isActive" : true, + "rddBlocks" : 0, + "memoryUsed" : 0, + "diskUsed" : 0, + "totalCores" : 4, + "maxTasks" : 4, + "activeTasks" : 0, + "failedTasks" : 0, + "completedTasks" : 12, + "totalTasks" : 12, + "totalDuration" : 2453, + "totalGCTime" : 72, + "totalInputBytes" : 0, + "totalShuffleRead" : 0, + "totalShuffleWrite" : 0, + "isBlacklisted" : true, + "maxMemory" : 908381388, + "executorLogs" : { + "stdout" : "http://172.22.0.167:51466/logPage/?appId=app-20161116163331-0000&executorId=3&logType=stdout", + "stderr" : "http://172.22.0.167:51466/logPage/?appId=app-20161116163331-0000&executorId=3&logType=stderr" + }, + "onHeapMemoryUsed" : 0, + "offHeapMemoryUsed" : 0, + "maxOnHeapMemory" : 384093388, + "maxOffHeapMemory" : 524288000 +} ] diff --git a/core/src/test/resources/HistoryServerExpectations/executor_node_blacklisting_expectation.json b/core/src/test/resources/HistoryServerExpectations/executor_node_blacklisting_expectation.json index 5914a1c2c4b6d..e732af2663503 100644 --- a/core/src/test/resources/HistoryServerExpectations/executor_node_blacklisting_expectation.json +++ b/core/src/test/resources/HistoryServerExpectations/executor_node_blacklisting_expectation.json @@ -17,11 +17,15 @@ "totalShuffleRead" : 0, "totalShuffleWrite" : 0, "isBlacklisted" : true, - "maxMemory" : 384093388, + "maxMemory" : 908381388, "executorLogs" : { "stdout" : "http://172.22.0.167:51469/logPage/?appId=app-20161116163331-0000&executorId=2&logType=stdout", "stderr" : "http://172.22.0.167:51469/logPage/?appId=app-20161116163331-0000&executorId=2&logType=stderr" - } + }, + "onHeapMemoryUsed" : 0, + "offHeapMemoryUsed" : 0, + "maxOnHeapMemory" : 384093388, + "maxOffHeapMemory" : 524288000 }, { "id" : "driver", "hostPort" : "172.22.0.167:51475", @@ -41,8 +45,12 @@ "totalShuffleRead" : 0, "totalShuffleWrite" : 0, "isBlacklisted" : true, - "maxMemory" : 384093388, - "executorLogs" : { } + "maxMemory" : 908381388, + "executorLogs" : { }, + "onHeapMemoryUsed" : 0, + "offHeapMemoryUsed" : 0, + "maxOnHeapMemory" : 384093388, + "maxOffHeapMemory" : 524288000 }, { "id" : "1", "hostPort" : "172.22.0.167:51490", @@ -62,11 +70,16 @@ "totalShuffleRead" : 0, "totalShuffleWrite" : 0, "isBlacklisted" : true, - "maxMemory" : 384093388, + "maxMemory" : 908381388, "executorLogs" : { "stdout" : "http://172.22.0.167:51467/logPage/?appId=app-20161116163331-0000&executorId=1&logType=stdout", "stderr" : "http://172.22.0.167:51467/logPage/?appId=app-20161116163331-0000&executorId=1&logType=stderr" - } + }, + + "onHeapMemoryUsed" : 0, + "offHeapMemoryUsed" : 0, + "maxOnHeapMemory" : 384093388, + "maxOffHeapMemory" : 524288000 }, { "id" : "0", "hostPort" : "172.22.0.167:51491", @@ -86,11 +99,15 @@ "totalShuffleRead" : 0, "totalShuffleWrite" : 0, "isBlacklisted" : true, - "maxMemory" : 384093388, + "maxMemory" : 908381388, "executorLogs" : { "stdout" : "http://172.22.0.167:51465/logPage/?appId=app-20161116163331-0000&executorId=0&logType=stdout", "stderr" : "http://172.22.0.167:51465/logPage/?appId=app-20161116163331-0000&executorId=0&logType=stderr" - } + }, + "onHeapMemoryUsed" : 0, + "offHeapMemoryUsed" : 0, + "maxOnHeapMemory" : 384093388, + "maxOffHeapMemory" : 524288000 }, { "id" : "3", "hostPort" : "172.22.0.167:51485", @@ -110,9 +127,13 @@ "totalShuffleRead" : 0, "totalShuffleWrite" : 0, "isBlacklisted" : true, - "maxMemory" : 384093388, + "maxMemory" : 908381388, "executorLogs" : { "stdout" : "http://172.22.0.167:51466/logPage/?appId=app-20161116163331-0000&executorId=3&logType=stdout", "stderr" : "http://172.22.0.167:51466/logPage/?appId=app-20161116163331-0000&executorId=3&logType=stderr" - } + }, + "onHeapMemoryUsed" : 0, + "offHeapMemoryUsed" : 0, + "maxOnHeapMemory" : 384093388, + "maxOffHeapMemory" : 524288000 } ] diff --git a/core/src/test/resources/spark-events/app-20161116163331-0000 b/core/src/test/resources/spark-events/app-20161116163331-0000 index 7566c9fc0a20b..57cfc5b973129 100755 --- a/core/src/test/resources/spark-events/app-20161116163331-0000 +++ b/core/src/test/resources/spark-events/app-20161116163331-0000 @@ -1,15 +1,15 @@ {"Event":"SparkListenerLogStart","Spark Version":"2.1.0-SNAPSHOT"} -{"Event":"SparkListenerBlockManagerAdded","Block Manager ID":{"Executor ID":"driver","Host":"172.22.0.167","Port":51475},"Maximum Memory":384093388,"Timestamp":1479335611477} +{"Event":"SparkListenerBlockManagerAdded","Block Manager ID":{"Executor ID":"driver","Host":"172.22.0.167","Port":51475},"Maximum Memory":908381388,"Timestamp":1479335611477,"Maximum Onheap Memory":384093388,"Maximum Offheap Memory":524288000} {"Event":"SparkListenerEnvironmentUpdate","JVM Information":{"Java Home":"/Library/Java/JavaVirtualMachines/jdk1.8.0_92.jdk/Contents/Home/jre","Java Version":"1.8.0_92 (Oracle Corporation)","Scala Version":"version 2.11.8"},"Spark Properties":{"spark.blacklist.task.maxTaskAttemptsPerExecutor":"3","spark.blacklist.enabled":"TRUE","spark.driver.host":"172.22.0.167","spark.blacklist.task.maxTaskAttemptsPerNode":"3","spark.eventLog.enabled":"TRUE","spark.driver.port":"51459","spark.repl.class.uri":"spark://172.22.0.167:51459/classes","spark.jars":"","spark.repl.class.outputDir":"/private/var/folders/l4/d46wlzj16593f3d812vk49tw0000gp/T/spark-1cbc97d0-7fe6-4c9f-8c2c-f6fe51ee3cf2/repl-39929169-ac4c-4c6d-b116-f648e4dd62ed","spark.app.name":"Spark shell","spark.blacklist.stage.maxFailedExecutorsPerNode":"3","spark.scheduler.mode":"FIFO","spark.eventLog.overwrite":"TRUE","spark.blacklist.stage.maxFailedTasksPerExecutor":"3","spark.executor.id":"driver","spark.blacklist.application.maxFailedExecutorsPerNode":"2","spark.submit.deployMode":"client","spark.master":"local-cluster[4,4,1024]","spark.home":"/Users/Jose/IdeaProjects/spark","spark.eventLog.dir":"/Users/jose/logs","spark.sql.catalogImplementation":"in-memory","spark.eventLog.compress":"FALSE","spark.blacklist.application.maxFailedTasksPerExecutor":"1","spark.blacklist.timeout":"1000000","spark.app.id":"app-20161116163331-0000","spark.task.maxFailures":"4"},"System Properties":{"java.io.tmpdir":"/var/folders/l4/d46wlzj16593f3d812vk49tw0000gp/T/","line.separator":"\n","path.separator":":","sun.management.compiler":"HotSpot 64-Bit Tiered Compilers","SPARK_SUBMIT":"true","sun.cpu.endian":"little","java.specification.version":"1.8","java.vm.specification.name":"Java Virtual Machine Specification","java.vendor":"Oracle Corporation","java.vm.specification.version":"1.8","user.home":"/Users/Jose","file.encoding.pkg":"sun.io","sun.nio.ch.bugLevel":"","ftp.nonProxyHosts":"local|*.local|169.254/16|*.169.254/16","sun.arch.data.model":"64","sun.boot.library.path":"/Library/Java/JavaVirtualMachines/jdk1.8.0_92.jdk/Contents/Home/jre/lib","user.dir":"/Users/Jose/IdeaProjects/spark","java.library.path":"/Users/Jose/Library/Java/Extensions:/Library/Java/Extensions:/Network/Library/Java/Extensions:/System/Library/Java/Extensions:/usr/lib/java:.","sun.cpu.isalist":"","os.arch":"x86_64","java.vm.version":"25.92-b14","java.endorsed.dirs":"/Library/Java/JavaVirtualMachines/jdk1.8.0_92.jdk/Contents/Home/jre/lib/endorsed","java.runtime.version":"1.8.0_92-b14","java.vm.info":"mixed mode","java.ext.dirs":"/Users/Jose/Library/Java/Extensions:/Library/Java/JavaVirtualMachines/jdk1.8.0_92.jdk/Contents/Home/jre/lib/ext:/Library/Java/Extensions:/Network/Library/Java/Extensions:/System/Library/Java/Extensions:/usr/lib/java","java.runtime.name":"Java(TM) SE Runtime Environment","file.separator":"/","io.netty.maxDirectMemory":"0","java.class.version":"52.0","scala.usejavacp":"true","java.specification.name":"Java Platform API Specification","sun.boot.class.path":"/Library/Java/JavaVirtualMachines/jdk1.8.0_92.jdk/Contents/Home/jre/lib/resources.jar:/Library/Java/JavaVirtualMachines/jdk1.8.0_92.jdk/Contents/Home/jre/lib/rt.jar:/Library/Java/JavaVirtualMachines/jdk1.8.0_92.jdk/Contents/Home/jre/lib/sunrsasign.jar:/Library/Java/JavaVirtualMachines/jdk1.8.0_92.jdk/Contents/Home/jre/lib/jsse.jar:/Library/Java/JavaVirtualMachines/jdk1.8.0_92.jdk/Contents/Home/jre/lib/jce.jar:/Library/Java/JavaVirtualMachines/jdk1.8.0_92.jdk/Contents/Home/jre/lib/charsets.jar:/Library/Java/JavaVirtualMachines/jdk1.8.0_92.jdk/Contents/Home/jre/lib/jfr.jar:/Library/Java/JavaVirtualMachines/jdk1.8.0_92.jdk/Contents/Home/jre/classes","file.encoding":"UTF-8","user.timezone":"America/Chicago","java.specification.vendor":"Oracle Corporation","sun.java.launcher":"SUN_STANDARD","os.version":"10.11.6","sun.os.patch.level":"unknown","gopherProxySet":"false","java.vm.specification.vendor":"Oracle Corporation","user.country":"US","sun.jnu.encoding":"UTF-8","http.nonProxyHosts":"local|*.local|169.254/16|*.169.254/16","user.language":"en","socksNonProxyHosts":"local|*.local|169.254/16|*.169.254/16","java.vendor.url":"http://java.oracle.com/","java.awt.printerjob":"sun.lwawt.macosx.CPrinterJob","java.awt.graphicsenv":"sun.awt.CGraphicsEnvironment","awt.toolkit":"sun.lwawt.macosx.LWCToolkit","os.name":"Mac OS X","java.vm.vendor":"Oracle Corporation","java.vendor.url.bug":"http://bugreport.sun.com/bugreport/","user.name":"jose","java.vm.name":"Java HotSpot(TM) 64-Bit Server VM","sun.java.command":"org.apache.spark.deploy.SparkSubmit --master local-cluster[4,4,1024] --conf spark.blacklist.enabled=TRUE --conf spark.blacklist.timeout=1000000 --conf spark.blacklist.application.maxFailedTasksPerExecutor=1 --conf spark.eventLog.overwrite=TRUE --conf spark.blacklist.task.maxTaskAttemptsPerNode=3 --conf spark.blacklist.stage.maxFailedTasksPerExecutor=3 --conf spark.blacklist.task.maxTaskAttemptsPerExecutor=3 --conf spark.eventLog.compress=FALSE --conf spark.blacklist.stage.maxFailedExecutorsPerNode=3 --conf spark.eventLog.enabled=TRUE --conf spark.eventLog.dir=/Users/jose/logs --conf spark.blacklist.application.maxFailedExecutorsPerNode=2 --conf spark.task.maxFailures=4 --class org.apache.spark.repl.Main --name Spark shell spark-shell -i /Users/Jose/dev/jose-utils/blacklist/test-blacklist.scala","java.home":"/Library/Java/JavaVirtualMachines/jdk1.8.0_92.jdk/Contents/Home/jre","java.version":"1.8.0_92","sun.io.unicode.encoding":"UnicodeBig"},"Classpath Entries":{"/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/avro-mapred-1.7.7-hadoop2.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/hadoop-mapreduce-client-core-2.2.0.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/jetty-servlet-9.2.16.v20160414.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/parquet-column-1.8.1.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/snappy-java-1.1.2.6.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/oro-2.0.8.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/arpack_combined_all-0.1.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/pmml-schema-1.2.15.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/spark-assembly_2.11-2.1.0-SNAPSHOT.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/javassist-3.18.1-GA.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/spark-tags_2.11-2.1.0-SNAPSHOT.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/spark-launcher_2.11-2.1.0-SNAPSHOT.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/commons-math3-3.4.1.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/hk2-api-2.4.0-b34.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/scala-xml_2.11-1.0.4.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/objenesis-2.1.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/spire-macros_2.11-0.7.4.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/scala-reflect-2.11.8.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/spark-mllib-local_2.11-2.1.0-SNAPSHOT.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/spark-mllib_2.11-2.1.0-SNAPSHOT.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/jersey-server-2.22.2.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/core/target/scala-2.11/classes/":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/jackson-mapper-asl-1.9.13.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/jackson-module-scala_2.11-2.6.5.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/curator-framework-2.4.0.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/javax.inject-1.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/curator-client-2.4.0.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/jackson-core-asl-1.9.13.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/common/network-common/target/scala-2.11/classes/":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/zookeeper-3.4.5.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/hadoop-auth-2.2.0.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/repl/target/scala-2.11/classes/":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/jul-to-slf4j-1.7.16.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/jersey-media-jaxb-2.22.2.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/jetty-io-9.2.16.v20160414.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/RoaringBitmap-0.5.11.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/javax.ws.rs-api-2.0.1.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/sql/catalyst/target/scala-2.11/classes/":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/spark-unsafe_2.11-2.1.0-SNAPSHOT.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/spark-repl_2.11-2.1.0-SNAPSHOT.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/jetty-continuation-9.2.16.v20160414.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/hadoop-yarn-client-2.2.0.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/sql/hive-thriftserver/target/scala-2.11/classes":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/hadoop-annotations-2.2.0.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/metrics-graphite-3.1.2.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/hadoop-yarn-api-2.2.0.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/jersey-container-servlet-core-2.22.2.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/streaming/target/scala-2.11/classes/":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/commons-net-3.1.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/jetty-proxy-9.2.16.v20160414.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/spark-catalyst_2.11-2.1.0-SNAPSHOT.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/lz4-1.3.0.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/commons-crypto-1.0.0.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/common/network-yarn/target/scala-2.11/classes":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/javax.annotation-api-1.2.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/spark-sql_2.11-2.1.0-SNAPSHOT.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/guava-14.0.1.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/javax.servlet-api-3.1.0.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/commons-collections-3.2.1.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/conf/":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/unused-1.0.0.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/aopalliance-1.0.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/parquet-encoding-1.8.1.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/common/tags/target/scala-2.11/classes/":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/json4s-jackson_2.11-3.2.11.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/commons-cli-1.2.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/hadoop-yarn-server-common-2.2.0.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/cglib-2.2.1-v20090111.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/pyrolite-4.13.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/scala-library-2.11.8.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/scala-parser-combinators_2.11-1.0.4.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/jetty-util-6.1.26.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/py4j-0.10.4.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/commons-configuration-1.6.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/core-1.1.2.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/core/target/jars/*":"System Classpath","/Users/Jose/IdeaProjects/spark/common/network-shuffle/target/scala-2.11/classes/":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/parquet-format-2.3.0-incubating.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/kryo-shaded-3.0.3.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/sql/core/target/scala-2.11/classes/":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/chill-java-0.8.0.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/jackson-annotations-2.6.5.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/parquet-hadoop-1.8.1.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/sql/hive/target/scala-2.11/classes/":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/avro-ipc-1.7.7.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/xz-1.0.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/parquet-jackson-1.8.1.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/aopalliance-repackaged-2.4.0-b34.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/jersey-common-2.22.2.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/log4j-1.2.17.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/metrics-core-3.1.2.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/jetty-util-9.2.16.v20160414.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/scalap-2.11.0.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/osgi-resource-locator-1.0.1.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/commons-beanutils-1.7.0.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/commons-compress-1.4.1.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/jcl-over-slf4j-1.7.16.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/yarn/target/scala-2.11/classes":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/jetty-plus-9.2.16.v20160414.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/protobuf-java-2.5.0.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/common/unsafe/target/scala-2.11/classes/":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/jackson-module-paranamer-2.6.5.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/leveldbjni-all-1.8.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/jackson-core-2.6.5.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/slf4j-api-1.7.16.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/compress-lzf-1.0.3.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/stream-2.7.0.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/hadoop-mapreduce-client-shuffle-2.2.0.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/commons-codec-1.10.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/hadoop-yarn-common-2.2.0.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/common/sketch/target/scala-2.11/classes/":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/breeze_2.11-0.12.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/hadoop-mapreduce-client-common-2.2.0.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/spark-core_2.11-2.1.0-SNAPSHOT.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/jersey-container-servlet-2.22.2.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/spark-network-shuffle_2.11-2.1.0-SNAPSHOT.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/commons-lang-2.5.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/ivy-2.4.0.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/hadoop-common-2.2.0.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/commons-math-2.1.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/hadoop-hdfs-2.2.0.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/scala-compiler-2.11.8.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/metrics-jvm-3.1.2.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/commons-lang3-3.5.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/jsr305-1.3.9.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/minlog-1.3.0.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/netty-3.8.0.Final.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/jetty-webapp-9.2.16.v20160414.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/json4s-ast_2.11-3.2.11.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/xbean-asm5-shaded-4.4.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/commons-io-2.1.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/slf4j-log4j12-1.7.16.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/hk2-locator-2.4.0-b34.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/shapeless_2.11-2.0.0.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/spark-network-common_2.11-2.1.0-SNAPSHOT.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/jetty-xml-9.2.16.v20160414.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/commons-httpclient-3.1.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/javax.inject-2.4.0-b34.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/mllib/target/scala-2.11/classes/":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/scalatest_2.11-2.2.6.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/hk2-utils-2.4.0-b34.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/jetty-client-9.2.16.v20160414.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/jersey-guava-2.22.2.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/jetty-jndi-9.2.16.v20160414.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/graphx/target/scala-2.11/classes/":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/hadoop-mapreduce-client-app-2.2.0.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/examples/target/scala-2.11/classes/":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/xmlenc-0.52.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/jets3t-0.7.1.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/curator-recipes-2.4.0.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/opencsv-2.3.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/jtransforms-2.4.0.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/antlr4-runtime-4.5.3.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/chill_2.11-0.8.0.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/commons-digester-1.8.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/univocity-parsers-2.2.1.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/jline-2.12.1.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/spark-streaming_2.11-2.1.0-SNAPSHOT.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/launcher/target/scala-2.11/classes/":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/breeze-macros_2.11-0.12.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/jersey-client-2.22.2.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/jackson-databind-2.6.5.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/jetty-servlets-9.2.16.v20160414.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/paranamer-2.6.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/jetty-security-9.2.16.v20160414.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/avro-ipc-1.7.7-tests.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/avro-1.7.7.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/spire_2.11-0.7.4.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/hadoop-client-2.2.0.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/metrics-json-3.1.2.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/commons-beanutils-core-1.8.0.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/validation-api-1.1.0.Final.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/spark-graphx_2.11-2.1.0-SNAPSHOT.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/netty-all-4.0.41.Final.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/janino-3.0.0.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/json4s-core_2.11-3.2.11.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/commons-compiler-3.0.0.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/guice-3.0.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/jetty-server-9.2.16.v20160414.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/jetty-http-9.2.16.v20160414.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/parquet-common-1.8.1.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/hadoop-mapreduce-client-jobclient-2.2.0.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/spark-sketch_2.11-2.1.0-SNAPSHOT.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/pmml-model-1.2.15.jar":"System Classpath"}} {"Event":"SparkListenerApplicationStart","App Name":"Spark shell","App ID":"app-20161116163331-0000","Timestamp":1479335609916,"User":"jose"} {"Event":"SparkListenerExecutorAdded","Timestamp":1479335615320,"Executor ID":"3","Executor Info":{"Host":"172.22.0.167","Total Cores":4,"Log Urls":{"stdout":"http://172.22.0.167:51466/logPage/?appId=app-20161116163331-0000&executorId=3&logType=stdout","stderr":"http://172.22.0.167:51466/logPage/?appId=app-20161116163331-0000&executorId=3&logType=stderr"}}} -{"Event":"SparkListenerBlockManagerAdded","Block Manager ID":{"Executor ID":"3","Host":"172.22.0.167","Port":51485},"Maximum Memory":384093388,"Timestamp":1479335615387} +{"Event":"SparkListenerBlockManagerAdded","Block Manager ID":{"Executor ID":"3","Host":"172.22.0.167","Port":51485},"Maximum Memory":908381388,"Timestamp":1479335615387,"Maximum Onheap Memory":384093388,"Maximum Offheap Memory":524288000} {"Event":"SparkListenerExecutorAdded","Timestamp":1479335615393,"Executor ID":"2","Executor Info":{"Host":"172.22.0.167","Total Cores":4,"Log Urls":{"stdout":"http://172.22.0.167:51469/logPage/?appId=app-20161116163331-0000&executorId=2&logType=stdout","stderr":"http://172.22.0.167:51469/logPage/?appId=app-20161116163331-0000&executorId=2&logType=stderr"}}} {"Event":"SparkListenerExecutorAdded","Timestamp":1479335615443,"Executor ID":"1","Executor Info":{"Host":"172.22.0.167","Total Cores":4,"Log Urls":{"stdout":"http://172.22.0.167:51467/logPage/?appId=app-20161116163331-0000&executorId=1&logType=stdout","stderr":"http://172.22.0.167:51467/logPage/?appId=app-20161116163331-0000&executorId=1&logType=stderr"}}} -{"Event":"SparkListenerBlockManagerAdded","Block Manager ID":{"Executor ID":"2","Host":"172.22.0.167","Port":51487},"Maximum Memory":384093388,"Timestamp":1479335615448} +{"Event":"SparkListenerBlockManagerAdded","Block Manager ID":{"Executor ID":"2","Host":"172.22.0.167","Port":51487},"Maximum Memory":908381388,"Timestamp":1479335615448,"Maximum Onheap Memory":384093388,"Maximum Offheap Memory":524288000} {"Event":"SparkListenerExecutorAdded","Timestamp":1479335615462,"Executor ID":"0","Executor Info":{"Host":"172.22.0.167","Total Cores":4,"Log Urls":{"stdout":"http://172.22.0.167:51465/logPage/?appId=app-20161116163331-0000&executorId=0&logType=stdout","stderr":"http://172.22.0.167:51465/logPage/?appId=app-20161116163331-0000&executorId=0&logType=stderr"}}} -{"Event":"SparkListenerBlockManagerAdded","Block Manager ID":{"Executor ID":"1","Host":"172.22.0.167","Port":51490},"Maximum Memory":384093388,"Timestamp":1479335615496} -{"Event":"SparkListenerBlockManagerAdded","Block Manager ID":{"Executor ID":"0","Host":"172.22.0.167","Port":51491},"Maximum Memory":384093388,"Timestamp":1479335615515} +{"Event":"SparkListenerBlockManagerAdded","Block Manager ID":{"Executor ID":"1","Host":"172.22.0.167","Port":51490},"Maximum Memory":908381388,"Timestamp":1479335615496,"Maximum Onheap Memory":384093388,"Maximum Offheap Memory":524288000} +{"Event":"SparkListenerBlockManagerAdded","Block Manager ID":{"Executor ID":"0","Host":"172.22.0.167","Port":51491},"Maximum Memory":908381388,"Timestamp":1479335615515,"Maximum Onheap Memory":384093388,"Maximum Offheap Memory":524288000} {"Event":"SparkListenerJobStart","Job ID":0,"Submission Time":1479335616467,"Stage Infos":[{"Stage ID":0,"Stage Attempt ID":0,"Stage Name":"count at :26","Number of Tasks":16,"RDD Info":[{"RDD ID":1,"Name":"MapPartitionsRDD","Scope":"{\"id\":\"1\",\"name\":\"map\"}","Callsite":"map at :26","Parent IDs":[0],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":16,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0},{"RDD ID":0,"Name":"ParallelCollectionRDD","Scope":"{\"id\":\"0\",\"name\":\"parallelize\"}","Callsite":"parallelize at :26","Parent IDs":[],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":16,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0}],"Parent IDs":[],"Details":"org.apache.spark.rdd.RDD.count(RDD.scala:1135)\n$line16.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw.(:26)\n$line16.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw.(:31)\n$line16.$read$$iw$$iw$$iw$$iw$$iw$$iw.(:33)\n$line16.$read$$iw$$iw$$iw$$iw$$iw.(:35)\n$line16.$read$$iw$$iw$$iw$$iw.(:37)\n$line16.$read$$iw$$iw$$iw.(:39)\n$line16.$read$$iw$$iw.(:41)\n$line16.$read$$iw.(:43)\n$line16.$read.(:45)\n$line16.$read$.(:49)\n$line16.$read$.()\n$line16.$eval$.$print$lzycompute(:7)\n$line16.$eval$.$print(:6)\n$line16.$eval.$print()\nsun.reflect.NativeMethodAccessorImpl.invoke0(Native Method)\nsun.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:62)\nsun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)\njava.lang.reflect.Method.invoke(Method.java:498)\nscala.tools.nsc.interpreter.IMain$ReadEvalPrint.call(IMain.scala:786)","Accumulables":[]}],"Stage IDs":[0],"Properties":{}} {"Event":"SparkListenerStageSubmitted","Stage Info":{"Stage ID":0,"Stage Attempt ID":0,"Stage Name":"count at :26","Number of Tasks":16,"RDD Info":[{"RDD ID":1,"Name":"MapPartitionsRDD","Scope":"{\"id\":\"1\",\"name\":\"map\"}","Callsite":"map at :26","Parent IDs":[0],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":16,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0},{"RDD ID":0,"Name":"ParallelCollectionRDD","Scope":"{\"id\":\"0\",\"name\":\"parallelize\"}","Callsite":"parallelize at :26","Parent IDs":[],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":16,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0}],"Parent IDs":[],"Details":"org.apache.spark.rdd.RDD.count(RDD.scala:1135)\n$line16.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw.(:26)\n$line16.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw.(:31)\n$line16.$read$$iw$$iw$$iw$$iw$$iw$$iw.(:33)\n$line16.$read$$iw$$iw$$iw$$iw$$iw.(:35)\n$line16.$read$$iw$$iw$$iw$$iw.(:37)\n$line16.$read$$iw$$iw$$iw.(:39)\n$line16.$read$$iw$$iw.(:41)\n$line16.$read$$iw.(:43)\n$line16.$read.(:45)\n$line16.$read$.(:49)\n$line16.$read$.()\n$line16.$eval$.$print$lzycompute(:7)\n$line16.$eval$.$print(:6)\n$line16.$eval.$print()\nsun.reflect.NativeMethodAccessorImpl.invoke0(Native Method)\nsun.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:62)\nsun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)\njava.lang.reflect.Method.invoke(Method.java:498)\nscala.tools.nsc.interpreter.IMain$ReadEvalPrint.call(IMain.scala:786)","Accumulables":[]},"Properties":{}} {"Event":"SparkListenerTaskStart","Stage ID":0,"Stage Attempt ID":0,"Task Info":{"Task ID":0,"Index":0,"Attempt":0,"Launch Time":1479335616657,"Executor ID":"1","Host":"172.22.0.167","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":0,"Failed":false,"Killed":false,"Accumulables":[]}} diff --git a/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerSuite.scala b/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerSuite.scala index dcf83cb530a91..764156c3edc41 100644 --- a/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerSuite.scala @@ -153,7 +153,8 @@ class HistoryServerSuite extends SparkFunSuite with BeforeAndAfter with Matchers "rdd list storage json" -> "applications/local-1422981780767/storage/rdd", "executor node blacklisting" -> "applications/app-20161116163331-0000/executors", - "executor node blacklisting unblacklisting" -> "applications/app-20161115172038-0000/executors" + "executor node blacklisting unblacklisting" -> "applications/app-20161115172038-0000/executors", + "executor memory usage" -> "applications/app-20161116163331-0000/executors" // Todo: enable this test when logging the even of onBlockUpdated. See: SPARK-13845 // "one rdd storage json" -> "applications/local-1422981780767/storage/rdd/0" ) diff --git a/core/src/test/scala/org/apache/spark/storage/StorageSuite.scala b/core/src/test/scala/org/apache/spark/storage/StorageSuite.scala index e5733aebf607c..da198f946fd64 100644 --- a/core/src/test/scala/org/apache/spark/storage/StorageSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/StorageSuite.scala @@ -27,7 +27,7 @@ class StorageSuite extends SparkFunSuite { // For testing add, update, and remove (for non-RDD blocks) private def storageStatus1: StorageStatus = { - val status = new StorageStatus(BlockManagerId("big", "dog", 1), 1000L) + val status = new StorageStatus(BlockManagerId("big", "dog", 1), 1000L, Some(1000L), Some(0L)) assert(status.blocks.isEmpty) assert(status.rddBlocks.isEmpty) assert(status.memUsed === 0L) @@ -74,7 +74,7 @@ class StorageSuite extends SparkFunSuite { // For testing add, update, remove, get, and contains etc. for both RDD and non-RDD blocks private def storageStatus2: StorageStatus = { - val status = new StorageStatus(BlockManagerId("big", "dog", 1), 1000L) + val status = new StorageStatus(BlockManagerId("big", "dog", 1), 1000L, Some(1000L), Some(0L)) assert(status.rddBlocks.isEmpty) status.addBlock(TestBlockId("dan"), BlockStatus(memAndDisk, 10L, 20L)) status.addBlock(TestBlockId("man"), BlockStatus(memAndDisk, 10L, 20L)) @@ -252,9 +252,9 @@ class StorageSuite extends SparkFunSuite { // For testing StorageUtils.updateRddInfo and StorageUtils.getRddBlockLocations private def stockStorageStatuses: Seq[StorageStatus] = { - val status1 = new StorageStatus(BlockManagerId("big", "dog", 1), 1000L) - val status2 = new StorageStatus(BlockManagerId("fat", "duck", 2), 2000L) - val status3 = new StorageStatus(BlockManagerId("fat", "cat", 3), 3000L) + val status1 = new StorageStatus(BlockManagerId("big", "dog", 1), 1000L, Some(1000L), Some(0L)) + val status2 = new StorageStatus(BlockManagerId("fat", "duck", 2), 2000L, Some(2000L), Some(0L)) + val status3 = new StorageStatus(BlockManagerId("fat", "cat", 3), 3000L, Some(3000L), Some(0L)) status1.addBlock(RDDBlockId(0, 0), BlockStatus(memAndDisk, 1L, 2L)) status1.addBlock(RDDBlockId(0, 1), BlockStatus(memAndDisk, 1L, 2L)) status2.addBlock(RDDBlockId(0, 2), BlockStatus(memAndDisk, 1L, 2L)) @@ -332,4 +332,81 @@ class StorageSuite extends SparkFunSuite { assert(blockLocations1(RDDBlockId(1, 2)) === Seq("cat:3")) } + private val offheap = StorageLevel.OFF_HEAP + // For testing add, update, remove, get, and contains etc. for both RDD and non-RDD onheap + // and offheap blocks + private def storageStatus3: StorageStatus = { + val status = new StorageStatus(BlockManagerId("big", "dog", 1), 2000L, Some(1000L), Some(1000L)) + assert(status.rddBlocks.isEmpty) + status.addBlock(TestBlockId("dan"), BlockStatus(memAndDisk, 10L, 20L)) + status.addBlock(TestBlockId("man"), BlockStatus(offheap, 10L, 0L)) + status.addBlock(RDDBlockId(0, 0), BlockStatus(offheap, 10L, 0L)) + status.addBlock(RDDBlockId(1, 1), BlockStatus(offheap, 100L, 0L)) + status.addBlock(RDDBlockId(2, 2), BlockStatus(memAndDisk, 10L, 20L)) + status.addBlock(RDDBlockId(2, 3), BlockStatus(memAndDisk, 10L, 20L)) + status.addBlock(RDDBlockId(2, 4), BlockStatus(memAndDisk, 10L, 40L)) + status + } + + test("storage memUsed, diskUsed with on-heap and off-heap blocks") { + val status = storageStatus3 + def actualMemUsed: Long = status.blocks.values.map(_.memSize).sum + def actualDiskUsed: Long = status.blocks.values.map(_.diskSize).sum + + def actualOnHeapMemUsed: Long = + status.blocks.values.filter(!_.storageLevel.useOffHeap).map(_.memSize).sum + def actualOffHeapMemUsed: Long = + status.blocks.values.filter(_.storageLevel.useOffHeap).map(_.memSize).sum + + assert(status.maxMem === status.maxOnHeapMem.get + status.maxOffHeapMem.get) + + assert(status.memUsed === actualMemUsed) + assert(status.diskUsed === actualDiskUsed) + assert(status.onHeapMemUsed.get === actualOnHeapMemUsed) + assert(status.offHeapMemUsed.get === actualOffHeapMemUsed) + + assert(status.memRemaining === status.maxMem - actualMemUsed) + assert(status.onHeapMemRemaining.get === status.maxOnHeapMem.get - actualOnHeapMemUsed) + assert(status.offHeapMemRemaining.get === status.maxOffHeapMem.get - actualOffHeapMemUsed) + + status.addBlock(TestBlockId("wire"), BlockStatus(memAndDisk, 400L, 500L)) + status.addBlock(RDDBlockId(25, 25), BlockStatus(memAndDisk, 40L, 50L)) + assert(status.memUsed === actualMemUsed) + assert(status.diskUsed === actualDiskUsed) + + status.updateBlock(TestBlockId("dan"), BlockStatus(memAndDisk, 4L, 5L)) + status.updateBlock(RDDBlockId(0, 0), BlockStatus(offheap, 4L, 0L)) + status.updateBlock(RDDBlockId(1, 1), BlockStatus(offheap, 4L, 0L)) + assert(status.memUsed === actualMemUsed) + assert(status.diskUsed === actualDiskUsed) + assert(status.onHeapMemUsed.get === actualOnHeapMemUsed) + assert(status.offHeapMemUsed.get === actualOffHeapMemUsed) + + status.removeBlock(TestBlockId("fire")) + status.removeBlock(TestBlockId("man")) + status.removeBlock(RDDBlockId(2, 2)) + status.removeBlock(RDDBlockId(2, 3)) + assert(status.memUsed === actualMemUsed) + assert(status.diskUsed === actualDiskUsed) + } + + private def storageStatus4: StorageStatus = { + val status = new StorageStatus(BlockManagerId("big", "dog", 1), 2000L, None, None) + status + } + test("old SparkListenerBlockManagerAdded event compatible") { + // This scenario will only be happened when replaying old event log. In this scenario there's + // no block add or remove event replayed, so only total amount of memory is valid. + val status = storageStatus4 + assert(status.maxMem === status.maxMemory) + + assert(status.memUsed === 0L) + assert(status.diskUsed === 0L) + assert(status.onHeapMemUsed === None) + assert(status.offHeapMemUsed === None) + + assert(status.memRemaining === status.maxMem) + assert(status.onHeapMemRemaining === None) + assert(status.offHeapMemRemaining === None) + } } diff --git a/core/src/test/scala/org/apache/spark/ui/UISeleniumSuite.scala b/core/src/test/scala/org/apache/spark/ui/UISeleniumSuite.scala index 4228373036425..f4c561c737794 100644 --- a/core/src/test/scala/org/apache/spark/ui/UISeleniumSuite.scala +++ b/core/src/test/scala/org/apache/spark/ui/UISeleniumSuite.scala @@ -39,7 +39,7 @@ import org.apache.spark.LocalSparkContext._ import org.apache.spark.api.java.StorageLevels import org.apache.spark.deploy.history.HistoryServerSuite import org.apache.spark.shuffle.FetchFailedException -import org.apache.spark.status.api.v1.{JacksonMessageWriter, StageStatus} +import org.apache.spark.status.api.v1.{JacksonMessageWriter, RDDDataDistribution, StageStatus} private[spark] class SparkUICssErrorHandler extends DefaultCssErrorHandler { @@ -103,6 +103,7 @@ class UISeleniumSuite extends SparkFunSuite with WebBrowser with Matchers with B .set("spark.ui.enabled", "true") .set("spark.ui.port", "0") .set("spark.ui.killEnabled", killEnabled.toString) + .set("spark.memory.offHeap.size", "64m") val sc = new SparkContext(conf) assert(sc.ui.isDefined) sc @@ -151,6 +152,39 @@ class UISeleniumSuite extends SparkFunSuite with WebBrowser with Matchers with B val updatedRddJson = getJson(ui, "storage/rdd/0") (updatedRddJson \ "storageLevel").extract[String] should be ( StorageLevels.MEMORY_ONLY.description) + + val dataDistributions0 = + (updatedRddJson \ "dataDistribution").extract[Seq[RDDDataDistribution]] + dataDistributions0.length should be (1) + val dist0 = dataDistributions0.head + + dist0.onHeapMemoryUsed should not be (None) + dist0.memoryUsed should be (dist0.onHeapMemoryUsed.get) + dist0.onHeapMemoryRemaining should not be (None) + dist0.offHeapMemoryRemaining should not be (None) + dist0.memoryRemaining should be ( + dist0.onHeapMemoryRemaining.get + dist0.offHeapMemoryRemaining.get) + dist0.onHeapMemoryUsed should not be (Some(0L)) + dist0.offHeapMemoryUsed should be (Some(0L)) + + rdd.unpersist() + rdd.persist(StorageLevels.OFF_HEAP).count() + val updatedStorageJson1 = getJson(ui, "storage/rdd") + updatedStorageJson1.children.length should be (1) + val updatedRddJson1 = getJson(ui, "storage/rdd/0") + val dataDistributions1 = + (updatedRddJson1 \ "dataDistribution").extract[Seq[RDDDataDistribution]] + dataDistributions1.length should be (1) + val dist1 = dataDistributions1.head + + dist1.offHeapMemoryUsed should not be (None) + dist1.memoryUsed should be (dist1.offHeapMemoryUsed.get) + dist1.onHeapMemoryRemaining should not be (None) + dist1.offHeapMemoryRemaining should not be (None) + dist1.memoryRemaining should be ( + dist1.onHeapMemoryRemaining.get + dist1.offHeapMemoryRemaining.get) + dist1.onHeapMemoryUsed should be (Some(0L)) + dist1.offHeapMemoryUsed should not be (Some(0L)) } } diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index 2e3f9f2d0f3ac..feae76a087dec 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -100,7 +100,16 @@ object MimaExcludes { ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.ml.linalg.Matrix.toDenseMatrix"), ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.ml.linalg.Matrix.toSparseMatrix"), ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.ml.linalg.Matrix.getSizeInBytes") - ) + ) ++ Seq( + // [SPARK-17019] Expose on-heap and off-heap memory usage in various places + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.scheduler.SparkListenerBlockManagerAdded.copy"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.scheduler.SparkListenerBlockManagerAdded.this"), + ProblemFilters.exclude[MissingTypesProblem]("org.apache.spark.scheduler.SparkListenerBlockManagerAdded$"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.scheduler.SparkListenerBlockManagerAdded.apply"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.storage.StorageStatus.this"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.storage.StorageStatus.this"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.status.api.v1.RDDDataDistribution.this") + ) // Exclude rules for 2.1.x lazy val v21excludes = v20excludes ++ { From 8129d59d0e389fa8074958f1b90f7539e3e79bb7 Mon Sep 17 00:00:00 2001 From: Dustin Koupal Date: Thu, 6 Apr 2017 16:56:36 -0700 Subject: [PATCH 0214/1765] [MINOR][DOCS] Fix typo in Hive Examples ## What changes were proposed in this pull request? Fix typo in hive examples from "DaraFrames" to "DataFrames" ## How was this patch tested? N/A Please review http://spark.apache.org/contributing.html before opening a pull request. Author: Dustin Koupal Closes #17554 from cooper6581/typo-daraframes. --- .../apache/spark/examples/sql/hive/JavaSparkHiveExample.java | 2 +- examples/src/main/python/sql/hive.py | 2 +- .../org/apache/spark/examples/sql/hive/SparkHiveExample.scala | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/examples/src/main/java/org/apache/spark/examples/sql/hive/JavaSparkHiveExample.java b/examples/src/main/java/org/apache/spark/examples/sql/hive/JavaSparkHiveExample.java index 47638565b1663..575a463e8725f 100644 --- a/examples/src/main/java/org/apache/spark/examples/sql/hive/JavaSparkHiveExample.java +++ b/examples/src/main/java/org/apache/spark/examples/sql/hive/JavaSparkHiveExample.java @@ -89,7 +89,7 @@ public static void main(String[] args) { // The results of SQL queries are themselves DataFrames and support all normal functions. Dataset sqlDF = spark.sql("SELECT key, value FROM src WHERE key < 10 ORDER BY key"); - // The items in DaraFrames are of type Row, which lets you to access each column by ordinal. + // The items in DataFrames are of type Row, which lets you to access each column by ordinal. Dataset stringsDS = sqlDF.map( (MapFunction) row -> "Key: " + row.get(0) + ", Value: " + row.get(1), Encoders.STRING()); diff --git a/examples/src/main/python/sql/hive.py b/examples/src/main/python/sql/hive.py index 1f175d725800f..1f83a6fb48b97 100644 --- a/examples/src/main/python/sql/hive.py +++ b/examples/src/main/python/sql/hive.py @@ -68,7 +68,7 @@ # The results of SQL queries are themselves DataFrames and support all normal functions. sqlDF = spark.sql("SELECT key, value FROM src WHERE key < 10 ORDER BY key") - # The items in DaraFrames are of type Row, which allows you to access each column by ordinal. + # The items in DataFrames are of type Row, which allows you to access each column by ordinal. stringsDS = sqlDF.rdd.map(lambda row: "Key: %d, Value: %s" % (row.key, row.value)) for record in stringsDS.collect(): print(record) diff --git a/examples/src/main/scala/org/apache/spark/examples/sql/hive/SparkHiveExample.scala b/examples/src/main/scala/org/apache/spark/examples/sql/hive/SparkHiveExample.scala index 3de26364b5288..e5f75d53edc86 100644 --- a/examples/src/main/scala/org/apache/spark/examples/sql/hive/SparkHiveExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/sql/hive/SparkHiveExample.scala @@ -76,7 +76,7 @@ object SparkHiveExample { // The results of SQL queries are themselves DataFrames and support all normal functions. val sqlDF = sql("SELECT key, value FROM src WHERE key < 10 ORDER BY key") - // The items in DaraFrames are of type Row, which allows you to access each column by ordinal. + // The items in DataFrames are of type Row, which allows you to access each column by ordinal. val stringsDS = sqlDF.map { case Row(key: Int, value: String) => s"Key: $key, Value: $value" } From 626b4cafce7d2dca186144336939d4d993b6f878 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Thu, 6 Apr 2017 19:24:03 -0700 Subject: [PATCH 0215/1765] [SPARK-19495][SQL] Make SQLConf slightly more extensible - addendum ## What changes were proposed in this pull request? This is a tiny addendum to SPARK-19495 to remove the private visibility for copy, which is the only package private method in the entire file. ## How was this patch tested? N/A - no semantic change. Author: Reynold Xin Closes #17555 from rxin/SPARK-19495-2. --- .../src/main/scala/org/apache/spark/sql/internal/SQLConf.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index e685c2bed50ae..640c0f189c237 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -1153,7 +1153,7 @@ class SQLConf extends Serializable with Logging { } // For test only - private[spark] def copy(entries: (ConfigEntry[_], Any)*): SQLConf = { + def copy(entries: (ConfigEntry[_], Any)*): SQLConf = { val cloned = clone() entries.foreach { case (entry, value) => cloned.setConfString(entry.key, value.toString) From ad3cc1312db3b5667cea134940a09896a4609b74 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Fri, 7 Apr 2017 15:58:50 +0800 Subject: [PATCH 0216/1765] [SPARK-20245][SQL][MINOR] pass output to LogicalRelation directly ## What changes were proposed in this pull request? Currently `LogicalRelation` has a `expectedOutputAttributes` parameter, which makes it hard to reason about what the actual output is. Like other leaf nodes, `LogicalRelation` should also take `output` as a parameter, to simplify the logic ## How was this patch tested? existing tests Author: Wenchen Fan Closes #17552 from cloud-fan/minor. --- .../sql/catalyst/catalog/interface.scala | 8 ++-- .../datasources/DataSourceStrategy.scala | 15 +++---- .../datasources/LogicalRelation.scala | 39 +++++++------------ .../PruneFileSourcePartitions.scala | 4 +- .../spark/sql/sources/PathOptionSuite.scala | 19 ++++----- .../spark/sql/hive/HiveMetastoreCatalog.scala | 13 +++++-- .../spark/sql/hive/CachedTableSuite.scala | 4 +- .../PruneFileSourcePartitionsSuite.scala | 2 +- 8 files changed, 49 insertions(+), 55 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala index dc2e40424fd5f..360e55d922821 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala @@ -27,7 +27,7 @@ import com.google.common.base.Objects import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.{FunctionIdentifier, InternalRow, TableIdentifier} import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation -import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap, Cast, Literal} +import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap, AttributeReference, Cast, Literal} import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, DateTimeUtils} import org.apache.spark.sql.catalyst.util.quoteIdentifier @@ -403,14 +403,14 @@ object CatalogTypes { */ case class CatalogRelation( tableMeta: CatalogTable, - dataCols: Seq[Attribute], - partitionCols: Seq[Attribute]) extends LeafNode with MultiInstanceRelation { + dataCols: Seq[AttributeReference], + partitionCols: Seq[AttributeReference]) extends LeafNode with MultiInstanceRelation { assert(tableMeta.identifier.database.isDefined) assert(tableMeta.partitionSchema.sameType(partitionCols.toStructType)) assert(tableMeta.dataSchema.sameType(dataCols.toStructType)) // The partition column should always appear after data columns. - override def output: Seq[Attribute] = dataCols ++ partitionCols + override def output: Seq[AttributeReference] = dataCols ++ partitionCols def isPartitioned: Boolean = partitionCols.nonEmpty diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala index e5c7c383d708c..2d83d512e702d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala @@ -231,16 +231,17 @@ class FindDataSourceTable(sparkSession: SparkSession) extends Rule[LogicalPlan] options = table.storage.properties ++ pathOption, catalogTable = Some(table)) - LogicalRelation( - dataSource.resolveRelation(checkFilesExist = false), - catalogTable = Some(table)) + LogicalRelation(dataSource.resolveRelation(checkFilesExist = false), table) } }).asInstanceOf[LogicalRelation] - // It's possible that the table schema is empty and need to be inferred at runtime. We should - // not specify expected outputs for this case. - val expectedOutputs = if (r.output.isEmpty) None else Some(r.output) - plan.copy(expectedOutputAttributes = expectedOutputs) + if (r.output.isEmpty) { + // It's possible that the table schema is empty and need to be inferred at runtime. For this + // case, we don't need to change the output of the cached plan. + plan + } else { + plan.copy(output = r.output) + } } override def apply(plan: LogicalPlan): LogicalPlan = plan transform { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/LogicalRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/LogicalRelation.scala index 3b14b794fd08c..4215203960075 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/LogicalRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/LogicalRelation.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.execution.datasources import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation import org.apache.spark.sql.catalyst.catalog.CatalogTable -import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap, AttributeReference} +import org.apache.spark.sql.catalyst.expressions.{AttributeMap, AttributeReference} import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, LogicalPlan, Statistics} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.sources.BaseRelation @@ -26,31 +26,13 @@ import org.apache.spark.util.Utils /** * Used to link a [[BaseRelation]] in to a logical query plan. - * - * Note that sometimes we need to use `LogicalRelation` to replace an existing leaf node without - * changing the output attributes' IDs. The `expectedOutputAttributes` parameter is used for - * this purpose. See https://issues.apache.org/jira/browse/SPARK-10741 for more details. */ case class LogicalRelation( relation: BaseRelation, - expectedOutputAttributes: Option[Seq[Attribute]] = None, - catalogTable: Option[CatalogTable] = None) + output: Seq[AttributeReference], + catalogTable: Option[CatalogTable]) extends LeafNode with MultiInstanceRelation { - override val output: Seq[AttributeReference] = { - val attrs = relation.schema.toAttributes - expectedOutputAttributes.map { expectedAttrs => - assert(expectedAttrs.length == attrs.length) - attrs.zip(expectedAttrs).map { - // We should respect the attribute names provided by base relation and only use the - // exprId in `expectedOutputAttributes`. - // The reason is that, some relations(like parquet) will reconcile attribute names to - // workaround case insensitivity issue. - case (attr, expected) => attr.withExprId(expected.exprId) - } - }.getOrElse(attrs) - } - // Logical Relations are distinct if they have different output for the sake of transformations. override def equals(other: Any): Boolean = other match { case l @ LogicalRelation(otherRelation, _, _) => relation == otherRelation && output == l.output @@ -87,11 +69,8 @@ case class LogicalRelation( * unique expression ids. We respect the `expectedOutputAttributes` and create * new instances of attributes in it. */ - override def newInstance(): this.type = { - LogicalRelation( - relation, - expectedOutputAttributes.map(_.map(_.newInstance())), - catalogTable).asInstanceOf[this.type] + override def newInstance(): LogicalRelation = { + this.copy(output = output.map(_.newInstance())) } override def refresh(): Unit = relation match { @@ -101,3 +80,11 @@ case class LogicalRelation( override def simpleString: String = s"Relation[${Utils.truncatedString(output, ",")}] $relation" } + +object LogicalRelation { + def apply(relation: BaseRelation): LogicalRelation = + LogicalRelation(relation, relation.schema.toAttributes, None) + + def apply(relation: BaseRelation, table: CatalogTable): LogicalRelation = + LogicalRelation(relation, relation.schema.toAttributes, Some(table)) +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PruneFileSourcePartitions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PruneFileSourcePartitions.scala index 8566a8061034b..905b8683e10bd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PruneFileSourcePartitions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PruneFileSourcePartitions.scala @@ -59,9 +59,7 @@ private[sql] object PruneFileSourcePartitions extends Rule[LogicalPlan] { val prunedFileIndex = catalogFileIndex.filterPartitions(partitionKeyFilters.toSeq) val prunedFsRelation = fsRelation.copy(location = prunedFileIndex)(sparkSession) - val prunedLogicalRelation = logicalRelation.copy( - relation = prunedFsRelation, - expectedOutputAttributes = Some(logicalRelation.output)) + val prunedLogicalRelation = logicalRelation.copy(relation = prunedFsRelation) // Keep partition-pruning predicates so that they are visible in physical planning val filterExpression = filters.reduceLeft(And) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/PathOptionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/PathOptionSuite.scala index 60adee4599b0b..6dd4847ead738 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/PathOptionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/PathOptionSuite.scala @@ -75,13 +75,13 @@ class PathOptionSuite extends DataSourceTest with SharedSQLContext { |USING ${classOf[TestOptionsSource].getCanonicalName} |OPTIONS (PATH '/tmp/path') """.stripMargin) - assert(getPathOption("src") == Some("file:/tmp/path")) + assert(getPathOption("src").map(makeQualifiedPath) == Some(makeQualifiedPath("/tmp/path"))) } // should exist even path option is not specified when creating table withTable("src") { sql(s"CREATE TABLE src(i int) USING ${classOf[TestOptionsSource].getCanonicalName}") - assert(getPathOption("src") == Some(CatalogUtils.URIToString(defaultTablePath("src")))) + assert(getPathOption("src").map(makeQualifiedPath) == Some(defaultTablePath("src"))) } } @@ -95,9 +95,9 @@ class PathOptionSuite extends DataSourceTest with SharedSQLContext { |OPTIONS (PATH '$p') |AS SELECT 1 """.stripMargin) - assert(CatalogUtils.stringToURI( - spark.table("src").schema.head.metadata.getString("path")) == - makeQualifiedPath(p.getAbsolutePath)) + assert( + spark.table("src").schema.head.metadata.getString("path") == + p.getAbsolutePath) } } @@ -109,8 +109,9 @@ class PathOptionSuite extends DataSourceTest with SharedSQLContext { |USING ${classOf[TestOptionsSource].getCanonicalName} |AS SELECT 1 """.stripMargin) - assert(spark.table("src").schema.head.metadata.getString("path") == - CatalogUtils.URIToString(defaultTablePath("src"))) + assert( + makeQualifiedPath(spark.table("src").schema.head.metadata.getString("path")) == + defaultTablePath("src")) } } @@ -122,13 +123,13 @@ class PathOptionSuite extends DataSourceTest with SharedSQLContext { |USING ${classOf[TestOptionsSource].getCanonicalName} |OPTIONS (PATH '/tmp/path')""".stripMargin) sql("ALTER TABLE src SET LOCATION '/tmp/path2'") - assert(getPathOption("src") == Some("/tmp/path2")) + assert(getPathOption("src").map(makeQualifiedPath) == Some(makeQualifiedPath("/tmp/path2"))) } withTable("src", "src2") { sql(s"CREATE TABLE src(i int) USING ${classOf[TestOptionsSource].getCanonicalName}") sql("ALTER TABLE src RENAME TO src2") - assert(getPathOption("src2") == Some(CatalogUtils.URIToString(defaultTablePath("src2")))) + assert(getPathOption("src2").map(makeQualifiedPath) == Some(defaultTablePath("src2"))) } } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala index 10f432570e94b..6b98066cb76c8 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala @@ -175,7 +175,7 @@ private[hive] class HiveMetastoreCatalog(sparkSession: SparkSession) extends Log bucketSpec = None, fileFormat = fileFormat, options = options)(sparkSession = sparkSession) - val created = LogicalRelation(fsRelation, catalogTable = Some(updatedTable)) + val created = LogicalRelation(fsRelation, updatedTable) tableRelationCache.put(tableIdentifier, created) created } @@ -203,7 +203,7 @@ private[hive] class HiveMetastoreCatalog(sparkSession: SparkSession) extends Log bucketSpec = None, options = options, className = fileType).resolveRelation(), - catalogTable = Some(updatedTable)) + table = updatedTable) tableRelationCache.put(tableIdentifier, created) created @@ -212,7 +212,14 @@ private[hive] class HiveMetastoreCatalog(sparkSession: SparkSession) extends Log logicalRelation }) } - result.copy(expectedOutputAttributes = Some(relation.output)) + // The inferred schema may have different filed names as the table schema, we should respect + // it, but also respect the exprId in table relation output. + assert(result.output.length == relation.output.length && + result.output.zip(relation.output).forall { case (a1, a2) => a1.dataType == a2.dataType }) + val newOutput = result.output.zip(relation.output).map { + case (a1, a2) => a1.withExprId(a2.exprId) + } + result.copy(output = newOutput) } private def inferIfNeeded( diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/CachedTableSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/CachedTableSuite.scala index 2b3f36064c1f8..d3cbf898e2439 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/CachedTableSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/CachedTableSuite.scala @@ -329,7 +329,7 @@ class CachedTableSuite extends QueryTest with SQLTestUtils with TestHiveSingleto fileFormat = new ParquetFileFormat(), options = Map.empty)(sparkSession = spark) - val plan = LogicalRelation(relation, catalogTable = Some(tableMeta)) + val plan = LogicalRelation(relation, tableMeta) spark.sharedState.cacheManager.cacheQuery(Dataset.ofRows(spark, plan)) assert(spark.sharedState.cacheManager.lookupCachedData(plan).isDefined) @@ -342,7 +342,7 @@ class CachedTableSuite extends QueryTest with SQLTestUtils with TestHiveSingleto bucketSpec = None, fileFormat = new ParquetFileFormat(), options = Map.empty)(sparkSession = spark) - val samePlan = LogicalRelation(sameRelation, catalogTable = Some(tableMeta)) + val samePlan = LogicalRelation(sameRelation, tableMeta) assert(spark.sharedState.cacheManager.lookupCachedData(samePlan).isDefined) } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/PruneFileSourcePartitionsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/PruneFileSourcePartitionsSuite.scala index cd8f94b1cc4f0..f818e29555468 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/PruneFileSourcePartitionsSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/PruneFileSourcePartitionsSuite.scala @@ -58,7 +58,7 @@ class PruneFileSourcePartitionsSuite extends QueryTest with SQLTestUtils with Te fileFormat = new ParquetFileFormat(), options = Map.empty)(sparkSession = spark) - val logicalRelation = LogicalRelation(relation, catalogTable = Some(tableMeta)) + val logicalRelation = LogicalRelation(relation, tableMeta) val query = Project(Seq('i, 'p), Filter('p === 1, logicalRelation)).analyze val optimized = Optimize.execute(query) From 1a52a62377a87cec493c8c6711bfd44e779c7973 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Fri, 7 Apr 2017 11:00:10 +0200 Subject: [PATCH 0217/1765] [SPARK-20076][ML][PYSPARK] Add Python interface for ml.stats.Correlation ## What changes were proposed in this pull request? The Dataframes-based support for the correlation statistics is added in #17108. This patch adds the Python interface for it. ## How was this patch tested? Python unit test. Please review http://spark.apache.org/contributing.html before opening a pull request. Author: Liang-Chi Hsieh Closes #17494 from viirya/correlation-python-api. --- .../apache/spark/ml/stat/Correlation.scala | 8 +-- python/pyspark/ml/stat.py | 61 +++++++++++++++++++ 2 files changed, 65 insertions(+), 4 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/stat/Correlation.scala b/mllib/src/main/scala/org/apache/spark/ml/stat/Correlation.scala index d3c84b77d26ac..e185bc8a6faaa 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/stat/Correlation.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/stat/Correlation.scala @@ -38,7 +38,7 @@ object Correlation { /** * :: Experimental :: - * Compute the correlation matrix for the input RDD of Vectors using the specified method. + * Compute the correlation matrix for the input Dataset of Vectors using the specified method. * Methods currently supported: `pearson` (default), `spearman`. * * @param dataset A dataset or a dataframe @@ -56,14 +56,14 @@ object Correlation { * Here is how to access the correlation coefficient: * {{{ * val data: Dataset[Vector] = ... - * val Row(coeff: Matrix) = Statistics.corr(data, "value").head + * val Row(coeff: Matrix) = Correlation.corr(data, "value").head * // coeff now contains the Pearson correlation matrix. * }}} * * @note For Spearman, a rank correlation, we need to create an RDD[Double] for each column * and sort it in order to retrieve the ranks and then join the columns back into an RDD[Vector], - * which is fairly costly. Cache the input RDD before calling corr with `method = "spearman"` to - * avoid recomputing the common lineage. + * which is fairly costly. Cache the input Dataset before calling corr with `method = "spearman"` + * to avoid recomputing the common lineage. */ @Since("2.2.0") def corr(dataset: Dataset[_], column: String, method: String): DataFrame = { diff --git a/python/pyspark/ml/stat.py b/python/pyspark/ml/stat.py index db043ff68feca..079b0833e1c6d 100644 --- a/python/pyspark/ml/stat.py +++ b/python/pyspark/ml/stat.py @@ -71,6 +71,67 @@ def test(dataset, featuresCol, labelCol): return _java2py(sc, javaTestObj.test(*args)) +class Correlation(object): + """ + .. note:: Experimental + + Compute the correlation matrix for the input dataset of Vectors using the specified method. + Methods currently supported: `pearson` (default), `spearman`. + + .. note:: For Spearman, a rank correlation, we need to create an RDD[Double] for each column + and sort it in order to retrieve the ranks and then join the columns back into an RDD[Vector], + which is fairly costly. Cache the input Dataset before calling corr with `method = 'spearman'` + to avoid recomputing the common lineage. + + :param dataset: + A dataset or a dataframe. + :param column: + The name of the column of vectors for which the correlation coefficient needs + to be computed. This must be a column of the dataset, and it must contain + Vector objects. + :param method: + String specifying the method to use for computing correlation. + Supported: `pearson` (default), `spearman`. + :return: + A dataframe that contains the correlation matrix of the column of vectors. This + dataframe contains a single row and a single column of name + '$METHODNAME($COLUMN)'. + + >>> from pyspark.ml.linalg import Vectors + >>> from pyspark.ml.stat import Correlation + >>> dataset = [[Vectors.dense([1, 0, 0, -2])], + ... [Vectors.dense([4, 5, 0, 3])], + ... [Vectors.dense([6, 7, 0, 8])], + ... [Vectors.dense([9, 0, 0, 1])]] + >>> dataset = spark.createDataFrame(dataset, ['features']) + >>> pearsonCorr = Correlation.corr(dataset, 'features', 'pearson').collect()[0][0] + >>> print(str(pearsonCorr).replace('nan', 'NaN')) + DenseMatrix([[ 1. , 0.0556..., NaN, 0.4004...], + [ 0.0556..., 1. , NaN, 0.9135...], + [ NaN, NaN, 1. , NaN], + [ 0.4004..., 0.9135..., NaN, 1. ]]) + >>> spearmanCorr = Correlation.corr(dataset, 'features', method='spearman').collect()[0][0] + >>> print(str(spearmanCorr).replace('nan', 'NaN')) + DenseMatrix([[ 1. , 0.1054..., NaN, 0.4 ], + [ 0.1054..., 1. , NaN, 0.9486... ], + [ NaN, NaN, 1. , NaN], + [ 0.4 , 0.9486... , NaN, 1. ]]) + + .. versionadded:: 2.2.0 + + """ + @staticmethod + @since("2.2.0") + def corr(dataset, column, method="pearson"): + """ + Compute the correlation matrix with specified method using dataset. + """ + sc = SparkContext._active_spark_context + javaCorrObj = _jvm().org.apache.spark.ml.stat.Correlation + args = [_py2java(sc, arg) for arg in (dataset, column, method)] + return _java2py(sc, javaCorrObj.corr(*args)) + + if __name__ == "__main__": import doctest import pyspark.ml.stat From 9e0893b53d68f777c1f3fb0a67820424a9c253ab Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=83=AD=E5=B0=8F=E9=BE=99=2010207633?= Date: Fri, 7 Apr 2017 13:03:07 +0100 Subject: [PATCH 0218/1765] [SPARK-20218][DOC][APP-ID] applications//stages' in REST API,add description. MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## What changes were proposed in this pull request? 1. '/applications/[app-id]/stages' in rest api.status should add description '?status=[active|complete|pending|failed] list only stages in the state.' Now the lack of this description, resulting in the use of this api do not know the use of the status through the brush stage list. 2.'/applications/[app-id]/stages/[stage-id]' in REST API,remove redundant description ‘?status=[active|complete|pending|failed] list only stages in the state.’. Because only one stage is determined based on stage-id. code: GET def stageList(QueryParam("status") statuses: JList[StageStatus]): Seq[StageData] = { val listener = ui.jobProgressListener val stageAndStatus = AllStagesResource.stagesAndStatus(ui) val adjStatuses = { if (statuses.isEmpty()) { Arrays.asList(StageStatus.values(): _*) } else { statuses } }; ## How was this patch tested? manual tests Please review http://spark.apache.org/contributing.html before opening a pull request. Author: 郭小龙 10207633 Closes #17534 from guoxiaolongzte/SPARK-20218. --- docs/monitoring.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/monitoring.md b/docs/monitoring.md index 4d0617d253b80..da954385dc452 100644 --- a/docs/monitoring.md +++ b/docs/monitoring.md @@ -299,12 +299,12 @@ can be identified by their `[attempt-id]`. In the API listed below, when running /applications/[app-id]/stages A list of all stages for a given application. +
    ?status=[active|complete|pending|failed] list only stages in the state. /applications/[app-id]/stages/[stage-id] A list of all attempts for the given stage. -
    ?status=[active|complete|pending|failed] list only stages in the state. From 870b9d9aa00c260b532c78088e4a0384f7f1fa8a Mon Sep 17 00:00:00 2001 From: actuaryzhang Date: Fri, 7 Apr 2017 10:57:12 -0700 Subject: [PATCH 0219/1765] [SPARK-20026][DOC][SPARKR] Add Tweedie example for SparkR in programming guide ## What changes were proposed in this pull request? Add Tweedie example for SparkR in programming guide. The doc was already updated in #17103. Author: actuaryzhang Closes #17553 from actuaryzhang/programGuide. --- examples/src/main/r/ml/glm.R | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/examples/src/main/r/ml/glm.R b/examples/src/main/r/ml/glm.R index ee13910382c58..23141b57df143 100644 --- a/examples/src/main/r/ml/glm.R +++ b/examples/src/main/r/ml/glm.R @@ -56,6 +56,15 @@ summary(binomialGLM) # Prediction binomialPredictions <- predict(binomialGLM, binomialTestDF) head(binomialPredictions) + +# Fit a generalized linear model of family "tweedie" with spark.glm +training3 <- read.df("data/mllib/sample_multiclass_classification_data.txt", source = "libsvm") +tweedieDF <- transform(training3, label = training3$label * exp(randn(10))) +tweedieGLM <- spark.glm(tweedieDF, label ~ features, family = "tweedie", + var.power = 1.2, link.power = 0) + +# Model summary +summary(tweedieGLM) # $example off$ sparkR.session.stop() From 8feb799af0bb67618310947342e3e4d2a77aae13 Mon Sep 17 00:00:00 2001 From: Felix Cheung Date: Fri, 7 Apr 2017 11:17:49 -0700 Subject: [PATCH 0220/1765] [SPARK-20197][SPARKR] CRAN check fail with package installation ## What changes were proposed in this pull request? Test failed because SPARK_HOME is not set before Spark is installed. Author: Felix Cheung Closes #17516 from felixcheung/rdircheckincran. --- R/pkg/tests/run-all.R | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/R/pkg/tests/run-all.R b/R/pkg/tests/run-all.R index cefaadda6e215..29812f872c784 100644 --- a/R/pkg/tests/run-all.R +++ b/R/pkg/tests/run-all.R @@ -22,12 +22,13 @@ library(SparkR) options("warn" = 2) # Setup global test environment +# Install Spark first to set SPARK_HOME +install.spark() + sparkRDir <- file.path(Sys.getenv("SPARK_HOME"), "R") sparkRFilesBefore <- list.files(path = sparkRDir, all.files = TRUE) sparkRWhitelistSQLDirs <- c("spark-warehouse", "metastore_db") invisible(lapply(sparkRWhitelistSQLDirs, function(x) { unlink(file.path(sparkRDir, x), recursive = TRUE, force = TRUE)})) -install.spark() - test_package("SparkR") From 1ad73f0a21d8007d8466ef8756f751c0ab6a9d1f Mon Sep 17 00:00:00 2001 From: actuaryzhang Date: Fri, 7 Apr 2017 12:29:45 -0700 Subject: [PATCH 0221/1765] [SPARK-20258][DOC][SPARKR] Fix SparkR logistic regression example in programming guide (did not converge) ## What changes were proposed in this pull request? SparkR logistic regression example did not converge in programming guide (for IRWLS). All estimates are essentially zero: ``` training2 <- read.df("data/mllib/sample_binary_classification_data.txt", source = "libsvm") df_list2 <- randomSplit(training2, c(7,3), 2) binomialDF <- df_list2[[1]] binomialTestDF <- df_list2[[2]] binomialGLM <- spark.glm(binomialDF, label ~ features, family = "binomial") 17/04/07 11:42:03 WARN WeightedLeastSquares: Cholesky solver failed due to singular covariance matrix. Retrying with Quasi-Newton solver. > summary(binomialGLM) Coefficients: Estimate (Intercept) 9.0255e+00 features_0 0.0000e+00 features_1 0.0000e+00 features_2 0.0000e+00 features_3 0.0000e+00 features_4 0.0000e+00 features_5 0.0000e+00 features_6 0.0000e+00 features_7 0.0000e+00 ``` Author: actuaryzhang Closes #17571 from actuaryzhang/programGuide2. --- examples/src/main/r/ml/glm.R | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/examples/src/main/r/ml/glm.R b/examples/src/main/r/ml/glm.R index 23141b57df143..68787f9aa9dca 100644 --- a/examples/src/main/r/ml/glm.R +++ b/examples/src/main/r/ml/glm.R @@ -27,7 +27,7 @@ sparkR.session(appName = "SparkR-ML-glm-example") # $example on$ training <- read.df("data/mllib/sample_multiclass_classification_data.txt", source = "libsvm") # Fit a generalized linear model of family "gaussian" with spark.glm -df_list <- randomSplit(training, c(7,3), 2) +df_list <- randomSplit(training, c(7, 3), 2) gaussianDF <- df_list[[1]] gaussianTestDF <- df_list[[2]] gaussianGLM <- spark.glm(gaussianDF, label ~ features, family = "gaussian") @@ -44,8 +44,9 @@ gaussianGLM2 <- glm(label ~ features, gaussianDF, family = "gaussian") summary(gaussianGLM2) # Fit a generalized linear model of family "binomial" with spark.glm -training2 <- read.df("data/mllib/sample_binary_classification_data.txt", source = "libsvm") -df_list2 <- randomSplit(training2, c(7,3), 2) +training2 <- read.df("data/mllib/sample_multiclass_classification_data.txt", source = "libsvm") +training2 <- transform(training2, label = cast(training2$label > 1, "integer")) +df_list2 <- randomSplit(training2, c(7, 3), 2) binomialDF <- df_list2[[1]] binomialTestDF <- df_list2[[2]] binomialGLM <- spark.glm(binomialDF, label ~ features, family = "binomial") From 589f3edb82e970b6df9121861ed0c6b4a6d02cb6 Mon Sep 17 00:00:00 2001 From: Adrian Ionescu Date: Fri, 7 Apr 2017 14:00:23 -0700 Subject: [PATCH 0222/1765] [SPARK-20255] Move listLeafFiles() to InMemoryFileIndex ## What changes were proposed in this pull request Trying to get a grip on the `FileIndex` hierarchy, I was confused by the following inconsistency: On the one hand, `PartitioningAwareFileIndex` defines `leafFiles` and `leafDirToChildrenFiles` as abstract, but on the other it fully implements `listLeafFiles` which does all the listing of files. However, the latter is only used by `InMemoryFileIndex`. I'm hereby proposing to move this method (and all its dependencies) to the implementation class that actually uses it, and thus unclutter the `PartitioningAwareFileIndex` interface. ## How was this patch tested? `./build/sbt sql/test` Author: Adrian Ionescu Closes #17570 from adrian-ionescu/list-leaf-files. --- .../datasources/InMemoryFileIndex.scala | 226 ++++++++++++++++++ .../PartitioningAwareFileIndex.scala | 223 +---------------- .../datasources/FileIndexSuite.scala | 18 +- 3 files changed, 236 insertions(+), 231 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InMemoryFileIndex.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InMemoryFileIndex.scala index ee4d0863d9771..11605dd280569 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InMemoryFileIndex.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InMemoryFileIndex.scala @@ -17,12 +17,19 @@ package org.apache.spark.sql.execution.datasources +import java.io.FileNotFoundException + import scala.collection.mutable +import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs._ +import org.apache.hadoop.mapred.{FileInputFormat, JobConf} +import org.apache.spark.internal.Logging +import org.apache.spark.metrics.source.HiveCatalogMetrics import org.apache.spark.sql.SparkSession import org.apache.spark.sql.types.StructType +import org.apache.spark.util.SerializableConfiguration /** @@ -84,4 +91,223 @@ class InMemoryFileIndex( } override def hashCode(): Int = rootPaths.toSet.hashCode() + + /** + * List leaf files of given paths. This method will submit a Spark job to do parallel + * listing whenever there is a path having more files than the parallel partition discovery + * discovery threshold. + * + * This is publicly visible for testing. + */ + def listLeafFiles(paths: Seq[Path]): mutable.LinkedHashSet[FileStatus] = { + val output = mutable.LinkedHashSet[FileStatus]() + val pathsToFetch = mutable.ArrayBuffer[Path]() + for (path <- paths) { + fileStatusCache.getLeafFiles(path) match { + case Some(files) => + HiveCatalogMetrics.incrementFileCacheHits(files.length) + output ++= files + case None => + pathsToFetch += path + } + } + val filter = FileInputFormat.getInputPathFilter(new JobConf(hadoopConf, this.getClass)) + val discovered = InMemoryFileIndex.bulkListLeafFiles( + pathsToFetch, hadoopConf, filter, sparkSession) + discovered.foreach { case (path, leafFiles) => + HiveCatalogMetrics.incrementFilesDiscovered(leafFiles.size) + fileStatusCache.putLeafFiles(path, leafFiles.toArray) + output ++= leafFiles + } + output + } +} + +object InMemoryFileIndex extends Logging { + + /** A serializable variant of HDFS's BlockLocation. */ + private case class SerializableBlockLocation( + names: Array[String], + hosts: Array[String], + offset: Long, + length: Long) + + /** A serializable variant of HDFS's FileStatus. */ + private case class SerializableFileStatus( + path: String, + length: Long, + isDir: Boolean, + blockReplication: Short, + blockSize: Long, + modificationTime: Long, + accessTime: Long, + blockLocations: Array[SerializableBlockLocation]) + + /** + * Lists a collection of paths recursively. Picks the listing strategy adaptively depending + * on the number of paths to list. + * + * This may only be called on the driver. + * + * @return for each input path, the set of discovered files for the path + */ + private def bulkListLeafFiles( + paths: Seq[Path], + hadoopConf: Configuration, + filter: PathFilter, + sparkSession: SparkSession): Seq[(Path, Seq[FileStatus])] = { + + // Short-circuits parallel listing when serial listing is likely to be faster. + if (paths.size <= sparkSession.sessionState.conf.parallelPartitionDiscoveryThreshold) { + return paths.map { path => + (path, listLeafFiles(path, hadoopConf, filter, Some(sparkSession))) + } + } + + logInfo(s"Listing leaf files and directories in parallel under: ${paths.mkString(", ")}") + HiveCatalogMetrics.incrementParallelListingJobCount(1) + + val sparkContext = sparkSession.sparkContext + val serializableConfiguration = new SerializableConfiguration(hadoopConf) + val serializedPaths = paths.map(_.toString) + val parallelPartitionDiscoveryParallelism = + sparkSession.sessionState.conf.parallelPartitionDiscoveryParallelism + + // Set the number of parallelism to prevent following file listing from generating many tasks + // in case of large #defaultParallelism. + val numParallelism = Math.min(paths.size, parallelPartitionDiscoveryParallelism) + + val statusMap = sparkContext + .parallelize(serializedPaths, numParallelism) + .mapPartitions { pathStrings => + val hadoopConf = serializableConfiguration.value + pathStrings.map(new Path(_)).toSeq.map { path => + (path, listLeafFiles(path, hadoopConf, filter, None)) + }.iterator + }.map { case (path, statuses) => + val serializableStatuses = statuses.map { status => + // Turn FileStatus into SerializableFileStatus so we can send it back to the driver + val blockLocations = status match { + case f: LocatedFileStatus => + f.getBlockLocations.map { loc => + SerializableBlockLocation( + loc.getNames, + loc.getHosts, + loc.getOffset, + loc.getLength) + } + + case _ => + Array.empty[SerializableBlockLocation] + } + + SerializableFileStatus( + status.getPath.toString, + status.getLen, + status.isDirectory, + status.getReplication, + status.getBlockSize, + status.getModificationTime, + status.getAccessTime, + blockLocations) + } + (path.toString, serializableStatuses) + }.collect() + + // turn SerializableFileStatus back to Status + statusMap.map { case (path, serializableStatuses) => + val statuses = serializableStatuses.map { f => + val blockLocations = f.blockLocations.map { loc => + new BlockLocation(loc.names, loc.hosts, loc.offset, loc.length) + } + new LocatedFileStatus( + new FileStatus( + f.length, f.isDir, f.blockReplication, f.blockSize, f.modificationTime, + new Path(f.path)), + blockLocations) + } + (new Path(path), statuses) + } + } + + /** + * Lists a single filesystem path recursively. If a SparkSession object is specified, this + * function may launch Spark jobs to parallelize listing. + * + * If sessionOpt is None, this may be called on executors. + * + * @return all children of path that match the specified filter. + */ + private def listLeafFiles( + path: Path, + hadoopConf: Configuration, + filter: PathFilter, + sessionOpt: Option[SparkSession]): Seq[FileStatus] = { + logTrace(s"Listing $path") + val fs = path.getFileSystem(hadoopConf) + val name = path.getName.toLowerCase + + // [SPARK-17599] Prevent InMemoryFileIndex from failing if path doesn't exist + // Note that statuses only include FileStatus for the files and dirs directly under path, + // and does not include anything else recursively. + val statuses = try fs.listStatus(path) catch { + case _: FileNotFoundException => + logWarning(s"The directory $path was not found. Was it deleted very recently?") + Array.empty[FileStatus] + } + + val filteredStatuses = statuses.filterNot(status => shouldFilterOut(status.getPath.getName)) + + val allLeafStatuses = { + val (dirs, topLevelFiles) = filteredStatuses.partition(_.isDirectory) + val nestedFiles: Seq[FileStatus] = sessionOpt match { + case Some(session) => + bulkListLeafFiles(dirs.map(_.getPath), hadoopConf, filter, session).flatMap(_._2) + case _ => + dirs.flatMap(dir => listLeafFiles(dir.getPath, hadoopConf, filter, sessionOpt)) + } + val allFiles = topLevelFiles ++ nestedFiles + if (filter != null) allFiles.filter(f => filter.accept(f.getPath)) else allFiles + } + + allLeafStatuses.filterNot(status => shouldFilterOut(status.getPath.getName)).map { + case f: LocatedFileStatus => + f + + // NOTE: + // + // - Although S3/S3A/S3N file system can be quite slow for remote file metadata + // operations, calling `getFileBlockLocations` does no harm here since these file system + // implementations don't actually issue RPC for this method. + // + // - Here we are calling `getFileBlockLocations` in a sequential manner, but it should not + // be a big deal since we always use to `listLeafFilesInParallel` when the number of + // paths exceeds threshold. + case f => + // The other constructor of LocatedFileStatus will call FileStatus.getPermission(), + // which is very slow on some file system (RawLocalFileSystem, which is launch a + // subprocess and parse the stdout). + val locations = fs.getFileBlockLocations(f, 0, f.getLen) + val lfs = new LocatedFileStatus(f.getLen, f.isDirectory, f.getReplication, f.getBlockSize, + f.getModificationTime, 0, null, null, null, null, f.getPath, locations) + if (f.isSymlink) { + lfs.setSymlink(f.getSymlink) + } + lfs + } + } + + /** Checks if we should filter out this path name. */ + def shouldFilterOut(pathName: String): Boolean = { + // We filter follow paths: + // 1. everything that starts with _ and ., except _common_metadata and _metadata + // because Parquet needs to find those metadata files from leaf files returned by this method. + // We should refactor this logic to not mix metadata files with data files. + // 2. everything that ends with `._COPYING_`, because this is a intermediate state of file. we + // should skip this file in case of double reading. + val exclude = (pathName.startsWith("_") && !pathName.contains("=")) || + pathName.startsWith(".") || pathName.endsWith("._COPYING_") + val include = pathName.startsWith("_common_metadata") || pathName.startsWith("_metadata") + exclude && !include + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningAwareFileIndex.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningAwareFileIndex.scala index 71500a010581e..ffd7f6c750f85 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningAwareFileIndex.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningAwareFileIndex.scala @@ -17,22 +17,17 @@ package org.apache.spark.sql.execution.datasources -import java.io.FileNotFoundException - import scala.collection.mutable import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs._ -import org.apache.hadoop.mapred.{FileInputFormat, JobConf} import org.apache.spark.internal.Logging -import org.apache.spark.metrics.source.HiveCatalogMetrics import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.{expressions, InternalRow} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, DateTimeUtils} import org.apache.spark.sql.types.{StringType, StructType} -import org.apache.spark.util.SerializableConfiguration /** * An abstract class that represents [[FileIndex]]s that are aware of partitioned tables. @@ -241,224 +236,8 @@ abstract class PartitioningAwareFileIndex( val name = path.getName !((name.startsWith("_") && !name.contains("=")) || name.startsWith(".")) } - - /** - * List leaf files of given paths. This method will submit a Spark job to do parallel - * listing whenever there is a path having more files than the parallel partition discovery - * discovery threshold. - * - * This is publicly visible for testing. - */ - def listLeafFiles(paths: Seq[Path]): mutable.LinkedHashSet[FileStatus] = { - val output = mutable.LinkedHashSet[FileStatus]() - val pathsToFetch = mutable.ArrayBuffer[Path]() - for (path <- paths) { - fileStatusCache.getLeafFiles(path) match { - case Some(files) => - HiveCatalogMetrics.incrementFileCacheHits(files.length) - output ++= files - case None => - pathsToFetch += path - } - } - val filter = FileInputFormat.getInputPathFilter(new JobConf(hadoopConf, this.getClass)) - val discovered = PartitioningAwareFileIndex.bulkListLeafFiles( - pathsToFetch, hadoopConf, filter, sparkSession) - discovered.foreach { case (path, leafFiles) => - HiveCatalogMetrics.incrementFilesDiscovered(leafFiles.size) - fileStatusCache.putLeafFiles(path, leafFiles.toArray) - output ++= leafFiles - } - output - } } -object PartitioningAwareFileIndex extends Logging { +object PartitioningAwareFileIndex { val BASE_PATH_PARAM = "basePath" - - /** A serializable variant of HDFS's BlockLocation. */ - private case class SerializableBlockLocation( - names: Array[String], - hosts: Array[String], - offset: Long, - length: Long) - - /** A serializable variant of HDFS's FileStatus. */ - private case class SerializableFileStatus( - path: String, - length: Long, - isDir: Boolean, - blockReplication: Short, - blockSize: Long, - modificationTime: Long, - accessTime: Long, - blockLocations: Array[SerializableBlockLocation]) - - /** - * Lists a collection of paths recursively. Picks the listing strategy adaptively depending - * on the number of paths to list. - * - * This may only be called on the driver. - * - * @return for each input path, the set of discovered files for the path - */ - private def bulkListLeafFiles( - paths: Seq[Path], - hadoopConf: Configuration, - filter: PathFilter, - sparkSession: SparkSession): Seq[(Path, Seq[FileStatus])] = { - - // Short-circuits parallel listing when serial listing is likely to be faster. - if (paths.size <= sparkSession.sessionState.conf.parallelPartitionDiscoveryThreshold) { - return paths.map { path => - (path, listLeafFiles(path, hadoopConf, filter, Some(sparkSession))) - } - } - - logInfo(s"Listing leaf files and directories in parallel under: ${paths.mkString(", ")}") - HiveCatalogMetrics.incrementParallelListingJobCount(1) - - val sparkContext = sparkSession.sparkContext - val serializableConfiguration = new SerializableConfiguration(hadoopConf) - val serializedPaths = paths.map(_.toString) - val parallelPartitionDiscoveryParallelism = - sparkSession.sessionState.conf.parallelPartitionDiscoveryParallelism - - // Set the number of parallelism to prevent following file listing from generating many tasks - // in case of large #defaultParallelism. - val numParallelism = Math.min(paths.size, parallelPartitionDiscoveryParallelism) - - val statusMap = sparkContext - .parallelize(serializedPaths, numParallelism) - .mapPartitions { pathStrings => - val hadoopConf = serializableConfiguration.value - pathStrings.map(new Path(_)).toSeq.map { path => - (path, listLeafFiles(path, hadoopConf, filter, None)) - }.iterator - }.map { case (path, statuses) => - val serializableStatuses = statuses.map { status => - // Turn FileStatus into SerializableFileStatus so we can send it back to the driver - val blockLocations = status match { - case f: LocatedFileStatus => - f.getBlockLocations.map { loc => - SerializableBlockLocation( - loc.getNames, - loc.getHosts, - loc.getOffset, - loc.getLength) - } - - case _ => - Array.empty[SerializableBlockLocation] - } - - SerializableFileStatus( - status.getPath.toString, - status.getLen, - status.isDirectory, - status.getReplication, - status.getBlockSize, - status.getModificationTime, - status.getAccessTime, - blockLocations) - } - (path.toString, serializableStatuses) - }.collect() - - // turn SerializableFileStatus back to Status - statusMap.map { case (path, serializableStatuses) => - val statuses = serializableStatuses.map { f => - val blockLocations = f.blockLocations.map { loc => - new BlockLocation(loc.names, loc.hosts, loc.offset, loc.length) - } - new LocatedFileStatus( - new FileStatus( - f.length, f.isDir, f.blockReplication, f.blockSize, f.modificationTime, - new Path(f.path)), - blockLocations) - } - (new Path(path), statuses) - } - } - - /** - * Lists a single filesystem path recursively. If a SparkSession object is specified, this - * function may launch Spark jobs to parallelize listing. - * - * If sessionOpt is None, this may be called on executors. - * - * @return all children of path that match the specified filter. - */ - private def listLeafFiles( - path: Path, - hadoopConf: Configuration, - filter: PathFilter, - sessionOpt: Option[SparkSession]): Seq[FileStatus] = { - logTrace(s"Listing $path") - val fs = path.getFileSystem(hadoopConf) - val name = path.getName.toLowerCase - - // [SPARK-17599] Prevent InMemoryFileIndex from failing if path doesn't exist - // Note that statuses only include FileStatus for the files and dirs directly under path, - // and does not include anything else recursively. - val statuses = try fs.listStatus(path) catch { - case _: FileNotFoundException => - logWarning(s"The directory $path was not found. Was it deleted very recently?") - Array.empty[FileStatus] - } - - val filteredStatuses = statuses.filterNot(status => shouldFilterOut(status.getPath.getName)) - - val allLeafStatuses = { - val (dirs, topLevelFiles) = filteredStatuses.partition(_.isDirectory) - val nestedFiles: Seq[FileStatus] = sessionOpt match { - case Some(session) => - bulkListLeafFiles(dirs.map(_.getPath), hadoopConf, filter, session).flatMap(_._2) - case _ => - dirs.flatMap(dir => listLeafFiles(dir.getPath, hadoopConf, filter, sessionOpt)) - } - val allFiles = topLevelFiles ++ nestedFiles - if (filter != null) allFiles.filter(f => filter.accept(f.getPath)) else allFiles - } - - allLeafStatuses.filterNot(status => shouldFilterOut(status.getPath.getName)).map { - case f: LocatedFileStatus => - f - - // NOTE: - // - // - Although S3/S3A/S3N file system can be quite slow for remote file metadata - // operations, calling `getFileBlockLocations` does no harm here since these file system - // implementations don't actually issue RPC for this method. - // - // - Here we are calling `getFileBlockLocations` in a sequential manner, but it should not - // be a big deal since we always use to `listLeafFilesInParallel` when the number of - // paths exceeds threshold. - case f => - // The other constructor of LocatedFileStatus will call FileStatus.getPermission(), - // which is very slow on some file system (RawLocalFileSystem, which is launch a - // subprocess and parse the stdout). - val locations = fs.getFileBlockLocations(f, 0, f.getLen) - val lfs = new LocatedFileStatus(f.getLen, f.isDirectory, f.getReplication, f.getBlockSize, - f.getModificationTime, 0, null, null, null, null, f.getPath, locations) - if (f.isSymlink) { - lfs.setSymlink(f.getSymlink) - } - lfs - } - } - - /** Checks if we should filter out this path name. */ - def shouldFilterOut(pathName: String): Boolean = { - // We filter follow paths: - // 1. everything that starts with _ and ., except _common_metadata and _metadata - // because Parquet needs to find those metadata files from leaf files returned by this method. - // We should refactor this logic to not mix metadata files with data files. - // 2. everything that ends with `._COPYING_`, because this is a intermediate state of file. we - // should skip this file in case of double reading. - val exclude = (pathName.startsWith("_") && !pathName.contains("=")) || - pathName.startsWith(".") || pathName.endsWith("._COPYING_") - val include = pathName.startsWith("_common_metadata") || pathName.startsWith("_metadata") - exclude && !include - } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileIndexSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileIndexSuite.scala index 7ea4064927576..00f5d5db8f5f4 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileIndexSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileIndexSuite.scala @@ -135,15 +135,15 @@ class FileIndexSuite extends SharedSQLContext { } } - test("PartitioningAwareFileIndex - file filtering") { - assert(!PartitioningAwareFileIndex.shouldFilterOut("abcd")) - assert(PartitioningAwareFileIndex.shouldFilterOut(".ab")) - assert(PartitioningAwareFileIndex.shouldFilterOut("_cd")) - assert(!PartitioningAwareFileIndex.shouldFilterOut("_metadata")) - assert(!PartitioningAwareFileIndex.shouldFilterOut("_common_metadata")) - assert(PartitioningAwareFileIndex.shouldFilterOut("_ab_metadata")) - assert(PartitioningAwareFileIndex.shouldFilterOut("_cd_common_metadata")) - assert(PartitioningAwareFileIndex.shouldFilterOut("a._COPYING_")) + test("InMemoryFileIndex - file filtering") { + assert(!InMemoryFileIndex.shouldFilterOut("abcd")) + assert(InMemoryFileIndex.shouldFilterOut(".ab")) + assert(InMemoryFileIndex.shouldFilterOut("_cd")) + assert(!InMemoryFileIndex.shouldFilterOut("_metadata")) + assert(!InMemoryFileIndex.shouldFilterOut("_common_metadata")) + assert(InMemoryFileIndex.shouldFilterOut("_ab_metadata")) + assert(InMemoryFileIndex.shouldFilterOut("_cd_common_metadata")) + assert(InMemoryFileIndex.shouldFilterOut("a._COPYING_")) } test("SPARK-17613 - PartitioningAwareFileIndex: base path w/o '/' at end") { From 7577e9c356b580d744e1fc27c645fce41bdf9cf0 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Fri, 7 Apr 2017 20:54:18 -0700 Subject: [PATCH 0223/1765] [SPARK-20246][SQL] should not push predicate down through aggregate with non-deterministic expressions ## What changes were proposed in this pull request? Similar to `Project`, when `Aggregate` has non-deterministic expressions, we should not push predicate down through it, as it will change the number of input rows and thus change the evaluation result of non-deterministic expressions in `Aggregate`. ## How was this patch tested? new regression test Author: Wenchen Fan Closes #17562 from cloud-fan/filter. --- .../sql/catalyst/optimizer/Optimizer.scala | 60 ++++++++++--------- .../optimizer/FilterPushdownSuite.scala | 41 +++++++++++-- 2 files changed, 68 insertions(+), 33 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 577112779eea4..d221b0611a892 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -755,7 +755,8 @@ object PushDownPredicate extends Rule[LogicalPlan] with PredicateHelper { // implies that, for a given input row, the output are determined by the expression's initial // state and all the input rows processed before. In another word, the order of input rows // matters for non-deterministic expressions, while pushing down predicates changes the order. - case filter @ Filter(condition, project @ Project(fields, grandChild)) + // This also applies to Aggregate. + case Filter(condition, project @ Project(fields, grandChild)) if fields.forall(_.deterministic) && canPushThroughCondition(grandChild, condition) => // Create a map of Aliases to their values from the child projection. @@ -766,33 +767,8 @@ object PushDownPredicate extends Rule[LogicalPlan] with PredicateHelper { project.copy(child = Filter(replaceAlias(condition, aliasMap), grandChild)) - // Push [[Filter]] operators through [[Window]] operators. Parts of the predicate that can be - // pushed beneath must satisfy the following conditions: - // 1. All the expressions are part of window partitioning key. The expressions can be compound. - // 2. Deterministic. - // 3. Placed before any non-deterministic predicates. - case filter @ Filter(condition, w: Window) - if w.partitionSpec.forall(_.isInstanceOf[AttributeReference]) => - val partitionAttrs = AttributeSet(w.partitionSpec.flatMap(_.references)) - - val (candidates, containingNonDeterministic) = - splitConjunctivePredicates(condition).span(_.deterministic) - - val (pushDown, rest) = candidates.partition { cond => - cond.references.subsetOf(partitionAttrs) - } - - val stayUp = rest ++ containingNonDeterministic - - if (pushDown.nonEmpty) { - val pushDownPredicate = pushDown.reduce(And) - val newWindow = w.copy(child = Filter(pushDownPredicate, w.child)) - if (stayUp.isEmpty) newWindow else Filter(stayUp.reduce(And), newWindow) - } else { - filter - } - - case filter @ Filter(condition, aggregate: Aggregate) => + case filter @ Filter(condition, aggregate: Aggregate) + if aggregate.aggregateExpressions.forall(_.deterministic) => // Find all the aliased expressions in the aggregate list that don't include any actual // AggregateExpression, and create a map from the alias to the expression val aliasMap = AttributeMap(aggregate.aggregateExpressions.collect { @@ -823,6 +799,32 @@ object PushDownPredicate extends Rule[LogicalPlan] with PredicateHelper { filter } + // Push [[Filter]] operators through [[Window]] operators. Parts of the predicate that can be + // pushed beneath must satisfy the following conditions: + // 1. All the expressions are part of window partitioning key. The expressions can be compound. + // 2. Deterministic. + // 3. Placed before any non-deterministic predicates. + case filter @ Filter(condition, w: Window) + if w.partitionSpec.forall(_.isInstanceOf[AttributeReference]) => + val partitionAttrs = AttributeSet(w.partitionSpec.flatMap(_.references)) + + val (candidates, containingNonDeterministic) = + splitConjunctivePredicates(condition).span(_.deterministic) + + val (pushDown, rest) = candidates.partition { cond => + cond.references.subsetOf(partitionAttrs) + } + + val stayUp = rest ++ containingNonDeterministic + + if (pushDown.nonEmpty) { + val pushDownPredicate = pushDown.reduce(And) + val newWindow = w.copy(child = Filter(pushDownPredicate, w.child)) + if (stayUp.isEmpty) newWindow else Filter(stayUp.reduce(And), newWindow) + } else { + filter + } + case filter @ Filter(condition, union: Union) => // Union could change the rows, so non-deterministic predicate can't be pushed down val (pushDown, stayUp) = splitConjunctivePredicates(condition).span(_.deterministic) @@ -848,7 +850,7 @@ object PushDownPredicate extends Rule[LogicalPlan] with PredicateHelper { filter } - case filter @ Filter(condition, u: UnaryNode) + case filter @ Filter(_, u: UnaryNode) if canPushThrough(u) && u.expressions.forall(_.deterministic) => pushDownPredicate(filter, u.child) { predicate => u.withNewChildren(Seq(Filter(predicate, u.child))) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala index d846786473eb0..ccd0b7c5d7f79 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala @@ -134,15 +134,20 @@ class FilterPushdownSuite extends PlanTest { comparePlans(optimized, correctAnswer) } - test("nondeterministic: can't push down filter with nondeterministic condition through project") { + test("nondeterministic: can always push down filter through project with deterministic field") { val originalQuery = testRelation - .select(Rand(10).as('rand), 'a) - .where('rand > 5 || 'a > 5) + .select('a) + .where(Rand(10) > 5 || 'a > 5) .analyze val optimized = Optimize.execute(originalQuery) - comparePlans(optimized, originalQuery) + val correctAnswer = testRelation + .where(Rand(10) > 5 || 'a > 5) + .select('a) + .analyze + + comparePlans(optimized, correctAnswer) } test("nondeterministic: can't push down filter through project with nondeterministic field") { @@ -156,6 +161,34 @@ class FilterPushdownSuite extends PlanTest { comparePlans(optimized, originalQuery) } + test("nondeterministic: can't push down filter through aggregate with nondeterministic field") { + val originalQuery = testRelation + .groupBy('a)('a, Rand(10).as('rand)) + .where('a > 5) + .analyze + + val optimized = Optimize.execute(originalQuery) + + comparePlans(optimized, originalQuery) + } + + test("nondeterministic: push down part of filter through aggregate with deterministic field") { + val originalQuery = testRelation + .groupBy('a)('a) + .where('a > 5 && Rand(10) > 5) + .analyze + + val optimized = Optimize.execute(originalQuery.analyze) + + val correctAnswer = testRelation + .where('a > 5) + .groupBy('a)('a) + .where(Rand(10) > 5) + .analyze + + comparePlans(optimized, correctAnswer) + } + test("filters: combines filters") { val originalQuery = testRelation .select('a) From e1afc4dcca8ba517f48200c0ecde1152505e41ec Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Fri, 7 Apr 2017 21:14:50 -0700 Subject: [PATCH 0224/1765] [SPARK-20262][SQL] AssertNotNull should throw NullPointerException ## What changes were proposed in this pull request? AssertNotNull currently throws RuntimeException. It should throw NullPointerException, which is more specific. ## How was this patch tested? N/A Author: Reynold Xin Closes #17573 from rxin/SPARK-20262. --- .../spark/sql/catalyst/expressions/objects/objects.scala | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala index 00e2ac91e67ca..53842ef348a57 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala @@ -989,7 +989,7 @@ case class InitializeJavaBean(beanInstance: Expression, setters: Map[String, Exp * `Int` field named `i`. Expression `s.i` is nullable because `s` can be null. However, for all * non-null `s`, `s.i` can't be null. */ -case class AssertNotNull(child: Expression, walkedTypePath: Seq[String]) +case class AssertNotNull(child: Expression, walkedTypePath: Seq[String] = Nil) extends UnaryExpression with NonSQLExpression { override def dataType: DataType = child.dataType @@ -1005,7 +1005,7 @@ case class AssertNotNull(child: Expression, walkedTypePath: Seq[String]) override def eval(input: InternalRow): Any = { val result = child.eval(input) if (result == null) { - throw new RuntimeException(errMsg) + throw new NullPointerException(errMsg) } result } @@ -1021,7 +1021,7 @@ case class AssertNotNull(child: Expression, walkedTypePath: Seq[String]) ${childGen.code} if (${childGen.isNull}) { - throw new RuntimeException($errMsgField); + throw new NullPointerException($errMsgField); } """ ev.copy(code = code, isNull = "false", value = childGen.value) From 34fc48fb5976ede00f3f6d8c4d3eec979e4f4d7f Mon Sep 17 00:00:00 2001 From: asmith26 Date: Sun, 9 Apr 2017 07:47:23 +0100 Subject: [PATCH 0225/1765] [MINOR] Issue: Change "slice" vs "partition" in exception messages (and code?) ## What changes were proposed in this pull request? Came across the term "slice" when running some spark scala code. Consequently, a Google search indicated that "slices" and "partitions" refer to the same things; indeed see: - [This issue](https://issues.apache.org/jira/browse/SPARK-1701) - [This pull request](https://github.com/apache/spark/pull/2305) - [This StackOverflow answer](http://stackoverflow.com/questions/23436640/what-is-the-difference-between-an-rdd-partition-and-a-slice) and [this one](http://stackoverflow.com/questions/24269495/what-are-the-differences-between-slices-and-partitions-of-rdds) Thus this pull request fixes the occurrence of slice I came accross. Nonetheless, [it would appear](https://github.com/apache/spark/search?utf8=%E2%9C%93&q=slice&type=) there are still many references to "slice/slices" - thus I thought I'd raise this Pull Request to address the issue (sorry if this is the wrong place, I'm not too familar with raising apache issues). ## How was this patch tested? (Not tested locally - only a minor exception message change.) Please review http://spark.apache.org/contributing.html before opening a pull request. Author: asmith26 Closes #17565 from asmith26/master. --- .../main/scala/org/apache/spark/rdd/ParallelCollectionRDD.scala | 2 +- .../src/main/java/org/apache/spark/examples/JavaSparkPi.java | 2 +- examples/src/main/java/org/apache/spark/examples/JavaTC.java | 2 +- .../main/scala/org/apache/spark/examples/BroadcastTest.scala | 2 +- .../scala/org/apache/spark/examples/MultiBroadcastTest.scala | 2 +- .../src/main/scala/org/apache/spark/examples/SparkALS.scala | 2 +- examples/src/main/scala/org/apache/spark/examples/SparkLR.scala | 2 +- 7 files changed, 7 insertions(+), 7 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/rdd/ParallelCollectionRDD.scala b/core/src/main/scala/org/apache/spark/rdd/ParallelCollectionRDD.scala index e9092739b298a..9f8019b80a4dd 100644 --- a/core/src/main/scala/org/apache/spark/rdd/ParallelCollectionRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/ParallelCollectionRDD.scala @@ -116,7 +116,7 @@ private object ParallelCollectionRDD { */ def slice[T: ClassTag](seq: Seq[T], numSlices: Int): Seq[Seq[T]] = { if (numSlices < 1) { - throw new IllegalArgumentException("Positive number of slices required") + throw new IllegalArgumentException("Positive number of partitions required") } // Sequences need to be sliced at the same set of index positions for operations // like RDD.zip() to behave as expected diff --git a/examples/src/main/java/org/apache/spark/examples/JavaSparkPi.java b/examples/src/main/java/org/apache/spark/examples/JavaSparkPi.java index cb4b26569088a..37bd8fffbe45a 100644 --- a/examples/src/main/java/org/apache/spark/examples/JavaSparkPi.java +++ b/examples/src/main/java/org/apache/spark/examples/JavaSparkPi.java @@ -26,7 +26,7 @@ /** * Computes an approximation to pi - * Usage: JavaSparkPi [slices] + * Usage: JavaSparkPi [partitions] */ public final class JavaSparkPi { diff --git a/examples/src/main/java/org/apache/spark/examples/JavaTC.java b/examples/src/main/java/org/apache/spark/examples/JavaTC.java index bde30b84d6cf3..c9ca9c9b3a412 100644 --- a/examples/src/main/java/org/apache/spark/examples/JavaTC.java +++ b/examples/src/main/java/org/apache/spark/examples/JavaTC.java @@ -32,7 +32,7 @@ /** * Transitive closure on a graph, implemented in Java. - * Usage: JavaTC [slices] + * Usage: JavaTC [partitions] */ public final class JavaTC { diff --git a/examples/src/main/scala/org/apache/spark/examples/BroadcastTest.scala b/examples/src/main/scala/org/apache/spark/examples/BroadcastTest.scala index 86eed3867c539..25718f904cc49 100644 --- a/examples/src/main/scala/org/apache/spark/examples/BroadcastTest.scala +++ b/examples/src/main/scala/org/apache/spark/examples/BroadcastTest.scala @@ -21,7 +21,7 @@ package org.apache.spark.examples import org.apache.spark.sql.SparkSession /** - * Usage: BroadcastTest [slices] [numElem] [blockSize] + * Usage: BroadcastTest [partitions] [numElem] [blockSize] */ object BroadcastTest { def main(args: Array[String]) { diff --git a/examples/src/main/scala/org/apache/spark/examples/MultiBroadcastTest.scala b/examples/src/main/scala/org/apache/spark/examples/MultiBroadcastTest.scala index 6495a86fcd77c..e6f33b7adf5d1 100644 --- a/examples/src/main/scala/org/apache/spark/examples/MultiBroadcastTest.scala +++ b/examples/src/main/scala/org/apache/spark/examples/MultiBroadcastTest.scala @@ -23,7 +23,7 @@ import org.apache.spark.sql.SparkSession /** - * Usage: MultiBroadcastTest [slices] [numElem] + * Usage: MultiBroadcastTest [partitions] [numElem] */ object MultiBroadcastTest { def main(args: Array[String]) { diff --git a/examples/src/main/scala/org/apache/spark/examples/SparkALS.scala b/examples/src/main/scala/org/apache/spark/examples/SparkALS.scala index 8a3d08f459783..a99ddd9fd37db 100644 --- a/examples/src/main/scala/org/apache/spark/examples/SparkALS.scala +++ b/examples/src/main/scala/org/apache/spark/examples/SparkALS.scala @@ -100,7 +100,7 @@ object SparkALS { ITERATIONS = iters.getOrElse("5").toInt slices = slices_.getOrElse("2").toInt case _ => - System.err.println("Usage: SparkALS [M] [U] [F] [iters] [slices]") + System.err.println("Usage: SparkALS [M] [U] [F] [iters] [partitions]") System.exit(1) } diff --git a/examples/src/main/scala/org/apache/spark/examples/SparkLR.scala b/examples/src/main/scala/org/apache/spark/examples/SparkLR.scala index afa8f58c96e59..cb2be091ffcf3 100644 --- a/examples/src/main/scala/org/apache/spark/examples/SparkLR.scala +++ b/examples/src/main/scala/org/apache/spark/examples/SparkLR.scala @@ -28,7 +28,7 @@ import org.apache.spark.sql.SparkSession /** * Logistic regression based classification. - * Usage: SparkLR [slices] + * Usage: SparkLR [partitions] * * This is an example implementation for learning how to use Spark. For more conventional use, * please refer to org.apache.spark.ml.classification.LogisticRegression. From 1f0de3c1c85a41eadc7c4131bdc948405f340099 Mon Sep 17 00:00:00 2001 From: Sean Owen Date: Sun, 9 Apr 2017 08:44:02 +0100 Subject: [PATCH 0226/1765] [SPARK-19991][CORE][YARN] FileSegmentManagedBuffer performance improvement ## What changes were proposed in this pull request? Avoid `NoSuchElementException` every time `ConfigProvider.get(val, default)` falls back to default. This apparently causes non-trivial overhead in at least one path, and can easily be avoided. See https://github.com/apache/spark/pull/17329 ## How was this patch tested? Existing tests Author: Sean Owen Closes #17567 from srowen/SPARK-19991. --- .../org/apache/spark/network/util/MapConfigProvider.java | 6 ++++++ .../spark/network/yarn/util/HadoopConfigProvider.java | 6 ++++++ .../org/apache/spark/network/netty/SparkTransportConf.scala | 2 +- 3 files changed, 13 insertions(+), 1 deletion(-) diff --git a/common/network-common/src/main/java/org/apache/spark/network/util/MapConfigProvider.java b/common/network-common/src/main/java/org/apache/spark/network/util/MapConfigProvider.java index 9cfee7f08d155..a2cf87d1af7ed 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/util/MapConfigProvider.java +++ b/common/network-common/src/main/java/org/apache/spark/network/util/MapConfigProvider.java @@ -42,6 +42,12 @@ public String get(String name) { return value; } + @Override + public String get(String name, String defaultValue) { + String value = config.get(name); + return value == null ? defaultValue : value; + } + @Override public Iterable> getAll() { return config.entrySet(); diff --git a/common/network-yarn/src/main/java/org/apache/spark/network/yarn/util/HadoopConfigProvider.java b/common/network-yarn/src/main/java/org/apache/spark/network/yarn/util/HadoopConfigProvider.java index 62a6cca4ed4eb..8beb033699471 100644 --- a/common/network-yarn/src/main/java/org/apache/spark/network/yarn/util/HadoopConfigProvider.java +++ b/common/network-yarn/src/main/java/org/apache/spark/network/yarn/util/HadoopConfigProvider.java @@ -41,6 +41,12 @@ public String get(String name) { return value; } + @Override + public String get(String name, String defaultValue) { + String value = conf.get(name); + return value == null ? defaultValue : value; + } + @Override public Iterable> getAll() { return conf; diff --git a/core/src/main/scala/org/apache/spark/network/netty/SparkTransportConf.scala b/core/src/main/scala/org/apache/spark/network/netty/SparkTransportConf.scala index df520f804b4c3..25f7bcb9801b9 100644 --- a/core/src/main/scala/org/apache/spark/network/netty/SparkTransportConf.scala +++ b/core/src/main/scala/org/apache/spark/network/netty/SparkTransportConf.scala @@ -60,7 +60,7 @@ object SparkTransportConf { new TransportConf(module, new ConfigProvider { override def get(name: String): String = conf.get(name) - + override def get(name: String, defaultValue: String): String = conf.get(name, defaultValue) override def getAll(): java.lang.Iterable[java.util.Map.Entry[String, String]] = { conf.getAll.toMap.asJava.entrySet() } From 261eaf5149a8fe479ab4f9c34db892bcedbf5739 Mon Sep 17 00:00:00 2001 From: Vijay Ramesh Date: Sun, 9 Apr 2017 19:39:09 +0100 Subject: [PATCH 0227/1765] [SPARK-20260][MLLIB] String interpolation required for error message ## What changes were proposed in this pull request? This error message doesn't get properly formatted because of a missing `s`. Currently the error looks like: ``` Caused by: java.lang.IllegalArgumentException: requirement failed: indices should be one-based and in ascending order; found current=$current, previous=$previous; line="$line" ``` (note the literal `$current` instead of the interpolated value) Please review http://spark.apache.org/contributing.html before opening a pull request. Author: Vijay Ramesh Closes #17572 from vijaykramesh/master. --- .../scala/org/apache/spark/deploy/SparkHadoopUtil.scala | 2 +- .../test/scala/org/apache/spark/ml/util/TestingUtils.scala | 2 +- .../spark/mllib/clustering/PowerIterationClustering.scala | 4 ++-- .../apache/spark/mllib/tree/model/DecisionTreeModel.scala | 2 +- .../main/scala/org/apache/spark/mllib/util/MLUtils.scala | 2 +- .../scala/org/apache/spark/mllib/util/TestingUtils.scala | 2 +- .../scala/org/apache/spark/sql/hive/orc/OrcQuerySuite.scala | 6 +++--- 7 files changed, 10 insertions(+), 10 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala b/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala index f475ce87540aa..bae7a3f307f52 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala @@ -349,7 +349,7 @@ class SparkHadoopUtil extends Logging { } } catch { case e: IOException => - logDebug("Failed to decode $token: $e", e) + logDebug(s"Failed to decode $token: $e", e) } buffer.toString } diff --git a/mllib-local/src/test/scala/org/apache/spark/ml/util/TestingUtils.scala b/mllib-local/src/test/scala/org/apache/spark/ml/util/TestingUtils.scala index 30edd00fb53e1..6c79d77f142e5 100644 --- a/mllib-local/src/test/scala/org/apache/spark/ml/util/TestingUtils.scala +++ b/mllib-local/src/test/scala/org/apache/spark/ml/util/TestingUtils.scala @@ -215,7 +215,7 @@ object TestingUtils { if (r.fun(x, r.y, r.eps)) { throw new TestFailedException( s"Did not expect \n$x\n and \n${r.y}\n to be within " + - "${r.eps}${r.method} for all elements.", 0) + s"${r.eps}${r.method} for all elements.", 0) } true } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/PowerIterationClustering.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/PowerIterationClustering.scala index 4d3e265455da6..b2437b845f826 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/PowerIterationClustering.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/PowerIterationClustering.scala @@ -259,7 +259,7 @@ object PowerIterationClustering extends Logging { val j = ctx.dstId val s = ctx.attr if (s < 0.0) { - throw new SparkException("Similarity must be nonnegative but found s($i, $j) = $s.") + throw new SparkException(s"Similarity must be nonnegative but found s($i, $j) = $s.") } if (s > 0.0) { ctx.sendToSrc(s) @@ -283,7 +283,7 @@ object PowerIterationClustering extends Logging { : Graph[Double, Double] = { val edges = similarities.flatMap { case (i, j, s) => if (s < 0.0) { - throw new SparkException("Similarity must be nonnegative but found s($i, $j) = $s.") + throw new SparkException(s"Similarity must be nonnegative but found s($i, $j) = $s.") } if (i != j) { Seq(Edge(i, j, s), Edge(j, i, s)) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala index a1562384b0a7e..27618e122aefd 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala @@ -248,7 +248,7 @@ object DecisionTreeModel extends Loader[DecisionTreeModel] with Logging { // Build node data into a tree. val trees = constructTrees(nodes) assert(trees.length == 1, - "Decision tree should contain exactly one tree but got ${trees.size} trees.") + s"Decision tree should contain exactly one tree but got ${trees.size} trees.") val model = new DecisionTreeModel(trees(0), Algo.fromString(algo)) assert(model.numNodes == numNodes, s"Unable to load DecisionTreeModel data from: $dataPath." + s" Expected $numNodes nodes but found ${model.numNodes}") diff --git a/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala b/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala index 95f904dac552c..4fdad05973969 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala @@ -119,7 +119,7 @@ object MLUtils extends Logging { while (i < indicesLength) { val current = indices(i) require(current > previous, s"indices should be one-based and in ascending order;" - + " found current=$current, previous=$previous; line=\"$line\"") + + s""" found current=$current, previous=$previous; line="$line"""") previous = current i += 1 } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/util/TestingUtils.scala b/mllib/src/test/scala/org/apache/spark/mllib/util/TestingUtils.scala index 39a6bc37d9638..d39865a19a5c5 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/util/TestingUtils.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/util/TestingUtils.scala @@ -207,7 +207,7 @@ object TestingUtils { if (r.fun(x, r.y, r.eps)) { throw new TestFailedException( s"Did not expect \n$x\n and \n${r.y}\n to be within " + - "${r.eps}${r.method} for all elements.", 0) + s"${r.eps}${r.method} for all elements.", 0) } true } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcQuerySuite.scala index 5d8ba9d7c85d1..8c855730c31f2 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcQuerySuite.scala @@ -285,7 +285,7 @@ class OrcQuerySuite extends QueryTest with BeforeAndAfterAll with OrcTest { val queryOutput = selfJoin.queryExecution.analyzed.output assertResult(4, "Field count mismatches")(queryOutput.size) - assertResult(2, "Duplicated expression ID in query plan:\n $selfJoin") { + assertResult(2, s"Duplicated expression ID in query plan:\n $selfJoin") { queryOutput.filter(_.name == "_1").map(_.exprId).size } @@ -294,7 +294,7 @@ class OrcQuerySuite extends QueryTest with BeforeAndAfterAll with OrcTest { } test("nested data - struct with array field") { - val data = (1 to 10).map(i => Tuple1((i, Seq("val_$i")))) + val data = (1 to 10).map(i => Tuple1((i, Seq(s"val_$i")))) withOrcTable(data, "t") { checkAnswer(sql("SELECT `_1`.`_2`[0] FROM t"), data.map { case Tuple1((_, Seq(string))) => Row(string) @@ -303,7 +303,7 @@ class OrcQuerySuite extends QueryTest with BeforeAndAfterAll with OrcTest { } test("nested data - array of struct") { - val data = (1 to 10).map(i => Tuple1(Seq(i -> "val_$i"))) + val data = (1 to 10).map(i => Tuple1(Seq(i -> s"val_$i"))) withOrcTable(data, "t") { checkAnswer(sql("SELECT `_1`[0].`_2` FROM t"), data.map { case Tuple1(Seq((_, string))) => Row(string) From 7a63f5e82758345ff1f3322950f2bbea350c48b9 Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Mon, 10 Apr 2017 10:47:17 +0800 Subject: [PATCH 0228/1765] [SPARK-20253][SQL] Remove unnecessary nullchecks of a return value from Spark runtime routines in generated Java code ## What changes were proposed in this pull request? This PR elminates unnecessary nullchecks of a return value from known Spark runtime routines. We know whether a given Spark runtime routine returns ``null`` or not (e.g. ``ArrayData.toDoubleArray()`` never returns ``null``). Thus, we can eliminate a null check for the return value from the Spark runtime routine. When we run the following example program, now we get the Java code "Without this PR". In this code, since we know ``ArrayData.toDoubleArray()`` never returns ``null```, we can eliminate null checks at lines 90-92, and 97. ```java val ds = sparkContext.parallelize(Seq(Array(1.1, 2.2)), 1).toDS.cache ds.count ds.map(e => e).show ``` Without this PR ```java /* 050 */ protected void processNext() throws java.io.IOException { /* 051 */ while (inputadapter_input.hasNext() && !stopEarly()) { /* 052 */ InternalRow inputadapter_row = (InternalRow) inputadapter_input.next(); /* 053 */ boolean inputadapter_isNull = inputadapter_row.isNullAt(0); /* 054 */ ArrayData inputadapter_value = inputadapter_isNull ? null : (inputadapter_row.getArray(0)); /* 055 */ /* 056 */ ArrayData deserializetoobject_value1 = null; /* 057 */ /* 058 */ if (!inputadapter_isNull) { /* 059 */ int deserializetoobject_dataLength = inputadapter_value.numElements(); /* 060 */ /* 061 */ Double[] deserializetoobject_convertedArray = null; /* 062 */ deserializetoobject_convertedArray = new Double[deserializetoobject_dataLength]; /* 063 */ /* 064 */ int deserializetoobject_loopIndex = 0; /* 065 */ while (deserializetoobject_loopIndex < deserializetoobject_dataLength) { /* 066 */ MapObjects_loopValue2 = (double) (inputadapter_value.getDouble(deserializetoobject_loopIndex)); /* 067 */ MapObjects_loopIsNull2 = inputadapter_value.isNullAt(deserializetoobject_loopIndex); /* 068 */ /* 069 */ if (MapObjects_loopIsNull2) { /* 070 */ throw new RuntimeException(((java.lang.String) references[0])); /* 071 */ } /* 072 */ if (false) { /* 073 */ deserializetoobject_convertedArray[deserializetoobject_loopIndex] = null; /* 074 */ } else { /* 075 */ deserializetoobject_convertedArray[deserializetoobject_loopIndex] = MapObjects_loopValue2; /* 076 */ } /* 077 */ /* 078 */ deserializetoobject_loopIndex += 1; /* 079 */ } /* 080 */ /* 081 */ deserializetoobject_value1 = new org.apache.spark.sql.catalyst.util.GenericArrayData(deserializetoobject_convertedArray); /*###*/ /* 082 */ } /* 083 */ boolean deserializetoobject_isNull = true; /* 084 */ double[] deserializetoobject_value = null; /* 085 */ if (!inputadapter_isNull) { /* 086 */ deserializetoobject_isNull = false; /* 087 */ if (!deserializetoobject_isNull) { /* 088 */ Object deserializetoobject_funcResult = null; /* 089 */ deserializetoobject_funcResult = deserializetoobject_value1.toDoubleArray(); /* 090 */ if (deserializetoobject_funcResult == null) { /* 091 */ deserializetoobject_isNull = true; /* 092 */ } else { /* 093 */ deserializetoobject_value = (double[]) deserializetoobject_funcResult; /* 094 */ } /* 095 */ /* 096 */ } /* 097 */ deserializetoobject_isNull = deserializetoobject_value == null; /* 098 */ } /* 099 */ /* 100 */ boolean mapelements_isNull = true; /* 101 */ double[] mapelements_value = null; /* 102 */ if (!false) { /* 103 */ mapelements_resultIsNull = false; /* 104 */ /* 105 */ if (!mapelements_resultIsNull) { /* 106 */ mapelements_resultIsNull = deserializetoobject_isNull; /* 107 */ mapelements_argValue = deserializetoobject_value; /* 108 */ } /* 109 */ /* 110 */ mapelements_isNull = mapelements_resultIsNull; /* 111 */ if (!mapelements_isNull) { /* 112 */ Object mapelements_funcResult = null; /* 113 */ mapelements_funcResult = ((scala.Function1) references[1]).apply(mapelements_argValue); /* 114 */ if (mapelements_funcResult == null) { /* 115 */ mapelements_isNull = true; /* 116 */ } else { /* 117 */ mapelements_value = (double[]) mapelements_funcResult; /* 118 */ } /* 119 */ /* 120 */ } /* 121 */ mapelements_isNull = mapelements_value == null; /* 122 */ } /* 123 */ /* 124 */ serializefromobject_resultIsNull = false; /* 125 */ /* 126 */ if (!serializefromobject_resultIsNull) { /* 127 */ serializefromobject_resultIsNull = mapelements_isNull; /* 128 */ serializefromobject_argValue = mapelements_value; /* 129 */ } /* 130 */ /* 131 */ boolean serializefromobject_isNull = serializefromobject_resultIsNull; /* 132 */ final ArrayData serializefromobject_value = serializefromobject_resultIsNull ? null : org.apache.spark.sql.catalyst.expressions.UnsafeArrayData.fromPrimitiveArray(serializefromobject_argValue); /* 133 */ serializefromobject_isNull = serializefromobject_value == null; /* 134 */ serializefromobject_holder.reset(); /* 135 */ /* 136 */ serializefromobject_rowWriter.zeroOutNullBytes(); /* 137 */ /* 138 */ if (serializefromobject_isNull) { /* 139 */ serializefromobject_rowWriter.setNullAt(0); /* 140 */ } else { /* 141 */ // Remember the current cursor so that we can calculate how many bytes are /* 142 */ // written later. /* 143 */ final int serializefromobject_tmpCursor = serializefromobject_holder.cursor; /* 144 */ /* 145 */ if (serializefromobject_value instanceof UnsafeArrayData) { /* 146 */ final int serializefromobject_sizeInBytes = ((UnsafeArrayData) serializefromobject_value).getSizeInBytes(); /* 147 */ // grow the global buffer before writing data. /* 148 */ serializefromobject_holder.grow(serializefromobject_sizeInBytes); /* 149 */ ((UnsafeArrayData) serializefromobject_value).writeToMemory(serializefromobject_holder.buffer, serializefromobject_holder.cursor); /* 150 */ serializefromobject_holder.cursor += serializefromobject_sizeInBytes; /* 151 */ /* 152 */ } else { /* 153 */ final int serializefromobject_numElements = serializefromobject_value.numElements(); /* 154 */ serializefromobject_arrayWriter.initialize(serializefromobject_holder, serializefromobject_numElements, 8); /* 155 */ /* 156 */ for (int serializefromobject_index = 0; serializefromobject_index < serializefromobject_numElements; serializefromobject_index++) { /* 157 */ if (serializefromobject_value.isNullAt(serializefromobject_index)) { /* 158 */ serializefromobject_arrayWriter.setNullDouble(serializefromobject_index); /* 159 */ } else { /* 160 */ final double serializefromobject_element = serializefromobject_value.getDouble(serializefromobject_index); /* 161 */ serializefromobject_arrayWriter.write(serializefromobject_index, serializefromobject_element); /* 162 */ } /* 163 */ } /* 164 */ } /* 165 */ /* 166 */ serializefromobject_rowWriter.setOffsetAndSize(0, serializefromobject_tmpCursor, serializefromobject_holder.cursor - serializefromobject_tmpCursor); /* 167 */ } /* 168 */ serializefromobject_result.setTotalSize(serializefromobject_holder.totalSize()); /* 169 */ append(serializefromobject_result); /* 170 */ if (shouldStop()) return; /* 171 */ } /* 172 */ } ``` With this PR (removed most of lines 90-97 in the above code) ```java /* 050 */ protected void processNext() throws java.io.IOException { /* 051 */ while (inputadapter_input.hasNext() && !stopEarly()) { /* 052 */ InternalRow inputadapter_row = (InternalRow) inputadapter_input.next(); /* 053 */ boolean inputadapter_isNull = inputadapter_row.isNullAt(0); /* 054 */ ArrayData inputadapter_value = inputadapter_isNull ? null : (inputadapter_row.getArray(0)); /* 055 */ /* 056 */ ArrayData deserializetoobject_value1 = null; /* 057 */ /* 058 */ if (!inputadapter_isNull) { /* 059 */ int deserializetoobject_dataLength = inputadapter_value.numElements(); /* 060 */ /* 061 */ Double[] deserializetoobject_convertedArray = null; /* 062 */ deserializetoobject_convertedArray = new Double[deserializetoobject_dataLength]; /* 063 */ /* 064 */ int deserializetoobject_loopIndex = 0; /* 065 */ while (deserializetoobject_loopIndex < deserializetoobject_dataLength) { /* 066 */ MapObjects_loopValue2 = (double) (inputadapter_value.getDouble(deserializetoobject_loopIndex)); /* 067 */ MapObjects_loopIsNull2 = inputadapter_value.isNullAt(deserializetoobject_loopIndex); /* 068 */ /* 069 */ if (MapObjects_loopIsNull2) { /* 070 */ throw new RuntimeException(((java.lang.String) references[0])); /* 071 */ } /* 072 */ if (false) { /* 073 */ deserializetoobject_convertedArray[deserializetoobject_loopIndex] = null; /* 074 */ } else { /* 075 */ deserializetoobject_convertedArray[deserializetoobject_loopIndex] = MapObjects_loopValue2; /* 076 */ } /* 077 */ /* 078 */ deserializetoobject_loopIndex += 1; /* 079 */ } /* 080 */ /* 081 */ deserializetoobject_value1 = new org.apache.spark.sql.catalyst.util.GenericArrayData(deserializetoobject_convertedArray); /*###*/ /* 082 */ } /* 083 */ boolean deserializetoobject_isNull = true; /* 084 */ double[] deserializetoobject_value = null; /* 085 */ if (!inputadapter_isNull) { /* 086 */ deserializetoobject_isNull = false; /* 087 */ if (!deserializetoobject_isNull) { /* 088 */ Object deserializetoobject_funcResult = null; /* 089 */ deserializetoobject_funcResult = deserializetoobject_value1.toDoubleArray(); /* 090 */ deserializetoobject_value = (double[]) deserializetoobject_funcResult; /* 091 */ /* 092 */ } /* 093 */ /* 094 */ } /* 095 */ /* 096 */ boolean mapelements_isNull = true; /* 097 */ double[] mapelements_value = null; /* 098 */ if (!false) { /* 099 */ mapelements_resultIsNull = false; /* 100 */ /* 101 */ if (!mapelements_resultIsNull) { /* 102 */ mapelements_resultIsNull = deserializetoobject_isNull; /* 103 */ mapelements_argValue = deserializetoobject_value; /* 104 */ } /* 105 */ /* 106 */ mapelements_isNull = mapelements_resultIsNull; /* 107 */ if (!mapelements_isNull) { /* 108 */ Object mapelements_funcResult = null; /* 109 */ mapelements_funcResult = ((scala.Function1) references[1]).apply(mapelements_argValue); /* 110 */ if (mapelements_funcResult == null) { /* 111 */ mapelements_isNull = true; /* 112 */ } else { /* 113 */ mapelements_value = (double[]) mapelements_funcResult; /* 114 */ } /* 115 */ /* 116 */ } /* 117 */ mapelements_isNull = mapelements_value == null; /* 118 */ } /* 119 */ /* 120 */ serializefromobject_resultIsNull = false; /* 121 */ /* 122 */ if (!serializefromobject_resultIsNull) { /* 123 */ serializefromobject_resultIsNull = mapelements_isNull; /* 124 */ serializefromobject_argValue = mapelements_value; /* 125 */ } /* 126 */ /* 127 */ boolean serializefromobject_isNull = serializefromobject_resultIsNull; /* 128 */ final ArrayData serializefromobject_value = serializefromobject_resultIsNull ? null : org.apache.spark.sql.catalyst.expressions.UnsafeArrayData.fromPrimitiveArray(serializefromobject_argValue); /* 129 */ serializefromobject_isNull = serializefromobject_value == null; /* 130 */ serializefromobject_holder.reset(); /* 131 */ /* 132 */ serializefromobject_rowWriter.zeroOutNullBytes(); /* 133 */ /* 134 */ if (serializefromobject_isNull) { /* 135 */ serializefromobject_rowWriter.setNullAt(0); /* 136 */ } else { /* 137 */ // Remember the current cursor so that we can calculate how many bytes are /* 138 */ // written later. /* 139 */ final int serializefromobject_tmpCursor = serializefromobject_holder.cursor; /* 140 */ /* 141 */ if (serializefromobject_value instanceof UnsafeArrayData) { /* 142 */ final int serializefromobject_sizeInBytes = ((UnsafeArrayData) serializefromobject_value).getSizeInBytes(); /* 143 */ // grow the global buffer before writing data. /* 144 */ serializefromobject_holder.grow(serializefromobject_sizeInBytes); /* 145 */ ((UnsafeArrayData) serializefromobject_value).writeToMemory(serializefromobject_holder.buffer, serializefromobject_holder.cursor); /* 146 */ serializefromobject_holder.cursor += serializefromobject_sizeInBytes; /* 147 */ /* 148 */ } else { /* 149 */ final int serializefromobject_numElements = serializefromobject_value.numElements(); /* 150 */ serializefromobject_arrayWriter.initialize(serializefromobject_holder, serializefromobject_numElements, 8); /* 151 */ /* 152 */ for (int serializefromobject_index = 0; serializefromobject_index < serializefromobject_numElements; serializefromobject_index++) { /* 153 */ if (serializefromobject_value.isNullAt(serializefromobject_index)) { /* 154 */ serializefromobject_arrayWriter.setNullDouble(serializefromobject_index); /* 155 */ } else { /* 156 */ final double serializefromobject_element = serializefromobject_value.getDouble(serializefromobject_index); /* 157 */ serializefromobject_arrayWriter.write(serializefromobject_index, serializefromobject_element); /* 158 */ } /* 159 */ } /* 160 */ } /* 161 */ /* 162 */ serializefromobject_rowWriter.setOffsetAndSize(0, serializefromobject_tmpCursor, serializefromobject_holder.cursor - serializefromobject_tmpCursor); /* 163 */ } /* 164 */ serializefromobject_result.setTotalSize(serializefromobject_holder.totalSize()); /* 165 */ append(serializefromobject_result); /* 166 */ if (shouldStop()) return; /* 167 */ } /* 168 */ } ``` ## How was this patch tested? Add test suites to ``DatasetPrimitiveSuite`` Author: Kazuaki Ishizaki Closes #17569 from kiszk/SPARK-20253. --- .../spark/sql/catalyst/ScalaReflection.scala | 27 ++++++++++-------- .../sql/catalyst/encoders/RowEncoder.scala | 19 +++++++------ .../expressions/objects/objects.scala | 28 +++++++++---------- .../spark/sql/DatasetPrimitiveSuite.scala | 10 +++++++ 4 files changed, 51 insertions(+), 33 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala index 206ae2f0e5eb1..198122759e4ad 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala @@ -251,19 +251,22 @@ object ScalaReflection extends ScalaReflection { getPath :: Nil) case t if t <:< localTypeOf[java.lang.String] => - Invoke(getPath, "toString", ObjectType(classOf[String])) + Invoke(getPath, "toString", ObjectType(classOf[String]), returnNullable = false) case t if t <:< localTypeOf[java.math.BigDecimal] => - Invoke(getPath, "toJavaBigDecimal", ObjectType(classOf[java.math.BigDecimal])) + Invoke(getPath, "toJavaBigDecimal", ObjectType(classOf[java.math.BigDecimal]), + returnNullable = false) case t if t <:< localTypeOf[BigDecimal] => - Invoke(getPath, "toBigDecimal", ObjectType(classOf[BigDecimal])) + Invoke(getPath, "toBigDecimal", ObjectType(classOf[BigDecimal]), returnNullable = false) case t if t <:< localTypeOf[java.math.BigInteger] => - Invoke(getPath, "toJavaBigInteger", ObjectType(classOf[java.math.BigInteger])) + Invoke(getPath, "toJavaBigInteger", ObjectType(classOf[java.math.BigInteger]), + returnNullable = false) case t if t <:< localTypeOf[scala.math.BigInt] => - Invoke(getPath, "toScalaBigInt", ObjectType(classOf[scala.math.BigInt])) + Invoke(getPath, "toScalaBigInt", ObjectType(classOf[scala.math.BigInt]), + returnNullable = false) case t if t <:< localTypeOf[Array[_]] => val TypeRef(_, _, Seq(elementType)) = t @@ -284,7 +287,7 @@ object ScalaReflection extends ScalaReflection { val arrayCls = arrayClassFor(elementType) if (elementNullable) { - Invoke(arrayData, "array", arrayCls) + Invoke(arrayData, "array", arrayCls, returnNullable = false) } else { val primitiveMethod = elementType match { case t if t <:< definitions.IntTpe => "toIntArray" @@ -297,7 +300,7 @@ object ScalaReflection extends ScalaReflection { case other => throw new IllegalStateException("expect primitive array element type " + "but got " + other) } - Invoke(arrayData, primitiveMethod, arrayCls) + Invoke(arrayData, primitiveMethod, arrayCls, returnNullable = false) } case t if t <:< localTypeOf[Seq[_]] => @@ -330,19 +333,21 @@ object ScalaReflection extends ScalaReflection { Invoke( MapObjects( p => deserializerFor(keyType, Some(p), walkedTypePath), - Invoke(getPath, "keyArray", ArrayType(schemaFor(keyType).dataType)), + Invoke(getPath, "keyArray", ArrayType(schemaFor(keyType).dataType), + returnNullable = false), schemaFor(keyType).dataType), "array", - ObjectType(classOf[Array[Any]])) + ObjectType(classOf[Array[Any]]), returnNullable = false) val valueData = Invoke( MapObjects( p => deserializerFor(valueType, Some(p), walkedTypePath), - Invoke(getPath, "valueArray", ArrayType(schemaFor(valueType).dataType)), + Invoke(getPath, "valueArray", ArrayType(schemaFor(valueType).dataType), + returnNullable = false), schemaFor(valueType).dataType), "array", - ObjectType(classOf[Array[Any]])) + ObjectType(classOf[Array[Any]]), returnNullable = false) StaticInvoke( ArrayBasedMapData.getClass, diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala index e95e97b9dc6cb..0f8282d3b2f1f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala @@ -89,7 +89,7 @@ object RowEncoder { udtClass, Nil, dataType = ObjectType(udtClass), false) - Invoke(obj, "serialize", udt, inputObject :: Nil) + Invoke(obj, "serialize", udt, inputObject :: Nil, returnNullable = false) case TimestampType => StaticInvoke( @@ -136,16 +136,18 @@ object RowEncoder { case t @ MapType(kt, vt, valueNullable) => val keys = Invoke( - Invoke(inputObject, "keysIterator", ObjectType(classOf[scala.collection.Iterator[_]])), + Invoke(inputObject, "keysIterator", ObjectType(classOf[scala.collection.Iterator[_]]), + returnNullable = false), "toSeq", - ObjectType(classOf[scala.collection.Seq[_]])) + ObjectType(classOf[scala.collection.Seq[_]]), returnNullable = false) val convertedKeys = serializerFor(keys, ArrayType(kt, false)) val values = Invoke( - Invoke(inputObject, "valuesIterator", ObjectType(classOf[scala.collection.Iterator[_]])), + Invoke(inputObject, "valuesIterator", ObjectType(classOf[scala.collection.Iterator[_]]), + returnNullable = false), "toSeq", - ObjectType(classOf[scala.collection.Seq[_]])) + ObjectType(classOf[scala.collection.Seq[_]]), returnNullable = false) val convertedValues = serializerFor(values, ArrayType(vt, valueNullable)) NewInstance( @@ -262,17 +264,18 @@ object RowEncoder { input :: Nil) case _: DecimalType => - Invoke(input, "toJavaBigDecimal", ObjectType(classOf[java.math.BigDecimal])) + Invoke(input, "toJavaBigDecimal", ObjectType(classOf[java.math.BigDecimal]), + returnNullable = false) case StringType => - Invoke(input, "toString", ObjectType(classOf[String])) + Invoke(input, "toString", ObjectType(classOf[String]), returnNullable = false) case ArrayType(et, nullable) => val arrayData = Invoke( MapObjects(deserializerFor(_), input, et), "array", - ObjectType(classOf[Array[_]])) + ObjectType(classOf[Array[_]]), returnNullable = false) StaticInvoke( scala.collection.mutable.WrappedArray.getClass, ObjectType(classOf[Seq[_]]), diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala index 53842ef348a57..6d94764f1bfac 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala @@ -225,25 +225,26 @@ case class Invoke( getFuncResult(ev.value, s"${obj.value}.$functionName($argString)") } else { val funcResult = ctx.freshName("funcResult") + // If the function can return null, we do an extra check to make sure our null bit is still + // set correctly. + val assignResult = if (!returnNullable) { + s"${ev.value} = (${ctx.boxedType(javaType)}) $funcResult;" + } else { + s""" + if ($funcResult != null) { + ${ev.value} = (${ctx.boxedType(javaType)}) $funcResult; + } else { + ${ev.isNull} = true; + } + """ + } s""" Object $funcResult = null; ${getFuncResult(funcResult, s"${obj.value}.$functionName($argString)")} - if ($funcResult == null) { - ${ev.isNull} = true; - } else { - ${ev.value} = (${ctx.boxedType(javaType)}) $funcResult; - } + $assignResult """ } - // If the function can return null, we do an extra check to make sure our null bit is still set - // correctly. - val postNullCheck = if (ctx.defaultValue(dataType) == "null") { - s"${ev.isNull} = ${ev.value} == null;" - } else { - "" - } - val code = s""" ${obj.code} boolean ${ev.isNull} = true; @@ -254,7 +255,6 @@ case class Invoke( if (!${ev.isNull}) { $evaluate } - $postNullCheck } """ ev.copy(code = code) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala index 82b707537e45f..541565344f758 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala @@ -96,6 +96,16 @@ class DatasetPrimitiveSuite extends QueryTest with SharedSQLContext { checkDataset(dsBoolean.map(e => !e), false, true) } + test("mapPrimitiveArray") { + val dsInt = Seq(Array(1, 2), Array(3, 4)).toDS() + checkDataset(dsInt.map(e => e), Array(1, 2), Array(3, 4)) + checkDataset(dsInt.map(e => null: Array[Int]), null, null) + + val dsDouble = Seq(Array(1D, 2D), Array(3D, 4D)).toDS() + checkDataset(dsDouble.map(e => e), Array(1D, 2D), Array(3D, 4D)) + checkDataset(dsDouble.map(e => null: Array[Double]), null, null) + } + test("filter") { val ds = Seq(1, 2, 3, 4).toDS() checkDataset( From 7bfa05e0a5e6860a942e1ce47e7890d665acdfe3 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Sun, 9 Apr 2017 20:32:07 -0700 Subject: [PATCH 0229/1765] [SPARK-20264][SQL] asm should be non-test dependency in sql/core ## What changes were proposed in this pull request? sq/core module currently declares asm as a test scope dependency. Transitively it should actually be a normal dependency since the actual core module defines it. This occasionally confuses IntelliJ. ## How was this patch tested? N/A - This is a build change. Author: Reynold Xin Closes #17574 from rxin/SPARK-20264. --- sql/core/pom.xml | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/sql/core/pom.xml b/sql/core/pom.xml index 69d797b479159..b203f31a76f03 100644 --- a/sql/core/pom.xml +++ b/sql/core/pom.xml @@ -103,6 +103,10 @@ jackson-databind ${fasterxml.jackson.version} + + org.apache.xbean + xbean-asm5-shaded + org.scalacheck scalacheck_${scala.binary.version} @@ -147,11 +151,6 @@ mockito-core test - - org.apache.xbean - xbean-asm5-shaded - test - target/scala-${scala.binary.version}/classes From 1a0bc41659eef317dcac18df35c26857216a4314 Mon Sep 17 00:00:00 2001 From: DB Tsai Date: Mon, 10 Apr 2017 05:16:34 +0000 Subject: [PATCH 0230/1765] [SPARK-20270][SQL] na.fill should not change the values in long or integer when the default value is in double ## What changes were proposed in this pull request? This bug was partially addressed in SPARK-18555 https://github.com/apache/spark/pull/15994, but the root cause isn't completely solved. This bug is pretty critical since it changes the member id in Long in our application if the member id can not be represented by Double losslessly when the member id is very big. Here is an example how this happens, with ``` Seq[(java.lang.Long, java.lang.Double)]((null, 3.14), (9123146099426677101L, null), (9123146560113991650L, 1.6), (null, null)).toDF("a", "b").na.fill(0.2), ``` the logical plan will be ``` == Analyzed Logical Plan == a: bigint, b: double Project [cast(coalesce(cast(a#232L as double), cast(0.2 as double)) as bigint) AS a#240L, cast(coalesce(nanvl(b#233, cast(null as double)), 0.2) as double) AS b#241] +- Project [_1#229L AS a#232L, _2#230 AS b#233] +- LocalRelation [_1#229L, _2#230] ``` Note that even the value is not null, Spark will cast the Long into Double first. Then if it's not null, Spark will cast it back to Long which results in losing precision. The behavior should be that the original value should not be changed if it's not null, but Spark will change the value which is wrong. With the PR, the logical plan will be ``` == Analyzed Logical Plan == a: bigint, b: double Project [coalesce(a#232L, cast(0.2 as bigint)) AS a#240L, coalesce(nanvl(b#233, cast(null as double)), cast(0.2 as double)) AS b#241] +- Project [_1#229L AS a#232L, _2#230 AS b#233] +- LocalRelation [_1#229L, _2#230] ``` which behaves correctly without changing the original Long values and also avoids extra cost of unnecessary casting. ## How was this patch tested? unit test added. +cc srowen rxin cloud-fan gatorsmile Thanks. Author: DB Tsai Closes #17577 from dbtsai/fixnafill. --- .../apache/spark/sql/DataFrameNaFunctions.scala | 5 +++-- .../spark/sql/DataFrameNaFunctionsSuite.scala | 14 ++++++++++++++ 2 files changed, 17 insertions(+), 2 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala index 28820681cd3a6..d8f953fba5a8b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala @@ -407,10 +407,11 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) { val quotedColName = "`" + col.name + "`" val colValue = col.dataType match { case DoubleType | FloatType => - nanvl(df.col(quotedColName), lit(null)) // nanvl only supports these types + // nanvl only supports these types + nanvl(df.col(quotedColName), lit(null).cast(col.dataType)) case _ => df.col(quotedColName) } - coalesce(colValue, lit(replacement)).cast(col.dataType).as(col.name) + coalesce(colValue, lit(replacement).cast(col.dataType)).as(col.name) } /** diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameNaFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameNaFunctionsSuite.scala index fd829846ac332..aa237d0619ac3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameNaFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameNaFunctionsSuite.scala @@ -145,6 +145,20 @@ class DataFrameNaFunctionsSuite extends QueryTest with SharedSQLContext { Row(1, 2) :: Row(-1, -2) :: Row(9123146099426677101L, 9123146560113991650L) :: Nil ) + checkAnswer( + Seq[(java.lang.Long, java.lang.Double)]((null, 3.14), (9123146099426677101L, null), + (9123146560113991650L, 1.6), (null, null)).toDF("a", "b").na.fill(0.2), + Row(0, 3.14) :: Row(9123146099426677101L, 0.2) :: Row(9123146560113991650L, 1.6) + :: Row(0, 0.2) :: Nil + ) + + checkAnswer( + Seq[(java.lang.Long, java.lang.Float)]((null, 3.14f), (9123146099426677101L, null), + (9123146560113991650L, 1.6f), (null, null)).toDF("a", "b").na.fill(0.2), + Row(0, 3.14f) :: Row(9123146099426677101L, 0.2f) :: Row(9123146560113991650L, 1.6f) + :: Row(0, 0.2f) :: Nil + ) + checkAnswer( Seq[(java.lang.Long, java.lang.Double)]((null, 1.23), (3L, null), (4L, 3.45)) .toDF("a", "b").na.fill(2.34), From 3d7f201f2adc2d33be6f564fa76435c18552f4ba Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Mon, 10 Apr 2017 13:36:08 +0800 Subject: [PATCH 0231/1765] [SPARK-20229][SQL] add semanticHash to QueryPlan ## What changes were proposed in this pull request? Like `Expression`, `QueryPlan` should also have a `semanticHash` method, then we can put plans to a hash map and look it up fast. This PR refactors `QueryPlan` to follow `Expression` and put all the normalization logic in `QueryPlan.canonicalized`, so that it's very natural to implement `semanticHash`. follow-up: improve `CacheManager` to leverage this `semanticHash` and speed up plan lookup, instead of iterating all cached plans. ## How was this patch tested? existing tests. Note that we don't need to test the `semanticHash` method, once the existing tests prove `sameResult` is correct, we are good. Author: Wenchen Fan Closes #17541 from cloud-fan/plan-semantic. --- .../sql/catalyst/analysis/Analyzer.scala | 2 +- .../sql/catalyst/catalog/interface.scala | 11 +- .../spark/sql/catalyst/plans/QueryPlan.scala | 102 +++++++++++------- .../plans/logical/LocalRelation.scala | 8 -- .../catalyst/plans/logical/LogicalPlan.scala | 2 - .../plans/logical/basicLogicalOperators.scala | 2 + .../plans/physical/broadcastMode.scala | 9 +- .../sql/execution/DataSourceScanExec.scala | 37 +++---- .../spark/sql/execution/ExistingRDD.scala | 14 --- .../sql/execution/LocalTableScanExec.scala | 2 +- .../execution/basicPhysicalOperators.scala | 10 +- .../datasources/LogicalRelation.scala | 13 +-- .../exchange/BroadcastExchangeExec.scala | 6 +- .../sql/execution/exchange/Exchange.scala | 6 +- .../sql/execution/joins/HashedRelation.scala | 11 +- .../spark/sql/execution/ExchangeSuite.scala | 18 ++-- .../hive/execution/HiveTableScanExec.scala | 45 ++++---- 17 files changed, 135 insertions(+), 163 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index c698ca6a8347c..b0cdef70297cf 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -617,7 +617,7 @@ class Analyzer( def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { case i @ InsertIntoTable(u: UnresolvedRelation, parts, child, _, _) if child.resolved => - lookupTableFromCatalog(u).canonicalized match { + EliminateSubqueryAliases(lookupTableFromCatalog(u)) match { case v: View => u.failAnalysis(s"Inserting into a view is not allowed. View: ${v.desc.identifier}.") case other => i.copy(table = other) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala index 360e55d922821..cc0cbba275b81 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala @@ -423,8 +423,15 @@ case class CatalogRelation( Objects.hashCode(tableMeta.identifier, output) } - /** Only compare table identifier. */ - override lazy val cleanArgs: Seq[Any] = Seq(tableMeta.identifier) + override def preCanonicalized: LogicalPlan = copy(tableMeta = CatalogTable( + identifier = tableMeta.identifier, + tableType = tableMeta.tableType, + storage = CatalogStorageFormat.empty, + schema = tableMeta.schema, + partitionColumnNames = tableMeta.partitionColumnNames, + bucketSpec = tableMeta.bucketSpec, + createTime = -1 + )) override def computeStats(conf: SQLConf): Statistics = { // For data source tables, we will create a `LogicalRelation` and won't call this method, for diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala index 2d8ec2053a4cb..3008e8cb84659 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala @@ -359,9 +359,59 @@ abstract class QueryPlan[PlanType <: QueryPlan[PlanType]] extends TreeNode[PlanT override protected def innerChildren: Seq[QueryPlan[_]] = subqueries /** - * Canonicalized copy of this query plan. + * Returns a plan where a best effort attempt has been made to transform `this` in a way + * that preserves the result but removes cosmetic variations (case sensitivity, ordering for + * commutative operations, expression id, etc.) + * + * Plans where `this.canonicalized == other.canonicalized` will always evaluate to the same + * result. + * + * Some nodes should overwrite this to provide proper canonicalize logic. + */ + lazy val canonicalized: PlanType = { + val canonicalizedChildren = children.map(_.canonicalized) + var id = -1 + preCanonicalized.mapExpressions { + case a: Alias => + id += 1 + // As the root of the expression, Alias will always take an arbitrary exprId, we need to + // normalize that for equality testing, by assigning expr id from 0 incrementally. The + // alias name doesn't matter and should be erased. + Alias(normalizeExprId(a.child), "")(ExprId(id), a.qualifier, isGenerated = a.isGenerated) + + case ar: AttributeReference if allAttributes.indexOf(ar.exprId) == -1 => + // Top level `AttributeReference` may also be used for output like `Alias`, we should + // normalize the epxrId too. + id += 1 + ar.withExprId(ExprId(id)) + + case other => normalizeExprId(other) + }.withNewChildren(canonicalizedChildren) + } + + /** + * Do some simple transformation on this plan before canonicalizing. Implementations can override + * this method to provide customized canonicalize logic without rewriting the whole logic. */ - protected lazy val canonicalized: PlanType = this + protected def preCanonicalized: PlanType = this + + /** + * Normalize the exprIds in the given expression, by updating the exprId in `AttributeReference` + * with its referenced ordinal from input attributes. It's similar to `BindReferences` but we + * do not use `BindReferences` here as the plan may take the expression as a parameter with type + * `Attribute`, and replace it with `BoundReference` will cause error. + */ + protected def normalizeExprId[T <: Expression](e: T, input: AttributeSeq = allAttributes): T = { + e.transformUp { + case ar: AttributeReference => + val ordinal = input.indexOf(ar.exprId) + if (ordinal == -1) { + ar + } else { + ar.withExprId(ExprId(ordinal)) + } + }.canonicalized.asInstanceOf[T] + } /** * Returns true when the given query plan will return the same results as this query plan. @@ -372,49 +422,19 @@ abstract class QueryPlan[PlanType <: QueryPlan[PlanType]] extends TreeNode[PlanT * enhancements like caching. However, it is not acceptable to return true if the results could * possibly be different. * - * By default this function performs a modified version of equality that is tolerant of cosmetic - * differences like attribute naming and or expression id differences. Operators that - * can do better should override this function. + * This function performs a modified version of equality that is tolerant of cosmetic + * differences like attribute naming and or expression id differences. */ - def sameResult(plan: PlanType): Boolean = { - val left = this.canonicalized - val right = plan.canonicalized - left.getClass == right.getClass && - left.children.size == right.children.size && - left.cleanArgs == right.cleanArgs && - (left.children, right.children).zipped.forall(_ sameResult _) - } + final def sameResult(other: PlanType): Boolean = this.canonicalized == other.canonicalized + + /** + * Returns a `hashCode` for the calculation performed by this plan. Unlike the standard + * `hashCode`, an attempt has been made to eliminate cosmetic differences. + */ + final def semanticHash(): Int = canonicalized.hashCode() /** * All the attributes that are used for this plan. */ lazy val allAttributes: AttributeSeq = children.flatMap(_.output) - - protected def cleanExpression(e: Expression): Expression = e match { - case a: Alias => - // As the root of the expression, Alias will always take an arbitrary exprId, we need - // to erase that for equality testing. - val cleanedExprId = - Alias(a.child, a.name)(ExprId(-1), a.qualifier, isGenerated = a.isGenerated) - BindReferences.bindReference(cleanedExprId, allAttributes, allowFailures = true) - case other => - BindReferences.bindReference(other, allAttributes, allowFailures = true) - } - - /** Args that have cleaned such that differences in expression id should not affect equality */ - protected lazy val cleanArgs: Seq[Any] = { - def cleanArg(arg: Any): Any = arg match { - // Children are checked using sameResult above. - case tn: TreeNode[_] if containsChild(tn) => null - case e: Expression => cleanExpression(e).canonicalized - case other => other - } - - mapProductIterator { - case s: Option[_] => s.map(cleanArg) - case s: Seq[_] => s.map(cleanArg) - case m: Map[_, _] => m.mapValues(cleanArg) - case other => cleanArg(other) - }.toSeq - } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LocalRelation.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LocalRelation.scala index b7177c4a2c4e4..9cd5dfd21b160 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LocalRelation.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LocalRelation.scala @@ -67,14 +67,6 @@ case class LocalRelation(output: Seq[Attribute], data: Seq[InternalRow] = Nil) } } - override def sameResult(plan: LogicalPlan): Boolean = { - plan.canonicalized match { - case LocalRelation(otherOutput, otherData) => - otherOutput.map(_.dataType) == output.map(_.dataType) && otherData == data - case _ => false - } - } - override def computeStats(conf: SQLConf): Statistics = Statistics(sizeInBytes = output.map(n => BigInt(n.dataType.defaultSize)).sum * data.length) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala index 036b6256684cb..6bdcf490ca5c8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala @@ -143,8 +143,6 @@ abstract class LogicalPlan extends QueryPlan[LogicalPlan] with Logging { */ def childrenResolved: Boolean = children.forall(_.resolved) - override lazy val canonicalized: LogicalPlan = EliminateSubqueryAliases(this) - /** * Resolves a given schema to concrete [[Attribute]] references in this query plan. This function * should only be called on analyzed plans since it will throw [[AnalysisException]] for diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala index c91de08ca5ef6..3ad757ebba851 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala @@ -803,6 +803,8 @@ case class SubqueryAlias( child: LogicalPlan) extends UnaryNode { + override lazy val canonicalized: LogicalPlan = child.canonicalized + override def output: Seq[Attribute] = child.output.map(_.withQualifier(Some(alias))) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/broadcastMode.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/broadcastMode.scala index 9dfdf4da78ff6..2ab46dc8330aa 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/broadcastMode.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/broadcastMode.scala @@ -26,10 +26,7 @@ import org.apache.spark.sql.catalyst.InternalRow trait BroadcastMode { def transform(rows: Array[InternalRow]): Any - /** - * Returns true iff this [[BroadcastMode]] generates the same result as `other`. - */ - def compatibleWith(other: BroadcastMode): Boolean + def canonicalized: BroadcastMode } /** @@ -39,7 +36,5 @@ case object IdentityBroadcastMode extends BroadcastMode { // TODO: pack the UnsafeRows into single bytes array. override def transform(rows: Array[InternalRow]): Array[InternalRow] = rows - override def compatibleWith(other: BroadcastMode): Boolean = { - this eq other - } + override def canonicalized: BroadcastMode = this } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala index 2fa660c4d5e01..3a9132d74ac11 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala @@ -119,7 +119,7 @@ case class RowDataSourceScanExec( val input = ctx.freshName("input") ctx.addMutableState("scala.collection.Iterator", input, s"$input = inputs[0];") val exprRows = output.zipWithIndex.map{ case (a, i) => - new BoundReference(i, a.dataType, a.nullable) + BoundReference(i, a.dataType, a.nullable) } val row = ctx.freshName("row") ctx.INPUT_ROW = row @@ -136,19 +136,17 @@ case class RowDataSourceScanExec( """.stripMargin } - // Ignore rdd when checking results - override def sameResult(plan: SparkPlan): Boolean = plan match { - case other: RowDataSourceScanExec => relation == other.relation && metadata == other.metadata - case _ => false - } + // Only care about `relation` and `metadata` when canonicalizing. + override def preCanonicalized: SparkPlan = + copy(rdd = null, outputPartitioning = null, metastoreTableIdentifier = None) } /** * Physical plan node for scanning data from HadoopFsRelations. * * @param relation The file-based relation to scan. - * @param output Output attributes of the scan. - * @param outputSchema Output schema of the scan. + * @param output Output attributes of the scan, including data attributes and partition attributes. + * @param requiredSchema Required schema of the underlying relation, excluding partition columns. * @param partitionFilters Predicates to use for partition pruning. * @param dataFilters Filters on non-partition columns. * @param metastoreTableIdentifier identifier for the table in the metastore. @@ -156,7 +154,7 @@ case class RowDataSourceScanExec( case class FileSourceScanExec( @transient relation: HadoopFsRelation, output: Seq[Attribute], - outputSchema: StructType, + requiredSchema: StructType, partitionFilters: Seq[Expression], dataFilters: Seq[Expression], override val metastoreTableIdentifier: Option[TableIdentifier]) @@ -267,7 +265,7 @@ case class FileSourceScanExec( val metadata = Map( "Format" -> relation.fileFormat.toString, - "ReadSchema" -> outputSchema.catalogString, + "ReadSchema" -> requiredSchema.catalogString, "Batched" -> supportsBatch.toString, "PartitionFilters" -> seqToString(partitionFilters), "PushedFilters" -> seqToString(pushedDownFilters), @@ -287,7 +285,7 @@ case class FileSourceScanExec( sparkSession = relation.sparkSession, dataSchema = relation.dataSchema, partitionSchema = relation.partitionSchema, - requiredSchema = outputSchema, + requiredSchema = requiredSchema, filters = pushedDownFilters, options = relation.options, hadoopConf = relation.sparkSession.sessionState.newHadoopConfWithOptions(relation.options)) @@ -515,14 +513,13 @@ case class FileSourceScanExec( } } - override def sameResult(plan: SparkPlan): Boolean = plan match { - case other: FileSourceScanExec => - val thisPredicates = partitionFilters.map(cleanExpression) - val otherPredicates = other.partitionFilters.map(cleanExpression) - val result = relation == other.relation && metadata == other.metadata && - thisPredicates.length == otherPredicates.length && - thisPredicates.zip(otherPredicates).forall(p => p._1.semanticEquals(p._2)) - result - case _ => false + override lazy val canonicalized: FileSourceScanExec = { + FileSourceScanExec( + relation, + output.map(normalizeExprId(_, output)), + requiredSchema, + partitionFilters.map(normalizeExprId(_, output)), + dataFilters.map(normalizeExprId(_, output)), + None) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala index 2827b8ac00331..3d1b481a53e75 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala @@ -87,13 +87,6 @@ case class ExternalRDD[T]( override def newInstance(): ExternalRDD.this.type = ExternalRDD(outputObjAttr.newInstance(), rdd)(session).asInstanceOf[this.type] - override def sameResult(plan: LogicalPlan): Boolean = { - plan.canonicalized match { - case ExternalRDD(_, otherRDD) => rdd.id == otherRDD.id - case _ => false - } - } - override protected def stringArgs: Iterator[Any] = Iterator(output) @transient override def computeStats(conf: SQLConf): Statistics = Statistics( @@ -162,13 +155,6 @@ case class LogicalRDD( )(session).asInstanceOf[this.type] } - override def sameResult(plan: LogicalPlan): Boolean = { - plan.canonicalized match { - case LogicalRDD(_, otherRDD, _, _) => rdd.id == otherRDD.id - case _ => false - } - } - override protected def stringArgs: Iterator[Any] = Iterator(output) @transient override def computeStats(conf: SQLConf): Statistics = Statistics( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/LocalTableScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/LocalTableScanExec.scala index e366b9af35c62..19c68c13262a5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/LocalTableScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/LocalTableScanExec.scala @@ -33,7 +33,7 @@ case class LocalTableScanExec( override lazy val metrics = Map( "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows")) - private val unsafeRows: Array[InternalRow] = { + private lazy val unsafeRows: Array[InternalRow] = { if (rows.isEmpty) { Array.empty } else { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala index 66a8e044ab879..44278e37c5276 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala @@ -342,8 +342,9 @@ case class RangeExec(range: org.apache.spark.sql.catalyst.plans.logical.Range) "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows"), "numGeneratedRows" -> SQLMetrics.createMetric(sparkContext, "number of generated rows")) - // output attributes should not affect the results - override lazy val cleanArgs: Seq[Any] = Seq(start, step, numSlices, numElements) + override lazy val canonicalized: SparkPlan = { + RangeExec(range.canonicalized.asInstanceOf[org.apache.spark.sql.catalyst.plans.logical.Range]) + } override def inputRDDs(): Seq[RDD[InternalRow]] = { sqlContext.sparkContext.parallelize(0 until numSlices, numSlices) @@ -607,11 +608,6 @@ case class SubqueryExec(name: String, child: SparkPlan) extends UnaryExecNode { override def outputOrdering: Seq[SortOrder] = child.outputOrdering - override def sameResult(o: SparkPlan): Boolean = o match { - case s: SubqueryExec => child.sameResult(s.child) - case _ => false - } - @transient private lazy val relationFuture: Future[Array[InternalRow]] = { // relationFuture is used in "doExecute". Therefore we can get the execution id correctly here. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/LogicalRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/LogicalRelation.scala index 4215203960075..3813f953e06a3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/LogicalRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/LogicalRelation.scala @@ -43,17 +43,8 @@ case class LogicalRelation( com.google.common.base.Objects.hashCode(relation, output) } - override def sameResult(otherPlan: LogicalPlan): Boolean = { - otherPlan.canonicalized match { - case LogicalRelation(otherRelation, _, _) => relation == otherRelation - case _ => false - } - } - - // When comparing two LogicalRelations from within LogicalPlan.sameResult, we only need - // LogicalRelation.cleanArgs to return Seq(relation), since expectedOutputAttribute's - // expId can be different but the relation is still the same. - override lazy val cleanArgs: Seq[Any] = Seq(relation) + // Only care about relation when canonicalizing. + override def preCanonicalized: LogicalPlan = copy(catalogTable = None) @transient override def computeStats(conf: SQLConf): Statistics = { catalogTable.flatMap(_.stats.map(_.toPlanStats(output))).getOrElse( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/BroadcastExchangeExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/BroadcastExchangeExec.scala index efcaca9338ad6..9c859e41f8762 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/BroadcastExchangeExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/BroadcastExchangeExec.scala @@ -48,10 +48,8 @@ case class BroadcastExchangeExec( override def outputPartitioning: Partitioning = BroadcastPartitioning(mode) - override def sameResult(plan: SparkPlan): Boolean = plan match { - case p: BroadcastExchangeExec => - mode.compatibleWith(p.mode) && child.sameResult(p.child) - case _ => false + override lazy val canonicalized: SparkPlan = { + BroadcastExchangeExec(mode.canonicalized, child.canonicalized) } @transient diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/Exchange.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/Exchange.scala index 9a9597d3733e0..d993ea6c6cef9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/Exchange.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/Exchange.scala @@ -48,10 +48,8 @@ abstract class Exchange extends UnaryExecNode { case class ReusedExchangeExec(override val output: Seq[Attribute], child: Exchange) extends LeafExecNode { - override def sameResult(plan: SparkPlan): Boolean = { - // Ignore this wrapper. `plan` could also be a ReusedExchange, so we reverse the order here. - plan.sameResult(child) - } + // Ignore this wrapper for canonicalizing. + override lazy val canonicalized: SparkPlan = child.canonicalized def doExecute(): RDD[InternalRow] = { child.execute() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala index b9f6601ea87fe..2dd1dc3da96c9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala @@ -829,15 +829,10 @@ private[execution] case class HashedRelationBroadcastMode(key: Seq[Expression]) extends BroadcastMode { override def transform(rows: Array[InternalRow]): HashedRelation = { - HashedRelation(rows.iterator, canonicalizedKey, rows.length) + HashedRelation(rows.iterator, canonicalized.key, rows.length) } - private lazy val canonicalizedKey: Seq[Expression] = { - key.map { e => e.canonicalized } - } - - override def compatibleWith(other: BroadcastMode): Boolean = other match { - case m: HashedRelationBroadcastMode => canonicalizedKey == m.canonicalizedKey - case _ => false + override lazy val canonicalized: HashedRelationBroadcastMode = { + this.copy(key = key.map(_.canonicalized)) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeSuite.scala index 36cde3233dce8..59eaf4d1c29b7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeSuite.scala @@ -36,17 +36,17 @@ class ExchangeSuite extends SparkPlanTest with SharedSQLContext { ) } - test("compatible BroadcastMode") { + test("BroadcastMode.canonicalized") { val mode1 = IdentityBroadcastMode val mode2 = HashedRelationBroadcastMode(Literal(1L) :: Nil) val mode3 = HashedRelationBroadcastMode(Literal("s") :: Nil) - assert(mode1.compatibleWith(mode1)) - assert(!mode1.compatibleWith(mode2)) - assert(!mode2.compatibleWith(mode1)) - assert(mode2.compatibleWith(mode2)) - assert(!mode2.compatibleWith(mode3)) - assert(mode3.compatibleWith(mode3)) + assert(mode1.canonicalized == mode1.canonicalized) + assert(mode1.canonicalized != mode2.canonicalized) + assert(mode2.canonicalized != mode1.canonicalized) + assert(mode2.canonicalized == mode2.canonicalized) + assert(mode2.canonicalized != mode3.canonicalized) + assert(mode3.canonicalized == mode3.canonicalized) } test("BroadcastExchange same result") { @@ -70,7 +70,7 @@ class ExchangeSuite extends SparkPlanTest with SharedSQLContext { assert(!exchange1.sameResult(exchange2)) assert(!exchange2.sameResult(exchange3)) - assert(!exchange3.sameResult(exchange4)) + assert(exchange3.sameResult(exchange4)) assert(exchange4 sameResult exchange3) } @@ -98,7 +98,7 @@ class ExchangeSuite extends SparkPlanTest with SharedSQLContext { assert(exchange1 sameResult exchange2) assert(!exchange2.sameResult(exchange3)) assert(!exchange3.sameResult(exchange4)) - assert(!exchange4.sameResult(exchange5)) + assert(exchange4.sameResult(exchange5)) assert(exchange5 sameResult exchange4) } } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScanExec.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScanExec.scala index 28f074849c0f5..fab0d7fa84827 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScanExec.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScanExec.scala @@ -72,7 +72,7 @@ case class HiveTableScanExec( // Bind all partition key attribute references in the partition pruning predicate for later // evaluation. - private val boundPruningPred = partitionPruningPred.reduceLeftOption(And).map { pred => + private lazy val boundPruningPred = partitionPruningPred.reduceLeftOption(And).map { pred => require( pred.dataType == BooleanType, s"Data type of predicate $pred must be BooleanType rather than ${pred.dataType}.") @@ -80,20 +80,22 @@ case class HiveTableScanExec( BindReferences.bindReference(pred, relation.partitionCols) } - // Create a local copy of hadoopConf,so that scan specific modifications should not impact - // other queries - @transient private val hadoopConf = sparkSession.sessionState.newHadoopConf() - - @transient private val hiveQlTable = HiveClientImpl.toHiveTable(relation.tableMeta) - @transient private val tableDesc = new TableDesc( + @transient private lazy val hiveQlTable = HiveClientImpl.toHiveTable(relation.tableMeta) + @transient private lazy val tableDesc = new TableDesc( hiveQlTable.getInputFormatClass, hiveQlTable.getOutputFormatClass, hiveQlTable.getMetadata) - // append columns ids and names before broadcast - addColumnMetadataToConf(hadoopConf) + // Create a local copy of hadoopConf,so that scan specific modifications should not impact + // other queries + @transient private lazy val hadoopConf = { + val c = sparkSession.sessionState.newHadoopConf() + // append columns ids and names before broadcast + addColumnMetadataToConf(c) + c + } - @transient private val hadoopReader = new HadoopTableReader( + @transient private lazy val hadoopReader = new HadoopTableReader( output, relation.partitionCols, tableDesc, @@ -104,7 +106,7 @@ case class HiveTableScanExec( Cast(Literal(value), dataType).eval(null) } - private def addColumnMetadataToConf(hiveConf: Configuration) { + private def addColumnMetadataToConf(hiveConf: Configuration): Unit = { // Specifies needed column IDs for those non-partitioning columns. val columnOrdinals = AttributeMap(relation.dataCols.zipWithIndex) val neededColumnIDs = output.flatMap(columnOrdinals.get).map(o => o: Integer) @@ -198,18 +200,13 @@ case class HiveTableScanExec( } } - override def sameResult(plan: SparkPlan): Boolean = plan match { - case other: HiveTableScanExec => - val thisPredicates = partitionPruningPred.map(cleanExpression) - val otherPredicates = other.partitionPruningPred.map(cleanExpression) - - val result = relation.sameResult(other.relation) && - output.length == other.output.length && - output.zip(other.output) - .forall(p => p._1.name == p._2.name && p._1.dataType == p._2.dataType) && - thisPredicates.length == otherPredicates.length && - thisPredicates.zip(otherPredicates).forall(p => p._1.semanticEquals(p._2)) - result - case _ => false + override lazy val canonicalized: HiveTableScanExec = { + val input: AttributeSeq = relation.output + HiveTableScanExec( + requestedAttributes.map(normalizeExprId(_, input)), + relation.canonicalized.asInstanceOf[CatalogRelation], + partitionPruningPred.map(normalizeExprId(_, input)))(sparkSession) } + + override def otherCopyArgs: Seq[AnyRef] = Seq(sparkSession) } From 4f7d49b955b8c362da29a2540697240f4564d3ee Mon Sep 17 00:00:00 2001 From: Bogdan Raducanu Date: Mon, 10 Apr 2017 17:34:15 +0200 Subject: [PATCH 0232/1765] [SPARK-20243][TESTS] DebugFilesystem.assertNoOpenStreams thread race ## What changes were proposed in this pull request? Synchronize access to openStreams map. ## How was this patch tested? Existing tests. Author: Bogdan Raducanu Closes #17592 from bogdanrdc/SPARK-20243. --- .../org/apache/spark/DebugFilesystem.scala | 26 ++++++++++++------- 1 file changed, 16 insertions(+), 10 deletions(-) diff --git a/core/src/test/scala/org/apache/spark/DebugFilesystem.scala b/core/src/test/scala/org/apache/spark/DebugFilesystem.scala index 72aea841117cc..91355f7362900 100644 --- a/core/src/test/scala/org/apache/spark/DebugFilesystem.scala +++ b/core/src/test/scala/org/apache/spark/DebugFilesystem.scala @@ -20,7 +20,6 @@ package org.apache.spark import java.io.{FileDescriptor, InputStream} import java.lang import java.nio.ByteBuffer -import java.util.concurrent.ConcurrentHashMap import scala.collection.JavaConverters._ import scala.collection.mutable @@ -31,21 +30,29 @@ import org.apache.spark.internal.Logging object DebugFilesystem extends Logging { // Stores the set of active streams and their creation sites. - private val openStreams = new ConcurrentHashMap[FSDataInputStream, Throwable]() + private val openStreams = mutable.Map.empty[FSDataInputStream, Throwable] - def clearOpenStreams(): Unit = { + def addOpenStream(stream: FSDataInputStream): Unit = openStreams.synchronized { + openStreams.put(stream, new Throwable()) + } + + def clearOpenStreams(): Unit = openStreams.synchronized { openStreams.clear() } - def assertNoOpenStreams(): Unit = { - val numOpen = openStreams.size() + def removeOpenStream(stream: FSDataInputStream): Unit = openStreams.synchronized { + openStreams.remove(stream) + } + + def assertNoOpenStreams(): Unit = openStreams.synchronized { + val numOpen = openStreams.values.size if (numOpen > 0) { - for (exc <- openStreams.values().asScala) { + for (exc <- openStreams.values) { logWarning("Leaked filesystem connection created at:") exc.printStackTrace() } throw new IllegalStateException(s"There are $numOpen possibly leaked file streams.", - openStreams.values().asScala.head) + openStreams.values.head) } } } @@ -60,8 +67,7 @@ class DebugFilesystem extends LocalFileSystem { override def open(f: Path, bufferSize: Int): FSDataInputStream = { val wrapped: FSDataInputStream = super.open(f, bufferSize) - openStreams.put(wrapped, new Throwable()) - + addOpenStream(wrapped) new FSDataInputStream(wrapped.getWrappedStream) { override def setDropBehind(dropBehind: lang.Boolean): Unit = wrapped.setDropBehind(dropBehind) @@ -98,7 +104,7 @@ class DebugFilesystem extends LocalFileSystem { override def close(): Unit = { wrapped.close() - openStreams.remove(wrapped) + removeOpenStream(wrapped) } override def read(): Int = wrapped.read() From 5acaf8c0c685e47ec619fbdfd353163721e1cf50 Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Mon, 10 Apr 2017 17:45:27 +0200 Subject: [PATCH 0233/1765] [SPARK-19518][SQL] IGNORE NULLS in first / last in SQL ## What changes were proposed in this pull request? This PR proposes to add `IGNORE NULLS` keyword in `first`/`last` in Spark's parser likewise http://docs.oracle.com/cd/B19306_01/server.102/b14200/functions057.htm. This simply maps the keywords to existing `ignoreNullsExpr`. **Before** ```scala scala> sql("select first('a' IGNORE NULLS)").show() ``` ``` org.apache.spark.sql.catalyst.parser.ParseException: extraneous input 'NULLS' expecting {')', ','}(line 1, pos 24) == SQL == select first('a' IGNORE NULLS) ------------------------^^^ at org.apache.spark.sql.catalyst.parser.ParseException.withCommand(ParseDriver.scala:210) at org.apache.spark.sql.catalyst.parser.AbstractSqlParser.parse(ParseDriver.scala:112) at org.apache.spark.sql.execution.SparkSqlParser.parse(SparkSqlParser.scala:46) at org.apache.spark.sql.catalyst.parser.AbstractSqlParser.parsePlan(ParseDriver.scala:66) at org.apache.spark.sql.SparkSession.sql(SparkSession.scala:622) ... 48 elided ``` **After** ```scala scala> sql("select first('a' IGNORE NULLS)").show() ``` ``` +--------------+ |first(a, true)| +--------------+ | a| +--------------+ ``` ## How was this patch tested? Unit tests in `ExpressionParserSuite`. Author: hyukjinkwon Closes #17566 from HyukjinKwon/SPARK-19518. --- .../apache/spark/sql/catalyst/parser/SqlBase.g4 | 5 ++++- .../spark/sql/catalyst/parser/AstBuilder.scala | 17 +++++++++++++++++ .../catalyst/parser/ExpressionParserSuite.scala | 8 ++++++++ 3 files changed, 29 insertions(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 index 52b5b347fa9c7..1ecb3d1958f43 100644 --- a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 +++ b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 @@ -552,6 +552,8 @@ primaryExpression | CASE whenClause+ (ELSE elseExpression=expression)? END #searchedCase | CASE value=expression whenClause+ (ELSE elseExpression=expression)? END #simpleCase | CAST '(' expression AS dataType ')' #cast + | FIRST '(' expression (IGNORE NULLS)? ')' #first + | LAST '(' expression (IGNORE NULLS)? ')' #last | constant #constantDefault | ASTERISK #star | qualifiedName '.' ASTERISK #star @@ -710,7 +712,7 @@ nonReserved | VIEW | REPLACE | IF | NO | DATA - | START | TRANSACTION | COMMIT | ROLLBACK + | START | TRANSACTION | COMMIT | ROLLBACK | IGNORE | SORT | CLUSTER | DISTRIBUTE | UNSET | TBLPROPERTIES | SKEWED | STORED | DIRECTORIES | LOCATION | EXCHANGE | ARCHIVE | UNARCHIVE | FILEFORMAT | TOUCH | COMPACT | CONCATENATE | CHANGE | CASCADE | RESTRICT | BUCKETS | CLUSTERED | SORTED | PURGE | INPUTFORMAT | OUTPUTFORMAT @@ -836,6 +838,7 @@ TRANSACTION: 'TRANSACTION'; COMMIT: 'COMMIT'; ROLLBACK: 'ROLLBACK'; MACRO: 'MACRO'; +IGNORE: 'IGNORE'; IF: 'IF'; diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index fab7e4c5b1285..c37255153802b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -31,6 +31,7 @@ import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.{FunctionIdentifier, TableIdentifier} import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.aggregate.{First, Last} import org.apache.spark.sql.catalyst.parser.SqlBaseParser._ import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical._ @@ -1022,6 +1023,22 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with Logging { Cast(expression(ctx.expression), visitSparkDataType(ctx.dataType)) } + /** + * Create a [[First]] expression. + */ + override def visitFirst(ctx: FirstContext): Expression = withOrigin(ctx) { + val ignoreNullsExpr = ctx.IGNORE != null + First(expression(ctx.expression), Literal(ignoreNullsExpr)).toAggregateExpression() + } + + /** + * Create a [[Last]] expression. + */ + override def visitLast(ctx: LastContext): Expression = withOrigin(ctx) { + val ignoreNullsExpr = ctx.IGNORE != null + Last(expression(ctx.expression), Literal(ignoreNullsExpr)).toAggregateExpression() + } + /** * Create a (windowed) Function expression. */ diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala index d1c6b50536cd2..e7f3b64a71130 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala @@ -21,6 +21,7 @@ import java.sql.{Date, Timestamp} import org.apache.spark.sql.catalyst.FunctionIdentifier import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, _} import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.aggregate.{First, Last} import org.apache.spark.sql.catalyst.plans.PlanTest import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.CalendarInterval @@ -549,4 +550,11 @@ class ExpressionParserSuite extends PlanTest { val complexName2 = FunctionIdentifier("ba``r", Some("fo``o")) assertEqual(complexName2.quotedString, UnresolvedAttribute("fo``o.ba``r")) } + + test("SPARK-19526 Support ignore nulls keywords for first and last") { + assertEqual("first(a ignore nulls)", First('a, Literal(true)).toAggregateExpression()) + assertEqual("first(a)", First('a, Literal(false)).toAggregateExpression()) + assertEqual("last(a ignore nulls)", Last('a, Literal(true)).toAggregateExpression()) + assertEqual("last(a)", Last('a, Literal(false)).toAggregateExpression()) + } } From fd711ea13e558f0e7d3e01f08e01444d394499a6 Mon Sep 17 00:00:00 2001 From: Xiao Li Date: Mon, 10 Apr 2017 09:15:04 -0700 Subject: [PATCH 0234/1765] [SPARK-20273][SQL] Disallow Non-deterministic Filter push-down into Join Conditions ## What changes were proposed in this pull request? ``` sql("SELECT t1.b, rand(0) as r FROM cachedData, cachedData t1 GROUP BY t1.b having r > 0.5").show() ``` We will get the following error: ``` Job aborted due to stage failure: Task 1 in stage 4.0 failed 1 times, most recent failure: Lost task 1.0 in stage 4.0 (TID 8, localhost, executor driver): java.lang.NullPointerException at org.apache.spark.sql.catalyst.expressions.GeneratedClass$SpecificPredicate.eval(Unknown Source) at org.apache.spark.sql.execution.joins.BroadcastNestedLoopJoinExec$$anonfun$org$apache$spark$sql$execution$joins$BroadcastNestedLoopJoinExec$$boundCondition$1.apply(BroadcastNestedLoopJoinExec.scala:87) at org.apache.spark.sql.execution.joins.BroadcastNestedLoopJoinExec$$anonfun$org$apache$spark$sql$execution$joins$BroadcastNestedLoopJoinExec$$boundCondition$1.apply(BroadcastNestedLoopJoinExec.scala:87) at scala.collection.Iterator$$anon$13.hasNext(Iterator.scala:463) ``` Filters could be pushed down to the join conditions by the optimizer rule `PushPredicateThroughJoin`. However, Analyzer [blocks users to add non-deterministics conditions](https://github.com/apache/spark/blob/master/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala#L386-L395) (For details, see the PR https://github.com/apache/spark/pull/7535). We should not push down non-deterministic conditions; otherwise, we need to explicitly initialize the non-deterministic expressions. This PR is to simply block it. ### How was this patch tested? Added a test case Author: Xiao Li Closes #17585 from gatorsmile/joinRandCondition. --- .../spark/sql/catalyst/expressions/predicates.scala | 2 ++ .../sql/catalyst/optimizer/FilterPushdownSuite.scala | 10 ++++++++++ 2 files changed, 12 insertions(+) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala index 1235204591bbd..8acb740f8db8c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala @@ -90,6 +90,8 @@ trait PredicateHelper { * Returns true iff `expr` could be evaluated as a condition within join. */ protected def canEvaluateWithinJoin(expr: Expression): Boolean = expr match { + // Non-deterministic expressions are not allowed as join conditions. + case e if !e.deterministic => false case l: ListQuery => // A ListQuery defines the query which we want to search in an IN subquery expression. // Currently the only way to evaluate an IN subquery is to convert it to a diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala index ccd0b7c5d7f79..950aa2379517e 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala @@ -241,6 +241,16 @@ class FilterPushdownSuite extends PlanTest { comparePlans(optimized, correctAnswer) } + test("joins: do not push down non-deterministic filters into join condition") { + val x = testRelation.subquery('x) + val y = testRelation1.subquery('y) + + val originalQuery = x.join(y).where(Rand(10) > 5.0).analyze + val optimized = Optimize.execute(originalQuery) + + comparePlans(optimized, originalQuery) + } + test("joins: push to one side after transformCondition") { val x = testRelation.subquery('x) val y = testRelation1.subquery('y) From a26e3ed5e414d0a350cfe65dd511b154868b9f1d Mon Sep 17 00:00:00 2001 From: Sean Owen Date: Mon, 10 Apr 2017 20:11:56 +0100 Subject: [PATCH 0235/1765] [SPARK-20156][CORE][SQL][STREAMING][MLLIB] Java String toLowerCase "Turkish locale bug" causes Spark problems ## What changes were proposed in this pull request? Add Locale.ROOT to internal calls to String `toLowerCase`, `toUpperCase`, to avoid inadvertent locale-sensitive variation in behavior (aka the "Turkish locale problem"). The change looks large but it is just adding `Locale.ROOT` (the locale with no country or language specified) to every call to these methods. ## How was this patch tested? Existing tests. Author: Sean Owen Closes #17527 from srowen/SPARK-20156. --- .../apache/spark/network/util/JavaUtils.java | 5 ++- .../spark/network/util/TransportConf.java | 5 ++- .../spark/status/api/v1/TaskSorting.java | 3 +- .../scala/org/apache/spark/SparkContext.scala | 2 +- .../scala/org/apache/spark/SparkEnv.scala | 4 +- .../CoarseGrainedExecutorBackend.scala | 3 +- .../apache/spark/io/CompressionCodec.scala | 4 +- .../spark/metrics/sink/ConsoleSink.scala | 4 +- .../apache/spark/metrics/sink/CsvSink.scala | 2 +- .../spark/metrics/sink/GraphiteSink.scala | 6 +-- .../apache/spark/metrics/sink/Slf4jSink.scala | 4 +- .../scheduler/EventLoggingListener.scala | 3 +- .../spark/scheduler/SchedulableBuilder.scala | 5 ++- .../spark/scheduler/TaskSchedulerImpl.scala | 18 ++++---- .../spark/serializer/KryoSerializer.scala | 4 +- .../ui/exec/ExecutorThreadDumpPage.scala | 4 +- .../org/apache/spark/ui/jobs/JobPage.scala | 4 +- .../scala/org/apache/spark/ShuffleSuite.scala | 4 +- .../spark/broadcast/BroadcastSuite.scala | 4 +- .../internal/config/ConfigEntrySuite.scala | 3 +- .../BlockManagerReplicationSuite.scala | 6 ++- .../org/apache/spark/ui/StagePageSuite.scala | 5 ++- .../org/apache/spark/ui/UISeleniumSuite.scala | 5 ++- .../scala/org/apache/spark/ui/UISuite.scala | 11 ++--- .../examples/ml/DecisionTreeExample.scala | 4 +- .../apache/spark/examples/ml/GBTExample.scala | 4 +- .../examples/ml/RandomForestExample.scala | 4 +- .../spark/examples/mllib/LDAExample.scala | 4 +- .../sql/kafka010/KafkaSourceProvider.scala | 22 +++++----- .../sql/kafka010/KafkaRelationSuite.scala | 3 +- .../spark/sql/kafka010/KafkaSinkSuite.scala | 24 ++++++----- .../spark/sql/kafka010/KafkaSourceSuite.scala | 6 +-- .../streaming/kafka010/ConsumerStrategy.scala | 9 ++-- .../spark/streaming/kafka/KafkaUtils.scala | 4 +- .../classification/LogisticRegression.scala | 4 +- .../org/apache/spark/ml/clustering/LDA.scala | 5 ++- .../GeneralizedLinearRegressionWrapper.scala | 6 ++- .../apache/spark/ml/recommendation/ALS.scala | 9 ++-- .../GeneralizedLinearRegression.scala | 38 +++++++++-------- .../org/apache/spark/ml/tree/treeParams.scala | 41 ++++++++++++------- .../apache/spark/mllib/clustering/LDA.scala | 4 +- .../spark/mllib/tree/impurity/Impurity.scala | 4 +- .../scala/org/apache/spark/repl/Main.scala | 3 +- .../org/apache/spark/deploy/yarn/Client.scala | 4 +- .../sql/catalyst/analysis/ResolveHints.scala | 4 +- .../catalog/ExternalCatalogUtils.scala | 7 +++- .../sql/catalyst/catalog/SessionCatalog.scala | 3 +- .../catalyst/catalog/functionResources.scala | 4 +- .../sql/catalyst/expressions/Expression.scala | 4 +- .../expressions/mathExpressions.scala | 6 ++- .../expressions/regexpExpressions.scala | 3 +- .../expressions/windowExpressions.scala | 4 +- .../sql/catalyst/json/JacksonParser.scala | 5 ++- .../sql/catalyst/parser/AstBuilder.scala | 12 ++++-- .../spark/sql/catalyst/plans/joinTypes.scala | 4 +- .../streaming/InternalOutputModes.scala | 4 +- .../catalyst/util/CaseInsensitiveMap.scala | 9 ++-- .../sql/catalyst/util/CompressionCodecs.scala | 4 +- .../sql/catalyst/util/DateTimeUtils.scala | 4 +- .../spark/sql/catalyst/util/ParseMode.scala | 4 +- .../sql/catalyst/util/StringKeyHashMap.scala | 4 +- .../apache/spark/sql/internal/SQLConf.scala | 6 +-- .../org/apache/spark/sql/types/DataType.scala | 8 +++- .../apache/spark/sql/types/DecimalType.scala | 4 +- .../sql/streaming/JavaOutputModeSuite.java | 6 ++- .../sql/catalyst/analysis/AnalysisTest.scala | 5 ++- .../analysis/UnsupportedOperationsSuite.scala | 7 ++-- .../catalyst/expressions/ScalaUDFSuite.scala | 4 +- .../streaming/InternalOutputModesSuite.scala | 4 +- .../spark/sql/DataFrameNaFunctions.scala | 3 +- .../apache/spark/sql/DataFrameReader.scala | 4 +- .../apache/spark/sql/DataFrameWriter.scala | 6 +-- .../spark/sql/RelationalGroupedDataset.scala | 4 +- .../org/apache/spark/sql/api/r/SQLUtils.scala | 24 ++++++----- .../spark/sql/execution/SparkSqlParser.scala | 20 +++++---- .../sql/execution/WholeStageCodegenExec.scala | 6 ++- .../spark/sql/execution/command/ddl.scala | 6 ++- .../sql/execution/command/functions.scala | 4 +- .../execution/datasources/DataSource.scala | 16 ++++---- .../datasources/InMemoryFileIndex.scala | 1 - .../datasources/PartitioningUtils.scala | 4 +- .../datasources/csv/CSVOptions.scala | 4 +- .../spark/sql/execution/datasources/ddl.scala | 4 +- .../datasources/jdbc/JDBCOptions.scala | 6 +-- .../datasources/jdbc/JdbcUtils.scala | 3 +- .../datasources/parquet/ParquetOptions.scala | 8 +++- .../sql/execution/datasources/rules.scala | 5 ++- .../state/HDFSBackedStateStoreProvider.scala | 3 +- .../apache/spark/sql/internal/HiveSerDe.scala | 4 +- .../spark/sql/internal/SharedState.scala | 1 - .../sql/streaming/DataStreamReader.scala | 4 +- .../sql/streaming/DataStreamWriter.scala | 4 +- .../apache/spark/sql/JavaDatasetSuite.java | 2 +- .../apache/spark/sql/SQLQueryTestSuite.scala | 3 +- .../sql/execution/QueryExecutionSuite.scala | 13 +++--- .../execution/command/DDLCommandSuite.scala | 7 +++- .../sql/execution/command/DDLSuite.scala | 5 ++- .../datasources/parquet/ParquetIOSuite.scala | 4 +- .../ParquetPartitionDiscoverySuite.scala | 8 ++-- .../spark/sql/sources/FilteredScanSuite.scala | 9 ++-- .../sql/streaming/FileStreamSinkSuite.scala | 4 +- .../streaming/StreamingAggregationSuite.scala | 4 +- .../test/DataStreamReaderWriterSuite.scala | 7 ++-- .../sql/test/DataFrameReaderWriterSuite.scala | 10 +++-- .../hive/service/auth/HiveAuthFactory.java | 5 ++- .../org/apache/hive/service/auth/SaslQOP.java | 3 +- .../org/apache/hive/service/cli/Type.java | 3 +- .../hive/thriftserver/HiveThriftServer2.scala | 2 +- .../hive/thriftserver/SparkSQLCLIDriver.scala | 4 +- .../spark/sql/hive/HiveExternalCatalog.scala | 5 ++- .../spark/sql/hive/HiveSessionCatalog.scala | 4 +- .../spark/sql/hive/HiveStrategies.scala | 11 ++--- .../org/apache/spark/sql/hive/HiveUtils.scala | 3 +- .../sql/hive/client/HiveClientImpl.scala | 9 ++-- .../spark/sql/hive/client/HiveShim.scala | 6 +-- .../sql/hive/execution/HiveOptions.scala | 10 +++-- .../spark/sql/hive/orc/OrcOptions.scala | 6 ++- .../spark/sql/hive/HiveDDLCommandSuite.scala | 3 +- .../sql/hive/HiveSchemaInferenceSuite.scala | 3 -- .../hive/execution/HiveComparisonTest.scala | 8 ++-- .../sql/hive/execution/HiveQuerySuite.scala | 2 +- .../sql/hive/execution/SQLQuerySuite.scala | 7 ++-- .../streaming/dstream/InputDStream.scala | 6 ++- .../apache/spark/streaming/Java8APISuite.java | 5 ++- .../apache/spark/streaming/JavaAPISuite.java | 4 +- .../streaming/StreamingContextSuite.scala | 3 +- 126 files changed, 482 insertions(+), 299 deletions(-) diff --git a/common/network-common/src/main/java/org/apache/spark/network/util/JavaUtils.java b/common/network-common/src/main/java/org/apache/spark/network/util/JavaUtils.java index 51d7fda0cb260..afc59efaef810 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/util/JavaUtils.java +++ b/common/network-common/src/main/java/org/apache/spark/network/util/JavaUtils.java @@ -24,6 +24,7 @@ import java.nio.ByteBuffer; import java.nio.channels.ReadableByteChannel; import java.nio.charset.StandardCharsets; +import java.util.Locale; import java.util.concurrent.TimeUnit; import java.util.regex.Matcher; import java.util.regex.Pattern; @@ -210,7 +211,7 @@ private static boolean isSymlink(File file) throws IOException { * The unit is also considered the default if the given string does not specify a unit. */ public static long timeStringAs(String str, TimeUnit unit) { - String lower = str.toLowerCase().trim(); + String lower = str.toLowerCase(Locale.ROOT).trim(); try { Matcher m = Pattern.compile("(-?[0-9]+)([a-z]+)?").matcher(lower); @@ -258,7 +259,7 @@ public static long timeStringAsSec(String str) { * provided, a direct conversion to the provided unit is attempted. */ public static long byteStringAs(String str, ByteUnit unit) { - String lower = str.toLowerCase().trim(); + String lower = str.toLowerCase(Locale.ROOT).trim(); try { Matcher m = Pattern.compile("([0-9]+)([a-z]+)?").matcher(lower); diff --git a/common/network-common/src/main/java/org/apache/spark/network/util/TransportConf.java b/common/network-common/src/main/java/org/apache/spark/network/util/TransportConf.java index c226d8f3bc8fa..a25078e262efb 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/util/TransportConf.java +++ b/common/network-common/src/main/java/org/apache/spark/network/util/TransportConf.java @@ -17,6 +17,7 @@ package org.apache.spark.network.util; +import java.util.Locale; import java.util.Properties; import com.google.common.primitives.Ints; @@ -75,7 +76,9 @@ public String getModuleName() { } /** IO mode: nio or epoll */ - public String ioMode() { return conf.get(SPARK_NETWORK_IO_MODE_KEY, "NIO").toUpperCase(); } + public String ioMode() { + return conf.get(SPARK_NETWORK_IO_MODE_KEY, "NIO").toUpperCase(Locale.ROOT); + } /** If true, we will prefer allocating off-heap byte buffers within Netty. */ public boolean preferDirectBufs() { diff --git a/core/src/main/java/org/apache/spark/status/api/v1/TaskSorting.java b/core/src/main/java/org/apache/spark/status/api/v1/TaskSorting.java index b38639e854815..dff4f5df68784 100644 --- a/core/src/main/java/org/apache/spark/status/api/v1/TaskSorting.java +++ b/core/src/main/java/org/apache/spark/status/api/v1/TaskSorting.java @@ -21,6 +21,7 @@ import java.util.Collections; import java.util.HashSet; +import java.util.Locale; import java.util.Set; public enum TaskSorting { @@ -35,7 +36,7 @@ public enum TaskSorting { } public static TaskSorting fromString(String str) { - String lower = str.toLowerCase(); + String lower = str.toLowerCase(Locale.ROOT); for (TaskSorting t: values()) { if (t.alternateNames.contains(lower)) { return t; diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index 0225fd6056074..99efc4893fda4 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -361,7 +361,7 @@ class SparkContext(config: SparkConf) extends Logging { */ def setLogLevel(logLevel: String) { // let's allow lowercase or mixed case too - val upperCased = logLevel.toUpperCase(Locale.ENGLISH) + val upperCased = logLevel.toUpperCase(Locale.ROOT) require(SparkContext.VALID_LOG_LEVELS.contains(upperCased), s"Supplied level $logLevel did not match one of:" + s" ${SparkContext.VALID_LOG_LEVELS.mkString(",")}") diff --git a/core/src/main/scala/org/apache/spark/SparkEnv.scala b/core/src/main/scala/org/apache/spark/SparkEnv.scala index 539dbb55eeff0..f4a59f069a5f9 100644 --- a/core/src/main/scala/org/apache/spark/SparkEnv.scala +++ b/core/src/main/scala/org/apache/spark/SparkEnv.scala @@ -19,6 +19,7 @@ package org.apache.spark import java.io.File import java.net.Socket +import java.util.Locale import scala.collection.mutable import scala.util.Properties @@ -319,7 +320,8 @@ object SparkEnv extends Logging { "sort" -> classOf[org.apache.spark.shuffle.sort.SortShuffleManager].getName, "tungsten-sort" -> classOf[org.apache.spark.shuffle.sort.SortShuffleManager].getName) val shuffleMgrName = conf.get("spark.shuffle.manager", "sort") - val shuffleMgrClass = shortShuffleMgrNames.getOrElse(shuffleMgrName.toLowerCase, shuffleMgrName) + val shuffleMgrClass = + shortShuffleMgrNames.getOrElse(shuffleMgrName.toLowerCase(Locale.ROOT), shuffleMgrName) val shuffleManager = instantiateClass[ShuffleManager](shuffleMgrClass) val useLegacyMemoryManager = conf.getBoolean("spark.memory.useLegacyMode", false) diff --git a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala index ba0096d874567..b2b26ee107c00 100644 --- a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala +++ b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala @@ -19,6 +19,7 @@ package org.apache.spark.executor import java.net.URL import java.nio.ByteBuffer +import java.util.Locale import java.util.concurrent.atomic.AtomicBoolean import scala.collection.mutable @@ -72,7 +73,7 @@ private[spark] class CoarseGrainedExecutorBackend( def extractLogUrls: Map[String, String] = { val prefix = "SPARK_LOG_URL_" sys.env.filterKeys(_.startsWith(prefix)) - .map(e => (e._1.substring(prefix.length).toLowerCase, e._2)) + .map(e => (e._1.substring(prefix.length).toLowerCase(Locale.ROOT), e._2)) } override def receive: PartialFunction[Any, Unit] = { diff --git a/core/src/main/scala/org/apache/spark/io/CompressionCodec.scala b/core/src/main/scala/org/apache/spark/io/CompressionCodec.scala index c216fe477fd15..0cb16f0627b72 100644 --- a/core/src/main/scala/org/apache/spark/io/CompressionCodec.scala +++ b/core/src/main/scala/org/apache/spark/io/CompressionCodec.scala @@ -18,6 +18,7 @@ package org.apache.spark.io import java.io._ +import java.util.Locale import com.ning.compress.lzf.{LZFInputStream, LZFOutputStream} import net.jpountz.lz4.LZ4BlockOutputStream @@ -66,7 +67,8 @@ private[spark] object CompressionCodec { } def createCodec(conf: SparkConf, codecName: String): CompressionCodec = { - val codecClass = shortCompressionCodecNames.getOrElse(codecName.toLowerCase, codecName) + val codecClass = + shortCompressionCodecNames.getOrElse(codecName.toLowerCase(Locale.ROOT), codecName) val codec = try { val ctor = Utils.classForName(codecClass).getConstructor(classOf[SparkConf]) Some(ctor.newInstance(conf).asInstanceOf[CompressionCodec]) diff --git a/core/src/main/scala/org/apache/spark/metrics/sink/ConsoleSink.scala b/core/src/main/scala/org/apache/spark/metrics/sink/ConsoleSink.scala index 81b9056b40fb8..fce556fd0382c 100644 --- a/core/src/main/scala/org/apache/spark/metrics/sink/ConsoleSink.scala +++ b/core/src/main/scala/org/apache/spark/metrics/sink/ConsoleSink.scala @@ -17,7 +17,7 @@ package org.apache.spark.metrics.sink -import java.util.Properties +import java.util.{Locale, Properties} import java.util.concurrent.TimeUnit import com.codahale.metrics.{ConsoleReporter, MetricRegistry} @@ -39,7 +39,7 @@ private[spark] class ConsoleSink(val property: Properties, val registry: MetricR } val pollUnit: TimeUnit = Option(property.getProperty(CONSOLE_KEY_UNIT)) match { - case Some(s) => TimeUnit.valueOf(s.toUpperCase()) + case Some(s) => TimeUnit.valueOf(s.toUpperCase(Locale.ROOT)) case None => TimeUnit.valueOf(CONSOLE_DEFAULT_UNIT) } diff --git a/core/src/main/scala/org/apache/spark/metrics/sink/CsvSink.scala b/core/src/main/scala/org/apache/spark/metrics/sink/CsvSink.scala index 9d5f2ae9328ad..88bba2fdbd1c6 100644 --- a/core/src/main/scala/org/apache/spark/metrics/sink/CsvSink.scala +++ b/core/src/main/scala/org/apache/spark/metrics/sink/CsvSink.scala @@ -42,7 +42,7 @@ private[spark] class CsvSink(val property: Properties, val registry: MetricRegis } val pollUnit: TimeUnit = Option(property.getProperty(CSV_KEY_UNIT)) match { - case Some(s) => TimeUnit.valueOf(s.toUpperCase()) + case Some(s) => TimeUnit.valueOf(s.toUpperCase(Locale.ROOT)) case None => TimeUnit.valueOf(CSV_DEFAULT_UNIT) } diff --git a/core/src/main/scala/org/apache/spark/metrics/sink/GraphiteSink.scala b/core/src/main/scala/org/apache/spark/metrics/sink/GraphiteSink.scala index 22454e50b14b4..23e31823f4930 100644 --- a/core/src/main/scala/org/apache/spark/metrics/sink/GraphiteSink.scala +++ b/core/src/main/scala/org/apache/spark/metrics/sink/GraphiteSink.scala @@ -18,7 +18,7 @@ package org.apache.spark.metrics.sink import java.net.InetSocketAddress -import java.util.Properties +import java.util.{Locale, Properties} import java.util.concurrent.TimeUnit import com.codahale.metrics.MetricRegistry @@ -59,7 +59,7 @@ private[spark] class GraphiteSink(val property: Properties, val registry: Metric } val pollUnit: TimeUnit = propertyToOption(GRAPHITE_KEY_UNIT) match { - case Some(s) => TimeUnit.valueOf(s.toUpperCase()) + case Some(s) => TimeUnit.valueOf(s.toUpperCase(Locale.ROOT)) case None => TimeUnit.valueOf(GRAPHITE_DEFAULT_UNIT) } @@ -67,7 +67,7 @@ private[spark] class GraphiteSink(val property: Properties, val registry: Metric MetricsSystem.checkMinimalPollingPeriod(pollUnit, pollPeriod) - val graphite = propertyToOption(GRAPHITE_KEY_PROTOCOL).map(_.toLowerCase) match { + val graphite = propertyToOption(GRAPHITE_KEY_PROTOCOL).map(_.toLowerCase(Locale.ROOT)) match { case Some("udp") => new GraphiteUDP(new InetSocketAddress(host, port)) case Some("tcp") | None => new Graphite(new InetSocketAddress(host, port)) case Some(p) => throw new Exception(s"Invalid Graphite protocol: $p") diff --git a/core/src/main/scala/org/apache/spark/metrics/sink/Slf4jSink.scala b/core/src/main/scala/org/apache/spark/metrics/sink/Slf4jSink.scala index 773e074336cb0..7fa4ba7622980 100644 --- a/core/src/main/scala/org/apache/spark/metrics/sink/Slf4jSink.scala +++ b/core/src/main/scala/org/apache/spark/metrics/sink/Slf4jSink.scala @@ -17,7 +17,7 @@ package org.apache.spark.metrics.sink -import java.util.Properties +import java.util.{Locale, Properties} import java.util.concurrent.TimeUnit import com.codahale.metrics.{MetricRegistry, Slf4jReporter} @@ -42,7 +42,7 @@ private[spark] class Slf4jSink( } val pollUnit: TimeUnit = Option(property.getProperty(SLF4J_KEY_UNIT)) match { - case Some(s) => TimeUnit.valueOf(s.toUpperCase()) + case Some(s) => TimeUnit.valueOf(s.toUpperCase(Locale.ROOT)) case None => TimeUnit.valueOf(SLF4J_DEFAULT_UNIT) } diff --git a/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala b/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala index af9bdefc967ef..aecb3a980e7c1 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala @@ -20,6 +20,7 @@ package org.apache.spark.scheduler import java.io._ import java.net.URI import java.nio.charset.StandardCharsets +import java.util.Locale import scala.collection.mutable import scala.collection.mutable.ArrayBuffer @@ -316,7 +317,7 @@ private[spark] object EventLoggingListener extends Logging { } private def sanitize(str: String): String = { - str.replaceAll("[ :/]", "-").replaceAll("[.${}'\"]", "_").toLowerCase + str.replaceAll("[ :/]", "-").replaceAll("[.${}'\"]", "_").toLowerCase(Locale.ROOT) } /** diff --git a/core/src/main/scala/org/apache/spark/scheduler/SchedulableBuilder.scala b/core/src/main/scala/org/apache/spark/scheduler/SchedulableBuilder.scala index 20cedaf060420..417103436144a 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/SchedulableBuilder.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/SchedulableBuilder.scala @@ -18,7 +18,7 @@ package org.apache.spark.scheduler import java.io.{FileInputStream, InputStream} -import java.util.{NoSuchElementException, Properties} +import java.util.{Locale, NoSuchElementException, Properties} import scala.util.control.NonFatal import scala.xml.{Node, XML} @@ -142,7 +142,8 @@ private[spark] class FairSchedulableBuilder(val rootPool: Pool, conf: SparkConf) defaultValue: SchedulingMode, fileName: String): SchedulingMode = { - val xmlSchedulingMode = (poolNode \ SCHEDULING_MODE_PROPERTY).text.trim.toUpperCase + val xmlSchedulingMode = + (poolNode \ SCHEDULING_MODE_PROPERTY).text.trim.toUpperCase(Locale.ROOT) val warningMessage = s"Unsupported schedulingMode: $xmlSchedulingMode found in " + s"Fair Scheduler configuration file: $fileName, using " + s"the default schedulingMode: $defaultValue for pool: $poolName" diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala index 07aea773fa632..c849a16023a7a 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala @@ -18,7 +18,7 @@ package org.apache.spark.scheduler import java.nio.ByteBuffer -import java.util.{Timer, TimerTask} +import java.util.{Locale, Timer, TimerTask} import java.util.concurrent.TimeUnit import java.util.concurrent.atomic.AtomicLong @@ -56,8 +56,7 @@ private[spark] class TaskSchedulerImpl private[scheduler]( val maxTaskFailures: Int, private[scheduler] val blacklistTrackerOpt: Option[BlacklistTracker], isLocal: Boolean = false) - extends TaskScheduler with Logging -{ + extends TaskScheduler with Logging { import TaskSchedulerImpl._ @@ -135,12 +134,13 @@ private[spark] class TaskSchedulerImpl private[scheduler]( private var schedulableBuilder: SchedulableBuilder = null // default scheduler is FIFO private val schedulingModeConf = conf.get(SCHEDULER_MODE_PROPERTY, SchedulingMode.FIFO.toString) - val schedulingMode: SchedulingMode = try { - SchedulingMode.withName(schedulingModeConf.toUpperCase) - } catch { - case e: java.util.NoSuchElementException => - throw new SparkException(s"Unrecognized $SCHEDULER_MODE_PROPERTY: $schedulingModeConf") - } + val schedulingMode: SchedulingMode = + try { + SchedulingMode.withName(schedulingModeConf.toUpperCase(Locale.ROOT)) + } catch { + case e: java.util.NoSuchElementException => + throw new SparkException(s"Unrecognized $SCHEDULER_MODE_PROPERTY: $schedulingModeConf") + } val rootPool: Pool = new Pool("", schedulingMode, 0, 0) diff --git a/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala b/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala index 6fc66e2374bd9..e15166d11c243 100644 --- a/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala +++ b/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala @@ -19,6 +19,7 @@ package org.apache.spark.serializer import java.io._ import java.nio.ByteBuffer +import java.util.Locale import javax.annotation.Nullable import scala.collection.JavaConverters._ @@ -244,7 +245,8 @@ class KryoDeserializationStream( kryo.readClassAndObject(input).asInstanceOf[T] } catch { // DeserializationStream uses the EOF exception to indicate stopping condition. - case e: KryoException if e.getMessage.toLowerCase.contains("buffer underflow") => + case e: KryoException + if e.getMessage.toLowerCase(Locale.ROOT).contains("buffer underflow") => throw new EOFException } } diff --git a/core/src/main/scala/org/apache/spark/ui/exec/ExecutorThreadDumpPage.scala b/core/src/main/scala/org/apache/spark/ui/exec/ExecutorThreadDumpPage.scala index dbcc6402bc309..6ce3f511e89c7 100644 --- a/core/src/main/scala/org/apache/spark/ui/exec/ExecutorThreadDumpPage.scala +++ b/core/src/main/scala/org/apache/spark/ui/exec/ExecutorThreadDumpPage.scala @@ -17,6 +17,7 @@ package org.apache.spark.ui.exec +import java.util.Locale import javax.servlet.http.HttpServletRequest import scala.xml.{Node, Text} @@ -42,7 +43,8 @@ private[ui] class ExecutorThreadDumpPage(parent: ExecutorsTab) extends WebUIPage val v1 = if (threadTrace1.threadName.contains("Executor task launch")) 1 else 0 val v2 = if (threadTrace2.threadName.contains("Executor task launch")) 1 else 0 if (v1 == v2) { - threadTrace1.threadName.toLowerCase < threadTrace2.threadName.toLowerCase + threadTrace1.threadName.toLowerCase(Locale.ROOT) < + threadTrace2.threadName.toLowerCase(Locale.ROOT) } else { v1 > v2 } diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/JobPage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/JobPage.scala index 0ff9e5e9411ca..3131c4a1eb7d4 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/JobPage.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/JobPage.scala @@ -17,7 +17,7 @@ package org.apache.spark.ui.jobs -import java.util.Date +import java.util.{Date, Locale} import javax.servlet.http.HttpServletRequest import scala.collection.mutable.{Buffer, ListBuffer} @@ -77,7 +77,7 @@ private[ui] class JobPage(parent: JobsTab) extends WebUIPage("job") { | 'content': '
    // Put the block into one of the stores - val blockId = new TestBlockId( - "block-with-" + storageLevel.description.replace(" ", "-").toLowerCase) + val blockId = TestBlockId( + "block-with-" + storageLevel.description.replace(" ", "-").toLowerCase(Locale.ROOT)) val testValue = Array.fill[Byte](blockSize)(1) stores(0).putSingle(blockId, testValue, storageLevel) diff --git a/core/src/test/scala/org/apache/spark/ui/StagePageSuite.scala b/core/src/test/scala/org/apache/spark/ui/StagePageSuite.scala index 38030e066080f..499d47b13d702 100644 --- a/core/src/test/scala/org/apache/spark/ui/StagePageSuite.scala +++ b/core/src/test/scala/org/apache/spark/ui/StagePageSuite.scala @@ -17,6 +17,7 @@ package org.apache.spark.ui +import java.util.Locale import javax.servlet.http.HttpServletRequest import scala.xml.Node @@ -37,14 +38,14 @@ class StagePageSuite extends SparkFunSuite with LocalSparkContext { test("peak execution memory should displayed") { val conf = new SparkConf(false) - val html = renderStagePage(conf).toString().toLowerCase + val html = renderStagePage(conf).toString().toLowerCase(Locale.ROOT) val targetString = "peak execution memory" assert(html.contains(targetString)) } test("SPARK-10543: peak execution memory should be per-task rather than cumulative") { val conf = new SparkConf(false) - val html = renderStagePage(conf).toString().toLowerCase + val html = renderStagePage(conf).toString().toLowerCase(Locale.ROOT) // verify min/25/50/75/max show task value not cumulative values assert(html.contains(s"$peakExecutionMemory.0 b" * 5)) } diff --git a/core/src/test/scala/org/apache/spark/ui/UISeleniumSuite.scala b/core/src/test/scala/org/apache/spark/ui/UISeleniumSuite.scala index f4c561c737794..bdd148875e38a 100644 --- a/core/src/test/scala/org/apache/spark/ui/UISeleniumSuite.scala +++ b/core/src/test/scala/org/apache/spark/ui/UISeleniumSuite.scala @@ -18,6 +18,7 @@ package org.apache.spark.ui import java.net.{HttpURLConnection, URL} +import java.util.Locale import javax.servlet.http.{HttpServletRequest, HttpServletResponse} import scala.io.Source @@ -453,8 +454,8 @@ class UISeleniumSuite extends SparkFunSuite with WebBrowser with Matchers with B eventually(timeout(10 seconds), interval(50 milliseconds)) { goToUi(sc, "/jobs") findAll(cssSelector("tbody tr a")).foreach { link => - link.text.toLowerCase should include ("count") - link.text.toLowerCase should not include "unknown" + link.text.toLowerCase(Locale.ROOT) should include ("count") + link.text.toLowerCase(Locale.ROOT) should not include "unknown" } } } diff --git a/core/src/test/scala/org/apache/spark/ui/UISuite.scala b/core/src/test/scala/org/apache/spark/ui/UISuite.scala index f1be0f6de3ce2..0c3d4caeeabf9 100644 --- a/core/src/test/scala/org/apache/spark/ui/UISuite.scala +++ b/core/src/test/scala/org/apache/spark/ui/UISuite.scala @@ -19,6 +19,7 @@ package org.apache.spark.ui import java.net.{BindException, ServerSocket} import java.net.{URI, URL} +import java.util.Locale import javax.servlet.http.{HttpServlet, HttpServletRequest, HttpServletResponse} import scala.io.Source @@ -72,10 +73,10 @@ class UISuite extends SparkFunSuite { eventually(timeout(10 seconds), interval(50 milliseconds)) { val html = Source.fromURL(sc.ui.get.webUrl).mkString assert(!html.contains("random data that should not be present")) - assert(html.toLowerCase.contains("stages")) - assert(html.toLowerCase.contains("storage")) - assert(html.toLowerCase.contains("environment")) - assert(html.toLowerCase.contains("executors")) + assert(html.toLowerCase(Locale.ROOT).contains("stages")) + assert(html.toLowerCase(Locale.ROOT).contains("storage")) + assert(html.toLowerCase(Locale.ROOT).contains("environment")) + assert(html.toLowerCase(Locale.ROOT).contains("executors")) } } } @@ -85,7 +86,7 @@ class UISuite extends SparkFunSuite { // test if visible from http://localhost:4040 eventually(timeout(10 seconds), interval(50 milliseconds)) { val html = Source.fromURL("http://localhost:4040").mkString - assert(html.toLowerCase.contains("stages")) + assert(html.toLowerCase(Locale.ROOT).contains("stages")) } } } diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/DecisionTreeExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/DecisionTreeExample.scala index 1745281c266cc..f736ceed4436f 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/DecisionTreeExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/DecisionTreeExample.scala @@ -18,6 +18,8 @@ // scalastyle:off println package org.apache.spark.examples.ml +import java.util.Locale + import scala.collection.mutable import scala.language.reflectiveCalls @@ -203,7 +205,7 @@ object DecisionTreeExample { .getOrCreate() params.checkpointDir.foreach(spark.sparkContext.setCheckpointDir) - val algo = params.algo.toLowerCase + val algo = params.algo.toLowerCase(Locale.ROOT) println(s"DecisionTreeExample with parameters:\n$params") diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/GBTExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/GBTExample.scala index db55298d8ea10..ed598d0d7dfae 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/GBTExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/GBTExample.scala @@ -18,6 +18,8 @@ // scalastyle:off println package org.apache.spark.examples.ml +import java.util.Locale + import scala.collection.mutable import scala.language.reflectiveCalls @@ -140,7 +142,7 @@ object GBTExample { .getOrCreate() params.checkpointDir.foreach(spark.sparkContext.setCheckpointDir) - val algo = params.algo.toLowerCase + val algo = params.algo.toLowerCase(Locale.ROOT) println(s"GBTExample with parameters:\n$params") diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/RandomForestExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/RandomForestExample.scala index a9e07c0705c92..8fd46c37e2987 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/RandomForestExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/RandomForestExample.scala @@ -18,6 +18,8 @@ // scalastyle:off println package org.apache.spark.examples.ml +import java.util.Locale + import scala.collection.mutable import scala.language.reflectiveCalls @@ -146,7 +148,7 @@ object RandomForestExample { .getOrCreate() params.checkpointDir.foreach(spark.sparkContext.setCheckpointDir) - val algo = params.algo.toLowerCase + val algo = params.algo.toLowerCase(Locale.ROOT) println(s"RandomForestExample with parameters:\n$params") diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/LDAExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/LDAExample.scala index b923e627f2095..cd77ecf990b3b 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/LDAExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/LDAExample.scala @@ -18,6 +18,8 @@ // scalastyle:off println package org.apache.spark.examples.mllib +import java.util.Locale + import org.apache.log4j.{Level, Logger} import scopt.OptionParser @@ -131,7 +133,7 @@ object LDAExample { // Run LDA. val lda = new LDA() - val optimizer = params.algorithm.toLowerCase match { + val optimizer = params.algorithm.toLowerCase(Locale.ROOT) match { case "em" => new EMLDAOptimizer // add (1.0 / actualCorpusSize) to MiniBatchFraction be more robust on tiny datasets. case "online" => new OnlineLDAOptimizer().setMiniBatchFraction(0.05 + 1.0 / actualCorpusSize) diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala index 58b52692b57ce..ab1ce347cbe34 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.kafka010 import java.{util => ju} -import java.util.UUID +import java.util.{Locale, UUID} import scala.collection.JavaConverters._ @@ -74,11 +74,11 @@ private[kafka010] class KafkaSourceProvider extends DataSourceRegister // id. Hence, we should generate a unique id for each query. val uniqueGroupId = s"spark-kafka-source-${UUID.randomUUID}-${metadataPath.hashCode}" - val caseInsensitiveParams = parameters.map { case (k, v) => (k.toLowerCase, v) } + val caseInsensitiveParams = parameters.map { case (k, v) => (k.toLowerCase(Locale.ROOT), v) } val specifiedKafkaParams = parameters .keySet - .filter(_.toLowerCase.startsWith("kafka.")) + .filter(_.toLowerCase(Locale.ROOT).startsWith("kafka.")) .map { k => k.drop(6).toString -> parameters(k) } .toMap @@ -115,11 +115,11 @@ private[kafka010] class KafkaSourceProvider extends DataSourceRegister // partial data since Kafka will assign partitions to multiple consumers having the same group // id. Hence, we should generate a unique id for each query. val uniqueGroupId = s"spark-kafka-relation-${UUID.randomUUID}" - val caseInsensitiveParams = parameters.map { case (k, v) => (k.toLowerCase, v) } + val caseInsensitiveParams = parameters.map { case (k, v) => (k.toLowerCase(Locale.ROOT), v) } val specifiedKafkaParams = parameters .keySet - .filter(_.toLowerCase.startsWith("kafka.")) + .filter(_.toLowerCase(Locale.ROOT).startsWith("kafka.")) .map { k => k.drop(6).toString -> parameters(k) } .toMap @@ -192,7 +192,7 @@ private[kafka010] class KafkaSourceProvider extends DataSourceRegister } private def kafkaParamsForProducer(parameters: Map[String, String]): Map[String, String] = { - val caseInsensitiveParams = parameters.map { case (k, v) => (k.toLowerCase, v) } + val caseInsensitiveParams = parameters.map { case (k, v) => (k.toLowerCase(Locale.ROOT), v) } if (caseInsensitiveParams.contains(s"kafka.${ProducerConfig.KEY_SERIALIZER_CLASS_CONFIG}")) { throw new IllegalArgumentException( s"Kafka option '${ProducerConfig.KEY_SERIALIZER_CLASS_CONFIG}' is not supported as keys " @@ -207,7 +207,7 @@ private[kafka010] class KafkaSourceProvider extends DataSourceRegister } parameters .keySet - .filter(_.toLowerCase.startsWith("kafka.")) + .filter(_.toLowerCase(Locale.ROOT).startsWith("kafka.")) .map { k => k.drop(6).toString -> parameters(k) } .toMap + (ProducerConfig.KEY_SERIALIZER_CLASS_CONFIG -> classOf[ByteArraySerializer].getName, ProducerConfig.VALUE_SERIALIZER_CLASS_CONFIG -> classOf[ByteArraySerializer].getName) @@ -272,7 +272,7 @@ private[kafka010] class KafkaSourceProvider extends DataSourceRegister private def validateGeneralOptions(parameters: Map[String, String]): Unit = { // Validate source options - val caseInsensitiveParams = parameters.map { case (k, v) => (k.toLowerCase, v) } + val caseInsensitiveParams = parameters.map { case (k, v) => (k.toLowerCase(Locale.ROOT), v) } val specifiedStrategies = caseInsensitiveParams.filter { case (k, _) => STRATEGY_OPTION_KEYS.contains(k) }.toSeq @@ -451,8 +451,10 @@ private[kafka010] object KafkaSourceProvider { offsetOptionKey: String, defaultOffsets: KafkaOffsetRangeLimit): KafkaOffsetRangeLimit = { params.get(offsetOptionKey).map(_.trim) match { - case Some(offset) if offset.toLowerCase == "latest" => LatestOffsetRangeLimit - case Some(offset) if offset.toLowerCase == "earliest" => EarliestOffsetRangeLimit + case Some(offset) if offset.toLowerCase(Locale.ROOT) == "latest" => + LatestOffsetRangeLimit + case Some(offset) if offset.toLowerCase(Locale.ROOT) == "earliest" => + EarliestOffsetRangeLimit case Some(json) => SpecificOffsetRangeLimit(JsonUtils.partitionOffsets(json)) case None => defaultOffsets } diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaRelationSuite.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaRelationSuite.scala index 68bc3e3e2e9a8..91893df4ec32f 100644 --- a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaRelationSuite.scala +++ b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaRelationSuite.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.kafka010 +import java.util.Locale import java.util.concurrent.atomic.AtomicInteger import org.apache.kafka.common.TopicPartition @@ -195,7 +196,7 @@ class KafkaRelationSuite extends QueryTest with BeforeAndAfter with SharedSQLCon reader.load() } expectedMsgs.foreach { m => - assert(ex.getMessage.toLowerCase.contains(m.toLowerCase)) + assert(ex.getMessage.toLowerCase(Locale.ROOT).contains(m.toLowerCase(Locale.ROOT))) } } diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSinkSuite.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSinkSuite.scala index 490535623cb36..4bd052d249eca 100644 --- a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSinkSuite.scala +++ b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSinkSuite.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.kafka010 +import java.util.Locale import java.util.concurrent.atomic.AtomicInteger import org.apache.kafka.clients.producer.ProducerConfig @@ -75,7 +76,7 @@ class KafkaSinkSuite extends StreamTest with SharedSQLContext { .option("kafka.bootstrap.servers", testUtils.brokerAddress) .save() } - assert(ex.getMessage.toLowerCase.contains( + assert(ex.getMessage.toLowerCase(Locale.ROOT).contains( "null topic present in the data")) } @@ -92,7 +93,7 @@ class KafkaSinkSuite extends StreamTest with SharedSQLContext { .mode(SaveMode.Ignore) .save() } - assert(ex.getMessage.toLowerCase.contains( + assert(ex.getMessage.toLowerCase(Locale.ROOT).contains( s"save mode ignore not allowed for kafka")) // Test bad save mode Overwrite @@ -103,7 +104,7 @@ class KafkaSinkSuite extends StreamTest with SharedSQLContext { .mode(SaveMode.Overwrite) .save() } - assert(ex.getMessage.toLowerCase.contains( + assert(ex.getMessage.toLowerCase(Locale.ROOT).contains( s"save mode overwrite not allowed for kafka")) } @@ -233,7 +234,7 @@ class KafkaSinkSuite extends StreamTest with SharedSQLContext { writer.stop() } assert(ex.getMessage - .toLowerCase + .toLowerCase(Locale.ROOT) .contains("topic option required when no 'topic' attribute is present")) try { @@ -248,7 +249,8 @@ class KafkaSinkSuite extends StreamTest with SharedSQLContext { } finally { writer.stop() } - assert(ex.getMessage.toLowerCase.contains("required attribute 'value' not found")) + assert(ex.getMessage.toLowerCase(Locale.ROOT).contains( + "required attribute 'value' not found")) } test("streaming - write data with valid schema but wrong types") { @@ -270,7 +272,7 @@ class KafkaSinkSuite extends StreamTest with SharedSQLContext { } finally { writer.stop() } - assert(ex.getMessage.toLowerCase.contains("topic type must be a string")) + assert(ex.getMessage.toLowerCase(Locale.ROOT).contains("topic type must be a string")) try { /* value field wrong type */ @@ -284,7 +286,7 @@ class KafkaSinkSuite extends StreamTest with SharedSQLContext { } finally { writer.stop() } - assert(ex.getMessage.toLowerCase.contains( + assert(ex.getMessage.toLowerCase(Locale.ROOT).contains( "value attribute type must be a string or binarytype")) try { @@ -299,7 +301,7 @@ class KafkaSinkSuite extends StreamTest with SharedSQLContext { } finally { writer.stop() } - assert(ex.getMessage.toLowerCase.contains( + assert(ex.getMessage.toLowerCase(Locale.ROOT).contains( "key attribute type must be a string or binarytype")) } @@ -318,7 +320,7 @@ class KafkaSinkSuite extends StreamTest with SharedSQLContext { } finally { writer.stop() } - assert(ex.getMessage.toLowerCase.contains("job aborted")) + assert(ex.getMessage.toLowerCase(Locale.ROOT).contains("job aborted")) } test("streaming - exception on config serializer") { @@ -330,7 +332,7 @@ class KafkaSinkSuite extends StreamTest with SharedSQLContext { input.toDF(), withOptions = Map("kafka.key.serializer" -> "foo"))() } - assert(ex.getMessage.toLowerCase.contains( + assert(ex.getMessage.toLowerCase(Locale.ROOT).contains( "kafka option 'key.serializer' is not supported")) ex = intercept[IllegalArgumentException] { @@ -338,7 +340,7 @@ class KafkaSinkSuite extends StreamTest with SharedSQLContext { input.toDF(), withOptions = Map("kafka.value.serializer" -> "foo"))() } - assert(ex.getMessage.toLowerCase.contains( + assert(ex.getMessage.toLowerCase(Locale.ROOT).contains( "kafka option 'value.serializer' is not supported")) } diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSourceSuite.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSourceSuite.scala index 0046ba7e43d13..2034b9be07f24 100644 --- a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSourceSuite.scala +++ b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSourceSuite.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.kafka010 import java.io._ import java.nio.charset.StandardCharsets.UTF_8 import java.nio.file.{Files, Paths} -import java.util.Properties +import java.util.{Locale, Properties} import java.util.concurrent.ConcurrentLinkedQueue import java.util.concurrent.atomic.AtomicInteger @@ -491,7 +491,7 @@ class KafkaSourceSuite extends KafkaSourceTest { reader.load() } expectedMsgs.foreach { m => - assert(ex.getMessage.toLowerCase.contains(m.toLowerCase)) + assert(ex.getMessage.toLowerCase(Locale.ROOT).contains(m.toLowerCase(Locale.ROOT))) } } @@ -524,7 +524,7 @@ class KafkaSourceSuite extends KafkaSourceTest { .option(s"$key", value) reader.load() } - assert(ex.getMessage.toLowerCase.contains("not supported")) + assert(ex.getMessage.toLowerCase(Locale.ROOT).contains("not supported")) } testUnsupportedConfig("kafka.group.id") diff --git a/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/ConsumerStrategy.scala b/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/ConsumerStrategy.scala index 778c06ea16a2b..d2100fc5a4aba 100644 --- a/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/ConsumerStrategy.scala +++ b/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/ConsumerStrategy.scala @@ -17,7 +17,8 @@ package org.apache.spark.streaming.kafka010 -import java.{ lang => jl, util => ju } +import java.{lang => jl, util => ju} +import java.util.Locale import scala.collection.JavaConverters._ @@ -93,7 +94,8 @@ private case class Subscribe[K, V]( // but cant seek to a position before poll, because poll is what gets subscription partitions // So, poll, suppress the first exception, then seek val aor = kafkaParams.get(ConsumerConfig.AUTO_OFFSET_RESET_CONFIG) - val shouldSuppress = aor != null && aor.asInstanceOf[String].toUpperCase == "NONE" + val shouldSuppress = + aor != null && aor.asInstanceOf[String].toUpperCase(Locale.ROOT) == "NONE" try { consumer.poll(0) } catch { @@ -145,7 +147,8 @@ private case class SubscribePattern[K, V]( if (!toSeek.isEmpty) { // work around KAFKA-3370 when reset is none, see explanation in Subscribe above val aor = kafkaParams.get(ConsumerConfig.AUTO_OFFSET_RESET_CONFIG) - val shouldSuppress = aor != null && aor.asInstanceOf[String].toUpperCase == "NONE" + val shouldSuppress = + aor != null && aor.asInstanceOf[String].toUpperCase(Locale.ROOT) == "NONE" try { consumer.poll(0) } catch { diff --git a/external/kafka-0-8/src/main/scala/org/apache/spark/streaming/kafka/KafkaUtils.scala b/external/kafka-0-8/src/main/scala/org/apache/spark/streaming/kafka/KafkaUtils.scala index d5aef8184fc87..78230725f322e 100644 --- a/external/kafka-0-8/src/main/scala/org/apache/spark/streaming/kafka/KafkaUtils.scala +++ b/external/kafka-0-8/src/main/scala/org/apache/spark/streaming/kafka/KafkaUtils.scala @@ -20,7 +20,7 @@ package org.apache.spark.streaming.kafka import java.io.OutputStream import java.lang.{Integer => JInt, Long => JLong, Number => JNumber} import java.nio.charset.StandardCharsets -import java.util.{List => JList, Map => JMap, Set => JSet} +import java.util.{List => JList, Locale, Map => JMap, Set => JSet} import scala.collection.JavaConverters._ import scala.reflect.ClassTag @@ -206,7 +206,7 @@ object KafkaUtils { kafkaParams: Map[String, String], topics: Set[String] ): Map[TopicAndPartition, Long] = { - val reset = kafkaParams.get("auto.offset.reset").map(_.toLowerCase) + val reset = kafkaParams.get("auto.offset.reset").map(_.toLowerCase(Locale.ROOT)) val result = for { topicPartitions <- kc.getPartitions(topics).right leaderOffsets <- (if (reset == Some("smallest")) { diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala index 7b56bce41c326..965ce3d6f275f 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala @@ -17,6 +17,8 @@ package org.apache.spark.ml.classification +import java.util.Locale + import scala.collection.mutable import breeze.linalg.{DenseVector => BDV} @@ -654,7 +656,7 @@ object LogisticRegression extends DefaultParamsReadable[LogisticRegression] { override def load(path: String): LogisticRegression = super.load(path) private[classification] val supportedFamilyNames = - Array("auto", "binomial", "multinomial").map(_.toLowerCase) + Array("auto", "binomial", "multinomial").map(_.toLowerCase(Locale.ROOT)) } /** diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala index 55720e2d613d9..2f50dc7c85f35 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala @@ -17,6 +17,8 @@ package org.apache.spark.ml.clustering +import java.util.Locale + import org.apache.hadoop.fs.Path import org.json4s.DefaultFormats import org.json4s.JsonAST.JObject @@ -173,7 +175,8 @@ private[clustering] trait LDAParams extends Params with HasFeaturesCol with HasM @Since("1.6.0") final val optimizer = new Param[String](this, "optimizer", "Optimizer or inference" + " algorithm used to estimate the LDA model. Supported: " + supportedOptimizers.mkString(", "), - (o: String) => ParamValidators.inArray(supportedOptimizers).apply(o.toLowerCase)) + (o: String) => + ParamValidators.inArray(supportedOptimizers).apply(o.toLowerCase(Locale.ROOT))) /** @group getParam */ @Since("1.6.0") diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/GeneralizedLinearRegressionWrapper.scala b/mllib/src/main/scala/org/apache/spark/ml/r/GeneralizedLinearRegressionWrapper.scala index c49416b240181..4bd4aa7113f68 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/r/GeneralizedLinearRegressionWrapper.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/r/GeneralizedLinearRegressionWrapper.scala @@ -17,6 +17,8 @@ package org.apache.spark.ml.r +import java.util.Locale + import org.apache.hadoop.fs.Path import org.json4s._ import org.json4s.JsonDSL._ @@ -91,7 +93,7 @@ private[r] object GeneralizedLinearRegressionWrapper .setRegParam(regParam) .setFeaturesCol(rFormula.getFeaturesCol) // set variancePower and linkPower if family is tweedie; otherwise, set link function - if (family.toLowerCase == "tweedie") { + if (family.toLowerCase(Locale.ROOT) == "tweedie") { glr.setVariancePower(variancePower).setLinkPower(linkPower) } else { glr.setLink(link) @@ -151,7 +153,7 @@ private[r] object GeneralizedLinearRegressionWrapper val rDeviance: Double = summary.deviance val rResidualDegreeOfFreedomNull: Long = summary.residualDegreeOfFreedomNull val rResidualDegreeOfFreedom: Long = summary.residualDegreeOfFreedom - val rAic: Double = if (family.toLowerCase == "tweedie" && + val rAic: Double = if (family.toLowerCase(Locale.ROOT) == "tweedie" && !Array(0.0, 1.0, 2.0).exists(x => math.abs(x - variancePower) < 1e-8)) { 0.0 } else { diff --git a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala index 60dd7367053e2..a20ef72446661 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala @@ -19,6 +19,7 @@ package org.apache.spark.ml.recommendation import java.{util => ju} import java.io.IOException +import java.util.Locale import scala.collection.mutable import scala.reflect.ClassTag @@ -40,8 +41,7 @@ import org.apache.spark.ml.util._ import org.apache.spark.mllib.linalg.CholeskyDecomposition import org.apache.spark.mllib.optimization.NNLS import org.apache.spark.rdd.RDD -import org.apache.spark.sql.{DataFrame, Dataset, Row} -import org.apache.spark.sql.catalyst.encoders.RowEncoder +import org.apache.spark.sql.{DataFrame, Dataset} import org.apache.spark.sql.functions._ import org.apache.spark.sql.types._ import org.apache.spark.storage.StorageLevel @@ -118,10 +118,11 @@ private[recommendation] trait ALSModelParams extends Params with HasPredictionCo "useful in cross-validation or production scenarios, for handling user/item ids the model " + "has not seen in the training data. Supported values: " + s"${ALSModel.supportedColdStartStrategies.mkString(",")}.", - (s: String) => ALSModel.supportedColdStartStrategies.contains(s.toLowerCase)) + (s: String) => + ALSModel.supportedColdStartStrategies.contains(s.toLowerCase(Locale.ROOT))) /** @group expertGetParam */ - def getColdStartStrategy: String = $(coldStartStrategy).toLowerCase + def getColdStartStrategy: String = $(coldStartStrategy).toLowerCase(Locale.ROOT) } /** diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala index 3be8b533ee3f3..33137b0c0fdec 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala @@ -17,6 +17,8 @@ package org.apache.spark.ml.regression +import java.util.Locale + import breeze.stats.{distributions => dist} import org.apache.hadoop.fs.Path @@ -57,7 +59,7 @@ private[regression] trait GeneralizedLinearRegressionBase extends PredictorParam final val family: Param[String] = new Param(this, "family", "The name of family which is a description of the error distribution to be used in the " + s"model. Supported options: ${supportedFamilyNames.mkString(", ")}.", - (value: String) => supportedFamilyNames.contains(value.toLowerCase)) + (value: String) => supportedFamilyNames.contains(value.toLowerCase(Locale.ROOT))) /** @group getParam */ @Since("2.0.0") @@ -99,7 +101,7 @@ private[regression] trait GeneralizedLinearRegressionBase extends PredictorParam final val link: Param[String] = new Param(this, "link", "The name of link function " + "which provides the relationship between the linear predictor and the mean of the " + s"distribution function. Supported options: ${supportedLinkNames.mkString(", ")}", - (value: String) => supportedLinkNames.contains(value.toLowerCase)) + (value: String) => supportedLinkNames.contains(value.toLowerCase(Locale.ROOT))) /** @group getParam */ @Since("2.0.0") @@ -148,7 +150,7 @@ private[regression] trait GeneralizedLinearRegressionBase extends PredictorParam schema: StructType, fitting: Boolean, featuresDataType: DataType): StructType = { - if ($(family).toLowerCase == "tweedie") { + if ($(family).toLowerCase(Locale.ROOT) == "tweedie") { if (isSet(link)) { logWarning("When family is tweedie, use param linkPower to specify link function. " + "Setting param link will take no effect.") @@ -460,13 +462,15 @@ object GeneralizedLinearRegression extends DefaultParamsReadable[GeneralizedLine */ def apply(params: GeneralizedLinearRegressionBase): FamilyAndLink = { val familyObj = Family.fromParams(params) - val linkObj = if ((params.getFamily.toLowerCase != "tweedie" && - params.isSet(params.link)) || (params.getFamily.toLowerCase == "tweedie" && - params.isSet(params.linkPower))) { - Link.fromParams(params) - } else { - familyObj.defaultLink - } + val linkObj = + if ((params.getFamily.toLowerCase(Locale.ROOT) != "tweedie" && + params.isSet(params.link)) || + (params.getFamily.toLowerCase(Locale.ROOT) == "tweedie" && + params.isSet(params.linkPower))) { + Link.fromParams(params) + } else { + familyObj.defaultLink + } new FamilyAndLink(familyObj, linkObj) } } @@ -519,7 +523,7 @@ object GeneralizedLinearRegression extends DefaultParamsReadable[GeneralizedLine * @param params the parameter map containing family name and variance power */ def fromParams(params: GeneralizedLinearRegressionBase): Family = { - params.getFamily.toLowerCase match { + params.getFamily.toLowerCase(Locale.ROOT) match { case Gaussian.name => Gaussian case Binomial.name => Binomial case Poisson.name => Poisson @@ -795,7 +799,7 @@ object GeneralizedLinearRegression extends DefaultParamsReadable[GeneralizedLine * @param params the parameter map containing family, link and linkPower */ def fromParams(params: GeneralizedLinearRegressionBase): Link = { - if (params.getFamily.toLowerCase == "tweedie") { + if (params.getFamily.toLowerCase(Locale.ROOT) == "tweedie") { params.getLinkPower match { case 0.0 => Log case 1.0 => Identity @@ -804,7 +808,7 @@ object GeneralizedLinearRegression extends DefaultParamsReadable[GeneralizedLine case others => new Power(others) } } else { - params.getLink.toLowerCase match { + params.getLink.toLowerCase(Locale.ROOT) match { case Identity.name => Identity case Logit.name => Logit case Log.name => Log @@ -1253,8 +1257,8 @@ class GeneralizedLinearRegressionSummary private[regression] ( */ @Since("2.0.0") lazy val dispersion: Double = if ( - model.getFamily.toLowerCase == Binomial.name || - model.getFamily.toLowerCase == Poisson.name) { + model.getFamily.toLowerCase(Locale.ROOT) == Binomial.name || + model.getFamily.toLowerCase(Locale.ROOT) == Poisson.name) { 1.0 } else { val rss = pearsonResiduals.agg(sum(pow(col("pearsonResiduals"), 2.0))).first().getDouble(0) @@ -1357,8 +1361,8 @@ class GeneralizedLinearRegressionTrainingSummary private[regression] ( @Since("2.0.0") lazy val pValues: Array[Double] = { if (isNormalSolver) { - if (model.getFamily.toLowerCase == Binomial.name || - model.getFamily.toLowerCase == Poisson.name) { + if (model.getFamily.toLowerCase(Locale.ROOT) == Binomial.name || + model.getFamily.toLowerCase(Locale.ROOT) == Poisson.name) { tValues.map { x => 2.0 * (1.0 - dist.Gaussian(0.0, 1.0).cdf(math.abs(x))) } } else { tValues.map { x => diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala index 5eb707dfe7bc3..cd1950bd76c05 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala @@ -17,6 +17,8 @@ package org.apache.spark.ml.tree +import java.util.Locale + import scala.util.Try import org.apache.spark.ml.PredictorParams @@ -218,7 +220,8 @@ private[ml] trait TreeClassifierParams extends Params { final val impurity: Param[String] = new Param[String](this, "impurity", "Criterion used for" + " information gain calculation (case-insensitive). Supported options:" + s" ${TreeClassifierParams.supportedImpurities.mkString(", ")}", - (value: String) => TreeClassifierParams.supportedImpurities.contains(value.toLowerCase)) + (value: String) => + TreeClassifierParams.supportedImpurities.contains(value.toLowerCase(Locale.ROOT))) setDefault(impurity -> "gini") @@ -230,7 +233,7 @@ private[ml] trait TreeClassifierParams extends Params { def setImpurity(value: String): this.type = set(impurity, value) /** @group getParam */ - final def getImpurity: String = $(impurity).toLowerCase + final def getImpurity: String = $(impurity).toLowerCase(Locale.ROOT) /** Convert new impurity to old impurity. */ private[ml] def getOldImpurity: OldImpurity = { @@ -247,7 +250,8 @@ private[ml] trait TreeClassifierParams extends Params { private[ml] object TreeClassifierParams { // These options should be lowercase. - final val supportedImpurities: Array[String] = Array("entropy", "gini").map(_.toLowerCase) + final val supportedImpurities: Array[String] = + Array("entropy", "gini").map(_.toLowerCase(Locale.ROOT)) } private[ml] trait DecisionTreeClassifierParams @@ -267,7 +271,8 @@ private[ml] trait TreeRegressorParams extends Params { final val impurity: Param[String] = new Param[String](this, "impurity", "Criterion used for" + " information gain calculation (case-insensitive). Supported options:" + s" ${TreeRegressorParams.supportedImpurities.mkString(", ")}", - (value: String) => TreeRegressorParams.supportedImpurities.contains(value.toLowerCase)) + (value: String) => + TreeRegressorParams.supportedImpurities.contains(value.toLowerCase(Locale.ROOT))) setDefault(impurity -> "variance") @@ -279,7 +284,7 @@ private[ml] trait TreeRegressorParams extends Params { def setImpurity(value: String): this.type = set(impurity, value) /** @group getParam */ - final def getImpurity: String = $(impurity).toLowerCase + final def getImpurity: String = $(impurity).toLowerCase(Locale.ROOT) /** Convert new impurity to old impurity. */ private[ml] def getOldImpurity: OldImpurity = { @@ -295,7 +300,8 @@ private[ml] trait TreeRegressorParams extends Params { private[ml] object TreeRegressorParams { // These options should be lowercase. - final val supportedImpurities: Array[String] = Array("variance").map(_.toLowerCase) + final val supportedImpurities: Array[String] = + Array("variance").map(_.toLowerCase(Locale.ROOT)) } private[ml] trait DecisionTreeRegressorParams extends DecisionTreeParams @@ -417,7 +423,8 @@ private[ml] trait RandomForestParams extends TreeEnsembleParams { s" Supported options: ${RandomForestParams.supportedFeatureSubsetStrategies.mkString(", ")}" + s", (0.0-1.0], [1-n].", (value: String) => - RandomForestParams.supportedFeatureSubsetStrategies.contains(value.toLowerCase) + RandomForestParams.supportedFeatureSubsetStrategies.contains( + value.toLowerCase(Locale.ROOT)) || Try(value.toInt).filter(_ > 0).isSuccess || Try(value.toDouble).filter(_ > 0).filter(_ <= 1.0).isSuccess) @@ -431,13 +438,13 @@ private[ml] trait RandomForestParams extends TreeEnsembleParams { def setFeatureSubsetStrategy(value: String): this.type = set(featureSubsetStrategy, value) /** @group getParam */ - final def getFeatureSubsetStrategy: String = $(featureSubsetStrategy).toLowerCase + final def getFeatureSubsetStrategy: String = $(featureSubsetStrategy).toLowerCase(Locale.ROOT) } private[spark] object RandomForestParams { // These options should be lowercase. final val supportedFeatureSubsetStrategies: Array[String] = - Array("auto", "all", "onethird", "sqrt", "log2").map(_.toLowerCase) + Array("auto", "all", "onethird", "sqrt", "log2").map(_.toLowerCase(Locale.ROOT)) } private[ml] trait RandomForestClassifierParams @@ -509,7 +516,8 @@ private[ml] trait GBTParams extends TreeEnsembleParams with HasMaxIter { private[ml] object GBTClassifierParams { // The losses below should be lowercase. /** Accessor for supported loss settings: logistic */ - final val supportedLossTypes: Array[String] = Array("logistic").map(_.toLowerCase) + final val supportedLossTypes: Array[String] = + Array("logistic").map(_.toLowerCase(Locale.ROOT)) } private[ml] trait GBTClassifierParams extends GBTParams with TreeClassifierParams { @@ -523,12 +531,13 @@ private[ml] trait GBTClassifierParams extends GBTParams with TreeClassifierParam val lossType: Param[String] = new Param[String](this, "lossType", "Loss function which GBT" + " tries to minimize (case-insensitive). Supported options:" + s" ${GBTClassifierParams.supportedLossTypes.mkString(", ")}", - (value: String) => GBTClassifierParams.supportedLossTypes.contains(value.toLowerCase)) + (value: String) => + GBTClassifierParams.supportedLossTypes.contains(value.toLowerCase(Locale.ROOT))) setDefault(lossType -> "logistic") /** @group getParam */ - def getLossType: String = $(lossType).toLowerCase + def getLossType: String = $(lossType).toLowerCase(Locale.ROOT) /** (private[ml]) Convert new loss to old loss. */ override private[ml] def getOldLossType: OldClassificationLoss = { @@ -544,7 +553,8 @@ private[ml] trait GBTClassifierParams extends GBTParams with TreeClassifierParam private[ml] object GBTRegressorParams { // The losses below should be lowercase. /** Accessor for supported loss settings: squared (L2), absolute (L1) */ - final val supportedLossTypes: Array[String] = Array("squared", "absolute").map(_.toLowerCase) + final val supportedLossTypes: Array[String] = + Array("squared", "absolute").map(_.toLowerCase(Locale.ROOT)) } private[ml] trait GBTRegressorParams extends GBTParams with TreeRegressorParams { @@ -558,12 +568,13 @@ private[ml] trait GBTRegressorParams extends GBTParams with TreeRegressorParams val lossType: Param[String] = new Param[String](this, "lossType", "Loss function which GBT" + " tries to minimize (case-insensitive). Supported options:" + s" ${GBTRegressorParams.supportedLossTypes.mkString(", ")}", - (value: String) => GBTRegressorParams.supportedLossTypes.contains(value.toLowerCase)) + (value: String) => + GBTRegressorParams.supportedLossTypes.contains(value.toLowerCase(Locale.ROOT))) setDefault(lossType -> "squared") /** @group getParam */ - def getLossType: String = $(lossType).toLowerCase + def getLossType: String = $(lossType).toLowerCase(Locale.ROOT) /** (private[ml]) Convert new loss to old loss. */ override private[ml] def getOldLossType: OldLoss = { diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDA.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDA.scala index 6c5f529fb8bfd..4aa647236b31c 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDA.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDA.scala @@ -17,6 +17,8 @@ package org.apache.spark.mllib.clustering +import java.util.Locale + import breeze.linalg.{DenseVector => BDV} import org.apache.spark.annotation.{DeveloperApi, Since} @@ -306,7 +308,7 @@ class LDA private ( @Since("1.4.0") def setOptimizer(optimizerName: String): this.type = { this.ldaOptimizer = - optimizerName.toLowerCase match { + optimizerName.toLowerCase(Locale.ROOT) match { case "em" => new EMLDAOptimizer case "online" => new OnlineLDAOptimizer case other => diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala index 98a3021461eb8..4c7746869dde1 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala @@ -17,6 +17,8 @@ package org.apache.spark.mllib.tree.impurity +import java.util.Locale + import org.apache.spark.annotation.{DeveloperApi, Since} /** @@ -184,7 +186,7 @@ private[spark] object ImpurityCalculator { * the given stats. */ def getCalculator(impurity: String, stats: Array[Double]): ImpurityCalculator = { - impurity.toLowerCase match { + impurity.toLowerCase(Locale.ROOT) match { case "gini" => new GiniCalculator(stats) case "entropy" => new EntropyCalculator(stats) case "variance" => new VarianceCalculator(stats) diff --git a/repl/scala-2.11/src/main/scala/org/apache/spark/repl/Main.scala b/repl/scala-2.11/src/main/scala/org/apache/spark/repl/Main.scala index 7f2ec01cc9676..39fc621de7807 100644 --- a/repl/scala-2.11/src/main/scala/org/apache/spark/repl/Main.scala +++ b/repl/scala-2.11/src/main/scala/org/apache/spark/repl/Main.scala @@ -18,6 +18,7 @@ package org.apache.spark.repl import java.io.File +import java.util.Locale import scala.tools.nsc.GenericRunnerSettings @@ -88,7 +89,7 @@ object Main extends Logging { } val builder = SparkSession.builder.config(conf) - if (conf.get(CATALOG_IMPLEMENTATION.key, "hive").toLowerCase == "hive") { + if (conf.get(CATALOG_IMPLEMENTATION.key, "hive").toLowerCase(Locale.ROOT) == "hive") { if (SparkSession.hiveClassesArePresent) { // In the case that the property is not set at all, builder's config // does not have this value set to 'hive' yet. The original default diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala index 3218d221143e5..424bbca123190 100644 --- a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala +++ b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala @@ -21,7 +21,7 @@ import java.io.{File, FileOutputStream, IOException, OutputStreamWriter} import java.net.{InetAddress, UnknownHostException, URI} import java.nio.ByteBuffer import java.nio.charset.StandardCharsets -import java.util.{Properties, UUID} +import java.util.{Locale, Properties, UUID} import java.util.zip.{ZipEntry, ZipOutputStream} import scala.collection.JavaConverters._ @@ -532,7 +532,7 @@ private[spark] class Client( try { jarsStream.setLevel(0) jarsDir.listFiles().foreach { f => - if (f.isFile && f.getName.toLowerCase().endsWith(".jar") && f.canRead) { + if (f.isFile && f.getName.toLowerCase(Locale.ROOT).endsWith(".jar") && f.canRead) { jarsStream.putNextEntry(new ZipEntry(f.getName)) Files.copy(f, jarsStream) jarsStream.closeEntry() diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveHints.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveHints.scala index f8004ca300ac7..c4827b81e8b63 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveHints.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveHints.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.catalyst.analysis +import java.util.Locale + import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.catalyst.trees.CurrentOrigin @@ -83,7 +85,7 @@ object ResolveHints { } def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { - case h: Hint if BROADCAST_HINT_NAMES.contains(h.name.toUpperCase) => + case h: Hint if BROADCAST_HINT_NAMES.contains(h.name.toUpperCase(Locale.ROOT)) => applyBroadcastHint(h.child, h.parameters.toSet) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalogUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalogUtils.scala index 254eedfe77517..3ca9e6a8da5b5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalogUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalogUtils.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.catalyst.catalog import java.net.URI +import java.util.Locale import org.apache.hadoop.fs.Path import org.apache.hadoop.util.Shell @@ -167,8 +168,10 @@ object CatalogUtils { */ def maskCredentials(options: Map[String, String]): Map[String, String] = { options.map { - case (key, _) if key.toLowerCase == "password" => (key, "###") - case (key, value) if key.toLowerCase == "url" && value.toLowerCase.contains("password") => + case (key, _) if key.toLowerCase(Locale.ROOT) == "password" => (key, "###") + case (key, value) + if key.toLowerCase(Locale.ROOT) == "url" && + value.toLowerCase(Locale.ROOT).contains("password") => (key, "###") case o => o } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala index 6f8c6ee2f0f44..faedf5f91c3ef 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.catalyst.catalog import java.net.URI +import java.util.Locale import javax.annotation.concurrent.GuardedBy import scala.collection.mutable @@ -1098,7 +1099,7 @@ class SessionCatalog( name.database.isEmpty && functionRegistry.functionExists(name.funcName) && !FunctionRegistry.builtin.functionExists(name.funcName) && - !hiveFunctions.contains(name.funcName.toLowerCase) + !hiveFunctions.contains(name.funcName.toLowerCase(Locale.ROOT)) } protected def failFunctionLookup(name: String): Nothing = { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/functionResources.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/functionResources.scala index 8e46b962ff432..67bf2d06c95dd 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/functionResources.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/functionResources.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.catalyst.catalog +import java.util.Locale + import org.apache.spark.sql.AnalysisException /** A trait that represents the type of a resourced needed by a function. */ @@ -33,7 +35,7 @@ object ArchiveResource extends FunctionResourceType("archive") object FunctionResourceType { def fromString(resourceType: String): FunctionResourceType = { - resourceType.toLowerCase match { + resourceType.toLowerCase(Locale.ROOT) match { case "jar" => JarResource case "file" => FileResource case "archive" => ArchiveResource diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala index 1db26d9c415a7..b847ef7bfaa97 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.catalyst.expressions +import java.util.Locale + import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.expressions.codegen._ @@ -184,7 +186,7 @@ abstract class Expression extends TreeNode[Expression] { * Returns a user-facing string representation of this expression's name. * This should usually match the name of the function in SQL. */ - def prettyName: String = nodeName.toLowerCase + def prettyName: String = nodeName.toLowerCase(Locale.ROOT) protected def flatArguments: Iterator[Any] = productIterator.flatMap { case t: Traversable[_] => t diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala index dea5f85cb08cc..c4d47ab2084fd 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.catalyst.expressions import java.{lang => jl} +import java.util.Locale import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{TypeCheckFailure, TypeCheckSuccess} @@ -68,7 +69,7 @@ abstract class UnaryMathExpression(val f: Double => Double, name: String) } // name of function in java.lang.Math - def funcName: String = name.toLowerCase + def funcName: String = name.toLowerCase(Locale.ROOT) override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { defineCodeGen(ctx, ev, c => s"java.lang.Math.${funcName}($c)") @@ -124,7 +125,8 @@ abstract class BinaryMathExpression(f: (Double, Double) => Double, name: String) } override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - defineCodeGen(ctx, ev, (c1, c2) => s"java.lang.Math.${name.toLowerCase}($c1, $c2)") + defineCodeGen(ctx, ev, (c1, c2) => + s"java.lang.Math.${name.toLowerCase(Locale.ROOT)}($c1, $c2)") } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala index b23da537be721..49b779711308f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.catalyst.expressions +import java.util.Locale import java.util.regex.{MatchResult, Pattern} import org.apache.commons.lang3.StringEscapeUtils @@ -60,7 +61,7 @@ abstract class StringRegexExpression extends BinaryExpression } } - override def sql: String = s"${left.sql} ${prettyName.toUpperCase} ${right.sql}" + override def sql: String = s"${left.sql} ${prettyName.toUpperCase(Locale.ROOT)} ${right.sql}" } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala index b2a3888ff7b08..37190429fc423 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.catalyst.expressions +import java.util.Locale + import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, UnresolvedException} import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{TypeCheckFailure, TypeCheckSuccess} @@ -631,7 +633,7 @@ abstract class RankLike extends AggregateWindowFunction { override val updateExpressions = increaseRank +: increaseRowNumber +: children override val evaluateExpression: Expression = rank - override def sql: String = s"${prettyName.toUpperCase}()" + override def sql: String = s"${prettyName.toUpperCase(Locale.ROOT)}()" def withOrder(order: Seq[Expression]): RankLike } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonParser.scala index fdb7d88d5bd7f..ff6c93ae9815c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonParser.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonParser.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.catalyst.json import java.io.ByteArrayOutputStream +import java.util.Locale import scala.collection.mutable.ArrayBuffer import scala.util.Try @@ -126,7 +127,7 @@ class JacksonParser( case VALUE_STRING => // Special case handling for NaN and Infinity. val value = parser.getText - val lowerCaseValue = value.toLowerCase + val lowerCaseValue = value.toLowerCase(Locale.ROOT) if (lowerCaseValue.equals("nan") || lowerCaseValue.equals("infinity") || lowerCaseValue.equals("-infinity") || @@ -146,7 +147,7 @@ class JacksonParser( case VALUE_STRING => // Special case handling for NaN and Infinity. val value = parser.getText - val lowerCaseValue = value.toLowerCase + val lowerCaseValue = value.toLowerCase(Locale.ROOT) if (lowerCaseValue.equals("nan") || lowerCaseValue.equals("infinity") || lowerCaseValue.equals("-infinity") || diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index c37255153802b..e1db1ef5b8695 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.catalyst.parser import java.sql.{Date, Timestamp} +import java.util.Locale import javax.xml.bind.DatatypeConverter import scala.collection.JavaConverters._ @@ -1047,7 +1048,8 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with Logging { val name = ctx.qualifiedName.getText val isDistinct = Option(ctx.setQuantifier()).exists(_.DISTINCT != null) val arguments = ctx.namedExpression().asScala.map(expression) match { - case Seq(UnresolvedStar(None)) if name.toLowerCase == "count" && !isDistinct => + case Seq(UnresolvedStar(None)) + if name.toLowerCase(Locale.ROOT) == "count" && !isDistinct => // Transform COUNT(*) into COUNT(1). Seq(Literal(1)) case expressions => @@ -1271,7 +1273,7 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with Logging { */ override def visitTypeConstructor(ctx: TypeConstructorContext): Literal = withOrigin(ctx) { val value = string(ctx.STRING) - val valueType = ctx.identifier.getText.toUpperCase + val valueType = ctx.identifier.getText.toUpperCase(Locale.ROOT) try { valueType match { case "DATE" => @@ -1427,7 +1429,8 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with Logging { import ctx._ val s = value.getText try { - val interval = (unit.getText.toLowerCase, Option(to).map(_.getText.toLowerCase)) match { + val unitText = unit.getText.toLowerCase(Locale.ROOT) + val interval = (unitText, Option(to).map(_.getText.toLowerCase(Locale.ROOT))) match { case (u, None) if u.endsWith("s") => // Handle plural forms, e.g: yearS/monthS/weekS/dayS/hourS/minuteS/hourS/... CalendarInterval.fromSingleUnitString(u.substring(0, u.length - 1), s) @@ -1465,7 +1468,8 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with Logging { * Resolve/create a primitive type. */ override def visitPrimitiveDataType(ctx: PrimitiveDataTypeContext): DataType = withOrigin(ctx) { - (ctx.identifier.getText.toLowerCase, ctx.INTEGER_VALUE().asScala.toList) match { + val dataType = ctx.identifier.getText.toLowerCase(Locale.ROOT) + (dataType, ctx.INTEGER_VALUE().asScala.toList) match { case ("boolean", Nil) => BooleanType case ("tinyint" | "byte", Nil) => ByteType case ("smallint" | "short", Nil) => ShortType diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/joinTypes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/joinTypes.scala index 818f4e5ed2ae5..90d11d6d91512 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/joinTypes.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/joinTypes.scala @@ -17,10 +17,12 @@ package org.apache.spark.sql.catalyst.plans +import java.util.Locale + import org.apache.spark.sql.catalyst.expressions.Attribute object JoinType { - def apply(typ: String): JoinType = typ.toLowerCase.replace("_", "") match { + def apply(typ: String): JoinType = typ.toLowerCase(Locale.ROOT).replace("_", "") match { case "inner" => Inner case "outer" | "full" | "fullouter" => FullOuter case "leftouter" | "left" => LeftOuter diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/streaming/InternalOutputModes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/streaming/InternalOutputModes.scala index bdf2baf7361d3..3cd6970ebefbc 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/streaming/InternalOutputModes.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/streaming/InternalOutputModes.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.catalyst.streaming +import java.util.Locale + import org.apache.spark.sql.streaming.OutputMode /** @@ -47,7 +49,7 @@ private[sql] object InternalOutputModes { def apply(outputMode: String): OutputMode = { - outputMode.toLowerCase match { + outputMode.toLowerCase(Locale.ROOT) match { case "append" => OutputMode.Append case "complete" => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/CaseInsensitiveMap.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/CaseInsensitiveMap.scala index 66dd093bbb691..bb2c5926ae9bb 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/CaseInsensitiveMap.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/CaseInsensitiveMap.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.catalyst.util +import java.util.Locale + /** * Builds a map in which keys are case insensitive. Input map can be accessed for cases where * case-sensitive information is required. The primary constructor is marked private to avoid @@ -26,11 +28,12 @@ package org.apache.spark.sql.catalyst.util class CaseInsensitiveMap[T] private (val originalMap: Map[String, T]) extends Map[String, T] with Serializable { - val keyLowerCasedMap = originalMap.map(kv => kv.copy(_1 = kv._1.toLowerCase)) + val keyLowerCasedMap = originalMap.map(kv => kv.copy(_1 = kv._1.toLowerCase(Locale.ROOT))) - override def get(k: String): Option[T] = keyLowerCasedMap.get(k.toLowerCase) + override def get(k: String): Option[T] = keyLowerCasedMap.get(k.toLowerCase(Locale.ROOT)) - override def contains(k: String): Boolean = keyLowerCasedMap.contains(k.toLowerCase) + override def contains(k: String): Boolean = + keyLowerCasedMap.contains(k.toLowerCase(Locale.ROOT)) override def +[B1 >: T](kv: (String, B1)): Map[String, B1] = { new CaseInsensitiveMap(originalMap + kv) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/CompressionCodecs.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/CompressionCodecs.scala index 435fba9d8851c..1377a03d93b7e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/CompressionCodecs.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/CompressionCodecs.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.catalyst.util +import java.util.Locale + import org.apache.hadoop.conf.Configuration import org.apache.hadoop.io.SequenceFile.CompressionType import org.apache.hadoop.io.compress._ @@ -38,7 +40,7 @@ object CompressionCodecs { * If it is already a class name, just return it. */ def getCodecClassName(name: String): String = { - val codecName = shortCompressionCodecNames.getOrElse(name.toLowerCase, name) + val codecName = shortCompressionCodecNames.getOrElse(name.toLowerCase(Locale.ROOT), name) try { // Validate the codec name if (codecName != null) { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala index f614965520f4a..eb6aad5b2d2bb 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala @@ -894,7 +894,7 @@ object DateTimeUtils { * (Because 1970-01-01 is Thursday). */ def getDayOfWeekFromString(string: UTF8String): Int = { - val dowString = string.toString.toUpperCase + val dowString = string.toString.toUpperCase(Locale.ROOT) dowString match { case "SU" | "SUN" | "SUNDAY" => 3 case "MO" | "MON" | "MONDAY" => 4 @@ -951,7 +951,7 @@ object DateTimeUtils { if (format == null) { TRUNC_INVALID } else { - format.toString.toUpperCase match { + format.toString.toUpperCase(Locale.ROOT) match { case "YEAR" | "YYYY" | "YY" => TRUNC_TO_YEAR case "MON" | "MONTH" | "MM" => TRUNC_TO_MONTH case _ => TRUNC_INVALID diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ParseMode.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ParseMode.scala index 4565dbde88c88..2beb875d1751d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ParseMode.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ParseMode.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.catalyst.util +import java.util.Locale + import org.apache.spark.internal.Logging sealed trait ParseMode { @@ -45,7 +47,7 @@ object ParseMode extends Logging { /** * Returns the parse mode from the given string. */ - def fromString(mode: String): ParseMode = mode.toUpperCase match { + def fromString(mode: String): ParseMode = mode.toUpperCase(Locale.ROOT) match { case PermissiveMode.name => PermissiveMode case DropMalformedMode.name => DropMalformedMode case FailFastMode.name => FailFastMode diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/StringKeyHashMap.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/StringKeyHashMap.scala index a7ac6136835a7..812d5ded4bf0f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/StringKeyHashMap.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/StringKeyHashMap.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.catalyst.util +import java.util.Locale + /** * Build a map with String type of key, and it also supports either key case * sensitive or insensitive. @@ -25,7 +27,7 @@ object StringKeyHashMap { def apply[T](caseSensitive: Boolean): StringKeyHashMap[T] = if (caseSensitive) { new StringKeyHashMap[T](identity) } else { - new StringKeyHashMap[T](_.toLowerCase) + new StringKeyHashMap[T](_.toLowerCase(Locale.ROOT)) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 640c0f189c237..6b0f495033494 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.internal -import java.util.{NoSuchElementException, Properties, TimeZone} +import java.util.{Locale, NoSuchElementException, Properties, TimeZone} import java.util.concurrent.TimeUnit import scala.collection.JavaConverters._ @@ -243,7 +243,7 @@ object SQLConf { .doc("Sets the compression codec use when writing Parquet files. Acceptable values include: " + "uncompressed, snappy, gzip, lzo.") .stringConf - .transform(_.toLowerCase()) + .transform(_.toLowerCase(Locale.ROOT)) .checkValues(Set("uncompressed", "snappy", "gzip", "lzo")) .createWithDefault("snappy") @@ -324,7 +324,7 @@ object SQLConf { "properties) and NEVER_INFER (fallback to using the case-insensitive metastore schema " + "instead of inferring).") .stringConf - .transform(_.toUpperCase()) + .transform(_.toUpperCase(Locale.ROOT)) .checkValues(HiveCaseSensitiveInferenceMode.values.map(_.toString)) .createWithDefault(HiveCaseSensitiveInferenceMode.INFER_AND_SAVE.toString) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala index 26871259c6b6e..520aff5e2b677 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.types +import java.util.Locale + import org.json4s._ import org.json4s.JsonAST.JValue import org.json4s.JsonDSL._ @@ -49,7 +51,9 @@ abstract class DataType extends AbstractDataType { /** Name of the type used in JSON serialization. */ def typeName: String = { - this.getClass.getSimpleName.stripSuffix("$").stripSuffix("Type").stripSuffix("UDT").toLowerCase + this.getClass.getSimpleName + .stripSuffix("$").stripSuffix("Type").stripSuffix("UDT") + .toLowerCase(Locale.ROOT) } private[sql] def jsonValue: JValue = typeName @@ -69,7 +73,7 @@ abstract class DataType extends AbstractDataType { /** Readable string representation for the type with truncation */ private[sql] def simpleString(maxNumberFields: Int): String = simpleString - def sql: String = simpleString.toUpperCase + def sql: String = simpleString.toUpperCase(Locale.ROOT) /** * Check if `this` and `other` are the same data type when ignoring nullability diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala index 4dc06fc9cf09b..5c4bc5e33c53a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.types +import java.util.Locale + import scala.reflect.runtime.universe.typeTag import org.apache.spark.annotation.InterfaceStability @@ -65,7 +67,7 @@ case class DecimalType(precision: Int, scale: Int) extends FractionalType { override def toString: String = s"DecimalType($precision,$scale)" - override def sql: String = typeName.toUpperCase + override def sql: String = typeName.toUpperCase(Locale.ROOT) /** * Returns whether this DecimalType is wider than `other`. If yes, it means `other` diff --git a/sql/catalyst/src/test/java/org/apache/spark/sql/streaming/JavaOutputModeSuite.java b/sql/catalyst/src/test/java/org/apache/spark/sql/streaming/JavaOutputModeSuite.java index e0a54fe30ac7d..d8845e0c838ff 100644 --- a/sql/catalyst/src/test/java/org/apache/spark/sql/streaming/JavaOutputModeSuite.java +++ b/sql/catalyst/src/test/java/org/apache/spark/sql/streaming/JavaOutputModeSuite.java @@ -17,6 +17,8 @@ package org.apache.spark.sql.streaming; +import java.util.Locale; + import org.junit.Test; public class JavaOutputModeSuite { @@ -24,8 +26,8 @@ public class JavaOutputModeSuite { @Test public void testOutputModes() { OutputMode o1 = OutputMode.Append(); - assert(o1.toString().toLowerCase().contains("append")); + assert(o1.toString().toLowerCase(Locale.ROOT).contains("append")); OutputMode o2 = OutputMode.Complete(); - assert (o2.toString().toLowerCase().contains("complete")); + assert (o2.toString().toLowerCase(Locale.ROOT).contains("complete")); } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisTest.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisTest.scala index 1be25ec06c741..82015b1e0671c 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisTest.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisTest.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.catalyst.analysis +import java.util.Locale + import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.catalog.{InMemoryCatalog, SessionCatalog} import org.apache.spark.sql.catalyst.plans.PlanTest @@ -79,7 +81,8 @@ trait AnalysisTest extends PlanTest { analyzer.checkAnalysis(analyzer.execute(inputPlan)) } - if (!expectedErrors.map(_.toLowerCase).forall(e.getMessage.toLowerCase.contains)) { + if (!expectedErrors.map(_.toLowerCase(Locale.ROOT)).forall( + e.getMessage.toLowerCase(Locale.ROOT).contains)) { fail( s"""Exception message should contain the following substrings: | diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationsSuite.scala index 8f0a0c0d99d15..c39e372c272b1 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationsSuite.scala @@ -17,19 +17,20 @@ package org.apache.spark.sql.catalyst.analysis +import java.util.Locale + import org.apache.spark.SparkFunSuite import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder -import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, Literal, NamedExpression} +import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, NamedExpression} import org.apache.spark.sql.catalyst.expressions.aggregate.Count import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical.{FlatMapGroupsWithState, _} import org.apache.spark.sql.catalyst.streaming.InternalOutputModes._ import org.apache.spark.sql.streaming.OutputMode import org.apache.spark.sql.types.{IntegerType, LongType, MetadataBuilder} -import org.apache.spark.unsafe.types.CalendarInterval /** A dummy command for testing unsupported operations. */ case class DummyCommand() extends Command @@ -696,7 +697,7 @@ class UnsupportedOperationsSuite extends SparkFunSuite { testBody } expectedMsgs.foreach { m => - if (!e.getMessage.toLowerCase.contains(m.toLowerCase)) { + if (!e.getMessage.toLowerCase(Locale.ROOT).contains(m.toLowerCase(Locale.ROOT))) { fail(s"Exception message should contain: '$m', " + s"actual exception message:\n\t'${e.getMessage}'") } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDFSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDFSuite.scala index 7e45028653e36..13bd363c8b692 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDFSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDFSuite.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.catalyst.expressions +import java.util.Locale + import org.apache.spark.{SparkException, SparkFunSuite} import org.apache.spark.sql.types.{IntegerType, StringType} @@ -32,7 +34,7 @@ class ScalaUDFSuite extends SparkFunSuite with ExpressionEvalHelper { test("better error message for NPE") { val udf = ScalaUDF( - (s: String) => s.toLowerCase, + (s: String) => s.toLowerCase(Locale.ROOT), StringType, Literal.create(null, StringType) :: Nil) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/streaming/InternalOutputModesSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/streaming/InternalOutputModesSuite.scala index 201dac35ed2d8..3159b541dca79 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/streaming/InternalOutputModesSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/streaming/InternalOutputModesSuite.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.catalyst.streaming +import java.util.Locale + import org.apache.spark.SparkFunSuite import org.apache.spark.sql.streaming.OutputMode @@ -40,7 +42,7 @@ class InternalOutputModesSuite extends SparkFunSuite { val acceptedModes = Seq("append", "update", "complete") val e = intercept[IllegalArgumentException](InternalOutputModes(outputMode)) (Seq("output mode", "unknown", outputMode) ++ acceptedModes).foreach { s => - assert(e.getMessage.toLowerCase.contains(s.toLowerCase)) + assert(e.getMessage.toLowerCase(Locale.ROOT).contains(s.toLowerCase(Locale.ROOT))) } } testMode("Xyz") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala index d8f953fba5a8b..93d565d9fe904 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql import java.{lang => jl} +import java.util.Locale import scala.collection.JavaConverters._ @@ -89,7 +90,7 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) { * @since 1.3.1 */ def drop(how: String, cols: Seq[String]): DataFrame = { - how.toLowerCase match { + how.toLowerCase(Locale.ROOT) match { case "any" => drop(cols.size, cols) case "all" => drop(1, cols) case _ => throw new IllegalArgumentException(s"how ($how) must be 'any' or 'all'") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala index 2b8537c3d4a63..49691c15d0f7d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql -import java.util.Properties +import java.util.{Locale, Properties} import scala.collection.JavaConverters._ @@ -164,7 +164,7 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { */ @scala.annotation.varargs def load(paths: String*): DataFrame = { - if (source.toLowerCase == DDLUtils.HIVE_PROVIDER) { + if (source.toLowerCase(Locale.ROOT) == DDLUtils.HIVE_PROVIDER) { throw new AnalysisException("Hive data source can only be used with tables, you can not " + "read files of Hive data source directly.") } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala index 338a6e1314d90..1732a8e08b73f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql -import java.util.Properties +import java.util.{Locale, Properties} import scala.collection.JavaConverters._ @@ -66,7 +66,7 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { * @since 1.4.0 */ def mode(saveMode: String): DataFrameWriter[T] = { - this.mode = saveMode.toLowerCase match { + this.mode = saveMode.toLowerCase(Locale.ROOT) match { case "overwrite" => SaveMode.Overwrite case "append" => SaveMode.Append case "ignore" => SaveMode.Ignore @@ -223,7 +223,7 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { * @since 1.4.0 */ def save(): Unit = { - if (source.toLowerCase == DDLUtils.HIVE_PROVIDER) { + if (source.toLowerCase(Locale.ROOT) == DDLUtils.HIVE_PROVIDER) { throw new AnalysisException("Hive data source can only be used with tables, you can not " + "write files of Hive data source directly.") } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala index 0fe8d87ebd6ba..64755434784a0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql +import java.util.Locale + import scala.collection.JavaConverters._ import scala.language.implicitConversions @@ -108,7 +110,7 @@ class RelationalGroupedDataset protected[sql]( private[this] def strToExpr(expr: String): (Expression => Expression) = { val exprToFunc: (Expression => Expression) = { - (inputExpr: Expression) => expr.toLowerCase match { + (inputExpr: Expression) => expr.toLowerCase(Locale.ROOT) match { // We special handle a few cases that have alias that are not in function registry. case "avg" | "average" | "mean" => UnresolvedFunction("avg", inputExpr :: Nil, isDistinct = false) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala index c77328690daec..a26d00411fbaa 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.api.r import java.io.{ByteArrayInputStream, ByteArrayOutputStream, DataInputStream, DataOutputStream} -import java.util.{Map => JMap} +import java.util.{Locale, Map => JMap} import scala.collection.JavaConverters._ import scala.util.matching.Regex @@ -47,17 +47,19 @@ private[sql] object SQLUtils extends Logging { jsc: JavaSparkContext, sparkConfigMap: JMap[Object, Object], enableHiveSupport: Boolean): SparkSession = { - val spark = if (SparkSession.hiveClassesArePresent && enableHiveSupport - && jsc.sc.conf.get(CATALOG_IMPLEMENTATION.key, "hive").toLowerCase == "hive") { - SparkSession.builder().sparkContext(withHiveExternalCatalog(jsc.sc)).getOrCreate() - } else { - if (enableHiveSupport) { - logWarning("SparkR: enableHiveSupport is requested for SparkSession but " + - s"Spark is not built with Hive or ${CATALOG_IMPLEMENTATION.key} is not set to 'hive', " + - "falling back to without Hive support.") + val spark = + if (SparkSession.hiveClassesArePresent && enableHiveSupport && + jsc.sc.conf.get(CATALOG_IMPLEMENTATION.key, "hive").toLowerCase(Locale.ROOT) == + "hive") { + SparkSession.builder().sparkContext(withHiveExternalCatalog(jsc.sc)).getOrCreate() + } else { + if (enableHiveSupport) { + logWarning("SparkR: enableHiveSupport is requested for SparkSession but " + + s"Spark is not built with Hive or ${CATALOG_IMPLEMENTATION.key} is not set to " + + "'hive', falling back to without Hive support.") + } + SparkSession.builder().sparkContext(jsc.sc).getOrCreate() } - SparkSession.builder().sparkContext(jsc.sc).getOrCreate() - } setSparkContextSessionConf(spark, sparkConfigMap) spark } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala index 80afb59b3e88e..20dacf88504f1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.execution +import java.util.Locale + import scala.collection.JavaConverters._ import org.antlr.v4.runtime.{ParserRuleContext, Token} @@ -103,7 +105,7 @@ class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder { logWarning(s"Partition specification is ignored: ${ctx.partitionSpec.getText}") } if (ctx.identifier != null) { - if (ctx.identifier.getText.toLowerCase != "noscan") { + if (ctx.identifier.getText.toLowerCase(Locale.ROOT) != "noscan") { throw new ParseException(s"Expected `NOSCAN` instead of `${ctx.identifier.getText}`", ctx) } AnalyzeTableCommand(visitTableIdentifier(ctx.tableIdentifier)) @@ -563,7 +565,7 @@ class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder { } else if (value.STRING != null) { string(value.STRING) } else if (value.booleanValue != null) { - value.getText.toLowerCase + value.getText.toLowerCase(Locale.ROOT) } else { value.getText } @@ -647,7 +649,7 @@ class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder { */ override def visitShowFunctions(ctx: ShowFunctionsContext): LogicalPlan = withOrigin(ctx) { import ctx._ - val (user, system) = Option(ctx.identifier).map(_.getText.toLowerCase) match { + val (user, system) = Option(ctx.identifier).map(_.getText.toLowerCase(Locale.ROOT)) match { case None | Some("all") => (true, true) case Some("system") => (false, true) case Some("user") => (true, false) @@ -677,7 +679,7 @@ class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder { */ override def visitCreateFunction(ctx: CreateFunctionContext): LogicalPlan = withOrigin(ctx) { val resources = ctx.resource.asScala.map { resource => - val resourceType = resource.identifier.getText.toLowerCase + val resourceType = resource.identifier.getText.toLowerCase(Locale.ROOT) resourceType match { case "jar" | "file" | "archive" => FunctionResource(FunctionResourceType.fromString(resourceType), string(resource.STRING)) @@ -959,7 +961,7 @@ class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder { .flatMap(_.orderedIdentifier.asScala) .map { orderedIdCtx => Option(orderedIdCtx.ordering).map(_.getText).foreach { dir => - if (dir.toLowerCase != "asc") { + if (dir.toLowerCase(Locale.ROOT) != "asc") { operationNotAllowed(s"Column ordering must be ASC, was '$dir'", ctx) } } @@ -1012,13 +1014,13 @@ class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder { val mayebePaths = remainder(ctx.identifier).trim ctx.op.getType match { case SqlBaseParser.ADD => - ctx.identifier.getText.toLowerCase match { + ctx.identifier.getText.toLowerCase(Locale.ROOT) match { case "file" => AddFileCommand(mayebePaths) case "jar" => AddJarCommand(mayebePaths) case other => operationNotAllowed(s"ADD with resource type '$other'", ctx) } case SqlBaseParser.LIST => - ctx.identifier.getText.toLowerCase match { + ctx.identifier.getText.toLowerCase(Locale.ROOT) match { case "files" | "file" => if (mayebePaths.length > 0) { ListFilesCommand(mayebePaths.split("\\s+")) @@ -1305,7 +1307,7 @@ class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder { (rowFormatCtx, createFileFormatCtx.fileFormat) match { case (_, ffTable: TableFileFormatContext) => // OK case (rfSerde: RowFormatSerdeContext, ffGeneric: GenericFileFormatContext) => - ffGeneric.identifier.getText.toLowerCase match { + ffGeneric.identifier.getText.toLowerCase(Locale.ROOT) match { case ("sequencefile" | "textfile" | "rcfile") => // OK case fmt => operationNotAllowed( @@ -1313,7 +1315,7 @@ class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder { parentCtx) } case (rfDelimited: RowFormatDelimitedContext, ffGeneric: GenericFileFormatContext) => - ffGeneric.identifier.getText.toLowerCase match { + ffGeneric.identifier.getText.toLowerCase(Locale.ROOT) match { case "textfile" => // OK case fmt => operationNotAllowed( s"ROW FORMAT DELIMITED is only compatible with 'textfile', not '$fmt'", parentCtx) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala index c31fd92447c0d..c1e1a631c677e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala @@ -17,7 +17,9 @@ package org.apache.spark.sql.execution -import org.apache.spark.{broadcast, TaskContext} +import java.util.Locale + +import org.apache.spark.broadcast import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ @@ -43,7 +45,7 @@ trait CodegenSupport extends SparkPlan { case _: SortMergeJoinExec => "smj" case _: RDDScanExec => "rdd" case _: DataSourceScanExec => "scan" - case _ => nodeName.toLowerCase + case _ => nodeName.toLowerCase(Locale.ROOT) } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala index 9d3c55060dfb6..55540563ef911 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.execution.command +import java.util.Locale + import scala.collection.{GenMap, GenSeq} import scala.collection.parallel.ForkJoinTaskSupport import scala.concurrent.forkjoin.ForkJoinPool @@ -764,11 +766,11 @@ object DDLUtils { val HIVE_PROVIDER = "hive" def isHiveTable(table: CatalogTable): Boolean = { - table.provider.isDefined && table.provider.get.toLowerCase == HIVE_PROVIDER + table.provider.isDefined && table.provider.get.toLowerCase(Locale.ROOT) == HIVE_PROVIDER } def isDatasourceTable(table: CatalogTable): Boolean = { - table.provider.isDefined && table.provider.get.toLowerCase != HIVE_PROVIDER + table.provider.isDefined && table.provider.get.toLowerCase(Locale.ROOT) != HIVE_PROVIDER } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/functions.scala index ea5398761c46d..5687f9332430e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/functions.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.execution.command +import java.util.Locale + import org.apache.spark.sql.{AnalysisException, Row, SparkSession} import org.apache.spark.sql.catalyst.FunctionIdentifier import org.apache.spark.sql.catalyst.analysis.{FunctionRegistry, NoSuchFunctionException} @@ -100,7 +102,7 @@ case class DescribeFunctionCommand( override def run(sparkSession: SparkSession): Seq[Row] = { // Hard code "<>", "!=", "between", and "case" for now as there is no corresponding functions. - functionName.funcName.toLowerCase match { + functionName.funcName.toLowerCase(Locale.ROOT) match { case "<>" => Row(s"Function: $functionName") :: Row("Usage: expr1 <> expr2 - " + diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala index c9384e44255b8..f3b209deaae5c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala @@ -17,12 +17,11 @@ package org.apache.spark.sql.execution.datasources -import java.util.{ServiceConfigurationError, ServiceLoader} +import java.util.{Locale, ServiceConfigurationError, ServiceLoader} import scala.collection.JavaConverters._ import scala.language.{existentials, implicitConversions} import scala.util.{Failure, Success, Try} -import scala.util.control.NonFatal import org.apache.hadoop.fs.Path @@ -539,15 +538,16 @@ object DataSource { // Found the data source using fully qualified path dataSource case Failure(error) => - if (provider1.toLowerCase == "orc" || + if (provider1.toLowerCase(Locale.ROOT) == "orc" || provider1.startsWith("org.apache.spark.sql.hive.orc")) { throw new AnalysisException( "The ORC data source must be used with Hive support enabled") - } else if (provider1.toLowerCase == "avro" || + } else if (provider1.toLowerCase(Locale.ROOT) == "avro" || provider1 == "com.databricks.spark.avro") { throw new AnalysisException( - s"Failed to find data source: ${provider1.toLowerCase}. Please find an Avro " + - "package at http://spark.apache.org/third-party-projects.html") + s"Failed to find data source: ${provider1.toLowerCase(Locale.ROOT)}. " + + "Please find an Avro package at " + + "http://spark.apache.org/third-party-projects.html") } else { throw new ClassNotFoundException( s"Failed to find data source: $provider1. Please find packages at " + @@ -596,8 +596,8 @@ object DataSource { */ def buildStorageFormatFromOptions(options: Map[String, String]): CatalogStorageFormat = { val path = CaseInsensitiveMap(options).get("path") - val optionsWithoutPath = options.filterKeys(_.toLowerCase != "path") + val optionsWithoutPath = options.filterKeys(_.toLowerCase(Locale.ROOT) != "path") CatalogStorageFormat.empty.copy( - locationUri = path.map(CatalogUtils.stringToURI(_)), properties = optionsWithoutPath) + locationUri = path.map(CatalogUtils.stringToURI), properties = optionsWithoutPath) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InMemoryFileIndex.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InMemoryFileIndex.scala index 11605dd280569..9897ab73b0da8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InMemoryFileIndex.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InMemoryFileIndex.scala @@ -245,7 +245,6 @@ object InMemoryFileIndex extends Logging { sessionOpt: Option[SparkSession]): Seq[FileStatus] = { logTrace(s"Listing $path") val fs = path.getFileSystem(hadoopConf) - val name = path.getName.toLowerCase // [SPARK-17599] Prevent InMemoryFileIndex from failing if path doesn't exist // Note that statuses only include FileStatus for the files and dirs directly under path, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala index 03980922ab38f..c3583209efc56 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.execution.datasources import java.lang.{Double => JDouble, Long => JLong} import java.math.{BigDecimal => JBigDecimal} -import java.util.TimeZone +import java.util.{Locale, TimeZone} import scala.collection.mutable.ArrayBuffer import scala.util.Try @@ -194,7 +194,7 @@ object PartitioningUtils { while (!finished) { // Sometimes (e.g., when speculative task is enabled), temporary directories may be left // uncleaned. Here we simply ignore them. - if (currentPath.getName.toLowerCase == "_temporary") { + if (currentPath.getName.toLowerCase(Locale.ROOT) == "_temporary") { return (None, None) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala index 4994b8dc80527..62e4c6e4b4ea0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala @@ -71,9 +71,9 @@ class CSVOptions( val param = parameters.getOrElse(paramName, default.toString) if (param == null) { default - } else if (param.toLowerCase == "true") { + } else if (param.toLowerCase(Locale.ROOT) == "true") { true - } else if (param.toLowerCase == "false") { + } else if (param.toLowerCase(Locale.ROOT) == "false") { false } else { throw new Exception(s"$paramName flag can be true or false") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ddl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ddl.scala index 110d503f91cf4..f8d4a9bb5b81a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ddl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ddl.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.execution.datasources +import java.util.Locale + import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.catalog.{CatalogTable, CatalogUtils} @@ -75,7 +77,7 @@ case class CreateTempViewUsing( } def run(sparkSession: SparkSession): Seq[Row] = { - if (provider.toLowerCase == DDLUtils.HIVE_PROVIDER) { + if (provider.toLowerCase(Locale.ROOT) == DDLUtils.HIVE_PROVIDER) { throw new AnalysisException("Hive data source can only be used with tables, " + "you can't use it with CREATE TEMP VIEW USING") } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCOptions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCOptions.scala index 89fe86c038b16..591096d5efd22 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCOptions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCOptions.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.execution.datasources.jdbc import java.sql.{Connection, DriverManager} -import java.util.Properties +import java.util.{Locale, Properties} import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap @@ -55,7 +55,7 @@ class JDBCOptions( */ val asConnectionProperties: Properties = { val properties = new Properties() - parameters.originalMap.filterKeys(key => !jdbcOptionNames(key.toLowerCase)) + parameters.originalMap.filterKeys(key => !jdbcOptionNames(key.toLowerCase(Locale.ROOT))) .foreach { case (k, v) => properties.setProperty(k, v) } properties } @@ -141,7 +141,7 @@ object JDBCOptions { private val jdbcOptionNames = collection.mutable.Set[String]() private def newOption(name: String): String = { - jdbcOptionNames += name.toLowerCase + jdbcOptionNames += name.toLowerCase(Locale.ROOT) name } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala index 774d1ba194321..5fc3c2753b6cf 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.execution.datasources.jdbc import java.sql.{Connection, Driver, DriverManager, PreparedStatement, ResultSet, ResultSetMetaData, SQLException} +import java.util.Locale import scala.collection.JavaConverters._ import scala.util.Try @@ -542,7 +543,7 @@ object JdbcUtils extends Logging { case ArrayType(et, _) => // remove type length parameters from end of type name val typeName = getJdbcType(et, dialect).databaseTypeDefinition - .toLowerCase.split("\\(")(0) + .toLowerCase(Locale.ROOT).split("\\(")(0) (stmt: PreparedStatement, row: Row, pos: Int) => val array = conn.createArrayOf( typeName, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetOptions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetOptions.scala index bdda299a621ac..772d4565de548 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetOptions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetOptions.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.execution.datasources.parquet +import java.util.Locale + import org.apache.parquet.hadoop.metadata.CompressionCodecName import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap @@ -40,9 +42,11 @@ private[parquet] class ParquetOptions( * Acceptable values are defined in [[shortParquetCompressionCodecNames]]. */ val compressionCodecClassName: String = { - val codecName = parameters.getOrElse("compression", sqlConf.parquetCompressionCodec).toLowerCase + val codecName = parameters.getOrElse("compression", + sqlConf.parquetCompressionCodec).toLowerCase(Locale.ROOT) if (!shortParquetCompressionCodecNames.contains(codecName)) { - val availableCodecs = shortParquetCompressionCodecNames.keys.map(_.toLowerCase) + val availableCodecs = + shortParquetCompressionCodecNames.keys.map(_.toLowerCase(Locale.ROOT)) throw new IllegalArgumentException(s"Codec [$codecName] " + s"is not available. Available codecs are ${availableCodecs.mkString(", ")}.") } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala index 8b598cc60e778..7abf2ae5166b5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.execution.datasources +import java.util.Locale + import org.apache.spark.sql.{AnalysisException, SaveMode, SparkSession} import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.catalog._ @@ -48,7 +50,8 @@ class ResolveSQLOnFile(sparkSession: SparkSession) extends Rule[LogicalPlan] { // will catch it and return the original plan, so that the analyzer can report table not // found later. val isFileFormat = classOf[FileFormat].isAssignableFrom(dataSource.providingClass) - if (!isFileFormat || dataSource.className.toLowerCase == DDLUtils.HIVE_PROVIDER) { + if (!isFileFormat || + dataSource.className.toLowerCase(Locale.ROOT) == DDLUtils.HIVE_PROVIDER) { throw new AnalysisException("Unsupported data source type for direct query on files: " + s"${u.tableIdentifier.database.get}") } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala index f9dd80230e488..1426728f9b550 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.execution.streaming.state import java.io.{DataInputStream, DataOutputStream, FileNotFoundException, IOException} +import java.util.Locale import scala.collection.JavaConverters._ import scala.collection.mutable @@ -599,7 +600,7 @@ private[state] class HDFSBackedStateStoreProvider( val nameParts = path.getName.split("\\.") if (nameParts.size == 2) { val version = nameParts(0).toLong - nameParts(1).toLowerCase match { + nameParts(1).toLowerCase(Locale.ROOT) match { case "delta" => // ignore the file otherwise, snapshot file already exists for that batch id if (!versionToFiles.contains(version)) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/HiveSerDe.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/HiveSerDe.scala index ca46a1151e3e1..b9515ec7bca2a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/HiveSerDe.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/HiveSerDe.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.internal +import java.util.Locale + import org.apache.spark.sql.catalyst.catalog.CatalogStorageFormat case class HiveSerDe( @@ -68,7 +70,7 @@ object HiveSerDe { * @return HiveSerDe associated with the specified source */ def sourceToSerDe(source: String): Option[HiveSerDe] = { - val key = source.toLowerCase match { + val key = source.toLowerCase(Locale.ROOT) match { case s if s.startsWith("org.apache.spark.sql.parquet") => "parquet" case s if s.startsWith("org.apache.spark.sql.orc") => "orc" case s if s.equals("orcfile") => "orc" diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/SharedState.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/SharedState.scala index 1ef9d52713d92..0289471bf841a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/SharedState.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/SharedState.scala @@ -21,7 +21,6 @@ import scala.reflect.ClassTag import scala.util.control.NonFatal import org.apache.hadoop.conf.Configuration -import org.apache.hadoop.fs.Path import org.apache.spark.{SparkConf, SparkContext, SparkException} import org.apache.spark.internal.Logging diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala index c3a9cfc08517a..746b2a94f102d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.streaming +import java.util.Locale + import scala.collection.JavaConverters._ import org.apache.spark.annotation.{Experimental, InterfaceStability} @@ -135,7 +137,7 @@ final class DataStreamReader private[sql](sparkSession: SparkSession) extends Lo * @since 2.0.0 */ def load(): DataFrame = { - if (source.toLowerCase == DDLUtils.HIVE_PROVIDER) { + if (source.toLowerCase(Locale.ROOT) == DDLUtils.HIVE_PROVIDER) { throw new AnalysisException("Hive data source can only be used with tables, you can not " + "read files of Hive data source directly.") } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala index f2f700590ca8e..0d2611f9bbcce 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.streaming +import java.util.Locale + import scala.collection.JavaConverters._ import org.apache.spark.annotation.{Experimental, InterfaceStability} @@ -230,7 +232,7 @@ final class DataStreamWriter[T] private[sql](ds: Dataset[T]) { * @since 2.0.0 */ def start(): StreamingQuery = { - if (source.toLowerCase == DDLUtils.HIVE_PROVIDER) { + if (source.toLowerCase(Locale.ROOT) == DDLUtils.HIVE_PROVIDER) { throw new AnalysisException("Hive data source can only be used with tables, you can not " + "write files of Hive data source directly.") } diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java index 78cf033dd81d7..3ba37addfc8b4 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java @@ -119,7 +119,7 @@ public void testCommonOperation() { Dataset parMapped = ds.mapPartitions((MapPartitionsFunction) it -> { List ls = new LinkedList<>(); while (it.hasNext()) { - ls.add(it.next().toUpperCase(Locale.ENGLISH)); + ls.add(it.next().toUpperCase(Locale.ROOT)); } return ls.iterator(); }, Encoders.STRING()); diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQueryTestSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQueryTestSuite.scala index 4b69baffab620..d9130fdcfaea6 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQueryTestSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQueryTestSuite.scala @@ -124,7 +124,8 @@ class SQLQueryTestSuite extends QueryTest with SharedSQLContext { } private def createScalaTestCase(testCase: TestCase): Unit = { - if (blackList.exists(t => testCase.name.toLowerCase.contains(t.toLowerCase))) { + if (blackList.exists(t => + testCase.name.toLowerCase(Locale.ROOT).contains(t.toLowerCase(Locale.ROOT)))) { // Create a test case to ignore this case. ignore(testCase.name) { /* Do nothing */ } } else { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/QueryExecutionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/QueryExecutionSuite.scala index 8bceab39f71d5..1c1931b6a6daf 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/QueryExecutionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/QueryExecutionSuite.scala @@ -16,6 +16,8 @@ */ package org.apache.spark.sql.execution +import java.util.Locale + import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, OneRowRelation} import org.apache.spark.sql.test.SharedSQLContext @@ -24,11 +26,12 @@ class QueryExecutionSuite extends SharedSQLContext { test("toString() exception/error handling") { val badRule = new SparkStrategy { var mode: String = "" - override def apply(plan: LogicalPlan): Seq[SparkPlan] = mode.toLowerCase match { - case "exception" => throw new AnalysisException(mode) - case "error" => throw new Error(mode) - case _ => Nil - } + override def apply(plan: LogicalPlan): Seq[SparkPlan] = + mode.toLowerCase(Locale.ROOT) match { + case "exception" => throw new AnalysisException(mode) + case "error" => throw new Error(mode) + case _ => Nil + } } spark.experimental.extraStrategies = badRule :: Nil diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLCommandSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLCommandSuite.scala index 13202a57851e1..97c61dc8694bc 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLCommandSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLCommandSuite.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.execution.command import java.net.URI +import java.util.Locale import scala.reflect.{classTag, ClassTag} @@ -40,8 +41,10 @@ class DDLCommandSuite extends PlanTest { val e = intercept[ParseException] { parser.parsePlan(sql) } - assert(e.getMessage.toLowerCase.contains("operation not allowed")) - containsThesePhrases.foreach { p => assert(e.getMessage.toLowerCase.contains(p.toLowerCase)) } + assert(e.getMessage.toLowerCase(Locale.ROOT).contains("operation not allowed")) + containsThesePhrases.foreach { p => + assert(e.getMessage.toLowerCase(Locale.ROOT).contains(p.toLowerCase(Locale.ROOT))) + } } private def parseAs[T: ClassTag](query: String): T = { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala index 9ebf2dd839a79..fe74ab49f91bd 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.execution.command import java.io.File import java.net.URI +import java.util.Locale import org.apache.hadoop.fs.Path import org.scalatest.BeforeAndAfterEach @@ -190,7 +191,7 @@ abstract class DDLSuite extends QueryTest with SQLTestUtils { val e = intercept[AnalysisException] { sql(query) } - assert(e.getMessage.toLowerCase.contains("operation not allowed")) + assert(e.getMessage.toLowerCase(Locale.ROOT).contains("operation not allowed")) } private def maybeWrapException[T](expectException: Boolean)(body: => T): Unit = { @@ -1813,7 +1814,7 @@ abstract class DDLSuite extends QueryTest with SQLTestUtils { withTable(tabName) { sql(s"CREATE TABLE $tabName(col1 int, col2 string) USING parquet ") val message = intercept[AnalysisException] { - sql(s"SHOW COLUMNS IN $db.showcolumn FROM ${db.toUpperCase}") + sql(s"SHOW COLUMNS IN $db.showcolumn FROM ${db.toUpperCase(Locale.ROOT)}") }.getMessage assert(message.contains("SHOW COLUMNS with conflicting databases")) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala index 57a0af1dda971..94a2f9a00b3f3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.execution.datasources.parquet +import java.util.Locale + import scala.collection.JavaConverters._ import scala.collection.mutable import scala.reflect.ClassTag @@ -300,7 +302,7 @@ class ParquetIOSuite extends QueryTest with ParquetTest with SharedSQLContext { def checkCompressionCodec(codec: CompressionCodecName): Unit = { withSQLConf(SQLConf.PARQUET_COMPRESSION.key -> codec.name()) { withParquetFile(data) { path => - assertResult(spark.conf.get(SQLConf.PARQUET_COMPRESSION).toUpperCase) { + assertResult(spark.conf.get(SQLConf.PARQUET_COMPRESSION).toUpperCase(Locale.ROOT)) { compressionCodecFor(path, codec.name()) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetPartitionDiscoverySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetPartitionDiscoverySuite.scala index 2b20b9716bf80..b4f3de9961209 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetPartitionDiscoverySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetPartitionDiscoverySuite.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.execution.datasources.parquet import java.io.File import java.math.BigInteger import java.sql.{Date, Timestamp} -import java.util.{Calendar, TimeZone} +import java.util.{Calendar, Locale, TimeZone} import scala.collection.mutable.ArrayBuffer @@ -476,7 +476,8 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest with Sha assert(partDf.schema.map(_.name) === Seq("intField", "stringField")) path.listFiles().foreach { f => - if (!f.getName.startsWith("_") && f.getName.toLowerCase().endsWith(".parquet")) { + if (!f.getName.startsWith("_") && + f.getName.toLowerCase(Locale.ROOT).endsWith(".parquet")) { // when the input is a path to a parquet file val df = spark.read.parquet(f.getCanonicalPath) assert(df.schema.map(_.name) === Seq("intField", "stringField")) @@ -484,7 +485,8 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest with Sha } path.listFiles().foreach { f => - if (!f.getName.startsWith("_") && f.getName.toLowerCase().endsWith(".parquet")) { + if (!f.getName.startsWith("_") && + f.getName.toLowerCase(Locale.ROOT).endsWith(".parquet")) { // when the input is a path to a parquet file but `basePath` is overridden to // the base path containing partitioning directories val df = spark diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/FilteredScanSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/FilteredScanSuite.scala index be56c964a18f8..5a0388ec1d1db 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/FilteredScanSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/FilteredScanSuite.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.sources +import java.util.Locale + import scala.language.existentials import org.apache.spark.rdd.RDD @@ -76,7 +78,7 @@ case class SimpleFilteredScan(from: Int, to: Int)(@transient val sparkSession: S case "b" => (i: Int) => Seq(i * 2) case "c" => (i: Int) => val c = (i - 1 + 'a').toChar.toString - Seq(c * 5 + c.toUpperCase * 5) + Seq(c * 5 + c.toUpperCase(Locale.ROOT) * 5) } FiltersPushed.list = filters @@ -113,7 +115,8 @@ case class SimpleFilteredScan(from: Int, to: Int)(@transient val sparkSession: S } def eval(a: Int) = { - val c = (a - 1 + 'a').toChar.toString * 5 + (a - 1 + 'a').toChar.toString.toUpperCase * 5 + val c = (a - 1 + 'a').toChar.toString * 5 + + (a - 1 + 'a').toChar.toString.toUpperCase(Locale.ROOT) * 5 filters.forall(translateFilterOnA(_)(a)) && filters.forall(translateFilterOnC(_)(c)) } @@ -151,7 +154,7 @@ class FilteredScanSuite extends DataSourceTest with SharedSQLContext with Predic sqlTest( "SELECT * FROM oneToTenFiltered", (1 to 10).map(i => Row(i, i * 2, (i - 1 + 'a').toChar.toString * 5 - + (i - 1 + 'a').toChar.toString.toUpperCase * 5)).toSeq) + + (i - 1 + 'a').toChar.toString.toUpperCase(Locale.ROOT) * 5)).toSeq) sqlTest( "SELECT a, b FROM oneToTenFiltered", diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSinkSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSinkSuite.scala index f67444fbc49d6..1211242b9fbb4 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSinkSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSinkSuite.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.streaming +import java.util.Locale + import org.apache.spark.sql.{AnalysisException, DataFrame} import org.apache.spark.sql.execution.DataSourceScanExec import org.apache.spark.sql.execution.datasources._ @@ -221,7 +223,7 @@ class FileStreamSinkSuite extends StreamTest { df.writeStream.format("parquet").outputMode(mode).start(dir.getCanonicalPath) } Seq(mode, "not support").foreach { w => - assert(e.getMessage.toLowerCase.contains(w)) + assert(e.getMessage.toLowerCase(Locale.ROOT).contains(w)) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala index e5d5b4f328820..f796a4cb4a398 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.streaming -import java.util.TimeZone +import java.util.{Locale, TimeZone} import org.scalatest.BeforeAndAfterAll @@ -105,7 +105,7 @@ class StreamingAggregationSuite extends StateStoreMetricsTest with BeforeAndAfte testStream(aggregated, Append)() } Seq("append", "not supported").foreach { m => - assert(e.getMessage.toLowerCase.contains(m.toLowerCase)) + assert(e.getMessage.toLowerCase(Locale.ROOT).contains(m.toLowerCase(Locale.ROOT))) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/test/DataStreamReaderWriterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/test/DataStreamReaderWriterSuite.scala index 05cd3d9f7c2fa..dc2506a48ad00 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/test/DataStreamReaderWriterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/test/DataStreamReaderWriterSuite.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.streaming.test import java.io.File +import java.util.Locale import java.util.concurrent.TimeUnit import scala.concurrent.duration._ @@ -126,7 +127,7 @@ class DataStreamReaderWriterSuite extends StreamTest with BeforeAndAfter { .save() } Seq("'write'", "not", "streaming Dataset/DataFrame").foreach { s => - assert(e.getMessage.toLowerCase.contains(s.toLowerCase)) + assert(e.getMessage.toLowerCase(Locale.ROOT).contains(s.toLowerCase(Locale.ROOT))) } } @@ -400,7 +401,7 @@ class DataStreamReaderWriterSuite extends StreamTest with BeforeAndAfter { var w = df.writeStream var e = intercept[IllegalArgumentException](w.foreach(null)) Seq("foreach", "null").foreach { s => - assert(e.getMessage.toLowerCase.contains(s.toLowerCase)) + assert(e.getMessage.toLowerCase(Locale.ROOT).contains(s.toLowerCase(Locale.ROOT))) } } @@ -417,7 +418,7 @@ class DataStreamReaderWriterSuite extends StreamTest with BeforeAndAfter { var w = df.writeStream.partitionBy("value") var e = intercept[AnalysisException](w.foreach(foreachWriter).start()) Seq("foreach", "partitioning").foreach { s => - assert(e.getMessage.toLowerCase.contains(s.toLowerCase)) + assert(e.getMessage.toLowerCase(Locale.ROOT).contains(s.toLowerCase(Locale.ROOT))) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/DataFrameReaderWriterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/DataFrameReaderWriterSuite.scala index 7c71e7280c6d3..fb15e7def6dbe 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/test/DataFrameReaderWriterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/DataFrameReaderWriterSuite.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.test import java.io.File +import java.util.Locale import java.util.concurrent.ConcurrentLinkedQueue import org.scalatest.BeforeAndAfter @@ -144,7 +145,7 @@ class DataFrameReaderWriterSuite extends QueryTest with SharedSQLContext with Be .start() } Seq("'writeStream'", "only", "streaming Dataset/DataFrame").foreach { s => - assert(e.getMessage.toLowerCase.contains(s.toLowerCase)) + assert(e.getMessage.toLowerCase(Locale.ROOT).contains(s.toLowerCase(Locale.ROOT))) } } @@ -276,13 +277,13 @@ class DataFrameReaderWriterSuite extends QueryTest with SharedSQLContext with Be var w = df.write.partitionBy("value") var e = intercept[AnalysisException](w.jdbc(null, null, null)) Seq("jdbc", "partitioning").foreach { s => - assert(e.getMessage.toLowerCase.contains(s.toLowerCase)) + assert(e.getMessage.toLowerCase(Locale.ROOT).contains(s.toLowerCase(Locale.ROOT))) } w = df.write.bucketBy(2, "value") e = intercept[AnalysisException](w.jdbc(null, null, null)) Seq("jdbc", "bucketing").foreach { s => - assert(e.getMessage.toLowerCase.contains(s.toLowerCase)) + assert(e.getMessage.toLowerCase(Locale.ROOT).contains(s.toLowerCase(Locale.ROOT))) } } @@ -385,7 +386,8 @@ class DataFrameReaderWriterSuite extends QueryTest with SharedSQLContext with Be // Reader, with user specified schema, should just apply user schema on the file data val e = intercept[AnalysisException] { spark.read.schema(userSchema).textFile() } - assert(e.getMessage.toLowerCase.contains("user specified schema not supported")) + assert(e.getMessage.toLowerCase(Locale.ROOT).contains( + "user specified schema not supported")) intercept[AnalysisException] { spark.read.schema(userSchema).textFile(dir) } intercept[AnalysisException] { spark.read.schema(userSchema).textFile(dir, dir) } intercept[AnalysisException] { spark.read.schema(userSchema).textFile(Seq(dir, dir): _*) } diff --git a/sql/hive-thriftserver/src/main/java/org/apache/hive/service/auth/HiveAuthFactory.java b/sql/hive-thriftserver/src/main/java/org/apache/hive/service/auth/HiveAuthFactory.java index 1e6ac4f3df475..c5ade65283045 100644 --- a/sql/hive-thriftserver/src/main/java/org/apache/hive/service/auth/HiveAuthFactory.java +++ b/sql/hive-thriftserver/src/main/java/org/apache/hive/service/auth/HiveAuthFactory.java @@ -24,6 +24,7 @@ import java.util.Arrays; import java.util.HashMap; import java.util.List; +import java.util.Locale; import java.util.Map; import javax.net.ssl.SSLServerSocket; @@ -259,12 +260,12 @@ public static TServerSocket getServerSSLSocket(String hiveHost, int portNum, Str if (thriftServerSocket.getServerSocket() instanceof SSLServerSocket) { List sslVersionBlacklistLocal = new ArrayList(); for (String sslVersion : sslVersionBlacklist) { - sslVersionBlacklistLocal.add(sslVersion.trim().toLowerCase()); + sslVersionBlacklistLocal.add(sslVersion.trim().toLowerCase(Locale.ROOT)); } SSLServerSocket sslServerSocket = (SSLServerSocket) thriftServerSocket.getServerSocket(); List enabledProtocols = new ArrayList(); for (String protocol : sslServerSocket.getEnabledProtocols()) { - if (sslVersionBlacklistLocal.contains(protocol.toLowerCase())) { + if (sslVersionBlacklistLocal.contains(protocol.toLowerCase(Locale.ROOT))) { LOG.debug("Disabling SSL Protocol: " + protocol); } else { enabledProtocols.add(protocol); diff --git a/sql/hive-thriftserver/src/main/java/org/apache/hive/service/auth/SaslQOP.java b/sql/hive-thriftserver/src/main/java/org/apache/hive/service/auth/SaslQOP.java index ab3ac6285aa02..ad4dfd75f4707 100644 --- a/sql/hive-thriftserver/src/main/java/org/apache/hive/service/auth/SaslQOP.java +++ b/sql/hive-thriftserver/src/main/java/org/apache/hive/service/auth/SaslQOP.java @@ -19,6 +19,7 @@ package org.apache.hive.service.auth; import java.util.HashMap; +import java.util.Locale; import java.util.Map; /** @@ -52,7 +53,7 @@ public String toString() { public static SaslQOP fromString(String str) { if (str != null) { - str = str.toLowerCase(); + str = str.toLowerCase(Locale.ROOT); } SaslQOP saslQOP = STR_TO_ENUM.get(str); if (saslQOP == null) { diff --git a/sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/Type.java b/sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/Type.java index a96d2ac371cd3..7752ec03a29b7 100644 --- a/sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/Type.java +++ b/sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/Type.java @@ -19,6 +19,7 @@ package org.apache.hive.service.cli; import java.sql.DatabaseMetaData; +import java.util.Locale; import org.apache.hadoop.hive.common.type.HiveDecimal; import org.apache.hive.service.cli.thrift.TTypeId; @@ -160,7 +161,7 @@ public static Type getType(String name) { if (name.equalsIgnoreCase(type.name)) { return type; } else if (type.isQualifiedType() || type.isComplexType()) { - if (name.toUpperCase().startsWith(type.name)) { + if (name.toUpperCase(Locale.ROOT).startsWith(type.name)) { return type; } } diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2.scala index 14553601b1d58..5e4734ad3ad25 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2.scala @@ -294,7 +294,7 @@ private[hive] class HiveThriftServer2(sqlContext: SQLContext) private def isHTTPTransportMode(hiveConf: HiveConf): Boolean = { val transportMode = hiveConf.getVar(ConfVars.HIVE_SERVER2_TRANSPORT_MODE) - transportMode.toLowerCase(Locale.ENGLISH).equals("http") + transportMode.toLowerCase(Locale.ROOT).equals("http") } diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala index 1bc5c3c62f045..d5cc3b3855045 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala @@ -302,7 +302,7 @@ private[hive] class SparkSQLCLIDriver extends CliDriver with Logging { override def processCmd(cmd: String): Int = { val cmd_trimmed: String = cmd.trim() - val cmd_lower = cmd_trimmed.toLowerCase(Locale.ENGLISH) + val cmd_lower = cmd_trimmed.toLowerCase(Locale.ROOT) val tokens: Array[String] = cmd_trimmed.split("\\s+") val cmd_1: String = cmd_trimmed.substring(tokens(0).length()).trim() if (cmd_lower.equals("quit") || @@ -310,7 +310,7 @@ private[hive] class SparkSQLCLIDriver extends CliDriver with Logging { sessionState.close() System.exit(0) } - if (tokens(0).toLowerCase(Locale.ENGLISH).equals("source") || + if (tokens(0).toLowerCase(Locale.ROOT).equals("source") || cmd_trimmed.startsWith("!") || isRemoteMode) { val start = System.currentTimeMillis() super.processCmd(cmd) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala index f0e35dff57f7b..806f2be5faeb0 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.hive import java.io.IOException import java.lang.reflect.InvocationTargetException import java.util +import java.util.Locale import scala.collection.mutable import scala.util.control.NonFatal @@ -499,7 +500,7 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat // We can't use `filterKeys` here, as the map returned by `filterKeys` is not serializable, // while `CatalogTable` should be serializable. val propsWithoutPath = table.storage.properties.filter { - case (k, v) => k.toLowerCase != "path" + case (k, v) => k.toLowerCase(Locale.ROOT) != "path" } table.storage.copy(properties = propsWithoutPath ++ newPath.map("path" -> _)) } @@ -1060,7 +1061,7 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat // Hive's metastore is case insensitive. However, Hive's createFunction does // not normalize the function name (unlike the getFunction part). So, // we are normalizing the function name. - val functionName = funcDefinition.identifier.funcName.toLowerCase + val functionName = funcDefinition.identifier.funcName.toLowerCase(Locale.ROOT) requireFunctionNotExists(db, functionName) val functionIdentifier = funcDefinition.identifier.copy(funcName = functionName) client.createFunction(db, funcDefinition.copy(identifier = functionIdentifier)) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala index 9e3eb2dd8234a..c917f110b90f2 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.hive +import java.util.Locale + import scala.util.{Failure, Success, Try} import scala.util.control.NonFatal @@ -143,7 +145,7 @@ private[sql] class HiveSessionCatalog( // This function is not in functionRegistry, let's try to load it as a Hive's // built-in function. // Hive is case insensitive. - val functionName = funcName.unquotedString.toLowerCase + val functionName = funcName.unquotedString.toLowerCase(Locale.ROOT) if (!hiveFunctions.contains(functionName)) { failFunctionLookup(funcName.unquotedString) } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala index 0465e9c031e27..09a5eda6e543f 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.hive import java.io.IOException +import java.util.Locale import org.apache.hadoop.fs.{FileSystem, Path} import org.apache.hadoop.hive.common.StatsSetupConst @@ -184,14 +185,14 @@ case class RelationConversions( conf: SQLConf, sessionCatalog: HiveSessionCatalog) extends Rule[LogicalPlan] { private def isConvertible(relation: CatalogRelation): Boolean = { - (relation.tableMeta.storage.serde.getOrElse("").toLowerCase.contains("parquet") && - conf.getConf(HiveUtils.CONVERT_METASTORE_PARQUET)) || - (relation.tableMeta.storage.serde.getOrElse("").toLowerCase.contains("orc") && - conf.getConf(HiveUtils.CONVERT_METASTORE_ORC)) + val serde = relation.tableMeta.storage.serde.getOrElse("").toLowerCase(Locale.ROOT) + serde.contains("parquet") && conf.getConf(HiveUtils.CONVERT_METASTORE_PARQUET) || + serde.contains("orc") && conf.getConf(HiveUtils.CONVERT_METASTORE_ORC) } private def convert(relation: CatalogRelation): LogicalRelation = { - if (relation.tableMeta.storage.serde.getOrElse("").toLowerCase.contains("parquet")) { + val serde = relation.tableMeta.storage.serde.getOrElse("").toLowerCase(Locale.ROOT) + if (serde.contains("parquet")) { val options = Map(ParquetOptions.MERGE_SCHEMA -> conf.getConf(HiveUtils.CONVERT_METASTORE_PARQUET_WITH_SCHEMA_MERGING).toString) sessionCatalog.metastoreCatalog diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveUtils.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveUtils.scala index afc2bf85334d0..3de60c7fc1318 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveUtils.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveUtils.scala @@ -21,6 +21,7 @@ import java.io.File import java.net.{URL, URLClassLoader} import java.nio.charset.StandardCharsets import java.sql.Timestamp +import java.util.Locale import java.util.concurrent.TimeUnit import scala.collection.mutable.HashMap @@ -338,7 +339,7 @@ private[spark] object HiveUtils extends Logging { logWarning(s"Hive jar path '$path' does not exist.") Nil } else { - files.filter(_.getName.toLowerCase.endsWith(".jar")) + files.filter(_.getName.toLowerCase(Locale.ROOT).endsWith(".jar")) } case path => new File(path) :: Nil diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala index 56ccac32a8d88..387ec4f967233 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.hive.client import java.io.{File, PrintStream} +import java.util.Locale import scala.collection.JavaConverters._ import scala.collection.mutable.ArrayBuffer @@ -153,7 +154,7 @@ private[hive] class HiveClientImpl( hadoopConf.iterator().asScala.foreach { entry => val key = entry.getKey val value = entry.getValue - if (key.toLowerCase.contains("password")) { + if (key.toLowerCase(Locale.ROOT).contains("password")) { logDebug(s"Applying Hadoop and Hive config to Hive Conf: $key=xxx") } else { logDebug(s"Applying Hadoop and Hive config to Hive Conf: $key=$value") @@ -168,7 +169,7 @@ private[hive] class HiveClientImpl( hiveConf.setClassLoader(initClassLoader) // 2: we set all spark confs to this hiveConf. sparkConf.getAll.foreach { case (k, v) => - if (k.toLowerCase.contains("password")) { + if (k.toLowerCase(Locale.ROOT).contains("password")) { logDebug(s"Applying Spark config to Hive Conf: $k=xxx") } else { logDebug(s"Applying Spark config to Hive Conf: $k=$v") @@ -177,7 +178,7 @@ private[hive] class HiveClientImpl( } // 3: we set all entries in config to this hiveConf. extraConfig.foreach { case (k, v) => - if (k.toLowerCase.contains("password")) { + if (k.toLowerCase(Locale.ROOT).contains("password")) { logDebug(s"Applying extra config to HiveConf: $k=xxx") } else { logDebug(s"Applying extra config to HiveConf: $k=$v") @@ -622,7 +623,7 @@ private[hive] class HiveClientImpl( */ protected def runHive(cmd: String, maxRows: Int = 1000): Seq[String] = withHiveState { logDebug(s"Running hiveql '$cmd'") - if (cmd.toLowerCase.startsWith("set")) { logDebug(s"Changing config: $cmd") } + if (cmd.toLowerCase(Locale.ROOT).startsWith("set")) { logDebug(s"Changing config: $cmd") } try { val cmd_trimmed: String = cmd.trim() val tokens: Array[String] = cmd_trimmed.split("\\s+") diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala index 2e35f39839488..7abb9f06b1310 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.hive.client import java.lang.{Boolean => JBoolean, Integer => JInteger, Long => JLong} import java.lang.reflect.{InvocationTargetException, Method, Modifier} import java.net.URI -import java.util.{ArrayList => JArrayList, List => JList, Map => JMap, Set => JSet} +import java.util.{ArrayList => JArrayList, List => JList, Locale, Map => JMap, Set => JSet} import java.util.concurrent.TimeUnit import scala.collection.JavaConverters._ @@ -505,8 +505,8 @@ private[client] class Shim_v0_13 extends Shim_v0_12 { private def toHiveFunction(f: CatalogFunction, db: String): HiveFunction = { val resourceUris = f.resources.map { resource => - new ResourceUri( - ResourceType.valueOf(resource.resourceType.resourceType.toUpperCase()), resource.uri) + new ResourceUri(ResourceType.valueOf( + resource.resourceType.resourceType.toUpperCase(Locale.ROOT)), resource.uri) } new HiveFunction( f.identifier.funcName, diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveOptions.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveOptions.scala index 192851028031b..5c515515b9b9c 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveOptions.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveOptions.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.hive.execution +import java.util.Locale + import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap /** @@ -29,7 +31,7 @@ class HiveOptions(@transient private val parameters: CaseInsensitiveMap[String]) def this(parameters: Map[String, String]) = this(CaseInsensitiveMap(parameters)) - val fileFormat = parameters.get(FILE_FORMAT).map(_.toLowerCase) + val fileFormat = parameters.get(FILE_FORMAT).map(_.toLowerCase(Locale.ROOT)) val inputFormat = parameters.get(INPUT_FORMAT) val outputFormat = parameters.get(OUTPUT_FORMAT) @@ -75,7 +77,7 @@ class HiveOptions(@transient private val parameters: CaseInsensitiveMap[String]) } def serdeProperties: Map[String, String] = parameters.filterKeys { - k => !lowerCasedOptionNames.contains(k.toLowerCase) + k => !lowerCasedOptionNames.contains(k.toLowerCase(Locale.ROOT)) }.map { case (k, v) => delimiterOptions.getOrElse(k, k) -> v } } @@ -83,7 +85,7 @@ object HiveOptions { private val lowerCasedOptionNames = collection.mutable.Set[String]() private def newOption(name: String): String = { - lowerCasedOptionNames += name.toLowerCase + lowerCasedOptionNames += name.toLowerCase(Locale.ROOT) name } @@ -99,5 +101,5 @@ object HiveOptions { // The following typo is inherited from Hive... "collectionDelim" -> "colelction.delim", "mapkeyDelim" -> "mapkey.delim", - "lineDelim" -> "line.delim").map { case (k, v) => k.toLowerCase -> v } + "lineDelim" -> "line.delim").map { case (k, v) => k.toLowerCase(Locale.ROOT) -> v } } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcOptions.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcOptions.scala index ccaa568dcce2a..043eb69818ba1 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcOptions.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcOptions.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.hive.orc +import java.util.Locale + import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap /** @@ -41,9 +43,9 @@ private[orc] class OrcOptions(@transient private val parameters: CaseInsensitive val codecName = parameters .get("compression") .orElse(orcCompressionConf) - .getOrElse("snappy").toLowerCase + .getOrElse("snappy").toLowerCase(Locale.ROOT) if (!shortOrcCompressionCodecNames.contains(codecName)) { - val availableCodecs = shortOrcCompressionCodecNames.keys.map(_.toLowerCase) + val availableCodecs = shortOrcCompressionCodecNames.keys.map(_.toLowerCase(Locale.ROOT)) throw new IllegalArgumentException(s"Codec [$codecName] " + s"is not available. Available codecs are ${availableCodecs.mkString(", ")}.") } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDDLCommandSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDDLCommandSuite.scala index 490e02d0bd541..59cc6605a1243 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDDLCommandSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDDLCommandSuite.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.hive import java.net.URI +import java.util.Locale import org.apache.spark.sql.{AnalysisException, SaveMode} import org.apache.spark.sql.catalyst.TableIdentifier @@ -49,7 +50,7 @@ class HiveDDLCommandSuite extends PlanTest with SQLTestUtils with TestHiveSingle val e = intercept[ParseException] { parser.parsePlan(sql) } - assert(e.getMessage.toLowerCase.contains("operation not allowed")) + assert(e.getMessage.toLowerCase(Locale.ROOT).contains("operation not allowed")) } private def analyzeCreateTable(sql: String): CatalogTable = { diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSchemaInferenceSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSchemaInferenceSuite.scala index e48ce2304d086..319d02613f00a 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSchemaInferenceSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSchemaInferenceSuite.scala @@ -18,18 +18,15 @@ package org.apache.spark.sql.hive import java.io.File -import java.util.concurrent.{Executors, TimeUnit} import scala.util.Random import org.scalatest.BeforeAndAfterEach -import org.apache.spark.metrics.source.HiveCatalogMetrics import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.catalog._ import org.apache.spark.sql.execution.datasources.FileStatusCache import org.apache.spark.sql.QueryTest -import org.apache.spark.sql.hive.client.HiveClient import org.apache.spark.sql.hive.test.TestHiveSingleton import org.apache.spark.sql.internal.{HiveSerDe, SQLConf} import org.apache.spark.sql.internal.SQLConf.HiveCaseSensitiveInferenceMode.{Value => InferenceMode, _} diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala index e45cf977bfaa2..abe5d835719b6 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.hive.execution import java.io._ import java.nio.charset.StandardCharsets import java.util +import java.util.Locale import scala.util.control.NonFatal @@ -299,10 +300,11 @@ abstract class HiveComparisonTest // thus the tables referenced in those DDL commands cannot be extracted for use by our // test table auto-loading mechanism. In addition, the tests which use the SHOW TABLES // command expect these tables to exist. - val hasShowTableCommand = queryList.exists(_.toLowerCase.contains("show tables")) + val hasShowTableCommand = + queryList.exists(_.toLowerCase(Locale.ROOT).contains("show tables")) for (table <- Seq("src", "srcpart")) { val hasMatchingQuery = queryList.exists { query => - val normalizedQuery = query.toLowerCase.stripSuffix(";") + val normalizedQuery = query.toLowerCase(Locale.ROOT).stripSuffix(";") normalizedQuery.endsWith(table) || normalizedQuery.contains(s"from $table") || normalizedQuery.contains(s"from default.$table") @@ -444,7 +446,7 @@ abstract class HiveComparisonTest "create table", "drop index" ) - !queryList.map(_.toLowerCase).exists { query => + !queryList.map(_.toLowerCase(Locale.ROOT)).exists { query => excludedSubstrings.exists(s => query.contains(s)) } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala index 65a902fc5438e..cf33760360724 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala @@ -80,7 +80,7 @@ class HiveQuerySuite extends HiveComparisonTest with SQLTestUtils with BeforeAnd private def assertUnsupportedFeature(body: => Unit): Unit = { val e = intercept[ParseException] { body } - assert(e.getMessage.toLowerCase.contains("operation not allowed")) + assert(e.getMessage.toLowerCase(Locale.ROOT).contains("operation not allowed")) } // Testing the Broadcast based join for cartesian join (cross join) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala index d012797e19926..75f3744ff35be 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.hive.execution import java.io.File import java.nio.charset.StandardCharsets import java.sql.{Date, Timestamp} +import java.util.Locale import com.google.common.io.Files import org.apache.hadoop.fs.Path @@ -475,13 +476,13 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton { case None => // OK. } // Also make sure that the format and serde are as desired. - assert(catalogTable.storage.inputFormat.get.toLowerCase.contains(format)) - assert(catalogTable.storage.outputFormat.get.toLowerCase.contains(format)) + assert(catalogTable.storage.inputFormat.get.toLowerCase(Locale.ROOT).contains(format)) + assert(catalogTable.storage.outputFormat.get.toLowerCase(Locale.ROOT).contains(format)) val serde = catalogTable.storage.serde.get format match { case "sequence" | "text" => assert(serde.contains("LazySimpleSerDe")) case "rcfile" => assert(serde.contains("LazyBinaryColumnarSerDe")) - case _ => assert(serde.toLowerCase.contains(format)) + case _ => assert(serde.toLowerCase(Locale.ROOT).contains(format)) } } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/InputDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/InputDStream.scala index 9a760e2947d0b..931f015f03b6f 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/InputDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/InputDStream.scala @@ -17,6 +17,8 @@ package org.apache.spark.streaming.dstream +import java.util.Locale + import scala.reflect.ClassTag import org.apache.spark.SparkContext @@ -60,7 +62,7 @@ abstract class InputDStream[T: ClassTag](_ssc: StreamingContext) .split("(?=[A-Z])") .filter(_.nonEmpty) .mkString(" ") - .toLowerCase + .toLowerCase(Locale.ROOT) .capitalize s"$newName [$id]" } @@ -74,7 +76,7 @@ abstract class InputDStream[T: ClassTag](_ssc: StreamingContext) protected[streaming] override val baseScope: Option[String] = { val scopeName = Option(ssc.sc.getLocalProperty(SparkContext.RDD_SCOPE_KEY)) .map { json => RDDOperationScope.fromJson(json).name + s" [$id]" } - .getOrElse(name.toLowerCase) + .getOrElse(name.toLowerCase(Locale.ROOT)) Some(new RDDOperationScope(scopeName).toJson) } diff --git a/streaming/src/test/java/test/org/apache/spark/streaming/Java8APISuite.java b/streaming/src/test/java/test/org/apache/spark/streaming/Java8APISuite.java index 80513de4ee117..90d1f8c5035b3 100644 --- a/streaming/src/test/java/test/org/apache/spark/streaming/Java8APISuite.java +++ b/streaming/src/test/java/test/org/apache/spark/streaming/Java8APISuite.java @@ -101,7 +101,7 @@ public void testMapPartitions() { JavaDStream mapped = stream.mapPartitions(in -> { String out = ""; while (in.hasNext()) { - out = out + in.next().toUpperCase(); + out = out + in.next().toUpperCase(Locale.ROOT); } return Arrays.asList(out).iterator(); }); @@ -806,7 +806,8 @@ public void testMapValues() { ssc, inputData, 1); JavaPairDStream pairStream = JavaPairDStream.fromJavaDStream(stream); - JavaPairDStream mapped = pairStream.mapValues(String::toUpperCase); + JavaPairDStream mapped = + pairStream.mapValues(s -> s.toUpperCase(Locale.ROOT)); JavaTestUtils.attachTestOutputStream(mapped); List>> result = JavaTestUtils.runStreams(ssc, 2, 2); diff --git a/streaming/src/test/java/test/org/apache/spark/streaming/JavaAPISuite.java b/streaming/src/test/java/test/org/apache/spark/streaming/JavaAPISuite.java index 96f8d9593d630..6c86cacec8279 100644 --- a/streaming/src/test/java/test/org/apache/spark/streaming/JavaAPISuite.java +++ b/streaming/src/test/java/test/org/apache/spark/streaming/JavaAPISuite.java @@ -267,7 +267,7 @@ public void testMapPartitions() { JavaDStream mapped = stream.mapPartitions(in -> { StringBuilder out = new StringBuilder(); while (in.hasNext()) { - out.append(in.next().toUpperCase(Locale.ENGLISH)); + out.append(in.next().toUpperCase(Locale.ROOT)); } return Arrays.asList(out.toString()).iterator(); }); @@ -1315,7 +1315,7 @@ public void testMapValues() { JavaPairDStream pairStream = JavaPairDStream.fromJavaDStream(stream); JavaPairDStream mapped = - pairStream.mapValues(s -> s.toUpperCase(Locale.ENGLISH)); + pairStream.mapValues(s -> s.toUpperCase(Locale.ROOT)); JavaTestUtils.attachTestOutputStream(mapped); List>> result = JavaTestUtils.runStreams(ssc, 2, 2); diff --git a/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala index 5645996de5a69..eb996c93ff381 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala @@ -18,6 +18,7 @@ package org.apache.spark.streaming import java.io.{File, NotSerializableException} +import java.util.Locale import java.util.concurrent.{CountDownLatch, TimeUnit} import java.util.concurrent.atomic.AtomicInteger @@ -745,7 +746,7 @@ class StreamingContextSuite extends SparkFunSuite with BeforeAndAfter with Timeo val ex = intercept[IllegalStateException] { body } - assert(ex.getMessage.toLowerCase().contains(expectedErrorMsg)) + assert(ex.getMessage.toLowerCase(Locale.ROOT).contains(expectedErrorMsg)) } } From f6dd8e0e1673aa491b895c1f0467655fa4e9d52f Mon Sep 17 00:00:00 2001 From: Bogdan Raducanu Date: Mon, 10 Apr 2017 21:56:21 +0200 Subject: [PATCH 0236/1765] [SPARK-20280][CORE] FileStatusCache Weigher integer overflow ## What changes were proposed in this pull request? Weigher.weigh needs to return Int but it is possible for an Array[FileStatus] to have size > Int.maxValue. To avoid this, the size is scaled down by a factor of 32. The maximumWeight of the cache is also scaled down by the same factor. ## How was this patch tested? New test in FileIndexSuite Author: Bogdan Raducanu Closes #17591 from bogdanrdc/SPARK-20280. --- .../datasources/FileStatusCache.scala | 47 ++++++++++++++----- .../datasources/FileIndexSuite.scala | 16 +++++++ 2 files changed, 50 insertions(+), 13 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileStatusCache.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileStatusCache.scala index 5d97558633146..aea27bd4c4d7f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileStatusCache.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileStatusCache.scala @@ -94,27 +94,48 @@ private class SharedInMemoryCache(maxSizeInBytes: Long) extends Logging { // Opaque object that uniquely identifies a shared cache user private type ClientId = Object + private val warnedAboutEviction = new AtomicBoolean(false) // we use a composite cache key in order to distinguish entries inserted by different clients - private val cache: Cache[(ClientId, Path), Array[FileStatus]] = CacheBuilder.newBuilder() - .weigher(new Weigher[(ClientId, Path), Array[FileStatus]] { + private val cache: Cache[(ClientId, Path), Array[FileStatus]] = { + // [[Weigher]].weigh returns Int so we could only cache objects < 2GB + // instead, the weight is divided by this factor (which is smaller + // than the size of one [[FileStatus]]). + // so it will support objects up to 64GB in size. + val weightScale = 32 + val weigher = new Weigher[(ClientId, Path), Array[FileStatus]] { override def weigh(key: (ClientId, Path), value: Array[FileStatus]): Int = { - (SizeEstimator.estimate(key) + SizeEstimator.estimate(value)).toInt - }}) - .removalListener(new RemovalListener[(ClientId, Path), Array[FileStatus]]() { - override def onRemoval(removed: RemovalNotification[(ClientId, Path), Array[FileStatus]]) - : Unit = { + val estimate = (SizeEstimator.estimate(key) + SizeEstimator.estimate(value)) / weightScale + if (estimate > Int.MaxValue) { + logWarning(s"Cached table partition metadata size is too big. Approximating to " + + s"${Int.MaxValue.toLong * weightScale}.") + Int.MaxValue + } else { + estimate.toInt + } + } + } + val removalListener = new RemovalListener[(ClientId, Path), Array[FileStatus]]() { + override def onRemoval( + removed: RemovalNotification[(ClientId, Path), + Array[FileStatus]]): Unit = { if (removed.getCause == RemovalCause.SIZE && - warnedAboutEviction.compareAndSet(false, true)) { + warnedAboutEviction.compareAndSet(false, true)) { logWarning( "Evicting cached table partition metadata from memory due to size constraints " + - "(spark.sql.hive.filesourcePartitionFileCacheSize = " + maxSizeInBytes + " bytes). " + - "This may impact query planning performance.") + "(spark.sql.hive.filesourcePartitionFileCacheSize = " + + maxSizeInBytes + " bytes). This may impact query planning performance.") } - }}) - .maximumWeight(maxSizeInBytes) - .build[(ClientId, Path), Array[FileStatus]]() + } + } + CacheBuilder.newBuilder() + .weigher(weigher) + .removalListener(removalListener) + .maximumWeight(maxSizeInBytes / weightScale) + .build[(ClientId, Path), Array[FileStatus]]() + } + /** * @return a FileStatusCache that does not share any entries with any other client, but does diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileIndexSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileIndexSuite.scala index 00f5d5db8f5f4..a9511cbd9e4cf 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileIndexSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileIndexSuite.scala @@ -29,6 +29,7 @@ import org.apache.spark.metrics.source.HiveCatalogMetrics import org.apache.spark.sql.catalyst.util._ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSQLContext +import org.apache.spark.util.{KnownSizeEstimation, SizeEstimator} class FileIndexSuite extends SharedSQLContext { @@ -220,6 +221,21 @@ class FileIndexSuite extends SharedSQLContext { assert(catalog.leafDirPaths.head == fs.makeQualified(dirPath)) } } + + test("SPARK-20280 - FileStatusCache with a partition with very many files") { + /* fake the size, otherwise we need to allocate 2GB of data to trigger this bug */ + class MyFileStatus extends FileStatus with KnownSizeEstimation { + override def estimatedSize: Long = 1000 * 1000 * 1000 + } + /* files * MyFileStatus.estimatedSize should overflow to negative integer + * so, make it between 2bn and 4bn + */ + val files = (1 to 3).map { i => + new MyFileStatus() + } + val fileStatusCache = FileStatusCache.getOrCreate(spark) + fileStatusCache.putLeafFiles(new Path("/tmp", "abc"), files.toArray) + } } class FakeParentPathFileSystem extends RawLocalFileSystem { From f9a50ba2d1bfa3f55199df031e71154611ba51f6 Mon Sep 17 00:00:00 2001 From: Shixiong Zhu Date: Mon, 10 Apr 2017 14:06:49 -0700 Subject: [PATCH 0237/1765] [SPARK-20285][TESTS] Increase the pyspark streaming test timeout to 30 seconds ## What changes were proposed in this pull request? Saw the following failure locally: ``` Traceback (most recent call last): File "/home/jenkins/workspace/python/pyspark/streaming/tests.py", line 351, in test_cogroup self._test_func(input, func, expected, sort=True, input2=input2) File "/home/jenkins/workspace/python/pyspark/streaming/tests.py", line 162, in _test_func self.assertEqual(expected, result) AssertionError: Lists differ: [[(1, ([1], [2])), (2, ([1], [... != [] First list contains 3 additional elements. First extra element 0: [(1, ([1], [2])), (2, ([1], [])), (3, ([1], []))] + [] - [[(1, ([1], [2])), (2, ([1], [])), (3, ([1], []))], - [(1, ([1, 1, 1], [])), (2, ([1], [])), (4, ([], [1]))], - [('', ([1, 1], [1, 2])), ('a', ([1, 1], [1, 1])), ('b', ([1], [1]))]] ``` It also happened on Jenkins: http://spark-tests.appspot.com/builds/spark-branch-2.1-test-sbt-hadoop-2.7/120 It's because when the machine is overloaded, the timeout is not enough. This PR just increases the timeout to 30 seconds. ## How was this patch tested? Jenkins Author: Shixiong Zhu Closes #17597 from zsxwing/SPARK-20285. --- python/pyspark/streaming/tests.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/pyspark/streaming/tests.py b/python/pyspark/streaming/tests.py index 1bec33509580c..ffba99502b148 100644 --- a/python/pyspark/streaming/tests.py +++ b/python/pyspark/streaming/tests.py @@ -55,7 +55,7 @@ class PySparkStreamingTestCase(unittest.TestCase): - timeout = 10 # seconds + timeout = 30 # seconds duration = .5 @classmethod From a35b9d97123697d23fa0f691c1054f9adab5956c Mon Sep 17 00:00:00 2001 From: Shixiong Zhu Date: Mon, 10 Apr 2017 14:09:32 -0700 Subject: [PATCH 0238/1765] [SPARK-20282][SS][TESTS] Write the commit log first to fix a race contion in tests ## What changes were proposed in this pull request? This PR fixes the following failure: ``` sbt.ForkMain$ForkError: org.scalatest.exceptions.TestFailedException: Assert on query failed: == Progress == AssertOnQuery(, ) StopStream AddData to MemoryStream[value#30891]: 1,2 StartStream(OneTimeTrigger,org.apache.spark.util.SystemClock35cdc93a,Map()) CheckAnswer: [6],[3] StopStream => AssertOnQuery(, ) AssertOnQuery(, ) StartStream(OneTimeTrigger,org.apache.spark.util.SystemClockcdb247d,Map()) CheckAnswer: [6],[3] StopStream AddData to MemoryStream[value#30891]: 3 StartStream(OneTimeTrigger,org.apache.spark.util.SystemClock55394e4d,Map()) CheckLastBatch: [2] StopStream AddData to MemoryStream[value#30891]: 0 StartStream(OneTimeTrigger,org.apache.spark.util.SystemClock749aa997,Map()) ExpectFailure[org.apache.spark.SparkException, isFatalError: false] AssertOnQuery(, ) AssertOnQuery(, incorrect start offset or end offset on exception) == Stream == Output Mode: Append Stream state: not started Thread state: dead == Sink == 0: [6] [3] == Plan == at org.scalatest.Assertions$class.newAssertionFailedException(Assertions.scala:495) at org.scalatest.FunSuite.newAssertionFailedException(FunSuite.scala:1555) at org.scalatest.Assertions$class.fail(Assertions.scala:1328) at org.scalatest.FunSuite.fail(FunSuite.scala:1555) at org.apache.spark.sql.streaming.StreamTest$class.failTest$1(StreamTest.scala:347) at org.apache.spark.sql.streaming.StreamTest$class.verify$1(StreamTest.scala:318) at org.apache.spark.sql.streaming.StreamTest$$anonfun$liftedTree1$1$1.apply(StreamTest.scala:483) at org.apache.spark.sql.streaming.StreamTest$$anonfun$liftedTree1$1$1.apply(StreamTest.scala:357) at scala.collection.mutable.ResizableArray$class.foreach(ResizableArray.scala:59) at scala.collection.mutable.ArrayBuffer.foreach(ArrayBuffer.scala:48) at org.apache.spark.sql.streaming.StreamTest$class.liftedTree1$1(StreamTest.scala:357) at org.apache.spark.sql.streaming.StreamTest$class.testStream(StreamTest.scala:356) at org.apache.spark.sql.streaming.StreamingQuerySuite.testStream(StreamingQuerySuite.scala:41) at org.apache.spark.sql.streaming.StreamingQuerySuite$$anonfun$6.apply$mcV$sp(StreamingQuerySuite.scala:166) at org.apache.spark.sql.streaming.StreamingQuerySuite$$anonfun$6.apply(StreamingQuerySuite.scala:161) at org.apache.spark.sql.streaming.StreamingQuerySuite$$anonfun$6.apply(StreamingQuerySuite.scala:161) at org.apache.spark.sql.catalyst.util.package$.quietly(package.scala:42) at org.apache.spark.sql.test.SQLTestUtils$$anonfun$testQuietly$1.apply$mcV$sp(SQLTestUtils.scala:268) at org.apache.spark.sql.test.SQLTestUtils$$anonfun$testQuietly$1.apply(SQLTestUtils.scala:268) at org.apache.spark.sql.test.SQLTestUtils$$anonfun$testQuietly$1.apply(SQLTestUtils.scala:268) at org.scalatest.Transformer$$anonfun$apply$1.apply$mcV$sp(Transformer.scala:22) at org.scalatest.OutcomeOf$class.outcomeOf(OutcomeOf.scala:85) at org.scalatest.OutcomeOf$.outcomeOf(OutcomeOf.scala:104) at org.scalatest.Transformer.apply(Transformer.scala:22) at org.scalatest.Transformer.apply(Transformer.scala:20) at org.scalatest.FunSuiteLike$$anon$1.apply(FunSuiteLike.scala:166) at org.apache.spark.SparkFunSuite.withFixture(SparkFunSuite.scala:68) at org.scalatest.FunSuiteLike$class.invokeWithFixture$1(FunSuiteLike.scala:163) at org.scalatest.FunSuiteLike$$anonfun$runTest$1.apply(FunSuiteLike.scala:175) at org.scalatest.FunSuiteLike$$anonfun$runTest$1.apply(FunSuiteLike.scala:175) at org.scalatest.SuperEngine.runTestImpl(Engine.scala:306) at org.scalatest.FunSuiteLike$class.runTest(FunSuiteLike.scala:175) at org.apache.spark.sql.streaming.StreamingQuerySuite.org$scalatest$BeforeAndAfterEach$$super$runTest(StreamingQuerySuite.scala:41) at org.scalatest.BeforeAndAfterEach$class.runTest(BeforeAndAfterEach.scala:255) at org.apache.spark.sql.streaming.StreamingQuerySuite.org$scalatest$BeforeAndAfter$$super$runTest(StreamingQuerySuite.scala:41) at org.scalatest.BeforeAndAfter$class.runTest(BeforeAndAfter.scala:200) at org.apache.spark.sql.streaming.StreamingQuerySuite.runTest(StreamingQuerySuite.scala:41) at org.scalatest.FunSuiteLike$$anonfun$runTests$1.apply(FunSuiteLike.scala:208) at org.scalatest.FunSuiteLike$$anonfun$runTests$1.apply(FunSuiteLike.scala:208) at org.scalatest.SuperEngine$$anonfun$traverseSubNodes$1$1.apply(Engine.scala:413) at org.scalatest.SuperEngine$$anonfun$traverseSubNodes$1$1.apply(Engine.scala:401) at scala.collection.immutable.List.foreach(List.scala:381) at org.scalatest.SuperEngine.traverseSubNodes$1(Engine.scala:401) at org.scalatest.SuperEngine.org$scalatest$SuperEngine$$runTestsInBranch(Engine.scala:396) at org.scalatest.SuperEngine.runTestsImpl(Engine.scala:483) at org.scalatest.FunSuiteLike$class.runTests(FunSuiteLike.scala:208) at org.scalatest.FunSuite.runTests(FunSuite.scala:1555) at org.scalatest.Suite$class.run(Suite.scala:1424) at org.scalatest.FunSuite.org$scalatest$FunSuiteLike$$super$run(FunSuite.scala:1555) at org.scalatest.FunSuiteLike$$anonfun$run$1.apply(FunSuiteLike.scala:212) at org.scalatest.FunSuiteLike$$anonfun$run$1.apply(FunSuiteLike.scala:212) at org.scalatest.SuperEngine.runImpl(Engine.scala:545) at org.scalatest.FunSuiteLike$class.run(FunSuiteLike.scala:212) at org.apache.spark.SparkFunSuite.org$scalatest$BeforeAndAfterAll$$super$run(SparkFunSuite.scala:31) at org.scalatest.BeforeAndAfterAll$class.liftedTree1$1(BeforeAndAfterAll.scala:257) at org.scalatest.BeforeAndAfterAll$class.run(BeforeAndAfterAll.scala:256) at org.apache.spark.sql.streaming.StreamingQuerySuite.org$scalatest$BeforeAndAfter$$super$run(StreamingQuerySuite.scala:41) at org.scalatest.BeforeAndAfter$class.run(BeforeAndAfter.scala:241) at org.apache.spark.sql.streaming.StreamingQuerySuite.run(StreamingQuerySuite.scala:41) at org.scalatest.tools.Framework.org$scalatest$tools$Framework$$runSuite(Framework.scala:357) at org.scalatest.tools.Framework$ScalaTestTask.execute(Framework.scala:502) at sbt.ForkMain$Run$2.call(ForkMain.java:296) at sbt.ForkMain$Run$2.call(ForkMain.java:286) at java.util.concurrent.FutureTask.run(FutureTask.java:266) at java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1142) at java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:617) at java.lang.Thread.run(Thread.java:745) ``` The failure is because `CheckAnswer` will run once `committedOffsets` is updated. Then writing the commit log may be interrupted by the following `StopStream`. This PR just change the order to write the commit log first. ## How was this patch tested? Jenkins Author: Shixiong Zhu Closes #17594 from zsxwing/SPARK-20282. --- .../apache/spark/sql/execution/streaming/StreamExecution.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala index 5f548172f5ced..8857966676ae2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala @@ -304,8 +304,8 @@ class StreamExecution( finishTrigger(dataAvailable) if (dataAvailable) { // Update committed offsets. - committedOffsets ++= availableOffsets batchCommitLog.add(currentBatchId) + committedOffsets ++= availableOffsets logDebug(s"batch ${currentBatchId} committed") // We'll increase currentBatchId after we complete processing current batch's data currentBatchId += 1 From 379b0b0bbdbba2278ce3bcf471bd75f6ffd9cf0d Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Mon, 10 Apr 2017 14:14:09 -0700 Subject: [PATCH 0239/1765] [SPARK-20283][SQL] Add preOptimizationBatches ## What changes were proposed in this pull request? We currently have postHocOptimizationBatches, but not preOptimizationBatches. This patch adds preOptimizationBatches so the optimizer debugging extensions are symmetric. ## How was this patch tested? N/A Author: Reynold Xin Closes #17595 from rxin/SPARK-20283. --- .../org/apache/spark/sql/execution/SparkOptimizer.scala | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala index 2cdfb7a7828c9..1de4f508b89a0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala @@ -30,13 +30,19 @@ class SparkOptimizer( experimentalMethods: ExperimentalMethods) extends Optimizer(catalog, conf) { - override def batches: Seq[Batch] = (super.batches :+ + override def batches: Seq[Batch] = (preOptimizationBatches ++ super.batches :+ Batch("Optimize Metadata Only Query", Once, OptimizeMetadataOnlyQuery(catalog, conf)) :+ Batch("Extract Python UDF from Aggregate", Once, ExtractPythonUDFFromAggregate) :+ Batch("Prune File Source Table Partitions", Once, PruneFileSourcePartitions)) ++ postHocOptimizationBatches :+ Batch("User Provided Optimizers", fixedPoint, experimentalMethods.extraOptimizations: _*) + /** + * Optimization batches that are executed before the regular optimization batches (also before + * the finish analysis batch). + */ + def preOptimizationBatches: Seq[Batch] = Nil + /** * Optimization batches that are executed after the regular optimization batches, but before the * batch executing the [[ExperimentalMethods]] optimizer rules. This hook can be used to add From 734dfbfcfea1ed1ab3a5f18f84c412a569dd87e7 Mon Sep 17 00:00:00 2001 From: Shixiong Zhu Date: Mon, 10 Apr 2017 20:41:08 -0700 Subject: [PATCH 0240/1765] [SPARK-17564][TESTS] Fix flaky RequestTimeoutIntegrationSuite.furtherRequestsDelay ## What changes were proposed in this pull request? This PR fixs the following failure: ``` sbt.ForkMain$ForkError: java.lang.AssertionError: null at org.junit.Assert.fail(Assert.java:86) at org.junit.Assert.assertTrue(Assert.java:41) at org.junit.Assert.assertTrue(Assert.java:52) at org.apache.spark.network.RequestTimeoutIntegrationSuite.furtherRequestsDelay(RequestTimeoutIntegrationSuite.java:230) at sun.reflect.NativeMethodAccessorImpl.invoke0(Native Method) at sun.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:62) at sun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43) at java.lang.reflect.Method.invoke(Method.java:497) at org.junit.runners.model.FrameworkMethod$1.runReflectiveCall(FrameworkMethod.java:50) at org.junit.internal.runners.model.ReflectiveCallable.run(ReflectiveCallable.java:12) at org.junit.runners.model.FrameworkMethod.invokeExplosively(FrameworkMethod.java:47) at org.junit.internal.runners.statements.InvokeMethod.evaluate(InvokeMethod.java:17) at org.junit.internal.runners.statements.RunBefores.evaluate(RunBefores.java:26) at org.junit.internal.runners.statements.RunAfters.evaluate(RunAfters.java:27) at org.junit.runners.ParentRunner.runLeaf(ParentRunner.java:325) at org.junit.runners.BlockJUnit4ClassRunner.runChild(BlockJUnit4ClassRunner.java:78) at org.junit.runners.BlockJUnit4ClassRunner.runChild(BlockJUnit4ClassRunner.java:57) at org.junit.runners.ParentRunner$3.run(ParentRunner.java:290) at org.junit.runners.ParentRunner$1.schedule(ParentRunner.java:71) at org.junit.runners.ParentRunner.runChildren(ParentRunner.java:288) at org.junit.runners.ParentRunner.access$000(ParentRunner.java:58) at org.junit.runners.ParentRunner$2.evaluate(ParentRunner.java:268) at org.junit.runners.ParentRunner.run(ParentRunner.java:363) at org.junit.runners.Suite.runChild(Suite.java:128) at org.junit.runners.Suite.runChild(Suite.java:27) at org.junit.runners.ParentRunner$3.run(ParentRunner.java:290) at org.junit.runners.ParentRunner$1.schedule(ParentRunner.java:71) at org.junit.runners.ParentRunner.runChildren(ParentRunner.java:288) at org.junit.runners.ParentRunner.access$000(ParentRunner.java:58) at org.junit.runners.ParentRunner$2.evaluate(ParentRunner.java:268) at org.junit.runners.ParentRunner.run(ParentRunner.java:363) at org.junit.runner.JUnitCore.run(JUnitCore.java:137) at org.junit.runner.JUnitCore.run(JUnitCore.java:115) at com.novocode.junit.JUnitRunner$1.execute(JUnitRunner.java:132) at sbt.ForkMain$Run$2.call(ForkMain.java:296) at sbt.ForkMain$Run$2.call(ForkMain.java:286) at java.util.concurrent.FutureTask.run(FutureTask.java:266) at java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1142) at java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:617) at java.lang.Thread.run(Thread.java:745) ``` It happens several times per month on [Jenkins](http://spark-tests.appspot.com/test-details?suite_name=org.apache.spark.network.RequestTimeoutIntegrationSuite&test_name=furtherRequestsDelay). The failure is because `callback1` may not be called before `assertTrue(callback1.failure instanceof IOException);`. It's pretty easy to reproduce this error by adding a sleep before this line: https://github.com/apache/spark/blob/379b0b0bbdbba2278ce3bcf471bd75f6ffd9cf0d/common/network-common/src/test/java/org/apache/spark/network/RequestTimeoutIntegrationSuite.java#L267 The fix is straightforward: just use the latch to wait until `callback1` is called. ## How was this patch tested? Jenkins Author: Shixiong Zhu Closes #17599 from zsxwing/SPARK-17564. --- .../apache/spark/network/RequestTimeoutIntegrationSuite.java | 2 ++ 1 file changed, 2 insertions(+) diff --git a/common/network-common/src/test/java/org/apache/spark/network/RequestTimeoutIntegrationSuite.java b/common/network-common/src/test/java/org/apache/spark/network/RequestTimeoutIntegrationSuite.java index 9aa17e24b6246..c0724e018263f 100644 --- a/common/network-common/src/test/java/org/apache/spark/network/RequestTimeoutIntegrationSuite.java +++ b/common/network-common/src/test/java/org/apache/spark/network/RequestTimeoutIntegrationSuite.java @@ -225,6 +225,8 @@ public StreamManager getStreamManager() { callback0.latch.await(60, TimeUnit.SECONDS); assertTrue(callback0.failure instanceof IOException); + // make sure callback1 is called. + callback1.latch.await(60, TimeUnit.SECONDS); // failed at same time as previous assertTrue(callback1.failure instanceof IOException); } From 0d2b796427a59d3e9967b62618be301307f29162 Mon Sep 17 00:00:00 2001 From: Benjamin Fradet Date: Tue, 11 Apr 2017 09:12:49 +0200 Subject: [PATCH 0241/1765] [SPARK-20097][ML] Fix visibility discrepancy with numInstances and degreesOfFreedom in LR and GLR ## What changes were proposed in this pull request? - made `numInstances` public in GLR - made `degreesOfFreedom` public in LR ## How was this patch tested? reran the concerned test suites Author: Benjamin Fradet Closes #17431 from BenFradet/SPARK-20097. --- .../spark/ml/regression/GeneralizedLinearRegression.scala | 3 ++- .../org/apache/spark/ml/regression/LinearRegression.scala | 3 ++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala index 33137b0c0fdec..d6093a01c671c 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala @@ -1133,7 +1133,8 @@ class GeneralizedLinearRegressionSummary private[regression] ( private[regression] lazy val link: Link = familyLink.link /** Number of instances in DataFrame predictions. */ - private[regression] lazy val numInstances: Long = predictions.count() + @Since("2.2.0") + lazy val numInstances: Long = predictions.count() /** The numeric rank of the fitted linear model. */ @Since("2.0.0") diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala index 45df1d9be647d..f7e3c8fa5b6e6 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala @@ -696,7 +696,8 @@ class LinearRegressionSummary private[regression] ( lazy val numInstances: Long = predictions.count() /** Degrees of freedom */ - private val degreesOfFreedom: Long = if (privateModel.getFitIntercept) { + @Since("2.2.0") + val degreesOfFreedom: Long = if (privateModel.getFitIntercept) { numInstances - privateModel.coefficients.size - 1 } else { numInstances - privateModel.coefficients.size From d11ef3d77ec2136d6b28bd69f5dd2cc0a22e4717 Mon Sep 17 00:00:00 2001 From: MirrorZ Date: Tue, 11 Apr 2017 10:34:39 +0100 Subject: [PATCH 0242/1765] Document Master URL format in high availability set up ## What changes were proposed in this pull request? Add documentation for adding master url in multi host, port format for standalone cluster with high availability with zookeeper. Referring documentation [Standby Masters with ZooKeeper](http://spark.apache.org/docs/latest/spark-standalone.html#standby-masters-with-zookeeper) ## How was this patch tested? Documenting the functionality already present. Author: MirrorZ Closes #17584 from MirrorZ/master. --- docs/submitting-applications.md | 3 +++ 1 file changed, 3 insertions(+) diff --git a/docs/submitting-applications.md b/docs/submitting-applications.md index d23dbcf10d952..866d6e527549c 100644 --- a/docs/submitting-applications.md +++ b/docs/submitting-applications.md @@ -143,6 +143,9 @@ The master URL passed to Spark can be in one of the following formats: spark://HOST:PORT Connect to the given Spark standalone cluster master. The port must be whichever one your master is configured to use, which is 7077 by default. + spark://HOST1:PORT1,HOST2:PORT2 Connect to the given Spark standalone + cluster with standby masters with Zookeeper. The list must have all the master hosts in the high availability cluster set up with Zookeeper. The port must be whichever each master is configured to use, which is 7077 by default. + mesos://HOST:PORT Connect to the given Mesos cluster. The port must be whichever one your is configured to use, which is 5050 by default. Or, for a Mesos cluster using ZooKeeper, use mesos://zk://.... From c8706980ae07362ae5963829e9ada5007eada46b Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Tue, 11 Apr 2017 20:21:04 +0800 Subject: [PATCH 0243/1765] [SPARK-20274][SQL] support compatible array element type in encoder ## What changes were proposed in this pull request? This is a regression caused by SPARK-19716. Before SPARK-19716, we will cast an array field to the expected array type. However, after SPARK-19716, the cast is removed, but we forgot to push the cast to the element level. ## How was this patch tested? new regression tests Author: Wenchen Fan Closes #17587 from cloud-fan/array. --- .../spark/sql/catalyst/ScalaReflection.scala | 18 +++++++++------ .../sql/catalyst/analysis/Analyzer.scala | 8 +++++-- .../encoders/EncoderResolutionSuite.scala | 23 +++++++++++++++++++ 3 files changed, 40 insertions(+), 9 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala index 198122759e4ad..0c5a818f54f5c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala @@ -132,7 +132,7 @@ object ScalaReflection extends ScalaReflection { def deserializerFor[T : TypeTag]: Expression = { val tpe = localTypeOf[T] val clsName = getClassNameFromType(tpe) - val walkedTypePath = s"""- root class: "${clsName}"""" :: Nil + val walkedTypePath = s"""- root class: "$clsName"""" :: Nil deserializerFor(tpe, None, walkedTypePath) } @@ -270,12 +270,14 @@ object ScalaReflection extends ScalaReflection { case t if t <:< localTypeOf[Array[_]] => val TypeRef(_, _, Seq(elementType)) = t - val Schema(_, elementNullable) = schemaFor(elementType) + val Schema(dataType, elementNullable) = schemaFor(elementType) val className = getClassNameFromType(elementType) val newTypePath = s"""- array element class: "$className"""" +: walkedTypePath - val mapFunction: Expression => Expression = p => { - val converter = deserializerFor(elementType, Some(p), newTypePath) + val mapFunction: Expression => Expression = element => { + // upcast the array element to the data type the encoder expected. + val casted = upCastToExpectedType(element, dataType, newTypePath) + val converter = deserializerFor(elementType, Some(casted), newTypePath) if (elementNullable) { converter } else { @@ -305,12 +307,14 @@ object ScalaReflection extends ScalaReflection { case t if t <:< localTypeOf[Seq[_]] => val TypeRef(_, _, Seq(elementType)) = t - val Schema(_, elementNullable) = schemaFor(elementType) + val Schema(dataType, elementNullable) = schemaFor(elementType) val className = getClassNameFromType(elementType) val newTypePath = s"""- array element class: "$className"""" +: walkedTypePath - val mapFunction: Expression => Expression = p => { - val converter = deserializerFor(elementType, Some(p), newTypePath) + val mapFunction: Expression => Expression = element => { + // upcast the array element to the data type the encoder expected. + val casted = upCastToExpectedType(element, dataType, newTypePath) + val converter = deserializerFor(elementType, Some(casted), newTypePath) if (elementNullable) { converter } else { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index b0cdef70297cf..9816b33ae8dff 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -26,7 +26,7 @@ import org.apache.spark.sql.catalyst.catalog._ import org.apache.spark.sql.catalyst.encoders.OuterScopes import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ -import org.apache.spark.sql.catalyst.expressions.objects.{MapObjects, NewInstance, UnresolvedMapObjects} +import org.apache.spark.sql.catalyst.expressions.objects.{LambdaVariable, MapObjects, NewInstance, UnresolvedMapObjects} import org.apache.spark.sql.catalyst.expressions.SubExprUtils._ import org.apache.spark.sql.catalyst.optimizer.BooleanSimplification import org.apache.spark.sql.catalyst.plans._ @@ -2321,7 +2321,11 @@ class Analyzer( */ object ResolveUpCast extends Rule[LogicalPlan] { private def fail(from: Expression, to: DataType, walkedTypePath: Seq[String]) = { - throw new AnalysisException(s"Cannot up cast ${from.sql} from " + + val fromStr = from match { + case l: LambdaVariable => "array element" + case e => e.sql + } + throw new AnalysisException(s"Cannot up cast $fromStr from " + s"${from.dataType.simpleString} to ${to.simpleString} as it may truncate\n" + "The type path of the target object is:\n" + walkedTypePath.mkString("", "\n", "\n") + "You can either add an explicit cast to the input data or choose a higher precision " + diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/EncoderResolutionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/EncoderResolutionSuite.scala index e5a3e1fd374dc..630e8a7990e7b 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/EncoderResolutionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/EncoderResolutionSuite.scala @@ -33,6 +33,8 @@ case class StringIntClass(a: String, b: Int) case class ComplexClass(a: Long, b: StringLongClass) +case class PrimitiveArrayClass(arr: Array[Long]) + case class ArrayClass(arr: Seq[StringIntClass]) case class NestedArrayClass(nestedArr: Array[ArrayClass]) @@ -66,6 +68,27 @@ class EncoderResolutionSuite extends PlanTest { encoder.resolveAndBind(attrs).fromRow(InternalRow(InternalRow(str, 1.toByte), 2)) } + test("real type doesn't match encoder schema but they are compatible: primitive array") { + val encoder = ExpressionEncoder[PrimitiveArrayClass] + val attrs = Seq('arr.array(IntegerType)) + val array = new GenericArrayData(Array(1, 2, 3)) + encoder.resolveAndBind(attrs).fromRow(InternalRow(array)) + } + + test("the real type is not compatible with encoder schema: primitive array") { + val encoder = ExpressionEncoder[PrimitiveArrayClass] + val attrs = Seq('arr.array(StringType)) + assert(intercept[AnalysisException](encoder.resolveAndBind(attrs)).message == + s""" + |Cannot up cast array element from string to bigint as it may truncate + |The type path of the target object is: + |- array element class: "scala.Long" + |- field (class: "scala.Array", name: "arr") + |- root class: "org.apache.spark.sql.catalyst.encoders.PrimitiveArrayClass" + |You can either add an explicit cast to the input data or choose a higher precision type + """.stripMargin.trim + " of the field in the target object") + } + test("real type doesn't match encoder schema but they are compatible: array") { val encoder = ExpressionEncoder[ArrayClass] val attrs = Seq('arr.array(new StructType().add("a", "int").add("b", "int").add("c", "int"))) From cd91f967145909852d9af09b10b80f86ed05edb5 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Tue, 11 Apr 2017 20:33:10 +0800 Subject: [PATCH 0244/1765] [SPARK-20175][SQL] Exists should not be evaluated in Join operator ## What changes were proposed in this pull request? Similar to `ListQuery`, `Exists` should not be evaluated in `Join` operator too. ## How was this patch tested? Jenkins tests. Please review http://spark.apache.org/contributing.html before opening a pull request. Author: Liang-Chi Hsieh Closes #17491 from viirya/dont-push-exists-to-join. --- .../spark/sql/catalyst/expressions/predicates.scala | 3 ++- .../scala/org/apache/spark/sql/SubquerySuite.scala | 10 ++++++++++ 2 files changed, 12 insertions(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala index 8acb740f8db8c..5034566132f7a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala @@ -92,11 +92,12 @@ trait PredicateHelper { protected def canEvaluateWithinJoin(expr: Expression): Boolean = expr match { // Non-deterministic expressions are not allowed as join conditions. case e if !e.deterministic => false - case l: ListQuery => + case _: ListQuery | _: Exists => // A ListQuery defines the query which we want to search in an IN subquery expression. // Currently the only way to evaluate an IN subquery is to convert it to a // LeftSemi/LeftAnti/ExistenceJoin by `RewritePredicateSubquery` rule. // It cannot be evaluated as part of a Join operator. + // An Exists shouldn't be push into a Join operator too. false case e: SubqueryExpression => // non-correlated subquery will be replaced as literal diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala index 5fe6667ceca18..0f0199cbe2777 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala @@ -844,4 +844,14 @@ class SubquerySuite extends QueryTest with SharedSQLContext { Row(0) :: Row(1) :: Nil) } } + + test("ListQuery and Exists should work even no correlated references") { + checkAnswer( + sql("select * from l, r where l.a = r.c AND (r.d in (select d from r) OR l.a >= 1)"), + Row(2, 1.0, 2, 3.0) :: Row(2, 1.0, 2, 3.0) :: Row(2, 1.0, 2, 3.0) :: + Row(2, 1.0, 2, 3.0) :: Row(3.0, 3.0, 3, 2.0) :: Row(6, null, 6, null) :: Nil) + checkAnswer( + sql("select * from l, r where l.a = r.c + 1 AND (exists (select * from r) OR l.a = r.c)"), + Row(3, 3.0, 2, 3.0) :: Row(3, 3.0, 2, 3.0) :: Nil) + } } From 123b4fbbc331f116b45f11b9f7ecbe0b0575323d Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Tue, 11 Apr 2017 11:12:31 -0700 Subject: [PATCH 0245/1765] [SPARK-20289][SQL] Use StaticInvoke to box primitive types ## What changes were proposed in this pull request? Dataset typed API currently uses NewInstance to box primitive types (i.e. calling the constructor). Instead, it'd be slightly more idiomatic in Java to use PrimitiveType.valueOf, which can be invoked using StaticInvoke expression. ## How was this patch tested? The change should be covered by existing tests for Dataset encoders. Author: Reynold Xin Closes #17604 from rxin/SPARK-20289. --- .../sql/catalyst/JavaTypeInference.scala | 27 +++++++++---------- .../spark/sql/catalyst/ScalaReflection.scala | 14 +++++----- 2 files changed, 20 insertions(+), 21 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala index 9d4617dda555f..86a73a319ec3f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala @@ -204,20 +204,19 @@ object JavaTypeInference { typeToken.getRawType match { case c if !inferExternalType(c).isInstanceOf[ObjectType] => getPath - case c if c == classOf[java.lang.Short] => - NewInstance(c, getPath :: Nil, ObjectType(c)) - case c if c == classOf[java.lang.Integer] => - NewInstance(c, getPath :: Nil, ObjectType(c)) - case c if c == classOf[java.lang.Long] => - NewInstance(c, getPath :: Nil, ObjectType(c)) - case c if c == classOf[java.lang.Double] => - NewInstance(c, getPath :: Nil, ObjectType(c)) - case c if c == classOf[java.lang.Byte] => - NewInstance(c, getPath :: Nil, ObjectType(c)) - case c if c == classOf[java.lang.Float] => - NewInstance(c, getPath :: Nil, ObjectType(c)) - case c if c == classOf[java.lang.Boolean] => - NewInstance(c, getPath :: Nil, ObjectType(c)) + case c if c == classOf[java.lang.Short] || + c == classOf[java.lang.Integer] || + c == classOf[java.lang.Long] || + c == classOf[java.lang.Double] || + c == classOf[java.lang.Float] || + c == classOf[java.lang.Byte] || + c == classOf[java.lang.Boolean] => + StaticInvoke( + c, + ObjectType(c), + "valueOf", + getPath :: Nil, + propagateNull = true) case c if c == classOf[java.sql.Date] => StaticInvoke( diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala index 0c5a818f54f5c..82710a2a183ab 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala @@ -204,37 +204,37 @@ object ScalaReflection extends ScalaReflection { case t if t <:< localTypeOf[java.lang.Integer] => val boxedType = classOf[java.lang.Integer] val objectType = ObjectType(boxedType) - NewInstance(boxedType, getPath :: Nil, objectType) + StaticInvoke(boxedType, objectType, "valueOf", getPath :: Nil, propagateNull = true) case t if t <:< localTypeOf[java.lang.Long] => val boxedType = classOf[java.lang.Long] val objectType = ObjectType(boxedType) - NewInstance(boxedType, getPath :: Nil, objectType) + StaticInvoke(boxedType, objectType, "valueOf", getPath :: Nil, propagateNull = true) case t if t <:< localTypeOf[java.lang.Double] => val boxedType = classOf[java.lang.Double] val objectType = ObjectType(boxedType) - NewInstance(boxedType, getPath :: Nil, objectType) + StaticInvoke(boxedType, objectType, "valueOf", getPath :: Nil, propagateNull = true) case t if t <:< localTypeOf[java.lang.Float] => val boxedType = classOf[java.lang.Float] val objectType = ObjectType(boxedType) - NewInstance(boxedType, getPath :: Nil, objectType) + StaticInvoke(boxedType, objectType, "valueOf", getPath :: Nil, propagateNull = true) case t if t <:< localTypeOf[java.lang.Short] => val boxedType = classOf[java.lang.Short] val objectType = ObjectType(boxedType) - NewInstance(boxedType, getPath :: Nil, objectType) + StaticInvoke(boxedType, objectType, "valueOf", getPath :: Nil, propagateNull = true) case t if t <:< localTypeOf[java.lang.Byte] => val boxedType = classOf[java.lang.Byte] val objectType = ObjectType(boxedType) - NewInstance(boxedType, getPath :: Nil, objectType) + StaticInvoke(boxedType, objectType, "valueOf", getPath :: Nil, propagateNull = true) case t if t <:< localTypeOf[java.lang.Boolean] => val boxedType = classOf[java.lang.Boolean] val objectType = ObjectType(boxedType) - NewInstance(boxedType, getPath :: Nil, objectType) + StaticInvoke(boxedType, objectType, "valueOf", getPath :: Nil, propagateNull = true) case t if t <:< localTypeOf[java.sql.Date] => StaticInvoke( From 6297697f975960a3006c4e58b4964d9ac40eeaf5 Mon Sep 17 00:00:00 2001 From: David Gingrich Date: Tue, 11 Apr 2017 12:18:31 -0700 Subject: [PATCH 0246/1765] [SPARK-19505][PYTHON] AttributeError on Exception.message in Python3 MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## What changes were proposed in this pull request? Added `util._message_exception` helper to use `str(e)` when `e.message` is unavailable (Python3). Grepped for all occurrences of `.message` in `pyspark/` and these were the only occurrences. ## How was this patch tested? - Doctests for helper function ## Legal This is my original work and I license the work to the project under the project’s open source license. Author: David Gingrich Closes #16845 from dgingrich/topic-spark-19505-py3-exceptions. --- dev/sparktestsupport/modules.py | 1 + python/pyspark/broadcast.py | 4 ++- python/pyspark/cloudpickle.py | 9 ++++--- python/pyspark/util.py | 45 +++++++++++++++++++++++++++++++++ 4 files changed, 54 insertions(+), 5 deletions(-) create mode 100644 python/pyspark/util.py diff --git a/dev/sparktestsupport/modules.py b/dev/sparktestsupport/modules.py index 246f5188a518d..78b5b8b0f4b59 100644 --- a/dev/sparktestsupport/modules.py +++ b/dev/sparktestsupport/modules.py @@ -340,6 +340,7 @@ def __hash__(self): "pyspark.profiler", "pyspark.shuffle", "pyspark.tests", + "pyspark.util", ] ) diff --git a/python/pyspark/broadcast.py b/python/pyspark/broadcast.py index 74dee1420754a..b1b59f73d6718 100644 --- a/python/pyspark/broadcast.py +++ b/python/pyspark/broadcast.py @@ -21,6 +21,7 @@ from tempfile import NamedTemporaryFile from pyspark.cloudpickle import print_exec +from pyspark.util import _exception_message if sys.version < '3': import cPickle as pickle @@ -82,7 +83,8 @@ def dump(self, value, f): except pickle.PickleError: raise except Exception as e: - msg = "Could not serialize broadcast: " + e.__class__.__name__ + ": " + e.message + msg = "Could not serialize broadcast: %s: %s" \ + % (e.__class__.__name__, _exception_message(e)) print_exec(sys.stderr) raise pickle.PicklingError(msg) f.close() diff --git a/python/pyspark/cloudpickle.py b/python/pyspark/cloudpickle.py index 959fb8b357f99..389bee7eee6e9 100644 --- a/python/pyspark/cloudpickle.py +++ b/python/pyspark/cloudpickle.py @@ -56,6 +56,7 @@ import traceback import weakref +from pyspark.util import _exception_message if sys.version < '3': from pickle import Pickler @@ -152,13 +153,13 @@ def dump(self, obj): except pickle.PickleError: raise except Exception as e: - if "'i' format requires" in e.message: - msg = "Object too large to serialize: " + e.message + emsg = _exception_message(e) + if "'i' format requires" in emsg: + msg = "Object too large to serialize: %s" % emsg else: - msg = "Could not serialize object: " + e.__class__.__name__ + ": " + e.message + msg = "Could not serialize object: %s: %s" % (e.__class__.__name__, emsg) print_exec(sys.stderr) raise pickle.PicklingError(msg) - def save_memoryview(self, obj): """Fallback to save_string""" diff --git a/python/pyspark/util.py b/python/pyspark/util.py new file mode 100644 index 0000000000000..e5d332ce54429 --- /dev/null +++ b/python/pyspark/util.py @@ -0,0 +1,45 @@ +# -*- coding: utf-8 -*- +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +__all__ = [] + + +def _exception_message(excp): + """Return the message from an exception as either a str or unicode object. Supports both + Python 2 and Python 3. + + >>> msg = "Exception message" + >>> excp = Exception(msg) + >>> msg == _exception_message(excp) + True + + >>> msg = u"unicöde" + >>> excp = Exception(msg) + >>> msg == _exception_message(excp) + True + """ + if hasattr(excp, "message"): + return excp.message + return str(excp) + + +if __name__ == "__main__": + import doctest + (failure_count, test_count) = doctest.testmod() + if failure_count: + exit(-1) From cde9e328484e4007aa6b505312d7cea5461a6eaf Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Tue, 11 Apr 2017 19:30:34 -0700 Subject: [PATCH 0247/1765] [MINOR][DOCS] Update supported versions for Hive Metastore ## What changes were proposed in this pull request? Since SPARK-18112 and SPARK-13446, Apache Spark starts to support reading Hive metastore 2.0 ~ 2.1.1. This updates the docs. ## How was this patch tested? N/A Author: Dongjoon Hyun Closes #17612 from dongjoon-hyun/metastore. --- docs/sql-programming-guide.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md index 7ae9847983d4d..c425faca4c273 100644 --- a/docs/sql-programming-guide.md +++ b/docs/sql-programming-guide.md @@ -1700,7 +1700,7 @@ referencing a singleton. Spark SQL is designed to be compatible with the Hive Metastore, SerDes and UDFs. Currently Hive SerDes and UDFs are based on Hive 1.2.1, and Spark SQL can be connected to different versions of Hive Metastore -(from 0.12.0 to 1.2.1. Also see [Interacting with Different Versions of Hive Metastore] (#interacting-with-different-versions-of-hive-metastore)). +(from 0.12.0 to 2.1.1. Also see [Interacting with Different Versions of Hive Metastore] (#interacting-with-different-versions-of-hive-metastore)). #### Deploying in Existing Hive Warehouses From 8ad63ee158815de5ffff7bf03cdf25aef312095f Mon Sep 17 00:00:00 2001 From: DB Tsai Date: Wed, 12 Apr 2017 11:19:20 +0800 Subject: [PATCH 0248/1765] [SPARK-20291][SQL] NaNvl(FloatType, NullType) should not be cast to NaNvl(DoubleType, DoubleType) ## What changes were proposed in this pull request? `NaNvl(float value, null)` will be converted into `NaNvl(float value, Cast(null, DoubleType))` and finally `NaNvl(Cast(float value, DoubleType), Cast(null, DoubleType))`. This will cause mismatching in the output type when the input type is float. By adding extra rule in TypeCoercion can resolve this issue. ## How was this patch tested? unite tests. Please review http://spark.apache.org/contributing.html before opening a pull request. Author: DB Tsai Closes #17606 from dbtsai/fixNaNvl. --- .../spark/sql/catalyst/analysis/TypeCoercion.scala | 1 + .../sql/catalyst/analysis/TypeCoercionSuite.scala | 14 ++++++++++---- .../apache/spark/sql/DataFrameNaFunctions.scala | 3 +-- 3 files changed, 12 insertions(+), 6 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala index 768897dc0713c..e1dd010d37a95 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala @@ -571,6 +571,7 @@ object TypeCoercion { NaNvl(l, Cast(r, DoubleType)) case NaNvl(l, r) if l.dataType == FloatType && r.dataType == DoubleType => NaNvl(Cast(l, DoubleType), r) + case NaNvl(l, r) if r.dataType == NullType => NaNvl(l, Cast(r, l.dataType)) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala index 3e0c357b6de42..011d09ff60641 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala @@ -656,14 +656,20 @@ class TypeCoercionSuite extends PlanTest { test("nanvl casts") { ruleTest(TypeCoercion.FunctionArgumentConversion, - NaNvl(Literal.create(1.0, FloatType), Literal.create(1.0, DoubleType)), - NaNvl(Cast(Literal.create(1.0, FloatType), DoubleType), Literal.create(1.0, DoubleType))) + NaNvl(Literal.create(1.0f, FloatType), Literal.create(1.0, DoubleType)), + NaNvl(Cast(Literal.create(1.0f, FloatType), DoubleType), Literal.create(1.0, DoubleType))) ruleTest(TypeCoercion.FunctionArgumentConversion, - NaNvl(Literal.create(1.0, DoubleType), Literal.create(1.0, FloatType)), - NaNvl(Literal.create(1.0, DoubleType), Cast(Literal.create(1.0, FloatType), DoubleType))) + NaNvl(Literal.create(1.0, DoubleType), Literal.create(1.0f, FloatType)), + NaNvl(Literal.create(1.0, DoubleType), Cast(Literal.create(1.0f, FloatType), DoubleType))) ruleTest(TypeCoercion.FunctionArgumentConversion, NaNvl(Literal.create(1.0, DoubleType), Literal.create(1.0, DoubleType)), NaNvl(Literal.create(1.0, DoubleType), Literal.create(1.0, DoubleType))) + ruleTest(TypeCoercion.FunctionArgumentConversion, + NaNvl(Literal.create(1.0f, FloatType), Literal.create(null, NullType)), + NaNvl(Literal.create(1.0f, FloatType), Cast(Literal.create(null, NullType), FloatType))) + ruleTest(TypeCoercion.FunctionArgumentConversion, + NaNvl(Literal.create(1.0, DoubleType), Literal.create(null, NullType)), + NaNvl(Literal.create(1.0, DoubleType), Cast(Literal.create(null, NullType), DoubleType))) } test("type coercion for If") { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala index 93d565d9fe904..052d85ad33bd6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala @@ -408,8 +408,7 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) { val quotedColName = "`" + col.name + "`" val colValue = col.dataType match { case DoubleType | FloatType => - // nanvl only supports these types - nanvl(df.col(quotedColName), lit(null).cast(col.dataType)) + nanvl(df.col(quotedColName), lit(null)) // nanvl only supports these types case _ => df.col(quotedColName) } coalesce(colValue, lit(replacement).cast(col.dataType)).as(col.name) From b14bfc3f8e97479ac5927c071b00ed18f2104c95 Mon Sep 17 00:00:00 2001 From: Dilip Biswal Date: Wed, 12 Apr 2017 12:18:01 +0800 Subject: [PATCH 0249/1765] [SPARK-19993][SQL] Caching logical plans containing subquery expressions does not work. ## What changes were proposed in this pull request? The sameResult() method does not work when the logical plan contains subquery expressions. **Before the fix** ```SQL scala> val ds = spark.sql("select * from s1 where s1.c1 in (select s2.c1 from s2 where s1.c1 = s2.c1)") ds: org.apache.spark.sql.DataFrame = [c1: int] scala> ds.cache res13: ds.type = [c1: int] scala> spark.sql("select * from s1 where s1.c1 in (select s2.c1 from s2 where s1.c1 = s2.c1)").explain(true) == Analyzed Logical Plan == c1: int Project [c1#86] +- Filter c1#86 IN (list#78 [c1#86]) : +- Project [c1#87] : +- Filter (outer(c1#86) = c1#87) : +- SubqueryAlias s2 : +- Relation[c1#87] parquet +- SubqueryAlias s1 +- Relation[c1#86] parquet == Optimized Logical Plan == Join LeftSemi, ((c1#86 = c1#87) && (c1#86 = c1#87)) :- Relation[c1#86] parquet +- Relation[c1#87] parquet ``` **Plan after fix** ```SQL == Analyzed Logical Plan == c1: int Project [c1#22] +- Filter c1#22 IN (list#14 [c1#22]) : +- Project [c1#23] : +- Filter (outer(c1#22) = c1#23) : +- SubqueryAlias s2 : +- Relation[c1#23] parquet +- SubqueryAlias s1 +- Relation[c1#22] parquet == Optimized Logical Plan == InMemoryRelation [c1#22], true, 10000, StorageLevel(disk, memory, deserialized, 1 replicas) +- *BroadcastHashJoin [c1#1, c1#1], [c1#2, c1#2], LeftSemi, BuildRight :- *FileScan parquet default.s1[c1#1] Batched: true, Format: Parquet, Location: InMemoryFileIndex[file:/Users/dbiswal/mygit/apache/spark/bin/spark-warehouse/s1], PartitionFilters: [], PushedFilters: [], ReadSchema: struct +- BroadcastExchange HashedRelationBroadcastMode(List((shiftleft(cast(input[0, int, true] as bigint), 32) | (cast(input[0, int, true] as bigint) & 4294967295)))) +- *FileScan parquet default.s2[c1#2] Batched: true, Format: Parquet, Location: InMemoryFileIndex[file:/Users/dbiswal/mygit/apache/spark/bin/spark-warehouse/s2], PartitionFilters: [], PushedFilters: [], ReadSchema: struct ``` ## How was this patch tested? New tests are added to CachedTableSuite. Author: Dilip Biswal Closes #17330 from dilipbiswal/subquery_cache_final. --- .../sql/catalyst/expressions/subquery.scala | 26 +++- .../spark/sql/catalyst/plans/QueryPlan.scala | 43 +++--- .../sql/execution/DataSourceScanExec.scala | 7 +- .../apache/spark/sql/CachedTableSuite.scala | 143 +++++++++++++++++- .../hive/execution/HiveTableScanExec.scala | 5 +- 5 files changed, 198 insertions(+), 26 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala index 59db28d58afce..d7b493d521ddb 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala @@ -47,7 +47,6 @@ abstract class SubqueryExpression( plan: LogicalPlan, children: Seq[Expression], exprId: ExprId) extends PlanExpression[LogicalPlan] { - override lazy val resolved: Boolean = childrenResolved && plan.resolved override lazy val references: AttributeSet = if (plan.resolved) super.references -- plan.outputSet else super.references @@ -59,6 +58,13 @@ abstract class SubqueryExpression( children.zip(p.children).forall(p => p._1.semanticEquals(p._2)) case _ => false } + def canonicalize(attrs: AttributeSeq): SubqueryExpression = { + // Normalize the outer references in the subquery plan. + val normalizedPlan = plan.transformAllExpressions { + case OuterReference(r) => OuterReference(QueryPlan.normalizeExprId(r, attrs)) + } + withNewPlan(normalizedPlan).canonicalized.asInstanceOf[SubqueryExpression] + } } object SubqueryExpression { @@ -236,6 +242,12 @@ case class ScalarSubquery( override def nullable: Boolean = true override def withNewPlan(plan: LogicalPlan): ScalarSubquery = copy(plan = plan) override def toString: String = s"scalar-subquery#${exprId.id} $conditionString" + override lazy val canonicalized: Expression = { + ScalarSubquery( + plan.canonicalized, + children.map(_.canonicalized), + ExprId(0)) + } } object ScalarSubquery { @@ -268,6 +280,12 @@ case class ListQuery( override def nullable: Boolean = false override def withNewPlan(plan: LogicalPlan): ListQuery = copy(plan = plan) override def toString: String = s"list#${exprId.id} $conditionString" + override lazy val canonicalized: Expression = { + ListQuery( + plan.canonicalized, + children.map(_.canonicalized), + ExprId(0)) + } } /** @@ -290,4 +308,10 @@ case class Exists( override def nullable: Boolean = false override def withNewPlan(plan: LogicalPlan): Exists = copy(plan = plan) override def toString: String = s"exists#${exprId.id} $conditionString" + override lazy val canonicalized: Expression = { + Exists( + plan.canonicalized, + children.map(_.canonicalized), + ExprId(0)) + } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala index 3008e8cb84659..2fb65bd435507 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala @@ -377,7 +377,8 @@ abstract class QueryPlan[PlanType <: QueryPlan[PlanType]] extends TreeNode[PlanT // As the root of the expression, Alias will always take an arbitrary exprId, we need to // normalize that for equality testing, by assigning expr id from 0 incrementally. The // alias name doesn't matter and should be erased. - Alias(normalizeExprId(a.child), "")(ExprId(id), a.qualifier, isGenerated = a.isGenerated) + val normalizedChild = QueryPlan.normalizeExprId(a.child, allAttributes) + Alias(normalizedChild, "")(ExprId(id), a.qualifier, isGenerated = a.isGenerated) case ar: AttributeReference if allAttributes.indexOf(ar.exprId) == -1 => // Top level `AttributeReference` may also be used for output like `Alias`, we should @@ -385,7 +386,7 @@ abstract class QueryPlan[PlanType <: QueryPlan[PlanType]] extends TreeNode[PlanT id += 1 ar.withExprId(ExprId(id)) - case other => normalizeExprId(other) + case other => QueryPlan.normalizeExprId(other, allAttributes) }.withNewChildren(canonicalizedChildren) } @@ -395,23 +396,6 @@ abstract class QueryPlan[PlanType <: QueryPlan[PlanType]] extends TreeNode[PlanT */ protected def preCanonicalized: PlanType = this - /** - * Normalize the exprIds in the given expression, by updating the exprId in `AttributeReference` - * with its referenced ordinal from input attributes. It's similar to `BindReferences` but we - * do not use `BindReferences` here as the plan may take the expression as a parameter with type - * `Attribute`, and replace it with `BoundReference` will cause error. - */ - protected def normalizeExprId[T <: Expression](e: T, input: AttributeSeq = allAttributes): T = { - e.transformUp { - case ar: AttributeReference => - val ordinal = input.indexOf(ar.exprId) - if (ordinal == -1) { - ar - } else { - ar.withExprId(ExprId(ordinal)) - } - }.canonicalized.asInstanceOf[T] - } /** * Returns true when the given query plan will return the same results as this query plan. @@ -438,3 +422,24 @@ abstract class QueryPlan[PlanType <: QueryPlan[PlanType]] extends TreeNode[PlanT */ lazy val allAttributes: AttributeSeq = children.flatMap(_.output) } + +object QueryPlan { + /** + * Normalize the exprIds in the given expression, by updating the exprId in `AttributeReference` + * with its referenced ordinal from input attributes. It's similar to `BindReferences` but we + * do not use `BindReferences` here as the plan may take the expression as a parameter with type + * `Attribute`, and replace it with `BoundReference` will cause error. + */ + def normalizeExprId[T <: Expression](e: T, input: AttributeSeq): T = { + e.transformUp { + case s: SubqueryExpression => s.canonicalize(input) + case ar: AttributeReference => + val ordinal = input.indexOf(ar.exprId) + if (ordinal == -1) { + ar + } else { + ar.withExprId(ExprId(ordinal)) + } + }.canonicalized.asInstanceOf[T] + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala index 3a9132d74ac11..866fa98533218 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala @@ -28,6 +28,7 @@ import org.apache.spark.sql.catalyst.{InternalRow, TableIdentifier} import org.apache.spark.sql.catalyst.catalog.BucketSpec import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen.CodegenContext +import org.apache.spark.sql.catalyst.plans.QueryPlan import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, Partitioning, UnknownPartitioning} import org.apache.spark.sql.execution.datasources._ import org.apache.spark.sql.execution.datasources.parquet.{ParquetFileFormat => ParquetSource} @@ -516,10 +517,10 @@ case class FileSourceScanExec( override lazy val canonicalized: FileSourceScanExec = { FileSourceScanExec( relation, - output.map(normalizeExprId(_, output)), + output.map(QueryPlan.normalizeExprId(_, output)), requiredSchema, - partitionFilters.map(normalizeExprId(_, output)), - dataFilters.map(normalizeExprId(_, output)), + partitionFilters.map(QueryPlan.normalizeExprId(_, output)), + dataFilters.map(QueryPlan.normalizeExprId(_, output)), None) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala index 7a7d52b21427a..e66fe97afad45 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala @@ -26,7 +26,7 @@ import org.scalatest.concurrent.Eventually._ import org.apache.spark.CleanerListener import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.expressions.SubqueryExpression -import org.apache.spark.sql.execution.RDDScanExec +import org.apache.spark.sql.execution.{RDDScanExec, SparkPlan} import org.apache.spark.sql.execution.columnar._ import org.apache.spark.sql.execution.exchange.ShuffleExchange import org.apache.spark.sql.functions._ @@ -76,6 +76,13 @@ class CachedTableSuite extends QueryTest with SQLTestUtils with SharedSQLContext sum } + private def getNumInMemoryTablesRecursively(plan: SparkPlan): Int = { + plan.collect { + case InMemoryTableScanExec(_, _, relation) => + getNumInMemoryTablesRecursively(relation.child) + 1 + }.sum + } + test("withColumn doesn't invalidate cached dataframe") { var evalCount = 0 val myUDF = udf((x: String) => { evalCount += 1; "result" }) @@ -670,4 +677,138 @@ class CachedTableSuite extends QueryTest with SQLTestUtils with SharedSQLContext assert(spark.read.parquet(path).filter($"id" > 4).count() == 15) } } + + test("SPARK-19993 simple subquery caching") { + withTempView("t1", "t2") { + Seq(1).toDF("c1").createOrReplaceTempView("t1") + Seq(2).toDF("c1").createOrReplaceTempView("t2") + + sql( + """ + |SELECT * FROM t1 + |WHERE + |NOT EXISTS (SELECT * FROM t2) + """.stripMargin).cache() + + val cachedDs = + sql( + """ + |SELECT * FROM t1 + |WHERE + |NOT EXISTS (SELECT * FROM t2) + """.stripMargin) + assert(getNumInMemoryRelations(cachedDs) == 1) + + // Additional predicate in the subquery plan should cause a cache miss + val cachedMissDs = + sql( + """ + |SELECT * FROM t1 + |WHERE + |NOT EXISTS (SELECT * FROM t2 where c1 = 0) + """.stripMargin) + assert(getNumInMemoryRelations(cachedMissDs) == 0) + } + } + + test("SPARK-19993 subquery caching with correlated predicates") { + withTempView("t1", "t2") { + Seq(1).toDF("c1").createOrReplaceTempView("t1") + Seq(1).toDF("c1").createOrReplaceTempView("t2") + + // Simple correlated predicate in subquery + sql( + """ + |SELECT * FROM t1 + |WHERE + |t1.c1 in (SELECT t2.c1 FROM t2 where t1.c1 = t2.c1) + """.stripMargin).cache() + + val cachedDs = + sql( + """ + |SELECT * FROM t1 + |WHERE + |t1.c1 in (SELECT t2.c1 FROM t2 where t1.c1 = t2.c1) + """.stripMargin) + assert(getNumInMemoryRelations(cachedDs) == 1) + } + } + + test("SPARK-19993 subquery with cached underlying relation") { + withTempView("t1") { + Seq(1).toDF("c1").createOrReplaceTempView("t1") + spark.catalog.cacheTable("t1") + + // underlying table t1 is cached as well as the query that refers to it. + val ds = + sql( + """ + |SELECT * FROM t1 + |WHERE + |NOT EXISTS (SELECT * FROM t1) + """.stripMargin) + assert(getNumInMemoryRelations(ds) == 2) + + val cachedDs = + sql( + """ + |SELECT * FROM t1 + |WHERE + |NOT EXISTS (SELECT * FROM t1) + """.stripMargin).cache() + assert(getNumInMemoryTablesRecursively(cachedDs.queryExecution.sparkPlan) == 3) + } + } + + test("SPARK-19993 nested subquery caching and scalar + predicate subqueris") { + withTempView("t1", "t2", "t3", "t4") { + Seq(1).toDF("c1").createOrReplaceTempView("t1") + Seq(2).toDF("c1").createOrReplaceTempView("t2") + Seq(1).toDF("c1").createOrReplaceTempView("t3") + Seq(1).toDF("c1").createOrReplaceTempView("t4") + + // Nested predicate subquery + sql( + """ + |SELECT * FROM t1 + |WHERE + |c1 IN (SELECT c1 FROM t2 WHERE c1 IN (SELECT c1 FROM t3 WHERE c1 = 1)) + """.stripMargin).cache() + + val cachedDs = + sql( + """ + |SELECT * FROM t1 + |WHERE + |c1 IN (SELECT c1 FROM t2 WHERE c1 IN (SELECT c1 FROM t3 WHERE c1 = 1)) + """.stripMargin) + assert(getNumInMemoryRelations(cachedDs) == 1) + + // Scalar subquery and predicate subquery + sql( + """ + |SELECT * FROM (SELECT max(c1) FROM t1 GROUP BY c1) + |WHERE + |c1 = (SELECT max(c1) FROM t2 GROUP BY c1) + |OR + |EXISTS (SELECT c1 FROM t3) + |OR + |c1 IN (SELECT c1 FROM t4) + """.stripMargin).cache() + + val cachedDs2 = + sql( + """ + |SELECT * FROM (SELECT max(c1) FROM t1 GROUP BY c1) + |WHERE + |c1 = (SELECT max(c1) FROM t2 GROUP BY c1) + |OR + |EXISTS (SELECT c1 FROM t3) + |OR + |c1 IN (SELECT c1 FROM t4) + """.stripMargin) + assert(getNumInMemoryRelations(cachedDs2) == 1) + } + } } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScanExec.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScanExec.scala index fab0d7fa84827..666548d1a490b 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScanExec.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScanExec.scala @@ -32,6 +32,7 @@ import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.catalog.CatalogRelation import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans.QueryPlan import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.metric.SQLMetrics import org.apache.spark.sql.hive._ @@ -203,9 +204,9 @@ case class HiveTableScanExec( override lazy val canonicalized: HiveTableScanExec = { val input: AttributeSeq = relation.output HiveTableScanExec( - requestedAttributes.map(normalizeExprId(_, input)), + requestedAttributes.map(QueryPlan.normalizeExprId(_, input)), relation.canonicalized.asInstanceOf[CatalogRelation], - partitionPruningPred.map(normalizeExprId(_, input)))(sparkSession) + partitionPruningPred.map(QueryPlan.normalizeExprId(_, input)))(sparkSession) } override def otherCopyArgs: Seq[AnyRef] = Seq(sparkSession) From b9384382484a9f5c6b389742e7fdf63865de81c0 Mon Sep 17 00:00:00 2001 From: Lee Dongjin Date: Wed, 12 Apr 2017 09:12:14 +0100 Subject: [PATCH 0250/1765] [MINOR][DOCS] Fix spacings in Structured Streaming Programming Guide ## What changes were proposed in this pull request? 1. Omitted space between the sentences: `... on static data.The Spark SQL engine will ...` -> `... on static data. The Spark SQL engine will ...` 2. Omitted colon in Output Model section. ## How was this patch tested? None. Author: Lee Dongjin Closes #17564 from dongjinleekr/feature/fix-programming-guide. --- docs/structured-streaming-programming-guide.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/structured-streaming-programming-guide.md b/docs/structured-streaming-programming-guide.md index 37a1d6189a42d..3cf7151819e2d 100644 --- a/docs/structured-streaming-programming-guide.md +++ b/docs/structured-streaming-programming-guide.md @@ -8,7 +8,7 @@ title: Structured Streaming Programming Guide {:toc} # Overview -Structured Streaming is a scalable and fault-tolerant stream processing engine built on the Spark SQL engine. You can express your streaming computation the same way you would express a batch computation on static data.The Spark SQL engine will take care of running it incrementally and continuously and updating the final result as streaming data continues to arrive. You can use the [Dataset/DataFrame API](sql-programming-guide.html) in Scala, Java or Python to express streaming aggregations, event-time windows, stream-to-batch joins, etc. The computation is executed on the same optimized Spark SQL engine. Finally, the system ensures end-to-end exactly-once fault-tolerance guarantees through checkpointing and Write Ahead Logs. In short, *Structured Streaming provides fast, scalable, fault-tolerant, end-to-end exactly-once stream processing without the user having to reason about streaming.* +Structured Streaming is a scalable and fault-tolerant stream processing engine built on the Spark SQL engine. You can express your streaming computation the same way you would express a batch computation on static data. The Spark SQL engine will take care of running it incrementally and continuously and updating the final result as streaming data continues to arrive. You can use the [Dataset/DataFrame API](sql-programming-guide.html) in Scala, Java or Python to express streaming aggregations, event-time windows, stream-to-batch joins, etc. The computation is executed on the same optimized Spark SQL engine. Finally, the system ensures end-to-end exactly-once fault-tolerance guarantees through checkpointing and Write Ahead Logs. In short, *Structured Streaming provides fast, scalable, fault-tolerant, end-to-end exactly-once stream processing without the user having to reason about streaming.* **Structured Streaming is still ALPHA in Spark 2.1** and the APIs are still experimental. In this guide, we are going to walk you through the programming model and the APIs. First, let's start with a simple example - a streaming word count. @@ -362,7 +362,7 @@ A query on the input will generate the "Result Table". Every trigger interval (s ![Model](img/structured-streaming-model.png) -The "Output" is defined as what gets written out to the external storage. The output can be defined in different modes +The "Output" is defined as what gets written out to the external storage. The output can be defined in a different mode: - *Complete Mode* - The entire updated Result Table will be written to the external storage. It is up to the storage connector to decide how to handle writing of the entire table. From bca4259f12b32eeb156b6755d0ec5e16d8e566b3 Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Wed, 12 Apr 2017 09:16:39 +0100 Subject: [PATCH 0251/1765] [MINOR][DOCS] JSON APIs related documentation fixes ## What changes were proposed in this pull request? This PR proposes corrections related to JSON APIs as below: - Rendering links in Python documentation - Replacing `RDD` to `Dataset` in programing guide - Adding missing description about JSON Lines consistently in `DataFrameReader.json` in Python API - De-duplicating little bit of `DataFrameReader.json` in Scala/Java API ## How was this patch tested? Manually build the documentation via `jekyll build`. Corresponding snapstops will be left on the codes. Note that currently there are Javadoc8 breaks in several places. These are proposed to be handled in https://github.com/apache/spark/pull/17477. So, this PR does not fix those. Author: hyukjinkwon Closes #17602 from HyukjinKwon/minor-json-documentation. --- docs/sql-programming-guide.md | 4 ++-- .../spark/examples/sql/JavaSQLDataSourceExample.java | 2 +- .../apache/spark/examples/sql/SQLDataSourceExample.scala | 2 +- python/pyspark/sql/readwriter.py | 8 +++++--- python/pyspark/sql/streaming.py | 4 ++-- .../main/scala/org/apache/spark/sql/DataFrameReader.scala | 4 ++-- 6 files changed, 13 insertions(+), 11 deletions(-) diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md index c425faca4c273..28942b68fa20d 100644 --- a/docs/sql-programming-guide.md +++ b/docs/sql-programming-guide.md @@ -883,7 +883,7 @@ Configuration of Parquet can be done using the `setConf` method on `SparkSession
    Spark SQL can automatically infer the schema of a JSON dataset and load it as a `Dataset[Row]`. -This conversion can be done using `SparkSession.read.json()` on either an RDD of String, +This conversion can be done using `SparkSession.read.json()` on either a `Dataset[String]`, or a JSON file. Note that the file that is offered as _a json file_ is not a typical JSON file. Each @@ -897,7 +897,7 @@ For a regular multi-line JSON file, set the `wholeFile` option to `true`.
    Spark SQL can automatically infer the schema of a JSON dataset and load it as a `Dataset`. -This conversion can be done using `SparkSession.read().json()` on either an RDD of String, +This conversion can be done using `SparkSession.read().json()` on either a `Dataset`, or a JSON file. Note that the file that is offered as _a json file_ is not a typical JSON file. Each diff --git a/examples/src/main/java/org/apache/spark/examples/sql/JavaSQLDataSourceExample.java b/examples/src/main/java/org/apache/spark/examples/sql/JavaSQLDataSourceExample.java index 1a7054614b348..b66abaed66000 100644 --- a/examples/src/main/java/org/apache/spark/examples/sql/JavaSQLDataSourceExample.java +++ b/examples/src/main/java/org/apache/spark/examples/sql/JavaSQLDataSourceExample.java @@ -215,7 +215,7 @@ private static void runJsonDatasetExample(SparkSession spark) { // +------+ // Alternatively, a DataFrame can be created for a JSON dataset represented by - // an Dataset[String] storing one JSON object per string. + // a Dataset storing one JSON object per string. List jsonData = Arrays.asList( "{\"name\":\"Yin\",\"address\":{\"city\":\"Columbus\",\"state\":\"Ohio\"}}"); Dataset anotherPeopleDataset = spark.createDataset(jsonData, Encoders.STRING()); diff --git a/examples/src/main/scala/org/apache/spark/examples/sql/SQLDataSourceExample.scala b/examples/src/main/scala/org/apache/spark/examples/sql/SQLDataSourceExample.scala index 82fd56de39847..ad74da72bd5e6 100644 --- a/examples/src/main/scala/org/apache/spark/examples/sql/SQLDataSourceExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/sql/SQLDataSourceExample.scala @@ -139,7 +139,7 @@ object SQLDataSourceExample { // +------+ // Alternatively, a DataFrame can be created for a JSON dataset represented by - // an Dataset[String] storing one JSON object per string + // a Dataset[String] storing one JSON object per string val otherPeopleDataset = spark.createDataset( """{"name":"Yin","address":{"city":"Columbus","state":"Ohio"}}""" :: Nil) val otherPeople = spark.read.json(otherPeopleDataset) diff --git a/python/pyspark/sql/readwriter.py b/python/pyspark/sql/readwriter.py index d912f395dafce..960fb882cf901 100644 --- a/python/pyspark/sql/readwriter.py +++ b/python/pyspark/sql/readwriter.py @@ -173,8 +173,8 @@ def json(self, path, schema=None, primitivesAsString=None, prefersDecimal=None, """ Loads JSON files and returns the results as a :class:`DataFrame`. - `JSON Lines `_(newline-delimited JSON) is supported by default. - For JSON (one record per file), set the `wholeFile` parameter to ``true``. + `JSON Lines `_ (newline-delimited JSON) is supported by default. + For JSON (one record per file), set the ``wholeFile`` parameter to ``true``. If the ``schema`` parameter is not specified, this function goes through the input once to determine the input schema. @@ -634,7 +634,9 @@ def saveAsTable(self, name, format=None, mode=None, partitionBy=None, **options) @since(1.4) def json(self, path, mode=None, compression=None, dateFormat=None, timestampFormat=None): - """Saves the content of the :class:`DataFrame` in JSON format at the specified path. + """Saves the content of the :class:`DataFrame` in JSON format + (`JSON Lines text format or newline-delimited JSON `_) at the + specified path. :param path: the path in any Hadoop supported file system :param mode: specifies the behavior of the save operation when data already exists. diff --git a/python/pyspark/sql/streaming.py b/python/pyspark/sql/streaming.py index 3b604963415f9..65b59d480da36 100644 --- a/python/pyspark/sql/streaming.py +++ b/python/pyspark/sql/streaming.py @@ -405,8 +405,8 @@ def json(self, path, schema=None, primitivesAsString=None, prefersDecimal=None, """ Loads a JSON file stream and returns the results as a :class:`DataFrame`. - `JSON Lines `_(newline-delimited JSON) is supported by default. - For JSON (one record per file), set the `wholeFile` parameter to ``true``. + `JSON Lines `_ (newline-delimited JSON) is supported by default. + For JSON (one record per file), set the ``wholeFile`` parameter to ``true``. If the ``schema`` parameter is not specified, this function goes through the input once to determine the input schema. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala index 49691c15d0f7d..c1b32917415ae 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala @@ -268,8 +268,8 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { } /** - * Loads a JSON file (JSON Lines text format or - * newline-delimited JSON) and returns the result as a `DataFrame`. + * Loads a JSON file and returns the results as a `DataFrame`. + * * See the documentation on the overloaded `json()` method with varargs for more details. * * @since 1.4.0 From 044f7ecbfd75ac5a13bfc8cd01990e195c9bd178 Mon Sep 17 00:00:00 2001 From: Brendan Dwyer Date: Wed, 12 Apr 2017 09:24:41 +0100 Subject: [PATCH 0252/1765] [SPARK-20298][SPARKR][MINOR] fixed spelling mistake "charactor" ## What changes were proposed in this pull request? Fixed spelling of "charactor" ## How was this patch tested? Spelling change only Author: Brendan Dwyer Closes #17611 from bdwyer2/SPARK-20298. --- R/pkg/R/DataFrame.R | 10 +++++----- R/pkg/R/SQLContext.R | 2 +- R/pkg/inst/tests/testthat/test_sparkSQL.R | 6 +++--- 3 files changed, 9 insertions(+), 9 deletions(-) diff --git a/R/pkg/R/DataFrame.R b/R/pkg/R/DataFrame.R index ec85f723c08c6..88a138fd8eb1f 100644 --- a/R/pkg/R/DataFrame.R +++ b/R/pkg/R/DataFrame.R @@ -2818,14 +2818,14 @@ setMethod("write.df", signature(df = "SparkDataFrame"), function(df, path = NULL, source = NULL, mode = "error", ...) { if (!is.null(path) && !is.character(path)) { - stop("path should be charactor, NULL or omitted.") + stop("path should be character, NULL or omitted.") } if (!is.null(source) && !is.character(source)) { stop("source should be character, NULL or omitted. It is the datasource specified ", "in 'spark.sql.sources.default' configuration by default.") } if (!is.character(mode)) { - stop("mode should be charactor or omitted. It is 'error' by default.") + stop("mode should be character or omitted. It is 'error' by default.") } if (is.null(source)) { source <- getDefaultSqlSource() @@ -3040,7 +3040,7 @@ setMethod("fillna", signature(x = "SparkDataFrame"), function(x, value, cols = NULL) { if (!(class(value) %in% c("integer", "numeric", "character", "list"))) { - stop("value should be an integer, numeric, charactor or named list.") + stop("value should be an integer, numeric, character or named list.") } if (class(value) == "list") { @@ -3052,7 +3052,7 @@ setMethod("fillna", # Check each item in the named list is of valid type lapply(value, function(v) { if (!(class(v) %in% c("integer", "numeric", "character"))) { - stop("Each item in value should be an integer, numeric or charactor.") + stop("Each item in value should be an integer, numeric or character.") } }) @@ -3598,7 +3598,7 @@ setMethod("write.stream", "in 'spark.sql.sources.default' configuration by default.") } if (!is.null(outputMode) && !is.character(outputMode)) { - stop("outputMode should be charactor or omitted.") + stop("outputMode should be character or omitted.") } if (is.null(source)) { source <- getDefaultSqlSource() diff --git a/R/pkg/R/SQLContext.R b/R/pkg/R/SQLContext.R index c2a1e240ad395..f5c3a749fe0a1 100644 --- a/R/pkg/R/SQLContext.R +++ b/R/pkg/R/SQLContext.R @@ -606,7 +606,7 @@ tableToDF <- function(tableName) { #' @note read.df since 1.4.0 read.df.default <- function(path = NULL, source = NULL, schema = NULL, na.strings = "NA", ...) { if (!is.null(path) && !is.character(path)) { - stop("path should be charactor, NULL or omitted.") + stop("path should be character, NULL or omitted.") } if (!is.null(source) && !is.character(source)) { stop("source should be character, NULL or omitted. It is the datasource specified ", diff --git a/R/pkg/inst/tests/testthat/test_sparkSQL.R b/R/pkg/inst/tests/testthat/test_sparkSQL.R index 58cf24256a94f..3fbb618ddfc39 100644 --- a/R/pkg/inst/tests/testthat/test_sparkSQL.R +++ b/R/pkg/inst/tests/testthat/test_sparkSQL.R @@ -2926,9 +2926,9 @@ test_that("Call DataFrameWriter.save() API in Java without path and check argume paste("source should be character, NULL or omitted. It is the datasource specified", "in 'spark.sql.sources.default' configuration by default.")) expect_error(write.df(df, path = c(3)), - "path should be charactor, NULL or omitted.") + "path should be character, NULL or omitted.") expect_error(write.df(df, mode = TRUE), - "mode should be charactor or omitted. It is 'error' by default.") + "mode should be character or omitted. It is 'error' by default.") }) test_that("Call DataFrameWriter.load() API in Java without path and check argument types", { @@ -2947,7 +2947,7 @@ test_that("Call DataFrameWriter.load() API in Java without path and check argume # Arguments checking in R side. expect_error(read.df(path = c(3)), - "path should be charactor, NULL or omitted.") + "path should be character, NULL or omitted.") expect_error(read.df(jsonPath, source = c(1, 2)), paste("source should be character, NULL or omitted. It is the datasource specified", "in 'spark.sql.sources.default' configuration by default.")) From ffc57b0118b58de57520967d8e8730b11baad507 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Wed, 12 Apr 2017 01:30:00 -0700 Subject: [PATCH 0253/1765] [SPARK-20302][SQL] Short circuit cast when from and to types are structurally the same ## What changes were proposed in this pull request? When we perform a cast expression and the from and to types are structurally the same (having the same structure but different field names), we should be able to skip the actual cast. ## How was this patch tested? Added unit tests for the newly introduced functions. Author: Reynold Xin Closes #17614 from rxin/SPARK-20302. --- .../spark/sql/catalyst/expressions/Cast.scala | 65 ++++++++++++------- .../org/apache/spark/sql/types/DataType.scala | 26 ++++++++ .../sql/catalyst/expressions/CastSuite.scala | 14 ++++ .../spark/sql/types/DataTypeSuite.scala | 31 +++++++++ 4 files changed, 113 insertions(+), 23 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala index 1049915986d9b..bb1273f5c3d84 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala @@ -462,35 +462,54 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String }) } - private[this] def cast(from: DataType, to: DataType): Any => Any = to match { - case dt if dt == from => identity[Any] - case StringType => castToString(from) - case BinaryType => castToBinary(from) - case DateType => castToDate(from) - case decimal: DecimalType => castToDecimal(from, decimal) - case TimestampType => castToTimestamp(from) - case CalendarIntervalType => castToInterval(from) - case BooleanType => castToBoolean(from) - case ByteType => castToByte(from) - case ShortType => castToShort(from) - case IntegerType => castToInt(from) - case FloatType => castToFloat(from) - case LongType => castToLong(from) - case DoubleType => castToDouble(from) - case array: ArrayType => castArray(from.asInstanceOf[ArrayType].elementType, array.elementType) - case map: MapType => castMap(from.asInstanceOf[MapType], map) - case struct: StructType => castStruct(from.asInstanceOf[StructType], struct) - case udt: UserDefinedType[_] - if udt.userClass == from.asInstanceOf[UserDefinedType[_]].userClass => - identity[Any] - case _: UserDefinedType[_] => - throw new SparkException(s"Cannot cast $from to $to.") + private[this] def cast(from: DataType, to: DataType): Any => Any = { + // If the cast does not change the structure, then we don't really need to cast anything. + // We can return what the children return. Same thing should happen in the codegen path. + if (DataType.equalsStructurally(from, to)) { + identity + } else { + to match { + case dt if dt == from => identity[Any] + case StringType => castToString(from) + case BinaryType => castToBinary(from) + case DateType => castToDate(from) + case decimal: DecimalType => castToDecimal(from, decimal) + case TimestampType => castToTimestamp(from) + case CalendarIntervalType => castToInterval(from) + case BooleanType => castToBoolean(from) + case ByteType => castToByte(from) + case ShortType => castToShort(from) + case IntegerType => castToInt(from) + case FloatType => castToFloat(from) + case LongType => castToLong(from) + case DoubleType => castToDouble(from) + case array: ArrayType => + castArray(from.asInstanceOf[ArrayType].elementType, array.elementType) + case map: MapType => castMap(from.asInstanceOf[MapType], map) + case struct: StructType => castStruct(from.asInstanceOf[StructType], struct) + case udt: UserDefinedType[_] + if udt.userClass == from.asInstanceOf[UserDefinedType[_]].userClass => + identity[Any] + case _: UserDefinedType[_] => + throw new SparkException(s"Cannot cast $from to $to.") + } + } } private[this] lazy val cast: Any => Any = cast(child.dataType, dataType) protected override def nullSafeEval(input: Any): Any = cast(input) + override def genCode(ctx: CodegenContext): ExprCode = { + // If the cast does not change the structure, then we don't really need to cast anything. + // We can return what the children return. Same thing should happen in the interpreted path. + if (DataType.equalsStructurally(child.dataType, dataType)) { + child.genCode(ctx) + } else { + super.genCode(ctx) + } + } + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val eval = child.genCode(ctx) val nullSafeCast = nullSafeCastFunction(child.dataType, dataType, ctx) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala index 520aff5e2b677..30745c6a9d42a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala @@ -288,4 +288,30 @@ object DataType { case (fromDataType, toDataType) => fromDataType == toDataType } } + + /** + * Returns true if the two data types share the same "shape", i.e. the types (including + * nullability) are the same, but the field names don't need to be the same. + */ + def equalsStructurally(from: DataType, to: DataType): Boolean = { + (from, to) match { + case (left: ArrayType, right: ArrayType) => + equalsStructurally(left.elementType, right.elementType) && + left.containsNull == right.containsNull + + case (left: MapType, right: MapType) => + equalsStructurally(left.keyType, right.keyType) && + equalsStructurally(left.valueType, right.valueType) && + left.valueContainsNull == right.valueContainsNull + + case (StructType(fromFields), StructType(toFields)) => + fromFields.length == toFields.length && + fromFields.zip(toFields) + .forall { case (l, r) => + equalsStructurally(l.dataType, r.dataType) && l.nullable == r.nullable + } + + case (fromDataType, toDataType) => fromDataType == toDataType + } + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala index 8eccadbdd8afb..a7ffa884d2286 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala @@ -813,4 +813,18 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper { assert(cast(1.0.toFloat, DateType).checkInputDataTypes().isFailure) assert(cast(1.0, DateType).checkInputDataTypes().isFailure) } + + test("SPARK-20302 cast with same structure") { + val from = new StructType() + .add("a", IntegerType) + .add("b", new StructType().add("b1", LongType)) + + val to = new StructType() + .add("a1", IntegerType) + .add("b1", new StructType().add("b11", LongType)) + + val input = Row(10, Row(12L)) + + checkEvaluation(cast(Literal.create(input, from), to), input) + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala index f078ef013387b..c4635c8f126af 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala @@ -411,4 +411,35 @@ class DataTypeSuite extends SparkFunSuite { checkCatalogString(ArrayType(createStruct(40))) checkCatalogString(MapType(IntegerType, StringType)) checkCatalogString(MapType(IntegerType, createStruct(40))) + + def checkEqualsStructurally(from: DataType, to: DataType, expected: Boolean): Unit = { + val testName = s"equalsStructurally: (from: $from, to: $to)" + test(testName) { + assert(DataType.equalsStructurally(from, to) === expected) + } + } + + checkEqualsStructurally(BooleanType, BooleanType, true) + checkEqualsStructurally(IntegerType, IntegerType, true) + checkEqualsStructurally(IntegerType, LongType, false) + checkEqualsStructurally(ArrayType(IntegerType, true), ArrayType(IntegerType, true), true) + checkEqualsStructurally(ArrayType(IntegerType, true), ArrayType(IntegerType, false), false) + + checkEqualsStructurally( + new StructType().add("f1", IntegerType), + new StructType().add("f2", IntegerType), + true) + checkEqualsStructurally( + new StructType().add("f1", IntegerType), + new StructType().add("f2", IntegerType, false), + false) + + checkEqualsStructurally( + new StructType().add("f1", IntegerType).add("f", new StructType().add("f2", StringType)), + new StructType().add("f2", IntegerType).add("g", new StructType().add("f1", StringType)), + true) + checkEqualsStructurally( + new StructType().add("f1", IntegerType).add("f", new StructType().add("f2", StringType, false)), + new StructType().add("f2", IntegerType).add("g", new StructType().add("f1", StringType)), + false) } From 2e1fd46e12bf948490ece2caa73d227b6a924a14 Mon Sep 17 00:00:00 2001 From: jtoka Date: Wed, 12 Apr 2017 11:36:08 +0100 Subject: [PATCH 0254/1765] [SPARK-20296][TRIVIAL][DOCS] Count distinct error message for streaming ## What changes were proposed in this pull request? Update count distinct error message for streaming datasets/dataframes to match current behavior. These aggregations are not yet supported, regardless of whether the dataset/dataframe is aggregated. Author: jtoka Closes #17609 from jtoka/master. --- .../sql/catalyst/analysis/UnsupportedOperationChecker.scala | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala index 7da7f55aa5d7f..3f76f26dbe4ec 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala @@ -139,9 +139,8 @@ object UnsupportedOperationChecker { } throwErrorIf( child.isStreaming && distinctAggExprs.nonEmpty, - "Distinct aggregations are not supported on streaming DataFrames/Datasets, unless " + - "it is on aggregated DataFrame/Dataset in Complete output mode. Consider using " + - "approximate distinct aggregation (e.g. approx_count_distinct() instead of count()).") + "Distinct aggregations are not supported on streaming DataFrames/Datasets. Consider " + + "using approx_count_distinct() instead.") case _: Command => throwError("Commands like CreateTable*, AlterTable*, Show* are not supported with " + From ceaf77ae43a14e993ac6d1ff34b50256eacd6abb Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Wed, 12 Apr 2017 12:38:48 +0100 Subject: [PATCH 0255/1765] [SPARK-18692][BUILD][DOCS] Test Java 8 unidoc build on Jenkins ## What changes were proposed in this pull request? This PR proposes to run Spark unidoc to test Javadoc 8 build as Javadoc 8 is easily re-breakable. There are several problems with it: - It introduces little extra bit of time to run the tests. In my case, it took 1.5 mins more (`Elapsed :[94.8746569157]`). How it was tested is described in "How was this patch tested?". - > One problem that I noticed was that Unidoc appeared to be processing test sources: if we can find a way to exclude those from being processed in the first place then that might significantly speed things up. (see joshrosen's [comment](https://issues.apache.org/jira/browse/SPARK-18692?focusedCommentId=15947627&page=com.atlassian.jira.plugin.system.issuetabpanels:comment-tabpanel#comment-15947627)) To complete this automated build, It also suggests to fix existing Javadoc breaks / ones introduced by test codes as described above. There fixes are similar instances that previously fixed. Please refer https://github.com/apache/spark/pull/15999 and https://github.com/apache/spark/pull/16013 Note that this only fixes **errors** not **warnings**. Please see my observation https://github.com/apache/spark/pull/17389#issuecomment-288438704 for spurious errors by warnings. ## How was this patch tested? Manually via `jekyll build` for building tests. Also, tested via running `./dev/run-tests`. This was tested via manually adding `time.time()` as below: ```diff profiles_and_goals = build_profiles + sbt_goals print("[info] Building Spark unidoc (w/Hive 1.2.1) using SBT with these arguments: ", " ".join(profiles_and_goals)) + import time + st = time.time() exec_sbt(profiles_and_goals) + print("Elapsed :[%s]" % str(time.time() - st)) ``` produces ``` ... ======================================================================== Building Unidoc API Documentation ======================================================================== ... [info] Main Java API documentation successful. ... Elapsed :[94.8746569157] ... Author: hyukjinkwon Closes #17477 from HyukjinKwon/SPARK-18692. --- .../org/apache/spark/rpc/RpcEndpoint.scala | 10 +++++----- .../org/apache/spark/rpc/RpcTimeout.scala | 2 +- .../apache/spark/scheduler/DAGScheduler.scala | 4 ++-- .../scheduler/ExternalClusterManager.scala | 2 +- .../spark/scheduler/TaskSchedulerImpl.scala | 8 ++++---- .../apache/spark/storage/BlockManager.scala | 2 +- .../org/apache/spark/AccumulatorSuite.scala | 4 ++-- .../spark/ExternalShuffleServiceSuite.scala | 2 +- .../org/apache/spark/LocalSparkContext.scala | 2 +- .../scheduler/SchedulerIntegrationSuite.scala | 4 ++-- .../serializer/SerializerPropertiesSuite.scala | 2 +- dev/run-tests.py | 15 +++++++++++++++ .../spark/graphx/LocalSparkContext.scala | 2 +- .../spark/ml/classification/Classifier.scala | 2 +- .../org/apache/spark/ml/PipelineSuite.scala | 8 ++++++-- .../org/apache/spark/ml/feature/LSHTest.scala | 12 ++++++++---- .../apache/spark/ml/param/ParamsSuite.scala | 2 +- .../apache/spark/ml/tree/impl/TreeTests.scala | 6 ++++-- .../spark/ml/util/DefaultReadWriteTest.scala | 18 +++++++++--------- .../apache/spark/ml/util/StopwatchSuite.scala | 4 ++-- .../apache/spark/ml/util/TempDirectory.scala | 4 +++- .../spark/mllib/tree/ImpuritySuite.scala | 2 +- .../mllib/util/MLlibTestSparkContext.scala | 2 +- .../cluster/mesos/MesosSchedulerUtils.scala | 6 +++--- .../apache/spark/sql/RandomDataGenerator.scala | 8 ++++---- .../spark/sql/UnsafeProjectionBenchmark.scala | 2 +- .../org/apache/spark/sql/catalog/Catalog.scala | 2 +- .../DatasetSerializerRegistratorSuite.scala | 4 +++- .../sql/streaming/FileStreamSourceSuite.scala | 4 ++-- .../spark/sql/streaming/StreamSuite.scala | 4 ++-- .../sql/streaming/StreamingQuerySuite.scala | 12 ++++++++---- .../apache/spark/sql/test/SQLTestUtils.scala | 18 ++++++++++-------- .../apache/spark/sql/test/TestSQLContext.scala | 2 +- .../java/org/apache/hive/service/Service.java | 2 +- .../apache/hive/service/ServiceOperations.java | 12 ++++++------ .../hive/service/auth/HttpAuthUtils.java | 2 +- .../auth/PasswdAuthenticationProvider.java | 2 +- .../service/auth/TSetIpAddressProcessor.java | 9 +++------ .../hive/service/cli/CLIServiceUtils.java | 2 +- .../cli/operation/ClassicTableTypeMapping.java | 6 +++--- .../cli/operation/TableTypeMapping.java | 2 +- .../service/cli/session/SessionManager.java | 4 +++- .../ThreadFactoryWithGarbageCleanup.java | 6 +++--- .../apache/spark/sql/hive/HiveInspectors.scala | 4 ++-- .../sql/hive/execution/HiveQueryFileTest.scala | 2 +- .../apache/spark/sql/hive/orc/OrcTest.scala | 4 ++-- .../spark/streaming/rdd/MapWithStateRDD.scala | 4 ++-- .../scheduler/rate/PIDRateEstimator.scala | 2 +- .../scheduler/rate/RateEstimator.scala | 2 +- 49 files changed, 140 insertions(+), 106 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/rpc/RpcEndpoint.scala b/core/src/main/scala/org/apache/spark/rpc/RpcEndpoint.scala index 0ba95169529e6..97eed540b8f59 100644 --- a/core/src/main/scala/org/apache/spark/rpc/RpcEndpoint.scala +++ b/core/src/main/scala/org/apache/spark/rpc/RpcEndpoint.scala @@ -35,7 +35,7 @@ private[spark] trait RpcEnvFactory { * * The life-cycle of an endpoint is: * - * constructor -> onStart -> receive* -> onStop + * {@code constructor -> onStart -> receive* -> onStop} * * Note: `receive` can be called concurrently. If you want `receive` to be thread-safe, please use * [[ThreadSafeRpcEndpoint]] @@ -63,16 +63,16 @@ private[spark] trait RpcEndpoint { } /** - * Process messages from [[RpcEndpointRef.send]] or [[RpcCallContext.reply)]]. If receiving a - * unmatched message, [[SparkException]] will be thrown and sent to `onError`. + * Process messages from `RpcEndpointRef.send` or `RpcCallContext.reply`. If receiving a + * unmatched message, `SparkException` will be thrown and sent to `onError`. */ def receive: PartialFunction[Any, Unit] = { case _ => throw new SparkException(self + " does not implement 'receive'") } /** - * Process messages from [[RpcEndpointRef.ask]]. If receiving a unmatched message, - * [[SparkException]] will be thrown and sent to `onError`. + * Process messages from `RpcEndpointRef.ask`. If receiving a unmatched message, + * `SparkException` will be thrown and sent to `onError`. */ def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { case _ => context.sendFailure(new SparkException(self + " won't reply anything")) diff --git a/core/src/main/scala/org/apache/spark/rpc/RpcTimeout.scala b/core/src/main/scala/org/apache/spark/rpc/RpcTimeout.scala index 2c9a976e76939..0557b7a3cc0b7 100644 --- a/core/src/main/scala/org/apache/spark/rpc/RpcTimeout.scala +++ b/core/src/main/scala/org/apache/spark/rpc/RpcTimeout.scala @@ -26,7 +26,7 @@ import org.apache.spark.SparkConf import org.apache.spark.util.{ThreadUtils, Utils} /** - * An exception thrown if RpcTimeout modifies a [[TimeoutException]]. + * An exception thrown if RpcTimeout modifies a `TimeoutException`. */ private[rpc] class RpcTimeoutException(message: String, cause: TimeoutException) extends TimeoutException(message) { initCause(cause) } diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index 09717316833a7..aab177f257a8c 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -607,7 +607,7 @@ class DAGScheduler( * @param resultHandler callback to pass each result to * @param properties scheduler properties to attach to this job, e.g. fair scheduler pool name * - * @throws Exception when the job fails + * @note Throws `Exception` when the job fails */ def runJob[T, U]( rdd: RDD[T], @@ -644,7 +644,7 @@ class DAGScheduler( * * @param rdd target RDD to run tasks on * @param func a function to run on each partition of the RDD - * @param evaluator [[ApproximateEvaluator]] to receive the partial results + * @param evaluator `ApproximateEvaluator` to receive the partial results * @param callSite where in the user program this job was called * @param timeout maximum time to wait for the job, in milliseconds * @param properties scheduler properties to attach to this job, e.g. fair scheduler pool name diff --git a/core/src/main/scala/org/apache/spark/scheduler/ExternalClusterManager.scala b/core/src/main/scala/org/apache/spark/scheduler/ExternalClusterManager.scala index d1ac7131baba5..47f3527a32c01 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/ExternalClusterManager.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/ExternalClusterManager.scala @@ -42,7 +42,7 @@ private[spark] trait ExternalClusterManager { /** * Create a scheduler backend for the given SparkContext and scheduler. This is - * called after task scheduler is created using [[ExternalClusterManager.createTaskScheduler()]]. + * called after task scheduler is created using `ExternalClusterManager.createTaskScheduler()`. * @param sc SparkContext * @param masterURL the master URL * @param scheduler TaskScheduler that will be used with the scheduler backend. diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala index c849a16023a7a..1b6bc9139f9c9 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala @@ -38,7 +38,7 @@ import org.apache.spark.util.{AccumulatorV2, ThreadUtils, Utils} /** * Schedules tasks for multiple types of clusters by acting through a SchedulerBackend. - * It can also work with a local setup by using a [[LocalSchedulerBackend]] and setting + * It can also work with a local setup by using a `LocalSchedulerBackend` and setting * isLocal to true. It handles common logic, like determining a scheduling order across jobs, waking * up to launch speculative tasks, etc. * @@ -704,12 +704,12 @@ private[spark] object TaskSchedulerImpl { * Used to balance containers across hosts. * * Accepts a map of hosts to resource offers for that host, and returns a prioritized list of - * resource offers representing the order in which the offers should be used. The resource + * resource offers representing the order in which the offers should be used. The resource * offers are ordered such that we'll allocate one container on each host before allocating a * second container on any host, and so on, in order to reduce the damage if a host fails. * - * For example, given , , , returns - * [o1, o5, o4, 02, o6, o3] + * For example, given {@literal }, {@literal } and + * {@literal }, returns {@literal [o1, o5, o4, o2, o6, o3]}. */ def prioritizeContainers[K, T] (map: HashMap[K, ArrayBuffer[T]]): List[T] = { val _keyList = new ArrayBuffer[K](map.size) diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala index 63acba65d3c5b..3219969bcd06f 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala @@ -66,7 +66,7 @@ private[spark] trait BlockData { /** * Returns a Netty-friendly wrapper for the block's data. * - * @see [[ManagedBuffer#convertToNetty()]] + * Please see `ManagedBuffer.convertToNetty()` for more details. */ def toNetty(): Object diff --git a/core/src/test/scala/org/apache/spark/AccumulatorSuite.scala b/core/src/test/scala/org/apache/spark/AccumulatorSuite.scala index 6d03ee091e4ed..ddbcb2d19dcbb 100644 --- a/core/src/test/scala/org/apache/spark/AccumulatorSuite.scala +++ b/core/src/test/scala/org/apache/spark/AccumulatorSuite.scala @@ -243,7 +243,7 @@ private[spark] object AccumulatorSuite { import InternalAccumulator._ /** - * Create a long accumulator and register it to [[AccumulatorContext]]. + * Create a long accumulator and register it to `AccumulatorContext`. */ def createLongAccum( name: String, @@ -258,7 +258,7 @@ private[spark] object AccumulatorSuite { } /** - * Make an [[AccumulableInfo]] out of an [[Accumulable]] with the intent to use the + * Make an `AccumulableInfo` out of an [[Accumulable]] with the intent to use the * info as an accumulator update. */ def makeInfo(a: AccumulatorV2[_, _]): AccumulableInfo = a.toInfo(Some(a.value), None) diff --git a/core/src/test/scala/org/apache/spark/ExternalShuffleServiceSuite.scala b/core/src/test/scala/org/apache/spark/ExternalShuffleServiceSuite.scala index eb3fb99747d12..fe944031bc948 100644 --- a/core/src/test/scala/org/apache/spark/ExternalShuffleServiceSuite.scala +++ b/core/src/test/scala/org/apache/spark/ExternalShuffleServiceSuite.scala @@ -27,7 +27,7 @@ import org.apache.spark.network.shuffle.{ExternalShuffleBlockHandler, ExternalSh /** * This suite creates an external shuffle server and routes all shuffle fetches through it. * Note that failures in this suite may arise due to changes in Spark that invalidate expectations - * set up in [[ExternalShuffleBlockHandler]], such as changing the format of shuffle files or how + * set up in `ExternalShuffleBlockHandler`, such as changing the format of shuffle files or how * we hash files into folders. */ class ExternalShuffleServiceSuite extends ShuffleSuite with BeforeAndAfterAll { diff --git a/core/src/test/scala/org/apache/spark/LocalSparkContext.scala b/core/src/test/scala/org/apache/spark/LocalSparkContext.scala index 24ec99c7e5e60..1dd89bcbe36bc 100644 --- a/core/src/test/scala/org/apache/spark/LocalSparkContext.scala +++ b/core/src/test/scala/org/apache/spark/LocalSparkContext.scala @@ -22,7 +22,7 @@ import org.scalatest.BeforeAndAfterAll import org.scalatest.BeforeAndAfterEach import org.scalatest.Suite -/** Manages a local `sc` {@link SparkContext} variable, correctly stopping it after each test. */ +/** Manages a local `sc` `SparkContext` variable, correctly stopping it after each test. */ trait LocalSparkContext extends BeforeAndAfterEach with BeforeAndAfterAll { self: Suite => @transient var sc: SparkContext = _ diff --git a/core/src/test/scala/org/apache/spark/scheduler/SchedulerIntegrationSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/SchedulerIntegrationSuite.scala index 8103983c4392a..8300607ea888b 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/SchedulerIntegrationSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/SchedulerIntegrationSuite.scala @@ -95,12 +95,12 @@ abstract class SchedulerIntegrationSuite[T <: MockBackend: ClassTag] extends Spa } /** - * A map from partition -> results for all tasks of a job when you call this test framework's + * A map from partition to results for all tasks of a job when you call this test framework's * [[submit]] method. Two important considerations: * * 1. If there is a job failure, results may or may not be empty. If any tasks succeed before * the job has failed, they will get included in `results`. Instead, check for job failure by - * checking [[failure]]. (Also see [[assertDataStructuresEmpty()]]) + * checking [[failure]]. (Also see `assertDataStructuresEmpty()`) * * 2. This only gets cleared between tests. So you'll need to do special handling if you submit * more than one job in one test. diff --git a/core/src/test/scala/org/apache/spark/serializer/SerializerPropertiesSuite.scala b/core/src/test/scala/org/apache/spark/serializer/SerializerPropertiesSuite.scala index 4ce3b941bea55..99882bf76e29d 100644 --- a/core/src/test/scala/org/apache/spark/serializer/SerializerPropertiesSuite.scala +++ b/core/src/test/scala/org/apache/spark/serializer/SerializerPropertiesSuite.scala @@ -29,7 +29,7 @@ import org.apache.spark.serializer.KryoTest.RegistratorWithoutAutoReset /** * Tests to ensure that [[Serializer]] implementations obey the API contracts for methods that * describe properties of the serialized stream, such as - * [[Serializer.supportsRelocationOfSerializedObjects]]. + * `Serializer.supportsRelocationOfSerializedObjects`. */ class SerializerPropertiesSuite extends SparkFunSuite { diff --git a/dev/run-tests.py b/dev/run-tests.py index 04035b33e6a6b..450b68123e1fc 100755 --- a/dev/run-tests.py +++ b/dev/run-tests.py @@ -344,6 +344,19 @@ def build_spark_sbt(hadoop_version): exec_sbt(profiles_and_goals) +def build_spark_unidoc_sbt(hadoop_version): + set_title_and_block("Building Unidoc API Documentation", "BLOCK_DOCUMENTATION") + # Enable all of the profiles for the build: + build_profiles = get_hadoop_profiles(hadoop_version) + modules.root.build_profile_flags + sbt_goals = ["unidoc"] + profiles_and_goals = build_profiles + sbt_goals + + print("[info] Building Spark unidoc (w/Hive 1.2.1) using SBT with these arguments: ", + " ".join(profiles_and_goals)) + + exec_sbt(profiles_and_goals) + + def build_spark_assembly_sbt(hadoop_version): # Enable all of the profiles for the build: build_profiles = get_hadoop_profiles(hadoop_version) + modules.root.build_profile_flags @@ -352,6 +365,8 @@ def build_spark_assembly_sbt(hadoop_version): print("[info] Building Spark assembly (w/Hive 1.2.1) using SBT with these arguments: ", " ".join(profiles_and_goals)) exec_sbt(profiles_and_goals) + # Make sure that Java and Scala API documentation can be generated + build_spark_unidoc_sbt(hadoop_version) def build_apache_spark(build_tool, hadoop_version): diff --git a/graphx/src/test/scala/org/apache/spark/graphx/LocalSparkContext.scala b/graphx/src/test/scala/org/apache/spark/graphx/LocalSparkContext.scala index d2ad9be555770..66c4747fec268 100644 --- a/graphx/src/test/scala/org/apache/spark/graphx/LocalSparkContext.scala +++ b/graphx/src/test/scala/org/apache/spark/graphx/LocalSparkContext.scala @@ -21,7 +21,7 @@ import org.apache.spark.SparkConf import org.apache.spark.SparkContext /** - * Provides a method to run tests against a {@link SparkContext} variable that is correctly stopped + * Provides a method to run tests against a `SparkContext` variable that is correctly stopped * after each test. */ trait LocalSparkContext { diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala index d8608d885d6f1..bc0b49d48d323 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala @@ -74,7 +74,7 @@ abstract class Classifier[ * and features (`Vector`). * @param numClasses Number of classes label can take. Labels must be integers in the range * [0, numClasses). - * @throws SparkException if any label is not an integer >= 0 + * @note Throws `SparkException` if any label is a non-integer or is negative */ protected def extractLabeledPoints(dataset: Dataset[_], numClasses: Int): RDD[LabeledPoint] = { require(numClasses > 0, s"Classifier (in extractLabeledPoints) found numClasses =" + diff --git a/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala index 4cdbf845ae4f5..4a7e4dd80f246 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala @@ -230,7 +230,9 @@ class PipelineSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul } -/** Used to test [[Pipeline]] with [[MLWritable]] stages */ +/** + * Used to test [[Pipeline]] with `MLWritable` stages + */ class WritableStage(override val uid: String) extends Transformer with MLWritable { final val intParam: IntParam = new IntParam(this, "intParam", "doc") @@ -257,7 +259,9 @@ object WritableStage extends MLReadable[WritableStage] { override def load(path: String): WritableStage = super.load(path) } -/** Used to test [[Pipeline]] with non-[[MLWritable]] stages */ +/** + * Used to test [[Pipeline]] with non-`MLWritable` stages + */ class UnWritableStage(override val uid: String) extends Transformer { final val intParam: IntParam = new IntParam(this, "intParam", "doc") diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/LSHTest.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/LSHTest.scala index dd4dd62b8cfe9..db4f56ed60d32 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/LSHTest.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/LSHTest.scala @@ -29,8 +29,10 @@ private[ml] object LSHTest { * the following property is satisfied. * * There exist dist1, dist2, p1, p2, so that for any two elements e1 and e2, - * If dist(e1, e2) <= dist1, then Pr{h(x) == h(y)} >= p1 - * If dist(e1, e2) >= dist2, then Pr{h(x) == h(y)} <= p2 + * If dist(e1, e2) is less than or equal to dist1, then Pr{h(x) == h(y)} is greater than + * or equal to p1 + * If dist(e1, e2) is greater than or equal to dist2, then Pr{h(x) == h(y)} is less than + * or equal to p2 * * This is called locality sensitive property. This method checks the property on an * existing dataset and calculate the probabilities. @@ -38,8 +40,10 @@ private[ml] object LSHTest { * * This method hashes each elements to hash buckets using LSH, and calculate the false positive * and false negative: - * False positive: Of all (e1, e2) sharing any bucket, the probability of dist(e1, e2) > distFP - * False negative: Of all (e1, e2) not sharing buckets, the probability of dist(e1, e2) < distFN + * False positive: Of all (e1, e2) sharing any bucket, the probability of dist(e1, e2) is greater + * than distFP + * False negative: Of all (e1, e2) not sharing buckets, the probability of dist(e1, e2) is less + * than distFN * * @param dataset The dataset to verify the locality sensitive hashing property. * @param lsh The lsh instance to perform the hashing diff --git a/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala index aa9c53ca30eee..78a33e05e0e48 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala @@ -377,7 +377,7 @@ class ParamsSuite extends SparkFunSuite { object ParamsSuite extends SparkFunSuite { /** - * Checks common requirements for [[Params.params]]: + * Checks common requirements for `Params.params`: * - params are ordered by names * - param parent has the same UID as the object's UID * - param name is the same as the param method name diff --git a/mllib/src/test/scala/org/apache/spark/ml/tree/impl/TreeTests.scala b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/TreeTests.scala index c90cb8ca1034c..92a236928e90b 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/tree/impl/TreeTests.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/TreeTests.scala @@ -34,7 +34,7 @@ private[ml] object TreeTests extends SparkFunSuite { * Convert the given data to a DataFrame, and set the features and label metadata. * @param data Dataset. Categorical features and labels must already have 0-based indices. * This must be non-empty. - * @param categoricalFeatures Map: categorical feature index -> number of distinct values + * @param categoricalFeatures Map: categorical feature index to number of distinct values * @param numClasses Number of classes label can take. If 0, mark as continuous. * @return DataFrame with metadata */ @@ -69,7 +69,9 @@ private[ml] object TreeTests extends SparkFunSuite { df("label").as("label", labelMetadata)) } - /** Java-friendly version of [[setMetadata()]] */ + /** + * Java-friendly version of `setMetadata()` + */ def setMetadata( data: JavaRDD[LabeledPoint], categoricalFeatures: java.util.Map[java.lang.Integer, java.lang.Integer], diff --git a/mllib/src/test/scala/org/apache/spark/ml/util/DefaultReadWriteTest.scala b/mllib/src/test/scala/org/apache/spark/ml/util/DefaultReadWriteTest.scala index bfe8f12258bb8..27d606cb05dc2 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/util/DefaultReadWriteTest.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/util/DefaultReadWriteTest.scala @@ -81,20 +81,20 @@ trait DefaultReadWriteTest extends TempDirectory { self: Suite => /** * Default test for Estimator, Model pairs: * - Explicitly set Params, and train model - * - Test save/load using [[testDefaultReadWrite()]] on Estimator and Model + * - Test save/load using `testDefaultReadWrite` on Estimator and Model * - Check Params on Estimator and Model * - Compare model data * - * This requires that [[Model]]'s [[Param]]s should be a subset of [[Estimator]]'s [[Param]]s. + * This requires that `Model`'s `Param`s should be a subset of `Estimator`'s `Param`s. * * @param estimator Estimator to test - * @param dataset Dataset to pass to [[Estimator.fit()]] - * @param testEstimatorParams Set of [[Param]] values to set in estimator - * @param testModelParams Set of [[Param]] values to set in model - * @param checkModelData Method which takes the original and loaded [[Model]] and compares their - * data. This method does not need to check [[Param]] values. - * @tparam E Type of [[Estimator]] - * @tparam M Type of [[Model]] produced by estimator + * @param dataset Dataset to pass to `Estimator.fit()` + * @param testEstimatorParams Set of `Param` values to set in estimator + * @param testModelParams Set of `Param` values to set in model + * @param checkModelData Method which takes the original and loaded `Model` and compares their + * data. This method does not need to check `Param` values. + * @tparam E Type of `Estimator` + * @tparam M Type of `Model` produced by estimator */ def testEstimatorAndModelReadWrite[ E <: Estimator[M] with MLWritable, M <: Model[M] with MLWritable]( diff --git a/mllib/src/test/scala/org/apache/spark/ml/util/StopwatchSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/util/StopwatchSuite.scala index 141249a427a4c..54e363a8b9f2b 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/util/StopwatchSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/util/StopwatchSuite.scala @@ -105,8 +105,8 @@ class StopwatchSuite extends SparkFunSuite with MLlibTestSparkContext { private object StopwatchSuite extends SparkFunSuite { /** - * Checks the input stopwatch on a task that takes a random time (<10ms) to finish. Validates and - * returns the duration reported by the stopwatch. + * Checks the input stopwatch on a task that takes a random time (less than 10ms) to finish. + * Validates and returns the duration reported by the stopwatch. */ def checkStopwatch(sw: Stopwatch): Long = { val ubStart = now diff --git a/mllib/src/test/scala/org/apache/spark/ml/util/TempDirectory.scala b/mllib/src/test/scala/org/apache/spark/ml/util/TempDirectory.scala index 8f11bbc8e47af..50b73e0e99a22 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/util/TempDirectory.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/util/TempDirectory.scala @@ -30,7 +30,9 @@ trait TempDirectory extends BeforeAndAfterAll { self: Suite => private var _tempDir: File = _ - /** Returns the temporary directory as a [[File]] instance. */ + /** + * Returns the temporary directory as a `File` instance. + */ protected def tempDir: File = _tempDir override def beforeAll(): Unit = { diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/ImpuritySuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/ImpuritySuite.scala index 14152cdd63bc7..d0f02dd966bd5 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/tree/ImpuritySuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/ImpuritySuite.scala @@ -21,7 +21,7 @@ import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.tree.impurity.{EntropyAggregator, GiniAggregator} /** - * Test suites for [[GiniAggregator]] and [[EntropyAggregator]]. + * Test suites for `GiniAggregator` and `EntropyAggregator`. */ class ImpuritySuite extends SparkFunSuite { test("Gini impurity does not support negative labels") { diff --git a/mllib/src/test/scala/org/apache/spark/mllib/util/MLlibTestSparkContext.scala b/mllib/src/test/scala/org/apache/spark/mllib/util/MLlibTestSparkContext.scala index 6bb7ed9c9513c..720237bd2dddd 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/util/MLlibTestSparkContext.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/util/MLlibTestSparkContext.scala @@ -60,7 +60,7 @@ trait MLlibTestSparkContext extends TempDirectory { self: Suite => * A helper object for importing SQL implicits. * * Note that the alternative of importing `spark.implicits._` is not possible here. - * This is because we create the [[SQLContext]] immediately before the first test is run, + * This is because we create the `SQLContext` immediately before the first test is run, * but the implicits import is needed in the constructor. */ protected object testImplicits extends SQLImplicits { diff --git a/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerUtils.scala b/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerUtils.scala index 3f25535cb5ec2..9d81025a3016b 100644 --- a/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerUtils.scala +++ b/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerUtils.scala @@ -239,7 +239,7 @@ trait MesosSchedulerUtils extends Logging { } /** - * Converts the attributes from the resource offer into a Map of name -> Attribute Value + * Converts the attributes from the resource offer into a Map of name to Attribute Value * The attribute values are the mesos attribute types and they are * * @param offerAttributes the attributes offered @@ -296,7 +296,7 @@ trait MesosSchedulerUtils extends Logging { /** * Parses the attributes constraints provided to spark and build a matching data struct: - * Map[, Set[values-to-match]] + * {@literal Map[, Set[values-to-match]} * The constraints are specified as ';' separated key-value pairs where keys and values * are separated by ':'. The ':' implies equality (for singular values) and "is one of" for * multiple values (comma separated). For example: @@ -354,7 +354,7 @@ trait MesosSchedulerUtils extends Logging { * container overheads. * * @param sc SparkContext to use to get `spark.mesos.executor.memoryOverhead` value - * @return memory requirement as (0.1 * ) or MEMORY_OVERHEAD_MINIMUM + * @return memory requirement as (0.1 * memoryOverhead) or MEMORY_OVERHEAD_MINIMUM * (whichever is larger) */ def executorMemory(sc: SparkContext): Int = { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/RandomDataGenerator.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/RandomDataGenerator.scala index 850869799507f..8ae3ff5043e68 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/RandomDataGenerator.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/RandomDataGenerator.scala @@ -117,11 +117,11 @@ object RandomDataGenerator { } /** - * Returns a function which generates random values for the given [[DataType]], or `None` if no + * Returns a function which generates random values for the given `DataType`, or `None` if no * random data generator is defined for that data type. The generated values will use an external - * representation of the data type; for example, the random generator for [[DateType]] will return - * instances of [[java.sql.Date]] and the generator for [[StructType]] will return a [[Row]]. - * For a [[UserDefinedType]] for a class X, an instance of class X is returned. + * representation of the data type; for example, the random generator for `DateType` will return + * instances of [[java.sql.Date]] and the generator for `StructType` will return a [[Row]]. + * For a `UserDefinedType` for a class X, an instance of class X is returned. * * @param dataType the type to generate values for * @param nullable whether null values should be generated diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/UnsafeProjectionBenchmark.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/UnsafeProjectionBenchmark.scala index a6d90409382e5..769addf3b29e6 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/UnsafeProjectionBenchmark.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/UnsafeProjectionBenchmark.scala @@ -24,7 +24,7 @@ import org.apache.spark.sql.types._ import org.apache.spark.util.Benchmark /** - * Benchmark [[UnsafeProjection]] for fixed-length/primitive-type fields. + * Benchmark `UnsafeProjection` for fixed-length/primitive-type fields. */ object UnsafeProjectionBenchmark { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/catalog/Catalog.scala b/sql/core/src/main/scala/org/apache/spark/sql/catalog/Catalog.scala index 074952ff7900a..7e5da012f84ca 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/catalog/Catalog.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/catalog/Catalog.scala @@ -510,7 +510,7 @@ abstract class Catalog { def refreshTable(tableName: String): Unit /** - * Invalidates and refreshes all the cached data (and the associated metadata) for any [[Dataset]] + * Invalidates and refreshes all the cached data (and the associated metadata) for any `Dataset` * that contains the given data source path. Path matching is by prefix, i.e. "/" would invalidate * everything that is cached. * diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSerializerRegistratorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSerializerRegistratorSuite.scala index 0f3d0cefe3bb5..92c5656f65bb4 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSerializerRegistratorSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSerializerRegistratorSuite.scala @@ -56,7 +56,9 @@ object TestRegistrator { def apply(): TestRegistrator = new TestRegistrator() } -/** A [[Serializer]] that takes a [[KryoData]] and serializes it as KryoData(0). */ +/** + * A `Serializer` that takes a [[KryoData]] and serializes it as KryoData(0). + */ class ZeroKryoDataSerializer extends Serializer[KryoData] { override def write(kryo: Kryo, output: Output, t: KryoData): Unit = { output.writeInt(0) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSourceSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSourceSuite.scala index 26967782f77c7..2108b118bf059 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSourceSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSourceSuite.scala @@ -44,8 +44,8 @@ abstract class FileStreamSourceTest import testImplicits._ /** - * A subclass [[AddData]] for adding data to files. This is meant to use the - * [[FileStreamSource]] actually being used in the execution. + * A subclass `AddData` for adding data to files. This is meant to use the + * `FileStreamSource` actually being used in the execution. */ abstract class AddFileData extends AddData { override def addData(query: Option[StreamExecution]): (Source, Offset) = { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala index 5ab9dc2bc7763..13fe51a557733 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala @@ -569,7 +569,7 @@ class ThrowingIOExceptionLikeHadoop12074 extends FakeSource { object ThrowingIOExceptionLikeHadoop12074 { /** - * A latch to allow the user to wait until [[ThrowingIOExceptionLikeHadoop12074.createSource]] is + * A latch to allow the user to wait until `ThrowingIOExceptionLikeHadoop12074.createSource` is * called. */ @volatile var createSourceLatch: CountDownLatch = null @@ -600,7 +600,7 @@ class ThrowingInterruptedIOException extends FakeSource { object ThrowingInterruptedIOException { /** - * A latch to allow the user to wait until [[ThrowingInterruptedIOException.createSource]] is + * A latch to allow the user to wait until `ThrowingInterruptedIOException.createSource` is * called. */ @volatile var createSourceLatch: CountDownLatch = null diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala index 2ebbfcd22b97c..b69536ed37463 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala @@ -642,8 +642,10 @@ class StreamingQuerySuite extends StreamTest with BeforeAndAfter with Logging wi * * @param expectedBehavior Expected behavior (not blocked, blocked, or exception thrown) * @param timeoutMs Timeout in milliseconds - * When timeoutMs <= 0, awaitTermination() is tested (i.e. w/o timeout) - * When timeoutMs > 0, awaitTermination(timeoutMs) is tested + * When timeoutMs is less than or equal to 0, awaitTermination() is + * tested (i.e. w/o timeout) + * When timeoutMs is greater than 0, awaitTermination(timeoutMs) is + * tested * @param expectedReturnValue Expected return value when awaitTermination(timeoutMs) is used */ case class TestAwaitTermination( @@ -667,8 +669,10 @@ class StreamingQuerySuite extends StreamTest with BeforeAndAfter with Logging wi * * @param expectedBehavior Expected behavior (not blocked, blocked, or exception thrown) * @param timeoutMs Timeout in milliseconds - * When timeoutMs <= 0, awaitTermination() is tested (i.e. w/o timeout) - * When timeoutMs > 0, awaitTermination(timeoutMs) is tested + * When timeoutMs is less than or equal to 0, awaitTermination() is + * tested (i.e. w/o timeout) + * When timeoutMs is greater than 0, awaitTermination(timeoutMs) is + * tested * @param expectedReturnValue Expected return value when awaitTermination(timeoutMs) is used */ def assertOnQueryCondition( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala index cab219216d1ca..6a4cc95d36bea 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala @@ -41,11 +41,11 @@ import org.apache.spark.util.{UninterruptibleThread, Utils} /** * Helper trait that should be extended by all SQL test suites. * - * This allows subclasses to plugin a custom [[SQLContext]]. It comes with test data + * This allows subclasses to plugin a custom `SQLContext`. It comes with test data * prepared in advance as well as all implicit conversions used extensively by dataframes. - * To use implicit methods, import `testImplicits._` instead of through the [[SQLContext]]. + * To use implicit methods, import `testImplicits._` instead of through the `SQLContext`. * - * Subclasses should *not* create [[SQLContext]]s in the test suite constructor, which is + * Subclasses should *not* create `SQLContext`s in the test suite constructor, which is * prone to leaving multiple overlapping [[org.apache.spark.SparkContext]]s in the same JVM. */ private[sql] trait SQLTestUtils @@ -65,7 +65,7 @@ private[sql] trait SQLTestUtils * A helper object for importing SQL implicits. * * Note that the alternative of importing `spark.implicits._` is not possible here. - * This is because we create the [[SQLContext]] immediately before the first test is run, + * This is because we create the `SQLContext` immediately before the first test is run, * but the implicits import is needed in the constructor. */ protected object testImplicits extends SQLImplicits { @@ -73,7 +73,7 @@ private[sql] trait SQLTestUtils } /** - * Materialize the test data immediately after the [[SQLContext]] is set up. + * Materialize the test data immediately after the `SQLContext` is set up. * This is necessary if the data is accessed by name but not through direct reference. */ protected def setupTestData(): Unit = { @@ -250,8 +250,8 @@ private[sql] trait SQLTestUtils } /** - * Turn a logical plan into a [[DataFrame]]. This should be removed once we have an easier - * way to construct [[DataFrame]] directly out of local data without relying on implicits. + * Turn a logical plan into a `DataFrame`. This should be removed once we have an easier + * way to construct `DataFrame` directly out of local data without relying on implicits. */ protected implicit def logicalPlanToSparkQuery(plan: LogicalPlan): DataFrame = { Dataset.ofRows(spark, plan) @@ -271,7 +271,9 @@ private[sql] trait SQLTestUtils } } - /** Run a test on a separate [[UninterruptibleThread]]. */ + /** + * Run a test on a separate `UninterruptibleThread`. + */ protected def testWithUninterruptibleThread(name: String, quietly: Boolean = false) (body: => Unit): Unit = { val timeoutMillis = 10000 diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/TestSQLContext.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/TestSQLContext.scala index b01977a23890f..959edf9a49371 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/test/TestSQLContext.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/TestSQLContext.scala @@ -22,7 +22,7 @@ import org.apache.spark.sql.SparkSession import org.apache.spark.sql.internal.{SessionState, SessionStateBuilder, SQLConf, WithTestConf} /** - * A special [[SparkSession]] prepared for testing. + * A special `SparkSession` prepared for testing. */ private[sql] class TestSparkSession(sc: SparkContext) extends SparkSession(sc) { self => def this(sparkConf: SparkConf) { diff --git a/sql/hive-thriftserver/src/main/java/org/apache/hive/service/Service.java b/sql/hive-thriftserver/src/main/java/org/apache/hive/service/Service.java index b95077cd62186..0d0e3e4011b5b 100644 --- a/sql/hive-thriftserver/src/main/java/org/apache/hive/service/Service.java +++ b/sql/hive-thriftserver/src/main/java/org/apache/hive/service/Service.java @@ -49,7 +49,7 @@ enum STATE { * The transition must be from {@link STATE#NOTINITED} to {@link STATE#INITED} unless the * operation failed and an exception was raised. * - * @param config + * @param conf * the configuration of the service */ void init(HiveConf conf); diff --git a/sql/hive-thriftserver/src/main/java/org/apache/hive/service/ServiceOperations.java b/sql/hive-thriftserver/src/main/java/org/apache/hive/service/ServiceOperations.java index a2c580d6acc71..c3219aabfc23b 100644 --- a/sql/hive-thriftserver/src/main/java/org/apache/hive/service/ServiceOperations.java +++ b/sql/hive-thriftserver/src/main/java/org/apache/hive/service/ServiceOperations.java @@ -51,7 +51,7 @@ public static void ensureCurrentState(Service.STATE state, /** * Initialize a service. - *

    + * * The service state is checked before the operation begins. * This process is not thread safe. * @param service a service that must be in the state @@ -69,7 +69,7 @@ public static void init(Service service, HiveConf configuration) { /** * Start a service. - *

    + * * The service state is checked before the operation begins. * This process is not thread safe. * @param service a service that must be in the state @@ -86,7 +86,7 @@ public static void start(Service service) { /** * Initialize then start a service. - *

    + * * The service state is checked before the operation begins. * This process is not thread safe. * @param service a service that must be in the state @@ -102,9 +102,9 @@ public static void deploy(Service service, HiveConf configuration) { /** * Stop a service. - *

    Do nothing if the service is null or not - * in a state in which it can be/needs to be stopped. - *

    + * + * Do nothing if the service is null or not in a state in which it can be/needs to be stopped. + * * The service state is checked before the operation begins. * This process is not thread safe. * @param service a service or null diff --git a/sql/hive-thriftserver/src/main/java/org/apache/hive/service/auth/HttpAuthUtils.java b/sql/hive-thriftserver/src/main/java/org/apache/hive/service/auth/HttpAuthUtils.java index 5021528299682..f7375ee707830 100644 --- a/sql/hive-thriftserver/src/main/java/org/apache/hive/service/auth/HttpAuthUtils.java +++ b/sql/hive-thriftserver/src/main/java/org/apache/hive/service/auth/HttpAuthUtils.java @@ -89,7 +89,7 @@ public static String getKerberosServiceTicket(String principal, String host, * @param clientUserName Client User name. * @return An unsigned cookie token generated from input parameters. * The final cookie generated is of the following format : - * cu=&rn=&s= + * {@code cu=&rn=&s=} */ public static String createCookieToken(String clientUserName) { StringBuffer sb = new StringBuffer(); diff --git a/sql/hive-thriftserver/src/main/java/org/apache/hive/service/auth/PasswdAuthenticationProvider.java b/sql/hive-thriftserver/src/main/java/org/apache/hive/service/auth/PasswdAuthenticationProvider.java index e2a6de165adc5..1af1c1d06e7f7 100644 --- a/sql/hive-thriftserver/src/main/java/org/apache/hive/service/auth/PasswdAuthenticationProvider.java +++ b/sql/hive-thriftserver/src/main/java/org/apache/hive/service/auth/PasswdAuthenticationProvider.java @@ -26,7 +26,7 @@ public interface PasswdAuthenticationProvider { * to authenticate users for their requests. * If a user is to be granted, return nothing/throw nothing. * When a user is to be disallowed, throw an appropriate {@link AuthenticationException}. - *

    + * * For an example implementation, see {@link LdapAuthenticationProviderImpl}. * * @param user The username received over the connection request diff --git a/sql/hive-thriftserver/src/main/java/org/apache/hive/service/auth/TSetIpAddressProcessor.java b/sql/hive-thriftserver/src/main/java/org/apache/hive/service/auth/TSetIpAddressProcessor.java index 645e3e2bbd4e2..9a61ad49942c8 100644 --- a/sql/hive-thriftserver/src/main/java/org/apache/hive/service/auth/TSetIpAddressProcessor.java +++ b/sql/hive-thriftserver/src/main/java/org/apache/hive/service/auth/TSetIpAddressProcessor.java @@ -31,12 +31,9 @@ /** * This class is responsible for setting the ipAddress for operations executed via HiveServer2. - *

    - *

      - *
    • IP address is only set for operations that calls listeners with hookContext
    • - *
    • IP address is only set if the underlying transport mechanism is socket
    • - *
    - *

    + * + * - IP address is only set for operations that calls listeners with hookContext + * - IP address is only set if the underlying transport mechanism is socket * * @see org.apache.hadoop.hive.ql.hooks.ExecuteWithHookContext */ diff --git a/sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/CLIServiceUtils.java b/sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/CLIServiceUtils.java index 9d64b102e008d..bf2380632fa6c 100644 --- a/sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/CLIServiceUtils.java +++ b/sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/CLIServiceUtils.java @@ -38,7 +38,7 @@ public class CLIServiceUtils { * Convert a SQL search pattern into an equivalent Java Regex. * * @param pattern input which may contain '%' or '_' wildcard characters, or - * these characters escaped using {@link #getSearchStringEscape()}. + * these characters escaped using {@code getSearchStringEscape()}. * @return replace %/_ with regex search characters, also handle escaped * characters. */ diff --git a/sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/operation/ClassicTableTypeMapping.java b/sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/operation/ClassicTableTypeMapping.java index 05a6bf938404b..af36057bdaeca 100644 --- a/sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/operation/ClassicTableTypeMapping.java +++ b/sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/operation/ClassicTableTypeMapping.java @@ -28,9 +28,9 @@ /** * ClassicTableTypeMapping. * Classic table type mapping : - * Managed Table ==> Table - * External Table ==> Table - * Virtual View ==> View + * Managed Table to Table + * External Table to Table + * Virtual View to View */ public class ClassicTableTypeMapping implements TableTypeMapping { diff --git a/sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/operation/TableTypeMapping.java b/sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/operation/TableTypeMapping.java index e392c459cf586..e59d19ea6be42 100644 --- a/sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/operation/TableTypeMapping.java +++ b/sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/operation/TableTypeMapping.java @@ -31,7 +31,7 @@ public interface TableTypeMapping { /** * Map hive's table type name to client's table type - * @param clientTypeName + * @param hiveTypeName * @return */ String mapToClientType(String hiveTypeName); diff --git a/sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/session/SessionManager.java b/sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/session/SessionManager.java index de066dd406c7a..c1b3892f52060 100644 --- a/sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/session/SessionManager.java +++ b/sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/session/SessionManager.java @@ -224,7 +224,9 @@ public SessionHandle openSession(TProtocolVersion protocol, String username, Str * The username passed to this method is the effective username. * If withImpersonation is true (==doAs true) we wrap all the calls in HiveSession * within a UGI.doAs, where UGI corresponds to the effective user. - * @see org.apache.hive.service.cli.thrift.ThriftCLIService#getUserName() + * + * Please see {@code org.apache.hive.service.cli.thrift.ThriftCLIService.getUserName()} for + * more details. * * @param protocol * @param username diff --git a/sql/hive-thriftserver/src/main/java/org/apache/hive/service/server/ThreadFactoryWithGarbageCleanup.java b/sql/hive-thriftserver/src/main/java/org/apache/hive/service/server/ThreadFactoryWithGarbageCleanup.java index fb8141a905acb..94f8126552e9d 100644 --- a/sql/hive-thriftserver/src/main/java/org/apache/hive/service/server/ThreadFactoryWithGarbageCleanup.java +++ b/sql/hive-thriftserver/src/main/java/org/apache/hive/service/server/ThreadFactoryWithGarbageCleanup.java @@ -30,12 +30,12 @@ * in custom cleanup code to be called before this thread is GC-ed. * Currently cleans up the following: * 1. ThreadLocal RawStore object: - * In case of an embedded metastore, HiveServer2 threads (foreground & background) + * In case of an embedded metastore, HiveServer2 threads (foreground and background) * end up caching a ThreadLocal RawStore object. The ThreadLocal RawStore object has - * an instance of PersistenceManagerFactory & PersistenceManager. + * an instance of PersistenceManagerFactory and PersistenceManager. * The PersistenceManagerFactory keeps a cache of PersistenceManager objects, * which are only removed when PersistenceManager#close method is called. - * HiveServer2 uses ExecutorService for managing thread pools for foreground & background threads. + * HiveServer2 uses ExecutorService for managing thread pools for foreground and background threads. * ExecutorService unfortunately does not provide any hooks to be called, * when a thread from the pool is terminated. * As a solution, we're using this ThreadFactory to keep a cache of RawStore objects per thread. diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala index 6f5b923cd4f9e..4dec2f71b8a50 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala @@ -53,8 +53,8 @@ import org.apache.spark.unsafe.types.UTF8String * java.sql.Date * java.sql.Timestamp * Complex Types => - * Map: [[MapData]] - * List: [[ArrayData]] + * Map: `MapData` + * List: `ArrayData` * Struct: [[org.apache.spark.sql.catalyst.InternalRow]] * Union: NOT SUPPORTED YET * The Complex types plays as a container, which can hold arbitrary data types. diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQueryFileTest.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQueryFileTest.scala index e772324a57ab8..bb4ce6d3aa3f1 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQueryFileTest.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQueryFileTest.scala @@ -24,7 +24,7 @@ import org.apache.spark.sql.catalyst.util._ /** * A framework for running the query tests that are listed as a set of text files. * - * TestSuites that derive from this class must provide a map of testCaseName -> testCaseFiles + * TestSuites that derive from this class must provide a map of testCaseName to testCaseFiles * that should be included. Additionally, there is support for whitelisting and blacklisting * tests as development progresses. */ diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcTest.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcTest.scala index 7226ed521ef32..a2f08c5ba72c6 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcTest.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcTest.scala @@ -43,7 +43,7 @@ private[sql] trait OrcTest extends SQLTestUtils with TestHiveSingleton { } /** - * Writes `data` to a Orc file and reads it back as a [[DataFrame]], + * Writes `data` to a Orc file and reads it back as a `DataFrame`, * which is then passed to `f`. The Orc file will be deleted after `f` returns. */ protected def withOrcDataFrame[T <: Product: ClassTag: TypeTag] @@ -53,7 +53,7 @@ private[sql] trait OrcTest extends SQLTestUtils with TestHiveSingleton { } /** - * Writes `data` to a Orc file, reads it back as a [[DataFrame]] and registers it as a + * Writes `data` to a Orc file, reads it back as a `DataFrame` and registers it as a * temporary table named `tableName`, then call `f`. The temporary table together with the * Orc file will be dropped/deleted after `f` returns. */ diff --git a/streaming/src/main/scala/org/apache/spark/streaming/rdd/MapWithStateRDD.scala b/streaming/src/main/scala/org/apache/spark/streaming/rdd/MapWithStateRDD.scala index 58b7031d5ea6a..15d3c7e54b8dd 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/rdd/MapWithStateRDD.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/rdd/MapWithStateRDD.scala @@ -29,7 +29,7 @@ import org.apache.spark.streaming.util.{EmptyStateMap, StateMap} import org.apache.spark.util.Utils /** - * Record storing the keyed-state [[MapWithStateRDD]]. Each record contains a [[StateMap]] and a + * Record storing the keyed-state [[MapWithStateRDD]]. Each record contains a `StateMap` and a * sequence of records returned by the mapping function of `mapWithState`. */ private[streaming] case class MapWithStateRDDRecord[K, S, E]( @@ -111,7 +111,7 @@ private[streaming] class MapWithStateRDDPartition( /** * RDD storing the keyed states of `mapWithState` operation and corresponding mapped data. * Each partition of this RDD has a single record of type [[MapWithStateRDDRecord]]. This contains a - * [[StateMap]] (containing the keyed-states) and the sequence of records returned by the mapping + * `StateMap` (containing the keyed-states) and the sequence of records returned by the mapping * function of `mapWithState`. * @param prevStateRDD The previous MapWithStateRDD on whose StateMap data `this` RDD * will be created diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/rate/PIDRateEstimator.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/rate/PIDRateEstimator.scala index a73e6cc2cd9c1..dc02062b9eb44 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/rate/PIDRateEstimator.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/rate/PIDRateEstimator.scala @@ -26,7 +26,7 @@ import org.apache.spark.internal.Logging * case of Spark Streaming the error is the difference between the measured processing * rate (number of elements/processing delay) and the previous rate. * - * @see https://en.wikipedia.org/wiki/PID_controller + * @see PID controller (Wikipedia) * * @param batchIntervalMillis the batch duration, in milliseconds * @param proportional how much the correction should depend on the current diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/rate/RateEstimator.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/rate/RateEstimator.scala index 7b2ef6881d6f7..e4b9dffee04f4 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/rate/RateEstimator.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/rate/RateEstimator.scala @@ -24,7 +24,7 @@ import org.apache.spark.streaming.Duration * A component that estimates the rate at which an `InputDStream` should ingest * records, based on updates at every batch completion. * - * @see [[org.apache.spark.streaming.scheduler.RateController]] + * Please see `org.apache.spark.streaming.scheduler.RateController` for more details. */ private[streaming] trait RateEstimator extends Serializable { From 504e62e2f4b7df7e002ea014a855cebe1ff95193 Mon Sep 17 00:00:00 2001 From: Xiao Li Date: Wed, 12 Apr 2017 09:01:26 -0700 Subject: [PATCH 0256/1765] [SPARK-20303][SQL] Rename createTempFunction to registerFunction ### What changes were proposed in this pull request? Session catalog API `createTempFunction` is being used by Hive build-in functions, persistent functions, and temporary functions. Thus, the name is confusing. This PR is to rename it by `registerFunction`. Also we can move construction of `FunctionBuilder` and `ExpressionInfo` into the new `registerFunction`, instead of duplicating the logics everywhere. In the next PRs, the remaining Function-related APIs also need cleanups. ### How was this patch tested? Existing test cases. Author: Xiao Li Closes #17615 from gatorsmile/cleanupCreateTempFunction. --- .../analysis/AlreadyExistException.scala | 3 -- .../sql/catalyst/catalog/SessionCatalog.scala | 31 +++++++------- .../catalog/SessionCatalogSuite.scala | 40 ++++++++++--------- .../sql/execution/command/functions.scala | 9 ++--- .../spark/sql/internal/CatalogSuite.scala | 5 ++- .../spark/sql/hive/HiveSessionCatalog.scala | 18 +++------ .../ObjectHashAggregateExecBenchmark.scala | 10 +++-- 7 files changed, 53 insertions(+), 63 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/AlreadyExistException.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/AlreadyExistException.scala index ec56fe7729c2a..57f7a80bedc6c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/AlreadyExistException.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/AlreadyExistException.scala @@ -44,6 +44,3 @@ class PartitionsAlreadyExistException(db: String, table: String, specs: Seq[Tabl class FunctionAlreadyExistsException(db: String, func: String) extends AnalysisException(s"Function '$func' already exists in database '$db'") - -class TempFunctionAlreadyExistsException(func: String) - extends AnalysisException(s"Temporary function '$func' already exists") diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala index faedf5f91c3ef..1417bccf657cd 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala @@ -1050,7 +1050,7 @@ class SessionCatalog( * * This performs reflection to decide what type of [[Expression]] to return in the builder. */ - def makeFunctionBuilder(name: String, functionClassName: String): FunctionBuilder = { + protected def makeFunctionBuilder(name: String, functionClassName: String): FunctionBuilder = { // TODO: at least support UDAFs here throw new UnsupportedOperationException("Use sqlContext.udf.register(...) instead.") } @@ -1064,18 +1064,20 @@ class SessionCatalog( } /** - * Create a temporary function. - * This assumes no database is specified in `funcDefinition`. + * Registers a temporary or permanent function into a session-specific [[FunctionRegistry]] */ - def createTempFunction( - name: String, - info: ExpressionInfo, - funcDefinition: FunctionBuilder, - ignoreIfExists: Boolean): Unit = { - if (functionRegistry.lookupFunctionBuilder(name).isDefined && !ignoreIfExists) { - throw new TempFunctionAlreadyExistsException(name) + def registerFunction( + funcDefinition: CatalogFunction, + ignoreIfExists: Boolean, + functionBuilder: Option[FunctionBuilder] = None): Unit = { + val func = funcDefinition.identifier + if (functionRegistry.functionExists(func.unquotedString) && !ignoreIfExists) { + throw new AnalysisException(s"Function $func already exists") } - functionRegistry.registerFunction(name, info, funcDefinition) + val info = new ExpressionInfo(funcDefinition.className, func.database.orNull, func.funcName) + val builder = + functionBuilder.getOrElse(makeFunctionBuilder(func.unquotedString, funcDefinition.className)) + functionRegistry.registerFunction(func.unquotedString, info, builder) } /** @@ -1180,12 +1182,7 @@ class SessionCatalog( // catalog. So, it is possible that qualifiedName is not exactly the same as // catalogFunction.identifier.unquotedString (difference is on case-sensitivity). // At here, we preserve the input from the user. - val info = new ExpressionInfo( - catalogFunction.className, - qualifiedName.database.orNull, - qualifiedName.funcName) - val builder = makeFunctionBuilder(qualifiedName.unquotedString, catalogFunction.className) - createTempFunction(qualifiedName.unquotedString, info, builder, ignoreIfExists = false) + registerFunction(catalogFunction.copy(identifier = qualifiedName), ignoreIfExists = false) // Now, we need to create the Expression. functionRegistry.lookupFunction(qualifiedName.unquotedString, children) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalogSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalogSuite.scala index 9ba846fb25279..be8903000a0d1 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalogSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalogSuite.scala @@ -1162,10 +1162,10 @@ abstract class SessionCatalogSuite extends PlanTest { withBasicCatalog { catalog => val tempFunc1 = (e: Seq[Expression]) => e.head val tempFunc2 = (e: Seq[Expression]) => e.last - val info1 = new ExpressionInfo("tempFunc1", "temp1") - val info2 = new ExpressionInfo("tempFunc2", "temp2") - catalog.createTempFunction("temp1", info1, tempFunc1, ignoreIfExists = false) - catalog.createTempFunction("temp2", info2, tempFunc2, ignoreIfExists = false) + catalog.registerFunction( + newFunc("temp1", None), ignoreIfExists = false, functionBuilder = Some(tempFunc1)) + catalog.registerFunction( + newFunc("temp2", None), ignoreIfExists = false, functionBuilder = Some(tempFunc2)) val arguments = Seq(Literal(1), Literal(2), Literal(3)) assert(catalog.lookupFunction(FunctionIdentifier("temp1"), arguments) === Literal(1)) assert(catalog.lookupFunction(FunctionIdentifier("temp2"), arguments) === Literal(3)) @@ -1174,13 +1174,15 @@ abstract class SessionCatalogSuite extends PlanTest { catalog.lookupFunction(FunctionIdentifier("temp3"), arguments) } val tempFunc3 = (e: Seq[Expression]) => Literal(e.size) - val info3 = new ExpressionInfo("tempFunc3", "temp1") // Temporary function already exists - intercept[TempFunctionAlreadyExistsException] { - catalog.createTempFunction("temp1", info3, tempFunc3, ignoreIfExists = false) - } + val e = intercept[AnalysisException] { + catalog.registerFunction( + newFunc("temp1", None), ignoreIfExists = false, functionBuilder = Some(tempFunc3)) + }.getMessage + assert(e.contains("Function temp1 already exists")) // Temporary function is overridden - catalog.createTempFunction("temp1", info3, tempFunc3, ignoreIfExists = true) + catalog.registerFunction( + newFunc("temp1", None), ignoreIfExists = true, functionBuilder = Some(tempFunc3)) assert( catalog.lookupFunction( FunctionIdentifier("temp1"), arguments) === Literal(arguments.length)) @@ -1193,8 +1195,8 @@ abstract class SessionCatalogSuite extends PlanTest { assert(!catalog.isTemporaryFunction(FunctionIdentifier("temp1"))) val tempFunc1 = (e: Seq[Expression]) => e.head - val info1 = new ExpressionInfo("tempFunc1", "temp1") - catalog.createTempFunction("temp1", info1, tempFunc1, ignoreIfExists = false) + catalog.registerFunction( + newFunc("temp1", None), ignoreIfExists = false, functionBuilder = Some(tempFunc1)) // Returns true when the function is temporary assert(catalog.isTemporaryFunction(FunctionIdentifier("temp1"))) @@ -1243,9 +1245,9 @@ abstract class SessionCatalogSuite extends PlanTest { test("drop temp function") { withBasicCatalog { catalog => - val info = new ExpressionInfo("tempFunc", "func1") val tempFunc = (e: Seq[Expression]) => e.head - catalog.createTempFunction("func1", info, tempFunc, ignoreIfExists = false) + catalog.registerFunction( + newFunc("func1", None), ignoreIfExists = false, functionBuilder = Some(tempFunc)) val arguments = Seq(Literal(1), Literal(2), Literal(3)) assert(catalog.lookupFunction(FunctionIdentifier("func1"), arguments) === Literal(1)) catalog.dropTempFunction("func1", ignoreIfNotExists = false) @@ -1284,9 +1286,9 @@ abstract class SessionCatalogSuite extends PlanTest { test("lookup temp function") { withBasicCatalog { catalog => - val info1 = new ExpressionInfo("tempFunc1", "func1") val tempFunc1 = (e: Seq[Expression]) => e.head - catalog.createTempFunction("func1", info1, tempFunc1, ignoreIfExists = false) + catalog.registerFunction( + newFunc("func1", None), ignoreIfExists = false, functionBuilder = Some(tempFunc1)) assert(catalog.lookupFunction( FunctionIdentifier("func1"), Seq(Literal(1), Literal(2), Literal(3))) == Literal(1)) catalog.dropTempFunction("func1", ignoreIfNotExists = false) @@ -1298,14 +1300,14 @@ abstract class SessionCatalogSuite extends PlanTest { test("list functions") { withBasicCatalog { catalog => - val info1 = new ExpressionInfo("tempFunc1", "func1") - val info2 = new ExpressionInfo("tempFunc2", "yes_me") + val funcMeta1 = newFunc("func1", None) + val funcMeta2 = newFunc("yes_me", None) val tempFunc1 = (e: Seq[Expression]) => e.head val tempFunc2 = (e: Seq[Expression]) => e.last catalog.createFunction(newFunc("func2", Some("db2")), ignoreIfExists = false) catalog.createFunction(newFunc("not_me", Some("db2")), ignoreIfExists = false) - catalog.createTempFunction("func1", info1, tempFunc1, ignoreIfExists = false) - catalog.createTempFunction("yes_me", info2, tempFunc2, ignoreIfExists = false) + catalog.registerFunction(funcMeta1, ignoreIfExists = false, functionBuilder = Some(tempFunc1)) + catalog.registerFunction(funcMeta2, ignoreIfExists = false, functionBuilder = Some(tempFunc2)) assert(catalog.listFunctions("db1", "*").map(_._1).toSet == Set(FunctionIdentifier("func1"), FunctionIdentifier("yes_me"))) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/functions.scala index 5687f9332430e..e0d0029369576 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/functions.scala @@ -51,6 +51,7 @@ case class CreateFunctionCommand( override def run(sparkSession: SparkSession): Seq[Row] = { val catalog = sparkSession.sessionState.catalog + val func = CatalogFunction(FunctionIdentifier(functionName, databaseName), className, resources) if (isTemp) { if (databaseName.isDefined) { throw new AnalysisException(s"Specifying a database in CREATE TEMPORARY FUNCTION " + @@ -59,17 +60,13 @@ case class CreateFunctionCommand( // We first load resources and then put the builder in the function registry. // Please note that it is allowed to overwrite an existing temp function. catalog.loadFunctionResources(resources) - val info = new ExpressionInfo(className, functionName) - val builder = catalog.makeFunctionBuilder(functionName, className) - catalog.createTempFunction(functionName, info, builder, ignoreIfExists = false) + catalog.registerFunction(func, ignoreIfExists = false) } else { // For a permanent, we will store the metadata into underlying external catalog. // This function will be loaded into the FunctionRegistry when a query uses it. // We do not load it into FunctionRegistry right now. // TODO: should we also parse "IF NOT EXISTS"? - catalog.createFunction( - CatalogFunction(FunctionIdentifier(functionName, databaseName), className, resources), - ignoreIfExists = false) + catalog.createFunction(func, ignoreIfExists = false) } Seq.empty[Row] } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/internal/CatalogSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/internal/CatalogSuite.scala index 6469e501c1f68..8f9c52cb1e031 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/internal/CatalogSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/internal/CatalogSuite.scala @@ -75,9 +75,10 @@ class CatalogSuite } private def createTempFunction(name: String): Unit = { - val info = new ExpressionInfo("className", name) val tempFunc = (e: Seq[Expression]) => e.head - sessionCatalog.createTempFunction(name, info, tempFunc, ignoreIfExists = false) + val funcMeta = CatalogFunction(FunctionIdentifier(name, None), "className", Nil) + sessionCatalog.registerFunction( + funcMeta, ignoreIfExists = false, functionBuilder = Some(tempFunc)) } private def dropFunction(name: String, db: Option[String] = None): Unit = { diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala index c917f110b90f2..377d4f2473c58 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala @@ -31,8 +31,8 @@ import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.FunctionIdentifier import org.apache.spark.sql.catalyst.analysis.FunctionRegistry import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FunctionBuilder -import org.apache.spark.sql.catalyst.catalog.{FunctionResourceLoader, GlobalTempViewManager, SessionCatalog} -import org.apache.spark.sql.catalyst.expressions.{Cast, Expression, ExpressionInfo} +import org.apache.spark.sql.catalyst.catalog.{CatalogFunction, FunctionResourceLoader, GlobalTempViewManager, SessionCatalog} +import org.apache.spark.sql.catalyst.expressions.{Cast, Expression} import org.apache.spark.sql.catalyst.parser.ParserInterface import org.apache.spark.sql.hive.HiveShim.HiveFunctionWrapper import org.apache.spark.sql.internal.SQLConf @@ -124,13 +124,6 @@ private[sql] class HiveSessionCatalog( } private def lookupFunction0(name: FunctionIdentifier, children: Seq[Expression]): Expression = { - // TODO: Once lookupFunction accepts a FunctionIdentifier, we should refactor this method to - // if (super.functionExists(name)) { - // super.lookupFunction(name, children) - // } else { - // // This function is a Hive builtin function. - // ... - // } val database = name.database.map(formatDatabaseName) val funcName = name.copy(database = database) Try(super.lookupFunction(funcName, children)) match { @@ -164,10 +157,11 @@ private[sql] class HiveSessionCatalog( } } val className = functionInfo.getFunctionClass.getName - val builder = makeFunctionBuilder(functionName, className) + val functionIdentifier = + FunctionIdentifier(functionName.toLowerCase(Locale.ROOT), database) + val func = CatalogFunction(functionIdentifier, className, Nil) // Put this Hive built-in function to our function registry. - val info = new ExpressionInfo(className, functionName) - createTempFunction(functionName, info, builder, ignoreIfExists = false) + registerFunction(func, ignoreIfExists = false) // Now, we need to create the Expression. functionRegistry.lookupFunction(functionName, children) } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/execution/benchmark/ObjectHashAggregateExecBenchmark.scala b/sql/hive/src/test/scala/org/apache/spark/sql/execution/benchmark/ObjectHashAggregateExecBenchmark.scala index 197110f4912a7..73383ae4d4118 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/execution/benchmark/ObjectHashAggregateExecBenchmark.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/execution/benchmark/ObjectHashAggregateExecBenchmark.scala @@ -22,7 +22,9 @@ import scala.concurrent.duration._ import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFPercentileApprox import org.apache.spark.sql.Column -import org.apache.spark.sql.catalyst.expressions.{ExpressionInfo, Literal} +import org.apache.spark.sql.catalyst.FunctionIdentifier +import org.apache.spark.sql.catalyst.catalog.CatalogFunction +import org.apache.spark.sql.catalyst.expressions.Literal import org.apache.spark.sql.catalyst.expressions.aggregate.ApproximatePercentile import org.apache.spark.sql.hive.HiveSessionCatalog import org.apache.spark.sql.hive.execution.TestingTypedCount @@ -217,9 +219,9 @@ class ObjectHashAggregateExecBenchmark extends BenchmarkBase with TestHiveSingle private def registerHiveFunction(functionName: String, clazz: Class[_]): Unit = { val sessionCatalog = sparkSession.sessionState.catalog.asInstanceOf[HiveSessionCatalog] - val builder = sessionCatalog.makeFunctionBuilder(functionName, clazz.getName) - val info = new ExpressionInfo(clazz.getName, functionName) - sessionCatalog.createTempFunction(functionName, info, builder, ignoreIfExists = false) + val functionIdentifier = FunctionIdentifier(functionName, database = None) + val func = CatalogFunction(functionIdentifier, clazz.getName, resources = Nil) + sessionCatalog.registerFunction(func, ignoreIfExists = false) } private def percentile_approx( From 540855382c8f139fbf4eb0800b31c7ce91f29c7f Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Wed, 12 Apr 2017 09:05:05 -0700 Subject: [PATCH 0257/1765] [SPARK-20304][SQL] AssertNotNull should not include path in string representation ## What changes were proposed in this pull request? AssertNotNull's toString/simpleString dumps the entire walkedTypePath. walkedTypePath is used for error message reporting and shouldn't be part of the output. ## How was this patch tested? Manually tested. Author: Reynold Xin Closes #17616 from rxin/SPARK-20304. --- .../apache/spark/sql/catalyst/expressions/objects/objects.scala | 2 ++ 1 file changed, 2 insertions(+) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala index 6d94764f1bfac..eed773d4cb368 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala @@ -996,6 +996,8 @@ case class AssertNotNull(child: Expression, walkedTypePath: Seq[String] = Nil) override def foldable: Boolean = false override def nullable: Boolean = false + override def flatArguments: Iterator[Any] = Iterator(child) + private val errMsg = "Null value appeared in non-nullable field:" + walkedTypePath.mkString("\n", "\n", "\n") + "If the schema is inferred from a Scala tuple/case class, or a Java bean, " + From 99a9473127ec389283ac4ec3b721d2e34434e647 Mon Sep 17 00:00:00 2001 From: Jeff Zhang Date: Wed, 12 Apr 2017 10:54:50 -0700 Subject: [PATCH 0258/1765] [SPARK-19570][PYSPARK] Allow to disable hive in pyspark shell ## What changes were proposed in this pull request? SPARK-15236 do this for scala shell, this ticket is for pyspark shell. This is not only for pyspark itself, but can also benefit downstream project like livy which use shell.py for its interactive session. For now, livy has no control of whether enable hive or not. ## How was this patch tested? I didn't find a way to add test for it. Just manually test it. Run `bin/pyspark --master local --conf spark.sql.catalogImplementation=in-memory` and verify hive is not enabled. Author: Jeff Zhang Closes #16906 from zjffdu/SPARK-19570. --- python/pyspark/shell.py | 22 ++++++++++++++++------ 1 file changed, 16 insertions(+), 6 deletions(-) diff --git a/python/pyspark/shell.py b/python/pyspark/shell.py index c1917d2be69d8..b5fcf7092d93a 100644 --- a/python/pyspark/shell.py +++ b/python/pyspark/shell.py @@ -24,13 +24,13 @@ import atexit import os import platform +import warnings import py4j -import pyspark +from pyspark import SparkConf from pyspark.context import SparkContext from pyspark.sql import SparkSession, SQLContext -from pyspark.storagelevel import StorageLevel if os.environ.get("SPARK_EXECUTOR_URI"): SparkContext.setSystemProperty("spark.executor.uri", os.environ["SPARK_EXECUTOR_URI"]) @@ -39,13 +39,23 @@ try: # Try to access HiveConf, it will raise exception if Hive is not added - SparkContext._jvm.org.apache.hadoop.hive.conf.HiveConf() - spark = SparkSession.builder\ - .enableHiveSupport()\ - .getOrCreate() + conf = SparkConf() + if conf.get('spark.sql.catalogImplementation', 'hive').lower() == 'hive': + SparkContext._jvm.org.apache.hadoop.hive.conf.HiveConf() + spark = SparkSession.builder\ + .enableHiveSupport()\ + .getOrCreate() + else: + spark = SparkSession.builder.getOrCreate() except py4j.protocol.Py4JError: + if conf.get('spark.sql.catalogImplementation', '').lower() == 'hive': + warnings.warn("Fall back to non-hive support because failing to access HiveConf, " + "please make sure you build spark with hive") spark = SparkSession.builder.getOrCreate() except TypeError: + if conf.get('spark.sql.catalogImplementation', '').lower() == 'hive': + warnings.warn("Fall back to non-hive support because failing to access HiveConf, " + "please make sure you build spark with hive") spark = SparkSession.builder.getOrCreate() sc = spark.sparkContext From 924c42477b5d6ed3c217c8eaaf4dc64b2379851a Mon Sep 17 00:00:00 2001 From: Burak Yavuz Date: Wed, 12 Apr 2017 11:24:59 -0700 Subject: [PATCH 0259/1765] [SPARK-20301][FLAKY-TEST] Fix Hadoop Shell.runCommand flakiness in Structured Streaming tests ## What changes were proposed in this pull request? Some Structured Streaming tests show flakiness such as: ``` [info] - prune results by current_date, complete mode - 696 *** FAILED *** (10 seconds, 937 milliseconds) [info] Timed out while stopping and waiting for microbatchthread to terminate.: The code passed to failAfter did not complete within 10 seconds. ``` This happens when we wait for the stream to stop, but it doesn't. The reason it doesn't stop is that we interrupt the microBatchThread, but Hadoop's `Shell.runCommand` swallows the interrupt exception, and the exception is not propagated upstream to the microBatchThread. Then this thread continues to run, only to start blocking on the `streamManualClock`. ## How was this patch tested? Thousand retries locally and [Jenkins](https://amplab.cs.berkeley.edu/jenkins/job/SparkPullRequestBuilder/75720/testReport) of the flaky tests Author: Burak Yavuz Closes #17613 from brkyvz/flaky-stream-agg. --- .../execution/streaming/StreamExecution.scala | 56 +++++++++---------- .../spark/sql/streaming/StreamTest.scala | 6 ++ 2 files changed, 32 insertions(+), 30 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala index 8857966676ae2..bcf0d970f7ec1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala @@ -284,42 +284,38 @@ class StreamExecution( triggerExecutor.execute(() => { startTrigger() - val continueToRun = - if (isActive) { - reportTimeTaken("triggerExecution") { - if (currentBatchId < 0) { - // We'll do this initialization only once - populateStartOffsets(sparkSessionToRunBatches) - logDebug(s"Stream running from $committedOffsets to $availableOffsets") - } else { - constructNextBatch() - } - if (dataAvailable) { - currentStatus = currentStatus.copy(isDataAvailable = true) - updateStatusMessage("Processing new data") - runBatch(sparkSessionToRunBatches) - } + if (isActive) { + reportTimeTaken("triggerExecution") { + if (currentBatchId < 0) { + // We'll do this initialization only once + populateStartOffsets(sparkSessionToRunBatches) + logDebug(s"Stream running from $committedOffsets to $availableOffsets") + } else { + constructNextBatch() } - // Report trigger as finished and construct progress object. - finishTrigger(dataAvailable) if (dataAvailable) { - // Update committed offsets. - batchCommitLog.add(currentBatchId) - committedOffsets ++= availableOffsets - logDebug(s"batch ${currentBatchId} committed") - // We'll increase currentBatchId after we complete processing current batch's data - currentBatchId += 1 - } else { - currentStatus = currentStatus.copy(isDataAvailable = false) - updateStatusMessage("Waiting for data to arrive") - Thread.sleep(pollingDelayMs) + currentStatus = currentStatus.copy(isDataAvailable = true) + updateStatusMessage("Processing new data") + runBatch(sparkSessionToRunBatches) } - true + } + // Report trigger as finished and construct progress object. + finishTrigger(dataAvailable) + if (dataAvailable) { + // Update committed offsets. + batchCommitLog.add(currentBatchId) + committedOffsets ++= availableOffsets + logDebug(s"batch ${currentBatchId} committed") + // We'll increase currentBatchId after we complete processing current batch's data + currentBatchId += 1 } else { - false + currentStatus = currentStatus.copy(isDataAvailable = false) + updateStatusMessage("Waiting for data to arrive") + Thread.sleep(pollingDelayMs) } + } updateStatusMessage("Waiting for next trigger") - continueToRun + isActive }) updateStatusMessage("Stopped") } else { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala index 03aa45b616880..5bc36dd30f6d1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala @@ -277,6 +277,11 @@ trait StreamTest extends QueryTest with SharedSQLContext with Timeouts { def threadState = if (currentStream != null && currentStream.microBatchThread.isAlive) "alive" else "dead" + def threadStackTrace = if (currentStream != null && currentStream.microBatchThread.isAlive) { + s"Thread stack trace: ${currentStream.microBatchThread.getStackTrace.mkString("\n")}" + } else { + "" + } def testState = s""" @@ -287,6 +292,7 @@ trait StreamTest extends QueryTest with SharedSQLContext with Timeouts { |Output Mode: $outputMode |Stream state: $currentOffsets |Thread state: $threadState + |$threadStackTrace |${if (streamThreadDeathCause != null) stackTraceToString(streamThreadDeathCause) else ""} | |== Sink == From a7b430b5717e263c1fbb55114deca6028ea9c3b3 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Thu, 13 Apr 2017 08:38:24 +0800 Subject: [PATCH 0260/1765] [SPARK-15354][FLAKY-TEST] TopologyAwareBlockReplicationPolicyBehavior.Peers in 2 racks ## What changes were proposed in this pull request? `TopologyAwareBlockReplicationPolicyBehavior.Peers in 2 racks` is failing occasionally: https://spark-tests.appspot.com/test-details?suite_name=org.apache.spark.storage.TopologyAwareBlockReplicationPolicyBehavior&test_name=Peers+in+2+racks. This is because, when we generate 10 block manager id to test, they may all belong to the same rack, as the rack is randomly picked. This PR fixes this problem by forcing each rack to be picked at least once. ## How was this patch tested? N/A Author: Wenchen Fan Closes #17624 from cloud-fan/test. --- .../spark/storage/BlockReplicationPolicySuite.scala | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/core/src/test/scala/org/apache/spark/storage/BlockReplicationPolicySuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockReplicationPolicySuite.scala index ecad0f5352e59..dfecd04c1b969 100644 --- a/core/src/test/scala/org/apache/spark/storage/BlockReplicationPolicySuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/BlockReplicationPolicySuite.scala @@ -70,9 +70,18 @@ class RandomBlockReplicationPolicyBehavior extends SparkFunSuite } } + /** + * Returns a sequence of [[BlockManagerId]], whose rack is randomly picked from the given `racks`. + * Note that, each rack will be picked at least once from `racks`, if `count` is greater or equal + * to the number of `racks`. + */ protected def generateBlockManagerIds(count: Int, racks: Seq[String]): Seq[BlockManagerId] = { - (1 to count).map{i => - BlockManagerId(s"Exec-$i", s"Host-$i", 10000 + i, Some(racks(Random.nextInt(racks.size)))) + val randomizedRacks: Seq[String] = Random.shuffle( + racks ++ racks.length.until(count).map(_ => racks(Random.nextInt(racks.length))) + ) + + (0 until count).map { i => + BlockManagerId(s"Exec-$i", s"Host-$i", 10000 + i, Some(randomizedRacks(i))) } } } From c5f1cc370f0aa1f0151fd34251607a8de861395e Mon Sep 17 00:00:00 2001 From: Shixiong Zhu Date: Wed, 12 Apr 2017 17:44:18 -0700 Subject: [PATCH 0261/1765] [SPARK-20131][CORE] Don't use `this` lock in StandaloneSchedulerBackend.stop ## What changes were proposed in this pull request? `o.a.s.streaming.StreamingContextSuite.SPARK-18560 Receiver data should be deserialized properly` is flaky is because there is a potential dead-lock in StandaloneSchedulerBackend which causes `await` timeout. Here is the related stack trace: ``` "Thread-31" #211 daemon prio=5 os_prio=31 tid=0x00007fedd4808000 nid=0x16403 waiting on condition [0x00007000239b7000] java.lang.Thread.State: TIMED_WAITING (parking) at sun.misc.Unsafe.park(Native Method) - parking to wait for <0x000000079b49ca10> (a scala.concurrent.impl.Promise$CompletionLatch) at java.util.concurrent.locks.LockSupport.parkNanos(LockSupport.java:215) at java.util.concurrent.locks.AbstractQueuedSynchronizer.doAcquireSharedNanos(AbstractQueuedSynchronizer.java:1037) at java.util.concurrent.locks.AbstractQueuedSynchronizer.tryAcquireSharedNanos(AbstractQueuedSynchronizer.java:1328) at scala.concurrent.impl.Promise$DefaultPromise.tryAwait(Promise.scala:208) at scala.concurrent.impl.Promise$DefaultPromise.ready(Promise.scala:218) at scala.concurrent.impl.Promise$DefaultPromise.result(Promise.scala:223) at org.apache.spark.util.ThreadUtils$.awaitResult(ThreadUtils.scala:201) at org.apache.spark.rpc.RpcTimeout.awaitResult(RpcTimeout.scala:75) at org.apache.spark.rpc.RpcEndpointRef.askSync(RpcEndpointRef.scala:92) at org.apache.spark.rpc.RpcEndpointRef.askSync(RpcEndpointRef.scala:76) at org.apache.spark.scheduler.cluster.CoarseGrainedSchedulerBackend.stop(CoarseGrainedSchedulerBackend.scala:402) at org.apache.spark.scheduler.cluster.StandaloneSchedulerBackend.org$apache$spark$scheduler$cluster$StandaloneSchedulerBackend$$stop(StandaloneSchedulerBackend.scala:213) - locked <0x00000007066fca38> (a org.apache.spark.scheduler.cluster.StandaloneSchedulerBackend) at org.apache.spark.scheduler.cluster.StandaloneSchedulerBackend.stop(StandaloneSchedulerBackend.scala:116) - locked <0x00000007066fca38> (a org.apache.spark.scheduler.cluster.StandaloneSchedulerBackend) at org.apache.spark.scheduler.TaskSchedulerImpl.stop(TaskSchedulerImpl.scala:517) at org.apache.spark.scheduler.DAGScheduler.stop(DAGScheduler.scala:1657) at org.apache.spark.SparkContext$$anonfun$stop$8.apply$mcV$sp(SparkContext.scala:1921) at org.apache.spark.util.Utils$.tryLogNonFatalError(Utils.scala:1302) at org.apache.spark.SparkContext.stop(SparkContext.scala:1920) at org.apache.spark.streaming.StreamingContext.stop(StreamingContext.scala:708) at org.apache.spark.streaming.StreamingContextSuite$$anonfun$43$$anonfun$apply$mcV$sp$66$$anon$3.run(StreamingContextSuite.scala:827) "dispatcher-event-loop-3" #18 daemon prio=5 os_prio=31 tid=0x00007fedd603a000 nid=0x6203 waiting for monitor entry [0x0000700003be4000] java.lang.Thread.State: BLOCKED (on object monitor) at org.apache.spark.scheduler.cluster.CoarseGrainedSchedulerBackend$DriverEndpoint.org$apache$spark$scheduler$cluster$CoarseGrainedSchedulerBackend$DriverEndpoint$$makeOffers(CoarseGrainedSchedulerBackend.scala:253) - waiting to lock <0x00000007066fca38> (a org.apache.spark.scheduler.cluster.StandaloneSchedulerBackend) at org.apache.spark.scheduler.cluster.CoarseGrainedSchedulerBackend$DriverEndpoint$$anonfun$receive$1.applyOrElse(CoarseGrainedSchedulerBackend.scala:124) at org.apache.spark.rpc.netty.Inbox$$anonfun$process$1.apply$mcV$sp(Inbox.scala:117) at org.apache.spark.rpc.netty.Inbox.safelyCall(Inbox.scala:205) at org.apache.spark.rpc.netty.Inbox.process(Inbox.scala:101) at org.apache.spark.rpc.netty.Dispatcher$MessageLoop.run(Dispatcher.scala:213) at java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1142) at java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:617) at java.lang.Thread.run(Thread.java:745) ``` This PR removes `synchronized` and changes `stopping` to AtomicBoolean to ensure idempotent to fix the dead-lock. ## How was this patch tested? Jenkins Author: Shixiong Zhu Closes #17610 from zsxwing/SPARK-20131. --- .../cluster/StandaloneSchedulerBackend.scala | 33 ++++++++++--------- 1 file changed, 17 insertions(+), 16 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/StandaloneSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/StandaloneSchedulerBackend.scala index 7befdb0c1f64d..0529fe9eed4da 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/StandaloneSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/StandaloneSchedulerBackend.scala @@ -18,6 +18,7 @@ package org.apache.spark.scheduler.cluster import java.util.concurrent.Semaphore +import java.util.concurrent.atomic.AtomicBoolean import scala.concurrent.Future @@ -42,7 +43,7 @@ private[spark] class StandaloneSchedulerBackend( with Logging { private var client: StandaloneAppClient = null - private var stopping = false + private val stopping = new AtomicBoolean(false) private val launcherBackend = new LauncherBackend() { override protected def onStopRequest(): Unit = stop(SparkAppHandle.State.KILLED) } @@ -112,7 +113,7 @@ private[spark] class StandaloneSchedulerBackend( launcherBackend.setState(SparkAppHandle.State.RUNNING) } - override def stop(): Unit = synchronized { + override def stop(): Unit = { stop(SparkAppHandle.State.FINISHED) } @@ -125,14 +126,14 @@ private[spark] class StandaloneSchedulerBackend( override def disconnected() { notifyContext() - if (!stopping) { + if (!stopping.get) { logWarning("Disconnected from Spark cluster! Waiting for reconnection...") } } override def dead(reason: String) { notifyContext() - if (!stopping) { + if (!stopping.get) { launcherBackend.setState(SparkAppHandle.State.KILLED) logError("Application has been killed. Reason: " + reason) try { @@ -206,20 +207,20 @@ private[spark] class StandaloneSchedulerBackend( registrationBarrier.release() } - private def stop(finalState: SparkAppHandle.State): Unit = synchronized { - try { - stopping = true - - super.stop() - client.stop() + private def stop(finalState: SparkAppHandle.State): Unit = { + if (stopping.compareAndSet(false, true)) { + try { + super.stop() + client.stop() - val callback = shutdownCallback - if (callback != null) { - callback(this) + val callback = shutdownCallback + if (callback != null) { + callback(this) + } + } finally { + launcherBackend.setState(finalState) + launcherBackend.close() } - } finally { - launcherBackend.setState(finalState) - launcherBackend.close() } } From ec68d8f8cfdede8a0de1d56476205158544cc4eb Mon Sep 17 00:00:00 2001 From: Yash Sharma Date: Thu, 13 Apr 2017 08:49:19 +0100 Subject: [PATCH 0262/1765] [SPARK-20189][DSTREAM] Fix spark kinesis testcases to remove deprecated createStream and use Builders ## What changes were proposed in this pull request? The spark-kinesis testcases use the KinesisUtils.createStream which are deprecated now. Modify the testcases to use the recommended KinesisInputDStream.builder instead. This change will also enable the testcases to automatically use the session tokens automatically. ## How was this patch tested? All the existing testcases work fine as expected with the changes. https://issues.apache.org/jira/browse/SPARK-20189 Author: Yash Sharma Closes #17506 from yssharma/ysharma/cleanup_kinesis_testcases. --- .../kinesis/KinesisInputDStream.scala | 2 +- .../kinesis/KinesisStreamSuite.scala | 58 ++++++++++++------- 2 files changed, 38 insertions(+), 22 deletions(-) diff --git a/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisInputDStream.scala b/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisInputDStream.scala index 8970ad2bafda0..77553412eda56 100644 --- a/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisInputDStream.scala +++ b/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisInputDStream.scala @@ -267,7 +267,7 @@ object KinesisInputDStream { getRequiredParam(checkpointAppName, "checkpointAppName"), checkpointInterval.getOrElse(ssc.graph.batchDuration), storageLevel.getOrElse(DEFAULT_STORAGE_LEVEL), - handler, + ssc.sc.clean(handler), kinesisCredsProvider.getOrElse(DefaultCredentials), dynamoDBCredsProvider, cloudWatchCredsProvider) diff --git a/external/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisStreamSuite.scala b/external/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisStreamSuite.scala index ed7e35805026e..341a6898cbbff 100644 --- a/external/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisStreamSuite.scala +++ b/external/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisStreamSuite.scala @@ -22,7 +22,6 @@ import scala.concurrent.duration._ import scala.language.postfixOps import scala.util.Random -import com.amazonaws.regions.RegionUtils import com.amazonaws.services.kinesis.clientlibrary.lib.worker.InitialPositionInStream import com.amazonaws.services.kinesis.model.Record import org.scalatest.{BeforeAndAfter, BeforeAndAfterAll} @@ -173,11 +172,15 @@ abstract class KinesisStreamTests(aggregateTestData: Boolean) extends KinesisFun * and you have to set the system environment variable RUN_KINESIS_TESTS=1 . */ testIfEnabled("basic operation") { - val awsCredentials = KinesisTestUtils.getAWSCredentials() - val stream = KinesisUtils.createStream(ssc, appName, testUtils.streamName, - testUtils.endpointUrl, testUtils.regionName, InitialPositionInStream.LATEST, - Seconds(10), StorageLevel.MEMORY_ONLY, - awsCredentials.getAWSAccessKeyId, awsCredentials.getAWSSecretKey) + val stream = KinesisInputDStream.builder.streamingContext(ssc) + .checkpointAppName(appName) + .streamName(testUtils.streamName) + .endpointUrl(testUtils.endpointUrl) + .regionName(testUtils.regionName) + .initialPositionInStream(InitialPositionInStream.LATEST) + .checkpointInterval(Seconds(10)) + .storageLevel(StorageLevel.MEMORY_ONLY) + .build() val collected = new mutable.HashSet[Int] stream.map { bytes => new String(bytes).toInt }.foreachRDD { rdd => @@ -198,12 +201,17 @@ abstract class KinesisStreamTests(aggregateTestData: Boolean) extends KinesisFun } testIfEnabled("custom message handling") { - val awsCredentials = KinesisTestUtils.getAWSCredentials() def addFive(r: Record): Int = JavaUtils.bytesToString(r.getData).toInt + 5 - val stream = KinesisUtils.createStream(ssc, appName, testUtils.streamName, - testUtils.endpointUrl, testUtils.regionName, InitialPositionInStream.LATEST, - Seconds(10), StorageLevel.MEMORY_ONLY, addFive(_), - awsCredentials.getAWSAccessKeyId, awsCredentials.getAWSSecretKey) + + val stream = KinesisInputDStream.builder.streamingContext(ssc) + .checkpointAppName(appName) + .streamName(testUtils.streamName) + .endpointUrl(testUtils.endpointUrl) + .regionName(testUtils.regionName) + .initialPositionInStream(InitialPositionInStream.LATEST) + .checkpointInterval(Seconds(10)) + .storageLevel(StorageLevel.MEMORY_ONLY) + .buildWithMessageHandler(addFive(_)) stream shouldBe a [ReceiverInputDStream[_]] @@ -233,11 +241,15 @@ abstract class KinesisStreamTests(aggregateTestData: Boolean) extends KinesisFun val localTestUtils = new KPLBasedKinesisTestUtils(1) localTestUtils.createStream() try { - val awsCredentials = KinesisTestUtils.getAWSCredentials() - val stream = KinesisUtils.createStream(ssc, localAppName, localTestUtils.streamName, - localTestUtils.endpointUrl, localTestUtils.regionName, InitialPositionInStream.LATEST, - Seconds(10), StorageLevel.MEMORY_ONLY, - awsCredentials.getAWSAccessKeyId, awsCredentials.getAWSSecretKey) + val stream = KinesisInputDStream.builder.streamingContext(ssc) + .checkpointAppName(localAppName) + .streamName(localTestUtils.streamName) + .endpointUrl(localTestUtils.endpointUrl) + .regionName(localTestUtils.regionName) + .initialPositionInStream(InitialPositionInStream.LATEST) + .checkpointInterval(Seconds(10)) + .storageLevel(StorageLevel.MEMORY_ONLY) + .build() val collected = new mutable.HashSet[Int] stream.map { bytes => new String(bytes).toInt }.foreachRDD { rdd => @@ -303,13 +315,17 @@ abstract class KinesisStreamTests(aggregateTestData: Boolean) extends KinesisFun ssc = new StreamingContext(sc, Milliseconds(1000)) ssc.checkpoint(checkpointDir) - val awsCredentials = KinesisTestUtils.getAWSCredentials() val collectedData = new mutable.HashMap[Time, (Array[SequenceNumberRanges], Seq[Int])] - val kinesisStream = KinesisUtils.createStream(ssc, appName, testUtils.streamName, - testUtils.endpointUrl, testUtils.regionName, InitialPositionInStream.LATEST, - Seconds(10), StorageLevel.MEMORY_ONLY, - awsCredentials.getAWSAccessKeyId, awsCredentials.getAWSSecretKey) + val kinesisStream = KinesisInputDStream.builder.streamingContext(ssc) + .checkpointAppName(appName) + .streamName(testUtils.streamName) + .endpointUrl(testUtils.endpointUrl) + .regionName(testUtils.regionName) + .initialPositionInStream(InitialPositionInStream.LATEST) + .checkpointInterval(Seconds(10)) + .storageLevel(StorageLevel.MEMORY_ONLY) + .build() // Verify that the generated RDDs are KinesisBackedBlockRDDs, and collect the data in each batch kinesisStream.foreachRDD((rdd: RDD[Array[Byte]], time: Time) => { From 095d1cb3aa0021c9078a6e910967b9189ddfa177 Mon Sep 17 00:00:00 2001 From: Syrux Date: Thu, 13 Apr 2017 09:44:33 +0100 Subject: [PATCH 0263/1765] [SPARK-20265][MLLIB] Improve Prefix'span pre-processing efficiency ## What changes were proposed in this pull request? Improve PrefixSpan pre-processing efficency by preventing sequences of zero in the cleaned database. The efficiency gain is reflected in the following graph : https://postimg.org/image/9x6ireuvn/ ## How was this patch tested? Using MLlib's PrefixSpan existing tests and tests of my own on the 8 datasets shown in the graph. All result obtained were stricly the same as the original implementation (without this change). dev/run-tests was also runned, no error were found. Author : Cyril de Vogelaere Author: Syrux Closes #17575 from Syrux/SPARK-20265. --- .../apache/spark/mllib/fpm/PrefixSpan.scala | 99 ++++++++++++------- .../spark/mllib/fpm/PrefixSpanSuite.scala | 51 ++++++++++ 2 files changed, 115 insertions(+), 35 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/fpm/PrefixSpan.scala b/mllib/src/main/scala/org/apache/spark/mllib/fpm/PrefixSpan.scala index 327cb974ef96c..3f8d65a378e2c 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/fpm/PrefixSpan.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/fpm/PrefixSpan.scala @@ -144,45 +144,13 @@ class PrefixSpan private ( logInfo(s"minimum count for a frequent pattern: $minCount") // Find frequent items. - val freqItemAndCounts = data.flatMap { itemsets => - val uniqItems = mutable.Set.empty[Item] - itemsets.foreach { _.foreach { item => - uniqItems += item - }} - uniqItems.toIterator.map((_, 1L)) - }.reduceByKey(_ + _) - .filter { case (_, count) => - count >= minCount - }.collect() - val freqItems = freqItemAndCounts.sortBy(-_._2).map(_._1) + val freqItems = findFrequentItems(data, minCount) logInfo(s"number of frequent items: ${freqItems.length}") // Keep only frequent items from input sequences and convert them to internal storage. val itemToInt = freqItems.zipWithIndex.toMap - val dataInternalRepr = data.flatMap { itemsets => - val allItems = mutable.ArrayBuilder.make[Int] - var containsFreqItems = false - allItems += 0 - itemsets.foreach { itemsets => - val items = mutable.ArrayBuilder.make[Int] - itemsets.foreach { item => - if (itemToInt.contains(item)) { - items += itemToInt(item) + 1 // using 1-indexing in internal format - } - } - val result = items.result() - if (result.nonEmpty) { - containsFreqItems = true - allItems ++= result.sorted - } - allItems += 0 - } - if (containsFreqItems) { - Iterator.single(allItems.result()) - } else { - Iterator.empty - } - }.persist(StorageLevel.MEMORY_AND_DISK) + val dataInternalRepr = toDatabaseInternalRepr(data, itemToInt) + .persist(StorageLevel.MEMORY_AND_DISK) val results = genFreqPatterns(dataInternalRepr, minCount, maxPatternLength, maxLocalProjDBSize) @@ -231,6 +199,67 @@ class PrefixSpan private ( @Since("1.5.0") object PrefixSpan extends Logging { + /** + * This methods finds all frequent items in a input dataset. + * + * @param data Sequences of itemsets. + * @param minCount The minimal number of sequence an item should be present in to be frequent + * + * @return An array of Item containing only frequent items. + */ + private[fpm] def findFrequentItems[Item: ClassTag]( + data: RDD[Array[Array[Item]]], + minCount: Long): Array[Item] = { + + data.flatMap { itemsets => + val uniqItems = mutable.Set.empty[Item] + itemsets.foreach(set => uniqItems ++= set) + uniqItems.toIterator.map((_, 1L)) + }.reduceByKey(_ + _).filter { case (_, count) => + count >= minCount + }.sortBy(-_._2).map(_._1).collect() + } + + /** + * This methods cleans the input dataset from un-frequent items, and translate it's item + * to their corresponding Int identifier. + * + * @param data Sequences of itemsets. + * @param itemToInt A map allowing translation of frequent Items to their Int Identifier. + * The map should only contain frequent item. + * + * @return The internal repr of the inputted dataset. With properly placed zero delimiter. + */ + private[fpm] def toDatabaseInternalRepr[Item: ClassTag]( + data: RDD[Array[Array[Item]]], + itemToInt: Map[Item, Int]): RDD[Array[Int]] = { + + data.flatMap { itemsets => + val allItems = mutable.ArrayBuilder.make[Int] + var containsFreqItems = false + allItems += 0 + itemsets.foreach { itemsets => + val items = mutable.ArrayBuilder.make[Int] + itemsets.foreach { item => + if (itemToInt.contains(item)) { + items += itemToInt(item) + 1 // using 1-indexing in internal format + } + } + val result = items.result() + if (result.nonEmpty) { + containsFreqItems = true + allItems ++= result.sorted + allItems += 0 + } + } + if (containsFreqItems) { + Iterator.single(allItems.result()) + } else { + Iterator.empty + } + } + } + /** * Find the complete set of frequent sequential patterns in the input sequences. * @param data ordered sequences of itemsets. We represent a sequence internally as Array[Int], diff --git a/mllib/src/test/scala/org/apache/spark/mllib/fpm/PrefixSpanSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/fpm/PrefixSpanSuite.scala index 4c2376376dd2a..c2e08d078fc1a 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/fpm/PrefixSpanSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/fpm/PrefixSpanSuite.scala @@ -360,6 +360,49 @@ class PrefixSpanSuite extends SparkFunSuite with MLlibTestSparkContext { compareResults(expected, model.freqSequences.collect()) } + test("PrefixSpan pre-processing's cleaning test") { + + // One item per itemSet + val itemToInt1 = (4 to 5).zipWithIndex.toMap + val sequences1 = Seq( + Array(Array(4), Array(1), Array(2), Array(5), Array(2), Array(4), Array(5)), + Array(Array(6), Array(7), Array(8))) + val rdd1 = sc.parallelize(sequences1, 2).cache() + + val cleanedSequence1 = PrefixSpan.toDatabaseInternalRepr(rdd1, itemToInt1).collect() + + val expected1 = Array(Array(0, 4, 0, 5, 0, 4, 0, 5, 0)) + .map(_.map(x => if (x == 0) 0 else itemToInt1(x) + 1)) + + compareInternalSequences(expected1, cleanedSequence1) + + // Multi-item sequence + val itemToInt2 = (4 to 6).zipWithIndex.toMap + val sequences2 = Seq( + Array(Array(4, 5), Array(1, 6, 2), Array(2), Array(5), Array(2), Array(4), Array(5, 6, 7)), + Array(Array(8, 9), Array(1, 2))) + val rdd2 = sc.parallelize(sequences2, 2).cache() + + val cleanedSequence2 = PrefixSpan.toDatabaseInternalRepr(rdd2, itemToInt2).collect() + + val expected2 = Array(Array(0, 4, 5, 0, 6, 0, 5, 0, 4, 0, 5, 6, 0)) + .map(_.map(x => if (x == 0) 0 else itemToInt2(x) + 1)) + + compareInternalSequences(expected2, cleanedSequence2) + + // Emptied sequence + val itemToInt3 = (10 to 10).zipWithIndex.toMap + val sequences3 = Seq( + Array(Array(4, 5), Array(1, 6, 2), Array(2), Array(5), Array(2), Array(4), Array(5, 6, 7)), + Array(Array(8, 9), Array(1, 2))) + val rdd3 = sc.parallelize(sequences3, 2).cache() + + val cleanedSequence3 = PrefixSpan.toDatabaseInternalRepr(rdd3, itemToInt3).collect() + val expected3 = Array[Array[Int]]() + + compareInternalSequences(expected3, cleanedSequence3) + } + test("model save/load") { val sequences = Seq( Array(Array(1, 2), Array(3)), @@ -409,4 +452,12 @@ class PrefixSpanSuite extends SparkFunSuite with MLlibTestSparkContext { val actualSet = actualValue.map(x => (x._1.toSeq, x._2)).toSet assert(expectedSet === actualSet) } + + private def compareInternalSequences( + expectedValue: Array[Array[Int]], + actualValue: Array[Array[Int]]): Unit = { + val expectedSet = expectedValue.map(x => x.toSeq).toSet + val actualSet = actualValue.map(x => x.toSeq).toSet + assert(expectedSet === actualSet) + } } From a4293c28438515d5ccf1f6b82f7b762e316d0a27 Mon Sep 17 00:00:00 2001 From: Sergei Lebedev Date: Thu, 13 Apr 2017 09:56:34 +0100 Subject: [PATCH 0264/1765] [SPARK-20284][CORE] Make {Des,S}erializationStream extend Closeable ## What changes were proposed in this pull request? This PR allows to use `SerializationStream` and `DeserializationStream` in try-with-resources. ## How was this patch tested? `core` unit tests. Author: Sergei Lebedev Closes #17598 from superbobry/compression-stream-closeable. --- .../scala/org/apache/spark/serializer/Serializer.scala | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/serializer/Serializer.scala b/core/src/main/scala/org/apache/spark/serializer/Serializer.scala index 01bbda0b5e6b3..cb8b1cc077637 100644 --- a/core/src/main/scala/org/apache/spark/serializer/Serializer.scala +++ b/core/src/main/scala/org/apache/spark/serializer/Serializer.scala @@ -125,7 +125,7 @@ abstract class SerializerInstance { * A stream for writing serialized objects. */ @DeveloperApi -abstract class SerializationStream { +abstract class SerializationStream extends Closeable { /** The most general-purpose method to write an object. */ def writeObject[T: ClassTag](t: T): SerializationStream /** Writes the object representing the key of a key-value pair. */ @@ -133,7 +133,7 @@ abstract class SerializationStream { /** Writes the object representing the value of a key-value pair. */ def writeValue[T: ClassTag](value: T): SerializationStream = writeObject(value) def flush(): Unit - def close(): Unit + override def close(): Unit def writeAll[T: ClassTag](iter: Iterator[T]): SerializationStream = { while (iter.hasNext) { @@ -149,14 +149,14 @@ abstract class SerializationStream { * A stream for reading serialized objects. */ @DeveloperApi -abstract class DeserializationStream { +abstract class DeserializationStream extends Closeable { /** The most general-purpose method to read an object. */ def readObject[T: ClassTag](): T /** Reads the object representing the key of a key-value pair. */ def readKey[T: ClassTag](): T = readObject[T]() /** Reads the object representing the value of a key-value pair. */ def readValue[T: ClassTag](): T = readObject[T]() - def close(): Unit + override def close(): Unit /** * Read the elements of this stream through an iterator. This can only be called once, as From fbe4216e1e83d243a7f0521b76bfb20c25278281 Mon Sep 17 00:00:00 2001 From: Ioana Delaney Date: Thu, 13 Apr 2017 22:27:04 +0800 Subject: [PATCH 0265/1765] [SPARK-20233][SQL] Apply star-join filter heuristics to dynamic programming join enumeration ## What changes were proposed in this pull request? Implements star-join filter to reduce the search space for dynamic programming join enumeration. Consider the following join graph: ``` T1 D1 - T2 - T3 \ / F1 | D2 star-join: {F1, D1, D2} non-star: {T1, T2, T3} ``` The following join combinations will be generated: ``` level 0: (F1), (D1), (D2), (T1), (T2), (T3) level 1: {F1, D1}, {F1, D2}, {T2, T3} level 2: {F1, D1, D2} level 3: {F1, D1, D2, T1}, {F1, D1, D2, T2} level 4: {F1, D1, D2, T1, T2}, {F1, D1, D2, T2, T3 } level 6: {F1, D1, D2, T1, T2, T3} ``` ## How was this patch tested? New test suite ```StarJOinCostBasedReorderSuite.scala```. Author: Ioana Delaney Closes #17546 from ioana-delaney/starSchemaCBOv3. --- .../optimizer/CostBasedJoinReorder.scala | 144 +++++- .../optimizer/StarSchemaDetection.scala | 2 +- .../apache/spark/sql/internal/SQLConf.scala | 8 + .../StarJoinCostBasedReorderSuite.scala | 426 ++++++++++++++++++ 4 files changed, 571 insertions(+), 9 deletions(-) create mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/StarJoinCostBasedReorderSuite.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/CostBasedJoinReorder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/CostBasedJoinReorder.scala index cbd506465ae6a..c704c2e6d36bd 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/CostBasedJoinReorder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/CostBasedJoinReorder.scala @@ -54,8 +54,6 @@ case class CostBasedJoinReorder(conf: SQLConf) extends Rule[LogicalPlan] with Pr private def reorder(plan: LogicalPlan, output: Seq[Attribute]): LogicalPlan = { val (items, conditions) = extractInnerJoins(plan) - // TODO: Compute the set of star-joins and use them in the join enumeration - // algorithm to prune un-optimal plan choices. val result = // Do reordering if the number of items is appropriate and join conditions exist. // We also need to check if costs of all items can be evaluated. @@ -150,12 +148,15 @@ object JoinReorderDP extends PredicateHelper with Logging { case (item, id) => Set(id) -> JoinPlan(Set(id), item, Set(), Cost(0, 0)) }.toMap) + // Build filters from the join graph to be used by the search algorithm. + val filters = JoinReorderDPFilters.buildJoinGraphInfo(conf, items, conditions, itemIndex) + // Build plans for next levels until the last level has only one plan. This plan contains // all items that can be joined, so there's no need to continue. val topOutputSet = AttributeSet(output) - while (foundPlans.size < items.length && foundPlans.last.size > 1) { + while (foundPlans.size < items.length) { // Build plans for the next level. - foundPlans += searchLevel(foundPlans, conf, conditions, topOutputSet) + foundPlans += searchLevel(foundPlans, conf, conditions, topOutputSet, filters) } val durationInMs = (System.nanoTime() - startTime) / (1000 * 1000) @@ -179,7 +180,8 @@ object JoinReorderDP extends PredicateHelper with Logging { existingLevels: Seq[JoinPlanMap], conf: SQLConf, conditions: Set[Expression], - topOutput: AttributeSet): JoinPlanMap = { + topOutput: AttributeSet, + filters: Option[JoinGraphInfo]): JoinPlanMap = { val nextLevel = mutable.Map.empty[Set[Int], JoinPlan] var k = 0 @@ -200,7 +202,7 @@ object JoinReorderDP extends PredicateHelper with Logging { } otherSideCandidates.foreach { otherSidePlan => - buildJoin(oneSidePlan, otherSidePlan, conf, conditions, topOutput) match { + buildJoin(oneSidePlan, otherSidePlan, conf, conditions, topOutput, filters) match { case Some(newJoinPlan) => // Check if it's the first plan for the item set, or it's a better plan than // the existing one due to lower cost. @@ -218,14 +220,20 @@ object JoinReorderDP extends PredicateHelper with Logging { } /** - * Builds a new JoinPlan when both conditions hold: + * Builds a new JoinPlan if the following conditions hold: * - the sets of items contained in left and right sides do not overlap. * - there exists at least one join condition involving references from both sides. + * - if star-join filter is enabled, allow the following combinations: + * 1) (oneJoinPlan U otherJoinPlan) is a subset of star-join + * 2) star-join is a subset of (oneJoinPlan U otherJoinPlan) + * 3) (oneJoinPlan U otherJoinPlan) is a subset of non star-join + * * @param oneJoinPlan One side JoinPlan for building a new JoinPlan. * @param otherJoinPlan The other side JoinPlan for building a new join node. * @param conf SQLConf for statistics computation. * @param conditions The overall set of join conditions. * @param topOutput The output attributes of the final plan. + * @param filters Join graph info to be used as filters by the search algorithm. * @return Builds and returns a new JoinPlan if both conditions hold. Otherwise, returns None. */ private def buildJoin( @@ -233,13 +241,27 @@ object JoinReorderDP extends PredicateHelper with Logging { otherJoinPlan: JoinPlan, conf: SQLConf, conditions: Set[Expression], - topOutput: AttributeSet): Option[JoinPlan] = { + topOutput: AttributeSet, + filters: Option[JoinGraphInfo]): Option[JoinPlan] = { if (oneJoinPlan.itemIds.intersect(otherJoinPlan.itemIds).nonEmpty) { // Should not join two overlapping item sets. return None } + if (filters.isDefined) { + // Apply star-join filter, which ensures that tables in a star schema relationship + // are planned together. The star-filter will eliminate joins among star and non-star + // tables until the star joins are built. The following combinations are allowed: + // 1. (oneJoinPlan U otherJoinPlan) is a subset of star-join + // 2. star-join is a subset of (oneJoinPlan U otherJoinPlan) + // 3. (oneJoinPlan U otherJoinPlan) is a subset of non star-join + val isValidJoinCombination = + JoinReorderDPFilters.starJoinFilter(oneJoinPlan.itemIds, otherJoinPlan.itemIds, + filters.get) + if (!isValidJoinCombination) return None + } + val onePlan = oneJoinPlan.plan val otherPlan = otherJoinPlan.plan val joinConds = conditions @@ -327,3 +349,109 @@ object JoinReorderDP extends PredicateHelper with Logging { case class Cost(card: BigInt, size: BigInt) { def +(other: Cost): Cost = Cost(this.card + other.card, this.size + other.size) } + +/** + * Implements optional filters to reduce the search space for join enumeration. + * + * 1) Star-join filters: Plan star-joins together since they are assumed + * to have an optimal execution based on their RI relationship. + * 2) Cartesian products: Defer their planning later in the graph to avoid + * large intermediate results (expanding joins, in general). + * 3) Composite inners: Don't generate "bushy tree" plans to avoid materializing + * intermediate results. + * + * Filters (2) and (3) are not implemented. + */ +object JoinReorderDPFilters extends PredicateHelper { + /** + * Builds join graph information to be used by the filtering strategies. + * Currently, it builds the sets of star/non-star joins. + * It can be extended with the sets of connected/unconnected joins, which + * can be used to filter Cartesian products. + */ + def buildJoinGraphInfo( + conf: SQLConf, + items: Seq[LogicalPlan], + conditions: Set[Expression], + itemIndex: Seq[(LogicalPlan, Int)]): Option[JoinGraphInfo] = { + + if (conf.joinReorderDPStarFilter) { + // Compute the tables in a star-schema relationship. + val starJoin = StarSchemaDetection(conf).findStarJoins(items, conditions.toSeq) + val nonStarJoin = items.filterNot(starJoin.contains(_)) + + if (starJoin.nonEmpty && nonStarJoin.nonEmpty) { + val itemMap = itemIndex.toMap + Some(JoinGraphInfo(starJoin.map(itemMap).toSet, nonStarJoin.map(itemMap).toSet)) + } else { + // Nothing interesting to return. + None + } + } else { + // Star schema filter is not enabled. + None + } + } + + /** + * Applies the star-join filter that eliminates join combinations among star + * and non-star tables until the star join is built. + * + * Given the oneSideJoinPlan/otherSideJoinPlan, which represent all the plan + * permutations generated by the DP join enumeration, and the star/non-star plans, + * the following plan combinations are allowed: + * 1. (oneSideJoinPlan U otherSideJoinPlan) is a subset of star-join + * 2. star-join is a subset of (oneSideJoinPlan U otherSideJoinPlan) + * 3. (oneSideJoinPlan U otherSideJoinPlan) is a subset of non star-join + * + * It assumes the sets are disjoint. + * + * Example query graph: + * + * t1 d1 - t2 - t3 + * \ / + * f1 + * | + * d2 + * + * star: {d1, f1, d2} + * non-star: {t2, t1, t3} + * + * level 0: (f1 ), (d2 ), (t3 ), (d1 ), (t1 ), (t2 ) + * level 1: {t3 t2 }, {f1 d2 }, {f1 d1 } + * level 2: {d2 f1 d1 } + * level 3: {t1 d1 f1 d2 }, {t2 d1 f1 d2 } + * level 4: {d1 t2 f1 t1 d2 }, {d1 t3 t2 f1 d2 } + * level 5: {d1 t3 t2 f1 t1 d2 } + * + * @param oneSideJoinPlan One side of the join represented as a set of plan ids. + * @param otherSideJoinPlan The other side of the join represented as a set of plan ids. + * @param filters Star and non-star plans represented as sets of plan ids + */ + def starJoinFilter( + oneSideJoinPlan: Set[Int], + otherSideJoinPlan: Set[Int], + filters: JoinGraphInfo) : Boolean = { + val starJoins = filters.starJoins + val nonStarJoins = filters.nonStarJoins + val join = oneSideJoinPlan.union(otherSideJoinPlan) + + // Disjoint sets + oneSideJoinPlan.intersect(otherSideJoinPlan).isEmpty && + // Either star or non-star is empty + (starJoins.isEmpty || nonStarJoins.isEmpty || + // Join is a subset of the star-join + join.subsetOf(starJoins) || + // Star-join is a subset of join + starJoins.subsetOf(join) || + // Join is a subset of non-star + join.subsetOf(nonStarJoins)) + } +} + +/** + * Helper class that keeps information about the join graph as sets of item/plan ids. + * It currently stores the star/non-star plans. It can be + * extended with the set of connected/unconnected plans. + */ +case class JoinGraphInfo (starJoins: Set[Int], nonStarJoins: Set[Int]) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/StarSchemaDetection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/StarSchemaDetection.scala index 91cb004eaec46..97ee9988386dd 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/StarSchemaDetection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/StarSchemaDetection.scala @@ -76,7 +76,7 @@ case class StarSchemaDetection(conf: SQLConf) extends PredicateHelper { val emptyStarJoinPlan = Seq.empty[LogicalPlan] - if (!conf.starSchemaDetection || input.size < 2) { + if (input.size < 2) { emptyStarJoinPlan } else { // Find if the input plans are eligible for star join detection. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 6b0f495033494..2e1798e22b9fc 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -736,6 +736,12 @@ object SQLConf { .checkValue(weight => weight >= 0 && weight <= 1, "The weight value must be in [0, 1].") .createWithDefault(0.7) + val JOIN_REORDER_DP_STAR_FILTER = + buildConf("spark.sql.cbo.joinReorder.dp.star.filter") + .doc("Applies star-join filter heuristics to cost based join enumeration.") + .booleanConf + .createWithDefault(false) + val STARSCHEMA_DETECTION = buildConf("spark.sql.cbo.starSchemaDetection") .doc("When true, it enables join reordering based on star schema detection. ") .booleanConf @@ -1011,6 +1017,8 @@ class SQLConf extends Serializable with Logging { def joinReorderCardWeight: Double = getConf(SQLConf.JOIN_REORDER_CARD_WEIGHT) + def joinReorderDPStarFilter: Boolean = getConf(SQLConf.JOIN_REORDER_DP_STAR_FILTER) + def windowExecBufferSpillThreshold: Int = getConf(WINDOW_EXEC_BUFFER_SPILL_THRESHOLD) def sortMergeJoinExecBufferSpillThreshold: Int = diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/StarJoinCostBasedReorderSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/StarJoinCostBasedReorderSuite.scala new file mode 100644 index 0000000000000..a23d6266b2840 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/StarJoinCostBasedReorderSuite.scala @@ -0,0 +1,426 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.optimizer + +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.dsl.plans._ +import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap} +import org.apache.spark.sql.catalyst.plans.{Inner, PlanTest} +import org.apache.spark.sql.catalyst.plans.logical.{ColumnStat, LogicalPlan} +import org.apache.spark.sql.catalyst.rules.RuleExecutor +import org.apache.spark.sql.catalyst.statsEstimation.{StatsEstimationTestBase, StatsTestPlan} +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.internal.SQLConf._ + + +class StarJoinCostBasedReorderSuite extends PlanTest with StatsEstimationTestBase { + + override val conf = new SQLConf().copy( + CBO_ENABLED -> true, + JOIN_REORDER_ENABLED -> true, + JOIN_REORDER_DP_STAR_FILTER -> true) + + object Optimize extends RuleExecutor[LogicalPlan] { + val batches = + Batch("Operator Optimizations", FixedPoint(100), + CombineFilters, + PushDownPredicate, + ReorderJoin(conf), + PushPredicateThroughJoin, + ColumnPruning, + CollapseProject) :: + Batch("Join Reorder", Once, + CostBasedJoinReorder(conf)) :: Nil + } + + private val columnInfo: AttributeMap[ColumnStat] = AttributeMap(Seq( + // F1 (fact table) + attr("f1_fk1") -> ColumnStat(distinctCount = 100, min = Some(1), max = Some(100), + nullCount = 0, avgLen = 4, maxLen = 4), + attr("f1_fk2") -> ColumnStat(distinctCount = 100, min = Some(1), max = Some(100), + nullCount = 0, avgLen = 4, maxLen = 4), + attr("f1_fk3") -> ColumnStat(distinctCount = 100, min = Some(1), max = Some(100), + nullCount = 0, avgLen = 4, maxLen = 4), + attr("f1_c1") -> ColumnStat(distinctCount = 100, min = Some(1), max = Some(100), + nullCount = 0, avgLen = 4, maxLen = 4), + attr("f1_c2") -> ColumnStat(distinctCount = 100, min = Some(1), max = Some(100), + nullCount = 0, avgLen = 4, maxLen = 4), + + // D1 (dimension) + attr("d1_pk") -> ColumnStat(distinctCount = 100, min = Some(1), max = Some(100), + nullCount = 0, avgLen = 4, maxLen = 4), + attr("d1_c2") -> ColumnStat(distinctCount = 50, min = Some(1), max = Some(50), + nullCount = 0, avgLen = 4, maxLen = 4), + attr("d1_c3") -> ColumnStat(distinctCount = 50, min = Some(1), max = Some(50), + nullCount = 0, avgLen = 4, maxLen = 4), + + // D2 (dimension) + attr("d2_pk") -> ColumnStat(distinctCount = 20, min = Some(1), max = Some(20), + nullCount = 0, avgLen = 4, maxLen = 4), + attr("d2_c2") -> ColumnStat(distinctCount = 10, min = Some(1), max = Some(10), + nullCount = 0, avgLen = 4, maxLen = 4), + attr("d2_c3") -> ColumnStat(distinctCount = 10, min = Some(1), max = Some(10), + nullCount = 0, avgLen = 4, maxLen = 4), + + // D3 (dimension) + attr("d3_pk") -> ColumnStat(distinctCount = 10, min = Some(1), max = Some(10), + nullCount = 0, avgLen = 4, maxLen = 4), + attr("d3_c2") -> ColumnStat(distinctCount = 5, min = Some(1), max = Some(5), + nullCount = 0, avgLen = 4, maxLen = 4), + attr("d3_c3") -> ColumnStat(distinctCount = 5, min = Some(1), max = Some(5), + nullCount = 0, avgLen = 4, maxLen = 4), + + // T1 (regular table i.e. outside star) + attr("t1_c1") -> ColumnStat(distinctCount = 20, min = Some(1), max = Some(20), + nullCount = 1, avgLen = 4, maxLen = 4), + attr("t1_c2") -> ColumnStat(distinctCount = 10, min = Some(1), max = Some(10), + nullCount = 1, avgLen = 4, maxLen = 4), + attr("t1_c3") -> ColumnStat(distinctCount = 10, min = Some(1), max = Some(10), + nullCount = 1, avgLen = 4, maxLen = 4), + + // T2 (regular table) + attr("t2_c1") -> ColumnStat(distinctCount = 5, min = Some(1), max = Some(5), + nullCount = 1, avgLen = 4, maxLen = 4), + attr("t2_c2") -> ColumnStat(distinctCount = 5, min = Some(1), max = Some(5), + nullCount = 1, avgLen = 4, maxLen = 4), + attr("t2_c3") -> ColumnStat(distinctCount = 5, min = Some(1), max = Some(5), + nullCount = 1, avgLen = 4, maxLen = 4), + + // T3 (regular table) + attr("t3_c1") -> ColumnStat(distinctCount = 5, min = Some(1), max = Some(5), + nullCount = 1, avgLen = 4, maxLen = 4), + attr("t3_c2") -> ColumnStat(distinctCount = 5, min = Some(1), max = Some(5), + nullCount = 1, avgLen = 4, maxLen = 4), + attr("t3_c3") -> ColumnStat(distinctCount = 5, min = Some(1), max = Some(5), + nullCount = 1, avgLen = 4, maxLen = 4), + + // T4 (regular table) + attr("t4_c1") -> ColumnStat(distinctCount = 5, min = Some(1), max = Some(5), + nullCount = 1, avgLen = 4, maxLen = 4), + attr("t4_c2") -> ColumnStat(distinctCount = 5, min = Some(1), max = Some(5), + nullCount = 1, avgLen = 4, maxLen = 4), + attr("t4_c3") -> ColumnStat(distinctCount = 5, min = Some(1), max = Some(5), + nullCount = 1, avgLen = 4, maxLen = 4), + + // T5 (regular table) + attr("t5_c1") -> ColumnStat(distinctCount = 5, min = Some(1), max = Some(5), + nullCount = 1, avgLen = 4, maxLen = 4), + attr("t5_c2") -> ColumnStat(distinctCount = 5, min = Some(1), max = Some(5), + nullCount = 1, avgLen = 4, maxLen = 4), + attr("t5_c3") -> ColumnStat(distinctCount = 5, min = Some(1), max = Some(5), + nullCount = 1, avgLen = 4, maxLen = 4), + + // T6 (regular table) + attr("t6_c1") -> ColumnStat(distinctCount = 5, min = Some(1), max = Some(5), + nullCount = 1, avgLen = 4, maxLen = 4), + attr("t6_c2") -> ColumnStat(distinctCount = 5, min = Some(1), max = Some(5), + nullCount = 1, avgLen = 4, maxLen = 4), + attr("t6_c3") -> ColumnStat(distinctCount = 5, min = Some(1), max = Some(5), + nullCount = 1, avgLen = 4, maxLen = 4) + + )) + + private val nameToAttr: Map[String, Attribute] = columnInfo.map(kv => kv._1.name -> kv._1) + private val nameToColInfo: Map[String, (Attribute, ColumnStat)] = + columnInfo.map(kv => kv._1.name -> kv) + + private val f1 = StatsTestPlan( + outputList = Seq("f1_fk1", "f1_fk2", "f1_fk3", "f1_c1", "f1_c2").map(nameToAttr), + rowCount = 1000, + size = Some(1000 * (8 + 4 * 5)), + attributeStats = AttributeMap(Seq("f1_fk1", "f1_fk2", "f1_fk3", "f1_c1", "f1_c2") + .map(nameToColInfo))) + + // To control the layout of the join plans, keep the size for the non-fact tables constant + // and vary the rowcount and the number of distinct values of the join columns. + private val d1 = StatsTestPlan( + outputList = Seq("d1_pk", "d1_c2", "d1_c3").map(nameToAttr), + rowCount = 100, + size = Some(3000), + attributeStats = AttributeMap(Seq("d1_pk", "d1_c2", "d1_c3").map(nameToColInfo))) + + private val d2 = StatsTestPlan( + outputList = Seq("d2_pk", "d2_c2", "d2_c3").map(nameToAttr), + rowCount = 20, + size = Some(3000), + attributeStats = AttributeMap(Seq("d2_pk", "d2_c2", "d2_c3").map(nameToColInfo))) + + private val d3 = StatsTestPlan( + outputList = Seq("d3_pk", "d3_c2", "d3_c3").map(nameToAttr), + rowCount = 10, + size = Some(3000), + attributeStats = AttributeMap(Seq("d3_pk", "d3_c2", "d3_c3").map(nameToColInfo))) + + private val t1 = StatsTestPlan( + outputList = Seq("t1_c1", "t1_c2", "t1_c3").map(nameToAttr), + rowCount = 50, + size = Some(3000), + attributeStats = AttributeMap(Seq("t1_c1", "t1_c2", "t1_c3").map(nameToColInfo))) + + private val t2 = StatsTestPlan( + outputList = Seq("t2_c1", "t2_c2", "t2_c3").map(nameToAttr), + rowCount = 10, + size = Some(3000), + attributeStats = AttributeMap(Seq("t2_c1", "t2_c2", "t2_c3").map(nameToColInfo))) + + private val t3 = StatsTestPlan( + outputList = Seq("t3_c1", "t3_c2", "t3_c3").map(nameToAttr), + rowCount = 10, + size = Some(3000), + attributeStats = AttributeMap(Seq("t3_c1", "t3_c2", "t3_c3").map(nameToColInfo))) + + private val t4 = StatsTestPlan( + outputList = Seq("t4_c1", "t4_c2", "t4_c3").map(nameToAttr), + rowCount = 10, + size = Some(3000), + attributeStats = AttributeMap(Seq("t4_c1", "t4_c2", "t4_c3").map(nameToColInfo))) + + private val t5 = StatsTestPlan( + outputList = Seq("t5_c1", "t5_c2", "t5_c3").map(nameToAttr), + rowCount = 10, + size = Some(3000), + attributeStats = AttributeMap(Seq("t5_c1", "t5_c2", "t5_c3").map(nameToColInfo))) + + private val t6 = StatsTestPlan( + outputList = Seq("t6_c1", "t6_c2", "t6_c3").map(nameToAttr), + rowCount = 10, + size = Some(3000), + attributeStats = AttributeMap(Seq("t6_c1", "t6_c2", "t6_c3").map(nameToColInfo))) + + test("Test 1: Star query with two dimensions and two regular tables") { + + // d1 t1 + // \ / + // f1 + // / \ + // d2 t2 + // + // star: {f1, d1, d2} + // non-star: {t1, t2} + // + // level 0: (t2 ), (d2 ), (f1 ), (d1 ), (t1 ) + // level 1: {f1 d1 }, {d2 f1 } + // level 2: {d2 f1 d1 } + // level 3: {t2 d1 d2 f1 }, {t1 d1 d2 f1 } + // level 4: {f1 t1 t2 d1 d2 } + // + // Number of generated plans: 11 (vs. 20 w/o filter) + val query = + f1.join(t1).join(t2).join(d1).join(d2) + .where((nameToAttr("f1_c1") === nameToAttr("t1_c1")) && + (nameToAttr("f1_c2") === nameToAttr("t2_c1")) && + (nameToAttr("f1_fk1") === nameToAttr("d1_pk")) && + (nameToAttr("f1_fk2") === nameToAttr("d2_pk"))) + + val expected = + f1.join(d2, Inner, Some(nameToAttr("f1_fk2") === nameToAttr("d2_pk"))) + .join(d1, Inner, Some(nameToAttr("f1_fk1") === nameToAttr("d1_pk"))) + .join(t2, Inner, Some(nameToAttr("f1_c2") === nameToAttr("t2_c1"))) + .join(t1, Inner, Some(nameToAttr("f1_c1") === nameToAttr("t1_c1"))) + + assertEqualPlans(query, expected) + } + + test("Test 2: Star with a linear branch") { + // + // t1 d1 - t2 - t3 + // \ / + // f1 + // | + // d2 + // + // star: {d1, f1, d2} + // non-star: {t2, t1, t3} + // + // level 0: (f1 ), (d2 ), (t3 ), (d1 ), (t1 ), (t2 ) + // level 1: {t3 t2 }, {f1 d2 }, {f1 d1 } + // level 2: {d2 f1 d1 } + // level 3: {t1 d1 f1 d2 }, {t2 d1 f1 d2 } + // level 4: {d1 t2 f1 t1 d2 }, {d1 t3 t2 f1 d2 } + // level 5: {d1 t3 t2 f1 t1 d2 } + // + // Number of generated plans: 15 (vs 24) + val query = + d1.join(t1).join(t2).join(f1).join(d2).join(t3) + .where((nameToAttr("d1_pk") === nameToAttr("f1_fk1")) && + (nameToAttr("t1_c1") === nameToAttr("f1_c1")) && + (nameToAttr("d2_pk") === nameToAttr("f1_fk2")) && + (nameToAttr("f1_fk2") === nameToAttr("d2_pk")) && + (nameToAttr("d1_c2") === nameToAttr("t2_c1")) && + (nameToAttr("t2_c2") === nameToAttr("t3_c1"))) + + val expected = + f1.join(d2, Inner, Some(nameToAttr("f1_fk2") === nameToAttr("d2_pk"))) + .join(d1, Inner, Some(nameToAttr("f1_fk1") === nameToAttr("d1_pk"))) + .join(t3.join(t2, Inner, Some(nameToAttr("t2_c2") === nameToAttr("t3_c1"))), Inner, + Some(nameToAttr("d1_c2") === nameToAttr("t2_c1"))) + .join(t1, Inner, Some(nameToAttr("t1_c1") === nameToAttr("f1_c1"))) + + assertEqualPlans(query, expected) + } + + test("Test 3: Star with derived branches") { + // t3 t2 + // | | + // d1 - t4 - t1 + // | + // f1 + // | + // d2 + // + // star: (d1 f1 d2 ) + // non-star: (t4 t1 t2 t3 ) + // + // level 0: (t1 ), (t3 ), (f1 ), (d1 ), (t2 ), (d2 ), (t4 ) + // level 1: {f1 d2 }, {t1 t4 }, {t1 t2 }, {f1 d1 }, {t3 t4 } + // level 2: {d1 f1 d2 }, {t2 t1 t4 }, {t1 t3 t4 } + // level 3: {t4 d1 f1 d2 }, {t3 t4 t1 t2 } + // level 4: {d1 f1 t4 d2 t3 }, {d1 f1 t4 d2 t1 } + // level 5: {d1 f1 t4 d2 t1 t2 }, {d1 f1 t4 d2 t1 t3 } + // level 6: {d1 f1 t4 d2 t1 t2 t3 } + // + // Number of generated plans: 22 (vs. 34) + val query = + d1.join(t1).join(t2).join(t3).join(t4).join(f1).join(d2) + .where((nameToAttr("t1_c1") === nameToAttr("t2_c1")) && + (nameToAttr("t3_c1") === nameToAttr("t4_c1")) && + (nameToAttr("t1_c2") === nameToAttr("t4_c2")) && + (nameToAttr("d1_c2") === nameToAttr("t4_c3")) && + (nameToAttr("f1_fk1") === nameToAttr("d1_pk")) && + (nameToAttr("f1_fk2") === nameToAttr("d2_pk"))) + + val expected = + f1.join(d2, Inner, Some(nameToAttr("f1_fk2") === nameToAttr("d2_pk"))) + .join(d1, Inner, Some(nameToAttr("f1_fk1") === nameToAttr("d1_pk"))) + .join(t3.join(t4, Inner, Some(nameToAttr("t3_c1") === nameToAttr("t4_c1"))), Inner, + Some(nameToAttr("t3_c1") === nameToAttr("t4_c1"))) + .join(t1.join(t2, Inner, Some(nameToAttr("t1_c1") === nameToAttr("t2_c1"))), Inner, + Some(nameToAttr("t1_c2") === nameToAttr("t4_c2"))) + + assertEqualPlans(query, expected) + } + + test("Test 4: Star with several branches") { + // + // d1 - t3 - t4 + // | + // f1 - d3 - t1 - t2 + // | + // d2 - t5 - t6 + // + // star: {d1 f1 d2 d3 } + // non-star: {t5 t3 t6 t2 t4 t1} + // + // level 0: (t4 ), (d2 ), (t5 ), (d3 ), (d1 ), (f1 ), (t2 ), (t6 ), (t1 ), (t3 ) + // level 1: {t5 t6 }, {t4 t3 }, {d3 f1 }, {t2 t1 }, {d2 f1 }, {d1 f1 } + // level 2: {d2 d1 f1 }, {d2 d3 f1 }, {d3 d1 f1 } + // level 3: {d2 d1 d3 f1 } + // level 4: {d1 t3 d3 f1 d2 }, {d1 d3 f1 t1 d2 }, {d1 t5 d3 f1 d2 } + // level 5: {d1 t5 d3 f1 t1 d2 }, {d1 t3 t4 d3 f1 d2 }, {d1 t5 t6 d3 f1 d2 }, + // {d1 t5 t3 d3 f1 d2 }, {d1 t3 d3 f1 t1 d2 }, {d1 t2 d3 f1 t1 d2 } + // level 6: {d1 t5 t3 t4 d3 f1 d2 }, {d1 t3 t2 d3 f1 t1 d2 }, {d1 t5 t6 d3 f1 t1 d2 }, + // {d1 t5 t3 d3 f1 t1 d2 }, {d1 t5 t2 d3 f1 t1 d2 }, ... + // ... + // level 9: {d1 t5 t3 t6 t2 t4 d3 f1 t1 d2 } + // + // Number of generated plans: 46 (vs. 82) + val query = + d1.join(t3).join(t4).join(f1).join(d2).join(t5).join(t6).join(d3).join(t1).join(t2) + .where((nameToAttr("d1_c2") === nameToAttr("t3_c1")) && + (nameToAttr("t3_c2") === nameToAttr("t4_c2")) && + (nameToAttr("d1_pk") === nameToAttr("f1_fk1")) && + (nameToAttr("f1_fk2") === nameToAttr("d2_pk")) && + (nameToAttr("d2_c2") === nameToAttr("t5_c1")) && + (nameToAttr("t5_c2") === nameToAttr("t6_c2")) && + (nameToAttr("f1_fk3") === nameToAttr("d3_pk")) && + (nameToAttr("d3_c2") === nameToAttr("t1_c1")) && + (nameToAttr("t1_c2") === nameToAttr("t2_c2"))) + + val expected = + f1.join(d3, Inner, Some(nameToAttr("f1_fk3") === nameToAttr("d3_pk"))) + .join(d1, Inner, Some(nameToAttr("f1_fk1") === nameToAttr("d1_pk"))) + .join(d2, Inner, Some(nameToAttr("f1_fk2") === nameToAttr("d2_pk"))) + .join(t4.join(t3, Inner, Some(nameToAttr("t3_c2") === nameToAttr("t4_c2"))), Inner, + Some(nameToAttr("d1_c2") === nameToAttr("t3_c1"))) + .join(t2.join(t1, Inner, Some(nameToAttr("t1_c2") === nameToAttr("t2_c2"))), Inner, + Some(nameToAttr("d3_c2") === nameToAttr("t1_c1"))) + .join(t5.join(t6, Inner, Some(nameToAttr("t5_c2") === nameToAttr("t6_c2"))), Inner, + Some(nameToAttr("d2_c2") === nameToAttr("t5_c1"))) + + assertEqualPlans(query, expected) + } + + test("Test 5: RI star only") { + // d1 + // | + // f1 + // / \ + // d2 d3 + // + // star: {f1, d1, d2, d3} + // non-star: {} + // level 0: (d1), (f1), (d2), (d3) + // level 1: {f1 d3 }, {f1 d2 }, {d1 f1 } + // level 2: {d1 f1 d2 }, {d2 f1 d3 }, {d1 f1 d3 } + // level 3: {d1 d2 f1 d3 } + // Number of generated plans: 11 (= 11) + val query = + d1.join(d2).join(f1).join(d3) + .where((nameToAttr("f1_fk1") === nameToAttr("d1_pk")) && + (nameToAttr("f1_fk2") === nameToAttr("d2_pk")) && + (nameToAttr("f1_fk3") === nameToAttr("d3_pk"))) + + val expected = + f1.join(d3, Inner, Some(nameToAttr("f1_fk3") === nameToAttr("d3_pk"))) + .join(d2, Inner, Some(nameToAttr("f1_fk2") === nameToAttr("d2_pk"))) + .join(d1, Inner, Some(nameToAttr("f1_fk1") === nameToAttr("d1_pk"))) + + assertEqualPlans(query, expected) + } + + test("Test 6: No RI star") { + // + // f1 - t1 - t2 - t3 + // + // star: {} + // non-star: {f1, t1, t2, t3} + // level 0: (t1), (f1), (t2), (t3) + // level 1: {f1 t3 }, {f1 t2 }, {t1 f1 } + // level 2: {t1 f1 t2 }, {t2 f1 t3 }, {dt f1 t3 } + // level 3: {t1 t2 f1 t3 } + // Number of generated plans: 11 (= 11) + val query = + t1.join(f1).join(t2).join(t3) + .where((nameToAttr("f1_fk1") === nameToAttr("t1_c1")) && + (nameToAttr("f1_fk2") === nameToAttr("t2_c1")) && + (nameToAttr("f1_fk3") === nameToAttr("t3_c1"))) + + val expected = + f1.join(t3, Inner, Some(nameToAttr("f1_fk3") === nameToAttr("t3_c1"))) + .join(t2, Inner, Some(nameToAttr("f1_fk2") === nameToAttr("t2_c1"))) + .join(t1, Inner, Some(nameToAttr("f1_fk1") === nameToAttr("t1_c1"))) + + assertEqualPlans(query, expected) + } + + private def assertEqualPlans( plan1: LogicalPlan, plan2: LogicalPlan): Unit = { + val optimized = Optimize.execute(plan1.analyze) + val expected = plan2.analyze + compareJoinOrder(optimized, expected) + } +} From 8ddf0d2a60795a2306f94df8eac6e265b1fe5230 Mon Sep 17 00:00:00 2001 From: David Gingrich Date: Thu, 13 Apr 2017 12:43:28 -0700 Subject: [PATCH 0266/1765] [SPARK-20232][PYTHON] Improve combineByKey docs MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## What changes were proposed in this pull request? Improve combineByKey documentation: * Add note on memory allocation * Change example code to use different mergeValue and mergeCombiners ## How was this patch tested? Doctest. ## Legal This is my original work and I license the work to the project under the project’s open source license. Author: David Gingrich Closes #17545 from dgingrich/topic-spark-20232-combinebykey-docs. --- python/pyspark/rdd.py | 24 +++++++++++++++++++----- 1 file changed, 19 insertions(+), 5 deletions(-) diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index 291c1caaaed57..60141792d499b 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -1804,17 +1804,31 @@ def combineByKey(self, createCombiner, mergeValue, mergeCombiners, a one-element list) - C{mergeValue}, to merge a V into a C (e.g., adds it to the end of a list) - - C{mergeCombiners}, to combine two C's into a single one. + - C{mergeCombiners}, to combine two C's into a single one (e.g., merges + the lists) + + To avoid memory allocation, both mergeValue and mergeCombiners are allowed to + modify and return their first argument instead of creating a new C. In addition, users can control the partitioning of the output RDD. .. note:: V and C can be different -- for example, one might group an RDD of type (Int, Int) into an RDD of type (Int, List[Int]). - >>> x = sc.parallelize([("a", 1), ("b", 1), ("a", 1)]) - >>> def add(a, b): return a + str(b) - >>> sorted(x.combineByKey(str, add, add).collect()) - [('a', '11'), ('b', '1')] + >>> x = sc.parallelize([("a", 1), ("b", 1), ("a", 2)]) + >>> def to_list(a): + ... return [a] + ... + >>> def append(a, b): + ... a.append(b) + ... return a + ... + >>> def extend(a, b): + ... a.extend(b) + ... return a + ... + >>> sorted(x.combineByKey(to_list, append, extend).collect()) + [('a', [1, 2]), ('b', [1])] """ if numPartitions is None: numPartitions = self._defaultReducePartitions() From 7536e2849df6d63587fbf16b4ecb5db06fed7125 Mon Sep 17 00:00:00 2001 From: Steve Loughran Date: Thu, 13 Apr 2017 15:30:44 -0500 Subject: [PATCH 0267/1765] [SPARK-20038][SQL] FileFormatWriter.ExecuteWriteTask.releaseResources() implementations to be re-entrant ## What changes were proposed in this pull request? have the`FileFormatWriter.ExecuteWriteTask.releaseResources()` implementations set `currentWriter=null` in a finally clause. This guarantees that if the first call to `currentWriter()` throws an exception, the second releaseResources() call made during the task cancel process will not trigger a second attempt to close the stream. ## How was this patch tested? Tricky. I've been fixing the underlying cause when I saw the problem [HADOOP-14204](https://issues.apache.org/jira/browse/HADOOP-14204), but SPARK-10109 shows I'm not the first to have seen this. I can't replicate it locally any more, my code no longer being broken. code review, however, should be straightforward Author: Steve Loughran Closes #17364 from steveloughran/stevel/SPARK-20038-close. --- .../execution/datasources/FileFormatWriter.scala | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala index bda64d4b91bbc..4ec09bff429c5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala @@ -324,8 +324,11 @@ object FileFormatWriter extends Logging { override def releaseResources(): Unit = { if (currentWriter != null) { - currentWriter.close() - currentWriter = null + try { + currentWriter.close() + } finally { + currentWriter = null + } } } } @@ -459,8 +462,11 @@ object FileFormatWriter extends Logging { override def releaseResources(): Unit = { if (currentWriter != null) { - currentWriter.close() - currentWriter = null + try { + currentWriter.close() + } finally { + currentWriter = null + } } } } From fb036c4413c2cd4d90880d080f418ec468d6c0fc Mon Sep 17 00:00:00 2001 From: wangzhenhua Date: Fri, 14 Apr 2017 19:16:47 +0800 Subject: [PATCH 0268/1765] [SPARK-20318][SQL] Use Catalyst type for min/max in ColumnStat for ease of estimation ## What changes were proposed in this pull request? Currently when estimating predicates like col > literal or col = literal, we will update min or max in column stats based on literal value. However, literal value is of Catalyst type (internal type), while min/max is of external type. Then for the next predicate, we again need to do type conversion to compare and update column stats. This is awkward and causes many unnecessary conversions in estimation. To solve this, we use Catalyst type for min/max in `ColumnStat`. Note that the persistent format in metastore is still of external type, so there's no inconsistency for statistics in metastore. This pr also fixes a bug for boolean type in `IN` condition. ## How was this patch tested? The changes for ColumnStat are covered by existing tests. For bug fix, a new test for boolean type in IN condition is added Author: wangzhenhua Closes #17630 from wzhfy/refactorColumnStat. --- .../catalyst/plans/logical/Statistics.scala | 95 +++++++++++++------ .../statsEstimation/EstimationUtils.scala | 30 +++++- .../statsEstimation/FilterEstimation.scala | 68 ++++--------- .../plans/logical/statsEstimation/Range.scala | 70 +++----------- .../FilterEstimationSuite.scala | 41 ++++---- .../statsEstimation/JoinEstimationSuite.scala | 15 +-- .../ProjectEstimationSuite.scala | 21 ++-- .../command/AnalyzeColumnCommand.scala | 8 +- .../spark/sql/StatisticsCollectionSuite.scala | 19 ++-- .../spark/sql/hive/HiveExternalCatalog.scala | 4 +- 10 files changed, 189 insertions(+), 182 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/Statistics.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/Statistics.scala index f24b240956a61..3d4efef953a64 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/Statistics.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/Statistics.scala @@ -25,6 +25,7 @@ import org.apache.spark.internal.Logging import org.apache.spark.sql.{AnalysisException, Row} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ +import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.types._ import org.apache.spark.util.Utils @@ -74,11 +75,10 @@ case class Statistics( * Statistics collected for a column. * * 1. Supported data types are defined in `ColumnStat.supportsType`. - * 2. The JVM data type stored in min/max is the external data type (used in Row) for the - * corresponding Catalyst data type. For example, for DateType we store java.sql.Date, and for - * TimestampType we store java.sql.Timestamp. - * 3. For integral types, they are all upcasted to longs, i.e. shorts are stored as longs. - * 4. There is no guarantee that the statistics collected are accurate. Approximation algorithms + * 2. The JVM data type stored in min/max is the internal data type for the corresponding + * Catalyst data type. For example, the internal type of DateType is Int, and that the internal + * type of TimestampType is Long. + * 3. There is no guarantee that the statistics collected are accurate. Approximation algorithms * (sketches) might have been used, and the data collected can also be stale. * * @param distinctCount number of distinct values @@ -104,22 +104,43 @@ case class ColumnStat( /** * Returns a map from string to string that can be used to serialize the column stats. * The key is the name of the field (e.g. "distinctCount" or "min"), and the value is the string - * representation for the value. The deserialization side is defined in [[ColumnStat.fromMap]]. + * representation for the value. min/max values are converted to the external data type. For + * example, for DateType we store java.sql.Date, and for TimestampType we store + * java.sql.Timestamp. The deserialization side is defined in [[ColumnStat.fromMap]]. * * As part of the protocol, the returned map always contains a key called "version". * In the case min/max values are null (None), they won't appear in the map. */ - def toMap: Map[String, String] = { + def toMap(colName: String, dataType: DataType): Map[String, String] = { val map = new scala.collection.mutable.HashMap[String, String] map.put(ColumnStat.KEY_VERSION, "1") map.put(ColumnStat.KEY_DISTINCT_COUNT, distinctCount.toString) map.put(ColumnStat.KEY_NULL_COUNT, nullCount.toString) map.put(ColumnStat.KEY_AVG_LEN, avgLen.toString) map.put(ColumnStat.KEY_MAX_LEN, maxLen.toString) - min.foreach { v => map.put(ColumnStat.KEY_MIN_VALUE, v.toString) } - max.foreach { v => map.put(ColumnStat.KEY_MAX_VALUE, v.toString) } + min.foreach { v => map.put(ColumnStat.KEY_MIN_VALUE, toExternalString(v, colName, dataType)) } + max.foreach { v => map.put(ColumnStat.KEY_MAX_VALUE, toExternalString(v, colName, dataType)) } map.toMap } + + /** + * Converts the given value from Catalyst data type to string representation of external + * data type. + */ + private def toExternalString(v: Any, colName: String, dataType: DataType): String = { + val externalValue = dataType match { + case DateType => DateTimeUtils.toJavaDate(v.asInstanceOf[Int]) + case TimestampType => DateTimeUtils.toJavaTimestamp(v.asInstanceOf[Long]) + case BooleanType | _: IntegralType | FloatType | DoubleType => v + case _: DecimalType => v.asInstanceOf[Decimal].toJavaBigDecimal + // This version of Spark does not use min/max for binary/string types so we ignore it. + case _ => + throw new AnalysisException("Column statistics deserialization is not supported for " + + s"column $colName of data type: $dataType.") + } + externalValue.toString + } + } @@ -150,28 +171,15 @@ object ColumnStat extends Logging { * Creates a [[ColumnStat]] object from the given map. This is used to deserialize column stats * from some external storage. The serialization side is defined in [[ColumnStat.toMap]]. */ - def fromMap(table: String, field: StructField, map: Map[String, String]) - : Option[ColumnStat] = { - val str2val: (String => Any) = field.dataType match { - case _: IntegralType => _.toLong - case _: DecimalType => new java.math.BigDecimal(_) - case DoubleType | FloatType => _.toDouble - case BooleanType => _.toBoolean - case DateType => java.sql.Date.valueOf - case TimestampType => java.sql.Timestamp.valueOf - // This version of Spark does not use min/max for binary/string types so we ignore it. - case BinaryType | StringType => _ => null - case _ => - throw new AnalysisException("Column statistics deserialization is not supported for " + - s"column ${field.name} of data type: ${field.dataType}.") - } - + def fromMap(table: String, field: StructField, map: Map[String, String]): Option[ColumnStat] = { try { Some(ColumnStat( distinctCount = BigInt(map(KEY_DISTINCT_COUNT).toLong), // Note that flatMap(Option.apply) turns Option(null) into None. - min = map.get(KEY_MIN_VALUE).map(str2val).flatMap(Option.apply), - max = map.get(KEY_MAX_VALUE).map(str2val).flatMap(Option.apply), + min = map.get(KEY_MIN_VALUE) + .map(fromExternalString(_, field.name, field.dataType)).flatMap(Option.apply), + max = map.get(KEY_MAX_VALUE) + .map(fromExternalString(_, field.name, field.dataType)).flatMap(Option.apply), nullCount = BigInt(map(KEY_NULL_COUNT).toLong), avgLen = map.getOrElse(KEY_AVG_LEN, field.dataType.defaultSize.toString).toLong, maxLen = map.getOrElse(KEY_MAX_LEN, field.dataType.defaultSize.toString).toLong @@ -183,6 +191,30 @@ object ColumnStat extends Logging { } } + /** + * Converts from string representation of external data type to the corresponding Catalyst data + * type. + */ + private def fromExternalString(s: String, name: String, dataType: DataType): Any = { + dataType match { + case BooleanType => s.toBoolean + case DateType => DateTimeUtils.fromJavaDate(java.sql.Date.valueOf(s)) + case TimestampType => DateTimeUtils.fromJavaTimestamp(java.sql.Timestamp.valueOf(s)) + case ByteType => s.toByte + case ShortType => s.toShort + case IntegerType => s.toInt + case LongType => s.toLong + case FloatType => s.toFloat + case DoubleType => s.toDouble + case _: DecimalType => Decimal(s) + // This version of Spark does not use min/max for binary/string types so we ignore it. + case BinaryType | StringType => null + case _ => + throw new AnalysisException("Column statistics deserialization is not supported for " + + s"column $name of data type: $dataType.") + } + } + /** * Constructs an expression to compute column statistics for a given column. * @@ -232,11 +264,14 @@ object ColumnStat extends Logging { } /** Convert a struct for column stats (defined in statExprs) into [[ColumnStat]]. */ - def rowToColumnStat(row: Row): ColumnStat = { + def rowToColumnStat(row: Row, attr: Attribute): ColumnStat = { ColumnStat( distinctCount = BigInt(row.getLong(0)), - min = Option(row.get(1)), // for string/binary min/max, get should return null - max = Option(row.get(2)), + // for string/binary min/max, get should return null + min = Option(row.get(1)) + .map(v => fromExternalString(v.toString, attr.name, attr.dataType)).flatMap(Option.apply), + max = Option(row.get(2)) + .map(v => fromExternalString(v.toString, attr.name, attr.dataType)).flatMap(Option.apply), nullCount = BigInt(row.getLong(3)), avgLen = row.getLong(4), maxLen = row.getLong(5) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/EstimationUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/EstimationUtils.scala index 5577233ffa6fe..f1aff62cb6af0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/EstimationUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/EstimationUtils.scala @@ -22,7 +22,7 @@ import scala.math.BigDecimal.RoundingMode import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap} import org.apache.spark.sql.catalyst.plans.logical.{ColumnStat, LogicalPlan, Statistics} import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.types.{DataType, StringType} +import org.apache.spark.sql.types.{DecimalType, _} object EstimationUtils { @@ -75,4 +75,32 @@ object EstimationUtils { // (simple computation of statistics returns product of children). if (outputRowCount > 0) outputRowCount * sizePerRow else 1 } + + /** + * For simplicity we use Decimal to unify operations for data types whose min/max values can be + * represented as numbers, e.g. Boolean can be represented as 0 (false) or 1 (true). + * The two methods below are the contract of conversion. + */ + def toDecimal(value: Any, dataType: DataType): Decimal = { + dataType match { + case _: NumericType | DateType | TimestampType => Decimal(value.toString) + case BooleanType => if (value.asInstanceOf[Boolean]) Decimal(1) else Decimal(0) + } + } + + def fromDecimal(dec: Decimal, dataType: DataType): Any = { + dataType match { + case BooleanType => dec.toLong == 1 + case DateType => dec.toInt + case TimestampType => dec.toLong + case ByteType => dec.toByte + case ShortType => dec.toShort + case IntegerType => dec.toInt + case LongType => dec.toLong + case FloatType => dec.toFloat + case DoubleType => dec.toDouble + case _: DecimalType => dec + } + } + } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/FilterEstimation.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/FilterEstimation.scala index 7bd8e6511232f..4b6b3b14d9ac8 100755 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/FilterEstimation.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/FilterEstimation.scala @@ -25,7 +25,6 @@ import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.Literal.{FalseLiteral, TrueLiteral} import org.apache.spark.sql.catalyst.plans.logical.{ColumnStat, Filter, LeafNode, Statistics} -import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ @@ -301,30 +300,6 @@ case class FilterEstimation(plan: Filter, catalystConf: SQLConf) extends Logging } } - /** - * For a SQL data type, its internal data type may be different from its external type. - * For DateType, its internal type is Int, and its external data type is Java Date type. - * The min/max values in ColumnStat are saved in their corresponding external type. - * - * @param attrDataType the column data type - * @param litValue the literal value - * @return a BigDecimal value - */ - def convertBoundValue(attrDataType: DataType, litValue: Any): Option[Any] = { - attrDataType match { - case DateType => - Some(DateTimeUtils.toJavaDate(litValue.toString.toInt)) - case TimestampType => - Some(DateTimeUtils.toJavaTimestamp(litValue.toString.toLong)) - case _: DecimalType => - Some(litValue.asInstanceOf[Decimal].toJavaBigDecimal) - case StringType | BinaryType => - None - case _ => - Some(litValue) - } - } - /** * Returns a percentage of rows meeting an equality (=) expression. * This method evaluates the equality predicate for all data types. @@ -356,12 +331,16 @@ case class FilterEstimation(plan: Filter, catalystConf: SQLConf) extends Logging val statsRange = Range(colStat.min, colStat.max, attr.dataType) if (statsRange.contains(literal)) { if (update) { - // We update ColumnStat structure after apply this equality predicate. - // Set distinctCount to 1. Set nullCount to 0. - // Need to save new min/max using the external type value of the literal - val newValue = convertBoundValue(attr.dataType, literal.value) - val newStats = colStat.copy(distinctCount = 1, min = newValue, - max = newValue, nullCount = 0) + // We update ColumnStat structure after apply this equality predicate: + // Set distinctCount to 1, nullCount to 0, and min/max values (if exist) to the literal + // value. + val newStats = attr.dataType match { + case StringType | BinaryType => + colStat.copy(distinctCount = 1, nullCount = 0) + case _ => + colStat.copy(distinctCount = 1, min = Some(literal.value), + max = Some(literal.value), nullCount = 0) + } colStatsMap(attr) = newStats } @@ -430,18 +409,14 @@ case class FilterEstimation(plan: Filter, catalystConf: SQLConf) extends Logging return Some(0.0) } - // Need to save new min/max using the external type value of the literal - val newMax = convertBoundValue( - attr.dataType, validQuerySet.maxBy(v => BigDecimal(v.toString))) - val newMin = convertBoundValue( - attr.dataType, validQuerySet.minBy(v => BigDecimal(v.toString))) - + val newMax = validQuerySet.maxBy(EstimationUtils.toDecimal(_, dataType)) + val newMin = validQuerySet.minBy(EstimationUtils.toDecimal(_, dataType)) // newNdv should not be greater than the old ndv. For example, column has only 2 values // 1 and 6. The predicate column IN (1, 2, 3, 4, 5). validQuerySet.size is 5. newNdv = ndv.min(BigInt(validQuerySet.size)) if (update) { - val newStats = colStat.copy(distinctCount = newNdv, min = newMin, - max = newMax, nullCount = 0) + val newStats = colStat.copy(distinctCount = newNdv, min = Some(newMin), + max = Some(newMax), nullCount = 0) colStatsMap(attr) = newStats } @@ -478,8 +453,8 @@ case class FilterEstimation(plan: Filter, catalystConf: SQLConf) extends Logging val colStat = colStatsMap(attr) val statsRange = Range(colStat.min, colStat.max, attr.dataType).asInstanceOf[NumericRange] - val max = BigDecimal(statsRange.max) - val min = BigDecimal(statsRange.min) + val max = statsRange.max.toBigDecimal + val min = statsRange.min.toBigDecimal val ndv = BigDecimal(colStat.distinctCount) // determine the overlapping degree between predicate range and column's range @@ -540,8 +515,7 @@ case class FilterEstimation(plan: Filter, catalystConf: SQLConf) extends Logging } if (update) { - // Need to save new min/max using the external type value of the literal - val newValue = convertBoundValue(attr.dataType, literal.value) + val newValue = Some(literal.value) var newMax = colStat.max var newMin = colStat.min var newNdv = (ndv * percent).setScale(0, RoundingMode.HALF_UP).toBigInt() @@ -606,14 +580,14 @@ case class FilterEstimation(plan: Filter, catalystConf: SQLConf) extends Logging val colStatLeft = colStatsMap(attrLeft) val statsRangeLeft = Range(colStatLeft.min, colStatLeft.max, attrLeft.dataType) .asInstanceOf[NumericRange] - val maxLeft = BigDecimal(statsRangeLeft.max) - val minLeft = BigDecimal(statsRangeLeft.min) + val maxLeft = statsRangeLeft.max + val minLeft = statsRangeLeft.min val colStatRight = colStatsMap(attrRight) val statsRangeRight = Range(colStatRight.min, colStatRight.max, attrRight.dataType) .asInstanceOf[NumericRange] - val maxRight = BigDecimal(statsRangeRight.max) - val minRight = BigDecimal(statsRangeRight.min) + val maxRight = statsRangeRight.max + val minRight = statsRangeRight.min // determine the overlapping degree between predicate range and column's range val allNotNull = (colStatLeft.nullCount == 0) && (colStatRight.nullCount == 0) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/Range.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/Range.scala index 3d13967cb62a4..4ac5ba5689f82 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/Range.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/Range.scala @@ -17,12 +17,8 @@ package org.apache.spark.sql.catalyst.plans.logical.statsEstimation -import java.math.{BigDecimal => JDecimal} -import java.sql.{Date, Timestamp} - import org.apache.spark.sql.catalyst.expressions.Literal -import org.apache.spark.sql.catalyst.util.DateTimeUtils -import org.apache.spark.sql.types.{BooleanType, DateType, TimestampType, _} +import org.apache.spark.sql.types._ /** Value range of a column. */ @@ -31,13 +27,10 @@ trait Range { } /** For simplicity we use decimal to unify operations of numeric ranges. */ -case class NumericRange(min: JDecimal, max: JDecimal) extends Range { +case class NumericRange(min: Decimal, max: Decimal) extends Range { override def contains(l: Literal): Boolean = { - val decimal = l.dataType match { - case BooleanType => if (l.value.asInstanceOf[Boolean]) new JDecimal(1) else new JDecimal(0) - case _ => new JDecimal(l.value.toString) - } - min.compareTo(decimal) <= 0 && max.compareTo(decimal) >= 0 + val lit = EstimationUtils.toDecimal(l.value, l.dataType) + min <= lit && max >= lit } } @@ -58,7 +51,10 @@ object Range { def apply(min: Option[Any], max: Option[Any], dataType: DataType): Range = dataType match { case StringType | BinaryType => new DefaultRange() case _ if min.isEmpty || max.isEmpty => new NullRange() - case _ => toNumericRange(min.get, max.get, dataType) + case _ => + NumericRange( + min = EstimationUtils.toDecimal(min.get, dataType), + max = EstimationUtils.toDecimal(max.get, dataType)) } def isIntersected(r1: Range, r2: Range): Boolean = (r1, r2) match { @@ -82,51 +78,11 @@ object Range { // binary/string types don't support intersecting. (None, None) case (n1: NumericRange, n2: NumericRange) => - val newRange = NumericRange(n1.min.max(n2.min), n1.max.min(n2.max)) - val (newMin, newMax) = fromNumericRange(newRange, dt) - (Some(newMin), Some(newMax)) + // Choose the maximum of two min values, and the minimum of two max values. + val newMin = if (n1.min <= n2.min) n2.min else n1.min + val newMax = if (n1.max <= n2.max) n1.max else n2.max + (Some(EstimationUtils.fromDecimal(newMin, dt)), + Some(EstimationUtils.fromDecimal(newMax, dt))) } } - - /** - * For simplicity we use decimal to unify operations of numeric types, the two methods below - * are the contract of conversion. - */ - private def toNumericRange(min: Any, max: Any, dataType: DataType): NumericRange = { - dataType match { - case _: NumericType => - NumericRange(new JDecimal(min.toString), new JDecimal(max.toString)) - case BooleanType => - val min1 = if (min.asInstanceOf[Boolean]) 1 else 0 - val max1 = if (max.asInstanceOf[Boolean]) 1 else 0 - NumericRange(new JDecimal(min1), new JDecimal(max1)) - case DateType => - val min1 = DateTimeUtils.fromJavaDate(min.asInstanceOf[Date]) - val max1 = DateTimeUtils.fromJavaDate(max.asInstanceOf[Date]) - NumericRange(new JDecimal(min1), new JDecimal(max1)) - case TimestampType => - val min1 = DateTimeUtils.fromJavaTimestamp(min.asInstanceOf[Timestamp]) - val max1 = DateTimeUtils.fromJavaTimestamp(max.asInstanceOf[Timestamp]) - NumericRange(new JDecimal(min1), new JDecimal(max1)) - } - } - - private def fromNumericRange(n: NumericRange, dataType: DataType): (Any, Any) = { - dataType match { - case _: IntegralType => - (n.min.longValue(), n.max.longValue()) - case FloatType | DoubleType => - (n.min.doubleValue(), n.max.doubleValue()) - case _: DecimalType => - (n.min, n.max) - case BooleanType => - (n.min.longValue() == 1, n.max.longValue() == 1) - case DateType => - (DateTimeUtils.toJavaDate(n.min.intValue()), DateTimeUtils.toJavaDate(n.max.intValue())) - case TimestampType => - (DateTimeUtils.toJavaTimestamp(n.min.longValue()), - DateTimeUtils.toJavaTimestamp(n.max.longValue())) - } - } - } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/FilterEstimationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/FilterEstimationSuite.scala index cffb0d8739287..a28447840ae09 100755 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/FilterEstimationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/FilterEstimationSuite.scala @@ -24,6 +24,7 @@ import org.apache.spark.sql.catalyst.expressions.Literal.{FalseLiteral, TrueLite import org.apache.spark.sql.catalyst.plans.LeftOuter import org.apache.spark.sql.catalyst.plans.logical.{ColumnStat, Filter, Join, Statistics} import org.apache.spark.sql.catalyst.plans.logical.statsEstimation.EstimationUtils._ +import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.types._ /** @@ -45,15 +46,15 @@ class FilterEstimationSuite extends StatsEstimationTestBase { nullCount = 0, avgLen = 1, maxLen = 1) // column cdate has 10 values from 2017-01-01 through 2017-01-10. - val dMin = Date.valueOf("2017-01-01") - val dMax = Date.valueOf("2017-01-10") + val dMin = DateTimeUtils.fromJavaDate(Date.valueOf("2017-01-01")) + val dMax = DateTimeUtils.fromJavaDate(Date.valueOf("2017-01-10")) val attrDate = AttributeReference("cdate", DateType)() val colStatDate = ColumnStat(distinctCount = 10, min = Some(dMin), max = Some(dMax), nullCount = 0, avgLen = 4, maxLen = 4) // column cdecimal has 4 values from 0.20 through 0.80 at increment of 0.20. - val decMin = new java.math.BigDecimal("0.200000000000000000") - val decMax = new java.math.BigDecimal("0.800000000000000000") + val decMin = Decimal("0.200000000000000000") + val decMax = Decimal("0.800000000000000000") val attrDecimal = AttributeReference("cdecimal", DecimalType(18, 18))() val colStatDecimal = ColumnStat(distinctCount = 4, min = Some(decMin), max = Some(decMax), nullCount = 0, avgLen = 8, maxLen = 8) @@ -147,7 +148,6 @@ class FilterEstimationSuite extends StatsEstimationTestBase { test("cint < 3 OR null") { val condition = Or(LessThan(attrInt, Literal(3)), Literal(null, IntegerType)) - val m = Filter(condition, childStatsTestPlan(Seq(attrInt), 10L)).stats(conf) validateEstimatedStats( Filter(condition, childStatsTestPlan(Seq(attrInt), 10L)), Seq(attrInt -> colStatInt), @@ -341,6 +341,14 @@ class FilterEstimationSuite extends StatsEstimationTestBase { expectedRowCount = 7) } + test("cbool IN (true)") { + validateEstimatedStats( + Filter(InSet(attrBool, Set(true)), childStatsTestPlan(Seq(attrBool), 10L)), + Seq(attrBool -> ColumnStat(distinctCount = 1, min = Some(true), max = Some(true), + nullCount = 0, avgLen = 1, maxLen = 1)), + expectedRowCount = 5) + } + test("cbool = true") { validateEstimatedStats( Filter(EqualTo(attrBool, Literal(true)), childStatsTestPlan(Seq(attrBool), 10L)), @@ -358,9 +366,9 @@ class FilterEstimationSuite extends StatsEstimationTestBase { } test("cdate = cast('2017-01-02' AS DATE)") { - val d20170102 = Date.valueOf("2017-01-02") + val d20170102 = DateTimeUtils.fromJavaDate(Date.valueOf("2017-01-02")) validateEstimatedStats( - Filter(EqualTo(attrDate, Literal(d20170102)), + Filter(EqualTo(attrDate, Literal(d20170102, DateType)), childStatsTestPlan(Seq(attrDate), 10L)), Seq(attrDate -> ColumnStat(distinctCount = 1, min = Some(d20170102), max = Some(d20170102), nullCount = 0, avgLen = 4, maxLen = 4)), @@ -368,9 +376,9 @@ class FilterEstimationSuite extends StatsEstimationTestBase { } test("cdate < cast('2017-01-03' AS DATE)") { - val d20170103 = Date.valueOf("2017-01-03") + val d20170103 = DateTimeUtils.fromJavaDate(Date.valueOf("2017-01-03")) validateEstimatedStats( - Filter(LessThan(attrDate, Literal(d20170103)), + Filter(LessThan(attrDate, Literal(d20170103, DateType)), childStatsTestPlan(Seq(attrDate), 10L)), Seq(attrDate -> ColumnStat(distinctCount = 2, min = Some(dMin), max = Some(d20170103), nullCount = 0, avgLen = 4, maxLen = 4)), @@ -379,19 +387,19 @@ class FilterEstimationSuite extends StatsEstimationTestBase { test("""cdate IN ( cast('2017-01-03' AS DATE), cast('2017-01-04' AS DATE), cast('2017-01-05' AS DATE) )""") { - val d20170103 = Date.valueOf("2017-01-03") - val d20170104 = Date.valueOf("2017-01-04") - val d20170105 = Date.valueOf("2017-01-05") + val d20170103 = DateTimeUtils.fromJavaDate(Date.valueOf("2017-01-03")) + val d20170104 = DateTimeUtils.fromJavaDate(Date.valueOf("2017-01-04")) + val d20170105 = DateTimeUtils.fromJavaDate(Date.valueOf("2017-01-05")) validateEstimatedStats( - Filter(In(attrDate, Seq(Literal(d20170103), Literal(d20170104), Literal(d20170105))), - childStatsTestPlan(Seq(attrDate), 10L)), + Filter(In(attrDate, Seq(Literal(d20170103, DateType), Literal(d20170104, DateType), + Literal(d20170105, DateType))), childStatsTestPlan(Seq(attrDate), 10L)), Seq(attrDate -> ColumnStat(distinctCount = 3, min = Some(d20170103), max = Some(d20170105), nullCount = 0, avgLen = 4, maxLen = 4)), expectedRowCount = 3) } test("cdecimal = 0.400000000000000000") { - val dec_0_40 = new java.math.BigDecimal("0.400000000000000000") + val dec_0_40 = Decimal("0.400000000000000000") validateEstimatedStats( Filter(EqualTo(attrDecimal, Literal(dec_0_40)), childStatsTestPlan(Seq(attrDecimal), 4L)), @@ -401,7 +409,7 @@ class FilterEstimationSuite extends StatsEstimationTestBase { } test("cdecimal < 0.60 ") { - val dec_0_60 = new java.math.BigDecimal("0.600000000000000000") + val dec_0_60 = Decimal("0.600000000000000000") validateEstimatedStats( Filter(LessThan(attrDecimal, Literal(dec_0_60)), childStatsTestPlan(Seq(attrDecimal), 4L)), @@ -532,7 +540,6 @@ class FilterEstimationSuite extends StatsEstimationTestBase { test("cint = cint3") { // no records qualify due to no overlap - val emptyColStats = Seq[(Attribute, ColumnStat)]() validateEstimatedStats( Filter(EqualTo(attrInt, attrInt3), childStatsTestPlan(Seq(attrInt, attrInt3), 10L)), Nil, // set to empty diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/JoinEstimationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/JoinEstimationSuite.scala index f62df842fa50a..2d6b6e8e21f34 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/JoinEstimationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/JoinEstimationSuite.scala @@ -25,6 +25,7 @@ import org.apache.spark.sql.catalyst.expressions.{And, Attribute, AttributeMap, import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical.{ColumnStat, Join, Project, Statistics} import org.apache.spark.sql.catalyst.plans.logical.statsEstimation.EstimationUtils._ +import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.types.{DateType, TimestampType, _} @@ -254,24 +255,24 @@ class JoinEstimationSuite extends StatsEstimationTestBase { test("test join keys of different types") { /** Columns in a table with only one row */ def genColumnData: mutable.LinkedHashMap[Attribute, ColumnStat] = { - val dec = new java.math.BigDecimal("1.000000000000000000") - val date = Date.valueOf("2016-05-08") - val timestamp = Timestamp.valueOf("2016-05-08 00:00:01") + val dec = Decimal("1.000000000000000000") + val date = DateTimeUtils.fromJavaDate(Date.valueOf("2016-05-08")) + val timestamp = DateTimeUtils.fromJavaTimestamp(Timestamp.valueOf("2016-05-08 00:00:01")) mutable.LinkedHashMap[Attribute, ColumnStat]( AttributeReference("cbool", BooleanType)() -> ColumnStat(distinctCount = 1, min = Some(false), max = Some(false), nullCount = 0, avgLen = 1, maxLen = 1), AttributeReference("cbyte", ByteType)() -> ColumnStat(distinctCount = 1, - min = Some(1L), max = Some(1L), nullCount = 0, avgLen = 1, maxLen = 1), + min = Some(1.toByte), max = Some(1.toByte), nullCount = 0, avgLen = 1, maxLen = 1), AttributeReference("cshort", ShortType)() -> ColumnStat(distinctCount = 1, - min = Some(1L), max = Some(1L), nullCount = 0, avgLen = 2, maxLen = 2), + min = Some(1.toShort), max = Some(1.toShort), nullCount = 0, avgLen = 2, maxLen = 2), AttributeReference("cint", IntegerType)() -> ColumnStat(distinctCount = 1, - min = Some(1L), max = Some(1L), nullCount = 0, avgLen = 4, maxLen = 4), + min = Some(1), max = Some(1), nullCount = 0, avgLen = 4, maxLen = 4), AttributeReference("clong", LongType)() -> ColumnStat(distinctCount = 1, min = Some(1L), max = Some(1L), nullCount = 0, avgLen = 8, maxLen = 8), AttributeReference("cdouble", DoubleType)() -> ColumnStat(distinctCount = 1, min = Some(1.0), max = Some(1.0), nullCount = 0, avgLen = 8, maxLen = 8), AttributeReference("cfloat", FloatType)() -> ColumnStat(distinctCount = 1, - min = Some(1.0), max = Some(1.0), nullCount = 0, avgLen = 4, maxLen = 4), + min = Some(1.0f), max = Some(1.0f), nullCount = 0, avgLen = 4, maxLen = 4), AttributeReference("cdec", DecimalType.SYSTEM_DEFAULT)() -> ColumnStat(distinctCount = 1, min = Some(dec), max = Some(dec), nullCount = 0, avgLen = 16, maxLen = 16), AttributeReference("cstring", StringType)() -> ColumnStat(distinctCount = 1, diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/ProjectEstimationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/ProjectEstimationSuite.scala index f408dc4153586..a5c4d22a29386 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/ProjectEstimationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/ProjectEstimationSuite.scala @@ -21,6 +21,7 @@ import java.sql.{Date, Timestamp} import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, AttributeMap, AttributeReference} import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.types._ @@ -62,28 +63,28 @@ class ProjectEstimationSuite extends StatsEstimationTestBase { } test("test row size estimation") { - val dec1 = new java.math.BigDecimal("1.000000000000000000") - val dec2 = new java.math.BigDecimal("8.000000000000000000") - val d1 = Date.valueOf("2016-05-08") - val d2 = Date.valueOf("2016-05-09") - val t1 = Timestamp.valueOf("2016-05-08 00:00:01") - val t2 = Timestamp.valueOf("2016-05-09 00:00:02") + val dec1 = Decimal("1.000000000000000000") + val dec2 = Decimal("8.000000000000000000") + val d1 = DateTimeUtils.fromJavaDate(Date.valueOf("2016-05-08")) + val d2 = DateTimeUtils.fromJavaDate(Date.valueOf("2016-05-09")) + val t1 = DateTimeUtils.fromJavaTimestamp(Timestamp.valueOf("2016-05-08 00:00:01")) + val t2 = DateTimeUtils.fromJavaTimestamp(Timestamp.valueOf("2016-05-09 00:00:02")) val columnInfo: AttributeMap[ColumnStat] = AttributeMap(Seq( AttributeReference("cbool", BooleanType)() -> ColumnStat(distinctCount = 2, min = Some(false), max = Some(true), nullCount = 0, avgLen = 1, maxLen = 1), AttributeReference("cbyte", ByteType)() -> ColumnStat(distinctCount = 2, - min = Some(1L), max = Some(2L), nullCount = 0, avgLen = 1, maxLen = 1), + min = Some(1.toByte), max = Some(2.toByte), nullCount = 0, avgLen = 1, maxLen = 1), AttributeReference("cshort", ShortType)() -> ColumnStat(distinctCount = 2, - min = Some(1L), max = Some(3L), nullCount = 0, avgLen = 2, maxLen = 2), + min = Some(1.toShort), max = Some(3.toShort), nullCount = 0, avgLen = 2, maxLen = 2), AttributeReference("cint", IntegerType)() -> ColumnStat(distinctCount = 2, - min = Some(1L), max = Some(4L), nullCount = 0, avgLen = 4, maxLen = 4), + min = Some(1), max = Some(4), nullCount = 0, avgLen = 4, maxLen = 4), AttributeReference("clong", LongType)() -> ColumnStat(distinctCount = 2, min = Some(1L), max = Some(5L), nullCount = 0, avgLen = 8, maxLen = 8), AttributeReference("cdouble", DoubleType)() -> ColumnStat(distinctCount = 2, min = Some(1.0), max = Some(6.0), nullCount = 0, avgLen = 8, maxLen = 8), AttributeReference("cfloat", FloatType)() -> ColumnStat(distinctCount = 2, - min = Some(1.0), max = Some(7.0), nullCount = 0, avgLen = 4, maxLen = 4), + min = Some(1.0f), max = Some(7.0f), nullCount = 0, avgLen = 4, maxLen = 4), AttributeReference("cdecimal", DecimalType.SYSTEM_DEFAULT)() -> ColumnStat(distinctCount = 2, min = Some(dec1), max = Some(dec2), nullCount = 0, avgLen = 16, maxLen = 16), AttributeReference("cstring", StringType)() -> ColumnStat(distinctCount = 2, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeColumnCommand.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeColumnCommand.scala index b89014ed8ef54..0d8db2ff5d5a0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeColumnCommand.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeColumnCommand.scala @@ -73,10 +73,10 @@ case class AnalyzeColumnCommand( val relation = sparkSession.table(tableIdent).logicalPlan // Resolve the column names and dedup using AttributeSet val resolver = sparkSession.sessionState.conf.resolver - val attributesToAnalyze = AttributeSet(columnNames.map { col => + val attributesToAnalyze = columnNames.map { col => val exprOption = relation.output.find(attr => resolver(attr.name, col)) exprOption.getOrElse(throw new AnalysisException(s"Column $col does not exist.")) - }).toSeq + } // Make sure the column types are supported for stats gathering. attributesToAnalyze.foreach { attr => @@ -99,8 +99,8 @@ case class AnalyzeColumnCommand( val statsRow = Dataset.ofRows(sparkSession, Aggregate(Nil, namedExpressions, relation)).head() val rowCount = statsRow.getLong(0) - val columnStats = attributesToAnalyze.zipWithIndex.map { case (expr, i) => - (expr.name, ColumnStat.rowToColumnStat(statsRow.getStruct(i + 1))) + val columnStats = attributesToAnalyze.zipWithIndex.map { case (attr, i) => + (attr.name, ColumnStat.rowToColumnStat(statsRow.getStruct(i + 1), attr)) }.toMap (rowCount, columnStats) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/StatisticsCollectionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/StatisticsCollectionSuite.scala index 1f547c5a2a8ff..ddc393c8da053 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/StatisticsCollectionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/StatisticsCollectionSuite.scala @@ -26,6 +26,7 @@ import scala.util.Random import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.catalog.{CatalogRelation, CatalogStatistics} import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.execution.datasources.LogicalRelation import org.apache.spark.sql.internal.StaticSQLConf import org.apache.spark.sql.test.{SharedSQLContext, SQLTestUtils} @@ -117,7 +118,7 @@ class StatisticsCollectionSuite extends StatisticsCollectionTestBase with Shared val df = data.toDF(stats.keys.toSeq :+ "carray" : _*) stats.zip(df.schema).foreach { case ((k, v), field) => withClue(s"column $k with type ${field.dataType}") { - val roundtrip = ColumnStat.fromMap("table_is_foo", field, v.toMap) + val roundtrip = ColumnStat.fromMap("table_is_foo", field, v.toMap(k, field.dataType)) assert(roundtrip == Some(v)) } } @@ -201,17 +202,19 @@ abstract class StatisticsCollectionTestBase extends QueryTest with SQLTestUtils /** A mapping from column to the stats collected. */ protected val stats = mutable.LinkedHashMap( "cbool" -> ColumnStat(2, Some(false), Some(true), 1, 1, 1), - "cbyte" -> ColumnStat(2, Some(1L), Some(2L), 1, 1, 1), - "cshort" -> ColumnStat(2, Some(1L), Some(3L), 1, 2, 2), - "cint" -> ColumnStat(2, Some(1L), Some(4L), 1, 4, 4), + "cbyte" -> ColumnStat(2, Some(1.toByte), Some(2.toByte), 1, 1, 1), + "cshort" -> ColumnStat(2, Some(1.toShort), Some(3.toShort), 1, 2, 2), + "cint" -> ColumnStat(2, Some(1), Some(4), 1, 4, 4), "clong" -> ColumnStat(2, Some(1L), Some(5L), 1, 8, 8), "cdouble" -> ColumnStat(2, Some(1.0), Some(6.0), 1, 8, 8), - "cfloat" -> ColumnStat(2, Some(1.0), Some(7.0), 1, 4, 4), - "cdecimal" -> ColumnStat(2, Some(dec1), Some(dec2), 1, 16, 16), + "cfloat" -> ColumnStat(2, Some(1.0f), Some(7.0f), 1, 4, 4), + "cdecimal" -> ColumnStat(2, Some(Decimal(dec1)), Some(Decimal(dec2)), 1, 16, 16), "cstring" -> ColumnStat(2, None, None, 1, 3, 3), "cbinary" -> ColumnStat(2, None, None, 1, 3, 3), - "cdate" -> ColumnStat(2, Some(d1), Some(d2), 1, 4, 4), - "ctimestamp" -> ColumnStat(2, Some(t1), Some(t2), 1, 8, 8) + "cdate" -> ColumnStat(2, Some(DateTimeUtils.fromJavaDate(d1)), + Some(DateTimeUtils.fromJavaDate(d2)), 1, 4, 4), + "ctimestamp" -> ColumnStat(2, Some(DateTimeUtils.fromJavaTimestamp(t1)), + Some(DateTimeUtils.fromJavaTimestamp(t2)), 1, 8, 8) ) private val randomName = new Random(31) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala index 806f2be5faeb0..8b0fdf49cefab 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala @@ -526,8 +526,10 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat if (stats.rowCount.isDefined) { statsProperties += STATISTICS_NUM_ROWS -> stats.rowCount.get.toString() } + val colNameTypeMap: Map[String, DataType] = + tableDefinition.schema.fields.map(f => (f.name, f.dataType)).toMap stats.colStats.foreach { case (colName, colStat) => - colStat.toMap.foreach { case (k, v) => + colStat.toMap(colName, colNameTypeMap(colName)).foreach { case (k, v) => statsProperties += (columnStatKeyPropName(colName, k) -> v) } } From 98b41ecbcbcddacdc2801c38fccc9823e710783b Mon Sep 17 00:00:00 2001 From: ouyangxiaochen Date: Sat, 15 Apr 2017 10:34:57 +0100 Subject: [PATCH 0269/1765] [SPARK-20316][SQL] Val and Var should strictly follow the Scala syntax ## What changes were proposed in this pull request? val and var should strictly follow the Scala syntax ## How was this patch tested? manual test and exisiting test cases Author: ouyangxiaochen Closes #17628 from ouyangxiaochen/spark-413. --- .../spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala index d5cc3b3855045..33e18a8da60fb 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala @@ -47,8 +47,8 @@ import org.apache.spark.util.ShutdownHookManager * has dropped its support. */ private[hive] object SparkSQLCLIDriver extends Logging { - private var prompt = "spark-sql" - private var continuedPrompt = "".padTo(prompt.length, ' ') + private val prompt = "spark-sql" + private val continuedPrompt = "".padTo(prompt.length, ' ') private var transport: TSocket = _ installSignalHandler() From 35e5ae4f81176af52569c465520a703529893b50 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Sun, 16 Apr 2017 11:14:18 +0800 Subject: [PATCH 0270/1765] [SPARK-19716][SQL][FOLLOW-UP] UnresolvedMapObjects should always be serializable ## What changes were proposed in this pull request? In https://github.com/apache/spark/pull/17398 we introduced `UnresolvedMapObjects` as a placeholder of `MapObjects`. Unfortunately `UnresolvedMapObjects` is not serializable as its `function` may reference Scala `Type` which is not serializable. Ideally this is fine, as we will never serialize and send unresolved expressions to executors. However users may accidentally do this, e.g. mistakenly reference an encoder instance when implementing `Aggregator`, we should fix it so that it's just a performance issue(more network traffic) and should not fail the query. ## How was this patch tested? N/A Author: Wenchen Fan Closes #17639 from cloud-fan/minor. --- .../expressions/objects/objects.scala | 56 ++++++++++--------- 1 file changed, 31 insertions(+), 25 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala index eed773d4cb368..f446c3e4a75f6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala @@ -406,7 +406,7 @@ case class WrapOption(child: Expression, optType: DataType) } /** - * A place holder for the loop variable used in [[MapObjects]]. This should never be constructed + * A placeholder for the loop variable used in [[MapObjects]]. This should never be constructed * manually, but will instead be passed into the provided lambda function. */ case class LambdaVariable( @@ -421,6 +421,27 @@ case class LambdaVariable( } } +/** + * When constructing [[MapObjects]], the element type must be given, which may not be available + * before analysis. This class acts like a placeholder for [[MapObjects]], and will be replaced by + * [[MapObjects]] during analysis after the input data is resolved. + * Note that, ideally we should not serialize and send unresolved expressions to executors, but + * users may accidentally do this(e.g. mistakenly reference an encoder instance when implementing + * Aggregator). Here we mark `function` as transient because it may reference scala Type, which is + * not serializable. Then even users mistakenly reference unresolved expression and serialize it, + * it's just a performance issue(more network traffic), and will not fail. + */ +case class UnresolvedMapObjects( + @transient function: Expression => Expression, + child: Expression, + customCollectionCls: Option[Class[_]] = None) extends UnaryExpression with Unevaluable { + override lazy val resolved = false + + override def dataType: DataType = customCollectionCls.map(ObjectType.apply).getOrElse { + throw new UnsupportedOperationException("not resolved") + } +} + object MapObjects { private val curId = new java.util.concurrent.atomic.AtomicInteger() @@ -442,20 +463,8 @@ object MapObjects { val loopValue = s"MapObjects_loopValue$id" val loopIsNull = s"MapObjects_loopIsNull$id" val loopVar = LambdaVariable(loopValue, loopIsNull, elementType) - val builderValue = s"MapObjects_builderValue$id" - MapObjects(loopValue, loopIsNull, elementType, function(loopVar), inputData, - customCollectionCls, builderValue) - } -} - -case class UnresolvedMapObjects( - function: Expression => Expression, - child: Expression, - customCollectionCls: Option[Class[_]] = None) extends UnaryExpression with Unevaluable { - override lazy val resolved = false - - override def dataType: DataType = customCollectionCls.map(ObjectType.apply).getOrElse { - throw new UnsupportedOperationException("not resolved") + MapObjects( + loopValue, loopIsNull, elementType, function(loopVar), inputData, customCollectionCls) } } @@ -482,8 +491,6 @@ case class UnresolvedMapObjects( * @param inputData An expression that when evaluated returns a collection object. * @param customCollectionCls Class of the resulting collection (returning ObjectType) * or None (returning ArrayType) - * @param builderValue The name of the builder variable used to construct the resulting collection - * (used only when returning ObjectType) */ case class MapObjects private( loopValue: String, @@ -491,8 +498,7 @@ case class MapObjects private( loopVarDataType: DataType, lambdaFunction: Expression, inputData: Expression, - customCollectionCls: Option[Class[_]], - builderValue: String) extends Expression with NonSQLExpression { + customCollectionCls: Option[Class[_]]) extends Expression with NonSQLExpression { override def nullable: Boolean = inputData.nullable @@ -590,15 +596,15 @@ case class MapObjects private( customCollectionCls match { case Some(cls) => // collection - val collObjectName = s"${cls.getName}$$.MODULE$$" - val getBuilderVar = s"$collObjectName.newBuilder()" + val getBuilder = s"${cls.getName}$$.MODULE$$.newBuilder()" + val builder = ctx.freshName("collectionBuilder") ( s""" - ${classOf[Builder[_, _]].getName} $builderValue = $getBuilderVar; - $builderValue.sizeHint($dataLength); + ${classOf[Builder[_, _]].getName} $builder = $getBuilder; + $builder.sizeHint($dataLength); """, - genValue => s"$builderValue.$$plus$$eq($genValue);", - s"(${cls.getName}) $builderValue.result();" + genValue => s"$builder.$$plus$$eq($genValue);", + s"(${cls.getName}) $builder.result();" ) case None => // array From e090f3c0ceebdf341536a1c0c70c80afddf2ee2a Mon Sep 17 00:00:00 2001 From: Xiao Li Date: Sun, 16 Apr 2017 12:09:34 +0800 Subject: [PATCH 0271/1765] [SPARK-20335][SQL] Children expressions of Hive UDF impacts the determinism of Hive UDF ### What changes were proposed in this pull request? ```JAVA /** * Certain optimizations should not be applied if UDF is not deterministic. * Deterministic UDF returns same result each time it is invoked with a * particular input. This determinism just needs to hold within the context of * a query. * * return true if the UDF is deterministic */ boolean deterministic() default true; ``` Based on the definition of [UDFType](https://github.com/apache/hive/blob/master/ql/src/java/org/apache/hadoop/hive/ql/udf/UDFType.java#L42-L50), when Hive UDF's children are non-deterministic, Hive UDF is also non-deterministic. ### How was this patch tested? Added test cases. Author: Xiao Li Closes #17635 from gatorsmile/udfDeterministic. --- .../org/apache/spark/sql/hive/hiveUDFs.scala | 4 ++-- .../hive/execution/AggregationQuerySuite.scala | 13 +++++++++++++ .../sql/hive/execution/HiveUDAFSuite.scala | 18 +++++++++++++++++- .../sql/hive/execution/HiveUDFSuite.scala | 15 +++++++++++++++ 4 files changed, 47 insertions(+), 3 deletions(-) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala index 51c814cf32a81..a83ad61b204ad 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala @@ -44,7 +44,7 @@ private[hive] case class HiveSimpleUDF( name: String, funcWrapper: HiveFunctionWrapper, children: Seq[Expression]) extends Expression with HiveInspectors with CodegenFallback with Logging { - override def deterministic: Boolean = isUDFDeterministic + override def deterministic: Boolean = isUDFDeterministic && children.forall(_.deterministic) override def nullable: Boolean = true @@ -123,7 +123,7 @@ private[hive] case class HiveGenericUDF( override def nullable: Boolean = true - override def deterministic: Boolean = isUDFDeterministic + override def deterministic: Boolean = isUDFDeterministic && children.forall(_.deterministic) override def foldable: Boolean = isUDFDeterministic && returnInspector.isInstanceOf[ConstantObjectInspector] diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala index 4a8086d7e5400..84f915977bd88 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala @@ -509,6 +509,19 @@ abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with Te Row(null, null, 110.0, null, null, 10.0) :: Nil) } + test("non-deterministic children expressions of UDAF") { + val e = intercept[AnalysisException] { + spark.sql( + """ + |SELECT mydoublesum(value + 1.5 * key + rand()) + |FROM agg1 + |GROUP BY key + """.stripMargin) + }.getMessage + assert(Seq("nondeterministic expression", + "should not appear in the arguments of an aggregate function").forall(e.contains)) + } + test("interpreted aggregate function") { checkAnswer( spark.sql( diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDAFSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDAFSuite.scala index c9ef72ee112cf..479ca1e8def56 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDAFSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDAFSuite.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.hive.execution import scala.collection.JavaConverters._ +import org.apache.hadoop.hive.ql.udf.UDAFPercentile import org.apache.hadoop.hive.ql.udf.generic.{AbstractGenericUDAFResolver, GenericUDAFEvaluator, GenericUDAFMax} import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFEvaluator.{AggregationBuffer, Mode} import org.apache.hadoop.hive.ql.util.JavaDataModel @@ -26,7 +27,7 @@ import org.apache.hadoop.hive.serde2.objectinspector.{ObjectInspector, ObjectIns import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory import org.apache.hadoop.hive.serde2.typeinfo.TypeInfo -import org.apache.spark.sql.{QueryTest, Row} +import org.apache.spark.sql.{AnalysisException, QueryTest, Row} import org.apache.spark.sql.execution.aggregate.ObjectHashAggregateExec import org.apache.spark.sql.hive.test.TestHiveSingleton import org.apache.spark.sql.test.SQLTestUtils @@ -84,6 +85,21 @@ class HiveUDAFSuite extends QueryTest with TestHiveSingleton with SQLTestUtils { Row(1, Row(1, 1)) )) } + + test("non-deterministic children expressions of UDAF") { + withTempView("view1") { + spark.range(1).selectExpr("id as x", "id as y").createTempView("view1") + withUserDefinedFunction("testUDAFPercentile" -> true) { + // non-deterministic children of Hive UDAF + sql(s"CREATE TEMPORARY FUNCTION testUDAFPercentile AS '${classOf[UDAFPercentile].getName}'") + val e1 = intercept[AnalysisException] { + sql("SELECT testUDAFPercentile(x, rand()) from view1 group by y") + }.getMessage + assert(Seq("nondeterministic expression", + "should not appear in the arguments of an aggregate function").forall(e1.contains)) + } + } + } } /** diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDFSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDFSuite.scala index ef6883839d437..4bbf9259192ea 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDFSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDFSuite.scala @@ -31,6 +31,7 @@ import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectIn import org.apache.hadoop.io.{LongWritable, Writable} import org.apache.spark.sql.{AnalysisException, QueryTest, Row} +import org.apache.spark.sql.catalyst.plans.logical.Project import org.apache.spark.sql.functions.max import org.apache.spark.sql.hive.test.TestHiveSingleton import org.apache.spark.sql.test.SQLTestUtils @@ -387,6 +388,20 @@ class HiveUDFSuite extends QueryTest with TestHiveSingleton with SQLTestUtils { hiveContext.reset() } + test("non-deterministic children of UDF") { + withUserDefinedFunction("testStringStringUDF" -> true, "testGenericUDFHash" -> true) { + // HiveSimpleUDF + sql(s"CREATE TEMPORARY FUNCTION testStringStringUDF AS '${classOf[UDFStringString].getName}'") + val df1 = sql("SELECT testStringStringUDF(rand(), \"hello\")") + assert(!df1.logicalPlan.asInstanceOf[Project].projectList.forall(_.deterministic)) + + // HiveGenericUDF + sql(s"CREATE TEMPORARY FUNCTION testGenericUDFHash AS '${classOf[GenericUDFHash].getName}'") + val df2 = sql("SELECT testGenericUDFHash(rand())") + assert(!df2.logicalPlan.asInstanceOf[Project].projectList.forall(_.deterministic)) + } + } + test("Hive UDFs with insufficient number of input arguments should trigger an analysis error") { Seq((1, 2)).toDF("a", "b").createOrReplaceTempView("testUDF") From a888fed3099e84c2cf45e9419f684a3658ada19d Mon Sep 17 00:00:00 2001 From: Ji Yan Date: Sun, 16 Apr 2017 14:34:12 +0100 Subject: [PATCH 0272/1765] [SPARK-19740][MESOS] Add support in Spark to pass arbitrary parameters into docker when running on mesos with docker containerizer ## What changes were proposed in this pull request? Allow passing in arbitrary parameters into docker when launching spark executors on mesos with docker containerizer tnachen ## How was this patch tested? Manually built and tested with passed in parameter Author: Ji Yan Closes #17109 from yanji84/ji/allow_set_docker_user. --- docs/running-on-mesos.md | 10 ++++ .../mesos/MesosSchedulerBackendUtil.scala | 36 +++++++++++-- .../MesosSchedulerBackendUtilSuite.scala | 53 +++++++++++++++++++ 3 files changed, 96 insertions(+), 3 deletions(-) create mode 100644 resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackendUtilSuite.scala diff --git a/docs/running-on-mesos.md b/docs/running-on-mesos.md index ef01cfe4b92cd..314a806edf39e 100644 --- a/docs/running-on-mesos.md +++ b/docs/running-on-mesos.md @@ -356,6 +356,16 @@ See the [configuration page](configuration.html) for information on Spark config By default Mesos agents will not pull images they already have cached. + + spark.mesos.executor.docker.parameters + (none) + + Set the list of custom parameters which will be passed into the docker run command when launching the Spark executor on Mesos using the docker containerizer. The format of this property is a comma-separated list of + key/value pairs. Example: + +
    key1=val1,key2=val2,key3=val3
    + + spark.mesos.executor.docker.volumes (none) diff --git a/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackendUtil.scala b/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackendUtil.scala index a2adb228dc299..fbcbc55099ec5 100644 --- a/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackendUtil.scala +++ b/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackendUtil.scala @@ -17,7 +17,7 @@ package org.apache.spark.scheduler.cluster.mesos -import org.apache.mesos.Protos.{ContainerInfo, Image, NetworkInfo, Volume} +import org.apache.mesos.Protos.{ContainerInfo, Image, NetworkInfo, Parameter, Volume} import org.apache.mesos.Protos.ContainerInfo.{DockerInfo, MesosInfo} import org.apache.spark.{SparkConf, SparkException} @@ -99,6 +99,28 @@ private[mesos] object MesosSchedulerBackendUtil extends Logging { .toList } + /** + * Parse a list of docker parameters, each of which + * takes the form key=value + */ + private def parseParamsSpec(params: String): List[Parameter] = { + // split with limit of 2 to avoid parsing error when '=' + // exists in the parameter value + params.split(",").map(_.split("=", 2)).flatMap { spec: Array[String] => + val param: Parameter.Builder = Parameter.newBuilder() + spec match { + case Array(key, value) => + Some(param.setKey(key).setValue(value)) + case spec => + logWarning(s"Unable to parse arbitary parameters: $params. " + + "Expected form: \"key=value(, ...)\"") + None + } + } + .map { _.build() } + .toList + } + def containerInfo(conf: SparkConf): ContainerInfo = { val containerType = if (conf.contains("spark.mesos.executor.docker.image") && conf.get("spark.mesos.containerizer", "docker") == "docker") { @@ -120,8 +142,14 @@ private[mesos] object MesosSchedulerBackendUtil extends Logging { .map(parsePortMappingsSpec) .getOrElse(List.empty) + val params = conf + .getOption("spark.mesos.executor.docker.parameters") + .map(parseParamsSpec) + .getOrElse(List.empty) + if (containerType == ContainerInfo.Type.DOCKER) { - containerInfo.setDocker(dockerInfo(image, forcePullImage, portMaps)) + containerInfo + .setDocker(dockerInfo(image, forcePullImage, portMaps, params)) } else { containerInfo.setMesos(mesosInfo(image, forcePullImage)) } @@ -144,11 +172,13 @@ private[mesos] object MesosSchedulerBackendUtil extends Logging { private def dockerInfo( image: String, forcePullImage: Boolean, - portMaps: List[ContainerInfo.DockerInfo.PortMapping]): DockerInfo = { + portMaps: List[ContainerInfo.DockerInfo.PortMapping], + params: List[Parameter]): DockerInfo = { val dockerBuilder = ContainerInfo.DockerInfo.newBuilder() .setImage(image) .setForcePullImage(forcePullImage) portMaps.foreach(dockerBuilder.addPortMappings(_)) + params.foreach(dockerBuilder.addParameters(_)) dockerBuilder.build } diff --git a/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackendUtilSuite.scala b/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackendUtilSuite.scala new file mode 100644 index 0000000000000..caf9d89fdd201 --- /dev/null +++ b/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackendUtilSuite.scala @@ -0,0 +1,53 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.scheduler.cluster.mesos + +import org.scalatest._ +import org.scalatest.mock.MockitoSugar + +import org.apache.spark.{SparkConf, SparkFunSuite} + +class MesosSchedulerBackendUtilSuite extends SparkFunSuite { + + test("ContainerInfo fails to parse invalid docker parameters") { + val conf = new SparkConf() + conf.set("spark.mesos.executor.docker.parameters", "a,b") + conf.set("spark.mesos.executor.docker.image", "test") + + val containerInfo = MesosSchedulerBackendUtil.containerInfo(conf) + val params = containerInfo.getDocker.getParametersList + + assert(params.size() == 0) + } + + test("ContainerInfo parses docker parameters") { + val conf = new SparkConf() + conf.set("spark.mesos.executor.docker.parameters", "a=1,b=2,c=3") + conf.set("spark.mesos.executor.docker.image", "test") + + val containerInfo = MesosSchedulerBackendUtil.containerInfo(conf) + val params = containerInfo.getDocker.getParametersList + assert(params.size() == 3) + assert(params.get(0).getKey == "a") + assert(params.get(0).getValue == "1") + assert(params.get(1).getKey == "b") + assert(params.get(1).getValue == "2") + assert(params.get(2).getKey == "c") + assert(params.get(2).getValue == "3") + } +} From ad935f526f57a9621c0a5ba082b85414c28282f4 Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Sun, 16 Apr 2017 14:36:42 +0100 Subject: [PATCH 0273/1765] [SPARK-20343][BUILD] Add avro dependency in core POM to resolve build failure in SBT Hadoop 2.6 master on Jenkins ## What changes were proposed in this pull request? This PR proposes to add ``` org.apache.avro avro ``` in core POM to see if it resolves the build failure as below: ``` [error] /home/jenkins/workspace/spark-master-test-sbt-hadoop-2.6/core/src/main/scala/org/apache/spark/serializer/GenericAvroSerializer.scala:123: value createDatumWriter is not a member of org.apache.avro.generic.GenericData [error] writerCache.getOrElseUpdate(schema, GenericData.get.createDatumWriter(schema)) [error] ``` https://amplab.cs.berkeley.edu/jenkins/view/Spark%20QA%20Test%20(Dashboard)/job/spark-master-test-sbt-hadoop-2.6/2770/consoleFull ## How was this patch tested? I tried many ways but I was unable to reproduce this in my local. Sean also tried the way I did but he was also unable to reproduce this. Please refer the comments in https://github.com/apache/spark/pull/17477#issuecomment-294094092 Author: hyukjinkwon Closes #17642 from HyukjinKwon/SPARK-20343. --- core/pom.xml | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/core/pom.xml b/core/pom.xml index 97a463abbefdd..24ce36deeb169 100644 --- a/core/pom.xml +++ b/core/pom.xml @@ -33,6 +33,10 @@ Spark Project Core http://spark.apache.org/ + + org.apache.avro + avro + org.apache.avro avro-mapred From 86d251c58591278a7c88745a1049e7a41db11964 Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Sun, 16 Apr 2017 11:27:27 -0700 Subject: [PATCH 0274/1765] [SPARK-20278][R] Disable 'multiple_dots_linter' lint rule that is against project's code style ## What changes were proposed in this pull request? Currently, multi-dot separated variables in R is not allowed. For example, ```diff setMethod("from_json", signature(x = "Column", schema = "structType"), - function(x, schema, asJsonArray = FALSE, ...) { + function(x, schema, as.json.array = FALSE, ...) { if (asJsonArray) { jschema <- callJStatic("org.apache.spark.sql.types.DataTypes", "createArrayType", ``` produces an error as below: ``` R/functions.R:2462:31: style: Words within variable and function names should be separated by '_' rather than '.'. function(x, schema, as.json.array = FALSE, ...) { ^~~~~~~~~~~~~ ``` This seems against https://google.github.io/styleguide/Rguide.xml#identifiers which says > The preferred form for variable names is all lower case letters and words separated with dots This looks because lintr by default https://github.com/jimhester/lintr follows http://r-pkgs.had.co.nz/style.html as written in the README.md. Few cases seems not following Google's one as "a few tweaks". Per [SPARK-6813](https://issues.apache.org/jira/browse/SPARK-6813), we follow Google's R Style Guide with few exceptions https://google.github.io/styleguide/Rguide.xml. This is also merged into Spark's website - https://github.com/apache/spark-website/pull/43 Also, it looks we have no limit on function name. This rule also looks affecting to the name of functions as written in the README.md. > `multiple_dots_linter`: check that function and variable names are separated by _ rather than .. ## How was this patch tested? Manually tested `./dev/lint-r`with the manual change below in `R/functions.R`: ```diff setMethod("from_json", signature(x = "Column", schema = "structType"), - function(x, schema, asJsonArray = FALSE, ...) { + function(x, schema, as.json.array = FALSE, ...) { if (asJsonArray) { jschema <- callJStatic("org.apache.spark.sql.types.DataTypes", "createArrayType", ``` **Before** ```R R/functions.R:2462:31: style: Words within variable and function names should be separated by '_' rather than '.'. function(x, schema, as.json.array = FALSE, ...) { ^~~~~~~~~~~~~ ``` **After** ``` lintr checks passed. ``` Author: hyukjinkwon Closes #17590 from HyukjinKwon/disable-dot-in-name. --- R/pkg/.lintr | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/R/pkg/.lintr b/R/pkg/.lintr index 038236fc149e6..ae50b28ec6166 100644 --- a/R/pkg/.lintr +++ b/R/pkg/.lintr @@ -1,2 +1,2 @@ -linters: with_defaults(line_length_linter(100), camel_case_linter = NULL, open_curly_linter(allow_single_line = TRUE), closed_curly_linter(allow_single_line = TRUE)) +linters: with_defaults(line_length_linter(100), multiple_dots_linter = NULL, camel_case_linter = NULL, open_curly_linter(allow_single_line = TRUE), closed_curly_linter(allow_single_line = TRUE)) exclusions: list("inst/profile/general.R" = 1, "inst/profile/shell.R") From 24f09b39c7b947e52fda952676d5114c2540e732 Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Mon, 17 Apr 2017 09:04:24 -0700 Subject: [PATCH 0275/1765] [SPARK-19828][R][FOLLOWUP] Rename asJsonArray to as.json.array in from_json function in R ## What changes were proposed in this pull request? This was suggested to be `as.json.array` at the first place in the PR to SPARK-19828 but we could not do this as the lint check emits an error for multiple dots in the variable names. After SPARK-20278, now we are able to use `multiple.dots.in.names`. `asJsonArray` in `from_json` function is still able to be changed as 2.2 is not released yet. So, this PR proposes to rename `asJsonArray` to `as.json.array`. ## How was this patch tested? Jenkins tests, local tests with `./R/run-tests.sh` and manual `./dev/lint-r`. Existing tests should cover this. Author: hyukjinkwon Closes #17653 from HyukjinKwon/SPARK-19828-followup. --- R/pkg/R/functions.R | 8 ++++---- R/pkg/inst/tests/testthat/test_sparkSQL.R | 2 +- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/R/pkg/R/functions.R b/R/pkg/R/functions.R index 449476dec5339..c311921fb33db 100644 --- a/R/pkg/R/functions.R +++ b/R/pkg/R/functions.R @@ -2438,12 +2438,12 @@ setMethod("date_format", signature(y = "Column", x = "character"), #' from_json #' #' Parses a column containing a JSON string into a Column of \code{structType} with the specified -#' \code{schema} or array of \code{structType} if \code{asJsonArray} is set to \code{TRUE}. +#' \code{schema} or array of \code{structType} if \code{as.json.array} is set to \code{TRUE}. #' If the string is unparseable, the Column will contains the value NA. #' #' @param x Column containing the JSON string. #' @param schema a structType object to use as the schema to use when parsing the JSON string. -#' @param asJsonArray indicating if input string is JSON array of objects or a single object. +#' @param as.json.array indicating if input string is JSON array of objects or a single object. #' @param ... additional named properties to control how the json is parsed, accepts the same #' options as the JSON data source. #' @@ -2459,8 +2459,8 @@ setMethod("date_format", signature(y = "Column", x = "character"), #'} #' @note from_json since 2.2.0 setMethod("from_json", signature(x = "Column", schema = "structType"), - function(x, schema, asJsonArray = FALSE, ...) { - if (asJsonArray) { + function(x, schema, as.json.array = FALSE, ...) { + if (as.json.array) { jschema <- callJStatic("org.apache.spark.sql.types.DataTypes", "createArrayType", schema$jobj) diff --git a/R/pkg/inst/tests/testthat/test_sparkSQL.R b/R/pkg/inst/tests/testthat/test_sparkSQL.R index 3fbb618ddfc39..6a6c9a809ab13 100644 --- a/R/pkg/inst/tests/testthat/test_sparkSQL.R +++ b/R/pkg/inst/tests/testthat/test_sparkSQL.R @@ -1454,7 +1454,7 @@ test_that("column functions", { jsonArr <- "[{\"name\":\"Bob\"}, {\"name\":\"Alice\"}]" df <- as.DataFrame(list(list("people" = jsonArr))) schema <- structType(structField("name", "string")) - arr <- collect(select(df, alias(from_json(df$people, schema, asJsonArray = TRUE), "arrcol"))) + arr <- collect(select(df, alias(from_json(df$people, schema, as.json.array = TRUE), "arrcol"))) expect_equal(ncol(arr), 1) expect_equal(nrow(arr), 1) expect_is(arr[[1]][[1]], "list") From 01ff0350a85b179715946c3bd4f003db7c5e3641 Mon Sep 17 00:00:00 2001 From: Xiao Li Date: Mon, 17 Apr 2017 09:50:20 -0700 Subject: [PATCH 0276/1765] [SPARK-20349][SQL] ListFunctions returns duplicate functions after using persistent functions ### What changes were proposed in this pull request? The session catalog caches some persistent functions in the `FunctionRegistry`, so there can be duplicates. Our Catalog API `listFunctions` does not handle it. It would be better if `SessionCatalog` API can de-duplciate the records, instead of doing it by each API caller. In `FunctionRegistry`, our functions are identified by the unquoted string. Thus, this PR is try to parse it using our parser interface and then de-duplicate the names. ### How was this patch tested? Added test cases. Author: Xiao Li Closes #17646 from gatorsmile/showFunctions. --- .../sql/catalyst/catalog/SessionCatalog.scala | 21 ++++++++++++++----- .../sql/execution/command/functions.scala | 4 +--- .../sql/hive/execution/HiveUDFSuite.scala | 17 +++++++++++++++ 3 files changed, 34 insertions(+), 8 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala index 1417bccf657cd..3fbf83f3a38a2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala @@ -22,6 +22,7 @@ import java.util.Locale import javax.annotation.concurrent.GuardedBy import scala.collection.mutable +import scala.util.{Failure, Success, Try} import com.google.common.cache.{Cache, CacheBuilder} import org.apache.hadoop.conf.Configuration @@ -1202,15 +1203,25 @@ class SessionCatalog( def listFunctions(db: String, pattern: String): Seq[(FunctionIdentifier, String)] = { val dbName = formatDatabaseName(db) requireDbExists(dbName) - val dbFunctions = externalCatalog.listFunctions(dbName, pattern) - .map { f => FunctionIdentifier(f, Some(dbName)) } - val loadedFunctions = StringUtils.filterPattern(functionRegistry.listFunction(), pattern) - .map { f => FunctionIdentifier(f) } + val dbFunctions = externalCatalog.listFunctions(dbName, pattern).map { f => + FunctionIdentifier(f, Some(dbName)) } + val loadedFunctions = + StringUtils.filterPattern(functionRegistry.listFunction(), pattern).map { f => + // In functionRegistry, function names are stored as an unquoted format. + Try(parser.parseFunctionIdentifier(f)) match { + case Success(e) => e + case Failure(_) => + // The names of some built-in functions are not parsable by our parser, e.g., % + FunctionIdentifier(f) + } + } val functions = dbFunctions ++ loadedFunctions + // The session catalog caches some persistent functions in the FunctionRegistry + // so there can be duplicates. functions.map { case f if FunctionRegistry.functionSet.contains(f.funcName) => (f, "SYSTEM") case f => (f, "USER") - } + }.distinct } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/functions.scala index e0d0029369576..545082324f0d3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/functions.scala @@ -207,8 +207,6 @@ case class ShowFunctionsCommand( case (f, "USER") if showUserFunctions => f.unquotedString case (f, "SYSTEM") if showSystemFunctions => f.unquotedString } - // The session catalog caches some persistent functions in the FunctionRegistry - // so there can be duplicates. - functionNames.distinct.sorted.map(Row(_)) + functionNames.sorted.map(Row(_)) } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDFSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDFSuite.scala index 4bbf9259192ea..4446af2e75e00 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDFSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDFSuite.scala @@ -573,6 +573,23 @@ class HiveUDFSuite extends QueryTest with TestHiveSingleton with SQLTestUtils { checkAnswer(testData.selectExpr("statelessUDF() as s").agg(max($"s")), Row(1)) } } + + test("Show persistent functions") { + val testData = spark.sparkContext.parallelize(StringCaseClass("") :: Nil).toDF() + withTempView("inputTable") { + testData.createOrReplaceTempView("inputTable") + withUserDefinedFunction("testUDFToListInt" -> false) { + val numFunc = spark.catalog.listFunctions().count() + sql(s"CREATE FUNCTION testUDFToListInt AS '${classOf[UDFToListInt].getName}'") + assert(spark.catalog.listFunctions().count() == numFunc + 1) + checkAnswer( + sql("SELECT testUDFToListInt(s) FROM inputTable"), + Seq(Row(Seq(1, 2, 3)))) + assert(sql("show functions").count() == numFunc + 1) + assert(spark.catalog.listFunctions().count() == numFunc + 1) + } + } + } } class TestPair(x: Int, y: Int) extends Writable with Serializable { From e5fee3e4f853f906f0b476bb04ee35a15f1ae650 Mon Sep 17 00:00:00 2001 From: Jakob Odersky Date: Mon, 17 Apr 2017 11:17:57 -0700 Subject: [PATCH 0277/1765] [SPARK-17647][SQL] Fix backslash escaping in 'LIKE' patterns. ## What changes were proposed in this pull request? This patch fixes a bug in the way LIKE patterns are translated to Java regexes. The bug causes any character following an escaped backslash to be escaped, i.e. there is double-escaping. A concrete example is the following pattern:`'%\\%'`. The expected Java regex that this pattern should correspond to (according to the behavior described below) is `'.*\\.*'`, however the current situation leads to `'.*\\%'` instead. --- Update: in light of the discussion that ensued, we should explicitly define the expected behaviour of LIKE expressions, especially in certain edge cases. With the help of gatorsmile, we put together a list of different RDBMS and their variations wrt to certain standard features. | RDBMS\Features | Wildcards | Default escape [1] | Case sensitivity | | --- | --- | --- | --- | | [MS SQL Server](https://msdn.microsoft.com/en-us/library/ms179859.aspx) | _, %, [], [^] | none | no | | [Oracle](https://docs.oracle.com/cd/B12037_01/server.101/b10759/conditions016.htm) | _, % | none | yes | | [DB2 z/OS](http://www.ibm.com/support/knowledgecenter/SSEPEK_11.0.0/sqlref/src/tpc/db2z_likepredicate.html) | _, % | none | yes | | [MySQL](http://dev.mysql.com/doc/refman/5.7/en/string-comparison-functions.html) | _, % | none | no | | [PostreSQL](https://www.postgresql.org/docs/9.0/static/functions-matching.html) | _, % | \ | yes | | [Hive](https://cwiki.apache.org/confluence/display/Hive/LanguageManual+UDF) | _, % | none | yes | | Current Spark | _, % | \ | yes | [1] Default escape character: most systems do not have a default escape character, instead the user can specify one by calling a like expression with an escape argument [A] LIKE [B] ESCAPE [C]. This syntax is currently not supported by Spark, however I would volunteer to implement this feature in a separate ticket. The specifications are often quite terse and certain scenarios are undocumented, so here is a list of scenarios that I am uncertain about and would appreciate any input. Specifically I am looking for feedback on whether or not Spark's current behavior should be changed. 1. [x] Ending a pattern with the escape sequence, e.g. `like 'a\'`. PostreSQL gives an error: 'LIKE pattern must not end with escape character', which I personally find logical. Currently, Spark allows "non-terminated" escapes and simply ignores them as part of the pattern. According to [DB2's documentation](http://www.ibm.com/support/knowledgecenter/SSEPGG_9.7.0/com.ibm.db2.luw.messages.sql.doc/doc/msql00130n.html), ending a pattern in an escape character is invalid. _Proposed new behaviour in Spark: throw AnalysisException_ 2. [x] Empty input, e.g. `'' like ''` Postgres and DB2 will match empty input only if the pattern is empty as well, any other combination of empty input will not match. Spark currently follows this rule. 3. [x] Escape before a non-special character, e.g. `'a' like '\a'`. Escaping a non-wildcard character is not really documented but PostgreSQL just treats it verbatim, which I also find the least surprising behavior. Spark does the same. According to [DB2's documentation](http://www.ibm.com/support/knowledgecenter/SSEPGG_9.7.0/com.ibm.db2.luw.messages.sql.doc/doc/msql00130n.html), it is invalid to follow an escape character with anything other than an escape character, an underscore or a percent sign. _Proposed new behaviour in Spark: throw AnalysisException_ The current specification is also described in the operator's source code in this patch. ## How was this patch tested? Extra case in regex unit tests. Author: Jakob Odersky This patch had conflicts when merged, resolved by Committer: Reynold Xin Closes #15398 from jodersky/SPARK-17647. --- .../expressions/regexpExpressions.scala | 25 ++- .../spark/sql/catalyst/util/StringUtils.scala | 50 +++--- .../expressions/RegexpExpressionsSuite.scala | 161 +++++++++++------- .../sql/catalyst/util/StringUtilsSuite.scala | 4 +- 4 files changed, 153 insertions(+), 87 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala index 49b779711308f..a36da8e94b3ad 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala @@ -69,7 +69,30 @@ abstract class StringRegexExpression extends BinaryExpression * Simple RegEx pattern matching function */ @ExpressionDescription( - usage = "str _FUNC_ pattern - Returns true if `str` matches `pattern`, or false otherwise.") + usage = "str _FUNC_ pattern - Returns true if str matches pattern, " + + "null if any arguments are null, false otherwise.", + extended = """ + Arguments: + str - a string expression + pattern - a string expression. The pattern is a string which is matched literally, with + exception to the following special symbols: + + _ matches any one character in the input (similar to . in posix regular expressions) + + % matches zero ore more characters in the input (similar to .* in posix regular + expressions) + + The escape character is '\'. If an escape character precedes a special symbol or another + escape character, the following character is matched literally. It is invalid to escape + any other character. + + Examples: + > SELECT '%SystemDrive%\Users\John' _FUNC_ '\%SystemDrive\%\\Users%' + true + + See also: + Use RLIKE to match with standard regular expressions. +""") case class Like(left: Expression, right: Expression) extends StringRegexExpression { override def escape(v: String): String = StringUtils.escapeLikeRegex(v) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/StringUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/StringUtils.scala index cde8bd5b9614c..ca22ea24207e1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/StringUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/StringUtils.scala @@ -19,32 +19,44 @@ package org.apache.spark.sql.catalyst.util import java.util.regex.{Pattern, PatternSyntaxException} +import org.apache.spark.sql.AnalysisException import org.apache.spark.unsafe.types.UTF8String object StringUtils { - // replace the _ with .{1} exactly match 1 time of any character - // replace the % with .*, match 0 or more times with any character - def escapeLikeRegex(v: String): String = { - if (!v.isEmpty) { - "(?s)" + (' ' +: v.init).zip(v).flatMap { - case (prev, '\\') => "" - case ('\\', c) => - c match { - case '_' => "_" - case '%' => "%" - case _ => Pattern.quote("\\" + c) - } - case (prev, c) => + /** + * Validate and convert SQL 'like' pattern to a Java regular expression. + * + * Underscores (_) are converted to '.' and percent signs (%) are converted to '.*', other + * characters are quoted literally. Escaping is done according to the rules specified in + * [[org.apache.spark.sql.catalyst.expressions.Like]] usage documentation. An invalid pattern will + * throw an [[AnalysisException]]. + * + * @param pattern the SQL pattern to convert + * @return the equivalent Java regular expression of the pattern + */ + def escapeLikeRegex(pattern: String): String = { + val in = pattern.toIterator + val out = new StringBuilder() + + def fail(message: String) = throw new AnalysisException( + s"the pattern '$pattern' is invalid, $message") + + while (in.hasNext) { + in.next match { + case '\\' if in.hasNext => + val c = in.next c match { - case '_' => "." - case '%' => ".*" - case _ => Pattern.quote(Character.toString(c)) + case '_' | '%' | '\\' => out ++= Pattern.quote(Character.toString(c)) + case _ => fail(s"the escape character is not allowed to precede '$c'") } - }.mkString - } else { - v + case '\\' => fail("it is not allowed to end with the escape character") + case '_' => out ++= "." + case '%' => out ++= ".*" + case c => out ++= Pattern.quote(Character.toString(c)) + } } + "(?s)" + out.result() // (?s) enables dotall mode, causing "." to match new lines } private[this] val trueStrings = Set("t", "true", "y", "yes", "1").map(UTF8String.fromString) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/RegexpExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/RegexpExpressionsSuite.scala index 5299549e7b4da..1ce150e091981 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/RegexpExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/RegexpExpressionsSuite.scala @@ -18,16 +18,38 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.dsl.expressions._ -import org.apache.spark.sql.types.StringType +import org.apache.spark.sql.types.{IntegerType, StringType} /** * Unit tests for regular expression (regexp) related SQL expressions. */ class RegexpExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { - test("LIKE literal Regular Expression") { - checkEvaluation(Literal.create(null, StringType).like("a"), null) + /** + * Check if a given expression evaluates to an expected output, in case the input is + * a literal and in case the input is in the form of a row. + * @tparam A type of input + * @param mkExpr the expression to test for a given input + * @param input value that will be used to create the expression, as literal and in the form + * of a row + * @param expected the expected output of the expression + * @param inputToExpression an implicit conversion from the input type to its corresponding + * sql expression + */ + def checkLiteralRow[A](mkExpr: Expression => Expression, input: A, expected: Any) + (implicit inputToExpression: A => Expression): Unit = { + checkEvaluation(mkExpr(input), expected) // check literal input + + val regex = 'a.string.at(0) + checkEvaluation(mkExpr(regex), expected, create_row(input)) // check row input + } + + test("LIKE Pattern") { + + // null handling + checkLiteralRow(Literal.create(null, StringType).like(_), "a", null) checkEvaluation(Literal.create("a", StringType).like(Literal.create(null, StringType)), null) checkEvaluation(Literal.create(null, StringType).like(Literal.create(null, StringType)), null) checkEvaluation( @@ -39,45 +61,64 @@ class RegexpExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation( Literal.create(null, StringType).like(NonFoldableLiteral.create(null, StringType)), null) - checkEvaluation("abdef" like "abdef", true) - checkEvaluation("a_%b" like "a\\__b", true) - checkEvaluation("addb" like "a_%b", true) - checkEvaluation("addb" like "a\\__b", false) - checkEvaluation("addb" like "a%\\%b", false) - checkEvaluation("a_%b" like "a%\\%b", true) - checkEvaluation("addb" like "a%", true) - checkEvaluation("addb" like "**", false) - checkEvaluation("abc" like "a%", true) - checkEvaluation("abc" like "b%", false) - checkEvaluation("abc" like "bc%", false) - checkEvaluation("a\nb" like "a_b", true) - checkEvaluation("ab" like "a%b", true) - checkEvaluation("a\nb" like "a%b", true) - } + // simple patterns + checkLiteralRow("abdef" like _, "abdef", true) + checkLiteralRow("a_%b" like _, "a\\__b", true) + checkLiteralRow("addb" like _, "a_%b", true) + checkLiteralRow("addb" like _, "a\\__b", false) + checkLiteralRow("addb" like _, "a%\\%b", false) + checkLiteralRow("a_%b" like _, "a%\\%b", true) + checkLiteralRow("addb" like _, "a%", true) + checkLiteralRow("addb" like _, "**", false) + checkLiteralRow("abc" like _, "a%", true) + checkLiteralRow("abc" like _, "b%", false) + checkLiteralRow("abc" like _, "bc%", false) + checkLiteralRow("a\nb" like _, "a_b", true) + checkLiteralRow("ab" like _, "a%b", true) + checkLiteralRow("a\nb" like _, "a%b", true) + + // empty input + checkLiteralRow("" like _, "", true) + checkLiteralRow("a" like _, "", false) + checkLiteralRow("" like _, "a", false) + + // SI-17647 double-escaping backslash + checkLiteralRow("""\\\\""" like _, """%\\%""", true) + checkLiteralRow("""%%""" like _, """%%""", true) + checkLiteralRow("""\__""" like _, """\\\__""", true) + checkLiteralRow("""\\\__""" like _, """%\\%\%""", false) + checkLiteralRow("""_\\\%""" like _, """%\\""", false) + + // unicode + // scalastyle:off nonascii + checkLiteralRow("a\u20ACa" like _, "_\u20AC_", true) + checkLiteralRow("a€a" like _, "_€_", true) + checkLiteralRow("a€a" like _, "_\u20AC_", true) + checkLiteralRow("a\u20ACa" like _, "_€_", true) + // scalastyle:on nonascii + + // invalid escaping + val invalidEscape = intercept[AnalysisException] { + evaluate("""a""" like """\a""") + } + assert(invalidEscape.getMessage.contains("pattern")) + + val endEscape = intercept[AnalysisException] { + evaluate("""a""" like """a\""") + } + assert(endEscape.getMessage.contains("pattern")) + + // case + checkLiteralRow("A" like _, "a%", false) + checkLiteralRow("a" like _, "A%", false) + checkLiteralRow("AaA" like _, "_a_", true) - test("LIKE Non-literal Regular Expression") { - val regEx = 'a.string.at(0) - checkEvaluation("abcd" like regEx, null, create_row(null)) - checkEvaluation("abdef" like regEx, true, create_row("abdef")) - checkEvaluation("a_%b" like regEx, true, create_row("a\\__b")) - checkEvaluation("addb" like regEx, true, create_row("a_%b")) - checkEvaluation("addb" like regEx, false, create_row("a\\__b")) - checkEvaluation("addb" like regEx, false, create_row("a%\\%b")) - checkEvaluation("a_%b" like regEx, true, create_row("a%\\%b")) - checkEvaluation("addb" like regEx, true, create_row("a%")) - checkEvaluation("addb" like regEx, false, create_row("**")) - checkEvaluation("abc" like regEx, true, create_row("a%")) - checkEvaluation("abc" like regEx, false, create_row("b%")) - checkEvaluation("abc" like regEx, false, create_row("bc%")) - checkEvaluation("a\nb" like regEx, true, create_row("a_b")) - checkEvaluation("ab" like regEx, true, create_row("a%b")) - checkEvaluation("a\nb" like regEx, true, create_row("a%b")) - - checkEvaluation(Literal.create(null, StringType) like regEx, null, create_row("bc%")) + // example + checkLiteralRow("""%SystemDrive%\Users\John""" like _, """\%SystemDrive\%\\Users%""", true) } - test("RLIKE literal Regular Expression") { - checkEvaluation(Literal.create(null, StringType) rlike "abdef", null) + test("RLIKE Regular Expression") { + checkLiteralRow(Literal.create(null, StringType) rlike _, "abdef", null) checkEvaluation("abdef" rlike Literal.create(null, StringType), null) checkEvaluation(Literal.create(null, StringType) rlike Literal.create(null, StringType), null) checkEvaluation("abdef" rlike NonFoldableLiteral.create("abdef", StringType), true) @@ -87,42 +128,32 @@ class RegexpExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation( Literal.create(null, StringType) rlike NonFoldableLiteral.create(null, StringType), null) - checkEvaluation("abdef" rlike "abdef", true) - checkEvaluation("abbbbc" rlike "a.*c", true) + checkLiteralRow("abdef" rlike _, "abdef", true) + checkLiteralRow("abbbbc" rlike _, "a.*c", true) - checkEvaluation("fofo" rlike "^fo", true) - checkEvaluation("fo\no" rlike "^fo\no$", true) - checkEvaluation("Bn" rlike "^Ba*n", true) - checkEvaluation("afofo" rlike "fo", true) - checkEvaluation("afofo" rlike "^fo", false) - checkEvaluation("Baan" rlike "^Ba?n", false) - checkEvaluation("axe" rlike "pi|apa", false) - checkEvaluation("pip" rlike "^(pi)*$", false) + checkLiteralRow("fofo" rlike _, "^fo", true) + checkLiteralRow("fo\no" rlike _, "^fo\no$", true) + checkLiteralRow("Bn" rlike _, "^Ba*n", true) + checkLiteralRow("afofo" rlike _, "fo", true) + checkLiteralRow("afofo" rlike _, "^fo", false) + checkLiteralRow("Baan" rlike _, "^Ba?n", false) + checkLiteralRow("axe" rlike _, "pi|apa", false) + checkLiteralRow("pip" rlike _, "^(pi)*$", false) - checkEvaluation("abc" rlike "^ab", true) - checkEvaluation("abc" rlike "^bc", false) - checkEvaluation("abc" rlike "^ab", true) - checkEvaluation("abc" rlike "^bc", false) + checkLiteralRow("abc" rlike _, "^ab", true) + checkLiteralRow("abc" rlike _, "^bc", false) + checkLiteralRow("abc" rlike _, "^ab", true) + checkLiteralRow("abc" rlike _, "^bc", false) intercept[java.util.regex.PatternSyntaxException] { evaluate("abbbbc" rlike "**") } - } - - test("RLIKE Non-literal Regular Expression") { - val regEx = 'a.string.at(0) - checkEvaluation("abdef" rlike regEx, true, create_row("abdef")) - checkEvaluation("abbbbc" rlike regEx, true, create_row("a.*c")) - checkEvaluation("fofo" rlike regEx, true, create_row("^fo")) - checkEvaluation("fo\no" rlike regEx, true, create_row("^fo\no$")) - checkEvaluation("Bn" rlike regEx, true, create_row("^Ba*n")) - intercept[java.util.regex.PatternSyntaxException] { - evaluate("abbbbc" rlike regEx, create_row("**")) + val regex = 'a.string.at(0) + evaluate("abbbbc" rlike regex, create_row("**")) } } - test("RegexReplace") { val row1 = create_row("100-200", "(\\d+)", "num") val row2 = create_row("100-200", "(\\d+)", "###") diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/StringUtilsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/StringUtilsSuite.scala index 2ffc18a8d14fb..78fee5135c3ae 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/StringUtilsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/StringUtilsSuite.scala @@ -24,9 +24,9 @@ class StringUtilsSuite extends SparkFunSuite { test("escapeLikeRegex") { assert(escapeLikeRegex("abdef") === "(?s)\\Qa\\E\\Qb\\E\\Qd\\E\\Qe\\E\\Qf\\E") - assert(escapeLikeRegex("a\\__b") === "(?s)\\Qa\\E_.\\Qb\\E") + assert(escapeLikeRegex("a\\__b") === "(?s)\\Qa\\E\\Q_\\E.\\Qb\\E") assert(escapeLikeRegex("a_%b") === "(?s)\\Qa\\E..*\\Qb\\E") - assert(escapeLikeRegex("a%\\%b") === "(?s)\\Qa\\E.*%\\Qb\\E") + assert(escapeLikeRegex("a%\\%b") === "(?s)\\Qa\\E.*\\Q%\\E\\Qb\\E") assert(escapeLikeRegex("a%") === "(?s)\\Qa\\E.*") assert(escapeLikeRegex("**") === "(?s)\\Q*\\E\\Q*\\E") assert(escapeLikeRegex("a_b") === "(?s)\\Qa\\E.\\Qb\\E") From 0075562dd2551a31c35ca26922d6bd73cdb78ea4 Mon Sep 17 00:00:00 2001 From: Andrew Ash Date: Mon, 17 Apr 2017 17:56:33 -0700 Subject: [PATCH 0278/1765] Typo fix: distitrbuted -> distributed ## What changes were proposed in this pull request? Typo fix: distitrbuted -> distributed ## How was this patch tested? Existing tests Author: Andrew Ash Closes #17664 from ash211/patch-1. --- .../src/main/scala/org/apache/spark/deploy/yarn/Client.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala index 424bbca123190..b817570c0abf7 100644 --- a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala +++ b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala @@ -577,7 +577,7 @@ private[spark] class Client( ).foreach { case (flist, resType, addToClasspath) => flist.foreach { file => val (_, localizedPath) = distribute(file, resType = resType) - // If addToClassPath, we ignore adding jar multiple times to distitrbuted cache. + // If addToClassPath, we ignore adding jar multiple times to distributed cache. if (addToClasspath) { if (localizedPath != null) { cachedSecondaryJarLinks += localizedPath From 33ea908af94152147e996a6dc8da41ada27d5af3 Mon Sep 17 00:00:00 2001 From: Jacek Laskowski Date: Mon, 17 Apr 2017 17:58:10 -0700 Subject: [PATCH 0279/1765] [TEST][MINOR] Replace repartitionBy with distribute in CollapseRepartitionSuite ## What changes were proposed in this pull request? Replace non-existent `repartitionBy` with `distribute` in `CollapseRepartitionSuite`. ## How was this patch tested? local build and `catalyst/testOnly *CollapseRepartitionSuite` Author: Jacek Laskowski Closes #17657 from jaceklaskowski/CollapseRepartitionSuite. --- .../optimizer/CollapseRepartitionSuite.scala | 21 +++++++++---------- 1 file changed, 10 insertions(+), 11 deletions(-) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CollapseRepartitionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CollapseRepartitionSuite.scala index 59d2dc46f00ce..8cc8decd65de1 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CollapseRepartitionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CollapseRepartitionSuite.scala @@ -106,8 +106,8 @@ class CollapseRepartitionSuite extends PlanTest { comparePlans(optimized2, correctAnswer) } - test("repartitionBy above repartition") { - // Always respects the top repartitionBy amd removes useless repartition + test("distribute above repartition") { + // Always respects the top distribute and removes useless repartition val query1 = testRelation .repartition(10) .distribute('a)(20) @@ -123,8 +123,8 @@ class CollapseRepartitionSuite extends PlanTest { comparePlans(optimized2, correctAnswer) } - test("repartitionBy above coalesce") { - // Always respects the top repartitionBy amd removes useless coalesce below repartition + test("distribute above coalesce") { + // Always respects the top distribute and removes useless coalesce below repartition val query1 = testRelation .coalesce(10) .distribute('a)(20) @@ -140,8 +140,8 @@ class CollapseRepartitionSuite extends PlanTest { comparePlans(optimized2, correctAnswer) } - test("repartition above repartitionBy") { - // Always respects the top repartition amd removes useless distribute below repartition + test("repartition above distribute") { + // Always respects the top repartition and removes useless distribute below repartition val query1 = testRelation .distribute('a)(10) .repartition(20) @@ -155,11 +155,10 @@ class CollapseRepartitionSuite extends PlanTest { comparePlans(optimized1, correctAnswer) comparePlans(optimized2, correctAnswer) - } - test("coalesce above repartitionBy") { - // Remove useless coalesce above repartition + test("coalesce above distribute") { + // Remove useless coalesce above distribute val query1 = testRelation .distribute('a)(10) .coalesce(20) @@ -180,8 +179,8 @@ class CollapseRepartitionSuite extends PlanTest { comparePlans(optimized2, correctAnswer2) } - test("collapse two adjacent repartitionBys into one") { - // Always respects the top repartitionBy + test("collapse two adjacent distributes into one") { + // Always respects the top distribute val query1 = testRelation .distribute('b)(10) .distribute('a)(20) From b0a1e93e93167b53058525a20a8b06f7df5f09a2 Mon Sep 17 00:00:00 2001 From: Felix Cheung Date: Mon, 17 Apr 2017 23:55:40 -0700 Subject: [PATCH 0280/1765] [SPARK-17647][SQL][FOLLOWUP][MINOR] fix typo ## What changes were proposed in this pull request? fix typo ## How was this patch tested? manual Author: Felix Cheung Closes #17663 from felixcheung/likedoctypo. --- .../spark/sql/catalyst/expressions/regexpExpressions.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala index a36da8e94b3ad..3fa84589e3c68 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala @@ -79,7 +79,7 @@ abstract class StringRegexExpression extends BinaryExpression _ matches any one character in the input (similar to . in posix regular expressions) - % matches zero ore more characters in the input (similar to .* in posix regular + % matches zero or more characters in the input (similar to .* in posix regular expressions) The escape character is '\'. If an escape character precedes a special symbol or another From 07fd94e0d05e827fae65d6e0e1cb89e28c8f2771 Mon Sep 17 00:00:00 2001 From: Robert Stupp Date: Tue, 18 Apr 2017 11:02:43 +0100 Subject: [PATCH 0281/1765] [SPARK-20344][SCHEDULER] Duplicate call in FairSchedulableBuilder.addTaskSetManager ## What changes were proposed in this pull request? Eliminate the duplicate call to `Pool.getSchedulableByName()` in `FairSchedulableBuilder.addTaskSetManager` ## How was this patch tested? ./dev/run-tests Author: Robert Stupp Closes #17647 from snazy/20344-dup-call-master. --- .../spark/scheduler/SchedulableBuilder.scala | 32 +++++++++---------- 1 file changed, 16 insertions(+), 16 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/scheduler/SchedulableBuilder.scala b/core/src/main/scala/org/apache/spark/scheduler/SchedulableBuilder.scala index 417103436144a..5f3c280ec31ed 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/SchedulableBuilder.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/SchedulableBuilder.scala @@ -181,23 +181,23 @@ private[spark] class FairSchedulableBuilder(val rootPool: Pool, conf: SparkConf) } override def addTaskSetManager(manager: Schedulable, properties: Properties) { - var poolName = DEFAULT_POOL_NAME - var parentPool = rootPool.getSchedulableByName(poolName) - if (properties != null) { - poolName = properties.getProperty(FAIR_SCHEDULER_PROPERTIES, DEFAULT_POOL_NAME) - parentPool = rootPool.getSchedulableByName(poolName) - if (parentPool == null) { - // we will create a new pool that user has configured in app - // instead of being defined in xml file - parentPool = new Pool(poolName, DEFAULT_SCHEDULING_MODE, - DEFAULT_MINIMUM_SHARE, DEFAULT_WEIGHT) - rootPool.addSchedulable(parentPool) - logWarning(s"A job was submitted with scheduler pool $poolName, which has not been " + - "configured. This can happen when the file that pools are read from isn't set, or " + - s"when that file doesn't contain $poolName. Created $poolName with default " + - s"configuration (schedulingMode: $DEFAULT_SCHEDULING_MODE, " + - s"minShare: $DEFAULT_MINIMUM_SHARE, weight: $DEFAULT_WEIGHT)") + val poolName = if (properties != null) { + properties.getProperty(FAIR_SCHEDULER_PROPERTIES, DEFAULT_POOL_NAME) + } else { + DEFAULT_POOL_NAME } + var parentPool = rootPool.getSchedulableByName(poolName) + if (parentPool == null) { + // we will create a new pool that user has configured in app + // instead of being defined in xml file + parentPool = new Pool(poolName, DEFAULT_SCHEDULING_MODE, + DEFAULT_MINIMUM_SHARE, DEFAULT_WEIGHT) + rootPool.addSchedulable(parentPool) + logWarning(s"A job was submitted with scheduler pool $poolName, which has not been " + + "configured. This can happen when the file that pools are read from isn't set, or " + + s"when that file doesn't contain $poolName. Created $poolName with default " + + s"configuration (schedulingMode: $DEFAULT_SCHEDULING_MODE, " + + s"minShare: $DEFAULT_MINIMUM_SHARE, weight: $DEFAULT_WEIGHT)") } parentPool.addSchedulable(manager) logInfo("Added task set " + manager.name + " tasks to pool " + poolName) From d4f10cbbe1b9d13e43d80a50d204781e1c5c2da9 Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Tue, 18 Apr 2017 11:05:00 +0100 Subject: [PATCH 0282/1765] [SPARK-20343][BUILD] Force Avro 1.7.7 in sbt build to resolve build failure in SBT Hadoop 2.6 master on Jenkins ## What changes were proposed in this pull request? This PR proposes to force Avro's version to 1.7.7 in core to resolve the build failure as below: ``` [error] /home/jenkins/workspace/spark-master-test-sbt-hadoop-2.6/core/src/main/scala/org/apache/spark/serializer/GenericAvroSerializer.scala:123: value createDatumWriter is not a member of org.apache.avro.generic.GenericData [error] writerCache.getOrElseUpdate(schema, GenericData.get.createDatumWriter(schema)) [error] ``` https://amplab.cs.berkeley.edu/jenkins/view/Spark%20QA%20Test%20(Dashboard)/job/spark-master-test-sbt-hadoop-2.6/2770/consoleFull Note that this is a hack and should be removed in the future. ## How was this patch tested? I only tested this actually overrides the dependency. I tried many ways but I was unable to reproduce this in my local. Sean also tried the way I did but he was also unable to reproduce this. Please refer the comments in https://github.com/apache/spark/pull/17477#issuecomment-294094092 Author: hyukjinkwon Closes #17651 from HyukjinKwon/SPARK-20343-sbt. --- pom.xml | 1 + project/SparkBuild.scala | 14 ++++++++++++-- 2 files changed, 13 insertions(+), 2 deletions(-) diff --git a/pom.xml b/pom.xml index c1174593c1922..14370d92a9080 100644 --- a/pom.xml +++ b/pom.xml @@ -142,6 +142,7 @@ 2.4.0 2.0.8 3.1.2 + 1.7.7 hadoop2 0.9.3 diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala index e52baf51aed1a..77dae289f7758 100644 --- a/project/SparkBuild.scala +++ b/project/SparkBuild.scala @@ -318,8 +318,8 @@ object SparkBuild extends PomBuild { enable(MimaBuild.mimaSettings(sparkHome, x))(x) } - /* Generate and pick the spark build info from extra-resources */ - enable(Core.settings)(core) + /* Generate and pick the spark build info from extra-resources and override a dependency */ + enable(Core.settings ++ CoreDependencyOverrides.settings)(core) /* Unsafe settings */ enable(Unsafe.settings)(unsafe) @@ -443,6 +443,16 @@ object DockerIntegrationTests { ) } +/** + * Overrides to work around sbt's dependency resolution being different from Maven's in Unidoc. + * + * Note that, this is a hack that should be removed in the future. See SPARK-20343 + */ +object CoreDependencyOverrides { + lazy val settings = Seq( + dependencyOverrides += "org.apache.avro" % "avro" % "1.7.7") +} + /** * Overrides to work around sbt's dependency resolution being different from Maven's. */ From 321b4f03bc983c582a3c6259019c077cdfac9d26 Mon Sep 17 00:00:00 2001 From: wangzhenhua Date: Tue, 18 Apr 2017 20:12:21 +0800 Subject: [PATCH 0283/1765] [SPARK-20366][SQL] Fix recursive join reordering: inside joins are not reordered ## What changes were proposed in this pull request? If a plan has multi-level successive joins, e.g.: ``` Join / \ Union t5 / \ Join t4 / \ Join t3 / \ t1 t2 ``` Currently we fail to reorder the inside joins, i.e. t1, t2, t3. In join reorder, we use `OrderedJoin` to indicate a join has been ordered, such that when transforming down the plan, these joins don't need to be rerodered again. But there's a problem in the definition of `OrderedJoin`: The real join node is a parameter, but not a child. This breaks the transform procedure because `mapChildren` applies transform function on parameters which should be children. In this patch, we change `OrderedJoin` to a class having the same structure as a join node. ## How was this patch tested? Add a corresponding test case. Author: wangzhenhua Closes #17668 from wzhfy/recursiveReorder. --- .../optimizer/CostBasedJoinReorder.scala | 22 +++++---- .../catalyst/optimizer/JoinReorderSuite.scala | 49 +++++++++++++++++-- 2 files changed, 58 insertions(+), 13 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/CostBasedJoinReorder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/CostBasedJoinReorder.scala index c704c2e6d36bd..51eca6ca33760 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/CostBasedJoinReorder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/CostBasedJoinReorder.scala @@ -21,7 +21,7 @@ import scala.collection.mutable import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.expressions.{And, Attribute, AttributeSet, Expression, PredicateHelper} -import org.apache.spark.sql.catalyst.plans.{Inner, InnerLike} +import org.apache.spark.sql.catalyst.plans.{Inner, InnerLike, JoinType} import org.apache.spark.sql.catalyst.plans.logical.{BinaryNode, Join, LogicalPlan, Project} import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.internal.SQLConf @@ -47,7 +47,7 @@ case class CostBasedJoinReorder(conf: SQLConf) extends Rule[LogicalPlan] with Pr } // After reordering is finished, convert OrderedJoin back to Join result transformDown { - case oj: OrderedJoin => oj.join + case OrderedJoin(left, right, jt, cond) => Join(left, right, jt, cond) } } } @@ -87,22 +87,24 @@ case class CostBasedJoinReorder(conf: SQLConf) extends Rule[LogicalPlan] with Pr } private def replaceWithOrderedJoin(plan: LogicalPlan): LogicalPlan = plan match { - case j @ Join(left, right, _: InnerLike, Some(cond)) => + case j @ Join(left, right, jt: InnerLike, Some(cond)) => val replacedLeft = replaceWithOrderedJoin(left) val replacedRight = replaceWithOrderedJoin(right) - OrderedJoin(j.copy(left = replacedLeft, right = replacedRight)) + OrderedJoin(replacedLeft, replacedRight, jt, Some(cond)) case p @ Project(projectList, j @ Join(_, _, _: InnerLike, Some(cond))) => p.copy(child = replaceWithOrderedJoin(j)) case _ => plan } +} - /** This is a wrapper class for a join node that has been ordered. */ - private case class OrderedJoin(join: Join) extends BinaryNode { - override def left: LogicalPlan = join.left - override def right: LogicalPlan = join.right - override def output: Seq[Attribute] = join.output - } +/** This is a mimic class for a join node that has been ordered. */ +case class OrderedJoin( + left: LogicalPlan, + right: LogicalPlan, + joinType: JoinType, + condition: Option[Expression]) extends BinaryNode { + override def output: Seq[Attribute] = left.output ++ right.output } /** diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinReorderSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinReorderSuite.scala index 1922eb30fdce4..71db4e2e0ec4d 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinReorderSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinReorderSuite.scala @@ -25,13 +25,12 @@ import org.apache.spark.sql.catalyst.plans.logical.{ColumnStat, LogicalPlan} import org.apache.spark.sql.catalyst.rules.RuleExecutor import org.apache.spark.sql.catalyst.statsEstimation.{StatsEstimationTestBase, StatsTestPlan} import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.internal.SQLConf.{CASE_SENSITIVE, CBO_ENABLED, JOIN_REORDER_ENABLED} +import org.apache.spark.sql.internal.SQLConf.{CBO_ENABLED, JOIN_REORDER_ENABLED} class JoinReorderSuite extends PlanTest with StatsEstimationTestBase { - override val conf = new SQLConf().copy( - CASE_SENSITIVE -> true, CBO_ENABLED -> true, JOIN_REORDER_ENABLED -> true) + override val conf = new SQLConf().copy(CBO_ENABLED -> true, JOIN_REORDER_ENABLED -> true) object Optimize extends RuleExecutor[LogicalPlan] { val batches = @@ -212,6 +211,50 @@ class JoinReorderSuite extends PlanTest with StatsEstimationTestBase { } } + test("reorder recursively") { + // Original order: + // Join + // / \ + // Union t5 + // / \ + // Join t4 + // / \ + // Join t3 + // / \ + // t1 t2 + val bottomJoins = + t1.join(t2).join(t3).where((nameToAttr("t1.k-1-2") === nameToAttr("t2.k-1-5")) && + (nameToAttr("t1.v-1-10") === nameToAttr("t3.v-1-100"))) + .select(nameToAttr("t1.v-1-10")) + + val originalPlan = bottomJoins + .union(t4.select(nameToAttr("t4.v-1-10"))) + .join(t5, Inner, Some(nameToAttr("t1.v-1-10") === nameToAttr("t5.v-1-5"))) + + // Should be able to reorder the bottom part. + // Best order: + // Join + // / \ + // Union t5 + // / \ + // Join t4 + // / \ + // Join t2 + // / \ + // t1 t3 + val bestBottomPlan = + t1.join(t3, Inner, Some(nameToAttr("t1.v-1-10") === nameToAttr("t3.v-1-100"))) + .select(nameToAttr("t1.k-1-2"), nameToAttr("t1.v-1-10")) + .join(t2, Inner, Some(nameToAttr("t1.k-1-2") === nameToAttr("t2.k-1-5"))) + .select(nameToAttr("t1.v-1-10")) + + val bestPlan = bestBottomPlan + .union(t4.select(nameToAttr("t4.v-1-10"))) + .join(t5, Inner, Some(nameToAttr("t1.v-1-10") === nameToAttr("t5.v-1-5"))) + + assertEqualPlans(originalPlan, bestPlan) + } + private def assertEqualPlans( originalPlan: LogicalPlan, groundTruthBestPlan: LogicalPlan): Unit = { From 1f81dda37cfc2049fabd6abd93ef3720d0aa03ea Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=83=AD=E5=B0=8F=E9=BE=99=2010207633?= Date: Tue, 18 Apr 2017 10:02:21 -0700 Subject: [PATCH 0284/1765] [SPARK-20354][CORE][REST-API] When I request access to the 'http: //ip:port/api/v1/applications' link, return 'sparkUser' is empty in REST API. MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## What changes were proposed in this pull request? When I request access to the 'http: //ip:port/api/v1/applications' link, get the json. I need the 'sparkUser' field specific value, because my Spark big data management platform needs to filter through this field which user submits the application to facilitate my administration and query, but the current return of the json string is empty, causing me this Function can not be achieved, that is, I do not know who the specific application is submitted by this REST Api. **current return json:** [ { "id" : "app-20170417152053-0000", "name" : "KafkaWordCount", "attempts" : [ { "startTime" : "2017-04-17T07:20:51.395GMT", "endTime" : "1969-12-31T23:59:59.999GMT", "lastUpdated" : "2017-04-17T07:20:51.395GMT", "duration" : 0, **"sparkUser" : "",** "completed" : false, "endTimeEpoch" : -1, "startTimeEpoch" : 1492413651395, "lastUpdatedEpoch" : 1492413651395 } ] } ] **When I fix this question, return json:** [ { "id" : "app-20170417154201-0000", "name" : "KafkaWordCount", "attempts" : [ { "startTime" : "2017-04-17T07:41:57.335GMT", "endTime" : "1969-12-31T23:59:59.999GMT", "lastUpdated" : "2017-04-17T07:41:57.335GMT", "duration" : 0, **"sparkUser" : "mr",** "completed" : false, "startTimeEpoch" : 1492414917335, "endTimeEpoch" : -1, "lastUpdatedEpoch" : 1492414917335 } ] } ] ## How was this patch tested? manual tests Please review http://spark.apache.org/contributing.html before opening a pull request. Author: 郭小龙 10207633 Author: guoxiaolong Author: guoxiaolongzte Closes #17656 from guoxiaolongzte/SPARK-20354. --- core/src/main/scala/org/apache/spark/ui/SparkUI.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/ui/SparkUI.scala b/core/src/main/scala/org/apache/spark/ui/SparkUI.scala index 7d31ac54a7177..bf4cf79e9faa3 100644 --- a/core/src/main/scala/org/apache/spark/ui/SparkUI.scala +++ b/core/src/main/scala/org/apache/spark/ui/SparkUI.scala @@ -117,7 +117,7 @@ private[spark] class SparkUI private ( endTime = new Date(-1), duration = 0, lastUpdated = new Date(startTime), - sparkUser = "", + sparkUser = getSparkUser, completed = false )) )) From f654b39a63d4f9b118733733c7ed2a1b58649e3d Mon Sep 17 00:00:00 2001 From: Kyle Kelley Date: Tue, 18 Apr 2017 12:35:27 -0700 Subject: [PATCH 0285/1765] [SPARK-20360][PYTHON] reprs for interpreters ## What changes were proposed in this pull request? Establishes a very minimal `_repr_html_` for PySpark's `SparkContext`. ## How was this patch tested? nteract: ![screen shot 2017-04-17 at 3 41 29 pm](https://cloud.githubusercontent.com/assets/836375/25107701/d57090ba-2385-11e7-8147-74bc2c50a41b.png) Jupyter: ![screen shot 2017-04-17 at 3 53 19 pm](https://cloud.githubusercontent.com/assets/836375/25107725/05bf1fe8-2386-11e7-93e1-07a20c917dde.png) Hydrogen: ![screen shot 2017-04-17 at 3 49 55 pm](https://cloud.githubusercontent.com/assets/836375/25107664/a75e1ddc-2385-11e7-8477-258661833007.png) Author: Kyle Kelley Closes #17662 from rgbkrk/repr. --- python/pyspark/context.py | 26 ++++++++++++++++++++++++++ python/pyspark/sql/session.py | 11 +++++++++++ 2 files changed, 37 insertions(+) diff --git a/python/pyspark/context.py b/python/pyspark/context.py index 2961cda553d6a..3be07325f4162 100644 --- a/python/pyspark/context.py +++ b/python/pyspark/context.py @@ -240,6 +240,32 @@ def signal_handler(signal, frame): if isinstance(threading.current_thread(), threading._MainThread): signal.signal(signal.SIGINT, signal_handler) + def __repr__(self): + return "".format( + master=self.master, + appName=self.appName, + ) + + def _repr_html_(self): + return """ +
    +

    SparkContext

    + +

    Spark UI

    + +
    +
    Version
    +
    v{sc.version}
    +
    Master
    +
    {sc.master}
    +
    AppName
    +
    {sc.appName}
    +
    +
    + """.format( + sc=self + ) + def _initialize_context(self, jconf): """ Initialize SparkContext in function to allow subclass specific initialization diff --git a/python/pyspark/sql/session.py b/python/pyspark/sql/session.py index 9f4772eec9f2a..c1bf2bd76fb7c 100644 --- a/python/pyspark/sql/session.py +++ b/python/pyspark/sql/session.py @@ -221,6 +221,17 @@ def __init__(self, sparkContext, jsparkSession=None): or SparkSession._instantiatedSession._sc._jsc is None: SparkSession._instantiatedSession = self + def _repr_html_(self): + return """ +
    +

    SparkSession - {catalogImplementation}

    + {sc_HTML} +
    + """.format( + catalogImplementation=self.conf.get("spark.sql.catalogImplementation"), + sc_HTML=self.sparkContext._repr_html_() + ) + @since(2.0) def newSession(self): """ From 74aa0df8f7f132b62754e5159262e4a5b9b641ab Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Tue, 18 Apr 2017 16:10:40 -0700 Subject: [PATCH 0286/1765] [SPARK-20377][SS] Fix JavaStructuredSessionization example ## What changes were proposed in this pull request? Extra accessors in java bean class causes incorrect encoder generation, which corrupted the state when using timeouts. ## How was this patch tested? manually ran the example Author: Tathagata Das Closes #17676 from tdas/SPARK-20377. --- .../sql/streaming/JavaStructuredSessionization.java | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/examples/src/main/java/org/apache/spark/examples/sql/streaming/JavaStructuredSessionization.java b/examples/src/main/java/org/apache/spark/examples/sql/streaming/JavaStructuredSessionization.java index da3a5dfe8628b..d3c8516882fa6 100644 --- a/examples/src/main/java/org/apache/spark/examples/sql/streaming/JavaStructuredSessionization.java +++ b/examples/src/main/java/org/apache/spark/examples/sql/streaming/JavaStructuredSessionization.java @@ -76,8 +76,6 @@ public Iterator call(LineWithTimestamp lineWithTimestamp) throws Exceptio for (String word : lineWithTimestamp.getLine().split(" ")) { eventList.add(new Event(word, lineWithTimestamp.getTimestamp())); } - System.out.println( - "Number of events from " + lineWithTimestamp.getLine() + " = " + eventList.size()); return eventList.iterator(); } }; @@ -100,7 +98,7 @@ public Iterator call(LineWithTimestamp lineWithTimestamp) throws Exceptio // If timed out, then remove session and send final update if (state.hasTimedOut()) { SessionUpdate finalUpdate = new SessionUpdate( - sessionId, state.get().getDurationMs(), state.get().getNumEvents(), true); + sessionId, state.get().calculateDuration(), state.get().getNumEvents(), true); state.remove(); return finalUpdate; @@ -133,7 +131,7 @@ public Iterator call(LineWithTimestamp lineWithTimestamp) throws Exceptio // Set timeout such that the session will be expired if no data received for 10 seconds state.setTimeoutDuration("10 seconds"); return new SessionUpdate( - sessionId, state.get().getDurationMs(), state.get().getNumEvents(), false); + sessionId, state.get().calculateDuration(), state.get().getNumEvents(), false); } } }; @@ -215,7 +213,8 @@ public void setStartTimestampMs(long startTimestampMs) { public long getEndTimestampMs() { return endTimestampMs; } public void setEndTimestampMs(long endTimestampMs) { this.endTimestampMs = endTimestampMs; } - public long getDurationMs() { return endTimestampMs - startTimestampMs; } + public long calculateDuration() { return endTimestampMs - startTimestampMs; } + @Override public String toString() { return "SessionInfo(numEvents = " + numEvents + ", timestamps = " + startTimestampMs + " to " + endTimestampMs + ")"; From e468a96c404eb54261ab219734f67dc2f5b06dc0 Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Wed, 19 Apr 2017 10:58:05 +0800 Subject: [PATCH 0287/1765] [SPARK-20254][SQL] Remove unnecessary data conversion for Dataset with primitive array ## What changes were proposed in this pull request? This PR elminates unnecessary data conversion, which is introduced by SPARK-19716, for Dataset with primitve array in the generated Java code. When we run the following example program, now we get the Java code "Without this PR". In this code, lines 56-82 are unnecessary since the primitive array in ArrayData can be converted into Java primitive array by using ``toDoubleArray()`` method. ``GenericArrayData`` is not required. ```java val ds = sparkContext.parallelize(Seq(Array(1.1, 2.2)), 1).toDS.cache ds.count ds.map(e => e).show ``` Without this PR ``` == Parsed Logical Plan == 'SerializeFromObject [staticinvoke(class org.apache.spark.sql.catalyst.expressions.UnsafeArrayData, ArrayType(DoubleType,false), fromPrimitiveArray, input[0, [D, true], true) AS value#25] +- 'MapElements , class [D, [StructField(value,ArrayType(DoubleType,false),true)], obj#24: [D +- 'DeserializeToObject unresolveddeserializer(unresolvedmapobjects(, getcolumnbyordinal(0, ArrayType(DoubleType,false)), None).toDoubleArray), obj#23: [D +- SerializeFromObject [staticinvoke(class org.apache.spark.sql.catalyst.expressions.UnsafeArrayData, ArrayType(DoubleType,false), fromPrimitiveArray, input[0, [D, true], true) AS value#2] +- ExternalRDD [obj#1] == Analyzed Logical Plan == value: array SerializeFromObject [staticinvoke(class org.apache.spark.sql.catalyst.expressions.UnsafeArrayData, ArrayType(DoubleType,false), fromPrimitiveArray, input[0, [D, true], true) AS value#25] +- MapElements , class [D, [StructField(value,ArrayType(DoubleType,false),true)], obj#24: [D +- DeserializeToObject mapobjects(MapObjects_loopValue5, MapObjects_loopIsNull5, DoubleType, assertnotnull(lambdavariable(MapObjects_loopValue5, MapObjects_loopIsNull5, DoubleType, true), - array element class: "scala.Double", - root class: "scala.Array"), value#2, None, MapObjects_builderValue5).toDoubleArray, obj#23: [D +- SerializeFromObject [staticinvoke(class org.apache.spark.sql.catalyst.expressions.UnsafeArrayData, ArrayType(DoubleType,false), fromPrimitiveArray, input[0, [D, true], true) AS value#2] +- ExternalRDD [obj#1] == Optimized Logical Plan == SerializeFromObject [staticinvoke(class org.apache.spark.sql.catalyst.expressions.UnsafeArrayData, ArrayType(DoubleType,false), fromPrimitiveArray, input[0, [D, true], true) AS value#25] +- MapElements , class [D, [StructField(value,ArrayType(DoubleType,false),true)], obj#24: [D +- DeserializeToObject mapobjects(MapObjects_loopValue5, MapObjects_loopIsNull5, DoubleType, assertnotnull(lambdavariable(MapObjects_loopValue5, MapObjects_loopIsNull5, DoubleType, true), - array element class: "scala.Double", - root class: "scala.Array"), value#2, None, MapObjects_builderValue5).toDoubleArray, obj#23: [D +- InMemoryRelation [value#2], true, 10000, StorageLevel(disk, memory, deserialized, 1 replicas) +- *SerializeFromObject [staticinvoke(class org.apache.spark.sql.catalyst.expressions.UnsafeArrayData, ArrayType(DoubleType,false), fromPrimitiveArray, input[0, [D, true], true) AS value#2] +- Scan ExternalRDDScan[obj#1] == Physical Plan == *SerializeFromObject [staticinvoke(class org.apache.spark.sql.catalyst.expressions.UnsafeArrayData, ArrayType(DoubleType,false), fromPrimitiveArray, input[0, [D, true], true) AS value#25] +- *MapElements , obj#24: [D +- *DeserializeToObject mapobjects(MapObjects_loopValue5, MapObjects_loopIsNull5, DoubleType, assertnotnull(lambdavariable(MapObjects_loopValue5, MapObjects_loopIsNull5, DoubleType, true), - array element class: "scala.Double", - root class: "scala.Array"), value#2, None, MapObjects_builderValue5).toDoubleArray, obj#23: [D +- InMemoryTableScan [value#2] +- InMemoryRelation [value#2], true, 10000, StorageLevel(disk, memory, deserialized, 1 replicas) +- *SerializeFromObject [staticinvoke(class org.apache.spark.sql.catalyst.expressions.UnsafeArrayData, ArrayType(DoubleType,false), fromPrimitiveArray, input[0, [D, true], true) AS value#2] +- Scan ExternalRDDScan[obj#1] ``` ```java /* 050 */ protected void processNext() throws java.io.IOException { /* 051 */ while (inputadapter_input.hasNext() && !stopEarly()) { /* 052 */ InternalRow inputadapter_row = (InternalRow) inputadapter_input.next(); /* 053 */ boolean inputadapter_isNull = inputadapter_row.isNullAt(0); /* 054 */ ArrayData inputadapter_value = inputadapter_isNull ? null : (inputadapter_row.getArray(0)); /* 055 */ /* 056 */ ArrayData deserializetoobject_value1 = null; /* 057 */ /* 058 */ if (!inputadapter_isNull) { /* 059 */ int deserializetoobject_dataLength = inputadapter_value.numElements(); /* 060 */ /* 061 */ Double[] deserializetoobject_convertedArray = null; /* 062 */ deserializetoobject_convertedArray = new Double[deserializetoobject_dataLength]; /* 063 */ /* 064 */ int deserializetoobject_loopIndex = 0; /* 065 */ while (deserializetoobject_loopIndex < deserializetoobject_dataLength) { /* 066 */ MapObjects_loopValue2 = (double) (inputadapter_value.getDouble(deserializetoobject_loopIndex)); /* 067 */ MapObjects_loopIsNull2 = inputadapter_value.isNullAt(deserializetoobject_loopIndex); /* 068 */ /* 069 */ if (MapObjects_loopIsNull2) { /* 070 */ throw new RuntimeException(((java.lang.String) references[0])); /* 071 */ } /* 072 */ if (false) { /* 073 */ deserializetoobject_convertedArray[deserializetoobject_loopIndex] = null; /* 074 */ } else { /* 075 */ deserializetoobject_convertedArray[deserializetoobject_loopIndex] = MapObjects_loopValue2; /* 076 */ } /* 077 */ /* 078 */ deserializetoobject_loopIndex += 1; /* 079 */ } /* 080 */ /* 081 */ deserializetoobject_value1 = new org.apache.spark.sql.catalyst.util.GenericArrayData(deserializetoobject_convertedArray); /*###*/ /* 082 */ } /* 083 */ boolean deserializetoobject_isNull = true; /* 084 */ double[] deserializetoobject_value = null; /* 085 */ if (!inputadapter_isNull) { /* 086 */ deserializetoobject_isNull = false; /* 087 */ if (!deserializetoobject_isNull) { /* 088 */ Object deserializetoobject_funcResult = null; /* 089 */ deserializetoobject_funcResult = deserializetoobject_value1.toDoubleArray(); /* 090 */ if (deserializetoobject_funcResult == null) { /* 091 */ deserializetoobject_isNull = true; /* 092 */ } else { /* 093 */ deserializetoobject_value = (double[]) deserializetoobject_funcResult; /* 094 */ } /* 095 */ /* 096 */ } /* 097 */ deserializetoobject_isNull = deserializetoobject_value == null; /* 098 */ } /* 099 */ /* 100 */ boolean mapelements_isNull = true; /* 101 */ double[] mapelements_value = null; /* 102 */ if (!false) { /* 103 */ mapelements_resultIsNull = false; /* 104 */ /* 105 */ if (!mapelements_resultIsNull) { /* 106 */ mapelements_resultIsNull = deserializetoobject_isNull; /* 107 */ mapelements_argValue = deserializetoobject_value; /* 108 */ } /* 109 */ /* 110 */ mapelements_isNull = mapelements_resultIsNull; /* 111 */ if (!mapelements_isNull) { /* 112 */ Object mapelements_funcResult = null; /* 113 */ mapelements_funcResult = ((scala.Function1) references[1]).apply(mapelements_argValue); /* 114 */ if (mapelements_funcResult == null) { /* 115 */ mapelements_isNull = true; /* 116 */ } else { /* 117 */ mapelements_value = (double[]) mapelements_funcResult; /* 118 */ } /* 119 */ /* 120 */ } /* 121 */ mapelements_isNull = mapelements_value == null; /* 122 */ } /* 123 */ /* 124 */ serializefromobject_resultIsNull = false; /* 125 */ /* 126 */ if (!serializefromobject_resultIsNull) { /* 127 */ serializefromobject_resultIsNull = mapelements_isNull; /* 128 */ serializefromobject_argValue = mapelements_value; /* 129 */ } /* 130 */ /* 131 */ boolean serializefromobject_isNull = serializefromobject_resultIsNull; /* 132 */ final ArrayData serializefromobject_value = serializefromobject_resultIsNull ? null : org.apache.spark.sql.catalyst.expressions.UnsafeArrayData.fromPrimitiveArray(serializefromobject_argValue); /* 133 */ serializefromobject_isNull = serializefromobject_value == null; /* 134 */ serializefromobject_holder.reset(); /* 135 */ /* 136 */ serializefromobject_rowWriter.zeroOutNullBytes(); /* 137 */ /* 138 */ if (serializefromobject_isNull) { /* 139 */ serializefromobject_rowWriter.setNullAt(0); /* 140 */ } else { /* 141 */ // Remember the current cursor so that we can calculate how many bytes are /* 142 */ // written later. /* 143 */ final int serializefromobject_tmpCursor = serializefromobject_holder.cursor; /* 144 */ /* 145 */ if (serializefromobject_value instanceof UnsafeArrayData) { /* 146 */ final int serializefromobject_sizeInBytes = ((UnsafeArrayData) serializefromobject_value).getSizeInBytes(); /* 147 */ // grow the global buffer before writing data. /* 148 */ serializefromobject_holder.grow(serializefromobject_sizeInBytes); /* 149 */ ((UnsafeArrayData) serializefromobject_value).writeToMemory(serializefromobject_holder.buffer, serializefromobject_holder.cursor); /* 150 */ serializefromobject_holder.cursor += serializefromobject_sizeInBytes; /* 151 */ /* 152 */ } else { /* 153 */ final int serializefromobject_numElements = serializefromobject_value.numElements(); /* 154 */ serializefromobject_arrayWriter.initialize(serializefromobject_holder, serializefromobject_numElements, 8); /* 155 */ /* 156 */ for (int serializefromobject_index = 0; serializefromobject_index < serializefromobject_numElements; serializefromobject_index++) { /* 157 */ if (serializefromobject_value.isNullAt(serializefromobject_index)) { /* 158 */ serializefromobject_arrayWriter.setNullDouble(serializefromobject_index); /* 159 */ } else { /* 160 */ final double serializefromobject_element = serializefromobject_value.getDouble(serializefromobject_index); /* 161 */ serializefromobject_arrayWriter.write(serializefromobject_index, serializefromobject_element); /* 162 */ } /* 163 */ } /* 164 */ } /* 165 */ /* 166 */ serializefromobject_rowWriter.setOffsetAndSize(0, serializefromobject_tmpCursor, serializefromobject_holder.cursor - serializefromobject_tmpCursor); /* 167 */ } /* 168 */ serializefromobject_result.setTotalSize(serializefromobject_holder.totalSize()); /* 169 */ append(serializefromobject_result); /* 170 */ if (shouldStop()) return; /* 171 */ } /* 172 */ } ``` With this PR (eliminated lines 56-62 in the above code) ```java /* 047 */ protected void processNext() throws java.io.IOException { /* 048 */ while (inputadapter_input.hasNext() && !stopEarly()) { /* 049 */ InternalRow inputadapter_row = (InternalRow) inputadapter_input.next(); /* 050 */ boolean inputadapter_isNull = inputadapter_row.isNullAt(0); /* 051 */ ArrayData inputadapter_value = inputadapter_isNull ? null : (inputadapter_row.getArray(0)); /* 052 */ /* 053 */ boolean deserializetoobject_isNull = true; /* 054 */ double[] deserializetoobject_value = null; /* 055 */ if (!inputadapter_isNull) { /* 056 */ deserializetoobject_isNull = false; /* 057 */ if (!deserializetoobject_isNull) { /* 058 */ Object deserializetoobject_funcResult = null; /* 059 */ deserializetoobject_funcResult = inputadapter_value.toDoubleArray(); /* 060 */ if (deserializetoobject_funcResult == null) { /* 061 */ deserializetoobject_isNull = true; /* 062 */ } else { /* 063 */ deserializetoobject_value = (double[]) deserializetoobject_funcResult; /* 064 */ } /* 065 */ /* 066 */ } /* 067 */ deserializetoobject_isNull = deserializetoobject_value == null; /* 068 */ } /* 069 */ /* 070 */ boolean mapelements_isNull = true; /* 071 */ double[] mapelements_value = null; /* 072 */ if (!false) { /* 073 */ mapelements_resultIsNull = false; /* 074 */ /* 075 */ if (!mapelements_resultIsNull) { /* 076 */ mapelements_resultIsNull = deserializetoobject_isNull; /* 077 */ mapelements_argValue = deserializetoobject_value; /* 078 */ } /* 079 */ /* 080 */ mapelements_isNull = mapelements_resultIsNull; /* 081 */ if (!mapelements_isNull) { /* 082 */ Object mapelements_funcResult = null; /* 083 */ mapelements_funcResult = ((scala.Function1) references[0]).apply(mapelements_argValue); /* 084 */ if (mapelements_funcResult == null) { /* 085 */ mapelements_isNull = true; /* 086 */ } else { /* 087 */ mapelements_value = (double[]) mapelements_funcResult; /* 088 */ } /* 089 */ /* 090 */ } /* 091 */ mapelements_isNull = mapelements_value == null; /* 092 */ } /* 093 */ /* 094 */ serializefromobject_resultIsNull = false; /* 095 */ /* 096 */ if (!serializefromobject_resultIsNull) { /* 097 */ serializefromobject_resultIsNull = mapelements_isNull; /* 098 */ serializefromobject_argValue = mapelements_value; /* 099 */ } /* 100 */ /* 101 */ boolean serializefromobject_isNull = serializefromobject_resultIsNull; /* 102 */ final ArrayData serializefromobject_value = serializefromobject_resultIsNull ? null : org.apache.spark.sql.catalyst.expressions.UnsafeArrayData.fromPrimitiveArray(serializefromobject_argValue); /* 103 */ serializefromobject_isNull = serializefromobject_value == null; /* 104 */ serializefromobject_holder.reset(); /* 105 */ /* 106 */ serializefromobject_rowWriter.zeroOutNullBytes(); /* 107 */ /* 108 */ if (serializefromobject_isNull) { /* 109 */ serializefromobject_rowWriter.setNullAt(0); /* 110 */ } else { /* 111 */ // Remember the current cursor so that we can calculate how many bytes are /* 112 */ // written later. /* 113 */ final int serializefromobject_tmpCursor = serializefromobject_holder.cursor; /* 114 */ /* 115 */ if (serializefromobject_value instanceof UnsafeArrayData) { /* 116 */ final int serializefromobject_sizeInBytes = ((UnsafeArrayData) serializefromobject_value).getSizeInBytes(); /* 117 */ // grow the global buffer before writing data. /* 118 */ serializefromobject_holder.grow(serializefromobject_sizeInBytes); /* 119 */ ((UnsafeArrayData) serializefromobject_value).writeToMemory(serializefromobject_holder.buffer, serializefromobject_holder.cursor); /* 120 */ serializefromobject_holder.cursor += serializefromobject_sizeInBytes; /* 121 */ /* 122 */ } else { /* 123 */ final int serializefromobject_numElements = serializefromobject_value.numElements(); /* 124 */ serializefromobject_arrayWriter.initialize(serializefromobject_holder, serializefromobject_numElements, 8); /* 125 */ /* 126 */ for (int serializefromobject_index = 0; serializefromobject_index < serializefromobject_numElements; serializefromobject_index++) { /* 127 */ if (serializefromobject_value.isNullAt(serializefromobject_index)) { /* 128 */ serializefromobject_arrayWriter.setNullDouble(serializefromobject_index); /* 129 */ } else { /* 130 */ final double serializefromobject_element = serializefromobject_value.getDouble(serializefromobject_index); /* 131 */ serializefromobject_arrayWriter.write(serializefromobject_index, serializefromobject_element); /* 132 */ } /* 133 */ } /* 134 */ } /* 135 */ /* 136 */ serializefromobject_rowWriter.setOffsetAndSize(0, serializefromobject_tmpCursor, serializefromobject_holder.cursor - serializefromobject_tmpCursor); /* 137 */ } /* 138 */ serializefromobject_result.setTotalSize(serializefromobject_holder.totalSize()); /* 139 */ append(serializefromobject_result); /* 140 */ if (shouldStop()) return; /* 141 */ } /* 142 */ } ``` ## How was this patch tested? Add test suites into `DatasetPrimitiveSuite` Author: Kazuaki Ishizaki Closes #17568 from kiszk/SPARK-20254. --- .../sql/catalyst/analysis/Analyzer.scala | 4 +- .../expressions/objects/objects.scala | 5 +- .../sql/catalyst/optimizer/Optimizer.scala | 3 +- .../sql/catalyst/optimizer/expressions.scala | 3 + .../sql/catalyst/optimizer/objects.scala | 13 ++++ .../optimizer/EliminateMapObjectsSuite.scala | 62 +++++++++++++++++++ 6 files changed, 86 insertions(+), 4 deletions(-) create mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/EliminateMapObjectsSuite.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 9816b33ae8dff..d9f36f7f874d7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -2230,8 +2230,8 @@ class Analyzer( val result = resolved transformDown { case UnresolvedMapObjects(func, inputData, cls) if inputData.resolved => inputData.dataType match { - case ArrayType(et, _) => - val expr = MapObjects(func, inputData, et, cls) transformUp { + case ArrayType(et, cn) => + val expr = MapObjects(func, inputData, et, cn, cls) transformUp { case UnresolvedExtractValue(child, fieldName) if child.resolved => ExtractValue(child, fieldName, resolver) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala index f446c3e4a75f6..1a202ecf745c9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala @@ -451,6 +451,8 @@ object MapObjects { * @param function The function applied on the collection elements. * @param inputData An expression that when evaluated returns a collection object. * @param elementType The data type of elements in the collection. + * @param elementNullable When false, indicating elements in the collection are always + * non-null value. * @param customCollectionCls Class of the resulting collection (returning ObjectType) * or None (returning ArrayType) */ @@ -458,11 +460,12 @@ object MapObjects { function: Expression => Expression, inputData: Expression, elementType: DataType, + elementNullable: Boolean = true, customCollectionCls: Option[Class[_]] = None): MapObjects = { val id = curId.getAndIncrement() val loopValue = s"MapObjects_loopValue$id" val loopIsNull = s"MapObjects_loopIsNull$id" - val loopVar = LambdaVariable(loopValue, loopIsNull, elementType) + val loopVar = LambdaVariable(loopValue, loopIsNull, elementType, elementNullable) MapObjects( loopValue, loopIsNull, elementType, function(loopVar), inputData, customCollectionCls) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index d221b0611a892..dd768d18e8588 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -119,7 +119,8 @@ abstract class Optimizer(sessionCatalog: SessionCatalog, conf: SQLConf) CostBasedJoinReorder(conf)) :: Batch("Decimal Optimizations", fixedPoint, DecimalAggregates(conf)) :: - Batch("Typed Filter Optimization", fixedPoint, + Batch("Object Expressions Optimization", fixedPoint, + EliminateMapObjects, CombineTypedFilters) :: Batch("LocalRelation", fixedPoint, ConvertToLocalRelation, diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala index 8445ee06bd89b..ea2c5d241d8dd 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala @@ -23,6 +23,7 @@ import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.expressions.Literal.{FalseLiteral, TrueLiteral} +import org.apache.spark.sql.catalyst.expressions.objects.AssertNotNull import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules._ @@ -368,6 +369,8 @@ case class NullPropagation(conf: SQLConf) extends Rule[LogicalPlan] { case EqualNullSafe(Literal(null, _), r) => IsNull(r) case EqualNullSafe(l, Literal(null, _)) => IsNull(l) + case AssertNotNull(c, _) if !c.nullable => c + // For Coalesce, remove null literals. case e @ Coalesce(children) => val newChildren = children.filterNot(isNullLiteral) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/objects.scala index 257dbfac8c3e8..8cdc6425bcad8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/objects.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst.optimizer import org.apache.spark.api.java.function.FilterFunction import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.objects._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules._ @@ -96,3 +97,15 @@ object CombineTypedFilters extends Rule[LogicalPlan] { } } } + +/** + * Removes MapObjects when the following conditions are satisfied + * 1. Mapobject(... lambdavariable(..., false) ...), which means types for input and output + * are primitive types with non-nullable + * 2. no custom collection class specified representation of data item. + */ +object EliminateMapObjects extends Rule[LogicalPlan] { + def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { + case MapObjects(_, _, _, LambdaVariable(_, _, _, false), inputData, None) => inputData + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/EliminateMapObjectsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/EliminateMapObjectsSuite.scala new file mode 100644 index 0000000000000..d4f37e2a5e877 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/EliminateMapObjectsSuite.scala @@ -0,0 +1,62 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.optimizer + +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.dsl.plans._ +import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder +import org.apache.spark.sql.catalyst.expressions.AttributeReference +import org.apache.spark.sql.catalyst.expressions.objects.Invoke +import org.apache.spark.sql.catalyst.plans.PlanTest +import org.apache.spark.sql.catalyst.plans.logical.{DeserializeToObject, LocalRelation, LogicalPlan} +import org.apache.spark.sql.catalyst.rules.RuleExecutor +import org.apache.spark.sql.types._ + +class EliminateMapObjectsSuite extends PlanTest { + object Optimize extends RuleExecutor[LogicalPlan] { + val batches = { + Batch("EliminateMapObjects", FixedPoint(50), + NullPropagation(conf), + SimplifyCasts, + EliminateMapObjects) :: Nil + } + } + + implicit private def intArrayEncoder = ExpressionEncoder[Array[Int]]() + implicit private def doubleArrayEncoder = ExpressionEncoder[Array[Double]]() + + test("SPARK-20254: Remove unnecessary data conversion for primitive array") { + val intObjType = ObjectType(classOf[Array[Int]]) + val intInput = LocalRelation('a.array(ArrayType(IntegerType, false))) + val intQuery = intInput.deserialize[Array[Int]].analyze + val intOptimized = Optimize.execute(intQuery) + val intExpected = DeserializeToObject( + Invoke(intInput.output(0), "toIntArray", intObjType, Nil, true, false), + AttributeReference("obj", intObjType, true)(), intInput) + comparePlans(intOptimized, intExpected) + + val doubleObjType = ObjectType(classOf[Array[Double]]) + val doubleInput = LocalRelation('a.array(ArrayType(DoubleType, false))) + val doubleQuery = doubleInput.deserialize[Array[Double]].analyze + val doubleOptimized = Optimize.execute(doubleQuery) + val doubleExpected = DeserializeToObject( + Invoke(doubleInput.output(0), "toDoubleArray", doubleObjType, Nil, true, false), + AttributeReference("obj", doubleObjType, true)(), doubleInput) + comparePlans(doubleOptimized, doubleExpected) + } +} From 702d85af2df9433254af6fa029683aa19c52a276 Mon Sep 17 00:00:00 2001 From: zero323 Date: Tue, 18 Apr 2017 19:59:18 -0700 Subject: [PATCH 0288/1765] [SPARK-20208][R][DOCS] Document R fpGrowth support ## What changes were proposed in this pull request? Document fpGrowth in: - vignettes - programming guide - code example ## How was this patch tested? Manual tests. Author: zero323 Closes #17557 from zero323/SPARK-20208. --- R/pkg/vignettes/sparkr-vignettes.Rmd | 37 +++++++++++++++++++- examples/src/main/r/ml/fpm.R | 50 ++++++++++++++++++++++++++++ 2 files changed, 86 insertions(+), 1 deletion(-) create mode 100644 examples/src/main/r/ml/fpm.R diff --git a/R/pkg/vignettes/sparkr-vignettes.Rmd b/R/pkg/vignettes/sparkr-vignettes.Rmd index a6ff650c33fea..f81dbab10b1e1 100644 --- a/R/pkg/vignettes/sparkr-vignettes.Rmd +++ b/R/pkg/vignettes/sparkr-vignettes.Rmd @@ -505,6 +505,10 @@ SparkR supports the following machine learning models and algorithms. * Alternating Least Squares (ALS) +#### Frequent Pattern Mining + +* FP-growth + #### Statistics * Kolmogorov-Smirnov Test @@ -707,7 +711,7 @@ summary(tweedieGLM1) ``` We can try other distributions in the tweedie family, for example, a compound Poisson distribution with a log link: ```{r} -tweedieGLM2 <- spark.glm(carsDF, mpg ~ wt + hp, family = "tweedie", +tweedieGLM2 <- spark.glm(carsDF, mpg ~ wt + hp, family = "tweedie", var.power = 1.2, link.power = 0.0) summary(tweedieGLM2) ``` @@ -906,6 +910,37 @@ predicted <- predict(model, df) head(predicted) ``` +#### FP-growth + +`spark.fpGrowth` executes FP-growth algorithm to mine frequent itemsets on a `SparkDataFrame`. `itemsCol` should be an array of values. + +```{r} +df <- selectExpr(createDataFrame(data.frame(rawItems = c( + "T,R,U", "T,S", "V,R", "R,U,T,V", "R,S", "V,S,U", "U,R", "S,T", "V,R", "V,U,S", + "T,V,U", "R,V", "T,S", "T,S", "S,T", "S,U", "T,R", "V,R", "S,V", "T,S,U" +))), "split(rawItems, ',') AS items") + +fpm <- spark.fpGrowth(df, minSupport = 0.2, minConfidence = 0.5) +``` + +`spark.freqItemsets` method can be used to retrieve a `SparkDataFrame` with the frequent itemsets. + +```{r} +head(spark.freqItemsets(fpm)) +``` + +`spark.associationRules` returns a `SparkDataFrame` with the association rules. + +```{r} +head(spark.associationRules(fpm)) +``` + +We can make predictions based on the `antecedent`. + +```{r} +head(predict(fpm, df)) +``` + #### Kolmogorov-Smirnov Test `spark.kstest` runs a two-sided, one-sample [Kolmogorov-Smirnov (KS) test](https://en.wikipedia.org/wiki/Kolmogorov%E2%80%93Smirnov_test). diff --git a/examples/src/main/r/ml/fpm.R b/examples/src/main/r/ml/fpm.R new file mode 100644 index 0000000000000..89c4564457d9e --- /dev/null +++ b/examples/src/main/r/ml/fpm.R @@ -0,0 +1,50 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# To run this example use +# ./bin/spark-submit examples/src/main/r/ml/fpm.R + +# Load SparkR library into your R session +library(SparkR) + +# Initialize SparkSession +sparkR.session(appName = "SparkR-ML-fpm-example") + +# $example on$ +# Load training data + +df <- selectExpr(createDataFrame(data.frame(rawItems = c( + "1,2,5", "1,2,3,5", "1,2" +))), "split(rawItems, ',') AS items") + +fpm <- spark.fpGrowth(df, itemsCol="items", minSupport=0.5, minConfidence=0.6) + +# Extracting frequent itemsets + +spark.freqItemsets(fpm) + +# Extracting association rules + +spark.associationRules(fpm) + +# Predict uses association rules to and combines possible consequents + +predict(fpm, df) + +# $example off$ + +sparkR.session.stop() From 608bf30f0b9759fd0b9b9f33766295550996a9eb Mon Sep 17 00:00:00 2001 From: Koert Kuipers Date: Wed, 19 Apr 2017 15:52:47 +0800 Subject: [PATCH 0289/1765] [SPARK-20359][SQL] Avoid unnecessary execution in EliminateOuterJoin optimization that can lead to NPE Avoid necessary execution that can lead to NPE in EliminateOuterJoin and add test in DataFrameSuite to confirm NPE is no longer thrown ## What changes were proposed in this pull request? Change leftHasNonNullPredicate and rightHasNonNullPredicate to lazy so they are only executed when needed. ## How was this patch tested? Added test in DataFrameSuite that failed before this fix and now succeeds. Note that a test in catalyst project would be better but i am unsure how to do this. Please review http://spark.apache.org/contributing.html before opening a pull request. Author: Koert Kuipers Closes #17660 from koertkuipers/feat-catch-npe-in-eliminate-outer-join. --- .../apache/spark/sql/catalyst/optimizer/joins.scala | 4 ++-- .../scala/org/apache/spark/sql/DataFrameSuite.scala | 10 ++++++++++ 2 files changed, 12 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/joins.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/joins.scala index c3ab58744953d..2fe3039774423 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/joins.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/joins.scala @@ -134,8 +134,8 @@ case class EliminateOuterJoin(conf: SQLConf) extends Rule[LogicalPlan] with Pred val leftConditions = conditions.filter(_.references.subsetOf(join.left.outputSet)) val rightConditions = conditions.filter(_.references.subsetOf(join.right.outputSet)) - val leftHasNonNullPredicate = leftConditions.exists(canFilterOutNull) - val rightHasNonNullPredicate = rightConditions.exists(canFilterOutNull) + lazy val leftHasNonNullPredicate = leftConditions.exists(canFilterOutNull) + lazy val rightHasNonNullPredicate = rightConditions.exists(canFilterOutNull) join.joinType match { case RightOuter if leftHasNonNullPredicate => Inner diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index 52bd4e19f8952..b4893b56a8a84 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -1722,4 +1722,14 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { "Cannot have map type columns in DataFrame which calls set operations")) } } + + test("SPARK-20359: catalyst outer join optimization should not throw npe") { + val df1 = Seq("a", "b", "c").toDF("x") + .withColumn("y", udf{ (x: String) => x.substring(0, 1) + "!" }.apply($"x")) + val df2 = Seq("a", "b").toDF("x1") + df1 + .join(df2, df1("x") === df2("x1"), "left_outer") + .filter($"x1".isNotNull || !$"y".isin("a!")) + .count + } } From 773754b6c1516c15b64846a00e491535cbcb1007 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Wed, 19 Apr 2017 16:01:28 +0800 Subject: [PATCH 0290/1765] [SPARK-20356][SQL] Pruned InMemoryTableScanExec should have correct output partitioning and ordering ## What changes were proposed in this pull request? The output of `InMemoryTableScanExec` can be pruned and mismatch with `InMemoryRelation` and its child plan's output. This causes wrong output partitioning and ordering. ## How was this patch tested? Jenkins tests. Please review http://spark.apache.org/contributing.html before opening a pull request. Author: Liang-Chi Hsieh Closes #17679 from viirya/SPARK-20356. --- .../columnar/InMemoryTableScanExec.scala | 4 +++- .../columnar/InMemoryColumnarQuerySuite.scala | 15 +++++++++++++++ 2 files changed, 18 insertions(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala index 214e8d309de11..7063b08f7c644 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala @@ -42,7 +42,9 @@ case class InMemoryTableScanExec( override def output: Seq[Attribute] = attributes private def updateAttribute(expr: Expression): Expression = { - val attrMap = AttributeMap(relation.child.output.zip(output)) + // attributes can be pruned so using relation's output. + // E.g., relation.output is [id, item] but this scan's output can be [item] only. + val attrMap = AttributeMap(relation.child.output.zip(relation.output)) expr.transform { case attr: Attribute => attrMap.getOrElse(attr, attr) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala index 1e6a6a8ba3362..109b1d9db60d2 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala @@ -414,4 +414,19 @@ class InMemoryColumnarQuerySuite extends QueryTest with SharedSQLContext { assert(partitionedAttrs.subsetOf(inMemoryScan.outputSet)) } } + + test("SPARK-20356: pruned InMemoryTableScanExec should have correct ordering and partitioning") { + withSQLConf("spark.sql.shuffle.partitions" -> "200") { + val df1 = Seq(("a", 1), ("b", 1), ("c", 2)).toDF("item", "group") + val df2 = Seq(("a", 1), ("b", 2), ("c", 3)).toDF("item", "id") + val df3 = df1.join(df2, Seq("item")).select($"id", $"group".as("item")).distinct() + + df3.unpersist() + val agg_without_cache = df3.groupBy($"item").count() + + df3.cache() + val agg_with_cache = df3.groupBy($"item").count() + checkAnswer(agg_without_cache, agg_with_cache) + } + } } From 35378766ad7d3c494425a8781efe9cb9349732b7 Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Wed, 19 Apr 2017 12:18:54 +0100 Subject: [PATCH 0291/1765] [SPARK-20343][BUILD] Avoid Unidoc build only if Hadoop 2.6 is explicitly set in SBT build ## What changes were proposed in this pull request? This PR proposes two things as below: - Avoid Unidoc build only if Hadoop 2.6 is explicitly set in SBT build Due to a different dependency resolution in SBT & Unidoc by an unknown reason, the documentation build fails on a specific machine & environment in Jenkins but it was unable to reproduce. So, this PR just checks an environment variable `AMPLAB_JENKINS_BUILD_PROFILE` that is set in Hadoop 2.6 SBT build against branches on Jenkins, and then disables Unidoc build. **Note that PR builder will still build it with Hadoop 2.6 & SBT.** ``` ======================================================================== Building Unidoc API Documentation ======================================================================== [info] Building Spark unidoc (w/Hive 1.2.1) using SBT with these arguments: -Phadoop-2.6 -Pmesos -Pkinesis-asl -Pyarn -Phive-thriftserver -Phive unidoc Using /usr/java/jdk1.8.0_60 as default JAVA_HOME. ... ``` I checked the environment variables from the logs (first bit) as below: - **spark-master-test-sbt-hadoop-2.6** (this one is being failed) - https://amplab.cs.berkeley.edu/jenkins/view/Spark%20QA%20Test%20(Dashboard)/job/spark-master-test-sbt-hadoop-2.6/lastBuild/consoleFull ``` JAVA_HOME=/usr/java/jdk1.8.0_60 JAVA_7_HOME=/usr/java/jdk1.7.0_79 SPARK_BRANCH=master AMPLAB_JENKINS_BUILD_PROFILE=hadoop2.6 <- I use this variable AMPLAB_JENKINS="true" ``` - spark-master-test-sbt-hadoop-2.7 - https://amplab.cs.berkeley.edu/jenkins/view/Spark%20QA%20Test%20(Dashboard)/job/spark-master-test-sbt-hadoop-2.7/lastBuild/consoleFull ``` JAVA_HOME=/usr/java/jdk1.8.0_60 JAVA_7_HOME=/usr/java/jdk1.7.0_79 SPARK_BRANCH=master AMPLAB_JENKINS_BUILD_PROFILE=hadoop2.7 AMPLAB_JENKINS="true" ``` - spark-master-test-maven-hadoop-2.6 - https://amplab.cs.berkeley.edu/jenkins/view/Spark%20QA%20Test%20(Dashboard)/job/spark-master-test-maven-hadoop-2.6/lastBuild/consoleFull ``` JAVA_HOME=/usr/java/jdk1.8.0_60 JAVA_7_HOME=/usr/java/jdk1.7.0_79 HADOOP_PROFILE=hadoop-2.6 HADOOP_VERSION= SPARK_BRANCH=master AMPLAB_JENKINS="true" ``` - spark-master-test-maven-hadoop-2.7 - https://amplab.cs.berkeley.edu/jenkins/view/Spark%20QA%20Test%20(Dashboard)/job/spark-master-test-maven-hadoop-2.7/lastBuild/consoleFull ``` JAVA_HOME=/usr/java/jdk1.8.0_60 JAVA_7_HOME=/usr/java/jdk1.7.0_79 HADOOP_PROFILE=hadoop-2.7 HADOOP_VERSION= SPARK_BRANCH=master AMPLAB_JENKINS="true" ``` - PR builder - https://amplab.cs.berkeley.edu/jenkins/job/SparkPullRequestBuilder/75843/consoleFull ``` JENKINS_MASTER_HOSTNAME=amp-jenkins-master JAVA_HOME=/usr/java/jdk1.8.0_60 JAVA_7_HOME=/usr/java/jdk1.7.0_79 ``` Assuming from other logs in branch-2.1 - SBT & Hadoop 2.6 against branch-2.1 https://amplab.cs.berkeley.edu/jenkins/view/Spark%20QA%20Test%20(Dashboard)/job/spark-branch-2.1-test-sbt-hadoop-2.6/lastBuild/consoleFull ``` JAVA_HOME=/usr/java/jdk1.8.0_60 JAVA_7_HOME=/usr/java/jdk1.7.0_79 SPARK_BRANCH=branch-2.1 AMPLAB_JENKINS_BUILD_PROFILE=hadoop2.6 AMPLAB_JENKINS="true" ``` - Maven & Hadoop 2.6 against branch-2.1 https://amplab.cs.berkeley.edu/jenkins/view/Spark%20QA%20Test%20(Dashboard)/job/spark-branch-2.1-test-maven-hadoop-2.6/lastBuild/consoleFull ``` JAVA_HOME=/usr/java/jdk1.8.0_60 JAVA_7_HOME=/usr/java/jdk1.7.0_79 HADOOP_PROFILE=hadoop-2.6 HADOOP_VERSION= SPARK_BRANCH=branch-2.1 AMPLAB_JENKINS="true" ``` We have been using the same convention for those variables. These are actually being used in `run-tests.py` script - here https://github.com/apache/spark/blob/master/dev/run-tests.py#L519-L520 - Revert the previous try After https://github.com/apache/spark/pull/17651, it seems the build still fails on SBT Hadoop 2.6 master. I am unable to reproduce this - https://github.com/apache/spark/pull/17477#issuecomment-294094092 and the reviewer was too. So, this got merged as it looks the only way to verify this is to merge it currently (as no one seems able to reproduce this). ## How was this patch tested? I only checked `is_hadoop_version_2_6 = os.environ.get("AMPLAB_JENKINS_BUILD_PROFILE") == "hadoop2.6"` is working fine as expected as below: ```python >>> import collections >>> os = collections.namedtuple('os', 'environ')(environ={"AMPLAB_JENKINS_BUILD_PROFILE": "hadoop2.6"}) >>> print(not os.environ.get("AMPLAB_JENKINS_BUILD_PROFILE") == "hadoop2.6") False >>> os = collections.namedtuple('os', 'environ')(environ={"AMPLAB_JENKINS_BUILD_PROFILE": "hadoop2.7"}) >>> print(not os.environ.get("AMPLAB_JENKINS_BUILD_PROFILE") == "hadoop2.6") True >>> os = collections.namedtuple('os', 'environ')(environ={}) >>> print(not os.environ.get("AMPLAB_JENKINS_BUILD_PROFILE") == "hadoop2.6") True ``` I tried many ways but I was unable to reproduce this in my local. Sean also tried the way I did but he was also unable to reproduce this. Please refer the comments in https://github.com/apache/spark/pull/17477#issuecomment-294094092 Author: hyukjinkwon Closes #17669 from HyukjinKwon/revert-SPARK-20343. --- dev/run-tests.py | 12 ++++++++++-- pom.xml | 1 - project/SparkBuild.scala | 14 ++------------ 3 files changed, 12 insertions(+), 15 deletions(-) diff --git a/dev/run-tests.py b/dev/run-tests.py index 450b68123e1fc..818a0c9f48419 100755 --- a/dev/run-tests.py +++ b/dev/run-tests.py @@ -365,8 +365,16 @@ def build_spark_assembly_sbt(hadoop_version): print("[info] Building Spark assembly (w/Hive 1.2.1) using SBT with these arguments: ", " ".join(profiles_and_goals)) exec_sbt(profiles_and_goals) - # Make sure that Java and Scala API documentation can be generated - build_spark_unidoc_sbt(hadoop_version) + + # Note that we skip Unidoc build only if Hadoop 2.6 is explicitly set in this SBT build. + # Due to a different dependency resolution in SBT & Unidoc by an unknown reason, the + # documentation build fails on a specific machine & environment in Jenkins but it was unable + # to reproduce. Please see SPARK-20343. This is a band-aid fix that should be removed in + # the future. + is_hadoop_version_2_6 = os.environ.get("AMPLAB_JENKINS_BUILD_PROFILE") == "hadoop2.6" + if not is_hadoop_version_2_6: + # Make sure that Java and Scala API documentation can be generated + build_spark_unidoc_sbt(hadoop_version) def build_apache_spark(build_tool, hadoop_version): diff --git a/pom.xml b/pom.xml index 14370d92a9080..c1174593c1922 100644 --- a/pom.xml +++ b/pom.xml @@ -142,7 +142,6 @@ 2.4.0 2.0.8 3.1.2 - 1.7.7 hadoop2 0.9.3 diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala index 77dae289f7758..e52baf51aed1a 100644 --- a/project/SparkBuild.scala +++ b/project/SparkBuild.scala @@ -318,8 +318,8 @@ object SparkBuild extends PomBuild { enable(MimaBuild.mimaSettings(sparkHome, x))(x) } - /* Generate and pick the spark build info from extra-resources and override a dependency */ - enable(Core.settings ++ CoreDependencyOverrides.settings)(core) + /* Generate and pick the spark build info from extra-resources */ + enable(Core.settings)(core) /* Unsafe settings */ enable(Unsafe.settings)(unsafe) @@ -443,16 +443,6 @@ object DockerIntegrationTests { ) } -/** - * Overrides to work around sbt's dependency resolution being different from Maven's in Unidoc. - * - * Note that, this is a hack that should be removed in the future. See SPARK-20343 - */ -object CoreDependencyOverrides { - lazy val settings = Seq( - dependencyOverrides += "org.apache.avro" % "avro" % "1.7.7") -} - /** * Overrides to work around sbt's dependency resolution being different from Maven's. */ From 71a8e9df12e547cb4716f954ecb762b358f862d5 Mon Sep 17 00:00:00 2001 From: cody koeninger Date: Wed, 19 Apr 2017 18:58:58 +0100 Subject: [PATCH 0292/1765] [SPARK-20036][DOC] Note incompatible dependencies on org.apache.kafka artifacts ## What changes were proposed in this pull request? Note that you shouldn't manually add dependencies on org.apache.kafka artifacts ## How was this patch tested? Doc only change, did jekyll build and looked at the page. Author: cody koeninger Closes #17675 from koeninger/SPARK-20036. --- docs/streaming-kafka-0-10-integration.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/docs/streaming-kafka-0-10-integration.md b/docs/streaming-kafka-0-10-integration.md index e3837013168dc..92c296a9e6bd3 100644 --- a/docs/streaming-kafka-0-10-integration.md +++ b/docs/streaming-kafka-0-10-integration.md @@ -12,6 +12,8 @@ For Scala/Java applications using SBT/Maven project definitions, link your strea artifactId = spark-streaming-kafka-0-10_{{site.SCALA_BINARY_VERSION}} version = {{site.SPARK_VERSION_SHORT}} +**Do not** manually add dependencies on `org.apache.kafka` artifacts (e.g. `kafka-clients`). The `spark-streaming-kafka-0-10` artifact has the appropriate transitive dependencies already, and different versions may be incompatible in hard to diagnose ways. + ### Creating a Direct Stream Note that the namespace for the import includes the version, org.apache.spark.streaming.kafka010 From 4fea7848c45d85ff3ad0863de5d1449d1fd1b4b0 Mon Sep 17 00:00:00 2001 From: Shixiong Zhu Date: Wed, 19 Apr 2017 13:10:44 -0700 Subject: [PATCH 0293/1765] [SPARK-20397][SPARKR][SS] Fix flaky test: test_streaming.R.Terminated by error ## What changes were proposed in this pull request? Checking a source parameter is asynchronous. When the query is created, it's not guaranteed that source has been created. This PR just increases the timeout of awaitTermination to ensure the parsing error is thrown. ## How was this patch tested? Jenkins Author: Shixiong Zhu Closes #17687 from zsxwing/SPARK-20397. --- R/pkg/inst/tests/testthat/test_streaming.R | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/R/pkg/inst/tests/testthat/test_streaming.R b/R/pkg/inst/tests/testthat/test_streaming.R index 03b1bd3dc1f44..1f4054a84df53 100644 --- a/R/pkg/inst/tests/testthat/test_streaming.R +++ b/R/pkg/inst/tests/testthat/test_streaming.R @@ -131,7 +131,7 @@ test_that("Terminated by error", { expect_error(q <- write.stream(counts, "memory", queryName = "people4", outputMode = "complete"), NA) - expect_error(awaitTermination(q, 1), + expect_error(awaitTermination(q, 5 * 1000), paste0(".*(awaitTermination : streaming query error - Invalid value '-1' for option", " 'maxFilesPerTrigger', must be a positive integer).*")) From 63824b2c8e010ba03013be498def236c654d4fed Mon Sep 17 00:00:00 2001 From: ptkool Date: Thu, 20 Apr 2017 09:51:13 +0800 Subject: [PATCH 0294/1765] [SPARK-20350] Add optimization rules to apply Complementation Laws. ## What changes were proposed in this pull request? Apply Complementation Laws during boolean expression simplification. ## How was this patch tested? Tested using unit tests, integration tests, and manual tests. Author: ptkool Author: Michael Styles Closes #17650 from ptkool/apply_complementation_laws. --- .../sql/catalyst/optimizer/expressions.scala | 5 +++++ .../BooleanSimplificationSuite.scala | 19 +++++++++++++++++++ 2 files changed, 24 insertions(+) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala index ea2c5d241d8dd..34382bd272406 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala @@ -154,6 +154,11 @@ object BooleanSimplification extends Rule[LogicalPlan] with PredicateHelper { case TrueLiteral Or _ => TrueLiteral case _ Or TrueLiteral => TrueLiteral + case a And b if Not(a).semanticEquals(b) => FalseLiteral + case a Or b if Not(a).semanticEquals(b) => TrueLiteral + case a And b if a.semanticEquals(Not(b)) => FalseLiteral + case a Or b if a.semanticEquals(Not(b)) => TrueLiteral + case a And b if a.semanticEquals(b) => a case a Or b if a.semanticEquals(b) => a diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/BooleanSimplificationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/BooleanSimplificationSuite.scala index 935bff7cef2e8..c275f997ba6e9 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/BooleanSimplificationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/BooleanSimplificationSuite.scala @@ -26,6 +26,7 @@ import org.apache.spark.sql.catalyst.plans.PlanTest import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules._ import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.Row class BooleanSimplificationSuite extends PlanTest with PredicateHelper { @@ -42,6 +43,16 @@ class BooleanSimplificationSuite extends PlanTest with PredicateHelper { val testRelation = LocalRelation('a.int, 'b.int, 'c.int, 'd.string) + val testRelationWithData = LocalRelation.fromExternalRows( + testRelation.output, Seq(Row(1, 2, 3, "abc")) + ) + + private def checkCondition(input: Expression, expected: LogicalPlan): Unit = { + val plan = testRelationWithData.where(input).analyze + val actual = Optimize.execute(plan) + comparePlans(actual, expected) + } + private def checkCondition(input: Expression, expected: Expression): Unit = { val plan = testRelation.where(input).analyze val actual = Optimize.execute(plan) @@ -160,4 +171,12 @@ class BooleanSimplificationSuite extends PlanTest with PredicateHelper { testRelation.where('a > 2 || ('b > 3 && 'b < 5))) comparePlans(actual, expected) } + + test("Complementation Laws") { + checkCondition('a && !'a, testRelation) + checkCondition(!'a && 'a, testRelation) + + checkCondition('a || !'a, testRelationWithData) + checkCondition(!'a || 'a, testRelationWithData) + } } From 39e303a8b6db642c26dbc26ba92e87680f50e4da Mon Sep 17 00:00:00 2001 From: Shixiong Zhu Date: Wed, 19 Apr 2017 18:58:14 -0700 Subject: [PATCH 0295/1765] [MINOR][SS] Fix a missing space in UnsupportedOperationChecker error message ## What changes were proposed in this pull request? Also went through the same file to ensure other string concatenation are correct. ## How was this patch tested? Jenkins Author: Shixiong Zhu Closes #17691 from zsxwing/fix-error-message. --- .../sql/catalyst/analysis/UnsupportedOperationChecker.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala index 3f76f26dbe4ec..6ab4153bac70e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala @@ -267,7 +267,7 @@ object UnsupportedOperationChecker { throwError("Limits are not supported on streaming DataFrames/Datasets") case Sort(_, _, _) if !containsCompleteData(subPlan) => - throwError("Sorting is not supported on streaming DataFrames/Datasets, unless it is on" + + throwError("Sorting is not supported on streaming DataFrames/Datasets, unless it is on " + "aggregated DataFrame/Dataset in Complete output mode") case Sample(_, _, _, _, child) if child.isStreaming => From dd6d55d5de970662eccf024e5eae4e6821373d35 Mon Sep 17 00:00:00 2001 From: Eric Liang Date: Wed, 19 Apr 2017 19:53:40 -0700 Subject: [PATCH 0296/1765] [SPARK-20398][SQL] range() operator should include cancellation reason when killed ## What changes were proposed in this pull request? https://issues.apache.org/jira/browse/SPARK-19820 adds a reason field for why tasks were killed. However, for backwards compatibility it left the old TaskKilledException constructor which defaults to "unknown reason". The range() operator should use the constructor that fills in the reason rather than dropping it on task kill. ## How was this patch tested? Existing tests, and I tested this manually. Author: Eric Liang Closes #17692 from ericl/fix-kill-reason-in-range. --- .../apache/spark/sql/execution/basicPhysicalOperators.scala | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala index 44278e37c5276..233a105f4d93a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala @@ -463,9 +463,7 @@ case class RangeExec(range: org.apache.spark.sql.catalyst.plans.logical.Range) | $number = $batchEnd; | } | - | if ($taskContext.isInterrupted()) { - | throw new TaskKilledException(); - | } + | $taskContext.killTaskIfInterrupted(); | | long $nextBatchTodo; | if ($numElementsTodo > ${batchSize}L) { From bdc60569196e9ae4e9086c3e514a406a9e8b23a6 Mon Sep 17 00:00:00 2001 From: ymahajan Date: Wed, 19 Apr 2017 20:08:31 -0700 Subject: [PATCH 0297/1765] Fixed typos in docs ## What changes were proposed in this pull request? Typos at a couple of place in the docs. ## How was this patch tested? build including docs Please review http://spark.apache.org/contributing.html before opening a pull request. Author: ymahajan Closes #17690 from ymahajan/master. --- docs/sql-programming-guide.md | 2 +- docs/structured-streaming-programming-guide.md | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md index 28942b68fa20d..490c1ce8a7cc5 100644 --- a/docs/sql-programming-guide.md +++ b/docs/sql-programming-guide.md @@ -571,7 +571,7 @@ be created by calling the `table` method on a `SparkSession` with the name of th For file-based data source, e.g. text, parquet, json, etc. you can specify a custom table path via the `path` option, e.g. `df.write.option("path", "/some/path").saveAsTable("t")`. When the table is dropped, the custom table path will not be removed and the table data is still there. If no custom table path is -specifed, Spark will write data to a default table path under the warehouse directory. When the table is +specified, Spark will write data to a default table path under the warehouse directory. When the table is dropped, the default table path will be removed too. Starting from Spark 2.1, persistent datasource tables have per-partition metadata stored in the Hive metastore. This brings several benefits: diff --git a/docs/structured-streaming-programming-guide.md b/docs/structured-streaming-programming-guide.md index 3cf7151819e2d..5b18cf2f3c2ef 100644 --- a/docs/structured-streaming-programming-guide.md +++ b/docs/structured-streaming-programming-guide.md @@ -778,7 +778,7 @@ windowedCounts = words \ In this example, we are defining the watermark of the query on the value of the column "timestamp", and also defining "10 minutes" as the threshold of how late is the data allowed to be. If this query is run in Update output mode (discussed later in [Output Modes](#output-modes) section), -the engine will keep updating counts of a window in the Resule Table until the window is older +the engine will keep updating counts of a window in the Result Table until the window is older than the watermark, which lags behind the current event time in column "timestamp" by 10 minutes. Here is an illustration. From 46c5749768fefd976097c7d5612ec184a4cfe1b9 Mon Sep 17 00:00:00 2001 From: zero323 Date: Wed, 19 Apr 2017 21:19:46 -0700 Subject: [PATCH 0298/1765] [SPARK-20375][R] R wrappers for array and map ## What changes were proposed in this pull request? Adds wrappers for `o.a.s.sql.functions.array` and `o.a.s.sql.functions.map` ## How was this patch tested? Unit tests, `check-cran.sh` Author: zero323 Closes #17674 from zero323/SPARK-20375. --- R/pkg/NAMESPACE | 2 + R/pkg/R/functions.R | 53 +++++++++++++++++++++++ R/pkg/R/generics.R | 8 ++++ R/pkg/inst/tests/testthat/test_sparkSQL.R | 17 ++++++++ 4 files changed, 80 insertions(+) diff --git a/R/pkg/NAMESPACE b/R/pkg/NAMESPACE index ca45c6f9b0a96..b6b559adf06ea 100644 --- a/R/pkg/NAMESPACE +++ b/R/pkg/NAMESPACE @@ -213,6 +213,8 @@ exportMethods("%in%", "count", "countDistinct", "crc32", + "create_array", + "create_map", "hash", "cume_dist", "date_add", diff --git a/R/pkg/R/functions.R b/R/pkg/R/functions.R index c311921fb33db..f854df11e5769 100644 --- a/R/pkg/R/functions.R +++ b/R/pkg/R/functions.R @@ -3652,3 +3652,56 @@ setMethod("posexplode", jc <- callJStatic("org.apache.spark.sql.functions", "posexplode", x@jc) column(jc) }) + +#' create_array +#' +#' Creates a new array column. The input columns must all have the same data type. +#' +#' @param x Column to compute on +#' @param ... additional Column(s). +#' +#' @family normal_funcs +#' @rdname create_array +#' @name create_array +#' @aliases create_array,Column-method +#' @export +#' @examples \dontrun{create_array(df$x, df$y, df$z)} +#' @note create_array since 2.3.0 +setMethod("create_array", + signature(x = "Column"), + function(x, ...) { + jcols <- lapply(list(x, ...), function (x) { + stopifnot(class(x) == "Column") + x@jc + }) + jc <- callJStatic("org.apache.spark.sql.functions", "array", jcols) + column(jc) + }) + +#' create_map +#' +#' Creates a new map column. The input columns must be grouped as key-value pairs, +#' e.g. (key1, value1, key2, value2, ...). +#' The key columns must all have the same data type, and can't be null. +#' The value columns must all have the same data type. +#' +#' @param x Column to compute on +#' @param ... additional Column(s). +#' +#' @family normal_funcs +#' @rdname create_map +#' @name create_map +#' @aliases create_map,Column-method +#' @export +#' @examples \dontrun{create_map(lit("x"), lit(1.0), lit("y"), lit(-1.0))} +#' @note create_map since 2.3.0 +setMethod("create_map", + signature(x = "Column"), + function(x, ...) { + jcols <- lapply(list(x, ...), function (x) { + stopifnot(class(x) == "Column") + x@jc + }) + jc <- callJStatic("org.apache.spark.sql.functions", "map", jcols) + column(jc) + }) diff --git a/R/pkg/R/generics.R b/R/pkg/R/generics.R index 945676c7f10b3..da46823f52a17 100644 --- a/R/pkg/R/generics.R +++ b/R/pkg/R/generics.R @@ -942,6 +942,14 @@ setGeneric("countDistinct", function(x, ...) { standardGeneric("countDistinct") #' @export setGeneric("crc32", function(x) { standardGeneric("crc32") }) +#' @rdname create_array +#' @export +setGeneric("create_array", function(x, ...) { standardGeneric("create_array") }) + +#' @rdname create_map +#' @export +setGeneric("create_map", function(x, ...) { standardGeneric("create_map") }) + #' @rdname hash #' @export setGeneric("hash", function(x, ...) { standardGeneric("hash") }) diff --git a/R/pkg/inst/tests/testthat/test_sparkSQL.R b/R/pkg/inst/tests/testthat/test_sparkSQL.R index 6a6c9a809ab13..9e87a47106994 100644 --- a/R/pkg/inst/tests/testthat/test_sparkSQL.R +++ b/R/pkg/inst/tests/testthat/test_sparkSQL.R @@ -1461,6 +1461,23 @@ test_that("column functions", { expect_equal(length(arr$arrcol[[1]]), 2) expect_equal(arr$arrcol[[1]][[1]]$name, "Bob") expect_equal(arr$arrcol[[1]][[2]]$name, "Alice") + + # Test create_array() and create_map() + df <- as.DataFrame(data.frame( + x = c(1.0, 2.0), y = c(-1.0, 3.0), z = c(-2.0, 5.0) + )) + + arrs <- collect(select(df, create_array(df$x, df$y, df$z))) + expect_equal(arrs[, 1], list(list(1, -1, -2), list(2, 3, 5))) + + maps <- collect(select( + df, create_map(lit("x"), df$x, lit("y"), df$y, lit("z"), df$z))) + + expect_equal( + maps[, 1], + lapply( + list(list(x = 1, y = -1, z = -2), list(x = 2, y = 3, z = 5)), + as.environment)) }) test_that("column binary mathfunctions", { From 55bea56911a958f6d3ec3ad96fb425cc71ec03f4 Mon Sep 17 00:00:00 2001 From: Xiao Li Date: Thu, 20 Apr 2017 11:13:48 +0100 Subject: [PATCH 0299/1765] [SPARK-20156][SQL][FOLLOW-UP] Java String toLowerCase "Turkish locale bug" in Database and Table DDLs ### What changes were proposed in this pull request? Database and Table names conform the Hive standard ("[a-zA-z_0-9]+"), i.e. if this name only contains characters, numbers, and _. When calling `toLowerCase` on the names, we should add `Locale.ROOT` to the `toLowerCase`for avoiding inadvertent locale-sensitive variation in behavior (aka the "Turkish locale problem"). ### How was this patch tested? Added a test case Author: Xiao Li Closes #17655 from gatorsmile/locale. --- .../ResolveTableValuedFunctions.scala | 4 ++- .../sql/catalyst/catalog/SessionCatalog.scala | 4 +-- .../spark/sql/internal/SharedState.scala | 4 ++- .../sql/execution/command/DDLSuite.scala | 19 +++++++++++++ .../apache/spark/sql/test/SQLTestUtils.scala | 28 ++++++++++++++++++- 5 files changed, 54 insertions(+), 5 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveTableValuedFunctions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveTableValuedFunctions.scala index 8841309939c24..de6de24350f23 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveTableValuedFunctions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveTableValuedFunctions.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.catalyst.analysis +import java.util.Locale + import org.apache.spark.sql.catalyst.expressions.Expression import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Range} import org.apache.spark.sql.catalyst.rules._ @@ -103,7 +105,7 @@ object ResolveTableValuedFunctions extends Rule[LogicalPlan] { override def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { case u: UnresolvedTableValuedFunction if u.functionArgs.forall(_.resolved) => - builtinFunctions.get(u.functionName.toLowerCase()) match { + builtinFunctions.get(u.functionName.toLowerCase(Locale.ROOT)) match { case Some(tvf) => val resolved = tvf.flatMap { case (argList, resolver) => argList.implicitCast(u.functionArgs) match { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala index 3fbf83f3a38a2..6c6d600190b66 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala @@ -115,14 +115,14 @@ class SessionCatalog( * Format table name, taking into account case sensitivity. */ protected[this] def formatTableName(name: String): String = { - if (conf.caseSensitiveAnalysis) name else name.toLowerCase + if (conf.caseSensitiveAnalysis) name else name.toLowerCase(Locale.ROOT) } /** * Format database name, taking into account case sensitivity. */ protected[this] def formatDatabaseName(name: String): String = { - if (conf.caseSensitiveAnalysis) name else name.toLowerCase + if (conf.caseSensitiveAnalysis) name else name.toLowerCase(Locale.ROOT) } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/SharedState.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/SharedState.scala index 0289471bf841a..d06dbaa2d0abc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/SharedState.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/SharedState.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.internal +import java.util.Locale + import scala.reflect.ClassTag import scala.util.control.NonFatal @@ -114,7 +116,7 @@ private[sql] class SharedState(val sparkContext: SparkContext) extends Logging { // System preserved database should not exists in metastore. However it's hard to guarantee it // for every session, because case-sensitivity differs. Here we always lowercase it to make our // life easier. - val globalTempDB = sparkContext.conf.get(GLOBAL_TEMP_DATABASE).toLowerCase + val globalTempDB = sparkContext.conf.get(GLOBAL_TEMP_DATABASE).toLowerCase(Locale.ROOT) if (externalCatalog.databaseExists(globalTempDB)) { throw new SparkException( s"$globalTempDB is a system preserved database, please rename your existing database " + diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala index fe74ab49f91bd..2f4eb1b15519b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala @@ -2295,5 +2295,24 @@ abstract class DDLSuite extends QueryTest with SQLTestUtils { } } } + + test(s"basic DDL using locale tr - caseSensitive $caseSensitive") { + withSQLConf(SQLConf.CASE_SENSITIVE.key -> s"$caseSensitive") { + withLocale("tr") { + val dbName = "DaTaBaSe_I" + withDatabase(dbName) { + sql(s"CREATE DATABASE $dbName") + sql(s"USE $dbName") + + val tabName = "tAb_I" + withTable(tabName) { + sql(s"CREATE TABLE $tabName(col_I int) USING PARQUET") + sql(s"INSERT OVERWRITE TABLE $tabName SELECT 1") + checkAnswer(sql(s"SELECT col_I FROM $tabName"), Row(1) :: Nil) + } + } + } + } + } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala index 6a4cc95d36bea..b5ad73b746a8b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.test import java.io.File import java.net.URI import java.nio.file.Files -import java.util.UUID +import java.util.{Locale, UUID} import scala.language.implicitConversions import scala.util.control.NonFatal @@ -228,6 +228,32 @@ private[sql] trait SQLTestUtils } } + /** + * Drops database `dbName` after calling `f`. + */ + protected def withDatabase(dbNames: String*)(f: => Unit): Unit = { + try f finally { + dbNames.foreach { name => + spark.sql(s"DROP DATABASE IF EXISTS $name") + } + } + } + + /** + * Enables Locale `language` before executing `f`, then switches back to the default locale of JVM + * after `f` returns. + */ + protected def withLocale(language: String)(f: => Unit): Unit = { + val originalLocale = Locale.getDefault + try { + // Add Locale setting + Locale.setDefault(new Locale(language)) + f + } finally { + Locale.setDefault(originalLocale) + } + } + /** * Activates database `db` before executing `f`, then switches back to `default` database after * `f` returns. From c6f62c5b8106534007df31ca8c460064b89b450b Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Thu, 20 Apr 2017 14:29:59 +0200 Subject: [PATCH 0300/1765] [SPARK-20405][SQL] Dataset.withNewExecutionId should be private ## What changes were proposed in this pull request? Dataset.withNewExecutionId is only used in Dataset itself and should be private. ## How was this patch tested? N/A - this is a simple visibility change. Author: Reynold Xin Closes #17699 from rxin/SPARK-20405. --- sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index 520663f624408..c6dcd93bbda66 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -2778,7 +2778,7 @@ class Dataset[T] private[sql]( * Wrap a Dataset action to track all Spark jobs in the body so that we can connect them with * an execution. */ - private[sql] def withNewExecutionId[U](body: => U): U = { + private def withNewExecutionId[U](body: => U): U = { SQLExecution.withNewExecutionId(sparkSession, queryExecution)(body) } From b91873db0930c6fe885c27936e1243d5fabd03ed Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Thu, 20 Apr 2017 16:59:38 +0200 Subject: [PATCH 0301/1765] [SPARK-20409][SQL] fail early if aggregate function in GROUP BY ## What changes were proposed in this pull request? It's illegal to have aggregate function in GROUP BY, and we should fail at analysis phase, if this happens. ## How was this patch tested? new regression test Author: Wenchen Fan Closes #17704 from cloud-fan/minor. --- .../spark/sql/catalyst/analysis/Analyzer.scala | 14 ++++---------- .../sql/catalyst/analysis/CheckAnalysis.scala | 7 ++++++- .../sql-tests/results/group-by-ordinal.sql.out | 4 ++-- .../apache/spark/sql/DataFrameAggregateSuite.scala | 7 +++++++ 4 files changed, 19 insertions(+), 13 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index d9f36f7f874d7..175bfb3e8085d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -966,7 +966,7 @@ class Analyzer( case p if !p.childrenResolved => p // Replace the index with the related attribute for ORDER BY, // which is a 1-base position of the projection list. - case s @ Sort(orders, global, child) + case Sort(orders, global, child) if orders.exists(_.child.isInstanceOf[UnresolvedOrdinal]) => val newOrders = orders map { case s @ SortOrder(UnresolvedOrdinal(index), direction, nullOrdering, _) => @@ -983,17 +983,11 @@ class Analyzer( // Replace the index with the corresponding expression in aggregateExpressions. The index is // a 1-base position of aggregateExpressions, which is output columns (select expression) - case a @ Aggregate(groups, aggs, child) if aggs.forall(_.resolved) && + case Aggregate(groups, aggs, child) if aggs.forall(_.resolved) && groups.exists(_.isInstanceOf[UnresolvedOrdinal]) => val newGroups = groups.map { - case ordinal @ UnresolvedOrdinal(index) if index > 0 && index <= aggs.size => - aggs(index - 1) match { - case e if ResolveAggregateFunctions.containsAggregate(e) => - ordinal.failAnalysis( - s"GROUP BY position $index is an aggregate function, and " + - "aggregate functions are not allowed in GROUP BY") - case o => o - } + case u @ UnresolvedOrdinal(index) if index > 0 && index <= aggs.size => + aggs(index - 1) case ordinal @ UnresolvedOrdinal(index) => ordinal.failAnalysis( s"GROUP BY position $index is not in select list " + diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala index da0c6b098f5ce..61797bc34dc27 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala @@ -254,6 +254,11 @@ trait CheckAnalysis extends PredicateHelper { } def checkValidGroupingExprs(expr: Expression): Unit = { + if (expr.find(_.isInstanceOf[AggregateExpression]).isDefined) { + failAnalysis( + "aggregate functions are not allowed in GROUP BY, but found " + expr.sql) + } + // Check if the data type of expr is orderable. if (!RowOrdering.isOrderable(expr.dataType)) { failAnalysis( @@ -271,8 +276,8 @@ trait CheckAnalysis extends PredicateHelper { } } - aggregateExprs.foreach(checkValidAggregateExpression) groupingExprs.foreach(checkValidGroupingExprs) + aggregateExprs.foreach(checkValidAggregateExpression) case Sort(orders, _, _) => orders.foreach { order => diff --git a/sql/core/src/test/resources/sql-tests/results/group-by-ordinal.sql.out b/sql/core/src/test/resources/sql-tests/results/group-by-ordinal.sql.out index c0930bbde69a4..d03681d0ea59c 100644 --- a/sql/core/src/test/resources/sql-tests/results/group-by-ordinal.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/group-by-ordinal.sql.out @@ -122,7 +122,7 @@ select a, b, sum(b) from data group by 3 struct<> -- !query 11 output org.apache.spark.sql.AnalysisException -GROUP BY position 3 is an aggregate function, and aggregate functions are not allowed in GROUP BY; line 1 pos 39 +aggregate functions are not allowed in GROUP BY, but found sum(CAST(data.`b` AS BIGINT)); -- !query 12 @@ -131,7 +131,7 @@ select a, b, sum(b) + 2 from data group by 3 struct<> -- !query 12 output org.apache.spark.sql.AnalysisException -GROUP BY position 3 is an aggregate function, and aggregate functions are not allowed in GROUP BY; line 1 pos 43 +aggregate functions are not allowed in GROUP BY, but found (sum(CAST(data.`b` AS BIGINT)) + CAST(2 AS BIGINT)); -- !query 13 diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala index e7079120bb7df..8569c2d76b694 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala @@ -538,4 +538,11 @@ class DataFrameAggregateSuite extends QueryTest with SharedSQLContext { Seq(Row(3, 0, 0.0, 1, 5.0), Row(2, 1, 4.0, 0, 0.0)) ) } + + test("aggregate function in GROUP BY") { + val e = intercept[AnalysisException] { + testData.groupBy(sum($"key")).count() + } + assert(e.message.contains("aggregate functions are not allowed in GROUP BY")) + } } From c5a31d160f47ba51bb9f8a4f3141851034640fc7 Mon Sep 17 00:00:00 2001 From: Bogdan Raducanu Date: Thu, 20 Apr 2017 18:49:39 +0200 Subject: [PATCH 0302/1765] [SPARK-20407][TESTS] ParquetQuerySuite 'Enabling/disabling ignoreCorruptFiles' flaky test ## What changes were proposed in this pull request? SharedSQLContext.afterEach now calls DebugFilesystem.assertNoOpenStreams inside eventually. SQLTestUtils withTempDir calls waitForTasksToFinish before deleting the directory. ## How was this patch tested? Added new test in ParquetQuerySuite based on the flaky test Author: Bogdan Raducanu Closes #17701 from bogdanrdc/SPARK-20407. --- .../parquet/ParquetQuerySuite.scala | 35 ++++++++++++++++++- .../apache/spark/sql/test/SQLTestUtils.scala | 19 ++++++++-- .../spark/sql/test/SharedSQLContext.scala | 13 ++++--- 3 files changed, 60 insertions(+), 7 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala index c36609586c807..2efff3f57d7d3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala @@ -23,7 +23,7 @@ import java.sql.Timestamp import org.apache.hadoop.fs.{FileSystem, Path} import org.apache.parquet.hadoop.ParquetOutputFormat -import org.apache.spark.SparkException +import org.apache.spark.{DebugFilesystem, SparkException} import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.{InternalRow, TableIdentifier} import org.apache.spark.sql.catalyst.expressions.SpecificInternalRow @@ -316,6 +316,39 @@ class ParquetQuerySuite extends QueryTest with ParquetTest with SharedSQLContext } } + /** + * this is part of test 'Enabling/disabling ignoreCorruptFiles' but run in a loop + * to increase the chance of failure + */ + ignore("SPARK-20407 ParquetQuerySuite 'Enabling/disabling ignoreCorruptFiles' flaky test") { + def testIgnoreCorruptFiles(): Unit = { + withTempDir { dir => + val basePath = dir.getCanonicalPath + spark.range(1).toDF("a").write.parquet(new Path(basePath, "first").toString) + spark.range(1, 2).toDF("a").write.parquet(new Path(basePath, "second").toString) + spark.range(2, 3).toDF("a").write.json(new Path(basePath, "third").toString) + val df = spark.read.parquet( + new Path(basePath, "first").toString, + new Path(basePath, "second").toString, + new Path(basePath, "third").toString) + checkAnswer( + df, + Seq(Row(0), Row(1))) + } + } + + for (i <- 1 to 100) { + DebugFilesystem.clearOpenStreams() + withSQLConf(SQLConf.IGNORE_CORRUPT_FILES.key -> "false") { + val exception = intercept[SparkException] { + testIgnoreCorruptFiles() + } + assert(exception.getMessage().contains("is not a Parquet file")) + } + DebugFilesystem.assertNoOpenStreams() + } + } + test("SPARK-8990 DataFrameReader.parquet() should respect user specified options") { withTempPath { dir => val basePath = dir.getCanonicalPath diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala index b5ad73b746a8b..44c0fc70d066b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala @@ -22,11 +22,13 @@ import java.net.URI import java.nio.file.Files import java.util.{Locale, UUID} +import scala.concurrent.duration._ import scala.language.implicitConversions import scala.util.control.NonFatal import org.apache.hadoop.fs.Path import org.scalatest.BeforeAndAfterAll +import org.scalatest.concurrent.Eventually import org.apache.spark.SparkFunSuite import org.apache.spark.sql._ @@ -49,7 +51,7 @@ import org.apache.spark.util.{UninterruptibleThread, Utils} * prone to leaving multiple overlapping [[org.apache.spark.SparkContext]]s in the same JVM. */ private[sql] trait SQLTestUtils - extends SparkFunSuite + extends SparkFunSuite with Eventually with BeforeAndAfterAll with SQLTestData { self => @@ -138,6 +140,15 @@ private[sql] trait SQLTestUtils } } + /** + * Waits for all tasks on all executors to be finished. + */ + protected def waitForTasksToFinish(): Unit = { + eventually(timeout(10.seconds)) { + assert(spark.sparkContext.statusTracker + .getExecutorInfos.map(_.numRunningTasks()).sum == 0) + } + } /** * Creates a temporary directory, which is then passed to `f` and will be deleted after `f` * returns. @@ -146,7 +157,11 @@ private[sql] trait SQLTestUtils */ protected def withTempDir(f: File => Unit): Unit = { val dir = Utils.createTempDir().getCanonicalFile - try f(dir) finally Utils.deleteRecursively(dir) + try f(dir) finally { + // wait for all tasks to finish before deleting files + waitForTasksToFinish() + Utils.deleteRecursively(dir) + } } /** diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSQLContext.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSQLContext.scala index e122b39f6fc40..3d76e05f616d5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSQLContext.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSQLContext.scala @@ -17,17 +17,18 @@ package org.apache.spark.sql.test +import scala.concurrent.duration._ + import org.scalatest.BeforeAndAfterEach +import org.scalatest.concurrent.Eventually import org.apache.spark.{DebugFilesystem, SparkConf} import org.apache.spark.sql.{SparkSession, SQLContext} -import org.apache.spark.sql.internal.SQLConf - /** * Helper trait for SQL test suites where all tests share a single [[TestSparkSession]]. */ -trait SharedSQLContext extends SQLTestUtils with BeforeAndAfterEach { +trait SharedSQLContext extends SQLTestUtils with BeforeAndAfterEach with Eventually { protected val sparkConf = new SparkConf() @@ -84,6 +85,10 @@ trait SharedSQLContext extends SQLTestUtils with BeforeAndAfterEach { protected override def afterEach(): Unit = { super.afterEach() - DebugFilesystem.assertNoOpenStreams() + // files can be closed from other threads, so wait a bit + // normally this doesn't take more than 1s + eventually(timeout(10.seconds)) { + DebugFilesystem.assertNoOpenStreams() + } } } From b2ebadfd55283348b8a8b37e28075fca0798228a Mon Sep 17 00:00:00 2001 From: Eric Liang Date: Thu, 20 Apr 2017 09:55:10 -0700 Subject: [PATCH 0303/1765] [SPARK-20358][CORE] Executors failing stage on interrupted exception thrown by cancelled tasks ## What changes were proposed in this pull request? This was a regression introduced by my earlier PR here: https://github.com/apache/spark/pull/17531 It turns out NonFatal() does not in fact catch InterruptedException. ## How was this patch tested? Extended cancellation unit test coverage. The first test fails before this patch. cc JoshRosen mridulm Author: Eric Liang Closes #17659 from ericl/spark-20358. --- .../org/apache/spark/executor/Executor.scala | 3 ++- .../org/apache/spark/SparkContextSuite.scala | 26 ++++++++++++------- 2 files changed, 19 insertions(+), 10 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/executor/Executor.scala b/core/src/main/scala/org/apache/spark/executor/Executor.scala index 83469c5ff0600..18f04391d64c3 100644 --- a/core/src/main/scala/org/apache/spark/executor/Executor.scala +++ b/core/src/main/scala/org/apache/spark/executor/Executor.scala @@ -432,7 +432,8 @@ private[spark] class Executor( setTaskFinishedAndClearInterruptStatus() execBackend.statusUpdate(taskId, TaskState.KILLED, ser.serialize(TaskKilled(t.reason))) - case NonFatal(_) if task != null && task.reasonIfKilled.isDefined => + case _: InterruptedException | NonFatal(_) if + task != null && task.reasonIfKilled.isDefined => val killReason = task.reasonIfKilled.getOrElse("unknown reason") logInfo(s"Executor interrupted and killed $taskName (TID $taskId), reason: $killReason") setTaskFinishedAndClearInterruptStatus() diff --git a/core/src/test/scala/org/apache/spark/SparkContextSuite.scala b/core/src/test/scala/org/apache/spark/SparkContextSuite.scala index 735f4454e299e..7e26139a2bead 100644 --- a/core/src/test/scala/org/apache/spark/SparkContextSuite.scala +++ b/core/src/test/scala/org/apache/spark/SparkContextSuite.scala @@ -540,10 +540,24 @@ class SparkContextSuite extends SparkFunSuite with LocalSparkContext with Eventu } } - // Launches one task that will run forever. Once the SparkListener detects the task has + testCancellingTasks("that raise interrupted exception on cancel") { + Thread.sleep(9999999) + } + + // SPARK-20217 should not fail stage if task throws non-interrupted exception + testCancellingTasks("that raise runtime exception on cancel") { + try { + Thread.sleep(9999999) + } catch { + case t: Throwable => + throw new RuntimeException("killed") + } + } + + // Launches one task that will block forever. Once the SparkListener detects the task has // started, kill and re-schedule it. The second run of the task will complete immediately. // If this test times out, then the first version of the task wasn't killed successfully. - test("Killing tasks") { + def testCancellingTasks(desc: String)(blockFn: => Unit): Unit = test(s"Killing tasks $desc") { sc = new SparkContext(new SparkConf().setAppName("test").setMaster("local")) SparkContextSuite.isTaskStarted = false @@ -572,13 +586,7 @@ class SparkContextSuite extends SparkFunSuite with LocalSparkContext with Eventu // first attempt will hang if (!SparkContextSuite.isTaskStarted) { SparkContextSuite.isTaskStarted = true - try { - Thread.sleep(9999999) - } catch { - case t: Throwable => - // SPARK-20217 should not fail stage if task throws non-interrupted exception - throw new RuntimeException("killed") - } + blockFn } // second attempt succeeds immediately } From d95e4d9d6a9705c534549add6d4a73d554e47274 Mon Sep 17 00:00:00 2001 From: Dilip Biswal Date: Thu, 20 Apr 2017 22:35:48 +0200 Subject: [PATCH 0304/1765] [SPARK-20334][SQL] Return a better error message when correlated predicates contain aggregate expression that has mixture of outer and local references. ## What changes were proposed in this pull request? Address a follow up in [comment](https://github.com/apache/spark/pull/16954#discussion_r105718880) Currently subqueries with correlated predicates containing aggregate expression having mixture of outer references and local references generate a codegen error like following : ```SQL SELECT t1a FROM t1 GROUP BY 1 HAVING EXISTS (SELECT 1 FROM t2 WHERE t2a < min(t1a + t2a)); ``` Exception snippet. ``` Cannot evaluate expression: min((input[0, int, false] + input[4, int, false])) at org.apache.spark.sql.catalyst.expressions.Unevaluable$class.doGenCode(Expression.scala:226) at org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression.doGenCode(interfaces.scala:87) at org.apache.spark.sql.catalyst.expressions.Expression$$anonfun$genCode$2.apply(Expression.scala:106) at org.apache.spark.sql.catalyst.expressions.Expression$$anonfun$genCode$2.apply(Expression.scala:103) at scala.Option.getOrElse(Option.scala:121) at org.apache.spark.sql.catalyst.expressions.Expression.genCode(Expression.scala:103) ``` After this PR, a better error message is issued. ``` org.apache.spark.sql.AnalysisException Error in query: Found an aggregate expression in a correlated predicate that has both outer and local references, which is not supported yet. Aggregate expression: min((t1.`t1a` + t2.`t2a`)), Outer references: t1.`t1a`, Local references: t2.`t2a`.; ``` ## How was this patch tested? Added tests in SQLQueryTestSuite. Author: Dilip Biswal Closes #17636 from dilipbiswal/subquery_followup1. --- .../sql/catalyst/analysis/Analyzer.scala | 49 +++++++--- .../negative-cases/invalid-correlation.sql | 74 +++++++++----- .../invalid-correlation.sql.out | 96 ++++++++++++++----- .../org/apache/spark/sql/SubquerySuite.scala | 23 ++++- 4 files changed, 181 insertions(+), 61 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 175bfb3e8085d..eafeb4ac1ae55 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -1204,6 +1204,28 @@ class Analyzer( private def checkAndGetOuterReferences(sub: LogicalPlan): Seq[Expression] = { val outerReferences = ArrayBuffer.empty[Expression] + // Validate that correlated aggregate expression do not contain a mixture + // of outer and local references. + def checkMixedReferencesInsideAggregateExpr(expr: Expression): Unit = { + expr.foreach { + case a: AggregateExpression if containsOuter(a) => + val outer = a.collect { case OuterReference(e) => e.toAttribute } + val local = a.references -- outer + if (local.nonEmpty) { + val msg = + s""" + |Found an aggregate expression in a correlated predicate that has both + |outer and local references, which is not supported yet. + |Aggregate expression: ${SubExprUtils.stripOuterReference(a).sql}, + |Outer references: ${outer.map(_.sql).mkString(", ")}, + |Local references: ${local.map(_.sql).mkString(", ")}. + """.stripMargin.replace("\n", " ").trim() + failAnalysis(msg) + } + case _ => + } + } + // Make sure a plan's subtree does not contain outer references def failOnOuterReferenceInSubTree(p: LogicalPlan): Unit = { if (hasOuterReferences(p)) { @@ -1211,9 +1233,12 @@ class Analyzer( } } - // Make sure a plan's expressions do not contain outer references - def failOnOuterReference(p: LogicalPlan): Unit = { - if (p.expressions.exists(containsOuter)) { + // Make sure a plan's expressions do not contain : + // 1. Aggregate expressions that have mixture of outer and local references. + // 2. Expressions containing outer references on plan nodes other than Filter. + def failOnInvalidOuterReference(p: LogicalPlan): Unit = { + p.expressions.foreach(checkMixedReferencesInsideAggregateExpr) + if (!p.isInstanceOf[Filter] && p.expressions.exists(containsOuter)) { failAnalysis( "Expressions referencing the outer query are not supported outside of WHERE/HAVING " + s"clauses:\n$p") @@ -1283,9 +1308,9 @@ class Analyzer( // These operators can be anywhere in a correlated subquery. // so long as they do not host outer references in the operators. case s: Sort => - failOnOuterReference(s) + failOnInvalidOuterReference(s) case r: RepartitionByExpression => - failOnOuterReference(r) + failOnInvalidOuterReference(r) // Category 3: // Filter is one of the two operators allowed to host correlated expressions. @@ -1299,6 +1324,8 @@ class Analyzer( case _: EqualTo | _: EqualNullSafe => false case _ => true } + + failOnInvalidOuterReference(f) // The aggregate expressions are treated in a special way by getOuterReferences. If the // aggregate expression contains only outer reference attributes then the entire aggregate // expression is isolated as an OuterReference. @@ -1308,7 +1335,7 @@ class Analyzer( // Project cannot host any correlated expressions // but can be anywhere in a correlated subquery. case p: Project => - failOnOuterReference(p) + failOnInvalidOuterReference(p) // Aggregate cannot host any correlated expressions // It can be on a correlation path if the correlation contains @@ -1316,7 +1343,7 @@ class Analyzer( // It cannot be on a correlation path if the correlation has // non-equality correlated predicates. case a: Aggregate => - failOnOuterReference(a) + failOnInvalidOuterReference(a) failOnNonEqualCorrelatedPredicate(foundNonEqualCorrelatedPred, a) // Join can host correlated expressions. @@ -1324,7 +1351,7 @@ class Analyzer( joinType match { // Inner join, like Filter, can be anywhere. case _: InnerLike => - failOnOuterReference(j) + failOnInvalidOuterReference(j) // Left outer join's right operand cannot be on a correlation path. // LeftAnti and ExistenceJoin are special cases of LeftOuter. @@ -1335,12 +1362,12 @@ class Analyzer( // Any correlated references in the subplan // of the right operand cannot be pulled up. case LeftOuter | LeftSemi | LeftAnti | ExistenceJoin(_) => - failOnOuterReference(j) + failOnInvalidOuterReference(j) failOnOuterReferenceInSubTree(right) // Likewise, Right outer join's left operand cannot be on a correlation path. case RightOuter => - failOnOuterReference(j) + failOnInvalidOuterReference(j) failOnOuterReferenceInSubTree(left) // Any other join types not explicitly listed above, @@ -1356,7 +1383,7 @@ class Analyzer( // Note: // Generator with join=false is treated as Category 4. case g: Generate if g.join => - failOnOuterReference(g) + failOnInvalidOuterReference(g) // Category 4: Any other operators not in the above 3 categories // cannot be on a correlation path, that is they are allowed only diff --git a/sql/core/src/test/resources/sql-tests/inputs/subquery/negative-cases/invalid-correlation.sql b/sql/core/src/test/resources/sql-tests/inputs/subquery/negative-cases/invalid-correlation.sql index cf93c5a835971..e22cade936792 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/subquery/negative-cases/invalid-correlation.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/subquery/negative-cases/invalid-correlation.sql @@ -1,42 +1,72 @@ -- The test file contains negative test cases -- of invalid queries where error messages are expected. -create temporary view t1 as select * from values +CREATE TEMPORARY VIEW t1 AS SELECT * FROM VALUES (1, 2, 3) -as t1(t1a, t1b, t1c); +AS t1(t1a, t1b, t1c); -create temporary view t2 as select * from values +CREATE TEMPORARY VIEW t2 AS SELECT * FROM VALUES (1, 0, 1) -as t2(t2a, t2b, t2c); +AS t2(t2a, t2b, t2c); -create temporary view t3 as select * from values +CREATE TEMPORARY VIEW t3 AS SELECT * FROM VALUES (3, 1, 2) -as t3(t3a, t3b, t3c); +AS t3(t3a, t3b, t3c); -- TC 01.01 -- The column t2b in the SELECT of the subquery is invalid -- because it is neither an aggregate function nor a GROUP BY column. -select t1a, t2b -from t1, t2 -where t1b = t2c -and t2b = (select max(avg) - from (select t2b, avg(t2b) avg - from t2 - where t2a = t1.t1b +SELECT t1a, t2b +FROM t1, t2 +WHERE t1b = t2c +AND t2b = (SELECT max(avg) + FROM (SELECT t2b, avg(t2b) avg + FROM t2 + WHERE t2a = t1.t1b ) ) ; -- TC 01.02 -- Invalid due to the column t2b not part of the output from table t2. -select * -from t1 -where t1a in (select min(t2a) - from t2 - group by t2c - having t2c in (select max(t3c) - from t3 - group by t3b - having t3b > t2b )) +SELECT * +FROM t1 +WHERE t1a IN (SELECT min(t2a) + FROM t2 + GROUP BY t2c + HAVING t2c IN (SELECT max(t3c) + FROM t3 + GROUP BY t3b + HAVING t3b > t2b )) ; +-- TC 01.03 +-- Invalid due to mixure of outer and local references under an AggegatedExpression +-- in a correlated predicate +SELECT t1a +FROM t1 +GROUP BY 1 +HAVING EXISTS (SELECT 1 + FROM t2 + WHERE t2a < min(t1a + t2a)); + +-- TC 01.04 +-- Invalid due to mixure of outer and local references under an AggegatedExpression +SELECT t1a +FROM t1 +WHERE t1a IN (SELECT t2a + FROM t2 + WHERE EXISTS (SELECT 1 + FROM t3 + GROUP BY 1 + HAVING min(t2a + t3a) > 1)); + +-- TC 01.05 +-- Invalid due to outer reference appearing in projection list +SELECT t1a +FROM t1 +WHERE t1a IN (SELECT t2a + FROM t2 + WHERE EXISTS (SELECT min(t2a) + FROM t3)); + diff --git a/sql/core/src/test/resources/sql-tests/results/subquery/negative-cases/invalid-correlation.sql.out b/sql/core/src/test/resources/sql-tests/results/subquery/negative-cases/invalid-correlation.sql.out index f7bbb35aad6ce..e4b1a2dbc675c 100644 --- a/sql/core/src/test/resources/sql-tests/results/subquery/negative-cases/invalid-correlation.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/subquery/negative-cases/invalid-correlation.sql.out @@ -1,11 +1,11 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 5 +-- Number of queries: 8 -- !query 0 -create temporary view t1 as select * from values +CREATE TEMPORARY VIEW t1 AS SELECT * FROM VALUES (1, 2, 3) -as t1(t1a, t1b, t1c) +AS t1(t1a, t1b, t1c) -- !query 0 schema struct<> -- !query 0 output @@ -13,9 +13,9 @@ struct<> -- !query 1 -create temporary view t2 as select * from values +CREATE TEMPORARY VIEW t2 AS SELECT * FROM VALUES (1, 0, 1) -as t2(t2a, t2b, t2c) +AS t2(t2a, t2b, t2c) -- !query 1 schema struct<> -- !query 1 output @@ -23,9 +23,9 @@ struct<> -- !query 2 -create temporary view t3 as select * from values +CREATE TEMPORARY VIEW t3 AS SELECT * FROM VALUES (3, 1, 2) -as t3(t3a, t3b, t3c) +AS t3(t3a, t3b, t3c) -- !query 2 schema struct<> -- !query 2 output @@ -33,13 +33,13 @@ struct<> -- !query 3 -select t1a, t2b -from t1, t2 -where t1b = t2c -and t2b = (select max(avg) - from (select t2b, avg(t2b) avg - from t2 - where t2a = t1.t1b +SELECT t1a, t2b +FROM t1, t2 +WHERE t1b = t2c +AND t2b = (SELECT max(avg) + FROM (SELECT t2b, avg(t2b) avg + FROM t2 + WHERE t2a = t1.t1b ) ) -- !query 3 schema @@ -50,17 +50,67 @@ grouping expressions sequence is empty, and 't2.`t2b`' is not an aggregate funct -- !query 4 -select * -from t1 -where t1a in (select min(t2a) - from t2 - group by t2c - having t2c in (select max(t3c) - from t3 - group by t3b - having t3b > t2b )) +SELECT * +FROM t1 +WHERE t1a IN (SELECT min(t2a) + FROM t2 + GROUP BY t2c + HAVING t2c IN (SELECT max(t3c) + FROM t3 + GROUP BY t3b + HAVING t3b > t2b )) -- !query 4 schema struct<> -- !query 4 output org.apache.spark.sql.AnalysisException resolved attribute(s) t2b#x missing from min(t2a)#x,t2c#x in operator !Filter t2c#x IN (list#x [t2b#x]); + + +-- !query 5 +SELECT t1a +FROM t1 +GROUP BY 1 +HAVING EXISTS (SELECT 1 + FROM t2 + WHERE t2a < min(t1a + t2a)) +-- !query 5 schema +struct<> +-- !query 5 output +org.apache.spark.sql.AnalysisException +Found an aggregate expression in a correlated predicate that has both outer and local references, which is not supported yet. Aggregate expression: min((t1.`t1a` + t2.`t2a`)), Outer references: t1.`t1a`, Local references: t2.`t2a`.; + + +-- !query 6 +SELECT t1a +FROM t1 +WHERE t1a IN (SELECT t2a + FROM t2 + WHERE EXISTS (SELECT 1 + FROM t3 + GROUP BY 1 + HAVING min(t2a + t3a) > 1)) +-- !query 6 schema +struct<> +-- !query 6 output +org.apache.spark.sql.AnalysisException +Found an aggregate expression in a correlated predicate that has both outer and local references, which is not supported yet. Aggregate expression: min((t2.`t2a` + t3.`t3a`)), Outer references: t2.`t2a`, Local references: t3.`t3a`.; + + +-- !query 7 +SELECT t1a +FROM t1 +WHERE t1a IN (SELECT t2a + FROM t2 + WHERE EXISTS (SELECT min(t2a) + FROM t3)) +-- !query 7 schema +struct<> +-- !query 7 output +org.apache.spark.sql.AnalysisException +Expressions referencing the outer query are not supported outside of WHERE/HAVING clauses: +Aggregate [min(outer(t2a#x)) AS min(outer())#x] ++- SubqueryAlias t3 + +- Project [t3a#x, t3b#x, t3c#x] + +- SubqueryAlias t3 + +- LocalRelation [t3a#x, t3b#x, t3c#x] +; diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala index 0f0199cbe2777..131abf7c1e5d3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala @@ -822,12 +822,25 @@ class SubquerySuite extends QueryTest with SharedSQLContext { checkAnswer( sql( """ - | select c2 - | from t1 - | where exists (select * - | from t2 lateral view explode(arr_c2) q as c2 - where t1.c1 = t2.c1)""".stripMargin), + | SELECT c2 + | FROM t1 + | WHERE EXISTS (SELECT * + | FROM t2 LATERAL VIEW explode(arr_c2) q AS c2 + WHERE t1.c1 = t2.c1)""".stripMargin), Row(1) :: Row(0) :: Nil) + + val msg1 = intercept[AnalysisException] { + sql( + """ + | SELECT c1 + | FROM t2 + | WHERE EXISTS (SELECT * + | FROM t1 LATERAL VIEW explode(t2.arr_c2) q AS c2 + | WHERE t1.c1 = t2.c1) + """.stripMargin) + } + assert(msg1.getMessage.contains( + "Expressions referencing the outer query are not supported outside of WHERE/HAVING")) } } From 033206355339677812a250b2b64818a261871fd2 Mon Sep 17 00:00:00 2001 From: Herman van Hovell Date: Thu, 20 Apr 2017 22:37:04 +0200 Subject: [PATCH 0305/1765] [SPARK-20410][SQL] Make sparkConf a def in SharedSQLContext ## What changes were proposed in this pull request? It is kind of annoying that `SharedSQLContext.sparkConf` is a val when overriding test cases, because you cannot call `super` on it. This PR makes it a function. ## How was this patch tested? Existing tests. Author: Herman van Hovell Closes #17705 from hvanhovell/SPARK-20410. --- .../spark/sql/AggregateHashMapSuite.scala | 35 ++++++++----------- .../DatasetSerializerRegistratorSuite.scala | 12 +++---- .../DataSourceScanExecRedactionSuite.scala | 11 ++---- .../datasources/FileSourceStrategySuite.scala | 2 +- .../CompactibleFileStreamLogSuite.scala | 4 +-- .../streaming/HDFSMetadataLogSuite.scala | 4 +-- .../spark/sql/test/SharedSQLContext.scala | 7 ++-- 7 files changed, 32 insertions(+), 43 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/AggregateHashMapSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/AggregateHashMapSuite.scala index 3e85d95523125..7e61a68025158 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/AggregateHashMapSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/AggregateHashMapSuite.scala @@ -19,13 +19,12 @@ package org.apache.spark.sql import org.scalatest.BeforeAndAfter -class SingleLevelAggregateHashMapSuite extends DataFrameAggregateSuite with BeforeAndAfter { +import org.apache.spark.SparkConf - protected override def beforeAll(): Unit = { - sparkConf.set("spark.sql.codegen.fallback", "false") - sparkConf.set("spark.sql.codegen.aggregate.map.twolevel.enable", "false") - super.beforeAll() - } +class SingleLevelAggregateHashMapSuite extends DataFrameAggregateSuite with BeforeAndAfter { + override protected def sparkConf: SparkConf = super.sparkConf + .set("spark.sql.codegen.fallback", "false") + .set("spark.sql.codegen.aggregate.map.twolevel.enable", "false") // adding some checking after each test is run, assuring that the configs are not changed // in test code @@ -38,12 +37,9 @@ class SingleLevelAggregateHashMapSuite extends DataFrameAggregateSuite with Befo } class TwoLevelAggregateHashMapSuite extends DataFrameAggregateSuite with BeforeAndAfter { - - protected override def beforeAll(): Unit = { - sparkConf.set("spark.sql.codegen.fallback", "false") - sparkConf.set("spark.sql.codegen.aggregate.map.twolevel.enable", "true") - super.beforeAll() - } + override protected def sparkConf: SparkConf = super.sparkConf + .set("spark.sql.codegen.fallback", "false") + .set("spark.sql.codegen.aggregate.map.twolevel.enable", "true") // adding some checking after each test is run, assuring that the configs are not changed // in test code @@ -55,15 +51,14 @@ class TwoLevelAggregateHashMapSuite extends DataFrameAggregateSuite with BeforeA } } -class TwoLevelAggregateHashMapWithVectorizedMapSuite extends DataFrameAggregateSuite with -BeforeAndAfter { +class TwoLevelAggregateHashMapWithVectorizedMapSuite + extends DataFrameAggregateSuite + with BeforeAndAfter { - protected override def beforeAll(): Unit = { - sparkConf.set("spark.sql.codegen.fallback", "false") - sparkConf.set("spark.sql.codegen.aggregate.map.twolevel.enable", "true") - sparkConf.set("spark.sql.codegen.aggregate.map.vectorized.enable", "true") - super.beforeAll() - } + override protected def sparkConf: SparkConf = super.sparkConf + .set("spark.sql.codegen.fallback", "false") + .set("spark.sql.codegen.aggregate.map.twolevel.enable", "true") + .set("spark.sql.codegen.aggregate.map.vectorized.enable", "true") // adding some checking after each test is run, assuring that the configs are not changed // in test code diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSerializerRegistratorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSerializerRegistratorSuite.scala index 92c5656f65bb4..68f7de047b392 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSerializerRegistratorSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSerializerRegistratorSuite.scala @@ -20,9 +20,9 @@ package org.apache.spark.sql import com.esotericsoftware.kryo.{Kryo, Serializer} import com.esotericsoftware.kryo.io.{Input, Output} +import org.apache.spark.SparkConf import org.apache.spark.serializer.KryoRegistrator import org.apache.spark.sql.test.SharedSQLContext -import org.apache.spark.sql.test.TestSparkSession /** * Test suite to test Kryo custom registrators. @@ -30,12 +30,10 @@ import org.apache.spark.sql.test.TestSparkSession class DatasetSerializerRegistratorSuite extends QueryTest with SharedSQLContext { import testImplicits._ - /** - * Initialize the [[TestSparkSession]] with a [[KryoRegistrator]]. - */ - protected override def beforeAll(): Unit = { - sparkConf.set("spark.kryo.registrator", TestRegistrator().getClass.getCanonicalName) - super.beforeAll() + + override protected def sparkConf: SparkConf = { + // Make sure we use the KryoRegistrator + super.sparkConf.set("spark.kryo.registrator", TestRegistrator().getClass.getCanonicalName) } test("Kryo registrator") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/DataSourceScanExecRedactionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/DataSourceScanExecRedactionSuite.scala index 05a2b2c862c73..f7f1ccea281c1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/DataSourceScanExecRedactionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/DataSourceScanExecRedactionSuite.scala @@ -18,22 +18,17 @@ package org.apache.spark.sql.execution import org.apache.hadoop.fs.Path +import org.apache.spark.SparkConf import org.apache.spark.sql.QueryTest import org.apache.spark.sql.test.SharedSQLContext -import org.apache.spark.util.Utils /** * Suite that tests the redaction of DataSourceScanExec */ class DataSourceScanExecRedactionSuite extends QueryTest with SharedSQLContext { - import Utils._ - - override def beforeAll(): Unit = { - sparkConf.set("spark.redaction.string.regex", - "file:/[\\w_]+") - super.beforeAll() - } + override protected def sparkConf: SparkConf = super.sparkConf + .set("spark.redaction.string.regex", "file:/[\\w_]+") test("treeString is redacted") { withTempDir { dir => diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategySuite.scala index f36162858bf7a..8703fe96e5878 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategySuite.scala @@ -42,7 +42,7 @@ import org.apache.spark.util.Utils class FileSourceStrategySuite extends QueryTest with SharedSQLContext with PredicateHelper { import testImplicits._ - protected override val sparkConf = new SparkConf().set("spark.default.parallelism", "1") + protected override def sparkConf = super.sparkConf.set("spark.default.parallelism", "1") test("unpartitioned table, single partition") { val table = diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/CompactibleFileStreamLogSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/CompactibleFileStreamLogSuite.scala index 20ac06f048c6f..3d480b148db55 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/CompactibleFileStreamLogSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/CompactibleFileStreamLogSuite.scala @@ -28,8 +28,8 @@ import org.apache.spark.sql.test.SharedSQLContext class CompactibleFileStreamLogSuite extends SparkFunSuite with SharedSQLContext { /** To avoid caching of FS objects */ - override protected val sparkConf = - new SparkConf().set(s"spark.hadoop.fs.$scheme.impl.disable.cache", "true") + override protected def sparkConf = + super.sparkConf.set(s"spark.hadoop.fs.$scheme.impl.disable.cache", "true") import CompactibleFileStreamLog._ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/HDFSMetadataLogSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/HDFSMetadataLogSuite.scala index 662c4466b21b2..7689bc03a4ccf 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/HDFSMetadataLogSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/HDFSMetadataLogSuite.scala @@ -38,8 +38,8 @@ import org.apache.spark.util.UninterruptibleThread class HDFSMetadataLogSuite extends SparkFunSuite with SharedSQLContext { /** To avoid caching of FS objects */ - override protected val sparkConf = - new SparkConf().set(s"spark.hadoop.fs.$scheme.impl.disable.cache", "true") + override protected def sparkConf = + super.sparkConf.set(s"spark.hadoop.fs.$scheme.impl.disable.cache", "true") private implicit def toOption[A](a: A): Option[A] = Option(a) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSQLContext.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSQLContext.scala index 3d76e05f616d5..81c69a338abcc 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSQLContext.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSQLContext.scala @@ -30,7 +30,9 @@ import org.apache.spark.sql.{SparkSession, SQLContext} */ trait SharedSQLContext extends SQLTestUtils with BeforeAndAfterEach with Eventually { - protected val sparkConf = new SparkConf() + protected def sparkConf = { + new SparkConf().set("spark.hadoop.fs.file.impl", classOf[DebugFilesystem].getName) + } /** * The [[TestSparkSession]] to use for all tests in this suite. @@ -51,8 +53,7 @@ trait SharedSQLContext extends SQLTestUtils with BeforeAndAfterEach with Eventua protected implicit def sqlContext: SQLContext = _spark.sqlContext protected def createSparkSession: TestSparkSession = { - new TestSparkSession( - sparkConf.set("spark.hadoop.fs.file.impl", classOf[DebugFilesystem].getName)) + new TestSparkSession(sparkConf) } /** From 592f5c89349f3c5b6ec0531c6514b8f7d95ad8da Mon Sep 17 00:00:00 2001 From: jerryshao Date: Thu, 20 Apr 2017 16:02:09 -0700 Subject: [PATCH 0306/1765] [SPARK-20172][CORE] Add file permission check when listing files in FsHistoryProvider ## What changes were proposed in this pull request? In the current Spark's HistoryServer we expected to get `AccessControlException` during listing all the files, but unfortunately it was not worked because we actually doesn't check the access permission and no other calls will throw such exception. What was worse is that this check will be deferred until reading files, which is not necessary and quite verbose, since it will be printed out the exception in every 10 seconds when checking the files. So here with this fix, we actually check the read permission during listing the files, which could avoid unnecessary file read later on and suppress the verbose log. ## How was this patch tested? Add unit test to verify. Author: jerryshao Closes #17495 from jerryshao/SPARK-20172. --- .../apache/spark/deploy/SparkHadoopUtil.scala | 23 +++++ .../deploy/history/FsHistoryProvider.scala | 28 +++--- .../spark/deploy/SparkHadoopUtilSuite.scala | 97 +++++++++++++++++++ .../history/FsHistoryProviderSuite.scala | 16 ++- 4 files changed, 145 insertions(+), 19 deletions(-) create mode 100644 core/src/test/scala/org/apache/spark/deploy/SparkHadoopUtilSuite.scala diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala b/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala index bae7a3f307f52..9cc321af4bde2 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala @@ -28,6 +28,7 @@ import scala.util.control.NonFatal import com.google.common.primitives.Longs import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.{FileStatus, FileSystem, Path, PathFilter} +import org.apache.hadoop.fs.permission.FsAction import org.apache.hadoop.mapred.JobConf import org.apache.hadoop.security.{Credentials, UserGroupInformation} import org.apache.hadoop.security.token.{Token, TokenIdentifier} @@ -353,6 +354,28 @@ class SparkHadoopUtil extends Logging { } buffer.toString } + + private[spark] def checkAccessPermission(status: FileStatus, mode: FsAction): Boolean = { + val perm = status.getPermission + val ugi = UserGroupInformation.getCurrentUser + + if (ugi.getShortUserName == status.getOwner) { + if (perm.getUserAction.implies(mode)) { + return true + } + } else if (ugi.getGroupNames.contains(status.getGroup)) { + if (perm.getGroupAction.implies(mode)) { + return true + } + } else if (perm.getOtherAction.implies(mode)) { + return true + } + + logDebug(s"Permission denied: user=${ugi.getShortUserName}, " + + s"path=${status.getPath}:${status.getOwner}:${status.getGroup}" + + s"${if (status.isDirectory) "d" else "-"}$perm") + false + } } object SparkHadoopUtil { diff --git a/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala b/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala index 9012736bc2745..f4235df245128 100644 --- a/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala +++ b/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala @@ -27,7 +27,8 @@ import scala.xml.Node import com.google.common.io.ByteStreams import com.google.common.util.concurrent.{MoreExecutors, ThreadFactoryBuilder} -import org.apache.hadoop.fs.{FileStatus, FileSystem, Path} +import org.apache.hadoop.fs.{FileStatus, Path} +import org.apache.hadoop.fs.permission.FsAction import org.apache.hadoop.hdfs.DistributedFileSystem import org.apache.hadoop.hdfs.protocol.HdfsConstants import org.apache.hadoop.security.AccessControlException @@ -318,21 +319,14 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) // scan for modified applications, replay and merge them val logInfos: Seq[FileStatus] = statusList .filter { entry => - try { - val prevFileSize = fileToAppInfo.get(entry.getPath()).map{_.fileSize}.getOrElse(0L) - !entry.isDirectory() && - // FsHistoryProvider generates a hidden file which can't be read. Accidentally - // reading a garbage file is safe, but we would log an error which can be scary to - // the end-user. - !entry.getPath().getName().startsWith(".") && - prevFileSize < entry.getLen() - } catch { - case e: AccessControlException => - // Do not use "logInfo" since these messages can get pretty noisy if printed on - // every poll. - logDebug(s"No permission to read $entry, ignoring.") - false - } + val prevFileSize = fileToAppInfo.get(entry.getPath()).map{_.fileSize}.getOrElse(0L) + !entry.isDirectory() && + // FsHistoryProvider generates a hidden file which can't be read. Accidentally + // reading a garbage file is safe, but we would log an error which can be scary to + // the end-user. + !entry.getPath().getName().startsWith(".") && + prevFileSize < entry.getLen() && + SparkHadoopUtil.get.checkAccessPermission(entry, FsAction.READ) } .flatMap { entry => Some(entry) } .sortWith { case (entry1, entry2) => @@ -445,7 +439,7 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) /** * Replay the log files in the list and merge the list of old applications with new ones */ - private def mergeApplicationListing(fileStatus: FileStatus): Unit = { + protected def mergeApplicationListing(fileStatus: FileStatus): Unit = { val newAttempts = try { val eventsFilter: ReplayEventsFilter = { eventString => eventString.startsWith(APPL_START_EVENT_PREFIX) || diff --git a/core/src/test/scala/org/apache/spark/deploy/SparkHadoopUtilSuite.scala b/core/src/test/scala/org/apache/spark/deploy/SparkHadoopUtilSuite.scala new file mode 100644 index 0000000000000..ab24a76e20a30 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/deploy/SparkHadoopUtilSuite.scala @@ -0,0 +1,97 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.deploy + +import java.security.PrivilegedExceptionAction + +import scala.util.Random + +import org.apache.hadoop.fs.FileStatus +import org.apache.hadoop.fs.permission.{FsAction, FsPermission} +import org.apache.hadoop.security.UserGroupInformation +import org.scalatest.Matchers + +import org.apache.spark.SparkFunSuite + +class SparkHadoopUtilSuite extends SparkFunSuite with Matchers { + test("check file permission") { + import FsAction._ + val testUser = s"user-${Random.nextInt(100)}" + val testGroups = Array(s"group-${Random.nextInt(100)}") + val testUgi = UserGroupInformation.createUserForTesting(testUser, testGroups) + + testUgi.doAs(new PrivilegedExceptionAction[Void] { + override def run(): Void = { + val sparkHadoopUtil = new SparkHadoopUtil + + // If file is owned by user and user has access permission + var status = fileStatus(testUser, testGroups.head, READ_WRITE, READ_WRITE, NONE) + sparkHadoopUtil.checkAccessPermission(status, READ) should be(true) + sparkHadoopUtil.checkAccessPermission(status, WRITE) should be(true) + + // If file is owned by user but user has no access permission + status = fileStatus(testUser, testGroups.head, NONE, READ_WRITE, NONE) + sparkHadoopUtil.checkAccessPermission(status, READ) should be(false) + sparkHadoopUtil.checkAccessPermission(status, WRITE) should be(false) + + val otherUser = s"test-${Random.nextInt(100)}" + val otherGroup = s"test-${Random.nextInt(100)}" + + // If file is owned by user's group and user's group has access permission + status = fileStatus(otherUser, testGroups.head, NONE, READ_WRITE, NONE) + sparkHadoopUtil.checkAccessPermission(status, READ) should be(true) + sparkHadoopUtil.checkAccessPermission(status, WRITE) should be(true) + + // If file is owned by user's group but user's group has no access permission + status = fileStatus(otherUser, testGroups.head, READ_WRITE, NONE, NONE) + sparkHadoopUtil.checkAccessPermission(status, READ) should be(false) + sparkHadoopUtil.checkAccessPermission(status, WRITE) should be(false) + + // If file is owned by other user and this user has access permission + status = fileStatus(otherUser, otherGroup, READ_WRITE, READ_WRITE, READ_WRITE) + sparkHadoopUtil.checkAccessPermission(status, READ) should be(true) + sparkHadoopUtil.checkAccessPermission(status, WRITE) should be(true) + + // If file is owned by other user but this user has no access permission + status = fileStatus(otherUser, otherGroup, READ_WRITE, READ_WRITE, NONE) + sparkHadoopUtil.checkAccessPermission(status, READ) should be(false) + sparkHadoopUtil.checkAccessPermission(status, WRITE) should be(false) + + null + } + }) + } + + private def fileStatus( + owner: String, + group: String, + userAction: FsAction, + groupAction: FsAction, + otherAction: FsAction): FileStatus = { + new FileStatus(0L, + false, + 0, + 0L, + 0L, + 0L, + new FsPermission(userAction, groupAction, otherAction), + owner, + group, + null) + } +} diff --git a/core/src/test/scala/org/apache/spark/deploy/history/FsHistoryProviderSuite.scala b/core/src/test/scala/org/apache/spark/deploy/history/FsHistoryProviderSuite.scala index ec580a44b8e76..456158d41b93f 100644 --- a/core/src/test/scala/org/apache/spark/deploy/history/FsHistoryProviderSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/history/FsHistoryProviderSuite.scala @@ -27,6 +27,7 @@ import scala.concurrent.duration._ import scala.language.postfixOps import com.google.common.io.{ByteStreams, Files} +import org.apache.hadoop.fs.FileStatus import org.apache.hadoop.hdfs.DistributedFileSystem import org.json4s.jackson.JsonMethods._ import org.mockito.Matchers.any @@ -130,9 +131,19 @@ class FsHistoryProviderSuite extends SparkFunSuite with BeforeAndAfter with Matc } } - test("SPARK-3697: ignore directories that cannot be read.") { + test("SPARK-3697: ignore files that cannot be read.") { // setReadable(...) does not work on Windows. Please refer JDK-6728842. assume(!Utils.isWindows) + + class TestFsHistoryProvider extends FsHistoryProvider(createTestConf()) { + var mergeApplicationListingCall = 0 + override protected def mergeApplicationListing(fileStatus: FileStatus): Unit = { + super.mergeApplicationListing(fileStatus) + mergeApplicationListingCall += 1 + } + } + val provider = new TestFsHistoryProvider + val logFile1 = newLogFile("new1", None, inProgress = false) writeFile(logFile1, true, None, SparkListenerApplicationStart("app1-1", Some("app1-1"), 1L, "test", None), @@ -145,10 +156,11 @@ class FsHistoryProviderSuite extends SparkFunSuite with BeforeAndAfter with Matc ) logFile2.setReadable(false, false) - val provider = new FsHistoryProvider(createTestConf()) updateAndCheck(provider) { list => list.size should be (1) } + + provider.mergeApplicationListingCall should be (1) } test("history file is renamed from inprogress to completed") { From 0368eb9d86634c83b3140ce3190cb9e0d0b7fd86 Mon Sep 17 00:00:00 2001 From: Juliusz Sompolski Date: Fri, 21 Apr 2017 09:49:42 +0800 Subject: [PATCH 0307/1765] [SPARK-20367] Properly unescape column names of partitioning columns parsed from paths. ## What changes were proposed in this pull request? When infering partitioning schema from paths, the column in parsePartitionColumn should be unescaped with unescapePathName, just like it is being done in e.g. parsePathFragmentAsSeq. ## How was this patch tested? Added a test to FileIndexSuite. Author: Juliusz Sompolski Closes #17703 from juliuszsompolski/SPARK-20367. --- .../execution/datasources/PartitioningUtils.scala | 2 +- .../sql/execution/datasources/FileIndexSuite.scala | 12 ++++++++++++ 2 files changed, 13 insertions(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala index c3583209efc56..2d70172487e17 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala @@ -243,7 +243,7 @@ object PartitioningUtils { if (equalSignIndex == -1) { None } else { - val columnName = columnSpec.take(equalSignIndex) + val columnName = unescapePathName(columnSpec.take(equalSignIndex)) assert(columnName.nonEmpty, s"Empty partition column name in '$columnSpec'") val rawColumnValue = columnSpec.drop(equalSignIndex + 1) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileIndexSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileIndexSuite.scala index a9511cbd9e4cf..b4616826e40b3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileIndexSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileIndexSuite.scala @@ -27,6 +27,7 @@ import org.apache.hadoop.fs.{FileStatus, Path, RawLocalFileSystem} import org.apache.spark.metrics.source.HiveCatalogMetrics import org.apache.spark.sql.catalyst.util._ +import org.apache.spark.sql.functions.col import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.util.{KnownSizeEstimation, SizeEstimator} @@ -236,6 +237,17 @@ class FileIndexSuite extends SharedSQLContext { val fileStatusCache = FileStatusCache.getOrCreate(spark) fileStatusCache.putLeafFiles(new Path("/tmp", "abc"), files.toArray) } + + test("SPARK-20367 - properly unescape column names in inferPartitioning") { + withTempPath { path => + val colToUnescape = "Column/#%'?" + spark + .range(1) + .select(col("id").as(colToUnescape), col("id")) + .write.partitionBy(colToUnescape).parquet(path.getAbsolutePath) + assert(spark.read.parquet(path.getAbsolutePath).schema.exists(_.name == colToUnescape)) + } + } } class FakeParentPathFileSystem extends RawLocalFileSystem { From 760c8d088df1d35d7b8942177d47bc1677daf143 Mon Sep 17 00:00:00 2001 From: Herman van Hovell Date: Fri, 21 Apr 2017 10:06:12 +0800 Subject: [PATCH 0308/1765] [SPARK-20329][SQL] Make timezone aware expression without timezone unresolved ## What changes were proposed in this pull request? A cast expression with a resolved time zone is not equal to a cast expression without a resolved time zone. The `ResolveAggregateFunction` assumed that these expression were the same, and would fail to resolve `HAVING` clauses which contain a `Cast` expression. This is in essence caused by the fact that a `TimeZoneAwareExpression` can be resolved without a set time zone. This PR fixes this, and makes a `TimeZoneAwareExpression` unresolved as long as it has no TimeZone set. ## How was this patch tested? Added a regression test to the `SQLQueryTestSuite.having` file. Author: Herman van Hovell Closes #17641 from hvanhovell/SPARK-20329. --- .../sql/catalyst/analysis/Analyzer.scala | 20 +----- .../analysis/ResolveInlineTables.scala | 10 +-- .../catalyst/analysis/timeZoneAnalysis.scala | 61 +++++++++++++++++++ .../spark/sql/catalyst/analysis/view.scala | 4 +- .../expressions/datetimeExpressions.scala | 4 +- .../analysis/ResolveInlineTablesSuite.scala | 10 +-- .../catalyst/analysis/TypeCoercionSuite.scala | 35 ++++++----- .../sql/catalyst/expressions/CastSuite.scala | 4 +- .../expressions/DateExpressionsSuite.scala | 6 +- .../expressions/ExpressionEvalHelper.scala | 7 ++- .../spark/sql/execution/SparkPlanner.scala | 2 +- .../datasources/DataSourceStrategy.scala | 20 +++--- .../sql/execution/datasources/rules.scala | 6 +- .../internal/BaseSessionStateBuilder.scala | 2 +- .../resources/sql-tests/inputs/having.sql | 3 + .../sql-tests/results/having.sql.out | 11 +++- .../spark/sql/sources/BucketedReadSuite.scala | 3 +- .../sql/sources/DataSourceAnalysisSuite.scala | 16 +++-- .../sql/hive/HiveSessionStateBuilder.scala | 2 +- 19 files changed, 148 insertions(+), 78 deletions(-) create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/timeZoneAnalysis.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index eafeb4ac1ae55..dcadbbc90f438 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -150,6 +150,7 @@ class Analyzer( ResolveAggregateFunctions :: TimeWindowing :: ResolveInlineTables(conf) :: + ResolveTimeZone(conf) :: TypeCoercion.typeCoercionRules ++ extendedResolutionRules : _*), Batch("Post-Hoc Resolution", Once, postHocResolutionRules: _*), @@ -161,8 +162,6 @@ class Analyzer( HandleNullInputsForUDF), Batch("FixNullability", Once, FixNullability), - Batch("ResolveTimeZone", Once, - ResolveTimeZone), Batch("Subquery", Once, UpdateOuterReferences), Batch("Cleanup", fixedPoint, @@ -2368,23 +2367,6 @@ class Analyzer( } } } - - /** - * Replace [[TimeZoneAwareExpression]] without timezone id by its copy with session local - * time zone. - */ - object ResolveTimeZone extends Rule[LogicalPlan] { - - override def apply(plan: LogicalPlan): LogicalPlan = plan.resolveExpressions { - case e: TimeZoneAwareExpression if e.timeZoneId.isEmpty => - e.withTimeZone(conf.sessionLocalTimeZone) - // Casts could be added in the subquery plan through the rule TypeCoercion while coercing - // the types between the value expression and list query expression of IN expression. - // We need to subject the subquery plan through ResolveTimeZone again to setup timezone - // information for time zone aware expressions. - case e: ListQuery => e.withNewPlan(apply(e.plan)) - } - } } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveInlineTables.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveInlineTables.scala index a991dd96e2828..f2df3e132629f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveInlineTables.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveInlineTables.scala @@ -20,7 +20,6 @@ package org.apache.spark.sql.catalyst.analysis import scala.util.control.NonFatal import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.{Cast, TimeZoneAwareExpression} import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan} import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.internal.SQLConf @@ -29,7 +28,7 @@ import org.apache.spark.sql.types.{StructField, StructType} /** * An analyzer rule that replaces [[UnresolvedInlineTable]] with [[LocalRelation]]. */ -case class ResolveInlineTables(conf: SQLConf) extends Rule[LogicalPlan] { +case class ResolveInlineTables(conf: SQLConf) extends Rule[LogicalPlan] with CastSupport { override def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { case table: UnresolvedInlineTable if table.expressionsResolved => validateInputDimension(table) @@ -99,12 +98,9 @@ case class ResolveInlineTables(conf: SQLConf) extends Rule[LogicalPlan] { val castedExpr = if (e.dataType.sameType(targetType)) { e } else { - Cast(e, targetType) + cast(e, targetType) } - castedExpr.transform { - case e: TimeZoneAwareExpression if e.timeZoneId.isEmpty => - e.withTimeZone(conf.sessionLocalTimeZone) - }.eval() + castedExpr.eval() } catch { case NonFatal(ex) => table.failAnalysis(s"failed to evaluate expression ${e.sql}: ${ex.getMessage}") diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/timeZoneAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/timeZoneAnalysis.scala new file mode 100644 index 0000000000000..a27aa845bf0ae --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/timeZoneAnalysis.scala @@ -0,0 +1,61 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.catalyst.analysis + +import org.apache.spark.sql.catalyst.expressions.{Cast, Expression, ListQuery, TimeZoneAwareExpression} +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.types.DataType + +/** + * Replace [[TimeZoneAwareExpression]] without timezone id by its copy with session local + * time zone. + */ +case class ResolveTimeZone(conf: SQLConf) extends Rule[LogicalPlan] { + private val transformTimeZoneExprs: PartialFunction[Expression, Expression] = { + case e: TimeZoneAwareExpression if e.timeZoneId.isEmpty => + e.withTimeZone(conf.sessionLocalTimeZone) + // Casts could be added in the subquery plan through the rule TypeCoercion while coercing + // the types between the value expression and list query expression of IN expression. + // We need to subject the subquery plan through ResolveTimeZone again to setup timezone + // information for time zone aware expressions. + case e: ListQuery => e.withNewPlan(apply(e.plan)) + } + + override def apply(plan: LogicalPlan): LogicalPlan = + plan.resolveExpressions(transformTimeZoneExprs) + + def resolveTimeZones(e: Expression): Expression = e.transform(transformTimeZoneExprs) +} + +/** + * Mix-in trait for constructing valid [[Cast]] expressions. + */ +trait CastSupport { + /** + * Configuration used to create a valid cast expression. + */ + def conf: SQLConf + + /** + * Create a Cast expression with the session local time zone. + */ + def cast(child: Expression, dataType: DataType): Cast = { + Cast(child, dataType, Option(conf.sessionLocalTimeZone)) + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/view.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/view.scala index 3bd54c257d98d..ea46dd7282401 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/view.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/view.scala @@ -47,7 +47,7 @@ import org.apache.spark.sql.internal.SQLConf * This should be only done after the batch of Resolution, because the view attributes are not * completely resolved during the batch of Resolution. */ -case class AliasViewChild(conf: SQLConf) extends Rule[LogicalPlan] { +case class AliasViewChild(conf: SQLConf) extends Rule[LogicalPlan] with CastSupport { override def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { case v @ View(desc, output, child) if child.resolved && output != child.output => val resolver = conf.resolver @@ -78,7 +78,7 @@ case class AliasViewChild(conf: SQLConf) extends Rule[LogicalPlan] { throw new AnalysisException(s"Cannot up cast ${originAttr.sql} from " + s"${originAttr.dataType.simpleString} to ${attr.simpleString} as it may truncate\n") } else { - Alias(Cast(originAttr, attr.dataType), attr.name)(exprId = attr.exprId, + Alias(cast(originAttr, attr.dataType), attr.name)(exprId = attr.exprId, qualifier = attr.qualifier, explicitMetadata = Some(attr.metadata)) } case (_, originAttr) => originAttr diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala index f8fe774823e5b..bb8fd5032d63d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala @@ -24,7 +24,6 @@ import java.util.{Calendar, TimeZone} import scala.util.control.NonFatal import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodegenFallback, ExprCode} import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.types._ @@ -34,6 +33,9 @@ import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String} * Common base class for time zone aware expressions. */ trait TimeZoneAwareExpression extends Expression { + /** The expression is only resolved when the time zone has been set. */ + override lazy val resolved: Boolean = + childrenResolved && checkInputDataTypes().isSuccess && timeZoneId.isDefined /** the timezone ID to be used to evaluate value. */ def timeZoneId: Option[String] diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveInlineTablesSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveInlineTablesSuite.scala index f45a826869842..d0fe815052256 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveInlineTablesSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveInlineTablesSuite.scala @@ -22,6 +22,7 @@ import org.scalatest.BeforeAndAfter import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.expressions.{Cast, Literal, Rand} import org.apache.spark.sql.catalyst.expressions.aggregate.Count +import org.apache.spark.sql.catalyst.plans.logical.LocalRelation import org.apache.spark.sql.types.{LongType, NullType, TimestampType} /** @@ -91,12 +92,13 @@ class ResolveInlineTablesSuite extends AnalysisTest with BeforeAndAfter { test("convert TimeZoneAwareExpression") { val table = UnresolvedInlineTable(Seq("c1"), Seq(Seq(Cast(lit("1991-12-06 00:00:00.0"), TimestampType)))) - val converted = ResolveInlineTables(conf).convert(table) + val withTimeZone = ResolveTimeZone(conf).apply(table) + val LocalRelation(output, data) = ResolveInlineTables(conf).apply(withTimeZone) val correct = Cast(lit("1991-12-06 00:00:00.0"), TimestampType) .withTimeZone(conf.sessionLocalTimeZone).eval().asInstanceOf[Long] - assert(converted.output.map(_.dataType) == Seq(TimestampType)) - assert(converted.data.size == 1) - assert(converted.data(0).getLong(0) == correct) + assert(output.map(_.dataType) == Seq(TimestampType)) + assert(data.size == 1) + assert(data.head.getLong(0) == correct) } test("nullability inference in convert") { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala index 011d09ff60641..2624f5586fd5d 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala @@ -25,6 +25,7 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.PlanTest import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules.{Rule, RuleExecutor} +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.CalendarInterval @@ -787,6 +788,12 @@ class TypeCoercionSuite extends PlanTest { } } + private val timeZoneResolver = ResolveTimeZone(new SQLConf) + + private def widenSetOperationTypes(plan: LogicalPlan): LogicalPlan = { + timeZoneResolver(TypeCoercion.WidenSetOperationTypes(plan)) + } + test("WidenSetOperationTypes for except and intersect") { val firstTable = LocalRelation( AttributeReference("i", IntegerType)(), @@ -799,11 +806,10 @@ class TypeCoercionSuite extends PlanTest { AttributeReference("f", FloatType)(), AttributeReference("l", LongType)()) - val wt = TypeCoercion.WidenSetOperationTypes val expectedTypes = Seq(StringType, DecimalType.SYSTEM_DEFAULT, FloatType, DoubleType) - val r1 = wt(Except(firstTable, secondTable)).asInstanceOf[Except] - val r2 = wt(Intersect(firstTable, secondTable)).asInstanceOf[Intersect] + val r1 = widenSetOperationTypes(Except(firstTable, secondTable)).asInstanceOf[Except] + val r2 = widenSetOperationTypes(Intersect(firstTable, secondTable)).asInstanceOf[Intersect] checkOutput(r1.left, expectedTypes) checkOutput(r1.right, expectedTypes) checkOutput(r2.left, expectedTypes) @@ -838,10 +844,9 @@ class TypeCoercionSuite extends PlanTest { AttributeReference("p", ByteType)(), AttributeReference("q", DoubleType)()) - val wt = TypeCoercion.WidenSetOperationTypes val expectedTypes = Seq(StringType, DecimalType.SYSTEM_DEFAULT, FloatType, DoubleType) - val unionRelation = wt( + val unionRelation = widenSetOperationTypes( Union(firstTable :: secondTable :: thirdTable :: forthTable :: Nil)).asInstanceOf[Union] assert(unionRelation.children.length == 4) checkOutput(unionRelation.children.head, expectedTypes) @@ -862,17 +867,15 @@ class TypeCoercionSuite extends PlanTest { } } - val dp = TypeCoercion.WidenSetOperationTypes - val left1 = LocalRelation( AttributeReference("l", DecimalType(10, 8))()) val right1 = LocalRelation( AttributeReference("r", DecimalType(5, 5))()) val expectedType1 = Seq(DecimalType(10, 8)) - val r1 = dp(Union(left1, right1)).asInstanceOf[Union] - val r2 = dp(Except(left1, right1)).asInstanceOf[Except] - val r3 = dp(Intersect(left1, right1)).asInstanceOf[Intersect] + val r1 = widenSetOperationTypes(Union(left1, right1)).asInstanceOf[Union] + val r2 = widenSetOperationTypes(Except(left1, right1)).asInstanceOf[Except] + val r3 = widenSetOperationTypes(Intersect(left1, right1)).asInstanceOf[Intersect] checkOutput(r1.children.head, expectedType1) checkOutput(r1.children.last, expectedType1) @@ -891,17 +894,17 @@ class TypeCoercionSuite extends PlanTest { val plan2 = LocalRelation( AttributeReference("r", rType)()) - val r1 = dp(Union(plan1, plan2)).asInstanceOf[Union] - val r2 = dp(Except(plan1, plan2)).asInstanceOf[Except] - val r3 = dp(Intersect(plan1, plan2)).asInstanceOf[Intersect] + val r1 = widenSetOperationTypes(Union(plan1, plan2)).asInstanceOf[Union] + val r2 = widenSetOperationTypes(Except(plan1, plan2)).asInstanceOf[Except] + val r3 = widenSetOperationTypes(Intersect(plan1, plan2)).asInstanceOf[Intersect] checkOutput(r1.children.last, Seq(expectedType)) checkOutput(r2.right, Seq(expectedType)) checkOutput(r3.right, Seq(expectedType)) - val r4 = dp(Union(plan2, plan1)).asInstanceOf[Union] - val r5 = dp(Except(plan2, plan1)).asInstanceOf[Except] - val r6 = dp(Intersect(plan2, plan1)).asInstanceOf[Intersect] + val r4 = widenSetOperationTypes(Union(plan2, plan1)).asInstanceOf[Union] + val r5 = widenSetOperationTypes(Except(plan2, plan1)).asInstanceOf[Except] + val r6 = widenSetOperationTypes(Intersect(plan2, plan1)).asInstanceOf[Intersect] checkOutput(r4.children.last, Seq(expectedType)) checkOutput(r5.left, Seq(expectedType)) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala index a7ffa884d2286..22f3f3514fa41 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala @@ -34,7 +34,7 @@ import org.apache.spark.unsafe.types.UTF8String */ class CastSuite extends SparkFunSuite with ExpressionEvalHelper { - private def cast(v: Any, targetType: DataType, timeZoneId: Option[String] = None): Cast = { + private def cast(v: Any, targetType: DataType, timeZoneId: Option[String] = Some("GMT")): Cast = { v match { case lit: Expression => Cast(lit, targetType, timeZoneId) case _ => Cast(Literal(v), targetType, timeZoneId) @@ -47,7 +47,7 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper { } private def checkNullCast(from: DataType, to: DataType): Unit = { - checkEvaluation(cast(Literal.create(null, from), to, Option("GMT")), null) + checkEvaluation(cast(Literal.create(null, from), to), null) } test("null cast") { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DateExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DateExpressionsSuite.scala index 9978f35a03810..ca89bf7db0b4f 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DateExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DateExpressionsSuite.scala @@ -160,7 +160,7 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { test("Seconds") { assert(Second(Literal.create(null, DateType), gmtId).resolved === false) - assert(Second(Cast(Literal(d), TimestampType), None).resolved === true) + assert(Second(Cast(Literal(d), TimestampType, gmtId), gmtId).resolved === true) checkEvaluation(Second(Cast(Literal(d), TimestampType, gmtId), gmtId), 0) checkEvaluation(Second(Cast(Literal(sdf.format(d)), TimestampType, gmtId), gmtId), 15) checkEvaluation(Second(Literal(ts), gmtId), 15) @@ -220,7 +220,7 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { test("Hour") { assert(Hour(Literal.create(null, DateType), gmtId).resolved === false) - assert(Hour(Literal(ts), None).resolved === true) + assert(Hour(Literal(ts), gmtId).resolved === true) checkEvaluation(Hour(Cast(Literal(d), TimestampType, gmtId), gmtId), 0) checkEvaluation(Hour(Cast(Literal(sdf.format(d)), TimestampType, gmtId), gmtId), 13) checkEvaluation(Hour(Literal(ts), gmtId), 13) @@ -246,7 +246,7 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { test("Minute") { assert(Minute(Literal.create(null, DateType), gmtId).resolved === false) - assert(Minute(Literal(ts), None).resolved === true) + assert(Minute(Literal(ts), gmtId).resolved === true) checkEvaluation(Minute(Cast(Literal(d), TimestampType, gmtId), gmtId), 0) checkEvaluation( Minute(Cast(Literal(sdf.format(d)), TimestampType, gmtId), gmtId), 10) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala index 1ba6dd1c5e8ca..b6399edb68dd6 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala @@ -25,10 +25,12 @@ import org.scalatest.prop.GeneratorDrivenPropertyChecks import org.apache.spark.{SparkConf, SparkFunSuite} import org.apache.spark.serializer.JavaSerializer import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} +import org.apache.spark.sql.catalyst.analysis.ResolveTimeZone import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.optimizer.SimpleTestOptimizer import org.apache.spark.sql.catalyst.plans.logical.{OneRowRelation, Project} -import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, ArrayData, GenericArrayData, MapData} +import org.apache.spark.sql.catalyst.util.{ArrayData, MapData} +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ import org.apache.spark.util.Utils @@ -45,7 +47,8 @@ trait ExpressionEvalHelper extends GeneratorDrivenPropertyChecks { protected def checkEvaluation( expression: => Expression, expected: Any, inputRow: InternalRow = EmptyRow): Unit = { val serializer = new JavaSerializer(new SparkConf()).newInstance - val expr: Expression = serializer.deserialize(serializer.serialize(expression)) + val resolver = ResolveTimeZone(new SQLConf) + val expr = resolver.resolveTimeZones(serializer.deserialize(serializer.serialize(expression))) val catalystValue = CatalystTypeConverters.convertToCatalyst(expected) checkEvaluationWithoutCodegen(expr, catalystValue, inputRow) checkEvaluationWithGeneratedMutableProjection(expr, catalystValue, inputRow) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanner.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanner.scala index 6566502bd8a8a..4e718d609c921 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanner.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanner.scala @@ -36,7 +36,7 @@ class SparkPlanner( experimentalMethods.extraStrategies ++ extraPlanningStrategies ++ ( FileSourceStrategy :: - DataSourceStrategy :: + DataSourceStrategy(conf) :: SpecialLimits :: Aggregation :: JoinSelection :: diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala index 2d83d512e702d..d307122b5c70d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala @@ -24,7 +24,7 @@ import scala.collection.mutable.ArrayBuffer import org.apache.spark.internal.Logging import org.apache.spark.rdd.RDD import org.apache.spark.sql._ -import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow, QualifiedTableName, TableIdentifier} +import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow, QualifiedTableName} import org.apache.spark.sql.catalyst.CatalystTypeConverters.convertToScala import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.catalog.{CatalogRelation, CatalogUtils} @@ -48,7 +48,7 @@ import org.apache.spark.unsafe.types.UTF8String * Note that, this rule must be run after `PreprocessTableCreation` and * `PreprocessTableInsertion`. */ -case class DataSourceAnalysis(conf: SQLConf) extends Rule[LogicalPlan] { +case class DataSourceAnalysis(conf: SQLConf) extends Rule[LogicalPlan] with CastSupport { def resolver: Resolver = conf.resolver @@ -98,11 +98,11 @@ case class DataSourceAnalysis(conf: SQLConf) extends Rule[LogicalPlan] { val potentialSpecs = staticPartitions.filter { case (partKey, partValue) => resolver(field.name, partKey) } - if (potentialSpecs.size == 0) { + if (potentialSpecs.isEmpty) { None } else if (potentialSpecs.size == 1) { val partValue = potentialSpecs.head._2 - Some(Alias(Cast(Literal(partValue), field.dataType), field.name)()) + Some(Alias(cast(Literal(partValue), field.dataType), field.name)()) } else { throw new AnalysisException( s"Partition column ${field.name} have multiple values specified, " + @@ -258,7 +258,9 @@ class FindDataSourceTable(sparkSession: SparkSession) extends Rule[LogicalPlan] /** * A Strategy for planning scans over data sources defined using the sources API. */ -object DataSourceStrategy extends Strategy with Logging { +case class DataSourceStrategy(conf: SQLConf) extends Strategy with Logging with CastSupport { + import DataSourceStrategy._ + def apply(plan: LogicalPlan): Seq[execution.SparkPlan] = plan match { case PhysicalOperation(projects, filters, l @ LogicalRelation(t: CatalystScan, _, _)) => pruneFilterProjectRaw( @@ -298,7 +300,7 @@ object DataSourceStrategy extends Strategy with Logging { // Restriction: Bucket pruning works iff the bucketing column has one and only one column. def getBucketId(bucketColumn: Attribute, numBuckets: Int, value: Any): Int = { val mutableRow = new SpecificInternalRow(Seq(bucketColumn.dataType)) - mutableRow(0) = Cast(Literal(value), bucketColumn.dataType).eval(null) + mutableRow(0) = cast(Literal(value), bucketColumn.dataType).eval(null) val bucketIdGeneration = UnsafeProjection.create( HashPartitioning(bucketColumn :: Nil, numBuckets).partitionIdExpression :: Nil, bucketColumn :: Nil) @@ -436,7 +438,9 @@ object DataSourceStrategy extends Strategy with Logging { private[this] def toCatalystRDD(relation: LogicalRelation, rdd: RDD[Row]): RDD[InternalRow] = { toCatalystRDD(relation, relation.output, rdd) } +} +object DataSourceStrategy { /** * Tries to translate a Catalyst [[Expression]] into data source [[Filter]]. * @@ -527,8 +531,8 @@ object DataSourceStrategy extends Strategy with Logging { * all [[Filter]]s that are completely filtered at the DataSource. */ protected[sql] def selectFilters( - relation: BaseRelation, - predicates: Seq[Expression]): (Seq[Expression], Seq[Filter], Set[Filter]) = { + relation: BaseRelation, + predicates: Seq[Expression]): (Seq[Expression], Seq[Filter], Set[Filter]) = { // For conciseness, all Catalyst filter expressions of type `expressions.Expression` below are // called `predicate`s, while all data source filters of type `sources.Filter` are simply called diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala index 7abf2ae5166b5..3f4a78580f1eb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala @@ -22,7 +22,7 @@ import java.util.Locale import org.apache.spark.sql.{AnalysisException, SaveMode, SparkSession} import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.catalog._ -import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, Cast, RowOrdering} +import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, RowOrdering} import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.execution.command.DDLUtils @@ -315,7 +315,7 @@ case class PreprocessTableCreation(sparkSession: SparkSession) extends Rule[Logi * table. It also does data type casting and field renaming, to make sure that the columns to be * inserted have the correct data type and fields have the correct names. */ -case class PreprocessTableInsertion(conf: SQLConf) extends Rule[LogicalPlan] { +case class PreprocessTableInsertion(conf: SQLConf) extends Rule[LogicalPlan] with CastSupport { private def preprocess( insert: InsertIntoTable, tblName: String, @@ -367,7 +367,7 @@ case class PreprocessTableInsertion(conf: SQLConf) extends Rule[LogicalPlan] { // Renaming is needed for handling the following cases like // 1) Column names/types do not match, e.g., INSERT INTO TABLE tab1 SELECT 1, 2 // 2) Target tables have column metadata - Alias(Cast(actual, expected.dataType), expected.name)( + Alias(cast(actual, expected.dataType), expected.name)( explicitMetadata = Option(expected.metadata)) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala index 2b14eca919fa4..df7c3678b7807 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.internal import org.apache.spark.SparkConf import org.apache.spark.annotation.{Experimental, InterfaceStability} import org.apache.spark.sql.{ExperimentalMethods, SparkSession, Strategy, UDFRegistration} -import org.apache.spark.sql.catalyst.analysis.{Analyzer, FunctionRegistry} +import org.apache.spark.sql.catalyst.analysis.{Analyzer, FunctionRegistry, ResolveTimeZone} import org.apache.spark.sql.catalyst.catalog.SessionCatalog import org.apache.spark.sql.catalyst.optimizer.Optimizer import org.apache.spark.sql.catalyst.parser.ParserInterface diff --git a/sql/core/src/test/resources/sql-tests/inputs/having.sql b/sql/core/src/test/resources/sql-tests/inputs/having.sql index 364c022d959dc..868a911e787f6 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/having.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/having.sql @@ -13,3 +13,6 @@ SELECT count(k) FROM hav GROUP BY v + 1 HAVING v + 1 = 2; -- SPARK-11032: resolve having correctly SELECT MIN(t.v) FROM (SELECT * FROM hav WHERE v > 0) t HAVING(COUNT(1) > 0); + +-- SPARK-20329: make sure we handle timezones correctly +SELECT a + b FROM VALUES (1L, 2), (3L, 4) AS T(a, b) GROUP BY a + b HAVING a + b > 1; diff --git a/sql/core/src/test/resources/sql-tests/results/having.sql.out b/sql/core/src/test/resources/sql-tests/results/having.sql.out index e0923832673cb..d87ee5221647f 100644 --- a/sql/core/src/test/resources/sql-tests/results/having.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/having.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 4 +-- Number of queries: 5 -- !query 0 @@ -38,3 +38,12 @@ SELECT MIN(t.v) FROM (SELECT * FROM hav WHERE v > 0) t HAVING(COUNT(1) > 0) struct -- !query 3 output 1 + + +-- !query 4 +SELECT a + b FROM VALUES (1L, 2), (3L, 4) AS T(a, b) GROUP BY a + b HAVING a + b > 1 +-- !query 4 schema +struct<(a + CAST(b AS BIGINT)):bigint> +-- !query 4 output +3 +7 diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala index 9b65419dba234..ba0ca666b5c14 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala @@ -90,6 +90,7 @@ abstract class BucketedReadSuite extends QueryTest with SQLTestUtils { originalDataFrame: DataFrame): Unit = { // This test verifies parts of the plan. Disable whole stage codegen. withSQLConf(SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> "false") { + val strategy = DataSourceStrategy(spark.sessionState.conf) val bucketedDataFrame = spark.table("bucketed_table").select("i", "j", "k") val BucketSpec(numBuckets, bucketColumnNames, _) = bucketSpec // Limit: bucket pruning only works when the bucket column has one and only one column @@ -98,7 +99,7 @@ abstract class BucketedReadSuite extends QueryTest with SQLTestUtils { val bucketColumn = bucketedDataFrame.schema.toAttributes(bucketColumnIndex) val matchedBuckets = new BitSet(numBuckets) bucketValues.foreach { value => - matchedBuckets.set(DataSourceStrategy.getBucketId(bucketColumn, numBuckets, value)) + matchedBuckets.set(strategy.getBucketId(bucketColumn, numBuckets, value)) } // Filter could hide the bug in bucket pruning. Thus, skipping all the filters diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/DataSourceAnalysisSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/DataSourceAnalysisSuite.scala index b16c9f8fc96b2..735e07c21373a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/DataSourceAnalysisSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/DataSourceAnalysisSuite.scala @@ -25,7 +25,7 @@ import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, Cast, Expression, Literal} import org.apache.spark.sql.execution.datasources.DataSourceAnalysis import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.types.{IntegerType, StructType} +import org.apache.spark.sql.types.{DataType, IntegerType, StructType} class DataSourceAnalysisSuite extends SparkFunSuite with BeforeAndAfterAll { @@ -49,7 +49,11 @@ class DataSourceAnalysisSuite extends SparkFunSuite with BeforeAndAfterAll { } Seq(true, false).foreach { caseSensitive => - val rule = DataSourceAnalysis(new SQLConf().copy(SQLConf.CASE_SENSITIVE -> caseSensitive)) + val conf = new SQLConf().copy(SQLConf.CASE_SENSITIVE -> caseSensitive) + def cast(e: Expression, dt: DataType): Expression = { + Cast(e, dt, Option(conf.sessionLocalTimeZone)) + } + val rule = DataSourceAnalysis(conf) test( s"convertStaticPartitions only handle INSERT having at least static partitions " + s"(caseSensitive: $caseSensitive)") { @@ -150,7 +154,7 @@ class DataSourceAnalysisSuite extends SparkFunSuite with BeforeAndAfterAll { if (!caseSensitive) { val nonPartitionedAttributes = Seq('e.int, 'f.int) val expected = nonPartitionedAttributes ++ - Seq(Cast(Literal("1"), IntegerType), Cast(Literal("3"), IntegerType)) + Seq(cast(Literal("1"), IntegerType), cast(Literal("3"), IntegerType)) val actual = rule.convertStaticPartitions( sourceAttributes = nonPartitionedAttributes, providedPartitions = Map("b" -> Some("1"), "C" -> Some("3")), @@ -162,7 +166,7 @@ class DataSourceAnalysisSuite extends SparkFunSuite with BeforeAndAfterAll { { val nonPartitionedAttributes = Seq('e.int, 'f.int) val expected = nonPartitionedAttributes ++ - Seq(Cast(Literal("1"), IntegerType), Cast(Literal("3"), IntegerType)) + Seq(cast(Literal("1"), IntegerType), cast(Literal("3"), IntegerType)) val actual = rule.convertStaticPartitions( sourceAttributes = nonPartitionedAttributes, providedPartitions = Map("b" -> Some("1"), "c" -> Some("3")), @@ -174,7 +178,7 @@ class DataSourceAnalysisSuite extends SparkFunSuite with BeforeAndAfterAll { // Test the case having a single static partition column. { val nonPartitionedAttributes = Seq('e.int, 'f.int) - val expected = nonPartitionedAttributes ++ Seq(Cast(Literal("1"), IntegerType)) + val expected = nonPartitionedAttributes ++ Seq(cast(Literal("1"), IntegerType)) val actual = rule.convertStaticPartitions( sourceAttributes = nonPartitionedAttributes, providedPartitions = Map("b" -> Some("1")), @@ -189,7 +193,7 @@ class DataSourceAnalysisSuite extends SparkFunSuite with BeforeAndAfterAll { val dynamicPartitionAttributes = Seq('g.int) val expected = nonPartitionedAttributes ++ - Seq(Cast(Literal("1"), IntegerType)) ++ + Seq(cast(Literal("1"), IntegerType)) ++ dynamicPartitionAttributes val actual = rule.convertStaticPartitions( sourceAttributes = nonPartitionedAttributes ++ dynamicPartitionAttributes, diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionStateBuilder.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionStateBuilder.scala index 9d3b31f39c0f5..e16c9e46b7723 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionStateBuilder.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionStateBuilder.scala @@ -101,7 +101,7 @@ class HiveSessionStateBuilder(session: SparkSession, parentState: Option[Session experimentalMethods.extraStrategies ++ extraPlanningStrategies ++ Seq( FileSourceStrategy, - DataSourceStrategy, + DataSourceStrategy(conf), SpecialLimits, InMemoryScans, HiveTableScans, From 48d760d028dd73371f99d084c4195dbc4dda5267 Mon Sep 17 00:00:00 2001 From: Takeshi Yamamuro Date: Thu, 20 Apr 2017 19:40:21 -0700 Subject: [PATCH 0309/1765] [SPARK-20281][SQL] Print the identical Range parameters of SparkContext APIs and SQL in explain ## What changes were proposed in this pull request? This pr modified code to print the identical `Range` parameters of SparkContext APIs and SQL in `explain` output. In the current master, they internally use `defaultParallelism` for `splits` by default though, they print different strings in explain output; ``` scala> spark.range(4).explain == Physical Plan == *Range (0, 4, step=1, splits=Some(8)) scala> sql("select * from range(4)").explain == Physical Plan == *Range (0, 4, step=1, splits=None) ``` ## How was this patch tested? Added tests in `SQLQuerySuite` and modified some results in the existing tests. Author: Takeshi Yamamuro Closes #17670 from maropu/SPARK-20281. --- .../apache/spark/sql/execution/basicPhysicalOperators.scala | 3 ++- .../sql-tests/results/sql-compatibility-functions.sql.out | 2 +- .../resources/sql-tests/results/table-valued-functions.sql.out | 2 +- 3 files changed, 4 insertions(+), 3 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala index 233a105f4d93a..d3efa428a6db8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala @@ -332,6 +332,7 @@ case class RangeExec(range: org.apache.spark.sql.catalyst.plans.logical.Range) extends LeafExecNode with CodegenSupport { def start: Long = range.start + def end: Long = range.end def step: Long = range.step def numSlices: Int = range.numSlices.getOrElse(sparkContext.defaultParallelism) def numElements: BigInt = range.numElements @@ -538,7 +539,7 @@ case class RangeExec(range: org.apache.spark.sql.catalyst.plans.logical.Range) } } - override def simpleString: String = range.simpleString + override def simpleString: String = s"Range ($start, $end, step=$step, splits=$numSlices)" } /** diff --git a/sql/core/src/test/resources/sql-tests/results/sql-compatibility-functions.sql.out b/sql/core/src/test/resources/sql-tests/results/sql-compatibility-functions.sql.out index 9f0b95994be53..732b11050f461 100644 --- a/sql/core/src/test/resources/sql-tests/results/sql-compatibility-functions.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/sql-compatibility-functions.sql.out @@ -88,7 +88,7 @@ Project [coalesce(cast(id#xL as string), x) AS ifnull(`id`, 'x')#x, id#xL AS nul == Physical Plan == *Project [coalesce(cast(id#xL as string), x) AS ifnull(`id`, 'x')#x, id#xL AS nullif(`id`, 'x')#xL, coalesce(cast(id#xL as string), x) AS nvl(`id`, 'x')#x, x AS nvl2(`id`, 'x', 'y')#x] -+- *Range (0, 2, step=1, splits=None) ++- *Range (0, 2, step=1, splits=2) -- !query 9 diff --git a/sql/core/src/test/resources/sql-tests/results/table-valued-functions.sql.out b/sql/core/src/test/resources/sql-tests/results/table-valued-functions.sql.out index acd4ecf14617e..e2ee970d35f60 100644 --- a/sql/core/src/test/resources/sql-tests/results/table-valued-functions.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/table-valued-functions.sql.out @@ -102,4 +102,4 @@ EXPLAIN select * from RaNgE(2) struct -- !query 8 output == Physical Plan == -*Range (0, 2, step=1, splits=None) +*Range (0, 2, step=1, splits=2) From e2b3d2367a563d4600d8d87b5317e71135c362f0 Mon Sep 17 00:00:00 2001 From: Herman van Hovell Date: Fri, 21 Apr 2017 00:05:03 -0700 Subject: [PATCH 0310/1765] [SPARK-20420][SQL] Add events to the external catalog ## What changes were proposed in this pull request? It is often useful to be able to track changes to the `ExternalCatalog`. This PR makes the `ExternalCatalog` emit events when a catalog object is changed. Events are fired before and after the change. The following events are fired per object: - Database - CreateDatabasePreEvent: event fired before the database is created. - CreateDatabaseEvent: event fired after the database has been created. - DropDatabasePreEvent: event fired before the database is dropped. - DropDatabaseEvent: event fired after the database has been dropped. - Table - CreateTablePreEvent: event fired before the table is created. - CreateTableEvent: event fired after the table has been created. - RenameTablePreEvent: event fired before the table is renamed. - RenameTableEvent: event fired after the table has been renamed. - DropTablePreEvent: event fired before the table is dropped. - DropTableEvent: event fired after the table has been dropped. - Function - CreateFunctionPreEvent: event fired before the function is created. - CreateFunctionEvent: event fired after the function has been created. - RenameFunctionPreEvent: event fired before the function is renamed. - RenameFunctionEvent: event fired after the function has been renamed. - DropFunctionPreEvent: event fired before the function is dropped. - DropFunctionPreEvent: event fired after the function has been dropped. The current events currently only contain the names of the object modified. We add more events, and more details at a later point. A user can monitor changes to the external catalog by adding a listener to the Spark listener bus checking for `ExternalCatalogEvent`s using the `SparkListener.onOtherEvent` hook. A more direct approach is add listener directly to the `ExternalCatalog`. ## How was this patch tested? Added the `ExternalCatalogEventSuite`. Author: Herman van Hovell Closes #17710 from hvanhovell/SPARK-20420. --- .../catalyst/catalog/ExternalCatalog.scala | 85 +++++++- .../catalyst/catalog/InMemoryCatalog.scala | 22 +- .../spark/sql/catalyst/catalog/events.scala | 158 +++++++++++++++ .../catalog/ExternalCatalogEventSuite.scala | 188 ++++++++++++++++++ .../spark/sql/internal/SharedState.scala | 7 + .../spark/sql/hive/HiveExternalCatalog.scala | 22 +- 6 files changed, 457 insertions(+), 25 deletions(-) create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/events.scala create mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalogEventSuite.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalog.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalog.scala index 08a01e8601897..974ef900e2eed 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalog.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalog.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.catalyst.catalog import org.apache.spark.sql.catalyst.analysis.{FunctionAlreadyExistsException, NoSuchDatabaseException, NoSuchFunctionException, NoSuchTableException} import org.apache.spark.sql.catalyst.expressions.Expression import org.apache.spark.sql.types.StructType +import org.apache.spark.util.ListenerBus /** * Interface for the system catalog (of functions, partitions, tables, and databases). @@ -30,7 +31,8 @@ import org.apache.spark.sql.types.StructType * * Implementations should throw [[NoSuchDatabaseException]] when databases don't exist. */ -abstract class ExternalCatalog { +abstract class ExternalCatalog + extends ListenerBus[ExternalCatalogEventListener, ExternalCatalogEvent] { import CatalogTypes.TablePartitionSpec protected def requireDbExists(db: String): Unit = { @@ -61,9 +63,22 @@ abstract class ExternalCatalog { // Databases // -------------------------------------------------------------------------- - def createDatabase(dbDefinition: CatalogDatabase, ignoreIfExists: Boolean): Unit + final def createDatabase(dbDefinition: CatalogDatabase, ignoreIfExists: Boolean): Unit = { + val db = dbDefinition.name + postToAll(CreateDatabasePreEvent(db)) + doCreateDatabase(dbDefinition, ignoreIfExists) + postToAll(CreateDatabaseEvent(db)) + } + + protected def doCreateDatabase(dbDefinition: CatalogDatabase, ignoreIfExists: Boolean): Unit + + final def dropDatabase(db: String, ignoreIfNotExists: Boolean, cascade: Boolean): Unit = { + postToAll(DropDatabasePreEvent(db)) + doDropDatabase(db, ignoreIfNotExists, cascade) + postToAll(DropDatabaseEvent(db)) + } - def dropDatabase(db: String, ignoreIfNotExists: Boolean, cascade: Boolean): Unit + protected def doDropDatabase(db: String, ignoreIfNotExists: Boolean, cascade: Boolean): Unit /** * Alter a database whose name matches the one specified in `dbDefinition`, @@ -88,11 +103,39 @@ abstract class ExternalCatalog { // Tables // -------------------------------------------------------------------------- - def createTable(tableDefinition: CatalogTable, ignoreIfExists: Boolean): Unit + final def createTable(tableDefinition: CatalogTable, ignoreIfExists: Boolean): Unit = { + val db = tableDefinition.database + val name = tableDefinition.identifier.table + postToAll(CreateTablePreEvent(db, name)) + doCreateTable(tableDefinition, ignoreIfExists) + postToAll(CreateTableEvent(db, name)) + } - def dropTable(db: String, table: String, ignoreIfNotExists: Boolean, purge: Boolean): Unit + protected def doCreateTable(tableDefinition: CatalogTable, ignoreIfExists: Boolean): Unit - def renameTable(db: String, oldName: String, newName: String): Unit + final def dropTable( + db: String, + table: String, + ignoreIfNotExists: Boolean, + purge: Boolean): Unit = { + postToAll(DropTablePreEvent(db, table)) + doDropTable(db, table, ignoreIfNotExists, purge) + postToAll(DropTableEvent(db, table)) + } + + protected def doDropTable( + db: String, + table: String, + ignoreIfNotExists: Boolean, + purge: Boolean): Unit + + final def renameTable(db: String, oldName: String, newName: String): Unit = { + postToAll(RenameTablePreEvent(db, oldName, newName)) + doRenameTable(db, oldName, newName) + postToAll(RenameTableEvent(db, oldName, newName)) + } + + protected def doRenameTable(db: String, oldName: String, newName: String): Unit /** * Alter a table whose database and name match the ones specified in `tableDefinition`, assuming @@ -269,11 +312,30 @@ abstract class ExternalCatalog { // Functions // -------------------------------------------------------------------------- - def createFunction(db: String, funcDefinition: CatalogFunction): Unit + final def createFunction(db: String, funcDefinition: CatalogFunction): Unit = { + val name = funcDefinition.identifier.funcName + postToAll(CreateFunctionPreEvent(db, name)) + doCreateFunction(db, funcDefinition) + postToAll(CreateFunctionEvent(db, name)) + } - def dropFunction(db: String, funcName: String): Unit + protected def doCreateFunction(db: String, funcDefinition: CatalogFunction): Unit - def renameFunction(db: String, oldName: String, newName: String): Unit + final def dropFunction(db: String, funcName: String): Unit = { + postToAll(DropFunctionPreEvent(db, funcName)) + doDropFunction(db, funcName) + postToAll(DropFunctionEvent(db, funcName)) + } + + protected def doDropFunction(db: String, funcName: String): Unit + + final def renameFunction(db: String, oldName: String, newName: String): Unit = { + postToAll(RenameFunctionPreEvent(db, oldName, newName)) + doRenameFunction(db, oldName, newName) + postToAll(RenameFunctionEvent(db, oldName, newName)) + } + + protected def doRenameFunction(db: String, oldName: String, newName: String): Unit def getFunction(db: String, funcName: String): CatalogFunction @@ -281,4 +343,9 @@ abstract class ExternalCatalog { def listFunctions(db: String, pattern: String): Seq[String] + override protected def doPostEvent( + listener: ExternalCatalogEventListener, + event: ExternalCatalogEvent): Unit = { + listener.onEvent(event) + } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/InMemoryCatalog.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/InMemoryCatalog.scala index 9ca1c71d1dcb1..81dd8efc0015f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/InMemoryCatalog.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/InMemoryCatalog.scala @@ -98,7 +98,7 @@ class InMemoryCatalog( // Databases // -------------------------------------------------------------------------- - override def createDatabase( + override protected def doCreateDatabase( dbDefinition: CatalogDatabase, ignoreIfExists: Boolean): Unit = synchronized { if (catalog.contains(dbDefinition.name)) { @@ -119,7 +119,7 @@ class InMemoryCatalog( } } - override def dropDatabase( + override protected def doDropDatabase( db: String, ignoreIfNotExists: Boolean, cascade: Boolean): Unit = synchronized { @@ -180,7 +180,7 @@ class InMemoryCatalog( // Tables // -------------------------------------------------------------------------- - override def createTable( + override protected def doCreateTable( tableDefinition: CatalogTable, ignoreIfExists: Boolean): Unit = synchronized { assert(tableDefinition.identifier.database.isDefined) @@ -221,7 +221,7 @@ class InMemoryCatalog( } } - override def dropTable( + override protected def doDropTable( db: String, table: String, ignoreIfNotExists: Boolean, @@ -264,7 +264,10 @@ class InMemoryCatalog( } } - override def renameTable(db: String, oldName: String, newName: String): Unit = synchronized { + override protected def doRenameTable( + db: String, + oldName: String, + newName: String): Unit = synchronized { requireTableExists(db, oldName) requireTableNotExists(db, newName) val oldDesc = catalog(db).tables(oldName) @@ -565,18 +568,21 @@ class InMemoryCatalog( // Functions // -------------------------------------------------------------------------- - override def createFunction(db: String, func: CatalogFunction): Unit = synchronized { + override protected def doCreateFunction(db: String, func: CatalogFunction): Unit = synchronized { requireDbExists(db) requireFunctionNotExists(db, func.identifier.funcName) catalog(db).functions.put(func.identifier.funcName, func) } - override def dropFunction(db: String, funcName: String): Unit = synchronized { + override protected def doDropFunction(db: String, funcName: String): Unit = synchronized { requireFunctionExists(db, funcName) catalog(db).functions.remove(funcName) } - override def renameFunction(db: String, oldName: String, newName: String): Unit = synchronized { + override protected def doRenameFunction( + db: String, + oldName: String, + newName: String): Unit = synchronized { requireFunctionExists(db, oldName) requireFunctionNotExists(db, newName) val newFunc = getFunction(db, oldName).copy(identifier = FunctionIdentifier(newName, Some(db))) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/events.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/events.scala new file mode 100644 index 0000000000000..459973a13bb10 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/events.scala @@ -0,0 +1,158 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.catalyst.catalog + +import org.apache.spark.scheduler.SparkListenerEvent + +/** + * Event emitted by the external catalog when it is modified. Events are either fired before or + * after the modification (the event should document this). + */ +trait ExternalCatalogEvent extends SparkListenerEvent + +/** + * Listener interface for external catalog modification events. + */ +trait ExternalCatalogEventListener { + def onEvent(event: ExternalCatalogEvent): Unit +} + +/** + * Event fired when a database is create or dropped. + */ +trait DatabaseEvent extends ExternalCatalogEvent { + /** + * Database of the object that was touched. + */ + val database: String +} + +/** + * Event fired before a database is created. + */ +case class CreateDatabasePreEvent(database: String) extends DatabaseEvent + +/** + * Event fired after a database has been created. + */ +case class CreateDatabaseEvent(database: String) extends DatabaseEvent + +/** + * Event fired before a database is dropped. + */ +case class DropDatabasePreEvent(database: String) extends DatabaseEvent + +/** + * Event fired after a database has been dropped. + */ +case class DropDatabaseEvent(database: String) extends DatabaseEvent + +/** + * Event fired when a table is created, dropped or renamed. + */ +trait TableEvent extends DatabaseEvent { + /** + * Name of the table that was touched. + */ + val name: String +} + +/** + * Event fired before a table is created. + */ +case class CreateTablePreEvent(database: String, name: String) extends TableEvent + +/** + * Event fired after a table has been created. + */ +case class CreateTableEvent(database: String, name: String) extends TableEvent + +/** + * Event fired before a table is dropped. + */ +case class DropTablePreEvent(database: String, name: String) extends TableEvent + +/** + * Event fired after a table has been dropped. + */ +case class DropTableEvent(database: String, name: String) extends TableEvent + +/** + * Event fired before a table is renamed. + */ +case class RenameTablePreEvent( + database: String, + name: String, + newName: String) + extends TableEvent + +/** + * Event fired after a table has been renamed. + */ +case class RenameTableEvent( + database: String, + name: String, + newName: String) + extends TableEvent + +/** + * Event fired when a function is created, dropped or renamed. + */ +trait FunctionEvent extends DatabaseEvent { + /** + * Name of the function that was touched. + */ + val name: String +} + +/** + * Event fired before a function is created. + */ +case class CreateFunctionPreEvent(database: String, name: String) extends FunctionEvent + +/** + * Event fired after a function has been created. + */ +case class CreateFunctionEvent(database: String, name: String) extends FunctionEvent + +/** + * Event fired before a function is dropped. + */ +case class DropFunctionPreEvent(database: String, name: String) extends FunctionEvent + +/** + * Event fired after a function has been dropped. + */ +case class DropFunctionEvent(database: String, name: String) extends FunctionEvent + +/** + * Event fired before a function is renamed. + */ +case class RenameFunctionPreEvent( + database: String, + name: String, + newName: String) + extends FunctionEvent + +/** + * Event fired after a function has been renamed. + */ +case class RenameFunctionEvent( + database: String, + name: String, + newName: String) + extends FunctionEvent diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalogEventSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalogEventSuite.scala new file mode 100644 index 0000000000000..2539ea615ff92 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalogEventSuite.scala @@ -0,0 +1,188 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.catalyst.catalog + +import java.net.URI +import java.nio.file.{Files, Path} + +import scala.collection.mutable + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.catalyst.{FunctionIdentifier, TableIdentifier} +import org.apache.spark.sql.types.StructType + +/** + * Test Suite for external catalog events + */ +class ExternalCatalogEventSuite extends SparkFunSuite { + + protected def newCatalog: ExternalCatalog = new InMemoryCatalog() + + private def testWithCatalog( + name: String)( + f: (ExternalCatalog, Seq[ExternalCatalogEvent] => Unit) => Unit): Unit = test(name) { + val catalog = newCatalog + val recorder = mutable.Buffer.empty[ExternalCatalogEvent] + catalog.addListener(new ExternalCatalogEventListener { + override def onEvent(event: ExternalCatalogEvent): Unit = { + recorder += event + } + }) + f(catalog, (expected: Seq[ExternalCatalogEvent]) => { + val actual = recorder.clone() + recorder.clear() + assert(expected === actual) + }) + } + + private def createDbDefinition(uri: URI): CatalogDatabase = { + CatalogDatabase(name = "db5", description = "", locationUri = uri, Map.empty) + } + + private def createDbDefinition(): CatalogDatabase = { + createDbDefinition(preparePath(Files.createTempDirectory("db_"))) + } + + private def preparePath(path: Path): URI = path.normalize().toUri + + testWithCatalog("database") { (catalog, checkEvents) => + // CREATE + val dbDefinition = createDbDefinition() + + catalog.createDatabase(dbDefinition, ignoreIfExists = false) + checkEvents(CreateDatabasePreEvent("db5") :: CreateDatabaseEvent("db5") :: Nil) + + catalog.createDatabase(dbDefinition, ignoreIfExists = true) + checkEvents(CreateDatabasePreEvent("db5") :: CreateDatabaseEvent("db5") :: Nil) + + intercept[AnalysisException] { + catalog.createDatabase(dbDefinition, ignoreIfExists = false) + } + checkEvents(CreateDatabasePreEvent("db5") :: Nil) + + // DROP + intercept[AnalysisException] { + catalog.dropDatabase("db4", ignoreIfNotExists = false, cascade = false) + } + checkEvents(DropDatabasePreEvent("db4") :: Nil) + + catalog.dropDatabase("db5", ignoreIfNotExists = false, cascade = false) + checkEvents(DropDatabasePreEvent("db5") :: DropDatabaseEvent("db5") :: Nil) + + catalog.dropDatabase("db4", ignoreIfNotExists = true, cascade = false) + checkEvents(DropDatabasePreEvent("db4") :: DropDatabaseEvent("db4") :: Nil) + } + + testWithCatalog("table") { (catalog, checkEvents) => + val path1 = Files.createTempDirectory("db_") + val path2 = Files.createTempDirectory(path1, "tbl_") + val uri1 = preparePath(path1) + val uri2 = preparePath(path2) + + // CREATE + val dbDefinition = createDbDefinition(uri1) + + val storage = CatalogStorageFormat.empty.copy( + locationUri = Option(uri2)) + val tableDefinition = CatalogTable( + identifier = TableIdentifier("tbl1", Some("db5")), + tableType = CatalogTableType.MANAGED, + storage = storage, + schema = new StructType().add("id", "long")) + + catalog.createDatabase(dbDefinition, ignoreIfExists = false) + checkEvents(CreateDatabasePreEvent("db5") :: CreateDatabaseEvent("db5") :: Nil) + + catalog.createTable(tableDefinition, ignoreIfExists = false) + checkEvents(CreateTablePreEvent("db5", "tbl1") :: CreateTableEvent("db5", "tbl1") :: Nil) + + catalog.createTable(tableDefinition, ignoreIfExists = true) + checkEvents(CreateTablePreEvent("db5", "tbl1") :: CreateTableEvent("db5", "tbl1") :: Nil) + + intercept[AnalysisException] { + catalog.createTable(tableDefinition, ignoreIfExists = false) + } + checkEvents(CreateTablePreEvent("db5", "tbl1") :: Nil) + + // RENAME + catalog.renameTable("db5", "tbl1", "tbl2") + checkEvents( + RenameTablePreEvent("db5", "tbl1", "tbl2") :: + RenameTableEvent("db5", "tbl1", "tbl2") :: Nil) + + intercept[AnalysisException] { + catalog.renameTable("db5", "tbl1", "tbl2") + } + checkEvents(RenameTablePreEvent("db5", "tbl1", "tbl2") :: Nil) + + // DROP + intercept[AnalysisException] { + catalog.dropTable("db5", "tbl1", ignoreIfNotExists = false, purge = true) + } + checkEvents(DropTablePreEvent("db5", "tbl1") :: Nil) + + catalog.dropTable("db5", "tbl2", ignoreIfNotExists = false, purge = true) + checkEvents(DropTablePreEvent("db5", "tbl2") :: DropTableEvent("db5", "tbl2") :: Nil) + + catalog.dropTable("db5", "tbl2", ignoreIfNotExists = true, purge = true) + checkEvents(DropTablePreEvent("db5", "tbl2") :: DropTableEvent("db5", "tbl2") :: Nil) + } + + testWithCatalog("function") { (catalog, checkEvents) => + // CREATE + val dbDefinition = createDbDefinition() + + val functionDefinition = CatalogFunction( + identifier = FunctionIdentifier("fn7", Some("db5")), + className = "", + resources = Seq.empty) + + val newIdentifier = functionDefinition.identifier.copy(funcName = "fn4") + val renamedFunctionDefinition = functionDefinition.copy(identifier = newIdentifier) + + catalog.createDatabase(dbDefinition, ignoreIfExists = false) + checkEvents(CreateDatabasePreEvent("db5") :: CreateDatabaseEvent("db5") :: Nil) + + catalog.createFunction("db5", functionDefinition) + checkEvents(CreateFunctionPreEvent("db5", "fn7") :: CreateFunctionEvent("db5", "fn7") :: Nil) + + intercept[AnalysisException] { + catalog.createFunction("db5", functionDefinition) + } + checkEvents(CreateFunctionPreEvent("db5", "fn7") :: Nil) + + // RENAME + catalog.renameFunction("db5", "fn7", "fn4") + checkEvents( + RenameFunctionPreEvent("db5", "fn7", "fn4") :: + RenameFunctionEvent("db5", "fn7", "fn4") :: Nil) + intercept[AnalysisException] { + catalog.renameFunction("db5", "fn7", "fn4") + } + checkEvents(RenameFunctionPreEvent("db5", "fn7", "fn4") :: Nil) + + // DROP + intercept[AnalysisException] { + catalog.dropFunction("db5", "fn7") + } + checkEvents(DropFunctionPreEvent("db5", "fn7") :: Nil) + + catalog.dropFunction("db5", "fn4") + checkEvents(DropFunctionPreEvent("db5", "fn4") :: DropFunctionEvent("db5", "fn4") :: Nil) + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/SharedState.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/SharedState.scala index d06dbaa2d0abc..f834569e59b7f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/SharedState.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/SharedState.scala @@ -109,6 +109,13 @@ private[sql] class SharedState(val sparkContext: SparkContext) extends Logging { } } + // Make sure we propagate external catalog events to the spark listener bus + externalCatalog.addListener(new ExternalCatalogEventListener { + override def onEvent(event: ExternalCatalogEvent): Unit = { + sparkContext.listenerBus.post(event) + } + }) + /** * A manager for global temporary views. */ diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala index 8b0fdf49cefab..71e33c46b9aed 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala @@ -141,13 +141,13 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat // Databases // -------------------------------------------------------------------------- - override def createDatabase( + override protected def doCreateDatabase( dbDefinition: CatalogDatabase, ignoreIfExists: Boolean): Unit = withClient { client.createDatabase(dbDefinition, ignoreIfExists) } - override def dropDatabase( + override protected def doDropDatabase( db: String, ignoreIfNotExists: Boolean, cascade: Boolean): Unit = withClient { @@ -194,7 +194,7 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat // Tables // -------------------------------------------------------------------------- - override def createTable( + override protected def doCreateTable( tableDefinition: CatalogTable, ignoreIfExists: Boolean): Unit = withClient { assert(tableDefinition.identifier.database.isDefined) @@ -456,7 +456,7 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat } } - override def dropTable( + override protected def doDropTable( db: String, table: String, ignoreIfNotExists: Boolean, @@ -465,7 +465,10 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat client.dropTable(db, table, ignoreIfNotExists, purge) } - override def renameTable(db: String, oldName: String, newName: String): Unit = withClient { + override protected def doRenameTable( + db: String, + oldName: String, + newName: String): Unit = withClient { val rawTable = getRawTable(db, oldName) // Note that Hive serde tables don't use path option in storage properties to store the value @@ -1056,7 +1059,7 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat // Functions // -------------------------------------------------------------------------- - override def createFunction( + override protected def doCreateFunction( db: String, funcDefinition: CatalogFunction): Unit = withClient { requireDbExists(db) @@ -1069,12 +1072,15 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat client.createFunction(db, funcDefinition.copy(identifier = functionIdentifier)) } - override def dropFunction(db: String, name: String): Unit = withClient { + override protected def doDropFunction(db: String, name: String): Unit = withClient { requireFunctionExists(db, name) client.dropFunction(db, name) } - override def renameFunction(db: String, oldName: String, newName: String): Unit = withClient { + override protected def doRenameFunction( + db: String, + oldName: String, + newName: String): Unit = withClient { requireFunctionExists(db, oldName) requireFunctionNotExists(db, newName) client.renameFunction(db, oldName, newName) From 34767997e0c6cb28e1fac8cb650fa3511f260ca5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Herv=C3=A9?= Date: Fri, 21 Apr 2017 08:52:18 +0100 Subject: [PATCH 0311/1765] Small rewording about history server use case MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Hello PR #10991 removed the built-in history view from Spark Standalone, so the history server is no longer useful to Yarn or Mesos only. Author: Hervé Closes #17709 from dud225/patch-1. --- docs/monitoring.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/monitoring.md b/docs/monitoring.md index da954385dc452..3e577c5f36778 100644 --- a/docs/monitoring.md +++ b/docs/monitoring.md @@ -27,8 +27,8 @@ in the UI to persisted storage. ## Viewing After the Fact -If Spark is run on Mesos or YARN, it is still possible to construct the UI of an -application through Spark's history server, provided that the application's event logs exist. +It is still possible to construct the UI of an application through Spark's history server, +provided that the application's event logs exist. You can start the history server by executing: ./sbin/start-history-server.sh From c9e6035e1fb825d280eaec3bdfc1e4d362897ffd Mon Sep 17 00:00:00 2001 From: Juliusz Sompolski Date: Fri, 21 Apr 2017 22:11:24 +0800 Subject: [PATCH 0312/1765] [SPARK-20412] Throw ParseException from visitNonOptionalPartitionSpec instead of returning null values. ## What changes were proposed in this pull request? If a partitionSpec is supposed to not contain optional values, a ParseException should be thrown, and not nulls returned. The nulls can later cause NullPointerExceptions in places not expecting them. ## How was this patch tested? A query like "SHOW PARTITIONS tbl PARTITION(col1='val1', col2)" used to throw a NullPointerException. Now it throws a ParseException. Author: Juliusz Sompolski Closes #17707 from juliuszsompolski/SPARK-20412. --- .../spark/sql/catalyst/parser/AstBuilder.scala | 5 ++++- .../sql/execution/command/DDLCommandSuite.scala | 16 ++++++++++++---- 2 files changed, 16 insertions(+), 5 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index e1db1ef5b8695..2cf06d15664d9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -215,7 +215,10 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with Logging { */ protected def visitNonOptionalPartitionSpec( ctx: PartitionSpecContext): Map[String, String] = withOrigin(ctx) { - visitPartitionSpec(ctx).mapValues(_.orNull).map(identity) + visitPartitionSpec(ctx).map { + case (key, None) => throw new ParseException(s"Found an empty partition key '$key'.", ctx) + case (key, Some(value)) => key -> value + } } /** diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLCommandSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLCommandSuite.scala index 97c61dc8694bc..8a6bc62fec96c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLCommandSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLCommandSuite.scala @@ -530,13 +530,13 @@ class DDLCommandSuite extends PlanTest { """.stripMargin val sql4 = """ - |ALTER TABLE table_name PARTITION (test, dt='2008-08-08', + |ALTER TABLE table_name PARTITION (test=1, dt='2008-08-08', |country='us') SET SERDE 'org.apache.class' WITH SERDEPROPERTIES ('columns'='foo,bar', |'field.delim' = ',') """.stripMargin val sql5 = """ - |ALTER TABLE table_name PARTITION (test, dt='2008-08-08', + |ALTER TABLE table_name PARTITION (test=1, dt='2008-08-08', |country='us') SET SERDEPROPERTIES ('columns'='foo,bar', 'field.delim' = ',') """.stripMargin val parsed1 = parser.parsePlan(sql1) @@ -558,12 +558,12 @@ class DDLCommandSuite extends PlanTest { tableIdent, Some("org.apache.class"), Some(Map("columns" -> "foo,bar", "field.delim" -> ",")), - Some(Map("test" -> null, "dt" -> "2008-08-08", "country" -> "us"))) + Some(Map("test" -> "1", "dt" -> "2008-08-08", "country" -> "us"))) val expected5 = AlterTableSerDePropertiesCommand( tableIdent, None, Some(Map("columns" -> "foo,bar", "field.delim" -> ",")), - Some(Map("test" -> null, "dt" -> "2008-08-08", "country" -> "us"))) + Some(Map("test" -> "1", "dt" -> "2008-08-08", "country" -> "us"))) comparePlans(parsed1, expected1) comparePlans(parsed2, expected2) comparePlans(parsed3, expected3) @@ -832,6 +832,14 @@ class DDLCommandSuite extends PlanTest { assert(e.contains("Found duplicate keys 'a'")) } + test("empty values in non-optional partition specs") { + val e = intercept[ParseException] { + parser.parsePlan( + "SHOW PARTITIONS dbx.tab1 PARTITION (a='1', b)") + }.getMessage + assert(e.contains("Found an empty partition key 'b'")) + } + test("drop table") { val tableName1 = "db.tab" val tableName2 = "tab" From a750a595976791cb8a77063f690ea8f82ea75a8f Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Fri, 21 Apr 2017 22:25:35 +0800 Subject: [PATCH 0313/1765] [SPARK-20341][SQL] Support BigInt's value that does not fit in long value range ## What changes were proposed in this pull request? This PR avoids an exception in the case where `scala.math.BigInt` has a value that does not fit into long value range (e.g. `Long.MAX_VALUE+1`). When we run the following code by using the current Spark, the following exception is thrown. This PR keeps the value using `BigDecimal` if we detect such an overflow case by catching `ArithmeticException`. Sample program: ``` case class BigIntWrapper(value:scala.math.BigInt)``` spark.createDataset(BigIntWrapper(scala.math.BigInt("10000000000000000002"))::Nil).show ``` Exception: ``` Error while encoding: java.lang.ArithmeticException: BigInteger out of long range staticinvoke(class org.apache.spark.sql.types.Decimal$, DecimalType(38,0), apply, assertnotnull(assertnotnull(input[0, org.apache.spark.sql.BigIntWrapper, true])).value, true) AS value#0 java.lang.RuntimeException: Error while encoding: java.lang.ArithmeticException: BigInteger out of long range staticinvoke(class org.apache.spark.sql.types.Decimal$, DecimalType(38,0), apply, assertnotnull(assertnotnull(input[0, org.apache.spark.sql.BigIntWrapper, true])).value, true) AS value#0 at org.apache.spark.sql.catalyst.encoders.ExpressionEncoder.toRow(ExpressionEncoder.scala:290) at org.apache.spark.sql.SparkSession$$anonfun$2.apply(SparkSession.scala:454) at org.apache.spark.sql.SparkSession$$anonfun$2.apply(SparkSession.scala:454) at scala.collection.TraversableLike$$anonfun$map$1.apply(TraversableLike.scala:234) at scala.collection.TraversableLike$$anonfun$map$1.apply(TraversableLike.scala:234) at scala.collection.immutable.List.foreach(List.scala:381) at scala.collection.TraversableLike$class.map(TraversableLike.scala:234) at scala.collection.immutable.List.map(List.scala:285) at org.apache.spark.sql.SparkSession.createDataset(SparkSession.scala:454) at org.apache.spark.sql.Agg$$anonfun$18.apply$mcV$sp(MySuite.scala:192) at org.apache.spark.sql.Agg$$anonfun$18.apply(MySuite.scala:192) at org.apache.spark.sql.Agg$$anonfun$18.apply(MySuite.scala:192) at org.scalatest.Transformer$$anonfun$apply$1.apply$mcV$sp(Transformer.scala:22) at org.scalatest.OutcomeOf$class.outcomeOf(OutcomeOf.scala:85) at org.scalatest.OutcomeOf$.outcomeOf(OutcomeOf.scala:104) at org.scalatest.Transformer.apply(Transformer.scala:22) at org.scalatest.Transformer.apply(Transformer.scala:20) at org.scalatest.FunSuiteLike$$anon$1.apply(FunSuiteLike.scala:166) at org.apache.spark.SparkFunSuite.withFixture(SparkFunSuite.scala:68) at org.scalatest.FunSuiteLike$class.invokeWithFixture$1(FunSuiteLike.scala:163) at org.scalatest.FunSuiteLike$$anonfun$runTest$1.apply(FunSuiteLike.scala:175) at org.scalatest.FunSuiteLike$$anonfun$runTest$1.apply(FunSuiteLike.scala:175) at org.scalatest.SuperEngine.runTestImpl(Engine.scala:306) at org.scalatest.FunSuiteLike$class.runTest(FunSuiteLike.scala:175) ... Caused by: java.lang.ArithmeticException: BigInteger out of long range at java.math.BigInteger.longValueExact(BigInteger.java:4531) at org.apache.spark.sql.types.Decimal.set(Decimal.scala:140) at org.apache.spark.sql.types.Decimal$.apply(Decimal.scala:434) at org.apache.spark.sql.types.Decimal.apply(Decimal.scala) at org.apache.spark.sql.catalyst.expressions.GeneratedClass$SpecificUnsafeProjection.apply(Unknown Source) at org.apache.spark.sql.catalyst.encoders.ExpressionEncoder.toRow(ExpressionEncoder.scala:287) ... 59 more ``` ## How was this patch tested? Add new test suite into `DecimalSuite` Author: Kazuaki Ishizaki Closes #17684 from kiszk/SPARK-20341. --- .../org/apache/spark/sql/types/Decimal.scala | 20 +++++++++++++------ .../apache/spark/sql/types/DecimalSuite.scala | 6 ++++++ 2 files changed, 20 insertions(+), 6 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala index e8f6884c025c2..80916ee9c5379 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala @@ -132,14 +132,22 @@ final class Decimal extends Ordered[Decimal] with Serializable { } /** - * Set this Decimal to the given BigInteger value. Will have precision 38 and scale 0. + * If the value is not in the range of long, convert it to BigDecimal and + * the precision and scale are based on the converted value. + * + * This code avoids BigDecimal object allocation as possible to improve runtime efficiency */ def set(bigintval: BigInteger): Decimal = { - this.decimalVal = null - this.longVal = bigintval.longValueExact() - this._precision = DecimalType.MAX_PRECISION - this._scale = 0 - this + try { + this.decimalVal = null + this.longVal = bigintval.longValueExact() + this._precision = DecimalType.MAX_PRECISION + this._scale = 0 + this + } catch { + case _: ArithmeticException => + set(BigDecimal(bigintval)) + } } /** diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DecimalSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DecimalSuite.scala index 714883a4099cf..93c231e30b49b 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DecimalSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DecimalSuite.scala @@ -212,4 +212,10 @@ class DecimalSuite extends SparkFunSuite with PrivateMethodTester { } } } + + test("SPARK-20341: support BigInt's value does not fit in long value range") { + val bigInt = scala.math.BigInt("9223372036854775808") + val decimal = Decimal.apply(bigInt) + assert(decimal.toJavaBigDecimal.unscaledValue.toString === "9223372036854775808") + } } From eb00378f0eed6afbf328ae6cd541cc202d14c1f0 Mon Sep 17 00:00:00 2001 From: WeichenXu Date: Fri, 21 Apr 2017 17:58:13 +0000 Subject: [PATCH 0314/1765] [SPARK-20423][ML] fix MLOR coeffs centering when reg == 0 ## What changes were proposed in this pull request? When reg == 0, MLOR has multiple solutions and we need to centralize the coeffs to get identical result. BUT current implementation centralize the `coefficientMatrix` by the global coeffs means. In fact the `coefficientMatrix` should be centralized on each feature index itself. Because, according to the MLOR probability distribution function, it can be proven easily that: suppose `{ w0, w1, .. w(K-1) }` make up the `coefficientMatrix`, then `{ w0 + c, w1 + c, ... w(K - 1) + c}` will also be the equivalent solution. `c` is an arbitrary vector of `numFeatures` dimension. reference https://core.ac.uk/download/pdf/6287975.pdf So that we need to centralize the `coefficientMatrix` on each feature dimension separately. **We can also confirm this through R library `glmnet`, that MLOR in `glmnet` always generate coefficients result that the sum of each dimension is all `zero`, when reg == 0.** ## How was this patch tested? Tests added. Author: WeichenXu Closes #17706 from WeichenXu123/mlor_center. --- .../spark/ml/classification/LogisticRegression.scala | 11 ++++++++--- .../ml/classification/LogisticRegressionSuite.scala | 6 ++++++ 2 files changed, 14 insertions(+), 3 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala index 965ce3d6f275f..bc8154692e52c 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala @@ -609,9 +609,14 @@ class LogisticRegression @Since("1.2.0") ( Friedman, et al. "Regularization Paths for Generalized Linear Models via Coordinate Descent," https://core.ac.uk/download/files/153/6287975.pdf */ - val denseValues = denseCoefficientMatrix.values - val coefficientMean = denseValues.sum / denseValues.length - denseCoefficientMatrix.update(_ - coefficientMean) + val centers = Array.fill(numFeatures)(0.0) + denseCoefficientMatrix.foreachActive { case (i, j, v) => + centers(j) += v + } + centers.transform(_ / numCoefficientSets) + denseCoefficientMatrix.foreachActive { case (i, j, v) => + denseCoefficientMatrix.update(i, j, v - centers(j)) + } } // center the intercepts when using multinomial algorithm diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala index c858b9bbfc256..83f575e83828f 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala @@ -1139,6 +1139,9 @@ class LogisticRegressionSuite 0.10095851, -0.85897154, 0.08392798, 0.07904499), isTransposed = true) val interceptsR = Vectors.dense(-2.10320093, 0.3394473, 1.76375361) + model1.coefficientMatrix.colIter.foreach(v => assert(v.toArray.sum ~== 0.0 absTol eps)) + model2.coefficientMatrix.colIter.foreach(v => assert(v.toArray.sum ~== 0.0 absTol eps)) + assert(model1.coefficientMatrix ~== coefficientsR relTol 0.05) assert(model1.coefficientMatrix.toArray.sum ~== 0.0 absTol eps) assert(model1.interceptVector ~== interceptsR relTol 0.05) @@ -1204,6 +1207,9 @@ class LogisticRegressionSuite -0.3180040, 0.9679074, -0.2252219, -0.4319914, 0.2452411, -0.6046524, 0.1050710, 0.1180180), isTransposed = true) + model1.coefficientMatrix.colIter.foreach(v => assert(v.toArray.sum ~== 0.0 absTol eps)) + model2.coefficientMatrix.colIter.foreach(v => assert(v.toArray.sum ~== 0.0 absTol eps)) + assert(model1.coefficientMatrix ~== coefficientsR relTol 0.05) assert(model1.coefficientMatrix.toArray.sum ~== 0.0 absTol eps) assert(model1.interceptVector.toArray === Array.fill(3)(0.0)) From fd648bff63f91a30810910dfc5664eea0ff5e6f9 Mon Sep 17 00:00:00 2001 From: zero323 Date: Fri, 21 Apr 2017 12:06:21 -0700 Subject: [PATCH 0315/1765] [SPARK-20371][R] Add wrappers for collect_list and collect_set ## What changes were proposed in this pull request? Adds wrappers for `collect_list` and `collect_set`. ## How was this patch tested? Unit tests, `check-cran.sh` Author: zero323 Closes #17672 from zero323/SPARK-20371. --- R/pkg/NAMESPACE | 2 ++ R/pkg/R/functions.R | 40 +++++++++++++++++++++++ R/pkg/R/generics.R | 9 +++++ R/pkg/inst/tests/testthat/test_sparkSQL.R | 22 +++++++++++++ 4 files changed, 73 insertions(+) diff --git a/R/pkg/NAMESPACE b/R/pkg/NAMESPACE index b6b559adf06ea..e804e30e14b86 100644 --- a/R/pkg/NAMESPACE +++ b/R/pkg/NAMESPACE @@ -203,6 +203,8 @@ exportMethods("%in%", "cbrt", "ceil", "ceiling", + "collect_list", + "collect_set", "column", "concat", "concat_ws", diff --git a/R/pkg/R/functions.R b/R/pkg/R/functions.R index f854df11e5769..e7decb91867bd 100644 --- a/R/pkg/R/functions.R +++ b/R/pkg/R/functions.R @@ -3705,3 +3705,43 @@ setMethod("create_map", jc <- callJStatic("org.apache.spark.sql.functions", "map", jcols) column(jc) }) + +#' collect_list +#' +#' Creates a list of objects with duplicates. +#' +#' @param x Column to compute on +#' +#' @rdname collect_list +#' @name collect_list +#' @family agg_funcs +#' @aliases collect_list,Column-method +#' @export +#' @examples \dontrun{collect_list(df$x)} +#' @note collect_list since 2.3.0 +setMethod("collect_list", + signature(x = "Column"), + function(x) { + jc <- callJStatic("org.apache.spark.sql.functions", "collect_list", x@jc) + column(jc) + }) + +#' collect_set +#' +#' Creates a list of objects with duplicate elements eliminated. +#' +#' @param x Column to compute on +#' +#' @rdname collect_set +#' @name collect_set +#' @family agg_funcs +#' @aliases collect_set,Column-method +#' @export +#' @examples \dontrun{collect_set(df$x)} +#' @note collect_set since 2.3.0 +setMethod("collect_set", + signature(x = "Column"), + function(x) { + jc <- callJStatic("org.apache.spark.sql.functions", "collect_set", x@jc) + column(jc) + }) diff --git a/R/pkg/R/generics.R b/R/pkg/R/generics.R index da46823f52a17..61d248ebd2e3e 100644 --- a/R/pkg/R/generics.R +++ b/R/pkg/R/generics.R @@ -918,6 +918,14 @@ setGeneric("cbrt", function(x) { standardGeneric("cbrt") }) #' @export setGeneric("ceil", function(x) { standardGeneric("ceil") }) +#' @rdname collect_list +#' @export +setGeneric("collect_list", function(x) { standardGeneric("collect_list") }) + +#' @rdname collect_set +#' @export +setGeneric("collect_set", function(x) { standardGeneric("collect_set") }) + #' @rdname column #' @export setGeneric("column", function(x) { standardGeneric("column") }) @@ -1358,6 +1366,7 @@ setGeneric("window", function(x, ...) { standardGeneric("window") }) #' @export setGeneric("year", function(x) { standardGeneric("year") }) + ###################### Spark.ML Methods ########################## #' @rdname fitted diff --git a/R/pkg/inst/tests/testthat/test_sparkSQL.R b/R/pkg/inst/tests/testthat/test_sparkSQL.R index 9e87a47106994..bf2093fdc475a 100644 --- a/R/pkg/inst/tests/testthat/test_sparkSQL.R +++ b/R/pkg/inst/tests/testthat/test_sparkSQL.R @@ -1731,6 +1731,28 @@ test_that("group by, agg functions", { expect_true(abs(sd(1:2) - 0.7071068) < 1e-6) expect_true(abs(var(1:5, 1:5) - 2.5) < 1e-6) + # Test collect_list and collect_set + gd3_collections_local <- collect( + agg(gd3, collect_set(df8$age), collect_list(df8$age)) + ) + + expect_equal( + unlist(gd3_collections_local[gd3_collections_local$name == "Andy", 2]), + c(30) + ) + + expect_equal( + unlist(gd3_collections_local[gd3_collections_local$name == "Andy", 3]), + c(30, 30) + ) + + expect_equal( + sort(unlist( + gd3_collections_local[gd3_collections_local$name == "Justin", 3] + )), + c(1, 19) + ) + unlink(jsonPath2) unlink(jsonPath3) }) From ad290402aa1d609abf5a2883a6d87fa8bc2bd517 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=83=AD=E5=B0=8F=E9=BE=99=2010207633?= Date: Fri, 21 Apr 2017 20:08:26 +0100 Subject: [PATCH 0316/1765] [SPARK-20401][DOC] In the spark official configuration document, the 'spark.driver.supervise' configuration parameter specification and default values are necessary. MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## What changes were proposed in this pull request? Use the REST interface submits the spark job. e.g. curl -X POST http://10.43.183.120:6066/v1/submissions/create --header "Content-Type:application/json;charset=UTF-8" --data'{ "action": "CreateSubmissionRequest", "appArgs": [ "myAppArgument" ], "appResource": "/home/mr/gxl/test.jar", "clientSparkVersion": "2.2.0", "environmentVariables": { "SPARK_ENV_LOADED": "1" }, "mainClass": "cn.zte.HdfsTest", "sparkProperties": { "spark.jars": "/home/mr/gxl/test.jar", **"spark.driver.supervise": "true",** "spark.app.name": "HdfsTest", "spark.eventLog.enabled": "false", "spark.submit.deployMode": "cluster", "spark.master": "spark://10.43.183.120:6066" } }' **I hope that make sure that the driver is automatically restarted if it fails with non-zero exit code. But I can not find the 'spark.driver.supervise' configuration parameter specification and default values from the spark official document.** ## How was this patch tested? manual tests Please review http://spark.apache.org/contributing.html before opening a pull request. Author: 郭小龙 10207633 Author: guoxiaolong Author: guoxiaolongzte Closes #17696 from guoxiaolongzte/SPARK-20401. --- docs/configuration.md | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/docs/configuration.md b/docs/configuration.md index 2687f542b8bd3..6b65d2bcb83e5 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -213,6 +213,14 @@ of the most common options to set are: and typically can have up to 50 characters. + + spark.driver.supervise + false + + If true, restarts the driver automatically if it fails with a non-zero exit status. + Only has effect in Spark standalone mode or Mesos cluster deploy mode. + + Apart from these, the following properties are also available, and may be useful in some situations: From 05a451491d535c0828413ce2eb06fe94571069ac Mon Sep 17 00:00:00 2001 From: eatoncys Date: Sat, 22 Apr 2017 12:29:35 +0100 Subject: [PATCH 0317/1765] [SPARK-20386][SPARK CORE] modify the log info if the block exists on the slave already ## What changes were proposed in this pull request? Modify the added memory size to memSize-originalMemSize if the block exists on the slave already since if the block exists, the added memory size should be memSize-originalMemSize; if originalMemSize is bigger than memSize ,then the log info should be Removed memory, removed size should be originalMemSize-memSize ## How was this patch tested? Multiple runs on existing unit tests (Please explain how this patch was tested. E.g. unit tests, integration tests, manual tests) (If this patch involves UI changes, please attach a screenshot; otherwise, remove this) Please review http://spark.apache.org/contributing.html before opening a pull request. Author: eatoncys Closes #17683 from eatoncys/SPARK-20386. --- .../storage/BlockManagerMasterEndpoint.scala | 52 +++++++++++++------ 1 file changed, 35 insertions(+), 17 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala index 467c3e0e6b51f..6f85b9e4d6c73 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala @@ -497,11 +497,17 @@ private[spark] class BlockManagerInfo( updateLastSeenMs() - if (_blocks.containsKey(blockId)) { + val blockExists = _blocks.containsKey(blockId) + var originalMemSize: Long = 0 + var originalDiskSize: Long = 0 + var originalLevel: StorageLevel = StorageLevel.NONE + + if (blockExists) { // The block exists on the slave already. val blockStatus: BlockStatus = _blocks.get(blockId) - val originalLevel: StorageLevel = blockStatus.storageLevel - val originalMemSize: Long = blockStatus.memSize + originalLevel = blockStatus.storageLevel + originalMemSize = blockStatus.memSize + originalDiskSize = blockStatus.diskSize if (originalLevel.useMemory) { _remainingMem += originalMemSize @@ -520,32 +526,44 @@ private[spark] class BlockManagerInfo( blockStatus = BlockStatus(storageLevel, memSize = memSize, diskSize = 0) _blocks.put(blockId, blockStatus) _remainingMem -= memSize - logInfo("Added %s in memory on %s (size: %s, free: %s)".format( - blockId, blockManagerId.hostPort, Utils.bytesToString(memSize), - Utils.bytesToString(_remainingMem))) + if (blockExists) { + logInfo(s"Updated $blockId in memory on ${blockManagerId.hostPort}" + + s" (current size: ${Utils.bytesToString(memSize)}," + + s" original size: ${Utils.bytesToString(originalMemSize)}," + + s" free: ${Utils.bytesToString(_remainingMem)})") + } else { + logInfo(s"Added $blockId in memory on ${blockManagerId.hostPort}" + + s" (size: ${Utils.bytesToString(memSize)}," + + s" free: ${Utils.bytesToString(_remainingMem)})") + } } if (storageLevel.useDisk) { blockStatus = BlockStatus(storageLevel, memSize = 0, diskSize = diskSize) _blocks.put(blockId, blockStatus) - logInfo("Added %s on disk on %s (size: %s)".format( - blockId, blockManagerId.hostPort, Utils.bytesToString(diskSize))) + if (blockExists) { + logInfo(s"Updated $blockId on disk on ${blockManagerId.hostPort}" + + s" (current size: ${Utils.bytesToString(diskSize)}," + + s" original size: ${Utils.bytesToString(originalDiskSize)})") + } else { + logInfo(s"Added $blockId on disk on ${blockManagerId.hostPort}" + + s" (size: ${Utils.bytesToString(diskSize)})") + } } if (!blockId.isBroadcast && blockStatus.isCached) { _cachedBlocks += blockId } - } else if (_blocks.containsKey(blockId)) { + } else if (blockExists) { // If isValid is not true, drop the block. - val blockStatus: BlockStatus = _blocks.get(blockId) _blocks.remove(blockId) _cachedBlocks -= blockId - if (blockStatus.storageLevel.useMemory) { - logInfo("Removed %s on %s in memory (size: %s, free: %s)".format( - blockId, blockManagerId.hostPort, Utils.bytesToString(blockStatus.memSize), - Utils.bytesToString(_remainingMem))) + if (originalLevel.useMemory) { + logInfo(s"Removed $blockId on ${blockManagerId.hostPort} in memory" + + s" (size: ${Utils.bytesToString(originalMemSize)}," + + s" free: ${Utils.bytesToString(_remainingMem)})") } - if (blockStatus.storageLevel.useDisk) { - logInfo("Removed %s on %s on disk (size: %s)".format( - blockId, blockManagerId.hostPort, Utils.bytesToString(blockStatus.diskSize))) + if (originalLevel.useDisk) { + logInfo(s"Removed $blockId on ${blockManagerId.hostPort} on disk" + + s" (size: ${Utils.bytesToString(originalDiskSize)})") } } } From b3c572a6b332b79fef72c309b9038b3c939dcba2 Mon Sep 17 00:00:00 2001 From: Takeshi Yamamuro Date: Sat, 22 Apr 2017 09:41:58 -0700 Subject: [PATCH 0318/1765] [SPARK-20430][SQL] Initialise RangeExec parameters in a driver side ## What changes were proposed in this pull request? This pr initialised `RangeExec` parameters in a driver side. In the current master, a query below throws `NullPointerException`; ``` sql("SET spark.sql.codegen.wholeStage=false") sql("SELECT * FROM range(1)").show 17/04/20 17:11:05 ERROR Executor: Exception in task 0.0 in stage 0.0 (TID 0) java.lang.NullPointerException at org.apache.spark.sql.execution.SparkPlan.sparkContext(SparkPlan.scala:54) at org.apache.spark.sql.execution.RangeExec.numSlices(basicPhysicalOperators.scala:343) at org.apache.spark.sql.execution.RangeExec$$anonfun$20.apply(basicPhysicalOperators.scala:506) at org.apache.spark.sql.execution.RangeExec$$anonfun$20.apply(basicPhysicalOperators.scala:505) at org.apache.spark.rdd.RDD$$anonfun$mapPartitionsWithIndex$1$$anonfun$apply$26.apply(RDD.scala:844) at org.apache.spark.rdd.RDD$$anonfun$mapPartitionsWithIndex$1$$anonfun$apply$26.apply(RDD.scala:844) at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:38) at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:323) at org.apache.spark.rdd.RDD.iterator(RDD.scala:287) at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:38) at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:323) at org.apache.spark.rdd.RDD.iterator(RDD.scala:287) at org.apache.spark.scheduler.ResultTask.runTask(ResultTask.scala:87) at org.apache.spark.scheduler.Task.run(Task.scala:108) at org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:320) at java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1142) at java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:617) ``` ## How was this patch tested? Added a test in `DataFrameRangeSuite`. Author: Takeshi Yamamuro Closes #17717 from maropu/SPARK-20430. --- .../spark/sql/execution/basicPhysicalOperators.scala | 10 +++++----- .../org/apache/spark/sql/DataFrameRangeSuite.scala | 6 ++++++ 2 files changed, 11 insertions(+), 5 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala index d3efa428a6db8..64698d5527578 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala @@ -331,11 +331,11 @@ case class SampleExec( case class RangeExec(range: org.apache.spark.sql.catalyst.plans.logical.Range) extends LeafExecNode with CodegenSupport { - def start: Long = range.start - def end: Long = range.end - def step: Long = range.step - def numSlices: Int = range.numSlices.getOrElse(sparkContext.defaultParallelism) - def numElements: BigInt = range.numElements + val start: Long = range.start + val end: Long = range.end + val step: Long = range.step + val numSlices: Int = range.numSlices.getOrElse(sparkContext.defaultParallelism) + val numElements: BigInt = range.numElements override val output: Seq[Attribute] = range.output diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameRangeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameRangeSuite.scala index 5e323c02b253d..7b495656b93d7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameRangeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameRangeSuite.scala @@ -185,6 +185,12 @@ class DataFrameRangeSuite extends QueryTest with SharedSQLContext with Eventuall } } } + + test("SPARK-20430 Initialize Range parameters in a driver side") { + withSQLConf(SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> "false") { + checkAnswer(sql("SELECT * FROM range(3)"), Row(0) :: Row(1) :: Row(2) :: Nil) + } + } } object DataFrameRangeSuite { From 8765bc17d0439032d0378686c4f2b17df2432abc Mon Sep 17 00:00:00 2001 From: Michael Patterson Date: Sat, 22 Apr 2017 19:58:54 -0700 Subject: [PATCH 0319/1765] [SPARK-20132][DOCS] Add documentation for column string functions ## What changes were proposed in this pull request? Add docstrings to column.py for the Column functions `rlike`, `like`, `startswith`, and `endswith`. Pass these docstrings through `_bin_op` There may be a better place to put the docstrings. I put them immediately above the Column class. ## How was this patch tested? I ran `make html` on my local computer to remake the documentation, and verified that the html pages were displaying the docstrings correctly. I tried running `dev-tests`, and the formatting tests passed. However, my mvn build didn't work I think due to issues on my computer. These docstrings are my original work and free license. davies has done the most recent work reorganizing `_bin_op` Author: Michael Patterson Closes #17469 from map222/patterson-documentation. --- python/pyspark/sql/column.py | 70 ++++++++++++++++++++++++++++++++---- 1 file changed, 64 insertions(+), 6 deletions(-) diff --git a/python/pyspark/sql/column.py b/python/pyspark/sql/column.py index ec05c18d4f062..46c1707cb6c37 100644 --- a/python/pyspark/sql/column.py +++ b/python/pyspark/sql/column.py @@ -250,11 +250,50 @@ def __iter__(self): raise TypeError("Column is not iterable") # string methods + _rlike_doc = """ + Return a Boolean :class:`Column` based on a regex match. + + :param other: an extended regex expression + + >>> df.filter(df.name.rlike('ice$')).collect() + [Row(age=2, name=u'Alice')] + """ + _like_doc = """ + Return a Boolean :class:`Column` based on a SQL LIKE match. + + :param other: a SQL LIKE pattern + + See :func:`rlike` for a regex version + + >>> df.filter(df.name.like('Al%')).collect() + [Row(age=2, name=u'Alice')] + """ + _startswith_doc = """ + Return a Boolean :class:`Column` based on a string match. + + :param other: string at end of line (do not use a regex `^`) + + >>> df.filter(df.name.startswith('Al')).collect() + [Row(age=2, name=u'Alice')] + >>> df.filter(df.name.startswith('^Al')).collect() + [] + """ + _endswith_doc = """ + Return a Boolean :class:`Column` based on matching end of string. + + :param other: string at end of line (do not use a regex `$`) + + >>> df.filter(df.name.endswith('ice')).collect() + [Row(age=2, name=u'Alice')] + >>> df.filter(df.name.endswith('ice$')).collect() + [] + """ + contains = _bin_op("contains") - rlike = _bin_op("rlike") - like = _bin_op("like") - startswith = _bin_op("startsWith") - endswith = _bin_op("endsWith") + rlike = ignore_unicode_prefix(_bin_op("rlike", _rlike_doc)) + like = ignore_unicode_prefix(_bin_op("like", _like_doc)) + startswith = ignore_unicode_prefix(_bin_op("startsWith", _startswith_doc)) + endswith = ignore_unicode_prefix(_bin_op("endsWith", _endswith_doc)) @ignore_unicode_prefix @since(1.3) @@ -303,8 +342,27 @@ def isin(self, *cols): desc = _unary_op("desc", "Returns a sort expression based on the" " descending order of the given column name.") - isNull = _unary_op("isNull", "True if the current expression is null.") - isNotNull = _unary_op("isNotNull", "True if the current expression is not null.") + _isNull_doc = """ + True if the current expression is null. Often combined with + :func:`DataFrame.filter` to select rows with null values. + + >>> from pyspark.sql import Row + >>> df2 = sc.parallelize([Row(name=u'Tom', height=80), Row(name=u'Alice', height=None)]).toDF() + >>> df2.filter(df2.height.isNull()).collect() + [Row(height=None, name=u'Alice')] + """ + _isNotNull_doc = """ + True if the current expression is null. Often combined with + :func:`DataFrame.filter` to select rows with non-null values. + + >>> from pyspark.sql import Row + >>> df2 = sc.parallelize([Row(name=u'Tom', height=80), Row(name=u'Alice', height=None)]).toDF() + >>> df2.filter(df2.height.isNotNull()).collect() + [Row(height=80, name=u'Tom')] + """ + + isNull = ignore_unicode_prefix(_unary_op("isNull", _isNull_doc)) + isNotNull = ignore_unicode_prefix(_unary_op("isNotNull", _isNotNull_doc)) @since(1.3) def alias(self, *alias, **kwargs): From 2eaf4f3fe3595ae341a3a5ce886b859992dea5b2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=83=AD=E5=B0=8F=E9=BE=99=2010207633?= Date: Sun, 23 Apr 2017 13:33:14 +0100 Subject: [PATCH 0320/1765] [SPARK-20385][WEB-UI] Submitted Time' field, the date format needs to be formatted, in running Drivers table or Completed Drivers table in master web ui. MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## What changes were proposed in this pull request? Submitted Time' field, the date format **needs to be formatted**, in running Drivers table or Completed Drivers table in master web ui. Before fix this problem e.g. Completed Drivers Submission ID **Submitted Time** Worker State Cores Memory Main Class driver-20170419145755-0005 **Wed Apr 19 14:57:55 CST 2017** worker-20170419145250-zdh120-40412 FAILED 1 1024.0 MB cn.zte.HdfsTest please see the attachment:https://issues.apache.org/jira/secure/attachment/12863977/before_fix.png After fix this problem e.g. Completed Drivers Submission ID **Submitted Time** Worker State Cores Memory Main Class driver-20170419145755-0006 **2017/04/19 16:01:25** worker-20170419145250-zdh120-40412 FAILED 1 1024.0 MB cn.zte.HdfsTest please see the attachment:https://issues.apache.org/jira/secure/attachment/12863976/after_fix.png 'Submitted Time' field, the date format **has been formatted**, in running Applications table or Completed Applicationstable in master web ui, **it is correct.** e.g. Running Applications Application ID Name Cores Memory per Executor **Submitted Time** User State Duration app-20170419160910-0000 (kill) SparkSQL::10.43.183.120 1 5.0 GB **2017/04/19 16:09:10** root RUNNING 53 s **Format after the time easier to observe, and consistent with the applications table,so I think it's worth fixing.** ## How was this patch tested? (Please explain how this patch was tested. E.g. unit tests, integration tests, manual tests) (If this patch involves UI changes, please attach a screenshot; otherwise, remove this) Please review http://spark.apache.org/contributing.html before opening a pull request. Author: 郭小龙 10207633 Author: guoxiaolong Author: guoxiaolongzte Closes #17682 from guoxiaolongzte/SPARK-20385. --- .../apache/spark/deploy/master/ui/ApplicationPage.scala | 2 +- .../org/apache/spark/deploy/master/ui/MasterPage.scala | 2 +- .../org/apache/spark/deploy/mesos/ui/DriverPage.scala | 4 ++-- .../apache/spark/deploy/mesos/ui/MesosClusterPage.scala | 8 ++++---- 4 files changed, 8 insertions(+), 8 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ui/ApplicationPage.scala b/core/src/main/scala/org/apache/spark/deploy/master/ui/ApplicationPage.scala index 946a92882141c..a8d721f3e0d49 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/ui/ApplicationPage.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/ui/ApplicationPage.scala @@ -83,7 +83,7 @@ private[ui] class ApplicationPage(parent: MasterWebUI) extends WebUIPage("app") Executor Memory: {Utils.megabytesToString(app.desc.memoryPerExecutorMB)}
  • -
  • Submit Date: {app.submitDate}
  • +
  • Submit Date: {UIUtils.formatDate(app.submitDate)}
  • State: {app.state}
  • { if (!app.isFinished) { diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterPage.scala b/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterPage.scala index e722a24d4a89e..9351c72094e34 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterPage.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterPage.scala @@ -252,7 +252,7 @@ private[ui] class MasterPage(parent: MasterWebUI) extends WebUIPage("") { } {driver.id} {killLink} - {driver.submitDate} + {UIUtils.formatDate(driver.submitDate)} {driver.worker.map(w => if (w.isAlive()) { diff --git a/resource-managers/mesos/src/main/scala/org/apache/spark/deploy/mesos/ui/DriverPage.scala b/resource-managers/mesos/src/main/scala/org/apache/spark/deploy/mesos/ui/DriverPage.scala index cd98110ddcc02..127fadabcce53 100644 --- a/resource-managers/mesos/src/main/scala/org/apache/spark/deploy/mesos/ui/DriverPage.scala +++ b/resource-managers/mesos/src/main/scala/org/apache/spark/deploy/mesos/ui/DriverPage.scala @@ -101,7 +101,7 @@ private[ui] class DriverPage(parent: MesosClusterUI) extends WebUIPage("driver") Launch Time - {state.startDate} + {UIUtils.formatDate(state.startDate)} Finish Time @@ -154,7 +154,7 @@ private[ui] class DriverPage(parent: MesosClusterUI) extends WebUIPage("driver") Memory{driver.mem} - Submitted{driver.submissionDate} + Submitted{UIUtils.formatDate(driver.submissionDate)} Supervise{driver.supervise} diff --git a/resource-managers/mesos/src/main/scala/org/apache/spark/deploy/mesos/ui/MesosClusterPage.scala b/resource-managers/mesos/src/main/scala/org/apache/spark/deploy/mesos/ui/MesosClusterPage.scala index 13ba7d311e57d..c9107c3e73d3f 100644 --- a/resource-managers/mesos/src/main/scala/org/apache/spark/deploy/mesos/ui/MesosClusterPage.scala +++ b/resource-managers/mesos/src/main/scala/org/apache/spark/deploy/mesos/ui/MesosClusterPage.scala @@ -68,7 +68,7 @@ private[mesos] class MesosClusterPage(parent: MesosClusterUI) extends WebUIPage( val id = submission.submissionId {id} - {submission.submissionDate} + {UIUtils.formatDate(submission.submissionDate)} {submission.command.mainClass} cpus: {submission.cores}, mem: {submission.mem} @@ -88,10 +88,10 @@ private[mesos] class MesosClusterPage(parent: MesosClusterUI) extends WebUIPage( {id} {historyCol} - {state.driverDescription.submissionDate} + {UIUtils.formatDate(state.driverDescription.submissionDate)} {state.driverDescription.command.mainClass} cpus: {state.driverDescription.cores}, mem: {state.driverDescription.mem} - {state.startDate} + {UIUtils.formatDate(state.startDate)} {state.slaveId.getValue} {stateString(state.mesosTaskStatus)} @@ -101,7 +101,7 @@ private[mesos] class MesosClusterPage(parent: MesosClusterUI) extends WebUIPage( val id = submission.submissionId {id} - {submission.submissionDate} + {UIUtils.formatDate(submission.submissionDate)} {submission.command.mainClass} {submission.retryState.get.lastFailureStatus} {submission.retryState.get.nextRetry} From e9f97154bc4af60376a550238315d7fc57099f9c Mon Sep 17 00:00:00 2001 From: Takeshi Yamamuro Date: Mon, 24 Apr 2017 09:34:38 +0100 Subject: [PATCH 0321/1765] [BUILD] Close stale PRs ## What changes were proposed in this pull request? This pr proposed to close stale PRs. Currently, we have 400+ open PRs and there are some stale PRs whose JIRA tickets have been already closed and whose JIRA tickets does not exist (also, they seem not to be minor issues). // Open PRs whose JIRA tickets have been already closed Closes #11785 Closes #13027 Closes #13614 Closes #13761 Closes #15197 Closes #14006 Closes #12576 Closes #15447 Closes #13259 Closes #15616 Closes #14473 Closes #16638 Closes #16146 Closes #17269 Closes #17313 Closes #17418 Closes #17485 Closes #17551 Closes #17463 Closes #17625 // Open PRs whose JIRA tickets does not exist and they are not minor issues Closes #10739 Closes #15193 Closes #15344 Closes #14804 Closes #16993 Closes #17040 Closes #15180 Closes #17238 ## How was this patch tested? N/A Author: Takeshi Yamamuro Closes #17734 from maropu/resolved_pr. From 776a2c0e91dfea170ea1c489118e1d42c4121f35 Mon Sep 17 00:00:00 2001 From: Xiao Li Date: Mon, 24 Apr 2017 17:21:42 +0800 Subject: [PATCH 0322/1765] [SPARK-20439][SQL] Fix Catalog API listTables and getTable when failed to fetch table metadata ### What changes were proposed in this pull request? `spark.catalog.listTables` and `spark.catalog.getTable` does not work if we are unable to retrieve table metadata due to any reason (e.g., table serde class is not accessible or the table type is not accepted by Spark SQL). After this PR, the APIs still return the corresponding Table without the description and tableType) ### How was this patch tested? Added a test case Author: Xiao Li Closes #17730 from gatorsmile/listTables. --- .../spark/sql/internal/CatalogImpl.scala | 28 +++++++++++++++---- .../sql/hive/execution/HiveDDLSuite.scala | 8 ++++++ 2 files changed, 31 insertions(+), 5 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/CatalogImpl.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/CatalogImpl.scala index aebb663df5c92..0b8e53868c999 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/CatalogImpl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/CatalogImpl.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.internal import scala.reflect.runtime.universe.TypeTag +import scala.util.control.NonFatal import org.apache.spark.annotation.Experimental import org.apache.spark.sql._ @@ -98,14 +99,27 @@ class CatalogImpl(sparkSession: SparkSession) extends Catalog { CatalogImpl.makeDataset(tables, sparkSession) } + /** + * Returns a Table for the given table/view or temporary view. + * + * Note that this function requires the table already exists in the Catalog. + * + * If the table metadata retrieval failed due to any reason (e.g., table serde class + * is not accessible or the table type is not accepted by Spark SQL), this function + * still returns the corresponding Table without the description and tableType) + */ private def makeTable(tableIdent: TableIdentifier): Table = { - val metadata = sessionCatalog.getTempViewOrPermanentTableMetadata(tableIdent) + val metadata = try { + Some(sessionCatalog.getTempViewOrPermanentTableMetadata(tableIdent)) + } catch { + case NonFatal(_) => None + } val isTemp = sessionCatalog.isTemporaryTable(tableIdent) new Table( name = tableIdent.table, - database = metadata.identifier.database.orNull, - description = metadata.comment.orNull, - tableType = if (isTemp) "TEMPORARY" else metadata.tableType.name, + database = metadata.map(_.identifier.database).getOrElse(tableIdent.database).orNull, + description = metadata.map(_.comment.orNull).orNull, + tableType = if (isTemp) "TEMPORARY" else metadata.map(_.tableType.name).orNull, isTemporary = isTemp) } @@ -197,7 +211,11 @@ class CatalogImpl(sparkSession: SparkSession) extends Catalog { * `AnalysisException` when no `Table` can be found. */ override def getTable(dbName: String, tableName: String): Table = { - makeTable(TableIdentifier(tableName, Option(dbName))) + if (tableExists(dbName, tableName)) { + makeTable(TableIdentifier(tableName, Option(dbName))) + } else { + throw new AnalysisException(s"Table or view '$tableName' not found in database '$dbName'") + } } /** diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala index 3906968aaff10..16a99321bad33 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala @@ -1197,6 +1197,14 @@ class HiveDDLSuite s"CREATE INDEX $indexName ON TABLE $tabName (a) AS 'COMPACT' WITH DEFERRED REBUILD") val indexTabName = spark.sessionState.catalog.listTables("default", s"*$indexName*").head.table + + // Even if index tables exist, listTables and getTable APIs should still work + checkAnswer( + spark.catalog.listTables().toDF(), + Row(indexTabName, "default", null, null, false) :: + Row(tabName, "default", null, "MANAGED", false) :: Nil) + assert(spark.catalog.getTable("default", indexTabName).name === indexTabName) + intercept[TableAlreadyExistsException] { sql(s"CREATE TABLE $indexTabName(b int)") } From 90264aced7cfdf265636517b91e5d1324fe60112 Mon Sep 17 00:00:00 2001 From: "wm624@hotmail.com" Date: Mon, 24 Apr 2017 23:43:06 +0800 Subject: [PATCH 0323/1765] [SPARK-18901][ML] Require in LR LogisticAggregator is redundant ## What changes were proposed in this pull request? In MultivariateOnlineSummarizer, `add` and `merge` have check for weights and feature sizes. The checks in LR are redundant, which are removed from this PR. ## How was this patch tested? Existing tests. Author: wm624@hotmail.com Closes #17478 from wangmiao1981/logit. --- .../apache/spark/ml/classification/LogisticRegression.scala | 5 ----- 1 file changed, 5 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala index bc8154692e52c..44b3478e0c3dd 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala @@ -1571,9 +1571,6 @@ private class LogisticAggregator( */ def add(instance: Instance): this.type = { instance match { case Instance(label, weight, features) => - require(numFeatures == features.size, s"Dimensions mismatch when adding new instance." + - s" Expecting $numFeatures but got ${features.size}.") - require(weight >= 0.0, s"instance weight, $weight has to be >= 0.0") if (weight == 0.0) return this @@ -1596,8 +1593,6 @@ private class LogisticAggregator( * @return This LogisticAggregator object. */ def merge(other: LogisticAggregator): this.type = { - require(numFeatures == other.numFeatures, s"Dimensions mismatch when merging with another " + - s"LogisticAggregator. Expecting $numFeatures but got ${other.numFeatures}.") if (other.weightSum != 0.0) { weightSum += other.weightSum From 8a272ddc9d2359a724aa89ae2f8de121a4aa7ac2 Mon Sep 17 00:00:00 2001 From: zero323 Date: Mon, 24 Apr 2017 10:56:57 -0700 Subject: [PATCH 0324/1765] [SPARK-20438][R] SparkR wrappers for split and repeat ## What changes were proposed in this pull request? Add wrappers for `o.a.s.sql.functions`: - `split` as `split_string` - `repeat` as `repeat_string` ## How was this patch tested? Existing tests, additional unit tests, `check-cran.sh` Author: zero323 Closes #17729 from zero323/SPARK-20438. --- R/pkg/NAMESPACE | 2 + R/pkg/R/functions.R | 58 +++++++++++++++++++++++ R/pkg/R/generics.R | 8 ++++ R/pkg/inst/tests/testthat/test_sparkSQL.R | 34 +++++++++++++ 4 files changed, 102 insertions(+) diff --git a/R/pkg/NAMESPACE b/R/pkg/NAMESPACE index e804e30e14b86..95d5cc6d1c78e 100644 --- a/R/pkg/NAMESPACE +++ b/R/pkg/NAMESPACE @@ -300,6 +300,7 @@ exportMethods("%in%", "rank", "regexp_extract", "regexp_replace", + "repeat_string", "reverse", "rint", "rlike", @@ -323,6 +324,7 @@ exportMethods("%in%", "sort_array", "soundex", "spark_partition_id", + "split_string", "stddev", "stddev_pop", "stddev_samp", diff --git a/R/pkg/R/functions.R b/R/pkg/R/functions.R index e7decb91867bd..752e4c5c7189d 100644 --- a/R/pkg/R/functions.R +++ b/R/pkg/R/functions.R @@ -3745,3 +3745,61 @@ setMethod("collect_set", jc <- callJStatic("org.apache.spark.sql.functions", "collect_set", x@jc) column(jc) }) + +#' split_string +#' +#' Splits string on regular expression. +#' +#' Equivalent to \code{split} SQL function +#' +#' @param x Column to compute on +#' @param pattern Java regular expression +#' +#' @rdname split_string +#' @family string_funcs +#' @aliases split_string,Column-method +#' @export +#' @examples \dontrun{ +#' df <- read.text("README.md") +#' +#' head(select(df, split_string(df$value, "\\s+"))) +#' +#' # This is equivalent to the following SQL expression +#' head(selectExpr(df, "split(value, '\\\\s+')")) +#' } +#' @note split_string 2.3.0 +setMethod("split_string", + signature(x = "Column", pattern = "character"), + function(x, pattern) { + jc <- callJStatic("org.apache.spark.sql.functions", "split", x@jc, pattern) + column(jc) + }) + +#' repeat_string +#' +#' Repeats string n times. +#' +#' Equivalent to \code{repeat} SQL function +#' +#' @param x Column to compute on +#' @param n Number of repetitions +#' +#' @rdname repeat_string +#' @family string_funcs +#' @aliases repeat_string,Column-method +#' @export +#' @examples \dontrun{ +#' df <- read.text("README.md") +#' +#' first(select(df, repeat_string(df$value, 3))) +#' +#' # This is equivalent to the following SQL expression +#' first(selectExpr(df, "repeat(value, 3)")) +#' } +#' @note repeat_string 2.3.0 +setMethod("repeat_string", + signature(x = "Column", n = "numeric"), + function(x, n) { + jc <- callJStatic("org.apache.spark.sql.functions", "repeat", x@jc, numToInt(n)) + column(jc) + }) diff --git a/R/pkg/R/generics.R b/R/pkg/R/generics.R index 61d248ebd2e3e..5e7a1c60c2b3b 100644 --- a/R/pkg/R/generics.R +++ b/R/pkg/R/generics.R @@ -1192,6 +1192,10 @@ setGeneric("regexp_extract", function(x, pattern, idx) { standardGeneric("regexp setGeneric("regexp_replace", function(x, pattern, replacement) { standardGeneric("regexp_replace") }) +#' @rdname repeat_string +#' @export +setGeneric("repeat_string", function(x, n) { standardGeneric("repeat_string") }) + #' @rdname reverse #' @export setGeneric("reverse", function(x) { standardGeneric("reverse") }) @@ -1257,6 +1261,10 @@ setGeneric("skewness", function(x) { standardGeneric("skewness") }) #' @export setGeneric("sort_array", function(x, asc = TRUE) { standardGeneric("sort_array") }) +#' @rdname split_string +#' @export +setGeneric("split_string", function(x, pattern) { standardGeneric("split_string") }) + #' @rdname soundex #' @export setGeneric("soundex", function(x) { standardGeneric("soundex") }) diff --git a/R/pkg/inst/tests/testthat/test_sparkSQL.R b/R/pkg/inst/tests/testthat/test_sparkSQL.R index bf2093fdc475a..c21ba2f1a138b 100644 --- a/R/pkg/inst/tests/testthat/test_sparkSQL.R +++ b/R/pkg/inst/tests/testthat/test_sparkSQL.R @@ -1546,6 +1546,40 @@ test_that("string operators", { expect_equal(collect(select(df3, substring_index(df3$a, ".", 2)))[1, 1], "a.b") expect_equal(collect(select(df3, substring_index(df3$a, ".", -3)))[1, 1], "b.c.d") expect_equal(collect(select(df3, translate(df3$a, "bc", "12")))[1, 1], "a.1.2.d") + + l4 <- list(list(a = "a.b@c.d 1\\b")) + df4 <- createDataFrame(l4) + expect_equal( + collect(select(df4, split_string(df4$a, "\\s+")))[1, 1], + list(list("a.b@c.d", "1\\b")) + ) + expect_equal( + collect(select(df4, split_string(df4$a, "\\.")))[1, 1], + list(list("a", "b@c", "d 1\\b")) + ) + expect_equal( + collect(select(df4, split_string(df4$a, "@")))[1, 1], + list(list("a.b", "c.d 1\\b")) + ) + expect_equal( + collect(select(df4, split_string(df4$a, "\\\\")))[1, 1], + list(list("a.b@c.d 1", "b")) + ) + + l5 <- list(list(a = "abc")) + df5 <- createDataFrame(l5) + expect_equal( + collect(select(df5, repeat_string(df5$a, 1L)))[1, 1], + "abc" + ) + expect_equal( + collect(select(df5, repeat_string(df5$a, 3)))[1, 1], + "abcabcabc" + ) + expect_equal( + collect(select(df5, repeat_string(df5$a, -1)))[1, 1], + "" + ) }) test_that("date functions on a DataFrame", { From 5280d93e6ecec7327e7fcd3d8d1cb90e01e774fc Mon Sep 17 00:00:00 2001 From: jerryshao Date: Mon, 24 Apr 2017 18:18:59 -0700 Subject: [PATCH 0325/1765] [SPARK-20239][CORE] Improve HistoryServer's ACL mechanism ## What changes were proposed in this pull request? Current SHS (Spark History Server) two different ACLs: * ACL of base URL, it is controlled by "spark.acls.enabled" or "spark.ui.acls.enabled", and with this enabled, only user configured with "spark.admin.acls" (or group) or "spark.ui.view.acls" (or group), or the user who started SHS could list all the applications, otherwise none of them can be listed. This will also affect REST APIs which listing the summary of all apps and one app. * Per application ACL. This is controlled by "spark.history.ui.acls.enabled". With this enabled only history admin user and user/group who ran this app can access the details of this app. With this two ACLs, we may encounter several unexpected behaviors: 1. if base URL's ACL (`spark.acls.enable`) is enabled but user A has no view permission. User "A" cannot see the app list but could still access details of it's own app. 2. if ACLs of base URL (`spark.acls.enable`) is disabled, then user "A" could download any application's event log, even it is not run by user "A". 3. The changes of Live UI's ACL will affect History UI's ACL which share the same conf file. The unexpected behaviors is mainly because we have two different ACLs, ideally we should have only one to manage all. So to improve SHS's ACL mechanism, here in this PR proposed to: 1. Disable "spark.acls.enable" and only use "spark.history.ui.acls.enable" for history server. 2. Check permission for event-log download REST API. With this PR: 1. Admin user could see/download the list of all applications, as well as application details. 2. Normal user could see the list of all applications, but can only download and check the details of applications accessible to him. ## How was this patch tested? New UTs are added, also verified in real cluster. CC tgravescs vanzin please help to review, this PR changes the semantics you did previously. Thanks a lot. Author: jerryshao Closes #17582 from jerryshao/SPARK-20239. --- .../history/ApplicationHistoryProvider.scala | 4 ++-- .../spark/deploy/history/HistoryServer.scala | 8 ++++++++ .../spark/status/api/v1/ApiRootResource.scala | 18 +++++++++++++++--- .../deploy/history/HistoryServerSuite.scala | 14 ++++++++------ 4 files changed, 33 insertions(+), 11 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/history/ApplicationHistoryProvider.scala b/core/src/main/scala/org/apache/spark/deploy/history/ApplicationHistoryProvider.scala index d7d82800b8b55..6d8758a3d3b1d 100644 --- a/core/src/main/scala/org/apache/spark/deploy/history/ApplicationHistoryProvider.scala +++ b/core/src/main/scala/org/apache/spark/deploy/history/ApplicationHistoryProvider.scala @@ -86,7 +86,7 @@ private[history] abstract class ApplicationHistoryProvider { * @return Count of application event logs that are currently under process */ def getEventLogsUnderProcess(): Int = { - return 0; + 0 } /** @@ -95,7 +95,7 @@ private[history] abstract class ApplicationHistoryProvider { * @return 0 if this is undefined or unsupported, otherwise the last updated time in millis */ def getLastUpdatedTime(): Long = { - return 0; + 0 } /** diff --git a/core/src/main/scala/org/apache/spark/deploy/history/HistoryServer.scala b/core/src/main/scala/org/apache/spark/deploy/history/HistoryServer.scala index 54f39f7620e5d..d9c8fda99ef97 100644 --- a/core/src/main/scala/org/apache/spark/deploy/history/HistoryServer.scala +++ b/core/src/main/scala/org/apache/spark/deploy/history/HistoryServer.scala @@ -301,6 +301,14 @@ object HistoryServer extends Logging { logDebug(s"Clearing ${SecurityManager.SPARK_AUTH_CONF}") config.set(SecurityManager.SPARK_AUTH_CONF, "false") } + + if (config.getBoolean("spark.acls.enable", config.getBoolean("spark.ui.acls.enable", false))) { + logInfo("Either spark.acls.enable or spark.ui.acls.enable is configured, clearing it and " + + "only using spark.history.ui.acl.enable") + config.set("spark.acls.enable", "false") + config.set("spark.ui.acls.enable", "false") + } + new SecurityManager(config) } diff --git a/core/src/main/scala/org/apache/spark/status/api/v1/ApiRootResource.scala b/core/src/main/scala/org/apache/spark/status/api/v1/ApiRootResource.scala index 00f918c09c66b..f17b637754826 100644 --- a/core/src/main/scala/org/apache/spark/status/api/v1/ApiRootResource.scala +++ b/core/src/main/scala/org/apache/spark/status/api/v1/ApiRootResource.scala @@ -184,14 +184,27 @@ private[v1] class ApiRootResource extends ApiRequestContext { @Path("applications/{appId}/logs") def getEventLogs( @PathParam("appId") appId: String): EventLogDownloadResource = { - new EventLogDownloadResource(uiRoot, appId, None) + try { + // withSparkUI will throw NotFoundException if attemptId exists for this application. + // So we need to try again with attempt id "1". + withSparkUI(appId, None) { _ => + new EventLogDownloadResource(uiRoot, appId, None) + } + } catch { + case _: NotFoundException => + withSparkUI(appId, Some("1")) { _ => + new EventLogDownloadResource(uiRoot, appId, None) + } + } } @Path("applications/{appId}/{attemptId}/logs") def getEventLogs( @PathParam("appId") appId: String, @PathParam("attemptId") attemptId: String): EventLogDownloadResource = { - new EventLogDownloadResource(uiRoot, appId, Some(attemptId)) + withSparkUI(appId, Some(attemptId)) { _ => + new EventLogDownloadResource(uiRoot, appId, Some(attemptId)) + } } @Path("version") @@ -291,7 +304,6 @@ private[v1] trait ApiRequestContext { case None => throw new NotFoundException("no such app: " + appId) } } - } private[v1] class ForbiddenException(msg: String) extends WebApplicationException( diff --git a/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerSuite.scala b/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerSuite.scala index 764156c3edc41..95acb9a54440f 100644 --- a/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerSuite.scala @@ -565,13 +565,12 @@ class HistoryServerSuite extends SparkFunSuite with BeforeAndAfter with Matchers assert(jobcount === getNumJobs("/jobs")) // no need to retain the test dir now the tests complete - logDir.deleteOnExit(); - + logDir.deleteOnExit() } test("ui and api authorization checks") { - val appId = "app-20161115172038-0000" - val owner = "jose" + val appId = "local-1430917381535" + val owner = "irashid" val admin = "root" val other = "alice" @@ -590,8 +589,11 @@ class HistoryServerSuite extends SparkFunSuite with BeforeAndAfter with Matchers val port = server.boundPort val testUrls = Seq( - s"http://localhost:$port/api/v1/applications/$appId/jobs", - s"http://localhost:$port/history/$appId/jobs/") + s"http://localhost:$port/api/v1/applications/$appId/1/jobs", + s"http://localhost:$port/history/$appId/1/jobs/", + s"http://localhost:$port/api/v1/applications/$appId/logs", + s"http://localhost:$port/api/v1/applications/$appId/1/logs", + s"http://localhost:$port/api/v1/applications/$appId/2/logs") tests.foreach { case (user, expectedCode) => testUrls.foreach { url => From f44c8a843ca512b319f099477415bc13eca2e373 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Mon, 24 Apr 2017 21:48:04 -0700 Subject: [PATCH 0326/1765] [SPARK-20453] Bump master branch version to 2.3.0-SNAPSHOT This patch bumps the master branch version to `2.3.0-SNAPSHOT`. Author: Josh Rosen Closes #17753 from JoshRosen/SPARK-20453. --- assembly/pom.xml | 2 +- common/network-common/pom.xml | 2 +- common/network-shuffle/pom.xml | 2 +- common/network-yarn/pom.xml | 2 +- common/sketch/pom.xml | 2 +- common/tags/pom.xml | 2 +- common/unsafe/pom.xml | 2 +- core/pom.xml | 2 +- docs/_config.yml | 4 ++-- examples/pom.xml | 2 +- external/docker-integration-tests/pom.xml | 2 +- external/flume-assembly/pom.xml | 2 +- external/flume-sink/pom.xml | 2 +- external/flume/pom.xml | 2 +- external/kafka-0-10-assembly/pom.xml | 2 +- external/kafka-0-10-sql/pom.xml | 2 +- external/kafka-0-10/pom.xml | 2 +- external/kafka-0-8-assembly/pom.xml | 2 +- external/kafka-0-8/pom.xml | 2 +- external/kinesis-asl-assembly/pom.xml | 2 +- external/kinesis-asl/pom.xml | 2 +- external/spark-ganglia-lgpl/pom.xml | 2 +- graphx/pom.xml | 2 +- launcher/pom.xml | 2 +- mllib-local/pom.xml | 2 +- mllib/pom.xml | 2 +- pom.xml | 2 +- project/MimaExcludes.scala | 5 +++++ repl/pom.xml | 2 +- resource-managers/mesos/pom.xml | 2 +- resource-managers/yarn/pom.xml | 2 +- sql/catalyst/pom.xml | 2 +- sql/core/pom.xml | 2 +- sql/hive-thriftserver/pom.xml | 2 +- sql/hive/pom.xml | 2 +- streaming/pom.xml | 2 +- tools/pom.xml | 2 +- 37 files changed, 42 insertions(+), 37 deletions(-) diff --git a/assembly/pom.xml b/assembly/pom.xml index 9d8607d9137c6..742a4a1531e71 100644 --- a/assembly/pom.xml +++ b/assembly/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.11 - 2.2.0-SNAPSHOT + 2.3.0-SNAPSHOT ../pom.xml diff --git a/common/network-common/pom.xml b/common/network-common/pom.xml index 8657af744c069..066970f24205f 100644 --- a/common/network-common/pom.xml +++ b/common/network-common/pom.xml @@ -22,7 +22,7 @@ org.apache.spark spark-parent_2.11 - 2.2.0-SNAPSHOT + 2.3.0-SNAPSHOT ../../pom.xml diff --git a/common/network-shuffle/pom.xml b/common/network-shuffle/pom.xml index 24c10fb1ddb9f..2de882adcb582 100644 --- a/common/network-shuffle/pom.xml +++ b/common/network-shuffle/pom.xml @@ -22,7 +22,7 @@ org.apache.spark spark-parent_2.11 - 2.2.0-SNAPSHOT + 2.3.0-SNAPSHOT ../../pom.xml diff --git a/common/network-yarn/pom.xml b/common/network-yarn/pom.xml index 5e5a80bd44467..a8488d8d1b704 100644 --- a/common/network-yarn/pom.xml +++ b/common/network-yarn/pom.xml @@ -22,7 +22,7 @@ org.apache.spark spark-parent_2.11 - 2.2.0-SNAPSHOT + 2.3.0-SNAPSHOT ../../pom.xml diff --git a/common/sketch/pom.xml b/common/sketch/pom.xml index 1356c4723b662..6b81fc2b2b040 100644 --- a/common/sketch/pom.xml +++ b/common/sketch/pom.xml @@ -22,7 +22,7 @@ org.apache.spark spark-parent_2.11 - 2.2.0-SNAPSHOT + 2.3.0-SNAPSHOT ../../pom.xml diff --git a/common/tags/pom.xml b/common/tags/pom.xml index 9345dc8f0cc4b..f7e586ee777e1 100644 --- a/common/tags/pom.xml +++ b/common/tags/pom.xml @@ -22,7 +22,7 @@ org.apache.spark spark-parent_2.11 - 2.2.0-SNAPSHOT + 2.3.0-SNAPSHOT ../../pom.xml diff --git a/common/unsafe/pom.xml b/common/unsafe/pom.xml index f03a4da5e7152..680d0413b1616 100644 --- a/common/unsafe/pom.xml +++ b/common/unsafe/pom.xml @@ -22,7 +22,7 @@ org.apache.spark spark-parent_2.11 - 2.2.0-SNAPSHOT + 2.3.0-SNAPSHOT ../../pom.xml diff --git a/core/pom.xml b/core/pom.xml index 24ce36deeb169..7f245b5b6384a 100644 --- a/core/pom.xml +++ b/core/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.11 - 2.2.0-SNAPSHOT + 2.3.0-SNAPSHOT ../pom.xml diff --git a/docs/_config.yml b/docs/_config.yml index 83bb30598d153..21255ef7a5c45 100644 --- a/docs/_config.yml +++ b/docs/_config.yml @@ -14,8 +14,8 @@ include: # These allow the documentation to be updated with newer releases # of Spark, Scala, and Mesos. -SPARK_VERSION: 2.2.0-SNAPSHOT -SPARK_VERSION_SHORT: 2.2.0 +SPARK_VERSION: 2.3.0-SNAPSHOT +SPARK_VERSION_SHORT: 2.3.0 SCALA_BINARY_VERSION: "2.11" SCALA_VERSION: "2.11.7" MESOS_VERSION: 1.0.0 diff --git a/examples/pom.xml b/examples/pom.xml index 91c2e81ebed2f..e674e799f24a3 100644 --- a/examples/pom.xml +++ b/examples/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.11 - 2.2.0-SNAPSHOT + 2.3.0-SNAPSHOT ../pom.xml diff --git a/external/docker-integration-tests/pom.xml b/external/docker-integration-tests/pom.xml index 8948df2da89e2..0fa87a697454b 100644 --- a/external/docker-integration-tests/pom.xml +++ b/external/docker-integration-tests/pom.xml @@ -22,7 +22,7 @@ org.apache.spark spark-parent_2.11 - 2.2.0-SNAPSHOT + 2.3.0-SNAPSHOT ../../pom.xml diff --git a/external/flume-assembly/pom.xml b/external/flume-assembly/pom.xml index f8ef8a991316d..71016bc645ca7 100644 --- a/external/flume-assembly/pom.xml +++ b/external/flume-assembly/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.11 - 2.2.0-SNAPSHOT + 2.3.0-SNAPSHOT ../../pom.xml diff --git a/external/flume-sink/pom.xml b/external/flume-sink/pom.xml index 6d547c46d6a2d..12630840e79dc 100644 --- a/external/flume-sink/pom.xml +++ b/external/flume-sink/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.11 - 2.2.0-SNAPSHOT + 2.3.0-SNAPSHOT ../../pom.xml diff --git a/external/flume/pom.xml b/external/flume/pom.xml index 46901d64eda97..87a09642405a7 100644 --- a/external/flume/pom.xml +++ b/external/flume/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.11 - 2.2.0-SNAPSHOT + 2.3.0-SNAPSHOT ../../pom.xml diff --git a/external/kafka-0-10-assembly/pom.xml b/external/kafka-0-10-assembly/pom.xml index 295142cbfdff9..75df886ca44f6 100644 --- a/external/kafka-0-10-assembly/pom.xml +++ b/external/kafka-0-10-assembly/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.11 - 2.2.0-SNAPSHOT + 2.3.0-SNAPSHOT ../../pom.xml diff --git a/external/kafka-0-10-sql/pom.xml b/external/kafka-0-10-sql/pom.xml index 6cf448e65e8b4..557d27296345f 100644 --- a/external/kafka-0-10-sql/pom.xml +++ b/external/kafka-0-10-sql/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.11 - 2.2.0-SNAPSHOT + 2.3.0-SNAPSHOT ../../pom.xml diff --git a/external/kafka-0-10/pom.xml b/external/kafka-0-10/pom.xml index 88499240cd569..6c98cb04fcfa6 100644 --- a/external/kafka-0-10/pom.xml +++ b/external/kafka-0-10/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.11 - 2.2.0-SNAPSHOT + 2.3.0-SNAPSHOT ../../pom.xml diff --git a/external/kafka-0-8-assembly/pom.xml b/external/kafka-0-8-assembly/pom.xml index 3fedd9eda1959..f9c2dcb38dc0e 100644 --- a/external/kafka-0-8-assembly/pom.xml +++ b/external/kafka-0-8-assembly/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.11 - 2.2.0-SNAPSHOT + 2.3.0-SNAPSHOT ../../pom.xml diff --git a/external/kafka-0-8/pom.xml b/external/kafka-0-8/pom.xml index 8368a1f12218d..849c8b465f99e 100644 --- a/external/kafka-0-8/pom.xml +++ b/external/kafka-0-8/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.11 - 2.2.0-SNAPSHOT + 2.3.0-SNAPSHOT ../../pom.xml diff --git a/external/kinesis-asl-assembly/pom.xml b/external/kinesis-asl-assembly/pom.xml index 90bb0e4987c82..48783d65826aa 100644 --- a/external/kinesis-asl-assembly/pom.xml +++ b/external/kinesis-asl-assembly/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.11 - 2.2.0-SNAPSHOT + 2.3.0-SNAPSHOT ../../pom.xml diff --git a/external/kinesis-asl/pom.xml b/external/kinesis-asl/pom.xml index daa79e79163b9..40a751a652fa9 100644 --- a/external/kinesis-asl/pom.xml +++ b/external/kinesis-asl/pom.xml @@ -20,7 +20,7 @@ org.apache.spark spark-parent_2.11 - 2.2.0-SNAPSHOT + 2.3.0-SNAPSHOT ../../pom.xml diff --git a/external/spark-ganglia-lgpl/pom.xml b/external/spark-ganglia-lgpl/pom.xml index 7da27817ebafd..36d555066b181 100644 --- a/external/spark-ganglia-lgpl/pom.xml +++ b/external/spark-ganglia-lgpl/pom.xml @@ -20,7 +20,7 @@ org.apache.spark spark-parent_2.11 - 2.2.0-SNAPSHOT + 2.3.0-SNAPSHOT ../../pom.xml diff --git a/graphx/pom.xml b/graphx/pom.xml index 8df33660ea9d1..cb30e4a4af4bc 100644 --- a/graphx/pom.xml +++ b/graphx/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.11 - 2.2.0-SNAPSHOT + 2.3.0-SNAPSHOT ../pom.xml diff --git a/launcher/pom.xml b/launcher/pom.xml index 025cd84f20f0e..e9b46c4cf0ffa 100644 --- a/launcher/pom.xml +++ b/launcher/pom.xml @@ -22,7 +22,7 @@ org.apache.spark spark-parent_2.11 - 2.2.0-SNAPSHOT + 2.3.0-SNAPSHOT ../pom.xml diff --git a/mllib-local/pom.xml b/mllib-local/pom.xml index 663f7fb0b010d..043d13609fd26 100644 --- a/mllib-local/pom.xml +++ b/mllib-local/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.11 - 2.2.0-SNAPSHOT + 2.3.0-SNAPSHOT ../pom.xml diff --git a/mllib/pom.xml b/mllib/pom.xml index 82f840b0fc269..572670dc11b42 100644 --- a/mllib/pom.xml +++ b/mllib/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.11 - 2.2.0-SNAPSHOT + 2.3.0-SNAPSHOT ../pom.xml diff --git a/pom.xml b/pom.xml index c1174593c1922..a65692e0d1318 100644 --- a/pom.xml +++ b/pom.xml @@ -26,7 +26,7 @@ org.apache.spark spark-parent_2.11 - 2.2.0-SNAPSHOT + 2.3.0-SNAPSHOT pom Spark Project Parent POM http://spark.apache.org/ diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index feae76a087dec..dbf933f28a784 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -34,6 +34,10 @@ import com.typesafe.tools.mima.core.ProblemFilters._ */ object MimaExcludes { + // Exclude rules for 2.3.x + lazy val v23excludes = v22excludes ++ Seq( + ) + // Exclude rules for 2.2.x lazy val v22excludes = v21excludes ++ Seq( // [SPARK-19652][UI] Do auth checks for REST API access. @@ -1003,6 +1007,7 @@ object MimaExcludes { } def excludes(version: String) = version match { + case v if v.startsWith("2.3") => v23excludes case v if v.startsWith("2.2") => v22excludes case v if v.startsWith("2.1") => v21excludes case v if v.startsWith("2.0") => v20excludes diff --git a/repl/pom.xml b/repl/pom.xml index a256ae3b84183..6d133a3cfff7d 100644 --- a/repl/pom.xml +++ b/repl/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.11 - 2.2.0-SNAPSHOT + 2.3.0-SNAPSHOT ../pom.xml diff --git a/resource-managers/mesos/pom.xml b/resource-managers/mesos/pom.xml index 03846d9f5a3be..20b53f2d8f987 100644 --- a/resource-managers/mesos/pom.xml +++ b/resource-managers/mesos/pom.xml @@ -20,7 +20,7 @@ org.apache.spark spark-parent_2.11 - 2.2.0-SNAPSHOT + 2.3.0-SNAPSHOT ../../pom.xml diff --git a/resource-managers/yarn/pom.xml b/resource-managers/yarn/pom.xml index a1b641c8eeb84..71d4ad681e169 100644 --- a/resource-managers/yarn/pom.xml +++ b/resource-managers/yarn/pom.xml @@ -20,7 +20,7 @@ org.apache.spark spark-parent_2.11 - 2.2.0-SNAPSHOT + 2.3.0-SNAPSHOT ../../pom.xml diff --git a/sql/catalyst/pom.xml b/sql/catalyst/pom.xml index 765c92b8d3b9e..8d80f8eca5dba 100644 --- a/sql/catalyst/pom.xml +++ b/sql/catalyst/pom.xml @@ -22,7 +22,7 @@ org.apache.spark spark-parent_2.11 - 2.2.0-SNAPSHOT + 2.3.0-SNAPSHOT ../../pom.xml diff --git a/sql/core/pom.xml b/sql/core/pom.xml index b203f31a76f03..e170133f0f0bf 100644 --- a/sql/core/pom.xml +++ b/sql/core/pom.xml @@ -22,7 +22,7 @@ org.apache.spark spark-parent_2.11 - 2.2.0-SNAPSHOT + 2.3.0-SNAPSHOT ../../pom.xml diff --git a/sql/hive-thriftserver/pom.xml b/sql/hive-thriftserver/pom.xml index 9c879218ddc0d..a5a8e2640586c 100644 --- a/sql/hive-thriftserver/pom.xml +++ b/sql/hive-thriftserver/pom.xml @@ -22,7 +22,7 @@ org.apache.spark spark-parent_2.11 - 2.2.0-SNAPSHOT + 2.3.0-SNAPSHOT ../../pom.xml diff --git a/sql/hive/pom.xml b/sql/hive/pom.xml index 0f249d7d59351..09dcc4055e000 100644 --- a/sql/hive/pom.xml +++ b/sql/hive/pom.xml @@ -22,7 +22,7 @@ org.apache.spark spark-parent_2.11 - 2.2.0-SNAPSHOT + 2.3.0-SNAPSHOT ../../pom.xml diff --git a/streaming/pom.xml b/streaming/pom.xml index de1be9c13e05f..fea882ad11230 100644 --- a/streaming/pom.xml +++ b/streaming/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.11 - 2.2.0-SNAPSHOT + 2.3.0-SNAPSHOT ../pom.xml diff --git a/tools/pom.xml b/tools/pom.xml index 938ba2f6ac201..7ba4dc9842f1b 100644 --- a/tools/pom.xml +++ b/tools/pom.xml @@ -20,7 +20,7 @@ org.apache.spark spark-parent_2.11 - 2.2.0-SNAPSHOT + 2.3.0-SNAPSHOT ../pom.xml From 31345fde82ada1f8bb12807b250b04726a1f6aa6 Mon Sep 17 00:00:00 2001 From: Sameer Agarwal Date: Tue, 25 Apr 2017 13:05:20 +0800 Subject: [PATCH 0327/1765] [SPARK-20451] Filter out nested mapType datatypes from sort order in randomSplit ## What changes were proposed in this pull request? In `randomSplit`, It is possible that the underlying dataset doesn't guarantee the ordering of rows in its constituent partitions each time a split is materialized which could result in overlapping splits. To prevent this, as part of SPARK-12662, we explicitly sort each input partition to make the ordering deterministic. Given that `MapTypes` cannot be sorted this patch explicitly prunes them out from the sort order. Additionally, if the resulting sort order is empty, this patch then materializes the dataset to guarantee determinism. ## How was this patch tested? Extended `randomSplit on reordered partitions` in `DataFrameStatSuite` to also test for dataframes with mapTypes nested mapTypes. Author: Sameer Agarwal Closes #17751 from sameeragarwal/randomsplit2. --- .../scala/org/apache/spark/sql/Dataset.scala | 18 +++++--- .../apache/spark/sql/DataFrameStatSuite.scala | 43 ++++++++++++------- 2 files changed, 41 insertions(+), 20 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index c6dcd93bbda66..06dd5500718de 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -1726,15 +1726,23 @@ class Dataset[T] private[sql]( // It is possible that the underlying dataframe doesn't guarantee the ordering of rows in its // constituent partitions each time a split is materialized which could result in // overlapping splits. To prevent this, we explicitly sort each input partition to make the - // ordering deterministic. - // MapType cannot be sorted. - val sorted = Sort(logicalPlan.output.filterNot(_.dataType.isInstanceOf[MapType]) - .map(SortOrder(_, Ascending)), global = false, logicalPlan) + // ordering deterministic. Note that MapTypes cannot be sorted and are explicitly pruned out + // from the sort order. + val sortOrder = logicalPlan.output + .filter(attr => RowOrdering.isOrderable(attr.dataType)) + .map(SortOrder(_, Ascending)) + val plan = if (sortOrder.nonEmpty) { + Sort(sortOrder, global = false, logicalPlan) + } else { + // SPARK-12662: If sort order is empty, we materialize the dataset to guarantee determinism + cache() + logicalPlan + } val sum = weights.sum val normalizedCumWeights = weights.map(_ / sum).scanLeft(0.0d)(_ + _) normalizedCumWeights.sliding(2).map { x => new Dataset[T]( - sparkSession, Sample(x(0), x(1), withReplacement = false, seed, sorted)(), encoder) + sparkSession, Sample(x(0), x(1), withReplacement = false, seed, plan)(), encoder) }.toArray } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala index 97890a035a62f..dd118f88e3bb3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala @@ -68,25 +68,38 @@ class DataFrameStatSuite extends QueryTest with SharedSQLContext { } test("randomSplit on reordered partitions") { - // This test ensures that randomSplit does not create overlapping splits even when the - // underlying dataframe (such as the one below) doesn't guarantee a deterministic ordering of - // rows in each partition. - val data = - sparkContext.parallelize(1 to 600, 2).mapPartitions(scala.util.Random.shuffle(_)).toDF("id") - val splits = data.randomSplit(Array[Double](2, 3), seed = 1) - assert(splits.length == 2, "wrong number of splits") + def testNonOverlappingSplits(data: DataFrame): Unit = { + val splits = data.randomSplit(Array[Double](2, 3), seed = 1) + assert(splits.length == 2, "wrong number of splits") + + // Verify that the splits span the entire dataset + assert(splits.flatMap(_.collect()).toSet == data.collect().toSet) - // Verify that the splits span the entire dataset - assert(splits.flatMap(_.collect()).toSet == data.collect().toSet) + // Verify that the splits don't overlap + assert(splits(0).collect().toSeq.intersect(splits(1).collect().toSeq).isEmpty) - // Verify that the splits don't overlap - assert(splits(0).intersect(splits(1)).collect().isEmpty) + // Verify that the results are deterministic across multiple runs + val firstRun = splits.toSeq.map(_.collect().toSeq) + val secondRun = data.randomSplit(Array[Double](2, 3), seed = 1).toSeq.map(_.collect().toSeq) + assert(firstRun == secondRun) + } - // Verify that the results are deterministic across multiple runs - val firstRun = splits.toSeq.map(_.collect().toSeq) - val secondRun = data.randomSplit(Array[Double](2, 3), seed = 1).toSeq.map(_.collect().toSeq) - assert(firstRun == secondRun) + // This test ensures that randomSplit does not create overlapping splits even when the + // underlying dataframe (such as the one below) doesn't guarantee a deterministic ordering of + // rows in each partition. + val dataWithInts = sparkContext.parallelize(1 to 600, 2) + .mapPartitions(scala.util.Random.shuffle(_)).toDF("int") + val dataWithMaps = sparkContext.parallelize(1 to 600, 2) + .map(i => (i, Map(i -> i.toString))) + .mapPartitions(scala.util.Random.shuffle(_)).toDF("int", "map") + val dataWithArrayOfMaps = sparkContext.parallelize(1 to 600, 2) + .map(i => (i, Array(Map(i -> i.toString)))) + .mapPartitions(scala.util.Random.shuffle(_)).toDF("int", "arrayOfMaps") + + testNonOverlappingSplits(dataWithInts) + testNonOverlappingSplits(dataWithMaps) + testNonOverlappingSplits(dataWithArrayOfMaps) } test("pearson correlation") { From c8f1219510f469935aa9ff0b1c92cfe20372377c Mon Sep 17 00:00:00 2001 From: Armin Braun Date: Tue, 25 Apr 2017 09:13:50 +0100 Subject: [PATCH 0328/1765] [SPARK-20455][DOCS] Fix Broken Docker IT Docs ## What changes were proposed in this pull request? Just added the Maven `test`goal. ## How was this patch tested? No test needed, just a trivial documentation fix. Author: Armin Braun Closes #17756 from original-brownbear/SPARK-20455. --- docs/building-spark.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/building-spark.md b/docs/building-spark.md index e99b70f7a8b47..0f551bc66b8c9 100644 --- a/docs/building-spark.md +++ b/docs/building-spark.md @@ -232,7 +232,7 @@ Once installed, the `docker` service needs to be started, if not already running On Linux, this can be done by `sudo service docker start`. ./build/mvn install -DskipTests - ./build/mvn -Pdocker-integration-tests -pl :spark-docker-integration-tests_2.11 + ./build/mvn test -Pdocker-integration-tests -pl :spark-docker-integration-tests_2.11 or From 0bc7a90210aad9025c1e1bdc99f8e723c1bf0fbf Mon Sep 17 00:00:00 2001 From: Sergey Zhemzhitsky Date: Tue, 25 Apr 2017 09:18:36 +0100 Subject: [PATCH 0329/1765] [SPARK-20404][CORE] Using Option(name) instead of Some(name) Using Option(name) instead of Some(name) to prevent runtime failures when using accumulators created like the following ``` sparkContext.accumulator(0, null) ``` Author: Sergey Zhemzhitsky Closes #17740 from szhem/SPARK-20404-null-acc-names. --- core/src/main/scala/org/apache/spark/SparkContext.scala | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index 99efc4893fda4..0ec1bdd39b2f5 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -1350,7 +1350,7 @@ class SparkContext(config: SparkConf) extends Logging { @deprecated("use AccumulatorV2", "2.0.0") def accumulator[T](initialValue: T, name: String)(implicit param: AccumulatorParam[T]) : Accumulator[T] = { - val acc = new Accumulator(initialValue, param, Some(name)) + val acc = new Accumulator(initialValue, param, Option(name)) cleaner.foreach(_.registerAccumulatorForCleanup(acc.newAcc)) acc } @@ -1379,7 +1379,7 @@ class SparkContext(config: SparkConf) extends Logging { @deprecated("use AccumulatorV2", "2.0.0") def accumulable[R, T](initialValue: R, name: String)(implicit param: AccumulableParam[R, T]) : Accumulable[R, T] = { - val acc = new Accumulable(initialValue, param, Some(name)) + val acc = new Accumulable(initialValue, param, Option(name)) cleaner.foreach(_.registerAccumulatorForCleanup(acc.newAcc)) acc } @@ -1414,7 +1414,7 @@ class SparkContext(config: SparkConf) extends Logging { * @note Accumulators must be registered before use, or it will throw exception. */ def register(acc: AccumulatorV2[_, _], name: String): Unit = { - acc.register(this, name = Some(name)) + acc.register(this, name = Option(name)) } /** From 387565cf14b490810f9479ff3adbf776e2edecdc Mon Sep 17 00:00:00 2001 From: wangmiao1981 Date: Tue, 25 Apr 2017 16:30:36 +0800 Subject: [PATCH 0330/1765] [SPARK-18901][FOLLOWUP][ML] Require in LR LogisticAggregator is redundant ## What changes were proposed in this pull request? This is a follow-up PR of #17478. ## How was this patch tested? Existing tests Author: wangmiao1981 Closes #17754 from wangmiao1981/followup. --- .../scala/org/apache/spark/ml/classification/LinearSVC.scala | 5 ++--- .../org/apache/spark/ml/regression/LinearRegression.scala | 5 ----- 2 files changed, 2 insertions(+), 8 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LinearSVC.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LinearSVC.scala index f76b14eeeb542..7507c7539d4ef 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/LinearSVC.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LinearSVC.scala @@ -458,9 +458,7 @@ private class LinearSVCAggregator( */ def add(instance: Instance): this.type = { instance match { case Instance(label, weight, features) => - require(weight >= 0.0, s"instance weight, $weight has to be >= 0.0") - require(numFeatures == features.size, s"Dimensions mismatch when adding new instance." + - s" Expecting $numFeatures but got ${features.size}.") + if (weight == 0.0) return this val localFeaturesStd = bcFeaturesStd.value val localCoefficients = coefficientsArray @@ -512,6 +510,7 @@ private class LinearSVCAggregator( * @return This LinearSVCAggregator object. */ def merge(other: LinearSVCAggregator): this.type = { + if (other.weightSum != 0.0) { weightSum += other.weightSum lossSum += other.lossSum diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala index f7e3c8fa5b6e6..eaad54985229e 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala @@ -971,9 +971,6 @@ private class LeastSquaresAggregator( */ def add(instance: Instance): this.type = { instance match { case Instance(label, weight, features) => - require(dim == features.size, s"Dimensions mismatch when adding new sample." + - s" Expecting $dim but got ${features.size}.") - require(weight >= 0.0, s"instance weight, $weight has to be >= 0.0") if (weight == 0.0) return this @@ -1005,8 +1002,6 @@ private class LeastSquaresAggregator( * @return This LeastSquaresAggregator object. */ def merge(other: LeastSquaresAggregator): this.type = { - require(dim == other.dim, s"Dimensions mismatch when merging with another " + - s"LeastSquaresAggregator. Expecting $dim but got ${other.dim}.") if (other.weightSum != 0) { totalCnt += other.totalCnt From 67eef47acfd26f1f0be3e8ef10453514f3655f62 Mon Sep 17 00:00:00 2001 From: Yanbo Liang Date: Tue, 25 Apr 2017 17:10:41 +0000 Subject: [PATCH 0331/1765] [SPARK-20449][ML] Upgrade breeze version to 0.13.1 ## What changes were proposed in this pull request? Upgrade breeze version to 0.13.1, which fixed some critical bugs of L-BFGS-B. ## How was this patch tested? Existing unit tests. Author: Yanbo Liang Closes #17746 from yanboliang/spark-20449. --- LICENSE | 1 + .../tests/testthat/test_mllib_classification.R | 10 +++++----- dev/deps/spark-deps-hadoop-2.6 | 12 +++++++----- dev/deps/spark-deps-hadoop-2.7 | 12 +++++++----- .../GeneralizedLinearRegression.scala | 4 ++-- .../spark/mllib/clustering/LDAModel.scala | 14 ++++---------- .../spark/mllib/optimization/LBFGSSuite.scala | 4 ++-- pom.xml | 2 +- python/pyspark/ml/classification.py | 18 ++++++++---------- 9 files changed, 37 insertions(+), 40 deletions(-) diff --git a/LICENSE b/LICENSE index 7950dd6ceb6db..c21032a1fd274 100644 --- a/LICENSE +++ b/LICENSE @@ -297,3 +297,4 @@ The text of each license is also included at licenses/LICENSE-[project].txt. (MIT License) RowsGroup (http://datatables.net/license/mit) (MIT License) jsonFormatter (http://www.jqueryscript.net/other/jQuery-Plugin-For-Pretty-JSON-Formatting-jsonFormatter.html) (MIT License) modernizr (https://github.com/Modernizr/Modernizr/blob/master/LICENSE) + (MIT License) machinist (https://github.com/typelevel/machinist) diff --git a/R/pkg/inst/tests/testthat/test_mllib_classification.R b/R/pkg/inst/tests/testthat/test_mllib_classification.R index 459254d271a58..af7cbdccf5d5d 100644 --- a/R/pkg/inst/tests/testthat/test_mllib_classification.R +++ b/R/pkg/inst/tests/testthat/test_mllib_classification.R @@ -288,18 +288,18 @@ test_that("spark.mlp", { c(0, 0, 0, 0, 0, 5, 5, 5, 5, 5, 9, 9, 9, 9, 9)) mlpPredictions <- collect(select(predict(model, mlpTestDF), "prediction")) expect_equal(head(mlpPredictions$prediction, 10), - c("1.0", "1.0", "1.0", "1.0", "2.0", "1.0", "2.0", "2.0", "1.0", "0.0")) + c("1.0", "1.0", "2.0", "1.0", "2.0", "1.0", "2.0", "2.0", "1.0", "0.0")) model <- spark.mlp(df, label ~ features, layers = c(4, 3), maxIter = 2, initialWeights = c(0.0, 0.0, 0.0, 0.0, 0.0, 5.0, 5.0, 5.0, 5.0, 5.0, 9.0, 9.0, 9.0, 9.0, 9.0)) mlpPredictions <- collect(select(predict(model, mlpTestDF), "prediction")) expect_equal(head(mlpPredictions$prediction, 10), - c("1.0", "1.0", "1.0", "1.0", "2.0", "1.0", "2.0", "2.0", "1.0", "0.0")) + c("1.0", "1.0", "2.0", "1.0", "2.0", "1.0", "2.0", "2.0", "1.0", "0.0")) model <- spark.mlp(df, label ~ features, layers = c(4, 3), maxIter = 2) mlpPredictions <- collect(select(predict(model, mlpTestDF), "prediction")) expect_equal(head(mlpPredictions$prediction, 10), - c("1.0", "1.0", "1.0", "1.0", "0.0", "1.0", "0.0", "2.0", "1.0", "0.0")) + c("1.0", "1.0", "1.0", "1.0", "0.0", "1.0", "0.0", "0.0", "1.0", "0.0")) # Test formula works well df <- suppressWarnings(createDataFrame(iris)) @@ -310,8 +310,8 @@ test_that("spark.mlp", { expect_equal(summary$numOfOutputs, 3) expect_equal(summary$layers, c(4, 3)) expect_equal(length(summary$weights), 15) - expect_equal(head(summary$weights, 5), list(-1.1957257, -5.2693685, 7.4489734, -6.3751413, - -10.2376130), tolerance = 1e-6) + expect_equal(head(summary$weights, 5), list(-0.5793153, -4.652961, 6.216155, -6.649478, + -10.51147), tolerance = 1e-3) }) test_that("spark.naiveBayes", { diff --git a/dev/deps/spark-deps-hadoop-2.6 b/dev/deps/spark-deps-hadoop-2.6 index 73dc1f9a1398c..9287bd47cf113 100644 --- a/dev/deps/spark-deps-hadoop-2.6 +++ b/dev/deps/spark-deps-hadoop-2.6 @@ -19,8 +19,8 @@ avro-mapred-1.7.7-hadoop2.jar base64-2.3.8.jar bcprov-jdk15on-1.51.jar bonecp-0.8.0.RELEASE.jar -breeze-macros_2.11-0.12.jar -breeze_2.11-0.12.jar +breeze-macros_2.11-0.13.1.jar +breeze_2.11-0.13.1.jar calcite-avatica-1.2.0-incubating.jar calcite-core-1.2.0-incubating.jar calcite-linq4j-1.2.0-incubating.jar @@ -129,6 +129,8 @@ libfb303-0.9.3.jar libthrift-0.9.3.jar log4j-1.2.17.jar lz4-1.3.0.jar +machinist_2.11-0.6.1.jar +macro-compat_2.11-1.1.1.jar mail-1.4.7.jar mesos-1.0.0-shaded-protobuf.jar metrics-core-3.1.2.jar @@ -162,13 +164,13 @@ scala-parser-combinators_2.11-1.0.4.jar scala-reflect-2.11.8.jar scala-xml_2.11-1.0.2.jar scalap-2.11.8.jar -shapeless_2.11-2.0.0.jar +shapeless_2.11-2.3.2.jar slf4j-api-1.7.16.jar slf4j-log4j12-1.7.16.jar snappy-0.2.jar snappy-java-1.1.2.6.jar -spire-macros_2.11-0.7.4.jar -spire_2.11-0.7.4.jar +spire-macros_2.11-0.13.0.jar +spire_2.11-0.13.0.jar stax-api-1.0-2.jar stax-api-1.0.1.jar stream-2.7.0.jar diff --git a/dev/deps/spark-deps-hadoop-2.7 b/dev/deps/spark-deps-hadoop-2.7 index 6bf0923a1d751..ab1de3d3dd8ad 100644 --- a/dev/deps/spark-deps-hadoop-2.7 +++ b/dev/deps/spark-deps-hadoop-2.7 @@ -19,8 +19,8 @@ avro-mapred-1.7.7-hadoop2.jar base64-2.3.8.jar bcprov-jdk15on-1.51.jar bonecp-0.8.0.RELEASE.jar -breeze-macros_2.11-0.12.jar -breeze_2.11-0.12.jar +breeze-macros_2.11-0.13.1.jar +breeze_2.11-0.13.1.jar calcite-avatica-1.2.0-incubating.jar calcite-core-1.2.0-incubating.jar calcite-linq4j-1.2.0-incubating.jar @@ -130,6 +130,8 @@ libfb303-0.9.3.jar libthrift-0.9.3.jar log4j-1.2.17.jar lz4-1.3.0.jar +machinist_2.11-0.6.1.jar +macro-compat_2.11-1.1.1.jar mail-1.4.7.jar mesos-1.0.0-shaded-protobuf.jar metrics-core-3.1.2.jar @@ -163,13 +165,13 @@ scala-parser-combinators_2.11-1.0.4.jar scala-reflect-2.11.8.jar scala-xml_2.11-1.0.2.jar scalap-2.11.8.jar -shapeless_2.11-2.0.0.jar +shapeless_2.11-2.3.2.jar slf4j-api-1.7.16.jar slf4j-log4j12-1.7.16.jar snappy-0.2.jar snappy-java-1.1.2.6.jar -spire-macros_2.11-0.7.4.jar -spire_2.11-0.7.4.jar +spire-macros_2.11-0.13.0.jar +spire_2.11-0.13.0.jar stax-api-1.0-2.jar stax-api-1.0.1.jar stream-2.7.0.jar diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala index d6093a01c671c..bff0d9bbb46ff 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala @@ -894,10 +894,10 @@ object GeneralizedLinearRegression extends DefaultParamsReadable[GeneralizedLine private[regression] object Probit extends Link("probit") { - override def link(mu: Double): Double = dist.Gaussian(0.0, 1.0).icdf(mu) + override def link(mu: Double): Double = dist.Gaussian(0.0, 1.0).inverseCdf(mu) override def deriv(mu: Double): Double = { - 1.0 / dist.Gaussian(0.0, 1.0).pdf(dist.Gaussian(0.0, 1.0).icdf(mu)) + 1.0 / dist.Gaussian(0.0, 1.0).pdf(dist.Gaussian(0.0, 1.0).inverseCdf(mu)) } override def unlink(eta: Double): Double = dist.Gaussian(0.0, 1.0).cdf(eta) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala index 7fd722a332923..15b723dadcff7 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala @@ -788,20 +788,14 @@ class DistributedLDAModel private[clustering] ( @Since("1.5.0") def topTopicsPerDocument(k: Int): RDD[(Long, Array[Int], Array[Double])] = { graph.vertices.filter(LDA.isDocumentVertex).map { case (docID, topicCounts) => - // TODO: Remove work-around for the breeze bug. - // https://github.com/scalanlp/breeze/issues/561 - val topIndices = if (k == topicCounts.length) { - Seq.range(0, k) - } else { - argtopk(topicCounts, k) - } + val topIndices = argtopk(topicCounts, k) val sumCounts = sum(topicCounts) val weights = if (sumCounts != 0) { - topicCounts(topIndices) / sumCounts + topicCounts(topIndices).toArray.map(_ / sumCounts) } else { - topicCounts(topIndices) + topicCounts(topIndices).toArray } - (docID.toLong, topIndices.toArray, weights.toArray) + (docID.toLong, topIndices.toArray, weights) } } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/optimization/LBFGSSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/optimization/LBFGSSuite.scala index 572959200f47f..3d6a9f8d84cac 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/optimization/LBFGSSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/optimization/LBFGSSuite.scala @@ -191,8 +191,8 @@ class LBFGSSuite extends SparkFunSuite with MLlibTestSparkContext with Matchers // With smaller convergenceTol, it takes more steps. assert(lossLBFGS3.length > lossLBFGS2.length) - // Based on observation, lossLBFGS2 runs 5 iterations, no theoretically guaranteed. - assert(lossLBFGS3.length == 6) + // Based on observation, lossLBFGS3 runs 7 iterations, no theoretically guaranteed. + assert(lossLBFGS3.length == 7) assert((lossLBFGS3(4) - lossLBFGS3(5)) / lossLBFGS3(4) < convergenceTol) } diff --git a/pom.xml b/pom.xml index a65692e0d1318..b6654c1411d25 100644 --- a/pom.xml +++ b/pom.xml @@ -658,7 +658,7 @@ org.scalanlp breeze_${scala.binary.version} - 0.12 + 0.13.1 diff --git a/python/pyspark/ml/classification.py b/python/pyspark/ml/classification.py index b4fc357e42d71..864968390ace9 100644 --- a/python/pyspark/ml/classification.py +++ b/python/pyspark/ml/classification.py @@ -190,9 +190,9 @@ class LogisticRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredicti >>> blor = LogisticRegression(maxIter=5, regParam=0.01, weightCol="weight") >>> blorModel = blor.fit(bdf) >>> blorModel.coefficients - DenseVector([5.5...]) + DenseVector([5.4...]) >>> blorModel.intercept - -2.68... + -2.63... >>> mdf = sc.parallelize([ ... Row(label=1.0, weight=2.0, features=Vectors.dense(1.0)), ... Row(label=0.0, weight=2.0, features=Vectors.sparse(1, [], [])), @@ -200,12 +200,10 @@ class LogisticRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredicti >>> mlor = LogisticRegression(maxIter=5, regParam=0.01, weightCol="weight", ... family="multinomial") >>> mlorModel = mlor.fit(mdf) - >>> print(mlorModel.coefficientMatrix) - DenseMatrix([[-2.3...], - [ 0.2...], - [ 2.1... ]]) + >>> mlorModel.coefficientMatrix + DenseMatrix(3, 1, [-2.3..., 0.2..., 2.1...], 1) >>> mlorModel.interceptVector - DenseVector([2.0..., 0.8..., -2.8...]) + DenseVector([2.1..., 0.6..., -2.8...]) >>> test0 = sc.parallelize([Row(features=Vectors.dense(-1.0))]).toDF() >>> result = blorModel.transform(test0).head() >>> result.prediction @@ -213,7 +211,7 @@ class LogisticRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredicti >>> result.probability DenseVector([0.99..., 0.00...]) >>> result.rawPrediction - DenseVector([8.22..., -8.22...]) + DenseVector([8.12..., -8.12...]) >>> test1 = sc.parallelize([Row(features=Vectors.sparse(1, [0], [1.0]))]).toDF() >>> blorModel.transform(test1).head().prediction 1.0 @@ -1490,9 +1488,9 @@ class OneVsRest(Estimator, OneVsRestParams, MLReadable, MLWritable): >>> ovr = OneVsRest(classifier=lr) >>> model = ovr.fit(df) >>> [x.coefficients for x in model.models] - [DenseVector([3.3925, 1.8785]), DenseVector([-4.3016, -6.3163]), DenseVector([-4.5855, 6.1785])] + [DenseVector([4.9791, 2.426]), DenseVector([-4.1198, -5.9326]), DenseVector([-3.314, 5.2423])] >>> [x.intercept for x in model.models] - [-3.64747..., 2.55078..., -1.10165...] + [-5.06544..., 2.30341..., -1.29133...] >>> test0 = sc.parallelize([Row(features=Vectors.dense(-1.0, 0.0))]).toDF() >>> model.transform(test0).head().prediction 1.0 From 0a7f5f2798b6e8b2ba15e8b3aa07d5953ad1c695 Mon Sep 17 00:00:00 2001 From: ding Date: Tue, 25 Apr 2017 11:20:32 -0700 Subject: [PATCH 0332/1765] [SPARK-5484][GRAPHX] Periodically do checkpoint in Pregel ## What changes were proposed in this pull request? Pregel-based iterative algorithms with more than ~50 iterations begin to slow down and eventually fail with a StackOverflowError due to Spark's lack of support for long lineage chains. This PR causes Pregel to checkpoint the graph periodically if the checkpoint directory is set. This PR moves PeriodicGraphCheckpointer.scala from mllib to graphx, moves PeriodicRDDCheckpointer.scala, PeriodicCheckpointer.scala from mllib to core ## How was this patch tested? unit tests, manual tests (Please explain how this patch was tested. E.g. unit tests, integration tests, manual tests) (If this patch involves UI changes, please attach a screenshot; otherwise, remove this) Author: ding Author: dding3 Author: Michael Allman Closes #15125 from dding3/cp2_pregel. --- .../main/scala/org/apache/spark/rdd/RDD.scala | 4 +- .../rdd/util}/PeriodicRDDCheckpointer.scala | 3 +- .../spark/util}/PeriodicCheckpointer.scala | 14 ++- .../org/apache/spark/rdd/SortingSuite.scala | 2 +- .../util}/PeriodicRDDCheckpointerSuite.scala | 8 +- docs/configuration.md | 14 +++ docs/graphx-programming-guide.md | 9 +- .../org/apache/spark/graphx/Pregel.scala | 25 ++++- .../util}/PeriodicGraphCheckpointer.scala | 13 ++- .../PeriodicGraphCheckpointerSuite.scala | 105 +++++++++--------- .../org/apache/spark/ml/clustering/LDA.scala | 3 +- .../ml/tree/impl/GradientBoostedTrees.scala | 2 +- .../spark/mllib/clustering/LDAOptimizer.scala | 2 +- 13 files changed, 128 insertions(+), 76 deletions(-) rename {mllib/src/main/scala/org/apache/spark/mllib/impl => core/src/main/scala/org/apache/spark/rdd/util}/PeriodicRDDCheckpointer.scala (97%) rename {mllib/src/main/scala/org/apache/spark/mllib/impl => core/src/main/scala/org/apache/spark/util}/PeriodicCheckpointer.scala (95%) rename {mllib/src/test/scala/org/apache/spark/mllib/impl => core/src/test/scala/org/apache/spark/util}/PeriodicRDDCheckpointerSuite.scala (96%) rename {mllib/src/main/scala/org/apache/spark/mllib/impl => graphx/src/main/scala/org/apache/spark/graphx/util}/PeriodicGraphCheckpointer.scala (91%) rename {mllib/src/test/scala/org/apache/spark/mllib/impl => graphx/src/test/scala/org/apache/spark/graphx/util}/PeriodicGraphCheckpointerSuite.scala (70%) diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala index e524675332d1b..63a87e7f09d85 100644 --- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala @@ -41,7 +41,7 @@ import org.apache.spark.partial.GroupedCountEvaluator import org.apache.spark.partial.PartialResult import org.apache.spark.storage.{RDDBlockId, StorageLevel} import org.apache.spark.util.{BoundedPriorityQueue, Utils} -import org.apache.spark.util.collection.OpenHashMap +import org.apache.spark.util.collection.{OpenHashMap, Utils => collectionUtils} import org.apache.spark.util.random.{BernoulliCellSampler, BernoulliSampler, PoissonSampler, SamplingUtils} @@ -1420,7 +1420,7 @@ abstract class RDD[T: ClassTag]( val mapRDDs = mapPartitions { items => // Priority keeps the largest elements, so let's reverse the ordering. val queue = new BoundedPriorityQueue[T](num)(ord.reverse) - queue ++= util.collection.Utils.takeOrdered(items, num)(ord) + queue ++= collectionUtils.takeOrdered(items, num)(ord) Iterator.single(queue) } if (mapRDDs.partitions.length == 0) { diff --git a/mllib/src/main/scala/org/apache/spark/mllib/impl/PeriodicRDDCheckpointer.scala b/core/src/main/scala/org/apache/spark/rdd/util/PeriodicRDDCheckpointer.scala similarity index 97% rename from mllib/src/main/scala/org/apache/spark/mllib/impl/PeriodicRDDCheckpointer.scala rename to core/src/main/scala/org/apache/spark/rdd/util/PeriodicRDDCheckpointer.scala index 145dc22b7428e..ab72addb2466b 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/impl/PeriodicRDDCheckpointer.scala +++ b/core/src/main/scala/org/apache/spark/rdd/util/PeriodicRDDCheckpointer.scala @@ -15,11 +15,12 @@ * limitations under the License. */ -package org.apache.spark.mllib.impl +package org.apache.spark.rdd.util import org.apache.spark.SparkContext import org.apache.spark.rdd.RDD import org.apache.spark.storage.StorageLevel +import org.apache.spark.util.PeriodicCheckpointer /** diff --git a/mllib/src/main/scala/org/apache/spark/mllib/impl/PeriodicCheckpointer.scala b/core/src/main/scala/org/apache/spark/util/PeriodicCheckpointer.scala similarity index 95% rename from mllib/src/main/scala/org/apache/spark/mllib/impl/PeriodicCheckpointer.scala rename to core/src/main/scala/org/apache/spark/util/PeriodicCheckpointer.scala index 4dd498cd91b4e..ce06e18879a49 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/impl/PeriodicCheckpointer.scala +++ b/core/src/main/scala/org/apache/spark/util/PeriodicCheckpointer.scala @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.mllib.impl +package org.apache.spark.util import scala.collection.mutable @@ -58,7 +58,7 @@ import org.apache.spark.storage.StorageLevel * @param sc SparkContext for the Datasets given to this checkpointer * @tparam T Dataset type, such as RDD[Double] */ -private[mllib] abstract class PeriodicCheckpointer[T]( +private[spark] abstract class PeriodicCheckpointer[T]( val checkpointInterval: Int, val sc: SparkContext) extends Logging { @@ -127,6 +127,16 @@ private[mllib] abstract class PeriodicCheckpointer[T]( /** Get list of checkpoint files for this given Dataset */ protected def getCheckpointFiles(data: T): Iterable[String] + /** + * Call this to unpersist the Dataset. + */ + def unpersistDataSet(): Unit = { + while (persistedQueue.nonEmpty) { + val dataToUnpersist = persistedQueue.dequeue() + unpersist(dataToUnpersist) + } + } + /** * Call this at the end to delete any remaining checkpoint files. */ diff --git a/core/src/test/scala/org/apache/spark/rdd/SortingSuite.scala b/core/src/test/scala/org/apache/spark/rdd/SortingSuite.scala index f9a7f151823a2..7f20206202cb9 100644 --- a/core/src/test/scala/org/apache/spark/rdd/SortingSuite.scala +++ b/core/src/test/scala/org/apache/spark/rdd/SortingSuite.scala @@ -135,7 +135,7 @@ class SortingSuite extends SparkFunSuite with SharedSparkContext with Matchers w } test("get a range of elements in an array not partitioned by a range partitioner") { - val pairArr = util.Random.shuffle((1 to 1000).toList).map(x => (x, x)) + val pairArr = scala.util.Random.shuffle((1 to 1000).toList).map(x => (x, x)) val pairs = sc.parallelize(pairArr, 10) val range = pairs.filterByRange(200, 800).collect() assert((800 to 200 by -1).toArray.sorted === range.map(_._1).sorted) diff --git a/mllib/src/test/scala/org/apache/spark/mllib/impl/PeriodicRDDCheckpointerSuite.scala b/core/src/test/scala/org/apache/spark/util/PeriodicRDDCheckpointerSuite.scala similarity index 96% rename from mllib/src/test/scala/org/apache/spark/mllib/impl/PeriodicRDDCheckpointerSuite.scala rename to core/src/test/scala/org/apache/spark/util/PeriodicRDDCheckpointerSuite.scala index 14adf8c29fc6b..f9e1b791c86ea 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/impl/PeriodicRDDCheckpointerSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/PeriodicRDDCheckpointerSuite.scala @@ -15,18 +15,18 @@ * limitations under the License. */ -package org.apache.spark.mllib.impl +package org.apache.spark.utils import org.apache.hadoop.fs.Path -import org.apache.spark.{SparkContext, SparkFunSuite} -import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.{SharedSparkContext, SparkContext, SparkFunSuite} import org.apache.spark.rdd.RDD +import org.apache.spark.rdd.util.PeriodicRDDCheckpointer import org.apache.spark.storage.StorageLevel import org.apache.spark.util.Utils -class PeriodicRDDCheckpointerSuite extends SparkFunSuite with MLlibTestSparkContext { +class PeriodicRDDCheckpointerSuite extends SparkFunSuite with SharedSparkContext { import PeriodicRDDCheckpointerSuite._ diff --git a/docs/configuration.md b/docs/configuration.md index 6b65d2bcb83e5..87b76322cae51 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -2149,6 +2149,20 @@ showDF(properties, numRows = 200, truncate = FALSE) +### GraphX + + + + + + + + +
    Property NameDefaultMeaning
    spark.graphx.pregel.checkpointInterval-1 + Checkpoint interval for graph and message in Pregel. It used to avoid stackOverflowError due to long lineage chains + after lots of iterations. The checkpoint is disabled by default. +
    + ### Deploy diff --git a/docs/graphx-programming-guide.md b/docs/graphx-programming-guide.md index e271b28fb4f28..76aa7b405e18c 100644 --- a/docs/graphx-programming-guide.md +++ b/docs/graphx-programming-guide.md @@ -708,7 +708,9 @@ messages remaining. > messaging function. These constraints allow additional optimization within GraphX. The following is the type signature of the [Pregel operator][GraphOps.pregel] as well as a *sketch* -of its implementation (note calls to graph.cache have been removed): +of its implementation (note: to avoid stackOverflowError due to long lineage chains, pregel support periodcally +checkpoint graph and messages by setting "spark.graphx.pregel.checkpointInterval" to a positive number, +say 10. And set checkpoint directory as well using SparkContext.setCheckpointDir(directory: String)): {% highlight scala %} class GraphOps[VD, ED] { @@ -722,6 +724,7 @@ class GraphOps[VD, ED] { : Graph[VD, ED] = { // Receive the initial message at each vertex var g = mapVertices( (vid, vdata) => vprog(vid, vdata, initialMsg) ).cache() + // compute the messages var messages = g.mapReduceTriplets(sendMsg, mergeMsg) var activeMessages = messages.count() @@ -734,8 +737,8 @@ class GraphOps[VD, ED] { // Send new messages, skipping edges where neither side received a message. We must cache // messages so it can be materialized on the next line, allowing us to uncache the previous // iteration. - messages = g.mapReduceTriplets( - sendMsg, mergeMsg, Some((oldMessages, activeDirection))).cache() + messages = GraphXUtils.mapReduceTriplets( + g, sendMsg, mergeMsg, Some((oldMessages, activeDirection))).cache() activeMessages = messages.count() i += 1 } diff --git a/graphx/src/main/scala/org/apache/spark/graphx/Pregel.scala b/graphx/src/main/scala/org/apache/spark/graphx/Pregel.scala index 646462b4a8350..755c6febc48e6 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/Pregel.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/Pregel.scala @@ -19,7 +19,10 @@ package org.apache.spark.graphx import scala.reflect.ClassTag +import org.apache.spark.graphx.util.PeriodicGraphCheckpointer import org.apache.spark.internal.Logging +import org.apache.spark.rdd.RDD +import org.apache.spark.rdd.util.PeriodicRDDCheckpointer /** * Implements a Pregel-like bulk-synchronous message-passing API. @@ -122,27 +125,39 @@ object Pregel extends Logging { require(maxIterations > 0, s"Maximum number of iterations must be greater than 0," + s" but got ${maxIterations}") - var g = graph.mapVertices((vid, vdata) => vprog(vid, vdata, initialMsg)).cache() + val checkpointInterval = graph.vertices.sparkContext.getConf + .getInt("spark.graphx.pregel.checkpointInterval", -1) + var g = graph.mapVertices((vid, vdata) => vprog(vid, vdata, initialMsg)) + val graphCheckpointer = new PeriodicGraphCheckpointer[VD, ED]( + checkpointInterval, graph.vertices.sparkContext) + graphCheckpointer.update(g) + // compute the messages var messages = GraphXUtils.mapReduceTriplets(g, sendMsg, mergeMsg) + val messageCheckpointer = new PeriodicRDDCheckpointer[(VertexId, A)]( + checkpointInterval, graph.vertices.sparkContext) + messageCheckpointer.update(messages.asInstanceOf[RDD[(VertexId, A)]]) var activeMessages = messages.count() + // Loop var prevG: Graph[VD, ED] = null var i = 0 while (activeMessages > 0 && i < maxIterations) { // Receive the messages and update the vertices. prevG = g - g = g.joinVertices(messages)(vprog).cache() + g = g.joinVertices(messages)(vprog) + graphCheckpointer.update(g) val oldMessages = messages // Send new messages, skipping edges where neither side received a message. We must cache // messages so it can be materialized on the next line, allowing us to uncache the previous // iteration. messages = GraphXUtils.mapReduceTriplets( - g, sendMsg, mergeMsg, Some((oldMessages, activeDirection))).cache() + g, sendMsg, mergeMsg, Some((oldMessages, activeDirection))) // The call to count() materializes `messages` and the vertices of `g`. This hides oldMessages // (depended on by the vertices of g) and the vertices of prevG (depended on by oldMessages // and the vertices of g). + messageCheckpointer.update(messages.asInstanceOf[RDD[(VertexId, A)]]) activeMessages = messages.count() logInfo("Pregel finished iteration " + i) @@ -154,7 +169,9 @@ object Pregel extends Logging { // count the iteration i += 1 } - messages.unpersist(blocking = false) + messageCheckpointer.unpersistDataSet() + graphCheckpointer.deleteAllCheckpoints() + messageCheckpointer.deleteAllCheckpoints() g } // end of apply diff --git a/mllib/src/main/scala/org/apache/spark/mllib/impl/PeriodicGraphCheckpointer.scala b/graphx/src/main/scala/org/apache/spark/graphx/util/PeriodicGraphCheckpointer.scala similarity index 91% rename from mllib/src/main/scala/org/apache/spark/mllib/impl/PeriodicGraphCheckpointer.scala rename to graphx/src/main/scala/org/apache/spark/graphx/util/PeriodicGraphCheckpointer.scala index 80074897567eb..fda501aa757d6 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/impl/PeriodicGraphCheckpointer.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/util/PeriodicGraphCheckpointer.scala @@ -15,11 +15,12 @@ * limitations under the License. */ -package org.apache.spark.mllib.impl +package org.apache.spark.graphx.util import org.apache.spark.SparkContext import org.apache.spark.graphx.Graph import org.apache.spark.storage.StorageLevel +import org.apache.spark.util.PeriodicCheckpointer /** @@ -74,9 +75,8 @@ import org.apache.spark.storage.StorageLevel * @tparam VD Vertex descriptor type * @tparam ED Edge descriptor type * - * TODO: Move this out of MLlib? */ -private[mllib] class PeriodicGraphCheckpointer[VD, ED]( +private[spark] class PeriodicGraphCheckpointer[VD, ED]( checkpointInterval: Int, sc: SparkContext) extends PeriodicCheckpointer[Graph[VD, ED]](checkpointInterval, sc) { @@ -87,10 +87,13 @@ private[mllib] class PeriodicGraphCheckpointer[VD, ED]( override protected def persist(data: Graph[VD, ED]): Unit = { if (data.vertices.getStorageLevel == StorageLevel.NONE) { - data.vertices.persist() + /* We need to use cache because persist does not honor the default storage level requested + * when constructing the graph. Only cache does that. + */ + data.vertices.cache() } if (data.edges.getStorageLevel == StorageLevel.NONE) { - data.edges.persist() + data.edges.cache() } } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/impl/PeriodicGraphCheckpointerSuite.scala b/graphx/src/test/scala/org/apache/spark/graphx/util/PeriodicGraphCheckpointerSuite.scala similarity index 70% rename from mllib/src/test/scala/org/apache/spark/mllib/impl/PeriodicGraphCheckpointerSuite.scala rename to graphx/src/test/scala/org/apache/spark/graphx/util/PeriodicGraphCheckpointerSuite.scala index a13e7f63a9296..e0c65e6940f66 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/impl/PeriodicGraphCheckpointerSuite.scala +++ b/graphx/src/test/scala/org/apache/spark/graphx/util/PeriodicGraphCheckpointerSuite.scala @@ -15,77 +15,81 @@ * limitations under the License. */ -package org.apache.spark.mllib.impl +package org.apache.spark.graphx.util import org.apache.hadoop.fs.Path import org.apache.spark.{SparkContext, SparkFunSuite} -import org.apache.spark.graphx.{Edge, Graph} -import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.graphx.{Edge, Graph, LocalSparkContext} import org.apache.spark.storage.StorageLevel import org.apache.spark.util.Utils -class PeriodicGraphCheckpointerSuite extends SparkFunSuite with MLlibTestSparkContext { +class PeriodicGraphCheckpointerSuite extends SparkFunSuite with LocalSparkContext { import PeriodicGraphCheckpointerSuite._ test("Persisting") { var graphsToCheck = Seq.empty[GraphToCheck] - val graph1 = createGraph(sc) - val checkpointer = - new PeriodicGraphCheckpointer[Double, Double](10, graph1.vertices.sparkContext) - checkpointer.update(graph1) - graphsToCheck = graphsToCheck :+ GraphToCheck(graph1, 1) - checkPersistence(graphsToCheck, 1) - - var iteration = 2 - while (iteration < 9) { - val graph = createGraph(sc) - checkpointer.update(graph) - graphsToCheck = graphsToCheck :+ GraphToCheck(graph, iteration) - checkPersistence(graphsToCheck, iteration) - iteration += 1 + withSpark { sc => + val graph1 = createGraph(sc) + val checkpointer = + new PeriodicGraphCheckpointer[Double, Double](10, graph1.vertices.sparkContext) + checkpointer.update(graph1) + graphsToCheck = graphsToCheck :+ GraphToCheck(graph1, 1) + checkPersistence(graphsToCheck, 1) + + var iteration = 2 + while (iteration < 9) { + val graph = createGraph(sc) + checkpointer.update(graph) + graphsToCheck = graphsToCheck :+ GraphToCheck(graph, iteration) + checkPersistence(graphsToCheck, iteration) + iteration += 1 + } } } test("Checkpointing") { - val tempDir = Utils.createTempDir() - val path = tempDir.toURI.toString - val checkpointInterval = 2 - var graphsToCheck = Seq.empty[GraphToCheck] - sc.setCheckpointDir(path) - val graph1 = createGraph(sc) - val checkpointer = new PeriodicGraphCheckpointer[Double, Double]( - checkpointInterval, graph1.vertices.sparkContext) - checkpointer.update(graph1) - graph1.edges.count() - graph1.vertices.count() - graphsToCheck = graphsToCheck :+ GraphToCheck(graph1, 1) - checkCheckpoint(graphsToCheck, 1, checkpointInterval) - - var iteration = 2 - while (iteration < 9) { - val graph = createGraph(sc) - checkpointer.update(graph) - graph.vertices.count() - graph.edges.count() - graphsToCheck = graphsToCheck :+ GraphToCheck(graph, iteration) - checkCheckpoint(graphsToCheck, iteration, checkpointInterval) - iteration += 1 - } + withSpark { sc => + val tempDir = Utils.createTempDir() + val path = tempDir.toURI.toString + val checkpointInterval = 2 + var graphsToCheck = Seq.empty[GraphToCheck] + sc.setCheckpointDir(path) + val graph1 = createGraph(sc) + val checkpointer = new PeriodicGraphCheckpointer[Double, Double]( + checkpointInterval, graph1.vertices.sparkContext) + checkpointer.update(graph1) + graph1.edges.count() + graph1.vertices.count() + graphsToCheck = graphsToCheck :+ GraphToCheck(graph1, 1) + checkCheckpoint(graphsToCheck, 1, checkpointInterval) + + var iteration = 2 + while (iteration < 9) { + val graph = createGraph(sc) + checkpointer.update(graph) + graph.vertices.count() + graph.edges.count() + graphsToCheck = graphsToCheck :+ GraphToCheck(graph, iteration) + checkCheckpoint(graphsToCheck, iteration, checkpointInterval) + iteration += 1 + } - checkpointer.deleteAllCheckpoints() - graphsToCheck.foreach { graph => - confirmCheckpointRemoved(graph.graph) - } + checkpointer.deleteAllCheckpoints() + graphsToCheck.foreach { graph => + confirmCheckpointRemoved(graph.graph) + } - Utils.deleteRecursively(tempDir) + Utils.deleteRecursively(tempDir) + } } } private object PeriodicGraphCheckpointerSuite { + private val defaultStorageLevel = StorageLevel.MEMORY_ONLY_SER case class GraphToCheck(graph: Graph[Double, Double], gIndex: Int) @@ -96,7 +100,8 @@ private object PeriodicGraphCheckpointerSuite { Edge[Double](3, 4, 0)) def createGraph(sc: SparkContext): Graph[Double, Double] = { - Graph.fromEdges[Double, Double](sc.parallelize(edges), 0) + Graph.fromEdges[Double, Double]( + sc.parallelize(edges), 0, defaultStorageLevel, defaultStorageLevel) } def checkPersistence(graphs: Seq[GraphToCheck], iteration: Int): Unit = { @@ -116,8 +121,8 @@ private object PeriodicGraphCheckpointerSuite { assert(graph.vertices.getStorageLevel == StorageLevel.NONE) assert(graph.edges.getStorageLevel == StorageLevel.NONE) } else { - assert(graph.vertices.getStorageLevel != StorageLevel.NONE) - assert(graph.edges.getStorageLevel != StorageLevel.NONE) + assert(graph.vertices.getStorageLevel == defaultStorageLevel) + assert(graph.edges.getStorageLevel == defaultStorageLevel) } } catch { case _: AssertionError => diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala index 2f50dc7c85f35..e3026c8efa823 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala @@ -36,7 +36,6 @@ import org.apache.spark.mllib.clustering.{DistributedLDAModel => OldDistributedL EMLDAOptimizer => OldEMLDAOptimizer, LDA => OldLDA, LDAModel => OldLDAModel, LDAOptimizer => OldLDAOptimizer, LocalLDAModel => OldLocalLDAModel, OnlineLDAOptimizer => OldOnlineLDAOptimizer} -import org.apache.spark.mllib.impl.PeriodicCheckpointer import org.apache.spark.mllib.linalg.{Vector => OldVector, Vectors => OldVectors} import org.apache.spark.mllib.linalg.MatrixImplicits._ import org.apache.spark.mllib.linalg.VectorImplicits._ @@ -45,9 +44,9 @@ import org.apache.spark.rdd.RDD import org.apache.spark.sql.{DataFrame, Dataset, Row, SparkSession} import org.apache.spark.sql.functions.{col, monotonically_increasing_id, udf} import org.apache.spark.sql.types.StructType +import org.apache.spark.util.PeriodicCheckpointer import org.apache.spark.util.VersionUtils - private[clustering] trait LDAParams extends Params with HasFeaturesCol with HasMaxIter with HasSeed with HasCheckpointInterval { diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/GradientBoostedTrees.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/GradientBoostedTrees.scala index 4c525c0714ec5..ce2bd7b430f43 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/GradientBoostedTrees.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/GradientBoostedTrees.scala @@ -21,12 +21,12 @@ import org.apache.spark.internal.Logging import org.apache.spark.ml.feature.LabeledPoint import org.apache.spark.ml.linalg.Vector import org.apache.spark.ml.regression.{DecisionTreeRegressionModel, DecisionTreeRegressor} -import org.apache.spark.mllib.impl.PeriodicRDDCheckpointer import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo} import org.apache.spark.mllib.tree.configuration.{BoostingStrategy => OldBoostingStrategy} import org.apache.spark.mllib.tree.impurity.{Variance => OldVariance} import org.apache.spark.mllib.tree.loss.{Loss => OldLoss} import org.apache.spark.rdd.RDD +import org.apache.spark.rdd.util.PeriodicRDDCheckpointer import org.apache.spark.storage.StorageLevel diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala index 48bae4276c480..3697a9b46dd84 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala @@ -25,7 +25,7 @@ import breeze.stats.distributions.{Gamma, RandBasis} import org.apache.spark.annotation.{DeveloperApi, Since} import org.apache.spark.graphx._ -import org.apache.spark.mllib.impl.PeriodicGraphCheckpointer +import org.apache.spark.graphx.util.PeriodicGraphCheckpointer import org.apache.spark.mllib.linalg.{DenseVector, Matrices, SparseVector, Vector, Vectors} import org.apache.spark.rdd.RDD import org.apache.spark.storage.StorageLevel From caf392025ce21d701b503112060fa016d5eabe04 Mon Sep 17 00:00:00 2001 From: Sameer Agarwal Date: Tue, 25 Apr 2017 17:05:20 -0700 Subject: [PATCH 0333/1765] [SPARK-18127] Add hooks and extension points to Spark ## What changes were proposed in this pull request? This patch adds support for customizing the spark session by injecting user-defined custom extensions. This allows a user to add custom analyzer rules/checks, optimizer rules, planning strategies or even a customized parser. ## How was this patch tested? Unit Tests in SparkSessionExtensionSuite Author: Sameer Agarwal Closes #17724 from sameeragarwal/session-extensions. --- .../sql/catalyst/parser/ParseDriver.scala | 9 +- .../sql/catalyst/parser/ParserInterface.scala | 35 +++- .../spark/sql/internal/StaticSQLConf.scala | 6 + .../org/apache/spark/sql/SparkSession.scala | 45 ++++- .../spark/sql/SparkSessionExtensions.scala | 171 ++++++++++++++++++ .../internal/BaseSessionStateBuilder.scala | 33 +++- .../sql/SparkSessionExtensionSuite.scala | 144 +++++++++++++++ 7 files changed, 418 insertions(+), 25 deletions(-) create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/SparkSessionExtensions.scala create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/SparkSessionExtensionSuite.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParseDriver.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParseDriver.scala index 80ab75cc17fab..dcccbd0ed8d6b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParseDriver.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParseDriver.scala @@ -34,8 +34,7 @@ import org.apache.spark.sql.types.{DataType, StructType} abstract class AbstractSqlParser extends ParserInterface with Logging { /** Creates/Resolves DataType for a given SQL string. */ - def parseDataType(sqlText: String): DataType = parse(sqlText) { parser => - // TODO add this to the parser interface. + override def parseDataType(sqlText: String): DataType = parse(sqlText) { parser => astBuilder.visitSingleDataType(parser.singleDataType()) } @@ -50,8 +49,10 @@ abstract class AbstractSqlParser extends ParserInterface with Logging { } /** Creates FunctionIdentifier for a given SQL string. */ - def parseFunctionIdentifier(sqlText: String): FunctionIdentifier = parse(sqlText) { parser => - astBuilder.visitSingleFunctionIdentifier(parser.singleFunctionIdentifier()) + override def parseFunctionIdentifier(sqlText: String): FunctionIdentifier = { + parse(sqlText) { parser => + astBuilder.visitSingleFunctionIdentifier(parser.singleFunctionIdentifier()) + } } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParserInterface.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParserInterface.scala index db3598bde04d3..75240d2196222 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParserInterface.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParserInterface.scala @@ -17,30 +17,51 @@ package org.apache.spark.sql.catalyst.parser +import org.apache.spark.annotation.DeveloperApi import org.apache.spark.sql.catalyst.{FunctionIdentifier, TableIdentifier} import org.apache.spark.sql.catalyst.expressions.Expression import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan -import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.types.{DataType, StructType} /** * Interface for a parser. */ +@DeveloperApi trait ParserInterface { - /** Creates LogicalPlan for a given SQL string. */ + /** + * Parse a string to a [[LogicalPlan]]. + */ + @throws[ParseException]("Text cannot be parsed to a LogicalPlan") def parsePlan(sqlText: String): LogicalPlan - /** Creates Expression for a given SQL string. */ + /** + * Parse a string to an [[Expression]]. + */ + @throws[ParseException]("Text cannot be parsed to an Expression") def parseExpression(sqlText: String): Expression - /** Creates TableIdentifier for a given SQL string. */ + /** + * Parse a string to a [[TableIdentifier]]. + */ + @throws[ParseException]("Text cannot be parsed to a TableIdentifier") def parseTableIdentifier(sqlText: String): TableIdentifier - /** Creates FunctionIdentifier for a given SQL string. */ + /** + * Parse a string to a [[FunctionIdentifier]]. + */ + @throws[ParseException]("Text cannot be parsed to a FunctionIdentifier") def parseFunctionIdentifier(sqlText: String): FunctionIdentifier /** - * Creates StructType for a given SQL string, which is a comma separated list of field - * definitions which will preserve the correct Hive metadata. + * Parse a string to a [[StructType]]. The passed SQL string should be a comma separated list + * of field definitions which will preserve the correct Hive metadata. */ + @throws[ParseException]("Text cannot be parsed to a schema") def parseTableSchema(sqlText: String): StructType + + /** + * Parse a string to a [[DataType]]. + */ + @throws[ParseException]("Text cannot be parsed to a DataType") + def parseDataType(sqlText: String): DataType } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/StaticSQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/StaticSQLConf.scala index af1a9cee2962a..c6c0a605d89ff 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/StaticSQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/StaticSQLConf.scala @@ -81,4 +81,10 @@ object StaticSQLConf { "SQL configuration and the current database.") .booleanConf .createWithDefault(false) + + val SPARK_SESSION_EXTENSIONS = buildStaticConf("spark.sql.extensions") + .doc("Name of the class used to configure Spark Session extensions. The class should " + + "implement Function1[SparkSessionExtension, Unit], and must have a no-args constructor.") + .stringConf + .createOptional } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala index 95f3463dfe62b..a519492ed8f4f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala @@ -38,7 +38,7 @@ import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, Range} import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.datasources.LogicalRelation import org.apache.spark.sql.execution.ui.SQLListener -import org.apache.spark.sql.internal.{BaseSessionStateBuilder, CatalogImpl, SessionState, SessionStateBuilder, SharedState} +import org.apache.spark.sql.internal._ import org.apache.spark.sql.internal.StaticSQLConf.CATALOG_IMPLEMENTATION import org.apache.spark.sql.sources.BaseRelation import org.apache.spark.sql.streaming._ @@ -77,11 +77,12 @@ import org.apache.spark.util.Utils class SparkSession private( @transient val sparkContext: SparkContext, @transient private val existingSharedState: Option[SharedState], - @transient private val parentSessionState: Option[SessionState]) + @transient private val parentSessionState: Option[SessionState], + @transient private[sql] val extensions: SparkSessionExtensions) extends Serializable with Closeable with Logging { self => private[sql] def this(sc: SparkContext) { - this(sc, None, None) + this(sc, None, None, new SparkSessionExtensions) } sparkContext.assertNotStopped() @@ -219,7 +220,7 @@ class SparkSession private( * @since 2.0.0 */ def newSession(): SparkSession = { - new SparkSession(sparkContext, Some(sharedState), parentSessionState = None) + new SparkSession(sparkContext, Some(sharedState), parentSessionState = None, extensions) } /** @@ -235,7 +236,7 @@ class SparkSession private( * implementation is Hive, this will initialize the metastore, which may take some time. */ private[sql] def cloneSession(): SparkSession = { - val result = new SparkSession(sparkContext, Some(sharedState), Some(sessionState)) + val result = new SparkSession(sparkContext, Some(sharedState), Some(sessionState), extensions) result.sessionState // force copy of SessionState result } @@ -754,6 +755,8 @@ object SparkSession { private[this] val options = new scala.collection.mutable.HashMap[String, String] + private[this] val extensions = new SparkSessionExtensions + private[this] var userSuppliedContext: Option[SparkContext] = None private[spark] def sparkContext(sparkContext: SparkContext): Builder = synchronized { @@ -847,6 +850,17 @@ object SparkSession { } } + /** + * Inject extensions into the [[SparkSession]]. This allows a user to add Analyzer rules, + * Optimizer rules, Planning Strategies or a customized parser. + * + * @since 2.2.0 + */ + def withExtensions(f: SparkSessionExtensions => Unit): Builder = { + f(extensions) + this + } + /** * Gets an existing [[SparkSession]] or, if there is no existing one, creates a new * one based on the options set in this builder. @@ -903,7 +917,26 @@ object SparkSession { } sc } - session = new SparkSession(sparkContext) + + // Initialize extensions if the user has defined a configurator class. + val extensionConfOption = sparkContext.conf.get(StaticSQLConf.SPARK_SESSION_EXTENSIONS) + if (extensionConfOption.isDefined) { + val extensionConfClassName = extensionConfOption.get + try { + val extensionConfClass = Utils.classForName(extensionConfClassName) + val extensionConf = extensionConfClass.newInstance() + .asInstanceOf[SparkSessionExtensions => Unit] + extensionConf(extensions) + } catch { + // Ignore the error if we cannot find the class or when the class has the wrong type. + case e @ (_: ClassCastException | + _: ClassNotFoundException | + _: NoClassDefFoundError) => + logWarning(s"Cannot use $extensionConfClassName to configure session extensions.", e) + } + } + + session = new SparkSession(sparkContext, None, None, extensions) options.foreach { case (k, v) => session.sessionState.conf.setConfString(k, v) } defaultSession.set(session) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SparkSessionExtensions.scala b/sql/core/src/main/scala/org/apache/spark/sql/SparkSessionExtensions.scala new file mode 100644 index 0000000000000..f99c108161f94 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/SparkSessionExtensions.scala @@ -0,0 +1,171 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import scala.collection.mutable + +import org.apache.spark.annotation.{DeveloperApi, Experimental, InterfaceStability} +import org.apache.spark.sql.catalyst.parser.ParserInterface +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.catalyst.rules.Rule + +/** + * :: Experimental :: + * Holder for injection points to the [[SparkSession]]. We make NO guarantee about the stability + * regarding binary compatibility and source compatibility of methods here. + * + * This current provides the following extension points: + * - Analyzer Rules. + * - Check Analysis Rules + * - Optimizer Rules. + * - Planning Strategies. + * - Customized Parser. + * - (External) Catalog listeners. + * + * The extensions can be used by calling withExtension on the [[SparkSession.Builder]], for + * example: + * {{{ + * SparkSession.builder() + * .master("...") + * .conf("...", true) + * .withExtensions { extensions => + * extensions.injectResolutionRule { session => + * ... + * } + * extensions.injectParser { (session, parser) => + * ... + * } + * } + * .getOrCreate() + * }}} + * + * Note that none of the injected builders should assume that the [[SparkSession]] is fully + * initialized and should not touch the session's internals (e.g. the SessionState). + */ +@DeveloperApi +@Experimental +@InterfaceStability.Unstable +class SparkSessionExtensions { + type RuleBuilder = SparkSession => Rule[LogicalPlan] + type CheckRuleBuilder = SparkSession => LogicalPlan => Unit + type StrategyBuilder = SparkSession => Strategy + type ParserBuilder = (SparkSession, ParserInterface) => ParserInterface + + private[this] val resolutionRuleBuilders = mutable.Buffer.empty[RuleBuilder] + + /** + * Build the analyzer resolution `Rule`s using the given [[SparkSession]]. + */ + private[sql] def buildResolutionRules(session: SparkSession): Seq[Rule[LogicalPlan]] = { + resolutionRuleBuilders.map(_.apply(session)) + } + + /** + * Inject an analyzer resolution `Rule` builder into the [[SparkSession]]. These analyzer + * rules will be executed as part of the resolution phase of analysis. + */ + def injectResolutionRule(builder: RuleBuilder): Unit = { + resolutionRuleBuilders += builder + } + + private[this] val postHocResolutionRuleBuilders = mutable.Buffer.empty[RuleBuilder] + + /** + * Build the analyzer post-hoc resolution `Rule`s using the given [[SparkSession]]. + */ + private[sql] def buildPostHocResolutionRules(session: SparkSession): Seq[Rule[LogicalPlan]] = { + postHocResolutionRuleBuilders.map(_.apply(session)) + } + + /** + * Inject an analyzer `Rule` builder into the [[SparkSession]]. These analyzer + * rules will be executed after resolution. + */ + def injectPostHocResolutionRule(builder: RuleBuilder): Unit = { + postHocResolutionRuleBuilders += builder + } + + private[this] val checkRuleBuilders = mutable.Buffer.empty[CheckRuleBuilder] + + /** + * Build the check analysis `Rule`s using the given [[SparkSession]]. + */ + private[sql] def buildCheckRules(session: SparkSession): Seq[LogicalPlan => Unit] = { + checkRuleBuilders.map(_.apply(session)) + } + + /** + * Inject an check analysis `Rule` builder into the [[SparkSession]]. The injected rules will + * be executed after the analysis phase. A check analysis rule is used to detect problems with a + * LogicalPlan and should throw an exception when a problem is found. + */ + def injectCheckRule(builder: CheckRuleBuilder): Unit = { + checkRuleBuilders += builder + } + + private[this] val optimizerRules = mutable.Buffer.empty[RuleBuilder] + + private[sql] def buildOptimizerRules(session: SparkSession): Seq[Rule[LogicalPlan]] = { + optimizerRules.map(_.apply(session)) + } + + /** + * Inject an optimizer `Rule` builder into the [[SparkSession]]. The injected rules will be + * executed during the operator optimization batch. An optimizer rule is used to improve the + * quality of an analyzed logical plan; these rules should never modify the result of the + * LogicalPlan. + */ + def injectOptimizerRule(builder: RuleBuilder): Unit = { + optimizerRules += builder + } + + private[this] val plannerStrategyBuilders = mutable.Buffer.empty[StrategyBuilder] + + private[sql] def buildPlannerStrategies(session: SparkSession): Seq[Strategy] = { + plannerStrategyBuilders.map(_.apply(session)) + } + + /** + * Inject a planner `Strategy` builder into the [[SparkSession]]. The injected strategy will + * be used to convert a `LogicalPlan` into a executable + * [[org.apache.spark.sql.execution.SparkPlan]]. + */ + def injectPlannerStrategy(builder: StrategyBuilder): Unit = { + plannerStrategyBuilders += builder + } + + private[this] val parserBuilders = mutable.Buffer.empty[ParserBuilder] + + private[sql] def buildParser( + session: SparkSession, + initial: ParserInterface): ParserInterface = { + parserBuilders.foldLeft(initial) { (parser, builder) => + builder(session, parser) + } + } + + /** + * Inject a custom parser into the [[SparkSession]]. Note that the builder is passed a session + * and an initial parser. The latter allows for a user to create a partial parser and to delegate + * to the underlying parser for completeness. If a user injects more parsers, then the parsers + * are stacked on top of each other. + */ + def injectParser(builder: ParserBuilder): Unit = { + parserBuilders += builder + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala index df7c3678b7807..2a801d87b12eb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala @@ -18,8 +18,8 @@ package org.apache.spark.sql.internal import org.apache.spark.SparkConf import org.apache.spark.annotation.{Experimental, InterfaceStability} -import org.apache.spark.sql.{ExperimentalMethods, SparkSession, Strategy, UDFRegistration} -import org.apache.spark.sql.catalyst.analysis.{Analyzer, FunctionRegistry, ResolveTimeZone} +import org.apache.spark.sql.{ExperimentalMethods, SparkSession, UDFRegistration, _} +import org.apache.spark.sql.catalyst.analysis.{Analyzer, FunctionRegistry} import org.apache.spark.sql.catalyst.catalog.SessionCatalog import org.apache.spark.sql.catalyst.optimizer.Optimizer import org.apache.spark.sql.catalyst.parser.ParserInterface @@ -63,6 +63,11 @@ abstract class BaseSessionStateBuilder( */ protected def newBuilder: NewBuilder + /** + * Session extensions defined in the [[SparkSession]]. + */ + protected def extensions: SparkSessionExtensions = session.extensions + /** * Extract entries from `SparkConf` and put them in the `SQLConf` */ @@ -108,7 +113,9 @@ abstract class BaseSessionStateBuilder( * * Note: this depends on the `conf` field. */ - protected lazy val sqlParser: ParserInterface = new SparkSqlParser(conf) + protected lazy val sqlParser: ParserInterface = { + extensions.buildParser(session, new SparkSqlParser(conf)) + } /** * ResourceLoader that is used to load function resources and jars. @@ -171,7 +178,9 @@ abstract class BaseSessionStateBuilder( * * Note that this may NOT depend on the `analyzer` function. */ - protected def customResolutionRules: Seq[Rule[LogicalPlan]] = Nil + protected def customResolutionRules: Seq[Rule[LogicalPlan]] = { + extensions.buildResolutionRules(session) + } /** * Custom post resolution rules to add to the Analyzer. Prefer overriding this instead of @@ -179,7 +188,9 @@ abstract class BaseSessionStateBuilder( * * Note that this may NOT depend on the `analyzer` function. */ - protected def customPostHocResolutionRules: Seq[Rule[LogicalPlan]] = Nil + protected def customPostHocResolutionRules: Seq[Rule[LogicalPlan]] = { + extensions.buildPostHocResolutionRules(session) + } /** * Custom check rules to add to the Analyzer. Prefer overriding this instead of creating @@ -187,7 +198,9 @@ abstract class BaseSessionStateBuilder( * * Note that this may NOT depend on the `analyzer` function. */ - protected def customCheckRules: Seq[LogicalPlan => Unit] = Nil + protected def customCheckRules: Seq[LogicalPlan => Unit] = { + extensions.buildCheckRules(session) + } /** * Logical query plan optimizer. @@ -207,7 +220,9 @@ abstract class BaseSessionStateBuilder( * * Note that this may NOT depend on the `optimizer` function. */ - protected def customOperatorOptimizationRules: Seq[Rule[LogicalPlan]] = Nil + protected def customOperatorOptimizationRules: Seq[Rule[LogicalPlan]] = { + extensions.buildOptimizerRules(session) + } /** * Planner that converts optimized logical plans to physical plans. @@ -227,7 +242,9 @@ abstract class BaseSessionStateBuilder( * * Note that this may NOT depend on the `planner` function. */ - protected def customPlanningStrategies: Seq[Strategy] = Nil + protected def customPlanningStrategies: Seq[Strategy] = { + extensions.buildPlannerStrategies(session) + } /** * Create a query execution object. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionExtensionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionExtensionSuite.scala new file mode 100644 index 0000000000000..43db79663322a --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionExtensionSuite.scala @@ -0,0 +1,144 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.{FunctionIdentifier, TableIdentifier} +import org.apache.spark.sql.catalyst.expressions.Expression +import org.apache.spark.sql.catalyst.parser.{CatalystSqlParser, ParserInterface} +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.execution.{SparkPlan, SparkStrategy} +import org.apache.spark.sql.types.{DataType, StructType} + +/** + * Test cases for the [[SparkSessionExtensions]]. + */ +class SparkSessionExtensionSuite extends SparkFunSuite { + type ExtensionsBuilder = SparkSessionExtensions => Unit + private def create(builder: ExtensionsBuilder): ExtensionsBuilder = builder + + private def stop(spark: SparkSession): Unit = { + spark.stop() + SparkSession.clearActiveSession() + SparkSession.clearDefaultSession() + } + + private def withSession(builder: ExtensionsBuilder)(f: SparkSession => Unit): Unit = { + val spark = SparkSession.builder().master("local[1]").withExtensions(builder).getOrCreate() + try f(spark) finally { + stop(spark) + } + } + + test("inject analyzer rule") { + withSession(_.injectResolutionRule(MyRule)) { session => + assert(session.sessionState.analyzer.extendedResolutionRules.contains(MyRule(session))) + } + } + + test("inject check analysis rule") { + withSession(_.injectCheckRule(MyCheckRule)) { session => + assert(session.sessionState.analyzer.extendedCheckRules.contains(MyCheckRule(session))) + } + } + + test("inject optimizer rule") { + withSession(_.injectOptimizerRule(MyRule)) { session => + assert(session.sessionState.optimizer.batches.flatMap(_.rules).contains(MyRule(session))) + } + } + + test("inject spark planner strategy") { + withSession(_.injectPlannerStrategy(MySparkStrategy)) { session => + assert(session.sessionState.planner.strategies.contains(MySparkStrategy(session))) + } + } + + test("inject parser") { + val extension = create { extensions => + extensions.injectParser((_, _) => CatalystSqlParser) + } + withSession(extension) { session => + assert(session.sessionState.sqlParser == CatalystSqlParser) + } + } + + test("inject stacked parsers") { + val extension = create { extensions => + extensions.injectParser((_, _) => CatalystSqlParser) + extensions.injectParser(MyParser) + extensions.injectParser(MyParser) + } + withSession(extension) { session => + val parser = MyParser(session, MyParser(session, CatalystSqlParser)) + assert(session.sessionState.sqlParser == parser) + } + } + + test("use custom class for extensions") { + val session = SparkSession.builder() + .master("local[1]") + .config("spark.sql.extensions", classOf[MyExtensions].getCanonicalName) + .getOrCreate() + try { + assert(session.sessionState.planner.strategies.contains(MySparkStrategy(session))) + assert(session.sessionState.analyzer.extendedResolutionRules.contains(MyRule(session))) + } finally { + stop(session) + } + } +} + +case class MyRule(spark: SparkSession) extends Rule[LogicalPlan] { + override def apply(plan: LogicalPlan): LogicalPlan = plan +} + +case class MyCheckRule(spark: SparkSession) extends (LogicalPlan => Unit) { + override def apply(plan: LogicalPlan): Unit = { } +} + +case class MySparkStrategy(spark: SparkSession) extends SparkStrategy { + override def apply(plan: LogicalPlan): Seq[SparkPlan] = Seq.empty +} + +case class MyParser(spark: SparkSession, delegate: ParserInterface) extends ParserInterface { + override def parsePlan(sqlText: String): LogicalPlan = + delegate.parsePlan(sqlText) + + override def parseExpression(sqlText: String): Expression = + delegate.parseExpression(sqlText) + + override def parseTableIdentifier(sqlText: String): TableIdentifier = + delegate.parseTableIdentifier(sqlText) + + override def parseFunctionIdentifier(sqlText: String): FunctionIdentifier = + delegate.parseFunctionIdentifier(sqlText) + + override def parseTableSchema(sqlText: String): StructType = + delegate.parseTableSchema(sqlText) + + override def parseDataType(sqlText: String): DataType = + delegate.parseDataType(sqlText) +} + +class MyExtensions extends (SparkSessionExtensions => Unit) { + def apply(e: SparkSessionExtensions): Unit = { + e.injectPlannerStrategy(MySparkStrategy) + e.injectResolutionRule(MyRule) + } +} From 57e1da39464131329318b723caa54df9f55fa54f Mon Sep 17 00:00:00 2001 From: Eric Wasserman Date: Wed, 26 Apr 2017 11:42:43 +0800 Subject: [PATCH 0334/1765] [SPARK-16548][SQL] Inconsistent error handling in JSON parsing SQL functions ## What changes were proposed in this pull request? change to using Jackson's `com.fasterxml.jackson.core.JsonFactory` public JsonParser createParser(String content) ## How was this patch tested? existing unit tests Please review http://spark.apache.org/contributing.html before opening a pull request. Author: Eric Wasserman Closes #17693 from ewasserman/SPARK-20314. --- .../catalyst/expressions/jsonExpressions.scala | 12 +++++++++--- .../expressions/JsonExpressionsSuite.scala | 17 +++++++++++++++++ 2 files changed, 26 insertions(+), 3 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala index df4d406b84d60..9fb0ea68153d2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.catalyst.expressions -import java.io.{ByteArrayOutputStream, CharArrayWriter, StringWriter} +import java.io.{ByteArrayInputStream, ByteArrayOutputStream, CharArrayWriter, InputStreamReader, StringWriter} import scala.util.parsing.combinator.RegexParsers @@ -149,7 +149,10 @@ case class GetJsonObject(json: Expression, path: Expression) if (parsed.isDefined) { try { - Utils.tryWithResource(jsonFactory.createParser(jsonStr.getBytes)) { parser => + /* We know the bytes are UTF-8 encoded. Pass a Reader to avoid having Jackson + detect character encoding which could fail for some malformed strings */ + Utils.tryWithResource(jsonFactory.createParser(new InputStreamReader( + new ByteArrayInputStream(jsonStr.getBytes), "UTF-8"))) { parser => val output = new ByteArrayOutputStream() val matched = Utils.tryWithResource( jsonFactory.createGenerator(output, JsonEncoding.UTF8)) { generator => @@ -393,7 +396,10 @@ case class JsonTuple(children: Seq[Expression]) } try { - Utils.tryWithResource(jsonFactory.createParser(json.getBytes)) { + /* We know the bytes are UTF-8 encoded. Pass a Reader to avoid having Jackson + detect character encoding which could fail for some malformed strings */ + Utils.tryWithResource(jsonFactory.createParser(new InputStreamReader( + new ByteArrayInputStream(json.getBytes), "UTF-8"))) { parser => parseRow(parser, input) } } catch { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/JsonExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/JsonExpressionsSuite.scala index c5b72235e5db0..4402ad4e9a9e5 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/JsonExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/JsonExpressionsSuite.scala @@ -39,6 +39,10 @@ class JsonExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { |"fb:testid":"1234"} |""".stripMargin + /* invalid json with leading nulls would trigger java.io.CharConversionException + in Jackson's JsonFactory.createParser(byte[]) due to RFC-4627 encoding detection */ + val badJson = "\0\0\0A\1AAA" + test("$.store.bicycle") { checkEvaluation( GetJsonObject(Literal(json), Literal("$.store.bicycle")), @@ -224,6 +228,13 @@ class JsonExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { null) } + test("SPARK-16548: character conversion") { + checkEvaluation( + GetJsonObject(Literal(badJson), Literal("$.a")), + null + ) + } + test("non foldable literal") { checkEvaluation( GetJsonObject(NonFoldableLiteral(json), NonFoldableLiteral("$.fb:testid")), @@ -340,6 +351,12 @@ class JsonExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { InternalRow(null, null, null, null, null)) } + test("SPARK-16548: json_tuple - invalid json with leading nulls") { + checkJsonTuple( + JsonTuple(Literal(badJson) :: jsonTupleQuery), + InternalRow(null, null, null, null, null)) + } + test("json_tuple - preserve newlines") { checkJsonTuple( JsonTuple(Literal("{\"a\":\"b\nc\"}") :: Literal("a") :: Nil), From df58a95a33b739462dbe84e098839af2a8643d45 Mon Sep 17 00:00:00 2001 From: zero323 Date: Tue, 25 Apr 2017 22:00:45 -0700 Subject: [PATCH 0335/1765] [SPARK-20437][R] R wrappers for rollup and cube ## What changes were proposed in this pull request? - Add `rollup` and `cube` methods and corresponding generics. - Add short description to the vignette. ## How was this patch tested? - Existing unit tests. - Additional unit tests covering new features. - `check-cran.sh`. Author: zero323 Closes #17728 from zero323/SPARK-20437. --- R/pkg/NAMESPACE | 2 + R/pkg/R/DataFrame.R | 73 +++++++++++++++- R/pkg/R/generics.R | 8 ++ R/pkg/inst/tests/testthat/test_sparkSQL.R | 102 ++++++++++++++++++++++ R/pkg/vignettes/sparkr-vignettes.Rmd | 15 ++++ docs/sparkr.md | 30 +++++++ 6 files changed, 229 insertions(+), 1 deletion(-) diff --git a/R/pkg/NAMESPACE b/R/pkg/NAMESPACE index 95d5cc6d1c78e..2800461658483 100644 --- a/R/pkg/NAMESPACE +++ b/R/pkg/NAMESPACE @@ -101,6 +101,7 @@ exportMethods("arrange", "createOrReplaceTempView", "crossJoin", "crosstab", + "cube", "dapply", "dapplyCollect", "describe", @@ -143,6 +144,7 @@ exportMethods("arrange", "registerTempTable", "rename", "repartition", + "rollup", "sample", "sample_frac", "sampleBy", diff --git a/R/pkg/R/DataFrame.R b/R/pkg/R/DataFrame.R index 88a138fd8eb1f..cd6f03a13d7c7 100644 --- a/R/pkg/R/DataFrame.R +++ b/R/pkg/R/DataFrame.R @@ -1321,7 +1321,7 @@ setMethod("toRDD", #' Groups the SparkDataFrame using the specified columns, so we can run aggregation on them. #' #' @param x a SparkDataFrame. -#' @param ... variable(s) (character names(s) or Column(s)) to group on. +#' @param ... character name(s) or Column(s) to group on. #' @return A GroupedData. #' @family SparkDataFrame functions #' @aliases groupBy,SparkDataFrame-method @@ -1337,6 +1337,7 @@ setMethod("toRDD", #' agg(groupBy(df, "department", "gender"), salary="avg", "age" -> "max") #' } #' @note groupBy since 1.4.0 +#' @seealso \link{agg}, \link{cube}, \link{rollup} setMethod("groupBy", signature(x = "SparkDataFrame"), function(x, ...) { @@ -3642,3 +3643,73 @@ setMethod("checkpoint", df <- callJMethod(x@sdf, "checkpoint", as.logical(eager)) dataFrame(df) }) + +#' cube +#' +#' Create a multi-dimensional cube for the SparkDataFrame using the specified columns. +#' +#' If grouping expression is missing \code{cube} creates a single global aggregate and is equivalent to +#' direct application of \link{agg}. +#' +#' @param x a SparkDataFrame. +#' @param ... character name(s) or Column(s) to group on. +#' @return A GroupedData. +#' @family SparkDataFrame functions +#' @aliases cube,SparkDataFrame-method +#' @rdname cube +#' @name cube +#' @export +#' @examples +#' \dontrun{ +#' df <- createDataFrame(mtcars) +#' mean(cube(df, "cyl", "gear", "am"), "mpg") +#' +#' # Following calls are equivalent +#' agg(cube(carsDF), mean(carsDF$mpg)) +#' agg(carsDF, mean(carsDF$mpg)) +#' } +#' @note cube since 2.3.0 +#' @seealso \link{agg}, \link{groupBy}, \link{rollup} +setMethod("cube", + signature(x = "SparkDataFrame"), + function(x, ...) { + cols <- list(...) + jcol <- lapply(cols, function(x) if (class(x) == "Column") x@jc else column(x)@jc) + sgd <- callJMethod(x@sdf, "cube", jcol) + groupedData(sgd) + }) + +#' rollup +#' +#' Create a multi-dimensional rollup for the SparkDataFrame using the specified columns. +#' +#' If grouping expression is missing \code{rollup} creates a single global aggregate and is equivalent to +#' direct application of \link{agg}. +#' +#' @param x a SparkDataFrame. +#' @param ... character name(s) or Column(s) to group on. +#' @return A GroupedData. +#' @family SparkDataFrame functions +#' @aliases rollup,SparkDataFrame-method +#' @rdname rollup +#' @name rollup +#' @export +#' @examples +#'\dontrun{ +#' df <- createDataFrame(mtcars) +#' mean(rollup(df, "cyl", "gear", "am"), "mpg") +#' +#' # Following calls are equivalent +#' agg(rollup(carsDF), mean(carsDF$mpg)) +#' agg(carsDF, mean(carsDF$mpg)) +#' } +#' @note rollup since 2.3.0 +#' @seealso \link{agg}, \link{cube}, \link{groupBy} +setMethod("rollup", + signature(x = "SparkDataFrame"), + function(x, ...) { + cols <- list(...) + jcol <- lapply(cols, function(x) if (class(x) == "Column") x@jc else column(x)@jc) + sgd <- callJMethod(x@sdf, "rollup", jcol) + groupedData(sgd) + }) diff --git a/R/pkg/R/generics.R b/R/pkg/R/generics.R index 5e7a1c60c2b3b..749ee9b54cc80 100644 --- a/R/pkg/R/generics.R +++ b/R/pkg/R/generics.R @@ -483,6 +483,10 @@ setGeneric("createOrReplaceTempView", # @export setGeneric("crossJoin", function(x, y) { standardGeneric("crossJoin") }) +#' @rdname cube +#' @export +setGeneric("cube", function(x, ...) { standardGeneric("cube") }) + #' @rdname dapply #' @export setGeneric("dapply", function(x, func, schema) { standardGeneric("dapply") }) @@ -631,6 +635,10 @@ setGeneric("sample", standardGeneric("sample") }) +#' @rdname rollup +#' @export +setGeneric("rollup", function(x, ...) { standardGeneric("rollup") }) + #' @rdname sample #' @export setGeneric("sample_frac", diff --git a/R/pkg/inst/tests/testthat/test_sparkSQL.R b/R/pkg/inst/tests/testthat/test_sparkSQL.R index c21ba2f1a138b..2cef7191d4f2a 100644 --- a/R/pkg/inst/tests/testthat/test_sparkSQL.R +++ b/R/pkg/inst/tests/testthat/test_sparkSQL.R @@ -1816,6 +1816,108 @@ test_that("pivot GroupedData column", { expect_error(collect(sum(pivot(groupBy(df, "year"), "course", list("R", "R")), "earnings"))) }) +test_that("test multi-dimensional aggregations with cube and rollup", { + df <- createDataFrame(data.frame( + id = 1:6, + year = c(2016, 2016, 2016, 2017, 2017, 2017), + salary = c(10000, 15000, 20000, 22000, 32000, 21000), + department = c("management", "rnd", "sales", "management", "rnd", "sales") + )) + + actual_cube <- collect( + orderBy( + agg( + cube(df, "year", "department"), + expr("sum(salary) AS total_salary"), expr("avg(salary) AS average_salary") + ), + "year", "department" + ) + ) + + expected_cube <- data.frame( + year = c(rep(NA, 4), rep(2016, 4), rep(2017, 4)), + department = rep(c(NA, "management", "rnd", "sales"), times = 3), + total_salary = c( + 120000, # Total + 10000 + 22000, 15000 + 32000, 20000 + 21000, # Department only + 20000 + 15000 + 10000, # 2016 + 10000, 15000, 20000, # 2016 each department + 21000 + 32000 + 22000, # 2017 + 22000, 32000, 21000 # 2017 each department + ), + average_salary = c( + # Total + mean(c(20000, 15000, 10000, 21000, 32000, 22000)), + # Mean by department + mean(c(10000, 22000)), mean(c(15000, 32000)), mean(c(20000, 21000)), + mean(c(10000, 15000, 20000)), # 2016 + 10000, 15000, 20000, # 2016 each department + mean(c(21000, 32000, 22000)), # 2017 + 22000, 32000, 21000 # 2017 each department + ), + stringsAsFactors = FALSE + ) + + expect_equal(actual_cube, expected_cube) + + # cube should accept column objects + expect_equal( + count(sum(cube(df, df$year, df$department), "salary")), + 12 + ) + + # cube without columns should result in a single aggregate + expect_equal( + collect(agg(cube(df), expr("sum(salary) as total_salary"))), + data.frame(total_salary = 120000) + ) + + actual_rollup <- collect( + orderBy( + agg( + rollup(df, "year", "department"), + expr("sum(salary) AS total_salary"), expr("avg(salary) AS average_salary") + ), + "year", "department" + ) + ) + + expected_rollup <- data.frame( + year = c(NA, rep(2016, 4), rep(2017, 4)), + department = c(NA, rep(c(NA, "management", "rnd", "sales"), times = 2)), + total_salary = c( + 120000, # Total + 20000 + 15000 + 10000, # 2016 + 10000, 15000, 20000, # 2016 each department + 21000 + 32000 + 22000, # 2017 + 22000, 32000, 21000 # 2017 each department + ), + average_salary = c( + # Total + mean(c(20000, 15000, 10000, 21000, 32000, 22000)), + mean(c(10000, 15000, 20000)), # 2016 + 10000, 15000, 20000, # 2016 each department + mean(c(21000, 32000, 22000)), # 2017 + 22000, 32000, 21000 # 2017 each department + ), + stringsAsFactors = FALSE + ) + + expect_equal(actual_rollup, expected_rollup) + + # cube should accept column objects + expect_equal( + count(sum(rollup(df, df$year, df$department), "salary")), + 9 + ) + + # rollup without columns should result in a single aggregate + expect_equal( + collect(agg(rollup(df), expr("sum(salary) as total_salary"))), + data.frame(total_salary = 120000) + ) +}) + test_that("arrange() and orderBy() on a DataFrame", { df <- read.json(jsonPath) sorted <- arrange(df, df$age) diff --git a/R/pkg/vignettes/sparkr-vignettes.Rmd b/R/pkg/vignettes/sparkr-vignettes.Rmd index f81dbab10b1e1..4b9d6c3806098 100644 --- a/R/pkg/vignettes/sparkr-vignettes.Rmd +++ b/R/pkg/vignettes/sparkr-vignettes.Rmd @@ -308,6 +308,21 @@ numCyl <- summarize(groupBy(carsDF, carsDF$cyl), count = n(carsDF$cyl)) head(numCyl) ``` +Use `cube` or `rollup` to compute subtotals across multiple dimensions. + +```{r} +mean(cube(carsDF, "cyl", "gear", "am"), "mpg") +``` + +generates groupings for {(`cyl`, `gear`, `am`), (`cyl`, `gear`), (`cyl`), ()}, while + +```{r} +mean(rollup(carsDF, "cyl", "gear", "am"), "mpg") +``` + +generates groupings for all possible combinations of grouping columns. + + #### Operating on Columns SparkR also provides a number of functions that can directly applied to columns for data processing and during aggregation. The example below shows the use of basic arithmetic functions. diff --git a/docs/sparkr.md b/docs/sparkr.md index a1a35a7757e57..e015ab260fca8 100644 --- a/docs/sparkr.md +++ b/docs/sparkr.md @@ -264,6 +264,36 @@ head(arrange(waiting_counts, desc(waiting_counts$count))) {% endhighlight %} +In addition to standard aggregations, SparkR supports [OLAP cube](https://en.wikipedia.org/wiki/OLAP_cube) operators `cube`: + +
    +{% highlight r %} +head(agg(cube(df, "cyl", "disp", "gear"), avg(df$mpg))) +## cyl disp gear avg(mpg) +##1 NA 140.8 4 22.8 +##2 4 75.7 4 30.4 +##3 8 400.0 3 19.2 +##4 8 318.0 3 15.5 +##5 NA 351.0 NA 15.8 +##6 NA 275.8 NA 16.3 +{% endhighlight %} +
    + +and `rollup`: + +
    +{% highlight r %} +head(agg(rollup(df, "cyl", "disp", "gear"), avg(df$mpg))) +## cyl disp gear avg(mpg) +##1 4 75.7 4 30.4 +##2 8 400.0 3 19.2 +##3 8 318.0 3 15.5 +##4 4 78.7 NA 32.4 +##5 8 304.0 3 15.2 +##6 4 79.0 NA 27.3 +{% endhighlight %} +
    + ### Operating on Columns SparkR also provides a number of functions that can directly applied to columns for data processing and during aggregation. The example below shows the use of basic arithmetic functions. From 7a365257e934e838bd90f6a0c50362bf47202b0e Mon Sep 17 00:00:00 2001 From: anabranch Date: Wed, 26 Apr 2017 09:49:05 +0100 Subject: [PATCH 0336/1765] [SPARK-20400][DOCS] Remove References to 3rd Party Vendor Tools ## What changes were proposed in this pull request? Simple documentation change to remove explicit vendor references. ## How was this patch tested? NA Please review http://spark.apache.org/contributing.html before opening a pull request. Author: anabranch Closes #17695 from anabranch/remove-vendor. --- docs/configuration.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/configuration.md b/docs/configuration.md index 87b76322cae51..8b53e92ccd416 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -2270,8 +2270,8 @@ should be included on Spark's classpath: * `hdfs-site.xml`, which provides default behaviors for the HDFS client. * `core-site.xml`, which sets the default filesystem name. -The location of these configuration files varies across CDH and HDP versions, but -a common location is inside of `/etc/hadoop/conf`. Some tools, such as Cloudera Manager, create +The location of these configuration files varies across Hadoop versions, but +a common location is inside of `/etc/hadoop/conf`. Some tools create configurations on-the-fly, but offer a mechanisms to download copies of them. To make these files visible to Spark, set `HADOOP_CONF_DIR` in `$SPARK_HOME/spark-env.sh` From 7fecf5130163df9c204a2764d121a7011d007f4e Mon Sep 17 00:00:00 2001 From: Tom Graves Date: Wed, 26 Apr 2017 08:23:31 -0500 Subject: [PATCH 0337/1765] =?UTF-8?q?[SPARK-19812]=20YARN=20shuffle=20serv?= =?UTF-8?q?ice=20fails=20to=20relocate=20recovery=20DB=20acro=E2=80=A6?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit …ss NFS directories ## What changes were proposed in this pull request? Change from using java Files.move to use Hadoop filesystem operations to move the directories. The java Files.move does not work when moving directories across NFS mounts and in fact also says that if the directory has entries you should do a recursive move. We are already using Hadoop filesystem here so just use the local filesystem from there as it handles this properly. Note that the DB here is actually a directory of files and not just a single file, hence the change in the name of the local var. ## How was this patch tested? Ran YarnShuffleServiceSuite unit tests. Unfortunately couldn't easily add one here since involves NFS. Ran manual tests to verify that the DB directories were properly moved across NFS mounted directories. Have been running this internally for weeks. Author: Tom Graves Closes #17748 from tgravescs/SPARK-19812. --- .../network/yarn/YarnShuffleService.java | 23 +++++++++++-------- 1 file changed, 13 insertions(+), 10 deletions(-) diff --git a/common/network-yarn/src/main/java/org/apache/spark/network/yarn/YarnShuffleService.java b/common/network-yarn/src/main/java/org/apache/spark/network/yarn/YarnShuffleService.java index c7620d0fe1288..4acc203153e5a 100644 --- a/common/network-yarn/src/main/java/org/apache/spark/network/yarn/YarnShuffleService.java +++ b/common/network-yarn/src/main/java/org/apache/spark/network/yarn/YarnShuffleService.java @@ -21,7 +21,6 @@ import java.io.IOException; import java.nio.charset.StandardCharsets; import java.nio.ByteBuffer; -import java.nio.file.Files; import java.util.List; import java.util.Map; @@ -340,9 +339,9 @@ protected Path getRecoveryPath(String fileName) { * when it previously was not. If YARN NM recovery is enabled it uses that path, otherwise * it will uses a YARN local dir. */ - protected File initRecoveryDb(String dbFileName) { + protected File initRecoveryDb(String dbName) { if (_recoveryPath != null) { - File recoveryFile = new File(_recoveryPath.toUri().getPath(), dbFileName); + File recoveryFile = new File(_recoveryPath.toUri().getPath(), dbName); if (recoveryFile.exists()) { return recoveryFile; } @@ -350,7 +349,7 @@ protected File initRecoveryDb(String dbFileName) { // db doesn't exist in recovery path go check local dirs for it String[] localDirs = _conf.getTrimmedStrings("yarn.nodemanager.local-dirs"); for (String dir : localDirs) { - File f = new File(new Path(dir).toUri().getPath(), dbFileName); + File f = new File(new Path(dir).toUri().getPath(), dbName); if (f.exists()) { if (_recoveryPath == null) { // If NM recovery is not enabled, we should specify the recovery path using NM local @@ -363,17 +362,21 @@ protected File initRecoveryDb(String dbFileName) { // make sure to move all DBs to the recovery path from the old NM local dirs. // If another DB was initialized first just make sure all the DBs are in the same // location. - File newLoc = new File(_recoveryPath.toUri().getPath(), dbFileName); - if (!newLoc.equals(f)) { + Path newLoc = new Path(_recoveryPath, dbName); + Path copyFrom = new Path(f.toURI()); + if (!newLoc.equals(copyFrom)) { + logger.info("Moving " + copyFrom + " to: " + newLoc); try { - Files.move(f.toPath(), newLoc.toPath()); + // The move here needs to handle moving non-empty directories across NFS mounts + FileSystem fs = FileSystem.getLocal(_conf); + fs.rename(copyFrom, newLoc); } catch (Exception e) { // Fail to move recovery file to new path, just continue on with new DB location logger.error("Failed to move recovery file {} to the path {}", - dbFileName, _recoveryPath.toString(), e); + dbName, _recoveryPath.toString(), e); } } - return newLoc; + return new File(newLoc.toUri().getPath()); } } } @@ -381,7 +384,7 @@ protected File initRecoveryDb(String dbFileName) { _recoveryPath = new Path(localDirs[0]); } - return new File(_recoveryPath.toUri().getPath(), dbFileName); + return new File(_recoveryPath.toUri().getPath(), dbName); } /** From dbb06c689c157502cb081421baecce411832aad8 Mon Sep 17 00:00:00 2001 From: Yanbo Liang Date: Wed, 26 Apr 2017 21:34:18 +0800 Subject: [PATCH 0338/1765] [MINOR][ML] Fix some PySpark & SparkR flaky tests MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## What changes were proposed in this pull request? Some PySpark & SparkR tests run with tiny dataset and tiny ```maxIter```, which means they are not converged. I don’t think checking intermediate result during iteration make sense, and these intermediate result may vulnerable and not stable, so we should switch to check the converged result. We hit this issue at #17746 when we upgrade breeze to 0.13.1. ## How was this patch tested? Existing tests. Author: Yanbo Liang Closes #17757 from yanboliang/flaky-test. --- .../testthat/test_mllib_classification.R | 17 +---- python/pyspark/ml/classification.py | 71 ++++++++++--------- 2 files changed, 38 insertions(+), 50 deletions(-) diff --git a/R/pkg/inst/tests/testthat/test_mllib_classification.R b/R/pkg/inst/tests/testthat/test_mllib_classification.R index af7cbdccf5d5d..cbc7087182868 100644 --- a/R/pkg/inst/tests/testthat/test_mllib_classification.R +++ b/R/pkg/inst/tests/testthat/test_mllib_classification.R @@ -284,22 +284,11 @@ test_that("spark.mlp", { c("1.0", "1.0", "1.0", "1.0", "0.0", "1.0", "2.0", "2.0", "1.0", "0.0")) # test initialWeights - model <- spark.mlp(df, label ~ features, layers = c(4, 3), maxIter = 2, initialWeights = + model <- spark.mlp(df, label ~ features, layers = c(4, 3), initialWeights = c(0, 0, 0, 0, 0, 5, 5, 5, 5, 5, 9, 9, 9, 9, 9)) mlpPredictions <- collect(select(predict(model, mlpTestDF), "prediction")) expect_equal(head(mlpPredictions$prediction, 10), - c("1.0", "1.0", "2.0", "1.0", "2.0", "1.0", "2.0", "2.0", "1.0", "0.0")) - - model <- spark.mlp(df, label ~ features, layers = c(4, 3), maxIter = 2, initialWeights = - c(0.0, 0.0, 0.0, 0.0, 0.0, 5.0, 5.0, 5.0, 5.0, 5.0, 9.0, 9.0, 9.0, 9.0, 9.0)) - mlpPredictions <- collect(select(predict(model, mlpTestDF), "prediction")) - expect_equal(head(mlpPredictions$prediction, 10), - c("1.0", "1.0", "2.0", "1.0", "2.0", "1.0", "2.0", "2.0", "1.0", "0.0")) - - model <- spark.mlp(df, label ~ features, layers = c(4, 3), maxIter = 2) - mlpPredictions <- collect(select(predict(model, mlpTestDF), "prediction")) - expect_equal(head(mlpPredictions$prediction, 10), - c("1.0", "1.0", "1.0", "1.0", "0.0", "1.0", "0.0", "0.0", "1.0", "0.0")) + c("1.0", "1.0", "1.0", "1.0", "0.0", "1.0", "2.0", "2.0", "1.0", "0.0")) # Test formula works well df <- suppressWarnings(createDataFrame(iris)) @@ -310,8 +299,6 @@ test_that("spark.mlp", { expect_equal(summary$numOfOutputs, 3) expect_equal(summary$layers, c(4, 3)) expect_equal(length(summary$weights), 15) - expect_equal(head(summary$weights, 5), list(-0.5793153, -4.652961, 6.216155, -6.649478, - -10.51147), tolerance = 1e-3) }) test_that("spark.naiveBayes", { diff --git a/python/pyspark/ml/classification.py b/python/pyspark/ml/classification.py index 864968390ace9..a9756ea4af99a 100644 --- a/python/pyspark/ml/classification.py +++ b/python/pyspark/ml/classification.py @@ -185,34 +185,33 @@ class LogisticRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredicti >>> from pyspark.sql import Row >>> from pyspark.ml.linalg import Vectors >>> bdf = sc.parallelize([ - ... Row(label=1.0, weight=2.0, features=Vectors.dense(1.0)), - ... Row(label=0.0, weight=2.0, features=Vectors.sparse(1, [], []))]).toDF() - >>> blor = LogisticRegression(maxIter=5, regParam=0.01, weightCol="weight") + ... Row(label=1.0, weight=1.0, features=Vectors.dense(0.0, 5.0)), + ... Row(label=0.0, weight=2.0, features=Vectors.dense(1.0, 2.0)), + ... Row(label=1.0, weight=3.0, features=Vectors.dense(2.0, 1.0)), + ... Row(label=0.0, weight=4.0, features=Vectors.dense(3.0, 3.0))]).toDF() + >>> blor = LogisticRegression(regParam=0.01, weightCol="weight") >>> blorModel = blor.fit(bdf) >>> blorModel.coefficients - DenseVector([5.4...]) + DenseVector([-1.080..., -0.646...]) >>> blorModel.intercept - -2.63... - >>> mdf = sc.parallelize([ - ... Row(label=1.0, weight=2.0, features=Vectors.dense(1.0)), - ... Row(label=0.0, weight=2.0, features=Vectors.sparse(1, [], [])), - ... Row(label=2.0, weight=2.0, features=Vectors.dense(3.0))]).toDF() - >>> mlor = LogisticRegression(maxIter=5, regParam=0.01, weightCol="weight", - ... family="multinomial") + 3.112... + >>> data_path = "data/mllib/sample_multiclass_classification_data.txt" + >>> mdf = spark.read.format("libsvm").load(data_path) + >>> mlor = LogisticRegression(regParam=0.1, elasticNetParam=1.0, family="multinomial") >>> mlorModel = mlor.fit(mdf) >>> mlorModel.coefficientMatrix - DenseMatrix(3, 1, [-2.3..., 0.2..., 2.1...], 1) + SparseMatrix(3, 4, [0, 1, 2, 3], [3, 2, 1], [1.87..., -2.75..., -0.50...], 1) >>> mlorModel.interceptVector - DenseVector([2.1..., 0.6..., -2.8...]) - >>> test0 = sc.parallelize([Row(features=Vectors.dense(-1.0))]).toDF() + DenseVector([0.04..., -0.42..., 0.37...]) + >>> test0 = sc.parallelize([Row(features=Vectors.dense(-1.0, 1.0))]).toDF() >>> result = blorModel.transform(test0).head() >>> result.prediction - 0.0 + 1.0 >>> result.probability - DenseVector([0.99..., 0.00...]) + DenseVector([0.02..., 0.97...]) >>> result.rawPrediction - DenseVector([8.12..., -8.12...]) - >>> test1 = sc.parallelize([Row(features=Vectors.sparse(1, [0], [1.0]))]).toDF() + DenseVector([-3.54..., 3.54...]) + >>> test1 = sc.parallelize([Row(features=Vectors.sparse(2, [0], [1.0]))]).toDF() >>> blorModel.transform(test1).head().prediction 1.0 >>> blor.setParams("vector") @@ -222,8 +221,8 @@ class LogisticRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredicti >>> lr_path = temp_path + "/lr" >>> blor.save(lr_path) >>> lr2 = LogisticRegression.load(lr_path) - >>> lr2.getMaxIter() - 5 + >>> lr2.getRegParam() + 0.01 >>> model_path = temp_path + "/lr_model" >>> blorModel.save(model_path) >>> model2 = LogisticRegressionModel.load(model_path) @@ -1480,31 +1479,33 @@ class OneVsRest(Estimator, OneVsRestParams, MLReadable, MLWritable): >>> from pyspark.sql import Row >>> from pyspark.ml.linalg import Vectors - >>> df = sc.parallelize([ - ... Row(label=0.0, features=Vectors.dense(1.0, 0.8)), - ... Row(label=1.0, features=Vectors.sparse(2, [], [])), - ... Row(label=2.0, features=Vectors.dense(0.5, 0.5))]).toDF() - >>> lr = LogisticRegression(maxIter=5, regParam=0.01) + >>> data_path = "data/mllib/sample_multiclass_classification_data.txt" + >>> df = spark.read.format("libsvm").load(data_path) + >>> lr = LogisticRegression(regParam=0.01) >>> ovr = OneVsRest(classifier=lr) >>> model = ovr.fit(df) - >>> [x.coefficients for x in model.models] - [DenseVector([4.9791, 2.426]), DenseVector([-4.1198, -5.9326]), DenseVector([-3.314, 5.2423])] + >>> model.models[0].coefficients + DenseVector([0.5..., -1.0..., 3.4..., 4.2...]) + >>> model.models[1].coefficients + DenseVector([-2.1..., 3.1..., -2.6..., -2.3...]) + >>> model.models[2].coefficients + DenseVector([0.3..., -3.4..., 1.0..., -1.1...]) >>> [x.intercept for x in model.models] - [-5.06544..., 2.30341..., -1.29133...] - >>> test0 = sc.parallelize([Row(features=Vectors.dense(-1.0, 0.0))]).toDF() + [-2.7..., -2.5..., -1.3...] + >>> test0 = sc.parallelize([Row(features=Vectors.dense(-1.0, 0.0, 1.0, 1.0))]).toDF() >>> model.transform(test0).head().prediction - 1.0 - >>> test1 = sc.parallelize([Row(features=Vectors.sparse(2, [0], [1.0]))]).toDF() - >>> model.transform(test1).head().prediction 0.0 - >>> test2 = sc.parallelize([Row(features=Vectors.dense(0.5, 0.4))]).toDF() - >>> model.transform(test2).head().prediction + >>> test1 = sc.parallelize([Row(features=Vectors.sparse(4, [0], [1.0]))]).toDF() + >>> model.transform(test1).head().prediction 2.0 + >>> test2 = sc.parallelize([Row(features=Vectors.dense(0.5, 0.4, 0.3, 0.2))]).toDF() + >>> model.transform(test2).head().prediction + 0.0 >>> model_path = temp_path + "/ovr_model" >>> model.save(model_path) >>> model2 = OneVsRestModel.load(model_path) >>> model2.transform(test0).head().prediction - 1.0 + 0.0 .. versionadded:: 2.0.0 """ From 66dd5b83ff95d5f91f37dcdf6aac89faa0b871c5 Mon Sep 17 00:00:00 2001 From: jerryshao Date: Wed, 26 Apr 2017 09:01:50 -0500 Subject: [PATCH 0339/1765] [SPARK-20391][CORE] Rename memory related fields in ExecutorSummay ## What changes were proposed in this pull request? This is a follow-up of #14617 to make the name of memory related fields more meaningful. Here for the backward compatibility, I didn't change `maxMemory` and `memoryUsed` fields. ## How was this patch tested? Existing UT and local verification. CC squito and tgravescs . Author: jerryshao Closes #17700 from jerryshao/SPARK-20391. --- .../apache/spark/ui/static/executorspage.js | 48 +++++++++-------- .../org/apache/spark/status/api/v1/api.scala | 11 ++-- .../apache/spark/ui/exec/ExecutorsPage.scala | 21 ++++---- .../executor_memory_usage_expectation.json | 51 +++++++++++-------- ...xecutor_node_blacklisting_expectation.json | 51 +++++++++++-------- 5 files changed, 105 insertions(+), 77 deletions(-) diff --git a/core/src/main/resources/org/apache/spark/ui/static/executorspage.js b/core/src/main/resources/org/apache/spark/ui/static/executorspage.js index 930a0698928d1..cb9922d23c445 100644 --- a/core/src/main/resources/org/apache/spark/ui/static/executorspage.js +++ b/core/src/main/resources/org/apache/spark/ui/static/executorspage.js @@ -253,10 +253,14 @@ $(document).ready(function () { var deadTotalBlacklisted = 0; response.forEach(function (exec) { - exec.onHeapMemoryUsed = exec.hasOwnProperty('onHeapMemoryUsed') ? exec.onHeapMemoryUsed : 0; - exec.maxOnHeapMemory = exec.hasOwnProperty('maxOnHeapMemory') ? exec.maxOnHeapMemory : 0; - exec.offHeapMemoryUsed = exec.hasOwnProperty('offHeapMemoryUsed') ? exec.offHeapMemoryUsed : 0; - exec.maxOffHeapMemory = exec.hasOwnProperty('maxOffHeapMemory') ? exec.maxOffHeapMemory : 0; + var memoryMetrics = { + usedOnHeapStorageMemory: 0, + usedOffHeapStorageMemory: 0, + totalOnHeapStorageMemory: 0, + totalOffHeapStorageMemory: 0 + }; + + exec.memoryMetrics = exec.hasOwnProperty('memoryMetrics') ? exec.memoryMetrics : memoryMetrics; }); response.forEach(function (exec) { @@ -264,10 +268,10 @@ $(document).ready(function () { allRDDBlocks += exec.rddBlocks; allMemoryUsed += exec.memoryUsed; allMaxMemory += exec.maxMemory; - allOnHeapMemoryUsed += exec.onHeapMemoryUsed; - allOnHeapMaxMemory += exec.maxOnHeapMemory; - allOffHeapMemoryUsed += exec.offHeapMemoryUsed; - allOffHeapMaxMemory += exec.maxOffHeapMemory; + allOnHeapMemoryUsed += exec.memoryMetrics.usedOnHeapStorageMemory; + allOnHeapMaxMemory += exec.memoryMetrics.totalOnHeapStorageMemory; + allOffHeapMemoryUsed += exec.memoryMetrics.usedOffHeapStorageMemory; + allOffHeapMaxMemory += exec.memoryMetrics.totalOffHeapStorageMemory; allDiskUsed += exec.diskUsed; allTotalCores += exec.totalCores; allMaxTasks += exec.maxTasks; @@ -286,10 +290,10 @@ $(document).ready(function () { activeRDDBlocks += exec.rddBlocks; activeMemoryUsed += exec.memoryUsed; activeMaxMemory += exec.maxMemory; - activeOnHeapMemoryUsed += exec.onHeapMemoryUsed; - activeOnHeapMaxMemory += exec.maxOnHeapMemory; - activeOffHeapMemoryUsed += exec.offHeapMemoryUsed; - activeOffHeapMaxMemory += exec.maxOffHeapMemory; + activeOnHeapMemoryUsed += exec.memoryMetrics.usedOnHeapStorageMemory; + activeOnHeapMaxMemory += exec.memoryMetrics.totalOnHeapStorageMemory; + activeOffHeapMemoryUsed += exec.memoryMetrics.usedOffHeapStorageMemory; + activeOffHeapMaxMemory += exec.memoryMetrics.totalOffHeapStorageMemory; activeDiskUsed += exec.diskUsed; activeTotalCores += exec.totalCores; activeMaxTasks += exec.maxTasks; @@ -308,10 +312,10 @@ $(document).ready(function () { deadRDDBlocks += exec.rddBlocks; deadMemoryUsed += exec.memoryUsed; deadMaxMemory += exec.maxMemory; - deadOnHeapMemoryUsed += exec.onHeapMemoryUsed; - deadOnHeapMaxMemory += exec.maxOnHeapMemory; - deadOffHeapMemoryUsed += exec.offHeapMemoryUsed; - deadOffHeapMaxMemory += exec.maxOffHeapMemory; + deadOnHeapMemoryUsed += exec.memoryMetrics.usedOnHeapStorageMemory; + deadOnHeapMaxMemory += exec.memoryMetrics.totalOnHeapStorageMemory; + deadOffHeapMemoryUsed += exec.memoryMetrics.usedOffHeapStorageMemory; + deadOffHeapMaxMemory += exec.memoryMetrics.totalOffHeapStorageMemory; deadDiskUsed += exec.diskUsed; deadTotalCores += exec.totalCores; deadMaxTasks += exec.maxTasks; @@ -431,10 +435,10 @@ $(document).ready(function () { { data: function (row, type) { if (type !== 'display') - return row.onHeapMemoryUsed; + return row.memoryMetrics.usedOnHeapStorageMemory; else - return (formatBytes(row.onHeapMemoryUsed, type) + ' / ' + - formatBytes(row.maxOnHeapMemory, type)); + return (formatBytes(row.memoryMetrics.usedOnHeapStorageMemory, type) + ' / ' + + formatBytes(row.memoryMetrics.totalOnHeapStorageMemory, type)); }, "fnCreatedCell": function (nTd, sData, oData, iRow, iCol) { $(nTd).addClass('on_heap_memory') @@ -443,10 +447,10 @@ $(document).ready(function () { { data: function (row, type) { if (type !== 'display') - return row.offHeapMemoryUsed; + return row.memoryMetrics.usedOffHeapStorageMemory; else - return (formatBytes(row.offHeapMemoryUsed, type) + ' / ' + - formatBytes(row.maxOffHeapMemory, type)); + return (formatBytes(row.memoryMetrics.usedOffHeapStorageMemory, type) + ' / ' + + formatBytes(row.memoryMetrics.totalOffHeapStorageMemory, type)); }, "fnCreatedCell": function (nTd, sData, oData, iRow, iCol) { $(nTd).addClass('off_heap_memory') diff --git a/core/src/main/scala/org/apache/spark/status/api/v1/api.scala b/core/src/main/scala/org/apache/spark/status/api/v1/api.scala index d159b9450ef5c..56d8e51732ffd 100644 --- a/core/src/main/scala/org/apache/spark/status/api/v1/api.scala +++ b/core/src/main/scala/org/apache/spark/status/api/v1/api.scala @@ -76,10 +76,13 @@ class ExecutorSummary private[spark]( val isBlacklisted: Boolean, val maxMemory: Long, val executorLogs: Map[String, String], - val onHeapMemoryUsed: Option[Long], - val offHeapMemoryUsed: Option[Long], - val maxOnHeapMemory: Option[Long], - val maxOffHeapMemory: Option[Long]) + val memoryMetrics: Option[MemoryMetrics]) + +class MemoryMetrics private[spark]( + val usedOnHeapStorageMemory: Long, + val usedOffHeapStorageMemory: Long, + val totalOnHeapStorageMemory: Long, + val totalOffHeapStorageMemory: Long) class JobData private[spark]( val jobId: Int, diff --git a/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsPage.scala b/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsPage.scala index 0a3c63d14ca8a..b7cbed468517c 100644 --- a/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsPage.scala +++ b/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsPage.scala @@ -21,7 +21,7 @@ import javax.servlet.http.HttpServletRequest import scala.xml.Node -import org.apache.spark.status.api.v1.ExecutorSummary +import org.apache.spark.status.api.v1.{ExecutorSummary, MemoryMetrics} import org.apache.spark.ui.{UIUtils, WebUIPage} // This isn't even used anymore -- but we need to keep it b/c of a MiMa false positive @@ -114,10 +114,16 @@ private[spark] object ExecutorsPage { val rddBlocks = status.numBlocks val memUsed = status.memUsed val maxMem = status.maxMem - val onHeapMemUsed = status.onHeapMemUsed - val offHeapMemUsed = status.offHeapMemUsed - val maxOnHeapMem = status.maxOnHeapMem - val maxOffHeapMem = status.maxOffHeapMem + val memoryMetrics = for { + onHeapUsed <- status.onHeapMemUsed + offHeapUsed <- status.offHeapMemUsed + maxOnHeap <- status.maxOnHeapMem + maxOffHeap <- status.maxOffHeapMem + } yield { + new MemoryMetrics(onHeapUsed, offHeapUsed, maxOnHeap, maxOffHeap) + } + + val diskUsed = status.diskUsed val taskSummary = listener.executorToTaskSummary.getOrElse(execId, ExecutorTaskSummary(execId)) @@ -142,10 +148,7 @@ private[spark] object ExecutorsPage { taskSummary.isBlacklisted, maxMem, taskSummary.executorLogs, - onHeapMemUsed, - offHeapMemUsed, - maxOnHeapMem, - maxOffHeapMem + memoryMetrics ) } } diff --git a/core/src/test/resources/HistoryServerExpectations/executor_memory_usage_expectation.json b/core/src/test/resources/HistoryServerExpectations/executor_memory_usage_expectation.json index e732af2663503..0f94e3b255dbc 100644 --- a/core/src/test/resources/HistoryServerExpectations/executor_memory_usage_expectation.json +++ b/core/src/test/resources/HistoryServerExpectations/executor_memory_usage_expectation.json @@ -22,10 +22,12 @@ "stdout" : "http://172.22.0.167:51469/logPage/?appId=app-20161116163331-0000&executorId=2&logType=stdout", "stderr" : "http://172.22.0.167:51469/logPage/?appId=app-20161116163331-0000&executorId=2&logType=stderr" }, - "onHeapMemoryUsed" : 0, - "offHeapMemoryUsed" : 0, - "maxOnHeapMemory" : 384093388, - "maxOffHeapMemory" : 524288000 + "memoryMetrics": { + "usedOnHeapStorageMemory": 0, + "usedOffHeapStorageMemory": 0, + "totalOnHeapStorageMemory": 384093388, + "totalOffHeapStorageMemory": 524288000 + } }, { "id" : "driver", "hostPort" : "172.22.0.167:51475", @@ -47,10 +49,12 @@ "isBlacklisted" : true, "maxMemory" : 908381388, "executorLogs" : { }, - "onHeapMemoryUsed" : 0, - "offHeapMemoryUsed" : 0, - "maxOnHeapMemory" : 384093388, - "maxOffHeapMemory" : 524288000 + "memoryMetrics": { + "usedOnHeapStorageMemory": 0, + "usedOffHeapStorageMemory": 0, + "totalOnHeapStorageMemory": 384093388, + "totalOffHeapStorageMemory": 524288000 + } }, { "id" : "1", "hostPort" : "172.22.0.167:51490", @@ -75,11 +79,12 @@ "stdout" : "http://172.22.0.167:51467/logPage/?appId=app-20161116163331-0000&executorId=1&logType=stdout", "stderr" : "http://172.22.0.167:51467/logPage/?appId=app-20161116163331-0000&executorId=1&logType=stderr" }, - - "onHeapMemoryUsed" : 0, - "offHeapMemoryUsed" : 0, - "maxOnHeapMemory" : 384093388, - "maxOffHeapMemory" : 524288000 + "memoryMetrics": { + "usedOnHeapStorageMemory": 0, + "usedOffHeapStorageMemory": 0, + "totalOnHeapStorageMemory": 384093388, + "totalOffHeapStorageMemory": 524288000 + } }, { "id" : "0", "hostPort" : "172.22.0.167:51491", @@ -104,10 +109,12 @@ "stdout" : "http://172.22.0.167:51465/logPage/?appId=app-20161116163331-0000&executorId=0&logType=stdout", "stderr" : "http://172.22.0.167:51465/logPage/?appId=app-20161116163331-0000&executorId=0&logType=stderr" }, - "onHeapMemoryUsed" : 0, - "offHeapMemoryUsed" : 0, - "maxOnHeapMemory" : 384093388, - "maxOffHeapMemory" : 524288000 + "memoryMetrics": { + "usedOnHeapStorageMemory": 0, + "usedOffHeapStorageMemory": 0, + "totalOnHeapStorageMemory": 384093388, + "totalOffHeapStorageMemory": 524288000 + } }, { "id" : "3", "hostPort" : "172.22.0.167:51485", @@ -132,8 +139,10 @@ "stdout" : "http://172.22.0.167:51466/logPage/?appId=app-20161116163331-0000&executorId=3&logType=stdout", "stderr" : "http://172.22.0.167:51466/logPage/?appId=app-20161116163331-0000&executorId=3&logType=stderr" }, - "onHeapMemoryUsed" : 0, - "offHeapMemoryUsed" : 0, - "maxOnHeapMemory" : 384093388, - "maxOffHeapMemory" : 524288000 + "memoryMetrics": { + "usedOnHeapStorageMemory": 0, + "usedOffHeapStorageMemory": 0, + "totalOnHeapStorageMemory": 384093388, + "totalOffHeapStorageMemory": 524288000 + } } ] diff --git a/core/src/test/resources/HistoryServerExpectations/executor_node_blacklisting_expectation.json b/core/src/test/resources/HistoryServerExpectations/executor_node_blacklisting_expectation.json index e732af2663503..0f94e3b255dbc 100644 --- a/core/src/test/resources/HistoryServerExpectations/executor_node_blacklisting_expectation.json +++ b/core/src/test/resources/HistoryServerExpectations/executor_node_blacklisting_expectation.json @@ -22,10 +22,12 @@ "stdout" : "http://172.22.0.167:51469/logPage/?appId=app-20161116163331-0000&executorId=2&logType=stdout", "stderr" : "http://172.22.0.167:51469/logPage/?appId=app-20161116163331-0000&executorId=2&logType=stderr" }, - "onHeapMemoryUsed" : 0, - "offHeapMemoryUsed" : 0, - "maxOnHeapMemory" : 384093388, - "maxOffHeapMemory" : 524288000 + "memoryMetrics": { + "usedOnHeapStorageMemory": 0, + "usedOffHeapStorageMemory": 0, + "totalOnHeapStorageMemory": 384093388, + "totalOffHeapStorageMemory": 524288000 + } }, { "id" : "driver", "hostPort" : "172.22.0.167:51475", @@ -47,10 +49,12 @@ "isBlacklisted" : true, "maxMemory" : 908381388, "executorLogs" : { }, - "onHeapMemoryUsed" : 0, - "offHeapMemoryUsed" : 0, - "maxOnHeapMemory" : 384093388, - "maxOffHeapMemory" : 524288000 + "memoryMetrics": { + "usedOnHeapStorageMemory": 0, + "usedOffHeapStorageMemory": 0, + "totalOnHeapStorageMemory": 384093388, + "totalOffHeapStorageMemory": 524288000 + } }, { "id" : "1", "hostPort" : "172.22.0.167:51490", @@ -75,11 +79,12 @@ "stdout" : "http://172.22.0.167:51467/logPage/?appId=app-20161116163331-0000&executorId=1&logType=stdout", "stderr" : "http://172.22.0.167:51467/logPage/?appId=app-20161116163331-0000&executorId=1&logType=stderr" }, - - "onHeapMemoryUsed" : 0, - "offHeapMemoryUsed" : 0, - "maxOnHeapMemory" : 384093388, - "maxOffHeapMemory" : 524288000 + "memoryMetrics": { + "usedOnHeapStorageMemory": 0, + "usedOffHeapStorageMemory": 0, + "totalOnHeapStorageMemory": 384093388, + "totalOffHeapStorageMemory": 524288000 + } }, { "id" : "0", "hostPort" : "172.22.0.167:51491", @@ -104,10 +109,12 @@ "stdout" : "http://172.22.0.167:51465/logPage/?appId=app-20161116163331-0000&executorId=0&logType=stdout", "stderr" : "http://172.22.0.167:51465/logPage/?appId=app-20161116163331-0000&executorId=0&logType=stderr" }, - "onHeapMemoryUsed" : 0, - "offHeapMemoryUsed" : 0, - "maxOnHeapMemory" : 384093388, - "maxOffHeapMemory" : 524288000 + "memoryMetrics": { + "usedOnHeapStorageMemory": 0, + "usedOffHeapStorageMemory": 0, + "totalOnHeapStorageMemory": 384093388, + "totalOffHeapStorageMemory": 524288000 + } }, { "id" : "3", "hostPort" : "172.22.0.167:51485", @@ -132,8 +139,10 @@ "stdout" : "http://172.22.0.167:51466/logPage/?appId=app-20161116163331-0000&executorId=3&logType=stdout", "stderr" : "http://172.22.0.167:51466/logPage/?appId=app-20161116163331-0000&executorId=3&logType=stderr" }, - "onHeapMemoryUsed" : 0, - "offHeapMemoryUsed" : 0, - "maxOnHeapMemory" : 384093388, - "maxOffHeapMemory" : 524288000 + "memoryMetrics": { + "usedOnHeapStorageMemory": 0, + "usedOffHeapStorageMemory": 0, + "totalOnHeapStorageMemory": 384093388, + "totalOffHeapStorageMemory": 524288000 + } } ] From 99c6cf9ef16bf8fae6edb23a62e46546a16bca80 Mon Sep 17 00:00:00 2001 From: Michal Szafranski Date: Wed, 26 Apr 2017 11:21:25 -0700 Subject: [PATCH 0340/1765] [SPARK-20473] Enabling missing types in ColumnVector.Array ## What changes were proposed in this pull request? ColumnVector implementations originally did not support some Catalyst types (float, short, and boolean). Now that they do, those types should be also added to the ColumnVector.Array. ## How was this patch tested? Tested using existing unit tests. Author: Michal Szafranski Closes #17772 from michal-databricks/spark-20473. --- .../apache/spark/sql/execution/vectorized/ColumnVector.java | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVector.java index 354c878aca000..b105e60a2d34a 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVector.java @@ -180,7 +180,7 @@ public Object[] array() { @Override public boolean getBoolean(int ordinal) { - throw new UnsupportedOperationException(); + return data.getBoolean(offset + ordinal); } @Override @@ -188,7 +188,7 @@ public boolean getBoolean(int ordinal) { @Override public short getShort(int ordinal) { - throw new UnsupportedOperationException(); + return data.getShort(offset + ordinal); } @Override @@ -199,7 +199,7 @@ public short getShort(int ordinal) { @Override public float getFloat(int ordinal) { - throw new UnsupportedOperationException(); + return data.getFloat(offset + ordinal); } @Override From a277ae80a2836e6533b338d2b9c4e59ed8a1daae Mon Sep 17 00:00:00 2001 From: Michal Szafranski Date: Wed, 26 Apr 2017 12:47:37 -0700 Subject: [PATCH 0341/1765] [SPARK-20474] Fixing OnHeapColumnVector reallocation ## What changes were proposed in this pull request? OnHeapColumnVector reallocation copies to the new storage data up to 'elementsAppended'. This variable is only updated when using the ColumnVector.appendX API, while ColumnVector.putX is more commonly used. ## How was this patch tested? Tested using existing unit tests. Author: Michal Szafranski Closes #17773 from michal-databricks/spark-20474. --- .../vectorized/OnHeapColumnVector.java | 20 +++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OnHeapColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OnHeapColumnVector.java index 9b410bacff5df..94ed32294cfae 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OnHeapColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OnHeapColumnVector.java @@ -410,53 +410,53 @@ protected void reserveInternal(int newCapacity) { int[] newLengths = new int[newCapacity]; int[] newOffsets = new int[newCapacity]; if (this.arrayLengths != null) { - System.arraycopy(this.arrayLengths, 0, newLengths, 0, elementsAppended); - System.arraycopy(this.arrayOffsets, 0, newOffsets, 0, elementsAppended); + System.arraycopy(this.arrayLengths, 0, newLengths, 0, capacity); + System.arraycopy(this.arrayOffsets, 0, newOffsets, 0, capacity); } arrayLengths = newLengths; arrayOffsets = newOffsets; } else if (type instanceof BooleanType) { if (byteData == null || byteData.length < newCapacity) { byte[] newData = new byte[newCapacity]; - if (byteData != null) System.arraycopy(byteData, 0, newData, 0, elementsAppended); + if (byteData != null) System.arraycopy(byteData, 0, newData, 0, capacity); byteData = newData; } } else if (type instanceof ByteType) { if (byteData == null || byteData.length < newCapacity) { byte[] newData = new byte[newCapacity]; - if (byteData != null) System.arraycopy(byteData, 0, newData, 0, elementsAppended); + if (byteData != null) System.arraycopy(byteData, 0, newData, 0, capacity); byteData = newData; } } else if (type instanceof ShortType) { if (shortData == null || shortData.length < newCapacity) { short[] newData = new short[newCapacity]; - if (shortData != null) System.arraycopy(shortData, 0, newData, 0, elementsAppended); + if (shortData != null) System.arraycopy(shortData, 0, newData, 0, capacity); shortData = newData; } } else if (type instanceof IntegerType || type instanceof DateType || DecimalType.is32BitDecimalType(type)) { if (intData == null || intData.length < newCapacity) { int[] newData = new int[newCapacity]; - if (intData != null) System.arraycopy(intData, 0, newData, 0, elementsAppended); + if (intData != null) System.arraycopy(intData, 0, newData, 0, capacity); intData = newData; } } else if (type instanceof LongType || type instanceof TimestampType || DecimalType.is64BitDecimalType(type)) { if (longData == null || longData.length < newCapacity) { long[] newData = new long[newCapacity]; - if (longData != null) System.arraycopy(longData, 0, newData, 0, elementsAppended); + if (longData != null) System.arraycopy(longData, 0, newData, 0, capacity); longData = newData; } } else if (type instanceof FloatType) { if (floatData == null || floatData.length < newCapacity) { float[] newData = new float[newCapacity]; - if (floatData != null) System.arraycopy(floatData, 0, newData, 0, elementsAppended); + if (floatData != null) System.arraycopy(floatData, 0, newData, 0, capacity); floatData = newData; } } else if (type instanceof DoubleType) { if (doubleData == null || doubleData.length < newCapacity) { double[] newData = new double[newCapacity]; - if (doubleData != null) System.arraycopy(doubleData, 0, newData, 0, elementsAppended); + if (doubleData != null) System.arraycopy(doubleData, 0, newData, 0, capacity); doubleData = newData; } } else if (resultStruct != null) { @@ -466,7 +466,7 @@ protected void reserveInternal(int newCapacity) { } byte[] newNulls = new byte[newCapacity]; - if (nulls != null) System.arraycopy(nulls, 0, newNulls, 0, elementsAppended); + if (nulls != null) System.arraycopy(nulls, 0, newNulls, 0, capacity); nulls = newNulls; capacity = newCapacity; From 2ba1eba371213d1ac3d1fa1552e5906e043c2ee4 Mon Sep 17 00:00:00 2001 From: Weiqing Yang Date: Wed, 26 Apr 2017 13:54:40 -0700 Subject: [PATCH 0342/1765] [SPARK-12868][SQL] Allow adding jars from hdfs ## What changes were proposed in this pull request? Spark 2.2 is going to be cut, it'll be great if SPARK-12868 can be resolved before that. There have been several PRs for this like [PR#16324](https://github.com/apache/spark/pull/16324) , but all of them are inactivity for a long time or have been closed. This PR added a SparkUrlStreamHandlerFactory, which relies on 'protocol' to choose the appropriate UrlStreamHandlerFactory like FsUrlStreamHandlerFactory to create URLStreamHandler. ## How was this patch tested? 1. Add a new unit test. 2. Check manually. Before: throw an exception with " failed unknown protocol: hdfs" screen shot 2017-03-17 at 9 07 36 pm After: screen shot 2017-03-18 at 11 42 18 am Author: Weiqing Yang Closes #17342 from weiqingy/SPARK-18910. --- .../org/apache/spark/sql/internal/SharedState.scala | 10 +++++++++- .../scala/org/apache/spark/sql/SQLQuerySuite.scala | 13 +++++++++++++ 2 files changed, 22 insertions(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/SharedState.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/SharedState.scala index f834569e59b7f..a93b701146077 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/SharedState.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/SharedState.scala @@ -17,12 +17,14 @@ package org.apache.spark.sql.internal +import java.net.URL import java.util.Locale import scala.reflect.ClassTag import scala.util.control.NonFatal import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.fs.FsUrlStreamHandlerFactory import org.apache.spark.{SparkConf, SparkContext, SparkException} import org.apache.spark.internal.Logging @@ -154,7 +156,13 @@ private[sql] class SharedState(val sparkContext: SparkContext) extends Logging { } } -object SharedState { +object SharedState extends Logging { + try { + URL.setURLStreamHandlerFactory(new FsUrlStreamHandlerFactory()) + } catch { + case e: Error => + logWarning("URL.setURLStreamHandlerFactory failed to set FsUrlStreamHandlerFactory") + } private val HIVE_EXTERNAL_CATALOG_CLASS_NAME = "org.apache.spark.sql.hive.HiveExternalCatalog" diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 0dd9296a3f0ff..3ecbf96b41961 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql import java.io.File import java.math.MathContext +import java.net.{MalformedURLException, URL} import java.sql.Timestamp import java.util.concurrent.atomic.AtomicBoolean @@ -2606,4 +2607,16 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { case ae: AnalysisException => assert(ae.plan == null && ae.getMessage == ae.getSimpleMessage) } } + + test("SPARK-12868: Allow adding jars from hdfs ") { + val jarFromHdfs = "hdfs://doesnotmatter/test.jar" + val jarFromInvalidFs = "fffs://doesnotmatter/test.jar" + + // if 'hdfs' is not supported, MalformedURLException will be thrown + new URL(jarFromHdfs) + + intercept[MalformedURLException] { + new URL(jarFromInvalidFs) + } + } } From 66636ef0b046e5d1f340c3b8153d7213fa9d19c7 Mon Sep 17 00:00:00 2001 From: Mark Grover Date: Wed, 26 Apr 2017 17:06:21 -0700 Subject: [PATCH 0343/1765] [SPARK-20435][CORE] More thorough redaction of sensitive information This change does a more thorough redaction of sensitive information from logs and UI Add unit tests that ensure that no regressions happen that leak sensitive information to the logs. The motivation for this change was appearance of password like so in `SparkListenerEnvironmentUpdate` in event logs under some JVM configurations: `"sun.java.command":"org.apache.spark.deploy.SparkSubmit ... --conf spark.executorEnv.HADOOP_CREDSTORE_PASSWORD=secret_password ..." ` Previously redaction logic was only checking if the key matched the secret regex pattern, it'd redact it's value. That worked for most cases. However, in the above case, the key (sun.java.command) doesn't tell much, so the value needs to be searched. This PR expands the check to check for values as well. ## How was this patch tested? New unit tests added that ensure that no sensitive information is present in the event logs or the yarn logs. Old unit test in UtilsSuite was modified because the test was asserting that a non-sensitive property's value won't be redacted. However, the non-sensitive value had the literal "secret" in it which was causing it to redact. Simply updating the non-sensitive property's value to another arbitrary value (that didn't have "secret" in it) fixed it. Author: Mark Grover Closes #17725 from markgrover/spark-20435. --- .../spark/internal/config/package.scala | 4 +-- .../scheduler/EventLoggingListener.scala | 16 ++++++--- .../scala/org/apache/spark/util/Utils.scala | 22 +++++++++--- .../spark/deploy/SparkSubmitSuite.scala | 34 +++++++++++++++++++ .../org/apache/spark/util/UtilsSuite.scala | 10 ++++-- docs/configuration.md | 4 +-- .../spark/deploy/yarn/YarnClusterSuite.scala | 32 +++++++++++++---- 7 files changed, 100 insertions(+), 22 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/internal/config/package.scala b/core/src/main/scala/org/apache/spark/internal/config/package.scala index 89aeea4939086..2f0a3064be111 100644 --- a/core/src/main/scala/org/apache/spark/internal/config/package.scala +++ b/core/src/main/scala/org/apache/spark/internal/config/package.scala @@ -244,8 +244,8 @@ package object config { ConfigBuilder("spark.redaction.regex") .doc("Regex to decide which Spark configuration properties and environment variables in " + "driver and executor environments contain sensitive information. When this regex matches " + - "a property, its value is redacted from the environment UI and various logs like YARN " + - "and event logs.") + "a property key or value, the value is redacted from the environment UI and various logs " + + "like YARN and event logs.") .regexConf .createWithDefault("(?i)secret|password".r) diff --git a/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala b/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala index aecb3a980e7c1..a7dbf87915b27 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala @@ -252,11 +252,17 @@ private[spark] class EventLoggingListener( private[spark] def redactEvent( event: SparkListenerEnvironmentUpdate): SparkListenerEnvironmentUpdate = { - // "Spark Properties" entry will always exist because the map is always populated with it. - val redactedProps = Utils.redact(sparkConf, event.environmentDetails("Spark Properties")) - val redactedEnvironmentDetails = event.environmentDetails + - ("Spark Properties" -> redactedProps) - SparkListenerEnvironmentUpdate(redactedEnvironmentDetails) + // environmentDetails maps a string descriptor to a set of properties + // Similar to: + // "JVM Information" -> jvmInformation, + // "Spark Properties" -> sparkProperties, + // ... + // where jvmInformation, sparkProperties, etc. are sequence of tuples. + // We go through the various of properties and redact sensitive information from them. + val redactedProps = event.environmentDetails.map{ case (name, props) => + name -> Utils.redact(sparkConf, props) + } + SparkListenerEnvironmentUpdate(redactedProps) } } diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index 943dde0723271..e042badcdd4a4 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -2606,10 +2606,24 @@ private[spark] object Utils extends Logging { } private def redact(redactionPattern: Regex, kvs: Seq[(String, String)]): Seq[(String, String)] = { - kvs.map { kv => - redactionPattern.findFirstIn(kv._1) - .map { _ => (kv._1, REDACTION_REPLACEMENT_TEXT) } - .getOrElse(kv) + // If the sensitive information regex matches with either the key or the value, redact the value + // While the original intent was to only redact the value if the key matched with the regex, + // we've found that especially in verbose mode, the value of the property may contain sensitive + // information like so: + // "sun.java.command":"org.apache.spark.deploy.SparkSubmit ... \ + // --conf spark.executorEnv.HADOOP_CREDSTORE_PASSWORD=secret_password ... + // + // And, in such cases, simply searching for the sensitive information regex in the key name is + // not sufficient. The values themselves have to be searched as well and redacted if matched. + // This does mean we may be accounting more false positives - for example, if the value of an + // arbitrary property contained the term 'password', we may redact the value from the UI and + // logs. In order to work around it, user would have to make the spark.redaction.regex property + // more specific. + kvs.map { case (key, value) => + redactionPattern.findFirstIn(key) + .orElse(redactionPattern.findFirstIn(value)) + .map { _ => (key, REDACTION_REPLACEMENT_TEXT) } + .getOrElse((key, value)) } } diff --git a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala index 7c2ec01a03d04..a43839a8815f9 100644 --- a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala @@ -21,8 +21,10 @@ import java.io._ import java.nio.charset.StandardCharsets import scala.collection.mutable.ArrayBuffer +import scala.io.Source import com.google.common.io.ByteStreams +import org.apache.hadoop.fs.Path import org.scalatest.{BeforeAndAfterEach, Matchers} import org.scalatest.concurrent.Timeouts import org.scalatest.time.SpanSugar._ @@ -34,6 +36,7 @@ import org.apache.spark.deploy.SparkSubmitUtils.MavenCoordinate import org.apache.spark.internal.config._ import org.apache.spark.internal.Logging import org.apache.spark.TestUtils.JavaSourceFromString +import org.apache.spark.scheduler.EventLoggingListener import org.apache.spark.util.{CommandLineUtils, ResetSystemProperties, Utils} @@ -404,6 +407,37 @@ class SparkSubmitSuite runSparkSubmit(args) } + test("launch simple application with spark-submit with redaction") { + val testDir = Utils.createTempDir() + testDir.deleteOnExit() + val testDirPath = new Path(testDir.getAbsolutePath()) + val unusedJar = TestUtils.createJarWithClasses(Seq.empty) + val fileSystem = Utils.getHadoopFileSystem("/", + SparkHadoopUtil.get.newConfiguration(new SparkConf())) + try { + val args = Seq( + "--class", SimpleApplicationTest.getClass.getName.stripSuffix("$"), + "--name", "testApp", + "--master", "local", + "--conf", "spark.ui.enabled=false", + "--conf", "spark.master.rest.enabled=false", + "--conf", "spark.executorEnv.HADOOP_CREDSTORE_PASSWORD=secret_password", + "--conf", "spark.eventLog.enabled=true", + "--conf", "spark.eventLog.testing=true", + "--conf", s"spark.eventLog.dir=${testDirPath.toUri.toString}", + "--conf", "spark.hadoop.fs.defaultFS=unsupported://example.com", + unusedJar.toString) + runSparkSubmit(args) + val listStatus = fileSystem.listStatus(testDirPath) + val logData = EventLoggingListener.openEventLog(listStatus.last.getPath, fileSystem) + Source.fromInputStream(logData).getLines().foreach { line => + assert(!line.contains("secret_password")) + } + } finally { + Utils.deleteRecursively(testDir) + } + } + test("includes jars passed in through --jars") { val unusedJar = TestUtils.createJarWithClasses(Seq.empty) val jar1 = TestUtils.createJarWithClasses(Seq("SparkSubmitClassA")) diff --git a/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala b/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala index 8ed09749ffd54..3339d5b35d3b2 100644 --- a/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala @@ -1010,15 +1010,19 @@ class UtilsSuite extends SparkFunSuite with ResetSystemProperties with Logging { "spark.executorEnv.HADOOP_CREDSTORE_PASSWORD", "spark.my.password", "spark.my.sECreT") - secretKeys.foreach { key => sparkConf.set(key, "secret_password") } + secretKeys.foreach { key => sparkConf.set(key, "sensitive_value") } // Set a non-secret key - sparkConf.set("spark.regular.property", "not_a_secret") + sparkConf.set("spark.regular.property", "regular_value") + // Set a property with a regular key but secret in the value + sparkConf.set("spark.sensitive.property", "has_secret_in_value") // Redact sensitive information val redactedConf = Utils.redact(sparkConf, sparkConf.getAll).toMap // Assert that secret information got redacted while the regular property remained the same secretKeys.foreach { key => assert(redactedConf(key) === Utils.REDACTION_REPLACEMENT_TEXT) } - assert(redactedConf("spark.regular.property") === "not_a_secret") + assert(redactedConf("spark.regular.property") === "regular_value") + assert(redactedConf("spark.sensitive.property") === Utils.REDACTION_REPLACEMENT_TEXT) + } } diff --git a/docs/configuration.md b/docs/configuration.md index 8b53e92ccd416..1d8d963016c71 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -372,8 +372,8 @@ Apart from these, the following properties are also available, and may be useful
    diff --git a/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala b/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala index 99fb58a28934a..59adb7e22d185 100644 --- a/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala +++ b/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala @@ -24,6 +24,7 @@ import java.util.{HashMap => JHashMap} import scala.collection.mutable import scala.concurrent.duration._ +import scala.io.Source import scala.language.postfixOps import com.google.common.io.{ByteStreams, Files} @@ -87,24 +88,30 @@ class YarnClusterSuite extends BaseYarnClusterSuite { testBasicYarnApp(false) } - test("run Spark in yarn-client mode with different configurations") { + test("run Spark in yarn-client mode with different configurations, ensuring redaction") { testBasicYarnApp(true, Map( "spark.driver.memory" -> "512m", "spark.executor.cores" -> "1", "spark.executor.memory" -> "512m", - "spark.executor.instances" -> "2" + "spark.executor.instances" -> "2", + // Sending some senstive information, which we'll make sure gets redacted + "spark.executorEnv.HADOOP_CREDSTORE_PASSWORD" -> YarnClusterDriver.SECRET_PASSWORD, + "spark.yarn.appMasterEnv.HADOOP_CREDSTORE_PASSWORD" -> YarnClusterDriver.SECRET_PASSWORD )) } - test("run Spark in yarn-cluster mode with different configurations") { + test("run Spark in yarn-cluster mode with different configurations, ensuring redaction") { testBasicYarnApp(false, Map( "spark.driver.memory" -> "512m", "spark.driver.cores" -> "1", "spark.executor.cores" -> "1", "spark.executor.memory" -> "512m", - "spark.executor.instances" -> "2" + "spark.executor.instances" -> "2", + // Sending some senstive information, which we'll make sure gets redacted + "spark.executorEnv.HADOOP_CREDSTORE_PASSWORD" -> YarnClusterDriver.SECRET_PASSWORD, + "spark.yarn.appMasterEnv.HADOOP_CREDSTORE_PASSWORD" -> YarnClusterDriver.SECRET_PASSWORD )) } @@ -349,6 +356,7 @@ private object YarnClusterDriverUseSparkHadoopUtilConf extends Logging with Matc private object YarnClusterDriver extends Logging with Matchers { val WAIT_TIMEOUT_MILLIS = 10000 + val SECRET_PASSWORD = "secret_password" def main(args: Array[String]): Unit = { if (args.length != 1) { @@ -395,6 +403,13 @@ private object YarnClusterDriver extends Logging with Matchers { assert(executorInfos.nonEmpty) executorInfos.foreach { info => assert(info.logUrlMap.nonEmpty) + info.logUrlMap.values.foreach { url => + val log = Source.fromURL(url).mkString + assert( + !log.contains(SECRET_PASSWORD), + s"Executor logs contain sensitive info (${SECRET_PASSWORD}): \n${log} " + ) + } } // If we are running in yarn-cluster mode, verify that driver logs links and present and are @@ -406,8 +421,13 @@ private object YarnClusterDriver extends Logging with Matchers { assert(driverLogs.contains("stderr")) assert(driverLogs.contains("stdout")) val urlStr = driverLogs("stderr") - // Ensure that this is a valid URL, else this will throw an exception - new URL(urlStr) + driverLogs.foreach { kv => + val log = Source.fromURL(kv._2).mkString + assert( + !log.contains(SECRET_PASSWORD), + s"Driver logs contain sensitive info (${SECRET_PASSWORD}): \n${log} " + ) + } val containerId = YarnSparkHadoopUtil.get.getContainerId val user = Utils.getCurrentUserName() assert(urlStr.endsWith(s"/node/containerlogs/$containerId/$user/stderr?start=-4096")) From b4724db19a10387a803cd7beec14facf7ad1894a Mon Sep 17 00:00:00 2001 From: Takeshi Yamamuro Date: Wed, 26 Apr 2017 22:18:01 -0700 Subject: [PATCH 0344/1765] [SPARK-20425][SQL] Support a vertical display mode for Dataset.show ## What changes were proposed in this pull request? This pr added a new display mode for `Dataset.show` to print output rows vertically (one line per column value). In the current master, when printing Dataset with many columns, the readability is low like; ``` scala> val df = spark.range(100).selectExpr((0 until 100).map(i => s"rand() AS c$i"): _*) scala> df.show|c0 |c1 |c2 |c3 |c4 |c5 |c6 |c7 |c8 |c9 |c10 |c11 |c12 |c13 |c14 |c15 |c16 |c17 |c18 |c19 |c20 |c21 |c22 |c23 |c24 |c25 |c26 |c27 |c28 |c29 |c30 |c31 |c32 |c33 |c34 |c35 |c36 |c37 |c38 |c39 |c40 |c41 |c42 |c43 |c44 |c45 |c46 |c47 |c48 |c49 |c50 |c51 |c52 |c53 |c54 |c55 |c56 |c57 |c58 |c59 |c60 |c61 |c62 |c63 |c64 |c65 |c66 |c67 |c68 |c69 |c70 |c71 |c72 |c73 |c74 |c75 |c76 |c77 |c78 |c79 |c80 |c81 |c82 |c83 |c84 |c85 |c86 |c87 |c88 |c89 |c90 |c91 |c92 |c93 |c94 |c95 |c96 |c97 |c98 |c99 ||0.6306087152476858|0.9174349686288383|0.5511324165035159|0.3320844128641819 |0.7738486877101489|0.2154915886962553|0.4754997600674299 |0.922780639280355 |0.7136894772661909|0.2277580838165979|0.5926874459847249|0.40311408392226633|0.467830264333843 |0.8330466896984213|0.1893258482389527|0.6320849515511165 |0.7530911056912044 |0.06700254871955424|0.370528597355559 |0.2755437445193154|0.23704391110980128|0.8067400174905822|0.13597793616251852|0.1708888820162453|0.01672725007605702|0.983118121881555 |0.25040195628629924|0.060537253723083384|0.20000530582637488|0.3400572407133511|0.9375689433322597 |0.057039316954370256|0.8053269714347623|0.5247817572228813|0.28419308820527944|0.9798908885194533 |0.31805988175678146|0.7034448027077574|0.5400575751346084|0.25336322371116216|0.9361634546853429|0.6118681368289798|0.6295081549153907 |0.13417468943957422|0.41617137072255794|0.7267230869252035|0.023792726137561115|0.5776157058356362 |0.04884204913195467|0.26728716103441275|0.646680370807925 |0.9782712690657244 |0.16434031314818154|0.20985522381321275|0.24739842475440077 |0.26335189682977334|0.19604841662422068|0.10742950487300651|0.20283136488091502|0.3100312319723688|0.886959006630645 |0.25157102269776244|0.34428775168410786|0.3500506818575777|0.3781142441912052 |0.8560316444386715|0.4737104888956839|0.735903101602148|0.02236617130529006|0.8769074095835873 |0.2001426662503153|0.5534032319238532 |0.7289496620397098|0.41955191309992157|0.9337700133660436 |0.34059094378451005|0.6419144759403556|0.08167496930341167|0.9947099478497635|0.48010888605366586|0.22314796858167918|0.17786598882331306|0.7351521162297135 |0.5422057170020095 |0.9521927872726792 |0.7459825486368227 |0.40907708791990627|0.8903819313311575|0.7251413746923618 |0.2977174938745204 |0.9515209660203555|0.9375968604766713|0.5087851740042524|0.4255237544908751 |0.8023768698664653|0.48003189618006703|0.1775841829745185|0.09050775629268382|0.6743909291138167 |0.2498415755876865 | |0.6866473844170801|0.4774360641212433|0.631696201340726 |0.33979113021468343|0.5663049010847052|0.7280190472258865|0.41370958502324806|0.9977433873622218|0.7671957338989901|0.2788708556233931|0.3355106391656496|0.88478952319287 |0.0333974166999893|0.6061744715862606|0.9617779139652359|0.22484954822341863|0.12770906021550898|0.5577789629508672 |0.2877649024640704|0.5566577406549361|0.9334933255278052 |0.9166720585157266|0.9689249324600591 |0.6367502457478598|0.7993572745928459 |0.23213222324218108|0.11928284054154137|0.6173493362456599 |0.0505122058694798 |0.9050228629552983|0.17112767911121707|0.47395598348370005 |0.5820498657823081|0.6241124650645072|0.18587258258036776|0.14987593554122225|0.3079446253653946 |0.9414228822867968|0.8362276265462365|0.9155655305576353 |0.5121559807153562|0.8963362656525707|0.22765970274318037|0.8177039187132797 |0.8190326635933787 |0.5256005177032199|0.8167598457269669 |0.030936807130934496|0.6733006585281015 |0.4208049626816347 |0.24603085738518538|0.22719198954208153|0.1622280557565281 |0.22217325159218038|0.014684419513742553|0.08987111517447499|0.2157764759142622 |0.8223414104088321 |0.4868624404491777 |0.4016191733088167|0.6169281906889263|0.15603611040433385|0.18289285085714913|0.9538408988218972|0.15037154865295121|0.5364516961987454|0.8077254873163031|0.712600478545675|0.7277477241003857 |0.19822912960348305|0.8305051199208777|0.18631911396566114|0.8909532487898342|0.3470409226992506 |0.35306974180587636|0.9107058868891469 |0.3321327206004986|0.48952332459050607|0.3630403307479373|0.5400046826340376 |0.5387377194310529 |0.42860539421837585|0.23214101630985995|0.21438968839794847|0.15370603160082352|0.04355605642700022|0.6096006707067466 |0.6933354157094292|0.06302172470859002|0.03174631856164001|0.664243581650643 |0.7833239547446621|0.696884598352864 |0.34626385933237736|0.9263495598791336|0.404818892816584 |0.2085585394755507|0.6150004897990109 |0.05391193524302473|0.28188484028329097|only showing top 2 rows ``` `psql`, CLI for PostgreSQL, supports a vertical display mode for this case like: http://stackoverflow.com/questions/9604723/alternate-output-format-for-psql ``` -RECORD 0------------------- c0 | 0.6306087152476858 c1 | 0.9174349686288383 c2 | 0.5511324165035159 ... c98 | 0.05391193524302473 c99 | 0.28188484028329097 -RECORD 1------------------- c0 | 0.6866473844170801 c1 | 0.4774360641212433 c2 | 0.631696201340726 ... c98 | 0.05391193524302473 c99 | 0.28188484028329097 only showing top 2 rows ``` ## How was this patch tested? Added tests in `DataFrameSuite`. Author: Takeshi Yamamuro Closes #17733 from maropu/SPARK-20425. --- R/pkg/R/DataFrame.R | 8 +- python/pyspark/sql/dataframe.py | 15 +- .../scala/org/apache/spark/sql/Dataset.scala | 149 ++++++++++++++---- .../org/apache/spark/sql/DataFrameSuite.scala | 112 +++++++++++++ 4 files changed, 247 insertions(+), 37 deletions(-) diff --git a/R/pkg/R/DataFrame.R b/R/pkg/R/DataFrame.R index cd6f03a13d7c7..7e57ba6287bb8 100644 --- a/R/pkg/R/DataFrame.R +++ b/R/pkg/R/DataFrame.R @@ -194,6 +194,7 @@ setMethod("isLocal", #' 20 characters will be truncated. However, if set greater than zero, #' truncates strings longer than \code{truncate} characters and all cells #' will be aligned right. +#' @param vertical whether print output rows vertically (one line per column value). #' @param ... further arguments to be passed to or from other methods. #' @family SparkDataFrame functions #' @aliases showDF,SparkDataFrame-method @@ -210,12 +211,13 @@ setMethod("isLocal", #' @note showDF since 1.4.0 setMethod("showDF", signature(x = "SparkDataFrame"), - function(x, numRows = 20, truncate = TRUE) { + function(x, numRows = 20, truncate = TRUE, vertical = FALSE) { if (is.logical(truncate) && truncate) { - s <- callJMethod(x@sdf, "showString", numToInt(numRows), numToInt(20)) + s <- callJMethod(x@sdf, "showString", numToInt(numRows), numToInt(20), vertical) } else { truncate2 <- as.numeric(truncate) - s <- callJMethod(x@sdf, "showString", numToInt(numRows), numToInt(truncate2)) + s <- callJMethod(x@sdf, "showString", numToInt(numRows), numToInt(truncate2), + vertical) } cat(s) }) diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index 774caf53f3a4b..ff21bb5d2fb3f 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -290,13 +290,15 @@ def isStreaming(self): return self._jdf.isStreaming() @since(1.3) - def show(self, n=20, truncate=True): + def show(self, n=20, truncate=True, vertical=False): """Prints the first ``n`` rows to the console. :param n: Number of rows to show. :param truncate: If set to True, truncate strings longer than 20 chars by default. If set to a number greater than one, truncates long strings to length ``truncate`` and align cells right. + :param vertical: If set to True, print output rows vertically (one line + per column value). >>> df DataFrame[age: int, name: string] @@ -314,11 +316,18 @@ def show(self, n=20, truncate=True): | 2| Ali| | 5| Bob| +---+----+ + >>> df.show(vertical=True) + -RECORD 0----- + age | 2 + name | Alice + -RECORD 1----- + age | 5 + name | Bob """ if isinstance(truncate, bool) and truncate: - print(self._jdf.showString(n, 20)) + print(self._jdf.showString(n, 20, vertical)) else: - print(self._jdf.showString(n, int(truncate))) + print(self._jdf.showString(n, int(truncate), vertical)) def __repr__(self): return "DataFrame[%s]" % (", ".join("%s: %s" % c for c in self.dtypes)) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index 06dd5500718de..147e7651ce55b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -240,8 +240,10 @@ class Dataset[T] private[sql]( * @param _numRows Number of rows to show * @param truncate If set to more than 0, truncates strings to `truncate` characters and * all cells will be aligned right. + * @param vertical If set to true, prints output rows vertically (one line per column value). */ - private[sql] def showString(_numRows: Int, truncate: Int = 20): String = { + private[sql] def showString( + _numRows: Int, truncate: Int = 20, vertical: Boolean = false): String = { val numRows = _numRows.max(0) val takeResult = toDF().take(numRows + 1) val hasMoreData = takeResult.length > numRows @@ -277,46 +279,80 @@ class Dataset[T] private[sql]( val sb = new StringBuilder val numCols = schema.fieldNames.length + // We set a minimum column width at '3' + val minimumColWidth = 3 - // Initialise the width of each column to a minimum value of '3' - val colWidths = Array.fill(numCols)(3) + if (!vertical) { + // Initialise the width of each column to a minimum value + val colWidths = Array.fill(numCols)(minimumColWidth) - // Compute the width of each column - for (row <- rows) { - for ((cell, i) <- row.zipWithIndex) { - colWidths(i) = math.max(colWidths(i), cell.length) - } - } - - // Create SeparateLine - val sep: String = colWidths.map("-" * _).addString(sb, "+", "+", "+\n").toString() - - // column names - rows.head.zipWithIndex.map { case (cell, i) => - if (truncate > 0) { - StringUtils.leftPad(cell, colWidths(i)) - } else { - StringUtils.rightPad(cell, colWidths(i)) + // Compute the width of each column + for (row <- rows) { + for ((cell, i) <- row.zipWithIndex) { + colWidths(i) = math.max(colWidths(i), cell.length) + } } - }.addString(sb, "|", "|", "|\n") - sb.append(sep) + // Create SeparateLine + val sep: String = colWidths.map("-" * _).addString(sb, "+", "+", "+\n").toString() - // data - rows.tail.map { - _.zipWithIndex.map { case (cell, i) => + // column names + rows.head.zipWithIndex.map { case (cell, i) => if (truncate > 0) { - StringUtils.leftPad(cell.toString, colWidths(i)) + StringUtils.leftPad(cell, colWidths(i)) } else { - StringUtils.rightPad(cell.toString, colWidths(i)) + StringUtils.rightPad(cell, colWidths(i)) } }.addString(sb, "|", "|", "|\n") - } - sb.append(sep) + sb.append(sep) + + // data + rows.tail.foreach { + _.zipWithIndex.map { case (cell, i) => + if (truncate > 0) { + StringUtils.leftPad(cell.toString, colWidths(i)) + } else { + StringUtils.rightPad(cell.toString, colWidths(i)) + } + }.addString(sb, "|", "|", "|\n") + } + + sb.append(sep) + } else { + // Extended display mode enabled + val fieldNames = rows.head + val dataRows = rows.tail + + // Compute the width of field name and data columns + val fieldNameColWidth = fieldNames.foldLeft(minimumColWidth) { case (curMax, fieldName) => + math.max(curMax, fieldName.length) + } + val dataColWidth = dataRows.foldLeft(minimumColWidth) { case (curMax, row) => + math.max(curMax, row.map(_.length).reduceLeftOption[Int] { case (cellMax, cell) => + math.max(cellMax, cell) + }.getOrElse(0)) + } + + dataRows.zipWithIndex.foreach { case (row, i) => + // "+ 5" in size means a character length except for padded names and data + val rowHeader = StringUtils.rightPad( + s"-RECORD $i", fieldNameColWidth + dataColWidth + 5, "-") + sb.append(rowHeader).append("\n") + row.zipWithIndex.map { case (cell, j) => + val fieldName = StringUtils.rightPad(fieldNames(j), fieldNameColWidth) + val data = StringUtils.rightPad(cell, dataColWidth) + s" $fieldName | $data " + }.addString(sb, "", "\n", "\n") + } + } - // For Data that has more than "numRows" records - if (hasMoreData) { + // Print a footer + if (vertical && data.isEmpty) { + // In a vertical mode, print an empty row set explicitly + sb.append("(0 rows)\n") + } else if (hasMoreData) { + // For Data that has more than "numRows" records val rowsString = if (numRows == 1) "row" else "rows" sb.append(s"only showing top $numRows $rowsString\n") } @@ -663,8 +699,59 @@ class Dataset[T] private[sql]( * @group action * @since 1.6.0 */ + def show(numRows: Int, truncate: Int): Unit = show(numRows, truncate, vertical = false) + + /** + * Displays the Dataset in a tabular form. For example: + * {{{ + * year month AVG('Adj Close) MAX('Adj Close) + * 1980 12 0.503218 0.595103 + * 1981 01 0.523289 0.570307 + * 1982 02 0.436504 0.475256 + * 1983 03 0.410516 0.442194 + * 1984 04 0.450090 0.483521 + * }}} + * + * If `vertical` enabled, this command prints output rows vertically (one line per column value)? + * + * {{{ + * -RECORD 0------------------- + * year | 1980 + * month | 12 + * AVG('Adj Close) | 0.503218 + * AVG('Adj Close) | 0.595103 + * -RECORD 1------------------- + * year | 1981 + * month | 01 + * AVG('Adj Close) | 0.523289 + * AVG('Adj Close) | 0.570307 + * -RECORD 2------------------- + * year | 1982 + * month | 02 + * AVG('Adj Close) | 0.436504 + * AVG('Adj Close) | 0.475256 + * -RECORD 3------------------- + * year | 1983 + * month | 03 + * AVG('Adj Close) | 0.410516 + * AVG('Adj Close) | 0.442194 + * -RECORD 4------------------- + * year | 1984 + * month | 04 + * AVG('Adj Close) | 0.450090 + * AVG('Adj Close) | 0.483521 + * }}} + * + * @param numRows Number of rows to show + * @param truncate If set to more than 0, truncates strings to `truncate` characters and + * all cells will be aligned right. + * @param vertical If set to true, prints output rows vertically (one line per column value). + * @group action + * @since 2.3.0 + */ // scalastyle:off println - def show(numRows: Int, truncate: Int): Unit = println(showString(numRows, truncate)) + def show(numRows: Int, truncate: Int, vertical: Boolean): Unit = + println(showString(numRows, truncate, vertical)) // scalastyle:on println /** diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index b4893b56a8a84..ef0de6f6f4ff1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -764,6 +764,21 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { assert(df.showString(10, truncate = 20) === expectedAnswerForTrue) } + test("showString: truncate = [0, 20], vertical = true") { + val longString = Array.fill(21)("1").mkString + val df = sparkContext.parallelize(Seq("1", longString)).toDF() + val expectedAnswerForFalse = "-RECORD 0----------------------\n" + + " value | 1 \n" + + "-RECORD 1----------------------\n" + + " value | 111111111111111111111 \n" + assert(df.showString(10, truncate = 0, vertical = true) === expectedAnswerForFalse) + val expectedAnswerForTrue = "-RECORD 0---------------------\n" + + " value | 1 \n" + + "-RECORD 1---------------------\n" + + " value | 11111111111111111... \n" + assert(df.showString(10, truncate = 20, vertical = true) === expectedAnswerForTrue) + } + test("showString: truncate = [3, 17]") { val longString = Array.fill(21)("1").mkString val df = sparkContext.parallelize(Seq("1", longString)).toDF() @@ -785,6 +800,21 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { assert(df.showString(10, truncate = 17) === expectedAnswerForTrue) } + test("showString: truncate = [3, 17], vertical = true") { + val longString = Array.fill(21)("1").mkString + val df = sparkContext.parallelize(Seq("1", longString)).toDF() + val expectedAnswerForFalse = "-RECORD 0----\n" + + " value | 1 \n" + + "-RECORD 1----\n" + + " value | 111 \n" + assert(df.showString(10, truncate = 3, vertical = true) === expectedAnswerForFalse) + val expectedAnswerForTrue = "-RECORD 0------------------\n" + + " value | 1 \n" + + "-RECORD 1------------------\n" + + " value | 11111111111111... \n" + assert(df.showString(10, truncate = 17, vertical = true) === expectedAnswerForTrue) + } + test("showString(negative)") { val expectedAnswer = """+---+-----+ ||key|value| @@ -795,6 +825,11 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { assert(testData.select($"*").showString(-1) === expectedAnswer) } + test("showString(negative), vertical = true") { + val expectedAnswer = "(0 rows)\n" + assert(testData.select($"*").showString(-1, vertical = true) === expectedAnswer) + } + test("showString(0)") { val expectedAnswer = """+---+-----+ ||key|value| @@ -805,6 +840,11 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { assert(testData.select($"*").showString(0) === expectedAnswer) } + test("showString(0), vertical = true") { + val expectedAnswer = "(0 rows)\n" + assert(testData.select($"*").showString(0, vertical = true) === expectedAnswer) + } + test("showString: array") { val df = Seq( (Array(1, 2, 3), Array(1, 2, 3)), @@ -820,6 +860,20 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { assert(df.showString(10) === expectedAnswer) } + test("showString: array, vertical = true") { + val df = Seq( + (Array(1, 2, 3), Array(1, 2, 3)), + (Array(2, 3, 4), Array(2, 3, 4)) + ).toDF() + val expectedAnswer = "-RECORD 0--------\n" + + " _1 | [1, 2, 3] \n" + + " _2 | [1, 2, 3] \n" + + "-RECORD 1--------\n" + + " _1 | [2, 3, 4] \n" + + " _2 | [2, 3, 4] \n" + assert(df.showString(10, vertical = true) === expectedAnswer) + } + test("showString: binary") { val df = Seq( ("12".getBytes(StandardCharsets.UTF_8), "ABC.".getBytes(StandardCharsets.UTF_8)), @@ -835,6 +889,20 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { assert(df.showString(10) === expectedAnswer) } + test("showString: binary, vertical = true") { + val df = Seq( + ("12".getBytes(StandardCharsets.UTF_8), "ABC.".getBytes(StandardCharsets.UTF_8)), + ("34".getBytes(StandardCharsets.UTF_8), "12346".getBytes(StandardCharsets.UTF_8)) + ).toDF() + val expectedAnswer = "-RECORD 0---------------\n" + + " _1 | [31 32] \n" + + " _2 | [41 42 43 2E] \n" + + "-RECORD 1---------------\n" + + " _1 | [33 34] \n" + + " _2 | [31 32 33 34 36] \n" + assert(df.showString(10, vertical = true) === expectedAnswer) + } + test("showString: minimum column width") { val df = Seq( (1, 1), @@ -850,6 +918,20 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { assert(df.showString(10) === expectedAnswer) } + test("showString: minimum column width, vertical = true") { + val df = Seq( + (1, 1), + (2, 2) + ).toDF() + val expectedAnswer = "-RECORD 0--\n" + + " _1 | 1 \n" + + " _2 | 1 \n" + + "-RECORD 1--\n" + + " _1 | 2 \n" + + " _2 | 2 \n" + assert(df.showString(10, vertical = true) === expectedAnswer) + } + test("SPARK-7319 showString") { val expectedAnswer = """+---+-----+ ||key|value| @@ -861,6 +943,14 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { assert(testData.select($"*").showString(1) === expectedAnswer) } + test("SPARK-7319 showString, vertical = true") { + val expectedAnswer = "-RECORD 0----\n" + + " key | 1 \n" + + " value | 1 \n" + + "only showing top 1 row\n" + assert(testData.select($"*").showString(1, vertical = true) === expectedAnswer) + } + test("SPARK-7327 show with empty dataFrame") { val expectedAnswer = """+---+-----+ ||key|value| @@ -870,6 +960,10 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { assert(testData.select($"*").filter($"key" < 0).showString(1) === expectedAnswer) } + test("SPARK-7327 show with empty dataFrame, vertical = true") { + assert(testData.select($"*").filter($"key" < 0).showString(1, vertical = true) === "(0 rows)\n") + } + test("SPARK-18350 show with session local timezone") { val d = Date.valueOf("2016-12-01") val ts = Timestamp.valueOf("2016-12-01 00:00:00") @@ -894,6 +988,24 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { } } + test("SPARK-18350 show with session local timezone, vertical = true") { + val d = Date.valueOf("2016-12-01") + val ts = Timestamp.valueOf("2016-12-01 00:00:00") + val df = Seq((d, ts)).toDF("d", "ts") + val expectedAnswer = "-RECORD 0------------------\n" + + " d | 2016-12-01 \n" + + " ts | 2016-12-01 00:00:00 \n" + assert(df.showString(1, truncate = 0, vertical = true) === expectedAnswer) + + withSQLConf(SQLConf.SESSION_LOCAL_TIMEZONE.key -> "GMT") { + + val expectedAnswer = "-RECORD 0------------------\n" + + " d | 2016-12-01 \n" + + " ts | 2016-12-01 08:00:00 \n" + assert(df.showString(1, truncate = 0, vertical = true) === expectedAnswer) + } + } + test("createDataFrame(RDD[Row], StructType) should convert UDTs (SPARK-6672)") { val rowRDD = sparkContext.parallelize(Seq(Row(new ExamplePoint(1.0, 2.0)))) val schema = StructType(Array(StructField("point", new ExamplePointUDT(), false))) From b58cf77c4db49ba236b779905a943f025c6aaedd Mon Sep 17 00:00:00 2001 From: zero323 Date: Thu, 27 Apr 2017 00:29:43 -0700 Subject: [PATCH 0345/1765] [DOCS][MINOR] Add missing since to SparkR repeat_string note. ## What changes were proposed in this pull request? Replace note repeat_string 2.3.0 with note repeat_string since 2.3.0 ## How was this patch tested? `create-docs.sh` Author: zero323 Closes #17779 from zero323/REPEAT-NOTE. --- R/pkg/R/functions.R | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/R/pkg/R/functions.R b/R/pkg/R/functions.R index 752e4c5c7189d..6b91fa5bde671 100644 --- a/R/pkg/R/functions.R +++ b/R/pkg/R/functions.R @@ -3796,7 +3796,7 @@ setMethod("split_string", #' # This is equivalent to the following SQL expression #' first(selectExpr(df, "repeat(value, 3)")) #' } -#' @note repeat_string 2.3.0 +#' @note repeat_string since 2.3.0 setMethod("repeat_string", signature(x = "Column", n = "numeric"), function(x, n) { From ba7666274e71f1903e5050a5e53fbdcd21debde5 Mon Sep 17 00:00:00 2001 From: zero323 Date: Thu, 27 Apr 2017 00:34:20 -0700 Subject: [PATCH 0346/1765] [SPARK-20208][DOCS][FOLLOW-UP] Add FP-Growth to SparkR programming guide ## What changes were proposed in this pull request? Add `spark.fpGrowth` to SparkR programming guide. ## How was this patch tested? Manual tests. Author: zero323 Closes #17775 from zero323/SPARK-20208-FOLLOW-UP. --- docs/sparkr.md | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/docs/sparkr.md b/docs/sparkr.md index e015ab260fca8..c3336ac2ce86a 100644 --- a/docs/sparkr.md +++ b/docs/sparkr.md @@ -504,6 +504,10 @@ SparkR supports the following machine learning algorithms currently: * [`spark.als`](api/R/spark.als.html): [`Alternating Least Squares (ALS)`](ml-collaborative-filtering.html#collaborative-filtering) +#### Frequent Pattern Mining + +* [`spark.fpGrowth`](api/R/spark.fpGrowth.html) : [`FP-growth`](ml-frequent-pattern-mining.html#fp-growth) + #### Statistics * [`spark.kstest`](api/R/spark.kstest.html): `Kolmogorov-Smirnov Test` From 7633933e54ffb08ab9d959be5f76c26fae29d1d9 Mon Sep 17 00:00:00 2001 From: Davis Shepherd Date: Thu, 27 Apr 2017 18:06:12 +0000 Subject: [PATCH 0347/1765] [SPARK-20483] Mesos Coarse mode may starve other Mesos frameworks ## What changes were proposed in this pull request? Set maxCores to be a multiple of the smallest executor that can be launched. This ensures that we correctly detect the condition where no more executors will be launched when spark.cores.max is not a multiple of spark.executor.cores ## How was this patch tested? This was manually tested with other sample frameworks measuring their incoming offers to determine if starvation would occur. dbtsai mgummelt Author: Davis Shepherd Closes #17786 from dgshep/fix_mesos_max_cores. --- .../MesosCoarseGrainedSchedulerBackend.scala | 15 ++++++++++++--- 1 file changed, 12 insertions(+), 3 deletions(-) diff --git a/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackend.scala b/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackend.scala index 2a36ec4fa8112..8f5b97ccb1f85 100644 --- a/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackend.scala +++ b/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackend.scala @@ -60,8 +60,16 @@ private[spark] class MesosCoarseGrainedSchedulerBackend( private val maxCoresOption = conf.getOption("spark.cores.max").map(_.toInt) + private val executorCoresOption = conf.getOption("spark.executor.cores").map(_.toInt) + + private val minCoresPerExecutor = executorCoresOption.getOrElse(1) + // Maximum number of cores to acquire - private val maxCores = maxCoresOption.getOrElse(Int.MaxValue) + private val maxCores = { + val cores = maxCoresOption.getOrElse(Int.MaxValue) + // Set maxCores to a multiple of smallest executor we can launch + cores - (cores % minCoresPerExecutor) + } private val useFetcherCache = conf.getBoolean("spark.mesos.fetcherCache.enable", false) @@ -489,8 +497,9 @@ private[spark] class MesosCoarseGrainedSchedulerBackend( } private def executorCores(offerCPUs: Int): Int = { - sc.conf.getInt("spark.executor.cores", - math.min(offerCPUs, maxCores - totalCoresAcquired)) + executorCoresOption.getOrElse( + math.min(offerCPUs, maxCores - totalCoresAcquired) + ) } override def statusUpdate(d: org.apache.mesos.SchedulerDriver, status: TaskStatus) { From 561e9cc390b429e4252f59f00a7ca4f6f8c853f8 Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Thu, 27 Apr 2017 11:31:01 -0700 Subject: [PATCH 0348/1765] [SPARK-20421][CORE] Mark internal listeners as deprecated. These listeners weren't really meant for external consumption, but they're public and marked with DeveloperApi. Adding the deprecated tag warns people that they may soon go away (as they will as part of the work for SPARK-18085). Note that not all types made public by https://github.com/apache/spark/pull/648 are being deprecated. Some remaining types are still exposed through the SparkListener API. Also note the text for StorageStatus is a tiny bit different, since I'm not so sure I'll be able to remove it. But the effect for the users should be the same (they should stop trying to use it). Author: Marcelo Vanzin Closes #17766 from vanzin/SPARK-20421. --- .../scala/org/apache/spark/storage/StorageStatusListener.scala | 1 + core/src/main/scala/org/apache/spark/storage/StorageUtils.scala | 1 + core/src/main/scala/org/apache/spark/ui/env/EnvironmentTab.scala | 1 + core/src/main/scala/org/apache/spark/ui/exec/ExecutorsTab.scala | 1 + .../scala/org/apache/spark/ui/jobs/JobProgressListener.scala | 1 + core/src/main/scala/org/apache/spark/ui/storage/StorageTab.scala | 1 + 6 files changed, 6 insertions(+) diff --git a/core/src/main/scala/org/apache/spark/storage/StorageStatusListener.scala b/core/src/main/scala/org/apache/spark/storage/StorageStatusListener.scala index 1b30d4fa93bc0..ac60f795915a3 100644 --- a/core/src/main/scala/org/apache/spark/storage/StorageStatusListener.scala +++ b/core/src/main/scala/org/apache/spark/storage/StorageStatusListener.scala @@ -30,6 +30,7 @@ import org.apache.spark.scheduler._ * This class is thread-safe (unlike JobProgressListener) */ @DeveloperApi +@deprecated("This class will be removed in a future release.", "2.2.0") class StorageStatusListener(conf: SparkConf) extends SparkListener { // This maintains only blocks that are cached (i.e. storage level is not StorageLevel.NONE) private[storage] val executorIdToStorageStatus = mutable.Map[String, StorageStatus]() diff --git a/core/src/main/scala/org/apache/spark/storage/StorageUtils.scala b/core/src/main/scala/org/apache/spark/storage/StorageUtils.scala index 8f0d181fc8fe5..e9694fdbca2de 100644 --- a/core/src/main/scala/org/apache/spark/storage/StorageUtils.scala +++ b/core/src/main/scala/org/apache/spark/storage/StorageUtils.scala @@ -35,6 +35,7 @@ import org.apache.spark.internal.Logging * class cannot mutate the source of the information. Accesses are not thread-safe. */ @DeveloperApi +@deprecated("This class may be removed or made private in a future release.", "2.2.0") class StorageStatus( val blockManagerId: BlockManagerId, val maxMemory: Long, diff --git a/core/src/main/scala/org/apache/spark/ui/env/EnvironmentTab.scala b/core/src/main/scala/org/apache/spark/ui/env/EnvironmentTab.scala index 70b3ffd95e605..8c18464e6477a 100644 --- a/core/src/main/scala/org/apache/spark/ui/env/EnvironmentTab.scala +++ b/core/src/main/scala/org/apache/spark/ui/env/EnvironmentTab.scala @@ -32,6 +32,7 @@ private[ui] class EnvironmentTab(parent: SparkUI) extends SparkUITab(parent, "en * A SparkListener that prepares information to be displayed on the EnvironmentTab */ @DeveloperApi +@deprecated("This class will be removed in a future release.", "2.2.0") class EnvironmentListener extends SparkListener { var jvmInformation = Seq[(String, String)]() var sparkProperties = Seq[(String, String)]() diff --git a/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsTab.scala b/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsTab.scala index 03851293eb2f1..aabf6e0c63c02 100644 --- a/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsTab.scala +++ b/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsTab.scala @@ -62,6 +62,7 @@ private[ui] case class ExecutorTaskSummary( * A SparkListener that prepares information to be displayed on the ExecutorsTab */ @DeveloperApi +@deprecated("This class will be removed in a future release.", "2.2.0") class ExecutorsListener(storageStatusListener: StorageStatusListener, conf: SparkConf) extends SparkListener { val executorToTaskSummary = LinkedHashMap[String, ExecutorTaskSummary]() diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala b/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala index f78db5ab80d15..8870187f2219c 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala @@ -41,6 +41,7 @@ import org.apache.spark.ui.jobs.UIData._ * updating the internal data structures concurrently. */ @DeveloperApi +@deprecated("This class will be removed in a future release.", "2.2.0") class JobProgressListener(conf: SparkConf) extends SparkListener with Logging { // Define a handful of type aliases so that data structures' types can serve as documentation. diff --git a/core/src/main/scala/org/apache/spark/ui/storage/StorageTab.scala b/core/src/main/scala/org/apache/spark/ui/storage/StorageTab.scala index c212362557be6..148efb134e14f 100644 --- a/core/src/main/scala/org/apache/spark/ui/storage/StorageTab.scala +++ b/core/src/main/scala/org/apache/spark/ui/storage/StorageTab.scala @@ -39,6 +39,7 @@ private[ui] class StorageTab(parent: SparkUI) extends SparkUITab(parent, "storag * This class is thread-safe (unlike JobProgressListener) */ @DeveloperApi +@deprecated("This class will be removed in a future release.", "2.2.0") class StorageListener(storageStatusListener: StorageStatusListener) extends BlockStatusListener { private[ui] val _rddInfoMap = mutable.Map[Int, RDDInfo]() // exposed for testing From 85c6ce61930490e2247fb4b0e22dfebbb8b6a1ee Mon Sep 17 00:00:00 2001 From: jinxing Date: Thu, 27 Apr 2017 14:06:07 -0500 Subject: [PATCH 0349/1765] [SPARK-20426] Lazy initialization of FileSegmentManagedBuffer for shuffle service. ## What changes were proposed in this pull request? When application contains large amount of shuffle blocks. NodeManager requires lots of memory to keep metadata(`FileSegmentManagedBuffer`) in `StreamManager`. When the number of shuffle blocks is big enough. NodeManager can run OOM. This pr proposes to do lazy initialization of `FileSegmentManagedBuffer` in shuffle service. ## How was this patch tested? Manually test. Author: jinxing Closes #17744 from jinxing64/SPARK-20426. --- .../shuffle/ExternalShuffleBlockHandler.java | 31 ++++++++++++------- .../ExternalShuffleBlockHandlerSuite.java | 4 +-- .../ExternalShuffleIntegrationSuite.java | 5 ++- .../network/netty/NettyBlockRpcServer.scala | 9 +++--- 4 files changed, 29 insertions(+), 20 deletions(-) diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandler.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandler.java index 6daf9609d76dc..c0f1da50f5e65 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandler.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandler.java @@ -21,7 +21,7 @@ import java.io.IOException; import java.nio.ByteBuffer; import java.util.HashMap; -import java.util.List; +import java.util.Iterator; import java.util.Map; import com.codahale.metrics.Gauge; @@ -30,7 +30,6 @@ import com.codahale.metrics.MetricSet; import com.codahale.metrics.Timer; import com.google.common.annotations.VisibleForTesting; -import com.google.common.collect.Lists; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -93,14 +92,25 @@ protected void handleMessage( OpenBlocks msg = (OpenBlocks) msgObj; checkAuth(client, msg.appId); - List blocks = Lists.newArrayList(); - long totalBlockSize = 0; - for (String blockId : msg.blockIds) { - final ManagedBuffer block = blockManager.getBlockData(msg.appId, msg.execId, blockId); - totalBlockSize += block != null ? block.size() : 0; - blocks.add(block); - } - long streamId = streamManager.registerStream(client.getClientId(), blocks.iterator()); + Iterator iter = new Iterator() { + private int index = 0; + + @Override + public boolean hasNext() { + return index < msg.blockIds.length; + } + + @Override + public ManagedBuffer next() { + final ManagedBuffer block = blockManager.getBlockData(msg.appId, msg.execId, + msg.blockIds[index]); + index++; + metrics.blockTransferRateBytes.mark(block != null ? block.size() : 0); + return block; + } + }; + + long streamId = streamManager.registerStream(client.getClientId(), iter); if (logger.isTraceEnabled()) { logger.trace("Registered streamId {} with {} buffers for client {} from host {}", streamId, @@ -109,7 +119,6 @@ protected void handleMessage( getRemoteAddress(client.getChannel())); } callback.onSuccess(new StreamHandle(streamId, msg.blockIds.length).toByteBuffer()); - metrics.blockTransferRateBytes.mark(totalBlockSize); } finally { responseDelayContext.stop(); } diff --git a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandlerSuite.java b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandlerSuite.java index e47a72c9d16cc..4d48b18970386 100644 --- a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandlerSuite.java +++ b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandlerSuite.java @@ -88,8 +88,6 @@ public void testOpenShuffleBlocks() { ByteBuffer openBlocks = new OpenBlocks("app0", "exec1", new String[] { "b0", "b1" }) .toByteBuffer(); handler.receive(client, openBlocks, callback); - verify(blockResolver, times(1)).getBlockData("app0", "exec1", "b0"); - verify(blockResolver, times(1)).getBlockData("app0", "exec1", "b1"); ArgumentCaptor response = ArgumentCaptor.forClass(ByteBuffer.class); verify(callback, times(1)).onSuccess(response.capture()); @@ -107,6 +105,8 @@ public void testOpenShuffleBlocks() { assertEquals(block0Marker, buffers.next()); assertEquals(block1Marker, buffers.next()); assertFalse(buffers.hasNext()); + verify(blockResolver, times(1)).getBlockData("app0", "exec1", "b0"); + verify(blockResolver, times(1)).getBlockData("app0", "exec1", "b1"); // Verify open block request latency metrics Timer openBlockRequestLatencyMillis = (Timer) ((ExternalShuffleBlockHandler) handler) diff --git a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleIntegrationSuite.java b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleIntegrationSuite.java index b8ae04eefb972..7a33b6821792c 100644 --- a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleIntegrationSuite.java +++ b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleIntegrationSuite.java @@ -216,9 +216,8 @@ public void testFetchWrongExecutor() throws Exception { registerExecutor("exec-0", dataContext0.createExecutorInfo(SORT_MANAGER)); FetchResult execFetch = fetchBlocks("exec-0", new String[] { "shuffle_0_0_0" /* right */, "shuffle_1_0_0" /* wrong */ }); - // Both still fail, as we start by checking for all block. - assertTrue(execFetch.successBlocks.isEmpty()); - assertEquals(Sets.newHashSet("shuffle_0_0_0", "shuffle_1_0_0"), execFetch.failedBlocks); + assertEquals(Sets.newHashSet("shuffle_0_0_0"), execFetch.successBlocks); + assertEquals(Sets.newHashSet("shuffle_1_0_0"), execFetch.failedBlocks); } @Test diff --git a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockRpcServer.scala b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockRpcServer.scala index 2ed8a00df7023..305fd9a6de10d 100644 --- a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockRpcServer.scala +++ b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockRpcServer.scala @@ -56,11 +56,12 @@ class NettyBlockRpcServer( message match { case openBlocks: OpenBlocks => - val blocks: Seq[ManagedBuffer] = - openBlocks.blockIds.map(BlockId.apply).map(blockManager.getBlockData) + val blocksNum = openBlocks.blockIds.length + val blocks = for (i <- (0 until blocksNum).view) + yield blockManager.getBlockData(BlockId.apply(openBlocks.blockIds(i))) val streamId = streamManager.registerStream(appId, blocks.iterator.asJava) - logTrace(s"Registered streamId $streamId with ${blocks.size} buffers") - responseContext.onSuccess(new StreamHandle(streamId, blocks.size).toByteBuffer) + logTrace(s"Registered streamId $streamId with $blocksNum buffers") + responseContext.onSuccess(new StreamHandle(streamId, blocksNum).toByteBuffer) case uploadBlock: UploadBlock => // StorageLevel and ClassTag are serialized as bytes using our JavaSerializer. From 26ac2ce05cbaf8f152347219403e31491e9c9bf1 Mon Sep 17 00:00:00 2001 From: Kris Mok Date: Thu, 27 Apr 2017 12:08:16 -0700 Subject: [PATCH 0350/1765] [SPARK-20482][SQL] Resolving Casts is too strict on having time zone set ## What changes were proposed in this pull request? Relax the requirement that a `TimeZoneAwareExpression` has to have its `timeZoneId` set to be considered resolved. With this change, a `Cast` (which is a `TimeZoneAwareExpression`) can be considered resolved if the `(fromType, toType)` combination doesn't require time zone information. Also de-relaxed test cases in `CastSuite` so Casts in that test suite don't get a default`timeZoneId = Option("GMT")`. ## How was this patch tested? Ran the de-relaxed`CastSuite` and it's passing. Also ran the SQL unit tests and they're passing too. Author: Kris Mok Closes #17777 from rednaxelafx/fix-catalyst-cast-timezone. --- .../spark/sql/catalyst/expressions/Cast.scala | 32 +++++++++++++++++++ .../sql/catalyst/expressions/CastSuite.scala | 4 +-- 2 files changed, 34 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala index bb1273f5c3d84..a53ef426f79b5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala @@ -89,6 +89,31 @@ object Cast { case _ => false } + /** + * Return true if we need to use the `timeZone` information casting `from` type to `to` type. + * The patterns matched reflect the current implementation in the Cast node. + * c.f. usage of `timeZone` in: + * * Cast.castToString + * * Cast.castToDate + * * Cast.castToTimestamp + */ + def needsTimeZone(from: DataType, to: DataType): Boolean = (from, to) match { + case (StringType, TimestampType) => true + case (DateType, TimestampType) => true + case (TimestampType, StringType) => true + case (TimestampType, DateType) => true + case (ArrayType(fromType, _), ArrayType(toType, _)) => needsTimeZone(fromType, toType) + case (MapType(fromKey, fromValue, _), MapType(toKey, toValue, _)) => + needsTimeZone(fromKey, toKey) || needsTimeZone(fromValue, toValue) + case (StructType(fromFields), StructType(toFields)) => + fromFields.length == toFields.length && + fromFields.zip(toFields).exists { + case (fromField, toField) => + needsTimeZone(fromField.dataType, toField.dataType) + } + case _ => false + } + /** * Return true iff we may truncate during casting `from` type to `to` type. e.g. long -> int, * timestamp -> date. @@ -165,6 +190,13 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String override def withTimeZone(timeZoneId: String): TimeZoneAwareExpression = copy(timeZoneId = Option(timeZoneId)) + // When this cast involves TimeZone, it's only resolved if the timeZoneId is set; + // Otherwise behave like Expression.resolved. + override lazy val resolved: Boolean = + childrenResolved && checkInputDataTypes().isSuccess && (!needsTimeZone || timeZoneId.isDefined) + + private[this] def needsTimeZone: Boolean = Cast.needsTimeZone(child.dataType, dataType) + // [[func]] assumes the input is no longer null because eval already does the null check. @inline private[this] def buildCast[T](a: Any, func: T => Any): Any = func(a.asInstanceOf[T]) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala index 22f3f3514fa41..a7ffa884d2286 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala @@ -34,7 +34,7 @@ import org.apache.spark.unsafe.types.UTF8String */ class CastSuite extends SparkFunSuite with ExpressionEvalHelper { - private def cast(v: Any, targetType: DataType, timeZoneId: Option[String] = Some("GMT")): Cast = { + private def cast(v: Any, targetType: DataType, timeZoneId: Option[String] = None): Cast = { v match { case lit: Expression => Cast(lit, targetType, timeZoneId) case _ => Cast(Literal(v), targetType, timeZoneId) @@ -47,7 +47,7 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper { } private def checkNullCast(from: DataType, to: DataType): Unit = { - checkEvaluation(cast(Literal.create(null, from), to), null) + checkEvaluation(cast(Literal.create(null, from), to, Option("GMT")), null) } test("null cast") { From a4aa4665a6775b514b714c88b70576090d2b4a7e Mon Sep 17 00:00:00 2001 From: Tejas Patil Date: Thu, 27 Apr 2017 12:13:16 -0700 Subject: [PATCH 0351/1765] [SPARK-20487][SQL] `HiveTableScan` node is quite verbose in explained plan ## What changes were proposed in this pull request? Changed `TreeNode.argString` to handle `CatalogTable` separately (otherwise it would call the default `toString` on the `CatalogTable`) ## How was this patch tested? - Expanded scope of existing unit test to ensure that verbose information is not present - Manual testing Before ``` scala> hc.sql(" SELECT * FROM my_table WHERE name = 'foo' ").explain(true) == Parsed Logical Plan == 'Project [*] +- 'Filter ('name = foo) +- 'UnresolvedRelation `my_table` == Analyzed Logical Plan == user_id: bigint, name: string, ds: string Project [user_id#13L, name#14, ds#15] +- Filter (name#14 = foo) +- SubqueryAlias my_table +- CatalogRelation CatalogTable( Database: default Table: my_table Owner: tejasp Created: Fri Apr 14 17:05:50 PDT 2017 Last Access: Wed Dec 31 16:00:00 PST 1969 Type: MANAGED Provider: hive Properties: [serialization.format=1] Statistics: 9223372036854775807 bytes Location: file:/tmp/warehouse/my_table Serde Library: org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe InputFormat: org.apache.hadoop.mapred.TextInputFormat OutputFormat: org.apache.hadoop.hive.ql.io.HiveIgnoreKeyTextOutputFormat Partition Provider: Catalog Partition Columns: [`ds`] Schema: root -- user_id: long (nullable = true) -- name: string (nullable = true) -- ds: string (nullable = true) ), [user_id#13L, name#14], [ds#15] == Optimized Logical Plan == Filter (isnotnull(name#14) && (name#14 = foo)) +- CatalogRelation CatalogTable( Database: default Table: my_table Owner: tejasp Created: Fri Apr 14 17:05:50 PDT 2017 Last Access: Wed Dec 31 16:00:00 PST 1969 Type: MANAGED Provider: hive Properties: [serialization.format=1] Statistics: 9223372036854775807 bytes Location: file:/tmp/warehouse/my_table Serde Library: org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe InputFormat: org.apache.hadoop.mapred.TextInputFormat OutputFormat: org.apache.hadoop.hive.ql.io.HiveIgnoreKeyTextOutputFormat Partition Provider: Catalog Partition Columns: [`ds`] Schema: root -- user_id: long (nullable = true) -- name: string (nullable = true) -- ds: string (nullable = true) ), [user_id#13L, name#14], [ds#15] == Physical Plan == *Filter (isnotnull(name#14) && (name#14 = foo)) +- HiveTableScan [user_id#13L, name#14, ds#15], CatalogRelation CatalogTable( Database: default Table: my_table Owner: tejasp Created: Fri Apr 14 17:05:50 PDT 2017 Last Access: Wed Dec 31 16:00:00 PST 1969 Type: MANAGED Provider: hive Properties: [serialization.format=1] Statistics: 9223372036854775807 bytes Location: file:/tmp/warehouse/my_table Serde Library: org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe InputFormat: org.apache.hadoop.mapred.TextInputFormat OutputFormat: org.apache.hadoop.hive.ql.io.HiveIgnoreKeyTextOutputFormat Partition Provider: Catalog Partition Columns: [`ds`] Schema: root -- user_id: long (nullable = true) -- name: string (nullable = true) -- ds: string (nullable = true) ), [user_id#13L, name#14], [ds#15] ``` After ``` scala> hc.sql(" SELECT * FROM my_table WHERE name = 'foo' ").explain(true) == Parsed Logical Plan == 'Project [*] +- 'Filter ('name = foo) +- 'UnresolvedRelation `my_table` == Analyzed Logical Plan == user_id: bigint, name: string, ds: string Project [user_id#13L, name#14, ds#15] +- Filter (name#14 = foo) +- SubqueryAlias my_table +- CatalogRelation `default`.`my_table`, [user_id#13L, name#14], [ds#15] == Optimized Logical Plan == Filter (isnotnull(name#14) && (name#14 = foo)) +- CatalogRelation `default`.`my_table`, [user_id#13L, name#14], [ds#15] == Physical Plan == *Filter (isnotnull(name#14) && (name#14 = foo)) +- HiveTableScan [user_id#13L, name#14, ds#15], CatalogRelation `default`.`my_table`, [user_id#13L, name#14], [ds#15] ``` Author: Tejas Patil Closes #17780 from tejasapatil/SPARK-20487_verbose_plan. --- .../spark/sql/catalyst/trees/TreeNode.scala | 1 + .../sql/hive/execution/HiveExplainSuite.scala | 18 +++++++++++++++++- 2 files changed, 18 insertions(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala index cc4c0835954ba..b091315f24f1f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala @@ -444,6 +444,7 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product { case None => Nil case Some(null) => Nil case Some(any) => any :: Nil + case table: CatalogTable => table.identifier :: Nil case other => other :: Nil }.mkString(", ") diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveExplainSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveExplainSuite.scala index 8a37bc3665d32..ebafe6de0c830 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveExplainSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveExplainSuite.scala @@ -47,7 +47,23 @@ class HiveExplainSuite extends QueryTest with SQLTestUtils with TestHiveSingleto checkKeywordsNotExist(sql(" explain select * from src where key=123 "), "== Parsed Logical Plan ==", "== Analyzed Logical Plan ==", - "== Optimized Logical Plan ==") + "== Optimized Logical Plan ==", + "Owner", + "Database", + "Created", + "Last Access", + "Type", + "Provider", + "Properties", + "Statistics", + "Location", + "Serde Library", + "InputFormat", + "OutputFormat", + "Partition Provider", + "Schema" + ) + checkKeywordsExist(sql(" explain extended select * from src where key=123 "), "== Parsed Logical Plan ==", "== Analyzed Logical Plan ==", From 039e32ca19d113e3be2c09171c7c921698be7ab8 Mon Sep 17 00:00:00 2001 From: Davis Shepherd Date: Thu, 27 Apr 2017 20:25:52 +0000 Subject: [PATCH 0352/1765] [SPARK-20483][MINOR] Test for Mesos Coarse mode may starve other Mesos frameworks ## What changes were proposed in this pull request? Add test case for scenarios where executor.cores is set as a (non)divisor of spark.cores.max This tests the change in #17786 ## How was this patch tested? Ran the existing test suite with the new tests dbtsai Author: Davis Shepherd Closes #17788 from dgshep/add_mesos_test. --- ...osCoarseGrainedSchedulerBackendSuite.scala | 34 +++++++++++++++++++ 1 file changed, 34 insertions(+) diff --git a/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackendSuite.scala b/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackendSuite.scala index c040f05d93b3a..0418bfbaa5ed8 100644 --- a/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackendSuite.scala +++ b/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackendSuite.scala @@ -199,6 +199,40 @@ class MesosCoarseGrainedSchedulerBackendSuite extends SparkFunSuite verifyDeclinedOffer(driver, createOfferId("o2"), true) } + test("mesos declines offers with a filter when maxCores not a multiple of executor.cores") { + val maxCores = 4 + val executorCores = 3 + setBackend(Map( + "spark.cores.max" -> maxCores.toString, + "spark.executor.cores" -> executorCores.toString + )) + val executorMemory = backend.executorMemory(sc) + offerResources(List( + Resources(executorMemory, maxCores + 1), + Resources(executorMemory, maxCores + 1) + )) + verifyTaskLaunched(driver, "o1") + verifyDeclinedOffer(driver, createOfferId("o2"), true) + } + + test("mesos declines offers with a filter when reached spark.cores.max with executor.cores") { + val maxCores = 4 + val executorCores = 2 + setBackend(Map( + "spark.cores.max" -> maxCores.toString, + "spark.executor.cores" -> executorCores.toString + )) + val executorMemory = backend.executorMemory(sc) + offerResources(List( + Resources(executorMemory, maxCores + 1), + Resources(executorMemory, maxCores + 1), + Resources(executorMemory, maxCores + 1) + )) + verifyTaskLaunched(driver, "o1") + verifyTaskLaunched(driver, "o2") + verifyDeclinedOffer(driver, createOfferId("o3"), true) + } + test("mesos assigns tasks round-robin on offers") { val executorCores = 4 val maxCores = executorCores * 2 From 606432a13ad22d862c7cb5028ad6fe73c9985423 Mon Sep 17 00:00:00 2001 From: Yanbo Liang Date: Thu, 27 Apr 2017 20:48:43 +0000 Subject: [PATCH 0353/1765] [SPARK-20047][ML] Constrained Logistic Regression ## What changes were proposed in this pull request? MLlib ```LogisticRegression``` should support bound constrained optimization (only for L2 regularization). Users can add bound constraints to coefficients to make the solver produce solution in the specified range. Under the hood, we call Breeze [```L-BFGS-B```](https://github.com/scalanlp/breeze/blob/master/math/src/main/scala/breeze/optimize/LBFGSB.scala) as the solver for bound constrained optimization. But in the current breeze implementation, there are some bugs in L-BFGS-B, and https://github.com/scalanlp/breeze/pull/633 fixed them. We need to upgrade dependent breeze later, and currently we use the workaround L-BFGS-B in this PR temporary for reviewing. ## How was this patch tested? Unit tests. Author: Yanbo Liang Closes #17715 from yanboliang/spark-20047. --- .../classification/LogisticRegression.scala | 223 ++++++++- .../LogisticRegressionSuite.scala | 466 +++++++++++++++++- 2 files changed, 682 insertions(+), 7 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala index 44b3478e0c3dd..d7dde329ed004 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala @@ -22,7 +22,7 @@ import java.util.Locale import scala.collection.mutable import breeze.linalg.{DenseVector => BDV} -import breeze.optimize.{CachedDiffFunction, DiffFunction, LBFGS => BreezeLBFGS, OWLQN => BreezeOWLQN} +import breeze.optimize.{CachedDiffFunction, DiffFunction, LBFGS => BreezeLBFGS, LBFGSB => BreezeLBFGSB, OWLQN => BreezeOWLQN} import org.apache.hadoop.fs.Path import org.apache.spark.SparkException @@ -178,11 +178,86 @@ private[classification] trait LogisticRegressionParams extends ProbabilisticClas } } + /** + * The lower bounds on coefficients if fitting under bound constrained optimization. + * The bound matrix must be compatible with the shape (1, number of features) for binomial + * regression, or (number of classes, number of features) for multinomial regression. + * Otherwise, it throws exception. + * + * @group param + */ + @Since("2.2.0") + val lowerBoundsOnCoefficients: Param[Matrix] = new Param(this, "lowerBoundsOnCoefficients", + "The lower bounds on coefficients if fitting under bound constrained optimization.") + + /** @group getParam */ + @Since("2.2.0") + def getLowerBoundsOnCoefficients: Matrix = $(lowerBoundsOnCoefficients) + + /** + * The upper bounds on coefficients if fitting under bound constrained optimization. + * The bound matrix must be compatible with the shape (1, number of features) for binomial + * regression, or (number of classes, number of features) for multinomial regression. + * Otherwise, it throws exception. + * + * @group param + */ + @Since("2.2.0") + val upperBoundsOnCoefficients: Param[Matrix] = new Param(this, "upperBoundsOnCoefficients", + "The upper bounds on coefficients if fitting under bound constrained optimization.") + + /** @group getParam */ + @Since("2.2.0") + def getUpperBoundsOnCoefficients: Matrix = $(upperBoundsOnCoefficients) + + /** + * The lower bounds on intercepts if fitting under bound constrained optimization. + * The bounds vector size must be equal with 1 for binomial regression, or the number + * of classes for multinomial regression. Otherwise, it throws exception. + * + * @group param + */ + @Since("2.2.0") + val lowerBoundsOnIntercepts: Param[Vector] = new Param(this, "lowerBoundsOnIntercepts", + "The lower bounds on intercepts if fitting under bound constrained optimization.") + + /** @group getParam */ + @Since("2.2.0") + def getLowerBoundsOnIntercepts: Vector = $(lowerBoundsOnIntercepts) + + /** + * The upper bounds on intercepts if fitting under bound constrained optimization. + * The bound vector size must be equal with 1 for binomial regression, or the number + * of classes for multinomial regression. Otherwise, it throws exception. + * + * @group param + */ + @Since("2.2.0") + val upperBoundsOnIntercepts: Param[Vector] = new Param(this, "upperBoundsOnIntercepts", + "The upper bounds on intercepts if fitting under bound constrained optimization.") + + /** @group getParam */ + @Since("2.2.0") + def getUpperBoundsOnIntercepts: Vector = $(upperBoundsOnIntercepts) + + protected def usingBoundConstrainedOptimization: Boolean = { + isSet(lowerBoundsOnCoefficients) || isSet(upperBoundsOnCoefficients) || + isSet(lowerBoundsOnIntercepts) || isSet(upperBoundsOnIntercepts) + } + override protected def validateAndTransformSchema( schema: StructType, fitting: Boolean, featuresDataType: DataType): StructType = { checkThresholdConsistency() + if (usingBoundConstrainedOptimization) { + require($(elasticNetParam) == 0.0, "Fitting under bound constrained optimization only " + + s"supports L2 regularization, but got elasticNetParam = $getElasticNetParam.") + } + if (!$(fitIntercept)) { + require(!isSet(lowerBoundsOnIntercepts) && !isSet(upperBoundsOnIntercepts), + "Pls don't set bounds on intercepts if fitting without intercept.") + } super.validateAndTransformSchema(schema, fitting, featuresDataType) } } @@ -217,6 +292,9 @@ class LogisticRegression @Since("1.2.0") ( * For alpha in (0,1), the penalty is a combination of L1 and L2. * Default is 0.0 which is an L2 penalty. * + * Note: Fitting under bound constrained optimization only supports L2 regularization, + * so throws exception if this param is non-zero value. + * * @group setParam */ @Since("1.4.0") @@ -312,6 +390,71 @@ class LogisticRegression @Since("1.2.0") ( def setAggregationDepth(value: Int): this.type = set(aggregationDepth, value) setDefault(aggregationDepth -> 2) + /** + * Set the lower bounds on coefficients if fitting under bound constrained optimization. + * + * @group setParam + */ + @Since("2.2.0") + def setLowerBoundsOnCoefficients(value: Matrix): this.type = set(lowerBoundsOnCoefficients, value) + + /** + * Set the upper bounds on coefficients if fitting under bound constrained optimization. + * + * @group setParam + */ + @Since("2.2.0") + def setUpperBoundsOnCoefficients(value: Matrix): this.type = set(upperBoundsOnCoefficients, value) + + /** + * Set the lower bounds on intercepts if fitting under bound constrained optimization. + * + * @group setParam + */ + @Since("2.2.0") + def setLowerBoundsOnIntercepts(value: Vector): this.type = set(lowerBoundsOnIntercepts, value) + + /** + * Set the upper bounds on intercepts if fitting under bound constrained optimization. + * + * @group setParam + */ + @Since("2.2.0") + def setUpperBoundsOnIntercepts(value: Vector): this.type = set(upperBoundsOnIntercepts, value) + + private def assertBoundConstrainedOptimizationParamsValid( + numCoefficientSets: Int, + numFeatures: Int): Unit = { + if (isSet(lowerBoundsOnCoefficients)) { + require($(lowerBoundsOnCoefficients).numRows == numCoefficientSets && + $(lowerBoundsOnCoefficients).numCols == numFeatures) + } + if (isSet(upperBoundsOnCoefficients)) { + require($(upperBoundsOnCoefficients).numRows == numCoefficientSets && + $(upperBoundsOnCoefficients).numCols == numFeatures) + } + if (isSet(lowerBoundsOnIntercepts)) { + require($(lowerBoundsOnIntercepts).size == numCoefficientSets) + } + if (isSet(upperBoundsOnIntercepts)) { + require($(upperBoundsOnIntercepts).size == numCoefficientSets) + } + if (isSet(lowerBoundsOnCoefficients) && isSet(upperBoundsOnCoefficients)) { + require($(lowerBoundsOnCoefficients).toArray.zip($(upperBoundsOnCoefficients).toArray) + .forall(x => x._1 <= x._2), "LowerBoundsOnCoefficients should always " + + "less than or equal to upperBoundsOnCoefficients, but found: " + + s"lowerBoundsOnCoefficients = $getLowerBoundsOnCoefficients, " + + s"upperBoundsOnCoefficients = $getUpperBoundsOnCoefficients.") + } + if (isSet(lowerBoundsOnIntercepts) && isSet(upperBoundsOnIntercepts)) { + require($(lowerBoundsOnIntercepts).toArray.zip($(upperBoundsOnIntercepts).toArray) + .forall(x => x._1 <= x._2), "LowerBoundsOnIntercepts should always " + + "less than or equal to upperBoundsOnIntercepts, but found: " + + s"lowerBoundsOnIntercepts = $getLowerBoundsOnIntercepts, " + + s"upperBoundsOnIntercepts = $getUpperBoundsOnIntercepts.") + } + } + private var optInitialModel: Option[LogisticRegressionModel] = None private[spark] def setInitialModel(model: LogisticRegressionModel): this.type = { @@ -378,6 +521,11 @@ class LogisticRegression @Since("1.2.0") ( } val numCoefficientSets = if (isMultinomial) numClasses else 1 + // Check params interaction is valid if fitting under bound constrained optimization. + if (usingBoundConstrainedOptimization) { + assertBoundConstrainedOptimizationParamsValid(numCoefficientSets, numFeatures) + } + if (isDefined(thresholds)) { require($(thresholds).length == numClasses, this.getClass.getSimpleName + ".train() called with non-matching numClasses and thresholds.length." + @@ -397,7 +545,7 @@ class LogisticRegression @Since("1.2.0") ( val isConstantLabel = histogram.count(_ != 0.0) == 1 - if ($(fitIntercept) && isConstantLabel) { + if ($(fitIntercept) && isConstantLabel && !usingBoundConstrainedOptimization) { logWarning(s"All labels are the same value and fitIntercept=true, so the coefficients " + s"will be zeros. Training is not needed.") val constantLabelIndex = Vectors.dense(histogram).argmax @@ -434,8 +582,53 @@ class LogisticRegression @Since("1.2.0") ( $(standardization), bcFeaturesStd, regParamL2, multinomial = isMultinomial, $(aggregationDepth)) + val numCoeffsPlusIntercepts = numFeaturesPlusIntercept * numCoefficientSets + + val (lowerBounds, upperBounds): (Array[Double], Array[Double]) = { + if (usingBoundConstrainedOptimization) { + val lowerBounds = Array.fill[Double](numCoeffsPlusIntercepts)(Double.NegativeInfinity) + val upperBounds = Array.fill[Double](numCoeffsPlusIntercepts)(Double.PositiveInfinity) + val isSetLowerBoundsOnCoefficients = isSet(lowerBoundsOnCoefficients) + val isSetUpperBoundsOnCoefficients = isSet(upperBoundsOnCoefficients) + val isSetLowerBoundsOnIntercepts = isSet(lowerBoundsOnIntercepts) + val isSetUpperBoundsOnIntercepts = isSet(upperBoundsOnIntercepts) + + var i = 0 + while (i < numCoeffsPlusIntercepts) { + val coefficientSetIndex = i % numCoefficientSets + val featureIndex = i / numCoefficientSets + if (featureIndex < numFeatures) { + if (isSetLowerBoundsOnCoefficients) { + lowerBounds(i) = $(lowerBoundsOnCoefficients)( + coefficientSetIndex, featureIndex) * featuresStd(featureIndex) + } + if (isSetUpperBoundsOnCoefficients) { + upperBounds(i) = $(upperBoundsOnCoefficients)( + coefficientSetIndex, featureIndex) * featuresStd(featureIndex) + } + } else { + if (isSetLowerBoundsOnIntercepts) { + lowerBounds(i) = $(lowerBoundsOnIntercepts)(coefficientSetIndex) + } + if (isSetUpperBoundsOnIntercepts) { + upperBounds(i) = $(upperBoundsOnIntercepts)(coefficientSetIndex) + } + } + i += 1 + } + (lowerBounds, upperBounds) + } else { + (null, null) + } + } + val optimizer = if ($(elasticNetParam) == 0.0 || $(regParam) == 0.0) { - new BreezeLBFGS[BDV[Double]]($(maxIter), 10, $(tol)) + if (lowerBounds != null && upperBounds != null) { + new BreezeLBFGSB( + BDV[Double](lowerBounds), BDV[Double](upperBounds), $(maxIter), 10, $(tol)) + } else { + new BreezeLBFGS[BDV[Double]]($(maxIter), 10, $(tol)) + } } else { val standardizationParam = $(standardization) def regParamL1Fun = (index: Int) => { @@ -546,6 +739,26 @@ class LogisticRegression @Since("1.2.0") ( math.log(histogram(1) / histogram(0))) } + if (usingBoundConstrainedOptimization) { + // Make sure all initial values locate in the corresponding bound. + var i = 0 + while (i < numCoeffsPlusIntercepts) { + val coefficientSetIndex = i % numCoefficientSets + val featureIndex = i / numCoefficientSets + if (initialCoefWithInterceptMatrix(coefficientSetIndex, featureIndex) < lowerBounds(i)) + { + initialCoefWithInterceptMatrix.update( + coefficientSetIndex, featureIndex, lowerBounds(i)) + } else if ( + initialCoefWithInterceptMatrix(coefficientSetIndex, featureIndex) > upperBounds(i)) + { + initialCoefWithInterceptMatrix.update( + coefficientSetIndex, featureIndex, upperBounds(i)) + } + i += 1 + } + } + val states = optimizer.iterations(new CachedDiffFunction(costFun), new BDV[Double](initialCoefWithInterceptMatrix.toArray)) @@ -599,7 +812,7 @@ class LogisticRegression @Since("1.2.0") ( if (isIntercept) interceptVec.toArray(classIndex) = value } - if ($(regParam) == 0.0 && isMultinomial) { + if ($(regParam) == 0.0 && isMultinomial && !usingBoundConstrainedOptimization) { /* When no regularization is applied, the multinomial coefficients lack identifiability because we do not use a pivot class. We can add any constant value to the coefficients @@ -620,7 +833,7 @@ class LogisticRegression @Since("1.2.0") ( } // center the intercepts when using multinomial algorithm - if ($(fitIntercept) && isMultinomial) { + if ($(fitIntercept) && isMultinomial && !usingBoundConstrainedOptimization) { val interceptArray = interceptVec.toArray val interceptMean = interceptArray.sum / interceptArray.length (0 until interceptVec.size).foreach { i => interceptArray(i) -= interceptMean } diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala index 83f575e83828f..bf6bfe30bfe20 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala @@ -26,7 +26,7 @@ import org.apache.spark.{SparkException, SparkFunSuite} import org.apache.spark.ml.attribute.NominalAttribute import org.apache.spark.ml.classification.LogisticRegressionSuite._ import org.apache.spark.ml.feature.{Instance, LabeledPoint} -import org.apache.spark.ml.linalg.{DenseMatrix, Matrices, SparseMatrix, Vector, Vectors} +import org.apache.spark.ml.linalg.{DenseMatrix, Matrices, Matrix, SparseMatrix, Vector, Vectors} import org.apache.spark.ml.param.{ParamMap, ParamsSuite} import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} import org.apache.spark.ml.util.TestingUtils._ @@ -150,6 +150,54 @@ class LogisticRegressionSuite assert(!model.hasSummary) } + test("logistic regression: illegal params") { + val lowerBoundsOnCoefficients = Matrices.dense(1, 4, Array(1.0, 0.0, 1.0, 0.0)) + val upperBoundsOnCoefficients1 = Matrices.dense(1, 4, Array(0.0, 1.0, 1.0, 0.0)) + val upperBoundsOnCoefficients2 = Matrices.dense(1, 3, Array(1.0, 0.0, 1.0)) + val lowerBoundsOnIntercepts = Vectors.dense(1.0) + + // Work well when only set bound in one side. + new LogisticRegression() + .setLowerBoundsOnCoefficients(lowerBoundsOnCoefficients) + .fit(binaryDataset) + + withClue("bound constrained optimization only supports L2 regularization") { + intercept[IllegalArgumentException] { + new LogisticRegression() + .setLowerBoundsOnCoefficients(lowerBoundsOnCoefficients) + .setElasticNetParam(1.0) + .fit(binaryDataset) + } + } + + withClue("lowerBoundsOnCoefficients should less than or equal to upperBoundsOnCoefficients") { + intercept[IllegalArgumentException] { + new LogisticRegression() + .setLowerBoundsOnCoefficients(lowerBoundsOnCoefficients) + .setUpperBoundsOnCoefficients(upperBoundsOnCoefficients1) + .fit(binaryDataset) + } + } + + withClue("the coefficients bound matrix mismatched with shape (1, number of features)") { + intercept[IllegalArgumentException] { + new LogisticRegression() + .setLowerBoundsOnCoefficients(lowerBoundsOnCoefficients) + .setUpperBoundsOnCoefficients(upperBoundsOnCoefficients2) + .fit(binaryDataset) + } + } + + withClue("bounds on intercepts should not be set if fitting without intercept") { + intercept[IllegalArgumentException] { + new LogisticRegression() + .setLowerBoundsOnIntercepts(lowerBoundsOnIntercepts) + .setFitIntercept(false) + .fit(binaryDataset) + } + } + } + test("empty probabilityCol") { val lr = new LogisticRegression().setProbabilityCol("") val model = lr.fit(smallBinaryDataset) @@ -610,6 +658,107 @@ class LogisticRegressionSuite assert(model2.coefficients ~= coefficientsR relTol 1E-3) } + test("binary logistic regression with intercept without regularization with bound") { + // Bound constrained optimization with bound on one side. + val upperBoundsOnCoefficients = Matrices.dense(1, 4, Array(1.0, 0.0, 1.0, 0.0)) + val upperBoundsOnIntercepts = Vectors.dense(1.0) + + val trainer1 = new LogisticRegression() + .setUpperBoundsOnCoefficients(upperBoundsOnCoefficients) + .setUpperBoundsOnIntercepts(upperBoundsOnIntercepts) + .setFitIntercept(true) + .setStandardization(true) + .setWeightCol("weight") + val trainer2 = new LogisticRegression() + .setUpperBoundsOnCoefficients(upperBoundsOnCoefficients) + .setUpperBoundsOnIntercepts(upperBoundsOnIntercepts) + .setFitIntercept(true) + .setStandardization(false) + .setWeightCol("weight") + + val model1 = trainer1.fit(binaryDataset) + val model2 = trainer2.fit(binaryDataset) + + // The solution is generated by https://github.com/yanboliang/bound-optimization. + val coefficientsExpected1 = Vectors.dense(0.06079437, 0.0, -0.26351059, -0.59102199) + val interceptExpected1 = 1.0 + + assert(model1.intercept ~== interceptExpected1 relTol 1E-3) + assert(model1.coefficients ~= coefficientsExpected1 relTol 1E-3) + + // Without regularization, with or without standardization will converge to the same solution. + assert(model2.intercept ~== interceptExpected1 relTol 1E-3) + assert(model2.coefficients ~= coefficientsExpected1 relTol 1E-3) + + // Bound constrained optimization with bound on both side. + val lowerBoundsOnCoefficients = Matrices.dense(1, 4, Array(0.0, -1.0, 0.0, -1.0)) + val lowerBoundsOnIntercepts = Vectors.dense(0.0) + + val trainer3 = new LogisticRegression() + .setUpperBoundsOnCoefficients(upperBoundsOnCoefficients) + .setUpperBoundsOnIntercepts(upperBoundsOnIntercepts) + .setLowerBoundsOnCoefficients(lowerBoundsOnCoefficients) + .setLowerBoundsOnIntercepts(lowerBoundsOnIntercepts) + .setFitIntercept(true) + .setStandardization(true) + .setWeightCol("weight") + val trainer4 = new LogisticRegression() + .setUpperBoundsOnCoefficients(upperBoundsOnCoefficients) + .setUpperBoundsOnIntercepts(upperBoundsOnIntercepts) + .setLowerBoundsOnCoefficients(lowerBoundsOnCoefficients) + .setLowerBoundsOnIntercepts(lowerBoundsOnIntercepts) + .setFitIntercept(true) + .setStandardization(false) + .setWeightCol("weight") + + val model3 = trainer3.fit(binaryDataset) + val model4 = trainer4.fit(binaryDataset) + + // The solution is generated by https://github.com/yanboliang/bound-optimization. + val coefficientsExpected3 = Vectors.dense(0.0, 0.0, 0.0, -0.71708632) + val interceptExpected3 = 0.58776113 + + assert(model3.intercept ~== interceptExpected3 relTol 1E-3) + assert(model3.coefficients ~= coefficientsExpected3 relTol 1E-3) + + // Without regularization, with or without standardization will converge to the same solution. + assert(model4.intercept ~== interceptExpected3 relTol 1E-3) + assert(model4.coefficients ~= coefficientsExpected3 relTol 1E-3) + + // Bound constrained optimization with infinite bound on both side. + val trainer5 = new LogisticRegression() + .setUpperBoundsOnCoefficients(Matrices.dense(1, 4, Array.fill(4)(Double.PositiveInfinity))) + .setUpperBoundsOnIntercepts(Vectors.dense(Double.PositiveInfinity)) + .setLowerBoundsOnCoefficients(Matrices.dense(1, 4, Array.fill(4)(Double.NegativeInfinity))) + .setLowerBoundsOnIntercepts(Vectors.dense(Double.NegativeInfinity)) + .setFitIntercept(true) + .setStandardization(true) + .setWeightCol("weight") + val trainer6 = new LogisticRegression() + .setUpperBoundsOnCoefficients(Matrices.dense(1, 4, Array.fill(4)(Double.PositiveInfinity))) + .setUpperBoundsOnIntercepts(Vectors.dense(Double.PositiveInfinity)) + .setLowerBoundsOnCoefficients(Matrices.dense(1, 4, Array.fill(4)(Double.NegativeInfinity))) + .setLowerBoundsOnIntercepts(Vectors.dense(Double.NegativeInfinity)) + .setFitIntercept(true) + .setStandardization(false) + .setWeightCol("weight") + + val model5 = trainer5.fit(binaryDataset) + val model6 = trainer6.fit(binaryDataset) + + // The solution is generated by https://github.com/yanboliang/bound-optimization. + // It should be same as unbound constrained optimization with LBFGS. + val coefficientsExpected5 = Vectors.dense(-0.5734389, 0.8911736, -0.3878645, -0.8060570) + val interceptExpected5 = 2.7355261 + + assert(model5.intercept ~== interceptExpected5 relTol 1E-3) + assert(model5.coefficients ~= coefficientsExpected5 relTol 1E-3) + + // Without regularization, with or without standardization will converge to the same solution. + assert(model6.intercept ~== interceptExpected5 relTol 1E-3) + assert(model6.coefficients ~= coefficientsExpected5 relTol 1E-3) + } + test("binary logistic regression without intercept without regularization") { val trainer1 = (new LogisticRegression).setFitIntercept(false).setStandardization(true) .setWeightCol("weight") @@ -650,6 +799,34 @@ class LogisticRegressionSuite assert(model2.coefficients ~= coefficientsR relTol 1E-2) } + test("binary logistic regression without intercept without regularization with bound") { + val upperBoundsOnCoefficients = Matrices.dense(1, 4, Array(1.0, 0.0, 1.0, 0.0)).toSparse + + val trainer1 = new LogisticRegression() + .setUpperBoundsOnCoefficients(upperBoundsOnCoefficients) + .setFitIntercept(false) + .setStandardization(true) + .setWeightCol("weight") + val trainer2 = new LogisticRegression() + .setUpperBoundsOnCoefficients(upperBoundsOnCoefficients) + .setFitIntercept(false) + .setStandardization(false) + .setWeightCol("weight") + + val model1 = trainer1.fit(binaryDataset) + val model2 = trainer2.fit(binaryDataset) + + // The solution is generated by https://github.com/yanboliang/bound-optimization. + val coefficientsExpected = Vectors.dense(0.20847553, 0.0, -0.24240289, -0.55568071) + + assert(model1.intercept ~== 0.0 relTol 1E-3) + assert(model1.coefficients ~= coefficientsExpected relTol 1E-3) + + // Without regularization, with or without standardization will converge to the same solution. + assert(model2.intercept ~== 0.0 relTol 1E-3) + assert(model2.coefficients ~= coefficientsExpected relTol 1E-3) + } + test("binary logistic regression with intercept with L1 regularization") { val trainer1 = (new LogisticRegression).setFitIntercept(true) .setElasticNetParam(1.0).setRegParam(0.12).setStandardization(true).setWeightCol("weight") @@ -815,6 +992,40 @@ class LogisticRegressionSuite assert(model2.coefficients ~= coefficientsR relTol 1E-3) } + test("binary logistic regression with intercept with L2 regularization with bound") { + val upperBoundsOnCoefficients = Matrices.dense(1, 4, Array(1.0, 0.0, 1.0, 0.0)) + val upperBoundsOnIntercepts = Vectors.dense(1.0) + + val trainer1 = new LogisticRegression() + .setUpperBoundsOnCoefficients(upperBoundsOnCoefficients) + .setUpperBoundsOnIntercepts(upperBoundsOnIntercepts) + .setRegParam(1.37) + .setFitIntercept(true) + .setStandardization(true) + .setWeightCol("weight") + val trainer2 = new LogisticRegression() + .setUpperBoundsOnCoefficients(upperBoundsOnCoefficients) + .setUpperBoundsOnIntercepts(upperBoundsOnIntercepts) + .setRegParam(1.37) + .setFitIntercept(true) + .setStandardization(false) + .setWeightCol("weight") + + val model1 = trainer1.fit(binaryDataset) + val model2 = trainer2.fit(binaryDataset) + + // The solution is generated by https://github.com/yanboliang/bound-optimization. + val coefficientsExpectedWithStd = Vectors.dense(-0.06985003, 0.0, -0.04794278, -0.10168595) + val interceptExpectedWithStd = 0.45750141 + val coefficientsExpected = Vectors.dense(-0.0494524, 0.0, -0.11360797, -0.06313577) + val interceptExpected = 0.53722967 + + assert(model1.intercept ~== interceptExpectedWithStd relTol 1E-3) + assert(model1.coefficients ~= coefficientsExpectedWithStd relTol 1E-3) + assert(model2.intercept ~== interceptExpected relTol 1E-3) + assert(model2.coefficients ~= coefficientsExpected relTol 1E-3) + } + test("binary logistic regression without intercept with L2 regularization") { val trainer1 = (new LogisticRegression).setFitIntercept(false) .setElasticNetParam(0.0).setRegParam(1.37).setStandardization(true).setWeightCol("weight") @@ -864,6 +1075,35 @@ class LogisticRegressionSuite assert(model2.coefficients ~= coefficientsR relTol 1E-2) } + test("binary logistic regression without intercept with L2 regularization with bound") { + val upperBoundsOnCoefficients = Matrices.dense(1, 4, Array(1.0, 0.0, 1.0, 0.0)) + + val trainer1 = new LogisticRegression() + .setUpperBoundsOnCoefficients(upperBoundsOnCoefficients) + .setRegParam(1.37) + .setFitIntercept(false) + .setStandardization(true) + .setWeightCol("weight") + val trainer2 = new LogisticRegression() + .setUpperBoundsOnCoefficients(upperBoundsOnCoefficients) + .setRegParam(1.37) + .setFitIntercept(false) + .setStandardization(false) + .setWeightCol("weight") + + val model1 = trainer1.fit(binaryDataset) + val model2 = trainer2.fit(binaryDataset) + + // The solution is generated by https://github.com/yanboliang/bound-optimization. + val coefficientsExpectedWithStd = Vectors.dense(-0.00796538, 0.0, -0.0394228, -0.0873314) + val coefficientsExpected = Vectors.dense(0.01105972, 0.0, -0.08574949, -0.05079558) + + assert(model1.intercept ~== 0.0 relTol 1E-3) + assert(model1.coefficients ~= coefficientsExpectedWithStd relTol 1E-3) + assert(model2.intercept ~== 0.0 relTol 1E-3) + assert(model2.coefficients ~= coefficientsExpected relTol 1E-3) + } + test("binary logistic regression with intercept with ElasticNet regularization") { val trainer1 = (new LogisticRegression).setFitIntercept(true).setMaxIter(200) .setElasticNetParam(0.38).setRegParam(0.21).setStandardization(true).setWeightCol("weight") @@ -1084,7 +1324,6 @@ class LogisticRegressionSuite } test("multinomial logistic regression with intercept without regularization") { - val trainer1 = (new LogisticRegression).setFitIntercept(true) .setElasticNetParam(0.0).setRegParam(0.0).setStandardization(true).setWeightCol("weight") val trainer2 = (new LogisticRegression).setFitIntercept(true) @@ -1152,6 +1391,110 @@ class LogisticRegressionSuite assert(model2.interceptVector.toArray.sum ~== 0.0 absTol eps) } + test("multinomial logistic regression with intercept without regularization with bound") { + // Bound constrained optimization with bound on one side. + val lowerBoundsOnCoefficients = Matrices.dense(3, 4, Array.fill(12)(1.0)) + val lowerBoundsOnIntercepts = Vectors.dense(Array.fill(3)(1.0)) + + val trainer1 = new LogisticRegression() + .setLowerBoundsOnCoefficients(lowerBoundsOnCoefficients) + .setLowerBoundsOnIntercepts(lowerBoundsOnIntercepts) + .setFitIntercept(true) + .setStandardization(true) + .setWeightCol("weight") + val trainer2 = new LogisticRegression() + .setLowerBoundsOnCoefficients(lowerBoundsOnCoefficients) + .setLowerBoundsOnIntercepts(lowerBoundsOnIntercepts) + .setFitIntercept(true) + .setStandardization(false) + .setWeightCol("weight") + + val model1 = trainer1.fit(multinomialDataset) + val model2 = trainer2.fit(multinomialDataset) + + // The solution is generated by https://github.com/yanboliang/bound-optimization. + val coefficientsExpected1 = new DenseMatrix(3, 4, Array( + 2.52076464, 2.73596057, 1.87984904, 2.73264492, + 1.93302281, 3.71363303, 1.50681746, 1.93398782, + 2.37839917, 1.93601818, 1.81924758, 2.45191255), isTransposed = true) + val interceptsExpected1 = Vectors.dense(1.00010477, 3.44237083, 4.86740286) + + checkCoefficientsEquivalent(model1.coefficientMatrix, coefficientsExpected1) + assert(model1.interceptVector ~== interceptsExpected1 relTol 0.01) + checkCoefficientsEquivalent(model2.coefficientMatrix, coefficientsExpected1) + assert(model2.interceptVector ~== interceptsExpected1 relTol 0.01) + + // Bound constrained optimization with bound on both side. + val upperBoundsOnCoefficients = Matrices.dense(3, 4, Array.fill(12)(2.0)) + val upperBoundsOnIntercepts = Vectors.dense(Array.fill(3)(2.0)) + + val trainer3 = new LogisticRegression() + .setLowerBoundsOnCoefficients(lowerBoundsOnCoefficients) + .setLowerBoundsOnIntercepts(lowerBoundsOnIntercepts) + .setUpperBoundsOnCoefficients(upperBoundsOnCoefficients) + .setUpperBoundsOnIntercepts(upperBoundsOnIntercepts) + .setFitIntercept(true) + .setStandardization(true) + .setWeightCol("weight") + val trainer4 = new LogisticRegression() + .setLowerBoundsOnCoefficients(lowerBoundsOnCoefficients) + .setLowerBoundsOnIntercepts(lowerBoundsOnIntercepts) + .setUpperBoundsOnCoefficients(upperBoundsOnCoefficients) + .setUpperBoundsOnIntercepts(upperBoundsOnIntercepts) + .setFitIntercept(true) + .setStandardization(false) + .setWeightCol("weight") + + val model3 = trainer3.fit(multinomialDataset) + val model4 = trainer4.fit(multinomialDataset) + + // The solution is generated by https://github.com/yanboliang/bound-optimization. + val coefficientsExpected3 = new DenseMatrix(3, 4, Array( + 1.61967097, 1.16027835, 1.45131448, 1.97390431, + 1.30529317, 2.0, 1.12985473, 1.26652854, + 1.61647195, 1.0, 1.40642959, 1.72985589), isTransposed = true) + val interceptsExpected3 = Vectors.dense(1.0, 2.0, 2.0) + + checkCoefficientsEquivalent(model3.coefficientMatrix, coefficientsExpected3) + assert(model3.interceptVector ~== interceptsExpected3 relTol 0.01) + checkCoefficientsEquivalent(model4.coefficientMatrix, coefficientsExpected3) + assert(model4.interceptVector ~== interceptsExpected3 relTol 0.01) + + // Bound constrained optimization with infinite bound on both side. + val trainer5 = new LogisticRegression() + .setLowerBoundsOnCoefficients(Matrices.dense(3, 4, Array.fill(12)(Double.NegativeInfinity))) + .setLowerBoundsOnIntercepts(Vectors.dense(Array.fill(3)(Double.NegativeInfinity))) + .setUpperBoundsOnCoefficients(Matrices.dense(3, 4, Array.fill(12)(Double.PositiveInfinity))) + .setUpperBoundsOnIntercepts(Vectors.dense(Array.fill(3)(Double.PositiveInfinity))) + .setFitIntercept(true) + .setStandardization(true) + .setWeightCol("weight") + val trainer6 = new LogisticRegression() + .setLowerBoundsOnCoefficients(Matrices.dense(3, 4, Array.fill(12)(Double.NegativeInfinity))) + .setLowerBoundsOnIntercepts(Vectors.dense(Array.fill(3)(Double.NegativeInfinity))) + .setUpperBoundsOnCoefficients(Matrices.dense(3, 4, Array.fill(12)(Double.PositiveInfinity))) + .setUpperBoundsOnIntercepts(Vectors.dense(Array.fill(3)(Double.PositiveInfinity))) + .setFitIntercept(true) + .setStandardization(false) + .setWeightCol("weight") + + val model5 = trainer5.fit(multinomialDataset) + val model6 = trainer6.fit(multinomialDataset) + + // The solution is generated by https://github.com/yanboliang/bound-optimization. + // It should be same as unbound constrained optimization with LBFGS. + val coefficientsExpected5 = new DenseMatrix(3, 4, Array( + 0.24337896, -0.05916156, 0.14446790, 0.35976165, + -0.3443375, 0.9181331, -0.2283959, -0.4388066, + 0.10095851, -0.85897154, 0.08392798, 0.07904499), isTransposed = true) + val interceptsExpected5 = Vectors.dense(-2.10320093, 0.3394473, 1.76375361) + + checkCoefficientsEquivalent(model5.coefficientMatrix, coefficientsExpected5) + assert(model5.interceptVector ~== interceptsExpected5 relTol 0.01) + checkCoefficientsEquivalent(model6.coefficientMatrix, coefficientsExpected5) + assert(model6.interceptVector ~== interceptsExpected5 relTol 0.01) + } + test("multinomial logistic regression without intercept without regularization") { val trainer1 = (new LogisticRegression).setFitIntercept(false) @@ -1220,6 +1563,35 @@ class LogisticRegressionSuite assert(model2.interceptVector.toArray.sum ~== 0.0 absTol eps) } + test("multinomial logistic regression without intercept without regularization with bound") { + val lowerBoundsOnCoefficients = Matrices.dense(3, 4, Array.fill(12)(1.0)) + + val trainer1 = new LogisticRegression() + .setLowerBoundsOnCoefficients(lowerBoundsOnCoefficients) + .setFitIntercept(false) + .setStandardization(true) + .setWeightCol("weight") + val trainer2 = new LogisticRegression() + .setLowerBoundsOnCoefficients(lowerBoundsOnCoefficients) + .setFitIntercept(false) + .setStandardization(false) + .setWeightCol("weight") + + val model1 = trainer1.fit(multinomialDataset) + val model2 = trainer2.fit(multinomialDataset) + + // The solution is generated by https://github.com/yanboliang/bound-optimization. + val coefficientsExpected = new DenseMatrix(3, 4, Array( + 1.62410051, 1.38219391, 1.34486618, 1.74641729, + 1.23058989, 2.71787825, 1.0, 1.00007073, + 1.79478632, 1.14360459, 1.33011603, 1.55093897), isTransposed = true) + + checkCoefficientsEquivalent(model1.coefficientMatrix, coefficientsExpected) + assert(model1.interceptVector.toArray === Array.fill(3)(0.0)) + checkCoefficientsEquivalent(model2.coefficientMatrix, coefficientsExpected) + assert(model2.interceptVector.toArray === Array.fill(3)(0.0)) + } + test("multinomial logistic regression with intercept with L1 regularization") { // use tighter constraints because OWL-QN solver takes longer to converge @@ -1518,6 +1890,46 @@ class LogisticRegressionSuite assert(model2.interceptVector.toArray.sum ~== 0.0 absTol eps) } + test("multinomial logistic regression with intercept with L2 regularization with bound") { + val lowerBoundsOnCoefficients = Matrices.dense(3, 4, Array.fill(12)(1.0)) + val lowerBoundsOnIntercepts = Vectors.dense(Array.fill(3)(1.0)) + + val trainer1 = new LogisticRegression() + .setLowerBoundsOnCoefficients(lowerBoundsOnCoefficients) + .setLowerBoundsOnIntercepts(lowerBoundsOnIntercepts) + .setRegParam(0.1) + .setFitIntercept(true) + .setStandardization(true) + .setWeightCol("weight") + val trainer2 = new LogisticRegression() + .setLowerBoundsOnCoefficients(lowerBoundsOnCoefficients) + .setLowerBoundsOnIntercepts(lowerBoundsOnIntercepts) + .setRegParam(0.1) + .setFitIntercept(true) + .setStandardization(false) + .setWeightCol("weight") + + val model1 = trainer1.fit(multinomialDataset) + val model2 = trainer2.fit(multinomialDataset) + + // The solution is generated by https://github.com/yanboliang/bound-optimization. + val coefficientsExpectedWithStd = new DenseMatrix(3, 4, Array( + 1.0, 1.0, 1.0, 1.01647497, + 1.0, 1.44105616, 1.0, 1.0, + 1.0, 1.0, 1.0, 1.0), isTransposed = true) + val interceptsExpectedWithStd = Vectors.dense(2.52055893, 1.0, 2.560682) + val coefficientsExpected = new DenseMatrix(3, 4, Array( + 1.0, 1.0, 1.03189386, 1.0, + 1.0, 1.0, 1.0, 1.0, + 1.0, 1.0, 1.0, 1.0), isTransposed = true) + val interceptsExpected = Vectors.dense(1.06418835, 1.0, 1.20494701) + + assert(model1.coefficientMatrix ~== coefficientsExpectedWithStd relTol 0.01) + assert(model1.interceptVector ~== interceptsExpectedWithStd relTol 0.01) + assert(model2.coefficientMatrix ~== coefficientsExpected relTol 0.01) + assert(model2.interceptVector ~== interceptsExpected relTol 0.01) + } + test("multinomial logistic regression without intercept with L2 regularization") { val trainer1 = (new LogisticRegression).setFitIntercept(false) .setElasticNetParam(0.0).setRegParam(0.1).setStandardization(true).setWeightCol("weight") @@ -1615,6 +2027,41 @@ class LogisticRegressionSuite assert(model2.interceptVector.toArray.sum ~== 0.0 absTol eps) } + test("multinomial logistic regression without intercept with L2 regularization with bound") { + val lowerBoundsOnCoefficients = Matrices.dense(3, 4, Array.fill(12)(1.0)) + + val trainer1 = new LogisticRegression() + .setLowerBoundsOnCoefficients(lowerBoundsOnCoefficients) + .setRegParam(0.1) + .setFitIntercept(false) + .setStandardization(true) + .setWeightCol("weight") + val trainer2 = new LogisticRegression() + .setLowerBoundsOnCoefficients(lowerBoundsOnCoefficients) + .setRegParam(0.1) + .setFitIntercept(false) + .setStandardization(false) + .setWeightCol("weight") + + val model1 = trainer1.fit(multinomialDataset) + val model2 = trainer2.fit(multinomialDataset) + + // The solution is generated by https://github.com/yanboliang/bound-optimization. + val coefficientsExpectedWithStd = new DenseMatrix(3, 4, Array( + 1.01324653, 1.0, 1.0, 1.0415767, + 1.0, 1.0, 1.0, 1.0, + 1.02244888, 1.0, 1.0, 1.0), isTransposed = true) + val coefficientsExpected = new DenseMatrix(3, 4, Array( + 1.0, 1.0, 1.03932259, 1.0, + 1.0, 1.0, 1.0, 1.0, + 1.0, 1.0, 1.03274649, 1.0), isTransposed = true) + + assert(model1.coefficientMatrix ~== coefficientsExpectedWithStd absTol 0.01) + assert(model1.interceptVector.toArray === Array.fill(3)(0.0)) + assert(model2.coefficientMatrix ~== coefficientsExpected absTol 0.01) + assert(model2.interceptVector.toArray === Array.fill(3)(0.0)) + } + test("multinomial logistic regression with intercept with elasticnet regularization") { val trainer1 = (new LogisticRegression).setFitIntercept(true).setWeightCol("weight") .setElasticNetParam(0.5).setRegParam(0.1).setStandardization(true) @@ -2273,4 +2720,19 @@ object LogisticRegressionSuite { val testData = (0 until nPoints).map(i => LabeledPoint(y(i), x(i))) testData } + + /** + * When no regularization is applied, the multinomial coefficients lack identifiability + * because we do not use a pivot class. We can add any constant value to the coefficients + * and get the same likelihood. If fitting under bound constrained optimization, we don't + * choose the mean centered coefficients like what we do for unbound problems, since they + * may out of the bounds. We use this function to check whether two coefficients are equivalent. + */ + def checkCoefficientsEquivalent(coefficients1: Matrix, coefficients2: Matrix): Unit = { + coefficients1.colIter.zip(coefficients2.colIter).foreach { case (col1: Vector, col2: Vector) => + (col1.asBreeze - col2.asBreeze).toArray.toSeq.sliding(2).foreach { + case Seq(v1, v2) => assert(v1 ~= v2 absTol 1E-3) + } + } + } } From 01c999e7f94d5e6c2fce67304dc62351dfbdf963 Mon Sep 17 00:00:00 2001 From: Shixiong Zhu Date: Thu, 27 Apr 2017 13:55:03 -0700 Subject: [PATCH 0354/1765] [SPARK-20461][CORE][SS] Use UninterruptibleThread for Executor and fix the potential hang in CachedKafkaConsumer ## What changes were proposed in this pull request? This PR changes Executor's threads to `UninterruptibleThread` so that we can use `runUninterruptibly` in `CachedKafkaConsumer`. However, this is just best effort to avoid hanging forever. If the user uses`CachedKafkaConsumer` in another thread (e.g., create a new thread or Future), the potential hang may still happen. ## How was this patch tested? The new added test. Author: Shixiong Zhu Closes #17761 from zsxwing/int. --- .../org/apache/spark/executor/Executor.scala | 19 +++++++++++++++++-- .../spark/util/UninterruptibleThread.scala | 8 +++++++- .../apache/spark/executor/ExecutorSuite.scala | 13 +++++++++++++ .../sql/kafka010/CachedKafkaConsumer.scala | 15 +++++++++++++-- 4 files changed, 50 insertions(+), 5 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/executor/Executor.scala b/core/src/main/scala/org/apache/spark/executor/Executor.scala index 18f04391d64c3..51b6c373c4daf 100644 --- a/core/src/main/scala/org/apache/spark/executor/Executor.scala +++ b/core/src/main/scala/org/apache/spark/executor/Executor.scala @@ -23,13 +23,15 @@ import java.lang.management.ManagementFactory import java.net.{URI, URL} import java.nio.ByteBuffer import java.util.Properties -import java.util.concurrent.{ConcurrentHashMap, TimeUnit} +import java.util.concurrent._ import javax.annotation.concurrent.GuardedBy import scala.collection.JavaConverters._ import scala.collection.mutable.{ArrayBuffer, HashMap, Map} import scala.util.control.NonFatal +import com.google.common.util.concurrent.ThreadFactoryBuilder + import org.apache.spark._ import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.internal.Logging @@ -84,7 +86,20 @@ private[spark] class Executor( } // Start worker thread pool - private val threadPool = ThreadUtils.newDaemonCachedThreadPool("Executor task launch worker") + private val threadPool = { + val threadFactory = new ThreadFactoryBuilder() + .setDaemon(true) + .setNameFormat("Executor task launch worker-%d") + .setThreadFactory(new ThreadFactory { + override def newThread(r: Runnable): Thread = + // Use UninterruptibleThread to run tasks so that we can allow running codes without being + // interrupted by `Thread.interrupt()`. Some issues, such as KAFKA-1894, HADOOP-10622, + // will hang forever if some methods are interrupted. + new UninterruptibleThread(r, "unused") // thread name will be set by ThreadFactoryBuilder + }) + .build() + Executors.newCachedThreadPool(threadFactory).asInstanceOf[ThreadPoolExecutor] + } private val executorSource = new ExecutorSource(threadPool, executorId) // Pool used for threads that supervise task killing / cancellation private val taskReaperPool = ThreadUtils.newDaemonCachedThreadPool("Task reaper") diff --git a/core/src/main/scala/org/apache/spark/util/UninterruptibleThread.scala b/core/src/main/scala/org/apache/spark/util/UninterruptibleThread.scala index f0b68f0cb7e29..27922b31949b6 100644 --- a/core/src/main/scala/org/apache/spark/util/UninterruptibleThread.scala +++ b/core/src/main/scala/org/apache/spark/util/UninterruptibleThread.scala @@ -27,7 +27,13 @@ import javax.annotation.concurrent.GuardedBy * * Note: "runUninterruptibly" should be called only in `this` thread. */ -private[spark] class UninterruptibleThread(name: String) extends Thread(name) { +private[spark] class UninterruptibleThread( + target: Runnable, + name: String) extends Thread(target, name) { + + def this(name: String) { + this(null, name) + } /** A monitor to protect "uninterruptible" and "interrupted" */ private val uninterruptibleLock = new Object diff --git a/core/src/test/scala/org/apache/spark/executor/ExecutorSuite.scala b/core/src/test/scala/org/apache/spark/executor/ExecutorSuite.scala index f47e574b4fc4b..efcad140350b9 100644 --- a/core/src/test/scala/org/apache/spark/executor/ExecutorSuite.scala +++ b/core/src/test/scala/org/apache/spark/executor/ExecutorSuite.scala @@ -44,6 +44,7 @@ import org.apache.spark.scheduler.{FakeTask, ResultTask, TaskDescription} import org.apache.spark.serializer.JavaSerializer import org.apache.spark.shuffle.FetchFailedException import org.apache.spark.storage.BlockManagerId +import org.apache.spark.util.UninterruptibleThread class ExecutorSuite extends SparkFunSuite with LocalSparkContext with MockitoSugar with Eventually { @@ -158,6 +159,18 @@ class ExecutorSuite extends SparkFunSuite with LocalSparkContext with MockitoSug assert(failReason.isInstanceOf[FetchFailed]) } + test("Executor's worker threads should be UninterruptibleThread") { + val conf = new SparkConf() + .setMaster("local") + .setAppName("executor thread test") + .set("spark.ui.enabled", "false") + sc = new SparkContext(conf) + val executorThread = sc.parallelize(Seq(1), 1).map { _ => + Thread.currentThread.getClass.getName + }.collect().head + assert(executorThread === classOf[UninterruptibleThread].getName) + } + test("SPARK-19276: OOMs correctly handled with a FetchFailure") { // when there is a fatal error like an OOM, we don't do normal fetch failure handling, since it // may be a false positive. And we should call the uncaught exception handler. diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/CachedKafkaConsumer.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/CachedKafkaConsumer.scala index 6d76904fb0e59..bf6c0900c97e1 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/CachedKafkaConsumer.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/CachedKafkaConsumer.scala @@ -28,6 +28,7 @@ import org.apache.kafka.common.TopicPartition import org.apache.spark.{SparkEnv, SparkException, TaskContext} import org.apache.spark.internal.Logging import org.apache.spark.sql.kafka010.KafkaSource._ +import org.apache.spark.util.UninterruptibleThread /** @@ -62,11 +63,20 @@ private[kafka010] case class CachedKafkaConsumer private( case class AvailableOffsetRange(earliest: Long, latest: Long) + private def runUninterruptiblyIfPossible[T](body: => T): T = Thread.currentThread match { + case ut: UninterruptibleThread => + ut.runUninterruptibly(body) + case _ => + logWarning("CachedKafkaConsumer is not running in UninterruptibleThread. " + + "It may hang when CachedKafkaConsumer's methods are interrupted because of KAFKA-1894") + body + } + /** * Return the available offset range of the current partition. It's a pair of the earliest offset * and the latest offset. */ - def getAvailableOffsetRange(): AvailableOffsetRange = { + def getAvailableOffsetRange(): AvailableOffsetRange = runUninterruptiblyIfPossible { consumer.seekToBeginning(Set(topicPartition).asJava) val earliestOffset = consumer.position(topicPartition) consumer.seekToEnd(Set(topicPartition).asJava) @@ -92,7 +102,8 @@ private[kafka010] case class CachedKafkaConsumer private( offset: Long, untilOffset: Long, pollTimeoutMs: Long, - failOnDataLoss: Boolean): ConsumerRecord[Array[Byte], Array[Byte]] = { + failOnDataLoss: Boolean): + ConsumerRecord[Array[Byte], Array[Byte]] = runUninterruptiblyIfPossible { require(offset < untilOffset, s"offset must always be less than untilOffset [offset: $offset, untilOffset: $untilOffset]") logDebug(s"Get $groupId $topicPartition nextOffset $nextOffsetInFetchedData requested $offset") From 823baca2cb8edb62885af547d3511c9e8923cefd Mon Sep 17 00:00:00 2001 From: Shixiong Zhu Date: Thu, 27 Apr 2017 13:58:44 -0700 Subject: [PATCH 0355/1765] [SPARK-20452][SS][KAFKA] Fix a potential ConcurrentModificationException for batch Kafka DataFrame ## What changes were proposed in this pull request? Cancel a batch Kafka query but one of task cannot be cancelled, and rerun the same DataFrame may cause ConcurrentModificationException because it may launch two tasks sharing the same group id. This PR always create a new consumer when `reuseKafkaConsumer = false` to avoid ConcurrentModificationException. It also contains other minor fixes. ## How was this patch tested? Jenkins. Author: Shixiong Zhu Closes #17752 from zsxwing/kafka-fix. --- .../sql/kafka010/CachedKafkaConsumer.scala | 12 +- .../sql/kafka010/KafkaOffsetReader.scala | 6 +- .../spark/sql/kafka010/KafkaRelation.scala | 30 +++- .../sql/kafka010/KafkaSourceProvider.scala | 147 ++++++++---------- .../spark/sql/kafka010/KafkaSourceRDD.scala | 19 ++- .../spark/streaming/kafka010/KafkaRDD.scala | 2 +- 6 files changed, 119 insertions(+), 97 deletions(-) diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/CachedKafkaConsumer.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/CachedKafkaConsumer.scala index bf6c0900c97e1..7c4f38e02fb2a 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/CachedKafkaConsumer.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/CachedKafkaConsumer.scala @@ -287,7 +287,7 @@ private[kafka010] case class CachedKafkaConsumer private( reportDataLoss0(failOnDataLoss, finalMessage, cause) } - private def close(): Unit = consumer.close() + def close(): Unit = consumer.close() private def seek(offset: Long): Unit = { logDebug(s"Seeking to $groupId $topicPartition $offset") @@ -382,7 +382,7 @@ private[kafka010] object CachedKafkaConsumer extends Logging { // If this is reattempt at running the task, then invalidate cache and start with // a new consumer - if (TaskContext.get != null && TaskContext.get.attemptNumber > 1) { + if (TaskContext.get != null && TaskContext.get.attemptNumber >= 1) { removeKafkaConsumer(topic, partition, kafkaParams) val consumer = new CachedKafkaConsumer(topicPartition, kafkaParams) consumer.inuse = true @@ -398,6 +398,14 @@ private[kafka010] object CachedKafkaConsumer extends Logging { } } + /** Create an [[CachedKafkaConsumer]] but don't put it into cache. */ + def createUncached( + topic: String, + partition: Int, + kafkaParams: ju.Map[String, Object]): CachedKafkaConsumer = { + new CachedKafkaConsumer(new TopicPartition(topic, partition), kafkaParams) + } + private def reportDataLoss0( failOnDataLoss: Boolean, finalMessage: String, diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaOffsetReader.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaOffsetReader.scala index 2696d6f089d2f..3e65949a6fd1b 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaOffsetReader.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaOffsetReader.scala @@ -95,8 +95,10 @@ private[kafka010] class KafkaOffsetReader( * Closes the connection to Kafka, and cleans up state. */ def close(): Unit = { - consumer.close() - kafkaReaderThread.shutdownNow() + runUninterruptibly { + consumer.close() + } + kafkaReaderThread.shutdown() } /** diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaRelation.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaRelation.scala index f180bbad6e363..97bd283169323 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaRelation.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaRelation.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.kafka010 import java.{util => ju} +import java.util.UUID import org.apache.kafka.common.TopicPartition @@ -33,9 +34,9 @@ import org.apache.spark.unsafe.types.UTF8String private[kafka010] class KafkaRelation( override val sqlContext: SQLContext, - kafkaReader: KafkaOffsetReader, - executorKafkaParams: ju.Map[String, Object], + strategy: ConsumerStrategy, sourceOptions: Map[String, String], + specifiedKafkaParams: Map[String, String], failOnDataLoss: Boolean, startingOffsets: KafkaOffsetRangeLimit, endingOffsets: KafkaOffsetRangeLimit) @@ -53,9 +54,27 @@ private[kafka010] class KafkaRelation( override def schema: StructType = KafkaOffsetReader.kafkaSchema override def buildScan(): RDD[Row] = { + // Each running query should use its own group id. Otherwise, the query may be only assigned + // partial data since Kafka will assign partitions to multiple consumers having the same group + // id. Hence, we should generate a unique id for each query. + val uniqueGroupId = s"spark-kafka-relation-${UUID.randomUUID}" + + val kafkaOffsetReader = new KafkaOffsetReader( + strategy, + KafkaSourceProvider.kafkaParamsForDriver(specifiedKafkaParams), + sourceOptions, + driverGroupIdPrefix = s"$uniqueGroupId-driver") + // Leverage the KafkaReader to obtain the relevant partition offsets - val fromPartitionOffsets = getPartitionOffsets(startingOffsets) - val untilPartitionOffsets = getPartitionOffsets(endingOffsets) + val (fromPartitionOffsets, untilPartitionOffsets) = { + try { + (getPartitionOffsets(kafkaOffsetReader, startingOffsets), + getPartitionOffsets(kafkaOffsetReader, endingOffsets)) + } finally { + kafkaOffsetReader.close() + } + } + // Obtain topicPartitions in both from and until partition offset, ignoring // topic partitions that were added and/or deleted between the two above calls. if (fromPartitionOffsets.keySet != untilPartitionOffsets.keySet) { @@ -82,6 +101,8 @@ private[kafka010] class KafkaRelation( offsetRanges.sortBy(_.topicPartition.toString).mkString(", ")) // Create an RDD that reads from Kafka and get the (key, value) pair as byte arrays. + val executorKafkaParams = + KafkaSourceProvider.kafkaParamsForExecutors(specifiedKafkaParams, uniqueGroupId) val rdd = new KafkaSourceRDD( sqlContext.sparkContext, executorKafkaParams, offsetRanges, pollTimeoutMs, failOnDataLoss, reuseKafkaConsumer = false).map { cr => @@ -98,6 +119,7 @@ private[kafka010] class KafkaRelation( } private def getPartitionOffsets( + kafkaReader: KafkaOffsetReader, kafkaOffsets: KafkaOffsetRangeLimit): Map[TopicPartition, Long] = { def validateTopicPartitions(partitions: Set[TopicPartition], partitionOffsets: Map[TopicPartition, Long]): Map[TopicPartition, Long] = { diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala index ab1ce347cbe34..3cb4d8cad12cc 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala @@ -111,10 +111,6 @@ private[kafka010] class KafkaSourceProvider extends DataSourceRegister sqlContext: SQLContext, parameters: Map[String, String]): BaseRelation = { validateBatchOptions(parameters) - // Each running query should use its own group id. Otherwise, the query may be only assigned - // partial data since Kafka will assign partitions to multiple consumers having the same group - // id. Hence, we should generate a unique id for each query. - val uniqueGroupId = s"spark-kafka-relation-${UUID.randomUUID}" val caseInsensitiveParams = parameters.map { case (k, v) => (k.toLowerCase(Locale.ROOT), v) } val specifiedKafkaParams = parameters @@ -131,20 +127,14 @@ private[kafka010] class KafkaSourceProvider extends DataSourceRegister ENDING_OFFSETS_OPTION_KEY, LatestOffsetRangeLimit) assert(endingRelationOffsets != EarliestOffsetRangeLimit) - val kafkaOffsetReader = new KafkaOffsetReader( - strategy(caseInsensitiveParams), - kafkaParamsForDriver(specifiedKafkaParams), - parameters, - driverGroupIdPrefix = s"$uniqueGroupId-driver") - new KafkaRelation( sqlContext, - kafkaOffsetReader, - kafkaParamsForExecutors(specifiedKafkaParams, uniqueGroupId), - parameters, - failOnDataLoss(caseInsensitiveParams), - startingRelationOffsets, - endingRelationOffsets) + strategy(caseInsensitiveParams), + sourceOptions = parameters, + specifiedKafkaParams = specifiedKafkaParams, + failOnDataLoss = failOnDataLoss(caseInsensitiveParams), + startingOffsets = startingRelationOffsets, + endingOffsets = endingRelationOffsets) } override def createSink( @@ -213,46 +203,6 @@ private[kafka010] class KafkaSourceProvider extends DataSourceRegister ProducerConfig.VALUE_SERIALIZER_CLASS_CONFIG -> classOf[ByteArraySerializer].getName) } - private def kafkaParamsForDriver(specifiedKafkaParams: Map[String, String]) = - ConfigUpdater("source", specifiedKafkaParams) - .set(ConsumerConfig.KEY_DESERIALIZER_CLASS_CONFIG, deserClassName) - .set(ConsumerConfig.VALUE_DESERIALIZER_CLASS_CONFIG, deserClassName) - - // Set to "earliest" to avoid exceptions. However, KafkaSource will fetch the initial - // offsets by itself instead of counting on KafkaConsumer. - .set(ConsumerConfig.AUTO_OFFSET_RESET_CONFIG, "earliest") - - // So that consumers in the driver does not commit offsets unnecessarily - .set(ConsumerConfig.ENABLE_AUTO_COMMIT_CONFIG, "false") - - // So that the driver does not pull too much data - .set(ConsumerConfig.MAX_POLL_RECORDS_CONFIG, new java.lang.Integer(1)) - - // If buffer config is not set, set it to reasonable value to work around - // buffer issues (see KAFKA-3135) - .setIfUnset(ConsumerConfig.RECEIVE_BUFFER_CONFIG, 65536: java.lang.Integer) - .build() - - private def kafkaParamsForExecutors( - specifiedKafkaParams: Map[String, String], uniqueGroupId: String) = - ConfigUpdater("executor", specifiedKafkaParams) - .set(ConsumerConfig.KEY_DESERIALIZER_CLASS_CONFIG, deserClassName) - .set(ConsumerConfig.VALUE_DESERIALIZER_CLASS_CONFIG, deserClassName) - - // Make sure executors do only what the driver tells them. - .set(ConsumerConfig.AUTO_OFFSET_RESET_CONFIG, "none") - - // So that consumers in executors do not mess with any existing group id - .set(ConsumerConfig.GROUP_ID_CONFIG, s"$uniqueGroupId-executor") - - // So that consumers in executors does not commit offsets unnecessarily - .set(ConsumerConfig.ENABLE_AUTO_COMMIT_CONFIG, "false") - - // If buffer config is not set, set it to reasonable value to work around - // buffer issues (see KAFKA-3135) - .setIfUnset(ConsumerConfig.RECEIVE_BUFFER_CONFIG, 65536: java.lang.Integer) - .build() - private def strategy(caseInsensitiveParams: Map[String, String]) = caseInsensitiveParams.find(x => STRATEGY_OPTION_KEYS.contains(x._1)).get match { case ("assign", value) => @@ -414,30 +364,9 @@ private[kafka010] class KafkaSourceProvider extends DataSourceRegister logWarning("maxOffsetsPerTrigger option ignored in batch queries") } } - - /** Class to conveniently update Kafka config params, while logging the changes */ - private case class ConfigUpdater(module: String, kafkaParams: Map[String, String]) { - private val map = new ju.HashMap[String, Object](kafkaParams.asJava) - - def set(key: String, value: Object): this.type = { - map.put(key, value) - logInfo(s"$module: Set $key to $value, earlier value: ${kafkaParams.getOrElse(key, "")}") - this - } - - def setIfUnset(key: String, value: Object): ConfigUpdater = { - if (!map.containsKey(key)) { - map.put(key, value) - logInfo(s"$module: Set $key to $value") - } - this - } - - def build(): ju.Map[String, Object] = map - } } -private[kafka010] object KafkaSourceProvider { +private[kafka010] object KafkaSourceProvider extends Logging { private val STRATEGY_OPTION_KEYS = Set("subscribe", "subscribepattern", "assign") private[kafka010] val STARTING_OFFSETS_OPTION_KEY = "startingoffsets" private[kafka010] val ENDING_OFFSETS_OPTION_KEY = "endingoffsets" @@ -459,4 +388,66 @@ private[kafka010] object KafkaSourceProvider { case None => defaultOffsets } } + + def kafkaParamsForDriver(specifiedKafkaParams: Map[String, String]): ju.Map[String, Object] = + ConfigUpdater("source", specifiedKafkaParams) + .set(ConsumerConfig.KEY_DESERIALIZER_CLASS_CONFIG, deserClassName) + .set(ConsumerConfig.VALUE_DESERIALIZER_CLASS_CONFIG, deserClassName) + + // Set to "earliest" to avoid exceptions. However, KafkaSource will fetch the initial + // offsets by itself instead of counting on KafkaConsumer. + .set(ConsumerConfig.AUTO_OFFSET_RESET_CONFIG, "earliest") + + // So that consumers in the driver does not commit offsets unnecessarily + .set(ConsumerConfig.ENABLE_AUTO_COMMIT_CONFIG, "false") + + // So that the driver does not pull too much data + .set(ConsumerConfig.MAX_POLL_RECORDS_CONFIG, new java.lang.Integer(1)) + + // If buffer config is not set, set it to reasonable value to work around + // buffer issues (see KAFKA-3135) + .setIfUnset(ConsumerConfig.RECEIVE_BUFFER_CONFIG, 65536: java.lang.Integer) + .build() + + def kafkaParamsForExecutors( + specifiedKafkaParams: Map[String, String], + uniqueGroupId: String): ju.Map[String, Object] = + ConfigUpdater("executor", specifiedKafkaParams) + .set(ConsumerConfig.KEY_DESERIALIZER_CLASS_CONFIG, deserClassName) + .set(ConsumerConfig.VALUE_DESERIALIZER_CLASS_CONFIG, deserClassName) + + // Make sure executors do only what the driver tells them. + .set(ConsumerConfig.AUTO_OFFSET_RESET_CONFIG, "none") + + // So that consumers in executors do not mess with any existing group id + .set(ConsumerConfig.GROUP_ID_CONFIG, s"$uniqueGroupId-executor") + + // So that consumers in executors does not commit offsets unnecessarily + .set(ConsumerConfig.ENABLE_AUTO_COMMIT_CONFIG, "false") + + // If buffer config is not set, set it to reasonable value to work around + // buffer issues (see KAFKA-3135) + .setIfUnset(ConsumerConfig.RECEIVE_BUFFER_CONFIG, 65536: java.lang.Integer) + .build() + + /** Class to conveniently update Kafka config params, while logging the changes */ + private case class ConfigUpdater(module: String, kafkaParams: Map[String, String]) { + private val map = new ju.HashMap[String, Object](kafkaParams.asJava) + + def set(key: String, value: Object): this.type = { + map.put(key, value) + logDebug(s"$module: Set $key to $value, earlier value: ${kafkaParams.getOrElse(key, "")}") + this + } + + def setIfUnset(key: String, value: Object): ConfigUpdater = { + if (!map.containsKey(key)) { + map.put(key, value) + logDebug(s"$module: Set $key to $value") + } + this + } + + def build(): ju.Map[String, Object] = map + } } diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceRDD.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceRDD.scala index 6fb3473eb75f5..9d9e2aaba8079 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceRDD.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceRDD.scala @@ -125,16 +125,15 @@ private[kafka010] class KafkaSourceRDD( context: TaskContext): Iterator[ConsumerRecord[Array[Byte], Array[Byte]]] = { val sourcePartition = thePart.asInstanceOf[KafkaSourceRDDPartition] val topic = sourcePartition.offsetRange.topic - if (!reuseKafkaConsumer) { - // if we can't reuse CachedKafkaConsumers, let's reset the groupId to something unique - // to each task (i.e., append the task's unique partition id), because we will have - // multiple tasks (e.g., in the case of union) reading from the same topic partitions - val old = executorKafkaParams.get(ConsumerConfig.GROUP_ID_CONFIG).asInstanceOf[String] - val id = TaskContext.getPartitionId() - executorKafkaParams.put(ConsumerConfig.GROUP_ID_CONFIG, old + "-" + id) - } val kafkaPartition = sourcePartition.offsetRange.partition - val consumer = CachedKafkaConsumer.getOrCreate(topic, kafkaPartition, executorKafkaParams) + val consumer = + if (!reuseKafkaConsumer) { + // If we can't reuse CachedKafkaConsumers, creating a new CachedKafkaConsumer. As here we + // uses `assign`, we don't need to worry about the "group.id" conflicts. + CachedKafkaConsumer.createUncached(topic, kafkaPartition, executorKafkaParams) + } else { + CachedKafkaConsumer.getOrCreate(topic, kafkaPartition, executorKafkaParams) + } val range = resolveRange(consumer, sourcePartition.offsetRange) assert( range.fromOffset <= range.untilOffset, @@ -170,7 +169,7 @@ private[kafka010] class KafkaSourceRDD( override protected def close(): Unit = { if (!reuseKafkaConsumer) { // Don't forget to close non-reuse KafkaConsumers. You may take down your cluster! - CachedKafkaConsumer.removeKafkaConsumer(topic, kafkaPartition, executorKafkaParams) + consumer.close() } else { // Indicate that we're no longer using this consumer CachedKafkaConsumer.releaseKafkaConsumer(topic, kafkaPartition, executorKafkaParams) diff --git a/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/KafkaRDD.scala b/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/KafkaRDD.scala index 4c6e2ce87e295..62cdf5b1134e4 100644 --- a/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/KafkaRDD.scala +++ b/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/KafkaRDD.scala @@ -199,7 +199,7 @@ private[spark] class KafkaRDD[K, V]( val consumer = if (useConsumerCache) { CachedKafkaConsumer.init(cacheInitialCapacity, cacheMaxCapacity, cacheLoadFactor) - if (context.attemptNumber > 1) { + if (context.attemptNumber >= 1) { // just in case the prior attempt failures were cache related CachedKafkaConsumer.remove(groupId, part.topic, part.partition) } From b90bf520fd7b979a90d1377cfc2ee7f0bf82c705 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Thu, 27 Apr 2017 19:38:14 -0700 Subject: [PATCH 0356/1765] [SPARK-12837][CORE] Do not send the name of internal accumulator to executor side ## What changes were proposed in this pull request? When sending accumulator updates back to driver, the network overhead is pretty big as there are a lot of accumulators, e.g. `TaskMetrics` will send about 20 accumulators everytime, there may be a lot of `SQLMetric` if the query plan is complicated. Therefore, it's critical to reduce the size of serialized accumulator. A simple way is to not send the name of internal accumulators to executor side, as it's unnecessary. When executor sends accumulator updates back to driver, we can look up the accumulator name in `AccumulatorContext` easily. Note that, we still need to send names of normal accumulators, as the user code run at executor side may rely on accumulator names. In the future, we should reimplement `TaskMetrics` to not rely on accumulators and use custom serialization. Tried on the example in https://issues.apache.org/jira/browse/SPARK-12837, the size of serialized accumulator has been cut down by about 40%. ## How was this patch tested? existing tests. Author: Wenchen Fan Closes #17596 from cloud-fan/oom. --- .../apache/spark/executor/TaskMetrics.scala | 29 ++++++------- .../org/apache/spark/scheduler/Task.scala | 13 +++--- .../org/apache/spark/util/AccumulatorV2.scala | 28 +++++++------ .../spark/scheduler/TaskContextSuite.scala | 2 +- .../ui/jobs/JobProgressListenerSuite.scala | 2 +- .../apache/spark/util/JsonProtocolSuite.scala | 2 +- .../SpecificParquetRecordReaderBase.java | 12 +++--- .../parquet/ParquetFilterSuite.scala | 42 +++++++++++++++---- 8 files changed, 76 insertions(+), 54 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala b/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala index dfd2f818acdac..a3ce3d1ccc5e3 100644 --- a/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala +++ b/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala @@ -251,13 +251,10 @@ class TaskMetrics private[spark] () extends Serializable { private[spark] def accumulators(): Seq[AccumulatorV2[_, _]] = internalAccums ++ externalAccums - /** - * Looks for a registered accumulator by accumulator name. - */ - private[spark] def lookForAccumulatorByName(name: String): Option[AccumulatorV2[_, _]] = { - accumulators.find { acc => - acc.name.isDefined && acc.name.get == name - } + private[spark] def nonZeroInternalAccums(): Seq[AccumulatorV2[_, _]] = { + // RESULT_SIZE accumulator is always zero at executor, we need to send it back as its + // value will be updated at driver side. + internalAccums.filter(a => !a.isZero || a == _resultSize) } } @@ -308,16 +305,16 @@ private[spark] object TaskMetrics extends Logging { */ def fromAccumulators(accums: Seq[AccumulatorV2[_, _]]): TaskMetrics = { val tm = new TaskMetrics - val (internalAccums, externalAccums) = - accums.partition(a => a.name.isDefined && tm.nameToAccums.contains(a.name.get)) - - internalAccums.foreach { acc => - val tmAcc = tm.nameToAccums(acc.name.get).asInstanceOf[AccumulatorV2[Any, Any]] - tmAcc.metadata = acc.metadata - tmAcc.merge(acc.asInstanceOf[AccumulatorV2[Any, Any]]) + for (acc <- accums) { + val name = acc.name + if (name.isDefined && tm.nameToAccums.contains(name.get)) { + val tmAcc = tm.nameToAccums(name.get).asInstanceOf[AccumulatorV2[Any, Any]] + tmAcc.metadata = acc.metadata + tmAcc.merge(acc.asInstanceOf[AccumulatorV2[Any, Any]]) + } else { + tm.externalAccums += acc + } } - - tm.externalAccums ++= externalAccums tm } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/Task.scala b/core/src/main/scala/org/apache/spark/scheduler/Task.scala index 7fd2918960cd0..5c337b992c840 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/Task.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/Task.scala @@ -182,14 +182,11 @@ private[spark] abstract class Task[T]( */ def collectAccumulatorUpdates(taskFailed: Boolean = false): Seq[AccumulatorV2[_, _]] = { if (context != null) { - context.taskMetrics.internalAccums.filter { a => - // RESULT_SIZE accumulator is always zero at executor, we need to send it back as its - // value will be updated at driver side. - // Note: internal accumulators representing task metrics always count failed values - !a.isZero || a.name == Some(InternalAccumulator.RESULT_SIZE) - // zero value external accumulators may still be useful, e.g. SQLMetrics, we should not filter - // them out. - } ++ context.taskMetrics.externalAccums.filter(a => !taskFailed || a.countFailedValues) + // Note: internal accumulators representing task metrics always count failed values + context.taskMetrics.nonZeroInternalAccums() ++ + // zero value external accumulators may still be useful, e.g. SQLMetrics, we should not + // filter them out. + context.taskMetrics.externalAccums.filter(a => !taskFailed || a.countFailedValues) } else { Seq.empty } diff --git a/core/src/main/scala/org/apache/spark/util/AccumulatorV2.scala b/core/src/main/scala/org/apache/spark/util/AccumulatorV2.scala index 7479de55140ea..a65ec75cc5db6 100644 --- a/core/src/main/scala/org/apache/spark/util/AccumulatorV2.scala +++ b/core/src/main/scala/org/apache/spark/util/AccumulatorV2.scala @@ -84,8 +84,12 @@ abstract class AccumulatorV2[IN, OUT] extends Serializable { * Returns the name of this accumulator, can only be called after registration. */ final def name: Option[String] = { - assertMetadataNotNull() - metadata.name + if (atDriverSide) { + AccumulatorContext.get(id).flatMap(_.metadata.name) + } else { + assertMetadataNotNull() + metadata.name + } } /** @@ -161,7 +165,15 @@ abstract class AccumulatorV2[IN, OUT] extends Serializable { } val copyAcc = copyAndReset() assert(copyAcc.isZero, "copyAndReset must return a zero value copy") - copyAcc.metadata = metadata + val isInternalAcc = + (name.isDefined && name.get.startsWith(InternalAccumulator.METRICS_PREFIX)) || + getClass.getSimpleName == "SQLMetric" + if (isInternalAcc) { + // Do not serialize the name of internal accumulator and send it to executor. + copyAcc.metadata = metadata.copy(name = None) + } else { + copyAcc.metadata = metadata + } copyAcc } else { this @@ -263,16 +275,6 @@ private[spark] object AccumulatorContext { originals.clear() } - /** - * Looks for a registered accumulator by accumulator name. - */ - private[spark] def lookForAccumulatorByName(name: String): Option[AccumulatorV2[_, _]] = { - originals.values().asScala.find { ref => - val acc = ref.get - acc != null && acc.name.isDefined && acc.name.get == name - }.map(_.get) - } - // Identifier for distinguishing SQL metrics from other accumulators private[spark] val SQL_ACCUM_IDENTIFIER = "sql" } diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala index 8f576daa77d15..b22da565d86e7 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala @@ -198,7 +198,7 @@ class TaskContextSuite extends SparkFunSuite with BeforeAndAfter with LocalSpark sc = new SparkContext("local", "test") // Create a dummy task. We won't end up running this; we just want to collect // accumulator updates from it. - val taskMetrics = TaskMetrics.empty + val taskMetrics = TaskMetrics.registered val task = new Task[Int](0, 0, 0) { context = new TaskContextImpl(0, 0, 0L, 0, new TaskMemoryManager(SparkEnv.get.memoryManager, 0L), diff --git a/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala b/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala index 93964a2d56743..48be3be81755a 100644 --- a/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala +++ b/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala @@ -293,7 +293,7 @@ class JobProgressListenerSuite extends SparkFunSuite with LocalSparkContext with val execId = "exe-1" def makeTaskMetrics(base: Int): TaskMetrics = { - val taskMetrics = TaskMetrics.empty + val taskMetrics = TaskMetrics.registered val shuffleReadMetrics = taskMetrics.createTempShuffleReadMetrics() val shuffleWriteMetrics = taskMetrics.shuffleWriteMetrics val inputMetrics = taskMetrics.inputMetrics diff --git a/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala b/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala index a64dbeae47294..a77c8e3cab4e8 100644 --- a/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala @@ -830,7 +830,7 @@ private[spark] object JsonProtocolSuite extends Assertions { hasHadoopInput: Boolean, hasOutput: Boolean, hasRecords: Boolean = true) = { - val t = TaskMetrics.empty + val t = TaskMetrics.registered // Set CPU times same as wall times for testing purpose t.setExecutorDeserializeTime(a) t.setExecutorDeserializeCpuTime(a) diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/SpecificParquetRecordReaderBase.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/SpecificParquetRecordReaderBase.java index eb97118872ea1..0bab321a657d6 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/SpecificParquetRecordReaderBase.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/SpecificParquetRecordReaderBase.java @@ -153,14 +153,14 @@ public void initialize(InputSplit inputSplit, TaskAttemptContext taskAttemptCont } // For test purpose. - // If the predefined accumulator exists, the row group number to read will be updated - // to the accumulator. So we can check if the row groups are filtered or not in test case. + // If the last external accumulator is `NumRowGroupsAccumulator`, the row group number to read + // will be updated to the accumulator. So we can check if the row groups are filtered or not + // in test case. TaskContext taskContext = TaskContext$.MODULE$.get(); if (taskContext != null) { - Option> accu = taskContext.taskMetrics() - .lookForAccumulatorByName("numRowGroups"); - if (accu.isDefined()) { - ((LongAccumulator)accu.get()).add((long)blocks.size()); + Option> accu = taskContext.taskMetrics().externalAccums().lastOption(); + if (accu.isDefined() && accu.get().getClass().getSimpleName().equals("NumRowGroupsAcc")) { + ((AccumulatorV2)accu.get()).add(blocks.size()); } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala index 9a3328fcecee8..dd53b561326f3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala @@ -32,7 +32,7 @@ import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types._ -import org.apache.spark.util.{AccumulatorContext, LongAccumulator} +import org.apache.spark.util.{AccumulatorContext, AccumulatorV2} /** * A test suite that tests Parquet filter2 API based filter pushdown optimization. @@ -499,18 +499,20 @@ class ParquetFilterSuite extends QueryTest with ParquetTest with SharedSQLContex val path = s"${dir.getCanonicalPath}/table" (1 to 1024).map(i => (101, i)).toDF("a", "b").write.parquet(path) - Seq(("true", (x: Long) => x == 0), ("false", (x: Long) => x > 0)).map { case (push, func) => - withSQLConf(SQLConf.PARQUET_FILTER_PUSHDOWN_ENABLED.key -> push) { - val accu = new LongAccumulator - accu.register(sparkContext, Some("numRowGroups")) + Seq(true, false).foreach { enablePushDown => + withSQLConf(SQLConf.PARQUET_FILTER_PUSHDOWN_ENABLED.key -> enablePushDown.toString) { + val accu = new NumRowGroupsAcc + sparkContext.register(accu) val df = spark.read.parquet(path).filter("a < 100") df.foreachPartition(_.foreach(v => accu.add(0))) df.collect - val numRowGroups = AccumulatorContext.lookForAccumulatorByName("numRowGroups") - assert(numRowGroups.isDefined) - assert(func(numRowGroups.get.asInstanceOf[LongAccumulator].value)) + if (enablePushDown) { + assert(accu.value == 0) + } else { + assert(accu.value > 0) + } AccumulatorContext.remove(accu.id) } } @@ -537,3 +539,27 @@ class ParquetFilterSuite extends QueryTest with ParquetTest with SharedSQLContex } } } + +class NumRowGroupsAcc extends AccumulatorV2[Integer, Integer] { + private var _sum = 0 + + override def isZero: Boolean = _sum == 0 + + override def copy(): AccumulatorV2[Integer, Integer] = { + val acc = new NumRowGroupsAcc() + acc._sum = _sum + acc + } + + override def reset(): Unit = _sum = 0 + + override def add(v: Integer): Unit = _sum += v + + override def merge(other: AccumulatorV2[Integer, Integer]): Unit = other match { + case a: NumRowGroupsAcc => _sum += a._sum + case _ => throw new UnsupportedOperationException( + s"Cannot merge ${this.getClass.getName} with ${other.getClass.getName}") + } + + override def value: Integer = _sum +} From 7fe8249793bd3eed4fa67cb4a210264a80520786 Mon Sep 17 00:00:00 2001 From: wangmiao1981 Date: Thu, 27 Apr 2017 22:29:47 -0700 Subject: [PATCH 0357/1765] [SPARKR][DOC] Document LinearSVC in R programming guide ## What changes were proposed in this pull request? add link to svmLinear in the SparkR programming document. ## How was this patch tested? Build doc manually and click the link to the document. It looks good. Author: wangmiao1981 Closes #17797 from wangmiao1981/doc. --- docs/sparkr.md | 1 + 1 file changed, 1 insertion(+) diff --git a/docs/sparkr.md b/docs/sparkr.md index c3336ac2ce86a..c85cfd45c4567 100644 --- a/docs/sparkr.md +++ b/docs/sparkr.md @@ -482,6 +482,7 @@ SparkR supports the following machine learning algorithms currently: * [`spark.logit`](api/R/spark.logit.html): [`Logistic Regression`](ml-classification-regression.html#logistic-regression) * [`spark.mlp`](api/R/spark.mlp.html): [`Multilayer Perceptron (MLP)`](ml-classification-regression.html#multilayer-perceptron-classifier) * [`spark.naiveBayes`](api/R/spark.naiveBayes.html): [`Naive Bayes`](ml-classification-regression.html#naive-bayes) +* [`spark.svmLinear`](api/R/spark.svmLinear.html): [`Linear Support Vector Machine`](ml-classification-regression.html#linear-support-vector-machine) #### Regression From e3c816043389e227db5e7a328c7c554209b4f394 Mon Sep 17 00:00:00 2001 From: Xiao Li Date: Fri, 28 Apr 2017 14:16:40 +0800 Subject: [PATCH 0358/1765] [SPARK-20476][SQL] Block users to create a table that use commas in the column names ### What changes were proposed in this pull request? ```SQL hive> create table t1(`a,` string); OK Time taken: 1.399 seconds hive> create table t2(`a,` string, b string); FAILED: Execution Error, return code 1 from org.apache.hadoop.hive.ql.exec.DDLTask. java.lang.RuntimeException: MetaException(message:org.apache.hadoop.hive.serde2.SerDeException org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe: columns has 3 elements while columns.types has 2 elements!) hive> create table t2(`a,` string, b string) stored as parquet; FAILED: Execution Error, return code 1 from org.apache.hadoop.hive.ql.exec.DDLTask. java.lang.IllegalArgumentException: ParquetHiveSerde initialization failed. Number of column name and column type differs. columnNames = [a, , b], columnTypes = [string, string] ``` It has a bug in Hive metastore. When users do not provide alias name in the SELECT query, we call `toPrettySQL` to generate the alias name. For example, the string `get_json_object(jstring, '$.f1')` will be the alias name for the function call in the statement ```SQL SELECT key, get_json_object(jstring, '$.f1') FROM tempView ``` Above is not an issue for the SELECT query statements. However, for CTAS, we hit the issue due to a bug in Hive metastore. Hive metastore does not like the column names containing commas and returned a confusing error message, like: ``` 17/04/26 23:12:56 ERROR [hive.log(397) -- main]: error in initSerDe: org.apache.hadoop.hive.serde2.SerDeException org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe: columns has 2 elements while columns.types has 1 elements! org.apache.hadoop.hive.serde2.SerDeException: org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe: columns has 2 elements while columns.types has 1 elements! ``` Thus, this PR is to block users to create a table in Hive metastore when the table table has a column containing commas in the name. ### How was this patch tested? Added a test case Author: Xiao Li Closes #17781 from gatorsmile/blockIllegalColumnNames. --- .../spark/sql/hive/HiveExternalCatalog.scala | 18 ++++++++++++++ .../sql/hive/execution/SQLQuerySuite.scala | 24 +++++++++++++++++++ 2 files changed, 42 insertions(+) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala index 71e33c46b9aed..ba48facff2933 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala @@ -137,6 +137,22 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat } } + /** + * Checks the validity of column names. Hive metastore disallows the table to use comma in + * data column names. Partition columns do not have such a restriction. Views do not have such + * a restriction. + */ + private def verifyColumnNames(table: CatalogTable): Unit = { + if (table.tableType != VIEW) { + table.dataSchema.map(_.name).foreach { colName => + if (colName.contains(",")) { + throw new AnalysisException("Cannot create a table having a column whose name contains " + + s"commas in Hive metastore. Table: ${table.identifier}; Column: $colName") + } + } + } + } + // -------------------------------------------------------------------------- // Databases // -------------------------------------------------------------------------- @@ -202,6 +218,7 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat val table = tableDefinition.identifier.table requireDbExists(db) verifyTableProperties(tableDefinition) + verifyColumnNames(tableDefinition) if (tableExists(db, table) && !ignoreIfExists) { throw new TableAlreadyExistsException(db = db, table = table) @@ -614,6 +631,7 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat requireTableExists(db, table) val rawTable = getRawTable(db, table) val withNewSchema = rawTable.copy(schema = schema) + verifyColumnNames(withNewSchema) // Add table metadata such as table schema, partition columns, etc. to table properties. val updatedTable = withNewSchema.copy( properties = withNewSchema.properties ++ tableMetaToTableProps(withNewSchema)) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala index 75f3744ff35be..c944f28d10ef4 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala @@ -1976,6 +1976,30 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton { } } + test("Auto alias construction of get_json_object") { + val df = Seq(("1", """{"f1": "value1", "f5": 5.23}""")).toDF("key", "jstring") + val expectedMsg = "Cannot create a table having a column whose name contains commas " + + "in Hive metastore. Table: `default`.`t`; Column: get_json_object(jstring, $.f1)" + + withTable("t") { + val e = intercept[AnalysisException] { + df.select($"key", functions.get_json_object($"jstring", "$.f1")) + .write.format("hive").saveAsTable("t") + }.getMessage + assert(e.contains(expectedMsg)) + } + + withTempView("tempView") { + withTable("t") { + df.createTempView("tempView") + val e = intercept[AnalysisException] { + sql("CREATE TABLE t AS SELECT key, get_json_object(jstring, '$.f1') FROM tempView") + }.getMessage + assert(e.contains(expectedMsg)) + } + } + } + test("SPARK-19912 String literals should be escaped for Hive metastore partition pruning") { withTable("spark_19912") { Seq( From 59e3a564448777657125b6f65057ed20d0162d13 Mon Sep 17 00:00:00 2001 From: Takeshi Yamamuro Date: Fri, 28 Apr 2017 14:41:53 +0800 Subject: [PATCH 0359/1765] [SPARK-14471][SQL] Aliases in SELECT could be used in GROUP BY ## What changes were proposed in this pull request? This pr added a new rule in `Analyzer` to resolve aliases in `GROUP BY`. The current master throws an exception if `GROUP BY` clauses have aliases in `SELECT`; ``` scala> spark.sql("select a a1, a1 + 1 as b, count(1) from t group by a1") org.apache.spark.sql.AnalysisException: cannot resolve '`a1`' given input columns: [a]; line 1 pos 51; 'Aggregate ['a1], [a#83L AS a1#87L, ('a1 + 1) AS b#88, count(1) AS count(1)#90L] +- SubqueryAlias t +- Project [id#80L AS a#83L] +- Range (0, 10, step=1, splits=Some(8)) at org.apache.spark.sql.catalyst.analysis.package$AnalysisErrorAt.failAnalysis(package.scala:42) at org.apache.spark.sql.catalyst.analysis.CheckAnalysis$$anonfun$checkAnalysis$1$$anonfun$apply$2.applyOrElse(CheckAnalysis.scala:77) at org.apache.spark.sql.catalyst.analysis.CheckAnalysis$$anonfun$checkAnalysis$1$$anonfun$apply$2.applyOrElse(CheckAnalysis.scala:74) at org.apache.spark.sql.catalyst.trees.TreeNode$$anonfun$transformUp$1.apply(TreeNode.scala:289) ``` ## How was this patch tested? Added tests in `SQLQuerySuite` and `SQLQueryTestSuite`. Author: Takeshi Yamamuro Closes #17191 from maropu/SPARK-14471. --- .../sql/catalyst/analysis/Analyzer.scala | 71 ++++++++++++------- .../apache/spark/sql/internal/SQLConf.scala | 8 +++ .../sql-tests/inputs/group-by-ordinal.sql | 3 + .../resources/sql-tests/inputs/group-by.sql | 18 +++++ .../results/group-by-ordinal.sql.out | 22 ++++-- .../sql-tests/results/group-by.sql.out | 66 ++++++++++++++++- 6 files changed, 156 insertions(+), 32 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index dcadbbc90f438..72e7d5dd3638d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -136,6 +136,7 @@ class Analyzer( ResolveGroupingAnalytics :: ResolvePivot :: ResolveOrdinalInOrderByAndGroupBy :: + ResolveAggAliasInGroupBy :: ResolveMissingReferences :: ExtractGenerator :: ResolveGenerate :: @@ -172,7 +173,7 @@ class Analyzer( * Analyze cte definitions and substitute child plan with analyzed cte definitions. */ object CTESubstitution extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { + def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators { case With(child, relations) => substituteCTE(child, relations.foldLeft(Seq.empty[(String, LogicalPlan)]) { case (resolved, (name, relation)) => @@ -200,7 +201,7 @@ class Analyzer( * Substitute child plan with WindowSpecDefinitions. */ object WindowsSubstitution extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { + def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators { // Lookup WindowSpecDefinitions. This rule works with unresolved children. case WithWindowDefinition(windowDefinitions, child) => child.transform { @@ -242,7 +243,7 @@ class Analyzer( private def hasUnresolvedAlias(exprs: Seq[NamedExpression]) = exprs.exists(_.find(_.isInstanceOf[UnresolvedAlias]).isDefined) - def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { + def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators { case Aggregate(groups, aggs, child) if child.resolved && hasUnresolvedAlias(aggs) => Aggregate(groups, assignAliases(aggs), child) @@ -614,7 +615,7 @@ class Analyzer( case _ => plan } - def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { + def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators { case i @ InsertIntoTable(u: UnresolvedRelation, parts, child, _, _) if child.resolved => EliminateSubqueryAliases(lookupTableFromCatalog(u)) match { case v: View => @@ -786,7 +787,7 @@ class Analyzer( } } - def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { + def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators { case p: LogicalPlan if !p.childrenResolved => p // If the projection list contains Stars, expand it. @@ -844,11 +845,10 @@ class Analyzer( case q: LogicalPlan => logTrace(s"Attempting to resolve ${q.simpleString}") - q transformExpressionsUp { + q.transformExpressionsUp { case u @ UnresolvedAttribute(nameParts) => - // Leave unchanged if resolution fails. Hopefully will be resolved next round. - val result = - withPosition(u) { q.resolveChildren(nameParts, resolver).getOrElse(u) } + // Leave unchanged if resolution fails. Hopefully will be resolved next round. + val result = withPosition(u) { q.resolveChildren(nameParts, resolver).getOrElse(u) } logDebug(s"Resolving $u to $result") result case UnresolvedExtractValue(child, fieldExpr) if child.resolved => @@ -961,7 +961,7 @@ class Analyzer( * have no effect on the results. */ object ResolveOrdinalInOrderByAndGroupBy extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { + def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators { case p if !p.childrenResolved => p // Replace the index with the related attribute for ORDER BY, // which is a 1-base position of the projection list. @@ -997,6 +997,27 @@ class Analyzer( } } + /** + * Replace unresolved expressions in grouping keys with resolved ones in SELECT clauses. + * This rule is expected to run after [[ResolveReferences]] applied. + */ + object ResolveAggAliasInGroupBy extends Rule[LogicalPlan] { + + override def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators { + case agg @ Aggregate(groups, aggs, child) + if conf.groupByAliases && child.resolved && aggs.forall(_.resolved) && + groups.exists(_.isInstanceOf[UnresolvedAttribute]) => + // This is a strict check though, we put this to apply the rule only in alias expressions + def notResolvableByChild(attrName: String): Boolean = + !child.output.exists(a => resolver(a.name, attrName)) + agg.copy(groupingExpressions = groups.map { + case u: UnresolvedAttribute if notResolvableByChild(u.name) => + aggs.find(ne => resolver(ne.name, u.name)).getOrElse(u) + case e => e + }) + } + } + /** * In many dialects of SQL it is valid to sort by attributes that are not present in the SELECT * clause. This rule detects such queries and adds the required attributes to the original @@ -1006,7 +1027,7 @@ class Analyzer( * The HAVING clause could also used a grouping columns that is not presented in the SELECT. */ object ResolveMissingReferences extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { + def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators { // Skip sort with aggregate. This will be handled in ResolveAggregateFunctions case sa @ Sort(_, _, child: Aggregate) => sa @@ -1130,7 +1151,7 @@ class Analyzer( * Replaces [[UnresolvedFunction]]s with concrete [[Expression]]s. */ object ResolveFunctions extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { + def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators { case q: LogicalPlan => q transformExpressions { case u if !u.childrenResolved => u // Skip until children are resolved. @@ -1469,7 +1490,7 @@ class Analyzer( /** * Resolve and rewrite all subqueries in an operator tree.. */ - def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { + def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators { // In case of HAVING (a filter after an aggregate) we use both the aggregate and // its child for resolution. case f @ Filter(_, a: Aggregate) if f.childrenResolved => @@ -1484,7 +1505,7 @@ class Analyzer( * Turns projections that contain aggregate expressions into aggregations. */ object GlobalAggregates extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { + def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators { case Project(projectList, child) if containsAggregates(projectList) => Aggregate(Nil, projectList, child) } @@ -1510,7 +1531,7 @@ class Analyzer( * underlying aggregate operator and then projected away after the original operator. */ object ResolveAggregateFunctions extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { + def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators { case filter @ Filter(havingCondition, aggregate @ Aggregate(grouping, originalAggExprs, child)) if aggregate.resolved => @@ -1682,7 +1703,7 @@ class Analyzer( } } - def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { + def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators { case Project(projectList, _) if projectList.exists(hasNestedGenerator) => val nestedGenerator = projectList.find(hasNestedGenerator).get throw new AnalysisException("Generators are not supported when it's nested in " + @@ -1740,7 +1761,7 @@ class Analyzer( * that wrap the [[Generator]]. */ object ResolveGenerate extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { + def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators { case g: Generate if !g.child.resolved || !g.generator.resolved => g case g: Generate if !g.resolved => g.copy(generatorOutput = makeGeneratorOutput(g.generator, g.generatorOutput.map(_.name))) @@ -2057,7 +2078,7 @@ class Analyzer( * put them into an inner Project and finally project them away at the outer Project. */ object PullOutNondeterministic extends Rule[LogicalPlan] { - override def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { + override def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators { case p if !p.resolved => p // Skip unresolved nodes. case p: Project => p case f: Filter => f @@ -2102,7 +2123,7 @@ class Analyzer( * and we should return null if the input is null. */ object HandleNullInputsForUDF extends Rule[LogicalPlan] { - override def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { + override def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators { case p if !p.resolved => p // Skip unresolved nodes. case p => p transformExpressionsUp { @@ -2167,7 +2188,7 @@ class Analyzer( * Then apply a Project on a normal Join to eliminate natural or using join. */ object ResolveNaturalAndUsingJoin extends Rule[LogicalPlan] { - override def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { + override def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators { case j @ Join(left, right, UsingJoin(joinType, usingCols), condition) if left.resolved && right.resolved && j.duplicateResolved => commonNaturalJoinProcessing(left, right, joinType, usingCols, None) @@ -2232,7 +2253,7 @@ class Analyzer( * to the given input attributes. */ object ResolveDeserializer extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { + def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators { case p if !p.childrenResolved => p case p if p.resolved => p @@ -2318,7 +2339,7 @@ class Analyzer( * constructed is an inner class. */ object ResolveNewInstance extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { + def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators { case p if !p.childrenResolved => p case p if p.resolved => p @@ -2352,7 +2373,7 @@ class Analyzer( "type of the field in the target object") } - def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { + def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators { case p if !p.childrenResolved => p case p if p.resolved => p @@ -2406,7 +2427,7 @@ object CleanupAliases extends Rule[LogicalPlan] { case other => trimAliases(other) } - override def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { + override def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators { case Project(projectList, child) => val cleanedProjectList = projectList.map(trimNonTopLevelAliases(_).asInstanceOf[NamedExpression]) @@ -2474,7 +2495,7 @@ object TimeWindowing extends Rule[LogicalPlan] { * @return the logical plan that will generate the time windows using the Expand operator, with * the Filter operator for correctness and Project for usability. */ - def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { + def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators { case p: LogicalPlan if p.children.size == 1 => val child = p.children.head val windowExpressions = diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 2e1798e22b9fc..b24419a41edb0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -421,6 +421,12 @@ object SQLConf { .booleanConf .createWithDefault(true) + val GROUP_BY_ALIASES = buildConf("spark.sql.groupByAliases") + .doc("When true, aliases in a select list can be used in group by clauses. When false, " + + "an analysis exception is thrown in the case.") + .booleanConf + .createWithDefault(true) + // The output committer class used by data sources. The specified class needs to be a // subclass of org.apache.hadoop.mapreduce.OutputCommitter. val OUTPUT_COMMITTER_CLASS = @@ -1003,6 +1009,8 @@ class SQLConf extends Serializable with Logging { def groupByOrdinal: Boolean = getConf(GROUP_BY_ORDINAL) + def groupByAliases: Boolean = getConf(GROUP_BY_ALIASES) + def crossJoinEnabled: Boolean = getConf(SQLConf.CROSS_JOINS_ENABLED) def sessionLocalTimeZone: String = getConf(SQLConf.SESSION_LOCAL_TIMEZONE) diff --git a/sql/core/src/test/resources/sql-tests/inputs/group-by-ordinal.sql b/sql/core/src/test/resources/sql-tests/inputs/group-by-ordinal.sql index 9c8d851e36e9b..6566338f3d4a9 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/group-by-ordinal.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/group-by-ordinal.sql @@ -49,6 +49,9 @@ select a, count(a) from (select 1 as a) tmp group by 1 order by 1; -- group by ordinal followed by having select count(a), a from (select 1 as a) tmp group by 2 having a > 0; +-- mixed cases: group-by ordinals and aliases +select a, a AS k, count(b) from data group by k, 1; + -- turn of group by ordinal set spark.sql.groupByOrdinal=false; diff --git a/sql/core/src/test/resources/sql-tests/inputs/group-by.sql b/sql/core/src/test/resources/sql-tests/inputs/group-by.sql index 4d0ed43153004..a7994f3beaff3 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/group-by.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/group-by.sql @@ -35,3 +35,21 @@ FROM testData; -- Aggregate with foldable input and multiple distinct groups. SELECT COUNT(DISTINCT b), COUNT(DISTINCT b, c) FROM (SELECT 1 AS a, 2 AS b, 3 AS c) GROUP BY a; + +-- Aliases in SELECT could be used in GROUP BY +SELECT a AS k, COUNT(b) FROM testData GROUP BY k; +SELECT a AS k, COUNT(b) FROM testData GROUP BY k HAVING k > 1; + +-- Aggregate functions cannot be used in GROUP BY +SELECT COUNT(b) AS k FROM testData GROUP BY k; + +-- Test data. +CREATE OR REPLACE TEMPORARY VIEW testDataHasSameNameWithAlias AS SELECT * FROM VALUES +(1, 1, 3), (1, 2, 1) AS testDataHasSameNameWithAlias(k, a, v); +SELECT k AS a, COUNT(v) FROM testDataHasSameNameWithAlias GROUP BY a; + +-- turn off group by aliases +set spark.sql.groupByAliases=false; + +-- Check analysis exceptions +SELECT a AS k, COUNT(b) FROM testData GROUP BY k; diff --git a/sql/core/src/test/resources/sql-tests/results/group-by-ordinal.sql.out b/sql/core/src/test/resources/sql-tests/results/group-by-ordinal.sql.out index d03681d0ea59c..9ecbe19078dd6 100644 --- a/sql/core/src/test/resources/sql-tests/results/group-by-ordinal.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/group-by-ordinal.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 19 +-- Number of queries: 20 -- !query 0 @@ -173,16 +173,26 @@ struct -- !query 17 -set spark.sql.groupByOrdinal=false +select a, a AS k, count(b) from data group by k, 1 -- !query 17 schema -struct +struct -- !query 17 output -spark.sql.groupByOrdinal false +1 1 2 +2 2 2 +3 3 2 -- !query 18 -select sum(b) from data group by -1 +set spark.sql.groupByOrdinal=false -- !query 18 schema -struct +struct -- !query 18 output +spark.sql.groupByOrdinal false + + +-- !query 19 +select sum(b) from data group by -1 +-- !query 19 schema +struct +-- !query 19 output 9 diff --git a/sql/core/src/test/resources/sql-tests/results/group-by.sql.out b/sql/core/src/test/resources/sql-tests/results/group-by.sql.out index 4b87d5161fc0e..6bf9dff883c1e 100644 --- a/sql/core/src/test/resources/sql-tests/results/group-by.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/group-by.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 15 +-- Number of queries: 22 -- !query 0 @@ -139,3 +139,67 @@ SELECT COUNT(DISTINCT b), COUNT(DISTINCT b, c) FROM (SELECT 1 AS a, 2 AS b, 3 AS struct -- !query 14 output 1 1 + + +-- !query 15 +SELECT a AS k, COUNT(b) FROM testData GROUP BY k +-- !query 15 schema +struct +-- !query 15 output +1 2 +2 2 +3 2 +NULL 1 + + +-- !query 16 +SELECT a AS k, COUNT(b) FROM testData GROUP BY k HAVING k > 1 +-- !query 16 schema +struct +-- !query 16 output +2 2 +3 2 + + +-- !query 17 +SELECT COUNT(b) AS k FROM testData GROUP BY k +-- !query 17 schema +struct<> +-- !query 17 output +org.apache.spark.sql.AnalysisException +aggregate functions are not allowed in GROUP BY, but found count(testdata.`b`); + + +-- !query 18 +CREATE OR REPLACE TEMPORARY VIEW testDataHasSameNameWithAlias AS SELECT * FROM VALUES +(1, 1, 3), (1, 2, 1) AS testDataHasSameNameWithAlias(k, a, v) +-- !query 18 schema +struct<> +-- !query 18 output + + + +-- !query 19 +SELECT k AS a, COUNT(v) FROM testDataHasSameNameWithAlias GROUP BY a +-- !query 19 schema +struct<> +-- !query 19 output +org.apache.spark.sql.AnalysisException +expression 'testdatahassamenamewithalias.`k`' is neither present in the group by, nor is it an aggregate function. Add to group by or wrap in first() (or first_value) if you don't care which value you get.; + + +-- !query 20 +set spark.sql.groupByAliases=false +-- !query 20 schema +struct +-- !query 20 output +spark.sql.groupByAliases false + + +-- !query 21 +SELECT a AS k, COUNT(b) FROM testData GROUP BY k +-- !query 21 schema +struct<> +-- !query 21 output +org.apache.spark.sql.AnalysisException +cannot resolve '`k`' given input columns: [a, b]; line 1 pos 47 From 8c911adac56a1b1d95bc19915e0070ce7305257c Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Fri, 28 Apr 2017 08:49:35 +0100 Subject: [PATCH 0360/1765] [SPARK-20465][CORE] Throws a proper exception when any temp directory could not be got ## What changes were proposed in this pull request? This PR proposes to throw an exception with better message rather than `ArrayIndexOutOfBoundsException` when temp directories could not be created. Running the commands below: ```bash ./bin/spark-shell --conf spark.local.dir=/NONEXISTENT_DIR_ONE,/NONEXISTENT_DIR_TWO ``` produces ... **Before** ``` Exception in thread "main" java.lang.ExceptionInInitializerError ... Caused by: java.lang.ArrayIndexOutOfBoundsException: 0 ... ``` **After** ``` Exception in thread "main" java.lang.ExceptionInInitializerError ... Caused by: java.io.IOException: Failed to get a temp directory under [/NONEXISTENT_DIR_ONE,/NONEXISTENT_DIR_TWO]. ... ``` ## How was this patch tested? Unit tests in `LocalDirsSuite.scala`. Author: hyukjinkwon Closes #17768 from HyukjinKwon/throws-temp-dir-exception. --- .../scala/org/apache/spark/util/Utils.scala | 6 ++++- .../apache/spark/storage/LocalDirsSuite.scala | 23 ++++++++++++++++--- 2 files changed, 25 insertions(+), 4 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index e042badcdd4a4..4d37db96dfc37 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -740,7 +740,11 @@ private[spark] object Utils extends Logging { * always return a single directory. */ def getLocalDir(conf: SparkConf): String = { - getOrCreateLocalRootDirs(conf)(0) + getOrCreateLocalRootDirs(conf).headOption.getOrElse { + val configuredLocalDirs = getConfiguredLocalDirs(conf) + throw new IOException( + s"Failed to get a temp directory under [${configuredLocalDirs.mkString(",")}].") + } } private[spark] def isRunningInYarnContainer(conf: SparkConf): Boolean = { diff --git a/core/src/test/scala/org/apache/spark/storage/LocalDirsSuite.scala b/core/src/test/scala/org/apache/spark/storage/LocalDirsSuite.scala index c7074078d8fd2..f7b3a2754f0ea 100644 --- a/core/src/test/scala/org/apache/spark/storage/LocalDirsSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/LocalDirsSuite.scala @@ -17,7 +17,7 @@ package org.apache.spark.storage -import java.io.File +import java.io.{File, IOException} import org.scalatest.BeforeAndAfter @@ -33,9 +33,13 @@ class LocalDirsSuite extends SparkFunSuite with BeforeAndAfter { Utils.clearLocalRootDirs() } + after { + Utils.clearLocalRootDirs() + } + test("Utils.getLocalDir() returns a valid directory, even if some local dirs are missing") { // Regression test for SPARK-2974 - assert(!new File("/NONEXISTENT_DIR").exists()) + assert(!new File("/NONEXISTENT_PATH").exists()) val conf = new SparkConf(false) .set("spark.local.dir", s"/NONEXISTENT_PATH,${System.getProperty("java.io.tmpdir")}") assert(new File(Utils.getLocalDir(conf)).exists()) @@ -43,7 +47,7 @@ class LocalDirsSuite extends SparkFunSuite with BeforeAndAfter { test("SPARK_LOCAL_DIRS override also affects driver") { // Regression test for SPARK-2975 - assert(!new File("/NONEXISTENT_DIR").exists()) + assert(!new File("/NONEXISTENT_PATH").exists()) // spark.local.dir only contains invalid directories, but that's not a problem since // SPARK_LOCAL_DIRS will override it on both the driver and workers: val conf = new SparkConfWithEnv(Map("SPARK_LOCAL_DIRS" -> System.getProperty("java.io.tmpdir"))) @@ -51,4 +55,17 @@ class LocalDirsSuite extends SparkFunSuite with BeforeAndAfter { assert(new File(Utils.getLocalDir(conf)).exists()) } + test("Utils.getLocalDir() throws an exception if any temporary directory cannot be retrieved") { + val path1 = "/NONEXISTENT_PATH_ONE" + val path2 = "/NONEXISTENT_PATH_TWO" + assert(!new File(path1).exists()) + assert(!new File(path2).exists()) + val conf = new SparkConf(false).set("spark.local.dir", s"$path1,$path2") + val message = intercept[IOException] { + Utils.getLocalDir(conf) + }.getMessage + // If any temporary directory could not be retrieved under the given paths above, it should + // throw an exception with the message that includes the paths. + assert(message.contains(s"$path1,$path2")) + } } From 733b81b835f952ab96723c749461d6afc0c71974 Mon Sep 17 00:00:00 2001 From: Bill Chambers Date: Fri, 28 Apr 2017 10:18:31 -0700 Subject: [PATCH 0361/1765] [SPARK-20496][SS] Bug in KafkaWriter Looks at Unanalyzed Plans ## What changes were proposed in this pull request? We didn't enforce analyzed plans in Spark 2.1 when writing out to Kafka. ## How was this patch tested? New unit test. Please review http://spark.apache.org/contributing.html before opening a pull request. Author: Bill Chambers Closes #17804 from anabranch/SPARK-20496-2. --- .../apache/spark/sql/kafka010/KafkaWriter.scala | 4 ++-- .../spark/sql/kafka010/KafkaSinkSuite.scala | 16 ++++++++++++++++ 2 files changed, 18 insertions(+), 2 deletions(-) diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaWriter.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaWriter.scala index a637d52c933a3..61936e32fd837 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaWriter.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaWriter.scala @@ -47,7 +47,7 @@ private[kafka010] object KafkaWriter extends Logging { queryExecution: QueryExecution, kafkaParameters: ju.Map[String, Object], topic: Option[String] = None): Unit = { - val schema = queryExecution.logical.output + val schema = queryExecution.analyzed.output schema.find(_.name == TOPIC_ATTRIBUTE_NAME).getOrElse( if (topic == None) { throw new AnalysisException(s"topic option required when no " + @@ -84,7 +84,7 @@ private[kafka010] object KafkaWriter extends Logging { queryExecution: QueryExecution, kafkaParameters: ju.Map[String, Object], topic: Option[String] = None): Unit = { - val schema = queryExecution.logical.output + val schema = queryExecution.analyzed.output validateQuery(queryExecution, kafkaParameters, topic) SQLExecution.withNewExecutionId(sparkSession, queryExecution) { queryExecution.toRdd.foreachPartition { iter => diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSinkSuite.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSinkSuite.scala index 4bd052d249eca..2ab336c7ac476 100644 --- a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSinkSuite.scala +++ b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSinkSuite.scala @@ -28,6 +28,7 @@ import org.apache.spark.SparkException import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.expressions.{AttributeReference, SpecificInternalRow, UnsafeProjection} import org.apache.spark.sql.execution.streaming.MemoryStream +import org.apache.spark.sql.functions._ import org.apache.spark.sql.streaming._ import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types.{BinaryType, DataType} @@ -108,6 +109,21 @@ class KafkaSinkSuite extends StreamTest with SharedSQLContext { s"save mode overwrite not allowed for kafka")) } + test("SPARK-20496: batch - enforce analyzed plans") { + val inputEvents = + spark.range(1, 1000) + .select(to_json(struct("*")) as 'value) + + val topic = newTopic() + testUtils.createTopic(topic) + // used to throw UnresolvedException + inputEvents.write + .format("kafka") + .option("kafka.bootstrap.servers", testUtils.brokerAddress) + .option("topic", topic) + .save() + } + test("streaming - write to kafka with topic field") { val input = MemoryStream[String] val topic = newTopic() From 5d71f3db83138bf50749dcd425ef7365c34bd799 Mon Sep 17 00:00:00 2001 From: Mark Grover Date: Fri, 28 Apr 2017 14:06:57 -0700 Subject: [PATCH 0362/1765] [SPARK-20514][CORE] Upgrade Jetty to 9.3.11.v20160721 Upgrade Jetty so it can work with Hadoop 3 (alpha 2 release, in particular). Without this change, because of incompatibily between Jetty versions, Spark fails to compile when built against Hadoop 3 ## How was this patch tested? Unit tests being run. Author: Mark Grover Closes #17790 from markgrover/spark-20514. --- core/src/main/scala/org/apache/spark/ui/JettyUtils.scala | 2 +- pom.xml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala b/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala index bdbdba5780856..edf328b5ae538 100644 --- a/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala +++ b/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala @@ -29,8 +29,8 @@ import org.eclipse.jetty.client.api.Response import org.eclipse.jetty.proxy.ProxyServlet import org.eclipse.jetty.server._ import org.eclipse.jetty.server.handler._ +import org.eclipse.jetty.server.handler.gzip.GzipHandler import org.eclipse.jetty.servlet._ -import org.eclipse.jetty.servlets.gzip.GzipHandler import org.eclipse.jetty.util.component.LifeCycle import org.eclipse.jetty.util.thread.{QueuedThreadPool, ScheduledExecutorScheduler} import org.json4s.JValue diff --git a/pom.xml b/pom.xml index b6654c1411d25..517ebc5c83fc6 100644 --- a/pom.xml +++ b/pom.xml @@ -136,7 +136,7 @@ 10.12.1.1 1.8.2 1.6.0 - 9.2.16.v20160414 + 9.3.11.v20160721 3.1.0 0.8.0 2.4.0 From ebff519c5ead31536e17a5b16cc47c2bf380d55e Mon Sep 17 00:00:00 2001 From: caoxuewen Date: Fri, 28 Apr 2017 14:47:17 -0700 Subject: [PATCH 0363/1765] [SPARK-20471] Remove AggregateBenchmark testsuite warning: Two level hashmap is disabled but vectorized hashmap is enabled What changes were proposed in this pull request? remove AggregateBenchmark testsuite warning: such as '14:26:33.220 WARN org.apache.spark.sql.execution.aggregate.HashAggregateExec: Two level hashmap is disabled but vectorized hashmap is enabled.' How was this patch tested? unit tests: AggregateBenchmark Modify the 'ignore function for 'test funtion Author: caoxuewen Closes #17771 from heary-cao/AggregateBenchmark. --- .../spark/sql/execution/benchmark/AggregateBenchmark.scala | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/AggregateBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/AggregateBenchmark.scala index 8a2993bdf4b28..8a798fb444696 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/AggregateBenchmark.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/AggregateBenchmark.scala @@ -107,6 +107,7 @@ class AggregateBenchmark extends BenchmarkBase { benchmark.addCase(s"codegen = T hashmap = F", numIters = 3) { iter => sparkSession.conf.set("spark.sql.codegen.wholeStage", "true") sparkSession.conf.set("spark.sql.codegen.aggregate.map.twolevel.enable", "false") + sparkSession.conf.set("spark.sql.codegen.aggregate.map.vectorized.enable", "false") f() } @@ -148,6 +149,7 @@ class AggregateBenchmark extends BenchmarkBase { benchmark.addCase(s"codegen = T hashmap = F", numIters = 3) { iter => sparkSession.conf.set("spark.sql.codegen.wholeStage", value = true) sparkSession.conf.set("spark.sql.codegen.aggregate.map.twolevel.enable", "false") + sparkSession.conf.set("spark.sql.codegen.aggregate.map.vectorized.enable", "false") f() } @@ -187,6 +189,7 @@ class AggregateBenchmark extends BenchmarkBase { benchmark.addCase(s"codegen = T hashmap = F", numIters = 3) { iter => sparkSession.conf.set("spark.sql.codegen.wholeStage", "true") sparkSession.conf.set("spark.sql.codegen.aggregate.map.twolevel.enable", "false") + sparkSession.conf.set("spark.sql.codegen.aggregate.map.vectorized.enable", "false") f() } @@ -225,6 +228,7 @@ class AggregateBenchmark extends BenchmarkBase { benchmark.addCase(s"codegen = T hashmap = F") { iter => sparkSession.conf.set("spark.sql.codegen.wholeStage", "true") sparkSession.conf.set("spark.sql.codegen.aggregate.map.twolevel.enable", "false") + sparkSession.conf.set("spark.sql.codegen.aggregate.map.vectorized.enable", "false") f() } @@ -273,6 +277,7 @@ class AggregateBenchmark extends BenchmarkBase { benchmark.addCase(s"codegen = T hashmap = F") { iter => sparkSession.conf.set("spark.sql.codegen.wholeStage", "true") sparkSession.conf.set("spark.sql.codegen.aggregate.map.twolevel.enable", "false") + sparkSession.conf.set("spark.sql.codegen.aggregate.map.vectorized.enable", "false") f() } From 77bcd77ed5fbd91fe61849cca76a8dffe5e4d6b2 Mon Sep 17 00:00:00 2001 From: Aaditya Ramesh Date: Fri, 28 Apr 2017 15:28:56 -0700 Subject: [PATCH 0364/1765] [SPARK-19525][CORE] Add RDD checkpoint compression support ## What changes were proposed in this pull request? This PR adds RDD checkpoint compression support and add a new config `spark.checkpoint.compress` to enable/disable it. Credit goes to aramesh117 Closes #17024 ## How was this patch tested? The new unit test. Author: Shixiong Zhu Author: Aaditya Ramesh Closes #17789 from zsxwing/pr17024. --- .../spark/internal/config/package.scala | 6 +++ .../spark/rdd/ReliableCheckpointRDD.scala | 24 ++++++++++- .../org/apache/spark/CheckpointSuite.scala | 41 +++++++++++++++++++ 3 files changed, 69 insertions(+), 2 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/internal/config/package.scala b/core/src/main/scala/org/apache/spark/internal/config/package.scala index 2f0a3064be111..7f7921d56f49e 100644 --- a/core/src/main/scala/org/apache/spark/internal/config/package.scala +++ b/core/src/main/scala/org/apache/spark/internal/config/package.scala @@ -272,4 +272,10 @@ package object config { .booleanConf .createWithDefault(false) + private[spark] val CHECKPOINT_COMPRESS = + ConfigBuilder("spark.checkpoint.compress") + .doc("Whether to compress RDD checkpoints. Generally a good idea. Compression will use " + + "spark.io.compression.codec.") + .booleanConf + .createWithDefault(false) } diff --git a/core/src/main/scala/org/apache/spark/rdd/ReliableCheckpointRDD.scala b/core/src/main/scala/org/apache/spark/rdd/ReliableCheckpointRDD.scala index e0a29b48314fb..37c67cee55f90 100644 --- a/core/src/main/scala/org/apache/spark/rdd/ReliableCheckpointRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/ReliableCheckpointRDD.scala @@ -18,6 +18,7 @@ package org.apache.spark.rdd import java.io.{FileNotFoundException, IOException} +import java.util.concurrent.TimeUnit import scala.reflect.ClassTag import scala.util.control.NonFatal @@ -27,6 +28,8 @@ import org.apache.hadoop.fs.Path import org.apache.spark._ import org.apache.spark.broadcast.Broadcast import org.apache.spark.internal.Logging +import org.apache.spark.internal.config.CHECKPOINT_COMPRESS +import org.apache.spark.io.CompressionCodec import org.apache.spark.util.{SerializableConfiguration, Utils} /** @@ -119,6 +122,7 @@ private[spark] object ReliableCheckpointRDD extends Logging { originalRDD: RDD[T], checkpointDir: String, blockSize: Int = -1): ReliableCheckpointRDD[T] = { + val checkpointStartTimeNs = System.nanoTime() val sc = originalRDD.sparkContext @@ -140,6 +144,10 @@ private[spark] object ReliableCheckpointRDD extends Logging { writePartitionerToCheckpointDir(sc, originalRDD.partitioner.get, checkpointDirPath) } + val checkpointDurationMs = + TimeUnit.NANOSECONDS.toMillis(System.nanoTime() - checkpointStartTimeNs) + logInfo(s"Checkpointing took $checkpointDurationMs ms.") + val newRDD = new ReliableCheckpointRDD[T]( sc, checkpointDirPath.toString, originalRDD.partitioner) if (newRDD.partitions.length != originalRDD.partitions.length) { @@ -169,7 +177,12 @@ private[spark] object ReliableCheckpointRDD extends Logging { val bufferSize = env.conf.getInt("spark.buffer.size", 65536) val fileOutputStream = if (blockSize < 0) { - fs.create(tempOutputPath, false, bufferSize) + val fileStream = fs.create(tempOutputPath, false, bufferSize) + if (env.conf.get(CHECKPOINT_COMPRESS)) { + CompressionCodec.createCodec(env.conf).compressedOutputStream(fileStream) + } else { + fileStream + } } else { // This is mainly for testing purpose fs.create(tempOutputPath, false, bufferSize, @@ -273,7 +286,14 @@ private[spark] object ReliableCheckpointRDD extends Logging { val env = SparkEnv.get val fs = path.getFileSystem(broadcastedConf.value.value) val bufferSize = env.conf.getInt("spark.buffer.size", 65536) - val fileInputStream = fs.open(path, bufferSize) + val fileInputStream = { + val fileStream = fs.open(path, bufferSize) + if (env.conf.get(CHECKPOINT_COMPRESS)) { + CompressionCodec.createCodec(env.conf).compressedInputStream(fileStream) + } else { + fileStream + } + } val serializer = env.serializer.newInstance() val deserializeStream = serializer.deserializeStream(fileInputStream) diff --git a/core/src/test/scala/org/apache/spark/CheckpointSuite.scala b/core/src/test/scala/org/apache/spark/CheckpointSuite.scala index b117c7709b46f..ee70a3399efed 100644 --- a/core/src/test/scala/org/apache/spark/CheckpointSuite.scala +++ b/core/src/test/scala/org/apache/spark/CheckpointSuite.scala @@ -21,8 +21,10 @@ import java.io.File import scala.reflect.ClassTag +import com.google.common.io.ByteStreams import org.apache.hadoop.fs.Path +import org.apache.spark.io.CompressionCodec import org.apache.spark.rdd._ import org.apache.spark.storage.{BlockId, StorageLevel, TestBlockId} import org.apache.spark.util.Utils @@ -580,3 +582,42 @@ object CheckpointSuite { ).asInstanceOf[RDD[(K, Array[Iterable[V]])]] } } + +class CheckpointCompressionSuite extends SparkFunSuite with LocalSparkContext { + + test("checkpoint compression") { + val checkpointDir = Utils.createTempDir() + try { + val conf = new SparkConf() + .set("spark.checkpoint.compress", "true") + .set("spark.ui.enabled", "false") + sc = new SparkContext("local", "test", conf) + sc.setCheckpointDir(checkpointDir.toString) + val rdd = sc.makeRDD(1 to 20, numSlices = 1) + rdd.checkpoint() + assert(rdd.collect().toSeq === (1 to 20)) + + // Verify that RDD is checkpointed + assert(rdd.firstParent.isInstanceOf[ReliableCheckpointRDD[_]]) + + val checkpointPath = new Path(rdd.getCheckpointFile.get) + val fs = checkpointPath.getFileSystem(sc.hadoopConfiguration) + val checkpointFile = + fs.listStatus(checkpointPath).map(_.getPath).find(_.getName.startsWith("part-")).get + + // Verify the checkpoint file is compressed, in other words, can be decompressed + val compressedInputStream = CompressionCodec.createCodec(conf) + .compressedInputStream(fs.open(checkpointFile)) + try { + ByteStreams.toByteArray(compressedInputStream) + } finally { + compressedInputStream.close() + } + + // Verify that the compressed content can be read back + assert(rdd.collect().toSeq === (1 to 20)) + } finally { + Utils.deleteRecursively(checkpointDir) + } + } +} From 814a61a867ded965433c944c90961df529ac83ab Mon Sep 17 00:00:00 2001 From: Tejas Patil Date: Fri, 28 Apr 2017 23:12:26 -0700 Subject: [PATCH 0365/1765] [SPARK-20487][SQL] Display `serde` for `HiveTableScan` node in explained plan ## What changes were proposed in this pull request? This was a suggestion by rxin at https://github.com/apache/spark/pull/17780#issuecomment-298073408 ## How was this patch tested? - modified existing unit test - manual testing: ``` scala> hc.sql(" SELECT * FROM tejasp_bucketed_partitioned_1 where name = '' ").explain(true) == Parsed Logical Plan == 'Project [*] +- 'Filter ('name = ) +- 'UnresolvedRelation `tejasp_bucketed_partitioned_1` == Analyzed Logical Plan == user_id: bigint, name: string, ds: string Project [user_id#24L, name#25, ds#26] +- Filter (name#25 = ) +- SubqueryAlias tejasp_bucketed_partitioned_1 +- CatalogRelation `default`.`tejasp_bucketed_partitioned_1`, org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe, [user_id#24L, name#25], [ds#26] == Optimized Logical Plan == Filter (isnotnull(name#25) && (name#25 = )) +- CatalogRelation `default`.`tejasp_bucketed_partitioned_1`, org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe, [user_id#24L, name#25], [ds#26] == Physical Plan == *Filter (isnotnull(name#25) && (name#25 = )) +- HiveTableScan [user_id#24L, name#25, ds#26], CatalogRelation `default`.`tejasp_bucketed_partitioned_1`, org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe, [user_id#24L, name#25], [ds#26] ``` Author: Tejas Patil Closes #17806 from tejasapatil/add_serde. --- .../org/apache/spark/sql/catalyst/trees/TreeNode.scala | 6 +++++- .../apache/spark/sql/hive/execution/HiveExplainSuite.scala | 4 +++- 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala index b091315f24f1f..2109c1c23b706 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala @@ -444,7 +444,11 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product { case None => Nil case Some(null) => Nil case Some(any) => any :: Nil - case table: CatalogTable => table.identifier :: Nil + case table: CatalogTable => + table.storage.serde match { + case Some(serde) => table.identifier :: serde :: Nil + case _ => table.identifier :: Nil + } case other => other :: Nil }.mkString(", ") diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveExplainSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveExplainSuite.scala index ebafe6de0c830..aa1ca2909074f 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveExplainSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveExplainSuite.scala @@ -43,7 +43,9 @@ class HiveExplainSuite extends QueryTest with SQLTestUtils with TestHiveSingleto test("explain extended command") { checkKeywordsExist(sql(" explain select * from src where key=123 "), - "== Physical Plan ==") + "== Physical Plan ==", + "org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe") + checkKeywordsNotExist(sql(" explain select * from src where key=123 "), "== Parsed Logical Plan ==", "== Analyzed Logical Plan ==", From b28c3bc2020a6936e4ac4c28d49fd832a952af42 Mon Sep 17 00:00:00 2001 From: wangmiao1981 Date: Sat, 29 Apr 2017 10:31:01 -0700 Subject: [PATCH 0366/1765] [SPARK-20477][SPARKR][DOC] Document R bisecting k-means in R programming guide ## What changes were proposed in this pull request? Add hyper link in the SparkR programming guide. ## How was this patch tested? Build doc and manually check the doc link. Author: wangmiao1981 Closes #17805 from wangmiao1981/doc. --- docs/sparkr.md | 1 + 1 file changed, 1 insertion(+) diff --git a/docs/sparkr.md b/docs/sparkr.md index c85cfd45c4567..16b1ef6512420 100644 --- a/docs/sparkr.md +++ b/docs/sparkr.md @@ -497,6 +497,7 @@ SparkR supports the following machine learning algorithms currently: #### Clustering +* [`spark.bisectingKmeans`](api/R/spark.bisectingKmeans.html): [`Bisecting k-means`](ml-clustering.html#bisecting-k-means) * [`spark.gaussianMixture`](api/R/spark.gaussianMixture.html): [`Gaussian Mixture Model (GMM)`](ml-clustering.html#gaussian-mixture-model-gmm) * [`spark.kmeans`](api/R/spark.kmeans.html): [`K-Means`](ml-clustering.html#k-means) * [`spark.lda`](api/R/spark.lda.html): [`Latent Dirichlet Allocation (LDA)`](ml-clustering.html#latent-dirichlet-allocation-lda) From add9d1bba5cf33218a115428a03d3c76a514aa86 Mon Sep 17 00:00:00 2001 From: Yuhao Yang Date: Sat, 29 Apr 2017 10:51:45 -0700 Subject: [PATCH 0367/1765] [SPARK-19791][ML] Add doc and example for fpgrowth ## What changes were proposed in this pull request? Add a new section for fpm Add Example for FPGrowth in scala and Java updated: Rewrite transform to be more compact. ## How was this patch tested? local doc generation. Author: Yuhao Yang Closes #17130 from hhbyyh/fpmdoc. --- docs/_data/menu-ml.yaml | 2 + docs/ml-frequent-pattern-mining.md | 87 +++++++++++++++++++ docs/mllib-frequent-pattern-mining.md | 2 +- .../examples/ml/JavaFPGrowthExample.java | 77 ++++++++++++++++ .../src/main/python/ml/fpgrowth_example.py | 56 ++++++++++++ .../spark/examples/ml/FPGrowthExample.scala | 67 ++++++++++++++ .../org/apache/spark/ml/fpm/FPGrowth.scala | 35 ++++---- .../apache/spark/ml/fpm/FPGrowthSuite.scala | 2 + 8 files changed, 310 insertions(+), 18 deletions(-) create mode 100644 docs/ml-frequent-pattern-mining.md create mode 100644 examples/src/main/java/org/apache/spark/examples/ml/JavaFPGrowthExample.java create mode 100644 examples/src/main/python/ml/fpgrowth_example.py create mode 100644 examples/src/main/scala/org/apache/spark/examples/ml/FPGrowthExample.scala diff --git a/docs/_data/menu-ml.yaml b/docs/_data/menu-ml.yaml index 0c6b9b20a6e4b..047423f75aec1 100644 --- a/docs/_data/menu-ml.yaml +++ b/docs/_data/menu-ml.yaml @@ -8,6 +8,8 @@ url: ml-clustering.html - text: Collaborative filtering url: ml-collaborative-filtering.html +- text: Frequent Pattern Mining + url: ml-frequent-pattern-mining.html - text: Model selection and tuning url: ml-tuning.html - text: Advanced topics diff --git a/docs/ml-frequent-pattern-mining.md b/docs/ml-frequent-pattern-mining.md new file mode 100644 index 0000000000000..81634de8aade7 --- /dev/null +++ b/docs/ml-frequent-pattern-mining.md @@ -0,0 +1,87 @@ +--- +layout: global +title: Frequent Pattern Mining +displayTitle: Frequent Pattern Mining +--- + +Mining frequent items, itemsets, subsequences, or other substructures is usually among the +first steps to analyze a large-scale dataset, which has been an active research topic in +data mining for years. +We refer users to Wikipedia's [association rule learning](http://en.wikipedia.org/wiki/Association_rule_learning) +for more information. + +**Table of Contents** + +* This will become a table of contents (this text will be scraped). +{:toc} + +## FP-Growth + +The FP-growth algorithm is described in the paper +[Han et al., Mining frequent patterns without candidate generation](http://dx.doi.org/10.1145/335191.335372), +where "FP" stands for frequent pattern. +Given a dataset of transactions, the first step of FP-growth is to calculate item frequencies and identify frequent items. +Different from [Apriori-like](http://en.wikipedia.org/wiki/Apriori_algorithm) algorithms designed for the same purpose, +the second step of FP-growth uses a suffix tree (FP-tree) structure to encode transactions without generating candidate sets +explicitly, which are usually expensive to generate. +After the second step, the frequent itemsets can be extracted from the FP-tree. +In `spark.mllib`, we implemented a parallel version of FP-growth called PFP, +as described in [Li et al., PFP: Parallel FP-growth for query recommendation](http://dx.doi.org/10.1145/1454008.1454027). +PFP distributes the work of growing FP-trees based on the suffixes of transactions, +and hence is more scalable than a single-machine implementation. +We refer users to the papers for more details. + +`spark.ml`'s FP-growth implementation takes the following (hyper-)parameters: + +* `minSupport`: the minimum support for an itemset to be identified as frequent. + For example, if an item appears 3 out of 5 transactions, it has a support of 3/5=0.6. +* `minConfidence`: minimum confidence for generating Association Rule. Confidence is an indication of how often an + association rule has been found to be true. For example, if in the transactions itemset `X` appears 4 times, `X` + and `Y` co-occur only 2 times, the confidence for the rule `X => Y` is then 2/4 = 0.5. The parameter will not + affect the mining for frequent itemsets, but specify the minimum confidence for generating association rules + from frequent itemsets. +* `numPartitions`: the number of partitions used to distribute the work. By default the param is not set, and + number of partitions of the input dataset is used. + +The `FPGrowthModel` provides: + +* `freqItemsets`: frequent itemsets in the format of DataFrame("items"[Array], "freq"[Long]) +* `associationRules`: association rules generated with confidence above `minConfidence`, in the format of + DataFrame("antecedent"[Array], "consequent"[Array], "confidence"[Double]). +* `transform`: For each transaction in `itemsCol`, the `transform` method will compare its items against the antecedents + of each association rule. If the record contains all the antecedents of a specific association rule, the rule + will be considered as applicable and its consequents will be added to the prediction result. The transform + method will summarize the consequents from all the applicable rules as prediction. The prediction column has + the same data type as `itemsCol` and does not contain existing items in the `itemsCol`. + + +**Examples** + +
    + +
    +Refer to the [Scala API docs](api/scala/index.html#org.apache.spark.ml.fpm.FPGrowth) for more details. + +{% include_example scala/org/apache/spark/examples/ml/FPGrowthExample.scala %} +
    + +
    +Refer to the [Java API docs](api/java/org/apache/spark/ml/fpm/FPGrowth.html) for more details. + +{% include_example java/org/apache/spark/examples/ml/JavaFPGrowthExample.java %} +
    + +
    +Refer to the [Python API docs](api/python/pyspark.ml.html#pyspark.ml.fpm.FPGrowth) for more details. + +{% include_example python/ml/fpgrowth_example.py %} +
    + +
    + +Refer to the [R API docs](api/R/spark.fpGrowth.html) for more details. + +{% include_example r/ml/fpm.R %} +
    + +
    diff --git a/docs/mllib-frequent-pattern-mining.md b/docs/mllib-frequent-pattern-mining.md index 93e3f0b2d2267..c9cd7cc85e754 100644 --- a/docs/mllib-frequent-pattern-mining.md +++ b/docs/mllib-frequent-pattern-mining.md @@ -24,7 +24,7 @@ explicitly, which are usually expensive to generate. After the second step, the frequent itemsets can be extracted from the FP-tree. In `spark.mllib`, we implemented a parallel version of FP-growth called PFP, as described in [Li et al., PFP: Parallel FP-growth for query recommendation](http://dx.doi.org/10.1145/1454008.1454027). -PFP distributes the work of growing FP-trees based on the suffices of transactions, +PFP distributes the work of growing FP-trees based on the suffixes of transactions, and hence more scalable than a single-machine implementation. We refer users to the papers for more details. diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaFPGrowthExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaFPGrowthExample.java new file mode 100644 index 0000000000000..717ec21c8b203 --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaFPGrowthExample.java @@ -0,0 +1,77 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.ml; + +// $example on$ +import java.util.Arrays; +import java.util.List; + +import org.apache.spark.ml.fpm.FPGrowth; +import org.apache.spark.ml.fpm.FPGrowthModel; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.RowFactory; +import org.apache.spark.sql.SparkSession; +import org.apache.spark.sql.types.*; +// $example off$ + +/** + * An example demonstrating FPGrowth. + * Run with + *
    + * bin/run-example ml.JavaFPGrowthExample
    + * 
    + */ +public class JavaFPGrowthExample { + public static void main(String[] args) { + SparkSession spark = SparkSession + .builder() + .appName("JavaFPGrowthExample") + .getOrCreate(); + + // $example on$ + List data = Arrays.asList( + RowFactory.create(Arrays.asList("1 2 5".split(" "))), + RowFactory.create(Arrays.asList("1 2 3 5".split(" "))), + RowFactory.create(Arrays.asList("1 2".split(" "))) + ); + StructType schema = new StructType(new StructField[]{ new StructField( + "items", new ArrayType(DataTypes.StringType, true), false, Metadata.empty()) + }); + Dataset itemsDF = spark.createDataFrame(data, schema); + + FPGrowthModel model = new FPGrowth() + .setItemsCol("items") + .setMinSupport(0.5) + .setMinConfidence(0.6) + .fit(itemsDF); + + // Display frequent itemsets. + model.freqItemsets().show(); + + // Display generated association rules. + model.associationRules().show(); + + // transform examines the input items against all the association rules and summarize the + // consequents as prediction + model.transform(itemsDF).show(); + // $example off$ + + spark.stop(); + } +} diff --git a/examples/src/main/python/ml/fpgrowth_example.py b/examples/src/main/python/ml/fpgrowth_example.py new file mode 100644 index 0000000000000..c92c3c27abb21 --- /dev/null +++ b/examples/src/main/python/ml/fpgrowth_example.py @@ -0,0 +1,56 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# $example on$ +from pyspark.ml.fpm import FPGrowth +# $example off$ +from pyspark.sql import SparkSession + +""" +An example demonstrating FPGrowth. +Run with: + bin/spark-submit examples/src/main/python/ml/fpgrowth_example.py +""" + +if __name__ == "__main__": + spark = SparkSession\ + .builder\ + .appName("FPGrowthExample")\ + .getOrCreate() + + # $example on$ + df = spark.createDataFrame([ + (0, [1, 2, 5]), + (1, [1, 2, 3, 5]), + (2, [1, 2]) + ], ["id", "items"]) + + fpGrowth = FPGrowth(itemsCol="items", minSupport=0.5, minConfidence=0.6) + model = fpGrowth.fit(df) + + # Display frequent itemsets. + model.freqItemsets.show() + + # Display generated association rules. + model.associationRules.show() + + # transform examines the input items against all the association rules and summarize the + # consequents as prediction + model.transform(df).show() + # $example off$ + + spark.stop() diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/FPGrowthExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/FPGrowthExample.scala new file mode 100644 index 0000000000000..59110d70de550 --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/ml/FPGrowthExample.scala @@ -0,0 +1,67 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.ml + +// scalastyle:off println + +// $example on$ +import org.apache.spark.ml.fpm.FPGrowth +// $example off$ +import org.apache.spark.sql.SparkSession + +/** + * An example demonstrating FP-Growth. + * Run with + * {{{ + * bin/run-example ml.FPGrowthExample + * }}} + */ +object FPGrowthExample { + + def main(args: Array[String]): Unit = { + val spark = SparkSession + .builder + .appName(s"${this.getClass.getSimpleName}") + .getOrCreate() + import spark.implicits._ + + // $example on$ + val dataset = spark.createDataset(Seq( + "1 2 5", + "1 2 3 5", + "1 2") + ).map(t => t.split(" ")).toDF("items") + + val fpgrowth = new FPGrowth().setItemsCol("items").setMinSupport(0.5).setMinConfidence(0.6) + val model = fpgrowth.fit(dataset) + + // Display frequent itemsets. + model.freqItemsets.show() + + // Display generated association rules. + model.associationRules.show() + + // transform examines the input items against all the association rules and summarize the + // consequents as prediction + model.transform(dataset).show() + // $example off$ + + spark.stop() + } +} +// scalastyle:on println diff --git a/mllib/src/main/scala/org/apache/spark/ml/fpm/FPGrowth.scala b/mllib/src/main/scala/org/apache/spark/ml/fpm/FPGrowth.scala index d604c1ac001a2..8f00daa59f1a5 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/fpm/FPGrowth.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/fpm/FPGrowth.scala @@ -17,7 +17,6 @@ package org.apache.spark.ml.fpm -import scala.collection.mutable.ArrayBuffer import scala.reflect.ClassTag import org.apache.hadoop.fs.Path @@ -54,7 +53,7 @@ private[fpm] trait FPGrowthParams extends Params with HasPredictionCol { /** * Minimal support level of the frequent pattern. [0.0, 1.0]. Any pattern that appears - * more than (minSupport * size-of-the-dataset) times will be output + * more than (minSupport * size-of-the-dataset) times will be output in the frequent itemsets. * Default: 0.3 * @group param */ @@ -82,8 +81,8 @@ private[fpm] trait FPGrowthParams extends Params with HasPredictionCol { def getNumPartitions: Int = $(numPartitions) /** - * Minimal confidence for generating Association Rule. - * Note that minConfidence has no effect during fitting. + * Minimal confidence for generating Association Rule. minConfidence will not affect the mining + * for frequent itemsets, but will affect the association rules generation. * Default: 0.8 * @group param */ @@ -118,7 +117,7 @@ private[fpm] trait FPGrowthParams extends Params with HasPredictionCol { * Recommendation. PFP distributes computation in such a way that each worker executes an * independent group of mining tasks. The FP-Growth algorithm is described in * Han et al., Mining frequent patterns without - * candidate generation. Note null values in the feature column are ignored during fit(). + * candidate generation. Note null values in the itemsCol column are ignored during fit(). * * @see * Association rule learning (Wikipedia) @@ -167,7 +166,6 @@ class FPGrowth @Since("2.2.0") ( } val parentModel = mllibFP.run(items) val rows = parentModel.freqItemsets.map(f => Row(f.items, f.freq)) - val schema = StructType(Seq( StructField("items", dataset.schema($(itemsCol)).dataType, nullable = false), StructField("freq", LongType, nullable = false))) @@ -196,7 +194,7 @@ object FPGrowth extends DefaultParamsReadable[FPGrowth] { * :: Experimental :: * Model fitted by FPGrowth. * - * @param freqItemsets frequent items in the format of DataFrame("items"[Seq], "freq"[Long]) + * @param freqItemsets frequent itemsets in the format of DataFrame("items"[Array], "freq"[Long]) */ @Since("2.2.0") @Experimental @@ -244,10 +242,13 @@ class FPGrowthModel private[ml] ( /** * The transform method first generates the association rules according to the frequent itemsets. - * Then for each association rule, it will examine the input items against antecedents and - * summarize the consequents as prediction. The prediction column has the same data type as the - * input column(Array[T]) and will not contain existing items in the input column. The null - * values in the feature columns are treated as empty sets. + * Then for each transaction in itemsCol, the transform method will compare its items against the + * antecedents of each association rule. If the record contains all the antecedents of a + * specific association rule, the rule will be considered as applicable and its consequents + * will be added to the prediction result. The transform method will summarize the consequents + * from all the applicable rules as prediction. The prediction column has the same data type as + * the input column(Array[T]) and will not contain existing items in the input column. The null + * values in the itemsCol columns are treated as empty sets. * WARNING: internally it collects association rules to the driver and uses broadcast for * efficiency. This may bring pressure to driver memory for large set of association rules. */ @@ -335,13 +336,13 @@ private[fpm] object AssociationRules { /** * Computes the association rules with confidence above minConfidence. - * @param dataset DataFrame("items", "freq") containing frequent itemset obtained from - * algorithms like [[FPGrowth]]. + * @param dataset DataFrame("items"[Array], "freq"[Long]) containing frequent itemsets obtained + * from algorithms like [[FPGrowth]]. * @param itemsCol column name for frequent itemsets - * @param freqCol column name for frequent itemsets count - * @param minConfidence minimum confidence for the result association rules - * @return a DataFrame("antecedent", "consequent", "confidence") containing the association - * rules. + * @param freqCol column name for appearance count of the frequent itemsets + * @param minConfidence minimum confidence for generating the association rules + * @return a DataFrame("antecedent"[Array], "consequent"[Array], "confidence"[Double]) + * containing the association rules. */ def getAssociationRulesFromFP[T: ClassTag]( dataset: Dataset[_], diff --git a/mllib/src/test/scala/org/apache/spark/ml/fpm/FPGrowthSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/fpm/FPGrowthSuite.scala index 6806cb03bc42b..87f8b9034dde8 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/fpm/FPGrowthSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/fpm/FPGrowthSuite.scala @@ -122,6 +122,8 @@ class FPGrowthSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul .setMinConfidence(0.5678) assert(fpGrowth.getMinSupport === 0.4567) assert(model.getMinConfidence === 0.5678) + // numPartitions should not have default value. + assert(fpGrowth.isDefined(fpGrowth.numPartitions) === false) MLTestingUtils.checkCopyAndUids(fpGrowth, model) ParamsSuite.checkParams(fpGrowth) ParamsSuite.checkParams(model) From ee694cdff6fdb47f23370038f87f8594a80a8f27 Mon Sep 17 00:00:00 2001 From: wangmiao1981 Date: Sat, 29 Apr 2017 10:58:48 -0700 Subject: [PATCH 0368/1765] [SPARK-20533][SPARKR] SparkR Wrappers Model should be private and value should be lazy ## What changes were proposed in this pull request? MultilayerPerceptronClassifierWrapper model should be private. LogisticRegressionWrapper.scala rFeatures and rCoefficients should be lazy. ## How was this patch tested? Unit tests. Author: wangmiao1981 Closes #17808 from wangmiao1981/lazy. --- .../org/apache/spark/ml/r/LogisticRegressionWrapper.scala | 4 ++-- .../spark/ml/r/MultilayerPerceptronClassifierWrapper.scala | 6 +++--- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/LogisticRegressionWrapper.scala b/mllib/src/main/scala/org/apache/spark/ml/r/LogisticRegressionWrapper.scala index c96f99cb83434..703bcdf4ca725 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/r/LogisticRegressionWrapper.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/r/LogisticRegressionWrapper.scala @@ -40,13 +40,13 @@ private[r] class LogisticRegressionWrapper private ( private val lrModel: LogisticRegressionModel = pipeline.stages(1).asInstanceOf[LogisticRegressionModel] - val rFeatures: Array[String] = if (lrModel.getFitIntercept) { + lazy val rFeatures: Array[String] = if (lrModel.getFitIntercept) { Array("(Intercept)") ++ features } else { features } - val rCoefficients: Array[Double] = { + lazy val rCoefficients: Array[Double] = { val numRows = lrModel.coefficientMatrix.numRows val numCols = lrModel.coefficientMatrix.numCols val numColsWithIntercept = if (lrModel.getFitIntercept) numCols + 1 else numCols diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/MultilayerPerceptronClassifierWrapper.scala b/mllib/src/main/scala/org/apache/spark/ml/r/MultilayerPerceptronClassifierWrapper.scala index d34de30931143..48c87743dee60 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/r/MultilayerPerceptronClassifierWrapper.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/r/MultilayerPerceptronClassifierWrapper.scala @@ -36,11 +36,11 @@ private[r] class MultilayerPerceptronClassifierWrapper private ( import MultilayerPerceptronClassifierWrapper._ - val mlpModel: MultilayerPerceptronClassificationModel = + private val mlpModel: MultilayerPerceptronClassificationModel = pipeline.stages(1).asInstanceOf[MultilayerPerceptronClassificationModel] - val weights: Array[Double] = mlpModel.weights.toArray - val layers: Array[Int] = mlpModel.layers + lazy val weights: Array[Double] = mlpModel.weights.toArray + lazy val layers: Array[Int] = mlpModel.layers def transform(dataset: Dataset[_]): DataFrame = { pipeline.transform(dataset) From 70f1bcd7bcd42b30eabcf06a9639363f1ca4b449 Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Sat, 29 Apr 2017 11:02:17 -0700 Subject: [PATCH 0369/1765] [SPARK-20493][R] De-duplicate parse logics for DDL-like type strings in R MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## What changes were proposed in this pull request? It seems we are using `SQLUtils.getSQLDataType` for type string in structField. It looks we can replace this with `CatalystSqlParser.parseDataType`. They look similar DDL-like type definitions as below: ```scala scala> Seq(Tuple1(Tuple1("a"))).toDF.show() ``` ``` +---+ | _1| +---+ |[a]| +---+ ``` ```scala scala> Seq(Tuple1(Tuple1("a"))).toDF.select($"_1".cast("struct<_1:string>")).show() ``` ``` +---+ | _1| +---+ |[a]| +---+ ``` Such type strings looks identical when R’s one as below: ```R > write.df(sql("SELECT named_struct('_1', 'a') as struct"), "/tmp/aa", "parquet") > collect(read.df("/tmp/aa", "parquet", structType(structField("struct", "struct<_1:string>")))) struct 1 a ``` R’s one is stricter because we are checking the types via regular expressions in R side ahead. Actual logics there look a bit different but as we check it ahead in R side, it looks replacing it would not introduce (I think) no behaviour changes. To make this sure, the tests dedicated for it were added in SPARK-20105. (It looks `structField` is the only place that calls this method). ## How was this patch tested? Existing tests - https://github.com/apache/spark/blob/master/R/pkg/inst/tests/testthat/test_sparkSQL.R#L143-L194 should cover this. Author: hyukjinkwon Closes #17785 from HyukjinKwon/SPARK-20493. --- R/pkg/R/utils.R | 8 ++++ R/pkg/inst/tests/testthat/test_sparkSQL.R | 13 +++++- R/pkg/inst/tests/testthat/test_utils.R | 6 +-- .../org/apache/spark/sql/api/r/SQLUtils.scala | 43 +------------------ 4 files changed, 24 insertions(+), 46 deletions(-) diff --git a/R/pkg/R/utils.R b/R/pkg/R/utils.R index fbc89e98847bf..d29af00affb98 100644 --- a/R/pkg/R/utils.R +++ b/R/pkg/R/utils.R @@ -864,6 +864,14 @@ captureJVMException <- function(e, method) { # Extract the first message of JVM exception. first <- strsplit(msg[2], "\r?\n\tat")[[1]][1] stop(paste0(rmsg, "no such table - ", first), call. = FALSE) + } else if (any(grep("org.apache.spark.sql.catalyst.parser.ParseException: ", stacktrace))) { + msg <- strsplit(stacktrace, "org.apache.spark.sql.catalyst.parser.ParseException: ", + fixed = TRUE)[[1]] + # Extract "Error in ..." message. + rmsg <- msg[1] + # Extract the first message of JVM exception. + first <- strsplit(msg[2], "\r?\n\tat")[[1]][1] + stop(paste0(rmsg, "parse error - ", first), call. = FALSE) } else { stop(stacktrace, call. = FALSE) } diff --git a/R/pkg/inst/tests/testthat/test_sparkSQL.R b/R/pkg/inst/tests/testthat/test_sparkSQL.R index 2cef7191d4f2a..1a3d6df437d7e 100644 --- a/R/pkg/inst/tests/testthat/test_sparkSQL.R +++ b/R/pkg/inst/tests/testthat/test_sparkSQL.R @@ -150,7 +150,12 @@ test_that("structField type strings", { binary = "BinaryType", boolean = "BooleanType", timestamp = "TimestampType", - date = "DateType") + date = "DateType", + tinyint = "ByteType", + smallint = "ShortType", + int = "IntegerType", + bigint = "LongType", + decimal = "DecimalType(10,0)") complexTypes <- list("map" = "MapType(StringType,IntegerType,true)", "array" = "ArrayType(StringType,true)", @@ -174,7 +179,11 @@ test_that("structField type strings", { numeric = "numeric", character = "character", raw = "raw", - logical = "logical") + logical = "logical", + short = "short", + varchar = "varchar", + long = "long", + char = "char") complexErrors <- list("map" = " integer", "array" = "String", diff --git a/R/pkg/inst/tests/testthat/test_utils.R b/R/pkg/inst/tests/testthat/test_utils.R index 6d006eccf665e..1ca383da26ec2 100644 --- a/R/pkg/inst/tests/testthat/test_utils.R +++ b/R/pkg/inst/tests/testthat/test_utils.R @@ -167,13 +167,13 @@ test_that("convertToJSaveMode", { }) test_that("captureJVMException", { - method <- "getSQLDataType" + method <- "createStructField" expect_error(tryCatch(callJStatic("org.apache.spark.sql.api.r.SQLUtils", method, - "unknown"), + "col", "unknown", TRUE), error = function(e) { captureJVMException(e, method) }), - "Error in getSQLDataType : illegal argument - Invalid type unknown") + "parse error - .*DataType unknown.*not supported.") }) test_that("hashCode", { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala index a26d00411fbaa..d94e528a3ad47 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala @@ -31,6 +31,7 @@ import org.apache.spark.broadcast.Broadcast import org.apache.spark.rdd.RDD import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.expressions.GenericRowWithSchema +import org.apache.spark.sql.catalyst.parser.CatalystSqlParser import org.apache.spark.sql.execution.command.ShowTablesCommand import org.apache.spark.sql.internal.StaticSQLConf.CATALOG_IMPLEMENTATION import org.apache.spark.sql.types._ @@ -92,48 +93,8 @@ private[sql] object SQLUtils extends Logging { def r: Regex = new Regex(sc.parts.mkString, sc.parts.tail.map(_ => "x"): _*) } - def getSQLDataType(dataType: String): DataType = { - dataType match { - case "byte" => org.apache.spark.sql.types.ByteType - case "integer" => org.apache.spark.sql.types.IntegerType - case "float" => org.apache.spark.sql.types.FloatType - case "double" => org.apache.spark.sql.types.DoubleType - case "numeric" => org.apache.spark.sql.types.DoubleType - case "character" => org.apache.spark.sql.types.StringType - case "string" => org.apache.spark.sql.types.StringType - case "binary" => org.apache.spark.sql.types.BinaryType - case "raw" => org.apache.spark.sql.types.BinaryType - case "logical" => org.apache.spark.sql.types.BooleanType - case "boolean" => org.apache.spark.sql.types.BooleanType - case "timestamp" => org.apache.spark.sql.types.TimestampType - case "date" => org.apache.spark.sql.types.DateType - case r"\Aarray<(.+)${elemType}>\Z" => - org.apache.spark.sql.types.ArrayType(getSQLDataType(elemType)) - case r"\Amap<(.+)${keyType},(.+)${valueType}>\Z" => - if (keyType != "string" && keyType != "character") { - throw new IllegalArgumentException("Key type of a map must be string or character") - } - org.apache.spark.sql.types.MapType(getSQLDataType(keyType), getSQLDataType(valueType)) - case r"\Astruct<(.+)${fieldsStr}>\Z" => - if (fieldsStr(fieldsStr.length - 1) == ',') { - throw new IllegalArgumentException(s"Invalid type $dataType") - } - val fields = fieldsStr.split(",") - val structFields = fields.map { field => - field match { - case r"\A(.+)${fieldName}:(.+)${fieldType}\Z" => - createStructField(fieldName, fieldType, true) - - case _ => throw new IllegalArgumentException(s"Invalid type $dataType") - } - } - createStructType(structFields) - case _ => throw new IllegalArgumentException(s"Invalid type $dataType") - } - } - def createStructField(name: String, dataType: String, nullable: Boolean): StructField = { - val dtObj = getSQLDataType(dataType) + val dtObj = CatalystSqlParser.parseDataType(dataType) StructField(name, dtObj, nullable) } From d228cd0b0243773a1c834414a240d1c553ab7af6 Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Sat, 29 Apr 2017 13:46:40 -0700 Subject: [PATCH 0370/1765] [SPARK-20442][PYTHON][DOCS] Fill up documentations for functions in Column API in PySpark ## What changes were proposed in this pull request? This PR proposes to fill up the documentation with examples for `bitwiseOR`, `bitwiseAND`, `bitwiseXOR`. `contains`, `asc` and `desc` in `Column` API. Also, this PR fixes minor typos in the documentation and matches some of the contents between Scala doc and Python doc. Lastly, this PR suggests to use `spark` rather than `sc` in doc tests in `Column` for Python documentation. ## How was this patch tested? Doc tests were added and manually tested with the commands below: `./python/run-tests.py --module pyspark-sql` `./python/run-tests.py --module pyspark-sql --python-executable python3` `./dev/lint-python` Output was checked via `make html` under `./python/docs`. The snapshots will be left on the codes with comments. Author: hyukjinkwon Closes #17737 from HyukjinKwon/SPARK-20442. --- python/pyspark/sql/column.py | 104 ++++++++++++++---- .../expressions/bitwiseExpressions.scala | 2 +- .../scala/org/apache/spark/sql/Column.scala | 31 +++--- 3 files changed, 99 insertions(+), 38 deletions(-) diff --git a/python/pyspark/sql/column.py b/python/pyspark/sql/column.py index 46c1707cb6c37..b8df37f25180f 100644 --- a/python/pyspark/sql/column.py +++ b/python/pyspark/sql/column.py @@ -185,9 +185,43 @@ def __contains__(self, item): "in a string column or 'array_contains' function for an array column.") # bitwise operators - bitwiseOR = _bin_op("bitwiseOR") - bitwiseAND = _bin_op("bitwiseAND") - bitwiseXOR = _bin_op("bitwiseXOR") + _bitwiseOR_doc = """ + Compute bitwise OR of this expression with another expression. + + :param other: a value or :class:`Column` to calculate bitwise or(|) against + this :class:`Column`. + + >>> from pyspark.sql import Row + >>> df = spark.createDataFrame([Row(a=170, b=75)]) + >>> df.select(df.a.bitwiseOR(df.b)).collect() + [Row((a | b)=235)] + """ + _bitwiseAND_doc = """ + Compute bitwise AND of this expression with another expression. + + :param other: a value or :class:`Column` to calculate bitwise and(&) against + this :class:`Column`. + + >>> from pyspark.sql import Row + >>> df = spark.createDataFrame([Row(a=170, b=75)]) + >>> df.select(df.a.bitwiseAND(df.b)).collect() + [Row((a & b)=10)] + """ + _bitwiseXOR_doc = """ + Compute bitwise XOR of this expression with another expression. + + :param other: a value or :class:`Column` to calculate bitwise xor(^) against + this :class:`Column`. + + >>> from pyspark.sql import Row + >>> df = spark.createDataFrame([Row(a=170, b=75)]) + >>> df.select(df.a.bitwiseXOR(df.b)).collect() + [Row((a ^ b)=225)] + """ + + bitwiseOR = _bin_op("bitwiseOR", _bitwiseOR_doc) + bitwiseAND = _bin_op("bitwiseAND", _bitwiseAND_doc) + bitwiseXOR = _bin_op("bitwiseXOR", _bitwiseXOR_doc) @since(1.3) def getItem(self, key): @@ -195,7 +229,7 @@ def getItem(self, key): An expression that gets an item at position ``ordinal`` out of a list, or gets an item by key out of a dict. - >>> df = sc.parallelize([([1, 2], {"key": "value"})]).toDF(["l", "d"]) + >>> df = spark.createDataFrame([([1, 2], {"key": "value"})], ["l", "d"]) >>> df.select(df.l.getItem(0), df.d.getItem("key")).show() +----+------+ |l[0]|d[key]| @@ -217,7 +251,7 @@ def getField(self, name): An expression that gets a field by name in a StructField. >>> from pyspark.sql import Row - >>> df = sc.parallelize([Row(r=Row(a=1, b="b"))]).toDF() + >>> df = spark.createDataFrame([Row(r=Row(a=1, b="b"))]) >>> df.select(df.r.getField("b")).show() +---+ |r.b| @@ -250,8 +284,17 @@ def __iter__(self): raise TypeError("Column is not iterable") # string methods + _contains_doc = """ + Contains the other element. Returns a boolean :class:`Column` based on a string match. + + :param other: string in line + + >>> df.filter(df.name.contains('o')).collect() + [Row(age=5, name=u'Bob')] + """ _rlike_doc = """ - Return a Boolean :class:`Column` based on a regex match. + SQL RLIKE expression (LIKE with Regex). Returns a boolean :class:`Column` based on a regex + match. :param other: an extended regex expression @@ -259,7 +302,7 @@ def __iter__(self): [Row(age=2, name=u'Alice')] """ _like_doc = """ - Return a Boolean :class:`Column` based on a SQL LIKE match. + SQL like expression. Returns a boolean :class:`Column` based on a SQL LIKE match. :param other: a SQL LIKE pattern @@ -269,9 +312,9 @@ def __iter__(self): [Row(age=2, name=u'Alice')] """ _startswith_doc = """ - Return a Boolean :class:`Column` based on a string match. + String starts with. Returns a boolean :class:`Column` based on a string match. - :param other: string at end of line (do not use a regex `^`) + :param other: string at start of line (do not use a regex `^`) >>> df.filter(df.name.startswith('Al')).collect() [Row(age=2, name=u'Alice')] @@ -279,7 +322,7 @@ def __iter__(self): [] """ _endswith_doc = """ - Return a Boolean :class:`Column` based on matching end of string. + String ends with. Returns a boolean :class:`Column` based on a string match. :param other: string at end of line (do not use a regex `$`) @@ -289,7 +332,7 @@ def __iter__(self): [] """ - contains = _bin_op("contains") + contains = ignore_unicode_prefix(_bin_op("contains", _contains_doc)) rlike = ignore_unicode_prefix(_bin_op("rlike", _rlike_doc)) like = ignore_unicode_prefix(_bin_op("like", _like_doc)) startswith = ignore_unicode_prefix(_bin_op("startsWith", _startswith_doc)) @@ -337,27 +380,40 @@ def isin(self, *cols): return Column(jc) # order - asc = _unary_op("asc", "Returns a sort expression based on the" - " ascending order of the given column name.") - desc = _unary_op("desc", "Returns a sort expression based on the" - " descending order of the given column name.") + _asc_doc = """ + Returns a sort expression based on the ascending order of the given column name + + >>> from pyspark.sql import Row + >>> df = spark.createDataFrame([Row(name=u'Tom', height=80), Row(name=u'Alice', height=None)]) + >>> df.select(df.name).orderBy(df.name.asc()).collect() + [Row(name=u'Alice'), Row(name=u'Tom')] + """ + _desc_doc = """ + Returns a sort expression based on the descending order of the given column name. + + >>> from pyspark.sql import Row + >>> df = spark.createDataFrame([Row(name=u'Tom', height=80), Row(name=u'Alice', height=None)]) + >>> df.select(df.name).orderBy(df.name.desc()).collect() + [Row(name=u'Tom'), Row(name=u'Alice')] + """ + + asc = ignore_unicode_prefix(_unary_op("asc", _asc_doc)) + desc = ignore_unicode_prefix(_unary_op("desc", _desc_doc)) _isNull_doc = """ - True if the current expression is null. Often combined with - :func:`DataFrame.filter` to select rows with null values. + True if the current expression is null. >>> from pyspark.sql import Row - >>> df2 = sc.parallelize([Row(name=u'Tom', height=80), Row(name=u'Alice', height=None)]).toDF() - >>> df2.filter(df2.height.isNull()).collect() + >>> df = spark.createDataFrame([Row(name=u'Tom', height=80), Row(name=u'Alice', height=None)]) + >>> df.filter(df.height.isNull()).collect() [Row(height=None, name=u'Alice')] """ _isNotNull_doc = """ - True if the current expression is null. Often combined with - :func:`DataFrame.filter` to select rows with non-null values. + True if the current expression is NOT null. >>> from pyspark.sql import Row - >>> df2 = sc.parallelize([Row(name=u'Tom', height=80), Row(name=u'Alice', height=None)]).toDF() - >>> df2.filter(df2.height.isNotNull()).collect() + >>> df = spark.createDataFrame([Row(name=u'Tom', height=80), Row(name=u'Alice', height=None)]) + >>> df.filter(df.height.isNotNull()).collect() [Row(height=80, name=u'Tom')] """ @@ -527,7 +583,7 @@ def _test(): .appName("sql.column tests")\ .getOrCreate() sc = spark.sparkContext - globs['sc'] = sc + globs['spark'] = spark globs['df'] = sc.parallelize([(2, 'Alice'), (5, 'Bob')]) \ .toDF(StructType([StructField('age', IntegerType()), StructField('name', StringType())])) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/bitwiseExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/bitwiseExpressions.scala index 2918040771433..425efbb6c96c4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/bitwiseExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/bitwiseExpressions.scala @@ -86,7 +86,7 @@ case class BitwiseOr(left: Expression, right: Expression) extends BinaryArithmet } /** - * A function that calculates bitwise xor of two numbers. + * A function that calculates bitwise xor({@literal ^}) of two numbers. * * Code generation inherited from BinaryArithmetic. */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala index 43de2de7e7094..b23ab1fa3514a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala @@ -779,7 +779,7 @@ class Column(val expr: Expression) extends Logging { def isin(list: Any*): Column = withExpr { In(expr, list.map(lit(_).expr)) } /** - * SQL like expression. + * SQL like expression. Returns a boolean column based on a SQL LIKE match. * * @group expr_ops * @since 1.3.0 @@ -787,7 +787,8 @@ class Column(val expr: Expression) extends Logging { def like(literal: String): Column = withExpr { Like(expr, lit(literal).expr) } /** - * SQL RLIKE expression (LIKE with Regex). + * SQL RLIKE expression (LIKE with Regex). Returns a boolean column based on a regex + * match. * * @group expr_ops * @since 1.3.0 @@ -838,7 +839,7 @@ class Column(val expr: Expression) extends Logging { } /** - * Contains the other element. + * Contains the other element. Returns a boolean column based on a string match. * * @group expr_ops * @since 1.3.0 @@ -846,7 +847,7 @@ class Column(val expr: Expression) extends Logging { def contains(other: Any): Column = withExpr { Contains(expr, lit(other).expr) } /** - * String starts with. + * String starts with. Returns a boolean column based on a string match. * * @group expr_ops * @since 1.3.0 @@ -854,7 +855,7 @@ class Column(val expr: Expression) extends Logging { def startsWith(other: Column): Column = withExpr { StartsWith(expr, lit(other).expr) } /** - * String starts with another string literal. + * String starts with another string literal. Returns a boolean column based on a string match. * * @group expr_ops * @since 1.3.0 @@ -862,7 +863,7 @@ class Column(val expr: Expression) extends Logging { def startsWith(literal: String): Column = this.startsWith(lit(literal)) /** - * String ends with. + * String ends with. Returns a boolean column based on a string match. * * @group expr_ops * @since 1.3.0 @@ -870,7 +871,7 @@ class Column(val expr: Expression) extends Logging { def endsWith(other: Column): Column = withExpr { EndsWith(expr, lit(other).expr) } /** - * String ends with another string literal. + * String ends with another string literal. Returns a boolean column based on a string match. * * @group expr_ops * @since 1.3.0 @@ -1008,7 +1009,7 @@ class Column(val expr: Expression) extends Logging { def cast(to: String): Column = cast(CatalystSqlParser.parseDataType(to)) /** - * Returns an ordering used in sorting. + * Returns a sort expression based on the descending order of the column. * {{{ * // Scala * df.sort(df("age").desc) @@ -1023,7 +1024,8 @@ class Column(val expr: Expression) extends Logging { def desc: Column = withExpr { SortOrder(expr, Descending) } /** - * Returns a descending ordering used in sorting, where null values appear before non-null values. + * Returns a sort expression based on the descending order of the column, + * and null values appear before non-null values. * {{{ * // Scala: sort a DataFrame by age column in descending order and null values appearing first. * df.sort(df("age").desc_nulls_first) @@ -1038,7 +1040,8 @@ class Column(val expr: Expression) extends Logging { def desc_nulls_first: Column = withExpr { SortOrder(expr, Descending, NullsFirst, Set.empty) } /** - * Returns a descending ordering used in sorting, where null values appear after non-null values. + * Returns a sort expression based on the descending order of the column, + * and null values appear after non-null values. * {{{ * // Scala: sort a DataFrame by age column in descending order and null values appearing last. * df.sort(df("age").desc_nulls_last) @@ -1053,7 +1056,7 @@ class Column(val expr: Expression) extends Logging { def desc_nulls_last: Column = withExpr { SortOrder(expr, Descending, NullsLast, Set.empty) } /** - * Returns an ascending ordering used in sorting. + * Returns a sort expression based on ascending order of the column. * {{{ * // Scala: sort a DataFrame by age column in ascending order. * df.sort(df("age").asc) @@ -1068,7 +1071,8 @@ class Column(val expr: Expression) extends Logging { def asc: Column = withExpr { SortOrder(expr, Ascending) } /** - * Returns an ascending ordering used in sorting, where null values appear before non-null values. + * Returns a sort expression based on ascending order of the column, + * and null values return before non-null values. * {{{ * // Scala: sort a DataFrame by age column in ascending order and null values appearing first. * df.sort(df("age").asc_nulls_last) @@ -1083,7 +1087,8 @@ class Column(val expr: Expression) extends Logging { def asc_nulls_first: Column = withExpr { SortOrder(expr, Ascending, NullsFirst, Set.empty) } /** - * Returns an ordering used in sorting, where null values appear after non-null values. + * Returns a sort expression based on ascending order of the column, + * and null values appear after non-null values. * {{{ * // Scala: sort a DataFrame by age column in ascending order and null values appearing last. * df.sort(df("age").asc_nulls_last) From 4d99b95ad0d0c7ef909c8e492ec45e94cf0189b4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=83=AD=E5=B0=8F=E9=BE=99=2010207633?= Date: Sun, 30 Apr 2017 09:06:25 +0100 Subject: [PATCH 0371/1765] [SPARK-20521][DOC][CORE] The default of 'spark.worker.cleanup.appDataTtl' should be 604800 in spark-standalone.md MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## What changes were proposed in this pull request? Currently, our project needs to be set to clean up the worker directory cleanup cycle is three days. When I follow http://spark.apache.org/docs/latest/spark-standalone.html, configure the 'spark.worker.cleanup.appDataTtl' parameter, I configured to 3 * 24 * 3600. When I start the spark service, the startup fails, and the worker log displays the error log as follows: 2017-04-28 15:02:03,306 INFO Utils: Successfully started service 'sparkWorker' on port 48728. Exception in thread "main" java.lang.NumberFormatException: For input string: "3 * 24 * 3600" at java.lang.NumberFormatException.forInputString(NumberFormatException.java:65) at java.lang.Long.parseLong(Long.java:430) at java.lang.Long.parseLong(Long.java:483) at scala.collection.immutable.StringLike$class.toLong(StringLike.scala:276) at scala.collection.immutable.StringOps.toLong(StringOps.scala:29) at org.apache.spark.SparkConf$$anonfun$getLong$2.apply(SparkConf.scala:380) at org.apache.spark.SparkConf$$anonfun$getLong$2.apply(SparkConf.scala:380) at scala.Option.map(Option.scala:146) at org.apache.spark.SparkConf.getLong(SparkConf.scala:380) at org.apache.spark.deploy.worker.Worker.(Worker.scala:100) at org.apache.spark.deploy.worker.Worker$.startRpcEnvAndEndpoint(Worker.scala:730) at org.apache.spark.deploy.worker.Worker$.main(Worker.scala:709) at org.apache.spark.deploy.worker.Worker.main(Worker.scala) **Because we put 7 * 24 * 3600 as a string, forced to convert to the dragon type, will lead to problems in the program.** **So I think the default value of the current configuration should be a specific long value, rather than 7 * 24 * 3600,should be 604800. Because it would mislead users for similar configurations, resulting in spark start failure.** ## How was this patch tested? manual tests Please review http://spark.apache.org/contributing.html before opening a pull request. Author: 郭小龙 10207633 Author: guoxiaolong Author: guoxiaolongzte Closes #17798 from guoxiaolongzte/SPARK-20521. --- docs/spark-standalone.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/spark-standalone.md b/docs/spark-standalone.md index 1c0b60f7b9346..34ced9ed7b462 100644 --- a/docs/spark-standalone.md +++ b/docs/spark-standalone.md @@ -242,7 +242,7 @@ SPARK_WORKER_OPTS supports the following system properties:
    - + ``` Here `num` field represents number of attempts, this is not equal to REST APIs. In the REST API, if attempt id is not existed the URL should be `api/v1/applications//logs`, otherwise the URL should be `api/v1/applications///logs`. Using `` to represent `` will lead to the issue of "no such app". Manual verification. CC ajbozarth can you please review this change, since you add this feature before? Thanks! Author: jerryshao Closes #17795 from jerryshao/SPARK-20517. --- .../org/apache/spark/ui/static/historypage-template.html | 2 +- .../main/resources/org/apache/spark/ui/static/historypage.js | 3 +++ 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/core/src/main/resources/org/apache/spark/ui/static/historypage-template.html b/core/src/main/resources/org/apache/spark/ui/static/historypage-template.html index 42e2d9abdeb5e..6ba3b092dc658 100644 --- a/core/src/main/resources/org/apache/spark/ui/static/historypage-template.html +++ b/core/src/main/resources/org/apache/spark/ui/static/historypage-template.html @@ -77,7 +77,7 @@ - + {{/attempts}} {{/applications}} diff --git a/core/src/main/resources/org/apache/spark/ui/static/historypage.js b/core/src/main/resources/org/apache/spark/ui/static/historypage.js index 54810edaf1460..1f89306403cd5 100644 --- a/core/src/main/resources/org/apache/spark/ui/static/historypage.js +++ b/core/src/main/resources/org/apache/spark/ui/static/historypage.js @@ -120,6 +120,9 @@ $(document).ready(function() { attempt["startTime"] = formatDate(attempt["startTime"]); attempt["endTime"] = formatDate(attempt["endTime"]); attempt["lastUpdated"] = formatDate(attempt["lastUpdated"]); + attempt["log"] = uiRoot + "/api/v1/applications/" + id + "/" + + (attempt.hasOwnProperty("attemptId") ? attempt["attemptId"] + "/" : "") + "logs"; + var app_clone = {"id" : id, "name" : name, "num" : num, "attempts" : [attempt]}; array.push(app_clone); } From 6fc6cf88d871f5b05b0ad1a504e0d6213cf9d331 Mon Sep 17 00:00:00 2001 From: Kunal Khamar Date: Mon, 1 May 2017 11:37:30 -0700 Subject: [PATCH 0380/1765] [SPARK-20464][SS] Add a job group and description for streaming queries and fix cancellation of running jobs using the job group ## What changes were proposed in this pull request? Job group: adding a job group is required to properly cancel running jobs related to a query. Description: the new description makes it easier to group the batches of a query by sorting by name in the Spark Jobs UI. ## How was this patch tested? - Unit tests - UI screenshot - Order by job id: ![screen shot 2017-04-27 at 5 10 09 pm](https://cloud.githubusercontent.com/assets/7865120/25509468/15452274-2b6e-11e7-87ba-d929816688cf.png) - Order by description: ![screen shot 2017-04-27 at 5 10 22 pm](https://cloud.githubusercontent.com/assets/7865120/25509474/1c298512-2b6e-11e7-99b8-fef1ef7665c1.png) - Order by job id (no query name): ![screen shot 2017-04-27 at 5 21 33 pm](https://cloud.githubusercontent.com/assets/7865120/25509482/28c96dc8-2b6e-11e7-8df0-9d3cdbb05e36.png) - Order by description (no query name): ![screen shot 2017-04-27 at 5 21 44 pm](https://cloud.githubusercontent.com/assets/7865120/25509489/37674742-2b6e-11e7-9357-b5c38ec16ac4.png) Author: Kunal Khamar Closes #17765 from kunalkhamar/sc-6696. --- .../scala/org/apache/spark/ui/UIUtils.scala | 2 +- .../execution/streaming/StreamExecution.scala | 12 ++++ .../spark/sql/streaming/StreamSuite.scala | 66 +++++++++++++++++++ 3 files changed, 79 insertions(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/ui/UIUtils.scala b/core/src/main/scala/org/apache/spark/ui/UIUtils.scala index e53d6907bc404..79b0d81af52b5 100644 --- a/core/src/main/scala/org/apache/spark/ui/UIUtils.scala +++ b/core/src/main/scala/org/apache/spark/ui/UIUtils.scala @@ -446,7 +446,7 @@ private[spark] object UIUtils extends Logging { val xml = XML.loadString(s"""$desc""") // Verify that this has only anchors and span (we are wrapping in span) - val allowedNodeLabels = Set("a", "span") + val allowedNodeLabels = Set("a", "span", "br") val illegalNodes = xml \\ "_" filterNot { case node: Node => allowedNodeLabels.contains(node.label) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala index bcf0d970f7ec1..affc2018c43cb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala @@ -252,6 +252,8 @@ class StreamExecution( */ private def runBatches(): Unit = { try { + sparkSession.sparkContext.setJobGroup(runId.toString, getBatchDescriptionString, + interruptOnCancel = true) if (sparkSession.sessionState.conf.streamingMetricsEnabled) { sparkSession.sparkContext.env.metricsSystem.registerSource(streamMetrics) } @@ -289,6 +291,7 @@ class StreamExecution( if (currentBatchId < 0) { // We'll do this initialization only once populateStartOffsets(sparkSessionToRunBatches) + sparkSession.sparkContext.setJobDescription(getBatchDescriptionString) logDebug(s"Stream running from $committedOffsets to $availableOffsets") } else { constructNextBatch() @@ -308,6 +311,7 @@ class StreamExecution( logDebug(s"batch ${currentBatchId} committed") // We'll increase currentBatchId after we complete processing current batch's data currentBatchId += 1 + sparkSession.sparkContext.setJobDescription(getBatchDescriptionString) } else { currentStatus = currentStatus.copy(isDataAvailable = false) updateStatusMessage("Waiting for data to arrive") @@ -684,8 +688,11 @@ class StreamExecution( // intentionally state.set(TERMINATED) if (microBatchThread.isAlive) { + sparkSession.sparkContext.cancelJobGroup(runId.toString) microBatchThread.interrupt() microBatchThread.join() + // microBatchThread may spawn new jobs, so we need to cancel again to prevent a leak + sparkSession.sparkContext.cancelJobGroup(runId.toString) } logInfo(s"Query $prettyIdString was stopped") } @@ -825,6 +832,11 @@ class StreamExecution( } } + private def getBatchDescriptionString: String = { + val batchDescription = if (currentBatchId < 0) "init" else currentBatchId.toString + Option(name).map(_ + "
    ").getOrElse("") + + s"id = $id
    runId = $runId
    batch = $batchDescription" + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala index 13fe51a557733..01ea62a9de4d5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala @@ -25,6 +25,8 @@ import scala.util.control.ControlThrowable import org.apache.commons.io.FileUtils +import org.apache.spark.SparkContext +import org.apache.spark.scheduler.{SparkListener, SparkListenerJobStart} import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.streaming.InternalOutputModes import org.apache.spark.sql.execution.command.ExplainCommand @@ -500,6 +502,70 @@ class StreamSuite extends StreamTest { } } } + + test("calling stop() on a query cancels related jobs") { + val input = MemoryStream[Int] + val query = input + .toDS() + .map { i => + while (!org.apache.spark.TaskContext.get().isInterrupted()) { + // keep looping till interrupted by query.stop() + Thread.sleep(100) + } + i + } + .writeStream + .format("console") + .start() + + input.addData(1) + // wait for jobs to start + eventually(timeout(streamingTimeout)) { + assert(sparkContext.statusTracker.getActiveJobIds().nonEmpty) + } + + query.stop() + // make sure jobs are stopped + eventually(timeout(streamingTimeout)) { + assert(sparkContext.statusTracker.getActiveJobIds().isEmpty) + } + } + + test("batch id is updated correctly in the job description") { + val queryName = "memStream" + @volatile var jobDescription: String = null + def assertDescContainsQueryNameAnd(batch: Integer): Unit = { + // wait for listener event to be processed + spark.sparkContext.listenerBus.waitUntilEmpty(streamingTimeout.toMillis) + assert(jobDescription.contains(queryName) && jobDescription.contains(s"batch = $batch")) + } + + spark.sparkContext.addSparkListener(new SparkListener { + override def onJobStart(jobStart: SparkListenerJobStart): Unit = { + jobDescription = jobStart.properties.getProperty(SparkContext.SPARK_JOB_DESCRIPTION) + } + }) + + val input = MemoryStream[Int] + val query = input + .toDS() + .map(_ + 1) + .writeStream + .format("memory") + .queryName(queryName) + .start() + + input.addData(1) + query.processAllAvailable() + assertDescContainsQueryNameAnd(batch = 0) + input.addData(2, 3) + query.processAllAvailable() + assertDescContainsQueryNameAnd(batch = 1) + input.addData(4) + query.processAllAvailable() + assertDescContainsQueryNameAnd(batch = 2) + query.stop() + } } abstract class FakeSource extends StreamSourceProvider { From 2b2dd08e975dd7fbf261436aa877f1d7497ed31f Mon Sep 17 00:00:00 2001 From: Ryan Blue Date: Mon, 1 May 2017 14:48:02 -0700 Subject: [PATCH 0381/1765] [SPARK-20540][CORE] Fix unstable executor requests. There are two problems fixed in this commit. First, the ExecutorAllocationManager sets a timeout to avoid requesting executors too often. However, the timeout is always updated based on its value and a timeout, not the current time. If the call is delayed by locking for more than the ongoing scheduler timeout, the manager will request more executors on every run. This seems to be the main cause of SPARK-20540. The second problem is that the total number of requested executors is not tracked by the CoarseGrainedSchedulerBackend. Instead, it calculates the value based on the current status of 3 variables: the number of known executors, the number of executors that have been killed, and the number of pending executors. But, the number of pending executors is never less than 0, even though there may be more known than requested. When executors are killed and not replaced, this can cause the request sent to YARN to be incorrect because there were too many executors due to the scheduler's state being slightly out of date. This is fixed by tracking the currently requested size explicitly. ## How was this patch tested? Existing tests. Author: Ryan Blue Closes #17813 from rdblue/SPARK-20540-fix-dynamic-allocation. --- .../spark/ExecutorAllocationManager.scala | 2 +- .../CoarseGrainedSchedulerBackend.scala | 32 ++++++++++++++++--- .../StandaloneDynamicAllocationSuite.scala | 6 ++-- 3 files changed, 33 insertions(+), 7 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala b/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala index 261b3329a7b9c..fcc72ff49276d 100644 --- a/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala +++ b/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala @@ -331,7 +331,7 @@ private[spark] class ExecutorAllocationManager( val delta = addExecutors(maxNeeded) logDebug(s"Starting timer to add more executors (to " + s"expire in $sustainedSchedulerBacklogTimeoutS seconds)") - addTime += sustainedSchedulerBacklogTimeoutS * 1000 + addTime = now + (sustainedSchedulerBacklogTimeoutS * 1000) delta } else { 0 diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala index 4eedaaea61195..dc82bb7704727 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala @@ -69,6 +69,10 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp // `CoarseGrainedSchedulerBackend.this`. private val executorDataMap = new HashMap[String, ExecutorData] + // Number of executors requested by the cluster manager, [[ExecutorAllocationManager]] + @GuardedBy("CoarseGrainedSchedulerBackend.this") + private var requestedTotalExecutors = 0 + // Number of executors requested from the cluster manager that have not registered yet @GuardedBy("CoarseGrainedSchedulerBackend.this") private var numPendingExecutors = 0 @@ -413,6 +417,7 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp * */ protected def reset(): Unit = { val executors = synchronized { + requestedTotalExecutors = 0 numPendingExecutors = 0 executorsPendingToRemove.clear() Set() ++ executorDataMap.keys @@ -487,12 +492,21 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp logInfo(s"Requesting $numAdditionalExecutors additional executor(s) from the cluster manager") val response = synchronized { + requestedTotalExecutors += numAdditionalExecutors numPendingExecutors += numAdditionalExecutors logDebug(s"Number of pending executors is now $numPendingExecutors") + if (requestedTotalExecutors != + (numExistingExecutors + numPendingExecutors - executorsPendingToRemove.size)) { + logDebug( + s"""requestExecutors($numAdditionalExecutors): Executor request doesn't match: + |requestedTotalExecutors = $requestedTotalExecutors + |numExistingExecutors = $numExistingExecutors + |numPendingExecutors = $numPendingExecutors + |executorsPendingToRemove = ${executorsPendingToRemove.size}""".stripMargin) + } // Account for executors pending to be added or removed - doRequestTotalExecutors( - numExistingExecutors + numPendingExecutors - executorsPendingToRemove.size) + doRequestTotalExecutors(requestedTotalExecutors) } defaultAskTimeout.awaitResult(response) @@ -524,6 +538,7 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp } val response = synchronized { + this.requestedTotalExecutors = numExecutors this.localityAwareTasks = localityAwareTasks this.hostToLocalTaskCount = hostToLocalTaskCount @@ -589,8 +604,17 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp // take into account executors that are pending to be added or removed. val adjustTotalExecutors = if (!replace) { - doRequestTotalExecutors( - numExistingExecutors + numPendingExecutors - executorsPendingToRemove.size) + requestedTotalExecutors = math.max(requestedTotalExecutors - executorsToKill.size, 0) + if (requestedTotalExecutors != + (numExistingExecutors + numPendingExecutors - executorsPendingToRemove.size)) { + logDebug( + s"""killExecutors($executorIds, $replace, $force): Executor counts do not match: + |requestedTotalExecutors = $requestedTotalExecutors + |numExistingExecutors = $numExistingExecutors + |numPendingExecutors = $numPendingExecutors + |executorsPendingToRemove = ${executorsPendingToRemove.size}""".stripMargin) + } + doRequestTotalExecutors(requestedTotalExecutors) } else { numPendingExecutors += knownExecutors.size Future.successful(true) diff --git a/core/src/test/scala/org/apache/spark/deploy/StandaloneDynamicAllocationSuite.scala b/core/src/test/scala/org/apache/spark/deploy/StandaloneDynamicAllocationSuite.scala index 9839dcf8535db..bf7480d79f8a1 100644 --- a/core/src/test/scala/org/apache/spark/deploy/StandaloneDynamicAllocationSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/StandaloneDynamicAllocationSuite.scala @@ -356,12 +356,13 @@ class StandaloneDynamicAllocationSuite test("kill the same executor twice (SPARK-9795)") { sc = new SparkContext(appConf) val appId = sc.applicationId + sc.requestExecutors(2) eventually(timeout(10.seconds), interval(10.millis)) { val apps = getApplications() assert(apps.size === 1) assert(apps.head.id === appId) assert(apps.head.executors.size === 2) - assert(apps.head.getExecutorLimit === Int.MaxValue) + assert(apps.head.getExecutorLimit === 2) } // sync executors between the Master and the driver, needed because // the driver refuses to kill executors it does not know about @@ -380,12 +381,13 @@ class StandaloneDynamicAllocationSuite test("the pending replacement executors should not be lost (SPARK-10515)") { sc = new SparkContext(appConf) val appId = sc.applicationId + sc.requestExecutors(2) eventually(timeout(10.seconds), interval(10.millis)) { val apps = getApplications() assert(apps.size === 1) assert(apps.head.id === appId) assert(apps.head.executors.size === 2) - assert(apps.head.getExecutorLimit === Int.MaxValue) + assert(apps.head.getExecutorLimit === 2) } // sync executors between the Master and the driver, needed because // the driver refuses to kill executors it does not know about From af726cd6117de05c6e3b9616b8699d884a53651b Mon Sep 17 00:00:00 2001 From: Sean Owen Date: Mon, 1 May 2017 17:01:05 -0700 Subject: [PATCH 0382/1765] [SPARK-20459][SQL] JdbcUtils throws IllegalStateException: Cause already initialized after getting SQLException ## What changes were proposed in this pull request? Avoid failing to initCause on JDBC exception with cause initialized to null ## How was this patch tested? Existing tests Author: Sean Owen Closes #17800 from srowen/SPARK-20459. --- .../sql/execution/datasources/jdbc/JdbcUtils.scala | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala index 5fc3c2753b6cf..0183805d56257 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala @@ -652,8 +652,17 @@ object JdbcUtils extends Logging { case e: SQLException => val cause = e.getNextException if (cause != null && e.getCause != cause) { + // If there is no cause already, set 'next exception' as cause. If cause is null, + // it *may* be because no cause was set yet if (e.getCause == null) { - e.initCause(cause) + try { + e.initCause(cause) + } catch { + // Or it may be null because the cause *was* explicitly initialized, to *null*, + // in which case this fails. There is no other way to detect it. + // addSuppressed in this case as well. + case _: IllegalStateException => e.addSuppressed(cause) + } } else { e.addSuppressed(cause) } From 259860d23d1740954b739b639c5bdc3ede65ed25 Mon Sep 17 00:00:00 2001 From: ptkool Date: Mon, 1 May 2017 17:05:35 -0700 Subject: [PATCH 0383/1765] [SPARK-20463] Add support for IS [NOT] DISTINCT FROM. ## What changes were proposed in this pull request? Add support for the SQL standard distinct predicate to SPARK SQL. ``` IS [NOT] DISTINCT FROM ``` ## How was this patch tested? Tested using unit tests, integration tests, manual tests. Author: ptkool Closes #17764 from ptkool/is_not_distinct_from. --- .../antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 | 1 + .../org/apache/spark/sql/catalyst/parser/AstBuilder.scala | 5 +++++ .../spark/sql/catalyst/parser/ExpressionParserSuite.scala | 5 +++++ 3 files changed, 11 insertions(+) diff --git a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 index 1ecb3d1958f43..14c511f670606 100644 --- a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 +++ b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 @@ -534,6 +534,7 @@ predicate | NOT? kind=IN '(' query ')' | NOT? kind=(RLIKE | LIKE) pattern=valueExpression | IS NOT? kind=NULL + | IS NOT? kind=DISTINCT FROM right=valueExpression ; valueExpression diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index a48a693a95c93..d2a9b4a9a9f59 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -935,6 +935,7 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with Logging { * - (NOT) LIKE * - (NOT) RLIKE * - IS (NOT) NULL. + * - IS (NOT) DISTINCT FROM */ private def withPredicate(e: Expression, ctx: PredicateContext): Expression = withOrigin(ctx) { // Invert a predicate if it has a valid NOT clause. @@ -962,6 +963,10 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with Logging { IsNotNull(e) case SqlBaseParser.NULL => IsNull(e) + case SqlBaseParser.DISTINCT if ctx.NOT != null => + EqualNullSafe(e, expression(ctx.right)) + case SqlBaseParser.DISTINCT => + Not(EqualNullSafe(e, expression(ctx.right))) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala index e7f3b64a71130..eb68eb9851b85 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala @@ -167,6 +167,11 @@ class ExpressionParserSuite extends PlanTest { assertEqual("a = b is not null", ('a === 'b).isNotNull) } + test("is distinct expressions") { + assertEqual("a is distinct from b", !('a <=> 'b)) + assertEqual("a is not distinct from b", 'a <=> 'b) + } + test("binary arithmetic expressions") { // Simple operations assertEqual("a * b", 'a * 'b) From 943a684b9827ca294ed06a46431507538d40a134 Mon Sep 17 00:00:00 2001 From: Sameer Agarwal Date: Mon, 1 May 2017 17:42:53 -0700 Subject: [PATCH 0384/1765] [SPARK-20548] Disable ReplSuite.newProductSeqEncoder with REPL defined class ## What changes were proposed in this pull request? `newProductSeqEncoder with REPL defined class` in `ReplSuite` has been failing in-deterministically : https://spark-tests.appspot.com/failed-tests over the last few days. Disabling the test until a fix is in place. https://spark.test.databricks.com/job/spark-master-test-sbt-hadoop-2.7/176/testReport/junit/org.apache.spark.repl/ReplSuite/newProductSeqEncoder_with_REPL_defined_class/history/ ## How was this patch tested? N/A Author: Sameer Agarwal Closes #17823 from sameeragarwal/disable-test. --- .../src/test/scala/org/apache/spark/repl/ReplSuite.scala | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/repl/scala-2.11/src/test/scala/org/apache/spark/repl/ReplSuite.scala b/repl/scala-2.11/src/test/scala/org/apache/spark/repl/ReplSuite.scala index 121a02a9be0a1..8fe27080cac66 100644 --- a/repl/scala-2.11/src/test/scala/org/apache/spark/repl/ReplSuite.scala +++ b/repl/scala-2.11/src/test/scala/org/apache/spark/repl/ReplSuite.scala @@ -474,7 +474,8 @@ class ReplSuite extends SparkFunSuite { assertDoesNotContain("Exception", output) } - test("newProductSeqEncoder with REPL defined class") { + // TODO: [SPARK-20548] Fix and re-enable + ignore("newProductSeqEncoder with REPL defined class") { val output = runInterpreterInPasteMode("local-cluster[1,4,4096]", """ |case class Click(id: Int) From d20a976e8918ca8d607af452301e8014fe14e64a Mon Sep 17 00:00:00 2001 From: Felix Cheung Date: Mon, 1 May 2017 21:03:48 -0700 Subject: [PATCH 0385/1765] [SPARK-20192][SPARKR][DOC] SparkR migration guide to 2.2.0 ## What changes were proposed in this pull request? Updating R Programming Guide ## How was this patch tested? manually Author: Felix Cheung Closes #17816 from felixcheung/r22relnote. --- docs/sparkr.md | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/docs/sparkr.md b/docs/sparkr.md index 16b1ef6512420..6dbd02a48890d 100644 --- a/docs/sparkr.md +++ b/docs/sparkr.md @@ -644,3 +644,11 @@ You can inspect the search path in R with [`search()`](https://stat.ethz.ch/R-ma ## Upgrading to SparkR 2.1.0 - `join` no longer performs Cartesian Product by default, use `crossJoin` instead. + +## Upgrading to SparkR 2.2.0 + + - A `numPartitions` parameter has been added to `createDataFrame` and `as.DataFrame`. When splitting the data, the partition position calculation has been made to match the one in Scala. + - The method `createExternalTable` has been deprecated to be replaced by `createTable`. Either methods can be called to create external or managed table. Additional catalog methods have also been added. + - By default, derby.log is now saved to `tempdir()`. This will be created when instantiating the SparkSession with `enableHiveSupport` set to `TRUE`. + - `spark.lda` was not setting the optimizer correctly. It has been corrected. + - Several model summary outputs are updated to have `coefficients` as `matrix`. This includes `spark.logit`, `spark.kmeans`, `spark.glm`. Model summary outputs for `spark.gaussianMixture` have added log-likelihood as `loglik`. From 90d77e971f6b3fa268e411279f34bc1db4321991 Mon Sep 17 00:00:00 2001 From: zero323 Date: Mon, 1 May 2017 21:39:17 -0700 Subject: [PATCH 0386/1765] [SPARK-20532][SPARKR] Implement grouping and grouping_id ## What changes were proposed in this pull request? Adds R wrappers for: - `o.a.s.sql.functions.grouping` as `o.a.s.sql.functions.is_grouping` (to avoid shading `base::grouping` - `o.a.s.sql.functions.grouping_id` ## How was this patch tested? Existing unit tests, additional unit tests. `check-cran.sh`. Author: zero323 Closes #17807 from zero323/SPARK-20532. --- R/pkg/NAMESPACE | 2 + R/pkg/R/functions.R | 84 +++++++++++++++++++++++ R/pkg/R/generics.R | 8 +++ R/pkg/inst/tests/testthat/test_sparkSQL.R | 56 ++++++++++++++- 4 files changed, 148 insertions(+), 2 deletions(-) diff --git a/R/pkg/NAMESPACE b/R/pkg/NAMESPACE index e8de34d9371a0..7ecd168137e8d 100644 --- a/R/pkg/NAMESPACE +++ b/R/pkg/NAMESPACE @@ -249,6 +249,8 @@ exportMethods("%<=>%", "getField", "getItem", "greatest", + "grouping_bit", + "grouping_id", "hex", "histogram", "hour", diff --git a/R/pkg/R/functions.R b/R/pkg/R/functions.R index f9687d680e7a2..38384a89919a2 100644 --- a/R/pkg/R/functions.R +++ b/R/pkg/R/functions.R @@ -3890,3 +3890,87 @@ setMethod("not", jc <- callJStatic("org.apache.spark.sql.functions", "not", x@jc) column(jc) }) + +#' grouping_bit +#' +#' Indicates whether a specified column in a GROUP BY list is aggregated or not, +#' returns 1 for aggregated or 0 for not aggregated in the result set. +#' +#' Same as \code{GROUPING} in SQL and \code{grouping} function in Scala. +#' +#' @param x Column to compute on +#' +#' @rdname grouping_bit +#' @name grouping_bit +#' @family agg_funcs +#' @aliases grouping_bit,Column-method +#' @export +#' @examples \dontrun{ +#' df <- createDataFrame(mtcars) +#' +#' # With cube +#' agg( +#' cube(df, "cyl", "gear", "am"), +#' mean(df$mpg), +#' grouping_bit(df$cyl), grouping_bit(df$gear), grouping_bit(df$am) +#' ) +#' +#' # With rollup +#' agg( +#' rollup(df, "cyl", "gear", "am"), +#' mean(df$mpg), +#' grouping_bit(df$cyl), grouping_bit(df$gear), grouping_bit(df$am) +#' ) +#' } +#' @note grouping_bit since 2.3.0 +setMethod("grouping_bit", + signature(x = "Column"), + function(x) { + jc <- callJStatic("org.apache.spark.sql.functions", "grouping", x@jc) + column(jc) + }) + +#' grouping_id +#' +#' Returns the level of grouping. +#' +#' Equals to \code{ +#' grouping_bit(c1) * 2^(n - 1) + grouping_bit(c2) * 2^(n - 2) + ... + grouping_bit(cn) +#' } +#' +#' @param x Column to compute on +#' @param ... additional Column(s) (optional). +#' +#' @rdname grouping_id +#' @name grouping_id +#' @family agg_funcs +#' @aliases grouping_id,Column-method +#' @export +#' @examples \dontrun{ +#' df <- createDataFrame(mtcars) +#' +#' # With cube +#' agg( +#' cube(df, "cyl", "gear", "am"), +#' mean(df$mpg), +#' grouping_id(df$cyl, df$gear, df$am) +#' ) +#' +#' # With rollup +#' agg( +#' rollup(df, "cyl", "gear", "am"), +#' mean(df$mpg), +#' grouping_id(df$cyl, df$gear, df$am) +#' ) +#' } +#' @note grouping_id since 2.3.0 +setMethod("grouping_id", + signature(x = "Column"), + function(x, ...) { + jcols <- lapply(list(x, ...), function (x) { + stopifnot(class(x) == "Column") + x@jc + }) + jc <- callJStatic("org.apache.spark.sql.functions", "grouping_id", jcols) + column(jc) + }) diff --git a/R/pkg/R/generics.R b/R/pkg/R/generics.R index ef36765a7a725..e02d46426a5a6 100644 --- a/R/pkg/R/generics.R +++ b/R/pkg/R/generics.R @@ -1052,6 +1052,14 @@ setGeneric("from_unixtime", function(x, ...) { standardGeneric("from_unixtime") #' @export setGeneric("greatest", function(x, ...) { standardGeneric("greatest") }) +#' @rdname grouping_bit +#' @export +setGeneric("grouping_bit", function(x) { standardGeneric("grouping_bit") }) + +#' @rdname grouping_id +#' @export +setGeneric("grouping_id", function(x, ...) { standardGeneric("grouping_id") }) + #' @rdname hex #' @export setGeneric("hex", function(x) { standardGeneric("hex") }) diff --git a/R/pkg/inst/tests/testthat/test_sparkSQL.R b/R/pkg/inst/tests/testthat/test_sparkSQL.R index 08296354ca7ed..12867c15d1f95 100644 --- a/R/pkg/inst/tests/testthat/test_sparkSQL.R +++ b/R/pkg/inst/tests/testthat/test_sparkSQL.R @@ -1848,7 +1848,11 @@ test_that("test multi-dimensional aggregations with cube and rollup", { orderBy( agg( cube(df, "year", "department"), - expr("sum(salary) AS total_salary"), expr("avg(salary) AS average_salary") + expr("sum(salary) AS total_salary"), + expr("avg(salary) AS average_salary"), + alias(grouping_bit(df$year), "grouping_year"), + alias(grouping_bit(df$department), "grouping_department"), + alias(grouping_id(df$year, df$department), "grouping_id") ), "year", "department" ) @@ -1875,6 +1879,30 @@ test_that("test multi-dimensional aggregations with cube and rollup", { mean(c(21000, 32000, 22000)), # 2017 22000, 32000, 21000 # 2017 each department ), + grouping_year = c( + 1, # global + 1, 1, 1, # by department + 0, # 2016 + 0, 0, 0, # 2016 by department + 0, # 2017 + 0, 0, 0 # 2017 by department + ), + grouping_department = c( + 1, # global + 0, 0, 0, # by department + 1, # 2016 + 0, 0, 0, # 2016 by department + 1, # 2017 + 0, 0, 0 # 2017 by department + ), + grouping_id = c( + 3, # 11 + 2, 2, 2, # 10 + 1, # 01 + 0, 0, 0, # 00 + 1, # 01 + 0, 0, 0 # 00 + ), stringsAsFactors = FALSE ) @@ -1896,7 +1924,10 @@ test_that("test multi-dimensional aggregations with cube and rollup", { orderBy( agg( rollup(df, "year", "department"), - expr("sum(salary) AS total_salary"), expr("avg(salary) AS average_salary") + expr("sum(salary) AS total_salary"), expr("avg(salary) AS average_salary"), + alias(grouping_bit(df$year), "grouping_year"), + alias(grouping_bit(df$department), "grouping_department"), + alias(grouping_id(df$year, df$department), "grouping_id") ), "year", "department" ) @@ -1920,6 +1951,27 @@ test_that("test multi-dimensional aggregations with cube and rollup", { mean(c(21000, 32000, 22000)), # 2017 22000, 32000, 21000 # 2017 each department ), + grouping_year = c( + 1, # global + 0, # 2016 + 0, 0, 0, # 2016 each department + 0, # 2017 + 0, 0, 0 # 2017 each department + ), + grouping_department = c( + 1, # global + 1, # 2016 + 0, 0, 0, # 2016 each department + 1, # 2017 + 0, 0, 0 # 2017 each department + ), + grouping_id = c( + 3, # 11 + 1, # 01 + 0, 0, 0, # 00 + 1, # 01 + 0, 0, 0 # 00 + ), stringsAsFactors = FALSE ) From afb21bf22a59c9416c04637412fb69d1442e6826 Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Tue, 2 May 2017 13:56:41 +0800 Subject: [PATCH 0387/1765] [SPARK-20537][CORE] Fixing OffHeapColumnVector reallocation ## What changes were proposed in this pull request? As #17773 revealed `OnHeapColumnVector` may copy a part of the original storage. `OffHeapColumnVector` reallocation also copies to the new storage data up to 'elementsAppended'. This variable is only updated when using the `ColumnVector.appendX` API, while `ColumnVector.putX` is more commonly used. This PR copies the new storage data up to the previously-allocated size in`OffHeapColumnVector`. ## How was this patch tested? Existing test suites Author: Kazuaki Ishizaki Closes #17811 from kiszk/SPARK-20537. --- .../vectorized/OffHeapColumnVector.java | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OffHeapColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OffHeapColumnVector.java index e988c0722bd72..a7d3744d00e91 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OffHeapColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OffHeapColumnVector.java @@ -436,28 +436,29 @@ public void loadBytes(ColumnVector.Array array) { // Split out the slow path. @Override protected void reserveInternal(int newCapacity) { + int oldCapacity = (this.data == 0L) ? 0 : capacity; if (this.resultArray != null) { this.lengthData = - Platform.reallocateMemory(lengthData, elementsAppended * 4, newCapacity * 4); + Platform.reallocateMemory(lengthData, oldCapacity * 4, newCapacity * 4); this.offsetData = - Platform.reallocateMemory(offsetData, elementsAppended * 4, newCapacity * 4); + Platform.reallocateMemory(offsetData, oldCapacity * 4, newCapacity * 4); } else if (type instanceof ByteType || type instanceof BooleanType) { - this.data = Platform.reallocateMemory(data, elementsAppended, newCapacity); + this.data = Platform.reallocateMemory(data, oldCapacity, newCapacity); } else if (type instanceof ShortType) { - this.data = Platform.reallocateMemory(data, elementsAppended * 2, newCapacity * 2); + this.data = Platform.reallocateMemory(data, oldCapacity * 2, newCapacity * 2); } else if (type instanceof IntegerType || type instanceof FloatType || type instanceof DateType || DecimalType.is32BitDecimalType(type)) { - this.data = Platform.reallocateMemory(data, elementsAppended * 4, newCapacity * 4); + this.data = Platform.reallocateMemory(data, oldCapacity * 4, newCapacity * 4); } else if (type instanceof LongType || type instanceof DoubleType || DecimalType.is64BitDecimalType(type) || type instanceof TimestampType) { - this.data = Platform.reallocateMemory(data, elementsAppended * 8, newCapacity * 8); + this.data = Platform.reallocateMemory(data, oldCapacity * 8, newCapacity * 8); } else if (resultStruct != null) { // Nothing to store. } else { throw new RuntimeException("Unhandled " + type); } - this.nulls = Platform.reallocateMemory(nulls, elementsAppended, newCapacity); - Platform.setMemory(nulls + elementsAppended, (byte)0, newCapacity - elementsAppended); + this.nulls = Platform.reallocateMemory(nulls, oldCapacity, newCapacity); + Platform.setMemory(nulls + oldCapacity, (byte)0, newCapacity - oldCapacity); capacity = newCapacity; } } From 86174ea89b39a300caaba6baffac70f3dc702788 Mon Sep 17 00:00:00 2001 From: Burak Yavuz Date: Tue, 2 May 2017 14:08:16 +0800 Subject: [PATCH 0388/1765] [SPARK-20549] java.io.CharConversionException: Invalid UTF-32' in JsonToStructs ## What changes were proposed in this pull request? A fix for the same problem was made in #17693 but ignored `JsonToStructs`. This PR uses the same fix for `JsonToStructs`. ## How was this patch tested? Regression test Author: Burak Yavuz Closes #17826 from brkyvz/SPARK-20549. --- .../spark/sql/catalyst/expressions/jsonExpressions.scala | 8 +++----- .../spark/sql/catalyst/json/CreateJacksonParser.scala | 7 +++++-- .../sql/catalyst/expressions/JsonExpressionsSuite.scala | 7 +++++++ 3 files changed, 15 insertions(+), 7 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala index 9fb0ea68153d2..6b90354367f40 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala @@ -151,8 +151,7 @@ case class GetJsonObject(json: Expression, path: Expression) try { /* We know the bytes are UTF-8 encoded. Pass a Reader to avoid having Jackson detect character encoding which could fail for some malformed strings */ - Utils.tryWithResource(jsonFactory.createParser(new InputStreamReader( - new ByteArrayInputStream(jsonStr.getBytes), "UTF-8"))) { parser => + Utils.tryWithResource(CreateJacksonParser.utf8String(jsonFactory, jsonStr)) { parser => val output = new ByteArrayOutputStream() val matched = Utils.tryWithResource( jsonFactory.createGenerator(output, JsonEncoding.UTF8)) { generator => @@ -398,9 +397,8 @@ case class JsonTuple(children: Seq[Expression]) try { /* We know the bytes are UTF-8 encoded. Pass a Reader to avoid having Jackson detect character encoding which could fail for some malformed strings */ - Utils.tryWithResource(jsonFactory.createParser(new InputStreamReader( - new ByteArrayInputStream(json.getBytes), "UTF-8"))) { - parser => parseRow(parser, input) + Utils.tryWithResource(CreateJacksonParser.utf8String(jsonFactory, json)) { parser => + parseRow(parser, input) } } catch { case _: JsonProcessingException => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/CreateJacksonParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/CreateJacksonParser.scala index e0ed03a68981a..025a388aacaa5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/CreateJacksonParser.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/CreateJacksonParser.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.catalyst.json -import java.io.InputStream +import java.io.{ByteArrayInputStream, InputStream, InputStreamReader} import com.fasterxml.jackson.core.{JsonFactory, JsonParser} import org.apache.hadoop.io.Text @@ -33,7 +33,10 @@ private[sql] object CreateJacksonParser extends Serializable { val bb = record.getByteBuffer assert(bb.hasArray) - jsonFactory.createParser(bb.array(), bb.arrayOffset() + bb.position(), bb.remaining()) + val bain = new ByteArrayInputStream( + bb.array(), bb.arrayOffset() + bb.position(), bb.remaining()) + + jsonFactory.createParser(new InputStreamReader(bain, "UTF-8")) } def text(jsonFactory: JsonFactory, record: Text): JsonParser = { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/JsonExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/JsonExpressionsSuite.scala index 4402ad4e9a9e5..65d5c3a582b16 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/JsonExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/JsonExpressionsSuite.scala @@ -453,6 +453,13 @@ class JsonExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { ) } + test("SPARK-20549: from_json bad UTF-8") { + val schema = StructType(StructField("a", IntegerType) :: Nil) + checkEvaluation( + JsonToStructs(schema, Map.empty, Literal(badJson), gmtId), + null) + } + test("from_json with timestamp") { val schema = StructType(StructField("t", TimestampType) :: Nil) From e300a5a145820ecd466885c73245d6684e8cb0aa Mon Sep 17 00:00:00 2001 From: Nick Pentreath Date: Tue, 2 May 2017 10:49:13 +0200 Subject: [PATCH 0389/1765] [SPARK-20300][ML][PYSPARK] Python API for ALSModel.recommendForAllUsers,Items Add Python API for `ALSModel` methods `recommendForAllUsers`, `recommendForAllItems` ## How was this patch tested? New doc tests. Author: Nick Pentreath Closes #17622 from MLnick/SPARK-20300-pyspark-recall. --- python/pyspark/ml/recommendation.py | 30 +++++++++++++++++++++++++++++ 1 file changed, 30 insertions(+) diff --git a/python/pyspark/ml/recommendation.py b/python/pyspark/ml/recommendation.py index 8bc899a0788bb..bcfb36880eb02 100644 --- a/python/pyspark/ml/recommendation.py +++ b/python/pyspark/ml/recommendation.py @@ -82,6 +82,14 @@ class ALS(JavaEstimator, HasCheckpointInterval, HasMaxIter, HasPredictionCol, Ha Row(user=1, item=0, prediction=2.6258413791656494) >>> predictions[2] Row(user=2, item=0, prediction=-1.5018409490585327) + >>> user_recs = model.recommendForAllUsers(3) + >>> user_recs.where(user_recs.user == 0)\ + .select("recommendations.item", "recommendations.rating").collect() + [Row(item=[0, 1, 2], rating=[3.910..., 1.992..., -0.138...])] + >>> item_recs = model.recommendForAllItems(3) + >>> item_recs.where(item_recs.item == 2)\ + .select("recommendations.user", "recommendations.rating").collect() + [Row(user=[2, 1, 0], rating=[4.901..., 3.981..., -0.138...])] >>> als_path = temp_path + "/als" >>> als.save(als_path) >>> als2 = ALS.load(als_path) @@ -384,6 +392,28 @@ def itemFactors(self): """ return self._call_java("itemFactors") + @since("2.2.0") + def recommendForAllUsers(self, numItems): + """ + Returns top `numItems` items recommended for each user, for all users. + + :param numItems: max number of recommendations for each user + :return: a DataFrame of (userCol, recommendations), where recommendations are + stored as an array of (itemCol, rating) Rows. + """ + return self._call_java("recommendForAllUsers", numItems) + + @since("2.2.0") + def recommendForAllItems(self, numUsers): + """ + Returns top `numUsers` users recommended for each item, for all items. + + :param numUsers: max number of recommendations for each item + :return: a DataFrame of (itemCol, recommendations), where recommendations are + stored as an array of (userCol, rating) Rows. + """ + return self._call_java("recommendForAllItems", numUsers) + if __name__ == "__main__": import doctest From b1e639ab09d3a7a1545119e45a505c9a04308353 Mon Sep 17 00:00:00 2001 From: Xiao Li Date: Tue, 2 May 2017 16:49:24 +0800 Subject: [PATCH 0390/1765] [SPARK-19235][SQL][TEST][FOLLOW-UP] Enable Test Cases in DDLSuite with Hive Metastore ### What changes were proposed in this pull request? This is a follow-up of enabling test cases in DDLSuite with Hive Metastore. It consists of the following remaining tasks: - Run all the `alter table` and `drop table` DDL tests against data source tables when using Hive metastore. - Do not run any `alter table` and `drop table` DDL test against Hive serde tables when using InMemoryCatalog. - Reenable `alter table: set serde partition` and `alter table: set serde` tests for Hive serde tables. ### How was this patch tested? N/A Author: Xiao Li Closes #17524 from gatorsmile/cleanupDDLSuite. --- .../sql/execution/command/DDLSuite.scala | 291 ++++++++---------- .../apache/spark/sql/test/SQLTestUtils.scala | 3 +- .../sql/hive/execution/HiveDDLSuite.scala | 73 ++++- 3 files changed, 195 insertions(+), 172 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala index 2f4eb1b15519b..0abcff76060f7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala @@ -49,7 +49,8 @@ class InMemoryCatalogedDDLSuite extends DDLSuite with SharedSQLContext with Befo protected override def generateTable( catalog: SessionCatalog, - name: TableIdentifier): CatalogTable = { + name: TableIdentifier, + isDataSource: Boolean = true): CatalogTable = { val storage = CatalogStorageFormat.empty.copy(locationUri = Some(catalog.defaultTablePath(name))) val metadata = new MetadataBuilder() @@ -70,46 +71,6 @@ class InMemoryCatalogedDDLSuite extends DDLSuite with SharedSQLContext with Befo tracksPartitionsInCatalog = true) } - test("alter table: set location (datasource table)") { - testSetLocation(isDatasourceTable = true) - } - - test("alter table: set properties (datasource table)") { - testSetProperties(isDatasourceTable = true) - } - - test("alter table: unset properties (datasource table)") { - testUnsetProperties(isDatasourceTable = true) - } - - test("alter table: set serde (datasource table)") { - testSetSerde(isDatasourceTable = true) - } - - test("alter table: set serde partition (datasource table)") { - testSetSerdePartition(isDatasourceTable = true) - } - - test("alter table: change column (datasource table)") { - testChangeColumn(isDatasourceTable = true) - } - - test("alter table: add partition (datasource table)") { - testAddPartitions(isDatasourceTable = true) - } - - test("alter table: drop partition (datasource table)") { - testDropPartitions(isDatasourceTable = true) - } - - test("alter table: rename partition (datasource table)") { - testRenamePartitions(isDatasourceTable = true) - } - - test("drop table - data source table") { - testDropTable(isDatasourceTable = true) - } - test("create a managed Hive source table") { assume(spark.sparkContext.conf.get(CATALOG_IMPLEMENTATION) == "in-memory") val tabName = "tbl" @@ -163,7 +124,10 @@ abstract class DDLSuite extends QueryTest with SQLTestUtils { spark.sparkContext.conf.get(CATALOG_IMPLEMENTATION) == "hive" } - protected def generateTable(catalog: SessionCatalog, name: TableIdentifier): CatalogTable + protected def generateTable( + catalog: SessionCatalog, + name: TableIdentifier, + isDataSource: Boolean = true): CatalogTable private val escapedIdentifier = "`(.+)`".r @@ -205,8 +169,11 @@ abstract class DDLSuite extends QueryTest with SQLTestUtils { ignoreIfExists = false) } - private def createTable(catalog: SessionCatalog, name: TableIdentifier): Unit = { - catalog.createTable(generateTable(catalog, name), ignoreIfExists = false) + private def createTable( + catalog: SessionCatalog, + name: TableIdentifier, + isDataSource: Boolean = true): Unit = { + catalog.createTable(generateTable(catalog, name, isDataSource), ignoreIfExists = false) } private def createTablePartition( @@ -223,6 +190,46 @@ abstract class DDLSuite extends QueryTest with SQLTestUtils { new Path(CatalogUtils.URIToString(warehousePath), s"$dbName.db").toUri } + test("alter table: set location (datasource table)") { + testSetLocation(isDatasourceTable = true) + } + + test("alter table: set properties (datasource table)") { + testSetProperties(isDatasourceTable = true) + } + + test("alter table: unset properties (datasource table)") { + testUnsetProperties(isDatasourceTable = true) + } + + test("alter table: set serde (datasource table)") { + testSetSerde(isDatasourceTable = true) + } + + test("alter table: set serde partition (datasource table)") { + testSetSerdePartition(isDatasourceTable = true) + } + + test("alter table: change column (datasource table)") { + testChangeColumn(isDatasourceTable = true) + } + + test("alter table: add partition (datasource table)") { + testAddPartitions(isDatasourceTable = true) + } + + test("alter table: drop partition (datasource table)") { + testDropPartitions(isDatasourceTable = true) + } + + test("alter table: rename partition (datasource table)") { + testRenamePartitions(isDatasourceTable = true) + } + + test("drop table - data source table") { + testDropTable(isDatasourceTable = true) + } + test("the qualified path of a database is stored in the catalog") { val catalog = spark.sessionState.catalog @@ -835,32 +842,6 @@ abstract class DDLSuite extends QueryTest with SQLTestUtils { } } - test("alter table: set location") { - testSetLocation(isDatasourceTable = false) - } - - test("alter table: set properties") { - testSetProperties(isDatasourceTable = false) - } - - test("alter table: unset properties") { - testUnsetProperties(isDatasourceTable = false) - } - - // TODO: move this test to HiveDDLSuite.scala - ignore("alter table: set serde") { - testSetSerde(isDatasourceTable = false) - } - - // TODO: move this test to HiveDDLSuite.scala - ignore("alter table: set serde partition") { - testSetSerdePartition(isDatasourceTable = false) - } - - test("alter table: change column") { - testChangeColumn(isDatasourceTable = false) - } - test("alter table: bucketing is not supported") { val catalog = spark.sessionState.catalog val tableIdent = TableIdentifier("tab1", Some("dbx")) @@ -885,10 +866,6 @@ abstract class DDLSuite extends QueryTest with SQLTestUtils { assertUnsupported("ALTER TABLE dbx.tab1 NOT STORED AS DIRECTORIES") } - test("alter table: add partition") { - testAddPartitions(isDatasourceTable = false) - } - test("alter table: recover partitions (sequential)") { withSQLConf("spark.rdd.parallelListingThreshold" -> "10") { testRecoverPartitions() @@ -957,17 +934,10 @@ abstract class DDLSuite extends QueryTest with SQLTestUtils { assertUnsupported("ALTER VIEW dbx.tab1 ADD IF NOT EXISTS PARTITION (b='2')") } - test("alter table: drop partition") { - testDropPartitions(isDatasourceTable = false) - } - test("alter table: drop partition is not supported for views") { assertUnsupported("ALTER VIEW dbx.tab1 DROP IF EXISTS PARTITION (b='2')") } - test("alter table: rename partition") { - testRenamePartitions(isDatasourceTable = false) - } test("show databases") { sql("CREATE DATABASE showdb2B") @@ -1011,18 +981,14 @@ abstract class DDLSuite extends QueryTest with SQLTestUtils { assert(catalog.listTables("default") == Nil) } - test("drop table") { - testDropTable(isDatasourceTable = false) - } - protected def testDropTable(isDatasourceTable: Boolean): Unit = { + if (!isUsingHiveMetastore) { + assert(isDatasourceTable, "InMemoryCatalog only supports data source tables") + } val catalog = spark.sessionState.catalog val tableIdent = TableIdentifier("tab1", Some("dbx")) createDatabase(catalog, "dbx") - createTable(catalog, tableIdent) - if (isDatasourceTable) { - convertToDatasourceTable(catalog, tableIdent) - } + createTable(catalog, tableIdent, isDatasourceTable) assert(catalog.listTables("dbx") == Seq(tableIdent)) sql("DROP TABLE dbx.tab1") assert(catalog.listTables("dbx") == Nil) @@ -1046,22 +1012,14 @@ abstract class DDLSuite extends QueryTest with SQLTestUtils { e.getMessage.contains("Cannot drop a table with DROP VIEW. Please use DROP TABLE instead")) } - private def convertToDatasourceTable( - catalog: SessionCatalog, - tableIdent: TableIdentifier): Unit = { - catalog.alterTable(catalog.getTableMetadata(tableIdent).copy( - provider = Some("csv"))) - assert(catalog.getTableMetadata(tableIdent).provider == Some("csv")) - } - protected def testSetProperties(isDatasourceTable: Boolean): Unit = { + if (!isUsingHiveMetastore) { + assert(isDatasourceTable, "InMemoryCatalog only supports data source tables") + } val catalog = spark.sessionState.catalog val tableIdent = TableIdentifier("tab1", Some("dbx")) createDatabase(catalog, "dbx") - createTable(catalog, tableIdent) - if (isDatasourceTable) { - convertToDatasourceTable(catalog, tableIdent) - } + createTable(catalog, tableIdent, isDatasourceTable) def getProps: Map[String, String] = { if (isUsingHiveMetastore) { normalizeCatalogTable(catalog.getTableMetadata(tableIdent)).properties @@ -1084,13 +1042,13 @@ abstract class DDLSuite extends QueryTest with SQLTestUtils { } protected def testUnsetProperties(isDatasourceTable: Boolean): Unit = { + if (!isUsingHiveMetastore) { + assert(isDatasourceTable, "InMemoryCatalog only supports data source tables") + } val catalog = spark.sessionState.catalog val tableIdent = TableIdentifier("tab1", Some("dbx")) createDatabase(catalog, "dbx") - createTable(catalog, tableIdent) - if (isDatasourceTable) { - convertToDatasourceTable(catalog, tableIdent) - } + createTable(catalog, tableIdent, isDatasourceTable) def getProps: Map[String, String] = { if (isUsingHiveMetastore) { normalizeCatalogTable(catalog.getTableMetadata(tableIdent)).properties @@ -1121,15 +1079,15 @@ abstract class DDLSuite extends QueryTest with SQLTestUtils { } protected def testSetLocation(isDatasourceTable: Boolean): Unit = { + if (!isUsingHiveMetastore) { + assert(isDatasourceTable, "InMemoryCatalog only supports data source tables") + } val catalog = spark.sessionState.catalog val tableIdent = TableIdentifier("tab1", Some("dbx")) val partSpec = Map("a" -> "1", "b" -> "2") createDatabase(catalog, "dbx") - createTable(catalog, tableIdent) + createTable(catalog, tableIdent, isDatasourceTable) createTablePartition(catalog, partSpec, tableIdent) - if (isDatasourceTable) { - convertToDatasourceTable(catalog, tableIdent) - } assert(catalog.getTableMetadata(tableIdent).storage.locationUri.isDefined) assert(normalizeSerdeProp(catalog.getTableMetadata(tableIdent).storage.properties).isEmpty) assert(catalog.getPartition(tableIdent, partSpec).storage.locationUri.isDefined) @@ -1171,13 +1129,13 @@ abstract class DDLSuite extends QueryTest with SQLTestUtils { } protected def testSetSerde(isDatasourceTable: Boolean): Unit = { + if (!isUsingHiveMetastore) { + assert(isDatasourceTable, "InMemoryCatalog only supports data source tables") + } val catalog = spark.sessionState.catalog val tableIdent = TableIdentifier("tab1", Some("dbx")) createDatabase(catalog, "dbx") - createTable(catalog, tableIdent) - if (isDatasourceTable) { - convertToDatasourceTable(catalog, tableIdent) - } + createTable(catalog, tableIdent, isDatasourceTable) def checkSerdeProps(expectedSerdeProps: Map[String, String]): Unit = { val serdeProp = catalog.getTableMetadata(tableIdent).storage.properties if (isUsingHiveMetastore) { @@ -1187,8 +1145,12 @@ abstract class DDLSuite extends QueryTest with SQLTestUtils { } } if (isUsingHiveMetastore) { - assert(catalog.getTableMetadata(tableIdent).storage.serde == - Some("org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe")) + val expectedSerde = if (isDatasourceTable) { + "org.apache.hadoop.hive.ql.io.parquet.serde.ParquetHiveSerDe" + } else { + "org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe" + } + assert(catalog.getTableMetadata(tableIdent).storage.serde == Some(expectedSerde)) } else { assert(catalog.getTableMetadata(tableIdent).storage.serde.isEmpty) } @@ -1229,18 +1191,18 @@ abstract class DDLSuite extends QueryTest with SQLTestUtils { } protected def testSetSerdePartition(isDatasourceTable: Boolean): Unit = { + if (!isUsingHiveMetastore) { + assert(isDatasourceTable, "InMemoryCatalog only supports data source tables") + } val catalog = spark.sessionState.catalog val tableIdent = TableIdentifier("tab1", Some("dbx")) val spec = Map("a" -> "1", "b" -> "2") createDatabase(catalog, "dbx") - createTable(catalog, tableIdent) + createTable(catalog, tableIdent, isDatasourceTable) createTablePartition(catalog, spec, tableIdent) createTablePartition(catalog, Map("a" -> "1", "b" -> "3"), tableIdent) createTablePartition(catalog, Map("a" -> "2", "b" -> "2"), tableIdent) createTablePartition(catalog, Map("a" -> "2", "b" -> "3"), tableIdent) - if (isDatasourceTable) { - convertToDatasourceTable(catalog, tableIdent) - } def checkPartitionSerdeProps(expectedSerdeProps: Map[String, String]): Unit = { val serdeProp = catalog.getPartition(tableIdent, spec).storage.properties if (isUsingHiveMetastore) { @@ -1250,8 +1212,12 @@ abstract class DDLSuite extends QueryTest with SQLTestUtils { } } if (isUsingHiveMetastore) { - assert(catalog.getPartition(tableIdent, spec).storage.serde == - Some("org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe")) + val expectedSerde = if (isDatasourceTable) { + "org.apache.hadoop.hive.ql.io.parquet.serde.ParquetHiveSerDe" + } else { + "org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe" + } + assert(catalog.getPartition(tableIdent, spec).storage.serde == Some(expectedSerde)) } else { assert(catalog.getPartition(tableIdent, spec).storage.serde.isEmpty) } @@ -1295,6 +1261,9 @@ abstract class DDLSuite extends QueryTest with SQLTestUtils { } protected def testAddPartitions(isDatasourceTable: Boolean): Unit = { + if (!isUsingHiveMetastore) { + assert(isDatasourceTable, "InMemoryCatalog only supports data source tables") + } val catalog = spark.sessionState.catalog val tableIdent = TableIdentifier("tab1", Some("dbx")) val part1 = Map("a" -> "1", "b" -> "5") @@ -1303,11 +1272,8 @@ abstract class DDLSuite extends QueryTest with SQLTestUtils { val part4 = Map("a" -> "4", "b" -> "8") val part5 = Map("a" -> "9", "b" -> "9") createDatabase(catalog, "dbx") - createTable(catalog, tableIdent) + createTable(catalog, tableIdent, isDatasourceTable) createTablePartition(catalog, part1, tableIdent) - if (isDatasourceTable) { - convertToDatasourceTable(catalog, tableIdent) - } assert(catalog.listPartitions(tableIdent).map(_.spec).toSet == Set(part1)) // basic add partition @@ -1354,6 +1320,9 @@ abstract class DDLSuite extends QueryTest with SQLTestUtils { } protected def testDropPartitions(isDatasourceTable: Boolean): Unit = { + if (!isUsingHiveMetastore) { + assert(isDatasourceTable, "InMemoryCatalog only supports data source tables") + } val catalog = spark.sessionState.catalog val tableIdent = TableIdentifier("tab1", Some("dbx")) val part1 = Map("a" -> "1", "b" -> "5") @@ -1362,7 +1331,7 @@ abstract class DDLSuite extends QueryTest with SQLTestUtils { val part4 = Map("a" -> "4", "b" -> "8") val part5 = Map("a" -> "9", "b" -> "9") createDatabase(catalog, "dbx") - createTable(catalog, tableIdent) + createTable(catalog, tableIdent, isDatasourceTable) createTablePartition(catalog, part1, tableIdent) createTablePartition(catalog, part2, tableIdent) createTablePartition(catalog, part3, tableIdent) @@ -1370,9 +1339,6 @@ abstract class DDLSuite extends QueryTest with SQLTestUtils { createTablePartition(catalog, part5, tableIdent) assert(catalog.listPartitions(tableIdent).map(_.spec).toSet == Set(part1, part2, part3, part4, part5)) - if (isDatasourceTable) { - convertToDatasourceTable(catalog, tableIdent) - } // basic drop partition sql("ALTER TABLE dbx.tab1 DROP IF EXISTS PARTITION (a='4', b='8'), PARTITION (a='3', b='7')") @@ -1407,20 +1373,20 @@ abstract class DDLSuite extends QueryTest with SQLTestUtils { } protected def testRenamePartitions(isDatasourceTable: Boolean): Unit = { + if (!isUsingHiveMetastore) { + assert(isDatasourceTable, "InMemoryCatalog only supports data source tables") + } val catalog = spark.sessionState.catalog val tableIdent = TableIdentifier("tab1", Some("dbx")) val part1 = Map("a" -> "1", "b" -> "q") val part2 = Map("a" -> "2", "b" -> "c") val part3 = Map("a" -> "3", "b" -> "p") createDatabase(catalog, "dbx") - createTable(catalog, tableIdent) + createTable(catalog, tableIdent, isDatasourceTable) createTablePartition(catalog, part1, tableIdent) createTablePartition(catalog, part2, tableIdent) createTablePartition(catalog, part3, tableIdent) assert(catalog.listPartitions(tableIdent).map(_.spec).toSet == Set(part1, part2, part3)) - if (isDatasourceTable) { - convertToDatasourceTable(catalog, tableIdent) - } // basic rename partition sql("ALTER TABLE dbx.tab1 PARTITION (a='1', b='q') RENAME TO PARTITION (a='100', b='p')") @@ -1451,14 +1417,14 @@ abstract class DDLSuite extends QueryTest with SQLTestUtils { } protected def testChangeColumn(isDatasourceTable: Boolean): Unit = { + if (!isUsingHiveMetastore) { + assert(isDatasourceTable, "InMemoryCatalog only supports data source tables") + } val catalog = spark.sessionState.catalog val resolver = spark.sessionState.conf.resolver val tableIdent = TableIdentifier("tab1", Some("dbx")) createDatabase(catalog, "dbx") - createTable(catalog, tableIdent) - if (isDatasourceTable) { - convertToDatasourceTable(catalog, tableIdent) - } + createTable(catalog, tableIdent, isDatasourceTable) def getMetadata(colName: String): Metadata = { val column = catalog.getTableMetadata(tableIdent).schema.fields.find { field => resolver(field.name, colName) @@ -1601,13 +1567,15 @@ abstract class DDLSuite extends QueryTest with SQLTestUtils { } test("drop current database") { - sql("CREATE DATABASE temp") - sql("USE temp") - sql("DROP DATABASE temp") - val e = intercept[AnalysisException] { + withDatabase("temp") { + sql("CREATE DATABASE temp") + sql("USE temp") + sql("DROP DATABASE temp") + val e = intercept[AnalysisException] { sql("CREATE TABLE t (a INT, b INT) USING parquet") }.getMessage - assert(e.contains("Database 'temp' not found")) + assert(e.contains("Database 'temp' not found")) + } } test("drop default database") { @@ -1837,22 +1805,25 @@ abstract class DDLSuite extends QueryTest with SQLTestUtils { checkAnswer(spark.table("tbl"), Row(1)) val defaultTablePath = spark.sessionState.catalog .getTableMetadata(TableIdentifier("tbl")).storage.locationUri.get - - sql(s"ALTER TABLE tbl SET LOCATION '${dir.toURI}'") - spark.catalog.refreshTable("tbl") - // SET LOCATION won't move data from previous table path to new table path. - assert(spark.table("tbl").count() == 0) - // the previous table path should be still there. - assert(new File(defaultTablePath).exists()) - - sql("INSERT INTO tbl SELECT 2") - checkAnswer(spark.table("tbl"), Row(2)) - // newly inserted data will go to the new table path. - assert(dir.listFiles().nonEmpty) - - sql("DROP TABLE tbl") - // the new table path will be removed after DROP TABLE. - assert(!dir.exists()) + try { + sql(s"ALTER TABLE tbl SET LOCATION '${dir.toURI}'") + spark.catalog.refreshTable("tbl") + // SET LOCATION won't move data from previous table path to new table path. + assert(spark.table("tbl").count() == 0) + // the previous table path should be still there. + assert(new File(defaultTablePath).exists()) + + sql("INSERT INTO tbl SELECT 2") + checkAnswer(spark.table("tbl"), Row(2)) + // newly inserted data will go to the new table path. + assert(dir.listFiles().nonEmpty) + + sql("DROP TABLE tbl") + // the new table path will be removed after DROP TABLE. + assert(!dir.exists()) + } finally { + Utils.deleteRecursively(new File(defaultTablePath)) + } } } } @@ -2125,7 +2096,7 @@ abstract class DDLSuite extends QueryTest with SQLTestUtils { Seq("a b", "a:b", "a%b").foreach { specialChars => test(s"location uri contains $specialChars for database") { - try { + withDatabase ("tmpdb") { withTable("t") { withTempDir { dir => val loc = new File(dir, specialChars) @@ -2140,8 +2111,6 @@ abstract class DDLSuite extends QueryTest with SQLTestUtils { assert(tblloc.listFiles().nonEmpty) } } - } finally { - spark.sql("DROP DATABASE IF EXISTS tmpdb") } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala index 44c0fc70d066b..f6d47734d7e83 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala @@ -237,7 +237,7 @@ private[sql] trait SQLTestUtils try f(dbName) finally { if (spark.catalog.currentDatabase == dbName) { - spark.sql(s"USE ${DEFAULT_DATABASE}") + spark.sql(s"USE $DEFAULT_DATABASE") } spark.sql(s"DROP DATABASE $dbName CASCADE") } @@ -251,6 +251,7 @@ private[sql] trait SQLTestUtils dbNames.foreach { name => spark.sql(s"DROP DATABASE IF EXISTS $name") } + spark.sql(s"USE $DEFAULT_DATABASE") } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala index 16a99321bad33..341e03b5e57fb 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala @@ -32,7 +32,7 @@ import org.apache.spark.sql.execution.command.{DDLSuite, DDLUtils} import org.apache.spark.sql.hive.HiveExternalCatalog import org.apache.spark.sql.hive.orc.OrcFileOperator import org.apache.spark.sql.hive.test.TestHiveSingleton -import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.internal.{HiveSerDe, SQLConf} import org.apache.spark.sql.internal.StaticSQLConf.CATALOG_IMPLEMENTATION import org.apache.spark.sql.test.SQLTestUtils import org.apache.spark.sql.types._ @@ -50,15 +50,28 @@ class HiveCatalogedDDLSuite extends DDLSuite with TestHiveSingleton with BeforeA protected override def generateTable( catalog: SessionCatalog, - name: TableIdentifier): CatalogTable = { + name: TableIdentifier, + isDataSource: Boolean): CatalogTable = { val storage = - CatalogStorageFormat( - locationUri = Some(catalog.defaultTablePath(name)), - inputFormat = Some("org.apache.hadoop.mapred.SequenceFileInputFormat"), - outputFormat = Some("org.apache.hadoop.hive.ql.io.HiveSequenceFileOutputFormat"), - serde = Some("org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe"), - compressed = false, - properties = Map("serialization.format" -> "1")) + if (isDataSource) { + val serde = HiveSerDe.sourceToSerDe("parquet") + assert(serde.isDefined, "The default format is not Hive compatible") + CatalogStorageFormat( + locationUri = Some(catalog.defaultTablePath(name)), + inputFormat = serde.get.inputFormat, + outputFormat = serde.get.outputFormat, + serde = serde.get.serde, + compressed = false, + properties = Map("serialization.format" -> "1")) + } else { + CatalogStorageFormat( + locationUri = Some(catalog.defaultTablePath(name)), + inputFormat = Some("org.apache.hadoop.mapred.SequenceFileInputFormat"), + outputFormat = Some("org.apache.hadoop.hive.ql.io.HiveSequenceFileOutputFormat"), + serde = Some("org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe"), + compressed = false, + properties = Map("serialization.format" -> "1")) + } val metadata = new MetadataBuilder() .putString("key", "value") .build() @@ -71,7 +84,7 @@ class HiveCatalogedDDLSuite extends DDLSuite with TestHiveSingleton with BeforeA .add("col2", "string") .add("a", "int") .add("b", "int"), - provider = Some("hive"), + provider = if (isDataSource) Some("parquet") else Some("hive"), partitionColumnNames = Seq("a", "b"), createTime = 0L, tracksPartitionsInCatalog = true) @@ -107,6 +120,46 @@ class HiveCatalogedDDLSuite extends DDLSuite with TestHiveSingleton with BeforeA ) } + test("alter table: set location") { + testSetLocation(isDatasourceTable = false) + } + + test("alter table: set properties") { + testSetProperties(isDatasourceTable = false) + } + + test("alter table: unset properties") { + testUnsetProperties(isDatasourceTable = false) + } + + test("alter table: set serde") { + testSetSerde(isDatasourceTable = false) + } + + test("alter table: set serde partition") { + testSetSerdePartition(isDatasourceTable = false) + } + + test("alter table: change column") { + testChangeColumn(isDatasourceTable = false) + } + + test("alter table: rename partition") { + testRenamePartitions(isDatasourceTable = false) + } + + test("alter table: drop partition") { + testDropPartitions(isDatasourceTable = false) + } + + test("alter table: add partition") { + testAddPartitions(isDatasourceTable = false) + } + + test("drop table") { + testDropTable(isDatasourceTable = false) + } + } class HiveDDLSuite From 13f47dc5033a99df8d9ec18f2ce373119462f7bc Mon Sep 17 00:00:00 2001 From: Felix Cheung Date: Tue, 2 May 2017 09:37:01 -0700 Subject: [PATCH 0391/1765] [SPARK-20490][SPARKR][DOC] add family tag for not function ## What changes were proposed in this pull request? doc only ## How was this patch tested? manual Author: Felix Cheung Closes #17828 from felixcheung/rnotfamily. --- R/pkg/R/functions.R | 1 + 1 file changed, 1 insertion(+) diff --git a/R/pkg/R/functions.R b/R/pkg/R/functions.R index 38384a89919a2..3d47b09ce5513 100644 --- a/R/pkg/R/functions.R +++ b/R/pkg/R/functions.R @@ -3871,6 +3871,7 @@ setMethod("posexplode_outer", #' @rdname not #' @name not #' @aliases not,Column-method +#' @family normal_funcs #' @export #' @examples \dontrun{ #' df <- createDataFrame(data.frame( From ef3df9125a30f8fb817fe855b74d7130be45b0ee Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Tue, 2 May 2017 14:30:06 -0700 Subject: [PATCH 0392/1765] [SPARK-20421][CORE] Add a missing deprecation tag. In the previous patch I deprecated StorageStatus, but not the method in SparkContext that exposes that class publicly. So deprecate the method too. Author: Marcelo Vanzin Closes #17824 from vanzin/SPARK-20421. --- core/src/main/scala/org/apache/spark/SparkContext.scala | 1 + 1 file changed, 1 insertion(+) diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index 0ec1bdd39b2f5..f7c32e5f0cec5 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -1734,6 +1734,7 @@ class SparkContext(config: SparkConf) extends Logging { * Return information about blocks stored in all of the slaves */ @DeveloperApi + @deprecated("This method may change or be removed in a future release.", "2.2.0") def getExecutorStorageStatus: Array[StorageStatus] = { assertNotStopped() env.blockManager.master.getStorageStatus From b946f3160eb7953fb30edf1f097ea87be75b33e7 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Wed, 3 May 2017 10:08:46 +0800 Subject: [PATCH 0393/1765] [SPARK-20558][CORE] clear InheritableThreadLocal variables in SparkContext when stopping it ## What changes were proposed in this pull request? To better understand this problem, let's take a look at an example first: ``` object Main { def main(args: Array[String]): Unit = { var t = new Test new Thread(new Runnable { override def run() = {} }).start() println("first thread finished") t.a = null t = new Test new Thread(new Runnable { override def run() = {} }).start() } } class Test { var a = new InheritableThreadLocal[String] { override protected def childValue(parent: String): String = { println("parent value is: " + parent) parent } } a.set("hello") } ``` The result is: ``` parent value is: hello first thread finished parent value is: hello parent value is: hello ``` Once an `InheritableThreadLocal` has been set value, child threads will inherit its value as long as it has not been GCed, so setting the variable which holds the `InheritableThreadLocal` to `null` doesn't work as we expected. In `SparkContext`, we have an `InheritableThreadLocal` for local properties, we should clear it when stopping `SparkContext`, or all the future child threads will still inherit it and copy the properties and waste memory. This is the root cause of https://issues.apache.org/jira/browse/SPARK-20548 , which creates/stops `SparkContext` many times and finally have a lot of `InheritableThreadLocal` alive, and cause OOM when starting new threads in the internal thread pools. ## How was this patch tested? N/A Author: Wenchen Fan Closes #17833 from cloud-fan/core. --- core/src/main/scala/org/apache/spark/SparkContext.scala | 3 +++ 1 file changed, 3 insertions(+) diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index f7c32e5f0cec5..7dbceb9c5c1a3 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -1939,6 +1939,9 @@ class SparkContext(config: SparkConf) extends Logging { } SparkEnv.set(null) } + // Clear this `InheritableThreadLocal`, or it will still be inherited in child threads even this + // `SparkContext` is stopped. + localProperties.remove() // Unset YARN mode system env variable, to allow switching between cluster types. System.clearProperty("SPARK_YARN_MODE") SparkContext.clearActiveContext() From 6235132a8ce64bb12d825d0a65e5dd052d1ee647 Mon Sep 17 00:00:00 2001 From: Michael Armbrust Date: Tue, 2 May 2017 22:44:27 -0700 Subject: [PATCH 0394/1765] [SPARK-20567] Lazily bind in GenerateExec It is not valid to eagerly bind with the child's output as this causes failures when we attempt to canonicalize the plan (replacing the attribute references with dummies). Author: Michael Armbrust Closes #17838 from marmbrus/fixBindExplode. --- .../spark/sql/execution/GenerateExec.scala | 2 +- .../streaming/StreamingAggregationSuite.scala | 16 ++++++++++++++++ 2 files changed, 17 insertions(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/GenerateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/GenerateExec.scala index 1812a1152cb48..c35e5638e9273 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/GenerateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/GenerateExec.scala @@ -78,7 +78,7 @@ case class GenerateExec( override def outputPartitioning: Partitioning = child.outputPartitioning - val boundGenerator: Generator = BindReferences.bindReference(generator, child.output) + lazy val boundGenerator: Generator = BindReferences.bindReference(generator, child.output) protected override def doExecute(): RDD[InternalRow] = { // boundGenerator.terminate() should be triggered after all of the rows in the partition diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala index f796a4cb4a398..4345a70601c34 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala @@ -69,6 +69,22 @@ class StreamingAggregationSuite extends StateStoreMetricsTest with BeforeAndAfte ) } + test("count distinct") { + val inputData = MemoryStream[(Int, Seq[Int])] + + val aggregated = + inputData.toDF() + .select($"*", explode($"_2") as 'value) + .groupBy($"_1") + .agg(size(collect_set($"value"))) + .as[(Int, Int)] + + testStream(aggregated, Update)( + AddData(inputData, (1, Seq(1, 2))), + CheckLastBatch((1, 2)) + ) + } + test("simple count, complete mode") { val inputData = MemoryStream[Int] From db2fb84b4a3c45daa449cc9232340193ce8eb37d Mon Sep 17 00:00:00 2001 From: MechCoder Date: Wed, 3 May 2017 10:58:05 +0200 Subject: [PATCH 0395/1765] [SPARK-6227][MLLIB][PYSPARK] Implement PySpark wrappers for SVD and PCA (v2) Add PCA and SVD to PySpark's wrappers for `RowMatrix` and `IndexedRowMatrix` (SVD only). Based on #7963, updated. ## How was this patch tested? New doc tests and unit tests. Ran all examples locally. Author: MechCoder Author: Nick Pentreath Closes #17621 from MLnick/SPARK-6227-pyspark-svd-pca. --- docs/mllib-dimensionality-reduction.md | 29 +-- .../spark/examples/mllib/JavaPCAExample.java | 27 ++- .../spark/examples/mllib/JavaSVDExample.java | 27 +-- .../python/mllib/pca_rowmatrix_example.py | 46 ++++ examples/src/main/python/mllib/svd_example.py | 48 +++++ .../mllib/PCAOnRowMatrixExample.scala | 4 +- .../spark/examples/mllib/SVDExample.scala | 11 +- python/pyspark/mllib/linalg/distributed.py | 199 +++++++++++++++++- python/pyspark/mllib/tests.py | 63 ++++++ 9 files changed, 408 insertions(+), 46 deletions(-) create mode 100644 examples/src/main/python/mllib/pca_rowmatrix_example.py create mode 100644 examples/src/main/python/mllib/svd_example.py diff --git a/docs/mllib-dimensionality-reduction.md b/docs/mllib-dimensionality-reduction.md index 539cbc1b3163a..a72680d52a26c 100644 --- a/docs/mllib-dimensionality-reduction.md +++ b/docs/mllib-dimensionality-reduction.md @@ -76,13 +76,14 @@ Refer to the [`SingularValueDecomposition` Java docs](api/java/org/apache/spark/ The same code applies to `IndexedRowMatrix` if `U` is defined as an `IndexedRowMatrix`. + +
    +Refer to the [`SingularValueDecomposition` Python docs](api/python/pyspark.mllib.html#pyspark.mllib.linalg.distributed.SingularValueDecomposition) for details on the API. -In order to run the above application, follow the instructions -provided in the [Self-Contained -Applications](quick-start.html#self-contained-applications) section of the Spark -quick-start guide. Be sure to also include *spark-mllib* to your build file as -a dependency. +{% include_example python/mllib/svd_example.py %} +The same code applies to `IndexedRowMatrix` if `U` is defined as an +`IndexedRowMatrix`.
    @@ -118,17 +119,21 @@ Refer to the [`PCA` Scala docs](api/scala/index.html#org.apache.spark.mllib.feat The following code demonstrates how to compute principal components on a `RowMatrix` and use them to project the vectors into a low-dimensional space. -The number of columns should be small, e.g, less than 1000. Refer to the [`RowMatrix` Java docs](api/java/org/apache/spark/mllib/linalg/distributed/RowMatrix.html) for details on the API. {% include_example java/org/apache/spark/examples/mllib/JavaPCAExample.java %} - -In order to run the above application, follow the instructions -provided in the [Self-Contained Applications](quick-start.html#self-contained-applications) -section of the Spark -quick-start guide. Be sure to also include *spark-mllib* to your build file as -a dependency. +
    + +The following code demonstrates how to compute principal components on a `RowMatrix` +and use them to project the vectors into a low-dimensional space. + +Refer to the [`RowMatrix` Python docs](api/python/pyspark.mllib.html#pyspark.mllib.linalg.distributed.RowMatrix) for details on the API. + +{% include_example python/mllib/pca_rowmatrix_example.py %} + +
    + diff --git a/examples/src/main/java/org/apache/spark/examples/mllib/JavaPCAExample.java b/examples/src/main/java/org/apache/spark/examples/mllib/JavaPCAExample.java index 3077f557ef886..0a7dc621e1110 100644 --- a/examples/src/main/java/org/apache/spark/examples/mllib/JavaPCAExample.java +++ b/examples/src/main/java/org/apache/spark/examples/mllib/JavaPCAExample.java @@ -18,7 +18,8 @@ package org.apache.spark.examples.mllib; // $example on$ -import java.util.LinkedList; +import java.util.Arrays; +import java.util.List; // $example off$ import org.apache.spark.SparkConf; @@ -39,21 +40,25 @@ public class JavaPCAExample { public static void main(String[] args) { SparkConf conf = new SparkConf().setAppName("PCA Example"); SparkContext sc = new SparkContext(conf); + JavaSparkContext jsc = JavaSparkContext.fromSparkContext(sc); // $example on$ - double[][] array = {{1.12, 2.05, 3.12}, {5.56, 6.28, 8.94}, {10.2, 8.0, 20.5}}; - LinkedList rowsList = new LinkedList<>(); - for (int i = 0; i < array.length; i++) { - Vector currentRow = Vectors.dense(array[i]); - rowsList.add(currentRow); - } - JavaRDD rows = JavaSparkContext.fromSparkContext(sc).parallelize(rowsList); + List data = Arrays.asList( + Vectors.sparse(5, new int[] {1, 3}, new double[] {1.0, 7.0}), + Vectors.dense(2.0, 0.0, 3.0, 4.0, 5.0), + Vectors.dense(4.0, 0.0, 0.0, 6.0, 7.0) + ); + + JavaRDD rows = jsc.parallelize(data); // Create a RowMatrix from JavaRDD. RowMatrix mat = new RowMatrix(rows.rdd()); - // Compute the top 3 principal components. - Matrix pc = mat.computePrincipalComponents(3); + // Compute the top 4 principal components. + // Principal components are stored in a local dense matrix. + Matrix pc = mat.computePrincipalComponents(4); + + // Project the rows to the linear space spanned by the top 4 principal components. RowMatrix projected = mat.multiply(pc); // $example off$ Vector[] collectPartitions = (Vector[])projected.rows().collect(); @@ -61,6 +66,6 @@ public static void main(String[] args) { for (Vector vector : collectPartitions) { System.out.println("\t" + vector); } - sc.stop(); + jsc.stop(); } } diff --git a/examples/src/main/java/org/apache/spark/examples/mllib/JavaSVDExample.java b/examples/src/main/java/org/apache/spark/examples/mllib/JavaSVDExample.java index 3730e60f68803..802be3960a337 100644 --- a/examples/src/main/java/org/apache/spark/examples/mllib/JavaSVDExample.java +++ b/examples/src/main/java/org/apache/spark/examples/mllib/JavaSVDExample.java @@ -18,7 +18,8 @@ package org.apache.spark.examples.mllib; // $example on$ -import java.util.LinkedList; +import java.util.Arrays; +import java.util.List; // $example off$ import org.apache.spark.SparkConf; @@ -43,22 +44,22 @@ public static void main(String[] args) { JavaSparkContext jsc = JavaSparkContext.fromSparkContext(sc); // $example on$ - double[][] array = {{1.12, 2.05, 3.12}, {5.56, 6.28, 8.94}, {10.2, 8.0, 20.5}}; - LinkedList rowsList = new LinkedList<>(); - for (int i = 0; i < array.length; i++) { - Vector currentRow = Vectors.dense(array[i]); - rowsList.add(currentRow); - } - JavaRDD rows = jsc.parallelize(rowsList); + List data = Arrays.asList( + Vectors.sparse(5, new int[] {1, 3}, new double[] {1.0, 7.0}), + Vectors.dense(2.0, 0.0, 3.0, 4.0, 5.0), + Vectors.dense(4.0, 0.0, 0.0, 6.0, 7.0) + ); + + JavaRDD rows = jsc.parallelize(data); // Create a RowMatrix from JavaRDD. RowMatrix mat = new RowMatrix(rows.rdd()); - // Compute the top 3 singular values and corresponding singular vectors. - SingularValueDecomposition svd = mat.computeSVD(3, true, 1.0E-9d); - RowMatrix U = svd.U(); - Vector s = svd.s(); - Matrix V = svd.V(); + // Compute the top 5 singular values and corresponding singular vectors. + SingularValueDecomposition svd = mat.computeSVD(5, true, 1.0E-9d); + RowMatrix U = svd.U(); // The U factor is a RowMatrix. + Vector s = svd.s(); // The singular values are stored in a local dense vector. + Matrix V = svd.V(); // The V factor is a local dense matrix. // $example off$ Vector[] collectPartitions = (Vector[]) U.rows().collect(); System.out.println("U factor is:"); diff --git a/examples/src/main/python/mllib/pca_rowmatrix_example.py b/examples/src/main/python/mllib/pca_rowmatrix_example.py new file mode 100644 index 0000000000000..49b9b1bbe08e9 --- /dev/null +++ b/examples/src/main/python/mllib/pca_rowmatrix_example.py @@ -0,0 +1,46 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from pyspark import SparkContext +# $example on$ +from pyspark.mllib.linalg import Vectors +from pyspark.mllib.linalg.distributed import RowMatrix +# $example off$ + +if __name__ == "__main__": + sc = SparkContext(appName="PythonPCAOnRowMatrixExample") + + # $example on$ + rows = sc.parallelize([ + Vectors.sparse(5, {1: 1.0, 3: 7.0}), + Vectors.dense(2.0, 0.0, 3.0, 4.0, 5.0), + Vectors.dense(4.0, 0.0, 0.0, 6.0, 7.0) + ]) + + mat = RowMatrix(rows) + # Compute the top 4 principal components. + # Principal components are stored in a local dense matrix. + pc = mat.computePrincipalComponents(4) + + # Project the rows to the linear space spanned by the top 4 principal components. + projected = mat.multiply(pc) + # $example off$ + collected = projected.rows.collect() + print("Projected Row Matrix of principal component:") + for vector in collected: + print(vector) + sc.stop() diff --git a/examples/src/main/python/mllib/svd_example.py b/examples/src/main/python/mllib/svd_example.py new file mode 100644 index 0000000000000..5b220fdb3fd67 --- /dev/null +++ b/examples/src/main/python/mllib/svd_example.py @@ -0,0 +1,48 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from pyspark import SparkContext +# $example on$ +from pyspark.mllib.linalg import Vectors +from pyspark.mllib.linalg.distributed import RowMatrix +# $example off$ + +if __name__ == "__main__": + sc = SparkContext(appName="PythonSVDExample") + + # $example on$ + rows = sc.parallelize([ + Vectors.sparse(5, {1: 1.0, 3: 7.0}), + Vectors.dense(2.0, 0.0, 3.0, 4.0, 5.0), + Vectors.dense(4.0, 0.0, 0.0, 6.0, 7.0) + ]) + + mat = RowMatrix(rows) + + # Compute the top 5 singular values and corresponding singular vectors. + svd = mat.computeSVD(5, computeU=True) + U = svd.U # The U factor is a RowMatrix. + s = svd.s # The singular values are stored in a local dense vector. + V = svd.V # The V factor is a local dense matrix. + # $example off$ + collected = U.rows.collect() + print("U factor is:") + for vector in collected: + print(vector) + print("Singular values are: %s" % s) + print("V factor is:\n%s" % V) + sc.stop() diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/PCAOnRowMatrixExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/PCAOnRowMatrixExample.scala index a137ba2a2f9d3..da43a8d9c7e80 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/PCAOnRowMatrixExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/PCAOnRowMatrixExample.scala @@ -39,9 +39,9 @@ object PCAOnRowMatrixExample { Vectors.dense(2.0, 0.0, 3.0, 4.0, 5.0), Vectors.dense(4.0, 0.0, 0.0, 6.0, 7.0)) - val dataRDD = sc.parallelize(data, 2) + val rows = sc.parallelize(data) - val mat: RowMatrix = new RowMatrix(dataRDD) + val mat: RowMatrix = new RowMatrix(rows) // Compute the top 4 principal components. // Principal components are stored in a local dense matrix. diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/SVDExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/SVDExample.scala index b286a3f7b9096..769ae2a3a88b1 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/SVDExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/SVDExample.scala @@ -28,6 +28,9 @@ import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.linalg.distributed.RowMatrix // $example off$ +/** + * Example for SingularValueDecomposition. + */ object SVDExample { def main(args: Array[String]): Unit = { @@ -41,15 +44,15 @@ object SVDExample { Vectors.dense(2.0, 0.0, 3.0, 4.0, 5.0), Vectors.dense(4.0, 0.0, 0.0, 6.0, 7.0)) - val dataRDD = sc.parallelize(data, 2) + val rows = sc.parallelize(data) - val mat: RowMatrix = new RowMatrix(dataRDD) + val mat: RowMatrix = new RowMatrix(rows) // Compute the top 5 singular values and corresponding singular vectors. val svd: SingularValueDecomposition[RowMatrix, Matrix] = mat.computeSVD(5, computeU = true) val U: RowMatrix = svd.U // The U factor is a RowMatrix. - val s: Vector = svd.s // The singular values are stored in a local dense vector. - val V: Matrix = svd.V // The V factor is a local dense matrix. + val s: Vector = svd.s // The singular values are stored in a local dense vector. + val V: Matrix = svd.V // The V factor is a local dense matrix. // $example off$ val collect = U.rows.collect() println("U factor is:") diff --git a/python/pyspark/mllib/linalg/distributed.py b/python/pyspark/mllib/linalg/distributed.py index 600655c912ca6..4cb802514be52 100644 --- a/python/pyspark/mllib/linalg/distributed.py +++ b/python/pyspark/mllib/linalg/distributed.py @@ -28,14 +28,13 @@ from pyspark import RDD, since from pyspark.mllib.common import callMLlibFunc, JavaModelWrapper -from pyspark.mllib.linalg import _convert_to_vector, Matrix, QRDecomposition +from pyspark.mllib.linalg import _convert_to_vector, DenseMatrix, Matrix, QRDecomposition from pyspark.mllib.stat import MultivariateStatisticalSummary from pyspark.storagelevel import StorageLevel -__all__ = ['DistributedMatrix', 'RowMatrix', 'IndexedRow', - 'IndexedRowMatrix', 'MatrixEntry', 'CoordinateMatrix', - 'BlockMatrix'] +__all__ = ['BlockMatrix', 'CoordinateMatrix', 'DistributedMatrix', 'IndexedRow', + 'IndexedRowMatrix', 'MatrixEntry', 'RowMatrix', 'SingularValueDecomposition'] class DistributedMatrix(object): @@ -301,6 +300,136 @@ def tallSkinnyQR(self, computeQ=False): R = decomp.call("R") return QRDecomposition(Q, R) + @since('2.2.0') + def computeSVD(self, k, computeU=False, rCond=1e-9): + """ + Computes the singular value decomposition of the RowMatrix. + + The given row matrix A of dimension (m X n) is decomposed into + U * s * V'T where + + * U: (m X k) (left singular vectors) is a RowMatrix whose + columns are the eigenvectors of (A X A') + * s: DenseVector consisting of square root of the eigenvalues + (singular values) in descending order. + * v: (n X k) (right singular vectors) is a Matrix whose columns + are the eigenvectors of (A' X A) + + For more specific details on implementation, please refer + the Scala documentation. + + :param k: Number of leading singular values to keep (`0 < k <= n`). + It might return less than k if there are numerically zero singular values + or there are not enough Ritz values converged before the maximum number of + Arnoldi update iterations is reached (in case that matrix A is ill-conditioned). + :param computeU: Whether or not to compute U. If set to be + True, then U is computed by A * V * s^-1 + :param rCond: Reciprocal condition number. All singular values + smaller than rCond * s[0] are treated as zero + where s[0] is the largest singular value. + :returns: :py:class:`SingularValueDecomposition` + + >>> rows = sc.parallelize([[3, 1, 1], [-1, 3, 1]]) + >>> rm = RowMatrix(rows) + + >>> svd_model = rm.computeSVD(2, True) + >>> svd_model.U.rows.collect() + [DenseVector([-0.7071, 0.7071]), DenseVector([-0.7071, -0.7071])] + >>> svd_model.s + DenseVector([3.4641, 3.1623]) + >>> svd_model.V + DenseMatrix(3, 2, [-0.4082, -0.8165, -0.4082, 0.8944, -0.4472, 0.0], 0) + """ + j_model = self._java_matrix_wrapper.call( + "computeSVD", int(k), bool(computeU), float(rCond)) + return SingularValueDecomposition(j_model) + + @since('2.2.0') + def computePrincipalComponents(self, k): + """ + Computes the k principal components of the given row matrix + + .. note:: This cannot be computed on matrices with more than 65535 columns. + + :param k: Number of principal components to keep. + :returns: :py:class:`pyspark.mllib.linalg.DenseMatrix` + + >>> rows = sc.parallelize([[1, 2, 3], [2, 4, 5], [3, 6, 1]]) + >>> rm = RowMatrix(rows) + + >>> # Returns the two principal components of rm + >>> pca = rm.computePrincipalComponents(2) + >>> pca + DenseMatrix(3, 2, [-0.349, -0.6981, 0.6252, -0.2796, -0.5592, -0.7805], 0) + + >>> # Transform into new dimensions with the greatest variance. + >>> rm.multiply(pca).rows.collect() # doctest: +NORMALIZE_WHITESPACE + [DenseVector([0.1305, -3.7394]), DenseVector([-0.3642, -6.6983]), \ + DenseVector([-4.6102, -4.9745])] + """ + return self._java_matrix_wrapper.call("computePrincipalComponents", k) + + @since('2.2.0') + def multiply(self, matrix): + """ + Multiply this matrix by a local dense matrix on the right. + + :param matrix: a local dense matrix whose number of rows must match the number of columns + of this matrix + :returns: :py:class:`RowMatrix` + + >>> rm = RowMatrix(sc.parallelize([[0, 1], [2, 3]])) + >>> rm.multiply(DenseMatrix(2, 2, [0, 2, 1, 3])).rows.collect() + [DenseVector([2.0, 3.0]), DenseVector([6.0, 11.0])] + """ + if not isinstance(matrix, DenseMatrix): + raise ValueError("Only multiplication with DenseMatrix " + "is supported.") + j_model = self._java_matrix_wrapper.call("multiply", matrix) + return RowMatrix(j_model) + + +class SingularValueDecomposition(JavaModelWrapper): + """ + Represents singular value decomposition (SVD) factors. + + .. versionadded:: 2.2.0 + """ + + @property + @since('2.2.0') + def U(self): + """ + Returns a distributed matrix whose columns are the left + singular vectors of the SingularValueDecomposition if computeU was set to be True. + """ + u = self.call("U") + if u is not None: + mat_name = u.getClass().getSimpleName() + if mat_name == "RowMatrix": + return RowMatrix(u) + elif mat_name == "IndexedRowMatrix": + return IndexedRowMatrix(u) + else: + raise TypeError("Expected RowMatrix/IndexedRowMatrix got %s" % mat_name) + + @property + @since('2.2.0') + def s(self): + """ + Returns a DenseVector with singular values in descending order. + """ + return self.call("s") + + @property + @since('2.2.0') + def V(self): + """ + Returns a DenseMatrix whose columns are the right singular + vectors of the SingularValueDecomposition. + """ + return self.call("V") + class IndexedRow(object): """ @@ -528,6 +657,68 @@ def toBlockMatrix(self, rowsPerBlock=1024, colsPerBlock=1024): colsPerBlock) return BlockMatrix(java_block_matrix, rowsPerBlock, colsPerBlock) + @since('2.2.0') + def computeSVD(self, k, computeU=False, rCond=1e-9): + """ + Computes the singular value decomposition of the IndexedRowMatrix. + + The given row matrix A of dimension (m X n) is decomposed into + U * s * V'T where + + * U: (m X k) (left singular vectors) is a IndexedRowMatrix + whose columns are the eigenvectors of (A X A') + * s: DenseVector consisting of square root of the eigenvalues + (singular values) in descending order. + * v: (n X k) (right singular vectors) is a Matrix whose columns + are the eigenvectors of (A' X A) + + For more specific details on implementation, please refer + the scala documentation. + + :param k: Number of leading singular values to keep (`0 < k <= n`). + It might return less than k if there are numerically zero singular values + or there are not enough Ritz values converged before the maximum number of + Arnoldi update iterations is reached (in case that matrix A is ill-conditioned). + :param computeU: Whether or not to compute U. If set to be + True, then U is computed by A * V * s^-1 + :param rCond: Reciprocal condition number. All singular values + smaller than rCond * s[0] are treated as zero + where s[0] is the largest singular value. + :returns: SingularValueDecomposition object + + >>> rows = [(0, (3, 1, 1)), (1, (-1, 3, 1))] + >>> irm = IndexedRowMatrix(sc.parallelize(rows)) + >>> svd_model = irm.computeSVD(2, True) + >>> svd_model.U.rows.collect() # doctest: +NORMALIZE_WHITESPACE + [IndexedRow(0, [-0.707106781187,0.707106781187]),\ + IndexedRow(1, [-0.707106781187,-0.707106781187])] + >>> svd_model.s + DenseVector([3.4641, 3.1623]) + >>> svd_model.V + DenseMatrix(3, 2, [-0.4082, -0.8165, -0.4082, 0.8944, -0.4472, 0.0], 0) + """ + j_model = self._java_matrix_wrapper.call( + "computeSVD", int(k), bool(computeU), float(rCond)) + return SingularValueDecomposition(j_model) + + @since('2.2.0') + def multiply(self, matrix): + """ + Multiply this matrix by a local dense matrix on the right. + + :param matrix: a local dense matrix whose number of rows must match the number of columns + of this matrix + :returns: :py:class:`IndexedRowMatrix` + + >>> mat = IndexedRowMatrix(sc.parallelize([(0, (0, 1)), (1, (2, 3))])) + >>> mat.multiply(DenseMatrix(2, 2, [0, 2, 1, 3])).rows.collect() + [IndexedRow(0, [2.0,3.0]), IndexedRow(1, [6.0,11.0])] + """ + if not isinstance(matrix, DenseMatrix): + raise ValueError("Only multiplication with DenseMatrix " + "is supported.") + return IndexedRowMatrix(self._java_matrix_wrapper.call("multiply", matrix)) + class MatrixEntry(object): """ diff --git a/python/pyspark/mllib/tests.py b/python/pyspark/mllib/tests.py index 523b3f1113317..1037bab7f1088 100644 --- a/python/pyspark/mllib/tests.py +++ b/python/pyspark/mllib/tests.py @@ -23,6 +23,7 @@ import sys import tempfile import array as pyarray +from math import sqrt from time import time, sleep from shutil import rmtree @@ -54,6 +55,7 @@ from pyspark.mllib.clustering import StreamingKMeans, StreamingKMeansModel from pyspark.mllib.linalg import Vector, SparseVector, DenseVector, VectorUDT, _convert_to_vector,\ DenseMatrix, SparseMatrix, Vectors, Matrices, MatrixUDT +from pyspark.mllib.linalg.distributed import RowMatrix from pyspark.mllib.classification import StreamingLogisticRegressionWithSGD from pyspark.mllib.recommendation import Rating from pyspark.mllib.regression import LabeledPoint, StreamingLinearRegressionWithSGD @@ -1699,6 +1701,67 @@ def test_binary_term_freqs(self): ": expected " + str(expected[i]) + ", got " + str(output[i])) +class DimensionalityReductionTests(MLlibTestCase): + + denseData = [ + Vectors.dense([0.0, 1.0, 2.0]), + Vectors.dense([3.0, 4.0, 5.0]), + Vectors.dense([6.0, 7.0, 8.0]), + Vectors.dense([9.0, 0.0, 1.0]) + ] + sparseData = [ + Vectors.sparse(3, [(1, 1.0), (2, 2.0)]), + Vectors.sparse(3, [(0, 3.0), (1, 4.0), (2, 5.0)]), + Vectors.sparse(3, [(0, 6.0), (1, 7.0), (2, 8.0)]), + Vectors.sparse(3, [(0, 9.0), (2, 1.0)]) + ] + + def assertEqualUpToSign(self, vecA, vecB): + eq1 = vecA - vecB + eq2 = vecA + vecB + self.assertTrue(sum(abs(eq1)) < 1e-6 or sum(abs(eq2)) < 1e-6) + + def test_svd(self): + denseMat = RowMatrix(self.sc.parallelize(self.denseData)) + sparseMat = RowMatrix(self.sc.parallelize(self.sparseData)) + m = 4 + n = 3 + for mat in [denseMat, sparseMat]: + for k in range(1, 4): + rm = mat.computeSVD(k, computeU=True) + self.assertEqual(rm.s.size, k) + self.assertEqual(rm.U.numRows(), m) + self.assertEqual(rm.U.numCols(), k) + self.assertEqual(rm.V.numRows, n) + self.assertEqual(rm.V.numCols, k) + + # Test that U returned is None if computeU is set to False. + self.assertEqual(mat.computeSVD(1).U, None) + + # Test that low rank matrices cannot have number of singular values + # greater than a limit. + rm = RowMatrix(self.sc.parallelize(tile([1, 2, 3], (3, 1)))) + self.assertEqual(rm.computeSVD(3, False, 1e-6).s.size, 1) + + def test_pca(self): + expected_pcs = array([ + [0.0, 1.0, 0.0], + [sqrt(2.0) / 2.0, 0.0, sqrt(2.0) / 2.0], + [sqrt(2.0) / 2.0, 0.0, -sqrt(2.0) / 2.0] + ]) + n = 3 + denseMat = RowMatrix(self.sc.parallelize(self.denseData)) + sparseMat = RowMatrix(self.sc.parallelize(self.sparseData)) + for mat in [denseMat, sparseMat]: + for k in range(1, 4): + pcs = mat.computePrincipalComponents(k) + self.assertEqual(pcs.numRows, n) + self.assertEqual(pcs.numCols, k) + + # We can just test the updated principal component for equality. + self.assertEqualUpToSign(pcs.toArray()[:, k - 1], expected_pcs[:, k - 1]) + + if __name__ == "__main__": from pyspark.mllib.tests import * if not _have_scipy: From 16fab6b0ef3dcb33f92df30e17680922ad5fb672 Mon Sep 17 00:00:00 2001 From: Sean Owen Date: Wed, 3 May 2017 10:18:35 +0100 Subject: [PATCH 0396/1765] [SPARK-20523][BUILD] Clean up build warnings for 2.2.0 release ## What changes were proposed in this pull request? Fix build warnings primarily related to Breeze 0.13 operator changes, Java style problems ## How was this patch tested? Existing tests Author: Sean Owen Closes #17803 from srowen/SPARK-20523. --- .../spark/network/yarn/YarnShuffleService.java | 4 ++-- .../java/org/apache/spark/unsafe/Platform.java | 3 ++- .../org/apache/spark/memory/TaskMemoryManager.java | 3 ++- .../spark/scheduler/TaskSetManagerSuite.scala | 11 ++++++----- .../storage/BlockReplicationPolicySuite.scala | 1 + dev/checkstyle-suppressions.xml | 4 ++++ .../streaming/JavaStructuredSessionization.java | 2 -- .../org/apache/spark/graphx/lib/PageRank.scala | 14 +++++++------- .../org/apache/spark/ml/ann/LossFunction.scala | 4 ++-- .../spark/ml/clustering/GaussianMixture.scala | 2 +- .../spark/mllib/clustering/GaussianMixture.scala | 2 +- .../apache/spark/mllib/clustering/LDAModel.scala | 8 ++++---- .../spark/mllib/clustering/LDAOptimizer.scala | 12 ++++++------ .../apache/spark/mllib/clustering/LDAUtils.scala | 2 +- .../spark/ml/classification/NaiveBayesSuite.scala | 2 +- pom.xml | 4 ---- .../cluster/YarnSchedulerBackendSuite.scala | 2 ++ .../spark/sql/streaming/GroupStateTimeout.java | 5 ++++- .../expressions/JsonExpressionsSuite.scala | 2 +- .../parquet/SpecificParquetRecordReaderBase.java | 5 +++-- .../spark/sql/execution/QueryExecutionSuite.scala | 2 ++ .../streaming/StreamingQueryListenerSuite.scala | 1 + .../spark/sql/hive/execution/HiveDDLSuite.scala | 2 +- 23 files changed, 54 insertions(+), 43 deletions(-) diff --git a/common/network-yarn/src/main/java/org/apache/spark/network/yarn/YarnShuffleService.java b/common/network-yarn/src/main/java/org/apache/spark/network/yarn/YarnShuffleService.java index 4acc203153e5a..fd50e3a4bfb9b 100644 --- a/common/network-yarn/src/main/java/org/apache/spark/network/yarn/YarnShuffleService.java +++ b/common/network-yarn/src/main/java/org/apache/spark/network/yarn/YarnShuffleService.java @@ -363,9 +363,9 @@ protected File initRecoveryDb(String dbName) { // If another DB was initialized first just make sure all the DBs are in the same // location. Path newLoc = new Path(_recoveryPath, dbName); - Path copyFrom = new Path(f.toURI()); + Path copyFrom = new Path(f.toURI()); if (!newLoc.equals(copyFrom)) { - logger.info("Moving " + copyFrom + " to: " + newLoc); + logger.info("Moving " + copyFrom + " to: " + newLoc); try { // The move here needs to handle moving non-empty directories across NFS mounts FileSystem fs = FileSystem.getLocal(_conf); diff --git a/common/unsafe/src/main/java/org/apache/spark/unsafe/Platform.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/Platform.java index 1321b83181150..4ab5b6889c212 100644 --- a/common/unsafe/src/main/java/org/apache/spark/unsafe/Platform.java +++ b/common/unsafe/src/main/java/org/apache/spark/unsafe/Platform.java @@ -48,7 +48,8 @@ public final class Platform { boolean _unaligned; String arch = System.getProperty("os.arch", ""); if (arch.equals("ppc64le") || arch.equals("ppc64")) { - // Since java.nio.Bits.unaligned() doesn't return true on ppc (See JDK-8165231), but ppc64 and ppc64le support it + // Since java.nio.Bits.unaligned() doesn't return true on ppc (See JDK-8165231), but + // ppc64 and ppc64le support it _unaligned = true; } else { try { diff --git a/core/src/main/java/org/apache/spark/memory/TaskMemoryManager.java b/core/src/main/java/org/apache/spark/memory/TaskMemoryManager.java index aa0b373231327..5f91411749167 100644 --- a/core/src/main/java/org/apache/spark/memory/TaskMemoryManager.java +++ b/core/src/main/java/org/apache/spark/memory/TaskMemoryManager.java @@ -155,7 +155,8 @@ public long acquireExecutionMemory(long required, MemoryConsumer consumer) { for (MemoryConsumer c: consumers) { if (c != consumer && c.getUsed() > 0 && c.getMode() == mode) { long key = c.getUsed(); - List list = sortedConsumers.computeIfAbsent(key, k -> new ArrayList<>(1)); + List list = + sortedConsumers.computeIfAbsent(key, k -> new ArrayList<>(1)); list.add(c); } } diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala index 9ca6b8b0fe635..db14c9acfdce5 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala @@ -1070,11 +1070,12 @@ class TaskSetManagerSuite extends SparkFunSuite with LocalSparkContext with Logg sched.dagScheduler = mockDAGScheduler val taskSet = FakeTask.createTaskSet(numTasks = 1, stageId = 0, stageAttemptId = 0) val manager = new TaskSetManager(sched, taskSet, MAX_TASK_FAILURES, clock = new ManualClock(1)) - when(mockDAGScheduler.taskEnded(any(), any(), any(), any(), any())).then(new Answer[Unit] { - override def answer(invocationOnMock: InvocationOnMock): Unit = { - assert(manager.isZombie === true) - } - }) + when(mockDAGScheduler.taskEnded(any(), any(), any(), any(), any())).thenAnswer( + new Answer[Unit] { + override def answer(invocationOnMock: InvocationOnMock): Unit = { + assert(manager.isZombie) + } + }) val taskOption = manager.resourceOffer("exec1", "host1", NO_PREF) assert(taskOption.isDefined) // this would fail, inside our mock dag scheduler, if it calls dagScheduler.taskEnded() too soon diff --git a/core/src/test/scala/org/apache/spark/storage/BlockReplicationPolicySuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockReplicationPolicySuite.scala index dfecd04c1b969..4000218e71a8b 100644 --- a/core/src/test/scala/org/apache/spark/storage/BlockReplicationPolicySuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/BlockReplicationPolicySuite.scala @@ -18,6 +18,7 @@ package org.apache.spark.storage import scala.collection.mutable +import scala.language.implicitConversions import scala.util.Random import org.scalatest.{BeforeAndAfter, Matchers} diff --git a/dev/checkstyle-suppressions.xml b/dev/checkstyle-suppressions.xml index 31656ca0e5a60..bb7d31cad7be3 100644 --- a/dev/checkstyle-suppressions.xml +++ b/dev/checkstyle-suppressions.xml @@ -44,4 +44,8 @@ files="src/main/java/org/apache/hive/service/server/ThreadWithGarbageCleanup.java"/> + + diff --git a/examples/src/main/java/org/apache/spark/examples/sql/streaming/JavaStructuredSessionization.java b/examples/src/main/java/org/apache/spark/examples/sql/streaming/JavaStructuredSessionization.java index d3c8516882fa6..6b8e6554f1bb1 100644 --- a/examples/src/main/java/org/apache/spark/examples/sql/streaming/JavaStructuredSessionization.java +++ b/examples/src/main/java/org/apache/spark/examples/sql/streaming/JavaStructuredSessionization.java @@ -28,8 +28,6 @@ import java.sql.Timestamp; import java.util.*; -import scala.Tuple2; - /** * Counts words in UTF8 encoded, '\n' delimited text received from the network. *

    diff --git a/graphx/src/main/scala/org/apache/spark/graphx/lib/PageRank.scala b/graphx/src/main/scala/org/apache/spark/graphx/lib/PageRank.scala index 13b2b57719188..fd7b7f7c1c487 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/lib/PageRank.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/lib/PageRank.scala @@ -226,18 +226,18 @@ object PageRank extends Logging { // Propagates the message along outbound edges // and adding start nodes back in with activation resetProb val rankUpdates = rankGraph.aggregateMessages[BV[Double]]( - ctx => ctx.sendToDst(ctx.srcAttr :* ctx.attr), - (a : BV[Double], b : BV[Double]) => a :+ b, TripletFields.Src) + ctx => ctx.sendToDst(ctx.srcAttr *:* ctx.attr), + (a : BV[Double], b : BV[Double]) => a +:+ b, TripletFields.Src) rankGraph = rankGraph.outerJoinVertices(rankUpdates) { (vid, oldRank, msgSumOpt) => - val popActivations: BV[Double] = msgSumOpt.getOrElse(zero) :* (1.0 - resetProb) + val popActivations: BV[Double] = msgSumOpt.getOrElse(zero) *:* (1.0 - resetProb) val resetActivations = if (sourcesInitMapBC.value contains vid) { - sourcesInitMapBC.value(vid) :* resetProb + sourcesInitMapBC.value(vid) *:* resetProb } else { zero } - popActivations :+ resetActivations + popActivations +:+ resetActivations }.cache() rankGraph.edges.foreachPartition(x => {}) // also materializes rankGraph.vertices @@ -250,9 +250,9 @@ object PageRank extends Logging { } // SPARK-18847 If the graph has sinks (vertices with no outgoing edges) correct the sum of ranks - val rankSums = rankGraph.vertices.values.fold(zero)(_ :+ _) + val rankSums = rankGraph.vertices.values.fold(zero)(_ +:+ _) rankGraph.mapVertices { (vid, attr) => - Vectors.fromBreeze(attr :/ rankSums) + Vectors.fromBreeze(attr /:/ rankSums) } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/ann/LossFunction.scala b/mllib/src/main/scala/org/apache/spark/ml/ann/LossFunction.scala index 32d78e9b226eb..3aea568cd6527 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/ann/LossFunction.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/ann/LossFunction.scala @@ -56,7 +56,7 @@ private[ann] class SigmoidLayerModelWithSquaredError extends FunctionalLayerModel(new FunctionalLayer(new SigmoidFunction)) with LossFunction { override def loss(output: BDM[Double], target: BDM[Double], delta: BDM[Double]): Double = { ApplyInPlace(output, target, delta, (o: Double, t: Double) => o - t) - val error = Bsum(delta :* delta) / 2 / output.cols + val error = Bsum(delta *:* delta) / 2 / output.cols ApplyInPlace(delta, output, delta, (x: Double, o: Double) => x * (o - o * o)) error } @@ -119,6 +119,6 @@ private[ann] class SoftmaxLayerModelWithCrossEntropyLoss extends LayerModel with override def loss(output: BDM[Double], target: BDM[Double], delta: BDM[Double]): Double = { ApplyInPlace(output, target, delta, (o: Double, t: Double) => o - t) - -Bsum( target :* brzlog(output)) / output.cols + -Bsum( target *:* brzlog(output)) / output.cols } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala index a9c1a7ba0bc8a..5259ee419445f 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala @@ -472,7 +472,7 @@ class GaussianMixture @Since("2.0.0") ( */ val cov = { val ss = new DenseVector(new Array[Double](numFeatures)).asBreeze - slice.foreach(xi => ss += (xi.asBreeze - mean.asBreeze) :^ 2.0) + slice.foreach(xi => ss += (xi.asBreeze - mean.asBreeze) ^:^ 2.0) val diagVec = Vectors.fromBreeze(ss) BLAS.scal(1.0 / numSamples, diagVec) val covVec = new DenseVector(Array.fill[Double]( diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixture.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixture.scala index 051ec2404fb6e..4d952ac88c9be 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixture.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixture.scala @@ -271,7 +271,7 @@ class GaussianMixture private ( private def initCovariance(x: IndexedSeq[BV[Double]]): BreezeMatrix[Double] = { val mu = vectorMean(x) val ss = BDV.zeros[Double](x(0).length) - x.foreach(xi => ss += (xi - mu) :^ 2.0) + x.foreach(xi => ss += (xi - mu) ^:^ 2.0) diag(ss / x.length.toDouble) } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala index 15b723dadcff7..663f63c25a940 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala @@ -314,7 +314,7 @@ class LocalLDAModel private[spark] ( docBound += count * LDAUtils.logSumExp(Elogthetad + localElogbeta(idx, ::).t) } // E[log p(theta | alpha) - log q(theta | gamma)] - docBound += sum((brzAlpha - gammad) :* Elogthetad) + docBound += sum((brzAlpha - gammad) *:* Elogthetad) docBound += sum(lgamma(gammad) - lgamma(brzAlpha)) docBound += lgamma(sum(brzAlpha)) - lgamma(sum(gammad)) @@ -324,7 +324,7 @@ class LocalLDAModel private[spark] ( // Bound component for prob(topic-term distributions): // E[log p(beta | eta) - log q(beta | lambda)] val sumEta = eta * vocabSize - val topicsPart = sum((eta - lambda) :* Elogbeta) + + val topicsPart = sum((eta - lambda) *:* Elogbeta) + sum(lgamma(lambda) - lgamma(eta)) + sum(lgamma(sumEta) - lgamma(sum(lambda(::, breeze.linalg.*)))) @@ -721,7 +721,7 @@ class DistributedLDAModel private[clustering] ( val N_wj = edgeContext.attr val smoothed_N_wk: TopicCounts = edgeContext.dstAttr + (eta - 1.0) val smoothed_N_kj: TopicCounts = edgeContext.srcAttr + (alpha - 1.0) - val phi_wk: TopicCounts = smoothed_N_wk :/ smoothed_N_k + val phi_wk: TopicCounts = smoothed_N_wk /:/ smoothed_N_k val theta_kj: TopicCounts = normalize(smoothed_N_kj, 1.0) val tokenLogLikelihood = N_wj * math.log(phi_wk.dot(theta_kj)) edgeContext.sendToDst(tokenLogLikelihood) @@ -748,7 +748,7 @@ class DistributedLDAModel private[clustering] ( if (isTermVertex(vertex)) { val N_wk = vertex._2 val smoothed_N_wk: TopicCounts = N_wk + (eta - 1.0) - val phi_wk: TopicCounts = smoothed_N_wk :/ smoothed_N_k + val phi_wk: TopicCounts = smoothed_N_wk /:/ smoothed_N_k sumPrior + (eta - 1.0) * sum(phi_wk.map(math.log)) } else { val N_kj = vertex._2 diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala index 3697a9b46dd84..d633893e55f55 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala @@ -482,7 +482,7 @@ final class OnlineLDAOptimizer extends LDAOptimizer { stats.map(_._2).flatMap(list => list).collect().map(_.toDenseMatrix): _*) stats.unpersist() expElogbetaBc.destroy(false) - val batchResult = statsSum :* expElogbeta.t + val batchResult = statsSum *:* expElogbeta.t // Note that this is an optimization to avoid batch.count updateLambda(batchResult, (miniBatchFraction * corpusSize).ceil.toInt) @@ -522,7 +522,7 @@ final class OnlineLDAOptimizer extends LDAOptimizer { val dalpha = -(gradf - b) / q - if (all((weight * dalpha + alpha) :> 0D)) { + if (all((weight * dalpha + alpha) >:> 0D)) { alpha :+= weight * dalpha this.alpha = Vectors.dense(alpha.toArray) } @@ -584,7 +584,7 @@ private[clustering] object OnlineLDAOptimizer { val expElogthetad: BDV[Double] = exp(LDAUtils.dirichletExpectation(gammad)) // K val expElogbetad = expElogbeta(ids, ::).toDenseMatrix // ids * K - val phiNorm: BDV[Double] = expElogbetad * expElogthetad :+ 1e-100 // ids + val phiNorm: BDV[Double] = expElogbetad * expElogthetad +:+ 1e-100 // ids var meanGammaChange = 1D val ctsVector = new BDV[Double](cts) // ids @@ -592,14 +592,14 @@ private[clustering] object OnlineLDAOptimizer { while (meanGammaChange > 1e-3) { val lastgamma = gammad.copy // K K * ids ids - gammad := (expElogthetad :* (expElogbetad.t * (ctsVector :/ phiNorm))) :+ alpha + gammad := (expElogthetad *:* (expElogbetad.t * (ctsVector /:/ phiNorm))) +:+ alpha expElogthetad := exp(LDAUtils.dirichletExpectation(gammad)) // TODO: Keep more values in log space, and only exponentiate when needed. - phiNorm := expElogbetad * expElogthetad :+ 1e-100 + phiNorm := expElogbetad * expElogthetad +:+ 1e-100 meanGammaChange = sum(abs(gammad - lastgamma)) / k } - val sstatsd = expElogthetad.asDenseMatrix.t * (ctsVector :/ phiNorm).asDenseMatrix + val sstatsd = expElogthetad.asDenseMatrix.t * (ctsVector /:/ phiNorm).asDenseMatrix (gammad, sstatsd, ids) } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAUtils.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAUtils.scala index 1f6e1a077f923..c4bbe51a46c32 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAUtils.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAUtils.scala @@ -29,7 +29,7 @@ private[clustering] object LDAUtils { */ private[clustering] def logSumExp(x: BDV[Double]): Double = { val a = max(x) - a + log(sum(exp(x :- a))) + a + log(sum(exp(x -:- a))) } /** diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala index b56f8e19ca53c..3a2be236f1257 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala @@ -168,7 +168,7 @@ class NaiveBayesSuite extends SparkFunSuite with MLlibTestSparkContext with Defa assert(m1.pi ~== m2.pi relTol 0.01) assert(m1.theta ~== m2.theta relTol 0.01) } - val testParams = Seq( + val testParams = Seq[(String, Dataset[_])]( ("bernoulli", bernoulliDataset), ("multinomial", dataset) ) diff --git a/pom.xml b/pom.xml index 517ebc5c83fc6..a1a1817e2f7d3 100644 --- a/pom.xml +++ b/pom.xml @@ -58,10 +58,6 @@ https://issues.apache.org/jira/browse/SPARK - - ${maven.version} - - Dev Mailing List diff --git a/resource-managers/yarn/src/test/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackendSuite.scala b/resource-managers/yarn/src/test/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackendSuite.scala index 4079d9e40fc41..0a413b2c23de1 100644 --- a/resource-managers/yarn/src/test/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackendSuite.scala +++ b/resource-managers/yarn/src/test/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackendSuite.scala @@ -16,6 +16,8 @@ */ package org.apache.spark.scheduler.cluster +import scala.language.reflectiveCalls + import org.mockito.Mockito.when import org.scalatest.mock.MockitoSugar diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/streaming/GroupStateTimeout.java b/sql/catalyst/src/main/java/org/apache/spark/sql/streaming/GroupStateTimeout.java index bd5e2d7ecca9b..5f1032d1229da 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/streaming/GroupStateTimeout.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/streaming/GroupStateTimeout.java @@ -37,7 +37,9 @@ public class GroupStateTimeout { * `map/flatMapGroupsWithState` by calling `GroupState.setTimeoutDuration()`. See documentation * on `GroupState` for more details. */ - public static GroupStateTimeout ProcessingTimeTimeout() { return ProcessingTimeTimeout$.MODULE$; } + public static GroupStateTimeout ProcessingTimeTimeout() { + return ProcessingTimeTimeout$.MODULE$; + } /** * Timeout based on event-time. The event-time timestamp for timeout can be set for each @@ -51,4 +53,5 @@ public class GroupStateTimeout { /** No timeout. */ public static GroupStateTimeout NoTimeout() { return NoTimeout$.MODULE$; } + } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/JsonExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/JsonExpressionsSuite.scala index 65d5c3a582b16..f892e80204603 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/JsonExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/JsonExpressionsSuite.scala @@ -41,7 +41,7 @@ class JsonExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { /* invalid json with leading nulls would trigger java.io.CharConversionException in Jackson's JsonFactory.createParser(byte[]) due to RFC-4627 encoding detection */ - val badJson = "\0\0\0A\1AAA" + val badJson = "\u0000\u0000\u0000A\u0001AAA" test("$.store.bicycle") { checkEvaluation( diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/SpecificParquetRecordReaderBase.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/SpecificParquetRecordReaderBase.java index 0bab321a657d6..5a810cae1e184 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/SpecificParquetRecordReaderBase.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/SpecificParquetRecordReaderBase.java @@ -66,7 +66,6 @@ import org.apache.spark.sql.types.StructType; import org.apache.spark.sql.types.StructType$; import org.apache.spark.util.AccumulatorV2; -import org.apache.spark.util.LongAccumulator; /** * Base class for custom RecordReaders for Parquet that directly materialize to `T`. @@ -160,7 +159,9 @@ public void initialize(InputSplit inputSplit, TaskAttemptContext taskAttemptCont if (taskContext != null) { Option> accu = taskContext.taskMetrics().externalAccums().lastOption(); if (accu.isDefined() && accu.get().getClass().getSimpleName().equals("NumRowGroupsAcc")) { - ((AccumulatorV2)accu.get()).add(blocks.size()); + @SuppressWarnings("unchecked") + AccumulatorV2 intAccum = (AccumulatorV2) accu.get(); + intAccum.add(blocks.size()); } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/QueryExecutionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/QueryExecutionSuite.scala index 1c1931b6a6daf..05637821f71f1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/QueryExecutionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/QueryExecutionSuite.scala @@ -18,6 +18,8 @@ package org.apache.spark.sql.execution import java.util.Locale +import scala.language.reflectiveCalls + import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, OneRowRelation} import org.apache.spark.sql.test.SharedSQLContext diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryListenerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryListenerSuite.scala index b8a694c177310..59c6a6fade175 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryListenerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryListenerSuite.scala @@ -21,6 +21,7 @@ import java.util.UUID import scala.collection.mutable import scala.concurrent.duration._ +import scala.language.reflectiveCalls import org.scalactic.TolerantNumerics import org.scalatest.concurrent.AsyncAssertions.Waiter diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala index 341e03b5e57fb..c3d734e5a0366 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala @@ -183,7 +183,7 @@ class HiveDDLSuite if (dbPath.isEmpty) { hiveContext.sessionState.catalog.defaultTablePath(tableIdentifier) } else { - new Path(new Path(dbPath.get), tableIdentifier.table) + new Path(new Path(dbPath.get), tableIdentifier.table).toUri } val filesystemPath = new Path(expectedTablePath.toString) val fs = filesystemPath.getFileSystem(spark.sessionState.newHadoopConf()) From 7f96f2d7f2d5abf81dd7f8ca27fea35cf798fd65 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Yan=20Facai=20=28=E9=A2=9C=E5=8F=91=E6=89=8D=29?= Date: Wed, 3 May 2017 10:54:40 +0100 Subject: [PATCH 0397/1765] [SPARK-16957][MLLIB] Use midpoints for split values. MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## What changes were proposed in this pull request? Use midpoints for split values now, and maybe later to make it weighted. ## How was this patch tested? + [x] add unit test. + [x] revise Split's unit test. Author: Yan Facai (颜发才) Author: 颜发才(Yan Facai) Closes #17556 from facaiy/ENH/decision_tree_overflow_and_precision_in_aggregation. --- .../spark/ml/tree/impl/RandomForest.scala | 15 ++++--- .../ml/tree/impl/RandomForestSuite.scala | 41 ++++++++++++++++--- python/pyspark/mllib/tree.py | 12 +++--- 3 files changed, 51 insertions(+), 17 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala index 008dd19c2498d..82e1ed85a0a14 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala @@ -996,7 +996,7 @@ private[spark] object RandomForest extends Logging { require(metadata.isContinuous(featureIndex), "findSplitsForContinuousFeature can only be used to find splits for a continuous feature.") - val splits = if (featureSamples.isEmpty) { + val splits: Array[Double] = if (featureSamples.isEmpty) { Array.empty[Double] } else { val numSplits = metadata.numSplits(featureIndex) @@ -1009,10 +1009,15 @@ private[spark] object RandomForest extends Logging { // sort distinct values val valueCounts = valueCountMap.toSeq.sortBy(_._1).toArray - // if possible splits is not enough or just enough, just return all possible splits val possibleSplits = valueCounts.length - 1 - if (possibleSplits <= numSplits) { - valueCounts.map(_._1).init + if (possibleSplits == 0) { + // constant feature + Array.empty[Double] + } else if (possibleSplits <= numSplits) { + // if possible splits is not enough or just enough, just return all possible splits + (1 to possibleSplits) + .map(index => (valueCounts(index - 1)._1 + valueCounts(index)._1) / 2.0) + .toArray } else { // stride between splits val stride: Double = numSamples.toDouble / (numSplits + 1) @@ -1037,7 +1042,7 @@ private[spark] object RandomForest extends Logging { // makes the gap between currentCount and targetCount smaller, // previous value is a split threshold. if (previousGap < currentGap) { - splitsBuilder += valueCounts(index - 1)._1 + splitsBuilder += (valueCounts(index - 1)._1 + valueCounts(index)._1) / 2.0 targetCount += stride } index += 1 diff --git a/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala index e1ab7c2d6520b..df155b464c64b 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala @@ -104,6 +104,31 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext { assert(splits.distinct.length === splits.length) } + // SPARK-16957: Use midpoints for split values. + { + val fakeMetadata = new DecisionTreeMetadata(1, 0, 0, 0, + Map(), Set(), + Array(3), Gini, QuantileStrategy.Sort, + 0, 0, 0.0, 0, 0 + ) + + // possibleSplits <= numSplits + { + val featureSamples = Array(0, 1, 0, 0, 1, 0, 1, 1).map(_.toDouble) + val splits = RandomForest.findSplitsForContinuousFeature(featureSamples, fakeMetadata, 0) + val expectedSplits = Array((0.0 + 1.0) / 2) + assert(splits === expectedSplits) + } + + // possibleSplits > numSplits + { + val featureSamples = Array(0, 0, 1, 1, 2, 2, 3, 3).map(_.toDouble) + val splits = RandomForest.findSplitsForContinuousFeature(featureSamples, fakeMetadata, 0) + val expectedSplits = Array((0.0 + 1.0) / 2, (2.0 + 3.0) / 2) + assert(splits === expectedSplits) + } + } + // find splits should not return identical splits // when there are not enough split candidates, reduce the number of splits in metadata { @@ -112,9 +137,10 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext { Array(5), Gini, QuantileStrategy.Sort, 0, 0, 0.0, 0, 0 ) - val featureSamples = Array(1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 3).map(_.toDouble) + val featureSamples = Array(1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 3, 3).map(_.toDouble) val splits = RandomForest.findSplitsForContinuousFeature(featureSamples, fakeMetadata, 0) - assert(splits === Array(1.0, 2.0)) + val expectedSplits = Array((1.0 + 2.0) / 2, (2.0 + 3.0) / 2) + assert(splits === expectedSplits) // check returned splits are distinct assert(splits.distinct.length === splits.length) } @@ -126,9 +152,11 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext { Array(3), Gini, QuantileStrategy.Sort, 0, 0, 0.0, 0, 0 ) - val featureSamples = Array(2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 3, 4, 5).map(_.toDouble) + val featureSamples = Array(2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 3, 4, 5) + .map(_.toDouble) val splits = RandomForest.findSplitsForContinuousFeature(featureSamples, fakeMetadata, 0) - assert(splits === Array(2.0, 3.0)) + val expectedSplits = Array((2.0 + 3.0) / 2, (3.0 + 4.0) / 2) + assert(splits === expectedSplits) } // find splits when most samples close to the maximum @@ -138,9 +166,10 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext { Array(2), Gini, QuantileStrategy.Sort, 0, 0, 0.0, 0, 0 ) - val featureSamples = Array(0, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2).map(_.toDouble) + val featureSamples = Array(0, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2).map(_.toDouble) val splits = RandomForest.findSplitsForContinuousFeature(featureSamples, fakeMetadata, 0) - assert(splits === Array(1.0)) + val expectedSplits = Array((1.0 + 2.0) / 2) + assert(splits === expectedSplits) } // find splits for constant feature diff --git a/python/pyspark/mllib/tree.py b/python/pyspark/mllib/tree.py index a6089fc8b9d32..619fa16d463f5 100644 --- a/python/pyspark/mllib/tree.py +++ b/python/pyspark/mllib/tree.py @@ -199,9 +199,9 @@ def trainClassifier(cls, data, numClasses, categoricalFeaturesInfo, >>> print(model.toDebugString()) DecisionTreeModel classifier of depth 1 with 3 nodes - If (feature 0 <= 0.0) + If (feature 0 <= 0.5) Predict: 0.0 - Else (feature 0 > 0.0) + Else (feature 0 > 0.5) Predict: 1.0 >>> model.predict(array([1.0])) @@ -383,14 +383,14 @@ def trainClassifier(cls, data, numClasses, categoricalFeaturesInfo, numTrees, Tree 0: Predict: 1.0 Tree 1: - If (feature 0 <= 1.0) + If (feature 0 <= 1.5) Predict: 0.0 - Else (feature 0 > 1.0) + Else (feature 0 > 1.5) Predict: 1.0 Tree 2: - If (feature 0 <= 1.0) + If (feature 0 <= 1.5) Predict: 0.0 - Else (feature 0 > 1.0) + Else (feature 0 > 1.5) Predict: 1.0 >>> model.predict([2.0]) From 27f543b15f2f493f6f8373e46b4c9564b0a1bf81 Mon Sep 17 00:00:00 2001 From: Liwei Lin Date: Wed, 3 May 2017 08:55:02 -0700 Subject: [PATCH 0398/1765] [SPARK-20441][SPARK-20432][SS] Within the same streaming query, one StreamingRelation should only be transformed to one StreamingExecutionRelation MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## What changes were proposed in this pull request? Within the same streaming query, when one `StreamingRelation` is referred multiple times – e.g. `df.union(df)` – we should transform it only to one `StreamingExecutionRelation`, instead of two or more different `StreamingExecutionRelation`s (each of which would have a separate set of source, source logs, ...). ## How was this patch tested? Added two test cases, each of which would fail without this patch. Author: Liwei Lin Closes #17735 from lw-lin/SPARK-20441. --- .../execution/streaming/StreamExecution.scala | 20 ++++---- .../spark/sql/streaming/StreamSuite.scala | 48 +++++++++++++++++++ 2 files changed, 60 insertions(+), 8 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala index affc2018c43cb..b6ddf7437ea13 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala @@ -23,6 +23,7 @@ import java.util.concurrent.{CountDownLatch, TimeUnit} import java.util.concurrent.atomic.AtomicReference import java.util.concurrent.locks.ReentrantLock +import scala.collection.mutable.{Map => MutableMap} import scala.collection.mutable.ArrayBuffer import scala.util.control.NonFatal @@ -148,15 +149,18 @@ class StreamExecution( "logicalPlan must be initialized in StreamExecutionThread " + s"but the current thread was ${Thread.currentThread}") var nextSourceId = 0L + val toExecutionRelationMap = MutableMap[StreamingRelation, StreamingExecutionRelation]() val _logicalPlan = analyzedPlan.transform { - case StreamingRelation(dataSource, _, output) => - // Materialize source to avoid creating it in every batch - val metadataPath = s"$checkpointRoot/sources/$nextSourceId" - val source = dataSource.createSource(metadataPath) - nextSourceId += 1 - // We still need to use the previous `output` instead of `source.schema` as attributes in - // "df.logicalPlan" has already used attributes of the previous `output`. - StreamingExecutionRelation(source, output) + case streamingRelation@StreamingRelation(dataSource, _, output) => + toExecutionRelationMap.getOrElseUpdate(streamingRelation, { + // Materialize source to avoid creating it in every batch + val metadataPath = s"$checkpointRoot/sources/$nextSourceId" + val source = dataSource.createSource(metadataPath) + nextSourceId += 1 + // We still need to use the previous `output` instead of `source.schema` as attributes in + // "df.logicalPlan" has already used attributes of the previous `output`. + StreamingExecutionRelation(source, output) + }) } sources = _logicalPlan.collect { case s: StreamingExecutionRelation => s.source } uniqueSources = sources.distinct diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala index 01ea62a9de4d5..1fc062974e185 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala @@ -71,6 +71,27 @@ class StreamSuite extends StreamTest { CheckAnswer(Row(1, 1, "one"), Row(2, 2, "two"), Row(4, 4, "four"))) } + test("SPARK-20432: union one stream with itself") { + val df = spark.readStream.format(classOf[FakeDefaultSource].getName).load().select("a") + val unioned = df.union(df) + withTempDir { outputDir => + withTempDir { checkpointDir => + val query = + unioned + .writeStream.format("parquet") + .option("checkpointLocation", checkpointDir.getAbsolutePath) + .start(outputDir.getAbsolutePath) + try { + query.processAllAvailable() + val outputDf = spark.read.parquet(outputDir.getAbsolutePath).as[Long] + checkDatasetUnorderly[Long](outputDf, (0L to 10L).union((0L to 10L)).toArray: _*) + } finally { + query.stop() + } + } + } + } + test("union two streams") { val inputData1 = MemoryStream[Int] val inputData2 = MemoryStream[Int] @@ -122,6 +143,33 @@ class StreamSuite extends StreamTest { assertDF(df) } + test("Within the same streaming query, one StreamingRelation should only be transformed to one " + + "StreamingExecutionRelation") { + val df = spark.readStream.format(classOf[FakeDefaultSource].getName).load() + var query: StreamExecution = null + try { + query = + df.union(df) + .writeStream + .format("memory") + .queryName("memory") + .start() + .asInstanceOf[StreamingQueryWrapper] + .streamingQuery + query.awaitInitialization(streamingTimeout.toMillis) + val executionRelations = + query + .logicalPlan + .collect { case ser: StreamingExecutionRelation => ser } + assert(executionRelations.size === 2) + assert(executionRelations.distinct.size === 1) + } finally { + if (query != null) { + query.stop() + } + } + } + test("unsupported queries") { val streamInput = MemoryStream[Int] val batchInput = Seq(1, 2, 3).toDS() From 527fc5d0c990daaacad4740f62cfe6736609b77b Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Wed, 3 May 2017 09:22:25 -0700 Subject: [PATCH 0399/1765] [SPARK-20576][SQL] Support generic hint function in Dataset/DataFrame ## What changes were proposed in this pull request? We allow users to specify hints (currently only "broadcast" is supported) in SQL and DataFrame. However, while SQL has a standard hint format (/*+ ... */), DataFrame doesn't have one and sometimes users are confused that they can't find how to apply a broadcast hint. This ticket adds a generic hint function on DataFrame that allows using the same hint on DataFrames as well as SQL. As an example, after this patch, the following will apply a broadcast hint on a DataFrame using the new hint function: ``` df1.join(df2.hint("broadcast")) ``` ## How was this patch tested? Added a test case in DataFrameJoinSuite. Author: Reynold Xin Closes #17839 from rxin/SPARK-20576. --- .../sql/catalyst/analysis/ResolveHints.scala | 8 +++++++- .../scala/org/apache/spark/sql/Dataset.scala | 16 ++++++++++++++++ .../apache/spark/sql/DataFrameJoinSuite.scala | 18 +++++++++++++++++- 3 files changed, 40 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveHints.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveHints.scala index c4827b81e8b63..df688fa0e58ae 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveHints.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveHints.scala @@ -86,7 +86,13 @@ object ResolveHints { def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { case h: Hint if BROADCAST_HINT_NAMES.contains(h.name.toUpperCase(Locale.ROOT)) => - applyBroadcastHint(h.child, h.parameters.toSet) + if (h.parameters.isEmpty) { + // If there is no table alias specified, turn the entire subtree into a BroadcastHint. + BroadcastHint(h.child) + } else { + // Otherwise, find within the subtree query plans that should be broadcasted. + applyBroadcastHint(h.child, h.parameters.toSet) + } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index 147e7651ce55b..620c8bd54ba00 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -1160,6 +1160,22 @@ class Dataset[T] private[sql]( */ def apply(colName: String): Column = col(colName) + /** + * Specifies some hint on the current Dataset. As an example, the following code specifies + * that one of the plan can be broadcasted: + * + * {{{ + * df1.join(df2.hint("broadcast")) + * }}} + * + * @group basic + * @since 2.2.0 + */ + @scala.annotation.varargs + def hint(name: String, parameters: String*): Dataset[T] = withTypedPlan { + Hint(name, parameters, logicalPlan) + } + /** * Selects column based on the column name and return it as a [[Column]]. * diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala index 541ffb58e727f..4a52af6c32c37 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala @@ -151,7 +151,7 @@ class DataFrameJoinSuite extends QueryTest with SharedSQLContext { Row(1, 1, 1, 1) :: Row(2, 1, 2, 2) :: Nil) } - test("broadcast join hint") { + test("broadcast join hint using broadcast function") { val df1 = Seq((1, "1"), (2, "2")).toDF("key", "value") val df2 = Seq((1, "1"), (2, "2")).toDF("key", "value") @@ -174,6 +174,22 @@ class DataFrameJoinSuite extends QueryTest with SharedSQLContext { } } + test("broadcast join hint using Dataset.hint") { + // make sure a giant join is not broadcastable + val plan1 = + spark.range(10e10.toLong) + .join(spark.range(10e10.toLong), "id") + .queryExecution.executedPlan + assert(plan1.collect { case p: BroadcastHashJoinExec => p }.size == 0) + + // now with a hint it should be broadcasted + val plan2 = + spark.range(10e10.toLong) + .join(spark.range(10e10.toLong).hint("broadcast"), "id") + .queryExecution.executedPlan + assert(plan2.collect { case p: BroadcastHashJoinExec => p }.size == 1) + } + test("join - outer join conversion") { val df = Seq((1, 2, "1"), (3, 4, "3")).toDF("int", "int2", "str").as("a") val df2 = Seq((1, 3, "1"), (5, 6, "5")).toDF("int", "int2", "str").as("b") From 6b9e49d12fc4c9b29d497122daa4cc9bf4540b16 Mon Sep 17 00:00:00 2001 From: Liwei Lin Date: Wed, 3 May 2017 11:10:24 -0700 Subject: [PATCH 0400/1765] [SPARK-19965][SS] DataFrame batch reader may fail to infer partitions when reading FileStreamSink's output ## The Problem Right now DataFrame batch reader may fail to infer partitions when reading FileStreamSink's output: ``` [info] - partitioned writing and batch reading with 'basePath' *** FAILED *** (3 seconds, 928 milliseconds) [info] java.lang.AssertionError: assertion failed: Conflicting directory structures detected. Suspicious paths: [info] ***/stream.output-65e3fa45-595a-4d29-b3df-4c001e321637 [info] ***/stream.output-65e3fa45-595a-4d29-b3df-4c001e321637/_spark_metadata [info] [info] If provided paths are partition directories, please set "basePath" in the options of the data source to specify the root directory of the table. If there are multiple root directories, please load them separately and then union them. [info] at scala.Predef$.assert(Predef.scala:170) [info] at org.apache.spark.sql.execution.datasources.PartitioningUtils$.parsePartitions(PartitioningUtils.scala:133) [info] at org.apache.spark.sql.execution.datasources.PartitioningUtils$.parsePartitions(PartitioningUtils.scala:98) [info] at org.apache.spark.sql.execution.datasources.PartitioningAwareFileIndex.inferPartitioning(PartitioningAwareFileIndex.scala:156) [info] at org.apache.spark.sql.execution.datasources.InMemoryFileIndex.partitionSpec(InMemoryFileIndex.scala:54) [info] at org.apache.spark.sql.execution.datasources.PartitioningAwareFileIndex.partitionSchema(PartitioningAwareFileIndex.scala:55) [info] at org.apache.spark.sql.execution.datasources.DataSource.getOrInferFileFormatSchema(DataSource.scala:133) [info] at org.apache.spark.sql.execution.datasources.DataSource.resolveRelation(DataSource.scala:361) [info] at org.apache.spark.sql.DataFrameReader.load(DataFrameReader.scala:160) [info] at org.apache.spark.sql.DataFrameReader.parquet(DataFrameReader.scala:536) [info] at org.apache.spark.sql.DataFrameReader.parquet(DataFrameReader.scala:520) [info] at org.apache.spark.sql.streaming.FileStreamSinkSuite$$anonfun$8.apply$mcV$sp(FileStreamSinkSuite.scala:292) [info] at org.apache.spark.sql.streaming.FileStreamSinkSuite$$anonfun$8.apply(FileStreamSinkSuite.scala:268) [info] at org.apache.spark.sql.streaming.FileStreamSinkSuite$$anonfun$8.apply(FileStreamSinkSuite.scala:268) ``` ## What changes were proposed in this pull request? This patch alters `InMemoryFileIndex` to filter out these `basePath`s whose ancestor is the streaming metadata dir (`_spark_metadata`). E.g., the following and other similar dir or files will be filtered out: - (introduced by globbing `basePath/*`) - `basePath/_spark_metadata` - (introduced by globbing `basePath/*/*`) - `basePath/_spark_metadata/0` - `basePath/_spark_metadata/1` - ... ## How was this patch tested? Added unit tests Author: Liwei Lin Closes #17346 from lw-lin/filter-metadata. --- .../datasources/InMemoryFileIndex.scala | 13 +++- .../execution/streaming/FileStreamSink.scala | 20 +++++++ .../datasources/FileSourceStrategySuite.scala | 2 +- .../sql/streaming/FileStreamSinkSuite.scala | 59 ++++++++++++++++++- 4 files changed, 90 insertions(+), 4 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InMemoryFileIndex.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InMemoryFileIndex.scala index 9897ab73b0da8..91e31650617ec 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InMemoryFileIndex.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InMemoryFileIndex.scala @@ -27,6 +27,7 @@ import org.apache.hadoop.mapred.{FileInputFormat, JobConf} import org.apache.spark.internal.Logging import org.apache.spark.metrics.source.HiveCatalogMetrics +import org.apache.spark.sql.execution.streaming.FileStreamSink import org.apache.spark.sql.SparkSession import org.apache.spark.sql.types.StructType import org.apache.spark.util.SerializableConfiguration @@ -36,20 +37,28 @@ import org.apache.spark.util.SerializableConfiguration * A [[FileIndex]] that generates the list of files to process by recursively listing all the * files present in `paths`. * - * @param rootPaths the list of root table paths to scan + * @param rootPathsSpecified the list of root table paths to scan (some of which might be + * filtered out later) * @param parameters as set of options to control discovery * @param partitionSchema an optional partition schema that will be use to provide types for the * discovered partitions */ class InMemoryFileIndex( sparkSession: SparkSession, - override val rootPaths: Seq[Path], + rootPathsSpecified: Seq[Path], parameters: Map[String, String], partitionSchema: Option[StructType], fileStatusCache: FileStatusCache = NoopCache) extends PartitioningAwareFileIndex( sparkSession, parameters, partitionSchema, fileStatusCache) { + // Filter out streaming metadata dirs or files such as "/.../_spark_metadata" (the metadata dir) + // or "/.../_spark_metadata/0" (a file in the metadata dir). `rootPathsSpecified` might contain + // such streaming metadata dir or files, e.g. when after globbing "basePath/*" where "basePath" + // is the output of a streaming query. + override val rootPaths = + rootPathsSpecified.filterNot(FileStreamSink.ancestorIsMetadataDirectory(_, hadoopConf)) + @volatile private var cachedLeafFiles: mutable.LinkedHashMap[Path, FileStatus] = _ @volatile private var cachedLeafDirToChildrenFiles: Map[Path, Array[FileStatus]] = _ @volatile private var cachedPartitionSpec: PartitionSpec = _ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSink.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSink.scala index 07ec4e9429e42..6885d0bf67ccb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSink.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSink.scala @@ -53,6 +53,26 @@ object FileStreamSink extends Logging { case _ => false } } + + /** + * Returns true if the path is the metadata dir or its ancestor is the metadata dir. + * E.g.: + * - ancestorIsMetadataDirectory(/.../_spark_metadata) => true + * - ancestorIsMetadataDirectory(/.../_spark_metadata/0) => true + * - ancestorIsMetadataDirectory(/a/b/c) => false + */ + def ancestorIsMetadataDirectory(path: Path, hadoopConf: Configuration): Boolean = { + val fs = path.getFileSystem(hadoopConf) + var currentPath = path.makeQualified(fs.getUri, fs.getWorkingDirectory) + while (currentPath != null) { + if (currentPath.getName == FileStreamSink.metadataDir) { + return true + } else { + currentPath = currentPath.getParent + } + } + return false + } } /** diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategySuite.scala index 8703fe96e5878..fa3c69612704d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategySuite.scala @@ -395,7 +395,7 @@ class FileSourceStrategySuite extends QueryTest with SharedSQLContext with Predi val fileCatalog = new InMemoryFileIndex( sparkSession = spark, - rootPaths = Seq(new Path(tempDir)), + rootPathsSpecified = Seq(new Path(tempDir)), parameters = Map.empty[String, String], partitionSchema = None) // This should not fail. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSinkSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSinkSuite.scala index 1211242b9fbb4..1a2d3a13f3a4a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSinkSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSinkSuite.scala @@ -19,10 +19,12 @@ package org.apache.spark.sql.streaming import java.util.Locale +import org.apache.hadoop.fs.Path + import org.apache.spark.sql.{AnalysisException, DataFrame} import org.apache.spark.sql.execution.DataSourceScanExec import org.apache.spark.sql.execution.datasources._ -import org.apache.spark.sql.execution.streaming.{MemoryStream, MetadataLogFileIndex} +import org.apache.spark.sql.execution.streaming._ import org.apache.spark.sql.functions._ import org.apache.spark.sql.types.{IntegerType, StructField, StructType} import org.apache.spark.util.Utils @@ -145,6 +147,43 @@ class FileStreamSinkSuite extends StreamTest { } } + test("partitioned writing and batch reading with 'basePath'") { + withTempDir { outputDir => + withTempDir { checkpointDir => + val outputPath = outputDir.getAbsolutePath + val inputData = MemoryStream[Int] + val ds = inputData.toDS() + + var query: StreamingQuery = null + + try { + query = + ds.map(i => (i, -i, i * 1000)) + .toDF("id1", "id2", "value") + .writeStream + .partitionBy("id1", "id2") + .option("checkpointLocation", checkpointDir.getAbsolutePath) + .format("parquet") + .start(outputPath) + + inputData.addData(1, 2, 3) + failAfter(streamingTimeout) { + query.processAllAvailable() + } + + val readIn = spark.read.option("basePath", outputPath).parquet(s"$outputDir/*/*") + checkDatasetUnorderly( + readIn.as[(Int, Int, Int)], + (1000, 1, -1), (2000, 2, -2), (3000, 3, -3)) + } finally { + if (query != null) { + query.stop() + } + } + } + } + } + // This tests whether FileStreamSink works with aggregations. Specifically, it tests // whether the correct streaming QueryExecution (i.e. IncrementalExecution) is used to // to execute the trigger for writing data to file sink. See SPARK-18440 for more details. @@ -266,4 +305,22 @@ class FileStreamSinkSuite extends StreamTest { } } } + + test("FileStreamSink.ancestorIsMetadataDirectory()") { + val hadoopConf = spark.sparkContext.hadoopConfiguration + def assertAncestorIsMetadataDirectory(path: String): Unit = + assert(FileStreamSink.ancestorIsMetadataDirectory(new Path(path), hadoopConf)) + def assertAncestorIsNotMetadataDirectory(path: String): Unit = + assert(!FileStreamSink.ancestorIsMetadataDirectory(new Path(path), hadoopConf)) + + assertAncestorIsMetadataDirectory(s"/${FileStreamSink.metadataDir}") + assertAncestorIsMetadataDirectory(s"/${FileStreamSink.metadataDir}/") + assertAncestorIsMetadataDirectory(s"/a/${FileStreamSink.metadataDir}") + assertAncestorIsMetadataDirectory(s"/a/${FileStreamSink.metadataDir}/") + assertAncestorIsMetadataDirectory(s"/a/b/${FileStreamSink.metadataDir}/c") + assertAncestorIsMetadataDirectory(s"/a/b/${FileStreamSink.metadataDir}/c/") + + assertAncestorIsNotMetadataDirectory(s"/a/b/c") + assertAncestorIsNotMetadataDirectory(s"/a/b/c/${FileStreamSink.metadataDir}extra") + } } From 13eb37c860c8f672d0e9d9065d0333f981db71e3 Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Wed, 3 May 2017 13:08:25 -0700 Subject: [PATCH 0401/1765] [MINOR][SQL] Fix the test title from =!= to <=>, remove a duplicated test and add a test for =!= ## What changes were proposed in this pull request? This PR proposes three things as below: - This test looks not testing `<=>` and identical with the test above, `===`. So, it removes the test. ```diff - test("<=>") { - checkAnswer( - testData2.filter($"a" === 1), - testData2.collect().toSeq.filter(r => r.getInt(0) == 1)) - - checkAnswer( - testData2.filter($"a" === $"b"), - testData2.collect().toSeq.filter(r => r.getInt(0) == r.getInt(1))) - } ``` - Replace the test title from `=!=` to `<=>`. It looks the test actually testing `<=>`. ```diff + private lazy val nullData = Seq( + (Some(1), Some(1)), (Some(1), Some(2)), (Some(1), None), (None, None)).toDF("a", "b") + ... - test("=!=") { + test("<=>") { - val nullData = spark.createDataFrame(sparkContext.parallelize( - Row(1, 1) :: - Row(1, 2) :: - Row(1, null) :: - Row(null, null) :: Nil), - StructType(Seq(StructField("a", IntegerType), StructField("b", IntegerType)))) - checkAnswer( nullData.filter($"b" <=> 1), ... ``` - Add the tests for `=!=` which looks not existing. ```diff + test("=!=") { + checkAnswer( + nullData.filter($"b" =!= 1), + Row(1, 2) :: Nil) + + checkAnswer(nullData.filter($"b" =!= null), Nil) + + checkAnswer( + nullData.filter($"a" =!= $"b"), + Row(1, 2) :: Nil) + } ``` ## How was this patch tested? Manually running the tests. Author: hyukjinkwon Closes #17842 from HyukjinKwon/minor-test-fix. --- .../spark/sql/ColumnExpressionSuite.scala | 31 +++++++++---------- 1 file changed, 14 insertions(+), 17 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala index b0f398dab7455..bc708ca88d7e1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala @@ -39,6 +39,9 @@ class ColumnExpressionSuite extends QueryTest with SharedSQLContext { StructType(Seq(StructField("a", BooleanType), StructField("b", BooleanType)))) } + private lazy val nullData = Seq( + (Some(1), Some(1)), (Some(1), Some(2)), (Some(1), None), (None, None)).toDF("a", "b") + test("column names with space") { val df = Seq((1, "a")).toDF("name with space", "name.with.dot") @@ -283,23 +286,6 @@ class ColumnExpressionSuite extends QueryTest with SharedSQLContext { } test("<=>") { - checkAnswer( - testData2.filter($"a" === 1), - testData2.collect().toSeq.filter(r => r.getInt(0) == 1)) - - checkAnswer( - testData2.filter($"a" === $"b"), - testData2.collect().toSeq.filter(r => r.getInt(0) == r.getInt(1))) - } - - test("=!=") { - val nullData = spark.createDataFrame(sparkContext.parallelize( - Row(1, 1) :: - Row(1, 2) :: - Row(1, null) :: - Row(null, null) :: Nil), - StructType(Seq(StructField("a", IntegerType), StructField("b", IntegerType)))) - checkAnswer( nullData.filter($"b" <=> 1), Row(1, 1) :: Nil) @@ -321,7 +307,18 @@ class ColumnExpressionSuite extends QueryTest with SharedSQLContext { checkAnswer( nullData2.filter($"a" <=> null), Row(null) :: Nil) + } + test("=!=") { + checkAnswer( + nullData.filter($"b" =!= 1), + Row(1, 2) :: Nil) + + checkAnswer(nullData.filter($"b" =!= null), Nil) + + checkAnswer( + nullData.filter($"a" =!= $"b"), + Row(1, 2) :: Nil) } test(">") { From 02bbe73118a39e2fb378aa2002449367a92f6d67 Mon Sep 17 00:00:00 2001 From: zero323 Date: Wed, 3 May 2017 19:15:28 -0700 Subject: [PATCH 0402/1765] [SPARK-20584][PYSPARK][SQL] Python generic hint support ## What changes were proposed in this pull request? Adds `hint` method to PySpark `DataFrame`. ## How was this patch tested? Unit tests, doctests. Author: zero323 Closes #17850 from zero323/SPARK-20584. --- python/pyspark/sql/dataframe.py | 29 +++++++++++++++++++++++++++++ python/pyspark/sql/tests.py | 16 ++++++++++++++++ 2 files changed, 45 insertions(+) diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index ab6d35bfa7c5c..7b67985f2b320 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -380,6 +380,35 @@ def withWatermark(self, eventTime, delayThreshold): jdf = self._jdf.withWatermark(eventTime, delayThreshold) return DataFrame(jdf, self.sql_ctx) + @since(2.2) + def hint(self, name, *parameters): + """Specifies some hint on the current DataFrame. + + :param name: A name of the hint. + :param parameters: Optional parameters. + :return: :class:`DataFrame` + + >>> df.join(df2.hint("broadcast"), "name").show() + +----+---+------+ + |name|age|height| + +----+---+------+ + | Bob| 5| 85| + +----+---+------+ + """ + if len(parameters) == 1 and isinstance(parameters[0], list): + parameters = parameters[0] + + if not isinstance(name, str): + raise TypeError("name should be provided as str, got {0}".format(type(name))) + + for p in parameters: + if not isinstance(p, str): + raise TypeError( + "all parameters should be str, got {0} of type {1}".format(p, type(p))) + + jdf = self._jdf.hint(name, self._jseq(parameters)) + return DataFrame(jdf, self.sql_ctx) + @since(1.3) def count(self): """Returns the number of rows in this :class:`DataFrame`. diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index ce4abf8fb7e5c..f644624f7f317 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -1906,6 +1906,22 @@ def test_functions_broadcast(self): # planner should not crash without a join broadcast(df1)._jdf.queryExecution().executedPlan() + def test_generic_hints(self): + from pyspark.sql import DataFrame + + df1 = self.spark.range(10e10).toDF("id") + df2 = self.spark.range(10e10).toDF("id") + + self.assertIsInstance(df1.hint("broadcast"), DataFrame) + self.assertIsInstance(df1.hint("broadcast", []), DataFrame) + + # Dummy rules + self.assertIsInstance(df1.hint("broadcast", "foo", "bar"), DataFrame) + self.assertIsInstance(df1.hint("broadcast", ["foo", "bar"]), DataFrame) + + plan = df1.join(df2.hint("broadcast"), "id")._jdf.queryExecution().executedPlan() + self.assertEqual(1, plan.toString().count("BroadcastHashJoin")) + def test_toDF_with_schema_string(self): data = [Row(key=i, value=str(i)) for i in range(100)] rdd = self.sc.parallelize(data, 5) From fc472bddd1d9c6a28e57e31496c0166777af597e Mon Sep 17 00:00:00 2001 From: Felix Cheung Date: Wed, 3 May 2017 21:40:18 -0700 Subject: [PATCH 0403/1765] [SPARK-20543][SPARKR] skip tests when running on CRAN ## What changes were proposed in this pull request? General rule on skip or not: skip if - RDD tests - tests could run long or complicated (streaming, hivecontext) - tests on error conditions - tests won't likely change/break ## How was this patch tested? unit tests, `R CMD check --as-cran`, `R CMD check` Author: Felix Cheung Closes #17817 from felixcheung/rskiptest. --- R/pkg/inst/tests/testthat/test_Serde.R | 6 + R/pkg/inst/tests/testthat/test_Windows.R | 2 + R/pkg/inst/tests/testthat/test_binaryFile.R | 8 ++ .../tests/testthat/test_binary_function.R | 6 + R/pkg/inst/tests/testthat/test_broadcast.R | 4 + R/pkg/inst/tests/testthat/test_client.R | 8 ++ R/pkg/inst/tests/testthat/test_context.R | 16 +++ .../inst/tests/testthat/test_includePackage.R | 4 + .../tests/testthat/test_mllib_clustering.R | 4 + .../tests/testthat/test_mllib_regression.R | 12 ++ .../tests/testthat/test_parallelize_collect.R | 8 ++ R/pkg/inst/tests/testthat/test_rdd.R | 106 +++++++++++++++++- R/pkg/inst/tests/testthat/test_shuffle.R | 24 ++++ R/pkg/inst/tests/testthat/test_sparkR.R | 2 + R/pkg/inst/tests/testthat/test_sparkSQL.R | 61 +++++++++- R/pkg/inst/tests/testthat/test_streaming.R | 12 ++ R/pkg/inst/tests/testthat/test_take.R | 2 + R/pkg/inst/tests/testthat/test_textFile.R | 18 +++ R/pkg/inst/tests/testthat/test_utils.R | 6 + R/run-tests.sh | 2 +- 20 files changed, 307 insertions(+), 4 deletions(-) diff --git a/R/pkg/inst/tests/testthat/test_Serde.R b/R/pkg/inst/tests/testthat/test_Serde.R index b5f6f1b54fa85..518fb7bd94043 100644 --- a/R/pkg/inst/tests/testthat/test_Serde.R +++ b/R/pkg/inst/tests/testthat/test_Serde.R @@ -20,6 +20,8 @@ context("SerDe functionality") sparkSession <- sparkR.session(enableHiveSupport = FALSE) test_that("SerDe of primitive types", { + skip_on_cran() + x <- callJStatic("SparkRHandler", "echo", 1L) expect_equal(x, 1L) expect_equal(class(x), "integer") @@ -38,6 +40,8 @@ test_that("SerDe of primitive types", { }) test_that("SerDe of list of primitive types", { + skip_on_cran() + x <- list(1L, 2L, 3L) y <- callJStatic("SparkRHandler", "echo", x) expect_equal(x, y) @@ -65,6 +69,8 @@ test_that("SerDe of list of primitive types", { }) test_that("SerDe of list of lists", { + skip_on_cran() + x <- list(list(1L, 2L, 3L), list(1, 2, 3), list(TRUE, FALSE), list("a", "b", "c")) y <- callJStatic("SparkRHandler", "echo", x) diff --git a/R/pkg/inst/tests/testthat/test_Windows.R b/R/pkg/inst/tests/testthat/test_Windows.R index 1d777ddb286df..919b063bf0693 100644 --- a/R/pkg/inst/tests/testthat/test_Windows.R +++ b/R/pkg/inst/tests/testthat/test_Windows.R @@ -17,6 +17,8 @@ context("Windows-specific tests") test_that("sparkJars tag in SparkContext", { + skip_on_cran() + if (.Platform$OS.type != "windows") { skip("This test is only for Windows, skipped") } diff --git a/R/pkg/inst/tests/testthat/test_binaryFile.R b/R/pkg/inst/tests/testthat/test_binaryFile.R index b5c279e3156e5..63f54e1af02b1 100644 --- a/R/pkg/inst/tests/testthat/test_binaryFile.R +++ b/R/pkg/inst/tests/testthat/test_binaryFile.R @@ -24,6 +24,8 @@ sc <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", "getJavaSparkContext", mockFile <- c("Spark is pretty.", "Spark is awesome.") test_that("saveAsObjectFile()/objectFile() following textFile() works", { + skip_on_cran() + fileName1 <- tempfile(pattern = "spark-test", fileext = ".tmp") fileName2 <- tempfile(pattern = "spark-test", fileext = ".tmp") writeLines(mockFile, fileName1) @@ -38,6 +40,8 @@ test_that("saveAsObjectFile()/objectFile() following textFile() works", { }) test_that("saveAsObjectFile()/objectFile() works on a parallelized list", { + skip_on_cran() + fileName <- tempfile(pattern = "spark-test", fileext = ".tmp") l <- list(1, 2, 3) @@ -50,6 +54,8 @@ test_that("saveAsObjectFile()/objectFile() works on a parallelized list", { }) test_that("saveAsObjectFile()/objectFile() following RDD transformations works", { + skip_on_cran() + fileName1 <- tempfile(pattern = "spark-test", fileext = ".tmp") fileName2 <- tempfile(pattern = "spark-test", fileext = ".tmp") writeLines(mockFile, fileName1) @@ -74,6 +80,8 @@ test_that("saveAsObjectFile()/objectFile() following RDD transformations works", }) test_that("saveAsObjectFile()/objectFile() works with multiple paths", { + skip_on_cran() + fileName1 <- tempfile(pattern = "spark-test", fileext = ".tmp") fileName2 <- tempfile(pattern = "spark-test", fileext = ".tmp") diff --git a/R/pkg/inst/tests/testthat/test_binary_function.R b/R/pkg/inst/tests/testthat/test_binary_function.R index 59cb2e6204405..25bb2b84266dd 100644 --- a/R/pkg/inst/tests/testthat/test_binary_function.R +++ b/R/pkg/inst/tests/testthat/test_binary_function.R @@ -29,6 +29,8 @@ rdd <- parallelize(sc, nums, 2L) mockFile <- c("Spark is pretty.", "Spark is awesome.") test_that("union on two RDDs", { + skip_on_cran() + actual <- collectRDD(unionRDD(rdd, rdd)) expect_equal(actual, as.list(rep(nums, 2))) @@ -51,6 +53,8 @@ test_that("union on two RDDs", { }) test_that("cogroup on two RDDs", { + skip_on_cran() + rdd1 <- parallelize(sc, list(list(1, 1), list(2, 4))) rdd2 <- parallelize(sc, list(list(1, 2), list(1, 3))) cogroup.rdd <- cogroup(rdd1, rdd2, numPartitions = 2L) @@ -69,6 +73,8 @@ test_that("cogroup on two RDDs", { }) test_that("zipPartitions() on RDDs", { + skip_on_cran() + rdd1 <- parallelize(sc, 1:2, 2L) # 1, 2 rdd2 <- parallelize(sc, 1:4, 2L) # 1:2, 3:4 rdd3 <- parallelize(sc, 1:6, 2L) # 1:3, 4:6 diff --git a/R/pkg/inst/tests/testthat/test_broadcast.R b/R/pkg/inst/tests/testthat/test_broadcast.R index 65f204d096f43..504ded4fc8623 100644 --- a/R/pkg/inst/tests/testthat/test_broadcast.R +++ b/R/pkg/inst/tests/testthat/test_broadcast.R @@ -26,6 +26,8 @@ nums <- 1:2 rrdd <- parallelize(sc, nums, 2L) test_that("using broadcast variable", { + skip_on_cran() + randomMat <- matrix(nrow = 10, ncol = 10, data = rnorm(100)) randomMatBr <- broadcast(sc, randomMat) @@ -38,6 +40,8 @@ test_that("using broadcast variable", { }) test_that("without using broadcast variable", { + skip_on_cran() + randomMat <- matrix(nrow = 10, ncol = 10, data = rnorm(100)) useBroadcast <- function(x) { diff --git a/R/pkg/inst/tests/testthat/test_client.R b/R/pkg/inst/tests/testthat/test_client.R index 0cf25fe1dbf39..3d53bebab6300 100644 --- a/R/pkg/inst/tests/testthat/test_client.R +++ b/R/pkg/inst/tests/testthat/test_client.R @@ -18,6 +18,8 @@ context("functions in client.R") test_that("adding spark-testing-base as a package works", { + skip_on_cran() + args <- generateSparkSubmitArgs("", "", "", "", "holdenk:spark-testing-base:1.3.0_0.0.5") expect_equal(gsub("[[:space:]]", "", args), @@ -26,16 +28,22 @@ test_that("adding spark-testing-base as a package works", { }) test_that("no package specified doesn't add packages flag", { + skip_on_cran() + args <- generateSparkSubmitArgs("", "", "", "", "") expect_equal(gsub("[[:space:]]", "", args), "") }) test_that("multiple packages don't produce a warning", { + skip_on_cran() + expect_warning(generateSparkSubmitArgs("", "", "", "", c("A", "B")), NA) }) test_that("sparkJars sparkPackages as character vectors", { + skip_on_cran() + args <- generateSparkSubmitArgs("", "", c("one.jar", "two.jar", "three.jar"), "", c("com.databricks:spark-avro_2.10:2.0.1")) expect_match(args, "--jars one.jar,two.jar,three.jar") diff --git a/R/pkg/inst/tests/testthat/test_context.R b/R/pkg/inst/tests/testthat/test_context.R index c64fe6edcd49e..632a90d68177f 100644 --- a/R/pkg/inst/tests/testthat/test_context.R +++ b/R/pkg/inst/tests/testthat/test_context.R @@ -18,6 +18,8 @@ context("test functions in sparkR.R") test_that("Check masked functions", { + skip_on_cran() + # Check that we are not masking any new function from base, stats, testthat unexpectedly # NOTE: We should avoid adding entries to *namesOfMaskedCompletely* as masked functions make it # hard for users to use base R functions. Please check when in doubt. @@ -55,6 +57,8 @@ test_that("Check masked functions", { }) test_that("repeatedly starting and stopping SparkR", { + skip_on_cran() + for (i in 1:4) { sc <- suppressWarnings(sparkR.init()) rdd <- parallelize(sc, 1:20, 2L) @@ -73,6 +77,8 @@ test_that("repeatedly starting and stopping SparkSession", { }) test_that("rdd GC across sparkR.stop", { + skip_on_cran() + sc <- sparkR.sparkContext() # sc should get id 0 rdd1 <- parallelize(sc, 1:20, 2L) # rdd1 should get id 1 rdd2 <- parallelize(sc, 1:10, 2L) # rdd2 should get id 2 @@ -96,6 +102,8 @@ test_that("rdd GC across sparkR.stop", { }) test_that("job group functions can be called", { + skip_on_cran() + sc <- sparkR.sparkContext() setJobGroup("groupId", "job description", TRUE) cancelJobGroup("groupId") @@ -108,12 +116,16 @@ test_that("job group functions can be called", { }) test_that("utility function can be called", { + skip_on_cran() + sparkR.sparkContext() setLogLevel("ERROR") sparkR.session.stop() }) test_that("getClientModeSparkSubmitOpts() returns spark-submit args from whitelist", { + skip_on_cran() + e <- new.env() e[["spark.driver.memory"]] <- "512m" ops <- getClientModeSparkSubmitOpts("sparkrmain", e) @@ -141,6 +153,8 @@ test_that("getClientModeSparkSubmitOpts() returns spark-submit args from whiteli }) test_that("sparkJars sparkPackages as comma-separated strings", { + skip_on_cran() + expect_warning(processSparkJars(" a, b ")) jars <- suppressWarnings(processSparkJars(" a, b ")) expect_equal(lapply(jars, basename), list("a", "b")) @@ -168,6 +182,8 @@ test_that("spark.lapply should perform simple transforms", { }) test_that("add and get file to be downloaded with Spark job on every node", { + skip_on_cran() + sparkR.sparkContext() # Test add file. path <- tempfile(pattern = "hello", fileext = ".txt") diff --git a/R/pkg/inst/tests/testthat/test_includePackage.R b/R/pkg/inst/tests/testthat/test_includePackage.R index 563ea298c2dd8..f823ad8e9c985 100644 --- a/R/pkg/inst/tests/testthat/test_includePackage.R +++ b/R/pkg/inst/tests/testthat/test_includePackage.R @@ -26,6 +26,8 @@ nums <- 1:2 rdd <- parallelize(sc, nums, 2L) test_that("include inside function", { + skip_on_cran() + # Only run the test if plyr is installed. if ("plyr" %in% rownames(installed.packages())) { suppressPackageStartupMessages(library(plyr)) @@ -42,6 +44,8 @@ test_that("include inside function", { }) test_that("use include package", { + skip_on_cran() + # Only run the test if plyr is installed. if ("plyr" %in% rownames(installed.packages())) { suppressPackageStartupMessages(library(plyr)) diff --git a/R/pkg/inst/tests/testthat/test_mllib_clustering.R b/R/pkg/inst/tests/testthat/test_mllib_clustering.R index 1661e987b730f..478012e8828cd 100644 --- a/R/pkg/inst/tests/testthat/test_mllib_clustering.R +++ b/R/pkg/inst/tests/testthat/test_mllib_clustering.R @@ -255,6 +255,8 @@ test_that("spark.lda with libsvm", { }) test_that("spark.lda with text input", { + skip_on_cran() + text <- read.text(absoluteSparkPath("data/mllib/sample_lda_data.txt")) model <- spark.lda(text, optimizer = "online", features = "value") @@ -297,6 +299,8 @@ test_that("spark.lda with text input", { }) test_that("spark.posterior and spark.perplexity", { + skip_on_cran() + text <- read.text(absoluteSparkPath("data/mllib/sample_lda_data.txt")) model <- spark.lda(text, features = "value", k = 3) diff --git a/R/pkg/inst/tests/testthat/test_mllib_regression.R b/R/pkg/inst/tests/testthat/test_mllib_regression.R index 3e9ad77198073..58924f952c6bf 100644 --- a/R/pkg/inst/tests/testthat/test_mllib_regression.R +++ b/R/pkg/inst/tests/testthat/test_mllib_regression.R @@ -23,6 +23,8 @@ context("MLlib regression algorithms, except for tree-based algorithms") sparkSession <- sparkR.session(enableHiveSupport = FALSE) test_that("formula of spark.glm", { + skip_on_cran() + training <- suppressWarnings(createDataFrame(iris)) # directly calling the spark API # dot minus and intercept vs native glm @@ -195,6 +197,8 @@ test_that("spark.glm summary", { }) test_that("spark.glm save/load", { + skip_on_cran() + training <- suppressWarnings(createDataFrame(iris)) m <- spark.glm(training, Sepal_Width ~ Sepal_Length + Species) s <- summary(m) @@ -222,6 +226,8 @@ test_that("spark.glm save/load", { }) test_that("formula of glm", { + skip_on_cran() + training <- suppressWarnings(createDataFrame(iris)) # dot minus and intercept vs native glm model <- glm(Sepal_Width ~ . - Species + 0, data = training) @@ -248,6 +254,8 @@ test_that("formula of glm", { }) test_that("glm and predict", { + skip_on_cran() + training <- suppressWarnings(createDataFrame(iris)) # gaussian family model <- glm(Sepal_Width ~ Sepal_Length + Species, data = training) @@ -292,6 +300,8 @@ test_that("glm and predict", { }) test_that("glm summary", { + skip_on_cran() + # gaussian family training <- suppressWarnings(createDataFrame(iris)) stats <- summary(glm(Sepal_Width ~ Sepal_Length + Species, data = training)) @@ -341,6 +351,8 @@ test_that("glm summary", { }) test_that("glm save/load", { + skip_on_cran() + training <- suppressWarnings(createDataFrame(iris)) m <- glm(Sepal_Width ~ Sepal_Length + Species, data = training) s <- summary(m) diff --git a/R/pkg/inst/tests/testthat/test_parallelize_collect.R b/R/pkg/inst/tests/testthat/test_parallelize_collect.R index 55972e1ba4693..1f7f387de08ce 100644 --- a/R/pkg/inst/tests/testthat/test_parallelize_collect.R +++ b/R/pkg/inst/tests/testthat/test_parallelize_collect.R @@ -39,6 +39,8 @@ jsc <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", "getJavaSparkContext", # Tests test_that("parallelize() on simple vectors and lists returns an RDD", { + skip_on_cran() + numVectorRDD <- parallelize(jsc, numVector, 1) numVectorRDD2 <- parallelize(jsc, numVector, 10) numListRDD <- parallelize(jsc, numList, 1) @@ -66,6 +68,8 @@ test_that("parallelize() on simple vectors and lists returns an RDD", { }) test_that("collect(), following a parallelize(), gives back the original collections", { + skip_on_cran() + numVectorRDD <- parallelize(jsc, numVector, 10) expect_equal(collectRDD(numVectorRDD), as.list(numVector)) @@ -86,6 +90,8 @@ test_that("collect(), following a parallelize(), gives back the original collect }) test_that("regression: collect() following a parallelize() does not drop elements", { + skip_on_cran() + # 10 %/% 6 = 1, ceiling(10 / 6) = 2 collLen <- 10 numPart <- 6 @@ -95,6 +101,8 @@ test_that("regression: collect() following a parallelize() does not drop element }) test_that("parallelize() and collect() work for lists of pairs (pairwise data)", { + skip_on_cran() + # use the pairwise logical to indicate pairwise data numPairsRDDD1 <- parallelize(jsc, numPairs, 1) numPairsRDDD2 <- parallelize(jsc, numPairs, 2) diff --git a/R/pkg/inst/tests/testthat/test_rdd.R b/R/pkg/inst/tests/testthat/test_rdd.R index b72c801dd958d..a3b1631e1d119 100644 --- a/R/pkg/inst/tests/testthat/test_rdd.R +++ b/R/pkg/inst/tests/testthat/test_rdd.R @@ -29,22 +29,30 @@ intPairs <- list(list(1L, -1), list(2L, 100), list(2L, 1), list(1L, 200)) intRdd <- parallelize(sc, intPairs, 2L) test_that("get number of partitions in RDD", { + skip_on_cran() + expect_equal(getNumPartitionsRDD(rdd), 2) expect_equal(getNumPartitionsRDD(intRdd), 2) }) test_that("first on RDD", { + skip_on_cran() + expect_equal(firstRDD(rdd), 1) newrdd <- lapply(rdd, function(x) x + 1) expect_equal(firstRDD(newrdd), 2) }) test_that("count and length on RDD", { - expect_equal(countRDD(rdd), 10) - expect_equal(lengthRDD(rdd), 10) + skip_on_cran() + + expect_equal(countRDD(rdd), 10) + expect_equal(lengthRDD(rdd), 10) }) test_that("count by values and keys", { + skip_on_cran() + mods <- lapply(rdd, function(x) { x %% 3 }) actual <- countByValue(mods) expected <- list(list(0, 3L), list(1, 4L), list(2, 3L)) @@ -56,30 +64,40 @@ test_that("count by values and keys", { }) test_that("lapply on RDD", { + skip_on_cran() + multiples <- lapply(rdd, function(x) { 2 * x }) actual <- collectRDD(multiples) expect_equal(actual, as.list(nums * 2)) }) test_that("lapplyPartition on RDD", { + skip_on_cran() + sums <- lapplyPartition(rdd, function(part) { sum(unlist(part)) }) actual <- collectRDD(sums) expect_equal(actual, list(15, 40)) }) test_that("mapPartitions on RDD", { + skip_on_cran() + sums <- mapPartitions(rdd, function(part) { sum(unlist(part)) }) actual <- collectRDD(sums) expect_equal(actual, list(15, 40)) }) test_that("flatMap() on RDDs", { + skip_on_cran() + flat <- flatMap(intRdd, function(x) { list(x, x) }) actual <- collectRDD(flat) expect_equal(actual, rep(intPairs, each = 2)) }) test_that("filterRDD on RDD", { + skip_on_cran() + filtered.rdd <- filterRDD(rdd, function(x) { x %% 2 == 0 }) actual <- collectRDD(filtered.rdd) expect_equal(actual, list(2, 4, 6, 8, 10)) @@ -95,6 +113,8 @@ test_that("filterRDD on RDD", { }) test_that("lookup on RDD", { + skip_on_cran() + vals <- lookup(intRdd, 1L) expect_equal(vals, list(-1, 200)) @@ -103,6 +123,8 @@ test_that("lookup on RDD", { }) test_that("several transformations on RDD (a benchmark on PipelinedRDD)", { + skip_on_cran() + rdd2 <- rdd for (i in 1:12) rdd2 <- lapplyPartitionsWithIndex( @@ -117,6 +139,8 @@ test_that("several transformations on RDD (a benchmark on PipelinedRDD)", { }) test_that("PipelinedRDD support actions: cache(), persist(), unpersist(), checkpoint()", { + skip_on_cran() + # RDD rdd2 <- rdd # PipelinedRDD @@ -158,6 +182,8 @@ test_that("PipelinedRDD support actions: cache(), persist(), unpersist(), checkp }) test_that("reduce on RDD", { + skip_on_cran() + sum <- reduce(rdd, "+") expect_equal(sum, 55) @@ -167,6 +193,8 @@ test_that("reduce on RDD", { }) test_that("lapply with dependency", { + skip_on_cran() + fa <- 5 multiples <- lapply(rdd, function(x) { fa * x }) actual <- collectRDD(multiples) @@ -175,6 +203,8 @@ test_that("lapply with dependency", { }) test_that("lapplyPartitionsWithIndex on RDDs", { + skip_on_cran() + func <- function(partIndex, part) { list(partIndex, Reduce("+", part)) } actual <- collectRDD(lapplyPartitionsWithIndex(rdd, func), flatten = FALSE) expect_equal(actual, list(list(0, 15), list(1, 40))) @@ -191,10 +221,14 @@ test_that("lapplyPartitionsWithIndex on RDDs", { }) test_that("sampleRDD() on RDDs", { + skip_on_cran() + expect_equal(unlist(collectRDD(sampleRDD(rdd, FALSE, 1.0, 2014L))), nums) }) test_that("takeSample() on RDDs", { + skip_on_cran() + # ported from RDDSuite.scala, modified seeds data <- parallelize(sc, 1:100, 2L) for (seed in 4:5) { @@ -237,6 +271,8 @@ test_that("takeSample() on RDDs", { }) test_that("mapValues() on pairwise RDDs", { + skip_on_cran() + multiples <- mapValues(intRdd, function(x) { x * 2 }) actual <- collectRDD(multiples) expected <- lapply(intPairs, function(x) { @@ -246,6 +282,8 @@ test_that("mapValues() on pairwise RDDs", { }) test_that("flatMapValues() on pairwise RDDs", { + skip_on_cran() + l <- parallelize(sc, list(list(1, c(1, 2)), list(2, c(3, 4)))) actual <- collectRDD(flatMapValues(l, function(x) { x })) expect_equal(actual, list(list(1, 1), list(1, 2), list(2, 3), list(2, 4))) @@ -258,6 +296,8 @@ test_that("flatMapValues() on pairwise RDDs", { }) test_that("reduceByKeyLocally() on PairwiseRDDs", { + skip_on_cran() + pairs <- parallelize(sc, list(list(1, 2), list(1.1, 3), list(1, 4)), 2L) actual <- reduceByKeyLocally(pairs, "+") expect_equal(sortKeyValueList(actual), @@ -271,6 +311,8 @@ test_that("reduceByKeyLocally() on PairwiseRDDs", { }) test_that("distinct() on RDDs", { + skip_on_cran() + nums.rep2 <- rep(1:10, 2) rdd.rep2 <- parallelize(sc, nums.rep2, 2L) uniques <- distinctRDD(rdd.rep2) @@ -279,21 +321,29 @@ test_that("distinct() on RDDs", { }) test_that("maximum() on RDDs", { + skip_on_cran() + max <- maximum(rdd) expect_equal(max, 10) }) test_that("minimum() on RDDs", { + skip_on_cran() + min <- minimum(rdd) expect_equal(min, 1) }) test_that("sumRDD() on RDDs", { + skip_on_cran() + sum <- sumRDD(rdd) expect_equal(sum, 55) }) test_that("keyBy on RDDs", { + skip_on_cran() + func <- function(x) { x * x } keys <- keyBy(rdd, func) actual <- collectRDD(keys) @@ -301,6 +351,8 @@ test_that("keyBy on RDDs", { }) test_that("repartition/coalesce on RDDs", { + skip_on_cran() + rdd <- parallelize(sc, 1:20, 4L) # each partition contains 5 elements # repartition @@ -322,6 +374,8 @@ test_that("repartition/coalesce on RDDs", { }) test_that("sortBy() on RDDs", { + skip_on_cran() + sortedRdd <- sortBy(rdd, function(x) { x * x }, ascending = FALSE) actual <- collectRDD(sortedRdd) expect_equal(actual, as.list(sort(nums, decreasing = TRUE))) @@ -333,6 +387,8 @@ test_that("sortBy() on RDDs", { }) test_that("takeOrdered() on RDDs", { + skip_on_cran() + l <- list(10, 1, 2, 9, 3, 4, 5, 6, 7) rdd <- parallelize(sc, l) actual <- takeOrdered(rdd, 6L) @@ -345,6 +401,8 @@ test_that("takeOrdered() on RDDs", { }) test_that("top() on RDDs", { + skip_on_cran() + l <- list(10, 1, 2, 9, 3, 4, 5, 6, 7) rdd <- parallelize(sc, l) actual <- top(rdd, 6L) @@ -357,6 +415,8 @@ test_that("top() on RDDs", { }) test_that("fold() on RDDs", { + skip_on_cran() + actual <- fold(rdd, 0, "+") expect_equal(actual, Reduce("+", nums, 0)) @@ -366,6 +426,8 @@ test_that("fold() on RDDs", { }) test_that("aggregateRDD() on RDDs", { + skip_on_cran() + rdd <- parallelize(sc, list(1, 2, 3, 4)) zeroValue <- list(0, 0) seqOp <- function(x, y) { list(x[[1]] + y, x[[2]] + 1) } @@ -379,6 +441,8 @@ test_that("aggregateRDD() on RDDs", { }) test_that("zipWithUniqueId() on RDDs", { + skip_on_cran() + rdd <- parallelize(sc, list("a", "b", "c", "d", "e"), 3L) actual <- collectRDD(zipWithUniqueId(rdd)) expected <- list(list("a", 0), list("b", 1), list("c", 4), @@ -393,6 +457,8 @@ test_that("zipWithUniqueId() on RDDs", { }) test_that("zipWithIndex() on RDDs", { + skip_on_cran() + rdd <- parallelize(sc, list("a", "b", "c", "d", "e"), 3L) actual <- collectRDD(zipWithIndex(rdd)) expected <- list(list("a", 0), list("b", 1), list("c", 2), @@ -407,24 +473,32 @@ test_that("zipWithIndex() on RDDs", { }) test_that("glom() on RDD", { + skip_on_cran() + rdd <- parallelize(sc, as.list(1:4), 2L) actual <- collectRDD(glom(rdd)) expect_equal(actual, list(list(1, 2), list(3, 4))) }) test_that("keys() on RDDs", { + skip_on_cran() + keys <- keys(intRdd) actual <- collectRDD(keys) expect_equal(actual, lapply(intPairs, function(x) { x[[1]] })) }) test_that("values() on RDDs", { + skip_on_cran() + values <- values(intRdd) actual <- collectRDD(values) expect_equal(actual, lapply(intPairs, function(x) { x[[2]] })) }) test_that("pipeRDD() on RDDs", { + skip_on_cran() + actual <- collectRDD(pipeRDD(rdd, "more")) expected <- as.list(as.character(1:10)) expect_equal(actual, expected) @@ -442,6 +516,8 @@ test_that("pipeRDD() on RDDs", { }) test_that("zipRDD() on RDDs", { + skip_on_cran() + rdd1 <- parallelize(sc, 0:4, 2) rdd2 <- parallelize(sc, 1000:1004, 2) actual <- collectRDD(zipRDD(rdd1, rdd2)) @@ -471,6 +547,8 @@ test_that("zipRDD() on RDDs", { }) test_that("cartesian() on RDDs", { + skip_on_cran() + rdd <- parallelize(sc, 1:3) actual <- collectRDD(cartesian(rdd, rdd)) expect_equal(sortKeyValueList(actual), @@ -514,6 +592,8 @@ test_that("cartesian() on RDDs", { }) test_that("subtract() on RDDs", { + skip_on_cran() + l <- list(1, 1, 2, 2, 3, 4) rdd1 <- parallelize(sc, l) @@ -541,6 +621,8 @@ test_that("subtract() on RDDs", { }) test_that("subtractByKey() on pairwise RDDs", { + skip_on_cran() + l <- list(list("a", 1), list("b", 4), list("b", 5), list("a", 2)) rdd1 <- parallelize(sc, l) @@ -570,6 +652,8 @@ test_that("subtractByKey() on pairwise RDDs", { }) test_that("intersection() on RDDs", { + skip_on_cran() + # intersection with self actual <- collectRDD(intersection(rdd, rdd)) expect_equal(sort(as.integer(actual)), nums) @@ -586,6 +670,8 @@ test_that("intersection() on RDDs", { }) test_that("join() on pairwise RDDs", { + skip_on_cran() + rdd1 <- parallelize(sc, list(list(1, 1), list(2, 4))) rdd2 <- parallelize(sc, list(list(1, 2), list(1, 3))) actual <- collectRDD(joinRDD(rdd1, rdd2, 2L)) @@ -610,6 +696,8 @@ test_that("join() on pairwise RDDs", { }) test_that("leftOuterJoin() on pairwise RDDs", { + skip_on_cran() + rdd1 <- parallelize(sc, list(list(1, 1), list(2, 4))) rdd2 <- parallelize(sc, list(list(1, 2), list(1, 3))) actual <- collectRDD(leftOuterJoin(rdd1, rdd2, 2L)) @@ -640,6 +728,8 @@ test_that("leftOuterJoin() on pairwise RDDs", { }) test_that("rightOuterJoin() on pairwise RDDs", { + skip_on_cran() + rdd1 <- parallelize(sc, list(list(1, 2), list(1, 3))) rdd2 <- parallelize(sc, list(list(1, 1), list(2, 4))) actual <- collectRDD(rightOuterJoin(rdd1, rdd2, 2L)) @@ -667,6 +757,8 @@ test_that("rightOuterJoin() on pairwise RDDs", { }) test_that("fullOuterJoin() on pairwise RDDs", { + skip_on_cran() + rdd1 <- parallelize(sc, list(list(1, 2), list(1, 3), list(3, 3))) rdd2 <- parallelize(sc, list(list(1, 1), list(2, 4))) actual <- collectRDD(fullOuterJoin(rdd1, rdd2, 2L)) @@ -698,6 +790,8 @@ test_that("fullOuterJoin() on pairwise RDDs", { }) test_that("sortByKey() on pairwise RDDs", { + skip_on_cran() + numPairsRdd <- map(rdd, function(x) { list (x, x) }) sortedRdd <- sortByKey(numPairsRdd, ascending = FALSE) actual <- collectRDD(sortedRdd) @@ -747,6 +841,8 @@ test_that("sortByKey() on pairwise RDDs", { }) test_that("collectAsMap() on a pairwise RDD", { + skip_on_cran() + rdd <- parallelize(sc, list(list(1, 2), list(3, 4))) vals <- collectAsMap(rdd) expect_equal(vals, list(`1` = 2, `3` = 4)) @@ -765,11 +861,15 @@ test_that("collectAsMap() on a pairwise RDD", { }) test_that("show()", { + skip_on_cran() + rdd <- parallelize(sc, list(1:10)) expect_output(showRDD(rdd), "ParallelCollectionRDD\\[\\d+\\] at parallelize at RRDD\\.scala:\\d+") }) test_that("sampleByKey() on pairwise RDDs", { + skip_on_cran() + rdd <- parallelize(sc, 1:2000) pairsRDD <- lapply(rdd, function(x) { if (x %% 2 == 0) list("a", x) else list("b", x) }) fractions <- list(a = 0.2, b = 0.1) @@ -794,6 +894,8 @@ test_that("sampleByKey() on pairwise RDDs", { }) test_that("Test correct concurrency of RRDD.compute()", { + skip_on_cran() + rdd <- parallelize(sc, 1:1000, 100) jrdd <- getJRDD(lapply(rdd, function(x) { x }), "row") zrdd <- callJMethod(jrdd, "zip", jrdd) diff --git a/R/pkg/inst/tests/testthat/test_shuffle.R b/R/pkg/inst/tests/testthat/test_shuffle.R index d38efab0fd1df..cedf4f100c6c4 100644 --- a/R/pkg/inst/tests/testthat/test_shuffle.R +++ b/R/pkg/inst/tests/testthat/test_shuffle.R @@ -37,6 +37,8 @@ strList <- list("Dexter Morgan: Blood. Sometimes it sets my teeth on edge and ", strListRDD <- parallelize(sc, strList, 4) test_that("groupByKey for integers", { + skip_on_cran() + grouped <- groupByKey(intRdd, 2L) actual <- collectRDD(grouped) @@ -46,6 +48,8 @@ test_that("groupByKey for integers", { }) test_that("groupByKey for doubles", { + skip_on_cran() + grouped <- groupByKey(doubleRdd, 2L) actual <- collectRDD(grouped) @@ -55,6 +59,8 @@ test_that("groupByKey for doubles", { }) test_that("reduceByKey for ints", { + skip_on_cran() + reduced <- reduceByKey(intRdd, "+", 2L) actual <- collectRDD(reduced) @@ -64,6 +70,8 @@ test_that("reduceByKey for ints", { }) test_that("reduceByKey for doubles", { + skip_on_cran() + reduced <- reduceByKey(doubleRdd, "+", 2L) actual <- collectRDD(reduced) @@ -72,6 +80,8 @@ test_that("reduceByKey for doubles", { }) test_that("combineByKey for ints", { + skip_on_cran() + reduced <- combineByKey(intRdd, function(x) { x }, "+", "+", 2L) actual <- collectRDD(reduced) @@ -81,6 +91,8 @@ test_that("combineByKey for ints", { }) test_that("combineByKey for doubles", { + skip_on_cran() + reduced <- combineByKey(doubleRdd, function(x) { x }, "+", "+", 2L) actual <- collectRDD(reduced) @@ -89,6 +101,8 @@ test_that("combineByKey for doubles", { }) test_that("combineByKey for characters", { + skip_on_cran() + stringKeyRDD <- parallelize(sc, list(list("max", 1L), list("min", 2L), list("other", 3L), list("max", 4L)), 2L) @@ -101,6 +115,8 @@ test_that("combineByKey for characters", { }) test_that("aggregateByKey", { + skip_on_cran() + # test aggregateByKey for int keys rdd <- parallelize(sc, list(list(1, 1), list(1, 2), list(2, 3), list(2, 4))) @@ -129,6 +145,8 @@ test_that("aggregateByKey", { }) test_that("foldByKey", { + skip_on_cran() + # test foldByKey for int keys folded <- foldByKey(intRdd, 0, "+", 2L) @@ -172,6 +190,8 @@ test_that("foldByKey", { }) test_that("partitionBy() partitions data correctly", { + skip_on_cran() + # Partition by magnitude partitionByMagnitude <- function(key) { if (key >= 3) 1 else 0 } @@ -187,6 +207,8 @@ test_that("partitionBy() partitions data correctly", { }) test_that("partitionBy works with dependencies", { + skip_on_cran() + kOne <- 1 partitionByParity <- function(key) { if (key %% 2 == kOne) 7 else 4 } @@ -205,6 +227,8 @@ test_that("partitionBy works with dependencies", { }) test_that("test partitionBy with string keys", { + skip_on_cran() + words <- flatMap(strListRDD, function(line) { strsplit(line, " ")[[1]] }) wordCount <- lapply(words, function(word) { list(word, 1L) }) diff --git a/R/pkg/inst/tests/testthat/test_sparkR.R b/R/pkg/inst/tests/testthat/test_sparkR.R index f73fc6baeccef..a40981c188f7a 100644 --- a/R/pkg/inst/tests/testthat/test_sparkR.R +++ b/R/pkg/inst/tests/testthat/test_sparkR.R @@ -18,6 +18,8 @@ context("functions in sparkR.R") test_that("sparkCheckInstall", { + skip_on_cran() + # "local, yarn-client, mesos-client" mode, SPARK_HOME was set correctly, # and the SparkR job was submitted by "spark-submit" sparkHome <- paste0(tempdir(), "/", "sparkHome") diff --git a/R/pkg/inst/tests/testthat/test_sparkSQL.R b/R/pkg/inst/tests/testthat/test_sparkSQL.R index 12867c15d1f95..a7bb3265d92d7 100644 --- a/R/pkg/inst/tests/testthat/test_sparkSQL.R +++ b/R/pkg/inst/tests/testthat/test_sparkSQL.R @@ -97,15 +97,21 @@ mapTypeJsonPath <- tempfile(pattern = "sparkr-test", fileext = ".tmp") writeLines(mockLinesMapType, mapTypeJsonPath) test_that("calling sparkRSQL.init returns existing SQL context", { + skip_on_cran() + sqlContext <- suppressWarnings(sparkRSQL.init(sc)) expect_equal(suppressWarnings(sparkRSQL.init(sc)), sqlContext) }) test_that("calling sparkRSQL.init returns existing SparkSession", { + skip_on_cran() + expect_equal(suppressWarnings(sparkRSQL.init(sc)), sparkSession) }) test_that("calling sparkR.session returns existing SparkSession", { + skip_on_cran() + expect_equal(sparkR.session(), sparkSession) }) @@ -203,6 +209,8 @@ test_that("structField type strings", { }) test_that("create DataFrame from RDD", { + skip_on_cran() + rdd <- lapply(parallelize(sc, 1:10), function(x) { list(x, as.character(x)) }) df <- createDataFrame(rdd, list("a", "b")) dfAsDF <- as.DataFrame(rdd, list("a", "b")) @@ -300,6 +308,8 @@ test_that("create DataFrame from RDD", { }) test_that("createDataFrame uses files for large objects", { + skip_on_cran() + # To simulate a large file scenario, we set spark.r.maxAllocationLimit to a smaller value conf <- callJMethod(sparkSession, "conf") callJMethod(conf, "set", "spark.r.maxAllocationLimit", "100") @@ -360,6 +370,8 @@ test_that("read/write csv as DataFrame", { }) test_that("Support other types for options", { + skip_on_cran() + csvPath <- tempfile(pattern = "sparkr-test", fileext = ".csv") mockLinesCsv <- c("year,make,model,comment,blank", "\"2012\",\"Tesla\",\"S\",\"No comment\",", @@ -414,6 +426,8 @@ test_that("convert NAs to null type in DataFrames", { }) test_that("toDF", { + skip_on_cran() + rdd <- lapply(parallelize(sc, 1:10), function(x) { list(x, as.character(x)) }) df <- toDF(rdd, list("a", "b")) expect_is(df, "SparkDataFrame") @@ -525,6 +539,8 @@ test_that("create DataFrame with complex types", { }) test_that("create DataFrame from a data.frame with complex types", { + skip_on_cran() + ldf <- data.frame(row.names = 1:2) ldf$a_list <- list(list(1, 2), list(3, 4)) ldf$an_envir <- c(as.environment(list(a = 1, b = 2)), as.environment(list(c = 3))) @@ -537,6 +553,8 @@ test_that("create DataFrame from a data.frame with complex types", { }) test_that("Collect DataFrame with complex types", { + skip_on_cran() + # ArrayType df <- read.json(complexTypeJsonPath) ldf <- collect(df) @@ -624,6 +642,8 @@ test_that("read/write json files", { }) test_that("read/write json files - compression option", { + skip_on_cran() + df <- read.df(jsonPath, "json") jsonPath <- tempfile(pattern = "jsonPath", fileext = ".json") @@ -637,6 +657,8 @@ test_that("read/write json files - compression option", { }) test_that("jsonRDD() on a RDD with json string", { + skip_on_cran() + sqlContext <- suppressWarnings(sparkRSQL.init(sc)) rdd <- parallelize(sc, mockLines) expect_equal(countRDD(rdd), 3) @@ -693,6 +715,8 @@ test_that( }) test_that("test cache, uncache and clearCache", { + skip_on_cran() + df <- read.json(jsonPath) createOrReplaceTempView(df, "table1") cacheTable("table1") @@ -746,6 +770,8 @@ test_that("tableToDF() returns a new DataFrame", { }) test_that("toRDD() returns an RRDD", { + skip_on_cran() + df <- read.json(jsonPath) testRDD <- toRDD(df) expect_is(testRDD, "RDD") @@ -753,6 +779,8 @@ test_that("toRDD() returns an RRDD", { }) test_that("union on two RDDs created from DataFrames returns an RRDD", { + skip_on_cran() + df <- read.json(jsonPath) RDD1 <- toRDD(df) RDD2 <- toRDD(df) @@ -763,6 +791,8 @@ test_that("union on two RDDs created from DataFrames returns an RRDD", { }) test_that("union on mixed serialization types correctly returns a byte RRDD", { + skip_on_cran() + # Byte RDD nums <- 1:10 rdd <- parallelize(sc, nums, 2L) @@ -792,6 +822,8 @@ test_that("union on mixed serialization types correctly returns a byte RRDD", { }) test_that("objectFile() works with row serialization", { + skip_on_cran() + objectPath <- tempfile(pattern = "spark-test", fileext = ".tmp") df <- read.json(jsonPath) dfRDD <- toRDD(df) @@ -804,6 +836,8 @@ test_that("objectFile() works with row serialization", { }) test_that("lapply() on a DataFrame returns an RDD with the correct columns", { + skip_on_cran() + df <- read.json(jsonPath) testRDD <- lapply(df, function(row) { row$newCol <- row$age + 5 @@ -872,6 +906,8 @@ test_that("collect() support Unicode characters", { }) test_that("multiple pipeline transformations result in an RDD with the correct values", { + skip_on_cran() + df <- read.json(jsonPath) first <- lapply(df, function(row) { row$age <- row$age + 5 @@ -1497,7 +1533,6 @@ test_that("column functions", { collect(select(df, alias(not(df$is_true), "is_false"))), data.frame(is_false = c(FALSE, TRUE, NA)) ) - }) test_that("column binary mathfunctions", { @@ -2306,6 +2341,8 @@ test_that("mutate(), transform(), rename() and names()", { }) test_that("read/write ORC files", { + skip_on_cran() + setHiveContext(sc) df <- read.df(jsonPath, "json") @@ -2327,6 +2364,8 @@ test_that("read/write ORC files", { }) test_that("read/write ORC files - compression option", { + skip_on_cran() + setHiveContext(sc) df <- read.df(jsonPath, "json") @@ -2373,6 +2412,8 @@ test_that("read/write Parquet files", { }) test_that("read/write Parquet files - compression option/mode", { + skip_on_cran() + df <- read.df(jsonPath, "json") tempPath <- tempfile(pattern = "tempPath", fileext = ".parquet") @@ -2390,6 +2431,8 @@ test_that("read/write Parquet files - compression option/mode", { }) test_that("read/write text files", { + skip_on_cran() + # Test write.df and read.df df <- read.df(jsonPath, "text") expect_is(df, "SparkDataFrame") @@ -2411,6 +2454,8 @@ test_that("read/write text files", { }) test_that("read/write text files - compression option", { + skip_on_cran() + df <- read.df(jsonPath, "text") textPath <- tempfile(pattern = "textPath", fileext = ".txt") @@ -2644,6 +2689,8 @@ test_that("approxQuantile() on a DataFrame", { }) test_that("SQL error message is returned from JVM", { + skip_on_cran() + retError <- tryCatch(sql("select * from blah"), error = function(e) e) expect_equal(grepl("Table or view not found", retError), TRUE) expect_equal(grepl("blah", retError), TRUE) @@ -2652,6 +2699,8 @@ test_that("SQL error message is returned from JVM", { irisDF <- suppressWarnings(createDataFrame(iris)) test_that("Method as.data.frame as a synonym for collect()", { + skip_on_cran() + expect_equal(as.data.frame(irisDF), collect(irisDF)) irisDF2 <- irisDF[irisDF$Species == "setosa", ] expect_equal(as.data.frame(irisDF2), collect(irisDF2)) @@ -3069,6 +3118,8 @@ test_that("Window functions on a DataFrame", { }) test_that("createDataFrame sqlContext parameter backward compatibility", { + skip_on_cran() + sqlContext <- suppressWarnings(sparkRSQL.init(sc)) a <- 1:3 b <- c("a", "b", "c") @@ -3148,6 +3199,8 @@ test_that("Setting and getting config on SparkSession, sparkR.conf(), sparkR.uiW }) test_that("enableHiveSupport on SparkSession", { + skip_on_cran() + setHiveContext(sc) unsetHiveContext() # if we are still here, it must be built with hive @@ -3163,6 +3216,8 @@ test_that("Spark version from SparkSession", { }) test_that("Call DataFrameWriter.save() API in Java without path and check argument types", { + skip_on_cran() + df <- read.df(jsonPath, "json") # This tests if the exception is thrown from JVM not from SparkR side. # It makes sure that we can omit path argument in write.df API and then it calls @@ -3189,6 +3244,8 @@ test_that("Call DataFrameWriter.save() API in Java without path and check argume }) test_that("Call DataFrameWriter.load() API in Java without path and check argument types", { + skip_on_cran() + # This tests if the exception is thrown from JVM not from SparkR side. # It makes sure that we can omit path argument in read.df API and then it calls # DataFrameWriter.load() without path. @@ -3313,6 +3370,8 @@ compare_list <- function(list1, list2) { # This should always be the **very last test** in this test file. test_that("No extra files are created in SPARK_HOME by starting session and making calls", { + skip_on_cran() + # Check that it is not creating any extra file. # Does not check the tempdir which would be cleaned up after. filesAfter <- list.files(path = sparkRDir, all.files = TRUE) diff --git a/R/pkg/inst/tests/testthat/test_streaming.R b/R/pkg/inst/tests/testthat/test_streaming.R index b125cb0591de2..8843991024308 100644 --- a/R/pkg/inst/tests/testthat/test_streaming.R +++ b/R/pkg/inst/tests/testthat/test_streaming.R @@ -47,6 +47,8 @@ schema <- structType(structField("name", "string"), structField("count", "double")) test_that("read.stream, write.stream, awaitTermination, stopQuery", { + skip_on_cran() + df <- read.stream("json", path = jsonDir, schema = schema, maxFilesPerTrigger = 1) expect_true(isStreaming(df)) counts <- count(group_by(df, "name")) @@ -65,6 +67,8 @@ test_that("read.stream, write.stream, awaitTermination, stopQuery", { }) test_that("print from explain, lastProgress, status, isActive", { + skip_on_cran() + df <- read.stream("json", path = jsonDir, schema = schema) expect_true(isStreaming(df)) counts <- count(group_by(df, "name")) @@ -83,6 +87,8 @@ test_that("print from explain, lastProgress, status, isActive", { }) test_that("Stream other format", { + skip_on_cran() + parquetPath <- tempfile(pattern = "sparkr-test", fileext = ".parquet") df <- read.df(jsonPath, "json", schema) write.df(df, parquetPath, "parquet", "overwrite") @@ -108,6 +114,8 @@ test_that("Stream other format", { }) test_that("Non-streaming DataFrame", { + skip_on_cran() + c <- as.DataFrame(cars) expect_false(isStreaming(c)) @@ -117,6 +125,8 @@ test_that("Non-streaming DataFrame", { }) test_that("Unsupported operation", { + skip_on_cran() + # memory sink without aggregation df <- read.stream("json", path = jsonDir, schema = schema, maxFilesPerTrigger = 1) expect_error(write.stream(df, "memory", queryName = "people", outputMode = "complete"), @@ -125,6 +135,8 @@ test_that("Unsupported operation", { }) test_that("Terminated by error", { + skip_on_cran() + df <- read.stream("json", path = jsonDir, schema = schema, maxFilesPerTrigger = -1) counts <- count(group_by(df, "name")) # This would not fail before returning with a StreamingQuery, diff --git a/R/pkg/inst/tests/testthat/test_take.R b/R/pkg/inst/tests/testthat/test_take.R index aaa532856c3d9..e2130eaac78dd 100644 --- a/R/pkg/inst/tests/testthat/test_take.R +++ b/R/pkg/inst/tests/testthat/test_take.R @@ -34,6 +34,8 @@ sparkSession <- sparkR.session(enableHiveSupport = FALSE) sc <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", "getJavaSparkContext", sparkSession) test_that("take() gives back the original elements in correct count and order", { + skip_on_cran() + numVectorRDD <- parallelize(sc, numVector, 10) # case: number of elements to take is less than the size of the first partition expect_equal(takeRDD(numVectorRDD, 1), as.list(head(numVector, n = 1))) diff --git a/R/pkg/inst/tests/testthat/test_textFile.R b/R/pkg/inst/tests/testthat/test_textFile.R index 3b466066e9390..28b7e8e3183fd 100644 --- a/R/pkg/inst/tests/testthat/test_textFile.R +++ b/R/pkg/inst/tests/testthat/test_textFile.R @@ -24,6 +24,8 @@ sc <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", "getJavaSparkContext", mockFile <- c("Spark is pretty.", "Spark is awesome.") test_that("textFile() on a local file returns an RDD", { + skip_on_cran() + fileName <- tempfile(pattern = "spark-test", fileext = ".tmp") writeLines(mockFile, fileName) @@ -36,6 +38,8 @@ test_that("textFile() on a local file returns an RDD", { }) test_that("textFile() followed by a collect() returns the same content", { + skip_on_cran() + fileName <- tempfile(pattern = "spark-test", fileext = ".tmp") writeLines(mockFile, fileName) @@ -46,6 +50,8 @@ test_that("textFile() followed by a collect() returns the same content", { }) test_that("textFile() word count works as expected", { + skip_on_cran() + fileName <- tempfile(pattern = "spark-test", fileext = ".tmp") writeLines(mockFile, fileName) @@ -64,6 +70,8 @@ test_that("textFile() word count works as expected", { }) test_that("several transformations on RDD created by textFile()", { + skip_on_cran() + fileName <- tempfile(pattern = "spark-test", fileext = ".tmp") writeLines(mockFile, fileName) @@ -78,6 +86,8 @@ test_that("several transformations on RDD created by textFile()", { }) test_that("textFile() followed by a saveAsTextFile() returns the same content", { + skip_on_cran() + fileName1 <- tempfile(pattern = "spark-test", fileext = ".tmp") fileName2 <- tempfile(pattern = "spark-test", fileext = ".tmp") writeLines(mockFile, fileName1) @@ -92,6 +102,8 @@ test_that("textFile() followed by a saveAsTextFile() returns the same content", }) test_that("saveAsTextFile() on a parallelized list works as expected", { + skip_on_cran() + fileName <- tempfile(pattern = "spark-test", fileext = ".tmp") l <- list(1, 2, 3) rdd <- parallelize(sc, l, 1L) @@ -103,6 +115,8 @@ test_that("saveAsTextFile() on a parallelized list works as expected", { }) test_that("textFile() and saveAsTextFile() word count works as expected", { + skip_on_cran() + fileName1 <- tempfile(pattern = "spark-test", fileext = ".tmp") fileName2 <- tempfile(pattern = "spark-test", fileext = ".tmp") writeLines(mockFile, fileName1) @@ -128,6 +142,8 @@ test_that("textFile() and saveAsTextFile() word count works as expected", { }) test_that("textFile() on multiple paths", { + skip_on_cran() + fileName1 <- tempfile(pattern = "spark-test", fileext = ".tmp") fileName2 <- tempfile(pattern = "spark-test", fileext = ".tmp") writeLines("Spark is pretty.", fileName1) @@ -141,6 +157,8 @@ test_that("textFile() on multiple paths", { }) test_that("Pipelined operations on RDDs created using textFile", { + skip_on_cran() + fileName <- tempfile(pattern = "spark-test", fileext = ".tmp") writeLines(mockFile, fileName) diff --git a/R/pkg/inst/tests/testthat/test_utils.R b/R/pkg/inst/tests/testthat/test_utils.R index 1ca383da26ec2..4a01e875405ff 100644 --- a/R/pkg/inst/tests/testthat/test_utils.R +++ b/R/pkg/inst/tests/testthat/test_utils.R @@ -23,6 +23,7 @@ sc <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", "getJavaSparkContext", test_that("convertJListToRList() gives back (deserializes) the original JLists of strings and integers", { + skip_on_cran() # It's hard to manually create a Java List using rJava, since it does not # support generics well. Instead, we rely on collectRDD() returning a # JList. @@ -40,6 +41,7 @@ test_that("convertJListToRList() gives back (deserializes) the original JLists }) test_that("serializeToBytes on RDD", { + skip_on_cran() # File content mockFile <- c("Spark is pretty.", "Spark is awesome.") fileName <- tempfile(pattern = "spark-test", fileext = ".tmp") @@ -167,6 +169,8 @@ test_that("convertToJSaveMode", { }) test_that("captureJVMException", { + skip_on_cran() + method <- "createStructField" expect_error(tryCatch(callJStatic("org.apache.spark.sql.api.r.SQLUtils", method, "col", "unknown", TRUE), @@ -177,6 +181,8 @@ test_that("captureJVMException", { }) test_that("hashCode", { + skip_on_cran() + expect_error(hashCode("bc53d3605e8a5b7de1e8e271c2317645"), NA) }) diff --git a/R/run-tests.sh b/R/run-tests.sh index 742a2c5ed76da..29764f48bd156 100755 --- a/R/run-tests.sh +++ b/R/run-tests.sh @@ -23,7 +23,7 @@ FAILED=0 LOGFILE=$FWDIR/unit-tests.out rm -f $LOGFILE -SPARK_TESTING=1 $FWDIR/../bin/spark-submit --driver-java-options "-Dlog4j.configuration=file:$FWDIR/log4j.properties" --conf spark.hadoop.fs.defaultFS="file:///" $FWDIR/pkg/tests/run-all.R 2>&1 | tee -a $LOGFILE +SPARK_TESTING=1 NOT_CRAN=true $FWDIR/../bin/spark-submit --driver-java-options "-Dlog4j.configuration=file:$FWDIR/log4j.properties" --conf spark.hadoop.fs.defaultFS="file:///" $FWDIR/pkg/tests/run-all.R 2>&1 | tee -a $LOGFILE FAILED=$((PIPESTATUS[0]||$FAILED)) NUM_TEST_WARNING="$(grep -c -e 'Warnings ----------------' $LOGFILE)" From b8302ccd02265f9d7a7895c7b033441fa2d8ffd1 Mon Sep 17 00:00:00 2001 From: Felix Cheung Date: Thu, 4 May 2017 00:27:10 -0700 Subject: [PATCH 0404/1765] [SPARK-20015][SPARKR][SS][DOC][EXAMPLE] Document R Structured Streaming (experimental) in R vignettes and R & SS programming guide, R example ## What changes were proposed in this pull request? Add - R vignettes - R programming guide - SS programming guide - R example Also disable spark.als in vignettes for now since it's failing (SPARK-20402) ## How was this patch tested? manually Author: Felix Cheung Closes #17814 from felixcheung/rdocss. --- R/pkg/vignettes/sparkr-vignettes.Rmd | 79 ++++- docs/sparkr.md | 4 + .../structured-streaming-programming-guide.md | 285 +++++++++++++++--- .../streaming/structured_network_wordcount.R | 57 ++++ 4 files changed, 381 insertions(+), 44 deletions(-) create mode 100644 examples/src/main/r/streaming/structured_network_wordcount.R diff --git a/R/pkg/vignettes/sparkr-vignettes.Rmd b/R/pkg/vignettes/sparkr-vignettes.Rmd index 4b9d6c3806098..d38ec4f1b6f37 100644 --- a/R/pkg/vignettes/sparkr-vignettes.Rmd +++ b/R/pkg/vignettes/sparkr-vignettes.Rmd @@ -182,7 +182,7 @@ head(df) ``` ### Data Sources -SparkR supports operating on a variety of data sources through the `SparkDataFrame` interface. You can check the Spark SQL programming guide for more [specific options](https://spark.apache.org/docs/latest/sql-programming-guide.html#manually-specifying-options) that are available for the built-in data sources. +SparkR supports operating on a variety of data sources through the `SparkDataFrame` interface. You can check the Spark SQL Programming Guide for more [specific options](https://spark.apache.org/docs/latest/sql-programming-guide.html#manually-specifying-options) that are available for the built-in data sources. The general method for creating `SparkDataFrame` from data sources is `read.df`. This method takes in the path for the file to load and the type of data source, and the currently active Spark Session will be used automatically. SparkR supports reading CSV, JSON and Parquet files natively and through Spark Packages you can find data source connectors for popular file formats like Avro. These packages can be added with `sparkPackages` parameter when initializing SparkSession using `sparkR.session`. @@ -232,7 +232,7 @@ write.df(people, path = "people.parquet", source = "parquet", mode = "overwrite" ``` ### Hive Tables -You can also create SparkDataFrames from Hive tables. To do this we will need to create a SparkSession with Hive support which can access tables in the Hive MetaStore. Note that Spark should have been built with Hive support and more details can be found in the [SQL programming guide](https://spark.apache.org/docs/latest/sql-programming-guide.html). In SparkR, by default it will attempt to create a SparkSession with Hive support enabled (`enableHiveSupport = TRUE`). +You can also create SparkDataFrames from Hive tables. To do this we will need to create a SparkSession with Hive support which can access tables in the Hive MetaStore. Note that Spark should have been built with Hive support and more details can be found in the [SQL Programming Guide](https://spark.apache.org/docs/latest/sql-programming-guide.html). In SparkR, by default it will attempt to create a SparkSession with Hive support enabled (`enableHiveSupport = TRUE`). ```{r, eval=FALSE} sql("CREATE TABLE IF NOT EXISTS src (key INT, value STRING)") @@ -314,7 +314,7 @@ Use `cube` or `rollup` to compute subtotals across multiple dimensions. mean(cube(carsDF, "cyl", "gear", "am"), "mpg") ``` -generates groupings for {(`cyl`, `gear`, `am`), (`cyl`, `gear`), (`cyl`), ()}, while +generates groupings for {(`cyl`, `gear`, `am`), (`cyl`, `gear`), (`cyl`), ()}, while ```{r} mean(rollup(carsDF, "cyl", "gear", "am"), "mpg") @@ -672,6 +672,7 @@ head(select(naiveBayesPrediction, "Class", "Sex", "Age", "Survived", "prediction Survival analysis studies the expected duration of time until an event happens, and often the relationship with risk factors or treatment taken on the subject. In contrast to standard regression analysis, survival modeling has to deal with special characteristics in the data including non-negative survival time and censoring. Accelerated Failure Time (AFT) model is a parametric survival model for censored data that assumes the effect of a covariate is to accelerate or decelerate the life course of an event by some constant. For more information, refer to the Wikipedia page [AFT Model](https://en.wikipedia.org/wiki/Accelerated_failure_time_model) and the references there. Different from a [Proportional Hazards Model](https://en.wikipedia.org/wiki/Proportional_hazards_model) designed for the same purpose, the AFT model is easier to parallelize because each instance contributes to the objective function independently. + ```{r, warning=FALSE} library(survival) ovarianDF <- createDataFrame(ovarian) @@ -902,7 +903,7 @@ perplexity There are multiple options that can be configured in `spark.als`, including `rank`, `reg`, `nonnegative`. For a complete list, refer to the help file. -```{r} +```{r, eval=FALSE} ratings <- list(list(0, 0, 4.0), list(0, 1, 2.0), list(1, 1, 3.0), list(1, 2, 4.0), list(2, 1, 1.0), list(2, 2, 5.0)) df <- createDataFrame(ratings, c("user", "item", "rating")) @@ -910,7 +911,7 @@ model <- spark.als(df, "rating", "user", "item", rank = 10, reg = 0.1, nonnegati ``` Extract latent factors. -```{r} +```{r, eval=FALSE} stats <- summary(model) userFactors <- stats$userFactors itemFactors <- stats$itemFactors @@ -920,7 +921,7 @@ head(itemFactors) Make predictions. -```{r} +```{r, eval=FALSE} predicted <- predict(model, df) head(predicted) ``` @@ -1002,6 +1003,72 @@ unlink(modelPath) ``` +## Structured Streaming + +SparkR supports the Structured Streaming API (experimental). + +You can check the Structured Streaming Programming Guide for [an introduction](https://spark.apache.org/docs/latest/structured-streaming-programming-guide.html#programming-model) to its programming model and basic concepts. + +### Simple Source and Sink + +Spark has a few built-in input sources. As an example, to test with a socket source reading text into words and displaying the computed word counts: + +```{r, eval=FALSE} +# Create DataFrame representing the stream of input lines from connection +lines <- read.stream("socket", host = hostname, port = port) + +# Split the lines into words +words <- selectExpr(lines, "explode(split(value, ' ')) as word") + +# Generate running word count +wordCounts <- count(groupBy(words, "word")) + +# Start running the query that prints the running counts to the console +query <- write.stream(wordCounts, "console", outputMode = "complete") +``` + +### Kafka Source + +It is simple to read data from Kafka. For more information, see [Input Sources](https://spark.apache.org/docs/latest/structured-streaming-programming-guide.html#input-sources) supported by Structured Streaming. + +```{r, eval=FALSE} +topic <- read.stream("kafka", + kafka.bootstrap.servers = "host1:port1,host2:port2", + subscribe = "topic1") +keyvalue <- selectExpr(topic, "CAST(key AS STRING)", "CAST(value AS STRING)") +``` + +### Operations and Sinks + +Most of the common operations on `SparkDataFrame` are supported for streaming, including selection, projection, and aggregation. Once you have defined the final result, to start the streaming computation, you will call the `write.stream` method setting a sink and `outputMode`. + +A streaming `SparkDataFrame` can be written for debugging to the console, to a temporary in-memory table, or for further processing in a fault-tolerant manner to a File Sink in different formats. + +```{r, eval=FALSE} +noAggDF <- select(where(deviceDataStreamingDf, "signal > 10"), "device") + +# Print new data to console +write.stream(noAggDF, "console") + +# Write new data to Parquet files +write.stream(noAggDF, + "parquet", + path = "path/to/destination/dir", + checkpointLocation = "path/to/checkpoint/dir") + +# Aggregate +aggDF <- count(groupBy(noAggDF, "device")) + +# Print updated aggregations to console +write.stream(aggDF, "console", outputMode = "complete") + +# Have all the aggregates in an in memory table. The query name will be the table name +write.stream(aggDF, "memory", queryName = "aggregates", outputMode = "complete") + +head(sql("select * from aggregates")) +``` + + ## Advanced Topics ### SparkR Object Classes diff --git a/docs/sparkr.md b/docs/sparkr.md index 6dbd02a48890d..569b85e72c3cf 100644 --- a/docs/sparkr.md +++ b/docs/sparkr.md @@ -593,6 +593,10 @@ The following example shows how to save/load a MLlib model by SparkR.

    (?i)secret|password Regex to decide which Spark configuration properties and environment variables in driver and - executor environments contain sensitive information. When this regex matches a property, its - value is redacted from the environment UI and various logs like YARN and event logs. + executor environments contain sensitive information. When this regex matches a property key or + value, the value is redacted from the environment UI and various logs like YARN and event logs.
    spark.worker.cleanup.appDataTtl7 * 24 * 3600 (7 days)604800 (7 days, 7 * 24 * 3600) The number of seconds to retain application work directories on each worker. This is a Time To Live and should depend on the amount of available disk space you have. Application logs and jars are From 1ee494d0868a85af3154996732817ed63679f382 Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Sun, 30 Apr 2017 08:24:10 -0700 Subject: [PATCH 0372/1765] [SPARK-20492][SQL] Do not print empty parentheses for invalid primitive types in parser ## What changes were proposed in this pull request? Currently, when the type string is invalid, it looks printing empty parentheses. This PR proposes a small improvement in an error message by removing it in the parse as below: ```scala spark.range(1).select($"col".cast("aa")) ``` **Before** ``` org.apache.spark.sql.catalyst.parser.ParseException: DataType aa() is not supported.(line 1, pos 0) == SQL == aa ^^^ ``` **After** ``` org.apache.spark.sql.catalyst.parser.ParseException: DataType aa is not supported.(line 1, pos 0) == SQL == aa ^^^ ``` ## How was this patch tested? Unit tests in `DataTypeParserSuite`. Author: hyukjinkwon Closes #17784 from HyukjinKwon/SPARK-20492. --- .../org/apache/spark/sql/catalyst/parser/AstBuilder.scala | 4 ++-- .../spark/sql/catalyst/parser/DataTypeParserSuite.scala | 7 ++++++- .../resources/sql-tests/results/json-functions.sql.out | 2 +- .../scala/org/apache/spark/sql/JsonFunctionsSuite.scala | 2 +- 4 files changed, 10 insertions(+), 5 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index 2cf06d15664d9..a48a693a95c93 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -1491,8 +1491,8 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with Logging { case ("decimal", precision :: scale :: Nil) => DecimalType(precision.getText.toInt, scale.getText.toInt) case (dt, params) => - throw new ParseException( - s"DataType $dt${params.mkString("(", ",", ")")} is not supported.", ctx) + val dtStr = if (params.nonEmpty) s"$dt(${params.mkString(",")})" else dt + throw new ParseException(s"DataType $dtStr is not supported.", ctx) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/DataTypeParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/DataTypeParserSuite.scala index 3964fa3924b24..4490523369006 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/DataTypeParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/DataTypeParserSuite.scala @@ -30,7 +30,7 @@ class DataTypeParserSuite extends SparkFunSuite { } } - def intercept(sql: String): Unit = + def intercept(sql: String): ParseException = intercept[ParseException](CatalystSqlParser.parseDataType(sql)) def unsupported(dataTypeString: String): Unit = { @@ -118,6 +118,11 @@ class DataTypeParserSuite extends SparkFunSuite { unsupported("struct") + test("Do not print empty parentheses for no params") { + assert(intercept("unkwon").getMessage.contains("unkwon is not supported")) + assert(intercept("unkwon(1,2,3)").getMessage.contains("unkwon(1,2,3) is not supported")) + } + // DataType parser accepts certain reserved keywords. checkDataType( "Struct", diff --git a/sql/core/src/test/resources/sql-tests/results/json-functions.sql.out b/sql/core/src/test/resources/sql-tests/results/json-functions.sql.out index 315e1730ce7df..fedabaee2237f 100644 --- a/sql/core/src/test/resources/sql-tests/results/json-functions.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/json-functions.sql.out @@ -141,7 +141,7 @@ struct<> -- !query 13 output org.apache.spark.sql.AnalysisException -DataType invalidtype() is not supported.(line 1, pos 2) +DataType invalidtype is not supported.(line 1, pos 2) == SQL == a InvalidType diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JsonFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JsonFunctionsSuite.scala index 8465e8d036a6d..69a500c845a7b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/JsonFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/JsonFunctionsSuite.scala @@ -274,7 +274,7 @@ class JsonFunctionsSuite extends QueryTest with SharedSQLContext { val errMsg2 = intercept[AnalysisException] { df3.selectExpr("""from_json(value, 'time InvalidType')""") } - assert(errMsg2.getMessage.contains("DataType invalidtype() is not supported")) + assert(errMsg2.getMessage.contains("DataType invalidtype is not supported")) val errMsg3 = intercept[AnalysisException] { df3.selectExpr("from_json(value, 'time Timestamp', named_struct('a', 1))") } From ae3df4e98f160f94d1e52c90363f26eb351d0153 Mon Sep 17 00:00:00 2001 From: zero323 Date: Sun, 30 Apr 2017 12:33:03 -0700 Subject: [PATCH 0373/1765] [SPARK-20535][SPARKR] R wrappers for explode_outer and posexplode_outer ## What changes were proposed in this pull request? Ad R wrappers for - `o.a.s.sql.functions.explode_outer` - `o.a.s.sql.functions.posexplode_outer` ## How was this patch tested? Additional unit tests, manual testing. Author: zero323 Closes #17809 from zero323/SPARK-20535. --- R/pkg/NAMESPACE | 2 + R/pkg/R/functions.R | 56 +++++++++++++++++++++++ R/pkg/R/generics.R | 8 ++++ R/pkg/inst/tests/testthat/test_sparkSQL.R | 1 + 4 files changed, 67 insertions(+) diff --git a/R/pkg/NAMESPACE b/R/pkg/NAMESPACE index 2800461658483..db8e06db18edc 100644 --- a/R/pkg/NAMESPACE +++ b/R/pkg/NAMESPACE @@ -234,6 +234,7 @@ exportMethods("%in%", "endsWith", "exp", "explode", + "explode_outer", "expm1", "expr", "factorial", @@ -296,6 +297,7 @@ exportMethods("%in%", "percent_rank", "pmod", "posexplode", + "posexplode_outer", "quarter", "rand", "randn", diff --git a/R/pkg/R/functions.R b/R/pkg/R/functions.R index 6b91fa5bde671..f4a34fbabe4d7 100644 --- a/R/pkg/R/functions.R +++ b/R/pkg/R/functions.R @@ -3803,3 +3803,59 @@ setMethod("repeat_string", jc <- callJStatic("org.apache.spark.sql.functions", "repeat", x@jc, numToInt(n)) column(jc) }) + +#' explode_outer +#' +#' Creates a new row for each element in the given array or map column. +#' Unlike \code{explode}, if the array/map is \code{null} or empty +#' then \code{null} is produced. +#' +#' @param x Column to compute on +#' +#' @rdname explode_outer +#' @name explode_outer +#' @family collection_funcs +#' @aliases explode_outer,Column-method +#' @export +#' @examples \dontrun{ +#' df <- createDataFrame(data.frame( +#' id = c(1, 2, 3), text = c("a,b,c", NA, "d,e") +#' )) +#' +#' head(select(df, df$id, explode_outer(split_string(df$text, ",")))) +#' } +#' @note explode_outer since 2.3.0 +setMethod("explode_outer", + signature(x = "Column"), + function(x) { + jc <- callJStatic("org.apache.spark.sql.functions", "explode_outer", x@jc) + column(jc) + }) + +#' posexplode_outer +#' +#' Creates a new row for each element with position in the given array or map column. +#' Unlike \code{posexplode}, if the array/map is \code{null} or empty +#' then the row (\code{null}, \code{null}) is produced. +#' +#' @param x Column to compute on +#' +#' @rdname posexplode_outer +#' @name posexplode_outer +#' @family collection_funcs +#' @aliases posexplode_outer,Column-method +#' @export +#' @examples \dontrun{ +#' df <- createDataFrame(data.frame( +#' id = c(1, 2, 3), text = c("a,b,c", NA, "d,e") +#' )) +#' +#' head(select(df, df$id, posexplode_outer(split_string(df$text, ",")))) +#' } +#' @note posexplode_outer since 2.3.0 +setMethod("posexplode_outer", + signature(x = "Column"), + function(x) { + jc <- callJStatic("org.apache.spark.sql.functions", "posexplode_outer", x@jc) + column(jc) + }) diff --git a/R/pkg/R/generics.R b/R/pkg/R/generics.R index 749ee9b54cc80..e510ff9a2d80f 100644 --- a/R/pkg/R/generics.R +++ b/R/pkg/R/generics.R @@ -1016,6 +1016,10 @@ setGeneric("encode", function(x, charset) { standardGeneric("encode") }) #' @export setGeneric("explode", function(x) { standardGeneric("explode") }) +#' @rdname explode_outer +#' @export +setGeneric("explode_outer", function(x) { standardGeneric("explode_outer") }) + #' @rdname expr #' @export setGeneric("expr", function(x) { standardGeneric("expr") }) @@ -1175,6 +1179,10 @@ setGeneric("pmod", function(y, x) { standardGeneric("pmod") }) #' @export setGeneric("posexplode", function(x) { standardGeneric("posexplode") }) +#' @rdname posexplode_outer +#' @export +setGeneric("posexplode_outer", function(x) { standardGeneric("posexplode_outer") }) + #' @rdname quarter #' @export setGeneric("quarter", function(x) { standardGeneric("quarter") }) diff --git a/R/pkg/inst/tests/testthat/test_sparkSQL.R b/R/pkg/inst/tests/testthat/test_sparkSQL.R index 1a3d6df437d7e..1828cddffd27c 100644 --- a/R/pkg/inst/tests/testthat/test_sparkSQL.R +++ b/R/pkg/inst/tests/testthat/test_sparkSQL.R @@ -1347,6 +1347,7 @@ test_that("column functions", { c18 <- covar_pop(c, c1) + covar_pop("c", "c1") c19 <- spark_partition_id() + coalesce(c) + coalesce(c1, c2, c3) c20 <- to_timestamp(c) + to_timestamp(c, "yyyy") + to_date(c, "yyyy") + c21 <- posexplode_outer(c) + explode_outer(c) # Test if base::is.nan() is exposed expect_equal(is.nan(c("a", "b")), c(FALSE, FALSE)) From 6613046c8c2daaf46a8ec13dd0a016aad22af1a4 Mon Sep 17 00:00:00 2001 From: Srinivasa Reddy Vundela Date: Sun, 30 Apr 2017 21:42:05 -0700 Subject: [PATCH 0374/1765] [MINOR][DOCS][PYTHON] Adding missing boolean type for replacement value in fillna ## What changes were proposed in this pull request? Currently pyspark Dataframe.fillna API supports boolean type when we pass dict, but it is missing in documentation. ## How was this patch tested? >>> spark.createDataFrame([Row(a=True),Row(a=None)]).fillna({"a" : True}).show() +----+ | a| +----+ |true| |true| +----+ Please review http://spark.apache.org/contributing.html before opening a pull request. Author: Srinivasa Reddy Vundela Closes #17688 from vundela/fillna_doc_fix. --- python/pyspark/sql/dataframe.py | 2 +- python/pyspark/sql/tests.py | 4 ++++ 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index ff21bb5d2fb3f..ab6d35bfa7c5c 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -1247,7 +1247,7 @@ def fillna(self, value, subset=None): Value to replace null values with. If the value is a dict, then `subset` is ignored and `value` must be a mapping from column name (string) to replacement value. The replacement value must be - an int, long, float, or string. + an int, long, float, boolean, or string. :param subset: optional list of column names to consider. Columns specified in subset that do not have matching data type are ignored. For example, if `value` is a string, and subset contains a non-string column, diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 2b2444304e04a..cd92148dfa5df 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -1711,6 +1711,10 @@ def test_fillna(self): self.assertEqual(row.age, None) self.assertEqual(row.height, None) + # fillna with dictionary for boolean types + row = self.spark.createDataFrame([Row(a=None), Row(a=True)]).fillna({"a": True}).first() + self.assertEqual(row.a, True) + def test_bitwise_operations(self): from pyspark.sql import functions row = Row(a=170, b=75) From 80e9cf1b59ce7186a4506f83e50f4fc7759c938c Mon Sep 17 00:00:00 2001 From: zero323 Date: Sun, 30 Apr 2017 22:07:12 -0700 Subject: [PATCH 0375/1765] [SPARK-20490][SPARKR] Add R wrappers for eqNullSafe and ! / not ## What changes were proposed in this pull request? - Add null-safe equality operator `%<=>%` (sames as `o.a.s.sql.Column.eqNullSafe`, `o.a.s.sql.Column.<=>`) - Add boolean negation operator `!` and function `not `. ## How was this patch tested? Existing unit tests, additional unit tests, `check-cran.sh`. Author: zero323 Closes #17783 from zero323/SPARK-20490. --- R/pkg/NAMESPACE | 4 +- R/pkg/R/column.R | 55 ++++++++++++++++++++++- R/pkg/R/functions.R | 31 +++++++++++++ R/pkg/R/generics.R | 8 ++++ R/pkg/inst/tests/testthat/test_context.R | 4 +- R/pkg/inst/tests/testthat/test_sparkSQL.R | 20 +++++++++ 6 files changed, 117 insertions(+), 5 deletions(-) diff --git a/R/pkg/NAMESPACE b/R/pkg/NAMESPACE index db8e06db18edc..e8de34d9371a0 100644 --- a/R/pkg/NAMESPACE +++ b/R/pkg/NAMESPACE @@ -182,7 +182,8 @@ exportMethods("arrange", exportClasses("Column") -exportMethods("%in%", +exportMethods("%<=>%", + "%in%", "abs", "acos", "add_months", @@ -291,6 +292,7 @@ exportMethods("%in%", "nanvl", "negate", "next_day", + "not", "ntile", "otherwise", "over", diff --git a/R/pkg/R/column.R b/R/pkg/R/column.R index 539d91b0f8797..147ee4b6887b9 100644 --- a/R/pkg/R/column.R +++ b/R/pkg/R/column.R @@ -67,8 +67,7 @@ operators <- list( "+" = "plus", "-" = "minus", "*" = "multiply", "/" = "divide", "%%" = "mod", "==" = "equalTo", ">" = "gt", "<" = "lt", "!=" = "notEqual", "<=" = "leq", ">=" = "geq", # we can not override `&&` and `||`, so use `&` and `|` instead - "&" = "and", "|" = "or", #, "!" = "unary_$bang" - "^" = "pow" + "&" = "and", "|" = "or", "^" = "pow" ) column_functions1 <- c("asc", "desc", "isNaN", "isNull", "isNotNull") column_functions2 <- c("like", "rlike", "getField", "getItem", "contains") @@ -302,3 +301,55 @@ setMethod("otherwise", jc <- callJMethod(x@jc, "otherwise", value) column(jc) }) + +#' \%<=>\% +#' +#' Equality test that is safe for null values. +#' +#' Can be used, unlike standard equality operator, to perform null-safe joins. +#' Equivalent to Scala \code{Column.<=>} and \code{Column.eqNullSafe}. +#' +#' @param x a Column +#' @param value a value to compare +#' @rdname eq_null_safe +#' @name %<=>% +#' @aliases %<=>%,Column-method +#' @export +#' @examples +#' \dontrun{ +#' df1 <- createDataFrame(data.frame( +#' x = c(1, NA, 3, NA), y = c(2, 6, 3, NA) +#' )) +#' +#' head(select(df1, df1$x == df1$y, df1$x %<=>% df1$y)) +#' +#' df2 <- createDataFrame(data.frame(y = c(3, NA))) +#' count(join(df1, df2, df1$y == df2$y)) +#' +#' count(join(df1, df2, df1$y %<=>% df2$y)) +#' } +#' @note \%<=>\% since 2.3.0 +setMethod("%<=>%", + signature(x = "Column", value = "ANY"), + function(x, value) { + value <- if (class(value) == "Column") { value@jc } else { value } + jc <- callJMethod(x@jc, "eqNullSafe", value) + column(jc) + }) + +#' ! +#' +#' Inversion of boolean expression. +#' +#' @rdname not +#' @name not +#' @aliases !,Column-method +#' @export +#' @examples +#' \dontrun{ +#' df <- createDataFrame(data.frame(x = c(-1, 0, 1))) +#' +#' head(select(df, !column("x") > 0)) +#' } +#' @note ! since 2.3.0 +setMethod("!", signature(x = "Column"), function(x) not(x)) diff --git a/R/pkg/R/functions.R b/R/pkg/R/functions.R index f4a34fbabe4d7..f9687d680e7a2 100644 --- a/R/pkg/R/functions.R +++ b/R/pkg/R/functions.R @@ -3859,3 +3859,34 @@ setMethod("posexplode_outer", jc <- callJStatic("org.apache.spark.sql.functions", "posexplode_outer", x@jc) column(jc) }) + +#' not +#' +#' Inversion of boolean expression. +#' +#' \code{not} and \code{!} cannot be applied directly to numerical column. +#' To achieve R-like truthiness column has to be casted to \code{BooleanType}. +#' +#' @param x Column to compute on +#' @rdname not +#' @name not +#' @aliases not,Column-method +#' @export +#' @examples \dontrun{ +#' df <- createDataFrame(data.frame( +#' is_true = c(TRUE, FALSE, NA), +#' flag = c(1, 0, 1) +#' )) +#' +#' head(select(df, not(df$is_true))) +#' +#' # Explicit cast is required when working with numeric column +#' head(select(df, not(cast(df$flag, "boolean")))) +#' } +#' @note not since 2.3.0 +setMethod("not", + signature(x = "Column"), + function(x) { + jc <- callJStatic("org.apache.spark.sql.functions", "not", x@jc) + column(jc) + }) diff --git a/R/pkg/R/generics.R b/R/pkg/R/generics.R index e510ff9a2d80f..d4e4958dc078c 100644 --- a/R/pkg/R/generics.R +++ b/R/pkg/R/generics.R @@ -856,6 +856,10 @@ setGeneric("otherwise", function(x, value) { standardGeneric("otherwise") }) #' @export setGeneric("over", function(x, window) { standardGeneric("over") }) +#' @rdname eq_null_safe +#' @export +setGeneric("%<=>%", function(x, value) { standardGeneric("%<=>%") }) + ###################### WindowSpec Methods ########################## #' @rdname partitionBy @@ -1154,6 +1158,10 @@ setGeneric("nanvl", function(y, x) { standardGeneric("nanvl") }) #' @export setGeneric("negate", function(x) { standardGeneric("negate") }) +#' @rdname not +#' @export +setGeneric("not", function(x) { standardGeneric("not") }) + #' @rdname next_day #' @export setGeneric("next_day", function(y, x) { standardGeneric("next_day") }) diff --git a/R/pkg/inst/tests/testthat/test_context.R b/R/pkg/inst/tests/testthat/test_context.R index c847113491113..c64fe6edcd49e 100644 --- a/R/pkg/inst/tests/testthat/test_context.R +++ b/R/pkg/inst/tests/testthat/test_context.R @@ -21,10 +21,10 @@ test_that("Check masked functions", { # Check that we are not masking any new function from base, stats, testthat unexpectedly # NOTE: We should avoid adding entries to *namesOfMaskedCompletely* as masked functions make it # hard for users to use base R functions. Please check when in doubt. - namesOfMaskedCompletely <- c("cov", "filter", "sample") + namesOfMaskedCompletely <- c("cov", "filter", "sample", "not") namesOfMasked <- c("describe", "cov", "filter", "lag", "na.omit", "predict", "sd", "var", "colnames", "colnames<-", "intersect", "rank", "rbind", "sample", "subset", - "summary", "transform", "drop", "window", "as.data.frame", "union") + "summary", "transform", "drop", "window", "as.data.frame", "union", "not") if (as.numeric(R.version$major) >= 3 && as.numeric(R.version$minor) >= 3) { namesOfMasked <- c("endsWith", "startsWith", namesOfMasked) } diff --git a/R/pkg/inst/tests/testthat/test_sparkSQL.R b/R/pkg/inst/tests/testthat/test_sparkSQL.R index 1828cddffd27c..08296354ca7ed 100644 --- a/R/pkg/inst/tests/testthat/test_sparkSQL.R +++ b/R/pkg/inst/tests/testthat/test_sparkSQL.R @@ -1323,6 +1323,8 @@ test_that("column operators", { c3 <- (c + c2 - c2) * c2 %% c2 c4 <- (c > c2) & (c2 <= c3) | (c == c2) & (c2 != c3) c5 <- c2 ^ c3 ^ c4 + c6 <- c2 %<=>% c3 + c7 <- !c6 }) test_that("column functions", { @@ -1348,6 +1350,7 @@ test_that("column functions", { c19 <- spark_partition_id() + coalesce(c) + coalesce(c1, c2, c3) c20 <- to_timestamp(c) + to_timestamp(c, "yyyy") + to_date(c, "yyyy") c21 <- posexplode_outer(c) + explode_outer(c) + c22 <- not(c) # Test if base::is.nan() is exposed expect_equal(is.nan(c("a", "b")), c(FALSE, FALSE)) @@ -1488,6 +1491,13 @@ test_that("column functions", { lapply( list(list(x = 1, y = -1, z = -2), list(x = 2, y = 3, z = 5)), as.environment)) + + df <- as.DataFrame(data.frame(is_true = c(TRUE, FALSE, NA))) + expect_equal( + collect(select(df, alias(not(df$is_true), "is_false"))), + data.frame(is_false = c(FALSE, TRUE, NA)) + ) + }) test_that("column binary mathfunctions", { @@ -1973,6 +1983,16 @@ test_that("filter() on a DataFrame", { filtered6 <- where(df, df$age %in% c(19, 30)) expect_equal(count(filtered6), 2) + # test suites for %<=>% + dfNa <- read.json(jsonPathNa) + expect_equal(count(filter(dfNa, dfNa$age %<=>% 60)), 1) + expect_equal(count(filter(dfNa, !(dfNa$age %<=>% 60))), 5 - 1) + expect_equal(count(filter(dfNa, dfNa$age %<=>% NULL)), 3) + expect_equal(count(filter(dfNa, !(dfNa$age %<=>% NULL))), 5 - 3) + # match NA from two columns + expect_equal(count(filter(dfNa, dfNa$age %<=>% dfNa$height)), 2) + expect_equal(count(filter(dfNa, !(dfNa$age %<=>% dfNa$height))), 5 - 2) + # Test stats::filter is working #expect_true(is.ts(filter(1:100, rep(1, 3)))) # nolint }) From a355b667a3718d9c5d48a0781e836bf5418ab842 Mon Sep 17 00:00:00 2001 From: Felix Cheung Date: Sun, 30 Apr 2017 23:23:49 -0700 Subject: [PATCH 0376/1765] [SPARK-20541][SPARKR][SS] support awaitTermination without timeout ## What changes were proposed in this pull request? Add without param for timeout - will need this to submit a job that runs until stopped Need this for 2.2 ## How was this patch tested? manually, unit test Author: Felix Cheung Closes #17815 from felixcheung/rssawaitinfinite. --- R/pkg/R/generics.R | 2 +- R/pkg/R/streaming.R | 14 ++++++++++---- R/pkg/inst/tests/testthat/test_streaming.R | 1 + 3 files changed, 12 insertions(+), 5 deletions(-) diff --git a/R/pkg/R/generics.R b/R/pkg/R/generics.R index d4e4958dc078c..ef36765a7a725 100644 --- a/R/pkg/R/generics.R +++ b/R/pkg/R/generics.R @@ -1518,7 +1518,7 @@ setGeneric("write.ml", function(object, path, ...) { standardGeneric("write.ml") #' @rdname awaitTermination #' @export -setGeneric("awaitTermination", function(x, timeout) { standardGeneric("awaitTermination") }) +setGeneric("awaitTermination", function(x, timeout = NULL) { standardGeneric("awaitTermination") }) #' @rdname isActive #' @export diff --git a/R/pkg/R/streaming.R b/R/pkg/R/streaming.R index e353d2dd07c3d..8390bd5e6de72 100644 --- a/R/pkg/R/streaming.R +++ b/R/pkg/R/streaming.R @@ -169,8 +169,10 @@ setMethod("isActive", #' immediately. #' #' @param x a StreamingQuery. -#' @param timeout time to wait in milliseconds -#' @return TRUE if query has terminated within the timeout period. +#' @param timeout time to wait in milliseconds, if omitted, wait indefinitely until \code{stopQuery} +#' is called or an error has occured. +#' @return TRUE if query has terminated within the timeout period; nothing if timeout is not +#' specified. #' @rdname awaitTermination #' @name awaitTermination #' @aliases awaitTermination,StreamingQuery-method @@ -182,8 +184,12 @@ setMethod("isActive", #' @note experimental setMethod("awaitTermination", signature(x = "StreamingQuery"), - function(x, timeout) { - handledCallJMethod(x@ssq, "awaitTermination", as.integer(timeout)) + function(x, timeout = NULL) { + if (is.null(timeout)) { + invisible(handledCallJMethod(x@ssq, "awaitTermination")) + } else { + handledCallJMethod(x@ssq, "awaitTermination", as.integer(timeout)) + } }) #' stopQuery diff --git a/R/pkg/inst/tests/testthat/test_streaming.R b/R/pkg/inst/tests/testthat/test_streaming.R index 1f4054a84df53..b125cb0591de2 100644 --- a/R/pkg/inst/tests/testthat/test_streaming.R +++ b/R/pkg/inst/tests/testthat/test_streaming.R @@ -61,6 +61,7 @@ test_that("read.stream, write.stream, awaitTermination, stopQuery", { stopQuery(q) expect_true(awaitTermination(q, 1)) + expect_error(awaitTermination(q), NA) }) test_that("print from explain, lastProgress, status, isActive", { From f0169a1c6a1ac06045d57f8aaa2c841bb39e23ac Mon Sep 17 00:00:00 2001 From: zero323 Date: Mon, 1 May 2017 09:43:32 -0700 Subject: [PATCH 0377/1765] [SPARK-20290][MINOR][PYTHON][SQL] Add PySpark wrapper for eqNullSafe ## What changes were proposed in this pull request? Adds Python bindings for `Column.eqNullSafe` ## How was this patch tested? Manual tests, existing unit tests, doc build. Author: zero323 Closes #17605 from zero323/SPARK-20290. --- python/pyspark/sql/column.py | 55 ++++++++++++++++++++++++++++++++++++ python/pyspark/sql/tests.py | 2 +- 2 files changed, 56 insertions(+), 1 deletion(-) diff --git a/python/pyspark/sql/column.py b/python/pyspark/sql/column.py index b8df37f25180f..e753ed402cdd7 100644 --- a/python/pyspark/sql/column.py +++ b/python/pyspark/sql/column.py @@ -171,6 +171,61 @@ def __init__(self, jc): __ge__ = _bin_op("geq") __gt__ = _bin_op("gt") + _eqNullSafe_doc = """ + Equality test that is safe for null values. + + :param other: a value or :class:`Column` + + >>> from pyspark.sql import Row + >>> df1 = spark.createDataFrame([ + ... Row(id=1, value='foo'), + ... Row(id=2, value=None) + ... ]) + >>> df1.select( + ... df1['value'] == 'foo', + ... df1['value'].eqNullSafe('foo'), + ... df1['value'].eqNullSafe(None) + ... ).show() + +-------------+---------------+----------------+ + |(value = foo)|(value <=> foo)|(value <=> NULL)| + +-------------+---------------+----------------+ + | true| true| false| + | null| false| true| + +-------------+---------------+----------------+ + >>> df2 = spark.createDataFrame([ + ... Row(value = 'bar'), + ... Row(value = None) + ... ]) + >>> df1.join(df2, df1["value"] == df2["value"]).count() + 0 + >>> df1.join(df2, df1["value"].eqNullSafe(df2["value"])).count() + 1 + >>> df2 = spark.createDataFrame([ + ... Row(id=1, value=float('NaN')), + ... Row(id=2, value=42.0), + ... Row(id=3, value=None) + ... ]) + >>> df2.select( + ... df2['value'].eqNullSafe(None), + ... df2['value'].eqNullSafe(float('NaN')), + ... df2['value'].eqNullSafe(42.0) + ... ).show() + +----------------+---------------+----------------+ + |(value <=> NULL)|(value <=> NaN)|(value <=> 42.0)| + +----------------+---------------+----------------+ + | false| true| false| + | false| false| true| + | true| false| false| + +----------------+---------------+----------------+ + + .. note:: Unlike Pandas, PySpark doesn't consider NaN values to be NULL. + See the `NaN Semantics`_ for details. + .. _NaN Semantics: + https://spark.apache.org/docs/latest/sql-programming-guide.html#nan-semantics + .. versionadded:: 2.3.0 + """ + eqNullSafe = _bin_op("eqNullSafe", _eqNullSafe_doc) + # `and`, `or`, `not` cannot be overloaded in Python, # so use bitwise operators as boolean operators __and__ = _bin_op('and') diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index cd92148dfa5df..ce4abf8fb7e5c 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -982,7 +982,7 @@ def test_column_operators(self): cbool = (ci & ci), (ci | ci), (~ci) self.assertTrue(all(isinstance(c, Column) for c in cbool)) css = cs.contains('a'), cs.like('a'), cs.rlike('a'), cs.asc(), cs.desc(),\ - cs.startswith('a'), cs.endswith('a') + cs.startswith('a'), cs.endswith('a'), ci.eqNullSafe(cs) self.assertTrue(all(isinstance(c, Column) for c in css)) self.assertTrue(isinstance(ci.cast(LongType()), Column)) self.assertRaisesRegexp(ValueError, From 6b44c4d63ab14162e338c5f1ac77333956870a90 Mon Sep 17 00:00:00 2001 From: Herman van Hovell Date: Mon, 1 May 2017 09:46:35 -0700 Subject: [PATCH 0378/1765] [SPARK-20534][SQL] Make outer generate exec return empty rows ## What changes were proposed in this pull request? Generate exec does not produce `null` values if the generator for the input row is empty and the generate operates in outer mode without join. This is caused by the fact that the `join=false` code path is different from the `join=true` code path, and that the `join=false` code path did deal with outer properly. This PR addresses this issue. ## How was this patch tested? Updated `outer*` tests in `GeneratorFunctionSuite`. Author: Herman van Hovell Closes #17810 from hvanhovell/SPARK-20534. --- .../sql/catalyst/optimizer/Optimizer.scala | 3 +- .../plans/logical/basicLogicalOperators.scala | 2 +- .../spark/sql/execution/GenerateExec.scala | 33 ++++++++++--------- .../spark/sql/GeneratorFunctionSuite.scala | 12 +++---- 4 files changed, 26 insertions(+), 24 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index dd768d18e8588..f2b9764b0f088 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -441,8 +441,7 @@ object ColumnPruning extends Rule[LogicalPlan] { g.copy(child = prunedChild(g.child, g.references)) // Turn off `join` for Generate if no column from it's child is used - case p @ Project(_, g: Generate) - if g.join && !g.outer && p.references.subsetOf(g.generatedSet) => + case p @ Project(_, g: Generate) if g.join && p.references.subsetOf(g.generatedSet) => p.copy(child = g.copy(join = false)) // Eliminate unneeded attributes from right side of a Left Existence Join. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala index 3ad757ebba851..f663d7b8a8f7b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala @@ -83,7 +83,7 @@ case class Project(projectList: Seq[NamedExpression], child: LogicalPlan) extend * @param join when true, each output row is implicitly joined with the input tuple that produced * it. * @param outer when true, each input row will be output at least once, even if the output of the - * given `generator` is empty. `outer` has no effect when `join` is false. + * given `generator` is empty. * @param qualifier Qualifier for the attributes of generator(UDTF) * @param generatorOutput The output schema of the Generator. * @param child Children logical plan node diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/GenerateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/GenerateExec.scala index f87d05884b276..1812a1152cb48 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/GenerateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/GenerateExec.scala @@ -32,7 +32,7 @@ import org.apache.spark.sql.types.{ArrayType, DataType, MapType, StructType} private[execution] sealed case class LazyIterator(func: () => TraversableOnce[InternalRow]) extends Iterator[InternalRow] { - lazy val results = func().toIterator + lazy val results: Iterator[InternalRow] = func().toIterator override def hasNext: Boolean = results.hasNext override def next(): InternalRow = results.next() } @@ -50,7 +50,7 @@ private[execution] sealed case class LazyIterator(func: () => TraversableOnce[In * @param join when true, each output row is implicitly joined with the input tuple that produced * it. * @param outer when true, each input row will be output at least once, even if the output of the - * given `generator` is empty. `outer` has no effect when `join` is false. + * given `generator` is empty. * @param generatorOutput the qualified output attributes of the generator of this node, which * constructed in analysis phase, and we can not change it, as the * parent node bound with it already. @@ -78,15 +78,15 @@ case class GenerateExec( override def outputPartitioning: Partitioning = child.outputPartitioning - val boundGenerator = BindReferences.bindReference(generator, child.output) + val boundGenerator: Generator = BindReferences.bindReference(generator, child.output) protected override def doExecute(): RDD[InternalRow] = { // boundGenerator.terminate() should be triggered after all of the rows in the partition - val rows = if (join) { - child.execute().mapPartitionsInternal { iter => - val generatorNullRow = new GenericInternalRow(generator.elementSchema.length) + val numOutputRows = longMetric("numOutputRows") + child.execute().mapPartitionsWithIndexInternal { (index, iter) => + val generatorNullRow = new GenericInternalRow(generator.elementSchema.length) + val rows = if (join) { val joinedRow = new JoinedRow - iter.flatMap { row => // we should always set the left (child output) joinedRow.withLeft(row) @@ -101,18 +101,21 @@ case class GenerateExec( // keep it the same as Hive does joinedRow.withRight(row) } + } else { + iter.flatMap { row => + val outputRows = boundGenerator.eval(row) + if (outer && outputRows.isEmpty) { + Seq(generatorNullRow) + } else { + outputRows + } + } ++ LazyIterator(boundGenerator.terminate) } - } else { - child.execute().mapPartitionsInternal { iter => - iter.flatMap(boundGenerator.eval) ++ LazyIterator(boundGenerator.terminate) - } - } - val numOutputRows = longMetric("numOutputRows") - rows.mapPartitionsWithIndexInternal { (index, iter) => + // Convert the rows to unsafe rows. val proj = UnsafeProjection.create(output, output) proj.initialize(index) - iter.map { r => + rows.map { r => numOutputRows += 1 proj(r) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/GeneratorFunctionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/GeneratorFunctionSuite.scala index cef5bbf0e85a7..b9871afd59e4f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/GeneratorFunctionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/GeneratorFunctionSuite.scala @@ -91,7 +91,7 @@ class GeneratorFunctionSuite extends QueryTest with SharedSQLContext { val df = Seq((1, Seq(1, 2, 3)), (2, Seq())).toDF("a", "intList") checkAnswer( df.select(explode_outer('intList)), - Row(1) :: Row(2) :: Row(3) :: Nil) + Row(1) :: Row(2) :: Row(3) :: Row(null) :: Nil) } test("single posexplode") { @@ -105,7 +105,7 @@ class GeneratorFunctionSuite extends QueryTest with SharedSQLContext { val df = Seq((1, Seq(1, 2, 3)), (2, Seq())).toDF("a", "intList") checkAnswer( df.select(posexplode_outer('intList)), - Row(0, 1) :: Row(1, 2) :: Row(2, 3) :: Nil) + Row(0, 1) :: Row(1, 2) :: Row(2, 3) :: Row(null, null) :: Nil) } test("explode and other columns") { @@ -161,7 +161,7 @@ class GeneratorFunctionSuite extends QueryTest with SharedSQLContext { checkAnswer( df.select(explode_outer('intList).as('int)).select('int), - Row(1) :: Row(2) :: Row(3) :: Nil) + Row(1) :: Row(2) :: Row(3) :: Row(null) :: Nil) checkAnswer( df.select(explode('intList).as('int)).select(sum('int)), @@ -182,7 +182,7 @@ class GeneratorFunctionSuite extends QueryTest with SharedSQLContext { checkAnswer( df.select(explode_outer('map)), - Row("a", "b") :: Row("c", "d") :: Nil) + Row("a", "b") :: Row(null, null) :: Row("c", "d") :: Nil) } test("explode on map with aliases") { @@ -198,7 +198,7 @@ class GeneratorFunctionSuite extends QueryTest with SharedSQLContext { checkAnswer( df.select(explode_outer('map).as("key1" :: "value1" :: Nil)).select("key1", "value1"), - Row("a", "b") :: Nil) + Row("a", "b") :: Row(null, null) :: Nil) } test("self join explode") { @@ -279,7 +279,7 @@ class GeneratorFunctionSuite extends QueryTest with SharedSQLContext { ) checkAnswer( df2.selectExpr("inline_outer(col1)"), - Row(3, "4") :: Row(5, "6") :: Nil + Row(null, null) :: Row(3, "4") :: Row(5, "6") :: Nil ) } From ab30590f448d05fc1864c54a59b6815bdeef8fc7 Mon Sep 17 00:00:00 2001 From: jerryshao Date: Mon, 1 May 2017 10:25:29 -0700 Subject: [PATCH 0379/1765] [SPARK-20517][UI] Fix broken history UI download link The download link in history server UI is concatenated with: ``` Download{{duration}} {{sparkUser}} {{lastUpdated}}DownloadDownload
    +# Structured Streaming + +SparkR supports the Structured Streaming API (experimental). Structured Streaming is a scalable and fault-tolerant stream processing engine built on the Spark SQL engine. For more information see the R API on the [Structured Streaming Programming Guide](structured-streaming-programming-guide.html) + # R Function Name Conflicts When loading and attaching a new package in R, it is possible to have a name [conflict](https://stat.ethz.ch/R-manual/R-devel/library/base/html/library.html), where a diff --git a/docs/structured-streaming-programming-guide.md b/docs/structured-streaming-programming-guide.md index 5b18cf2f3c2ef..53b3db21da769 100644 --- a/docs/structured-streaming-programming-guide.md +++ b/docs/structured-streaming-programming-guide.md @@ -8,13 +8,13 @@ title: Structured Streaming Programming Guide {:toc} # Overview -Structured Streaming is a scalable and fault-tolerant stream processing engine built on the Spark SQL engine. You can express your streaming computation the same way you would express a batch computation on static data. The Spark SQL engine will take care of running it incrementally and continuously and updating the final result as streaming data continues to arrive. You can use the [Dataset/DataFrame API](sql-programming-guide.html) in Scala, Java or Python to express streaming aggregations, event-time windows, stream-to-batch joins, etc. The computation is executed on the same optimized Spark SQL engine. Finally, the system ensures end-to-end exactly-once fault-tolerance guarantees through checkpointing and Write Ahead Logs. In short, *Structured Streaming provides fast, scalable, fault-tolerant, end-to-end exactly-once stream processing without the user having to reason about streaming.* +Structured Streaming is a scalable and fault-tolerant stream processing engine built on the Spark SQL engine. You can express your streaming computation the same way you would express a batch computation on static data. The Spark SQL engine will take care of running it incrementally and continuously and updating the final result as streaming data continues to arrive. You can use the [Dataset/DataFrame API](sql-programming-guide.html) in Scala, Java, Python or R to express streaming aggregations, event-time windows, stream-to-batch joins, etc. The computation is executed on the same optimized Spark SQL engine. Finally, the system ensures end-to-end exactly-once fault-tolerance guarantees through checkpointing and Write Ahead Logs. In short, *Structured Streaming provides fast, scalable, fault-tolerant, end-to-end exactly-once stream processing without the user having to reason about streaming.* -**Structured Streaming is still ALPHA in Spark 2.1** and the APIs are still experimental. In this guide, we are going to walk you through the programming model and the APIs. First, let's start with a simple example - a streaming word count. +**Structured Streaming is still ALPHA in Spark 2.1** and the APIs are still experimental. In this guide, we are going to walk you through the programming model and the APIs. First, let's start with a simple example - a streaming word count. # Quick Example -Let’s say you want to maintain a running word count of text data received from a data server listening on a TCP socket. Let’s see how you can express this using Structured Streaming. You can see the full code in -[Scala]({{site.SPARK_GITHUB_URL}}/blob/v{{site.SPARK_VERSION_SHORT}}/examples/src/main/scala/org/apache/spark/examples/sql/streaming/StructuredNetworkWordCount.scala)/[Java]({{site.SPARK_GITHUB_URL}}/blob/v{{site.SPARK_VERSION_SHORT}}/examples/src/main/java/org/apache/spark/examples/sql/streaming/JavaStructuredNetworkWordCount.java)/[Python]({{site.SPARK_GITHUB_URL}}/blob/v{{site.SPARK_VERSION_SHORT}}/examples/src/main/python/sql/streaming/structured_network_wordcount.py). +Let’s say you want to maintain a running word count of text data received from a data server listening on a TCP socket. Let’s see how you can express this using Structured Streaming. You can see the full code in +[Scala]({{site.SPARK_GITHUB_URL}}/blob/v{{site.SPARK_VERSION_SHORT}}/examples/src/main/scala/org/apache/spark/examples/sql/streaming/StructuredNetworkWordCount.scala)/[Java]({{site.SPARK_GITHUB_URL}}/blob/v{{site.SPARK_VERSION_SHORT}}/examples/src/main/java/org/apache/spark/examples/sql/streaming/JavaStructuredNetworkWordCount.java)/[Python]({{site.SPARK_GITHUB_URL}}/blob/v{{site.SPARK_VERSION_SHORT}}/examples/src/main/python/sql/streaming/structured_network_wordcount.py)/[R]({{site.SPARK_GITHUB_URL}}/blob/v{{site.SPARK_VERSION_SHORT}}/examples/src/main/r/streaming/structured_network_wordcount.R). And if you [download Spark](http://spark.apache.org/downloads.html), you can directly run the example. In any case, let’s walk through the example step-by-step and understand how it works. First, we have to import the necessary classes and create a local SparkSession, the starting point of all functionalities related to Spark.
    @@ -63,6 +63,13 @@ spark = SparkSession \ .getOrCreate() {% endhighlight %} +
    +
    + +{% highlight r %} +sparkR.session(appName = "StructuredNetworkWordCount") +{% endhighlight %} +
    @@ -136,6 +143,22 @@ wordCounts = words.groupBy("word").count() This `lines` DataFrame represents an unbounded table containing the streaming text data. This table contains one column of strings named "value", and each line in the streaming text data becomes a row in the table. Note, that this is not currently receiving any data as we are just setting up the transformation, and have not yet started it. Next, we have used two built-in SQL functions - split and explode, to split each line into multiple rows with a word each. In addition, we use the function `alias` to name the new column as "word". Finally, we have defined the `wordCounts` DataFrame by grouping by the unique values in the Dataset and counting them. Note that this is a streaming DataFrame which represents the running word counts of the stream. + +
    + +{% highlight r %} +# Create DataFrame representing the stream of input lines from connection to localhost:9999 +lines <- read.stream("socket", host = "localhost", port = 9999) + +# Split the lines into words +words <- selectExpr(lines, "explode(split(value, ' ')) as word") + +# Generate running word count +wordCounts <- count(group_by(words, "word")) +{% endhighlight %} + +This `lines` SparkDataFrame represents an unbounded table containing the streaming text data. This table contains one column of strings named "value", and each line in the streaming text data becomes a row in the table. Note, that this is not currently receiving any data as we are just setting up the transformation, and have not yet started it. Next, we have a SQL expression with two SQL functions - split and explode, to split each line into multiple rows with a word each. In addition, we name the new column as "word". Finally, we have defined the `wordCounts` SparkDataFrame by grouping by the unique values in the SparkDataFrame and counting them. Note that this is a streaming SparkDataFrame which represents the running word counts of the stream. +
    @@ -181,10 +204,20 @@ query = wordCounts \ query.awaitTermination() {% endhighlight %} + +
    + +{% highlight r %} +# Start running the query that prints the running counts to the console +query <- write.stream(wordCounts, "console", outputMode = "complete") + +awaitTermination(query) +{% endhighlight %} +
    -After this code is executed, the streaming computation will have started in the background. The `query` object is a handle to that active streaming query, and we have decided to wait for the termination of the query using `query.awaitTermination()` to prevent the process from exiting while the query is active. +After this code is executed, the streaming computation will have started in the background. The `query` object is a handle to that active streaming query, and we have decided to wait for the termination of the query using `awaitTermination()` to prevent the process from exiting while the query is active. To actually execute this example code, you can either compile the code in your own [Spark application](quick-start.html#self-contained-applications), or simply @@ -211,6 +244,11 @@ $ ./bin/run-example org.apache.spark.examples.sql.streaming.JavaStructuredNetwor $ ./bin/spark-submit examples/src/main/python/sql/streaming/structured_network_wordcount.py localhost 9999 {% endhighlight %} +
    +{% highlight bash %} +$ ./bin/spark-submit examples/src/main/r/streaming/structured_network_wordcount.R localhost 9999 +{% endhighlight %} +
    Then, any lines typed in the terminal running the netcat server will be counted and printed on screen every second. It will look something like the following. @@ -325,6 +363,35 @@ Batch: 0 | spark| 1| +------+-----+ +------------------------------------------- +Batch: 1 +------------------------------------------- ++------+-----+ +| value|count| ++------+-----+ +|apache| 2| +| spark| 1| +|hadoop| 1| ++------+-----+ +... +{% endhighlight %} + +
    +{% highlight bash %} +# TERMINAL 2: RUNNING structured_network_wordcount.R + +$ ./bin/spark-submit examples/src/main/r/streaming/structured_network_wordcount.R localhost 9999 + +------------------------------------------- +Batch: 0 +------------------------------------------- ++------+-----+ +| value|count| ++------+-----+ +|apache| 1| +| spark| 1| ++------+-----+ + ------------------------------------------- Batch: 1 ------------------------------------------- @@ -409,14 +476,14 @@ to track the read position in the stream. The engine uses checkpointing and writ # API using Datasets and DataFrames Since Spark 2.0, DataFrames and Datasets can represent static, bounded data, as well as streaming, unbounded data. Similar to static Datasets/DataFrames, you can use the common entry point `SparkSession` -([Scala](api/scala/index.html#org.apache.spark.sql.SparkSession)/[Java](api/java/org/apache/spark/sql/SparkSession.html)/[Python](api/python/pyspark.sql.html#pyspark.sql.SparkSession) docs) +([Scala](api/scala/index.html#org.apache.spark.sql.SparkSession)/[Java](api/java/org/apache/spark/sql/SparkSession.html)/[Python](api/python/pyspark.sql.html#pyspark.sql.SparkSession)/[R](api/R/sparkR.session.html) docs) to create streaming DataFrames/Datasets from streaming sources, and apply the same operations on them as static DataFrames/Datasets. If you are not familiar with Datasets/DataFrames, you are strongly advised to familiarize yourself with them using the [DataFrame/Dataset Programming Guide](sql-programming-guide.html). ## Creating streaming DataFrames and streaming Datasets -Streaming DataFrames can be created through the `DataStreamReader` interface +Streaming DataFrames can be created through the `DataStreamReader` interface ([Scala](api/scala/index.html#org.apache.spark.sql.streaming.DataStreamReader)/[Java](api/java/org/apache/spark/sql/streaming/DataStreamReader.html)/[Python](api/python/pyspark.sql.html#pyspark.sql.streaming.DataStreamReader) docs) -returned by `SparkSession.readStream()`. Similar to the read interface for creating static DataFrame, you can specify the details of the source – data format, schema, options, etc. +returned by `SparkSession.readStream()`. In [R](api/R/read.stream.html), with the `read.stream()` method. Similar to the read interface for creating static DataFrame, you can specify the details of the source – data format, schema, options, etc. #### Input Sources In Spark 2.0, there are a few built-in sources. @@ -445,7 +512,8 @@ Here are the details of all the sources in Spark. path: path to the input directory, and common to all file formats.

    For file-format-specific options, see the related methods in DataStreamReader - (Scala/Java/Python). + (Scala/Java/Python/R). E.g. for "parquet" format options see DataStreamReader.parquet() Yes Supports glob paths, but does not support multiple comma-separated paths/globs. @@ -483,7 +551,7 @@ Here are some examples. {% highlight scala %} val spark: SparkSession = ... -// Read text from socket +// Read text from socket val socketDF = spark .readStream .format("socket") @@ -493,7 +561,7 @@ val socketDF = spark socketDF.isStreaming // Returns True for DataFrames that have streaming sources -socketDF.printSchema +socketDF.printSchema // Read all the csv files written atomically in a directory val userSchema = new StructType().add("name", "string").add("age", "integer") @@ -510,7 +578,7 @@ val csvDF = spark {% highlight java %} SparkSession spark = ... -// Read text from socket +// Read text from socket Dataset socketDF = spark .readStream() .format("socket") @@ -537,7 +605,7 @@ Dataset csvDF = spark {% highlight python %} spark = SparkSession. ... -# Read text from socket +# Read text from socket socketDF = spark \ .readStream \ .format("socket") \ @@ -547,7 +615,7 @@ socketDF = spark \ socketDF.isStreaming() # Returns True for DataFrames that have streaming sources -socketDF.printSchema() +socketDF.printSchema() # Read all the csv files written atomically in a directory userSchema = StructType().add("name", "string").add("age", "integer") @@ -558,6 +626,25 @@ csvDF = spark \ .csv("/path/to/directory") # Equivalent to format("csv").load("/path/to/directory") {% endhighlight %} +
    +
    + +{% highlight r %} +sparkR.session(...) + +# Read text from socket +socketDF <- read.stream("socket", host = hostname, port = port) + +isStreaming(socketDF) # Returns TRUE for SparkDataFrames that have streaming sources + +printSchema(socketDF) + +# Read all the csv files written atomically in a directory +schema <- structType(structField("name", "string"), + structField("age", "integer")) +csvDF <- read.stream("csv", path = "/path/to/directory", schema = schema, sep = ";") +{% endhighlight %} +
    @@ -638,12 +725,24 @@ ds.groupByKey((MapFunction) value -> value.getDeviceType(), df = ... # streaming DataFrame with IOT device data with schema { device: string, deviceType: string, signal: double, time: DateType } # Select the devices which have signal more than 10 -df.select("device").where("signal > 10") +df.select("device").where("signal > 10") # Running count of the number of updates for each device type df.groupBy("deviceType").count() {% endhighlight %} +
    + +{% highlight r %} +df <- ... # streaming DataFrame with IOT device data with schema { device: string, deviceType: string, signal: double, time: DateType } + +# Select the devices which have signal more than 10 +select(where(df, "signal > 10"), "device") + +# Running count of the number of updates for each device type +count(groupBy(df, "deviceType")) +{% endhighlight %} +
    ### Window Operations on Event Time @@ -840,7 +939,7 @@ Streaming DataFrames can be joined with static DataFrames to create new streamin {% highlight scala %} val staticDf = spark.read. ... -val streamingDf = spark.readStream. ... +val streamingDf = spark.readStream. ... streamingDf.join(staticDf, "type") // inner equi-join with a static DF streamingDf.join(staticDf, "type", "right_join") // right outer join with a static DF @@ -972,7 +1071,7 @@ Once you have defined the final result DataFrame/Dataset, all that is left is fo ([Scala](api/scala/index.html#org.apache.spark.sql.streaming.DataStreamWriter)/[Java](api/java/org/apache/spark/sql/streaming/DataStreamWriter.html)/[Python](api/python/pyspark.sql.html#pyspark.sql.streaming.DataStreamWriter) docs) returned through `Dataset.writeStream()`. You will have to specify one or more of the following in this interface. -- *Details of the output sink:* Data format, location, etc. +- *Details of the output sink:* Data format, location, etc. - *Output mode:* Specify what gets written to the output sink. @@ -1077,7 +1176,7 @@ Here is the compatibility matrix. #### Output Sinks There are a few types of built-in output sinks. -- **File sink** - Stores the output to a directory. +- **File sink** - Stores the output to a directory. {% highlight scala %} writeStream @@ -1145,7 +1244,8 @@ Here are the details of all the sinks in Spark. · "s3a://a/b/c/dataset.txt"

    For file-format-specific options, see the related methods in DataFrameWriter - (Scala/Java/Python). + (Scala/Java/Python/R). E.g. for "parquet" format options see DataFrameWriter.parquet() Yes @@ -1208,7 +1308,7 @@ noAggDF .option("checkpointLocation", "path/to/checkpoint/dir") .option("path", "path/to/destination/dir") .start() - + // ========== DF with aggregation ========== val aggDF = df.groupBy("device").count() @@ -1219,7 +1319,7 @@ aggDF .format("console") .start() -// Have all the aggregates in an in-memory table +// Have all the aggregates in an in-memory table aggDF .writeStream .queryName("aggregates") // this query name will be the table name @@ -1250,7 +1350,7 @@ noAggDF .option("checkpointLocation", "path/to/checkpoint/dir") .option("path", "path/to/destination/dir") .start(); - + // ========== DF with aggregation ========== Dataset aggDF = df.groupBy("device").count(); @@ -1261,7 +1361,7 @@ aggDF .format("console") .start(); -// Have all the aggregates in an in-memory table +// Have all the aggregates in an in-memory table aggDF .writeStream() .queryName("aggregates") // this query name will be the table name @@ -1292,7 +1392,7 @@ noAggDF \ .option("checkpointLocation", "path/to/checkpoint/dir") \ .option("path", "path/to/destination/dir") \ .start() - + # ========== DF with aggregation ========== aggDF = df.groupBy("device").count() @@ -1314,6 +1414,35 @@ aggDF \ spark.sql("select * from aggregates").show() # interactively query in-memory table {% endhighlight %} + +
    + +{% highlight r %} +# ========== DF with no aggregations ========== +noAggDF <- select(where(deviceDataDf, "signal > 10"), "device") + +# Print new data to console +write.stream(noAggDF, "console") + +# Write new data to Parquet files +write.stream(noAggDF, + "parquet", + path = "path/to/destination/dir", + checkpointLocation = "path/to/checkpoint/dir") + +# ========== DF with aggregation ========== +aggDF <- count(groupBy(df, "device")) + +# Print updated aggregations to console +write.stream(aggDF, "console", outputMode = "complete") + +# Have all the aggregates in an in memory table. The query name will be the table name +write.stream(aggDF, "memory", queryName = "aggregates", outputMode = "complete") + +# Interactively query in-memory table +head(sql("select * from aggregates")) +{% endhighlight %} +
    @@ -1351,7 +1480,7 @@ query.name // get the name of the auto-generated or user-specified name query.explain() // print detailed explanations of the query -query.stop() // stop the query +query.stop() // stop the query query.awaitTermination() // block until query is terminated, with stop() or with error @@ -1403,7 +1532,7 @@ query.name() # get the name of the auto-generated or user-specified name query.explain() # print detailed explanations of the query -query.stop() # stop the query +query.stop() # stop the query query.awaitTermination() # block until query is terminated, with stop() or with error @@ -1415,6 +1544,24 @@ query.lastProgress() # the most recent progress update of this streaming quer {% endhighlight %} + +
    + +{% highlight r %} +query <- write.stream(df, "console") # get the query object + +queryName(query) # get the name of the auto-generated or user-specified name + +explain(query) # print detailed explanations of the query + +stopQuery(query) # stop the query + +awaitTermination(query) # block until query is terminated, with stop() or with error + +lastProgress(query) # the most recent progress update of this streaming query + +{% endhighlight %} +
    @@ -1461,6 +1608,12 @@ spark.streams().get(id) # get a query object by its unique id spark.streams().awaitAnyTermination() # block until any one of them terminates {% endhighlight %} + +
    +{% highlight bash %} +Not available in R. +{% endhighlight %} +
    @@ -1644,6 +1797,58 @@ Will print something like the following. ''' {% endhighlight %} + +
    + +{% highlight r %} +query <- ... # a StreamingQuery +lastProgress(query) + +''' +Will print something like the following. + +{ + "id" : "8c57e1ec-94b5-4c99-b100-f694162df0b9", + "runId" : "ae505c5a-a64e-4896-8c28-c7cbaf926f16", + "name" : null, + "timestamp" : "2017-04-26T08:27:28.835Z", + "numInputRows" : 0, + "inputRowsPerSecond" : 0.0, + "processedRowsPerSecond" : 0.0, + "durationMs" : { + "getOffset" : 0, + "triggerExecution" : 1 + }, + "stateOperators" : [ { + "numRowsTotal" : 4, + "numRowsUpdated" : 0 + } ], + "sources" : [ { + "description" : "TextSocketSource[host: localhost, port: 9999]", + "startOffset" : 1, + "endOffset" : 1, + "numInputRows" : 0, + "inputRowsPerSecond" : 0.0, + "processedRowsPerSecond" : 0.0 + } ], + "sink" : { + "description" : "org.apache.spark.sql.execution.streaming.ConsoleSink@76b37531" + } +} +''' + +status(query) +''' +Will print something like the following. + +{ + "message" : "Waiting for data to arrive", + "isDataAvailable" : false, + "isTriggerActive" : false +} +''' +{% endhighlight %} +
    @@ -1703,11 +1908,17 @@ spark.streams().addListener(new StreamingQueryListener() { Not available in Python. {% endhighlight %} + +
    +{% highlight bash %} +Not available in R. +{% endhighlight %} +
    ## Recovering from Failures with Checkpointing -In case of a failure or intentional shutdown, you can recover the previous progress and state of a previous query, and continue where it left off. This is done using checkpointing and write ahead logs. You can configure a query with a checkpoint location, and the query will save all the progress information (i.e. range of offsets processed in each trigger) and the running aggregates (e.g. word counts in the [quick example](#quick-example)) to the checkpoint location. This checkpoint location has to be a path in an HDFS compatible file system, and can be set as an option in the DataStreamWriter when [starting a query](#starting-streaming-queries). +In case of a failure or intentional shutdown, you can recover the previous progress and state of a previous query, and continue where it left off. This is done using checkpointing and write ahead logs. You can configure a query with a checkpoint location, and the query will save all the progress information (i.e. range of offsets processed in each trigger) and the running aggregates (e.g. word counts in the [quick example](#quick-example)) to the checkpoint location. This checkpoint location has to be a path in an HDFS compatible file system, and can be set as an option in the DataStreamWriter when [starting a query](#starting-streaming-queries).
    @@ -1745,20 +1956,18 @@ aggDF \ .start() {% endhighlight %} +
    +
    + +{% highlight r %} +write.stream(aggDF, "memory", outputMode = "complete", checkpointLocation = "path/to/HDFS/dir") +{% endhighlight %} +
    # Where to go from here -- Examples: See and run the -[Scala]({{site.SPARK_GITHUB_URL}}/tree/master/examples/src/main/scala/org/apache/spark/examples/sql/streaming)/[Java]({{site.SPARK_GITHUB_URL}}/tree/master/examples/src/main/java/org/apache/spark/examples/sql/streaming)/[Python]({{site.SPARK_GITHUB_URL}}/tree/master/examples/src/main/python/sql/streaming) +- Examples: See and run the +[Scala]({{site.SPARK_GITHUB_URL}}/tree/master/examples/src/main/scala/org/apache/spark/examples/sql/streaming)/[Java]({{site.SPARK_GITHUB_URL}}/tree/master/examples/src/main/java/org/apache/spark/examples/sql/streaming)/[Python]({{site.SPARK_GITHUB_URL}}/tree/master/examples/src/main/python/sql/streaming)/[R]({{site.SPARK_GITHUB_URL}}/tree/master/examples/src/main/r/streaming) examples. - Spark Summit 2016 Talk - [A Deep Dive into Structured Streaming](https://spark-summit.org/2016/events/a-deep-dive-into-structured-streaming/) - - - - - - - - - diff --git a/examples/src/main/r/streaming/structured_network_wordcount.R b/examples/src/main/r/streaming/structured_network_wordcount.R new file mode 100644 index 0000000000000..cda18ebc072ee --- /dev/null +++ b/examples/src/main/r/streaming/structured_network_wordcount.R @@ -0,0 +1,57 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# Counts words in UTF8 encoded, '\n' delimited text received from the network. + +# To run this on your local machine, you need to first run a Netcat server +# $ nc -lk 9999 +# and then run the example +# ./bin/spark-submit examples/src/main/r/streaming/structured_network_wordcount.R localhost 9999 + +# Load SparkR library into your R session +library(SparkR) + +# Initialize SparkSession +sparkR.session(appName = "SparkR-Streaming-structured-network-wordcount-example") + +args <- commandArgs(trailing = TRUE) + +if (length(args) != 2) { + print("Usage: structured_network_wordcount.R ") + print(" and describe the TCP server that Structured Streaming") + print("would connect to receive data.") + q("no") +} + +hostname <- args[[1]] +port <- as.integer(args[[2]]) + +# Create DataFrame representing the stream of input lines from connection to localhost:9999 +lines <- read.stream("socket", host = hostname, port = port) + +# Split the lines into words +words <- selectExpr(lines, "explode(split(value, ' ')) as word") + +# Generate running word count +wordCounts <- count(groupBy(words, "word")) + +# Start running the query that prints the running counts to the console +query <- write.stream(wordCounts, "console", outputMode = "complete") + +awaitTermination(query) + +sparkR.session.stop() From 9c36aa27919fb7625e388f5c3c90af62ef902b24 Mon Sep 17 00:00:00 2001 From: zero323 Date: Thu, 4 May 2017 01:41:36 -0700 Subject: [PATCH 0405/1765] [SPARK-20585][SPARKR] R generic hint support ## What changes were proposed in this pull request? Adds support for generic hints on `SparkDataFrame` ## How was this patch tested? Unit tests, `check-cran.sh` Author: zero323 Closes #17851 from zero323/SPARK-20585. --- R/pkg/NAMESPACE | 1 + R/pkg/R/DataFrame.R | 30 +++++++++++++++++++++++ R/pkg/R/generics.R | 4 +++ R/pkg/inst/tests/testthat/test_sparkSQL.R | 12 +++++++++ 4 files changed, 47 insertions(+) diff --git a/R/pkg/NAMESPACE b/R/pkg/NAMESPACE index 7ecd168137e8d..daa168c87ecd1 100644 --- a/R/pkg/NAMESPACE +++ b/R/pkg/NAMESPACE @@ -123,6 +123,7 @@ exportMethods("arrange", "group_by", "groupBy", "head", + "hint", "insertInto", "intersect", "isLocal", diff --git a/R/pkg/R/DataFrame.R b/R/pkg/R/DataFrame.R index 7e57ba6287bb8..1c8869202f677 100644 --- a/R/pkg/R/DataFrame.R +++ b/R/pkg/R/DataFrame.R @@ -3715,3 +3715,33 @@ setMethod("rollup", sgd <- callJMethod(x@sdf, "rollup", jcol) groupedData(sgd) }) + +#' hint +#' +#' Specifies execution plan hint and return a new SparkDataFrame. +#' +#' @param x a SparkDataFrame. +#' @param name a name of the hint. +#' @param ... optional parameters for the hint. +#' @return A SparkDataFrame. +#' @family SparkDataFrame functions +#' @aliases hint,SparkDataFrame,character-method +#' @rdname hint +#' @name hint +#' @export +#' @examples +#' \dontrun{ +#' df <- createDataFrame(mtcars) +#' avg_mpg <- mean(groupBy(createDataFrame(mtcars), "cyl"), "mpg") +#' +#' head(join(df, hint(avg_mpg, "broadcast"), df$cyl == avg_mpg$cyl)) +#' } +#' @note hint since 2.2.0 +setMethod("hint", + signature(x = "SparkDataFrame", name = "character"), + function(x, name, ...) { + parameters <- list(...) + stopifnot(all(sapply(parameters, is.character))) + jdf <- callJMethod(x@sdf, "hint", name, parameters) + dataFrame(jdf) + }) diff --git a/R/pkg/R/generics.R b/R/pkg/R/generics.R index e02d46426a5a6..56ef1bee93536 100644 --- a/R/pkg/R/generics.R +++ b/R/pkg/R/generics.R @@ -576,6 +576,10 @@ setGeneric("group_by", function(x, ...) { standardGeneric("group_by") }) #' @export setGeneric("groupBy", function(x, ...) { standardGeneric("groupBy") }) +#' @rdname hint +#' @export +setGeneric("hint", function(x, name, ...) { standardGeneric("hint") }) + #' @rdname insertInto #' @export setGeneric("insertInto", function(x, tableName, ...) { standardGeneric("insertInto") }) diff --git a/R/pkg/inst/tests/testthat/test_sparkSQL.R b/R/pkg/inst/tests/testthat/test_sparkSQL.R index a7bb3265d92d7..82007a5348496 100644 --- a/R/pkg/inst/tests/testthat/test_sparkSQL.R +++ b/R/pkg/inst/tests/testthat/test_sparkSQL.R @@ -2182,6 +2182,18 @@ test_that("join(), crossJoin() and merge() on a DataFrame", { unlink(jsonPath2) unlink(jsonPath3) + + # Join with broadcast hint + df1 <- sql("SELECT * FROM range(10e10)") + df2 <- sql("SELECT * FROM range(10e10)") + + execution_plan <- capture.output(explain(join(df1, df2, df1$id == df2$id))) + expect_false(any(grepl("BroadcastHashJoin", execution_plan))) + + execution_plan_hint <- capture.output( + explain(join(df1, hint(df2, "broadcast"), df1$id == df2$id)) + ) + expect_true(any(grepl("BroadcastHashJoin", execution_plan_hint))) }) test_that("toJSON() on DataFrame", { From f21897fc157ce467f2b2edb5631b31787883accd Mon Sep 17 00:00:00 2001 From: zero323 Date: Thu, 4 May 2017 01:51:37 -0700 Subject: [PATCH 0406/1765] [SPARK-20544][SPARKR] R wrapper for input_file_name ## What changes were proposed in this pull request? Adds wrapper for `o.a.s.sql.functions.input_file_name` ## How was this patch tested? Existing unit tests, additional unit tests, `check-cran.sh`. Author: zero323 Closes #17818 from zero323/SPARK-20544. --- R/pkg/NAMESPACE | 1 + R/pkg/R/functions.R | 21 +++++++++++++++++++++ R/pkg/R/generics.R | 6 ++++++ R/pkg/inst/tests/testthat/test_sparkSQL.R | 5 +++++ 4 files changed, 33 insertions(+) diff --git a/R/pkg/NAMESPACE b/R/pkg/NAMESPACE index daa168c87ecd1..ba0fe7708bcc3 100644 --- a/R/pkg/NAMESPACE +++ b/R/pkg/NAMESPACE @@ -258,6 +258,7 @@ exportMethods("%<=>%", "hypot", "ifelse", "initcap", + "input_file_name", "instr", "isNaN", "isNotNull", diff --git a/R/pkg/R/functions.R b/R/pkg/R/functions.R index 3d47b09ce5513..5f9d11475c94b 100644 --- a/R/pkg/R/functions.R +++ b/R/pkg/R/functions.R @@ -3975,3 +3975,24 @@ setMethod("grouping_id", jc <- callJStatic("org.apache.spark.sql.functions", "grouping_id", jcols) column(jc) }) + +#' input_file_name +#' +#' Creates a string column with the input file name for a given row +#' +#' @rdname input_file_name +#' @name input_file_name +#' @family normal_funcs +#' @aliases input_file_name,missing-method +#' @export +#' @examples \dontrun{ +#' df <- read.text("README.md") +#' +#' head(select(df, input_file_name())) +#' } +#' @note input_file_name since 2.3.0 +setMethod("input_file_name", signature("missing"), + function() { + jc <- callJStatic("org.apache.spark.sql.functions", "input_file_name") + column(jc) + }) diff --git a/R/pkg/R/generics.R b/R/pkg/R/generics.R index 56ef1bee93536..e835ef3e4f40d 100644 --- a/R/pkg/R/generics.R +++ b/R/pkg/R/generics.R @@ -1080,6 +1080,12 @@ setGeneric("hypot", function(y, x) { standardGeneric("hypot") }) #' @export setGeneric("initcap", function(x) { standardGeneric("initcap") }) +#' @param x empty. Should be used with no argument. +#' @rdname input_file_name +#' @export +setGeneric("input_file_name", + function(x = "missing") { standardGeneric("input_file_name") }) + #' @rdname instr #' @export setGeneric("instr", function(y, x) { standardGeneric("instr") }) diff --git a/R/pkg/inst/tests/testthat/test_sparkSQL.R b/R/pkg/inst/tests/testthat/test_sparkSQL.R index 82007a5348496..47cc34a6c5b75 100644 --- a/R/pkg/inst/tests/testthat/test_sparkSQL.R +++ b/R/pkg/inst/tests/testthat/test_sparkSQL.R @@ -1402,6 +1402,11 @@ test_that("column functions", { expect_equal(collect(df2)[[3, 1]], FALSE) expect_equal(collect(df2)[[3, 2]], TRUE) + # Test that input_file_name() + actual_names <- sort(collect(distinct(select(df, input_file_name())))) + expect_equal(length(actual_names), 1) + expect_equal(basename(actual_names[1, 1]), basename(jsonPath)) + df3 <- select(df, between(df$name, c("Apache", "Spark"))) expect_equal(collect(df3)[[1, 1]], TRUE) expect_equal(collect(df3)[[2, 1]], FALSE) From 57b64703e66ec8490d8d9dbf6beebc160a61ec29 Mon Sep 17 00:00:00 2001 From: Felix Cheung Date: Thu, 4 May 2017 01:54:59 -0700 Subject: [PATCH 0407/1765] [SPARK-20571][SPARKR][SS] Flaky Structured Streaming tests ## What changes were proposed in this pull request? Make tests more reliable by having it till processed. Increasing timeout value might help but ultimately the flakiness from processing delay when Jenkins is hard to account for. This isn't an actual public API supported ## How was this patch tested? unit tests Author: Felix Cheung Closes #17857 from felixcheung/rsstestrelia. --- R/pkg/inst/tests/testthat/test_streaming.R | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/R/pkg/inst/tests/testthat/test_streaming.R b/R/pkg/inst/tests/testthat/test_streaming.R index 8843991024308..91df7ac6f9849 100644 --- a/R/pkg/inst/tests/testthat/test_streaming.R +++ b/R/pkg/inst/tests/testthat/test_streaming.R @@ -55,10 +55,12 @@ test_that("read.stream, write.stream, awaitTermination, stopQuery", { q <- write.stream(counts, "memory", queryName = "people", outputMode = "complete") expect_false(awaitTermination(q, 5 * 1000)) + callJMethod(q@ssq, "processAllAvailable") expect_equal(head(sql("SELECT count(*) FROM people"))[[1]], 3) writeLines(mockLinesNa, jsonPathNa) awaitTermination(q, 5 * 1000) + callJMethod(q@ssq, "processAllAvailable") expect_equal(head(sql("SELECT count(*) FROM people"))[[1]], 6) stopQuery(q) @@ -75,6 +77,7 @@ test_that("print from explain, lastProgress, status, isActive", { q <- write.stream(counts, "memory", queryName = "people2", outputMode = "complete") awaitTermination(q, 5 * 1000) + callJMethod(q@ssq, "processAllAvailable") expect_equal(capture.output(explain(q))[[1]], "== Physical Plan ==") expect_true(any(grepl("\"description\" : \"MemorySink\"", capture.output(lastProgress(q))))) @@ -99,6 +102,7 @@ test_that("Stream other format", { q <- write.stream(counts, "memory", queryName = "people3", outputMode = "complete") expect_false(awaitTermination(q, 5 * 1000)) + callJMethod(q@ssq, "processAllAvailable") expect_equal(head(sql("SELECT count(*) FROM people3"))[[1]], 3) expect_equal(queryName(q), "people3") From c5dceb8c65545169bc96628140b5acdaa85dd226 Mon Sep 17 00:00:00 2001 From: Yanbo Liang Date: Thu, 4 May 2017 17:56:43 +0800 Subject: [PATCH 0408/1765] [SPARK-20047][FOLLOWUP][ML] Constrained Logistic Regression follow up ## What changes were proposed in this pull request? Address some minor comments for #17715: * Put bound-constrained optimization params under expertParams. * Update some docs. ## How was this patch tested? Existing tests. Author: Yanbo Liang Closes #17829 from yanboliang/spark-20047-followup. --- .../classification/LogisticRegression.scala | 54 ++++++++++++------- 1 file changed, 35 insertions(+), 19 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala index d7dde329ed004..42dc7fbebe4c3 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala @@ -183,14 +183,15 @@ private[classification] trait LogisticRegressionParams extends ProbabilisticClas * The bound matrix must be compatible with the shape (1, number of features) for binomial * regression, or (number of classes, number of features) for multinomial regression. * Otherwise, it throws exception. + * Default is none. * - * @group param + * @group expertParam */ @Since("2.2.0") val lowerBoundsOnCoefficients: Param[Matrix] = new Param(this, "lowerBoundsOnCoefficients", "The lower bounds on coefficients if fitting under bound constrained optimization.") - /** @group getParam */ + /** @group expertGetParam */ @Since("2.2.0") def getLowerBoundsOnCoefficients: Matrix = $(lowerBoundsOnCoefficients) @@ -199,14 +200,15 @@ private[classification] trait LogisticRegressionParams extends ProbabilisticClas * The bound matrix must be compatible with the shape (1, number of features) for binomial * regression, or (number of classes, number of features) for multinomial regression. * Otherwise, it throws exception. + * Default is none. * - * @group param + * @group expertParam */ @Since("2.2.0") val upperBoundsOnCoefficients: Param[Matrix] = new Param(this, "upperBoundsOnCoefficients", "The upper bounds on coefficients if fitting under bound constrained optimization.") - /** @group getParam */ + /** @group expertGetParam */ @Since("2.2.0") def getUpperBoundsOnCoefficients: Matrix = $(upperBoundsOnCoefficients) @@ -214,14 +216,15 @@ private[classification] trait LogisticRegressionParams extends ProbabilisticClas * The lower bounds on intercepts if fitting under bound constrained optimization. * The bounds vector size must be equal with 1 for binomial regression, or the number * of classes for multinomial regression. Otherwise, it throws exception. + * Default is none. * - * @group param + * @group expertParam */ @Since("2.2.0") val lowerBoundsOnIntercepts: Param[Vector] = new Param(this, "lowerBoundsOnIntercepts", "The lower bounds on intercepts if fitting under bound constrained optimization.") - /** @group getParam */ + /** @group expertGetParam */ @Since("2.2.0") def getLowerBoundsOnIntercepts: Vector = $(lowerBoundsOnIntercepts) @@ -229,14 +232,15 @@ private[classification] trait LogisticRegressionParams extends ProbabilisticClas * The upper bounds on intercepts if fitting under bound constrained optimization. * The bound vector size must be equal with 1 for binomial regression, or the number * of classes for multinomial regression. Otherwise, it throws exception. + * Default is none. * - * @group param + * @group expertParam */ @Since("2.2.0") val upperBoundsOnIntercepts: Param[Vector] = new Param(this, "upperBoundsOnIntercepts", "The upper bounds on intercepts if fitting under bound constrained optimization.") - /** @group getParam */ + /** @group expertGetParam */ @Since("2.2.0") def getUpperBoundsOnIntercepts: Vector = $(upperBoundsOnIntercepts) @@ -256,7 +260,7 @@ private[classification] trait LogisticRegressionParams extends ProbabilisticClas } if (!$(fitIntercept)) { require(!isSet(lowerBoundsOnIntercepts) && !isSet(upperBoundsOnIntercepts), - "Pls don't set bounds on intercepts if fitting without intercept.") + "Please don't set bounds on intercepts if fitting without intercept.") } super.validateAndTransformSchema(schema, fitting, featuresDataType) } @@ -393,7 +397,7 @@ class LogisticRegression @Since("1.2.0") ( /** * Set the lower bounds on coefficients if fitting under bound constrained optimization. * - * @group setParam + * @group expertSetParam */ @Since("2.2.0") def setLowerBoundsOnCoefficients(value: Matrix): this.type = set(lowerBoundsOnCoefficients, value) @@ -401,7 +405,7 @@ class LogisticRegression @Since("1.2.0") ( /** * Set the upper bounds on coefficients if fitting under bound constrained optimization. * - * @group setParam + * @group expertSetParam */ @Since("2.2.0") def setUpperBoundsOnCoefficients(value: Matrix): this.type = set(upperBoundsOnCoefficients, value) @@ -409,7 +413,7 @@ class LogisticRegression @Since("1.2.0") ( /** * Set the lower bounds on intercepts if fitting under bound constrained optimization. * - * @group setParam + * @group expertSetParam */ @Since("2.2.0") def setLowerBoundsOnIntercepts(value: Vector): this.type = set(lowerBoundsOnIntercepts, value) @@ -417,7 +421,7 @@ class LogisticRegression @Since("1.2.0") ( /** * Set the upper bounds on intercepts if fitting under bound constrained optimization. * - * @group setParam + * @group expertSetParam */ @Since("2.2.0") def setUpperBoundsOnIntercepts(value: Vector): this.type = set(upperBoundsOnIntercepts, value) @@ -427,28 +431,40 @@ class LogisticRegression @Since("1.2.0") ( numFeatures: Int): Unit = { if (isSet(lowerBoundsOnCoefficients)) { require($(lowerBoundsOnCoefficients).numRows == numCoefficientSets && - $(lowerBoundsOnCoefficients).numCols == numFeatures) + $(lowerBoundsOnCoefficients).numCols == numFeatures, + "The shape of LowerBoundsOnCoefficients must be compatible with (1, number of features) " + + "for binomial regression, or (number of classes, number of features) for multinomial " + + "regression, but found: " + + s"(${getLowerBoundsOnCoefficients.numRows}, ${getLowerBoundsOnCoefficients.numCols}).") } if (isSet(upperBoundsOnCoefficients)) { require($(upperBoundsOnCoefficients).numRows == numCoefficientSets && - $(upperBoundsOnCoefficients).numCols == numFeatures) + $(upperBoundsOnCoefficients).numCols == numFeatures, + "The shape of upperBoundsOnCoefficients must be compatible with (1, number of features) " + + "for binomial regression, or (number of classes, number of features) for multinomial " + + "regression, but found: " + + s"(${getUpperBoundsOnCoefficients.numRows}, ${getUpperBoundsOnCoefficients.numCols}).") } if (isSet(lowerBoundsOnIntercepts)) { - require($(lowerBoundsOnIntercepts).size == numCoefficientSets) + require($(lowerBoundsOnIntercepts).size == numCoefficientSets, "The size of " + + "lowerBoundsOnIntercepts must be equal with 1 for binomial regression, or the number of " + + s"classes for multinomial regression, but found: ${getLowerBoundsOnIntercepts.size}.") } if (isSet(upperBoundsOnIntercepts)) { - require($(upperBoundsOnIntercepts).size == numCoefficientSets) + require($(upperBoundsOnIntercepts).size == numCoefficientSets, "The size of " + + "upperBoundsOnIntercepts must be equal with 1 for binomial regression, or the number of " + + s"classes for multinomial regression, but found: ${getUpperBoundsOnIntercepts.size}.") } if (isSet(lowerBoundsOnCoefficients) && isSet(upperBoundsOnCoefficients)) { require($(lowerBoundsOnCoefficients).toArray.zip($(upperBoundsOnCoefficients).toArray) - .forall(x => x._1 <= x._2), "LowerBoundsOnCoefficients should always " + + .forall(x => x._1 <= x._2), "LowerBoundsOnCoefficients should always be " + "less than or equal to upperBoundsOnCoefficients, but found: " + s"lowerBoundsOnCoefficients = $getLowerBoundsOnCoefficients, " + s"upperBoundsOnCoefficients = $getUpperBoundsOnCoefficients.") } if (isSet(lowerBoundsOnIntercepts) && isSet(upperBoundsOnIntercepts)) { require($(lowerBoundsOnIntercepts).toArray.zip($(upperBoundsOnIntercepts).toArray) - .forall(x => x._1 <= x._2), "LowerBoundsOnIntercepts should always " + + .forall(x => x._1 <= x._2), "LowerBoundsOnIntercepts should always be " + "less than or equal to upperBoundsOnIntercepts, but found: " + s"lowerBoundsOnIntercepts = $getLowerBoundsOnIntercepts, " + s"upperBoundsOnIntercepts = $getUpperBoundsOnIntercepts.") From bfc8c79c8dda7668cfded2a728424853a26da035 Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Thu, 4 May 2017 21:04:15 +0800 Subject: [PATCH 0409/1765] [SPARK-20566][SQL] ColumnVector should support `appendFloats` for array ## What changes were proposed in this pull request? This PR aims to add a missing `appendFloats` API for array into **ColumnVector** class. For double type, there is `appendDoubles` for array [here](https://github.com/apache/spark/blob/master/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVector.java#L818-L824). ## How was this patch tested? Pass the Jenkins with a newly added test case. Author: Dongjoon Hyun Closes #17836 from dongjoon-hyun/SPARK-20566. --- .../execution/vectorized/ColumnVector.java | 8 + .../vectorized/ColumnarBatchSuite.scala | 256 ++++++++++++++++-- 2 files changed, 240 insertions(+), 24 deletions(-) diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVector.java index b105e60a2d34a..ad267ab0c9c47 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVector.java @@ -801,6 +801,14 @@ public final int appendFloats(int count, float v) { return result; } + public final int appendFloats(int length, float[] src, int offset) { + reserve(elementsAppended + length); + int result = elementsAppended; + putFloats(elementsAppended, length, src, offset); + elementsAppended += length; + return result; + } + public final int appendDouble(double v) { reserve(elementsAppended + 1); putDouble(elementsAppended, v); diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala index 8184d7d909f4b..e48e3f6402901 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala @@ -41,24 +41,49 @@ class ColumnarBatchSuite extends SparkFunSuite { val column = ColumnVector.allocate(1024, IntegerType, memMode) var idx = 0 assert(column.anyNullsSet() == false) + assert(column.numNulls() == 0) + + column.appendNotNull() + reference += false + assert(column.anyNullsSet() == false) + assert(column.numNulls() == 0) + + column.appendNotNulls(3) + (1 to 3).foreach(_ => reference += false) + assert(column.anyNullsSet() == false) + assert(column.numNulls() == 0) + + column.appendNull() + reference += true + assert(column.anyNullsSet()) + assert(column.numNulls() == 1) + + column.appendNulls(3) + (1 to 3).foreach(_ => reference += true) + assert(column.anyNullsSet()) + assert(column.numNulls() == 4) + + idx = column.elementsAppended column.putNotNull(idx) reference += false idx += 1 - assert(column.anyNullsSet() == false) + assert(column.anyNullsSet()) + assert(column.numNulls() == 4) column.putNull(idx) reference += true idx += 1 - assert(column.anyNullsSet() == true) - assert(column.numNulls() == 1) + assert(column.anyNullsSet()) + assert(column.numNulls() == 5) column.putNulls(idx, 3) reference += true reference += true reference += true idx += 3 - assert(column.anyNullsSet() == true) + assert(column.anyNullsSet()) + assert(column.numNulls() == 8) column.putNotNulls(idx, 4) reference += false @@ -66,8 +91,8 @@ class ColumnarBatchSuite extends SparkFunSuite { reference += false reference += false idx += 4 - assert(column.anyNullsSet() == true) - assert(column.numNulls() == 4) + assert(column.anyNullsSet()) + assert(column.numNulls() == 8) reference.zipWithIndex.foreach { v => assert(v._1 == column.isNullAt(v._2)) @@ -85,9 +110,26 @@ class ColumnarBatchSuite extends SparkFunSuite { val reference = mutable.ArrayBuffer.empty[Byte] val column = ColumnVector.allocate(1024, ByteType, memMode) - var idx = 0 - val values = (1 :: 2 :: 3 :: 4 :: 5 :: Nil).map(_.toByte).toArray + var values = (10 :: 20 :: 30 :: 40 :: 50 :: Nil).map(_.toByte).toArray + column.appendBytes(2, values, 0) + reference += 10.toByte + reference += 20.toByte + + column.appendBytes(3, values, 2) + reference += 30.toByte + reference += 40.toByte + reference += 50.toByte + + column.appendBytes(6, 60.toByte) + (1 to 6).foreach(_ => reference += 60.toByte) + + column.appendByte(70.toByte) + reference += 70.toByte + + var idx = column.elementsAppended + + values = (1 :: 2 :: 3 :: 4 :: 5 :: Nil).map(_.toByte).toArray column.putBytes(idx, 2, values, 0) reference += 1 reference += 2 @@ -126,9 +168,26 @@ class ColumnarBatchSuite extends SparkFunSuite { val reference = mutable.ArrayBuffer.empty[Short] val column = ColumnVector.allocate(1024, ShortType, memMode) - var idx = 0 - val values = (1 :: 2 :: 3 :: 4 :: 5 :: Nil).map(_.toShort).toArray + var values = (10 :: 20 :: 30 :: 40 :: 50 :: Nil).map(_.toShort).toArray + column.appendShorts(2, values, 0) + reference += 10.toShort + reference += 20.toShort + + column.appendShorts(3, values, 2) + reference += 30.toShort + reference += 40.toShort + reference += 50.toShort + + column.appendShorts(6, 60.toShort) + (1 to 6).foreach(_ => reference += 60.toShort) + + column.appendShort(70.toShort) + reference += 70.toShort + + var idx = column.elementsAppended + + values = (1 :: 2 :: 3 :: 4 :: 5 :: Nil).map(_.toShort).toArray column.putShorts(idx, 2, values, 0) reference += 1 reference += 2 @@ -189,9 +248,26 @@ class ColumnarBatchSuite extends SparkFunSuite { val reference = mutable.ArrayBuffer.empty[Int] val column = ColumnVector.allocate(1024, IntegerType, memMode) - var idx = 0 - val values = (1 :: 2 :: 3 :: 4 :: 5 :: Nil).toArray + var values = (10 :: 20 :: 30 :: 40 :: 50 :: Nil).toArray + column.appendInts(2, values, 0) + reference += 10 + reference += 20 + + column.appendInts(3, values, 2) + reference += 30 + reference += 40 + reference += 50 + + column.appendInts(6, 60) + (1 to 6).foreach(_ => reference += 60) + + column.appendInt(70) + reference += 70 + + var idx = column.elementsAppended + + values = (1 :: 2 :: 3 :: 4 :: 5 :: Nil).toArray column.putInts(idx, 2, values, 0) reference += 1 reference += 2 @@ -257,9 +333,26 @@ class ColumnarBatchSuite extends SparkFunSuite { val reference = mutable.ArrayBuffer.empty[Long] val column = ColumnVector.allocate(1024, LongType, memMode) - var idx = 0 - val values = (1L :: 2L :: 3L :: 4L :: 5L :: Nil).toArray + var values = (10L :: 20L :: 30L :: 40L :: 50L :: Nil).toArray + column.appendLongs(2, values, 0) + reference += 10L + reference += 20L + + column.appendLongs(3, values, 2) + reference += 30L + reference += 40L + reference += 50L + + column.appendLongs(6, 60L) + (1 to 6).foreach(_ => reference += 60L) + + column.appendLong(70L) + reference += 70L + + var idx = column.elementsAppended + + values = (1L :: 2L :: 3L :: 4L :: 5L :: Nil).toArray column.putLongs(idx, 2, values, 0) reference += 1 reference += 2 @@ -320,6 +413,97 @@ class ColumnarBatchSuite extends SparkFunSuite { }} } + test("Float APIs") { + (MemoryMode.ON_HEAP :: MemoryMode.OFF_HEAP :: Nil).foreach { memMode => { + val seed = System.currentTimeMillis() + val random = new Random(seed) + val reference = mutable.ArrayBuffer.empty[Float] + + val column = ColumnVector.allocate(1024, FloatType, memMode) + + var values = (.1f :: .2f :: .3f :: .4f :: .5f :: Nil).toArray + column.appendFloats(2, values, 0) + reference += .1f + reference += .2f + + column.appendFloats(3, values, 2) + reference += .3f + reference += .4f + reference += .5f + + column.appendFloats(6, .6f) + (1 to 6).foreach(_ => reference += .6f) + + column.appendFloat(.7f) + reference += .7f + + var idx = column.elementsAppended + + values = (1.0f :: 2.0f :: 3.0f :: 4.0f :: 5.0f :: Nil).toArray + column.putFloats(idx, 2, values, 0) + reference += 1.0f + reference += 2.0f + idx += 2 + + column.putFloats(idx, 3, values, 2) + reference += 3.0f + reference += 4.0f + reference += 5.0f + idx += 3 + + val buffer = new Array[Byte](8) + Platform.putFloat(buffer, Platform.BYTE_ARRAY_OFFSET, 2.234f) + Platform.putFloat(buffer, Platform.BYTE_ARRAY_OFFSET + 4, 1.123f) + + if (ByteOrder.nativeOrder().equals(ByteOrder.BIG_ENDIAN)) { + // Ensure array contains Little Endian floats + val bb = ByteBuffer.wrap(buffer).order(ByteOrder.LITTLE_ENDIAN) + Platform.putFloat(buffer, Platform.BYTE_ARRAY_OFFSET, bb.getFloat(0)) + Platform.putFloat(buffer, Platform.BYTE_ARRAY_OFFSET + 4, bb.getFloat(4)) + } + + column.putFloats(idx, 1, buffer, 4) + column.putFloats(idx + 1, 1, buffer, 0) + reference += 1.123f + reference += 2.234f + idx += 2 + + column.putFloats(idx, 2, buffer, 0) + reference += 2.234f + reference += 1.123f + idx += 2 + + while (idx < column.capacity) { + val single = random.nextBoolean() + if (single) { + val v = random.nextFloat() + column.putFloat(idx, v) + reference += v + idx += 1 + } else { + val n = math.min(random.nextInt(column.capacity / 20), column.capacity - idx) + val v = random.nextFloat() + column.putFloats(idx, n, v) + var i = 0 + while (i < n) { + reference += v + i += 1 + } + idx += n + } + } + + reference.zipWithIndex.foreach { v => + assert(v._1 == column.getFloat(v._2), "Seed = " + seed + " MemMode=" + memMode) + if (memMode == MemoryMode.OFF_HEAP) { + val addr = column.valuesNativeAddress() + assert(v._1 == Platform.getFloat(null, addr + 4 * v._2)) + } + } + column.close + }} + } + test("Double APIs") { (MemoryMode.ON_HEAP :: MemoryMode.OFF_HEAP :: Nil).foreach { memMode => { val seed = System.currentTimeMillis() @@ -327,9 +511,26 @@ class ColumnarBatchSuite extends SparkFunSuite { val reference = mutable.ArrayBuffer.empty[Double] val column = ColumnVector.allocate(1024, DoubleType, memMode) - var idx = 0 - val values = (1.0 :: 2.0 :: 3.0 :: 4.0 :: 5.0 :: Nil).toArray + var values = (.1 :: .2 :: .3 :: .4 :: .5 :: Nil).toArray + column.appendDoubles(2, values, 0) + reference += .1 + reference += .2 + + column.appendDoubles(3, values, 2) + reference += .3 + reference += .4 + reference += .5 + + column.appendDoubles(6, .6) + (1 to 6).foreach(_ => reference += .6) + + column.appendDouble(.7) + reference += .7 + + var idx = column.elementsAppended + + values = (1.0 :: 2.0 :: 3.0 :: 4.0 :: 5.0 :: Nil).toArray column.putDoubles(idx, 2, values, 0) reference += 1.0 reference += 2.0 @@ -346,8 +547,8 @@ class ColumnarBatchSuite extends SparkFunSuite { Platform.putDouble(buffer, Platform.BYTE_ARRAY_OFFSET + 8, 1.123) if (ByteOrder.nativeOrder().equals(ByteOrder.BIG_ENDIAN)) { - // Ensure array contains Liitle Endian doubles - var bb = ByteBuffer.wrap(buffer).order(ByteOrder.LITTLE_ENDIAN) + // Ensure array contains Little Endian doubles + val bb = ByteBuffer.wrap(buffer).order(ByteOrder.LITTLE_ENDIAN) Platform.putDouble(buffer, Platform.BYTE_ARRAY_OFFSET, bb.getDouble(0)) Platform.putDouble(buffer, Platform.BYTE_ARRAY_OFFSET + 8, bb.getDouble(8)) } @@ -400,40 +601,47 @@ class ColumnarBatchSuite extends SparkFunSuite { val column = ColumnVector.allocate(6, BinaryType, memMode) assert(column.arrayData().elementsAppended == 0) - var idx = 0 + + val str = "string" + column.appendByteArray(str.getBytes(StandardCharsets.UTF_8), + 0, str.getBytes(StandardCharsets.UTF_8).length) + reference += str + assert(column.arrayData().elementsAppended == 6) + + var idx = column.elementsAppended val values = ("Hello" :: "abc" :: Nil).toArray column.putByteArray(idx, values(0).getBytes(StandardCharsets.UTF_8), 0, values(0).getBytes(StandardCharsets.UTF_8).length) reference += values(0) idx += 1 - assert(column.arrayData().elementsAppended == 5) + assert(column.arrayData().elementsAppended == 11) column.putByteArray(idx, values(1).getBytes(StandardCharsets.UTF_8), 0, values(1).getBytes(StandardCharsets.UTF_8).length) reference += values(1) idx += 1 - assert(column.arrayData().elementsAppended == 8) + assert(column.arrayData().elementsAppended == 14) // Just put llo val offset = column.putByteArray(idx, values(0).getBytes(StandardCharsets.UTF_8), 2, values(0).getBytes(StandardCharsets.UTF_8).length - 2) reference += "llo" idx += 1 - assert(column.arrayData().elementsAppended == 11) + assert(column.arrayData().elementsAppended == 17) // Put the same "ll" at offset. This should not allocate more memory in the column. column.putArray(idx, offset, 2) reference += "ll" idx += 1 - assert(column.arrayData().elementsAppended == 11) + assert(column.arrayData().elementsAppended == 17) // Put a long string val s = "abcdefghijklmnopqrstuvwxyz" column.putByteArray(idx, (s + s).getBytes(StandardCharsets.UTF_8)) reference += (s + s) idx += 1 - assert(column.arrayData().elementsAppended == 11 + (s + s).length) + assert(column.arrayData().elementsAppended == 17 + (s + s).length) reference.zipWithIndex.foreach { v => assert(v._1.length == column.getArrayLength(v._2), "MemoryMode=" + memMode) From 0d16faab90e4cd1f73c5b749dbda7bc2a400b26f Mon Sep 17 00:00:00 2001 From: Wayne Zhang Date: Fri, 5 May 2017 10:23:58 +0800 Subject: [PATCH 0410/1765] [SPARK-20574][ML] Allow Bucketizer to handle non-Double numeric column ## What changes were proposed in this pull request? Bucketizer currently requires input column to be Double, but the logic should work on any numeric data types. Many practical problems have integer/float data types, and it could get very tedious to manually cast them into Double before calling bucketizer. This PR extends bucketizer to handle all numeric types. ## How was this patch tested? New test. Author: Wayne Zhang Closes #17840 from actuaryzhang/bucketizer. --- .../apache/spark/ml/feature/Bucketizer.scala | 4 +-- .../spark/ml/feature/BucketizerSuite.scala | 25 +++++++++++++++++++ 2 files changed, 27 insertions(+), 2 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala index d1f3b2af1e482..bb8f2a3aa5f71 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala @@ -116,7 +116,7 @@ final class Bucketizer @Since("1.4.0") (@Since("1.4.0") override val uid: String Bucketizer.binarySearchForBuckets($(splits), feature, keepInvalid) } - val newCol = bucketizer(filteredDataset($(inputCol))) + val newCol = bucketizer(filteredDataset($(inputCol)).cast(DoubleType)) val newField = prepOutputField(filteredDataset.schema) filteredDataset.withColumn($(outputCol), newCol, newField.metadata) } @@ -130,7 +130,7 @@ final class Bucketizer @Since("1.4.0") (@Since("1.4.0") override val uid: String @Since("1.4.0") override def transformSchema(schema: StructType): StructType = { - SchemaUtils.checkColumnType(schema, $(inputCol), DoubleType) + SchemaUtils.checkNumericType(schema, $(inputCol)) SchemaUtils.appendColumn(schema, prepOutputField(schema)) } diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/BucketizerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/BucketizerSuite.scala index aac29137d7911..420fb17ddce8c 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/BucketizerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/BucketizerSuite.scala @@ -26,6 +26,8 @@ import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} import org.apache.spark.ml.util.TestingUtils._ import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.sql.{DataFrame, Row} +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.types._ class BucketizerSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { @@ -162,6 +164,29 @@ class BucketizerSuite extends SparkFunSuite with MLlibTestSparkContext with Defa .setSplits(Array(0.1, 0.8, 0.9)) testDefaultReadWrite(t) } + + test("Bucket numeric features") { + val splits = Array(-3.0, 0.0, 3.0) + val data = Array(-2.0, -1.0, 0.0, 1.0, 2.0) + val expectedBuckets = Array(0.0, 0.0, 1.0, 1.0, 1.0) + val dataFrame: DataFrame = data.zip(expectedBuckets).toSeq.toDF("feature", "expected") + + val bucketizer: Bucketizer = new Bucketizer() + .setInputCol("feature") + .setOutputCol("result") + .setSplits(splits) + + val types = Seq(ShortType, IntegerType, LongType, FloatType, DoubleType, + ByteType, DecimalType(10, 0)) + for (mType <- types) { + val df = dataFrame.withColumn("feature", col("feature").cast(mType)) + bucketizer.transform(df).select("result", "expected").collect().foreach { + case Row(x: Double, y: Double) => + assert(x === y, "The result is not correct after bucketing in type " + + mType.toString + ". " + s"Expected $y but found $x.") + } + } + } } private object BucketizerSuite extends SparkFunSuite { From 4411ac70524ced901f7807d492fb0ad2480a8841 Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Fri, 5 May 2017 09:50:40 +0100 Subject: [PATCH 0411/1765] [INFRA] Close stale PRs ## What changes were proposed in this pull request? This PR proposes to close a stale PR, several PRs suggested to be closed by a committer and obviously inappropriate PRs. Closes #11119 Closes #17853 Closes #17732 Closes #17456 Closes #17410 Closes #17314 Closes #17362 Closes #17542 ## How was this patch tested? N/A Author: hyukjinkwon Closes #17855 from HyukjinKwon/close-pr. From 37cdf077cd3f436f777562df311e3827b0727ce7 Mon Sep 17 00:00:00 2001 From: Yuming Wang Date: Fri, 5 May 2017 11:31:59 +0100 Subject: [PATCH 0412/1765] [SPARK-19660][SQL] Replace the deprecated property name fs.default.name to fs.defaultFS that newly introduced ## What changes were proposed in this pull request? Replace the deprecated property name `fs.default.name` to `fs.defaultFS` that newly introduced. ## How was this patch tested? Existing tests Author: Yuming Wang Closes #17856 from wangyum/SPARK-19660. --- .../spark/sql/execution/streaming/state/StateStoreSuite.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala index ebb7422765ebb..cc09b2d5b7763 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala @@ -314,7 +314,7 @@ class StateStoreSuite extends SparkFunSuite with BeforeAndAfter with PrivateMeth test("SPARK-19677: Committing a delta file atop an existing one should not fail on HDFS") { val conf = new Configuration() conf.set("fs.fake.impl", classOf[RenameLikeHDFSFileSystem].getName) - conf.set("fs.default.name", "fake:///") + conf.set("fs.defaultFS", "fake:///") val provider = newStoreProvider(hadoopConf = conf) provider.getStore(0).commit() From 5773ab121d5d7cbefeef17ff4ac6f8af36cc1251 Mon Sep 17 00:00:00 2001 From: jyu00 Date: Fri, 5 May 2017 11:36:51 +0100 Subject: [PATCH 0413/1765] [SPARK-20546][DEPLOY] spark-class gets syntax error in posix mode ## What changes were proposed in this pull request? Updated spark-class to turn off posix mode so the process substitution doesn't cause a syntax error. ## How was this patch tested? Existing unit tests, manual spark-shell testing with posix mode on Author: jyu00 Closes #17852 from jyu00/master. --- bin/spark-class | 2 ++ 1 file changed, 2 insertions(+) diff --git a/bin/spark-class b/bin/spark-class index 77ea40cc37946..65d3b9612909a 100755 --- a/bin/spark-class +++ b/bin/spark-class @@ -72,6 +72,8 @@ build_command() { printf "%d\0" $? } +# Turn off posix mode since it does not allow process substitution +set +o posix CMD=() while IFS= read -d '' -r ARG; do CMD+=("$ARG") From 9064f1b04461513a147aeb8179471b05595ddbc4 Mon Sep 17 00:00:00 2001 From: madhu Date: Fri, 5 May 2017 22:44:03 +0800 Subject: [PATCH 0414/1765] [SPARK-20495][SQL][CORE] Add StorageLevel to cacheTable API ## What changes were proposed in this pull request? Currently cacheTable API only supports MEMORY_AND_DISK. This PR adds additional API to take different storage levels. ## How was this patch tested? unit tests Author: madhu Closes #17802 from phatak-dev/cacheTableAPI. --- project/MimaExcludes.scala | 2 ++ .../org/apache/spark/sql/catalog/Catalog.scala | 14 +++++++++++++- .../apache/spark/sql/internal/CatalogImpl.scala | 13 +++++++++++++ .../apache/spark/sql/internal/CatalogSuite.scala | 8 ++++++++ 4 files changed, 36 insertions(+), 1 deletion(-) diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index dbf933f28a784..d50882cb1917e 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -36,6 +36,8 @@ object MimaExcludes { // Exclude rules for 2.3.x lazy val v23excludes = v22excludes ++ Seq( + // [SPARK-20495][SQL] Add StorageLevel to cacheTable API + ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.sql.catalog.Catalog.cacheTable") ) // Exclude rules for 2.2.x diff --git a/sql/core/src/main/scala/org/apache/spark/sql/catalog/Catalog.scala b/sql/core/src/main/scala/org/apache/spark/sql/catalog/Catalog.scala index 7e5da012f84ca..ab81725def3f4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/catalog/Catalog.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/catalog/Catalog.scala @@ -22,7 +22,7 @@ import scala.collection.JavaConverters._ import org.apache.spark.annotation.{Experimental, InterfaceStability} import org.apache.spark.sql.{AnalysisException, DataFrame, Dataset} import org.apache.spark.sql.types.StructType - +import org.apache.spark.storage.StorageLevel /** * Catalog interface for Spark. To access this, use `SparkSession.catalog`. @@ -476,6 +476,18 @@ abstract class Catalog { */ def cacheTable(tableName: String): Unit + /** + * Caches the specified table with the given storage level. + * + * @param tableName is either a qualified or unqualified name that designates a table/view. + * If no database identifier is provided, it refers to a temporary view or + * a table/view in the current database. + * @param storageLevel storage level to cache table. + * @since 2.3.0 + */ + def cacheTable(tableName: String, storageLevel: StorageLevel): Unit + + /** * Removes the specified table from the in-memory cache. * diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/CatalogImpl.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/CatalogImpl.scala index 0b8e53868c999..e1049c665a417 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/CatalogImpl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/CatalogImpl.scala @@ -30,6 +30,8 @@ import org.apache.spark.sql.catalyst.plans.logical.LocalRelation import org.apache.spark.sql.execution.command.AlterTableRecoverPartitionsCommand import org.apache.spark.sql.execution.datasources.{CreateTable, DataSource} import org.apache.spark.sql.types.StructType +import org.apache.spark.storage.StorageLevel + /** @@ -419,6 +421,17 @@ class CatalogImpl(sparkSession: SparkSession) extends Catalog { sparkSession.sharedState.cacheManager.cacheQuery(sparkSession.table(tableName), Some(tableName)) } + /** + * Caches the specified table or view with the given storage level. + * + * @group cachemgmt + * @since 2.3.0 + */ + override def cacheTable(tableName: String, storageLevel: StorageLevel): Unit = { + sparkSession.sharedState.cacheManager.cacheQuery( + sparkSession.table(tableName), Some(tableName), storageLevel) + } + /** * Removes the specified table or view from the in-memory cache. * diff --git a/sql/core/src/test/scala/org/apache/spark/sql/internal/CatalogSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/internal/CatalogSuite.scala index 8f9c52cb1e031..bc641fd280a15 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/internal/CatalogSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/internal/CatalogSuite.scala @@ -30,6 +30,7 @@ import org.apache.spark.sql.catalyst.expressions.{Expression, ExpressionInfo} import org.apache.spark.sql.catalyst.plans.logical.Range import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types.StructType +import org.apache.spark.storage.StorageLevel /** @@ -535,4 +536,11 @@ class CatalogSuite .createTempView("fork_table", Range(1, 2, 3, 4), overrideIfExists = true) assert(spark.catalog.listTables().collect().map(_.name).toSet == Set()) } + + test("cacheTable with storage level") { + createTempTable("my_temp_table") + spark.catalog.cacheTable("my_temp_table", StorageLevel.DISK_ONLY) + assert(spark.table("my_temp_table").storageLevel == StorageLevel.DISK_ONLY) + } + } From b9ad2d1916af5091c8585d06ccad8219e437e2bc Mon Sep 17 00:00:00 2001 From: Jarrett Meyer Date: Fri, 5 May 2017 08:30:42 -0700 Subject: [PATCH 0415/1765] [SPARK-20613] Remove excess quotes in Windows executable ## What changes were proposed in this pull request? Quotes are already added to the RUNNER variable on line 54. There is no need to put quotes on line 67. If you do, you will get an error when launching Spark. '""C:\Program' is not recognized as an internal or external command, operable program or batch file. ## How was this patch tested? Tested manually on Windows 10. Author: Jarrett Meyer Closes #17861 from jarrettmeyer/fix-windows-cmd. --- bin/spark-class2.cmd | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/bin/spark-class2.cmd b/bin/spark-class2.cmd index 9faa7d65f83e4..f6157f42843e8 100644 --- a/bin/spark-class2.cmd +++ b/bin/spark-class2.cmd @@ -51,7 +51,7 @@ if not "x%SPARK_PREPEND_CLASSES%"=="x" ( rem Figure out where java is. set RUNNER=java if not "x%JAVA_HOME%"=="x" ( - set RUNNER="%JAVA_HOME%\bin\java" + set RUNNER=%JAVA_HOME%\bin\java ) else ( where /q "%RUNNER%" if ERRORLEVEL 1 ( From 41439fd52dd263b9f7d92e608f027f193f461777 Mon Sep 17 00:00:00 2001 From: Yucai Date: Fri, 5 May 2017 09:51:57 -0700 Subject: [PATCH 0416/1765] [SPARK-20381][SQL] Add SQL metrics of numOutputRows for ObjectHashAggregateExec ## What changes were proposed in this pull request? ObjectHashAggregateExec is missing numOutputRows, add this metrics for it. ## How was this patch tested? Added unit tests for the new metrics. Author: Yucai Closes #17678 from yucai/objectAgg_numOutputRows. --- .../aggregate/ObjectAggregationIterator.scala | 8 ++++++-- .../aggregate/ObjectHashAggregateExec.scala | 3 ++- .../sql/execution/metric/SQLMetricsSuite.scala | 18 ++++++++++++++++++ 3 files changed, 26 insertions(+), 3 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectAggregationIterator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectAggregationIterator.scala index 3a7fcf1fa9d89..6e47f9d611199 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectAggregationIterator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectAggregationIterator.scala @@ -24,6 +24,7 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.expressions.codegen.{BaseOrdering, GenerateOrdering} import org.apache.spark.sql.execution.UnsafeKVExternalSorter +import org.apache.spark.sql.execution.metric.SQLMetric import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.StructType import org.apache.spark.unsafe.KVIterator @@ -39,7 +40,8 @@ class ObjectAggregationIterator( newMutableProjection: (Seq[Expression], Seq[Attribute]) => MutableProjection, originalInputAttributes: Seq[Attribute], inputRows: Iterator[InternalRow], - fallbackCountThreshold: Int) + fallbackCountThreshold: Int, + numOutputRows: SQLMetric) extends AggregationIterator( groupingExpressions, originalInputAttributes, @@ -83,7 +85,9 @@ class ObjectAggregationIterator( override final def next(): UnsafeRow = { val entry = aggBufferIterator.next() - generateOutput(entry.groupingKey, entry.aggregationBuffer) + val res = generateOutput(entry.groupingKey, entry.aggregationBuffer) + numOutputRows += 1 + res } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectHashAggregateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectHashAggregateExec.scala index 3fcb7ec9a6411..b53521b1b6ba2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectHashAggregateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectHashAggregateExec.scala @@ -117,7 +117,8 @@ case class ObjectHashAggregateExec( newMutableProjection(expressions, inputSchema, subexpressionEliminationEnabled), child.output, iter, - fallbackCountThreshold) + fallbackCountThreshold, + numOutputRows) if (!hasInput && groupingExpressions.isEmpty) { numOutputRows += 1 Iterator.single[UnsafeRow](aggregationIterator.outputForEmptyGroupingKeyWithoutInput()) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala index 2ce7db6a22c01..e544245588f46 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala @@ -143,6 +143,24 @@ class SQLMetricsSuite extends SparkFunSuite with SharedSQLContext { ) } + test("ObjectHashAggregate metrics") { + // Assume the execution plan is + // ... -> ObjectHashAggregate(nodeId = 2) -> Exchange(nodeId = 1) + // -> ObjectHashAggregate(nodeId = 0) + val df = testData2.groupBy().agg(collect_set('a)) // 2 partitions + testSparkPlanMetrics(df, 1, Map( + 2L -> ("ObjectHashAggregate", Map("number of output rows" -> 2L)), + 0L -> ("ObjectHashAggregate", Map("number of output rows" -> 1L))) + ) + + // 2 partitions and each partition contains 2 keys + val df2 = testData2.groupBy('a).agg(collect_set('a)) + testSparkPlanMetrics(df2, 1, Map( + 2L -> ("ObjectHashAggregate", Map("number of output rows" -> 4L)), + 0L -> ("ObjectHashAggregate", Map("number of output rows" -> 3L))) + ) + } + test("Sort metrics") { // Assume the execution plan is // WholeStageCodegen(nodeId = 0, Range(nodeId = 2) -> Sort(nodeId = 1)) From bd5788287957d8610a6d19c273b75bd4cdd2d166 Mon Sep 17 00:00:00 2001 From: Shixiong Zhu Date: Fri, 5 May 2017 11:08:26 -0700 Subject: [PATCH 0417/1765] [SPARK-20603][SS][TEST] Set default number of topic partitions to 1 to reduce the load ## What changes were proposed in this pull request? I checked the logs of https://amplab.cs.berkeley.edu/jenkins/job/spark-branch-2.2-test-maven-hadoop-2.7/47/ and found it took several seconds to create Kafka internal topic `__consumer_offsets`. As Kafka creates this topic lazily, the topic creation happens in the first test `deserialization of initial offset with Spark 2.1.0` and causes it timeout. This PR changes `offsets.topic.num.partitions` from the default value 50 to 1 to make creating `__consumer_offsets` (50 partitions -> 1 partition) much faster. ## How was this patch tested? Jenkins Author: Shixiong Zhu Closes #17863 from zsxwing/fix-kafka-flaky-test. --- .../scala/org/apache/spark/sql/kafka010/KafkaTestUtils.scala | 1 + 1 file changed, 1 insertion(+) diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaTestUtils.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaTestUtils.scala index 2ce2760b7f463..f86b8f586d2a0 100644 --- a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaTestUtils.scala +++ b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaTestUtils.scala @@ -292,6 +292,7 @@ class KafkaTestUtils(withBrokerProps: Map[String, Object] = Map.empty) extends L props.put("log.flush.interval.messages", "1") props.put("replica.socket.timeout.ms", "1500") props.put("delete.topic.enable", "true") + props.put("offsets.topic.num.partitions", "1") props.putAll(withBrokerProps.asJava) props } From b31648c081e8db34e0d6c71875318f7b0b047c8b Mon Sep 17 00:00:00 2001 From: Jannik Arndt Date: Fri, 5 May 2017 11:42:55 -0700 Subject: [PATCH 0418/1765] [SPARK-20557][SQL] Support for db column type TIMESTAMP WITH TIME ZONE ## What changes were proposed in this pull request? SparkSQL can now read from a database table with column type [TIMESTAMP WITH TIME ZONE](https://docs.oracle.com/javase/8/docs/api/java/sql/Types.html#TIMESTAMP_WITH_TIMEZONE). ## How was this patch tested? Tested against Oracle database. JoshRosen, you seem to know the class, would you look at this? Thanks! Author: Jannik Arndt Closes #17832 from JannikArndt/spark-20557-timestamp-with-timezone. --- .../spark/sql/jdbc/OracleIntegrationSuite.scala | 13 +++++++++++++ .../sql/execution/datasources/jdbc/JdbcUtils.scala | 3 +++ 2 files changed, 16 insertions(+) diff --git a/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/OracleIntegrationSuite.scala b/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/OracleIntegrationSuite.scala index 1bb89a361ca75..85d4a4a791e6b 100644 --- a/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/OracleIntegrationSuite.scala +++ b/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/OracleIntegrationSuite.scala @@ -70,6 +70,12 @@ class OracleIntegrationSuite extends DockerJDBCIntegrationSuite with SharedSQLCo """.stripMargin.replaceAll("\n", " ")).executeUpdate() conn.commit() + conn.prepareStatement("CREATE TABLE ts_with_timezone (id NUMBER(10), t TIMESTAMP WITH TIME ZONE)") + .executeUpdate() + conn.prepareStatement("INSERT INTO ts_with_timezone VALUES (1, to_timestamp_tz('1999-12-01 11:00:00 UTC','YYYY-MM-DD HH:MI:SS TZR'))") + .executeUpdate() + conn.commit() + sql( s""" |CREATE TEMPORARY VIEW datetime @@ -185,4 +191,11 @@ class OracleIntegrationSuite extends DockerJDBCIntegrationSuite with SharedSQLCo sql("INSERT INTO TABLE datetime1 SELECT * FROM datetime where id = 1") checkRow(sql("SELECT * FROM datetime1 where id = 1").head()) } + + test("SPARK-20557: column type TIMEZONE with TIME STAMP should be recognized") { + val dfRead = sqlContext.read.jdbc(jdbcUrl, "ts_with_timezone", new Properties) + val rows = dfRead.collect() + val types = rows(0).toSeq.map(x => x.getClass.toString) + assert(types(1).equals("class java.sql.Timestamp")) + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala index 0183805d56257..fb877d1ca7639 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala @@ -223,6 +223,9 @@ object JdbcUtils extends Logging { case java.sql.Types.STRUCT => StringType case java.sql.Types.TIME => TimestampType case java.sql.Types.TIMESTAMP => TimestampType + case java.sql.Types.TIMESTAMP_WITH_TIMEZONE + => TimestampType + case -101 => TimestampType // Value for Timestamp with Time Zone in Oracle case java.sql.Types.TINYINT => IntegerType case java.sql.Types.VARBINARY => BinaryType case java.sql.Types.VARCHAR => StringType From 5d75b14bf0f4c1f0813287efaabf49797908ed55 Mon Sep 17 00:00:00 2001 From: Juliusz Sompolski Date: Fri, 5 May 2017 15:31:06 -0700 Subject: [PATCH 0419/1765] [SPARK-20616] RuleExecutor logDebug of batch results should show diff to start of batch ## What changes were proposed in this pull request? Due to a likely typo, the logDebug msg printing the diff of query plans shows a diff to the initial plan, not diff to the start of batch. ## How was this patch tested? Now the debug message prints the diff between start and end of batch. Author: Juliusz Sompolski Closes #17875 from juliuszsompolski/SPARK-20616. --- .../org/apache/spark/sql/catalyst/rules/RuleExecutor.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleExecutor.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleExecutor.scala index 6fc828f63f152..85b368c862630 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleExecutor.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleExecutor.scala @@ -122,7 +122,7 @@ abstract class RuleExecutor[TreeType <: TreeNode[_]] extends Logging { logDebug( s""" |=== Result of Batch ${batch.name} === - |${sideBySide(plan.treeString, curPlan.treeString).mkString("\n")} + |${sideBySide(batchStartPlan.treeString, curPlan.treeString).mkString("\n")} """.stripMargin) } else { logTrace(s"Batch ${batch.name} has no effect.") From b433acae74887e59f2e237a6284a4ae04fbbe854 Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Fri, 5 May 2017 21:26:55 -0700 Subject: [PATCH 0420/1765] [SPARK-20614][PROJECT INFRA] Use the same log4j configuration with Jenkins in AppVeyor ## What changes were proposed in this pull request? Currently, there are flooding logs in AppVeyor (in the console). This has been fine because we can download all the logs. However, (given my observations so far), logs are truncated when there are too many. It has been grown recently and it started to get truncated. For example, see https://ci.appveyor.com/project/ApacheSoftwareFoundation/spark/build/1209-master Even after the log is downloaded, it looks truncated as below: ``` [00:44:21] 17/05/04 18:56:18 INFO TaskSetManager: Finished task 197.0 in stage 601.0 (TID 9211) in 0 ms on localhost (executor driver) (194/200) [00:44:21] 17/05/04 18:56:18 INFO Executor: Running task 199.0 in stage 601.0 (TID 9213) [00:44:21] 17/05/04 18:56:18 INFO Executor: Finished task 198.0 in stage 601.0 (TID 9212). 2473 bytes result sent to driver ... ``` Probably, it looks better to use the same log4j configuration that we are using for SparkR tests in Jenkins(please see https://github.com/apache/spark/blob/fc472bddd1d9c6a28e57e31496c0166777af597e/R/run-tests.sh#L26 and https://github.com/apache/spark/blob/fc472bddd1d9c6a28e57e31496c0166777af597e/R/log4j.properties) ``` # Set everything to be logged to the file target/unit-tests.log log4j.rootCategory=INFO, file log4j.appender.file=org.apache.log4j.FileAppender log4j.appender.file.append=true log4j.appender.file.file=R/target/unit-tests.log log4j.appender.file.layout=org.apache.log4j.PatternLayout log4j.appender.file.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss.SSS} %t %p %c{1}: %m%n # Ignore messages below warning level from Jetty, because it's a bit verbose log4j.logger.org.eclipse.jetty=WARN org.eclipse.jetty.LEVEL=WARN ``` ## How was this patch tested? Manually tested with spark-test account - https://ci.appveyor.com/project/spark-test/spark/build/672-r-log4j (there is an example for flaky test here) - https://ci.appveyor.com/project/spark-test/spark/build/673-r-log4j (I re-ran the build). Author: hyukjinkwon Closes #17873 from HyukjinKwon/appveyor-reduce-logs. --- appveyor.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/appveyor.yml b/appveyor.yml index bbb27589cad09..4d31af70f056e 100644 --- a/appveyor.yml +++ b/appveyor.yml @@ -49,7 +49,7 @@ build_script: - cmd: mvn -DskipTests -Psparkr -Phive -Phive-thriftserver package test_script: - - cmd: .\bin\spark-submit2.cmd --conf spark.hadoop.fs.defaultFS="file:///" R\pkg\tests\run-all.R + - cmd: .\bin\spark-submit2.cmd --driver-java-options "-Dlog4j.configuration=file:///%CD:\=/%/R/log4j.properties" --conf spark.hadoop.fs.defaultFS="file:///" R\pkg\tests\run-all.R notifications: - provider: Email From cafca54c0ea8bd9c3b80dcbc88d9f2b8d708a026 Mon Sep 17 00:00:00 2001 From: Xiao Li Date: Sat, 6 May 2017 22:21:19 -0700 Subject: [PATCH 0421/1765] [SPARK-20557][SQL] Support JDBC data type Time with Time Zone ### What changes were proposed in this pull request? This PR is to support JDBC data type TIME WITH TIME ZONE. It can be converted to TIMESTAMP In addition, before this PR, for unsupported data types, we simply output the type number instead of the type name. ``` java.sql.SQLException: Unsupported type 2014 ``` After this PR, the message is like ``` java.sql.SQLException: Unsupported type TIMESTAMP_WITH_TIMEZONE ``` - Also upgrade the H2 version to `1.4.195` which has the type fix for "TIMESTAMP WITH TIMEZONE". However, it is not fully supported. Thus, we capture the exception, but we still need it to partially test the support of "TIMESTAMP WITH TIMEZONE", because Docker tests are not regularly run. ### How was this patch tested? Added test cases. Author: Xiao Li Closes #17835 from gatorsmile/h2. --- .../sql/jdbc/OracleIntegrationSuite.scala | 2 +- .../sql/jdbc/PostgresIntegrationSuite.scala | 15 ++++++++++++ sql/core/pom.xml | 2 +- .../datasources/jdbc/JdbcUtils.scala | 12 +++++++--- .../spark/sql/internal/CatalogImpl.scala | 1 - .../org/apache/spark/sql/jdbc/JDBCSuite.scala | 24 +++++++++++++++++-- 6 files changed, 48 insertions(+), 8 deletions(-) diff --git a/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/OracleIntegrationSuite.scala b/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/OracleIntegrationSuite.scala index 85d4a4a791e6b..f7b1ec34ced76 100644 --- a/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/OracleIntegrationSuite.scala +++ b/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/OracleIntegrationSuite.scala @@ -192,7 +192,7 @@ class OracleIntegrationSuite extends DockerJDBCIntegrationSuite with SharedSQLCo checkRow(sql("SELECT * FROM datetime1 where id = 1").head()) } - test("SPARK-20557: column type TIMEZONE with TIME STAMP should be recognized") { + test("SPARK-20557: column type TIMESTAMP with TIME ZONE should be recognized") { val dfRead = sqlContext.read.jdbc(jdbcUrl, "ts_with_timezone", new Properties) val rows = dfRead.collect() val types = rows(0).toSeq.map(x => x.getClass.toString) diff --git a/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/PostgresIntegrationSuite.scala b/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/PostgresIntegrationSuite.scala index a1a065a443e67..eb3c458360e7b 100644 --- a/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/PostgresIntegrationSuite.scala +++ b/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/PostgresIntegrationSuite.scala @@ -55,6 +55,13 @@ class PostgresIntegrationSuite extends DockerJDBCIntegrationSuite { + "null, null, null, null, null, " + "null, null, null, null, null, null, null)" ).executeUpdate() + + conn.prepareStatement("CREATE TABLE ts_with_timezone " + + "(id integer, tstz TIMESTAMP WITH TIME ZONE, ttz TIME WITH TIME ZONE)") + .executeUpdate() + conn.prepareStatement("INSERT INTO ts_with_timezone VALUES " + + "(1, TIMESTAMP WITH TIME ZONE '2016-08-12 10:22:31.949271-07', TIME WITH TIME ZONE '17:22:31.949271+00')") + .executeUpdate() } test("Type mapping for various types") { @@ -126,4 +133,12 @@ class PostgresIntegrationSuite extends DockerJDBCIntegrationSuite { assert(schema(0).dataType == FloatType) assert(schema(1).dataType == ShortType) } + + test("SPARK-20557: column type TIMESTAMP with TIME ZONE and TIME with TIME ZONE should be recognized") { + val dfRead = sqlContext.read.jdbc(jdbcUrl, "ts_with_timezone", new Properties) + val rows = dfRead.collect() + val types = rows(0).toSeq.map(x => x.getClass.toString) + assert(types(1).equals("class java.sql.Timestamp")) + assert(types(2).equals("class java.sql.Timestamp")) + } } diff --git a/sql/core/pom.xml b/sql/core/pom.xml index e170133f0f0bf..fe4be963e8184 100644 --- a/sql/core/pom.xml +++ b/sql/core/pom.xml @@ -115,7 +115,7 @@ com.h2database h2 - 1.4.183 + 1.4.195 test diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala index fb877d1ca7639..71eaab119d75d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.execution.datasources.jdbc -import java.sql.{Connection, Driver, DriverManager, PreparedStatement, ResultSet, ResultSetMetaData, SQLException} +import java.sql.{Connection, Driver, DriverManager, JDBCType, PreparedStatement, ResultSet, ResultSetMetaData, SQLException} import java.util.Locale import scala.collection.JavaConverters._ @@ -217,11 +217,14 @@ object JdbcUtils extends Logging { case java.sql.Types.OTHER => null case java.sql.Types.REAL => DoubleType case java.sql.Types.REF => StringType + case java.sql.Types.REF_CURSOR => null case java.sql.Types.ROWID => LongType case java.sql.Types.SMALLINT => IntegerType case java.sql.Types.SQLXML => StringType case java.sql.Types.STRUCT => StringType case java.sql.Types.TIME => TimestampType + case java.sql.Types.TIME_WITH_TIMEZONE + => TimestampType case java.sql.Types.TIMESTAMP => TimestampType case java.sql.Types.TIMESTAMP_WITH_TIMEZONE => TimestampType @@ -229,11 +232,14 @@ object JdbcUtils extends Logging { case java.sql.Types.TINYINT => IntegerType case java.sql.Types.VARBINARY => BinaryType case java.sql.Types.VARCHAR => StringType - case _ => null + case _ => + throw new SQLException("Unrecognized SQL type " + sqlType) // scalastyle:on } - if (answer == null) throw new SQLException("Unsupported type " + sqlType) + if (answer == null) { + throw new SQLException("Unsupported type " + JDBCType.valueOf(sqlType).getName) + } answer } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/CatalogImpl.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/CatalogImpl.scala index e1049c665a417..142b005850a49 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/CatalogImpl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/CatalogImpl.scala @@ -33,7 +33,6 @@ import org.apache.spark.sql.types.StructType import org.apache.spark.storage.StorageLevel - /** * Internal implementation of the user-facing `Catalog`. */ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala index 5bd36ec25ccb0..d9f3689411ab7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala @@ -18,13 +18,13 @@ package org.apache.spark.sql.jdbc import java.math.BigDecimal -import java.sql.{Date, DriverManager, Timestamp} +import java.sql.{Date, DriverManager, SQLException, Timestamp} import java.util.{Calendar, GregorianCalendar, Properties} import org.h2.jdbc.JdbcSQLException import org.scalatest.{BeforeAndAfter, PrivateMethodTester} -import org.apache.spark.SparkFunSuite +import org.apache.spark.{SparkException, SparkFunSuite} import org.apache.spark.sql.{AnalysisException, DataFrame, Row} import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap import org.apache.spark.sql.execution.DataSourceScanExec @@ -141,6 +141,15 @@ class JDBCSuite extends SparkFunSuite |OPTIONS (url '$url', dbtable 'TEST.TIMETYPES', user 'testUser', password 'testPass') """.stripMargin.replaceAll("\n", " ")) + conn.prepareStatement("CREATE TABLE test.timezone (tz TIMESTAMP WITH TIME ZONE) " + + "AS SELECT '1999-01-08 04:05:06.543543543 GMT-08:00'") + .executeUpdate() + conn.commit() + + conn.prepareStatement("CREATE TABLE test.array (ar ARRAY) " + + "AS SELECT '(1, 2, 3)'") + .executeUpdate() + conn.commit() conn.prepareStatement("create table test.flttypes (a DOUBLE, b REAL, c DECIMAL(38, 18))" ).executeUpdate() @@ -919,6 +928,17 @@ class JDBCSuite extends SparkFunSuite assert(res === (foobarCnt, 0L, foobarCnt) :: Nil) } + test("unsupported types") { + var e = intercept[SparkException] { + spark.read.jdbc(urlWithUserAndPass, "TEST.TIMEZONE", new Properties()).collect() + }.getMessage + assert(e.contains("java.lang.UnsupportedOperationException: unimplemented")) + e = intercept[SQLException] { + spark.read.jdbc(urlWithUserAndPass, "TEST.ARRAY", new Properties()).collect() + }.getMessage + assert(e.contains("Unsupported type ARRAY")) + } + test("SPARK-19318: Connection properties keys should be case-sensitive.") { def testJdbcOptions(options: JDBCOptions): Unit = { // Spark JDBC data source options are case-insensitive From 63d90e7da4913917982c0501d63ccc433a9b6b46 Mon Sep 17 00:00:00 2001 From: zero323 Date: Sat, 6 May 2017 22:28:42 -0700 Subject: [PATCH 0422/1765] [SPARK-18777][PYTHON][SQL] Return UDF from udf.register ## What changes were proposed in this pull request? - Move udf wrapping code from `functions.udf` to `functions.UserDefinedFunction`. - Return wrapped udf from `catalog.registerFunction` and dependent methods. - Update docstrings in `catalog.registerFunction` and `SQLContext.registerFunction`. - Unit tests. ## How was this patch tested? - Existing unit tests and docstests. - Additional tests covering new feature. Author: zero323 Closes #17831 from zero323/SPARK-18777. --- python/pyspark/sql/catalog.py | 11 ++++++++--- python/pyspark/sql/context.py | 12 ++++++++---- python/pyspark/sql/functions.py | 23 ++++++++++++++--------- python/pyspark/sql/tests.py | 9 +++++++++ 4 files changed, 39 insertions(+), 16 deletions(-) diff --git a/python/pyspark/sql/catalog.py b/python/pyspark/sql/catalog.py index 41e68a45a6159..5f25dce161963 100644 --- a/python/pyspark/sql/catalog.py +++ b/python/pyspark/sql/catalog.py @@ -237,23 +237,28 @@ def registerFunction(self, name, f, returnType=StringType()): :param name: name of the UDF :param f: python function :param returnType: a :class:`pyspark.sql.types.DataType` object + :return: a wrapped :class:`UserDefinedFunction` - >>> spark.catalog.registerFunction("stringLengthString", lambda x: len(x)) + >>> strlen = spark.catalog.registerFunction("stringLengthString", len) >>> spark.sql("SELECT stringLengthString('test')").collect() [Row(stringLengthString(test)=u'4')] + >>> spark.sql("SELECT 'foo' AS text").select(strlen("text")).collect() + [Row(stringLengthString(text)=u'3')] + >>> from pyspark.sql.types import IntegerType - >>> spark.catalog.registerFunction("stringLengthInt", lambda x: len(x), IntegerType()) + >>> _ = spark.catalog.registerFunction("stringLengthInt", len, IntegerType()) >>> spark.sql("SELECT stringLengthInt('test')").collect() [Row(stringLengthInt(test)=4)] >>> from pyspark.sql.types import IntegerType - >>> spark.udf.register("stringLengthInt", lambda x: len(x), IntegerType()) + >>> _ = spark.udf.register("stringLengthInt", len, IntegerType()) >>> spark.sql("SELECT stringLengthInt('test')").collect() [Row(stringLengthInt(test)=4)] """ udf = UserDefinedFunction(f, returnType, name) self._jsparkSession.udf().registerPython(name, udf._judf) + return udf._wrapped() @since(2.0) def isCached(self, tableName): diff --git a/python/pyspark/sql/context.py b/python/pyspark/sql/context.py index fdb7abbad4e5f..5197a9e004610 100644 --- a/python/pyspark/sql/context.py +++ b/python/pyspark/sql/context.py @@ -185,22 +185,26 @@ def registerFunction(self, name, f, returnType=StringType()): :param name: name of the UDF :param f: python function :param returnType: a :class:`pyspark.sql.types.DataType` object + :return: a wrapped :class:`UserDefinedFunction` - >>> sqlContext.registerFunction("stringLengthString", lambda x: len(x)) + >>> strlen = sqlContext.registerFunction("stringLengthString", lambda x: len(x)) >>> sqlContext.sql("SELECT stringLengthString('test')").collect() [Row(stringLengthString(test)=u'4')] + >>> sqlContext.sql("SELECT 'foo' AS text").select(strlen("text")).collect() + [Row(stringLengthString(text)=u'3')] + >>> from pyspark.sql.types import IntegerType - >>> sqlContext.registerFunction("stringLengthInt", lambda x: len(x), IntegerType()) + >>> _ = sqlContext.registerFunction("stringLengthInt", lambda x: len(x), IntegerType()) >>> sqlContext.sql("SELECT stringLengthInt('test')").collect() [Row(stringLengthInt(test)=4)] >>> from pyspark.sql.types import IntegerType - >>> sqlContext.udf.register("stringLengthInt", lambda x: len(x), IntegerType()) + >>> _ = sqlContext.udf.register("stringLengthInt", lambda x: len(x), IntegerType()) >>> sqlContext.sql("SELECT stringLengthInt('test')").collect() [Row(stringLengthInt(test)=4)] """ - self.sparkSession.catalog.registerFunction(name, f, returnType) + return self.sparkSession.catalog.registerFunction(name, f, returnType) @ignore_unicode_prefix @since(2.1) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index 843ae3816f061..8b3487c3f1083 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -1917,6 +1917,19 @@ def __call__(self, *cols): sc = SparkContext._active_spark_context return Column(judf.apply(_to_seq(sc, cols, _to_java_column))) + def _wrapped(self): + """ + Wrap this udf with a function and attach docstring from func + """ + @functools.wraps(self.func) + def wrapper(*args): + return self(*args) + + wrapper.func = self.func + wrapper.returnType = self.returnType + + return wrapper + @since(1.3) def udf(f=None, returnType=StringType()): @@ -1951,15 +1964,7 @@ def udf(f=None, returnType=StringType()): """ def _udf(f, returnType=StringType()): udf_obj = UserDefinedFunction(f, returnType) - - @functools.wraps(f) - def wrapper(*args): - return udf_obj(*args) - - wrapper.func = udf_obj.func - wrapper.returnType = udf_obj.returnType - - return wrapper + return udf_obj._wrapped() # decorator @udf, @udf() or @udf(dataType()) if f is None or isinstance(f, (str, DataType)): diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index f644624f7f317..7983bc536fc6c 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -436,6 +436,15 @@ def test_udf_with_order_by_and_limit(self): res.explain(True) self.assertEqual(res.collect(), [Row(id=0, copy=0)]) + def test_udf_registration_returns_udf(self): + df = self.spark.range(10) + add_three = self.spark.udf.register("add_three", lambda x: x + 3, IntegerType()) + + self.assertListEqual( + df.selectExpr("add_three(id) AS plus_three").collect(), + df.select(add_three("id").alias("plus_three")).collect() + ) + def test_wholefile_json(self): people1 = self.spark.read.json("python/test_support/sql/people.json") people_array = self.spark.read.json("python/test_support/sql/people_array.json", From 37f963ac13ec1bd958c44c7c15b5e8cb6c06cbbc Mon Sep 17 00:00:00 2001 From: caoxuewen Date: Sun, 7 May 2017 10:08:06 +0100 Subject: [PATCH 0423/1765] [SPARK-20518][CORE] Supplement the new blockidsuite unit tests ## What changes were proposed in this pull request? This PR adds the new unit tests to support ShuffleDataBlockId , ShuffleIndexBlockId , TempShuffleBlockId , TempLocalBlockId ## How was this patch tested? The new unit test. Author: caoxuewen Closes #17794 from heary-cao/blockidsuite. --- .../apache/spark/storage/BlockIdSuite.scala | 52 +++++++++++++++++++ 1 file changed, 52 insertions(+) diff --git a/core/src/test/scala/org/apache/spark/storage/BlockIdSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockIdSuite.scala index 89ed031b6fcd1..f0c521b00b583 100644 --- a/core/src/test/scala/org/apache/spark/storage/BlockIdSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/BlockIdSuite.scala @@ -17,6 +17,8 @@ package org.apache.spark.storage +import java.util.UUID + import org.apache.spark.SparkFunSuite class BlockIdSuite extends SparkFunSuite { @@ -67,6 +69,32 @@ class BlockIdSuite extends SparkFunSuite { assertSame(id, BlockId(id.toString)) } + test("shuffle data") { + val id = ShuffleDataBlockId(4, 5, 6) + assertSame(id, ShuffleDataBlockId(4, 5, 6)) + assertDifferent(id, ShuffleDataBlockId(6, 5, 6)) + assert(id.name === "shuffle_4_5_6.data") + assert(id.asRDDId === None) + assert(id.shuffleId === 4) + assert(id.mapId === 5) + assert(id.reduceId === 6) + assert(!id.isShuffle) + assertSame(id, BlockId(id.toString)) + } + + test("shuffle index") { + val id = ShuffleIndexBlockId(7, 8, 9) + assertSame(id, ShuffleIndexBlockId(7, 8, 9)) + assertDifferent(id, ShuffleIndexBlockId(9, 8, 9)) + assert(id.name === "shuffle_7_8_9.index") + assert(id.asRDDId === None) + assert(id.shuffleId === 7) + assert(id.mapId === 8) + assert(id.reduceId === 9) + assert(!id.isShuffle) + assertSame(id, BlockId(id.toString)) + } + test("broadcast") { val id = BroadcastBlockId(42) assertSame(id, BroadcastBlockId(42)) @@ -101,6 +129,30 @@ class BlockIdSuite extends SparkFunSuite { assertSame(id, BlockId(id.toString)) } + test("temp local") { + val id = TempLocalBlockId(new UUID(5, 2)) + assertSame(id, TempLocalBlockId(new UUID(5, 2))) + assertDifferent(id, TempLocalBlockId(new UUID(5, 3))) + assert(id.name === "temp_local_00000000-0000-0005-0000-000000000002") + assert(id.asRDDId === None) + assert(id.isBroadcast === false) + assert(id.id.getMostSignificantBits() === 5) + assert(id.id.getLeastSignificantBits() === 2) + assert(!id.isShuffle) + } + + test("temp shuffle") { + val id = TempShuffleBlockId(new UUID(1, 2)) + assertSame(id, TempShuffleBlockId(new UUID(1, 2))) + assertDifferent(id, TempShuffleBlockId(new UUID(1, 3))) + assert(id.name === "temp_shuffle_00000000-0000-0001-0000-000000000002") + assert(id.asRDDId === None) + assert(id.isBroadcast === false) + assert(id.id.getMostSignificantBits() === 1) + assert(id.id.getLeastSignificantBits() === 2) + assert(!id.isShuffle) + } + test("test") { val id = TestBlockId("abc") assertSame(id, TestBlockId("abc")) From 88e6d75072c23fa99d4df00d087d03d8c38e8c69 Mon Sep 17 00:00:00 2001 From: Daniel Li Date: Sun, 7 May 2017 10:09:58 +0100 Subject: [PATCH 0424/1765] [SPARK-20484][MLLIB] Add documentation to ALS code MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## What changes were proposed in this pull request? This PR adds documentation to the ALS code. ## How was this patch tested? Existing tests were used. mengxr srowen This contribution is my original work. I have the license to work on this project under the Spark project’s open source license. Author: Daniel Li Closes #17793 from danielyli/spark-20484. --- .../apache/spark/ml/recommendation/ALS.scala | 236 +++++++++++++++--- 1 file changed, 202 insertions(+), 34 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala index a20ef72446661..1562bf1beb7e1 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala @@ -774,6 +774,28 @@ object ALS extends DefaultParamsReadable[ALS] with Logging { /** * :: DeveloperApi :: * Implementation of the ALS algorithm. + * + * This implementation of the ALS factorization algorithm partitions the two sets of factors among + * Spark workers so as to reduce network communication by only sending one copy of each factor + * vector to each Spark worker on each iteration, and only if needed. This is achieved by + * precomputing some information about the ratings matrix to determine which users require which + * item factors and vice versa. See the Scaladoc for `InBlock` for a detailed explanation of how + * the precomputation is done. + * + * In addition, since each iteration of calculating the factor matrices depends on the known + * ratings, which are spread across Spark partitions, a naive implementation would incur + * significant network communication overhead between Spark workers, as the ratings RDD would be + * repeatedly shuffled during each iteration. This implementation reduces that overhead by + * performing the shuffling operation up front, precomputing each partition's ratings dependencies + * and duplicating those values to the appropriate workers before starting iterations to solve for + * the factor matrices. See the Scaladoc for `OutBlock` for a detailed explanation of how the + * precomputation is done. + * + * Note that the term "rating block" is a bit of a misnomer, as the ratings are not partitioned by + * contiguous blocks from the ratings matrix but by a hash function on the rating's location in + * the matrix. If it helps you to visualize the partitions, it is easier to think of the term + * "block" as referring to a subset of an RDD containing the ratings rather than a contiguous + * submatrix of the ratings matrix. */ @DeveloperApi def train[ID: ClassTag]( // scalastyle:ignore @@ -791,32 +813,43 @@ object ALS extends DefaultParamsReadable[ALS] with Logging { checkpointInterval: Int = 10, seed: Long = 0L)( implicit ord: Ordering[ID]): (RDD[(ID, Array[Float])], RDD[(ID, Array[Float])]) = { + require(!ratings.isEmpty(), s"No ratings available from $ratings") require(intermediateRDDStorageLevel != StorageLevel.NONE, "ALS is not designed to run without persisting intermediate RDDs.") + val sc = ratings.sparkContext + + // Precompute the rating dependencies of each partition val userPart = new ALSPartitioner(numUserBlocks) val itemPart = new ALSPartitioner(numItemBlocks) - val userLocalIndexEncoder = new LocalIndexEncoder(userPart.numPartitions) - val itemLocalIndexEncoder = new LocalIndexEncoder(itemPart.numPartitions) - val solver = if (nonnegative) new NNLSSolver else new CholeskySolver val blockRatings = partitionRatings(ratings, userPart, itemPart) .persist(intermediateRDDStorageLevel) val (userInBlocks, userOutBlocks) = makeBlocks("user", blockRatings, userPart, itemPart, intermediateRDDStorageLevel) - // materialize blockRatings and user blocks - userOutBlocks.count() + userOutBlocks.count() // materialize blockRatings and user blocks val swappedBlockRatings = blockRatings.map { case ((userBlockId, itemBlockId), RatingBlock(userIds, itemIds, localRatings)) => ((itemBlockId, userBlockId), RatingBlock(itemIds, userIds, localRatings)) } val (itemInBlocks, itemOutBlocks) = makeBlocks("item", swappedBlockRatings, itemPart, userPart, intermediateRDDStorageLevel) - // materialize item blocks - itemOutBlocks.count() + itemOutBlocks.count() // materialize item blocks + + // Encoders for storing each user/item's partition ID and index within its partition using a + // single integer; used as an optimization + val userLocalIndexEncoder = new LocalIndexEncoder(userPart.numPartitions) + val itemLocalIndexEncoder = new LocalIndexEncoder(itemPart.numPartitions) + + // These are the user and item factor matrices that, once trained, are multiplied together to + // estimate the rating matrix. The two matrices are stored in RDDs, partitioned by column such + // that each factor column resides on the same Spark worker as its corresponding user or item. val seedGen = new XORShiftRandom(seed) var userFactors = initialize(userInBlocks, rank, seedGen.nextLong()) var itemFactors = initialize(itemInBlocks, rank, seedGen.nextLong()) + + val solver = if (nonnegative) new NNLSSolver else new CholeskySolver + var previousCheckpointFile: Option[String] = None val shouldCheckpoint: Int => Boolean = (iter) => sc.checkpointDir.isDefined && checkpointInterval != -1 && (iter % checkpointInterval == 0) @@ -830,6 +863,7 @@ object ALS extends DefaultParamsReadable[ALS] with Logging { logWarning(s"Cannot delete checkpoint file $file:", e) } } + if (implicitPrefs) { for (iter <- 1 to maxIter) { userFactors.setName(s"userFactors-$iter").persist(intermediateRDDStorageLevel) @@ -910,26 +944,154 @@ object ALS extends DefaultParamsReadable[ALS] with Logging { private type FactorBlock = Array[Array[Float]] /** - * Out-link block that stores, for each dst (item/user) block, which src (user/item) factors to - * send. For example, outLinkBlock(0) contains the local indices (not the original src IDs) of the - * src factors in this block to send to dst block 0. + * A mapping of the columns of the items factor matrix that are needed when calculating each row + * of the users factor matrix, and vice versa. + * + * Specifically, when calculating a user factor vector, since only those columns of the items + * factor matrix that correspond to the items that that user has rated are needed, we can avoid + * having to repeatedly copy the entire items factor matrix to each worker later in the algorithm + * by precomputing these dependencies for all users, storing them in an RDD of `OutBlock`s. The + * items' dependencies on the columns of the users factor matrix is computed similarly. + * + * =Example= + * + * Using the example provided in the `InBlock` Scaladoc, `userOutBlocks` would look like the + * following: + * + * {{{ + * userOutBlocks.collect() == Seq( + * 0 -> Array(Array(0, 1), Array(0, 1)), + * 1 -> Array(Array(0), Array(0)) + * ) + * }}} + * + * Each value in this map-like sequence is of type `Array[Array[Int]]`. The values in the + * inner array are the ranks of the sorted user IDs in that partition; so in the example above, + * `Array(0, 1)` in partition 0 refers to user IDs 0 and 6, since when all unique user IDs in + * partition 0 are sorted, 0 is the first ID and 6 is the second. The position of each inner + * array in its enclosing outer array denotes the partition number to which item IDs map; in the + * example, the first `Array(0, 1)` is in position 0 of its outer array, denoting item IDs that + * map to partition 0. + * + * In summary, the data structure encodes the following information: + * + * * There are ratings with user IDs 0 and 6 (encoded in `Array(0, 1)`, where 0 and 1 are the + * indices of the user IDs 0 and 6 on partition 0) whose item IDs map to partitions 0 and 1 + * (represented by the fact that `Array(0, 1)` appears in both the 0th and 1st positions). + * + * * There are ratings with user ID 3 (encoded in `Array(0)`, where 0 is the index of the user + * ID 3 on partition 1) whose item IDs map to partitions 0 and 1 (represented by the fact that + * `Array(0)` appears in both the 0th and 1st positions). */ private type OutBlock = Array[Array[Int]] /** - * In-link block for computing src (user/item) factors. This includes the original src IDs - * of the elements within this block as well as encoded dst (item/user) indices and corresponding - * ratings. The dst indices are in the form of (blockId, localIndex), which are not the original - * dst IDs. To compute src factors, we expect receiving dst factors that match the dst indices. - * For example, if we have an in-link record + * In-link block for computing user and item factor matrices. + * + * The ALS algorithm partitions the columns of the users factor matrix evenly among Spark workers. + * Since each column of the factor matrix is calculated using the known ratings of the correspond- + * ing user, and since the ratings don't change across iterations, the ALS algorithm preshuffles + * the ratings to the appropriate partitions, storing them in `InBlock` objects. + * + * The ratings shuffled by item ID are computed similarly and also stored in `InBlock` objects. + * Note that this means every rating is stored twice, once as shuffled by user ID and once by item + * ID. This is a necessary tradeoff, since in general a rating will not be on the same worker + * when partitioned by user as by item. + * + * =Example= + * + * Say we have a small collection of eight items to offer the seven users in our application. We + * have some known ratings given by the users, as seen in the matrix below: + * + * {{{ + * Items + * 0 1 2 3 4 5 6 7 + * +---+---+---+---+---+---+---+---+ + * 0 | |0.1| | |0.4| | |0.7| + * +---+---+---+---+---+---+---+---+ + * 1 | | | | | | | | | + * +---+---+---+---+---+---+---+---+ + * U 2 | | | | | | | | | + * s +---+---+---+---+---+---+---+---+ + * e 3 | |3.1| | |3.4| | |3.7| + * r +---+---+---+---+---+---+---+---+ + * s 4 | | | | | | | | | + * +---+---+---+---+---+---+---+---+ + * 5 | | | | | | | | | + * +---+---+---+---+---+---+---+---+ + * 6 | |6.1| | |6.4| | |6.7| + * +---+---+---+---+---+---+---+---+ + * }}} + * + * The ratings are represented as an RDD, passed to the `partitionRatings` method as the `ratings` + * parameter: + * + * {{{ + * ratings.collect() == Seq( + * Rating(0, 1, 0.1f), + * Rating(0, 4, 0.4f), + * Rating(0, 7, 0.7f), + * Rating(3, 1, 3.1f), + * Rating(3, 4, 3.4f), + * Rating(3, 7, 3.7f), + * Rating(6, 1, 6.1f), + * Rating(6, 4, 6.4f), + * Rating(6, 7, 6.7f) + * ) + * }}} * - * {srcId: 0, dstBlockId: 2, dstLocalIndex: 3, rating: 5.0}, + * Say that we are using two partitions to calculate each factor matrix: * - * and assume that the dst factors are stored as dstFactors: Map[Int, Array[Array[Float]]], which - * is a blockId to dst factors map, the corresponding dst factor of the record is dstFactor(2)(3). + * {{{ + * val userPart = new ALSPartitioner(2) + * val itemPart = new ALSPartitioner(2) + * val blockRatings = partitionRatings(ratings, userPart, itemPart) + * }}} * - * We use a CSC-like (compressed sparse column) format to store the in-link information. So we can - * compute src factors one after another using only one normal equation instance. + * Ratings are mapped to partitions using the user/item IDs modulo the number of partitions. With + * two partitions, ratings with even-valued user IDs are shuffled to partition 0 while those with + * odd-valued user IDs are shuffled to partition 1: + * + * {{{ + * userInBlocks.collect() == Seq( + * 0 -> Seq( + * // Internally, the class stores the ratings in a more optimized format than + * // a sequence of `Rating`s, but for clarity we show it as such here. + * Rating(0, 1, 0.1f), + * Rating(0, 4, 0.4f), + * Rating(0, 7, 0.7f), + * Rating(6, 1, 6.1f), + * Rating(6, 4, 6.4f), + * Rating(6, 7, 6.7f) + * ), + * 1 -> Seq( + * Rating(3, 1, 3.1f), + * Rating(3, 4, 3.4f), + * Rating(3, 7, 3.7f) + * ) + * ) + * }}} + * + * Similarly, ratings with even-valued item IDs are shuffled to partition 0 while those with + * odd-valued item IDs are shuffled to partition 1: + * + * {{{ + * itemInBlocks.collect() == Seq( + * 0 -> Seq( + * Rating(0, 4, 0.4f), + * Rating(3, 4, 3.4f), + * Rating(6, 4, 6.4f) + * ), + * 1 -> Seq( + * Rating(0, 1, 0.1f), + * Rating(0, 7, 0.7f), + * Rating(3, 1, 3.1f), + * Rating(3, 7, 3.7f), + * Rating(6, 1, 6.1f), + * Rating(6, 7, 6.7f) + * ) + * ) + * }}} * * @param srcIds src ids (ordered) * @param dstPtrs dst pointers. Elements in range [dstPtrs(i), dstPtrs(i+1)) of dst indices and @@ -1026,7 +1188,24 @@ object ALS extends DefaultParamsReadable[ALS] with Logging { } /** - * Partitions raw ratings into blocks. + * Groups an RDD of [[Rating]]s by the user partition and item partition to which each `Rating` + * maps according to the given partitioners. The returned pair RDD holds the ratings, encoded in + * a memory-efficient format but otherwise unchanged, keyed by the (user partition ID, item + * partition ID) pair. + * + * Performance note: This is an expensive operation that performs an RDD shuffle. + * + * Implementation note: This implementation produces the same result as the following but + * generates fewer intermediate objects: + * + * {{{ + * ratings.map { r => + * ((srcPart.getPartition(r.user), dstPart.getPartition(r.item)), r) + * }.aggregateByKey(new RatingBlockBuilder)( + * seqOp = (b, r) => b.add(r), + * combOp = (b0, b1) => b0.merge(b1.build())) + * .mapValues(_.build()) + * }}} * * @param ratings raw ratings * @param srcPart partitioner for src IDs @@ -1037,17 +1216,6 @@ object ALS extends DefaultParamsReadable[ALS] with Logging { ratings: RDD[Rating[ID]], srcPart: Partitioner, dstPart: Partitioner): RDD[((Int, Int), RatingBlock[ID])] = { - - /* The implementation produces the same result as the following but generates less objects. - - ratings.map { r => - ((srcPart.getPartition(r.user), dstPart.getPartition(r.item)), r) - }.aggregateByKey(new RatingBlockBuilder)( - seqOp = (b, r) => b.add(r), - combOp = (b0, b1) => b0.merge(b1.build())) - .mapValues(_.build()) - */ - val numPartitions = srcPart.numPartitions * dstPart.numPartitions ratings.mapPartitions { iter => val builders = Array.fill(numPartitions)(new RatingBlockBuilder[ID]) @@ -1135,8 +1303,8 @@ object ALS extends DefaultParamsReadable[ALS] with Logging { def length: Int = srcIds.length /** - * Compresses the block into an [[InBlock]]. The algorithm is the same as converting a - * sparse matrix from coordinate list (COO) format into compressed sparse column (CSC) format. + * Compresses the block into an `InBlock`. The algorithm is the same as converting a sparse + * matrix from coordinate list (COO) format into compressed sparse column (CSC) format. * Sorting is done using Spark's built-in Timsort to avoid generating too many objects. */ def compress(): InBlock[ID] = { From 2cf83c47838115f71419ba5b9296c69ec1d746cd Mon Sep 17 00:00:00 2001 From: Steve Loughran Date: Sun, 7 May 2017 10:15:31 +0100 Subject: [PATCH 0425/1765] [SPARK-7481][BUILD] Add spark-hadoop-cloud module to pull in object store access. ## What changes were proposed in this pull request? Add a new `spark-hadoop-cloud` module and maven profile to pull in object store support from `hadoop-openstack`, `hadoop-aws` and `hadoop-azure` (Hadoop 2.7+) JARs, along with their dependencies, fixing up the dependencies so that everything works, in particular Jackson. It restores `s3n://` access to S3, adds its `s3a://` replacement, OpenStack `swift://` and azure `wasb://`. There's a documentation page, `cloud_integration.md`, which covers the basic details of using Spark with object stores, referring the reader to the supplier's own documentation, with specific warnings on security and the possible mismatch between a store's behavior and that of a filesystem. In particular, users are advised be very cautious when trying to use an object store as the destination of data, and to consult the documentation of the storage supplier and the connector. (this is the successor to #12004; I can't re-open it) ## How was this patch tested? Downstream tests exist in [https://github.com/steveloughran/spark-cloud-examples/tree/master/cloud-examples](https://github.com/steveloughran/spark-cloud-examples/tree/master/cloud-examples) Those verify that the dependencies are sufficient to allow downstream applications to work with s3a, azure wasb and swift storage connectors, and perform basic IO & dataframe operations thereon. All seems well. Manually clean build & verify that assembly contains the relevant aws-* hadoop-* artifacts on Hadoop 2.6; azure on a hadoop-2.7 profile. SBT build: `build/sbt -Phadoop-cloud -Phadoop-2.7 package` maven build `mvn install -Phadoop-cloud -Phadoop-2.7` This PR *does not* update `dev/deps/spark-deps-hadoop-2.7` or `dev/deps/spark-deps-hadoop-2.6`, because unless the hadoop-cloud profile is enabled, no extra JARs show up in the dependency list. The dependency check in Jenkins isn't setting the property, so the new JARs aren't visible. Author: Steve Loughran Author: Steve Loughran Closes #17834 from steveloughran/cloud/SPARK-7481-current. --- assembly/pom.xml | 14 +++ docs/cloud-integration.md | 200 ++++++++++++++++++++++++++++++++ docs/index.md | 1 + docs/rdd-programming-guide.md | 6 +- docs/storage-openstack-swift.md | 38 ++---- hadoop-cloud/pom.xml | 185 +++++++++++++++++++++++++++++ pom.xml | 7 ++ project/SparkBuild.scala | 4 +- 8 files changed, 424 insertions(+), 31 deletions(-) create mode 100644 docs/cloud-integration.md create mode 100644 hadoop-cloud/pom.xml diff --git a/assembly/pom.xml b/assembly/pom.xml index 742a4a1531e71..464af16e46f6e 100644 --- a/assembly/pom.xml +++ b/assembly/pom.xml @@ -226,5 +226,19 @@ provided + + + + hadoop-cloud + + + org.apache.spark + spark-hadoop-cloud_${scala.binary.version} + ${project.version} + + + diff --git a/docs/cloud-integration.md b/docs/cloud-integration.md new file mode 100644 index 0000000000000..751a192da4ffd --- /dev/null +++ b/docs/cloud-integration.md @@ -0,0 +1,200 @@ +--- +layout: global +displayTitle: Integration with Cloud Infrastructures +title: Integration with Cloud Infrastructures +description: Introduction to cloud storage support in Apache Spark SPARK_VERSION_SHORT +--- + + +* This will become a table of contents (this text will be scraped). +{:toc} + +## Introduction + + +All major cloud providers offer persistent data storage in *object stores*. +These are not classic "POSIX" file systems. +In order to store hundreds of petabytes of data without any single points of failure, +object stores replace the classic filesystem directory tree +with a simpler model of `object-name => data`. To enable remote access, operations +on objects are usually offered as (slow) HTTP REST operations. + +Spark can read and write data in object stores through filesystem connectors implemented +in Hadoop or provided by the infrastructure suppliers themselves. +These connectors make the object stores look *almost* like filesystems, with directories and files +and the classic operations on them such as list, delete and rename. + + +### Important: Cloud Object Stores are Not Real Filesystems + +While the stores appear to be filesystems, underneath +they are still object stores, [and the difference is significant](https://hadoop.apache.org/docs/current/hadoop-project-dist/hadoop-common/filesystem/introduction.html) + +They cannot be used as a direct replacement for a cluster filesystem such as HDFS +*except where this is explicitly stated*. + +Key differences are: + +* Changes to stored objects may not be immediately visible, both in directory listings and actual data access. +* The means by which directories are emulated may make working with them slow. +* Rename operations may be very slow and, on failure, leave the store in an unknown state. +* Seeking within a file may require new HTTP calls, hurting performance. + +How does this affect Spark? + +1. Reading and writing data can be significantly slower than working with a normal filesystem. +1. Some directory structures may be very inefficient to scan during query split calculation. +1. The output of work may not be immediately visible to a follow-on query. +1. The rename-based algorithm by which Spark normally commits work when saving an RDD, DataFrame or Dataset + is potentially both slow and unreliable. + +For these reasons, it is not always safe to use an object store as a direct destination of queries, or as +an intermediate store in a chain of queries. Consult the documentation of the object store and its +connector to determine which uses are considered safe. + +In particular: *without some form of consistency layer, Amazon S3 cannot +be safely used as the direct destination of work with the normal rename-based committer.* + +### Installation + +With the relevant libraries on the classpath and Spark configured with valid credentials, +objects can be can be read or written by using their URLs as the path to data. +For example `sparkContext.textFile("s3a://landsat-pds/scene_list.gz")` will create +an RDD of the file `scene_list.gz` stored in S3, using the s3a connector. + +To add the relevant libraries to an application's classpath, include the `hadoop-cloud` +module and its dependencies. + +In Maven, add the following to the `pom.xml` file, assuming `spark.version` +is set to the chosen version of Spark: + +{% highlight xml %} + + ... + + org.apache.spark + hadoop-cloud_2.11 + ${spark.version} + + ... + +{% endhighlight %} + +Commercial products based on Apache Spark generally directly set up the classpath +for talking to cloud infrastructures, in which case this module may not be needed. + +### Authenticating + +Spark jobs must authenticate with the object stores to access data within them. + +1. When Spark is running in a cloud infrastructure, the credentials are usually automatically set up. +1. `spark-submit` reads the `AWS_ACCESS_KEY`, `AWS_SECRET_KEY` +and `AWS_SESSION_TOKEN` environment variables and sets the associated authentication options +for the `s3n` and `s3a` connectors to Amazon S3. +1. In a Hadoop cluster, settings may be set in the `core-site.xml` file. +1. Authentication details may be manually added to the Spark configuration in `spark-default.conf` +1. Alternatively, they can be programmatically set in the `SparkConf` instance used to configure +the application's `SparkContext`. + +*Important: never check authentication secrets into source code repositories, +especially public ones* + +Consult [the Hadoop documentation](https://hadoop.apache.org/docs/current/) for the relevant +configuration and security options. + +## Configuring + +Each cloud connector has its own set of configuration parameters, again, +consult the relevant documentation. + +### Recommended settings for writing to object stores + +For object stores whose consistency model means that rename-based commits are safe +use the `FileOutputCommitter` v2 algorithm for performance: + +``` +spark.hadoop.mapreduce.fileoutputcommitter.algorithm.version 2 +``` + +This does less renaming at the end of a job than the "version 1" algorithm. +As it still uses `rename()` to commit files, it is unsafe to use +when the object store does not have consistent metadata/listings. + +The committer can also be set to ignore failures when cleaning up temporary +files; this reduces the risk that a transient network problem is escalated into a +job failure: + +``` +spark.hadoop.mapreduce.fileoutputcommitter.cleanup-failures.ignored true +``` + +As storing temporary files can run up charges; delete +directories called `"_temporary"` on a regular basis to avoid this. + +### Parquet I/O Settings + +For optimal performance when working with Parquet data use the following settings: + +``` +spark.hadoop.parquet.enable.summary-metadata false +spark.sql.parquet.mergeSchema false +spark.sql.parquet.filterPushdown true +spark.sql.hive.metastorePartitionPruning true +``` + +These minimise the amount of data read during queries. + +### ORC I/O Settings + +For best performance when working with ORC data, use these settings: + +``` +spark.sql.orc.filterPushdown true +spark.sql.orc.splits.include.file.footer true +spark.sql.orc.cache.stripe.details.size 10000 +spark.sql.hive.metastorePartitionPruning true +``` + +Again, these minimise the amount of data read during queries. + +## Spark Streaming and Object Storage + +Spark Streaming can monitor files added to object stores, by +creating a `FileInputDStream` to monitor a path in the store through a call to +`StreamingContext.textFileStream()`. + +1. The time to scan for new files is proportional to the number of files +under the path, not the number of *new* files, so it can become a slow operation. +The size of the window needs to be set to handle this. + +1. Files only appear in an object store once they are completely written; there +is no need for a worklow of write-then-rename to ensure that files aren't picked up +while they are still being written. Applications can write straight to the monitored directory. + +1. Streams should only be checkpointed to an store implementing a fast and +atomic `rename()` operation Otherwise the checkpointing may be slow and potentially unreliable. + +## Further Reading + +Here is the documentation on the standard connectors both from Apache and the cloud providers. + +* [OpenStack Swift](https://hadoop.apache.org/docs/current/hadoop-openstack/index.html). Hadoop 2.6+ +* [Azure Blob Storage](https://hadoop.apache.org/docs/current/hadoop-aws/tools/hadoop-aws/index.html). Since Hadoop 2.7 +* [Azure Data Lake](https://hadoop.apache.org/docs/current/hadoop-azure-datalake/index.html). Since Hadoop 2.8 +* [Amazon S3 via S3A and S3N](https://hadoop.apache.org/docs/current/hadoop-aws/tools/hadoop-aws/index.html). Hadoop 2.6+ +* [Amazon EMR File System (EMRFS)](https://docs.aws.amazon.com/emr/latest/ManagementGuide/emr-fs.html). From Amazon +* [Google Cloud Storage Connector for Spark and Hadoop](https://cloud.google.com/hadoop/google-cloud-storage-connector). From Google + + diff --git a/docs/index.md b/docs/index.md index ad4f24ff1a5d1..960b968454d0e 100644 --- a/docs/index.md +++ b/docs/index.md @@ -126,6 +126,7 @@ options for deployment: * [Security](security.html): Spark security support * [Hardware Provisioning](hardware-provisioning.html): recommendations for cluster hardware * Integration with other storage systems: + * [Cloud Infrastructures](cloud-integration.html) * [OpenStack Swift](storage-openstack-swift.html) * [Building Spark](building-spark.html): build Spark using the Maven system * [Contributing to Spark](http://spark.apache.org/contributing.html) diff --git a/docs/rdd-programming-guide.md b/docs/rdd-programming-guide.md index e2bf2d7ca77ca..52e59df9990e9 100644 --- a/docs/rdd-programming-guide.md +++ b/docs/rdd-programming-guide.md @@ -323,7 +323,7 @@ One important parameter for parallel collections is the number of *partitions* t Spark can create distributed datasets from any storage source supported by Hadoop, including your local file system, HDFS, Cassandra, HBase, [Amazon S3](http://wiki.apache.org/hadoop/AmazonS3), etc. Spark supports text files, [SequenceFiles](http://hadoop.apache.org/common/docs/current/api/org/apache/hadoop/mapred/SequenceFileInputFormat.html), and any other Hadoop [InputFormat](http://hadoop.apache.org/docs/stable/api/org/apache/hadoop/mapred/InputFormat.html). -Text file RDDs can be created using `SparkContext`'s `textFile` method. This method takes an URI for the file (either a local path on the machine, or a `hdfs://`, `s3n://`, etc URI) and reads it as a collection of lines. Here is an example invocation: +Text file RDDs can be created using `SparkContext`'s `textFile` method. This method takes an URI for the file (either a local path on the machine, or a `hdfs://`, `s3a://`, etc URI) and reads it as a collection of lines. Here is an example invocation: {% highlight scala %} scala> val distFile = sc.textFile("data.txt") @@ -356,7 +356,7 @@ Apart from text files, Spark's Scala API also supports several other data format Spark can create distributed datasets from any storage source supported by Hadoop, including your local file system, HDFS, Cassandra, HBase, [Amazon S3](http://wiki.apache.org/hadoop/AmazonS3), etc. Spark supports text files, [SequenceFiles](http://hadoop.apache.org/common/docs/current/api/org/apache/hadoop/mapred/SequenceFileInputFormat.html), and any other Hadoop [InputFormat](http://hadoop.apache.org/docs/stable/api/org/apache/hadoop/mapred/InputFormat.html). -Text file RDDs can be created using `SparkContext`'s `textFile` method. This method takes an URI for the file (either a local path on the machine, or a `hdfs://`, `s3n://`, etc URI) and reads it as a collection of lines. Here is an example invocation: +Text file RDDs can be created using `SparkContext`'s `textFile` method. This method takes an URI for the file (either a local path on the machine, or a `hdfs://`, `s3a://`, etc URI) and reads it as a collection of lines. Here is an example invocation: {% highlight java %} JavaRDD distFile = sc.textFile("data.txt"); @@ -388,7 +388,7 @@ Apart from text files, Spark's Java API also supports several other data formats PySpark can create distributed datasets from any storage source supported by Hadoop, including your local file system, HDFS, Cassandra, HBase, [Amazon S3](http://wiki.apache.org/hadoop/AmazonS3), etc. Spark supports text files, [SequenceFiles](http://hadoop.apache.org/common/docs/current/api/org/apache/hadoop/mapred/SequenceFileInputFormat.html), and any other Hadoop [InputFormat](http://hadoop.apache.org/docs/stable/api/org/apache/hadoop/mapred/InputFormat.html). -Text file RDDs can be created using `SparkContext`'s `textFile` method. This method takes an URI for the file (either a local path on the machine, or a `hdfs://`, `s3n://`, etc URI) and reads it as a collection of lines. Here is an example invocation: +Text file RDDs can be created using `SparkContext`'s `textFile` method. This method takes an URI for the file (either a local path on the machine, or a `hdfs://`, `s3a://`, etc URI) and reads it as a collection of lines. Here is an example invocation: {% highlight python %} >>> distFile = sc.textFile("data.txt") diff --git a/docs/storage-openstack-swift.md b/docs/storage-openstack-swift.md index c39ef1ce59e1c..f4bb2353e3c49 100644 --- a/docs/storage-openstack-swift.md +++ b/docs/storage-openstack-swift.md @@ -8,7 +8,8 @@ same URI formats as in Hadoop. You can specify a path in Swift as input through URI of the form swift://container.PROVIDER/path. You will also need to set your Swift security credentials, through core-site.xml or via SparkContext.hadoopConfiguration. -Current Swift driver requires Swift to use Keystone authentication method. +The current Swift driver requires Swift to use the Keystone authentication method, or +its Rackspace-specific predecessor. # Configuring Swift for Better Data Locality @@ -19,41 +20,30 @@ Although not mandatory, it is recommended to configure the proxy server of Swift # Dependencies -The Spark application should include hadoop-openstack dependency. +The Spark application should include hadoop-openstack dependency, which can +be done by including the `hadoop-cloud` module for the specific version of spark used. For example, for Maven support, add the following to the pom.xml file: {% highlight xml %} ... - org.apache.hadoop - hadoop-openstack - 2.3.0 + org.apache.spark + hadoop-cloud_2.11 + ${spark.version} ... {% endhighlight %} - # Configuration Parameters Create core-site.xml and place it inside Spark's conf directory. -There are two main categories of parameters that should to be configured: declaration of the -Swift driver and the parameters that are required by Keystone. +The main category of parameters that should be configured are the authentication parameters +required by Keystone. -Configuration of Hadoop to use Swift File system achieved via - - - - - - - -
    Property NameValue
    fs.swift.implorg.apache.hadoop.fs.swift.snative.SwiftNativeFileSystem
    - -Additional parameters required by Keystone (v2.0) and should be provided to the Swift driver. Those -parameters will be used to perform authentication in Keystone to access Swift. The following table -contains a list of Keystone mandatory parameters. PROVIDER can be any name. +The following table contains a list of Keystone mandatory parameters. PROVIDER can be +any (alphanumeric) name. @@ -94,7 +84,7 @@ contains a list of Keystone mandatory parameters. PROVIDER can be a - +
    Property NameMeaningRequired
    fs.swift.service.PROVIDER.publicIndicates if all URLs are publicIndicates whether to use the public (off cloud) or private (in cloud; no transfer fees) endpoints Mandatory
    @@ -104,10 +94,6 @@ defined for tenant test. Then core-site.xml should inc {% highlight xml %} - - fs.swift.impl - org.apache.hadoop.fs.swift.snative.SwiftNativeFileSystem - fs.swift.service.SparkTest.auth.url http://127.0.0.1:5000/v2.0/tokens diff --git a/hadoop-cloud/pom.xml b/hadoop-cloud/pom.xml new file mode 100644 index 0000000000000..aa36dd4774d86 --- /dev/null +++ b/hadoop-cloud/pom.xml @@ -0,0 +1,185 @@ + + + + 4.0.0 + + org.apache.spark + spark-parent_2.11 + 2.3.0-SNAPSHOT + ../pom.xml + + + spark-hadoop-cloud_2.11 + jar + Spark Project Cloud Integration through Hadoop Libraries + + Contains support for cloud infrastructures, specifically the Hadoop JARs and + transitive dependencies needed to interact with the infrastructures, + making everything consistent with Spark's other dependencies. + + + hadoop-cloud + + + + + + org.apache.hadoop + hadoop-aws + ${hadoop.version} + ${hadoop.deps.scope} + + + org.apache.hadoop + hadoop-common + + + commons-logging + commons-logging + + + org.codehaus.jackson + jackson-mapper-asl + + + org.codehaus.jackson + jackson-core-asl + + + com.fasterxml.jackson.core + jackson-core + + + com.fasterxml.jackson.core + jackson-databind + + + com.fasterxml.jackson.core + jackson-annotations + + + + + org.apache.hadoop + hadoop-openstack + ${hadoop.version} + ${hadoop.deps.scope} + + + org.apache.hadoop + hadoop-common + + + commons-logging + commons-logging + + + junit + junit + + + org.mockito + mockito-all + + + + + + + joda-time + joda-time + ${hadoop.deps.scope} + + + + com.fasterxml.jackson.core + jackson-databind + ${hadoop.deps.scope} + + + com.fasterxml.jackson.core + jackson-annotations + ${hadoop.deps.scope} + + + com.fasterxml.jackson.dataformat + jackson-dataformat-cbor + ${fasterxml.jackson.version} + + + + org.apache.httpcomponents + httpclient + ${hadoop.deps.scope} + + + + org.apache.httpcomponents + httpcore + ${hadoop.deps.scope} + + + + + + + hadoop-2.7 + + + + + + org.apache.hadoop + hadoop-azure + ${hadoop.version} + ${hadoop.deps.scope} + + + org.apache.hadoop + hadoop-common + + + org.codehaus.jackson + jackson-mapper-asl + + + com.fasterxml.jackson.core + jackson-core + + + com.google.guava + guava + + + + + + + + + diff --git a/pom.xml b/pom.xml index a1a1817e2f7d3..0533a8dcf2e0a 100644 --- a/pom.xml +++ b/pom.xml @@ -2546,6 +2546,13 @@ + + hadoop-cloud + + hadoop-cloud + + + scala-2.10 diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala index e52baf51aed1a..b5362ec1ae452 100644 --- a/project/SparkBuild.scala +++ b/project/SparkBuild.scala @@ -57,9 +57,9 @@ object BuildCommons { ).map(ProjectRef(buildLocation, _)) ++ sqlProjects ++ streamingProjects val optionallyEnabledProjects@Seq(mesos, yarn, sparkGangliaLgpl, - streamingKinesisAsl, dockerIntegrationTests) = + streamingKinesisAsl, dockerIntegrationTests, hadoopCloud) = Seq("mesos", "yarn", "ganglia-lgpl", "streaming-kinesis-asl", - "docker-integration-tests").map(ProjectRef(buildLocation, _)) + "docker-integration-tests", "hadoop-cloud").map(ProjectRef(buildLocation, _)) val assemblyProjects@Seq(networkYarn, streamingFlumeAssembly, streamingKafkaAssembly, streamingKafka010Assembly, streamingKinesisAslAssembly) = Seq("network-yarn", "streaming-flume-assembly", "streaming-kafka-0-8-assembly", "streaming-kafka-0-10-assembly", "streaming-kinesis-asl-assembly") From 7087e01194964a1aad0b45bdb41506a17100eacf Mon Sep 17 00:00:00 2001 From: Felix Cheung Date: Sun, 7 May 2017 13:10:10 -0700 Subject: [PATCH 0426/1765] [SPARK-20543][SPARKR][FOLLOWUP] Don't skip tests on AppVeyor ## What changes were proposed in this pull request? add environment ## How was this patch tested? wait for appveyor run Author: Felix Cheung Closes #17878 from felixcheung/appveyorrcran. --- R/pkg/inst/tests/testthat/test_sparkSQL.R | 2 +- appveyor.yml | 4 +++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/R/pkg/inst/tests/testthat/test_sparkSQL.R b/R/pkg/inst/tests/testthat/test_sparkSQL.R index 47cc34a6c5b75..232246d6be9b4 100644 --- a/R/pkg/inst/tests/testthat/test_sparkSQL.R +++ b/R/pkg/inst/tests/testthat/test_sparkSQL.R @@ -3387,7 +3387,7 @@ compare_list <- function(list1, list2) { # This should always be the **very last test** in this test file. test_that("No extra files are created in SPARK_HOME by starting session and making calls", { - skip_on_cran() + skip_on_cran() # skip because when run from R CMD check SPARK_HOME is not the current directory # Check that it is not creating any extra file. # Does not check the tempdir which would be cleaned up after. diff --git a/appveyor.yml b/appveyor.yml index 4d31af70f056e..58c2e98289e96 100644 --- a/appveyor.yml +++ b/appveyor.yml @@ -48,6 +48,9 @@ install: build_script: - cmd: mvn -DskipTests -Psparkr -Phive -Phive-thriftserver package +environment: + NOT_CRAN: true + test_script: - cmd: .\bin\spark-submit2.cmd --driver-java-options "-Dlog4j.configuration=file:///%CD:\=/%/R/log4j.properties" --conf spark.hadoop.fs.defaultFS="file:///" R\pkg\tests\run-all.R @@ -56,4 +59,3 @@ notifications: on_build_success: false on_build_failure: false on_build_status_changed: false - From 500436b4368207db9e9b9cef83f9c11d33e31e1a Mon Sep 17 00:00:00 2001 From: Jacek Laskowski Date: Sun, 7 May 2017 13:56:13 -0700 Subject: [PATCH 0427/1765] [MINOR][SQL][DOCS] Improve unix_timestamp's scaladoc (and typo hunting) ## What changes were proposed in this pull request? * Docs are consistent (across different `unix_timestamp` variants and their internal expressions) * typo hunting ## How was this patch tested? local build Author: Jacek Laskowski Closes #17801 from jaceklaskowski/unix_timestamp. --- .../expressions/datetimeExpressions.scala | 6 ++--- .../sql/catalyst/util/DateTimeUtils.scala | 2 +- .../org/apache/spark/sql/functions.scala | 26 ++++++++++++------- 3 files changed, 21 insertions(+), 13 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala index bb8fd5032d63d..a98cd33f2780c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala @@ -488,7 +488,7 @@ case class DateFormatClass(left: Expression, right: Expression, timeZoneId: Opti * Deterministic version of [[UnixTimestamp]], must have at least one parameter. */ @ExpressionDescription( - usage = "_FUNC_(expr[, pattern]) - Returns the UNIX timestamp of the give time.", + usage = "_FUNC_(expr[, pattern]) - Returns the UNIX timestamp of the given time.", extended = """ Examples: > SELECT _FUNC_('2016-04-08', 'yyyy-MM-dd'); @@ -1225,8 +1225,8 @@ case class ParseToTimestamp(left: Expression, format: Expression, child: Express extends RuntimeReplaceable { def this(left: Expression, format: Expression) = { - this(left, format, Cast(UnixTimestamp(left, format), TimestampType)) -} + this(left, format, Cast(UnixTimestamp(left, format), TimestampType)) + } override def flatArguments: Iterator[Any] = Iterator(left, format) override def sql: String = s"$prettyName(${left.sql}, ${format.sql})" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala index eb6aad5b2d2bb..6c1592fd8881d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala @@ -423,7 +423,7 @@ object DateTimeUtils { } /** - * Parses a given UTF8 date string to the corresponding a corresponding [[Int]] value. + * Parses a given UTF8 date string to a corresponding [[Int]] value. * The return type is [[Option]] in order to distinguish between 0 and null. The following * formats are allowed: * diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index f07e04368389f..987011edfe1e5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -2491,10 +2491,10 @@ object functions { * Converts a date/timestamp/string to a value of string in the format specified by the date * format given by the second argument. * - * A pattern could be for instance `dd.MM.yyyy` and could return a string like '18.03.1993'. All - * pattern letters of `java.text.SimpleDateFormat` can be used. + * A pattern `dd.MM.yyyy` would return a string like `18.03.1993`. + * All pattern letters of `java.text.SimpleDateFormat` can be used. * - * @note Use when ever possible specialized functions like [[year]]. These benefit from a + * @note Use specialized functions like [[year]] whenever possible as they benefit from a * specialized implementation. * * @group datetime_funcs @@ -2647,7 +2647,11 @@ object functions { } /** - * Gets current Unix timestamp in seconds. + * Returns the current Unix timestamp (in seconds). + * + * @note All calls of `unix_timestamp` within the same query return the same value + * (i.e. the current timestamp is calculated at the start of query evaluation). + * * @group datetime_funcs * @since 1.5.0 */ @@ -2657,7 +2661,9 @@ object functions { /** * Converts time string in format yyyy-MM-dd HH:mm:ss to Unix timestamp (in seconds), - * using the default timezone and the default locale, return null if fail. + * using the default timezone and the default locale. + * Returns `null` if fails. + * * @group datetime_funcs * @since 1.5.0 */ @@ -2666,13 +2672,15 @@ object functions { } /** - * Convert time string with given pattern - * (see [http://docs.oracle.com/javase/tutorial/i18n/format/simpleDateFormat.html]) - * to Unix time stamp (in seconds), return null if fail. + * Converts time string with given pattern to Unix timestamp (in seconds). + * Returns `null` if fails. + * + * @see + * Customizing Formats * @group datetime_funcs * @since 1.5.0 */ - def unix_timestamp(s: Column, p: String): Column = withExpr {UnixTimestamp(s.expr, Literal(p)) } + def unix_timestamp(s: Column, p: String): Column = withExpr { UnixTimestamp(s.expr, Literal(p)) } /** * Convert time string to a Unix timestamp (in seconds). From 1f73d3589a84b78473598c17ac328a9805896778 Mon Sep 17 00:00:00 2001 From: zero323 Date: Sun, 7 May 2017 16:24:42 -0700 Subject: [PATCH 0428/1765] [SPARK-20550][SPARKR] R wrapper for Dataset.alias ## What changes were proposed in this pull request? - Add SparkR wrapper for `Dataset.alias`. - Adjust roxygen annotations for `functions.alias` (including example usage). ## How was this patch tested? Unit tests, `check_cran.sh`. Author: zero323 Closes #17825 from zero323/SPARK-20550. --- R/pkg/R/DataFrame.R | 24 +++++++++++++++++++++++ R/pkg/R/column.R | 16 +++++++-------- R/pkg/R/generics.R | 11 +++++++++++ R/pkg/inst/tests/testthat/test_sparkSQL.R | 10 ++++++++++ 4 files changed, 53 insertions(+), 8 deletions(-) diff --git a/R/pkg/R/DataFrame.R b/R/pkg/R/DataFrame.R index 1c8869202f677..b56dddcb9f2ef 100644 --- a/R/pkg/R/DataFrame.R +++ b/R/pkg/R/DataFrame.R @@ -3745,3 +3745,27 @@ setMethod("hint", jdf <- callJMethod(x@sdf, "hint", name, parameters) dataFrame(jdf) }) + +#' alias +#' +#' @aliases alias,SparkDataFrame-method +#' @family SparkDataFrame functions +#' @rdname alias +#' @name alias +#' @export +#' @examples +#' \dontrun{ +#' df <- alias(createDataFrame(mtcars), "mtcars") +#' avg_mpg <- alias(agg(groupBy(df, df$cyl), avg(df$mpg)), "avg_mpg") +#' +#' head(select(df, column("mtcars.mpg"))) +#' head(join(df, avg_mpg, column("mtcars.cyl") == column("avg_mpg.cyl"))) +#' } +#' @note alias(SparkDataFrame) since 2.3.0 +setMethod("alias", + signature(object = "SparkDataFrame"), + function(object, data) { + stopifnot(is.character(data)) + sdf <- callJMethod(object@sdf, "alias", data) + dataFrame(sdf) + }) diff --git a/R/pkg/R/column.R b/R/pkg/R/column.R index 147ee4b6887b9..574078012adad 100644 --- a/R/pkg/R/column.R +++ b/R/pkg/R/column.R @@ -130,19 +130,19 @@ createMethods <- function() { createMethods() -#' alias -#' -#' Set a new name for a column -#' -#' @param object Column to rename -#' @param data new name to use -#' #' @rdname alias #' @name alias #' @aliases alias,Column-method #' @family colum_func #' @export -#' @note alias since 1.4.0 +#' @examples \dontrun{ +#' df <- createDataFrame(iris) +#' +#' head(select( +#' df, alias(df$Sepal_Length, "slength"), alias(df$Petal_Length, "plength") +#' )) +#' } +#' @note alias(Column) since 1.4.0 setMethod("alias", signature(object = "Column"), function(object, data) { diff --git a/R/pkg/R/generics.R b/R/pkg/R/generics.R index e835ef3e4f40d..3c84bf8a4803e 100644 --- a/R/pkg/R/generics.R +++ b/R/pkg/R/generics.R @@ -387,6 +387,17 @@ setGeneric("value", function(bcast) { standardGeneric("value") }) #' @export setGeneric("agg", function (x, ...) { standardGeneric("agg") }) +#' alias +#' +#' Returns a new SparkDataFrame or a Column with an alias set. Equivalent to SQL "AS" keyword. +#' +#' @name alias +#' @rdname alias +#' @param object x a SparkDataFrame or a Column +#' @param data new name to use +#' @return a SparkDataFrame or a Column +NULL + #' @rdname arrange #' @export setGeneric("arrange", function(x, col, ...) { standardGeneric("arrange") }) diff --git a/R/pkg/inst/tests/testthat/test_sparkSQL.R b/R/pkg/inst/tests/testthat/test_sparkSQL.R index 232246d6be9b4..0856bab5686c5 100644 --- a/R/pkg/inst/tests/testthat/test_sparkSQL.R +++ b/R/pkg/inst/tests/testthat/test_sparkSQL.R @@ -1223,6 +1223,16 @@ test_that("select with column", { expect_equal(columns(df4), c("name", "age")) expect_equal(count(df4), 3) + # Test select with alias + df5 <- alias(df, "table") + + expect_equal(columns(select(df5, column("table.name"))), "name") + expect_equal(columns(select(df5, "table.name")), "name") + + # Test that stats::alias is not masked + expect_is(alias(aov(yield ~ block + N * P * K, npk)), "listof") + + expect_error(select(df, c("name", "age"), "name"), "To select multiple columns, use a character vector or list for col") }) From f53a820721fe0525c275e2bb4415c20909c42dc3 Mon Sep 17 00:00:00 2001 From: zero323 Date: Mon, 8 May 2017 10:58:27 +0800 Subject: [PATCH 0429/1765] [SPARK-16931][PYTHON][SQL] Add Python wrapper for bucketBy ## What changes were proposed in this pull request? Adds Python wrappers for `DataFrameWriter.bucketBy` and `DataFrameWriter.sortBy` ([SPARK-16931](https://issues.apache.org/jira/browse/SPARK-16931)) ## How was this patch tested? Unit tests covering new feature. __Note__: Based on work of GregBowyer (f49b9a23468f7af32cb53d2b654272757c151725) CC HyukjinKwon Author: zero323 Author: Greg Bowyer Closes #17077 from zero323/SPARK-16931. --- python/pyspark/sql/readwriter.py | 57 ++++++++++++++++++++++++++++++++ python/pyspark/sql/tests.py | 54 ++++++++++++++++++++++++++++++ 2 files changed, 111 insertions(+) diff --git a/python/pyspark/sql/readwriter.py b/python/pyspark/sql/readwriter.py index 960fb882cf901..90ce8f81eb7fd 100644 --- a/python/pyspark/sql/readwriter.py +++ b/python/pyspark/sql/readwriter.py @@ -563,6 +563,63 @@ def partitionBy(self, *cols): self._jwrite = self._jwrite.partitionBy(_to_seq(self._spark._sc, cols)) return self + @since(2.3) + def bucketBy(self, numBuckets, col, *cols): + """Buckets the output by the given columns.If specified, + the output is laid out on the file system similar to Hive's bucketing scheme. + + :param numBuckets: the number of buckets to save + :param col: a name of a column, or a list of names. + :param cols: additional names (optional). If `col` is a list it should be empty. + + .. note:: Applicable for file-based data sources in combination with + :py:meth:`DataFrameWriter.saveAsTable`. + + >>> (df.write.format('parquet') + ... .bucketBy(100, 'year', 'month') + ... .mode("overwrite") + ... .saveAsTable('bucketed_table')) + """ + if not isinstance(numBuckets, int): + raise TypeError("numBuckets should be an int, got {0}.".format(type(numBuckets))) + + if isinstance(col, (list, tuple)): + if cols: + raise ValueError("col is a {0} but cols are not empty".format(type(col))) + + col, cols = col[0], col[1:] + + if not all(isinstance(c, basestring) for c in cols) or not(isinstance(col, basestring)): + raise TypeError("all names should be `str`") + + self._jwrite = self._jwrite.bucketBy(numBuckets, col, _to_seq(self._spark._sc, cols)) + return self + + @since(2.3) + def sortBy(self, col, *cols): + """Sorts the output in each bucket by the given columns on the file system. + + :param col: a name of a column, or a list of names. + :param cols: additional names (optional). If `col` is a list it should be empty. + + >>> (df.write.format('parquet') + ... .bucketBy(100, 'year', 'month') + ... .sortBy('day') + ... .mode("overwrite") + ... .saveAsTable('sorted_bucketed_table')) + """ + if isinstance(col, (list, tuple)): + if cols: + raise ValueError("col is a {0} but cols are not empty".format(type(col))) + + col, cols = col[0], col[1:] + + if not all(isinstance(c, basestring) for c in cols) or not(isinstance(col, basestring)): + raise TypeError("all names should be `str`") + + self._jwrite = self._jwrite.sortBy(col, _to_seq(self._spark._sc, cols)) + return self + @since(1.4) def save(self, path=None, format=None, mode=None, partitionBy=None, **options): """Saves the contents of the :class:`DataFrame` to a data source. diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 7983bc536fc6c..e3fe01eae243f 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -211,6 +211,12 @@ def test_sqlcontext_reuses_sparksession(self): sqlContext2 = SQLContext(self.sc) self.assertTrue(sqlContext1.sparkSession is sqlContext2.sparkSession) + def tearDown(self): + super(SQLTests, self).tearDown() + + # tear down test_bucketed_write state + self.spark.sql("DROP TABLE IF EXISTS pyspark_bucket") + def test_row_should_be_read_only(self): row = Row(a=1, b=2) self.assertEqual(1, row.a) @@ -2196,6 +2202,54 @@ def test_BinaryType_serialization(self): df = self.spark.createDataFrame(data, schema=schema) df.collect() + def test_bucketed_write(self): + data = [ + (1, "foo", 3.0), (2, "foo", 5.0), + (3, "bar", -1.0), (4, "bar", 6.0), + ] + df = self.spark.createDataFrame(data, ["x", "y", "z"]) + + def count_bucketed_cols(names, table="pyspark_bucket"): + """Given a sequence of column names and a table name + query the catalog and return number o columns which are + used for bucketing + """ + cols = self.spark.catalog.listColumns(table) + num = len([c for c in cols if c.name in names and c.isBucket]) + return num + + # Test write with one bucketing column + df.write.bucketBy(3, "x").mode("overwrite").saveAsTable("pyspark_bucket") + self.assertEqual(count_bucketed_cols(["x"]), 1) + self.assertSetEqual(set(data), set(self.spark.table("pyspark_bucket").collect())) + + # Test write two bucketing columns + df.write.bucketBy(3, "x", "y").mode("overwrite").saveAsTable("pyspark_bucket") + self.assertEqual(count_bucketed_cols(["x", "y"]), 2) + self.assertSetEqual(set(data), set(self.spark.table("pyspark_bucket").collect())) + + # Test write with bucket and sort + df.write.bucketBy(2, "x").sortBy("z").mode("overwrite").saveAsTable("pyspark_bucket") + self.assertEqual(count_bucketed_cols(["x"]), 1) + self.assertSetEqual(set(data), set(self.spark.table("pyspark_bucket").collect())) + + # Test write with a list of columns + df.write.bucketBy(3, ["x", "y"]).mode("overwrite").saveAsTable("pyspark_bucket") + self.assertEqual(count_bucketed_cols(["x", "y"]), 2) + self.assertSetEqual(set(data), set(self.spark.table("pyspark_bucket").collect())) + + # Test write with bucket and sort with a list of columns + (df.write.bucketBy(2, "x") + .sortBy(["y", "z"]) + .mode("overwrite").saveAsTable("pyspark_bucket")) + self.assertSetEqual(set(data), set(self.spark.table("pyspark_bucket").collect())) + + # Test write with bucket and sort with multiple columns + (df.write.bucketBy(2, "x") + .sortBy("y", "z") + .mode("overwrite").saveAsTable("pyspark_bucket")) + self.assertSetEqual(set(data), set(self.spark.table("pyspark_bucket").collect())) + class HiveSparkSubmitTests(SparkSubmitTests): From 22691556e5f0dfbac81b8cc9ca0a67c70c1711ca Mon Sep 17 00:00:00 2001 From: Imran Rashid Date: Mon, 8 May 2017 12:16:00 +0900 Subject: [PATCH 0430/1765] [SPARK-12297][SQL] Hive compatibility for Parquet Timestamps ## What changes were proposed in this pull request? This change allows timestamps in parquet-based hive table to behave as a "floating time", without a timezone, as timestamps are for other file formats. If the storage timezone is the same as the session timezone, this conversion is a no-op. When data is read from a hive table, the table property is *always* respected. This allows spark to not change behavior when reading old data, but read newly written data correctly (whatever the source of the data is). Spark inherited the original behavior from Hive, but Hive is also updating behavior to use the same scheme in HIVE-12767 / HIVE-16231. The default for Spark remains unchanged; created tables do not include the new table property. This will only apply to hive tables; nothing is added to parquet metadata to indicate the timezone, so data that is read or written directly from parquet files will never have any conversions applied. ## How was this patch tested? Added a unit test which creates tables, reads and writes data, under a variety of permutations (different storage timezones, different session timezones, vectorized reading on and off). Author: Imran Rashid Closes #16781 from squito/SPARK-12297. --- .../sql/catalyst/catalog/interface.scala | 4 +- .../sql/catalyst/util/DateTimeUtils.scala | 5 + .../parquet/VectorizedColumnReader.java | 28 +- .../VectorizedParquetRecordReader.java | 6 +- .../spark/sql/execution/command/tables.scala | 8 +- .../parquet/ParquetFileFormat.scala | 2 + .../parquet/ParquetReadSupport.scala | 3 +- .../parquet/ParquetRecordMaterializer.scala | 9 +- .../parquet/ParquetRowConverter.scala | 53 ++- .../parquet/ParquetWriteSupport.scala | 25 +- .../spark/sql/hive/HiveExternalCatalog.scala | 11 +- .../spark/sql/hive/HiveMetastoreCatalog.scala | 12 +- .../hive/ParquetHiveCompatibilitySuite.scala | 379 +++++++++++++++++- 13 files changed, 516 insertions(+), 29 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala index cc0cbba275b81..c39017ebbfe60 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala @@ -132,10 +132,10 @@ case class CatalogTablePartition( /** * Given the partition schema, returns a row with that schema holding the partition values. */ - def toRow(partitionSchema: StructType, defaultTimeZondId: String): InternalRow = { + def toRow(partitionSchema: StructType, defaultTimeZoneId: String): InternalRow = { val caseInsensitiveProperties = CaseInsensitiveMap(storage.properties) val timeZoneId = caseInsensitiveProperties.getOrElse( - DateTimeUtils.TIMEZONE_OPTION, defaultTimeZondId) + DateTimeUtils.TIMEZONE_OPTION, defaultTimeZoneId) InternalRow.fromSeq(partitionSchema.map { field => val partValue = if (spec(field.name) == ExternalCatalogUtils.DEFAULT_PARTITION_NAME) { null diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala index 6c1592fd8881d..bf596fa0a89db 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala @@ -498,6 +498,11 @@ object DateTimeUtils { false } + lazy val validTimezones = TimeZone.getAvailableIDs().toSet + def isValidTimezone(timezoneId: String): Boolean = { + validTimezones.contains(timezoneId) + } + /** * Returns the microseconds since year zero (-17999) from microseconds since epoch. */ diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedColumnReader.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedColumnReader.java index 9d641b528723a..dabbc2b6387e4 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedColumnReader.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedColumnReader.java @@ -18,7 +18,9 @@ package org.apache.spark.sql.execution.datasources.parquet; import java.io.IOException; +import java.util.TimeZone; +import org.apache.hadoop.conf.Configuration; import org.apache.parquet.bytes.BytesUtils; import org.apache.parquet.column.ColumnDescriptor; import org.apache.parquet.column.Dictionary; @@ -30,6 +32,7 @@ import org.apache.spark.sql.catalyst.util.DateTimeUtils; import org.apache.spark.sql.execution.vectorized.ColumnVector; +import org.apache.spark.sql.internal.SQLConf; import org.apache.spark.sql.types.DataTypes; import org.apache.spark.sql.types.DecimalType; @@ -90,11 +93,30 @@ public class VectorizedColumnReader { private final PageReader pageReader; private final ColumnDescriptor descriptor; + private final TimeZone storageTz; + private final TimeZone sessionTz; - public VectorizedColumnReader(ColumnDescriptor descriptor, PageReader pageReader) + public VectorizedColumnReader(ColumnDescriptor descriptor, PageReader pageReader, + Configuration conf) throws IOException { this.descriptor = descriptor; this.pageReader = pageReader; + // If the table has a timezone property, apply the correct conversions. See SPARK-12297. + // The conf is sometimes null in tests. + String sessionTzString = + conf == null ? null : conf.get(SQLConf.SESSION_LOCAL_TIMEZONE().key()); + if (sessionTzString == null || sessionTzString.isEmpty()) { + sessionTz = DateTimeUtils.defaultTimeZone(); + } else { + sessionTz = TimeZone.getTimeZone(sessionTzString); + } + String storageTzString = + conf == null ? null : conf.get(ParquetFileFormat.PARQUET_TIMEZONE_TABLE_PROPERTY()); + if (storageTzString == null || storageTzString.isEmpty()) { + storageTz = sessionTz; + } else { + storageTz = TimeZone.getTimeZone(storageTzString); + } this.maxDefLevel = descriptor.getMaxDefinitionLevel(); DictionaryPage dictionaryPage = pageReader.readDictionaryPage(); @@ -289,7 +311,7 @@ private void decodeDictionaryIds(int rowId, int num, ColumnVector column, // TODO: Convert dictionary of Binaries to dictionary of Longs if (!column.isNullAt(i)) { Binary v = dictionary.decodeToBinary(dictionaryIds.getDictId(i)); - column.putLong(i, ParquetRowConverter.binaryToSQLTimestamp(v)); + column.putLong(i, ParquetRowConverter.binaryToSQLTimestamp(v, sessionTz, storageTz)); } } } else { @@ -422,7 +444,7 @@ private void readBinaryBatch(int rowId, int num, ColumnVector column) throws IOE if (defColumn.readInteger() == maxDefLevel) { column.putLong(rowId + i, // Read 12 bytes for INT96 - ParquetRowConverter.binaryToSQLTimestamp(data.readBinary(12))); + ParquetRowConverter.binaryToSQLTimestamp(data.readBinary(12), sessionTz, storageTz)); } else { column.putNull(rowId + i); } diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedParquetRecordReader.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedParquetRecordReader.java index 51bdf0f0f2291..d8974ddf24704 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedParquetRecordReader.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedParquetRecordReader.java @@ -21,6 +21,7 @@ import java.util.Arrays; import java.util.List; +import org.apache.hadoop.conf.Configuration; import org.apache.hadoop.mapreduce.InputSplit; import org.apache.hadoop.mapreduce.TaskAttemptContext; import org.apache.parquet.column.ColumnDescriptor; @@ -95,6 +96,8 @@ public class VectorizedParquetRecordReader extends SpecificParquetRecordReaderBa */ private boolean returnColumnarBatch; + private Configuration conf; + /** * The default config on whether columnarBatch should be offheap. */ @@ -107,6 +110,7 @@ public class VectorizedParquetRecordReader extends SpecificParquetRecordReaderBa public void initialize(InputSplit inputSplit, TaskAttemptContext taskAttemptContext) throws IOException, InterruptedException, UnsupportedOperationException { super.initialize(inputSplit, taskAttemptContext); + this.conf = taskAttemptContext.getConfiguration(); initializeInternal(); } @@ -277,7 +281,7 @@ private void checkEndOfRowGroup() throws IOException { for (int i = 0; i < columns.size(); ++i) { if (missingColumns[i]) continue; columnReaders[i] = new VectorizedColumnReader(columns.get(i), - pages.getPageReader(columns.get(i))); + pages.getPageReader(columns.get(i)), conf); } totalCountLoadedSoFar += pages.getRowCount(); } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala index ebf03e1bf8869..5843c5b56d44c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala @@ -26,7 +26,6 @@ import scala.collection.mutable.ArrayBuffer import scala.util.control.NonFatal import scala.util.Try -import org.apache.commons.lang3.StringEscapeUtils import org.apache.hadoop.fs.Path import org.apache.spark.sql.{AnalysisException, Row, SparkSession} @@ -37,7 +36,7 @@ import org.apache.spark.sql.catalyst.catalog.CatalogTableType._ import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference} import org.apache.spark.sql.catalyst.util.quoteIdentifier -import org.apache.spark.sql.execution.datasources.{DataSource, FileFormat, PartitioningUtils} +import org.apache.spark.sql.execution.datasources.{DataSource, PartitioningUtils} import org.apache.spark.sql.execution.datasources.csv.CSVFileFormat import org.apache.spark.sql.execution.datasources.json.JsonFileFormat import org.apache.spark.sql.execution.datasources.parquet.ParquetFileFormat @@ -74,6 +73,10 @@ case class CreateTableLikeCommand( sourceTableDesc.provider } + val properties = sourceTableDesc.properties.filter { case (k, _) => + k == ParquetFileFormat.PARQUET_TIMEZONE_TABLE_PROPERTY + } + // If the location is specified, we create an external table internally. // Otherwise create a managed table. val tblType = if (location.isEmpty) CatalogTableType.MANAGED else CatalogTableType.EXTERNAL @@ -86,6 +89,7 @@ case class CreateTableLikeCommand( locationUri = location.map(CatalogUtils.stringToURI(_))), schema = sourceTableDesc.schema, provider = newProvider, + properties = properties, partitionColumnNames = sourceTableDesc.partitionColumnNames, bucketSpec = sourceTableDesc.bucketSpec) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala index 2f3a2c62b912c..8113768cd793f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala @@ -632,4 +632,6 @@ object ParquetFileFormat extends Logging { Failure(cause) }.toOption } + + val PARQUET_TIMEZONE_TABLE_PROPERTY = "parquet.mr.int96.write.zone" } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetReadSupport.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetReadSupport.scala index f1a35dd8a6200..bf395a0bef745 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetReadSupport.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetReadSupport.scala @@ -95,7 +95,8 @@ private[parquet] class ParquetReadSupport extends ReadSupport[UnsafeRow] with Lo new ParquetRecordMaterializer( parquetRequestedSchema, ParquetReadSupport.expandUDT(catalystRequestedSchema), - new ParquetSchemaConverter(conf)) + new ParquetSchemaConverter(conf), + conf) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRecordMaterializer.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRecordMaterializer.scala index 4e49a0dac97c0..df041996cdea9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRecordMaterializer.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRecordMaterializer.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.execution.datasources.parquet +import org.apache.hadoop.conf.Configuration import org.apache.parquet.io.api.{GroupConverter, RecordMaterializer} import org.apache.parquet.schema.MessageType @@ -29,13 +30,17 @@ import org.apache.spark.sql.types.StructType * @param parquetSchema Parquet schema of the records to be read * @param catalystSchema Catalyst schema of the rows to be constructed * @param schemaConverter A Parquet-Catalyst schema converter that helps initializing row converters + * @param hadoopConf hadoop Configuration for passing extra params for parquet conversion */ private[parquet] class ParquetRecordMaterializer( - parquetSchema: MessageType, catalystSchema: StructType, schemaConverter: ParquetSchemaConverter) + parquetSchema: MessageType, + catalystSchema: StructType, + schemaConverter: ParquetSchemaConverter, + hadoopConf: Configuration) extends RecordMaterializer[UnsafeRow] { private val rootConverter = - new ParquetRowConverter(schemaConverter, parquetSchema, catalystSchema, NoopUpdater) + new ParquetRowConverter(schemaConverter, parquetSchema, catalystSchema, hadoopConf, NoopUpdater) override def getCurrentRecord: UnsafeRow = rootConverter.currentRecord diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRowConverter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRowConverter.scala index 32e6c60cd9766..d52ff62d93b26 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRowConverter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRowConverter.scala @@ -19,10 +19,12 @@ package org.apache.spark.sql.execution.datasources.parquet import java.math.{BigDecimal, BigInteger} import java.nio.ByteOrder +import java.util.TimeZone import scala.collection.JavaConverters._ import scala.collection.mutable.ArrayBuffer +import org.apache.hadoop.conf.Configuration import org.apache.parquet.column.Dictionary import org.apache.parquet.io.api.{Binary, Converter, GroupConverter, PrimitiveConverter} import org.apache.parquet.schema.{GroupType, MessageType, OriginalType, Type} @@ -34,6 +36,7 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, DateTimeUtils, GenericArrayData} import org.apache.spark.sql.catalyst.util.DateTimeUtils.SQLTimestamp +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String @@ -117,12 +120,14 @@ private[parquet] class ParquetPrimitiveConverter(val updater: ParentContainerUpd * @param parquetType Parquet schema of Parquet records * @param catalystType Spark SQL schema that corresponds to the Parquet record type. User-defined * types should have been expanded. + * @param hadoopConf a hadoop Configuration for passing any extra parameters for parquet conversion * @param updater An updater which propagates converted field values to the parent container */ private[parquet] class ParquetRowConverter( schemaConverter: ParquetSchemaConverter, parquetType: GroupType, catalystType: StructType, + hadoopConf: Configuration, updater: ParentContainerUpdater) extends ParquetGroupConverter(updater) with Logging { @@ -261,18 +266,18 @@ private[parquet] class ParquetRowConverter( case TimestampType => // TODO Implements `TIMESTAMP_MICROS` once parquet-mr has that. + // If the table has a timezone property, apply the correct conversions. See SPARK-12297. + val sessionTzString = hadoopConf.get(SQLConf.SESSION_LOCAL_TIMEZONE.key) + val sessionTz = Option(sessionTzString).map(TimeZone.getTimeZone(_)) + .getOrElse(DateTimeUtils.defaultTimeZone()) + val storageTzString = hadoopConf.get(ParquetFileFormat.PARQUET_TIMEZONE_TABLE_PROPERTY) + val storageTz = Option(storageTzString).map(TimeZone.getTimeZone(_)).getOrElse(sessionTz) new ParquetPrimitiveConverter(updater) { // Converts nanosecond timestamps stored as INT96 override def addBinary(value: Binary): Unit = { - assert( - value.length() == 12, - "Timestamps (with nanoseconds) are expected to be stored in 12-byte long binaries, " + - s"but got a ${value.length()}-byte binary.") - - val buf = value.toByteBuffer.order(ByteOrder.LITTLE_ENDIAN) - val timeOfDayNanos = buf.getLong - val julianDay = buf.getInt - updater.setLong(DateTimeUtils.fromJulianDay(julianDay, timeOfDayNanos)) + val timestamp = ParquetRowConverter.binaryToSQLTimestamp(value, sessionTz = sessionTz, + storageTz = storageTz) + updater.setLong(timestamp) } } @@ -302,7 +307,7 @@ private[parquet] class ParquetRowConverter( case t: StructType => new ParquetRowConverter( - schemaConverter, parquetType.asGroupType(), t, new ParentContainerUpdater { + schemaConverter, parquetType.asGroupType(), t, hadoopConf, new ParentContainerUpdater { override def set(value: Any): Unit = updater.set(value.asInstanceOf[InternalRow].copy()) }) @@ -651,6 +656,7 @@ private[parquet] class ParquetRowConverter( } private[parquet] object ParquetRowConverter { + def binaryToUnscaledLong(binary: Binary): Long = { // The underlying `ByteBuffer` implementation is guaranteed to be `HeapByteBuffer`, so here // we are using `Binary.toByteBuffer.array()` to steal the underlying byte array without @@ -673,12 +679,35 @@ private[parquet] object ParquetRowConverter { unscaled } - def binaryToSQLTimestamp(binary: Binary): SQLTimestamp = { + /** + * Converts an int96 to a SQLTimestamp, given both the storage timezone and the local timezone. + * The timestamp is really meant to be interpreted as a "floating time", but since we + * actually store it as micros since epoch, why we have to apply a conversion when timezones + * change. + * + * @param binary a parquet Binary which holds one int96 + * @param sessionTz the session timezone. This will be used to determine how to display the time, + * and compute functions on the timestamp which involve a timezone, eg. extract + * the hour. + * @param storageTz the timezone which was used to store the timestamp. This should come from the + * timestamp table property, or else assume its the same as the sessionTz + * @return a timestamp (millis since epoch) which will render correctly in the sessionTz + */ + def binaryToSQLTimestamp( + binary: Binary, + sessionTz: TimeZone, + storageTz: TimeZone): SQLTimestamp = { assert(binary.length() == 12, s"Timestamps (with nanoseconds) are expected to be stored in" + s" 12-byte long binaries. Found a ${binary.length()}-byte binary instead.") val buffer = binary.toByteBuffer.order(ByteOrder.LITTLE_ENDIAN) val timeOfDayNanos = buffer.getLong val julianDay = buffer.getInt - DateTimeUtils.fromJulianDay(julianDay, timeOfDayNanos) + val utcEpochMicros = DateTimeUtils.fromJulianDay(julianDay, timeOfDayNanos) + // avoid expensive time logic if possible. + if (sessionTz.getID() != storageTz.getID()) { + DateTimeUtils.convertTz(utcEpochMicros, sessionTz, storageTz) + } else { + utcEpochMicros + } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetWriteSupport.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetWriteSupport.scala index 38b0e33937f3c..679ed8e361b74 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetWriteSupport.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetWriteSupport.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.execution.datasources.parquet import java.nio.{ByteBuffer, ByteOrder} import java.util +import java.util.TimeZone import scala.collection.JavaConverters.mapAsJavaMapConverter @@ -75,6 +76,9 @@ private[parquet] class ParquetWriteSupport extends WriteSupport[InternalRow] wit // Reusable byte array used to write decimal values private val decimalBuffer = new Array[Byte](minBytesForPrecision(DecimalType.MAX_PRECISION)) + private var storageTz: TimeZone = _ + private var sessionTz: TimeZone = _ + override def init(configuration: Configuration): WriteContext = { val schemaString = configuration.get(ParquetWriteSupport.SPARK_ROW_SCHEMA) this.schema = StructType.fromString(schemaString) @@ -91,6 +95,19 @@ private[parquet] class ParquetWriteSupport extends WriteSupport[InternalRow] wit this.rootFieldWriters = schema.map(_.dataType).map(makeWriter) + // If the table has a timezone property, apply the correct conversions. See SPARK-12297. + val sessionTzString = configuration.get(SQLConf.SESSION_LOCAL_TIMEZONE.key) + sessionTz = if (sessionTzString == null || sessionTzString == "") { + TimeZone.getDefault() + } else { + TimeZone.getTimeZone(sessionTzString) + } + val storageTzString = configuration.get(ParquetFileFormat.PARQUET_TIMEZONE_TABLE_PROPERTY) + storageTz = if (storageTzString == null || storageTzString == "") { + sessionTz + } else { + TimeZone.getTimeZone(storageTzString) + } val messageType = new ParquetSchemaConverter(configuration).convert(schema) val metadata = Map(ParquetReadSupport.SPARK_METADATA_KEY -> schemaString).asJava @@ -178,7 +195,13 @@ private[parquet] class ParquetWriteSupport extends WriteSupport[InternalRow] wit // NOTE: Starting from Spark 1.5, Spark SQL `TimestampType` only has microsecond // precision. Nanosecond parts of timestamp values read from INT96 are simply stripped. - val (julianDay, timeOfDayNanos) = DateTimeUtils.toJulianDay(row.getLong(ordinal)) + val rawMicros = row.getLong(ordinal) + val adjustedMicros = if (sessionTz.getID() == storageTz.getID()) { + rawMicros + } else { + DateTimeUtils.convertTz(rawMicros, storageTz, sessionTz) + } + val (julianDay, timeOfDayNanos) = DateTimeUtils.toJulianDay(adjustedMicros) val buf = ByteBuffer.wrap(timestampBuffer) buf.order(ByteOrder.LITTLE_ENDIAN).putLong(timeOfDayNanos).putInt(julianDay) recordConsumer.addBinary(Binary.fromReusedByteArray(timestampBuffer)) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala index ba48facff2933..8fef467f5f5cb 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala @@ -39,9 +39,10 @@ import org.apache.spark.sql.catalyst.catalog._ import org.apache.spark.sql.catalyst.catalog.ExternalCatalogUtils._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical.ColumnStat -import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap +import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, DateTimeUtils} import org.apache.spark.sql.execution.command.DDLUtils import org.apache.spark.sql.execution.datasources.PartitioningUtils +import org.apache.spark.sql.execution.datasources.parquet.ParquetFileFormat import org.apache.spark.sql.hive.client.HiveClient import org.apache.spark.sql.internal.HiveSerDe import org.apache.spark.sql.internal.StaticSQLConf._ @@ -224,6 +225,14 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat throw new TableAlreadyExistsException(db = db, table = table) } + val tableTz = tableDefinition.properties.get(ParquetFileFormat.PARQUET_TIMEZONE_TABLE_PROPERTY) + tableTz.foreach { tz => + if (!DateTimeUtils.isValidTimezone(tz)) { + throw new AnalysisException(s"Cannot set" + + s" ${ParquetFileFormat.PARQUET_TIMEZONE_TABLE_PROPERTY} to invalid timezone $tz") + } + } + if (tableDefinition.tableType == VIEW) { client.createTable(tableDefinition, ignoreIfExists) } else { diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala index 6b98066cb76c8..e0b565c0d79a0 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala @@ -29,6 +29,7 @@ import org.apache.spark.sql.catalyst.{QualifiedTableName, TableIdentifier} import org.apache.spark.sql.catalyst.catalog._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.execution.datasources._ +import org.apache.spark.sql.execution.datasources.parquet.ParquetFileFormat import org.apache.spark.sql.internal.SQLConf.HiveCaseSensitiveInferenceMode._ import org.apache.spark.sql.types._ @@ -174,7 +175,7 @@ private[hive] class HiveMetastoreCatalog(sparkSession: SparkSession) extends Log // We don't support hive bucketed tables, only ones we write out. bucketSpec = None, fileFormat = fileFormat, - options = options)(sparkSession = sparkSession) + options = options ++ getStorageTzOptions(relation))(sparkSession = sparkSession) val created = LogicalRelation(fsRelation, updatedTable) tableRelationCache.put(tableIdentifier, created) created @@ -201,7 +202,7 @@ private[hive] class HiveMetastoreCatalog(sparkSession: SparkSession) extends Log userSpecifiedSchema = Option(dataSchema), // We don't support hive bucketed tables, only ones we write out. bucketSpec = None, - options = options, + options = options ++ getStorageTzOptions(relation), className = fileType).resolveRelation(), table = updatedTable) @@ -222,6 +223,13 @@ private[hive] class HiveMetastoreCatalog(sparkSession: SparkSession) extends Log result.copy(output = newOutput) } + private def getStorageTzOptions(relation: CatalogRelation): Map[String, String] = { + // We add the table timezone to the relation options, which automatically gets injected into the + // hadoopConf for the Parquet Converters + val storageTzKey = ParquetFileFormat.PARQUET_TIMEZONE_TABLE_PROPERTY + relation.tableMeta.properties.get(storageTzKey).map(storageTzKey -> _).toMap + } + private def inferIfNeeded( relation: CatalogRelation, options: Map[String, String], diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/ParquetHiveCompatibilitySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/ParquetHiveCompatibilitySuite.scala index 05b6059472f59..2bfd63d9b56e6 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/ParquetHiveCompatibilitySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/ParquetHiveCompatibilitySuite.scala @@ -17,12 +17,22 @@ package org.apache.spark.sql.hive +import java.io.File +import java.net.URLDecoder import java.sql.Timestamp +import java.util.TimeZone -import org.apache.spark.sql.Row -import org.apache.spark.sql.execution.datasources.parquet.ParquetCompatibilityTest +import org.apache.hadoop.fs.{FileSystem, Path} +import org.apache.parquet.hadoop.ParquetFileReader +import org.apache.parquet.schema.PrimitiveType.PrimitiveTypeName + +import org.apache.spark.sql.{AnalysisException, Dataset, Row, SparkSession} +import org.apache.spark.sql.catalyst.TableIdentifier +import org.apache.spark.sql.execution.datasources.parquet.{ParquetCompatibilityTest, ParquetFileFormat} +import org.apache.spark.sql.functions._ import org.apache.spark.sql.hive.test.TestHiveSingleton import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.types.{StringType, StructType, TimestampType} class ParquetHiveCompatibilitySuite extends ParquetCompatibilityTest with TestHiveSingleton { /** @@ -141,4 +151,369 @@ class ParquetHiveCompatibilitySuite extends ParquetCompatibilityTest with TestHi Row(Seq(Row(1))), "ARRAY>") } + + val testTimezones = Seq( + "UTC" -> "UTC", + "LA" -> "America/Los_Angeles", + "Berlin" -> "Europe/Berlin" + ) + // Check creating parquet tables with timestamps, writing data into them, and reading it back out + // under a variety of conditions: + // * tables with explicit tz and those without + // * altering table properties directly + // * variety of timezones, local & non-local + val sessionTimezones = testTimezones.map(_._2).map(Some(_)) ++ Seq(None) + sessionTimezones.foreach { sessionTzOpt => + val sparkSession = spark.newSession() + sessionTzOpt.foreach { tz => sparkSession.conf.set(SQLConf.SESSION_LOCAL_TIMEZONE.key, tz) } + testCreateWriteRead(sparkSession, "no_tz", None, sessionTzOpt) + val localTz = TimeZone.getDefault.getID() + testCreateWriteRead(sparkSession, "local", Some(localTz), sessionTzOpt) + // check with a variety of timezones. The unit tests currently are configured to always use + // America/Los_Angeles, but even if they didn't, we'd be sure to cover a non-local timezone. + testTimezones.foreach { case (tableName, zone) => + if (zone != localTz) { + testCreateWriteRead(sparkSession, tableName, Some(zone), sessionTzOpt) + } + } + } + + private def testCreateWriteRead( + sparkSession: SparkSession, + baseTable: String, + explicitTz: Option[String], + sessionTzOpt: Option[String]): Unit = { + testCreateAlterTablesWithTimezone(sparkSession, baseTable, explicitTz, sessionTzOpt) + testWriteTablesWithTimezone(sparkSession, baseTable, explicitTz, sessionTzOpt) + testReadTablesWithTimezone(sparkSession, baseTable, explicitTz, sessionTzOpt) + } + + private def checkHasTz(spark: SparkSession, table: String, tz: Option[String]): Unit = { + val tableMetadata = spark.sessionState.catalog.getTableMetadata(TableIdentifier(table)) + assert(tableMetadata.properties.get(ParquetFileFormat.PARQUET_TIMEZONE_TABLE_PROPERTY) === tz) + } + + private def testCreateAlterTablesWithTimezone( + spark: SparkSession, + baseTable: String, + explicitTz: Option[String], + sessionTzOpt: Option[String]): Unit = { + test(s"SPARK-12297: Create and Alter Parquet tables and timezones; explicitTz = $explicitTz; " + + s"sessionTzOpt = $sessionTzOpt") { + val key = ParquetFileFormat.PARQUET_TIMEZONE_TABLE_PROPERTY + withTable(baseTable, s"like_$baseTable", s"select_$baseTable", s"partitioned_$baseTable") { + // If we ever add a property to set the table timezone by default, defaultTz would change + val defaultTz = None + // check that created tables have correct TBLPROPERTIES + val tblProperties = explicitTz.map { + tz => s"""TBLPROPERTIES ($key="$tz")""" + }.getOrElse("") + spark.sql( + s"""CREATE TABLE $baseTable ( + | x int + | ) + | STORED AS PARQUET + | $tblProperties + """.stripMargin) + val expectedTableTz = explicitTz.orElse(defaultTz) + checkHasTz(spark, baseTable, expectedTableTz) + spark.sql( + s"""CREATE TABLE partitioned_$baseTable ( + | x int + | ) + | PARTITIONED BY (y int) + | STORED AS PARQUET + | $tblProperties + """.stripMargin) + checkHasTz(spark, s"partitioned_$baseTable", expectedTableTz) + spark.sql(s"CREATE TABLE like_$baseTable LIKE $baseTable") + checkHasTz(spark, s"like_$baseTable", expectedTableTz) + spark.sql( + s"""CREATE TABLE select_$baseTable + | STORED AS PARQUET + | AS + | SELECT * from $baseTable + """.stripMargin) + checkHasTz(spark, s"select_$baseTable", defaultTz) + + // check alter table, setting, unsetting, resetting the property + spark.sql( + s"""ALTER TABLE $baseTable SET TBLPROPERTIES ($key="America/Los_Angeles")""") + checkHasTz(spark, baseTable, Some("America/Los_Angeles")) + spark.sql(s"""ALTER TABLE $baseTable SET TBLPROPERTIES ($key="UTC")""") + checkHasTz(spark, baseTable, Some("UTC")) + spark.sql(s"""ALTER TABLE $baseTable UNSET TBLPROPERTIES ($key)""") + checkHasTz(spark, baseTable, None) + explicitTz.foreach { tz => + spark.sql(s"""ALTER TABLE $baseTable SET TBLPROPERTIES ($key="$tz")""") + checkHasTz(spark, baseTable, expectedTableTz) + } + } + } + } + + val desiredTimestampStrings = Seq( + "2015-12-31 22:49:59.123", + "2015-12-31 23:50:59.123", + "2016-01-01 00:39:59.123", + "2016-01-01 01:29:59.123" + ) + // We don't want to mess with timezones inside the tests themselves, since we use a shared + // spark context, and then we might be prone to issues from lazy vals for timezones. Instead, + // we manually adjust the timezone just to determine what the desired millis (since epoch, in utc) + // is for various "wall-clock" times in different timezones, and then we can compare against those + // in our tests. + val timestampTimezoneToMillis = { + val originalTz = TimeZone.getDefault + try { + desiredTimestampStrings.flatMap { timestampString => + Seq("America/Los_Angeles", "Europe/Berlin", "UTC").map { tzId => + TimeZone.setDefault(TimeZone.getTimeZone(tzId)) + val timestamp = Timestamp.valueOf(timestampString) + (timestampString, tzId) -> timestamp.getTime() + } + }.toMap + } finally { + TimeZone.setDefault(originalTz) + } + } + + private def createRawData(spark: SparkSession): Dataset[(String, Timestamp)] = { + import spark.implicits._ + val df = desiredTimestampStrings.toDF("display") + // this will get the millis corresponding to the display time given the current *session* + // timezone. + df.withColumn("ts", expr("cast(display as timestamp)")).as[(String, Timestamp)] + } + + private def testWriteTablesWithTimezone( + spark: SparkSession, + baseTable: String, + explicitTz: Option[String], + sessionTzOpt: Option[String]) : Unit = { + val key = ParquetFileFormat.PARQUET_TIMEZONE_TABLE_PROPERTY + test(s"SPARK-12297: Write to Parquet tables with Timestamps; explicitTz = $explicitTz; " + + s"sessionTzOpt = $sessionTzOpt") { + + withTable(s"saveAsTable_$baseTable", s"insert_$baseTable", s"partitioned_ts_$baseTable") { + val sessionTzId = sessionTzOpt.getOrElse(TimeZone.getDefault().getID()) + // check that created tables have correct TBLPROPERTIES + val tblProperties = explicitTz.map { + tz => s"""TBLPROPERTIES ($key="$tz")""" + }.getOrElse("") + + val rawData = createRawData(spark) + // Check writing data out. + // We write data into our tables, and then check the raw parquet files to see whether + // the correct conversion was applied. + rawData.write.saveAsTable(s"saveAsTable_$baseTable") + checkHasTz(spark, s"saveAsTable_$baseTable", None) + spark.sql( + s"""CREATE TABLE insert_$baseTable ( + | display string, + | ts timestamp + | ) + | STORED AS PARQUET + | $tblProperties + """.stripMargin) + checkHasTz(spark, s"insert_$baseTable", explicitTz) + rawData.write.insertInto(s"insert_$baseTable") + // no matter what, roundtripping via the table should leave the data unchanged + val readFromTable = spark.table(s"insert_$baseTable").collect() + .map { row => (row.getAs[String](0), row.getAs[Timestamp](1)).toString() }.sorted + assert(readFromTable === rawData.collect().map(_.toString()).sorted) + + // Now we load the raw parquet data on disk, and check if it was adjusted correctly. + // Note that we only store the timezone in the table property, so when we read the + // data this way, we're bypassing all of the conversion logic, and reading the raw + // values in the parquet file. + val onDiskLocation = spark.sessionState.catalog + .getTableMetadata(TableIdentifier(s"insert_$baseTable")).location.getPath + // we test reading the data back with and without the vectorized reader, to make sure we + // haven't broken reading parquet from non-hive tables, with both readers. + Seq(false, true).foreach { vectorized => + spark.conf.set(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key, vectorized) + val readFromDisk = spark.read.parquet(onDiskLocation).collect() + val storageTzId = explicitTz.getOrElse(sessionTzId) + readFromDisk.foreach { row => + val displayTime = row.getAs[String](0) + val millis = row.getAs[Timestamp](1).getTime() + val expectedMillis = timestampTimezoneToMillis((displayTime, storageTzId)) + assert(expectedMillis === millis, s"Display time '$displayTime' was stored " + + s"incorrectly with sessionTz = ${sessionTzOpt}; Got $millis, expected " + + s"$expectedMillis (delta = ${millis - expectedMillis})") + } + } + + // check tables partitioned by timestamps. We don't compare the "raw" data in this case, + // since they are adjusted even when we bypass the hive table. + rawData.write.partitionBy("ts").saveAsTable(s"partitioned_ts_$baseTable") + val partitionDiskLocation = spark.sessionState.catalog + .getTableMetadata(TableIdentifier(s"partitioned_ts_$baseTable")).location.getPath + // no matter what mix of timezones we use, the dirs should specify the value with the + // same time we use for display. + val parts = new File(partitionDiskLocation).list().collect { + case name if name.startsWith("ts=") => URLDecoder.decode(name.stripPrefix("ts=")) + }.toSet + assert(parts === desiredTimestampStrings.toSet) + } + } + } + + private def testReadTablesWithTimezone( + spark: SparkSession, + baseTable: String, + explicitTz: Option[String], + sessionTzOpt: Option[String]): Unit = { + val key = ParquetFileFormat.PARQUET_TIMEZONE_TABLE_PROPERTY + test(s"SPARK-12297: Read from Parquet tables with Timestamps; explicitTz = $explicitTz; " + + s"sessionTzOpt = $sessionTzOpt") { + withTable(s"external_$baseTable", s"partitioned_$baseTable") { + // we intentionally save this data directly, without creating a table, so we can + // see that the data is read back differently depending on table properties. + // we'll save with adjusted millis, so that it should be the correct millis after reading + // back. + val rawData = createRawData(spark) + // to avoid closing over entire class + val timestampTimezoneToMillis = this.timestampTimezoneToMillis + import spark.implicits._ + val adjustedRawData = (explicitTz match { + case Some(tzId) => + rawData.map { case (displayTime, _) => + val storageMillis = timestampTimezoneToMillis((displayTime, tzId)) + (displayTime, new Timestamp(storageMillis)) + } + case _ => + rawData + }).withColumnRenamed("_1", "display").withColumnRenamed("_2", "ts") + withTempPath { basePath => + val unpartitionedPath = new File(basePath, "flat") + val partitionedPath = new File(basePath, "partitioned") + adjustedRawData.write.parquet(unpartitionedPath.getCanonicalPath) + val options = Map("path" -> unpartitionedPath.getCanonicalPath) ++ + explicitTz.map { tz => Map(key -> tz) }.getOrElse(Map()) + + spark.catalog.createTable( + tableName = s"external_$baseTable", + source = "parquet", + schema = new StructType().add("display", StringType).add("ts", TimestampType), + options = options + ) + + // also write out a partitioned table, to make sure we can access that correctly. + // add a column we can partition by (value doesn't particularly matter). + val partitionedData = adjustedRawData.withColumn("id", monotonicallyIncreasingId) + partitionedData.write.partitionBy("id") + .parquet(partitionedPath.getCanonicalPath) + // unfortunately, catalog.createTable() doesn't let us specify partitioning, so just use + // a "CREATE TABLE" stmt. + val tblOpts = explicitTz.map { tz => s"""TBLPROPERTIES ($key="$tz")""" }.getOrElse("") + spark.sql(s"""CREATE EXTERNAL TABLE partitioned_$baseTable ( + | display string, + | ts timestamp + |) + |PARTITIONED BY (id bigint) + |STORED AS parquet + |LOCATION 'file:${partitionedPath.getCanonicalPath}' + |$tblOpts + """.stripMargin) + spark.sql(s"msck repair table partitioned_$baseTable") + + for { + vectorized <- Seq(false, true) + partitioned <- Seq(false, true) + } { + withClue(s"vectorized = $vectorized; partitioned = $partitioned") { + spark.conf.set(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key, vectorized) + val sessionTz = sessionTzOpt.getOrElse(TimeZone.getDefault().getID()) + val table = if (partitioned) s"partitioned_$baseTable" else s"external_$baseTable" + val query = s"select display, cast(ts as string) as ts_as_string, ts " + + s"from $table" + val collectedFromExternal = spark.sql(query).collect() + assert( collectedFromExternal.size === 4) + collectedFromExternal.foreach { row => + val displayTime = row.getAs[String](0) + // the timestamp should still display the same, despite the changes in timezones + assert(displayTime === row.getAs[String](1).toString()) + // we'll also check that the millis behind the timestamp has the appropriate + // adjustments. + val millis = row.getAs[Timestamp](2).getTime() + val expectedMillis = timestampTimezoneToMillis((displayTime, sessionTz)) + val delta = millis - expectedMillis + val deltaHours = delta / (1000L * 60 * 60) + assert(millis === expectedMillis, s"Display time '$displayTime' did not have " + + s"correct millis: was $millis, expected $expectedMillis; delta = $delta " + + s"($deltaHours hours)") + } + + // Now test that the behavior is still correct even with a filter which could get + // pushed down into parquet. We don't need extra handling for pushed down + // predicates because (a) in ParquetFilters, we ignore TimestampType and (b) parquet + // does not read statistics from int96 fields, as they are unsigned. See + // scalastyle:off line.size.limit + // https://github.com/apache/parquet-mr/blob/2fd62ee4d524c270764e9b91dca72e5cf1a005b7/parquet-hadoop/src/main/java/org/apache/parquet/format/converter/ParquetMetadataConverter.java#L419 + // https://github.com/apache/parquet-mr/blob/2fd62ee4d524c270764e9b91dca72e5cf1a005b7/parquet-hadoop/src/main/java/org/apache/parquet/format/converter/ParquetMetadataConverter.java#L348 + // scalastyle:on line.size.limit + // + // Just to be defensive in case anything ever changes in parquet, this test checks + // the assumption on column stats, and also the end-to-end behavior. + + val hadoopConf = sparkContext.hadoopConfiguration + val fs = FileSystem.get(hadoopConf) + val parts = if (partitioned) { + val subdirs = fs.listStatus(new Path(partitionedPath.getCanonicalPath)) + .filter(_.getPath().getName().startsWith("id=")) + fs.listStatus(subdirs.head.getPath()) + .filter(_.getPath().getName().endsWith(".parquet")) + } else { + fs.listStatus(new Path(unpartitionedPath.getCanonicalPath)) + .filter(_.getPath().getName().endsWith(".parquet")) + } + // grab the meta data from the parquet file. The next section of asserts just make + // sure the test is configured correctly. + assert(parts.size == 1) + val oneFooter = ParquetFileReader.readFooter(hadoopConf, parts.head.getPath) + assert(oneFooter.getFileMetaData.getSchema.getColumns.size === 2) + assert(oneFooter.getFileMetaData.getSchema.getColumns.get(1).getType() === + PrimitiveTypeName.INT96) + val oneBlockMeta = oneFooter.getBlocks().get(0) + val oneBlockColumnMeta = oneBlockMeta.getColumns().get(1) + val columnStats = oneBlockColumnMeta.getStatistics + // This is the important assert. Column stats are written, but they are ignored + // when the data is read back as mentioned above, b/c int96 is unsigned. This + // assert makes sure this holds even if we change parquet versions (if eg. there + // were ever statistics even on unsigned columns). + assert(columnStats.isEmpty) + + // These queries should return the entire dataset, but if the predicates were + // applied to the raw values in parquet, they would incorrectly filter data out. + Seq( + ">" -> "2015-12-31 22:00:00", + "<" -> "2016-01-01 02:00:00" + ).foreach { case (comparison, value) => + val query = + s"select ts from $table where ts $comparison '$value'" + val countWithFilter = spark.sql(query).count() + assert(countWithFilter === 4, query) + } + } + } + } + } + } + } + + test("SPARK-12297: exception on bad timezone") { + val key = ParquetFileFormat.PARQUET_TIMEZONE_TABLE_PROPERTY + val badTzException = intercept[AnalysisException] { + spark.sql( + s"""CREATE TABLE bad_tz_table ( + | x int + | ) + | STORED AS PARQUET + | TBLPROPERTIES ($key="Blart Versenwald III") + """.stripMargin) + } + assert(badTzException.getMessage.contains("Blart Versenwald III")) + } } From c24bdaab5a234d18b273544cefc44cc4005bf8fc Mon Sep 17 00:00:00 2001 From: Felix Cheung Date: Sun, 7 May 2017 23:10:18 -0700 Subject: [PATCH 0431/1765] [SPARK-20626][SPARKR] address date test warning with timezone on windows ## What changes were proposed in this pull request? set timezone on windows ## How was this patch tested? unit test, AppVeyor Author: Felix Cheung Closes #17892 from felixcheung/rtimestamptest. --- R/pkg/inst/tests/testthat/test_sparkSQL.R | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/R/pkg/inst/tests/testthat/test_sparkSQL.R b/R/pkg/inst/tests/testthat/test_sparkSQL.R index 0856bab5686c5..f517ce6713133 100644 --- a/R/pkg/inst/tests/testthat/test_sparkSQL.R +++ b/R/pkg/inst/tests/testthat/test_sparkSQL.R @@ -96,6 +96,10 @@ mockLinesMapType <- c("{\"name\":\"Bob\",\"info\":{\"age\":16,\"height\":176.5}} mapTypeJsonPath <- tempfile(pattern = "sparkr-test", fileext = ".tmp") writeLines(mockLinesMapType, mapTypeJsonPath) +if (.Platform$OS.type == "windows") { + Sys.setenv(TZ = "GMT") +} + test_that("calling sparkRSQL.init returns existing SQL context", { skip_on_cran() From 42cc6d13edbebb7c435ec47c0c12b445e05fdd49 Mon Sep 17 00:00:00 2001 From: sujith71955 Date: Sun, 7 May 2017 23:15:00 -0700 Subject: [PATCH 0432/1765] [SPARK-20380][SQL] Unable to set/unset table comment property using ALTER TABLE SET/UNSET TBLPROPERTIES ddl ### What changes were proposed in this pull request? Table comment was not getting set/unset using **ALTER TABLE SET/UNSET TBLPROPERTIES** query eg: ALTER TABLE table_with_comment SET TBLPROPERTIES("comment"= "modified comment) when user alter the table properties and adds/updates table comment,table comment which is a field of **CatalogTable** instance is not getting updated and old table comment if exists was shown to user, inorder to handle this issue, update the comment field value in **CatalogTable** with the newly added/modified comment along with other table level properties when user executes **ALTER TABLE SET TBLPROPERTIES** query. This pr has also taken care of unsetting the table comment when user executes query **ALTER TABLE UNSET TBLPROPERTIES** inorder to unset or remove table comment. eg: ALTER TABLE table_comment UNSET TBLPROPERTIES IF EXISTS ('comment') ### How was this patch tested? Added test cases as part of **SQLQueryTestSuite** for verifying table comment using desc formatted table query after adding/modifying table comment as part of **AlterTableSetPropertiesCommand** and unsetting the table comment using **AlterTableUnsetPropertiesCommand**. Author: sujith71955 Closes #17649 from sujith71955/alter_table_comment. --- .../catalyst/catalog/InMemoryCatalog.scala | 8 +- .../spark/sql/execution/command/ddl.scala | 12 +- .../describe-table-after-alter-table.sql | 29 ++++ .../describe-table-after-alter-table.sql.out | 161 ++++++++++++++++++ 4 files changed, 204 insertions(+), 6 deletions(-) create mode 100644 sql/core/src/test/resources/sql-tests/inputs/describe-table-after-alter-table.sql create mode 100644 sql/core/src/test/resources/sql-tests/results/describe-table-after-alter-table.sql.out diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/InMemoryCatalog.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/InMemoryCatalog.scala index 81dd8efc0015f..8a5319bebe54e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/InMemoryCatalog.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/InMemoryCatalog.scala @@ -216,8 +216,8 @@ class InMemoryCatalog( } else { tableDefinition } - - catalog(db).tables.put(table, new TableDesc(tableWithLocation)) + val tableProp = tableWithLocation.properties.filter(_._1 != "comment") + catalog(db).tables.put(table, new TableDesc(tableWithLocation.copy(properties = tableProp))) } } @@ -298,7 +298,9 @@ class InMemoryCatalog( assert(tableDefinition.identifier.database.isDefined) val db = tableDefinition.identifier.database.get requireTableExists(db, tableDefinition.identifier.table) - catalog(db).tables(tableDefinition.identifier.table).table = tableDefinition + val updatedProperties = tableDefinition.properties.filter(kv => kv._1 != "comment") + val newTableDefinition = tableDefinition.copy(properties = updatedProperties) + catalog(db).tables(tableDefinition.identifier.table).table = newTableDefinition } override def alterTableSchema( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala index 55540563ef911..793fb9b795596 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala @@ -231,8 +231,12 @@ case class AlterTableSetPropertiesCommand( val catalog = sparkSession.sessionState.catalog val table = catalog.getTableMetadata(tableName) DDLUtils.verifyAlterTableType(catalog, table, isView) - // This overrides old properties - val newTable = table.copy(properties = table.properties ++ properties) + // This overrides old properties and update the comment parameter of CatalogTable + // with the newly added/modified comment since CatalogTable also holds comment as its + // direct property. + val newTable = table.copy( + properties = table.properties ++ properties, + comment = properties.get("comment")) catalog.alterTable(newTable) Seq.empty[Row] } @@ -267,8 +271,10 @@ case class AlterTableUnsetPropertiesCommand( } } } + // If comment is in the table property, we reset it to None + val tableComment = if (propKeys.contains("comment")) None else table.properties.get("comment") val newProperties = table.properties.filter { case (k, _) => !propKeys.contains(k) } - val newTable = table.copy(properties = newProperties) + val newTable = table.copy(properties = newProperties, comment = tableComment) catalog.alterTable(newTable) Seq.empty[Row] } diff --git a/sql/core/src/test/resources/sql-tests/inputs/describe-table-after-alter-table.sql b/sql/core/src/test/resources/sql-tests/inputs/describe-table-after-alter-table.sql new file mode 100644 index 0000000000000..69bff6656c43a --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/inputs/describe-table-after-alter-table.sql @@ -0,0 +1,29 @@ +CREATE TABLE table_with_comment (a STRING, b INT, c STRING, d STRING) USING parquet COMMENT 'added'; + +DESC FORMATTED table_with_comment; + +-- ALTER TABLE BY MODIFYING COMMENT +ALTER TABLE table_with_comment SET TBLPROPERTIES("comment"= "modified comment", "type"= "parquet"); + +DESC FORMATTED table_with_comment; + +-- DROP TEST TABLE +DROP TABLE table_with_comment; + +-- CREATE TABLE WITHOUT COMMENT +CREATE TABLE table_comment (a STRING, b INT) USING parquet; + +DESC FORMATTED table_comment; + +-- ALTER TABLE BY ADDING COMMENT +ALTER TABLE table_comment SET TBLPROPERTIES(comment = "added comment"); + +DESC formatted table_comment; + +-- ALTER UNSET PROPERTIES COMMENT +ALTER TABLE table_comment UNSET TBLPROPERTIES IF EXISTS ('comment'); + +DESC FORMATTED table_comment; + +-- DROP TEST TABLE +DROP TABLE table_comment; diff --git a/sql/core/src/test/resources/sql-tests/results/describe-table-after-alter-table.sql.out b/sql/core/src/test/resources/sql-tests/results/describe-table-after-alter-table.sql.out new file mode 100644 index 0000000000000..1cc11c475bc40 --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/results/describe-table-after-alter-table.sql.out @@ -0,0 +1,161 @@ +-- Automatically generated by SQLQueryTestSuite +-- Number of queries: 12 + + +-- !query 0 +CREATE TABLE table_with_comment (a STRING, b INT, c STRING, d STRING) USING parquet COMMENT 'added' +-- !query 0 schema +struct<> +-- !query 0 output + + + +-- !query 1 +DESC FORMATTED table_with_comment +-- !query 1 schema +struct +-- !query 1 output +# col_name data_type comment +a string +b int +c string +d string + +# Detailed Table Information +Database default +Table table_with_comment +Created [not included in comparison] +Last Access [not included in comparison] +Type MANAGED +Provider parquet +Comment added +Location [not included in comparison]sql/core/spark-warehouse/table_with_comment + + +-- !query 2 +ALTER TABLE table_with_comment SET TBLPROPERTIES("comment"= "modified comment", "type"= "parquet") +-- !query 2 schema +struct<> +-- !query 2 output + + + +-- !query 3 +DESC FORMATTED table_with_comment +-- !query 3 schema +struct +-- !query 3 output +# col_name data_type comment +a string +b int +c string +d string + +# Detailed Table Information +Database default +Table table_with_comment +Created [not included in comparison] +Last Access [not included in comparison] +Type MANAGED +Provider parquet +Comment modified comment +Properties [type=parquet] +Location [not included in comparison]sql/core/spark-warehouse/table_with_comment + + +-- !query 4 +DROP TABLE table_with_comment +-- !query 4 schema +struct<> +-- !query 4 output + + + +-- !query 5 +CREATE TABLE table_comment (a STRING, b INT) USING parquet +-- !query 5 schema +struct<> +-- !query 5 output + + + +-- !query 6 +DESC FORMATTED table_comment +-- !query 6 schema +struct +-- !query 6 output +# col_name data_type comment +a string +b int + +# Detailed Table Information +Database default +Table table_comment +Created [not included in comparison] +Last Access [not included in comparison] +Type MANAGED +Provider parquet +Location [not included in comparison]sql/core/spark-warehouse/table_comment + + +-- !query 7 +ALTER TABLE table_comment SET TBLPROPERTIES(comment = "added comment") +-- !query 7 schema +struct<> +-- !query 7 output + + + +-- !query 8 +DESC formatted table_comment +-- !query 8 schema +struct +-- !query 8 output +# col_name data_type comment +a string +b int + +# Detailed Table Information +Database default +Table table_comment +Created [not included in comparison] +Last Access [not included in comparison] +Type MANAGED +Provider parquet +Comment added comment +Location [not included in comparison]sql/core/spark-warehouse/table_comment + + +-- !query 9 +ALTER TABLE table_comment UNSET TBLPROPERTIES IF EXISTS ('comment') +-- !query 9 schema +struct<> +-- !query 9 output + + + +-- !query 10 +DESC FORMATTED table_comment +-- !query 10 schema +struct +-- !query 10 output +# col_name data_type comment +a string +b int + +# Detailed Table Information +Database default +Table table_comment +Created [not included in comparison] +Last Access [not included in comparison] +Type MANAGED +Provider parquet +Location [not included in comparison]sql/core/spark-warehouse/table_comment + + +-- !query 11 +DROP TABLE table_comment +-- !query 11 schema +struct<> +-- !query 11 output + From 2fdaeb52bbe2ed1a9127ac72917286e505303c85 Mon Sep 17 00:00:00 2001 From: Wayne Zhang Date: Sun, 7 May 2017 23:16:30 -0700 Subject: [PATCH 0433/1765] [SPARKR][DOC] fix typo in vignettes ## What changes were proposed in this pull request? Fix typo in vignettes Author: Wayne Zhang Closes #17884 from actuaryzhang/typo. --- R/pkg/vignettes/sparkr-vignettes.Rmd | 36 ++++++++++++++-------------- 1 file changed, 18 insertions(+), 18 deletions(-) diff --git a/R/pkg/vignettes/sparkr-vignettes.Rmd b/R/pkg/vignettes/sparkr-vignettes.Rmd index d38ec4f1b6f37..49f4ab8f146a8 100644 --- a/R/pkg/vignettes/sparkr-vignettes.Rmd +++ b/R/pkg/vignettes/sparkr-vignettes.Rmd @@ -65,7 +65,7 @@ We can view the first few rows of the `SparkDataFrame` by `head` or `showDF` fun head(carsDF) ``` -Common data processing operations such as `filter`, `select` are supported on the `SparkDataFrame`. +Common data processing operations such as `filter` and `select` are supported on the `SparkDataFrame`. ```{r} carsSubDF <- select(carsDF, "model", "mpg", "hp") carsSubDF <- filter(carsSubDF, carsSubDF$hp >= 200) @@ -379,7 +379,7 @@ out <- dapply(carsSubDF, function(x) { x <- cbind(x, x$mpg * 1.61) }, schema) head(collect(out)) ``` -Like `dapply`, apply a function to each partition of a `SparkDataFrame` and collect the result back. The output of function should be a `data.frame`, but no schema is required in this case. Note that `dapplyCollect` can fail if the output of UDF run on all the partition cannot be pulled to the driver and fit in driver memory. +Like `dapply`, `dapplyCollect` can apply a function to each partition of a `SparkDataFrame` and collect the result back. The output of the function should be a `data.frame`, but no schema is required in this case. Note that `dapplyCollect` can fail if the output of the UDF on all partitions cannot be pulled into the driver's memory. ```{r} out <- dapplyCollect( @@ -405,7 +405,7 @@ result <- gapply( head(arrange(result, "max_mpg", decreasing = TRUE)) ``` -Like gapply, `gapplyCollect` applies a function to each partition of a `SparkDataFrame` and collect the result back to R `data.frame`. The output of the function should be a `data.frame` but no schema is required in this case. Note that `gapplyCollect` can fail if the output of UDF run on all the partition cannot be pulled to the driver and fit in driver memory. +Like `gapply`, `gapplyCollect` can apply a function to each partition of a `SparkDataFrame` and collect the result back to R `data.frame`. The output of the function should be a `data.frame` but no schema is required in this case. Note that `gapplyCollect` can fail if the output of the UDF on all partitions cannot be pulled into the driver's memory. ```{r} result <- gapplyCollect( @@ -458,20 +458,20 @@ options(ops) ### SQL Queries -A `SparkDataFrame` can also be registered as a temporary view in Spark SQL and that allows you to run SQL queries over its data. The sql function enables applications to run SQL queries programmatically and returns the result as a `SparkDataFrame`. +A `SparkDataFrame` can also be registered as a temporary view in Spark SQL so that one can run SQL queries over its data. The sql function enables applications to run SQL queries programmatically and returns the result as a `SparkDataFrame`. ```{r} people <- read.df(paste0(sparkR.conf("spark.home"), "/examples/src/main/resources/people.json"), "json") ``` -Register this SparkDataFrame as a temporary view. +Register this `SparkDataFrame` as a temporary view. ```{r} createOrReplaceTempView(people, "people") ``` -SQL statements can be run by using the sql method. +SQL statements can be run using the sql method. ```{r} teenagers <- sql("SELECT name FROM people WHERE age >= 13 AND age <= 19") head(teenagers) @@ -780,7 +780,7 @@ head(predict(isoregModel, newDF)) `spark.gbt` fits a [gradient-boosted tree](https://en.wikipedia.org/wiki/Gradient_boosting) classification or regression model on a `SparkDataFrame`. Users can call `summary` to get a summary of the fitted model, `predict` to make predictions, and `write.ml`/`read.ml` to save/load fitted models. -Similar to the random forest example above, we use the `longley` dataset to train a gradient-boosted tree and make predictions: +We use the `longley` dataset to train a gradient-boosted tree and make predictions: ```{r, warning=FALSE} df <- createDataFrame(longley) @@ -820,7 +820,7 @@ head(select(fitted, "Class", "prediction")) `spark.gaussianMixture` fits multivariate [Gaussian Mixture Model](https://en.wikipedia.org/wiki/Mixture_model#Multivariate_Gaussian_mixture_model) (GMM) against a `SparkDataFrame`. [Expectation-Maximization](https://en.wikipedia.org/wiki/Expectation%E2%80%93maximization_algorithm) (EM) is used to approximate the maximum likelihood estimator (MLE) of the model. -We use a simulated example to demostrate the usage. +We use a simulated example to demonstrate the usage. ```{r} X1 <- data.frame(V1 = rnorm(4), V2 = rnorm(4)) X2 <- data.frame(V1 = rnorm(6, 3), V2 = rnorm(6, 4)) @@ -851,9 +851,9 @@ head(select(kmeansPredictions, "model", "mpg", "hp", "wt", "prediction"), n = 20 * Topics and documents both exist in a feature space, where feature vectors are vectors of word counts (bag of words). -* Rather than estimating a clustering using a traditional distance, LDA uses a function based on a statistical model of how text documents are generated. +* Rather than clustering using a traditional distance, LDA uses a function based on a statistical model of how text documents are generated. -To use LDA, we need to specify a `features` column in `data` where each entry represents a document. There are two type options for the column: +To use LDA, we need to specify a `features` column in `data` where each entry represents a document. There are two options for the column: * character string: This can be a string of the whole document. It will be parsed automatically. Additional stop words can be added in `customizedStopWords`. @@ -901,7 +901,7 @@ perplexity `spark.als` learns latent factors in [collaborative filtering](https://en.wikipedia.org/wiki/Recommender_system#Collaborative_filtering) via [alternating least squares](http://dl.acm.org/citation.cfm?id=1608614). -There are multiple options that can be configured in `spark.als`, including `rank`, `reg`, `nonnegative`. For a complete list, refer to the help file. +There are multiple options that can be configured in `spark.als`, including `rank`, `reg`, and `nonnegative`. For a complete list, refer to the help file. ```{r, eval=FALSE} ratings <- list(list(0, 0, 4.0), list(0, 1, 2.0), list(1, 1, 3.0), list(1, 2, 4.0), @@ -981,7 +981,7 @@ testSummary ### Model Persistence -The following example shows how to save/load an ML model by SparkR. +The following example shows how to save/load an ML model in SparkR. ```{r} t <- as.data.frame(Titanic) training <- createDataFrame(t) @@ -1079,19 +1079,19 @@ There are three main object classes in SparkR you may be working with. + `sdf` stores a reference to the corresponding Spark Dataset in the Spark JVM backend. + `env` saves the meta-information of the object such as `isCached`. -It can be created by data import methods or by transforming an existing `SparkDataFrame`. We can manipulate `SparkDataFrame` by numerous data processing functions and feed that into machine learning algorithms. + It can be created by data import methods or by transforming an existing `SparkDataFrame`. We can manipulate `SparkDataFrame` by numerous data processing functions and feed that into machine learning algorithms. -* `Column`: an S4 class representing column of `SparkDataFrame`. The slot `jc` saves a reference to the corresponding Column object in the Spark JVM backend. +* `Column`: an S4 class representing a column of `SparkDataFrame`. The slot `jc` saves a reference to the corresponding `Column` object in the Spark JVM backend. -It can be obtained from a `SparkDataFrame` by `$` operator, `df$col`. More often, it is used together with other functions, for example, with `select` to select particular columns, with `filter` and constructed conditions to select rows, with aggregation functions to compute aggregate statistics for each group. + It can be obtained from a `SparkDataFrame` by `$` operator, e.g., `df$col`. More often, it is used together with other functions, for example, with `select` to select particular columns, with `filter` and constructed conditions to select rows, with aggregation functions to compute aggregate statistics for each group. -* `GroupedData`: an S4 class representing grouped data created by `groupBy` or by transforming other `GroupedData`. Its `sgd` slot saves a reference to a RelationalGroupedDataset object in the backend. +* `GroupedData`: an S4 class representing grouped data created by `groupBy` or by transforming other `GroupedData`. Its `sgd` slot saves a reference to a `RelationalGroupedDataset` object in the backend. -This is often an intermediate object with group information and followed up by aggregation operations. + This is often an intermediate object with group information and followed up by aggregation operations. ### Architecture -A complete description of architecture can be seen in reference, in particular the paper *SparkR: Scaling R Programs with Spark*. +A complete description of architecture can be seen in the references, in particular the paper *SparkR: Scaling R Programs with Spark*. Under the hood of SparkR is Spark SQL engine. This avoids the overheads of running interpreted R code, and the optimized SQL execution engine in Spark uses structural information about data and computation flow to perform a bunch of optimizations to speed up the computation. From 0f820e2b6c507dc4156703862ce65e598ca41cca Mon Sep 17 00:00:00 2001 From: liuxian Date: Mon, 8 May 2017 10:00:58 +0100 Subject: [PATCH 0434/1765] [SPARK-20519][SQL][CORE] Modify to prevent some possible runtime exceptions Signed-off-by: liuxian ## What changes were proposed in this pull request? When the input parameter is null, may be a runtime exception occurs ## How was this patch tested? Existing unit tests Author: liuxian Closes #17796 from 10110346/wip_lx_0428. --- .../scala/org/apache/spark/api/python/PythonRDD.scala | 2 +- .../scala/org/apache/spark/deploy/DeployMessage.scala | 8 ++++---- .../scala/org/apache/spark/deploy/master/Master.scala | 2 +- .../org/apache/spark/deploy/master/MasterArguments.scala | 4 ++-- .../org/apache/spark/deploy/master/WorkerInfo.scala | 2 +- .../scala/org/apache/spark/deploy/worker/Worker.scala | 2 +- .../org/apache/spark/deploy/worker/WorkerArguments.scala | 4 ++-- .../main/scala/org/apache/spark/executor/Executor.scala | 2 +- .../scala/org/apache/spark/storage/BlockManagerId.scala | 2 +- core/src/main/scala/org/apache/spark/util/RpcUtils.scala | 2 +- core/src/main/scala/org/apache/spark/util/Utils.scala | 9 +++++---- .../deploy/mesos/MesosClusterDispatcherArguments.scala | 2 +- 12 files changed, 21 insertions(+), 20 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala index b0dd2fc187baf..fb0405b1a69c6 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala @@ -879,7 +879,7 @@ private[spark] class PythonAccumulatorV2( private val serverPort: Int) extends CollectionAccumulator[Array[Byte]] { - Utils.checkHost(serverHost, "Expected hostname") + Utils.checkHost(serverHost) val bufferSize = SparkEnv.get.conf.getInt("spark.buffer.size", 65536) diff --git a/core/src/main/scala/org/apache/spark/deploy/DeployMessage.scala b/core/src/main/scala/org/apache/spark/deploy/DeployMessage.scala index ac09c6c497f8b..b5cb3f0a0f9dc 100644 --- a/core/src/main/scala/org/apache/spark/deploy/DeployMessage.scala +++ b/core/src/main/scala/org/apache/spark/deploy/DeployMessage.scala @@ -43,7 +43,7 @@ private[deploy] object DeployMessages { memory: Int, workerWebUiUrl: String) extends DeployMessage { - Utils.checkHost(host, "Required hostname") + Utils.checkHost(host) assert (port > 0) } @@ -131,7 +131,7 @@ private[deploy] object DeployMessages { // TODO(matei): replace hostPort with host case class ExecutorAdded(id: Int, workerId: String, hostPort: String, cores: Int, memory: Int) { - Utils.checkHostPort(hostPort, "Required hostport") + Utils.checkHostPort(hostPort) } case class ExecutorUpdated(id: Int, state: ExecutorState, message: Option[String], @@ -183,7 +183,7 @@ private[deploy] object DeployMessages { completedDrivers: Array[DriverInfo], status: MasterState) { - Utils.checkHost(host, "Required hostname") + Utils.checkHost(host) assert (port > 0) def uri: String = "spark://" + host + ":" + port @@ -201,7 +201,7 @@ private[deploy] object DeployMessages { drivers: List[DriverRunner], finishedDrivers: List[DriverRunner], masterUrl: String, cores: Int, memory: Int, coresUsed: Int, memoryUsed: Int, masterWebUiUrl: String) { - Utils.checkHost(host, "Required hostname") + Utils.checkHost(host) assert (port > 0) } diff --git a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala index 816bf37e39fee..e061939623cbb 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala @@ -80,7 +80,7 @@ private[deploy] class Master( private val waitingDrivers = new ArrayBuffer[DriverInfo] private var nextDriverNumber = 0 - Utils.checkHost(address.host, "Expected hostname") + Utils.checkHost(address.host) private val masterMetricsSystem = MetricsSystem.createMetricsSystem("master", conf, securityMgr) private val applicationMetricsSystem = MetricsSystem.createMetricsSystem("applications", conf, diff --git a/core/src/main/scala/org/apache/spark/deploy/master/MasterArguments.scala b/core/src/main/scala/org/apache/spark/deploy/master/MasterArguments.scala index c63793c16dcef..615d2533cf085 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/MasterArguments.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/MasterArguments.scala @@ -60,12 +60,12 @@ private[master] class MasterArguments(args: Array[String], conf: SparkConf) exte @tailrec private def parse(args: List[String]): Unit = args match { case ("--ip" | "-i") :: value :: tail => - Utils.checkHost(value, "ip no longer supported, please use hostname " + value) + Utils.checkHost(value) host = value parse(tail) case ("--host" | "-h") :: value :: tail => - Utils.checkHost(value, "Please use hostname " + value) + Utils.checkHost(value) host = value parse(tail) diff --git a/core/src/main/scala/org/apache/spark/deploy/master/WorkerInfo.scala b/core/src/main/scala/org/apache/spark/deploy/master/WorkerInfo.scala index 4e20c10fd1427..c87d6e24b78c6 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/WorkerInfo.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/WorkerInfo.scala @@ -32,7 +32,7 @@ private[spark] class WorkerInfo( val webUiAddress: String) extends Serializable { - Utils.checkHost(host, "Expected hostname") + Utils.checkHost(host) assert (port > 0) @transient var executors: mutable.HashMap[String, ExecutorDesc] = _ // executorId => info diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala b/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala index 00b9d1af373db..34e3a4c020c80 100755 --- a/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala @@ -55,7 +55,7 @@ private[deploy] class Worker( private val host = rpcEnv.address.host private val port = rpcEnv.address.port - Utils.checkHost(host, "Expected hostname") + Utils.checkHost(host) assert (port > 0) // A scheduled executor used to send messages at the specified time. diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/WorkerArguments.scala b/core/src/main/scala/org/apache/spark/deploy/worker/WorkerArguments.scala index 777020d4d5c84..bd07d342e04ac 100644 --- a/core/src/main/scala/org/apache/spark/deploy/worker/WorkerArguments.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/WorkerArguments.scala @@ -68,12 +68,12 @@ private[worker] class WorkerArguments(args: Array[String], conf: SparkConf) { @tailrec private def parse(args: List[String]): Unit = args match { case ("--ip" | "-i") :: value :: tail => - Utils.checkHost(value, "ip no longer supported, please use hostname " + value) + Utils.checkHost(value) host = value parse(tail) case ("--host" | "-h") :: value :: tail => - Utils.checkHost(value, "Please use hostname " + value) + Utils.checkHost(value) host = value parse(tail) diff --git a/core/src/main/scala/org/apache/spark/executor/Executor.scala b/core/src/main/scala/org/apache/spark/executor/Executor.scala index 51b6c373c4daf..3bc47b670305b 100644 --- a/core/src/main/scala/org/apache/spark/executor/Executor.scala +++ b/core/src/main/scala/org/apache/spark/executor/Executor.scala @@ -71,7 +71,7 @@ private[spark] class Executor( private val conf = env.conf // No ip or host:port - just hostname - Utils.checkHost(executorHostname, "Expected executed slave to be a hostname") + Utils.checkHost(executorHostname) // must not have port specified. assert (0 == Utils.parseHostPort(executorHostname)._2) diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerId.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerId.scala index c37a3604d28fa..2c3da0ee85e06 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerId.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerId.scala @@ -46,7 +46,7 @@ class BlockManagerId private ( def executorId: String = executorId_ if (null != host_) { - Utils.checkHost(host_, "Expected hostname") + Utils.checkHost(host_) assert (port_ > 0) } diff --git a/core/src/main/scala/org/apache/spark/util/RpcUtils.scala b/core/src/main/scala/org/apache/spark/util/RpcUtils.scala index 46a5cb2cff5a5..e5cccf39f9455 100644 --- a/core/src/main/scala/org/apache/spark/util/RpcUtils.scala +++ b/core/src/main/scala/org/apache/spark/util/RpcUtils.scala @@ -28,7 +28,7 @@ private[spark] object RpcUtils { def makeDriverRef(name: String, conf: SparkConf, rpcEnv: RpcEnv): RpcEndpointRef = { val driverHost: String = conf.get("spark.driver.host", "localhost") val driverPort: Int = conf.getInt("spark.driver.port", 7077) - Utils.checkHost(driverHost, "Expected hostname") + Utils.checkHost(driverHost) rpcEnv.setupEndpointRef(RpcAddress(driverHost, driverPort), name) } diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index 4d37db96dfc37..edfe229792323 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -937,12 +937,13 @@ private[spark] object Utils extends Logging { customHostname.getOrElse(InetAddresses.toUriString(localIpAddress)) } - def checkHost(host: String, message: String = "") { - assert(host.indexOf(':') == -1, message) + def checkHost(host: String) { + assert(host != null && host.indexOf(':') == -1, s"Expected hostname (not IP) but got $host") } - def checkHostPort(hostPort: String, message: String = "") { - assert(hostPort.indexOf(':') != -1, message) + def checkHostPort(hostPort: String) { + assert(hostPort != null && hostPort.indexOf(':') != -1, + s"Expected host and port but got $hostPort") } // Typically, this will be of order of number of nodes in cluster diff --git a/resource-managers/mesos/src/main/scala/org/apache/spark/deploy/mesos/MesosClusterDispatcherArguments.scala b/resource-managers/mesos/src/main/scala/org/apache/spark/deploy/mesos/MesosClusterDispatcherArguments.scala index ef08502ec8dd6..ddea762fdb919 100644 --- a/resource-managers/mesos/src/main/scala/org/apache/spark/deploy/mesos/MesosClusterDispatcherArguments.scala +++ b/resource-managers/mesos/src/main/scala/org/apache/spark/deploy/mesos/MesosClusterDispatcherArguments.scala @@ -59,7 +59,7 @@ private[mesos] class MesosClusterDispatcherArguments(args: Array[String], conf: @tailrec private def parse(args: List[String]): Unit = args match { case ("--host" | "-h") :: value :: tail => - Utils.checkHost(value, "Please use hostname " + value) + Utils.checkHost(value) host = value parse(tail) From 15526653a93a32cde3c9ea0c0e68e35622b0a590 Mon Sep 17 00:00:00 2001 From: Xianyang Liu Date: Mon, 8 May 2017 17:33:47 +0800 Subject: [PATCH 0435/1765] [SPARK-19956][CORE] Optimize a location order of blocks with topology information ## What changes were proposed in this pull request? When call the method getLocations of BlockManager, we only compare the data block host. Random selection for non-local data blocks, this may cause the selected data block to be in a different rack. So in this patch to increase the sort of the rack. ## How was this patch tested? New test case. Please review http://spark.apache.org/contributing.html before opening a pull request. Author: Xianyang Liu Closes #17300 from ConeyLiu/blockmanager. --- .../apache/spark/storage/BlockManager.scala | 11 +++++-- .../spark/storage/BlockManagerSuite.scala | 31 +++++++++++++++++-- 2 files changed, 37 insertions(+), 5 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala index 3219969bcd06f..33ce30c58e1ad 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala @@ -612,12 +612,19 @@ private[spark] class BlockManager( /** * Return a list of locations for the given block, prioritizing the local machine since - * multiple block managers can share the same host. + * multiple block managers can share the same host, followed by hosts on the same rack. */ private def getLocations(blockId: BlockId): Seq[BlockManagerId] = { val locs = Random.shuffle(master.getLocations(blockId)) val (preferredLocs, otherLocs) = locs.partition { loc => blockManagerId.host == loc.host } - preferredLocs ++ otherLocs + blockManagerId.topologyInfo match { + case None => preferredLocs ++ otherLocs + case Some(_) => + val (sameRackLocs, differentRackLocs) = otherLocs.partition { + loc => blockManagerId.topologyInfo == loc.topologyInfo + } + preferredLocs ++ sameRackLocs ++ differentRackLocs + } } /** diff --git a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala index a8b9604899838..1e7bcdb6740f6 100644 --- a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala @@ -496,8 +496,8 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE assert(list2DiskGet.get.readMethod === DataReadMethod.Disk) } - test("optimize a location order of blocks") { - val localHost = Utils.localHostName() + test("optimize a location order of blocks without topology information") { + val localHost = "localhost" val otherHost = "otherHost" val bmMaster = mock(classOf[BlockManagerMaster]) val bmId1 = BlockManagerId("id1", localHost, 1) @@ -508,7 +508,32 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE val blockManager = makeBlockManager(128, "exec", bmMaster) val getLocations = PrivateMethod[Seq[BlockManagerId]]('getLocations) val locations = blockManager invokePrivate getLocations(BroadcastBlockId(0)) - assert(locations.map(_.host).toSet === Set(localHost, localHost, otherHost)) + assert(locations.map(_.host) === Seq(localHost, localHost, otherHost)) + } + + test("optimize a location order of blocks with topology information") { + val localHost = "localhost" + val otherHost = "otherHost" + val localRack = "localRack" + val otherRack = "otherRack" + + val bmMaster = mock(classOf[BlockManagerMaster]) + val bmId1 = BlockManagerId("id1", localHost, 1, Some(localRack)) + val bmId2 = BlockManagerId("id2", localHost, 2, Some(localRack)) + val bmId3 = BlockManagerId("id3", otherHost, 3, Some(otherRack)) + val bmId4 = BlockManagerId("id4", otherHost, 4, Some(otherRack)) + val bmId5 = BlockManagerId("id5", otherHost, 5, Some(localRack)) + when(bmMaster.getLocations(mc.any[BlockId])) + .thenReturn(Seq(bmId1, bmId2, bmId5, bmId3, bmId4)) + + val blockManager = makeBlockManager(128, "exec", bmMaster) + blockManager.blockManagerId = + BlockManagerId(SparkContext.DRIVER_IDENTIFIER, localHost, 1, Some(localRack)) + val getLocations = PrivateMethod[Seq[BlockManagerId]]('getLocations) + val locations = blockManager invokePrivate getLocations(BroadcastBlockId(0)) + assert(locations.map(_.host) === Seq(localHost, localHost, otherHost, otherHost, otherHost)) + assert(locations.flatMap(_.topologyInfo) + === Seq(localRack, localRack, localRack, otherRack, otherRack)) } test("SPARK-9591: getRemoteBytes from another location when Exception throw") { From 58518d070777fc0665c4d02bad8adf910807df98 Mon Sep 17 00:00:00 2001 From: Nick Pentreath Date: Mon, 8 May 2017 12:45:00 +0200 Subject: [PATCH 0436/1765] [SPARK-20596][ML][TEST] Consolidate and improve ALS recommendAll test cases Existing test cases for `recommendForAllX` methods (added in [SPARK-19535](https://issues.apache.org/jira/browse/SPARK-19535)) test `k < num items` and `k = num items`. Technically we should also test that `k > num items` returns the same results as `k = num items`. ## How was this patch tested? Updated existing unit tests. Author: Nick Pentreath Closes #17860 from MLnick/SPARK-20596-als-rec-tests. --- .../spark/ml/recommendation/ALSSuite.scala | 63 ++++++++----------- 1 file changed, 25 insertions(+), 38 deletions(-) diff --git a/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala index 7574af3d77ea8..9d31e792633cd 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala @@ -671,58 +671,45 @@ class ALSSuite .setItemCol("item") } - test("recommendForAllUsers with k < num_items") { - val topItems = getALSModel.recommendForAllUsers(2) - assert(topItems.count() == 3) - assert(topItems.columns.contains("user")) - - val expected = Map( - 0 -> Array((3, 54f), (4, 44f)), - 1 -> Array((3, 39f), (5, 33f)), - 2 -> Array((3, 51f), (5, 45f)) - ) - checkRecommendations(topItems, expected, "item") - } - - test("recommendForAllUsers with k = num_items") { - val topItems = getALSModel.recommendForAllUsers(4) - assert(topItems.count() == 3) - assert(topItems.columns.contains("user")) - + test("recommendForAllUsers with k <, = and > num_items") { + val model = getALSModel + val numUsers = model.userFactors.count + val numItems = model.itemFactors.count val expected = Map( 0 -> Array((3, 54f), (4, 44f), (5, 42f), (6, 28f)), 1 -> Array((3, 39f), (5, 33f), (4, 26f), (6, 16f)), 2 -> Array((3, 51f), (5, 45f), (4, 30f), (6, 18f)) ) - checkRecommendations(topItems, expected, "item") - } - test("recommendForAllItems with k < num_users") { - val topUsers = getALSModel.recommendForAllItems(2) - assert(topUsers.count() == 4) - assert(topUsers.columns.contains("item")) - - val expected = Map( - 3 -> Array((0, 54f), (2, 51f)), - 4 -> Array((0, 44f), (2, 30f)), - 5 -> Array((2, 45f), (0, 42f)), - 6 -> Array((0, 28f), (2, 18f)) - ) - checkRecommendations(topUsers, expected, "user") + Seq(2, 4, 6).foreach { k => + val n = math.min(k, numItems).toInt + val expectedUpToN = expected.mapValues(_.slice(0, n)) + val topItems = model.recommendForAllUsers(k) + assert(topItems.count() == numUsers) + assert(topItems.columns.contains("user")) + checkRecommendations(topItems, expectedUpToN, "item") + } } - test("recommendForAllItems with k = num_users") { - val topUsers = getALSModel.recommendForAllItems(3) - assert(topUsers.count() == 4) - assert(topUsers.columns.contains("item")) - + test("recommendForAllItems with k <, = and > num_users") { + val model = getALSModel + val numUsers = model.userFactors.count + val numItems = model.itemFactors.count val expected = Map( 3 -> Array((0, 54f), (2, 51f), (1, 39f)), 4 -> Array((0, 44f), (2, 30f), (1, 26f)), 5 -> Array((2, 45f), (0, 42f), (1, 33f)), 6 -> Array((0, 28f), (2, 18f), (1, 16f)) ) - checkRecommendations(topUsers, expected, "user") + + Seq(2, 3, 4).foreach { k => + val n = math.min(k, numUsers).toInt + val expectedUpToN = expected.mapValues(_.slice(0, n)) + val topUsers = getALSModel.recommendForAllItems(k) + assert(topUsers.count() == numItems) + assert(topUsers.columns.contains("item")) + checkRecommendations(topUsers, expectedUpToN, "user") + } } private def checkRecommendations( From aeb2ecc0cd898f5352df0a04be1014b02ea3e20e Mon Sep 17 00:00:00 2001 From: Xianyang Liu Date: Mon, 8 May 2017 10:25:24 -0700 Subject: [PATCH 0437/1765] [SPARK-20621][DEPLOY] Delete deprecated config parameter in 'spark-env.sh' ## What changes were proposed in this pull request? Currently, `spark.executor.instances` is deprecated in `spark-env.sh`, because we suggest config it in `spark-defaults.conf` or other config file. And also this parameter is useless even if you set it in `spark-env.sh`, so remove it in this patch. ## How was this patch tested? Existing tests. Please review http://spark.apache.org/contributing.html before opening a pull request. Author: Xianyang Liu Closes #17881 from ConeyLiu/deprecatedParam. --- conf/spark-env.sh.template | 1 - .../org/apache/spark/deploy/yarn/YarnSparkHadoopUtil.scala | 5 +---- 2 files changed, 1 insertion(+), 5 deletions(-) diff --git a/conf/spark-env.sh.template b/conf/spark-env.sh.template index 94bd2c477a35b..b7c985ace69cf 100755 --- a/conf/spark-env.sh.template +++ b/conf/spark-env.sh.template @@ -34,7 +34,6 @@ # Options read in YARN client mode # - HADOOP_CONF_DIR, to point Spark towards Hadoop configuration files -# - SPARK_EXECUTOR_INSTANCES, Number of executors to start (Default: 2) # - SPARK_EXECUTOR_CORES, Number of cores for the executors (Default: 1). # - SPARK_EXECUTOR_MEMORY, Memory per Executor (e.g. 1000M, 2G) (Default: 1G) # - SPARK_DRIVER_MEMORY, Memory for Driver (e.g. 1000M, 2G) (Default: 1G) diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtil.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtil.scala index 93578855122cd..0fc994d629ccb 100644 --- a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtil.scala +++ b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtil.scala @@ -280,10 +280,7 @@ object YarnSparkHadoopUtil { initialNumExecutors } else { - val targetNumExecutors = - sys.env.get("SPARK_EXECUTOR_INSTANCES").map(_.toInt).getOrElse(numExecutors) - // System property can override environment variable. - conf.get(EXECUTOR_INSTANCES).getOrElse(targetNumExecutors) + conf.get(EXECUTOR_INSTANCES).getOrElse(numExecutors) } } } From 829cd7b8b70e65a91aa66e6d626bd45f18e0ad97 Mon Sep 17 00:00:00 2001 From: jerryshao Date: Mon, 8 May 2017 14:27:56 -0700 Subject: [PATCH 0438/1765] [SPARK-20605][CORE][YARN][MESOS] Deprecate not used AM and executor port configuration ## What changes were proposed in this pull request? After SPARK-10997, client mode Netty RpcEnv doesn't require to start server, so port configurations are not used any more, here propose to remove these two configurations: "spark.executor.port" and "spark.am.port". ## How was this patch tested? Existing UTs. Author: jerryshao Closes #17866 from jerryshao/SPARK-20605. --- .../scala/org/apache/spark/SparkConf.scala | 4 ++- .../scala/org/apache/spark/SparkEnv.scala | 14 +++----- .../CoarseGrainedExecutorBackend.scala | 5 ++- docs/running-on-mesos.md | 2 +- docs/running-on-yarn.md | 7 ---- .../spark/executor/MesosExecutorBackend.scala | 3 +- .../cluster/mesos/MesosSchedulerUtils.scala | 2 +- .../mesos/MesosSchedulerUtilsSuite.scala | 34 +++++-------------- .../spark/deploy/yarn/ApplicationMaster.scala | 3 +- .../org/apache/spark/deploy/yarn/config.scala | 5 --- 10 files changed, 22 insertions(+), 57 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/SparkConf.scala b/core/src/main/scala/org/apache/spark/SparkConf.scala index 2a2ce0504dbbf..956724b14bba3 100644 --- a/core/src/main/scala/org/apache/spark/SparkConf.scala +++ b/core/src/main/scala/org/apache/spark/SparkConf.scala @@ -579,7 +579,9 @@ private[spark] object SparkConf extends Logging { "are no longer accepted. To specify the equivalent now, one may use '64k'."), DeprecatedConfig("spark.rpc", "2.0", "Not used any more."), DeprecatedConfig("spark.scheduler.executorTaskBlacklistTime", "2.1.0", - "Please use the new blacklisting options, spark.blacklist.*") + "Please use the new blacklisting options, spark.blacklist.*"), + DeprecatedConfig("spark.yarn.am.port", "2.0.0", "Not used any more"), + DeprecatedConfig("spark.executor.port", "2.0.0", "Not used any more") ) Map(configs.map { cfg => (cfg.key -> cfg) } : _*) diff --git a/core/src/main/scala/org/apache/spark/SparkEnv.scala b/core/src/main/scala/org/apache/spark/SparkEnv.scala index f4a59f069a5f9..3196c1ece15eb 100644 --- a/core/src/main/scala/org/apache/spark/SparkEnv.scala +++ b/core/src/main/scala/org/apache/spark/SparkEnv.scala @@ -177,7 +177,7 @@ object SparkEnv extends Logging { SparkContext.DRIVER_IDENTIFIER, bindAddress, advertiseAddress, - port, + Option(port), isLocal, numCores, ioEncryptionKey, @@ -194,7 +194,6 @@ object SparkEnv extends Logging { conf: SparkConf, executorId: String, hostname: String, - port: Int, numCores: Int, ioEncryptionKey: Option[Array[Byte]], isLocal: Boolean): SparkEnv = { @@ -203,7 +202,7 @@ object SparkEnv extends Logging { executorId, hostname, hostname, - port, + None, isLocal, numCores, ioEncryptionKey @@ -220,7 +219,7 @@ object SparkEnv extends Logging { executorId: String, bindAddress: String, advertiseAddress: String, - port: Int, + port: Option[Int], isLocal: Boolean, numUsableCores: Int, ioEncryptionKey: Option[Array[Byte]], @@ -243,17 +242,12 @@ object SparkEnv extends Logging { } val systemName = if (isDriver) driverSystemName else executorSystemName - val rpcEnv = RpcEnv.create(systemName, bindAddress, advertiseAddress, port, conf, + val rpcEnv = RpcEnv.create(systemName, bindAddress, advertiseAddress, port.getOrElse(-1), conf, securityManager, clientMode = !isDriver) // Figure out which port RpcEnv actually bound to in case the original port is 0 or occupied. - // In the non-driver case, the RPC env's address may be null since it may not be listening - // for incoming connections. if (isDriver) { conf.set("spark.driver.port", rpcEnv.address.port.toString) - } else if (rpcEnv.address != null) { - conf.set("spark.executor.port", rpcEnv.address.port.toString) - logInfo(s"Setting spark.executor.port to: ${rpcEnv.address.port.toString}") } // Create an instance of the class with the given name, possibly initializing it with our conf diff --git a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala index b2b26ee107c00..a2f1aa22b0063 100644 --- a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala +++ b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala @@ -191,11 +191,10 @@ private[spark] object CoarseGrainedExecutorBackend extends Logging { // Bootstrap to fetch the driver's Spark properties. val executorConf = new SparkConf - val port = executorConf.getInt("spark.executor.port", 0) val fetcher = RpcEnv.create( "driverPropsFetcher", hostname, - port, + -1, executorConf, new SecurityManager(executorConf), clientMode = true) @@ -221,7 +220,7 @@ private[spark] object CoarseGrainedExecutorBackend extends Logging { } val env = SparkEnv.createExecutorEnv( - driverConf, executorId, hostname, port, cores, cfg.ioEncryptionKey, isLocal = false) + driverConf, executorId, hostname, cores, cfg.ioEncryptionKey, isLocal = false) env.rpcEnv.setupEndpoint("Executor", new CoarseGrainedExecutorBackend( env.rpcEnv, driverUrl, executorId, hostname, cores, userClassPath, env)) diff --git a/docs/running-on-mesos.md b/docs/running-on-mesos.md index 314a806edf39e..c1344ad99a7d2 100644 --- a/docs/running-on-mesos.md +++ b/docs/running-on-mesos.md @@ -209,7 +209,7 @@ provide such guarantees on the offer stream. In this mode spark executors will honor port allocation if such is provided from the user. Specifically if the user defines -`spark.executor.port` or `spark.blockManager.port` in Spark configuration, +`spark.blockManager.port` in Spark configuration, the mesos scheduler will check the available offers for a valid port range containing the port numbers. If no such range is available it will not launch any task. If no restriction is imposed on port numbers by the diff --git a/docs/running-on-yarn.md b/docs/running-on-yarn.md index e9ddaa76a797f..2d56123028f2b 100644 --- a/docs/running-on-yarn.md +++ b/docs/running-on-yarn.md @@ -239,13 +239,6 @@ To use a custom metrics.properties for the application master and executors, upd Same as spark.yarn.driver.memoryOverhead, but for the YARN Application Master in client mode. - - spark.yarn.am.port - (random) - - Port for the YARN Application Master to listen on. In YARN client mode, this is used to communicate between the Spark driver running on a gateway and the YARN Application Master running on YARN. In YARN cluster mode, this is used for the dynamic executor feature, where it handles the kill from the scheduler backend. - - spark.yarn.queue default diff --git a/resource-managers/mesos/src/main/scala/org/apache/spark/executor/MesosExecutorBackend.scala b/resource-managers/mesos/src/main/scala/org/apache/spark/executor/MesosExecutorBackend.scala index a086ec7ea2da6..61bfa27a84fd8 100644 --- a/resource-managers/mesos/src/main/scala/org/apache/spark/executor/MesosExecutorBackend.scala +++ b/resource-managers/mesos/src/main/scala/org/apache/spark/executor/MesosExecutorBackend.scala @@ -74,9 +74,8 @@ private[spark] class MesosExecutorBackend val properties = Utils.deserialize[Array[(String, String)]](executorInfo.getData.toByteArray) ++ Seq[(String, String)](("spark.app.id", frameworkInfo.getId.getValue)) val conf = new SparkConf(loadDefaults = true).setAll(properties) - val port = conf.getInt("spark.executor.port", 0) val env = SparkEnv.createExecutorEnv( - conf, executorId, slaveInfo.getHostname, port, cpusPerTask, None, isLocal = false) + conf, executorId, slaveInfo.getHostname, cpusPerTask, None, isLocal = false) executor = new Executor( executorId, diff --git a/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerUtils.scala b/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerUtils.scala index 9d81025a3016b..062ed1f93fa52 100644 --- a/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerUtils.scala +++ b/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerUtils.scala @@ -438,7 +438,7 @@ trait MesosSchedulerUtils extends Logging { } } - val managedPortNames = List("spark.executor.port", BLOCK_MANAGER_PORT.key) + val managedPortNames = List(BLOCK_MANAGER_PORT.key) /** * The values of the non-zero ports to be used by the executor process. diff --git a/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerUtilsSuite.scala b/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerUtilsSuite.scala index ec47ab153177e..5d4bf6d082c4c 100644 --- a/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerUtilsSuite.scala +++ b/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerUtilsSuite.scala @@ -179,40 +179,25 @@ class MesosSchedulerUtilsSuite extends SparkFunSuite with Matchers with MockitoS test("Port reservation is done correctly with user specified ports only") { val conf = new SparkConf() - conf.set("spark.executor.port", "3000" ) conf.set(BLOCK_MANAGER_PORT, 4000) val portResource = createTestPortResource((3000, 5000), Some("my_role")) val (resourcesLeft, resourcesToBeUsed) = utils - .partitionPortResources(List(3000, 4000), List(portResource)) - resourcesToBeUsed.length shouldBe 2 + .partitionPortResources(List(4000), List(portResource)) + resourcesToBeUsed.length shouldBe 1 val portsToUse = getRangesFromResources(resourcesToBeUsed).map{r => r._1}.toArray - portsToUse.length shouldBe 2 - arePortsEqual(portsToUse, Array(3000L, 4000L)) shouldBe true + portsToUse.length shouldBe 1 + arePortsEqual(portsToUse, Array(4000L)) shouldBe true val portRangesToBeUsed = rangesResourcesToTuple(resourcesToBeUsed) - val expectedUSed = Array((3000L, 3000L), (4000L, 4000L)) + val expectedUSed = Array((4000L, 4000L)) arePortsEqual(portRangesToBeUsed.toArray, expectedUSed) shouldBe true } - test("Port reservation is done correctly with some user specified ports (spark.executor.port)") { - val conf = new SparkConf() - conf.set("spark.executor.port", "3100" ) - val portResource = createTestPortResource((3000, 5000), Some("my_role")) - - val (resourcesLeft, resourcesToBeUsed) = utils - .partitionPortResources(List(3100), List(portResource)) - - val portsToUse = getRangesFromResources(resourcesToBeUsed).map{r => r._1} - - portsToUse.length shouldBe 1 - portsToUse.contains(3100) shouldBe true - } - test("Port reservation is done correctly with all random ports") { val conf = new SparkConf() val portResource = createTestPortResource((3000L, 5000L), Some("my_role")) @@ -226,21 +211,20 @@ class MesosSchedulerUtilsSuite extends SparkFunSuite with Matchers with MockitoS test("Port reservation is done correctly with user specified ports only - multiple ranges") { val conf = new SparkConf() - conf.set("spark.executor.port", "2100" ) conf.set("spark.blockManager.port", "4000") val portResourceList = List(createTestPortResource((3000, 5000), Some("my_role")), createTestPortResource((2000, 2500), Some("other_role"))) val (resourcesLeft, resourcesToBeUsed) = utils - .partitionPortResources(List(2100, 4000), portResourceList) + .partitionPortResources(List(4000), portResourceList) val portsToUse = getRangesFromResources(resourcesToBeUsed).map{r => r._1} - portsToUse.length shouldBe 2 + portsToUse.length shouldBe 1 val portsRangesLeft = rangesResourcesToTuple(resourcesLeft) val portRangesToBeUsed = rangesResourcesToTuple(resourcesToBeUsed) - val expectedUsed = Array((2100L, 2100L), (4000L, 4000L)) + val expectedUsed = Array((4000L, 4000L)) - arePortsEqual(portsToUse.toArray, Array(2100L, 4000L)) shouldBe true + arePortsEqual(portsToUse.toArray, Array(4000L)) shouldBe true arePortsEqual(portRangesToBeUsed.toArray, expectedUsed) shouldBe true } diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala index 864c834d110fd..6da2c0b5f330a 100644 --- a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala +++ b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala @@ -429,8 +429,7 @@ private[spark] class ApplicationMaster( } private def runExecutorLauncher(securityMgr: SecurityManager): Unit = { - val port = sparkConf.get(AM_PORT) - rpcEnv = RpcEnv.create("sparkYarnAM", Utils.localHostName, port, sparkConf, securityMgr, + rpcEnv = RpcEnv.create("sparkYarnAM", Utils.localHostName, -1, sparkConf, securityMgr, clientMode = true) val driverRef = waitForSparkDriver() addAmIpFilter() diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/config.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/config.scala index d8c96c35ca71c..d4108caab28c1 100644 --- a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/config.scala +++ b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/config.scala @@ -40,11 +40,6 @@ package object config { .timeConf(TimeUnit.MILLISECONDS) .createOptional - private[spark] val AM_PORT = - ConfigBuilder("spark.yarn.am.port") - .intConf - .createWithDefault(0) - private[spark] val EXECUTOR_ATTEMPT_FAILURE_VALIDITY_INTERVAL_MS = ConfigBuilder("spark.yarn.executor.failuresValidityInterval") .doc("Interval after which Executor failures will be considered independent and not " + From 2abfee18b6511482b916c36f00bf3abf68a59e19 Mon Sep 17 00:00:00 2001 From: Hossein Date: Mon, 8 May 2017 14:48:11 -0700 Subject: [PATCH 0439/1765] [SPARK-20661][SPARKR][TEST] SparkR tableNames() test fails ## What changes were proposed in this pull request? Cleaning existing temp tables before running tableNames tests ## How was this patch tested? SparkR Unit tests Author: Hossein Closes #17903 from falaki/SPARK-20661. --- R/pkg/inst/tests/testthat/test_sparkSQL.R | 2 ++ 1 file changed, 2 insertions(+) diff --git a/R/pkg/inst/tests/testthat/test_sparkSQL.R b/R/pkg/inst/tests/testthat/test_sparkSQL.R index f517ce6713133..ab6888ea34fdd 100644 --- a/R/pkg/inst/tests/testthat/test_sparkSQL.R +++ b/R/pkg/inst/tests/testthat/test_sparkSQL.R @@ -677,6 +677,8 @@ test_that("jsonRDD() on a RDD with json string", { }) test_that("test tableNames and tables", { + # Making sure there are no registered temp tables from previous tests + suppressWarnings(sapply(tableNames(), function(tname) { dropTempTable(tname) })) df <- read.json(jsonPath) createOrReplaceTempView(df, "table1") expect_equal(length(tableNames()), 1) From b952b44af4d243f1e3ad88bccf4af7d04df3fc81 Mon Sep 17 00:00:00 2001 From: Felix Cheung Date: Mon, 8 May 2017 22:49:40 -0700 Subject: [PATCH 0440/1765] [SPARK-20661][SPARKR][TEST][FOLLOWUP] SparkR tableNames() test fails ## What changes were proposed in this pull request? Change it to check for relative count like in this test https://github.com/apache/spark/blame/master/R/pkg/inst/tests/testthat/test_sparkSQL.R#L3355 for catalog APIs ## How was this patch tested? unit tests, this needs to combine with another commit with SQL change to check Author: Felix Cheung Closes #17905 from felixcheung/rtabletests. --- R/pkg/inst/tests/testthat/test_sparkSQL.R | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/R/pkg/inst/tests/testthat/test_sparkSQL.R b/R/pkg/inst/tests/testthat/test_sparkSQL.R index ab6888ea34fdd..19aa61e9a56c3 100644 --- a/R/pkg/inst/tests/testthat/test_sparkSQL.R +++ b/R/pkg/inst/tests/testthat/test_sparkSQL.R @@ -677,26 +677,27 @@ test_that("jsonRDD() on a RDD with json string", { }) test_that("test tableNames and tables", { - # Making sure there are no registered temp tables from previous tests - suppressWarnings(sapply(tableNames(), function(tname) { dropTempTable(tname) })) + count <- count(listTables()) + df <- read.json(jsonPath) createOrReplaceTempView(df, "table1") - expect_equal(length(tableNames()), 1) - expect_equal(length(tableNames("default")), 1) + expect_equal(length(tableNames()), count + 1) + expect_equal(length(tableNames("default")), count + 1) + tables <- listTables() - expect_equal(count(tables), 1) + expect_equal(count(tables), count + 1) expect_equal(count(tables()), count(tables)) expect_true("tableName" %in% colnames(tables())) expect_true(all(c("tableName", "database", "isTemporary") %in% colnames(tables()))) suppressWarnings(registerTempTable(df, "table2")) tables <- listTables() - expect_equal(count(tables), 2) + expect_equal(count(tables), count + 2) suppressWarnings(dropTempTable("table1")) expect_true(dropTempView("table2")) tables <- listTables() - expect_equal(count(tables), 0) + expect_equal(count(tables), count + 0) }) test_that( From 8079424763c2043264f30a6898ce964379bd9b56 Mon Sep 17 00:00:00 2001 From: Peng Date: Tue, 9 May 2017 10:05:49 +0200 Subject: [PATCH 0441/1765] [SPARK-11968][MLLIB] Optimize MLLIB ALS recommendForAll The recommendForAll of MLLIB ALS is very slow. GC is a key problem of the current method. The task use the following code to keep temp result: val output = new Array[(Int, (Int, Double))](m*n) m = n = 4096 (default value, no method to set) so output is about 4k * 4k * (4 + 4 + 8) = 256M. This is a large memory and cause serious GC problem, and it is frequently OOM. Actually, we don't need to save all the temp result. Support we recommend topK (topK is about 10, or 20) product for each user, we only need 4k * topK * (4 + 4 + 8) memory to save the temp result. The Test Environment: 3 workers: each work 10 core, each work 30G memory, each work 1 executor. The Data: User 480,000, and Item 17,000 BlockSize: 1024 2048 4096 8192 Old method: 245s 332s 488s OOM This solution: 121s 118s 117s 120s The existing UT. Author: Peng Author: Peng Meng Closes #17742 from mpjlu/OptimizeAls. --- .../MatrixFactorizationModel.scala | 81 ++++++++++++------- 1 file changed, 50 insertions(+), 31 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModel.scala index 23045fa2b6863..d45866c016d91 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModel.scala @@ -39,6 +39,7 @@ import org.apache.spark.mllib.util.{Loader, Saveable} import org.apache.spark.rdd.RDD import org.apache.spark.sql.{Row, SparkSession} import org.apache.spark.storage.StorageLevel +import org.apache.spark.util.BoundedPriorityQueue /** * Model representing the result of matrix factorization. @@ -274,46 +275,64 @@ object MatrixFactorizationModel extends Loader[MatrixFactorizationModel] { srcFeatures: RDD[(Int, Array[Double])], dstFeatures: RDD[(Int, Array[Double])], num: Int): RDD[(Int, Array[(Int, Double)])] = { - val srcBlocks = blockify(rank, srcFeatures) - val dstBlocks = blockify(rank, dstFeatures) - val ratings = srcBlocks.cartesian(dstBlocks).flatMap { - case ((srcIds, srcFactors), (dstIds, dstFactors)) => - val m = srcIds.length - val n = dstIds.length - val ratings = srcFactors.transpose.multiply(dstFactors) - val output = new Array[(Int, (Int, Double))](m * n) - var k = 0 - ratings.foreachActive { (i, j, r) => - output(k) = (srcIds(i), (dstIds(j), r)) - k += 1 + val srcBlocks = blockify(srcFeatures) + val dstBlocks = blockify(dstFeatures) + /** + * The previous approach used for computing top-k recommendations aimed to group + * individual factor vectors into blocks, so that Level 3 BLAS operations (gemm) could + * be used for efficiency. However, this causes excessive GC pressure due to the large + * arrays required for intermediate result storage, as well as a high sensitivity to the + * block size used. + * The following approach still groups factors into blocks, but instead computes the + * top-k elements per block, using a simple dot product (instead of gemm) and an efficient + * [[BoundedPriorityQueue]]. This avoids any large intermediate data structures and results + * in significantly reduced GC pressure as well as shuffle data, which far outweighs + * any cost incurred from not using Level 3 BLAS operations. + */ + val ratings = srcBlocks.cartesian(dstBlocks).flatMap { case (srcIter, dstIter) => + val m = srcIter.size + val n = math.min(dstIter.size, num) + val output = new Array[(Int, (Int, Double))](m * n) + var j = 0 + val pq = new BoundedPriorityQueue[(Int, Double)](n)(Ordering.by(_._2)) + srcIter.foreach { case (srcId, srcFactor) => + dstIter.foreach { case (dstId, dstFactor) => + /* + * The below code is equivalent to + * `val score = blas.ddot(rank, srcFactor, 1, dstFactor, 1)` + * This handwritten version is as or more efficient as BLAS calls in this case. + */ + var score: Double = 0 + var k = 0 + while (k < rank) { + score += srcFactor(k) * dstFactor(k) + k += 1 + } + pq += dstId -> score + } + val pqIter = pq.iterator + var i = 0 + while (i < n) { + output(j + i) = (srcId, pqIter.next()) + i += 1 } - output.toSeq + j += n + pq.clear() + } + output.toSeq } ratings.topByKey(num)(Ordering.by(_._2)) } /** - * Blockifies features to use Level-3 BLAS. + * Blockifies features to improve the efficiency of cartesian product + * TODO: SPARK-20443 - expose blockSize as a param? */ private def blockify( - rank: Int, - features: RDD[(Int, Array[Double])]): RDD[(Array[Int], DenseMatrix)] = { - val blockSize = 4096 // TODO: tune the block size - val blockStorage = rank * blockSize + features: RDD[(Int, Array[Double])], + blockSize: Int = 4096): RDD[Seq[(Int, Array[Double])]] = { features.mapPartitions { iter => - iter.grouped(blockSize).map { grouped => - val ids = mutable.ArrayBuilder.make[Int] - ids.sizeHint(blockSize) - val factors = mutable.ArrayBuilder.make[Double] - factors.sizeHint(blockStorage) - var i = 0 - grouped.foreach { case (id, factor) => - ids += id - factors ++= factor - i += 1 - } - (ids.result(), new DenseMatrix(rank, i, factors.result())) - } + iter.grouped(blockSize) } } From 10b00abadf4a3473332eef996db7b66f491316f2 Mon Sep 17 00:00:00 2001 From: Nick Pentreath Date: Tue, 9 May 2017 10:13:15 +0200 Subject: [PATCH 0442/1765] [SPARK-20587][ML] Improve performance of ML ALS recommendForAll This PR is a `DataFrame` version of #17742 for [SPARK-11968](https://issues.apache.org/jira/browse/SPARK-11968), for improving the performance of `recommendAll` methods. ## How was this patch tested? Existing unit tests. Author: Nick Pentreath Closes #17845 from MLnick/ml-als-perf. --- .../apache/spark/ml/recommendation/ALS.scala | 71 +++++++++++++++++-- 1 file changed, 64 insertions(+), 7 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala index 1562bf1beb7e1..d626f04599670 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala @@ -45,7 +45,7 @@ import org.apache.spark.sql.{DataFrame, Dataset} import org.apache.spark.sql.functions._ import org.apache.spark.sql.types._ import org.apache.spark.storage.StorageLevel -import org.apache.spark.util.Utils +import org.apache.spark.util.{BoundedPriorityQueue, Utils} import org.apache.spark.util.collection.{OpenHashMap, OpenHashSet, SortDataFormat, Sorter} import org.apache.spark.util.random.XORShiftRandom @@ -356,6 +356,19 @@ class ALSModel private[ml] ( /** * Makes recommendations for all users (or items). + * + * Note: the previous approach used for computing top-k recommendations + * used a cross-join followed by predicting a score for each row of the joined dataset. + * However, this results in exploding the size of intermediate data. While Spark SQL makes it + * relatively efficient, the approach implemented here is significantly more efficient. + * + * This approach groups factors into blocks and computes the top-k elements per block, + * using a simple dot product (instead of gemm) and an efficient [[BoundedPriorityQueue]]. + * It then computes the global top-k by aggregating the per block top-k elements with + * a [[TopByKeyAggregator]]. This significantly reduces the size of intermediate and shuffle data. + * This is the DataFrame equivalent to the approach used in + * [[org.apache.spark.mllib.recommendation.MatrixFactorizationModel]]. + * * @param srcFactors src factors for which to generate recommendations * @param dstFactors dst factors used to make recommendations * @param srcOutputColumn name of the column for the source ID in the output DataFrame @@ -372,11 +385,43 @@ class ALSModel private[ml] ( num: Int): DataFrame = { import srcFactors.sparkSession.implicits._ - val ratings = srcFactors.crossJoin(dstFactors) - .select( - srcFactors("id"), - dstFactors("id"), - predict(srcFactors("features"), dstFactors("features"))) + val srcFactorsBlocked = blockify(srcFactors.as[(Int, Array[Float])]) + val dstFactorsBlocked = blockify(dstFactors.as[(Int, Array[Float])]) + val ratings = srcFactorsBlocked.crossJoin(dstFactorsBlocked) + .as[(Seq[(Int, Array[Float])], Seq[(Int, Array[Float])])] + .flatMap { case (srcIter, dstIter) => + val m = srcIter.size + val n = math.min(dstIter.size, num) + val output = new Array[(Int, Int, Float)](m * n) + var j = 0 + val pq = new BoundedPriorityQueue[(Int, Float)](num)(Ordering.by(_._2)) + srcIter.foreach { case (srcId, srcFactor) => + dstIter.foreach { case (dstId, dstFactor) => + /* + * The below code is equivalent to + * `val score = blas.sdot(rank, srcFactor, 1, dstFactor, 1)` + * This handwritten version is as or more efficient as BLAS calls in this case. + */ + var score = 0.0f + var k = 0 + while (k < rank) { + score += srcFactor(k) * dstFactor(k) + k += 1 + } + pq += dstId -> score + } + val pqIter = pq.iterator + var i = 0 + while (i < n) { + val (dstId, score) = pqIter.next() + output(j + i) = (srcId, dstId, score) + i += 1 + } + j += n + pq.clear() + } + output.toSeq + } // We'll force the IDs to be Int. Unfortunately this converts IDs to Int in the output. val topKAggregator = new TopByKeyAggregator[Int, Int, Float](num, Ordering.by(_._2)) val recs = ratings.as[(Int, Int, Float)].groupByKey(_._1).agg(topKAggregator.toColumn) @@ -387,8 +432,20 @@ class ALSModel private[ml] ( .add(dstOutputColumn, IntegerType) .add("rating", FloatType) ) - recs.select($"id" as srcOutputColumn, $"recommendations" cast arrayType) + recs.select($"id".as(srcOutputColumn), $"recommendations".cast(arrayType)) } + + /** + * Blockifies factors to improve the efficiency of cross join + * TODO: SPARK-20443 - expose blockSize as a param? + */ + private def blockify( + factors: Dataset[(Int, Array[Float])], + blockSize: Int = 4096): Dataset[Seq[(Int, Array[Float])]] = { + import factors.sparkSession.implicits._ + factors.mapPartitions(_.grouped(blockSize)) + } + } @Since("1.6.0") From be53a78352ae7c70d8a07d0df24574b3e3129b4a Mon Sep 17 00:00:00 2001 From: Jon McLean Date: Tue, 9 May 2017 09:47:50 +0100 Subject: [PATCH 0443/1765] [SPARK-20615][ML][TEST] SparseVector.argmax throws IndexOutOfBoundsException ## What changes were proposed in this pull request? Added a check for for the number of defined values. Previously the argmax function assumed that at least one value was defined if the vector size was greater than zero. ## How was this patch tested? Tests were added to the existing VectorsSuite to cover this case. Author: Jon McLean Closes #17877 from jonmclean/vectorArgmaxIndexBug. --- .../main/scala/org/apache/spark/ml/linalg/Vectors.scala | 2 ++ .../scala/org/apache/spark/ml/linalg/VectorsSuite.scala | 7 +++++++ .../main/scala/org/apache/spark/mllib/linalg/Vectors.scala | 2 ++ .../scala/org/apache/spark/mllib/linalg/VectorsSuite.scala | 7 +++++++ 4 files changed, 18 insertions(+) diff --git a/mllib-local/src/main/scala/org/apache/spark/ml/linalg/Vectors.scala b/mllib-local/src/main/scala/org/apache/spark/ml/linalg/Vectors.scala index 8e166ba0ff51a..3fbc0958a0f11 100644 --- a/mllib-local/src/main/scala/org/apache/spark/ml/linalg/Vectors.scala +++ b/mllib-local/src/main/scala/org/apache/spark/ml/linalg/Vectors.scala @@ -657,6 +657,8 @@ class SparseVector @Since("2.0.0") ( override def argmax: Int = { if (size == 0) { -1 + } else if (numActives == 0) { + 0 } else { // Find the max active entry. var maxIdx = indices(0) diff --git a/mllib-local/src/test/scala/org/apache/spark/ml/linalg/VectorsSuite.scala b/mllib-local/src/test/scala/org/apache/spark/ml/linalg/VectorsSuite.scala index dfbdaf19d374b..4cd91afd6d7fc 100644 --- a/mllib-local/src/test/scala/org/apache/spark/ml/linalg/VectorsSuite.scala +++ b/mllib-local/src/test/scala/org/apache/spark/ml/linalg/VectorsSuite.scala @@ -125,6 +125,13 @@ class VectorsSuite extends SparkMLFunSuite { val vec8 = Vectors.sparse(5, Array(1, 2), Array(0.0, -1.0)) assert(vec8.argmax === 0) + + // Check for case when sparse vector is non-empty but the values are empty + val vec9 = Vectors.sparse(100, Array.empty[Int], Array.empty[Double]).asInstanceOf[SparseVector] + assert(vec9.argmax === 0) + + val vec10 = Vectors.sparse(1, Array.empty[Int], Array.empty[Double]).asInstanceOf[SparseVector] + assert(vec10.argmax === 0) } test("vector equals") { diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala index 723addc7150dd..f063420bec143 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala @@ -846,6 +846,8 @@ class SparseVector @Since("1.0.0") ( override def argmax: Int = { if (size == 0) { -1 + } else if (numActives == 0) { + 0 } else { // Find the max active entry. var maxIdx = indices(0) diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala index 71a3ceac1b947..6172cffee861c 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala @@ -122,6 +122,13 @@ class VectorsSuite extends SparkFunSuite with Logging { val vec8 = Vectors.sparse(5, Array(1, 2), Array(0.0, -1.0)) assert(vec8.argmax === 0) + + // Check for case when sparse vector is non-empty but the values are empty + val vec9 = Vectors.sparse(100, Array.empty[Int], Array.empty[Double]).asInstanceOf[SparseVector] + assert(vec9.argmax === 0) + + val vec10 = Vectors.sparse(1, Array.empty[Int], Array.empty[Double]).asInstanceOf[SparseVector] + assert(vec10.argmax === 0) } test("vector equals") { From b8733e0ad9f5a700f385e210450fd2c10137293e Mon Sep 17 00:00:00 2001 From: Yanbo Liang Date: Tue, 9 May 2017 17:30:37 +0800 Subject: [PATCH 0444/1765] [SPARK-20606][ML] ML 2.2 QA: Remove deprecated methods for ML ## What changes were proposed in this pull request? Remove ML methods we deprecated in 2.1. ## How was this patch tested? Existing tests. Author: Yanbo Liang Closes #17867 from yanboliang/spark-20606. --- .../DecisionTreeClassifier.scala | 18 +-- .../ml/classification/GBTClassifier.scala | 24 ++-- .../RandomForestClassifier.scala | 24 ++-- .../ml/regression/DecisionTreeRegressor.scala | 18 +-- .../spark/ml/regression/GBTRegressor.scala | 24 ++-- .../ml/regression/RandomForestRegressor.scala | 24 ++-- .../org/apache/spark/ml/tree/treeParams.scala | 105 ------------------ .../org/apache/spark/ml/util/ReadWrite.scala | 16 --- project/MimaExcludes.scala | 68 ++++++++++++ python/pyspark/ml/util.py | 32 ------ 10 files changed, 134 insertions(+), 219 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala index 9f60f0896ec52..5fb105c6aff60 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala @@ -54,27 +54,27 @@ class DecisionTreeClassifier @Since("1.4.0") ( /** @group setParam */ @Since("1.4.0") - override def setMaxDepth(value: Int): this.type = set(maxDepth, value) + def setMaxDepth(value: Int): this.type = set(maxDepth, value) /** @group setParam */ @Since("1.4.0") - override def setMaxBins(value: Int): this.type = set(maxBins, value) + def setMaxBins(value: Int): this.type = set(maxBins, value) /** @group setParam */ @Since("1.4.0") - override def setMinInstancesPerNode(value: Int): this.type = set(minInstancesPerNode, value) + def setMinInstancesPerNode(value: Int): this.type = set(minInstancesPerNode, value) /** @group setParam */ @Since("1.4.0") - override def setMinInfoGain(value: Double): this.type = set(minInfoGain, value) + def setMinInfoGain(value: Double): this.type = set(minInfoGain, value) /** @group expertSetParam */ @Since("1.4.0") - override def setMaxMemoryInMB(value: Int): this.type = set(maxMemoryInMB, value) + def setMaxMemoryInMB(value: Int): this.type = set(maxMemoryInMB, value) /** @group expertSetParam */ @Since("1.4.0") - override def setCacheNodeIds(value: Boolean): this.type = set(cacheNodeIds, value) + def setCacheNodeIds(value: Boolean): this.type = set(cacheNodeIds, value) /** * Specifies how often to checkpoint the cached node IDs. @@ -86,15 +86,15 @@ class DecisionTreeClassifier @Since("1.4.0") ( * @group setParam */ @Since("1.4.0") - override def setCheckpointInterval(value: Int): this.type = set(checkpointInterval, value) + def setCheckpointInterval(value: Int): this.type = set(checkpointInterval, value) /** @group setParam */ @Since("1.4.0") - override def setImpurity(value: String): this.type = set(impurity, value) + def setImpurity(value: String): this.type = set(impurity, value) /** @group setParam */ @Since("1.6.0") - override def setSeed(value: Long): this.type = set(seed, value) + def setSeed(value: Long): this.type = set(seed, value) override protected def train(dataset: Dataset[_]): DecisionTreeClassificationModel = { val categoricalFeatures: Map[Int, Int] = diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala index ade0960f87a0d..263ed10f19855 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala @@ -70,27 +70,27 @@ class GBTClassifier @Since("1.4.0") ( /** @group setParam */ @Since("1.4.0") - override def setMaxDepth(value: Int): this.type = set(maxDepth, value) + def setMaxDepth(value: Int): this.type = set(maxDepth, value) /** @group setParam */ @Since("1.4.0") - override def setMaxBins(value: Int): this.type = set(maxBins, value) + def setMaxBins(value: Int): this.type = set(maxBins, value) /** @group setParam */ @Since("1.4.0") - override def setMinInstancesPerNode(value: Int): this.type = set(minInstancesPerNode, value) + def setMinInstancesPerNode(value: Int): this.type = set(minInstancesPerNode, value) /** @group setParam */ @Since("1.4.0") - override def setMinInfoGain(value: Double): this.type = set(minInfoGain, value) + def setMinInfoGain(value: Double): this.type = set(minInfoGain, value) /** @group expertSetParam */ @Since("1.4.0") - override def setMaxMemoryInMB(value: Int): this.type = set(maxMemoryInMB, value) + def setMaxMemoryInMB(value: Int): this.type = set(maxMemoryInMB, value) /** @group expertSetParam */ @Since("1.4.0") - override def setCacheNodeIds(value: Boolean): this.type = set(cacheNodeIds, value) + def setCacheNodeIds(value: Boolean): this.type = set(cacheNodeIds, value) /** * Specifies how often to checkpoint the cached node IDs. @@ -102,7 +102,7 @@ class GBTClassifier @Since("1.4.0") ( * @group setParam */ @Since("1.4.0") - override def setCheckpointInterval(value: Int): this.type = set(checkpointInterval, value) + def setCheckpointInterval(value: Int): this.type = set(checkpointInterval, value) /** * The impurity setting is ignored for GBT models. @@ -111,7 +111,7 @@ class GBTClassifier @Since("1.4.0") ( * @group setParam */ @Since("1.4.0") - override def setImpurity(value: String): this.type = { + def setImpurity(value: String): this.type = { logWarning("GBTClassifier.setImpurity should NOT be used") this } @@ -120,21 +120,21 @@ class GBTClassifier @Since("1.4.0") ( /** @group setParam */ @Since("1.4.0") - override def setSubsamplingRate(value: Double): this.type = set(subsamplingRate, value) + def setSubsamplingRate(value: Double): this.type = set(subsamplingRate, value) /** @group setParam */ @Since("1.4.0") - override def setSeed(value: Long): this.type = set(seed, value) + def setSeed(value: Long): this.type = set(seed, value) // Parameters from GBTParams: /** @group setParam */ @Since("1.4.0") - override def setMaxIter(value: Int): this.type = set(maxIter, value) + def setMaxIter(value: Int): this.type = set(maxIter, value) /** @group setParam */ @Since("1.4.0") - override def setStepSize(value: Double): this.type = set(stepSize, value) + def setStepSize(value: Double): this.type = set(stepSize, value) // Parameters from GBTClassifierParams: diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala index ab4c235209289..441cfda899276 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala @@ -56,27 +56,27 @@ class RandomForestClassifier @Since("1.4.0") ( /** @group setParam */ @Since("1.4.0") - override def setMaxDepth(value: Int): this.type = set(maxDepth, value) + def setMaxDepth(value: Int): this.type = set(maxDepth, value) /** @group setParam */ @Since("1.4.0") - override def setMaxBins(value: Int): this.type = set(maxBins, value) + def setMaxBins(value: Int): this.type = set(maxBins, value) /** @group setParam */ @Since("1.4.0") - override def setMinInstancesPerNode(value: Int): this.type = set(minInstancesPerNode, value) + def setMinInstancesPerNode(value: Int): this.type = set(minInstancesPerNode, value) /** @group setParam */ @Since("1.4.0") - override def setMinInfoGain(value: Double): this.type = set(minInfoGain, value) + def setMinInfoGain(value: Double): this.type = set(minInfoGain, value) /** @group expertSetParam */ @Since("1.4.0") - override def setMaxMemoryInMB(value: Int): this.type = set(maxMemoryInMB, value) + def setMaxMemoryInMB(value: Int): this.type = set(maxMemoryInMB, value) /** @group expertSetParam */ @Since("1.4.0") - override def setCacheNodeIds(value: Boolean): this.type = set(cacheNodeIds, value) + def setCacheNodeIds(value: Boolean): this.type = set(cacheNodeIds, value) /** * Specifies how often to checkpoint the cached node IDs. @@ -88,31 +88,31 @@ class RandomForestClassifier @Since("1.4.0") ( * @group setParam */ @Since("1.4.0") - override def setCheckpointInterval(value: Int): this.type = set(checkpointInterval, value) + def setCheckpointInterval(value: Int): this.type = set(checkpointInterval, value) /** @group setParam */ @Since("1.4.0") - override def setImpurity(value: String): this.type = set(impurity, value) + def setImpurity(value: String): this.type = set(impurity, value) // Parameters from TreeEnsembleParams: /** @group setParam */ @Since("1.4.0") - override def setSubsamplingRate(value: Double): this.type = set(subsamplingRate, value) + def setSubsamplingRate(value: Double): this.type = set(subsamplingRate, value) /** @group setParam */ @Since("1.4.0") - override def setSeed(value: Long): this.type = set(seed, value) + def setSeed(value: Long): this.type = set(seed, value) // Parameters from RandomForestParams: /** @group setParam */ @Since("1.4.0") - override def setNumTrees(value: Int): this.type = set(numTrees, value) + def setNumTrees(value: Int): this.type = set(numTrees, value) /** @group setParam */ @Since("1.4.0") - override def setFeatureSubsetStrategy(value: String): this.type = + def setFeatureSubsetStrategy(value: String): this.type = set(featureSubsetStrategy, value) override protected def train(dataset: Dataset[_]): RandomForestClassificationModel = { diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala index 01c5cc1c7efa9..c2b0358e8405d 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala @@ -53,27 +53,27 @@ class DecisionTreeRegressor @Since("1.4.0") (@Since("1.4.0") override val uid: S // Override parameter setters from parent trait for Java API compatibility. /** @group setParam */ @Since("1.4.0") - override def setMaxDepth(value: Int): this.type = set(maxDepth, value) + def setMaxDepth(value: Int): this.type = set(maxDepth, value) /** @group setParam */ @Since("1.4.0") - override def setMaxBins(value: Int): this.type = set(maxBins, value) + def setMaxBins(value: Int): this.type = set(maxBins, value) /** @group setParam */ @Since("1.4.0") - override def setMinInstancesPerNode(value: Int): this.type = set(minInstancesPerNode, value) + def setMinInstancesPerNode(value: Int): this.type = set(minInstancesPerNode, value) /** @group setParam */ @Since("1.4.0") - override def setMinInfoGain(value: Double): this.type = set(minInfoGain, value) + def setMinInfoGain(value: Double): this.type = set(minInfoGain, value) /** @group expertSetParam */ @Since("1.4.0") - override def setMaxMemoryInMB(value: Int): this.type = set(maxMemoryInMB, value) + def setMaxMemoryInMB(value: Int): this.type = set(maxMemoryInMB, value) /** @group expertSetParam */ @Since("1.4.0") - override def setCacheNodeIds(value: Boolean): this.type = set(cacheNodeIds, value) + def setCacheNodeIds(value: Boolean): this.type = set(cacheNodeIds, value) /** * Specifies how often to checkpoint the cached node IDs. @@ -85,15 +85,15 @@ class DecisionTreeRegressor @Since("1.4.0") (@Since("1.4.0") override val uid: S * @group setParam */ @Since("1.4.0") - override def setCheckpointInterval(value: Int): this.type = set(checkpointInterval, value) + def setCheckpointInterval(value: Int): this.type = set(checkpointInterval, value) /** @group setParam */ @Since("1.4.0") - override def setImpurity(value: String): this.type = set(impurity, value) + def setImpurity(value: String): this.type = set(impurity, value) /** @group setParam */ @Since("1.6.0") - override def setSeed(value: Long): this.type = set(seed, value) + def setSeed(value: Long): this.type = set(seed, value) /** @group setParam */ @Since("2.0.0") diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala index 08d175cb94442..8d9b519efb142 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala @@ -68,27 +68,27 @@ class GBTRegressor @Since("1.4.0") (@Since("1.4.0") override val uid: String) /** @group setParam */ @Since("1.4.0") - override def setMaxDepth(value: Int): this.type = set(maxDepth, value) + def setMaxDepth(value: Int): this.type = set(maxDepth, value) /** @group setParam */ @Since("1.4.0") - override def setMaxBins(value: Int): this.type = set(maxBins, value) + def setMaxBins(value: Int): this.type = set(maxBins, value) /** @group setParam */ @Since("1.4.0") - override def setMinInstancesPerNode(value: Int): this.type = set(minInstancesPerNode, value) + def setMinInstancesPerNode(value: Int): this.type = set(minInstancesPerNode, value) /** @group setParam */ @Since("1.4.0") - override def setMinInfoGain(value: Double): this.type = set(minInfoGain, value) + def setMinInfoGain(value: Double): this.type = set(minInfoGain, value) /** @group expertSetParam */ @Since("1.4.0") - override def setMaxMemoryInMB(value: Int): this.type = set(maxMemoryInMB, value) + def setMaxMemoryInMB(value: Int): this.type = set(maxMemoryInMB, value) /** @group expertSetParam */ @Since("1.4.0") - override def setCacheNodeIds(value: Boolean): this.type = set(cacheNodeIds, value) + def setCacheNodeIds(value: Boolean): this.type = set(cacheNodeIds, value) /** * Specifies how often to checkpoint the cached node IDs. @@ -100,7 +100,7 @@ class GBTRegressor @Since("1.4.0") (@Since("1.4.0") override val uid: String) * @group setParam */ @Since("1.4.0") - override def setCheckpointInterval(value: Int): this.type = set(checkpointInterval, value) + def setCheckpointInterval(value: Int): this.type = set(checkpointInterval, value) /** * The impurity setting is ignored for GBT models. @@ -109,7 +109,7 @@ class GBTRegressor @Since("1.4.0") (@Since("1.4.0") override val uid: String) * @group setParam */ @Since("1.4.0") - override def setImpurity(value: String): this.type = { + def setImpurity(value: String): this.type = { logWarning("GBTRegressor.setImpurity should NOT be used") this } @@ -118,21 +118,21 @@ class GBTRegressor @Since("1.4.0") (@Since("1.4.0") override val uid: String) /** @group setParam */ @Since("1.4.0") - override def setSubsamplingRate(value: Double): this.type = set(subsamplingRate, value) + def setSubsamplingRate(value: Double): this.type = set(subsamplingRate, value) /** @group setParam */ @Since("1.4.0") - override def setSeed(value: Long): this.type = set(seed, value) + def setSeed(value: Long): this.type = set(seed, value) // Parameters from GBTParams: /** @group setParam */ @Since("1.4.0") - override def setMaxIter(value: Int): this.type = set(maxIter, value) + def setMaxIter(value: Int): this.type = set(maxIter, value) /** @group setParam */ @Since("1.4.0") - override def setStepSize(value: Double): this.type = set(stepSize, value) + def setStepSize(value: Double): this.type = set(stepSize, value) // Parameters from GBTRegressorParams: diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala index a58da50fad972..7b9ddf6e9521a 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala @@ -55,27 +55,27 @@ class RandomForestRegressor @Since("1.4.0") (@Since("1.4.0") override val uid: S /** @group setParam */ @Since("1.4.0") - override def setMaxDepth(value: Int): this.type = set(maxDepth, value) + def setMaxDepth(value: Int): this.type = set(maxDepth, value) /** @group setParam */ @Since("1.4.0") - override def setMaxBins(value: Int): this.type = set(maxBins, value) + def setMaxBins(value: Int): this.type = set(maxBins, value) /** @group setParam */ @Since("1.4.0") - override def setMinInstancesPerNode(value: Int): this.type = set(minInstancesPerNode, value) + def setMinInstancesPerNode(value: Int): this.type = set(minInstancesPerNode, value) /** @group setParam */ @Since("1.4.0") - override def setMinInfoGain(value: Double): this.type = set(minInfoGain, value) + def setMinInfoGain(value: Double): this.type = set(minInfoGain, value) /** @group expertSetParam */ @Since("1.4.0") - override def setMaxMemoryInMB(value: Int): this.type = set(maxMemoryInMB, value) + def setMaxMemoryInMB(value: Int): this.type = set(maxMemoryInMB, value) /** @group expertSetParam */ @Since("1.4.0") - override def setCacheNodeIds(value: Boolean): this.type = set(cacheNodeIds, value) + def setCacheNodeIds(value: Boolean): this.type = set(cacheNodeIds, value) /** * Specifies how often to checkpoint the cached node IDs. @@ -87,31 +87,31 @@ class RandomForestRegressor @Since("1.4.0") (@Since("1.4.0") override val uid: S * @group setParam */ @Since("1.4.0") - override def setCheckpointInterval(value: Int): this.type = set(checkpointInterval, value) + def setCheckpointInterval(value: Int): this.type = set(checkpointInterval, value) /** @group setParam */ @Since("1.4.0") - override def setImpurity(value: String): this.type = set(impurity, value) + def setImpurity(value: String): this.type = set(impurity, value) // Parameters from TreeEnsembleParams: /** @group setParam */ @Since("1.4.0") - override def setSubsamplingRate(value: Double): this.type = set(subsamplingRate, value) + def setSubsamplingRate(value: Double): this.type = set(subsamplingRate, value) /** @group setParam */ @Since("1.4.0") - override def setSeed(value: Long): this.type = set(seed, value) + def setSeed(value: Long): this.type = set(seed, value) // Parameters from RandomForestParams: /** @group setParam */ @Since("1.4.0") - override def setNumTrees(value: Int): this.type = set(numTrees, value) + def setNumTrees(value: Int): this.type = set(numTrees, value) /** @group setParam */ @Since("1.4.0") - override def setFeatureSubsetStrategy(value: String): this.type = + def setFeatureSubsetStrategy(value: String): this.type = set(featureSubsetStrategy, value) override protected def train(dataset: Dataset[_]): RandomForestRegressionModel = { diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala index cd1950bd76c05..5526d4d75bd73 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala @@ -109,80 +109,24 @@ private[ml] trait DecisionTreeParams extends PredictorParams setDefault(maxDepth -> 5, maxBins -> 32, minInstancesPerNode -> 1, minInfoGain -> 0.0, maxMemoryInMB -> 256, cacheNodeIds -> false, checkpointInterval -> 10) - /** - * @deprecated This method is deprecated and will be removed in 2.2.0. - * @group setParam - */ - @deprecated("This method is deprecated and will be removed in 2.2.0.", "2.1.0") - def setMaxDepth(value: Int): this.type = set(maxDepth, value) - /** @group getParam */ final def getMaxDepth: Int = $(maxDepth) - /** - * @deprecated This method is deprecated and will be removed in 2.2.0. - * @group setParam - */ - @deprecated("This method is deprecated and will be removed in 2.2.0.", "2.1.0") - def setMaxBins(value: Int): this.type = set(maxBins, value) - /** @group getParam */ final def getMaxBins: Int = $(maxBins) - /** - * @deprecated This method is deprecated and will be removed in 2.2.0. - * @group setParam - */ - @deprecated("This method is deprecated and will be removed in 2.2.0.", "2.1.0") - def setMinInstancesPerNode(value: Int): this.type = set(minInstancesPerNode, value) - /** @group getParam */ final def getMinInstancesPerNode: Int = $(minInstancesPerNode) - /** - * @deprecated This method is deprecated and will be removed in 2.2.0. - * @group setParam - */ - @deprecated("This method is deprecated and will be removed in 2.2.0.", "2.1.0") - def setMinInfoGain(value: Double): this.type = set(minInfoGain, value) - /** @group getParam */ final def getMinInfoGain: Double = $(minInfoGain) - /** - * @deprecated This method is deprecated and will be removed in 2.2.0. - * @group setParam - */ - @deprecated("This method is deprecated and will be removed in 2.2.0.", "2.1.0") - def setSeed(value: Long): this.type = set(seed, value) - - /** - * @deprecated This method is deprecated and will be removed in 2.2.0. - * @group expertSetParam - */ - @deprecated("This method is deprecated and will be removed in 2.2.0.", "2.1.0") - def setMaxMemoryInMB(value: Int): this.type = set(maxMemoryInMB, value) - /** @group expertGetParam */ final def getMaxMemoryInMB: Int = $(maxMemoryInMB) - /** - * @deprecated This method is deprecated and will be removed in 2.2.0. - * @group expertSetParam - */ - @deprecated("This method is deprecated and will be removed in 2.2.0.", "2.1.0") - def setCacheNodeIds(value: Boolean): this.type = set(cacheNodeIds, value) - /** @group expertGetParam */ final def getCacheNodeIds: Boolean = $(cacheNodeIds) - /** - * @deprecated This method is deprecated and will be removed in 2.2.0. - * @group setParam - */ - @deprecated("This method is deprecated and will be removed in 2.2.0.", "2.1.0") - def setCheckpointInterval(value: Int): this.type = set(checkpointInterval, value) - /** (private[ml]) Create a Strategy instance to use with the old API. */ private[ml] def getOldStrategy( categoricalFeatures: Map[Int, Int], @@ -225,13 +169,6 @@ private[ml] trait TreeClassifierParams extends Params { setDefault(impurity -> "gini") - /** - * @deprecated This method is deprecated and will be removed in 2.2.0. - * @group setParam - */ - @deprecated("This method is deprecated and will be removed in 2.2.0.", "2.1.0") - def setImpurity(value: String): this.type = set(impurity, value) - /** @group getParam */ final def getImpurity: String = $(impurity).toLowerCase(Locale.ROOT) @@ -276,13 +213,6 @@ private[ml] trait TreeRegressorParams extends Params { setDefault(impurity -> "variance") - /** - * @deprecated This method is deprecated and will be removed in 2.2.0. - * @group setParam - */ - @deprecated("This method is deprecated and will be removed in 2.2.0.", "2.1.0") - def setImpurity(value: String): this.type = set(impurity, value) - /** @group getParam */ final def getImpurity: String = $(impurity).toLowerCase(Locale.ROOT) @@ -338,13 +268,6 @@ private[ml] trait TreeEnsembleParams extends DecisionTreeParams { setDefault(subsamplingRate -> 1.0) - /** - * @deprecated This method is deprecated and will be removed in 2.2.0. - * @group setParam - */ - @deprecated("This method is deprecated and will be removed in 2.2.0.", "2.1.0") - def setSubsamplingRate(value: Double): this.type = set(subsamplingRate, value) - /** @group getParam */ final def getSubsamplingRate: Double = $(subsamplingRate) @@ -382,13 +305,6 @@ private[ml] trait RandomForestParams extends TreeEnsembleParams { setDefault(numTrees -> 20) - /** - * @deprecated This method is deprecated and will be removed in 2.2.0. - * @group setParam - */ - @deprecated("This method is deprecated and will be removed in 2.2.0.", "2.1.0") - def setNumTrees(value: Int): this.type = set(numTrees, value) - /** @group getParam */ final def getNumTrees: Int = $(numTrees) @@ -430,13 +346,6 @@ private[ml] trait RandomForestParams extends TreeEnsembleParams { setDefault(featureSubsetStrategy -> "auto") - /** - * @deprecated This method is deprecated and will be removed in 2.2.0. - * @group setParam - */ - @deprecated("This method is deprecated and will be removed in 2.2.0.", "2.1.0") - def setFeatureSubsetStrategy(value: String): this.type = set(featureSubsetStrategy, value) - /** @group getParam */ final def getFeatureSubsetStrategy: String = $(featureSubsetStrategy).toLowerCase(Locale.ROOT) } @@ -471,13 +380,6 @@ private[ml] trait GBTParams extends TreeEnsembleParams with HasMaxIter { // final val validationTol: DoubleParam = new DoubleParam(this, "validationTol", "") // validationTol -> 1e-5 - /** - * @deprecated This method is deprecated and will be removed in 2.2.0. - * @group setParam - */ - @deprecated("This method is deprecated and will be removed in 2.2.0.", "2.1.0") - def setMaxIter(value: Int): this.type = set(maxIter, value) - /** * Param for Step size (a.k.a. learning rate) in interval (0, 1] for shrinking * the contribution of each estimator. @@ -491,13 +393,6 @@ private[ml] trait GBTParams extends TreeEnsembleParams with HasMaxIter { /** @group getParam */ final def getStepSize: Double = $(stepSize) - /** - * @deprecated This method is deprecated and will be removed in 2.2.0. - * @group setParam - */ - @deprecated("This method is deprecated and will be removed in 2.2.0.", "2.1.0") - def setStepSize(value: Double): this.type = set(stepSize, value) - setDefault(maxIter -> 20, stepSize -> 0.1) /** (private[ml]) Create a BoostingStrategy instance to use with the old API. */ diff --git a/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala b/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala index a8b80031faf86..f7e570fd5cc94 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala @@ -42,16 +42,6 @@ import org.apache.spark.util.Utils private[util] sealed trait BaseReadWrite { private var optionSparkSession: Option[SparkSession] = None - /** - * Sets the Spark SQLContext to use for saving/loading. - */ - @Since("1.6.0") - @deprecated("Use session instead, This method will be removed in 2.2.0.", "2.0.0") - def context(sqlContext: SQLContext): this.type = { - optionSparkSession = Option(sqlContext.sparkSession) - this - } - /** * Sets the Spark Session to use for saving/loading. */ @@ -130,9 +120,6 @@ abstract class MLWriter extends BaseReadWrite with Logging { // override for Java compatibility override def session(sparkSession: SparkSession): this.type = super.session(sparkSession) - - // override for Java compatibility - override def context(sqlContext: SQLContext): this.type = super.session(sqlContext.sparkSession) } /** @@ -188,9 +175,6 @@ abstract class MLReader[T] extends BaseReadWrite { // override for Java compatibility override def session(sparkSession: SparkSession): this.type = super.session(sparkSession) - - // override for Java compatibility - override def context(sqlContext: SQLContext): this.type = super.session(sqlContext.sparkSession) } /** diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index d50882cb1917e..d8b37aebb5d1d 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -1005,6 +1005,74 @@ object MimaExcludes { ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.ml.classification.RandomForestClassificationModel.setFeatureSubsetStrategy"), ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.ml.regression.RandomForestRegressionModel.numTrees"), ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.ml.regression.RandomForestRegressionModel.setFeatureSubsetStrategy") + ) ++ Seq( + // [SPARK-20606] ML 2.2 QA: Remove deprecated methods for ML + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.classification.DecisionTreeClassificationModel.setSeed"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.classification.DecisionTreeClassificationModel.setMinInfoGain"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.classification.DecisionTreeClassificationModel.setCacheNodeIds"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.classification.DecisionTreeClassificationModel.setCheckpointInterval"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.classification.DecisionTreeClassificationModel.setMaxDepth"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.classification.DecisionTreeClassificationModel.setImpurity"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.classification.DecisionTreeClassificationModel.setMaxMemoryInMB"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.classification.DecisionTreeClassificationModel.setMaxBins"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.classification.DecisionTreeClassificationModel.setMinInstancesPerNode"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.classification.GBTClassificationModel.setSeed"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.classification.GBTClassificationModel.setMinInfoGain"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.classification.GBTClassificationModel.setSubsamplingRate"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.classification.GBTClassificationModel.setMaxIter"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.classification.GBTClassificationModel.setCacheNodeIds"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.classification.GBTClassificationModel.setCheckpointInterval"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.classification.GBTClassificationModel.setMaxDepth"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.classification.GBTClassificationModel.setImpurity"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.classification.GBTClassificationModel.setMaxMemoryInMB"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.classification.GBTClassificationModel.setStepSize"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.classification.GBTClassificationModel.setMaxBins"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.classification.GBTClassificationModel.setMinInstancesPerNode"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.classification.RandomForestClassificationModel.setSeed"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.classification.RandomForestClassificationModel.setMinInfoGain"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.classification.RandomForestClassificationModel.setSubsamplingRate"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.classification.RandomForestClassificationModel.setCacheNodeIds"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.classification.RandomForestClassificationModel.setCheckpointInterval"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.classification.RandomForestClassificationModel.setMaxDepth"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.classification.RandomForestClassificationModel.setImpurity"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.classification.RandomForestClassificationModel.setMaxMemoryInMB"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.classification.RandomForestClassificationModel.setFeatureSubsetStrategy"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.classification.RandomForestClassificationModel.setMaxBins"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.classification.RandomForestClassificationModel.setMinInstancesPerNode"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.regression.DecisionTreeRegressionModel.setSeed"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.regression.DecisionTreeRegressionModel.setMinInfoGain"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.regression.DecisionTreeRegressionModel.setCacheNodeIds"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.regression.DecisionTreeRegressionModel.setCheckpointInterval"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.regression.DecisionTreeRegressionModel.setMaxDepth"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.regression.DecisionTreeRegressionModel.setImpurity"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.regression.DecisionTreeRegressionModel.setMaxMemoryInMB"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.regression.DecisionTreeRegressionModel.setMaxBins"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.regression.DecisionTreeRegressionModel.setMinInstancesPerNode"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.regression.GBTRegressionModel.setSeed"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.regression.GBTRegressionModel.setMinInfoGain"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.regression.GBTRegressionModel.setSubsamplingRate"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.regression.GBTRegressionModel.setMaxIter"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.regression.GBTRegressionModel.setCacheNodeIds"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.regression.GBTRegressionModel.setCheckpointInterval"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.regression.GBTRegressionModel.setMaxDepth"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.regression.GBTRegressionModel.setImpurity"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.regression.GBTRegressionModel.setMaxMemoryInMB"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.regression.GBTRegressionModel.setStepSize"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.regression.GBTRegressionModel.setMaxBins"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.regression.GBTRegressionModel.setMinInstancesPerNode"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.regression.RandomForestRegressionModel.setSeed"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.regression.RandomForestRegressionModel.setMinInfoGain"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.regression.RandomForestRegressionModel.setSubsamplingRate"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.regression.RandomForestRegressionModel.setCacheNodeIds"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.regression.RandomForestRegressionModel.setCheckpointInterval"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.regression.RandomForestRegressionModel.setMaxDepth"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.regression.RandomForestRegressionModel.setImpurity"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.regression.RandomForestRegressionModel.setMaxMemoryInMB"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.regression.RandomForestRegressionModel.setFeatureSubsetStrategy"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.regression.RandomForestRegressionModel.setMaxBins"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.regression.RandomForestRegressionModel.setMinInstancesPerNode"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.util.MLWriter.context"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.util.MLReader.context") ) } diff --git a/python/pyspark/ml/util.py b/python/pyspark/ml/util.py index 02016f172aebc..688109ab11fd2 100644 --- a/python/pyspark/ml/util.py +++ b/python/pyspark/ml/util.py @@ -76,13 +76,6 @@ def overwrite(self): """Overwrites if the output path already exists.""" raise NotImplementedError("MLWriter is not yet implemented for type: %s" % type(self)) - def context(self, sqlContext): - """ - Sets the SQL context to use for saving. - .. note:: Deprecated in 2.1 and will be removed in 2.2, use session instead. - """ - raise NotImplementedError("MLWriter is not yet implemented for type: %s" % type(self)) - def session(self, sparkSession): """Sets the Spark Session to use for saving.""" raise NotImplementedError("MLWriter is not yet implemented for type: %s" % type(self)) @@ -110,15 +103,6 @@ def overwrite(self): self._jwrite.overwrite() return self - def context(self, sqlContext): - """ - Sets the SQL context to use for saving. - .. note:: Deprecated in 2.1 and will be removed in 2.2, use session instead. - """ - warnings.warn("Deprecated in 2.1 and will be removed in 2.2, use session instead.") - self._jwrite.context(sqlContext._ssql_ctx) - return self - def session(self, sparkSession): """Sets the Spark Session to use for saving.""" self._jwrite.session(sparkSession._jsparkSession) @@ -165,13 +149,6 @@ def load(self, path): """Load the ML instance from the input path.""" raise NotImplementedError("MLReader is not yet implemented for type: %s" % type(self)) - def context(self, sqlContext): - """ - Sets the SQL context to use for loading. - .. note:: Deprecated in 2.1 and will be removed in 2.2, use session instead. - """ - raise NotImplementedError("MLReader is not yet implemented for type: %s" % type(self)) - def session(self, sparkSession): """Sets the Spark Session to use for loading.""" raise NotImplementedError("MLReader is not yet implemented for type: %s" % type(self)) @@ -197,15 +174,6 @@ def load(self, path): % self._clazz) return self._clazz._from_java(java_obj) - def context(self, sqlContext): - """ - Sets the SQL context to use for loading. - .. note:: Deprecated in 2.1 and will be removed in 2.2, use session instead. - """ - warnings.warn("Deprecated in 2.1 and will be removed in 2.2, use session instead.") - self._jread.context(sqlContext._ssql_ctx) - return self - def session(self, sparkSession): """Sets the Spark Session to use for loading.""" self._jread.session(sparkSession._jsparkSession) From 0d00c768a860fc03402c8f0c9081b8147c29133e Mon Sep 17 00:00:00 2001 From: Xiao Li Date: Tue, 9 May 2017 20:10:50 +0800 Subject: [PATCH 0445/1765] [SPARK-20667][SQL][TESTS] Cleanup the cataloged metadata after completing the package of sql/core and sql/hive ## What changes were proposed in this pull request? So far, we do not drop all the cataloged objects after each package. Sometimes, we might hit strange test case errors because the previous test suite did not drop the cataloged/temporary objects (tables/functions/database). At least, we can first clean up the environment when completing the package of `sql/core` and `sql/hive`. ## How was this patch tested? N/A Author: Xiao Li Closes #17908 from gatorsmile/reset. --- .../apache/spark/sql/catalyst/catalog/SessionCatalog.scala | 3 ++- .../scala/org/apache/spark/sql/test/SharedSQLContext.scala | 1 + .../scala/org/apache/spark/sql/hive/test/TestHive.scala | 7 +------ 3 files changed, 4 insertions(+), 7 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala index 6c6d600190b66..18e514681e811 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala @@ -1251,9 +1251,10 @@ class SessionCatalog( dropTempFunction(func.funcName, ignoreIfNotExists = false) } } - tempTables.clear() + clearTempTables() globalTempViewManager.clear() functionRegistry.clear() + tableRelationCache.invalidateAll() // restore built-in functions FunctionRegistry.builtin.listFunction().foreach { f => val expressionInfo = FunctionRegistry.builtin.lookupFunction(f) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSQLContext.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSQLContext.scala index 81c69a338abcc..7cea4c02155ea 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSQLContext.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSQLContext.scala @@ -74,6 +74,7 @@ trait SharedSQLContext extends SQLTestUtils with BeforeAndAfterEach with Eventua protected override def afterAll(): Unit = { super.afterAll() if (_spark != null) { + _spark.sessionState.catalog.reset() _spark.stop() _spark = null } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala index d9bb1f8c7edcc..ee9ac21a738dc 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala @@ -488,14 +488,9 @@ private[hive] class TestHiveSparkSession( sharedState.cacheManager.clearCache() loadedTables.clear() - sessionState.catalog.clearTempTables() - sessionState.catalog.tableRelationCache.invalidateAll() - + sessionState.catalog.reset() metadataHive.reset() - FunctionRegistry.getFunctionNames.asScala.filterNot(originalUDFs.contains(_)). - foreach { udfName => FunctionRegistry.unregisterTemporaryUDF(udfName) } - // HDFS root scratch dir requires the write all (733) permission. For each connecting user, // an HDFS scratch dir: ${hive.exec.scratchdir}/ is created, with // ${hive.scratch.dir.permission}. To resolve the permission issue, the simplest way is to From 714811d0b5bcb5d47c39782ff74f898d276ecc59 Mon Sep 17 00:00:00 2001 From: Takeshi Yamamuro Date: Tue, 9 May 2017 20:22:51 +0800 Subject: [PATCH 0446/1765] [SPARK-20311][SQL] Support aliases for table value functions ## What changes were proposed in this pull request? This pr added parsing rules to support aliases in table value functions. ## How was this patch tested? Added tests in `PlanParserSuite`. Author: Takeshi Yamamuro Closes #17666 from maropu/SPARK-20311. --- .../spark/sql/catalyst/parser/SqlBase.g4 | 20 ++++++++++++----- .../ResolveTableValuedFunctions.scala | 22 ++++++++++++++++--- .../sql/catalyst/analysis/unresolved.scala | 10 +++++++-- .../sql/catalyst/parser/AstBuilder.scala | 17 ++++++++++---- .../sql/catalyst/analysis/AnalysisSuite.scala | 14 +++++++++++- .../sql/catalyst/parser/PlanParserSuite.scala | 13 ++++++++++- 6 files changed, 79 insertions(+), 17 deletions(-) diff --git a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 index 14c511f670606..41daf58a98fd9 100644 --- a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 +++ b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 @@ -472,15 +472,23 @@ identifierComment ; relationPrimary - : tableIdentifier sample? (AS? strictIdentifier)? #tableName - | '(' queryNoWith ')' sample? (AS? strictIdentifier)? #aliasedQuery - | '(' relation ')' sample? (AS? strictIdentifier)? #aliasedRelation - | inlineTable #inlineTableDefault2 - | identifier '(' (expression (',' expression)*)? ')' #tableValuedFunction + : tableIdentifier sample? (AS? strictIdentifier)? #tableName + | '(' queryNoWith ')' sample? (AS? strictIdentifier)? #aliasedQuery + | '(' relation ')' sample? (AS? strictIdentifier)? #aliasedRelation + | inlineTable #inlineTableDefault2 + | functionTable #tableValuedFunction ; inlineTable - : VALUES expression (',' expression)* (AS? identifier identifierList?)? + : VALUES expression (',' expression)* tableAlias + ; + +functionTable + : identifier '(' (expression (',' expression)*)? ')' tableAlias + ; + +tableAlias + : (AS? identifier identifierList?)? ; rowFormat diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveTableValuedFunctions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveTableValuedFunctions.scala index de6de24350f23..dad1340571cc8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveTableValuedFunctions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveTableValuedFunctions.scala @@ -19,8 +19,8 @@ package org.apache.spark.sql.catalyst.analysis import java.util.Locale -import org.apache.spark.sql.catalyst.expressions.Expression -import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Range} +import org.apache.spark.sql.catalyst.expressions.{Alias, Expression} +import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Project, Range} import org.apache.spark.sql.catalyst.rules._ import org.apache.spark.sql.types.{DataType, IntegerType, LongType} @@ -105,7 +105,7 @@ object ResolveTableValuedFunctions extends Rule[LogicalPlan] { override def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { case u: UnresolvedTableValuedFunction if u.functionArgs.forall(_.resolved) => - builtinFunctions.get(u.functionName.toLowerCase(Locale.ROOT)) match { + val resolvedFunc = builtinFunctions.get(u.functionName.toLowerCase(Locale.ROOT)) match { case Some(tvf) => val resolved = tvf.flatMap { case (argList, resolver) => argList.implicitCast(u.functionArgs) match { @@ -125,5 +125,21 @@ object ResolveTableValuedFunctions extends Rule[LogicalPlan] { case _ => u.failAnalysis(s"could not resolve `${u.functionName}` to a table-valued function") } + + // If alias names assigned, add `Project` with the aliases + if (u.outputNames.nonEmpty) { + val outputAttrs = resolvedFunc.output + // Checks if the number of the aliases is equal to expected one + if (u.outputNames.size != outputAttrs.size) { + u.failAnalysis(s"expected ${outputAttrs.size} columns but " + + s"found ${u.outputNames.size} columns") + } + val aliases = outputAttrs.zip(u.outputNames).map { + case (attr, name) => Alias(attr, name)() + } + Project(aliases, resolvedFunc) + } else { + resolvedFunc + } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala index 262b894e2a0a3..51bef6e20b9fa 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala @@ -66,10 +66,16 @@ case class UnresolvedInlineTable( /** * A table-valued function, e.g. * {{{ - * select * from range(10); + * select id from range(10); + * + * // Assign alias names + * select t.a from range(10) t(a); * }}} */ -case class UnresolvedTableValuedFunction(functionName: String, functionArgs: Seq[Expression]) +case class UnresolvedTableValuedFunction( + functionName: String, + functionArgs: Seq[Expression], + outputNames: Seq[String]) extends LeafNode { override def output: Seq[Attribute] = Nil diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index d2a9b4a9a9f59..e03fe2ccb8d89 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -687,7 +687,16 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with Logging { */ override def visitTableValuedFunction(ctx: TableValuedFunctionContext) : LogicalPlan = withOrigin(ctx) { - UnresolvedTableValuedFunction(ctx.identifier.getText, ctx.expression.asScala.map(expression)) + val func = ctx.functionTable + val aliases = if (func.tableAlias.identifierList != null) { + visitIdentifierList(func.tableAlias.identifierList) + } else { + Seq.empty + } + + val tvf = UnresolvedTableValuedFunction( + func.identifier.getText, func.expression.asScala.map(expression), aliases) + tvf.optionalMap(func.tableAlias.identifier)(aliasPlan) } /** @@ -705,14 +714,14 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with Logging { } } - val aliases = if (ctx.identifierList != null) { - visitIdentifierList(ctx.identifierList) + val aliases = if (ctx.tableAlias.identifierList != null) { + visitIdentifierList(ctx.tableAlias.identifierList) } else { Seq.tabulate(rows.head.size)(i => s"col${i + 1}") } val table = UnresolvedInlineTable(aliases, rows) - table.optionalMap(ctx.identifier)(aliasPlan) + table.optionalMap(ctx.tableAlias.identifier)(aliasPlan) } /** diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala index 893bb1b74cea7..31047f688600b 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala @@ -25,7 +25,6 @@ import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.parser.CatalystSqlParser import org.apache.spark.sql.catalyst.plans.Cross import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.types._ @@ -441,4 +440,17 @@ class AnalysisSuite extends AnalysisTest with ShouldMatchers { checkAnalysis(SubqueryAlias("tbl", testRelation).as("tbl2"), testRelation) } + + test("SPARK-20311 range(N) as alias") { + def rangeWithAliases(args: Seq[Int], outputNames: Seq[String]): LogicalPlan = { + SubqueryAlias("t", UnresolvedTableValuedFunction("range", args.map(Literal(_)), outputNames)) + .select(star()) + } + assertAnalysisSuccess(rangeWithAliases(3 :: Nil, "a" :: Nil)) + assertAnalysisSuccess(rangeWithAliases(1 :: 4 :: Nil, "b" :: Nil)) + assertAnalysisSuccess(rangeWithAliases(2 :: 6 :: 2 :: Nil, "c" :: Nil)) + assertAnalysisError( + rangeWithAliases(3 :: Nil, "a" :: "b" :: Nil), + Seq("expected 1 columns but found 2 columns")) + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala index 411777d6e85a2..4c2476296c049 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala @@ -468,7 +468,18 @@ class PlanParserSuite extends PlanTest { test("table valued function") { assertEqual( "select * from range(2)", - UnresolvedTableValuedFunction("range", Literal(2) :: Nil).select(star())) + UnresolvedTableValuedFunction("range", Literal(2) :: Nil, Seq.empty).select(star())) + } + + test("SPARK-20311 range(N) as alias") { + assertEqual( + "select * from range(10) AS t", + SubqueryAlias("t", UnresolvedTableValuedFunction("range", Literal(10) :: Nil, Seq.empty)) + .select(star())) + assertEqual( + "select * from range(7) AS t(a)", + SubqueryAlias("t", UnresolvedTableValuedFunction("range", Literal(7) :: Nil, "a" :: Nil)) + .select(star())) } test("inline table") { From 181261a81d592b93181135a8267570e0c9ab2243 Mon Sep 17 00:00:00 2001 From: Sanket Date: Tue, 9 May 2017 09:30:09 -0500 Subject: [PATCH 0447/1765] [SPARK-20355] Add per application spark version on the history server headerpage ## What changes were proposed in this pull request? Spark Version for a specific application is not displayed on the history page now. It should be nice to switch the spark version on the UI when we click on the specific application. Currently there seems to be way as SparkListenerLogStart records the application version. So, it should be trivial to listen to this event and provision this change on the UI. For Example screen shot 2017-04-06 at 3 23 41 pm screen shot 2017-04-17 at 9 59 33 am {"Event":"SparkListenerLogStart","Spark Version":"2.0.0"} (Please fill in changes proposed in this fix) Modified the SparkUI for History server to listen to SparkLogListenerStart event and extract the version and print it. ## How was this patch tested? Manual testing of UI page. Attaching the UI screenshot changes here (Please explain how this patch was tested. E.g. unit tests, integration tests, manual tests) (If this patch involves UI changes, please attach a screenshot; otherwise, remove this) Please review http://spark.apache.org/contributing.html before opening a pull request. Author: Sanket Closes #17658 from redsanket/SPARK-20355. --- .../history/ApplicationHistoryProvider.scala | 3 ++- .../deploy/history/FsHistoryProvider.scala | 17 ++++++++++++----- .../scheduler/ApplicationEventListener.scala | 7 +++++++ .../spark/scheduler/EventLoggingListener.scala | 13 ++++++++++--- .../apache/spark/scheduler/SparkListener.scala | 4 ++-- .../spark/scheduler/SparkListenerBus.scala | 1 - .../status/api/v1/ApplicationListResource.scala | 3 ++- .../org/apache/spark/status/api/v1/api.scala | 3 ++- .../scala/org/apache/spark/ui/SparkUI.scala | 6 +++++- .../scala/org/apache/spark/ui/UIUtils.scala | 2 +- .../application_list_json_expectation.json | 10 ++++++++++ .../completed_app_list_json_expectation.json | 11 +++++++++++ .../limit_app_list_json_expectation.json | 3 +++ .../maxDate2_app_list_json_expectation.json | 1 + .../maxDate_app_list_json_expectation.json | 2 ++ .../maxEndDate_app_list_json_expectation.json | 7 +++++++ ...nd_maxEndDate_app_list_json_expectation.json | 4 ++++ .../minDate_app_list_json_expectation.json | 8 ++++++++ ...nd_maxEndDate_app_list_json_expectation.json | 4 ++++ .../minEndDate_app_list_json_expectation.json | 6 +++++- .../one_app_json_expectation.json | 1 + .../one_app_multi_attempt_json_expectation.json | 2 ++ .../deploy/history/ApplicationCacheSuite.scala | 2 +- .../deploy/history/FsHistoryProviderSuite.scala | 4 ++-- project/MimaExcludes.scala | 3 +++ 25 files changed, 107 insertions(+), 20 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/history/ApplicationHistoryProvider.scala b/core/src/main/scala/org/apache/spark/deploy/history/ApplicationHistoryProvider.scala index 6d8758a3d3b1d..5cb48ca3e60b0 100644 --- a/core/src/main/scala/org/apache/spark/deploy/history/ApplicationHistoryProvider.scala +++ b/core/src/main/scala/org/apache/spark/deploy/history/ApplicationHistoryProvider.scala @@ -30,7 +30,8 @@ private[spark] case class ApplicationAttemptInfo( endTime: Long, lastUpdated: Long, sparkUser: String, - completed: Boolean = false) + completed: Boolean = false, + appSparkVersion: String) private[spark] case class ApplicationHistoryInfo( id: String, diff --git a/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala b/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala index f4235df245128..d05ca142b618b 100644 --- a/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala +++ b/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala @@ -248,7 +248,8 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) val conf = this.conf.clone() val appSecManager = new SecurityManager(conf) SparkUI.createHistoryUI(conf, replayBus, appSecManager, appInfo.name, - HistoryServer.getAttemptURI(appId, attempt.attemptId), attempt.startTime) + HistoryServer.getAttemptURI(appId, attempt.attemptId), + attempt.startTime) // Do not call ui.bind() to avoid creating a new server for each application } @@ -257,6 +258,7 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) val appListener = replay(fileStatus, isApplicationCompleted(fileStatus), replayBus) if (appListener.appId.isDefined) { + ui.appSparkVersion = appListener.appSparkVersion.getOrElse("") ui.getSecurityManager.setAcls(HISTORY_UI_ACLS_ENABLE) // make sure to set admin acls before view acls so they are properly picked up val adminAcls = HISTORY_UI_ADMIN_ACLS + "," + appListener.adminAcls.getOrElse("") @@ -443,7 +445,8 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) val newAttempts = try { val eventsFilter: ReplayEventsFilter = { eventString => eventString.startsWith(APPL_START_EVENT_PREFIX) || - eventString.startsWith(APPL_END_EVENT_PREFIX) + eventString.startsWith(APPL_END_EVENT_PREFIX) || + eventString.startsWith(LOG_START_EVENT_PREFIX) } val logPath = fileStatus.getPath() @@ -469,7 +472,8 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) lastUpdated, appListener.sparkUser.getOrElse(NOT_STARTED), appCompleted, - fileStatus.getLen() + fileStatus.getLen(), + appListener.appSparkVersion.getOrElse("") ) fileToAppInfo(logPath) = attemptInfo logDebug(s"Application log ${attemptInfo.logPath} loaded successfully: $attemptInfo") @@ -735,6 +739,8 @@ private[history] object FsHistoryProvider { private val APPL_START_EVENT_PREFIX = "{\"Event\":\"SparkListenerApplicationStart\"" private val APPL_END_EVENT_PREFIX = "{\"Event\":\"SparkListenerApplicationEnd\"" + + private val LOG_START_EVENT_PREFIX = "{\"Event\":\"SparkListenerLogStart\"" } /** @@ -762,9 +768,10 @@ private class FsApplicationAttemptInfo( lastUpdated: Long, sparkUser: String, completed: Boolean, - val fileSize: Long) + val fileSize: Long, + appSparkVersion: String) extends ApplicationAttemptInfo( - attemptId, startTime, endTime, lastUpdated, sparkUser, completed) { + attemptId, startTime, endTime, lastUpdated, sparkUser, completed, appSparkVersion) { /** extend the superclass string value with the extra attributes of this class */ override def toString: String = { diff --git a/core/src/main/scala/org/apache/spark/scheduler/ApplicationEventListener.scala b/core/src/main/scala/org/apache/spark/scheduler/ApplicationEventListener.scala index 28c45d800ed06..6da8865cd10d3 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/ApplicationEventListener.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/ApplicationEventListener.scala @@ -34,6 +34,7 @@ private[spark] class ApplicationEventListener extends SparkListener { var adminAcls: Option[String] = None var viewAclsGroups: Option[String] = None var adminAclsGroups: Option[String] = None + var appSparkVersion: Option[String] = None override def onApplicationStart(applicationStart: SparkListenerApplicationStart) { appName = Some(applicationStart.appName) @@ -57,4 +58,10 @@ private[spark] class ApplicationEventListener extends SparkListener { adminAclsGroups = allProperties.get("spark.admin.acls.groups") } } + + override def onOtherEvent(event: SparkListenerEvent): Unit = event match { + case SparkListenerLogStart(sparkVersion) => + appSparkVersion = Some(sparkVersion) + case _ => + } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala b/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala index a7dbf87915b27..f481436332249 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala @@ -119,7 +119,7 @@ private[spark] class EventLoggingListener( val cstream = compressionCodec.map(_.compressedOutputStream(dstream)).getOrElse(dstream) val bstream = new BufferedOutputStream(cstream, outputBufferSize) - EventLoggingListener.initEventLog(bstream) + EventLoggingListener.initEventLog(bstream, testing, loggedEvents) fileSystem.setPermission(path, LOG_FILE_PERMISSIONS) writer = Some(new PrintWriter(bstream)) logInfo("Logging events to %s".format(logPath)) @@ -283,10 +283,17 @@ private[spark] object EventLoggingListener extends Logging { * * @param logStream Raw output stream to the event log file. */ - def initEventLog(logStream: OutputStream): Unit = { + def initEventLog( + logStream: OutputStream, + testing: Boolean, + loggedEvents: ArrayBuffer[JValue]): Unit = { val metadata = SparkListenerLogStart(SPARK_VERSION) - val metadataJson = compact(JsonProtocol.logStartToJson(metadata)) + "\n" + val eventJson = JsonProtocol.logStartToJson(metadata) + val metadataJson = compact(eventJson) + "\n" logStream.write(metadataJson.getBytes(StandardCharsets.UTF_8)) + if (testing && loggedEvents != null) { + loggedEvents += eventJson + } } /** diff --git a/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala b/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala index bc2e530716686..59f89a82a1da8 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala @@ -160,9 +160,9 @@ case class SparkListenerApplicationEnd(time: Long) extends SparkListenerEvent /** * An internal class that describes the metadata of an event log. - * This event is not meant to be posted to listeners downstream. */ -private[spark] case class SparkListenerLogStart(sparkVersion: String) extends SparkListenerEvent +@DeveloperApi +case class SparkListenerLogStart(sparkVersion: String) extends SparkListenerEvent /** * Interface for creating history listeners defined in other modules like SQL, which are used to diff --git a/core/src/main/scala/org/apache/spark/scheduler/SparkListenerBus.scala b/core/src/main/scala/org/apache/spark/scheduler/SparkListenerBus.scala index 3ff363321e8c9..3b0d3b1b150fe 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/SparkListenerBus.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/SparkListenerBus.scala @@ -71,7 +71,6 @@ private[spark] trait SparkListenerBus listener.onNodeUnblacklisted(nodeUnblacklisted) case blockUpdated: SparkListenerBlockUpdated => listener.onBlockUpdated(blockUpdated) - case logStart: SparkListenerLogStart => // ignore event log metadata case _ => listener.onOtherEvent(event) } } diff --git a/core/src/main/scala/org/apache/spark/status/api/v1/ApplicationListResource.scala b/core/src/main/scala/org/apache/spark/status/api/v1/ApplicationListResource.scala index a0239266d8756..f039744e7f67f 100644 --- a/core/src/main/scala/org/apache/spark/status/api/v1/ApplicationListResource.scala +++ b/core/src/main/scala/org/apache/spark/status/api/v1/ApplicationListResource.scala @@ -90,7 +90,8 @@ private[spark] object ApplicationsListResource { }, lastUpdated = new Date(internalAttemptInfo.lastUpdated), sparkUser = internalAttemptInfo.sparkUser, - completed = internalAttemptInfo.completed + completed = internalAttemptInfo.completed, + appSparkVersion = internalAttemptInfo.appSparkVersion ) } ) diff --git a/core/src/main/scala/org/apache/spark/status/api/v1/api.scala b/core/src/main/scala/org/apache/spark/status/api/v1/api.scala index 56d8e51732ffd..f6203271f3cd2 100644 --- a/core/src/main/scala/org/apache/spark/status/api/v1/api.scala +++ b/core/src/main/scala/org/apache/spark/status/api/v1/api.scala @@ -38,7 +38,8 @@ class ApplicationAttemptInfo private[spark]( val lastUpdated: Date, val duration: Long, val sparkUser: String, - val completed: Boolean = false) { + val completed: Boolean = false, + val appSparkVersion: String) { def getStartTimeEpoch: Long = startTime.getTime def getEndTimeEpoch: Long = endTime.getTime def getLastUpdatedEpoch: Long = lastUpdated.getTime diff --git a/core/src/main/scala/org/apache/spark/ui/SparkUI.scala b/core/src/main/scala/org/apache/spark/ui/SparkUI.scala index bf4cf79e9faa3..f271c56021e95 100644 --- a/core/src/main/scala/org/apache/spark/ui/SparkUI.scala +++ b/core/src/main/scala/org/apache/spark/ui/SparkUI.scala @@ -60,6 +60,8 @@ private[spark] class SparkUI private ( var appId: String = _ + var appSparkVersion = org.apache.spark.SPARK_VERSION + private var streamingJobProgressListener: Option[SparkListener] = None /** Initialize all components of the server. */ @@ -118,7 +120,8 @@ private[spark] class SparkUI private ( duration = 0, lastUpdated = new Date(startTime), sparkUser = getSparkUser, - completed = false + completed = false, + appSparkVersion = appSparkVersion )) )) } @@ -139,6 +142,7 @@ private[spark] abstract class SparkUITab(parent: SparkUI, prefix: String) def appName: String = parent.appName + def appSparkVersion: String = parent.appSparkVersion } private[spark] object SparkUI { diff --git a/core/src/main/scala/org/apache/spark/ui/UIUtils.scala b/core/src/main/scala/org/apache/spark/ui/UIUtils.scala index 79b0d81af52b5..8e1aafa448bc4 100644 --- a/core/src/main/scala/org/apache/spark/ui/UIUtils.scala +++ b/core/src/main/scala/org/apache/spark/ui/UIUtils.scala @@ -228,7 +228,7 @@ private[spark] object UIUtils extends Logging {
  • `timestampFormat` (default `yyyy-MM-dd'T'HH:mm:ss.SSSXXX`): sets the string that * indicates a timestamp format. Custom date formats follow the formats at * `java.text.SimpleDateFormat`. This applies to timestamp type.
  • - *
  • `wholeFile` (default `false`): parse one record, which may span multiple lines, + *
  • `multiLine` (default `false`): parse one record, which may span multiple lines, * per file
  • * * @@ -537,7 +537,7 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { *
  • `columnNameOfCorruptRecord` (default is the value specified in * `spark.sql.columnNameOfCorruptRecord`): allows renaming the new field having malformed string * created by `PERMISSIVE` mode. This overrides `spark.sql.columnNameOfCorruptRecord`.
  • - *
  • `wholeFile` (default `false`): parse one record, which may span multiple lines.
  • + *
  • `multiLine` (default `false`): parse one record, which may span multiple lines.
  • * * @since 2.0.0 */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala index 76f121c0c955f..eadc6c94f4b3c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala @@ -111,8 +111,8 @@ abstract class CSVDataSource extends Serializable { object CSVDataSource { def apply(options: CSVOptions): CSVDataSource = { - if (options.wholeFile) { - WholeFileCSVDataSource + if (options.multiLine) { + MultiLineCSVDataSource } else { TextInputCSVDataSource } @@ -197,7 +197,7 @@ object TextInputCSVDataSource extends CSVDataSource { } } -object WholeFileCSVDataSource extends CSVDataSource { +object MultiLineCSVDataSource extends CSVDataSource { override val isSplitable: Boolean = false override def readFile( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala index 78c16b75ee684..a13a5a34b4a84 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala @@ -128,7 +128,7 @@ class CSVOptions( FastDateFormat.getInstance( parameters.getOrElse("timestampFormat", "yyyy-MM-dd'T'HH:mm:ss.SSSXXX"), timeZone, Locale.US) - val wholeFile = parameters.get("wholeFile").map(_.toBoolean).getOrElse(false) + val multiLine = parameters.get("multiLine").map(_.toBoolean).getOrElse(false) val maxColumns = getInt("maxColumns", 20480) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonDataSource.scala index 4f2963da9ace9..5a92a71d19e78 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonDataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonDataSource.scala @@ -86,8 +86,8 @@ abstract class JsonDataSource extends Serializable { object JsonDataSource { def apply(options: JSONOptions): JsonDataSource = { - if (options.wholeFile) { - WholeFileJsonDataSource + if (options.multiLine) { + MultiLineJsonDataSource } else { TextInputJsonDataSource } @@ -147,7 +147,7 @@ object TextInputJsonDataSource extends JsonDataSource { } } -object WholeFileJsonDataSource extends JsonDataSource { +object MultiLineJsonDataSource extends JsonDataSource { override val isSplitable: Boolean = { false } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala index 766776230257d..7e8e6394b4862 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala @@ -163,7 +163,7 @@ final class DataStreamReader private[sql](sparkSession: SparkSession) extends Lo * Loads a JSON file stream and returns the results as a `DataFrame`. * * JSON Lines (newline-delimited JSON) is supported by - * default. For JSON (one record per file), set the `wholeFile` option to true. + * default. For JSON (one record per file), set the `multiLine` option to true. * * This function goes through the input once to determine the input schema. If you know the * schema in advance, use the version that specifies the schema to avoid the extra scan. @@ -205,7 +205,7 @@ final class DataStreamReader private[sql](sparkSession: SparkSession) extends Lo *
  • `timestampFormat` (default `yyyy-MM-dd'T'HH:mm:ss.SSSXXX`): sets the string that * indicates a timestamp format. Custom date formats follow the formats at * `java.text.SimpleDateFormat`. This applies to timestamp type.
  • - *
  • `wholeFile` (default `false`): parse one record, which may span multiple lines, + *
  • `multiLine` (default `false`): parse one record, which may span multiple lines, * per file
  • * * @@ -276,7 +276,7 @@ final class DataStreamReader private[sql](sparkSession: SparkSession) extends Lo *
  • `columnNameOfCorruptRecord` (default is the value specified in * `spark.sql.columnNameOfCorruptRecord`): allows renaming the new field having malformed string * created by `PERMISSIVE` mode. This overrides `spark.sql.columnNameOfCorruptRecord`.
  • - *
  • `wholeFile` (default `false`): parse one record, which may span multiple lines.
  • + *
  • `multiLine` (default `false`): parse one record, which may span multiple lines.
  • * * * @since 2.0.0 diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala index 352dba79a4c08..89d9b69dec7ef 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala @@ -261,10 +261,10 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils { } test("test for DROPMALFORMED parsing mode") { - Seq(false, true).foreach { wholeFile => + Seq(false, true).foreach { multiLine => val cars = spark.read .format("csv") - .option("wholeFile", wholeFile) + .option("multiLine", multiLine) .options(Map("header" -> "true", "mode" -> "dropmalformed")) .load(testFile(carsFile)) @@ -284,11 +284,11 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils { } test("test for FAILFAST parsing mode") { - Seq(false, true).foreach { wholeFile => + Seq(false, true).foreach { multiLine => val exception = intercept[SparkException] { spark.read .format("csv") - .option("wholeFile", wholeFile) + .option("multiLine", multiLine) .options(Map("header" -> "true", "mode" -> "failfast")) .load(testFile(carsFile)).collect() } @@ -990,13 +990,13 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils { } test("SPARK-18699 put malformed records in a `columnNameOfCorruptRecord` field") { - Seq(false, true).foreach { wholeFile => + Seq(false, true).foreach { multiLine => val schema = new StructType().add("a", IntegerType).add("b", TimestampType) // We use `PERMISSIVE` mode by default if invalid string is given. val df1 = spark .read .option("mode", "abcd") - .option("wholeFile", wholeFile) + .option("multiLine", multiLine) .schema(schema) .csv(testFile(valueMalformedFile)) checkAnswer(df1, @@ -1011,7 +1011,7 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils { .read .option("mode", "Permissive") .option("columnNameOfCorruptRecord", columnNameOfCorruptRecord) - .option("wholeFile", wholeFile) + .option("multiLine", multiLine) .schema(schemaWithCorrField1) .csv(testFile(valueMalformedFile)) checkAnswer(df2, @@ -1028,7 +1028,7 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils { .read .option("mode", "permissive") .option("columnNameOfCorruptRecord", columnNameOfCorruptRecord) - .option("wholeFile", wholeFile) + .option("multiLine", multiLine) .schema(schemaWithCorrField2) .csv(testFile(valueMalformedFile)) checkAnswer(df3, @@ -1041,7 +1041,7 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils { .read .option("mode", "PERMISSIVE") .option("columnNameOfCorruptRecord", columnNameOfCorruptRecord) - .option("wholeFile", wholeFile) + .option("multiLine", multiLine) .schema(schema.add(columnNameOfCorruptRecord, IntegerType)) .csv(testFile(valueMalformedFile)) .collect @@ -1073,7 +1073,7 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils { val df = spark.read .option("header", true) - .option("wholeFile", true) + .option("multiLine", true) .csv(path.getAbsolutePath) // Check if headers have new lines in the names. @@ -1096,10 +1096,10 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils { } test("Empty file produces empty dataframe with empty schema") { - Seq(false, true).foreach { wholeFile => + Seq(false, true).foreach { multiLine => val df = spark.read.format("csv") .option("header", true) - .option("wholeFile", wholeFile) + .option("multiLine", multiLine) .load(testFile(emptyFile)) assert(df.schema === spark.emptyDataFrame.schema) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala index 65472cda9c1c0..704823ad516c2 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala @@ -1814,7 +1814,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { assert(new File(path).listFiles().exists(_.getName.endsWith(".gz"))) - val jsonDF = spark.read.option("wholeFile", true).json(path) + val jsonDF = spark.read.option("multiLine", true).json(path) val jsonDir = new File(dir, "json").getCanonicalPath jsonDF.coalesce(1).write .option("compression", "gZiP") @@ -1836,7 +1836,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { .write .text(path) - val jsonDF = spark.read.option("wholeFile", true).json(path) + val jsonDF = spark.read.option("multiLine", true).json(path) val jsonDir = new File(dir, "json").getCanonicalPath jsonDF.coalesce(1).write.json(jsonDir) @@ -1865,7 +1865,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { .write .text(path) - val jsonDF = spark.read.option("wholeFile", true).json(path) + val jsonDF = spark.read.option("multiLine", true).json(path) // no corrupt record column should be created assert(jsonDF.schema === StructType(Seq())) // only the first object should be read @@ -1886,7 +1886,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { .write .text(path) - val jsonDF = spark.read.option("wholeFile", true).option("mode", "PERMISSIVE").json(path) + val jsonDF = spark.read.option("multiLine", true).option("mode", "PERMISSIVE").json(path) assert(jsonDF.count() === corruptRecordCount) assert(jsonDF.schema === new StructType() .add("_corrupt_record", StringType) @@ -1917,7 +1917,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { .write .text(path) - val jsonDF = spark.read.option("wholeFile", true).option("mode", "DROPMALFORMED").json(path) + val jsonDF = spark.read.option("multiLine", true).option("mode", "DROPMALFORMED").json(path) checkAnswer(jsonDF, Seq(Row("test"))) } } @@ -1940,7 +1940,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { // `FAILFAST` mode should throw an exception for corrupt records. val exceptionOne = intercept[SparkException] { spark.read - .option("wholeFile", true) + .option("multiLine", true) .option("mode", "FAILFAST") .json(path) } @@ -1949,7 +1949,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { val exceptionTwo = intercept[SparkException] { spark.read - .option("wholeFile", true) + .option("multiLine", true) .option("mode", "FAILFAST") .schema(schema) .json(path) From b32b2123ddca66e00acf4c9d956232e07f779f9f Mon Sep 17 00:00:00 2001 From: ALeksander Eskilson Date: Thu, 15 Jun 2017 13:45:08 +0800 Subject: [PATCH 0721/1765] [SPARK-18016][SQL][CATALYST] Code Generation: Constant Pool Limit - Class Splitting ## What changes were proposed in this pull request? This pull-request exclusively includes the class splitting feature described in #16648. When code for a given class would grow beyond 1600k bytes, a private, nested sub-class is generated into which subsequent functions are inlined. Additional sub-classes are generated as the code threshold is met subsequent times. This code includes 3 changes: 1. Includes helper maps, lists, and functions for keeping track of sub-classes during code generation (included in the `CodeGenerator` class). These helper functions allow nested classes and split functions to be initialized/declared/inlined to the appropriate locations in the various projection classes. 2. Changes `addNewFunction` to return a string to support instances where a split function is inlined to a nested class and not the outer class (and so must be invoked using the class-qualified name). Uses of `addNewFunction` throughout the codebase are modified so that the returned name is properly used. 3. Removes instances of the `this` keyword when used on data inside generated classes. All state declared in the outer class is by default global and accessible to the nested classes. However, if a reference to global state in a nested class is prepended with the `this` keyword, it would attempt to reference state belonging to the nested class (which would not exist), rather than the correct variable belonging to the outer class. ## How was this patch tested? Added a test case to the `GeneratedProjectionSuite` that increases the number of columns tested in various projections to a threshold that would previously have triggered a `JaninoRuntimeException` for the Constant Pool. Note: This PR does not address the second Constant Pool issue with code generation (also mentioned in #16648): excess global mutable state. A second PR may be opened to resolve that issue. Author: ALeksander Eskilson Closes #18075 from bdrillard/class_splitting_only. --- sql/catalyst/pom.xml | 7 + .../sql/catalyst/expressions/ScalaUDF.scala | 6 +- .../expressions/codegen/CodeGenerator.scala | 140 +++++++++++++++--- .../codegen/GenerateMutableProjection.scala | 17 ++- .../codegen/GenerateOrdering.scala | 3 + .../codegen/GeneratePredicate.scala | 3 + .../codegen/GenerateSafeProjection.scala | 9 +- .../codegen/GenerateUnsafeProjection.scala | 9 +- .../expressions/complexTypeCreator.scala | 6 +- .../expressions/conditionalExpressions.scala | 4 +- .../sql/catalyst/expressions/generators.scala | 6 +- .../expressions/objects/objects.scala | 2 +- .../codegen/GeneratedProjectionSuite.scala | 72 +++++++-- sql/core/pom.xml | 7 + .../sql/execution/ColumnarBatchScan.scala | 6 +- .../apache/spark/sql/execution/SortExec.scala | 4 +- .../sql/execution/WholeStageCodegenExec.scala | 3 + .../aggregate/HashAggregateExec.scala | 8 +- .../execution/basicPhysicalOperators.scala | 11 +- .../columnar/GenerateColumnAccessor.scala | 13 +- .../execution/joins/SortMergeJoinExec.scala | 2 +- .../apache/spark/sql/execution/limit.scala | 2 +- 22 files changed, 259 insertions(+), 81 deletions(-) diff --git a/sql/catalyst/pom.xml b/sql/catalyst/pom.xml index 8d80f8eca5dba..36948ba52b064 100644 --- a/sql/catalyst/pom.xml +++ b/sql/catalyst/pom.xml @@ -131,6 +131,13 @@ + + org.scalatest + scalatest-maven-plugin + + -Xmx4g -Xss4096k -XX:MaxPermSize=${MaxPermGen} -XX:ReservedCodeCacheSize=512m + + org.antlr antlr4-maven-plugin diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala index af1eba26621bd..a54f6d0e11147 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala @@ -988,7 +988,7 @@ case class ScalaUDF( val converterTerm = ctx.freshName("converter") val expressionIdx = ctx.references.size - 1 ctx.addMutableState(converterClassName, converterTerm, - s"this.$converterTerm = ($converterClassName)$typeConvertersClassName" + + s"$converterTerm = ($converterClassName)$typeConvertersClassName" + s".createToScalaConverter(((${expressionClassName})((($scalaUDFClassName)" + s"references[$expressionIdx]).getChildren().apply($index))).dataType());") converterTerm @@ -1005,7 +1005,7 @@ case class ScalaUDF( // Generate codes used to convert the returned value of user-defined functions to Catalyst type val catalystConverterTerm = ctx.freshName("catalystConverter") ctx.addMutableState(converterClassName, catalystConverterTerm, - s"this.$catalystConverterTerm = ($converterClassName)$typeConvertersClassName" + + s"$catalystConverterTerm = ($converterClassName)$typeConvertersClassName" + s".createToCatalystConverter($scalaUDF.dataType());") val resultTerm = ctx.freshName("result") @@ -1019,7 +1019,7 @@ case class ScalaUDF( val funcTerm = ctx.freshName("udf") ctx.addMutableState(funcClassName, funcTerm, - s"this.$funcTerm = ($funcClassName)$scalaUDF.userDefinedFunc();") + s"$funcTerm = ($funcClassName)$scalaUDF.userDefinedFunc();") // codegen for children expressions val evals = children.map(_.genCode(ctx)) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index fd9780245fcfb..5158949b95629 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -28,7 +28,6 @@ import scala.util.control.NonFatal import com.google.common.cache.{CacheBuilder, CacheLoader} import com.google.common.util.concurrent.{ExecutionError, UncheckedExecutionException} -import org.apache.commons.lang3.exception.ExceptionUtils import org.codehaus.commons.compiler.CompileException import org.codehaus.janino.{ByteArrayClassLoader, ClassBodyEvaluator, JaninoRuntimeException, SimpleCompiler} import org.codehaus.janino.util.ClassFile @@ -113,7 +112,7 @@ class CodegenContext { val idx = references.length references += obj val clsName = Option(className).getOrElse(obj.getClass.getName) - addMutableState(clsName, term, s"this.$term = ($clsName) references[$idx];") + addMutableState(clsName, term, s"$term = ($clsName) references[$idx];") term } @@ -202,16 +201,6 @@ class CodegenContext { partitionInitializationStatements.mkString("\n") } - /** - * Holding all the functions those will be added into generated class. - */ - val addedFunctions: mutable.Map[String, String] = - mutable.Map.empty[String, String] - - def addNewFunction(funcName: String, funcCode: String): Unit = { - addedFunctions += ((funcName, funcCode)) - } - /** * Holds expressions that are equivalent. Used to perform subexpression elimination * during codegen. @@ -233,10 +222,118 @@ class CodegenContext { // The collection of sub-expression result resetting methods that need to be called on each row. val subexprFunctions = mutable.ArrayBuffer.empty[String] - def declareAddedFunctions(): String = { - addedFunctions.map { case (funcName, funcCode) => funcCode }.mkString("\n") + val outerClassName = "OuterClass" + + /** + * Holds the class and instance names to be generated, where `OuterClass` is a placeholder + * standing for whichever class is generated as the outermost class and which will contain any + * nested sub-classes. All other classes and instance names in this list will represent private, + * nested sub-classes. + */ + private val classes: mutable.ListBuffer[(String, String)] = + mutable.ListBuffer[(String, String)](outerClassName -> null) + + // A map holding the current size in bytes of each class to be generated. + private val classSize: mutable.Map[String, Int] = + mutable.Map[String, Int](outerClassName -> 0) + + // Nested maps holding function names and their code belonging to each class. + private val classFunctions: mutable.Map[String, mutable.Map[String, String]] = + mutable.Map(outerClassName -> mutable.Map.empty[String, String]) + + // Returns the size of the most recently added class. + private def currClassSize(): Int = classSize(classes.head._1) + + // Returns the class name and instance name for the most recently added class. + private def currClass(): (String, String) = classes.head + + // Adds a new class. Requires the class' name, and its instance name. + private def addClass(className: String, classInstance: String): Unit = { + classes.prepend(className -> classInstance) + classSize += className -> 0 + classFunctions += className -> mutable.Map.empty[String, String] + } + + /** + * Adds a function to the generated class. If the code for the `OuterClass` grows too large, the + * function will be inlined into a new private, nested class, and a class-qualified name for the + * function will be returned. Otherwise, the function will be inined to the `OuterClass` the + * simple `funcName` will be returned. + * + * @param funcName the class-unqualified name of the function + * @param funcCode the body of the function + * @param inlineToOuterClass whether the given code must be inlined to the `OuterClass`. This + * can be necessary when a function is declared outside of the context + * it is eventually referenced and a returned qualified function name + * cannot otherwise be accessed. + * @return the name of the function, qualified by class if it will be inlined to a private, + * nested sub-class + */ + def addNewFunction( + funcName: String, + funcCode: String, + inlineToOuterClass: Boolean = false): String = { + // The number of named constants that can exist in the class is limited by the Constant Pool + // limit, 65,536. We cannot know how many constants will be inserted for a class, so we use a + // threshold of 1600k bytes to determine when a function should be inlined to a private, nested + // sub-class. + val (className, classInstance) = if (inlineToOuterClass) { + outerClassName -> "" + } else if (currClassSize > 1600000) { + val className = freshName("NestedClass") + val classInstance = freshName("nestedClassInstance") + + addClass(className, classInstance) + + className -> classInstance + } else { + currClass() + } + + classSize(className) += funcCode.length + classFunctions(className) += funcName -> funcCode + + if (className == outerClassName) { + funcName + } else { + + s"$classInstance.$funcName" + } + } + + /** + * Instantiates all nested, private sub-classes as objects to the `OuterClass` + */ + private[sql] def initNestedClasses(): String = { + // Nested, private sub-classes have no mutable state (though they do reference the outer class' + // mutable state), so we declare and initialize them inline to the OuterClass. + classes.filter(_._1 != outerClassName).map { + case (className, classInstance) => + s"private $className $classInstance = new $className();" + }.mkString("\n") + } + + /** + * Declares all function code that should be inlined to the `OuterClass`. + */ + private[sql] def declareAddedFunctions(): String = { + classFunctions(outerClassName).values.mkString("\n") } + /** + * Declares all nested, private sub-classes and the function code that should be inlined to them. + */ + private[sql] def declareNestedClasses(): String = { + classFunctions.filterKeys(_ != outerClassName).map { + case (className, functions) => + s""" + |private class $className { + | ${functions.values.mkString("\n")} + |} + """.stripMargin + } + }.mkString("\n") + final val JAVA_BOOLEAN = "boolean" final val JAVA_BYTE = "byte" final val JAVA_SHORT = "short" @@ -556,8 +653,7 @@ class CodegenContext { return 0; } """ - addNewFunction(compareFunc, funcCode) - s"this.$compareFunc($c1, $c2)" + s"${addNewFunction(compareFunc, funcCode)}($c1, $c2)" case schema: StructType => val comparisons = GenerateOrdering.genComparisons(this, schema) val compareFunc = freshName("compareStruct") @@ -573,8 +669,7 @@ class CodegenContext { return 0; } """ - addNewFunction(compareFunc, funcCode) - s"this.$compareFunc($c1, $c2)" + s"${addNewFunction(compareFunc, funcCode)}($c1, $c2)" case other if other.isInstanceOf[AtomicType] => s"$c1.compare($c2)" case udt: UserDefinedType[_] => genComp(udt.sqlType, c1, c2) case _ => @@ -629,7 +724,9 @@ class CodegenContext { /** * Splits the generated code of expressions into multiple functions, because function has - * 64kb code size limit in JVM + * 64kb code size limit in JVM. If the class to which the function would be inlined would grow + * beyond 1600kb, we declare a private, nested sub-class, and the function is inlined to it + * instead, because classes have a constant pool limit of 65,536 named values. * * @param row the variable name of row that is used by expressions * @param expressions the codes to evaluate expressions. @@ -689,7 +786,6 @@ class CodegenContext { |} """.stripMargin addNewFunction(name, code) - name } foldFunctions(functions.map(name => s"$name(${arguments.map(_._2).mkString(", ")})")) @@ -773,8 +869,6 @@ class CodegenContext { |} """.stripMargin - addNewFunction(fnName, fn) - // Add a state and a mapping of the common subexpressions that are associate with this // state. Adding this expression to subExprEliminationExprMap means it will call `fn` // when it is code generated. This decision should be a cost based one. @@ -792,7 +886,7 @@ class CodegenContext { addMutableState(javaType(expr.dataType), value, s"$value = ${defaultValue(expr.dataType)};") - subexprFunctions += s"$fnName($INPUT_ROW);" + subexprFunctions += s"${addNewFunction(fnName, fn)}($INPUT_ROW);" val state = SubExprEliminationState(isNull, value) e.foreach(subExprEliminationExprs.put(_, state)) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala index 4d732445544a8..635766835029b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala @@ -63,21 +63,21 @@ object GenerateMutableProjection extends CodeGenerator[Seq[Expression], MutableP if (e.nullable) { val isNull = s"isNull_$i" val value = s"value_$i" - ctx.addMutableState("boolean", isNull, s"this.$isNull = true;") + ctx.addMutableState("boolean", isNull, s"$isNull = true;") ctx.addMutableState(ctx.javaType(e.dataType), value, - s"this.$value = ${ctx.defaultValue(e.dataType)};") + s"$value = ${ctx.defaultValue(e.dataType)};") s""" ${ev.code} - this.$isNull = ${ev.isNull}; - this.$value = ${ev.value}; + $isNull = ${ev.isNull}; + $value = ${ev.value}; """ } else { val value = s"value_$i" ctx.addMutableState(ctx.javaType(e.dataType), value, - s"this.$value = ${ctx.defaultValue(e.dataType)};") + s"$value = ${ctx.defaultValue(e.dataType)};") s""" ${ev.code} - this.$value = ${ev.value}; + $value = ${ev.value}; """ } } @@ -87,7 +87,7 @@ object GenerateMutableProjection extends CodeGenerator[Seq[Expression], MutableP val updates = validExpr.zip(index).map { case (e, i) => - val ev = ExprCode("", s"this.isNull_$i", s"this.value_$i") + val ev = ExprCode("", s"isNull_$i", s"value_$i") ctx.updateColumn("mutableRow", e.dataType, i, ev, e.nullable) } @@ -135,6 +135,9 @@ object GenerateMutableProjection extends CodeGenerator[Seq[Expression], MutableP $allUpdates return mutableRow; } + + ${ctx.initNestedClasses()} + ${ctx.declareNestedClasses()} } """ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala index f7fc2d54a047b..a31943255b995 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala @@ -179,6 +179,9 @@ object GenerateOrdering extends CodeGenerator[Seq[SortOrder], Ordering[InternalR $comparisons return 0; } + + ${ctx.initNestedClasses()} + ${ctx.declareNestedClasses()} }""" val code = CodeFormatter.stripOverlappingComments( diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratePredicate.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratePredicate.scala index dcd1ed96a298e..b400783bb5e55 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratePredicate.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratePredicate.scala @@ -72,6 +72,9 @@ object GeneratePredicate extends CodeGenerator[Expression, Predicate] { ${eval.code} return !${eval.isNull} && ${eval.value}; } + + ${ctx.initNestedClasses()} + ${ctx.declareNestedClasses()} }""" val code = CodeFormatter.stripOverlappingComments( diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala index b1cb6edefb852..f708aeff2b146 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala @@ -49,7 +49,7 @@ object GenerateSafeProjection extends CodeGenerator[Seq[Expression], Projection] val output = ctx.freshName("safeRow") val values = ctx.freshName("values") // These expressions could be split into multiple functions - ctx.addMutableState("Object[]", values, s"this.$values = null;") + ctx.addMutableState("Object[]", values, s"$values = null;") val rowClass = classOf[GenericInternalRow].getName @@ -65,10 +65,10 @@ object GenerateSafeProjection extends CodeGenerator[Seq[Expression], Projection] val allFields = ctx.splitExpressions(tmp, fieldWriters) val code = s""" final InternalRow $tmp = $input; - this.$values = new Object[${schema.length}]; + $values = new Object[${schema.length}]; $allFields final InternalRow $output = new $rowClass($values); - this.$values = null; + $values = null; """ ExprCode(code, "false", output) @@ -184,6 +184,9 @@ object GenerateSafeProjection extends CodeGenerator[Seq[Expression], Projection] $allExpressions return mutableRow; } + + ${ctx.initNestedClasses()} + ${ctx.declareNestedClasses()} } """ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala index efbbc038bd33b..6be69d119bf8a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala @@ -82,7 +82,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro val rowWriterClass = classOf[UnsafeRowWriter].getName val rowWriter = ctx.freshName("rowWriter") ctx.addMutableState(rowWriterClass, rowWriter, - s"this.$rowWriter = new $rowWriterClass($bufferHolder, ${inputs.length});") + s"$rowWriter = new $rowWriterClass($bufferHolder, ${inputs.length});") val resetWriter = if (isTopLevel) { // For top level row writer, it always writes to the beginning of the global buffer holder, @@ -182,7 +182,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro val arrayWriterClass = classOf[UnsafeArrayWriter].getName val arrayWriter = ctx.freshName("arrayWriter") ctx.addMutableState(arrayWriterClass, arrayWriter, - s"this.$arrayWriter = new $arrayWriterClass();") + s"$arrayWriter = new $arrayWriterClass();") val numElements = ctx.freshName("numElements") val index = ctx.freshName("index") val element = ctx.freshName("element") @@ -321,7 +321,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro val holder = ctx.freshName("holder") val holderClass = classOf[BufferHolder].getName ctx.addMutableState(holderClass, holder, - s"this.$holder = new $holderClass($result, ${numVarLenFields * 32});") + s"$holder = new $holderClass($result, ${numVarLenFields * 32});") val resetBufferHolder = if (numVarLenFields == 0) { "" @@ -402,6 +402,9 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro ${eval.code.trim} return ${eval.value}; } + + ${ctx.initNestedClasses()} + ${ctx.declareNestedClasses()} } """ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala index b6675a84ece48..98c4cbee38dee 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala @@ -93,7 +93,7 @@ private [sql] object GenArrayData { if (!ctx.isPrimitiveType(elementType)) { val genericArrayClass = classOf[GenericArrayData].getName ctx.addMutableState("Object[]", arrayName, - s"this.$arrayName = new Object[${numElements}];") + s"$arrayName = new Object[${numElements}];") val assignments = elementsCode.zipWithIndex.map { case (eval, i) => val isNullAssignment = if (!isMapKey) { @@ -340,7 +340,7 @@ case class CreateNamedStruct(children: Seq[Expression]) extends CreateNamedStruc override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val rowClass = classOf[GenericInternalRow].getName val values = ctx.freshName("values") - ctx.addMutableState("Object[]", values, s"this.$values = null;") + ctx.addMutableState("Object[]", values, s"$values = null;") ev.copy(code = s""" $values = new Object[${valExprs.size}];""" + @@ -357,7 +357,7 @@ case class CreateNamedStruct(children: Seq[Expression]) extends CreateNamedStruc }) + s""" final InternalRow ${ev.value} = new $rowClass($values); - this.$values = null; + $values = null; """, isNull = "false") } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala index ee365fe636614..ae8efb673f91c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala @@ -131,8 +131,8 @@ case class If(predicate: Expression, trueValue: Expression, falseValue: Expressi | $globalValue = ${ev.value}; |} """.stripMargin - ctx.addNewFunction(funcName, funcBody) - (funcName, globalIsNull, globalValue) + val fullFuncName = ctx.addNewFunction(funcName, funcBody) + (fullFuncName, globalIsNull, globalValue) } override def toString: String = s"if ($predicate) $trueValue else $falseValue" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala index e023f0567ea87..c217aa875d9eb 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala @@ -200,7 +200,7 @@ case class Stack(children: Seq[Expression]) extends Generator { override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { // Rows - we write these into an array. val rowData = ctx.freshName("rows") - ctx.addMutableState("InternalRow[]", rowData, s"this.$rowData = new InternalRow[$numRows];") + ctx.addMutableState("InternalRow[]", rowData, s"$rowData = new InternalRow[$numRows];") val values = children.tail val dataTypes = values.take(numFields).map(_.dataType) val code = ctx.splitExpressions(ctx.INPUT_ROW, Seq.tabulate(numRows) { row => @@ -209,7 +209,7 @@ case class Stack(children: Seq[Expression]) extends Generator { if (index < values.length) values(index) else Literal(null, dataTypes(col)) } val eval = CreateStruct(fields).genCode(ctx) - s"${eval.code}\nthis.$rowData[$row] = ${eval.value};" + s"${eval.code}\n$rowData[$row] = ${eval.value};" }) // Create the collection. @@ -217,7 +217,7 @@ case class Stack(children: Seq[Expression]) extends Generator { ctx.addMutableState( s"$wrapperClass", ev.value, - s"this.${ev.value} = $wrapperClass$$.MODULE$$.make(this.$rowData);") + s"${ev.value} = $wrapperClass$$.MODULE$$.make($rowData);") ev.copy(code = code, isNull = "false") } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala index 5bb0febc943f2..073993cccdf8a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala @@ -1163,7 +1163,7 @@ case class InitializeJavaBean(beanInstance: Expression, setters: Map[String, Exp val code = s""" ${instanceGen.code} - this.${javaBeanInstance} = ${instanceGen.value}; + ${javaBeanInstance} = ${instanceGen.value}; if (!${instanceGen.isNull}) { $initializeCode } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratedProjectionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratedProjectionSuite.scala index b69b74b4240bd..58ea5b9cb52d3 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratedProjectionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratedProjectionSuite.scala @@ -33,10 +33,10 @@ class GeneratedProjectionSuite extends SparkFunSuite { test("generated projections on wider table") { val N = 1000 - val wideRow1 = new GenericInternalRow((1 to N).toArray[Any]) + val wideRow1 = new GenericInternalRow((0 until N).toArray[Any]) val schema1 = StructType((1 to N).map(i => StructField("", IntegerType))) val wideRow2 = new GenericInternalRow( - (1 to N).map(i => UTF8String.fromString(i.toString)).toArray[Any]) + (0 until N).map(i => UTF8String.fromString(i.toString)).toArray[Any]) val schema2 = StructType((1 to N).map(i => StructField("", StringType))) val joined = new JoinedRow(wideRow1, wideRow2) val joinedSchema = StructType(schema1 ++ schema2) @@ -48,12 +48,12 @@ class GeneratedProjectionSuite extends SparkFunSuite { val unsafeProj = UnsafeProjection.create(nestedSchema) val unsafe: UnsafeRow = unsafeProj(nested) (0 until N).foreach { i => - val s = UTF8String.fromString((i + 1).toString) - assert(i + 1 === unsafe.getInt(i + 2)) + val s = UTF8String.fromString(i.toString) + assert(i === unsafe.getInt(i + 2)) assert(s === unsafe.getUTF8String(i + 2 + N)) - assert(i + 1 === unsafe.getStruct(0, N * 2).getInt(i)) + assert(i === unsafe.getStruct(0, N * 2).getInt(i)) assert(s === unsafe.getStruct(0, N * 2).getUTF8String(i + N)) - assert(i + 1 === unsafe.getStruct(1, N * 2).getInt(i)) + assert(i === unsafe.getStruct(1, N * 2).getInt(i)) assert(s === unsafe.getStruct(1, N * 2).getUTF8String(i + N)) } @@ -62,13 +62,63 @@ class GeneratedProjectionSuite extends SparkFunSuite { val result = safeProj(unsafe) // Can't compare GenericInternalRow with JoinedRow directly (0 until N).foreach { i => - val r = i + 1 - val s = UTF8String.fromString((i + 1).toString) - assert(r === result.getInt(i + 2)) + val s = UTF8String.fromString(i.toString) + assert(i === result.getInt(i + 2)) assert(s === result.getUTF8String(i + 2 + N)) - assert(r === result.getStruct(0, N * 2).getInt(i)) + assert(i === result.getStruct(0, N * 2).getInt(i)) assert(s === result.getStruct(0, N * 2).getUTF8String(i + N)) - assert(r === result.getStruct(1, N * 2).getInt(i)) + assert(i === result.getStruct(1, N * 2).getInt(i)) + assert(s === result.getStruct(1, N * 2).getUTF8String(i + N)) + } + + // test generated MutableProjection + val exprs = nestedSchema.fields.zipWithIndex.map { case (f, i) => + BoundReference(i, f.dataType, true) + } + val mutableProj = GenerateMutableProjection.generate(exprs) + val row1 = mutableProj(result) + assert(result === row1) + val row2 = mutableProj(result) + assert(result === row2) + } + + test("SPARK-18016: generated projections on wider table requiring class-splitting") { + val N = 4000 + val wideRow1 = new GenericInternalRow((0 until N).toArray[Any]) + val schema1 = StructType((1 to N).map(i => StructField("", IntegerType))) + val wideRow2 = new GenericInternalRow( + (0 until N).map(i => UTF8String.fromString(i.toString)).toArray[Any]) + val schema2 = StructType((1 to N).map(i => StructField("", StringType))) + val joined = new JoinedRow(wideRow1, wideRow2) + val joinedSchema = StructType(schema1 ++ schema2) + val nested = new JoinedRow(InternalRow(joined, joined), joined) + val nestedSchema = StructType( + Seq(StructField("", joinedSchema), StructField("", joinedSchema)) ++ joinedSchema) + + // test generated UnsafeProjection + val unsafeProj = UnsafeProjection.create(nestedSchema) + val unsafe: UnsafeRow = unsafeProj(nested) + (0 until N).foreach { i => + val s = UTF8String.fromString(i.toString) + assert(i === unsafe.getInt(i + 2)) + assert(s === unsafe.getUTF8String(i + 2 + N)) + assert(i === unsafe.getStruct(0, N * 2).getInt(i)) + assert(s === unsafe.getStruct(0, N * 2).getUTF8String(i + N)) + assert(i === unsafe.getStruct(1, N * 2).getInt(i)) + assert(s === unsafe.getStruct(1, N * 2).getUTF8String(i + N)) + } + + // test generated SafeProjection + val safeProj = FromUnsafeProjection(nestedSchema) + val result = safeProj(unsafe) + // Can't compare GenericInternalRow with JoinedRow directly + (0 until N).foreach { i => + val s = UTF8String.fromString(i.toString) + assert(i === result.getInt(i + 2)) + assert(s === result.getUTF8String(i + 2 + N)) + assert(i === result.getStruct(0, N * 2).getInt(i)) + assert(s === result.getStruct(0, N * 2).getUTF8String(i + N)) + assert(i === result.getStruct(1, N * 2).getInt(i)) assert(s === result.getStruct(1, N * 2).getUTF8String(i + N)) } diff --git a/sql/core/pom.xml b/sql/core/pom.xml index fe4be963e8184..7327c9b0c9c50 100644 --- a/sql/core/pom.xml +++ b/sql/core/pom.xml @@ -183,6 +183,13 @@ + + org.scalatest + scalatest-maven-plugin + + -Xmx4g -Xss4096k -XX:MaxPermSize=${MaxPermGen} -XX:ReservedCodeCacheSize=512m + + org.codehaus.mojo build-helper-maven-plugin diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ColumnarBatchScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ColumnarBatchScan.scala index e86116680a57a..74a47da2deef2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ColumnarBatchScan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ColumnarBatchScan.scala @@ -93,7 +93,7 @@ private[sql] trait ColumnarBatchScan extends CodegenSupport { } val nextBatch = ctx.freshName("nextBatch") - ctx.addNewFunction(nextBatch, + val nextBatchFuncName = ctx.addNewFunction(nextBatch, s""" |private void $nextBatch() throws java.io.IOException { | long getBatchStart = System.nanoTime(); @@ -121,7 +121,7 @@ private[sql] trait ColumnarBatchScan extends CodegenSupport { } s""" |if ($batch == null) { - | $nextBatch(); + | $nextBatchFuncName(); |} |while ($batch != null) { | int $numRows = $batch.numRows(); @@ -133,7 +133,7 @@ private[sql] trait ColumnarBatchScan extends CodegenSupport { | } | $idx = $numRows; | $batch = null; - | $nextBatch(); + | $nextBatchFuncName(); |} |$scanTimeMetric.add($scanTimeTotalNs / (1000 * 1000)); |$scanTimeTotalNs = 0; diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SortExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SortExec.scala index f98ae82574d20..ff71fd4dc7bb7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SortExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SortExec.scala @@ -141,7 +141,7 @@ case class SortExec( ctx.addMutableState("scala.collection.Iterator", sortedIterator, "") val addToSorter = ctx.freshName("addToSorter") - ctx.addNewFunction(addToSorter, + val addToSorterFuncName = ctx.addNewFunction(addToSorter, s""" | private void $addToSorter() throws java.io.IOException { | ${child.asInstanceOf[CodegenSupport].produce(ctx, this)} @@ -160,7 +160,7 @@ case class SortExec( s""" | if ($needToSort) { | long $spillSizeBefore = $metrics.memoryBytesSpilled(); - | $addToSorter(); + | $addToSorterFuncName(); | $sortedIterator = $sorterVariable.sort(); | $sortTime.add($sorterVariable.getSortTimeNanos() / 1000000); | $peakMemory.add($sorterVariable.getPeakMemoryUsage()); diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala index ac30b11557adb..0bd28e36135c8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala @@ -357,6 +357,9 @@ case class WholeStageCodegenExec(child: SparkPlan) extends UnaryExecNode with Co protected void processNext() throws java.io.IOException { ${code.trim} } + + ${ctx.initNestedClasses()} + ${ctx.declareNestedClasses()} } """.trim diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala index 9df5e58f70add..5027a615ced7a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala @@ -212,7 +212,7 @@ case class HashAggregateExec( } val doAgg = ctx.freshName("doAggregateWithoutKey") - ctx.addNewFunction(doAgg, + val doAggFuncName = ctx.addNewFunction(doAgg, s""" | private void $doAgg() throws java.io.IOException { | // initialize aggregation buffer @@ -229,7 +229,7 @@ case class HashAggregateExec( | while (!$initAgg) { | $initAgg = true; | long $beforeAgg = System.nanoTime(); - | $doAgg(); + | $doAggFuncName(); | $aggTime.add((System.nanoTime() - $beforeAgg) / 1000000); | | // output the result @@ -600,7 +600,7 @@ case class HashAggregateExec( } else "" } - ctx.addNewFunction(doAgg, + val doAggFuncName = ctx.addNewFunction(doAgg, s""" ${generateGenerateCode} private void $doAgg() throws java.io.IOException { @@ -681,7 +681,7 @@ case class HashAggregateExec( if (!$initAgg) { $initAgg = true; long $beforeAgg = System.nanoTime(); - $doAgg(); + $doAggFuncName(); $aggTime.add((System.nanoTime() - $beforeAgg) / 1000000); } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala index bd7a5c5d914c1..f3ca8397047fe 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala @@ -281,10 +281,8 @@ case class SampleExec( val samplerClass = classOf[PoissonSampler[UnsafeRow]].getName val initSampler = ctx.freshName("initSampler") ctx.copyResult = true - ctx.addMutableState(s"$samplerClass", sampler, - s"$initSampler();") - ctx.addNewFunction(initSampler, + val initSamplerFuncName = ctx.addNewFunction(initSampler, s""" | private void $initSampler() { | $sampler = new $samplerClass($upperBound - $lowerBound, false); @@ -299,6 +297,9 @@ case class SampleExec( | } """.stripMargin.trim) + ctx.addMutableState(s"$samplerClass", sampler, + s"$initSamplerFuncName();") + val samplingCount = ctx.freshName("samplingCount") s""" | int $samplingCount = $sampler.sample(); @@ -394,7 +395,7 @@ case class RangeExec(range: org.apache.spark.sql.catalyst.plans.logical.Range) // The default size of a batch, which must be positive integer val batchSize = 1000 - ctx.addNewFunction("initRange", + val initRangeFuncName = ctx.addNewFunction("initRange", s""" | private void initRange(int idx) { | $BigInt index = $BigInt.valueOf(idx); @@ -451,7 +452,7 @@ case class RangeExec(range: org.apache.spark.sql.catalyst.plans.logical.Range) | // initialize Range | if (!$initTerm) { | $initTerm = true; - | initRange(partitionIndex); + | $initRangeFuncName(partitionIndex); | } | | while (true) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/GenerateColumnAccessor.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/GenerateColumnAccessor.scala index 14024d6c10558..d3fa0dcd2d7c3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/GenerateColumnAccessor.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/GenerateColumnAccessor.scala @@ -128,9 +128,7 @@ object GenerateColumnAccessor extends CodeGenerator[Seq[DataType], ColumnarItera } else { val groupedAccessorsItr = initializeAccessors.grouped(numberOfStatementsThreshold) val groupedExtractorsItr = extractors.grouped(numberOfStatementsThreshold) - var groupedAccessorsLength = 0 - groupedAccessorsItr.zipWithIndex.foreach { case (body, i) => - groupedAccessorsLength += 1 + val accessorNames = groupedAccessorsItr.zipWithIndex.map { case (body, i) => val funcName = s"accessors$i" val funcCode = s""" |private void $funcName() { @@ -139,7 +137,7 @@ object GenerateColumnAccessor extends CodeGenerator[Seq[DataType], ColumnarItera """.stripMargin ctx.addNewFunction(funcName, funcCode) } - groupedExtractorsItr.zipWithIndex.foreach { case (body, i) => + val extractorNames = groupedExtractorsItr.zipWithIndex.map { case (body, i) => val funcName = s"extractors$i" val funcCode = s""" |private void $funcName() { @@ -148,8 +146,8 @@ object GenerateColumnAccessor extends CodeGenerator[Seq[DataType], ColumnarItera """.stripMargin ctx.addNewFunction(funcName, funcCode) } - ((0 to groupedAccessorsLength - 1).map { i => s"accessors$i();" }.mkString("\n"), - (0 to groupedAccessorsLength - 1).map { i => s"extractors$i();" }.mkString("\n")) + (accessorNames.map { accessorName => s"$accessorName();" }.mkString("\n"), + extractorNames.map { extractorName => s"$extractorName();"}.mkString("\n")) } val codeBody = s""" @@ -224,6 +222,9 @@ object GenerateColumnAccessor extends CodeGenerator[Seq[DataType], ColumnarItera unsafeRow.setTotalSize(bufferHolder.totalSize()); return unsafeRow; } + + ${ctx.initNestedClasses()} + ${ctx.declareNestedClasses()} }""" val code = CodeFormatter.stripOverlappingComments( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala index 26fb6103953fc..8445c26eeee58 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala @@ -478,7 +478,7 @@ case class SortMergeJoinExec( | } | return false; // unreachable |} - """.stripMargin) + """.stripMargin, inlineToOuterClass = true) (leftRow, matches) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala index 757fe2185d302..73a0f8735ed45 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala @@ -75,7 +75,7 @@ trait BaseLimitExec extends UnaryExecNode with CodegenSupport { protected boolean stopEarly() { return $stopEarly; } - """) + """, inlineToOuterClass = true) val countTerm = ctx.freshName("count") ctx.addMutableState("int", countTerm, s"$countTerm = 0;") s""" From 1bf55e396c7b995a276df61d9a4eb8e60bcee334 Mon Sep 17 00:00:00 2001 From: Felix Cheung Date: Wed, 14 Jun 2017 23:08:05 -0700 Subject: [PATCH 0722/1765] [SPARK-20980][DOCS] update doc to reflect multiLine change ## What changes were proposed in this pull request? doc only change ## How was this patch tested? manually Author: Felix Cheung Closes #18312 from felixcheung/sqljsonwholefiledoc. --- docs/sql-programming-guide.md | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md index 314ff6ef80d29..8e722ae6adca6 100644 --- a/docs/sql-programming-guide.md +++ b/docs/sql-programming-guide.md @@ -998,7 +998,7 @@ Note that the file that is offered as _a json file_ is not a typical JSON file. line must contain a separate, self-contained valid JSON object. For more information, please see [JSON Lines text format, also called newline-delimited JSON](http://jsonlines.org/). -For a regular multi-line JSON file, set the `wholeFile` option to `true`. +For a regular multi-line JSON file, set the `multiLine` option to `true`. {% include_example json_dataset scala/org/apache/spark/examples/sql/SQLDataSourceExample.scala %} @@ -1012,7 +1012,7 @@ Note that the file that is offered as _a json file_ is not a typical JSON file. line must contain a separate, self-contained valid JSON object. For more information, please see [JSON Lines text format, also called newline-delimited JSON](http://jsonlines.org/). -For a regular multi-line JSON file, set the `wholeFile` option to `true`. +For a regular multi-line JSON file, set the `multiLine` option to `true`. {% include_example json_dataset java/org/apache/spark/examples/sql/JavaSQLDataSourceExample.java %} @@ -1025,7 +1025,7 @@ Note that the file that is offered as _a json file_ is not a typical JSON file. line must contain a separate, self-contained valid JSON object. For more information, please see [JSON Lines text format, also called newline-delimited JSON](http://jsonlines.org/). -For a regular multi-line JSON file, set the `wholeFile` parameter to `True`. +For a regular multi-line JSON file, set the `multiLine` parameter to `True`. {% include_example json_dataset python/sql/datasource.py %} @@ -1039,7 +1039,7 @@ Note that the file that is offered as _a json file_ is not a typical JSON file. line must contain a separate, self-contained valid JSON object. For more information, please see [JSON Lines text format, also called newline-delimited JSON](http://jsonlines.org/). -For a regular multi-line JSON file, set a named parameter `wholeFile` to `TRUE`. +For a regular multi-line JSON file, set a named parameter `multiLine` to `TRUE`. {% include_example json_dataset r/RSparkSQLExample.R %} From 7dc3e697c74864a4e3cca7342762f1427058b3c3 Mon Sep 17 00:00:00 2001 From: Xingbo Jiang Date: Fri, 16 Jun 2017 00:06:54 +0800 Subject: [PATCH 0723/1765] [SPARK-16251][SPARK-20200][CORE][TEST] Flaky test: org.apache.spark.rdd.LocalCheckpointSuite.missing checkpoint block fails with informative message ## What changes were proposed in this pull request? Currently we don't wait to confirm the removal of the block from the slave's BlockManager, if the removal takes too much time, we will fail the assertion in this test case. The failure can be easily reproduced if we sleep for a while before we remove the block in BlockManagerSlaveEndpoint.receiveAndReply(). ## How was this patch tested? N/A Author: Xingbo Jiang Closes #18314 from jiangxb1987/LocalCheckpointSuite. --- .../scala/org/apache/spark/rdd/LocalCheckpointSuite.scala | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/core/src/test/scala/org/apache/spark/rdd/LocalCheckpointSuite.scala b/core/src/test/scala/org/apache/spark/rdd/LocalCheckpointSuite.scala index 2802cd975292c..9e204f5cc33fe 100644 --- a/core/src/test/scala/org/apache/spark/rdd/LocalCheckpointSuite.scala +++ b/core/src/test/scala/org/apache/spark/rdd/LocalCheckpointSuite.scala @@ -17,6 +17,10 @@ package org.apache.spark.rdd +import scala.concurrent.duration._ + +import org.scalatest.concurrent.Eventually.{eventually, interval, timeout} + import org.apache.spark.{LocalSparkContext, SparkContext, SparkException, SparkFunSuite} import org.apache.spark.storage.{RDDBlockId, StorageLevel} @@ -168,6 +172,10 @@ class LocalCheckpointSuite extends SparkFunSuite with LocalSparkContext { // Collecting the RDD should now fail with an informative exception val blockId = RDDBlockId(rdd.id, numPartitions - 1) bmm.removeBlock(blockId) + // Wait until the block has been removed successfully. + eventually(timeout(1 seconds), interval(100 milliseconds)) { + assert(bmm.getBlockStatus(blockId).isEmpty) + } try { rdd.collect() fail("Collect should have failed if local checkpoint block is removed...") From a18d637112b97d2caaca0a8324bdd99086664b24 Mon Sep 17 00:00:00 2001 From: Michael Gummelt Date: Thu, 15 Jun 2017 11:46:00 -0700 Subject: [PATCH 0724/1765] [SPARK-20434][YARN][CORE] Move Hadoop delegation token code from yarn to core ## What changes were proposed in this pull request? Move Hadoop delegation token code from `spark-yarn` to `spark-core`, so that other schedulers (such as Mesos), may use it. In order to avoid exposing Hadoop interfaces in spark-core, the new Hadoop delegation token classes are kept private. In order to provider backward compatiblity, and to allow YARN users to continue to load their own delegation token providers via Java service loading, the old YARN interfaces, as well as the client code that uses them, have been retained. Summary: - Move registered `yarn.security.ServiceCredentialProvider` classes from `spark-yarn` to `spark-core`. Moved them into a new, private hierarchy under `HadoopDelegationTokenProvider`. Client code in `HadoopDelegationTokenManager` now loads credentials from a whitelist of three providers (`HadoopFSDelegationTokenProvider`, `HiveDelegationTokenProvider`, `HBaseDelegationTokenProvider`), instead of service loading, which means that users are not able to implement their own delegation token providers, as they are in the `spark-yarn` module. - The `yarn.security.ServiceCredentialProvider` interface has been kept for backwards compatibility, and to continue to allow YARN users to implement their own delegation token provider implementations. Client code in YARN now fetches tokens via the new `YARNHadoopDelegationTokenManager` class, which fetches tokens from the core providers through `HadoopDelegationTokenManager`, as well as service loads them from `yarn.security.ServiceCredentialProvider`. Old Hierarchy: ``` yarn.security.ServiceCredentialProvider (service loaded) HadoopFSCredentialProvider HiveCredentialProvider HBaseCredentialProvider yarn.security.ConfigurableCredentialManager ``` New Hierarchy: ``` HadoopDelegationTokenManager HadoopDelegationTokenProvider (not service loaded) HadoopFSDelegationTokenProvider HiveDelegationTokenProvider HBaseDelegationTokenProvider yarn.security.ServiceCredentialProvider (service loaded) yarn.security.YARNHadoopDelegationTokenManager ``` ## How was this patch tested? unit tests Author: Michael Gummelt Author: Dr. Stefan Schimanski Closes #17723 from mgummelt/SPARK-20434-refactor-kerberos. --- core/pom.xml | 28 ++++ .../HBaseDelegationTokenProvider.scala | 11 +- .../HadoopDelegationTokenManager.scala | 119 ++++++++++++++ .../HadoopDelegationTokenProvider.scala | 50 ++++++ .../HadoopFSDelegationTokenProvider.scala | 126 +++++++++++++++ .../HiveDelegationTokenProvider.scala | 78 ++++----- .../HadoopDelegationTokenManagerSuite.scala | 116 ++++++++++++++ dev/.rat-excludes | 5 +- docs/running-on-yarn.md | 12 +- resource-managers/yarn/pom.xml | 14 +- ...oy.yarn.security.ServiceCredentialProvider | 3 - .../spark/deploy/yarn/ApplicationMaster.scala | 10 +- .../org/apache/spark/deploy/yarn/Client.scala | 9 +- .../deploy/yarn/YarnSparkHadoopUtil.scala | 31 +++- .../yarn/security/AMCredentialRenewer.scala | 6 +- .../ConfigurableCredentialManager.scala | 107 ------------- .../yarn/security/CredentialUpdater.scala | 2 +- .../security/HadoopFSCredentialProvider.scala | 120 -------------- .../security/ServiceCredentialProvider.scala | 3 +- .../YARNHadoopDelegationTokenManager.scala | 83 ++++++++++ ...oy.yarn.security.ServiceCredentialProvider | 2 +- .../ConfigurableCredentialManagerSuite.scala | 150 ------------------ .../HadoopFSCredentialProviderSuite.scala | 70 -------- ...ARNHadoopDelegationTokenManagerSuite.scala | 66 ++++++++ 24 files changed, 689 insertions(+), 532 deletions(-) rename resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/security/HBaseCredentialProvider.scala => core/src/main/scala/org/apache/spark/deploy/security/HBaseDelegationTokenProvider.scala (88%) create mode 100644 core/src/main/scala/org/apache/spark/deploy/security/HadoopDelegationTokenManager.scala create mode 100644 core/src/main/scala/org/apache/spark/deploy/security/HadoopDelegationTokenProvider.scala create mode 100644 core/src/main/scala/org/apache/spark/deploy/security/HadoopFSDelegationTokenProvider.scala rename resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/security/HiveCredentialProvider.scala => core/src/main/scala/org/apache/spark/deploy/security/HiveDelegationTokenProvider.scala (54%) create mode 100644 core/src/test/scala/org/apache/spark/deploy/security/HadoopDelegationTokenManagerSuite.scala delete mode 100644 resource-managers/yarn/src/main/resources/META-INF/services/org.apache.spark.deploy.yarn.security.ServiceCredentialProvider delete mode 100644 resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/security/ConfigurableCredentialManager.scala delete mode 100644 resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/security/HadoopFSCredentialProvider.scala create mode 100644 resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/security/YARNHadoopDelegationTokenManager.scala delete mode 100644 resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/security/ConfigurableCredentialManagerSuite.scala delete mode 100644 resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/security/HadoopFSCredentialProviderSuite.scala create mode 100644 resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/security/YARNHadoopDelegationTokenManagerSuite.scala diff --git a/core/pom.xml b/core/pom.xml index 7f245b5b6384a..326dde4f274bb 100644 --- a/core/pom.xml +++ b/core/pom.xml @@ -357,6 +357,34 @@ org.apache.commons commons-crypto
    + + + + ${hive.group} + hive-exec + provided + + + ${hive.group} + hive-metastore + provided + + + org.apache.thrift + libthrift + provided + + + org.apache.thrift + libfb303 + provided + + target/scala-${scala.binary.version}/classes diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/security/HBaseCredentialProvider.scala b/core/src/main/scala/org/apache/spark/deploy/security/HBaseDelegationTokenProvider.scala similarity index 88% rename from resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/security/HBaseCredentialProvider.scala rename to core/src/main/scala/org/apache/spark/deploy/security/HBaseDelegationTokenProvider.scala index 5adeb8e605ff4..35621daf9c0d7 100644 --- a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/security/HBaseCredentialProvider.scala +++ b/core/src/main/scala/org/apache/spark/deploy/security/HBaseDelegationTokenProvider.scala @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.deploy.yarn.security +package org.apache.spark.deploy.security import scala.reflect.runtime.universe import scala.util.control.NonFatal @@ -24,17 +24,16 @@ import org.apache.hadoop.conf.Configuration import org.apache.hadoop.security.Credentials import org.apache.hadoop.security.token.{Token, TokenIdentifier} -import org.apache.spark.SparkConf import org.apache.spark.internal.Logging import org.apache.spark.util.Utils -private[security] class HBaseCredentialProvider extends ServiceCredentialProvider with Logging { +private[security] class HBaseDelegationTokenProvider + extends HadoopDelegationTokenProvider with Logging { override def serviceName: String = "hbase" - override def obtainCredentials( + override def obtainDelegationTokens( hadoopConf: Configuration, - sparkConf: SparkConf, creds: Credentials): Option[Long] = { try { val mirror = universe.runtimeMirror(Utils.getContextOrSparkClassLoader) @@ -55,7 +54,7 @@ private[security] class HBaseCredentialProvider extends ServiceCredentialProvide None } - override def credentialsRequired(hadoopConf: Configuration): Boolean = { + override def delegationTokensRequired(hadoopConf: Configuration): Boolean = { hbaseConf(hadoopConf).get("hbase.security.authentication") == "kerberos" } diff --git a/core/src/main/scala/org/apache/spark/deploy/security/HadoopDelegationTokenManager.scala b/core/src/main/scala/org/apache/spark/deploy/security/HadoopDelegationTokenManager.scala new file mode 100644 index 0000000000000..89b6f52ba4bca --- /dev/null +++ b/core/src/main/scala/org/apache/spark/deploy/security/HadoopDelegationTokenManager.scala @@ -0,0 +1,119 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.deploy.security + +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.fs.FileSystem +import org.apache.hadoop.security.Credentials + +import org.apache.spark.SparkConf +import org.apache.spark.internal.Logging + +/** + * Manages all the registered HadoopDelegationTokenProviders and offer APIs for other modules to + * obtain delegation tokens and their renewal time. By default [[HadoopFSDelegationTokenProvider]], + * [[HiveDelegationTokenProvider]] and [[HBaseDelegationTokenProvider]] will be loaded in if not + * explicitly disabled. + * + * Also, each HadoopDelegationTokenProvider is controlled by + * spark.security.credentials.{service}.enabled, and will not be loaded if this config is set to + * false. For example, Hive's delegation token provider [[HiveDelegationTokenProvider]] can be + * enabled/disabled by the configuration spark.security.credentials.hive.enabled. + * + * @param sparkConf Spark configuration + * @param hadoopConf Hadoop configuration + * @param fileSystems Delegation tokens will be fetched for these Hadoop filesystems. + */ +private[spark] class HadoopDelegationTokenManager( + sparkConf: SparkConf, + hadoopConf: Configuration, + fileSystems: Set[FileSystem]) + extends Logging { + + private val deprecatedProviderEnabledConfigs = List( + "spark.yarn.security.tokens.%s.enabled", + "spark.yarn.security.credentials.%s.enabled") + private val providerEnabledConfig = "spark.security.credentials.%s.enabled" + + // Maintain all the registered delegation token providers + private val delegationTokenProviders = getDelegationTokenProviders + logDebug(s"Using the following delegation token providers: " + + s"${delegationTokenProviders.keys.mkString(", ")}.") + + private def getDelegationTokenProviders: Map[String, HadoopDelegationTokenProvider] = { + val providers = List(new HadoopFSDelegationTokenProvider(fileSystems), + new HiveDelegationTokenProvider, + new HBaseDelegationTokenProvider) + + // Filter out providers for which spark.security.credentials.{service}.enabled is false. + providers + .filter { p => isServiceEnabled(p.serviceName) } + .map { p => (p.serviceName, p) } + .toMap + } + + def isServiceEnabled(serviceName: String): Boolean = { + val key = providerEnabledConfig.format(serviceName) + + deprecatedProviderEnabledConfigs.foreach { pattern => + val deprecatedKey = pattern.format(serviceName) + if (sparkConf.contains(deprecatedKey)) { + logWarning(s"${deprecatedKey} is deprecated. Please use ${key} instead.") + } + } + + val isEnabledDeprecated = deprecatedProviderEnabledConfigs.forall { pattern => + sparkConf + .getOption(pattern.format(serviceName)) + .map(_.toBoolean) + .getOrElse(true) + } + + sparkConf + .getOption(key) + .map(_.toBoolean) + .getOrElse(isEnabledDeprecated) + } + + /** + * Get delegation token provider for the specified service. + */ + def getServiceDelegationTokenProvider(service: String): Option[HadoopDelegationTokenProvider] = { + delegationTokenProviders.get(service) + } + + /** + * Writes delegation tokens to creds. Delegation tokens are fetched from all registered + * providers. + * + * @return Time after which the fetched delegation tokens should be renewed. + */ + def obtainDelegationTokens( + hadoopConf: Configuration, + creds: Credentials): Long = { + delegationTokenProviders.values.flatMap { provider => + if (provider.delegationTokensRequired(hadoopConf)) { + provider.obtainDelegationTokens(hadoopConf, creds) + } else { + logDebug(s"Service ${provider.serviceName} does not require a token." + + s" Check your configuration to see if security is disabled or not.") + None + } + }.foldLeft(Long.MaxValue)(math.min) + } +} diff --git a/core/src/main/scala/org/apache/spark/deploy/security/HadoopDelegationTokenProvider.scala b/core/src/main/scala/org/apache/spark/deploy/security/HadoopDelegationTokenProvider.scala new file mode 100644 index 0000000000000..f162e7e58c53a --- /dev/null +++ b/core/src/main/scala/org/apache/spark/deploy/security/HadoopDelegationTokenProvider.scala @@ -0,0 +1,50 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.deploy.security + +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.security.Credentials + +/** + * Hadoop delegation token provider. + */ +private[spark] trait HadoopDelegationTokenProvider { + + /** + * Name of the service to provide delegation tokens. This name should be unique. Spark will + * internally use this name to differentiate delegation token providers. + */ + def serviceName: String + + /** + * Returns true if delegation tokens are required for this service. By default, it is based on + * whether Hadoop security is enabled. + */ + def delegationTokensRequired(hadoopConf: Configuration): Boolean + + /** + * Obtain delegation tokens for this service and get the time of the next renewal. + * @param hadoopConf Configuration of current Hadoop Compatible system. + * @param creds Credentials to add tokens and security keys to. + * @return If the returned tokens are renewable and can be renewed, return the time of the next + * renewal, otherwise None should be returned. + */ + def obtainDelegationTokens( + hadoopConf: Configuration, + creds: Credentials): Option[Long] +} diff --git a/core/src/main/scala/org/apache/spark/deploy/security/HadoopFSDelegationTokenProvider.scala b/core/src/main/scala/org/apache/spark/deploy/security/HadoopFSDelegationTokenProvider.scala new file mode 100644 index 0000000000000..13157f33e2bf9 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/deploy/security/HadoopFSDelegationTokenProvider.scala @@ -0,0 +1,126 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.deploy.security + +import scala.collection.JavaConverters._ +import scala.util.Try + +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.fs.FileSystem +import org.apache.hadoop.mapred.Master +import org.apache.hadoop.security.{Credentials, UserGroupInformation} +import org.apache.hadoop.security.token.delegation.AbstractDelegationTokenIdentifier + +import org.apache.spark.SparkException +import org.apache.spark.internal.Logging + +private[deploy] class HadoopFSDelegationTokenProvider(fileSystems: Set[FileSystem]) + extends HadoopDelegationTokenProvider with Logging { + + // This tokenRenewalInterval will be set in the first call to obtainDelegationTokens. + // If None, no token renewer is specified or no token can be renewed, + // so we cannot get the token renewal interval. + private var tokenRenewalInterval: Option[Long] = null + + override val serviceName: String = "hadoopfs" + + override def obtainDelegationTokens( + hadoopConf: Configuration, + creds: Credentials): Option[Long] = { + + val newCreds = fetchDelegationTokens( + getTokenRenewer(hadoopConf), + fileSystems) + + // Get the token renewal interval if it is not set. It will only be called once. + if (tokenRenewalInterval == null) { + tokenRenewalInterval = getTokenRenewalInterval(hadoopConf, fileSystems) + } + + // Get the time of next renewal. + val nextRenewalDate = tokenRenewalInterval.flatMap { interval => + val nextRenewalDates = newCreds.getAllTokens.asScala + .filter(_.decodeIdentifier().isInstanceOf[AbstractDelegationTokenIdentifier]) + .map { token => + val identifier = token + .decodeIdentifier() + .asInstanceOf[AbstractDelegationTokenIdentifier] + identifier.getIssueDate + interval + } + if (nextRenewalDates.isEmpty) None else Some(nextRenewalDates.min) + } + + creds.addAll(newCreds) + nextRenewalDate + } + + def delegationTokensRequired(hadoopConf: Configuration): Boolean = { + UserGroupInformation.isSecurityEnabled + } + + private def getTokenRenewer(hadoopConf: Configuration): String = { + val tokenRenewer = Master.getMasterPrincipal(hadoopConf) + logDebug("Delegation token renewer is: " + tokenRenewer) + + if (tokenRenewer == null || tokenRenewer.length() == 0) { + val errorMessage = "Can't get Master Kerberos principal for use as renewer." + logError(errorMessage) + throw new SparkException(errorMessage) + } + + tokenRenewer + } + + private def fetchDelegationTokens( + renewer: String, + filesystems: Set[FileSystem]): Credentials = { + + val creds = new Credentials() + + filesystems.foreach { fs => + logInfo("getting token for: " + fs) + fs.addDelegationTokens(renewer, creds) + } + + creds + } + + private def getTokenRenewalInterval( + hadoopConf: Configuration, + filesystems: Set[FileSystem]): Option[Long] = { + // We cannot use the tokens generated with renewer yarn. Trying to renew + // those will fail with an access control issue. So create new tokens with the logged in + // user as renewer. + val creds = fetchDelegationTokens( + UserGroupInformation.getCurrentUser.getUserName, + filesystems) + + val renewIntervals = creds.getAllTokens.asScala.filter { + _.decodeIdentifier().isInstanceOf[AbstractDelegationTokenIdentifier] + }.flatMap { token => + Try { + val newExpiration = token.renew(hadoopConf) + val identifier = token.decodeIdentifier().asInstanceOf[AbstractDelegationTokenIdentifier] + val interval = newExpiration - identifier.getIssueDate + logInfo(s"Renewal interval is $interval for token ${token.getKind.toString}") + interval + }.toOption + } + if (renewIntervals.isEmpty) None else Some(renewIntervals.min) + } +} diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/security/HiveCredentialProvider.scala b/core/src/main/scala/org/apache/spark/deploy/security/HiveDelegationTokenProvider.scala similarity index 54% rename from resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/security/HiveCredentialProvider.scala rename to core/src/main/scala/org/apache/spark/deploy/security/HiveDelegationTokenProvider.scala index 16d8fc32bb42d..53b9f898c6e7d 100644 --- a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/security/HiveCredentialProvider.scala +++ b/core/src/main/scala/org/apache/spark/deploy/security/HiveDelegationTokenProvider.scala @@ -15,97 +15,89 @@ * limitations under the License. */ -package org.apache.spark.deploy.yarn.security +package org.apache.spark.deploy.security import java.lang.reflect.UndeclaredThrowableException import java.security.PrivilegedExceptionAction -import scala.reflect.runtime.universe import scala.util.control.NonFatal import org.apache.hadoop.conf.Configuration import org.apache.hadoop.hdfs.security.token.delegation.DelegationTokenIdentifier +import org.apache.hadoop.hive.conf.HiveConf +import org.apache.hadoop.hive.ql.metadata.Hive import org.apache.hadoop.io.Text import org.apache.hadoop.security.{Credentials, UserGroupInformation} import org.apache.hadoop.security.token.Token -import org.apache.spark.SparkConf import org.apache.spark.internal.Logging import org.apache.spark.util.Utils -private[security] class HiveCredentialProvider extends ServiceCredentialProvider with Logging { +private[security] class HiveDelegationTokenProvider + extends HadoopDelegationTokenProvider with Logging { override def serviceName: String = "hive" + private val classNotFoundErrorStr = s"You are attempting to use the " + + s"${getClass.getCanonicalName}, but your Spark distribution is not built with Hive libraries." + private def hiveConf(hadoopConf: Configuration): Configuration = { try { - val mirror = universe.runtimeMirror(Utils.getContextOrSparkClassLoader) - // the hive configuration class is a subclass of Hadoop Configuration, so can be cast down - // to a Configuration and used without reflection - val hiveConfClass = mirror.classLoader.loadClass("org.apache.hadoop.hive.conf.HiveConf") - // using the (Configuration, Class) constructor allows the current configuration to be - // included in the hive config. - val ctor = hiveConfClass.getDeclaredConstructor(classOf[Configuration], - classOf[Object].getClass) - ctor.newInstance(hadoopConf, hiveConfClass).asInstanceOf[Configuration] + new HiveConf(hadoopConf, classOf[HiveConf]) } catch { case NonFatal(e) => logDebug("Fail to create Hive Configuration", e) hadoopConf + case e: NoClassDefFoundError => + logWarning(classNotFoundErrorStr) + hadoopConf } } - override def credentialsRequired(hadoopConf: Configuration): Boolean = { + override def delegationTokensRequired(hadoopConf: Configuration): Boolean = { UserGroupInformation.isSecurityEnabled && hiveConf(hadoopConf).getTrimmed("hive.metastore.uris", "").nonEmpty } - override def obtainCredentials( + override def obtainDelegationTokens( hadoopConf: Configuration, - sparkConf: SparkConf, creds: Credentials): Option[Long] = { - val conf = hiveConf(hadoopConf) - - val principalKey = "hive.metastore.kerberos.principal" - val principal = conf.getTrimmed(principalKey, "") - require(principal.nonEmpty, s"Hive principal $principalKey undefined") - val metastoreUri = conf.getTrimmed("hive.metastore.uris", "") - require(metastoreUri.nonEmpty, "Hive metastore uri undefined") - - val currentUser = UserGroupInformation.getCurrentUser() - logDebug(s"Getting Hive delegation token for ${currentUser.getUserName()} against " + - s"$principal at $metastoreUri") + try { + val conf = hiveConf(hadoopConf) - val mirror = universe.runtimeMirror(Utils.getContextOrSparkClassLoader) - val hiveClass = mirror.classLoader.loadClass("org.apache.hadoop.hive.ql.metadata.Hive") - val hiveConfClass = mirror.classLoader.loadClass("org.apache.hadoop.hive.conf.HiveConf") - val closeCurrent = hiveClass.getMethod("closeCurrent") + val principalKey = "hive.metastore.kerberos.principal" + val principal = conf.getTrimmed(principalKey, "") + require(principal.nonEmpty, s"Hive principal $principalKey undefined") + val metastoreUri = conf.getTrimmed("hive.metastore.uris", "") + require(metastoreUri.nonEmpty, "Hive metastore uri undefined") - try { - // get all the instance methods before invoking any - val getDelegationToken = hiveClass.getMethod("getDelegationToken", - classOf[String], classOf[String]) - val getHive = hiveClass.getMethod("get", hiveConfClass) + val currentUser = UserGroupInformation.getCurrentUser() + logDebug(s"Getting Hive delegation token for ${currentUser.getUserName()} against " + + s"$principal at $metastoreUri") doAsRealUser { - val hive = getHive.invoke(null, conf) - val tokenStr = getDelegationToken.invoke(hive, currentUser.getUserName(), principal) - .asInstanceOf[String] + val hive = Hive.get(conf, classOf[HiveConf]) + val tokenStr = hive.getDelegationToken(currentUser.getUserName(), principal) + val hive2Token = new Token[DelegationTokenIdentifier]() hive2Token.decodeFromUrlString(tokenStr) logInfo(s"Get Token from hive metastore: ${hive2Token.toString}") creds.addToken(new Text("hive.server2.delegation.token"), hive2Token) } + + None } catch { case NonFatal(e) => - logDebug(s"Fail to get token from service $serviceName", e) + logDebug(s"Failed to get token from service $serviceName", e) + None + case e: NoClassDefFoundError => + logWarning(classNotFoundErrorStr) + None } finally { Utils.tryLogNonFatalError { - closeCurrent.invoke(null) + Hive.closeCurrent() } } - - None } /** diff --git a/core/src/test/scala/org/apache/spark/deploy/security/HadoopDelegationTokenManagerSuite.scala b/core/src/test/scala/org/apache/spark/deploy/security/HadoopDelegationTokenManagerSuite.scala new file mode 100644 index 0000000000000..335f3449cb782 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/deploy/security/HadoopDelegationTokenManagerSuite.scala @@ -0,0 +1,116 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.deploy.security + +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.fs.FileSystem +import org.apache.hadoop.security.Credentials +import org.scalatest.Matchers + +import org.apache.spark.{SparkConf, SparkFunSuite} + +class HadoopDelegationTokenManagerSuite extends SparkFunSuite with Matchers { + private var delegationTokenManager: HadoopDelegationTokenManager = null + private var sparkConf: SparkConf = null + private var hadoopConf: Configuration = null + + override def beforeAll(): Unit = { + super.beforeAll() + + sparkConf = new SparkConf() + hadoopConf = new Configuration() + } + + test("Correctly load default credential providers") { + delegationTokenManager = new HadoopDelegationTokenManager( + sparkConf, + hadoopConf, + hadoopFSsToAccess(hadoopConf)) + + delegationTokenManager.getServiceDelegationTokenProvider("hadoopfs") should not be (None) + delegationTokenManager.getServiceDelegationTokenProvider("hbase") should not be (None) + delegationTokenManager.getServiceDelegationTokenProvider("hive") should not be (None) + delegationTokenManager.getServiceDelegationTokenProvider("bogus") should be (None) + } + + test("disable hive credential provider") { + sparkConf.set("spark.security.credentials.hive.enabled", "false") + delegationTokenManager = new HadoopDelegationTokenManager( + sparkConf, + hadoopConf, + hadoopFSsToAccess(hadoopConf)) + + delegationTokenManager.getServiceDelegationTokenProvider("hadoopfs") should not be (None) + delegationTokenManager.getServiceDelegationTokenProvider("hbase") should not be (None) + delegationTokenManager.getServiceDelegationTokenProvider("hive") should be (None) + } + + test("using deprecated configurations") { + sparkConf.set("spark.yarn.security.tokens.hadoopfs.enabled", "false") + sparkConf.set("spark.yarn.security.credentials.hive.enabled", "false") + delegationTokenManager = new HadoopDelegationTokenManager( + sparkConf, + hadoopConf, + hadoopFSsToAccess(hadoopConf)) + + delegationTokenManager.getServiceDelegationTokenProvider("hadoopfs") should be (None) + delegationTokenManager.getServiceDelegationTokenProvider("hive") should be (None) + delegationTokenManager.getServiceDelegationTokenProvider("hbase") should not be (None) + } + + test("verify no credentials are obtained") { + delegationTokenManager = new HadoopDelegationTokenManager( + sparkConf, + hadoopConf, + hadoopFSsToAccess(hadoopConf)) + val creds = new Credentials() + + // Tokens cannot be obtained from HDFS, Hive, HBase in unit tests. + delegationTokenManager.obtainDelegationTokens(hadoopConf, creds) + val tokens = creds.getAllTokens + tokens.size() should be (0) + } + + test("obtain tokens For HiveMetastore") { + val hadoopConf = new Configuration() + hadoopConf.set("hive.metastore.kerberos.principal", "bob") + // thrift picks up on port 0 and bails out, without trying to talk to endpoint + hadoopConf.set("hive.metastore.uris", "http://localhost:0") + + val hiveCredentialProvider = new HiveDelegationTokenProvider() + val credentials = new Credentials() + hiveCredentialProvider.obtainDelegationTokens(hadoopConf, credentials) + + credentials.getAllTokens.size() should be (0) + } + + test("Obtain tokens For HBase") { + val hadoopConf = new Configuration() + hadoopConf.set("hbase.security.authentication", "kerberos") + + val hbaseTokenProvider = new HBaseDelegationTokenProvider() + val creds = new Credentials() + hbaseTokenProvider.obtainDelegationTokens(hadoopConf, creds) + + creds.getAllTokens.size should be (0) + } + + private[spark] def hadoopFSsToAccess(hadoopConf: Configuration): Set[FileSystem] = { + Set(FileSystem.get(hadoopConf)) + } +} diff --git a/dev/.rat-excludes b/dev/.rat-excludes index 2355d40d1e6fe..607234b4068d0 100644 --- a/dev/.rat-excludes +++ b/dev/.rat-excludes @@ -93,16 +93,13 @@ INDEX .lintr gen-java.* .*avpr -org.apache.spark.sql.sources.DataSourceRegister -org.apache.spark.scheduler.SparkHistoryListenerFactory .*parquet spark-deps-.* .*csv .*tsv -org.apache.spark.scheduler.ExternalClusterManager .*\.sql .Rbuildignore -org.apache.spark.deploy.yarn.security.ServiceCredentialProvider +META-INF/* spark-warehouse structured-streaming/* kafka-source-initial-offset-version-2.1.0.bin diff --git a/docs/running-on-yarn.md b/docs/running-on-yarn.md index 2d56123028f2b..e4a74556d4f26 100644 --- a/docs/running-on-yarn.md +++ b/docs/running-on-yarn.md @@ -419,7 +419,7 @@ To use a custom metrics.properties for the application master and executors, upd - spark.yarn.security.credentials.${service}.enabled + spark.security.credentials.${service}.enabled true Controls whether to obtain credentials for services when security is enabled. @@ -482,11 +482,11 @@ token for the cluster's default Hadoop filesystem, and potentially for HBase and An HBase token will be obtained if HBase is in on classpath, the HBase configuration declares the application is secure (i.e. `hbase-site.xml` sets `hbase.security.authentication` to `kerberos`), -and `spark.yarn.security.credentials.hbase.enabled` is not set to `false`. +and `spark.security.credentials.hbase.enabled` is not set to `false`. Similarly, a Hive token will be obtained if Hive is on the classpath, its configuration includes a URI of the metadata store in `"hive.metastore.uris`, and -`spark.yarn.security.credentials.hive.enabled` is not set to `false`. +`spark.security.credentials.hive.enabled` is not set to `false`. If an application needs to interact with other secure Hadoop filesystems, then the tokens needed to access these clusters must be explicitly requested at @@ -500,7 +500,7 @@ Spark supports integrating with other security-aware services through Java Servi `java.util.ServiceLoader`). To do that, implementations of `org.apache.spark.deploy.yarn.security.ServiceCredentialProvider` should be available to Spark by listing their names in the corresponding file in the jar's `META-INF/services` directory. These plug-ins can be disabled by setting -`spark.yarn.security.credentials.{service}.enabled` to `false`, where `{service}` is the name of +`spark.security.credentials.{service}.enabled` to `false`, where `{service}` is the name of credential provider. ## Configuring the External Shuffle Service @@ -564,8 +564,8 @@ the Spark configuration must be set to disable token collection for the services The Spark configuration must include the lines: ``` -spark.yarn.security.credentials.hive.enabled false -spark.yarn.security.credentials.hbase.enabled false +spark.security.credentials.hive.enabled false +spark.security.credentials.hbase.enabled false ``` The configuration option `spark.yarn.access.hadoopFileSystems` must be unset. diff --git a/resource-managers/yarn/pom.xml b/resource-managers/yarn/pom.xml index 71d4ad681e169..43a7ce95bd3de 100644 --- a/resource-managers/yarn/pom.xml +++ b/resource-managers/yarn/pom.xml @@ -167,29 +167,27 @@ ${jersey-1.version}
    - + ${hive.group} hive-exec - test + provided ${hive.group} hive-metastore - test + provided org.apache.thrift libthrift - test + provided org.apache.thrift libfb303 - test + provided diff --git a/resource-managers/yarn/src/main/resources/META-INF/services/org.apache.spark.deploy.yarn.security.ServiceCredentialProvider b/resource-managers/yarn/src/main/resources/META-INF/services/org.apache.spark.deploy.yarn.security.ServiceCredentialProvider deleted file mode 100644 index f5a807ecac9d7..0000000000000 --- a/resource-managers/yarn/src/main/resources/META-INF/services/org.apache.spark.deploy.yarn.security.ServiceCredentialProvider +++ /dev/null @@ -1,3 +0,0 @@ -org.apache.spark.deploy.yarn.security.HadoopFSCredentialProvider -org.apache.spark.deploy.yarn.security.HBaseCredentialProvider -org.apache.spark.deploy.yarn.security.HiveCredentialProvider diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala index 6da2c0b5f330a..4f71a1606312d 100644 --- a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala +++ b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala @@ -38,7 +38,7 @@ import org.apache.spark._ import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.deploy.history.HistoryServer import org.apache.spark.deploy.yarn.config._ -import org.apache.spark.deploy.yarn.security.{AMCredentialRenewer, ConfigurableCredentialManager} +import org.apache.spark.deploy.yarn.security.{AMCredentialRenewer, YARNHadoopDelegationTokenManager} import org.apache.spark.internal.Logging import org.apache.spark.internal.config._ import org.apache.spark.rpc._ @@ -247,8 +247,12 @@ private[spark] class ApplicationMaster( if (sparkConf.contains(CREDENTIALS_FILE_PATH.key)) { // If a principal and keytab have been set, use that to create new credentials for executors // periodically - credentialRenewer = - new ConfigurableCredentialManager(sparkConf, yarnConf).credentialRenewer() + val credentialManager = new YARNHadoopDelegationTokenManager( + sparkConf, + yarnConf, + YarnSparkHadoopUtil.get.hadoopFSsToAccess(sparkConf, yarnConf)) + + val credentialRenewer = new AMCredentialRenewer(sparkConf, yarnConf, credentialManager) credentialRenewer.scheduleLoginFromKeytab() } diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala index 1fb7edf2a6e30..e5131e636dc04 100644 --- a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala +++ b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala @@ -49,7 +49,7 @@ import org.apache.hadoop.yarn.util.Records import org.apache.spark.{SecurityManager, SparkConf, SparkException} import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.deploy.yarn.config._ -import org.apache.spark.deploy.yarn.security.ConfigurableCredentialManager +import org.apache.spark.deploy.yarn.security.YARNHadoopDelegationTokenManager import org.apache.spark.internal.Logging import org.apache.spark.internal.config._ import org.apache.spark.launcher.{LauncherBackend, SparkAppHandle, YarnCommandBuilderUtils} @@ -121,7 +121,10 @@ private[spark] class Client( private val appStagingBaseDir = sparkConf.get(STAGING_DIR).map { new Path(_) } .getOrElse(FileSystem.get(hadoopConf).getHomeDirectory()) - private val credentialManager = new ConfigurableCredentialManager(sparkConf, hadoopConf) + private val credentialManager = new YARNHadoopDelegationTokenManager( + sparkConf, + hadoopConf, + YarnSparkHadoopUtil.get.hadoopFSsToAccess(sparkConf, hadoopConf)) def reportLauncherState(state: SparkAppHandle.State): Unit = { launcherBackend.setState(state) @@ -368,7 +371,7 @@ private[spark] class Client( val fs = destDir.getFileSystem(hadoopConf) // Merge credentials obtained from registered providers - val nearestTimeOfNextRenewal = credentialManager.obtainCredentials(hadoopConf, credentials) + val nearestTimeOfNextRenewal = credentialManager.obtainDelegationTokens(hadoopConf, credentials) if (credentials != null) { // Add credentials to current user's UGI, so that following operations don't need to use the diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtil.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtil.scala index 0fc994d629ccb..4522071bd92e2 100644 --- a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtil.scala +++ b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtil.scala @@ -24,8 +24,9 @@ import java.util.regex.Pattern import scala.collection.mutable.{HashMap, ListBuffer} import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.fs.{FileSystem, Path} import org.apache.hadoop.io.Text -import org.apache.hadoop.mapred.JobConf +import org.apache.hadoop.mapred.{JobConf, Master} import org.apache.hadoop.security.Credentials import org.apache.hadoop.security.UserGroupInformation import org.apache.hadoop.yarn.api.ApplicationConstants @@ -35,11 +36,14 @@ import org.apache.hadoop.yarn.util.ConverterUtils import org.apache.spark.{SecurityManager, SparkConf, SparkException} import org.apache.spark.deploy.SparkHadoopUtil -import org.apache.spark.deploy.yarn.security.{ConfigurableCredentialManager, CredentialUpdater} +import org.apache.spark.deploy.yarn.config._ +import org.apache.spark.deploy.yarn.security.CredentialUpdater +import org.apache.spark.deploy.yarn.security.YARNHadoopDelegationTokenManager import org.apache.spark.internal.config._ import org.apache.spark.launcher.YarnCommandBuilderUtils import org.apache.spark.util.Utils + /** * Contains util methods to interact with Hadoop from spark. */ @@ -87,8 +91,12 @@ class YarnSparkHadoopUtil extends SparkHadoopUtil { } private[spark] override def startCredentialUpdater(sparkConf: SparkConf): Unit = { - credentialUpdater = - new ConfigurableCredentialManager(sparkConf, newConfiguration(sparkConf)).credentialUpdater() + val hadoopConf = newConfiguration(sparkConf) + val credentialManager = new YARNHadoopDelegationTokenManager( + sparkConf, + hadoopConf, + YarnSparkHadoopUtil.get.hadoopFSsToAccess(sparkConf, hadoopConf)) + credentialUpdater = new CredentialUpdater(sparkConf, hadoopConf, credentialManager) credentialUpdater.start() } @@ -103,6 +111,21 @@ class YarnSparkHadoopUtil extends SparkHadoopUtil { val containerIdString = System.getenv(ApplicationConstants.Environment.CONTAINER_ID.name()) ConverterUtils.toContainerId(containerIdString) } + + /** The filesystems for which YARN should fetch delegation tokens. */ + private[spark] def hadoopFSsToAccess( + sparkConf: SparkConf, + hadoopConf: Configuration): Set[FileSystem] = { + val filesystemsToAccess = sparkConf.get(FILESYSTEMS_TO_ACCESS) + .map(new Path(_).getFileSystem(hadoopConf)) + .toSet + + val stagingFS = sparkConf.get(STAGING_DIR) + .map(new Path(_).getFileSystem(hadoopConf)) + .getOrElse(FileSystem.get(hadoopConf)) + + filesystemsToAccess + stagingFS + } } object YarnSparkHadoopUtil { diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/security/AMCredentialRenewer.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/security/AMCredentialRenewer.scala index 7e76f402db249..68a2e9e70a78b 100644 --- a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/security/AMCredentialRenewer.scala +++ b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/security/AMCredentialRenewer.scala @@ -54,7 +54,7 @@ import org.apache.spark.util.ThreadUtils private[yarn] class AMCredentialRenewer( sparkConf: SparkConf, hadoopConf: Configuration, - credentialManager: ConfigurableCredentialManager) extends Logging { + credentialManager: YARNHadoopDelegationTokenManager) extends Logging { private var lastCredentialsFileSuffix = 0 @@ -174,7 +174,9 @@ private[yarn] class AMCredentialRenewer( keytabLoggedInUGI.doAs(new PrivilegedExceptionAction[Void] { // Get a copy of the credentials override def run(): Void = { - nearestNextRenewalTime = credentialManager.obtainCredentials(freshHadoopConf, tempCreds) + nearestNextRenewalTime = credentialManager.obtainDelegationTokens( + freshHadoopConf, + tempCreds) null } }) diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/security/ConfigurableCredentialManager.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/security/ConfigurableCredentialManager.scala deleted file mode 100644 index 4f4be52a0d691..0000000000000 --- a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/security/ConfigurableCredentialManager.scala +++ /dev/null @@ -1,107 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.deploy.yarn.security - -import java.util.ServiceLoader - -import scala.collection.JavaConverters._ - -import org.apache.hadoop.conf.Configuration -import org.apache.hadoop.security.Credentials - -import org.apache.spark.SparkConf -import org.apache.spark.internal.Logging -import org.apache.spark.util.Utils - -/** - * A ConfigurableCredentialManager to manage all the registered credential providers and offer - * APIs for other modules to obtain credentials as well as renewal time. By default - * [[HadoopFSCredentialProvider]], [[HiveCredentialProvider]] and [[HBaseCredentialProvider]] will - * be loaded in if not explicitly disabled, any plugged-in credential provider wants to be - * managed by ConfigurableCredentialManager needs to implement [[ServiceCredentialProvider]] - * interface and put into resources/META-INF/services to be loaded by ServiceLoader. - * - * Also each credential provider is controlled by - * spark.yarn.security.credentials.{service}.enabled, it will not be loaded in if set to false. - * For example, Hive's credential provider [[HiveCredentialProvider]] can be enabled/disabled by - * the configuration spark.yarn.security.credentials.hive.enabled. - */ -private[yarn] final class ConfigurableCredentialManager( - sparkConf: SparkConf, hadoopConf: Configuration) extends Logging { - private val deprecatedProviderEnabledConfig = "spark.yarn.security.tokens.%s.enabled" - private val providerEnabledConfig = "spark.yarn.security.credentials.%s.enabled" - - // Maintain all the registered credential providers - private val credentialProviders = { - val providers = ServiceLoader.load(classOf[ServiceCredentialProvider], - Utils.getContextOrSparkClassLoader).asScala - - // Filter out credentials in which spark.yarn.security.credentials.{service}.enabled is false. - providers.filter { p => - sparkConf.getOption(providerEnabledConfig.format(p.serviceName)) - .orElse { - sparkConf.getOption(deprecatedProviderEnabledConfig.format(p.serviceName)).map { c => - logWarning(s"${deprecatedProviderEnabledConfig.format(p.serviceName)} is deprecated, " + - s"using ${providerEnabledConfig.format(p.serviceName)} instead") - c - } - }.map(_.toBoolean).getOrElse(true) - }.map { p => (p.serviceName, p) }.toMap - } - - /** - * Get credential provider for the specified service. - */ - def getServiceCredentialProvider(service: String): Option[ServiceCredentialProvider] = { - credentialProviders.get(service) - } - - /** - * Obtain credentials from all the registered providers. - * @return nearest time of next renewal, Long.MaxValue if all the credentials aren't renewable, - * otherwise the nearest renewal time of any credentials will be returned. - */ - def obtainCredentials(hadoopConf: Configuration, creds: Credentials): Long = { - credentialProviders.values.flatMap { provider => - if (provider.credentialsRequired(hadoopConf)) { - provider.obtainCredentials(hadoopConf, sparkConf, creds) - } else { - logDebug(s"Service ${provider.serviceName} does not require a token." + - s" Check your configuration to see if security is disabled or not.") - None - } - }.foldLeft(Long.MaxValue)(math.min) - } - - /** - * Create an [[AMCredentialRenewer]] instance, caller should be responsible to stop this - * instance when it is not used. AM will use it to renew credentials periodically. - */ - def credentialRenewer(): AMCredentialRenewer = { - new AMCredentialRenewer(sparkConf, hadoopConf, this) - } - - /** - * Create an [[CredentialUpdater]] instance, caller should be resposible to stop this intance - * when it is not used. Executors and driver (client mode) will use it to update credentials. - * periodically. - */ - def credentialUpdater(): CredentialUpdater = { - new CredentialUpdater(sparkConf, hadoopConf, this) - } -} diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/security/CredentialUpdater.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/security/CredentialUpdater.scala index 41b7b5d60b038..fe173dffc22a8 100644 --- a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/security/CredentialUpdater.scala +++ b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/security/CredentialUpdater.scala @@ -34,7 +34,7 @@ import org.apache.spark.util.{ThreadUtils, Utils} private[spark] class CredentialUpdater( sparkConf: SparkConf, hadoopConf: Configuration, - credentialManager: ConfigurableCredentialManager) extends Logging { + credentialManager: YARNHadoopDelegationTokenManager) extends Logging { @volatile private var lastCredentialsFileSuffix = 0 diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/security/HadoopFSCredentialProvider.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/security/HadoopFSCredentialProvider.scala deleted file mode 100644 index f65c886db944e..0000000000000 --- a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/security/HadoopFSCredentialProvider.scala +++ /dev/null @@ -1,120 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.deploy.yarn.security - -import scala.collection.JavaConverters._ -import scala.util.Try - -import org.apache.hadoop.conf.Configuration -import org.apache.hadoop.fs.{FileSystem, Path} -import org.apache.hadoop.mapred.Master -import org.apache.hadoop.security.Credentials -import org.apache.hadoop.security.token.delegation.AbstractDelegationTokenIdentifier - -import org.apache.spark.{SparkConf, SparkException} -import org.apache.spark.deploy.yarn.config._ -import org.apache.spark.internal.Logging -import org.apache.spark.internal.config._ - -private[security] class HadoopFSCredentialProvider - extends ServiceCredentialProvider with Logging { - // Token renewal interval, this value will be set in the first call, - // if None means no token renewer specified or no token can be renewed, - // so cannot get token renewal interval. - private var tokenRenewalInterval: Option[Long] = null - - override val serviceName: String = "hadoopfs" - - override def obtainCredentials( - hadoopConf: Configuration, - sparkConf: SparkConf, - creds: Credentials): Option[Long] = { - // NameNode to access, used to get tokens from different FileSystems - val tmpCreds = new Credentials() - val tokenRenewer = getTokenRenewer(hadoopConf) - hadoopFSsToAccess(hadoopConf, sparkConf).foreach { dst => - val dstFs = dst.getFileSystem(hadoopConf) - logInfo("getting token for: " + dst) - dstFs.addDelegationTokens(tokenRenewer, tmpCreds) - } - - // Get the token renewal interval if it is not set. It will only be called once. - if (tokenRenewalInterval == null) { - tokenRenewalInterval = getTokenRenewalInterval(hadoopConf, sparkConf) - } - - // Get the time of next renewal. - val nextRenewalDate = tokenRenewalInterval.flatMap { interval => - val nextRenewalDates = tmpCreds.getAllTokens.asScala - .filter(_.decodeIdentifier().isInstanceOf[AbstractDelegationTokenIdentifier]) - .map { t => - val identifier = t.decodeIdentifier().asInstanceOf[AbstractDelegationTokenIdentifier] - identifier.getIssueDate + interval - } - if (nextRenewalDates.isEmpty) None else Some(nextRenewalDates.min) - } - - creds.addAll(tmpCreds) - nextRenewalDate - } - - private def getTokenRenewalInterval( - hadoopConf: Configuration, sparkConf: SparkConf): Option[Long] = { - // We cannot use the tokens generated with renewer yarn. Trying to renew - // those will fail with an access control issue. So create new tokens with the logged in - // user as renewer. - sparkConf.get(PRINCIPAL).flatMap { renewer => - val creds = new Credentials() - hadoopFSsToAccess(hadoopConf, sparkConf).foreach { dst => - val dstFs = dst.getFileSystem(hadoopConf) - dstFs.addDelegationTokens(renewer, creds) - } - - val renewIntervals = creds.getAllTokens.asScala.filter { - _.decodeIdentifier().isInstanceOf[AbstractDelegationTokenIdentifier] - }.flatMap { token => - Try { - val newExpiration = token.renew(hadoopConf) - val identifier = token.decodeIdentifier().asInstanceOf[AbstractDelegationTokenIdentifier] - val interval = newExpiration - identifier.getIssueDate - logInfo(s"Renewal interval is $interval for token ${token.getKind.toString}") - interval - }.toOption - } - if (renewIntervals.isEmpty) None else Some(renewIntervals.min) - } - } - - private def getTokenRenewer(conf: Configuration): String = { - val delegTokenRenewer = Master.getMasterPrincipal(conf) - logDebug("delegation token renewer is: " + delegTokenRenewer) - if (delegTokenRenewer == null || delegTokenRenewer.length() == 0) { - val errorMessage = "Can't get Master Kerberos principal for use as renewer" - logError(errorMessage) - throw new SparkException(errorMessage) - } - - delegTokenRenewer - } - - private def hadoopFSsToAccess(hadoopConf: Configuration, sparkConf: SparkConf): Set[Path] = { - sparkConf.get(FILESYSTEMS_TO_ACCESS).map(new Path(_)).toSet + - sparkConf.get(STAGING_DIR).map(new Path(_)) - .getOrElse(FileSystem.get(hadoopConf).getHomeDirectory) - } -} diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/security/ServiceCredentialProvider.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/security/ServiceCredentialProvider.scala index 4e3fcce8dbb1d..cc24ac4d9bcf6 100644 --- a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/security/ServiceCredentialProvider.scala +++ b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/security/ServiceCredentialProvider.scala @@ -35,7 +35,7 @@ trait ServiceCredentialProvider { def serviceName: String /** - * To decide whether credential is required for this service. By default it based on whether + * Returns true if credentials are required by this service. By default, it is based on whether * Hadoop security is enabled. */ def credentialsRequired(hadoopConf: Configuration): Boolean = { @@ -44,6 +44,7 @@ trait ServiceCredentialProvider { /** * Obtain credentials for this service and get the time of the next renewal. + * * @param hadoopConf Configuration of current Hadoop Compatible system. * @param sparkConf Spark configuration. * @param creds Credentials to add tokens and security keys to. diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/security/YARNHadoopDelegationTokenManager.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/security/YARNHadoopDelegationTokenManager.scala new file mode 100644 index 0000000000000..bbd17c8fc1272 --- /dev/null +++ b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/security/YARNHadoopDelegationTokenManager.scala @@ -0,0 +1,83 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.deploy.yarn.security + +import java.util.ServiceLoader + +import scala.collection.JavaConverters._ + +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.fs.FileSystem +import org.apache.hadoop.security.Credentials + +import org.apache.spark.SparkConf +import org.apache.spark.deploy.security.HadoopDelegationTokenManager +import org.apache.spark.internal.Logging +import org.apache.spark.util.Utils + +/** + * This class loads delegation token providers registered under the YARN-specific + * [[ServiceCredentialProvider]] interface, as well as the builtin providers defined + * in [[HadoopDelegationTokenManager]]. + */ +private[yarn] class YARNHadoopDelegationTokenManager( + sparkConf: SparkConf, + hadoopConf: Configuration, + fileSystems: Set[FileSystem]) extends Logging { + + private val delegationTokenManager = + new HadoopDelegationTokenManager(sparkConf, hadoopConf, fileSystems) + + // public for testing + val credentialProviders = getCredentialProviders + + /** + * Writes delegation tokens to creds. Delegation tokens are fetched from all registered + * providers. + * + * @return Time after which the fetched delegation tokens should be renewed. + */ + def obtainDelegationTokens(hadoopConf: Configuration, creds: Credentials): Long = { + val superInterval = delegationTokenManager.obtainDelegationTokens(hadoopConf, creds) + + credentialProviders.values.flatMap { provider => + if (provider.credentialsRequired(hadoopConf)) { + provider.obtainCredentials(hadoopConf, sparkConf, creds) + } else { + logDebug(s"Service ${provider.serviceName} does not require a token." + + s" Check your configuration to see if security is disabled or not.") + None + } + }.foldLeft(superInterval)(math.min) + } + + private def getCredentialProviders: Map[String, ServiceCredentialProvider] = { + val providers = loadCredentialProviders + + providers. + filter { p => delegationTokenManager.isServiceEnabled(p.serviceName) } + .map { p => (p.serviceName, p) } + .toMap + } + + private def loadCredentialProviders: List[ServiceCredentialProvider] = { + ServiceLoader.load(classOf[ServiceCredentialProvider], Utils.getContextOrSparkClassLoader) + .asScala + .toList + } +} diff --git a/resource-managers/yarn/src/test/resources/META-INF/services/org.apache.spark.deploy.yarn.security.ServiceCredentialProvider b/resource-managers/yarn/src/test/resources/META-INF/services/org.apache.spark.deploy.yarn.security.ServiceCredentialProvider index d0ef5efa36e86..f31c232693133 100644 --- a/resource-managers/yarn/src/test/resources/META-INF/services/org.apache.spark.deploy.yarn.security.ServiceCredentialProvider +++ b/resource-managers/yarn/src/test/resources/META-INF/services/org.apache.spark.deploy.yarn.security.ServiceCredentialProvider @@ -1 +1 @@ -org.apache.spark.deploy.yarn.security.TestCredentialProvider +org.apache.spark.deploy.yarn.security.YARNTestCredentialProvider diff --git a/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/security/ConfigurableCredentialManagerSuite.scala b/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/security/ConfigurableCredentialManagerSuite.scala deleted file mode 100644 index b0067aa4517c7..0000000000000 --- a/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/security/ConfigurableCredentialManagerSuite.scala +++ /dev/null @@ -1,150 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.deploy.yarn.security - -import org.apache.hadoop.conf.Configuration -import org.apache.hadoop.io.Text -import org.apache.hadoop.security.Credentials -import org.apache.hadoop.security.token.Token -import org.scalatest.{BeforeAndAfter, Matchers} - -import org.apache.spark.{SparkConf, SparkFunSuite} -import org.apache.spark.deploy.yarn.config._ - -class ConfigurableCredentialManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfter { - private var credentialManager: ConfigurableCredentialManager = null - private var sparkConf: SparkConf = null - private var hadoopConf: Configuration = null - - override def beforeAll(): Unit = { - super.beforeAll() - - sparkConf = new SparkConf() - hadoopConf = new Configuration() - System.setProperty("SPARK_YARN_MODE", "true") - } - - override def afterAll(): Unit = { - System.clearProperty("SPARK_YARN_MODE") - - super.afterAll() - } - - test("Correctly load default credential providers") { - credentialManager = new ConfigurableCredentialManager(sparkConf, hadoopConf) - - credentialManager.getServiceCredentialProvider("hadoopfs") should not be (None) - credentialManager.getServiceCredentialProvider("hbase") should not be (None) - credentialManager.getServiceCredentialProvider("hive") should not be (None) - } - - test("disable hive credential provider") { - sparkConf.set("spark.yarn.security.credentials.hive.enabled", "false") - credentialManager = new ConfigurableCredentialManager(sparkConf, hadoopConf) - - credentialManager.getServiceCredentialProvider("hadoopfs") should not be (None) - credentialManager.getServiceCredentialProvider("hbase") should not be (None) - credentialManager.getServiceCredentialProvider("hive") should be (None) - } - - test("using deprecated configurations") { - sparkConf.set("spark.yarn.security.tokens.hadoopfs.enabled", "false") - sparkConf.set("spark.yarn.security.tokens.hive.enabled", "false") - credentialManager = new ConfigurableCredentialManager(sparkConf, hadoopConf) - - credentialManager.getServiceCredentialProvider("hadoopfs") should be (None) - credentialManager.getServiceCredentialProvider("hive") should be (None) - credentialManager.getServiceCredentialProvider("test") should not be (None) - credentialManager.getServiceCredentialProvider("hbase") should not be (None) - } - - test("verify obtaining credentials from provider") { - credentialManager = new ConfigurableCredentialManager(sparkConf, hadoopConf) - val creds = new Credentials() - - // Tokens can only be obtained from TestTokenProvider, for hdfs, hbase and hive tokens cannot - // be obtained. - credentialManager.obtainCredentials(hadoopConf, creds) - val tokens = creds.getAllTokens - tokens.size() should be (1) - tokens.iterator().next().getService should be (new Text("test")) - } - - test("verify getting credential renewal info") { - credentialManager = new ConfigurableCredentialManager(sparkConf, hadoopConf) - val creds = new Credentials() - - val testCredentialProvider = credentialManager.getServiceCredentialProvider("test").get - .asInstanceOf[TestCredentialProvider] - // Only TestTokenProvider can get the time of next token renewal - val nextRenewal = credentialManager.obtainCredentials(hadoopConf, creds) - nextRenewal should be (testCredentialProvider.timeOfNextTokenRenewal) - } - - test("obtain tokens For HiveMetastore") { - val hadoopConf = new Configuration() - hadoopConf.set("hive.metastore.kerberos.principal", "bob") - // thrift picks up on port 0 and bails out, without trying to talk to endpoint - hadoopConf.set("hive.metastore.uris", "http://localhost:0") - - val hiveCredentialProvider = new HiveCredentialProvider() - val credentials = new Credentials() - hiveCredentialProvider.obtainCredentials(hadoopConf, sparkConf, credentials) - - credentials.getAllTokens.size() should be (0) - } - - test("Obtain tokens For HBase") { - val hadoopConf = new Configuration() - hadoopConf.set("hbase.security.authentication", "kerberos") - - val hbaseTokenProvider = new HBaseCredentialProvider() - val creds = new Credentials() - hbaseTokenProvider.obtainCredentials(hadoopConf, sparkConf, creds) - - creds.getAllTokens.size should be (0) - } -} - -class TestCredentialProvider extends ServiceCredentialProvider { - val tokenRenewalInterval = 86400 * 1000L - var timeOfNextTokenRenewal = 0L - - override def serviceName: String = "test" - - override def credentialsRequired(conf: Configuration): Boolean = true - - override def obtainCredentials( - hadoopConf: Configuration, - sparkConf: SparkConf, - creds: Credentials): Option[Long] = { - if (creds == null) { - // Guard out other unit test failures. - return None - } - - val emptyToken = new Token() - emptyToken.setService(new Text("test")) - creds.addToken(emptyToken.getService, emptyToken) - - val currTime = System.currentTimeMillis() - timeOfNextTokenRenewal = (currTime - currTime % tokenRenewalInterval) + tokenRenewalInterval - - Some(timeOfNextTokenRenewal) - } -} diff --git a/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/security/HadoopFSCredentialProviderSuite.scala b/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/security/HadoopFSCredentialProviderSuite.scala deleted file mode 100644 index f50ee193c258f..0000000000000 --- a/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/security/HadoopFSCredentialProviderSuite.scala +++ /dev/null @@ -1,70 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.deploy.yarn.security - -import org.apache.hadoop.conf.Configuration -import org.scalatest.{Matchers, PrivateMethodTester} - -import org.apache.spark.{SparkException, SparkFunSuite} - -class HadoopFSCredentialProviderSuite - extends SparkFunSuite - with PrivateMethodTester - with Matchers { - private val _getTokenRenewer = PrivateMethod[String]('getTokenRenewer) - - private def getTokenRenewer( - fsCredentialProvider: HadoopFSCredentialProvider, conf: Configuration): String = { - fsCredentialProvider invokePrivate _getTokenRenewer(conf) - } - - private var hadoopFsCredentialProvider: HadoopFSCredentialProvider = null - - override def beforeAll() { - super.beforeAll() - - if (hadoopFsCredentialProvider == null) { - hadoopFsCredentialProvider = new HadoopFSCredentialProvider() - } - } - - override def afterAll() { - if (hadoopFsCredentialProvider != null) { - hadoopFsCredentialProvider = null - } - - super.afterAll() - } - - test("check token renewer") { - val hadoopConf = new Configuration() - hadoopConf.set("yarn.resourcemanager.address", "myrm:8033") - hadoopConf.set("yarn.resourcemanager.principal", "yarn/myrm:8032@SPARKTEST.COM") - val renewer = getTokenRenewer(hadoopFsCredentialProvider, hadoopConf) - renewer should be ("yarn/myrm:8032@SPARKTEST.COM") - } - - test("check token renewer default") { - val hadoopConf = new Configuration() - val caught = - intercept[SparkException] { - getTokenRenewer(hadoopFsCredentialProvider, hadoopConf) - } - assert(caught.getMessage === "Can't get Master Kerberos principal for use as renewer") - } -} diff --git a/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/security/YARNHadoopDelegationTokenManagerSuite.scala b/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/security/YARNHadoopDelegationTokenManagerSuite.scala new file mode 100644 index 0000000000000..2b226eff5ce19 --- /dev/null +++ b/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/security/YARNHadoopDelegationTokenManagerSuite.scala @@ -0,0 +1,66 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.deploy.yarn.security + +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.security.Credentials +import org.scalatest.Matchers + +import org.apache.spark.{SparkConf, SparkFunSuite} +import org.apache.spark.deploy.yarn.YarnSparkHadoopUtil + +class YARNHadoopDelegationTokenManagerSuite extends SparkFunSuite with Matchers { + private var credentialManager: YARNHadoopDelegationTokenManager = null + private var sparkConf: SparkConf = null + private var hadoopConf: Configuration = null + + override def beforeAll(): Unit = { + super.beforeAll() + + System.setProperty("SPARK_YARN_MODE", "true") + + sparkConf = new SparkConf() + hadoopConf = new Configuration() + } + + override def afterAll(): Unit = { + super.afterAll() + + System.clearProperty("SPARK_YARN_MODE") + } + + test("Correctly loads credential providers") { + credentialManager = new YARNHadoopDelegationTokenManager( + sparkConf, + hadoopConf, + YarnSparkHadoopUtil.get.hadoopFSsToAccess(sparkConf, hadoopConf)) + + credentialManager.credentialProviders.get("yarn-test") should not be (None) + } +} + +class YARNTestCredentialProvider extends ServiceCredentialProvider { + override def serviceName: String = "yarn-test" + + override def credentialsRequired(conf: Configuration): Boolean = true + + override def obtainCredentials( + hadoopConf: Configuration, + sparkConf: SparkConf, + creds: Credentials): Option[Long] = None +} From 5d35d5c15c63debaa79202708c6e6481980a6a7f Mon Sep 17 00:00:00 2001 From: Xiao Li Date: Fri, 16 Jun 2017 10:11:23 +0800 Subject: [PATCH 0725/1765] [SPARK-21112][SQL] ALTER TABLE SET TBLPROPERTIES should not overwrite COMMENT ### What changes were proposed in this pull request? `ALTER TABLE SET TBLPROPERTIES` should not overwrite `COMMENT` even if the input property does not have the property of `COMMENT`. This PR is to fix the issue. ### How was this patch tested? Covered by the existing tests. Author: Xiao Li Closes #18318 from gatorsmile/fixTableComment. --- .../main/scala/org/apache/spark/sql/execution/command/ddl.scala | 2 +- sql/core/src/test/resources/sql-tests/results/describe.sql.out | 2 ++ 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala index 5a7f8cf1eb59e..f924b3d914635 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala @@ -235,7 +235,7 @@ case class AlterTableSetPropertiesCommand( // direct property. val newTable = table.copy( properties = table.properties ++ properties, - comment = properties.get("comment")) + comment = properties.get("comment").orElse(table.comment)) catalog.alterTable(newTable) Seq.empty[Row] } diff --git a/sql/core/src/test/resources/sql-tests/results/describe.sql.out b/sql/core/src/test/resources/sql-tests/results/describe.sql.out index 329532cd7c842..ab9f2783f06bb 100644 --- a/sql/core/src/test/resources/sql-tests/results/describe.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/describe.sql.out @@ -127,6 +127,7 @@ Provider parquet Num Buckets 2 Bucket Columns [`a`] Sort Columns [`b`] +Comment table_comment Table Properties [e=3] Location [not included in comparison]sql/core/spark-warehouse/t Storage Properties [a=1, b=2] @@ -157,6 +158,7 @@ Provider parquet Num Buckets 2 Bucket Columns [`a`] Sort Columns [`b`] +Comment table_comment Table Properties [e=3] Location [not included in comparison]sql/core/spark-warehouse/t Storage Properties [a=1, b=2] From 87ab0cec65b50584a627037b9d1b6fdecaee725c Mon Sep 17 00:00:00 2001 From: Xianyang Liu Date: Fri, 16 Jun 2017 12:10:09 +0800 Subject: [PATCH 0726/1765] [SPARK-21072][SQL] TreeNode.mapChildren should only apply to the children node. ## What changes were proposed in this pull request? Just as the function name and comments of `TreeNode.mapChildren` mentioned, the function should be apply to all currently node children. So, the follow code should judge whether it is the children node. https://github.com/apache/spark/blob/master/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala#L342 ## How was this patch tested? Existing tests. Author: Xianyang Liu Closes #18284 from ConeyLiu/treenode. --- .../spark/sql/catalyst/trees/TreeNode.scala | 14 +++++++++++-- .../sql/catalyst/trees/TreeNodeSuite.scala | 21 ++++++++++++++++++- 2 files changed, 32 insertions(+), 3 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala index df66f9a082aee..7375a0bcbae75 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala @@ -340,8 +340,18 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product { arg } case tuple@(arg1: TreeNode[_], arg2: TreeNode[_]) => - val newChild1 = f(arg1.asInstanceOf[BaseType]) - val newChild2 = f(arg2.asInstanceOf[BaseType]) + val newChild1 = if (containsChild(arg1)) { + f(arg1.asInstanceOf[BaseType]) + } else { + arg1.asInstanceOf[BaseType] + } + + val newChild2 = if (containsChild(arg2)) { + f(arg2.asInstanceOf[BaseType]) + } else { + arg2.asInstanceOf[BaseType] + } + if (!(newChild1 fastEquals arg1) || !(newChild2 fastEquals arg2)) { changed = true (newChild1, newChild2) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala index 712841835acd5..819078218c546 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala @@ -54,13 +54,21 @@ case class ComplexPlan(exprs: Seq[Seq[Expression]]) override def output: Seq[Attribute] = Nil } -case class ExpressionInMap(map: Map[String, Expression]) extends Expression with Unevaluable { +case class ExpressionInMap(map: Map[String, Expression]) extends Unevaluable { override def children: Seq[Expression] = map.values.toSeq override def nullable: Boolean = true override def dataType: NullType = NullType override lazy val resolved = true } +case class SeqTupleExpression(sons: Seq[(Expression, Expression)], + nonSons: Seq[(Expression, Expression)]) extends Unevaluable { + override def children: Seq[Expression] = sons.flatMap(t => Iterator(t._1, t._2)) + override def nullable: Boolean = true + override def dataType: NullType = NullType + override lazy val resolved = true +} + case class JsonTestTreeNode(arg: Any) extends LeafNode { override def output: Seq[Attribute] = Seq.empty[Attribute] } @@ -146,6 +154,17 @@ class TreeNodeSuite extends SparkFunSuite { assert(actual === Dummy(None)) } + test("mapChildren should only works on children") { + val children = Seq((Literal(1), Literal(2))) + val nonChildren = Seq((Literal(3), Literal(4))) + val before = SeqTupleExpression(children, nonChildren) + val toZero: PartialFunction[Expression, Expression] = { case Literal(_, _) => Literal(0) } + val expect = SeqTupleExpression(Seq((Literal(0), Literal(0))), nonChildren) + + val actual = before mapChildren toZero + assert(actual === expect) + } + test("preserves origin") { CurrentOrigin.setPosition(1, 1) val add = Add(Literal(1), Literal(1)) From 7a3e5dc28b67ac1630c5a578a27a5a5acf80aa51 Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Thu, 15 Jun 2017 23:06:58 -0700 Subject: [PATCH 0727/1765] [SPARK-20749][SQL] Built-in SQL Function Support - all variants of LEN[GTH] ## What changes were proposed in this pull request? This PR adds built-in SQL function `BIT_LENGTH()`, `CHAR_LENGTH()`, and `OCTET_LENGTH()` functions. `BIT_LENGTH()` returns the bit length of the given string or binary expression. `CHAR_LENGTH()` returns the length of the given string or binary expression. (i.e. equal to `LENGTH()`) `OCTET_LENGTH()` returns the byte length of the given string or binary expression. ## How was this patch tested? Added new test suites for these three functions Author: Kazuaki Ishizaki Closes #18046 from kiszk/SPARK-20749. --- .../catalyst/analysis/FunctionRegistry.scala | 3 + .../expressions/stringExpressions.scala | 61 ++++++++++++++++++- .../expressions/StringExpressionsSuite.scala | 20 ++++++ .../resources/sql-tests/inputs/operators.sql | 5 ++ .../sql-tests/results/operators.sql.out | 26 +++++++- 5 files changed, 112 insertions(+), 3 deletions(-) mode change 100644 => 100755 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index 877328164a8a9..e4e9918a3a887 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -305,6 +305,8 @@ object FunctionRegistry { expression[Chr]("char"), expression[Chr]("chr"), expression[Base64]("base64"), + expression[BitLength]("bit_length"), + expression[Length]("char_length"), expression[Concat]("concat"), expression[ConcatWs]("concat_ws"), expression[Decode]("decode"), @@ -321,6 +323,7 @@ object FunctionRegistry { expression[Levenshtein]("levenshtein"), expression[Like]("like"), expression[Lower]("lower"), + expression[OctetLength]("octet_length"), expression[StringLocate]("locate"), expression[StringLPad]("lpad"), expression[StringTrimLeft]("ltrim"), diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala old mode 100644 new mode 100755 index 717ada225a4f1..908fdb8f7e68f --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala @@ -1199,15 +1199,18 @@ case class Substring(str: Expression, pos: Expression, len: Expression) } /** - * A function that return the length of the given string or binary expression. + * A function that returns the char length of the given string expression or + * number of bytes of the given binary expression. */ +// scalastyle:off line.size.limit @ExpressionDescription( - usage = "_FUNC_(expr) - Returns the length of `expr` or number of bytes in binary data.", + usage = "_FUNC_(expr) - Returns the character length of `expr` or number of bytes in binary data.", extended = """ Examples: > SELECT _FUNC_('Spark SQL'); 9 """) +// scalastyle:on line.size.limit case class Length(child: Expression) extends UnaryExpression with ImplicitCastInputTypes { override def dataType: DataType = IntegerType override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection(StringType, BinaryType)) @@ -1225,6 +1228,60 @@ case class Length(child: Expression) extends UnaryExpression with ImplicitCastIn } } +/** + * A function that returns the bit length of the given string or binary expression. + */ +@ExpressionDescription( + usage = "_FUNC_(expr) - Returns the bit length of `expr` or number of bits in binary data.", + extended = """ + Examples: + > SELECT _FUNC_('Spark SQL'); + 72 + """) +case class BitLength(child: Expression) extends UnaryExpression with ImplicitCastInputTypes { + override def dataType: DataType = IntegerType + override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection(StringType, BinaryType)) + + protected override def nullSafeEval(value: Any): Any = child.dataType match { + case StringType => value.asInstanceOf[UTF8String].numBytes * 8 + case BinaryType => value.asInstanceOf[Array[Byte]].length * 8 + } + + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + child.dataType match { + case StringType => defineCodeGen(ctx, ev, c => s"($c).numBytes() * 8") + case BinaryType => defineCodeGen(ctx, ev, c => s"($c).length * 8") + } + } +} + +/** + * A function that returns the byte length of the given string or binary expression. + */ +@ExpressionDescription( + usage = "_FUNC_(expr) - Returns the byte length of `expr` or number of bytes in binary data.", + extended = """ + Examples: + > SELECT _FUNC_('Spark SQL'); + 9 + """) +case class OctetLength(child: Expression) extends UnaryExpression with ImplicitCastInputTypes { + override def dataType: DataType = IntegerType + override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection(StringType, BinaryType)) + + protected override def nullSafeEval(value: Any): Any = child.dataType match { + case StringType => value.asInstanceOf[UTF8String].numBytes + case BinaryType => value.asInstanceOf[Array[Byte]].length + } + + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + child.dataType match { + case StringType => defineCodeGen(ctx, ev, c => s"($c).numBytes()") + case BinaryType => defineCodeGen(ctx, ev, c => s"($c).length") + } + } +} + /** * A function that return the Levenshtein distance between the two given strings. */ diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala index 4bdb43bfed8b5..4f08031153ab0 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala @@ -558,20 +558,40 @@ class StringExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { // scalastyle:off // non ascii characters are not allowed in the source code, so we disable the scalastyle. checkEvaluation(Length(Literal("a花花c")), 4, create_row(string)) + checkEvaluation(OctetLength(Literal("a花花c")), 8, create_row(string)) + checkEvaluation(BitLength(Literal("a花花c")), 8 * 8, create_row(string)) // scalastyle:on checkEvaluation(Length(Literal(bytes)), 5, create_row(Array.empty[Byte])) + checkEvaluation(OctetLength(Literal(bytes)), 5, create_row(Array.empty[Byte])) + checkEvaluation(BitLength(Literal(bytes)), 5 * 8, create_row(Array.empty[Byte])) checkEvaluation(Length(a), 5, create_row(string)) + checkEvaluation(OctetLength(a), 5, create_row(string)) + checkEvaluation(BitLength(a), 5 * 8, create_row(string)) checkEvaluation(Length(b), 5, create_row(bytes)) + checkEvaluation(OctetLength(b), 5, create_row(bytes)) + checkEvaluation(BitLength(b), 5 * 8, create_row(bytes)) checkEvaluation(Length(a), 0, create_row("")) + checkEvaluation(OctetLength(a), 0, create_row("")) + checkEvaluation(BitLength(a), 0, create_row("")) checkEvaluation(Length(b), 0, create_row(Array.empty[Byte])) + checkEvaluation(OctetLength(b), 0, create_row(Array.empty[Byte])) + checkEvaluation(BitLength(b), 0, create_row(Array.empty[Byte])) checkEvaluation(Length(a), null, create_row(null)) + checkEvaluation(OctetLength(a), null, create_row(null)) + checkEvaluation(BitLength(a), null, create_row(null)) checkEvaluation(Length(b), null, create_row(null)) + checkEvaluation(OctetLength(b), null, create_row(null)) + checkEvaluation(BitLength(b), null, create_row(null)) checkEvaluation(Length(Literal.create(null, StringType)), null, create_row(string)) + checkEvaluation(OctetLength(Literal.create(null, StringType)), null, create_row(string)) + checkEvaluation(BitLength(Literal.create(null, StringType)), null, create_row(string)) checkEvaluation(Length(Literal.create(null, BinaryType)), null, create_row(bytes)) + checkEvaluation(OctetLength(Literal.create(null, BinaryType)), null, create_row(bytes)) + checkEvaluation(BitLength(Literal.create(null, BinaryType)), null, create_row(bytes)) } test("format_number / FormatNumber") { diff --git a/sql/core/src/test/resources/sql-tests/inputs/operators.sql b/sql/core/src/test/resources/sql-tests/inputs/operators.sql index 3934620577e99..a8de23e73892c 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/operators.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/operators.sql @@ -80,3 +80,8 @@ select 1 > 0.00001; -- mod select mod(7, 2), mod(7, 0), mod(0, 2), mod(7, null), mod(null, 2), mod(null, null); + +-- length +select BIT_LENGTH('abc'); +select CHAR_LENGTH('abc'); +select OCTET_LENGTH('abc'); diff --git a/sql/core/src/test/resources/sql-tests/results/operators.sql.out b/sql/core/src/test/resources/sql-tests/results/operators.sql.out index 51ccf764d952f..85ee10b4d274f 100644 --- a/sql/core/src/test/resources/sql-tests/results/operators.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/operators.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 51 +-- Number of queries: 54 -- !query 0 @@ -420,3 +420,27 @@ select mod(7, 2), mod(7, 0), mod(0, 2), mod(7, null), mod(null, 2), mod(null, nu struct<(7 % 2):int,(7 % 0):int,(0 % 2):int,(7 % CAST(NULL AS INT)):int,(CAST(NULL AS INT) % 2):int,(CAST(NULL AS DOUBLE) % CAST(NULL AS DOUBLE)):double> -- !query 50 output 1 NULL 0 NULL NULL NULL + + +-- !query 51 +select BIT_LENGTH('abc') +-- !query 51 schema +struct +-- !query 51 output +24 + + +-- !query 52 +select CHAR_LENGTH('abc') +-- !query 52 schema +struct +-- !query 52 output +3 + + +-- !query 53 +select OCTET_LENGTH('abc') +-- !query 53 schema +struct +-- !query 53 output +3 From 2837b14cdc42f096dce07e383caa30c7469c5d6b Mon Sep 17 00:00:00 2001 From: jerryshao Date: Fri, 16 Jun 2017 14:24:15 +0800 Subject: [PATCH 0728/1765] [SPARK-12552][FOLLOWUP] Fix flaky test for "o.a.s.deploy.master.MasterSuite.master correctly recover the application" ## What changes were proposed in this pull request? Due to the RPC asynchronous event processing, The test "correctly recover the application" could potentially be failed. The issue could be found in here: https://amplab.cs.berkeley.edu/jenkins/job/SparkPullRequestBuilder/78126/testReport/org.apache.spark.deploy.master/MasterSuite/master_correctly_recover_the_application/. So here fixing this flaky test. ## How was this patch tested? Existing UT. CC cloud-fan jiangxb1987 , please help to review, thanks! Author: jerryshao Closes #18321 from jerryshao/SPARK-12552-followup. --- .../test/scala/org/apache/spark/deploy/master/MasterSuite.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/core/src/test/scala/org/apache/spark/deploy/master/MasterSuite.scala b/core/src/test/scala/org/apache/spark/deploy/master/MasterSuite.scala index 6bb0eec040787..a2232126787f6 100644 --- a/core/src/test/scala/org/apache/spark/deploy/master/MasterSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/master/MasterSuite.scala @@ -214,7 +214,7 @@ class MasterSuite extends SparkFunSuite master.rpcEnv.setupEndpoint(Master.ENDPOINT_NAME, master) // Wait until Master recover from checkpoint data. eventually(timeout(5 seconds), interval(100 milliseconds)) { - master.idToApp.size should be(1) + master.workers.size should be(1) } master.idToApp.keySet should be(Set(fakeAppInfo.id)) From 45824fb608930eb461e7df53bb678c9534c183a9 Mon Sep 17 00:00:00 2001 From: Yuming Wang Date: Fri, 16 Jun 2017 11:03:54 +0100 Subject: [PATCH 0729/1765] [MINOR][DOCS] Improve Running R Tests docs ## What changes were proposed in this pull request? Update Running R Tests dependence packages to: ```bash R -e "install.packages(c('knitr', 'rmarkdown', 'testthat', 'e1071', 'survival'), repos='http://cran.us.r-project.org')" ``` ## How was this patch tested? manual tests Author: Yuming Wang Closes #18271 from wangyum/building-spark. --- R/README.md | 6 +----- R/WINDOWS.md | 3 +-- docs/building-spark.md | 8 +++++--- 3 files changed, 7 insertions(+), 10 deletions(-) diff --git a/R/README.md b/R/README.md index 4c40c5963db70..1152b1e8e5f9f 100644 --- a/R/README.md +++ b/R/README.md @@ -66,11 +66,7 @@ To run one of them, use `./bin/spark-submit `. For example: ```bash ./bin/spark-submit examples/src/main/r/dataframe.R ``` -You can also run the unit tests for SparkR by running. You need to install the [testthat](http://cran.r-project.org/web/packages/testthat/index.html) package first: -```bash -R -e 'install.packages("testthat", repos="http://cran.us.r-project.org")' -./R/run-tests.sh -``` +You can run R unit tests by following the instructions under [Running R Tests](http://spark.apache.org/docs/latest/building-spark.html#running-r-tests). ### Running on YARN diff --git a/R/WINDOWS.md b/R/WINDOWS.md index 9ca7e58e20cd2..124bc631be9cd 100644 --- a/R/WINDOWS.md +++ b/R/WINDOWS.md @@ -34,10 +34,9 @@ To run the SparkR unit tests on Windows, the following steps are required —ass 4. Set the environment variable `HADOOP_HOME` to the full path to the newly created `hadoop` directory. -5. Run unit tests for SparkR by running the command below. You need to install the [testthat](http://cran.r-project.org/web/packages/testthat/index.html) package first: +5. Run unit tests for SparkR by running the command below. You need to install the needed packages following the instructions under [Running R Tests](http://spark.apache.org/docs/latest/building-spark.html#running-r-tests) first: ``` - R -e "install.packages('testthat', repos='http://cran.us.r-project.org')" .\bin\spark-submit2.cmd --conf spark.hadoop.fs.defaultFS="file:///" R\pkg\tests\run-all.R ``` diff --git a/docs/building-spark.md b/docs/building-spark.md index 0f551bc66b8c9..777635a64f83c 100644 --- a/docs/building-spark.md +++ b/docs/building-spark.md @@ -218,9 +218,11 @@ The run-tests script also can be limited to a specific Python version or a speci ## Running R Tests -To run the SparkR tests you will need to install the R package `testthat` -(run `install.packages(testthat)` from R shell). You can run just the SparkR tests using -the command: +To run the SparkR tests you will need to install the [knitr](https://cran.r-project.org/package=knitr), [rmarkdown](https://cran.r-project.org/package=rmarkdown), [testthat](https://cran.r-project.org/package=testthat), [e1071](https://cran.r-project.org/package=e1071) and [survival](https://cran.r-project.org/package=survival) packages first: + + R -e "install.packages(c('knitr', 'rmarkdown', 'testthat', 'e1071', 'survival'), repos='http://cran.us.r-project.org')" + +You can run just the SparkR tests using the command: ./R/run-tests.sh From 93dd0c518d040155b04e5ab258c5835aec7776fc Mon Sep 17 00:00:00 2001 From: jinxing Date: Fri, 16 Jun 2017 20:09:45 +0800 Subject: [PATCH 0730/1765] [SPARK-20994] Remove redundant characters in OpenBlocks to save memory for shuffle service. ## What changes were proposed in this pull request? In current code, blockIds in `OpenBlocks` are stored in the iterator on shuffle service. There are some redundant characters in blockId(`"shuffle_" + shuffleId + "_" + mapId + "_" + reduceId`). This pr proposes to improve the footprint and alleviate the memory pressure on shuffle service. Author: jinxing Closes #18231 from jinxing64/SPARK-20994-v2. --- .../shuffle/ExternalShuffleBlockHandler.java | 70 +++++++++++++------ .../shuffle/ExternalShuffleBlockResolver.java | 23 +++--- .../network/sasl/SaslIntegrationSuite.java | 2 +- .../ExternalShuffleBlockHandlerSuite.java | 11 +-- .../ExternalShuffleBlockResolverSuite.java | 10 +-- .../ExternalShuffleIntegrationSuite.java | 8 +-- 6 files changed, 73 insertions(+), 51 deletions(-) diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandler.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandler.java index c0f1da50f5e65..fc7bba41185f0 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandler.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandler.java @@ -44,7 +44,6 @@ import static org.apache.spark.network.util.NettyUtils.getRemoteAddress; import org.apache.spark.network.util.TransportConf; - /** * RPC Handler for a server which can serve shuffle blocks from outside of an Executor process. * @@ -91,26 +90,8 @@ protected void handleMessage( try { OpenBlocks msg = (OpenBlocks) msgObj; checkAuth(client, msg.appId); - - Iterator iter = new Iterator() { - private int index = 0; - - @Override - public boolean hasNext() { - return index < msg.blockIds.length; - } - - @Override - public ManagedBuffer next() { - final ManagedBuffer block = blockManager.getBlockData(msg.appId, msg.execId, - msg.blockIds[index]); - index++; - metrics.blockTransferRateBytes.mark(block != null ? block.size() : 0); - return block; - } - }; - - long streamId = streamManager.registerStream(client.getClientId(), iter); + long streamId = streamManager.registerStream(client.getClientId(), + new ManagedBufferIterator(msg.appId, msg.execId, msg.blockIds)); if (logger.isTraceEnabled()) { logger.trace("Registered streamId {} with {} buffers for client {} from host {}", streamId, @@ -209,4 +190,51 @@ public Map getMetrics() { } } + private class ManagedBufferIterator implements Iterator { + + private int index = 0; + private final String appId; + private final String execId; + private final int shuffleId; + // An array containing mapId and reduceId pairs. + private final int[] mapIdAndReduceIds; + + ManagedBufferIterator(String appId, String execId, String[] blockIds) { + this.appId = appId; + this.execId = execId; + String[] blockId0Parts = blockIds[0].split("_"); + if (blockId0Parts.length != 4 || !blockId0Parts[0].equals("shuffle")) { + throw new IllegalArgumentException("Unexpected shuffle block id format: " + blockIds[0]); + } + this.shuffleId = Integer.parseInt(blockId0Parts[1]); + mapIdAndReduceIds = new int[2 * blockIds.length]; + for (int i = 0; i < blockIds.length; i++) { + String[] blockIdParts = blockIds[i].split("_"); + if (blockIdParts.length != 4 || !blockIdParts[0].equals("shuffle")) { + throw new IllegalArgumentException("Unexpected shuffle block id format: " + blockIds[i]); + } + if (Integer.parseInt(blockIdParts[1]) != shuffleId) { + throw new IllegalArgumentException("Expected shuffleId=" + shuffleId + + ", got:" + blockIds[i]); + } + mapIdAndReduceIds[2 * i] = Integer.parseInt(blockIdParts[2]); + mapIdAndReduceIds[2 * i + 1] = Integer.parseInt(blockIdParts[3]); + } + } + + @Override + public boolean hasNext() { + return index < mapIdAndReduceIds.length; + } + + @Override + public ManagedBuffer next() { + final ManagedBuffer block = blockManager.getBlockData(appId, execId, shuffleId, + mapIdAndReduceIds[index], mapIdAndReduceIds[index + 1]); + index += 2; + metrics.blockTransferRateBytes.mark(block != null ? block.size() : 0); + return block; + } + } + } diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolver.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolver.java index 62d58aba4c1e7..d7ec0e299dead 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolver.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolver.java @@ -150,27 +150,20 @@ public void registerExecutor( } /** - * Obtains a FileSegmentManagedBuffer from a shuffle block id. We expect the blockId has the - * format "shuffle_ShuffleId_MapId_ReduceId" (from ShuffleBlockId), and additionally make - * assumptions about how the hash and sort based shuffles store their data. + * Obtains a FileSegmentManagedBuffer from (shuffleId, mapId, reduceId). We make assumptions + * about how the hash and sort based shuffles store their data. */ - public ManagedBuffer getBlockData(String appId, String execId, String blockId) { - String[] blockIdParts = blockId.split("_"); - if (blockIdParts.length < 4) { - throw new IllegalArgumentException("Unexpected block id format: " + blockId); - } else if (!blockIdParts[0].equals("shuffle")) { - throw new IllegalArgumentException("Expected shuffle block id, got: " + blockId); - } - int shuffleId = Integer.parseInt(blockIdParts[1]); - int mapId = Integer.parseInt(blockIdParts[2]); - int reduceId = Integer.parseInt(blockIdParts[3]); - + public ManagedBuffer getBlockData( + String appId, + String execId, + int shuffleId, + int mapId, + int reduceId) { ExecutorShuffleInfo executor = executors.get(new AppExecId(appId, execId)); if (executor == null) { throw new RuntimeException( String.format("Executor is not registered (appId=%s, execId=%s)", appId, execId)); } - return getSortBasedShuffleBlockData(executor, shuffleId, mapId, reduceId); } diff --git a/common/network-shuffle/src/test/java/org/apache/spark/network/sasl/SaslIntegrationSuite.java b/common/network-shuffle/src/test/java/org/apache/spark/network/sasl/SaslIntegrationSuite.java index 0c054fc5db8f4..8110f1e004c73 100644 --- a/common/network-shuffle/src/test/java/org/apache/spark/network/sasl/SaslIntegrationSuite.java +++ b/common/network-shuffle/src/test/java/org/apache/spark/network/sasl/SaslIntegrationSuite.java @@ -202,7 +202,7 @@ public void onBlockFetchFailure(String blockId, Throwable t) { } }; - String[] blockIds = { "shuffle_2_3_4", "shuffle_6_7_8" }; + String[] blockIds = { "shuffle_0_1_2", "shuffle_0_3_4" }; OneForOneBlockFetcher fetcher = new OneForOneBlockFetcher(client1, "app-2", "0", blockIds, listener, conf, null); fetcher.start(); diff --git a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandlerSuite.java b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandlerSuite.java index 4d48b18970386..7846b71d5a8b1 100644 --- a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandlerSuite.java +++ b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandlerSuite.java @@ -83,9 +83,10 @@ public void testOpenShuffleBlocks() { ManagedBuffer block0Marker = new NioManagedBuffer(ByteBuffer.wrap(new byte[3])); ManagedBuffer block1Marker = new NioManagedBuffer(ByteBuffer.wrap(new byte[7])); - when(blockResolver.getBlockData("app0", "exec1", "b0")).thenReturn(block0Marker); - when(blockResolver.getBlockData("app0", "exec1", "b1")).thenReturn(block1Marker); - ByteBuffer openBlocks = new OpenBlocks("app0", "exec1", new String[] { "b0", "b1" }) + when(blockResolver.getBlockData("app0", "exec1", 0, 0, 0)).thenReturn(block0Marker); + when(blockResolver.getBlockData("app0", "exec1", 0, 0, 1)).thenReturn(block1Marker); + ByteBuffer openBlocks = new OpenBlocks("app0", "exec1", + new String[] { "shuffle_0_0_0", "shuffle_0_0_1" }) .toByteBuffer(); handler.receive(client, openBlocks, callback); @@ -105,8 +106,8 @@ public void testOpenShuffleBlocks() { assertEquals(block0Marker, buffers.next()); assertEquals(block1Marker, buffers.next()); assertFalse(buffers.hasNext()); - verify(blockResolver, times(1)).getBlockData("app0", "exec1", "b0"); - verify(blockResolver, times(1)).getBlockData("app0", "exec1", "b1"); + verify(blockResolver, times(1)).getBlockData("app0", "exec1", 0, 0, 0); + verify(blockResolver, times(1)).getBlockData("app0", "exec1", 0, 0, 1); // Verify open block request latency metrics Timer openBlockRequestLatencyMillis = (Timer) ((ExternalShuffleBlockHandler) handler) diff --git a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolverSuite.java b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolverSuite.java index bc97594903bef..23438a08fa094 100644 --- a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolverSuite.java +++ b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolverSuite.java @@ -65,7 +65,7 @@ public void testBadRequests() throws IOException { ExternalShuffleBlockResolver resolver = new ExternalShuffleBlockResolver(conf, null); // Unregistered executor try { - resolver.getBlockData("app0", "exec1", "shuffle_1_1_0"); + resolver.getBlockData("app0", "exec1", 1, 1, 0); fail("Should have failed"); } catch (RuntimeException e) { assertTrue("Bad error message: " + e, e.getMessage().contains("not registered")); @@ -74,7 +74,7 @@ public void testBadRequests() throws IOException { // Invalid shuffle manager try { resolver.registerExecutor("app0", "exec2", dataContext.createExecutorInfo("foobar")); - resolver.getBlockData("app0", "exec2", "shuffle_1_1_0"); + resolver.getBlockData("app0", "exec2", 1, 1, 0); fail("Should have failed"); } catch (UnsupportedOperationException e) { // pass @@ -84,7 +84,7 @@ public void testBadRequests() throws IOException { resolver.registerExecutor("app0", "exec3", dataContext.createExecutorInfo(SORT_MANAGER)); try { - resolver.getBlockData("app0", "exec3", "shuffle_1_1_0"); + resolver.getBlockData("app0", "exec3", 1, 1, 0); fail("Should have failed"); } catch (Exception e) { // pass @@ -98,14 +98,14 @@ public void testSortShuffleBlocks() throws IOException { dataContext.createExecutorInfo(SORT_MANAGER)); InputStream block0Stream = - resolver.getBlockData("app0", "exec0", "shuffle_0_0_0").createInputStream(); + resolver.getBlockData("app0", "exec0", 0, 0, 0).createInputStream(); String block0 = CharStreams.toString( new InputStreamReader(block0Stream, StandardCharsets.UTF_8)); block0Stream.close(); assertEquals(sortBlock0, block0); InputStream block1Stream = - resolver.getBlockData("app0", "exec0", "shuffle_0_0_1").createInputStream(); + resolver.getBlockData("app0", "exec0", 0, 0, 1).createInputStream(); String block1 = CharStreams.toString( new InputStreamReader(block1Stream, StandardCharsets.UTF_8)); block1Stream.close(); diff --git a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleIntegrationSuite.java b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleIntegrationSuite.java index d1d8f5b4e188a..4391e3023491b 100644 --- a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleIntegrationSuite.java +++ b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleIntegrationSuite.java @@ -214,10 +214,10 @@ public void testFetchNonexistent() throws Exception { @Test public void testFetchWrongExecutor() throws Exception { registerExecutor("exec-0", dataContext0.createExecutorInfo(SORT_MANAGER)); - FetchResult execFetch = fetchBlocks("exec-0", - new String[] { "shuffle_0_0_0" /* right */, "shuffle_1_0_0" /* wrong */ }); - assertEquals(Sets.newHashSet("shuffle_0_0_0"), execFetch.successBlocks); - assertEquals(Sets.newHashSet("shuffle_1_0_0"), execFetch.failedBlocks); + FetchResult execFetch0 = fetchBlocks("exec-0", new String[] { "shuffle_0_0_0" /* right */}); + FetchResult execFetch1 = fetchBlocks("exec-0", new String[] { "shuffle_1_0_0" /* wrong */ }); + assertEquals(Sets.newHashSet("shuffle_0_0_0"), execFetch0.successBlocks); + assertEquals(Sets.newHashSet("shuffle_1_0_0"), execFetch1.failedBlocks); } @Test From d1c333ac77e2554832477fd9ec56fb0b2015cde6 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Fri, 16 Jun 2017 08:05:43 -0700 Subject: [PATCH 0731/1765] [SPARK-21119][SQL] unset table properties should keep the table comment ## What changes were proposed in this pull request? Previous code mistakenly use `table.properties.get("comment")` to read the existing table comment, we should use `table.comment` ## How was this patch tested? new regression test Author: Wenchen Fan Closes #18325 from cloud-fan/unset. --- .../spark/sql/execution/command/ddl.scala | 4 +- .../resources/sql-tests/inputs/describe.sql | 8 + .../sql-tests/results/describe.sql.out | 201 ++++++++++++------ 3 files changed, 148 insertions(+), 65 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala index f924b3d914635..413f5f3ba539c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala @@ -264,14 +264,14 @@ case class AlterTableUnsetPropertiesCommand( DDLUtils.verifyAlterTableType(catalog, table, isView) if (!ifExists) { propKeys.foreach { k => - if (!table.properties.contains(k)) { + if (!table.properties.contains(k) && k != "comment") { throw new AnalysisException( s"Attempted to unset non-existent property '$k' in table '${table.identifier}'") } } } // If comment is in the table property, we reset it to None - val tableComment = if (propKeys.contains("comment")) None else table.properties.get("comment") + val tableComment = if (propKeys.contains("comment")) None else table.comment val newProperties = table.properties.filter { case (k, _) => !propKeys.contains(k) } val newTable = table.copy(properties = newProperties, comment = tableComment) catalog.alterTable(newTable) diff --git a/sql/core/src/test/resources/sql-tests/inputs/describe.sql b/sql/core/src/test/resources/sql-tests/inputs/describe.sql index 91b966829f8fb..a222e11916cda 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/describe.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/describe.sql @@ -28,6 +28,14 @@ DESC FORMATTED t; DESC EXTENDED t; +ALTER TABLE t UNSET TBLPROPERTIES (e); + +DESC EXTENDED t; + +ALTER TABLE t UNSET TBLPROPERTIES (comment); + +DESC EXTENDED t; + DESC t PARTITION (c='Us', d=1); DESC EXTENDED t PARTITION (c='Us', d=1); diff --git a/sql/core/src/test/resources/sql-tests/results/describe.sql.out b/sql/core/src/test/resources/sql-tests/results/describe.sql.out index ab9f2783f06bb..e2b79e8f7801d 100644 --- a/sql/core/src/test/resources/sql-tests/results/describe.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/describe.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 32 +-- Number of queries: 36 -- !query 0 @@ -166,10 +166,85 @@ Partition Provider Catalog -- !query 11 -DESC t PARTITION (c='Us', d=1) +ALTER TABLE t UNSET TBLPROPERTIES (e) -- !query 11 schema -struct +struct<> -- !query 11 output + + + +-- !query 12 +DESC EXTENDED t +-- !query 12 schema +struct +-- !query 12 output +a string +b int +c string +d string +# Partition Information +# col_name data_type comment +c string +d string + +# Detailed Table Information +Database default +Table t +Created [not included in comparison] +Last Access [not included in comparison] +Type MANAGED +Provider parquet +Num Buckets 2 +Bucket Columns [`a`] +Sort Columns [`b`] +Comment table_comment +Location [not included in comparison]sql/core/spark-warehouse/t +Storage Properties [a=1, b=2] +Partition Provider Catalog + + +-- !query 13 +ALTER TABLE t UNSET TBLPROPERTIES (comment) +-- !query 13 schema +struct<> +-- !query 13 output + + + +-- !query 14 +DESC EXTENDED t +-- !query 14 schema +struct +-- !query 14 output +a string +b int +c string +d string +# Partition Information +# col_name data_type comment +c string +d string + +# Detailed Table Information +Database default +Table t +Created [not included in comparison] +Last Access [not included in comparison] +Type MANAGED +Provider parquet +Num Buckets 2 +Bucket Columns [`a`] +Sort Columns [`b`] +Location [not included in comparison]sql/core/spark-warehouse/t +Storage Properties [a=1, b=2] +Partition Provider Catalog + + +-- !query 15 +DESC t PARTITION (c='Us', d=1) +-- !query 15 schema +struct +-- !query 15 output a string b int c string @@ -180,11 +255,11 @@ c string d string --- !query 12 +-- !query 16 DESC EXTENDED t PARTITION (c='Us', d=1) --- !query 12 schema +-- !query 16 schema struct --- !query 12 output +-- !query 16 output a string b int c string @@ -209,11 +284,11 @@ Location [not included in comparison]sql/core/spark-warehouse/t Storage Properties [a=1, b=2] --- !query 13 +-- !query 17 DESC FORMATTED t PARTITION (c='Us', d=1) --- !query 13 schema +-- !query 17 schema struct --- !query 13 output +-- !query 17 output a string b int c string @@ -238,31 +313,31 @@ Location [not included in comparison]sql/core/spark-warehouse/t Storage Properties [a=1, b=2] --- !query 14 +-- !query 18 DESC t PARTITION (c='Us', d=2) --- !query 14 schema +-- !query 18 schema struct<> --- !query 14 output +-- !query 18 output org.apache.spark.sql.catalyst.analysis.NoSuchPartitionException Partition not found in table 't' database 'default': c -> Us d -> 2; --- !query 15 +-- !query 19 DESC t PARTITION (c='Us') --- !query 15 schema +-- !query 19 schema struct<> --- !query 15 output +-- !query 19 output org.apache.spark.sql.AnalysisException Partition spec is invalid. The spec (c) must match the partition spec (c, d) defined in table '`default`.`t`'; --- !query 16 +-- !query 20 DESC t PARTITION (c='Us', d) --- !query 16 schema +-- !query 20 schema struct<> --- !query 16 output +-- !query 20 output org.apache.spark.sql.catalyst.parser.ParseException PARTITION specification is incomplete: `d`(line 1, pos 0) @@ -272,55 +347,55 @@ DESC t PARTITION (c='Us', d) ^^^ --- !query 17 +-- !query 21 DESC temp_v --- !query 17 schema +-- !query 21 schema struct --- !query 17 output +-- !query 21 output a string b int c string d string --- !query 18 +-- !query 22 DESC TABLE temp_v --- !query 18 schema +-- !query 22 schema struct --- !query 18 output +-- !query 22 output a string b int c string d string --- !query 19 +-- !query 23 DESC FORMATTED temp_v --- !query 19 schema +-- !query 23 schema struct --- !query 19 output +-- !query 23 output a string b int c string d string --- !query 20 +-- !query 24 DESC EXTENDED temp_v --- !query 20 schema +-- !query 24 schema struct --- !query 20 output +-- !query 24 output a string b int c string d string --- !query 21 +-- !query 25 DESC temp_Data_Source_View --- !query 21 schema +-- !query 25 schema struct --- !query 21 output +-- !query 25 output intType int test comment test1 stringType string dateType date @@ -339,42 +414,42 @@ arrayType array structType struct --- !query 22 +-- !query 26 DESC temp_v PARTITION (c='Us', d=1) --- !query 22 schema +-- !query 26 schema struct<> --- !query 22 output +-- !query 26 output org.apache.spark.sql.AnalysisException DESC PARTITION is not allowed on a temporary view: temp_v; --- !query 23 +-- !query 27 DESC v --- !query 23 schema +-- !query 27 schema struct --- !query 23 output +-- !query 27 output a string b int c string d string --- !query 24 +-- !query 28 DESC TABLE v --- !query 24 schema +-- !query 28 schema struct --- !query 24 output +-- !query 28 output a string b int c string d string --- !query 25 +-- !query 29 DESC FORMATTED v --- !query 25 schema +-- !query 29 schema struct --- !query 25 output +-- !query 29 output a string b int c string @@ -392,11 +467,11 @@ View Query Output Columns [a, b, c, d] Table Properties [view.query.out.col.3=d, view.query.out.col.0=a, view.query.out.numCols=4, view.default.database=default, view.query.out.col.1=b, view.query.out.col.2=c] --- !query 26 +-- !query 30 DESC EXTENDED v --- !query 26 schema +-- !query 30 schema struct --- !query 26 output +-- !query 30 output a string b int c string @@ -414,42 +489,42 @@ View Query Output Columns [a, b, c, d] Table Properties [view.query.out.col.3=d, view.query.out.col.0=a, view.query.out.numCols=4, view.default.database=default, view.query.out.col.1=b, view.query.out.col.2=c] --- !query 27 +-- !query 31 DESC v PARTITION (c='Us', d=1) --- !query 27 schema +-- !query 31 schema struct<> --- !query 27 output +-- !query 31 output org.apache.spark.sql.AnalysisException DESC PARTITION is not allowed on a view: v; --- !query 28 +-- !query 32 DROP TABLE t --- !query 28 schema +-- !query 32 schema struct<> --- !query 28 output +-- !query 32 output --- !query 29 +-- !query 33 DROP VIEW temp_v --- !query 29 schema +-- !query 33 schema struct<> --- !query 29 output +-- !query 33 output --- !query 30 +-- !query 34 DROP VIEW temp_Data_Source_View --- !query 30 schema +-- !query 34 schema struct<> --- !query 30 output +-- !query 34 output --- !query 31 +-- !query 35 DROP VIEW v --- !query 31 schema +-- !query 35 schema struct<> --- !query 31 output +-- !query 35 output From 53e48f73e42bb3eea075894ff08494e0abe9d60a Mon Sep 17 00:00:00 2001 From: Yuming Wang Date: Fri, 16 Jun 2017 09:40:58 -0700 Subject: [PATCH 0732/1765] [SPARK-20931][SQL] ABS function support string type. ## What changes were proposed in this pull request? ABS function support string type. Hive/MySQL support this feature. Ref: https://github.com/apache/hive/blob/4ba713ccd85c3706d195aeef9476e6e6363f1c21/ql/src/java/org/apache/hadoop/hive/ql/udf/generic/GenericUDFAbs.java#L93 ## How was this patch tested? unit tests Author: Yuming Wang Closes #18153 from wangyum/SPARK-20931. --- .../spark/sql/catalyst/analysis/TypeCoercion.scala | 1 + .../analysis/ExpressionTypeCheckingSuite.scala | 1 - .../src/test/resources/sql-tests/inputs/operators.sql | 3 +++ .../test/resources/sql-tests/results/operators.sql.out | 10 +++++++++- 4 files changed, 13 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala index 1f217390518a6..6082c58e2c53a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala @@ -357,6 +357,7 @@ object TypeCoercion { val commonType = findCommonTypeForBinaryComparison(left.dataType, right.dataType).get p.makeCopy(Array(castExpr(left, commonType), castExpr(right, commonType))) + case Abs(e @ StringType()) => Abs(Cast(e, DoubleType)) case Sum(e @ StringType()) => Sum(Cast(e, DoubleType)) case Average(e @ StringType()) => Average(Cast(e, DoubleType)) case StddevPop(e @ StringType()) => StddevPop(Cast(e, DoubleType)) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala index 744057b7c5f4c..2239bf815de71 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala @@ -57,7 +57,6 @@ class ExpressionTypeCheckingSuite extends SparkFunSuite { test("check types for unary arithmetic") { assertError(UnaryMinus('stringField), "(numeric or calendarinterval) type") - assertError(Abs('stringField), "requires numeric type") assertError(BitwiseNot('stringField), "requires integral type") } diff --git a/sql/core/src/test/resources/sql-tests/inputs/operators.sql b/sql/core/src/test/resources/sql-tests/inputs/operators.sql index a8de23e73892c..a1e8a32ed8f66 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/operators.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/operators.sql @@ -85,3 +85,6 @@ select mod(7, 2), mod(7, 0), mod(0, 2), mod(7, null), mod(null, 2), mod(null, nu select BIT_LENGTH('abc'); select CHAR_LENGTH('abc'); select OCTET_LENGTH('abc'); + +-- abs +select abs(-3.13), abs('-2.19'); diff --git a/sql/core/src/test/resources/sql-tests/results/operators.sql.out b/sql/core/src/test/resources/sql-tests/results/operators.sql.out index 85ee10b4d274f..eac3080bec67d 100644 --- a/sql/core/src/test/resources/sql-tests/results/operators.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/operators.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 54 +-- Number of queries: 55 -- !query 0 @@ -444,3 +444,11 @@ select OCTET_LENGTH('abc') struct -- !query 53 output 3 + + +-- !query 54 +select abs(-3.13), abs('-2.19') +-- !query 54 schema +struct +-- !query 54 output +3.13 2.19 From edcb878e2fbd0d85bf70614fed37f4cbf0caa95e Mon Sep 17 00:00:00 2001 From: zuotingbing Date: Fri, 16 Jun 2017 10:34:52 -0700 Subject: [PATCH 0733/1765] [SPARK-20338][CORE] Spaces in spark.eventLog.dir are not correctly handled MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## What changes were proposed in this pull request? “spark.eventLog.dir” supports with space characters. 1. Update EventLoggingListenerSuite like `testDir = Utils.createTempDir(namePrefix = s"history log")` 2. Fix EventLoggingListenerSuite tests ## How was this patch tested? update unit tests Author: zuotingbing Closes #18285 from zuotingbing/spark-resolveURI. --- .../org/apache/spark/scheduler/EventLoggingListener.scala | 4 ++-- .../spark/deploy/history/FsHistoryProviderSuite.scala | 5 ++--- .../apache/spark/scheduler/EventLoggingListenerSuite.scala | 7 +++---- 3 files changed, 7 insertions(+), 9 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala b/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala index f481436332249..35690b2783ad3 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala @@ -96,8 +96,8 @@ private[spark] class EventLoggingListener( } val workingPath = logPath + IN_PROGRESS - val uri = new URI(workingPath) val path = new Path(workingPath) + val uri = path.toUri val defaultFs = FileSystem.getDefaultUri(hadoopConf).getScheme val isDefaultLocal = defaultFs == null || defaultFs == "file" @@ -320,7 +320,7 @@ private[spark] object EventLoggingListener extends Logging { appId: String, appAttemptId: Option[String], compressionCodecName: Option[String] = None): String = { - val base = logBaseDir.toString.stripSuffix("/") + "/" + sanitize(appId) + val base = new Path(logBaseDir).toString.stripSuffix("/") + "/" + sanitize(appId) val codec = compressionCodecName.map("." + _).getOrElse("") if (appAttemptId.isDefined) { base + "_" + sanitize(appAttemptId.get) + codec diff --git a/core/src/test/scala/org/apache/spark/deploy/history/FsHistoryProviderSuite.scala b/core/src/test/scala/org/apache/spark/deploy/history/FsHistoryProviderSuite.scala index 9b3e4ec793825..7109146ece371 100644 --- a/core/src/test/scala/org/apache/spark/deploy/history/FsHistoryProviderSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/history/FsHistoryProviderSuite.scala @@ -18,7 +18,6 @@ package org.apache.spark.deploy.history import java.io._ -import java.net.URI import java.nio.charset.StandardCharsets import java.util.concurrent.TimeUnit import java.util.zip.{ZipInputStream, ZipOutputStream} @@ -27,7 +26,7 @@ import scala.concurrent.duration._ import scala.language.postfixOps import com.google.common.io.{ByteStreams, Files} -import org.apache.hadoop.fs.FileStatus +import org.apache.hadoop.fs.{FileStatus, Path} import org.apache.hadoop.hdfs.DistributedFileSystem import org.json4s.jackson.JsonMethods._ import org.mockito.Matchers.any @@ -63,7 +62,7 @@ class FsHistoryProviderSuite extends SparkFunSuite with BeforeAndAfter with Matc codec: Option[String] = None): File = { val ip = if (inProgress) EventLoggingListener.IN_PROGRESS else "" val logUri = EventLoggingListener.getLogPath(testDir.toURI, appId, appAttemptId) - val logPath = new URI(logUri).getPath + ip + val logPath = new Path(logUri).toUri.getPath + ip new File(logPath) } diff --git a/core/src/test/scala/org/apache/spark/scheduler/EventLoggingListenerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/EventLoggingListenerSuite.scala index 4cae6c61118a8..0afd07b851cf9 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/EventLoggingListenerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/EventLoggingListenerSuite.scala @@ -18,7 +18,6 @@ package org.apache.spark.scheduler import java.io.{File, FileOutputStream, InputStream, IOException} -import java.net.URI import scala.collection.mutable import scala.io.Source @@ -52,7 +51,7 @@ class EventLoggingListenerSuite extends SparkFunSuite with LocalSparkContext wit private var testDirPath: Path = _ before { - testDir = Utils.createTempDir() + testDir = Utils.createTempDir(namePrefix = s"history log") testDir.deleteOnExit() testDirPath = new Path(testDir.getAbsolutePath()) } @@ -111,7 +110,7 @@ class EventLoggingListenerSuite extends SparkFunSuite with LocalSparkContext wit test("Log overwriting") { val logUri = EventLoggingListener.getLogPath(testDir.toURI, "test", None) - val logPath = new URI(logUri).getPath + val logPath = new Path(logUri).toUri.getPath // Create file before writing the event log new FileOutputStream(new File(logPath)).close() // Expected IOException, since we haven't enabled log overwrite. @@ -293,7 +292,7 @@ object EventLoggingListenerSuite { val conf = new SparkConf conf.set("spark.eventLog.enabled", "true") conf.set("spark.eventLog.testing", "true") - conf.set("spark.eventLog.dir", logDir.toUri.toString) + conf.set("spark.eventLog.dir", logDir.toString) compressionCodec.foreach { codec => conf.set("spark.eventLog.compress", "true") conf.set("spark.io.compression.codec", codec) From 0d8604bb849b3370cc21966cdd773238f3a29f84 Mon Sep 17 00:00:00 2001 From: liuzhaokun Date: Sun, 18 Jun 2017 08:32:29 +0100 Subject: [PATCH 0734/1765] [SPARK-21126] The configuration which named "spark.core.connection.auth.wait.timeout" hasn't been used in spark [https://issues.apache.org/jira/browse/SPARK-21126](https://issues.apache.org/jira/browse/SPARK-21126) The configuration which named "spark.core.connection.auth.wait.timeout" hasn't been used in spark,so I think it should be removed from configuration.md. Author: liuzhaokun Closes #18333 from liu-zhaokun/new3. --- docs/configuration.md | 8 -------- 1 file changed, 8 deletions(-) diff --git a/docs/configuration.md b/docs/configuration.md index f777811a93f62..c1464741ebb6f 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -1774,14 +1774,6 @@ Apart from these, the following properties are also available, and may be useful you can set larger value. - - spark.core.connection.auth.wait.timeout - 30s - - How long for the connection to wait for authentication to occur before timing - out and giving up. - - spark.modify.acls Empty From 75a6d05853fea13f88e3c941b1959b24e4640824 Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Sun, 18 Jun 2017 08:43:47 +0100 Subject: [PATCH 0735/1765] [MINOR][R] Add knitr and rmarkdown packages/improve output for version info in AppVeyor tests MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## What changes were proposed in this pull request? This PR proposes three things as below: **Install packages per documentation** - this does not affect the tests itself (but CRAN which we are not doing via AppVeyor) up to my knowledge. This adds `knitr` and `rmarkdown` per https://github.com/apache/spark/blob/45824fb608930eb461e7df53bb678c9534c183a9/R/WINDOWS.md#unit-tests (please see https://github.com/apache/spark/commit/45824fb608930eb461e7df53bb678c9534c183a9) **Improve logs/shorten logs** - actually, long logs can be a problem on AppVeyor (e.g., see https://github.com/apache/spark/pull/17873) `R -e ...` repeats printing R information for each invocation as below: ``` R version 3.3.1 (2016-06-21) -- "Bug in Your Hair" Copyright (C) 2016 The R Foundation for Statistical Computing Platform: i386-w64-mingw32/i386 (32-bit) R is free software and comes with ABSOLUTELY NO WARRANTY. You are welcome to redistribute it under certain conditions. Type 'license()' or 'licence()' for distribution details. Natural language support but running in an English locale R is a collaborative project with many contributors. Type 'contributors()' for more information and 'citation()' on how to cite R or R packages in publications. Type 'demo()' for some demos, 'help()' for on-line help, or 'help.start()' for an HTML browser interface to help. Type 'q()' to quit R. ``` It looks reducing the call might be slightly better and print out the versions together looks more readable. Before: ``` # R information ... > packageVersion('testthat') [1] '1.0.2' > > # R information ... > packageVersion('e1071') [1] '1.6.8' > > ... 3 more times ``` After: ``` # R information ... > packageVersion('knitr'); packageVersion('rmarkdown'); packageVersion('testthat'); packageVersion('e1071'); packageVersion('survival') [1] ‘1.16’ [1] ‘1.6’ [1] ‘1.0.2’ [1] ‘1.6.8’ [1] ‘2.41.3’ ``` **Add`appveyor.yml`/`dev/appveyor-install-dependencies.ps1` for triggering the test** Changing this file might break the test, e.g., https://github.com/apache/spark/pull/16927 ## How was this patch tested? Before (please see https://ci.appveyor.com/project/HyukjinKwon/spark/build/169-master) After (please see the AppVeyor build in this PR): Author: hyukjinkwon Closes #18336 from HyukjinKwon/minor-add-knitr-and-rmarkdown. --- appveyor.yml | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/appveyor.yml b/appveyor.yml index 58c2e98289e96..43dad9bce60ac 100644 --- a/appveyor.yml +++ b/appveyor.yml @@ -26,6 +26,8 @@ branches: only_commits: files: + - appveyor.yml + - dev/appveyor-install-dependencies.ps1 - R/ - sql/core/src/main/scala/org/apache/spark/sql/api/r/ - core/src/main/scala/org/apache/spark/api/r/ @@ -38,12 +40,8 @@ install: # Install maven and dependencies - ps: .\dev\appveyor-install-dependencies.ps1 # Required package for R unit tests - - cmd: R -e "install.packages('testthat', repos='http://cran.us.r-project.org')" - - cmd: R -e "packageVersion('testthat')" - - cmd: R -e "install.packages('e1071', repos='http://cran.us.r-project.org')" - - cmd: R -e "packageVersion('e1071')" - - cmd: R -e "install.packages('survival', repos='http://cran.us.r-project.org')" - - cmd: R -e "packageVersion('survival')" + - cmd: R -e "install.packages(c('knitr', 'rmarkdown', 'testthat', 'e1071', 'survival'), repos='http://cran.us.r-project.org')" + - cmd: R -e "packageVersion('knitr'); packageVersion('rmarkdown'); packageVersion('testthat'); packageVersion('e1071'); packageVersion('survival')" build_script: - cmd: mvn -DskipTests -Psparkr -Phive -Phive-thriftserver package From 05f83c532a96ead8dec1c046f985164b7f7205c0 Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Sun, 18 Jun 2017 11:26:27 -0700 Subject: [PATCH 0736/1765] [SPARK-21128][R] Remove both "spark-warehouse" and "metastore_db" before listing files in R tests ## What changes were proposed in this pull request? This PR proposes to list the files in test _after_ removing both "spark-warehouse" and "metastore_db" so that the next run of R tests pass fine. This is sometimes a bit annoying. ## How was this patch tested? Manually running multiple times R tests via `./R/run-tests.sh`. **Before** Second run: ``` SparkSQL functions: Spark package found in SPARK_HOME: .../sparkailed ------------------------------------------------------------------------- 1. Failure: No extra files are created in SPARK_HOME by starting session and making calls (test_sparkSQL.R#3384) length(list1) not equal to length(list2). 1/1 mismatches [1] 25 - 23 == 2 2. Failure: No extra files are created in SPARK_HOME by starting session and making calls (test_sparkSQL.R#3384) sort(list1, na.last = TRUE) not equal to sort(list2, na.last = TRUE). 10/25 mismatches x[16]: "metastore_db" y[16]: "pkg" x[17]: "pkg" y[17]: "R" x[18]: "R" y[18]: "README.md" x[19]: "README.md" y[19]: "run-tests.sh" x[20]: "run-tests.sh" y[20]: "SparkR_2.2.0.tar.gz" x[21]: "metastore_db" y[21]: "pkg" x[22]: "pkg" y[22]: "R" x[23]: "R" y[23]: "README.md" x[24]: "README.md" y[24]: "run-tests.sh" x[25]: "run-tests.sh" y[25]: "SparkR_2.2.0.tar.gz" 3. Failure: No extra files are created in SPARK_HOME by starting session and making calls (test_sparkSQL.R#3388) length(list1) not equal to length(list2). 1/1 mismatches [1] 25 - 23 == 2 4. Failure: No extra files are created in SPARK_HOME by starting session and making calls (test_sparkSQL.R#3388) sort(list1, na.last = TRUE) not equal to sort(list2, na.last = TRUE). 10/25 mismatches x[16]: "metastore_db" y[16]: "pkg" x[17]: "pkg" y[17]: "R" x[18]: "R" y[18]: "README.md" x[19]: "README.md" y[19]: "run-tests.sh" x[20]: "run-tests.sh" y[20]: "SparkR_2.2.0.tar.gz" x[21]: "metastore_db" y[21]: "pkg" x[22]: "pkg" y[22]: "R" x[23]: "R" y[23]: "README.md" x[24]: "README.md" y[24]: "run-tests.sh" x[25]: "run-tests.sh" y[25]: "SparkR_2.2.0.tar.gz" DONE =========================================================================== ``` **After** Second run: ``` SparkSQL functions: Spark package found in SPARK_HOME: .../spark``` Author: hyukjinkwon Closes #18335 from HyukjinKwon/SPARK-21128. --- R/pkg/tests/run-all.R | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/R/pkg/tests/run-all.R b/R/pkg/tests/run-all.R index f00a610679752..0aefd8006caa4 100644 --- a/R/pkg/tests/run-all.R +++ b/R/pkg/tests/run-all.R @@ -30,10 +30,10 @@ if (.Platform$OS.type == "windows") { install.spark() sparkRDir <- file.path(Sys.getenv("SPARK_HOME"), "R") -sparkRFilesBefore <- list.files(path = sparkRDir, all.files = TRUE) sparkRWhitelistSQLDirs <- c("spark-warehouse", "metastore_db") invisible(lapply(sparkRWhitelistSQLDirs, function(x) { unlink(file.path(sparkRDir, x), recursive = TRUE, force = TRUE)})) +sparkRFilesBefore <- list.files(path = sparkRDir, all.files = TRUE) sparkRTestMaster <- "local[1]" if (identical(Sys.getenv("NOT_CRAN"), "true")) { From 110ce1f27b66905afada6b5fd63c34fbf7602739 Mon Sep 17 00:00:00 2001 From: actuaryzhang Date: Sun, 18 Jun 2017 18:00:27 -0700 Subject: [PATCH 0737/1765] [SPARK-20892][SPARKR] Add SQL trunc function to SparkR ## What changes were proposed in this pull request? Add SQL trunc function ## How was this patch tested? standard test Author: actuaryzhang Closes #18291 from actuaryzhang/sparkRTrunc2. --- R/pkg/NAMESPACE | 1 + R/pkg/R/functions.R | 29 +++++++++++++++++++++++++++ R/pkg/tests/fulltests/test_sparkSQL.R | 2 ++ 3 files changed, 32 insertions(+) diff --git a/R/pkg/NAMESPACE b/R/pkg/NAMESPACE index 4e3fe00a2e9bd..229de4a997eef 100644 --- a/R/pkg/NAMESPACE +++ b/R/pkg/NAMESPACE @@ -357,6 +357,7 @@ exportMethods("%<=>%", "to_utc_timestamp", "translate", "trim", + "trunc", "unbase64", "unhex", "unix_timestamp", diff --git a/R/pkg/R/functions.R b/R/pkg/R/functions.R index 06a90192bb12f..7128c3b9adff4 100644 --- a/R/pkg/R/functions.R +++ b/R/pkg/R/functions.R @@ -4015,3 +4015,32 @@ setMethod("input_file_name", signature("missing"), jc <- callJStatic("org.apache.spark.sql.functions", "input_file_name") column(jc) }) + +#' trunc +#' +#' Returns date truncated to the unit specified by the format. +#' +#' @param x Column to compute on. +#' @param format string used for specify the truncation method. For example, "year", "yyyy", +#' "yy" for truncate by year, or "month", "mon", "mm" for truncate by month. +#' +#' @rdname trunc +#' @name trunc +#' @family date time functions +#' @aliases trunc,Column-method +#' @export +#' @examples +#' \dontrun{ +#' trunc(df$c, "year") +#' trunc(df$c, "yy") +#' trunc(df$c, "month") +#' trunc(df$c, "mon") +#' } +#' @note trunc since 2.3.0 +setMethod("trunc", + signature(x = "Column"), + function(x, format) { + jc <- callJStatic("org.apache.spark.sql.functions", "trunc", + x@jc, as.character(format)) + column(jc) + }) diff --git a/R/pkg/tests/fulltests/test_sparkSQL.R b/R/pkg/tests/fulltests/test_sparkSQL.R index af529067f43e0..911b73b9ee551 100644 --- a/R/pkg/tests/fulltests/test_sparkSQL.R +++ b/R/pkg/tests/fulltests/test_sparkSQL.R @@ -1382,6 +1382,8 @@ test_that("column functions", { c20 <- to_timestamp(c) + to_timestamp(c, "yyyy") + to_date(c, "yyyy") c21 <- posexplode_outer(c) + explode_outer(c) c22 <- not(c) + c23 <- trunc(c, "year") + trunc(c, "yyyy") + trunc(c, "yy") + + trunc(c, "month") + trunc(c, "mon") + trunc(c, "mm") # Test if base::is.nan() is exposed expect_equal(is.nan(c("a", "b")), c(FALSE, FALSE)) From ce49428ef7d640c1734e91ffcddc49dbc8547ba7 Mon Sep 17 00:00:00 2001 From: Yuming Wang Date: Sun, 18 Jun 2017 18:56:53 -0700 Subject: [PATCH 0738/1765] [SPARK-20749][SQL][FOLLOWUP] Support character_length ## What changes were proposed in this pull request? The function `char_length` is shorthand for `character_length` function. Both Hive and Postgresql support `character_length`, This PR add support for `character_length`. Ref: https://cwiki.apache.org/confluence/display/Hive/LanguageManual+UDF#LanguageManualUDF-StringFunctions https://www.postgresql.org/docs/current/static/functions-string.html ## How was this patch tested? unit tests Author: Yuming Wang Closes #18330 from wangyum/SPARK-20749-character_length. --- .../catalyst/analysis/FunctionRegistry.scala | 1 + .../expressions/stringExpressions.scala | 4 ++++ .../resources/sql-tests/inputs/operators.sql | 1 + .../sql-tests/results/operators.sql.out | 18 +++++++++++++----- 4 files changed, 19 insertions(+), 5 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index e4e9918a3a887..f4b3e86052d8e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -307,6 +307,7 @@ object FunctionRegistry { expression[Base64]("base64"), expression[BitLength]("bit_length"), expression[Length]("char_length"), + expression[Length]("character_length"), expression[Concat]("concat"), expression[ConcatWs]("concat_ws"), expression[Decode]("decode"), diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala index 908fdb8f7e68f..83fdcfce9c3bd 100755 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala @@ -1209,6 +1209,10 @@ case class Substring(str: Expression, pos: Expression, len: Expression) Examples: > SELECT _FUNC_('Spark SQL'); 9 + > SELECT CHAR_LENGTH('Spark SQL'); + 9 + > SELECT CHARACTER_LENGTH('Spark SQL'); + 9 """) // scalastyle:on line.size.limit case class Length(child: Expression) extends UnaryExpression with ImplicitCastInputTypes { diff --git a/sql/core/src/test/resources/sql-tests/inputs/operators.sql b/sql/core/src/test/resources/sql-tests/inputs/operators.sql index a1e8a32ed8f66..9841ec4b65983 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/operators.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/operators.sql @@ -84,6 +84,7 @@ select mod(7, 2), mod(7, 0), mod(0, 2), mod(7, null), mod(null, 2), mod(null, nu -- length select BIT_LENGTH('abc'); select CHAR_LENGTH('abc'); +select CHARACTER_LENGTH('abc'); select OCTET_LENGTH('abc'); -- abs diff --git a/sql/core/src/test/resources/sql-tests/results/operators.sql.out b/sql/core/src/test/resources/sql-tests/results/operators.sql.out index eac3080bec67d..4a6ef27c3be42 100644 --- a/sql/core/src/test/resources/sql-tests/results/operators.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/operators.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 55 +-- Number of queries: 56 -- !query 0 @@ -439,16 +439,24 @@ struct -- !query 53 -select OCTET_LENGTH('abc') +select CHARACTER_LENGTH('abc') -- !query 53 schema -struct +struct -- !query 53 output 3 -- !query 54 -select abs(-3.13), abs('-2.19') +select OCTET_LENGTH('abc') -- !query 54 schema -struct +struct -- !query 54 output +3 + + +-- !query 55 +select abs(-3.13), abs('-2.19') +-- !query 55 schema +struct +-- !query 55 output 3.13 2.19 From f913f158ec41bd3de9dc229b908aaab0dbd60d27 Mon Sep 17 00:00:00 2001 From: Yuming Wang Date: Sun, 18 Jun 2017 20:14:05 -0700 Subject: [PATCH 0739/1765] [SPARK-20948][SQL] Built-in SQL Function UnaryMinus/UnaryPositive support string type ## What changes were proposed in this pull request? Built-in SQL Function UnaryMinus/UnaryPositive support string type, if it's string type, convert it to double type, after this PR: ```sql spark-sql> select positive('-1.11'), negative('-1.11'); -1.11 1.11 spark-sql> ``` ## How was this patch tested? unit tests Author: Yuming Wang Closes #18173 from wangyum/SPARK-20948. --- .../apache/spark/sql/catalyst/analysis/TypeCoercion.scala | 2 ++ .../catalyst/analysis/ExpressionTypeCheckingSuite.scala | 1 - .../src/test/resources/sql-tests/inputs/operators.sql | 3 +++ .../test/resources/sql-tests/results/operators.sql.out | 8 ++++++++ 4 files changed, 13 insertions(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala index 6082c58e2c53a..a78e1c98e89de 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala @@ -362,6 +362,8 @@ object TypeCoercion { case Average(e @ StringType()) => Average(Cast(e, DoubleType)) case StddevPop(e @ StringType()) => StddevPop(Cast(e, DoubleType)) case StddevSamp(e @ StringType()) => StddevSamp(Cast(e, DoubleType)) + case UnaryMinus(e @ StringType()) => UnaryMinus(Cast(e, DoubleType)) + case UnaryPositive(e @ StringType()) => UnaryPositive(Cast(e, DoubleType)) case VariancePop(e @ StringType()) => VariancePop(Cast(e, DoubleType)) case VarianceSamp(e @ StringType()) => VarianceSamp(Cast(e, DoubleType)) case Skewness(e @ StringType()) => Skewness(Cast(e, DoubleType)) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala index 2239bf815de71..30459f173ab52 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala @@ -56,7 +56,6 @@ class ExpressionTypeCheckingSuite extends SparkFunSuite { } test("check types for unary arithmetic") { - assertError(UnaryMinus('stringField), "(numeric or calendarinterval) type") assertError(BitwiseNot('stringField), "requires integral type") } diff --git a/sql/core/src/test/resources/sql-tests/inputs/operators.sql b/sql/core/src/test/resources/sql-tests/inputs/operators.sql index 9841ec4b65983..a766275192492 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/operators.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/operators.sql @@ -89,3 +89,6 @@ select OCTET_LENGTH('abc'); -- abs select abs(-3.13), abs('-2.19'); + +-- positive/negative +select positive('-1.11'), positive(-1.11), negative('-1.11'), negative(-1.11); diff --git a/sql/core/src/test/resources/sql-tests/results/operators.sql.out b/sql/core/src/test/resources/sql-tests/results/operators.sql.out index 4a6ef27c3be42..5cb6ed3e27bf2 100644 --- a/sql/core/src/test/resources/sql-tests/results/operators.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/operators.sql.out @@ -460,3 +460,11 @@ select abs(-3.13), abs('-2.19') struct -- !query 55 output 3.13 2.19 + + +-- !query 55 +select positive('-1.11'), positive(-1.11), negative('-1.11'), negative(-1.11) +-- !query 55 schema +struct<(+ CAST(-1.11 AS DOUBLE)):double,(+ -1.11):decimal(3,2),(- CAST(-1.11 AS DOUBLE)):double,(- -1.11):decimal(3,2)> +-- !query 55 output +-1.11 -1.11 1.11 1.11 From 112bd9bfc5b9729f6f86518998b5d80c5e79fe5e Mon Sep 17 00:00:00 2001 From: liuxian Date: Mon, 19 Jun 2017 11:46:58 +0800 Subject: [PATCH 0740/1765] [SPARK-21090][CORE] Optimize the unified memory manager code ## What changes were proposed in this pull request? 1.In `acquireStorageMemory`, when the Memory Mode is OFF_HEAP ,the `maxOffHeapMemory` should be modified to `maxOffHeapStorageMemory`. after this PR,it will same as ON_HEAP Memory Mode. Because when acquire memory is between `maxOffHeapStorageMemory` and `maxOffHeapMemory`,it will fail surely, so if acquire memory is greater than `maxOffHeapStorageMemory`(not greater than `maxOffHeapMemory`),we should fail fast. 2. Borrow memory from execution, `numBytes` modified to `numBytes - storagePool.memoryFree` will be more reasonable. Because we just acquire `(numBytes - storagePool.memoryFree)`, unnecessary borrowed `numBytes` from execution ## How was this patch tested? added unit test case Author: liuxian Closes #18296 from 10110346/wip-lx-0614. --- .../spark/memory/UnifiedMemoryManager.scala | 5 +-- .../spark/memory/MemoryManagerSuite.scala | 2 +- .../memory/UnifiedMemoryManagerSuite.scala | 32 +++++++++++++++++++ 3 files changed, 36 insertions(+), 3 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/memory/UnifiedMemoryManager.scala b/core/src/main/scala/org/apache/spark/memory/UnifiedMemoryManager.scala index fea2808218a53..df193552bed3c 100644 --- a/core/src/main/scala/org/apache/spark/memory/UnifiedMemoryManager.scala +++ b/core/src/main/scala/org/apache/spark/memory/UnifiedMemoryManager.scala @@ -160,7 +160,7 @@ private[spark] class UnifiedMemoryManager private[memory] ( case MemoryMode.OFF_HEAP => ( offHeapExecutionMemoryPool, offHeapStorageMemoryPool, - maxOffHeapMemory) + maxOffHeapStorageMemory) } if (numBytes > maxMemory) { // Fail fast if the block simply won't fit @@ -171,7 +171,8 @@ private[spark] class UnifiedMemoryManager private[memory] ( if (numBytes > storagePool.memoryFree) { // There is not enough free memory in the storage pool, so try to borrow free memory from // the execution pool. - val memoryBorrowedFromExecution = Math.min(executionPool.memoryFree, numBytes) + val memoryBorrowedFromExecution = Math.min(executionPool.memoryFree, + numBytes - storagePool.memoryFree) executionPool.decrementPoolSize(memoryBorrowedFromExecution) storagePool.incrementPoolSize(memoryBorrowedFromExecution) } diff --git a/core/src/test/scala/org/apache/spark/memory/MemoryManagerSuite.scala b/core/src/test/scala/org/apache/spark/memory/MemoryManagerSuite.scala index eb2b3ffd1509a..85eeb5055ae03 100644 --- a/core/src/test/scala/org/apache/spark/memory/MemoryManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/memory/MemoryManagerSuite.scala @@ -117,7 +117,7 @@ private[memory] trait MemoryManagerSuite extends SparkFunSuite with BeforeAndAft evictBlocksToFreeSpaceCalled.set(numBytesToFree) if (numBytesToFree <= mm.storageMemoryUsed) { // We can evict enough blocks to fulfill the request for space - mm.releaseStorageMemory(numBytesToFree, MemoryMode.ON_HEAP) + mm.releaseStorageMemory(numBytesToFree, mm.tungstenMemoryMode) evictedBlocks += Tuple2(null, BlockStatus(StorageLevel.MEMORY_ONLY, numBytesToFree, 0L)) numBytesToFree } else { diff --git a/core/src/test/scala/org/apache/spark/memory/UnifiedMemoryManagerSuite.scala b/core/src/test/scala/org/apache/spark/memory/UnifiedMemoryManagerSuite.scala index c821054412d7d..02b04cdbb2a5f 100644 --- a/core/src/test/scala/org/apache/spark/memory/UnifiedMemoryManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/memory/UnifiedMemoryManagerSuite.scala @@ -303,4 +303,36 @@ class UnifiedMemoryManagerSuite extends MemoryManagerSuite with PrivateMethodTes mm.invokePrivate[Unit](assertInvariants()) } + test("not enough free memory in the storage pool --OFF_HEAP") { + val conf = new SparkConf() + .set("spark.memory.offHeap.size", "1000") + .set("spark.testing.memory", "1000") + .set("spark.memory.offHeap.enabled", "true") + val taskAttemptId = 0L + val mm = UnifiedMemoryManager(conf, numCores = 1) + val ms = makeMemoryStore(mm) + val memoryMode = MemoryMode.OFF_HEAP + + assert(mm.acquireExecutionMemory(400L, taskAttemptId, memoryMode) === 400L) + assert(mm.storageMemoryUsed === 0L) + assert(mm.executionMemoryUsed === 400L) + + // Fail fast + assert(!mm.acquireStorageMemory(dummyBlock, 700L, memoryMode)) + assert(mm.storageMemoryUsed === 0L) + + assert(mm.acquireStorageMemory(dummyBlock, 100L, memoryMode)) + assert(mm.storageMemoryUsed === 100L) + assertEvictBlocksToFreeSpaceNotCalled(ms) + + // Borrow 50 from execution memory + assert(mm.acquireStorageMemory(dummyBlock, 450L, memoryMode)) + assertEvictBlocksToFreeSpaceNotCalled(ms) + assert(mm.storageMemoryUsed === 550L) + + // Borrow 50 from execution memory and evict 50 to free space + assert(mm.acquireStorageMemory(dummyBlock, 100L, memoryMode)) + assertEvictBlocksToFreeSpaceCalled(ms, 50) + assert(mm.storageMemoryUsed === 600L) + } } From ea542d29b2ae99cfff47fed40b7a9ab77d41b391 Mon Sep 17 00:00:00 2001 From: Xingbo Jiang Date: Sun, 18 Jun 2017 22:05:06 -0700 Subject: [PATCH 0741/1765] [SPARK-19824][CORE] Update JsonProtocol to keep consistent with the UI ## What changes were proposed in this pull request? Fix any inconsistent part in JsonProtocol with the UI. This PR also contains the modifications in #17181 ## How was this patch tested? Updated JsonProtocolSuite. Before this change, localhost:8080/json shows: ``` { "url" : "spark://xingbos-MBP.local:7077", "workers" : [ { "id" : "worker-20170615172946-192.168.0.101-49450", "host" : "192.168.0.101", "port" : 49450, "webuiaddress" : "http://192.168.0.101:8081", "cores" : 8, "coresused" : 8, "coresfree" : 0, "memory" : 15360, "memoryused" : 1024, "memoryfree" : 14336, "state" : "ALIVE", "lastheartbeat" : 1497519481722 }, { "id" : "worker-20170615172948-192.168.0.101-49452", "host" : "192.168.0.101", "port" : 49452, "webuiaddress" : "http://192.168.0.101:8082", "cores" : 8, "coresused" : 8, "coresfree" : 0, "memory" : 15360, "memoryused" : 1024, "memoryfree" : 14336, "state" : "ALIVE", "lastheartbeat" : 1497519484160 }, { "id" : "worker-20170615172951-192.168.0.101-49469", "host" : "192.168.0.101", "port" : 49469, "webuiaddress" : "http://192.168.0.101:8083", "cores" : 8, "coresused" : 8, "coresfree" : 0, "memory" : 15360, "memoryused" : 1024, "memoryfree" : 14336, "state" : "ALIVE", "lastheartbeat" : 1497519486905 } ], "cores" : 24, "coresused" : 24, "memory" : 46080, "memoryused" : 3072, "activeapps" : [ { "starttime" : 1497519426990, "id" : "app-20170615173706-0001", "name" : "Spark shell", "user" : "xingbojiang", "memoryperslave" : 1024, "submitdate" : "Thu Jun 15 17:37:06 CST 2017", "state" : "RUNNING", "duration" : 65362 } ], "completedapps" : [ { "starttime" : 1497519250893, "id" : "app-20170615173410-0000", "name" : "Spark shell", "user" : "xingbojiang", "memoryperslave" : 1024, "submitdate" : "Thu Jun 15 17:34:10 CST 2017", "state" : "FINISHED", "duration" : 116895 } ], "activedrivers" : [ ], "status" : "ALIVE" } ``` After the change: ``` { "url" : "spark://xingbos-MBP.local:7077", "workers" : [ { "id" : "worker-20170615175032-192.168.0.101-49951", "host" : "192.168.0.101", "port" : 49951, "webuiaddress" : "http://192.168.0.101:8081", "cores" : 8, "coresused" : 8, "coresfree" : 0, "memory" : 15360, "memoryused" : 1024, "memoryfree" : 14336, "state" : "ALIVE", "lastheartbeat" : 1497520292900 }, { "id" : "worker-20170615175034-192.168.0.101-49953", "host" : "192.168.0.101", "port" : 49953, "webuiaddress" : "http://192.168.0.101:8082", "cores" : 8, "coresused" : 8, "coresfree" : 0, "memory" : 15360, "memoryused" : 1024, "memoryfree" : 14336, "state" : "ALIVE", "lastheartbeat" : 1497520280301 }, { "id" : "worker-20170615175037-192.168.0.101-49955", "host" : "192.168.0.101", "port" : 49955, "webuiaddress" : "http://192.168.0.101:8083", "cores" : 8, "coresused" : 8, "coresfree" : 0, "memory" : 15360, "memoryused" : 1024, "memoryfree" : 14336, "state" : "ALIVE", "lastheartbeat" : 1497520282884 } ], "aliveworkers" : 3, "cores" : 24, "coresused" : 24, "memory" : 46080, "memoryused" : 3072, "activeapps" : [ { "id" : "app-20170615175122-0001", "starttime" : 1497520282115, "name" : "Spark shell", "cores" : 24, "user" : "xingbojiang", "memoryperslave" : 1024, "submitdate" : "Thu Jun 15 17:51:22 CST 2017", "state" : "RUNNING", "duration" : 10805 } ], "completedapps" : [ { "id" : "app-20170615175058-0000", "starttime" : 1497520258766, "name" : "Spark shell", "cores" : 24, "user" : "xingbojiang", "memoryperslave" : 1024, "submitdate" : "Thu Jun 15 17:50:58 CST 2017", "state" : "FINISHED", "duration" : 9876 } ], "activedrivers" : [ ], "completeddrivers" : [ ], "status" : "ALIVE" } ``` Author: Xingbo Jiang Closes #18303 from jiangxb1987/json-protocol. --- .../apache/spark/deploy/JsonProtocol.scala | 158 +++++++++++++++--- .../apache/spark/deploy/DeployTestUtils.scala | 4 +- .../spark/deploy/JsonProtocolSuite.scala | 15 +- 3 files changed, 149 insertions(+), 28 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/JsonProtocol.scala b/core/src/main/scala/org/apache/spark/deploy/JsonProtocol.scala index 220b20bf7cbd1..7212696166570 100644 --- a/core/src/main/scala/org/apache/spark/deploy/JsonProtocol.scala +++ b/core/src/main/scala/org/apache/spark/deploy/JsonProtocol.scala @@ -21,30 +21,65 @@ import org.json4s.JsonAST.JObject import org.json4s.JsonDSL._ import org.apache.spark.deploy.DeployMessages.{MasterStateResponse, WorkerStateResponse} -import org.apache.spark.deploy.master.{ApplicationInfo, DriverInfo, WorkerInfo} +import org.apache.spark.deploy.master._ +import org.apache.spark.deploy.master.RecoveryState.MasterState import org.apache.spark.deploy.worker.ExecutorRunner private[deploy] object JsonProtocol { - def writeWorkerInfo(obj: WorkerInfo): JObject = { - ("id" -> obj.id) ~ - ("host" -> obj.host) ~ - ("port" -> obj.port) ~ - ("webuiaddress" -> obj.webUiAddress) ~ - ("cores" -> obj.cores) ~ - ("coresused" -> obj.coresUsed) ~ - ("coresfree" -> obj.coresFree) ~ - ("memory" -> obj.memory) ~ - ("memoryused" -> obj.memoryUsed) ~ - ("memoryfree" -> obj.memoryFree) ~ - ("state" -> obj.state.toString) ~ - ("lastheartbeat" -> obj.lastHeartbeat) - } + /** + * Export the [[WorkerInfo]] to a Json object. A [[WorkerInfo]] consists of the information of a + * worker. + * + * @return a Json object containing the following fields: + * `id` a string identifier of the worker + * `host` the host that the worker is running on + * `port` the port that the worker is bound to + * `webuiaddress` the address used in web UI + * `cores` total cores of the worker + * `coresused` allocated cores of the worker + * `coresfree` free cores of the worker + * `memory` total memory of the worker + * `memoryused` allocated memory of the worker + * `memoryfree` free memory of the worker + * `state` state of the worker, see [[WorkerState]] + * `lastheartbeat` time in milliseconds that the latest heart beat message from the + * worker is received + */ + def writeWorkerInfo(obj: WorkerInfo): JObject = { + ("id" -> obj.id) ~ + ("host" -> obj.host) ~ + ("port" -> obj.port) ~ + ("webuiaddress" -> obj.webUiAddress) ~ + ("cores" -> obj.cores) ~ + ("coresused" -> obj.coresUsed) ~ + ("coresfree" -> obj.coresFree) ~ + ("memory" -> obj.memory) ~ + ("memoryused" -> obj.memoryUsed) ~ + ("memoryfree" -> obj.memoryFree) ~ + ("state" -> obj.state.toString) ~ + ("lastheartbeat" -> obj.lastHeartbeat) + } + /** + * Export the [[ApplicationInfo]] to a Json objec. An [[ApplicationInfo]] consists of the + * information of an application. + * + * @return a Json object containing the following fields: + * `id` a string identifier of the application + * `starttime` time in milliseconds that the application starts + * `name` the description of the application + * `cores` total cores granted to the application + * `user` name of the user who submitted the application + * `memoryperslave` minimal memory in MB required to each executor + * `submitdate` time in Date that the application is submitted + * `state` state of the application, see [[ApplicationState]] + * `duration` time in milliseconds that the application has been running + */ def writeApplicationInfo(obj: ApplicationInfo): JObject = { - ("starttime" -> obj.startTime) ~ ("id" -> obj.id) ~ + ("starttime" -> obj.startTime) ~ ("name" -> obj.desc.name) ~ - ("cores" -> obj.desc.maxCores) ~ + ("cores" -> obj.coresGranted) ~ ("user" -> obj.desc.user) ~ ("memoryperslave" -> obj.desc.memoryPerExecutorMB) ~ ("submitdate" -> obj.submitDate.toString) ~ @@ -52,14 +87,36 @@ private[deploy] object JsonProtocol { ("duration" -> obj.duration) } + /** + * Export the [[ApplicationDescription]] to a Json object. An [[ApplicationDescription]] consists + * of the description of an application. + * + * @return a Json object containing the following fields: + * `name` the description of the application + * `cores` max cores that can be allocated to the application, 0 means unlimited + * `memoryperslave` minimal memory in MB required to each executor + * `user` name of the user who submitted the application + * `command` the command string used to submit the application + */ def writeApplicationDescription(obj: ApplicationDescription): JObject = { ("name" -> obj.name) ~ - ("cores" -> obj.maxCores) ~ + ("cores" -> obj.maxCores.getOrElse(0)) ~ ("memoryperslave" -> obj.memoryPerExecutorMB) ~ ("user" -> obj.user) ~ ("command" -> obj.command.toString) } + /** + * Export the [[ExecutorRunner]] to a Json object. An [[ExecutorRunner]] consists of the + * information of an executor. + * + * @return a Json object containing the following fields: + * `id` an integer identifier of the executor + * `memory` memory in MB allocated to the executor + * `appid` a string identifier of the application that the executor is working on + * `appdesc` a Json object of the [[ApplicationDescription]] of the application that the + * executor is working on + */ def writeExecutorRunner(obj: ExecutorRunner): JObject = { ("id" -> obj.execId) ~ ("memory" -> obj.memory) ~ @@ -67,18 +124,59 @@ private[deploy] object JsonProtocol { ("appdesc" -> writeApplicationDescription(obj.appDesc)) } + /** + * Export the [[DriverInfo]] to a Json object. A [[DriverInfo]] consists of the information of a + * driver. + * + * @return a Json object containing the following fields: + * `id` a string identifier of the driver + * `starttime` time in milliseconds that the driver starts + * `state` state of the driver, see [[DriverState]] + * `cores` cores allocated to the driver + * `memory` memory in MB allocated to the driver + * `submitdate` time in Date that the driver is created + * `worker` identifier of the worker that the driver is running on + * `mainclass` main class of the command string that started the driver + */ def writeDriverInfo(obj: DriverInfo): JObject = { ("id" -> obj.id) ~ ("starttime" -> obj.startTime.toString) ~ ("state" -> obj.state.toString) ~ ("cores" -> obj.desc.cores) ~ - ("memory" -> obj.desc.mem) + ("memory" -> obj.desc.mem) ~ + ("submitdate" -> obj.submitDate.toString) ~ + ("worker" -> obj.worker.map(_.id).getOrElse("None")) ~ + ("mainclass" -> obj.desc.command.arguments(2)) } + /** + * Export the [[MasterStateResponse]] to a Json object. A [[MasterStateResponse]] consists the + * information of a master node. + * + * @return a Json object containing the following fields: + * `url` the url of the master node + * `workers` a list of Json objects of [[WorkerInfo]] of the workers allocated to the + * master + * `aliveworkers` size of alive workers allocated to the master + * `cores` total cores available of the master + * `coresused` cores used by the master + * `memory` total memory available of the master + * `memoryused` memory used by the master + * `activeapps` a list of Json objects of [[ApplicationInfo]] of the active applications + * running on the master + * `completedapps` a list of Json objects of [[ApplicationInfo]] of the applications + * completed in the master + * `activedrivers` a list of Json objects of [[DriverInfo]] of the active drivers of the + * master + * `completeddrivers` a list of Json objects of [[DriverInfo]] of the completed drivers + * of the master + * `status` status of the master, see [[MasterState]] + */ def writeMasterState(obj: MasterStateResponse): JObject = { val aliveWorkers = obj.workers.filter(_.isAlive()) ("url" -> obj.uri) ~ ("workers" -> obj.workers.toList.map(writeWorkerInfo)) ~ + ("aliveworkers" -> aliveWorkers.length) ~ ("cores" -> aliveWorkers.map(_.cores).sum) ~ ("coresused" -> aliveWorkers.map(_.coresUsed).sum) ~ ("memory" -> aliveWorkers.map(_.memory).sum) ~ @@ -86,9 +184,27 @@ private[deploy] object JsonProtocol { ("activeapps" -> obj.activeApps.toList.map(writeApplicationInfo)) ~ ("completedapps" -> obj.completedApps.toList.map(writeApplicationInfo)) ~ ("activedrivers" -> obj.activeDrivers.toList.map(writeDriverInfo)) ~ + ("completeddrivers" -> obj.completedDrivers.toList.map(writeDriverInfo)) ~ ("status" -> obj.status.toString) } + /** + * Export the [[WorkerStateResponse]] to a Json object. A [[WorkerStateResponse]] consists the + * information of a worker node. + * + * @return a Json object containing the following fields: + * `id` a string identifier of the worker node + * `masterurl` url of the master node of the worker + * `masterwebuiurl` the address used in web UI of the master node of the worker + * `cores` total cores of the worker + * `coreused` used cores of the worker + * `memory` total memory of the worker + * `memoryused` used memory of the worker + * `executors` a list of Json objects of [[ExecutorRunner]] of the executors running on + * the worker + * `finishedexecutors` a list of Json objects of [[ExecutorRunner]] of the finished + * executors of the worker + */ def writeWorkerState(obj: WorkerStateResponse): JObject = { ("id" -> obj.workerId) ~ ("masterurl" -> obj.masterUrl) ~ @@ -97,7 +213,7 @@ private[deploy] object JsonProtocol { ("coresused" -> obj.coresUsed) ~ ("memory" -> obj.memory) ~ ("memoryused" -> obj.memoryUsed) ~ - ("executors" -> obj.executors.toList.map(writeExecutorRunner)) ~ - ("finishedexecutors" -> obj.finishedExecutors.toList.map(writeExecutorRunner)) + ("executors" -> obj.executors.map(writeExecutorRunner)) ~ + ("finishedexecutors" -> obj.finishedExecutors.map(writeExecutorRunner)) } } diff --git a/core/src/test/scala/org/apache/spark/deploy/DeployTestUtils.scala b/core/src/test/scala/org/apache/spark/deploy/DeployTestUtils.scala index 9c13c15281a42..55a541d60ea3c 100644 --- a/core/src/test/scala/org/apache/spark/deploy/DeployTestUtils.scala +++ b/core/src/test/scala/org/apache/spark/deploy/DeployTestUtils.scala @@ -39,7 +39,7 @@ private[deploy] object DeployTestUtils { } def createDriverCommand(): Command = new Command( - "org.apache.spark.FakeClass", Seq("some arg --and-some options -g foo"), + "org.apache.spark.FakeClass", Seq("WORKER_URL", "USER_JAR", "mainClass"), Map(("K1", "V1"), ("K2", "V2")), Seq("cp1", "cp2"), Seq("lp1", "lp2"), Seq("-Dfoo") ) @@ -47,7 +47,7 @@ private[deploy] object DeployTestUtils { new DriverDescription("hdfs://some-dir/some.jar", 100, 3, false, createDriverCommand()) def createDriverInfo(): DriverInfo = new DriverInfo(3, "driver-3", - createDriverDesc(), new Date()) + createDriverDesc(), JsonConstants.submitDate) def createWorkerInfo(): WorkerInfo = { val workerInfo = new WorkerInfo("id", "host", 8080, 4, 1234, null, "http://publicAddress:80") diff --git a/core/src/test/scala/org/apache/spark/deploy/JsonProtocolSuite.scala b/core/src/test/scala/org/apache/spark/deploy/JsonProtocolSuite.scala index 7093dad05c5f6..1903130cb694a 100644 --- a/core/src/test/scala/org/apache/spark/deploy/JsonProtocolSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/JsonProtocolSuite.scala @@ -104,8 +104,8 @@ object JsonConstants { val submitDate = new Date(123456789) val appInfoJsonStr = """ - |{"starttime":3,"id":"id","name":"name", - |"cores":4,"user":"%s", + |{"id":"id","starttime":3,"name":"name", + |"cores":0,"user":"%s", |"memoryperslave":1234,"submitdate":"%s", |"state":"WAITING","duration":%d} """.format(System.getProperty("user.name", ""), @@ -134,19 +134,24 @@ object JsonConstants { val driverInfoJsonStr = """ - |{"id":"driver-3","starttime":"3","state":"SUBMITTED","cores":3,"memory":100} - """.stripMargin + |{"id":"driver-3","starttime":"3", + |"state":"SUBMITTED","cores":3,"memory":100, + |"submitdate":"%s","worker":"None", + |"mainclass":"mainClass"} + """.format(submitDate.toString).stripMargin val masterStateJsonStr = """ |{"url":"spark://host:8080", |"workers":[%s,%s], + |"aliveworkers":2, |"cores":8,"coresused":0,"memory":2468,"memoryused":0, |"activeapps":[%s],"completedapps":[], |"activedrivers":[%s], + |"completeddrivers":[%s], |"status":"ALIVE"} """.format(workerInfoJsonStr, workerInfoJsonStr, - appInfoJsonStr, driverInfoJsonStr).stripMargin + appInfoJsonStr, driverInfoJsonStr, driverInfoJsonStr).stripMargin val workerStateJsonStr = """ From 9413b84b5a99e264816c61f72905b392c2f9cd35 Mon Sep 17 00:00:00 2001 From: Xiao Li Date: Mon, 19 Jun 2017 15:51:21 +0800 Subject: [PATCH 0742/1765] [SPARK-21132][SQL] DISTINCT modifier of function arguments should not be silently ignored ### What changes were proposed in this pull request? We should not silently ignore `DISTINCT` when they are not supported in the function arguments. This PR is to block these cases and issue the error messages. ### How was this patch tested? Added test cases for both regular functions and window functions Author: Xiao Li Closes #18340 from gatorsmile/firstCount. --- .../spark/sql/catalyst/analysis/Analyzer.scala | 14 ++++++++++++-- .../catalyst/analysis/AnalysisErrorSuite.scala | 15 +++++++++++++-- .../sql/catalyst/analysis/AnalysisTest.scala | 8 ++++++-- 3 files changed, 31 insertions(+), 6 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 196b4a9bada3c..647fc0b9342c1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -1206,11 +1206,21 @@ class Analyzer( // AggregateWindowFunctions are AggregateFunctions that can only be evaluated within // the context of a Window clause. They do not need to be wrapped in an // AggregateExpression. - case wf: AggregateWindowFunction => wf + case wf: AggregateWindowFunction => + if (isDistinct) { + failAnalysis(s"${wf.prettyName} does not support the modifier DISTINCT") + } else { + wf + } // We get an aggregate function, we need to wrap it in an AggregateExpression. case agg: AggregateFunction => AggregateExpression(agg, Complete, isDistinct) // This function is not an aggregate function, just return the resolved one. - case other => other + case other => + if (isDistinct) { + failAnalysis(s"${other.prettyName} does not support the modifier DISTINCT") + } else { + other + } } } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala index d2ebca5a83dd3..5050318d96358 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala @@ -24,7 +24,8 @@ import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, Complete, Count, Max} -import org.apache.spark.sql.catalyst.plans.{Cross, Inner, LeftOuter, RightOuter} +import org.apache.spark.sql.catalyst.parser.CatalystSqlParser +import org.apache.spark.sql.catalyst.plans.{Cross, LeftOuter, RightOuter} import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, GenericArrayData, MapData} import org.apache.spark.sql.types._ @@ -152,7 +153,7 @@ class AnalysisErrorSuite extends AnalysisTest { "not supported within a window function" :: Nil) errorTest( - "distinct window function", + "distinct aggregate function in window", testRelation2.select( WindowExpression( AggregateExpression(Count(UnresolvedAttribute("b")), Complete, isDistinct = true), @@ -162,6 +163,16 @@ class AnalysisErrorSuite extends AnalysisTest { UnspecifiedFrame)).as('window)), "Distinct window functions are not supported" :: Nil) + errorTest( + "distinct function", + CatalystSqlParser.parsePlan("SELECT hex(DISTINCT a) FROM TaBlE"), + "hex does not support the modifier DISTINCT" :: Nil) + + errorTest( + "distinct window function", + CatalystSqlParser.parsePlan("SELECT percent_rank(DISTINCT a) over () FROM TaBlE"), + "percent_rank does not support the modifier DISTINCT" :: Nil) + errorTest( "nested aggregate functions", testRelation.groupBy('a)( diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisTest.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisTest.scala index afc7ce4195a8b..edfa8c45f9867 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisTest.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisTest.scala @@ -17,10 +17,11 @@ package org.apache.spark.sql.catalyst.analysis +import java.net.URI import java.util.Locale import org.apache.spark.sql.AnalysisException -import org.apache.spark.sql.catalyst.catalog.{InMemoryCatalog, SessionCatalog} +import org.apache.spark.sql.catalyst.catalog.{CatalogDatabase, InMemoryCatalog, SessionCatalog} import org.apache.spark.sql.catalyst.plans.PlanTest import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.internal.SQLConf @@ -32,7 +33,10 @@ trait AnalysisTest extends PlanTest { private def makeAnalyzer(caseSensitive: Boolean): Analyzer = { val conf = new SQLConf().copy(SQLConf.CASE_SENSITIVE -> caseSensitive) - val catalog = new SessionCatalog(new InMemoryCatalog, EmptyFunctionRegistry, conf) + val catalog = new SessionCatalog(new InMemoryCatalog, FunctionRegistry.builtin, conf) + catalog.createDatabase( + CatalogDatabase("default", "", new URI("loc"), Map.empty), + ignoreIfExists = false) catalog.createTempView("TaBlE", TestRelations.testRelation, overrideIfExists = true) catalog.createTempView("TaBlE2", TestRelations.testRelation2, overrideIfExists = true) catalog.createTempView("TaBlE3", TestRelations.testRelation3, overrideIfExists = true) From 9a145fd796145d1386fd75c01e4103deadb97ac9 Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Mon, 19 Jun 2017 11:13:03 +0100 Subject: [PATCH 0743/1765] [MINOR] Bump SparkR and PySpark version to 2.3.0. ## What changes were proposed in this pull request? #17753 bumps master branch version to 2.3.0-SNAPSHOT, but it seems SparkR and PySpark version were omitted. ditto of https://github.com/apache/spark/pull/16488 / https://github.com/apache/spark/pull/17523 ## How was this patch tested? N/A Author: hyukjinkwon Closes #18341 from HyukjinKwon/r-version. --- R/pkg/DESCRIPTION | 2 +- python/pyspark/version.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/R/pkg/DESCRIPTION b/R/pkg/DESCRIPTION index 879c1f80f2c5d..b739d423a36cc 100644 --- a/R/pkg/DESCRIPTION +++ b/R/pkg/DESCRIPTION @@ -1,6 +1,6 @@ Package: SparkR Type: Package -Version: 2.2.0 +Version: 2.3.0 Title: R Frontend for Apache Spark Description: The SparkR package provides an R Frontend for Apache Spark. Authors@R: c(person("Shivaram", "Venkataraman", role = c("aut", "cre"), diff --git a/python/pyspark/version.py b/python/pyspark/version.py index 41bf8c269b795..12dd53b9d2902 100644 --- a/python/pyspark/version.py +++ b/python/pyspark/version.py @@ -16,4 +16,4 @@ # See the License for the specific language governing permissions and # limitations under the License. -__version__ = "2.2.0.dev0" +__version__ = "2.3.0.dev0" From e92ffe6f1771e3fe9ea2e62ba552c1b5cf255368 Mon Sep 17 00:00:00 2001 From: saturday_s Date: Mon, 19 Jun 2017 10:24:29 -0700 Subject: [PATCH 0744/1765] [SPARK-19688][STREAMING] Not to read `spark.yarn.credentials.file` from checkpoint. ## What changes were proposed in this pull request? Reload the `spark.yarn.credentials.file` property when restarting a streaming application from checkpoint. ## How was this patch tested? Manual tested with 1.6.3 and 2.1.1. I didn't test this with master because of some compile problems, but I think it will be the same result. ## Notice This should be merged into maintenance branches too. jira: [SPARK-21008](https://issues.apache.org/jira/browse/SPARK-21008) Author: saturday_s Closes #18230 from saturday-shi/SPARK-21008. --- .../src/main/scala/org/apache/spark/streaming/Checkpoint.scala | 3 +++ 1 file changed, 3 insertions(+) diff --git a/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala b/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala index 5cbad8bf3ce6e..b8c780db07c98 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala @@ -55,6 +55,9 @@ class Checkpoint(ssc: StreamingContext, val checkpointTime: Time) "spark.master", "spark.yarn.keytab", "spark.yarn.principal", + "spark.yarn.credentials.file", + "spark.yarn.credentials.renewalTime", + "spark.yarn.credentials.updateTime", "spark.ui.filters") val newSparkConf = new SparkConf(loadDefaults = false).setAll(sparkConfPairs) From 66a792cd88c63cc0a1d20cbe14ac5699afbb3662 Mon Sep 17 00:00:00 2001 From: assafmendelson Date: Mon, 19 Jun 2017 10:58:58 -0700 Subject: [PATCH 0745/1765] [SPARK-21123][DOCS][STRUCTURED STREAMING] Options for file stream source are in a wrong table ## What changes were proposed in this pull request? The description for several options of File Source for structured streaming appeared in the File Sink description instead. This pull request has two commits: The first includes changes to the version as it appeared in spark 2.1 and the second handled an additional option added for spark 2.2 ## How was this patch tested? Built the documentation by SKIP_API=1 jekyll build and visually inspected the structured streaming programming guide. The original documentation was written by tdas and lw-lin Author: assafmendelson Closes #18342 from assafmendelson/spark-21123. --- .../structured-streaming-programming-guide.md | 28 ++++++++++--------- 1 file changed, 15 insertions(+), 13 deletions(-) diff --git a/docs/structured-streaming-programming-guide.md b/docs/structured-streaming-programming-guide.md index 9b9177d44145f..d478042dea5c8 100644 --- a/docs/structured-streaming-programming-guide.md +++ b/docs/structured-streaming-programming-guide.md @@ -510,7 +510,20 @@ Here are the details of all the sources in Spark. File source path: path to the input directory, and common to all file formats. -

    +
    + maxFilesPerTrigger: maximum number of new files to be considered in every trigger (default: no max) +
    + latestFirst: whether to processs the latest new files first, useful when there is a large backlog of files (default: false) +
    + fileNameOnly: whether to check new files based on only the filename instead of on the full path (default: false). With this set to `true`, the following files would be considered as the same file, because their filenames, "dataset.txt", are the same: +
    + · "file:///dataset.txt"
    + · "s3://a/dataset.txt"
    + · "s3n://a/b/dataset.txt"
    + · "s3a://a/b/c/dataset.txt"
    +
    + +
    For file-format-specific options, see the related methods in DataStreamReader (Scala/Java/Python/R). @@ -1234,18 +1247,7 @@ Here are the details of all the sinks in Spark. Append path: path to the output directory, must be specified. -
    - maxFilesPerTrigger: maximum number of new files to be considered in every trigger (default: no max) -
    - latestFirst: whether to processs the latest new files first, useful when there is a large backlog of files (default: false) -
    - fileNameOnly: whether to check new files based on only the filename instead of on the full path (default: false). With this set to `true`, the following files would be considered as the same file, because their filenames, "dataset.txt", are the same: -
    - · "file:///dataset.txt"
    - · "s3://a/dataset.txt"
    - · "s3n://a/b/dataset.txt"
    - · "s3a://a/b/c/dataset.txt"
    -
    +

    For file-format-specific options, see the related methods in DataFrameWriter (Scala/Java/Python/R). From e5387018e76a9af1318e78c4133ee68232e6a159 Mon Sep 17 00:00:00 2001 From: Yong Tang Date: Mon, 19 Jun 2017 11:40:07 -0700 Subject: [PATCH 0746/1765] [SPARK-19975][PYTHON][SQL] Add map_keys and map_values functions to Python ## What changes were proposed in this pull request? This fix tries to address the issue in SPARK-19975 where we have `map_keys` and `map_values` functions in SQL yet there is no Python equivalent functions. This fix adds `map_keys` and `map_values` functions to Python. ## How was this patch tested? This fix is tested manually (See Python docs for examples). Author: Yong Tang Closes #17328 from yongtang/SPARK-19975. --- python/pyspark/sql/functions.py | 40 +++++++++++++++++++ .../org/apache/spark/sql/functions.scala | 14 +++++++ 2 files changed, 54 insertions(+) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index d9b86aff63fa0..240ae65a61785 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -1855,6 +1855,46 @@ def sort_array(col, asc=True): return Column(sc._jvm.functions.sort_array(_to_java_column(col), asc)) +@since(2.3) +def map_keys(col): + """ + Collection function: Returns an unordered array containing the keys of the map. + + :param col: name of column or expression + + >>> from pyspark.sql.functions import map_keys + >>> df = spark.sql("SELECT map(1, 'a', 2, 'b') as data") + >>> df.select(map_keys("data").alias("keys")).show() + +------+ + | keys| + +------+ + |[1, 2]| + +------+ + """ + sc = SparkContext._active_spark_context + return Column(sc._jvm.functions.map_keys(_to_java_column(col))) + + +@since(2.3) +def map_values(col): + """ + Collection function: Returns an unordered array containing the values of the map. + + :param col: name of column or expression + + >>> from pyspark.sql.functions import map_values + >>> df = spark.sql("SELECT map(1, 'a', 2, 'b') as data") + >>> df.select(map_values("data").alias("values")).show() + +------+ + |values| + +------+ + |[a, b]| + +------+ + """ + sc = SparkContext._active_spark_context + return Column(sc._jvm.functions.map_values(_to_java_column(col))) + + # ---------------------------- User Defined Function ---------------------------------- def _wrap_function(sc, func, returnType): diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 8d2e1f32da059..9a35a5c4658e3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -3161,6 +3161,20 @@ object functions { */ def sort_array(e: Column, asc: Boolean): Column = withExpr { SortArray(e.expr, lit(asc).expr) } + /** + * Returns an unordered array containing the keys of the map. + * @group collection_funcs + * @since 2.3.0 + */ + def map_keys(e: Column): Column = withExpr { MapKeys(e.expr) } + + /** + * Returns an unordered array containing the values of the map. + * @group collection_funcs + * @since 2.3.0 + */ + def map_values(e: Column): Column = withExpr { MapValues(e.expr) } + ////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////// From ecc5631351e81bbee4befb213f3053a4f31532a7 Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Mon, 19 Jun 2017 20:17:54 +0100 Subject: [PATCH 0747/1765] [MINOR][BUILD] Fix Java linter errors ## What changes were proposed in this pull request? This PR cleans up a few Java linter errors for Apache Spark 2.2 release. ## How was this patch tested? ```bash $ dev/lint-java Using `mvn` from path: /usr/local/bin/mvn Checkstyle checks passed. ``` We can check the result at Travis CI, [here](https://travis-ci.org/dongjoon-hyun/spark/builds/244297894). Author: Dongjoon Hyun Closes #18345 from dongjoon-hyun/fix_lint_java_2. --- .../src/main/java/org/apache/spark/kvstore/KVIndex.java | 2 +- .../src/main/java/org/apache/spark/kvstore/KVStore.java | 7 ++----- .../main/java/org/apache/spark/kvstore/KVStoreView.java | 3 --- .../main/java/org/apache/spark/kvstore/KVTypeInfo.java | 2 -- .../src/main/java/org/apache/spark/kvstore/LevelDB.java | 1 - .../java/org/apache/spark/kvstore/LevelDBIterator.java | 1 - .../java/org/apache/spark/kvstore/LevelDBTypeInfo.java | 5 ----- .../java/org/apache/spark/kvstore/DBIteratorSuite.java | 4 +--- .../test/java/org/apache/spark/kvstore/LevelDBSuite.java | 2 -- .../spark/network/shuffle/OneForOneBlockFetcher.java | 2 +- .../apache/spark/shuffle/sort/UnsafeShuffleWriter.java | 8 +++++--- .../java/org/apache/spark/examples/ml/JavaALSExample.java | 2 +- .../spark/examples/sql/JavaSQLDataSourceExample.java | 6 +++++- .../java/org/apache/spark/sql/streaming/OutputMode.java | 1 - 14 files changed, 16 insertions(+), 30 deletions(-) diff --git a/common/kvstore/src/main/java/org/apache/spark/kvstore/KVIndex.java b/common/kvstore/src/main/java/org/apache/spark/kvstore/KVIndex.java index 8b8899023c938..0cffefe07c25d 100644 --- a/common/kvstore/src/main/java/org/apache/spark/kvstore/KVIndex.java +++ b/common/kvstore/src/main/java/org/apache/spark/kvstore/KVIndex.java @@ -50,7 +50,7 @@ @Target({ElementType.FIELD, ElementType.METHOD}) public @interface KVIndex { - public static final String NATURAL_INDEX_NAME = "__main__"; + String NATURAL_INDEX_NAME = "__main__"; /** * The name of the index to be created for the annotated entity. Must be unique within diff --git a/common/kvstore/src/main/java/org/apache/spark/kvstore/KVStore.java b/common/kvstore/src/main/java/org/apache/spark/kvstore/KVStore.java index 3be4b829b4d8d..c7808ea3c3881 100644 --- a/common/kvstore/src/main/java/org/apache/spark/kvstore/KVStore.java +++ b/common/kvstore/src/main/java/org/apache/spark/kvstore/KVStore.java @@ -18,9 +18,6 @@ package org.apache.spark.kvstore; import java.io.Closeable; -import java.util.Iterator; -import java.util.Map; -import java.util.NoSuchElementException; /** * Abstraction for a local key/value store for storing app data. @@ -84,7 +81,7 @@ public interface KVStore extends Closeable { * * @param naturalKey The object's "natural key", which uniquely identifies it. Null keys * are not allowed. - * @throws NoSuchElementException If an element with the given key does not exist. + * @throws java.util.NoSuchElementException If an element with the given key does not exist. */ T read(Class klass, Object naturalKey) throws Exception; @@ -107,7 +104,7 @@ public interface KVStore extends Closeable { * @param type The object's type. * @param naturalKey The object's "natural key", which uniquely identifies it. Null keys * are not allowed. - * @throws NoSuchElementException If an element with the given key does not exist. + * @throws java.util.NoSuchElementException If an element with the given key does not exist. */ void delete(Class type, Object naturalKey) throws Exception; diff --git a/common/kvstore/src/main/java/org/apache/spark/kvstore/KVStoreView.java b/common/kvstore/src/main/java/org/apache/spark/kvstore/KVStoreView.java index b761640e6da8b..8cd1f52892293 100644 --- a/common/kvstore/src/main/java/org/apache/spark/kvstore/KVStoreView.java +++ b/common/kvstore/src/main/java/org/apache/spark/kvstore/KVStoreView.java @@ -17,9 +17,6 @@ package org.apache.spark.kvstore; -import java.util.Iterator; -import java.util.Map; - import com.google.common.base.Preconditions; /** diff --git a/common/kvstore/src/main/java/org/apache/spark/kvstore/KVTypeInfo.java b/common/kvstore/src/main/java/org/apache/spark/kvstore/KVTypeInfo.java index 90f2ff0079b8a..e1cc0ba3f5aa7 100644 --- a/common/kvstore/src/main/java/org/apache/spark/kvstore/KVTypeInfo.java +++ b/common/kvstore/src/main/java/org/apache/spark/kvstore/KVTypeInfo.java @@ -19,8 +19,6 @@ import java.lang.reflect.Field; import java.lang.reflect.Method; -import java.util.ArrayList; -import java.util.Collection; import java.util.HashMap; import java.util.Map; import java.util.stream.Stream; diff --git a/common/kvstore/src/main/java/org/apache/spark/kvstore/LevelDB.java b/common/kvstore/src/main/java/org/apache/spark/kvstore/LevelDB.java index 08b22fd8265d8..27141358dc0f2 100644 --- a/common/kvstore/src/main/java/org/apache/spark/kvstore/LevelDB.java +++ b/common/kvstore/src/main/java/org/apache/spark/kvstore/LevelDB.java @@ -29,7 +29,6 @@ import static java.nio.charset.StandardCharsets.UTF_8; import com.google.common.annotations.VisibleForTesting; -import com.google.common.base.Objects; import com.google.common.base.Preconditions; import com.google.common.base.Throwables; import org.fusesource.leveldbjni.JniDBFactory; diff --git a/common/kvstore/src/main/java/org/apache/spark/kvstore/LevelDBIterator.java b/common/kvstore/src/main/java/org/apache/spark/kvstore/LevelDBIterator.java index a5d0f9f4fb373..263d45c242106 100644 --- a/common/kvstore/src/main/java/org/apache/spark/kvstore/LevelDBIterator.java +++ b/common/kvstore/src/main/java/org/apache/spark/kvstore/LevelDBIterator.java @@ -18,7 +18,6 @@ package org.apache.spark.kvstore; import java.io.IOException; -import java.util.Arrays; import java.util.ArrayList; import java.util.List; import java.util.Map; diff --git a/common/kvstore/src/main/java/org/apache/spark/kvstore/LevelDBTypeInfo.java b/common/kvstore/src/main/java/org/apache/spark/kvstore/LevelDBTypeInfo.java index 3ab17dbd03ca7..722f54e6f9c66 100644 --- a/common/kvstore/src/main/java/org/apache/spark/kvstore/LevelDBTypeInfo.java +++ b/common/kvstore/src/main/java/org/apache/spark/kvstore/LevelDBTypeInfo.java @@ -18,17 +18,12 @@ package org.apache.spark.kvstore; import java.lang.reflect.Array; -import java.lang.reflect.Field; -import java.lang.reflect.Method; -import java.io.ByteArrayOutputStream; -import java.io.IOException; import java.util.Collection; import java.util.HashMap; import java.util.Map; import static java.nio.charset.StandardCharsets.UTF_8; import com.google.common.base.Preconditions; -import com.google.common.base.Throwables; import org.iq80.leveldb.WriteBatch; /** diff --git a/common/kvstore/src/test/java/org/apache/spark/kvstore/DBIteratorSuite.java b/common/kvstore/src/test/java/org/apache/spark/kvstore/DBIteratorSuite.java index 8549712213393..3a418189ecfec 100644 --- a/common/kvstore/src/test/java/org/apache/spark/kvstore/DBIteratorSuite.java +++ b/common/kvstore/src/test/java/org/apache/spark/kvstore/DBIteratorSuite.java @@ -25,11 +25,9 @@ import java.util.List; import java.util.Random; -import com.google.common.base.Predicate; import com.google.common.collect.Iterables; import com.google.common.collect.Iterators; import com.google.common.collect.Lists; -import org.apache.commons.io.FileUtils; import org.junit.AfterClass; import org.junit.Before; import org.junit.BeforeClass; @@ -50,7 +48,7 @@ public abstract class DBIteratorSuite { private static List clashingEntries; private static KVStore db; - private static interface BaseComparator extends Comparator { + private interface BaseComparator extends Comparator { /** * Returns a comparator that falls back to natural order if this comparator's ordering * returns equality for two elements. Used to mimic how the index sorts things internally. diff --git a/common/kvstore/src/test/java/org/apache/spark/kvstore/LevelDBSuite.java b/common/kvstore/src/test/java/org/apache/spark/kvstore/LevelDBSuite.java index ee1c397c08573..42bff610457e7 100644 --- a/common/kvstore/src/test/java/org/apache/spark/kvstore/LevelDBSuite.java +++ b/common/kvstore/src/test/java/org/apache/spark/kvstore/LevelDBSuite.java @@ -20,9 +20,7 @@ import java.io.File; import java.util.Arrays; import java.util.List; -import java.util.Map; import java.util.NoSuchElementException; -import static java.nio.charset.StandardCharsets.UTF_8; import org.apache.commons.io.FileUtils; import org.iq80.leveldb.DBIterator; diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockFetcher.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockFetcher.java index 5f428759252aa..d46ce2e0e6b78 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockFetcher.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockFetcher.java @@ -157,7 +157,7 @@ private class DownloadCallback implements StreamCallback { private File targetFile = null; private int chunkIndex; - public DownloadCallback(File targetFile, int chunkIndex) throws IOException { + DownloadCallback(File targetFile, int chunkIndex) throws IOException { this.targetFile = targetFile; this.channel = Channels.newChannel(new FileOutputStream(targetFile)); this.chunkIndex = chunkIndex; diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java b/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java index 857ec8a4dadd2..34c179990214f 100644 --- a/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java +++ b/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java @@ -364,7 +364,8 @@ private long[] mergeSpillsWithFileStream( // Use a counting output stream to avoid having to close the underlying file and ask // the file system for its size after each partition is written. final CountingOutputStream mergedFileOutputStream = new CountingOutputStream(bos); - final int inputBufferSizeInBytes = (int) sparkConf.getSizeAsKb("spark.shuffle.file.buffer", "32k") * 1024; + final int inputBufferSizeInBytes = + (int) sparkConf.getSizeAsKb("spark.shuffle.file.buffer", "32k") * 1024; boolean threwException = true; try { @@ -375,8 +376,9 @@ private long[] mergeSpillsWithFileStream( } for (int partition = 0; partition < numPartitions; partition++) { final long initialFileLength = mergedFileOutputStream.getByteCount(); - // Shield the underlying output stream from close() and flush() calls, so that we can close the higher - // level streams to make sure all data is really flushed and internal state is cleaned. + // Shield the underlying output stream from close() and flush() calls, so that we can close + // the higher level streams to make sure all data is really flushed and internal state is + // cleaned. OutputStream partitionOutput = new CloseAndFlushShieldOutputStream( new TimeTrackingOutputStream(writeMetrics, mergedFileOutputStream)); partitionOutput = blockManager.serializerManager().wrapForEncryption(partitionOutput); diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaALSExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaALSExample.java index 60ef03d89d17b..fe4d6bc83f04a 100644 --- a/examples/src/main/java/org/apache/spark/examples/ml/JavaALSExample.java +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaALSExample.java @@ -121,7 +121,7 @@ public static void main(String[] args) { // $example off$ userRecs.show(); movieRecs.show(); - + spark.stop(); } } diff --git a/examples/src/main/java/org/apache/spark/examples/sql/JavaSQLDataSourceExample.java b/examples/src/main/java/org/apache/spark/examples/sql/JavaSQLDataSourceExample.java index 706856b5215e4..95859c52c2aeb 100644 --- a/examples/src/main/java/org/apache/spark/examples/sql/JavaSQLDataSourceExample.java +++ b/examples/src/main/java/org/apache/spark/examples/sql/JavaSQLDataSourceExample.java @@ -124,7 +124,11 @@ private static void runBasicDataSourceExample(SparkSession spark) { peopleDF.write().bucketBy(42, "name").sortBy("age").saveAsTable("people_bucketed"); // $example off:write_sorting_and_bucketing$ // $example on:write_partitioning$ - usersDF.write().partitionBy("favorite_color").format("parquet").save("namesPartByColor.parquet"); + usersDF + .write() + .partitionBy("favorite_color") + .format("parquet") + .save("namesPartByColor.parquet"); // $example off:write_partitioning$ // $example on:write_partition_and_bucket$ peopleDF diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/streaming/OutputMode.java b/sql/catalyst/src/main/java/org/apache/spark/sql/streaming/OutputMode.java index 8410abd14fd59..2800b3068f87b 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/streaming/OutputMode.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/streaming/OutputMode.java @@ -17,7 +17,6 @@ package org.apache.spark.sql.streaming; -import org.apache.spark.annotation.Experimental; import org.apache.spark.annotation.InterfaceStability; import org.apache.spark.sql.catalyst.streaming.InternalOutputModes; From 0a4b7e4f81109cff651d2afb94f9f8bf734abdeb Mon Sep 17 00:00:00 2001 From: Xianyang Liu Date: Mon, 19 Jun 2017 20:35:58 +0100 Subject: [PATCH 0748/1765] [MINOR] Fix some typo of the document ## What changes were proposed in this pull request? Fix some typo of the document. ## How was this patch tested? Existing tests. Please review http://spark.apache.org/contributing.html before opening a pull request. Author: Xianyang Liu Closes #18350 from ConeyLiu/fixtypo. --- dev/change-version-to-2.10.sh | 2 +- dev/change-version-to-2.11.sh | 2 +- python/pyspark/__init__.py | 2 +- .../apache/spark/sql/catalyst/expressions/ExpressionSet.scala | 2 +- .../apache/spark/sql/execution/streaming/BatchCommitLog.scala | 2 +- .../scala/org/apache/spark/sql/DataFrameAggregateSuite.scala | 2 +- .../sql/execution/datasources/FileSourceStrategySuite.scala | 2 +- 7 files changed, 7 insertions(+), 7 deletions(-) diff --git a/dev/change-version-to-2.10.sh b/dev/change-version-to-2.10.sh index 0962d34c52f28..b718d94f849dd 100755 --- a/dev/change-version-to-2.10.sh +++ b/dev/change-version-to-2.10.sh @@ -17,7 +17,7 @@ # limitations under the License. # -# This script exists for backwards compability. Use change-scala-version.sh instead. +# This script exists for backwards compatibility. Use change-scala-version.sh instead. echo "This script is deprecated. Please instead run: change-scala-version.sh 2.10" $(dirname $0)/change-scala-version.sh 2.10 diff --git a/dev/change-version-to-2.11.sh b/dev/change-version-to-2.11.sh index 4ccfeef09fd04..93087959a38dd 100755 --- a/dev/change-version-to-2.11.sh +++ b/dev/change-version-to-2.11.sh @@ -17,7 +17,7 @@ # limitations under the License. # -# This script exists for backwards compability. Use change-scala-version.sh instead. +# This script exists for backwards compatibility. Use change-scala-version.sh instead. echo "This script is deprecated. Please instead run: change-scala-version.sh 2.11" $(dirname $0)/change-scala-version.sh 2.11 diff --git a/python/pyspark/__init__.py b/python/pyspark/__init__.py index 14c51a306e1c2..4d142c91629cc 100644 --- a/python/pyspark/__init__.py +++ b/python/pyspark/__init__.py @@ -35,7 +35,7 @@ - :class:`StorageLevel`: Finer-grained cache persistence levels. - :class:`TaskContext`: - Information about the current running task, avaialble on the workers and experimental. + Information about the current running task, available on the workers and experimental. """ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExpressionSet.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExpressionSet.scala index f93e5736de401..ede0b1654bbd6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExpressionSet.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExpressionSet.scala @@ -39,7 +39,7 @@ object ExpressionSet { * guaranteed to see at least one such expression. For example: * * {{{ - * val set = AttributeSet(a + 1, 1 + a) + * val set = ExpressionSet(a + 1, 1 + a) * * set.iterator => Iterator(a + 1) * set.contains(a + 1) => true diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/BatchCommitLog.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/BatchCommitLog.scala index a34938f911f76..5e24e8fc4e3cc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/BatchCommitLog.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/BatchCommitLog.scala @@ -33,7 +33,7 @@ import org.apache.spark.sql.SparkSession * - process batch 1 * - write batch 1 to completion log * - trigger batch 2 - * - obtain bactch 2 offsets and write to offset log + * - obtain batch 2 offsets and write to offset log * - process batch 2 * - write batch 2 to completion log * .... diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala index 8569c2d76b694..5db354d79bb6e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala @@ -507,7 +507,7 @@ class DataFrameAggregateSuite extends QueryTest with SharedSQLContext { Row(2.0) :: Row(2.0) :: Row(2.0) :: Nil) } - test("SQL decimal test (used for catching certain demical handling bugs in aggregates)") { + test("SQL decimal test (used for catching certain decimal handling bugs in aggregates)") { checkAnswer( decimalData.groupBy('a cast DecimalType(10, 2)).agg(avg('b cast DecimalType(10, 2))), Seq(Row(new java.math.BigDecimal(1.0), new java.math.BigDecimal(1.5)), diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategySuite.scala index 9a2dcafb5e4b3..d77f0c298ffe3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategySuite.scala @@ -244,7 +244,7 @@ class FileSourceStrategySuite extends QueryTest with SharedSQLContext with Predi val df2 = table.where("(p1 + c2) = 2 AND c1 = 1") // Filter on data only are advisory so we have to reevaluate. assert(getPhysicalFilters(df2) contains resolve(df2, "c1 = 1")) - // Need to evalaute filters that are not pushed down. + // Need to evaluate filters that are not pushed down. assert(getPhysicalFilters(df2) contains resolve(df2, "(p1 + c2) = 2")) } From 581565dd871ca51507603d19b2d4203993c2636d Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Mon, 19 Jun 2017 14:41:58 -0700 Subject: [PATCH 0749/1765] [SPARK-21124][UI] Show correct application user in UI. The jobs page currently shows the application user, but it assumes the OS user is the same as the user running the application, which may not be true in all scenarios (e.g., kerberos). While it might be useful to show both in the UI, this change just chooses the application user over the OS user, since the latter can be found in the environment page if needed. Tested in live application and in history server. Author: Marcelo Vanzin Closes #18331 from vanzin/SPARK-21124. --- core/src/main/scala/org/apache/spark/ui/SparkUI.scala | 4 +++- .../main/scala/org/apache/spark/ui/env/EnvironmentTab.scala | 5 +++++ 2 files changed, 8 insertions(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/ui/SparkUI.scala b/core/src/main/scala/org/apache/spark/ui/SparkUI.scala index f271c56021e95..589f811145519 100644 --- a/core/src/main/scala/org/apache/spark/ui/SparkUI.scala +++ b/core/src/main/scala/org/apache/spark/ui/SparkUI.scala @@ -86,7 +86,9 @@ private[spark] class SparkUI private ( initialize() def getSparkUser: String = { - environmentListener.systemProperties.toMap.getOrElse("user.name", "") + environmentListener.sparkUser + .orElse(environmentListener.systemProperties.toMap.get("user.name")) + .getOrElse("") } def getAppName: String = appName diff --git a/core/src/main/scala/org/apache/spark/ui/env/EnvironmentTab.scala b/core/src/main/scala/org/apache/spark/ui/env/EnvironmentTab.scala index 8c18464e6477a..61b12aaa32bb6 100644 --- a/core/src/main/scala/org/apache/spark/ui/env/EnvironmentTab.scala +++ b/core/src/main/scala/org/apache/spark/ui/env/EnvironmentTab.scala @@ -34,11 +34,16 @@ private[ui] class EnvironmentTab(parent: SparkUI) extends SparkUITab(parent, "en @DeveloperApi @deprecated("This class will be removed in a future release.", "2.2.0") class EnvironmentListener extends SparkListener { + var sparkUser: Option[String] = None var jvmInformation = Seq[(String, String)]() var sparkProperties = Seq[(String, String)]() var systemProperties = Seq[(String, String)]() var classpathEntries = Seq[(String, String)]() + override def onApplicationStart(event: SparkListenerApplicationStart): Unit = { + sparkUser = Some(event.sparkUser) + } + override def onEnvironmentUpdate(environmentUpdate: SparkListenerEnvironmentUpdate) { synchronized { val environmentDetails = environmentUpdate.environmentDetails From 3d4d11a80fe8953d48d8bfac2ce112e37d38dc90 Mon Sep 17 00:00:00 2001 From: sharkdtu Date: Mon, 19 Jun 2017 14:54:54 -0700 Subject: [PATCH 0750/1765] [SPARK-21138][YARN] Cannot delete staging dir when the clusters of "spark.yarn.stagingDir" and "spark.hadoop.fs.defaultFS" are different MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## What changes were proposed in this pull request? When I set different clusters for "spark.hadoop.fs.defaultFS" and "spark.yarn.stagingDir" as follows: ``` spark.hadoop.fs.defaultFS hdfs://tl-nn-tdw.tencent-distribute.com:54310 spark.yarn.stagingDir hdfs://ss-teg-2-v2/tmp/spark ``` The staging dir can not be deleted, it will prompt following message: ``` java.lang.IllegalArgumentException: Wrong FS: hdfs://ss-teg-2-v2/tmp/spark/.sparkStaging/application_1496819138021_77618, expected: hdfs://tl-nn-tdw.tencent-distribute.com:54310 ``` ## How was this patch tested? Existing tests Author: sharkdtu Closes #18352 from sharkdtu/master. --- .../org/apache/spark/deploy/yarn/ApplicationMaster.scala | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala index 4f71a1606312d..4868180569778 100644 --- a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala +++ b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala @@ -209,8 +209,6 @@ private[spark] class ApplicationMaster( logInfo("ApplicationAttemptId: " + appAttemptId) - val fs = FileSystem.get(yarnConf) - // This shutdown hook should run *after* the SparkContext is shut down. val priority = ShutdownHookManager.SPARK_CONTEXT_SHUTDOWN_PRIORITY - 1 ShutdownHookManager.addShutdownHook(priority) { () => @@ -232,7 +230,7 @@ private[spark] class ApplicationMaster( // we only want to unregister if we don't want the RM to retry if (finalStatus == FinalApplicationStatus.SUCCEEDED || isLastAttempt) { unregister(finalStatus, finalMsg) - cleanupStagingDir(fs) + cleanupStagingDir() } } } @@ -533,7 +531,7 @@ private[spark] class ApplicationMaster( /** * Clean up the staging directory. */ - private def cleanupStagingDir(fs: FileSystem) { + private def cleanupStagingDir(): Unit = { var stagingDirPath: Path = null try { val preserveFiles = sparkConf.get(PRESERVE_STAGING_FILES) @@ -544,6 +542,7 @@ private[spark] class ApplicationMaster( return } logInfo("Deleting staging directory " + stagingDirPath) + val fs = stagingDirPath.getFileSystem(yarnConf) fs.delete(stagingDirPath, true) } } catch { From 9eacc5e4384de26eaf1d6475bcc698c4e86c996d Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Mon, 19 Jun 2017 15:14:33 -0700 Subject: [PATCH 0751/1765] [INFRA] Close stale PRs. Closes #18311 Closes #18278 From 9b57cd8d5c594731a7b3c90ce59bcddb05193d79 Mon Sep 17 00:00:00 2001 From: Yuming Wang Date: Tue, 20 Jun 2017 09:22:30 +0800 Subject: [PATCH 0752/1765] [SPARK-21133][CORE] Fix HighlyCompressedMapStatus#writeExternal throws NPE ## What changes were proposed in this pull request? Fix HighlyCompressedMapStatus#writeExternal NPE: ``` 17/06/18 15:00:27 ERROR Utils: Exception encountered java.lang.NullPointerException at org.apache.spark.scheduler.HighlyCompressedMapStatus$$anonfun$writeExternal$2.apply$mcV$sp(MapStatus.scala:171) at org.apache.spark.scheduler.HighlyCompressedMapStatus$$anonfun$writeExternal$2.apply(MapStatus.scala:167) at org.apache.spark.scheduler.HighlyCompressedMapStatus$$anonfun$writeExternal$2.apply(MapStatus.scala:167) at org.apache.spark.util.Utils$.tryOrIOException(Utils.scala:1303) at org.apache.spark.scheduler.HighlyCompressedMapStatus.writeExternal(MapStatus.scala:167) at java.io.ObjectOutputStream.writeExternalData(ObjectOutputStream.java:1459) at java.io.ObjectOutputStream.writeOrdinaryObject(ObjectOutputStream.java:1430) at java.io.ObjectOutputStream.writeObject0(ObjectOutputStream.java:1178) at java.io.ObjectOutputStream.writeArray(ObjectOutputStream.java:1378) at java.io.ObjectOutputStream.writeObject0(ObjectOutputStream.java:1174) at java.io.ObjectOutputStream.writeObject(ObjectOutputStream.java:348) at org.apache.spark.MapOutputTracker$$anonfun$serializeMapStatuses$1.apply$mcV$sp(MapOutputTracker.scala:617) at org.apache.spark.MapOutputTracker$$anonfun$serializeMapStatuses$1.apply(MapOutputTracker.scala:616) at org.apache.spark.MapOutputTracker$$anonfun$serializeMapStatuses$1.apply(MapOutputTracker.scala:616) at org.apache.spark.util.Utils$.tryWithSafeFinally(Utils.scala:1337) at org.apache.spark.MapOutputTracker$.serializeMapStatuses(MapOutputTracker.scala:619) at org.apache.spark.MapOutputTrackerMaster.getSerializedMapOutputStatuses(MapOutputTracker.scala:562) at org.apache.spark.MapOutputTrackerMaster$MessageLoop.run(MapOutputTracker.scala:351) at java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1142) at java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:617) at java.lang.Thread.run(Thread.java:745) 17/06/18 15:00:27 ERROR MapOutputTrackerMaster: java.lang.NullPointerException java.io.IOException: java.lang.NullPointerException at org.apache.spark.util.Utils$.tryOrIOException(Utils.scala:1310) at org.apache.spark.scheduler.HighlyCompressedMapStatus.writeExternal(MapStatus.scala:167) at java.io.ObjectOutputStream.writeExternalData(ObjectOutputStream.java:1459) at java.io.ObjectOutputStream.writeOrdinaryObject(ObjectOutputStream.java:1430) at java.io.ObjectOutputStream.writeObject0(ObjectOutputStream.java:1178) at java.io.ObjectOutputStream.writeArray(ObjectOutputStream.java:1378) at java.io.ObjectOutputStream.writeObject0(ObjectOutputStream.java:1174) at java.io.ObjectOutputStream.writeObject(ObjectOutputStream.java:348) at org.apache.spark.MapOutputTracker$$anonfun$serializeMapStatuses$1.apply$mcV$sp(MapOutputTracker.scala:617) at org.apache.spark.MapOutputTracker$$anonfun$serializeMapStatuses$1.apply(MapOutputTracker.scala:616) at org.apache.spark.MapOutputTracker$$anonfun$serializeMapStatuses$1.apply(MapOutputTracker.scala:616) at org.apache.spark.util.Utils$.tryWithSafeFinally(Utils.scala:1337) at org.apache.spark.MapOutputTracker$.serializeMapStatuses(MapOutputTracker.scala:619) at org.apache.spark.MapOutputTrackerMaster.getSerializedMapOutputStatuses(MapOutputTracker.scala:562) at org.apache.spark.MapOutputTrackerMaster$MessageLoop.run(MapOutputTracker.scala:351) at java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1142) at java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:617) at java.lang.Thread.run(Thread.java:745) Caused by: java.lang.NullPointerException at org.apache.spark.scheduler.HighlyCompressedMapStatus$$anonfun$writeExternal$2.apply$mcV$sp(MapStatus.scala:171) at org.apache.spark.scheduler.HighlyCompressedMapStatus$$anonfun$writeExternal$2.apply(MapStatus.scala:167) at org.apache.spark.scheduler.HighlyCompressedMapStatus$$anonfun$writeExternal$2.apply(MapStatus.scala:167) at org.apache.spark.util.Utils$.tryOrIOException(Utils.scala:1303) ... 17 more 17/06/18 15:00:27 INFO MapOutputTrackerMasterEndpoint: Asked to send map output locations for shuffle 0 to 10.17.47.20:50188 17/06/18 15:00:27 ERROR Utils: Exception encountered java.lang.NullPointerException at org.apache.spark.scheduler.HighlyCompressedMapStatus$$anonfun$writeExternal$2.apply$mcV$sp(MapStatus.scala:171) at org.apache.spark.scheduler.HighlyCompressedMapStatus$$anonfun$writeExternal$2.apply(MapStatus.scala:167) at org.apache.spark.scheduler.HighlyCompressedMapStatus$$anonfun$writeExternal$2.apply(MapStatus.scala:167) at org.apache.spark.util.Utils$.tryOrIOException(Utils.scala:1303) at org.apache.spark.scheduler.HighlyCompressedMapStatus.writeExternal(MapStatus.scala:167) at java.io.ObjectOutputStream.writeExternalData(ObjectOutputStream.java:1459) at java.io.ObjectOutputStream.writeOrdinaryObject(ObjectOutputStream.java:1430) at java.io.ObjectOutputStream.writeObject0(ObjectOutputStream.java:1178) at java.io.ObjectOutputStream.writeArray(ObjectOutputStream.java:1378) at java.io.ObjectOutputStream.writeObject0(ObjectOutputStream.java:1174) at java.io.ObjectOutputStream.writeObject(ObjectOutputStream.java:348) at org.apache.spark.MapOutputTracker$$anonfun$serializeMapStatuses$1.apply$mcV$sp(MapOutputTracker.scala:617) at org.apache.spark.MapOutputTracker$$anonfun$serializeMapStatuses$1.apply(MapOutputTracker.scala:616) at org.apache.spark.MapOutputTracker$$anonfun$serializeMapStatuses$1.apply(MapOutputTracker.scala:616) at org.apache.spark.util.Utils$.tryWithSafeFinally(Utils.scala:1337) at org.apache.spark.MapOutputTracker$.serializeMapStatuses(MapOutputTracker.scala:619) at org.apache.spark.MapOutputTrackerMaster.getSerializedMapOutputStatuses(MapOutputTracker.scala:562) at org.apache.spark.MapOutputTrackerMaster$MessageLoop.run(MapOutputTracker.scala:351) at java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1142) at java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:617) at java.lang.Thread.run(Thread.java:745) ``` ## How was this patch tested? manual tests Author: Yuming Wang Closes #18343 from wangyum/SPARK-21133. --- .../org/apache/spark/scheduler/MapStatus.scala | 2 +- .../spark/serializer/KryoSerializer.scala | 1 + .../spark/scheduler/MapStatusSuite.scala | 18 ++++++++++++++++-- 3 files changed, 18 insertions(+), 3 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/scheduler/MapStatus.scala b/core/src/main/scala/org/apache/spark/scheduler/MapStatus.scala index 048e0d0186594..5e45b375ddd45 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/MapStatus.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/MapStatus.scala @@ -141,7 +141,7 @@ private[spark] class HighlyCompressedMapStatus private ( private[this] var numNonEmptyBlocks: Int, private[this] var emptyBlocks: RoaringBitmap, private[this] var avgSize: Long, - @transient private var hugeBlockSizes: Map[Int, Byte]) + private var hugeBlockSizes: Map[Int, Byte]) extends MapStatus with Externalizable { // loc could be null when the default constructor is called during deserialization diff --git a/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala b/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala index e15166d11c243..4f03e54e304f6 100644 --- a/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala +++ b/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala @@ -175,6 +175,7 @@ class KryoSerializer(conf: SparkConf) kryo.register(None.getClass) kryo.register(Nil.getClass) kryo.register(Utils.classForName("scala.collection.immutable.$colon$colon")) + kryo.register(Utils.classForName("scala.collection.immutable.Map$EmptyMap$")) kryo.register(classOf[ArrayBuffer[Any]]) kryo.setClassLoader(classLoader) diff --git a/core/src/test/scala/org/apache/spark/scheduler/MapStatusSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/MapStatusSuite.scala index 3ec37f674c77b..e6120139f4958 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/MapStatusSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/MapStatusSuite.scala @@ -24,9 +24,9 @@ import scala.util.Random import org.mockito.Mockito._ import org.roaringbitmap.RoaringBitmap -import org.apache.spark.{SparkConf, SparkEnv, SparkFunSuite} +import org.apache.spark.{SparkConf, SparkContext, SparkEnv, SparkFunSuite} import org.apache.spark.internal.config -import org.apache.spark.serializer.JavaSerializer +import org.apache.spark.serializer.{JavaSerializer, KryoSerializer} import org.apache.spark.storage.BlockManagerId class MapStatusSuite extends SparkFunSuite { @@ -154,4 +154,18 @@ class MapStatusSuite extends SparkFunSuite { case part => assert(status2.getSizeForBlock(part) >= sizes(part)) } } + + test("SPARK-21133 HighlyCompressedMapStatus#writeExternal throws NPE") { + val conf = new SparkConf() + .set("spark.serializer", classOf[KryoSerializer].getName) + .setMaster("local") + .setAppName("SPARK-21133") + val sc = new SparkContext(conf) + try { + val count = sc.parallelize(0 until 3000, 10).repartition(2001).collect().length + assert(count === 3000) + } finally { + sc.stop() + } + } } From 8965fe764a4218d944938aa4828072f1ad9dbda7 Mon Sep 17 00:00:00 2001 From: actuaryzhang Date: Mon, 19 Jun 2017 19:41:24 -0700 Subject: [PATCH 0753/1765] [SPARK-20889][SPARKR] Grouped documentation for AGGREGATE column methods ## What changes were proposed in this pull request? Grouped documentation for the aggregate functions for Column. Author: actuaryzhang Closes #18025 from actuaryzhang/sparkRDoc4. --- R/pkg/R/functions.R | 427 ++++++++++++++++++-------------------------- R/pkg/R/generics.R | 56 ++++-- R/pkg/R/stats.R | 22 +-- 3 files changed, 219 insertions(+), 286 deletions(-) diff --git a/R/pkg/R/functions.R b/R/pkg/R/functions.R index 7128c3b9adff4..01ca8b8c4527d 100644 --- a/R/pkg/R/functions.R +++ b/R/pkg/R/functions.R @@ -18,6 +18,22 @@ #' @include generics.R column.R NULL +#' Aggregate functions for Column operations +#' +#' Aggregate functions defined for \code{Column}. +#' +#' @param x Column to compute on. +#' @param y,na.rm,use currently not used. +#' @param ... additional argument(s). For example, it could be used to pass additional Columns. +#' @name column_aggregate_functions +#' @rdname column_aggregate_functions +#' @family aggregate functions +#' @examples +#' \dontrun{ +#' # Dataframe used throughout this doc +#' df <- createDataFrame(cbind(model = rownames(mtcars), mtcars))} +NULL + #' lit #' #' A new \linkS4class{Column} is created to represent the literal value. @@ -85,17 +101,20 @@ setMethod("acos", column(jc) }) -#' Returns the approximate number of distinct items in a group +#' @details +#' \code{approxCountDistinct}: Returns the approximate number of distinct items in a group. #' -#' Returns the approximate number of distinct items in a group. This is a column -#' aggregate function. -#' -#' @rdname approxCountDistinct -#' @name approxCountDistinct -#' @return the approximate number of distinct items in a group. +#' @rdname column_aggregate_functions #' @export -#' @aliases approxCountDistinct,Column-method -#' @examples \dontrun{approxCountDistinct(df$c)} +#' @aliases approxCountDistinct approxCountDistinct,Column-method +#' @examples +#' +#' \dontrun{ +#' head(select(df, approxCountDistinct(df$gear))) +#' head(select(df, approxCountDistinct(df$gear, 0.02))) +#' head(select(df, countDistinct(df$gear, df$cyl))) +#' head(select(df, n_distinct(df$gear))) +#' head(distinct(select(df, "gear")))} #' @note approxCountDistinct(Column) since 1.4.0 setMethod("approxCountDistinct", signature(x = "Column"), @@ -342,10 +361,13 @@ setMethod("column", #' #' @rdname corr #' @name corr -#' @family math functions +#' @family aggregate functions #' @export #' @aliases corr,Column-method -#' @examples \dontrun{corr(df$c, df$d)} +#' @examples +#' \dontrun{ +#' df <- createDataFrame(cbind(model = rownames(mtcars), mtcars)) +#' head(select(df, corr(df$mpg, df$hp)))} #' @note corr since 1.6.0 setMethod("corr", signature(x = "Column"), function(x, col2) { @@ -356,20 +378,22 @@ setMethod("corr", signature(x = "Column"), #' cov #' -#' Compute the sample covariance between two expressions. +#' Compute the covariance between two expressions. +#' +#' @details +#' \code{cov}: Compute the sample covariance between two expressions. #' #' @rdname cov #' @name cov -#' @family math functions +#' @family aggregate functions #' @export #' @aliases cov,characterOrColumn-method #' @examples #' \dontrun{ -#' cov(df$c, df$d) -#' cov("c", "d") -#' covar_samp(df$c, df$d) -#' covar_samp("c", "d") -#' } +#' df <- createDataFrame(cbind(model = rownames(mtcars), mtcars)) +#' head(select(df, cov(df$mpg, df$hp), cov("mpg", "hp"), +#' covar_samp(df$mpg, df$hp), covar_samp("mpg", "hp"), +#' covar_pop(df$mpg, df$hp), covar_pop("mpg", "hp")))} #' @note cov since 1.6.0 setMethod("cov", signature(x = "characterOrColumn"), function(x, col2) { @@ -377,6 +401,9 @@ setMethod("cov", signature(x = "characterOrColumn"), covar_samp(x, col2) }) +#' @details +#' \code{covar_sample}: Alias for \code{cov}. +#' #' @rdname cov #' #' @param col1 the first Column. @@ -395,23 +422,13 @@ setMethod("covar_samp", signature(col1 = "characterOrColumn", col2 = "characterO column(jc) }) -#' covar_pop +#' @details +#' \code{covar_pop}: Computes the population covariance between two expressions. #' -#' Compute the population covariance between two expressions. -#' -#' @param col1 First column to compute cov_pop. -#' @param col2 Second column to compute cov_pop. -#' -#' @rdname covar_pop +#' @rdname cov #' @name covar_pop -#' @family math functions #' @export #' @aliases covar_pop,characterOrColumn,characterOrColumn-method -#' @examples -#' \dontrun{ -#' covar_pop(df$c, df$d) -#' covar_pop("c", "d") -#' } #' @note covar_pop since 2.0.0 setMethod("covar_pop", signature(col1 = "characterOrColumn", col2 = "characterOrColumn"), function(col1, col2) { @@ -823,18 +840,16 @@ setMethod("isnan", column(jc) }) -#' kurtosis -#' -#' Aggregate function: returns the kurtosis of the values in a group. +#' @details +#' \code{kurtosis}: Returns the kurtosis of the values in a group. #' -#' @param x Column to compute on. -#' -#' @rdname kurtosis -#' @name kurtosis -#' @aliases kurtosis,Column-method -#' @family aggregate functions +#' @rdname column_aggregate_functions +#' @aliases kurtosis kurtosis,Column-method #' @export -#' @examples \dontrun{kurtosis(df$c)} +#' @examples +#' +#' \dontrun{ +#' head(select(df, mean(df$mpg), sd(df$mpg), skewness(df$mpg), kurtosis(df$mpg)))} #' @note kurtosis since 1.6.0 setMethod("kurtosis", signature(x = "Column"), @@ -1040,18 +1055,11 @@ setMethod("ltrim", column(jc) }) -#' max -#' -#' Aggregate function: returns the maximum value of the expression in a group. -#' -#' @param x Column to compute on. +#' @details +#' \code{max}: Returns the maximum value of the expression in a group. #' -#' @rdname max -#' @name max -#' @family aggregate functions -#' @aliases max,Column-method -#' @export -#' @examples \dontrun{max(df$c)} +#' @rdname column_aggregate_functions +#' @aliases max max,Column-method #' @note max since 1.5.0 setMethod("max", signature(x = "Column"), @@ -1081,19 +1089,24 @@ setMethod("md5", column(jc) }) -#' mean +#' @details +#' \code{mean}: Returns the average of the values in a group. Alias for \code{avg}. #' -#' Aggregate function: returns the average of the values in a group. -#' Alias for avg. +#' @rdname column_aggregate_functions +#' @aliases mean mean,Column-method +#' @export +#' @examples #' -#' @param x Column to compute on. +#' \dontrun{ +#' head(select(df, avg(df$mpg), mean(df$mpg), sum(df$mpg), min(df$wt), max(df$qsec))) #' -#' @rdname mean -#' @name mean -#' @family aggregate functions -#' @aliases mean,Column-method -#' @export -#' @examples \dontrun{mean(df$c)} +#' # metrics by num of cylinders +#' tmp <- agg(groupBy(df, "cyl"), avg(df$mpg), avg(df$hp), avg(df$wt), avg(df$qsec)) +#' head(orderBy(tmp, "cyl")) +#' +#' # car with the max mpg +#' mpg_max <- as.numeric(collect(agg(df, max(df$mpg)))) +#' head(where(df, df$mpg == mpg_max))} #' @note mean since 1.5.0 setMethod("mean", signature(x = "Column"), @@ -1102,18 +1115,12 @@ setMethod("mean", column(jc) }) -#' min -#' -#' Aggregate function: returns the minimum value of the expression in a group. -#' -#' @param x Column to compute on. +#' @details +#' \code{min}: Returns the minimum value of the expression in a group. #' -#' @rdname min -#' @name min -#' @aliases min,Column-method -#' @family aggregate functions +#' @rdname column_aggregate_functions +#' @aliases min min,Column-method #' @export -#' @examples \dontrun{min(df$c)} #' @note min since 1.5.0 setMethod("min", signature(x = "Column"), @@ -1338,24 +1345,17 @@ setMethod("rtrim", column(jc) }) -#' sd -#' -#' Aggregate function: alias for \link{stddev_samp} + +#' @details +#' \code{sd}: Alias for \code{stddev_samp}. #' -#' @param x Column to compute on. -#' @param na.rm currently not used. -#' @rdname sd -#' @name sd -#' @family aggregate functions -#' @aliases sd,Column-method -#' @seealso \link{stddev_pop}, \link{stddev_samp} +#' @rdname column_aggregate_functions +#' @aliases sd sd,Column-method #' @export #' @examples -#'\dontrun{ -#'stddev(df$c) -#'select(df, stddev(df$age)) -#'agg(df, sd(df$age)) -#'} +#' +#' \dontrun{ +#' head(select(df, sd(df$mpg), stddev(df$mpg), stddev_pop(df$wt), stddev_samp(df$qsec)))} #' @note sd since 1.6.0 setMethod("sd", signature(x = "Column"), @@ -1465,18 +1465,12 @@ setMethod("sinh", column(jc) }) -#' skewness -#' -#' Aggregate function: returns the skewness of the values in a group. -#' -#' @param x Column to compute on. +#' @details +#' \code{skewness}: Returns the skewness of the values in a group. #' -#' @rdname skewness -#' @name skewness -#' @family aggregate functions -#' @aliases skewness,Column-method +#' @rdname column_aggregate_functions +#' @aliases skewness skewness,Column-method #' @export -#' @examples \dontrun{skewness(df$c)} #' @note skewness since 1.6.0 setMethod("skewness", signature(x = "Column"), @@ -1527,9 +1521,11 @@ setMethod("spark_partition_id", column(jc) }) -#' @rdname sd -#' @aliases stddev,Column-method -#' @name stddev +#' @details +#' \code{stddev}: Alias for \code{std_dev}. +#' +#' @rdname column_aggregate_functions +#' @aliases stddev stddev,Column-method #' @note stddev since 1.6.0 setMethod("stddev", signature(x = "Column"), @@ -1538,19 +1534,12 @@ setMethod("stddev", column(jc) }) -#' stddev_pop -#' -#' Aggregate function: returns the population standard deviation of the expression in a group. -#' -#' @param x Column to compute on. +#' @details +#' \code{stddev_pop}: Returns the population standard deviation of the expression in a group. #' -#' @rdname stddev_pop -#' @name stddev_pop -#' @family aggregate functions -#' @aliases stddev_pop,Column-method -#' @seealso \link{sd}, \link{stddev_samp} +#' @rdname column_aggregate_functions +#' @aliases stddev_pop stddev_pop,Column-method #' @export -#' @examples \dontrun{stddev_pop(df$c)} #' @note stddev_pop since 1.6.0 setMethod("stddev_pop", signature(x = "Column"), @@ -1559,19 +1548,12 @@ setMethod("stddev_pop", column(jc) }) -#' stddev_samp -#' -#' Aggregate function: returns the unbiased sample standard deviation of the expression in a group. -#' -#' @param x Column to compute on. +#' @details +#' \code{stddev_samp}: Returns the unbiased sample standard deviation of the expression in a group. #' -#' @rdname stddev_samp -#' @name stddev_samp -#' @family aggregate functions -#' @aliases stddev_samp,Column-method -#' @seealso \link{stddev_pop}, \link{sd} +#' @rdname column_aggregate_functions +#' @aliases stddev_samp stddev_samp,Column-method #' @export -#' @examples \dontrun{stddev_samp(df$c)} #' @note stddev_samp since 1.6.0 setMethod("stddev_samp", signature(x = "Column"), @@ -1630,18 +1612,12 @@ setMethod("sqrt", column(jc) }) -#' sum -#' -#' Aggregate function: returns the sum of all values in the expression. -#' -#' @param x Column to compute on. +#' @details +#' \code{sum}: Returns the sum of all values in the expression. #' -#' @rdname sum -#' @name sum -#' @family aggregate functions -#' @aliases sum,Column-method +#' @rdname column_aggregate_functions +#' @aliases sum sum,Column-method #' @export -#' @examples \dontrun{sum(df$c)} #' @note sum since 1.5.0 setMethod("sum", signature(x = "Column"), @@ -1650,18 +1626,17 @@ setMethod("sum", column(jc) }) -#' sumDistinct -#' -#' Aggregate function: returns the sum of distinct values in the expression. +#' @details +#' \code{sumDistinct}: Returns the sum of distinct values in the expression. #' -#' @param x Column to compute on. -#' -#' @rdname sumDistinct -#' @name sumDistinct -#' @family aggregate functions -#' @aliases sumDistinct,Column-method +#' @rdname column_aggregate_functions +#' @aliases sumDistinct sumDistinct,Column-method #' @export -#' @examples \dontrun{sumDistinct(df$c)} +#' @examples +#' +#' \dontrun{ +#' head(select(df, sumDistinct(df$gear))) +#' head(distinct(select(df, "gear")))} #' @note sumDistinct since 1.4.0 setMethod("sumDistinct", signature(x = "Column"), @@ -1952,24 +1927,16 @@ setMethod("upper", column(jc) }) -#' var -#' -#' Aggregate function: alias for \link{var_samp}. +#' @details +#' \code{var}: Alias for \code{var_samp}. #' -#' @param x a Column to compute on. -#' @param y,na.rm,use currently not used. -#' @rdname var -#' @name var -#' @family aggregate functions -#' @aliases var,Column-method -#' @seealso \link{var_pop}, \link{var_samp} +#' @rdname column_aggregate_functions +#' @aliases var var,Column-method #' @export #' @examples +#' #'\dontrun{ -#'variance(df$c) -#'select(df, var_pop(df$age)) -#'agg(df, var(df$age)) -#'} +#'head(agg(df, var(df$mpg), variance(df$mpg), var_pop(df$mpg), var_samp(df$mpg)))} #' @note var since 1.6.0 setMethod("var", signature(x = "Column"), @@ -1978,9 +1945,9 @@ setMethod("var", var_samp(x) }) -#' @rdname var -#' @aliases variance,Column-method -#' @name variance +#' @rdname column_aggregate_functions +#' @aliases variance variance,Column-method +#' @export #' @note variance since 1.6.0 setMethod("variance", signature(x = "Column"), @@ -1989,19 +1956,12 @@ setMethod("variance", column(jc) }) -#' var_pop +#' @details +#' \code{var_pop}: Returns the population variance of the values in a group. #' -#' Aggregate function: returns the population variance of the values in a group. -#' -#' @param x Column to compute on. -#' -#' @rdname var_pop -#' @name var_pop -#' @family aggregate functions -#' @aliases var_pop,Column-method -#' @seealso \link{var}, \link{var_samp} +#' @rdname column_aggregate_functions +#' @aliases var_pop var_pop,Column-method #' @export -#' @examples \dontrun{var_pop(df$c)} #' @note var_pop since 1.5.0 setMethod("var_pop", signature(x = "Column"), @@ -2010,19 +1970,12 @@ setMethod("var_pop", column(jc) }) -#' var_samp +#' @details +#' \code{var_samp}: Returns the unbiased variance of the values in a group. #' -#' Aggregate function: returns the unbiased variance of the values in a group. -#' -#' @param x Column to compute on. -#' -#' @rdname var_samp -#' @name var_samp -#' @aliases var_samp,Column-method -#' @family aggregate functions -#' @seealso \link{var_pop}, \link{var} +#' @rdname column_aggregate_functions +#' @aliases var_samp var_samp,Column-method #' @export -#' @examples \dontrun{var_samp(df$c)} #' @note var_samp since 1.6.0 setMethod("var_samp", signature(x = "Column"), @@ -2235,17 +2188,11 @@ setMethod("pmod", signature(y = "Column"), column(jc) }) - -#' @rdname approxCountDistinct -#' @name approxCountDistinct -#' -#' @param x Column to compute on. #' @param rsd maximum estimation error allowed (default = 0.05) -#' @param ... further arguments to be passed to or from other methods. #' +#' @rdname column_aggregate_functions #' @aliases approxCountDistinct,Column-method #' @export -#' @examples \dontrun{approxCountDistinct(df$c, 0.02)} #' @note approxCountDistinct(Column, numeric) since 1.4.0 setMethod("approxCountDistinct", signature(x = "Column"), @@ -2254,18 +2201,12 @@ setMethod("approxCountDistinct", column(jc) }) -#' Count Distinct Values +#' @details +#' \code{countDistinct}: Returns the number of distinct items in a group. #' -#' @param x Column to compute on -#' @param ... other columns -#' -#' @family aggregate functions -#' @rdname countDistinct -#' @name countDistinct -#' @aliases countDistinct,Column-method -#' @return the number of distinct items in a group. +#' @rdname column_aggregate_functions +#' @aliases countDistinct countDistinct,Column-method #' @export -#' @examples \dontrun{countDistinct(df$c)} #' @note countDistinct since 1.4.0 setMethod("countDistinct", signature(x = "Column"), @@ -2384,15 +2325,12 @@ setMethod("sign", signature(x = "Column"), signum(x) }) -#' n_distinct -#' -#' Aggregate function: returns the number of distinct items in a group. +#' @details +#' \code{n_distinct}: Returns the number of distinct items in a group. #' -#' @rdname countDistinct -#' @name n_distinct -#' @aliases n_distinct,Column-method +#' @rdname column_aggregate_functions +#' @aliases n_distinct n_distinct,Column-method #' @export -#' @examples \dontrun{n_distinct(df$c)} #' @note n_distinct since 1.4.0 setMethod("n_distinct", signature(x = "Column"), function(x, ...) { @@ -3717,18 +3655,18 @@ setMethod("create_map", column(jc) }) -#' collect_list +#' @details +#' \code{collect_list}: Creates a list of objects with duplicates. #' -#' Creates a list of objects with duplicates. -#' -#' @param x Column to compute on -#' -#' @rdname collect_list -#' @name collect_list -#' @family aggregate functions -#' @aliases collect_list,Column-method +#' @rdname column_aggregate_functions +#' @aliases collect_list collect_list,Column-method #' @export -#' @examples \dontrun{collect_list(df$x)} +#' @examples +#' +#' \dontrun{ +#' df2 = df[df$mpg > 20, ] +#' collect(select(df2, collect_list(df2$gear))) +#' collect(select(df2, collect_set(df2$gear)))} #' @note collect_list since 2.3.0 setMethod("collect_list", signature(x = "Column"), @@ -3737,18 +3675,12 @@ setMethod("collect_list", column(jc) }) -#' collect_set -#' -#' Creates a list of objects with duplicate elements eliminated. +#' @details +#' \code{collect_set}: Creates a list of objects with duplicate elements eliminated. #' -#' @param x Column to compute on -#' -#' @rdname collect_set -#' @name collect_set -#' @family aggregate functions -#' @aliases collect_set,Column-method +#' @rdname column_aggregate_functions +#' @aliases collect_set collect_set,Column-method #' @export -#' @examples \dontrun{collect_set(df$x)} #' @note collect_set since 2.3.0 setMethod("collect_set", signature(x = "Column"), @@ -3908,24 +3840,17 @@ setMethod("not", column(jc) }) -#' grouping_bit -#' -#' Indicates whether a specified column in a GROUP BY list is aggregated or not, -#' returns 1 for aggregated or 0 for not aggregated in the result set. +#' @details +#' \code{grouping_bit}: Indicates whether a specified column in a GROUP BY list is aggregated or not, +#' returns 1 for aggregated or 0 for not aggregated in the result set. Same as \code{GROUPING} in SQL +#' and \code{grouping} function in Scala. #' -#' Same as \code{GROUPING} in SQL and \code{grouping} function in Scala. -#' -#' @param x Column to compute on -#' -#' @rdname grouping_bit -#' @name grouping_bit -#' @family aggregate functions -#' @aliases grouping_bit,Column-method +#' @rdname column_aggregate_functions +#' @aliases grouping_bit grouping_bit,Column-method #' @export #' @examples -#' \dontrun{ -#' df <- createDataFrame(mtcars) #' +#' \dontrun{ #' # With cube #' agg( #' cube(df, "cyl", "gear", "am"), @@ -3938,8 +3863,7 @@ setMethod("not", #' rollup(df, "cyl", "gear", "am"), #' mean(df$mpg), #' grouping_bit(df$cyl), grouping_bit(df$gear), grouping_bit(df$am) -#' ) -#' } +#' )} #' @note grouping_bit since 2.3.0 setMethod("grouping_bit", signature(x = "Column"), @@ -3948,26 +3872,18 @@ setMethod("grouping_bit", column(jc) }) -#' grouping_id -#' -#' Returns the level of grouping. -#' +#' @details +#' \code{grouping_id}: Returns the level of grouping. #' Equals to \code{ #' grouping_bit(c1) * 2^(n - 1) + grouping_bit(c2) * 2^(n - 2) + ... + grouping_bit(cn) #' } #' -#' @param x Column to compute on -#' @param ... additional Column(s) (optional). -#' -#' @rdname grouping_id -#' @name grouping_id -#' @family aggregate functions -#' @aliases grouping_id,Column-method +#' @rdname column_aggregate_functions +#' @aliases grouping_id grouping_id,Column-method #' @export #' @examples -#' \dontrun{ -#' df <- createDataFrame(mtcars) #' +#' \dontrun{ #' # With cube #' agg( #' cube(df, "cyl", "gear", "am"), @@ -3980,8 +3896,7 @@ setMethod("grouping_bit", #' rollup(df, "cyl", "gear", "am"), #' mean(df$mpg), #' grouping_id(df$cyl, df$gear, df$am) -#' ) -#' } +#' )} #' @note grouping_id since 2.3.0 setMethod("grouping_id", signature(x = "Column"), diff --git a/R/pkg/R/generics.R b/R/pkg/R/generics.R index 5630d0c8a0df9..b3cc4868a0b33 100644 --- a/R/pkg/R/generics.R +++ b/R/pkg/R/generics.R @@ -479,7 +479,7 @@ setGeneric("corr", function(x, ...) {standardGeneric("corr") }) #' @export setGeneric("covar_samp", function(col1, col2) {standardGeneric("covar_samp") }) -#' @rdname covar_pop +#' @rdname cov #' @export setGeneric("covar_pop", function(col1, col2) {standardGeneric("covar_pop") }) @@ -907,8 +907,9 @@ setGeneric("windowOrderBy", function(col, ...) { standardGeneric("windowOrderBy" #' @export setGeneric("add_months", function(y, x) { standardGeneric("add_months") }) -#' @rdname approxCountDistinct +#' @rdname column_aggregate_functions #' @export +#' @name NULL setGeneric("approxCountDistinct", function(x, ...) { standardGeneric("approxCountDistinct") }) #' @rdname array_contains @@ -949,12 +950,14 @@ setGeneric("cbrt", function(x) { standardGeneric("cbrt") }) #' @export setGeneric("ceil", function(x) { standardGeneric("ceil") }) -#' @rdname collect_list +#' @rdname column_aggregate_functions #' @export +#' @name NULL setGeneric("collect_list", function(x) { standardGeneric("collect_list") }) -#' @rdname collect_set +#' @rdname column_aggregate_functions #' @export +#' @name NULL setGeneric("collect_set", function(x) { standardGeneric("collect_set") }) #' @rdname column @@ -973,8 +976,9 @@ setGeneric("concat_ws", function(sep, x, ...) { standardGeneric("concat_ws") }) #' @export setGeneric("conv", function(x, fromBase, toBase) { standardGeneric("conv") }) -#' @rdname countDistinct +#' @rdname column_aggregate_functions #' @export +#' @name NULL setGeneric("countDistinct", function(x, ...) { standardGeneric("countDistinct") }) #' @rdname crc32 @@ -1071,12 +1075,14 @@ setGeneric("from_unixtime", function(x, ...) { standardGeneric("from_unixtime") #' @export setGeneric("greatest", function(x, ...) { standardGeneric("greatest") }) -#' @rdname grouping_bit +#' @rdname column_aggregate_functions #' @export +#' @name NULL setGeneric("grouping_bit", function(x) { standardGeneric("grouping_bit") }) -#' @rdname grouping_id +#' @rdname column_aggregate_functions #' @export +#' @name NULL setGeneric("grouping_id", function(x, ...) { standardGeneric("grouping_id") }) #' @rdname hex @@ -1109,8 +1115,9 @@ setGeneric("instr", function(y, x) { standardGeneric("instr") }) #' @export setGeneric("isnan", function(x) { standardGeneric("isnan") }) -#' @rdname kurtosis +#' @rdname column_aggregate_functions #' @export +#' @name NULL setGeneric("kurtosis", function(x) { standardGeneric("kurtosis") }) #' @rdname lag @@ -1203,8 +1210,9 @@ setGeneric("next_day", function(y, x) { standardGeneric("next_day") }) #' @export setGeneric("ntile", function(x) { standardGeneric("ntile") }) -#' @rdname countDistinct +#' @rdname column_aggregate_functions #' @export +#' @name NULL setGeneric("n_distinct", function(x, ...) { standardGeneric("n_distinct") }) #' @param x empty. Should be used with no argument. @@ -1274,8 +1282,9 @@ setGeneric("rpad", function(x, len, pad) { standardGeneric("rpad") }) #' @export setGeneric("rtrim", function(x) { standardGeneric("rtrim") }) -#' @rdname sd +#' @rdname column_aggregate_functions #' @export +#' @name NULL setGeneric("sd", function(x, na.rm = FALSE) { standardGeneric("sd") }) #' @rdname second @@ -1310,8 +1319,9 @@ setGeneric("signum", function(x) { standardGeneric("signum") }) #' @export setGeneric("size", function(x) { standardGeneric("size") }) -#' @rdname skewness +#' @rdname column_aggregate_functions #' @export +#' @name NULL setGeneric("skewness", function(x) { standardGeneric("skewness") }) #' @rdname sort_array @@ -1331,16 +1341,19 @@ setGeneric("soundex", function(x) { standardGeneric("soundex") }) #' @export setGeneric("spark_partition_id", function(x = "missing") { standardGeneric("spark_partition_id") }) -#' @rdname sd +#' @rdname column_aggregate_functions #' @export +#' @name NULL setGeneric("stddev", function(x) { standardGeneric("stddev") }) -#' @rdname stddev_pop +#' @rdname column_aggregate_functions #' @export +#' @name NULL setGeneric("stddev_pop", function(x) { standardGeneric("stddev_pop") }) -#' @rdname stddev_samp +#' @rdname column_aggregate_functions #' @export +#' @name NULL setGeneric("stddev_samp", function(x) { standardGeneric("stddev_samp") }) #' @rdname struct @@ -1351,8 +1364,9 @@ setGeneric("struct", function(x, ...) { standardGeneric("struct") }) #' @export setGeneric("substring_index", function(x, delim, count) { standardGeneric("substring_index") }) -#' @rdname sumDistinct +#' @rdname column_aggregate_functions #' @export +#' @name NULL setGeneric("sumDistinct", function(x) { standardGeneric("sumDistinct") }) #' @rdname toDegrees @@ -1403,20 +1417,24 @@ setGeneric("unix_timestamp", function(x, format) { standardGeneric("unix_timesta #' @export setGeneric("upper", function(x) { standardGeneric("upper") }) -#' @rdname var +#' @rdname column_aggregate_functions #' @export +#' @name NULL setGeneric("var", function(x, y = NULL, na.rm = FALSE, use) { standardGeneric("var") }) -#' @rdname var +#' @rdname column_aggregate_functions #' @export +#' @name NULL setGeneric("variance", function(x) { standardGeneric("variance") }) -#' @rdname var_pop +#' @rdname column_aggregate_functions #' @export +#' @name NULL setGeneric("var_pop", function(x) { standardGeneric("var_pop") }) -#' @rdname var_samp +#' @rdname column_aggregate_functions #' @export +#' @name NULL setGeneric("var_samp", function(x) { standardGeneric("var_samp") }) #' @rdname weekofyear diff --git a/R/pkg/R/stats.R b/R/pkg/R/stats.R index d78a10893f92e..9a9fa84044ce6 100644 --- a/R/pkg/R/stats.R +++ b/R/pkg/R/stats.R @@ -52,22 +52,23 @@ setMethod("crosstab", collect(dataFrame(sct)) }) -#' Calculate the sample covariance of two numerical columns of a SparkDataFrame. +#' @details +#' \code{cov}: When applied to SparkDataFrame, this calculates the sample covariance of two numerical +#' columns of \emph{one} SparkDataFrame. #' #' @param colName1 the name of the first column #' @param colName2 the name of the second column #' @return The covariance of the two columns. #' #' @rdname cov -#' @name cov #' @aliases cov,SparkDataFrame-method #' @family stat functions #' @export #' @examples -#'\dontrun{ -#' df <- read.json("/path/to/file.json") -#' cov <- cov(df, "title", "gender") -#' } +#' +#' \dontrun{ +#' cov(df, "mpg", "hp") +#' cov(df, df$mpg, df$hp)} #' @note cov since 1.6.0 setMethod("cov", signature(x = "SparkDataFrame"), @@ -93,11 +94,10 @@ setMethod("cov", #' @family stat functions #' @export #' @examples -#'\dontrun{ -#' df <- read.json("/path/to/file.json") -#' corr <- corr(df, "title", "gender") -#' corr <- corr(df, "title", "gender", method = "pearson") -#' } +#' +#' \dontrun{ +#' corr(df, "mpg", "hp") +#' corr(df, "mpg", "hp", method = "pearson")} #' @note corr since 1.6.0 setMethod("corr", signature(x = "SparkDataFrame"), From cc67bd573264c9046c4a034927ed8deb2a732110 Mon Sep 17 00:00:00 2001 From: "Joseph K. Bradley" Date: Mon, 19 Jun 2017 23:04:17 -0700 Subject: [PATCH 0754/1765] [SPARK-20929][ML] LinearSVC should use its own threshold param ## What changes were proposed in this pull request? LinearSVC should use its own threshold param, rather than the shared one, since it applies to rawPrediction instead of probability. This PR changes the param in the Scala, Python and R APIs. ## How was this patch tested? New unit test to make sure the threshold can be set to any Double value. Author: Joseph K. Bradley Closes #18151 from jkbradley/ml-2.2-linearsvc-cleanup. --- R/pkg/R/mllib_classification.R | 4 ++- .../spark/ml/classification/LinearSVC.scala | 25 +++++++++++-- .../ml/classification/LinearSVCSuite.scala | 35 ++++++++++++++++++- python/pyspark/ml/classification.py | 20 ++++++++++- 4 files changed, 79 insertions(+), 5 deletions(-) diff --git a/R/pkg/R/mllib_classification.R b/R/pkg/R/mllib_classification.R index 306a9b8676539..bdcc0818d139d 100644 --- a/R/pkg/R/mllib_classification.R +++ b/R/pkg/R/mllib_classification.R @@ -62,7 +62,9 @@ setClass("NaiveBayesModel", representation(jobj = "jobj")) #' of models will be always returned on the original scale, so it will be transparent for #' users. Note that with/without standardization, the models should be always converged #' to the same solution when no regularization is applied. -#' @param threshold The threshold in binary classification, in range [0, 1]. +#' @param threshold The threshold in binary classification applied to the linear model prediction. +#' This threshold can be any real number, where Inf will make all predictions 0.0 +#' and -Inf will make all predictions 1.0. #' @param weightCol The weight column name. #' @param aggregationDepth The depth for treeAggregate (greater than or equal to 2). If the dimensions of features #' or the number of partitions are large, this param could be adjusted to a larger size. diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LinearSVC.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LinearSVC.scala index 9900fbc9edda7..d6ed6a4570a4a 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/LinearSVC.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LinearSVC.scala @@ -42,7 +42,23 @@ import org.apache.spark.sql.functions.{col, lit} /** Params for linear SVM Classifier. */ private[classification] trait LinearSVCParams extends ClassifierParams with HasRegParam with HasMaxIter with HasFitIntercept with HasTol with HasStandardization with HasWeightCol - with HasThreshold with HasAggregationDepth + with HasAggregationDepth { + + /** + * Param for threshold in binary classification prediction. + * For LinearSVC, this threshold is applied to the rawPrediction, rather than a probability. + * This threshold can be any real number, where Inf will make all predictions 0.0 + * and -Inf will make all predictions 1.0. + * Default: 0.0 + * + * @group param + */ + final val threshold: DoubleParam = new DoubleParam(this, "threshold", + "threshold in binary classification prediction applied to rawPrediction") + + /** @group getParam */ + def getThreshold: Double = $(threshold) +} /** * :: Experimental :: @@ -126,7 +142,7 @@ class LinearSVC @Since("2.2.0") ( def setWeightCol(value: String): this.type = set(weightCol, value) /** - * Set threshold in binary classification, in range [0, 1]. + * Set threshold in binary classification. * * @group setParam */ @@ -284,6 +300,7 @@ class LinearSVCModel private[classification] ( @Since("2.2.0") def setThreshold(value: Double): this.type = set(threshold, value) + setDefault(threshold, 0.0) @Since("2.2.0") def setWeightCol(value: Double): this.type = set(threshold, value) @@ -301,6 +318,10 @@ class LinearSVCModel private[classification] ( Vectors.dense(-m, m) } + override protected def raw2prediction(rawPrediction: Vector): Double = { + if (rawPrediction(1) > $(threshold)) 1.0 else 0.0 + } + @Since("2.2.0") override def copy(extra: ParamMap): LinearSVCModel = { copyValues(new LinearSVCModel(uid, coefficients, intercept), extra).setParent(parent) diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/LinearSVCSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/LinearSVCSuite.scala index 2f87afc23fe7e..f2b00d0bae1d6 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/LinearSVCSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/LinearSVCSuite.scala @@ -25,7 +25,7 @@ import org.apache.spark.SparkFunSuite import org.apache.spark.ml.classification.LinearSVCSuite._ import org.apache.spark.ml.feature.{Instance, LabeledPoint} import org.apache.spark.ml.linalg.{DenseVector, SparseVector, Vector, Vectors} -import org.apache.spark.ml.param.ParamsSuite +import org.apache.spark.ml.param.{ParamMap, ParamsSuite} import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} import org.apache.spark.ml.util.TestingUtils._ import org.apache.spark.mllib.util.MLlibTestSparkContext @@ -127,6 +127,39 @@ class LinearSVCSuite extends SparkFunSuite with MLlibTestSparkContext with Defau MLTestingUtils.checkCopyAndUids(lsvc, model) } + test("LinearSVC threshold acts on rawPrediction") { + val lsvc = + new LinearSVCModel(uid = "myLSVCM", coefficients = Vectors.dense(1.0), intercept = 0.0) + val df = spark.createDataFrame(Seq( + (1, Vectors.dense(1e-7)), + (0, Vectors.dense(0.0)), + (-1, Vectors.dense(-1e-7)))).toDF("id", "features") + + def checkOneResult( + model: LinearSVCModel, + threshold: Double, + expected: Set[(Int, Double)]): Unit = { + model.setThreshold(threshold) + val results = model.transform(df).select("id", "prediction").collect() + .map(r => (r.getInt(0), r.getDouble(1))) + .toSet + assert(results === expected, s"Failed for threshold = $threshold") + } + + def checkResults(threshold: Double, expected: Set[(Int, Double)]): Unit = { + // Check via code path using Classifier.raw2prediction + lsvc.setRawPredictionCol("rawPrediction") + checkOneResult(lsvc, threshold, expected) + // Check via code path using Classifier.predict + lsvc.setRawPredictionCol("") + checkOneResult(lsvc, threshold, expected) + } + + checkResults(0.0, Set((1, 1.0), (0, 0.0), (-1, 0.0))) + checkResults(Double.PositiveInfinity, Set((1, 0.0), (0, 0.0), (-1, 0.0))) + checkResults(Double.NegativeInfinity, Set((1, 1.0), (0, 1.0), (-1, 1.0))) + } + test("linear svc doesn't fit intercept when fitIntercept is off") { val lsvc = new LinearSVC().setFitIntercept(false).setMaxIter(5) val model = lsvc.fit(smallBinaryDataset) diff --git a/python/pyspark/ml/classification.py b/python/pyspark/ml/classification.py index 60bdeedd6a144..9b345ac73f3d9 100644 --- a/python/pyspark/ml/classification.py +++ b/python/pyspark/ml/classification.py @@ -63,7 +63,7 @@ def numClasses(self): @inherit_doc class LinearSVC(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, HasMaxIter, HasRegParam, HasTol, HasRawPredictionCol, HasFitIntercept, HasStandardization, - HasThreshold, HasWeightCol, HasAggregationDepth, JavaMLWritable, JavaMLReadable): + HasWeightCol, HasAggregationDepth, JavaMLWritable, JavaMLReadable): """ .. note:: Experimental @@ -109,6 +109,12 @@ class LinearSVC(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, Ha .. versionadded:: 2.2.0 """ + threshold = Param(Params._dummy(), "threshold", + "The threshold in binary classification applied to the linear model" + " prediction. This threshold can be any real number, where Inf will make" + " all predictions 0.0 and -Inf will make all predictions 1.0.", + typeConverter=TypeConverters.toFloat) + @keyword_only def __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", maxIter=100, regParam=0.0, tol=1e-6, rawPredictionCol="rawPrediction", @@ -147,6 +153,18 @@ def setParams(self, featuresCol="features", labelCol="label", predictionCol="pre def _create_model(self, java_model): return LinearSVCModel(java_model) + def setThreshold(self, value): + """ + Sets the value of :py:attr:`threshold`. + """ + return self._set(threshold=value) + + def getThreshold(self): + """ + Gets the value of threshold or its default value. + """ + return self.getOrDefault(self.threshold) + class LinearSVCModel(JavaModel, JavaClassificationModel, JavaMLWritable, JavaMLReadable): """ From ef1622899ffc6ab136102ffc6bcc714402e6f334 Mon Sep 17 00:00:00 2001 From: Xingbo Jiang Date: Tue, 20 Jun 2017 17:17:21 +0800 Subject: [PATCH 0755/1765] [SPARK-20989][CORE] Fail to start multiple workers on one host if external shuffle service is enabled in standalone mode ## What changes were proposed in this pull request? In standalone mode, if we enable external shuffle service by setting `spark.shuffle.service.enabled` to true, and then we try to start multiple workers on one host(by setting `SPARK_WORKER_INSTANCES=3` in spark-env.sh, and then run `sbin/start-slaves.sh`), we can only launch one worker on each host successfully and the rest of the workers fail to launch. The reason is the port of external shuffle service if configed by `spark.shuffle.service.port`, so currently we could start no more than one external shuffle service on each host. In our case, each worker tries to start a external shuffle service, and only one of them succeeded doing this. We should give explicit reason of failure instead of fail silently. ## How was this patch tested? Manually test by the following steps: 1. SET `SPARK_WORKER_INSTANCES=1` in `conf/spark-env.sh`; 2. SET `spark.shuffle.service.enabled` to `true` in `conf/spark-defaults.conf`; 3. Run `sbin/start-all.sh`. Before the change, you will see no error in the command line, as the following: ``` starting org.apache.spark.deploy.master.Master, logging to /Users/xxx/workspace/spark/logs/spark-xxx-org.apache.spark.deploy.master.Master-1-xxx.local.out localhost: starting org.apache.spark.deploy.worker.Worker, logging to /Users/xxx/workspace/spark/logs/spark-xxx-org.apache.spark.deploy.worker.Worker-1-xxx.local.out localhost: starting org.apache.spark.deploy.worker.Worker, logging to /Users/xxx/workspace/spark/logs/spark-xxx-org.apache.spark.deploy.worker.Worker-2-xxx.local.out localhost: starting org.apache.spark.deploy.worker.Worker, logging to /Users/xxx/workspace/spark/logs/spark-xxx-org.apache.spark.deploy.worker.Worker-3-xxx.local.out ``` And you can see in the webUI that only one worker is running. After the change, you get explicit error messages in the command line: ``` starting org.apache.spark.deploy.master.Master, logging to /Users/xxx/workspace/spark/logs/spark-xxx-org.apache.spark.deploy.master.Master-1-xxx.local.out localhost: starting org.apache.spark.deploy.worker.Worker, logging to /Users/xxx/workspace/spark/logs/spark-xxx-org.apache.spark.deploy.worker.Worker-1-xxx.local.out localhost: failed to launch: nice -n 0 /Users/xxx/workspace/spark/bin/spark-class org.apache.spark.deploy.worker.Worker --webui-port 8081 spark://xxx.local:7077 localhost: 17/06/13 23:24:53 INFO SecurityManager: Changing view acls to: xxx localhost: 17/06/13 23:24:53 INFO SecurityManager: Changing modify acls to: xxx localhost: 17/06/13 23:24:53 INFO SecurityManager: Changing view acls groups to: localhost: 17/06/13 23:24:53 INFO SecurityManager: Changing modify acls groups to: localhost: 17/06/13 23:24:53 INFO SecurityManager: SecurityManager: authentication disabled; ui acls disabled; users with view permissions: Set(xxx); groups with view permissions: Set(); users with modify permissions: Set(xxx); groups with modify permissions: Set() localhost: 17/06/13 23:24:54 INFO Utils: Successfully started service 'sparkWorker' on port 63354. localhost: Exception in thread "main" java.lang.IllegalArgumentException: requirement failed: Start multiple worker on one host failed because we may launch no more than one external shuffle service on each host, please set spark.shuffle.service.enabled to false or set SPARK_WORKER_INSTANCES to 1 to resolve the conflict. localhost: at scala.Predef$.require(Predef.scala:224) localhost: at org.apache.spark.deploy.worker.Worker$.main(Worker.scala:752) localhost: at org.apache.spark.deploy.worker.Worker.main(Worker.scala) localhost: full log in /Users/xxx/workspace/spark/logs/spark-xxx-org.apache.spark.deploy.worker.Worker-1-xxx.local.out localhost: starting org.apache.spark.deploy.worker.Worker, logging to /Users/xxx/workspace/spark/logs/spark-xxx-org.apache.spark.deploy.worker.Worker-2-xxx.local.out localhost: failed to launch: nice -n 0 /Users/xxx/workspace/spark/bin/spark-class org.apache.spark.deploy.worker.Worker --webui-port 8082 spark://xxx.local:7077 localhost: 17/06/13 23:24:56 INFO SecurityManager: Changing view acls to: xxx localhost: 17/06/13 23:24:56 INFO SecurityManager: Changing modify acls to: xxx localhost: 17/06/13 23:24:56 INFO SecurityManager: Changing view acls groups to: localhost: 17/06/13 23:24:56 INFO SecurityManager: Changing modify acls groups to: localhost: 17/06/13 23:24:56 INFO SecurityManager: SecurityManager: authentication disabled; ui acls disabled; users with view permissions: Set(xxx); groups with view permissions: Set(); users with modify permissions: Set(xxx); groups with modify permissions: Set() localhost: 17/06/13 23:24:56 INFO Utils: Successfully started service 'sparkWorker' on port 63359. localhost: Exception in thread "main" java.lang.IllegalArgumentException: requirement failed: Start multiple worker on one host failed because we may launch no more than one external shuffle service on each host, please set spark.shuffle.service.enabled to false or set SPARK_WORKER_INSTANCES to 1 to resolve the conflict. localhost: at scala.Predef$.require(Predef.scala:224) localhost: at org.apache.spark.deploy.worker.Worker$.main(Worker.scala:752) localhost: at org.apache.spark.deploy.worker.Worker.main(Worker.scala) localhost: full log in /Users/xxx/workspace/spark/logs/spark-xxx-org.apache.spark.deploy.worker.Worker-2-xxx.local.out localhost: starting org.apache.spark.deploy.worker.Worker, logging to /Users/xxx/workspace/spark/logs/spark-xxx-org.apache.spark.deploy.worker.Worker-3-xxx.local.out localhost: failed to launch: nice -n 0 /Users/xxx/workspace/spark/bin/spark-class org.apache.spark.deploy.worker.Worker --webui-port 8083 spark://xxx.local:7077 localhost: 17/06/13 23:24:59 INFO SecurityManager: Changing view acls to: xxx localhost: 17/06/13 23:24:59 INFO SecurityManager: Changing modify acls to: xxx localhost: 17/06/13 23:24:59 INFO SecurityManager: Changing view acls groups to: localhost: 17/06/13 23:24:59 INFO SecurityManager: Changing modify acls groups to: localhost: 17/06/13 23:24:59 INFO SecurityManager: SecurityManager: authentication disabled; ui acls disabled; users with view permissions: Set(xxx); groups with view permissions: Set(); users with modify permissions: Set(xxx); groups with modify permissions: Set() localhost: 17/06/13 23:24:59 INFO Utils: Successfully started service 'sparkWorker' on port 63360. localhost: Exception in thread "main" java.lang.IllegalArgumentException: requirement failed: Start multiple worker on one host failed because we may launch no more than one external shuffle service on each host, please set spark.shuffle.service.enabled to false or set SPARK_WORKER_INSTANCES to 1 to resolve the conflict. localhost: at scala.Predef$.require(Predef.scala:224) localhost: at org.apache.spark.deploy.worker.Worker$.main(Worker.scala:752) localhost: at org.apache.spark.deploy.worker.Worker.main(Worker.scala) localhost: full log in /Users/xxx/workspace/spark/logs/spark-xxx-org.apache.spark.deploy.worker.Worker-3-xxx.local.out ``` Author: Xingbo Jiang Closes #18290 from jiangxb1987/start-slave. --- .../scala/org/apache/spark/deploy/worker/Worker.scala | 11 +++++++++++ sbin/spark-daemon.sh | 2 +- 2 files changed, 12 insertions(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala b/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala index 1198e3cb05eaa..bed47455680dd 100755 --- a/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala @@ -742,6 +742,17 @@ private[deploy] object Worker extends Logging { val args = new WorkerArguments(argStrings, conf) val rpcEnv = startRpcEnvAndEndpoint(args.host, args.port, args.webUiPort, args.cores, args.memory, args.masters, args.workDir, conf = conf) + // With external shuffle service enabled, if we request to launch multiple workers on one host, + // we can only successfully launch the first worker and the rest fails, because with the port + // bound, we may launch no more than one external shuffle service on each host. + // When this happens, we should give explicit reason of failure instead of fail silently. For + // more detail see SPARK-20989. + val externalShuffleServiceEnabled = conf.getBoolean("spark.shuffle.service.enabled", false) + val sparkWorkerInstances = scala.sys.env.getOrElse("SPARK_WORKER_INSTANCES", "1").toInt + require(externalShuffleServiceEnabled == false || sparkWorkerInstances <= 1, + "Starting multiple workers on one host is failed because we may launch no more than one " + + "external shuffle service on each host, please set spark.shuffle.service.enabled to " + + "false or set SPARK_WORKER_INSTANCES to 1 to resolve the conflict.") rpcEnv.awaitTermination() } diff --git a/sbin/spark-daemon.sh b/sbin/spark-daemon.sh index c227c9828e6ac..6de67e039b48f 100755 --- a/sbin/spark-daemon.sh +++ b/sbin/spark-daemon.sh @@ -143,7 +143,7 @@ execute_command() { # Check if the process has died; in that case we'll tail the log so the user can see if [[ ! $(ps -p "$newpid" -o comm=) =~ "java" ]]; then echo "failed to launch: $@" - tail -2 "$log" | sed 's/^/ /' + tail -10 "$log" | sed 's/^/ /' echo "full log in $log" fi else From e862dc904963cf7832bafc1d3d0ea9090bbddd81 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Tue, 20 Jun 2017 09:15:33 -0700 Subject: [PATCH 0756/1765] [SPARK-21150][SQL] Persistent view stored in Hive metastore should be case preserving ## What changes were proposed in this pull request? This is a regression in Spark 2.2. In Spark 2.2, we introduced a new way to resolve persisted view: https://issues.apache.org/jira/browse/SPARK-18209 , but this makes the persisted view non case-preserving because we store the schema in hive metastore directly. We should follow data source table and store schema in table properties. ## How was this patch tested? new regression test Author: Wenchen Fan Closes #18360 from cloud-fan/view. --- .../spark/sql/execution/command/views.scala | 4 +- .../spark/sql/execution/SQLViewSuite.scala | 10 +++ .../spark/sql/hive/HiveExternalCatalog.scala | 84 ++++++++++--------- 3 files changed, 56 insertions(+), 42 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/views.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/views.scala index 1945d68241343..a6d56ca91a3ee 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/views.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/views.scala @@ -159,7 +159,9 @@ case class CreateViewCommand( checkCyclicViewReference(analyzedPlan, Seq(viewIdent), viewIdent) // Handles `CREATE OR REPLACE VIEW v0 AS SELECT ...` - catalog.alterTable(prepareTable(sparkSession, analyzedPlan)) + // Nothing we need to retain from the old view, so just drop and create a new one + catalog.dropTable(viewIdent, ignoreIfNotExists = false, purge = false) + catalog.createTable(prepareTable(sparkSession, analyzedPlan), ignoreIfExists = false) } else { // Handles `CREATE VIEW v0 AS SELECT ...`. Throws exception when the target view already // exists. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLViewSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLViewSuite.scala index d32716c18ddfb..6761f05bb462a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLViewSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLViewSuite.scala @@ -669,4 +669,14 @@ abstract class SQLViewSuite extends QueryTest with SQLTestUtils { "positive.")) } } + + test("permanent view should be case-preserving") { + withView("v") { + sql("CREATE VIEW v AS SELECT 1 as aBc") + assert(spark.table("v").schema.head.name == "aBc") + + sql("CREATE OR REPLACE VIEW v AS SELECT 2 as cBa") + assert(spark.table("v").schema.head.name == "cBa") + } + } } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala index 19453679a30df..6e7c475fa34c9 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala @@ -224,39 +224,36 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat throw new TableAlreadyExistsException(db = db, table = table) } - if (tableDefinition.tableType == VIEW) { - client.createTable(tableDefinition, ignoreIfExists) + // Ideally we should not create a managed table with location, but Hive serde table can + // specify location for managed table. And in [[CreateDataSourceTableAsSelectCommand]] we have + // to create the table directory and write out data before we create this table, to avoid + // exposing a partial written table. + val needDefaultTableLocation = tableDefinition.tableType == MANAGED && + tableDefinition.storage.locationUri.isEmpty + + val tableLocation = if (needDefaultTableLocation) { + Some(CatalogUtils.stringToURI(defaultTablePath(tableDefinition.identifier))) } else { - // Ideally we should not create a managed table with location, but Hive serde table can - // specify location for managed table. And in [[CreateDataSourceTableAsSelectCommand]] we have - // to create the table directory and write out data before we create this table, to avoid - // exposing a partial written table. - val needDefaultTableLocation = tableDefinition.tableType == MANAGED && - tableDefinition.storage.locationUri.isEmpty - - val tableLocation = if (needDefaultTableLocation) { - Some(CatalogUtils.stringToURI(defaultTablePath(tableDefinition.identifier))) - } else { - tableDefinition.storage.locationUri - } + tableDefinition.storage.locationUri + } - if (DDLUtils.isHiveTable(tableDefinition)) { - val tableWithDataSourceProps = tableDefinition.copy( - // We can't leave `locationUri` empty and count on Hive metastore to set a default table - // location, because Hive metastore uses hive.metastore.warehouse.dir to generate default - // table location for tables in default database, while we expect to use the location of - // default database. - storage = tableDefinition.storage.copy(locationUri = tableLocation), - // Here we follow data source tables and put table metadata like table schema, partition - // columns etc. in table properties, so that we can work around the Hive metastore issue - // about not case preserving and make Hive serde table support mixed-case column names. - properties = tableDefinition.properties ++ tableMetaToTableProps(tableDefinition)) - client.createTable(tableWithDataSourceProps, ignoreIfExists) - } else { - createDataSourceTable( - tableDefinition.withNewStorage(locationUri = tableLocation), - ignoreIfExists) - } + if (DDLUtils.isDatasourceTable(tableDefinition)) { + createDataSourceTable( + tableDefinition.withNewStorage(locationUri = tableLocation), + ignoreIfExists) + } else { + val tableWithDataSourceProps = tableDefinition.copy( + // We can't leave `locationUri` empty and count on Hive metastore to set a default table + // location, because Hive metastore uses hive.metastore.warehouse.dir to generate default + // table location for tables in default database, while we expect to use the location of + // default database. + storage = tableDefinition.storage.copy(locationUri = tableLocation), + // Here we follow data source tables and put table metadata like table schema, partition + // columns etc. in table properties, so that we can work around the Hive metastore issue + // about not case preserving and make Hive serde table and view support mixed-case column + // names. + properties = tableDefinition.properties ++ tableMetaToTableProps(tableDefinition)) + client.createTable(tableWithDataSourceProps, ignoreIfExists) } } @@ -679,16 +676,21 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat var table = inputTable - if (table.tableType != VIEW) { - table.properties.get(DATASOURCE_PROVIDER) match { - // No provider in table properties, which means this is a Hive serde table. - case None => - table = restoreHiveSerdeTable(table) - - // This is a regular data source table. - case Some(provider) => - table = restoreDataSourceTable(table, provider) - } + table.properties.get(DATASOURCE_PROVIDER) match { + case None if table.tableType == VIEW => + // If this is a view created by Spark 2.2 or higher versions, we should restore its schema + // from table properties. + if (table.properties.contains(DATASOURCE_SCHEMA_NUMPARTS)) { + table = table.copy(schema = getSchemaFromTableProperties(table)) + } + + // No provider in table properties, which means this is a Hive serde table. + case None => + table = restoreHiveSerdeTable(table) + + // This is a regular data source table. + case Some(provider) => + table = restoreDataSourceTable(table, provider) } // Restore Spark's statistics from information in Metastore. From b6b108826a5dd5c889a70180365f9320452557fc Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Tue, 20 Jun 2017 11:34:22 -0700 Subject: [PATCH 0757/1765] [SPARK-21103][SQL] QueryPlanConstraints should be part of LogicalPlan ## What changes were proposed in this pull request? QueryPlanConstraints should be part of LogicalPlan, rather than QueryPlan, since the constraint framework is only used for query plan rewriting and not for physical planning. ## How was this patch tested? Should be covered by existing tests, since it is a simple refactoring. Author: Reynold Xin Closes #18310 from rxin/SPARK-21103. --- .../org/apache/spark/sql/catalyst/plans/QueryPlan.scala | 5 +---- .../spark/sql/catalyst/plans/logical/LogicalPlan.scala | 2 +- .../plans/{ => logical}/QueryPlanConstraints.scala | 7 ++++--- 3 files changed, 6 insertions(+), 8 deletions(-) rename sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/{ => logical}/QueryPlanConstraints.scala (96%) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala index 9130b14763e24..1f6d05bc8d816 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala @@ -22,10 +22,7 @@ import org.apache.spark.sql.catalyst.trees.TreeNode import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.{DataType, StructType} -abstract class QueryPlan[PlanType <: QueryPlan[PlanType]] - extends TreeNode[PlanType] - with QueryPlanConstraints[PlanType] { - +abstract class QueryPlan[PlanType <: QueryPlan[PlanType]] extends TreeNode[PlanType] { self: PlanType => def conf: SQLConf = SQLConf.get diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala index 2ebb2ff323c6b..95b4165f6b10d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala @@ -27,7 +27,7 @@ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.StructType -abstract class LogicalPlan extends QueryPlan[LogicalPlan] with Logging { +abstract class LogicalPlan extends QueryPlan[LogicalPlan] with QueryPlanConstraints with Logging { private var _analyzed: Boolean = false diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlanConstraints.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/QueryPlanConstraints.scala similarity index 96% rename from sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlanConstraints.scala rename to sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/QueryPlanConstraints.scala index b08a009f0dca1..8bffbd0c208cb 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlanConstraints.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/QueryPlanConstraints.scala @@ -15,12 +15,12 @@ * limitations under the License. */ -package org.apache.spark.sql.catalyst.plans +package org.apache.spark.sql.catalyst.plans.logical import org.apache.spark.sql.catalyst.expressions._ -trait QueryPlanConstraints[PlanType <: QueryPlan[PlanType]] { self: QueryPlan[PlanType] => +trait QueryPlanConstraints { self: LogicalPlan => /** * An [[ExpressionSet]] that contains invariants about the rows output by this operator. For @@ -99,7 +99,8 @@ trait QueryPlanConstraints[PlanType <: QueryPlan[PlanType]] { self: QueryPlan[Pl private lazy val aliasMap: AttributeMap[Expression] = AttributeMap( expressions.collect { case a: Alias => (a.toAttribute, a.child) - } ++ children.flatMap(_.asInstanceOf[QueryPlanConstraints[PlanType]].aliasMap)) + } ++ children.flatMap(_.asInstanceOf[QueryPlanConstraints].aliasMap)) + // Note: the explicit cast is necessary, since Scala compiler fails to infer the type. /** * Infers an additional set of constraints from a given set of equality constraints. From 9ce714dca272315ef7f50d791563f22e8d5922ac Mon Sep 17 00:00:00 2001 From: sureshthalamati Date: Tue, 20 Jun 2017 22:35:42 -0700 Subject: [PATCH 0758/1765] [SPARK-10655][SQL] Adding additional data type mappings to jdbc DB2dialect. This patch adds DB2 specific data type mappings for decfloat, real, xml , and timestamp with time zone (DB2Z specific type) types on read and for byte, short data types on write to the to jdbc data source DB2 dialect. Default mapping does not work for these types when reading/writing from DB2 database. Added docker test, and a JDBC unit test case. Author: sureshthalamati Closes #9162 from sureshthalamati/db2dialect_enhancements-spark-10655. --- .../spark/sql/jdbc/DB2IntegrationSuite.scala | 47 +++++++++++++++---- .../apache/spark/sql/jdbc/DB2Dialect.scala | 21 ++++++++- .../org/apache/spark/sql/jdbc/JDBCSuite.scala | 9 ++++ 3 files changed, 66 insertions(+), 11 deletions(-) diff --git a/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/DB2IntegrationSuite.scala b/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/DB2IntegrationSuite.scala index 3da34b1b382d7..f5930bc281e8c 100644 --- a/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/DB2IntegrationSuite.scala +++ b/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/DB2IntegrationSuite.scala @@ -21,10 +21,13 @@ import java.math.BigDecimal import java.sql.{Connection, Date, Timestamp} import java.util.Properties -import org.scalatest._ +import org.scalatest.Ignore +import org.apache.spark.sql.Row +import org.apache.spark.sql.types.{BooleanType, ByteType, ShortType, StructType} import org.apache.spark.tags.DockerTest + @DockerTest @Ignore // AMPLab Jenkins needs to be updated before shared memory works on docker class DB2IntegrationSuite extends DockerJDBCIntegrationSuite { @@ -47,19 +50,22 @@ class DB2IntegrationSuite extends DockerJDBCIntegrationSuite { conn.prepareStatement("INSERT INTO tbl VALUES (17,'dave')").executeUpdate() conn.prepareStatement("CREATE TABLE numbers ( small SMALLINT, med INTEGER, big BIGINT, " - + "deci DECIMAL(31,20), flt FLOAT, dbl DOUBLE)").executeUpdate() + + "deci DECIMAL(31,20), flt FLOAT, dbl DOUBLE, real REAL, " + + "decflt DECFLOAT, decflt16 DECFLOAT(16), decflt34 DECFLOAT(34))").executeUpdate() conn.prepareStatement("INSERT INTO numbers VALUES (17, 77777, 922337203685477580, " - + "123456745.56789012345000000000, 42.75, 5.4E-70)").executeUpdate() + + "123456745.56789012345000000000, 42.75, 5.4E-70, " + + "3.4028234663852886e+38, 4.2999, DECFLOAT('9.999999999999999E19', 16), " + + "DECFLOAT('1234567891234567.123456789123456789', 34))").executeUpdate() conn.prepareStatement("CREATE TABLE dates (d DATE, t TIME, ts TIMESTAMP )").executeUpdate() conn.prepareStatement("INSERT INTO dates VALUES ('1991-11-09', '13:31:24', " + "'2009-02-13 23:31:30')").executeUpdate() // TODO: Test locale conversion for strings. - conn.prepareStatement("CREATE TABLE strings (a CHAR(10), b VARCHAR(10), c CLOB, d BLOB)") - .executeUpdate() - conn.prepareStatement("INSERT INTO strings VALUES ('the', 'quick', 'brown', BLOB('fox'))") + conn.prepareStatement("CREATE TABLE strings (a CHAR(10), b VARCHAR(10), c CLOB, d BLOB, e XML)") .executeUpdate() + conn.prepareStatement("INSERT INTO strings VALUES ('the', 'quick', 'brown', BLOB('fox')," + + "'Kathy')").executeUpdate() } test("Basic test") { @@ -77,13 +83,17 @@ class DB2IntegrationSuite extends DockerJDBCIntegrationSuite { val rows = df.collect() assert(rows.length == 1) val types = rows(0).toSeq.map(x => x.getClass.toString) - assert(types.length == 6) + assert(types.length == 10) assert(types(0).equals("class java.lang.Integer")) assert(types(1).equals("class java.lang.Integer")) assert(types(2).equals("class java.lang.Long")) assert(types(3).equals("class java.math.BigDecimal")) assert(types(4).equals("class java.lang.Double")) assert(types(5).equals("class java.lang.Double")) + assert(types(6).equals("class java.lang.Float")) + assert(types(7).equals("class java.math.BigDecimal")) + assert(types(8).equals("class java.math.BigDecimal")) + assert(types(9).equals("class java.math.BigDecimal")) assert(rows(0).getInt(0) == 17) assert(rows(0).getInt(1) == 77777) assert(rows(0).getLong(2) == 922337203685477580L) @@ -91,6 +101,10 @@ class DB2IntegrationSuite extends DockerJDBCIntegrationSuite { assert(rows(0).getAs[BigDecimal](3).equals(bd)) assert(rows(0).getDouble(4) == 42.75) assert(rows(0).getDouble(5) == 5.4E-70) + assert(rows(0).getFloat(6) == 3.4028234663852886e+38) + assert(rows(0).getDecimal(7) == new BigDecimal("4.299900000000000000")) + assert(rows(0).getDecimal(8) == new BigDecimal("99999999999999990000.000000000000000000")) + assert(rows(0).getDecimal(9) == new BigDecimal("1234567891234567.123456789123456789")) } test("Date types") { @@ -112,7 +126,7 @@ class DB2IntegrationSuite extends DockerJDBCIntegrationSuite { val rows = df.collect() assert(rows.length == 1) val types = rows(0).toSeq.map(x => x.getClass.toString) - assert(types.length == 4) + assert(types.length == 5) assert(types(0).equals("class java.lang.String")) assert(types(1).equals("class java.lang.String")) assert(types(2).equals("class java.lang.String")) @@ -121,14 +135,27 @@ class DB2IntegrationSuite extends DockerJDBCIntegrationSuite { assert(rows(0).getString(1).equals("quick")) assert(rows(0).getString(2).equals("brown")) assert(java.util.Arrays.equals(rows(0).getAs[Array[Byte]](3), Array[Byte](102, 111, 120))) + assert(rows(0).getString(4).equals("""Kathy""")) } test("Basic write test") { - // val df1 = sqlContext.read.jdbc(jdbcUrl, "numbers", new Properties) + // cast decflt column with precision value of 38 to DB2 max decimal precision value of 31. + val df1 = sqlContext.read.jdbc(jdbcUrl, "numbers", new Properties) + .selectExpr("small", "med", "big", "deci", "flt", "dbl", "real", + "cast(decflt as decimal(31, 5)) as decflt") val df2 = sqlContext.read.jdbc(jdbcUrl, "dates", new Properties) val df3 = sqlContext.read.jdbc(jdbcUrl, "strings", new Properties) - // df1.write.jdbc(jdbcUrl, "numberscopy", new Properties) + df1.write.jdbc(jdbcUrl, "numberscopy", new Properties) df2.write.jdbc(jdbcUrl, "datescopy", new Properties) df3.write.jdbc(jdbcUrl, "stringscopy", new Properties) + // spark types that does not have exact matching db2 table types. + val df4 = sqlContext.createDataFrame( + sparkContext.parallelize(Seq(Row("1".toShort, "20".toByte, true))), + new StructType().add("c1", ShortType).add("b", ByteType).add("c3", BooleanType)) + df4.write.jdbc(jdbcUrl, "otherscopy", new Properties) + val rows = sqlContext.read.jdbc(jdbcUrl, "otherscopy", new Properties).collect() + assert(rows(0).getInt(0) == 1) + assert(rows(0).getInt(1) == 20) + assert(rows(0).getString(2) == "1") } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/DB2Dialect.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/DB2Dialect.scala index 190463df0d928..d160ad82888a2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/DB2Dialect.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/DB2Dialect.scala @@ -17,15 +17,34 @@ package org.apache.spark.sql.jdbc -import org.apache.spark.sql.types.{BooleanType, DataType, StringType} +import java.sql.Types + +import org.apache.spark.sql.types._ private object DB2Dialect extends JdbcDialect { override def canHandle(url: String): Boolean = url.startsWith("jdbc:db2") + override def getCatalystType( + sqlType: Int, + typeName: String, + size: Int, + md: MetadataBuilder): Option[DataType] = sqlType match { + case Types.REAL => Option(FloatType) + case Types.OTHER => + typeName match { + case "DECFLOAT" => Option(DecimalType(38, 18)) + case "XML" => Option(StringType) + case t if (t.startsWith("TIMESTAMP")) => Option(TimestampType) // TIMESTAMP WITH TIMEZONE + case _ => None + } + case _ => None + } + override def getJDBCType(dt: DataType): Option[JdbcType] = dt match { case StringType => Option(JdbcType("CLOB", java.sql.Types.CLOB)) case BooleanType => Option(JdbcType("CHAR(1)", java.sql.Types.CHAR)) + case ShortType | ByteType => Some(JdbcType("SMALLINT", java.sql.Types.SMALLINT)) case _ => None } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala index 70bee929b31da..d1daf860fdfff 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala @@ -713,6 +713,15 @@ class JDBCSuite extends SparkFunSuite val db2Dialect = JdbcDialects.get("jdbc:db2://127.0.0.1/db") assert(db2Dialect.getJDBCType(StringType).map(_.databaseTypeDefinition).get == "CLOB") assert(db2Dialect.getJDBCType(BooleanType).map(_.databaseTypeDefinition).get == "CHAR(1)") + assert(db2Dialect.getJDBCType(ShortType).map(_.databaseTypeDefinition).get == "SMALLINT") + assert(db2Dialect.getJDBCType(ByteType).map(_.databaseTypeDefinition).get == "SMALLINT") + // test db2 dialect mappings on read + assert(db2Dialect.getCatalystType(java.sql.Types.REAL, "REAL", 1, null) == Option(FloatType)) + assert(db2Dialect.getCatalystType(java.sql.Types.OTHER, "DECFLOAT", 1, null) == + Option(DecimalType(38, 18))) + assert(db2Dialect.getCatalystType(java.sql.Types.OTHER, "XML", 1, null) == Option(StringType)) + assert(db2Dialect.getCatalystType(java.sql.Types.OTHER, "TIMESTAMP WITH TIME ZONE", 1, null) == + Option(TimestampType)) } test("PostgresDialect type mapping") { From d107b3b910d8f434fb15b663a9db4c2dfe0a9f43 Mon Sep 17 00:00:00 2001 From: Li Yichao Date: Wed, 21 Jun 2017 21:54:29 +0800 Subject: [PATCH 0759/1765] [SPARK-20640][CORE] Make rpc timeout and retry for shuffle registration configurable. ## What changes were proposed in this pull request? Currently the shuffle service registration timeout and retry has been hardcoded. This works well for small workloads but under heavy workload when the shuffle service is busy transferring large amount of data we see significant delay in responding to the registration request, as a result we often see the executors fail to register with the shuffle service, eventually failing the job. We need to make these two parameters configurable. ## How was this patch tested? * Updated `BlockManagerSuite` to test registration timeout and max attempts configuration actually works. cc sitalkedia Author: Li Yichao Closes #18092 from liyichao/SPARK-20640. --- .../shuffle/ExternalShuffleClient.java | 7 +- .../mesos/MesosExternalShuffleClient.java | 5 +- .../ExternalShuffleIntegrationSuite.java | 4 +- .../shuffle/ExternalShuffleSecuritySuite.java | 2 +- .../spark/internal/config/package.scala | 13 ++++ .../apache/spark/storage/BlockManager.scala | 7 +- .../spark/storage/BlockManagerSuite.scala | 68 +++++++++++++++++-- docs/configuration.md | 14 ++++ .../MesosCoarseGrainedSchedulerBackend.scala | 4 +- 9 files changed, 109 insertions(+), 15 deletions(-) diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleClient.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleClient.java index 269fa72dad5f5..6ac9302517ee0 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleClient.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleClient.java @@ -49,6 +49,7 @@ public class ExternalShuffleClient extends ShuffleClient { private final TransportConf conf; private final boolean authEnabled; private final SecretKeyHolder secretKeyHolder; + private final long registrationTimeoutMs; protected TransportClientFactory clientFactory; protected String appId; @@ -60,10 +61,12 @@ public class ExternalShuffleClient extends ShuffleClient { public ExternalShuffleClient( TransportConf conf, SecretKeyHolder secretKeyHolder, - boolean authEnabled) { + boolean authEnabled, + long registrationTimeoutMs) { this.conf = conf; this.secretKeyHolder = secretKeyHolder; this.authEnabled = authEnabled; + this.registrationTimeoutMs = registrationTimeoutMs; } protected void checkInit() { @@ -132,7 +135,7 @@ public void registerWithShuffleServer( checkInit(); try (TransportClient client = clientFactory.createUnmanagedClient(host, port)) { ByteBuffer registerMessage = new RegisterExecutor(appId, execId, executorInfo).toByteBuffer(); - client.sendRpcSync(registerMessage, 5000 /* timeoutMs */); + client.sendRpcSync(registerMessage, registrationTimeoutMs); } } diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/mesos/MesosExternalShuffleClient.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/mesos/MesosExternalShuffleClient.java index dbc1010847fb1..60179f126bc44 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/mesos/MesosExternalShuffleClient.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/mesos/MesosExternalShuffleClient.java @@ -60,8 +60,9 @@ public class MesosExternalShuffleClient extends ExternalShuffleClient { public MesosExternalShuffleClient( TransportConf conf, SecretKeyHolder secretKeyHolder, - boolean authEnabled) { - super(conf, secretKeyHolder, authEnabled); + boolean authEnabled, + long registrationTimeoutMs) { + super(conf, secretKeyHolder, authEnabled, registrationTimeoutMs); } public void registerDriverWithShuffleService( diff --git a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleIntegrationSuite.java b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleIntegrationSuite.java index 4391e3023491b..a6a1b8d0ac3f1 100644 --- a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleIntegrationSuite.java +++ b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleIntegrationSuite.java @@ -133,7 +133,7 @@ private FetchResult fetchBlocks( final Semaphore requestsRemaining = new Semaphore(0); - ExternalShuffleClient client = new ExternalShuffleClient(clientConf, null, false); + ExternalShuffleClient client = new ExternalShuffleClient(clientConf, null, false, 5000); client.init(APP_ID); client.fetchBlocks(TestUtils.getLocalHost(), port, execId, blockIds, new BlockFetchingListener() { @@ -242,7 +242,7 @@ public void testFetchNoServer() throws Exception { private static void registerExecutor(String executorId, ExecutorShuffleInfo executorInfo) throws IOException, InterruptedException { - ExternalShuffleClient client = new ExternalShuffleClient(conf, null, false); + ExternalShuffleClient client = new ExternalShuffleClient(conf, null, false, 5000); client.init(APP_ID); client.registerWithShuffleServer(TestUtils.getLocalHost(), server.getPort(), executorId, executorInfo); diff --git a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleSecuritySuite.java b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleSecuritySuite.java index bf20c577ed420..16bad9f1b319d 100644 --- a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleSecuritySuite.java +++ b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleSecuritySuite.java @@ -97,7 +97,7 @@ private void validate(String appId, String secretKey, boolean encrypt) } ExternalShuffleClient client = - new ExternalShuffleClient(testConf, new TestSecretKeyHolder(appId, secretKey), true); + new ExternalShuffleClient(testConf, new TestSecretKeyHolder(appId, secretKey), true, 5000); client.init(appId); // Registration either succeeds or throws an exception. client.registerWithShuffleServer(TestUtils.getLocalHost(), server.getPort(), "exec0", diff --git a/core/src/main/scala/org/apache/spark/internal/config/package.scala b/core/src/main/scala/org/apache/spark/internal/config/package.scala index 84ef57f2d271b..615497d36fd14 100644 --- a/core/src/main/scala/org/apache/spark/internal/config/package.scala +++ b/core/src/main/scala/org/apache/spark/internal/config/package.scala @@ -303,6 +303,19 @@ package object config { .bytesConf(ByteUnit.BYTE) .createWithDefault(100 * 1024 * 1024) + private[spark] val SHUFFLE_REGISTRATION_TIMEOUT = + ConfigBuilder("spark.shuffle.registration.timeout") + .doc("Timeout in milliseconds for registration to the external shuffle service.") + .timeConf(TimeUnit.MILLISECONDS) + .createWithDefault(5000) + + private[spark] val SHUFFLE_REGISTRATION_MAX_ATTEMPTS = + ConfigBuilder("spark.shuffle.registration.maxAttempts") + .doc("When we fail to register to the external shuffle service, we will " + + "retry for maxAttempts times.") + .intConf + .createWithDefault(3) + private[spark] val REDUCER_MAX_REQ_SIZE_SHUFFLE_TO_MEM = ConfigBuilder("spark.reducer.maxReqSizeShuffleToMem") .doc("The blocks of a shuffle request will be fetched to disk when size of the request is " + diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala index 1689baa832d52..74be70348305c 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala @@ -31,7 +31,7 @@ import scala.util.control.NonFatal import org.apache.spark._ import org.apache.spark.executor.{DataReadMethod, ShuffleWriteMetrics} -import org.apache.spark.internal.Logging +import org.apache.spark.internal.{config, Logging} import org.apache.spark.memory.{MemoryManager, MemoryMode} import org.apache.spark.network._ import org.apache.spark.network.buffer.ManagedBuffer @@ -174,7 +174,8 @@ private[spark] class BlockManager( // standard BlockTransferService to directly connect to other Executors. private[spark] val shuffleClient = if (externalShuffleServiceEnabled) { val transConf = SparkTransportConf.fromSparkConf(conf, "shuffle", numUsableCores) - new ExternalShuffleClient(transConf, securityManager, securityManager.isAuthenticationEnabled()) + new ExternalShuffleClient(transConf, securityManager, + securityManager.isAuthenticationEnabled(), conf.get(config.SHUFFLE_REGISTRATION_TIMEOUT)) } else { blockTransferService } @@ -254,7 +255,7 @@ private[spark] class BlockManager( diskBlockManager.subDirsPerLocalDir, shuffleManager.getClass.getName) - val MAX_ATTEMPTS = 3 + val MAX_ATTEMPTS = conf.get(config.SHUFFLE_REGISTRATION_MAX_ATTEMPTS) val SLEEP_TIME_SECS = 5 for (i <- 1 to MAX_ATTEMPTS) { diff --git a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala index 9d52b488b223e..88f18294aa015 100644 --- a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala @@ -20,13 +20,15 @@ package org.apache.spark.storage import java.io.File import java.nio.ByteBuffer +import scala.collection.JavaConverters._ +import scala.collection.mutable import scala.collection.mutable.ArrayBuffer import scala.concurrent.duration._ import scala.concurrent.Future -import scala.language.implicitConversions -import scala.language.postfixOps +import scala.language.{implicitConversions, postfixOps} import scala.reflect.ClassTag +import org.apache.commons.lang3.RandomUtils import org.mockito.{Matchers => mc} import org.mockito.Mockito.{mock, times, verify, when} import org.scalatest._ @@ -38,10 +40,13 @@ import org.apache.spark.broadcast.BroadcastManager import org.apache.spark.executor.DataReadMethod import org.apache.spark.internal.config._ import org.apache.spark.memory.UnifiedMemoryManager -import org.apache.spark.network.{BlockDataManager, BlockTransferService} +import org.apache.spark.network.{BlockDataManager, BlockTransferService, TransportContext} import org.apache.spark.network.buffer.{ManagedBuffer, NioManagedBuffer} -import org.apache.spark.network.netty.NettyBlockTransferService +import org.apache.spark.network.client.{RpcResponseCallback, TransportClient} +import org.apache.spark.network.netty.{NettyBlockTransferService, SparkTransportConf} +import org.apache.spark.network.server.{NoOpRpcHandler, TransportServer, TransportServerBootstrap} import org.apache.spark.network.shuffle.BlockFetchingListener +import org.apache.spark.network.shuffle.protocol.{BlockTransferMessage, RegisterExecutor} import org.apache.spark.rpc.RpcEnv import org.apache.spark.scheduler.LiveListenerBus import org.apache.spark.security.{CryptoStreamUtils, EncryptionFunSuite} @@ -1281,6 +1286,61 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE assert(master.getLocations("item").isEmpty) } + test("SPARK-20640: Shuffle registration timeout and maxAttempts conf are working") { + val tryAgainMsg = "test_spark_20640_try_again" + // a server which delays response 50ms and must try twice for success. + def newShuffleServer(port: Int): (TransportServer, Int) = { + val attempts = new mutable.HashMap[String, Int]() + val handler = new NoOpRpcHandler { + override def receive( + client: TransportClient, + message: ByteBuffer, + callback: RpcResponseCallback): Unit = { + val msgObj = BlockTransferMessage.Decoder.fromByteBuffer(message) + msgObj match { + case exec: RegisterExecutor => + Thread.sleep(50) + val attempt = attempts.getOrElse(exec.execId, 0) + 1 + attempts(exec.execId) = attempt + if (attempt < 2) { + callback.onFailure(new Exception(tryAgainMsg)) + return + } + callback.onSuccess(ByteBuffer.wrap(new Array[Byte](0))) + } + } + } + + val transConf = SparkTransportConf.fromSparkConf(conf, "shuffle", numUsableCores = 0) + val transCtx = new TransportContext(transConf, handler, true) + (transCtx.createServer(port, Seq.empty[TransportServerBootstrap].asJava), port) + } + val candidatePort = RandomUtils.nextInt(1024, 65536) + val (server, shufflePort) = Utils.startServiceOnPort(candidatePort, + newShuffleServer, conf, "ShuffleServer") + + conf.set("spark.shuffle.service.enabled", "true") + conf.set("spark.shuffle.service.port", shufflePort.toString) + conf.set(SHUFFLE_REGISTRATION_TIMEOUT.key, "40") + conf.set(SHUFFLE_REGISTRATION_MAX_ATTEMPTS.key, "1") + var e = intercept[SparkException]{ + makeBlockManager(8000, "executor1") + }.getMessage + assert(e.contains("TimeoutException")) + + conf.set(SHUFFLE_REGISTRATION_TIMEOUT.key, "1000") + conf.set(SHUFFLE_REGISTRATION_MAX_ATTEMPTS.key, "1") + e = intercept[SparkException]{ + makeBlockManager(8000, "executor2") + }.getMessage + assert(e.contains(tryAgainMsg)) + + conf.set(SHUFFLE_REGISTRATION_TIMEOUT.key, "1000") + conf.set(SHUFFLE_REGISTRATION_MAX_ATTEMPTS.key, "2") + makeBlockManager(8000, "executor3") + server.close() + } + class MockBlockTransferService(val maxFailures: Int) extends BlockTransferService { var numCalls = 0 diff --git a/docs/configuration.md b/docs/configuration.md index c1464741ebb6f..f1c6d04115ab0 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -638,6 +638,20 @@ Apart from these, the following properties are also available, and may be useful underestimating shuffle block size when fetch shuffle blocks. + + spark.shuffle.registration.timeout + 5000 + + Timeout in milliseconds for registration to the external shuffle service. + + + + spark.shuffle.registration.maxAttempts + 3 + + When we fail to register to the external shuffle service, we will retry for maxAttempts times. + + spark.io.encryption.enabled false diff --git a/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackend.scala b/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackend.scala index 871685c6cccc0..7dd42c41aa7c2 100644 --- a/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackend.scala +++ b/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackend.scala @@ -29,6 +29,7 @@ import org.apache.mesos.Protos.{TaskInfo => MesosTaskInfo, _} import org.apache.mesos.SchedulerDriver import org.apache.spark.{SecurityManager, SparkContext, SparkException, TaskState} +import org.apache.spark.internal.config import org.apache.spark.network.netty.SparkTransportConf import org.apache.spark.network.shuffle.mesos.MesosExternalShuffleClient import org.apache.spark.rpc.RpcEndpointAddress @@ -150,7 +151,8 @@ private[spark] class MesosCoarseGrainedSchedulerBackend( new MesosExternalShuffleClient( SparkTransportConf.fromSparkConf(conf, "shuffle"), securityManager, - securityManager.isAuthenticationEnabled()) + securityManager.isAuthenticationEnabled(), + conf.get(config.SHUFFLE_REGISTRATION_TIMEOUT)) } private var nextMesosTaskId = 0 From 987eb8faddbb533e006c769d382a3e4fda3dd6ee Mon Sep 17 00:00:00 2001 From: Yuming Wang Date: Wed, 21 Jun 2017 15:30:31 +0100 Subject: [PATCH 0760/1765] [MINOR][DOCS] Add lost tag for configuration.md ## What changes were proposed in this pull request? Add lost `` tag for `configuration.md`. ## How was this patch tested? N/A Author: Yuming Wang Closes #18372 from wangyum/docs-missing-tr. --- docs/configuration.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/docs/configuration.md b/docs/configuration.md index f1c6d04115ab0..f4bec589208be 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -1566,6 +1566,8 @@ Apart from these, the following properties are also available, and may be useful of this setting is to act as a safety-net to prevent runaway uncancellable tasks from rendering an executor unusable. + + spark.stage.maxConsecutiveAttempts 4 From e92befcb4b57c3e4afe57b6de1622ac72e7d819c Mon Sep 17 00:00:00 2001 From: Marcos P Date: Wed, 21 Jun 2017 15:34:10 +0100 Subject: [PATCH 0761/1765] [MINOR][DOC] modified issue link and updated status ## What changes were proposed in this pull request? This PR aims to clarify some outdated comments that i found at **spark-catalyst** and **spark-sql** pom files. Maven bug still happening and in order to track it I have updated the issue link and also the status of the issue. Author: Marcos P Closes #18374 from mpenate/fix/mng-3559-comment. --- sql/catalyst/pom.xml | 2 +- sql/core/pom.xml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/pom.xml b/sql/catalyst/pom.xml index 36948ba52b064..0bbf7a95124cf 100644 --- a/sql/catalyst/pom.xml +++ b/sql/catalyst/pom.xml @@ -109,7 +109,7 @@ so that the tests classes of external modules can use them. The two execution profiles are necessary - first one for 'mvn package', second one for 'mvn test-compile'. Ideally, 'mvn compile' should not compile test classes and therefore should not need this. - However, an open Maven bug (http://jira.codehaus.org/browse/MNG-3559) + However, a closed due to "Cannot Reproduce" Maven bug (https://issues.apache.org/jira/browse/MNG-3559) causes the compilation to fail if catalyst test-jar is not generated. Hence, the second execution profile for 'mvn test-compile'. --> diff --git a/sql/core/pom.xml b/sql/core/pom.xml index 7327c9b0c9c50..1bc34a6b069d9 100644 --- a/sql/core/pom.xml +++ b/sql/core/pom.xml @@ -161,7 +161,7 @@ so that the tests classes of external modules can use them. The two execution profiles are necessary - first one for 'mvn package', second one for 'mvn test-compile'. Ideally, 'mvn compile' should not compile test classes and therefore should not need this. - However, an open Maven bug (http://jira.codehaus.org/browse/MNG-3559) + However, a closed due to "Cannot Reproduce" Maven bug (https://issues.apache.org/jira/browse/MNG-3559) causes the compilation to fail if catalyst test-jar is not generated. Hence, the second execution profile for 'mvn test-compile'. --> From cad88f17e87e6cb96550b70e35d3ed75305dc59d Mon Sep 17 00:00:00 2001 From: Xingbo Jiang Date: Wed, 21 Jun 2017 09:40:06 -0700 Subject: [PATCH 0762/1765] [SPARK-17851][SQL][TESTS] Make sure all test sqls in catalyst pass checkAnalysis ## What changes were proposed in this pull request? Currently we have several tens of test sqls in catalyst will fail at `SimpleAnalyzer.checkAnalysis`, we should make sure they are valid. This PR makes the following changes: 1. Apply `checkAnalysis` on plans that tests `Optimizer` rules, but don't require the testcases for `Parser`/`Analyzer` pass `checkAnalysis`; 2. Fix testcases for `Optimizer` that would have fall. ## How was this patch tested? Apply `SimpleAnalyzer.checkAnalysis` on plans in `PlanTest.comparePlans`, update invalid test cases. Author: Xingbo Jiang Author: jiangxingbo Closes #15417 from jiangxb1987/cptest. --- .../sql/catalyst/analysis/AnalysisTest.scala | 8 +++ .../analysis/DecimalPrecisionSuite.scala | 2 +- .../catalyst/analysis/TypeCoercionSuite.scala | 2 +- .../catalog/SessionCatalogSuite.scala | 2 +- .../optimizer/AggregateOptimizeSuite.scala | 4 +- .../BooleanSimplificationSuite.scala | 57 ++++++++++--------- .../optimizer/ColumnPruningSuite.scala | 4 +- .../optimizer/ConstantPropagationSuite.scala | 9 ++- .../optimizer/FilterPushdownSuite.scala | 11 ++-- .../optimizer/LimitPushdownSuite.scala | 12 ++-- .../optimizer/OptimizeCodegenSuite.scala | 4 +- .../optimizer/OuterJoinEliminationSuite.scala | 4 +- .../optimizer/SimplifyCastsSuite.scala | 9 ++- .../sql/catalyst/parser/PlanParserSuite.scala | 6 +- .../spark/sql/catalyst/plans/PlanTest.scala | 14 ++++- .../apache/spark/sql/DataFrameHintSuite.scala | 4 +- .../sql/execution/SparkSqlParserSuite.scala | 5 +- .../spark/sql/hive/HiveDDLCommandSuite.scala | 20 +++---- 18 files changed, 101 insertions(+), 76 deletions(-) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisTest.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisTest.scala index edfa8c45f9867..549a4355dfba3 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisTest.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisTest.scala @@ -59,6 +59,14 @@ trait AnalysisTest extends PlanTest { comparePlans(actualPlan, expectedPlan) } + protected override def comparePlans( + plan1: LogicalPlan, + plan2: LogicalPlan, + checkAnalysis: Boolean = false): Unit = { + // Analysis tests may have not been fully resolved, so skip checkAnalysis. + super.comparePlans(plan1, plan2, checkAnalysis) + } + protected def assertAnalysisSuccess( inputPlan: LogicalPlan, caseSensitive: Boolean = true): Unit = { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecisionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecisionSuite.scala index 8f43171f309a9..ccf3c3fb0949d 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecisionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecisionSuite.scala @@ -29,7 +29,7 @@ import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, Project, Unio import org.apache.spark.sql.types._ -class DecimalPrecisionSuite extends PlanTest with BeforeAndAfter { +class DecimalPrecisionSuite extends AnalysisTest with BeforeAndAfter { private val catalog = new SessionCatalog(new InMemoryCatalog, EmptyFunctionRegistry, conf) private val analyzer = new Analyzer(catalog, conf) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala index 7358f401ed520..b3994ab0828ad 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala @@ -29,7 +29,7 @@ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.CalendarInterval -class TypeCoercionSuite extends PlanTest { +class TypeCoercionSuite extends AnalysisTest { // scalastyle:off line.size.limit // The following table shows all implicit data type conversions that are not visible to the user. diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalogSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalogSuite.scala index dce73b3635e72..a6dc21b03d446 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalogSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalogSuite.scala @@ -44,7 +44,7 @@ class InMemorySessionCatalogSuite extends SessionCatalogSuite { * signatures but do not extend a common parent. This is largely by design but * unfortunately leads to very similar test code in two places. */ -abstract class SessionCatalogSuite extends PlanTest { +abstract class SessionCatalogSuite extends AnalysisTest { protected val utils: CatalogTestUtils protected val isHiveExternalCatalog = false diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/AggregateOptimizeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/AggregateOptimizeSuite.scala index e6132ab2e4d17..a3184a4266c7c 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/AggregateOptimizeSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/AggregateOptimizeSuite.scala @@ -59,9 +59,9 @@ class AggregateOptimizeSuite extends PlanTest { } test("Remove aliased literals") { - val query = testRelation.select('a, Literal(1).as('y)).groupBy('a, 'y)(sum('b)) + val query = testRelation.select('a, 'b, Literal(1).as('y)).groupBy('a, 'y)(sum('b)) val optimized = Optimize.execute(analyzer.execute(query)) - val correctAnswer = testRelation.select('a, Literal(1).as('y)).groupBy('a)(sum('b)).analyze + val correctAnswer = testRelation.select('a, 'b, Literal(1).as('y)).groupBy('a)(sum('b)).analyze comparePlans(optimized, correctAnswer) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/BooleanSimplificationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/BooleanSimplificationSuite.scala index 1df0a89cf0bf1..c6345b60b744b 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/BooleanSimplificationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/BooleanSimplificationSuite.scala @@ -41,7 +41,8 @@ class BooleanSimplificationSuite extends PlanTest with PredicateHelper { PruneFilters) :: Nil } - val testRelation = LocalRelation('a.int, 'b.int, 'c.int, 'd.string) + val testRelation = LocalRelation('a.int, 'b.int, 'c.int, 'd.string, + 'e.boolean, 'f.boolean, 'g.boolean, 'h.boolean) val testRelationWithData = LocalRelation.fromExternalRows( testRelation.output, Seq(Row(1, 2, 3, "abc")) @@ -101,52 +102,52 @@ class BooleanSimplificationSuite extends PlanTest with PredicateHelper { 'a === 'b || 'b > 3 && 'a > 3 && 'a < 5) } - test("a && (!a || b)") { - checkCondition('a && (!'a || 'b ), 'a && 'b) + test("e && (!e || f)") { + checkCondition('e && (!'e || 'f ), 'e && 'f) - checkCondition('a && ('b || !'a ), 'a && 'b) + checkCondition('e && ('f || !'e ), 'e && 'f) - checkCondition((!'a || 'b ) && 'a, 'b && 'a) + checkCondition((!'e || 'f ) && 'e, 'f && 'e) - checkCondition(('b || !'a ) && 'a, 'b && 'a) + checkCondition(('f || !'e ) && 'e, 'f && 'e) } - test("a < 1 && (!(a < 1) || b)") { - checkCondition('a < 1 && (!('a < 1) || 'b), ('a < 1) && 'b) - checkCondition('a < 1 && ('b || !('a < 1)), ('a < 1) && 'b) + test("a < 1 && (!(a < 1) || f)") { + checkCondition('a < 1 && (!('a < 1) || 'f), ('a < 1) && 'f) + checkCondition('a < 1 && ('f || !('a < 1)), ('a < 1) && 'f) - checkCondition('a <= 1 && (!('a <= 1) || 'b), ('a <= 1) && 'b) - checkCondition('a <= 1 && ('b || !('a <= 1)), ('a <= 1) && 'b) + checkCondition('a <= 1 && (!('a <= 1) || 'f), ('a <= 1) && 'f) + checkCondition('a <= 1 && ('f || !('a <= 1)), ('a <= 1) && 'f) - checkCondition('a > 1 && (!('a > 1) || 'b), ('a > 1) && 'b) - checkCondition('a > 1 && ('b || !('a > 1)), ('a > 1) && 'b) + checkCondition('a > 1 && (!('a > 1) || 'f), ('a > 1) && 'f) + checkCondition('a > 1 && ('f || !('a > 1)), ('a > 1) && 'f) - checkCondition('a >= 1 && (!('a >= 1) || 'b), ('a >= 1) && 'b) - checkCondition('a >= 1 && ('b || !('a >= 1)), ('a >= 1) && 'b) + checkCondition('a >= 1 && (!('a >= 1) || 'f), ('a >= 1) && 'f) + checkCondition('a >= 1 && ('f || !('a >= 1)), ('a >= 1) && 'f) } - test("a < 1 && ((a >= 1) || b)") { - checkCondition('a < 1 && ('a >= 1 || 'b ), ('a < 1) && 'b) - checkCondition('a < 1 && ('b || 'a >= 1), ('a < 1) && 'b) + test("a < 1 && ((a >= 1) || f)") { + checkCondition('a < 1 && ('a >= 1 || 'f ), ('a < 1) && 'f) + checkCondition('a < 1 && ('f || 'a >= 1), ('a < 1) && 'f) - checkCondition('a <= 1 && ('a > 1 || 'b ), ('a <= 1) && 'b) - checkCondition('a <= 1 && ('b || 'a > 1), ('a <= 1) && 'b) + checkCondition('a <= 1 && ('a > 1 || 'f ), ('a <= 1) && 'f) + checkCondition('a <= 1 && ('f || 'a > 1), ('a <= 1) && 'f) - checkCondition('a > 1 && (('a <= 1) || 'b), ('a > 1) && 'b) - checkCondition('a > 1 && ('b || ('a <= 1)), ('a > 1) && 'b) + checkCondition('a > 1 && (('a <= 1) || 'f), ('a > 1) && 'f) + checkCondition('a > 1 && ('f || ('a <= 1)), ('a > 1) && 'f) - checkCondition('a >= 1 && (('a < 1) || 'b), ('a >= 1) && 'b) - checkCondition('a >= 1 && ('b || ('a < 1)), ('a >= 1) && 'b) + checkCondition('a >= 1 && (('a < 1) || 'f), ('a >= 1) && 'f) + checkCondition('a >= 1 && ('f || ('a < 1)), ('a >= 1) && 'f) } test("DeMorgan's law") { - checkCondition(!('a && 'b), !'a || !'b) + checkCondition(!('e && 'f), !'e || !'f) - checkCondition(!('a || 'b), !'a && !'b) + checkCondition(!('e || 'f), !'e && !'f) - checkCondition(!(('a && 'b) || ('c && 'd)), (!'a || !'b) && (!'c || !'d)) + checkCondition(!(('e && 'f) || ('g && 'h)), (!'e || !'f) && (!'g || !'h)) - checkCondition(!(('a || 'b) && ('c || 'd)), (!'a && !'b) || (!'c && !'d)) + checkCondition(!(('e || 'f) && ('g || 'h)), (!'e && !'f) || (!'g && !'h)) } private val caseInsensitiveConf = new SQLConf().copy(SQLConf.CASE_SENSITIVE -> false) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ColumnPruningSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ColumnPruningSuite.scala index a0a0daea7d075..0b419e9631b29 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ColumnPruningSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ColumnPruningSuite.scala @@ -266,8 +266,8 @@ class ColumnPruningSuite extends PlanTest { test("Column pruning on Window with useless aggregate functions") { val input = LocalRelation('a.int, 'b.string, 'c.double, 'd.int) - val winSpec = windowSpec('a :: Nil, 'b.asc :: Nil, UnspecifiedFrame) - val winExpr = windowExpr(count('b), winSpec) + val winSpec = windowSpec('a :: Nil, 'd.asc :: Nil, UnspecifiedFrame) + val winExpr = windowExpr(count('d), winSpec) val originalQuery = input.groupBy('a, 'c, 'd)('a, 'c, 'd, winExpr.as('window)).select('a, 'c) val correctAnswer = input.select('a, 'c, 'd).groupBy('a, 'c, 'd)('a, 'c).analyze diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConstantPropagationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConstantPropagationSuite.scala index 81d2f3667e2d0..94174eec8fd0f 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConstantPropagationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConstantPropagationSuite.scala @@ -35,7 +35,6 @@ class ConstantPropagationSuite extends PlanTest { Batch("AnalysisNodes", Once, EliminateSubqueryAliases) :: Batch("ConstantPropagation", FixedPoint(10), - ColumnPruning, ConstantPropagation, ConstantFolding, BooleanSimplification) :: Nil @@ -43,9 +42,9 @@ class ConstantPropagationSuite extends PlanTest { val testRelation = LocalRelation('a.int, 'b.int, 'c.int) - private val columnA = 'a.int - private val columnB = 'b.int - private val columnC = 'c.int + private val columnA = 'a + private val columnB = 'b + private val columnC = 'c test("basic test") { val query = testRelation @@ -160,7 +159,7 @@ class ConstantPropagationSuite extends PlanTest { val correctAnswer = testRelation .select(columnA) - .where(columnA === Literal(1) && columnA === Literal(2) && columnB === Literal(5)) + .where(columnA === Literal(1) && columnA === Literal(2) && columnB === Literal(5)).analyze comparePlans(Optimize.execute(query.analyze), correctAnswer) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala index d4d281e7e05db..3553d23560dad 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala @@ -629,14 +629,14 @@ class FilterPushdownSuite extends PlanTest { val originalQuery = { testRelationWithArrayType .generate(Explode('c_arr), true, false, Some("arr")) - .where(('b >= 5) && ('a + Rand(10).as("rnd") > 6) && ('c > 6)) + .where(('b >= 5) && ('a + Rand(10).as("rnd") > 6) && ('col > 6)) } val optimized = Optimize.execute(originalQuery.analyze) val correctAnswer = { testRelationWithArrayType .where('b >= 5) .generate(Explode('c_arr), true, false, Some("arr")) - .where('a + Rand(10).as("rnd") > 6 && 'c > 6) + .where('a + Rand(10).as("rnd") > 6 && 'col > 6) .analyze } @@ -676,7 +676,7 @@ class FilterPushdownSuite extends PlanTest { val originalQuery = { testRelationWithArrayType .generate(Explode('c_arr), true, false, Some("arr")) - .where(('c > 6) || ('b > 5)).analyze + .where(('col > 6) || ('b > 5)).analyze } val optimized = Optimize.execute(originalQuery) @@ -1129,6 +1129,9 @@ class FilterPushdownSuite extends PlanTest { val correctAnswer = x.where("x.a".attr === 5).join(y.where("y.a".attr === 5), condition = Some("x.a".attr === Rand(10) && "y.b".attr === 5)) - comparePlans(Optimize.execute(originalQuery.analyze), correctAnswer.analyze) + // CheckAnalysis will ensure nondeterministic expressions not appear in join condition. + // TODO support nondeterministic expressions in join condition. + comparePlans(Optimize.execute(originalQuery.analyze), correctAnswer.analyze, + checkAnalysis = false) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/LimitPushdownSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/LimitPushdownSuite.scala index 2885fd6841e9d..fb34c82de468b 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/LimitPushdownSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/LimitPushdownSuite.scala @@ -70,19 +70,21 @@ class LimitPushdownSuite extends PlanTest { } test("Union: no limit to both sides if children having smaller limit values") { - val unionQuery = Union(testRelation.limit(1), testRelation2.select('d).limit(1)).limit(2) + val unionQuery = + Union(testRelation.limit(1), testRelation2.select('d, 'e, 'f).limit(1)).limit(2) val unionOptimized = Optimize.execute(unionQuery.analyze) val unionCorrectAnswer = - Limit(2, Union(testRelation.limit(1), testRelation2.select('d).limit(1))).analyze + Limit(2, Union(testRelation.limit(1), testRelation2.select('d, 'e, 'f).limit(1))).analyze comparePlans(unionOptimized, unionCorrectAnswer) } test("Union: limit to each sides if children having larger limit values") { - val testLimitUnion = Union(testRelation.limit(3), testRelation2.select('d).limit(4)) - val unionQuery = testLimitUnion.limit(2) + val unionQuery = + Union(testRelation.limit(3), testRelation2.select('d, 'e, 'f).limit(4)).limit(2) val unionOptimized = Optimize.execute(unionQuery.analyze) val unionCorrectAnswer = - Limit(2, Union(LocalLimit(2, testRelation), LocalLimit(2, testRelation2.select('d)))).analyze + Limit(2, Union( + LocalLimit(2, testRelation), LocalLimit(2, testRelation2.select('d, 'e, 'f)))).analyze comparePlans(unionOptimized, unionCorrectAnswer) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeCodegenSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeCodegenSuite.scala index f3b65cc797ec4..9dc6738ba04b3 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeCodegenSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeCodegenSuite.scala @@ -50,10 +50,10 @@ class OptimizeCodegenSuite extends PlanTest { test("Nested CaseWhen Codegen.") { assertEquivalent( CaseWhen( - Seq((CaseWhen(Seq((TrueLiteral, Literal(1))), Literal(2)), Literal(3))), + Seq((CaseWhen(Seq((TrueLiteral, TrueLiteral)), FalseLiteral), Literal(3))), CaseWhen(Seq((TrueLiteral, Literal(4))), Literal(5))), CaseWhen( - Seq((CaseWhen(Seq((TrueLiteral, Literal(1))), Literal(2)).toCodegen(), Literal(3))), + Seq((CaseWhen(Seq((TrueLiteral, TrueLiteral)), FalseLiteral).toCodegen(), Literal(3))), CaseWhen(Seq((TrueLiteral, Literal(4))), Literal(5)).toCodegen()).toCodegen()) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OuterJoinEliminationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OuterJoinEliminationSuite.scala index a37bc4bca2422..623ff3d446a5f 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OuterJoinEliminationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OuterJoinEliminationSuite.scala @@ -201,7 +201,7 @@ class OuterJoinEliminationSuite extends PlanTest { val originalQuery = x.join(y, FullOuter, Option("x.a".attr === "y.d".attr)) - .where(Coalesce("y.e".attr :: "x.a".attr :: Nil)) + .where(Coalesce("y.e".attr :: "x.a".attr :: Nil) === 0) val optimized = Optimize.execute(originalQuery.analyze) @@ -209,7 +209,7 @@ class OuterJoinEliminationSuite extends PlanTest { val right = testRelation1 val correctAnswer = left.join(right, FullOuter, Option("a".attr === "d".attr)) - .where(Coalesce("e".attr :: "a".attr :: Nil)).analyze + .where(Coalesce("e".attr :: "a".attr :: Nil) === 0).analyze comparePlans(optimized, correctAnswer) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SimplifyCastsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SimplifyCastsSuite.scala index e84f11272d214..7b3f5b084b015 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SimplifyCastsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SimplifyCastsSuite.scala @@ -44,7 +44,9 @@ class SimplifyCastsSuite extends PlanTest { val input = LocalRelation('a.array(ArrayType(IntegerType, true))) val plan = input.select('a.cast(ArrayType(IntegerType, false)).as("casted")).analyze val optimized = Optimize.execute(plan) - comparePlans(optimized, plan) + // Though cast from `ArrayType(IntegerType, true)` to `ArrayType(IntegerType, false)` is not + // allowed, here we just ensure that `SimplifyCasts` rule respect the plan. + comparePlans(optimized, plan, checkAnalysis = false) } test("non-nullable value map to nullable value map cast") { @@ -61,7 +63,10 @@ class SimplifyCastsSuite extends PlanTest { val plan = input.select('m.cast(MapType(StringType, StringType, false)) .as("casted")).analyze val optimized = Optimize.execute(plan) - comparePlans(optimized, plan) + // Though cast from `MapType(StringType, StringType, true)` to + // `MapType(StringType, StringType, false)` is not allowed, here we just ensure that + // `SimplifyCasts` rule respect the plan. + comparePlans(optimized, plan, checkAnalysis = false) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala index fef39a5b6a32f..0a4ae098d65cc 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.catalyst.parser import org.apache.spark.sql.catalyst.{FunctionIdentifier, TableIdentifier} -import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedFunction, UnresolvedGenerator, UnresolvedInlineTable, UnresolvedRelation, UnresolvedTableValuedFunction} +import org.apache.spark.sql.catalyst.analysis.{AnalysisTest, UnresolvedAttribute, UnresolvedFunction, UnresolvedGenerator, UnresolvedInlineTable, UnresolvedRelation, UnresolvedTableValuedFunction} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical._ @@ -29,13 +29,13 @@ import org.apache.spark.sql.types.IntegerType * * There is also SparkSqlParserSuite in sql/core module for parser rules defined in sql/core module. */ -class PlanParserSuite extends PlanTest { +class PlanParserSuite extends AnalysisTest { import CatalystSqlParser._ import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ private def assertEqual(sqlCommand: String, plan: LogicalPlan): Unit = { - comparePlans(parsePlan(sqlCommand), plan) + comparePlans(parsePlan(sqlCommand), plan, checkAnalysis = false) } private def intercept(sqlCommand: String, messages: String*): Unit = { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala index f44428c3512a9..25313af2be184 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.catalyst.plans import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.analysis.SimpleAnalyzer import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression import org.apache.spark.sql.catalyst.plans.logical._ @@ -90,7 +91,16 @@ abstract class PlanTest extends SparkFunSuite with PredicateHelper { } /** Fails the test if the two plans do not match */ - protected def comparePlans(plan1: LogicalPlan, plan2: LogicalPlan) { + protected def comparePlans( + plan1: LogicalPlan, + plan2: LogicalPlan, + checkAnalysis: Boolean = true): Unit = { + if (checkAnalysis) { + // Make sure both plan pass checkAnalysis. + SimpleAnalyzer.checkAnalysis(plan1) + SimpleAnalyzer.checkAnalysis(plan2) + } + val normalized1 = normalizePlan(normalizeExprIds(plan1)) val normalized2 = normalizePlan(normalizeExprIds(plan2)) if (normalized1 != normalized2) { @@ -104,7 +114,7 @@ abstract class PlanTest extends SparkFunSuite with PredicateHelper { /** Fails the test if the two expressions do not match */ protected def compareExpressions(e1: Expression, e2: Expression): Unit = { - comparePlans(Filter(e1, OneRowRelation), Filter(e2, OneRowRelation)) + comparePlans(Filter(e1, OneRowRelation), Filter(e2, OneRowRelation), checkAnalysis = false) } /** Fails the test if the join order in the two plans do not match */ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameHintSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameHintSuite.scala index 60f6f23860ed9..0dd5bdcba2e4c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameHintSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameHintSuite.scala @@ -17,11 +17,11 @@ package org.apache.spark.sql -import org.apache.spark.sql.catalyst.plans.PlanTest +import org.apache.spark.sql.catalyst.analysis.AnalysisTest import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.test.SharedSQLContext -class DataFrameHintSuite extends PlanTest with SharedSQLContext { +class DataFrameHintSuite extends AnalysisTest with SharedSQLContext { import testImplicits._ lazy val df = spark.range(10) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlParserSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlParserSuite.scala index b32fb90e10072..bd9c2ebd6fab9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlParserSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlParserSuite.scala @@ -19,11 +19,10 @@ package org.apache.spark.sql.execution import org.apache.spark.sql.SaveMode import org.apache.spark.sql.catalyst.{FunctionIdentifier, TableIdentifier} -import org.apache.spark.sql.catalyst.analysis.{UnresolvedAlias, UnresolvedAttribute, UnresolvedRelation, UnresolvedStar} +import org.apache.spark.sql.catalyst.analysis.{AnalysisTest, UnresolvedAlias, UnresolvedAttribute, UnresolvedRelation, UnresolvedStar} import org.apache.spark.sql.catalyst.catalog.{BucketSpec, CatalogStorageFormat, CatalogTable, CatalogTableType} import org.apache.spark.sql.catalyst.expressions.{Ascending, Concat, SortOrder} import org.apache.spark.sql.catalyst.parser.ParseException -import org.apache.spark.sql.catalyst.plans.PlanTest import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Project, RepartitionByExpression, Sort} import org.apache.spark.sql.execution.command._ import org.apache.spark.sql.execution.datasources.CreateTable @@ -36,7 +35,7 @@ import org.apache.spark.sql.types.{IntegerType, LongType, StringType, StructType * See [[org.apache.spark.sql.catalyst.parser.PlanParserSuite]] for rules * defined in the Catalyst module. */ -class SparkSqlParserSuite extends PlanTest { +class SparkSqlParserSuite extends AnalysisTest { val newConf = new SQLConf private lazy val parser = new SparkSqlParser(newConf) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDDLCommandSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDDLCommandSuite.scala index d97b11e447fe2..bee470d8e1382 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDDLCommandSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDDLCommandSuite.scala @@ -30,7 +30,7 @@ import org.apache.spark.sql.catalyst.dsl.plans.DslLogicalPlan import org.apache.spark.sql.catalyst.expressions.JsonTuple import org.apache.spark.sql.catalyst.parser.ParseException import org.apache.spark.sql.catalyst.plans.PlanTest -import org.apache.spark.sql.catalyst.plans.logical.{Generate, ScriptTransformation} +import org.apache.spark.sql.catalyst.plans.logical.{Generate, LogicalPlan, ScriptTransformation} import org.apache.spark.sql.execution.command._ import org.apache.spark.sql.execution.datasources.CreateTable import org.apache.spark.sql.hive.test.{TestHive, TestHiveSingleton} @@ -59,6 +59,11 @@ class HiveDDLCommandSuite extends PlanTest with SQLTestUtils with TestHiveSingle }.head } + private def compareTransformQuery(sql: String, expected: LogicalPlan): Unit = { + val plan = parser.parsePlan(sql).asInstanceOf[ScriptTransformation].copy(ioschema = null) + comparePlans(plan, expected, checkAnalysis = false) + } + test("Test CTAS #1") { val s1 = """CREATE EXTERNAL TABLE IF NOT EXISTS mydb.page_view @@ -253,22 +258,15 @@ class HiveDDLCommandSuite extends PlanTest with SQLTestUtils with TestHiveSingle } test("transform query spec") { - val plan1 = parser.parsePlan("select transform(a, b) using 'func' from e where f < 10") - .asInstanceOf[ScriptTransformation].copy(ioschema = null) - val plan2 = parser.parsePlan("map a, b using 'func' as c, d from e") - .asInstanceOf[ScriptTransformation].copy(ioschema = null) - val plan3 = parser.parsePlan("reduce a, b using 'func' as (c int, d decimal(10, 0)) from e") - .asInstanceOf[ScriptTransformation].copy(ioschema = null) - val p = ScriptTransformation( Seq(UnresolvedAttribute("a"), UnresolvedAttribute("b")), "func", Seq.empty, plans.table("e"), null) - comparePlans(plan1, + compareTransformQuery("select transform(a, b) using 'func' from e where f < 10", p.copy(child = p.child.where('f < 10), output = Seq('key.string, 'value.string))) - comparePlans(plan2, + compareTransformQuery("map a, b using 'func' as c, d from e", p.copy(output = Seq('c.string, 'd.string))) - comparePlans(plan3, + compareTransformQuery("reduce a, b using 'func' as (c int, d decimal(10, 0)) from e", p.copy(output = Seq('c.int, 'd.decimal(10, 0)))) } From ad459cfb1d169d8dd7b9e039ca135ba5cafcab83 Mon Sep 17 00:00:00 2001 From: actuaryzhang Date: Wed, 21 Jun 2017 10:35:16 -0700 Subject: [PATCH 0763/1765] [SPARK-20917][ML][SPARKR] SparkR supports string encoding consistent with R ## What changes were proposed in this pull request? Add `stringIndexerOrderType` to `spark.glm` and `spark.survreg` to support string encoding that is consistent with default R. ## How was this patch tested? new tests Author: actuaryzhang Closes #18140 from actuaryzhang/sparkRFormula. --- R/pkg/R/mllib_regression.R | 52 +++++++++++++--- R/pkg/tests/fulltests/test_mllib_regression.R | 62 +++++++++++++++++++ .../ml/r/AFTSurvivalRegressionWrapper.scala | 4 +- .../GeneralizedLinearRegressionWrapper.scala | 6 +- 4 files changed, 115 insertions(+), 9 deletions(-) diff --git a/R/pkg/R/mllib_regression.R b/R/pkg/R/mllib_regression.R index d59c890f3e5fd..9ecd887f2c127 100644 --- a/R/pkg/R/mllib_regression.R +++ b/R/pkg/R/mllib_regression.R @@ -70,6 +70,12 @@ setClass("IsotonicRegressionModel", representation(jobj = "jobj")) #' the relationship between the variance and mean of the distribution. Only #' applicable to the Tweedie family. #' @param link.power the index in the power link function. Only applicable to the Tweedie family. +#' @param stringIndexerOrderType how to order categories of a string feature column. This is used to +#' decide the base level of a string feature as the last category after +#' ordering is dropped when encoding strings. Supported options are +#' "frequencyDesc", "frequencyAsc", "alphabetDesc", and "alphabetAsc". +#' The default value is "frequencyDesc". When the ordering is set to +#' "alphabetDesc", this drops the same category as R when encoding strings. #' @param ... additional arguments passed to the method. #' @aliases spark.glm,SparkDataFrame,formula-method #' @return \code{spark.glm} returns a fitted generalized linear model. @@ -79,7 +85,7 @@ setClass("IsotonicRegressionModel", representation(jobj = "jobj")) #' @examples #' \dontrun{ #' sparkR.session() -#' t <- as.data.frame(Titanic) +#' t <- as.data.frame(Titanic, stringsAsFactors = FALSE) #' df <- createDataFrame(t) #' model <- spark.glm(df, Freq ~ Sex + Age, family = "gaussian") #' summary(model) @@ -96,6 +102,15 @@ setClass("IsotonicRegressionModel", representation(jobj = "jobj")) #' savedModel <- read.ml(path) #' summary(savedModel) #' +#' # note that the default string encoding is different from R's glm +#' model2 <- glm(Freq ~ Sex + Age, family = "gaussian", data = t) +#' summary(model2) +#' # use stringIndexerOrderType = "alphabetDesc" to force string encoding +#' # to be consistent with R +#' model3 <- spark.glm(df, Freq ~ Sex + Age, family = "gaussian", +#' stringIndexerOrderType = "alphabetDesc") +#' summary(model3) +#' #' # fit tweedie model #' model <- spark.glm(df, Freq ~ Sex + Age, family = "tweedie", #' var.power = 1.2, link.power = 0) @@ -110,8 +125,11 @@ setClass("IsotonicRegressionModel", representation(jobj = "jobj")) #' @seealso \link{glm}, \link{read.ml} setMethod("spark.glm", signature(data = "SparkDataFrame", formula = "formula"), function(data, formula, family = gaussian, tol = 1e-6, maxIter = 25, weightCol = NULL, - regParam = 0.0, var.power = 0.0, link.power = 1.0 - var.power) { + regParam = 0.0, var.power = 0.0, link.power = 1.0 - var.power, + stringIndexerOrderType = c("frequencyDesc", "frequencyAsc", + "alphabetDesc", "alphabetAsc")) { + stringIndexerOrderType <- match.arg(stringIndexerOrderType) if (is.character(family)) { # Handle when family = "tweedie" if (tolower(family) == "tweedie") { @@ -145,7 +163,8 @@ setMethod("spark.glm", signature(data = "SparkDataFrame", formula = "formula"), jobj <- callJStatic("org.apache.spark.ml.r.GeneralizedLinearRegressionWrapper", "fit", formula, data@sdf, tolower(family$family), family$link, tol, as.integer(maxIter), weightCol, regParam, - as.double(var.power), as.double(link.power)) + as.double(var.power), as.double(link.power), + stringIndexerOrderType) new("GeneralizedLinearRegressionModel", jobj = jobj) }) @@ -167,6 +186,12 @@ setMethod("spark.glm", signature(data = "SparkDataFrame", formula = "formula"), #' @param maxit integer giving the maximal number of IRLS iterations. #' @param var.power the index of the power variance function in the Tweedie family. #' @param link.power the index of the power link function in the Tweedie family. +#' @param stringIndexerOrderType how to order categories of a string feature column. This is used to +#' decide the base level of a string feature as the last category after +#' ordering is dropped when encoding strings. Supported options are +#' "frequencyDesc", "frequencyAsc", "alphabetDesc", and "alphabetAsc". +#' The default value is "frequencyDesc". When the ordering is set to +#' "alphabetDesc", this drops the same category as R when encoding strings. #' @return \code{glm} returns a fitted generalized linear model. #' @rdname glm #' @export @@ -182,9 +207,12 @@ setMethod("spark.glm", signature(data = "SparkDataFrame", formula = "formula"), #' @seealso \link{spark.glm} setMethod("glm", signature(formula = "formula", family = "ANY", data = "SparkDataFrame"), function(formula, family = gaussian, data, epsilon = 1e-6, maxit = 25, weightCol = NULL, - var.power = 0.0, link.power = 1.0 - var.power) { + var.power = 0.0, link.power = 1.0 - var.power, + stringIndexerOrderType = c("frequencyDesc", "frequencyAsc", + "alphabetDesc", "alphabetAsc")) { spark.glm(data, formula, family, tol = epsilon, maxIter = maxit, weightCol = weightCol, - var.power = var.power, link.power = link.power) + var.power = var.power, link.power = link.power, + stringIndexerOrderType = stringIndexerOrderType) }) # Returns the summary of a model produced by glm() or spark.glm(), similarly to R's summary(). @@ -418,6 +446,12 @@ setMethod("write.ml", signature(object = "IsotonicRegressionModel", path = "char #' @param aggregationDepth The depth for treeAggregate (greater than or equal to 2). If the dimensions of features #' or the number of partitions are large, this param could be adjusted to a larger size. #' This is an expert parameter. Default value should be good for most cases. +#' @param stringIndexerOrderType how to order categories of a string feature column. This is used to +#' decide the base level of a string feature as the last category after +#' ordering is dropped when encoding strings. Supported options are +#' "frequencyDesc", "frequencyAsc", "alphabetDesc", and "alphabetAsc". +#' The default value is "frequencyDesc". When the ordering is set to +#' "alphabetDesc", this drops the same category as R when encoding strings. #' @param ... additional arguments passed to the method. #' @return \code{spark.survreg} returns a fitted AFT survival regression model. #' @rdname spark.survreg @@ -443,10 +477,14 @@ setMethod("write.ml", signature(object = "IsotonicRegressionModel", path = "char #' } #' @note spark.survreg since 2.0.0 setMethod("spark.survreg", signature(data = "SparkDataFrame", formula = "formula"), - function(data, formula, aggregationDepth = 2) { + function(data, formula, aggregationDepth = 2, + stringIndexerOrderType = c("frequencyDesc", "frequencyAsc", + "alphabetDesc", "alphabetAsc")) { + stringIndexerOrderType <- match.arg(stringIndexerOrderType) formula <- paste(deparse(formula), collapse = "") jobj <- callJStatic("org.apache.spark.ml.r.AFTSurvivalRegressionWrapper", - "fit", formula, data@sdf, as.integer(aggregationDepth)) + "fit", formula, data@sdf, as.integer(aggregationDepth), + stringIndexerOrderType) new("AFTSurvivalRegressionModel", jobj = jobj) }) diff --git a/R/pkg/tests/fulltests/test_mllib_regression.R b/R/pkg/tests/fulltests/test_mllib_regression.R index 82472c92b9965..6b72a09b200d6 100644 --- a/R/pkg/tests/fulltests/test_mllib_regression.R +++ b/R/pkg/tests/fulltests/test_mllib_regression.R @@ -367,6 +367,49 @@ test_that("glm save/load", { unlink(modelPath) }) +test_that("spark.glm and glm with string encoding", { + t <- as.data.frame(Titanic, stringsAsFactors = FALSE) + df <- createDataFrame(t) + + # base R + rm <- stats::glm(Freq ~ Sex + Age, family = "gaussian", data = t) + # spark.glm with default stringIndexerOrderType = "frequencyDesc" + sm0 <- spark.glm(df, Freq ~ Sex + Age, family = "gaussian") + # spark.glm with stringIndexerOrderType = "alphabetDesc" + sm1 <- spark.glm(df, Freq ~ Sex + Age, family = "gaussian", + stringIndexerOrderType = "alphabetDesc") + # glm with stringIndexerOrderType = "alphabetDesc" + sm2 <- glm(Freq ~ Sex + Age, family = "gaussian", data = df, + stringIndexerOrderType = "alphabetDesc") + + rStats <- summary(rm) + rCoefs <- rStats$coefficients + sStats <- lapply(list(sm0, sm1, sm2), summary) + # order by coefficient size since column rendering may be different + o <- order(rCoefs[, 1]) + + # default encoding does not produce same results as R + expect_false(all(abs(rCoefs[o, ] - sStats[[1]]$coefficients[o, ]) < 1e-4)) + + # all estimates should be the same as R with stringIndexerOrderType = "alphabetDesc" + test <- lapply(sStats[2:3], function(stats) { + expect_true(all(abs(rCoefs[o, ] - stats$coefficients[o, ]) < 1e-4)) + expect_equal(stats$dispersion, rStats$dispersion) + expect_equal(stats$null.deviance, rStats$null.deviance) + expect_equal(stats$deviance, rStats$deviance) + expect_equal(stats$df.null, rStats$df.null) + expect_equal(stats$df.residual, rStats$df.residual) + expect_equal(stats$aic, rStats$aic) + }) + + # fitted values should be equal regardless of string encoding + rVals <- predict(rm, t) + test <- lapply(list(sm0, sm1, sm2), function(sm) { + vals <- collect(select(predict(sm, df), "prediction")) + expect_true(all(abs(rVals - vals) < 1e-6), rVals - vals) + }) +}) + test_that("spark.isoreg", { label <- c(7.0, 5.0, 3.0, 5.0, 1.0) feature <- c(0.0, 1.0, 2.0, 3.0, 4.0) @@ -462,6 +505,25 @@ test_that("spark.survreg", { model <- survival::survreg(formula = survival::Surv(time, status) ~ x + sex, data = rData), NA) expect_equal(predict(model, rData)[[1]], 3.724591, tolerance = 1e-4) + + # Test stringIndexerOrderType + rData <- as.data.frame(rData) + rData$sex2 <- c("female", "male")[rData$sex + 1] + df <- createDataFrame(rData) + expect_error( + rModel <- survival::survreg(survival::Surv(time, status) ~ x + sex2, rData), NA) + rCoefs <- as.numeric(summary(rModel)$table[, 1]) + model <- spark.survreg(df, Surv(time, status) ~ x + sex2) + coefs <- as.vector(summary(model)$coefficients[, 1]) + o <- order(rCoefs) + # stringIndexerOrderType = "frequencyDesc" produces different estimates from R + expect_false(all(abs(rCoefs[o] - coefs[o]) < 1e-4)) + + # stringIndexerOrderType = "alphabetDesc" produces the same estimates as R + model <- spark.survreg(df, Surv(time, status) ~ x + sex2, + stringIndexerOrderType = "alphabetDesc") + coefs <- as.vector(summary(model)$coefficients[, 1]) + expect_true(all(abs(rCoefs[o] - coefs[o]) < 1e-4)) } }) diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/AFTSurvivalRegressionWrapper.scala b/mllib/src/main/scala/org/apache/spark/ml/r/AFTSurvivalRegressionWrapper.scala index 0bf543d88894e..80d03ab03c87d 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/r/AFTSurvivalRegressionWrapper.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/r/AFTSurvivalRegressionWrapper.scala @@ -85,11 +85,13 @@ private[r] object AFTSurvivalRegressionWrapper extends MLReadable[AFTSurvivalReg def fit( formula: String, data: DataFrame, - aggregationDepth: Int): AFTSurvivalRegressionWrapper = { + aggregationDepth: Int, + stringIndexerOrderType: String): AFTSurvivalRegressionWrapper = { val (rewritedFormula, censorCol) = formulaRewrite(formula) val rFormula = new RFormula().setFormula(rewritedFormula) + .setStringIndexerOrderType(stringIndexerOrderType) RWrapperUtils.checkDataColumns(rFormula, data) val rFormulaModel = rFormula.fit(data) diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/GeneralizedLinearRegressionWrapper.scala b/mllib/src/main/scala/org/apache/spark/ml/r/GeneralizedLinearRegressionWrapper.scala index 4bd4aa7113f68..ee1fc9b14ceaa 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/r/GeneralizedLinearRegressionWrapper.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/r/GeneralizedLinearRegressionWrapper.scala @@ -65,6 +65,7 @@ private[r] class GeneralizedLinearRegressionWrapper private ( private[r] object GeneralizedLinearRegressionWrapper extends MLReadable[GeneralizedLinearRegressionWrapper] { + // scalastyle:off def fit( formula: String, data: DataFrame, @@ -75,8 +76,11 @@ private[r] object GeneralizedLinearRegressionWrapper weightCol: String, regParam: Double, variancePower: Double, - linkPower: Double): GeneralizedLinearRegressionWrapper = { + linkPower: Double, + stringIndexerOrderType: String): GeneralizedLinearRegressionWrapper = { + // scalastyle:on val rFormula = new RFormula().setFormula(formula) + .setStringIndexerOrderType(stringIndexerOrderType) checkDataColumns(rFormula, data) val rFormulaModel = rFormula.fit(data) // get labels and feature names from output schema From 7a00c658d44139d950b7d3ecd670d79f76e2e747 Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Wed, 21 Jun 2017 10:51:17 -0700 Subject: [PATCH 0764/1765] [SPARK-21147][SS] Throws an analysis exception when a user-specified schema is given in socket/rate sources ## What changes were proposed in this pull request? This PR proposes to throw an exception if a schema is provided by user to socket source as below: **socket source** ```scala import org.apache.spark.sql.types._ val userSpecifiedSchema = StructType( StructField("name", StringType) :: StructField("area", StringType) :: Nil) val df = spark.readStream.format("socket").option("host", "localhost").option("port", 9999).schema(userSpecifiedSchema).load df.printSchema ``` Before ``` root |-- value: string (nullable = true) ``` After ``` org.apache.spark.sql.AnalysisException: The socket source does not support a user-specified schema.; at org.apache.spark.sql.execution.streaming.TextSocketSourceProvider.sourceSchema(socket.scala:199) at org.apache.spark.sql.execution.datasources.DataSource.sourceSchema(DataSource.scala:192) at org.apache.spark.sql.execution.datasources.DataSource.sourceInfo$lzycompute(DataSource.scala:87) at org.apache.spark.sql.execution.datasources.DataSource.sourceInfo(DataSource.scala:87) at org.apache.spark.sql.execution.streaming.StreamingRelation$.apply(StreamingRelation.scala:30) at org.apache.spark.sql.streaming.DataStreamReader.load(DataStreamReader.scala:150) ... 50 elided ``` **rate source** ```scala spark.readStream.format("rate").schema(spark.range(1).schema).load().printSchema() ``` Before ``` root |-- timestamp: timestamp (nullable = true) |-- value: long (nullable = true)` ``` After ``` org.apache.spark.sql.AnalysisException: The rate source does not support a user-specified schema.; at org.apache.spark.sql.execution.streaming.RateSourceProvider.sourceSchema(RateSourceProvider.scala:57) at org.apache.spark.sql.execution.datasources.DataSource.sourceSchema(DataSource.scala:192) at org.apache.spark.sql.execution.datasources.DataSource.sourceInfo$lzycompute(DataSource.scala:87) at org.apache.spark.sql.execution.datasources.DataSource.sourceInfo(DataSource.scala:87) at org.apache.spark.sql.execution.streaming.StreamingRelation$.apply(StreamingRelation.scala:30) at org.apache.spark.sql.streaming.DataStreamReader.load(DataStreamReader.scala:150) ... 48 elided ``` ## How was this patch tested? Unit test in `TextSocketStreamSuite` and `RateSourceSuite`. Author: hyukjinkwon Closes #18365 from HyukjinKwon/SPARK-21147. --- .../execution/streaming/RateSourceProvider.scala | 9 +++++++-- .../spark/sql/execution/streaming/socket.scala | 8 ++++++-- .../sql/execution/streaming/RateSourceSuite.scala | 12 ++++++++++++ .../streaming/TextSocketStreamSuite.scala | 15 +++++++++++++++ 4 files changed, 40 insertions(+), 4 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/RateSourceProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/RateSourceProvider.scala index e61a8eb628891..e76d4dc6125df 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/RateSourceProvider.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/RateSourceProvider.scala @@ -25,7 +25,7 @@ import org.apache.commons.io.IOUtils import org.apache.spark.internal.Logging import org.apache.spark.network.util.JavaUtils -import org.apache.spark.sql.{DataFrame, SQLContext} +import org.apache.spark.sql.{AnalysisException, DataFrame, SQLContext} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, DateTimeUtils} import org.apache.spark.sql.sources.{DataSourceRegister, StreamSourceProvider} @@ -52,8 +52,13 @@ class RateSourceProvider extends StreamSourceProvider with DataSourceRegister { sqlContext: SQLContext, schema: Option[StructType], providerName: String, - parameters: Map[String, String]): (String, StructType) = + parameters: Map[String, String]): (String, StructType) = { + if (schema.nonEmpty) { + throw new AnalysisException("The rate source does not support a user-specified schema.") + } + (shortName(), RateSourceProvider.SCHEMA) + } override def createSource( sqlContext: SQLContext, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/socket.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/socket.scala index 58bff27a05bf3..8e63207959575 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/socket.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/socket.scala @@ -195,13 +195,17 @@ class TextSocketSourceProvider extends StreamSourceProvider with DataSourceRegis if (!parameters.contains("port")) { throw new AnalysisException("Set a port to read from with option(\"port\", ...).") } - val schema = + if (schema.nonEmpty) { + throw new AnalysisException("The socket source does not support a user-specified schema.") + } + + val sourceSchema = if (parseIncludeTimestamp(parameters)) { TextSocketSource.SCHEMA_TIMESTAMP } else { TextSocketSource.SCHEMA_REGULAR } - ("textSocket", schema) + ("textSocket", sourceSchema) } override def createSource( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/RateSourceSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/RateSourceSuite.scala index bdba536425a43..03d0f63fa4d7f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/RateSourceSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/RateSourceSuite.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.execution.streaming import java.util.concurrent.TimeUnit +import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.functions._ import org.apache.spark.sql.streaming.{StreamingQueryException, StreamTest} import org.apache.spark.util.ManualClock @@ -179,4 +180,15 @@ class RateSourceSuite extends StreamTest { testIllegalOptionValue("rowsPerSecond", "-1", Seq("-1", "rowsPerSecond", "positive")) testIllegalOptionValue("numPartitions", "-1", Seq("-1", "numPartitions", "positive")) } + + test("user-specified schema given") { + val exception = intercept[AnalysisException] { + spark.readStream + .format("rate") + .schema(spark.range(1).schema) + .load() + } + assert(exception.getMessage.contains( + "rate source does not support a user-specified schema")) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/TextSocketStreamSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/TextSocketStreamSuite.scala index 5174a0415304c..9ebf4d2835266 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/TextSocketStreamSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/TextSocketStreamSuite.scala @@ -148,6 +148,21 @@ class TextSocketStreamSuite extends StreamTest with SharedSQLContext with Before } } + test("user-specified schema given") { + val provider = new TextSocketSourceProvider + val userSpecifiedSchema = StructType( + StructField("name", StringType) :: + StructField("area", StringType) :: Nil) + val exception = intercept[AnalysisException] { + provider.sourceSchema( + sqlContext, Some(userSpecifiedSchema), + "", + Map("host" -> "localhost", "port" -> "1234")) + } + assert(exception.getMessage.contains( + "socket source does not support a user-specified schema")) + } + test("no server up") { val provider = new TextSocketSourceProvider val parameters = Map("host" -> "localhost", "port" -> "0") From ba78514da7bf2132873270b8bf39b50e54f4b094 Mon Sep 17 00:00:00 2001 From: sjarvie Date: Wed, 21 Jun 2017 10:51:45 -0700 Subject: [PATCH 0765/1765] [SPARK-21125][PYTHON] Extend setJobDescription to PySpark and JavaSpark APIs ## What changes were proposed in this pull request? Extend setJobDescription to PySpark and JavaSpark APIs SPARK-21125 ## How was this patch tested? Testing was done by running a local Spark shell on the built UI. I originally had added a unit test but the PySpark context cannot easily access the Scala Spark Context's private variable with the Job Description key so I omitted the test, due to the simplicity of this addition. Also ran the existing tests. # Misc This contribution is my original work and that I license the work to the project under the project's open source license. Author: sjarvie Closes #18332 from sjarvie/add_python_set_job_description. --- .../scala/org/apache/spark/api/java/JavaSparkContext.scala | 6 ++++++ python/pyspark/context.py | 6 ++++++ 2 files changed, 12 insertions(+) diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala b/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala index 9481156bc93a5..f1936bf587282 100644 --- a/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala +++ b/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala @@ -757,6 +757,12 @@ class JavaSparkContext(val sc: SparkContext) */ def getLocalProperty(key: String): String = sc.getLocalProperty(key) + /** + * Set a human readable description of the current job. + * @since 2.3.0 + */ + def setJobDescription(value: String): Unit = sc.setJobDescription(value) + /** Control our logLevel. This overrides any user-defined log settings. * @param logLevel The desired log level as a string. * Valid log levels include: ALL, DEBUG, ERROR, FATAL, INFO, OFF, TRACE, WARN diff --git a/python/pyspark/context.py b/python/pyspark/context.py index 3be07325f4162..c4b7e6372d1a2 100644 --- a/python/pyspark/context.py +++ b/python/pyspark/context.py @@ -942,6 +942,12 @@ def getLocalProperty(self, key): """ return self._jsc.getLocalProperty(key) + def setJobDescription(self, value): + """ + Set a human readable description of the current job. + """ + self._jsc.setJobDescription(value) + def sparkUser(self): """ Get SPARK_USER for user who is running SparkContext. From 215281d88ed664547088309cb432da2fed18b8b7 Mon Sep 17 00:00:00 2001 From: zero323 Date: Wed, 21 Jun 2017 14:59:52 -0700 Subject: [PATCH 0766/1765] [SPARK-20830][PYSPARK][SQL] Add posexplode and posexplode_outer ## What changes were proposed in this pull request? Add Python wrappers for `o.a.s.sql.functions.explode_outer` and `o.a.s.sql.functions.posexplode_outer`. ## How was this patch tested? Unit tests, doctests. Author: zero323 Closes #18049 from zero323/SPARK-20830. --- python/pyspark/sql/functions.py | 65 +++++++++++++++++++++++++++++++++ python/pyspark/sql/tests.py | 20 +++++++++- 2 files changed, 83 insertions(+), 2 deletions(-) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index 240ae65a61785..3416c4b118a07 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -1727,6 +1727,71 @@ def posexplode(col): return Column(jc) +@since(2.3) +def explode_outer(col): + """Returns a new row for each element in the given array or map. + Unlike explode, if the array/map is null or empty then null is produced. + + >>> df = spark.createDataFrame( + ... [(1, ["foo", "bar"], {"x": 1.0}), (2, [], {}), (3, None, None)], + ... ("id", "an_array", "a_map") + ... ) + >>> df.select("id", "an_array", explode_outer("a_map")).show() + +---+----------+----+-----+ + | id| an_array| key|value| + +---+----------+----+-----+ + | 1|[foo, bar]| x| 1.0| + | 2| []|null| null| + | 3| null|null| null| + +---+----------+----+-----+ + + >>> df.select("id", "a_map", explode_outer("an_array")).show() + +---+-------------+----+ + | id| a_map| col| + +---+-------------+----+ + | 1|Map(x -> 1.0)| foo| + | 1|Map(x -> 1.0)| bar| + | 2| Map()|null| + | 3| null|null| + +---+-------------+----+ + """ + sc = SparkContext._active_spark_context + jc = sc._jvm.functions.explode_outer(_to_java_column(col)) + return Column(jc) + + +@since(2.3) +def posexplode_outer(col): + """Returns a new row for each element with position in the given array or map. + Unlike posexplode, if the array/map is null or empty then the row (null, null) is produced. + + >>> df = spark.createDataFrame( + ... [(1, ["foo", "bar"], {"x": 1.0}), (2, [], {}), (3, None, None)], + ... ("id", "an_array", "a_map") + ... ) + >>> df.select("id", "an_array", posexplode_outer("a_map")).show() + +---+----------+----+----+-----+ + | id| an_array| pos| key|value| + +---+----------+----+----+-----+ + | 1|[foo, bar]| 0| x| 1.0| + | 2| []|null|null| null| + | 3| null|null|null| null| + +---+----------+----+----+-----+ + >>> df.select("id", "a_map", posexplode_outer("an_array")).show() + +---+-------------+----+----+ + | id| a_map| pos| col| + +---+-------------+----+----+ + | 1|Map(x -> 1.0)| 0| foo| + | 1|Map(x -> 1.0)| 1| bar| + | 2| Map()|null|null| + | 3| null|null|null| + +---+-------------+----+----+ + """ + sc = SparkContext._active_spark_context + jc = sc._jvm.functions.posexplode_outer(_to_java_column(col)) + return Column(jc) + + @ignore_unicode_prefix @since(1.6) def get_json_object(col, path): diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 31f932a363225..3b308579a3778 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -258,8 +258,12 @@ def test_column_name_encoding(self): self.assertTrue(isinstance(columns[1], str)) def test_explode(self): - from pyspark.sql.functions import explode - d = [Row(a=1, intlist=[1, 2, 3], mapfield={"a": "b"})] + from pyspark.sql.functions import explode, explode_outer, posexplode_outer + d = [ + Row(a=1, intlist=[1, 2, 3], mapfield={"a": "b"}), + Row(a=1, intlist=[], mapfield={}), + Row(a=1, intlist=None, mapfield=None), + ] rdd = self.sc.parallelize(d) data = self.spark.createDataFrame(rdd) @@ -272,6 +276,18 @@ def test_explode(self): self.assertEqual(result[0][0], "a") self.assertEqual(result[0][1], "b") + result = [tuple(x) for x in data.select(posexplode_outer("intlist")).collect()] + self.assertEqual(result, [(0, 1), (1, 2), (2, 3), (None, None), (None, None)]) + + result = [tuple(x) for x in data.select(posexplode_outer("mapfield")).collect()] + self.assertEqual(result, [(0, 'a', 'b'), (None, None, None), (None, None, None)]) + + result = [x[0] for x in data.select(explode_outer("intlist")).collect()] + self.assertEqual(result, [1, 2, 3, None, None]) + + result = [tuple(x) for x in data.select(explode_outer("mapfield")).collect()] + self.assertEqual(result, [('a', 'b'), (None, None), (None, None)]) + def test_and_in_expression(self): self.assertEqual(4, self.df.filter((self.df.key <= 10) & (self.df.value <= "2")).count()) self.assertRaises(ValueError, lambda: (self.df.key <= 10) and (self.df.value <= "2")) From 53543374ce0cf0cec26de2382fbc85b7d5c7e9d6 Mon Sep 17 00:00:00 2001 From: wangmiao1981 Date: Wed, 21 Jun 2017 20:42:45 -0700 Subject: [PATCH 0767/1765] [SPARK-20906][SPARKR] Constrained Logistic Regression for SparkR ## What changes were proposed in this pull request? PR https://github.com/apache/spark/pull/17715 Added Constrained Logistic Regression for ML. We should add it to SparkR. ## How was this patch tested? Add new unit tests. Author: wangmiao1981 Closes #18128 from wangmiao1981/test. --- R/pkg/R/mllib_classification.R | 61 ++++++++++++++++++- .../fulltests/test_mllib_classification.R | 40 ++++++++++++ .../classification/LogisticRegression.scala | 8 +-- .../ml/r/LogisticRegressionWrapper.scala | 34 ++++++++++- 4 files changed, 135 insertions(+), 8 deletions(-) diff --git a/R/pkg/R/mllib_classification.R b/R/pkg/R/mllib_classification.R index bdcc0818d139d..82d2428f3c444 100644 --- a/R/pkg/R/mllib_classification.R +++ b/R/pkg/R/mllib_classification.R @@ -204,6 +204,20 @@ function(object, path, overwrite = FALSE) { #' @param aggregationDepth The depth for treeAggregate (greater than or equal to 2). If the dimensions of features #' or the number of partitions are large, this param could be adjusted to a larger size. #' This is an expert parameter. Default value should be good for most cases. +#' @param lowerBoundsOnCoefficients The lower bounds on coefficients if fitting under bound constrained optimization. +#' The bound matrix must be compatible with the shape (1, number of features) for binomial +#' regression, or (number of classes, number of features) for multinomial regression. +#' It is a R matrix. +#' @param upperBoundsOnCoefficients The upper bounds on coefficients if fitting under bound constrained optimization. +#' The bound matrix must be compatible with the shape (1, number of features) for binomial +#' regression, or (number of classes, number of features) for multinomial regression. +#' It is a R matrix. +#' @param lowerBoundsOnIntercepts The lower bounds on intercepts if fitting under bound constrained optimization. +#' The bounds vector size must be equal to 1 for binomial regression, or the number +#' of classes for multinomial regression. +#' @param upperBoundsOnIntercepts The upper bounds on intercepts if fitting under bound constrained optimization. +#' The bound vector size must be equal to 1 for binomial regression, or the number +#' of classes for multinomial regression. #' @param ... additional arguments passed to the method. #' @return \code{spark.logit} returns a fitted logistic regression model. #' @rdname spark.logit @@ -241,8 +255,12 @@ function(object, path, overwrite = FALSE) { setMethod("spark.logit", signature(data = "SparkDataFrame", formula = "formula"), function(data, formula, regParam = 0.0, elasticNetParam = 0.0, maxIter = 100, tol = 1E-6, family = "auto", standardization = TRUE, - thresholds = 0.5, weightCol = NULL, aggregationDepth = 2) { + thresholds = 0.5, weightCol = NULL, aggregationDepth = 2, + lowerBoundsOnCoefficients = NULL, upperBoundsOnCoefficients = NULL, + lowerBoundsOnIntercepts = NULL, upperBoundsOnIntercepts = NULL) { formula <- paste(deparse(formula), collapse = "") + row <- 0 + col <- 0 if (!is.null(weightCol) && weightCol == "") { weightCol <- NULL @@ -250,12 +268,51 @@ setMethod("spark.logit", signature(data = "SparkDataFrame", formula = "formula") weightCol <- as.character(weightCol) } + if (!is.null(lowerBoundsOnIntercepts)) { + lowerBoundsOnIntercepts <- as.array(lowerBoundsOnIntercepts) + } + + if (!is.null(upperBoundsOnIntercepts)) { + upperBoundsOnIntercepts <- as.array(upperBoundsOnIntercepts) + } + + if (!is.null(lowerBoundsOnCoefficients)) { + if (class(lowerBoundsOnCoefficients) != "matrix") { + stop("lowerBoundsOnCoefficients must be a matrix.") + } + row <- nrow(lowerBoundsOnCoefficients) + col <- ncol(lowerBoundsOnCoefficients) + lowerBoundsOnCoefficients <- as.array(as.vector(lowerBoundsOnCoefficients)) + } + + if (!is.null(upperBoundsOnCoefficients)) { + if (class(upperBoundsOnCoefficients) != "matrix") { + stop("upperBoundsOnCoefficients must be a matrix.") + } + + if (!is.null(lowerBoundsOnCoefficients) && (row != nrow(upperBoundsOnCoefficients) + || col != ncol(upperBoundsOnCoefficients))) { + stop(paste0("dimension of upperBoundsOnCoefficients ", + "is not the same as lowerBoundsOnCoefficients", sep = "")) + } + + if (is.null(lowerBoundsOnCoefficients)) { + row <- nrow(upperBoundsOnCoefficients) + col <- ncol(upperBoundsOnCoefficients) + } + + upperBoundsOnCoefficients <- as.array(as.vector(upperBoundsOnCoefficients)) + } + jobj <- callJStatic("org.apache.spark.ml.r.LogisticRegressionWrapper", "fit", data@sdf, formula, as.numeric(regParam), as.numeric(elasticNetParam), as.integer(maxIter), as.numeric(tol), as.character(family), as.logical(standardization), as.array(thresholds), - weightCol, as.integer(aggregationDepth)) + weightCol, as.integer(aggregationDepth), + as.integer(row), as.integer(col), + lowerBoundsOnCoefficients, upperBoundsOnCoefficients, + lowerBoundsOnIntercepts, upperBoundsOnIntercepts) new("LogisticRegressionModel", jobj = jobj) }) diff --git a/R/pkg/tests/fulltests/test_mllib_classification.R b/R/pkg/tests/fulltests/test_mllib_classification.R index 726e9d9a20b1c..3d75f4ce11ec8 100644 --- a/R/pkg/tests/fulltests/test_mllib_classification.R +++ b/R/pkg/tests/fulltests/test_mllib_classification.R @@ -223,6 +223,46 @@ test_that("spark.logit", { model2 <- spark.logit(df2, label ~ feature, weightCol = "weight") prediction2 <- collect(select(predict(model2, df2), "prediction")) expect_equal(sort(prediction2$prediction), c("0.0", "0.0", "0.0", "0.0", "0.0")) + + # Test binomial logistic regression againt two classes with upperBoundsOnCoefficients + # and upperBoundsOnIntercepts + u <- matrix(c(1.0, 0.0, 1.0, 0.0), nrow = 1, ncol = 4) + model <- spark.logit(training, Species ~ ., upperBoundsOnCoefficients = u, + upperBoundsOnIntercepts = 1.0) + summary <- summary(model) + coefsR <- c(-11.13331, 1.00000, 0.00000, 1.00000, 0.00000) + coefs <- summary$coefficients[, "Estimate"] + expect_true(all(abs(coefsR - coefs) < 0.1)) + # Test upperBoundsOnCoefficients should be matrix + expect_error(spark.logit(training, Species ~ ., upperBoundsOnCoefficients = as.array(c(1, 2)), + upperBoundsOnIntercepts = 1.0)) + + # Test binomial logistic regression againt two classes with lowerBoundsOnCoefficients + # and lowerBoundsOnIntercepts + l <- matrix(c(0.0, -1.0, 0.0, -1.0), nrow = 1, ncol = 4) + model <- spark.logit(training, Species ~ ., lowerBoundsOnCoefficients = l, + lowerBoundsOnIntercepts = 0.0) + summary <- summary(model) + coefsR <- c(0, 0, -1, 0, 1.902192) + coefs <- summary$coefficients[, "Estimate"] + expect_true(all(abs(coefsR - coefs) < 0.1)) + # Test lowerBoundsOnCoefficients should be matrix + expect_error(spark.logit(training, Species ~ ., lowerBoundsOnCoefficients = as.array(c(1, 2)), + lowerBoundsOnIntercepts = 0.0)) + + # Test multinomial logistic regression with lowerBoundsOnCoefficients + # and lowerBoundsOnIntercepts + l <- matrix(c(0.0, -1.0, 0.0, -1.0, 0.0, -1.0, 0.0, -1.0), nrow = 2, ncol = 4) + model <- spark.logit(training, Species ~ ., family = "multinomial", + lowerBoundsOnCoefficients = l, + lowerBoundsOnIntercepts = as.array(c(0.0, 0.0))) + summary <- summary(model) + versicolorCoefsR <- c(42.639465, 7.258104, 14.330814, 16.298243, 11.716429) + virginicaCoefsR <- c(0.0002970796, 4.79274, 7.65047, 25.72793, 30.0021) + versicolorCoefs <- summary$coefficients[, "versicolor"] + virginicaCoefs <- summary$coefficients[, "virginica"] + expect_true(all(abs(versicolorCoefsR - versicolorCoefs) < 0.1)) + expect_true(all(abs(virginicaCoefsR - virginicaCoefs) < 0.1)) }) test_that("spark.mlp", { diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala index 567af0488e1b4..b234bc4c2df4f 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala @@ -214,7 +214,7 @@ private[classification] trait LogisticRegressionParams extends ProbabilisticClas /** * The lower bounds on intercepts if fitting under bound constrained optimization. - * The bounds vector size must be equal with 1 for binomial regression, or the number + * The bounds vector size must be equal to 1 for binomial regression, or the number * of classes for multinomial regression. Otherwise, it throws exception. * Default is none. * @@ -230,7 +230,7 @@ private[classification] trait LogisticRegressionParams extends ProbabilisticClas /** * The upper bounds on intercepts if fitting under bound constrained optimization. - * The bound vector size must be equal with 1 for binomial regression, or the number + * The bound vector size must be equal to 1 for binomial regression, or the number * of classes for multinomial regression. Otherwise, it throws exception. * Default is none. * @@ -451,12 +451,12 @@ class LogisticRegression @Since("1.2.0") ( } if (isSet(lowerBoundsOnIntercepts)) { require($(lowerBoundsOnIntercepts).size == numCoefficientSets, "The size of " + - "lowerBoundsOnIntercepts must be equal with 1 for binomial regression, or the number of " + + "lowerBoundsOnIntercepts must be equal to 1 for binomial regression, or the number of " + s"classes for multinomial regression, but found: ${getLowerBoundsOnIntercepts.size}.") } if (isSet(upperBoundsOnIntercepts)) { require($(upperBoundsOnIntercepts).size == numCoefficientSets, "The size of " + - "upperBoundsOnIntercepts must be equal with 1 for binomial regression, or the number of " + + "upperBoundsOnIntercepts must be equal to 1 for binomial regression, or the number of " + s"classes for multinomial regression, but found: ${getUpperBoundsOnIntercepts.size}.") } if (isSet(lowerBoundsOnCoefficients) && isSet(upperBoundsOnCoefficients)) { diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/LogisticRegressionWrapper.scala b/mllib/src/main/scala/org/apache/spark/ml/r/LogisticRegressionWrapper.scala index 703bcdf4ca725..b96481acf46d7 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/r/LogisticRegressionWrapper.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/r/LogisticRegressionWrapper.scala @@ -25,7 +25,7 @@ import org.json4s.jackson.JsonMethods._ import org.apache.spark.ml.{Pipeline, PipelineModel} import org.apache.spark.ml.classification.{LogisticRegression, LogisticRegressionModel} import org.apache.spark.ml.feature.{IndexToString, RFormula} -import org.apache.spark.ml.linalg.Vector +import org.apache.spark.ml.linalg.{Matrices, Vector, Vectors} import org.apache.spark.ml.r.RWrapperUtils._ import org.apache.spark.ml.util._ import org.apache.spark.sql.{DataFrame, Dataset} @@ -97,7 +97,13 @@ private[r] object LogisticRegressionWrapper standardization: Boolean, thresholds: Array[Double], weightCol: String, - aggregationDepth: Int + aggregationDepth: Int, + numRowsOfBoundsOnCoefficients: Int, + numColsOfBoundsOnCoefficients: Int, + lowerBoundsOnCoefficients: Array[Double], + upperBoundsOnCoefficients: Array[Double], + lowerBoundsOnIntercepts: Array[Double], + upperBoundsOnIntercepts: Array[Double] ): LogisticRegressionWrapper = { val rFormula = new RFormula() @@ -133,6 +139,30 @@ private[r] object LogisticRegressionWrapper if (weightCol != null) lr.setWeightCol(weightCol) + if (numRowsOfBoundsOnCoefficients != 0 && + numColsOfBoundsOnCoefficients != 0 && lowerBoundsOnCoefficients != null) { + val coef = Matrices.dense(numRowsOfBoundsOnCoefficients, + numColsOfBoundsOnCoefficients, lowerBoundsOnCoefficients) + lr.setLowerBoundsOnCoefficients(coef) + } + + if (numRowsOfBoundsOnCoefficients != 0 && + numColsOfBoundsOnCoefficients != 0 && upperBoundsOnCoefficients != null) { + val coef = Matrices.dense(numRowsOfBoundsOnCoefficients, + numColsOfBoundsOnCoefficients, upperBoundsOnCoefficients) + lr.setUpperBoundsOnCoefficients(coef) + } + + if (lowerBoundsOnIntercepts != null) { + val intercept = Vectors.dense(lowerBoundsOnIntercepts) + lr.setLowerBoundsOnIntercepts(intercept) + } + + if (upperBoundsOnIntercepts != null) { + val intercept = Vectors.dense(upperBoundsOnIntercepts) + lr.setUpperBoundsOnIntercepts(intercept) + } + val idxToStr = new IndexToString() .setInputCol(PREDICTED_LABEL_INDEX_COL) .setOutputCol(PREDICTED_LABEL_COL) From d66b143eec7f604595089f72d8786edbdcd74282 Mon Sep 17 00:00:00 2001 From: Shixiong Zhu Date: Wed, 21 Jun 2017 23:43:21 -0700 Subject: [PATCH 0768/1765] [SPARK-21167][SS] Decode the path generated by File sink to handle special characters ## What changes were proposed in this pull request? Decode the path generated by File sink to handle special characters. ## How was this patch tested? The added unit test. Author: Shixiong Zhu Closes #18381 from zsxwing/SPARK-21167. --- .../streaming/FileStreamSinkLog.scala | 5 +++- .../sql/streaming/FileStreamSinkSuite.scala | 29 +++++++++++++++++++ 2 files changed, 33 insertions(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSinkLog.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSinkLog.scala index 8d718b2164d22..c9939ac1db746 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSinkLog.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSinkLog.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.execution.streaming +import java.net.URI + import org.apache.hadoop.fs.{FileStatus, Path} import org.json4s.NoTypeHints import org.json4s.jackson.Serialization @@ -47,7 +49,8 @@ case class SinkFileStatus( action: String) { def toFileStatus: FileStatus = { - new FileStatus(size, isDir, blockReplication, blockSize, modificationTime, new Path(path)) + new FileStatus( + size, isDir, blockReplication, blockSize, modificationTime, new Path(new URI(path))) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSinkSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSinkSuite.scala index 1a2d3a13f3a4a..bb6a27803bb20 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSinkSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSinkSuite.scala @@ -64,6 +64,35 @@ class FileStreamSinkSuite extends StreamTest { } } + test("SPARK-21167: encode and decode path correctly") { + val inputData = MemoryStream[String] + val ds = inputData.toDS() + + val outputDir = Utils.createTempDir(namePrefix = "stream.output").getCanonicalPath + val checkpointDir = Utils.createTempDir(namePrefix = "stream.checkpoint").getCanonicalPath + + val query = ds.map(s => (s, s.length)) + .toDF("value", "len") + .writeStream + .partitionBy("value") + .option("checkpointLocation", checkpointDir) + .format("parquet") + .start(outputDir) + + try { + // The output is partitoned by "value", so the value will appear in the file path. + // This is to test if we handle spaces in the path correctly. + inputData.addData("hello world") + failAfter(streamingTimeout) { + query.processAllAvailable() + } + val outputDf = spark.read.parquet(outputDir) + checkDatasetUnorderly(outputDf.as[(Int, String)], ("hello world".length, "hello world")) + } finally { + query.stop() + } + } + test("partitioned writing and batch reading") { val inputData = MemoryStream[Int] val ds = inputData.toDS() From 67c75021c59d93cda9b5d70c0ef6d547fff92083 Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Thu, 22 Jun 2017 16:22:02 +0800 Subject: [PATCH 0769/1765] [SPARK-21163][SQL] DataFrame.toPandas should respect the data type ## What changes were proposed in this pull request? Currently we convert a spark DataFrame to Pandas Dataframe by `pd.DataFrame.from_records`. It infers the data type from the data and doesn't respect the spark DataFrame Schema. This PR fixes it. ## How was this patch tested? a new regression test Author: hyukjinkwon Author: Wenchen Fan Author: Wenchen Fan Closes #18378 from cloud-fan/to_pandas. --- python/pyspark/sql/dataframe.py | 31 ++++++++++++++++++++++++++++++- python/pyspark/sql/tests.py | 24 ++++++++++++++++++++++++ 2 files changed, 54 insertions(+), 1 deletion(-) diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index 8541403dfe2f1..0649271ed2246 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -1721,7 +1721,18 @@ def toPandas(self): 1 5 Bob """ import pandas as pd - return pd.DataFrame.from_records(self.collect(), columns=self.columns) + + dtype = {} + for field in self.schema: + pandas_type = _to_corrected_pandas_type(field.dataType) + if pandas_type is not None: + dtype[field.name] = pandas_type + + pdf = pd.DataFrame.from_records(self.collect(), columns=self.columns) + + for f, t in dtype.items(): + pdf[f] = pdf[f].astype(t, copy=False) + return pdf ########################################################################################## # Pandas compatibility @@ -1750,6 +1761,24 @@ def _to_scala_map(sc, jm): return sc._jvm.PythonUtils.toScalaMap(jm) +def _to_corrected_pandas_type(dt): + """ + When converting Spark SQL records to Pandas DataFrame, the inferred data type may be wrong. + This method gets the corrected data type for Pandas if that type may be inferred uncorrectly. + """ + import numpy as np + if type(dt) == ByteType: + return np.int8 + elif type(dt) == ShortType: + return np.int16 + elif type(dt) == IntegerType: + return np.int32 + elif type(dt) == FloatType: + return np.float32 + else: + return None + + class DataFrameNaFunctions(object): """Functionality for working with missing data in :class:`DataFrame`. diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 3b308579a3778..0a1cd6856b8e8 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -46,6 +46,14 @@ else: import unittest +_have_pandas = False +try: + import pandas + _have_pandas = True +except: + # No Pandas, but that's okay, we'll skip those tests + pass + from pyspark import SparkContext from pyspark.sql import SparkSession, SQLContext, HiveContext, Column, Row from pyspark.sql.types import * @@ -2290,6 +2298,22 @@ def count_bucketed_cols(names, table="pyspark_bucket"): .mode("overwrite").saveAsTable("pyspark_bucket")) self.assertSetEqual(set(data), set(self.spark.table("pyspark_bucket").collect())) + @unittest.skipIf(not _have_pandas, "Pandas not installed") + def test_to_pandas(self): + import numpy as np + schema = StructType().add("a", IntegerType()).add("b", StringType())\ + .add("c", BooleanType()).add("d", FloatType()) + data = [ + (1, "foo", True, 3.0), (2, "foo", True, 5.0), + (3, "bar", False, -1.0), (4, "bar", False, 6.0), + ] + df = self.spark.createDataFrame(data, schema) + types = df.toPandas().dtypes + self.assertEquals(types[0], np.int32) + self.assertEquals(types[1], np.object) + self.assertEquals(types[2], np.bool) + self.assertEquals(types[3], np.float32) + class HiveSparkSubmitTests(SparkSubmitTests): From 97b307c87c0f262ea3e020bf3d72383deef76619 Mon Sep 17 00:00:00 2001 From: actuaryzhang Date: Thu, 22 Jun 2017 10:12:33 +0100 Subject: [PATCH 0770/1765] [SQL][DOC] Fix documentation of lpad ## What changes were proposed in this pull request? Fix incomplete documentation for `lpad`. Author: actuaryzhang Closes #18367 from actuaryzhang/SQLDoc. --- .../src/main/scala/org/apache/spark/sql/functions.scala | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 9a35a5c4658e3..839cbf42024e3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -2292,7 +2292,8 @@ object functions { } /** - * Left-pad the string column with + * Left-pad the string column with pad to a length of len. If the string column is longer + * than len, the return value is shortened to len characters. * * @group string_funcs * @since 1.5.0 @@ -2350,7 +2351,8 @@ object functions { def unbase64(e: Column): Column = withExpr { UnBase64(e.expr) } /** - * Right-padded with pad to a length of len. + * Right-pad the string column with pad to a length of len. If the string column is longer + * than len, the return value is shortened to len characters. * * @group string_funcs * @since 1.5.0 From 2dadea95c8e2c727e97fca91b0060f666fc0c65b Mon Sep 17 00:00:00 2001 From: Xingbo Jiang Date: Thu, 22 Jun 2017 20:48:12 +0800 Subject: [PATCH 0771/1765] [SPARK-20832][CORE] Standalone master should explicitly inform drivers of worker deaths and invalidate external shuffle service outputs ## What changes were proposed in this pull request? In standalone mode, master should explicitly inform each active driver of any worker deaths, so the invalid external shuffle service outputs on the lost host would be removed from the shuffle mapStatus, thus we can avoid future `FetchFailure`s. ## How was this patch tested? Manually tested by the following steps: 1. Start a standalone Spark cluster with one driver node and two worker nodes; 2. Run a Job with ShuffleMapStage, ensure the outputs distribute on each worker; 3. Run another Job to make all executors exit, but the workers are all alive; 4. Kill one of the workers; 5. Run rdd.collect(), before this change, we should see `FetchFailure`s and failed Stages, while after the change, the job should complete without failure. Before the change: ![image](https://user-images.githubusercontent.com/4784782/27335366-c251c3d6-55fe-11e7-99dd-d1fdcb429210.png) After the change: ![image](https://user-images.githubusercontent.com/4784782/27335393-d1c71640-55fe-11e7-89ed-bd760f1f39af.png) Author: Xingbo Jiang Closes #18362 from jiangxb1987/removeWorker. --- .../apache/spark/deploy/DeployMessage.scala | 2 ++ .../deploy/client/StandaloneAppClient.scala | 4 +++ .../client/StandaloneAppClientListener.scala | 8 +++-- .../apache/spark/deploy/master/Master.scala | 15 ++++++---- .../apache/spark/scheduler/DAGScheduler.scala | 30 +++++++++++++++++++ .../spark/scheduler/DAGSchedulerEvent.scala | 3 ++ .../spark/scheduler/TaskScheduler.scala | 5 ++++ .../spark/scheduler/TaskSchedulerImpl.scala | 5 ++++ .../cluster/CoarseGrainedClusterMessage.scala | 3 ++ .../CoarseGrainedSchedulerBackend.scala | 25 +++++++++++++--- .../cluster/StandaloneSchedulerBackend.scala | 5 ++++ .../spark/deploy/client/AppClientSuite.scala | 2 ++ .../spark/scheduler/DAGSchedulerSuite.scala | 2 ++ .../ExternalClusterManagerSuite.scala | 1 + 14 files changed, 98 insertions(+), 12 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/DeployMessage.scala b/core/src/main/scala/org/apache/spark/deploy/DeployMessage.scala index c1a91c27eef2d..49a319abb3238 100644 --- a/core/src/main/scala/org/apache/spark/deploy/DeployMessage.scala +++ b/core/src/main/scala/org/apache/spark/deploy/DeployMessage.scala @@ -158,6 +158,8 @@ private[deploy] object DeployMessages { case class ApplicationRemoved(message: String) + case class WorkerRemoved(id: String, host: String, message: String) + // DriverClient <-> Master case class RequestSubmitDriver(driverDescription: DriverDescription) extends DeployMessage diff --git a/core/src/main/scala/org/apache/spark/deploy/client/StandaloneAppClient.scala b/core/src/main/scala/org/apache/spark/deploy/client/StandaloneAppClient.scala index 93f58ce63799f..757c930b84eb2 100644 --- a/core/src/main/scala/org/apache/spark/deploy/client/StandaloneAppClient.scala +++ b/core/src/main/scala/org/apache/spark/deploy/client/StandaloneAppClient.scala @@ -182,6 +182,10 @@ private[spark] class StandaloneAppClient( listener.executorRemoved(fullId, message.getOrElse(""), exitStatus, workerLost) } + case WorkerRemoved(id, host, message) => + logInfo("Master removed worker %s: %s".format(id, message)) + listener.workerRemoved(id, host, message) + case MasterChanged(masterRef, masterWebUiUrl) => logInfo("Master has changed, new master is at " + masterRef.address.toSparkURL) master = Some(masterRef) diff --git a/core/src/main/scala/org/apache/spark/deploy/client/StandaloneAppClientListener.scala b/core/src/main/scala/org/apache/spark/deploy/client/StandaloneAppClientListener.scala index 64255ec92b72a..d8bc1a883def1 100644 --- a/core/src/main/scala/org/apache/spark/deploy/client/StandaloneAppClientListener.scala +++ b/core/src/main/scala/org/apache/spark/deploy/client/StandaloneAppClientListener.scala @@ -18,9 +18,9 @@ package org.apache.spark.deploy.client /** - * Callbacks invoked by deploy client when various events happen. There are currently four events: - * connecting to the cluster, disconnecting, being given an executor, and having an executor - * removed (either due to failure or due to revocation). + * Callbacks invoked by deploy client when various events happen. There are currently five events: + * connecting to the cluster, disconnecting, being given an executor, having an executor removed + * (either due to failure or due to revocation), and having a worker removed. * * Users of this API should *not* block inside the callback methods. */ @@ -38,4 +38,6 @@ private[spark] trait StandaloneAppClientListener { def executorRemoved( fullId: String, message: String, exitStatus: Option[Int], workerLost: Boolean): Unit + + def workerRemoved(workerId: String, host: String, message: String): Unit } diff --git a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala index f10a41286c52f..c192a0cc82ef6 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala @@ -498,7 +498,7 @@ private[deploy] class Master( override def onDisconnected(address: RpcAddress): Unit = { // The disconnected client could've been either a worker or an app; remove whichever it was logInfo(s"$address got disassociated, removing it.") - addressToWorker.get(address).foreach(removeWorker) + addressToWorker.get(address).foreach(removeWorker(_, s"${address} got disassociated")) addressToApp.get(address).foreach(finishApplication) if (state == RecoveryState.RECOVERING && canCompleteRecovery) { completeRecovery() } } @@ -544,7 +544,8 @@ private[deploy] class Master( state = RecoveryState.COMPLETING_RECOVERY // Kill off any workers and apps that didn't respond to us. - workers.filter(_.state == WorkerState.UNKNOWN).foreach(removeWorker) + workers.filter(_.state == WorkerState.UNKNOWN).foreach( + removeWorker(_, "Not responding for recovery")) apps.filter(_.state == ApplicationState.UNKNOWN).foreach(finishApplication) // Update the state of recovered apps to RUNNING @@ -755,7 +756,7 @@ private[deploy] class Master( if (oldWorker.state == WorkerState.UNKNOWN) { // A worker registering from UNKNOWN implies that the worker was restarted during recovery. // The old worker must thus be dead, so we will remove it and accept the new worker. - removeWorker(oldWorker) + removeWorker(oldWorker, "Worker replaced by a new worker with same address") } else { logInfo("Attempted to re-register worker at same address: " + workerAddress) return false @@ -771,7 +772,7 @@ private[deploy] class Master( true } - private def removeWorker(worker: WorkerInfo) { + private def removeWorker(worker: WorkerInfo, msg: String) { logInfo("Removing worker " + worker.id + " on " + worker.host + ":" + worker.port) worker.setState(WorkerState.DEAD) idToWorker -= worker.id @@ -795,6 +796,10 @@ private[deploy] class Master( removeDriver(driver.id, DriverState.ERROR, None) } } + logInfo(s"Telling app of lost worker: " + worker.id) + apps.filterNot(completedApps.contains(_)).foreach { app => + app.driver.send(WorkerRemoved(worker.id, worker.host, msg)) + } persistenceEngine.removeWorker(worker) } @@ -979,7 +984,7 @@ private[deploy] class Master( if (worker.state != WorkerState.DEAD) { logWarning("Removing %s because we got no heartbeat in %d seconds".format( worker.id, WORKER_TIMEOUT_MS / 1000)) - removeWorker(worker) + removeWorker(worker, s"Not receiving heartbeat for ${WORKER_TIMEOUT_MS / 1000} seconds") } else { if (worker.lastHeartbeat < currentTime - ((REAPER_ITERATIONS + 1) * WORKER_TIMEOUT_MS)) { workers -= worker // we've seen this DEAD worker in the UI, etc. for long enough; cull it diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index fafe9cafdc18f..3422a5f204b12 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -259,6 +259,13 @@ class DAGScheduler( eventProcessLoop.post(ExecutorLost(execId, reason)) } + /** + * Called by TaskScheduler implementation when a worker is removed. + */ + def workerRemoved(workerId: String, host: String, message: String): Unit = { + eventProcessLoop.post(WorkerRemoved(workerId, host, message)) + } + /** * Called by TaskScheduler implementation when a host is added. */ @@ -1432,6 +1439,26 @@ class DAGScheduler( } } + /** + * Responds to a worker being removed. This is called inside the event loop, so it assumes it can + * modify the scheduler's internal state. Use workerRemoved() to post a loss event from outside. + * + * We will assume that we've lost all shuffle blocks associated with the host if a worker is + * removed, so we will remove them all from MapStatus. + * + * @param workerId identifier of the worker that is removed. + * @param host host of the worker that is removed. + * @param message the reason why the worker is removed. + */ + private[scheduler] def handleWorkerRemoved( + workerId: String, + host: String, + message: String): Unit = { + logInfo("Shuffle files lost for worker %s on host %s".format(workerId, host)) + mapOutputTracker.removeOutputsOnHost(host) + clearCacheLocs() + } + private[scheduler] def handleExecutorAdded(execId: String, host: String) { // remove from failedEpoch(execId) ? if (failedEpoch.contains(execId)) { @@ -1727,6 +1754,9 @@ private[scheduler] class DAGSchedulerEventProcessLoop(dagScheduler: DAGScheduler } dagScheduler.handleExecutorLost(execId, workerLost) + case WorkerRemoved(workerId, host, message) => + dagScheduler.handleWorkerRemoved(workerId, host, message) + case BeginEvent(task, taskInfo) => dagScheduler.handleBeginEvent(task, taskInfo) diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerEvent.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerEvent.scala index cda0585f154a9..3f8d5639a2b90 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerEvent.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerEvent.scala @@ -86,6 +86,9 @@ private[scheduler] case class ExecutorAdded(execId: String, host: String) extend private[scheduler] case class ExecutorLost(execId: String, reason: ExecutorLossReason) extends DAGSchedulerEvent +private[scheduler] case class WorkerRemoved(workerId: String, host: String, message: String) + extends DAGSchedulerEvent + private[scheduler] case class TaskSetFailed(taskSet: TaskSet, reason: String, exception: Option[Throwable]) extends DAGSchedulerEvent diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskScheduler.scala index 3de7d1f7de22b..90644fea23ab1 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskScheduler.scala @@ -89,6 +89,11 @@ private[spark] trait TaskScheduler { */ def executorLost(executorId: String, reason: ExecutorLossReason): Unit + /** + * Process a removed worker + */ + def workerRemoved(workerId: String, host: String, message: String): Unit + /** * Get an application's attempt ID associated with the job. * diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala index 629cfc7c7a8ce..bba0b294f1afb 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala @@ -569,6 +569,11 @@ private[spark] class TaskSchedulerImpl private[scheduler]( } } + override def workerRemoved(workerId: String, host: String, message: String): Unit = { + logInfo(s"Handle removed worker $workerId: $message") + dagScheduler.workerRemoved(workerId, host, message) + } + private def logExecutorLoss( executorId: String, hostPort: String, diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedClusterMessage.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedClusterMessage.scala index 6b49bd699a13a..89a9ad6811e18 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedClusterMessage.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedClusterMessage.scala @@ -85,6 +85,9 @@ private[spark] object CoarseGrainedClusterMessages { case class RemoveExecutor(executorId: String, reason: ExecutorLossReason) extends CoarseGrainedClusterMessage + case class RemoveWorker(workerId: String, host: String, message: String) + extends CoarseGrainedClusterMessage + case class SetupDriver(driver: RpcEndpointRef) extends CoarseGrainedClusterMessage // Exchanged between the driver and the AM in Yarn client mode diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala index dc82bb7704727..0b396b794ddce 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala @@ -219,6 +219,10 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp removeExecutor(executorId, reason) context.reply(true) + case RemoveWorker(workerId, host, message) => + removeWorker(workerId, host, message) + context.reply(true) + case RetrieveSparkAppConfig => val reply = SparkAppConfig(sparkProperties, SparkEnv.get.securityManager.getIOEncryptionKey()) @@ -231,8 +235,9 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp val taskDescs = CoarseGrainedSchedulerBackend.this.synchronized { // Filter out executors under killing val activeExecutors = executorDataMap.filterKeys(executorIsAlive) - val workOffers = activeExecutors.map { case (id, executorData) => - new WorkerOffer(id, executorData.executorHost, executorData.freeCores) + val workOffers = activeExecutors.map { + case (id, executorData) => + new WorkerOffer(id, executorData.executorHost, executorData.freeCores) }.toIndexedSeq scheduler.resourceOffers(workOffers) } @@ -331,6 +336,12 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp } } + // Remove a lost worker from the cluster + private def removeWorker(workerId: String, host: String, message: String): Unit = { + logDebug(s"Asked to remove worker $workerId with reason $message") + scheduler.workerRemoved(workerId, host, message) + } + /** * Stop making resource offers for the given executor. The executor is marked as lost with * the loss reason still pending. @@ -449,8 +460,14 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp */ protected def removeExecutor(executorId: String, reason: ExecutorLossReason): Unit = { // Only log the failure since we don't care about the result. - driverEndpoint.ask[Boolean](RemoveExecutor(executorId, reason)).onFailure { case t => - logError(t.getMessage, t) + driverEndpoint.ask[Boolean](RemoveExecutor(executorId, reason)).onFailure { + case t => logError(t.getMessage, t) + }(ThreadUtils.sameThread) + } + + protected def removeWorker(workerId: String, host: String, message: String): Unit = { + driverEndpoint.ask[Boolean](RemoveWorker(workerId, host, message)).onFailure { + case t => logError(t.getMessage, t) }(ThreadUtils.sameThread) } diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/StandaloneSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/StandaloneSchedulerBackend.scala index 0529fe9eed4da..fd8e64454bf70 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/StandaloneSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/StandaloneSchedulerBackend.scala @@ -161,6 +161,11 @@ private[spark] class StandaloneSchedulerBackend( removeExecutor(fullId.split("/")(1), reason) } + override def workerRemoved(workerId: String, host: String, message: String): Unit = { + logInfo("Worker %s removed: %s".format(workerId, message)) + removeWorker(workerId, host, message) + } + override def sufficientResourcesRegistered(): Boolean = { totalCoreCount.get() >= totalExpectedCores * minRegisteredRatio } diff --git a/core/src/test/scala/org/apache/spark/deploy/client/AppClientSuite.scala b/core/src/test/scala/org/apache/spark/deploy/client/AppClientSuite.scala index 936639b845789..a1707e6540b39 100644 --- a/core/src/test/scala/org/apache/spark/deploy/client/AppClientSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/client/AppClientSuite.scala @@ -214,6 +214,8 @@ class AppClientSuite id: String, message: String, exitStatus: Option[Int], workerLost: Boolean): Unit = { execRemovedList.add(id) } + + def workerRemoved(workerId: String, host: String, message: String): Unit = {} } /** Create AppClient and supporting objects */ diff --git a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala index ddd3281106745..453be26ed8d0c 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala @@ -131,6 +131,7 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with Timeou override def setDAGScheduler(dagScheduler: DAGScheduler) = {} override def defaultParallelism() = 2 override def executorLost(executorId: String, reason: ExecutorLossReason): Unit = {} + override def workerRemoved(workerId: String, host: String, message: String): Unit = {} override def applicationAttemptId(): Option[String] = None } @@ -632,6 +633,7 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with Timeou accumUpdates: Array[(Long, Seq[AccumulatorV2[_, _]])], blockManagerId: BlockManagerId): Boolean = true override def executorLost(executorId: String, reason: ExecutorLossReason): Unit = {} + override def workerRemoved(workerId: String, host: String, message: String): Unit = {} override def applicationAttemptId(): Option[String] = None } val noKillScheduler = new DAGScheduler( diff --git a/core/src/test/scala/org/apache/spark/scheduler/ExternalClusterManagerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/ExternalClusterManagerSuite.scala index ba56af8215cd7..a4e4ea7cd2894 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/ExternalClusterManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/ExternalClusterManagerSuite.scala @@ -84,6 +84,7 @@ private class DummyTaskScheduler extends TaskScheduler { override def setDAGScheduler(dagScheduler: DAGScheduler): Unit = {} override def defaultParallelism(): Int = 2 override def executorLost(executorId: String, reason: ExecutorLossReason): Unit = {} + override def workerRemoved(workerId: String, host: String, message: String): Unit = {} override def applicationAttemptId(): Option[String] = None def executorHeartbeatReceived( execId: String, From 19331b8e44ad910550f810b80e2a0caf0ef62cb3 Mon Sep 17 00:00:00 2001 From: actuaryzhang Date: Thu, 22 Jun 2017 10:16:51 -0700 Subject: [PATCH 0772/1765] [SPARK-20889][SPARKR] Grouped documentation for DATETIME column methods ## What changes were proposed in this pull request? Grouped documentation for datetime column methods. Author: actuaryzhang Closes #18114 from actuaryzhang/sparkRDocDate. --- R/pkg/R/functions.R | 532 +++++++++++++++++++------------------------- R/pkg/R/generics.R | 69 ++++-- 2 files changed, 273 insertions(+), 328 deletions(-) diff --git a/R/pkg/R/functions.R b/R/pkg/R/functions.R index 01ca8b8c4527d..31028585aaa13 100644 --- a/R/pkg/R/functions.R +++ b/R/pkg/R/functions.R @@ -34,6 +34,58 @@ NULL #' df <- createDataFrame(cbind(model = rownames(mtcars), mtcars))} NULL +#' Date time functions for Column operations +#' +#' Date time functions defined for \code{Column}. +#' +#' @param x Column to compute on. +#' @param format For \code{to_date} and \code{to_timestamp}, it is the string to use to parse +#' x Column to DateType or TimestampType. For \code{trunc}, it is the string used +#' for specifying the truncation method. For example, "year", "yyyy", "yy" for +#' truncate by year, or "month", "mon", "mm" for truncate by month. +#' @param ... additional argument(s). +#' @name column_datetime_functions +#' @rdname column_datetime_functions +#' @family data time functions +#' @examples +#' \dontrun{ +#' dts <- c("2005-01-02 18:47:22", +#' "2005-12-24 16:30:58", +#' "2005-10-28 07:30:05", +#' "2005-12-28 07:01:05", +#' "2006-01-24 00:01:10") +#' y <- c(2.0, 2.2, 3.4, 2.5, 1.8) +#' df <- createDataFrame(data.frame(time = as.POSIXct(dts), y = y))} +NULL + +#' Date time arithmetic functions for Column operations +#' +#' Date time arithmetic functions defined for \code{Column}. +#' +#' @param y Column to compute on. +#' @param x For class \code{Column}, it is the column used to perform arithmetic operations +#' with column \code{y}. For class \code{numeric}, it is the number of months or +#' days to be added to or subtracted from \code{y}. For class \code{character}, it is +#' \itemize{ +#' \item \code{date_format}: date format specification. +#' \item \code{from_utc_timestamp}, \code{to_utc_timestamp}: time zone to use. +#' \item \code{next_day}: day of the week string. +#' } +#' +#' @name column_datetime_diff_functions +#' @rdname column_datetime_diff_functions +#' @family data time functions +#' @examples +#' \dontrun{ +#' dts <- c("2005-01-02 18:47:22", +#' "2005-12-24 16:30:58", +#' "2005-10-28 07:30:05", +#' "2005-12-28 07:01:05", +#' "2006-01-24 00:01:10") +#' y <- c(2.0, 2.2, 3.4, 2.5, 1.8) +#' df <- createDataFrame(data.frame(time = as.POSIXct(dts), y = y))} +NULL + #' lit #' #' A new \linkS4class{Column} is created to represent the literal value. @@ -546,18 +598,20 @@ setMethod("hash", column(jc) }) -#' dayofmonth -#' -#' Extracts the day of the month as an integer from a given date/timestamp/string. -#' -#' @param x Column to compute on. +#' @details +#' \code{dayofmonth}: Extracts the day of the month as an integer from a +#' given date/timestamp/string. #' -#' @rdname dayofmonth -#' @name dayofmonth -#' @family date time functions -#' @aliases dayofmonth,Column-method +#' @rdname column_datetime_functions +#' @aliases dayofmonth dayofmonth,Column-method #' @export -#' @examples \dontrun{dayofmonth(df$c)} +#' @examples +#' +#' \dontrun{ +#' head(select(df, df$time, year(df$time), quarter(df$time), month(df$time), +#' dayofmonth(df$time), dayofyear(df$time), weekofyear(df$time))) +#' head(agg(groupBy(df, year(df$time)), count(df$y), avg(df$y))) +#' head(agg(groupBy(df, month(df$time)), avg(df$y)))} #' @note dayofmonth since 1.5.0 setMethod("dayofmonth", signature(x = "Column"), @@ -566,18 +620,13 @@ setMethod("dayofmonth", column(jc) }) -#' dayofyear -#' -#' Extracts the day of the year as an integer from a given date/timestamp/string. -#' -#' @param x Column to compute on. +#' @details +#' \code{dayofyear}: Extracts the day of the year as an integer from a +#' given date/timestamp/string. #' -#' @rdname dayofyear -#' @name dayofyear -#' @family date time functions -#' @aliases dayofyear,Column-method +#' @rdname column_datetime_functions +#' @aliases dayofyear dayofyear,Column-method #' @export -#' @examples \dontrun{dayofyear(df$c)} #' @note dayofyear since 1.5.0 setMethod("dayofyear", signature(x = "Column"), @@ -763,18 +812,19 @@ setMethod("hex", column(jc) }) -#' hour -#' -#' Extracts the hours as an integer from a given date/timestamp/string. -#' -#' @param x Column to compute on. +#' @details +#' \code{hour}: Extracts the hours as an integer from a given date/timestamp/string. #' -#' @rdname hour -#' @name hour -#' @aliases hour,Column-method -#' @family date time functions +#' @rdname column_datetime_functions +#' @aliases hour hour,Column-method #' @export -#' @examples \dontrun{hour(df$c)} +#' @examples +#' +#' \dontrun{ +#' head(select(df, hour(df$time), minute(df$time), second(df$time))) +#' head(agg(groupBy(df, dayofmonth(df$time)), avg(df$y))) +#' head(agg(groupBy(df, hour(df$time)), avg(df$y))) +#' head(agg(groupBy(df, minute(df$time)), avg(df$y)))} #' @note hour since 1.5.0 setMethod("hour", signature(x = "Column"), @@ -893,20 +943,18 @@ setMethod("last", column(jc) }) -#' last_day -#' -#' Given a date column, returns the last day of the month which the given date belongs to. -#' For example, input "2015-07-27" returns "2015-07-31" since July 31 is the last day of the -#' month in July 2015. -#' -#' @param x Column to compute on. +#' @details +#' \code{last_day}: Given a date column, returns the last day of the month which the +#' given date belongs to. For example, input "2015-07-27" returns "2015-07-31" since +#' July 31 is the last day of the month in July 2015. #' -#' @rdname last_day -#' @name last_day -#' @aliases last_day,Column-method -#' @family date time functions +#' @rdname column_datetime_functions +#' @aliases last_day last_day,Column-method #' @export -#' @examples \dontrun{last_day(df$c)} +#' @examples +#' +#' \dontrun{ +#' head(select(df, df$time, last_day(df$time), month(df$time)))} #' @note last_day since 1.5.0 setMethod("last_day", signature(x = "Column"), @@ -1129,18 +1177,12 @@ setMethod("min", column(jc) }) -#' minute -#' -#' Extracts the minutes as an integer from a given date/timestamp/string. -#' -#' @param x Column to compute on. +#' @details +#' \code{minute}: Extracts the minutes as an integer from a given date/timestamp/string. #' -#' @rdname minute -#' @name minute -#' @aliases minute,Column-method -#' @family date time functions +#' @rdname column_datetime_functions +#' @aliases minute minute,Column-method #' @export -#' @examples \dontrun{minute(df$c)} #' @note minute since 1.5.0 setMethod("minute", signature(x = "Column"), @@ -1177,18 +1219,12 @@ setMethod("monotonically_increasing_id", column(jc) }) -#' month -#' -#' Extracts the month as an integer from a given date/timestamp/string. -#' -#' @param x Column to compute on. +#' @details +#' \code{month}: Extracts the month as an integer from a given date/timestamp/string. #' -#' @rdname month -#' @name month -#' @aliases month,Column-method -#' @family date time functions +#' @rdname column_datetime_functions +#' @aliases month month,Column-method #' @export -#' @examples \dontrun{month(df$c)} #' @note month since 1.5.0 setMethod("month", signature(x = "Column"), @@ -1217,18 +1253,12 @@ setMethod("negate", column(jc) }) -#' quarter -#' -#' Extracts the quarter as an integer from a given date/timestamp/string. -#' -#' @param x Column to compute on. +#' @details +#' \code{quarter}: Extracts the quarter as an integer from a given date/timestamp/string. #' -#' @rdname quarter -#' @name quarter -#' @family date time functions -#' @aliases quarter,Column-method +#' @rdname column_datetime_functions +#' @aliases quarter quarter,Column-method #' @export -#' @examples \dontrun{quarter(df$c)} #' @note quarter since 1.5.0 setMethod("quarter", signature(x = "Column"), @@ -1364,18 +1394,12 @@ setMethod("sd", stddev_samp(x) }) -#' second -#' -#' Extracts the seconds as an integer from a given date/timestamp/string. -#' -#' @param x Column to compute on. +#' @details +#' \code{second}: Extracts the seconds as an integer from a given date/timestamp/string. #' -#' @rdname second -#' @name second -#' @family date time functions -#' @aliases second,Column-method +#' @rdname column_datetime_functions +#' @aliases second second,Column-method #' @export -#' @examples \dontrun{second(df$c)} #' @note second since 1.5.0 setMethod("second", signature(x = "Column"), @@ -1725,29 +1749,28 @@ setMethod("toRadians", column(jc) }) -#' to_date -#' -#' Converts the column into a DateType. You may optionally specify a format -#' according to the rules in: +#' @details +#' \code{to_date}: Converts the column into a DateType. You may optionally specify +#' a format according to the rules in: #' \url{http://docs.oracle.com/javase/tutorial/i18n/format/simpleDateFormat.html}. #' If the string cannot be parsed according to the specified format (or default), #' the value of the column will be null. #' By default, it follows casting rules to a DateType if the format is omitted #' (equivalent to \code{cast(df$x, "date")}). #' -#' @param x Column to parse. -#' @param format string to use to parse x Column to DateType. (optional) -#' -#' @rdname to_date -#' @name to_date -#' @family date time functions -#' @aliases to_date,Column,missing-method +#' @rdname column_datetime_functions +#' @aliases to_date to_date,Column,missing-method #' @export #' @examples +#' #' \dontrun{ -#' to_date(df$c) -#' to_date(df$c, 'yyyy-MM-dd') -#' } +#' tmp <- createDataFrame(data.frame(time_string = dts)) +#' tmp2 <- mutate(tmp, date1 = to_date(tmp$time_string), +#' date2 = to_date(tmp$time_string, "yyyy-MM-dd"), +#' date3 = date_format(tmp$time_string, "MM/dd/yyy"), +#' time1 = to_timestamp(tmp$time_string), +#' time2 = to_timestamp(tmp$time_string, "yyyy-MM-dd")) +#' head(tmp2)} #' @note to_date(Column) since 1.5.0 setMethod("to_date", signature(x = "Column", format = "missing"), @@ -1756,9 +1779,7 @@ setMethod("to_date", column(jc) }) -#' @rdname to_date -#' @name to_date -#' @family date time functions +#' @rdname column_datetime_functions #' @aliases to_date,Column,character-method #' @export #' @note to_date(Column, character) since 2.2.0 @@ -1801,29 +1822,18 @@ setMethod("to_json", signature(x = "Column"), column(jc) }) -#' to_timestamp -#' -#' Converts the column into a TimestampType. You may optionally specify a format -#' according to the rules in: +#' @details +#' \code{to_timestamp}: Converts the column into a TimestampType. You may optionally specify +#' a format according to the rules in: #' \url{http://docs.oracle.com/javase/tutorial/i18n/format/simpleDateFormat.html}. #' If the string cannot be parsed according to the specified format (or default), #' the value of the column will be null. #' By default, it follows casting rules to a TimestampType if the format is omitted #' (equivalent to \code{cast(df$x, "timestamp")}). #' -#' @param x Column to parse. -#' @param format string to use to parse x Column to TimestampType. (optional) -#' -#' @rdname to_timestamp -#' @name to_timestamp -#' @family date time functions -#' @aliases to_timestamp,Column,missing-method +#' @rdname column_datetime_functions +#' @aliases to_timestamp to_timestamp,Column,missing-method #' @export -#' @examples -#' \dontrun{ -#' to_timestamp(df$c) -#' to_timestamp(df$c, 'yyyy-MM-dd') -#' } #' @note to_timestamp(Column) since 2.2.0 setMethod("to_timestamp", signature(x = "Column", format = "missing"), @@ -1832,9 +1842,7 @@ setMethod("to_timestamp", column(jc) }) -#' @rdname to_timestamp -#' @name to_timestamp -#' @family date time functions +#' @rdname column_datetime_functions #' @aliases to_timestamp,Column,character-method #' @export #' @note to_timestamp(Column, character) since 2.2.0 @@ -1984,18 +1992,12 @@ setMethod("var_samp", column(jc) }) -#' weekofyear -#' -#' Extracts the week number as an integer from a given date/timestamp/string. -#' -#' @param x Column to compute on. +#' @details +#' \code{weekofyear}: Extracts the week number as an integer from a given date/timestamp/string. #' -#' @rdname weekofyear -#' @name weekofyear -#' @aliases weekofyear,Column-method -#' @family date time functions +#' @rdname column_datetime_functions +#' @aliases weekofyear weekofyear,Column-method #' @export -#' @examples \dontrun{weekofyear(df$c)} #' @note weekofyear since 1.5.0 setMethod("weekofyear", signature(x = "Column"), @@ -2004,18 +2006,12 @@ setMethod("weekofyear", column(jc) }) -#' year -#' -#' Extracts the year as an integer from a given date/timestamp/string. -#' -#' @param x Column to compute on. +#' @details +#' \code{year}: Extracts the year as an integer from a given date/timestamp/string. #' -#' @rdname year -#' @name year -#' @family date time functions -#' @aliases year,Column-method +#' @rdname column_datetime_functions +#' @aliases year year,Column-method #' @export -#' @examples \dontrun{year(df$c)} #' @note year since 1.5.0 setMethod("year", signature(x = "Column"), @@ -2048,19 +2044,20 @@ setMethod("atan2", signature(y = "Column"), column(jc) }) -#' datediff -#' -#' Returns the number of days from \code{start} to \code{end}. -#' -#' @param x start Column to use. -#' @param y end Column to use. +#' @details +#' \code{datediff}: Returns the number of days from \code{y} to \code{x}. #' -#' @rdname datediff -#' @name datediff -#' @aliases datediff,Column-method -#' @family date time functions +#' @rdname column_datetime_diff_functions +#' @aliases datediff datediff,Column-method #' @export -#' @examples \dontrun{datediff(df$c, x)} +#' @examples +#' +#' \dontrun{ +#' tmp <- createDataFrame(data.frame(time_string1 = as.POSIXct(dts), +#' time_string2 = as.POSIXct(dts[order(runif(length(dts)))]))) +#' tmp2 <- mutate(tmp, datediff = datediff(tmp$time_string1, tmp$time_string2), +#' monthdiff = months_between(tmp$time_string1, tmp$time_string2)) +#' head(tmp2)} #' @note datediff since 1.5.0 setMethod("datediff", signature(y = "Column"), function(y, x) { @@ -2117,19 +2114,12 @@ setMethod("levenshtein", signature(y = "Column"), column(jc) }) -#' months_between -#' -#' Returns number of months between dates \code{date1} and \code{date2}. -#' -#' @param x start Column to use. -#' @param y end Column to use. +#' @details +#' \code{months_between}: Returns number of months between dates \code{y} and \code{x}. #' -#' @rdname months_between -#' @name months_between -#' @family date time functions -#' @aliases months_between,Column-method +#' @rdname column_datetime_diff_functions +#' @aliases months_between months_between,Column-method #' @export -#' @examples \dontrun{months_between(df$c, x)} #' @note months_between since 1.5.0 setMethod("months_between", signature(y = "Column"), function(y, x) { @@ -2348,26 +2338,18 @@ setMethod("n", signature(x = "Column"), count(x) }) -#' date_format -#' -#' Converts a date/timestamp/string to a value of string in the format specified by the date -#' format given by the second argument. -#' -#' A pattern could be for instance \preformatted{dd.MM.yyyy} and could return a string like '18.03.1993'. All +#' @details +#' \code{date_format}: Converts a date/timestamp/string to a value of string in the format +#' specified by the date format given by the second argument. A pattern could be for instance +#' \code{dd.MM.yyyy} and could return a string like '18.03.1993'. All #' pattern letters of \code{java.text.SimpleDateFormat} can be used. -#' #' Note: Use when ever possible specialized functions like \code{year}. These benefit from a #' specialized implementation. #' -#' @param y Column to compute on. -#' @param x date format specification. +#' @rdname column_datetime_diff_functions #' -#' @family date time functions -#' @rdname date_format -#' @name date_format -#' @aliases date_format,Column,character-method +#' @aliases date_format date_format,Column,character-method #' @export -#' @examples \dontrun{date_format(df$t, 'MM/dd/yyy')} #' @note date_format since 1.5.0 setMethod("date_format", signature(y = "Column", x = "character"), function(y, x) { @@ -2414,20 +2396,20 @@ setMethod("from_json", signature(x = "Column", schema = "structType"), column(jc) }) -#' from_utc_timestamp -#' -#' Given a timestamp, which corresponds to a certain time of day in UTC, returns another timestamp -#' that corresponds to the same time of day in the given timezone. +#' @details +#' \code{from_utc_timestamp}: Given a timestamp, which corresponds to a certain time of day in UTC, +#' returns another timestamp that corresponds to the same time of day in the given timezone. #' -#' @param y Column to compute on. -#' @param x time zone to use. +#' @rdname column_datetime_diff_functions #' -#' @family date time functions -#' @rdname from_utc_timestamp -#' @name from_utc_timestamp -#' @aliases from_utc_timestamp,Column,character-method +#' @aliases from_utc_timestamp from_utc_timestamp,Column,character-method #' @export -#' @examples \dontrun{from_utc_timestamp(df$t, 'PST')} +#' @examples +#' +#' \dontrun{ +#' tmp <- mutate(df, from_utc = from_utc_timestamp(df$time, 'PST'), +#' to_utc = to_utc_timestamp(df$time, 'PST')) +#' head(tmp)} #' @note from_utc_timestamp since 1.5.0 setMethod("from_utc_timestamp", signature(y = "Column", x = "character"), function(y, x) { @@ -2458,30 +2440,16 @@ setMethod("instr", signature(y = "Column", x = "character"), column(jc) }) -#' next_day -#' -#' Given a date column, returns the first date which is later than the value of the date column -#' that is on the specified day of the week. -#' -#' For example, \code{next_day('2015-07-27', "Sunday")} returns 2015-08-02 because that is the first -#' Sunday after 2015-07-27. -#' -#' Day of the week parameter is case insensitive, and accepts first three or two characters: -#' "Mon", "Tue", "Wed", "Thu", "Fri", "Sat", "Sun". -#' -#' @param y Column to compute on. -#' @param x Day of the week string. +#' @details +#' \code{next_day}: Given a date column, returns the first date which is later than the value of +#' the date column that is on the specified day of the week. For example, +#' \code{next_day('2015-07-27', "Sunday")} returns 2015-08-02 because that is the first Sunday +#' after 2015-07-27. Day of the week parameter is case insensitive, and accepts first three or +#' two characters: "Mon", "Tue", "Wed", "Thu", "Fri", "Sat", "Sun". #' -#' @family date time functions -#' @rdname next_day -#' @name next_day -#' @aliases next_day,Column,character-method +#' @rdname column_datetime_diff_functions +#' @aliases next_day next_day,Column,character-method #' @export -#' @examples -#'\dontrun{ -#'next_day(df$d, 'Sun') -#'next_day(df$d, 'Sunday') -#'} #' @note next_day since 1.5.0 setMethod("next_day", signature(y = "Column", x = "character"), function(y, x) { @@ -2489,20 +2457,13 @@ setMethod("next_day", signature(y = "Column", x = "character"), column(jc) }) -#' to_utc_timestamp -#' -#' Given a timestamp, which corresponds to a certain time of day in the given timezone, returns -#' another timestamp that corresponds to the same time of day in UTC. -#' -#' @param y Column to compute on -#' @param x timezone to use +#' @details +#' \code{to_utc_timestamp}: Given a timestamp, which corresponds to a certain time of day +#' in the given timezone, returns another timestamp that corresponds to the same time of day in UTC. #' -#' @family date time functions -#' @rdname to_utc_timestamp -#' @name to_utc_timestamp -#' @aliases to_utc_timestamp,Column,character-method +#' @rdname column_datetime_diff_functions +#' @aliases to_utc_timestamp to_utc_timestamp,Column,character-method #' @export -#' @examples \dontrun{to_utc_timestamp(df$t, 'PST')} #' @note to_utc_timestamp since 1.5.0 setMethod("to_utc_timestamp", signature(y = "Column", x = "character"), function(y, x) { @@ -2510,19 +2471,20 @@ setMethod("to_utc_timestamp", signature(y = "Column", x = "character"), column(jc) }) -#' add_months -#' -#' Returns the date that is numMonths after startDate. -#' -#' @param y Column to compute on -#' @param x Number of months to add +#' @details +#' \code{add_months}: Returns the date that is numMonths (\code{x}) after startDate (\code{y}). #' -#' @name add_months -#' @family date time functions -#' @rdname add_months -#' @aliases add_months,Column,numeric-method +#' @rdname column_datetime_diff_functions +#' @aliases add_months add_months,Column,numeric-method #' @export -#' @examples \dontrun{add_months(df$d, 1)} +#' @examples +#' +#' \dontrun{ +#' tmp <- mutate(df, t1 = add_months(df$time, 1), +#' t2 = date_add(df$time, 2), +#' t3 = date_sub(df$time, 3), +#' t4 = next_day(df$time, 'Sun')) +#' head(tmp)} #' @note add_months since 1.5.0 setMethod("add_months", signature(y = "Column", x = "numeric"), function(y, x) { @@ -2530,19 +2492,12 @@ setMethod("add_months", signature(y = "Column", x = "numeric"), column(jc) }) -#' date_add -#' -#' Returns the date that is \code{x} days after -#' -#' @param y Column to compute on -#' @param x Number of days to add +#' @details +#' \code{date_add}: Returns the date that is \code{x} days after. #' -#' @family date time functions -#' @rdname date_add -#' @name date_add -#' @aliases date_add,Column,numeric-method +#' @rdname column_datetime_diff_functions +#' @aliases date_add date_add,Column,numeric-method #' @export -#' @examples \dontrun{date_add(df$d, 1)} #' @note date_add since 1.5.0 setMethod("date_add", signature(y = "Column", x = "numeric"), function(y, x) { @@ -2550,19 +2505,13 @@ setMethod("date_add", signature(y = "Column", x = "numeric"), column(jc) }) -#' date_sub -#' -#' Returns the date that is \code{x} days before +#' @details +#' \code{date_sub}: Returns the date that is \code{x} days before. #' -#' @param y Column to compute on -#' @param x Number of days to substract +#' @rdname column_datetime_diff_functions #' -#' @family date time functions -#' @rdname date_sub -#' @name date_sub -#' @aliases date_sub,Column,numeric-method +#' @aliases date_sub date_sub,Column,numeric-method #' @export -#' @examples \dontrun{date_sub(df$d, 1)} #' @note date_sub since 1.5.0 setMethod("date_sub", signature(y = "Column", x = "numeric"), function(y, x) { @@ -2774,27 +2723,24 @@ setMethod("format_string", signature(format = "character", x = "Column"), column(jc) }) -#' from_unixtime +#' @details +#' \code{from_unixtime}: Converts the number of seconds from unix epoch (1970-01-01 00:00:00 UTC) to a +#' string representing the timestamp of that moment in the current system time zone in the JVM in the +#' given format. See \href{http://docs.oracle.com/javase/tutorial/i18n/format/simpleDateFormat.html}{ +#' Customizing Formats} for available options. #' -#' Converts the number of seconds from unix epoch (1970-01-01 00:00:00 UTC) to a string -#' representing the timestamp of that moment in the current system time zone in the given -#' format. +#' @rdname column_datetime_functions #' -#' @param x a Column of unix timestamp. -#' @param format the target format. See -#' \href{http://docs.oracle.com/javase/tutorial/i18n/format/simpleDateFormat.html}{ -#' Customizing Formats} for available options. -#' @param ... further arguments to be passed to or from other methods. -#' @family date time functions -#' @rdname from_unixtime -#' @name from_unixtime -#' @aliases from_unixtime,Column-method +#' @aliases from_unixtime from_unixtime,Column-method #' @export #' @examples -#'\dontrun{ -#'from_unixtime(df$t) -#'from_unixtime(df$t, 'yyyy/MM/dd HH') -#'} +#' +#' \dontrun{ +#' tmp <- mutate(df, to_unix = unix_timestamp(df$time), +#' to_unix2 = unix_timestamp(df$time, 'yyyy-MM-dd HH'), +#' from_unix = from_unixtime(unix_timestamp(df$time)), +#' from_unix2 = from_unixtime(unix_timestamp(df$time), 'yyyy-MM-dd HH:mm')) +#' head(tmp)} #' @note from_unixtime since 1.5.0 setMethod("from_unixtime", signature(x = "Column"), function(x, format = "yyyy-MM-dd HH:mm:ss") { @@ -3111,21 +3057,12 @@ setMethod("translate", column(jc) }) -#' unix_timestamp -#' -#' Gets current Unix timestamp in seconds. +#' @details +#' \code{unix_timestamp}: Gets current Unix timestamp in seconds. #' -#' @family date time functions -#' @rdname unix_timestamp -#' @name unix_timestamp -#' @aliases unix_timestamp,missing,missing-method +#' @rdname column_datetime_functions +#' @aliases unix_timestamp unix_timestamp,missing,missing-method #' @export -#' @examples -#'\dontrun{ -#'unix_timestamp() -#'unix_timestamp(df$t) -#'unix_timestamp(df$t, 'yyyy-MM-dd HH') -#'} #' @note unix_timestamp since 1.5.0 setMethod("unix_timestamp", signature(x = "missing", format = "missing"), function(x, format) { @@ -3133,8 +3070,7 @@ setMethod("unix_timestamp", signature(x = "missing", format = "missing"), column(jc) }) -#' @rdname unix_timestamp -#' @name unix_timestamp +#' @rdname column_datetime_functions #' @aliases unix_timestamp,Column,missing-method #' @export #' @note unix_timestamp(Column) since 1.5.0 @@ -3144,12 +3080,7 @@ setMethod("unix_timestamp", signature(x = "Column", format = "missing"), column(jc) }) -#' @param x a Column of date, in string, date or timestamp type. -#' @param format the target format. See -#' \href{http://docs.oracle.com/javase/tutorial/i18n/format/simpleDateFormat.html}{ -#' Customizing Formats} for available options. -#' @rdname unix_timestamp -#' @name unix_timestamp +#' @rdname column_datetime_functions #' @aliases unix_timestamp,Column,character-method #' @export #' @note unix_timestamp(Column, character) since 1.5.0 @@ -3931,26 +3862,17 @@ setMethod("input_file_name", signature("missing"), column(jc) }) -#' trunc -#' -#' Returns date truncated to the unit specified by the format. -#' -#' @param x Column to compute on. -#' @param format string used for specify the truncation method. For example, "year", "yyyy", -#' "yy" for truncate by year, or "month", "mon", "mm" for truncate by month. +#' @details +#' \code{trunc}: Returns date truncated to the unit specified by the format. #' -#' @rdname trunc -#' @name trunc -#' @family date time functions -#' @aliases trunc,Column-method +#' @rdname column_datetime_functions +#' @aliases trunc trunc,Column-method #' @export #' @examples +#' #' \dontrun{ -#' trunc(df$c, "year") -#' trunc(df$c, "yy") -#' trunc(df$c, "month") -#' trunc(df$c, "mon") -#' } +#' head(select(df, df$time, trunc(df$time, "year"), trunc(df$time, "yy"), +#' trunc(df$time, "month"), trunc(df$time, "mon")))} #' @note trunc since 2.3.0 setMethod("trunc", signature(x = "Column"), diff --git a/R/pkg/R/generics.R b/R/pkg/R/generics.R index b3cc4868a0b33..f105174cea70d 100644 --- a/R/pkg/R/generics.R +++ b/R/pkg/R/generics.R @@ -903,8 +903,9 @@ setGeneric("windowOrderBy", function(col, ...) { standardGeneric("windowOrderBy" ###################### Expression Function Methods ########################## -#' @rdname add_months +#' @rdname column_datetime_diff_functions #' @export +#' @name NULL setGeneric("add_months", function(y, x) { standardGeneric("add_months") }) #' @rdname column_aggregate_functions @@ -1002,28 +1003,34 @@ setGeneric("hash", function(x, ...) { standardGeneric("hash") }) #' @export setGeneric("cume_dist", function(x = "missing") { standardGeneric("cume_dist") }) -#' @rdname datediff +#' @rdname column_datetime_diff_functions #' @export +#' @name NULL setGeneric("datediff", function(y, x) { standardGeneric("datediff") }) -#' @rdname date_add +#' @rdname column_datetime_diff_functions #' @export +#' @name NULL setGeneric("date_add", function(y, x) { standardGeneric("date_add") }) -#' @rdname date_format +#' @rdname column_datetime_diff_functions #' @export +#' @name NULL setGeneric("date_format", function(y, x) { standardGeneric("date_format") }) -#' @rdname date_sub +#' @rdname column_datetime_diff_functions #' @export +#' @name NULL setGeneric("date_sub", function(y, x) { standardGeneric("date_sub") }) -#' @rdname dayofmonth +#' @rdname column_datetime_functions #' @export +#' @name NULL setGeneric("dayofmonth", function(x) { standardGeneric("dayofmonth") }) -#' @rdname dayofyear +#' @rdname column_datetime_functions #' @export +#' @name NULL setGeneric("dayofyear", function(x) { standardGeneric("dayofyear") }) #' @rdname decode @@ -1051,8 +1058,9 @@ setGeneric("explode_outer", function(x) { standardGeneric("explode_outer") }) #' @export setGeneric("expr", function(x) { standardGeneric("expr") }) -#' @rdname from_utc_timestamp +#' @rdname column_datetime_diff_functions #' @export +#' @name NULL setGeneric("from_utc_timestamp", function(y, x) { standardGeneric("from_utc_timestamp") }) #' @rdname format_number @@ -1067,8 +1075,9 @@ setGeneric("format_string", function(format, x, ...) { standardGeneric("format_s #' @export setGeneric("from_json", function(x, schema, ...) { standardGeneric("from_json") }) -#' @rdname from_unixtime +#' @rdname column_datetime_functions #' @export +#' @name NULL setGeneric("from_unixtime", function(x, ...) { standardGeneric("from_unixtime") }) #' @rdname greatest @@ -1089,8 +1098,9 @@ setGeneric("grouping_id", function(x, ...) { standardGeneric("grouping_id") }) #' @export setGeneric("hex", function(x) { standardGeneric("hex") }) -#' @rdname hour +#' @rdname column_datetime_functions #' @export +#' @name NULL setGeneric("hour", function(x) { standardGeneric("hour") }) #' @rdname hypot @@ -1128,8 +1138,9 @@ setGeneric("lag", function(x, ...) { standardGeneric("lag") }) #' @export setGeneric("last", function(x, ...) { standardGeneric("last") }) -#' @rdname last_day +#' @rdname column_datetime_functions #' @export +#' @name NULL setGeneric("last_day", function(x) { standardGeneric("last_day") }) #' @rdname lead @@ -1168,8 +1179,9 @@ setGeneric("ltrim", function(x) { standardGeneric("ltrim") }) #' @export setGeneric("md5", function(x) { standardGeneric("md5") }) -#' @rdname minute +#' @rdname column_datetime_functions #' @export +#' @name NULL setGeneric("minute", function(x) { standardGeneric("minute") }) #' @param x empty. Should be used with no argument. @@ -1178,12 +1190,14 @@ setGeneric("minute", function(x) { standardGeneric("minute") }) setGeneric("monotonically_increasing_id", function(x = "missing") { standardGeneric("monotonically_increasing_id") }) -#' @rdname month +#' @rdname column_datetime_functions #' @export +#' @name NULL setGeneric("month", function(x) { standardGeneric("month") }) -#' @rdname months_between +#' @rdname column_datetime_diff_functions #' @export +#' @name NULL setGeneric("months_between", function(y, x) { standardGeneric("months_between") }) #' @rdname count @@ -1202,8 +1216,9 @@ setGeneric("negate", function(x) { standardGeneric("negate") }) #' @export setGeneric("not", function(x) { standardGeneric("not") }) -#' @rdname next_day +#' @rdname column_datetime_diff_functions #' @export +#' @name NULL setGeneric("next_day", function(y, x) { standardGeneric("next_day") }) #' @rdname ntile @@ -1232,8 +1247,9 @@ setGeneric("posexplode", function(x) { standardGeneric("posexplode") }) #' @export setGeneric("posexplode_outer", function(x) { standardGeneric("posexplode_outer") }) -#' @rdname quarter +#' @rdname column_datetime_functions #' @export +#' @name NULL setGeneric("quarter", function(x) { standardGeneric("quarter") }) #' @rdname rand @@ -1287,8 +1303,9 @@ setGeneric("rtrim", function(x) { standardGeneric("rtrim") }) #' @name NULL setGeneric("sd", function(x, na.rm = FALSE) { standardGeneric("sd") }) -#' @rdname second +#' @rdname column_datetime_functions #' @export +#' @name NULL setGeneric("second", function(x) { standardGeneric("second") }) #' @rdname sha1 @@ -1377,20 +1394,23 @@ setGeneric("toDegrees", function(x) { standardGeneric("toDegrees") }) #' @export setGeneric("toRadians", function(x) { standardGeneric("toRadians") }) -#' @rdname to_date +#' @rdname column_datetime_functions #' @export +#' @name NULL setGeneric("to_date", function(x, format) { standardGeneric("to_date") }) #' @rdname to_json #' @export setGeneric("to_json", function(x, ...) { standardGeneric("to_json") }) -#' @rdname to_timestamp +#' @rdname column_datetime_functions #' @export +#' @name NULL setGeneric("to_timestamp", function(x, format) { standardGeneric("to_timestamp") }) -#' @rdname to_utc_timestamp +#' @rdname column_datetime_diff_functions #' @export +#' @name NULL setGeneric("to_utc_timestamp", function(y, x) { standardGeneric("to_utc_timestamp") }) #' @rdname translate @@ -1409,8 +1429,9 @@ setGeneric("unbase64", function(x) { standardGeneric("unbase64") }) #' @export setGeneric("unhex", function(x) { standardGeneric("unhex") }) -#' @rdname unix_timestamp +#' @rdname column_datetime_functions #' @export +#' @name NULL setGeneric("unix_timestamp", function(x, format) { standardGeneric("unix_timestamp") }) #' @rdname upper @@ -1437,16 +1458,18 @@ setGeneric("var_pop", function(x) { standardGeneric("var_pop") }) #' @name NULL setGeneric("var_samp", function(x) { standardGeneric("var_samp") }) -#' @rdname weekofyear +#' @rdname column_datetime_functions #' @export +#' @name NULL setGeneric("weekofyear", function(x) { standardGeneric("weekofyear") }) #' @rdname window #' @export setGeneric("window", function(x, ...) { standardGeneric("window") }) -#' @rdname year +#' @rdname column_datetime_functions #' @export +#' @name NULL setGeneric("year", function(x) { standardGeneric("year") }) From e55a105ae04f1d1c35ee8f02005a3ab71d789124 Mon Sep 17 00:00:00 2001 From: Lubo Zhang Date: Thu, 22 Jun 2017 11:18:58 -0700 Subject: [PATCH 0773/1765] [SPARK-20599][SS] ConsoleSink should work with (batch) ## What changes were proposed in this pull request? Currently, if we read a batch and want to display it on the console sink, it will lead a runtime exception. Changes: - In this PR, we add a match rule to check whether it is a ConsoleSinkProvider, we will display the Dataset if using console format. ## How was this patch tested? spark.read.schema().json(path).write.format("console").save Author: Lubo Zhang Author: lubozhan Closes #18347 from lubozhan/dev. --- .../sql/execution/streaming/console.scala | 28 +++++++++++++++++-- 1 file changed, 26 insertions(+), 2 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/console.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/console.scala index 38c63191106d0..9e889ff679450 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/console.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/console.scala @@ -19,8 +19,10 @@ package org.apache.spark.sql.execution.streaming import org.apache.spark.internal.Logging import org.apache.spark.sql.{DataFrame, SQLContext} -import org.apache.spark.sql.sources.{DataSourceRegister, StreamSinkProvider} +import org.apache.spark.sql.sources.{BaseRelation, CreatableRelationProvider, DataSourceRegister, StreamSinkProvider} import org.apache.spark.sql.streaming.OutputMode +import org.apache.spark.sql.SaveMode +import org.apache.spark.sql.types.StructType class ConsoleSink(options: Map[String, String]) extends Sink with Logging { // Number of rows to display, by default 20 rows @@ -51,7 +53,14 @@ class ConsoleSink(options: Map[String, String]) extends Sink with Logging { } } -class ConsoleSinkProvider extends StreamSinkProvider with DataSourceRegister { +case class ConsoleRelation(override val sqlContext: SQLContext, data: DataFrame) + extends BaseRelation { + override def schema: StructType = data.schema +} + +class ConsoleSinkProvider extends StreamSinkProvider + with DataSourceRegister + with CreatableRelationProvider { def createSink( sqlContext: SQLContext, parameters: Map[String, String], @@ -60,5 +69,20 @@ class ConsoleSinkProvider extends StreamSinkProvider with DataSourceRegister { new ConsoleSink(parameters) } + def createRelation( + sqlContext: SQLContext, + mode: SaveMode, + parameters: Map[String, String], + data: DataFrame): BaseRelation = { + // Number of rows to display, by default 20 rows + val numRowsToShow = parameters.get("numRows").map(_.toInt).getOrElse(20) + + // Truncate the displayed data if it is too long, by default it is true + val isTruncated = parameters.get("truncate").map(_.toBoolean).getOrElse(true) + data.showInternal(numRowsToShow, isTruncated) + + ConsoleRelation(sqlContext, data) + } + def shortName(): String = "console" } From 58434acdd8cec0c762b4f09ace25e41d603af0a4 Mon Sep 17 00:00:00 2001 From: jinxing Date: Thu, 22 Jun 2017 14:10:51 -0700 Subject: [PATCH 0774/1765] [SPARK-19937] Collect metrics for remote bytes read to disk during shuffle. In current code(https://github.com/apache/spark/pull/16989), big blocks are shuffled to disk. This pr proposes to collect metrics for remote bytes fetched to disk. Author: jinxing Closes #18249 from jinxing64/SPARK-19937. --- .../apache/spark/InternalAccumulator.scala | 1 + .../spark/executor/ShuffleReadMetrics.scala | 13 +++++ .../apache/spark/executor/TaskMetrics.scala | 1 + .../status/api/v1/AllStagesResource.scala | 2 + .../org/apache/spark/status/api/v1/api.scala | 2 + .../storage/ShuffleBlockFetcherIterator.scala | 6 +++ .../org/apache/spark/ui/jobs/UIData.scala | 4 +- .../org/apache/spark/util/JsonProtocol.scala | 3 ++ .../one_stage_attempt_json_expectation.json | 8 +++ .../one_stage_json_expectation.json | 8 +++ .../stage_task_list_expectation.json | 20 ++++++++ ...multi_attempt_app_json_1__expectation.json | 8 +++ ...multi_attempt_app_json_2__expectation.json | 8 +++ ...k_list_w__offset___length_expectation.json | 50 +++++++++++++++++++ ...stage_task_list_w__sortBy_expectation.json | 20 ++++++++ ...tBy_short_names___runtime_expectation.json | 20 ++++++++ ...rtBy_short_names__runtime_expectation.json | 20 ++++++++ ...mmary_w__custom_quantiles_expectation.json | 1 + ...sk_summary_w_shuffle_read_expectation.json | 1 + ...k_summary_w_shuffle_write_expectation.json | 1 + ...age_with_accumulable_json_expectation.json | 8 +++ .../spark/executor/TaskMetricsSuite.scala | 3 ++ .../apache/spark/util/JsonProtocolSuite.scala | 33 ++++++++---- project/MimaExcludes.scala | 6 ++- 24 files changed, 234 insertions(+), 13 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/InternalAccumulator.scala b/core/src/main/scala/org/apache/spark/InternalAccumulator.scala index 82d3098e2e055..18b10d23da94c 100644 --- a/core/src/main/scala/org/apache/spark/InternalAccumulator.scala +++ b/core/src/main/scala/org/apache/spark/InternalAccumulator.scala @@ -50,6 +50,7 @@ private[spark] object InternalAccumulator { val REMOTE_BLOCKS_FETCHED = SHUFFLE_READ_METRICS_PREFIX + "remoteBlocksFetched" val LOCAL_BLOCKS_FETCHED = SHUFFLE_READ_METRICS_PREFIX + "localBlocksFetched" val REMOTE_BYTES_READ = SHUFFLE_READ_METRICS_PREFIX + "remoteBytesRead" + val REMOTE_BYTES_READ_TO_DISK = SHUFFLE_READ_METRICS_PREFIX + "remoteBytesReadToDisk" val LOCAL_BYTES_READ = SHUFFLE_READ_METRICS_PREFIX + "localBytesRead" val FETCH_WAIT_TIME = SHUFFLE_READ_METRICS_PREFIX + "fetchWaitTime" val RECORDS_READ = SHUFFLE_READ_METRICS_PREFIX + "recordsRead" diff --git a/core/src/main/scala/org/apache/spark/executor/ShuffleReadMetrics.scala b/core/src/main/scala/org/apache/spark/executor/ShuffleReadMetrics.scala index 8dd1a1ea059be..4be395c8358b2 100644 --- a/core/src/main/scala/org/apache/spark/executor/ShuffleReadMetrics.scala +++ b/core/src/main/scala/org/apache/spark/executor/ShuffleReadMetrics.scala @@ -31,6 +31,7 @@ class ShuffleReadMetrics private[spark] () extends Serializable { private[executor] val _remoteBlocksFetched = new LongAccumulator private[executor] val _localBlocksFetched = new LongAccumulator private[executor] val _remoteBytesRead = new LongAccumulator + private[executor] val _remoteBytesReadToDisk = new LongAccumulator private[executor] val _localBytesRead = new LongAccumulator private[executor] val _fetchWaitTime = new LongAccumulator private[executor] val _recordsRead = new LongAccumulator @@ -50,6 +51,11 @@ class ShuffleReadMetrics private[spark] () extends Serializable { */ def remoteBytesRead: Long = _remoteBytesRead.sum + /** + * Total number of remotes bytes read to disk from the shuffle by this task. + */ + def remoteBytesReadToDisk: Long = _remoteBytesReadToDisk.sum + /** * Shuffle data that was read from the local disk (as opposed to from a remote executor). */ @@ -80,6 +86,7 @@ class ShuffleReadMetrics private[spark] () extends Serializable { private[spark] def incRemoteBlocksFetched(v: Long): Unit = _remoteBlocksFetched.add(v) private[spark] def incLocalBlocksFetched(v: Long): Unit = _localBlocksFetched.add(v) private[spark] def incRemoteBytesRead(v: Long): Unit = _remoteBytesRead.add(v) + private[spark] def incRemoteBytesReadToDisk(v: Long): Unit = _remoteBytesReadToDisk.add(v) private[spark] def incLocalBytesRead(v: Long): Unit = _localBytesRead.add(v) private[spark] def incFetchWaitTime(v: Long): Unit = _fetchWaitTime.add(v) private[spark] def incRecordsRead(v: Long): Unit = _recordsRead.add(v) @@ -87,6 +94,7 @@ class ShuffleReadMetrics private[spark] () extends Serializable { private[spark] def setRemoteBlocksFetched(v: Int): Unit = _remoteBlocksFetched.setValue(v) private[spark] def setLocalBlocksFetched(v: Int): Unit = _localBlocksFetched.setValue(v) private[spark] def setRemoteBytesRead(v: Long): Unit = _remoteBytesRead.setValue(v) + private[spark] def setRemoteBytesReadToDisk(v: Long): Unit = _remoteBytesReadToDisk.setValue(v) private[spark] def setLocalBytesRead(v: Long): Unit = _localBytesRead.setValue(v) private[spark] def setFetchWaitTime(v: Long): Unit = _fetchWaitTime.setValue(v) private[spark] def setRecordsRead(v: Long): Unit = _recordsRead.setValue(v) @@ -99,6 +107,7 @@ class ShuffleReadMetrics private[spark] () extends Serializable { _remoteBlocksFetched.setValue(0) _localBlocksFetched.setValue(0) _remoteBytesRead.setValue(0) + _remoteBytesReadToDisk.setValue(0) _localBytesRead.setValue(0) _fetchWaitTime.setValue(0) _recordsRead.setValue(0) @@ -106,6 +115,7 @@ class ShuffleReadMetrics private[spark] () extends Serializable { _remoteBlocksFetched.add(metric.remoteBlocksFetched) _localBlocksFetched.add(metric.localBlocksFetched) _remoteBytesRead.add(metric.remoteBytesRead) + _remoteBytesReadToDisk.add(metric.remoteBytesReadToDisk) _localBytesRead.add(metric.localBytesRead) _fetchWaitTime.add(metric.fetchWaitTime) _recordsRead.add(metric.recordsRead) @@ -122,6 +132,7 @@ private[spark] class TempShuffleReadMetrics { private[this] var _remoteBlocksFetched = 0L private[this] var _localBlocksFetched = 0L private[this] var _remoteBytesRead = 0L + private[this] var _remoteBytesReadToDisk = 0L private[this] var _localBytesRead = 0L private[this] var _fetchWaitTime = 0L private[this] var _recordsRead = 0L @@ -129,6 +140,7 @@ private[spark] class TempShuffleReadMetrics { def incRemoteBlocksFetched(v: Long): Unit = _remoteBlocksFetched += v def incLocalBlocksFetched(v: Long): Unit = _localBlocksFetched += v def incRemoteBytesRead(v: Long): Unit = _remoteBytesRead += v + def incRemoteBytesReadToDisk(v: Long): Unit = _remoteBytesReadToDisk += v def incLocalBytesRead(v: Long): Unit = _localBytesRead += v def incFetchWaitTime(v: Long): Unit = _fetchWaitTime += v def incRecordsRead(v: Long): Unit = _recordsRead += v @@ -136,6 +148,7 @@ private[spark] class TempShuffleReadMetrics { def remoteBlocksFetched: Long = _remoteBlocksFetched def localBlocksFetched: Long = _localBlocksFetched def remoteBytesRead: Long = _remoteBytesRead + def remoteBytesReadToDisk: Long = _remoteBytesReadToDisk def localBytesRead: Long = _localBytesRead def fetchWaitTime: Long = _fetchWaitTime def recordsRead: Long = _recordsRead diff --git a/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala b/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala index a3ce3d1ccc5e3..341a6da8107ef 100644 --- a/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala +++ b/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala @@ -215,6 +215,7 @@ class TaskMetrics private[spark] () extends Serializable { shuffleRead.REMOTE_BLOCKS_FETCHED -> shuffleReadMetrics._remoteBlocksFetched, shuffleRead.LOCAL_BLOCKS_FETCHED -> shuffleReadMetrics._localBlocksFetched, shuffleRead.REMOTE_BYTES_READ -> shuffleReadMetrics._remoteBytesRead, + shuffleRead.REMOTE_BYTES_READ_TO_DISK -> shuffleReadMetrics._remoteBytesReadToDisk, shuffleRead.LOCAL_BYTES_READ -> shuffleReadMetrics._localBytesRead, shuffleRead.FETCH_WAIT_TIME -> shuffleReadMetrics._fetchWaitTime, shuffleRead.RECORDS_READ -> shuffleReadMetrics._recordsRead, diff --git a/core/src/main/scala/org/apache/spark/status/api/v1/AllStagesResource.scala b/core/src/main/scala/org/apache/spark/status/api/v1/AllStagesResource.scala index 1818935392eb3..56028710ecc66 100644 --- a/core/src/main/scala/org/apache/spark/status/api/v1/AllStagesResource.scala +++ b/core/src/main/scala/org/apache/spark/status/api/v1/AllStagesResource.scala @@ -200,6 +200,7 @@ private[v1] object AllStagesResource { readBytes = submetricQuantiles(_.totalBytesRead), readRecords = submetricQuantiles(_.recordsRead), remoteBytesRead = submetricQuantiles(_.remoteBytesRead), + remoteBytesReadToDisk = submetricQuantiles(_.remoteBytesReadToDisk), remoteBlocksFetched = submetricQuantiles(_.remoteBlocksFetched), localBlocksFetched = submetricQuantiles(_.localBlocksFetched), totalBlocksFetched = submetricQuantiles(_.totalBlocksFetched), @@ -281,6 +282,7 @@ private[v1] object AllStagesResource { localBlocksFetched = internal.localBlocksFetched, fetchWaitTime = internal.fetchWaitTime, remoteBytesRead = internal.remoteBytesRead, + remoteBytesReadToDisk = internal.remoteBytesReadToDisk, localBytesRead = internal.localBytesRead, recordsRead = internal.recordsRead ) diff --git a/core/src/main/scala/org/apache/spark/status/api/v1/api.scala b/core/src/main/scala/org/apache/spark/status/api/v1/api.scala index f6203271f3cd2..05948f2661056 100644 --- a/core/src/main/scala/org/apache/spark/status/api/v1/api.scala +++ b/core/src/main/scala/org/apache/spark/status/api/v1/api.scala @@ -208,6 +208,7 @@ class ShuffleReadMetrics private[spark]( val localBlocksFetched: Long, val fetchWaitTime: Long, val remoteBytesRead: Long, + val remoteBytesReadToDisk: Long, val localBytesRead: Long, val recordsRead: Long) @@ -249,6 +250,7 @@ class ShuffleReadMetricDistributions private[spark]( val localBlocksFetched: IndexedSeq[Double], val fetchWaitTime: IndexedSeq[Double], val remoteBytesRead: IndexedSeq[Double], + val remoteBytesReadToDisk: IndexedSeq[Double], val totalBlocksFetched: IndexedSeq[Double]) class ShuffleWriteMetricDistributions private[spark]( diff --git a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala index bded3a1e4eb54..a10f1feadd0af 100644 --- a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala +++ b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala @@ -165,6 +165,9 @@ final class ShuffleBlockFetcherIterator( case SuccessFetchResult(_, address, _, buf, _) => if (address != blockManager.blockManagerId) { shuffleMetrics.incRemoteBytesRead(buf.size) + if (buf.isInstanceOf[FileSegmentManagedBuffer]) { + shuffleMetrics.incRemoteBytesReadToDisk(buf.size) + } shuffleMetrics.incRemoteBlocksFetched(1) } buf.release() @@ -363,6 +366,9 @@ final class ShuffleBlockFetcherIterator( case r @ SuccessFetchResult(blockId, address, size, buf, isNetworkReqDone) => if (address != blockManager.blockManagerId) { shuffleMetrics.incRemoteBytesRead(buf.size) + if (buf.isInstanceOf[FileSegmentManagedBuffer]) { + shuffleMetrics.incRemoteBytesReadToDisk(buf.size) + } shuffleMetrics.incRemoteBlocksFetched(1) } bytesInFlight -= size diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/UIData.scala b/core/src/main/scala/org/apache/spark/ui/jobs/UIData.scala index 6764daa0df529..9448baac096dc 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/UIData.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/UIData.scala @@ -251,6 +251,7 @@ private[spark] object UIData { remoteBlocksFetched: Long, localBlocksFetched: Long, remoteBytesRead: Long, + remoteBytesReadToDisk: Long, localBytesRead: Long, fetchWaitTime: Long, recordsRead: Long, @@ -274,6 +275,7 @@ private[spark] object UIData { remoteBlocksFetched = metrics.remoteBlocksFetched, localBlocksFetched = metrics.localBlocksFetched, remoteBytesRead = metrics.remoteBytesRead, + remoteBytesReadToDisk = metrics.remoteBytesReadToDisk, localBytesRead = metrics.localBytesRead, fetchWaitTime = metrics.fetchWaitTime, recordsRead = metrics.recordsRead, @@ -282,7 +284,7 @@ private[spark] object UIData { ) } } - private val EMPTY = ShuffleReadMetricsUIData(0, 0, 0, 0, 0, 0, 0, 0) + private val EMPTY = ShuffleReadMetricsUIData(0, 0, 0, 0, 0, 0, 0, 0, 0) } case class ShuffleWriteMetricsUIData( diff --git a/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala b/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala index 8296c4294242c..806d14e7cc119 100644 --- a/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala +++ b/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala @@ -339,6 +339,7 @@ private[spark] object JsonProtocol { ("Local Blocks Fetched" -> taskMetrics.shuffleReadMetrics.localBlocksFetched) ~ ("Fetch Wait Time" -> taskMetrics.shuffleReadMetrics.fetchWaitTime) ~ ("Remote Bytes Read" -> taskMetrics.shuffleReadMetrics.remoteBytesRead) ~ + ("Remote Bytes Read To Disk" -> taskMetrics.shuffleReadMetrics.remoteBytesReadToDisk) ~ ("Local Bytes Read" -> taskMetrics.shuffleReadMetrics.localBytesRead) ~ ("Total Records Read" -> taskMetrics.shuffleReadMetrics.recordsRead) val shuffleWriteMetrics: JValue = @@ -804,6 +805,8 @@ private[spark] object JsonProtocol { readMetrics.incRemoteBlocksFetched((readJson \ "Remote Blocks Fetched").extract[Int]) readMetrics.incLocalBlocksFetched((readJson \ "Local Blocks Fetched").extract[Int]) readMetrics.incRemoteBytesRead((readJson \ "Remote Bytes Read").extract[Long]) + Utils.jsonOption(readJson \ "Remote Bytes Read To Disk") + .foreach { v => readMetrics.incRemoteBytesReadToDisk(v.extract[Long])} readMetrics.incLocalBytesRead( Utils.jsonOption(readJson \ "Local Bytes Read").map(_.extract[Long]).getOrElse(0L)) readMetrics.incFetchWaitTime((readJson \ "Fetch Wait Time").extract[Long]) diff --git a/core/src/test/resources/HistoryServerExpectations/one_stage_attempt_json_expectation.json b/core/src/test/resources/HistoryServerExpectations/one_stage_attempt_json_expectation.json index c2f450ba87c6d..6fb40f6f1713b 100644 --- a/core/src/test/resources/HistoryServerExpectations/one_stage_attempt_json_expectation.json +++ b/core/src/test/resources/HistoryServerExpectations/one_stage_attempt_json_expectation.json @@ -60,6 +60,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -105,6 +106,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -150,6 +152,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -195,6 +198,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -240,6 +244,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -285,6 +290,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -330,6 +336,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -375,6 +382,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, diff --git a/core/src/test/resources/HistoryServerExpectations/one_stage_json_expectation.json b/core/src/test/resources/HistoryServerExpectations/one_stage_json_expectation.json index 506859ae545b1..f5a89a2107646 100644 --- a/core/src/test/resources/HistoryServerExpectations/one_stage_json_expectation.json +++ b/core/src/test/resources/HistoryServerExpectations/one_stage_json_expectation.json @@ -60,6 +60,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -105,6 +106,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -150,6 +152,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -195,6 +198,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -240,6 +244,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -285,6 +290,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -330,6 +336,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -375,6 +382,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, diff --git a/core/src/test/resources/HistoryServerExpectations/stage_task_list_expectation.json b/core/src/test/resources/HistoryServerExpectations/stage_task_list_expectation.json index f4cec68fbfdf2..9b401b414f8d4 100644 --- a/core/src/test/resources/HistoryServerExpectations/stage_task_list_expectation.json +++ b/core/src/test/resources/HistoryServerExpectations/stage_task_list_expectation.json @@ -33,6 +33,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -77,6 +78,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -121,6 +123,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -165,6 +168,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -209,6 +213,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -253,6 +258,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -297,6 +303,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -341,6 +348,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -385,6 +393,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -429,6 +438,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -473,6 +483,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -517,6 +528,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -561,6 +573,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -605,6 +618,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -649,6 +663,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -693,6 +708,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -737,6 +753,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -781,6 +798,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -825,6 +843,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -869,6 +888,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, diff --git a/core/src/test/resources/HistoryServerExpectations/stage_task_list_from_multi_attempt_app_json_1__expectation.json b/core/src/test/resources/HistoryServerExpectations/stage_task_list_from_multi_attempt_app_json_1__expectation.json index 496a21c328da9..2ebee66a6d7c2 100644 --- a/core/src/test/resources/HistoryServerExpectations/stage_task_list_from_multi_attempt_app_json_1__expectation.json +++ b/core/src/test/resources/HistoryServerExpectations/stage_task_list_from_multi_attempt_app_json_1__expectation.json @@ -38,6 +38,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -87,6 +88,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -136,6 +138,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -185,6 +188,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -234,6 +238,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -283,6 +288,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -332,6 +338,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -381,6 +388,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, diff --git a/core/src/test/resources/HistoryServerExpectations/stage_task_list_from_multi_attempt_app_json_2__expectation.json b/core/src/test/resources/HistoryServerExpectations/stage_task_list_from_multi_attempt_app_json_2__expectation.json index 4328dc753c5d4..965a31a4104c3 100644 --- a/core/src/test/resources/HistoryServerExpectations/stage_task_list_from_multi_attempt_app_json_2__expectation.json +++ b/core/src/test/resources/HistoryServerExpectations/stage_task_list_from_multi_attempt_app_json_2__expectation.json @@ -38,6 +38,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -87,6 +88,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -136,6 +138,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -185,6 +188,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -234,6 +238,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -283,6 +288,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -332,6 +338,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -381,6 +388,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, diff --git a/core/src/test/resources/HistoryServerExpectations/stage_task_list_w__offset___length_expectation.json b/core/src/test/resources/HistoryServerExpectations/stage_task_list_w__offset___length_expectation.json index 8c571430f3a1f..31132e156937c 100644 --- a/core/src/test/resources/HistoryServerExpectations/stage_task_list_w__offset___length_expectation.json +++ b/core/src/test/resources/HistoryServerExpectations/stage_task_list_w__offset___length_expectation.json @@ -33,6 +33,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -77,6 +78,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -121,6 +123,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -165,6 +168,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -209,6 +213,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -253,6 +258,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -297,6 +303,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -341,6 +348,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -385,6 +393,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -429,6 +438,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -473,6 +483,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -517,6 +528,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -561,6 +573,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -605,6 +618,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -649,6 +663,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -693,6 +708,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -737,6 +753,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -781,6 +798,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -825,6 +843,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -869,6 +888,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -913,6 +933,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -957,6 +978,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -1001,6 +1023,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -1045,6 +1068,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -1089,6 +1113,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -1133,6 +1158,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -1177,6 +1203,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -1221,6 +1248,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -1265,6 +1293,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -1309,6 +1338,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -1353,6 +1383,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -1397,6 +1428,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -1441,6 +1473,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -1485,6 +1518,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -1529,6 +1563,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -1573,6 +1608,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -1617,6 +1653,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -1661,6 +1698,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -1705,6 +1743,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -1749,6 +1788,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -1793,6 +1833,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -1837,6 +1878,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -1881,6 +1923,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -1925,6 +1968,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -1969,6 +2013,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -2013,6 +2058,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -2057,6 +2103,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -2101,6 +2148,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -2145,6 +2193,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -2189,6 +2238,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, diff --git a/core/src/test/resources/HistoryServerExpectations/stage_task_list_w__sortBy_expectation.json b/core/src/test/resources/HistoryServerExpectations/stage_task_list_w__sortBy_expectation.json index 0bd614bdc756e..6af1cfbeb8f7e 100644 --- a/core/src/test/resources/HistoryServerExpectations/stage_task_list_w__sortBy_expectation.json +++ b/core/src/test/resources/HistoryServerExpectations/stage_task_list_w__sortBy_expectation.json @@ -33,6 +33,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -77,6 +78,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -121,6 +123,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -165,6 +168,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -209,6 +213,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -253,6 +258,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -297,6 +303,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -341,6 +348,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -385,6 +393,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -429,6 +438,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -473,6 +483,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -517,6 +528,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -561,6 +573,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -605,6 +618,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -649,6 +663,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -693,6 +708,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -737,6 +753,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -781,6 +798,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -825,6 +843,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -869,6 +888,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, diff --git a/core/src/test/resources/HistoryServerExpectations/stage_task_list_w__sortBy_short_names___runtime_expectation.json b/core/src/test/resources/HistoryServerExpectations/stage_task_list_w__sortBy_short_names___runtime_expectation.json index 0bd614bdc756e..6af1cfbeb8f7e 100644 --- a/core/src/test/resources/HistoryServerExpectations/stage_task_list_w__sortBy_short_names___runtime_expectation.json +++ b/core/src/test/resources/HistoryServerExpectations/stage_task_list_w__sortBy_short_names___runtime_expectation.json @@ -33,6 +33,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -77,6 +78,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -121,6 +123,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -165,6 +168,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -209,6 +213,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -253,6 +258,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -297,6 +303,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -341,6 +348,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -385,6 +393,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -429,6 +438,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -473,6 +483,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -517,6 +528,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -561,6 +573,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -605,6 +618,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -649,6 +663,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -693,6 +708,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -737,6 +753,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -781,6 +798,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -825,6 +843,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -869,6 +888,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, diff --git a/core/src/test/resources/HistoryServerExpectations/stage_task_list_w__sortBy_short_names__runtime_expectation.json b/core/src/test/resources/HistoryServerExpectations/stage_task_list_w__sortBy_short_names__runtime_expectation.json index b58f1a51ba481..c26daf4b8d7bd 100644 --- a/core/src/test/resources/HistoryServerExpectations/stage_task_list_w__sortBy_short_names__runtime_expectation.json +++ b/core/src/test/resources/HistoryServerExpectations/stage_task_list_w__sortBy_short_names__runtime_expectation.json @@ -33,6 +33,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -77,6 +78,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -121,6 +123,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -165,6 +168,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -209,6 +213,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -253,6 +258,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -297,6 +303,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -341,6 +348,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -385,6 +393,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -429,6 +438,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -473,6 +483,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -517,6 +528,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -561,6 +573,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -605,6 +618,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -649,6 +663,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -693,6 +708,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -737,6 +753,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -781,6 +798,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -825,6 +843,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -869,6 +888,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, diff --git a/core/src/test/resources/HistoryServerExpectations/stage_task_summary_w__custom_quantiles_expectation.json b/core/src/test/resources/HistoryServerExpectations/stage_task_summary_w__custom_quantiles_expectation.json index 0ed609d5b7f92..f8e27703c0def 100644 --- a/core/src/test/resources/HistoryServerExpectations/stage_task_summary_w__custom_quantiles_expectation.json +++ b/core/src/test/resources/HistoryServerExpectations/stage_task_summary_w__custom_quantiles_expectation.json @@ -24,6 +24,7 @@ "localBlocksFetched" : [ 0.0, 0.0, 0.0 ], "fetchWaitTime" : [ 0.0, 0.0, 0.0 ], "remoteBytesRead" : [ 0.0, 0.0, 0.0 ], + "remoteBytesReadToDisk" : [ 0.0, 0.0, 0.0 ], "totalBlocksFetched" : [ 0.0, 0.0, 0.0 ] }, "shuffleWriteMetrics" : { diff --git a/core/src/test/resources/HistoryServerExpectations/stage_task_summary_w_shuffle_read_expectation.json b/core/src/test/resources/HistoryServerExpectations/stage_task_summary_w_shuffle_read_expectation.json index 6d230ac653776..a28bda16a956e 100644 --- a/core/src/test/resources/HistoryServerExpectations/stage_task_summary_w_shuffle_read_expectation.json +++ b/core/src/test/resources/HistoryServerExpectations/stage_task_summary_w_shuffle_read_expectation.json @@ -24,6 +24,7 @@ "localBlocksFetched" : [ 100.0, 100.0, 100.0, 100.0, 100.0 ], "fetchWaitTime" : [ 0.0, 0.0, 0.0, 1.0, 1.0 ], "remoteBytesRead" : [ 0.0, 0.0, 0.0, 0.0, 0.0 ], + "remoteBytesReadToDisk" : [ 0.0, 0.0, 0.0, 0.0, 0.0 ], "totalBlocksFetched" : [ 100.0, 100.0, 100.0, 100.0, 100.0 ] }, "shuffleWriteMetrics" : { diff --git a/core/src/test/resources/HistoryServerExpectations/stage_task_summary_w_shuffle_write_expectation.json b/core/src/test/resources/HistoryServerExpectations/stage_task_summary_w_shuffle_write_expectation.json index aea0f5413d8b9..ede3eaed1d1d2 100644 --- a/core/src/test/resources/HistoryServerExpectations/stage_task_summary_w_shuffle_write_expectation.json +++ b/core/src/test/resources/HistoryServerExpectations/stage_task_summary_w_shuffle_write_expectation.json @@ -24,6 +24,7 @@ "localBlocksFetched" : [ 0.0, 0.0, 0.0, 0.0, 0.0 ], "fetchWaitTime" : [ 0.0, 0.0, 0.0, 0.0, 0.0 ], "remoteBytesRead" : [ 0.0, 0.0, 0.0, 0.0, 0.0 ], + "remoteBytesReadToDisk" : [ 0.0, 0.0, 0.0, 0.0, 0.0 ], "totalBlocksFetched" : [ 0.0, 0.0, 0.0, 0.0, 0.0 ] }, "shuffleWriteMetrics" : { diff --git a/core/src/test/resources/HistoryServerExpectations/stage_with_accumulable_json_expectation.json b/core/src/test/resources/HistoryServerExpectations/stage_with_accumulable_json_expectation.json index a449926ee7dc6..44b5f66efe339 100644 --- a/core/src/test/resources/HistoryServerExpectations/stage_with_accumulable_json_expectation.json +++ b/core/src/test/resources/HistoryServerExpectations/stage_with_accumulable_json_expectation.json @@ -69,6 +69,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -119,6 +120,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -169,6 +171,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -219,6 +222,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -269,6 +273,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -319,6 +324,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -369,6 +375,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -419,6 +426,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, diff --git a/core/src/test/scala/org/apache/spark/executor/TaskMetricsSuite.scala b/core/src/test/scala/org/apache/spark/executor/TaskMetricsSuite.scala index eae26fa742a23..7bcc2fb5231db 100644 --- a/core/src/test/scala/org/apache/spark/executor/TaskMetricsSuite.scala +++ b/core/src/test/scala/org/apache/spark/executor/TaskMetricsSuite.scala @@ -94,6 +94,8 @@ class TaskMetricsSuite extends SparkFunSuite { sr.setRemoteBytesRead(30L) sr.incRemoteBytesRead(3L) sr.incRemoteBytesRead(3L) + sr.setRemoteBytesReadToDisk(10L) + sr.incRemoteBytesReadToDisk(8L) sr.setLocalBytesRead(400L) sr.setLocalBytesRead(40L) sr.incLocalBytesRead(4L) @@ -110,6 +112,7 @@ class TaskMetricsSuite extends SparkFunSuite { assert(sr.remoteBlocksFetched == 12) assert(sr.localBlocksFetched == 24) assert(sr.remoteBytesRead == 36L) + assert(sr.remoteBytesReadToDisk == 18L) assert(sr.localBytesRead == 48L) assert(sr.fetchWaitTime == 60L) assert(sr.recordsRead == 72L) diff --git a/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala b/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala index a77c8e3cab4e8..57452d4912abe 100644 --- a/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala @@ -848,6 +848,7 @@ private[spark] object JsonProtocolSuite extends Assertions { } else { val sr = t.createTempShuffleReadMetrics() sr.incRemoteBytesRead(b + d) + sr.incRemoteBytesReadToDisk(b) sr.incLocalBlocksFetched(e) sr.incFetchWaitTime(a + d) sr.incRemoteBlocksFetched(f) @@ -1128,6 +1129,7 @@ private[spark] object JsonProtocolSuite extends Assertions { | "Local Blocks Fetched": 700, | "Fetch Wait Time": 900, | "Remote Bytes Read": 1000, + | "Remote Bytes Read To Disk": 400, | "Local Bytes Read": 1100, | "Total Records Read": 10 | }, @@ -1228,6 +1230,7 @@ private[spark] object JsonProtocolSuite extends Assertions { | "Local Blocks Fetched" : 0, | "Fetch Wait Time" : 0, | "Remote Bytes Read" : 0, + | "Remote Bytes Read To Disk" : 0, | "Local Bytes Read" : 0, | "Total Records Read" : 0 | }, @@ -1328,10 +1331,11 @@ private[spark] object JsonProtocolSuite extends Assertions { | "Local Blocks Fetched" : 0, | "Fetch Wait Time" : 0, | "Remote Bytes Read" : 0, + | "Remote Bytes Read To Disk" : 0, | "Local Bytes Read" : 0, | "Total Records Read" : 0 | }, - | "Shuffle Write Metrics" : { + | "Shuffle Write Metrics": { | "Shuffle Bytes Written" : 0, | "Shuffle Write Time" : 0, | "Shuffle Records Written" : 0 @@ -1915,76 +1919,83 @@ private[spark] object JsonProtocolSuite extends Assertions { | }, | { | "ID": 14, - | "Name": "${shuffleRead.LOCAL_BYTES_READ}", + | "Name": "${shuffleRead.REMOTE_BYTES_READ_TO_DISK}", | "Update": 0, | "Internal": true, | "Count Failed Values": true | }, | { | "ID": 15, - | "Name": "${shuffleRead.FETCH_WAIT_TIME}", + | "Name": "${shuffleRead.LOCAL_BYTES_READ}", | "Update": 0, | "Internal": true, | "Count Failed Values": true | }, | { | "ID": 16, - | "Name": "${shuffleRead.RECORDS_READ}", + | "Name": "${shuffleRead.FETCH_WAIT_TIME}", | "Update": 0, | "Internal": true, | "Count Failed Values": true | }, | { | "ID": 17, - | "Name": "${shuffleWrite.BYTES_WRITTEN}", + | "Name": "${shuffleRead.RECORDS_READ}", | "Update": 0, | "Internal": true, | "Count Failed Values": true | }, | { | "ID": 18, - | "Name": "${shuffleWrite.RECORDS_WRITTEN}", + | "Name": "${shuffleWrite.BYTES_WRITTEN}", | "Update": 0, | "Internal": true, | "Count Failed Values": true | }, | { | "ID": 19, - | "Name": "${shuffleWrite.WRITE_TIME}", + | "Name": "${shuffleWrite.RECORDS_WRITTEN}", | "Update": 0, | "Internal": true, | "Count Failed Values": true | }, | { | "ID": 20, + | "Name": "${shuffleWrite.WRITE_TIME}", + | "Update": 0, + | "Internal": true, + | "Count Failed Values": true + | }, + | { + | "ID": 21, | "Name": "${input.BYTES_READ}", | "Update": 2100, | "Internal": true, | "Count Failed Values": true | }, | { - | "ID": 21, + | "ID": 22, | "Name": "${input.RECORDS_READ}", | "Update": 21, | "Internal": true, | "Count Failed Values": true | }, | { - | "ID": 22, + | "ID": 23, | "Name": "${output.BYTES_WRITTEN}", | "Update": 1200, | "Internal": true, | "Count Failed Values": true | }, | { - | "ID": 23, + | "ID": 24, | "Name": "${output.RECORDS_WRITTEN}", | "Update": 12, | "Internal": true, | "Count Failed Values": true | }, | { - | "ID": 24, + | "ID": 25, | "Name": "$TEST_ACCUM", | "Update": 0, | "Internal": true, diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index 3cc089dcede38..1793da03a2c3e 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -37,7 +37,11 @@ object MimaExcludes { // Exclude rules for 2.3.x lazy val v23excludes = v22excludes ++ Seq( // [SPARK-20495][SQL] Add StorageLevel to cacheTable API - ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.sql.catalog.Catalog.cacheTable") + ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.sql.catalog.Catalog.cacheTable"), + + // [SPARK-19937] Add remote bytes read to disk. + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.status.api.v1.ShuffleReadMetrics.this"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.status.api.v1.ShuffleReadMetricDistributions.this") ) // Exclude rules for 2.2.x From e44697606f429b01808c1a22cb44cb5b89585c5c Mon Sep 17 00:00:00 2001 From: Bryan Cutler Date: Fri, 23 Jun 2017 09:01:13 +0800 Subject: [PATCH 0775/1765] [SPARK-13534][PYSPARK] Using Apache Arrow to increase performance of DataFrame.toPandas ## What changes were proposed in this pull request? Integrate Apache Arrow with Spark to increase performance of `DataFrame.toPandas`. This has been done by using Arrow to convert data partitions on the executor JVM to Arrow payload byte arrays where they are then served to the Python process. The Python DataFrame can then collect the Arrow payloads where they are combined and converted to a Pandas DataFrame. All non-complex data types are currently supported, otherwise an `UnsupportedOperation` exception is thrown. Additions to Spark include a Scala package private method `Dataset.toArrowPayloadBytes` that will convert data partitions in the executor JVM to `ArrowPayload`s as byte arrays so they can be easily served. A package private class/object `ArrowConverters` that provide data type mappings and conversion routines. In Python, a public method `DataFrame.collectAsArrow` is added to collect Arrow payloads and an optional flag in `toPandas(useArrow=False)` to enable using Arrow (uses the old conversion by default). ## How was this patch tested? Added a new test suite `ArrowConvertersSuite` that will run tests on conversion of Datasets to Arrow payloads for supported types. The suite will generate a Dataset and matching Arrow JSON data, then the dataset is converted to an Arrow payload and finally validated against the JSON data. This will ensure that the schema and data has been converted correctly. Added PySpark tests to verify the `toPandas` method is producing equal DataFrames with and without pyarrow. A roundtrip test to ensure the pandas DataFrame produced by pyspark is equal to a one made directly with pandas. Author: Bryan Cutler Author: Li Jin Author: Li Jin Author: Wes McKinney Closes #15821 from BryanCutler/wip-toPandas_with_arrow-SPARK-13534. --- bin/pyspark | 2 +- dev/deps/spark-deps-hadoop-2.6 | 5 + dev/deps/spark-deps-hadoop-2.7 | 5 + dev/run-pip-tests | 6 + pom.xml | 20 + python/pyspark/serializers.py | 17 + python/pyspark/sql/dataframe.py | 48 +- python/pyspark/sql/tests.py | 79 +- .../apache/spark/sql/internal/SQLConf.scala | 22 + sql/core/pom.xml | 4 + .../scala/org/apache/spark/sql/Dataset.scala | 20 + .../sql/execution/arrow/ArrowConverters.scala | 429 ++++++ .../arrow/ArrowConvertersSuite.scala | 1222 +++++++++++++++++ 13 files changed, 1866 insertions(+), 13 deletions(-) create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowConvertersSuite.scala diff --git a/bin/pyspark b/bin/pyspark index 98387c2ec5b8a..8eeea7716cc98 100755 --- a/bin/pyspark +++ b/bin/pyspark @@ -68,7 +68,7 @@ if [[ -n "$SPARK_TESTING" ]]; then unset YARN_CONF_DIR unset HADOOP_CONF_DIR export PYTHONHASHSEED=0 - exec "$PYSPARK_DRIVER_PYTHON" -m "$1" + exec "$PYSPARK_DRIVER_PYTHON" -m "$@" exit fi diff --git a/dev/deps/spark-deps-hadoop-2.6 b/dev/deps/spark-deps-hadoop-2.6 index 9287bd47cf113..9868c1ab7c2ab 100644 --- a/dev/deps/spark-deps-hadoop-2.6 +++ b/dev/deps/spark-deps-hadoop-2.6 @@ -13,6 +13,9 @@ apacheds-kerberos-codec-2.0.0-M15.jar api-asn1-api-1.0.0-M20.jar api-util-1.0.0-M20.jar arpack_combined_all-0.1.jar +arrow-format-0.4.0.jar +arrow-memory-0.4.0.jar +arrow-vector-0.4.0.jar avro-1.7.7.jar avro-ipc-1.7.7.jar avro-mapred-1.7.7-hadoop2.jar @@ -55,6 +58,7 @@ datanucleus-core-3.2.10.jar datanucleus-rdbms-3.2.9.jar derby-10.12.1.1.jar eigenbase-properties-1.1.5.jar +flatbuffers-1.2.0-3f79e055.jar gson-2.2.4.jar guava-14.0.1.jar guice-3.0.jar @@ -77,6 +81,7 @@ hadoop-yarn-server-web-proxy-2.6.5.jar hk2-api-2.4.0-b34.jar hk2-locator-2.4.0-b34.jar hk2-utils-2.4.0-b34.jar +hppc-0.7.1.jar htrace-core-3.0.4.jar httpclient-4.5.2.jar httpcore-4.4.4.jar diff --git a/dev/deps/spark-deps-hadoop-2.7 b/dev/deps/spark-deps-hadoop-2.7 index 9127413ab6c23..57c78cfe12087 100644 --- a/dev/deps/spark-deps-hadoop-2.7 +++ b/dev/deps/spark-deps-hadoop-2.7 @@ -13,6 +13,9 @@ apacheds-kerberos-codec-2.0.0-M15.jar api-asn1-api-1.0.0-M20.jar api-util-1.0.0-M20.jar arpack_combined_all-0.1.jar +arrow-format-0.4.0.jar +arrow-memory-0.4.0.jar +arrow-vector-0.4.0.jar avro-1.7.7.jar avro-ipc-1.7.7.jar avro-mapred-1.7.7-hadoop2.jar @@ -55,6 +58,7 @@ datanucleus-core-3.2.10.jar datanucleus-rdbms-3.2.9.jar derby-10.12.1.1.jar eigenbase-properties-1.1.5.jar +flatbuffers-1.2.0-3f79e055.jar gson-2.2.4.jar guava-14.0.1.jar guice-3.0.jar @@ -77,6 +81,7 @@ hadoop-yarn-server-web-proxy-2.7.3.jar hk2-api-2.4.0-b34.jar hk2-locator-2.4.0-b34.jar hk2-utils-2.4.0-b34.jar +hppc-0.7.1.jar htrace-core-3.1.0-incubating.jar httpclient-4.5.2.jar httpcore-4.4.4.jar diff --git a/dev/run-pip-tests b/dev/run-pip-tests index d51dde12a03c5..225e9209536f0 100755 --- a/dev/run-pip-tests +++ b/dev/run-pip-tests @@ -83,6 +83,8 @@ for python in "${PYTHON_EXECS[@]}"; do if [ -n "$USE_CONDA" ]; then conda create -y -p "$VIRTUALENV_PATH" python=$python numpy pandas pip setuptools source activate "$VIRTUALENV_PATH" + conda install -y -c conda-forge pyarrow=0.4.0 + TEST_PYARROW=1 else mkdir -p "$VIRTUALENV_PATH" virtualenv --python=$python "$VIRTUALENV_PATH" @@ -120,6 +122,10 @@ for python in "${PYTHON_EXECS[@]}"; do python "$FWDIR"/dev/pip-sanity-check.py echo "Run the tests for context.py" python "$FWDIR"/python/pyspark/context.py + if [ -n "$TEST_PYARROW" ]; then + echo "Run tests for pyarrow" + SPARK_TESTING=1 "$FWDIR"/bin/pyspark pyspark.sql.tests ArrowTests + fi cd "$FWDIR" diff --git a/pom.xml b/pom.xml index 5f524079495c0..f124ba45007b7 100644 --- a/pom.xml +++ b/pom.xml @@ -181,6 +181,7 @@ 2.6 1.8 1.0.0 + 0.4.0 ${java.home} @@ -1878,6 +1879,25 @@ paranamer ${paranamer.version} + + org.apache.arrow + arrow-vector + ${arrow.version} + + + com.fasterxml.jackson.core + jackson-annotations + + + com.fasterxml.jackson.core + jackson-databind + + + io.netty + netty-handler + + + diff --git a/python/pyspark/serializers.py b/python/pyspark/serializers.py index ea5e00e9eeef5..d5c2a7518b18f 100644 --- a/python/pyspark/serializers.py +++ b/python/pyspark/serializers.py @@ -182,6 +182,23 @@ def loads(self, obj): raise NotImplementedError +class ArrowSerializer(FramedSerializer): + """ + Serializes an Arrow stream. + """ + + def dumps(self, obj): + raise NotImplementedError + + def loads(self, obj): + import pyarrow as pa + reader = pa.RecordBatchFileReader(pa.BufferReader(obj)) + return reader.read_all() + + def __repr__(self): + return "ArrowSerializer" + + class BatchedSerializer(Serializer): """ diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index 0649271ed2246..760f113dfd197 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -29,7 +29,8 @@ from pyspark import copy_func, since from pyspark.rdd import RDD, _load_from_socket, ignore_unicode_prefix -from pyspark.serializers import BatchedSerializer, PickleSerializer, UTF8Deserializer +from pyspark.serializers import ArrowSerializer, BatchedSerializer, PickleSerializer, \ + UTF8Deserializer from pyspark.storagelevel import StorageLevel from pyspark.traceback_utils import SCCallSiteSync from pyspark.sql.types import _parse_datatype_json_string @@ -1708,7 +1709,8 @@ def toDF(self, *cols): @since(1.3) def toPandas(self): - """Returns the contents of this :class:`DataFrame` as Pandas ``pandas.DataFrame``. + """ + Returns the contents of this :class:`DataFrame` as Pandas ``pandas.DataFrame``. This is only available if Pandas is installed and available. @@ -1721,18 +1723,42 @@ def toPandas(self): 1 5 Bob """ import pandas as pd + if self.sql_ctx.getConf("spark.sql.execution.arrow.enable", "false").lower() == "true": + try: + import pyarrow + tables = self._collectAsArrow() + if tables: + table = pyarrow.concat_tables(tables) + return table.to_pandas() + else: + return pd.DataFrame.from_records([], columns=self.columns) + except ImportError as e: + msg = "note: pyarrow must be installed and available on calling Python process " \ + "if using spark.sql.execution.arrow.enable=true" + raise ImportError("%s\n%s" % (e.message, msg)) + else: + dtype = {} + for field in self.schema: + pandas_type = _to_corrected_pandas_type(field.dataType) + if pandas_type is not None: + dtype[field.name] = pandas_type - dtype = {} - for field in self.schema: - pandas_type = _to_corrected_pandas_type(field.dataType) - if pandas_type is not None: - dtype[field.name] = pandas_type + pdf = pd.DataFrame.from_records(self.collect(), columns=self.columns) - pdf = pd.DataFrame.from_records(self.collect(), columns=self.columns) + for f, t in dtype.items(): + pdf[f] = pdf[f].astype(t, copy=False) + return pdf - for f, t in dtype.items(): - pdf[f] = pdf[f].astype(t, copy=False) - return pdf + def _collectAsArrow(self): + """ + Returns all records as list of deserialized ArrowPayloads, pyarrow must be installed + and available. + + .. note:: Experimental. + """ + with SCCallSiteSync(self._sc) as css: + port = self._jdf.collectAsArrowToPython() + return list(_load_from_socket(port, ArrowSerializer())) ########################################################################################## # Pandas compatibility diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 0a1cd6856b8e8..326e8548a617c 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -58,12 +58,21 @@ from pyspark.sql import SparkSession, SQLContext, HiveContext, Column, Row from pyspark.sql.types import * from pyspark.sql.types import UserDefinedType, _infer_type -from pyspark.tests import ReusedPySparkTestCase, SparkSubmitTests +from pyspark.tests import QuietTest, ReusedPySparkTestCase, SparkSubmitTests from pyspark.sql.functions import UserDefinedFunction, sha2, lit from pyspark.sql.window import Window from pyspark.sql.utils import AnalysisException, ParseException, IllegalArgumentException +_have_arrow = False +try: + import pyarrow + _have_arrow = True +except: + # No Arrow, but that's okay, we'll skip those tests + pass + + class UTCOffsetTimezone(datetime.tzinfo): """ Specifies timezone in UTC offset @@ -2620,6 +2629,74 @@ def range_frame_match(): importlib.reload(window) + +@unittest.skipIf(not _have_arrow, "Arrow not installed") +class ArrowTests(ReusedPySparkTestCase): + + @classmethod + def setUpClass(cls): + ReusedPySparkTestCase.setUpClass() + cls.spark = SparkSession(cls.sc) + cls.spark.conf.set("spark.sql.execution.arrow.enable", "true") + cls.schema = StructType([ + StructField("1_str_t", StringType(), True), + StructField("2_int_t", IntegerType(), True), + StructField("3_long_t", LongType(), True), + StructField("4_float_t", FloatType(), True), + StructField("5_double_t", DoubleType(), True)]) + cls.data = [("a", 1, 10, 0.2, 2.0), + ("b", 2, 20, 0.4, 4.0), + ("c", 3, 30, 0.8, 6.0)] + + def assertFramesEqual(self, df_with_arrow, df_without): + msg = ("DataFrame from Arrow is not equal" + + ("\n\nWith Arrow:\n%s\n%s" % (df_with_arrow, df_with_arrow.dtypes)) + + ("\n\nWithout:\n%s\n%s" % (df_without, df_without.dtypes))) + self.assertTrue(df_without.equals(df_with_arrow), msg=msg) + + def test_unsupported_datatype(self): + schema = StructType([StructField("array", ArrayType(IntegerType(), False), True)]) + df = self.spark.createDataFrame([([1, 2, 3],)], schema=schema) + with QuietTest(self.sc): + self.assertRaises(Exception, lambda: df.toPandas()) + + def test_null_conversion(self): + df_null = self.spark.createDataFrame([tuple([None for _ in range(len(self.data[0]))])] + + self.data) + pdf = df_null.toPandas() + null_counts = pdf.isnull().sum().tolist() + self.assertTrue(all([c == 1 for c in null_counts])) + + def test_toPandas_arrow_toggle(self): + df = self.spark.createDataFrame(self.data, schema=self.schema) + self.spark.conf.set("spark.sql.execution.arrow.enable", "false") + pdf = df.toPandas() + self.spark.conf.set("spark.sql.execution.arrow.enable", "true") + pdf_arrow = df.toPandas() + self.assertFramesEqual(pdf_arrow, pdf) + + def test_pandas_round_trip(self): + import pandas as pd + import numpy as np + data_dict = {} + for j, name in enumerate(self.schema.names): + data_dict[name] = [self.data[i][j] for i in range(len(self.data))] + # need to convert these to numpy types first + data_dict["2_int_t"] = np.int32(data_dict["2_int_t"]) + data_dict["4_float_t"] = np.float32(data_dict["4_float_t"]) + pdf = pd.DataFrame(data=data_dict) + df = self.spark.createDataFrame(self.data, schema=self.schema) + pdf_arrow = df.toPandas() + self.assertFramesEqual(pdf_arrow, pdf) + + def test_filtered_frame(self): + df = self.spark.range(3).toDF("i") + pdf = df.filter("i < 0").toPandas() + self.assertEqual(len(pdf.columns), 1) + self.assertEqual(pdf.columns[0], "i") + self.assertTrue(pdf.empty) + + if __name__ == "__main__": from pyspark.sql.tests import * if xmlrunner: diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 6ab3a615e6cc0..e609256db2802 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -846,6 +846,24 @@ object SQLConf { .intConf .createWithDefault(UnsafeExternalSorter.DEFAULT_NUM_ELEMENTS_FOR_SPILL_THRESHOLD.toInt) + val ARROW_EXECUTION_ENABLE = + buildConf("spark.sql.execution.arrow.enable") + .internal() + .doc("Make use of Apache Arrow for columnar data transfers. Currently available " + + "for use with pyspark.sql.DataFrame.toPandas with the following data types: " + + "StringType, BinaryType, BooleanType, DoubleType, FloatType, ByteType, IntegerType, " + + "LongType, ShortType") + .booleanConf + .createWithDefault(false) + + val ARROW_EXECUTION_MAX_RECORDS_PER_BATCH = + buildConf("spark.sql.execution.arrow.maxRecordsPerBatch") + .internal() + .doc("When using Apache Arrow, limit the maximum number of records that can be written " + + "to a single ArrowRecordBatch in memory. If set to zero or negative there is no limit.") + .intConf + .createWithDefault(10000) + object Deprecated { val MAPRED_REDUCE_TASKS = "mapred.reduce.tasks" } @@ -1104,6 +1122,10 @@ class SQLConf extends Serializable with Logging { def starSchemaFTRatio: Double = getConf(STARSCHEMA_FACT_TABLE_RATIO) + def arrowEnable: Boolean = getConf(ARROW_EXECUTION_ENABLE) + + def arrowMaxRecordsPerBatch: Int = getConf(ARROW_EXECUTION_MAX_RECORDS_PER_BATCH) + /** ********************** SQLConf functionality methods ************ */ /** Set Spark SQL configuration properties. */ diff --git a/sql/core/pom.xml b/sql/core/pom.xml index 1bc34a6b069d9..661c31ded7148 100644 --- a/sql/core/pom.xml +++ b/sql/core/pom.xml @@ -103,6 +103,10 @@ jackson-databind ${fasterxml.jackson.version} + + org.apache.arrow + arrow-vector + org.apache.xbean xbean-asm5-shaded diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index d28ff7888d127..a2af9c2efe2ab 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -47,6 +47,7 @@ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.plans.physical.{Partitioning, PartitioningCollection} import org.apache.spark.sql.catalyst.util.{usePrettyExpression, DateTimeUtils} import org.apache.spark.sql.execution._ +import org.apache.spark.sql.execution.arrow.{ArrowConverters, ArrowPayload} import org.apache.spark.sql.execution.command._ import org.apache.spark.sql.execution.datasources.LogicalRelation import org.apache.spark.sql.execution.python.EvaluatePython @@ -2922,6 +2923,16 @@ class Dataset[T] private[sql]( } } + /** + * Collect a Dataset as ArrowPayload byte arrays and serve to PySpark. + */ + private[sql] def collectAsArrowToPython(): Int = { + withNewExecutionId { + val iter = toArrowPayload.collect().iterator.map(_.asPythonSerializable) + PythonRDD.serveIterator(iter, "serve-Arrow") + } + } + private[sql] def toPythonIterator(): Int = { withNewExecutionId { PythonRDD.toLocalIteratorAndServe(javaToPython.rdd) @@ -3003,4 +3014,13 @@ class Dataset[T] private[sql]( Dataset(sparkSession, logicalPlan) } } + + /** Convert to an RDD of ArrowPayload byte arrays */ + private[sql] def toArrowPayload: RDD[ArrowPayload] = { + val schemaCaptured = this.schema + val maxRecordsPerBatch = sparkSession.sessionState.conf.arrowMaxRecordsPerBatch + queryExecution.toRdd.mapPartitionsInternal { iter => + ArrowConverters.toPayloadIterator(iter, schemaCaptured, maxRecordsPerBatch) + } + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala new file mode 100644 index 0000000000000..6af5c73422377 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala @@ -0,0 +1,429 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +package org.apache.spark.sql.execution.arrow + +import java.io.ByteArrayOutputStream +import java.nio.channels.Channels + +import scala.collection.JavaConverters._ + +import io.netty.buffer.ArrowBuf +import org.apache.arrow.memory.{BufferAllocator, RootAllocator} +import org.apache.arrow.vector._ +import org.apache.arrow.vector.BaseValueVector.BaseMutator +import org.apache.arrow.vector.file._ +import org.apache.arrow.vector.schema.{ArrowFieldNode, ArrowRecordBatch} +import org.apache.arrow.vector.types.FloatingPointPrecision +import org.apache.arrow.vector.types.pojo.{ArrowType, Field, FieldType, Schema} +import org.apache.arrow.vector.util.ByteArrayReadableSeekableByteChannel + +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.types._ +import org.apache.spark.util.Utils + + +/** + * Store Arrow data in a form that can be serialized by Spark and served to a Python process. + */ +private[sql] class ArrowPayload private[arrow] (payload: Array[Byte]) extends Serializable { + + /** + * Convert the ArrowPayload to an ArrowRecordBatch. + */ + def loadBatch(allocator: BufferAllocator): ArrowRecordBatch = { + ArrowConverters.byteArrayToBatch(payload, allocator) + } + + /** + * Get the ArrowPayload as a type that can be served to Python. + */ + def asPythonSerializable: Array[Byte] = payload +} + +private[sql] object ArrowPayload { + + /** + * Create an ArrowPayload from an ArrowRecordBatch and Spark schema. + */ + def apply( + batch: ArrowRecordBatch, + schema: StructType, + allocator: BufferAllocator): ArrowPayload = { + new ArrowPayload(ArrowConverters.batchToByteArray(batch, schema, allocator)) + } +} + +private[sql] object ArrowConverters { + + /** + * Map a Spark DataType to ArrowType. + */ + private[arrow] def sparkTypeToArrowType(dataType: DataType): ArrowType = { + dataType match { + case BooleanType => ArrowType.Bool.INSTANCE + case ShortType => new ArrowType.Int(8 * ShortType.defaultSize, true) + case IntegerType => new ArrowType.Int(8 * IntegerType.defaultSize, true) + case LongType => new ArrowType.Int(8 * LongType.defaultSize, true) + case FloatType => new ArrowType.FloatingPoint(FloatingPointPrecision.SINGLE) + case DoubleType => new ArrowType.FloatingPoint(FloatingPointPrecision.DOUBLE) + case ByteType => new ArrowType.Int(8, true) + case StringType => ArrowType.Utf8.INSTANCE + case BinaryType => ArrowType.Binary.INSTANCE + case _ => throw new UnsupportedOperationException(s"Unsupported data type: $dataType") + } + } + + /** + * Convert a Spark Dataset schema to Arrow schema. + */ + private[arrow] def schemaToArrowSchema(schema: StructType): Schema = { + val arrowFields = schema.fields.map { f => + new Field(f.name, f.nullable, sparkTypeToArrowType(f.dataType), List.empty[Field].asJava) + } + new Schema(arrowFields.toList.asJava) + } + + /** + * Maps Iterator from InternalRow to ArrowPayload. Limit ArrowRecordBatch size in ArrowPayload + * by setting maxRecordsPerBatch or use 0 to fully consume rowIter. + */ + private[sql] def toPayloadIterator( + rowIter: Iterator[InternalRow], + schema: StructType, + maxRecordsPerBatch: Int): Iterator[ArrowPayload] = { + new Iterator[ArrowPayload] { + private val _allocator = new RootAllocator(Long.MaxValue) + private var _nextPayload = if (rowIter.nonEmpty) convert() else null + + override def hasNext: Boolean = _nextPayload != null + + override def next(): ArrowPayload = { + val obj = _nextPayload + if (hasNext) { + if (rowIter.hasNext) { + _nextPayload = convert() + } else { + _allocator.close() + _nextPayload = null + } + } + obj + } + + private def convert(): ArrowPayload = { + val batch = internalRowIterToArrowBatch(rowIter, schema, _allocator, maxRecordsPerBatch) + ArrowPayload(batch, schema, _allocator) + } + } + } + + /** + * Iterate over InternalRows and write to an ArrowRecordBatch, stopping when rowIter is consumed + * or the number of records in the batch equals maxRecordsInBatch. If maxRecordsPerBatch is 0, + * then rowIter will be fully consumed. + */ + private def internalRowIterToArrowBatch( + rowIter: Iterator[InternalRow], + schema: StructType, + allocator: BufferAllocator, + maxRecordsPerBatch: Int = 0): ArrowRecordBatch = { + + val columnWriters = schema.fields.zipWithIndex.map { case (field, ordinal) => + ColumnWriter(field.dataType, ordinal, allocator).init() + } + + val writerLength = columnWriters.length + var recordsInBatch = 0 + while (rowIter.hasNext && (maxRecordsPerBatch <= 0 || recordsInBatch < maxRecordsPerBatch)) { + val row = rowIter.next() + var i = 0 + while (i < writerLength) { + columnWriters(i).write(row) + i += 1 + } + recordsInBatch += 1 + } + + val (fieldNodes, bufferArrays) = columnWriters.map(_.finish()).unzip + val buffers = bufferArrays.flatten + + val rowLength = if (fieldNodes.nonEmpty) fieldNodes.head.getLength else 0 + val recordBatch = new ArrowRecordBatch(rowLength, + fieldNodes.toList.asJava, buffers.toList.asJava) + + buffers.foreach(_.release()) + recordBatch + } + + /** + * Convert an ArrowRecordBatch to a byte array and close batch to release resources. Once closed, + * the batch can no longer be used. + */ + private[arrow] def batchToByteArray( + batch: ArrowRecordBatch, + schema: StructType, + allocator: BufferAllocator): Array[Byte] = { + val arrowSchema = ArrowConverters.schemaToArrowSchema(schema) + val root = VectorSchemaRoot.create(arrowSchema, allocator) + val out = new ByteArrayOutputStream() + val writer = new ArrowFileWriter(root, null, Channels.newChannel(out)) + + // Write a batch to byte stream, ensure the batch, allocator and writer are closed + Utils.tryWithSafeFinally { + val loader = new VectorLoader(root) + loader.load(batch) + writer.writeBatch() // writeBatch can throw IOException + } { + batch.close() + root.close() + writer.close() + } + out.toByteArray + } + + /** + * Convert a byte array to an ArrowRecordBatch. + */ + private[arrow] def byteArrayToBatch( + batchBytes: Array[Byte], + allocator: BufferAllocator): ArrowRecordBatch = { + val in = new ByteArrayReadableSeekableByteChannel(batchBytes) + val reader = new ArrowFileReader(in, allocator) + + // Read a batch from a byte stream, ensure the reader is closed + Utils.tryWithSafeFinally { + val root = reader.getVectorSchemaRoot // throws IOException + val unloader = new VectorUnloader(root) + reader.loadNextBatch() // throws IOException + unloader.getRecordBatch + } { + reader.close() + } + } +} + +/** + * Interface for writing InternalRows to Arrow Buffers. + */ +private[arrow] trait ColumnWriter { + def init(): this.type + def write(row: InternalRow): Unit + + /** + * Clear the column writer and return the ArrowFieldNode and ArrowBuf. + * This should be called only once after all the data is written. + */ + def finish(): (ArrowFieldNode, Array[ArrowBuf]) +} + +/** + * Base class for flat arrow column writer, i.e., column without children. + */ +private[arrow] abstract class PrimitiveColumnWriter(val ordinal: Int) + extends ColumnWriter { + + def getFieldType(dtype: ArrowType): FieldType = FieldType.nullable(dtype) + + def valueVector: BaseDataValueVector + def valueMutator: BaseMutator + + def setNull(): Unit + def setValue(row: InternalRow): Unit + + protected var count = 0 + protected var nullCount = 0 + + override def init(): this.type = { + valueVector.allocateNew() + this + } + + override def write(row: InternalRow): Unit = { + if (row.isNullAt(ordinal)) { + setNull() + nullCount += 1 + } else { + setValue(row) + } + count += 1 + } + + override def finish(): (ArrowFieldNode, Array[ArrowBuf]) = { + valueMutator.setValueCount(count) + val fieldNode = new ArrowFieldNode(count, nullCount) + val valueBuffers = valueVector.getBuffers(true) + (fieldNode, valueBuffers) + } +} + +private[arrow] class BooleanColumnWriter(dtype: ArrowType, ordinal: Int, allocator: BufferAllocator) + extends PrimitiveColumnWriter(ordinal) { + override val valueVector: NullableBitVector + = new NullableBitVector("BooleanValue", getFieldType(dtype), allocator) + override val valueMutator: NullableBitVector#Mutator = valueVector.getMutator + + override def setNull(): Unit = valueMutator.setNull(count) + override def setValue(row: InternalRow): Unit + = valueMutator.setSafe(count, if (row.getBoolean(ordinal)) 1 else 0 ) +} + +private[arrow] class ShortColumnWriter(dtype: ArrowType, ordinal: Int, allocator: BufferAllocator) + extends PrimitiveColumnWriter(ordinal) { + override val valueVector: NullableSmallIntVector + = new NullableSmallIntVector("ShortValue", getFieldType(dtype: ArrowType), allocator) + override val valueMutator: NullableSmallIntVector#Mutator = valueVector.getMutator + + override def setNull(): Unit = valueMutator.setNull(count) + override def setValue(row: InternalRow): Unit + = valueMutator.setSafe(count, row.getShort(ordinal)) +} + +private[arrow] class IntegerColumnWriter(dtype: ArrowType, ordinal: Int, allocator: BufferAllocator) + extends PrimitiveColumnWriter(ordinal) { + override val valueVector: NullableIntVector + = new NullableIntVector("IntValue", getFieldType(dtype), allocator) + override val valueMutator: NullableIntVector#Mutator = valueVector.getMutator + + override def setNull(): Unit = valueMutator.setNull(count) + override def setValue(row: InternalRow): Unit + = valueMutator.setSafe(count, row.getInt(ordinal)) +} + +private[arrow] class LongColumnWriter(dtype: ArrowType, ordinal: Int, allocator: BufferAllocator) + extends PrimitiveColumnWriter(ordinal) { + override val valueVector: NullableBigIntVector + = new NullableBigIntVector("LongValue", getFieldType(dtype), allocator) + override val valueMutator: NullableBigIntVector#Mutator = valueVector.getMutator + + override def setNull(): Unit = valueMutator.setNull(count) + override def setValue(row: InternalRow): Unit + = valueMutator.setSafe(count, row.getLong(ordinal)) +} + +private[arrow] class FloatColumnWriter(dtype: ArrowType, ordinal: Int, allocator: BufferAllocator) + extends PrimitiveColumnWriter(ordinal) { + override val valueVector: NullableFloat4Vector + = new NullableFloat4Vector("FloatValue", getFieldType(dtype), allocator) + override val valueMutator: NullableFloat4Vector#Mutator = valueVector.getMutator + + override def setNull(): Unit = valueMutator.setNull(count) + override def setValue(row: InternalRow): Unit + = valueMutator.setSafe(count, row.getFloat(ordinal)) +} + +private[arrow] class DoubleColumnWriter(dtype: ArrowType, ordinal: Int, allocator: BufferAllocator) + extends PrimitiveColumnWriter(ordinal) { + override val valueVector: NullableFloat8Vector + = new NullableFloat8Vector("DoubleValue", getFieldType(dtype), allocator) + override val valueMutator: NullableFloat8Vector#Mutator = valueVector.getMutator + + override def setNull(): Unit = valueMutator.setNull(count) + override def setValue(row: InternalRow): Unit + = valueMutator.setSafe(count, row.getDouble(ordinal)) +} + +private[arrow] class ByteColumnWriter(dtype: ArrowType, ordinal: Int, allocator: BufferAllocator) + extends PrimitiveColumnWriter(ordinal) { + override val valueVector: NullableUInt1Vector + = new NullableUInt1Vector("ByteValue", getFieldType(dtype), allocator) + override val valueMutator: NullableUInt1Vector#Mutator = valueVector.getMutator + + override def setNull(): Unit = valueMutator.setNull(count) + override def setValue(row: InternalRow): Unit + = valueMutator.setSafe(count, row.getByte(ordinal)) +} + +private[arrow] class UTF8StringColumnWriter( + dtype: ArrowType, + ordinal: Int, + allocator: BufferAllocator) + extends PrimitiveColumnWriter(ordinal) { + override val valueVector: NullableVarCharVector + = new NullableVarCharVector("UTF8StringValue", getFieldType(dtype), allocator) + override val valueMutator: NullableVarCharVector#Mutator = valueVector.getMutator + + override def setNull(): Unit = valueMutator.setNull(count) + override def setValue(row: InternalRow): Unit = { + val str = row.getUTF8String(ordinal) + valueMutator.setSafe(count, str.getByteBuffer, 0, str.numBytes) + } +} + +private[arrow] class BinaryColumnWriter(dtype: ArrowType, ordinal: Int, allocator: BufferAllocator) + extends PrimitiveColumnWriter(ordinal) { + override val valueVector: NullableVarBinaryVector + = new NullableVarBinaryVector("BinaryValue", getFieldType(dtype), allocator) + override val valueMutator: NullableVarBinaryVector#Mutator = valueVector.getMutator + + override def setNull(): Unit = valueMutator.setNull(count) + override def setValue(row: InternalRow): Unit = { + val bytes = row.getBinary(ordinal) + valueMutator.setSafe(count, bytes, 0, bytes.length) + } +} + +private[arrow] class DateColumnWriter(dtype: ArrowType, ordinal: Int, allocator: BufferAllocator) + extends PrimitiveColumnWriter(ordinal) { + override val valueVector: NullableDateDayVector + = new NullableDateDayVector("DateValue", getFieldType(dtype), allocator) + override val valueMutator: NullableDateDayVector#Mutator = valueVector.getMutator + + override def setNull(): Unit = valueMutator.setNull(count) + override def setValue(row: InternalRow): Unit = { + valueMutator.setSafe(count, row.getInt(ordinal)) + } +} + +private[arrow] class TimeStampColumnWriter( + dtype: ArrowType, + ordinal: Int, + allocator: BufferAllocator) + extends PrimitiveColumnWriter(ordinal) { + override val valueVector: NullableTimeStampMicroVector + = new NullableTimeStampMicroVector("TimeStampValue", getFieldType(dtype), allocator) + override val valueMutator: NullableTimeStampMicroVector#Mutator = valueVector.getMutator + + override def setNull(): Unit = valueMutator.setNull(count) + override def setValue(row: InternalRow): Unit = { + valueMutator.setSafe(count, row.getLong(ordinal)) + } +} + +private[arrow] object ColumnWriter { + + /** + * Create an Arrow ColumnWriter given the type and ordinal of row. + */ + def apply(dataType: DataType, ordinal: Int, allocator: BufferAllocator): ColumnWriter = { + val dtype = ArrowConverters.sparkTypeToArrowType(dataType) + dataType match { + case BooleanType => new BooleanColumnWriter(dtype, ordinal, allocator) + case ShortType => new ShortColumnWriter(dtype, ordinal, allocator) + case IntegerType => new IntegerColumnWriter(dtype, ordinal, allocator) + case LongType => new LongColumnWriter(dtype, ordinal, allocator) + case FloatType => new FloatColumnWriter(dtype, ordinal, allocator) + case DoubleType => new DoubleColumnWriter(dtype, ordinal, allocator) + case ByteType => new ByteColumnWriter(dtype, ordinal, allocator) + case StringType => new UTF8StringColumnWriter(dtype, ordinal, allocator) + case BinaryType => new BinaryColumnWriter(dtype, ordinal, allocator) + case DateType => new DateColumnWriter(dtype, ordinal, allocator) + case TimestampType => new TimeStampColumnWriter(dtype, ordinal, allocator) + case _ => throw new UnsupportedOperationException(s"Unsupported data type: $dataType") + } + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowConvertersSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowConvertersSuite.scala new file mode 100644 index 0000000000000..159328cc0d958 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowConvertersSuite.scala @@ -0,0 +1,1222 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.execution.arrow + +import java.io.File +import java.nio.charset.StandardCharsets +import java.sql.{Date, Timestamp} +import java.text.SimpleDateFormat +import java.util.Locale + +import com.google.common.io.Files +import org.apache.arrow.memory.RootAllocator +import org.apache.arrow.vector.{VectorLoader, VectorSchemaRoot} +import org.apache.arrow.vector.file.json.JsonFileReader +import org.apache.arrow.vector.util.Validator +import org.scalatest.BeforeAndAfterAll + +import org.apache.spark.SparkException +import org.apache.spark.sql.{DataFrame, Row} +import org.apache.spark.sql.test.SharedSQLContext +import org.apache.spark.sql.types.{BinaryType, StructField, StructType} +import org.apache.spark.util.Utils + + +class ArrowConvertersSuite extends SharedSQLContext with BeforeAndAfterAll { + import testImplicits._ + + private var tempDataPath: String = _ + + override def beforeAll(): Unit = { + super.beforeAll() + tempDataPath = Utils.createTempDir(namePrefix = "arrow").getAbsolutePath + } + + test("collect to arrow record batch") { + val indexData = (1 to 6).toDF("i") + val arrowPayloads = indexData.toArrowPayload.collect() + assert(arrowPayloads.nonEmpty) + assert(arrowPayloads.length == indexData.rdd.getNumPartitions) + val allocator = new RootAllocator(Long.MaxValue) + val arrowRecordBatches = arrowPayloads.map(_.loadBatch(allocator)) + val rowCount = arrowRecordBatches.map(_.getLength).sum + assert(rowCount === indexData.count()) + arrowRecordBatches.foreach(batch => assert(batch.getNodes.size() > 0)) + arrowRecordBatches.foreach(_.close()) + allocator.close() + } + + test("short conversion") { + val json = + s""" + |{ + | "schema" : { + | "fields" : [ { + | "name" : "a_s", + | "type" : { + | "name" : "int", + | "isSigned" : true, + | "bitWidth" : 16 + | }, + | "nullable" : false, + | "children" : [ ], + | "typeLayout" : { + | "vectors" : [ { + | "type" : "VALIDITY", + | "typeBitWidth" : 1 + | }, { + | "type" : "DATA", + | "typeBitWidth" : 16 + | } ] + | } + | }, { + | "name" : "b_s", + | "type" : { + | "name" : "int", + | "isSigned" : true, + | "bitWidth" : 16 + | }, + | "nullable" : true, + | "children" : [ ], + | "typeLayout" : { + | "vectors" : [ { + | "type" : "VALIDITY", + | "typeBitWidth" : 1 + | }, { + | "type" : "DATA", + | "typeBitWidth" : 16 + | } ] + | } + | } ] + | }, + | "batches" : [ { + | "count" : 6, + | "columns" : [ { + | "name" : "a_s", + | "count" : 6, + | "VALIDITY" : [ 1, 1, 1, 1, 1, 1 ], + | "DATA" : [ 1, -1, 2, -2, 32767, -32768 ] + | }, { + | "name" : "b_s", + | "count" : 6, + | "VALIDITY" : [ 1, 0, 0, 1, 0, 1 ], + | "DATA" : [ 1, 0, 0, -2, 0, -32768 ] + | } ] + | } ] + |} + """.stripMargin + + val a_s = List[Short](1, -1, 2, -2, 32767, -32768) + val b_s = List[Option[Short]](Some(1), None, None, Some(-2), None, Some(-32768)) + val df = a_s.zip(b_s).toDF("a_s", "b_s") + + collectAndValidate(df, json, "integer-16bit.json") + } + + test("int conversion") { + val json = + s""" + |{ + | "schema" : { + | "fields" : [ { + | "name" : "a_i", + | "type" : { + | "name" : "int", + | "isSigned" : true, + | "bitWidth" : 32 + | }, + | "nullable" : false, + | "children" : [ ], + | "typeLayout" : { + | "vectors" : [ { + | "type" : "VALIDITY", + | "typeBitWidth" : 1 + | }, { + | "type" : "DATA", + | "typeBitWidth" : 32 + | } ] + | } + | }, { + | "name" : "b_i", + | "type" : { + | "name" : "int", + | "isSigned" : true, + | "bitWidth" : 32 + | }, + | "nullable" : true, + | "children" : [ ], + | "typeLayout" : { + | "vectors" : [ { + | "type" : "VALIDITY", + | "typeBitWidth" : 1 + | }, { + | "type" : "DATA", + | "typeBitWidth" : 32 + | } ] + | } + | } ] + | }, + | "batches" : [ { + | "count" : 6, + | "columns" : [ { + | "name" : "a_i", + | "count" : 6, + | "VALIDITY" : [ 1, 1, 1, 1, 1, 1 ], + | "DATA" : [ 1, -1, 2, -2, 2147483647, -2147483648 ] + | }, { + | "name" : "b_i", + | "count" : 6, + | "VALIDITY" : [ 1, 0, 0, 1, 0, 1 ], + | "DATA" : [ 1, 0, 0, -2, 0, -2147483648 ] + | } ] + | } ] + |} + """.stripMargin + + val a_i = List[Int](1, -1, 2, -2, 2147483647, -2147483648) + val b_i = List[Option[Int]](Some(1), None, None, Some(-2), None, Some(-2147483648)) + val df = a_i.zip(b_i).toDF("a_i", "b_i") + + collectAndValidate(df, json, "integer-32bit.json") + } + + test("long conversion") { + val json = + s""" + |{ + | "schema" : { + | "fields" : [ { + | "name" : "a_l", + | "type" : { + | "name" : "int", + | "isSigned" : true, + | "bitWidth" : 64 + | }, + | "nullable" : false, + | "children" : [ ], + | "typeLayout" : { + | "vectors" : [ { + | "type" : "VALIDITY", + | "typeBitWidth" : 1 + | }, { + | "type" : "DATA", + | "typeBitWidth" : 64 + | } ] + | } + | }, { + | "name" : "b_l", + | "type" : { + | "name" : "int", + | "isSigned" : true, + | "bitWidth" : 64 + | }, + | "nullable" : true, + | "children" : [ ], + | "typeLayout" : { + | "vectors" : [ { + | "type" : "VALIDITY", + | "typeBitWidth" : 1 + | }, { + | "type" : "DATA", + | "typeBitWidth" : 64 + | } ] + | } + | } ] + | }, + | "batches" : [ { + | "count" : 6, + | "columns" : [ { + | "name" : "a_l", + | "count" : 6, + | "VALIDITY" : [ 1, 1, 1, 1, 1, 1 ], + | "DATA" : [ 1, -1, 2, -2, 9223372036854775807, -9223372036854775808 ] + | }, { + | "name" : "b_l", + | "count" : 6, + | "VALIDITY" : [ 1, 0, 0, 1, 0, 1 ], + | "DATA" : [ 1, 0, 0, -2, 0, -9223372036854775808 ] + | } ] + | } ] + |} + """.stripMargin + + val a_l = List[Long](1, -1, 2, -2, 9223372036854775807L, -9223372036854775808L) + val b_l = List[Option[Long]](Some(1), None, None, Some(-2), None, Some(-9223372036854775808L)) + val df = a_l.zip(b_l).toDF("a_l", "b_l") + + collectAndValidate(df, json, "integer-64bit.json") + } + + test("float conversion") { + val json = + s""" + |{ + | "schema" : { + | "fields" : [ { + | "name" : "a_f", + | "type" : { + | "name" : "floatingpoint", + | "precision" : "SINGLE" + | }, + | "nullable" : false, + | "children" : [ ], + | "typeLayout" : { + | "vectors" : [ { + | "type" : "VALIDITY", + | "typeBitWidth" : 1 + | }, { + | "type" : "DATA", + | "typeBitWidth" : 32 + | } ] + | } + | }, { + | "name" : "b_f", + | "type" : { + | "name" : "floatingpoint", + | "precision" : "SINGLE" + | }, + | "nullable" : true, + | "children" : [ ], + | "typeLayout" : { + | "vectors" : [ { + | "type" : "VALIDITY", + | "typeBitWidth" : 1 + | }, { + | "type" : "DATA", + | "typeBitWidth" : 32 + | } ] + | } + | } ] + | }, + | "batches" : [ { + | "count" : 6, + | "columns" : [ { + | "name" : "a_f", + | "count" : 6, + | "VALIDITY" : [ 1, 1, 1, 1, 1, 1 ], + | "DATA" : [ 1.0, 2.0, 0.01, 200.0, 0.0001, 20000.0 ] + | }, { + | "name" : "b_f", + | "count" : 6, + | "VALIDITY" : [ 1, 0, 0, 1, 0, 1 ], + | "DATA" : [ 1.1, 0.0, 0.0, 2.2, 0.0, 3.3 ] + | } ] + | } ] + |} + """.stripMargin + + val a_f = List(1.0f, 2.0f, 0.01f, 200.0f, 0.0001f, 20000.0f) + val b_f = List[Option[Float]](Some(1.1f), None, None, Some(2.2f), None, Some(3.3f)) + val df = a_f.zip(b_f).toDF("a_f", "b_f") + + collectAndValidate(df, json, "floating_point-single_precision.json") + } + + test("double conversion") { + val json = + s""" + |{ + | "schema" : { + | "fields" : [ { + | "name" : "a_d", + | "type" : { + | "name" : "floatingpoint", + | "precision" : "DOUBLE" + | }, + | "nullable" : false, + | "children" : [ ], + | "typeLayout" : { + | "vectors" : [ { + | "type" : "VALIDITY", + | "typeBitWidth" : 1 + | }, { + | "type" : "DATA", + | "typeBitWidth" : 64 + | } ] + | } + | }, { + | "name" : "b_d", + | "type" : { + | "name" : "floatingpoint", + | "precision" : "DOUBLE" + | }, + | "nullable" : true, + | "children" : [ ], + | "typeLayout" : { + | "vectors" : [ { + | "type" : "VALIDITY", + | "typeBitWidth" : 1 + | }, { + | "type" : "DATA", + | "typeBitWidth" : 64 + | } ] + | } + | } ] + | }, + | "batches" : [ { + | "count" : 6, + | "columns" : [ { + | "name" : "a_d", + | "count" : 6, + | "VALIDITY" : [ 1, 1, 1, 1, 1, 1 ], + | "DATA" : [ 1.0, 2.0, 0.01, 200.0, 1.0E-4, 20000.0 ] + | }, { + | "name" : "b_d", + | "count" : 6, + | "VALIDITY" : [ 1, 0, 0, 1, 0, 1 ], + | "DATA" : [ 1.1, 0.0, 0.0, 2.2, 0.0, 3.3 ] + | } ] + | } ] + |} + """.stripMargin + + val a_d = List(1.0, 2.0, 0.01, 200.0, 0.0001, 20000.0) + val b_d = List[Option[Double]](Some(1.1), None, None, Some(2.2), None, Some(3.3)) + val df = a_d.zip(b_d).toDF("a_d", "b_d") + + collectAndValidate(df, json, "floating_point-double_precision.json") + } + + test("index conversion") { + val data = List[Int](1, 2, 3, 4, 5, 6) + val json = + s""" + |{ + | "schema" : { + | "fields" : [ { + | "name" : "i", + | "type" : { + | "name" : "int", + | "isSigned" : true, + | "bitWidth" : 32 + | }, + | "nullable" : false, + | "children" : [ ], + | "typeLayout" : { + | "vectors" : [ { + | "type" : "VALIDITY", + | "typeBitWidth" : 1 + | }, { + | "type" : "DATA", + | "typeBitWidth" : 32 + | } ] + | } + | } ] + | }, + | "batches" : [ { + | "count" : 6, + | "columns" : [ { + | "name" : "i", + | "count" : 6, + | "VALIDITY" : [ 1, 1, 1, 1, 1, 1 ], + | "DATA" : [ 1, 2, 3, 4, 5, 6 ] + | } ] + | } ] + |} + """.stripMargin + val df = data.toDF("i") + + collectAndValidate(df, json, "indexData-ints.json") + } + + test("mixed numeric type conversion") { + val json = + s""" + |{ + | "schema" : { + | "fields" : [ { + | "name" : "a", + | "type" : { + | "name" : "int", + | "isSigned" : true, + | "bitWidth" : 16 + | }, + | "nullable" : false, + | "children" : [ ], + | "typeLayout" : { + | "vectors" : [ { + | "type" : "VALIDITY", + | "typeBitWidth" : 1 + | }, { + | "type" : "DATA", + | "typeBitWidth" : 16 + | } ] + | } + | }, { + | "name" : "b", + | "type" : { + | "name" : "floatingpoint", + | "precision" : "SINGLE" + | }, + | "nullable" : false, + | "children" : [ ], + | "typeLayout" : { + | "vectors" : [ { + | "type" : "VALIDITY", + | "typeBitWidth" : 1 + | }, { + | "type" : "DATA", + | "typeBitWidth" : 32 + | } ] + | } + | }, { + | "name" : "c", + | "type" : { + | "name" : "int", + | "isSigned" : true, + | "bitWidth" : 32 + | }, + | "nullable" : false, + | "children" : [ ], + | "typeLayout" : { + | "vectors" : [ { + | "type" : "VALIDITY", + | "typeBitWidth" : 1 + | }, { + | "type" : "DATA", + | "typeBitWidth" : 32 + | } ] + | } + | }, { + | "name" : "d", + | "type" : { + | "name" : "floatingpoint", + | "precision" : "DOUBLE" + | }, + | "nullable" : false, + | "children" : [ ], + | "typeLayout" : { + | "vectors" : [ { + | "type" : "VALIDITY", + | "typeBitWidth" : 1 + | }, { + | "type" : "DATA", + | "typeBitWidth" : 64 + | } ] + | } + | }, { + | "name" : "e", + | "type" : { + | "name" : "int", + | "isSigned" : true, + | "bitWidth" : 64 + | }, + | "nullable" : false, + | "children" : [ ], + | "typeLayout" : { + | "vectors" : [ { + | "type" : "VALIDITY", + | "typeBitWidth" : 1 + | }, { + | "type" : "DATA", + | "typeBitWidth" : 64 + | } ] + | } + | } ] + | }, + | "batches" : [ { + | "count" : 6, + | "columns" : [ { + | "name" : "a", + | "count" : 6, + | "VALIDITY" : [ 1, 1, 1, 1, 1, 1 ], + | "DATA" : [ 1, 2, 3, 4, 5, 6 ] + | }, { + | "name" : "b", + | "count" : 6, + | "VALIDITY" : [ 1, 1, 1, 1, 1, 1 ], + | "DATA" : [ 1.0, 2.0, 3.0, 4.0, 5.0, 6.0 ] + | }, { + | "name" : "c", + | "count" : 6, + | "VALIDITY" : [ 1, 1, 1, 1, 1, 1 ], + | "DATA" : [ 1, 2, 3, 4, 5, 6 ] + | }, { + | "name" : "d", + | "count" : 6, + | "VALIDITY" : [ 1, 1, 1, 1, 1, 1 ], + | "DATA" : [ 1.0, 2.0, 3.0, 4.0, 5.0, 6.0 ] + | }, { + | "name" : "e", + | "count" : 6, + | "VALIDITY" : [ 1, 1, 1, 1, 1, 1 ], + | "DATA" : [ 1, 2, 3, 4, 5, 6 ] + | } ] + | } ] + |} + """.stripMargin + + val data = List(1, 2, 3, 4, 5, 6) + val data_tuples = for (d <- data) yield { + (d.toShort, d.toFloat, d.toInt, d.toDouble, d.toLong) + } + val df = data_tuples.toDF("a", "b", "c", "d", "e") + + collectAndValidate(df, json, "mixed_numeric_types.json") + } + + test("string type conversion") { + val json = + s""" + |{ + | "schema" : { + | "fields" : [ { + | "name" : "upper_case", + | "type" : { + | "name" : "utf8" + | }, + | "nullable" : true, + | "children" : [ ], + | "typeLayout" : { + | "vectors" : [ { + | "type" : "VALIDITY", + | "typeBitWidth" : 1 + | }, { + | "type" : "OFFSET", + | "typeBitWidth" : 32 + | }, { + | "type" : "DATA", + | "typeBitWidth" : 8 + | } ] + | } + | }, { + | "name" : "lower_case", + | "type" : { + | "name" : "utf8" + | }, + | "nullable" : true, + | "children" : [ ], + | "typeLayout" : { + | "vectors" : [ { + | "type" : "VALIDITY", + | "typeBitWidth" : 1 + | }, { + | "type" : "OFFSET", + | "typeBitWidth" : 32 + | }, { + | "type" : "DATA", + | "typeBitWidth" : 8 + | } ] + | } + | }, { + | "name" : "null_str", + | "type" : { + | "name" : "utf8" + | }, + | "nullable" : true, + | "children" : [ ], + | "typeLayout" : { + | "vectors" : [ { + | "type" : "VALIDITY", + | "typeBitWidth" : 1 + | }, { + | "type" : "OFFSET", + | "typeBitWidth" : 32 + | }, { + | "type" : "DATA", + | "typeBitWidth" : 8 + | } ] + | } + | } ] + | }, + | "batches" : [ { + | "count" : 3, + | "columns" : [ { + | "name" : "upper_case", + | "count" : 3, + | "VALIDITY" : [ 1, 1, 1 ], + | "OFFSET" : [ 0, 1, 2, 3 ], + | "DATA" : [ "A", "B", "C" ] + | }, { + | "name" : "lower_case", + | "count" : 3, + | "VALIDITY" : [ 1, 1, 1 ], + | "OFFSET" : [ 0, 1, 2, 3 ], + | "DATA" : [ "a", "b", "c" ] + | }, { + | "name" : "null_str", + | "count" : 3, + | "VALIDITY" : [ 1, 1, 0 ], + | "OFFSET" : [ 0, 2, 5, 5 ], + | "DATA" : [ "ab", "CDE", "" ] + | } ] + | } ] + |} + """.stripMargin + + val upperCase = Seq("A", "B", "C") + val lowerCase = Seq("a", "b", "c") + val nullStr = Seq("ab", "CDE", null) + val df = (upperCase, lowerCase, nullStr).zipped.toList + .toDF("upper_case", "lower_case", "null_str") + + collectAndValidate(df, json, "stringData.json") + } + + test("boolean type conversion") { + val json = + s""" + |{ + | "schema" : { + | "fields" : [ { + | "name" : "a_bool", + | "type" : { + | "name" : "bool" + | }, + | "nullable" : false, + | "children" : [ ], + | "typeLayout" : { + | "vectors" : [ { + | "type" : "VALIDITY", + | "typeBitWidth" : 1 + | }, { + | "type" : "DATA", + | "typeBitWidth" : 1 + | } ] + | } + | } ] + | }, + | "batches" : [ { + | "count" : 4, + | "columns" : [ { + | "name" : "a_bool", + | "count" : 4, + | "VALIDITY" : [ 1, 1, 1, 1 ], + | "DATA" : [ true, true, false, true ] + | } ] + | } ] + |} + """.stripMargin + val df = Seq(true, true, false, true).toDF("a_bool") + collectAndValidate(df, json, "boolData.json") + } + + test("byte type conversion") { + val json = + s""" + |{ + | "schema" : { + | "fields" : [ { + | "name" : "a_byte", + | "type" : { + | "name" : "int", + | "isSigned" : true, + | "bitWidth" : 8 + | }, + | "nullable" : false, + | "children" : [ ], + | "typeLayout" : { + | "vectors" : [ { + | "type" : "VALIDITY", + | "typeBitWidth" : 1 + | }, { + | "type" : "DATA", + | "typeBitWidth" : 8 + | } ] + | } + | } ] + | }, + | "batches" : [ { + | "count" : 4, + | "columns" : [ { + | "name" : "a_byte", + | "count" : 4, + | "VALIDITY" : [ 1, 1, 1, 1 ], + | "DATA" : [ 1, -1, 64, 127 ] + | } ] + | } ] + |} + | + """.stripMargin + val df = List[Byte](1.toByte, (-1).toByte, 64.toByte, Byte.MaxValue).toDF("a_byte") + collectAndValidate(df, json, "byteData.json") + } + + test("binary type conversion") { + val json = + s""" + |{ + | "schema" : { + | "fields" : [ { + | "name" : "a_binary", + | "type" : { + | "name" : "binary" + | }, + | "nullable" : true, + | "children" : [ ], + | "typeLayout" : { + | "vectors" : [ { + | "type" : "VALIDITY", + | "typeBitWidth" : 1 + | }, { + | "type" : "OFFSET", + | "typeBitWidth" : 32 + | }, { + | "type" : "DATA", + | "typeBitWidth" : 8 + | } ] + | } + | } ] + | }, + | "batches" : [ { + | "count" : 3, + | "columns" : [ { + | "name" : "a_binary", + | "count" : 3, + | "VALIDITY" : [ 1, 1, 1 ], + | "OFFSET" : [ 0, 3, 4, 6 ], + | "DATA" : [ "616263", "64", "6566" ] + | } ] + | } ] + |} + """.stripMargin + + val data = Seq("abc", "d", "ef") + val rdd = sparkContext.parallelize(data.map(s => Row(s.getBytes("utf-8")))) + val df = spark.createDataFrame(rdd, StructType(Seq(StructField("a_binary", BinaryType)))) + + collectAndValidate(df, json, "binaryData.json") + } + + test("floating-point NaN") { + val json = + s""" + |{ + | "schema" : { + | "fields" : [ { + | "name" : "NaN_f", + | "type" : { + | "name" : "floatingpoint", + | "precision" : "SINGLE" + | }, + | "nullable" : false, + | "children" : [ ], + | "typeLayout" : { + | "vectors" : [ { + | "type" : "VALIDITY", + | "typeBitWidth" : 1 + | }, { + | "type" : "DATA", + | "typeBitWidth" : 32 + | } ] + | } + | }, { + | "name" : "NaN_d", + | "type" : { + | "name" : "floatingpoint", + | "precision" : "DOUBLE" + | }, + | "nullable" : false, + | "children" : [ ], + | "typeLayout" : { + | "vectors" : [ { + | "type" : "VALIDITY", + | "typeBitWidth" : 1 + | }, { + | "type" : "DATA", + | "typeBitWidth" : 64 + | } ] + | } + | } ] + | }, + | "batches" : [ { + | "count" : 2, + | "columns" : [ { + | "name" : "NaN_f", + | "count" : 2, + | "VALIDITY" : [ 1, 1 ], + | "DATA" : [ 1.2000000476837158, "NaN" ] + | }, { + | "name" : "NaN_d", + | "count" : 2, + | "VALIDITY" : [ 1, 1 ], + | "DATA" : [ "NaN", 1.2 ] + | } ] + | } ] + |} + """.stripMargin + + val fnan = Seq(1.2F, Float.NaN) + val dnan = Seq(Double.NaN, 1.2) + val df = fnan.zip(dnan).toDF("NaN_f", "NaN_d") + + collectAndValidate(df, json, "nanData-floating_point.json") + } + + test("partitioned DataFrame") { + val json1 = + s""" + |{ + | "schema" : { + | "fields" : [ { + | "name" : "a", + | "type" : { + | "name" : "int", + | "isSigned" : true, + | "bitWidth" : 32 + | }, + | "nullable" : false, + | "children" : [ ], + | "typeLayout" : { + | "vectors" : [ { + | "type" : "VALIDITY", + | "typeBitWidth" : 1 + | }, { + | "type" : "DATA", + | "typeBitWidth" : 32 + | } ] + | } + | }, { + | "name" : "b", + | "type" : { + | "name" : "int", + | "isSigned" : true, + | "bitWidth" : 32 + | }, + | "nullable" : false, + | "children" : [ ], + | "typeLayout" : { + | "vectors" : [ { + | "type" : "VALIDITY", + | "typeBitWidth" : 1 + | }, { + | "type" : "DATA", + | "typeBitWidth" : 32 + | } ] + | } + | } ] + | }, + | "batches" : [ { + | "count" : 3, + | "columns" : [ { + | "name" : "a", + | "count" : 3, + | "VALIDITY" : [ 1, 1, 1 ], + | "DATA" : [ 1, 1, 2 ] + | }, { + | "name" : "b", + | "count" : 3, + | "VALIDITY" : [ 1, 1, 1 ], + | "DATA" : [ 1, 2, 1 ] + | } ] + | } ] + |} + """.stripMargin + val json2 = + s""" + |{ + | "schema" : { + | "fields" : [ { + | "name" : "a", + | "type" : { + | "name" : "int", + | "isSigned" : true, + | "bitWidth" : 32 + | }, + | "nullable" : false, + | "children" : [ ], + | "typeLayout" : { + | "vectors" : [ { + | "type" : "VALIDITY", + | "typeBitWidth" : 1 + | }, { + | "type" : "DATA", + | "typeBitWidth" : 32 + | } ] + | } + | }, { + | "name" : "b", + | "type" : { + | "name" : "int", + | "isSigned" : true, + | "bitWidth" : 32 + | }, + | "nullable" : false, + | "children" : [ ], + | "typeLayout" : { + | "vectors" : [ { + | "type" : "VALIDITY", + | "typeBitWidth" : 1 + | }, { + | "type" : "DATA", + | "typeBitWidth" : 32 + | } ] + | } + | } ] + | }, + | "batches" : [ { + | "count" : 3, + | "columns" : [ { + | "name" : "a", + | "count" : 3, + | "VALIDITY" : [ 1, 1, 1 ], + | "DATA" : [ 2, 3, 3 ] + | }, { + | "name" : "b", + | "count" : 3, + | "VALIDITY" : [ 1, 1, 1 ], + | "DATA" : [ 2, 1, 2 ] + | } ] + | } ] + |} + """.stripMargin + + val arrowPayloads = testData2.toArrowPayload.collect() + // NOTE: testData2 should have 2 partitions -> 2 arrow batches in payload + assert(arrowPayloads.length === 2) + val schema = testData2.schema + + val tempFile1 = new File(tempDataPath, "testData2-ints-part1.json") + val tempFile2 = new File(tempDataPath, "testData2-ints-part2.json") + Files.write(json1, tempFile1, StandardCharsets.UTF_8) + Files.write(json2, tempFile2, StandardCharsets.UTF_8) + + validateConversion(schema, arrowPayloads(0), tempFile1) + validateConversion(schema, arrowPayloads(1), tempFile2) + } + + test("empty frame collect") { + val arrowPayload = spark.emptyDataFrame.toArrowPayload.collect() + assert(arrowPayload.isEmpty) + + val filteredDF = List[Int](1, 2, 3, 4, 5, 6).toDF("i") + val filteredArrowPayload = filteredDF.filter("i < 0").toArrowPayload.collect() + assert(filteredArrowPayload.isEmpty) + } + + test("empty partition collect") { + val emptyPart = spark.sparkContext.parallelize(Seq(1), 2).toDF("i") + val arrowPayloads = emptyPart.toArrowPayload.collect() + assert(arrowPayloads.length === 1) + val allocator = new RootAllocator(Long.MaxValue) + val arrowRecordBatches = arrowPayloads.map(_.loadBatch(allocator)) + assert(arrowRecordBatches.head.getLength == 1) + arrowRecordBatches.foreach(_.close()) + allocator.close() + } + + test("max records in batch conf") { + val totalRecords = 10 + val maxRecordsPerBatch = 3 + spark.conf.set("spark.sql.execution.arrow.maxRecordsPerBatch", maxRecordsPerBatch) + val df = spark.sparkContext.parallelize(1 to totalRecords, 2).toDF("i") + val arrowPayloads = df.toArrowPayload.collect() + val allocator = new RootAllocator(Long.MaxValue) + val arrowRecordBatches = arrowPayloads.map(_.loadBatch(allocator)) + var recordCount = 0 + arrowRecordBatches.foreach { batch => + assert(batch.getLength > 0) + assert(batch.getLength <= maxRecordsPerBatch) + recordCount += batch.getLength + batch.close() + } + assert(recordCount == totalRecords) + allocator.close() + spark.conf.unset("spark.sql.execution.arrow.maxRecordsPerBatch") + } + + testQuietly("unsupported types") { + def runUnsupported(block: => Unit): Unit = { + val msg = intercept[SparkException] { + block + } + assert(msg.getMessage.contains("Unsupported data type")) + assert(msg.getCause.getClass === classOf[UnsupportedOperationException]) + } + + runUnsupported { decimalData.toArrowPayload.collect() } + runUnsupported { arrayData.toDF().toArrowPayload.collect() } + runUnsupported { mapData.toDF().toArrowPayload.collect() } + runUnsupported { complexData.toArrowPayload.collect() } + + val sdf = new SimpleDateFormat("yyyy-MM-dd HH:mm:ss.SSS z", Locale.US) + val d1 = new Date(sdf.parse("2015-04-08 13:10:15.000 UTC").getTime) + val d2 = new Date(sdf.parse("2016-05-09 13:10:15.000 UTC").getTime) + runUnsupported { Seq(d1, d2).toDF("date").toArrowPayload.collect() } + + val ts1 = new Timestamp(sdf.parse("2013-04-08 01:10:15.567 UTC").getTime) + val ts2 = new Timestamp(sdf.parse("2013-04-08 13:10:10.789 UTC").getTime) + runUnsupported { Seq(ts1, ts2).toDF("timestamp").toArrowPayload.collect() } + } + + test("test Arrow Validator") { + val json = + s""" + |{ + | "schema" : { + | "fields" : [ { + | "name" : "a_i", + | "type" : { + | "name" : "int", + | "isSigned" : true, + | "bitWidth" : 32 + | }, + | "nullable" : false, + | "children" : [ ], + | "typeLayout" : { + | "vectors" : [ { + | "type" : "VALIDITY", + | "typeBitWidth" : 1 + | }, { + | "type" : "DATA", + | "typeBitWidth" : 32 + | } ] + | } + | }, { + | "name" : "b_i", + | "type" : { + | "name" : "int", + | "isSigned" : true, + | "bitWidth" : 32 + | }, + | "nullable" : true, + | "children" : [ ], + | "typeLayout" : { + | "vectors" : [ { + | "type" : "VALIDITY", + | "typeBitWidth" : 1 + | }, { + | "type" : "DATA", + | "typeBitWidth" : 32 + | } ] + | } + | } ] + | }, + | "batches" : [ { + | "count" : 6, + | "columns" : [ { + | "name" : "a_i", + | "count" : 6, + | "VALIDITY" : [ 1, 1, 1, 1, 1, 1 ], + | "DATA" : [ 1, -1, 2, -2, 2147483647, -2147483648 ] + | }, { + | "name" : "b_i", + | "count" : 6, + | "VALIDITY" : [ 1, 0, 0, 1, 0, 1 ], + | "DATA" : [ 1, 0, 0, -2, 0, -2147483648 ] + | } ] + | } ] + |} + """.stripMargin + val json_diff_col_order = + s""" + |{ + | "schema" : { + | "fields" : [ { + | "name" : "b_i", + | "type" : { + | "name" : "int", + | "isSigned" : true, + | "bitWidth" : 32 + | }, + | "nullable" : true, + | "children" : [ ], + | "typeLayout" : { + | "vectors" : [ { + | "type" : "VALIDITY", + | "typeBitWidth" : 1 + | }, { + | "type" : "DATA", + | "typeBitWidth" : 32 + | } ] + | } + | }, { + | "name" : "a_i", + | "type" : { + | "name" : "int", + | "isSigned" : true, + | "bitWidth" : 32 + | }, + | "nullable" : false, + | "children" : [ ], + | "typeLayout" : { + | "vectors" : [ { + | "type" : "VALIDITY", + | "typeBitWidth" : 1 + | }, { + | "type" : "DATA", + | "typeBitWidth" : 32 + | } ] + | } + | } ] + | }, + | "batches" : [ { + | "count" : 6, + | "columns" : [ { + | "name" : "a_i", + | "count" : 6, + | "VALIDITY" : [ 1, 1, 1, 1, 1, 1 ], + | "DATA" : [ 1, -1, 2, -2, 2147483647, -2147483648 ] + | }, { + | "name" : "b_i", + | "count" : 6, + | "VALIDITY" : [ 1, 0, 0, 1, 0, 1 ], + | "DATA" : [ 1, 0, 0, -2, 0, -2147483648 ] + | } ] + | } ] + |} + """.stripMargin + + val a_i = List[Int](1, -1, 2, -2, 2147483647, -2147483648) + val b_i = List[Option[Int]](Some(1), None, None, Some(-2), None, Some(-2147483648)) + val df = a_i.zip(b_i).toDF("a_i", "b_i") + + // Different schema + intercept[IllegalArgumentException] { + collectAndValidate(df, json_diff_col_order, "validator_diff_schema.json") + } + + // Different values + intercept[IllegalArgumentException] { + collectAndValidate(df.sort($"a_i".desc), json, "validator_diff_values.json") + } + } + + /** Test that a converted DataFrame to Arrow record batch equals batch read from JSON file */ + private def collectAndValidate(df: DataFrame, json: String, file: String): Unit = { + // NOTE: coalesce to single partition because can only load 1 batch in validator + val arrowPayload = df.coalesce(1).toArrowPayload.collect().head + val tempFile = new File(tempDataPath, file) + Files.write(json, tempFile, StandardCharsets.UTF_8) + validateConversion(df.schema, arrowPayload, tempFile) + } + + private def validateConversion( + sparkSchema: StructType, + arrowPayload: ArrowPayload, + jsonFile: File): Unit = { + val allocator = new RootAllocator(Long.MaxValue) + val jsonReader = new JsonFileReader(jsonFile, allocator) + + val arrowSchema = ArrowConverters.schemaToArrowSchema(sparkSchema) + val jsonSchema = jsonReader.start() + Validator.compareSchemas(arrowSchema, jsonSchema) + + val arrowRoot = VectorSchemaRoot.create(arrowSchema, allocator) + val vectorLoader = new VectorLoader(arrowRoot) + val arrowRecordBatch = arrowPayload.loadBatch(allocator) + vectorLoader.load(arrowRecordBatch) + val jsonRoot = jsonReader.read() + Validator.compareVectorSchemaRoot(arrowRoot, jsonRoot) + + jsonRoot.close() + jsonReader.close() + arrowRecordBatch.close() + arrowRoot.close() + allocator.close() + } +} From 5b5a69bea9de806e2c39b04b248ee82a7b664d7b Mon Sep 17 00:00:00 2001 From: Thomas Graves Date: Fri, 23 Jun 2017 09:19:02 +0800 Subject: [PATCH 0776/1765] [SPARK-20923] turn tracking of TaskMetrics._updatedBlockStatuses off ## What changes were proposed in this pull request? Turn tracking of TaskMetrics._updatedBlockStatuses off by default. As far as I can see its not used by anything and it uses a lot of memory when caching and processing a lot of blocks. In my case it was taking 5GB of a 10GB heap and I even went up to 50GB heap and the job still ran out of memory. With this change in place the same job easily runs in less then 10GB of heap. We leave the api there as well as a config to turn it back on just in case anyone is using it. TaskMetrics is exposed via SparkListenerTaskEnd so if users are relying on it they can turn it back on. ## How was this patch tested? Ran unit tests that were modified and manually tested on a couple of jobs (with and without caching). Clicked through the UI and didn't see anything missing. Ran my very large hive query job with 200,000 small tasks, 1000 executors, cached 6+TB of data this runs fine now whereas without this change it would go into full gcs and eventually die. Author: Thomas Graves Author: Tom Graves Closes #18162 from tgravescs/SPARK-20923. --- .../apache/spark/executor/TaskMetrics.scala | 6 ++++ .../spark/internal/config/package.scala | 8 +++++ .../apache/spark/storage/BlockManager.scala | 6 ++-- .../spark/storage/BlockManagerSuite.scala | 32 ++++++++++++++++++- 4 files changed, 49 insertions(+), 3 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala b/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala index 341a6da8107ef..85b2745a2aec4 100644 --- a/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala +++ b/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala @@ -112,6 +112,12 @@ class TaskMetrics private[spark] () extends Serializable { /** * Storage statuses of any blocks that have been updated as a result of this task. + * + * Tracking the _updatedBlockStatuses can use a lot of memory. + * It is not used anywhere inside of Spark so we would ideally remove it, but its exposed to + * the user in SparkListenerTaskEnd so the api is kept for compatibility. + * Tracking can be turned off to save memory via config + * TASK_METRICS_TRACK_UPDATED_BLOCK_STATUSES. */ def updatedBlockStatuses: Seq[(BlockId, BlockStatus)] = { // This is called on driver. All accumulator updates have a fixed value. So it's safe to use diff --git a/core/src/main/scala/org/apache/spark/internal/config/package.scala b/core/src/main/scala/org/apache/spark/internal/config/package.scala index 615497d36fd14..462c1890fd8df 100644 --- a/core/src/main/scala/org/apache/spark/internal/config/package.scala +++ b/core/src/main/scala/org/apache/spark/internal/config/package.scala @@ -322,4 +322,12 @@ package object config { "above this threshold. This is to avoid a giant request takes too much memory.") .bytesConf(ByteUnit.BYTE) .createWithDefaultString("200m") + + private[spark] val TASK_METRICS_TRACK_UPDATED_BLOCK_STATUSES = + ConfigBuilder("spark.taskMetrics.trackUpdatedBlockStatuses") + .doc("Enable tracking of updatedBlockStatuses in the TaskMetrics. Off by default since " + + "tracking the block statuses can use a lot of memory and its not used anywhere within " + + "spark.") + .booleanConf + .createWithDefault(false) } diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala index 74be70348305c..adbe3cfd89ea6 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala @@ -1473,8 +1473,10 @@ private[spark] class BlockManager( } private def addUpdatedBlockStatusToTaskMetrics(blockId: BlockId, status: BlockStatus): Unit = { - Option(TaskContext.get()).foreach { c => - c.taskMetrics().incUpdatedBlockStatuses(blockId -> status) + if (conf.get(config.TASK_METRICS_TRACK_UPDATED_BLOCK_STATUSES)) { + Option(TaskContext.get()).foreach { c => + c.taskMetrics().incUpdatedBlockStatuses(blockId -> status) + } } } diff --git a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala index 88f18294aa015..086adccea954c 100644 --- a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala @@ -922,8 +922,38 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE } } + test("turn off updated block statuses") { + val conf = new SparkConf() + conf.set(TASK_METRICS_TRACK_UPDATED_BLOCK_STATUSES, false) + store = makeBlockManager(12000, testConf = Some(conf)) + + store.registerTask(0) + val list = List.fill(2)(new Array[Byte](2000)) + + def getUpdatedBlocks(task: => Unit): Seq[(BlockId, BlockStatus)] = { + val context = TaskContext.empty() + try { + TaskContext.setTaskContext(context) + task + } finally { + TaskContext.unset() + } + context.taskMetrics.updatedBlockStatuses + } + + // 1 updated block (i.e. list1) + val updatedBlocks1 = getUpdatedBlocks { + store.putIterator( + "list1", list.iterator, StorageLevel.MEMORY_ONLY, tellMaster = true) + } + assert(updatedBlocks1.size === 0) + } + + test("updated block statuses") { - store = makeBlockManager(12000) + val conf = new SparkConf() + conf.set(TASK_METRICS_TRACK_UPDATED_BLOCK_STATUSES, true) + store = makeBlockManager(12000, testConf = Some(conf)) store.registerTask(0) val list = List.fill(2)(new Array[Byte](2000)) val bigList = List.fill(8)(new Array[Byte](2000)) From b8a743b6a531432e57eb50ecff06798ebc19483e Mon Sep 17 00:00:00 2001 From: Wang Gengliang Date: Fri, 23 Jun 2017 09:27:35 +0800 Subject: [PATCH 0777/1765] [SPARK-21174][SQL] Validate sampling fraction in logical operator level ## What changes were proposed in this pull request? Currently the validation of sampling fraction in dataset is incomplete. As an improvement, validate sampling fraction in logical operator level: 1) if with replacement: fraction should be nonnegative 2) else: fraction should be on interval [0, 1] Also add test cases for the validation. ## How was this patch tested? integration tests gatorsmile cloud-fan Please review http://spark.apache.org/contributing.html before opening a pull request. Author: Wang Gengliang Closes #18387 from gengliangwang/sample_ratio_validate. --- .../spark/sql/catalyst/parser/SqlBase.g4 | 2 +- .../sql/catalyst/parser/AstBuilder.scala | 3 +- .../plans/logical/basicLogicalOperators.scala | 13 ++++ .../scala/org/apache/spark/sql/Dataset.scala | 3 - .../sql-tests/inputs/tablesample-negative.sql | 14 +++++ .../sql-tests/results/operators.sql.out | 8 +-- .../results/tablesample-negative.sql.out | 62 +++++++++++++++++++ .../org/apache/spark/sql/DatasetSuite.scala | 28 +++++++++ 8 files changed, 124 insertions(+), 9 deletions(-) create mode 100644 sql/core/src/test/resources/sql-tests/inputs/tablesample-negative.sql create mode 100644 sql/core/src/test/resources/sql-tests/results/tablesample-negative.sql.out diff --git a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 index ef5648c6dbe47..9456031736528 100644 --- a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 +++ b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 @@ -440,7 +440,7 @@ joinCriteria sample : TABLESAMPLE '(' - ( (percentage=(INTEGER_VALUE | DECIMAL_VALUE) sampleType=PERCENTLIT) + ( (negativeSign=MINUS? percentage=(INTEGER_VALUE | DECIMAL_VALUE) sampleType=PERCENTLIT) | (expression sampleType=ROWS) | sampleType=BYTELENGTH_LITERAL | (sampleType=BUCKET numerator=INTEGER_VALUE OUT OF denominator=INTEGER_VALUE (ON (identifier | qualifiedName '(' ')'))?)) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index 500d999c30da7..315c6721b3f65 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -636,7 +636,8 @@ class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging case SqlBaseParser.PERCENTLIT => val fraction = ctx.percentage.getText.toDouble - sample(fraction / 100.0d) + val sign = if (ctx.negativeSign == null) 1 else -1 + sample(sign * fraction / 100.0d) case SqlBaseParser.BYTELENGTH_LITERAL => throw new ParseException( diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala index 6878b6b179c3a..6e88b7a57dc33 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala @@ -26,6 +26,7 @@ import org.apache.spark.sql.catalyst.plans.logical.statsEstimation._ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ import org.apache.spark.util.Utils +import org.apache.spark.util.random.RandomSampler /** * When planning take() or collect() operations, this special node that is inserted at the top of @@ -817,6 +818,18 @@ case class Sample( child: LogicalPlan)( val isTableSample: java.lang.Boolean = false) extends UnaryNode { + val eps = RandomSampler.roundingEpsilon + val fraction = upperBound - lowerBound + if (withReplacement) { + require( + fraction >= 0.0 - eps, + s"Sampling fraction ($fraction) must be nonnegative with replacement") + } else { + require( + fraction >= 0.0 - eps && fraction <= 1.0 + eps, + s"Sampling fraction ($fraction) must be on interval [0, 1] without replacement") + } + override def output: Seq[Attribute] = child.output override def computeStats(conf: SQLConf): Statistics = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index a2af9c2efe2ab..767dad3e63a6d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -1806,9 +1806,6 @@ class Dataset[T] private[sql]( * @since 1.6.0 */ def sample(withReplacement: Boolean, fraction: Double, seed: Long): Dataset[T] = { - require(fraction >= 0, - s"Fraction must be nonnegative, but got ${fraction}") - withTypedPlan { Sample(0.0, fraction, withReplacement, seed, logicalPlan)() } diff --git a/sql/core/src/test/resources/sql-tests/inputs/tablesample-negative.sql b/sql/core/src/test/resources/sql-tests/inputs/tablesample-negative.sql new file mode 100644 index 0000000000000..72508f59bee27 --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/inputs/tablesample-negative.sql @@ -0,0 +1,14 @@ +-- Negative testcases for tablesample +CREATE DATABASE mydb1; +USE mydb1; +CREATE TABLE t1 USING parquet AS SELECT 1 AS i1; + +-- Negative tests: negative percentage +SELECT mydb1.t1 FROM t1 TABLESAMPLE (-1 PERCENT); + +-- Negative tests: percentage over 100 +-- The TABLESAMPLE clause samples without replacement, so the value of PERCENT must not exceed 100 +SELECT mydb1.t1 FROM t1 TABLESAMPLE (101 PERCENT); + +-- reset +DROP DATABASE mydb1 CASCADE; diff --git a/sql/core/src/test/resources/sql-tests/results/operators.sql.out b/sql/core/src/test/resources/sql-tests/results/operators.sql.out index 5cb6ed3e27bf2..fec423fca5bbe 100644 --- a/sql/core/src/test/resources/sql-tests/results/operators.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/operators.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 56 +-- Number of queries: 57 -- !query 0 @@ -462,9 +462,9 @@ struct 3.13 2.19 --- !query 55 +-- !query 56 select positive('-1.11'), positive(-1.11), negative('-1.11'), negative(-1.11) --- !query 55 schema +-- !query 56 schema struct<(+ CAST(-1.11 AS DOUBLE)):double,(+ -1.11):decimal(3,2),(- CAST(-1.11 AS DOUBLE)):double,(- -1.11):decimal(3,2)> --- !query 55 output +-- !query 56 output -1.11 -1.11 1.11 1.11 diff --git a/sql/core/src/test/resources/sql-tests/results/tablesample-negative.sql.out b/sql/core/src/test/resources/sql-tests/results/tablesample-negative.sql.out new file mode 100644 index 0000000000000..35f3931736b83 --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/results/tablesample-negative.sql.out @@ -0,0 +1,62 @@ +-- Automatically generated by SQLQueryTestSuite +-- Number of queries: 6 + + +-- !query 0 +CREATE DATABASE mydb1 +-- !query 0 schema +struct<> +-- !query 0 output + + + +-- !query 1 +USE mydb1 +-- !query 1 schema +struct<> +-- !query 1 output + + + +-- !query 2 +CREATE TABLE t1 USING parquet AS SELECT 1 AS i1 +-- !query 2 schema +struct<> +-- !query 2 output + + + +-- !query 3 +SELECT mydb1.t1 FROM t1 TABLESAMPLE (-1 PERCENT) +-- !query 3 schema +struct<> +-- !query 3 output +org.apache.spark.sql.catalyst.parser.ParseException + +Sampling fraction (-0.01) must be on interval [0, 1](line 1, pos 24) + +== SQL == +SELECT mydb1.t1 FROM t1 TABLESAMPLE (-1 PERCENT) +------------------------^^^ + + +-- !query 4 +SELECT mydb1.t1 FROM t1 TABLESAMPLE (101 PERCENT) +-- !query 4 schema +struct<> +-- !query 4 output +org.apache.spark.sql.catalyst.parser.ParseException + +Sampling fraction (1.01) must be on interval [0, 1](line 1, pos 24) + +== SQL == +SELECT mydb1.t1 FROM t1 TABLESAMPLE (101 PERCENT) +------------------------^^^ + + +-- !query 5 +DROP DATABASE mydb1 CASCADE +-- !query 5 schema +struct<> +-- !query 5 output + diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala index 8eb381b91f46d..165176f6c040e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala @@ -457,6 +457,34 @@ class DatasetSuite extends QueryTest with SharedSQLContext { 3, 17, 27, 58, 62) } + test("sample fraction should not be negative with replacement") { + val data = sparkContext.parallelize(1 to 2, 1).toDS() + val errMsg = intercept[IllegalArgumentException] { + data.sample(withReplacement = true, -0.1, 0) + }.getMessage + assert(errMsg.contains("Sampling fraction (-0.1) must be nonnegative with replacement")) + + // Sampling fraction can be greater than 1 with replacement. + checkDataset( + data.sample(withReplacement = true, 1.05, seed = 13), + 1, 2) + } + + test("sample fraction should be on interval [0, 1] without replacement") { + val data = sparkContext.parallelize(1 to 2, 1).toDS() + val errMsg1 = intercept[IllegalArgumentException] { + data.sample(withReplacement = false, -0.1, 0) + }.getMessage() + assert(errMsg1.contains( + "Sampling fraction (-0.1) must be on interval [0, 1] without replacement")) + + val errMsg2 = intercept[IllegalArgumentException] { + data.sample(withReplacement = false, 1.1, 0) + }.getMessage() + assert(errMsg2.contains( + "Sampling fraction (1.1) must be on interval [0, 1] without replacement")) + } + test("SPARK-16686: Dataset.sample with seed results shouldn't depend on downstream usage") { val simpleUdf = udf((n: Int) => { require(n != 1, "simpleUdf shouldn't see id=1!") From fe24634d14bc0973ca38222db2f58eafbf0c890d Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Fri, 23 Jun 2017 00:43:21 -0700 Subject: [PATCH 0778/1765] [SPARK-21145][SS] Added StateStoreProviderId with queryRunId to reload StateStoreProviders when query is restarted ## What changes were proposed in this pull request? StateStoreProvider instances are loaded on-demand in a executor when a query is started. When a query is restarted, the loaded provider instance will get reused. Now, there is a non-trivial chance, that the task of the previous query run is still running, while the tasks of the restarted run has started. So for a stateful partition, there may be two concurrent tasks related to the same stateful partition, and there for using the same provider instance. This can lead to inconsistent results and possibly random failures, as state store implementations are not designed to be thread-safe. To fix this, I have introduced a `StateStoreProviderId`, that unique identifies a provider loaded in an executor. It has the query run id in it, thus making sure that restarted queries will force the executor to load a new provider instance, thus avoiding two concurrent tasks (from two different runs) from reusing the same provider instance. Additional minor bug fixes - All state stores related to query run is marked as deactivated in the `StateStoreCoordinator` so that the executors can unload them and clear resources. - Moved the code that determined the checkpoint directory of a state store from implementation-specific code (`HDFSBackedStateStoreProvider`) to non-specific code (StateStoreId), so that implementation do not accidentally get it wrong. - Also added store name to the path, to support multiple stores per sql operator partition. *Note:* This change does not address the scenario where two tasks of the same run (e.g. speculative tasks) are concurrently running in the same executor. The chance of this very small, because ideally speculative tasks should never run in the same executor. ## How was this patch tested? Existing unit tests + new unit test. Author: Tathagata Das Closes #18355 from tdas/SPARK-21145. --- .../sql/execution/aggregate/AggUtils.scala | 2 +- .../sql/execution/command/commands.scala | 5 +- .../FlatMapGroupsWithStateExec.scala | 7 +- .../streaming/IncrementalExecution.scala | 27 +++--- .../execution/streaming/StreamExecution.scala | 1 + .../state/HDFSBackedStateStoreProvider.scala | 16 ++-- .../streaming/state/StateStore.scala | 91 +++++++++++++----- .../state/StateStoreCoordinator.scala | 41 ++++---- .../streaming/state/StateStoreRDD.scala | 21 ++++- .../execution/streaming/state/package.scala | 25 ++--- .../streaming/statefulOperators.scala | 38 ++++---- .../sql/streaming/StreamingQueryManager.scala | 1 + .../state/StateStoreCoordinatorSuite.scala | 61 ++++++++++-- .../streaming/state/StateStoreRDDSuite.scala | 51 +++++----- .../streaming/state/StateStoreSuite.scala | 93 +++++++++++++++---- .../spark/sql/streaming/StreamSuite.scala | 2 +- .../spark/sql/streaming/StreamTest.scala | 13 ++- 17 files changed, 329 insertions(+), 166 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala index aa789af6f812f..12f8cffb6774a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala @@ -311,7 +311,7 @@ object AggUtils { val saved = StateStoreSaveExec( groupingAttributes, - stateId = None, + stateInfo = None, outputMode = None, eventTimeWatermark = None, partialMerged2) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/commands.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/commands.scala index 2d82fcf4da6e9..81bc93e7ebcf4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/commands.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/commands.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.execution.command +import java.util.UUID + import org.apache.spark.rdd.RDD import org.apache.spark.sql.{Row, SparkSession} import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} @@ -117,7 +119,8 @@ case class ExplainCommand( // This is used only by explaining `Dataset/DataFrame` created by `spark.readStream`, so the // output mode does not matter since there is no `Sink`. new IncrementalExecution( - sparkSession, logicalPlan, OutputMode.Append(), "", 0, OffsetSeqMetadata(0, 0)) + sparkSession, logicalPlan, OutputMode.Append(), "", + UUID.randomUUID, 0, OffsetSeqMetadata(0, 0)) } else { sparkSession.sessionState.executePlan(logicalPlan) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala index 2aad8701a4eca..9dcac33b4107c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala @@ -50,7 +50,7 @@ case class FlatMapGroupsWithStateExec( groupingAttributes: Seq[Attribute], dataAttributes: Seq[Attribute], outputObjAttr: Attribute, - stateId: Option[OperatorStateId], + stateInfo: Option[StatefulOperatorStateInfo], stateEncoder: ExpressionEncoder[Any], outputMode: OutputMode, timeoutConf: GroupStateTimeout, @@ -107,10 +107,7 @@ case class FlatMapGroupsWithStateExec( } child.execute().mapPartitionsWithStateStore[InternalRow]( - getStateId.checkpointLocation, - getStateId.operatorId, - storeName = "default", - getStateId.batchId, + getStateInfo, groupingAttributes.toStructType, stateAttributes.toStructType, indexOrdinal = None, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala index 622e049630db2..ab89dc6b705d5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.execution.streaming +import java.util.UUID import java.util.concurrent.atomic.AtomicInteger import org.apache.spark.internal.Logging @@ -36,6 +37,7 @@ class IncrementalExecution( logicalPlan: LogicalPlan, val outputMode: OutputMode, val checkpointLocation: String, + val runId: UUID, val currentBatchId: Long, offsetSeqMetadata: OffsetSeqMetadata) extends QueryExecution(sparkSession, logicalPlan) with Logging { @@ -69,7 +71,13 @@ class IncrementalExecution( * Records the current id for a given stateful operator in the query plan as the `state` * preparation walks the query plan. */ - private val operatorId = new AtomicInteger(0) + private val statefulOperatorId = new AtomicInteger(0) + + /** Get the state info of the next stateful operator */ + private def nextStatefulOperationStateInfo(): StatefulOperatorStateInfo = { + StatefulOperatorStateInfo( + checkpointLocation, runId, statefulOperatorId.getAndIncrement(), currentBatchId) + } /** Locates save/restore pairs surrounding aggregation. */ val state = new Rule[SparkPlan] { @@ -78,35 +86,28 @@ class IncrementalExecution( case StateStoreSaveExec(keys, None, None, None, UnaryExecNode(agg, StateStoreRestoreExec(keys2, None, child))) => - val stateId = - OperatorStateId(checkpointLocation, operatorId.getAndIncrement(), currentBatchId) - + val aggStateInfo = nextStatefulOperationStateInfo StateStoreSaveExec( keys, - Some(stateId), + Some(aggStateInfo), Some(outputMode), Some(offsetSeqMetadata.batchWatermarkMs), agg.withNewChildren( StateStoreRestoreExec( keys, - Some(stateId), + Some(aggStateInfo), child) :: Nil)) case StreamingDeduplicateExec(keys, child, None, None) => - val stateId = - OperatorStateId(checkpointLocation, operatorId.getAndIncrement(), currentBatchId) - StreamingDeduplicateExec( keys, child, - Some(stateId), + Some(nextStatefulOperationStateInfo), Some(offsetSeqMetadata.batchWatermarkMs)) case m: FlatMapGroupsWithStateExec => - val stateId = - OperatorStateId(checkpointLocation, operatorId.getAndIncrement(), currentBatchId) m.copy( - stateId = Some(stateId), + stateInfo = Some(nextStatefulOperationStateInfo), batchTimestampMs = Some(offsetSeqMetadata.batchTimestampMs), eventTimeWatermark = Some(offsetSeqMetadata.batchWatermarkMs)) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala index 74f0f509bbf85..06bdec8b06407 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala @@ -652,6 +652,7 @@ class StreamExecution( triggerLogicalPlan, outputMode, checkpointFile("state"), + runId, currentBatchId, offsetSeqMetadata) lastExecution.executedPlan // Force the lazy generation of execution plan diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala index 67d86daf10812..bae7a15165e43 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala @@ -92,7 +92,7 @@ private[state] class HDFSBackedStateStoreProvider extends StateStoreProvider wit @volatile private var state: STATE = UPDATING @volatile private var finalDeltaFile: Path = null - override def id: StateStoreId = HDFSBackedStateStoreProvider.this.id + override def id: StateStoreId = HDFSBackedStateStoreProvider.this.stateStoreId override def get(key: UnsafeRow): UnsafeRow = { mapToUpdate.get(key) @@ -177,7 +177,7 @@ private[state] class HDFSBackedStateStoreProvider extends StateStoreProvider wit /** * Whether all updates have been committed */ - override private[streaming] def hasCommitted: Boolean = { + override def hasCommitted: Boolean = { state == COMMITTED } @@ -205,7 +205,7 @@ private[state] class HDFSBackedStateStoreProvider extends StateStoreProvider wit indexOrdinal: Option[Int], // for sorting the data storeConf: StateStoreConf, hadoopConf: Configuration): Unit = { - this.stateStoreId = stateStoreId + this.stateStoreId_ = stateStoreId this.keySchema = keySchema this.valueSchema = valueSchema this.storeConf = storeConf @@ -213,7 +213,7 @@ private[state] class HDFSBackedStateStoreProvider extends StateStoreProvider wit fs.mkdirs(baseDir) } - override def id: StateStoreId = stateStoreId + override def stateStoreId: StateStoreId = stateStoreId_ /** Do maintenance backing data files, including creating snapshots and cleaning up old files */ override def doMaintenance(): Unit = { @@ -231,20 +231,20 @@ private[state] class HDFSBackedStateStoreProvider extends StateStoreProvider wit } override def toString(): String = { - s"HDFSStateStoreProvider[id = (op=${id.operatorId}, part=${id.partitionId}), dir = $baseDir]" + s"HDFSStateStoreProvider[" + + s"id = (op=${stateStoreId.operatorId},part=${stateStoreId.partitionId}),dir = $baseDir]" } /* Internal fields and methods */ - @volatile private var stateStoreId: StateStoreId = _ + @volatile private var stateStoreId_ : StateStoreId = _ @volatile private var keySchema: StructType = _ @volatile private var valueSchema: StructType = _ @volatile private var storeConf: StateStoreConf = _ @volatile private var hadoopConf: Configuration = _ private lazy val loadedMaps = new mutable.HashMap[Long, MapType] - private lazy val baseDir = - new Path(id.checkpointLocation, s"${id.operatorId}/${id.partitionId.toString}") + private lazy val baseDir = stateStoreId.storeCheckpointLocation() private lazy val fs = baseDir.getFileSystem(hadoopConf) private lazy val sparkConf = Option(SparkEnv.get).map(_.conf).getOrElse(new SparkConf) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala index 29c456f86e1ed..a94ff8a7ebd1e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.execution.streaming.state +import java.util.UUID import java.util.concurrent.{ScheduledFuture, TimeUnit} import javax.annotation.concurrent.GuardedBy @@ -24,14 +25,14 @@ import scala.collection.mutable import scala.util.control.NonFatal import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.fs.Path -import org.apache.spark.SparkEnv +import org.apache.spark.{SparkContext, SparkEnv} import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.expressions.UnsafeRow import org.apache.spark.sql.types.StructType import org.apache.spark.util.{ThreadUtils, Utils} - /** * Base trait for a versioned key-value store. Each instance of a `StateStore` represents a specific * version of state data, and such instances are created through a [[StateStoreProvider]]. @@ -99,7 +100,7 @@ trait StateStore { /** * Whether all updates have been committed */ - private[streaming] def hasCommitted: Boolean + def hasCommitted: Boolean } @@ -147,7 +148,7 @@ trait StateStoreProvider { * Return the id of the StateStores this provider will generate. * Should be the same as the one passed in init(). */ - def id: StateStoreId + def stateStoreId: StateStoreId /** Called when the provider instance is unloaded from the executor */ def close(): Unit @@ -179,13 +180,46 @@ object StateStoreProvider { } } +/** + * Unique identifier for a provider, used to identify when providers can be reused. + * Note that `queryRunId` is used uniquely identify a provider, so that the same provider + * instance is not reused across query restarts. + */ +case class StateStoreProviderId(storeId: StateStoreId, queryRunId: UUID) -/** Unique identifier for a bunch of keyed state data. */ +/** + * Unique identifier for a bunch of keyed state data. + * @param checkpointRootLocation Root directory where all the state data of a query is stored + * @param operatorId Unique id of a stateful operator + * @param partitionId Index of the partition of an operators state data + * @param storeName Optional, name of the store. Each partition can optionally use multiple state + * stores, but they have to be identified by distinct names. + */ case class StateStoreId( - checkpointLocation: String, + checkpointRootLocation: String, operatorId: Long, partitionId: Int, - name: String = "") + storeName: String = StateStoreId.DEFAULT_STORE_NAME) { + + /** + * Checkpoint directory to be used by a single state store, identified uniquely by the tuple + * (operatorId, partitionId, storeName). All implementations of [[StateStoreProvider]] should + * use this path for saving state data, as this ensures that distinct stores will write to + * different locations. + */ + def storeCheckpointLocation(): Path = { + if (storeName == StateStoreId.DEFAULT_STORE_NAME) { + // For reading state store data that was generated before store names were used (Spark <= 2.2) + new Path(checkpointRootLocation, s"$operatorId/$partitionId") + } else { + new Path(checkpointRootLocation, s"$operatorId/$partitionId/$storeName") + } + } +} + +object StateStoreId { + val DEFAULT_STORE_NAME = "default" +} /** Mutable, and reusable class for representing a pair of UnsafeRows. */ class UnsafeRowPair(var key: UnsafeRow = null, var value: UnsafeRow = null) { @@ -211,7 +245,7 @@ object StateStore extends Logging { val MAINTENANCE_INTERVAL_DEFAULT_SECS = 60 @GuardedBy("loadedProviders") - private val loadedProviders = new mutable.HashMap[StateStoreId, StateStoreProvider]() + private val loadedProviders = new mutable.HashMap[StateStoreProviderId, StateStoreProvider]() /** * Runs the `task` periodically and automatically cancels it if there is an exception. `onError` @@ -253,7 +287,7 @@ object StateStore extends Logging { /** Get or create a store associated with the id. */ def get( - storeId: StateStoreId, + storeProviderId: StateStoreProviderId, keySchema: StructType, valueSchema: StructType, indexOrdinal: Option[Int], @@ -264,24 +298,24 @@ object StateStore extends Logging { val storeProvider = loadedProviders.synchronized { startMaintenanceIfNeeded() val provider = loadedProviders.getOrElseUpdate( - storeId, + storeProviderId, StateStoreProvider.instantiate( - storeId, keySchema, valueSchema, indexOrdinal, storeConf, hadoopConf) + storeProviderId.storeId, keySchema, valueSchema, indexOrdinal, storeConf, hadoopConf) ) - reportActiveStoreInstance(storeId) + reportActiveStoreInstance(storeProviderId) provider } storeProvider.getStore(version) } /** Unload a state store provider */ - def unload(storeId: StateStoreId): Unit = loadedProviders.synchronized { - loadedProviders.remove(storeId).foreach(_.close()) + def unload(storeProviderId: StateStoreProviderId): Unit = loadedProviders.synchronized { + loadedProviders.remove(storeProviderId).foreach(_.close()) } /** Whether a state store provider is loaded or not */ - def isLoaded(storeId: StateStoreId): Boolean = loadedProviders.synchronized { - loadedProviders.contains(storeId) + def isLoaded(storeProviderId: StateStoreProviderId): Boolean = loadedProviders.synchronized { + loadedProviders.contains(storeProviderId) } def isMaintenanceRunning: Boolean = loadedProviders.synchronized { @@ -340,21 +374,21 @@ object StateStore extends Logging { } } - private def reportActiveStoreInstance(storeId: StateStoreId): Unit = { + private def reportActiveStoreInstance(storeProviderId: StateStoreProviderId): Unit = { if (SparkEnv.get != null) { val host = SparkEnv.get.blockManager.blockManagerId.host val executorId = SparkEnv.get.blockManager.blockManagerId.executorId - coordinatorRef.foreach(_.reportActiveInstance(storeId, host, executorId)) - logDebug(s"Reported that the loaded instance $storeId is active") + coordinatorRef.foreach(_.reportActiveInstance(storeProviderId, host, executorId)) + logInfo(s"Reported that the loaded instance $storeProviderId is active") } } - private def verifyIfStoreInstanceActive(storeId: StateStoreId): Boolean = { + private def verifyIfStoreInstanceActive(storeProviderId: StateStoreProviderId): Boolean = { if (SparkEnv.get != null) { val executorId = SparkEnv.get.blockManager.blockManagerId.executorId val verified = - coordinatorRef.map(_.verifyIfInstanceActive(storeId, executorId)).getOrElse(false) - logDebug(s"Verified whether the loaded instance $storeId is active: $verified") + coordinatorRef.map(_.verifyIfInstanceActive(storeProviderId, executorId)).getOrElse(false) + logDebug(s"Verified whether the loaded instance $storeProviderId is active: $verified") verified } else { false @@ -364,12 +398,21 @@ object StateStore extends Logging { private def coordinatorRef: Option[StateStoreCoordinatorRef] = loadedProviders.synchronized { val env = SparkEnv.get if (env != null) { - if (_coordRef == null) { + logInfo("Env is not null") + val isDriver = + env.executorId == SparkContext.DRIVER_IDENTIFIER || + env.executorId == SparkContext.LEGACY_DRIVER_IDENTIFIER + // If running locally, then the coordinator reference in _coordRef may be have become inactive + // as SparkContext + SparkEnv may have been restarted. Hence, when running in driver, + // always recreate the reference. + if (isDriver || _coordRef == null) { + logInfo("Getting StateStoreCoordinatorRef") _coordRef = StateStoreCoordinatorRef.forExecutor(env) } - logDebug(s"Retrieved reference to StateStoreCoordinator: ${_coordRef}") + logInfo(s"Retrieved reference to StateStoreCoordinator: ${_coordRef}") Some(_coordRef) } else { + logInfo("Env is null") _coordRef = null None } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinator.scala index d0f81887e62d1..3884f5e6ce766 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinator.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.execution.streaming.state +import java.util.UUID + import scala.collection.mutable import org.apache.spark.SparkEnv @@ -29,16 +31,19 @@ import org.apache.spark.util.RpcUtils private sealed trait StateStoreCoordinatorMessage extends Serializable /** Classes representing messages */ -private case class ReportActiveInstance(storeId: StateStoreId, host: String, executorId: String) +private case class ReportActiveInstance( + storeId: StateStoreProviderId, + host: String, + executorId: String) extends StateStoreCoordinatorMessage -private case class VerifyIfInstanceActive(storeId: StateStoreId, executorId: String) +private case class VerifyIfInstanceActive(storeId: StateStoreProviderId, executorId: String) extends StateStoreCoordinatorMessage -private case class GetLocation(storeId: StateStoreId) +private case class GetLocation(storeId: StateStoreProviderId) extends StateStoreCoordinatorMessage -private case class DeactivateInstances(checkpointLocation: String) +private case class DeactivateInstances(runId: UUID) extends StateStoreCoordinatorMessage private object StopCoordinator @@ -80,25 +85,27 @@ object StateStoreCoordinatorRef extends Logging { class StateStoreCoordinatorRef private(rpcEndpointRef: RpcEndpointRef) { private[state] def reportActiveInstance( - storeId: StateStoreId, + stateStoreProviderId: StateStoreProviderId, host: String, executorId: String): Unit = { - rpcEndpointRef.send(ReportActiveInstance(storeId, host, executorId)) + rpcEndpointRef.send(ReportActiveInstance(stateStoreProviderId, host, executorId)) } /** Verify whether the given executor has the active instance of a state store */ - private[state] def verifyIfInstanceActive(storeId: StateStoreId, executorId: String): Boolean = { - rpcEndpointRef.askSync[Boolean](VerifyIfInstanceActive(storeId, executorId)) + private[state] def verifyIfInstanceActive( + stateStoreProviderId: StateStoreProviderId, + executorId: String): Boolean = { + rpcEndpointRef.askSync[Boolean](VerifyIfInstanceActive(stateStoreProviderId, executorId)) } /** Get the location of the state store */ - private[state] def getLocation(storeId: StateStoreId): Option[String] = { - rpcEndpointRef.askSync[Option[String]](GetLocation(storeId)) + private[state] def getLocation(stateStoreProviderId: StateStoreProviderId): Option[String] = { + rpcEndpointRef.askSync[Option[String]](GetLocation(stateStoreProviderId)) } - /** Deactivate instances related to a set of operator */ - private[state] def deactivateInstances(storeRootLocation: String): Unit = { - rpcEndpointRef.askSync[Boolean](DeactivateInstances(storeRootLocation)) + /** Deactivate instances related to a query */ + private[sql] def deactivateInstances(runId: UUID): Unit = { + rpcEndpointRef.askSync[Boolean](DeactivateInstances(runId)) } private[state] def stop(): Unit = { @@ -113,7 +120,7 @@ class StateStoreCoordinatorRef private(rpcEndpointRef: RpcEndpointRef) { */ private class StateStoreCoordinator(override val rpcEnv: RpcEnv) extends ThreadSafeRpcEndpoint with Logging { - private val instances = new mutable.HashMap[StateStoreId, ExecutorCacheTaskLocation] + private val instances = new mutable.HashMap[StateStoreProviderId, ExecutorCacheTaskLocation] override def receive: PartialFunction[Any, Unit] = { case ReportActiveInstance(id, host, executorId) => @@ -135,11 +142,11 @@ private class StateStoreCoordinator(override val rpcEnv: RpcEnv) logDebug(s"Got location of the state store $id: $executorId") context.reply(executorId) - case DeactivateInstances(checkpointLocation) => + case DeactivateInstances(runId) => val storeIdsToRemove = - instances.keys.filter(_.checkpointLocation == checkpointLocation).toSeq + instances.keys.filter(_.queryRunId == runId).toSeq instances --= storeIdsToRemove - logDebug(s"Deactivating instances related to checkpoint location $checkpointLocation: " + + logDebug(s"Deactivating instances related to checkpoint location $runId: " + storeIdsToRemove.mkString(", ")) context.reply(true) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDD.scala index b744c25dc97a8..01d8e75980993 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDD.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.execution.streaming.state +import java.util.UUID + import scala.reflect.ClassTag import org.apache.spark.{Partition, TaskContext} @@ -34,8 +36,8 @@ class StateStoreRDD[T: ClassTag, U: ClassTag]( dataRDD: RDD[T], storeUpdateFunction: (StateStore, Iterator[T]) => Iterator[U], checkpointLocation: String, + queryRunId: UUID, operatorId: Long, - storeName: String, storeVersion: Long, keySchema: StructType, valueSchema: StructType, @@ -52,16 +54,25 @@ class StateStoreRDD[T: ClassTag, U: ClassTag]( override protected def getPartitions: Array[Partition] = dataRDD.partitions + /** + * Set the preferred location of each partition using the executor that has the related + * [[StateStoreProvider]] already loaded. + */ override def getPreferredLocations(partition: Partition): Seq[String] = { - val storeId = StateStoreId(checkpointLocation, operatorId, partition.index, storeName) - storeCoordinator.flatMap(_.getLocation(storeId)).toSeq + val stateStoreProviderId = StateStoreProviderId( + StateStoreId(checkpointLocation, operatorId, partition.index), + queryRunId) + storeCoordinator.flatMap(_.getLocation(stateStoreProviderId)).toSeq } override def compute(partition: Partition, ctxt: TaskContext): Iterator[U] = { var store: StateStore = null - val storeId = StateStoreId(checkpointLocation, operatorId, partition.index, storeName) + val storeProviderId = StateStoreProviderId( + StateStoreId(checkpointLocation, operatorId, partition.index), + queryRunId) + store = StateStore.get( - storeId, keySchema, valueSchema, indexOrdinal, storeVersion, + storeProviderId, keySchema, valueSchema, indexOrdinal, storeVersion, storeConf, hadoopConfBroadcast.value.value) val inputIter = dataRDD.iterator(partition, ctxt) storeUpdateFunction(store, inputIter) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/package.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/package.scala index 228fe86d59940..a0086e251f9c6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/package.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/package.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.execution.streaming +import java.util.UUID + import scala.reflect.ClassTag import org.apache.spark.TaskContext @@ -32,20 +34,14 @@ package object state { /** Map each partition of an RDD along with data in a [[StateStore]]. */ def mapPartitionsWithStateStore[U: ClassTag]( sqlContext: SQLContext, - checkpointLocation: String, - operatorId: Long, - storeName: String, - storeVersion: Long, + stateInfo: StatefulOperatorStateInfo, keySchema: StructType, valueSchema: StructType, indexOrdinal: Option[Int])( storeUpdateFunction: (StateStore, Iterator[T]) => Iterator[U]): StateStoreRDD[T, U] = { mapPartitionsWithStateStore( - checkpointLocation, - operatorId, - storeName, - storeVersion, + stateInfo, keySchema, valueSchema, indexOrdinal, @@ -56,10 +52,7 @@ package object state { /** Map each partition of an RDD along with data in a [[StateStore]]. */ private[streaming] def mapPartitionsWithStateStore[U: ClassTag]( - checkpointLocation: String, - operatorId: Long, - storeName: String, - storeVersion: Long, + stateInfo: StatefulOperatorStateInfo, keySchema: StructType, valueSchema: StructType, indexOrdinal: Option[Int], @@ -79,10 +72,10 @@ package object state { new StateStoreRDD( dataRDD, wrappedF, - checkpointLocation, - operatorId, - storeName, - storeVersion, + stateInfo.checkpointLocation, + stateInfo.queryRunId, + stateInfo.operatorId, + stateInfo.storeVersion, keySchema, valueSchema, indexOrdinal, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala index 3e57f3fbada32..c5722466a33af 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.execution.streaming +import java.util.UUID import java.util.concurrent.TimeUnit._ import org.apache.spark.rdd.RDD @@ -36,20 +37,22 @@ import org.apache.spark.util.{CompletionIterator, NextIterator} /** Used to identify the state store for a given operator. */ -case class OperatorStateId( +case class StatefulOperatorStateInfo( checkpointLocation: String, + queryRunId: UUID, operatorId: Long, - batchId: Long) + storeVersion: Long) /** - * An operator that reads or writes state from the [[StateStore]]. The [[OperatorStateId]] should - * be filled in by `prepareForExecution` in [[IncrementalExecution]]. + * An operator that reads or writes state from the [[StateStore]]. + * The [[StatefulOperatorStateInfo]] should be filled in by `prepareForExecution` in + * [[IncrementalExecution]]. */ trait StatefulOperator extends SparkPlan { - def stateId: Option[OperatorStateId] + def stateInfo: Option[StatefulOperatorStateInfo] - protected def getStateId: OperatorStateId = attachTree(this) { - stateId.getOrElse { + protected def getStateInfo: StatefulOperatorStateInfo = attachTree(this) { + stateInfo.getOrElse { throw new IllegalStateException("State location not present for execution") } } @@ -140,7 +143,7 @@ trait WatermarkSupport extends UnaryExecNode { */ case class StateStoreRestoreExec( keyExpressions: Seq[Attribute], - stateId: Option[OperatorStateId], + stateInfo: Option[StatefulOperatorStateInfo], child: SparkPlan) extends UnaryExecNode with StateStoreReader { @@ -148,10 +151,7 @@ case class StateStoreRestoreExec( val numOutputRows = longMetric("numOutputRows") child.execute().mapPartitionsWithStateStore( - getStateId.checkpointLocation, - operatorId = getStateId.operatorId, - storeName = "default", - storeVersion = getStateId.batchId, + getStateInfo, keyExpressions.toStructType, child.output.toStructType, indexOrdinal = None, @@ -177,7 +177,7 @@ case class StateStoreRestoreExec( */ case class StateStoreSaveExec( keyExpressions: Seq[Attribute], - stateId: Option[OperatorStateId] = None, + stateInfo: Option[StatefulOperatorStateInfo] = None, outputMode: Option[OutputMode] = None, eventTimeWatermark: Option[Long] = None, child: SparkPlan) @@ -189,10 +189,7 @@ case class StateStoreSaveExec( "Incorrect planning in IncrementalExecution, outputMode has not been set") child.execute().mapPartitionsWithStateStore( - getStateId.checkpointLocation, - getStateId.operatorId, - storeName = "default", - getStateId.batchId, + getStateInfo, keyExpressions.toStructType, child.output.toStructType, indexOrdinal = None, @@ -319,7 +316,7 @@ case class StateStoreSaveExec( case class StreamingDeduplicateExec( keyExpressions: Seq[Attribute], child: SparkPlan, - stateId: Option[OperatorStateId] = None, + stateInfo: Option[StatefulOperatorStateInfo] = None, eventTimeWatermark: Option[Long] = None) extends UnaryExecNode with StateStoreWriter with WatermarkSupport { @@ -331,10 +328,7 @@ case class StreamingDeduplicateExec( metrics // force lazy init at driver child.execute().mapPartitionsWithStateStore( - getStateId.checkpointLocation, - getStateId.operatorId, - storeName = "default", - getStateId.batchId, + getStateInfo, keyExpressions.toStructType, child.output.toStructType, indexOrdinal = None, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryManager.scala index 002c45413b4c2..48b0ea20e5da1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryManager.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryManager.scala @@ -332,5 +332,6 @@ class StreamingQueryManager private[sql] (sparkSession: SparkSession) extends Lo } awaitTerminationLock.notifyAll() } + stateStoreCoordinator.deactivateInstances(terminatedQuery.runId) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinatorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinatorSuite.scala index a7e32626264cc..9a7595eee7bd0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinatorSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinatorSuite.scala @@ -17,11 +17,17 @@ package org.apache.spark.sql.execution.streaming.state +import java.util.UUID + import org.scalatest.concurrent.Eventually._ import org.scalatest.time.SpanSugar._ import org.apache.spark.{SharedSparkContext, SparkContext, SparkFunSuite} import org.apache.spark.scheduler.ExecutorCacheTaskLocation +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.execution.streaming.{MemoryStream, StreamingQueryWrapper} +import org.apache.spark.sql.functions.count +import org.apache.spark.util.Utils class StateStoreCoordinatorSuite extends SparkFunSuite with SharedSparkContext { @@ -29,7 +35,7 @@ class StateStoreCoordinatorSuite extends SparkFunSuite with SharedSparkContext { test("report, verify, getLocation") { withCoordinatorRef(sc) { coordinatorRef => - val id = StateStoreId("x", 0, 0) + val id = StateStoreProviderId(StateStoreId("x", 0, 0), UUID.randomUUID) assert(coordinatorRef.verifyIfInstanceActive(id, "exec1") === false) assert(coordinatorRef.getLocation(id) === None) @@ -57,9 +63,11 @@ class StateStoreCoordinatorSuite extends SparkFunSuite with SharedSparkContext { test("make inactive") { withCoordinatorRef(sc) { coordinatorRef => - val id1 = StateStoreId("x", 0, 0) - val id2 = StateStoreId("y", 1, 0) - val id3 = StateStoreId("x", 0, 1) + val runId1 = UUID.randomUUID + val runId2 = UUID.randomUUID + val id1 = StateStoreProviderId(StateStoreId("x", 0, 0), runId1) + val id2 = StateStoreProviderId(StateStoreId("y", 1, 0), runId2) + val id3 = StateStoreProviderId(StateStoreId("x", 0, 1), runId1) val host = "hostX" val exec = "exec1" @@ -73,7 +81,7 @@ class StateStoreCoordinatorSuite extends SparkFunSuite with SharedSparkContext { assert(coordinatorRef.verifyIfInstanceActive(id3, exec) === true) } - coordinatorRef.deactivateInstances("x") + coordinatorRef.deactivateInstances(runId1) assert(coordinatorRef.verifyIfInstanceActive(id1, exec) === false) assert(coordinatorRef.verifyIfInstanceActive(id2, exec) === true) @@ -85,7 +93,7 @@ class StateStoreCoordinatorSuite extends SparkFunSuite with SharedSparkContext { Some(ExecutorCacheTaskLocation(host, exec).toString)) assert(coordinatorRef.getLocation(id3) === None) - coordinatorRef.deactivateInstances("y") + coordinatorRef.deactivateInstances(runId2) assert(coordinatorRef.verifyIfInstanceActive(id2, exec) === false) assert(coordinatorRef.getLocation(id2) === None) } @@ -95,7 +103,7 @@ class StateStoreCoordinatorSuite extends SparkFunSuite with SharedSparkContext { withCoordinatorRef(sc) { coordRef1 => val coordRef2 = StateStoreCoordinatorRef.forDriver(sc.env) - val id = StateStoreId("x", 0, 0) + val id = StateStoreProviderId(StateStoreId("x", 0, 0), UUID.randomUUID) coordRef1.reportActiveInstance(id, "hostX", "exec1") @@ -107,6 +115,45 @@ class StateStoreCoordinatorSuite extends SparkFunSuite with SharedSparkContext { } } } + + test("query stop deactivates related store providers") { + var coordRef: StateStoreCoordinatorRef = null + try { + val spark = SparkSession.builder().sparkContext(sc).getOrCreate() + SparkSession.setActiveSession(spark) + import spark.implicits._ + coordRef = spark.streams.stateStoreCoordinator + implicit val sqlContext = spark.sqlContext + spark.conf.set("spark.sql.shuffle.partitions", "1") + + // Start a query and run a batch to load state stores + val inputData = MemoryStream[Int] + val aggregated = inputData.toDF().groupBy("value").agg(count("*")) // stateful query + val checkpointLocation = Utils.createTempDir().getAbsoluteFile + val query = aggregated.writeStream + .format("memory") + .outputMode("update") + .queryName("query") + .option("checkpointLocation", checkpointLocation.toString) + .start() + inputData.addData(1, 2, 3) + query.processAllAvailable() + + // Verify state store has been loaded + val stateCheckpointDir = + query.asInstanceOf[StreamingQueryWrapper].streamingQuery.lastExecution.checkpointLocation + val providerId = StateStoreProviderId(StateStoreId(stateCheckpointDir, 0, 0), query.runId) + assert(coordRef.getLocation(providerId).nonEmpty) + + // Stop and verify whether the stores are deactivated in the coordinator + query.stop() + assert(coordRef.getLocation(providerId).isEmpty) + } finally { + SparkSession.getActiveSession.foreach(_.streams.active.foreach(_.stop())) + if (coordRef != null) coordRef.stop() + StateStore.stop() + } + } } object StateStoreCoordinatorSuite { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDDSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDDSuite.scala index 4a1a089af54c2..defb9ed63a881 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDDSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDDSuite.scala @@ -19,20 +19,19 @@ package org.apache.spark.sql.execution.streaming.state import java.io.File import java.nio.file.Files +import java.util.UUID import scala.util.Random import org.scalatest.{BeforeAndAfter, BeforeAndAfterAll} -import org.scalatest.concurrent.Eventually._ -import org.scalatest.time.SpanSugar._ import org.apache.spark.{SparkConf, SparkContext, SparkFunSuite} -import org.apache.spark.sql.LocalSparkSession._ -import org.apache.spark.LocalSparkContext._ import org.apache.spark.rdd.RDD import org.apache.spark.scheduler.ExecutorCacheTaskLocation +import org.apache.spark.sql.LocalSparkSession._ import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.util.quietly +import org.apache.spark.sql.execution.streaming.StatefulOperatorStateInfo import org.apache.spark.sql.types.{IntegerType, StringType, StructField, StructType} import org.apache.spark.util.{CompletionIterator, Utils} @@ -57,16 +56,14 @@ class StateStoreRDDSuite extends SparkFunSuite with BeforeAndAfter with BeforeAn test("versioning and immutability") { withSparkSession(SparkSession.builder.config(sparkConf).getOrCreate()) { spark => val path = Utils.createDirectory(tempDir, Random.nextString(10)).toString - val opId = 0 - val rdd1 = - makeRDD(spark.sparkContext, Seq("a", "b", "a")).mapPartitionsWithStateStore( - spark.sqlContext, path, opId, "name", storeVersion = 0, keySchema, valueSchema, None)( + val rdd1 = makeRDD(spark.sparkContext, Seq("a", "b", "a")).mapPartitionsWithStateStore( + spark.sqlContext, operatorStateInfo(path, version = 0), keySchema, valueSchema, None)( increment) assert(rdd1.collect().toSet === Set("a" -> 2, "b" -> 1)) // Generate next version of stores val rdd2 = makeRDD(spark.sparkContext, Seq("a", "c")).mapPartitionsWithStateStore( - spark.sqlContext, path, opId, "name", storeVersion = 1, keySchema, valueSchema, None)( + spark.sqlContext, operatorStateInfo(path, version = 1), keySchema, valueSchema, None)( increment) assert(rdd2.collect().toSet === Set("a" -> 3, "b" -> 1, "c" -> 1)) @@ -76,7 +73,6 @@ class StateStoreRDDSuite extends SparkFunSuite with BeforeAndAfter with BeforeAn } test("recovering from files") { - val opId = 0 val path = Utils.createDirectory(tempDir, Random.nextString(10)).toString def makeStoreRDD( @@ -85,7 +81,8 @@ class StateStoreRDDSuite extends SparkFunSuite with BeforeAndAfter with BeforeAn storeVersion: Int): RDD[(String, Int)] = { implicit val sqlContext = spark.sqlContext makeRDD(spark.sparkContext, Seq("a")).mapPartitionsWithStateStore( - sqlContext, path, opId, "name", storeVersion, keySchema, valueSchema, None)(increment) + sqlContext, operatorStateInfo(path, version = storeVersion), + keySchema, valueSchema, None)(increment) } // Generate RDDs and state store data @@ -132,17 +129,17 @@ class StateStoreRDDSuite extends SparkFunSuite with BeforeAndAfter with BeforeAn } val rddOfGets1 = makeRDD(spark.sparkContext, Seq("a", "b", "c")).mapPartitionsWithStateStore( - spark.sqlContext, path, opId, "name", storeVersion = 0, keySchema, valueSchema, None)( + spark.sqlContext, operatorStateInfo(path, version = 0), keySchema, valueSchema, None)( iteratorOfGets) assert(rddOfGets1.collect().toSet === Set("a" -> None, "b" -> None, "c" -> None)) val rddOfPuts = makeRDD(spark.sparkContext, Seq("a", "b", "a")).mapPartitionsWithStateStore( - sqlContext, path, opId, "name", storeVersion = 0, keySchema, valueSchema, None)( + sqlContext, operatorStateInfo(path, version = 0), keySchema, valueSchema, None)( iteratorOfPuts) assert(rddOfPuts.collect().toSet === Set("a" -> 1, "a" -> 2, "b" -> 1)) val rddOfGets2 = makeRDD(spark.sparkContext, Seq("a", "b", "c")).mapPartitionsWithStateStore( - sqlContext, path, opId, "name", storeVersion = 1, keySchema, valueSchema, None)( + sqlContext, operatorStateInfo(path, version = 1), keySchema, valueSchema, None)( iteratorOfGets) assert(rddOfGets2.collect().toSet === Set("a" -> Some(2), "b" -> Some(1), "c" -> None)) } @@ -150,22 +147,25 @@ class StateStoreRDDSuite extends SparkFunSuite with BeforeAndAfter with BeforeAn test("preferred locations using StateStoreCoordinator") { quietly { + val queryRunId = UUID.randomUUID val opId = 0 val path = Utils.createDirectory(tempDir, Random.nextString(10)).toString withSparkSession(SparkSession.builder.config(sparkConf).getOrCreate()) { spark => implicit val sqlContext = spark.sqlContext val coordinatorRef = sqlContext.streams.stateStoreCoordinator - coordinatorRef.reportActiveInstance(StateStoreId(path, opId, 0, "name"), "host1", "exec1") - coordinatorRef.reportActiveInstance(StateStoreId(path, opId, 1, "name"), "host2", "exec2") + val storeProviderId1 = StateStoreProviderId(StateStoreId(path, opId, 0), queryRunId) + val storeProviderId2 = StateStoreProviderId(StateStoreId(path, opId, 1), queryRunId) + coordinatorRef.reportActiveInstance(storeProviderId1, "host1", "exec1") + coordinatorRef.reportActiveInstance(storeProviderId2, "host2", "exec2") - assert( - coordinatorRef.getLocation(StateStoreId(path, opId, 0, "name")) === + require( + coordinatorRef.getLocation(storeProviderId1) === Some(ExecutorCacheTaskLocation("host1", "exec1").toString)) val rdd = makeRDD(spark.sparkContext, Seq("a", "b", "a")).mapPartitionsWithStateStore( - sqlContext, path, opId, "name", storeVersion = 0, keySchema, valueSchema, None)( - increment) + sqlContext, operatorStateInfo(path, queryRunId = queryRunId), + keySchema, valueSchema, None)(increment) require(rdd.partitions.length === 2) assert( @@ -192,12 +192,12 @@ class StateStoreRDDSuite extends SparkFunSuite with BeforeAndAfter with BeforeAn val path = Utils.createDirectory(tempDir, Random.nextString(10)).toString val opId = 0 val rdd1 = makeRDD(spark.sparkContext, Seq("a", "b", "a")).mapPartitionsWithStateStore( - sqlContext, path, opId, "name", storeVersion = 0, keySchema, valueSchema, None)(increment) + sqlContext, operatorStateInfo(path, version = 0), keySchema, valueSchema, None)(increment) assert(rdd1.collect().toSet === Set("a" -> 2, "b" -> 1)) // Generate next version of stores val rdd2 = makeRDD(spark.sparkContext, Seq("a", "c")).mapPartitionsWithStateStore( - sqlContext, path, opId, "name", storeVersion = 1, keySchema, valueSchema, None)(increment) + sqlContext, operatorStateInfo(path, version = 1), keySchema, valueSchema, None)(increment) assert(rdd2.collect().toSet === Set("a" -> 3, "b" -> 1, "c" -> 1)) // Make sure the previous RDD still has the same data. @@ -210,6 +210,13 @@ class StateStoreRDDSuite extends SparkFunSuite with BeforeAndAfter with BeforeAn sc.makeRDD(seq, 2).groupBy(x => x).flatMap(_._2) } + private def operatorStateInfo( + path: String, + queryRunId: UUID = UUID.randomUUID, + version: Int = 0): StatefulOperatorStateInfo = { + StatefulOperatorStateInfo(path, queryRunId, operatorId = 0, version) + } + private val increment = (store: StateStore, iter: Iterator[String]) => { iter.foreach { s => val key = stringToRow(s) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala index af2b9f1c11fb6..c2087ec219e57 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.execution.streaming.state import java.io.{File, IOException} import java.net.URI +import java.util.UUID import scala.collection.JavaConverters._ import scala.collection.mutable @@ -33,8 +34,11 @@ import org.scalatest.time.SpanSugar._ import org.apache.spark.{SparkConf, SparkContext, SparkEnv, SparkFunSuite} import org.apache.spark.LocalSparkContext._ +import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, UnsafeProjection, UnsafeRow} import org.apache.spark.sql.catalyst.util.quietly +import org.apache.spark.sql.execution.streaming.{MemoryStream, StreamingQueryWrapper} +import org.apache.spark.sql.functions.count import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String @@ -143,7 +147,7 @@ class StateStoreSuite extends StateStoreSuiteBase[HDFSBackedStateStoreProvider] provider.getStore(0).commit() // Verify we don't leak temp files - val tempFiles = FileUtils.listFiles(new File(provider.id.checkpointLocation), + val tempFiles = FileUtils.listFiles(new File(provider.stateStoreId.checkpointRootLocation), null, true).asScala.filter(_.getName.startsWith("temp-")) assert(tempFiles.isEmpty) } @@ -183,7 +187,7 @@ class StateStoreSuite extends StateStoreSuiteBase[HDFSBackedStateStoreProvider] test("StateStore.get") { quietly { val dir = newDir() - val storeId = StateStoreId(dir, 0, 0) + val storeId = StateStoreProviderId(StateStoreId(dir, 0, 0), UUID.randomUUID) val storeConf = StateStoreConf.empty val hadoopConf = new Configuration() @@ -243,18 +247,18 @@ class StateStoreSuite extends StateStoreSuiteBase[HDFSBackedStateStoreProvider] .set("spark.rpc.numRetries", "1") val opId = 0 val dir = newDir() - val storeId = StateStoreId(dir, opId, 0) + val storeProviderId = StateStoreProviderId(StateStoreId(dir, opId, 0), UUID.randomUUID) val sqlConf = new SQLConf() sqlConf.setConf(SQLConf.MIN_BATCHES_TO_RETAIN, 2) val storeConf = StateStoreConf(sqlConf) val hadoopConf = new Configuration() - val provider = newStoreProvider(storeId) + val provider = newStoreProvider(storeProviderId.storeId) var latestStoreVersion = 0 def generateStoreVersions() { for (i <- 1 to 20) { - val store = StateStore.get(storeId, keySchema, valueSchema, None, + val store = StateStore.get(storeProviderId, keySchema, valueSchema, None, latestStoreVersion, storeConf, hadoopConf) put(store, "a", i) store.commit() @@ -274,7 +278,8 @@ class StateStoreSuite extends StateStoreSuiteBase[HDFSBackedStateStoreProvider] eventually(timeout(timeoutDuration)) { // Store should have been reported to the coordinator - assert(coordinatorRef.getLocation(storeId).nonEmpty, "active instance was not reported") + assert(coordinatorRef.getLocation(storeProviderId).nonEmpty, + "active instance was not reported") // Background maintenance should clean up and generate snapshots assert(StateStore.isMaintenanceRunning, "Maintenance task is not running") @@ -295,35 +300,35 @@ class StateStoreSuite extends StateStoreSuiteBase[HDFSBackedStateStoreProvider] assert(!fileExists(provider, 1, isSnapshot = false), "earliest file not deleted") } - // If driver decides to deactivate all instances of the store, then this instance - // should be unloaded - coordinatorRef.deactivateInstances(dir) + // If driver decides to deactivate all stores related to a query run, + // then this instance should be unloaded + coordinatorRef.deactivateInstances(storeProviderId.queryRunId) eventually(timeout(timeoutDuration)) { - assert(!StateStore.isLoaded(storeId)) + assert(!StateStore.isLoaded(storeProviderId)) } // Reload the store and verify - StateStore.get(storeId, keySchema, valueSchema, indexOrdinal = None, + StateStore.get(storeProviderId, keySchema, valueSchema, indexOrdinal = None, latestStoreVersion, storeConf, hadoopConf) - assert(StateStore.isLoaded(storeId)) + assert(StateStore.isLoaded(storeProviderId)) // If some other executor loads the store, then this instance should be unloaded - coordinatorRef.reportActiveInstance(storeId, "other-host", "other-exec") + coordinatorRef.reportActiveInstance(storeProviderId, "other-host", "other-exec") eventually(timeout(timeoutDuration)) { - assert(!StateStore.isLoaded(storeId)) + assert(!StateStore.isLoaded(storeProviderId)) } // Reload the store and verify - StateStore.get(storeId, keySchema, valueSchema, indexOrdinal = None, + StateStore.get(storeProviderId, keySchema, valueSchema, indexOrdinal = None, latestStoreVersion, storeConf, hadoopConf) - assert(StateStore.isLoaded(storeId)) + assert(StateStore.isLoaded(storeProviderId)) } } // Verify if instance is unloaded if SparkContext is stopped eventually(timeout(timeoutDuration)) { require(SparkEnv.get === null) - assert(!StateStore.isLoaded(storeId)) + assert(!StateStore.isLoaded(storeProviderId)) assert(!StateStore.isMaintenanceRunning) } } @@ -344,7 +349,7 @@ class StateStoreSuite extends StateStoreSuiteBase[HDFSBackedStateStoreProvider] test("SPARK-18416: do not create temp delta file until the store is updated") { val dir = newDir() - val storeId = StateStoreId(dir, 0, 0) + val storeId = StateStoreProviderId(StateStoreId(dir, 0, 0), UUID.randomUUID) val storeConf = StateStoreConf.empty val hadoopConf = new Configuration() val deltaFileDir = new File(s"$dir/0/0/") @@ -408,12 +413,60 @@ class StateStoreSuite extends StateStoreSuiteBase[HDFSBackedStateStoreProvider] assert(numDeltaFiles === 3) } + test("SPARK-21145: Restarted queries create new provider instances") { + try { + val checkpointLocation = Utils.createTempDir().getAbsoluteFile + val spark = SparkSession.builder().master("local[2]").getOrCreate() + SparkSession.setActiveSession(spark) + implicit val sqlContext = spark.sqlContext + spark.conf.set("spark.sql.shuffle.partitions", "1") + import spark.implicits._ + val inputData = MemoryStream[Int] + + def runQueryAndGetLoadedProviders(): Seq[StateStoreProvider] = { + val aggregated = inputData.toDF().groupBy("value").agg(count("*")) + // stateful query + val query = aggregated.writeStream + .format("memory") + .outputMode("complete") + .queryName("query") + .option("checkpointLocation", checkpointLocation.toString) + .start() + inputData.addData(1, 2, 3) + query.processAllAvailable() + require(query.lastProgress != null) // at least one batch processed after start + val loadedProvidersMethod = + PrivateMethod[mutable.HashMap[StateStoreProviderId, StateStoreProvider]]('loadedProviders) + val loadedProvidersMap = StateStore invokePrivate loadedProvidersMethod() + val loadedProviders = loadedProvidersMap.synchronized { loadedProvidersMap.values.toSeq } + query.stop() + loadedProviders + } + + val loadedProvidersAfterRun1 = runQueryAndGetLoadedProviders() + require(loadedProvidersAfterRun1.length === 1) + + val loadedProvidersAfterRun2 = runQueryAndGetLoadedProviders() + assert(loadedProvidersAfterRun2.length === 2) // two providers loaded for 2 runs + + // Both providers should have the same StateStoreId, but the should be different objects + assert(loadedProvidersAfterRun2(0).stateStoreId === loadedProvidersAfterRun2(1).stateStoreId) + assert(loadedProvidersAfterRun2(0) ne loadedProvidersAfterRun2(1)) + + } finally { + SparkSession.getActiveSession.foreach { spark => + spark.streams.active.foreach(_.stop()) + spark.stop() + } + } + } + override def newStoreProvider(): HDFSBackedStateStoreProvider = { newStoreProvider(opId = Random.nextInt(), partition = 0) } override def newStoreProvider(storeId: StateStoreId): HDFSBackedStateStoreProvider = { - newStoreProvider(storeId.operatorId, storeId.partitionId, dir = storeId.checkpointLocation) + newStoreProvider(storeId.operatorId, storeId.partitionId, dir = storeId.checkpointRootLocation) } override def getLatestData(storeProvider: HDFSBackedStateStoreProvider): Set[(String, Int)] = { @@ -423,7 +476,7 @@ class StateStoreSuite extends StateStoreSuiteBase[HDFSBackedStateStoreProvider] override def getData( provider: HDFSBackedStateStoreProvider, version: Int = -1): Set[(String, Int)] = { - val reloadedProvider = newStoreProvider(provider.id) + val reloadedProvider = newStoreProvider(provider.stateStoreId) if (version < 0) { reloadedProvider.latestIterator().map(rowsToStringInt).toSet } else { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala index 4ede4fd9a035e..86c3a35a59c13 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala @@ -777,7 +777,7 @@ class TestStateStoreProvider extends StateStoreProvider { throw new Exception("Successfully instantiated") } - override def id: StateStoreId = null + override def stateStoreId: StateStoreId = null override def close(): Unit = { } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala index 2a4039cc5831a..b2c42eef88f6d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala @@ -26,9 +26,8 @@ import scala.reflect.ClassTag import scala.util.Random import scala.util.control.NonFatal -import org.scalatest.Assertions +import org.scalatest.{Assertions, BeforeAndAfterAll} import org.scalatest.concurrent.{Eventually, Timeouts} -import org.scalatest.concurrent.Eventually._ import org.scalatest.concurrent.PatienceConfiguration.Timeout import org.scalatest.exceptions.TestFailedDueToTimeoutException import org.scalatest.time.Span @@ -39,9 +38,10 @@ import org.apache.spark.sql.catalyst.encoders.{encoderFor, ExpressionEncoder, Ro import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.util._ import org.apache.spark.sql.execution.streaming._ +import org.apache.spark.sql.execution.streaming.state.StateStore import org.apache.spark.sql.streaming.StreamingQueryListener._ import org.apache.spark.sql.test.SharedSQLContext -import org.apache.spark.util.{Clock, ManualClock, SystemClock, Utils} +import org.apache.spark.util.{Clock, SystemClock, Utils} /** * A framework for implementing tests for streaming queries and sources. @@ -67,7 +67,12 @@ import org.apache.spark.util.{Clock, ManualClock, SystemClock, Utils} * avoid hanging forever in the case of failures. However, individual suites can change this * by overriding `streamingTimeout`. */ -trait StreamTest extends QueryTest with SharedSQLContext with Timeouts { +trait StreamTest extends QueryTest with SharedSQLContext with Timeouts with BeforeAndAfterAll { + + override def afterAll(): Unit = { + super.afterAll() + StateStore.stop() // stop the state store maintenance thread and unload store providers + } /** How long to wait for an active stream to catch up when checking a result. */ val streamingTimeout = 10.seconds From 153dd49b74e1b6df2b8e35760806c9754ca7bfae Mon Sep 17 00:00:00 2001 From: jinxing Date: Fri, 23 Jun 2017 20:41:17 +0800 Subject: [PATCH 0779/1765] [SPARK-21047] Add test suites for complicated cases in ColumnarBatchSuite ## What changes were proposed in this pull request? Current ColumnarBatchSuite has very simple test cases for `Array` and `Struct`. This pr wants to add some test suites for complicated cases in ColumnVector. Author: jinxing Closes #18327 from jinxing64/SPARK-21047. --- .../execution/vectorized/ColumnarBatch.java | 35 ++++- .../vectorized/ColumnarBatchSuite.scala | 122 ++++++++++++++++++ 2 files changed, 156 insertions(+), 1 deletion(-) diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarBatch.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarBatch.java index 8b7b0e655b31d..e23a64350cbc5 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarBatch.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarBatch.java @@ -241,7 +241,40 @@ public MapData getMap(int ordinal) { @Override public Object get(int ordinal, DataType dataType) { - throw new UnsupportedOperationException(); + if (dataType instanceof BooleanType) { + return getBoolean(ordinal); + } else if (dataType instanceof ByteType) { + return getByte(ordinal); + } else if (dataType instanceof ShortType) { + return getShort(ordinal); + } else if (dataType instanceof IntegerType) { + return getInt(ordinal); + } else if (dataType instanceof LongType) { + return getLong(ordinal); + } else if (dataType instanceof FloatType) { + return getFloat(ordinal); + } else if (dataType instanceof DoubleType) { + return getDouble(ordinal); + } else if (dataType instanceof StringType) { + return getUTF8String(ordinal); + } else if (dataType instanceof BinaryType) { + return getBinary(ordinal); + } else if (dataType instanceof DecimalType) { + DecimalType t = (DecimalType) dataType; + return getDecimal(ordinal, t.precision(), t.scale()); + } else if (dataType instanceof DateType) { + return getInt(ordinal); + } else if (dataType instanceof TimestampType) { + return getLong(ordinal); + } else if (dataType instanceof ArrayType) { + return getArray(ordinal); + } else if (dataType instanceof StructType) { + return getStruct(ordinal, ((StructType)dataType).fields().length); + } else if (dataType instanceof MapType) { + return getMap(ordinal); + } else { + throw new UnsupportedOperationException("Datatype not supported " + dataType); + } } @Override diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala index e48e3f6402901..80d41577dcf2d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala @@ -739,6 +739,128 @@ class ColumnarBatchSuite extends SparkFunSuite { }} } + test("Nest Array in Array.") { + (MemoryMode.ON_HEAP :: MemoryMode.OFF_HEAP :: Nil).foreach { memMode => + val column = ColumnVector.allocate(10, new ArrayType(new ArrayType(IntegerType, true), true), + memMode) + val childColumn = column.arrayData() + val data = column.arrayData().arrayData() + (0 until 6).foreach { + case 3 => data.putNull(3) + case i => data.putInt(i, i) + } + // Arrays in child column: [0], [1, 2], [], [null, 4, 5] + childColumn.putArray(0, 0, 1) + childColumn.putArray(1, 1, 2) + childColumn.putArray(2, 2, 0) + childColumn.putArray(3, 3, 3) + // Arrays in column: [[0]], [[1, 2], []], [[], [null, 4, 5]], null + column.putArray(0, 0, 1) + column.putArray(1, 1, 2) + column.putArray(2, 2, 2) + column.putNull(3) + + assert(column.getArray(0).getArray(0).toIntArray() === Array(0)) + assert(column.getArray(1).getArray(0).toIntArray() === Array(1, 2)) + assert(column.getArray(1).getArray(1).toIntArray() === Array()) + assert(column.getArray(2).getArray(0).toIntArray() === Array()) + assert(column.getArray(2).getArray(1).isNullAt(0)) + assert(column.getArray(2).getArray(1).getInt(1) === 4) + assert(column.getArray(2).getArray(1).getInt(2) === 5) + assert(column.isNullAt(3)) + } + } + + test("Nest Struct in Array.") { + (MemoryMode.ON_HEAP :: MemoryMode.OFF_HEAP :: Nil).foreach { memMode => + val schema = new StructType().add("int", IntegerType).add("long", LongType) + val column = ColumnVector.allocate(10, new ArrayType(schema, true), memMode) + val data = column.arrayData() + val c0 = data.getChildColumn(0) + val c1 = data.getChildColumn(1) + // Structs in child column: (0, 0), (1, 10), (2, 20), (3, 30), (4, 40), (5, 50) + (0 until 6).foreach { i => + c0.putInt(i, i) + c1.putLong(i, i * 10) + } + // Arrays in column: [(0, 0), (1, 10)], [(1, 10), (2, 20), (3, 30)], + // [(4, 40), (5, 50)] + column.putArray(0, 0, 2) + column.putArray(1, 1, 3) + column.putArray(2, 4, 2) + + assert(column.getArray(0).getStruct(0, 2).toSeq(schema) === Seq(0, 0)) + assert(column.getArray(0).getStruct(1, 2).toSeq(schema) === Seq(1, 10)) + assert(column.getArray(1).getStruct(0, 2).toSeq(schema) === Seq(1, 10)) + assert(column.getArray(1).getStruct(1, 2).toSeq(schema) === Seq(2, 20)) + assert(column.getArray(1).getStruct(2, 2).toSeq(schema) === Seq(3, 30)) + assert(column.getArray(2).getStruct(0, 2).toSeq(schema) === Seq(4, 40)) + assert(column.getArray(2).getStruct(1, 2).toSeq(schema) === Seq(5, 50)) + } + } + + test("Nest Array in Struct.") { + (MemoryMode.ON_HEAP :: MemoryMode.OFF_HEAP :: Nil).foreach { memMode => + val schema = new StructType() + .add("int", IntegerType) + .add("array", new ArrayType(IntegerType, true)) + val column = ColumnVector.allocate(10, schema, memMode) + val c0 = column.getChildColumn(0) + val c1 = column.getChildColumn(1) + c0.putInt(0, 0) + c0.putInt(1, 1) + c0.putInt(2, 2) + val c1Child = c1.arrayData() + (0 until 6).foreach { i => + c1Child.putInt(i, i) + } + // Arrays in c1: [0, 1], [2], [3, 4, 5] + c1.putArray(0, 0, 2) + c1.putArray(1, 2, 1) + c1.putArray(2, 3, 3) + + assert(column.getStruct(0).getInt(0) === 0) + assert(column.getStruct(0).getArray(1).toIntArray() === Array(0, 1)) + assert(column.getStruct(1).getInt(0) === 1) + assert(column.getStruct(1).getArray(1).toIntArray() === Array(2)) + assert(column.getStruct(2).getInt(0) === 2) + assert(column.getStruct(2).getArray(1).toIntArray() === Array(3, 4, 5)) + } + } + + test("Nest Struct in Struct.") { + (MemoryMode.ON_HEAP :: Nil).foreach { memMode => + val subSchema = new StructType() + .add("int", IntegerType) + .add("int", IntegerType) + val schema = new StructType() + .add("int", IntegerType) + .add("struct", subSchema) + val column = ColumnVector.allocate(10, schema, memMode) + val c0 = column.getChildColumn(0) + val c1 = column.getChildColumn(1) + c0.putInt(0, 0) + c0.putInt(1, 1) + c0.putInt(2, 2) + val c1c0 = c1.getChildColumn(0) + val c1c1 = c1.getChildColumn(1) + // Structs in c1: (7, 70), (8, 80), (9, 90) + c1c0.putInt(0, 7) + c1c0.putInt(1, 8) + c1c0.putInt(2, 9) + c1c1.putInt(0, 70) + c1c1.putInt(1, 80) + c1c1.putInt(2, 90) + + assert(column.getStruct(0).getInt(0) === 0) + assert(column.getStruct(0).getStruct(1, 2).toSeq(subSchema) === Seq(7, 70)) + assert(column.getStruct(1).getInt(0) === 1) + assert(column.getStruct(1).getStruct(1, 2).toSeq(subSchema) === Seq(8, 80)) + assert(column.getStruct(2).getInt(0) === 2) + assert(column.getStruct(2).getStruct(1, 2).toSeq(subSchema) === Seq(9, 90)) + } + } + test("ColumnarBatch basic") { (MemoryMode.ON_HEAP :: MemoryMode.OFF_HEAP :: Nil).foreach { memMode => { val schema = new StructType() From acd208ee50b29bde4e097bf88761867b1d57a665 Mon Sep 17 00:00:00 2001 From: 10129659 Date: Fri, 23 Jun 2017 20:53:26 +0800 Subject: [PATCH 0780/1765] [SPARK-21115][CORE] If the cores left is less than the coresPerExecutor,the cores left will not be allocated, so it should not to check in every schedule ## What changes were proposed in this pull request? If we start an app with the param --total-executor-cores=4 and spark.executor.cores=3, the cores left is always 1, so it will try to allocate executors in the function org.apache.spark.deploy.master.startExecutorsOnWorkers in every schedule. Another question is, is it will be better to allocate another executor with 1 core for the cores left. ## How was this patch tested? unit test Author: 10129659 Closes #18322 from eatoncys/leftcores. --- .../scala/org/apache/spark/SparkConf.scala | 11 +++++++ .../apache/spark/deploy/master/Master.scala | 29 ++++++++++--------- 2 files changed, 27 insertions(+), 13 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/SparkConf.scala b/core/src/main/scala/org/apache/spark/SparkConf.scala index ba7a65f79c414..de2f475c6895f 100644 --- a/core/src/main/scala/org/apache/spark/SparkConf.scala +++ b/core/src/main/scala/org/apache/spark/SparkConf.scala @@ -543,6 +543,17 @@ class SparkConf(loadDefaults: Boolean) extends Cloneable with Logging with Seria } } + if (contains("spark.cores.max") && contains("spark.executor.cores")) { + val totalCores = getInt("spark.cores.max", 1) + val executorCores = getInt("spark.executor.cores", 1) + val leftCores = totalCores % executorCores + if (leftCores != 0) { + logWarning(s"Total executor cores: ${totalCores} is not " + + s"divisible by cores per executor: ${executorCores}, " + + s"the left cores: ${leftCores} will not be allocated") + } + } + val encryptionEnabled = get(NETWORK_ENCRYPTION_ENABLED) || get(SASL_ENCRYPTION_ENABLED) require(!encryptionEnabled || get(NETWORK_AUTH_ENABLED), s"${NETWORK_AUTH_ENABLED.key} must be enabled when enabling encryption.") diff --git a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala index c192a0cc82ef6..0dee25fb2ebe2 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala @@ -659,19 +659,22 @@ private[deploy] class Master( private def startExecutorsOnWorkers(): Unit = { // Right now this is a very simple FIFO scheduler. We keep trying to fit in the first app // in the queue, then the second app, etc. - for (app <- waitingApps if app.coresLeft > 0) { - val coresPerExecutor: Option[Int] = app.desc.coresPerExecutor - // Filter out workers that don't have enough resources to launch an executor - val usableWorkers = workers.toArray.filter(_.state == WorkerState.ALIVE) - .filter(worker => worker.memoryFree >= app.desc.memoryPerExecutorMB && - worker.coresFree >= coresPerExecutor.getOrElse(1)) - .sortBy(_.coresFree).reverse - val assignedCores = scheduleExecutorsOnWorkers(app, usableWorkers, spreadOutApps) - - // Now that we've decided how many cores to allocate on each worker, let's allocate them - for (pos <- 0 until usableWorkers.length if assignedCores(pos) > 0) { - allocateWorkerResourceToExecutors( - app, assignedCores(pos), coresPerExecutor, usableWorkers(pos)) + for (app <- waitingApps) { + val coresPerExecutor = app.desc.coresPerExecutor.getOrElse(1) + // If the cores left is less than the coresPerExecutor,the cores left will not be allocated + if (app.coresLeft >= coresPerExecutor) { + // Filter out workers that don't have enough resources to launch an executor + val usableWorkers = workers.toArray.filter(_.state == WorkerState.ALIVE) + .filter(worker => worker.memoryFree >= app.desc.memoryPerExecutorMB && + worker.coresFree >= coresPerExecutor) + .sortBy(_.coresFree).reverse + val assignedCores = scheduleExecutorsOnWorkers(app, usableWorkers, spreadOutApps) + + // Now that we've decided how many cores to allocate on each worker, let's allocate them + for (pos <- 0 until usableWorkers.length if assignedCores(pos) > 0) { + allocateWorkerResourceToExecutors( + app, assignedCores(pos), app.desc.coresPerExecutor, usableWorkers(pos)) + } } } } From 5dca10b8fdec81a3cc476301fa4f82ea917c34ec Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Fri, 23 Jun 2017 21:51:55 +0800 Subject: [PATCH 0781/1765] [SPARK-21193][PYTHON] Specify Pandas version in setup.py ## What changes were proposed in this pull request? It looks we missed specifying the Pandas version. This PR proposes to fix it. For the current state, it should be Pandas 0.13.0 given my test. This PR propose to fix it as 0.13.0. Running the codes below: ```python from pyspark.sql.types import * schema = StructType().add("a", IntegerType()).add("b", StringType())\ .add("c", BooleanType()).add("d", FloatType()) data = [ (1, "foo", True, 3.0,), (2, "foo", True, 5.0), (3, "bar", False, -1.0), (4, "bar", False, 6.0), ] spark.createDataFrame(data, schema).toPandas().dtypes ``` prints ... **With Pandas 0.13.0** - released, 2014-01 ``` a int32 b object c bool d float32 dtype: object ``` **With Pandas 0.12.0** - - released, 2013-06 ``` Traceback (most recent call last): File "", line 1, in File ".../spark/python/pyspark/sql/dataframe.py", line 1734, in toPandas pdf[f] = pdf[f].astype(t, copy=False) TypeError: astype() got an unexpected keyword argument 'copy' ``` without `copy` ``` a int32 b object c bool d float32 dtype: object ``` **With Pandas 0.11.0** - released, 2013-03 ``` Traceback (most recent call last): File "", line 1, in File ".../spark/python/pyspark/sql/dataframe.py", line 1734, in toPandas pdf[f] = pdf[f].astype(t, copy=False) TypeError: astype() got an unexpected keyword argument 'copy' ``` without `copy` ``` a int32 b object c bool d float32 dtype: object ``` **With Pandas 0.10.0** - released, 2012-12 ``` Traceback (most recent call last): File "", line 1, in File ".../spark/python/pyspark/sql/dataframe.py", line 1734, in toPandas pdf[f] = pdf[f].astype(t, copy=False) TypeError: astype() got an unexpected keyword argument 'copy' ``` without `copy` ``` a int64 # <- this should be 'int32' b object c bool d float64 # <- this should be 'float32' ``` ## How was this patch tested? Manually tested with Pandas from 0.10.0 to 0.13.0. Author: hyukjinkwon Closes #18403 from HyukjinKwon/SPARK-21193. --- python/setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/setup.py b/python/setup.py index f50035435e26b..2644d3e79dea1 100644 --- a/python/setup.py +++ b/python/setup.py @@ -199,7 +199,7 @@ def _supports_symlinks(): extras_require={ 'ml': ['numpy>=1.7'], 'mllib': ['numpy>=1.7'], - 'sql': ['pandas'] + 'sql': ['pandas>=0.13.0'] }, classifiers=[ 'Development Status :: 5 - Production/Stable', From f3dea60793d86212ba1068e88ad89cb3dcf07801 Mon Sep 17 00:00:00 2001 From: Takeshi Yamamuro Date: Fri, 23 Jun 2017 09:28:02 -0700 Subject: [PATCH 0782/1765] [SPARK-21144][SQL] Print a warning if the data schema and partition schema have the duplicate columns ## What changes were proposed in this pull request? The current master outputs unexpected results when the data schema and partition schema have the duplicate columns: ``` withTempPath { dir => val basePath = dir.getCanonicalPath spark.range(0, 3).toDF("foo").write.parquet(new Path(basePath, "foo=1").toString) spark.range(0, 3).toDF("foo").write.parquet(new Path(basePath, "foo=a").toString) spark.read.parquet(basePath).show() } +---+ |foo| +---+ | 1| | 1| | a| | a| | 1| | a| +---+ ``` This patch added code to print a warning when the duplication found. ## How was this patch tested? Manually checked. Author: Takeshi Yamamuro Closes #18375 from maropu/SPARK-21144-3. --- .../apache/spark/sql/util/SchemaUtils.scala | 53 +++++++++++++++++++ .../execution/datasources/DataSource.scala | 6 +++ 2 files changed, 59 insertions(+) create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/util/SchemaUtils.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/util/SchemaUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/util/SchemaUtils.scala new file mode 100644 index 0000000000000..e881685ce6262 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/util/SchemaUtils.scala @@ -0,0 +1,53 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.util + +import org.apache.spark.internal.Logging + + +/** + * Utils for handling schemas. + * + * TODO: Merge this file with [[org.apache.spark.ml.util.SchemaUtils]]. + */ +private[spark] object SchemaUtils extends Logging { + + /** + * Checks if input column names have duplicate identifiers. Prints a warning message if + * the duplication exists. + * + * @param columnNames column names to check + * @param colType column type name, used in a warning message + * @param caseSensitiveAnalysis whether duplication checks should be case sensitive or not + */ + def checkColumnNameDuplication( + columnNames: Seq[String], colType: String, caseSensitiveAnalysis: Boolean): Unit = { + val names = if (caseSensitiveAnalysis) { + columnNames + } else { + columnNames.map(_.toLowerCase) + } + if (names.distinct.length != names.length) { + val duplicateColumns = names.groupBy(identity).collect { + case (x, ys) if ys.length > 1 => s"`$x`" + } + logWarning(s"Found duplicate column(s) $colType: ${duplicateColumns.mkString(", ")}. " + + "You might need to assign different column names.") + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala index 08c78e6e326af..75e530607570f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala @@ -40,6 +40,7 @@ import org.apache.spark.sql.execution.streaming._ import org.apache.spark.sql.sources._ import org.apache.spark.sql.streaming.OutputMode import org.apache.spark.sql.types.{CalendarIntervalType, StructType} +import org.apache.spark.sql.util.SchemaUtils import org.apache.spark.util.Utils /** @@ -182,6 +183,11 @@ case class DataSource( throw new AnalysisException( s"Unable to infer schema for $format. It must be specified manually.") } + + SchemaUtils.checkColumnNameDuplication( + (dataSchema ++ partitionSchema).map(_.name), "in the data schema and the partition schema", + sparkSession.sessionState.conf.caseSensitiveAnalysis) + (dataSchema, partitionSchema) } From 07479b3cfb7a617a18feca14e9e31c208c80630e Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Fri, 23 Jun 2017 09:59:24 -0700 Subject: [PATCH 0783/1765] [SPARK-21149][R] Add job description API for R ## What changes were proposed in this pull request? Extend `setJobDescription` to SparkR API. ## How was this patch tested? It looks difficult to add a test. Manually tested as below: ```r df <- createDataFrame(iris) count(df) setJobDescription("This is an example job.") count(df) ``` prints ... ![2017-06-22 12 05 49](https://user-images.githubusercontent.com/6477701/27415670-2a649936-5743-11e7-8e95-312f1cd103af.png) Author: hyukjinkwon Closes #18382 from HyukjinKwon/SPARK-21149. --- R/pkg/NAMESPACE | 3 ++- R/pkg/R/sparkR.R | 17 +++++++++++++++++ R/pkg/tests/fulltests/test_context.R | 1 + 3 files changed, 20 insertions(+), 1 deletion(-) diff --git a/R/pkg/NAMESPACE b/R/pkg/NAMESPACE index 229de4a997eef..b7fdae58de459 100644 --- a/R/pkg/NAMESPACE +++ b/R/pkg/NAMESPACE @@ -75,7 +75,8 @@ exportMethods("glm", # Job group lifecycle management methods export("setJobGroup", "clearJobGroup", - "cancelJobGroup") + "cancelJobGroup", + "setJobDescription") # Export Utility methods export("setLogLevel") diff --git a/R/pkg/R/sparkR.R b/R/pkg/R/sparkR.R index d0a12b7ecec65..f2d2620e5447a 100644 --- a/R/pkg/R/sparkR.R +++ b/R/pkg/R/sparkR.R @@ -535,6 +535,23 @@ cancelJobGroup <- function(sc, groupId) { } } +#' Set a human readable description of the current job. +#' +#' Set a description that is shown as a job description in UI. +#' +#' @param value The job description of the current job. +#' @rdname setJobDescription +#' @name setJobDescription +#' @examples +#'\dontrun{ +#' setJobDescription("This is an example job.") +#'} +#' @note setJobDescription since 2.3.0 +setJobDescription <- function(value) { + sc <- getSparkContext() + invisible(callJMethod(sc, "setJobDescription", value)) +} + sparkConfToSubmitOps <- new.env() sparkConfToSubmitOps[["spark.driver.memory"]] <- "--driver-memory" sparkConfToSubmitOps[["spark.driver.extraClassPath"]] <- "--driver-class-path" diff --git a/R/pkg/tests/fulltests/test_context.R b/R/pkg/tests/fulltests/test_context.R index 710485d56685a..77635c5a256b9 100644 --- a/R/pkg/tests/fulltests/test_context.R +++ b/R/pkg/tests/fulltests/test_context.R @@ -100,6 +100,7 @@ test_that("job group functions can be called", { setJobGroup("groupId", "job description", TRUE) cancelJobGroup("groupId") clearJobGroup() + setJobDescription("job description") suppressWarnings(setJobGroup(sc, "groupId", "job description", TRUE)) suppressWarnings(cancelJobGroup(sc, "groupId")) From b803b66a8133f705463039325ee71ee6827ce1a7 Mon Sep 17 00:00:00 2001 From: wangzhenhua Date: Fri, 23 Jun 2017 10:33:53 -0700 Subject: [PATCH 0784/1765] [SPARK-21180][SQL] Remove conf from stats functions since now we have conf in LogicalPlan ## What changes were proposed in this pull request? After wiring `SQLConf` in logical plan ([PR 18299](https://github.com/apache/spark/pull/18299)), we can remove the need of passing `conf` into `def stats` and `def computeStats`. ## How was this patch tested? Covered by existing tests, plus some modified existing tests. Author: wangzhenhua Author: Zhenhua Wang Closes #18391 from wzhfy/removeConf. --- .../sql/catalyst/catalog/interface.scala | 3 +- .../optimizer/CostBasedJoinReorder.scala | 4 +- .../sql/catalyst/optimizer/Optimizer.scala | 2 +- .../optimizer/StarSchemaDetection.scala | 14 ++-- .../plans/logical/LocalRelation.scala | 3 +- .../catalyst/plans/logical/LogicalPlan.scala | 15 ++--- .../plans/logical/basicLogicalOperators.scala | 65 +++++++++---------- .../sql/catalyst/plans/logical/hints.scala | 5 +- .../statsEstimation/AggregateEstimation.scala | 7 +- .../statsEstimation/EstimationUtils.scala | 5 +- .../statsEstimation/FilterEstimation.scala | 5 +- .../statsEstimation/JoinEstimation.scala | 21 +++--- .../statsEstimation/ProjectEstimation.scala | 7 +- .../optimizer/JoinOptimizationSuite.scala | 2 +- .../optimizer/LimitPushdownSuite.scala | 6 +- .../AggregateEstimationSuite.scala | 30 +++++---- .../BasicStatsEstimationSuite.scala | 27 +++++--- .../FilterEstimationSuite.scala | 2 +- .../statsEstimation/JoinEstimationSuite.scala | 26 ++++---- .../ProjectEstimationSuite.scala | 4 +- .../StatsEstimationTestBase.scala | 18 +++-- .../spark/sql/execution/ExistingRDD.scala | 5 +- .../spark/sql/execution/QueryExecution.scala | 2 +- .../spark/sql/execution/SparkStrategies.scala | 13 ++-- .../execution/columnar/InMemoryRelation.scala | 3 +- .../datasources/LogicalRelation.scala | 3 +- .../sql/execution/streaming/memory.scala | 3 +- .../apache/spark/sql/CachedTableSuite.scala | 2 +- .../org/apache/spark/sql/DatasetSuite.scala | 2 +- .../org/apache/spark/sql/JoinSuite.scala | 2 +- .../spark/sql/StatisticsCollectionSuite.scala | 18 ++--- .../columnar/InMemoryColumnarQuerySuite.scala | 2 +- .../datasources/HadoopFsRelationSuite.scala | 2 +- .../execution/streaming/MemorySinkSuite.scala | 6 +- .../apache/spark/sql/test/SQLTestData.scala | 3 - .../spark/sql/hive/HiveMetastoreCatalog.scala | 2 +- .../spark/sql/hive/StatisticsSuite.scala | 10 +-- .../PruneFileSourcePartitionsSuite.scala | 2 +- 38 files changed, 178 insertions(+), 173 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala index c043ed9c431b7..b63bef9193332 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala @@ -31,7 +31,6 @@ import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap, Attri import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, DateTimeUtils} import org.apache.spark.sql.catalyst.util.quoteIdentifier -import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.StructType @@ -436,7 +435,7 @@ case class CatalogRelation( createTime = -1 )) - override def computeStats(conf: SQLConf): Statistics = { + override def computeStats: Statistics = { // For data source tables, we will create a `LogicalRelation` and won't call this method, for // hive serde tables, we will always generate a statistics. // TODO: unify the table stats generation. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/CostBasedJoinReorder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/CostBasedJoinReorder.scala index 51eca6ca33760..3a7543e2141e9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/CostBasedJoinReorder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/CostBasedJoinReorder.scala @@ -58,7 +58,7 @@ case class CostBasedJoinReorder(conf: SQLConf) extends Rule[LogicalPlan] with Pr // Do reordering if the number of items is appropriate and join conditions exist. // We also need to check if costs of all items can be evaluated. if (items.size > 2 && items.size <= conf.joinReorderDPThreshold && conditions.nonEmpty && - items.forall(_.stats(conf).rowCount.isDefined)) { + items.forall(_.stats.rowCount.isDefined)) { JoinReorderDP.search(conf, items, conditions, output) } else { plan @@ -322,7 +322,7 @@ object JoinReorderDP extends PredicateHelper with Logging { /** Get the cost of the root node of this plan tree. */ def rootCost(conf: SQLConf): Cost = { if (itemIds.size > 1) { - val rootStats = plan.stats(conf) + val rootStats = plan.stats Cost(rootStats.rowCount.get, rootStats.sizeInBytes) } else { // If the plan is a leaf item, it has zero cost. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 3ab70fb90470c..b410312030c5d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -317,7 +317,7 @@ case class LimitPushDown(conf: SQLConf) extends Rule[LogicalPlan] { case FullOuter => (left.maxRows, right.maxRows) match { case (None, None) => - if (left.stats(conf).sizeInBytes >= right.stats(conf).sizeInBytes) { + if (left.stats.sizeInBytes >= right.stats.sizeInBytes) { join.copy(left = maybePushLimit(exp, left)) } else { join.copy(right = maybePushLimit(exp, right)) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/StarSchemaDetection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/StarSchemaDetection.scala index 97ee9988386dd..ca729127e7d1d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/StarSchemaDetection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/StarSchemaDetection.scala @@ -82,7 +82,7 @@ case class StarSchemaDetection(conf: SQLConf) extends PredicateHelper { // Find if the input plans are eligible for star join detection. // An eligible plan is a base table access with valid statistics. val foundEligibleJoin = input.forall { - case PhysicalOperation(_, _, t: LeafNode) if t.stats(conf).rowCount.isDefined => true + case PhysicalOperation(_, _, t: LeafNode) if t.stats.rowCount.isDefined => true case _ => false } @@ -181,7 +181,7 @@ case class StarSchemaDetection(conf: SQLConf) extends PredicateHelper { val leafCol = findLeafNodeCol(column, plan) leafCol match { case Some(col) if t.outputSet.contains(col) => - val stats = t.stats(conf) + val stats = t.stats stats.rowCount match { case Some(rowCount) if rowCount >= 0 => if (stats.attributeStats.nonEmpty && stats.attributeStats.contains(col)) { @@ -237,7 +237,7 @@ case class StarSchemaDetection(conf: SQLConf) extends PredicateHelper { val leafCol = findLeafNodeCol(column, plan) leafCol match { case Some(col) if t.outputSet.contains(col) => - val stats = t.stats(conf) + val stats = t.stats stats.attributeStats.nonEmpty && stats.attributeStats.contains(col) case None => false } @@ -296,11 +296,11 @@ case class StarSchemaDetection(conf: SQLConf) extends PredicateHelper { */ private def getTableAccessCardinality( input: LogicalPlan): Option[BigInt] = input match { - case PhysicalOperation(_, cond, t: LeafNode) if t.stats(conf).rowCount.isDefined => - if (conf.cboEnabled && input.stats(conf).rowCount.isDefined) { - Option(input.stats(conf).rowCount.get) + case PhysicalOperation(_, cond, t: LeafNode) if t.stats.rowCount.isDefined => + if (conf.cboEnabled && input.stats.rowCount.isDefined) { + Option(input.stats.rowCount.get) } else { - Option(t.stats(conf).rowCount.get) + Option(t.stats.rowCount.get) } case _ => None } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LocalRelation.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LocalRelation.scala index 9cd5dfd21b160..dc2add64b68b7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LocalRelation.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LocalRelation.scala @@ -21,7 +21,6 @@ import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} import org.apache.spark.sql.catalyst.analysis import org.apache.spark.sql.catalyst.expressions.{Attribute, Literal} -import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.{StructField, StructType} object LocalRelation { @@ -67,7 +66,7 @@ case class LocalRelation(output: Seq[Attribute], data: Seq[InternalRow] = Nil) } } - override def computeStats(conf: SQLConf): Statistics = + override def computeStats: Statistics = Statistics(sizeInBytes = output.map(n => BigInt(n.dataType.defaultSize)).sum * data.length) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala index 95b4165f6b10d..0c098ac0209e8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala @@ -23,7 +23,6 @@ import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.QueryPlan import org.apache.spark.sql.catalyst.trees.CurrentOrigin -import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.StructType @@ -90,8 +89,8 @@ abstract class LogicalPlan extends QueryPlan[LogicalPlan] with QueryPlanConstrai * first time. If the configuration changes, the cache can be invalidated by calling * [[invalidateStatsCache()]]. */ - final def stats(conf: SQLConf): Statistics = statsCache.getOrElse { - statsCache = Some(computeStats(conf)) + final def stats: Statistics = statsCache.getOrElse { + statsCache = Some(computeStats) statsCache.get } @@ -108,11 +107,11 @@ abstract class LogicalPlan extends QueryPlan[LogicalPlan] with QueryPlanConstrai * * [[LeafNode]]s must override this. */ - protected def computeStats(conf: SQLConf): Statistics = { + protected def computeStats: Statistics = { if (children.isEmpty) { throw new UnsupportedOperationException(s"LeafNode $nodeName must implement statistics.") } - Statistics(sizeInBytes = children.map(_.stats(conf).sizeInBytes).product) + Statistics(sizeInBytes = children.map(_.stats.sizeInBytes).product) } override def verboseStringWithSuffix: String = { @@ -333,13 +332,13 @@ abstract class UnaryNode extends LogicalPlan { override protected def validConstraints: Set[Expression] = child.constraints - override def computeStats(conf: SQLConf): Statistics = { + override def computeStats: Statistics = { // There should be some overhead in Row object, the size should not be zero when there is // no columns, this help to prevent divide-by-zero error. val childRowSize = child.output.map(_.dataType.defaultSize).sum + 8 val outputRowSize = output.map(_.dataType.defaultSize).sum + 8 // Assume there will be the same number of rows as child has. - var sizeInBytes = (child.stats(conf).sizeInBytes * outputRowSize) / childRowSize + var sizeInBytes = (child.stats.sizeInBytes * outputRowSize) / childRowSize if (sizeInBytes == 0) { // sizeInBytes can't be zero, or sizeInBytes of BinaryNode will also be zero // (product of children). @@ -347,7 +346,7 @@ abstract class UnaryNode extends LogicalPlan { } // Don't propagate rowCount and attributeStats, since they are not estimated here. - Statistics(sizeInBytes = sizeInBytes, hints = child.stats(conf).hints) + Statistics(sizeInBytes = sizeInBytes, hints = child.stats.hints) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala index 6e88b7a57dc33..d8f89b108e63f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala @@ -23,7 +23,6 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical.statsEstimation._ -import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ import org.apache.spark.util.Utils import org.apache.spark.util.random.RandomSampler @@ -65,11 +64,11 @@ case class Project(projectList: Seq[NamedExpression], child: LogicalPlan) extend override def validConstraints: Set[Expression] = child.constraints.union(getAliasedConstraints(projectList)) - override def computeStats(conf: SQLConf): Statistics = { + override def computeStats: Statistics = { if (conf.cboEnabled) { - ProjectEstimation.estimate(conf, this).getOrElse(super.computeStats(conf)) + ProjectEstimation.estimate(this).getOrElse(super.computeStats) } else { - super.computeStats(conf) + super.computeStats } } } @@ -139,11 +138,11 @@ case class Filter(condition: Expression, child: LogicalPlan) child.constraints.union(predicates.toSet) } - override def computeStats(conf: SQLConf): Statistics = { + override def computeStats: Statistics = { if (conf.cboEnabled) { - FilterEstimation(this, conf).estimate.getOrElse(super.computeStats(conf)) + FilterEstimation(this).estimate.getOrElse(super.computeStats) } else { - super.computeStats(conf) + super.computeStats } } } @@ -192,13 +191,13 @@ case class Intersect(left: LogicalPlan, right: LogicalPlan) extends SetOperation } } - override def computeStats(conf: SQLConf): Statistics = { - val leftSize = left.stats(conf).sizeInBytes - val rightSize = right.stats(conf).sizeInBytes + override def computeStats: Statistics = { + val leftSize = left.stats.sizeInBytes + val rightSize = right.stats.sizeInBytes val sizeInBytes = if (leftSize < rightSize) leftSize else rightSize Statistics( sizeInBytes = sizeInBytes, - hints = left.stats(conf).hints.resetForJoin()) + hints = left.stats.hints.resetForJoin()) } } @@ -209,8 +208,8 @@ case class Except(left: LogicalPlan, right: LogicalPlan) extends SetOperation(le override protected def validConstraints: Set[Expression] = leftConstraints - override def computeStats(conf: SQLConf): Statistics = { - left.stats(conf).copy() + override def computeStats: Statistics = { + left.stats.copy() } } @@ -248,8 +247,8 @@ case class Union(children: Seq[LogicalPlan]) extends LogicalPlan { children.length > 1 && childrenResolved && allChildrenCompatible } - override def computeStats(conf: SQLConf): Statistics = { - val sizeInBytes = children.map(_.stats(conf).sizeInBytes).sum + override def computeStats: Statistics = { + val sizeInBytes = children.map(_.stats.sizeInBytes).sum Statistics(sizeInBytes = sizeInBytes) } @@ -357,20 +356,20 @@ case class Join( case _ => resolvedExceptNatural } - override def computeStats(conf: SQLConf): Statistics = { + override def computeStats: Statistics = { def simpleEstimation: Statistics = joinType match { case LeftAnti | LeftSemi => // LeftSemi and LeftAnti won't ever be bigger than left - left.stats(conf) + left.stats case _ => // Make sure we don't propagate isBroadcastable in other joins, because // they could explode the size. - val stats = super.computeStats(conf) + val stats = super.computeStats stats.copy(hints = stats.hints.resetForJoin()) } if (conf.cboEnabled) { - JoinEstimation.estimate(conf, this).getOrElse(simpleEstimation) + JoinEstimation.estimate(this).getOrElse(simpleEstimation) } else { simpleEstimation } @@ -523,7 +522,7 @@ case class Range( override def newInstance(): Range = copy(output = output.map(_.newInstance())) - override def computeStats(conf: SQLConf): Statistics = { + override def computeStats: Statistics = { val sizeInBytes = LongType.defaultSize * numElements Statistics( sizeInBytes = sizeInBytes ) } @@ -556,20 +555,20 @@ case class Aggregate( child.constraints.union(getAliasedConstraints(nonAgg)) } - override def computeStats(conf: SQLConf): Statistics = { + override def computeStats: Statistics = { def simpleEstimation: Statistics = { if (groupingExpressions.isEmpty) { Statistics( sizeInBytes = EstimationUtils.getOutputSize(output, outputRowCount = 1), rowCount = Some(1), - hints = child.stats(conf).hints) + hints = child.stats.hints) } else { - super.computeStats(conf) + super.computeStats } } if (conf.cboEnabled) { - AggregateEstimation.estimate(conf, this).getOrElse(simpleEstimation) + AggregateEstimation.estimate(this).getOrElse(simpleEstimation) } else { simpleEstimation } @@ -672,8 +671,8 @@ case class Expand( override def references: AttributeSet = AttributeSet(projections.flatten.flatMap(_.references)) - override def computeStats(conf: SQLConf): Statistics = { - val sizeInBytes = super.computeStats(conf).sizeInBytes * projections.length + override def computeStats: Statistics = { + val sizeInBytes = super.computeStats.sizeInBytes * projections.length Statistics(sizeInBytes = sizeInBytes) } @@ -743,9 +742,9 @@ case class GlobalLimit(limitExpr: Expression, child: LogicalPlan) extends UnaryN case _ => None } } - override def computeStats(conf: SQLConf): Statistics = { + override def computeStats: Statistics = { val limit = limitExpr.eval().asInstanceOf[Int] - val childStats = child.stats(conf) + val childStats = child.stats val rowCount: BigInt = childStats.rowCount.map(_.min(limit)).getOrElse(limit) // Don't propagate column stats, because we don't know the distribution after a limit operation Statistics( @@ -763,9 +762,9 @@ case class LocalLimit(limitExpr: Expression, child: LogicalPlan) extends UnaryNo case _ => None } } - override def computeStats(conf: SQLConf): Statistics = { + override def computeStats: Statistics = { val limit = limitExpr.eval().asInstanceOf[Int] - val childStats = child.stats(conf) + val childStats = child.stats if (limit == 0) { // sizeInBytes can't be zero, or sizeInBytes of BinaryNode will also be zero // (product of children). @@ -832,9 +831,9 @@ case class Sample( override def output: Seq[Attribute] = child.output - override def computeStats(conf: SQLConf): Statistics = { + override def computeStats: Statistics = { val ratio = upperBound - lowerBound - val childStats = child.stats(conf) + val childStats = child.stats var sizeInBytes = EstimationUtils.ceil(BigDecimal(childStats.sizeInBytes) * ratio) if (sizeInBytes == 0) { sizeInBytes = 1 @@ -898,7 +897,7 @@ case class RepartitionByExpression( case object OneRowRelation extends LeafNode { override def maxRows: Option[Long] = Some(1) override def output: Seq[Attribute] = Nil - override def computeStats(conf: SQLConf): Statistics = Statistics(sizeInBytes = 1) + override def computeStats: Statistics = Statistics(sizeInBytes = 1) } /** A logical plan for `dropDuplicates`. */ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/hints.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/hints.scala index e49970df80457..8479c702d7561 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/hints.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/hints.scala @@ -18,7 +18,6 @@ package org.apache.spark.sql.catalyst.plans.logical import org.apache.spark.sql.catalyst.expressions.Attribute -import org.apache.spark.sql.internal.SQLConf /** * A general hint for the child that is not yet resolved. This node is generated by the parser and @@ -44,8 +43,8 @@ case class ResolvedHint(child: LogicalPlan, hints: HintInfo = HintInfo()) override lazy val canonicalized: LogicalPlan = child.canonicalized - override def computeStats(conf: SQLConf): Statistics = { - val stats = child.stats(conf) + override def computeStats: Statistics = { + val stats = child.stats stats.copy(hints = hints) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/AggregateEstimation.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/AggregateEstimation.scala index a0c23198451a8..c41fac4015ec0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/AggregateEstimation.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/AggregateEstimation.scala @@ -19,7 +19,6 @@ package org.apache.spark.sql.catalyst.plans.logical.statsEstimation import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Statistics} -import org.apache.spark.sql.internal.SQLConf object AggregateEstimation { @@ -29,13 +28,13 @@ object AggregateEstimation { * Estimate the number of output rows based on column stats of group-by columns, and propagate * column stats for aggregate expressions. */ - def estimate(conf: SQLConf, agg: Aggregate): Option[Statistics] = { - val childStats = agg.child.stats(conf) + def estimate(agg: Aggregate): Option[Statistics] = { + val childStats = agg.child.stats // Check if we have column stats for all group-by columns. val colStatsExist = agg.groupingExpressions.forall { e => e.isInstanceOf[Attribute] && childStats.attributeStats.contains(e.asInstanceOf[Attribute]) } - if (rowCountsExist(conf, agg.child) && colStatsExist) { + if (rowCountsExist(agg.child) && colStatsExist) { // Multiply distinct counts of group-by columns. This is an upper bound, which assumes // the data contains all combinations of distinct values of group-by columns. var outputRows: BigInt = agg.groupingExpressions.foldLeft(BigInt(1))( diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/EstimationUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/EstimationUtils.scala index e5fcdf9039be9..9c34a9b7aa756 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/EstimationUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/EstimationUtils.scala @@ -21,15 +21,14 @@ import scala.math.BigDecimal.RoundingMode import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap} import org.apache.spark.sql.catalyst.plans.logical.{ColumnStat, LogicalPlan, Statistics} -import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.{DecimalType, _} object EstimationUtils { /** Check if each plan has rowCount in its statistics. */ - def rowCountsExist(conf: SQLConf, plans: LogicalPlan*): Boolean = - plans.forall(_.stats(conf).rowCount.isDefined) + def rowCountsExist(plans: LogicalPlan*): Boolean = + plans.forall(_.stats.rowCount.isDefined) /** Check if each attribute has column stat in the corresponding statistics. */ def columnStatsExist(statsAndAttr: (Statistics, Attribute)*): Boolean = { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/FilterEstimation.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/FilterEstimation.scala index df190867189ec..5a3bee7b9e449 100755 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/FilterEstimation.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/FilterEstimation.scala @@ -25,12 +25,11 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.Literal.{FalseLiteral, TrueLiteral} import org.apache.spark.sql.catalyst.plans.logical.{ColumnStat, Filter, LeafNode, Statistics} import org.apache.spark.sql.catalyst.plans.logical.statsEstimation.EstimationUtils._ -import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ -case class FilterEstimation(plan: Filter, catalystConf: SQLConf) extends Logging { +case class FilterEstimation(plan: Filter) extends Logging { - private val childStats = plan.child.stats(catalystConf) + private val childStats = plan.child.stats private val colStatsMap = new ColumnStatsMap(childStats.attributeStats) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/JoinEstimation.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/JoinEstimation.scala index 8ef905c45d50d..f48196997a24d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/JoinEstimation.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/JoinEstimation.scala @@ -26,7 +26,6 @@ import org.apache.spark.sql.catalyst.planning.ExtractEquiJoinKeys import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical.{ColumnStat, Join, Statistics} import org.apache.spark.sql.catalyst.plans.logical.statsEstimation.EstimationUtils._ -import org.apache.spark.sql.internal.SQLConf object JoinEstimation extends Logging { @@ -34,12 +33,12 @@ object JoinEstimation extends Logging { * Estimate statistics after join. Return `None` if the join type is not supported, or we don't * have enough statistics for estimation. */ - def estimate(conf: SQLConf, join: Join): Option[Statistics] = { + def estimate(join: Join): Option[Statistics] = { join.joinType match { case Inner | Cross | LeftOuter | RightOuter | FullOuter => - InnerOuterEstimation(conf, join).doEstimate() + InnerOuterEstimation(join).doEstimate() case LeftSemi | LeftAnti => - LeftSemiAntiEstimation(conf, join).doEstimate() + LeftSemiAntiEstimation(join).doEstimate() case _ => logDebug(s"[CBO] Unsupported join type: ${join.joinType}") None @@ -47,16 +46,16 @@ object JoinEstimation extends Logging { } } -case class InnerOuterEstimation(conf: SQLConf, join: Join) extends Logging { +case class InnerOuterEstimation(join: Join) extends Logging { - private val leftStats = join.left.stats(conf) - private val rightStats = join.right.stats(conf) + private val leftStats = join.left.stats + private val rightStats = join.right.stats /** * Estimate output size and number of rows after a join operator, and update output column stats. */ def doEstimate(): Option[Statistics] = join match { - case _ if !rowCountsExist(conf, join.left, join.right) => + case _ if !rowCountsExist(join.left, join.right) => None case ExtractEquiJoinKeys(joinType, leftKeys, rightKeys, _, _, _) => @@ -273,13 +272,13 @@ case class InnerOuterEstimation(conf: SQLConf, join: Join) extends Logging { } } -case class LeftSemiAntiEstimation(conf: SQLConf, join: Join) { +case class LeftSemiAntiEstimation(join: Join) { def doEstimate(): Option[Statistics] = { // TODO: It's error-prone to estimate cardinalities for LeftSemi and LeftAnti based on basic // column stats. Now we just propagate the statistics from left side. We should do more // accurate estimation when advanced stats (e.g. histograms) are available. - if (rowCountsExist(conf, join.left)) { - val leftStats = join.left.stats(conf) + if (rowCountsExist(join.left)) { + val leftStats = join.left.stats // Propagate the original column stats for cartesian product val outputRows = leftStats.rowCount.get Some(Statistics( diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/ProjectEstimation.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/ProjectEstimation.scala index d700cd3b20f7d..489eb904ffd05 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/ProjectEstimation.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/ProjectEstimation.scala @@ -19,14 +19,13 @@ package org.apache.spark.sql.catalyst.plans.logical.statsEstimation import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, AttributeMap} import org.apache.spark.sql.catalyst.plans.logical.{Project, Statistics} -import org.apache.spark.sql.internal.SQLConf object ProjectEstimation { import EstimationUtils._ - def estimate(conf: SQLConf, project: Project): Option[Statistics] = { - if (rowCountsExist(conf, project.child)) { - val childStats = project.child.stats(conf) + def estimate(project: Project): Option[Statistics] = { + if (rowCountsExist(project.child)) { + val childStats = project.child.stats val inputAttrStats = childStats.attributeStats // Match alias with its child's column stat val aliasStats = project.expressions.collect { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinOptimizationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinOptimizationSuite.scala index 105407d43bf39..a6584aa5fbba7 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinOptimizationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinOptimizationSuite.scala @@ -142,7 +142,7 @@ class JoinOptimizationSuite extends PlanTest { comparePlans(optimized, expected) val broadcastChildren = optimized.collect { - case Join(_, r, _, _) if r.stats(conf).sizeInBytes == 1 => r + case Join(_, r, _, _) if r.stats.sizeInBytes == 1 => r } assert(broadcastChildren.size == 1) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/LimitPushdownSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/LimitPushdownSuite.scala index fb34c82de468b..d8302dfc9462d 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/LimitPushdownSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/LimitPushdownSuite.scala @@ -112,7 +112,7 @@ class LimitPushdownSuite extends PlanTest { } test("full outer join where neither side is limited and both sides have same statistics") { - assert(x.stats(conf).sizeInBytes === y.stats(conf).sizeInBytes) + assert(x.stats.sizeInBytes === y.stats.sizeInBytes) val originalQuery = x.join(y, FullOuter).limit(1) val optimized = Optimize.execute(originalQuery.analyze) val correctAnswer = Limit(1, LocalLimit(1, x).join(y, FullOuter)).analyze @@ -121,7 +121,7 @@ class LimitPushdownSuite extends PlanTest { test("full outer join where neither side is limited and left side has larger statistics") { val xBig = testRelation.copy(data = Seq.fill(2)(null)).subquery('x) - assert(xBig.stats(conf).sizeInBytes > y.stats(conf).sizeInBytes) + assert(xBig.stats.sizeInBytes > y.stats.sizeInBytes) val originalQuery = xBig.join(y, FullOuter).limit(1) val optimized = Optimize.execute(originalQuery.analyze) val correctAnswer = Limit(1, LocalLimit(1, xBig).join(y, FullOuter)).analyze @@ -130,7 +130,7 @@ class LimitPushdownSuite extends PlanTest { test("full outer join where neither side is limited and right side has larger statistics") { val yBig = testRelation.copy(data = Seq.fill(2)(null)).subquery('y) - assert(x.stats(conf).sizeInBytes < yBig.stats(conf).sizeInBytes) + assert(x.stats.sizeInBytes < yBig.stats.sizeInBytes) val originalQuery = x.join(yBig, FullOuter).limit(1) val optimized = Optimize.execute(originalQuery.analyze) val correctAnswer = Limit(1, x.join(LocalLimit(1, yBig), FullOuter)).analyze diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/AggregateEstimationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/AggregateEstimationSuite.scala index 38483a298cef0..30ddf03bd3c4f 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/AggregateEstimationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/AggregateEstimationSuite.scala @@ -100,17 +100,23 @@ class AggregateEstimationSuite extends StatsEstimationTestBase { size = Some(4 * (8 + 4)), attributeStats = AttributeMap(Seq("key12").map(nameToColInfo))) - val noGroupAgg = Aggregate(groupingExpressions = Nil, - aggregateExpressions = Seq(Alias(Count(Literal(1)), "cnt")()), child) - assert(noGroupAgg.stats(conf.copy(SQLConf.CBO_ENABLED -> false)) == - // overhead + count result size - Statistics(sizeInBytes = 8 + 8, rowCount = Some(1))) - - val hasGroupAgg = Aggregate(groupingExpressions = attributes, - aggregateExpressions = attributes :+ Alias(Count(Literal(1)), "cnt")(), child) - assert(hasGroupAgg.stats(conf.copy(SQLConf.CBO_ENABLED -> false)) == - // From UnaryNode.computeStats, childSize * outputRowSize / childRowSize - Statistics(sizeInBytes = 48 * (8 + 4 + 8) / (8 + 4))) + val originalValue = SQLConf.get.getConf(SQLConf.CBO_ENABLED) + try { + SQLConf.get.setConf(SQLConf.CBO_ENABLED, false) + val noGroupAgg = Aggregate(groupingExpressions = Nil, + aggregateExpressions = Seq(Alias(Count(Literal(1)), "cnt")()), child) + assert(noGroupAgg.stats == + // overhead + count result size + Statistics(sizeInBytes = 8 + 8, rowCount = Some(1))) + + val hasGroupAgg = Aggregate(groupingExpressions = attributes, + aggregateExpressions = attributes :+ Alias(Count(Literal(1)), "cnt")(), child) + assert(hasGroupAgg.stats == + // From UnaryNode.computeStats, childSize * outputRowSize / childRowSize + Statistics(sizeInBytes = 48 * (8 + 4 + 8) / (8 + 4))) + } finally { + SQLConf.get.setConf(SQLConf.CBO_ENABLED, originalValue) + } } private def checkAggStats( @@ -134,6 +140,6 @@ class AggregateEstimationSuite extends StatsEstimationTestBase { rowCount = Some(expectedOutputRowCount), attributeStats = expectedAttrStats) - assert(testAgg.stats(conf) == expectedStats) + assert(testAgg.stats == expectedStats) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/BasicStatsEstimationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/BasicStatsEstimationSuite.scala index 833f5a71994f7..e9ed36feec48c 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/BasicStatsEstimationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/BasicStatsEstimationSuite.scala @@ -57,16 +57,16 @@ class BasicStatsEstimationSuite extends StatsEstimationTestBase { val localLimit = LocalLimit(Literal(2), plan) val globalLimit = GlobalLimit(Literal(2), plan) // LocalLimit's stats is just its child's stats except column stats - checkStats(localLimit, plan.stats(conf).copy(attributeStats = AttributeMap(Nil))) + checkStats(localLimit, plan.stats.copy(attributeStats = AttributeMap(Nil))) checkStats(globalLimit, Statistics(sizeInBytes = 24, rowCount = Some(2))) } test("limit estimation: limit > child's rowCount") { val localLimit = LocalLimit(Literal(20), plan) val globalLimit = GlobalLimit(Literal(20), plan) - checkStats(localLimit, plan.stats(conf).copy(attributeStats = AttributeMap(Nil))) + checkStats(localLimit, plan.stats.copy(attributeStats = AttributeMap(Nil))) // Limit is larger than child's rowCount, so GlobalLimit's stats is equal to its child's stats. - checkStats(globalLimit, plan.stats(conf).copy(attributeStats = AttributeMap(Nil))) + checkStats(globalLimit, plan.stats.copy(attributeStats = AttributeMap(Nil))) } test("limit estimation: limit = 0") { @@ -113,12 +113,19 @@ class BasicStatsEstimationSuite extends StatsEstimationTestBase { plan: LogicalPlan, expectedStatsCboOn: Statistics, expectedStatsCboOff: Statistics): Unit = { - // Invalidate statistics - plan.invalidateStatsCache() - assert(plan.stats(conf.copy(SQLConf.CBO_ENABLED -> true)) == expectedStatsCboOn) - - plan.invalidateStatsCache() - assert(plan.stats(conf.copy(SQLConf.CBO_ENABLED -> false)) == expectedStatsCboOff) + val originalValue = SQLConf.get.getConf(SQLConf.CBO_ENABLED) + try { + // Invalidate statistics + plan.invalidateStatsCache() + SQLConf.get.setConf(SQLConf.CBO_ENABLED, true) + assert(plan.stats == expectedStatsCboOn) + + plan.invalidateStatsCache() + SQLConf.get.setConf(SQLConf.CBO_ENABLED, false) + assert(plan.stats == expectedStatsCboOff) + } finally { + SQLConf.get.setConf(SQLConf.CBO_ENABLED, originalValue) + } } /** Check estimated stats when it's the same whether cbo is turned on or off. */ @@ -135,6 +142,6 @@ private case class DummyLogicalPlan( cboStats: Statistics) extends LogicalPlan { override def output: Seq[Attribute] = Nil override def children: Seq[LogicalPlan] = Nil - override def computeStats(conf: SQLConf): Statistics = + override def computeStats: Statistics = if (conf.cboEnabled) cboStats else defaultStats } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/FilterEstimationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/FilterEstimationSuite.scala index 2fa53a6466ef2..455037e6c9952 100755 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/FilterEstimationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/FilterEstimationSuite.scala @@ -620,7 +620,7 @@ class FilterEstimationSuite extends StatsEstimationTestBase { rowCount = Some(expectedRowCount), attributeStats = expectedAttributeMap) - val filterStats = filter.stats(conf) + val filterStats = filter.stats assert(filterStats.sizeInBytes == expectedStats.sizeInBytes) assert(filterStats.rowCount == expectedStats.rowCount) val rowCountValue = filterStats.rowCount.getOrElse(0) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/JoinEstimationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/JoinEstimationSuite.scala index 2d6b6e8e21f34..097c78eb27fca 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/JoinEstimationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/JoinEstimationSuite.scala @@ -77,7 +77,7 @@ class JoinEstimationSuite extends StatsEstimationTestBase { // Keep the column stat from both sides unchanged. attributeStats = AttributeMap( Seq("key-1-5", "key-5-9", "key-1-2", "key-2-4").map(nameToColInfo))) - assert(join.stats(conf) == expectedStats) + assert(join.stats == expectedStats) } test("disjoint inner join") { @@ -90,7 +90,7 @@ class JoinEstimationSuite extends StatsEstimationTestBase { sizeInBytes = 1, rowCount = Some(0), attributeStats = AttributeMap(Nil)) - assert(join.stats(conf) == expectedStats) + assert(join.stats == expectedStats) } test("disjoint left outer join") { @@ -106,7 +106,7 @@ class JoinEstimationSuite extends StatsEstimationTestBase { // Null count for right side columns = left row count Seq(nameToAttr("key-1-2") -> nullColumnStat(nameToAttr("key-1-2").dataType, 5), nameToAttr("key-2-4") -> nullColumnStat(nameToAttr("key-2-4").dataType, 5)))) - assert(join.stats(conf) == expectedStats) + assert(join.stats == expectedStats) } test("disjoint right outer join") { @@ -122,7 +122,7 @@ class JoinEstimationSuite extends StatsEstimationTestBase { // Null count for left side columns = right row count Seq(nameToAttr("key-1-5") -> nullColumnStat(nameToAttr("key-1-5").dataType, 3), nameToAttr("key-5-9") -> nullColumnStat(nameToAttr("key-5-9").dataType, 3)))) - assert(join.stats(conf) == expectedStats) + assert(join.stats == expectedStats) } test("disjoint full outer join") { @@ -140,7 +140,7 @@ class JoinEstimationSuite extends StatsEstimationTestBase { nameToAttr("key-5-9") -> columnInfo(nameToAttr("key-5-9")).copy(nullCount = 3), nameToAttr("key-1-2") -> columnInfo(nameToAttr("key-1-2")).copy(nullCount = 5), nameToAttr("key-2-4") -> columnInfo(nameToAttr("key-2-4")).copy(nullCount = 5)))) - assert(join.stats(conf) == expectedStats) + assert(join.stats == expectedStats) } test("inner join") { @@ -161,7 +161,7 @@ class JoinEstimationSuite extends StatsEstimationTestBase { attributeStats = AttributeMap( Seq(nameToAttr("key-1-5") -> joinedColStat, nameToAttr("key-1-2") -> joinedColStat, nameToAttr("key-5-9") -> colStatForkey59, nameToColInfo("key-2-4")))) - assert(join.stats(conf) == expectedStats) + assert(join.stats == expectedStats) } test("inner join with multiple equi-join keys") { @@ -183,7 +183,7 @@ class JoinEstimationSuite extends StatsEstimationTestBase { attributeStats = AttributeMap( Seq(nameToAttr("key-1-2") -> joinedColStat1, nameToAttr("key-1-2") -> joinedColStat1, nameToAttr("key-2-4") -> joinedColStat2, nameToAttr("key-2-3") -> joinedColStat2))) - assert(join.stats(conf) == expectedStats) + assert(join.stats == expectedStats) } test("left outer join") { @@ -201,7 +201,7 @@ class JoinEstimationSuite extends StatsEstimationTestBase { attributeStats = AttributeMap( Seq(nameToColInfo("key-1-2"), nameToColInfo("key-2-3"), nameToColInfo("key-1-2"), nameToAttr("key-2-4") -> joinedColStat))) - assert(join.stats(conf) == expectedStats) + assert(join.stats == expectedStats) } test("right outer join") { @@ -219,7 +219,7 @@ class JoinEstimationSuite extends StatsEstimationTestBase { attributeStats = AttributeMap( Seq(nameToColInfo("key-1-2"), nameToAttr("key-2-4") -> joinedColStat, nameToColInfo("key-1-2"), nameToColInfo("key-2-3")))) - assert(join.stats(conf) == expectedStats) + assert(join.stats == expectedStats) } test("full outer join") { @@ -234,7 +234,7 @@ class JoinEstimationSuite extends StatsEstimationTestBase { // Keep the column stat from both sides unchanged. attributeStats = AttributeMap(Seq(nameToColInfo("key-1-2"), nameToColInfo("key-2-4"), nameToColInfo("key-1-2"), nameToColInfo("key-2-3")))) - assert(join.stats(conf) == expectedStats) + assert(join.stats == expectedStats) } test("left semi/anti join") { @@ -248,7 +248,7 @@ class JoinEstimationSuite extends StatsEstimationTestBase { sizeInBytes = 3 * (8 + 4 * 2), rowCount = Some(3), attributeStats = AttributeMap(Seq(nameToColInfo("key-1-2"), nameToColInfo("key-2-4")))) - assert(join.stats(conf) == expectedStats) + assert(join.stats == expectedStats) } } @@ -306,7 +306,7 @@ class JoinEstimationSuite extends StatsEstimationTestBase { sizeInBytes = 1 * (8 + 2 * getColSize(key1, columnInfo1(key1))), rowCount = Some(1), attributeStats = AttributeMap(Seq(key1 -> columnInfo1(key1), key2 -> columnInfo1(key1)))) - assert(join.stats(conf) == expectedStats) + assert(join.stats == expectedStats) } } } @@ -323,6 +323,6 @@ class JoinEstimationSuite extends StatsEstimationTestBase { sizeInBytes = 1, rowCount = Some(0), attributeStats = AttributeMap(Nil)) - assert(join.stats(conf) == expectedStats) + assert(join.stats == expectedStats) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/ProjectEstimationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/ProjectEstimationSuite.scala index a5c4d22a29386..cda54fa9d64f4 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/ProjectEstimationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/ProjectEstimationSuite.scala @@ -45,7 +45,7 @@ class ProjectEstimationSuite extends StatsEstimationTestBase { sizeInBytes = 2 * (8 + 4 + 4), rowCount = Some(2), attributeStats = expectedAttrStats) - assert(proj.stats(conf) == expectedStats) + assert(proj.stats == expectedStats) } test("project on empty table") { @@ -131,6 +131,6 @@ class ProjectEstimationSuite extends StatsEstimationTestBase { sizeInBytes = expectedSize, rowCount = Some(expectedRowCount), attributeStats = projectAttrMap) - assert(proj.stats(conf) == expectedStats) + assert(proj.stats == expectedStats) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/StatsEstimationTestBase.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/StatsEstimationTestBase.scala index 263f4e18803d5..eaa33e44a6a5a 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/StatsEstimationTestBase.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/StatsEstimationTestBase.scala @@ -21,14 +21,24 @@ import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap, AttributeReference} import org.apache.spark.sql.catalyst.plans.logical.{ColumnStat, LeafNode, LogicalPlan, Statistics} import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.internal.SQLConf.{CASE_SENSITIVE, CBO_ENABLED} import org.apache.spark.sql.types.{IntegerType, StringType} trait StatsEstimationTestBase extends SparkFunSuite { - /** Enable stats estimation based on CBO. */ - protected val conf = new SQLConf().copy(CASE_SENSITIVE -> true, CBO_ENABLED -> true) + var originalValue: Boolean = false + + override def beforeAll(): Unit = { + super.beforeAll() + // Enable stats estimation based on CBO. + originalValue = SQLConf.get.getConf(SQLConf.CBO_ENABLED) + SQLConf.get.setConf(SQLConf.CBO_ENABLED, true) + } + + override def afterAll(): Unit = { + SQLConf.get.setConf(SQLConf.CBO_ENABLED, originalValue) + super.afterAll() + } def getColSize(attribute: Attribute, colStat: ColumnStat): Long = attribute.dataType match { // For UTF8String: base + offset + numBytes @@ -55,7 +65,7 @@ case class StatsTestPlan( attributeStats: AttributeMap[ColumnStat], size: Option[BigInt] = None) extends LeafNode { override def output: Seq[Attribute] = outputList - override def computeStats(conf: SQLConf): Statistics = Statistics( + override def computeStats: Statistics = Statistics( // If sizeInBytes is useless in testing, we just use a fake value sizeInBytes = size.getOrElse(Int.MaxValue), rowCount = Some(rowCount), diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala index 3d1b481a53e75..66f66a289a065 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala @@ -25,7 +25,6 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.plans.physical.{Partitioning, UnknownPartitioning} import org.apache.spark.sql.execution.metric.SQLMetrics -import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.DataType import org.apache.spark.util.Utils @@ -89,7 +88,7 @@ case class ExternalRDD[T]( override protected def stringArgs: Iterator[Any] = Iterator(output) - @transient override def computeStats(conf: SQLConf): Statistics = Statistics( + @transient override def computeStats: Statistics = Statistics( // TODO: Instead of returning a default value here, find a way to return a meaningful size // estimate for RDDs. See PR 1238 for more discussions. sizeInBytes = BigInt(session.sessionState.conf.defaultSizeInBytes) @@ -157,7 +156,7 @@ case class LogicalRDD( override protected def stringArgs: Iterator[Any] = Iterator(output) - @transient override def computeStats(conf: SQLConf): Statistics = Statistics( + @transient override def computeStats: Statistics = Statistics( // TODO: Instead of returning a default value here, find a way to return a meaningful size // estimate for RDDs. See PR 1238 for more discussions. sizeInBytes = BigInt(session.sessionState.conf.defaultSizeInBytes) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala index 34998cbd61552..c7cac332a0377 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala @@ -221,7 +221,7 @@ class QueryExecution(val sparkSession: SparkSession, val logical: LogicalPlan) { def stringWithStats: String = { // trigger to compute stats for logical plans - optimizedPlan.stats(sparkSession.sessionState.conf) + optimizedPlan.stats // only show optimized logical plan and physical plan s"""== Optimized Logical Plan == diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index ea86f6e00fefa..a57d5abb90c0e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -22,7 +22,6 @@ import org.apache.spark.sql.Strategy import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.encoders.RowEncoder import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.expressions.aggregate.First import org.apache.spark.sql.catalyst.planning._ import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical._ @@ -114,9 +113,9 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { * Matches a plan whose output should be small enough to be used in broadcast join. */ private def canBroadcast(plan: LogicalPlan): Boolean = { - plan.stats(conf).hints.broadcast || - (plan.stats(conf).sizeInBytes >= 0 && - plan.stats(conf).sizeInBytes <= conf.autoBroadcastJoinThreshold) + plan.stats.hints.broadcast || + (plan.stats.sizeInBytes >= 0 && + plan.stats.sizeInBytes <= conf.autoBroadcastJoinThreshold) } /** @@ -126,7 +125,7 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { * dynamic. */ private def canBuildLocalHashMap(plan: LogicalPlan): Boolean = { - plan.stats(conf).sizeInBytes < conf.autoBroadcastJoinThreshold * conf.numShufflePartitions + plan.stats.sizeInBytes < conf.autoBroadcastJoinThreshold * conf.numShufflePartitions } /** @@ -137,7 +136,7 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { * use the size of bytes here as estimation. */ private def muchSmaller(a: LogicalPlan, b: LogicalPlan): Boolean = { - a.stats(conf).sizeInBytes * 3 <= b.stats(conf).sizeInBytes + a.stats.sizeInBytes * 3 <= b.stats.sizeInBytes } private def canBuildRight(joinType: JoinType): Boolean = joinType match { @@ -206,7 +205,7 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { case logical.Join(left, right, joinType, condition) => val buildSide = - if (right.stats(conf).sizeInBytes <= left.stats(conf).sizeInBytes) { + if (right.stats.sizeInBytes <= left.stats.sizeInBytes) { BuildRight } else { BuildLeft diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala index 456a8f3b20f30..2972132336de0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala @@ -27,7 +27,6 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical import org.apache.spark.sql.catalyst.plans.logical.Statistics import org.apache.spark.sql.execution.SparkPlan -import org.apache.spark.sql.internal.SQLConf import org.apache.spark.storage.StorageLevel import org.apache.spark.util.LongAccumulator @@ -70,7 +69,7 @@ case class InMemoryRelation( @transient val partitionStatistics = new PartitionStatistics(output) - override def computeStats(conf: SQLConf): Statistics = { + override def computeStats: Statistics = { if (batchStats.value == 0L) { // Underlying columnar RDD hasn't been materialized, no useful statistics information // available, return the default statistics. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/LogicalRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/LogicalRelation.scala index 3813f953e06a3..c1b2895f1747e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/LogicalRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/LogicalRelation.scala @@ -20,7 +20,6 @@ import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation import org.apache.spark.sql.catalyst.catalog.CatalogTable import org.apache.spark.sql.catalyst.expressions.{AttributeMap, AttributeReference} import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, LogicalPlan, Statistics} -import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.sources.BaseRelation import org.apache.spark.util.Utils @@ -46,7 +45,7 @@ case class LogicalRelation( // Only care about relation when canonicalizing. override def preCanonicalized: LogicalPlan = copy(catalogTable = None) - @transient override def computeStats(conf: SQLConf): Statistics = { + @transient override def computeStats: Statistics = { catalogTable.flatMap(_.stats.map(_.toPlanStats(output))).getOrElse( Statistics(sizeInBytes = relation.sizeInBytes)) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala index 7eaa803a9ecb4..a5dac469f85b6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala @@ -29,7 +29,6 @@ import org.apache.spark.sql.catalyst.encoders.encoderFor import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, Statistics} import org.apache.spark.sql.catalyst.streaming.InternalOutputModes._ -import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.streaming.OutputMode import org.apache.spark.sql.types.StructType import org.apache.spark.util.Utils @@ -230,6 +229,6 @@ case class MemoryPlan(sink: MemorySink, output: Seq[Attribute]) extends LeafNode private val sizePerRow = sink.schema.toAttributes.map(_.dataType.defaultSize).sum - override def computeStats(conf: SQLConf): Statistics = + override def computeStats: Statistics = Statistics(sizePerRow * sink.allData.size) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala index 8532a5b5bc8eb..506cc2548e260 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala @@ -313,7 +313,7 @@ class CachedTableSuite extends QueryTest with SQLTestUtils with SharedSQLContext spark.table("testData").queryExecution.withCachedData.collect { case cached: InMemoryRelation => val actualSizeInBytes = (1 to 100).map(i => 4 + i.toString.length + 4).sum - assert(cached.stats(sqlConf).sizeInBytes === actualSizeInBytes) + assert(cached.stats.sizeInBytes === actualSizeInBytes) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala index 165176f6c040e..87b7b090de3bf 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala @@ -1146,7 +1146,7 @@ class DatasetSuite extends QueryTest with SharedSQLContext { // instead of Int for avoiding possible overflow. val ds = (0 to 10000).map( i => (i, Seq((i, Seq((i, "This is really not that long of a string")))))).toDS() - val sizeInBytes = ds.logicalPlan.stats(sqlConf).sizeInBytes + val sizeInBytes = ds.logicalPlan.stats.sizeInBytes // sizeInBytes is 2404280404, before the fix, it overflows to a negative number assert(sizeInBytes > 0) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala index 1a66aa85f5a02..895ca196a7a51 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala @@ -33,7 +33,7 @@ class JoinSuite extends QueryTest with SharedSQLContext { setupTestData() def statisticSizeInByte(df: DataFrame): BigInt = { - df.queryExecution.optimizedPlan.stats(sqlConf).sizeInBytes + df.queryExecution.optimizedPlan.stats.sizeInBytes } test("equi-join is hash-join") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/StatisticsCollectionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/StatisticsCollectionSuite.scala index 601324f2c0172..9824062f969b3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/StatisticsCollectionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/StatisticsCollectionSuite.scala @@ -60,7 +60,7 @@ class StatisticsCollectionSuite extends StatisticsCollectionTestBase with Shared val df = df1.join(df2, Seq("k"), "left") val sizes = df.queryExecution.analyzed.collect { case g: Join => - g.stats(conf).sizeInBytes + g.stats.sizeInBytes } assert(sizes.size === 1, s"number of Join nodes is wrong:\n ${df.queryExecution}") @@ -107,9 +107,9 @@ class StatisticsCollectionSuite extends StatisticsCollectionTestBase with Shared test("SPARK-15392: DataFrame created from RDD should not be broadcasted") { val rdd = sparkContext.range(1, 100).map(i => Row(i, i)) val df = spark.createDataFrame(rdd, new StructType().add("a", LongType).add("b", LongType)) - assert(df.queryExecution.analyzed.stats(conf).sizeInBytes > + assert(df.queryExecution.analyzed.stats.sizeInBytes > spark.sessionState.conf.autoBroadcastJoinThreshold) - assert(df.selectExpr("a").queryExecution.analyzed.stats(conf).sizeInBytes > + assert(df.selectExpr("a").queryExecution.analyzed.stats.sizeInBytes > spark.sessionState.conf.autoBroadcastJoinThreshold) } @@ -250,13 +250,13 @@ abstract class StatisticsCollectionTestBase extends QueryTest with SQLTestUtils test("SPARK-18856: non-empty partitioned table should not report zero size") { withTable("ds_tbl", "hive_tbl") { spark.range(100).select($"id", $"id" % 5 as "p").write.partitionBy("p").saveAsTable("ds_tbl") - val stats = spark.table("ds_tbl").queryExecution.optimizedPlan.stats(conf) + val stats = spark.table("ds_tbl").queryExecution.optimizedPlan.stats assert(stats.sizeInBytes > 0, "non-empty partitioned table should not report zero size.") if (spark.conf.get(StaticSQLConf.CATALOG_IMPLEMENTATION) == "hive") { sql("CREATE TABLE hive_tbl(i int) PARTITIONED BY (j int)") sql("INSERT INTO hive_tbl PARTITION(j=1) SELECT 1") - val stats2 = spark.table("hive_tbl").queryExecution.optimizedPlan.stats(conf) + val stats2 = spark.table("hive_tbl").queryExecution.optimizedPlan.stats assert(stats2.sizeInBytes > 0, "non-empty partitioned table should not report zero size.") } } @@ -296,10 +296,10 @@ abstract class StatisticsCollectionTestBase extends QueryTest with SQLTestUtils assert(catalogTable.stats.get.colStats == Map("c1" -> emptyColStat)) // Check relation statistics - assert(relation.stats(conf).sizeInBytes == 0) - assert(relation.stats(conf).rowCount == Some(0)) - assert(relation.stats(conf).attributeStats.size == 1) - val (attribute, colStat) = relation.stats(conf).attributeStats.head + assert(relation.stats.sizeInBytes == 0) + assert(relation.stats.rowCount == Some(0)) + assert(relation.stats.attributeStats.size == 1) + val (attribute, colStat) = relation.stats.attributeStats.head assert(attribute.name == "c1") assert(colStat == emptyColStat) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala index 109b1d9db60d2..8d411eb191cd9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala @@ -126,7 +126,7 @@ class InMemoryColumnarQuerySuite extends QueryTest with SharedSQLContext { .toDF().createOrReplaceTempView("sizeTst") spark.catalog.cacheTable("sizeTst") assert( - spark.table("sizeTst").queryExecution.analyzed.stats(sqlConf).sizeInBytes > + spark.table("sizeTst").queryExecution.analyzed.stats.sizeInBytes > spark.conf.get(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD)) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/HadoopFsRelationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/HadoopFsRelationSuite.scala index becb3aa270401..caf03885e3873 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/HadoopFsRelationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/HadoopFsRelationSuite.scala @@ -36,7 +36,7 @@ class HadoopFsRelationSuite extends QueryTest with SharedSQLContext { }) val totalSize = allFiles.map(_.length()).sum val df = spark.read.parquet(dir.toString) - assert(df.queryExecution.logical.stats(sqlConf).sizeInBytes === BigInt(totalSize)) + assert(df.queryExecution.logical.stats.sizeInBytes === BigInt(totalSize)) } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/MemorySinkSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/MemorySinkSuite.scala index 24a7b7740fa5b..e8420eee7fe9d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/MemorySinkSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/MemorySinkSuite.scala @@ -216,15 +216,15 @@ class MemorySinkSuite extends StreamTest with BeforeAndAfter { // Before adding data, check output checkAnswer(sink.allData, Seq.empty) - assert(plan.stats(sqlConf).sizeInBytes === 0) + assert(plan.stats.sizeInBytes === 0) sink.addBatch(0, 1 to 3) plan.invalidateStatsCache() - assert(plan.stats(sqlConf).sizeInBytes === 12) + assert(plan.stats.sizeInBytes === 12) sink.addBatch(1, 4 to 6) plan.invalidateStatsCache() - assert(plan.stats(sqlConf).sizeInBytes === 24) + assert(plan.stats.sizeInBytes === 24) } ignore("stress test") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestData.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestData.scala index f9b3ff8405823..0cfe260e52152 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestData.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestData.scala @@ -21,7 +21,6 @@ import java.nio.charset.StandardCharsets import org.apache.spark.rdd.RDD import org.apache.spark.sql.{DataFrame, SparkSession, SQLContext, SQLImplicits} -import org.apache.spark.sql.internal.SQLConf /** * A collection of sample data used in SQL tests. @@ -29,8 +28,6 @@ import org.apache.spark.sql.internal.SQLConf private[sql] trait SQLTestData { self => protected def spark: SparkSession - protected def sqlConf: SQLConf = spark.sessionState.conf - // Helper object to import SQL implicits without a concrete SQLContext private object internalImplicits extends SQLImplicits { protected override def _sqlContext: SQLContext = self.spark.sqlContext diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala index ff5afc8e3ce05..808dc013f170b 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala @@ -154,7 +154,7 @@ private[hive] class HiveMetastoreCatalog(sparkSession: SparkSession) extends Log Some(partitionSchema)) val logicalRelation = cached.getOrElse { - val sizeInBytes = relation.stats(sparkSession.sessionState.conf).sizeInBytes.toLong + val sizeInBytes = relation.stats.sizeInBytes.toLong val fileIndex = { val index = new CatalogFileIndex(sparkSession, relation.tableMeta, sizeInBytes) if (lazyPruningEnabled) { diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala index 001bbc230ff18..279db9a397258 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala @@ -68,7 +68,7 @@ class StatisticsSuite extends StatisticsCollectionTestBase with TestHiveSingleto assert(properties("totalSize").toLong <= 0, "external table totalSize must be <= 0") assert(properties("rawDataSize").toLong <= 0, "external table rawDataSize must be <= 0") - val sizeInBytes = relation.stats(conf).sizeInBytes + val sizeInBytes = relation.stats.sizeInBytes assert(sizeInBytes === BigInt(file1.length() + file2.length())) } } @@ -77,7 +77,7 @@ class StatisticsSuite extends StatisticsCollectionTestBase with TestHiveSingleto test("analyze Hive serde tables") { def queryTotalSize(tableName: String): BigInt = - spark.table(tableName).queryExecution.analyzed.stats(conf).sizeInBytes + spark.table(tableName).queryExecution.analyzed.stats.sizeInBytes // Non-partitioned table sql("CREATE TABLE analyzeTable (key STRING, value STRING)").collect() @@ -659,7 +659,7 @@ class StatisticsSuite extends StatisticsCollectionTestBase with TestHiveSingleto test("estimates the size of a test Hive serde tables") { val df = sql("""SELECT * FROM src""") val sizes = df.queryExecution.analyzed.collect { - case relation: CatalogRelation => relation.stats(conf).sizeInBytes + case relation: CatalogRelation => relation.stats.sizeInBytes } assert(sizes.size === 1, s"Size wrong for:\n ${df.queryExecution}") assert(sizes(0).equals(BigInt(5812)), @@ -679,7 +679,7 @@ class StatisticsSuite extends StatisticsCollectionTestBase with TestHiveSingleto // Assert src has a size smaller than the threshold. val sizes = df.queryExecution.analyzed.collect { - case r if ct.runtimeClass.isAssignableFrom(r.getClass) => r.stats(conf).sizeInBytes + case r if ct.runtimeClass.isAssignableFrom(r.getClass) => r.stats.sizeInBytes } assert(sizes.size === 2 && sizes(0) <= spark.sessionState.conf.autoBroadcastJoinThreshold && sizes(1) <= spark.sessionState.conf.autoBroadcastJoinThreshold, @@ -733,7 +733,7 @@ class StatisticsSuite extends StatisticsCollectionTestBase with TestHiveSingleto // Assert src has a size smaller than the threshold. val sizes = df.queryExecution.analyzed.collect { - case relation: CatalogRelation => relation.stats(conf).sizeInBytes + case relation: CatalogRelation => relation.stats.sizeInBytes } assert(sizes.size === 2 && sizes(1) <= spark.sessionState.conf.autoBroadcastJoinThreshold && sizes(0) <= spark.sessionState.conf.autoBroadcastJoinThreshold, diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/PruneFileSourcePartitionsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/PruneFileSourcePartitionsSuite.scala index d91f25a4da013..3a724aa14f2a9 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/PruneFileSourcePartitionsSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/PruneFileSourcePartitionsSuite.scala @@ -86,7 +86,7 @@ class PruneFileSourcePartitionsSuite extends QueryTest with SQLTestUtils with Te case relation: LogicalRelation => relation } assert(relations.size === 1, s"Size wrong for:\n ${df.queryExecution}") - val size2 = relations(0).computeStats(conf).sizeInBytes + val size2 = relations(0).computeStats.sizeInBytes assert(size2 == relations(0).catalogTable.get.stats.get.sizeInBytes) assert(size2 < tableStats.get.sizeInBytes) } From 1ebe7ffe072bcac03360e65e959a6cd36530a9c4 Mon Sep 17 00:00:00 2001 From: Dhruve Ashar Date: Fri, 23 Jun 2017 10:36:29 -0700 Subject: [PATCH 0785/1765] [SPARK-21181] Release byteBuffers to suppress netty error messages ## What changes were proposed in this pull request? We are explicitly calling release on the byteBuf's used to encode the string to Base64 to suppress the memory leak error message reported by netty. This is to make it less confusing for the user. ### Changes proposed in this fix By explicitly invoking release on the byteBuf's we are decrement the internal reference counts for the wrappedByteBuf's. Now, when the GC kicks in, these would be reclaimed as before, just that netty wouldn't report any memory leak error messages as the internal ref. counts are now 0. ## How was this patch tested? Ran a few spark-applications and examined the logs. The error message no longer appears. Original PR was opened against branch-2.1 => https://github.com/apache/spark/pull/18392 Author: Dhruve Ashar Closes #18407 from dhruve/master. --- .../spark/network/sasl/SparkSaslServer.java | 26 ++++++++++++++++--- 1 file changed, 22 insertions(+), 4 deletions(-) diff --git a/common/network-common/src/main/java/org/apache/spark/network/sasl/SparkSaslServer.java b/common/network-common/src/main/java/org/apache/spark/network/sasl/SparkSaslServer.java index e24fdf0c74de3..00f3e83dbc8b3 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/sasl/SparkSaslServer.java +++ b/common/network-common/src/main/java/org/apache/spark/network/sasl/SparkSaslServer.java @@ -34,6 +34,7 @@ import com.google.common.base.Preconditions; import com.google.common.base.Throwables; import com.google.common.collect.ImmutableMap; +import io.netty.buffer.ByteBuf; import io.netty.buffer.Unpooled; import io.netty.handler.codec.base64.Base64; import org.slf4j.Logger; @@ -187,14 +188,31 @@ public void handle(Callback[] callbacks) throws IOException, UnsupportedCallback /* Encode a byte[] identifier as a Base64-encoded string. */ public static String encodeIdentifier(String identifier) { Preconditions.checkNotNull(identifier, "User cannot be null if SASL is enabled"); - return Base64.encode(Unpooled.wrappedBuffer(identifier.getBytes(StandardCharsets.UTF_8))) - .toString(StandardCharsets.UTF_8); + return getBase64EncodedString(identifier); } /** Encode a password as a base64-encoded char[] array. */ public static char[] encodePassword(String password) { Preconditions.checkNotNull(password, "Password cannot be null if SASL is enabled"); - return Base64.encode(Unpooled.wrappedBuffer(password.getBytes(StandardCharsets.UTF_8))) - .toString(StandardCharsets.UTF_8).toCharArray(); + return getBase64EncodedString(password).toCharArray(); + } + + /** Return a Base64-encoded string. */ + private static String getBase64EncodedString(String str) { + ByteBuf byteBuf = null; + ByteBuf encodedByteBuf = null; + try { + byteBuf = Unpooled.wrappedBuffer(str.getBytes(StandardCharsets.UTF_8)); + encodedByteBuf = Base64.encode(byteBuf); + return encodedByteBuf.toString(StandardCharsets.UTF_8); + } finally { + // The release is called to suppress the memory leak error messages raised by netty. + if (byteBuf != null) { + byteBuf.release(); + if (encodedByteBuf != null) { + encodedByteBuf.release(); + } + } + } } } From 2ebd0838d165fe33b404e8d86c0fa445d1f47439 Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Fri, 23 Jun 2017 10:55:02 -0700 Subject: [PATCH 0786/1765] [SPARK-21192][SS] Preserve State Store provider class configuration across StreamingQuery restarts ## What changes were proposed in this pull request? If the SQL conf for StateStore provider class is changed between restarts (i.e. query started with providerClass1 and attempted to restart using providerClass2), then the query will fail in a unpredictable way as files saved by one provider class cannot be used by the newer one. Ideally, the provider class used to start the query should be used to restart the query, and the configuration in the session where it is being restarted should be ignored. This PR saves the provider class config to OffsetSeqLog, in the same way # shuffle partitions is saved and recovered. ## How was this patch tested? new unit tests Author: Tathagata Das Closes #18402 from tdas/SPARK-21192. --- .../apache/spark/sql/internal/SQLConf.scala | 5 +- .../sql/execution/streaming/OffsetSeq.scala | 39 +++++++++++++- .../execution/streaming/StreamExecution.scala | 26 +++------- .../streaming/state/StateStore.scala | 3 +- .../streaming/state/StateStoreConf.scala | 2 +- .../streaming/OffsetSeqLogSuite.scala | 10 ++-- .../spark/sql/streaming/StreamSuite.scala | 51 +++++++++++++++---- 7 files changed, 96 insertions(+), 40 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index e609256db2802..9c8e26a8eeadf 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -601,7 +601,8 @@ object SQLConf { "The class used to manage state data in stateful streaming queries. This class must " + "be a subclass of StateStoreProvider, and must have a zero-arg constructor.") .stringConf - .createOptional + .createWithDefault( + "org.apache.spark.sql.execution.streaming.state.HDFSBackedStateStoreProvider") val STATE_STORE_MIN_DELTAS_FOR_SNAPSHOT = buildConf("spark.sql.streaming.stateStore.minDeltasForSnapshot") @@ -897,7 +898,7 @@ class SQLConf extends Serializable with Logging { def optimizerInSetConversionThreshold: Int = getConf(OPTIMIZER_INSET_CONVERSION_THRESHOLD) - def stateStoreProviderClass: Option[String] = getConf(STATE_STORE_PROVIDER_CLASS) + def stateStoreProviderClass: String = getConf(STATE_STORE_PROVIDER_CLASS) def stateStoreMinDeltasForSnapshot: Int = getConf(STATE_STORE_MIN_DELTAS_FOR_SNAPSHOT) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/OffsetSeq.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/OffsetSeq.scala index 8249adab4bba8..4e0a468b962a2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/OffsetSeq.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/OffsetSeq.scala @@ -20,6 +20,10 @@ package org.apache.spark.sql.execution.streaming import org.json4s.NoTypeHints import org.json4s.jackson.Serialization +import org.apache.spark.internal.Logging +import org.apache.spark.sql.RuntimeConfig +import org.apache.spark.sql.internal.SQLConf.{SHUFFLE_PARTITIONS, STATE_STORE_PROVIDER_CLASS} + /** * An ordered collection of offsets, used to track the progress of processing data from one or more * [[Source]]s that are present in a streaming query. This is similar to simplified, single-instance @@ -78,7 +82,40 @@ case class OffsetSeqMetadata( def json: String = Serialization.write(this)(OffsetSeqMetadata.format) } -object OffsetSeqMetadata { +object OffsetSeqMetadata extends Logging { private implicit val format = Serialization.formats(NoTypeHints) + private val relevantSQLConfs = Seq(SHUFFLE_PARTITIONS, STATE_STORE_PROVIDER_CLASS) + def apply(json: String): OffsetSeqMetadata = Serialization.read[OffsetSeqMetadata](json) + + def apply( + batchWatermarkMs: Long, + batchTimestampMs: Long, + sessionConf: RuntimeConfig): OffsetSeqMetadata = { + val confs = relevantSQLConfs.map { conf => conf.key -> sessionConf.get(conf.key) }.toMap + OffsetSeqMetadata(batchWatermarkMs, batchTimestampMs, confs) + } + + /** Set the SparkSession configuration with the values in the metadata */ + def setSessionConf(metadata: OffsetSeqMetadata, sessionConf: RuntimeConfig): Unit = { + OffsetSeqMetadata.relevantSQLConfs.map(_.key).foreach { confKey => + + metadata.conf.get(confKey) match { + + case Some(valueInMetadata) => + // Config value exists in the metadata, update the session config with this value + val optionalValueInSession = sessionConf.getOption(confKey) + if (optionalValueInSession.isDefined && optionalValueInSession.get != valueInMetadata) { + logWarning(s"Updating the value of conf '$confKey' in current session from " + + s"'${optionalValueInSession.get}' to '$valueInMetadata'.") + } + sessionConf.set(confKey, valueInMetadata) + + case None => + // For backward compatibility, if a config was not recorded in the offset log, + // then log it, and let the existing conf value in SparkSession prevail. + logWarning (s"Conf '$confKey' was not found in the offset log, using existing value") + } + } + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala index 06bdec8b06407..d5f8d2acba92b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala @@ -125,9 +125,8 @@ class StreamExecution( } /** Metadata associated with the offset seq of a batch in the query. */ - protected var offsetSeqMetadata = OffsetSeqMetadata(batchWatermarkMs = 0, batchTimestampMs = 0, - conf = Map(SQLConf.SHUFFLE_PARTITIONS.key -> - sparkSession.conf.get(SQLConf.SHUFFLE_PARTITIONS).toString)) + protected var offsetSeqMetadata = OffsetSeqMetadata( + batchWatermarkMs = 0, batchTimestampMs = 0, sparkSession.conf) override val id: UUID = UUID.fromString(streamMetadata.id) @@ -285,9 +284,8 @@ class StreamExecution( val sparkSessionToRunBatches = sparkSession.cloneSession() // Adaptive execution can change num shuffle partitions, disallow sparkSessionToRunBatches.conf.set(SQLConf.ADAPTIVE_EXECUTION_ENABLED.key, "false") - offsetSeqMetadata = OffsetSeqMetadata(batchWatermarkMs = 0, batchTimestampMs = 0, - conf = Map(SQLConf.SHUFFLE_PARTITIONS.key -> - sparkSessionToRunBatches.conf.get(SQLConf.SHUFFLE_PARTITIONS.key))) + offsetSeqMetadata = OffsetSeqMetadata( + batchWatermarkMs = 0, batchTimestampMs = 0, sparkSessionToRunBatches.conf) if (state.compareAndSet(INITIALIZING, ACTIVE)) { // Unblock `awaitInitialization` @@ -441,21 +439,9 @@ class StreamExecution( // update offset metadata nextOffsets.metadata.foreach { metadata => - val shufflePartitionsSparkSession: Int = - sparkSessionToRunBatches.conf.get(SQLConf.SHUFFLE_PARTITIONS) - val shufflePartitionsToUse = metadata.conf.getOrElse(SQLConf.SHUFFLE_PARTITIONS.key, { - // For backward compatibility, if # partitions was not recorded in the offset log, - // then ensure it is not missing. The new value is picked up from the conf. - logWarning("Number of shuffle partitions from previous run not found in checkpoint. " - + s"Using the value from the conf, $shufflePartitionsSparkSession partitions.") - shufflePartitionsSparkSession - }) + OffsetSeqMetadata.setSessionConf(metadata, sparkSessionToRunBatches.conf) offsetSeqMetadata = OffsetSeqMetadata( - metadata.batchWatermarkMs, metadata.batchTimestampMs, - metadata.conf + (SQLConf.SHUFFLE_PARTITIONS.key -> shufflePartitionsToUse.toString)) - // Update conf with correct number of shuffle partitions - sparkSessionToRunBatches.conf.set( - SQLConf.SHUFFLE_PARTITIONS.key, shufflePartitionsToUse.toString) + metadata.batchWatermarkMs, metadata.batchTimestampMs, sparkSessionToRunBatches.conf) } /* identify the current batch id: if commit log indicates we successfully processed the diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala index a94ff8a7ebd1e..86886466c4f56 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala @@ -172,8 +172,7 @@ object StateStoreProvider { indexOrdinal: Option[Int], // for sorting the data storeConf: StateStoreConf, hadoopConf: Configuration): StateStoreProvider = { - val providerClass = storeConf.providerClass.map(Utils.classForName) - .getOrElse(classOf[HDFSBackedStateStoreProvider]) + val providerClass = Utils.classForName(storeConf.providerClass) val provider = providerClass.newInstance().asInstanceOf[StateStoreProvider] provider.init(stateStoreId, keySchema, valueSchema, indexOrdinal, storeConf, hadoopConf) provider diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreConf.scala index bab297c7df594..765ff076cb467 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreConf.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreConf.scala @@ -38,7 +38,7 @@ class StateStoreConf(@transient private val sqlConf: SQLConf) * Optional fully qualified name of the subclass of [[StateStoreProvider]] * managing state data. That is, the implementation of the State Store to use. */ - val providerClass: Option[String] = sqlConf.stateStoreProviderClass + val providerClass: String = sqlConf.stateStoreProviderClass /** * Additional configurations related to state store. This will capture all configs in diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/OffsetSeqLogSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/OffsetSeqLogSuite.scala index dc556322beddb..e6cdc063c4e9f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/OffsetSeqLogSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/OffsetSeqLogSuite.scala @@ -37,16 +37,18 @@ class OffsetSeqLogSuite extends SparkFunSuite with SharedSQLContext { } // None set - assert(OffsetSeqMetadata(0, 0, Map.empty) === OffsetSeqMetadata("""{}""")) + assert(new OffsetSeqMetadata(0, 0, Map.empty) === OffsetSeqMetadata("""{}""")) // One set - assert(OffsetSeqMetadata(1, 0, Map.empty) === OffsetSeqMetadata("""{"batchWatermarkMs":1}""")) - assert(OffsetSeqMetadata(0, 2, Map.empty) === OffsetSeqMetadata("""{"batchTimestampMs":2}""")) + assert(new OffsetSeqMetadata(1, 0, Map.empty) === + OffsetSeqMetadata("""{"batchWatermarkMs":1}""")) + assert(new OffsetSeqMetadata(0, 2, Map.empty) === + OffsetSeqMetadata("""{"batchTimestampMs":2}""")) assert(OffsetSeqMetadata(0, 0, getConfWith(shufflePartitions = 2)) === OffsetSeqMetadata(s"""{"conf": {"$key":2}}""")) // Two set - assert(OffsetSeqMetadata(1, 2, Map.empty) === + assert(new OffsetSeqMetadata(1, 2, Map.empty) === OffsetSeqMetadata("""{"batchWatermarkMs":1,"batchTimestampMs":2}""")) assert(OffsetSeqMetadata(1, 0, getConfWith(shufflePartitions = 3)) === OffsetSeqMetadata(s"""{"batchWatermarkMs":1,"conf": {"$key":3}}""")) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala index 86c3a35a59c13..6f7b9d35a6bb3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala @@ -637,19 +637,11 @@ class StreamSuite extends StreamTest { } testQuietly("specify custom state store provider") { - val queryName = "memStream" val providerClassName = classOf[TestStateStoreProvider].getCanonicalName withSQLConf("spark.sql.streaming.stateStore.providerClass" -> providerClassName) { val input = MemoryStream[Int] - val query = input - .toDS() - .groupBy() - .count() - .writeStream - .outputMode("complete") - .format("memory") - .queryName(queryName) - .start() + val df = input.toDS().groupBy().count() + val query = df.writeStream.outputMode("complete").format("memory").queryName("name").start() input.addData(1, 2, 3) val e = intercept[Exception] { query.awaitTermination() @@ -659,6 +651,45 @@ class StreamSuite extends StreamTest { assert(e.getMessage.contains("instantiated")) } } + + testQuietly("custom state store provider read from offset log") { + val input = MemoryStream[Int] + val df = input.toDS().groupBy().count() + val providerConf1 = "spark.sql.streaming.stateStore.providerClass" -> + "org.apache.spark.sql.execution.streaming.state.HDFSBackedStateStoreProvider" + val providerConf2 = "spark.sql.streaming.stateStore.providerClass" -> + classOf[TestStateStoreProvider].getCanonicalName + + def runQuery(queryName: String, checkpointLoc: String): Unit = { + val query = df.writeStream + .outputMode("complete") + .format("memory") + .queryName(queryName) + .option("checkpointLocation", checkpointLoc) + .start() + input.addData(1, 2, 3) + query.processAllAvailable() + query.stop() + } + + withTempDir { dir => + val checkpointLoc1 = new File(dir, "1").getCanonicalPath + withSQLConf(providerConf1) { + runQuery("query1", checkpointLoc1) // generate checkpoints + } + + val checkpointLoc2 = new File(dir, "2").getCanonicalPath + withSQLConf(providerConf2) { + // Verify new query will use new provider that throw error on loading + intercept[Exception] { + runQuery("query2", checkpointLoc2) + } + + // Verify old query from checkpoint will still use old provider + runQuery("query1", checkpointLoc1) + } + } + } } abstract class FakeSource extends StreamSourceProvider { From 4cc62951a2b12a372a2b267bf8597a0a31e2b2cb Mon Sep 17 00:00:00 2001 From: Ong Ming Yang Date: Fri, 23 Jun 2017 10:56:59 -0700 Subject: [PATCH 0787/1765] [MINOR][DOCS] Docs in DataFrameNaFunctions.scala use wrong method ## What changes were proposed in this pull request? * Following the first few examples in this file, the remaining methods should also be methods of `df.na` not `df`. * Filled in some missing parentheses ## How was this patch tested? N/A Author: Ong Ming Yang Closes #18398 from ongmingyang/master. --- .../spark/sql/DataFrameNaFunctions.scala | 20 +++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala index ee949e78fa3ba..871fff71e5538 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala @@ -268,13 +268,13 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) { * import com.google.common.collect.ImmutableMap; * * // Replaces all occurrences of 1.0 with 2.0 in column "height". - * df.replace("height", ImmutableMap.of(1.0, 2.0)); + * df.na.replace("height", ImmutableMap.of(1.0, 2.0)); * * // Replaces all occurrences of "UNKNOWN" with "unnamed" in column "name". - * df.replace("name", ImmutableMap.of("UNKNOWN", "unnamed")); + * df.na.replace("name", ImmutableMap.of("UNKNOWN", "unnamed")); * * // Replaces all occurrences of "UNKNOWN" with "unnamed" in all string columns. - * df.replace("*", ImmutableMap.of("UNKNOWN", "unnamed")); + * df.na.replace("*", ImmutableMap.of("UNKNOWN", "unnamed")); * }}} * * @param col name of the column to apply the value replacement @@ -295,10 +295,10 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) { * import com.google.common.collect.ImmutableMap; * * // Replaces all occurrences of 1.0 with 2.0 in column "height" and "weight". - * df.replace(new String[] {"height", "weight"}, ImmutableMap.of(1.0, 2.0)); + * df.na.replace(new String[] {"height", "weight"}, ImmutableMap.of(1.0, 2.0)); * * // Replaces all occurrences of "UNKNOWN" with "unnamed" in column "firstname" and "lastname". - * df.replace(new String[] {"firstname", "lastname"}, ImmutableMap.of("UNKNOWN", "unnamed")); + * df.na.replace(new String[] {"firstname", "lastname"}, ImmutableMap.of("UNKNOWN", "unnamed")); * }}} * * @param cols list of columns to apply the value replacement @@ -319,13 +319,13 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) { * * {{{ * // Replaces all occurrences of 1.0 with 2.0 in column "height". - * df.replace("height", Map(1.0 -> 2.0)) + * df.na.replace("height", Map(1.0 -> 2.0)); * * // Replaces all occurrences of "UNKNOWN" with "unnamed" in column "name". - * df.replace("name", Map("UNKNOWN" -> "unnamed") + * df.na.replace("name", Map("UNKNOWN" -> "unnamed")); * * // Replaces all occurrences of "UNKNOWN" with "unnamed" in all string columns. - * df.replace("*", Map("UNKNOWN" -> "unnamed") + * df.na.replace("*", Map("UNKNOWN" -> "unnamed")); * }}} * * @param col name of the column to apply the value replacement @@ -348,10 +348,10 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) { * * {{{ * // Replaces all occurrences of 1.0 with 2.0 in column "height" and "weight". - * df.replace("height" :: "weight" :: Nil, Map(1.0 -> 2.0)); + * df.na.replace("height" :: "weight" :: Nil, Map(1.0 -> 2.0)); * * // Replaces all occurrences of "UNKNOWN" with "unnamed" in column "firstname" and "lastname". - * df.replace("firstname" :: "lastname" :: Nil, Map("UNKNOWN" -> "unnamed"); + * df.na.replace("firstname" :: "lastname" :: Nil, Map("UNKNOWN" -> "unnamed")); * }}} * * @param cols list of columns to apply the value replacement From 13c2a4f2f8c6d3484f920caadddf4e5edce0a945 Mon Sep 17 00:00:00 2001 From: Dilip Biswal Date: Fri, 23 Jun 2017 11:02:54 -0700 Subject: [PATCH 0788/1765] [SPARK-20417][SQL] Move subquery error handling to checkAnalysis from Analyzer ## What changes were proposed in this pull request? Currently we do a lot of validations for subquery in the Analyzer. We should move them to CheckAnalysis which is the framework to catch and report Analysis errors. This was mentioned as a review comment in SPARK-18874. ## How was this patch tested? Exists tests + A few tests added to SQLQueryTestSuite. Author: Dilip Biswal Closes #17713 from dilipbiswal/subquery_checkanalysis. --- .../sql/catalyst/analysis/Analyzer.scala | 230 +----------- .../sql/catalyst/analysis/CheckAnalysis.scala | 338 ++++++++++++++---- .../sql/catalyst/expressions/predicates.scala | 46 ++- .../analysis/AnalysisErrorSuite.scala | 3 +- .../analysis/ResolveSubquerySuite.scala | 2 +- .../negative-cases/subq-input-typecheck.sql | 47 +++ .../subq-input-typecheck.sql.out | 106 ++++++ .../org/apache/spark/sql/SubquerySuite.scala | 2 +- 8 files changed, 464 insertions(+), 310 deletions(-) create mode 100644 sql/core/src/test/resources/sql-tests/inputs/subquery/negative-cases/subq-input-typecheck.sql create mode 100644 sql/core/src/test/resources/sql-tests/results/subquery/negative-cases/subq-input-typecheck.sql.out diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 647fc0b9342c1..193082eb77024 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -28,7 +28,6 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.expressions.objects.{LambdaVariable, MapObjects, NewInstance, UnresolvedMapObjects} import org.apache.spark.sql.catalyst.expressions.SubExprUtils._ -import org.apache.spark.sql.catalyst.optimizer.BooleanSimplification import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, _} import org.apache.spark.sql.catalyst.rules._ @@ -1257,217 +1256,16 @@ class Analyzer( } /** - * Validates to make sure the outer references appearing inside the subquery - * are legal. This function also returns the list of expressions - * that contain outer references. These outer references would be kept as children - * of subquery expressions by the caller of this function. - */ - private def checkAndGetOuterReferences(sub: LogicalPlan): Seq[Expression] = { - val outerReferences = ArrayBuffer.empty[Expression] - - // Validate that correlated aggregate expression do not contain a mixture - // of outer and local references. - def checkMixedReferencesInsideAggregateExpr(expr: Expression): Unit = { - expr.foreach { - case a: AggregateExpression if containsOuter(a) => - val outer = a.collect { case OuterReference(e) => e.toAttribute } - val local = a.references -- outer - if (local.nonEmpty) { - val msg = - s""" - |Found an aggregate expression in a correlated predicate that has both - |outer and local references, which is not supported yet. - |Aggregate expression: ${SubExprUtils.stripOuterReference(a).sql}, - |Outer references: ${outer.map(_.sql).mkString(", ")}, - |Local references: ${local.map(_.sql).mkString(", ")}. - """.stripMargin.replace("\n", " ").trim() - failAnalysis(msg) - } - case _ => - } - } - - // Make sure a plan's subtree does not contain outer references - def failOnOuterReferenceInSubTree(p: LogicalPlan): Unit = { - if (hasOuterReferences(p)) { - failAnalysis(s"Accessing outer query column is not allowed in:\n$p") - } - } - - // Make sure a plan's expressions do not contain : - // 1. Aggregate expressions that have mixture of outer and local references. - // 2. Expressions containing outer references on plan nodes other than Filter. - def failOnInvalidOuterReference(p: LogicalPlan): Unit = { - p.expressions.foreach(checkMixedReferencesInsideAggregateExpr) - if (!p.isInstanceOf[Filter] && p.expressions.exists(containsOuter)) { - failAnalysis( - "Expressions referencing the outer query are not supported outside of WHERE/HAVING " + - s"clauses:\n$p") - } - } - - // SPARK-17348: A potential incorrect result case. - // When a correlated predicate is a non-equality predicate, - // certain operators are not permitted from the operator - // hosting the correlated predicate up to the operator on the outer table. - // Otherwise, the pull up of the correlated predicate - // will generate a plan with a different semantics - // which could return incorrect result. - // Currently we check for Aggregate and Window operators - // - // Below shows an example of a Logical Plan during Analyzer phase that - // show this problem. Pulling the correlated predicate [outer(c2#77) >= ..] - // through the Aggregate (or Window) operator could alter the result of - // the Aggregate. - // - // Project [c1#76] - // +- Project [c1#87, c2#88] - // : (Aggregate or Window operator) - // : +- Filter [outer(c2#77) >= c2#88)] - // : +- SubqueryAlias t2, `t2` - // : +- Project [_1#84 AS c1#87, _2#85 AS c2#88] - // : +- LocalRelation [_1#84, _2#85] - // +- SubqueryAlias t1, `t1` - // +- Project [_1#73 AS c1#76, _2#74 AS c2#77] - // +- LocalRelation [_1#73, _2#74] - def failOnNonEqualCorrelatedPredicate(found: Boolean, p: LogicalPlan): Unit = { - if (found) { - // Report a non-supported case as an exception - failAnalysis(s"Correlated column is not allowed in a non-equality predicate:\n$p") - } - } - - var foundNonEqualCorrelatedPred : Boolean = false - - // Simplify the predicates before validating any unsupported correlation patterns - // in the plan. - BooleanSimplification(sub).foreachUp { - - // Whitelist operators allowed in a correlated subquery - // There are 4 categories: - // 1. Operators that are allowed anywhere in a correlated subquery, and, - // by definition of the operators, they either do not contain - // any columns or cannot host outer references. - // 2. Operators that are allowed anywhere in a correlated subquery - // so long as they do not host outer references. - // 3. Operators that need special handlings. These operators are - // Project, Filter, Join, Aggregate, and Generate. - // - // Any operators that are not in the above list are allowed - // in a correlated subquery only if they are not on a correlation path. - // In other word, these operators are allowed only under a correlation point. - // - // A correlation path is defined as the sub-tree of all the operators that - // are on the path from the operator hosting the correlated expressions - // up to the operator producing the correlated values. - - // Category 1: - // BroadcastHint, Distinct, LeafNode, Repartition, and SubqueryAlias - case _: ResolvedHint | _: Distinct | _: LeafNode | _: Repartition | _: SubqueryAlias => - - // Category 2: - // These operators can be anywhere in a correlated subquery. - // so long as they do not host outer references in the operators. - case s: Sort => - failOnInvalidOuterReference(s) - case r: RepartitionByExpression => - failOnInvalidOuterReference(r) - - // Category 3: - // Filter is one of the two operators allowed to host correlated expressions. - // The other operator is Join. Filter can be anywhere in a correlated subquery. - case f: Filter => - // Find all predicates with an outer reference. - val (correlated, _) = splitConjunctivePredicates(f.condition).partition(containsOuter) - - // Find any non-equality correlated predicates - foundNonEqualCorrelatedPred = foundNonEqualCorrelatedPred || correlated.exists { - case _: EqualTo | _: EqualNullSafe => false - case _ => true - } - - failOnInvalidOuterReference(f) - // The aggregate expressions are treated in a special way by getOuterReferences. If the - // aggregate expression contains only outer reference attributes then the entire aggregate - // expression is isolated as an OuterReference. - // i.e min(OuterReference(b)) => OuterReference(min(b)) - outerReferences ++= getOuterReferences(correlated) - - // Project cannot host any correlated expressions - // but can be anywhere in a correlated subquery. - case p: Project => - failOnInvalidOuterReference(p) - - // Aggregate cannot host any correlated expressions - // It can be on a correlation path if the correlation contains - // only equality correlated predicates. - // It cannot be on a correlation path if the correlation has - // non-equality correlated predicates. - case a: Aggregate => - failOnInvalidOuterReference(a) - failOnNonEqualCorrelatedPredicate(foundNonEqualCorrelatedPred, a) - - // Join can host correlated expressions. - case j @ Join(left, right, joinType, _) => - joinType match { - // Inner join, like Filter, can be anywhere. - case _: InnerLike => - failOnInvalidOuterReference(j) - - // Left outer join's right operand cannot be on a correlation path. - // LeftAnti and ExistenceJoin are special cases of LeftOuter. - // Note that ExistenceJoin cannot be expressed externally in both SQL and DataFrame - // so it should not show up here in Analysis phase. This is just a safety net. - // - // LeftSemi does not allow output from the right operand. - // Any correlated references in the subplan - // of the right operand cannot be pulled up. - case LeftOuter | LeftSemi | LeftAnti | ExistenceJoin(_) => - failOnInvalidOuterReference(j) - failOnOuterReferenceInSubTree(right) - - // Likewise, Right outer join's left operand cannot be on a correlation path. - case RightOuter => - failOnInvalidOuterReference(j) - failOnOuterReferenceInSubTree(left) - - // Any other join types not explicitly listed above, - // including Full outer join, are treated as Category 4. - case _ => - failOnOuterReferenceInSubTree(j) - } - - // Generator with join=true, i.e., expressed with - // LATERAL VIEW [OUTER], similar to inner join, - // allows to have correlation under it - // but must not host any outer references. - // Note: - // Generator with join=false is treated as Category 4. - case g: Generate if g.join => - failOnInvalidOuterReference(g) - - // Category 4: Any other operators not in the above 3 categories - // cannot be on a correlation path, that is they are allowed only - // under a correlation point but they and their descendant operators - // are not allowed to have any correlated expressions. - case p => - failOnOuterReferenceInSubTree(p) - } - outerReferences - } - - /** - * Resolves the subquery. The subquery is resolved using its outer plans. This method - * will resolve the subquery by alternating between the regular analyzer and by applying the - * resolveOuterReferences rule. + * Resolves the subquery plan that is referenced in a subquery expression. The normal + * attribute references are resolved using regular analyzer and the outer references are + * resolved from the outer plans using the resolveOuterReferences method. * * Outer references from the correlated predicates are updated as children of * Subquery expression. */ private def resolveSubQuery( e: SubqueryExpression, - plans: Seq[LogicalPlan], - requiredColumns: Int = 0)( + plans: Seq[LogicalPlan])( f: (LogicalPlan, Seq[Expression]) => SubqueryExpression): SubqueryExpression = { // Step 1: Resolve the outer expressions. var previous: LogicalPlan = null @@ -1488,15 +1286,8 @@ class Analyzer( // Step 2: If the subquery plan is fully resolved, pull the outer references and record // them as children of SubqueryExpression. if (current.resolved) { - // Make sure the resolved query has the required number of output columns. This is only - // needed for Scalar and IN subqueries. - if (requiredColumns > 0 && requiredColumns != current.output.size) { - failAnalysis(s"The number of columns in the subquery (${current.output.size}) " + - s"does not match the required number of columns ($requiredColumns)") - } - // Validate the outer reference and record the outer references as children of - // subquery expression. - f(current, checkAndGetOuterReferences(current)) + // Record the outer references as children of subquery expression. + f(current, SubExprUtils.getOuterReferences(current)) } else { e.withNewPlan(current) } @@ -1514,16 +1305,11 @@ class Analyzer( private def resolveSubQueries(plan: LogicalPlan, plans: Seq[LogicalPlan]): LogicalPlan = { plan transformExpressions { case s @ ScalarSubquery(sub, _, exprId) if !sub.resolved => - resolveSubQuery(s, plans, 1)(ScalarSubquery(_, _, exprId)) + resolveSubQuery(s, plans)(ScalarSubquery(_, _, exprId)) case e @ Exists(sub, _, exprId) if !sub.resolved => resolveSubQuery(e, plans)(Exists(_, _, exprId)) case In(value, Seq(l @ ListQuery(sub, _, exprId))) if value.resolved && !sub.resolved => - // Get the left hand side expressions. - val expressions = value match { - case cns : CreateNamedStruct => cns.valExprs - case expr => Seq(expr) - } - val expr = resolveSubQuery(l, plans, expressions.size)(ListQuery(_, _, exprId)) + val expr = resolveSubQuery(l, plans)(ListQuery(_, _, exprId)) In(value, Seq(expr)) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala index 2e3ac3e474866..fb81a7006bc5e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala @@ -21,6 +21,8 @@ import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression import org.apache.spark.sql.catalyst.expressions.SubExprUtils._ +import org.apache.spark.sql.catalyst.optimizer.BooleanSimplification +import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.types._ @@ -129,61 +131,8 @@ trait CheckAnalysis extends PredicateHelper { case None => w } - case s @ ScalarSubquery(query, conditions, _) => - checkAnalysis(query) - - // If no correlation, the output must be exactly one column - if (conditions.isEmpty && query.output.size != 1) { - failAnalysis( - s"Scalar subquery must return only one column, but got ${query.output.size}") - } else if (conditions.nonEmpty) { - def checkAggregate(agg: Aggregate): Unit = { - // Make sure correlated scalar subqueries contain one row for every outer row by - // enforcing that they are aggregates containing exactly one aggregate expression. - // The analyzer has already checked that subquery contained only one output column, - // and added all the grouping expressions to the aggregate. - val aggregates = agg.expressions.flatMap(_.collect { - case a: AggregateExpression => a - }) - if (aggregates.isEmpty) { - failAnalysis("The output of a correlated scalar subquery must be aggregated") - } - - // SPARK-18504/SPARK-18814: Block cases where GROUP BY columns - // are not part of the correlated columns. - val groupByCols = AttributeSet(agg.groupingExpressions.flatMap(_.references)) - // Collect the local references from the correlated predicate in the subquery. - val subqueryColumns = getCorrelatedPredicates(query).flatMap(_.references) - .filterNot(conditions.flatMap(_.references).contains) - val correlatedCols = AttributeSet(subqueryColumns) - val invalidCols = groupByCols -- correlatedCols - // GROUP BY columns must be a subset of columns in the predicates - if (invalidCols.nonEmpty) { - failAnalysis( - "A GROUP BY clause in a scalar correlated subquery " + - "cannot contain non-correlated columns: " + - invalidCols.mkString(",")) - } - } - - // Skip subquery aliases added by the Analyzer. - // For projects, do the necessary mapping and skip to its child. - def cleanQuery(p: LogicalPlan): LogicalPlan = p match { - case s: SubqueryAlias => cleanQuery(s.child) - case p: Project => cleanQuery(p.child) - case child => child - } - - cleanQuery(query) match { - case a: Aggregate => checkAggregate(a) - case Filter(_, a: Aggregate) => checkAggregate(a) - case fail => failAnalysis(s"Correlated scalar subqueries must be Aggregated: $fail") - } - } - s - case s: SubqueryExpression => - checkAnalysis(s.plan) + checkSubqueryExpression(operator, s) s } @@ -291,19 +240,6 @@ trait CheckAnalysis extends PredicateHelper { case LocalLimit(limitExpr, _) => checkLimitClause(limitExpr) - case p if p.expressions.exists(ScalarSubquery.hasCorrelatedScalarSubquery) => - p match { - case _: Filter | _: Aggregate | _: Project => // Ok - case other => failAnalysis( - s"Correlated scalar sub-queries can only be used in a Filter/Aggregate/Project: $p") - } - - case p if p.expressions.exists(SubqueryExpression.hasInOrExistsSubquery) => - p match { - case _: Filter => // Ok - case _ => failAnalysis(s"Predicate sub-queries can only be used in a Filter: $p") - } - case _: Union | _: SetOperation if operator.children.length > 1 => def dataTypes(plan: LogicalPlan): Seq[DataType] = plan.output.map(_.dataType) def ordinalNumber(i: Int): String = i match { @@ -414,4 +350,272 @@ trait CheckAnalysis extends PredicateHelper { plan.foreach(_.setAnalyzed()) } + + /** + * Validates subquery expressions in the plan. Upon failure, returns an user facing error. + */ + private def checkSubqueryExpression(plan: LogicalPlan, expr: SubqueryExpression): Unit = { + def checkAggregateInScalarSubquery( + conditions: Seq[Expression], + query: LogicalPlan, agg: Aggregate): Unit = { + // Make sure correlated scalar subqueries contain one row for every outer row by + // enforcing that they are aggregates containing exactly one aggregate expression. + val aggregates = agg.expressions.flatMap(_.collect { + case a: AggregateExpression => a + }) + if (aggregates.isEmpty) { + failAnalysis("The output of a correlated scalar subquery must be aggregated") + } + + // SPARK-18504/SPARK-18814: Block cases where GROUP BY columns + // are not part of the correlated columns. + val groupByCols = AttributeSet(agg.groupingExpressions.flatMap(_.references)) + // Collect the local references from the correlated predicate in the subquery. + val subqueryColumns = getCorrelatedPredicates(query).flatMap(_.references) + .filterNot(conditions.flatMap(_.references).contains) + val correlatedCols = AttributeSet(subqueryColumns) + val invalidCols = groupByCols -- correlatedCols + // GROUP BY columns must be a subset of columns in the predicates + if (invalidCols.nonEmpty) { + failAnalysis( + "A GROUP BY clause in a scalar correlated subquery " + + "cannot contain non-correlated columns: " + + invalidCols.mkString(",")) + } + } + + // Skip subquery aliases added by the Analyzer. + // For projects, do the necessary mapping and skip to its child. + def cleanQueryInScalarSubquery(p: LogicalPlan): LogicalPlan = p match { + case s: SubqueryAlias => cleanQueryInScalarSubquery(s.child) + case p: Project => cleanQueryInScalarSubquery(p.child) + case child => child + } + + // Validate the subquery plan. + checkAnalysis(expr.plan) + + expr match { + case ScalarSubquery(query, conditions, _) => + // Scalar subquery must return one column as output. + if (query.output.size != 1) { + failAnalysis( + s"Scalar subquery must return only one column, but got ${query.output.size}") + } + + if (conditions.nonEmpty) { + cleanQueryInScalarSubquery(query) match { + case a: Aggregate => checkAggregateInScalarSubquery(conditions, query, a) + case Filter(_, a: Aggregate) => checkAggregateInScalarSubquery(conditions, query, a) + case fail => failAnalysis(s"Correlated scalar subqueries must be aggregated: $fail") + } + + // Only certain operators are allowed to host subquery expression containing + // outer references. + plan match { + case _: Filter | _: Aggregate | _: Project => // Ok + case other => failAnalysis( + "Correlated scalar sub-queries can only be used in a " + + s"Filter/Aggregate/Project: $plan") + } + } + + case inSubqueryOrExistsSubquery => + plan match { + case _: Filter => // Ok + case _ => + failAnalysis(s"IN/EXISTS predicate sub-queries can only be used in a Filter: $plan") + } + } + + // Validate to make sure the correlations appearing in the query are valid and + // allowed by spark. + checkCorrelationsInSubquery(expr.plan) + } + + /** + * Validates to make sure the outer references appearing inside the subquery + * are allowed. + */ + private def checkCorrelationsInSubquery(sub: LogicalPlan): Unit = { + // Validate that correlated aggregate expression do not contain a mixture + // of outer and local references. + def checkMixedReferencesInsideAggregateExpr(expr: Expression): Unit = { + expr.foreach { + case a: AggregateExpression if containsOuter(a) => + val outer = a.collect { case OuterReference(e) => e.toAttribute } + val local = a.references -- outer + if (local.nonEmpty) { + val msg = + s""" + |Found an aggregate expression in a correlated predicate that has both + |outer and local references, which is not supported yet. + |Aggregate expression: ${SubExprUtils.stripOuterReference(a).sql}, + |Outer references: ${outer.map(_.sql).mkString(", ")}, + |Local references: ${local.map(_.sql).mkString(", ")}. + """.stripMargin.replace("\n", " ").trim() + failAnalysis(msg) + } + case _ => + } + } + + // Make sure a plan's subtree does not contain outer references + def failOnOuterReferenceInSubTree(p: LogicalPlan): Unit = { + if (hasOuterReferences(p)) { + failAnalysis(s"Accessing outer query column is not allowed in:\n$p") + } + } + + // Make sure a plan's expressions do not contain : + // 1. Aggregate expressions that have mixture of outer and local references. + // 2. Expressions containing outer references on plan nodes other than Filter. + def failOnInvalidOuterReference(p: LogicalPlan): Unit = { + p.expressions.foreach(checkMixedReferencesInsideAggregateExpr) + if (!p.isInstanceOf[Filter] && p.expressions.exists(containsOuter)) { + failAnalysis( + "Expressions referencing the outer query are not supported outside of WHERE/HAVING " + + s"clauses:\n$p") + } + } + + // SPARK-17348: A potential incorrect result case. + // When a correlated predicate is a non-equality predicate, + // certain operators are not permitted from the operator + // hosting the correlated predicate up to the operator on the outer table. + // Otherwise, the pull up of the correlated predicate + // will generate a plan with a different semantics + // which could return incorrect result. + // Currently we check for Aggregate and Window operators + // + // Below shows an example of a Logical Plan during Analyzer phase that + // show this problem. Pulling the correlated predicate [outer(c2#77) >= ..] + // through the Aggregate (or Window) operator could alter the result of + // the Aggregate. + // + // Project [c1#76] + // +- Project [c1#87, c2#88] + // : (Aggregate or Window operator) + // : +- Filter [outer(c2#77) >= c2#88)] + // : +- SubqueryAlias t2, `t2` + // : +- Project [_1#84 AS c1#87, _2#85 AS c2#88] + // : +- LocalRelation [_1#84, _2#85] + // +- SubqueryAlias t1, `t1` + // +- Project [_1#73 AS c1#76, _2#74 AS c2#77] + // +- LocalRelation [_1#73, _2#74] + def failOnNonEqualCorrelatedPredicate(found: Boolean, p: LogicalPlan): Unit = { + if (found) { + // Report a non-supported case as an exception + failAnalysis(s"Correlated column is not allowed in a non-equality predicate:\n$p") + } + } + + var foundNonEqualCorrelatedPred: Boolean = false + + // Simplify the predicates before validating any unsupported correlation patterns + // in the plan. + BooleanSimplification(sub).foreachUp { + // Whitelist operators allowed in a correlated subquery + // There are 4 categories: + // 1. Operators that are allowed anywhere in a correlated subquery, and, + // by definition of the operators, they either do not contain + // any columns or cannot host outer references. + // 2. Operators that are allowed anywhere in a correlated subquery + // so long as they do not host outer references. + // 3. Operators that need special handlings. These operators are + // Filter, Join, Aggregate, and Generate. + // + // Any operators that are not in the above list are allowed + // in a correlated subquery only if they are not on a correlation path. + // In other word, these operators are allowed only under a correlation point. + // + // A correlation path is defined as the sub-tree of all the operators that + // are on the path from the operator hosting the correlated expressions + // up to the operator producing the correlated values. + + // Category 1: + // ResolvedHint, Distinct, LeafNode, Repartition, and SubqueryAlias + case _: ResolvedHint | _: Distinct | _: LeafNode | _: Repartition | _: SubqueryAlias => + + // Category 2: + // These operators can be anywhere in a correlated subquery. + // so long as they do not host outer references in the operators. + case p: Project => + failOnInvalidOuterReference(p) + + case s: Sort => + failOnInvalidOuterReference(s) + + case r: RepartitionByExpression => + failOnInvalidOuterReference(r) + + // Category 3: + // Filter is one of the two operators allowed to host correlated expressions. + // The other operator is Join. Filter can be anywhere in a correlated subquery. + case f: Filter => + val (correlated, _) = splitConjunctivePredicates(f.condition).partition(containsOuter) + + // Find any non-equality correlated predicates + foundNonEqualCorrelatedPred = foundNonEqualCorrelatedPred || correlated.exists { + case _: EqualTo | _: EqualNullSafe => false + case _ => true + } + failOnInvalidOuterReference(f) + + // Aggregate cannot host any correlated expressions + // It can be on a correlation path if the correlation contains + // only equality correlated predicates. + // It cannot be on a correlation path if the correlation has + // non-equality correlated predicates. + case a: Aggregate => + failOnInvalidOuterReference(a) + failOnNonEqualCorrelatedPredicate(foundNonEqualCorrelatedPred, a) + + // Join can host correlated expressions. + case j @ Join(left, right, joinType, _) => + joinType match { + // Inner join, like Filter, can be anywhere. + case _: InnerLike => + failOnInvalidOuterReference(j) + + // Left outer join's right operand cannot be on a correlation path. + // LeftAnti and ExistenceJoin are special cases of LeftOuter. + // Note that ExistenceJoin cannot be expressed externally in both SQL and DataFrame + // so it should not show up here in Analysis phase. This is just a safety net. + // + // LeftSemi does not allow output from the right operand. + // Any correlated references in the subplan + // of the right operand cannot be pulled up. + case LeftOuter | LeftSemi | LeftAnti | ExistenceJoin(_) => + failOnInvalidOuterReference(j) + failOnOuterReferenceInSubTree(right) + + // Likewise, Right outer join's left operand cannot be on a correlation path. + case RightOuter => + failOnInvalidOuterReference(j) + failOnOuterReferenceInSubTree(left) + + // Any other join types not explicitly listed above, + // including Full outer join, are treated as Category 4. + case _ => + failOnOuterReferenceInSubTree(j) + } + + // Generator with join=true, i.e., expressed with + // LATERAL VIEW [OUTER], similar to inner join, + // allows to have correlation under it + // but must not host any outer references. + // Note: + // Generator with join=false is treated as Category 4. + case g: Generate if g.join => + failOnInvalidOuterReference(g) + + // Category 4: Any other operators not in the above 3 categories + // cannot be on a correlation path, that is they are allowed only + // under a correlation point but they and their descendant operators + // are not allowed to have any correlated expressions. + case p => + failOnOuterReferenceInSubTree(p) + } + } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala index c15ee2ab270bc..f3fe58caa6fe2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala @@ -144,27 +144,39 @@ case class In(value: Expression, list: Seq[Expression]) extends Predicate { case cns: CreateNamedStruct => cns.valExprs case expr => Seq(expr) } - - val mismatchedColumns = valExprs.zip(sub.output).flatMap { - case (l, r) if l.dataType != r.dataType => - s"(${l.sql}:${l.dataType.catalogString}, ${r.sql}:${r.dataType.catalogString})" - case _ => None - } - - if (mismatchedColumns.nonEmpty) { + if (valExprs.length != sub.output.length) { TypeCheckResult.TypeCheckFailure( s""" - |The data type of one or more elements in the left hand side of an IN subquery - |is not compatible with the data type of the output of the subquery - |Mismatched columns: - |[${mismatchedColumns.mkString(", ")}] - |Left side: - |[${valExprs.map(_.dataType.catalogString).mkString(", ")}]. - |Right side: - |[${sub.output.map(_.dataType.catalogString).mkString(", ")}]. + |The number of columns in the left hand side of an IN subquery does not match the + |number of columns in the output of subquery. + |#columns in left hand side: ${valExprs.length}. + |#columns in right hand side: ${sub.output.length}. + |Left side columns: + |[${valExprs.map(_.sql).mkString(", ")}]. + |Right side columns: + |[${sub.output.map(_.sql).mkString(", ")}]. """.stripMargin) } else { - TypeCheckResult.TypeCheckSuccess + val mismatchedColumns = valExprs.zip(sub.output).flatMap { + case (l, r) if l.dataType != r.dataType => + s"(${l.sql}:${l.dataType.catalogString}, ${r.sql}:${r.dataType.catalogString})" + case _ => None + } + if (mismatchedColumns.nonEmpty) { + TypeCheckResult.TypeCheckFailure( + s""" + |The data type of one or more elements in the left hand side of an IN subquery + |is not compatible with the data type of the output of the subquery + |Mismatched columns: + |[${mismatchedColumns.mkString(", ")}] + |Left side: + |[${valExprs.map(_.dataType.catalogString).mkString(", ")}]. + |Right side: + |[${sub.output.map(_.dataType.catalogString).mkString(", ")}]. + """.stripMargin) + } else { + TypeCheckResult.TypeCheckSuccess + } } case _ => if (list.exists(l => l.dataType != value.dataType)) { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala index 5050318d96358..4ed995e20d7ce 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala @@ -111,8 +111,7 @@ class AnalysisErrorSuite extends AnalysisTest { "scalar subquery with 2 columns", testRelation.select( (ScalarSubquery(testRelation.select('a, dateLit.as('b))) + Literal(1)).as('a)), - "The number of columns in the subquery (2)" :: - "does not match the required number of columns (1)":: Nil) + "Scalar subquery must return only one column, but got 2" :: Nil) errorTest( "scalar subquery with no column", diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveSubquerySuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveSubquerySuite.scala index 55693121431a2..1bf8d76da04d8 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveSubquerySuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveSubquerySuite.scala @@ -35,7 +35,7 @@ class ResolveSubquerySuite extends AnalysisTest { test("SPARK-17251 Improve `OuterReference` to be `NamedExpression`") { val expr = Filter(In(a, Seq(ListQuery(Project(Seq(UnresolvedAttribute("a")), t2)))), t1) val m = intercept[AnalysisException] { - SimpleAnalyzer.ResolveSubquery(expr) + SimpleAnalyzer.checkAnalysis(SimpleAnalyzer.ResolveSubquery(expr)) }.getMessage assert(m.contains( "Expressions referencing the outer query are not supported outside of WHERE/HAVING clauses")) diff --git a/sql/core/src/test/resources/sql-tests/inputs/subquery/negative-cases/subq-input-typecheck.sql b/sql/core/src/test/resources/sql-tests/inputs/subquery/negative-cases/subq-input-typecheck.sql new file mode 100644 index 0000000000000..b15f4da81dd93 --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/inputs/subquery/negative-cases/subq-input-typecheck.sql @@ -0,0 +1,47 @@ +-- The test file contains negative test cases +-- of invalid queries where error messages are expected. + +CREATE TEMPORARY VIEW t1 AS SELECT * FROM VALUES + (1, 2, 3) +AS t1(t1a, t1b, t1c); + +CREATE TEMPORARY VIEW t2 AS SELECT * FROM VALUES + (1, 0, 1) +AS t2(t2a, t2b, t2c); + +CREATE TEMPORARY VIEW t3 AS SELECT * FROM VALUES + (3, 1, 2) +AS t3(t3a, t3b, t3c); + +-- TC 01.01 +SELECT + ( SELECT max(t2b), min(t2b) + FROM t2 + WHERE t2.t2b = t1.t1b + GROUP BY t2.t2b + ) +FROM t1; + +-- TC 01.01 +SELECT + ( SELECT max(t2b), min(t2b) + FROM t2 + WHERE t2.t2b > 0 + GROUP BY t2.t2b + ) +FROM t1; + +-- TC 01.03 +SELECT * FROM t1 +WHERE +t1a IN (SELECT t2a, t2b + FROM t2 + WHERE t1a = t2a); + +-- TC 01.04 +SELECT * FROM T1 +WHERE +(t1a, t1b) IN (SELECT t2a + FROM t2 + WHERE t1a = t2a); + diff --git a/sql/core/src/test/resources/sql-tests/results/subquery/negative-cases/subq-input-typecheck.sql.out b/sql/core/src/test/resources/sql-tests/results/subquery/negative-cases/subq-input-typecheck.sql.out new file mode 100644 index 0000000000000..9ea9d3c4c6f40 --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/results/subquery/negative-cases/subq-input-typecheck.sql.out @@ -0,0 +1,106 @@ +-- Automatically generated by SQLQueryTestSuite +-- Number of queries: 7 + + +-- !query 0 +CREATE TEMPORARY VIEW t1 AS SELECT * FROM VALUES + (1, 2, 3) +AS t1(t1a, t1b, t1c) +-- !query 0 schema +struct<> +-- !query 0 output + + + +-- !query 1 +CREATE TEMPORARY VIEW t2 AS SELECT * FROM VALUES + (1, 0, 1) +AS t2(t2a, t2b, t2c) +-- !query 1 schema +struct<> +-- !query 1 output + + + +-- !query 2 +CREATE TEMPORARY VIEW t3 AS SELECT * FROM VALUES + (3, 1, 2) +AS t3(t3a, t3b, t3c) +-- !query 2 schema +struct<> +-- !query 2 output + + + +-- !query 3 +SELECT + ( SELECT max(t2b), min(t2b) + FROM t2 + WHERE t2.t2b = t1.t1b + GROUP BY t2.t2b + ) +FROM t1 +-- !query 3 schema +struct<> +-- !query 3 output +org.apache.spark.sql.AnalysisException +Scalar subquery must return only one column, but got 2; + + +-- !query 4 +SELECT + ( SELECT max(t2b), min(t2b) + FROM t2 + WHERE t2.t2b > 0 + GROUP BY t2.t2b + ) +FROM t1 +-- !query 4 schema +struct<> +-- !query 4 output +org.apache.spark.sql.AnalysisException +Scalar subquery must return only one column, but got 2; + + +-- !query 5 +SELECT * FROM t1 +WHERE +t1a IN (SELECT t2a, t2b + FROM t2 + WHERE t1a = t2a) +-- !query 5 schema +struct<> +-- !query 5 output +org.apache.spark.sql.AnalysisException +cannot resolve '(t1.`t1a` IN (listquery(t1.`t1a`)))' due to data type mismatch: +The number of columns in the left hand side of an IN subquery does not match the +number of columns in the output of subquery. +#columns in left hand side: 1. +#columns in right hand side: 2. +Left side columns: +[t1.`t1a`]. +Right side columns: +[t2.`t2a`, t2.`t2b`]. + ; + + +-- !query 6 +SELECT * FROM T1 +WHERE +(t1a, t1b) IN (SELECT t2a + FROM t2 + WHERE t1a = t2a) +-- !query 6 schema +struct<> +-- !query 6 output +org.apache.spark.sql.AnalysisException +cannot resolve '(named_struct('t1a', t1.`t1a`, 't1b', t1.`t1b`) IN (listquery(t1.`t1a`)))' due to data type mismatch: +The number of columns in the left hand side of an IN subquery does not match the +number of columns in the output of subquery. +#columns in left hand side: 2. +#columns in right hand side: 1. +Left side columns: +[t1.`t1a`, t1.`t1b`]. +Right side columns: +[t2.`t2a`]. + ; diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala index 4629a8c0dbe5f..820cff655c4ff 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala @@ -517,7 +517,7 @@ class SubquerySuite extends QueryTest with SharedSQLContext { val msg1 = intercept[AnalysisException] { sql("select a, (select b from l l2 where l2.a = l1.a) sum_b from l l1") } - assert(msg1.getMessage.contains("Correlated scalar subqueries must be Aggregated")) + assert(msg1.getMessage.contains("Correlated scalar subqueries must be aggregated")) val msg2 = intercept[AnalysisException] { sql("select a, (select b from l l2 where l2.a = l1.a group by 1) sum_b from l l1") From 03eb6117affcca21798be25706a39e0d5a2f7288 Mon Sep 17 00:00:00 2001 From: Xiao Li Date: Fri, 23 Jun 2017 14:48:33 -0700 Subject: [PATCH 0789/1765] [SPARK-21164][SQL] Remove isTableSample from Sample and isGenerated from Alias and AttributeReference ## What changes were proposed in this pull request? `isTableSample` and `isGenerated ` were introduced for SQL Generation respectively by https://github.com/apache/spark/pull/11148 and https://github.com/apache/spark/pull/11050 Since SQL Generation is removed, we do not need to keep `isTableSample`. ## How was this patch tested? The existing test cases Author: Xiao Li Closes #18379 from gatorsmile/CleanSample. --- .../sql/catalyst/analysis/Analyzer.scala | 8 ++--- .../expressions/namedExpressions.scala | 34 +++++++------------ .../optimizer/RewriteDistinctAggregates.scala | 2 +- .../sql/catalyst/parser/AstBuilder.scala | 2 +- .../sql/catalyst/planning/patterns.scala | 4 +-- .../spark/sql/catalyst/plans/QueryPlan.scala | 2 +- .../catalyst/plans/logical/LogicalPlan.scala | 2 +- .../plans/logical/basicLogicalOperators.scala | 6 +--- .../analysis/AnalysisErrorSuite.scala | 2 +- .../analysis/UnsupportedOperationsSuite.scala | 2 +- .../optimizer/ColumnPruningSuite.scala | 8 ++--- .../sql/catalyst/parser/PlanParserSuite.scala | 4 +-- .../spark/sql/catalyst/plans/PlanTest.scala | 10 +++--- .../BasicStatsEstimationSuite.scala | 4 +-- .../scala/org/apache/spark/sql/Dataset.scala | 4 +-- 15 files changed, 40 insertions(+), 54 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 193082eb77024..7e5ebfc93286f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -874,7 +874,7 @@ class Analyzer( def newAliases(expressions: Seq[NamedExpression]): Seq[NamedExpression] = { expressions.map { - case a: Alias => Alias(a.child, a.name)(isGenerated = a.isGenerated) + case a: Alias => Alias(a.child, a.name)() case other => other } } @@ -1368,7 +1368,7 @@ class Analyzer( val aggregatedCondition = Aggregate( grouping, - Alias(havingCondition, "havingCondition")(isGenerated = true) :: Nil, + Alias(havingCondition, "havingCondition")() :: Nil, child) val resolvedOperator = execute(aggregatedCondition) def resolvedAggregateFilter = @@ -1424,7 +1424,7 @@ class Analyzer( try { val unresolvedSortOrders = sortOrder.filter(s => !s.resolved || containsAggregate(s)) val aliasedOrdering = - unresolvedSortOrders.map(o => Alias(o.child, "aggOrder")(isGenerated = true)) + unresolvedSortOrders.map(o => Alias(o.child, "aggOrder")()) val aggregatedOrdering = aggregate.copy(aggregateExpressions = aliasedOrdering) val resolvedAggregate: Aggregate = execute(aggregatedOrdering).asInstanceOf[Aggregate] val resolvedAliasedOrdering: Seq[Alias] = @@ -1935,7 +1935,7 @@ class Analyzer( leafNondeterministic.distinct.map { e => val ne = e match { case n: NamedExpression => n - case _ => Alias(e, "_nondeterministic")(isGenerated = true) + case _ => Alias(e, "_nondeterministic")() } e -> ne } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala index c842f85af693c..29c33804f077a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala @@ -81,9 +81,6 @@ trait NamedExpression extends Expression { /** Returns the metadata when an expression is a reference to another expression with metadata. */ def metadata: Metadata = Metadata.empty - /** Returns true if the expression is generated by Catalyst */ - def isGenerated: java.lang.Boolean = false - /** Returns a copy of this expression with a new `exprId`. */ def newInstance(): NamedExpression @@ -128,13 +125,11 @@ abstract class Attribute extends LeafExpression with NamedExpression with NullIn * qualified way. Consider the examples tableName.name, subQueryAlias.name. * tableName and subQueryAlias are possible qualifiers. * @param explicitMetadata Explicit metadata associated with this alias that overwrites child's. - * @param isGenerated A flag to indicate if this alias is generated by Catalyst */ case class Alias(child: Expression, name: String)( val exprId: ExprId = NamedExpression.newExprId, val qualifier: Option[String] = None, - val explicitMetadata: Option[Metadata] = None, - override val isGenerated: java.lang.Boolean = false) + val explicitMetadata: Option[Metadata] = None) extends UnaryExpression with NamedExpression { // Alias(Generator, xx) need to be transformed into Generate(generator, ...) @@ -159,13 +154,11 @@ case class Alias(child: Expression, name: String)( } def newInstance(): NamedExpression = - Alias(child, name)( - qualifier = qualifier, explicitMetadata = explicitMetadata, isGenerated = isGenerated) + Alias(child, name)(qualifier = qualifier, explicitMetadata = explicitMetadata) override def toAttribute: Attribute = { if (resolved) { - AttributeReference(name, child.dataType, child.nullable, metadata)( - exprId, qualifier, isGenerated) + AttributeReference(name, child.dataType, child.nullable, metadata)(exprId, qualifier) } else { UnresolvedAttribute(name) } @@ -174,7 +167,7 @@ case class Alias(child: Expression, name: String)( override def toString: String = s"$child AS $name#${exprId.id}$typeSuffix" override protected final def otherCopyArgs: Seq[AnyRef] = { - exprId :: qualifier :: explicitMetadata :: isGenerated :: Nil + exprId :: qualifier :: explicitMetadata :: Nil } override def hashCode(): Int = { @@ -207,7 +200,6 @@ case class Alias(child: Expression, name: String)( * @param qualifier An optional string that can be used to referred to this attribute in a fully * qualified way. Consider the examples tableName.name, subQueryAlias.name. * tableName and subQueryAlias are possible qualifiers. - * @param isGenerated A flag to indicate if this reference is generated by Catalyst */ case class AttributeReference( name: String, @@ -215,8 +207,7 @@ case class AttributeReference( nullable: Boolean = true, override val metadata: Metadata = Metadata.empty)( val exprId: ExprId = NamedExpression.newExprId, - val qualifier: Option[String] = None, - override val isGenerated: java.lang.Boolean = false) + val qualifier: Option[String] = None) extends Attribute with Unevaluable { /** @@ -253,8 +244,7 @@ case class AttributeReference( } override def newInstance(): AttributeReference = - AttributeReference(name, dataType, nullable, metadata)( - qualifier = qualifier, isGenerated = isGenerated) + AttributeReference(name, dataType, nullable, metadata)(qualifier = qualifier) /** * Returns a copy of this [[AttributeReference]] with changed nullability. @@ -263,7 +253,7 @@ case class AttributeReference( if (nullable == newNullability) { this } else { - AttributeReference(name, dataType, newNullability, metadata)(exprId, qualifier, isGenerated) + AttributeReference(name, dataType, newNullability, metadata)(exprId, qualifier) } } @@ -271,7 +261,7 @@ case class AttributeReference( if (name == newName) { this } else { - AttributeReference(newName, dataType, nullable, metadata)(exprId, qualifier, isGenerated) + AttributeReference(newName, dataType, nullable, metadata)(exprId, qualifier) } } @@ -282,7 +272,7 @@ case class AttributeReference( if (newQualifier == qualifier) { this } else { - AttributeReference(name, dataType, nullable, metadata)(exprId, newQualifier, isGenerated) + AttributeReference(name, dataType, nullable, metadata)(exprId, newQualifier) } } @@ -290,16 +280,16 @@ case class AttributeReference( if (exprId == newExprId) { this } else { - AttributeReference(name, dataType, nullable, metadata)(newExprId, qualifier, isGenerated) + AttributeReference(name, dataType, nullable, metadata)(newExprId, qualifier) } } override def withMetadata(newMetadata: Metadata): Attribute = { - AttributeReference(name, dataType, nullable, newMetadata)(exprId, qualifier, isGenerated) + AttributeReference(name, dataType, nullable, newMetadata)(exprId, qualifier) } override protected final def otherCopyArgs: Seq[AnyRef] = { - exprId :: qualifier :: isGenerated :: Nil + exprId :: qualifier :: Nil } /** Used to signal the column used to calculate an eventTime watermark (e.g. a#1-T{delayMs}) */ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregates.scala index 3b27cd2ffe028..4448ace7105a4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregates.scala @@ -134,7 +134,7 @@ object RewriteDistinctAggregates extends Rule[LogicalPlan] { // Aggregation strategy can handle queries with a single distinct group. if (distinctAggGroups.size > 1) { // Create the attributes for the grouping id and the group by clause. - val gid = AttributeReference("gid", IntegerType, nullable = false)(isGenerated = true) + val gid = AttributeReference("gid", IntegerType, nullable = false)() val groupByMap = a.groupingExpressions.collect { case ne: NamedExpression => ne -> ne.toAttribute case e => e -> AttributeReference(e.sql, e.dataType, e.nullable)() diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index 315c6721b3f65..ef79cbcaa0ce6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -627,7 +627,7 @@ class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging validate(fraction >= 0.0 - eps && fraction <= 1.0 + eps, s"Sampling fraction ($fraction) must be on interval [0, 1]", ctx) - Sample(0.0, fraction, withReplacement = false, (math.random * 1000).toInt, query)(true) + Sample(0.0, fraction, withReplacement = false, (math.random * 1000).toInt, query) } ctx.sampleType.getType match { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala index ef925f92ecc7e..7f370fb731b2f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala @@ -80,12 +80,12 @@ object PhysicalOperation extends PredicateHelper { expr.transform { case a @ Alias(ref: AttributeReference, name) => aliases.get(ref) - .map(Alias(_, name)(a.exprId, a.qualifier, isGenerated = a.isGenerated)) + .map(Alias(_, name)(a.exprId, a.qualifier)) .getOrElse(a) case a: AttributeReference => aliases.get(a) - .map(Alias(_, a.name)(a.exprId, a.qualifier, isGenerated = a.isGenerated)).getOrElse(a) + .map(Alias(_, a.name)(a.exprId, a.qualifier)).getOrElse(a) } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala index 1f6d05bc8d816..01b3da3f7c482 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala @@ -200,7 +200,7 @@ abstract class QueryPlan[PlanType <: QueryPlan[PlanType]] extends TreeNode[PlanT // normalize that for equality testing, by assigning expr id from 0 incrementally. The // alias name doesn't matter and should be erased. val normalizedChild = QueryPlan.normalizeExprId(a.child, allAttributes) - Alias(normalizedChild, "")(ExprId(id), a.qualifier, isGenerated = a.isGenerated) + Alias(normalizedChild, "")(ExprId(id), a.qualifier) case ar: AttributeReference if allAttributes.indexOf(ar.exprId) == -1 => // Top level `AttributeReference` may also be used for output like `Alias`, we should diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala index 0c098ac0209e8..0d30aa76049a5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala @@ -221,7 +221,7 @@ abstract class LogicalPlan extends QueryPlan[LogicalPlan] with QueryPlanConstrai nameParts: Seq[String], resolver: Resolver, attribute: Attribute): Option[(Attribute, List[String])] = { - if (!attribute.isGenerated && resolver(attribute.name, nameParts.head)) { + if (resolver(attribute.name, nameParts.head)) { Option((attribute.withName(nameParts.head), nameParts.tail.toList)) } else { None diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala index d8f89b108e63f..e89caabf252d7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala @@ -807,15 +807,13 @@ case class SubqueryAlias( * @param withReplacement Whether to sample with replacement. * @param seed the random seed * @param child the LogicalPlan - * @param isTableSample Is created from TABLESAMPLE in the parser. */ case class Sample( lowerBound: Double, upperBound: Double, withReplacement: Boolean, seed: Long, - child: LogicalPlan)( - val isTableSample: java.lang.Boolean = false) extends UnaryNode { + child: LogicalPlan) extends UnaryNode { val eps = RandomSampler.roundingEpsilon val fraction = upperBound - lowerBound @@ -842,8 +840,6 @@ case class Sample( // Don't propagate column stats, because we don't know the distribution after a sample operation Statistics(sizeInBytes, sampledRowCount, hints = childStats.hints) } - - override protected def otherCopyArgs: Seq[AnyRef] = isTableSample :: Nil } /** diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala index 4ed995e20d7ce..7311dc3899e53 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala @@ -573,7 +573,7 @@ class AnalysisErrorSuite extends AnalysisTest { val plan5 = Filter( Exists( Sample(0.0, 0.5, false, 1L, - Filter(EqualTo(UnresolvedAttribute("a"), b), LocalRelation(b)))().select('b) + Filter(EqualTo(UnresolvedAttribute("a"), b), LocalRelation(b))).select('b) ), LocalRelation(a)) assertAnalysisError(plan5, diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationsSuite.scala index c39e372c272b1..f68d930f60523 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationsSuite.scala @@ -491,7 +491,7 @@ class UnsupportedOperationsSuite extends SparkFunSuite { // Other unary operations testUnaryOperatorInStreamingPlan( - "sample", Sample(0.1, 1, true, 1L, _)(), expectedMsg = "sampling") + "sample", Sample(0.1, 1, true, 1L, _), expectedMsg = "sampling") testUnaryOperatorInStreamingPlan( "window", Window(Nil, Nil, Nil, _), expectedMsg = "non-time-based windows") diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ColumnPruningSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ColumnPruningSuite.scala index 0b419e9631b29..08e58d47e0e25 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ColumnPruningSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ColumnPruningSuite.scala @@ -349,14 +349,14 @@ class ColumnPruningSuite extends PlanTest { val testRelation = LocalRelation('a.int, 'b.int, 'c.int) val x = testRelation.subquery('x) - val query1 = Sample(0.0, 0.6, false, 11L, x)().select('a) + val query1 = Sample(0.0, 0.6, false, 11L, x).select('a) val optimized1 = Optimize.execute(query1.analyze) - val expected1 = Sample(0.0, 0.6, false, 11L, x.select('a))() + val expected1 = Sample(0.0, 0.6, false, 11L, x.select('a)) comparePlans(optimized1, expected1.analyze) - val query2 = Sample(0.0, 0.6, false, 11L, x)().select('a as 'aa) + val query2 = Sample(0.0, 0.6, false, 11L, x).select('a as 'aa) val optimized2 = Optimize.execute(query2.analyze) - val expected2 = Sample(0.0, 0.6, false, 11L, x.select('a))().select('a as 'aa) + val expected2 = Sample(0.0, 0.6, false, 11L, x.select('a)).select('a as 'aa) comparePlans(optimized2, expected2.analyze) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala index 0a4ae098d65cc..bf15b85d5b510 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala @@ -411,9 +411,9 @@ class PlanParserSuite extends AnalysisTest { assertEqual(s"$sql tablesample(100 rows)", table("t").limit(100).select(star())) assertEqual(s"$sql tablesample(43 percent) as x", - Sample(0, .43d, withReplacement = false, 10L, table("t").as("x"))(true).select(star())) + Sample(0, .43d, withReplacement = false, 10L, table("t").as("x")).select(star())) assertEqual(s"$sql tablesample(bucket 4 out of 10) as x", - Sample(0, .4d, withReplacement = false, 10L, table("t").as("x"))(true).select(star())) + Sample(0, .4d, withReplacement = false, 10L, table("t").as("x")).select(star())) intercept(s"$sql tablesample(bucket 4 out of 10 on x) as x", "TABLESAMPLE(BUCKET x OUT OF y ON colname) is not supported") intercept(s"$sql tablesample(bucket 11 out of 10) as x", diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala index 25313af2be184..6883d23d477e4 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala @@ -63,14 +63,14 @@ abstract class PlanTest extends SparkFunSuite with PredicateHelper { */ protected def normalizePlan(plan: LogicalPlan): LogicalPlan = { plan transform { - case filter @ Filter(condition: Expression, child: LogicalPlan) => - Filter(splitConjunctivePredicates(condition).map(rewriteEqual(_)).sortBy(_.hashCode()) + case Filter(condition: Expression, child: LogicalPlan) => + Filter(splitConjunctivePredicates(condition).map(rewriteEqual).sortBy(_.hashCode()) .reduce(And), child) case sample: Sample => - sample.copy(seed = 0L)(true) - case join @ Join(left, right, joinType, condition) if condition.isDefined => + sample.copy(seed = 0L) + case Join(left, right, joinType, condition) if condition.isDefined => val newCondition = - splitConjunctivePredicates(condition.get).map(rewriteEqual(_)).sortBy(_.hashCode()) + splitConjunctivePredicates(condition.get).map(rewriteEqual).sortBy(_.hashCode()) .reduce(And) Join(left, right, joinType, Some(newCondition)) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/BasicStatsEstimationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/BasicStatsEstimationSuite.scala index e9ed36feec48c..912c5fed63450 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/BasicStatsEstimationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/BasicStatsEstimationSuite.scala @@ -78,14 +78,14 @@ class BasicStatsEstimationSuite extends StatsEstimationTestBase { } test("sample estimation") { - val sample = Sample(0.0, 0.5, withReplacement = false, (math.random * 1000).toLong, plan)() + val sample = Sample(0.0, 0.5, withReplacement = false, (math.random * 1000).toLong, plan) checkStats(sample, Statistics(sizeInBytes = 60, rowCount = Some(5))) // Child doesn't have rowCount in stats val childStats = Statistics(sizeInBytes = 120) val childPlan = DummyLogicalPlan(childStats, childStats) val sample2 = - Sample(0.0, 0.11, withReplacement = false, (math.random * 1000).toLong, childPlan)() + Sample(0.0, 0.11, withReplacement = false, (math.random * 1000).toLong, childPlan) checkStats(sample2, Statistics(sizeInBytes = 14)) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index 767dad3e63a6d..6e66e92091ff9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -1807,7 +1807,7 @@ class Dataset[T] private[sql]( */ def sample(withReplacement: Boolean, fraction: Double, seed: Long): Dataset[T] = { withTypedPlan { - Sample(0.0, fraction, withReplacement, seed, logicalPlan)() + Sample(0.0, fraction, withReplacement, seed, logicalPlan) } } @@ -1863,7 +1863,7 @@ class Dataset[T] private[sql]( val normalizedCumWeights = weights.map(_ / sum).scanLeft(0.0d)(_ + _) normalizedCumWeights.sliding(2).map { x => new Dataset[T]( - sparkSession, Sample(x(0), x(1), withReplacement = false, seed, plan)(), encoder) + sparkSession, Sample(x(0), x(1), withReplacement = false, seed, plan), encoder) }.toArray } From 7525ce98b4575b1ac4e44cc9b3a5773f03eba19e Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Sat, 24 Jun 2017 11:39:41 +0800 Subject: [PATCH 0790/1765] [SPARK-20431][SS][FOLLOWUP] Specify a schema by using a DDL-formatted string in DataStreamReader ## What changes were proposed in this pull request? This pr supported a DDL-formatted string in `DataStreamReader.schema`. This fix could make users easily define a schema without importing the type classes. For example, ```scala scala> spark.readStream.schema("col0 INT, col1 DOUBLE").load("/tmp/abc").printSchema() root |-- col0: integer (nullable = true) |-- col1: double (nullable = true) ``` ## How was this patch tested? Added tests in `DataStreamReaderWriterSuite`. Author: hyukjinkwon Closes #18373 from HyukjinKwon/SPARK-20431. --- python/pyspark/sql/readwriter.py | 2 ++ python/pyspark/sql/streaming.py | 24 ++++++++++++------- .../sql/streaming/DataStreamReader.scala | 12 ++++++++++ .../test/DataStreamReaderWriterSuite.scala | 12 ++++++++++ 4 files changed, 42 insertions(+), 8 deletions(-) diff --git a/python/pyspark/sql/readwriter.py b/python/pyspark/sql/readwriter.py index aef71f9ca7001..7279173df6e4f 100644 --- a/python/pyspark/sql/readwriter.py +++ b/python/pyspark/sql/readwriter.py @@ -98,6 +98,8 @@ def schema(self, schema): :param schema: a :class:`pyspark.sql.types.StructType` object or a DDL-formatted string (For example ``col0 INT, col1 DOUBLE``). + + >>> s = spark.read.schema("col0 INT, col1 DOUBLE") """ from pyspark.sql import SparkSession spark = SparkSession.builder.getOrCreate() diff --git a/python/pyspark/sql/streaming.py b/python/pyspark/sql/streaming.py index 58aa2468e006d..5bbd70cf0a789 100644 --- a/python/pyspark/sql/streaming.py +++ b/python/pyspark/sql/streaming.py @@ -319,16 +319,21 @@ def schema(self, schema): .. note:: Evolving. - :param schema: a :class:`pyspark.sql.types.StructType` object + :param schema: a :class:`pyspark.sql.types.StructType` object or a DDL-formatted string + (For example ``col0 INT, col1 DOUBLE``). >>> s = spark.readStream.schema(sdf_schema) + >>> s = spark.readStream.schema("col0 INT, col1 DOUBLE") """ from pyspark.sql import SparkSession - if not isinstance(schema, StructType): - raise TypeError("schema should be StructType") spark = SparkSession.builder.getOrCreate() - jschema = spark._jsparkSession.parseDataType(schema.json()) - self._jreader = self._jreader.schema(jschema) + if isinstance(schema, StructType): + jschema = spark._jsparkSession.parseDataType(schema.json()) + self._jreader = self._jreader.schema(jschema) + elif isinstance(schema, basestring): + self._jreader = self._jreader.schema(schema) + else: + raise TypeError("schema should be StructType or string") return self @since(2.0) @@ -372,7 +377,8 @@ def load(self, path=None, format=None, schema=None, **options): :param path: optional string for file-system backed data sources. :param format: optional string for format of the data source. Default to 'parquet'. - :param schema: optional :class:`pyspark.sql.types.StructType` for the input schema. + :param schema: optional :class:`pyspark.sql.types.StructType` for the input schema + or a DDL-formatted string (For example ``col0 INT, col1 DOUBLE``). :param options: all other string options >>> json_sdf = spark.readStream.format("json") \\ @@ -415,7 +421,8 @@ def json(self, path, schema=None, primitivesAsString=None, prefersDecimal=None, :param path: string represents path to the JSON dataset, or RDD of Strings storing JSON objects. - :param schema: an optional :class:`pyspark.sql.types.StructType` for the input schema. + :param schema: an optional :class:`pyspark.sql.types.StructType` for the input schema + or a DDL-formatted string (For example ``col0 INT, col1 DOUBLE``). :param primitivesAsString: infers all primitive values as a string type. If None is set, it uses the default value, ``false``. :param prefersDecimal: infers all floating-point values as a decimal type. If the values @@ -542,7 +549,8 @@ def csv(self, path, schema=None, sep=None, encoding=None, quote=None, escape=Non .. note:: Evolving. :param path: string, or list of strings, for input path(s). - :param schema: an optional :class:`pyspark.sql.types.StructType` for the input schema. + :param schema: an optional :class:`pyspark.sql.types.StructType` for the input schema + or a DDL-formatted string (For example ``col0 INT, col1 DOUBLE``). :param sep: sets the single character as a separator for each field and value. If None is set, it uses the default value, ``,``. :param encoding: decodes the CSV files by the given encoding type. If None is set, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala index 7e8e6394b4862..70ddfa8e9b835 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala @@ -59,6 +59,18 @@ final class DataStreamReader private[sql](sparkSession: SparkSession) extends Lo this } + /** + * Specifies the schema by using the input DDL-formatted string. Some data sources (e.g. JSON) can + * infer the input schema automatically from data. By specifying the schema here, the underlying + * data source can skip the schema inference step, and thus speed up data loading. + * + * @since 2.3.0 + */ + def schema(schemaString: String): DataStreamReader = { + this.userSpecifiedSchema = Option(StructType.fromDDL(schemaString)) + this + } + /** * Adds an input option for the underlying data source. * diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/test/DataStreamReaderWriterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/test/DataStreamReaderWriterSuite.scala index b5f1e28d7396a..3de0ae67a3892 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/test/DataStreamReaderWriterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/test/DataStreamReaderWriterSuite.scala @@ -663,4 +663,16 @@ class DataStreamReaderWriterSuite extends StreamTest with BeforeAndAfter { } assert(fs.exists(checkpointDir)) } + + test("SPARK-20431: Specify a schema by using a DDL-formatted string") { + spark.readStream + .format("org.apache.spark.sql.streaming.test") + .schema("aa INT") + .load() + + assert(LastOptions.schema.isDefined) + assert(LastOptions.schema.get === StructType(StructField("aa", IntegerType) :: Nil)) + + LastOptions.clear() + } } From b837bf9ae97cf7ee7558c10a5a34636e69367a05 Mon Sep 17 00:00:00 2001 From: Gabor Feher Date: Fri, 23 Jun 2017 21:53:38 -0700 Subject: [PATCH 0791/1765] [SPARK-20555][SQL] Fix mapping of Oracle DECIMAL types to Spark types in read path ## What changes were proposed in this pull request? This PR is to revert some code changes in the read path of https://github.com/apache/spark/pull/14377. The original fix is https://github.com/apache/spark/pull/17830 When merging this PR, please give the credit to gaborfeher ## How was this patch tested? Added a test case to OracleIntegrationSuite.scala Author: Gabor Feher Author: gatorsmile Closes #18408 from gatorsmile/OracleType. --- .../sql/jdbc/OracleIntegrationSuite.scala | 65 +++++++++++++------ .../apache/spark/sql/jdbc/OracleDialect.scala | 4 -- 2 files changed, 45 insertions(+), 24 deletions(-) diff --git a/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/OracleIntegrationSuite.scala b/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/OracleIntegrationSuite.scala index f7b1ec34ced76..b2f096964427e 100644 --- a/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/OracleIntegrationSuite.scala +++ b/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/OracleIntegrationSuite.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.jdbc import java.sql.{Connection, Date, Timestamp} import java.util.Properties +import java.math.BigDecimal import org.apache.spark.sql.Row import org.apache.spark.sql.test.SharedSQLContext @@ -93,8 +94,31 @@ class OracleIntegrationSuite extends DockerJDBCIntegrationSuite with SharedSQLCo |USING org.apache.spark.sql.jdbc |OPTIONS (url '$jdbcUrl', dbTable 'datetime1', oracle.jdbc.mapDateToTimestamp 'false') """.stripMargin.replaceAll("\n", " ")) + + + conn.prepareStatement("CREATE TABLE numerics (b DECIMAL(1), f DECIMAL(3, 2), i DECIMAL(10))").executeUpdate(); + conn.prepareStatement( + "INSERT INTO numerics VALUES (4, 1.23, 9999999999)").executeUpdate(); + conn.commit(); } + + test("SPARK-16625 : Importing Oracle numeric types") { + val df = sqlContext.read.jdbc(jdbcUrl, "numerics", new Properties); + val rows = df.collect() + assert(rows.size == 1) + val row = rows(0) + // The main point of the below assertions is not to make sure that these Oracle types are + // mapped to decimal types, but to make sure that the returned values are correct. + // A value > 1 from DECIMAL(1) is correct: + assert(row.getDecimal(0).compareTo(BigDecimal.valueOf(4)) == 0) + // A value with fractions from DECIMAL(3, 2) is correct: + assert(row.getDecimal(1).compareTo(BigDecimal.valueOf(1.23)) == 0) + // A value > Int.MaxValue from DECIMAL(10) is correct: + assert(row.getDecimal(2).compareTo(BigDecimal.valueOf(9999999999l)) == 0) + } + + test("SPARK-12941: String datatypes to be mapped to Varchar in Oracle") { // create a sample dataframe with string type val df1 = sparkContext.parallelize(Seq(("foo"))).toDF("x") @@ -154,27 +178,28 @@ class OracleIntegrationSuite extends DockerJDBCIntegrationSuite with SharedSQLCo val dfRead = spark.read.jdbc(jdbcUrl, tableName, props) val rows = dfRead.collect() // verify the data type is inserted - val types = rows(0).toSeq.map(x => x.getClass.toString) - assert(types(0).equals("class java.lang.Boolean")) - assert(types(1).equals("class java.lang.Integer")) - assert(types(2).equals("class java.lang.Long")) - assert(types(3).equals("class java.lang.Float")) - assert(types(4).equals("class java.lang.Float")) - assert(types(5).equals("class java.lang.Integer")) - assert(types(6).equals("class java.lang.Integer")) - assert(types(7).equals("class java.lang.String")) - assert(types(8).equals("class [B")) - assert(types(9).equals("class java.sql.Date")) - assert(types(10).equals("class java.sql.Timestamp")) + val types = dfRead.schema.map(field => field.dataType) + assert(types(0).equals(DecimalType(1, 0))) + assert(types(1).equals(DecimalType(10, 0))) + assert(types(2).equals(DecimalType(19, 0))) + assert(types(3).equals(DecimalType(19, 4))) + assert(types(4).equals(DecimalType(19, 4))) + assert(types(5).equals(DecimalType(3, 0))) + assert(types(6).equals(DecimalType(5, 0))) + assert(types(7).equals(StringType)) + assert(types(8).equals(BinaryType)) + assert(types(9).equals(DateType)) + assert(types(10).equals(TimestampType)) + // verify the value is the inserted correct or not val values = rows(0) - assert(values.getBoolean(0).equals(booleanVal)) - assert(values.getInt(1).equals(integerVal)) - assert(values.getLong(2).equals(longVal)) - assert(values.getFloat(3).equals(floatVal)) - assert(values.getFloat(4).equals(doubleVal.toFloat)) - assert(values.getInt(5).equals(byteVal.toInt)) - assert(values.getInt(6).equals(shortVal.toInt)) + assert(values.getDecimal(0).compareTo(BigDecimal.valueOf(1)) == 0) + assert(values.getDecimal(1).compareTo(BigDecimal.valueOf(integerVal)) == 0) + assert(values.getDecimal(2).compareTo(BigDecimal.valueOf(longVal)) == 0) + assert(values.getDecimal(3).compareTo(BigDecimal.valueOf(floatVal)) == 0) + assert(values.getDecimal(4).compareTo(BigDecimal.valueOf(doubleVal)) == 0) + assert(values.getDecimal(5).compareTo(BigDecimal.valueOf(byteVal)) == 0) + assert(values.getDecimal(6).compareTo(BigDecimal.valueOf(shortVal)) == 0) assert(values.getString(7).equals(stringVal)) assert(values.getAs[Array[Byte]](8).mkString.equals("678")) assert(values.getDate(9).equals(dateVal)) @@ -183,7 +208,7 @@ class OracleIntegrationSuite extends DockerJDBCIntegrationSuite with SharedSQLCo test("SPARK-19318: connection property keys should be case-sensitive") { def checkRow(row: Row): Unit = { - assert(row.getInt(0) == 1) + assert(row.getDecimal(0).equals(BigDecimal.valueOf(1))) assert(row.getDate(1).equals(Date.valueOf("1991-11-09"))) assert(row.getTimestamp(2).equals(Timestamp.valueOf("1996-01-01 01:23:45"))) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/OracleDialect.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/OracleDialect.scala index f541996b651e9..20e634c06b610 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/OracleDialect.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/OracleDialect.scala @@ -43,10 +43,6 @@ private case object OracleDialect extends JdbcDialect { // Not sure if there is a more robust way to identify the field as a float (or other // numeric types that do not specify a scale. case _ if scale == -127L => Option(DecimalType(DecimalType.MAX_PRECISION, 10)) - case 1 => Option(BooleanType) - case 3 | 5 | 10 => Option(IntegerType) - case 19 if scale == 0L => Option(LongType) - case 19 if scale == 4L => Option(FloatType) case _ => None } } else { From bfd73a7c48b87456d1b84d826e04eca938a1be64 Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Sat, 24 Jun 2017 13:23:43 +0800 Subject: [PATCH 0792/1765] [SPARK-21159][CORE] Don't try to connect to launcher in standalone cluster mode. Monitoring for standalone cluster mode is not implemented (see SPARK-11033), but the same scheduler implementation is used, and if it tries to connect to the launcher it will fail. So fix the scheduler so it only tries that in client mode; cluster mode applications will be correctly launched and will work, but monitoring through the launcher handle will not be available. Tested by running a cluster mode app with "SparkLauncher.startApplication". Author: Marcelo Vanzin Closes #18397 from vanzin/SPARK-21159. --- .../scheduler/cluster/StandaloneSchedulerBackend.scala | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/StandaloneSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/StandaloneSchedulerBackend.scala index fd8e64454bf70..a4e2a74341283 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/StandaloneSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/StandaloneSchedulerBackend.scala @@ -58,7 +58,13 @@ private[spark] class StandaloneSchedulerBackend( override def start() { super.start() - launcherBackend.connect() + + // SPARK-21159. The scheduler backend should only try to connect to the launcher when in client + // mode. In cluster mode, the code that submits the application to the Master needs to connect + // to the launcher instead. + if (sc.deployMode == "client") { + launcherBackend.connect() + } // The endpoint for executors to talk to us val driverUrl = RpcEndpointAddress( From 7c7bc8fc0ff85fe70968b47433bb7757326a6b12 Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Sat, 24 Jun 2017 10:14:31 +0100 Subject: [PATCH 0793/1765] [SPARK-21189][INFRA] Handle unknown error codes in Jenkins rather then leaving incomplete comment in PRs ## What changes were proposed in this pull request? Recently, Jenkins tests were unstable due to unknown reasons as below: ``` /home/jenkins/workspace/SparkPullRequestBuilder/dev/lint-r ; process was terminated by signal 9 test_result_code, test_result_note = run_tests(tests_timeout) File "./dev/run-tests-jenkins.py", line 140, in run_tests test_result_note = ' * This patch **fails %s**.' % failure_note_by_errcode[test_result_code] KeyError: -9 ``` ``` Traceback (most recent call last): File "./dev/run-tests-jenkins.py", line 226, in main() File "./dev/run-tests-jenkins.py", line 213, in main test_result_code, test_result_note = run_tests(tests_timeout) File "./dev/run-tests-jenkins.py", line 140, in run_tests test_result_note = ' * This patch **fails %s**.' % failure_note_by_errcode[test_result_code] KeyError: -10 ``` This exception looks causing failing to update the comments in the PR. For example: ![2017-06-23 4 19 41](https://user-images.githubusercontent.com/6477701/27470626-d035ecd8-582f-11e7-883e-0ae6941659b7.png) ![2017-06-23 4 19 50](https://user-images.githubusercontent.com/6477701/27470629-d11ba782-582f-11e7-97e0-64d28cbc19aa.png) these comment just remain. This always requires, for both reviewers and the author, a overhead to click and check the logs, which I believe are not really useful. This PR proposes to leave the code in the PR comment messages and let update the comments. ## How was this patch tested? Jenkins tests below, I manually gave the error code to test this. Author: hyukjinkwon Closes #18399 from HyukjinKwon/jenkins-print-errors. --- dev/run-tests-jenkins.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/dev/run-tests-jenkins.py b/dev/run-tests-jenkins.py index 53061bc947e5f..914eb93622d51 100755 --- a/dev/run-tests-jenkins.py +++ b/dev/run-tests-jenkins.py @@ -137,7 +137,9 @@ def run_tests(tests_timeout): if test_result_code == 0: test_result_note = ' * This patch passes all tests.' else: - test_result_note = ' * This patch **fails %s**.' % failure_note_by_errcode[test_result_code] + note = failure_note_by_errcode.get( + test_result_code, "due to an unknown error code, %s" % test_result_code) + test_result_note = ' * This patch **fails %s**.' % note return [test_result_code, test_result_note] From 2e1586f60a77ea0adb6f3f68ba74323f0c242199 Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Sat, 24 Jun 2017 22:35:59 +0800 Subject: [PATCH 0794/1765] [SPARK-21203][SQL] Fix wrong results of insertion of Array of Struct ### What changes were proposed in this pull request? ```SQL CREATE TABLE `tab1` (`custom_fields` ARRAY>) USING parquet INSERT INTO `tab1` SELECT ARRAY(named_struct('id', 1, 'value', 'a'), named_struct('id', 2, 'value', 'b')) SELECT custom_fields.id, custom_fields.value FROM tab1 ``` The above query always return the last struct of the array, because the rule `SimplifyCasts` incorrectly rewrites the query. The underlying cause is we always use the same `GenericInternalRow` object when doing the cast. ### How was this patch tested? Author: gatorsmile Closes #18412 from gatorsmile/castStruct. --- .../spark/sql/catalyst/expressions/Cast.scala | 4 ++-- .../spark/sql/sources/InsertSuite.scala | 21 +++++++++++++++++++ 2 files changed, 23 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala index a53ef426f79b5..43df19ba009a8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala @@ -482,15 +482,15 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String case (fromField, toField) => cast(fromField.dataType, toField.dataType) } // TODO: Could be faster? - val newRow = new GenericInternalRow(from.fields.length) buildCast[InternalRow](_, row => { + val newRow = new GenericInternalRow(from.fields.length) var i = 0 while (i < row.numFields) { newRow.update(i, if (row.isNullAt(i)) null else castFuncs(i)(row.get(i, from.apply(i).dataType))) i += 1 } - newRow.copy() + newRow }) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertSuite.scala index 2eae66dda88de..41abff2a5da25 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertSuite.scala @@ -345,4 +345,25 @@ class InsertSuite extends DataSourceTest with SharedSQLContext { ) } } + + test("SPARK-21203 wrong results of insertion of Array of Struct") { + val tabName = "tab1" + withTable(tabName) { + spark.sql( + """ + |CREATE TABLE `tab1` + |(`custom_fields` ARRAY>) + |USING parquet + """.stripMargin) + spark.sql( + """ + |INSERT INTO `tab1` + |SELECT ARRAY(named_struct('id', 1, 'value', 'a'), named_struct('id', 2, 'value', 'b')) + """.stripMargin) + + checkAnswer( + spark.sql("SELECT custom_fields.id, custom_fields.value FROM tab1"), + Row(Array(1, 2), Array("a", "b"))) + } + } } From b449a1d6aa322a50cf221cd7a2ae85a91d6c7e9f Mon Sep 17 00:00:00 2001 From: Masha Basmanova Date: Sat, 24 Jun 2017 22:49:35 -0700 Subject: [PATCH 0795/1765] [SPARK-21079][SQL] Calculate total size of a partition table as a sum of individual partitions ## What changes were proposed in this pull request? Storage URI of a partitioned table may or may not point to a directory under which individual partitions are stored. In fact, individual partitions may be located in totally unrelated directories. Before this change, ANALYZE TABLE table COMPUTE STATISTICS command calculated total size of a table by adding up sizes of files found under table's storage URI. This calculation could produce 0 if partitions are stored elsewhere. This change uses storage URIs of individual partitions to calculate the sizes of all partitions of a table and adds these up to produce the total size of a table. CC: wzhfy ## How was this patch tested? Added unit test. Ran ANALYZE TABLE xxx COMPUTE STATISTICS on a partitioned Hive table and verified that sizeInBytes is calculated correctly. Before this change, the size would be zero. Author: Masha Basmanova Closes #18309 from mbasmanova/mbasmanova-analyze-part-table. --- .../command/AnalyzeTableCommand.scala | 29 ++++++-- .../spark/sql/hive/StatisticsSuite.scala | 72 +++++++++++++++++++ 2 files changed, 95 insertions(+), 6 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeTableCommand.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeTableCommand.scala index 3c59b982c2dca..06e588f56f1e9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeTableCommand.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeTableCommand.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.execution.command +import java.net.URI + import scala.util.control.NonFatal import org.apache.hadoop.fs.{FileSystem, Path} @@ -81,6 +83,21 @@ case class AnalyzeTableCommand( object AnalyzeTableCommand extends Logging { def calculateTotalSize(sessionState: SessionState, catalogTable: CatalogTable): Long = { + if (catalogTable.partitionColumnNames.isEmpty) { + calculateLocationSize(sessionState, catalogTable.identifier, catalogTable.storage.locationUri) + } else { + // Calculate table size as a sum of the visible partitions. See SPARK-21079 + val partitions = sessionState.catalog.listPartitions(catalogTable.identifier) + partitions.map(p => + calculateLocationSize(sessionState, catalogTable.identifier, p.storage.locationUri) + ).sum + } + } + + private def calculateLocationSize( + sessionState: SessionState, + tableId: TableIdentifier, + locationUri: Option[URI]): Long = { // This method is mainly based on // org.apache.hadoop.hive.ql.stats.StatsUtils.getFileSizeForTable(HiveConf, Table) // in Hive 0.13 (except that we do not use fs.getContentSummary). @@ -91,13 +108,13 @@ object AnalyzeTableCommand extends Logging { // countFileSize to count the table size. val stagingDir = sessionState.conf.getConfString("hive.exec.stagingdir", ".hive-staging") - def calculateTableSize(fs: FileSystem, path: Path): Long = { + def calculateLocationSize(fs: FileSystem, path: Path): Long = { val fileStatus = fs.getFileStatus(path) val size = if (fileStatus.isDirectory) { fs.listStatus(path) .map { status => if (!status.getPath.getName.startsWith(stagingDir)) { - calculateTableSize(fs, status.getPath) + calculateLocationSize(fs, status.getPath) } else { 0L } @@ -109,16 +126,16 @@ object AnalyzeTableCommand extends Logging { size } - catalogTable.storage.locationUri.map { p => + locationUri.map { p => val path = new Path(p) try { val fs = path.getFileSystem(sessionState.newHadoopConf()) - calculateTableSize(fs, path) + calculateLocationSize(fs, path) } catch { case NonFatal(e) => logWarning( - s"Failed to get the size of table ${catalogTable.identifier.table} in the " + - s"database ${catalogTable.identifier.database} because of ${e.toString}", e) + s"Failed to get the size of table ${tableId.table} in the " + + s"database ${tableId.database} because of ${e.toString}", e) 0L } }.getOrElse(0L) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala index 279db9a397258..0ee18bbe9befe 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala @@ -33,6 +33,7 @@ import org.apache.spark.sql.execution.joins._ import org.apache.spark.sql.hive.test.TestHiveSingleton import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ +import org.apache.spark.util.Utils class StatisticsSuite extends StatisticsCollectionTestBase with TestHiveSingleton { @@ -128,6 +129,77 @@ class StatisticsSuite extends StatisticsCollectionTestBase with TestHiveSingleto TableIdentifier("tempTable"), ignoreIfNotExists = true, purge = false) } + test("SPARK-21079 - analyze table with location different than that of individual partitions") { + def queryTotalSize(tableName: String): BigInt = + spark.table(tableName).queryExecution.analyzed.stats(conf).sizeInBytes + + val tableName = "analyzeTable_part" + withTable(tableName) { + withTempPath { path => + sql(s"CREATE TABLE $tableName (key STRING, value STRING) PARTITIONED BY (ds STRING)") + + val partitionDates = List("2010-01-01", "2010-01-02", "2010-01-03") + partitionDates.foreach { ds => + sql(s"INSERT INTO TABLE $tableName PARTITION (ds='$ds') SELECT * FROM src") + } + + sql(s"ALTER TABLE $tableName SET LOCATION '$path'") + + sql(s"ANALYZE TABLE $tableName COMPUTE STATISTICS noscan") + + assert(queryTotalSize(tableName) === BigInt(17436)) + } + } + } + + test("SPARK-21079 - analyze partitioned table with only a subset of partitions visible") { + def queryTotalSize(tableName: String): BigInt = + spark.table(tableName).queryExecution.analyzed.stats(conf).sizeInBytes + + val sourceTableName = "analyzeTable_part" + val tableName = "analyzeTable_part_vis" + withTable(sourceTableName, tableName) { + withTempPath { path => + // Create a table with 3 partitions all located under a single top-level directory 'path' + sql( + s""" + |CREATE TABLE $sourceTableName (key STRING, value STRING) + |PARTITIONED BY (ds STRING) + |LOCATION '$path' + """.stripMargin) + + val partitionDates = List("2010-01-01", "2010-01-02", "2010-01-03") + partitionDates.foreach { ds => + sql( + s""" + |INSERT INTO TABLE $sourceTableName PARTITION (ds='$ds') + |SELECT * FROM src + """.stripMargin) + } + + // Create another table referring to the same location + sql( + s""" + |CREATE TABLE $tableName (key STRING, value STRING) + |PARTITIONED BY (ds STRING) + |LOCATION '$path' + """.stripMargin) + + // Register only one of the partitions found on disk + val ds = partitionDates.head + sql(s"ALTER TABLE $tableName ADD PARTITION (ds='$ds')").collect() + + // Analyze original table - expect 3 partitions + sql(s"ANALYZE TABLE $sourceTableName COMPUTE STATISTICS noscan") + assert(queryTotalSize(sourceTableName) === BigInt(3 * 5812)) + + // Analyze partial-copy table - expect only 1 partition + sql(s"ANALYZE TABLE $tableName COMPUTE STATISTICS noscan") + assert(queryTotalSize(tableName) === BigInt(5812)) + } + } + } + test("analyzing views is not supported") { def assertAnalyzeUnsupported(analyzeCommand: String): Unit = { val err = intercept[AnalysisException] { From 884347e1f79e4e7c157834881e79447d7ee58f88 Mon Sep 17 00:00:00 2001 From: Zhenhua Wang Date: Sun, 25 Jun 2017 15:06:29 +0100 Subject: [PATCH 0796/1765] [HOT FIX] fix stats functions in the recent patch ## What changes were proposed in this pull request? Builds failed due to the recent [merge](https://github.com/apache/spark/commit/b449a1d6aa322a50cf221cd7a2ae85a91d6c7e9f). This is because [PR#18309](https://github.com/apache/spark/pull/18309) needed update after [this patch](https://github.com/apache/spark/commit/b803b66a8133f705463039325ee71ee6827ce1a7) was merged. ## How was this patch tested? N/A Author: Zhenhua Wang Closes #18415 from wzhfy/hotfixStats. --- .../scala/org/apache/spark/sql/hive/StatisticsSuite.scala | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala index 0ee18bbe9befe..64deb3818d5d1 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala @@ -33,7 +33,6 @@ import org.apache.spark.sql.execution.joins._ import org.apache.spark.sql.hive.test.TestHiveSingleton import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ -import org.apache.spark.util.Utils class StatisticsSuite extends StatisticsCollectionTestBase with TestHiveSingleton { @@ -131,7 +130,7 @@ class StatisticsSuite extends StatisticsCollectionTestBase with TestHiveSingleto test("SPARK-21079 - analyze table with location different than that of individual partitions") { def queryTotalSize(tableName: String): BigInt = - spark.table(tableName).queryExecution.analyzed.stats(conf).sizeInBytes + spark.table(tableName).queryExecution.analyzed.stats.sizeInBytes val tableName = "analyzeTable_part" withTable(tableName) { @@ -154,7 +153,7 @@ class StatisticsSuite extends StatisticsCollectionTestBase with TestHiveSingleto test("SPARK-21079 - analyze partitioned table with only a subset of partitions visible") { def queryTotalSize(tableName: String): BigInt = - spark.table(tableName).queryExecution.analyzed.stats(conf).sizeInBytes + spark.table(tableName).queryExecution.analyzed.stats.sizeInBytes val sourceTableName = "analyzeTable_part" val tableName = "analyzeTable_part_vis" From 6b3d02285ee0debc73cbcab01b10398a498fbeb8 Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Sun, 25 Jun 2017 11:05:57 -0700 Subject: [PATCH 0797/1765] [SPARK-21093][R] Terminate R's worker processes in the parent of R's daemon to prevent a leak ## What changes were proposed in this pull request? `mcfork` in R looks opening a pipe ahead but the existing logic does not properly close it when it is executed hot. This leads to the failure of more forking due to the limit for number of files open. This hot execution looks particularly for `gapply`/`gapplyCollect`. For unknown reason, this happens more easily in CentOS and could be reproduced in Mac too. All the details are described in https://issues.apache.org/jira/browse/SPARK-21093 This PR proposes simply to terminate R's worker processes in the parent of R's daemon to prevent a leak. ## How was this patch tested? I ran the codes below on both CentOS and Mac with that configuration disabled/enabled. ```r df <- createDataFrame(list(list(1L, 1, "1", 0.1)), c("a", "b", "c", "d")) collect(gapply(df, "a", function(key, x) { x }, schema(df))) collect(gapply(df, "a", function(key, x) { x }, schema(df))) ... # 30 times ``` Also, now it passes R tests on CentOS as below: ``` SparkSQL functions: Spark package found in SPARK_HOME: .../spark``` Author: hyukjinkwon Closes #18320 from HyukjinKwon/SPARK-21093. --- R/pkg/inst/worker/daemon.R | 59 +++++++++++++++++++++++++++++++++++--- 1 file changed, 55 insertions(+), 4 deletions(-) diff --git a/R/pkg/inst/worker/daemon.R b/R/pkg/inst/worker/daemon.R index 3a318b71ea06d..6e385b2a27622 100644 --- a/R/pkg/inst/worker/daemon.R +++ b/R/pkg/inst/worker/daemon.R @@ -30,8 +30,55 @@ port <- as.integer(Sys.getenv("SPARKR_WORKER_PORT")) inputCon <- socketConnection( port = port, open = "rb", blocking = TRUE, timeout = connectionTimeout) +# Waits indefinitely for a socket connecion by default. +selectTimeout <- NULL + +# Exit code that children send to the parent to indicate they exited. +exitCode <- 1 + while (TRUE) { - ready <- socketSelect(list(inputCon)) + ready <- socketSelect(list(inputCon), timeout = selectTimeout) + + # Note that the children should be terminated in the parent. If each child terminates + # itself, it appears that the resource is not released properly, that causes an unexpected + # termination of this daemon due to, for example, running out of file descriptors + # (see SPARK-21093). Therefore, the current implementation tries to retrieve children + # that are exited (but not terminated) and then sends a kill signal to terminate them properly + # in the parent. + # + # There are two paths that it attempts to send a signal to terminate the children in the parent. + # + # 1. Every second if any socket connection is not available and if there are child workers + # running. + # 2. Right after a socket connection is available. + # + # In other words, the parent attempts to send the signal to the children every second if + # any worker is running or right before launching other worker children from the following + # new socket connection. + + # Only the process IDs of children sent data to the parent are returned below. The children + # send a custom exit code to the parent after being exited and the parent tries + # to terminate them only if they sent the exit code. + children <- parallel:::selectChildren(timeout = 0) + + if (is.integer(children)) { + lapply(children, function(child) { + # This data should be raw bytes if any data was sent from this child. + # Otherwise, this returns the PID. + data <- parallel:::readChild(child) + if (is.raw(data)) { + # This checks if the data from this child is the exit code that indicates an exited child. + if (unserialize(data) == exitCode) { + # If so, we terminate this child. + tools::pskill(child, tools::SIGUSR1) + } + } + }) + } else if (is.null(children)) { + # If it is NULL, there are no children. Waits indefinitely for a socket connecion. + selectTimeout <- NULL + } + if (ready) { port <- SparkR:::readInt(inputCon) # There is a small chance that it could be interrupted by signal, retry one time @@ -44,12 +91,16 @@ while (TRUE) { } p <- parallel:::mcfork() if (inherits(p, "masterProcess")) { + # Reach here because this is a child process. close(inputCon) Sys.setenv(SPARKR_WORKER_PORT = port) try(source(script)) - # Set SIGUSR1 so that child can exit - tools::pskill(Sys.getpid(), tools::SIGUSR1) - parallel:::mcexit(0L) + # Note that this mcexit does not fully terminate this child. So, this writes back + # a custom exit code so that the parent can read and terminate this child. + parallel:::mcexit(0L, send = exitCode) + } else { + # Forking succeeded and we need to check if they finished their jobs every second. + selectTimeout <- 1 } } } From 5282bae0408dec8aa0cefafd7673dd34d232ead9 Mon Sep 17 00:00:00 2001 From: Burak Yavuz Date: Mon, 26 Jun 2017 01:26:32 -0700 Subject: [PATCH 0798/1765] [SPARK-21153] Use project instead of expand in tumbling windows ## What changes were proposed in this pull request? Time windowing in Spark currently performs an Expand + Filter, because there is no way to guarantee the amount of windows a timestamp will fall in, in the general case. However, for tumbling windows, a record is guaranteed to fall into a single bucket. In this case, doubling the number of records with Expand is wasteful, and can be improved by using a simple Projection instead. Benchmarks show that we get an order of magnitude performance improvement after this patch. ## How was this patch tested? Existing unit tests. Benchmarked using the following code: ```scala import org.apache.spark.sql.functions._ spark.time { spark.range(numRecords) .select(from_unixtime((current_timestamp().cast("long") * 1000 + 'id / 1000) / 1000) as 'time) .select(window('time, "10 seconds")) .count() } ``` Setup: - 1 c3.2xlarge worker (8 cores) ![image](https://user-images.githubusercontent.com/5243515/27348748-ed991b84-55a9-11e7-8f8b-6e7abc524417.png) 1 B rows ran in 287 seconds after this optimization. I didn't wait for it to finish without the optimization. Shows about 5x improvement for large number of records. Author: Burak Yavuz Closes #18364 from brkyvz/opt-tumble. --- .../sql/catalyst/analysis/Analyzer.scala | 72 +++++++++++++------ .../sql/catalyst/expressions/TimeWindow.scala | 12 ++-- .../sql/DataFrameTimeWindowingSuite.scala | 49 +++++++++---- 3 files changed, 94 insertions(+), 39 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 7e5ebfc93286f..434b6ffee37fa 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -2301,6 +2301,7 @@ object EliminateEventTimeWatermark extends Rule[LogicalPlan] { object TimeWindowing extends Rule[LogicalPlan] { import org.apache.spark.sql.catalyst.dsl.expressions._ + private final val WINDOW_COL_NAME = "window" private final val WINDOW_START = "start" private final val WINDOW_END = "end" @@ -2336,49 +2337,76 @@ object TimeWindowing extends Rule[LogicalPlan] { case p: LogicalPlan if p.children.size == 1 => val child = p.children.head val windowExpressions = - p.expressions.flatMap(_.collect { case t: TimeWindow => t }).distinct.toList // Not correct. + p.expressions.flatMap(_.collect { case t: TimeWindow => t }).toSet // Only support a single window expression for now if (windowExpressions.size == 1 && windowExpressions.head.timeColumn.resolved && windowExpressions.head.checkInputDataTypes().isSuccess) { + val window = windowExpressions.head val metadata = window.timeColumn match { case a: Attribute => a.metadata case _ => Metadata.empty } - val windowAttr = - AttributeReference("window", window.dataType, metadata = metadata)() - - val maxNumOverlapping = math.ceil(window.windowDuration * 1.0 / window.slideDuration).toInt - val windows = Seq.tabulate(maxNumOverlapping + 1) { i => - val windowId = Ceil((PreciseTimestamp(window.timeColumn) - window.startTime) / - window.slideDuration) - val windowStart = (windowId + i - maxNumOverlapping) * - window.slideDuration + window.startTime + + def getWindow(i: Int, overlappingWindows: Int): Expression = { + val division = (PreciseTimestampConversion( + window.timeColumn, TimestampType, LongType) - window.startTime) / window.slideDuration + val ceil = Ceil(division) + // if the division is equal to the ceiling, our record is the start of a window + val windowId = CaseWhen(Seq((ceil === division, ceil + 1)), Some(ceil)) + val windowStart = (windowId + i - overlappingWindows) * + window.slideDuration + window.startTime val windowEnd = windowStart + window.windowDuration CreateNamedStruct( - Literal(WINDOW_START) :: windowStart :: - Literal(WINDOW_END) :: windowEnd :: Nil) + Literal(WINDOW_START) :: + PreciseTimestampConversion(windowStart, LongType, TimestampType) :: + Literal(WINDOW_END) :: + PreciseTimestampConversion(windowEnd, LongType, TimestampType) :: + Nil) } - val projections = windows.map(_ +: p.children.head.output) + val windowAttr = AttributeReference( + WINDOW_COL_NAME, window.dataType, metadata = metadata)() + + if (window.windowDuration == window.slideDuration) { + val windowStruct = Alias(getWindow(0, 1), WINDOW_COL_NAME)( + exprId = windowAttr.exprId) + + val replacedPlan = p transformExpressions { + case t: TimeWindow => windowAttr + } + + // For backwards compatibility we add a filter to filter out nulls + val filterExpr = IsNotNull(window.timeColumn) - val filterExpr = - window.timeColumn >= windowAttr.getField(WINDOW_START) && - window.timeColumn < windowAttr.getField(WINDOW_END) + replacedPlan.withNewChildren( + Filter(filterExpr, + Project(windowStruct +: child.output, child)) :: Nil) + } else { + val overlappingWindows = + math.ceil(window.windowDuration * 1.0 / window.slideDuration).toInt + val windows = + Seq.tabulate(overlappingWindows)(i => getWindow(i, overlappingWindows)) + + val projections = windows.map(_ +: child.output) + + val filterExpr = + window.timeColumn >= windowAttr.getField(WINDOW_START) && + window.timeColumn < windowAttr.getField(WINDOW_END) - val expandedPlan = - Filter(filterExpr, + val substitutedPlan = Filter(filterExpr, Expand(projections, windowAttr +: child.output, child)) - val substitutedPlan = p transformExpressions { - case t: TimeWindow => windowAttr - } + val renamedPlan = p transformExpressions { + case t: TimeWindow => windowAttr + } - substitutedPlan.withNewChildren(expandedPlan :: Nil) + renamedPlan.withNewChildren(substitutedPlan :: Nil) + } } else if (windowExpressions.size > 1) { p.failAnalysis("Multiple time window expressions would result in a cartesian product " + "of rows, therefore they are currently not supported.") diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TimeWindow.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TimeWindow.scala index 7ff61ee479452..9a9f579b37f58 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TimeWindow.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TimeWindow.scala @@ -152,12 +152,15 @@ object TimeWindow { } /** - * Expression used internally to convert the TimestampType to Long without losing + * Expression used internally to convert the TimestampType to Long and back without losing * precision, i.e. in microseconds. Used in time windowing. */ -case class PreciseTimestamp(child: Expression) extends UnaryExpression with ExpectsInputTypes { - override def inputTypes: Seq[AbstractDataType] = Seq(TimestampType) - override def dataType: DataType = LongType +case class PreciseTimestampConversion( + child: Expression, + fromType: DataType, + toType: DataType) extends UnaryExpression with ExpectsInputTypes { + override def inputTypes: Seq[AbstractDataType] = Seq(fromType) + override def dataType: DataType = toType override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val eval = child.genCode(ctx) ev.copy(code = eval.code + @@ -165,4 +168,5 @@ case class PreciseTimestamp(child: Expression) extends UnaryExpression with Expe |${ctx.javaType(dataType)} ${ev.value} = ${eval.value}; """.stripMargin) } + override def nullSafeEval(input: Any): Any = input } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameTimeWindowingSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameTimeWindowingSuite.scala index 22d5c47a6fb51..6fe356877c268 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameTimeWindowingSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameTimeWindowingSuite.scala @@ -17,10 +17,9 @@ package org.apache.spark.sql -import java.util.TimeZone - import org.scalatest.BeforeAndAfterEach +import org.apache.spark.sql.catalyst.plans.logical.Expand import org.apache.spark.sql.functions._ import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types.StringType @@ -29,11 +28,27 @@ class DataFrameTimeWindowingSuite extends QueryTest with SharedSQLContext with B import testImplicits._ + test("simple tumbling window with record at window start") { + val df = Seq( + ("2016-03-27 19:39:30", 1, "a")).toDF("time", "value", "id") + + checkAnswer( + df.groupBy(window($"time", "10 seconds")) + .agg(count("*").as("counts")) + .orderBy($"window.start".asc) + .select($"window.start".cast("string"), $"window.end".cast("string"), $"counts"), + Seq( + Row("2016-03-27 19:39:30", "2016-03-27 19:39:40", 1) + ) + ) + } + test("tumbling window groupBy statement") { val df = Seq( ("2016-03-27 19:39:34", 1, "a"), ("2016-03-27 19:39:56", 2, "a"), ("2016-03-27 19:39:27", 4, "b")).toDF("time", "value", "id") + checkAnswer( df.groupBy(window($"time", "10 seconds")) .agg(count("*").as("counts")) @@ -59,14 +74,18 @@ class DataFrameTimeWindowingSuite extends QueryTest with SharedSQLContext with B test("tumbling window with multi-column projection") { val df = Seq( - ("2016-03-27 19:39:34", 1, "a"), - ("2016-03-27 19:39:56", 2, "a"), - ("2016-03-27 19:39:27", 4, "b")).toDF("time", "value", "id") + ("2016-03-27 19:39:34", 1, "a"), + ("2016-03-27 19:39:56", 2, "a"), + ("2016-03-27 19:39:27", 4, "b")).toDF("time", "value", "id") + .select(window($"time", "10 seconds"), $"value") + .orderBy($"window.start".asc) + .select($"window.start".cast("string"), $"window.end".cast("string"), $"value") + + val expands = df.queryExecution.optimizedPlan.find(_.isInstanceOf[Expand]) + assert(expands.isEmpty, "Tumbling windows shouldn't require expand") checkAnswer( - df.select(window($"time", "10 seconds"), $"value") - .orderBy($"window.start".asc) - .select($"window.start".cast("string"), $"window.end".cast("string"), $"value"), + df, Seq( Row("2016-03-27 19:39:20", "2016-03-27 19:39:30", 4), Row("2016-03-27 19:39:30", "2016-03-27 19:39:40", 1), @@ -104,13 +123,17 @@ class DataFrameTimeWindowingSuite extends QueryTest with SharedSQLContext with B test("sliding window projection") { val df = Seq( - ("2016-03-27 19:39:34", 1, "a"), - ("2016-03-27 19:39:56", 2, "a"), - ("2016-03-27 19:39:27", 4, "b")).toDF("time", "value", "id") + ("2016-03-27 19:39:34", 1, "a"), + ("2016-03-27 19:39:56", 2, "a"), + ("2016-03-27 19:39:27", 4, "b")).toDF("time", "value", "id") + .select(window($"time", "10 seconds", "3 seconds", "0 second"), $"value") + .orderBy($"window.start".asc, $"value".desc).select("value") + + val expands = df.queryExecution.optimizedPlan.find(_.isInstanceOf[Expand]) + assert(expands.nonEmpty, "Sliding windows require expand") checkAnswer( - df.select(window($"time", "10 seconds", "3 seconds", "0 second"), $"value") - .orderBy($"window.start".asc, $"value".desc).select("value"), + df, // 2016-03-27 19:39:27 UTC -> 4 bins // 2016-03-27 19:39:34 UTC -> 3 bins // 2016-03-27 19:39:56 UTC -> 3 bins From 9e50a1d37a4cf0c34e20a7c1a910ceaff41535a2 Mon Sep 17 00:00:00 2001 From: jerryshao Date: Mon, 26 Jun 2017 11:14:03 -0500 Subject: [PATCH 0799/1765] [SPARK-13669][SPARK-20898][CORE] Improve the blacklist mechanism to handle external shuffle service unavailable situation MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## What changes were proposed in this pull request? Currently we are running into an issue with Yarn work preserving enabled + external shuffle service. In the work preserving enabled scenario, the failure of NM will not lead to the exit of executors, so executors can still accept and run the tasks. The problem here is when NM is failed, external shuffle service is actually inaccessible, so reduce tasks will always complain about the “Fetch failure”, and the failure of reduce stage will make the parent stage (map stage) rerun. The tricky thing here is Spark scheduler is not aware of the unavailability of external shuffle service, and will reschedule the map tasks on the executor where NM is failed, and again reduce stage will be failed with “Fetch failure”, and after 4 retries, the job is failed. This could also apply to other cluster manager with external shuffle service. So here the main problem is that we should avoid assigning tasks to those bad executors (where shuffle service is unavailable). Current Spark's blacklist mechanism could blacklist executors/nodes by failure tasks, but it doesn't handle this specific fetch failure scenario. So here propose to improve the current application blacklist mechanism to handle fetch failure issue (especially with external shuffle service unavailable issue), to blacklist the executors/nodes where shuffle fetch is unavailable. ## How was this patch tested? Unit test and small cluster verification. Author: jerryshao Closes #17113 from jerryshao/SPARK-13669. --- .../spark/internal/config/package.scala | 5 + .../spark/scheduler/BlacklistTracker.scala | 95 ++++++++++++++----- .../spark/scheduler/TaskSchedulerImpl.scala | 18 +--- .../spark/scheduler/TaskSetManager.scala | 6 ++ .../scheduler/BlacklistTrackerSuite.scala | 55 +++++++++++ .../scheduler/TaskSchedulerImplSuite.scala | 4 +- .../spark/scheduler/TaskSetManagerSuite.scala | 32 +++++++ docs/configuration.md | 9 ++ 8 files changed, 186 insertions(+), 38 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/internal/config/package.scala b/core/src/main/scala/org/apache/spark/internal/config/package.scala index 462c1890fd8df..be63c637a3a13 100644 --- a/core/src/main/scala/org/apache/spark/internal/config/package.scala +++ b/core/src/main/scala/org/apache/spark/internal/config/package.scala @@ -149,6 +149,11 @@ package object config { .internal() .timeConf(TimeUnit.MILLISECONDS) .createOptional + + private[spark] val BLACKLIST_FETCH_FAILURE_ENABLED = + ConfigBuilder("spark.blacklist.application.fetchFailure.enabled") + .booleanConf + .createWithDefault(false) // End blacklist confs private[spark] val UNREGISTER_OUTPUT_ON_HOST_ON_FETCH_FAILURE = diff --git a/core/src/main/scala/org/apache/spark/scheduler/BlacklistTracker.scala b/core/src/main/scala/org/apache/spark/scheduler/BlacklistTracker.scala index e130e609e4f63..cd8e61d6d0208 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/BlacklistTracker.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/BlacklistTracker.scala @@ -61,6 +61,7 @@ private[scheduler] class BlacklistTracker ( private val MAX_FAILURES_PER_EXEC = conf.get(config.MAX_FAILURES_PER_EXEC) private val MAX_FAILED_EXEC_PER_NODE = conf.get(config.MAX_FAILED_EXEC_PER_NODE) val BLACKLIST_TIMEOUT_MILLIS = BlacklistTracker.getBlacklistTimeout(conf) + private val BLACKLIST_FETCH_FAILURE_ENABLED = conf.get(config.BLACKLIST_FETCH_FAILURE_ENABLED) /** * A map from executorId to information on task failures. Tracks the time of each task failure, @@ -145,6 +146,74 @@ private[scheduler] class BlacklistTracker ( nextExpiryTime = math.min(execMinExpiry, nodeMinExpiry) } + private def killBlacklistedExecutor(exec: String): Unit = { + if (conf.get(config.BLACKLIST_KILL_ENABLED)) { + allocationClient match { + case Some(a) => + logInfo(s"Killing blacklisted executor id $exec " + + s"since ${config.BLACKLIST_KILL_ENABLED.key} is set.") + a.killExecutors(Seq(exec), true, true) + case None => + logWarning(s"Not attempting to kill blacklisted executor id $exec " + + s"since allocation client is not defined.") + } + } + } + + private def killExecutorsOnBlacklistedNode(node: String): Unit = { + if (conf.get(config.BLACKLIST_KILL_ENABLED)) { + allocationClient match { + case Some(a) => + logInfo(s"Killing all executors on blacklisted host $node " + + s"since ${config.BLACKLIST_KILL_ENABLED.key} is set.") + if (a.killExecutorsOnHost(node) == false) { + logError(s"Killing executors on node $node failed.") + } + case None => + logWarning(s"Not attempting to kill executors on blacklisted host $node " + + s"since allocation client is not defined.") + } + } + } + + def updateBlacklistForFetchFailure(host: String, exec: String): Unit = { + if (BLACKLIST_FETCH_FAILURE_ENABLED) { + // If we blacklist on fetch failures, we are implicitly saying that we believe the failure is + // non-transient, and can't be recovered from (even if this is the first fetch failure, + // stage is retried after just one failure, so we don't always get a chance to collect + // multiple fetch failures). + // If the external shuffle-service is on, then every other executor on this node would + // be suffering from the same issue, so we should blacklist (and potentially kill) all + // of them immediately. + + val now = clock.getTimeMillis() + val expiryTimeForNewBlacklists = now + BLACKLIST_TIMEOUT_MILLIS + + if (conf.get(config.SHUFFLE_SERVICE_ENABLED)) { + if (!nodeIdToBlacklistExpiryTime.contains(host)) { + logInfo(s"blacklisting node $host due to fetch failure of external shuffle service") + + nodeIdToBlacklistExpiryTime.put(host, expiryTimeForNewBlacklists) + listenerBus.post(SparkListenerNodeBlacklisted(now, host, 1)) + _nodeBlacklist.set(nodeIdToBlacklistExpiryTime.keySet.toSet) + killExecutorsOnBlacklistedNode(host) + updateNextExpiryTime() + } + } else if (!executorIdToBlacklistStatus.contains(exec)) { + logInfo(s"Blacklisting executor $exec due to fetch failure") + + executorIdToBlacklistStatus.put(exec, BlacklistedExecutor(host, expiryTimeForNewBlacklists)) + // We hardcoded number of failure tasks to 1 for fetch failure, because there's no + // reattempt for such failure. + listenerBus.post(SparkListenerExecutorBlacklisted(now, exec, 1)) + updateNextExpiryTime() + killBlacklistedExecutor(exec) + + val blacklistedExecsOnNode = nodeToBlacklistedExecs.getOrElseUpdate(exec, HashSet[String]()) + blacklistedExecsOnNode += exec + } + } + } def updateBlacklistForSuccessfulTaskSet( stageId: Int, @@ -174,17 +243,7 @@ private[scheduler] class BlacklistTracker ( listenerBus.post(SparkListenerExecutorBlacklisted(now, exec, newTotal)) executorIdToFailureList.remove(exec) updateNextExpiryTime() - if (conf.get(config.BLACKLIST_KILL_ENABLED)) { - allocationClient match { - case Some(allocationClient) => - logInfo(s"Killing blacklisted executor id $exec " + - s"since spark.blacklist.killBlacklistedExecutors is set.") - allocationClient.killExecutors(Seq(exec), true, true) - case None => - logWarning(s"Not attempting to kill blacklisted executor id $exec " + - s"since allocation client is not defined.") - } - } + killBlacklistedExecutor(exec) // In addition to blacklisting the executor, we also update the data for failures on the // node, and potentially put the entire node into a blacklist as well. @@ -199,19 +258,7 @@ private[scheduler] class BlacklistTracker ( nodeIdToBlacklistExpiryTime.put(node, expiryTimeForNewBlacklists) listenerBus.post(SparkListenerNodeBlacklisted(now, node, blacklistedExecsOnNode.size)) _nodeBlacklist.set(nodeIdToBlacklistExpiryTime.keySet.toSet) - if (conf.get(config.BLACKLIST_KILL_ENABLED)) { - allocationClient match { - case Some(allocationClient) => - logInfo(s"Killing all executors on blacklisted host $node " + - s"since spark.blacklist.killBlacklistedExecutors is set.") - if (allocationClient.killExecutorsOnHost(node) == false) { - logError(s"Killing executors on node $node failed.") - } - case None => - logWarning(s"Not attempting to kill executors on blacklisted host $node " + - s"since allocation client is not defined.") - } - } + killExecutorsOnBlacklistedNode(node) } } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala index bba0b294f1afb..91ec172ffeda1 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala @@ -51,29 +51,21 @@ import org.apache.spark.util.{AccumulatorV2, ThreadUtils, Utils} * acquire a lock on us, so we need to make sure that we don't try to lock the backend while * we are holding a lock on ourselves. */ -private[spark] class TaskSchedulerImpl private[scheduler]( +private[spark] class TaskSchedulerImpl( val sc: SparkContext, val maxTaskFailures: Int, - private[scheduler] val blacklistTrackerOpt: Option[BlacklistTracker], isLocal: Boolean = false) extends TaskScheduler with Logging { import TaskSchedulerImpl._ def this(sc: SparkContext) = { - this( - sc, - sc.conf.get(config.MAX_TASK_FAILURES), - TaskSchedulerImpl.maybeCreateBlacklistTracker(sc)) + this(sc, sc.conf.get(config.MAX_TASK_FAILURES)) } - def this(sc: SparkContext, maxTaskFailures: Int, isLocal: Boolean) = { - this( - sc, - maxTaskFailures, - TaskSchedulerImpl.maybeCreateBlacklistTracker(sc), - isLocal = isLocal) - } + // Lazily initializing blackListTrackOpt to avoid getting empty ExecutorAllocationClient, + // because ExecutorAllocationClient is created after this TaskSchedulerImpl. + private[scheduler] lazy val blacklistTrackerOpt = maybeCreateBlacklistTracker(sc) val conf = sc.conf diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala index a41b059fa7dec..02d374dc37cd5 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala @@ -774,6 +774,12 @@ private[spark] class TaskSetManager( tasksSuccessful += 1 } isZombie = true + + if (fetchFailed.bmAddress != null) { + blacklistTracker.foreach(_.updateBlacklistForFetchFailure( + fetchFailed.bmAddress.host, fetchFailed.bmAddress.executorId)) + } + None case ef: ExceptionFailure => diff --git a/core/src/test/scala/org/apache/spark/scheduler/BlacklistTrackerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/BlacklistTrackerSuite.scala index 571c6bbb4585d..7ff03c44b0611 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/BlacklistTrackerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/BlacklistTrackerSuite.scala @@ -530,4 +530,59 @@ class BlacklistTrackerSuite extends SparkFunSuite with BeforeAndAfterEach with M verify(allocationClientMock).killExecutors(Seq("2"), true, true) verify(allocationClientMock).killExecutorsOnHost("hostA") } + + test("fetch failure blacklisting kills executors, configured by BLACKLIST_KILL_ENABLED") { + val allocationClientMock = mock[ExecutorAllocationClient] + when(allocationClientMock.killExecutors(any(), any(), any())).thenReturn(Seq("called")) + when(allocationClientMock.killExecutorsOnHost("hostA")).thenAnswer(new Answer[Boolean] { + // To avoid a race between blacklisting and killing, it is important that the nodeBlacklist + // is updated before we ask the executor allocation client to kill all the executors + // on a particular host. + override def answer(invocation: InvocationOnMock): Boolean = { + if (blacklist.nodeBlacklist.contains("hostA") == false) { + throw new IllegalStateException("hostA should be on the blacklist") + } + true + } + }) + + conf.set(config.BLACKLIST_FETCH_FAILURE_ENABLED, true) + blacklist = new BlacklistTracker(listenerBusMock, conf, Some(allocationClientMock), clock) + + // Disable auto-kill. Blacklist an executor and make sure killExecutors is not called. + conf.set(config.BLACKLIST_KILL_ENABLED, false) + blacklist.updateBlacklistForFetchFailure("hostA", exec = "1") + + verify(allocationClientMock, never).killExecutors(any(), any(), any()) + verify(allocationClientMock, never).killExecutorsOnHost(any()) + + // Enable auto-kill. Blacklist an executor and make sure killExecutors is called. + conf.set(config.BLACKLIST_KILL_ENABLED, true) + blacklist = new BlacklistTracker(listenerBusMock, conf, Some(allocationClientMock), clock) + clock.advance(1000) + blacklist.updateBlacklistForFetchFailure("hostA", exec = "1") + + verify(allocationClientMock).killExecutors(Seq("1"), true, true) + verify(allocationClientMock, never).killExecutorsOnHost(any()) + + assert(blacklist.executorIdToBlacklistStatus.contains("1")) + assert(blacklist.executorIdToBlacklistStatus("1").node === "hostA") + assert(blacklist.executorIdToBlacklistStatus("1").expiryTime === + 1000 + blacklist.BLACKLIST_TIMEOUT_MILLIS) + assert(blacklist.nextExpiryTime === 1000 + blacklist.BLACKLIST_TIMEOUT_MILLIS) + assert(blacklist.nodeIdToBlacklistExpiryTime.isEmpty) + + // Enable external shuffle service to see if all the executors on this node will be killed. + conf.set(config.SHUFFLE_SERVICE_ENABLED, true) + clock.advance(1000) + blacklist.updateBlacklistForFetchFailure("hostA", exec = "2") + + verify(allocationClientMock, never).killExecutors(Seq("2"), true, true) + verify(allocationClientMock).killExecutorsOnHost("hostA") + + assert(blacklist.nodeIdToBlacklistExpiryTime.contains("hostA")) + assert(blacklist.nodeIdToBlacklistExpiryTime("hostA") === + 2000 + blacklist.BLACKLIST_TIMEOUT_MILLIS) + assert(blacklist.nextExpiryTime === 1000 + blacklist.BLACKLIST_TIMEOUT_MILLIS) + } } diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala index 8b9d45f734cda..a00337776dadc 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala @@ -87,7 +87,7 @@ class TaskSchedulerImplSuite extends SparkFunSuite with LocalSparkContext with B conf.set(config.BLACKLIST_ENABLED, true) sc = new SparkContext(conf) taskScheduler = - new TaskSchedulerImpl(sc, sc.conf.getInt("spark.task.maxFailures", 4), Some(blacklist)) { + new TaskSchedulerImpl(sc, sc.conf.getInt("spark.task.maxFailures", 4)) { override def createTaskSetManager(taskSet: TaskSet, maxFailures: Int): TaskSetManager = { val tsm = super.createTaskSetManager(taskSet, maxFailures) // we need to create a spied tsm just so we can set the TaskSetBlacklist @@ -98,6 +98,8 @@ class TaskSchedulerImplSuite extends SparkFunSuite with LocalSparkContext with B stageToMockTaskSetBlacklist(taskSet.stageId) = taskSetBlacklist tsmSpy } + + override private[scheduler] lazy val blacklistTrackerOpt = Some(blacklist) } setupHelper() } diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala index db14c9acfdce5..80fb674725814 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala @@ -1140,6 +1140,38 @@ class TaskSetManagerSuite extends SparkFunSuite with LocalSparkContext with Logg .updateBlacklistForFailedTask(anyString(), anyString(), anyInt()) } + test("update application blacklist for shuffle-fetch") { + // Setup a taskset, and fail some one task for fetch failure. + val conf = new SparkConf() + .set(config.BLACKLIST_ENABLED, true) + .set(config.SHUFFLE_SERVICE_ENABLED, true) + .set(config.BLACKLIST_FETCH_FAILURE_ENABLED, true) + sc = new SparkContext("local", "test", conf) + sched = new FakeTaskScheduler(sc, ("exec1", "host1"), ("exec2", "host2")) + val taskSet = FakeTask.createTaskSet(4) + val blacklistTracker = new BlacklistTracker(sc, None) + val tsm = new TaskSetManager(sched, taskSet, 4, Some(blacklistTracker)) + + // make some offers to our taskset, to get tasks we will fail + val taskDescs = Seq( + "exec1" -> "host1", + "exec2" -> "host2" + ).flatMap { case (exec, host) => + // offer each executor twice (simulating 2 cores per executor) + (0 until 2).flatMap{ _ => tsm.resourceOffer(exec, host, TaskLocality.ANY)} + } + assert(taskDescs.size === 4) + + assert(!blacklistTracker.isExecutorBlacklisted(taskDescs(0).executorId)) + assert(!blacklistTracker.isNodeBlacklisted("host1")) + + // Fail the task with fetch failure + tsm.handleFailedTask(taskDescs(0).taskId, TaskState.FAILED, + FetchFailed(BlockManagerId(taskDescs(0).executorId, "host1", 12345), 0, 0, 0, "ignored")) + + assert(blacklistTracker.isNodeBlacklisted("host1")) + } + private def createTaskResult( id: Int, accumUpdates: Seq[AccumulatorV2[_, _]] = Seq.empty): DirectTaskResult[Int] = { diff --git a/docs/configuration.md b/docs/configuration.md index f4bec589208be..c8e61537a457c 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -1479,6 +1479,15 @@ Apart from these, the following properties are also available, and may be useful all of the executors on that node will be killed. + + spark.blacklist.application.fetchFailure.enabled + false + + (Experimental) If set to "true", Spark will blacklist the executor immediately when a fetch + failure happenes. If external shuffle service is enabled, then the whole node will be + blacklisted. + + spark.speculation false From c22810004fb2db249be6477c9801d09b807af851 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Tue, 27 Jun 2017 02:35:51 +0800 Subject: [PATCH 0800/1765] [SPARK-20213][SQL][FOLLOW-UP] introduce SQLExecution.ignoreNestedExecutionId ## What changes were proposed in this pull request? in https://github.com/apache/spark/pull/18064, to work around the nested sql execution id issue, we introduced several internal methods in `Dataset`, like `collectInternal`, `countInternal`, `showInternal`, etc., to avoid nested execution id. However, this approach has poor expansibility. When we hit other nested execution id cases, we may need to add more internal methods in `Dataset`. Our goal is to ignore the nested execution id in some cases, and we can have a better approach to achieve this goal, by introducing `SQLExecution.ignoreNestedExecutionId`. Whenever we find a place which needs to ignore the nested execution, we can just wrap the action with `SQLExecution.ignoreNestedExecutionId`, and this is more expansible than the previous approach. The idea comes from https://github.com/apache/spark/pull/17540/files#diff-ab49028253e599e6e74cc4f4dcb2e3a8R57 by rdblue ## How was this patch tested? existing tests. Author: Wenchen Fan Closes #18419 from cloud-fan/follow. --- .../scala/org/apache/spark/sql/Dataset.scala | 39 ++----------------- .../spark/sql/execution/SQLExecution.scala | 39 +++++++++++++++++-- .../command/AnalyzeTableCommand.scala | 5 ++- .../spark/sql/execution/command/cache.scala | 19 ++++----- .../datasources/csv/CSVDataSource.scala | 6 ++- .../datasources/jdbc/JDBCRelation.scala | 14 +++---- .../sql/execution/streaming/console.scala | 13 +++++-- .../sql/execution/streaming/memory.scala | 33 +++++++++------- 8 files changed, 89 insertions(+), 79 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index 6e66e92091ff9..268a37ff5d271 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -246,13 +246,8 @@ class Dataset[T] private[sql]( _numRows: Int, truncate: Int = 20, vertical: Boolean = false): String = { val numRows = _numRows.max(0) val takeResult = toDF().take(numRows + 1) - showString(takeResult, numRows, truncate, vertical) - } - - private def showString( - dataWithOneMoreRow: Array[Row], numRows: Int, truncate: Int, vertical: Boolean): String = { - val hasMoreData = dataWithOneMoreRow.length > numRows - val data = dataWithOneMoreRow.take(numRows) + val hasMoreData = takeResult.length > numRows + val data = takeResult.take(numRows) lazy val timeZone = DateTimeUtils.getTimeZone(sparkSession.sessionState.conf.sessionLocalTimeZone) @@ -688,19 +683,6 @@ class Dataset[T] private[sql]( println(showString(numRows, truncate = 0)) } - // An internal version of `show`, which won't set execution id and trigger listeners. - private[sql] def showInternal(_numRows: Int, truncate: Boolean): Unit = { - val numRows = _numRows.max(0) - val takeResult = toDF().takeInternal(numRows + 1) - - if (truncate) { - println(showString(takeResult, numRows, truncate = 20, vertical = false)) - } else { - println(showString(takeResult, numRows, truncate = 0, vertical = false)) - } - } - // scalastyle:on println - /** * Displays the Dataset in a tabular form. For example: * {{{ @@ -2467,11 +2449,6 @@ class Dataset[T] private[sql]( */ def take(n: Int): Array[T] = head(n) - // An internal version of `take`, which won't set execution id and trigger listeners. - private[sql] def takeInternal(n: Int): Array[T] = { - collectFromPlan(limit(n).queryExecution.executedPlan) - } - /** * Returns the first `n` rows in the Dataset as a list. * @@ -2496,11 +2473,6 @@ class Dataset[T] private[sql]( */ def collect(): Array[T] = withAction("collect", queryExecution)(collectFromPlan) - // An internal version of `collect`, which won't set execution id and trigger listeners. - private[sql] def collectInternal(): Array[T] = { - collectFromPlan(queryExecution.executedPlan) - } - /** * Returns a Java list that contains all rows in this Dataset. * @@ -2542,11 +2514,6 @@ class Dataset[T] private[sql]( plan.executeCollect().head.getLong(0) } - // An internal version of `count`, which won't set execution id and trigger listeners. - private[sql] def countInternal(): Long = { - groupBy().count().queryExecution.executedPlan.executeCollect().head.getLong(0) - } - /** * Returns a new Dataset that has exactly `numPartitions` partitions. * @@ -2792,7 +2759,7 @@ class Dataset[T] private[sql]( createTempViewCommand(viewName, replace = true, global = true) } - private[spark] def createTempViewCommand( + private def createTempViewCommand( viewName: String, replace: Boolean, global: Boolean): CreateViewCommand = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala index bb206e84325fd..ca8bed5214f87 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala @@ -29,6 +29,8 @@ object SQLExecution { val EXECUTION_ID_KEY = "spark.sql.execution.id" + private val IGNORE_NESTED_EXECUTION_ID = "spark.sql.execution.ignoreNestedExecutionId" + private val _nextExecutionId = new AtomicLong(0) private def nextExecutionId: Long = _nextExecutionId.getAndIncrement @@ -42,8 +44,11 @@ object SQLExecution { private val testing = sys.props.contains("spark.testing") private[sql] def checkSQLExecutionId(sparkSession: SparkSession): Unit = { + val sc = sparkSession.sparkContext + val isNestedExecution = sc.getLocalProperty(IGNORE_NESTED_EXECUTION_ID) != null + val hasExecutionId = sc.getLocalProperty(EXECUTION_ID_KEY) != null // only throw an exception during tests. a missing execution ID should not fail a job. - if (testing && sparkSession.sparkContext.getLocalProperty(EXECUTION_ID_KEY) == null) { + if (testing && !isNestedExecution && !hasExecutionId) { // Attention testers: when a test fails with this exception, it means that the action that // started execution of a query didn't call withNewExecutionId. The execution ID should be // set by calling withNewExecutionId in the action that begins execution, like @@ -65,7 +70,7 @@ object SQLExecution { val executionId = SQLExecution.nextExecutionId sc.setLocalProperty(EXECUTION_ID_KEY, executionId.toString) executionIdToQueryExecution.put(executionId, queryExecution) - val r = try { + try { // sparkContext.getCallSite() would first try to pick up any call site that was previously // set, then fall back to Utils.getCallSite(); call Utils.getCallSite() directly on // streaming queries would give us call site like "run at :0" @@ -84,7 +89,15 @@ object SQLExecution { executionIdToQueryExecution.remove(executionId) sc.setLocalProperty(EXECUTION_ID_KEY, null) } - r + } else if (sc.getLocalProperty(IGNORE_NESTED_EXECUTION_ID) != null) { + // If `IGNORE_NESTED_EXECUTION_ID` is set, just ignore the execution id while evaluating the + // `body`, so that Spark jobs issued in the `body` won't be tracked. + try { + sc.setLocalProperty(EXECUTION_ID_KEY, null) + body + } finally { + sc.setLocalProperty(EXECUTION_ID_KEY, oldExecutionId) + } } else { // Don't support nested `withNewExecutionId`. This is an example of the nested // `withNewExecutionId`: @@ -100,7 +113,9 @@ object SQLExecution { // all accumulator metrics will be 0. It will confuse people if we show them in Web UI. // // A real case is the `DataFrame.count` method. - throw new IllegalArgumentException(s"$EXECUTION_ID_KEY is already set") + throw new IllegalArgumentException(s"$EXECUTION_ID_KEY is already set, please wrap your " + + "action with SQLExecution.ignoreNestedExecutionId if you don't want to track the Spark " + + "jobs issued by the nested execution.") } } @@ -118,4 +133,20 @@ object SQLExecution { sc.setLocalProperty(SQLExecution.EXECUTION_ID_KEY, oldExecutionId) } } + + /** + * Wrap an action which may have nested execution id. This method can be used to run an execution + * inside another execution, e.g., `CacheTableCommand` need to call `Dataset.collect`. Note that, + * all Spark jobs issued in the body won't be tracked in UI. + */ + def ignoreNestedExecutionId[T](sparkSession: SparkSession)(body: => T): T = { + val sc = sparkSession.sparkContext + val allowNestedPreviousValue = sc.getLocalProperty(IGNORE_NESTED_EXECUTION_ID) + try { + sc.setLocalProperty(IGNORE_NESTED_EXECUTION_ID, "true") + body + } finally { + sc.setLocalProperty(IGNORE_NESTED_EXECUTION_ID, allowNestedPreviousValue) + } + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeTableCommand.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeTableCommand.scala index 06e588f56f1e9..13b8faff844c7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeTableCommand.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeTableCommand.scala @@ -27,6 +27,7 @@ import org.apache.spark.internal.Logging import org.apache.spark.sql.{AnalysisException, Row, SparkSession} import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.catalog.{CatalogStatistics, CatalogTable, CatalogTableType} +import org.apache.spark.sql.execution.SQLExecution import org.apache.spark.sql.internal.SessionState @@ -58,7 +59,9 @@ case class AnalyzeTableCommand( // 2. when total size is changed, `oldRowCount` becomes invalid. // This is to make sure that we only record the right statistics. if (!noscan) { - val newRowCount = sparkSession.table(tableIdentWithDB).countInternal() + val newRowCount = SQLExecution.ignoreNestedExecutionId(sparkSession) { + sparkSession.table(tableIdentWithDB).count() + } if (newRowCount >= 0 && newRowCount != oldRowCount) { newStats = if (newStats.isDefined) { newStats.map(_.copy(rowCount = Some(BigInt(newRowCount)))) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/cache.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/cache.scala index 184d0387ebfa9..d36eb7587a3ef 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/cache.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/cache.scala @@ -22,6 +22,7 @@ import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.analysis.NoSuchTableException import org.apache.spark.sql.catalyst.plans.QueryPlan import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.execution.SQLExecution case class CacheTableCommand( tableIdent: TableIdentifier, @@ -33,16 +34,16 @@ case class CacheTableCommand( override def innerChildren: Seq[QueryPlan[_]] = plan.toSeq override def run(sparkSession: SparkSession): Seq[Row] = { - plan.foreach { logicalPlan => - Dataset.ofRows(sparkSession, logicalPlan) - .createTempViewCommand(tableIdent.quotedString, replace = false, global = false) - .run(sparkSession) - } - sparkSession.catalog.cacheTable(tableIdent.quotedString) + SQLExecution.ignoreNestedExecutionId(sparkSession) { + plan.foreach { logicalPlan => + Dataset.ofRows(sparkSession, logicalPlan).createTempView(tableIdent.quotedString) + } + sparkSession.catalog.cacheTable(tableIdent.quotedString) - if (!isLazy) { - // Performs eager caching - sparkSession.table(tableIdent).countInternal() + if (!isLazy) { + // Performs eager caching + sparkSession.table(tableIdent).count() + } } Seq.empty[Row] diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala index eadc6c94f4b3c..99133bd70989a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala @@ -32,6 +32,7 @@ import org.apache.spark.input.{PortableDataStream, StreamInputFormat} import org.apache.spark.rdd.{BinaryFileRDD, RDD} import org.apache.spark.sql.{Dataset, Encoders, SparkSession} import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.execution.SQLExecution import org.apache.spark.sql.execution.datasources._ import org.apache.spark.sql.execution.datasources.text.TextFileFormat import org.apache.spark.sql.types.StructType @@ -144,8 +145,9 @@ object TextInputCSVDataSource extends CSVDataSource { inputPaths: Seq[FileStatus], parsedOptions: CSVOptions): StructType = { val csv = createBaseDataset(sparkSession, inputPaths, parsedOptions) - val maybeFirstLine = - CSVUtils.filterCommentAndEmpty(csv, parsedOptions).takeInternal(1).headOption + val maybeFirstLine = SQLExecution.ignoreNestedExecutionId(sparkSession) { + CSVUtils.filterCommentAndEmpty(csv, parsedOptions).take(1).headOption + } inferFromDataset(sparkSession, csv, maybeFirstLine, parsedOptions) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRelation.scala index a06f1ce3287e6..b11da7045de22 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRelation.scala @@ -23,6 +23,7 @@ import org.apache.spark.internal.Logging import org.apache.spark.Partition import org.apache.spark.rdd.RDD import org.apache.spark.sql.{DataFrame, Row, SaveMode, SparkSession, SQLContext} +import org.apache.spark.sql.execution.SQLExecution import org.apache.spark.sql.jdbc.JdbcDialects import org.apache.spark.sql.sources._ import org.apache.spark.sql.types.StructType @@ -129,14 +130,11 @@ private[sql] case class JDBCRelation( } override def insert(data: DataFrame, overwrite: Boolean): Unit = { - import scala.collection.JavaConverters._ - - val options = jdbcOptions.asProperties.asScala + - ("url" -> jdbcOptions.url, "dbtable" -> jdbcOptions.table) - val mode = if (overwrite) SaveMode.Overwrite else SaveMode.Append - - new JdbcRelationProvider().createRelation( - data.sparkSession.sqlContext, mode, options.toMap, data) + SQLExecution.ignoreNestedExecutionId(data.sparkSession) { + data.write + .mode(if (overwrite) SaveMode.Overwrite else SaveMode.Append) + .jdbc(jdbcOptions.url, jdbcOptions.table, jdbcOptions.asProperties) + } } override def toString: String = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/console.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/console.scala index 9e889ff679450..6fa7c113defaa 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/console.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/console.scala @@ -22,6 +22,7 @@ import org.apache.spark.sql.{DataFrame, SQLContext} import org.apache.spark.sql.sources.{BaseRelation, CreatableRelationProvider, DataSourceRegister, StreamSinkProvider} import org.apache.spark.sql.streaming.OutputMode import org.apache.spark.sql.SaveMode +import org.apache.spark.sql.execution.SQLExecution import org.apache.spark.sql.types.StructType class ConsoleSink(options: Map[String, String]) extends Sink with Logging { @@ -47,9 +48,11 @@ class ConsoleSink(options: Map[String, String]) extends Sink with Logging { println(batchIdStr) println("-------------------------------------------") // scalastyle:off println - data.sparkSession.createDataFrame( - data.sparkSession.sparkContext.parallelize(data.collectInternal()), data.schema) - .showInternal(numRowsToShow, isTruncated) + SQLExecution.ignoreNestedExecutionId(data.sparkSession) { + data.sparkSession.createDataFrame( + data.sparkSession.sparkContext.parallelize(data.collect()), data.schema) + .show(numRowsToShow, isTruncated) + } } } @@ -79,7 +82,9 @@ class ConsoleSinkProvider extends StreamSinkProvider // Truncate the displayed data if it is too long, by default it is true val isTruncated = parameters.get("truncate").map(_.toBoolean).getOrElse(true) - data.showInternal(numRowsToShow, isTruncated) + SQLExecution.ignoreNestedExecutionId(sqlContext.sparkSession) { + data.show(numRowsToShow, isTruncated) + } ConsoleRelation(sqlContext, data) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala index a5dac469f85b6..198a342582804 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala @@ -29,6 +29,7 @@ import org.apache.spark.sql.catalyst.encoders.encoderFor import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, Statistics} import org.apache.spark.sql.catalyst.streaming.InternalOutputModes._ +import org.apache.spark.sql.execution.SQLExecution import org.apache.spark.sql.streaming.OutputMode import org.apache.spark.sql.types.StructType import org.apache.spark.util.Utils @@ -193,21 +194,23 @@ class MemorySink(val schema: StructType, outputMode: OutputMode) extends Sink wi } if (notCommitted) { logDebug(s"Committing batch $batchId to $this") - outputMode match { - case Append | Update => - val rows = AddedData(batchId, data.collectInternal()) - synchronized { batches += rows } - - case Complete => - val rows = AddedData(batchId, data.collectInternal()) - synchronized { - batches.clear() - batches += rows - } - - case _ => - throw new IllegalArgumentException( - s"Output mode $outputMode is not supported by MemorySink") + SQLExecution.ignoreNestedExecutionId(data.sparkSession) { + outputMode match { + case Append | Update => + val rows = AddedData(batchId, data.collect()) + synchronized { batches += rows } + + case Complete => + val rows = AddedData(batchId, data.collect()) + synchronized { + batches.clear() + batches += rows + } + + case _ => + throw new IllegalArgumentException( + s"Output mode $outputMode is not supported by MemorySink") + } } } else { logDebug(s"Skipping already committed batch: $batchId") From 3cb3ccce120fa9f0273133912624b877b42d95fd Mon Sep 17 00:00:00 2001 From: Wang Gengliang Date: Tue, 27 Jun 2017 17:24:46 +0800 Subject: [PATCH 0801/1765] [SPARK-21196] Split codegen info of query plan into sequence codegen info of query plan can be very long. In debugging console / web page, it would be more readable if the subtrees and corresponding codegen are split into sequence. Example: ```java codegenStringSeq(sql("select 1").queryExecution.executedPlan) ``` The example will return Seq[(String, String)] of length 1, containing the subtree as string and the corresponding generated code. The subtree as string: > (*Project [1 AS 1#0] > +- Scan OneRowRelation[] The generated code: ```java /* 001 */ public Object generate(Object[] references) { /* 002 */ return new GeneratedIterator(references); /* 003 */ } /* 004 */ /* 005 */ final class GeneratedIterator extends org.apache.spark.sql.execution.BufferedRowIterator { /* 006 */ private Object[] references; /* 007 */ private scala.collection.Iterator[] inputs; /* 008 */ private scala.collection.Iterator inputadapter_input; /* 009 */ private UnsafeRow project_result; /* 010 */ private org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder project_holder; /* 011 */ private org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter project_rowWriter; /* 012 */ /* 013 */ public GeneratedIterator(Object[] references) { /* 014 */ this.references = references; /* 015 */ } /* 016 */ /* 017 */ public void init(int index, scala.collection.Iterator[] inputs) { /* 018 */ partitionIndex = index; /* 019 */ this.inputs = inputs; /* 020 */ inputadapter_input = inputs[0]; /* 021 */ project_result = new UnsafeRow(1); /* 022 */ project_holder = new org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder(project_result, 0); /* 023 */ project_rowWriter = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter(project_holder, 1); /* 024 */ /* 025 */ } /* 026 */ /* 027 */ protected void processNext() throws java.io.IOException { /* 028 */ while (inputadapter_input.hasNext() && !stopEarly()) { /* 029 */ InternalRow inputadapter_row = (InternalRow) inputadapter_input.next(); /* 030 */ project_rowWriter.write(0, 1); /* 031 */ append(project_result); /* 032 */ if (shouldStop()) return; /* 033 */ } /* 034 */ } /* 035 */ /* 036 */ } ``` ## What changes were proposed in this pull request? add method codegenToSeq: split codegen info of query plan into sequence ## How was this patch tested? unit test cloud-fan gatorsmile Please review http://spark.apache.org/contributing.html before opening a pull request. Author: Wang Gengliang Closes #18409 from gengliangwang/codegen. --- .../spark/sql/execution/QueryExecution.scala | 9 +++++ .../spark/sql/execution/debug/package.scala | 35 ++++++++++++++----- .../sql/execution/debug/DebuggingSuite.scala | 7 ++++ 3 files changed, 43 insertions(+), 8 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala index c7cac332a0377..9533144214a10 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala @@ -245,5 +245,14 @@ class QueryExecution(val sparkSession: SparkSession, val logical: LogicalPlan) { println(org.apache.spark.sql.execution.debug.codegenString(executedPlan)) // scalastyle:on println } + + /** + * Get WholeStageCodegenExec subtrees and the codegen in a query plan + * + * @return Sequence of WholeStageCodegen subtrees and corresponding codegen + */ + def codegenToSeq(): Seq[(String, String)] = { + org.apache.spark.sql.execution.debug.codegenStringSeq(executedPlan) + } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala index 0395c43ba2cbc..a717cbd4a7df9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala @@ -50,7 +50,31 @@ package object debug { // scalastyle:on println } + /** + * Get WholeStageCodegenExec subtrees and the codegen in a query plan into one String + * + * @param plan the query plan for codegen + * @return single String containing all WholeStageCodegen subtrees and corresponding codegen + */ def codegenString(plan: SparkPlan): String = { + val codegenSeq = codegenStringSeq(plan) + var output = s"Found ${codegenSeq.size} WholeStageCodegen subtrees.\n" + for (((subtree, code), i) <- codegenSeq.zipWithIndex) { + output += s"== Subtree ${i + 1} / ${codegenSeq.size} ==\n" + output += subtree + output += "\nGenerated code:\n" + output += s"${code}\n" + } + output + } + + /** + * Get WholeStageCodegenExec subtrees and the codegen in a query plan + * + * @param plan the query plan for codegen + * @return Sequence of WholeStageCodegen subtrees and corresponding codegen + */ + def codegenStringSeq(plan: SparkPlan): Seq[(String, String)] = { val codegenSubtrees = new collection.mutable.HashSet[WholeStageCodegenExec]() plan transform { case s: WholeStageCodegenExec => @@ -58,15 +82,10 @@ package object debug { s case s => s } - var output = s"Found ${codegenSubtrees.size} WholeStageCodegen subtrees.\n" - for ((s, i) <- codegenSubtrees.toSeq.zipWithIndex) { - output += s"== Subtree ${i + 1} / ${codegenSubtrees.size} ==\n" - output += s - output += "\nGenerated code:\n" - val (_, source) = s.doCodeGen() - output += s"${CodeFormatter.format(source)}\n" + codegenSubtrees.toSeq.map { subtree => + val (_, source) = subtree.doCodeGen() + (subtree.toString, CodeFormatter.format(source)) } - output } /** diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/debug/DebuggingSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/debug/DebuggingSuite.scala index 4fc52c99fbeeb..adcaf2d76519f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/debug/DebuggingSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/debug/DebuggingSuite.scala @@ -38,4 +38,11 @@ class DebuggingSuite extends SparkFunSuite with SharedSQLContext { assert(res.contains("Subtree 2 / 2")) assert(res.contains("Object[]")) } + + test("debugCodegenStringSeq") { + val res = codegenStringSeq(spark.range(10).groupBy("id").count().queryExecution.executedPlan) + assert(res.length == 2) + assert(res.forall{ case (subtree, code) => + subtree.contains("Range") && code.contains("Object[]")}) + } } From b32bd005e46443bbd487b7a1f1078578c8f4c181 Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Tue, 27 Jun 2017 13:14:12 +0100 Subject: [PATCH 0802/1765] [INFRA] Close stale PRs ## What changes were proposed in this pull request? This PR proposes to close stale PRs, mostly the same instances with https://github.com/apache/spark/pull/18017 I believe the author in #14807 removed his account. Closes #7075 Closes #8927 Closes #9202 Closes #9366 Closes #10861 Closes #11420 Closes #12356 Closes #13028 Closes #13506 Closes #14191 Closes #14198 Closes #14330 Closes #14807 Closes #15839 Closes #16225 Closes #16685 Closes #16692 Closes #16995 Closes #17181 Closes #17211 Closes #17235 Closes #17237 Closes #17248 Closes #17341 Closes #17708 Closes #17716 Closes #17721 Closes #17937 Added: Closes #14739 Closes #17139 Closes #17445 Closes #18042 Closes #18359 Added: Closes #16450 Closes #16525 Closes #17738 Added: Closes #16458 Closes #16508 Closes #17714 Added: Closes #17830 Closes #14742 ## How was this patch tested? N/A Author: hyukjinkwon Closes #18417 from HyukjinKwon/close-stale-pr. From fd8c931a30a084ee981b75aa469fc97dda6cfaa9 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Wed, 28 Jun 2017 00:57:05 +0800 Subject: [PATCH 0803/1765] [SPARK-19104][SQL] Lambda variables in ExternalMapToCatalyst should be global ## What changes were proposed in this pull request? The issue happens in `ExternalMapToCatalyst`. For example, the following codes create `ExternalMapToCatalyst` to convert Scala Map to catalyst map format. val data = Seq.tabulate(10)(i => NestedData(1, Map("key" -> InnerData("name", i + 100)))) val ds = spark.createDataset(data) The `valueConverter` in `ExternalMapToCatalyst` looks like: if (isnull(lambdavariable(ExternalMapToCatalyst_value52, ExternalMapToCatalyst_value_isNull52, ObjectType(class org.apache.spark.sql.InnerData), true))) null else named_struct(name, staticinvoke(class org.apache.spark.unsafe.types.UTF8String, StringType, fromString, assertnotnull(lambdavariable(ExternalMapToCatalyst_value52, ExternalMapToCatalyst_value_isNull52, ObjectType(class org.apache.spark.sql.InnerData), true)).name, true), value, assertnotnull(lambdavariable(ExternalMapToCatalyst_value52, ExternalMapToCatalyst_value_isNull52, ObjectType(class org.apache.spark.sql.InnerData), true)).value) There is a `CreateNamedStruct` expression (`named_struct`) to create a row of `InnerData.name` and `InnerData.value` that are referred by `ExternalMapToCatalyst_value52`. Because `ExternalMapToCatalyst_value52` are local variable, when `CreateNamedStruct` splits expressions to individual functions, the local variable can't be accessed anymore. ## How was this patch tested? Jenkins tests. Author: Liang-Chi Hsieh Closes #18418 from viirya/SPARK-19104. --- .../catalyst/expressions/objects/objects.scala | 18 ++++++++++++------ .../spark/sql/DatasetPrimitiveSuite.scala | 8 ++++++++ 2 files changed, 20 insertions(+), 6 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala index 073993cccdf8a..4b651836ff4d2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala @@ -911,6 +911,12 @@ case class ExternalMapToCatalyst private( val entry = ctx.freshName("entry") val entries = ctx.freshName("entries") + val keyElementJavaType = ctx.javaType(keyType) + val valueElementJavaType = ctx.javaType(valueType) + ctx.addMutableState(keyElementJavaType, key, "") + ctx.addMutableState("boolean", valueIsNull, "") + ctx.addMutableState(valueElementJavaType, value, "") + val (defineEntries, defineKeyValue) = child.dataType match { case ObjectType(cls) if classOf[java.util.Map[_, _]].isAssignableFrom(cls) => val javaIteratorCls = classOf[java.util.Iterator[_]].getName @@ -922,8 +928,8 @@ case class ExternalMapToCatalyst private( val defineKeyValue = s""" final $javaMapEntryCls $entry = ($javaMapEntryCls) $entries.next(); - ${ctx.javaType(keyType)} $key = (${ctx.boxedType(keyType)}) $entry.getKey(); - ${ctx.javaType(valueType)} $value = (${ctx.boxedType(valueType)}) $entry.getValue(); + $key = (${ctx.boxedType(keyType)}) $entry.getKey(); + $value = (${ctx.boxedType(valueType)}) $entry.getValue(); """ defineEntries -> defineKeyValue @@ -937,17 +943,17 @@ case class ExternalMapToCatalyst private( val defineKeyValue = s""" final $scalaMapEntryCls $entry = ($scalaMapEntryCls) $entries.next(); - ${ctx.javaType(keyType)} $key = (${ctx.boxedType(keyType)}) $entry._1(); - ${ctx.javaType(valueType)} $value = (${ctx.boxedType(valueType)}) $entry._2(); + $key = (${ctx.boxedType(keyType)}) $entry._1(); + $value = (${ctx.boxedType(valueType)}) $entry._2(); """ defineEntries -> defineKeyValue } val valueNullCheck = if (ctx.isPrimitiveType(valueType)) { - s"boolean $valueIsNull = false;" + s"$valueIsNull = false;" } else { - s"boolean $valueIsNull = $value == null;" + s"$valueIsNull = $value == null;" } val arrayCls = classOf[GenericArrayData].getName diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala index 4126660b5d102..a6847dcfbffc4 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala @@ -39,6 +39,9 @@ case class ComplexClass(seq: SeqClass, list: ListClass, queue: QueueClass) case class ComplexMapClass(map: MapClass, lhmap: LHMapClass) +case class InnerData(name: String, value: Int) +case class NestedData(id: Int, param: Map[String, InnerData]) + package object packageobject { case class PackageClass(value: Int) } @@ -354,4 +357,9 @@ class DatasetPrimitiveSuite extends QueryTest with SharedSQLContext { checkDataset(Seq(PackageClass(1)).toDS(), PackageClass(1)) } + test("SPARK-19104: Lambda variables in ExternalMapToCatalyst should be global") { + val data = Seq.tabulate(10)(i => NestedData(1, Map("key" -> InnerData("name", i + 100)))) + val ds = spark.createDataset(data) + checkDataset(ds, data: _*) + } } From 2d686a19e341a31d976aa42228b7589f87dfd6c2 Mon Sep 17 00:00:00 2001 From: Eric Vandenberg Date: Wed, 28 Jun 2017 09:26:33 +0800 Subject: [PATCH 0804/1765] [SPARK-21155][WEBUI] Add (? running tasks) into Spark UI progress ## What changes were proposed in this pull request? Add metric on number of running tasks to status bar on Jobs / Active Jobs. ## How was this patch tested? Run a long running (1 minute) query in spark-shell and use localhost:4040 web UI to observe progress. See jira for screen snapshot. Author: Eric Vandenberg Closes #18369 from ericvandenbergfb/runningTasks. --- core/src/main/scala/org/apache/spark/ui/UIUtils.scala | 1 + 1 file changed, 1 insertion(+) diff --git a/core/src/main/scala/org/apache/spark/ui/UIUtils.scala b/core/src/main/scala/org/apache/spark/ui/UIUtils.scala index 2610f673d27f6..ba798df13c95d 100644 --- a/core/src/main/scala/org/apache/spark/ui/UIUtils.scala +++ b/core/src/main/scala/org/apache/spark/ui/UIUtils.scala @@ -356,6 +356,7 @@ private[spark] object UIUtils extends Logging {
    {completed}/{total} + { if (failed == 0 && skipped == 0 && started > 0) s"($started running)" } { if (failed > 0) s"($failed failed)" } { if (skipped > 0) s"($skipped skipped)" } { reasonToNumKilled.toSeq.sortBy(-_._2).map { From e793bf248bc3c71b9664f26377bce06b0ffa97a7 Mon Sep 17 00:00:00 2001 From: actuaryzhang Date: Tue, 27 Jun 2017 23:15:45 -0700 Subject: [PATCH 0805/1765] [SPARK-20889][SPARKR] Grouped documentation for MATH column methods ## What changes were proposed in this pull request? Grouped documentation for math column methods. Author: actuaryzhang Author: Wayne Zhang Closes #18371 from actuaryzhang/sparkRDocMath. --- R/pkg/R/functions.R | 619 +++++++++++++++----------------------------- R/pkg/R/generics.R | 48 ++-- 2 files changed, 241 insertions(+), 426 deletions(-) diff --git a/R/pkg/R/functions.R b/R/pkg/R/functions.R index 31028585aaa13..23ccdf941a8c7 100644 --- a/R/pkg/R/functions.R +++ b/R/pkg/R/functions.R @@ -86,6 +86,31 @@ NULL #' df <- createDataFrame(data.frame(time = as.POSIXct(dts), y = y))} NULL +#' Math functions for Column operations +#' +#' Math functions defined for \code{Column}. +#' +#' @param x Column to compute on. In \code{shiftLeft}, \code{shiftRight} and \code{shiftRightUnsigned}, +#' this is the number of bits to shift. +#' @param y Column to compute on. +#' @param ... additional argument(s). +#' @name column_math_functions +#' @rdname column_math_functions +#' @family math functions +#' @examples +#' \dontrun{ +#' # Dataframe used throughout this doc +#' df <- createDataFrame(cbind(model = rownames(mtcars), mtcars)) +#' tmp <- mutate(df, v1 = log(df$mpg), v2 = cbrt(df$disp), +#' v3 = bround(df$wt, 1), v4 = bin(df$cyl), +#' v5 = hex(df$wt), v6 = toDegrees(df$gear), +#' v7 = atan2(df$cyl, df$am), v8 = hypot(df$cyl, df$am), +#' v9 = pmod(df$hp, df$cyl), v10 = shiftLeft(df$disp, 1), +#' v11 = conv(df$hp, 10, 16), v12 = sign(df$vs - 0.5), +#' v13 = sqrt(df$disp), v14 = ceil(df$wt)) +#' head(tmp)} +NULL + #' lit #' #' A new \linkS4class{Column} is created to represent the literal value. @@ -112,18 +137,12 @@ setMethod("lit", signature("ANY"), column(jc) }) -#' abs -#' -#' Computes the absolute value. -#' -#' @param x Column to compute on. +#' @details +#' \code{abs}: Computes the absolute value. #' -#' @rdname abs -#' @name abs -#' @family non-aggregate functions +#' @rdname column_math_functions #' @export -#' @examples \dontrun{abs(df$c)} -#' @aliases abs,Column-method +#' @aliases abs abs,Column-method #' @note abs since 1.5.0 setMethod("abs", signature(x = "Column"), @@ -132,19 +151,13 @@ setMethod("abs", column(jc) }) -#' acos -#' -#' Computes the cosine inverse of the given value; the returned angle is in the range -#' 0.0 through pi. -#' -#' @param x Column to compute on. +#' @details +#' \code{acos}: Computes the cosine inverse of the given value; the returned angle is in +#' the range 0.0 through pi. #' -#' @rdname acos -#' @name acos -#' @family math functions +#' @rdname column_math_functions #' @export -#' @examples \dontrun{acos(df$c)} -#' @aliases acos,Column-method +#' @aliases acos acos,Column-method #' @note acos since 1.5.0 setMethod("acos", signature(x = "Column"), @@ -196,19 +209,13 @@ setMethod("ascii", column(jc) }) -#' asin -#' -#' Computes the sine inverse of the given value; the returned angle is in the range -#' -pi/2 through pi/2. -#' -#' @param x Column to compute on. +#' @details +#' \code{asin}: Computes the sine inverse of the given value; the returned angle is in +#' the range -pi/2 through pi/2. #' -#' @rdname asin -#' @name asin -#' @family math functions +#' @rdname column_math_functions #' @export -#' @aliases asin,Column-method -#' @examples \dontrun{asin(df$c)} +#' @aliases asin asin,Column-method #' @note asin since 1.5.0 setMethod("asin", signature(x = "Column"), @@ -217,18 +224,12 @@ setMethod("asin", column(jc) }) -#' atan -#' -#' Computes the tangent inverse of the given value. -#' -#' @param x Column to compute on. +#' @details +#' \code{atan}: Computes the tangent inverse of the given value. #' -#' @rdname atan -#' @name atan -#' @family math functions +#' @rdname column_math_functions #' @export -#' @aliases atan,Column-method -#' @examples \dontrun{atan(df$c)} +#' @aliases atan atan,Column-method #' @note atan since 1.5.0 setMethod("atan", signature(x = "Column"), @@ -276,19 +277,13 @@ setMethod("base64", column(jc) }) -#' bin -#' -#' An expression that returns the string representation of the binary value of the given long -#' column. For example, bin("12") returns "1100". -#' -#' @param x Column to compute on. +#' @details +#' \code{bin}: An expression that returns the string representation of the binary value +#' of the given long column. For example, bin("12") returns "1100". #' -#' @rdname bin -#' @name bin -#' @family math functions +#' @rdname column_math_functions #' @export -#' @aliases bin,Column-method -#' @examples \dontrun{bin(df$c)} +#' @aliases bin bin,Column-method #' @note bin since 1.5.0 setMethod("bin", signature(x = "Column"), @@ -317,18 +312,12 @@ setMethod("bitwiseNOT", column(jc) }) -#' cbrt -#' -#' Computes the cube-root of the given value. -#' -#' @param x Column to compute on. +#' @details +#' \code{cbrt}: Computes the cube-root of the given value. #' -#' @rdname cbrt -#' @name cbrt -#' @family math functions +#' @rdname column_math_functions #' @export -#' @aliases cbrt,Column-method -#' @examples \dontrun{cbrt(df$c)} +#' @aliases cbrt cbrt,Column-method #' @note cbrt since 1.4.0 setMethod("cbrt", signature(x = "Column"), @@ -337,18 +326,12 @@ setMethod("cbrt", column(jc) }) -#' Computes the ceiling of the given value -#' -#' Computes the ceiling of the given value. -#' -#' @param x Column to compute on. +#' @details +#' \code{ceil}: Computes the ceiling of the given value. #' -#' @rdname ceil -#' @name ceil -#' @family math functions +#' @rdname column_math_functions #' @export -#' @aliases ceil,Column-method -#' @examples \dontrun{ceil(df$c)} +#' @aliases ceil ceil,Column-method #' @note ceil since 1.5.0 setMethod("ceil", signature(x = "Column"), @@ -357,6 +340,19 @@ setMethod("ceil", column(jc) }) +#' @details +#' \code{ceiling}: Alias for \code{ceil}. +#' +#' @rdname column_math_functions +#' @aliases ceiling ceiling,Column-method +#' @export +#' @note ceiling since 1.5.0 +setMethod("ceiling", + signature(x = "Column"), + function(x) { + ceil(x) + }) + #' Returns the first column that is not NA #' #' Returns the first column that is not NA, or NA if all inputs are. @@ -405,6 +401,7 @@ setMethod("column", function(x) { col(x) }) + #' corr #' #' Computes the Pearson Correlation Coefficient for two Columns. @@ -493,18 +490,12 @@ setMethod("covar_pop", signature(col1 = "characterOrColumn", col2 = "characterOr column(jc) }) -#' cos -#' -#' Computes the cosine of the given value. -#' -#' @param x Column to compute on. +#' @details +#' \code{cos}: Computes the cosine of the given value. #' -#' @rdname cos -#' @name cos -#' @family math functions -#' @aliases cos,Column-method +#' @rdname column_math_functions +#' @aliases cos cos,Column-method #' @export -#' @examples \dontrun{cos(df$c)} #' @note cos since 1.5.0 setMethod("cos", signature(x = "Column"), @@ -513,18 +504,12 @@ setMethod("cos", column(jc) }) -#' cosh -#' -#' Computes the hyperbolic cosine of the given value. -#' -#' @param x Column to compute on. +#' @details +#' \code{cosh}: Computes the hyperbolic cosine of the given value. #' -#' @rdname cosh -#' @name cosh -#' @family math functions -#' @aliases cosh,Column-method +#' @rdname column_math_functions +#' @aliases cosh cosh,Column-method #' @export -#' @examples \dontrun{cosh(df$c)} #' @note cosh since 1.5.0 setMethod("cosh", signature(x = "Column"), @@ -679,18 +664,12 @@ setMethod("encode", column(jc) }) -#' exp -#' -#' Computes the exponential of the given value. -#' -#' @param x Column to compute on. +#' @details +#' \code{exp}: Computes the exponential of the given value. #' -#' @rdname exp -#' @name exp -#' @family math functions -#' @aliases exp,Column-method +#' @rdname column_math_functions +#' @aliases exp exp,Column-method #' @export -#' @examples \dontrun{exp(df$c)} #' @note exp since 1.5.0 setMethod("exp", signature(x = "Column"), @@ -699,18 +678,12 @@ setMethod("exp", column(jc) }) -#' expm1 -#' -#' Computes the exponential of the given value minus one. -#' -#' @param x Column to compute on. +#' @details +#' \code{expm1}: Computes the exponential of the given value minus one. #' -#' @rdname expm1 -#' @name expm1 -#' @aliases expm1,Column-method -#' @family math functions +#' @rdname column_math_functions +#' @aliases expm1 expm1,Column-method #' @export -#' @examples \dontrun{expm1(df$c)} #' @note expm1 since 1.5.0 setMethod("expm1", signature(x = "Column"), @@ -719,18 +692,12 @@ setMethod("expm1", column(jc) }) -#' factorial -#' -#' Computes the factorial of the given value. -#' -#' @param x Column to compute on. +#' @details +#' \code{factorial}: Computes the factorial of the given value. #' -#' @rdname factorial -#' @name factorial -#' @aliases factorial,Column-method -#' @family math functions +#' @rdname column_math_functions +#' @aliases factorial factorial,Column-method #' @export -#' @examples \dontrun{factorial(df$c)} #' @note factorial since 1.5.0 setMethod("factorial", signature(x = "Column"), @@ -772,18 +739,12 @@ setMethod("first", column(jc) }) -#' floor -#' -#' Computes the floor of the given value. -#' -#' @param x Column to compute on. +#' @details +#' \code{floor}: Computes the floor of the given value. #' -#' @rdname floor -#' @name floor -#' @aliases floor,Column-method -#' @family math functions +#' @rdname column_math_functions +#' @aliases floor floor,Column-method #' @export -#' @examples \dontrun{floor(df$c)} #' @note floor since 1.5.0 setMethod("floor", signature(x = "Column"), @@ -792,18 +753,12 @@ setMethod("floor", column(jc) }) -#' hex -#' -#' Computes hex value of the given column. -#' -#' @param x Column to compute on. +#' @details +#' \code{hex}: Computes hex value of the given column. #' -#' @rdname hex -#' @name hex -#' @family math functions -#' @aliases hex,Column-method +#' @rdname column_math_functions +#' @aliases hex hex,Column-method #' @export -#' @examples \dontrun{hex(df$c)} #' @note hex since 1.5.0 setMethod("hex", signature(x = "Column"), @@ -983,18 +938,12 @@ setMethod("length", column(jc) }) -#' log -#' -#' Computes the natural logarithm of the given value. -#' -#' @param x Column to compute on. +#' @details +#' \code{log}: Computes the natural logarithm of the given value. #' -#' @rdname log -#' @name log -#' @aliases log,Column-method -#' @family math functions +#' @rdname column_math_functions +#' @aliases log log,Column-method #' @export -#' @examples \dontrun{log(df$c)} #' @note log since 1.5.0 setMethod("log", signature(x = "Column"), @@ -1003,18 +952,12 @@ setMethod("log", column(jc) }) -#' log10 -#' -#' Computes the logarithm of the given value in base 10. -#' -#' @param x Column to compute on. +#' @details +#' \code{log10}: Computes the logarithm of the given value in base 10. #' -#' @rdname log10 -#' @name log10 -#' @family math functions -#' @aliases log10,Column-method +#' @rdname column_math_functions +#' @aliases log10 log10,Column-method #' @export -#' @examples \dontrun{log10(df$c)} #' @note log10 since 1.5.0 setMethod("log10", signature(x = "Column"), @@ -1023,18 +966,12 @@ setMethod("log10", column(jc) }) -#' log1p -#' -#' Computes the natural logarithm of the given value plus one. -#' -#' @param x Column to compute on. +#' @details +#' \code{log1p}: Computes the natural logarithm of the given value plus one. #' -#' @rdname log1p -#' @name log1p -#' @family math functions -#' @aliases log1p,Column-method +#' @rdname column_math_functions +#' @aliases log1p log1p,Column-method #' @export -#' @examples \dontrun{log1p(df$c)} #' @note log1p since 1.5.0 setMethod("log1p", signature(x = "Column"), @@ -1043,18 +980,12 @@ setMethod("log1p", column(jc) }) -#' log2 -#' -#' Computes the logarithm of the given column in base 2. -#' -#' @param x Column to compute on. +#' @details +#' \code{log2}: Computes the logarithm of the given column in base 2. #' -#' @rdname log2 -#' @name log2 -#' @family math functions -#' @aliases log2,Column-method +#' @rdname column_math_functions +#' @aliases log2 log2,Column-method #' @export -#' @examples \dontrun{log2(df$c)} #' @note log2 since 1.5.0 setMethod("log2", signature(x = "Column"), @@ -1287,19 +1218,13 @@ setMethod("reverse", column(jc) }) -#' rint -#' -#' Returns the double value that is closest in value to the argument and +#' @details +#' \code{rint}: Returns the double value that is closest in value to the argument and #' is equal to a mathematical integer. #' -#' @param x Column to compute on. -#' -#' @rdname rint -#' @name rint -#' @family math functions -#' @aliases rint,Column-method +#' @rdname column_math_functions +#' @aliases rint rint,Column-method #' @export -#' @examples \dontrun{rint(df$c)} #' @note rint since 1.5.0 setMethod("rint", signature(x = "Column"), @@ -1308,18 +1233,13 @@ setMethod("rint", column(jc) }) -#' round -#' -#' Returns the value of the column \code{e} rounded to 0 decimal places using HALF_UP rounding mode. -#' -#' @param x Column to compute on. +#' @details +#' \code{round}: Returns the value of the column rounded to 0 decimal places +#' using HALF_UP rounding mode. #' -#' @rdname round -#' @name round -#' @family math functions -#' @aliases round,Column-method +#' @rdname column_math_functions +#' @aliases round round,Column-method #' @export -#' @examples \dontrun{round(df$c)} #' @note round since 1.5.0 setMethod("round", signature(x = "Column"), @@ -1328,24 +1248,18 @@ setMethod("round", column(jc) }) -#' bround -#' -#' Returns the value of the column \code{e} rounded to \code{scale} decimal places using HALF_EVEN rounding -#' mode if \code{scale} >= 0 or at integer part when \code{scale} < 0. +#' @details +#' \code{bround}: Returns the value of the column \code{e} rounded to \code{scale} decimal places +#' using HALF_EVEN rounding mode if \code{scale} >= 0 or at integer part when \code{scale} < 0. #' Also known as Gaussian rounding or bankers' rounding that rounds to the nearest even number. #' bround(2.5, 0) = 2, bround(3.5, 0) = 4. #' -#' @param x Column to compute on. #' @param scale round to \code{scale} digits to the right of the decimal point when \code{scale} > 0, #' the nearest even number when \code{scale} = 0, and \code{scale} digits to the left #' of the decimal point when \code{scale} < 0. -#' @param ... further arguments to be passed to or from other methods. -#' @rdname bround -#' @name bround -#' @family math functions -#' @aliases bround,Column-method +#' @rdname column_math_functions +#' @aliases bround bround,Column-method #' @export -#' @examples \dontrun{bround(df$c, 0)} #' @note bround since 2.0.0 setMethod("bround", signature(x = "Column"), @@ -1354,7 +1268,6 @@ setMethod("bround", column(jc) }) - #' rtrim #' #' Trim the spaces from right end for the specified string value. @@ -1375,7 +1288,6 @@ setMethod("rtrim", column(jc) }) - #' @details #' \code{sd}: Alias for \code{stddev_samp}. #' @@ -1429,18 +1341,12 @@ setMethod("sha1", column(jc) }) -#' signum -#' -#' Computes the signum of the given value. -#' -#' @param x Column to compute on. +#' @details +#' \code{signum}: Computes the signum of the given value. #' -#' @rdname sign -#' @name signum -#' @aliases signum,Column-method -#' @family math functions +#' @rdname column_math_functions +#' @aliases signum signum,Column-method #' @export -#' @examples \dontrun{signum(df$c)} #' @note signum since 1.5.0 setMethod("signum", signature(x = "Column"), @@ -1449,18 +1355,24 @@ setMethod("signum", column(jc) }) -#' sin -#' -#' Computes the sine of the given value. +#' @details +#' \code{sign}: Alias for \code{signum}. #' -#' @param x Column to compute on. +#' @rdname column_math_functions +#' @aliases sign sign,Column-method +#' @export +#' @note sign since 1.5.0 +setMethod("sign", signature(x = "Column"), + function(x) { + signum(x) + }) + +#' @details +#' \code{sin}: Computes the sine of the given value. #' -#' @rdname sin -#' @name sin -#' @family math functions -#' @aliases sin,Column-method +#' @rdname column_math_functions +#' @aliases sin sin,Column-method #' @export -#' @examples \dontrun{sin(df$c)} #' @note sin since 1.5.0 setMethod("sin", signature(x = "Column"), @@ -1469,18 +1381,12 @@ setMethod("sin", column(jc) }) -#' sinh -#' -#' Computes the hyperbolic sine of the given value. -#' -#' @param x Column to compute on. +#' @details +#' \code{sinh}: Computes the hyperbolic sine of the given value. #' -#' @rdname sinh -#' @name sinh -#' @family math functions -#' @aliases sinh,Column-method +#' @rdname column_math_functions +#' @aliases sinh sinh,Column-method #' @export -#' @examples \dontrun{sinh(df$c)} #' @note sinh since 1.5.0 setMethod("sinh", signature(x = "Column"), @@ -1616,18 +1522,12 @@ setMethod("struct", column(jc) }) -#' sqrt -#' -#' Computes the square root of the specified float value. -#' -#' @param x Column to compute on. +#' @details +#' \code{sqrt}: Computes the square root of the specified float value. #' -#' @rdname sqrt -#' @name sqrt -#' @family math functions -#' @aliases sqrt,Column-method +#' @rdname column_math_functions +#' @aliases sqrt sqrt,Column-method #' @export -#' @examples \dontrun{sqrt(df$c)} #' @note sqrt since 1.5.0 setMethod("sqrt", signature(x = "Column"), @@ -1669,18 +1569,12 @@ setMethod("sumDistinct", column(jc) }) -#' tan -#' -#' Computes the tangent of the given value. -#' -#' @param x Column to compute on. +#' @details +#' \code{tan}: Computes the tangent of the given value. #' -#' @rdname tan -#' @name tan -#' @family math functions -#' @aliases tan,Column-method +#' @rdname column_math_functions +#' @aliases tan tan,Column-method #' @export -#' @examples \dontrun{tan(df$c)} #' @note tan since 1.5.0 setMethod("tan", signature(x = "Column"), @@ -1689,18 +1583,12 @@ setMethod("tan", column(jc) }) -#' tanh -#' -#' Computes the hyperbolic tangent of the given value. -#' -#' @param x Column to compute on. +#' @details +#' \code{tanh}: Computes the hyperbolic tangent of the given value. #' -#' @rdname tanh -#' @name tanh -#' @family math functions -#' @aliases tanh,Column-method +#' @rdname column_math_functions +#' @aliases tanh tanh,Column-method #' @export -#' @examples \dontrun{tanh(df$c)} #' @note tanh since 1.5.0 setMethod("tanh", signature(x = "Column"), @@ -1709,18 +1597,13 @@ setMethod("tanh", column(jc) }) -#' toDegrees -#' -#' Converts an angle measured in radians to an approximately equivalent angle measured in degrees. -#' -#' @param x Column to compute on. +#' @details +#' \code{toDegrees}: Converts an angle measured in radians to an approximately equivalent angle +#' measured in degrees. #' -#' @rdname toDegrees -#' @name toDegrees -#' @family math functions -#' @aliases toDegrees,Column-method +#' @rdname column_math_functions +#' @aliases toDegrees toDegrees,Column-method #' @export -#' @examples \dontrun{toDegrees(df$c)} #' @note toDegrees since 1.4.0 setMethod("toDegrees", signature(x = "Column"), @@ -1729,18 +1612,13 @@ setMethod("toDegrees", column(jc) }) -#' toRadians -#' -#' Converts an angle measured in degrees to an approximately equivalent angle measured in radians. -#' -#' @param x Column to compute on. +#' @details +#' \code{toRadians}: Converts an angle measured in degrees to an approximately equivalent angle +#' measured in radians. #' -#' @rdname toRadians -#' @name toRadians -#' @family math functions -#' @aliases toRadians,Column-method +#' @rdname column_math_functions +#' @aliases toRadians toRadians,Column-method #' @export -#' @examples \dontrun{toRadians(df$c)} #' @note toRadians since 1.4.0 setMethod("toRadians", signature(x = "Column"), @@ -1894,19 +1772,13 @@ setMethod("unbase64", column(jc) }) -#' unhex -#' -#' Inverse of hex. Interprets each pair of characters as a hexadecimal number +#' @details +#' \code{unhex}: Inverse of hex. Interprets each pair of characters as a hexadecimal number #' and converts to the byte representation of number. #' -#' @param x Column to compute on. -#' -#' @rdname unhex -#' @name unhex -#' @family math functions -#' @aliases unhex,Column-method +#' @rdname column_math_functions +#' @aliases unhex unhex,Column-method #' @export -#' @examples \dontrun{unhex(df$c)} #' @note unhex since 1.5.0 setMethod("unhex", signature(x = "Column"), @@ -2020,20 +1892,13 @@ setMethod("year", column(jc) }) -#' atan2 -#' -#' Returns the angle theta from the conversion of rectangular coordinates (x, y) to -#' polar coordinates (r, theta). -# -#' @param x Column to compute on. -#' @param y Column to compute on. +#' @details +#' \code{atan2}: Returns the angle theta from the conversion of rectangular coordinates +#' (x, y) to polar coordinates (r, theta). #' -#' @rdname atan2 -#' @name atan2 -#' @family math functions -#' @aliases atan2,Column-method +#' @rdname column_math_functions +#' @aliases atan2 atan2,Column-method #' @export -#' @examples \dontrun{atan2(df$c, x)} #' @note atan2 since 1.5.0 setMethod("atan2", signature(y = "Column"), function(y, x) { @@ -2068,19 +1933,12 @@ setMethod("datediff", signature(y = "Column"), column(jc) }) -#' hypot -#' -#' Computes "sqrt(a^2 + b^2)" without intermediate overflow or underflow. -# -#' @param x Column to compute on. -#' @param y Column to compute on. +#' @details +#' \code{hypot}: Computes "sqrt(a^2 + b^2)" without intermediate overflow or underflow. #' -#' @rdname hypot -#' @name hypot -#' @family math functions -#' @aliases hypot,Column-method +#' @rdname column_math_functions +#' @aliases hypot hypot,Column-method #' @export -#' @examples \dontrun{hypot(df$c, x)} #' @note hypot since 1.4.0 setMethod("hypot", signature(y = "Column"), function(y, x) { @@ -2154,20 +2012,13 @@ setMethod("nanvl", signature(y = "Column"), column(jc) }) -#' pmod -#' -#' Returns the positive value of dividend mod divisor. -#' -#' @param x divisor Column. -#' @param y dividend Column. +#' @details +#' \code{pmod}: Returns the positive value of dividend mod divisor. +#' Column \code{x} is divisor column, and column \code{y} is the dividend column. #' -#' @rdname pmod -#' @name pmod -#' @docType methods -#' @family math functions -#' @aliases pmod,Column-method +#' @rdname column_math_functions +#' @aliases pmod pmod,Column-method #' @export -#' @examples \dontrun{pmod(df$c, x)} #' @note pmod since 1.5.0 setMethod("pmod", signature(y = "Column"), function(y, x) { @@ -2290,31 +2141,6 @@ setMethod("least", column(jc) }) -#' @rdname ceil -#' -#' @name ceiling -#' @aliases ceiling,Column-method -#' @export -#' @examples \dontrun{ceiling(df$c)} -#' @note ceiling since 1.5.0 -setMethod("ceiling", - signature(x = "Column"), - function(x) { - ceil(x) - }) - -#' @rdname sign -#' -#' @name sign -#' @aliases sign,Column-method -#' @export -#' @examples \dontrun{sign(df$c)} -#' @note sign since 1.5.0 -setMethod("sign", signature(x = "Column"), - function(x) { - signum(x) - }) - #' @details #' \code{n_distinct}: Returns the number of distinct items in a group. #' @@ -2564,20 +2390,13 @@ setMethod("sha2", signature(y = "Column", x = "numeric"), column(jc) }) -#' shiftLeft -#' -#' Shift the given value numBits left. If the given value is a long value, this function -#' will return a long value else it will return an integer value. -#' -#' @param y column to compute on. -#' @param x number of bits to shift. +#' @details +#' \code{shiftLeft}: Shifts the given value numBits left. If the given value is a long value, +#' this function will return a long value else it will return an integer value. #' -#' @family math functions -#' @rdname shiftLeft -#' @name shiftLeft -#' @aliases shiftLeft,Column,numeric-method +#' @rdname column_math_functions +#' @aliases shiftLeft shiftLeft,Column,numeric-method #' @export -#' @examples \dontrun{shiftLeft(df$c, 1)} #' @note shiftLeft since 1.5.0 setMethod("shiftLeft", signature(y = "Column", x = "numeric"), function(y, x) { @@ -2587,20 +2406,13 @@ setMethod("shiftLeft", signature(y = "Column", x = "numeric"), column(jc) }) -#' shiftRight -#' -#' (Signed) shift the given value numBits right. If the given value is a long value, it will return -#' a long value else it will return an integer value. -#' -#' @param y column to compute on. -#' @param x number of bits to shift. +#' @details +#' \code{shiftRight}: (Signed) shifts the given value numBits right. If the given value is a long value, +#' it will return a long value else it will return an integer value. #' -#' @family math functions -#' @rdname shiftRight -#' @name shiftRight -#' @aliases shiftRight,Column,numeric-method +#' @rdname column_math_functions +#' @aliases shiftRight shiftRight,Column,numeric-method #' @export -#' @examples \dontrun{shiftRight(df$c, 1)} #' @note shiftRight since 1.5.0 setMethod("shiftRight", signature(y = "Column", x = "numeric"), function(y, x) { @@ -2610,20 +2422,13 @@ setMethod("shiftRight", signature(y = "Column", x = "numeric"), column(jc) }) -#' shiftRightUnsigned -#' -#' Unsigned shift the given value numBits right. If the given value is a long value, +#' @details +#' \code{shiftRight}: (Unigned) shifts the given value numBits right. If the given value is a long value, #' it will return a long value else it will return an integer value. #' -#' @param y column to compute on. -#' @param x number of bits to shift. -#' -#' @family math functions -#' @rdname shiftRightUnsigned -#' @name shiftRightUnsigned -#' @aliases shiftRightUnsigned,Column,numeric-method +#' @rdname column_math_functions +#' @aliases shiftRightUnsigned shiftRightUnsigned,Column,numeric-method #' @export -#' @examples \dontrun{shiftRightUnsigned(df$c, 1)} #' @note shiftRightUnsigned since 1.5.0 setMethod("shiftRightUnsigned", signature(y = "Column", x = "numeric"), function(y, x) { @@ -2656,20 +2461,14 @@ setMethod("concat_ws", signature(sep = "character", x = "Column"), column(jc) }) -#' conv -#' -#' Convert a number in a string column from one base to another. +#' @details +#' \code{conv}: Converts a number in a string column from one base to another. #' -#' @param x column to convert. #' @param fromBase base to convert from. #' @param toBase base to convert to. -#' -#' @family math functions -#' @rdname conv -#' @aliases conv,Column,numeric,numeric-method -#' @name conv +#' @rdname column_math_functions +#' @aliases conv conv,Column,numeric,numeric-method #' @export -#' @examples \dontrun{conv(df$n, 2, 16)} #' @note conv since 1.5.0 setMethod("conv", signature(x = "Column", fromBase = "numeric", toBase = "numeric"), function(x, fromBase, toBase) { diff --git a/R/pkg/R/generics.R b/R/pkg/R/generics.R index f105174cea70d..0248ec585d771 100644 --- a/R/pkg/R/generics.R +++ b/R/pkg/R/generics.R @@ -931,24 +931,28 @@ setGeneric("avg", function(x, ...) { standardGeneric("avg") }) #' @export setGeneric("base64", function(x) { standardGeneric("base64") }) -#' @rdname bin +#' @rdname column_math_functions #' @export +#' @name NULL setGeneric("bin", function(x) { standardGeneric("bin") }) #' @rdname bitwiseNOT #' @export setGeneric("bitwiseNOT", function(x) { standardGeneric("bitwiseNOT") }) -#' @rdname bround +#' @rdname column_math_functions #' @export +#' @name NULL setGeneric("bround", function(x, ...) { standardGeneric("bround") }) -#' @rdname cbrt +#' @rdname column_math_functions #' @export +#' @name NULL setGeneric("cbrt", function(x) { standardGeneric("cbrt") }) -#' @rdname ceil +#' @rdname column_math_functions #' @export +#' @name NULL setGeneric("ceil", function(x) { standardGeneric("ceil") }) #' @rdname column_aggregate_functions @@ -973,8 +977,9 @@ setGeneric("concat", function(x, ...) { standardGeneric("concat") }) #' @export setGeneric("concat_ws", function(sep, x, ...) { standardGeneric("concat_ws") }) -#' @rdname conv +#' @rdname column_math_functions #' @export +#' @name NULL setGeneric("conv", function(x, fromBase, toBase) { standardGeneric("conv") }) #' @rdname column_aggregate_functions @@ -1094,8 +1099,9 @@ setGeneric("grouping_bit", function(x) { standardGeneric("grouping_bit") }) #' @name NULL setGeneric("grouping_id", function(x, ...) { standardGeneric("grouping_id") }) -#' @rdname hex +#' @rdname column_math_functions #' @export +#' @name NULL setGeneric("hex", function(x) { standardGeneric("hex") }) #' @rdname column_datetime_functions @@ -1103,8 +1109,9 @@ setGeneric("hex", function(x) { standardGeneric("hex") }) #' @name NULL setGeneric("hour", function(x) { standardGeneric("hour") }) -#' @rdname hypot +#' @rdname column_math_functions #' @export +#' @name NULL setGeneric("hypot", function(y, x) { standardGeneric("hypot") }) #' @rdname initcap @@ -1235,8 +1242,9 @@ setGeneric("n_distinct", function(x, ...) { standardGeneric("n_distinct") }) #' @export setGeneric("percent_rank", function(x = "missing") { standardGeneric("percent_rank") }) -#' @rdname pmod +#' @rdname column_math_functions #' @export +#' @name NULL setGeneric("pmod", function(y, x) { standardGeneric("pmod") }) #' @rdname posexplode @@ -1281,8 +1289,9 @@ setGeneric("repeat_string", function(x, n) { standardGeneric("repeat_string") }) #' @export setGeneric("reverse", function(x) { standardGeneric("reverse") }) -#' @rdname rint +#' @rdname column_math_functions #' @export +#' @name NULL setGeneric("rint", function(x) { standardGeneric("rint") }) #' @param x empty. Should be used with no argument. @@ -1316,20 +1325,24 @@ setGeneric("sha1", function(x) { standardGeneric("sha1") }) #' @export setGeneric("sha2", function(y, x) { standardGeneric("sha2") }) -#' @rdname shiftLeft +#' @rdname column_math_functions #' @export +#' @name NULL setGeneric("shiftLeft", function(y, x) { standardGeneric("shiftLeft") }) -#' @rdname shiftRight +#' @rdname column_math_functions #' @export +#' @name NULL setGeneric("shiftRight", function(y, x) { standardGeneric("shiftRight") }) -#' @rdname shiftRightUnsigned +#' @rdname column_math_functions #' @export +#' @name NULL setGeneric("shiftRightUnsigned", function(y, x) { standardGeneric("shiftRightUnsigned") }) -#' @rdname sign +#' @rdname column_math_functions #' @export +#' @name NULL setGeneric("signum", function(x) { standardGeneric("signum") }) #' @rdname size @@ -1386,12 +1399,14 @@ setGeneric("substring_index", function(x, delim, count) { standardGeneric("subst #' @name NULL setGeneric("sumDistinct", function(x) { standardGeneric("sumDistinct") }) -#' @rdname toDegrees +#' @rdname column_math_functions #' @export +#' @name NULL setGeneric("toDegrees", function(x) { standardGeneric("toDegrees") }) -#' @rdname toRadians +#' @rdname column_math_functions #' @export +#' @name NULL setGeneric("toRadians", function(x) { standardGeneric("toRadians") }) #' @rdname column_datetime_functions @@ -1425,8 +1440,9 @@ setGeneric("trim", function(x) { standardGeneric("trim") }) #' @export setGeneric("unbase64", function(x) { standardGeneric("unbase64") }) -#' @rdname unhex +#' @rdname column_math_functions #' @export +#' @name NULL setGeneric("unhex", function(x) { standardGeneric("unhex") }) #' @rdname column_datetime_functions From 838effb98a0d3410766771533402ce0386133af3 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Wed, 28 Jun 2017 14:28:40 +0800 Subject: [PATCH 0806/1765] Revert "[SPARK-13534][PYSPARK] Using Apache Arrow to increase performance of DataFrame.toPandas" This reverts commit e44697606f429b01808c1a22cb44cb5b89585c5c. --- bin/pyspark | 2 +- dev/deps/spark-deps-hadoop-2.6 | 5 - dev/deps/spark-deps-hadoop-2.7 | 5 - dev/run-pip-tests | 6 - pom.xml | 20 - python/pyspark/serializers.py | 17 - python/pyspark/sql/dataframe.py | 48 +- python/pyspark/sql/tests.py | 79 +- .../apache/spark/sql/internal/SQLConf.scala | 22 - sql/core/pom.xml | 4 - .../scala/org/apache/spark/sql/Dataset.scala | 20 - .../sql/execution/arrow/ArrowConverters.scala | 429 ------ .../arrow/ArrowConvertersSuite.scala | 1222 ----------------- 13 files changed, 13 insertions(+), 1866 deletions(-) delete mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala delete mode 100644 sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowConvertersSuite.scala diff --git a/bin/pyspark b/bin/pyspark index 8eeea7716cc98..98387c2ec5b8a 100755 --- a/bin/pyspark +++ b/bin/pyspark @@ -68,7 +68,7 @@ if [[ -n "$SPARK_TESTING" ]]; then unset YARN_CONF_DIR unset HADOOP_CONF_DIR export PYTHONHASHSEED=0 - exec "$PYSPARK_DRIVER_PYTHON" -m "$@" + exec "$PYSPARK_DRIVER_PYTHON" -m "$1" exit fi diff --git a/dev/deps/spark-deps-hadoop-2.6 b/dev/deps/spark-deps-hadoop-2.6 index 9868c1ab7c2ab..9287bd47cf113 100644 --- a/dev/deps/spark-deps-hadoop-2.6 +++ b/dev/deps/spark-deps-hadoop-2.6 @@ -13,9 +13,6 @@ apacheds-kerberos-codec-2.0.0-M15.jar api-asn1-api-1.0.0-M20.jar api-util-1.0.0-M20.jar arpack_combined_all-0.1.jar -arrow-format-0.4.0.jar -arrow-memory-0.4.0.jar -arrow-vector-0.4.0.jar avro-1.7.7.jar avro-ipc-1.7.7.jar avro-mapred-1.7.7-hadoop2.jar @@ -58,7 +55,6 @@ datanucleus-core-3.2.10.jar datanucleus-rdbms-3.2.9.jar derby-10.12.1.1.jar eigenbase-properties-1.1.5.jar -flatbuffers-1.2.0-3f79e055.jar gson-2.2.4.jar guava-14.0.1.jar guice-3.0.jar @@ -81,7 +77,6 @@ hadoop-yarn-server-web-proxy-2.6.5.jar hk2-api-2.4.0-b34.jar hk2-locator-2.4.0-b34.jar hk2-utils-2.4.0-b34.jar -hppc-0.7.1.jar htrace-core-3.0.4.jar httpclient-4.5.2.jar httpcore-4.4.4.jar diff --git a/dev/deps/spark-deps-hadoop-2.7 b/dev/deps/spark-deps-hadoop-2.7 index 57c78cfe12087..9127413ab6c23 100644 --- a/dev/deps/spark-deps-hadoop-2.7 +++ b/dev/deps/spark-deps-hadoop-2.7 @@ -13,9 +13,6 @@ apacheds-kerberos-codec-2.0.0-M15.jar api-asn1-api-1.0.0-M20.jar api-util-1.0.0-M20.jar arpack_combined_all-0.1.jar -arrow-format-0.4.0.jar -arrow-memory-0.4.0.jar -arrow-vector-0.4.0.jar avro-1.7.7.jar avro-ipc-1.7.7.jar avro-mapred-1.7.7-hadoop2.jar @@ -58,7 +55,6 @@ datanucleus-core-3.2.10.jar datanucleus-rdbms-3.2.9.jar derby-10.12.1.1.jar eigenbase-properties-1.1.5.jar -flatbuffers-1.2.0-3f79e055.jar gson-2.2.4.jar guava-14.0.1.jar guice-3.0.jar @@ -81,7 +77,6 @@ hadoop-yarn-server-web-proxy-2.7.3.jar hk2-api-2.4.0-b34.jar hk2-locator-2.4.0-b34.jar hk2-utils-2.4.0-b34.jar -hppc-0.7.1.jar htrace-core-3.1.0-incubating.jar httpclient-4.5.2.jar httpcore-4.4.4.jar diff --git a/dev/run-pip-tests b/dev/run-pip-tests index 225e9209536f0..d51dde12a03c5 100755 --- a/dev/run-pip-tests +++ b/dev/run-pip-tests @@ -83,8 +83,6 @@ for python in "${PYTHON_EXECS[@]}"; do if [ -n "$USE_CONDA" ]; then conda create -y -p "$VIRTUALENV_PATH" python=$python numpy pandas pip setuptools source activate "$VIRTUALENV_PATH" - conda install -y -c conda-forge pyarrow=0.4.0 - TEST_PYARROW=1 else mkdir -p "$VIRTUALENV_PATH" virtualenv --python=$python "$VIRTUALENV_PATH" @@ -122,10 +120,6 @@ for python in "${PYTHON_EXECS[@]}"; do python "$FWDIR"/dev/pip-sanity-check.py echo "Run the tests for context.py" python "$FWDIR"/python/pyspark/context.py - if [ -n "$TEST_PYARROW" ]; then - echo "Run tests for pyarrow" - SPARK_TESTING=1 "$FWDIR"/bin/pyspark pyspark.sql.tests ArrowTests - fi cd "$FWDIR" diff --git a/pom.xml b/pom.xml index f124ba45007b7..5f524079495c0 100644 --- a/pom.xml +++ b/pom.xml @@ -181,7 +181,6 @@ 2.6 1.8 1.0.0 - 0.4.0 ${java.home} @@ -1879,25 +1878,6 @@ paranamer ${paranamer.version} - - org.apache.arrow - arrow-vector - ${arrow.version} - - - com.fasterxml.jackson.core - jackson-annotations - - - com.fasterxml.jackson.core - jackson-databind - - - io.netty - netty-handler - - - diff --git a/python/pyspark/serializers.py b/python/pyspark/serializers.py index d5c2a7518b18f..ea5e00e9eeef5 100644 --- a/python/pyspark/serializers.py +++ b/python/pyspark/serializers.py @@ -182,23 +182,6 @@ def loads(self, obj): raise NotImplementedError -class ArrowSerializer(FramedSerializer): - """ - Serializes an Arrow stream. - """ - - def dumps(self, obj): - raise NotImplementedError - - def loads(self, obj): - import pyarrow as pa - reader = pa.RecordBatchFileReader(pa.BufferReader(obj)) - return reader.read_all() - - def __repr__(self): - return "ArrowSerializer" - - class BatchedSerializer(Serializer): """ diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index 760f113dfd197..0649271ed2246 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -29,8 +29,7 @@ from pyspark import copy_func, since from pyspark.rdd import RDD, _load_from_socket, ignore_unicode_prefix -from pyspark.serializers import ArrowSerializer, BatchedSerializer, PickleSerializer, \ - UTF8Deserializer +from pyspark.serializers import BatchedSerializer, PickleSerializer, UTF8Deserializer from pyspark.storagelevel import StorageLevel from pyspark.traceback_utils import SCCallSiteSync from pyspark.sql.types import _parse_datatype_json_string @@ -1709,8 +1708,7 @@ def toDF(self, *cols): @since(1.3) def toPandas(self): - """ - Returns the contents of this :class:`DataFrame` as Pandas ``pandas.DataFrame``. + """Returns the contents of this :class:`DataFrame` as Pandas ``pandas.DataFrame``. This is only available if Pandas is installed and available. @@ -1723,42 +1721,18 @@ def toPandas(self): 1 5 Bob """ import pandas as pd - if self.sql_ctx.getConf("spark.sql.execution.arrow.enable", "false").lower() == "true": - try: - import pyarrow - tables = self._collectAsArrow() - if tables: - table = pyarrow.concat_tables(tables) - return table.to_pandas() - else: - return pd.DataFrame.from_records([], columns=self.columns) - except ImportError as e: - msg = "note: pyarrow must be installed and available on calling Python process " \ - "if using spark.sql.execution.arrow.enable=true" - raise ImportError("%s\n%s" % (e.message, msg)) - else: - dtype = {} - for field in self.schema: - pandas_type = _to_corrected_pandas_type(field.dataType) - if pandas_type is not None: - dtype[field.name] = pandas_type - pdf = pd.DataFrame.from_records(self.collect(), columns=self.columns) + dtype = {} + for field in self.schema: + pandas_type = _to_corrected_pandas_type(field.dataType) + if pandas_type is not None: + dtype[field.name] = pandas_type - for f, t in dtype.items(): - pdf[f] = pdf[f].astype(t, copy=False) - return pdf + pdf = pd.DataFrame.from_records(self.collect(), columns=self.columns) - def _collectAsArrow(self): - """ - Returns all records as list of deserialized ArrowPayloads, pyarrow must be installed - and available. - - .. note:: Experimental. - """ - with SCCallSiteSync(self._sc) as css: - port = self._jdf.collectAsArrowToPython() - return list(_load_from_socket(port, ArrowSerializer())) + for f, t in dtype.items(): + pdf[f] = pdf[f].astype(t, copy=False) + return pdf ########################################################################################## # Pandas compatibility diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 326e8548a617c..0a1cd6856b8e8 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -58,21 +58,12 @@ from pyspark.sql import SparkSession, SQLContext, HiveContext, Column, Row from pyspark.sql.types import * from pyspark.sql.types import UserDefinedType, _infer_type -from pyspark.tests import QuietTest, ReusedPySparkTestCase, SparkSubmitTests +from pyspark.tests import ReusedPySparkTestCase, SparkSubmitTests from pyspark.sql.functions import UserDefinedFunction, sha2, lit from pyspark.sql.window import Window from pyspark.sql.utils import AnalysisException, ParseException, IllegalArgumentException -_have_arrow = False -try: - import pyarrow - _have_arrow = True -except: - # No Arrow, but that's okay, we'll skip those tests - pass - - class UTCOffsetTimezone(datetime.tzinfo): """ Specifies timezone in UTC offset @@ -2629,74 +2620,6 @@ def range_frame_match(): importlib.reload(window) - -@unittest.skipIf(not _have_arrow, "Arrow not installed") -class ArrowTests(ReusedPySparkTestCase): - - @classmethod - def setUpClass(cls): - ReusedPySparkTestCase.setUpClass() - cls.spark = SparkSession(cls.sc) - cls.spark.conf.set("spark.sql.execution.arrow.enable", "true") - cls.schema = StructType([ - StructField("1_str_t", StringType(), True), - StructField("2_int_t", IntegerType(), True), - StructField("3_long_t", LongType(), True), - StructField("4_float_t", FloatType(), True), - StructField("5_double_t", DoubleType(), True)]) - cls.data = [("a", 1, 10, 0.2, 2.0), - ("b", 2, 20, 0.4, 4.0), - ("c", 3, 30, 0.8, 6.0)] - - def assertFramesEqual(self, df_with_arrow, df_without): - msg = ("DataFrame from Arrow is not equal" + - ("\n\nWith Arrow:\n%s\n%s" % (df_with_arrow, df_with_arrow.dtypes)) + - ("\n\nWithout:\n%s\n%s" % (df_without, df_without.dtypes))) - self.assertTrue(df_without.equals(df_with_arrow), msg=msg) - - def test_unsupported_datatype(self): - schema = StructType([StructField("array", ArrayType(IntegerType(), False), True)]) - df = self.spark.createDataFrame([([1, 2, 3],)], schema=schema) - with QuietTest(self.sc): - self.assertRaises(Exception, lambda: df.toPandas()) - - def test_null_conversion(self): - df_null = self.spark.createDataFrame([tuple([None for _ in range(len(self.data[0]))])] + - self.data) - pdf = df_null.toPandas() - null_counts = pdf.isnull().sum().tolist() - self.assertTrue(all([c == 1 for c in null_counts])) - - def test_toPandas_arrow_toggle(self): - df = self.spark.createDataFrame(self.data, schema=self.schema) - self.spark.conf.set("spark.sql.execution.arrow.enable", "false") - pdf = df.toPandas() - self.spark.conf.set("spark.sql.execution.arrow.enable", "true") - pdf_arrow = df.toPandas() - self.assertFramesEqual(pdf_arrow, pdf) - - def test_pandas_round_trip(self): - import pandas as pd - import numpy as np - data_dict = {} - for j, name in enumerate(self.schema.names): - data_dict[name] = [self.data[i][j] for i in range(len(self.data))] - # need to convert these to numpy types first - data_dict["2_int_t"] = np.int32(data_dict["2_int_t"]) - data_dict["4_float_t"] = np.float32(data_dict["4_float_t"]) - pdf = pd.DataFrame(data=data_dict) - df = self.spark.createDataFrame(self.data, schema=self.schema) - pdf_arrow = df.toPandas() - self.assertFramesEqual(pdf_arrow, pdf) - - def test_filtered_frame(self): - df = self.spark.range(3).toDF("i") - pdf = df.filter("i < 0").toPandas() - self.assertEqual(len(pdf.columns), 1) - self.assertEqual(pdf.columns[0], "i") - self.assertTrue(pdf.empty) - - if __name__ == "__main__": from pyspark.sql.tests import * if xmlrunner: diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 9c8e26a8eeadf..c641e4d3a23e1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -847,24 +847,6 @@ object SQLConf { .intConf .createWithDefault(UnsafeExternalSorter.DEFAULT_NUM_ELEMENTS_FOR_SPILL_THRESHOLD.toInt) - val ARROW_EXECUTION_ENABLE = - buildConf("spark.sql.execution.arrow.enable") - .internal() - .doc("Make use of Apache Arrow for columnar data transfers. Currently available " + - "for use with pyspark.sql.DataFrame.toPandas with the following data types: " + - "StringType, BinaryType, BooleanType, DoubleType, FloatType, ByteType, IntegerType, " + - "LongType, ShortType") - .booleanConf - .createWithDefault(false) - - val ARROW_EXECUTION_MAX_RECORDS_PER_BATCH = - buildConf("spark.sql.execution.arrow.maxRecordsPerBatch") - .internal() - .doc("When using Apache Arrow, limit the maximum number of records that can be written " + - "to a single ArrowRecordBatch in memory. If set to zero or negative there is no limit.") - .intConf - .createWithDefault(10000) - object Deprecated { val MAPRED_REDUCE_TASKS = "mapred.reduce.tasks" } @@ -1123,10 +1105,6 @@ class SQLConf extends Serializable with Logging { def starSchemaFTRatio: Double = getConf(STARSCHEMA_FACT_TABLE_RATIO) - def arrowEnable: Boolean = getConf(ARROW_EXECUTION_ENABLE) - - def arrowMaxRecordsPerBatch: Int = getConf(ARROW_EXECUTION_MAX_RECORDS_PER_BATCH) - /** ********************** SQLConf functionality methods ************ */ /** Set Spark SQL configuration properties. */ diff --git a/sql/core/pom.xml b/sql/core/pom.xml index 661c31ded7148..1bc34a6b069d9 100644 --- a/sql/core/pom.xml +++ b/sql/core/pom.xml @@ -103,10 +103,6 @@ jackson-databind ${fasterxml.jackson.version} - - org.apache.arrow - arrow-vector - org.apache.xbean xbean-asm5-shaded diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index 268a37ff5d271..7be4aa1ca9562 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -47,7 +47,6 @@ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.plans.physical.{Partitioning, PartitioningCollection} import org.apache.spark.sql.catalyst.util.{usePrettyExpression, DateTimeUtils} import org.apache.spark.sql.execution._ -import org.apache.spark.sql.execution.arrow.{ArrowConverters, ArrowPayload} import org.apache.spark.sql.execution.command._ import org.apache.spark.sql.execution.datasources.LogicalRelation import org.apache.spark.sql.execution.python.EvaluatePython @@ -2887,16 +2886,6 @@ class Dataset[T] private[sql]( } } - /** - * Collect a Dataset as ArrowPayload byte arrays and serve to PySpark. - */ - private[sql] def collectAsArrowToPython(): Int = { - withNewExecutionId { - val iter = toArrowPayload.collect().iterator.map(_.asPythonSerializable) - PythonRDD.serveIterator(iter, "serve-Arrow") - } - } - private[sql] def toPythonIterator(): Int = { withNewExecutionId { PythonRDD.toLocalIteratorAndServe(javaToPython.rdd) @@ -2978,13 +2967,4 @@ class Dataset[T] private[sql]( Dataset(sparkSession, logicalPlan) } } - - /** Convert to an RDD of ArrowPayload byte arrays */ - private[sql] def toArrowPayload: RDD[ArrowPayload] = { - val schemaCaptured = this.schema - val maxRecordsPerBatch = sparkSession.sessionState.conf.arrowMaxRecordsPerBatch - queryExecution.toRdd.mapPartitionsInternal { iter => - ArrowConverters.toPayloadIterator(iter, schemaCaptured, maxRecordsPerBatch) - } - } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala deleted file mode 100644 index 6af5c73422377..0000000000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala +++ /dev/null @@ -1,429 +0,0 @@ -/* -* Licensed to the Apache Software Foundation (ASF) under one or more -* contributor license agreements. See the NOTICE file distributed with -* this work for additional information regarding copyright ownership. -* The ASF licenses this file to You under the Apache License, Version 2.0 -* (the "License"); you may not use this file except in compliance with -* the License. You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*/ - -package org.apache.spark.sql.execution.arrow - -import java.io.ByteArrayOutputStream -import java.nio.channels.Channels - -import scala.collection.JavaConverters._ - -import io.netty.buffer.ArrowBuf -import org.apache.arrow.memory.{BufferAllocator, RootAllocator} -import org.apache.arrow.vector._ -import org.apache.arrow.vector.BaseValueVector.BaseMutator -import org.apache.arrow.vector.file._ -import org.apache.arrow.vector.schema.{ArrowFieldNode, ArrowRecordBatch} -import org.apache.arrow.vector.types.FloatingPointPrecision -import org.apache.arrow.vector.types.pojo.{ArrowType, Field, FieldType, Schema} -import org.apache.arrow.vector.util.ByteArrayReadableSeekableByteChannel - -import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.types._ -import org.apache.spark.util.Utils - - -/** - * Store Arrow data in a form that can be serialized by Spark and served to a Python process. - */ -private[sql] class ArrowPayload private[arrow] (payload: Array[Byte]) extends Serializable { - - /** - * Convert the ArrowPayload to an ArrowRecordBatch. - */ - def loadBatch(allocator: BufferAllocator): ArrowRecordBatch = { - ArrowConverters.byteArrayToBatch(payload, allocator) - } - - /** - * Get the ArrowPayload as a type that can be served to Python. - */ - def asPythonSerializable: Array[Byte] = payload -} - -private[sql] object ArrowPayload { - - /** - * Create an ArrowPayload from an ArrowRecordBatch and Spark schema. - */ - def apply( - batch: ArrowRecordBatch, - schema: StructType, - allocator: BufferAllocator): ArrowPayload = { - new ArrowPayload(ArrowConverters.batchToByteArray(batch, schema, allocator)) - } -} - -private[sql] object ArrowConverters { - - /** - * Map a Spark DataType to ArrowType. - */ - private[arrow] def sparkTypeToArrowType(dataType: DataType): ArrowType = { - dataType match { - case BooleanType => ArrowType.Bool.INSTANCE - case ShortType => new ArrowType.Int(8 * ShortType.defaultSize, true) - case IntegerType => new ArrowType.Int(8 * IntegerType.defaultSize, true) - case LongType => new ArrowType.Int(8 * LongType.defaultSize, true) - case FloatType => new ArrowType.FloatingPoint(FloatingPointPrecision.SINGLE) - case DoubleType => new ArrowType.FloatingPoint(FloatingPointPrecision.DOUBLE) - case ByteType => new ArrowType.Int(8, true) - case StringType => ArrowType.Utf8.INSTANCE - case BinaryType => ArrowType.Binary.INSTANCE - case _ => throw new UnsupportedOperationException(s"Unsupported data type: $dataType") - } - } - - /** - * Convert a Spark Dataset schema to Arrow schema. - */ - private[arrow] def schemaToArrowSchema(schema: StructType): Schema = { - val arrowFields = schema.fields.map { f => - new Field(f.name, f.nullable, sparkTypeToArrowType(f.dataType), List.empty[Field].asJava) - } - new Schema(arrowFields.toList.asJava) - } - - /** - * Maps Iterator from InternalRow to ArrowPayload. Limit ArrowRecordBatch size in ArrowPayload - * by setting maxRecordsPerBatch or use 0 to fully consume rowIter. - */ - private[sql] def toPayloadIterator( - rowIter: Iterator[InternalRow], - schema: StructType, - maxRecordsPerBatch: Int): Iterator[ArrowPayload] = { - new Iterator[ArrowPayload] { - private val _allocator = new RootAllocator(Long.MaxValue) - private var _nextPayload = if (rowIter.nonEmpty) convert() else null - - override def hasNext: Boolean = _nextPayload != null - - override def next(): ArrowPayload = { - val obj = _nextPayload - if (hasNext) { - if (rowIter.hasNext) { - _nextPayload = convert() - } else { - _allocator.close() - _nextPayload = null - } - } - obj - } - - private def convert(): ArrowPayload = { - val batch = internalRowIterToArrowBatch(rowIter, schema, _allocator, maxRecordsPerBatch) - ArrowPayload(batch, schema, _allocator) - } - } - } - - /** - * Iterate over InternalRows and write to an ArrowRecordBatch, stopping when rowIter is consumed - * or the number of records in the batch equals maxRecordsInBatch. If maxRecordsPerBatch is 0, - * then rowIter will be fully consumed. - */ - private def internalRowIterToArrowBatch( - rowIter: Iterator[InternalRow], - schema: StructType, - allocator: BufferAllocator, - maxRecordsPerBatch: Int = 0): ArrowRecordBatch = { - - val columnWriters = schema.fields.zipWithIndex.map { case (field, ordinal) => - ColumnWriter(field.dataType, ordinal, allocator).init() - } - - val writerLength = columnWriters.length - var recordsInBatch = 0 - while (rowIter.hasNext && (maxRecordsPerBatch <= 0 || recordsInBatch < maxRecordsPerBatch)) { - val row = rowIter.next() - var i = 0 - while (i < writerLength) { - columnWriters(i).write(row) - i += 1 - } - recordsInBatch += 1 - } - - val (fieldNodes, bufferArrays) = columnWriters.map(_.finish()).unzip - val buffers = bufferArrays.flatten - - val rowLength = if (fieldNodes.nonEmpty) fieldNodes.head.getLength else 0 - val recordBatch = new ArrowRecordBatch(rowLength, - fieldNodes.toList.asJava, buffers.toList.asJava) - - buffers.foreach(_.release()) - recordBatch - } - - /** - * Convert an ArrowRecordBatch to a byte array and close batch to release resources. Once closed, - * the batch can no longer be used. - */ - private[arrow] def batchToByteArray( - batch: ArrowRecordBatch, - schema: StructType, - allocator: BufferAllocator): Array[Byte] = { - val arrowSchema = ArrowConverters.schemaToArrowSchema(schema) - val root = VectorSchemaRoot.create(arrowSchema, allocator) - val out = new ByteArrayOutputStream() - val writer = new ArrowFileWriter(root, null, Channels.newChannel(out)) - - // Write a batch to byte stream, ensure the batch, allocator and writer are closed - Utils.tryWithSafeFinally { - val loader = new VectorLoader(root) - loader.load(batch) - writer.writeBatch() // writeBatch can throw IOException - } { - batch.close() - root.close() - writer.close() - } - out.toByteArray - } - - /** - * Convert a byte array to an ArrowRecordBatch. - */ - private[arrow] def byteArrayToBatch( - batchBytes: Array[Byte], - allocator: BufferAllocator): ArrowRecordBatch = { - val in = new ByteArrayReadableSeekableByteChannel(batchBytes) - val reader = new ArrowFileReader(in, allocator) - - // Read a batch from a byte stream, ensure the reader is closed - Utils.tryWithSafeFinally { - val root = reader.getVectorSchemaRoot // throws IOException - val unloader = new VectorUnloader(root) - reader.loadNextBatch() // throws IOException - unloader.getRecordBatch - } { - reader.close() - } - } -} - -/** - * Interface for writing InternalRows to Arrow Buffers. - */ -private[arrow] trait ColumnWriter { - def init(): this.type - def write(row: InternalRow): Unit - - /** - * Clear the column writer and return the ArrowFieldNode and ArrowBuf. - * This should be called only once after all the data is written. - */ - def finish(): (ArrowFieldNode, Array[ArrowBuf]) -} - -/** - * Base class for flat arrow column writer, i.e., column without children. - */ -private[arrow] abstract class PrimitiveColumnWriter(val ordinal: Int) - extends ColumnWriter { - - def getFieldType(dtype: ArrowType): FieldType = FieldType.nullable(dtype) - - def valueVector: BaseDataValueVector - def valueMutator: BaseMutator - - def setNull(): Unit - def setValue(row: InternalRow): Unit - - protected var count = 0 - protected var nullCount = 0 - - override def init(): this.type = { - valueVector.allocateNew() - this - } - - override def write(row: InternalRow): Unit = { - if (row.isNullAt(ordinal)) { - setNull() - nullCount += 1 - } else { - setValue(row) - } - count += 1 - } - - override def finish(): (ArrowFieldNode, Array[ArrowBuf]) = { - valueMutator.setValueCount(count) - val fieldNode = new ArrowFieldNode(count, nullCount) - val valueBuffers = valueVector.getBuffers(true) - (fieldNode, valueBuffers) - } -} - -private[arrow] class BooleanColumnWriter(dtype: ArrowType, ordinal: Int, allocator: BufferAllocator) - extends PrimitiveColumnWriter(ordinal) { - override val valueVector: NullableBitVector - = new NullableBitVector("BooleanValue", getFieldType(dtype), allocator) - override val valueMutator: NullableBitVector#Mutator = valueVector.getMutator - - override def setNull(): Unit = valueMutator.setNull(count) - override def setValue(row: InternalRow): Unit - = valueMutator.setSafe(count, if (row.getBoolean(ordinal)) 1 else 0 ) -} - -private[arrow] class ShortColumnWriter(dtype: ArrowType, ordinal: Int, allocator: BufferAllocator) - extends PrimitiveColumnWriter(ordinal) { - override val valueVector: NullableSmallIntVector - = new NullableSmallIntVector("ShortValue", getFieldType(dtype: ArrowType), allocator) - override val valueMutator: NullableSmallIntVector#Mutator = valueVector.getMutator - - override def setNull(): Unit = valueMutator.setNull(count) - override def setValue(row: InternalRow): Unit - = valueMutator.setSafe(count, row.getShort(ordinal)) -} - -private[arrow] class IntegerColumnWriter(dtype: ArrowType, ordinal: Int, allocator: BufferAllocator) - extends PrimitiveColumnWriter(ordinal) { - override val valueVector: NullableIntVector - = new NullableIntVector("IntValue", getFieldType(dtype), allocator) - override val valueMutator: NullableIntVector#Mutator = valueVector.getMutator - - override def setNull(): Unit = valueMutator.setNull(count) - override def setValue(row: InternalRow): Unit - = valueMutator.setSafe(count, row.getInt(ordinal)) -} - -private[arrow] class LongColumnWriter(dtype: ArrowType, ordinal: Int, allocator: BufferAllocator) - extends PrimitiveColumnWriter(ordinal) { - override val valueVector: NullableBigIntVector - = new NullableBigIntVector("LongValue", getFieldType(dtype), allocator) - override val valueMutator: NullableBigIntVector#Mutator = valueVector.getMutator - - override def setNull(): Unit = valueMutator.setNull(count) - override def setValue(row: InternalRow): Unit - = valueMutator.setSafe(count, row.getLong(ordinal)) -} - -private[arrow] class FloatColumnWriter(dtype: ArrowType, ordinal: Int, allocator: BufferAllocator) - extends PrimitiveColumnWriter(ordinal) { - override val valueVector: NullableFloat4Vector - = new NullableFloat4Vector("FloatValue", getFieldType(dtype), allocator) - override val valueMutator: NullableFloat4Vector#Mutator = valueVector.getMutator - - override def setNull(): Unit = valueMutator.setNull(count) - override def setValue(row: InternalRow): Unit - = valueMutator.setSafe(count, row.getFloat(ordinal)) -} - -private[arrow] class DoubleColumnWriter(dtype: ArrowType, ordinal: Int, allocator: BufferAllocator) - extends PrimitiveColumnWriter(ordinal) { - override val valueVector: NullableFloat8Vector - = new NullableFloat8Vector("DoubleValue", getFieldType(dtype), allocator) - override val valueMutator: NullableFloat8Vector#Mutator = valueVector.getMutator - - override def setNull(): Unit = valueMutator.setNull(count) - override def setValue(row: InternalRow): Unit - = valueMutator.setSafe(count, row.getDouble(ordinal)) -} - -private[arrow] class ByteColumnWriter(dtype: ArrowType, ordinal: Int, allocator: BufferAllocator) - extends PrimitiveColumnWriter(ordinal) { - override val valueVector: NullableUInt1Vector - = new NullableUInt1Vector("ByteValue", getFieldType(dtype), allocator) - override val valueMutator: NullableUInt1Vector#Mutator = valueVector.getMutator - - override def setNull(): Unit = valueMutator.setNull(count) - override def setValue(row: InternalRow): Unit - = valueMutator.setSafe(count, row.getByte(ordinal)) -} - -private[arrow] class UTF8StringColumnWriter( - dtype: ArrowType, - ordinal: Int, - allocator: BufferAllocator) - extends PrimitiveColumnWriter(ordinal) { - override val valueVector: NullableVarCharVector - = new NullableVarCharVector("UTF8StringValue", getFieldType(dtype), allocator) - override val valueMutator: NullableVarCharVector#Mutator = valueVector.getMutator - - override def setNull(): Unit = valueMutator.setNull(count) - override def setValue(row: InternalRow): Unit = { - val str = row.getUTF8String(ordinal) - valueMutator.setSafe(count, str.getByteBuffer, 0, str.numBytes) - } -} - -private[arrow] class BinaryColumnWriter(dtype: ArrowType, ordinal: Int, allocator: BufferAllocator) - extends PrimitiveColumnWriter(ordinal) { - override val valueVector: NullableVarBinaryVector - = new NullableVarBinaryVector("BinaryValue", getFieldType(dtype), allocator) - override val valueMutator: NullableVarBinaryVector#Mutator = valueVector.getMutator - - override def setNull(): Unit = valueMutator.setNull(count) - override def setValue(row: InternalRow): Unit = { - val bytes = row.getBinary(ordinal) - valueMutator.setSafe(count, bytes, 0, bytes.length) - } -} - -private[arrow] class DateColumnWriter(dtype: ArrowType, ordinal: Int, allocator: BufferAllocator) - extends PrimitiveColumnWriter(ordinal) { - override val valueVector: NullableDateDayVector - = new NullableDateDayVector("DateValue", getFieldType(dtype), allocator) - override val valueMutator: NullableDateDayVector#Mutator = valueVector.getMutator - - override def setNull(): Unit = valueMutator.setNull(count) - override def setValue(row: InternalRow): Unit = { - valueMutator.setSafe(count, row.getInt(ordinal)) - } -} - -private[arrow] class TimeStampColumnWriter( - dtype: ArrowType, - ordinal: Int, - allocator: BufferAllocator) - extends PrimitiveColumnWriter(ordinal) { - override val valueVector: NullableTimeStampMicroVector - = new NullableTimeStampMicroVector("TimeStampValue", getFieldType(dtype), allocator) - override val valueMutator: NullableTimeStampMicroVector#Mutator = valueVector.getMutator - - override def setNull(): Unit = valueMutator.setNull(count) - override def setValue(row: InternalRow): Unit = { - valueMutator.setSafe(count, row.getLong(ordinal)) - } -} - -private[arrow] object ColumnWriter { - - /** - * Create an Arrow ColumnWriter given the type and ordinal of row. - */ - def apply(dataType: DataType, ordinal: Int, allocator: BufferAllocator): ColumnWriter = { - val dtype = ArrowConverters.sparkTypeToArrowType(dataType) - dataType match { - case BooleanType => new BooleanColumnWriter(dtype, ordinal, allocator) - case ShortType => new ShortColumnWriter(dtype, ordinal, allocator) - case IntegerType => new IntegerColumnWriter(dtype, ordinal, allocator) - case LongType => new LongColumnWriter(dtype, ordinal, allocator) - case FloatType => new FloatColumnWriter(dtype, ordinal, allocator) - case DoubleType => new DoubleColumnWriter(dtype, ordinal, allocator) - case ByteType => new ByteColumnWriter(dtype, ordinal, allocator) - case StringType => new UTF8StringColumnWriter(dtype, ordinal, allocator) - case BinaryType => new BinaryColumnWriter(dtype, ordinal, allocator) - case DateType => new DateColumnWriter(dtype, ordinal, allocator) - case TimestampType => new TimeStampColumnWriter(dtype, ordinal, allocator) - case _ => throw new UnsupportedOperationException(s"Unsupported data type: $dataType") - } - } -} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowConvertersSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowConvertersSuite.scala deleted file mode 100644 index 159328cc0d958..0000000000000 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowConvertersSuite.scala +++ /dev/null @@ -1,1222 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.spark.sql.execution.arrow - -import java.io.File -import java.nio.charset.StandardCharsets -import java.sql.{Date, Timestamp} -import java.text.SimpleDateFormat -import java.util.Locale - -import com.google.common.io.Files -import org.apache.arrow.memory.RootAllocator -import org.apache.arrow.vector.{VectorLoader, VectorSchemaRoot} -import org.apache.arrow.vector.file.json.JsonFileReader -import org.apache.arrow.vector.util.Validator -import org.scalatest.BeforeAndAfterAll - -import org.apache.spark.SparkException -import org.apache.spark.sql.{DataFrame, Row} -import org.apache.spark.sql.test.SharedSQLContext -import org.apache.spark.sql.types.{BinaryType, StructField, StructType} -import org.apache.spark.util.Utils - - -class ArrowConvertersSuite extends SharedSQLContext with BeforeAndAfterAll { - import testImplicits._ - - private var tempDataPath: String = _ - - override def beforeAll(): Unit = { - super.beforeAll() - tempDataPath = Utils.createTempDir(namePrefix = "arrow").getAbsolutePath - } - - test("collect to arrow record batch") { - val indexData = (1 to 6).toDF("i") - val arrowPayloads = indexData.toArrowPayload.collect() - assert(arrowPayloads.nonEmpty) - assert(arrowPayloads.length == indexData.rdd.getNumPartitions) - val allocator = new RootAllocator(Long.MaxValue) - val arrowRecordBatches = arrowPayloads.map(_.loadBatch(allocator)) - val rowCount = arrowRecordBatches.map(_.getLength).sum - assert(rowCount === indexData.count()) - arrowRecordBatches.foreach(batch => assert(batch.getNodes.size() > 0)) - arrowRecordBatches.foreach(_.close()) - allocator.close() - } - - test("short conversion") { - val json = - s""" - |{ - | "schema" : { - | "fields" : [ { - | "name" : "a_s", - | "type" : { - | "name" : "int", - | "isSigned" : true, - | "bitWidth" : 16 - | }, - | "nullable" : false, - | "children" : [ ], - | "typeLayout" : { - | "vectors" : [ { - | "type" : "VALIDITY", - | "typeBitWidth" : 1 - | }, { - | "type" : "DATA", - | "typeBitWidth" : 16 - | } ] - | } - | }, { - | "name" : "b_s", - | "type" : { - | "name" : "int", - | "isSigned" : true, - | "bitWidth" : 16 - | }, - | "nullable" : true, - | "children" : [ ], - | "typeLayout" : { - | "vectors" : [ { - | "type" : "VALIDITY", - | "typeBitWidth" : 1 - | }, { - | "type" : "DATA", - | "typeBitWidth" : 16 - | } ] - | } - | } ] - | }, - | "batches" : [ { - | "count" : 6, - | "columns" : [ { - | "name" : "a_s", - | "count" : 6, - | "VALIDITY" : [ 1, 1, 1, 1, 1, 1 ], - | "DATA" : [ 1, -1, 2, -2, 32767, -32768 ] - | }, { - | "name" : "b_s", - | "count" : 6, - | "VALIDITY" : [ 1, 0, 0, 1, 0, 1 ], - | "DATA" : [ 1, 0, 0, -2, 0, -32768 ] - | } ] - | } ] - |} - """.stripMargin - - val a_s = List[Short](1, -1, 2, -2, 32767, -32768) - val b_s = List[Option[Short]](Some(1), None, None, Some(-2), None, Some(-32768)) - val df = a_s.zip(b_s).toDF("a_s", "b_s") - - collectAndValidate(df, json, "integer-16bit.json") - } - - test("int conversion") { - val json = - s""" - |{ - | "schema" : { - | "fields" : [ { - | "name" : "a_i", - | "type" : { - | "name" : "int", - | "isSigned" : true, - | "bitWidth" : 32 - | }, - | "nullable" : false, - | "children" : [ ], - | "typeLayout" : { - | "vectors" : [ { - | "type" : "VALIDITY", - | "typeBitWidth" : 1 - | }, { - | "type" : "DATA", - | "typeBitWidth" : 32 - | } ] - | } - | }, { - | "name" : "b_i", - | "type" : { - | "name" : "int", - | "isSigned" : true, - | "bitWidth" : 32 - | }, - | "nullable" : true, - | "children" : [ ], - | "typeLayout" : { - | "vectors" : [ { - | "type" : "VALIDITY", - | "typeBitWidth" : 1 - | }, { - | "type" : "DATA", - | "typeBitWidth" : 32 - | } ] - | } - | } ] - | }, - | "batches" : [ { - | "count" : 6, - | "columns" : [ { - | "name" : "a_i", - | "count" : 6, - | "VALIDITY" : [ 1, 1, 1, 1, 1, 1 ], - | "DATA" : [ 1, -1, 2, -2, 2147483647, -2147483648 ] - | }, { - | "name" : "b_i", - | "count" : 6, - | "VALIDITY" : [ 1, 0, 0, 1, 0, 1 ], - | "DATA" : [ 1, 0, 0, -2, 0, -2147483648 ] - | } ] - | } ] - |} - """.stripMargin - - val a_i = List[Int](1, -1, 2, -2, 2147483647, -2147483648) - val b_i = List[Option[Int]](Some(1), None, None, Some(-2), None, Some(-2147483648)) - val df = a_i.zip(b_i).toDF("a_i", "b_i") - - collectAndValidate(df, json, "integer-32bit.json") - } - - test("long conversion") { - val json = - s""" - |{ - | "schema" : { - | "fields" : [ { - | "name" : "a_l", - | "type" : { - | "name" : "int", - | "isSigned" : true, - | "bitWidth" : 64 - | }, - | "nullable" : false, - | "children" : [ ], - | "typeLayout" : { - | "vectors" : [ { - | "type" : "VALIDITY", - | "typeBitWidth" : 1 - | }, { - | "type" : "DATA", - | "typeBitWidth" : 64 - | } ] - | } - | }, { - | "name" : "b_l", - | "type" : { - | "name" : "int", - | "isSigned" : true, - | "bitWidth" : 64 - | }, - | "nullable" : true, - | "children" : [ ], - | "typeLayout" : { - | "vectors" : [ { - | "type" : "VALIDITY", - | "typeBitWidth" : 1 - | }, { - | "type" : "DATA", - | "typeBitWidth" : 64 - | } ] - | } - | } ] - | }, - | "batches" : [ { - | "count" : 6, - | "columns" : [ { - | "name" : "a_l", - | "count" : 6, - | "VALIDITY" : [ 1, 1, 1, 1, 1, 1 ], - | "DATA" : [ 1, -1, 2, -2, 9223372036854775807, -9223372036854775808 ] - | }, { - | "name" : "b_l", - | "count" : 6, - | "VALIDITY" : [ 1, 0, 0, 1, 0, 1 ], - | "DATA" : [ 1, 0, 0, -2, 0, -9223372036854775808 ] - | } ] - | } ] - |} - """.stripMargin - - val a_l = List[Long](1, -1, 2, -2, 9223372036854775807L, -9223372036854775808L) - val b_l = List[Option[Long]](Some(1), None, None, Some(-2), None, Some(-9223372036854775808L)) - val df = a_l.zip(b_l).toDF("a_l", "b_l") - - collectAndValidate(df, json, "integer-64bit.json") - } - - test("float conversion") { - val json = - s""" - |{ - | "schema" : { - | "fields" : [ { - | "name" : "a_f", - | "type" : { - | "name" : "floatingpoint", - | "precision" : "SINGLE" - | }, - | "nullable" : false, - | "children" : [ ], - | "typeLayout" : { - | "vectors" : [ { - | "type" : "VALIDITY", - | "typeBitWidth" : 1 - | }, { - | "type" : "DATA", - | "typeBitWidth" : 32 - | } ] - | } - | }, { - | "name" : "b_f", - | "type" : { - | "name" : "floatingpoint", - | "precision" : "SINGLE" - | }, - | "nullable" : true, - | "children" : [ ], - | "typeLayout" : { - | "vectors" : [ { - | "type" : "VALIDITY", - | "typeBitWidth" : 1 - | }, { - | "type" : "DATA", - | "typeBitWidth" : 32 - | } ] - | } - | } ] - | }, - | "batches" : [ { - | "count" : 6, - | "columns" : [ { - | "name" : "a_f", - | "count" : 6, - | "VALIDITY" : [ 1, 1, 1, 1, 1, 1 ], - | "DATA" : [ 1.0, 2.0, 0.01, 200.0, 0.0001, 20000.0 ] - | }, { - | "name" : "b_f", - | "count" : 6, - | "VALIDITY" : [ 1, 0, 0, 1, 0, 1 ], - | "DATA" : [ 1.1, 0.0, 0.0, 2.2, 0.0, 3.3 ] - | } ] - | } ] - |} - """.stripMargin - - val a_f = List(1.0f, 2.0f, 0.01f, 200.0f, 0.0001f, 20000.0f) - val b_f = List[Option[Float]](Some(1.1f), None, None, Some(2.2f), None, Some(3.3f)) - val df = a_f.zip(b_f).toDF("a_f", "b_f") - - collectAndValidate(df, json, "floating_point-single_precision.json") - } - - test("double conversion") { - val json = - s""" - |{ - | "schema" : { - | "fields" : [ { - | "name" : "a_d", - | "type" : { - | "name" : "floatingpoint", - | "precision" : "DOUBLE" - | }, - | "nullable" : false, - | "children" : [ ], - | "typeLayout" : { - | "vectors" : [ { - | "type" : "VALIDITY", - | "typeBitWidth" : 1 - | }, { - | "type" : "DATA", - | "typeBitWidth" : 64 - | } ] - | } - | }, { - | "name" : "b_d", - | "type" : { - | "name" : "floatingpoint", - | "precision" : "DOUBLE" - | }, - | "nullable" : true, - | "children" : [ ], - | "typeLayout" : { - | "vectors" : [ { - | "type" : "VALIDITY", - | "typeBitWidth" : 1 - | }, { - | "type" : "DATA", - | "typeBitWidth" : 64 - | } ] - | } - | } ] - | }, - | "batches" : [ { - | "count" : 6, - | "columns" : [ { - | "name" : "a_d", - | "count" : 6, - | "VALIDITY" : [ 1, 1, 1, 1, 1, 1 ], - | "DATA" : [ 1.0, 2.0, 0.01, 200.0, 1.0E-4, 20000.0 ] - | }, { - | "name" : "b_d", - | "count" : 6, - | "VALIDITY" : [ 1, 0, 0, 1, 0, 1 ], - | "DATA" : [ 1.1, 0.0, 0.0, 2.2, 0.0, 3.3 ] - | } ] - | } ] - |} - """.stripMargin - - val a_d = List(1.0, 2.0, 0.01, 200.0, 0.0001, 20000.0) - val b_d = List[Option[Double]](Some(1.1), None, None, Some(2.2), None, Some(3.3)) - val df = a_d.zip(b_d).toDF("a_d", "b_d") - - collectAndValidate(df, json, "floating_point-double_precision.json") - } - - test("index conversion") { - val data = List[Int](1, 2, 3, 4, 5, 6) - val json = - s""" - |{ - | "schema" : { - | "fields" : [ { - | "name" : "i", - | "type" : { - | "name" : "int", - | "isSigned" : true, - | "bitWidth" : 32 - | }, - | "nullable" : false, - | "children" : [ ], - | "typeLayout" : { - | "vectors" : [ { - | "type" : "VALIDITY", - | "typeBitWidth" : 1 - | }, { - | "type" : "DATA", - | "typeBitWidth" : 32 - | } ] - | } - | } ] - | }, - | "batches" : [ { - | "count" : 6, - | "columns" : [ { - | "name" : "i", - | "count" : 6, - | "VALIDITY" : [ 1, 1, 1, 1, 1, 1 ], - | "DATA" : [ 1, 2, 3, 4, 5, 6 ] - | } ] - | } ] - |} - """.stripMargin - val df = data.toDF("i") - - collectAndValidate(df, json, "indexData-ints.json") - } - - test("mixed numeric type conversion") { - val json = - s""" - |{ - | "schema" : { - | "fields" : [ { - | "name" : "a", - | "type" : { - | "name" : "int", - | "isSigned" : true, - | "bitWidth" : 16 - | }, - | "nullable" : false, - | "children" : [ ], - | "typeLayout" : { - | "vectors" : [ { - | "type" : "VALIDITY", - | "typeBitWidth" : 1 - | }, { - | "type" : "DATA", - | "typeBitWidth" : 16 - | } ] - | } - | }, { - | "name" : "b", - | "type" : { - | "name" : "floatingpoint", - | "precision" : "SINGLE" - | }, - | "nullable" : false, - | "children" : [ ], - | "typeLayout" : { - | "vectors" : [ { - | "type" : "VALIDITY", - | "typeBitWidth" : 1 - | }, { - | "type" : "DATA", - | "typeBitWidth" : 32 - | } ] - | } - | }, { - | "name" : "c", - | "type" : { - | "name" : "int", - | "isSigned" : true, - | "bitWidth" : 32 - | }, - | "nullable" : false, - | "children" : [ ], - | "typeLayout" : { - | "vectors" : [ { - | "type" : "VALIDITY", - | "typeBitWidth" : 1 - | }, { - | "type" : "DATA", - | "typeBitWidth" : 32 - | } ] - | } - | }, { - | "name" : "d", - | "type" : { - | "name" : "floatingpoint", - | "precision" : "DOUBLE" - | }, - | "nullable" : false, - | "children" : [ ], - | "typeLayout" : { - | "vectors" : [ { - | "type" : "VALIDITY", - | "typeBitWidth" : 1 - | }, { - | "type" : "DATA", - | "typeBitWidth" : 64 - | } ] - | } - | }, { - | "name" : "e", - | "type" : { - | "name" : "int", - | "isSigned" : true, - | "bitWidth" : 64 - | }, - | "nullable" : false, - | "children" : [ ], - | "typeLayout" : { - | "vectors" : [ { - | "type" : "VALIDITY", - | "typeBitWidth" : 1 - | }, { - | "type" : "DATA", - | "typeBitWidth" : 64 - | } ] - | } - | } ] - | }, - | "batches" : [ { - | "count" : 6, - | "columns" : [ { - | "name" : "a", - | "count" : 6, - | "VALIDITY" : [ 1, 1, 1, 1, 1, 1 ], - | "DATA" : [ 1, 2, 3, 4, 5, 6 ] - | }, { - | "name" : "b", - | "count" : 6, - | "VALIDITY" : [ 1, 1, 1, 1, 1, 1 ], - | "DATA" : [ 1.0, 2.0, 3.0, 4.0, 5.0, 6.0 ] - | }, { - | "name" : "c", - | "count" : 6, - | "VALIDITY" : [ 1, 1, 1, 1, 1, 1 ], - | "DATA" : [ 1, 2, 3, 4, 5, 6 ] - | }, { - | "name" : "d", - | "count" : 6, - | "VALIDITY" : [ 1, 1, 1, 1, 1, 1 ], - | "DATA" : [ 1.0, 2.0, 3.0, 4.0, 5.0, 6.0 ] - | }, { - | "name" : "e", - | "count" : 6, - | "VALIDITY" : [ 1, 1, 1, 1, 1, 1 ], - | "DATA" : [ 1, 2, 3, 4, 5, 6 ] - | } ] - | } ] - |} - """.stripMargin - - val data = List(1, 2, 3, 4, 5, 6) - val data_tuples = for (d <- data) yield { - (d.toShort, d.toFloat, d.toInt, d.toDouble, d.toLong) - } - val df = data_tuples.toDF("a", "b", "c", "d", "e") - - collectAndValidate(df, json, "mixed_numeric_types.json") - } - - test("string type conversion") { - val json = - s""" - |{ - | "schema" : { - | "fields" : [ { - | "name" : "upper_case", - | "type" : { - | "name" : "utf8" - | }, - | "nullable" : true, - | "children" : [ ], - | "typeLayout" : { - | "vectors" : [ { - | "type" : "VALIDITY", - | "typeBitWidth" : 1 - | }, { - | "type" : "OFFSET", - | "typeBitWidth" : 32 - | }, { - | "type" : "DATA", - | "typeBitWidth" : 8 - | } ] - | } - | }, { - | "name" : "lower_case", - | "type" : { - | "name" : "utf8" - | }, - | "nullable" : true, - | "children" : [ ], - | "typeLayout" : { - | "vectors" : [ { - | "type" : "VALIDITY", - | "typeBitWidth" : 1 - | }, { - | "type" : "OFFSET", - | "typeBitWidth" : 32 - | }, { - | "type" : "DATA", - | "typeBitWidth" : 8 - | } ] - | } - | }, { - | "name" : "null_str", - | "type" : { - | "name" : "utf8" - | }, - | "nullable" : true, - | "children" : [ ], - | "typeLayout" : { - | "vectors" : [ { - | "type" : "VALIDITY", - | "typeBitWidth" : 1 - | }, { - | "type" : "OFFSET", - | "typeBitWidth" : 32 - | }, { - | "type" : "DATA", - | "typeBitWidth" : 8 - | } ] - | } - | } ] - | }, - | "batches" : [ { - | "count" : 3, - | "columns" : [ { - | "name" : "upper_case", - | "count" : 3, - | "VALIDITY" : [ 1, 1, 1 ], - | "OFFSET" : [ 0, 1, 2, 3 ], - | "DATA" : [ "A", "B", "C" ] - | }, { - | "name" : "lower_case", - | "count" : 3, - | "VALIDITY" : [ 1, 1, 1 ], - | "OFFSET" : [ 0, 1, 2, 3 ], - | "DATA" : [ "a", "b", "c" ] - | }, { - | "name" : "null_str", - | "count" : 3, - | "VALIDITY" : [ 1, 1, 0 ], - | "OFFSET" : [ 0, 2, 5, 5 ], - | "DATA" : [ "ab", "CDE", "" ] - | } ] - | } ] - |} - """.stripMargin - - val upperCase = Seq("A", "B", "C") - val lowerCase = Seq("a", "b", "c") - val nullStr = Seq("ab", "CDE", null) - val df = (upperCase, lowerCase, nullStr).zipped.toList - .toDF("upper_case", "lower_case", "null_str") - - collectAndValidate(df, json, "stringData.json") - } - - test("boolean type conversion") { - val json = - s""" - |{ - | "schema" : { - | "fields" : [ { - | "name" : "a_bool", - | "type" : { - | "name" : "bool" - | }, - | "nullable" : false, - | "children" : [ ], - | "typeLayout" : { - | "vectors" : [ { - | "type" : "VALIDITY", - | "typeBitWidth" : 1 - | }, { - | "type" : "DATA", - | "typeBitWidth" : 1 - | } ] - | } - | } ] - | }, - | "batches" : [ { - | "count" : 4, - | "columns" : [ { - | "name" : "a_bool", - | "count" : 4, - | "VALIDITY" : [ 1, 1, 1, 1 ], - | "DATA" : [ true, true, false, true ] - | } ] - | } ] - |} - """.stripMargin - val df = Seq(true, true, false, true).toDF("a_bool") - collectAndValidate(df, json, "boolData.json") - } - - test("byte type conversion") { - val json = - s""" - |{ - | "schema" : { - | "fields" : [ { - | "name" : "a_byte", - | "type" : { - | "name" : "int", - | "isSigned" : true, - | "bitWidth" : 8 - | }, - | "nullable" : false, - | "children" : [ ], - | "typeLayout" : { - | "vectors" : [ { - | "type" : "VALIDITY", - | "typeBitWidth" : 1 - | }, { - | "type" : "DATA", - | "typeBitWidth" : 8 - | } ] - | } - | } ] - | }, - | "batches" : [ { - | "count" : 4, - | "columns" : [ { - | "name" : "a_byte", - | "count" : 4, - | "VALIDITY" : [ 1, 1, 1, 1 ], - | "DATA" : [ 1, -1, 64, 127 ] - | } ] - | } ] - |} - | - """.stripMargin - val df = List[Byte](1.toByte, (-1).toByte, 64.toByte, Byte.MaxValue).toDF("a_byte") - collectAndValidate(df, json, "byteData.json") - } - - test("binary type conversion") { - val json = - s""" - |{ - | "schema" : { - | "fields" : [ { - | "name" : "a_binary", - | "type" : { - | "name" : "binary" - | }, - | "nullable" : true, - | "children" : [ ], - | "typeLayout" : { - | "vectors" : [ { - | "type" : "VALIDITY", - | "typeBitWidth" : 1 - | }, { - | "type" : "OFFSET", - | "typeBitWidth" : 32 - | }, { - | "type" : "DATA", - | "typeBitWidth" : 8 - | } ] - | } - | } ] - | }, - | "batches" : [ { - | "count" : 3, - | "columns" : [ { - | "name" : "a_binary", - | "count" : 3, - | "VALIDITY" : [ 1, 1, 1 ], - | "OFFSET" : [ 0, 3, 4, 6 ], - | "DATA" : [ "616263", "64", "6566" ] - | } ] - | } ] - |} - """.stripMargin - - val data = Seq("abc", "d", "ef") - val rdd = sparkContext.parallelize(data.map(s => Row(s.getBytes("utf-8")))) - val df = spark.createDataFrame(rdd, StructType(Seq(StructField("a_binary", BinaryType)))) - - collectAndValidate(df, json, "binaryData.json") - } - - test("floating-point NaN") { - val json = - s""" - |{ - | "schema" : { - | "fields" : [ { - | "name" : "NaN_f", - | "type" : { - | "name" : "floatingpoint", - | "precision" : "SINGLE" - | }, - | "nullable" : false, - | "children" : [ ], - | "typeLayout" : { - | "vectors" : [ { - | "type" : "VALIDITY", - | "typeBitWidth" : 1 - | }, { - | "type" : "DATA", - | "typeBitWidth" : 32 - | } ] - | } - | }, { - | "name" : "NaN_d", - | "type" : { - | "name" : "floatingpoint", - | "precision" : "DOUBLE" - | }, - | "nullable" : false, - | "children" : [ ], - | "typeLayout" : { - | "vectors" : [ { - | "type" : "VALIDITY", - | "typeBitWidth" : 1 - | }, { - | "type" : "DATA", - | "typeBitWidth" : 64 - | } ] - | } - | } ] - | }, - | "batches" : [ { - | "count" : 2, - | "columns" : [ { - | "name" : "NaN_f", - | "count" : 2, - | "VALIDITY" : [ 1, 1 ], - | "DATA" : [ 1.2000000476837158, "NaN" ] - | }, { - | "name" : "NaN_d", - | "count" : 2, - | "VALIDITY" : [ 1, 1 ], - | "DATA" : [ "NaN", 1.2 ] - | } ] - | } ] - |} - """.stripMargin - - val fnan = Seq(1.2F, Float.NaN) - val dnan = Seq(Double.NaN, 1.2) - val df = fnan.zip(dnan).toDF("NaN_f", "NaN_d") - - collectAndValidate(df, json, "nanData-floating_point.json") - } - - test("partitioned DataFrame") { - val json1 = - s""" - |{ - | "schema" : { - | "fields" : [ { - | "name" : "a", - | "type" : { - | "name" : "int", - | "isSigned" : true, - | "bitWidth" : 32 - | }, - | "nullable" : false, - | "children" : [ ], - | "typeLayout" : { - | "vectors" : [ { - | "type" : "VALIDITY", - | "typeBitWidth" : 1 - | }, { - | "type" : "DATA", - | "typeBitWidth" : 32 - | } ] - | } - | }, { - | "name" : "b", - | "type" : { - | "name" : "int", - | "isSigned" : true, - | "bitWidth" : 32 - | }, - | "nullable" : false, - | "children" : [ ], - | "typeLayout" : { - | "vectors" : [ { - | "type" : "VALIDITY", - | "typeBitWidth" : 1 - | }, { - | "type" : "DATA", - | "typeBitWidth" : 32 - | } ] - | } - | } ] - | }, - | "batches" : [ { - | "count" : 3, - | "columns" : [ { - | "name" : "a", - | "count" : 3, - | "VALIDITY" : [ 1, 1, 1 ], - | "DATA" : [ 1, 1, 2 ] - | }, { - | "name" : "b", - | "count" : 3, - | "VALIDITY" : [ 1, 1, 1 ], - | "DATA" : [ 1, 2, 1 ] - | } ] - | } ] - |} - """.stripMargin - val json2 = - s""" - |{ - | "schema" : { - | "fields" : [ { - | "name" : "a", - | "type" : { - | "name" : "int", - | "isSigned" : true, - | "bitWidth" : 32 - | }, - | "nullable" : false, - | "children" : [ ], - | "typeLayout" : { - | "vectors" : [ { - | "type" : "VALIDITY", - | "typeBitWidth" : 1 - | }, { - | "type" : "DATA", - | "typeBitWidth" : 32 - | } ] - | } - | }, { - | "name" : "b", - | "type" : { - | "name" : "int", - | "isSigned" : true, - | "bitWidth" : 32 - | }, - | "nullable" : false, - | "children" : [ ], - | "typeLayout" : { - | "vectors" : [ { - | "type" : "VALIDITY", - | "typeBitWidth" : 1 - | }, { - | "type" : "DATA", - | "typeBitWidth" : 32 - | } ] - | } - | } ] - | }, - | "batches" : [ { - | "count" : 3, - | "columns" : [ { - | "name" : "a", - | "count" : 3, - | "VALIDITY" : [ 1, 1, 1 ], - | "DATA" : [ 2, 3, 3 ] - | }, { - | "name" : "b", - | "count" : 3, - | "VALIDITY" : [ 1, 1, 1 ], - | "DATA" : [ 2, 1, 2 ] - | } ] - | } ] - |} - """.stripMargin - - val arrowPayloads = testData2.toArrowPayload.collect() - // NOTE: testData2 should have 2 partitions -> 2 arrow batches in payload - assert(arrowPayloads.length === 2) - val schema = testData2.schema - - val tempFile1 = new File(tempDataPath, "testData2-ints-part1.json") - val tempFile2 = new File(tempDataPath, "testData2-ints-part2.json") - Files.write(json1, tempFile1, StandardCharsets.UTF_8) - Files.write(json2, tempFile2, StandardCharsets.UTF_8) - - validateConversion(schema, arrowPayloads(0), tempFile1) - validateConversion(schema, arrowPayloads(1), tempFile2) - } - - test("empty frame collect") { - val arrowPayload = spark.emptyDataFrame.toArrowPayload.collect() - assert(arrowPayload.isEmpty) - - val filteredDF = List[Int](1, 2, 3, 4, 5, 6).toDF("i") - val filteredArrowPayload = filteredDF.filter("i < 0").toArrowPayload.collect() - assert(filteredArrowPayload.isEmpty) - } - - test("empty partition collect") { - val emptyPart = spark.sparkContext.parallelize(Seq(1), 2).toDF("i") - val arrowPayloads = emptyPart.toArrowPayload.collect() - assert(arrowPayloads.length === 1) - val allocator = new RootAllocator(Long.MaxValue) - val arrowRecordBatches = arrowPayloads.map(_.loadBatch(allocator)) - assert(arrowRecordBatches.head.getLength == 1) - arrowRecordBatches.foreach(_.close()) - allocator.close() - } - - test("max records in batch conf") { - val totalRecords = 10 - val maxRecordsPerBatch = 3 - spark.conf.set("spark.sql.execution.arrow.maxRecordsPerBatch", maxRecordsPerBatch) - val df = spark.sparkContext.parallelize(1 to totalRecords, 2).toDF("i") - val arrowPayloads = df.toArrowPayload.collect() - val allocator = new RootAllocator(Long.MaxValue) - val arrowRecordBatches = arrowPayloads.map(_.loadBatch(allocator)) - var recordCount = 0 - arrowRecordBatches.foreach { batch => - assert(batch.getLength > 0) - assert(batch.getLength <= maxRecordsPerBatch) - recordCount += batch.getLength - batch.close() - } - assert(recordCount == totalRecords) - allocator.close() - spark.conf.unset("spark.sql.execution.arrow.maxRecordsPerBatch") - } - - testQuietly("unsupported types") { - def runUnsupported(block: => Unit): Unit = { - val msg = intercept[SparkException] { - block - } - assert(msg.getMessage.contains("Unsupported data type")) - assert(msg.getCause.getClass === classOf[UnsupportedOperationException]) - } - - runUnsupported { decimalData.toArrowPayload.collect() } - runUnsupported { arrayData.toDF().toArrowPayload.collect() } - runUnsupported { mapData.toDF().toArrowPayload.collect() } - runUnsupported { complexData.toArrowPayload.collect() } - - val sdf = new SimpleDateFormat("yyyy-MM-dd HH:mm:ss.SSS z", Locale.US) - val d1 = new Date(sdf.parse("2015-04-08 13:10:15.000 UTC").getTime) - val d2 = new Date(sdf.parse("2016-05-09 13:10:15.000 UTC").getTime) - runUnsupported { Seq(d1, d2).toDF("date").toArrowPayload.collect() } - - val ts1 = new Timestamp(sdf.parse("2013-04-08 01:10:15.567 UTC").getTime) - val ts2 = new Timestamp(sdf.parse("2013-04-08 13:10:10.789 UTC").getTime) - runUnsupported { Seq(ts1, ts2).toDF("timestamp").toArrowPayload.collect() } - } - - test("test Arrow Validator") { - val json = - s""" - |{ - | "schema" : { - | "fields" : [ { - | "name" : "a_i", - | "type" : { - | "name" : "int", - | "isSigned" : true, - | "bitWidth" : 32 - | }, - | "nullable" : false, - | "children" : [ ], - | "typeLayout" : { - | "vectors" : [ { - | "type" : "VALIDITY", - | "typeBitWidth" : 1 - | }, { - | "type" : "DATA", - | "typeBitWidth" : 32 - | } ] - | } - | }, { - | "name" : "b_i", - | "type" : { - | "name" : "int", - | "isSigned" : true, - | "bitWidth" : 32 - | }, - | "nullable" : true, - | "children" : [ ], - | "typeLayout" : { - | "vectors" : [ { - | "type" : "VALIDITY", - | "typeBitWidth" : 1 - | }, { - | "type" : "DATA", - | "typeBitWidth" : 32 - | } ] - | } - | } ] - | }, - | "batches" : [ { - | "count" : 6, - | "columns" : [ { - | "name" : "a_i", - | "count" : 6, - | "VALIDITY" : [ 1, 1, 1, 1, 1, 1 ], - | "DATA" : [ 1, -1, 2, -2, 2147483647, -2147483648 ] - | }, { - | "name" : "b_i", - | "count" : 6, - | "VALIDITY" : [ 1, 0, 0, 1, 0, 1 ], - | "DATA" : [ 1, 0, 0, -2, 0, -2147483648 ] - | } ] - | } ] - |} - """.stripMargin - val json_diff_col_order = - s""" - |{ - | "schema" : { - | "fields" : [ { - | "name" : "b_i", - | "type" : { - | "name" : "int", - | "isSigned" : true, - | "bitWidth" : 32 - | }, - | "nullable" : true, - | "children" : [ ], - | "typeLayout" : { - | "vectors" : [ { - | "type" : "VALIDITY", - | "typeBitWidth" : 1 - | }, { - | "type" : "DATA", - | "typeBitWidth" : 32 - | } ] - | } - | }, { - | "name" : "a_i", - | "type" : { - | "name" : "int", - | "isSigned" : true, - | "bitWidth" : 32 - | }, - | "nullable" : false, - | "children" : [ ], - | "typeLayout" : { - | "vectors" : [ { - | "type" : "VALIDITY", - | "typeBitWidth" : 1 - | }, { - | "type" : "DATA", - | "typeBitWidth" : 32 - | } ] - | } - | } ] - | }, - | "batches" : [ { - | "count" : 6, - | "columns" : [ { - | "name" : "a_i", - | "count" : 6, - | "VALIDITY" : [ 1, 1, 1, 1, 1, 1 ], - | "DATA" : [ 1, -1, 2, -2, 2147483647, -2147483648 ] - | }, { - | "name" : "b_i", - | "count" : 6, - | "VALIDITY" : [ 1, 0, 0, 1, 0, 1 ], - | "DATA" : [ 1, 0, 0, -2, 0, -2147483648 ] - | } ] - | } ] - |} - """.stripMargin - - val a_i = List[Int](1, -1, 2, -2, 2147483647, -2147483648) - val b_i = List[Option[Int]](Some(1), None, None, Some(-2), None, Some(-2147483648)) - val df = a_i.zip(b_i).toDF("a_i", "b_i") - - // Different schema - intercept[IllegalArgumentException] { - collectAndValidate(df, json_diff_col_order, "validator_diff_schema.json") - } - - // Different values - intercept[IllegalArgumentException] { - collectAndValidate(df.sort($"a_i".desc), json, "validator_diff_values.json") - } - } - - /** Test that a converted DataFrame to Arrow record batch equals batch read from JSON file */ - private def collectAndValidate(df: DataFrame, json: String, file: String): Unit = { - // NOTE: coalesce to single partition because can only load 1 batch in validator - val arrowPayload = df.coalesce(1).toArrowPayload.collect().head - val tempFile = new File(tempDataPath, file) - Files.write(json, tempFile, StandardCharsets.UTF_8) - validateConversion(df.schema, arrowPayload, tempFile) - } - - private def validateConversion( - sparkSchema: StructType, - arrowPayload: ArrowPayload, - jsonFile: File): Unit = { - val allocator = new RootAllocator(Long.MaxValue) - val jsonReader = new JsonFileReader(jsonFile, allocator) - - val arrowSchema = ArrowConverters.schemaToArrowSchema(sparkSchema) - val jsonSchema = jsonReader.start() - Validator.compareSchemas(arrowSchema, jsonSchema) - - val arrowRoot = VectorSchemaRoot.create(arrowSchema, allocator) - val vectorLoader = new VectorLoader(arrowRoot) - val arrowRecordBatch = arrowPayload.loadBatch(allocator) - vectorLoader.load(arrowRecordBatch) - val jsonRoot = jsonReader.read() - Validator.compareVectorSchemaRoot(arrowRoot, jsonRoot) - - jsonRoot.close() - jsonReader.close() - arrowRecordBatch.close() - arrowRoot.close() - allocator.close() - } -} From e68aed70fbf1cfa59ba51df70287d718d737a193 Mon Sep 17 00:00:00 2001 From: Burak Yavuz Date: Wed, 28 Jun 2017 10:45:45 -0700 Subject: [PATCH 0807/1765] [SPARK-21216][SS] Hive strategies missed in Structured Streaming IncrementalExecution ## What changes were proposed in this pull request? If someone creates a HiveSession, the planner in `IncrementalExecution` doesn't take into account the Hive scan strategies. This causes joins of Streaming DataFrame's with Hive tables to fail. ## How was this patch tested? Regression test Author: Burak Yavuz Closes #18426 from brkyvz/hive-join. --- .../streaming/IncrementalExecution.scala | 4 ++ .../sql/hive/execution/HiveDDLSuite.scala | 41 ++++++++++++++++++- 2 files changed, 44 insertions(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala index ab89dc6b705d5..dbe652b3b1ed2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala @@ -47,6 +47,10 @@ class IncrementalExecution( sparkSession.sparkContext, sparkSession.sessionState.conf, sparkSession.sessionState.experimentalMethods) { + override def strategies: Seq[Strategy] = + extraPlanningStrategies ++ + sparkSession.sessionState.planner.strategies + override def extraPlanningStrategies: Seq[Strategy] = StatefulAggregationStrategy :: FlatMapGroupsWithStateStrategy :: diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala index aca964907d4cd..31fa3d2447467 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala @@ -160,7 +160,6 @@ class HiveCatalogedDDLSuite extends DDLSuite with TestHiveSingleton with BeforeA test("drop table") { testDropTable(isDatasourceTable = false) } - } class HiveDDLSuite @@ -1956,4 +1955,44 @@ class HiveDDLSuite } } } + + test("SPARK-21216: join with a streaming DataFrame") { + import org.apache.spark.sql.execution.streaming.MemoryStream + import testImplicits._ + + implicit val _sqlContext = spark.sqlContext + + Seq((1, "one"), (2, "two"), (4, "four")).toDF("number", "word").createOrReplaceTempView("t1") + // Make a table and ensure it will be broadcast. + sql("""CREATE TABLE smallTable(word string, number int) + |ROW FORMAT SERDE 'org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe' + |STORED AS TEXTFILE + """.stripMargin) + + sql( + """INSERT INTO smallTable + |SELECT word, number from t1 + """.stripMargin) + + val inputData = MemoryStream[Int] + val joined = inputData.toDS().toDF() + .join(spark.table("smallTable"), $"value" === $"number") + + val sq = joined.writeStream + .format("memory") + .queryName("t2") + .start() + try { + inputData.addData(1, 2) + + sq.processAllAvailable() + + checkAnswer( + spark.table("t2"), + Seq(Row(1, "one", 1), Row(2, "two", 2)) + ) + } finally { + sq.stop() + } + } } From b72b8521d9cad878a1a4e4dbb19cf980169dcbc7 Mon Sep 17 00:00:00 2001 From: Wang Gengliang Date: Thu, 29 Jun 2017 08:47:31 +0800 Subject: [PATCH 0808/1765] [SPARK-21222] Move elimination of Distinct clause from analyzer to optimizer ## What changes were proposed in this pull request? Move elimination of Distinct clause from analyzer to optimizer Distinct clause is useless after MAX/MIN clause. For example, "Select MAX(distinct a) FROM src from" is equivalent of "Select MAX(a) FROM src from" However, this optimization is implemented in analyzer. It should be in optimizer. ## How was this patch tested? Unit test gatorsmile cloud-fan Please review http://spark.apache.org/contributing.html before opening a pull request. Author: Wang Gengliang Closes #18429 from gengliangwang/distinct_opt. --- .../sql/catalyst/analysis/Analyzer.scala | 5 -- .../spark/sql/catalyst/dsl/package.scala | 2 + .../sql/catalyst/optimizer/Optimizer.scala | 15 +++++ .../optimizer/EliminateDistinctSuite.scala | 56 +++++++++++++++++++ 4 files changed, 73 insertions(+), 5 deletions(-) create mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/EliminateDistinctSuite.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 434b6ffee37fa..53536496d0457 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -1197,11 +1197,6 @@ class Analyzer( case u @ UnresolvedFunction(funcId, children, isDistinct) => withPosition(u) { catalog.lookupFunction(funcId, children) match { - // DISTINCT is not meaningful for a Max or a Min. - case max: Max if isDistinct => - AggregateExpression(max, Complete, isDistinct = false) - case min: Min if isDistinct => - AggregateExpression(min, Complete, isDistinct = false) // AggregateWindowFunctions are AggregateFunctions that can only be evaluated within // the context of a Window clause. They do not need to be wrapped in an // AggregateExpression. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala index beee93d906f0f..f6792569b704e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala @@ -159,7 +159,9 @@ package object dsl { def first(e: Expression): Expression = new First(e).toAggregateExpression() def last(e: Expression): Expression = new Last(e).toAggregateExpression() def min(e: Expression): Expression = Min(e).toAggregateExpression() + def minDistinct(e: Expression): Expression = Min(e).toAggregateExpression(isDistinct = true) def max(e: Expression): Expression = Max(e).toAggregateExpression() + def maxDistinct(e: Expression): Expression = Max(e).toAggregateExpression(isDistinct = true) def upper(e: Expression): Expression = Upper(e) def lower(e: Expression): Expression = Lower(e) def sqrt(e: Expression): Expression = Sqrt(e) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index b410312030c5d..946fa7bae0199 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -40,6 +40,7 @@ abstract class Optimizer(sessionCatalog: SessionCatalog, conf: SQLConf) protected val fixedPoint = FixedPoint(conf.optimizerMaxIterations) def batches: Seq[Batch] = { + Batch("Eliminate Distinct", Once, EliminateDistinct) :: // Technically some of the rules in Finish Analysis are not optimizer rules and belong more // in the analyzer, because they are needed for correctness (e.g. ComputeCurrentTime). // However, because we also use the analyzer to canonicalized queries (for view definition), @@ -151,6 +152,20 @@ abstract class Optimizer(sessionCatalog: SessionCatalog, conf: SQLConf) def extendedOperatorOptimizationRules: Seq[Rule[LogicalPlan]] = Nil } +/** + * Remove useless DISTINCT for MAX and MIN. + * This rule should be applied before RewriteDistinctAggregates. + */ +object EliminateDistinct extends Rule[LogicalPlan] { + override def apply(plan: LogicalPlan): LogicalPlan = plan transformExpressions { + case ae: AggregateExpression if ae.isDistinct => + ae.aggregateFunction match { + case _: Max | _: Min => ae.copy(isDistinct = false) + case _ => ae + } + } +} + /** * An optimizer used in test code. * diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/EliminateDistinctSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/EliminateDistinctSuite.scala new file mode 100644 index 0000000000000..f40691bd1a038 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/EliminateDistinctSuite.scala @@ -0,0 +1,56 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.catalyst.optimizer + +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.dsl.plans._ +import org.apache.spark.sql.catalyst.plans.PlanTest +import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Expand, LocalRelation, LogicalPlan} +import org.apache.spark.sql.catalyst.rules.RuleExecutor + +class EliminateDistinctSuite extends PlanTest { + + object Optimize extends RuleExecutor[LogicalPlan] { + val batches = + Batch("Operator Optimizations", Once, + EliminateDistinct) :: Nil + } + + val testRelation = LocalRelation('a.int) + + test("Eliminate Distinct in Max") { + val query = testRelation + .select(maxDistinct('a).as('result)) + .analyze + val answer = testRelation + .select(max('a).as('result)) + .analyze + assert(query != answer) + comparePlans(Optimize.execute(query), answer) + } + + test("Eliminate Distinct in Min") { + val query = testRelation + .select(minDistinct('a).as('result)) + .analyze + val answer = testRelation + .select(min('a).as('result)) + .analyze + assert(query != answer) + comparePlans(Optimize.execute(query), answer) + } +} From 376d90d556fcd4fd84f70ee42a1323e1f48f829d Mon Sep 17 00:00:00 2001 From: actuaryzhang Date: Wed, 28 Jun 2017 19:31:54 -0700 Subject: [PATCH 0809/1765] [SPARK-20889][SPARKR] Grouped documentation for STRING column methods ## What changes were proposed in this pull request? Grouped documentation for string column methods. Author: actuaryzhang Author: Wayne Zhang Closes #18366 from actuaryzhang/sparkRDocString. --- R/pkg/R/functions.R | 573 +++++++++++++++++++------------------------- R/pkg/R/generics.R | 84 ++++--- 2 files changed, 300 insertions(+), 357 deletions(-) diff --git a/R/pkg/R/functions.R b/R/pkg/R/functions.R index 23ccdf941a8c7..70ea620b471fe 100644 --- a/R/pkg/R/functions.R +++ b/R/pkg/R/functions.R @@ -111,6 +111,27 @@ NULL #' head(tmp)} NULL +#' String functions for Column operations +#' +#' String functions defined for \code{Column}. +#' +#' @param x Column to compute on except in the following methods: +#' \itemize{ +#' \item \code{instr}: \code{character}, the substring to check. See 'Details'. +#' \item \code{format_number}: \code{numeric}, the number of decimal place to +#' format to. See 'Details'. +#' } +#' @param y Column to compute on. +#' @param ... additional columns. +#' @name column_string_functions +#' @rdname column_string_functions +#' @family string functions +#' @examples +#' \dontrun{ +#' # Dataframe used throughout this doc +#' df <- createDataFrame(as.data.frame(Titanic, stringsAsFactors = FALSE))} +NULL + #' lit #' #' A new \linkS4class{Column} is created to represent the literal value. @@ -188,19 +209,17 @@ setMethod("approxCountDistinct", column(jc) }) -#' ascii -#' -#' Computes the numeric value of the first character of the string column, and returns the -#' result as a int column. -#' -#' @param x Column to compute on. +#' @details +#' \code{ascii}: Computes the numeric value of the first character of the string column, +#' and returns the result as an int column. #' -#' @rdname ascii -#' @name ascii -#' @family string functions +#' @rdname column_string_functions #' @export -#' @aliases ascii,Column-method -#' @examples \dontrun{\dontrun{ascii(df$c)}} +#' @aliases ascii ascii,Column-method +#' @examples +#' +#' \dontrun{ +#' head(select(df, ascii(df$Class), ascii(df$Sex)))} #' @note ascii since 1.5.0 setMethod("ascii", signature(x = "Column"), @@ -256,19 +275,22 @@ setMethod("avg", column(jc) }) -#' base64 -#' -#' Computes the BASE64 encoding of a binary column and returns it as a string column. -#' This is the reverse of unbase64. -#' -#' @param x Column to compute on. +#' @details +#' \code{base64}: Computes the BASE64 encoding of a binary column and returns it as +#' a string column. This is the reverse of unbase64. #' -#' @rdname base64 -#' @name base64 -#' @family string functions +#' @rdname column_string_functions #' @export -#' @aliases base64,Column-method -#' @examples \dontrun{base64(df$c)} +#' @aliases base64 base64,Column-method +#' @examples +#' +#' \dontrun{ +#' tmp <- mutate(df, s1 = encode(df$Class, "UTF-8")) +#' str(tmp) +#' tmp2 <- mutate(tmp, s2 = base64(tmp$s1), s3 = decode(tmp$s1, "UTF-8"), +#' s4 = soundex(tmp$Sex)) +#' head(tmp2) +#' head(select(tmp2, unbase64(tmp2$s2)))} #' @note base64 since 1.5.0 setMethod("base64", signature(x = "Column"), @@ -620,20 +642,16 @@ setMethod("dayofyear", column(jc) }) -#' decode -#' -#' Computes the first argument into a string from a binary using the provided character set -#' (one of 'US-ASCII', 'ISO-8859-1', 'UTF-8', 'UTF-16BE', 'UTF-16LE', 'UTF-16'). +#' @details +#' \code{decode}: Computes the first argument into a string from a binary using the provided +#' character set. #' -#' @param x Column to compute on. -#' @param charset Character set to use +#' @param charset Character set to use (one of "US-ASCII", "ISO-8859-1", "UTF-8", "UTF-16BE", +#' "UTF-16LE", "UTF-16"). #' -#' @rdname decode -#' @name decode -#' @family string functions -#' @aliases decode,Column,character-method +#' @rdname column_string_functions +#' @aliases decode decode,Column,character-method #' @export -#' @examples \dontrun{decode(df$c, "UTF-8")} #' @note decode since 1.6.0 setMethod("decode", signature(x = "Column", charset = "character"), @@ -642,20 +660,13 @@ setMethod("decode", column(jc) }) -#' encode -#' -#' Computes the first argument into a binary from a string using the provided character set -#' (one of 'US-ASCII', 'ISO-8859-1', 'UTF-8', 'UTF-16BE', 'UTF-16LE', 'UTF-16'). -#' -#' @param x Column to compute on. -#' @param charset Character set to use +#' @details +#' \code{encode}: Computes the first argument into a binary from a string using the provided +#' character set. #' -#' @rdname encode -#' @name encode -#' @family string functions -#' @aliases encode,Column,character-method +#' @rdname column_string_functions +#' @aliases encode encode,Column,character-method #' @export -#' @examples \dontrun{encode(df$c, "UTF-8")} #' @note encode since 1.6.0 setMethod("encode", signature(x = "Column", charset = "character"), @@ -788,21 +799,23 @@ setMethod("hour", column(jc) }) -#' initcap -#' -#' Returns a new string column by converting the first letter of each word to uppercase. -#' Words are delimited by whitespace. -#' -#' For example, "hello world" will become "Hello World". -#' -#' @param x Column to compute on. +#' @details +#' \code{initcap}: Returns a new string column by converting the first letter of +#' each word to uppercase. Words are delimited by whitespace. For example, "hello world" +#' will become "Hello World". #' -#' @rdname initcap -#' @name initcap -#' @family string functions -#' @aliases initcap,Column-method +#' @rdname column_string_functions +#' @aliases initcap initcap,Column-method #' @export -#' @examples \dontrun{initcap(df$c)} +#' @examples +#' +#' \dontrun{ +#' tmp <- mutate(df, sex_lower = lower(df$Sex), age_upper = upper(df$age), +#' sex_age = concat_ws(" ", lower(df$sex), lower(df$age))) +#' head(tmp) +#' tmp2 <- mutate(tmp, s1 = initcap(tmp$sex_lower), s2 = initcap(tmp$sex_age), +#' s3 = reverse(df$Sex)) +#' head(tmp2)} #' @note initcap since 1.5.0 setMethod("initcap", signature(x = "Column"), @@ -918,18 +931,12 @@ setMethod("last_day", column(jc) }) -#' length -#' -#' Computes the length of a given string or binary column. -#' -#' @param x Column to compute on. +#' @details +#' \code{length}: Computes the length of a given string or binary column. #' -#' @rdname length -#' @name length -#' @aliases length,Column-method -#' @family string functions +#' @rdname column_string_functions +#' @aliases length length,Column-method #' @export -#' @examples \dontrun{length(df$c)} #' @note length since 1.5.0 setMethod("length", signature(x = "Column"), @@ -994,18 +1001,12 @@ setMethod("log2", column(jc) }) -#' lower -#' -#' Converts a string column to lower case. -#' -#' @param x Column to compute on. +#' @details +#' \code{lower}: Converts a string column to lower case. #' -#' @rdname lower -#' @name lower -#' @family string functions -#' @aliases lower,Column-method +#' @rdname column_string_functions +#' @aliases lower lower,Column-method #' @export -#' @examples \dontrun{lower(df$c)} #' @note lower since 1.4.0 setMethod("lower", signature(x = "Column"), @@ -1014,18 +1015,24 @@ setMethod("lower", column(jc) }) -#' ltrim -#' -#' Trim the spaces from left end for the specified string value. -#' -#' @param x Column to compute on. +#' @details +#' \code{ltrim}: Trims the spaces from left end for the specified string value. #' -#' @rdname ltrim -#' @name ltrim -#' @family string functions -#' @aliases ltrim,Column-method +#' @rdname column_string_functions +#' @aliases ltrim ltrim,Column-method #' @export -#' @examples \dontrun{ltrim(df$c)} +#' @examples +#' +#' \dontrun{ +#' tmp <- mutate(df, SexLpad = lpad(df$Sex, 6, " "), SexRpad = rpad(df$Sex, 7, " ")) +#' head(select(tmp, length(tmp$Sex), length(tmp$SexLpad), length(tmp$SexRpad))) +#' tmp2 <- mutate(tmp, SexLtrim = ltrim(tmp$SexLpad), SexRtrim = rtrim(tmp$SexRpad), +#' SexTrim = trim(tmp$SexLpad)) +#' head(select(tmp2, length(tmp2$Sex), length(tmp2$SexLtrim), +#' length(tmp2$SexRtrim), length(tmp2$SexTrim))) +#' +#' tmp <- mutate(df, SexLpad = lpad(df$Sex, 6, "xx"), SexRpad = rpad(df$Sex, 7, "xx")) +#' head(tmp)} #' @note ltrim since 1.5.0 setMethod("ltrim", signature(x = "Column"), @@ -1198,18 +1205,12 @@ setMethod("quarter", column(jc) }) -#' reverse -#' -#' Reverses the string column and returns it as a new string column. -#' -#' @param x Column to compute on. +#' @details +#' \code{reverse}: Reverses the string column and returns it as a new string column. #' -#' @rdname reverse -#' @name reverse -#' @family string functions -#' @aliases reverse,Column-method +#' @rdname column_string_functions +#' @aliases reverse reverse,Column-method #' @export -#' @examples \dontrun{reverse(df$c)} #' @note reverse since 1.5.0 setMethod("reverse", signature(x = "Column"), @@ -1268,18 +1269,12 @@ setMethod("bround", column(jc) }) -#' rtrim -#' -#' Trim the spaces from right end for the specified string value. -#' -#' @param x Column to compute on. +#' @details +#' \code{rtrim}: Trims the spaces from right end for the specified string value. #' -#' @rdname rtrim -#' @name rtrim -#' @family string functions -#' @aliases rtrim,Column-method +#' @rdname column_string_functions +#' @aliases rtrim rtrim,Column-method #' @export -#' @examples \dontrun{rtrim(df$c)} #' @note rtrim since 1.5.0 setMethod("rtrim", signature(x = "Column"), @@ -1409,18 +1404,12 @@ setMethod("skewness", column(jc) }) -#' soundex -#' -#' Return the soundex code for the specified expression. -#' -#' @param x Column to compute on. +#' @details +#' \code{soundex}: Returns the soundex code for the specified expression. #' -#' @rdname soundex -#' @name soundex -#' @family string functions -#' @aliases soundex,Column-method +#' @rdname column_string_functions +#' @aliases soundex soundex,Column-method #' @export -#' @examples \dontrun{soundex(df$c)} #' @note soundex since 1.5.0 setMethod("soundex", signature(x = "Column"), @@ -1731,18 +1720,12 @@ setMethod("to_timestamp", column(jc) }) -#' trim -#' -#' Trim the spaces from both ends for the specified string column. -#' -#' @param x Column to compute on. +#' @details +#' \code{trim}: Trims the spaces from both ends for the specified string column. #' -#' @rdname trim -#' @name trim -#' @family string functions -#' @aliases trim,Column-method +#' @rdname column_string_functions +#' @aliases trim trim,Column-method #' @export -#' @examples \dontrun{trim(df$c)} #' @note trim since 1.5.0 setMethod("trim", signature(x = "Column"), @@ -1751,19 +1734,13 @@ setMethod("trim", column(jc) }) -#' unbase64 -#' -#' Decodes a BASE64 encoded string column and returns it as a binary column. +#' @details +#' \code{unbase64}: Decodes a BASE64 encoded string column and returns it as a binary column. #' This is the reverse of base64. #' -#' @param x Column to compute on. -#' -#' @rdname unbase64 -#' @name unbase64 -#' @family string functions -#' @aliases unbase64,Column-method +#' @rdname column_string_functions +#' @aliases unbase64 unbase64,Column-method #' @export -#' @examples \dontrun{unbase64(df$c)} #' @note unbase64 since 1.5.0 setMethod("unbase64", signature(x = "Column"), @@ -1787,18 +1764,12 @@ setMethod("unhex", column(jc) }) -#' upper -#' -#' Converts a string column to upper case. -#' -#' @param x Column to compute on. +#' @details +#' \code{upper}: Converts a string column to upper case. #' -#' @rdname upper -#' @name upper -#' @family string functions -#' @aliases upper,Column-method +#' @rdname column_string_functions +#' @aliases upper upper,Column-method #' @export -#' @examples \dontrun{upper(df$c)} #' @note upper since 1.4.0 setMethod("upper", signature(x = "Column"), @@ -1949,19 +1920,19 @@ setMethod("hypot", signature(y = "Column"), column(jc) }) -#' levenshtein -#' -#' Computes the Levenshtein distance of the two given string columns. -#' -#' @param x Column to compute on. -#' @param y Column to compute on. +#' @details +#' \code{levenshtein}: Computes the Levenshtein distance of the two given string columns. #' -#' @rdname levenshtein -#' @name levenshtein -#' @family string functions -#' @aliases levenshtein,Column-method +#' @rdname column_string_functions +#' @aliases levenshtein levenshtein,Column-method #' @export -#' @examples \dontrun{levenshtein(df$c, x)} +#' @examples +#' +#' \dontrun{ +#' tmp <- mutate(df, d1 = levenshtein(df$Class, df$Sex), +#' d2 = levenshtein(df$Age, df$Sex), +#' d3 = levenshtein(df$Age, df$Age)) +#' head(tmp)} #' @note levenshtein since 1.5.0 setMethod("levenshtein", signature(y = "Column"), function(y, x) { @@ -2061,20 +2032,22 @@ setMethod("countDistinct", column(jc) }) - -#' concat -#' -#' Concatenates multiple input string columns together into a single string column. -#' -#' @param x Column to compute on -#' @param ... other columns +#' @details +#' \code{concat}: Concatenates multiple input string columns together into a single string column. #' -#' @family string functions -#' @rdname concat -#' @name concat -#' @aliases concat,Column-method +#' @rdname column_string_functions +#' @aliases concat concat,Column-method #' @export -#' @examples \dontrun{concat(df$strings, df$strings2)} +#' @examples +#' +#' \dontrun{ +#' # concatenate strings +#' tmp <- mutate(df, s1 = concat(df$Class, df$Sex), +#' s2 = concat(df$Class, df$Sex, df$Age), +#' s3 = concat(df$Class, df$Sex, df$Age, df$Class), +#' s4 = concat_ws("_", df$Class, df$Sex), +#' s5 = concat_ws("+", df$Class, df$Sex, df$Age, df$Survived)) +#' head(tmp)} #' @note concat since 1.5.0 setMethod("concat", signature(x = "Column"), @@ -2243,22 +2216,21 @@ setMethod("from_utc_timestamp", signature(y = "Column", x = "character"), column(jc) }) -#' instr -#' -#' Locate the position of the first occurrence of substr column in the given string. -#' Returns null if either of the arguments are null. -#' -#' Note: The position is not zero based, but 1 based index. Returns 0 if substr -#' could not be found in str. +#' @details +#' \code{instr}: Locates the position of the first occurrence of a substring (\code{x}) +#' in the given string column (\code{y}). Returns null if either of the arguments are null. +#' Note: The position is not zero based, but 1 based index. Returns 0 if the substring +#' could not be found in the string column. #' -#' @param y column to check -#' @param x substring to check -#' @family string functions -#' @aliases instr,Column,character-method -#' @rdname instr -#' @name instr +#' @rdname column_string_functions +#' @aliases instr instr,Column,character-method #' @export -#' @examples \dontrun{instr(df$c, 'b')} +#' @examples +#' +#' \dontrun{ +#' tmp <- mutate(df, s1 = instr(df$Sex, "m"), s2 = instr(df$Sex, "M"), +#' s3 = locate("m", df$Sex), s4 = locate("m", df$Sex, pos = 4)) +#' head(tmp)} #' @note instr since 1.5.0 setMethod("instr", signature(y = "Column", x = "character"), function(y, x) { @@ -2345,22 +2317,22 @@ setMethod("date_sub", signature(y = "Column", x = "numeric"), column(jc) }) -#' format_number -#' -#' Formats numeric column y to a format like '#,###,###.##', rounded to x decimal places -#' with HALF_EVEN round mode, and returns the result as a string column. -#' -#' If x is 0, the result has no decimal point or fractional part. -#' If x < 0, the result will be null. +#' @details +#' \code{format_number}: Formats numeric column \code{y} to a format like '#,###,###.##', +#' rounded to \code{x} decimal places with HALF_EVEN round mode, and returns the result +#' as a string column. +#' If \code{x} is 0, the result has no decimal point or fractional part. +#' If \code{x} < 0, the result will be null. #' -#' @param y column to format -#' @param x number of decimal place to format to -#' @family string functions -#' @rdname format_number -#' @name format_number -#' @aliases format_number,Column,numeric-method +#' @rdname column_string_functions +#' @aliases format_number format_number,Column,numeric-method #' @export -#' @examples \dontrun{format_number(df$n, 4)} +#' @examples +#' +#' \dontrun{ +#' tmp <- mutate(df, v1 = df$Freq/3) +#' head(select(tmp, format_number(tmp$v1, 0), format_number(tmp$v1, 2), +#' format_string("%4.2f %s", tmp$v1, tmp$Sex)), 10)} #' @note format_number since 1.5.0 setMethod("format_number", signature(y = "Column", x = "numeric"), function(y, x) { @@ -2438,21 +2410,14 @@ setMethod("shiftRightUnsigned", signature(y = "Column", x = "numeric"), column(jc) }) -#' concat_ws -#' -#' Concatenates multiple input string columns together into a single string column, -#' using the given separator. +#' @details +#' \code{concat_ws}: Concatenates multiple input string columns together into a single +#' string column, using the given separator. #' -#' @param x column to concatenate. #' @param sep separator to use. -#' @param ... other columns to concatenate. -#' -#' @family string functions -#' @rdname concat_ws -#' @name concat_ws -#' @aliases concat_ws,character,Column-method +#' @rdname column_string_functions +#' @aliases concat_ws concat_ws,character,Column-method #' @export -#' @examples \dontrun{concat_ws('-', df$s, df$d)} #' @note concat_ws since 1.5.0 setMethod("concat_ws", signature(sep = "character", x = "Column"), function(sep, x, ...) { @@ -2499,19 +2464,14 @@ setMethod("expr", signature(x = "character"), column(jc) }) -#' format_string -#' -#' Formats the arguments in printf-style and returns the result as a string column. +#' @details +#' \code{format_string}: Formats the arguments in printf-style and returns the result +#' as a string column. #' #' @param format a character object of format strings. -#' @param x a Column. -#' @param ... additional Column(s). -#' @family string functions -#' @rdname format_string -#' @name format_string -#' @aliases format_string,character,Column-method +#' @rdname column_string_functions +#' @aliases format_string format_string,character,Column-method #' @export -#' @examples \dontrun{format_string('%d %s', df$a, df$b)} #' @note format_string since 1.5.0 setMethod("format_string", signature(format = "character", x = "Column"), function(format, x, ...) { @@ -2620,23 +2580,17 @@ setMethod("window", signature(x = "Column"), column(jc) }) -#' locate -#' -#' Locate the position of the first occurrence of substr. -#' +#' @details +#' \code{locate}: Locates the position of the first occurrence of substr. #' Note: The position is not zero based, but 1 based index. Returns 0 if substr #' could not be found in str. #' #' @param substr a character string to be matched. #' @param str a Column where matches are sought for each entry. #' @param pos start position of search. -#' @param ... further arguments to be passed to or from other methods. -#' @family string functions -#' @rdname locate -#' @aliases locate,character,Column-method -#' @name locate +#' @rdname column_string_functions +#' @aliases locate locate,character,Column-method #' @export -#' @examples \dontrun{locate('b', df$c, 1)} #' @note locate since 1.5.0 setMethod("locate", signature(substr = "character", str = "Column"), function(substr, str, pos = 1) { @@ -2646,19 +2600,14 @@ setMethod("locate", signature(substr = "character", str = "Column"), column(jc) }) -#' lpad -#' -#' Left-pad the string column with +#' @details +#' \code{lpad}: Left-padded with pad to a length of len. #' -#' @param x the string Column to be left-padded. #' @param len maximum length of each output result. #' @param pad a character string to be padded with. -#' @family string functions -#' @rdname lpad -#' @aliases lpad,Column,numeric,character-method -#' @name lpad +#' @rdname column_string_functions +#' @aliases lpad lpad,Column,numeric,character-method #' @export -#' @examples \dontrun{lpad(df$c, 6, '#')} #' @note lpad since 1.5.0 setMethod("lpad", signature(x = "Column", len = "numeric", pad = "character"), function(x, len, pad) { @@ -2728,20 +2677,27 @@ setMethod("randn", signature(seed = "numeric"), column(jc) }) -#' regexp_extract -#' -#' Extract a specific \code{idx} group identified by a Java regex, from the specified string column. -#' If the regex did not match, or the specified group did not match, an empty string is returned. +#' @details +#' \code{regexp_extract}: Extracts a specific \code{idx} group identified by a Java regex, +#' from the specified string column. If the regex did not match, or the specified group did +#' not match, an empty string is returned. #' -#' @param x a string Column. #' @param pattern a regular expression. #' @param idx a group index. -#' @family string functions -#' @rdname regexp_extract -#' @name regexp_extract -#' @aliases regexp_extract,Column,character,numeric-method +#' @rdname column_string_functions +#' @aliases regexp_extract regexp_extract,Column,character,numeric-method #' @export -#' @examples \dontrun{regexp_extract(df$c, '(\d+)-(\d+)', 1)} +#' @examples +#' +#' \dontrun{ +#' tmp <- mutate(df, s1 = regexp_extract(df$Class, "(\\d+)\\w+", 1), +#' s2 = regexp_extract(df$Sex, "^(\\w)\\w+", 1), +#' s3 = regexp_replace(df$Class, "\\D+", ""), +#' s4 = substring_index(df$Sex, "a", 1), +#' s5 = substring_index(df$Sex, "a", -1), +#' s6 = translate(df$Sex, "ale", ""), +#' s7 = translate(df$Sex, "a", "-")) +#' head(tmp)} #' @note regexp_extract since 1.5.0 setMethod("regexp_extract", signature(x = "Column", pattern = "character", idx = "numeric"), @@ -2752,19 +2708,14 @@ setMethod("regexp_extract", column(jc) }) -#' regexp_replace -#' -#' Replace all substrings of the specified string value that match regexp with rep. +#' @details +#' \code{regexp_replace}: Replaces all substrings of the specified string value that +#' match regexp with rep. #' -#' @param x a string Column. -#' @param pattern a regular expression. #' @param replacement a character string that a matched \code{pattern} is replaced with. -#' @family string functions -#' @rdname regexp_replace -#' @name regexp_replace -#' @aliases regexp_replace,Column,character,character-method +#' @rdname column_string_functions +#' @aliases regexp_replace regexp_replace,Column,character,character-method #' @export -#' @examples \dontrun{regexp_replace(df$c, '(\\d+)', '--')} #' @note regexp_replace since 1.5.0 setMethod("regexp_replace", signature(x = "Column", pattern = "character", replacement = "character"), @@ -2775,19 +2726,12 @@ setMethod("regexp_replace", column(jc) }) -#' rpad -#' -#' Right-padded with pad to a length of len. +#' @details +#' \code{rpad}: Right-padded with pad to a length of len. #' -#' @param x the string Column to be right-padded. -#' @param len maximum length of each output result. -#' @param pad a character string to be padded with. -#' @family string functions -#' @rdname rpad -#' @name rpad -#' @aliases rpad,Column,numeric,character-method +#' @rdname column_string_functions +#' @aliases rpad rpad,Column,numeric,character-method #' @export -#' @examples \dontrun{rpad(df$c, 6, '#')} #' @note rpad since 1.5.0 setMethod("rpad", signature(x = "Column", len = "numeric", pad = "character"), function(x, len, pad) { @@ -2797,28 +2741,20 @@ setMethod("rpad", signature(x = "Column", len = "numeric", pad = "character"), column(jc) }) -#' substring_index -#' -#' Returns the substring from string str before count occurrences of the delimiter delim. -#' If count is positive, everything the left of the final delimiter (counting from left) is -#' returned. If count is negative, every to the right of the final delimiter (counting from the -#' right) is returned. substring_index performs a case-sensitive match when searching for delim. +#' @details +#' \code{substring_index}: Returns the substring from string str before count occurrences of +#' the delimiter delim. If count is positive, everything the left of the final delimiter +#' (counting from left) is returned. If count is negative, every to the right of the final +#' delimiter (counting from the right) is returned. substring_index performs a case-sensitive +#' match when searching for delim. #' -#' @param x a Column. #' @param delim a delimiter string. #' @param count number of occurrences of \code{delim} before the substring is returned. #' A positive number means counting from the left, while negative means #' counting from the right. -#' @family string functions -#' @rdname substring_index -#' @aliases substring_index,Column,character,numeric-method -#' @name substring_index +#' @rdname column_string_functions +#' @aliases substring_index substring_index,Column,character,numeric-method #' @export -#' @examples -#'\dontrun{ -#'substring_index(df$c, '.', 2) -#'substring_index(df$c, '.', -1) -#'} #' @note substring_index since 1.5.0 setMethod("substring_index", signature(x = "Column", delim = "character", count = "numeric"), @@ -2829,24 +2765,19 @@ setMethod("substring_index", column(jc) }) -#' translate -#' -#' Translate any character in the src by a character in replaceString. +#' @details +#' \code{translate}: Translates any character in the src by a character in replaceString. #' The characters in replaceString is corresponding to the characters in matchingString. #' The translate will happen when any character in the string matching with the character #' in the matchingString. #' -#' @param x a string Column. #' @param matchingString a source string where each character will be translated. #' @param replaceString a target string where each \code{matchingString} character will #' be replaced by the character in \code{replaceString} #' at the same location, if any. -#' @family string functions -#' @rdname translate -#' @name translate -#' @aliases translate,Column,character,character-method +#' @rdname column_string_functions +#' @aliases translate translate,Column,character,character-method #' @export -#' @examples \dontrun{translate(df$c, 'rnlt', '123')} #' @note translate since 1.5.0 setMethod("translate", signature(x = "Column", matchingString = "character", replaceString = "character"), @@ -3419,28 +3350,20 @@ setMethod("collect_set", column(jc) }) -#' split_string -#' -#' Splits string on regular expression. -#' -#' Equivalent to \code{split} SQL function -#' -#' @param x Column to compute on -#' @param pattern Java regular expression +#' @details +#' \code{split_string}: Splits string on regular expression. +#' Equivalent to \code{split} SQL function. #' -#' @rdname split_string -#' @family string functions -#' @aliases split_string,Column-method +#' @rdname column_string_functions +#' @aliases split_string split_string,Column-method #' @export #' @examples -#' \dontrun{ -#' df <- read.text("README.md") -#' -#' head(select(df, split_string(df$value, "\\s+"))) #' +#' \dontrun{ +#' head(select(df, split_string(df$Sex, "a"))) +#' head(select(df, split_string(df$Class, "\\d"))) #' # This is equivalent to the following SQL expression -#' head(selectExpr(df, "split(value, '\\\\s+')")) -#' } +#' head(selectExpr(df, "split(Class, '\\\\d')"))} #' @note split_string 2.3.0 setMethod("split_string", signature(x = "Column", pattern = "character"), @@ -3449,28 +3372,20 @@ setMethod("split_string", column(jc) }) -#' repeat_string -#' -#' Repeats string n times. -#' -#' Equivalent to \code{repeat} SQL function +#' @details +#' \code{repeat_string}: Repeats string n times. +#' Equivalent to \code{repeat} SQL function. #' -#' @param x Column to compute on #' @param n Number of repetitions -#' -#' @rdname repeat_string -#' @family string functions -#' @aliases repeat_string,Column-method +#' @rdname column_string_functions +#' @aliases repeat_string repeat_string,Column-method #' @export #' @examples -#' \dontrun{ -#' df <- read.text("README.md") -#' -#' first(select(df, repeat_string(df$value, 3))) #' +#' \dontrun{ +#' head(select(df, repeat_string(df$Class, 3))) #' # This is equivalent to the following SQL expression -#' first(selectExpr(df, "repeat(value, 3)")) -#' } +#' head(selectExpr(df, "repeat(Class, 3)"))} #' @note repeat_string since 2.3.0 setMethod("repeat_string", signature(x = "Column", n = "numeric"), diff --git a/R/pkg/R/generics.R b/R/pkg/R/generics.R index 0248ec585d771..dc99e3d94b269 100644 --- a/R/pkg/R/generics.R +++ b/R/pkg/R/generics.R @@ -917,8 +917,9 @@ setGeneric("approxCountDistinct", function(x, ...) { standardGeneric("approxCoun #' @export setGeneric("array_contains", function(x, value) { standardGeneric("array_contains") }) -#' @rdname ascii +#' @rdname column_string_functions #' @export +#' @name NULL setGeneric("ascii", function(x) { standardGeneric("ascii") }) #' @param x Column to compute on or a GroupedData object. @@ -927,8 +928,9 @@ setGeneric("ascii", function(x) { standardGeneric("ascii") }) #' @export setGeneric("avg", function(x, ...) { standardGeneric("avg") }) -#' @rdname base64 +#' @rdname column_string_functions #' @export +#' @name NULL setGeneric("base64", function(x) { standardGeneric("base64") }) #' @rdname column_math_functions @@ -969,12 +971,14 @@ setGeneric("collect_set", function(x) { standardGeneric("collect_set") }) #' @export setGeneric("column", function(x) { standardGeneric("column") }) -#' @rdname concat +#' @rdname column_string_functions #' @export +#' @name NULL setGeneric("concat", function(x, ...) { standardGeneric("concat") }) -#' @rdname concat_ws +#' @rdname column_string_functions #' @export +#' @name NULL setGeneric("concat_ws", function(sep, x, ...) { standardGeneric("concat_ws") }) #' @rdname column_math_functions @@ -1038,8 +1042,9 @@ setGeneric("dayofmonth", function(x) { standardGeneric("dayofmonth") }) #' @name NULL setGeneric("dayofyear", function(x) { standardGeneric("dayofyear") }) -#' @rdname decode +#' @rdname column_string_functions #' @export +#' @name NULL setGeneric("decode", function(x, charset) { standardGeneric("decode") }) #' @param x empty. Should be used with no argument. @@ -1047,8 +1052,9 @@ setGeneric("decode", function(x, charset) { standardGeneric("decode") }) #' @export setGeneric("dense_rank", function(x = "missing") { standardGeneric("dense_rank") }) -#' @rdname encode +#' @rdname column_string_functions #' @export +#' @name NULL setGeneric("encode", function(x, charset) { standardGeneric("encode") }) #' @rdname explode @@ -1068,12 +1074,14 @@ setGeneric("expr", function(x) { standardGeneric("expr") }) #' @name NULL setGeneric("from_utc_timestamp", function(y, x) { standardGeneric("from_utc_timestamp") }) -#' @rdname format_number +#' @rdname column_string_functions #' @export +#' @name NULL setGeneric("format_number", function(y, x) { standardGeneric("format_number") }) -#' @rdname format_string +#' @rdname column_string_functions #' @export +#' @name NULL setGeneric("format_string", function(format, x, ...) { standardGeneric("format_string") }) #' @rdname from_json @@ -1114,8 +1122,9 @@ setGeneric("hour", function(x) { standardGeneric("hour") }) #' @name NULL setGeneric("hypot", function(y, x) { standardGeneric("hypot") }) -#' @rdname initcap +#' @rdname column_string_functions #' @export +#' @name NULL setGeneric("initcap", function(x) { standardGeneric("initcap") }) #' @param x empty. Should be used with no argument. @@ -1124,8 +1133,9 @@ setGeneric("initcap", function(x) { standardGeneric("initcap") }) setGeneric("input_file_name", function(x = "missing") { standardGeneric("input_file_name") }) -#' @rdname instr +#' @rdname column_string_functions #' @export +#' @name NULL setGeneric("instr", function(y, x) { standardGeneric("instr") }) #' @rdname is.nan @@ -1158,28 +1168,33 @@ setGeneric("lead", function(x, offset, defaultValue = NULL) { standardGeneric("l #' @export setGeneric("least", function(x, ...) { standardGeneric("least") }) -#' @rdname levenshtein +#' @rdname column_string_functions #' @export +#' @name NULL setGeneric("levenshtein", function(y, x) { standardGeneric("levenshtein") }) #' @rdname lit #' @export setGeneric("lit", function(x) { standardGeneric("lit") }) -#' @rdname locate +#' @rdname column_string_functions #' @export +#' @name NULL setGeneric("locate", function(substr, str, ...) { standardGeneric("locate") }) -#' @rdname lower +#' @rdname column_string_functions #' @export +#' @name NULL setGeneric("lower", function(x) { standardGeneric("lower") }) -#' @rdname lpad +#' @rdname column_string_functions #' @export +#' @name NULL setGeneric("lpad", function(x, len, pad) { standardGeneric("lpad") }) -#' @rdname ltrim +#' @rdname column_string_functions #' @export +#' @name NULL setGeneric("ltrim", function(x) { standardGeneric("ltrim") }) #' @rdname md5 @@ -1272,21 +1287,25 @@ setGeneric("randn", function(seed) { standardGeneric("randn") }) #' @export setGeneric("rank", function(x, ...) { standardGeneric("rank") }) -#' @rdname regexp_extract +#' @rdname column_string_functions #' @export +#' @name NULL setGeneric("regexp_extract", function(x, pattern, idx) { standardGeneric("regexp_extract") }) -#' @rdname regexp_replace +#' @rdname column_string_functions #' @export +#' @name NULL setGeneric("regexp_replace", function(x, pattern, replacement) { standardGeneric("regexp_replace") }) -#' @rdname repeat_string +#' @rdname column_string_functions #' @export +#' @name NULL setGeneric("repeat_string", function(x, n) { standardGeneric("repeat_string") }) -#' @rdname reverse +#' @rdname column_string_functions #' @export +#' @name NULL setGeneric("reverse", function(x) { standardGeneric("reverse") }) #' @rdname column_math_functions @@ -1299,12 +1318,14 @@ setGeneric("rint", function(x) { standardGeneric("rint") }) #' @export setGeneric("row_number", function(x = "missing") { standardGeneric("row_number") }) -#' @rdname rpad +#' @rdname column_string_functions #' @export +#' @name NULL setGeneric("rpad", function(x, len, pad) { standardGeneric("rpad") }) -#' @rdname rtrim +#' @rdname column_string_functions #' @export +#' @name NULL setGeneric("rtrim", function(x) { standardGeneric("rtrim") }) #' @rdname column_aggregate_functions @@ -1358,12 +1379,14 @@ setGeneric("skewness", function(x) { standardGeneric("skewness") }) #' @export setGeneric("sort_array", function(x, asc = TRUE) { standardGeneric("sort_array") }) -#' @rdname split_string +#' @rdname column_string_functions #' @export +#' @name NULL setGeneric("split_string", function(x, pattern) { standardGeneric("split_string") }) -#' @rdname soundex +#' @rdname column_string_functions #' @export +#' @name NULL setGeneric("soundex", function(x) { standardGeneric("soundex") }) #' @param x empty. Should be used with no argument. @@ -1390,8 +1413,9 @@ setGeneric("stddev_samp", function(x) { standardGeneric("stddev_samp") }) #' @export setGeneric("struct", function(x, ...) { standardGeneric("struct") }) -#' @rdname substring_index +#' @rdname column_string_functions #' @export +#' @name NULL setGeneric("substring_index", function(x, delim, count) { standardGeneric("substring_index") }) #' @rdname column_aggregate_functions @@ -1428,16 +1452,19 @@ setGeneric("to_timestamp", function(x, format) { standardGeneric("to_timestamp") #' @name NULL setGeneric("to_utc_timestamp", function(y, x) { standardGeneric("to_utc_timestamp") }) -#' @rdname translate +#' @rdname column_string_functions #' @export +#' @name NULL setGeneric("translate", function(x, matchingString, replaceString) { standardGeneric("translate") }) -#' @rdname trim +#' @rdname column_string_functions #' @export +#' @name NULL setGeneric("trim", function(x) { standardGeneric("trim") }) -#' @rdname unbase64 +#' @rdname column_string_functions #' @export +#' @name NULL setGeneric("unbase64", function(x) { standardGeneric("unbase64") }) #' @rdname column_math_functions @@ -1450,8 +1477,9 @@ setGeneric("unhex", function(x) { standardGeneric("unhex") }) #' @name NULL setGeneric("unix_timestamp", function(x, format) { standardGeneric("unix_timestamp") }) -#' @rdname upper +#' @rdname column_string_functions #' @export +#' @name NULL setGeneric("upper", function(x) { standardGeneric("upper") }) #' @rdname column_aggregate_functions From 0c8444cf6d0620cd219ddcf5f50b12ff648639e9 Mon Sep 17 00:00:00 2001 From: Yanbo Liang Date: Thu, 29 Jun 2017 10:32:32 +0800 Subject: [PATCH 0810/1765] [SPARK-14657][SPARKR][ML] RFormula w/o intercept should output reference category when encoding string terms ## What changes were proposed in this pull request? Please see [SPARK-14657](https://issues.apache.org/jira/browse/SPARK-14657) for detail of this bug. I searched online and test some other cases, found when we fit R glm model(or other models powered by R formula) w/o intercept on a dataset including string/category features, one of the categories in the first category feature is being used as reference category, we will not drop any category for that feature. I think we should keep consistent semantics between Spark RFormula and R formula. ## How was this patch tested? Add standard unit tests. cc mengxr Author: Yanbo Liang Closes #12414 from yanboliang/spark-14657. --- .../apache/spark/ml/feature/RFormula.scala | 10 ++- .../spark/ml/feature/RFormulaSuite.scala | 83 +++++++++++++++++++ 2 files changed, 92 insertions(+), 1 deletion(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala index 1fad0a6fc9443..4b44878784c90 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala @@ -205,12 +205,20 @@ class RFormula @Since("1.5.0") (@Since("1.5.0") override val uid: String) }.toMap // Then we handle one-hot encoding and interactions between terms. + var keepReferenceCategory = false val encodedTerms = resolvedFormula.terms.map { case Seq(term) if dataset.schema(term).dataType == StringType => val encodedCol = tmpColumn("onehot") - encoderStages += new OneHotEncoder() + var encoder = new OneHotEncoder() .setInputCol(indexed(term)) .setOutputCol(encodedCol) + // Formula w/o intercept, one of the categories in the first category feature is + // being used as reference category, we will not drop any category for that feature. + if (!hasIntercept && !keepReferenceCategory) { + encoder = encoder.setDropLast(false) + keepReferenceCategory = true + } + encoderStages += encoder prefixesToRewrite(encodedCol + "_") = term + "_" encodedCol case Seq(term) => diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala index 41d0062c2cabd..23570d6e0b4cb 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala @@ -213,6 +213,89 @@ class RFormulaSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul assert(result.collect() === expected.collect()) } + test("formula w/o intercept, we should output reference category when encoding string terms") { + /* + R code: + + df <- data.frame(id = c(1, 2, 3, 4), + a = c("foo", "bar", "bar", "baz"), + b = c("zq", "zz", "zz", "zz"), + c = c(4, 4, 5, 5)) + model.matrix(id ~ a + b + c - 1, df) + + abar abaz afoo bzz c + 1 0 0 1 0 4 + 2 1 0 0 1 4 + 3 1 0 0 1 5 + 4 0 1 0 1 5 + + model.matrix(id ~ a:b + c - 1, df) + + c abar:bzq abaz:bzq afoo:bzq abar:bzz abaz:bzz afoo:bzz + 1 4 0 0 1 0 0 0 + 2 4 0 0 0 1 0 0 + 3 5 0 0 0 1 0 0 + 4 5 0 0 0 0 1 0 + */ + val original = Seq((1, "foo", "zq", 4), (2, "bar", "zz", 4), (3, "bar", "zz", 5), + (4, "baz", "zz", 5)).toDF("id", "a", "b", "c") + + val formula1 = new RFormula().setFormula("id ~ a + b + c - 1") + .setStringIndexerOrderType(StringIndexer.alphabetDesc) + val model1 = formula1.fit(original) + val result1 = model1.transform(original) + val resultSchema1 = model1.transformSchema(original.schema) + // Note the column order is different between R and Spark. + val expected1 = Seq( + (1, "foo", "zq", 4, Vectors.sparse(5, Array(0, 4), Array(1.0, 4.0)), 1.0), + (2, "bar", "zz", 4, Vectors.dense(0.0, 0.0, 1.0, 1.0, 4.0), 2.0), + (3, "bar", "zz", 5, Vectors.dense(0.0, 0.0, 1.0, 1.0, 5.0), 3.0), + (4, "baz", "zz", 5, Vectors.dense(0.0, 1.0, 0.0, 1.0, 5.0), 4.0) + ).toDF("id", "a", "b", "c", "features", "label") + assert(result1.schema.toString == resultSchema1.toString) + assert(result1.collect() === expected1.collect()) + + val attrs1 = AttributeGroup.fromStructField(result1.schema("features")) + val expectedAttrs1 = new AttributeGroup( + "features", + Array[Attribute]( + new BinaryAttribute(Some("a_foo"), Some(1)), + new BinaryAttribute(Some("a_baz"), Some(2)), + new BinaryAttribute(Some("a_bar"), Some(3)), + new BinaryAttribute(Some("b_zz"), Some(4)), + new NumericAttribute(Some("c"), Some(5)))) + assert(attrs1 === expectedAttrs1) + + // There is no impact for string terms interaction. + val formula2 = new RFormula().setFormula("id ~ a:b + c - 1") + .setStringIndexerOrderType(StringIndexer.alphabetDesc) + val model2 = formula2.fit(original) + val result2 = model2.transform(original) + val resultSchema2 = model2.transformSchema(original.schema) + // Note the column order is different between R and Spark. + val expected2 = Seq( + (1, "foo", "zq", 4, Vectors.sparse(7, Array(1, 6), Array(1.0, 4.0)), 1.0), + (2, "bar", "zz", 4, Vectors.sparse(7, Array(4, 6), Array(1.0, 4.0)), 2.0), + (3, "bar", "zz", 5, Vectors.sparse(7, Array(4, 6), Array(1.0, 5.0)), 3.0), + (4, "baz", "zz", 5, Vectors.sparse(7, Array(2, 6), Array(1.0, 5.0)), 4.0) + ).toDF("id", "a", "b", "c", "features", "label") + assert(result2.schema.toString == resultSchema2.toString) + assert(result2.collect() === expected2.collect()) + + val attrs2 = AttributeGroup.fromStructField(result2.schema("features")) + val expectedAttrs2 = new AttributeGroup( + "features", + Array[Attribute]( + new NumericAttribute(Some("a_foo:b_zz"), Some(1)), + new NumericAttribute(Some("a_foo:b_zq"), Some(2)), + new NumericAttribute(Some("a_baz:b_zz"), Some(3)), + new NumericAttribute(Some("a_baz:b_zq"), Some(4)), + new NumericAttribute(Some("a_bar:b_zz"), Some(5)), + new NumericAttribute(Some("a_bar:b_zq"), Some(6)), + new NumericAttribute(Some("c"), Some(7)))) + assert(attrs2 === expectedAttrs2) + } + test("index string label") { val formula = new RFormula().setFormula("id ~ a + b") val original = From db44f5f3e8b5bc28c33b154319539d51c05a089c Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Wed, 28 Jun 2017 19:36:00 -0700 Subject: [PATCH 0811/1765] [SPARK-21224][R] Specify a schema by using a DDL-formatted string when reading in R ## What changes were proposed in this pull request? This PR proposes to support a DDL-formetted string as schema as below: ```r mockLines <- c("{\"name\":\"Michael\"}", "{\"name\":\"Andy\", \"age\":30}", "{\"name\":\"Justin\", \"age\":19}") jsonPath <- tempfile(pattern = "sparkr-test", fileext = ".tmp") writeLines(mockLines, jsonPath) df <- read.df(jsonPath, "json", "name STRING, age DOUBLE") collect(df) ``` ## How was this patch tested? Tests added in `test_streaming.R` and `test_sparkSQL.R` and manual tests. Author: hyukjinkwon Closes #18431 from HyukjinKwon/r-ddl-schema. --- R/pkg/R/SQLContext.R | 38 +++++++++++++------ R/pkg/tests/fulltests/test_sparkSQL.R | 20 +++++++++- R/pkg/tests/fulltests/test_streaming.R | 23 +++++++++++ .../org/apache/spark/sql/api/r/SQLUtils.scala | 15 -------- 4 files changed, 67 insertions(+), 29 deletions(-) diff --git a/R/pkg/R/SQLContext.R b/R/pkg/R/SQLContext.R index e3528bc7c3135..3b7f71bbbffb8 100644 --- a/R/pkg/R/SQLContext.R +++ b/R/pkg/R/SQLContext.R @@ -584,7 +584,7 @@ tableToDF <- function(tableName) { #' #' @param path The path of files to load #' @param source The name of external data source -#' @param schema The data schema defined in structType +#' @param schema The data schema defined in structType or a DDL-formatted string. #' @param na.strings Default string value for NA when source is "csv" #' @param ... additional external data source specific named properties. #' @return SparkDataFrame @@ -600,6 +600,8 @@ tableToDF <- function(tableName) { #' structField("info", "map")) #' df2 <- read.df(mapTypeJsonPath, "json", schema, multiLine = TRUE) #' df3 <- loadDF("data/test_table", "parquet", mergeSchema = "true") +#' stringSchema <- "name STRING, info MAP" +#' df4 <- read.df(mapTypeJsonPath, "json", stringSchema, multiLine = TRUE) #' } #' @name read.df #' @method read.df default @@ -623,14 +625,19 @@ read.df.default <- function(path = NULL, source = NULL, schema = NULL, na.string if (source == "csv" && is.null(options[["nullValue"]])) { options[["nullValue"]] <- na.strings } + read <- callJMethod(sparkSession, "read") + read <- callJMethod(read, "format", source) if (!is.null(schema)) { - stopifnot(class(schema) == "structType") - sdf <- handledCallJStatic("org.apache.spark.sql.api.r.SQLUtils", "loadDF", sparkSession, - source, schema$jobj, options) - } else { - sdf <- handledCallJStatic("org.apache.spark.sql.api.r.SQLUtils", "loadDF", sparkSession, - source, options) + if (class(schema) == "structType") { + read <- callJMethod(read, "schema", schema$jobj) + } else if (is.character(schema)) { + read <- callJMethod(read, "schema", schema) + } else { + stop("schema should be structType or character.") + } } + read <- callJMethod(read, "options", options) + sdf <- handledCallJMethod(read, "load") dataFrame(sdf) } @@ -717,8 +724,8 @@ read.jdbc <- function(url, tableName, #' "spark.sql.sources.default" will be used. #' #' @param source The name of external data source -#' @param schema The data schema defined in structType, this is required for file-based streaming -#' data source +#' @param schema The data schema defined in structType or a DDL-formatted string, this is +#' required for file-based streaming data source #' @param ... additional external data source specific named options, for instance \code{path} for #' file-based streaming data source #' @return SparkDataFrame @@ -733,6 +740,8 @@ read.jdbc <- function(url, tableName, #' q <- write.stream(df, "text", path = "/home/user/out", checkpointLocation = "/home/user/cp") #' #' df <- read.stream("json", path = jsonDir, schema = schema, maxFilesPerTrigger = 1) +#' stringSchema <- "name STRING, info MAP" +#' df1 <- read.stream("json", path = jsonDir, schema = stringSchema, maxFilesPerTrigger = 1) #' } #' @name read.stream #' @note read.stream since 2.2.0 @@ -750,10 +759,15 @@ read.stream <- function(source = NULL, schema = NULL, ...) { read <- callJMethod(sparkSession, "readStream") read <- callJMethod(read, "format", source) if (!is.null(schema)) { - stopifnot(class(schema) == "structType") - read <- callJMethod(read, "schema", schema$jobj) + if (class(schema) == "structType") { + read <- callJMethod(read, "schema", schema$jobj) + } else if (is.character(schema)) { + read <- callJMethod(read, "schema", schema) + } else { + stop("schema should be structType or character.") + } } read <- callJMethod(read, "options", options) sdf <- handledCallJMethod(read, "load") - dataFrame(callJMethod(sdf, "toDF")) + dataFrame(sdf) } diff --git a/R/pkg/tests/fulltests/test_sparkSQL.R b/R/pkg/tests/fulltests/test_sparkSQL.R index 911b73b9ee551..a2bcb5aefe16d 100644 --- a/R/pkg/tests/fulltests/test_sparkSQL.R +++ b/R/pkg/tests/fulltests/test_sparkSQL.R @@ -3248,9 +3248,9 @@ test_that("Call DataFrameWriter.load() API in Java without path and check argume # It makes sure that we can omit path argument in read.df API and then it calls # DataFrameWriter.load() without path. expect_error(read.df(source = "json"), - paste("Error in loadDF : analysis error - Unable to infer schema for JSON.", + paste("Error in load : analysis error - Unable to infer schema for JSON.", "It must be specified manually")) - expect_error(read.df("arbitrary_path"), "Error in loadDF : analysis error - Path does not exist") + expect_error(read.df("arbitrary_path"), "Error in load : analysis error - Path does not exist") expect_error(read.json("arbitrary_path"), "Error in json : analysis error - Path does not exist") expect_error(read.text("arbitrary_path"), "Error in text : analysis error - Path does not exist") expect_error(read.orc("arbitrary_path"), "Error in orc : analysis error - Path does not exist") @@ -3268,6 +3268,22 @@ test_that("Call DataFrameWriter.load() API in Java without path and check argume "Unnamed arguments ignored: 2, 3, a.") }) +test_that("Specify a schema by using a DDL-formatted string when reading", { + # Test read.df with a user defined schema in a DDL-formatted string. + df1 <- read.df(jsonPath, "json", "name STRING, age DOUBLE") + expect_is(df1, "SparkDataFrame") + expect_equal(dtypes(df1), list(c("name", "string"), c("age", "double"))) + + expect_error(read.df(jsonPath, "json", "name stri"), "DataType stri is not supported.") + + # Test loadDF with a user defined schema in a DDL-formatted string. + df2 <- loadDF(jsonPath, "json", "name STRING, age DOUBLE") + expect_is(df2, "SparkDataFrame") + expect_equal(dtypes(df2), list(c("name", "string"), c("age", "double"))) + + expect_error(loadDF(jsonPath, "json", "name stri"), "DataType stri is not supported.") +}) + test_that("Collect on DataFrame when NAs exists at the top of a timestamp column", { ldf <- data.frame(col1 = c(0, 1, 2), col2 = c(as.POSIXct("2017-01-01 00:00:01"), diff --git a/R/pkg/tests/fulltests/test_streaming.R b/R/pkg/tests/fulltests/test_streaming.R index d691de7cd725d..54f40bbd5f517 100644 --- a/R/pkg/tests/fulltests/test_streaming.R +++ b/R/pkg/tests/fulltests/test_streaming.R @@ -46,6 +46,8 @@ schema <- structType(structField("name", "string"), structField("age", "integer"), structField("count", "double")) +stringSchema <- "name STRING, age INTEGER, count DOUBLE" + test_that("read.stream, write.stream, awaitTermination, stopQuery", { df <- read.stream("json", path = jsonDir, schema = schema, maxFilesPerTrigger = 1) expect_true(isStreaming(df)) @@ -111,6 +113,27 @@ test_that("Stream other format", { unlink(parquetPath) }) +test_that("Specify a schema by using a DDL-formatted string when reading", { + # Test read.stream with a user defined schema in a DDL-formatted string. + parquetPath <- tempfile(pattern = "sparkr-test", fileext = ".parquet") + df <- read.df(jsonPath, "json", schema) + write.df(df, parquetPath, "parquet", "overwrite") + + df <- read.stream(path = parquetPath, schema = stringSchema) + expect_true(isStreaming(df)) + counts <- count(group_by(df, "name")) + q <- write.stream(counts, "memory", queryName = "people3", outputMode = "complete") + + expect_false(awaitTermination(q, 5 * 1000)) + callJMethod(q@ssq, "processAllAvailable") + expect_equal(head(sql("SELECT count(*) FROM people3"))[[1]], 3) + + expect_error(read.stream(path = parquetPath, schema = "name stri"), + "DataType stri is not supported.") + + unlink(parquetPath) +}) + test_that("Non-streaming DataFrame", { c <- as.DataFrame(cars) expect_false(isStreaming(c)) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala index d94e528a3ad47..9bd2987057dbc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala @@ -193,21 +193,6 @@ private[sql] object SQLUtils extends Logging { } } - def loadDF( - sparkSession: SparkSession, - source: String, - options: java.util.Map[String, String]): DataFrame = { - sparkSession.read.format(source).options(options).load() - } - - def loadDF( - sparkSession: SparkSession, - source: String, - schema: StructType, - options: java.util.Map[String, String]): DataFrame = { - sparkSession.read.format(source).schema(schema).options(options).load() - } - def readSqlObject(dis: DataInputStream, dataType: Char): Object = { dataType match { case 's' => From fc92d25f2a27e81ef2d5031dcf856af1cc1d8c31 Mon Sep 17 00:00:00 2001 From: Felix Cheung Date: Wed, 28 Jun 2017 20:06:29 -0700 Subject: [PATCH 0812/1765] Revert "[SPARK-21094][R] Terminate R's worker processes in the parent of R's daemon to prevent a leak" This reverts commit 6b3d02285ee0debc73cbcab01b10398a498fbeb8. --- R/pkg/inst/worker/daemon.R | 59 +++----------------------------------- 1 file changed, 4 insertions(+), 55 deletions(-) diff --git a/R/pkg/inst/worker/daemon.R b/R/pkg/inst/worker/daemon.R index 6e385b2a27622..3a318b71ea06d 100644 --- a/R/pkg/inst/worker/daemon.R +++ b/R/pkg/inst/worker/daemon.R @@ -30,55 +30,8 @@ port <- as.integer(Sys.getenv("SPARKR_WORKER_PORT")) inputCon <- socketConnection( port = port, open = "rb", blocking = TRUE, timeout = connectionTimeout) -# Waits indefinitely for a socket connecion by default. -selectTimeout <- NULL - -# Exit code that children send to the parent to indicate they exited. -exitCode <- 1 - while (TRUE) { - ready <- socketSelect(list(inputCon), timeout = selectTimeout) - - # Note that the children should be terminated in the parent. If each child terminates - # itself, it appears that the resource is not released properly, that causes an unexpected - # termination of this daemon due to, for example, running out of file descriptors - # (see SPARK-21093). Therefore, the current implementation tries to retrieve children - # that are exited (but not terminated) and then sends a kill signal to terminate them properly - # in the parent. - # - # There are two paths that it attempts to send a signal to terminate the children in the parent. - # - # 1. Every second if any socket connection is not available and if there are child workers - # running. - # 2. Right after a socket connection is available. - # - # In other words, the parent attempts to send the signal to the children every second if - # any worker is running or right before launching other worker children from the following - # new socket connection. - - # Only the process IDs of children sent data to the parent are returned below. The children - # send a custom exit code to the parent after being exited and the parent tries - # to terminate them only if they sent the exit code. - children <- parallel:::selectChildren(timeout = 0) - - if (is.integer(children)) { - lapply(children, function(child) { - # This data should be raw bytes if any data was sent from this child. - # Otherwise, this returns the PID. - data <- parallel:::readChild(child) - if (is.raw(data)) { - # This checks if the data from this child is the exit code that indicates an exited child. - if (unserialize(data) == exitCode) { - # If so, we terminate this child. - tools::pskill(child, tools::SIGUSR1) - } - } - }) - } else if (is.null(children)) { - # If it is NULL, there are no children. Waits indefinitely for a socket connecion. - selectTimeout <- NULL - } - + ready <- socketSelect(list(inputCon)) if (ready) { port <- SparkR:::readInt(inputCon) # There is a small chance that it could be interrupted by signal, retry one time @@ -91,16 +44,12 @@ while (TRUE) { } p <- parallel:::mcfork() if (inherits(p, "masterProcess")) { - # Reach here because this is a child process. close(inputCon) Sys.setenv(SPARKR_WORKER_PORT = port) try(source(script)) - # Note that this mcexit does not fully terminate this child. So, this writes back - # a custom exit code so that the parent can read and terminate this child. - parallel:::mcexit(0L, send = exitCode) - } else { - # Forking succeeded and we need to check if they finished their jobs every second. - selectTimeout <- 1 + # Set SIGUSR1 so that child can exit + tools::pskill(Sys.getpid(), tools::SIGUSR1) + parallel:::mcexit(0L) } } } From 25c2edf6f9da9d4d45fc628cf97de657f2a2cc7e Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Thu, 29 Jun 2017 11:21:50 +0800 Subject: [PATCH 0813/1765] [SPARK-21229][SQL] remove QueryPlan.preCanonicalized ## What changes were proposed in this pull request? `QueryPlan.preCanonicalized` is only overridden in a few places, and it does introduce an extra concept to `QueryPlan` which may confuse people. This PR removes it and override `canonicalized` in these places ## How was this patch tested? existing tests Author: Wenchen Fan Closes #18440 from cloud-fan/minor. --- .../sql/catalyst/catalog/interface.scala | 23 +++++++++++-------- .../spark/sql/catalyst/plans/QueryPlan.scala | 13 ++++------- .../sql/execution/DataSourceScanExec.scala | 8 +++++-- .../datasources/LogicalRelation.scala | 5 +++- 4 files changed, 27 insertions(+), 22 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala index b63bef9193332..da50b0e7e8e42 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala @@ -27,7 +27,8 @@ import com.google.common.base.Objects import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.{FunctionIdentifier, InternalRow, TableIdentifier} import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation -import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap, AttributeReference, Cast, Literal} +import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap, AttributeReference, Cast, ExprId, Literal} +import org.apache.spark.sql.catalyst.plans.QueryPlan import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, DateTimeUtils} import org.apache.spark.sql.catalyst.util.quoteIdentifier @@ -425,15 +426,17 @@ case class CatalogRelation( Objects.hashCode(tableMeta.identifier, output) } - override def preCanonicalized: LogicalPlan = copy(tableMeta = CatalogTable( - identifier = tableMeta.identifier, - tableType = tableMeta.tableType, - storage = CatalogStorageFormat.empty, - schema = tableMeta.schema, - partitionColumnNames = tableMeta.partitionColumnNames, - bucketSpec = tableMeta.bucketSpec, - createTime = -1 - )) + override lazy val canonicalized: LogicalPlan = copy( + tableMeta = tableMeta.copy( + storage = CatalogStorageFormat.empty, + createTime = -1 + ), + dataCols = dataCols.zipWithIndex.map { + case (attr, index) => attr.withExprId(ExprId(index)) + }, + partitionCols = partitionCols.zipWithIndex.map { + case (attr, index) => attr.withExprId(ExprId(index + dataCols.length)) + }) override def computeStats: Statistics = { // For data source tables, we will create a `LogicalRelation` and won't call this method, for diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala index 01b3da3f7c482..7addbaaa9afa5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala @@ -188,12 +188,13 @@ abstract class QueryPlan[PlanType <: QueryPlan[PlanType]] extends TreeNode[PlanT * Plans where `this.canonicalized == other.canonicalized` will always evaluate to the same * result. * - * Some nodes should overwrite this to provide proper canonicalize logic. + * Some nodes should overwrite this to provide proper canonicalize logic, but they should remove + * expressions cosmetic variations themselves. */ lazy val canonicalized: PlanType = { val canonicalizedChildren = children.map(_.canonicalized) var id = -1 - preCanonicalized.mapExpressions { + mapExpressions { case a: Alias => id += 1 // As the root of the expression, Alias will always take an arbitrary exprId, we need to @@ -206,18 +207,12 @@ abstract class QueryPlan[PlanType <: QueryPlan[PlanType]] extends TreeNode[PlanT // Top level `AttributeReference` may also be used for output like `Alias`, we should // normalize the epxrId too. id += 1 - ar.withExprId(ExprId(id)) + ar.withExprId(ExprId(id)).canonicalized case other => QueryPlan.normalizeExprId(other, allAttributes) }.withNewChildren(canonicalizedChildren) } - /** - * Do some simple transformation on this plan before canonicalizing. Implementations can override - * this method to provide customized canonicalize logic without rewriting the whole logic. - */ - protected def preCanonicalized: PlanType = this - /** * Returns true when the given query plan will return the same results as this query plan. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala index 74fc23a52a141..a0def68d88e0d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala @@ -138,8 +138,12 @@ case class RowDataSourceScanExec( } // Only care about `relation` and `metadata` when canonicalizing. - override def preCanonicalized: SparkPlan = - copy(rdd = null, outputPartitioning = null, metastoreTableIdentifier = None) + override lazy val canonicalized: SparkPlan = + copy( + output.map(QueryPlan.normalizeExprId(_, output)), + rdd = null, + outputPartitioning = null, + metastoreTableIdentifier = None) } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/LogicalRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/LogicalRelation.scala index c1b2895f1747e..6ba190b9e5dcf 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/LogicalRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/LogicalRelation.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.execution.datasources import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation import org.apache.spark.sql.catalyst.catalog.CatalogTable import org.apache.spark.sql.catalyst.expressions.{AttributeMap, AttributeReference} +import org.apache.spark.sql.catalyst.plans.QueryPlan import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, LogicalPlan, Statistics} import org.apache.spark.sql.sources.BaseRelation import org.apache.spark.util.Utils @@ -43,7 +44,9 @@ case class LogicalRelation( } // Only care about relation when canonicalizing. - override def preCanonicalized: LogicalPlan = copy(catalogTable = None) + override lazy val canonicalized: LogicalPlan = copy( + output = output.map(QueryPlan.normalizeExprId(_, output)), + catalogTable = None) @transient override def computeStats: Statistics = { catalogTable.flatMap(_.stats.map(_.toPlanStats(output))).getOrElse( From 82e24912d6e15a9e4fbadd83da9a08d4f80a592b Mon Sep 17 00:00:00 2001 From: wangzhenhua Date: Thu, 29 Jun 2017 11:32:29 +0800 Subject: [PATCH 0814/1765] [SPARK-21237][SQL] Invalidate stats once table data is changed ## What changes were proposed in this pull request? Invalidate spark's stats after data changing commands: - InsertIntoHadoopFsRelationCommand - InsertIntoHiveTable - LoadDataCommand - TruncateTableCommand - AlterTableSetLocationCommand - AlterTableDropPartitionCommand ## How was this patch tested? Added test cases. Author: wangzhenhua Closes #18449 from wzhfy/removeStats. --- .../catalyst/catalog/ExternalCatalog.scala | 3 +- .../catalyst/catalog/InMemoryCatalog.scala | 4 +- .../sql/catalyst/catalog/SessionCatalog.scala | 2 +- .../catalog/ExternalCatalogSuite.scala | 2 +- .../catalog/SessionCatalogSuite.scala | 2 +- .../command/AnalyzeColumnCommand.scala | 4 +- .../command/AnalyzeTableCommand.scala | 76 +--------- .../sql/execution/command/CommandUtils.scala | 102 ++++++++++++++ .../spark/sql/execution/command/ddl.scala | 9 +- .../spark/sql/execution/command/tables.scala | 7 + .../InsertIntoHadoopFsRelationCommand.scala | 5 + .../spark/sql/StatisticsCollectionSuite.scala | 85 ++++++++++-- .../apache/spark/sql/test/SQLTestUtils.scala | 14 ++ .../spark/sql/hive/HiveExternalCatalog.scala | 24 ++-- .../hive/execution/InsertIntoHiveTable.scala | 4 +- .../spark/sql/hive/StatisticsSuite.scala | 130 ++++++++++++++---- 16 files changed, 340 insertions(+), 133 deletions(-) create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/command/CommandUtils.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalog.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalog.scala index 12ba5aedde026..0254b6bb6d136 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalog.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalog.scala @@ -160,7 +160,8 @@ abstract class ExternalCatalog */ def alterTableSchema(db: String, table: String, schema: StructType): Unit - def alterTableStats(db: String, table: String, stats: CatalogStatistics): Unit + /** Alter the statistics of a table. If `stats` is None, then remove all existing statistics. */ + def alterTableStats(db: String, table: String, stats: Option[CatalogStatistics]): Unit def getTable(db: String, table: String): CatalogTable diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/InMemoryCatalog.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/InMemoryCatalog.scala index 9820522a230e3..747190faa3c8c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/InMemoryCatalog.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/InMemoryCatalog.scala @@ -315,10 +315,10 @@ class InMemoryCatalog( override def alterTableStats( db: String, table: String, - stats: CatalogStatistics): Unit = synchronized { + stats: Option[CatalogStatistics]): Unit = synchronized { requireTableExists(db, table) val origTable = catalog(db).tables(table).table - catalog(db).tables(table).table = origTable.copy(stats = Some(stats)) + catalog(db).tables(table).table = origTable.copy(stats = stats) } override def getTable(db: String, table: String): CatalogTable = synchronized { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala index cf02da8993658..7ece77df7fc14 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala @@ -380,7 +380,7 @@ class SessionCatalog( * Alter Spark's statistics of an existing metastore table identified by the provided table * identifier. */ - def alterTableStats(identifier: TableIdentifier, newStats: CatalogStatistics): Unit = { + def alterTableStats(identifier: TableIdentifier, newStats: Option[CatalogStatistics]): Unit = { val db = formatDatabaseName(identifier.database.getOrElse(getCurrentDatabase)) val table = formatTableName(identifier.table) val tableIdentifier = TableIdentifier(table, Some(db)) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalogSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalogSuite.scala index 557b0970b54e5..c22d55fc96a65 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalogSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalogSuite.scala @@ -260,7 +260,7 @@ abstract class ExternalCatalogSuite extends SparkFunSuite with BeforeAndAfterEac val oldTableStats = catalog.getTable("db2", "tbl1").stats assert(oldTableStats.isEmpty) val newStats = CatalogStatistics(sizeInBytes = 1) - catalog.alterTableStats("db2", "tbl1", newStats) + catalog.alterTableStats("db2", "tbl1", Some(newStats)) val newTableStats = catalog.getTable("db2", "tbl1").stats assert(newTableStats.get == newStats) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalogSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalogSuite.scala index a6dc21b03d446..fc3893e197792 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalogSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalogSuite.scala @@ -454,7 +454,7 @@ abstract class SessionCatalogSuite extends AnalysisTest { val oldTableStats = catalog.getTableMetadata(tableId).stats assert(oldTableStats.isEmpty) val newStats = CatalogStatistics(sizeInBytes = 1) - catalog.alterTableStats(tableId, newStats) + catalog.alterTableStats(tableId, Some(newStats)) val newTableStats = catalog.getTableMetadata(tableId).stats assert(newTableStats.get == newStats) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeColumnCommand.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeColumnCommand.scala index 2f273b63e8348..6588993ef9ad9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeColumnCommand.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeColumnCommand.scala @@ -42,7 +42,7 @@ case class AnalyzeColumnCommand( if (tableMeta.tableType == CatalogTableType.VIEW) { throw new AnalysisException("ANALYZE TABLE is not supported on views.") } - val sizeInBytes = AnalyzeTableCommand.calculateTotalSize(sessionState, tableMeta) + val sizeInBytes = CommandUtils.calculateTotalSize(sessionState, tableMeta) // Compute stats for each column val (rowCount, newColStats) = computeColumnStats(sparkSession, tableIdentWithDB, columnNames) @@ -54,7 +54,7 @@ case class AnalyzeColumnCommand( // Newly computed column stats should override the existing ones. colStats = tableMeta.stats.map(_.colStats).getOrElse(Map.empty) ++ newColStats) - sessionState.catalog.alterTableStats(tableIdentWithDB, statistics) + sessionState.catalog.alterTableStats(tableIdentWithDB, Some(statistics)) // Refresh the cached data source table in the catalog. sessionState.catalog.refreshTable(tableIdentWithDB) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeTableCommand.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeTableCommand.scala index 13b8faff844c7..d780ef42f3fae 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeTableCommand.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeTableCommand.scala @@ -17,18 +17,10 @@ package org.apache.spark.sql.execution.command -import java.net.URI - -import scala.util.control.NonFatal - -import org.apache.hadoop.fs.{FileSystem, Path} - -import org.apache.spark.internal.Logging import org.apache.spark.sql.{AnalysisException, Row, SparkSession} import org.apache.spark.sql.catalyst.TableIdentifier -import org.apache.spark.sql.catalyst.catalog.{CatalogStatistics, CatalogTable, CatalogTableType} +import org.apache.spark.sql.catalyst.catalog.{CatalogStatistics, CatalogTableType} import org.apache.spark.sql.execution.SQLExecution -import org.apache.spark.sql.internal.SessionState /** @@ -46,7 +38,7 @@ case class AnalyzeTableCommand( if (tableMeta.tableType == CatalogTableType.VIEW) { throw new AnalysisException("ANALYZE TABLE is not supported on views.") } - val newTotalSize = AnalyzeTableCommand.calculateTotalSize(sessionState, tableMeta) + val newTotalSize = CommandUtils.calculateTotalSize(sessionState, tableMeta) val oldTotalSize = tableMeta.stats.map(_.sizeInBytes.toLong).getOrElse(0L) val oldRowCount = tableMeta.stats.flatMap(_.rowCount.map(_.toLong)).getOrElse(-1L) @@ -74,7 +66,7 @@ case class AnalyzeTableCommand( // Update the metastore if the above statistics of the table are different from those // recorded in the metastore. if (newStats.isDefined) { - sessionState.catalog.alterTableStats(tableIdentWithDB, newStats.get) + sessionState.catalog.alterTableStats(tableIdentWithDB, newStats) // Refresh the cached data source table in the catalog. sessionState.catalog.refreshTable(tableIdentWithDB) } @@ -82,65 +74,3 @@ case class AnalyzeTableCommand( Seq.empty[Row] } } - -object AnalyzeTableCommand extends Logging { - - def calculateTotalSize(sessionState: SessionState, catalogTable: CatalogTable): Long = { - if (catalogTable.partitionColumnNames.isEmpty) { - calculateLocationSize(sessionState, catalogTable.identifier, catalogTable.storage.locationUri) - } else { - // Calculate table size as a sum of the visible partitions. See SPARK-21079 - val partitions = sessionState.catalog.listPartitions(catalogTable.identifier) - partitions.map(p => - calculateLocationSize(sessionState, catalogTable.identifier, p.storage.locationUri) - ).sum - } - } - - private def calculateLocationSize( - sessionState: SessionState, - tableId: TableIdentifier, - locationUri: Option[URI]): Long = { - // This method is mainly based on - // org.apache.hadoop.hive.ql.stats.StatsUtils.getFileSizeForTable(HiveConf, Table) - // in Hive 0.13 (except that we do not use fs.getContentSummary). - // TODO: Generalize statistics collection. - // TODO: Why fs.getContentSummary returns wrong size on Jenkins? - // Can we use fs.getContentSummary in future? - // Seems fs.getContentSummary returns wrong table size on Jenkins. So we use - // countFileSize to count the table size. - val stagingDir = sessionState.conf.getConfString("hive.exec.stagingdir", ".hive-staging") - - def calculateLocationSize(fs: FileSystem, path: Path): Long = { - val fileStatus = fs.getFileStatus(path) - val size = if (fileStatus.isDirectory) { - fs.listStatus(path) - .map { status => - if (!status.getPath.getName.startsWith(stagingDir)) { - calculateLocationSize(fs, status.getPath) - } else { - 0L - } - }.sum - } else { - fileStatus.getLen - } - - size - } - - locationUri.map { p => - val path = new Path(p) - try { - val fs = path.getFileSystem(sessionState.newHadoopConf()) - calculateLocationSize(fs, path) - } catch { - case NonFatal(e) => - logWarning( - s"Failed to get the size of table ${tableId.table} in the " + - s"database ${tableId.database} because of ${e.toString}", e) - 0L - } - }.getOrElse(0L) - } -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/CommandUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/CommandUtils.scala new file mode 100644 index 0000000000000..92397607f38fd --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/CommandUtils.scala @@ -0,0 +1,102 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +package org.apache.spark.sql.execution.command + +import java.net.URI + +import scala.util.control.NonFatal + +import org.apache.hadoop.fs.{FileSystem, Path} + +import org.apache.spark.internal.Logging +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.catalyst.TableIdentifier +import org.apache.spark.sql.catalyst.catalog.{CatalogStatistics, CatalogTable} +import org.apache.spark.sql.internal.SessionState + + +object CommandUtils extends Logging { + + /** Change statistics after changing data by commands. */ + def updateTableStats(sparkSession: SparkSession, table: CatalogTable): Unit = { + if (table.stats.nonEmpty) { + val catalog = sparkSession.sessionState.catalog + catalog.alterTableStats(table.identifier, None) + } + } + + def calculateTotalSize(sessionState: SessionState, catalogTable: CatalogTable): BigInt = { + if (catalogTable.partitionColumnNames.isEmpty) { + calculateLocationSize(sessionState, catalogTable.identifier, catalogTable.storage.locationUri) + } else { + // Calculate table size as a sum of the visible partitions. See SPARK-21079 + val partitions = sessionState.catalog.listPartitions(catalogTable.identifier) + partitions.map { p => + calculateLocationSize(sessionState, catalogTable.identifier, p.storage.locationUri) + }.sum + } + } + + def calculateLocationSize( + sessionState: SessionState, + identifier: TableIdentifier, + locationUri: Option[URI]): Long = { + // This method is mainly based on + // org.apache.hadoop.hive.ql.stats.StatsUtils.getFileSizeForTable(HiveConf, Table) + // in Hive 0.13 (except that we do not use fs.getContentSummary). + // TODO: Generalize statistics collection. + // TODO: Why fs.getContentSummary returns wrong size on Jenkins? + // Can we use fs.getContentSummary in future? + // Seems fs.getContentSummary returns wrong table size on Jenkins. So we use + // countFileSize to count the table size. + val stagingDir = sessionState.conf.getConfString("hive.exec.stagingdir", ".hive-staging") + + def getPathSize(fs: FileSystem, path: Path): Long = { + val fileStatus = fs.getFileStatus(path) + val size = if (fileStatus.isDirectory) { + fs.listStatus(path) + .map { status => + if (!status.getPath.getName.startsWith(stagingDir)) { + getPathSize(fs, status.getPath) + } else { + 0L + } + }.sum + } else { + fileStatus.getLen + } + + size + } + + locationUri.map { p => + val path = new Path(p) + try { + val fs = path.getFileSystem(sessionState.newHadoopConf()) + getPathSize(fs, path) + } catch { + case NonFatal(e) => + logWarning( + s"Failed to get the size of table ${identifier.table} in the " + + s"database ${identifier.database} because of ${e.toString}", e) + 0L + } + }.getOrElse(0L) + } + +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala index 413f5f3ba539c..ac897c1b22d77 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala @@ -433,9 +433,11 @@ case class AlterTableAddPartitionCommand( sparkSession.sessionState.conf.resolver) // inherit table storage format (possibly except for location) CatalogTablePartition(normalizedSpec, table.storage.copy( - locationUri = location.map(CatalogUtils.stringToURI(_)))) + locationUri = location.map(CatalogUtils.stringToURI))) } catalog.createPartitions(table.identifier, parts, ignoreIfExists = ifNotExists) + + CommandUtils.updateTableStats(sparkSession, table) Seq.empty[Row] } @@ -519,6 +521,9 @@ case class AlterTableDropPartitionCommand( catalog.dropPartitions( table.identifier, normalizedSpecs, ignoreIfNotExists = ifExists, purge = purge, retainData = retainData) + + CommandUtils.updateTableStats(sparkSession, table) + Seq.empty[Row] } @@ -768,6 +773,8 @@ case class AlterTableSetLocationCommand( // No partition spec is specified, so we set the location for the table itself catalog.alterTable(table.withNewStorage(locationUri = Some(locUri))) } + + CommandUtils.updateTableStats(sparkSession, table) Seq.empty[Row] } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala index b937a8a9f375b..8ded1060f7bf0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala @@ -400,6 +400,7 @@ case class LoadDataCommand( // Refresh the metadata cache to ensure the data visible to the users catalog.refreshTable(targetTable.identifier) + CommandUtils.updateTableStats(sparkSession, targetTable) Seq.empty[Row] } } @@ -487,6 +488,12 @@ case class TruncateTableCommand( case NonFatal(e) => log.warn(s"Exception when attempting to uncache table $tableIdentWithDB", e) } + + if (table.stats.nonEmpty) { + // empty table after truncation + val newStats = CatalogStatistics(sizeInBytes = 0, rowCount = Some(0)) + catalog.alterTableStats(tableName, Some(newStats)) + } Seq.empty[Row] } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelationCommand.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelationCommand.scala index 00aa1240886e4..ab26f2affbce5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelationCommand.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelationCommand.scala @@ -161,6 +161,11 @@ case class InsertIntoHadoopFsRelationCommand( fileIndex.foreach(_.refresh()) // refresh data cache if table is cached sparkSession.catalog.refreshByPath(outputPath.toString) + + if (catalogTable.nonEmpty) { + CommandUtils.updateTableStats(sparkSession, catalogTable.get) + } + } else { logInfo("Skipping insertion into a relation that already exists.") } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/StatisticsCollectionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/StatisticsCollectionSuite.scala index 9824062f969b3..b031c52dad8b5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/StatisticsCollectionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/StatisticsCollectionSuite.scala @@ -40,17 +40,6 @@ import org.apache.spark.sql.types._ class StatisticsCollectionSuite extends StatisticsCollectionTestBase with SharedSQLContext { import testImplicits._ - private def checkTableStats(tableName: String, expectedRowCount: Option[Int]) - : Option[CatalogStatistics] = { - val df = spark.table(tableName) - val stats = df.queryExecution.analyzed.collect { case rel: LogicalRelation => - assert(rel.catalogTable.get.stats.flatMap(_.rowCount) === expectedRowCount) - rel.catalogTable.get.stats - } - assert(stats.size == 1) - stats.head - } - test("estimates the size of a limit 0 on outer join") { withTempView("test") { Seq(("one", 1), ("two", 2), ("three", 3), ("four", 4)).toDF("k", "v") @@ -96,11 +85,11 @@ class StatisticsCollectionSuite extends StatisticsCollectionTestBase with Shared // noscan won't count the number of rows sql(s"ANALYZE TABLE $tableName COMPUTE STATISTICS noscan") - checkTableStats(tableName, expectedRowCount = None) + checkTableStats(tableName, hasSizeInBytes = true, expectedRowCounts = None) // without noscan, we count the number of rows sql(s"ANALYZE TABLE $tableName COMPUTE STATISTICS") - checkTableStats(tableName, expectedRowCount = Some(2)) + checkTableStats(tableName, hasSizeInBytes = true, expectedRowCounts = Some(2)) } } @@ -168,6 +157,60 @@ class StatisticsCollectionSuite extends StatisticsCollectionTestBase with Shared assert(stats.simpleString == expectedString) } } + + test("change stats after truncate command") { + val table = "change_stats_truncate_table" + withTable(table) { + spark.range(100).select($"id", $"id" % 5 as "value").write.saveAsTable(table) + // analyze to get initial stats + sql(s"ANALYZE TABLE $table COMPUTE STATISTICS FOR COLUMNS id, value") + val fetched1 = checkTableStats(table, hasSizeInBytes = true, expectedRowCounts = Some(100)) + assert(fetched1.get.sizeInBytes > 0) + assert(fetched1.get.colStats.size == 2) + + // truncate table command + sql(s"TRUNCATE TABLE $table") + val fetched2 = checkTableStats(table, hasSizeInBytes = true, expectedRowCounts = Some(0)) + assert(fetched2.get.sizeInBytes == 0) + assert(fetched2.get.colStats.isEmpty) + } + } + + test("change stats after set location command") { + val table = "change_stats_set_location_table" + withTable(table) { + spark.range(100).select($"id", $"id" % 5 as "value").write.saveAsTable(table) + // analyze to get initial stats + sql(s"ANALYZE TABLE $table COMPUTE STATISTICS FOR COLUMNS id, value") + val fetched1 = checkTableStats( + table, hasSizeInBytes = true, expectedRowCounts = Some(100)) + assert(fetched1.get.sizeInBytes > 0) + assert(fetched1.get.colStats.size == 2) + + // set location command + withTempDir { newLocation => + sql(s"ALTER TABLE $table SET LOCATION '${newLocation.toURI.toString}'") + checkTableStats(table, hasSizeInBytes = false, expectedRowCounts = None) + } + } + } + + test("change stats after insert command for datasource table") { + val table = "change_stats_insert_datasource_table" + withTable(table) { + sql(s"CREATE TABLE $table (i int, j string) USING PARQUET") + // analyze to get initial stats + sql(s"ANALYZE TABLE $table COMPUTE STATISTICS FOR COLUMNS i, j") + val fetched1 = checkTableStats(table, hasSizeInBytes = true, expectedRowCounts = Some(0)) + assert(fetched1.get.sizeInBytes == 0) + assert(fetched1.get.colStats.size == 2) + + // insert into command + sql(s"INSERT INTO TABLE $table SELECT 1, 'abc'") + checkTableStats(table, hasSizeInBytes = false, expectedRowCounts = None) + } + } + } @@ -219,6 +262,22 @@ abstract class StatisticsCollectionTestBase extends QueryTest with SQLTestUtils private val randomName = new Random(31) + def checkTableStats( + tableName: String, + hasSizeInBytes: Boolean, + expectedRowCounts: Option[Int]): Option[CatalogStatistics] = { + val stats = spark.sessionState.catalog.getTableMetadata(TableIdentifier(tableName)).stats + if (hasSizeInBytes || expectedRowCounts.nonEmpty) { + assert(stats.isDefined) + assert(stats.get.sizeInBytes >= 0) + assert(stats.get.rowCount === expectedRowCounts) + } else { + assert(stats.isEmpty) + } + + stats + } + /** * Compute column stats for the given DataFrame and compare it with colStats. */ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala index f6d47734d7e83..d74a7cce25ed6 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala @@ -149,6 +149,7 @@ private[sql] trait SQLTestUtils .getExecutorInfos.map(_.numRunningTasks()).sum == 0) } } + /** * Creates a temporary directory, which is then passed to `f` and will be deleted after `f` * returns. @@ -164,6 +165,19 @@ private[sql] trait SQLTestUtils } } + /** + * Creates the specified number of temporary directories, which is then passed to `f` and will be + * deleted after `f` returns. + */ + protected def withTempPaths(numPaths: Int)(f: Seq[File] => Unit): Unit = { + val files = Array.fill[File](numPaths)(Utils.createTempDir().getCanonicalFile) + try f(files) finally { + // wait for all tasks to finish before deleting files + waitForTasksToFinish() + files.foreach(Utils.deleteRecursively) + } + } + /** * Drops functions after calling `f`. A function is represented by (functionName, isTemporary). */ diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala index 6e7c475fa34c9..2a17849fa8a34 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala @@ -631,21 +631,23 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat override def alterTableStats( db: String, table: String, - stats: CatalogStatistics): Unit = withClient { + stats: Option[CatalogStatistics]): Unit = withClient { requireTableExists(db, table) val rawTable = getRawTable(db, table) // convert table statistics to properties so that we can persist them through hive client - var statsProperties: Map[String, String] = - Map(STATISTICS_TOTAL_SIZE -> stats.sizeInBytes.toString()) - if (stats.rowCount.isDefined) { - statsProperties += STATISTICS_NUM_ROWS -> stats.rowCount.get.toString() - } - val colNameTypeMap: Map[String, DataType] = - rawTable.schema.fields.map(f => (f.name, f.dataType)).toMap - stats.colStats.foreach { case (colName, colStat) => - colStat.toMap(colName, colNameTypeMap(colName)).foreach { case (k, v) => - statsProperties += (columnStatKeyPropName(colName, k) -> v) + val statsProperties = new mutable.HashMap[String, String]() + if (stats.isDefined) { + statsProperties += STATISTICS_TOTAL_SIZE -> stats.get.sizeInBytes.toString() + if (stats.get.rowCount.isDefined) { + statsProperties += STATISTICS_NUM_ROWS -> stats.get.rowCount.get.toString() + } + val colNameTypeMap: Map[String, DataType] = + rawTable.schema.fields.map(f => (f.name, f.dataType)).toMap + stats.get.colStats.foreach { case (colName, colStat) => + colStat.toMap(colName, colNameTypeMap(colName)).foreach { case (k, v) => + statsProperties += (columnStatKeyPropName(colName, k) -> v) + } } } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala index 392b7cfaa8eff..223d375232393 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala @@ -37,7 +37,7 @@ import org.apache.spark.sql.catalyst.catalog.CatalogTable import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.execution.SparkPlan -import org.apache.spark.sql.execution.command.RunnableCommand +import org.apache.spark.sql.execution.command.{CommandUtils, RunnableCommand} import org.apache.spark.sql.execution.datasources.FileFormatWriter import org.apache.spark.sql.hive._ import org.apache.spark.sql.hive.HiveShim.{ShimFileSinkDesc => FileSinkDesc} @@ -434,6 +434,8 @@ case class InsertIntoHiveTable( sparkSession.catalog.uncacheTable(table.identifier.quotedString) sparkSession.sessionState.catalog.refreshTable(table.identifier) + CommandUtils.updateTableStats(sparkSession, table) + // It would be nice to just return the childRdd unchanged so insert operations could be chained, // however for now we return an empty list to simplify compatibility checks with hive, which // does not return anything for insert operations. diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala index 64deb3818d5d1..5fd266c2d033c 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala @@ -30,10 +30,12 @@ import org.apache.spark.sql.catalyst.util.StringUtils import org.apache.spark.sql.execution.command.DDLUtils import org.apache.spark.sql.execution.datasources.LogicalRelation import org.apache.spark.sql.execution.joins._ +import org.apache.spark.sql.hive.HiveExternalCatalog._ import org.apache.spark.sql.hive.test.TestHiveSingleton import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ + class StatisticsSuite extends StatisticsCollectionTestBase with TestHiveSingleton { test("Hive serde tables should fallback to HDFS for size estimation") { @@ -219,23 +221,6 @@ class StatisticsSuite extends StatisticsCollectionTestBase with TestHiveSingleto } } - private def checkTableStats( - tableName: String, - hasSizeInBytes: Boolean, - expectedRowCounts: Option[Int]): Option[CatalogStatistics] = { - val stats = spark.sessionState.catalog.getTableMetadata(TableIdentifier(tableName)).stats - - if (hasSizeInBytes || expectedRowCounts.nonEmpty) { - assert(stats.isDefined) - assert(stats.get.sizeInBytes > 0) - assert(stats.get.rowCount === expectedRowCounts) - } else { - assert(stats.isEmpty) - } - - stats - } - test("test table-level statistics for hive tables created in HiveExternalCatalog") { val textTable = "textTable" withTable(textTable) { @@ -326,7 +311,7 @@ class StatisticsSuite extends StatisticsCollectionTestBase with TestHiveSingleto descOutput: Seq[String], propKey: String): Option[BigInt] = { val str = descOutput - .filterNot(_.contains(HiveExternalCatalog.STATISTICS_PREFIX)) + .filterNot(_.contains(STATISTICS_PREFIX)) .filter(_.contains(propKey)) if (str.isEmpty) { None @@ -448,6 +433,103 @@ class StatisticsSuite extends StatisticsCollectionTestBase with TestHiveSingleto "ALTER TABLE unset_prop_table UNSET TBLPROPERTIES ('prop1')") } + /** + * To see if stats exist, we need to check spark's stats properties instead of catalog + * statistics, because hive would change stats in metastore and thus change catalog statistics. + */ + private def getStatsProperties(tableName: String): Map[String, String] = { + val hTable = hiveClient.getTable(spark.sessionState.catalog.getCurrentDatabase, tableName) + hTable.properties.filterKeys(_.startsWith(STATISTICS_PREFIX)) + } + + test("change stats after insert command for hive table") { + val table = s"change_stats_insert_hive_table" + withTable(table) { + sql(s"CREATE TABLE $table (i int, j string)") + // analyze to get initial stats + sql(s"ANALYZE TABLE $table COMPUTE STATISTICS FOR COLUMNS i, j") + val fetched1 = checkTableStats(table, hasSizeInBytes = true, expectedRowCounts = Some(0)) + assert(fetched1.get.sizeInBytes == 0) + assert(fetched1.get.colStats.size == 2) + + // insert into command + sql(s"INSERT INTO TABLE $table SELECT 1, 'abc'") + assert(getStatsProperties(table).isEmpty) + } + } + + test("change stats after load data command") { + val table = "change_stats_load_table" + withTable(table) { + sql(s"CREATE TABLE $table (i INT, j STRING) STORED AS PARQUET") + // analyze to get initial stats + sql(s"ANALYZE TABLE $table COMPUTE STATISTICS FOR COLUMNS i, j") + val fetched1 = checkTableStats(table, hasSizeInBytes = true, expectedRowCounts = Some(0)) + assert(fetched1.get.sizeInBytes == 0) + assert(fetched1.get.colStats.size == 2) + + withTempDir { loadPath => + // load data command + val file = new File(loadPath + "/data") + val writer = new PrintWriter(file) + writer.write("2,xyz") + writer.close() + sql(s"LOAD DATA INPATH '${loadPath.toURI.toString}' INTO TABLE $table") + assert(getStatsProperties(table).isEmpty) + } + } + } + + test("change stats after add/drop partition command") { + val table = "change_stats_part_table" + withTable(table) { + sql(s"CREATE TABLE $table (i INT, j STRING) PARTITIONED BY (ds STRING, hr STRING)") + // table has two partitions initially + for (ds <- Seq("2008-04-08"); hr <- Seq("11", "12")) { + sql(s"INSERT OVERWRITE TABLE $table PARTITION (ds='$ds',hr='$hr') SELECT 1, 'a'") + } + // analyze to get initial stats + sql(s"ANALYZE TABLE $table COMPUTE STATISTICS FOR COLUMNS i, j") + val fetched1 = checkTableStats(table, hasSizeInBytes = true, expectedRowCounts = Some(2)) + assert(fetched1.get.sizeInBytes > 0) + assert(fetched1.get.colStats.size == 2) + + withTempPaths(numPaths = 2) { case Seq(dir1, dir2) => + val file1 = new File(dir1 + "/data") + val writer1 = new PrintWriter(file1) + writer1.write("1,a") + writer1.close() + + val file2 = new File(dir2 + "/data") + val writer2 = new PrintWriter(file2) + writer2.write("1,a") + writer2.close() + + // add partition command + sql( + s""" + |ALTER TABLE $table ADD + |PARTITION (ds='2008-04-09', hr='11') LOCATION '${dir1.toURI.toString}' + |PARTITION (ds='2008-04-09', hr='12') LOCATION '${dir2.toURI.toString}' + """.stripMargin) + assert(getStatsProperties(table).isEmpty) + + // generate stats again + sql(s"ANALYZE TABLE $table COMPUTE STATISTICS FOR COLUMNS i, j") + val fetched2 = checkTableStats(table, hasSizeInBytes = true, expectedRowCounts = Some(4)) + assert(fetched2.get.sizeInBytes > 0) + assert(fetched2.get.colStats.size == 2) + + // drop partition command + sql(s"ALTER TABLE $table DROP PARTITION (ds='2008-04-08'), PARTITION (hr='12')") + // only one partition left + assert(spark.sessionState.catalog.listPartitions(TableIdentifier(table)) + .map(_.spec).toSet == Set(Map("ds" -> "2008-04-09", "hr" -> "11"))) + assert(getStatsProperties(table).isEmpty) + } + } + } + test("add/drop partitions - managed table") { val catalog = spark.sessionState.catalog val managedTable = "partitionedTable" @@ -483,23 +565,19 @@ class StatisticsSuite extends StatisticsCollectionTestBase with TestHiveSingleto assert(catalog.listPartitions(TableIdentifier(managedTable)).map(_.spec).toSet == Set(Map("ds" -> "2008-04-09", "hr" -> "11"))) - val stats2 = checkTableStats( - managedTable, hasSizeInBytes = true, expectedRowCounts = Some(4)) - assert(stats1 == stats2) - sql(s"ANALYZE TABLE $managedTable COMPUTE STATISTICS") - val stats3 = checkTableStats( + val stats2 = checkTableStats( managedTable, hasSizeInBytes = true, expectedRowCounts = Some(1)) - assert(stats2.get.sizeInBytes > stats3.get.sizeInBytes) + assert(stats1.get.sizeInBytes > stats2.get.sizeInBytes) sql(s"ALTER TABLE $managedTable ADD PARTITION (ds='2008-04-08', hr='12')") sql(s"ANALYZE TABLE $managedTable COMPUTE STATISTICS") val stats4 = checkTableStats( managedTable, hasSizeInBytes = true, expectedRowCounts = Some(1)) - assert(stats2.get.sizeInBytes > stats4.get.sizeInBytes) - assert(stats4.get.sizeInBytes == stats3.get.sizeInBytes) + assert(stats1.get.sizeInBytes > stats4.get.sizeInBytes) + assert(stats4.get.sizeInBytes == stats2.get.sizeInBytes) } } From a946be35ac177737e99942ad42de6f319f186138 Mon Sep 17 00:00:00 2001 From: Sital Kedia Date: Thu, 29 Jun 2017 14:25:51 +0800 Subject: [PATCH 0815/1765] [SPARK-3577] Report Spill size on disk for UnsafeExternalSorter ## What changes were proposed in this pull request? Report Spill size on disk for UnsafeExternalSorter ## How was this patch tested? Tested by running a job on cluster and verify the spill size on disk. Author: Sital Kedia Closes #17471 from sitalkedia/fix_disk_spill_size. --- .../unsafe/sort/UnsafeExternalSorter.java | 9 +++---- .../sort/UnsafeExternalSorterSuite.java | 25 +++++++++++++++++++ 2 files changed, 29 insertions(+), 5 deletions(-) diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java index f312fa2b2ddd7..82d03e3e9190c 100644 --- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java @@ -54,7 +54,6 @@ public final class UnsafeExternalSorter extends MemoryConsumer { private final BlockManager blockManager; private final SerializerManager serializerManager; private final TaskContext taskContext; - private ShuffleWriteMetrics writeMetrics; /** The buffer size to use when writing spills using DiskBlockObjectWriter */ private final int fileBufferSizeBytes; @@ -144,10 +143,6 @@ private UnsafeExternalSorter( // Use getSizeAsKb (not bytes) to maintain backwards compatibility for units // this.fileBufferSizeBytes = (int) conf.getSizeAsKb("spark.shuffle.file.buffer", "32k") * 1024 this.fileBufferSizeBytes = 32 * 1024; - // The spill metrics are stored in a new ShuffleWriteMetrics, - // and then discarded (this fixes SPARK-16827). - // TODO: Instead, separate spill metrics should be stored and reported (tracked in SPARK-3577). - this.writeMetrics = new ShuffleWriteMetrics(); if (existingInMemorySorter == null) { this.inMemSorter = new UnsafeInMemorySorter( @@ -199,6 +194,7 @@ public long spill(long size, MemoryConsumer trigger) throws IOException { spillWriters.size(), spillWriters.size() > 1 ? " times" : " time"); + ShuffleWriteMetrics writeMetrics = new ShuffleWriteMetrics(); // We only write out contents of the inMemSorter if it is not empty. if (inMemSorter.numRecords() > 0) { final UnsafeSorterSpillWriter spillWriter = @@ -226,6 +222,7 @@ public long spill(long size, MemoryConsumer trigger) throws IOException { // pages, we might not be able to get memory for the pointer array. taskContext.taskMetrics().incMemoryBytesSpilled(spillSize); + taskContext.taskMetrics().incDiskBytesSpilled(writeMetrics.bytesWritten()); totalSpillBytes += spillSize; return spillSize; } @@ -502,6 +499,7 @@ public long spill() throws IOException { UnsafeInMemorySorter.SortedIterator inMemIterator = ((UnsafeInMemorySorter.SortedIterator) upstream).clone(); + ShuffleWriteMetrics writeMetrics = new ShuffleWriteMetrics(); // Iterate over the records that have not been returned and spill them. final UnsafeSorterSpillWriter spillWriter = new UnsafeSorterSpillWriter(blockManager, fileBufferSizeBytes, writeMetrics, numRecords); @@ -540,6 +538,7 @@ public long spill() throws IOException { inMemSorter.free(); inMemSorter = null; taskContext.taskMetrics().incMemoryBytesSpilled(released); + taskContext.taskMetrics().incDiskBytesSpilled(writeMetrics.bytesWritten()); totalSpillBytes += released; return released; } diff --git a/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java b/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java index 771d39016c188..d31d7c1c0900c 100644 --- a/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java +++ b/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java @@ -405,6 +405,31 @@ public void forcedSpillingWithoutComparator() throws Exception { assertSpillFilesWereCleanedUp(); } + @Test + public void testDiskSpilledBytes() throws Exception { + final UnsafeExternalSorter sorter = newSorter(); + long[] record = new long[100]; + int recordSize = record.length * 8; + int n = (int) pageSizeBytes / recordSize * 3; + for (int i = 0; i < n; i++) { + record[0] = (long) i; + sorter.insertRecord(record, Platform.LONG_ARRAY_OFFSET, recordSize, 0, false); + } + // We will have at-least 2 memory pages allocated because of rounding happening due to + // integer division of pageSizeBytes and recordSize. + assertTrue(sorter.getNumberOfAllocatedPages() >= 2); + assertTrue(taskContext.taskMetrics().diskBytesSpilled() == 0); + UnsafeExternalSorter.SpillableIterator iter = + (UnsafeExternalSorter.SpillableIterator) sorter.getSortedIterator(); + assertTrue(iter.spill() > 0); + assertTrue(taskContext.taskMetrics().diskBytesSpilled() > 0); + assertEquals(0, iter.spill()); + // Even if we did not spill second time, the disk spilled bytes should still be non-zero + assertTrue(taskContext.taskMetrics().diskBytesSpilled() > 0); + sorter.cleanupResources(); + assertSpillFilesWereCleanedUp(); + } + @Test public void testPeakMemoryUsed() throws Exception { final long recordLengthBytes = 8; From 9f6b3e65ccfa0daec31b58c5a6386b3a890c2149 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Thu, 29 Jun 2017 14:37:42 +0800 Subject: [PATCH 0816/1765] [SPARK-21238][SQL] allow nested SQL execution ## What changes were proposed in this pull request? This is kind of another follow-up for https://github.com/apache/spark/pull/18064 . In #18064 , we wrap every SQL command with SQL execution, which makes nested SQL execution very likely to happen. #18419 trid to improve it a little bit, by introduing `SQLExecition.ignoreNestedExecutionId`. However, this is not friendly to data source developers, they may need to update their code to use this `ignoreNestedExecutionId` API. This PR proposes a new solution, to just allow nested execution. The downside is that, we may have multiple executions for one query. We can improve this by updating the data organization in SQLListener, to have 1-n mapping from query to execution, instead of 1-1 mapping. This can be done in a follow-up. ## How was this patch tested? existing tests. Author: Wenchen Fan Closes #18450 from cloud-fan/execution-id. --- .../spark/sql/execution/SQLExecution.scala | 88 ++++--------------- .../command/AnalyzeTableCommand.scala | 4 +- .../spark/sql/execution/command/cache.scala | 16 ++-- .../datasources/csv/CSVDataSource.scala | 4 +- .../datasources/jdbc/JDBCRelation.scala | 8 +- .../sql/execution/streaming/console.scala | 12 +-- .../sql/execution/streaming/memory.scala | 32 ++++--- .../sql/execution/SQLExecutionSuite.scala | 24 ----- 8 files changed, 50 insertions(+), 138 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala index ca8bed5214f87..e991da7df0bde 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala @@ -22,15 +22,12 @@ import java.util.concurrent.atomic.AtomicLong import org.apache.spark.SparkContext import org.apache.spark.sql.SparkSession -import org.apache.spark.sql.execution.ui.{SparkListenerSQLExecutionEnd, - SparkListenerSQLExecutionStart} +import org.apache.spark.sql.execution.ui.{SparkListenerSQLExecutionEnd, SparkListenerSQLExecutionStart} object SQLExecution { val EXECUTION_ID_KEY = "spark.sql.execution.id" - private val IGNORE_NESTED_EXECUTION_ID = "spark.sql.execution.ignoreNestedExecutionId" - private val _nextExecutionId = new AtomicLong(0) private def nextExecutionId: Long = _nextExecutionId.getAndIncrement @@ -45,10 +42,8 @@ object SQLExecution { private[sql] def checkSQLExecutionId(sparkSession: SparkSession): Unit = { val sc = sparkSession.sparkContext - val isNestedExecution = sc.getLocalProperty(IGNORE_NESTED_EXECUTION_ID) != null - val hasExecutionId = sc.getLocalProperty(EXECUTION_ID_KEY) != null // only throw an exception during tests. a missing execution ID should not fail a job. - if (testing && !isNestedExecution && !hasExecutionId) { + if (testing && sc.getLocalProperty(EXECUTION_ID_KEY) == null) { // Attention testers: when a test fails with this exception, it means that the action that // started execution of a query didn't call withNewExecutionId. The execution ID should be // set by calling withNewExecutionId in the action that begins execution, like @@ -66,56 +61,27 @@ object SQLExecution { queryExecution: QueryExecution)(body: => T): T = { val sc = sparkSession.sparkContext val oldExecutionId = sc.getLocalProperty(EXECUTION_ID_KEY) - if (oldExecutionId == null) { - val executionId = SQLExecution.nextExecutionId - sc.setLocalProperty(EXECUTION_ID_KEY, executionId.toString) - executionIdToQueryExecution.put(executionId, queryExecution) - try { - // sparkContext.getCallSite() would first try to pick up any call site that was previously - // set, then fall back to Utils.getCallSite(); call Utils.getCallSite() directly on - // streaming queries would give us call site like "run at :0" - val callSite = sparkSession.sparkContext.getCallSite() - - sparkSession.sparkContext.listenerBus.post(SparkListenerSQLExecutionStart( - executionId, callSite.shortForm, callSite.longForm, queryExecution.toString, - SparkPlanInfo.fromSparkPlan(queryExecution.executedPlan), System.currentTimeMillis())) - try { - body - } finally { - sparkSession.sparkContext.listenerBus.post(SparkListenerSQLExecutionEnd( - executionId, System.currentTimeMillis())) - } - } finally { - executionIdToQueryExecution.remove(executionId) - sc.setLocalProperty(EXECUTION_ID_KEY, null) - } - } else if (sc.getLocalProperty(IGNORE_NESTED_EXECUTION_ID) != null) { - // If `IGNORE_NESTED_EXECUTION_ID` is set, just ignore the execution id while evaluating the - // `body`, so that Spark jobs issued in the `body` won't be tracked. + val executionId = SQLExecution.nextExecutionId + sc.setLocalProperty(EXECUTION_ID_KEY, executionId.toString) + executionIdToQueryExecution.put(executionId, queryExecution) + try { + // sparkContext.getCallSite() would first try to pick up any call site that was previously + // set, then fall back to Utils.getCallSite(); call Utils.getCallSite() directly on + // streaming queries would give us call site like "run at :0" + val callSite = sparkSession.sparkContext.getCallSite() + + sparkSession.sparkContext.listenerBus.post(SparkListenerSQLExecutionStart( + executionId, callSite.shortForm, callSite.longForm, queryExecution.toString, + SparkPlanInfo.fromSparkPlan(queryExecution.executedPlan), System.currentTimeMillis())) try { - sc.setLocalProperty(EXECUTION_ID_KEY, null) body } finally { - sc.setLocalProperty(EXECUTION_ID_KEY, oldExecutionId) + sparkSession.sparkContext.listenerBus.post(SparkListenerSQLExecutionEnd( + executionId, System.currentTimeMillis())) } - } else { - // Don't support nested `withNewExecutionId`. This is an example of the nested - // `withNewExecutionId`: - // - // class DataFrame { - // def foo: T = withNewExecutionId { something.createNewDataFrame().collect() } - // } - // - // Note: `collect` will call withNewExecutionId - // In this case, only the "executedPlan" for "collect" will be executed. The "executedPlan" - // for the outer DataFrame won't be executed. So it's meaningless to create a new Execution - // for the outer DataFrame. Even if we track it, since its "executedPlan" doesn't run, - // all accumulator metrics will be 0. It will confuse people if we show them in Web UI. - // - // A real case is the `DataFrame.count` method. - throw new IllegalArgumentException(s"$EXECUTION_ID_KEY is already set, please wrap your " + - "action with SQLExecution.ignoreNestedExecutionId if you don't want to track the Spark " + - "jobs issued by the nested execution.") + } finally { + executionIdToQueryExecution.remove(executionId) + sc.setLocalProperty(EXECUTION_ID_KEY, oldExecutionId) } } @@ -133,20 +99,4 @@ object SQLExecution { sc.setLocalProperty(SQLExecution.EXECUTION_ID_KEY, oldExecutionId) } } - - /** - * Wrap an action which may have nested execution id. This method can be used to run an execution - * inside another execution, e.g., `CacheTableCommand` need to call `Dataset.collect`. Note that, - * all Spark jobs issued in the body won't be tracked in UI. - */ - def ignoreNestedExecutionId[T](sparkSession: SparkSession)(body: => T): T = { - val sc = sparkSession.sparkContext - val allowNestedPreviousValue = sc.getLocalProperty(IGNORE_NESTED_EXECUTION_ID) - try { - sc.setLocalProperty(IGNORE_NESTED_EXECUTION_ID, "true") - body - } finally { - sc.setLocalProperty(IGNORE_NESTED_EXECUTION_ID, allowNestedPreviousValue) - } - } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeTableCommand.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeTableCommand.scala index d780ef42f3fae..42e2a9ca5c4e2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeTableCommand.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeTableCommand.scala @@ -51,9 +51,7 @@ case class AnalyzeTableCommand( // 2. when total size is changed, `oldRowCount` becomes invalid. // This is to make sure that we only record the right statistics. if (!noscan) { - val newRowCount = SQLExecution.ignoreNestedExecutionId(sparkSession) { - sparkSession.table(tableIdentWithDB).count() - } + val newRowCount = sparkSession.table(tableIdentWithDB).count() if (newRowCount >= 0 && newRowCount != oldRowCount) { newStats = if (newStats.isDefined) { newStats.map(_.copy(rowCount = Some(BigInt(newRowCount)))) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/cache.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/cache.scala index d36eb7587a3ef..47952f2f227a3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/cache.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/cache.scala @@ -34,16 +34,14 @@ case class CacheTableCommand( override def innerChildren: Seq[QueryPlan[_]] = plan.toSeq override def run(sparkSession: SparkSession): Seq[Row] = { - SQLExecution.ignoreNestedExecutionId(sparkSession) { - plan.foreach { logicalPlan => - Dataset.ofRows(sparkSession, logicalPlan).createTempView(tableIdent.quotedString) - } - sparkSession.catalog.cacheTable(tableIdent.quotedString) + plan.foreach { logicalPlan => + Dataset.ofRows(sparkSession, logicalPlan).createTempView(tableIdent.quotedString) + } + sparkSession.catalog.cacheTable(tableIdent.quotedString) - if (!isLazy) { - // Performs eager caching - sparkSession.table(tableIdent).count() - } + if (!isLazy) { + // Performs eager caching + sparkSession.table(tableIdent).count() } Seq.empty[Row] diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala index 99133bd70989a..2031381dd2e10 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala @@ -145,9 +145,7 @@ object TextInputCSVDataSource extends CSVDataSource { inputPaths: Seq[FileStatus], parsedOptions: CSVOptions): StructType = { val csv = createBaseDataset(sparkSession, inputPaths, parsedOptions) - val maybeFirstLine = SQLExecution.ignoreNestedExecutionId(sparkSession) { - CSVUtils.filterCommentAndEmpty(csv, parsedOptions).take(1).headOption - } + val maybeFirstLine = CSVUtils.filterCommentAndEmpty(csv, parsedOptions).take(1).headOption inferFromDataset(sparkSession, csv, maybeFirstLine, parsedOptions) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRelation.scala index b11da7045de22..a521fd1323852 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRelation.scala @@ -130,11 +130,9 @@ private[sql] case class JDBCRelation( } override def insert(data: DataFrame, overwrite: Boolean): Unit = { - SQLExecution.ignoreNestedExecutionId(data.sparkSession) { - data.write - .mode(if (overwrite) SaveMode.Overwrite else SaveMode.Append) - .jdbc(jdbcOptions.url, jdbcOptions.table, jdbcOptions.asProperties) - } + data.write + .mode(if (overwrite) SaveMode.Overwrite else SaveMode.Append) + .jdbc(jdbcOptions.url, jdbcOptions.table, jdbcOptions.asProperties) } override def toString: String = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/console.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/console.scala index 6fa7c113defaa..3baea6376069f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/console.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/console.scala @@ -48,11 +48,9 @@ class ConsoleSink(options: Map[String, String]) extends Sink with Logging { println(batchIdStr) println("-------------------------------------------") // scalastyle:off println - SQLExecution.ignoreNestedExecutionId(data.sparkSession) { - data.sparkSession.createDataFrame( - data.sparkSession.sparkContext.parallelize(data.collect()), data.schema) - .show(numRowsToShow, isTruncated) - } + data.sparkSession.createDataFrame( + data.sparkSession.sparkContext.parallelize(data.collect()), data.schema) + .show(numRowsToShow, isTruncated) } } @@ -82,9 +80,7 @@ class ConsoleSinkProvider extends StreamSinkProvider // Truncate the displayed data if it is too long, by default it is true val isTruncated = parameters.get("truncate").map(_.toBoolean).getOrElse(true) - SQLExecution.ignoreNestedExecutionId(sqlContext.sparkSession) { - data.show(numRowsToShow, isTruncated) - } + data.show(numRowsToShow, isTruncated) ConsoleRelation(sqlContext, data) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala index 198a342582804..4979873ee3c7f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala @@ -194,23 +194,21 @@ class MemorySink(val schema: StructType, outputMode: OutputMode) extends Sink wi } if (notCommitted) { logDebug(s"Committing batch $batchId to $this") - SQLExecution.ignoreNestedExecutionId(data.sparkSession) { - outputMode match { - case Append | Update => - val rows = AddedData(batchId, data.collect()) - synchronized { batches += rows } - - case Complete => - val rows = AddedData(batchId, data.collect()) - synchronized { - batches.clear() - batches += rows - } - - case _ => - throw new IllegalArgumentException( - s"Output mode $outputMode is not supported by MemorySink") - } + outputMode match { + case Append | Update => + val rows = AddedData(batchId, data.collect()) + synchronized { batches += rows } + + case Complete => + val rows = AddedData(batchId, data.collect()) + synchronized { + batches.clear() + batches += rows + } + + case _ => + throw new IllegalArgumentException( + s"Output mode $outputMode is not supported by MemorySink") } } else { logDebug(s"Skipping already committed batch: $batchId") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLExecutionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLExecutionSuite.scala index fe78a76568837..f6b006b98edd1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLExecutionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLExecutionSuite.scala @@ -26,22 +26,9 @@ import org.apache.spark.sql.SparkSession class SQLExecutionSuite extends SparkFunSuite { test("concurrent query execution (SPARK-10548)") { - // Try to reproduce the issue with the old SparkContext val conf = new SparkConf() .setMaster("local[*]") .setAppName("test") - val badSparkContext = new BadSparkContext(conf) - try { - testConcurrentQueryExecution(badSparkContext) - fail("unable to reproduce SPARK-10548") - } catch { - case e: IllegalArgumentException => - assert(e.getMessage.contains(SQLExecution.EXECUTION_ID_KEY)) - } finally { - badSparkContext.stop() - } - - // Verify that the issue is fixed with the latest SparkContext val goodSparkContext = new SparkContext(conf) try { testConcurrentQueryExecution(goodSparkContext) @@ -134,17 +121,6 @@ class SQLExecutionSuite extends SparkFunSuite { } } -/** - * A bad [[SparkContext]] that does not clone the inheritable thread local properties - * when passing them to children threads. - */ -private class BadSparkContext(conf: SparkConf) extends SparkContext(conf) { - protected[spark] override val localProperties = new InheritableThreadLocal[Properties] { - override protected def childValue(parent: Properties): Properties = new Properties(parent) - override protected def initialValue(): Properties = new Properties() - } -} - object SQLExecutionSuite { @volatile var canProgress = false } From a2d5623548194f15989e7b68118d744673e33819 Mon Sep 17 00:00:00 2001 From: actuaryzhang Date: Thu, 29 Jun 2017 01:23:13 -0700 Subject: [PATCH 0817/1765] [SPARK-20889][SPARKR] Grouped documentation for NONAGGREGATE column methods ## What changes were proposed in this pull request? Grouped documentation for nonaggregate column methods. Author: actuaryzhang Author: Wayne Zhang Closes #18422 from actuaryzhang/sparkRDocNonAgg. --- R/pkg/R/functions.R | 360 ++++++++++++++++++-------------------------- R/pkg/R/generics.R | 55 ++++--- 2 files changed, 182 insertions(+), 233 deletions(-) diff --git a/R/pkg/R/functions.R b/R/pkg/R/functions.R index 70ea620b471fe..cb09e847d739a 100644 --- a/R/pkg/R/functions.R +++ b/R/pkg/R/functions.R @@ -132,23 +132,39 @@ NULL #' df <- createDataFrame(as.data.frame(Titanic, stringsAsFactors = FALSE))} NULL -#' lit +#' Non-aggregate functions for Column operations #' -#' A new \linkS4class{Column} is created to represent the literal value. -#' If the parameter is a \linkS4class{Column}, it is returned unchanged. +#' Non-aggregate functions defined for \code{Column}. #' -#' @param x a literal value or a Column. +#' @param x Column to compute on. In \code{lit}, it is a literal value or a Column. +#' In \code{expr}, it contains an expression character object to be parsed. +#' @param y Column to compute on. +#' @param ... additional Columns. +#' @name column_nonaggregate_functions +#' @rdname column_nonaggregate_functions +#' @seealso coalesce,SparkDataFrame-method #' @family non-aggregate functions -#' @rdname lit -#' @name lit +#' @examples +#' \dontrun{ +#' # Dataframe used throughout this doc +#' df <- createDataFrame(cbind(model = rownames(mtcars), mtcars))} +NULL + +#' @details +#' \code{lit}: A new Column is created to represent the literal value. +#' If the parameter is a Column, it is returned unchanged. +#' +#' @rdname column_nonaggregate_functions #' @export -#' @aliases lit,ANY-method +#' @aliases lit lit,ANY-method #' @examples +#' #' \dontrun{ -#' lit(df$name) -#' select(df, lit("x")) -#' select(df, lit("2015-01-01")) -#'} +#' tmp <- mutate(df, v1 = lit(df$mpg), v2 = lit("x"), v3 = lit("2015-01-01"), +#' v4 = negate(df$mpg), v5 = expr('length(model)'), +#' v6 = greatest(df$vs, df$am), v7 = least(df$vs, df$am), +#' v8 = column("mpg")) +#' head(tmp)} #' @note lit since 1.5.0 setMethod("lit", signature("ANY"), function(x) { @@ -314,18 +330,16 @@ setMethod("bin", column(jc) }) -#' bitwiseNOT -#' -#' Computes bitwise NOT. -#' -#' @param x Column to compute on. +#' @details +#' \code{bitwiseNOT}: Computes bitwise NOT. #' -#' @rdname bitwiseNOT -#' @name bitwiseNOT -#' @family non-aggregate functions +#' @rdname column_nonaggregate_functions #' @export -#' @aliases bitwiseNOT,Column-method -#' @examples \dontrun{bitwiseNOT(df$c)} +#' @aliases bitwiseNOT bitwiseNOT,Column-method +#' @examples +#' +#' \dontrun{ +#' head(select(df, bitwiseNOT(cast(df$vs, "int"))))} #' @note bitwiseNOT since 1.5.0 setMethod("bitwiseNOT", signature(x = "Column"), @@ -375,16 +389,12 @@ setMethod("ceiling", ceil(x) }) -#' Returns the first column that is not NA -#' -#' Returns the first column that is not NA, or NA if all inputs are. +#' @details +#' \code{coalesce}: Returns the first column that is not NA, or NA if all inputs are. #' -#' @rdname coalesce -#' @name coalesce -#' @family non-aggregate functions +#' @rdname column_nonaggregate_functions #' @export #' @aliases coalesce,Column-method -#' @examples \dontrun{coalesce(df$c, df$d, df$e)} #' @note coalesce(Column) since 2.1.1 setMethod("coalesce", signature(x = "Column"), @@ -824,22 +834,24 @@ setMethod("initcap", column(jc) }) -#' is.nan -#' -#' Return true if the column is NaN, alias for \link{isnan} -#' -#' @param x Column to compute on. +#' @details +#' \code{isnan}: Returns true if the column is NaN. +#' @rdname column_nonaggregate_functions +#' @aliases isnan isnan,Column-method +#' @note isnan since 2.0.0 +setMethod("isnan", + signature(x = "Column"), + function(x) { + jc <- callJStatic("org.apache.spark.sql.functions", "isnan", x@jc) + column(jc) + }) + +#' @details +#' \code{is.nan}: Alias for \link{isnan}. #' -#' @rdname is.nan -#' @name is.nan -#' @family non-aggregate functions -#' @aliases is.nan,Column-method +#' @rdname column_nonaggregate_functions +#' @aliases is.nan is.nan,Column-method #' @export -#' @examples -#' \dontrun{ -#' is.nan(df$c) -#' isnan(df$c) -#' } #' @note is.nan since 2.0.0 setMethod("is.nan", signature(x = "Column"), @@ -847,17 +859,6 @@ setMethod("is.nan", isnan(x) }) -#' @rdname is.nan -#' @name isnan -#' @aliases isnan,Column-method -#' @note isnan since 2.0.0 -setMethod("isnan", - signature(x = "Column"), - function(x) { - jc <- callJStatic("org.apache.spark.sql.functions", "isnan", x@jc) - column(jc) - }) - #' @details #' \code{kurtosis}: Returns the kurtosis of the values in a group. #' @@ -1129,27 +1130,24 @@ setMethod("minute", column(jc) }) -#' monotonically_increasing_id -#' -#' Return a column that generates monotonically increasing 64-bit integers. -#' -#' The generated ID is guaranteed to be monotonically increasing and unique, but not consecutive. -#' The current implementation puts the partition ID in the upper 31 bits, and the record number -#' within each partition in the lower 33 bits. The assumption is that the SparkDataFrame has -#' less than 1 billion partitions, and each partition has less than 8 billion records. -#' -#' As an example, consider a SparkDataFrame with two partitions, each with 3 records. +#' @details +#' \code{monotonically_increasing_id}: Returns a column that generates monotonically increasing +#' 64-bit integers. The generated ID is guaranteed to be monotonically increasing and unique, +#' but not consecutive. The current implementation puts the partition ID in the upper 31 bits, +#' and the record number within each partition in the lower 33 bits. The assumption is that the +#' SparkDataFrame has less than 1 billion partitions, and each partition has less than 8 billion +#' records. As an example, consider a SparkDataFrame with two partitions, each with 3 records. #' This expression would return the following IDs: #' 0, 1, 2, 8589934592 (1L << 33), 8589934593, 8589934594. -#' #' This is equivalent to the MONOTONICALLY_INCREASING_ID function in SQL. +#' The method should be used with no argument. #' -#' @rdname monotonically_increasing_id -#' @aliases monotonically_increasing_id,missing-method -#' @name monotonically_increasing_id -#' @family misc functions +#' @rdname column_nonaggregate_functions +#' @aliases monotonically_increasing_id monotonically_increasing_id,missing-method #' @export -#' @examples \dontrun{select(df, monotonically_increasing_id())} +#' @examples +#' +#' \dontrun{head(select(df, monotonically_increasing_id()))} setMethod("monotonically_increasing_id", signature("missing"), function() { @@ -1171,18 +1169,12 @@ setMethod("month", column(jc) }) -#' negate -#' -#' Unary minus, i.e. negate the expression. -#' -#' @param x Column to compute on. +#' @details +#' \code{negate}: Unary minus, i.e. negate the expression. #' -#' @rdname negate -#' @name negate -#' @family non-aggregate functions -#' @aliases negate,Column-method +#' @rdname column_nonaggregate_functions +#' @aliases negate negate,Column-method #' @export -#' @examples \dontrun{negate(df$c)} #' @note negate since 1.5.0 setMethod("negate", signature(x = "Column"), @@ -1481,23 +1473,19 @@ setMethod("stddev_samp", column(jc) }) -#' struct -#' -#' Creates a new struct column that composes multiple input columns. -#' -#' @param x a column to compute on. -#' @param ... optional column(s) to be included. +#' @details +#' \code{struct}: Creates a new struct column that composes multiple input columns. #' -#' @rdname struct -#' @name struct -#' @family non-aggregate functions -#' @aliases struct,characterOrColumn-method +#' @rdname column_nonaggregate_functions +#' @aliases struct struct,characterOrColumn-method #' @export #' @examples +#' #' \dontrun{ -#' struct(df$c, df$d) -#' struct("col1", "col2") -#' } +#' tmp <- mutate(df, v1 = struct(df$mpg, df$cyl), v2 = struct("hp", "wt", "vs"), +#' v3 = create_array(df$mpg, df$cyl, df$hp), +#' v4 = create_map(lit("x"), lit(1.0), lit("y"), lit(-1.0))) +#' head(tmp)} #' @note struct since 1.6.0 setMethod("struct", signature(x = "characterOrColumn"), @@ -1959,20 +1947,13 @@ setMethod("months_between", signature(y = "Column"), column(jc) }) -#' nanvl -#' -#' Returns col1 if it is not NaN, or col2 if col1 is NaN. -#' Both inputs should be floating point columns (DoubleType or FloatType). -#' -#' @param x first Column. -#' @param y second Column. +#' @details +#' \code{nanvl}: Returns the first column (\code{y}) if it is not NaN, or the second column (\code{x}) if +#' the first column is NaN. Both inputs should be floating point columns (DoubleType or FloatType). #' -#' @rdname nanvl -#' @name nanvl -#' @family non-aggregate functions -#' @aliases nanvl,Column-method +#' @rdname column_nonaggregate_functions +#' @aliases nanvl nanvl,Column-method #' @export -#' @examples \dontrun{nanvl(df$c, x)} #' @note nanvl since 1.5.0 setMethod("nanvl", signature(y = "Column"), function(y, x) { @@ -2060,20 +2041,13 @@ setMethod("concat", column(jc) }) -#' greatest -#' -#' Returns the greatest value of the list of column names, skipping null values. +#' @details +#' \code{greatest}: Returns the greatest value of the list of column names, skipping null values. #' This function takes at least 2 parameters. It will return null if all parameters are null. #' -#' @param x Column to compute on -#' @param ... other columns -#' -#' @family non-aggregate functions -#' @rdname greatest -#' @name greatest -#' @aliases greatest,Column-method +#' @rdname column_nonaggregate_functions +#' @aliases greatest greatest,Column-method #' @export -#' @examples \dontrun{greatest(df$c, df$d)} #' @note greatest since 1.5.0 setMethod("greatest", signature(x = "Column"), @@ -2087,20 +2061,13 @@ setMethod("greatest", column(jc) }) -#' least -#' -#' Returns the least value of the list of column names, skipping null values. +#' @details +#' \code{least}: Returns the least value of the list of column names, skipping null values. #' This function takes at least 2 parameters. It will return null if all parameters are null. #' -#' @param x Column to compute on -#' @param ... other columns -#' -#' @family non-aggregate functions -#' @rdname least -#' @aliases least,Column-method -#' @name least +#' @rdname column_nonaggregate_functions +#' @aliases least least,Column-method #' @export -#' @examples \dontrun{least(df$c, df$d)} #' @note least since 1.5.0 setMethod("least", signature(x = "Column"), @@ -2445,18 +2412,13 @@ setMethod("conv", signature(x = "Column", fromBase = "numeric", toBase = "numeri column(jc) }) -#' expr -#' -#' Parses the expression string into the column that it represents, similar to -#' SparkDataFrame.selectExpr +#' @details +#' \code{expr}: Parses the expression string into the column that it represents, similar to +#' \code{SparkDataFrame.selectExpr} #' -#' @param x an expression character object to be parsed. -#' @family non-aggregate functions -#' @rdname expr -#' @aliases expr,character-method -#' @name expr +#' @rdname column_nonaggregate_functions +#' @aliases expr expr,character-method #' @export -#' @examples \dontrun{expr('length(name)')} #' @note expr since 1.5.0 setMethod("expr", signature(x = "character"), function(x) { @@ -2617,18 +2579,19 @@ setMethod("lpad", signature(x = "Column", len = "numeric", pad = "character"), column(jc) }) -#' rand -#' -#' Generate a random column with independent and identically distributed (i.i.d.) samples +#' @details +#' \code{rand}: Generates a random column with independent and identically distributed (i.i.d.) samples #' from U[0.0, 1.0]. #' +#' @rdname column_nonaggregate_functions #' @param seed a random seed. Can be missing. -#' @family non-aggregate functions -#' @rdname rand -#' @name rand -#' @aliases rand,missing-method +#' @aliases rand rand,missing-method #' @export -#' @examples \dontrun{rand()} +#' @examples +#' +#' \dontrun{ +#' tmp <- mutate(df, r1 = rand(), r2 = rand(10), r3 = randn(), r4 = randn(10)) +#' head(tmp)} #' @note rand since 1.5.0 setMethod("rand", signature(seed = "missing"), function(seed) { @@ -2636,8 +2599,7 @@ setMethod("rand", signature(seed = "missing"), column(jc) }) -#' @rdname rand -#' @name rand +#' @rdname column_nonaggregate_functions #' @aliases rand,numeric-method #' @export #' @note rand(numeric) since 1.5.0 @@ -2647,18 +2609,13 @@ setMethod("rand", signature(seed = "numeric"), column(jc) }) -#' randn -#' -#' Generate a column with independent and identically distributed (i.i.d.) samples from +#' @details +#' \code{randn}: Generates a column with independent and identically distributed (i.i.d.) samples from #' the standard normal distribution. #' -#' @param seed a random seed. Can be missing. -#' @family non-aggregate functions -#' @rdname randn -#' @name randn -#' @aliases randn,missing-method +#' @rdname column_nonaggregate_functions +#' @aliases randn randn,missing-method #' @export -#' @examples \dontrun{randn()} #' @note randn since 1.5.0 setMethod("randn", signature(seed = "missing"), function(seed) { @@ -2666,8 +2623,7 @@ setMethod("randn", signature(seed = "missing"), column(jc) }) -#' @rdname randn -#' @name randn +#' @rdname column_nonaggregate_functions #' @aliases randn,numeric-method #' @export #' @note randn(numeric) since 1.5.0 @@ -2819,20 +2775,26 @@ setMethod("unix_timestamp", signature(x = "Column", format = "character"), jc <- callJStatic("org.apache.spark.sql.functions", "unix_timestamp", x@jc, format) column(jc) }) -#' when -#' -#' Evaluates a list of conditions and returns one of multiple possible result expressions. + +#' @details +#' \code{when}: Evaluates a list of conditions and returns one of multiple possible result expressions. #' For unmatched expressions null is returned. #' +#' @rdname column_nonaggregate_functions #' @param condition the condition to test on. Must be a Column expression. #' @param value result expression. -#' @family non-aggregate functions -#' @rdname when -#' @name when -#' @aliases when,Column-method -#' @seealso \link{ifelse} +#' @aliases when when,Column-method #' @export -#' @examples \dontrun{when(df$age == 2, df$age + 1)} +#' @examples +#' +#' \dontrun{ +#' tmp <- mutate(df, mpg_na = otherwise(when(df$mpg > 20, df$mpg), lit(NaN)), +#' mpg2 = ifelse(df$mpg > 20 & df$am > 0, 0, 1), +#' mpg3 = ifelse(df$mpg > 20, df$mpg, 20.0)) +#' head(tmp) +#' tmp <- mutate(tmp, ind_na1 = is.nan(tmp$mpg_na), ind_na2 = isnan(tmp$mpg_na)) +#' head(select(tmp, coalesce(tmp$mpg_na, tmp$mpg))) +#' head(select(tmp, nanvl(tmp$mpg_na, tmp$hp)))} #' @note when since 1.5.0 setMethod("when", signature(condition = "Column", value = "ANY"), function(condition, value) { @@ -2842,25 +2804,16 @@ setMethod("when", signature(condition = "Column", value = "ANY"), column(jc) }) -#' ifelse -#' -#' Evaluates a list of conditions and returns \code{yes} if the conditions are satisfied. +#' @details +#' \code{ifelse}: Evaluates a list of conditions and returns \code{yes} if the conditions are satisfied. #' Otherwise \code{no} is returned for unmatched conditions. #' +#' @rdname column_nonaggregate_functions #' @param test a Column expression that describes the condition. #' @param yes return values for \code{TRUE} elements of test. #' @param no return values for \code{FALSE} elements of test. -#' @family non-aggregate functions -#' @rdname ifelse -#' @name ifelse -#' @aliases ifelse,Column-method -#' @seealso \link{when} +#' @aliases ifelse ifelse,Column-method #' @export -#' @examples -#' \dontrun{ -#' ifelse(df$a > 1 & df$b > 2, 0, 1) -#' ifelse(df$a > 1, df$a, 1) -#' } #' @note ifelse since 1.5.0 setMethod("ifelse", signature(test = "Column", yes = "ANY", no = "ANY"), @@ -3263,19 +3216,12 @@ setMethod("posexplode", column(jc) }) -#' create_array -#' -#' Creates a new array column. The input columns must all have the same data type. -#' -#' @param x Column to compute on -#' @param ... additional Column(s). +#' @details +#' \code{create_array}: Creates a new array column. The input columns must all have the same data type. #' -#' @family non-aggregate functions -#' @rdname create_array -#' @name create_array -#' @aliases create_array,Column-method +#' @rdname column_nonaggregate_functions +#' @aliases create_array create_array,Column-method #' @export -#' @examples \dontrun{create_array(df$x, df$y, df$z)} #' @note create_array since 2.3.0 setMethod("create_array", signature(x = "Column"), @@ -3288,22 +3234,15 @@ setMethod("create_array", column(jc) }) -#' create_map -#' -#' Creates a new map column. The input columns must be grouped as key-value pairs, +#' @details +#' \code{create_map}: Creates a new map column. The input columns must be grouped as key-value pairs, #' e.g. (key1, value1, key2, value2, ...). #' The key columns must all have the same data type, and can't be null. #' The value columns must all have the same data type. #' -#' @param x Column to compute on -#' @param ... additional Column(s). -#' -#' @family non-aggregate functions -#' @rdname create_map -#' @name create_map -#' @aliases create_map,Column-method +#' @rdname column_nonaggregate_functions +#' @aliases create_map create_map,Column-method #' @export -#' @examples \dontrun{create_map(lit("x"), lit(1.0), lit("y"), lit(-1.0))} #' @note create_map since 2.3.0 setMethod("create_map", signature(x = "Column"), @@ -3554,21 +3493,18 @@ setMethod("grouping_id", column(jc) }) -#' input_file_name -#' -#' Creates a string column with the input file name for a given row +#' @details +#' \code{input_file_name}: Creates a string column with the input file name for a given row. +#' The method should be used with no argument. #' -#' @rdname input_file_name -#' @name input_file_name -#' @family non-aggregate functions -#' @aliases input_file_name,missing-method +#' @rdname column_nonaggregate_functions +#' @aliases input_file_name input_file_name,missing-method #' @export #' @examples -#' \dontrun{ -#' df <- read.text("README.md") #' -#' head(select(df, input_file_name())) -#' } +#' \dontrun{ +#' tmp <- read.text("README.md") +#' head(select(tmp, input_file_name()))} #' @note input_file_name since 2.3.0 setMethod("input_file_name", signature("missing"), function() { diff --git a/R/pkg/R/generics.R b/R/pkg/R/generics.R index dc99e3d94b269..1deb057bb1b82 100644 --- a/R/pkg/R/generics.R +++ b/R/pkg/R/generics.R @@ -422,9 +422,8 @@ setGeneric("cache", function(x) { standardGeneric("cache") }) setGeneric("checkpoint", function(x, eager = TRUE) { standardGeneric("checkpoint") }) #' @rdname coalesce -#' @param x a Column or a SparkDataFrame. -#' @param ... additional argument(s). If \code{x} is a Column, additional Columns can be optionally -#' provided. +#' @param x a SparkDataFrame. +#' @param ... additional argument(s). #' @export setGeneric("coalesce", function(x, ...) { standardGeneric("coalesce") }) @@ -863,8 +862,9 @@ setGeneric("rlike", function(x, ...) { standardGeneric("rlike") }) #' @export setGeneric("startsWith", function(x, prefix) { standardGeneric("startsWith") }) -#' @rdname when +#' @rdname column_nonaggregate_functions #' @export +#' @name NULL setGeneric("when", function(condition, value) { standardGeneric("when") }) #' @rdname otherwise @@ -938,8 +938,9 @@ setGeneric("base64", function(x) { standardGeneric("base64") }) #' @name NULL setGeneric("bin", function(x) { standardGeneric("bin") }) -#' @rdname bitwiseNOT +#' @rdname column_nonaggregate_functions #' @export +#' @name NULL setGeneric("bitwiseNOT", function(x) { standardGeneric("bitwiseNOT") }) #' @rdname column_math_functions @@ -995,12 +996,14 @@ setGeneric("countDistinct", function(x, ...) { standardGeneric("countDistinct") #' @export setGeneric("crc32", function(x) { standardGeneric("crc32") }) -#' @rdname create_array +#' @rdname column_nonaggregate_functions #' @export +#' @name NULL setGeneric("create_array", function(x, ...) { standardGeneric("create_array") }) -#' @rdname create_map +#' @rdname column_nonaggregate_functions #' @export +#' @name NULL setGeneric("create_map", function(x, ...) { standardGeneric("create_map") }) #' @rdname hash @@ -1065,8 +1068,9 @@ setGeneric("explode", function(x) { standardGeneric("explode") }) #' @export setGeneric("explode_outer", function(x) { standardGeneric("explode_outer") }) -#' @rdname expr +#' @rdname column_nonaggregate_functions #' @export +#' @name NULL setGeneric("expr", function(x) { standardGeneric("expr") }) #' @rdname column_datetime_diff_functions @@ -1093,8 +1097,9 @@ setGeneric("from_json", function(x, schema, ...) { standardGeneric("from_json") #' @name NULL setGeneric("from_unixtime", function(x, ...) { standardGeneric("from_unixtime") }) -#' @rdname greatest +#' @rdname column_nonaggregate_functions #' @export +#' @name NULL setGeneric("greatest", function(x, ...) { standardGeneric("greatest") }) #' @rdname column_aggregate_functions @@ -1127,9 +1132,9 @@ setGeneric("hypot", function(y, x) { standardGeneric("hypot") }) #' @name NULL setGeneric("initcap", function(x) { standardGeneric("initcap") }) -#' @param x empty. Should be used with no argument. -#' @rdname input_file_name +#' @rdname column_nonaggregate_functions #' @export +#' @name NULL setGeneric("input_file_name", function(x = "missing") { standardGeneric("input_file_name") }) @@ -1138,8 +1143,9 @@ setGeneric("input_file_name", #' @name NULL setGeneric("instr", function(y, x) { standardGeneric("instr") }) -#' @rdname is.nan +#' @rdname column_nonaggregate_functions #' @export +#' @name NULL setGeneric("isnan", function(x) { standardGeneric("isnan") }) #' @rdname column_aggregate_functions @@ -1164,8 +1170,9 @@ setGeneric("last_day", function(x) { standardGeneric("last_day") }) #' @export setGeneric("lead", function(x, offset, defaultValue = NULL) { standardGeneric("lead") }) -#' @rdname least +#' @rdname column_nonaggregate_functions #' @export +#' @name NULL setGeneric("least", function(x, ...) { standardGeneric("least") }) #' @rdname column_string_functions @@ -1173,8 +1180,9 @@ setGeneric("least", function(x, ...) { standardGeneric("least") }) #' @name NULL setGeneric("levenshtein", function(y, x) { standardGeneric("levenshtein") }) -#' @rdname lit +#' @rdname column_nonaggregate_functions #' @export +#' @name NULL setGeneric("lit", function(x) { standardGeneric("lit") }) #' @rdname column_string_functions @@ -1206,9 +1214,9 @@ setGeneric("md5", function(x) { standardGeneric("md5") }) #' @name NULL setGeneric("minute", function(x) { standardGeneric("minute") }) -#' @param x empty. Should be used with no argument. -#' @rdname monotonically_increasing_id +#' @rdname column_nonaggregate_functions #' @export +#' @name NULL setGeneric("monotonically_increasing_id", function(x = "missing") { standardGeneric("monotonically_increasing_id") }) @@ -1226,12 +1234,14 @@ setGeneric("months_between", function(y, x) { standardGeneric("months_between") #' @export setGeneric("n", function(x) { standardGeneric("n") }) -#' @rdname nanvl +#' @rdname column_nonaggregate_functions #' @export +#' @name NULL setGeneric("nanvl", function(y, x) { standardGeneric("nanvl") }) -#' @rdname negate +#' @rdname column_nonaggregate_functions #' @export +#' @name NULL setGeneric("negate", function(x) { standardGeneric("negate") }) #' @rdname not @@ -1275,12 +1285,14 @@ setGeneric("posexplode_outer", function(x) { standardGeneric("posexplode_outer") #' @name NULL setGeneric("quarter", function(x) { standardGeneric("quarter") }) -#' @rdname rand +#' @rdname column_nonaggregate_functions #' @export +#' @name NULL setGeneric("rand", function(seed) { standardGeneric("rand") }) -#' @rdname randn +#' @rdname column_nonaggregate_functions #' @export +#' @name NULL setGeneric("randn", function(seed) { standardGeneric("randn") }) #' @rdname rank @@ -1409,8 +1421,9 @@ setGeneric("stddev_pop", function(x) { standardGeneric("stddev_pop") }) #' @name NULL setGeneric("stddev_samp", function(x) { standardGeneric("stddev_samp") }) -#' @rdname struct +#' @rdname column_nonaggregate_functions #' @export +#' @name NULL setGeneric("struct", function(x, ...) { standardGeneric("struct") }) #' @rdname column_string_functions From 70085e83d1ee728b23f7df15f570eb8d77f67a7a Mon Sep 17 00:00:00 2001 From: Nick Pentreath Date: Thu, 29 Jun 2017 09:51:12 +0100 Subject: [PATCH 0818/1765] [SPARK-21210][DOC][ML] Javadoc 8 fixes for ML shared param traits PR #15999 included fixes for doc strings in the ML shared param traits (occurrences of `>` and `>=`). This PR simply uses the HTML-escaped version of the param doc to embed into the Scaladoc, to ensure that when `SharedParamsCodeGen` is run, the generated javadoc will be compliant for Java 8. ## How was this patch tested? Existing tests Author: Nick Pentreath Closes #18420 from MLnick/shared-params-javadoc8. --- .../apache/spark/ml/param/shared/SharedParamsCodeGen.scala | 5 ++++- .../org/apache/spark/ml/param/shared/sharedParams.scala | 2 +- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala b/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala index c94b8b4e9dfda..013817a41baf5 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala @@ -20,6 +20,7 @@ package org.apache.spark.ml.param.shared import java.io.PrintWriter import scala.reflect.ClassTag +import scala.xml.Utility /** * Code generator for shared params (sharedParams.scala). Run under the Spark folder with @@ -167,6 +168,8 @@ private[shared] object SharedParamsCodeGen { "def" } + val htmlCompliantDoc = Utility.escape(doc) + s""" |/** | * Trait for shared param $name$defaultValueDoc. @@ -174,7 +177,7 @@ private[shared] object SharedParamsCodeGen { |private[ml] trait Has$Name extends Params { | | /** - | * Param for $doc. + | * Param for $htmlCompliantDoc. | * @group ${groupStr(0)} | */ | final val $name: $Param = new $Param(this, "$name", "$doc"$isValid) diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala b/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala index e3e03dfd43dd6..50619607a5054 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala @@ -176,7 +176,7 @@ private[ml] trait HasThreshold extends Params { private[ml] trait HasThresholds extends Params { /** - * Param for Thresholds in multi-class classification to adjust the probability of predicting each class. Array must have length equal to the number of classes, with values > 0 excepting that at most one value may be 0. The class with largest value p/t is predicted, where p is the original probability of that class and t is the class's threshold. + * Param for Thresholds in multi-class classification to adjust the probability of predicting each class. Array must have length equal to the number of classes, with values > 0 excepting that at most one value may be 0. The class with largest value p/t is predicted, where p is the original probability of that class and t is the class's threshold. * @group param */ final val thresholds: DoubleArrayParam = new DoubleArrayParam(this, "thresholds", "Thresholds in multi-class classification to adjust the probability of predicting each class. Array must have length equal to the number of classes, with values > 0 excepting that at most one value may be 0. The class with largest value p/t is predicted, where p is the original probability of that class and t is the class's threshold", (t: Array[Double]) => t.forall(_ >= 0) && t.count(_ == 0) <= 1) From d106a74c53f493c3c18741a9b19cb821dace4ba2 Mon Sep 17 00:00:00 2001 From: jinxing Date: Thu, 29 Jun 2017 09:59:36 +0100 Subject: [PATCH 0819/1765] [SPARK-21240] Fix code style for constructing and stopping a SparkContext in UT. ## What changes were proposed in this pull request? Same with SPARK-20985. Fix code style for constructing and stopping a `SparkContext`. Assure the context is stopped to avoid other tests complain that there's only one `SparkContext` can exist. Author: jinxing Closes #18454 from jinxing64/SPARK-21240. --- .../scala/org/apache/spark/scheduler/MapStatusSuite.scala | 6 ++---- .../apache/spark/sql/execution/ui/SQLListenerSuite.scala | 6 ++---- 2 files changed, 4 insertions(+), 8 deletions(-) diff --git a/core/src/test/scala/org/apache/spark/scheduler/MapStatusSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/MapStatusSuite.scala index e6120139f4958..276169e02f01d 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/MapStatusSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/MapStatusSuite.scala @@ -26,6 +26,7 @@ import org.roaringbitmap.RoaringBitmap import org.apache.spark.{SparkConf, SparkContext, SparkEnv, SparkFunSuite} import org.apache.spark.internal.config +import org.apache.spark.LocalSparkContext._ import org.apache.spark.serializer.{JavaSerializer, KryoSerializer} import org.apache.spark.storage.BlockManagerId @@ -160,12 +161,9 @@ class MapStatusSuite extends SparkFunSuite { .set("spark.serializer", classOf[KryoSerializer].getName) .setMaster("local") .setAppName("SPARK-21133") - val sc = new SparkContext(conf) - try { + withSpark(new SparkContext(conf)) { sc => val count = sc.parallelize(0 until 3000, 10).repartition(2001).collect().length assert(count === 3000) - } finally { - sc.stop() } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLListenerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLListenerSuite.scala index e6cd41e4facf1..82eff5e6491ef 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLListenerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLListenerSuite.scala @@ -25,6 +25,7 @@ import org.mockito.Mockito.mock import org.apache.spark._ import org.apache.spark.executor.TaskMetrics import org.apache.spark.internal.config +import org.apache.spark.LocalSparkContext._ import org.apache.spark.rdd.RDD import org.apache.spark.scheduler._ import org.apache.spark.sql.{DataFrame, SparkSession} @@ -496,8 +497,7 @@ class SQLListenerMemoryLeakSuite extends SparkFunSuite { .setAppName("test") .set(config.MAX_TASK_FAILURES, 1) // Don't retry the tasks to run this test quickly .set("spark.sql.ui.retainedExecutions", "50") // Set it to 50 to run this test quickly - val sc = new SparkContext(conf) - try { + withSpark(new SparkContext(conf)) { sc => SparkSession.sqlListener.set(null) val spark = new SparkSession(sc) import spark.implicits._ @@ -522,8 +522,6 @@ class SQLListenerMemoryLeakSuite extends SparkFunSuite { assert(spark.sharedState.listener.executionIdToData.size <= 100) assert(spark.sharedState.listener.jobIdToExecutionId.size <= 100) assert(spark.sharedState.listener.stageIdToStageMetrics.size <= 100) - } finally { - sc.stop() } } } From d7da2b94d6107341b33ca9224e9bfa4c9a92ed88 Mon Sep 17 00:00:00 2001 From: fjh100456 Date: Thu, 29 Jun 2017 10:01:12 +0100 Subject: [PATCH 0820/1765] =?UTF-8?q?[SPARK-21135][WEB=20UI]=20On=20histor?= =?UTF-8?q?y=20server=20page=EF=BC=8Cduration=20of=20incompleted=20applica?= =?UTF-8?q?tions=20should=20be=20hidden=20instead=20of=20showing=20up=20as?= =?UTF-8?q?=200?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## What changes were proposed in this pull request? Hide duration of incompleted applications. ## How was this patch tested? manual tests Author: fjh100456 Closes #18351 from fjh100456/master. --- .../spark/ui/static/historypage-template.html | 4 ++-- .../org/apache/spark/ui/static/historypage.js | 15 ++++++++++----- 2 files changed, 12 insertions(+), 7 deletions(-) diff --git a/core/src/main/resources/org/apache/spark/ui/static/historypage-template.html b/core/src/main/resources/org/apache/spark/ui/static/historypage-template.html index bfe31aae555ba..6cff0068d8bcb 100644 --- a/core/src/main/resources/org/apache/spark/ui/static/historypage-template.html +++ b/core/src/main/resources/org/apache/spark/ui/static/historypage-template.html @@ -44,7 +44,7 @@ Completed - + Duration @@ -74,7 +74,7 @@ {{attemptId}} {{startTime}} {{endTime}} - {{duration}} + {{duration}} {{sparkUser}} {{lastUpdated}} Download diff --git a/core/src/main/resources/org/apache/spark/ui/static/historypage.js b/core/src/main/resources/org/apache/spark/ui/static/historypage.js index 5ec1ce15a2127..9edd3ba0e0ba6 100644 --- a/core/src/main/resources/org/apache/spark/ui/static/historypage.js +++ b/core/src/main/resources/org/apache/spark/ui/static/historypage.js @@ -182,12 +182,17 @@ $(document).ready(function() { for (i = 0; i < completedCells.length; i++) { completedCells[i].style.display='none'; } - } - var durationCells = document.getElementsByClassName("durationClass"); - for (i = 0; i < durationCells.length; i++) { - var timeInMilliseconds = parseInt(durationCells[i].title); - durationCells[i].innerHTML = formatDuration(timeInMilliseconds); + var durationCells = document.getElementsByClassName("durationColumn"); + for (i = 0; i < durationCells.length; i++) { + durationCells[i].style.display='none'; + } + } else { + var durationCells = document.getElementsByClassName("durationClass"); + for (i = 0; i < durationCells.length; i++) { + var timeInMilliseconds = parseInt(durationCells[i].title); + durationCells[i].innerHTML = formatDuration(timeInMilliseconds); + } } if ($(selector.concat(" tr")).length < 20) { From 29bd251dd5914fc3b6146eb4fe0b45f1c84dba62 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=9D=A8=E6=B2=BB=E5=9B=BD10192065?= Date: Thu, 29 Jun 2017 20:53:48 +0800 Subject: [PATCH 0821/1765] [SPARK-21225][CORE] Considering CPUS_PER_TASK when allocating task slots for each WorkerOffer MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit JIRA Issue:https://issues.apache.org/jira/browse/SPARK-21225 In the function "resourceOffers", It declare a variable "tasks" for storage the tasks which have allocated a executor. It declared like this: `val tasks = shuffledOffers.map(o => new ArrayBuffer[TaskDescription](o.cores))` But, I think this code only conside a situation for that one task per core. If the user set "spark.task.cpus" as 2 or 3, It really don't need so much Mem. I think It can motify as follow: val tasks = shuffledOffers.map(o => new ArrayBuffer[TaskDescription](o.cores / CPUS_PER_TASK)) to instead. Motify like this the other earning is that it's more easy to understand the way how the tasks allocate offers. Author: 杨治国10192065 Closes #18435 from JackYangzg/motifyTaskCoreDisp. --- .../scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala index 91ec172ffeda1..737b383631148 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala @@ -345,7 +345,7 @@ private[spark] class TaskSchedulerImpl( val shuffledOffers = shuffleOffers(filteredOffers) // Build a list of tasks to assign to each worker. - val tasks = shuffledOffers.map(o => new ArrayBuffer[TaskDescription](o.cores)) + val tasks = shuffledOffers.map(o => new ArrayBuffer[TaskDescription](o.cores / CPUS_PER_TASK)) val availableCpus = shuffledOffers.map(o => o.cores).toArray val sortedTaskSets = rootPool.getSortedTaskSetQueue for (taskSet <- sortedTaskSets) { From 18066f2e61f430b691ed8a777c9b4e5786bf9dbc Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Thu, 29 Jun 2017 21:28:48 +0800 Subject: [PATCH 0822/1765] [SPARK-21052][SQL] Add hash map metrics to join ## What changes were proposed in this pull request? This adds the average hash map probe metrics to join operator such as `BroadcastHashJoin` and `ShuffledHashJoin`. This PR adds the API to `HashedRelation` to get average hash map probe. ## How was this patch tested? Related test cases are added. Author: Liang-Chi Hsieh Closes #18301 from viirya/SPARK-21052. --- .../aggregate/HashAggregateExec.scala | 15 +- .../TungstenAggregationIterator.scala | 34 ++-- .../joins/BroadcastHashJoinExec.scala | 30 ++- .../spark/sql/execution/joins/HashJoin.scala | 8 +- .../sql/execution/joins/HashedRelation.scala | 43 +++- .../joins/ShuffledHashJoinExec.scala | 6 +- .../sql/execution/metric/SQLMetrics.scala | 32 ++- .../execution/metric/SQLMetricsSuite.scala | 188 ++++++++++++++++-- 8 files changed, 296 insertions(+), 60 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala index 5027a615ced7a..56f61c30c4a38 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala @@ -60,7 +60,7 @@ case class HashAggregateExec( "peakMemory" -> SQLMetrics.createSizeMetric(sparkContext, "peak memory"), "spillSize" -> SQLMetrics.createSizeMetric(sparkContext, "spill size"), "aggTime" -> SQLMetrics.createTimingMetric(sparkContext, "aggregate time"), - "avgHashmapProbe" -> SQLMetrics.createAverageMetric(sparkContext, "avg hashmap probe")) + "avgHashProbe" -> SQLMetrics.createAverageMetric(sparkContext, "avg hash probe")) override def output: Seq[Attribute] = resultExpressions.map(_.toAttribute) @@ -94,7 +94,7 @@ case class HashAggregateExec( val numOutputRows = longMetric("numOutputRows") val peakMemory = longMetric("peakMemory") val spillSize = longMetric("spillSize") - val avgHashmapProbe = longMetric("avgHashmapProbe") + val avgHashProbe = longMetric("avgHashProbe") child.execute().mapPartitions { iter => @@ -119,7 +119,7 @@ case class HashAggregateExec( numOutputRows, peakMemory, spillSize, - avgHashmapProbe) + avgHashProbe) if (!hasInput && groupingExpressions.isEmpty) { numOutputRows += 1 Iterator.single[UnsafeRow](aggregationIterator.outputForEmptyGroupingKeyWithoutInput()) @@ -344,7 +344,7 @@ case class HashAggregateExec( sorter: UnsafeKVExternalSorter, peakMemory: SQLMetric, spillSize: SQLMetric, - avgHashmapProbe: SQLMetric): KVIterator[UnsafeRow, UnsafeRow] = { + avgHashProbe: SQLMetric): KVIterator[UnsafeRow, UnsafeRow] = { // update peak execution memory val mapMemory = hashMap.getPeakMemoryUsedBytes @@ -355,8 +355,7 @@ case class HashAggregateExec( metrics.incPeakExecutionMemory(maxMemory) // Update average hashmap probe - val avgProbes = hashMap.getAverageProbesPerLookup() - avgHashmapProbe.add(avgProbes.ceil.toLong) + avgHashProbe.set(hashMap.getAverageProbesPerLookup()) if (sorter == null) { // not spilled @@ -584,7 +583,7 @@ case class HashAggregateExec( val doAgg = ctx.freshName("doAggregateWithKeys") val peakMemory = metricTerm(ctx, "peakMemory") val spillSize = metricTerm(ctx, "spillSize") - val avgHashmapProbe = metricTerm(ctx, "avgHashmapProbe") + val avgHashProbe = metricTerm(ctx, "avgHashProbe") def generateGenerateCode(): String = { if (isFastHashMapEnabled) { @@ -611,7 +610,7 @@ case class HashAggregateExec( s"$iterTermForFastHashMap = $fastHashMapTerm.rowIterator();"} else ""} $iterTerm = $thisPlan.finishAggregate($hashMapTerm, $sorterTerm, $peakMemory, $spillSize, - $avgHashmapProbe); + $avgHashProbe); } """) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala index 8efa95d48aea0..cfa930607360c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala @@ -89,7 +89,7 @@ class TungstenAggregationIterator( numOutputRows: SQLMetric, peakMemory: SQLMetric, spillSize: SQLMetric, - avgHashmapProbe: SQLMetric) + avgHashProbe: SQLMetric) extends AggregationIterator( groupingExpressions, originalInputAttributes, @@ -367,6 +367,22 @@ class TungstenAggregationIterator( } } + TaskContext.get().addTaskCompletionListener(_ => { + // At the end of the task, update the task's peak memory usage. Since we destroy + // the map to create the sorter, their memory usages should not overlap, so it is safe + // to just use the max of the two. + val mapMemory = hashMap.getPeakMemoryUsedBytes + val sorterMemory = Option(externalSorter).map(_.getPeakMemoryUsedBytes).getOrElse(0L) + val maxMemory = Math.max(mapMemory, sorterMemory) + val metrics = TaskContext.get().taskMetrics() + peakMemory.set(maxMemory) + spillSize.set(metrics.memoryBytesSpilled - spillSizeBefore) + metrics.incPeakExecutionMemory(maxMemory) + + // Updating average hashmap probe + avgHashProbe.set(hashMap.getAverageProbesPerLookup()) + }) + /////////////////////////////////////////////////////////////////////////// // Part 7: Iterator's public methods. /////////////////////////////////////////////////////////////////////////// @@ -409,22 +425,6 @@ class TungstenAggregationIterator( } } - // If this is the last record, update the task's peak memory usage. Since we destroy - // the map to create the sorter, their memory usages should not overlap, so it is safe - // to just use the max of the two. - if (!hasNext) { - val mapMemory = hashMap.getPeakMemoryUsedBytes - val sorterMemory = Option(externalSorter).map(_.getPeakMemoryUsedBytes).getOrElse(0L) - val maxMemory = Math.max(mapMemory, sorterMemory) - val metrics = TaskContext.get().taskMetrics() - peakMemory += maxMemory - spillSize += metrics.memoryBytesSpilled - spillSizeBefore - metrics.incPeakExecutionMemory(maxMemory) - - // Update average hashmap probe if this is the last record. - val averageProbes = hashMap.getAverageProbesPerLookup() - avgHashmapProbe.add(averageProbes.ceil.toLong) - } numOutputRows += 1 res } else { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala index 0bc261d593df4..bfa1e9d49a545 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala @@ -28,6 +28,7 @@ import org.apache.spark.sql.catalyst.plans.physical.{BroadcastDistribution, Dist import org.apache.spark.sql.execution.{BinaryExecNode, CodegenSupport, SparkPlan} import org.apache.spark.sql.execution.metric.SQLMetrics import org.apache.spark.sql.types.LongType +import org.apache.spark.util.TaskCompletionListener /** * Performs an inner hash join of two child relations. When the output RDD of this operator is @@ -46,7 +47,8 @@ case class BroadcastHashJoinExec( extends BinaryExecNode with HashJoin with CodegenSupport { override lazy val metrics = Map( - "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows")) + "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows"), + "avgHashProbe" -> SQLMetrics.createAverageMetric(sparkContext, "avg hash probe")) override def requiredChildDistribution: Seq[Distribution] = { val mode = HashedRelationBroadcastMode(buildKeys) @@ -60,12 +62,13 @@ case class BroadcastHashJoinExec( protected override def doExecute(): RDD[InternalRow] = { val numOutputRows = longMetric("numOutputRows") + val avgHashProbe = longMetric("avgHashProbe") val broadcastRelation = buildPlan.executeBroadcast[HashedRelation]() streamedPlan.execute().mapPartitions { streamedIter => val hashed = broadcastRelation.value.asReadOnlyCopy() TaskContext.get().taskMetrics().incPeakExecutionMemory(hashed.estimatedSize) - join(streamedIter, hashed, numOutputRows) + join(streamedIter, hashed, numOutputRows, avgHashProbe) } } @@ -90,6 +93,23 @@ case class BroadcastHashJoinExec( } } + /** + * Returns the codes used to add a task completion listener to update avg hash probe + * at the end of the task. + */ + private def genTaskListener(avgHashProbe: String, relationTerm: String): String = { + val listenerClass = classOf[TaskCompletionListener].getName + val taskContextClass = classOf[TaskContext].getName + s""" + | $taskContextClass$$.MODULE$$.get().addTaskCompletionListener(new $listenerClass() { + | @Override + | public void onTaskCompletion($taskContextClass context) { + | $avgHashProbe.set($relationTerm.getAverageProbesPerLookup()); + | } + | }); + """.stripMargin + } + /** * Returns a tuple of Broadcast of HashedRelation and the variable name for it. */ @@ -99,10 +119,16 @@ case class BroadcastHashJoinExec( val broadcast = ctx.addReferenceObj("broadcast", broadcastRelation) val relationTerm = ctx.freshName("relation") val clsName = broadcastRelation.value.getClass.getName + + // At the end of the task, we update the avg hash probe. + val avgHashProbe = metricTerm(ctx, "avgHashProbe") + val addTaskListener = genTaskListener(avgHashProbe, relationTerm) + ctx.addMutableState(clsName, relationTerm, s""" | $relationTerm = (($clsName) $broadcast.value()).asReadOnlyCopy(); | incPeakExecutionMemory($relationTerm.estimatedSize()); + | $addTaskListener """.stripMargin) (broadcastRelation, relationTerm) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala index 1aef5f6864263..b09edf380c2d4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.execution.joins +import org.apache.spark.TaskContext import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans._ @@ -193,7 +194,8 @@ trait HashJoin { protected def join( streamedIter: Iterator[InternalRow], hashed: HashedRelation, - numOutputRows: SQLMetric): Iterator[InternalRow] = { + numOutputRows: SQLMetric, + avgHashProbe: SQLMetric): Iterator[InternalRow] = { val joinedIter = joinType match { case _: InnerLike => @@ -211,6 +213,10 @@ trait HashJoin { s"BroadcastHashJoin should not take $x as the JoinType") } + // At the end of the task, we update the avg hash probe. + TaskContext.get().addTaskCompletionListener(_ => + avgHashProbe.set(hashed.getAverageProbesPerLookup())) + val resultProj = createResultProjection joinedIter.map { r => numOutputRows += 1 diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala index 2dd1dc3da96c9..3c702856114f9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala @@ -79,6 +79,11 @@ private[execution] sealed trait HashedRelation extends KnownSizeEstimation { * Release any used resources. */ def close(): Unit + + /** + * Returns the average number of probes per key lookup. + */ + def getAverageProbesPerLookup(): Double } private[execution] object HashedRelation { @@ -242,7 +247,8 @@ private[joins] class UnsafeHashedRelation( binaryMap = new BytesToBytesMap( taskMemoryManager, (nKeys * 1.5 + 1).toInt, // reduce hash collision - pageSizeBytes) + pageSizeBytes, + true) var i = 0 var keyBuffer = new Array[Byte](1024) @@ -273,6 +279,8 @@ private[joins] class UnsafeHashedRelation( override def read(kryo: Kryo, in: Input): Unit = Utils.tryOrIOException { read(in.readInt, in.readLong, in.readBytes) } + + override def getAverageProbesPerLookup(): Double = binaryMap.getAverageProbesPerLookup() } private[joins] object UnsafeHashedRelation { @@ -290,7 +298,8 @@ private[joins] object UnsafeHashedRelation { taskMemoryManager, // Only 70% of the slots can be used before growing, more capacity help to reduce collision (sizeEstimate * 1.5 + 1).toInt, - pageSizeBytes) + pageSizeBytes, + true) // Create a mapping of buildKeys -> rows val keyGenerator = UnsafeProjection.create(key) @@ -344,7 +353,7 @@ private[joins] object UnsafeHashedRelation { * determined by `key1 - minKey`. * * The map is created as sparse mode, then key-value could be appended into it. Once finish - * appending, caller could all optimize() to try to turn the map into dense mode, which is faster + * appending, caller could call optimize() to try to turn the map into dense mode, which is faster * to probe. * * see http://java-performance.info/implementing-world-fastest-java-int-to-int-hash-map/ @@ -385,6 +394,10 @@ private[execution] final class LongToUnsafeRowMap(val mm: TaskMemoryManager, cap // The number of unique keys. private var numKeys = 0L + // Tracking average number of probes per key lookup. + private var numKeyLookups = 0L + private var numProbes = 0L + // needed by serializer def this() = { this( @@ -469,6 +482,8 @@ private[execution] final class LongToUnsafeRowMap(val mm: TaskMemoryManager, cap */ def getValue(key: Long, resultRow: UnsafeRow): UnsafeRow = { if (isDense) { + numKeyLookups += 1 + numProbes += 1 if (key >= minKey && key <= maxKey) { val value = array((key - minKey).toInt) if (value > 0) { @@ -477,11 +492,14 @@ private[execution] final class LongToUnsafeRowMap(val mm: TaskMemoryManager, cap } } else { var pos = firstSlot(key) + numKeyLookups += 1 + numProbes += 1 while (array(pos + 1) != 0) { if (array(pos) == key) { return getRow(array(pos + 1), resultRow) } pos = nextSlot(pos) + numProbes += 1 } } null @@ -509,6 +527,8 @@ private[execution] final class LongToUnsafeRowMap(val mm: TaskMemoryManager, cap */ def get(key: Long, resultRow: UnsafeRow): Iterator[UnsafeRow] = { if (isDense) { + numKeyLookups += 1 + numProbes += 1 if (key >= minKey && key <= maxKey) { val value = array((key - minKey).toInt) if (value > 0) { @@ -517,11 +537,14 @@ private[execution] final class LongToUnsafeRowMap(val mm: TaskMemoryManager, cap } } else { var pos = firstSlot(key) + numKeyLookups += 1 + numProbes += 1 while (array(pos + 1) != 0) { if (array(pos) == key) { return valueIter(array(pos + 1), resultRow) } pos = nextSlot(pos) + numProbes += 1 } } null @@ -573,8 +596,11 @@ private[execution] final class LongToUnsafeRowMap(val mm: TaskMemoryManager, cap private def updateIndex(key: Long, address: Long): Unit = { var pos = firstSlot(key) assert(numKeys < array.length / 2) + numKeyLookups += 1 + numProbes += 1 while (array(pos) != key && array(pos + 1) != 0) { pos = nextSlot(pos) + numProbes += 1 } if (array(pos + 1) == 0) { // this is the first value for this key, put the address in array. @@ -686,6 +712,8 @@ private[execution] final class LongToUnsafeRowMap(val mm: TaskMemoryManager, cap writeLong(maxKey) writeLong(numKeys) writeLong(numValues) + writeLong(numKeyLookups) + writeLong(numProbes) writeLong(array.length) writeLongArray(writeBuffer, array, array.length) @@ -727,6 +755,8 @@ private[execution] final class LongToUnsafeRowMap(val mm: TaskMemoryManager, cap maxKey = readLong() numKeys = readLong() numValues = readLong() + numKeyLookups = readLong() + numProbes = readLong() val length = readLong().toInt mask = length - 2 @@ -742,6 +772,11 @@ private[execution] final class LongToUnsafeRowMap(val mm: TaskMemoryManager, cap override def read(kryo: Kryo, in: Input): Unit = { read(in.readBoolean, in.readLong, in.readBytes) } + + /** + * Returns the average number of probes per key lookup. + */ + def getAverageProbesPerLookup(): Double = numProbes.toDouble / numKeyLookups } private[joins] class LongHashedRelation( @@ -793,6 +828,8 @@ private[joins] class LongHashedRelation( resultRow = new UnsafeRow(nFields) map = in.readObject().asInstanceOf[LongToUnsafeRowMap] } + + override def getAverageProbesPerLookup(): Double = map.getAverageProbesPerLookup() } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoinExec.scala index afb6e5e3dd235..f1df41ca49c27 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoinExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoinExec.scala @@ -42,7 +42,8 @@ case class ShuffledHashJoinExec( override lazy val metrics = Map( "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows"), "buildDataSize" -> SQLMetrics.createSizeMetric(sparkContext, "data size of build side"), - "buildTime" -> SQLMetrics.createTimingMetric(sparkContext, "time to build hash map")) + "buildTime" -> SQLMetrics.createTimingMetric(sparkContext, "time to build hash map"), + "avgHashProbe" -> SQLMetrics.createAverageMetric(sparkContext, "avg hash probe")) override def requiredChildDistribution: Seq[Distribution] = ClusteredDistribution(leftKeys) :: ClusteredDistribution(rightKeys) :: Nil @@ -62,9 +63,10 @@ case class ShuffledHashJoinExec( protected override def doExecute(): RDD[InternalRow] = { val numOutputRows = longMetric("numOutputRows") + val avgHashProbe = longMetric("avgHashProbe") streamedPlan.execute().zipPartitions(buildPlan.execute()) { (streamIter, buildIter) => val hashed = buildHashedRelation(buildIter) - join(streamIter, hashed, numOutputRows) + join(streamIter, hashed, numOutputRows, avgHashProbe) } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetrics.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetrics.scala index 49cab04de2bf0..b4653c1b564f3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetrics.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetrics.scala @@ -57,6 +57,12 @@ class SQLMetric(val metricType: String, initValue: Long = 0L) extends Accumulato override def add(v: Long): Unit = _value += v + // We can set a double value to `SQLMetric` which stores only long value, if it is + // average metrics. + def set(v: Double): Unit = SQLMetrics.setDoubleForAverageMetrics(this, v) + + def set(v: Long): Unit = _value = v + def +=(v: Long): Unit = _value += v override def value: Long = _value @@ -74,6 +80,19 @@ object SQLMetrics { private val TIMING_METRIC = "timing" private val AVERAGE_METRIC = "average" + private val baseForAvgMetric: Int = 10 + + /** + * Converts a double value to long value by multiplying a base integer, so we can store it in + * `SQLMetrics`. It only works for average metrics. When showing the metrics on UI, we restore + * it back to a double value up to the decimal places bound by the base integer. + */ + private[sql] def setDoubleForAverageMetrics(metric: SQLMetric, v: Double): Unit = { + assert(metric.metricType == AVERAGE_METRIC, + s"Can't set a double to a metric of metrics type: ${metric.metricType}") + metric.set((v * baseForAvgMetric).toLong) + } + def createMetric(sc: SparkContext, name: String): SQLMetric = { val acc = new SQLMetric(SUM_METRIC) acc.register(sc, name = Some(name), countFailedValues = false) @@ -104,15 +123,14 @@ object SQLMetrics { /** * Create a metric to report the average information (including min, med, max) like - * avg hashmap probe. Because `SQLMetric` stores long values, we take the ceil of the average - * values before storing them. This metric is used to record an average value computed in the - * end of a task. It should be set once. The initial values (zeros) of this metrics will be - * excluded after. + * avg hash probe. As average metrics are double values, this kind of metrics should be + * only set with `SQLMetric.set` method instead of other methods like `SQLMetric.add`. + * The initial values (zeros) of this metrics will be excluded after. */ def createAverageMetric(sc: SparkContext, name: String): SQLMetric = { // The final result of this metric in physical operator UI may looks like: // probe avg (min, med, max): - // (1, 2, 6) + // (1.2, 2.2, 6.3) val acc = new SQLMetric(AVERAGE_METRIC) acc.register(sc, name = Some(s"$name (min, med, max)"), countFailedValues = false) acc @@ -127,7 +145,7 @@ object SQLMetrics { val numberFormat = NumberFormat.getIntegerInstance(Locale.US) numberFormat.format(values.sum) } else if (metricsType == AVERAGE_METRIC) { - val numberFormat = NumberFormat.getIntegerInstance(Locale.US) + val numberFormat = NumberFormat.getNumberInstance(Locale.US) val validValues = values.filter(_ > 0) val Seq(min, med, max) = { @@ -137,7 +155,7 @@ object SQLMetrics { val sorted = validValues.sorted Seq(sorted(0), sorted(validValues.length / 2), sorted(validValues.length - 1)) } - metric.map(numberFormat.format) + metric.map(v => numberFormat.format(v.toDouble / baseForAvgMetric)) } s"\n($min, $med, $max)" } else { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala index a12ce2b9eba34..cb3405b2fe19b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala @@ -47,9 +47,10 @@ class SQLMetricsSuite extends SparkFunSuite with SharedSQLContext { private def getSparkPlanMetrics( df: DataFrame, expectedNumOfJobs: Int, - expectedNodeIds: Set[Long]): Option[Map[Long, (String, Map[String, Any])]] = { + expectedNodeIds: Set[Long], + enableWholeStage: Boolean = false): Option[Map[Long, (String, Map[String, Any])]] = { val previousExecutionIds = spark.sharedState.listener.executionIdToData.keySet - withSQLConf("spark.sql.codegen.wholeStage" -> "false") { + withSQLConf("spark.sql.codegen.wholeStage" -> enableWholeStage.toString) { df.collect() } sparkContext.listenerBus.waitUntilEmpty(10000) @@ -110,6 +111,20 @@ class SQLMetricsSuite extends SparkFunSuite with SharedSQLContext { } } + /** + * Generates a `DataFrame` by filling randomly generated bytes for hash collision. + */ + private def generateRandomBytesDF(numRows: Int = 65535): DataFrame = { + val random = new Random() + val manyBytes = (0 until numRows).map { _ => + val byteArrSize = random.nextInt(100) + val bytes = new Array[Byte](byteArrSize) + random.nextBytes(bytes) + (bytes, random.nextInt(100)) + } + manyBytes.toSeq.toDF("a", "b") + } + test("LocalTableScanExec computes metrics in collect and take") { val df1 = spark.createDataset(Seq(1, 2, 3)) val logical = df1.queryExecution.logical @@ -151,9 +166,9 @@ class SQLMetricsSuite extends SparkFunSuite with SharedSQLContext { val df = testData2.groupBy().count() // 2 partitions val expected1 = Seq( Map("number of output rows" -> 2L, - "avg hashmap probe (min, med, max)" -> "\n(1, 1, 1)"), + "avg hash probe (min, med, max)" -> "\n(1, 1, 1)"), Map("number of output rows" -> 1L, - "avg hashmap probe (min, med, max)" -> "\n(1, 1, 1)")) + "avg hash probe (min, med, max)" -> "\n(1, 1, 1)")) testSparkPlanMetrics(df, 1, Map( 2L -> ("HashAggregate", expected1(0)), 0L -> ("HashAggregate", expected1(1))) @@ -163,9 +178,9 @@ class SQLMetricsSuite extends SparkFunSuite with SharedSQLContext { val df2 = testData2.groupBy('a).count() val expected2 = Seq( Map("number of output rows" -> 4L, - "avg hashmap probe (min, med, max)" -> "\n(1, 1, 1)"), + "avg hash probe (min, med, max)" -> "\n(1, 1, 1)"), Map("number of output rows" -> 3L, - "avg hashmap probe (min, med, max)" -> "\n(1, 1, 1)")) + "avg hash probe (min, med, max)" -> "\n(1, 1, 1)")) testSparkPlanMetrics(df2, 1, Map( 2L -> ("HashAggregate", expected2(0)), 0L -> ("HashAggregate", expected2(1))) @@ -173,19 +188,42 @@ class SQLMetricsSuite extends SparkFunSuite with SharedSQLContext { } test("Aggregate metrics: track avg probe") { - val random = new Random() - val manyBytes = (0 until 65535).map { _ => - val byteArrSize = random.nextInt(100) - val bytes = new Array[Byte](byteArrSize) - random.nextBytes(bytes) - (bytes, random.nextInt(100)) - } - val df = manyBytes.toSeq.toDF("a", "b").repartition(1).groupBy('a).count() - val metrics = getSparkPlanMetrics(df, 1, Set(2L, 0L)).get - Seq(metrics(2L)._2("avg hashmap probe (min, med, max)"), - metrics(0L)._2("avg hashmap probe (min, med, max)")).foreach { probes => - probes.toString.stripPrefix("\n(").stripSuffix(")").split(", ").foreach { probe => - assert(probe.toInt > 1) + // The executed plan looks like: + // HashAggregate(keys=[a#61], functions=[count(1)], output=[a#61, count#71L]) + // +- Exchange hashpartitioning(a#61, 5) + // +- HashAggregate(keys=[a#61], functions=[partial_count(1)], output=[a#61, count#76L]) + // +- Exchange RoundRobinPartitioning(1) + // +- LocalTableScan [a#61] + // + // Assume the execution plan with node id is: + // Wholestage disabled: + // HashAggregate(nodeId = 0) + // Exchange(nodeId = 1) + // HashAggregate(nodeId = 2) + // Exchange (nodeId = 3) + // LocalTableScan(nodeId = 4) + // + // Wholestage enabled: + // WholeStageCodegen(nodeId = 0) + // HashAggregate(nodeId = 1) + // Exchange(nodeId = 2) + // WholeStageCodegen(nodeId = 3) + // HashAggregate(nodeId = 4) + // Exchange(nodeId = 5) + // LocalTableScan(nodeId = 6) + Seq(true, false).foreach { enableWholeStage => + val df = generateRandomBytesDF().repartition(1).groupBy('a).count() + val nodeIds = if (enableWholeStage) { + Set(4L, 1L) + } else { + Set(2L, 0L) + } + val metrics = getSparkPlanMetrics(df, 1, nodeIds, enableWholeStage).get + nodeIds.foreach { nodeId => + val probes = metrics(nodeId)._2("avg hash probe (min, med, max)") + probes.toString.stripPrefix("\n(").stripSuffix(")").split(", ").foreach { probe => + assert(probe.toDouble > 1.0) + } } } } @@ -267,10 +305,120 @@ class SQLMetricsSuite extends SparkFunSuite with SharedSQLContext { val df = df1.join(broadcast(df2), "key") testSparkPlanMetrics(df, 2, Map( 1L -> ("BroadcastHashJoin", Map( - "number of output rows" -> 2L))) + "number of output rows" -> 2L, + "avg hash probe (min, med, max)" -> "\n(1, 1, 1)"))) ) } + test("BroadcastHashJoin metrics: track avg probe") { + // The executed plan looks like: + // Project [a#210, b#211, b#221] + // +- BroadcastHashJoin [a#210], [a#220], Inner, BuildRight + // :- Project [_1#207 AS a#210, _2#208 AS b#211] + // : +- Filter isnotnull(_1#207) + // : +- LocalTableScan [_1#207, _2#208] + // +- BroadcastExchange HashedRelationBroadcastMode(List(input[0, binary, true])) + // +- Project [_1#217 AS a#220, _2#218 AS b#221] + // +- Filter isnotnull(_1#217) + // +- LocalTableScan [_1#217, _2#218] + // + // Assume the execution plan with node id is + // WholeStageCodegen disabled: + // Project(nodeId = 0) + // BroadcastHashJoin(nodeId = 1) + // ...(ignored) + // + // WholeStageCodegen enabled: + // WholeStageCodegen(nodeId = 0) + // Project(nodeId = 1) + // BroadcastHashJoin(nodeId = 2) + // Project(nodeId = 3) + // Filter(nodeId = 4) + // ...(ignored) + Seq(true, false).foreach { enableWholeStage => + val df1 = generateRandomBytesDF() + val df2 = generateRandomBytesDF() + val df = df1.join(broadcast(df2), "a") + val nodeIds = if (enableWholeStage) { + Set(2L) + } else { + Set(1L) + } + val metrics = getSparkPlanMetrics(df, 2, nodeIds, enableWholeStage).get + nodeIds.foreach { nodeId => + val probes = metrics(nodeId)._2("avg hash probe (min, med, max)") + probes.toString.stripPrefix("\n(").stripSuffix(")").split(", ").foreach { probe => + assert(probe.toDouble > 1.0) + } + } + } + } + + test("ShuffledHashJoin metrics") { + withSQLConf("spark.sql.autoBroadcastJoinThreshold" -> "40", + "spark.sql.shuffle.partitions" -> "2", + "spark.sql.join.preferSortMergeJoin" -> "false") { + val df1 = Seq((1, "1"), (2, "2")).toDF("key", "value") + val df2 = (1 to 10).map(i => (i, i.toString)).toSeq.toDF("key", "value") + // Assume the execution plan is + // ... -> ShuffledHashJoin(nodeId = 1) -> Project(nodeId = 0) + val df = df1.join(df2, "key") + val metrics = getSparkPlanMetrics(df, 1, Set(1L)) + testSparkPlanMetrics(df, 1, Map( + 1L -> ("ShuffledHashJoin", Map( + "number of output rows" -> 2L, + "avg hash probe (min, med, max)" -> "\n(1, 1, 1)"))) + ) + } + } + + test("ShuffledHashJoin metrics: track avg probe") { + // The executed plan looks like: + // Project [a#308, b#309, b#319] + // +- ShuffledHashJoin [a#308], [a#318], Inner, BuildRight + // :- Exchange hashpartitioning(a#308, 2) + // : +- Project [_1#305 AS a#308, _2#306 AS b#309] + // : +- Filter isnotnull(_1#305) + // : +- LocalTableScan [_1#305, _2#306] + // +- Exchange hashpartitioning(a#318, 2) + // +- Project [_1#315 AS a#318, _2#316 AS b#319] + // +- Filter isnotnull(_1#315) + // +- LocalTableScan [_1#315, _2#316] + // + // Assume the execution plan with node id is + // WholeStageCodegen disabled: + // Project(nodeId = 0) + // ShuffledHashJoin(nodeId = 1) + // ...(ignored) + // + // WholeStageCodegen enabled: + // WholeStageCodegen(nodeId = 0) + // Project(nodeId = 1) + // ShuffledHashJoin(nodeId = 2) + // ...(ignored) + withSQLConf("spark.sql.autoBroadcastJoinThreshold" -> "5000000", + "spark.sql.shuffle.partitions" -> "2", + "spark.sql.join.preferSortMergeJoin" -> "false") { + Seq(true, false).foreach { enableWholeStage => + val df1 = generateRandomBytesDF(65535 * 5) + val df2 = generateRandomBytesDF(65535) + val df = df1.join(df2, "a") + val nodeIds = if (enableWholeStage) { + Set(2L) + } else { + Set(1L) + } + val metrics = getSparkPlanMetrics(df, 1, nodeIds, enableWholeStage).get + nodeIds.foreach { nodeId => + val probes = metrics(nodeId)._2("avg hash probe (min, med, max)") + probes.toString.stripPrefix("\n(").stripSuffix(")").split(", ").foreach { probe => + assert(probe.toDouble > 1.0) + } + } + } + } + } + test("BroadcastHashJoin(outer) metrics") { val df1 = Seq((1, "a"), (1, "b"), (4, "c")).toDF("key", "value") val df2 = Seq((1, "a"), (1, "b"), (2, "c"), (3, "d")).toDF("key2", "value") From f9151bebca986d44cdab7699959fec2bc050773a Mon Sep 17 00:00:00 2001 From: Feng Liu Date: Thu, 29 Jun 2017 16:03:15 -0700 Subject: [PATCH 0823/1765] [SPARK-21188][CORE] releaseAllLocksForTask should synchronize the whole method ## What changes were proposed in this pull request? Since the objects `readLocksByTask`, `writeLocksByTask` and `info`s are coupled and supposed to be modified by other threads concurrently, all the read and writes of them in the method `releaseAllLocksForTask` should be protected by a single synchronized block like other similar methods. ## How was this patch tested? existing tests Author: Feng Liu Closes #18400 from liufengdb/synchronize. --- .../spark/storage/BlockInfoManager.scala | 24 +++++++------------ 1 file changed, 9 insertions(+), 15 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/storage/BlockInfoManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockInfoManager.scala index 7064872ec1c77..219a0e799cc73 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockInfoManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockInfoManager.scala @@ -341,15 +341,11 @@ private[storage] class BlockInfoManager extends Logging { * * @return the ids of blocks whose pins were released */ - def releaseAllLocksForTask(taskAttemptId: TaskAttemptId): Seq[BlockId] = { + def releaseAllLocksForTask(taskAttemptId: TaskAttemptId): Seq[BlockId] = synchronized { val blocksWithReleasedLocks = mutable.ArrayBuffer[BlockId]() - val readLocks = synchronized { - readLocksByTask.remove(taskAttemptId).getOrElse(ImmutableMultiset.of[BlockId]()) - } - val writeLocks = synchronized { - writeLocksByTask.remove(taskAttemptId).getOrElse(Seq.empty) - } + val readLocks = readLocksByTask.remove(taskAttemptId).getOrElse(ImmutableMultiset.of[BlockId]()) + val writeLocks = writeLocksByTask.remove(taskAttemptId).getOrElse(Seq.empty) for (blockId <- writeLocks) { infos.get(blockId).foreach { info => @@ -358,21 +354,19 @@ private[storage] class BlockInfoManager extends Logging { } blocksWithReleasedLocks += blockId } + readLocks.entrySet().iterator().asScala.foreach { entry => val blockId = entry.getElement val lockCount = entry.getCount blocksWithReleasedLocks += blockId - synchronized { - get(blockId).foreach { info => - info.readerCount -= lockCount - assert(info.readerCount >= 0) - } + get(blockId).foreach { info => + info.readerCount -= lockCount + assert(info.readerCount >= 0) } } - synchronized { - notifyAll() - } + notifyAll() + blocksWithReleasedLocks } From 4996c53949376153f9ebdc74524fed7226968808 Mon Sep 17 00:00:00 2001 From: Shixiong Zhu Date: Fri, 30 Jun 2017 10:56:48 +0800 Subject: [PATCH 0824/1765] [SPARK-21253][CORE] Fix a bug that StreamCallback may not be notified if network errors happen ## What changes were proposed in this pull request? If a network error happens before processing StreamResponse/StreamFailure events, StreamCallback.onFailure won't be called. This PR fixes `failOutstandingRequests` to also notify outstanding StreamCallbacks. ## How was this patch tested? The new unit tests. Author: Shixiong Zhu Closes #18472 from zsxwing/fix-stream-2. --- .../spark/network/client/TransportClient.java | 2 +- .../client/TransportResponseHandler.java | 38 ++++++++++++++----- .../TransportResponseHandlerSuite.java | 31 ++++++++++++++- 3 files changed, 59 insertions(+), 12 deletions(-) diff --git a/common/network-common/src/main/java/org/apache/spark/network/client/TransportClient.java b/common/network-common/src/main/java/org/apache/spark/network/client/TransportClient.java index a6f527c118218..8f354ad78bbaa 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/client/TransportClient.java +++ b/common/network-common/src/main/java/org/apache/spark/network/client/TransportClient.java @@ -179,7 +179,7 @@ public void stream(String streamId, StreamCallback callback) { // written to the socket atomically, so that callbacks are called in the right order // when responses arrive. synchronized (this) { - handler.addStreamCallback(callback); + handler.addStreamCallback(streamId, callback); channel.writeAndFlush(new StreamRequest(streamId)).addListener(future -> { if (future.isSuccess()) { long timeTaken = System.currentTimeMillis() - startTime; diff --git a/common/network-common/src/main/java/org/apache/spark/network/client/TransportResponseHandler.java b/common/network-common/src/main/java/org/apache/spark/network/client/TransportResponseHandler.java index 41bead546cad6..be9f18203c8e4 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/client/TransportResponseHandler.java +++ b/common/network-common/src/main/java/org/apache/spark/network/client/TransportResponseHandler.java @@ -24,6 +24,8 @@ import java.util.concurrent.ConcurrentLinkedQueue; import java.util.concurrent.atomic.AtomicLong; +import scala.Tuple2; + import com.google.common.annotations.VisibleForTesting; import io.netty.channel.Channel; import org.slf4j.Logger; @@ -56,7 +58,7 @@ public class TransportResponseHandler extends MessageHandler { private final Map outstandingRpcs; - private final Queue streamCallbacks; + private final Queue> streamCallbacks; private volatile boolean streamActive; /** Records the time (in system nanoseconds) that the last fetch or RPC request was sent. */ @@ -88,9 +90,9 @@ public void removeRpcRequest(long requestId) { outstandingRpcs.remove(requestId); } - public void addStreamCallback(StreamCallback callback) { + public void addStreamCallback(String streamId, StreamCallback callback) { timeOfLastRequestNs.set(System.nanoTime()); - streamCallbacks.offer(callback); + streamCallbacks.offer(Tuple2.apply(streamId, callback)); } @VisibleForTesting @@ -104,15 +106,31 @@ public void deactivateStream() { */ private void failOutstandingRequests(Throwable cause) { for (Map.Entry entry : outstandingFetches.entrySet()) { - entry.getValue().onFailure(entry.getKey().chunkIndex, cause); + try { + entry.getValue().onFailure(entry.getKey().chunkIndex, cause); + } catch (Exception e) { + logger.warn("ChunkReceivedCallback.onFailure throws exception", e); + } } for (Map.Entry entry : outstandingRpcs.entrySet()) { - entry.getValue().onFailure(cause); + try { + entry.getValue().onFailure(cause); + } catch (Exception e) { + logger.warn("RpcResponseCallback.onFailure throws exception", e); + } + } + for (Tuple2 entry : streamCallbacks) { + try { + entry._2().onFailure(entry._1(), cause); + } catch (Exception e) { + logger.warn("StreamCallback.onFailure throws exception", e); + } } // It's OK if new fetches appear, as they will fail immediately. outstandingFetches.clear(); outstandingRpcs.clear(); + streamCallbacks.clear(); } @Override @@ -190,8 +208,9 @@ public void handle(ResponseMessage message) throws Exception { } } else if (message instanceof StreamResponse) { StreamResponse resp = (StreamResponse) message; - StreamCallback callback = streamCallbacks.poll(); - if (callback != null) { + Tuple2 entry = streamCallbacks.poll(); + if (entry != null) { + StreamCallback callback = entry._2(); if (resp.byteCount > 0) { StreamInterceptor interceptor = new StreamInterceptor(this, resp.streamId, resp.byteCount, callback); @@ -216,8 +235,9 @@ public void handle(ResponseMessage message) throws Exception { } } else if (message instanceof StreamFailure) { StreamFailure resp = (StreamFailure) message; - StreamCallback callback = streamCallbacks.poll(); - if (callback != null) { + Tuple2 entry = streamCallbacks.poll(); + if (entry != null) { + StreamCallback callback = entry._2(); try { callback.onFailure(resp.streamId, new RuntimeException(resp.error)); } catch (IOException ioe) { diff --git a/common/network-common/src/test/java/org/apache/spark/network/TransportResponseHandlerSuite.java b/common/network-common/src/test/java/org/apache/spark/network/TransportResponseHandlerSuite.java index 09fc80d12d510..b4032c4c3f031 100644 --- a/common/network-common/src/test/java/org/apache/spark/network/TransportResponseHandlerSuite.java +++ b/common/network-common/src/test/java/org/apache/spark/network/TransportResponseHandlerSuite.java @@ -17,6 +17,7 @@ package org.apache.spark.network; +import java.io.IOException; import java.nio.ByteBuffer; import io.netty.channel.Channel; @@ -127,7 +128,7 @@ public void testActiveStreams() throws Exception { StreamResponse response = new StreamResponse("stream", 1234L, null); StreamCallback cb = mock(StreamCallback.class); - handler.addStreamCallback(cb); + handler.addStreamCallback("stream", cb); assertEquals(1, handler.numOutstandingRequests()); handler.handle(response); assertEquals(1, handler.numOutstandingRequests()); @@ -135,9 +136,35 @@ public void testActiveStreams() throws Exception { assertEquals(0, handler.numOutstandingRequests()); StreamFailure failure = new StreamFailure("stream", "uh-oh"); - handler.addStreamCallback(cb); + handler.addStreamCallback("stream", cb); assertEquals(1, handler.numOutstandingRequests()); handler.handle(failure); assertEquals(0, handler.numOutstandingRequests()); } + + @Test + public void failOutstandingStreamCallbackOnClose() throws Exception { + Channel c = new LocalChannel(); + c.pipeline().addLast(TransportFrameDecoder.HANDLER_NAME, new TransportFrameDecoder()); + TransportResponseHandler handler = new TransportResponseHandler(c); + + StreamCallback cb = mock(StreamCallback.class); + handler.addStreamCallback("stream-1", cb); + handler.channelInactive(); + + verify(cb).onFailure(eq("stream-1"), isA(IOException.class)); + } + + @Test + public void failOutstandingStreamCallbackOnException() throws Exception { + Channel c = new LocalChannel(); + c.pipeline().addLast(TransportFrameDecoder.HANDLER_NAME, new TransportFrameDecoder()); + TransportResponseHandler handler = new TransportResponseHandler(c); + + StreamCallback cb = mock(StreamCallback.class); + handler.addStreamCallback("stream-1", cb); + handler.exceptionCaught(new IOException("Oops!")); + + verify(cb).onFailure(eq("stream-1"), isA(IOException.class)); + } } From 80f7ac3a601709dd9471092244612023363f54cd Mon Sep 17 00:00:00 2001 From: Shixiong Zhu Date: Fri, 30 Jun 2017 11:02:22 +0800 Subject: [PATCH 0825/1765] [SPARK-21253][CORE] Disable spark.reducer.maxReqSizeShuffleToMem ## What changes were proposed in this pull request? Disable spark.reducer.maxReqSizeShuffleToMem because it breaks the old shuffle service. Credits to wangyum Closes #18466 ## How was this patch tested? Jenkins Author: Shixiong Zhu Author: Yuming Wang Closes #18467 from zsxwing/SPARK-21253. --- .../scala/org/apache/spark/internal/config/package.scala | 3 ++- docs/configuration.md | 8 -------- 2 files changed, 2 insertions(+), 9 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/internal/config/package.scala b/core/src/main/scala/org/apache/spark/internal/config/package.scala index be63c637a3a13..8dee0d970c4c6 100644 --- a/core/src/main/scala/org/apache/spark/internal/config/package.scala +++ b/core/src/main/scala/org/apache/spark/internal/config/package.scala @@ -323,10 +323,11 @@ package object config { private[spark] val REDUCER_MAX_REQ_SIZE_SHUFFLE_TO_MEM = ConfigBuilder("spark.reducer.maxReqSizeShuffleToMem") + .internal() .doc("The blocks of a shuffle request will be fetched to disk when size of the request is " + "above this threshold. This is to avoid a giant request takes too much memory.") .bytesConf(ByteUnit.BYTE) - .createWithDefaultString("200m") + .createWithDefault(Long.MaxValue) private[spark] val TASK_METRICS_TRACK_UPDATED_BLOCK_STATUSES = ConfigBuilder("spark.taskMetrics.trackUpdatedBlockStatuses") diff --git a/docs/configuration.md b/docs/configuration.md index c8e61537a457c..bd6a1f9e240e2 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -528,14 +528,6 @@ Apart from these, the following properties are also available, and may be useful By allowing it to limit the number of fetch requests, this scenario can be mitigated. - - spark.reducer.maxReqSizeShuffleToMem - 200m - - The blocks of a shuffle request will be fetched to disk when size of the request is above - this threshold. This is to avoid a giant request takes too much memory. - - spark.shuffle.compress true From 88a536babf119b7e331d02aac5d52b57658803bf Mon Sep 17 00:00:00 2001 From: IngoSchuster Date: Fri, 30 Jun 2017 11:16:09 +0800 Subject: [PATCH 0826/1765] [SPARK-21176][WEB UI] Limit number of selector threads for admin ui proxy servlets to 8 ## What changes were proposed in this pull request? Please see also https://issues.apache.org/jira/browse/SPARK-21176 This change limits the number of selector threads that jetty creates to maximum 8 per proxy servlet (Jetty default is number of processors / 2). The newHttpClient for Jettys ProxyServlet class is overwritten to avoid the Jetty defaults (which are designed for high-performance http servers). Once https://github.com/eclipse/jetty.project/issues/1643 is available, the code could be cleaned up to avoid the method override. I really need this on v2.1.1 - what is the best way for a backport automatic merge works fine)? Shall I create another PR? ## How was this patch tested? (Please explain how this patch was tested. E.g. unit tests, integration tests, manual tests) The patch was tested manually on a Spark cluster with a head node that has 88 processors using JMX to verify that the number of selector threads is now limited to 8 per proxy. gurvindersingh zsxwing can you please review the change? Author: IngoSchuster Author: Ingo Schuster Closes #18437 from IngoSchuster/master. --- .../main/scala/org/apache/spark/ui/JettyUtils.scala | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala b/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala index edf328b5ae538..b9371c7ad7b45 100644 --- a/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala +++ b/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala @@ -26,6 +26,8 @@ import scala.language.implicitConversions import scala.xml.Node import org.eclipse.jetty.client.api.Response +import org.eclipse.jetty.client.HttpClient +import org.eclipse.jetty.client.http.HttpClientTransportOverHTTP import org.eclipse.jetty.proxy.ProxyServlet import org.eclipse.jetty.server._ import org.eclipse.jetty.server.handler._ @@ -208,6 +210,16 @@ private[spark] object JettyUtils extends Logging { rewrittenURI.toString() } + override def newHttpClient(): HttpClient = { + // SPARK-21176: Use the Jetty logic to calculate the number of selector threads (#CPUs/2), + // but limit it to 8 max. + // Otherwise, it might happen that we exhaust the threadpool since in reverse proxy mode + // a proxy is instantiated for each executor. If the head node has many processors, this + // can quickly add up to an unreasonably high number of threads. + val numSelectors = math.max(1, math.min(8, Runtime.getRuntime().availableProcessors() / 2)) + new HttpClient(new HttpClientTransportOverHTTP(numSelectors), null) + } + override def filterServerResponseHeader( clientRequest: HttpServletRequest, serverResponse: Response, From cfc696f4a4289acf132cb26baf7c02c5b6305277 Mon Sep 17 00:00:00 2001 From: Shixiong Zhu Date: Thu, 29 Jun 2017 20:56:37 -0700 Subject: [PATCH 0827/1765] [SPARK-21253][CORE][HOTFIX] Fix Scala 2.10 build ## What changes were proposed in this pull request? A follow up PR to fix Scala 2.10 build for #18472 ## How was this patch tested? Jenkins Author: Shixiong Zhu Closes #18478 from zsxwing/SPARK-21253-2. --- .../apache/spark/network/client/TransportResponseHandler.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/common/network-common/src/main/java/org/apache/spark/network/client/TransportResponseHandler.java b/common/network-common/src/main/java/org/apache/spark/network/client/TransportResponseHandler.java index be9f18203c8e4..340b8b96aabc6 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/client/TransportResponseHandler.java +++ b/common/network-common/src/main/java/org/apache/spark/network/client/TransportResponseHandler.java @@ -92,7 +92,7 @@ public void removeRpcRequest(long requestId) { public void addStreamCallback(String streamId, StreamCallback callback) { timeOfLastRequestNs.set(System.nanoTime()); - streamCallbacks.offer(Tuple2.apply(streamId, callback)); + streamCallbacks.offer(new Tuple2<>(streamId, callback)); } @VisibleForTesting From e2f32ee45ac907f1f53fde7e412676a849a94872 Mon Sep 17 00:00:00 2001 From: Herman van Hovell Date: Fri, 30 Jun 2017 12:34:09 +0800 Subject: [PATCH 0828/1765] [SPARK-21258][SQL] Fix WindowExec complex object aggregation with spilling ## What changes were proposed in this pull request? `WindowExec` currently improperly stores complex objects (UnsafeRow, UnsafeArrayData, UnsafeMapData, UTF8String) during aggregation by keeping a reference in the buffer used by `GeneratedMutableProjections` to the actual input data. Things go wrong when the input object (or the backing bytes) are reused for other things. This could happen in window functions when it starts spilling to disk. When reading the back the spill files the `UnsafeSorterSpillReader` reuses the buffer to which the `UnsafeRow` points, leading to weird corruption scenario's. Note that this only happens for aggregate functions that preserve (parts of) their input, for example `FIRST`, `LAST`, `MIN` & `MAX`. This was not seen before, because the spilling logic was not doing actual spills as much and actually used an in-memory page. This page was not cleaned up during window processing and made sure unsafe objects point to their own dedicated memory location. This was changed by https://github.com/apache/spark/pull/16909, after this PR Spark spills more eagerly. This PR provides a surgical fix because we are close to releasing Spark 2.2. This change just makes sure that there cannot be any object reuse at the expensive of a little bit of performance. We will follow-up with a more subtle solution at a later point. ## How was this patch tested? Added a regression test to `DataFrameWindowFunctionsSuite`. Author: Herman van Hovell Closes #18470 from hvanhovell/SPARK-21258. --- .../execution/window/AggregateProcessor.scala | 7 ++- .../sql/DataFrameWindowFunctionsSuite.scala | 47 ++++++++++++++++++- 2 files changed, 51 insertions(+), 3 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/window/AggregateProcessor.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/window/AggregateProcessor.scala index bc141b36e63b4..2195c6ea95948 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/window/AggregateProcessor.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/window/AggregateProcessor.scala @@ -145,10 +145,13 @@ private[window] final class AggregateProcessor( /** Update the buffer. */ def update(input: InternalRow): Unit = { - updateProjection(join(buffer, input)) + // TODO(hvanhovell) this sacrifices performance for correctness. We should make sure that + // MutableProjection makes copies of the complex input objects it buffer. + val copy = input.copy() + updateProjection(join(buffer, copy)) var i = 0 while (i < numImperatives) { - imperatives(i).update(buffer, input) + imperatives(i).update(buffer, copy) i += 1 } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFunctionsSuite.scala index 1255c49104718..204858fa29787 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFunctionsSuite.scala @@ -19,8 +19,9 @@ package org.apache.spark.sql import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction, Window} import org.apache.spark.sql.functions._ +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSQLContext -import org.apache.spark.sql.types.{DataType, LongType, StructType} +import org.apache.spark.sql.types._ /** * Window function testing for DataFrame API. @@ -423,4 +424,48 @@ class DataFrameWindowFunctionsSuite extends QueryTest with SharedSQLContext { df.select(selectList: _*).where($"value" < 2), Seq(Row(3, "1", null, 3.0, 4.0, 3.0), Row(5, "1", false, 4.0, 5.0, 5.0))) } + + test("SPARK-21258: complex object in combination with spilling") { + // Make sure we trigger the spilling path. + withSQLConf(SQLConf.WINDOW_EXEC_BUFFER_SPILL_THRESHOLD.key -> "17") { + val sampleSchema = new StructType(). + add("f0", StringType). + add("f1", LongType). + add("f2", ArrayType(new StructType(). + add("f20", StringType))). + add("f3", ArrayType(new StructType(). + add("f30", StringType))) + + val w0 = Window.partitionBy("f0").orderBy("f1") + val w1 = w0.rowsBetween(Long.MinValue, Long.MaxValue) + + val c0 = first(struct($"f2", $"f3")).over(w0) as "c0" + val c1 = last(struct($"f2", $"f3")).over(w1) as "c1" + + val input = + """{"f1":1497820153720,"f2":[{"f20":"x","f21":0}],"f3":[{"f30":"x","f31":0}]} + |{"f1":1497802179638} + |{"f1":1497802189347} + |{"f1":1497802189593} + |{"f1":1497802189597} + |{"f1":1497802189599} + |{"f1":1497802192103} + |{"f1":1497802193414} + |{"f1":1497802193577} + |{"f1":1497802193709} + |{"f1":1497802202883} + |{"f1":1497802203006} + |{"f1":1497802203743} + |{"f1":1497802203834} + |{"f1":1497802203887} + |{"f1":1497802203893} + |{"f1":1497802203976} + |{"f1":1497820168098} + |""".stripMargin.split("\n").toSeq + + import testImplicits._ + + spark.read.schema(sampleSchema).json(input.toDS()).select(c0, c1).foreach { _ => () } + } + } } From fddb63f46345be36c40d9a7f3660920af6502bbd Mon Sep 17 00:00:00 2001 From: actuaryzhang Date: Thu, 29 Jun 2017 21:35:01 -0700 Subject: [PATCH 0829/1765] [SPARK-20889][SPARKR] Grouped documentation for MISC column methods ## What changes were proposed in this pull request? Grouped documentation for column misc methods. Author: actuaryzhang Author: Wayne Zhang Closes #18448 from actuaryzhang/sparkRDocMisc. --- R/pkg/R/functions.R | 98 +++++++++++++++++++++------------------------ R/pkg/R/generics.R | 15 ++++--- 2 files changed, 55 insertions(+), 58 deletions(-) diff --git a/R/pkg/R/functions.R b/R/pkg/R/functions.R index cb09e847d739a..67cb7a7f6db08 100644 --- a/R/pkg/R/functions.R +++ b/R/pkg/R/functions.R @@ -150,6 +150,27 @@ NULL #' df <- createDataFrame(cbind(model = rownames(mtcars), mtcars))} NULL +#' Miscellaneous functions for Column operations +#' +#' Miscellaneous functions defined for \code{Column}. +#' +#' @param x Column to compute on. In \code{sha2}, it is one of 224, 256, 384, or 512. +#' @param y Column to compute on. +#' @param ... additional Columns. +#' @name column_misc_functions +#' @rdname column_misc_functions +#' @family misc functions +#' @examples +#' \dontrun{ +#' # Dataframe used throughout this doc +#' df <- createDataFrame(cbind(model = rownames(mtcars), mtcars)[, 1:2]) +#' tmp <- mutate(df, v1 = crc32(df$model), v2 = hash(df$model), +#' v3 = hash(df$model, df$mpg), v4 = md5(df$model), +#' v5 = sha1(df$model), v6 = sha2(df$model, 256)) +#' head(tmp) +#' } +NULL + #' @details #' \code{lit}: A new Column is created to represent the literal value. #' If the parameter is a Column, it is returned unchanged. @@ -569,19 +590,13 @@ setMethod("count", column(jc) }) -#' crc32 -#' -#' Calculates the cyclic redundancy check value (CRC32) of a binary column and -#' returns the value as a bigint. -#' -#' @param x Column to compute on. +#' @details +#' \code{crc32}: Calculates the cyclic redundancy check value (CRC32) of a binary column +#' and returns the value as a bigint. #' -#' @rdname crc32 -#' @name crc32 -#' @family misc functions -#' @aliases crc32,Column-method +#' @rdname column_misc_functions +#' @aliases crc32 crc32,Column-method #' @export -#' @examples \dontrun{crc32(df$c)} #' @note crc32 since 1.5.0 setMethod("crc32", signature(x = "Column"), @@ -590,19 +605,13 @@ setMethod("crc32", column(jc) }) -#' hash -#' -#' Calculates the hash code of given columns, and returns the result as a int column. -#' -#' @param x Column to compute on. -#' @param ... additional Column(s) to be included. +#' @details +#' \code{hash}: Calculates the hash code of given columns, and returns the result +#' as an int column. #' -#' @rdname hash -#' @name hash -#' @family misc functions -#' @aliases hash,Column-method +#' @rdname column_misc_functions +#' @aliases hash hash,Column-method #' @export -#' @examples \dontrun{hash(df$c)} #' @note hash since 2.0.0 setMethod("hash", signature(x = "Column"), @@ -1055,19 +1064,13 @@ setMethod("max", column(jc) }) -#' md5 -#' -#' Calculates the MD5 digest of a binary column and returns the value +#' @details +#' \code{md5}: Calculates the MD5 digest of a binary column and returns the value #' as a 32 character hex string. #' -#' @param x Column to compute on. -#' -#' @rdname md5 -#' @name md5 -#' @family misc functions -#' @aliases md5,Column-method +#' @rdname column_misc_functions +#' @aliases md5 md5,Column-method #' @export -#' @examples \dontrun{md5(df$c)} #' @note md5 since 1.5.0 setMethod("md5", signature(x = "Column"), @@ -1307,19 +1310,13 @@ setMethod("second", column(jc) }) -#' sha1 -#' -#' Calculates the SHA-1 digest of a binary column and returns the value +#' @details +#' \code{sha1}: Calculates the SHA-1 digest of a binary column and returns the value #' as a 40 character hex string. #' -#' @param x Column to compute on. -#' -#' @rdname sha1 -#' @name sha1 -#' @family misc functions -#' @aliases sha1,Column-method +#' @rdname column_misc_functions +#' @aliases sha1 sha1,Column-method #' @export -#' @examples \dontrun{sha1(df$c)} #' @note sha1 since 1.5.0 setMethod("sha1", signature(x = "Column"), @@ -2309,19 +2306,14 @@ setMethod("format_number", signature(y = "Column", x = "numeric"), column(jc) }) -#' sha2 -#' -#' Calculates the SHA-2 family of hash functions of a binary column and -#' returns the value as a hex string. +#' @details +#' \code{sha2}: Calculates the SHA-2 family of hash functions of a binary column and +#' returns the value as a hex string. The second argument \code{x} specifies the number +#' of bits, and is one of 224, 256, 384, or 512. #' -#' @param y column to compute SHA-2 on. -#' @param x one of 224, 256, 384, or 512. -#' @family misc functions -#' @rdname sha2 -#' @name sha2 -#' @aliases sha2,Column,numeric-method +#' @rdname column_misc_functions +#' @aliases sha2 sha2,Column,numeric-method #' @export -#' @examples \dontrun{sha2(df$c, 256)} #' @note sha2 since 1.5.0 setMethod("sha2", signature(y = "Column", x = "numeric"), function(y, x) { diff --git a/R/pkg/R/generics.R b/R/pkg/R/generics.R index 1deb057bb1b82..bdd4b360f4973 100644 --- a/R/pkg/R/generics.R +++ b/R/pkg/R/generics.R @@ -992,8 +992,9 @@ setGeneric("conv", function(x, fromBase, toBase) { standardGeneric("conv") }) #' @name NULL setGeneric("countDistinct", function(x, ...) { standardGeneric("countDistinct") }) -#' @rdname crc32 +#' @rdname column_misc_functions #' @export +#' @name NULL setGeneric("crc32", function(x) { standardGeneric("crc32") }) #' @rdname column_nonaggregate_functions @@ -1006,8 +1007,9 @@ setGeneric("create_array", function(x, ...) { standardGeneric("create_array") }) #' @name NULL setGeneric("create_map", function(x, ...) { standardGeneric("create_map") }) -#' @rdname hash +#' @rdname column_misc_functions #' @export +#' @name NULL setGeneric("hash", function(x, ...) { standardGeneric("hash") }) #' @param x empty. Should be used with no argument. @@ -1205,8 +1207,9 @@ setGeneric("lpad", function(x, len, pad) { standardGeneric("lpad") }) #' @name NULL setGeneric("ltrim", function(x) { standardGeneric("ltrim") }) -#' @rdname md5 +#' @rdname column_misc_functions #' @export +#' @name NULL setGeneric("md5", function(x) { standardGeneric("md5") }) #' @rdname column_datetime_functions @@ -1350,12 +1353,14 @@ setGeneric("sd", function(x, na.rm = FALSE) { standardGeneric("sd") }) #' @name NULL setGeneric("second", function(x) { standardGeneric("second") }) -#' @rdname sha1 +#' @rdname column_misc_functions #' @export +#' @name NULL setGeneric("sha1", function(x) { standardGeneric("sha1") }) -#' @rdname sha2 +#' @rdname column_misc_functions #' @export +#' @name NULL setGeneric("sha2", function(y, x) { standardGeneric("sha2") }) #' @rdname column_math_functions From 52981715bb8d653a1141f55b36da804412eb783a Mon Sep 17 00:00:00 2001 From: actuaryzhang Date: Thu, 29 Jun 2017 23:00:50 -0700 Subject: [PATCH 0830/1765] [SPARK-20889][SPARKR] Grouped documentation for COLLECTION column methods ## What changes were proposed in this pull request? Grouped documentation for column collection methods. Author: actuaryzhang Author: Wayne Zhang Closes #18458 from actuaryzhang/sparkRDocCollection. --- R/pkg/R/functions.R | 204 +++++++++++++++++++------------------------- R/pkg/R/generics.R | 27 ++++-- 2 files changed, 108 insertions(+), 123 deletions(-) diff --git a/R/pkg/R/functions.R b/R/pkg/R/functions.R index 67cb7a7f6db08..a1f5c4f8cc18d 100644 --- a/R/pkg/R/functions.R +++ b/R/pkg/R/functions.R @@ -171,6 +171,35 @@ NULL #' } NULL +#' Collection functions for Column operations +#' +#' Collection functions defined for \code{Column}. +#' +#' @param x Column to compute on. Note the difference in the following methods: +#' \itemize{ +#' \item \code{to_json}: it is the column containing the struct or array of the structs. +#' \item \code{from_json}: it is the column containing the JSON string. +#' } +#' @param ... additional argument(s). In \code{to_json} and \code{from_json}, this contains +#' additional named properties to control how it is converted, accepts the same +#' options as the JSON data source. +#' @name column_collection_functions +#' @rdname column_collection_functions +#' @family collection functions +#' @examples +#' \dontrun{ +#' # Dataframe used throughout this doc +#' df <- createDataFrame(cbind(model = rownames(mtcars), mtcars)) +#' df <- createDataFrame(cbind(model = rownames(mtcars), mtcars)) +#' tmp <- mutate(df, v1 = create_array(df$mpg, df$cyl, df$hp)) +#' head(select(tmp, array_contains(tmp$v1, 21), size(tmp$v1))) +#' tmp2 <- mutate(tmp, v2 = explode(tmp$v1)) +#' head(tmp2) +#' head(select(tmp, posexplode(tmp$v1))) +#' head(select(tmp, sort_array(tmp$v1))) +#' head(select(tmp, sort_array(tmp$v1, asc = FALSE)))} +NULL + #' @details #' \code{lit}: A new Column is created to represent the literal value. #' If the parameter is a Column, it is returned unchanged. @@ -1642,30 +1671,23 @@ setMethod("to_date", column(jc) }) -#' to_json -#' -#' Converts a column containing a \code{structType} or array of \code{structType} into a Column -#' of JSON string. Resolving the Column can fail if an unsupported type is encountered. -#' -#' @param x Column containing the struct or array of the structs -#' @param ... additional named properties to control how it is converted, accepts the same options -#' as the JSON data source. +#' @details +#' \code{to_json}: Converts a column containing a \code{structType} or array of \code{structType} +#' into a Column of JSON string. Resolving the Column can fail if an unsupported type is encountered. #' -#' @family non-aggregate functions -#' @rdname to_json -#' @name to_json -#' @aliases to_json,Column-method +#' @rdname column_collection_functions +#' @aliases to_json to_json,Column-method #' @export #' @examples +#' #' \dontrun{ #' # Converts a struct into a JSON object -#' df <- sql("SELECT named_struct('date', cast('2000-01-01' as date)) as d") -#' select(df, to_json(df$d, dateFormat = 'dd/MM/yyyy')) +#' df2 <- sql("SELECT named_struct('date', cast('2000-01-01' as date)) as d") +#' select(df2, to_json(df2$d, dateFormat = 'dd/MM/yyyy')) #' #' # Converts an array of structs into a JSON array -#' df <- sql("SELECT array(named_struct('name', 'Bob'), named_struct('name', 'Alice')) as people") -#' select(df, to_json(df$people)) -#'} +#' df2 <- sql("SELECT array(named_struct('name', 'Bob'), named_struct('name', 'Alice')) as people") +#' df2 <- mutate(df2, people_json = to_json(df2$people))} #' @note to_json since 2.2.0 setMethod("to_json", signature(x = "Column"), function(x, ...) { @@ -2120,28 +2142,28 @@ setMethod("date_format", signature(y = "Column", x = "character"), column(jc) }) -#' from_json -#' -#' Parses a column containing a JSON string into a Column of \code{structType} with the specified -#' \code{schema} or array of \code{structType} if \code{as.json.array} is set to \code{TRUE}. -#' If the string is unparseable, the Column will contains the value NA. +#' @details +#' \code{from_json}: Parses a column containing a JSON string into a Column of \code{structType} +#' with the specified \code{schema} or array of \code{structType} if \code{as.json.array} is set +#' to \code{TRUE}. If the string is unparseable, the Column will contain the value NA. #' -#' @param x Column containing the JSON string. +#' @rdname column_collection_functions #' @param schema a structType object to use as the schema to use when parsing the JSON string. #' @param as.json.array indicating if input string is JSON array of objects or a single object. -#' @param ... additional named properties to control how the json is parsed, accepts the same -#' options as the JSON data source. -#' -#' @family non-aggregate functions -#' @rdname from_json -#' @name from_json -#' @aliases from_json,Column,structType-method +#' @aliases from_json from_json,Column,structType-method #' @export #' @examples +#' #' \dontrun{ -#' schema <- structType(structField("name", "string"), -#' select(df, from_json(df$value, schema, dateFormat = "dd/MM/yyyy")) -#'} +#' df2 <- sql("SELECT named_struct('date', cast('2000-01-01' as date)) as d") +#' df2 <- mutate(df2, d2 = to_json(df2$d, dateFormat = 'dd/MM/yyyy')) +#' schema <- structType(structField("date", "string")) +#' head(select(df2, from_json(df2$d2, schema, dateFormat = 'dd/MM/yyyy'))) + +#' df2 <- sql("SELECT named_struct('name', 'Bob') as people") +#' df2 <- mutate(df2, people_json = to_json(df2$people)) +#' schema <- structType(structField("name", "string")) +#' head(select(df2, from_json(df2$people_json, schema)))} #' @note from_json since 2.2.0 setMethod("from_json", signature(x = "Column", schema = "structType"), function(x, schema, as.json.array = FALSE, ...) { @@ -3101,18 +3123,14 @@ setMethod("row_number", ###################### Collection functions###################### -#' array_contains -#' -#' Returns null if the array is null, true if the array contains the value, and false otherwise. +#' @details +#' \code{array_contains}: Returns null if the array is null, true if the array contains +#' the value, and false otherwise. #' -#' @param x A Column #' @param value A value to be checked if contained in the column -#' @rdname array_contains -#' @aliases array_contains,Column-method -#' @name array_contains -#' @family collection functions +#' @rdname column_collection_functions +#' @aliases array_contains array_contains,Column-method #' @export -#' @examples \dontrun{array_contains(df$c, 1)} #' @note array_contains since 1.6.0 setMethod("array_contains", signature(x = "Column", value = "ANY"), @@ -3121,18 +3139,12 @@ setMethod("array_contains", column(jc) }) -#' explode -#' -#' Creates a new row for each element in the given array or map column. -#' -#' @param x Column to compute on +#' @details +#' \code{explode}: Creates a new row for each element in the given array or map column. #' -#' @rdname explode -#' @name explode -#' @family collection functions -#' @aliases explode,Column-method +#' @rdname column_collection_functions +#' @aliases explode explode,Column-method #' @export -#' @examples \dontrun{explode(df$c)} #' @note explode since 1.5.0 setMethod("explode", signature(x = "Column"), @@ -3141,18 +3153,12 @@ setMethod("explode", column(jc) }) -#' size -#' -#' Returns length of array or map. -#' -#' @param x Column to compute on +#' @details +#' \code{size}: Returns length of array or map. #' -#' @rdname size -#' @name size -#' @aliases size,Column-method -#' @family collection functions +#' @rdname column_collection_functions +#' @aliases size size,Column-method #' @export -#' @examples \dontrun{size(df$c)} #' @note size since 1.5.0 setMethod("size", signature(x = "Column"), @@ -3161,25 +3167,16 @@ setMethod("size", column(jc) }) -#' sort_array -#' -#' Sorts the input array in ascending or descending order according +#' @details +#' \code{sort_array}: Sorts the input array in ascending or descending order according #' to the natural ordering of the array elements. #' -#' @param x A Column to sort +#' @rdname column_collection_functions #' @param asc A logical flag indicating the sorting order. #' TRUE, sorting is in ascending order. #' FALSE, sorting is in descending order. -#' @rdname sort_array -#' @name sort_array -#' @aliases sort_array,Column-method -#' @family collection functions +#' @aliases sort_array sort_array,Column-method #' @export -#' @examples -#' \dontrun{ -#' sort_array(df$c) -#' sort_array(df$c, FALSE) -#' } #' @note sort_array since 1.6.0 setMethod("sort_array", signature(x = "Column"), @@ -3188,18 +3185,13 @@ setMethod("sort_array", column(jc) }) -#' posexplode -#' -#' Creates a new row for each element with position in the given array or map column. -#' -#' @param x Column to compute on +#' @details +#' \code{posexplode}: Creates a new row for each element with position in the given array +#' or map column. #' -#' @rdname posexplode -#' @name posexplode -#' @family collection functions -#' @aliases posexplode,Column-method +#' @rdname column_collection_functions +#' @aliases posexplode posexplode,Column-method #' @export -#' @examples \dontrun{posexplode(df$c)} #' @note posexplode since 2.1.0 setMethod("posexplode", signature(x = "Column"), @@ -3325,27 +3317,24 @@ setMethod("repeat_string", column(jc) }) -#' explode_outer -#' -#' Creates a new row for each element in the given array or map column. +#' @details +#' \code{explode}: Creates a new row for each element in the given array or map column. #' Unlike \code{explode}, if the array/map is \code{null} or empty #' then \code{null} is produced. #' -#' @param x Column to compute on #' -#' @rdname explode_outer -#' @name explode_outer -#' @family collection functions -#' @aliases explode_outer,Column-method +#' @rdname column_collection_functions +#' @aliases explode_outer explode_outer,Column-method #' @export #' @examples +#' #' \dontrun{ -#' df <- createDataFrame(data.frame( +#' df2 <- createDataFrame(data.frame( #' id = c(1, 2, 3), text = c("a,b,c", NA, "d,e") #' )) #' -#' head(select(df, df$id, explode_outer(split_string(df$text, ",")))) -#' } +#' head(select(df2, df2$id, explode_outer(split_string(df2$text, ",")))) +#' head(select(df2, df2$id, posexplode_outer(split_string(df2$text, ","))))} #' @note explode_outer since 2.3.0 setMethod("explode_outer", signature(x = "Column"), @@ -3354,27 +3343,14 @@ setMethod("explode_outer", column(jc) }) -#' posexplode_outer -#' -#' Creates a new row for each element with position in the given array or map column. -#' Unlike \code{posexplode}, if the array/map is \code{null} or empty +#' @details +#' \code{posexplode_outer}: Creates a new row for each element with position in the given +#' array or map column. Unlike \code{posexplode}, if the array/map is \code{null} or empty #' then the row (\code{null}, \code{null}) is produced. #' -#' @param x Column to compute on -#' -#' @rdname posexplode_outer -#' @name posexplode_outer -#' @family collection functions -#' @aliases posexplode_outer,Column-method +#' @rdname column_collection_functions +#' @aliases posexplode_outer posexplode_outer,Column-method #' @export -#' @examples -#' \dontrun{ -#' df <- createDataFrame(data.frame( -#' id = c(1, 2, 3), text = c("a,b,c", NA, "d,e") -#' )) -#' -#' head(select(df, df$id, posexplode_outer(split_string(df$text, ",")))) -#' } #' @note posexplode_outer since 2.3.0 setMethod("posexplode_outer", signature(x = "Column"), diff --git a/R/pkg/R/generics.R b/R/pkg/R/generics.R index bdd4b360f4973..b901b74e4728d 100644 --- a/R/pkg/R/generics.R +++ b/R/pkg/R/generics.R @@ -913,8 +913,9 @@ setGeneric("add_months", function(y, x) { standardGeneric("add_months") }) #' @name NULL setGeneric("approxCountDistinct", function(x, ...) { standardGeneric("approxCountDistinct") }) -#' @rdname array_contains +#' @rdname column_collection_functions #' @export +#' @name NULL setGeneric("array_contains", function(x, value) { standardGeneric("array_contains") }) #' @rdname column_string_functions @@ -1062,12 +1063,14 @@ setGeneric("dense_rank", function(x = "missing") { standardGeneric("dense_rank") #' @name NULL setGeneric("encode", function(x, charset) { standardGeneric("encode") }) -#' @rdname explode +#' @rdname column_collection_functions #' @export +#' @name NULL setGeneric("explode", function(x) { standardGeneric("explode") }) -#' @rdname explode_outer +#' @rdname column_collection_functions #' @export +#' @name NULL setGeneric("explode_outer", function(x) { standardGeneric("explode_outer") }) #' @rdname column_nonaggregate_functions @@ -1090,8 +1093,9 @@ setGeneric("format_number", function(y, x) { standardGeneric("format_number") }) #' @name NULL setGeneric("format_string", function(format, x, ...) { standardGeneric("format_string") }) -#' @rdname from_json +#' @rdname column_collection_functions #' @export +#' @name NULL setGeneric("from_json", function(x, schema, ...) { standardGeneric("from_json") }) #' @rdname column_datetime_functions @@ -1275,12 +1279,14 @@ setGeneric("percent_rank", function(x = "missing") { standardGeneric("percent_ra #' @name NULL setGeneric("pmod", function(y, x) { standardGeneric("pmod") }) -#' @rdname posexplode +#' @rdname column_collection_functions #' @export +#' @name NULL setGeneric("posexplode", function(x) { standardGeneric("posexplode") }) -#' @rdname posexplode_outer +#' @rdname column_collection_functions #' @export +#' @name NULL setGeneric("posexplode_outer", function(x) { standardGeneric("posexplode_outer") }) #' @rdname column_datetime_functions @@ -1383,8 +1389,9 @@ setGeneric("shiftRightUnsigned", function(y, x) { standardGeneric("shiftRightUns #' @name NULL setGeneric("signum", function(x) { standardGeneric("signum") }) -#' @rdname size +#' @rdname column_collection_functions #' @export +#' @name NULL setGeneric("size", function(x) { standardGeneric("size") }) #' @rdname column_aggregate_functions @@ -1392,8 +1399,9 @@ setGeneric("size", function(x) { standardGeneric("size") }) #' @name NULL setGeneric("skewness", function(x) { standardGeneric("skewness") }) -#' @rdname sort_array +#' @rdname column_collection_functions #' @export +#' @name NULL setGeneric("sort_array", function(x, asc = TRUE) { standardGeneric("sort_array") }) #' @rdname column_string_functions @@ -1456,8 +1464,9 @@ setGeneric("toRadians", function(x) { standardGeneric("toRadians") }) #' @name NULL setGeneric("to_date", function(x, format) { standardGeneric("to_date") }) -#' @rdname to_json +#' @rdname column_collection_functions #' @export +#' @name NULL setGeneric("to_json", function(x, ...) { standardGeneric("to_json") }) #' @rdname column_datetime_functions From 49d767d838691fc7d964be2c4349662f5500ff2b Mon Sep 17 00:00:00 2001 From: actuaryzhang Date: Fri, 30 Jun 2017 20:02:15 +0800 Subject: [PATCH 0831/1765] [SPARK-18710][ML] Add offset in GLM ## What changes were proposed in this pull request? Add support for offset in GLM. This is useful for at least two reasons: 1. Account for exposure: e.g., when modeling the number of accidents, we may need to use miles driven as an offset to access factors on frequency. 2. Test incremental effects of new variables: we can use predictions from the existing model as offset and run a much smaller model on only new variables. This avoids re-estimating the large model with all variables (old + new) and can be very important for efficient large-scaled analysis. ## How was this patch tested? New test. yanboliang srowen felixcheung sethah Author: actuaryzhang Closes #16699 from actuaryzhang/offset. --- .../apache/spark/ml/feature/Instance.scala | 21 + .../IterativelyReweightedLeastSquares.scala | 14 +- .../spark/ml/optim/WeightedLeastSquares.scala | 2 +- .../GeneralizedLinearRegression.scala | 184 +++-- ...erativelyReweightedLeastSquaresSuite.scala | 40 +- .../GeneralizedLinearRegressionSuite.scala | 634 ++++++++++-------- 6 files changed, 534 insertions(+), 361 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Instance.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Instance.scala index cce3ca45ccd8f..dd56fbbfa2b63 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/Instance.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Instance.scala @@ -27,3 +27,24 @@ import org.apache.spark.ml.linalg.Vector * @param features The vector of features for this data point. */ private[ml] case class Instance(label: Double, weight: Double, features: Vector) + +/** + * Case class that represents an instance of data point with + * label, weight, offset and features. + * This is mainly used in GeneralizedLinearRegression currently. + * + * @param label Label for this data point. + * @param weight The weight of this instance. + * @param offset The offset used for this data point. + * @param features The vector of features for this data point. + */ +private[ml] case class OffsetInstance( + label: Double, + weight: Double, + offset: Double, + features: Vector) { + + /** Converts to an [[Instance]] object by leaving out the offset. */ + def toInstance: Instance = Instance(label, weight, features) + +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/optim/IterativelyReweightedLeastSquares.scala b/mllib/src/main/scala/org/apache/spark/ml/optim/IterativelyReweightedLeastSquares.scala index 9c495512422ba..6961b45f55e4d 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/optim/IterativelyReweightedLeastSquares.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/optim/IterativelyReweightedLeastSquares.scala @@ -18,7 +18,7 @@ package org.apache.spark.ml.optim import org.apache.spark.internal.Logging -import org.apache.spark.ml.feature.Instance +import org.apache.spark.ml.feature.{Instance, OffsetInstance} import org.apache.spark.ml.linalg._ import org.apache.spark.rdd.RDD @@ -43,7 +43,7 @@ private[ml] class IterativelyReweightedLeastSquaresModel( * find M-estimator in robust regression and other optimization problems. * * @param initialModel the initial guess model. - * @param reweightFunc the reweight function which is used to update offsets and weights + * @param reweightFunc the reweight function which is used to update working labels and weights * at each iteration. * @param fitIntercept whether to fit intercept. * @param regParam L2 regularization parameter used by WLS. @@ -57,13 +57,13 @@ private[ml] class IterativelyReweightedLeastSquaresModel( */ private[ml] class IterativelyReweightedLeastSquares( val initialModel: WeightedLeastSquaresModel, - val reweightFunc: (Instance, WeightedLeastSquaresModel) => (Double, Double), + val reweightFunc: (OffsetInstance, WeightedLeastSquaresModel) => (Double, Double), val fitIntercept: Boolean, val regParam: Double, val maxIter: Int, val tol: Double) extends Logging with Serializable { - def fit(instances: RDD[Instance]): IterativelyReweightedLeastSquaresModel = { + def fit(instances: RDD[OffsetInstance]): IterativelyReweightedLeastSquaresModel = { var converged = false var iter = 0 @@ -75,10 +75,10 @@ private[ml] class IterativelyReweightedLeastSquares( oldModel = model - // Update offsets and weights using reweightFunc + // Update working labels and weights using reweightFunc val newInstances = instances.map { instance => - val (newOffset, newWeight) = reweightFunc(instance, oldModel) - Instance(newOffset, newWeight, instance.features) + val (newLabel, newWeight) = reweightFunc(instance, oldModel) + Instance(newLabel, newWeight, instance.features) } // Estimate new model diff --git a/mllib/src/main/scala/org/apache/spark/ml/optim/WeightedLeastSquares.scala b/mllib/src/main/scala/org/apache/spark/ml/optim/WeightedLeastSquares.scala index 56ab9675700a0..32b0af72ba9bb 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/optim/WeightedLeastSquares.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/optim/WeightedLeastSquares.scala @@ -18,7 +18,7 @@ package org.apache.spark.ml.optim import org.apache.spark.internal.Logging -import org.apache.spark.ml.feature.Instance +import org.apache.spark.ml.feature.{Instance, OffsetInstance} import org.apache.spark.ml.linalg._ import org.apache.spark.rdd.RDD diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala index bff0d9bbb46ff..ce3460ae43566 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala @@ -26,8 +26,8 @@ import org.apache.spark.SparkException import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.internal.Logging import org.apache.spark.ml.PredictorParams -import org.apache.spark.ml.feature.Instance -import org.apache.spark.ml.linalg.{BLAS, Vector} +import org.apache.spark.ml.feature.{Instance, OffsetInstance} +import org.apache.spark.ml.linalg.{BLAS, Vector, Vectors} import org.apache.spark.ml.optim._ import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared._ @@ -138,6 +138,27 @@ private[regression] trait GeneralizedLinearRegressionBase extends PredictorParam @Since("2.0.0") def getLinkPredictionCol: String = $(linkPredictionCol) + /** + * Param for offset column name. If this is not set or empty, we treat all instance offsets + * as 0.0. The feature specified as offset has a constant coefficient of 1.0. + * @group param + */ + @Since("2.3.0") + final val offsetCol: Param[String] = new Param[String](this, "offsetCol", "The offset " + + "column name. If this is not set or empty, we treat all instance offsets as 0.0") + + /** @group getParam */ + @Since("2.3.0") + def getOffsetCol: String = $(offsetCol) + + /** Checks whether weight column is set and nonempty. */ + private[regression] def hasWeightCol: Boolean = + isSet(weightCol) && $(weightCol).nonEmpty + + /** Checks whether offset column is set and nonempty. */ + private[regression] def hasOffsetCol: Boolean = + isSet(offsetCol) && $(offsetCol).nonEmpty + /** Checks whether we should output link prediction. */ private[regression] def hasLinkPredictionCol: Boolean = { isDefined(linkPredictionCol) && $(linkPredictionCol).nonEmpty @@ -172,6 +193,11 @@ private[regression] trait GeneralizedLinearRegressionBase extends PredictorParam } val newSchema = super.validateAndTransformSchema(schema, fitting, featuresDataType) + + if (hasOffsetCol) { + SchemaUtils.checkNumericType(schema, $(offsetCol)) + } + if (hasLinkPredictionCol) { SchemaUtils.appendColumn(newSchema, $(linkPredictionCol), DoubleType) } else { @@ -306,6 +332,16 @@ class GeneralizedLinearRegression @Since("2.0.0") (@Since("2.0.0") override val @Since("2.0.0") def setWeightCol(value: String): this.type = set(weightCol, value) + /** + * Sets the value of param [[offsetCol]]. + * If this is not set or empty, we treat all instance offsets as 0.0. + * Default is not set, so all instances have offset 0.0. + * + * @group setParam + */ + @Since("2.3.0") + def setOffsetCol(value: String): this.type = set(offsetCol, value) + /** * Sets the solver algorithm used for optimization. * Currently only supports "irls" which is also the default solver. @@ -329,7 +365,7 @@ class GeneralizedLinearRegression @Since("2.0.0") (@Since("2.0.0") override val val numFeatures = dataset.select(col($(featuresCol))).first().getAs[Vector](0).size val instr = Instrumentation.create(this, dataset) - instr.logParams(labelCol, featuresCol, weightCol, predictionCol, linkPredictionCol, + instr.logParams(labelCol, featuresCol, weightCol, offsetCol, predictionCol, linkPredictionCol, family, solver, fitIntercept, link, maxIter, regParam, tol) instr.logNumFeatures(numFeatures) @@ -343,15 +379,16 @@ class GeneralizedLinearRegression @Since("2.0.0") (@Since("2.0.0") override val "GeneralizedLinearRegression was given data with 0 features, and with Param fitIntercept " + "set to false. To fit a model with 0 features, fitIntercept must be set to true." ) - val w = if (!isDefined(weightCol) || $(weightCol).isEmpty) lit(1.0) else col($(weightCol)) - val instances: RDD[Instance] = - dataset.select(col($(labelCol)), w, col($(featuresCol))).rdd.map { - case Row(label: Double, weight: Double, features: Vector) => - Instance(label, weight, features) - } + val w = if (!hasWeightCol) lit(1.0) else col($(weightCol)) + val offset = if (!hasOffsetCol) lit(0.0) else col($(offsetCol)).cast(DoubleType) val model = if (familyAndLink.family == Gaussian && familyAndLink.link == Identity) { // TODO: Make standardizeFeatures and standardizeLabel configurable. + val instances: RDD[Instance] = + dataset.select(col($(labelCol)), w, offset, col($(featuresCol))).rdd.map { + case Row(label: Double, weight: Double, offset: Double, features: Vector) => + Instance(label - offset, weight, features) + } val optimizer = new WeightedLeastSquares($(fitIntercept), $(regParam), elasticNetParam = 0.0, standardizeFeatures = true, standardizeLabel = true) val wlsModel = optimizer.fit(instances) @@ -362,6 +399,11 @@ class GeneralizedLinearRegression @Since("2.0.0") (@Since("2.0.0") override val wlsModel.diagInvAtWA.toArray, 1, getSolver) model.setSummary(Some(trainingSummary)) } else { + val instances: RDD[OffsetInstance] = + dataset.select(col($(labelCol)), w, offset, col($(featuresCol))).rdd.map { + case Row(label: Double, weight: Double, offset: Double, features: Vector) => + OffsetInstance(label, weight, offset, features) + } // Fit Generalized Linear Model by iteratively reweighted least squares (IRLS). val initialModel = familyAndLink.initialize(instances, $(fitIntercept), $(regParam)) val optimizer = new IterativelyReweightedLeastSquares(initialModel, @@ -425,12 +467,12 @@ object GeneralizedLinearRegression extends DefaultParamsReadable[GeneralizedLine * Get the initial guess model for [[IterativelyReweightedLeastSquares]]. */ def initialize( - instances: RDD[Instance], + instances: RDD[OffsetInstance], fitIntercept: Boolean, regParam: Double): WeightedLeastSquaresModel = { val newInstances = instances.map { instance => val mu = family.initialize(instance.label, instance.weight) - val eta = predict(mu) + val eta = predict(mu) - instance.offset Instance(eta, instance.weight, instance.features) } // TODO: Make standardizeFeatures and standardizeLabel configurable. @@ -441,16 +483,16 @@ object GeneralizedLinearRegression extends DefaultParamsReadable[GeneralizedLine } /** - * The reweight function used to update offsets and weights + * The reweight function used to update working labels and weights * at each iteration of [[IterativelyReweightedLeastSquares]]. */ - val reweightFunc: (Instance, WeightedLeastSquaresModel) => (Double, Double) = { - (instance: Instance, model: WeightedLeastSquaresModel) => { - val eta = model.predict(instance.features) + val reweightFunc: (OffsetInstance, WeightedLeastSquaresModel) => (Double, Double) = { + (instance: OffsetInstance, model: WeightedLeastSquaresModel) => { + val eta = model.predict(instance.features) + instance.offset val mu = fitted(eta) - val offset = eta + (instance.label - mu) * link.deriv(mu) - val weight = instance.weight / (math.pow(this.link.deriv(mu), 2.0) * family.variance(mu)) - (offset, weight) + val newLabel = eta - instance.offset + (instance.label - mu) * link.deriv(mu) + val newWeight = instance.weight / (math.pow(this.link.deriv(mu), 2.0) * family.variance(mu)) + (newLabel, newWeight) } } } @@ -950,15 +992,22 @@ class GeneralizedLinearRegressionModel private[ml] ( private lazy val familyAndLink = FamilyAndLink(this) override protected def predict(features: Vector): Double = { - val eta = predictLink(features) + predict(features, 0.0) + } + + /** + * Calculates the predicted value when offset is set. + */ + private def predict(features: Vector, offset: Double): Double = { + val eta = predictLink(features, offset) familyAndLink.fitted(eta) } /** - * Calculate the link prediction (linear predictor) of the given instance. + * Calculates the link prediction (linear predictor) of the given instance. */ - private def predictLink(features: Vector): Double = { - BLAS.dot(features, coefficients) + intercept + private def predictLink(features: Vector, offset: Double): Double = { + BLAS.dot(features, coefficients) + intercept + offset } override def transform(dataset: Dataset[_]): DataFrame = { @@ -967,14 +1016,16 @@ class GeneralizedLinearRegressionModel private[ml] ( } override protected def transformImpl(dataset: Dataset[_]): DataFrame = { - val predictUDF = udf { (features: Vector) => predict(features) } - val predictLinkUDF = udf { (features: Vector) => predictLink(features) } + val predictUDF = udf { (features: Vector, offset: Double) => predict(features, offset) } + val predictLinkUDF = udf { (features: Vector, offset: Double) => predictLink(features, offset) } + + val offset = if (!hasOffsetCol) lit(0.0) else col($(offsetCol)).cast(DoubleType) var output = dataset if ($(predictionCol).nonEmpty) { - output = output.withColumn($(predictionCol), predictUDF(col($(featuresCol)))) + output = output.withColumn($(predictionCol), predictUDF(col($(featuresCol)), offset)) } if (hasLinkPredictionCol) { - output = output.withColumn($(linkPredictionCol), predictLinkUDF(col($(featuresCol)))) + output = output.withColumn($(linkPredictionCol), predictLinkUDF(col($(featuresCol)), offset)) } output.toDF() } @@ -1146,9 +1197,7 @@ class GeneralizedLinearRegressionSummary private[regression] ( /** Degrees of freedom. */ @Since("2.0.0") - lazy val degreesOfFreedom: Long = { - numInstances - rank - } + lazy val degreesOfFreedom: Long = numInstances - rank /** The residual degrees of freedom. */ @Since("2.0.0") @@ -1156,18 +1205,20 @@ class GeneralizedLinearRegressionSummary private[regression] ( /** The residual degrees of freedom for the null model. */ @Since("2.0.0") - lazy val residualDegreeOfFreedomNull: Long = if (model.getFitIntercept) { - numInstances - 1 - } else { - numInstances + lazy val residualDegreeOfFreedomNull: Long = { + if (model.getFitIntercept) numInstances - 1 else numInstances } - private def weightCol: Column = { - if (!model.isDefined(model.weightCol) || model.getWeightCol.isEmpty) { - lit(1.0) - } else { - col(model.getWeightCol) - } + private def label: Column = col(model.getLabelCol).cast(DoubleType) + + private def prediction: Column = col(predictionCol) + + private def weight: Column = { + if (!model.hasWeightCol) lit(1.0) else col(model.getWeightCol) + } + + private def offset: Column = { + if (!model.hasOffsetCol) lit(0.0) else col(model.getOffsetCol).cast(DoubleType) } private[regression] lazy val devianceResiduals: DataFrame = { @@ -1175,25 +1226,23 @@ class GeneralizedLinearRegressionSummary private[regression] ( val r = math.sqrt(math.max(family.deviance(y, mu, weight), 0.0)) if (y > mu) r else -1.0 * r } - val w = weightCol predictions.select( - drUDF(col(model.getLabelCol), col(predictionCol), w).as("devianceResiduals")) + drUDF(label, prediction, weight).as("devianceResiduals")) } private[regression] lazy val pearsonResiduals: DataFrame = { val prUDF = udf { mu: Double => family.variance(mu) } - val w = weightCol - predictions.select(col(model.getLabelCol).minus(col(predictionCol)) - .multiply(sqrt(w)).divide(sqrt(prUDF(col(predictionCol)))).as("pearsonResiduals")) + predictions.select(label.minus(prediction) + .multiply(sqrt(weight)).divide(sqrt(prUDF(prediction))).as("pearsonResiduals")) } private[regression] lazy val workingResiduals: DataFrame = { val wrUDF = udf { (y: Double, mu: Double) => (y - mu) * link.deriv(mu) } - predictions.select(wrUDF(col(model.getLabelCol), col(predictionCol)).as("workingResiduals")) + predictions.select(wrUDF(label, prediction).as("workingResiduals")) } private[regression] lazy val responseResiduals: DataFrame = { - predictions.select(col(model.getLabelCol).minus(col(predictionCol)).as("responseResiduals")) + predictions.select(label.minus(prediction).as("responseResiduals")) } /** @@ -1225,16 +1274,35 @@ class GeneralizedLinearRegressionSummary private[regression] ( */ @Since("2.0.0") lazy val nullDeviance: Double = { - val w = weightCol - val wtdmu: Double = if (model.getFitIntercept) { - val agg = predictions.agg(sum(w.multiply(col(model.getLabelCol))), sum(w)).first() - agg.getDouble(0) / agg.getDouble(1) + val intercept: Double = if (!model.getFitIntercept) { + 0.0 } else { - link.unlink(0.0) + /* + Estimate intercept analytically when there is no offset, or when there is offset but + the model is Gaussian family with identity link. Otherwise, fit an intercept only model. + */ + if (!model.hasOffsetCol || + (model.hasOffsetCol && family == Gaussian && link == Identity)) { + val agg = predictions.agg(sum(weight.multiply( + label.minus(offset))), sum(weight)).first() + link.link(agg.getDouble(0) / agg.getDouble(1)) + } else { + // Create empty feature column and fit intercept only model using param setting from model + val featureNull = "feature_" + java.util.UUID.randomUUID.toString + val paramMap = model.extractParamMap() + paramMap.put(model.featuresCol, featureNull) + if (family.name != "tweedie") { + paramMap.remove(model.variancePower) + } + val emptyVectorUDF = udf{ () => Vectors.zeros(0) } + model.parent.fit( + dataset.withColumn(featureNull, emptyVectorUDF()), paramMap + ).intercept + } } - predictions.select(col(model.getLabelCol).cast(DoubleType), w).rdd.map { - case Row(y: Double, weight: Double) => - family.deviance(y, wtdmu, weight) + predictions.select(label, offset, weight).rdd.map { + case Row(y: Double, offset: Double, weight: Double) => + family.deviance(y, link.unlink(intercept + offset), weight) }.sum() } @@ -1243,8 +1311,7 @@ class GeneralizedLinearRegressionSummary private[regression] ( */ @Since("2.0.0") lazy val deviance: Double = { - val w = weightCol - predictions.select(col(model.getLabelCol).cast(DoubleType), col(predictionCol), w).rdd.map { + predictions.select(label, prediction, weight).rdd.map { case Row(label: Double, pred: Double, weight: Double) => family.deviance(label, pred, weight) }.sum() @@ -1269,10 +1336,9 @@ class GeneralizedLinearRegressionSummary private[regression] ( /** Akaike Information Criterion (AIC) for the fitted model. */ @Since("2.0.0") lazy val aic: Double = { - val w = weightCol - val weightSum = predictions.select(w).agg(sum(w)).first().getDouble(0) + val weightSum = predictions.select(weight).agg(sum(weight)).first().getDouble(0) val t = predictions.select( - col(model.getLabelCol).cast(DoubleType), col(predictionCol), w).rdd.map { + label, prediction, weight).rdd.map { case Row(label: Double, pred: Double, weight: Double) => (label, pred, weight) } diff --git a/mllib/src/test/scala/org/apache/spark/ml/optim/IterativelyReweightedLeastSquaresSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/optim/IterativelyReweightedLeastSquaresSuite.scala index 50260952ecb66..6d143504fcf58 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/optim/IterativelyReweightedLeastSquaresSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/optim/IterativelyReweightedLeastSquaresSuite.scala @@ -18,7 +18,7 @@ package org.apache.spark.ml.optim import org.apache.spark.SparkFunSuite -import org.apache.spark.ml.feature.Instance +import org.apache.spark.ml.feature.{Instance, OffsetInstance} import org.apache.spark.ml.linalg.Vectors import org.apache.spark.ml.util.TestingUtils._ import org.apache.spark.mllib.util.MLlibTestSparkContext @@ -26,8 +26,8 @@ import org.apache.spark.rdd.RDD class IterativelyReweightedLeastSquaresSuite extends SparkFunSuite with MLlibTestSparkContext { - private var instances1: RDD[Instance] = _ - private var instances2: RDD[Instance] = _ + private var instances1: RDD[OffsetInstance] = _ + private var instances2: RDD[OffsetInstance] = _ override def beforeAll(): Unit = { super.beforeAll() @@ -39,10 +39,10 @@ class IterativelyReweightedLeastSquaresSuite extends SparkFunSuite with MLlibTes w <- c(1, 2, 3, 4) */ instances1 = sc.parallelize(Seq( - Instance(1.0, 1.0, Vectors.dense(0.0, 5.0).toSparse), - Instance(0.0, 2.0, Vectors.dense(1.0, 2.0)), - Instance(1.0, 3.0, Vectors.dense(2.0, 1.0)), - Instance(0.0, 4.0, Vectors.dense(3.0, 3.0)) + OffsetInstance(1.0, 1.0, 0.0, Vectors.dense(0.0, 5.0).toSparse), + OffsetInstance(0.0, 2.0, 0.0, Vectors.dense(1.0, 2.0)), + OffsetInstance(1.0, 3.0, 0.0, Vectors.dense(2.0, 1.0)), + OffsetInstance(0.0, 4.0, 0.0, Vectors.dense(3.0, 3.0)) ), 2) /* R code: @@ -52,10 +52,10 @@ class IterativelyReweightedLeastSquaresSuite extends SparkFunSuite with MLlibTes w <- c(1, 2, 3, 4) */ instances2 = sc.parallelize(Seq( - Instance(2.0, 1.0, Vectors.dense(0.0, 5.0).toSparse), - Instance(8.0, 2.0, Vectors.dense(1.0, 7.0)), - Instance(3.0, 3.0, Vectors.dense(2.0, 11.0)), - Instance(9.0, 4.0, Vectors.dense(3.0, 13.0)) + OffsetInstance(2.0, 1.0, 0.0, Vectors.dense(0.0, 5.0).toSparse), + OffsetInstance(8.0, 2.0, 0.0, Vectors.dense(1.0, 7.0)), + OffsetInstance(3.0, 3.0, 0.0, Vectors.dense(2.0, 11.0)), + OffsetInstance(9.0, 4.0, 0.0, Vectors.dense(3.0, 13.0)) ), 2) } @@ -156,7 +156,7 @@ class IterativelyReweightedLeastSquaresSuite extends SparkFunSuite with MLlibTes var idx = 0 for (fitIntercept <- Seq(false, true)) { val initial = new WeightedLeastSquares(fitIntercept, regParam = 0.0, elasticNetParam = 0.0, - standardizeFeatures = false, standardizeLabel = false).fit(instances2) + standardizeFeatures = false, standardizeLabel = false).fit(instances2.map(_.toInstance)) val irls = new IterativelyReweightedLeastSquares(initial, L1RegressionReweightFunc, fitIntercept, regParam = 0.0, maxIter = 200, tol = 1e-7).fit(instances2) val actual = Vectors.dense(irls.intercept, irls.coefficients(0), irls.coefficients(1)) @@ -169,29 +169,29 @@ class IterativelyReweightedLeastSquaresSuite extends SparkFunSuite with MLlibTes object IterativelyReweightedLeastSquaresSuite { def BinomialReweightFunc( - instance: Instance, + instance: OffsetInstance, model: WeightedLeastSquaresModel): (Double, Double) = { - val eta = model.predict(instance.features) + val eta = model.predict(instance.features) + instance.offset val mu = 1.0 / (1.0 + math.exp(-1.0 * eta)) - val z = eta + (instance.label - mu) / (mu * (1.0 - mu)) + val z = eta - instance.offset + (instance.label - mu) / (mu * (1.0 - mu)) val w = mu * (1 - mu) * instance.weight (z, w) } def PoissonReweightFunc( - instance: Instance, + instance: OffsetInstance, model: WeightedLeastSquaresModel): (Double, Double) = { - val eta = model.predict(instance.features) + val eta = model.predict(instance.features) + instance.offset val mu = math.exp(eta) - val z = eta + (instance.label - mu) / mu + val z = eta - instance.offset + (instance.label - mu) / mu val w = mu * instance.weight (z, w) } def L1RegressionReweightFunc( - instance: Instance, + instance: OffsetInstance, model: WeightedLeastSquaresModel): (Double, Double) = { - val eta = model.predict(instance.features) + val eta = model.predict(instance.features) + instance.offset val e = math.max(math.abs(eta - instance.label), 1e-7) val w = 1 / e val y = instance.label diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala index f7c7c001a36af..cfaa57314bd66 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala @@ -21,7 +21,7 @@ import scala.util.Random import org.apache.spark.SparkFunSuite import org.apache.spark.ml.classification.LogisticRegressionSuite._ -import org.apache.spark.ml.feature.Instance +import org.apache.spark.ml.feature.{Instance, OffsetInstance} import org.apache.spark.ml.feature.LabeledPoint import org.apache.spark.ml.linalg.{BLAS, DenseVector, Vector, Vectors} import org.apache.spark.ml.param.{ParamMap, ParamsSuite} @@ -797,77 +797,160 @@ class GeneralizedLinearRegressionSuite } } - test("glm summary: gaussian family with weight") { + test("generalized linear regression with weight and offset") { /* - R code: + R code: + library(statmod) - A <- matrix(c(0, 1, 2, 3, 5, 7, 11, 13), 4, 2) - b <- c(17, 19, 23, 29) - w <- c(1, 2, 3, 4) - df <- as.data.frame(cbind(A, b)) - */ - val datasetWithWeight = Seq( - Instance(17.0, 1.0, Vectors.dense(0.0, 5.0).toSparse), - Instance(19.0, 2.0, Vectors.dense(1.0, 7.0)), - Instance(23.0, 3.0, Vectors.dense(2.0, 11.0)), - Instance(29.0, 4.0, Vectors.dense(3.0, 13.0)) + df <- as.data.frame(matrix(c( + 0.2, 1.0, 2.0, 0.0, 5.0, + 0.5, 2.1, 0.5, 1.0, 2.0, + 0.9, 0.4, 1.0, 2.0, 1.0, + 0.7, 0.7, 0.0, 3.0, 3.0), 4, 5, byrow = TRUE)) + families <- list(gaussian, binomial, poisson, Gamma, tweedie(1.5)) + f1 <- V1 ~ -1 + V4 + V5 + f2 <- V1 ~ V4 + V5 + for (f in c(f1, f2)) { + for (fam in families) { + model <- glm(f, df, family = fam, weights = V2, offset = V3) + print(as.vector(coef(model))) + } + } + [1] 0.5169222 -0.3344444 + [1] 0.9419107 -0.6864404 + [1] 0.1812436 -0.6568422 + [1] -0.2869094 0.7857710 + [1] 0.1055254 0.2979113 + [1] -0.05990345 0.53188982 -0.32118415 + [1] -0.2147117 0.9911750 -0.6356096 + [1] -1.5616130 0.6646470 -0.3192581 + [1] 0.3390397 -0.3406099 0.6870259 + [1] 0.3665034 0.1039416 0.1484616 + */ + val dataset = Seq( + OffsetInstance(0.2, 1.0, 2.0, Vectors.dense(0.0, 5.0)), + OffsetInstance(0.5, 2.1, 0.5, Vectors.dense(1.0, 2.0)), + OffsetInstance(0.9, 0.4, 1.0, Vectors.dense(2.0, 1.0)), + OffsetInstance(0.7, 0.7, 0.0, Vectors.dense(3.0, 3.0)) ).toDF() + + val expected = Seq( + Vectors.dense(0, 0.5169222, -0.3344444), + Vectors.dense(0, 0.9419107, -0.6864404), + Vectors.dense(0, 0.1812436, -0.6568422), + Vectors.dense(0, -0.2869094, 0.785771), + Vectors.dense(0, 0.1055254, 0.2979113), + Vectors.dense(-0.05990345, 0.53188982, -0.32118415), + Vectors.dense(-0.2147117, 0.991175, -0.6356096), + Vectors.dense(-1.561613, 0.664647, -0.3192581), + Vectors.dense(0.3390397, -0.3406099, 0.6870259), + Vectors.dense(0.3665034, 0.1039416, 0.1484616)) + + import GeneralizedLinearRegression._ + + var idx = 0 + + for (fitIntercept <- Seq(false, true)) { + for (family <- Seq("gaussian", "binomial", "poisson", "gamma", "tweedie")) { + val trainer = new GeneralizedLinearRegression().setFamily(family) + .setFitIntercept(fitIntercept).setOffsetCol("offset") + .setWeightCol("weight").setLinkPredictionCol("linkPrediction") + if (family == "tweedie") trainer.setVariancePower(1.5) + val model = trainer.fit(dataset) + val actual = Vectors.dense(model.intercept, model.coefficients(0), model.coefficients(1)) + assert(actual ~= expected(idx) absTol 1e-4, s"Model mismatch: GLM with family = $family," + + s" and fitIntercept = $fitIntercept.") + + val familyLink = FamilyAndLink(trainer) + model.transform(dataset).select("features", "offset", "prediction", "linkPrediction") + .collect().foreach { + case Row(features: DenseVector, offset: Double, prediction1: Double, + linkPrediction1: Double) => + val eta = BLAS.dot(features, model.coefficients) + model.intercept + offset + val prediction2 = familyLink.fitted(eta) + val linkPrediction2 = eta + assert(prediction1 ~= prediction2 relTol 1E-5, "Prediction mismatch: GLM with " + + s"family = $family, and fitIntercept = $fitIntercept.") + assert(linkPrediction1 ~= linkPrediction2 relTol 1E-5, "Link Prediction mismatch: " + + s"GLM with family = $family, and fitIntercept = $fitIntercept.") + } + + idx += 1 + } + } + } + + test("glm summary: gaussian family with weight and offset") { /* - R code: + R code: - model <- glm(formula = "b ~ .", family="gaussian", data = df, weights = w) - summary(model) + A <- matrix(c(0, 1, 2, 3, 5, 7, 11, 13), 4, 2) + b <- c(17, 19, 23, 29) + w <- c(1, 2, 3, 4) + off <- c(2, 3, 1, 4) + df <- as.data.frame(cbind(A, b)) + */ + val dataset = Seq( + OffsetInstance(17.0, 1.0, 2.0, Vectors.dense(0.0, 5.0).toSparse), + OffsetInstance(19.0, 2.0, 3.0, Vectors.dense(1.0, 7.0)), + OffsetInstance(23.0, 3.0, 1.0, Vectors.dense(2.0, 11.0)), + OffsetInstance(29.0, 4.0, 4.0, Vectors.dense(3.0, 13.0)) + ).toDF() + /* + R code: - Deviance Residuals: - 1 2 3 4 - 1.920 -1.358 -1.109 0.960 + model <- glm(formula = "b ~ .", family = "gaussian", data = df, + weights = w, offset = off) + summary(model) - Coefficients: - Estimate Std. Error t value Pr(>|t|) - (Intercept) 18.080 9.608 1.882 0.311 - V1 6.080 5.556 1.094 0.471 - V2 -0.600 1.960 -0.306 0.811 + Deviance Residuals: + 1 2 3 4 + 0.9600 -0.6788 -0.5543 0.4800 - (Dispersion parameter for gaussian family taken to be 7.68) + Coefficients: + Estimate Std. Error t value Pr(>|t|) + (Intercept) 5.5400 4.8040 1.153 0.455 + V1 -0.9600 2.7782 -0.346 0.788 + V2 1.7000 0.9798 1.735 0.333 - Null deviance: 202.00 on 3 degrees of freedom - Residual deviance: 7.68 on 1 degrees of freedom - AIC: 18.783 + (Dispersion parameter for gaussian family taken to be 1.92) - Number of Fisher Scoring iterations: 2 + Null deviance: 152.10 on 3 degrees of freedom + Residual deviance: 1.92 on 1 degrees of freedom + AIC: 13.238 - residuals(model, type="pearson") - 1 2 3 4 - 1.920000 -1.357645 -1.108513 0.960000 + Number of Fisher Scoring iterations: 2 - residuals(model, type="working") + residuals(model, type = "pearson") + 1 2 3 4 + 0.9600000 -0.6788225 -0.5542563 0.4800000 + residuals(model, type = "working") 1 2 3 4 - 1.92 -0.96 -0.64 0.48 - - residuals(model, type="response") + 0.96 -0.48 -0.32 0.24 + residuals(model, type = "response") 1 2 3 4 - 1.92 -0.96 -0.64 0.48 + 0.96 -0.48 -0.32 0.24 */ val trainer = new GeneralizedLinearRegression() - .setWeightCol("weight") + .setWeightCol("weight").setOffsetCol("offset") + + val model = trainer.fit(dataset) - val model = trainer.fit(datasetWithWeight) - - val coefficientsR = Vectors.dense(Array(6.080, -0.600)) - val interceptR = 18.080 - val devianceResidualsR = Array(1.920, -1.358, -1.109, 0.960) - val pearsonResidualsR = Array(1.920000, -1.357645, -1.108513, 0.960000) - val workingResidualsR = Array(1.92, -0.96, -0.64, 0.48) - val responseResidualsR = Array(1.92, -0.96, -0.64, 0.48) - val seCoefR = Array(5.556, 1.960, 9.608) - val tValsR = Array(1.094, -0.306, 1.882) - val pValsR = Array(0.471, 0.811, 0.311) - val dispersionR = 7.68 - val nullDevianceR = 202.00 - val residualDevianceR = 7.68 + val coefficientsR = Vectors.dense(Array(-0.96, 1.7)) + val interceptR = 5.54 + val devianceResidualsR = Array(0.96, -0.67882, -0.55426, 0.48) + val pearsonResidualsR = Array(0.96, -0.67882, -0.55426, 0.48) + val workingResidualsR = Array(0.96, -0.48, -0.32, 0.24) + val responseResidualsR = Array(0.96, -0.48, -0.32, 0.24) + val seCoefR = Array(2.7782, 0.9798, 4.804) + val tValsR = Array(-0.34555, 1.73506, 1.15321) + val pValsR = Array(0.78819, 0.33286, 0.45478) + val dispersionR = 1.92 + val nullDevianceR = 152.1 + val residualDevianceR = 1.92 val residualDegreeOfFreedomNullR = 3 val residualDegreeOfFreedomR = 1 - val aicR = 18.783 + val aicR = 13.23758 assert(model.hasSummary) val summary = model.summary @@ -912,7 +995,7 @@ class GeneralizedLinearRegressionSuite assert(summary.aic ~== aicR absTol 1E-3) assert(summary.solver === "irls") - val summary2: GeneralizedLinearRegressionSummary = model.evaluate(datasetWithWeight) + val summary2: GeneralizedLinearRegressionSummary = model.evaluate(dataset) assert(summary.predictions.columns.toSet === summary2.predictions.columns.toSet) assert(summary.predictionCol === summary2.predictionCol) assert(summary.rank === summary2.rank) @@ -925,79 +1008,79 @@ class GeneralizedLinearRegressionSuite assert(summary.aic === summary2.aic) } - test("glm summary: binomial family with weight") { + test("glm summary: binomial family with weight and offset") { /* - R code: + R code: - A <- matrix(c(0, 1, 2, 3, 5, 2, 1, 3), 4, 2) - b <- c(1, 0.5, 1, 0) - w <- c(1, 2.0, 0.3, 4.7) - df <- as.data.frame(cbind(A, b)) + df <- as.data.frame(matrix(c( + 0.2, 1.0, 2.0, 0.0, 5.0, + 0.5, 2.1, 0.5, 1.0, 2.0, + 0.9, 0.4, 1.0, 2.0, 1.0, + 0.7, 0.7, 0.0, 3.0, 3.0), 4, 5, byrow = TRUE)) */ - val datasetWithWeight = Seq( - Instance(1.0, 1.0, Vectors.dense(0.0, 5.0).toSparse), - Instance(0.5, 2.0, Vectors.dense(1.0, 2.0)), - Instance(1.0, 0.3, Vectors.dense(2.0, 1.0)), - Instance(0.0, 4.7, Vectors.dense(3.0, 3.0)) + val dataset = Seq( + OffsetInstance(0.2, 1.0, 2.0, Vectors.dense(0.0, 5.0)), + OffsetInstance(0.5, 2.1, 0.5, Vectors.dense(1.0, 2.0)), + OffsetInstance(0.9, 0.4, 1.0, Vectors.dense(2.0, 1.0)), + OffsetInstance(0.7, 0.7, 0.0, Vectors.dense(3.0, 3.0)) ).toDF() - /* - R code: - - model <- glm(formula = "b ~ . -1", family="binomial", data = df, weights = w) - summary(model) - - Deviance Residuals: - 1 2 3 4 - 0.2404 0.1965 1.2824 -0.6916 + R code: - Coefficients: - Estimate Std. Error z value Pr(>|z|) - x1 -1.6901 1.2764 -1.324 0.185 - x2 0.7059 0.9449 0.747 0.455 + model <- glm(formula = "V1 ~ V4 + V5", family = "binomial", data = df, + weights = V2, offset = V3) + summary(model) - (Dispersion parameter for binomial family taken to be 1) + Deviance Residuals: + 1 2 3 4 + 0.002584 -0.003800 0.012478 -0.001796 - Null deviance: 8.3178 on 4 degrees of freedom - Residual deviance: 2.2193 on 2 degrees of freedom - AIC: 5.9915 + Coefficients: + Estimate Std. Error z value Pr(>|z|) + (Intercept) -0.2147 3.5687 -0.060 0.952 + V4 0.9912 1.2344 0.803 0.422 + V5 -0.6356 0.9669 -0.657 0.511 - Number of Fisher Scoring iterations: 5 + (Dispersion parameter for binomial family taken to be 1) - residuals(model, type="pearson") - 1 2 3 4 - 0.171217 0.197406 2.085864 -0.495332 + Null deviance: 2.17560881 on 3 degrees of freedom + Residual deviance: 0.00018005 on 1 degrees of freedom + AIC: 10.245 - residuals(model, type="working") - 1 2 3 4 - 1.029315 0.281881 15.502768 -1.052203 + Number of Fisher Scoring iterations: 4 - residuals(model, type="response") - 1 2 3 4 - 0.028480 0.069123 0.935495 -0.049613 + residuals(model, type = "pearson") + 1 2 3 4 + 0.002586113 -0.003799744 0.012372235 -0.001796892 + residuals(model, type = "working") + 1 2 3 4 + 0.006477857 -0.005244163 0.063541250 -0.004691064 + residuals(model, type = "response") + 1 2 3 4 + 0.0010324375 -0.0013110318 0.0060225522 -0.0009832738 */ val trainer = new GeneralizedLinearRegression() .setFamily("Binomial") .setWeightCol("weight") - .setFitIntercept(false) - - val model = trainer.fit(datasetWithWeight) - - val coefficientsR = Vectors.dense(Array(-1.690134, 0.705929)) - val interceptR = 0.0 - val devianceResidualsR = Array(0.2404, 0.1965, 1.2824, -0.6916) - val pearsonResidualsR = Array(0.171217, 0.197406, 2.085864, -0.495332) - val workingResidualsR = Array(1.029315, 0.281881, 15.502768, -1.052203) - val responseResidualsR = Array(0.02848, 0.069123, 0.935495, -0.049613) - val seCoefR = Array(1.276417, 0.944934) - val tValsR = Array(-1.324124, 0.747068) - val pValsR = Array(0.185462, 0.455023) - val dispersionR = 1.0 - val nullDevianceR = 8.3178 - val residualDevianceR = 2.2193 - val residualDegreeOfFreedomNullR = 4 - val residualDegreeOfFreedomR = 2 - val aicR = 5.991537 + .setOffsetCol("offset") + + val model = trainer.fit(dataset) + + val coefficientsR = Vectors.dense(Array(0.99117, -0.63561)) + val interceptR = -0.21471 + val devianceResidualsR = Array(0.00258, -0.0038, 0.01248, -0.0018) + val pearsonResidualsR = Array(0.00259, -0.0038, 0.01237, -0.0018) + val workingResidualsR = Array(0.00648, -0.00524, 0.06354, -0.00469) + val responseResidualsR = Array(0.00103, -0.00131, 0.00602, -0.00098) + val seCoefR = Array(1.23439, 0.9669, 3.56866) + val tValsR = Array(0.80297, -0.65737, -0.06017) + val pValsR = Array(0.42199, 0.51094, 0.95202) + val dispersionR = 1 + val nullDevianceR = 2.17561 + val residualDevianceR = 0.00018 + val residualDegreeOfFreedomNullR = 3 + val residualDegreeOfFreedomR = 1 + val aicR = 10.24453 val summary = model.summary val devianceResiduals = summary.residuals() @@ -1040,81 +1123,79 @@ class GeneralizedLinearRegressionSuite assert(summary.solver === "irls") } - test("glm summary: poisson family with weight") { + test("glm summary: poisson family with weight and offset") { /* - R code: + R code: - A <- matrix(c(0, 1, 2, 3, 5, 7, 11, 13), 4, 2) - b <- c(2, 8, 3, 9) - w <- c(1, 2, 3, 4) - df <- as.data.frame(cbind(A, b)) + A <- matrix(c(0, 1, 2, 3, 5, 7, 11, 13), 4, 2) + b <- c(2, 8, 3, 9) + w <- c(1, 2, 3, 4) + off <- c(2, 3, 1, 4) + df <- as.data.frame(cbind(A, b)) */ - val datasetWithWeight = Seq( - Instance(2.0, 1.0, Vectors.dense(0.0, 5.0).toSparse), - Instance(8.0, 2.0, Vectors.dense(1.0, 7.0)), - Instance(3.0, 3.0, Vectors.dense(2.0, 11.0)), - Instance(9.0, 4.0, Vectors.dense(3.0, 13.0)) + val dataset = Seq( + OffsetInstance(2.0, 1.0, 2.0, Vectors.dense(0.0, 5.0).toSparse), + OffsetInstance(8.0, 2.0, 3.0, Vectors.dense(1.0, 7.0)), + OffsetInstance(3.0, 3.0, 1.0, Vectors.dense(2.0, 11.0)), + OffsetInstance(9.0, 4.0, 4.0, Vectors.dense(3.0, 13.0)) ).toDF() /* - R code: - - model <- glm(formula = "b ~ .", family="poisson", data = df, weights = w) - summary(model) - - Deviance Residuals: - 1 2 3 4 - -0.28952 0.11048 0.14839 -0.07268 - - Coefficients: - Estimate Std. Error z value Pr(>|z|) - (Intercept) 6.2999 1.6086 3.916 8.99e-05 *** - V1 3.3241 1.0184 3.264 0.00110 ** - V2 -1.0818 0.3522 -3.071 0.00213 ** - --- - Signif. codes: 0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1 - - (Dispersion parameter for poisson family taken to be 1) - - Null deviance: 15.38066 on 3 degrees of freedom - Residual deviance: 0.12333 on 1 degrees of freedom - AIC: 41.803 - - Number of Fisher Scoring iterations: 3 + R code: - residuals(model, type="pearson") - 1 2 3 4 - -0.28043145 0.11099310 0.14963714 -0.07253611 + model <- glm(formula = "b ~ .", family = "poisson", data = df, + weights = w, offset = off) + summary(model) - residuals(model, type="working") - 1 2 3 4 - -0.17960679 0.02813593 0.05113852 -0.01201650 + Deviance Residuals: + 1 2 3 4 + -2.0480 1.2315 1.8293 -0.7107 - residuals(model, type="response") - 1 2 3 4 - -0.4378554 0.2189277 0.1459518 -0.1094638 + Coefficients: + Estimate Std. Error z value Pr(>|z|) + (Intercept) -4.5678 1.9625 -2.328 0.0199 + V1 -2.8784 1.1683 -2.464 0.0137 + V2 0.8859 0.4170 2.124 0.0336 + + (Dispersion parameter for poisson family taken to be 1) + + Null deviance: 22.5585 on 3 degrees of freedom + Residual deviance: 9.5622 on 1 degrees of freedom + AIC: 51.242 + + Number of Fisher Scoring iterations: 5 + + residuals(model, type = "pearson") + 1 2 3 4 + -1.7480418 1.3037611 2.0750099 -0.6972966 + residuals(model, type = "working") + 1 2 3 4 + -0.6891489 0.3833588 0.9710682 -0.1096590 + residuals(model, type = "response") + 1 2 3 4 + -4.433948 2.216974 1.477983 -1.108487 */ val trainer = new GeneralizedLinearRegression() .setFamily("Poisson") .setWeightCol("weight") - .setFitIntercept(true) - - val model = trainer.fit(datasetWithWeight) - - val coefficientsR = Vectors.dense(Array(3.3241, -1.0818)) - val interceptR = 6.2999 - val devianceResidualsR = Array(-0.28952, 0.11048, 0.14839, -0.07268) - val pearsonResidualsR = Array(-0.28043145, 0.11099310, 0.14963714, -0.07253611) - val workingResidualsR = Array(-0.17960679, 0.02813593, 0.05113852, -0.01201650) - val responseResidualsR = Array(-0.4378554, 0.2189277, 0.1459518, -0.1094638) - val seCoefR = Array(1.0184, 0.3522, 1.6086) - val tValsR = Array(3.264, -3.071, 3.916) - val pValsR = Array(0.00110, 0.00213, 0.00009) - val dispersionR = 1.0 - val nullDevianceR = 15.38066 - val residualDevianceR = 0.12333 + .setOffsetCol("offset") + + val model = trainer.fit(dataset) + + val coefficientsR = Vectors.dense(Array(-2.87843, 0.88589)) + val interceptR = -4.56784 + val devianceResidualsR = Array(-2.04796, 1.23149, 1.82933, -0.71066) + val pearsonResidualsR = Array(-1.74804, 1.30376, 2.07501, -0.6973) + val workingResidualsR = Array(-0.68915, 0.38336, 0.97107, -0.10966) + val responseResidualsR = Array(-4.43395, 2.21697, 1.47798, -1.10849) + val seCoefR = Array(1.16826, 0.41703, 1.96249) + val tValsR = Array(-2.46387, 2.12428, -2.32757) + val pValsR = Array(0.01374, 0.03365, 0.01993) + val dispersionR = 1 + val nullDevianceR = 22.55853 + val residualDevianceR = 9.5622 val residualDegreeOfFreedomNullR = 3 val residualDegreeOfFreedomR = 1 - val aicR = 41.803 + val aicR = 51.24218 val summary = model.summary val devianceResiduals = summary.residuals() @@ -1157,78 +1238,79 @@ class GeneralizedLinearRegressionSuite assert(summary.solver === "irls") } - test("glm summary: gamma family with weight") { + test("glm summary: gamma family with weight and offset") { /* - R code: + R code: - A <- matrix(c(0, 1, 2, 3, 5, 7, 11, 13), 4, 2) - b <- c(2, 8, 3, 9) - w <- c(1, 2, 3, 4) - df <- as.data.frame(cbind(A, b)) + A <- matrix(c(0, 5, 1, 2, 2, 1, 3, 3), 4, 2, byrow = TRUE) + b <- c(1, 2, 1, 2) + w <- c(1, 2, 3, 4) + off <- c(0, 0.5, 1, 0) + df <- as.data.frame(cbind(A, b)) */ - val datasetWithWeight = Seq( - Instance(2.0, 1.0, Vectors.dense(0.0, 5.0).toSparse), - Instance(8.0, 2.0, Vectors.dense(1.0, 7.0)), - Instance(3.0, 3.0, Vectors.dense(2.0, 11.0)), - Instance(9.0, 4.0, Vectors.dense(3.0, 13.0)) + val dataset = Seq( + OffsetInstance(1.0, 1.0, 0.0, Vectors.dense(0.0, 5.0)), + OffsetInstance(2.0, 2.0, 0.5, Vectors.dense(1.0, 2.0)), + OffsetInstance(1.0, 3.0, 1.0, Vectors.dense(2.0, 1.0)), + OffsetInstance(2.0, 4.0, 0.0, Vectors.dense(3.0, 3.0)) ).toDF() /* - R code: - - model <- glm(formula = "b ~ .", family="Gamma", data = df, weights = w) - summary(model) + R code: - Deviance Residuals: - 1 2 3 4 - -0.26343 0.05761 0.12818 -0.03484 + model <- glm(formula = "b ~ .", family = "Gamma", data = df, + weights = w, offset = off) + summary(model) - Coefficients: - Estimate Std. Error t value Pr(>|t|) - (Intercept) -0.81511 0.23449 -3.476 0.178 - V1 -0.72730 0.16137 -4.507 0.139 - V2 0.23894 0.05481 4.359 0.144 + Deviance Residuals: + 1 2 3 4 + -0.17095 0.19867 -0.23604 0.03241 - (Dispersion parameter for Gamma family taken to be 0.07986091) + Coefficients: + Estimate Std. Error t value Pr(>|t|) + (Intercept) -0.56474 0.23866 -2.366 0.255 + V1 0.07695 0.06931 1.110 0.467 + V2 0.28068 0.07320 3.835 0.162 - Null deviance: 2.937462 on 3 degrees of freedom - Residual deviance: 0.090358 on 1 degrees of freedom - AIC: 23.202 + (Dispersion parameter for Gamma family taken to be 0.1212174) - Number of Fisher Scoring iterations: 4 + Null deviance: 2.02568 on 3 degrees of freedom + Residual deviance: 0.12546 on 1 degrees of freedom + AIC: 0.93388 - residuals(model, type="pearson") - 1 2 3 4 - -0.24082508 0.05839241 0.13135766 -0.03463621 + Number of Fisher Scoring iterations: 4 - residuals(model, type="working") + residuals(model, type = "pearson") + 1 2 3 4 + -0.16134949 0.20807694 -0.22544551 0.03258777 + residuals(model, type = "working") 1 2 3 4 - 0.091414181 -0.005374314 -0.027196998 0.001890910 - - residuals(model, type="response") - 1 2 3 4 - -0.6344390 0.3172195 0.2114797 -0.1586097 + 0.135315831 -0.084390309 0.113219135 -0.008279688 + residuals(model, type = "response") + 1 2 3 4 + -0.1923918 0.2565224 -0.1496381 0.0320653 */ val trainer = new GeneralizedLinearRegression() .setFamily("Gamma") .setWeightCol("weight") + .setOffsetCol("offset") + + val model = trainer.fit(dataset) - val model = trainer.fit(datasetWithWeight) - - val coefficientsR = Vectors.dense(Array(-0.72730, 0.23894)) - val interceptR = -0.81511 - val devianceResidualsR = Array(-0.26343, 0.05761, 0.12818, -0.03484) - val pearsonResidualsR = Array(-0.24082508, 0.05839241, 0.13135766, -0.03463621) - val workingResidualsR = Array(0.091414181, -0.005374314, -0.027196998, 0.001890910) - val responseResidualsR = Array(-0.6344390, 0.3172195, 0.2114797, -0.1586097) - val seCoefR = Array(0.16137, 0.05481, 0.23449) - val tValsR = Array(-4.507, 4.359, -3.476) - val pValsR = Array(0.139, 0.144, 0.178) - val dispersionR = 0.07986091 - val nullDevianceR = 2.937462 - val residualDevianceR = 0.090358 + val coefficientsR = Vectors.dense(Array(0.07695, 0.28068)) + val interceptR = -0.56474 + val devianceResidualsR = Array(-0.17095, 0.19867, -0.23604, 0.03241) + val pearsonResidualsR = Array(-0.16135, 0.20808, -0.22545, 0.03259) + val workingResidualsR = Array(0.13532, -0.08439, 0.11322, -0.00828) + val responseResidualsR = Array(-0.19239, 0.25652, -0.14964, 0.03207) + val seCoefR = Array(0.06931, 0.0732, 0.23866) + val tValsR = Array(1.11031, 3.83453, -2.3663) + val pValsR = Array(0.46675, 0.16241, 0.25454) + val dispersionR = 0.12122 + val nullDevianceR = 2.02568 + val residualDevianceR = 0.12546 val residualDegreeOfFreedomNullR = 3 val residualDegreeOfFreedomR = 1 - val aicR = 23.202 + val aicR = 0.93388 val summary = model.summary val devianceResiduals = summary.residuals() @@ -1271,77 +1353,81 @@ class GeneralizedLinearRegressionSuite assert(summary.solver === "irls") } - test("glm summary: tweedie family with weight") { + test("glm summary: tweedie family with weight and offset") { /* R code: - library(statmod) df <- as.data.frame(matrix(c( - 1.0, 1.0, 0.0, 5.0, - 0.5, 2.0, 1.0, 2.0, - 1.0, 3.0, 2.0, 1.0, - 0.0, 4.0, 3.0, 3.0), 4, 4, byrow = TRUE)) + 1.0, 1.0, 1.0, 0.0, 5.0, + 0.5, 2.0, 3.0, 1.0, 2.0, + 1.0, 3.0, 2.0, 2.0, 1.0, + 0.0, 4.0, 0.0, 3.0, 3.0), 4, 5, byrow = TRUE)) + */ + val dataset = Seq( + OffsetInstance(1.0, 1.0, 1.0, Vectors.dense(0.0, 5.0)), + OffsetInstance(0.5, 2.0, 3.0, Vectors.dense(1.0, 2.0)), + OffsetInstance(1.0, 3.0, 2.0, Vectors.dense(2.0, 1.0)), + OffsetInstance(0.0, 4.0, 0.0, Vectors.dense(3.0, 3.0)) + ).toDF() + /* + R code: - model <- glm(V1 ~ -1 + V3 + V4, data = df, weights = V2, - family = tweedie(var.power = 1.6, link.power = 0)) + library(statmod) + model <- glm(V1 ~ V4 + V5, data = df, weights = V2, offset = V3, + family = tweedie(var.power = 1.6, link.power = 0.0)) summary(model) Deviance Residuals: 1 2 3 4 - 0.6210 -0.0515 1.6935 -3.2539 + 0.8917 -2.1396 1.2252 -1.7946 Coefficients: - Estimate Std. Error t value Pr(>|t|) - V3 -0.4087 0.5205 -0.785 0.515 - V4 -0.1212 0.4082 -0.297 0.794 + Estimate Std. Error t value Pr(>|t|) + (Intercept) -0.03047 3.65000 -0.008 0.995 + V4 -1.14577 1.41674 -0.809 0.567 + V5 -0.36585 0.97065 -0.377 0.771 - (Dispersion parameter for Tweedie family taken to be 3.830036) + (Dispersion parameter for Tweedie family taken to be 6.334961) - Null deviance: 20.702 on 4 degrees of freedom - Residual deviance: 13.844 on 2 degrees of freedom + Null deviance: 12.784 on 3 degrees of freedom + Residual deviance: 10.095 on 1 degrees of freedom AIC: NA - Number of Fisher Scoring iterations: 11 - - residuals(model, type="pearson") - 1 2 3 4 - 0.7383616 -0.0509458 2.2348337 -1.4552090 - residuals(model, type="working") - 1 2 3 4 - 0.83354150 -0.04103552 1.55676369 -1.00000000 - residuals(model, type="response") - 1 2 3 4 - 0.45460738 -0.02139574 0.60888055 -0.20392801 + Number of Fisher Scoring iterations: 18 + + residuals(model, type = "pearson") + 1 2 3 4 + 1.1472554 -1.4642569 1.4935199 -0.8025842 + residuals(model, type = "working") + 1 2 3 4 + 1.3624928 -0.8322375 0.9894580 -1.0000000 + residuals(model, type = "response") + 1 2 3 4 + 0.57671828 -2.48040354 0.49735052 -0.01040646 */ - val datasetWithWeight = Seq( - Instance(1.0, 1.0, Vectors.dense(0.0, 5.0)), - Instance(0.5, 2.0, Vectors.dense(1.0, 2.0)), - Instance(1.0, 3.0, Vectors.dense(2.0, 1.0)), - Instance(0.0, 4.0, Vectors.dense(3.0, 3.0)) - ).toDF() - val trainer = new GeneralizedLinearRegression() .setFamily("tweedie") .setVariancePower(1.6) .setLinkPower(0.0) .setWeightCol("weight") - .setFitIntercept(false) - - val model = trainer.fit(datasetWithWeight) - val coefficientsR = Vectors.dense(Array(-0.408746, -0.12125)) - val interceptR = 0.0 - val devianceResidualsR = Array(0.621047, -0.051515, 1.693473, -3.253946) - val pearsonResidualsR = Array(0.738362, -0.050946, 2.234834, -1.455209) - val workingResidualsR = Array(0.833541, -0.041036, 1.556764, -1.0) - val responseResidualsR = Array(0.454607, -0.021396, 0.608881, -0.203928) - val seCoefR = Array(0.520519, 0.408215) - val tValsR = Array(-0.785267, -0.297024) - val pValsR = Array(0.514549, 0.794457) - val dispersionR = 3.830036 - val nullDevianceR = 20.702 - val residualDevianceR = 13.844 - val residualDegreeOfFreedomNullR = 4 - val residualDegreeOfFreedomR = 2 + .setOffsetCol("offset") + + val model = trainer.fit(dataset) + + val coefficientsR = Vectors.dense(Array(-1.14577, -0.36585)) + val interceptR = -0.03047 + val devianceResidualsR = Array(0.89171, -2.13961, 1.2252, -1.79463) + val pearsonResidualsR = Array(1.14726, -1.46426, 1.49352, -0.80258) + val workingResidualsR = Array(1.36249, -0.83224, 0.98946, -1) + val responseResidualsR = Array(0.57672, -2.4804, 0.49735, -0.01041) + val seCoefR = Array(1.41674, 0.97065, 3.65) + val tValsR = Array(-0.80873, -0.37691, -0.00835) + val pValsR = Array(0.56707, 0.77053, 0.99468) + val dispersionR = 6.33496 + val nullDevianceR = 12.78358 + val residualDevianceR = 10.09488 + val residualDegreeOfFreedomNullR = 3 + val residualDegreeOfFreedomR = 1 val summary = model.summary From 3c2fc19d478256f8dc0ae7219fdd188030218c07 Mon Sep 17 00:00:00 2001 From: Xingbo Jiang Date: Fri, 30 Jun 2017 20:30:26 +0800 Subject: [PATCH 0832/1765] [SPARK-18294][CORE] Implement commit protocol to support `mapred` package's committer ## What changes were proposed in this pull request? This PR makes the following changes: - Implement a new commit protocol `HadoopMapRedCommitProtocol` which support the old `mapred` package's committer; - Refactor SparkHadoopWriter and SparkHadoopMapReduceWriter, now they are combined together, thus we can support write through both mapred and mapreduce API by the new SparkHadoopWriter, a lot of duplicated codes are removed. After this change, it should be pretty easy for us to support the committer from both the new and the old hadoop API at high level. ## How was this patch tested? No major behavior change, passed the existing test cases. Author: Xingbo Jiang Closes #18438 from jiangxb1987/SparkHadoopWriter. --- .../io/HadoopMapRedCommitProtocol.scala | 36 ++ .../internal/io/HadoopWriteConfigUtil.scala | 79 ++++ .../io/SparkHadoopMapReduceWriter.scala | 181 -------- .../spark/internal/io/SparkHadoopWriter.scala | 393 ++++++++++++++---- .../apache/spark/rdd/PairRDDFunctions.scala | 72 +--- .../spark/rdd/PairRDDFunctionsSuite.scala | 2 +- .../OutputCommitCoordinatorSuite.scala | 35 +- 7 files changed, 461 insertions(+), 337 deletions(-) create mode 100644 core/src/main/scala/org/apache/spark/internal/io/HadoopMapRedCommitProtocol.scala create mode 100644 core/src/main/scala/org/apache/spark/internal/io/HadoopWriteConfigUtil.scala delete mode 100644 core/src/main/scala/org/apache/spark/internal/io/SparkHadoopMapReduceWriter.scala diff --git a/core/src/main/scala/org/apache/spark/internal/io/HadoopMapRedCommitProtocol.scala b/core/src/main/scala/org/apache/spark/internal/io/HadoopMapRedCommitProtocol.scala new file mode 100644 index 0000000000000..ddbd624b380d4 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/internal/io/HadoopMapRedCommitProtocol.scala @@ -0,0 +1,36 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.internal.io + +import org.apache.hadoop.mapred._ +import org.apache.hadoop.mapreduce.{TaskAttemptContext => NewTaskAttemptContext} + +/** + * An [[FileCommitProtocol]] implementation backed by an underlying Hadoop OutputCommitter + * (from the old mapred API). + * + * Unlike Hadoop's OutputCommitter, this implementation is serializable. + */ +class HadoopMapRedCommitProtocol(jobId: String, path: String) + extends HadoopMapReduceCommitProtocol(jobId, path) { + + override def setupCommitter(context: NewTaskAttemptContext): OutputCommitter = { + val config = context.getConfiguration.asInstanceOf[JobConf] + config.getOutputCommitter + } +} diff --git a/core/src/main/scala/org/apache/spark/internal/io/HadoopWriteConfigUtil.scala b/core/src/main/scala/org/apache/spark/internal/io/HadoopWriteConfigUtil.scala new file mode 100644 index 0000000000000..9b987e0e1bb67 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/internal/io/HadoopWriteConfigUtil.scala @@ -0,0 +1,79 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.internal.io + +import scala.reflect.ClassTag + +import org.apache.hadoop.mapreduce._ + +import org.apache.spark.SparkConf + +/** + * Interface for create output format/committer/writer used during saving an RDD using a Hadoop + * OutputFormat (both from the old mapred API and the new mapreduce API) + * + * Notes: + * 1. Implementations should throw [[IllegalArgumentException]] when wrong hadoop API is + * referenced; + * 2. Implementations must be serializable, as the instance instantiated on the driver + * will be used for tasks on executors; + * 3. Implementations should have a constructor with exactly one argument: + * (conf: SerializableConfiguration) or (conf: SerializableJobConf). + */ +abstract class HadoopWriteConfigUtil[K, V: ClassTag] extends Serializable { + + // -------------------------------------------------------------------------- + // Create JobContext/TaskAttemptContext + // -------------------------------------------------------------------------- + + def createJobContext(jobTrackerId: String, jobId: Int): JobContext + + def createTaskAttemptContext( + jobTrackerId: String, + jobId: Int, + splitId: Int, + taskAttemptId: Int): TaskAttemptContext + + // -------------------------------------------------------------------------- + // Create committer + // -------------------------------------------------------------------------- + + def createCommitter(jobId: Int): HadoopMapReduceCommitProtocol + + // -------------------------------------------------------------------------- + // Create writer + // -------------------------------------------------------------------------- + + def initWriter(taskContext: TaskAttemptContext, splitId: Int): Unit + + def write(pair: (K, V)): Unit + + def closeWriter(taskContext: TaskAttemptContext): Unit + + // -------------------------------------------------------------------------- + // Create OutputFormat + // -------------------------------------------------------------------------- + + def initOutputFormat(jobContext: JobContext): Unit + + // -------------------------------------------------------------------------- + // Verify hadoop config + // -------------------------------------------------------------------------- + + def assertConf(jobContext: JobContext, conf: SparkConf): Unit +} diff --git a/core/src/main/scala/org/apache/spark/internal/io/SparkHadoopMapReduceWriter.scala b/core/src/main/scala/org/apache/spark/internal/io/SparkHadoopMapReduceWriter.scala deleted file mode 100644 index 376ff9bb19f74..0000000000000 --- a/core/src/main/scala/org/apache/spark/internal/io/SparkHadoopMapReduceWriter.scala +++ /dev/null @@ -1,181 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.internal.io - -import java.text.SimpleDateFormat -import java.util.{Date, Locale} - -import scala.reflect.ClassTag -import scala.util.DynamicVariable - -import org.apache.hadoop.conf.{Configurable, Configuration} -import org.apache.hadoop.fs.Path -import org.apache.hadoop.mapred.{JobConf, JobID} -import org.apache.hadoop.mapreduce._ -import org.apache.hadoop.mapreduce.task.TaskAttemptContextImpl - -import org.apache.spark.{SparkConf, SparkException, TaskContext} -import org.apache.spark.deploy.SparkHadoopUtil -import org.apache.spark.executor.OutputMetrics -import org.apache.spark.internal.Logging -import org.apache.spark.internal.io.FileCommitProtocol.TaskCommitMessage -import org.apache.spark.rdd.RDD -import org.apache.spark.util.{SerializableConfiguration, Utils} - -/** - * A helper object that saves an RDD using a Hadoop OutputFormat - * (from the newer mapreduce API, not the old mapred API). - */ -private[spark] -object SparkHadoopMapReduceWriter extends Logging { - - /** - * Basic work flow of this command is: - * 1. Driver side setup, prepare the data source and hadoop configuration for the write job to - * be issued. - * 2. Issues a write job consists of one or more executor side tasks, each of which writes all - * rows within an RDD partition. - * 3. If no exception is thrown in a task, commits that task, otherwise aborts that task; If any - * exception is thrown during task commitment, also aborts that task. - * 4. If all tasks are committed, commit the job, otherwise aborts the job; If any exception is - * thrown during job commitment, also aborts the job. - */ - def write[K, V: ClassTag]( - rdd: RDD[(K, V)], - hadoopConf: Configuration): Unit = { - // Extract context and configuration from RDD. - val sparkContext = rdd.context - val stageId = rdd.id - val sparkConf = rdd.conf - val conf = new SerializableConfiguration(hadoopConf) - - // Set up a job. - val jobTrackerId = SparkHadoopWriterUtils.createJobTrackerID(new Date()) - val jobAttemptId = new TaskAttemptID(jobTrackerId, stageId, TaskType.MAP, 0, 0) - val jobContext = new TaskAttemptContextImpl(conf.value, jobAttemptId) - val format = jobContext.getOutputFormatClass - - if (SparkHadoopWriterUtils.isOutputSpecValidationEnabled(sparkConf)) { - // FileOutputFormat ignores the filesystem parameter - val jobFormat = format.newInstance - jobFormat.checkOutputSpecs(jobContext) - } - - val committer = FileCommitProtocol.instantiate( - className = classOf[HadoopMapReduceCommitProtocol].getName, - jobId = stageId.toString, - outputPath = conf.value.get("mapreduce.output.fileoutputformat.outputdir"), - isAppend = false).asInstanceOf[HadoopMapReduceCommitProtocol] - committer.setupJob(jobContext) - - // Try to write all RDD partitions as a Hadoop OutputFormat. - try { - val ret = sparkContext.runJob(rdd, (context: TaskContext, iter: Iterator[(K, V)]) => { - executeTask( - context = context, - jobTrackerId = jobTrackerId, - sparkStageId = context.stageId, - sparkPartitionId = context.partitionId, - sparkAttemptNumber = context.attemptNumber, - committer = committer, - hadoopConf = conf.value, - outputFormat = format.asInstanceOf[Class[OutputFormat[K, V]]], - iterator = iter) - }) - - committer.commitJob(jobContext, ret) - logInfo(s"Job ${jobContext.getJobID} committed.") - } catch { - case cause: Throwable => - logError(s"Aborting job ${jobContext.getJobID}.", cause) - committer.abortJob(jobContext) - throw new SparkException("Job aborted.", cause) - } - } - - /** Write an RDD partition out in a single Spark task. */ - private def executeTask[K, V: ClassTag]( - context: TaskContext, - jobTrackerId: String, - sparkStageId: Int, - sparkPartitionId: Int, - sparkAttemptNumber: Int, - committer: FileCommitProtocol, - hadoopConf: Configuration, - outputFormat: Class[_ <: OutputFormat[K, V]], - iterator: Iterator[(K, V)]): TaskCommitMessage = { - // Set up a task. - val attemptId = new TaskAttemptID(jobTrackerId, sparkStageId, TaskType.REDUCE, - sparkPartitionId, sparkAttemptNumber) - val taskContext = new TaskAttemptContextImpl(hadoopConf, attemptId) - committer.setupTask(taskContext) - - val (outputMetrics, callback) = SparkHadoopWriterUtils.initHadoopOutputMetrics(context) - - // Initiate the writer. - val taskFormat = outputFormat.newInstance() - // If OutputFormat is Configurable, we should set conf to it. - taskFormat match { - case c: Configurable => c.setConf(hadoopConf) - case _ => () - } - var writer = taskFormat.getRecordWriter(taskContext) - .asInstanceOf[RecordWriter[K, V]] - require(writer != null, "Unable to obtain RecordWriter") - var recordsWritten = 0L - - // Write all rows in RDD partition. - try { - val ret = Utils.tryWithSafeFinallyAndFailureCallbacks { - // Write rows out, release resource and commit the task. - while (iterator.hasNext) { - val pair = iterator.next() - writer.write(pair._1, pair._2) - - // Update bytes written metric every few records - SparkHadoopWriterUtils.maybeUpdateOutputMetrics(outputMetrics, callback, recordsWritten) - recordsWritten += 1 - } - if (writer != null) { - writer.close(taskContext) - writer = null - } - committer.commitTask(taskContext) - }(catchBlock = { - // If there is an error, release resource and then abort the task. - try { - if (writer != null) { - writer.close(taskContext) - writer = null - } - } finally { - committer.abortTask(taskContext) - logError(s"Task ${taskContext.getTaskAttemptID} aborted.") - } - }) - - outputMetrics.setBytesWritten(callback()) - outputMetrics.setRecordsWritten(recordsWritten) - - ret - } catch { - case t: Throwable => - throw new SparkException("Task failed while writing rows", t) - } - } -} diff --git a/core/src/main/scala/org/apache/spark/internal/io/SparkHadoopWriter.scala b/core/src/main/scala/org/apache/spark/internal/io/SparkHadoopWriter.scala index acc9c38571007..7d846f9354df6 100644 --- a/core/src/main/scala/org/apache/spark/internal/io/SparkHadoopWriter.scala +++ b/core/src/main/scala/org/apache/spark/internal/io/SparkHadoopWriter.scala @@ -17,143 +17,374 @@ package org.apache.spark.internal.io -import java.io.IOException -import java.text.{NumberFormat, SimpleDateFormat} +import java.text.NumberFormat import java.util.{Date, Locale} +import scala.reflect.ClassTag + +import org.apache.hadoop.conf.{Configurable, Configuration} import org.apache.hadoop.fs.FileSystem import org.apache.hadoop.mapred._ -import org.apache.hadoop.mapreduce.TaskType +import org.apache.hadoop.mapreduce.{JobContext => NewJobContext, +OutputFormat => NewOutputFormat, RecordWriter => NewRecordWriter, +TaskAttemptContext => NewTaskAttemptContext, TaskAttemptID => NewTaskAttemptID, TaskType} +import org.apache.hadoop.mapreduce.task.{TaskAttemptContextImpl => NewTaskAttemptContextImpl} -import org.apache.spark.SerializableWritable +import org.apache.spark.{SerializableWritable, SparkConf, SparkException, TaskContext} +import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.internal.Logging -import org.apache.spark.mapred.SparkHadoopMapRedUtil -import org.apache.spark.rdd.HadoopRDD -import org.apache.spark.util.SerializableJobConf +import org.apache.spark.internal.io.FileCommitProtocol.TaskCommitMessage +import org.apache.spark.rdd.{HadoopRDD, RDD} +import org.apache.spark.util.{SerializableConfiguration, SerializableJobConf, Utils} /** - * Internal helper class that saves an RDD using a Hadoop OutputFormat. - * - * Saves the RDD using a JobConf, which should contain an output key class, an output value class, - * a filename to write to, etc, exactly like in a Hadoop MapReduce job. + * A helper object that saves an RDD using a Hadoop OutputFormat. + */ +private[spark] +object SparkHadoopWriter extends Logging { + import SparkHadoopWriterUtils._ + + /** + * Basic work flow of this command is: + * 1. Driver side setup, prepare the data source and hadoop configuration for the write job to + * be issued. + * 2. Issues a write job consists of one or more executor side tasks, each of which writes all + * rows within an RDD partition. + * 3. If no exception is thrown in a task, commits that task, otherwise aborts that task; If any + * exception is thrown during task commitment, also aborts that task. + * 4. If all tasks are committed, commit the job, otherwise aborts the job; If any exception is + * thrown during job commitment, also aborts the job. + */ + def write[K, V: ClassTag]( + rdd: RDD[(K, V)], + config: HadoopWriteConfigUtil[K, V]): Unit = { + // Extract context and configuration from RDD. + val sparkContext = rdd.context + val stageId = rdd.id + + // Set up a job. + val jobTrackerId = createJobTrackerID(new Date()) + val jobContext = config.createJobContext(jobTrackerId, stageId) + config.initOutputFormat(jobContext) + + // Assert the output format/key/value class is set in JobConf. + config.assertConf(jobContext, rdd.conf) + + val committer = config.createCommitter(stageId) + committer.setupJob(jobContext) + + // Try to write all RDD partitions as a Hadoop OutputFormat. + try { + val ret = sparkContext.runJob(rdd, (context: TaskContext, iter: Iterator[(K, V)]) => { + executeTask( + context = context, + config = config, + jobTrackerId = jobTrackerId, + sparkStageId = context.stageId, + sparkPartitionId = context.partitionId, + sparkAttemptNumber = context.attemptNumber, + committer = committer, + iterator = iter) + }) + + committer.commitJob(jobContext, ret) + logInfo(s"Job ${jobContext.getJobID} committed.") + } catch { + case cause: Throwable => + logError(s"Aborting job ${jobContext.getJobID}.", cause) + committer.abortJob(jobContext) + throw new SparkException("Job aborted.", cause) + } + } + + /** Write a RDD partition out in a single Spark task. */ + private def executeTask[K, V: ClassTag]( + context: TaskContext, + config: HadoopWriteConfigUtil[K, V], + jobTrackerId: String, + sparkStageId: Int, + sparkPartitionId: Int, + sparkAttemptNumber: Int, + committer: FileCommitProtocol, + iterator: Iterator[(K, V)]): TaskCommitMessage = { + // Set up a task. + val taskContext = config.createTaskAttemptContext( + jobTrackerId, sparkStageId, sparkPartitionId, sparkAttemptNumber) + committer.setupTask(taskContext) + + val (outputMetrics, callback) = initHadoopOutputMetrics(context) + + // Initiate the writer. + config.initWriter(taskContext, sparkPartitionId) + var recordsWritten = 0L + + // Write all rows in RDD partition. + try { + val ret = Utils.tryWithSafeFinallyAndFailureCallbacks { + while (iterator.hasNext) { + val pair = iterator.next() + config.write(pair) + + // Update bytes written metric every few records + maybeUpdateOutputMetrics(outputMetrics, callback, recordsWritten) + recordsWritten += 1 + } + + config.closeWriter(taskContext) + committer.commitTask(taskContext) + }(catchBlock = { + // If there is an error, release resource and then abort the task. + try { + config.closeWriter(taskContext) + } finally { + committer.abortTask(taskContext) + logError(s"Task ${taskContext.getTaskAttemptID} aborted.") + } + }) + + outputMetrics.setBytesWritten(callback()) + outputMetrics.setRecordsWritten(recordsWritten) + + ret + } catch { + case t: Throwable => + throw new SparkException("Task failed while writing rows", t) + } + } +} + +/** + * A helper class that reads JobConf from older mapred API, creates output Format/Committer/Writer. */ private[spark] -class SparkHadoopWriter(jobConf: JobConf) extends Logging with Serializable { +class HadoopMapRedWriteConfigUtil[K, V: ClassTag](conf: SerializableJobConf) + extends HadoopWriteConfigUtil[K, V] with Logging { - private val now = new Date() - private val conf = new SerializableJobConf(jobConf) + private var outputFormat: Class[_ <: OutputFormat[K, V]] = null + private var writer: RecordWriter[K, V] = null - private var jobID = 0 - private var splitID = 0 - private var attemptID = 0 - private var jID: SerializableWritable[JobID] = null - private var taID: SerializableWritable[TaskAttemptID] = null + private def getConf: JobConf = conf.value - @transient private var writer: RecordWriter[AnyRef, AnyRef] = null - @transient private var format: OutputFormat[AnyRef, AnyRef] = null - @transient private var committer: OutputCommitter = null - @transient private var jobContext: JobContext = null - @transient private var taskContext: TaskAttemptContext = null + // -------------------------------------------------------------------------- + // Create JobContext/TaskAttemptContext + // -------------------------------------------------------------------------- - def preSetup() { - setIDs(0, 0, 0) - HadoopRDD.addLocalConfiguration("", 0, 0, 0, conf.value) + override def createJobContext(jobTrackerId: String, jobId: Int): NewJobContext = { + val jobAttemptId = new SerializableWritable(new JobID(jobTrackerId, jobId)) + new JobContextImpl(getConf, jobAttemptId.value) + } - val jCtxt = getJobContext() - getOutputCommitter().setupJob(jCtxt) + override def createTaskAttemptContext( + jobTrackerId: String, + jobId: Int, + splitId: Int, + taskAttemptId: Int): NewTaskAttemptContext = { + // Update JobConf. + HadoopRDD.addLocalConfiguration(jobTrackerId, jobId, splitId, taskAttemptId, conf.value) + // Create taskContext. + val attemptId = new TaskAttemptID(jobTrackerId, jobId, TaskType.MAP, splitId, taskAttemptId) + new TaskAttemptContextImpl(getConf, attemptId) } + // -------------------------------------------------------------------------- + // Create committer + // -------------------------------------------------------------------------- - def setup(jobid: Int, splitid: Int, attemptid: Int) { - setIDs(jobid, splitid, attemptid) - HadoopRDD.addLocalConfiguration(new SimpleDateFormat("yyyyMMddHHmmss", Locale.US).format(now), - jobid, splitID, attemptID, conf.value) + override def createCommitter(jobId: Int): HadoopMapReduceCommitProtocol = { + // Update JobConf. + HadoopRDD.addLocalConfiguration("", 0, 0, 0, getConf) + // Create commit protocol. + FileCommitProtocol.instantiate( + className = classOf[HadoopMapRedCommitProtocol].getName, + jobId = jobId.toString, + outputPath = getConf.get("mapred.output.dir"), + isAppend = false).asInstanceOf[HadoopMapReduceCommitProtocol] } - def open() { + // -------------------------------------------------------------------------- + // Create writer + // -------------------------------------------------------------------------- + + override def initWriter(taskContext: NewTaskAttemptContext, splitId: Int): Unit = { val numfmt = NumberFormat.getInstance(Locale.US) numfmt.setMinimumIntegerDigits(5) numfmt.setGroupingUsed(false) - val outputName = "part-" + numfmt.format(splitID) - val path = FileOutputFormat.getOutputPath(conf.value) + val outputName = "part-" + numfmt.format(splitId) + val path = FileOutputFormat.getOutputPath(getConf) val fs: FileSystem = { if (path != null) { - path.getFileSystem(conf.value) + path.getFileSystem(getConf) } else { - FileSystem.get(conf.value) + FileSystem.get(getConf) } } - getOutputCommitter().setupTask(getTaskContext()) - writer = getOutputFormat().getRecordWriter(fs, conf.value, outputName, Reporter.NULL) + writer = getConf.getOutputFormat + .getRecordWriter(fs, getConf, outputName, Reporter.NULL) + .asInstanceOf[RecordWriter[K, V]] + + require(writer != null, "Unable to obtain RecordWriter") } - def write(key: AnyRef, value: AnyRef) { + override def write(pair: (K, V)): Unit = { + require(writer != null, "Must call createWriter before write.") + writer.write(pair._1, pair._2) + } + + override def closeWriter(taskContext: NewTaskAttemptContext): Unit = { if (writer != null) { - writer.write(key, value) - } else { - throw new IOException("Writer is null, open() has not been called") + writer.close(Reporter.NULL) } } - def close() { - writer.close(Reporter.NULL) - } + // -------------------------------------------------------------------------- + // Create OutputFormat + // -------------------------------------------------------------------------- - def commit() { - SparkHadoopMapRedUtil.commitTask(getOutputCommitter(), getTaskContext(), jobID, splitID) + override def initOutputFormat(jobContext: NewJobContext): Unit = { + if (outputFormat == null) { + outputFormat = getConf.getOutputFormat.getClass + .asInstanceOf[Class[_ <: OutputFormat[K, V]]] + } } - def commitJob() { - val cmtr = getOutputCommitter() - cmtr.commitJob(getJobContext()) + private def getOutputFormat(): OutputFormat[K, V] = { + require(outputFormat != null, "Must call initOutputFormat first.") + + outputFormat.newInstance() } - // ********* Private Functions ********* + // -------------------------------------------------------------------------- + // Verify hadoop config + // -------------------------------------------------------------------------- + + override def assertConf(jobContext: NewJobContext, conf: SparkConf): Unit = { + val outputFormatInstance = getOutputFormat() + val keyClass = getConf.getOutputKeyClass + val valueClass = getConf.getOutputValueClass + if (outputFormatInstance == null) { + throw new SparkException("Output format class not set") + } + if (keyClass == null) { + throw new SparkException("Output key class not set") + } + if (valueClass == null) { + throw new SparkException("Output value class not set") + } + SparkHadoopUtil.get.addCredentials(getConf) + + logDebug("Saving as hadoop file of type (" + keyClass.getSimpleName + ", " + + valueClass.getSimpleName + ")") - private def getOutputFormat(): OutputFormat[AnyRef, AnyRef] = { - if (format == null) { - format = conf.value.getOutputFormat() - .asInstanceOf[OutputFormat[AnyRef, AnyRef]] + if (SparkHadoopWriterUtils.isOutputSpecValidationEnabled(conf)) { + // FileOutputFormat ignores the filesystem parameter + val ignoredFs = FileSystem.get(getConf) + getOutputFormat().checkOutputSpecs(ignoredFs, getConf) } - format + } +} + +/** + * A helper class that reads Configuration from newer mapreduce API, creates output + * Format/Committer/Writer. + */ +private[spark] +class HadoopMapReduceWriteConfigUtil[K, V: ClassTag](conf: SerializableConfiguration) + extends HadoopWriteConfigUtil[K, V] with Logging { + + private var outputFormat: Class[_ <: NewOutputFormat[K, V]] = null + private var writer: NewRecordWriter[K, V] = null + + private def getConf: Configuration = conf.value + + // -------------------------------------------------------------------------- + // Create JobContext/TaskAttemptContext + // -------------------------------------------------------------------------- + + override def createJobContext(jobTrackerId: String, jobId: Int): NewJobContext = { + val jobAttemptId = new NewTaskAttemptID(jobTrackerId, jobId, TaskType.MAP, 0, 0) + new NewTaskAttemptContextImpl(getConf, jobAttemptId) + } + + override def createTaskAttemptContext( + jobTrackerId: String, + jobId: Int, + splitId: Int, + taskAttemptId: Int): NewTaskAttemptContext = { + val attemptId = new NewTaskAttemptID( + jobTrackerId, jobId, TaskType.REDUCE, splitId, taskAttemptId) + new NewTaskAttemptContextImpl(getConf, attemptId) + } + + // -------------------------------------------------------------------------- + // Create committer + // -------------------------------------------------------------------------- + + override def createCommitter(jobId: Int): HadoopMapReduceCommitProtocol = { + FileCommitProtocol.instantiate( + className = classOf[HadoopMapReduceCommitProtocol].getName, + jobId = jobId.toString, + outputPath = getConf.get("mapreduce.output.fileoutputformat.outputdir"), + isAppend = false).asInstanceOf[HadoopMapReduceCommitProtocol] } - private def getOutputCommitter(): OutputCommitter = { - if (committer == null) { - committer = conf.value.getOutputCommitter + // -------------------------------------------------------------------------- + // Create writer + // -------------------------------------------------------------------------- + + override def initWriter(taskContext: NewTaskAttemptContext, splitId: Int): Unit = { + val taskFormat = getOutputFormat() + // If OutputFormat is Configurable, we should set conf to it. + taskFormat match { + case c: Configurable => c.setConf(getConf) + case _ => () } - committer + + writer = taskFormat.getRecordWriter(taskContext) + .asInstanceOf[NewRecordWriter[K, V]] + + require(writer != null, "Unable to obtain RecordWriter") + } + + override def write(pair: (K, V)): Unit = { + require(writer != null, "Must call createWriter before write.") + writer.write(pair._1, pair._2) } - private def getJobContext(): JobContext = { - if (jobContext == null) { - jobContext = new JobContextImpl(conf.value, jID.value) + override def closeWriter(taskContext: NewTaskAttemptContext): Unit = { + if (writer != null) { + writer.close(taskContext) + writer = null + } else { + logWarning("Writer has been closed.") } - jobContext } - private def getTaskContext(): TaskAttemptContext = { - if (taskContext == null) { - taskContext = newTaskAttemptContext(conf.value, taID.value) + // -------------------------------------------------------------------------- + // Create OutputFormat + // -------------------------------------------------------------------------- + + override def initOutputFormat(jobContext: NewJobContext): Unit = { + if (outputFormat == null) { + outputFormat = jobContext.getOutputFormatClass + .asInstanceOf[Class[_ <: NewOutputFormat[K, V]]] } - taskContext } - protected def newTaskAttemptContext( - conf: JobConf, - attemptId: TaskAttemptID): TaskAttemptContext = { - new TaskAttemptContextImpl(conf, attemptId) + private def getOutputFormat(): NewOutputFormat[K, V] = { + require(outputFormat != null, "Must call initOutputFormat first.") + + outputFormat.newInstance() } - private def setIDs(jobid: Int, splitid: Int, attemptid: Int) { - jobID = jobid - splitID = splitid - attemptID = attemptid + // -------------------------------------------------------------------------- + // Verify hadoop config + // -------------------------------------------------------------------------- - jID = new SerializableWritable[JobID](SparkHadoopWriterUtils.createJobID(now, jobid)) - taID = new SerializableWritable[TaskAttemptID]( - new TaskAttemptID(new TaskID(jID.value, TaskType.MAP, splitID), attemptID)) + override def assertConf(jobContext: NewJobContext, conf: SparkConf): Unit = { + if (SparkHadoopWriterUtils.isOutputSpecValidationEnabled(conf)) { + getOutputFormat().checkOutputSpecs(jobContext) + } } } diff --git a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala index 58762cc0838cd..4628fa8ba270e 100644 --- a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala +++ b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala @@ -27,7 +27,6 @@ import scala.reflect.ClassTag import com.clearspring.analytics.stream.cardinality.HyperLogLogPlus import org.apache.hadoop.conf.Configuration -import org.apache.hadoop.fs.FileSystem import org.apache.hadoop.io.SequenceFile.CompressionType import org.apache.hadoop.io.compress.CompressionCodec import org.apache.hadoop.mapred.{FileOutputCommitter, FileOutputFormat, JobConf, OutputFormat} @@ -36,13 +35,11 @@ import org.apache.hadoop.mapreduce.{Job => NewAPIHadoopJob, OutputFormat => NewO import org.apache.spark._ import org.apache.spark.Partitioner.defaultPartitioner import org.apache.spark.annotation.Experimental -import org.apache.spark.deploy.SparkHadoopUtil -import org.apache.spark.internal.io.{SparkHadoopMapReduceWriter, SparkHadoopWriter, - SparkHadoopWriterUtils} +import org.apache.spark.internal.io._ import org.apache.spark.internal.Logging import org.apache.spark.partial.{BoundedDouble, PartialResult} import org.apache.spark.serializer.Serializer -import org.apache.spark.util.Utils +import org.apache.spark.util.{SerializableConfiguration, SerializableJobConf, Utils} import org.apache.spark.util.collection.CompactBuffer import org.apache.spark.util.random.StratifiedSamplingUtils @@ -1082,9 +1079,10 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) * result of using direct output committer with speculation enabled. */ def saveAsNewAPIHadoopDataset(conf: Configuration): Unit = self.withScope { - SparkHadoopMapReduceWriter.write( + val config = new HadoopMapReduceWriteConfigUtil[K, V](new SerializableConfiguration(conf)) + SparkHadoopWriter.write( rdd = self, - hadoopConf = conf) + config = config) } /** @@ -1094,62 +1092,10 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) * MapReduce job. */ def saveAsHadoopDataset(conf: JobConf): Unit = self.withScope { - // Rename this as hadoopConf internally to avoid shadowing (see SPARK-2038). - val hadoopConf = conf - val outputFormatInstance = hadoopConf.getOutputFormat - val keyClass = hadoopConf.getOutputKeyClass - val valueClass = hadoopConf.getOutputValueClass - if (outputFormatInstance == null) { - throw new SparkException("Output format class not set") - } - if (keyClass == null) { - throw new SparkException("Output key class not set") - } - if (valueClass == null) { - throw new SparkException("Output value class not set") - } - SparkHadoopUtil.get.addCredentials(hadoopConf) - - logDebug("Saving as hadoop file of type (" + keyClass.getSimpleName + ", " + - valueClass.getSimpleName + ")") - - if (SparkHadoopWriterUtils.isOutputSpecValidationEnabled(self.conf)) { - // FileOutputFormat ignores the filesystem parameter - val ignoredFs = FileSystem.get(hadoopConf) - hadoopConf.getOutputFormat.checkOutputSpecs(ignoredFs, hadoopConf) - } - - val writer = new SparkHadoopWriter(hadoopConf) - writer.preSetup() - - val writeToFile = (context: TaskContext, iter: Iterator[(K, V)]) => { - // Hadoop wants a 32-bit task attempt ID, so if ours is bigger than Int.MaxValue, roll it - // around by taking a mod. We expect that no task will be attempted 2 billion times. - val taskAttemptId = (context.taskAttemptId % Int.MaxValue).toInt - - val (outputMetrics, callback) = SparkHadoopWriterUtils.initHadoopOutputMetrics(context) - - writer.setup(context.stageId, context.partitionId, taskAttemptId) - writer.open() - var recordsWritten = 0L - - Utils.tryWithSafeFinallyAndFailureCallbacks { - while (iter.hasNext) { - val record = iter.next() - writer.write(record._1.asInstanceOf[AnyRef], record._2.asInstanceOf[AnyRef]) - - // Update bytes written metric every few records - SparkHadoopWriterUtils.maybeUpdateOutputMetrics(outputMetrics, callback, recordsWritten) - recordsWritten += 1 - } - }(finallyBlock = writer.close()) - writer.commit() - outputMetrics.setBytesWritten(callback()) - outputMetrics.setRecordsWritten(recordsWritten) - } - - self.context.runJob(self, writeToFile) - writer.commitJob() + val config = new HadoopMapRedWriteConfigUtil[K, V](new SerializableJobConf(conf)) + SparkHadoopWriter.write( + rdd = self, + config = config) } /** diff --git a/core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala b/core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala index 02df157be377c..44dd955ce8690 100644 --- a/core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala +++ b/core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala @@ -561,7 +561,7 @@ class PairRDDFunctionsSuite extends SparkFunSuite with SharedSparkContext { pairs.saveAsHadoopFile( "ignored", pairs.keyClass, pairs.valueClass, classOf[FakeFormatWithCallback], conf) } - assert(e.getMessage contains "failed to write") + assert(e.getCause.getMessage contains "failed to write") assert(FakeWriterWithCallback.calledBy === "write,callback,close") assert(FakeWriterWithCallback.exception != null, "exception should be captured") diff --git a/core/src/test/scala/org/apache/spark/scheduler/OutputCommitCoordinatorSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/OutputCommitCoordinatorSuite.scala index e51e6a0d3ff6b..1579b614ea5b0 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/OutputCommitCoordinatorSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/OutputCommitCoordinatorSuite.scala @@ -18,12 +18,14 @@ package org.apache.spark.scheduler import java.io.File +import java.util.Date import java.util.concurrent.TimeoutException import scala.concurrent.duration._ import scala.language.postfixOps -import org.apache.hadoop.mapred.{JobConf, OutputCommitter, TaskAttemptContext, TaskAttemptID} +import org.apache.hadoop.mapred._ +import org.apache.hadoop.mapreduce.TaskType import org.mockito.Matchers import org.mockito.Mockito._ import org.mockito.invocation.InvocationOnMock @@ -31,7 +33,7 @@ import org.mockito.stubbing.Answer import org.scalatest.BeforeAndAfter import org.apache.spark._ -import org.apache.spark.internal.io.SparkHadoopWriter +import org.apache.spark.internal.io.{FileCommitProtocol, HadoopMapRedCommitProtocol, SparkHadoopWriterUtils} import org.apache.spark.rdd.{FakeOutputCommitter, RDD} import org.apache.spark.util.{ThreadUtils, Utils} @@ -214,6 +216,8 @@ class OutputCommitCoordinatorSuite extends SparkFunSuite with BeforeAndAfter { */ private case class OutputCommitFunctions(tempDirPath: String) { + private val jobId = new SerializableWritable(SparkHadoopWriterUtils.createJobID(new Date, 0)) + // Mock output committer that simulates a successful commit (after commit is authorized) private def successfulOutputCommitter = new FakeOutputCommitter { override def commitTask(context: TaskAttemptContext): Unit = { @@ -256,14 +260,23 @@ private case class OutputCommitFunctions(tempDirPath: String) { def jobConf = new JobConf { override def getOutputCommitter(): OutputCommitter = outputCommitter } - val sparkHadoopWriter = new SparkHadoopWriter(jobConf) { - override def newTaskAttemptContext( - conf: JobConf, - attemptId: TaskAttemptID): TaskAttemptContext = { - mock(classOf[TaskAttemptContext]) - } - } - sparkHadoopWriter.setup(ctx.stageId, ctx.partitionId, ctx.attemptNumber) - sparkHadoopWriter.commit() + + // Instantiate committer. + val committer = FileCommitProtocol.instantiate( + className = classOf[HadoopMapRedCommitProtocol].getName, + jobId = jobId.value.getId.toString, + outputPath = jobConf.get("mapred.output.dir"), + isAppend = false) + + // Create TaskAttemptContext. + // Hadoop wants a 32-bit task attempt ID, so if ours is bigger than Int.MaxValue, roll it + // around by taking a mod. We expect that no task will be attempted 2 billion times. + val taskAttemptId = (ctx.taskAttemptId % Int.MaxValue).toInt + val attemptId = new TaskAttemptID( + new TaskID(jobId.value, TaskType.MAP, ctx.partitionId), taskAttemptId) + val taskContext = new TaskAttemptContextImpl(jobConf, attemptId) + + committer.setupTask(taskContext) + committer.commitTask(taskContext) } } From 528c9281aecc49e9bff204dd303962c705c6f237 Mon Sep 17 00:00:00 2001 From: Yanbo Liang Date: Fri, 30 Jun 2017 23:25:14 +0800 Subject: [PATCH 0833/1765] [ML] Fix scala-2.10 build failure of GeneralizedLinearRegressionSuite. ## What changes were proposed in this pull request? Fix scala-2.10 build failure of ```GeneralizedLinearRegressionSuite```. ## How was this patch tested? Build with scala-2.10. Author: Yanbo Liang Closes #18489 from yanboliang/glr. --- .../ml/regression/GeneralizedLinearRegressionSuite.scala | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala index cfaa57314bd66..83f1344a7bcb1 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala @@ -1075,7 +1075,7 @@ class GeneralizedLinearRegressionSuite val seCoefR = Array(1.23439, 0.9669, 3.56866) val tValsR = Array(0.80297, -0.65737, -0.06017) val pValsR = Array(0.42199, 0.51094, 0.95202) - val dispersionR = 1 + val dispersionR = 1.0 val nullDevianceR = 2.17561 val residualDevianceR = 0.00018 val residualDegreeOfFreedomNullR = 3 @@ -1114,7 +1114,7 @@ class GeneralizedLinearRegressionSuite assert(x._1 ~== x._2 absTol 1E-3) } summary.tValues.zip(tValsR).foreach{ x => assert(x._1 ~== x._2 absTol 1E-3) } summary.pValues.zip(pValsR).foreach{ x => assert(x._1 ~== x._2 absTol 1E-3) } - assert(summary.dispersion ~== dispersionR absTol 1E-3) + assert(summary.dispersion === dispersionR) assert(summary.nullDeviance ~== nullDevianceR absTol 1E-3) assert(summary.deviance ~== residualDevianceR absTol 1E-3) assert(summary.residualDegreeOfFreedom === residualDegreeOfFreedomR) @@ -1190,7 +1190,7 @@ class GeneralizedLinearRegressionSuite val seCoefR = Array(1.16826, 0.41703, 1.96249) val tValsR = Array(-2.46387, 2.12428, -2.32757) val pValsR = Array(0.01374, 0.03365, 0.01993) - val dispersionR = 1 + val dispersionR = 1.0 val nullDevianceR = 22.55853 val residualDevianceR = 9.5622 val residualDegreeOfFreedomNullR = 3 @@ -1229,7 +1229,7 @@ class GeneralizedLinearRegressionSuite assert(x._1 ~== x._2 absTol 1E-3) } summary.tValues.zip(tValsR).foreach{ x => assert(x._1 ~== x._2 absTol 1E-3) } summary.pValues.zip(pValsR).foreach{ x => assert(x._1 ~== x._2 absTol 1E-3) } - assert(summary.dispersion ~== dispersionR absTol 1E-3) + assert(summary.dispersion === dispersionR) assert(summary.nullDeviance ~== nullDevianceR absTol 1E-3) assert(summary.deviance ~== residualDevianceR absTol 1E-3) assert(summary.residualDegreeOfFreedom === residualDegreeOfFreedomR) From 1fe08d62f022e12f2f0161af5d8f9eac51baf1b9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=9B=BE=E6=9E=97=E8=A5=BF?= Date: Fri, 30 Jun 2017 19:28:43 +0100 Subject: [PATCH 0834/1765] [SPARK-21223] Change fileToAppInfo in FsHistoryProvider to fix concurrent issue. MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit # What issue does this PR address ? Jira:https://issues.apache.org/jira/browse/SPARK-21223 fix the Thread-safety issue in FsHistoryProvider Currently, Spark HistoryServer use a HashMap named fileToAppInfo in class FsHistoryProvider to store the map of eventlog path and attemptInfo. When use ThreadPool to Replay the log files in the list and merge the list of old applications with new ones, multi thread may update fileToAppInfo at the same time, which may cause Thread-safety issues, such as falling into an infinite loop because of calling resize func of the hashtable. Author: 曾林西 Closes #18430 from zenglinxi0615/master. --- .../apache/spark/deploy/history/FsHistoryProvider.scala | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala b/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala index d05ca142b618b..b2a50bd055712 100644 --- a/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala +++ b/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala @@ -19,7 +19,7 @@ package org.apache.spark.deploy.history import java.io.{FileNotFoundException, IOException, OutputStream} import java.util.UUID -import java.util.concurrent.{Executors, ExecutorService, Future, TimeUnit} +import java.util.concurrent.{ConcurrentHashMap, Executors, ExecutorService, Future, TimeUnit} import java.util.zip.{ZipEntry, ZipOutputStream} import scala.collection.mutable @@ -122,7 +122,7 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) @volatile private var applications: mutable.LinkedHashMap[String, FsApplicationHistoryInfo] = new mutable.LinkedHashMap() - val fileToAppInfo = new mutable.HashMap[Path, FsApplicationAttemptInfo]() + val fileToAppInfo = new ConcurrentHashMap[Path, FsApplicationAttemptInfo]() // List of application logs to be deleted by event log cleaner. private var attemptsToClean = new mutable.ListBuffer[FsApplicationAttemptInfo] @@ -321,7 +321,8 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) // scan for modified applications, replay and merge them val logInfos: Seq[FileStatus] = statusList .filter { entry => - val prevFileSize = fileToAppInfo.get(entry.getPath()).map{_.fileSize}.getOrElse(0L) + val fileInfo = fileToAppInfo.get(entry.getPath()) + val prevFileSize = if (fileInfo != null) fileInfo.fileSize else 0L !entry.isDirectory() && // FsHistoryProvider generates a hidden file which can't be read. Accidentally // reading a garbage file is safe, but we would log an error which can be scary to @@ -475,7 +476,7 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) fileStatus.getLen(), appListener.appSparkVersion.getOrElse("") ) - fileToAppInfo(logPath) = attemptInfo + fileToAppInfo.put(logPath, attemptInfo) logDebug(s"Application log ${attemptInfo.logPath} loaded successfully: $attemptInfo") Some(attemptInfo) } else { From eed9c4ef859fdb75a816a3e0ce2d593b34b23444 Mon Sep 17 00:00:00 2001 From: Xiao Li Date: Fri, 30 Jun 2017 14:23:56 -0700 Subject: [PATCH 0835/1765] [SPARK-21129][SQL] Arguments of SQL function call should not be named expressions ### What changes were proposed in this pull request? Function argument should not be named expressions. It could cause two issues: - Misleading error message - Unexpected query results when the column name is `distinct`, which is not a reserved word in our parser. ``` spark-sql> select count(distinct c1, distinct c2) from t1; Error in query: cannot resolve '`distinct`' given input columns: [c1, c2]; line 1 pos 26; 'Project [unresolvedalias('count(c1#30, 'distinct), None)] +- SubqueryAlias t1 +- CatalogRelation `default`.`t1`, org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe, [c1#30, c2#31] ``` After the fix, the error message becomes ``` spark-sql> select count(distinct c1, distinct c2) from t1; Error in query: extraneous input 'c2' expecting {')', ',', '.', '[', 'OR', 'AND', 'IN', NOT, 'BETWEEN', 'LIKE', RLIKE, 'IS', EQ, '<=>', '<>', '!=', '<', LTE, '>', GTE, '+', '-', '*', '/', '%', 'DIV', '&', '|', '||', '^'}(line 1, pos 35) == SQL == select count(distinct c1, distinct c2) from t1 -----------------------------------^^^ ``` ### How was this patch tested? Added a test case to parser suite. Author: Xiao Li Author: gatorsmile Closes #18338 from gatorsmile/parserDistinctAggFunc. --- .../spark/sql/catalyst/parser/SqlBase.g4 | 3 +- .../spark/sql/catalyst/dsl/package.scala | 1 + .../sql/catalyst/parser/AstBuilder.scala | 9 +++++- .../parser/ExpressionParserSuite.scala | 6 ++-- .../sql/catalyst/parser/PlanParserSuite.scala | 6 ++++ .../resources/sql-tests/inputs/struct.sql | 7 ++++ .../sql-tests/results/struct.sql.out | 32 ++++++++++++++++++- 7 files changed, 59 insertions(+), 5 deletions(-) diff --git a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 index 9456031736528..7ffa150096333 100644 --- a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 +++ b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 @@ -561,6 +561,7 @@ primaryExpression | CASE whenClause+ (ELSE elseExpression=expression)? END #searchedCase | CASE value=expression whenClause+ (ELSE elseExpression=expression)? END #simpleCase | CAST '(' expression AS dataType ')' #cast + | STRUCT '(' (argument+=namedExpression (',' argument+=namedExpression)*)? ')' #struct | FIRST '(' expression (IGNORE NULLS)? ')' #first | LAST '(' expression (IGNORE NULLS)? ')' #last | POSITION '(' substr=valueExpression IN str=valueExpression ')' #position @@ -569,7 +570,7 @@ primaryExpression | qualifiedName '.' ASTERISK #star | '(' namedExpression (',' namedExpression)+ ')' #rowConstructor | '(' query ')' #subqueryExpression - | qualifiedName '(' (setQuantifier? namedExpression (',' namedExpression)*)? ')' + | qualifiedName '(' (setQuantifier? argument+=expression (',' argument+=expression)*)? ')' (OVER windowSpec)? #functionCall | value=primaryExpression '[' index=valueExpression ']' #subscript | identifier #columnReference diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala index f6792569b704e..7c100afcd738f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala @@ -170,6 +170,7 @@ package object dsl { case Seq() => UnresolvedStar(None) case target => UnresolvedStar(Option(target)) } + def namedStruct(e: Expression*): Expression = CreateNamedStruct(e) def callFunction[T, U]( func: T => U, diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index ef79cbcaa0ce6..8eac3ef2d3568 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -1061,6 +1061,13 @@ class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging Cast(expression(ctx.expression), visitSparkDataType(ctx.dataType)) } + /** + * Create a [[CreateStruct]] expression. + */ + override def visitStruct(ctx: StructContext): Expression = withOrigin(ctx) { + CreateStruct(ctx.argument.asScala.map(expression)) + } + /** * Create a [[First]] expression. */ @@ -1091,7 +1098,7 @@ class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging // Create the function call. val name = ctx.qualifiedName.getText val isDistinct = Option(ctx.setQuantifier()).exists(_.DISTINCT != null) - val arguments = ctx.namedExpression().asScala.map(expression) match { + val arguments = ctx.argument.asScala.map(expression) match { case Seq(UnresolvedStar(None)) if name.toLowerCase(Locale.ROOT) == "count" && !isDistinct => // Transform COUNT(*) into COUNT(1). diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala index 4d08f016a4a16..45f9f72dccc45 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala @@ -231,7 +231,7 @@ class ExpressionParserSuite extends PlanTest { assertEqual("foo(distinct a, b)", 'foo.distinctFunction('a, 'b)) assertEqual("grouping(distinct a, b)", 'grouping.distinctFunction('a, 'b)) assertEqual("`select`(all a, b)", 'select.function('a, 'b)) - assertEqual("foo(a as x, b as e)", 'foo.function('a as 'x, 'b as 'e)) + intercept("foo(a x)", "extraneous input 'x'") } test("window function expressions") { @@ -330,7 +330,9 @@ class ExpressionParserSuite extends PlanTest { assertEqual("a.b", UnresolvedAttribute("a.b")) assertEqual("`select`.b", UnresolvedAttribute("select.b")) assertEqual("(a + b).b", ('a + 'b).getField("b")) // This will fail analysis. - assertEqual("struct(a, b).b", 'struct.function('a, 'b).getField("b")) + assertEqual( + "struct(a, b).b", + namedStruct(NamePlaceholder, 'a, NamePlaceholder, 'b).getField("b")) } test("reference") { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala index bf15b85d5b510..5b2573fa4d601 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala @@ -223,6 +223,12 @@ class PlanParserSuite extends AnalysisTest { assertEqual(s"$sql grouping sets((a, b), (a), ())", GroupingSets(Seq(Seq('a, 'b), Seq('a), Seq()), Seq('a, 'b), table("d"), Seq('a, 'b, 'sum.function('c).as("c")))) + + val m = intercept[ParseException] { + parsePlan("SELECT a, b, count(distinct a, distinct b) as c FROM d GROUP BY a, b") + }.getMessage + assert(m.contains("extraneous input 'b'")) + } test("limit") { diff --git a/sql/core/src/test/resources/sql-tests/inputs/struct.sql b/sql/core/src/test/resources/sql-tests/inputs/struct.sql index e56344dc4de80..93a1238ab18c2 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/struct.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/struct.sql @@ -18,3 +18,10 @@ SELECT ID, STRUCT(ST.*,CAST(ID AS STRING) AS E) NST FROM tbl_x; -- Prepend a column to a struct SELECT ID, STRUCT(CAST(ID AS STRING) AS AA, ST.*) NST FROM tbl_x; + +-- Select a column from a struct +SELECT ID, STRUCT(ST.*).C NST FROM tbl_x; +SELECT ID, STRUCT(ST.C, ST.D).D NST FROM tbl_x; + +-- Select an alias from a struct +SELECT ID, STRUCT(ST.C as STC, ST.D as STD).STD FROM tbl_x; \ No newline at end of file diff --git a/sql/core/src/test/resources/sql-tests/results/struct.sql.out b/sql/core/src/test/resources/sql-tests/results/struct.sql.out index 3e32f46195464..1da33bc736f0b 100644 --- a/sql/core/src/test/resources/sql-tests/results/struct.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/struct.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 6 +-- Number of queries: 9 -- !query 0 @@ -58,3 +58,33 @@ struct> 1 {"AA":"1","C":"gamma","D":"delta"} 2 {"AA":"2","C":"epsilon","D":"eta"} 3 {"AA":"3","C":"theta","D":"iota"} + + +-- !query 6 +SELECT ID, STRUCT(ST.*).C NST FROM tbl_x +-- !query 6 schema +struct +-- !query 6 output +1 gamma +2 epsilon +3 theta + + +-- !query 7 +SELECT ID, STRUCT(ST.C, ST.D).D NST FROM tbl_x +-- !query 7 schema +struct +-- !query 7 output +1 delta +2 eta +3 iota + + +-- !query 8 +SELECT ID, STRUCT(ST.C as STC, ST.D as STD).STD FROM tbl_x +-- !query 8 schema +struct +-- !query 8 output +1 delta +2 eta +3 iota From fd1325522549937232f37215db53d6478f48644c Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Fri, 30 Jun 2017 15:11:27 -0700 Subject: [PATCH 0836/1765] [SPARK-21052][SQL][FOLLOW-UP] Add hash map metrics to join ## What changes were proposed in this pull request? Remove `numHashCollisions` in `BytesToBytesMap`. And change `getAverageProbesPerLookup()` to `getAverageProbesPerLookup` as suggested. ## How was this patch tested? Existing tests. Author: Liang-Chi Hsieh Closes #18480 from viirya/SPARK-21052-followup. --- .../spark/unsafe/map/BytesToBytesMap.java | 33 ------------------- .../spark/sql/execution/joins/HashJoin.scala | 2 +- .../sql/execution/joins/HashedRelation.scala | 8 ++--- 3 files changed, 5 insertions(+), 38 deletions(-) diff --git a/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java b/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java index 4bef21b6b4e4d..3b6200e74f1e1 100644 --- a/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java +++ b/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java @@ -160,14 +160,10 @@ public final class BytesToBytesMap extends MemoryConsumer { private final boolean enablePerfMetrics; - private long timeSpentResizingNs = 0; - private long numProbes = 0; private long numKeyLookups = 0; - private long numHashCollisions = 0; - private long peakMemoryUsedBytes = 0L; private final int initialCapacity; @@ -489,10 +485,6 @@ public void safeLookup(Object keyBase, long keyOffset, int keyLength, Location l ); if (areEqual) { return; - } else { - if (enablePerfMetrics) { - numHashCollisions++; - } } } } @@ -859,16 +851,6 @@ public long getPeakMemoryUsedBytes() { return peakMemoryUsedBytes; } - /** - * Returns the total amount of time spent resizing this map (in nanoseconds). - */ - public long getTimeSpentResizingNs() { - if (!enablePerfMetrics) { - throw new IllegalStateException(); - } - return timeSpentResizingNs; - } - /** * Returns the average number of probes per key lookup. */ @@ -879,13 +861,6 @@ public double getAverageProbesPerLookup() { return (1.0 * numProbes) / numKeyLookups; } - public long getNumHashCollisions() { - if (!enablePerfMetrics) { - throw new IllegalStateException(); - } - return numHashCollisions; - } - @VisibleForTesting public int getNumDataPages() { return dataPages.size(); @@ -923,10 +898,6 @@ public void reset() { void growAndRehash() { assert(longArray != null); - long resizeStartTime = -1; - if (enablePerfMetrics) { - resizeStartTime = System.nanoTime(); - } // Store references to the old data structures to be used when we re-hash final LongArray oldLongArray = longArray; final int oldCapacity = (int) oldLongArray.size() / 2; @@ -951,9 +922,5 @@ void growAndRehash() { longArray.set(newPos * 2 + 1, hashcode); } freeArray(oldLongArray); - - if (enablePerfMetrics) { - timeSpentResizingNs += System.nanoTime() - resizeStartTime; - } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala index b09edf380c2d4..0396168d3f311 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala @@ -215,7 +215,7 @@ trait HashJoin { // At the end of the task, we update the avg hash probe. TaskContext.get().addTaskCompletionListener(_ => - avgHashProbe.set(hashed.getAverageProbesPerLookup())) + avgHashProbe.set(hashed.getAverageProbesPerLookup)) val resultProj = createResultProjection joinedIter.map { r => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala index 3c702856114f9..2038cb9edb67d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala @@ -83,7 +83,7 @@ private[execution] sealed trait HashedRelation extends KnownSizeEstimation { /** * Returns the average number of probes per key lookup. */ - def getAverageProbesPerLookup(): Double + def getAverageProbesPerLookup: Double } private[execution] object HashedRelation { @@ -280,7 +280,7 @@ private[joins] class UnsafeHashedRelation( read(in.readInt, in.readLong, in.readBytes) } - override def getAverageProbesPerLookup(): Double = binaryMap.getAverageProbesPerLookup() + override def getAverageProbesPerLookup: Double = binaryMap.getAverageProbesPerLookup } private[joins] object UnsafeHashedRelation { @@ -776,7 +776,7 @@ private[execution] final class LongToUnsafeRowMap(val mm: TaskMemoryManager, cap /** * Returns the average number of probes per key lookup. */ - def getAverageProbesPerLookup(): Double = numProbes.toDouble / numKeyLookups + def getAverageProbesPerLookup: Double = numProbes.toDouble / numKeyLookups } private[joins] class LongHashedRelation( @@ -829,7 +829,7 @@ private[joins] class LongHashedRelation( map = in.readObject().asInstanceOf[LongToUnsafeRowMap] } - override def getAverageProbesPerLookup(): Double = map.getAverageProbesPerLookup() + override def getAverageProbesPerLookup: Double = map.getAverageProbesPerLookup } /** From 4eb41879ce774dec1d16b2281ab1fbf41f9d418a Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Sat, 1 Jul 2017 09:25:29 +0800 Subject: [PATCH 0837/1765] [SPARK-17528][SQL] data should be copied properly before saving into InternalRow ## What changes were proposed in this pull request? For performance reasons, `UnsafeRow.getString`, `getStruct`, etc. return a "pointer" that points to a memory region of this unsafe row. This makes the unsafe projection a little dangerous, because all of its output rows share one instance. When we implement SQL operators, we should be careful to not cache the input rows because they may be produced by unsafe projection from child operator and thus its content may change overtime. However, when we updating values of InternalRow(e.g. in mutable projection and safe projection), we only copy UTF8String, we should also copy InternalRow, ArrayData and MapData. This PR fixes this, and also fixes the copy of vairous InternalRow, ArrayData and MapData implementations. ## How was this patch tested? new regression tests Author: Wenchen Fan Closes #18483 from cloud-fan/fix-copy. --- .../apache/spark/unsafe/types/UTF8String.java | 6 + .../spark/sql/catalyst/InternalRow.scala | 27 ++++- .../spark/sql/catalyst/expressions/Cast.scala | 2 +- .../expressions/SpecificInternalRow.scala | 12 -- .../expressions/aggregate/collect.scala | 2 +- .../expressions/aggregate/interfaces.scala | 6 + .../expressions/codegen/CodeGenerator.scala | 6 +- .../codegen/GenerateSafeProjection.scala | 2 - .../spark/sql/catalyst/expressions/rows.scala | 23 ++-- .../sql/catalyst/util/GenericArrayData.scala | 10 +- .../scala/org/apache/spark/sql/RowTest.scala | 4 - .../catalyst/expressions/MapDataSuite.scala | 57 ---------- .../codegen/GeneratedProjectionSuite.scala | 36 ++++++ .../sql/catalyst/util/ComplexDataSuite.scala | 107 ++++++++++++++++++ .../execution/vectorized/ColumnarBatch.java | 2 +- .../SortBasedAggregationIterator.scala | 15 +-- .../columnar/GenerateColumnAccessor.scala | 1 - .../execution/window/AggregateProcessor.scala | 7 +- 18 files changed, 212 insertions(+), 113 deletions(-) delete mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MapDataSuite.scala create mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/ComplexDataSuite.scala diff --git a/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java index 40b9fc9534f44..9de4ca71ff6d4 100644 --- a/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java +++ b/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java @@ -1088,6 +1088,12 @@ public UTF8String clone() { return fromBytes(getBytes()); } + public UTF8String copy() { + byte[] bytes = new byte[numBytes]; + copyMemory(base, offset, bytes, BYTE_ARRAY_OFFSET, numBytes); + return fromBytes(bytes); + } + @Override public int compareTo(@Nonnull final UTF8String other) { int len = Math.min(numBytes, other.numBytes); diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala index 256f64e320be8..29110640d64f2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala @@ -18,7 +18,9 @@ package org.apache.spark.sql.catalyst import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.util.{ArrayData, MapData} import org.apache.spark.sql.types.{DataType, Decimal, StructType} +import org.apache.spark.unsafe.types.UTF8String /** * An abstract class for row used internally in Spark SQL, which only contains the columns as @@ -33,6 +35,10 @@ abstract class InternalRow extends SpecializedGetters with Serializable { def setNullAt(i: Int): Unit + /** + * Updates the value at column `i`. Note that after updating, the given value will be kept in this + * row, and the caller side should guarantee that this value won't be changed afterwards. + */ def update(i: Int, value: Any): Unit // default implementation (slow) @@ -58,7 +64,15 @@ abstract class InternalRow extends SpecializedGetters with Serializable { def copy(): InternalRow /** Returns true if there are any NULL values in this row. */ - def anyNull: Boolean + def anyNull: Boolean = { + val len = numFields + var i = 0 + while (i < len) { + if (isNullAt(i)) { return true } + i += 1 + } + false + } /* ---------------------- utility methods for Scala ---------------------- */ @@ -94,4 +108,15 @@ object InternalRow { /** Returns an empty [[InternalRow]]. */ val empty = apply() + + /** + * Copies the given value if it's string/struct/array/map type. + */ + def copyValue(value: Any): Any = value match { + case v: UTF8String => v.copy() + case v: InternalRow => v.copy() + case v: ArrayData => v.copy() + case v: MapData => v.copy() + case _ => value + } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala index 43df19ba009a8..3862e64b9d828 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala @@ -1047,7 +1047,7 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String final $rowClass $result = new $rowClass(${fieldsCasts.length}); final InternalRow $tmpRow = $c; $fieldsEvalCode - $evPrim = $result.copy(); + $evPrim = $result; """ } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificInternalRow.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificInternalRow.scala index 74e0b4691d4cc..75feaf670c84a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificInternalRow.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificInternalRow.scala @@ -17,7 +17,6 @@ package org.apache.spark.sql.catalyst.expressions -import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.types._ /** @@ -220,17 +219,6 @@ final class SpecificInternalRow(val values: Array[MutableValue]) extends BaseGen override def isNullAt(i: Int): Boolean = values(i).isNull - override def copy(): InternalRow = { - val newValues = new Array[Any](values.length) - var i = 0 - while (i < values.length) { - newValues(i) = values(i).boxed - i += 1 - } - - new GenericInternalRow(newValues) - } - override protected def genericGet(i: Int): Any = values(i).boxed override def update(ordinal: Int, value: Any) { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/collect.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/collect.scala index 26cd9ab665383..0d2f9889a27d5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/collect.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/collect.scala @@ -52,7 +52,7 @@ abstract class Collect[T <: Growable[Any] with Iterable[Any]] extends TypedImper // Do not allow null values. We follow the semantics of Hive's collect_list/collect_set here. // See: org.apache.hadoop.hive.ql.udf.generic.GenericUDAFMkCollectionEvaluator if (value != null) { - buffer += value + buffer += InternalRow.copyValue(value) } buffer } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala index fffcc7c9ef53a..7af4901435857 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala @@ -317,6 +317,9 @@ abstract class ImperativeAggregate extends AggregateFunction with CodegenFallbac * Updates its aggregation buffer, located in `mutableAggBuffer`, based on the given `inputRow`. * * Use `fieldNumber + mutableAggBufferOffset` to access fields of `mutableAggBuffer`. + * + * Note that, the input row may be produced by unsafe projection and it may not be safe to cache + * some fields of the input row, as the values can be changed unexpectedly. */ def update(mutableAggBuffer: InternalRow, inputRow: InternalRow): Unit @@ -326,6 +329,9 @@ abstract class ImperativeAggregate extends AggregateFunction with CodegenFallbac * * Use `fieldNumber + mutableAggBufferOffset` to access fields of `mutableAggBuffer`. * Use `fieldNumber + inputAggBufferOffset` to access fields of `inputAggBuffer`. + * + * Note that, the input row may be produced by unsafe projection and it may not be safe to cache + * some fields of the input row, as the values can be changed unexpectedly. */ def merge(mutableAggBuffer: InternalRow, inputAggBuffer: InternalRow): Unit } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index 5158949b95629..b15bf2ca7c116 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -408,9 +408,11 @@ class CodegenContext { dataType match { case _ if isPrimitiveType(jt) => s"$row.set${primitiveTypeName(jt)}($ordinal, $value)" case t: DecimalType => s"$row.setDecimal($ordinal, $value, ${t.precision})" - // The UTF8String may came from UnsafeRow, otherwise clone is cheap (re-use the bytes) - case StringType => s"$row.update($ordinal, $value.clone())" case udt: UserDefinedType[_] => setColumn(row, udt.sqlType, ordinal, value) + // The UTF8String, InternalRow, ArrayData and MapData may came from UnsafeRow, we should copy + // it to avoid keeping a "pointer" to a memory region which may get updated afterwards. + case StringType | _: StructType | _: ArrayType | _: MapType => + s"$row.update($ordinal, $value.copy())" case _ => s"$row.update($ordinal, $value)" } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala index f708aeff2b146..dd0419d2286d1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala @@ -131,8 +131,6 @@ object GenerateSafeProjection extends CodeGenerator[Seq[Expression], Projection] case s: StructType => createCodeForStruct(ctx, input, s) case ArrayType(elementType, _) => createCodeForArray(ctx, input, elementType) case MapType(keyType, valueType, _) => createCodeForMap(ctx, input, keyType, valueType) - // UTF8String act as a pointer if it's inside UnsafeRow, so copy it to make it safe. - case StringType => ExprCode("", "false", s"$input.clone()") case udt: UserDefinedType[_] => convertToSafe(ctx, input, udt.sqlType) case _ => ExprCode("", "false", input) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala index 751b821e1b009..65539a2f00e6c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala @@ -50,16 +50,6 @@ trait BaseGenericInternalRow extends InternalRow { override def getMap(ordinal: Int): MapData = getAs(ordinal) override def getStruct(ordinal: Int, numFields: Int): InternalRow = getAs(ordinal) - override def anyNull: Boolean = { - val len = numFields - var i = 0 - while (i < len) { - if (isNullAt(i)) { return true } - i += 1 - } - false - } - override def toString: String = { if (numFields == 0) { "[empty row]" @@ -79,6 +69,17 @@ trait BaseGenericInternalRow extends InternalRow { } } + override def copy(): GenericInternalRow = { + val len = numFields + val newValues = new Array[Any](len) + var i = 0 + while (i < len) { + newValues(i) = InternalRow.copyValue(genericGet(i)) + i += 1 + } + new GenericInternalRow(newValues) + } + override def equals(o: Any): Boolean = { if (!o.isInstanceOf[BaseGenericInternalRow]) { return false @@ -206,6 +207,4 @@ class GenericInternalRow(val values: Array[Any]) extends BaseGenericInternalRow override def setNullAt(i: Int): Unit = { values(i) = null} override def update(i: Int, value: Any): Unit = { values(i) = value } - - override def copy(): GenericInternalRow = this } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/GenericArrayData.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/GenericArrayData.scala index dd660c80a9c3c..9e39ed9c3a778 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/GenericArrayData.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/GenericArrayData.scala @@ -49,7 +49,15 @@ class GenericArrayData(val array: Array[Any]) extends ArrayData { def this(seqOrArray: Any) = this(GenericArrayData.anyToSeq(seqOrArray)) - override def copy(): ArrayData = new GenericArrayData(array.clone()) + override def copy(): ArrayData = { + val newValues = new Array[Any](array.length) + var i = 0 + while (i < array.length) { + newValues(i) = InternalRow.copyValue(array(i)) + i += 1 + } + new GenericArrayData(newValues) + } override def numElements(): Int = array.length diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/RowTest.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/RowTest.scala index c9c9599e7f463..25699de33d717 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/RowTest.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/RowTest.scala @@ -121,10 +121,6 @@ class RowTest extends FunSpec with Matchers { externalRow should be theSameInstanceAs externalRow.copy() } - it("copy should return same ref for internal rows") { - internalRow should be theSameInstanceAs internalRow.copy() - } - it("toSeq should not expose internal state for external rows") { val modifiedValues = modifyValues(externalRow.toSeq) externalRow.toSeq should not equal modifiedValues diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MapDataSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MapDataSuite.scala deleted file mode 100644 index 25a675a90276d..0000000000000 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MapDataSuite.scala +++ /dev/null @@ -1,57 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.catalyst.expressions - -import scala.collection._ - -import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.catalyst.util.ArrayBasedMapData -import org.apache.spark.sql.types.{DataType, IntegerType, MapType, StringType} -import org.apache.spark.unsafe.types.UTF8String - -class MapDataSuite extends SparkFunSuite { - - test("inequality tests") { - def u(str: String): UTF8String = UTF8String.fromString(str) - - // test data - val testMap1 = Map(u("key1") -> 1) - val testMap2 = Map(u("key1") -> 1, u("key2") -> 2) - val testMap3 = Map(u("key1") -> 1) - val testMap4 = Map(u("key1") -> 1, u("key2") -> 2) - - // ArrayBasedMapData - val testArrayMap1 = ArrayBasedMapData(testMap1.toMap) - val testArrayMap2 = ArrayBasedMapData(testMap2.toMap) - val testArrayMap3 = ArrayBasedMapData(testMap3.toMap) - val testArrayMap4 = ArrayBasedMapData(testMap4.toMap) - assert(testArrayMap1 !== testArrayMap3) - assert(testArrayMap2 !== testArrayMap4) - - // UnsafeMapData - val unsafeConverter = UnsafeProjection.create(Array[DataType](MapType(StringType, IntegerType))) - val row = new GenericInternalRow(1) - def toUnsafeMap(map: ArrayBasedMapData): UnsafeMapData = { - row.update(0, map) - val unsafeRow = unsafeConverter.apply(row) - unsafeRow.getMap(0).copy - } - assert(toUnsafeMap(testArrayMap1) !== toUnsafeMap(testArrayMap3)) - assert(toUnsafeMap(testArrayMap2) !== toUnsafeMap(testArrayMap4)) - } -} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratedProjectionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratedProjectionSuite.scala index 58ea5b9cb52d3..0cd0d8859145f 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratedProjectionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratedProjectionSuite.scala @@ -172,4 +172,40 @@ class GeneratedProjectionSuite extends SparkFunSuite { assert(unsafe1 === unsafe3) assert(unsafe1.getStruct(1, 7) === unsafe3.getStruct(1, 7)) } + + test("MutableProjection should not cache content from the input row") { + val mutableProj = GenerateMutableProjection.generate( + Seq(BoundReference(0, new StructType().add("i", StringType), true))) + val row = new GenericInternalRow(1) + mutableProj.target(row) + + val unsafeProj = GenerateUnsafeProjection.generate( + Seq(BoundReference(0, new StructType().add("i", StringType), true))) + val unsafeRow = unsafeProj.apply(InternalRow(InternalRow(UTF8String.fromString("a")))) + + mutableProj.apply(unsafeRow) + assert(row.getStruct(0, 1).getString(0) == "a") + + // Even if the input row of the mutable projection has been changed, the target mutable row + // should keep same. + unsafeProj.apply(InternalRow(InternalRow(UTF8String.fromString("b")))) + assert(row.getStruct(0, 1).getString(0).toString == "a") + } + + test("SafeProjection should not cache content from the input row") { + val safeProj = GenerateSafeProjection.generate( + Seq(BoundReference(0, new StructType().add("i", StringType), true))) + + val unsafeProj = GenerateUnsafeProjection.generate( + Seq(BoundReference(0, new StructType().add("i", StringType), true))) + val unsafeRow = unsafeProj.apply(InternalRow(InternalRow(UTF8String.fromString("a")))) + + val row = safeProj.apply(unsafeRow) + assert(row.getStruct(0, 1).getString(0) == "a") + + // Even if the input row of the mutable projection has been changed, the target mutable row + // should keep same. + unsafeProj.apply(InternalRow(InternalRow(UTF8String.fromString("b")))) + assert(row.getStruct(0, 1).getString(0).toString == "a") + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/ComplexDataSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/ComplexDataSuite.scala new file mode 100644 index 0000000000000..9d285916bcf42 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/ComplexDataSuite.scala @@ -0,0 +1,107 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.util + +import scala.collection._ + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.{BoundReference, GenericInternalRow, SpecificInternalRow, UnsafeMapData, UnsafeProjection} +import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection +import org.apache.spark.sql.types.{DataType, IntegerType, MapType, StringType} +import org.apache.spark.unsafe.types.UTF8String + +class ComplexDataSuite extends SparkFunSuite { + def utf8(str: String): UTF8String = UTF8String.fromString(str) + + test("inequality tests for MapData") { + // test data + val testMap1 = Map(utf8("key1") -> 1) + val testMap2 = Map(utf8("key1") -> 1, utf8("key2") -> 2) + val testMap3 = Map(utf8("key1") -> 1) + val testMap4 = Map(utf8("key1") -> 1, utf8("key2") -> 2) + + // ArrayBasedMapData + val testArrayMap1 = ArrayBasedMapData(testMap1.toMap) + val testArrayMap2 = ArrayBasedMapData(testMap2.toMap) + val testArrayMap3 = ArrayBasedMapData(testMap3.toMap) + val testArrayMap4 = ArrayBasedMapData(testMap4.toMap) + assert(testArrayMap1 !== testArrayMap3) + assert(testArrayMap2 !== testArrayMap4) + + // UnsafeMapData + val unsafeConverter = UnsafeProjection.create(Array[DataType](MapType(StringType, IntegerType))) + val row = new GenericInternalRow(1) + def toUnsafeMap(map: ArrayBasedMapData): UnsafeMapData = { + row.update(0, map) + val unsafeRow = unsafeConverter.apply(row) + unsafeRow.getMap(0).copy + } + assert(toUnsafeMap(testArrayMap1) !== toUnsafeMap(testArrayMap3)) + assert(toUnsafeMap(testArrayMap2) !== toUnsafeMap(testArrayMap4)) + } + + test("GenericInternalRow.copy return a new instance that is independent from the old one") { + val project = GenerateUnsafeProjection.generate(Seq(BoundReference(0, StringType, true))) + val unsafeRow = project.apply(InternalRow(utf8("a"))) + + val genericRow = new GenericInternalRow(Array[Any](unsafeRow.getUTF8String(0))) + val copiedGenericRow = genericRow.copy() + assert(copiedGenericRow.getString(0) == "a") + project.apply(InternalRow(UTF8String.fromString("b"))) + // The copied internal row should not be changed externally. + assert(copiedGenericRow.getString(0) == "a") + } + + test("SpecificMutableRow.copy return a new instance that is independent from the old one") { + val project = GenerateUnsafeProjection.generate(Seq(BoundReference(0, StringType, true))) + val unsafeRow = project.apply(InternalRow(utf8("a"))) + + val mutableRow = new SpecificInternalRow(Seq(StringType)) + mutableRow(0) = unsafeRow.getUTF8String(0) + val copiedMutableRow = mutableRow.copy() + assert(copiedMutableRow.getString(0) == "a") + project.apply(InternalRow(UTF8String.fromString("b"))) + // The copied internal row should not be changed externally. + assert(copiedMutableRow.getString(0) == "a") + } + + test("GenericArrayData.copy return a new instance that is independent from the old one") { + val project = GenerateUnsafeProjection.generate(Seq(BoundReference(0, StringType, true))) + val unsafeRow = project.apply(InternalRow(utf8("a"))) + + val genericArray = new GenericArrayData(Array[Any](unsafeRow.getUTF8String(0))) + val copiedGenericArray = genericArray.copy() + assert(copiedGenericArray.getUTF8String(0).toString == "a") + project.apply(InternalRow(UTF8String.fromString("b"))) + // The copied array data should not be changed externally. + assert(copiedGenericArray.getUTF8String(0).toString == "a") + } + + test("copy on nested complex type") { + val project = GenerateUnsafeProjection.generate(Seq(BoundReference(0, StringType, true))) + val unsafeRow = project.apply(InternalRow(utf8("a"))) + + val arrayOfRow = new GenericArrayData(Array[Any](InternalRow(unsafeRow.getUTF8String(0)))) + val copied = arrayOfRow.copy() + assert(copied.getStruct(0, 1).getUTF8String(0).toString == "a") + project.apply(InternalRow(UTF8String.fromString("b"))) + // The copied data should not be changed externally. + assert(copied.getStruct(0, 1).getUTF8String(0).toString == "a") + } +} diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarBatch.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarBatch.java index e23a64350cbc5..34dc3af9b85c8 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarBatch.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarBatch.java @@ -149,7 +149,7 @@ public InternalRow copy() { } else if (dt instanceof DoubleType) { row.setDouble(i, getDouble(i)); } else if (dt instanceof StringType) { - row.update(i, getUTF8String(i)); + row.update(i, getUTF8String(i).copy()); } else if (dt instanceof BinaryType) { row.update(i, getBinary(i)); } else if (dt instanceof DecimalType) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregationIterator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregationIterator.scala index bea2dce1a7657..a5a444b160c63 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregationIterator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregationIterator.scala @@ -86,17 +86,6 @@ class SortBasedAggregationIterator( // The aggregation buffer used by the sort-based aggregation. private[this] val sortBasedAggregationBuffer: InternalRow = newBuffer - // This safe projection is used to turn the input row into safe row. This is necessary - // because the input row may be produced by unsafe projection in child operator and all the - // produced rows share one byte array. However, when we update the aggregate buffer according to - // the input row, we may cache some values from input row, e.g. `Max` will keep the max value from - // input row via MutableProjection, `CollectList` will keep all values in an array via - // ImperativeAggregate framework. These values may get changed unexpectedly if the underlying - // unsafe projection update the shared byte array. By applying a safe projection to the input row, - // we can cut down the connection from input row to the shared byte array, and thus it's safe to - // cache values from input row while updating the aggregation buffer. - private[this] val safeProj: Projection = FromUnsafeProjection(valueAttributes.map(_.dataType)) - protected def initialize(): Unit = { if (inputIterator.hasNext) { initializeBuffer(sortBasedAggregationBuffer) @@ -119,7 +108,7 @@ class SortBasedAggregationIterator( // We create a variable to track if we see the next group. var findNextPartition = false // firstRowInNextGroup is the first row of this group. We first process it. - processRow(sortBasedAggregationBuffer, safeProj(firstRowInNextGroup)) + processRow(sortBasedAggregationBuffer, firstRowInNextGroup) // The search will stop when we see the next group or there is no // input row left in the iter. @@ -130,7 +119,7 @@ class SortBasedAggregationIterator( // Check if the current row belongs the current input row. if (currentGroupingKey == groupingKey) { - processRow(sortBasedAggregationBuffer, safeProj(currentRow)) + processRow(sortBasedAggregationBuffer, currentRow) } else { // We find a new group. findNextPartition = true diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/GenerateColumnAccessor.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/GenerateColumnAccessor.scala index d3fa0dcd2d7c3..fc977f2fd5530 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/GenerateColumnAccessor.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/GenerateColumnAccessor.scala @@ -56,7 +56,6 @@ class MutableUnsafeRow(val writer: UnsafeRowWriter) extends BaseGenericInternalR // all other methods inherited from GenericMutableRow are not need override protected def genericGet(ordinal: Int): Any = throw new UnsupportedOperationException override def numFields: Int = throw new UnsupportedOperationException - override def copy(): InternalRow = throw new UnsupportedOperationException } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/window/AggregateProcessor.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/window/AggregateProcessor.scala index 2195c6ea95948..bc141b36e63b4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/window/AggregateProcessor.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/window/AggregateProcessor.scala @@ -145,13 +145,10 @@ private[window] final class AggregateProcessor( /** Update the buffer. */ def update(input: InternalRow): Unit = { - // TODO(hvanhovell) this sacrifices performance for correctness. We should make sure that - // MutableProjection makes copies of the complex input objects it buffer. - val copy = input.copy() - updateProjection(join(buffer, copy)) + updateProjection(join(buffer, input)) var i = 0 while (i < numImperatives) { - imperatives(i).update(buffer, copy) + imperatives(i).update(buffer, input) i += 1 } } From 61b5df567eb8ae0df4059cb0e334316fff462de9 Mon Sep 17 00:00:00 2001 From: wangzhenhua Date: Sat, 1 Jul 2017 10:01:44 +0800 Subject: [PATCH 0838/1765] [SPARK-21127][SQL] Update statistics after data changing commands ## What changes were proposed in this pull request? Update stats after the following data changing commands: - InsertIntoHadoopFsRelationCommand - InsertIntoHiveTable - LoadDataCommand - TruncateTableCommand - AlterTableSetLocationCommand - AlterTableDropPartitionCommand ## How was this patch tested? Added new test cases. Author: wangzhenhua Author: Zhenhua Wang Closes #18334 from wzhfy/changeStatsForOperation. --- .../apache/spark/sql/internal/SQLConf.scala | 10 + .../sql/execution/command/CommandUtils.scala | 17 +- .../spark/sql/execution/command/ddl.scala | 15 +- .../spark/sql/StatisticsCollectionSuite.scala | 77 +++++--- .../spark/sql/hive/StatisticsSuite.scala | 187 +++++++++++------- 5 files changed, 207 insertions(+), 99 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index c641e4d3a23e1..25152f3e32d6b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -774,6 +774,14 @@ object SQLConf { .doubleConf .createWithDefault(0.05) + val AUTO_UPDATE_SIZE = + buildConf("spark.sql.statistics.autoUpdate.size") + .doc("Enables automatic update for table size once table's data is changed. Note that if " + + "the total number of files of the table is very large, this can be expensive and slow " + + "down data change commands.") + .booleanConf + .createWithDefault(false) + val CBO_ENABLED = buildConf("spark.sql.cbo.enabled") .doc("Enables CBO for estimation of plan statistics when set true.") @@ -1083,6 +1091,8 @@ class SQLConf extends Serializable with Logging { def cboEnabled: Boolean = getConf(SQLConf.CBO_ENABLED) + def autoUpdateSize: Boolean = getConf(SQLConf.AUTO_UPDATE_SIZE) + def joinReorderEnabled: Boolean = getConf(SQLConf.JOIN_REORDER_ENABLED) def joinReorderDPThreshold: Int = getConf(SQLConf.JOIN_REORDER_DP_THRESHOLD) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/CommandUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/CommandUtils.scala index 92397607f38fd..fce12cc96620c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/CommandUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/CommandUtils.scala @@ -36,7 +36,14 @@ object CommandUtils extends Logging { def updateTableStats(sparkSession: SparkSession, table: CatalogTable): Unit = { if (table.stats.nonEmpty) { val catalog = sparkSession.sessionState.catalog - catalog.alterTableStats(table.identifier, None) + if (sparkSession.sessionState.conf.autoUpdateSize) { + val newTable = catalog.getTableMetadata(table.identifier) + val newSize = CommandUtils.calculateTotalSize(sparkSession.sessionState, newTable) + val newStats = CatalogStatistics(sizeInBytes = newSize) + catalog.alterTableStats(table.identifier, Some(newStats)) + } else { + catalog.alterTableStats(table.identifier, None) + } } } @@ -84,7 +91,9 @@ object CommandUtils extends Logging { size } - locationUri.map { p => + val startTime = System.nanoTime() + logInfo(s"Starting to calculate the total file size under path $locationUri.") + val size = locationUri.map { p => val path = new Path(p) try { val fs = path.getFileSystem(sessionState.newHadoopConf()) @@ -97,6 +106,10 @@ object CommandUtils extends Logging { 0L } }.getOrElse(0L) + val durationInMs = (System.nanoTime() - startTime) / (1000 * 1000) + logInfo(s"It took $durationInMs ms to calculate the total file size under path $locationUri.") + + size } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala index ac897c1b22d77..ba7ca84f229fc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala @@ -437,7 +437,20 @@ case class AlterTableAddPartitionCommand( } catalog.createPartitions(table.identifier, parts, ignoreIfExists = ifNotExists) - CommandUtils.updateTableStats(sparkSession, table) + if (table.stats.nonEmpty) { + if (sparkSession.sessionState.conf.autoUpdateSize) { + val addedSize = parts.map { part => + CommandUtils.calculateLocationSize(sparkSession.sessionState, table.identifier, + part.storage.locationUri) + }.sum + if (addedSize > 0) { + val newStats = CatalogStatistics(sizeInBytes = table.stats.get.sizeInBytes + addedSize) + catalog.alterTableStats(table.identifier, Some(newStats)) + } + } else { + catalog.alterTableStats(table.identifier, None) + } + } Seq.empty[Row] } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/StatisticsCollectionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/StatisticsCollectionSuite.scala index b031c52dad8b5..d9392de37a815 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/StatisticsCollectionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/StatisticsCollectionSuite.scala @@ -28,7 +28,7 @@ import org.apache.spark.sql.catalyst.catalog.{CatalogRelation, CatalogStatistics import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.execution.datasources.LogicalRelation -import org.apache.spark.sql.internal.StaticSQLConf +import org.apache.spark.sql.internal.{SQLConf, StaticSQLConf} import org.apache.spark.sql.test.{SharedSQLContext, SQLTestUtils} import org.apache.spark.sql.test.SQLTestData.ArrayData import org.apache.spark.sql.types._ @@ -178,36 +178,63 @@ class StatisticsCollectionSuite extends StatisticsCollectionTestBase with Shared test("change stats after set location command") { val table = "change_stats_set_location_table" - withTable(table) { - spark.range(100).select($"id", $"id" % 5 as "value").write.saveAsTable(table) - // analyze to get initial stats - sql(s"ANALYZE TABLE $table COMPUTE STATISTICS FOR COLUMNS id, value") - val fetched1 = checkTableStats( - table, hasSizeInBytes = true, expectedRowCounts = Some(100)) - assert(fetched1.get.sizeInBytes > 0) - assert(fetched1.get.colStats.size == 2) - - // set location command - withTempDir { newLocation => - sql(s"ALTER TABLE $table SET LOCATION '${newLocation.toURI.toString}'") - checkTableStats(table, hasSizeInBytes = false, expectedRowCounts = None) + Seq(false, true).foreach { autoUpdate => + withSQLConf(SQLConf.AUTO_UPDATE_SIZE.key -> autoUpdate.toString) { + withTable(table) { + spark.range(100).select($"id", $"id" % 5 as "value").write.saveAsTable(table) + // analyze to get initial stats + sql(s"ANALYZE TABLE $table COMPUTE STATISTICS FOR COLUMNS id, value") + val fetched1 = checkTableStats( + table, hasSizeInBytes = true, expectedRowCounts = Some(100)) + assert(fetched1.get.sizeInBytes > 0) + assert(fetched1.get.colStats.size == 2) + + // set location command + val initLocation = spark.sessionState.catalog.getTableMetadata(TableIdentifier(table)) + .storage.locationUri.get.toString + withTempDir { newLocation => + sql(s"ALTER TABLE $table SET LOCATION '${newLocation.toURI.toString}'") + if (autoUpdate) { + val fetched2 = checkTableStats(table, hasSizeInBytes = true, expectedRowCounts = None) + assert(fetched2.get.sizeInBytes == 0) + assert(fetched2.get.colStats.isEmpty) + + // set back to the initial location + sql(s"ALTER TABLE $table SET LOCATION '$initLocation'") + val fetched3 = checkTableStats(table, hasSizeInBytes = true, expectedRowCounts = None) + assert(fetched3.get.sizeInBytes == fetched1.get.sizeInBytes) + } else { + checkTableStats(table, hasSizeInBytes = false, expectedRowCounts = None) + } + } + } } } } test("change stats after insert command for datasource table") { val table = "change_stats_insert_datasource_table" - withTable(table) { - sql(s"CREATE TABLE $table (i int, j string) USING PARQUET") - // analyze to get initial stats - sql(s"ANALYZE TABLE $table COMPUTE STATISTICS FOR COLUMNS i, j") - val fetched1 = checkTableStats(table, hasSizeInBytes = true, expectedRowCounts = Some(0)) - assert(fetched1.get.sizeInBytes == 0) - assert(fetched1.get.colStats.size == 2) - - // insert into command - sql(s"INSERT INTO TABLE $table SELECT 1, 'abc'") - checkTableStats(table, hasSizeInBytes = false, expectedRowCounts = None) + Seq(false, true).foreach { autoUpdate => + withSQLConf(SQLConf.AUTO_UPDATE_SIZE.key -> autoUpdate.toString) { + withTable(table) { + sql(s"CREATE TABLE $table (i int, j string) USING PARQUET") + // analyze to get initial stats + sql(s"ANALYZE TABLE $table COMPUTE STATISTICS FOR COLUMNS i, j") + val fetched1 = checkTableStats(table, hasSizeInBytes = true, expectedRowCounts = Some(0)) + assert(fetched1.get.sizeInBytes == 0) + assert(fetched1.get.colStats.size == 2) + + // insert into command + sql(s"INSERT INTO TABLE $table SELECT 1, 'abc'") + if (autoUpdate) { + val fetched2 = checkTableStats(table, hasSizeInBytes = true, expectedRowCounts = None) + assert(fetched2.get.sizeInBytes > 0) + assert(fetched2.get.colStats.isEmpty) + } else { + checkTableStats(table, hasSizeInBytes = false, expectedRowCounts = None) + } + } + } } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala index 5fd266c2d033c..c601038a2b0af 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala @@ -444,88 +444,133 @@ class StatisticsSuite extends StatisticsCollectionTestBase with TestHiveSingleto test("change stats after insert command for hive table") { val table = s"change_stats_insert_hive_table" - withTable(table) { - sql(s"CREATE TABLE $table (i int, j string)") - // analyze to get initial stats - sql(s"ANALYZE TABLE $table COMPUTE STATISTICS FOR COLUMNS i, j") - val fetched1 = checkTableStats(table, hasSizeInBytes = true, expectedRowCounts = Some(0)) - assert(fetched1.get.sizeInBytes == 0) - assert(fetched1.get.colStats.size == 2) - - // insert into command - sql(s"INSERT INTO TABLE $table SELECT 1, 'abc'") - assert(getStatsProperties(table).isEmpty) + Seq(false, true).foreach { autoUpdate => + withSQLConf(SQLConf.AUTO_UPDATE_SIZE.key -> autoUpdate.toString) { + withTable(table) { + sql(s"CREATE TABLE $table (i int, j string)") + // analyze to get initial stats + sql(s"ANALYZE TABLE $table COMPUTE STATISTICS FOR COLUMNS i, j") + val fetched1 = checkTableStats(table, hasSizeInBytes = true, expectedRowCounts = Some(0)) + assert(fetched1.get.sizeInBytes == 0) + assert(fetched1.get.colStats.size == 2) + + // insert into command + sql(s"INSERT INTO TABLE $table SELECT 1, 'abc'") + if (autoUpdate) { + val fetched2 = checkTableStats(table, hasSizeInBytes = true, expectedRowCounts = None) + assert(fetched2.get.sizeInBytes > 0) + assert(fetched2.get.colStats.isEmpty) + val statsProp = getStatsProperties(table) + assert(statsProp(STATISTICS_TOTAL_SIZE).toLong == fetched2.get.sizeInBytes) + } else { + assert(getStatsProperties(table).isEmpty) + } + } + } } } test("change stats after load data command") { val table = "change_stats_load_table" - withTable(table) { - sql(s"CREATE TABLE $table (i INT, j STRING) STORED AS PARQUET") - // analyze to get initial stats - sql(s"ANALYZE TABLE $table COMPUTE STATISTICS FOR COLUMNS i, j") - val fetched1 = checkTableStats(table, hasSizeInBytes = true, expectedRowCounts = Some(0)) - assert(fetched1.get.sizeInBytes == 0) - assert(fetched1.get.colStats.size == 2) - - withTempDir { loadPath => - // load data command - val file = new File(loadPath + "/data") - val writer = new PrintWriter(file) - writer.write("2,xyz") - writer.close() - sql(s"LOAD DATA INPATH '${loadPath.toURI.toString}' INTO TABLE $table") - assert(getStatsProperties(table).isEmpty) + Seq(false, true).foreach { autoUpdate => + withSQLConf(SQLConf.AUTO_UPDATE_SIZE.key -> autoUpdate.toString) { + withTable(table) { + sql(s"CREATE TABLE $table (i INT, j STRING) STORED AS PARQUET") + // analyze to get initial stats + sql(s"ANALYZE TABLE $table COMPUTE STATISTICS FOR COLUMNS i, j") + val fetched1 = checkTableStats(table, hasSizeInBytes = true, expectedRowCounts = Some(0)) + assert(fetched1.get.sizeInBytes == 0) + assert(fetched1.get.colStats.size == 2) + + withTempDir { loadPath => + // load data command + val file = new File(loadPath + "/data") + val writer = new PrintWriter(file) + writer.write("2,xyz") + writer.close() + sql(s"LOAD DATA INPATH '${loadPath.toURI.toString}' INTO TABLE $table") + if (autoUpdate) { + val fetched2 = checkTableStats(table, hasSizeInBytes = true, expectedRowCounts = None) + assert(fetched2.get.sizeInBytes > 0) + assert(fetched2.get.colStats.isEmpty) + val statsProp = getStatsProperties(table) + assert(statsProp(STATISTICS_TOTAL_SIZE).toLong == fetched2.get.sizeInBytes) + } else { + assert(getStatsProperties(table).isEmpty) + } + } + } } } } test("change stats after add/drop partition command") { val table = "change_stats_part_table" - withTable(table) { - sql(s"CREATE TABLE $table (i INT, j STRING) PARTITIONED BY (ds STRING, hr STRING)") - // table has two partitions initially - for (ds <- Seq("2008-04-08"); hr <- Seq("11", "12")) { - sql(s"INSERT OVERWRITE TABLE $table PARTITION (ds='$ds',hr='$hr') SELECT 1, 'a'") - } - // analyze to get initial stats - sql(s"ANALYZE TABLE $table COMPUTE STATISTICS FOR COLUMNS i, j") - val fetched1 = checkTableStats(table, hasSizeInBytes = true, expectedRowCounts = Some(2)) - assert(fetched1.get.sizeInBytes > 0) - assert(fetched1.get.colStats.size == 2) - - withTempPaths(numPaths = 2) { case Seq(dir1, dir2) => - val file1 = new File(dir1 + "/data") - val writer1 = new PrintWriter(file1) - writer1.write("1,a") - writer1.close() - - val file2 = new File(dir2 + "/data") - val writer2 = new PrintWriter(file2) - writer2.write("1,a") - writer2.close() - - // add partition command - sql( - s""" - |ALTER TABLE $table ADD - |PARTITION (ds='2008-04-09', hr='11') LOCATION '${dir1.toURI.toString}' - |PARTITION (ds='2008-04-09', hr='12') LOCATION '${dir2.toURI.toString}' - """.stripMargin) - assert(getStatsProperties(table).isEmpty) - - // generate stats again - sql(s"ANALYZE TABLE $table COMPUTE STATISTICS FOR COLUMNS i, j") - val fetched2 = checkTableStats(table, hasSizeInBytes = true, expectedRowCounts = Some(4)) - assert(fetched2.get.sizeInBytes > 0) - assert(fetched2.get.colStats.size == 2) - - // drop partition command - sql(s"ALTER TABLE $table DROP PARTITION (ds='2008-04-08'), PARTITION (hr='12')") - // only one partition left - assert(spark.sessionState.catalog.listPartitions(TableIdentifier(table)) - .map(_.spec).toSet == Set(Map("ds" -> "2008-04-09", "hr" -> "11"))) - assert(getStatsProperties(table).isEmpty) + Seq(false, true).foreach { autoUpdate => + withSQLConf(SQLConf.AUTO_UPDATE_SIZE.key -> autoUpdate.toString) { + withTable(table) { + sql(s"CREATE TABLE $table (i INT, j STRING) PARTITIONED BY (ds STRING, hr STRING)") + // table has two partitions initially + for (ds <- Seq("2008-04-08"); hr <- Seq("11", "12")) { + sql(s"INSERT OVERWRITE TABLE $table PARTITION (ds='$ds',hr='$hr') SELECT 1, 'a'") + } + // analyze to get initial stats + sql(s"ANALYZE TABLE $table COMPUTE STATISTICS FOR COLUMNS i, j") + val fetched1 = checkTableStats(table, hasSizeInBytes = true, expectedRowCounts = Some(2)) + assert(fetched1.get.sizeInBytes > 0) + assert(fetched1.get.colStats.size == 2) + + withTempPaths(numPaths = 2) { case Seq(dir1, dir2) => + val file1 = new File(dir1 + "/data") + val writer1 = new PrintWriter(file1) + writer1.write("1,a") + writer1.close() + + val file2 = new File(dir2 + "/data") + val writer2 = new PrintWriter(file2) + writer2.write("1,a") + writer2.close() + + // add partition command + sql( + s""" + |ALTER TABLE $table ADD + |PARTITION (ds='2008-04-09', hr='11') LOCATION '${dir1.toURI.toString}' + |PARTITION (ds='2008-04-09', hr='12') LOCATION '${dir2.toURI.toString}' + """.stripMargin) + if (autoUpdate) { + val fetched2 = checkTableStats(table, hasSizeInBytes = true, expectedRowCounts = None) + assert(fetched2.get.sizeInBytes > fetched1.get.sizeInBytes) + assert(fetched2.get.colStats.isEmpty) + val statsProp = getStatsProperties(table) + assert(statsProp(STATISTICS_TOTAL_SIZE).toLong == fetched2.get.sizeInBytes) + } else { + assert(getStatsProperties(table).isEmpty) + } + + // now the table has four partitions, generate stats again + sql(s"ANALYZE TABLE $table COMPUTE STATISTICS FOR COLUMNS i, j") + val fetched3 = checkTableStats( + table, hasSizeInBytes = true, expectedRowCounts = Some(4)) + assert(fetched3.get.sizeInBytes > 0) + assert(fetched3.get.colStats.size == 2) + + // drop partition command + sql(s"ALTER TABLE $table DROP PARTITION (ds='2008-04-08'), PARTITION (hr='12')") + assert(spark.sessionState.catalog.listPartitions(TableIdentifier(table)) + .map(_.spec).toSet == Set(Map("ds" -> "2008-04-09", "hr" -> "11"))) + // only one partition left + if (autoUpdate) { + val fetched4 = checkTableStats(table, hasSizeInBytes = true, expectedRowCounts = None) + assert(fetched4.get.sizeInBytes < fetched1.get.sizeInBytes) + assert(fetched4.get.colStats.isEmpty) + val statsProp = getStatsProperties(table) + assert(statsProp(STATISTICS_TOTAL_SIZE).toLong == fetched4.get.sizeInBytes) + } else { + assert(getStatsProperties(table).isEmpty) + } + } + } } } } From b1d719e7c9faeb5661a7e712b3ecefca56bf356f Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Fri, 30 Jun 2017 21:10:23 -0700 Subject: [PATCH 0839/1765] [SPARK-21273][SQL] Propagate logical plan stats using visitor pattern and mixin ## What changes were proposed in this pull request? We currently implement statistics propagation directly in logical plan. Given we already have two different implementations, it'd make sense to actually decouple the two and add stats propagation using mixin. This would reduce the coupling between logical plan and statistics handling. This can also be a powerful pattern in the future to add additional properties (e.g. constraints). ## How was this patch tested? Should be covered by existing test cases. Author: Reynold Xin Closes #18479 from rxin/stats-trait. --- .../sql/catalyst/catalog/interface.scala | 2 +- .../plans/logical/LocalRelation.scala | 5 +- .../catalyst/plans/logical/LogicalPlan.scala | 61 +------ .../plans/logical/LogicalPlanVisitor.scala | 87 ++++++++++ .../plans/logical/basicLogicalOperators.scala | 128 +------------- .../sql/catalyst/plans/logical/hints.scala | 5 - .../BasicStatsPlanVisitor.scala | 82 +++++++++ .../statsEstimation/LogicalPlanStats.scala | 50 ++++++ .../SizeInBytesOnlyStatsPlanVisitor.scala | 163 ++++++++++++++++++ .../BasicStatsEstimationSuite.scala | 44 ----- .../StatsEstimationTestBase.scala | 2 +- .../spark/sql/execution/ExistingRDD.scala | 4 +- .../execution/columnar/InMemoryRelation.scala | 2 +- .../datasources/LogicalRelation.scala | 7 +- .../sql/execution/streaming/memory.scala | 3 +- .../PruneFileSourcePartitionsSuite.scala | 2 +- 16 files changed, 409 insertions(+), 238 deletions(-) create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlanVisitor.scala create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/BasicStatsPlanVisitor.scala create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/LogicalPlanStats.scala create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/SizeInBytesOnlyStatsPlanVisitor.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala index da50b0e7e8e42..9531456434a15 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala @@ -438,7 +438,7 @@ case class CatalogRelation( case (attr, index) => attr.withExprId(ExprId(index + dataCols.length)) }) - override def computeStats: Statistics = { + override def computeStats(): Statistics = { // For data source tables, we will create a `LogicalRelation` and won't call this method, for // hive serde tables, we will always generate a statistics. // TODO: unify the table stats generation. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LocalRelation.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LocalRelation.scala index dc2add64b68b7..1c986fbde7ada 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LocalRelation.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LocalRelation.scala @@ -66,9 +66,8 @@ case class LocalRelation(output: Seq[Attribute], data: Seq[InternalRow] = Nil) } } - override def computeStats: Statistics = - Statistics(sizeInBytes = - output.map(n => BigInt(n.dataType.defaultSize)).sum * data.length) + override def computeStats(): Statistics = + Statistics(sizeInBytes = output.map(n => BigInt(n.dataType.defaultSize)).sum * data.length) def toSQL(inlineTableName: String): String = { require(data.nonEmpty) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala index 0d30aa76049a5..8649603b1a9f5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala @@ -22,11 +22,16 @@ import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.QueryPlan +import org.apache.spark.sql.catalyst.plans.logical.statsEstimation.LogicalPlanStats import org.apache.spark.sql.catalyst.trees.CurrentOrigin import org.apache.spark.sql.types.StructType -abstract class LogicalPlan extends QueryPlan[LogicalPlan] with QueryPlanConstraints with Logging { +abstract class LogicalPlan + extends QueryPlan[LogicalPlan] + with LogicalPlanStats + with QueryPlanConstraints + with Logging { private var _analyzed: Boolean = false @@ -80,40 +85,6 @@ abstract class LogicalPlan extends QueryPlan[LogicalPlan] with QueryPlanConstrai } } - /** A cache for the estimated statistics, such that it will only be computed once. */ - private var statsCache: Option[Statistics] = None - - /** - * Returns the estimated statistics for the current logical plan node. Under the hood, this - * method caches the return value, which is computed based on the configuration passed in the - * first time. If the configuration changes, the cache can be invalidated by calling - * [[invalidateStatsCache()]]. - */ - final def stats: Statistics = statsCache.getOrElse { - statsCache = Some(computeStats) - statsCache.get - } - - /** Invalidates the stats cache. See [[stats]] for more information. */ - final def invalidateStatsCache(): Unit = { - statsCache = None - children.foreach(_.invalidateStatsCache()) - } - - /** - * Computes [[Statistics]] for this plan. The default implementation assumes the output - * cardinality is the product of all child plan's cardinality, i.e. applies in the case - * of cartesian joins. - * - * [[LeafNode]]s must override this. - */ - protected def computeStats: Statistics = { - if (children.isEmpty) { - throw new UnsupportedOperationException(s"LeafNode $nodeName must implement statistics.") - } - Statistics(sizeInBytes = children.map(_.stats.sizeInBytes).product) - } - override def verboseStringWithSuffix: String = { super.verboseString + statsCache.map(", " + _.toString).getOrElse("") } @@ -300,6 +271,9 @@ abstract class LogicalPlan extends QueryPlan[LogicalPlan] with QueryPlanConstrai abstract class LeafNode extends LogicalPlan { override final def children: Seq[LogicalPlan] = Nil override def producedAttributes: AttributeSet = outputSet + + /** Leaf nodes that can survive analysis must define their own statistics. */ + def computeStats(): Statistics = throw new UnsupportedOperationException } /** @@ -331,23 +305,6 @@ abstract class UnaryNode extends LogicalPlan { } override protected def validConstraints: Set[Expression] = child.constraints - - override def computeStats: Statistics = { - // There should be some overhead in Row object, the size should not be zero when there is - // no columns, this help to prevent divide-by-zero error. - val childRowSize = child.output.map(_.dataType.defaultSize).sum + 8 - val outputRowSize = output.map(_.dataType.defaultSize).sum + 8 - // Assume there will be the same number of rows as child has. - var sizeInBytes = (child.stats.sizeInBytes * outputRowSize) / childRowSize - if (sizeInBytes == 0) { - // sizeInBytes can't be zero, or sizeInBytes of BinaryNode will also be zero - // (product of children). - sizeInBytes = 1 - } - - // Don't propagate rowCount and attributeStats, since they are not estimated here. - Statistics(sizeInBytes = sizeInBytes, hints = child.stats.hints) - } } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlanVisitor.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlanVisitor.scala new file mode 100644 index 0000000000000..b23045810a4f6 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlanVisitor.scala @@ -0,0 +1,87 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.plans.logical + +/** + * A visitor pattern for traversing a [[LogicalPlan]] tree and compute some properties. + */ +trait LogicalPlanVisitor[T] { + + def visit(p: LogicalPlan): T = p match { + case p: Aggregate => visitAggregate(p) + case p: Distinct => visitDistinct(p) + case p: Except => visitExcept(p) + case p: Expand => visitExpand(p) + case p: Filter => visitFilter(p) + case p: Generate => visitGenerate(p) + case p: GlobalLimit => visitGlobalLimit(p) + case p: Intersect => visitIntersect(p) + case p: Join => visitJoin(p) + case p: LocalLimit => visitLocalLimit(p) + case p: Pivot => visitPivot(p) + case p: Project => visitProject(p) + case p: Range => visitRange(p) + case p: Repartition => visitRepartition(p) + case p: RepartitionByExpression => visitRepartitionByExpr(p) + case p: Sample => visitSample(p) + case p: ScriptTransformation => visitScriptTransform(p) + case p: Union => visitUnion(p) + case p: ResolvedHint => visitHint(p) + case p: LogicalPlan => default(p) + } + + def default(p: LogicalPlan): T + + def visitAggregate(p: Aggregate): T + + def visitDistinct(p: Distinct): T + + def visitExcept(p: Except): T + + def visitExpand(p: Expand): T + + def visitFilter(p: Filter): T + + def visitGenerate(p: Generate): T + + def visitGlobalLimit(p: GlobalLimit): T + + def visitHint(p: ResolvedHint): T + + def visitIntersect(p: Intersect): T + + def visitJoin(p: Join): T + + def visitLocalLimit(p: LocalLimit): T + + def visitPivot(p: Pivot): T + + def visitProject(p: Project): T + + def visitRange(p: Range): T + + def visitRepartition(p: Repartition): T + + def visitRepartitionByExpr(p: RepartitionByExpression): T + + def visitSample(p: Sample): T + + def visitScriptTransform(p: ScriptTransformation): T + + def visitUnion(p: Union): T +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala index e89caabf252d7..0bd3166352d35 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala @@ -63,14 +63,6 @@ case class Project(projectList: Seq[NamedExpression], child: LogicalPlan) extend override def validConstraints: Set[Expression] = child.constraints.union(getAliasedConstraints(projectList)) - - override def computeStats: Statistics = { - if (conf.cboEnabled) { - ProjectEstimation.estimate(this).getOrElse(super.computeStats) - } else { - super.computeStats - } - } } /** @@ -137,14 +129,6 @@ case class Filter(condition: Expression, child: LogicalPlan) .filterNot(SubqueryExpression.hasCorrelatedSubquery) child.constraints.union(predicates.toSet) } - - override def computeStats: Statistics = { - if (conf.cboEnabled) { - FilterEstimation(this).estimate.getOrElse(super.computeStats) - } else { - super.computeStats - } - } } abstract class SetOperation(left: LogicalPlan, right: LogicalPlan) extends BinaryNode { @@ -190,15 +174,6 @@ case class Intersect(left: LogicalPlan, right: LogicalPlan) extends SetOperation Some(children.flatMap(_.maxRows).min) } } - - override def computeStats: Statistics = { - val leftSize = left.stats.sizeInBytes - val rightSize = right.stats.sizeInBytes - val sizeInBytes = if (leftSize < rightSize) leftSize else rightSize - Statistics( - sizeInBytes = sizeInBytes, - hints = left.stats.hints.resetForJoin()) - } } case class Except(left: LogicalPlan, right: LogicalPlan) extends SetOperation(left, right) { @@ -207,10 +182,6 @@ case class Except(left: LogicalPlan, right: LogicalPlan) extends SetOperation(le override def output: Seq[Attribute] = left.output override protected def validConstraints: Set[Expression] = leftConstraints - - override def computeStats: Statistics = { - left.stats.copy() - } } /** Factory for constructing new `Union` nodes. */ @@ -247,11 +218,6 @@ case class Union(children: Seq[LogicalPlan]) extends LogicalPlan { children.length > 1 && childrenResolved && allChildrenCompatible } - override def computeStats: Statistics = { - val sizeInBytes = children.map(_.stats.sizeInBytes).sum - Statistics(sizeInBytes = sizeInBytes) - } - /** * Maps the constraints containing a given (original) sequence of attributes to those with a * given (reference) sequence of attributes. Given the nature of union, we expect that the @@ -355,25 +321,6 @@ case class Join( case UsingJoin(_, _) => false case _ => resolvedExceptNatural } - - override def computeStats: Statistics = { - def simpleEstimation: Statistics = joinType match { - case LeftAnti | LeftSemi => - // LeftSemi and LeftAnti won't ever be bigger than left - left.stats - case _ => - // Make sure we don't propagate isBroadcastable in other joins, because - // they could explode the size. - val stats = super.computeStats - stats.copy(hints = stats.hints.resetForJoin()) - } - - if (conf.cboEnabled) { - JoinEstimation.estimate(this).getOrElse(simpleEstimation) - } else { - simpleEstimation - } - } } /** @@ -522,14 +469,13 @@ case class Range( override def newInstance(): Range = copy(output = output.map(_.newInstance())) - override def computeStats: Statistics = { - val sizeInBytes = LongType.defaultSize * numElements - Statistics( sizeInBytes = sizeInBytes ) - } - override def simpleString: String = { s"Range ($start, $end, step=$step, splits=$numSlices)" } + + override def computeStats(): Statistics = { + Statistics(sizeInBytes = LongType.defaultSize * numElements) + } } case class Aggregate( @@ -554,25 +500,6 @@ case class Aggregate( val nonAgg = aggregateExpressions.filter(_.find(_.isInstanceOf[AggregateExpression]).isEmpty) child.constraints.union(getAliasedConstraints(nonAgg)) } - - override def computeStats: Statistics = { - def simpleEstimation: Statistics = { - if (groupingExpressions.isEmpty) { - Statistics( - sizeInBytes = EstimationUtils.getOutputSize(output, outputRowCount = 1), - rowCount = Some(1), - hints = child.stats.hints) - } else { - super.computeStats - } - } - - if (conf.cboEnabled) { - AggregateEstimation.estimate(this).getOrElse(simpleEstimation) - } else { - simpleEstimation - } - } } case class Window( @@ -671,11 +598,6 @@ case class Expand( override def references: AttributeSet = AttributeSet(projections.flatten.flatMap(_.references)) - override def computeStats: Statistics = { - val sizeInBytes = super.computeStats.sizeInBytes * projections.length - Statistics(sizeInBytes = sizeInBytes) - } - // This operator can reuse attributes (for example making them null when doing a roll up) so // the constraints of the child may no longer be valid. override protected def validConstraints: Set[Expression] = Set.empty[Expression] @@ -742,16 +664,6 @@ case class GlobalLimit(limitExpr: Expression, child: LogicalPlan) extends UnaryN case _ => None } } - override def computeStats: Statistics = { - val limit = limitExpr.eval().asInstanceOf[Int] - val childStats = child.stats - val rowCount: BigInt = childStats.rowCount.map(_.min(limit)).getOrElse(limit) - // Don't propagate column stats, because we don't know the distribution after a limit operation - Statistics( - sizeInBytes = EstimationUtils.getOutputSize(output, rowCount, childStats.attributeStats), - rowCount = Some(rowCount), - hints = childStats.hints) - } } case class LocalLimit(limitExpr: Expression, child: LogicalPlan) extends UnaryNode { @@ -762,24 +674,6 @@ case class LocalLimit(limitExpr: Expression, child: LogicalPlan) extends UnaryNo case _ => None } } - override def computeStats: Statistics = { - val limit = limitExpr.eval().asInstanceOf[Int] - val childStats = child.stats - if (limit == 0) { - // sizeInBytes can't be zero, or sizeInBytes of BinaryNode will also be zero - // (product of children). - Statistics( - sizeInBytes = 1, - rowCount = Some(0), - hints = childStats.hints) - } else { - // The output row count of LocalLimit should be the sum of row counts from each partition. - // However, since the number of partitions is not available here, we just use statistics of - // the child. Because the distribution after a limit operation is unknown, we do not propagate - // the column stats. - childStats.copy(attributeStats = AttributeMap(Nil)) - } - } } /** @@ -828,18 +722,6 @@ case class Sample( } override def output: Seq[Attribute] = child.output - - override def computeStats: Statistics = { - val ratio = upperBound - lowerBound - val childStats = child.stats - var sizeInBytes = EstimationUtils.ceil(BigDecimal(childStats.sizeInBytes) * ratio) - if (sizeInBytes == 0) { - sizeInBytes = 1 - } - val sampledRowCount = childStats.rowCount.map(c => EstimationUtils.ceil(BigDecimal(c) * ratio)) - // Don't propagate column stats, because we don't know the distribution after a sample operation - Statistics(sizeInBytes, sampledRowCount, hints = childStats.hints) - } } /** @@ -893,7 +775,7 @@ case class RepartitionByExpression( case object OneRowRelation extends LeafNode { override def maxRows: Option[Long] = Some(1) override def output: Seq[Attribute] = Nil - override def computeStats: Statistics = Statistics(sizeInBytes = 1) + override def computeStats(): Statistics = Statistics(sizeInBytes = 1) } /** A logical plan for `dropDuplicates`. */ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/hints.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/hints.scala index 8479c702d7561..29a43528124d8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/hints.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/hints.scala @@ -42,11 +42,6 @@ case class ResolvedHint(child: LogicalPlan, hints: HintInfo = HintInfo()) override def output: Seq[Attribute] = child.output override lazy val canonicalized: LogicalPlan = child.canonicalized - - override def computeStats: Statistics = { - val stats = child.stats - stats.copy(hints = hints) - } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/BasicStatsPlanVisitor.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/BasicStatsPlanVisitor.scala new file mode 100644 index 0000000000000..93908b04fb643 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/BasicStatsPlanVisitor.scala @@ -0,0 +1,82 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.plans.logical.statsEstimation + +import org.apache.spark.sql.catalyst.plans.logical +import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.types.LongType + +/** + * An [[LogicalPlanVisitor]] that computes a the statistics used in a cost-based optimizer. + */ +object BasicStatsPlanVisitor extends LogicalPlanVisitor[Statistics] { + + /** Falls back to the estimation computed by [[SizeInBytesOnlyStatsPlanVisitor]]. */ + private def fallback(p: LogicalPlan): Statistics = SizeInBytesOnlyStatsPlanVisitor.visit(p) + + override def default(p: LogicalPlan): Statistics = fallback(p) + + override def visitAggregate(p: Aggregate): Statistics = { + AggregateEstimation.estimate(p).getOrElse(fallback(p)) + } + + override def visitDistinct(p: Distinct): Statistics = fallback(p) + + override def visitExcept(p: Except): Statistics = fallback(p) + + override def visitExpand(p: Expand): Statistics = fallback(p) + + override def visitFilter(p: Filter): Statistics = { + FilterEstimation(p).estimate.getOrElse(fallback(p)) + } + + override def visitGenerate(p: Generate): Statistics = fallback(p) + + override def visitGlobalLimit(p: GlobalLimit): Statistics = fallback(p) + + override def visitHint(p: ResolvedHint): Statistics = fallback(p) + + override def visitIntersect(p: Intersect): Statistics = fallback(p) + + override def visitJoin(p: Join): Statistics = { + JoinEstimation.estimate(p).getOrElse(fallback(p)) + } + + override def visitLocalLimit(p: LocalLimit): Statistics = fallback(p) + + override def visitPivot(p: Pivot): Statistics = fallback(p) + + override def visitProject(p: Project): Statistics = { + ProjectEstimation.estimate(p).getOrElse(fallback(p)) + } + + override def visitRange(p: logical.Range): Statistics = { + val sizeInBytes = LongType.defaultSize * p.numElements + Statistics(sizeInBytes = sizeInBytes) + } + + override def visitRepartition(p: Repartition): Statistics = fallback(p) + + override def visitRepartitionByExpr(p: RepartitionByExpression): Statistics = fallback(p) + + override def visitSample(p: Sample): Statistics = fallback(p) + + override def visitScriptTransform(p: ScriptTransformation): Statistics = fallback(p) + + override def visitUnion(p: Union): Statistics = fallback(p) +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/LogicalPlanStats.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/LogicalPlanStats.scala new file mode 100644 index 0000000000000..8660d93550192 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/LogicalPlanStats.scala @@ -0,0 +1,50 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.plans.logical.statsEstimation + +import org.apache.spark.sql.catalyst.plans.logical._ + +/** + * A trait to add statistics propagation to [[LogicalPlan]]. + */ +trait LogicalPlanStats { self: LogicalPlan => + + /** + * Returns the estimated statistics for the current logical plan node. Under the hood, this + * method caches the return value, which is computed based on the configuration passed in the + * first time. If the configuration changes, the cache can be invalidated by calling + * [[invalidateStatsCache()]]. + */ + def stats: Statistics = statsCache.getOrElse { + if (conf.cboEnabled) { + statsCache = Option(BasicStatsPlanVisitor.visit(self)) + } else { + statsCache = Option(SizeInBytesOnlyStatsPlanVisitor.visit(self)) + } + statsCache.get + } + + /** A cache for the estimated statistics, such that it will only be computed once. */ + protected var statsCache: Option[Statistics] = None + + /** Invalidates the stats cache. See [[stats]] for more information. */ + final def invalidateStatsCache(): Unit = { + statsCache = None + children.foreach(_.invalidateStatsCache()) + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/SizeInBytesOnlyStatsPlanVisitor.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/SizeInBytesOnlyStatsPlanVisitor.scala new file mode 100644 index 0000000000000..559f12072e448 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/SizeInBytesOnlyStatsPlanVisitor.scala @@ -0,0 +1,163 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.plans.logical.statsEstimation + +import org.apache.spark.sql.catalyst.expressions.AttributeMap +import org.apache.spark.sql.catalyst.plans.{LeftAnti, LeftSemi} +import org.apache.spark.sql.catalyst.plans.logical +import org.apache.spark.sql.catalyst.plans.logical._ + +/** + * An [[LogicalPlanVisitor]] that computes a single dimension for plan stats: size in bytes. + */ +object SizeInBytesOnlyStatsPlanVisitor extends LogicalPlanVisitor[Statistics] { + + /** + * A default, commonly used estimation for unary nodes. We assume the input row number is the + * same as the output row number, and compute sizes based on the column types. + */ + private def visitUnaryNode(p: UnaryNode): Statistics = { + // There should be some overhead in Row object, the size should not be zero when there is + // no columns, this help to prevent divide-by-zero error. + val childRowSize = p.child.output.map(_.dataType.defaultSize).sum + 8 + val outputRowSize = p.output.map(_.dataType.defaultSize).sum + 8 + // Assume there will be the same number of rows as child has. + var sizeInBytes = (p.child.stats.sizeInBytes * outputRowSize) / childRowSize + if (sizeInBytes == 0) { + // sizeInBytes can't be zero, or sizeInBytes of BinaryNode will also be zero + // (product of children). + sizeInBytes = 1 + } + + // Don't propagate rowCount and attributeStats, since they are not estimated here. + Statistics(sizeInBytes = sizeInBytes, hints = p.child.stats.hints) + } + + /** + * For leaf nodes, use its computeStats. For other nodes, we assume the size in bytes is the + * sum of all of the children's. + */ + override def default(p: LogicalPlan): Statistics = p match { + case p: LeafNode => p.computeStats() + case _: LogicalPlan => Statistics(sizeInBytes = p.children.map(_.stats.sizeInBytes).product) + } + + override def visitAggregate(p: Aggregate): Statistics = { + if (p.groupingExpressions.isEmpty) { + Statistics( + sizeInBytes = EstimationUtils.getOutputSize(p.output, outputRowCount = 1), + rowCount = Some(1), + hints = p.child.stats.hints) + } else { + visitUnaryNode(p) + } + } + + override def visitDistinct(p: Distinct): Statistics = default(p) + + override def visitExcept(p: Except): Statistics = p.left.stats.copy() + + override def visitExpand(p: Expand): Statistics = { + val sizeInBytes = visitUnaryNode(p).sizeInBytes * p.projections.length + Statistics(sizeInBytes = sizeInBytes) + } + + override def visitFilter(p: Filter): Statistics = visitUnaryNode(p) + + override def visitGenerate(p: Generate): Statistics = default(p) + + override def visitGlobalLimit(p: GlobalLimit): Statistics = { + val limit = p.limitExpr.eval().asInstanceOf[Int] + val childStats = p.child.stats + val rowCount: BigInt = childStats.rowCount.map(_.min(limit)).getOrElse(limit) + // Don't propagate column stats, because we don't know the distribution after limit + Statistics( + sizeInBytes = EstimationUtils.getOutputSize(p.output, rowCount, childStats.attributeStats), + rowCount = Some(rowCount), + hints = childStats.hints) + } + + override def visitHint(p: ResolvedHint): Statistics = p.child.stats.copy(hints = p.hints) + + override def visitIntersect(p: Intersect): Statistics = { + val leftSize = p.left.stats.sizeInBytes + val rightSize = p.right.stats.sizeInBytes + val sizeInBytes = if (leftSize < rightSize) leftSize else rightSize + Statistics( + sizeInBytes = sizeInBytes, + hints = p.left.stats.hints.resetForJoin()) + } + + override def visitJoin(p: Join): Statistics = { + p.joinType match { + case LeftAnti | LeftSemi => + // LeftSemi and LeftAnti won't ever be bigger than left + p.left.stats + case _ => + // Make sure we don't propagate isBroadcastable in other joins, because + // they could explode the size. + val stats = default(p) + stats.copy(hints = stats.hints.resetForJoin()) + } + } + + override def visitLocalLimit(p: LocalLimit): Statistics = { + val limit = p.limitExpr.eval().asInstanceOf[Int] + val childStats = p.child.stats + if (limit == 0) { + // sizeInBytes can't be zero, or sizeInBytes of BinaryNode will also be zero + // (product of children). + Statistics(sizeInBytes = 1, rowCount = Some(0), hints = childStats.hints) + } else { + // The output row count of LocalLimit should be the sum of row counts from each partition. + // However, since the number of partitions is not available here, we just use statistics of + // the child. Because the distribution after a limit operation is unknown, we do not propagate + // the column stats. + childStats.copy(attributeStats = AttributeMap(Nil)) + } + } + + override def visitPivot(p: Pivot): Statistics = default(p) + + override def visitProject(p: Project): Statistics = visitUnaryNode(p) + + override def visitRange(p: logical.Range): Statistics = { + p.computeStats() + } + + override def visitRepartition(p: Repartition): Statistics = default(p) + + override def visitRepartitionByExpr(p: RepartitionByExpression): Statistics = default(p) + + override def visitSample(p: Sample): Statistics = { + val ratio = p.upperBound - p.lowerBound + var sizeInBytes = EstimationUtils.ceil(BigDecimal(p.child.stats.sizeInBytes) * ratio) + if (sizeInBytes == 0) { + sizeInBytes = 1 + } + val sampleRows = p.child.stats.rowCount.map(c => EstimationUtils.ceil(BigDecimal(c) * ratio)) + // Don't propagate column stats, because we don't know the distribution after a sample operation + Statistics(sizeInBytes, sampleRows, hints = p.child.stats.hints) + } + + override def visitScriptTransform(p: ScriptTransformation): Statistics = default(p) + + override def visitUnion(p: Union): Statistics = { + Statistics(sizeInBytes = p.children.map(_.stats.sizeInBytes).sum) + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/BasicStatsEstimationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/BasicStatsEstimationSuite.scala index 912c5fed63450..31a8cbdee9777 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/BasicStatsEstimationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/BasicStatsEstimationSuite.scala @@ -77,37 +77,6 @@ class BasicStatsEstimationSuite extends StatsEstimationTestBase { checkStats(globalLimit, stats) } - test("sample estimation") { - val sample = Sample(0.0, 0.5, withReplacement = false, (math.random * 1000).toLong, plan) - checkStats(sample, Statistics(sizeInBytes = 60, rowCount = Some(5))) - - // Child doesn't have rowCount in stats - val childStats = Statistics(sizeInBytes = 120) - val childPlan = DummyLogicalPlan(childStats, childStats) - val sample2 = - Sample(0.0, 0.11, withReplacement = false, (math.random * 1000).toLong, childPlan) - checkStats(sample2, Statistics(sizeInBytes = 14)) - } - - test("estimate statistics when the conf changes") { - val expectedDefaultStats = - Statistics( - sizeInBytes = 40, - rowCount = Some(10), - attributeStats = AttributeMap(Seq( - AttributeReference("c1", IntegerType)() -> ColumnStat(10, Some(1), Some(10), 0, 4, 4)))) - val expectedCboStats = - Statistics( - sizeInBytes = 4, - rowCount = Some(1), - attributeStats = AttributeMap(Seq( - AttributeReference("c1", IntegerType)() -> ColumnStat(1, Some(5), Some(5), 0, 4, 4)))) - - val plan = DummyLogicalPlan(defaultStats = expectedDefaultStats, cboStats = expectedCboStats) - checkStats( - plan, expectedStatsCboOn = expectedCboStats, expectedStatsCboOff = expectedDefaultStats) - } - /** Check estimated stats when cbo is turned on/off. */ private def checkStats( plan: LogicalPlan, @@ -132,16 +101,3 @@ class BasicStatsEstimationSuite extends StatsEstimationTestBase { private def checkStats(plan: LogicalPlan, expectedStats: Statistics): Unit = checkStats(plan, expectedStats, expectedStats) } - -/** - * This class is used for unit-testing the cbo switch, it mimics a logical plan which computes - * a simple statistics or a cbo estimated statistics based on the conf. - */ -private case class DummyLogicalPlan( - defaultStats: Statistics, - cboStats: Statistics) extends LogicalPlan { - override def output: Seq[Attribute] = Nil - override def children: Seq[LogicalPlan] = Nil - override def computeStats: Statistics = - if (conf.cboEnabled) cboStats else defaultStats -} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/StatsEstimationTestBase.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/StatsEstimationTestBase.scala index eaa33e44a6a5a..31dea2e3e7f1d 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/StatsEstimationTestBase.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/StatsEstimationTestBase.scala @@ -65,7 +65,7 @@ case class StatsTestPlan( attributeStats: AttributeMap[ColumnStat], size: Option[BigInt] = None) extends LeafNode { override def output: Seq[Attribute] = outputList - override def computeStats: Statistics = Statistics( + override def computeStats(): Statistics = Statistics( // If sizeInBytes is useless in testing, we just use a fake value sizeInBytes = size.getOrElse(Int.MaxValue), rowCount = Some(rowCount), diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala index 66f66a289a065..dcb918eeb9d10 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala @@ -88,7 +88,7 @@ case class ExternalRDD[T]( override protected def stringArgs: Iterator[Any] = Iterator(output) - @transient override def computeStats: Statistics = Statistics( + override def computeStats(): Statistics = Statistics( // TODO: Instead of returning a default value here, find a way to return a meaningful size // estimate for RDDs. See PR 1238 for more discussions. sizeInBytes = BigInt(session.sessionState.conf.defaultSizeInBytes) @@ -156,7 +156,7 @@ case class LogicalRDD( override protected def stringArgs: Iterator[Any] = Iterator(output) - @transient override def computeStats: Statistics = Statistics( + override def computeStats(): Statistics = Statistics( // TODO: Instead of returning a default value here, find a way to return a meaningful size // estimate for RDDs. See PR 1238 for more discussions. sizeInBytes = BigInt(session.sessionState.conf.defaultSizeInBytes) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala index 2972132336de0..39cf8fcac5116 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala @@ -69,7 +69,7 @@ case class InMemoryRelation( @transient val partitionStatistics = new PartitionStatistics(output) - override def computeStats: Statistics = { + override def computeStats(): Statistics = { if (batchStats.value == 0L) { // Underlying columnar RDD hasn't been materialized, no useful statistics information // available, return the default statistics. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/LogicalRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/LogicalRelation.scala index 6ba190b9e5dcf..699f1bad9c4ed 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/LogicalRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/LogicalRelation.scala @@ -48,9 +48,10 @@ case class LogicalRelation( output = output.map(QueryPlan.normalizeExprId(_, output)), catalogTable = None) - @transient override def computeStats: Statistics = { - catalogTable.flatMap(_.stats.map(_.toPlanStats(output))).getOrElse( - Statistics(sizeInBytes = relation.sizeInBytes)) + override def computeStats(): Statistics = { + catalogTable + .flatMap(_.stats.map(_.toPlanStats(output))) + .getOrElse(Statistics(sizeInBytes = relation.sizeInBytes)) } /** Used to lookup original attribute capitalization */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala index 4979873ee3c7f..587ae2bfb63fb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala @@ -230,6 +230,5 @@ case class MemoryPlan(sink: MemorySink, output: Seq[Attribute]) extends LeafNode private val sizePerRow = sink.schema.toAttributes.map(_.dataType.defaultSize).sum - override def computeStats: Statistics = - Statistics(sizePerRow * sink.allData.size) + override def computeStats(): Statistics = Statistics(sizePerRow * sink.allData.size) } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/PruneFileSourcePartitionsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/PruneFileSourcePartitionsSuite.scala index 3a724aa14f2a9..94384185d190a 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/PruneFileSourcePartitionsSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/PruneFileSourcePartitionsSuite.scala @@ -86,7 +86,7 @@ class PruneFileSourcePartitionsSuite extends QueryTest with SQLTestUtils with Te case relation: LogicalRelation => relation } assert(relations.size === 1, s"Size wrong for:\n ${df.queryExecution}") - val size2 = relations(0).computeStats.sizeInBytes + val size2 = relations(0).stats.sizeInBytes assert(size2 == relations(0).catalogTable.get.stats.get.sizeInBytes) assert(size2 < tableStats.get.sizeInBytes) } From 37ef32e515ea071afe63b56ba0d4299bb76e8a75 Mon Sep 17 00:00:00 2001 From: actuaryzhang Date: Sat, 1 Jul 2017 14:57:57 +0800 Subject: [PATCH 0840/1765] [SPARK-21275][ML] Update GLM test to use supportedFamilyNames ## What changes were proposed in this pull request? Update GLM test to use supportedFamilyNames as suggested here: https://github.com/apache/spark/pull/16699#discussion-diff-100574976R855 Author: actuaryzhang Closes #18495 from actuaryzhang/mlGlmTest2. --- .../GeneralizedLinearRegressionSuite.scala | 33 +++++++++---------- 1 file changed, 16 insertions(+), 17 deletions(-) diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala index 83f1344a7bcb1..a47bd17f47bb1 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala @@ -749,15 +749,15 @@ class GeneralizedLinearRegressionSuite library(statmod) y <- c(1.0, 0.5, 0.7, 0.3) w <- c(1, 2, 3, 4) - for (fam in list(gaussian(), poisson(), binomial(), Gamma(), tweedie(1.6))) { + for (fam in list(binomial(), Gamma(), gaussian(), poisson(), tweedie(1.6))) { model1 <- glm(y ~ 1, family = fam) model2 <- glm(y ~ 1, family = fam, weights = w) print(as.vector(c(coef(model1), coef(model2)))) } - [1] 0.625 0.530 - [1] -0.4700036 -0.6348783 [1] 0.5108256 0.1201443 [1] 1.600000 1.886792 + [1] 0.625 0.530 + [1] -0.4700036 -0.6348783 [1] 1.325782 1.463641 */ @@ -768,13 +768,13 @@ class GeneralizedLinearRegressionSuite Instance(0.3, 4.0, Vectors.zeros(0)) ).toDF() - val expected = Seq(0.625, 0.530, -0.4700036, -0.6348783, 0.5108256, 0.1201443, - 1.600000, 1.886792, 1.325782, 1.463641) + val expected = Seq(0.5108256, 0.1201443, 1.600000, 1.886792, 0.625, 0.530, + -0.4700036, -0.6348783, 1.325782, 1.463641) import GeneralizedLinearRegression._ var idx = 0 - for (family <- Seq("gaussian", "poisson", "binomial", "gamma", "tweedie")) { + for (family <- GeneralizedLinearRegression.supportedFamilyNames.sortWith(_ < _)) { for (useWeight <- Seq(false, true)) { val trainer = new GeneralizedLinearRegression().setFamily(family) if (useWeight) trainer.setWeightCol("weight") @@ -807,7 +807,7 @@ class GeneralizedLinearRegressionSuite 0.5, 2.1, 0.5, 1.0, 2.0, 0.9, 0.4, 1.0, 2.0, 1.0, 0.7, 0.7, 0.0, 3.0, 3.0), 4, 5, byrow = TRUE)) - families <- list(gaussian, binomial, poisson, Gamma, tweedie(1.5)) + families <- list(binomial, Gamma, gaussian, poisson, tweedie(1.5)) f1 <- V1 ~ -1 + V4 + V5 f2 <- V1 ~ V4 + V5 for (f in c(f1, f2)) { @@ -816,15 +816,15 @@ class GeneralizedLinearRegressionSuite print(as.vector(coef(model))) } } - [1] 0.5169222 -0.3344444 [1] 0.9419107 -0.6864404 - [1] 0.1812436 -0.6568422 [1] -0.2869094 0.7857710 + [1] 0.5169222 -0.3344444 + [1] 0.1812436 -0.6568422 [1] 0.1055254 0.2979113 - [1] -0.05990345 0.53188982 -0.32118415 [1] -0.2147117 0.9911750 -0.6356096 - [1] -1.5616130 0.6646470 -0.3192581 [1] 0.3390397 -0.3406099 0.6870259 + [1] -0.05990345 0.53188982 -0.32118415 + [1] -1.5616130 0.6646470 -0.3192581 [1] 0.3665034 0.1039416 0.1484616 */ val dataset = Seq( @@ -835,23 +835,22 @@ class GeneralizedLinearRegressionSuite ).toDF() val expected = Seq( - Vectors.dense(0, 0.5169222, -0.3344444), Vectors.dense(0, 0.9419107, -0.6864404), - Vectors.dense(0, 0.1812436, -0.6568422), Vectors.dense(0, -0.2869094, 0.785771), + Vectors.dense(0, 0.5169222, -0.3344444), + Vectors.dense(0, 0.1812436, -0.6568422), Vectors.dense(0, 0.1055254, 0.2979113), - Vectors.dense(-0.05990345, 0.53188982, -0.32118415), Vectors.dense(-0.2147117, 0.991175, -0.6356096), - Vectors.dense(-1.561613, 0.664647, -0.3192581), Vectors.dense(0.3390397, -0.3406099, 0.6870259), + Vectors.dense(-0.05990345, 0.53188982, -0.32118415), + Vectors.dense(-1.561613, 0.664647, -0.3192581), Vectors.dense(0.3665034, 0.1039416, 0.1484616)) import GeneralizedLinearRegression._ var idx = 0 - for (fitIntercept <- Seq(false, true)) { - for (family <- Seq("gaussian", "binomial", "poisson", "gamma", "tweedie")) { + for (family <- GeneralizedLinearRegression.supportedFamilyNames.sortWith(_ < _)) { val trainer = new GeneralizedLinearRegression().setFamily(family) .setFitIntercept(fitIntercept).setOffsetCol("offset") .setWeightCol("weight").setLinkPredictionCol("linkPrediction") From e0b047eafed92eadf6842a9df964438095e12d41 Mon Sep 17 00:00:00 2001 From: Ruifeng Zheng Date: Sat, 1 Jul 2017 15:37:41 +0800 Subject: [PATCH 0841/1765] [SPARK-18518][ML] HasSolver supports override ## What changes were proposed in this pull request? 1, make param support non-final with `finalFields` option 2, generate `HasSolver` with `finalFields = false` 3, override `solver` in LiR, GLR, and make MLPC inherit `HasSolver` ## How was this patch tested? existing tests Author: Ruifeng Zheng Author: Zheng RuiFeng Closes #16028 from zhengruifeng/param_non_final. --- .../MultilayerPerceptronClassifier.scala | 19 ++++---- .../ml/param/shared/SharedParamsCodeGen.scala | 11 +++-- .../spark/ml/param/shared/sharedParams.scala | 8 ++-- .../GeneralizedLinearRegression.scala | 21 ++++++++- .../ml/regression/LinearRegression.scala | 46 +++++++++++++++---- python/pyspark/ml/classification.py | 18 +------- python/pyspark/ml/regression.py | 5 ++ 7 files changed, 82 insertions(+), 46 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifier.scala index ec39f964e213a..ceba11edc93be 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifier.scala @@ -27,13 +27,16 @@ import org.apache.spark.ml.ann.{FeedForwardTopology, FeedForwardTrainer} import org.apache.spark.ml.feature.LabeledPoint import org.apache.spark.ml.linalg.{Vector, Vectors} import org.apache.spark.ml.param._ -import org.apache.spark.ml.param.shared.{HasMaxIter, HasSeed, HasStepSize, HasTol} +import org.apache.spark.ml.param.shared._ import org.apache.spark.ml.util._ import org.apache.spark.sql.Dataset /** Params for Multilayer Perceptron. */ private[classification] trait MultilayerPerceptronParams extends PredictorParams - with HasSeed with HasMaxIter with HasTol with HasStepSize { + with HasSeed with HasMaxIter with HasTol with HasStepSize with HasSolver { + + import MultilayerPerceptronClassifier._ + /** * Layer sizes including input size and output size. * @@ -78,14 +81,10 @@ private[classification] trait MultilayerPerceptronParams extends PredictorParams * @group expertParam */ @Since("2.0.0") - final val solver: Param[String] = new Param[String](this, "solver", + final override val solver: Param[String] = new Param[String](this, "solver", "The solver algorithm for optimization. Supported options: " + - s"${MultilayerPerceptronClassifier.supportedSolvers.mkString(", ")}. (Default l-bfgs)", - ParamValidators.inArray[String](MultilayerPerceptronClassifier.supportedSolvers)) - - /** @group expertGetParam */ - @Since("2.0.0") - final def getSolver: String = $(solver) + s"${supportedSolvers.mkString(", ")}. (Default l-bfgs)", + ParamValidators.inArray[String](supportedSolvers)) /** * The initial weights of the model. @@ -101,7 +100,7 @@ private[classification] trait MultilayerPerceptronParams extends PredictorParams final def getInitialWeights: Vector = $(initialWeights) setDefault(maxIter -> 100, tol -> 1e-6, blockSize -> 128, - solver -> MultilayerPerceptronClassifier.LBFGS, stepSize -> 0.03) + solver -> LBFGS, stepSize -> 0.03) } /** Label to vector converter. */ diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala b/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala index 013817a41baf5..23e0d45d943a0 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala @@ -80,8 +80,7 @@ private[shared] object SharedParamsCodeGen { " 0)", isValid = "ParamValidators.gt(0)"), ParamDesc[String]("weightCol", "weight column name. If this is not set or empty, we treat " + "all instance weights as 1.0"), - ParamDesc[String]("solver", "the solver algorithm for optimization. If this is not set or " + - "empty, default value is 'auto'", Some("\"auto\"")), + ParamDesc[String]("solver", "the solver algorithm for optimization", finalFields = false), ParamDesc[Int]("aggregationDepth", "suggested depth for treeAggregate (>= 2)", Some("2"), isValid = "ParamValidators.gtEq(2)", isExpertParam = true)) @@ -99,6 +98,7 @@ private[shared] object SharedParamsCodeGen { defaultValueStr: Option[String] = None, isValid: String = "", finalMethods: Boolean = true, + finalFields: Boolean = true, isExpertParam: Boolean = false) { require(name.matches("[a-z][a-zA-Z0-9]*"), s"Param name $name is invalid.") @@ -167,6 +167,11 @@ private[shared] object SharedParamsCodeGen { } else { "def" } + val fieldStr = if (param.finalFields) { + "final val" + } else { + "val" + } val htmlCompliantDoc = Utility.escape(doc) @@ -180,7 +185,7 @@ private[shared] object SharedParamsCodeGen { | * Param for $htmlCompliantDoc. | * @group ${groupStr(0)} | */ - | final val $name: $Param = new $Param(this, "$name", "$doc"$isValid) + | $fieldStr $name: $Param = new $Param(this, "$name", "$doc"$isValid) |$setDefault | /** @group ${groupStr(1)} */ | $methodStr get$Name: $T = $$($name) diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala b/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala index 50619607a5054..1a8f499798b80 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala @@ -374,17 +374,15 @@ private[ml] trait HasWeightCol extends Params { } /** - * Trait for shared param solver (default: "auto"). + * Trait for shared param solver. */ private[ml] trait HasSolver extends Params { /** - * Param for the solver algorithm for optimization. If this is not set or empty, default value is 'auto'. + * Param for the solver algorithm for optimization. * @group param */ - final val solver: Param[String] = new Param[String](this, "solver", "the solver algorithm for optimization. If this is not set or empty, default value is 'auto'") - - setDefault(solver, "auto") + val solver: Param[String] = new Param[String](this, "solver", "the solver algorithm for optimization") /** @group getParam */ final def getSolver: String = $(solver) diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala index ce3460ae43566..c600b87bdc64a 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala @@ -164,7 +164,18 @@ private[regression] trait GeneralizedLinearRegressionBase extends PredictorParam isDefined(linkPredictionCol) && $(linkPredictionCol).nonEmpty } - import GeneralizedLinearRegression._ + /** + * The solver algorithm for optimization. + * Supported options: "irls" (iteratively reweighted least squares). + * Default: "irls" + * + * @group param + */ + @Since("2.3.0") + final override val solver: Param[String] = new Param[String](this, "solver", + "The solver algorithm for optimization. Supported options: " + + s"${supportedSolvers.mkString(", ")}. (Default irls)", + ParamValidators.inArray[String](supportedSolvers)) @Since("2.0.0") override def validateAndTransformSchema( @@ -350,7 +361,7 @@ class GeneralizedLinearRegression @Since("2.0.0") (@Since("2.0.0") override val */ @Since("2.0.0") def setSolver(value: String): this.type = set(solver, value) - setDefault(solver -> "irls") + setDefault(solver -> IRLS) /** * Sets the link prediction (linear predictor) column name. @@ -442,6 +453,12 @@ object GeneralizedLinearRegression extends DefaultParamsReadable[GeneralizedLine Gamma -> Inverse, Gamma -> Identity, Gamma -> Log ) + /** String name for "irls" (iteratively reweighted least squares) solver. */ + private[regression] val IRLS = "irls" + + /** Set of solvers that GeneralizedLinearRegression supports. */ + private[regression] val supportedSolvers = Array(IRLS) + /** Set of family names that GeneralizedLinearRegression supports. */ private[regression] lazy val supportedFamilyNames = supportedFamilyAndLinkPairs.map(_._1.name).toArray :+ "tweedie" diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala index db5ac4f14bd3b..ce5e0797915df 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala @@ -34,7 +34,7 @@ import org.apache.spark.ml.optim.WeightedLeastSquares import org.apache.spark.ml.PredictorParams import org.apache.spark.ml.optim.aggregator.LeastSquaresAggregator import org.apache.spark.ml.optim.loss.{L2Regularization, RDDLossFunction} -import org.apache.spark.ml.param.ParamMap +import org.apache.spark.ml.param.{Param, ParamMap, ParamValidators} import org.apache.spark.ml.param.shared._ import org.apache.spark.ml.util._ import org.apache.spark.mllib.evaluation.RegressionMetrics @@ -53,7 +53,23 @@ import org.apache.spark.storage.StorageLevel private[regression] trait LinearRegressionParams extends PredictorParams with HasRegParam with HasElasticNetParam with HasMaxIter with HasTol with HasFitIntercept with HasStandardization with HasWeightCol with HasSolver - with HasAggregationDepth + with HasAggregationDepth { + + import LinearRegression._ + + /** + * The solver algorithm for optimization. + * Supported options: "l-bfgs", "normal" and "auto". + * Default: "auto" + * + * @group param + */ + @Since("2.3.0") + final override val solver: Param[String] = new Param[String](this, "solver", + "The solver algorithm for optimization. Supported options: " + + s"${supportedSolvers.mkString(", ")}. (Default auto)", + ParamValidators.inArray[String](supportedSolvers)) +} /** * Linear regression. @@ -78,6 +94,8 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String extends Regressor[Vector, LinearRegression, LinearRegressionModel] with LinearRegressionParams with DefaultParamsWritable with Logging { + import LinearRegression._ + @Since("1.4.0") def this() = this(Identifiable.randomUID("linReg")) @@ -175,12 +193,8 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String * @group setParam */ @Since("1.6.0") - def setSolver(value: String): this.type = { - require(Set("auto", "l-bfgs", "normal").contains(value), - s"Solver $value was not supported. Supported options: auto, l-bfgs, normal") - set(solver, value) - } - setDefault(solver -> "auto") + def setSolver(value: String): this.type = set(solver, value) + setDefault(solver -> AUTO) /** * Suggested depth for treeAggregate (greater than or equal to 2). @@ -210,8 +224,8 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String elasticNetParam, fitIntercept, maxIter, regParam, standardization, aggregationDepth) instr.logNumFeatures(numFeatures) - if (($(solver) == "auto" && - numFeatures <= WeightedLeastSquares.MAX_NUM_FEATURES) || $(solver) == "normal") { + if (($(solver) == AUTO && + numFeatures <= WeightedLeastSquares.MAX_NUM_FEATURES) || $(solver) == NORMAL) { // For low dimensional data, WeightedLeastSquares is more efficient since the // training algorithm only requires one pass through the data. (SPARK-10668) @@ -444,6 +458,18 @@ object LinearRegression extends DefaultParamsReadable[LinearRegression] { */ @Since("2.1.0") val MAX_FEATURES_FOR_NORMAL_SOLVER: Int = WeightedLeastSquares.MAX_NUM_FEATURES + + /** String name for "auto". */ + private[regression] val AUTO = "auto" + + /** String name for "normal". */ + private[regression] val NORMAL = "normal" + + /** String name for "l-bfgs". */ + private[regression] val LBFGS = "l-bfgs" + + /** Set of solvers that LinearRegression supports. */ + private[regression] val supportedSolvers = Array(AUTO, NORMAL, LBFGS) } /** diff --git a/python/pyspark/ml/classification.py b/python/pyspark/ml/classification.py index 9b345ac73f3d9..948806a5c936c 100644 --- a/python/pyspark/ml/classification.py +++ b/python/pyspark/ml/classification.py @@ -1265,8 +1265,8 @@ def theta(self): @inherit_doc class MultilayerPerceptronClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, - HasMaxIter, HasTol, HasSeed, HasStepSize, JavaMLWritable, - JavaMLReadable): + HasMaxIter, HasTol, HasSeed, HasStepSize, HasSolver, + JavaMLWritable, JavaMLReadable): """ Classifier trainer based on the Multilayer Perceptron. Each layer has sigmoid activation function, output layer has softmax. @@ -1407,20 +1407,6 @@ def getStepSize(self): """ return self.getOrDefault(self.stepSize) - @since("2.0.0") - def setSolver(self, value): - """ - Sets the value of :py:attr:`solver`. - """ - return self._set(solver=value) - - @since("2.0.0") - def getSolver(self): - """ - Gets the value of solver or its default value. - """ - return self.getOrDefault(self.solver) - @since("2.0.0") def setInitialWeights(self, value): """ diff --git a/python/pyspark/ml/regression.py b/python/pyspark/ml/regression.py index 2d17f95b0c44f..84d843369e105 100644 --- a/python/pyspark/ml/regression.py +++ b/python/pyspark/ml/regression.py @@ -95,6 +95,9 @@ class LinearRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPrediction .. versionadded:: 1.4.0 """ + solver = Param(Params._dummy(), "solver", "The solver algorithm for optimization. Supported " + + "options: auto, normal, l-bfgs.", typeConverter=TypeConverters.toString) + @keyword_only def __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", maxIter=100, regParam=0.0, elasticNetParam=0.0, tol=1e-6, fitIntercept=True, @@ -1371,6 +1374,8 @@ class GeneralizedLinearRegression(JavaEstimator, HasLabelCol, HasFeaturesCol, Ha linkPower = Param(Params._dummy(), "linkPower", "The index in the power link function. " + "Only applicable to the Tweedie family.", typeConverter=TypeConverters.toFloat) + solver = Param(Params._dummy(), "solver", "The solver algorithm for optimization. Supported " + + "options: irls.", typeConverter=TypeConverters.toString) @keyword_only def __init__(self, labelCol="label", featuresCol="features", predictionCol="prediction", From 6beca9ce94f484de2f9ffb946bef8334781b3122 Mon Sep 17 00:00:00 2001 From: Devaraj K Date: Sat, 1 Jul 2017 15:53:49 +0100 Subject: [PATCH 0842/1765] [SPARK-21170][CORE] Utils.tryWithSafeFinallyAndFailureCallbacks throws IllegalArgumentException: Self-suppression not permitted ## What changes were proposed in this pull request? Not adding the exception to the suppressed if it is the same instance as originalThrowable. ## How was this patch tested? Added new tests to verify this, these tests fail without source code changes and passes with the change. Author: Devaraj K Closes #18384 from devaraj-kavali/SPARK-21170. --- .../scala/org/apache/spark/util/Utils.scala | 30 +++---- .../org/apache/spark/util/UtilsSuite.scala | 88 ++++++++++++++++++- 2 files changed, 99 insertions(+), 19 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index bbb7999e2a144..26f61e25da4d3 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -1348,14 +1348,10 @@ private[spark] object Utils extends Logging { try { finallyBlock } catch { - case t: Throwable => - if (originalThrowable != null) { - originalThrowable.addSuppressed(t) - logWarning(s"Suppressing exception in finally: " + t.getMessage, t) - throw originalThrowable - } else { - throw t - } + case t: Throwable if (originalThrowable != null && originalThrowable != t) => + originalThrowable.addSuppressed(t) + logWarning(s"Suppressing exception in finally: ${t.getMessage}", t) + throw originalThrowable } } } @@ -1387,22 +1383,20 @@ private[spark] object Utils extends Logging { catchBlock } catch { case t: Throwable => - originalThrowable.addSuppressed(t) - logWarning(s"Suppressing exception in catch: " + t.getMessage, t) + if (originalThrowable != t) { + originalThrowable.addSuppressed(t) + logWarning(s"Suppressing exception in catch: ${t.getMessage}", t) + } } throw originalThrowable } finally { try { finallyBlock } catch { - case t: Throwable => - if (originalThrowable != null) { - originalThrowable.addSuppressed(t) - logWarning(s"Suppressing exception in finally: " + t.getMessage, t) - throw originalThrowable - } else { - throw t - } + case t: Throwable if (originalThrowable != null && originalThrowable != t) => + originalThrowable.addSuppressed(t) + logWarning(s"Suppressing exception in finally: ${t.getMessage}", t) + throw originalThrowable } } } diff --git a/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala b/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala index f7bc8f888b0d5..4ce143f18bbf1 100644 --- a/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala @@ -38,7 +38,7 @@ import org.apache.commons.math3.stat.inference.ChiSquareTest import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.Path -import org.apache.spark.{SparkConf, SparkFunSuite} +import org.apache.spark.{SparkConf, SparkFunSuite, TaskContext} import org.apache.spark.internal.Logging import org.apache.spark.network.util.ByteUnit @@ -1024,4 +1024,90 @@ class UtilsSuite extends SparkFunSuite with ResetSystemProperties with Logging { assert(redactedConf("spark.sensitive.property") === Utils.REDACTION_REPLACEMENT_TEXT) } + + test("tryWithSafeFinally") { + var e = new Error("Block0") + val finallyBlockError = new Error("Finally Block") + var isErrorOccurred = false + // if the try and finally blocks throw different exception instances + try { + Utils.tryWithSafeFinally { throw e }(finallyBlock = { throw finallyBlockError }) + } catch { + case t: Error => + assert(t.getSuppressed.head == finallyBlockError) + isErrorOccurred = true + } + assert(isErrorOccurred) + // if the try and finally blocks throw the same exception instance then it should not + // try to add to suppressed and get IllegalArgumentException + e = new Error("Block1") + isErrorOccurred = false + try { + Utils.tryWithSafeFinally { throw e }(finallyBlock = { throw e }) + } catch { + case t: Error => + assert(t.getSuppressed.length == 0) + isErrorOccurred = true + } + assert(isErrorOccurred) + // if the try throws the exception and finally doesn't throw exception + e = new Error("Block2") + isErrorOccurred = false + try { + Utils.tryWithSafeFinally { throw e }(finallyBlock = {}) + } catch { + case t: Error => + assert(t.getSuppressed.length == 0) + isErrorOccurred = true + } + assert(isErrorOccurred) + // if the try and finally block don't throw exception + Utils.tryWithSafeFinally {}(finallyBlock = {}) + } + + test("tryWithSafeFinallyAndFailureCallbacks") { + var e = new Error("Block0") + val catchBlockError = new Error("Catch Block") + val finallyBlockError = new Error("Finally Block") + var isErrorOccurred = false + TaskContext.setTaskContext(TaskContext.empty()) + // if the try, catch and finally blocks throw different exception instances + try { + Utils.tryWithSafeFinallyAndFailureCallbacks { throw e }( + catchBlock = { throw catchBlockError }, finallyBlock = { throw finallyBlockError }) + } catch { + case t: Error => + assert(t.getSuppressed.head == catchBlockError) + assert(t.getSuppressed.last == finallyBlockError) + isErrorOccurred = true + } + assert(isErrorOccurred) + // if the try, catch and finally blocks throw the same exception instance then it should not + // try to add to suppressed and get IllegalArgumentException + e = new Error("Block1") + isErrorOccurred = false + try { + Utils.tryWithSafeFinallyAndFailureCallbacks { throw e }(catchBlock = { throw e }, + finallyBlock = { throw e }) + } catch { + case t: Error => + assert(t.getSuppressed.length == 0) + isErrorOccurred = true + } + assert(isErrorOccurred) + // if the try throws the exception, catch and finally don't throw exceptions + e = new Error("Block2") + isErrorOccurred = false + try { + Utils.tryWithSafeFinallyAndFailureCallbacks { throw e }(catchBlock = {}, finallyBlock = {}) + } catch { + case t: Error => + assert(t.getSuppressed.length == 0) + isErrorOccurred = true + } + assert(isErrorOccurred) + // if the try, catch and finally blocks don't throw exceptions + Utils.tryWithSafeFinallyAndFailureCallbacks {}(catchBlock = {}, finallyBlock = {}) + TaskContext.unset + } } From c605fee01f180588ecb2f48710a7b84073bd3b9a Mon Sep 17 00:00:00 2001 From: Xingbo Jiang Date: Sun, 2 Jul 2017 08:50:48 +0100 Subject: [PATCH 0843/1765] [SPARK-21260][SQL][MINOR] Remove the unused OutputFakerExec ## What changes were proposed in this pull request? OutputFakerExec was added long ago and is not used anywhere now so we should remove it. ## How was this patch tested? N/A Author: Xingbo Jiang Closes #18473 from jiangxb1987/OutputFakerExec. --- .../spark/sql/execution/basicPhysicalOperators.scala | 11 ----------- 1 file changed, 11 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala index f3ca8397047fe..2151c339b9b87 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala @@ -584,17 +584,6 @@ case class CoalesceExec(numPartitions: Int, child: SparkPlan) extends UnaryExecN } } -/** - * A plan node that does nothing but lie about the output of its child. Used to spice a - * (hopefully structurally equivalent) tree from a different optimization sequence into an already - * resolved tree. - */ -case class OutputFakerExec(output: Seq[Attribute], child: SparkPlan) extends SparkPlan { - def children: Seq[SparkPlan] = child :: Nil - - protected override def doExecute(): RDD[InternalRow] = child.execute() -} - /** * Physical plan for a subquery. */ From c19680be1c532dded1e70edce7a981ba28af09ad Mon Sep 17 00:00:00 2001 From: Yanbo Liang Date: Sun, 2 Jul 2017 16:17:03 +0800 Subject: [PATCH 0844/1765] [SPARK-19852][PYSPARK][ML] Python StringIndexer supports 'keep' to handle invalid data ## What changes were proposed in this pull request? This PR is to maintain API parity with changes made in SPARK-17498 to support a new option 'keep' in StringIndexer to handle unseen labels or NULL values with PySpark. Note: This is updated version of #17237 , the primary author of this PR is VinceShieh . ## How was this patch tested? Unit tests. Author: VinceShieh Author: Yanbo Liang Closes #18453 from yanboliang/spark-19852. --- python/pyspark/ml/feature.py | 6 ++++++ python/pyspark/ml/tests.py | 21 +++++++++++++++++++++ 2 files changed, 27 insertions(+) diff --git a/python/pyspark/ml/feature.py b/python/pyspark/ml/feature.py index 77de1cc18246d..25ad06f682ed9 100755 --- a/python/pyspark/ml/feature.py +++ b/python/pyspark/ml/feature.py @@ -2132,6 +2132,12 @@ class StringIndexer(JavaEstimator, HasInputCol, HasOutputCol, HasHandleInvalid, "frequencyDesc, frequencyAsc, alphabetDesc, alphabetAsc.", typeConverter=TypeConverters.toString) + handleInvalid = Param(Params._dummy(), "handleInvalid", "how to handle invalid data (unseen " + + "labels or NULL values). Options are 'skip' (filter out rows with " + + "invalid data), error (throw an error), or 'keep' (put invalid data " + + "in a special additional bucket, at index numLabels).", + typeConverter=TypeConverters.toString) + @keyword_only def __init__(self, inputCol=None, outputCol=None, handleInvalid="error", stringOrderType="frequencyDesc"): diff --git a/python/pyspark/ml/tests.py b/python/pyspark/ml/tests.py index 17a39472e1fe5..ffb8b0a890ff8 100755 --- a/python/pyspark/ml/tests.py +++ b/python/pyspark/ml/tests.py @@ -551,6 +551,27 @@ def test_rformula_string_indexer_order_type(self): for i in range(0, len(expected)): self.assertTrue(all(observed[i]["features"].toArray() == expected[i])) + def test_string_indexer_handle_invalid(self): + df = self.spark.createDataFrame([ + (0, "a"), + (1, "d"), + (2, None)], ["id", "label"]) + + si1 = StringIndexer(inputCol="label", outputCol="indexed", handleInvalid="keep", + stringOrderType="alphabetAsc") + model1 = si1.fit(df) + td1 = model1.transform(df) + actual1 = td1.select("id", "indexed").collect() + expected1 = [Row(id=0, indexed=0.0), Row(id=1, indexed=1.0), Row(id=2, indexed=2.0)] + self.assertEqual(actual1, expected1) + + si2 = si1.setHandleInvalid("skip") + model2 = si2.fit(df) + td2 = model2.transform(df) + actual2 = td2.select("id", "indexed").collect() + expected2 = [Row(id=0, indexed=0.0), Row(id=1, indexed=1.0)] + self.assertEqual(actual2, expected2) + class HasInducedError(Params): From d4107196d59638845bd19da6aab074424d90ddaf Mon Sep 17 00:00:00 2001 From: Rui Zha Date: Sun, 2 Jul 2017 17:37:47 -0700 Subject: [PATCH 0845/1765] [SPARK-18004][SQL] Make sure the date or timestamp related predicate can be pushed down to Oracle correctly ## What changes were proposed in this pull request? Move `compileValue` method in JDBCRDD to JdbcDialect, and override the `compileValue` method in OracleDialect to rewrite the Oracle-specific timestamp and date literals in where clause. ## How was this patch tested? An integration test has been added. Author: Rui Zha Author: Zharui Closes #18451 from SharpRay/extend-compileValue-to-dialects. --- .../sql/jdbc/OracleIntegrationSuite.scala | 45 +++++++++++++++++++ .../execution/datasources/jdbc/JDBCRDD.scala | 35 +++++---------- .../apache/spark/sql/jdbc/JdbcDialects.scala | 27 ++++++++++- .../apache/spark/sql/jdbc/OracleDialect.scala | 15 ++++++- 4 files changed, 95 insertions(+), 27 deletions(-) diff --git a/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/OracleIntegrationSuite.scala b/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/OracleIntegrationSuite.scala index b2f096964427e..e14810a32edc6 100644 --- a/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/OracleIntegrationSuite.scala +++ b/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/OracleIntegrationSuite.scala @@ -223,4 +223,49 @@ class OracleIntegrationSuite extends DockerJDBCIntegrationSuite with SharedSQLCo val types = rows(0).toSeq.map(x => x.getClass.toString) assert(types(1).equals("class java.sql.Timestamp")) } + + test("SPARK-18004: Make sure date or timestamp related predicate is pushed down correctly") { + val props = new Properties() + props.put("oracle.jdbc.mapDateToTimestamp", "false") + + val schema = StructType(Seq( + StructField("date_type", DateType, true), + StructField("timestamp_type", TimestampType, true) + )) + + val tableName = "test_date_timestamp_pushdown" + val dateVal = Date.valueOf("2017-06-22") + val timestampVal = Timestamp.valueOf("2017-06-22 21:30:07") + + val data = spark.sparkContext.parallelize(Seq( + Row(dateVal, timestampVal) + )) + + val dfWrite = spark.createDataFrame(data, schema) + dfWrite.write.jdbc(jdbcUrl, tableName, props) + + val dfRead = spark.read.jdbc(jdbcUrl, tableName, props) + + val millis = System.currentTimeMillis() + val dt = new java.sql.Date(millis) + val ts = new java.sql.Timestamp(millis) + + // Query Oracle table with date and timestamp predicates + // which should be pushed down to Oracle. + val df = dfRead.filter(dfRead.col("date_type").lt(dt)) + .filter(dfRead.col("timestamp_type").lt(ts)) + + val metadata = df.queryExecution.sparkPlan.metadata + // The "PushedFilters" part should be exist in Datafrome's + // physical plan and the existence of right literals in + // "PushedFilters" is used to prove that the predicates + // pushing down have been effective. + assert(metadata.get("PushedFilters").ne(None)) + assert(metadata("PushedFilters").contains(dt.toString)) + assert(metadata("PushedFilters").contains(ts.toString)) + + val row = df.collect()(0) + assert(row.getDate(0).equals(dateVal)) + assert(row.getTimestamp(1).equals(timestampVal)) + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala index 2bdc43254133e..0f53b5c7c6f0f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala @@ -17,12 +17,10 @@ package org.apache.spark.sql.execution.datasources.jdbc -import java.sql.{Connection, Date, PreparedStatement, ResultSet, SQLException, Timestamp} +import java.sql.{Connection, PreparedStatement, ResultSet, SQLException} import scala.util.control.NonFatal -import org.apache.commons.lang3.StringUtils - import org.apache.spark.{InterruptibleIterator, Partition, SparkContext, TaskContext} import org.apache.spark.internal.Logging import org.apache.spark.rdd.RDD @@ -86,20 +84,6 @@ object JDBCRDD extends Logging { new StructType(columns.map(name => fieldMap(name))) } - /** - * Converts value to SQL expression. - */ - private def compileValue(value: Any): Any = value match { - case stringValue: String => s"'${escapeSql(stringValue)}'" - case timestampValue: Timestamp => "'" + timestampValue + "'" - case dateValue: Date => "'" + dateValue + "'" - case arrayValue: Array[Any] => arrayValue.map(compileValue).mkString(", ") - case _ => value - } - - private def escapeSql(value: String): String = - if (value == null) null else StringUtils.replace(value, "'", "''") - /** * Turns a single Filter into a String representing a SQL expression. * Returns None for an unhandled filter. @@ -108,15 +92,16 @@ object JDBCRDD extends Logging { def quote(colName: String): String = dialect.quoteIdentifier(colName) Option(f match { - case EqualTo(attr, value) => s"${quote(attr)} = ${compileValue(value)}" + case EqualTo(attr, value) => s"${quote(attr)} = ${dialect.compileValue(value)}" case EqualNullSafe(attr, value) => val col = quote(attr) - s"(NOT ($col != ${compileValue(value)} OR $col IS NULL OR " + - s"${compileValue(value)} IS NULL) OR ($col IS NULL AND ${compileValue(value)} IS NULL))" - case LessThan(attr, value) => s"${quote(attr)} < ${compileValue(value)}" - case GreaterThan(attr, value) => s"${quote(attr)} > ${compileValue(value)}" - case LessThanOrEqual(attr, value) => s"${quote(attr)} <= ${compileValue(value)}" - case GreaterThanOrEqual(attr, value) => s"${quote(attr)} >= ${compileValue(value)}" + s"(NOT ($col != ${dialect.compileValue(value)} OR $col IS NULL OR " + + s"${dialect.compileValue(value)} IS NULL) OR " + + s"($col IS NULL AND ${dialect.compileValue(value)} IS NULL))" + case LessThan(attr, value) => s"${quote(attr)} < ${dialect.compileValue(value)}" + case GreaterThan(attr, value) => s"${quote(attr)} > ${dialect.compileValue(value)}" + case LessThanOrEqual(attr, value) => s"${quote(attr)} <= ${dialect.compileValue(value)}" + case GreaterThanOrEqual(attr, value) => s"${quote(attr)} >= ${dialect.compileValue(value)}" case IsNull(attr) => s"${quote(attr)} IS NULL" case IsNotNull(attr) => s"${quote(attr)} IS NOT NULL" case StringStartsWith(attr, value) => s"${quote(attr)} LIKE '${value}%'" @@ -124,7 +109,7 @@ object JDBCRDD extends Logging { case StringContains(attr, value) => s"${quote(attr)} LIKE '%${value}%'" case In(attr, value) if value.isEmpty => s"CASE WHEN ${quote(attr)} IS NULL THEN NULL ELSE FALSE END" - case In(attr, value) => s"${quote(attr)} IN (${compileValue(value)})" + case In(attr, value) => s"${quote(attr)} IN (${dialect.compileValue(value)})" case Not(f) => compileFilter(f, dialect).map(p => s"(NOT ($p))").getOrElse(null) case Or(f1, f2) => // We can't compile Or filter unless both sub-filters are compiled successfully. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala index a86a86d408906..7c38ed68c0413 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala @@ -17,7 +17,9 @@ package org.apache.spark.sql.jdbc -import java.sql.Connection +import java.sql.{Connection, Date, Timestamp} + +import org.apache.commons.lang3.StringUtils import org.apache.spark.annotation.{DeveloperApi, InterfaceStability, Since} import org.apache.spark.sql.types._ @@ -123,6 +125,29 @@ abstract class JdbcDialect extends Serializable { def beforeFetch(connection: Connection, properties: Map[String, String]): Unit = { } + /** + * Escape special characters in SQL string literals. + * @param value The string to be escaped. + * @return Escaped string. + */ + @Since("2.3.0") + protected[jdbc] def escapeSql(value: String): String = + if (value == null) null else StringUtils.replace(value, "'", "''") + + /** + * Converts value to SQL expression. + * @param value The value to be converted. + * @return Converted value. + */ + @Since("2.3.0") + def compileValue(value: Any): Any = value match { + case stringValue: String => s"'${escapeSql(stringValue)}'" + case timestampValue: Timestamp => "'" + timestampValue + "'" + case dateValue: Date => "'" + dateValue + "'" + case arrayValue: Array[Any] => arrayValue.map(compileValue).mkString(", ") + case _ => value + } + /** * Return Some[true] iff `TRUNCATE TABLE` causes cascading default. * Some[true] : TRUNCATE TABLE causes cascading. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/OracleDialect.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/OracleDialect.scala index 20e634c06b610..3b44c1de93a61 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/OracleDialect.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/OracleDialect.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.jdbc -import java.sql.Types +import java.sql.{Date, Timestamp, Types} import org.apache.spark.sql.types._ @@ -64,5 +64,18 @@ private case object OracleDialect extends JdbcDialect { case _ => None } + override def compileValue(value: Any): Any = value match { + // The JDBC drivers support date literals in SQL statements written in the + // format: {d 'yyyy-mm-dd'} and timestamp literals in SQL statements written + // in the format: {ts 'yyyy-mm-dd hh:mm:ss.f...'}. For details, see + // 'Oracle Database JDBC Developer’s Guide and Reference, 11g Release 1 (11.1)' + // Appendix A Reference Information. + case stringValue: String => s"'${escapeSql(stringValue)}'" + case timestampValue: Timestamp => "{ts '" + timestampValue + "'}" + case dateValue: Date => "{d '" + dateValue + "'}" + case arrayValue: Array[Any] => arrayValue.map(compileValue).mkString(", ") + case _ => value + } + override def isCascadingTruncateTable(): Option[Boolean] = Some(false) } From d913db16a0de0983961f9d0c5f9b146be7226ac1 Mon Sep 17 00:00:00 2001 From: guoxiaolong Date: Mon, 3 Jul 2017 13:31:01 +0800 Subject: [PATCH 0846/1765] [SPARK-21250][WEB-UI] Add a url in the table of 'Running Executors' in worker page to visit job page. ## What changes were proposed in this pull request? Add a url in the table of 'Running Executors' in worker page to visit job page. When I click URL of 'Name', the current page jumps to the job page. Of course this is only in the table of 'Running Executors'. This URL of 'Name' is in the table of 'Finished Executors' does not exist, the click will not jump to any page. fix before: ![1](https://user-images.githubusercontent.com/26266482/27679397-30ddc262-5ceb-11e7-839b-0889d1f42480.png) fix after: ![2](https://user-images.githubusercontent.com/26266482/27679405-3588ef12-5ceb-11e7-9756-0a93815cd698.png) ## How was this patch tested? manual tests Please review http://spark.apache.org/contributing.html before opening a pull request. Author: guoxiaolong Closes #18464 from guoxiaolongzte/SPARK-21250. --- .../apache/spark/deploy/worker/ui/WorkerPage.scala | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerPage.scala b/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerPage.scala index 1ad973122b609..ea39b0dce0a41 100644 --- a/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerPage.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerPage.scala @@ -23,8 +23,8 @@ import scala.xml.Node import org.json4s.JValue +import org.apache.spark.deploy.{ExecutorState, JsonProtocol} import org.apache.spark.deploy.DeployMessages.{RequestWorkerState, WorkerStateResponse} -import org.apache.spark.deploy.JsonProtocol import org.apache.spark.deploy.master.DriverState import org.apache.spark.deploy.worker.{DriverRunner, ExecutorRunner} import org.apache.spark.ui.{UIUtils, WebUIPage} @@ -112,7 +112,15 @@ private[ui] class WorkerPage(parent: WorkerWebUI) extends WebUIPage("") {
    • ID: {executor.appId}
    • -
    • Name: {executor.appDesc.name}
    • +
    • Name: + { + if ({executor.state == ExecutorState.RUNNING} && executor.appDesc.appUiUrl.nonEmpty) { + {executor.appDesc.name} + } else { + {executor.appDesc.name} + } + } +
    • User: {executor.appDesc.user}
    From a9339db99f0620d4828eb903523be55dfbf2fb64 Mon Sep 17 00:00:00 2001 From: Sean Owen Date: Mon, 3 Jul 2017 19:52:39 +0800 Subject: [PATCH 0847/1765] [SPARK-21137][CORE] Spark reads many small files slowly ## What changes were proposed in this pull request? Parallelize FileInputFormat.listStatus in Hadoop API via LIST_STATUS_NUM_THREADS to speed up examination of file sizes for wholeTextFiles et al ## How was this patch tested? Existing tests, which will exercise the key path here: using a local file system. Author: Sean Owen Closes #18441 from srowen/SPARK-21137. --- .../main/scala/org/apache/spark/rdd/BinaryFileRDD.scala | 7 ++++++- .../main/scala/org/apache/spark/rdd/WholeTextFileRDD.scala | 7 ++++++- 2 files changed, 12 insertions(+), 2 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/rdd/BinaryFileRDD.scala b/core/src/main/scala/org/apache/spark/rdd/BinaryFileRDD.scala index 50d977a92da51..a14bad47dfe10 100644 --- a/core/src/main/scala/org/apache/spark/rdd/BinaryFileRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/BinaryFileRDD.scala @@ -20,6 +20,7 @@ package org.apache.spark.rdd import org.apache.hadoop.conf.{Configurable, Configuration} import org.apache.hadoop.io.Writable import org.apache.hadoop.mapreduce._ +import org.apache.hadoop.mapreduce.lib.input.FileInputFormat import org.apache.hadoop.mapreduce.task.JobContextImpl import org.apache.spark.{Partition, SparkContext} @@ -35,8 +36,12 @@ private[spark] class BinaryFileRDD[T]( extends NewHadoopRDD[String, T](sc, inputFormatClass, keyClass, valueClass, conf) { override def getPartitions: Array[Partition] = { - val inputFormat = inputFormatClass.newInstance val conf = getConf + // setMinPartitions below will call FileInputFormat.listStatus(), which can be quite slow when + // traversing a large number of directories and files. Parallelize it. + conf.setIfUnset(FileInputFormat.LIST_STATUS_NUM_THREADS, + Runtime.getRuntime.availableProcessors().toString) + val inputFormat = inputFormatClass.newInstance inputFormat match { case configurable: Configurable => configurable.setConf(conf) diff --git a/core/src/main/scala/org/apache/spark/rdd/WholeTextFileRDD.scala b/core/src/main/scala/org/apache/spark/rdd/WholeTextFileRDD.scala index 8e1baae796fc5..9f3d0745c33c9 100644 --- a/core/src/main/scala/org/apache/spark/rdd/WholeTextFileRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/WholeTextFileRDD.scala @@ -20,6 +20,7 @@ package org.apache.spark.rdd import org.apache.hadoop.conf.{Configurable, Configuration} import org.apache.hadoop.io.{Text, Writable} import org.apache.hadoop.mapreduce.InputSplit +import org.apache.hadoop.mapreduce.lib.input.FileInputFormat import org.apache.hadoop.mapreduce.task.JobContextImpl import org.apache.spark.{Partition, SparkContext} @@ -38,8 +39,12 @@ private[spark] class WholeTextFileRDD( extends NewHadoopRDD[Text, Text](sc, inputFormatClass, keyClass, valueClass, conf) { override def getPartitions: Array[Partition] = { - val inputFormat = inputFormatClass.newInstance val conf = getConf + // setMinPartitions below will call FileInputFormat.listStatus(), which can be quite slow when + // traversing a large number of directories and files. Parallelize it. + conf.setIfUnset(FileInputFormat.LIST_STATUS_NUM_THREADS, + Runtime.getRuntime.availableProcessors().toString) + val inputFormat = inputFormatClass.newInstance inputFormat match { case configurable: Configurable => configurable.setConf(conf) From eb7a5a66bbd5837c01f13c76b68de2a6034976f3 Mon Sep 17 00:00:00 2001 From: Zhenhua Wang Date: Mon, 3 Jul 2017 09:01:42 -0700 Subject: [PATCH 0848/1765] [TEST] Load test table based on case sensitivity ## What changes were proposed in this pull request? It is strange that we will get "table not found" error if **the first sql** uses upper case table names, when developers write tests with `TestHiveSingleton`, **although case insensitivity**. This is because in `TestHiveQueryExecution`, test tables are loaded based on exact matching instead of case sensitivity. ## How was this patch tested? Added a new test case. Author: Zhenhua Wang Closes #18504 from wzhfy/testHive. --- .../apache/spark/sql/hive/test/TestHive.scala | 7 ++- .../apache/spark/sql/hive/TestHiveSuite.scala | 45 +++++++++++++++++++ 2 files changed, 51 insertions(+), 1 deletion(-) create mode 100644 sql/hive/src/test/scala/org/apache/spark/sql/hive/TestHiveSuite.scala diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala index 4e1792321c89b..801f9b9923641 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala @@ -449,6 +449,8 @@ private[hive] class TestHiveSparkSession( private val loadedTables = new collection.mutable.HashSet[String] + def getLoadedTables: collection.mutable.HashSet[String] = loadedTables + def loadTestTable(name: String) { if (!(loadedTables contains name)) { // Marks the table as loaded first to prevent infinite mutually recursive table loading. @@ -553,7 +555,10 @@ private[hive] class TestHiveQueryExecution( val referencedTables = describedTables ++ logical.collect { case UnresolvedRelation(tableIdent, _) => tableIdent.table } - val referencedTestTables = referencedTables.filter(sparkSession.testTables.contains) + val resolver = sparkSession.sessionState.conf.resolver + val referencedTestTables = sparkSession.testTables.keys.filter { testTable => + referencedTables.exists(resolver(_, testTable)) + } logDebug(s"Query references test tables: ${referencedTestTables.mkString(", ")}") referencedTestTables.foreach(sparkSession.loadTestTable) // Proceed with analysis. diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/TestHiveSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/TestHiveSuite.scala new file mode 100644 index 0000000000000..193fa83dbad99 --- /dev/null +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/TestHiveSuite.scala @@ -0,0 +1,45 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive + +import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.hive.test.{TestHiveSingleton, TestHiveSparkSession} +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.test.SQLTestUtils + + +class TestHiveSuite extends TestHiveSingleton with SQLTestUtils { + test("load test table based on case sensitivity") { + val testHiveSparkSession = spark.asInstanceOf[TestHiveSparkSession] + + withSQLConf(SQLConf.CASE_SENSITIVE.key -> "false") { + sql("SELECT * FROM SRC").queryExecution.analyzed + assert(testHiveSparkSession.getLoadedTables.contains("src")) + assert(testHiveSparkSession.getLoadedTables.size == 1) + } + testHiveSparkSession.reset() + + withSQLConf(SQLConf.CASE_SENSITIVE.key -> "true") { + val err = intercept[AnalysisException] { + sql("SELECT * FROM SRC").queryExecution.analyzed + } + assert(err.message.contains("Table or view not found")) + } + testHiveSparkSession.reset() + } +} From 17bdc36ef16a544b693c628db276fe32db87fe7a Mon Sep 17 00:00:00 2001 From: aokolnychyi Date: Mon, 3 Jul 2017 09:35:49 -0700 Subject: [PATCH 0849/1765] [SPARK-21102][SQL] Refresh command is too aggressive in parsing ### Idea This PR adds validation to REFRESH sql statements. Currently, users can specify whatever they want as resource path. For example, spark.sql("REFRESH ! $ !") will be executed without any exceptions. ### Implementation I am not sure that my current implementation is the most optimal, so any feedback is appreciated. My first idea was to make the grammar as strict as possible. Unfortunately, there were some problems. I tried the approach below: SqlBase.g4 ``` ... | REFRESH TABLE tableIdentifier #refreshTable | REFRESH resourcePath #refreshResource ... resourcePath : STRING | (IDENTIFIER | number | nonReserved | '/' | '-')+ // other symbols can be added if needed ; ``` It is not flexible enough and requires to explicitly mention all possible symbols. Therefore, I came up with the current approach that is implemented in the code. Let me know your opinion on which one is better. Author: aokolnychyi Closes #18368 from aokolnychyi/spark-21102. --- .../spark/sql/catalyst/parser/SqlBase.g4 | 2 +- .../spark/sql/execution/SparkSqlParser.scala | 20 +++++++++++++++--- .../sql/execution/SparkSqlParserSuite.scala | 21 ++++++++++++++++++- 3 files changed, 38 insertions(+), 5 deletions(-) diff --git a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 index 7ffa150096333..29f554451ed4a 100644 --- a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 +++ b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 @@ -149,7 +149,7 @@ statement | (DESC | DESCRIBE) TABLE? option=(EXTENDED | FORMATTED)? tableIdentifier partitionSpec? describeColName? #describeTable | REFRESH TABLE tableIdentifier #refreshTable - | REFRESH .*? #refreshResource + | REFRESH (STRING | .*?) #refreshResource | CACHE LAZY? TABLE tableIdentifier (AS? query)? #cacheTable | UNCACHE TABLE (IF EXISTS)? tableIdentifier #uncacheTable | CLEAR CACHE #clearCache diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala index 3c58c6e1b6780..2b79eb5eac0f1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala @@ -230,11 +230,25 @@ class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder(conf) { } /** - * Create a [[RefreshTable]] logical plan. + * Create a [[RefreshResource]] logical plan. */ override def visitRefreshResource(ctx: RefreshResourceContext): LogicalPlan = withOrigin(ctx) { - val resourcePath = remainder(ctx.REFRESH.getSymbol).trim - RefreshResource(resourcePath) + val path = if (ctx.STRING != null) string(ctx.STRING) else extractUnquotedResourcePath(ctx) + RefreshResource(path) + } + + private def extractUnquotedResourcePath(ctx: RefreshResourceContext): String = withOrigin(ctx) { + val unquotedPath = remainder(ctx.REFRESH.getSymbol).trim + validate( + unquotedPath != null && !unquotedPath.isEmpty, + "Resource paths cannot be empty in REFRESH statements. Use / to match everything", + ctx) + val forbiddenSymbols = Seq(" ", "\n", "\r", "\t") + validate( + !forbiddenSymbols.exists(unquotedPath.contains(_)), + "REFRESH statements cannot contain ' ', '\\n', '\\r', '\\t' inside unquoted resource paths", + ctx) + unquotedPath } /** diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlParserSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlParserSuite.scala index bd9c2ebd6fab9..d238c76fbeeff 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlParserSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlParserSuite.scala @@ -25,7 +25,7 @@ import org.apache.spark.sql.catalyst.expressions.{Ascending, Concat, SortOrder} import org.apache.spark.sql.catalyst.parser.ParseException import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Project, RepartitionByExpression, Sort} import org.apache.spark.sql.execution.command._ -import org.apache.spark.sql.execution.datasources.CreateTable +import org.apache.spark.sql.execution.datasources.{CreateTable, RefreshResource} import org.apache.spark.sql.internal.{HiveSerDe, SQLConf} import org.apache.spark.sql.types.{IntegerType, LongType, StringType, StructType} @@ -66,6 +66,25 @@ class SparkSqlParserSuite extends AnalysisTest { } } + test("refresh resource") { + assertEqual("REFRESH prefix_path", RefreshResource("prefix_path")) + assertEqual("REFRESH /", RefreshResource("/")) + assertEqual("REFRESH /path///a", RefreshResource("/path///a")) + assertEqual("REFRESH pat1h/112/_1a", RefreshResource("pat1h/112/_1a")) + assertEqual("REFRESH pat1h/112/_1a/a-1", RefreshResource("pat1h/112/_1a/a-1")) + assertEqual("REFRESH path-with-dash", RefreshResource("path-with-dash")) + assertEqual("REFRESH \'path with space\'", RefreshResource("path with space")) + assertEqual("REFRESH \"path with space 2\"", RefreshResource("path with space 2")) + intercept("REFRESH a b", "REFRESH statements cannot contain") + intercept("REFRESH a\tb", "REFRESH statements cannot contain") + intercept("REFRESH a\nb", "REFRESH statements cannot contain") + intercept("REFRESH a\rb", "REFRESH statements cannot contain") + intercept("REFRESH a\r\nb", "REFRESH statements cannot contain") + intercept("REFRESH @ $a$", "REFRESH statements cannot contain") + intercept("REFRESH ", "Resource paths cannot be empty in REFRESH statements") + intercept("REFRESH", "Resource paths cannot be empty in REFRESH statements") + } + test("show functions") { assertEqual("show functions", ShowFunctionsCommand(None, None, true, true)) assertEqual("show all functions", ShowFunctionsCommand(None, None, true, true)) From 363bfe30ba44852a8fac946a37032f76480f6f1b Mon Sep 17 00:00:00 2001 From: Takeshi Yamamuro Date: Mon, 3 Jul 2017 10:14:03 -0700 Subject: [PATCH 0850/1765] [SPARK-20073][SQL] Prints an explicit warning message in case of NULL-safe equals ## What changes were proposed in this pull request? This pr added code to print the same warning messages with `===` cases when using NULL-safe equals (`<=>`). ## How was this patch tested? Existing tests. Author: Takeshi Yamamuro Closes #18436 from maropu/SPARK-20073. --- .../src/main/scala/org/apache/spark/sql/Column.scala | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala index 7e1f1d83cb3de..bd1669b6dba69 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala @@ -464,7 +464,15 @@ class Column(val expr: Expression) extends Logging { * @group expr_ops * @since 1.3.0 */ - def <=> (other: Any): Column = withExpr { EqualNullSafe(expr, lit(other).expr) } + def <=> (other: Any): Column = withExpr { + val right = lit(other).expr + if (this.expr == right) { + logWarning( + s"Constructing trivially true equals predicate, '${this.expr} <=> $right'. " + + "Perhaps you need to use aliases.") + } + EqualNullSafe(expr, right) + } /** * Equality test that is safe for null values. From f953ca56eccdaef29ac580d44613a028415ba3f5 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Mon, 3 Jul 2017 10:51:44 -0700 Subject: [PATCH 0851/1765] [SPARK-21284][SQL] rename SessionCatalog.registerFunction parameter name ## What changes were proposed in this pull request? Looking at the code in `SessionCatalog.registerFunction`, the parameter `ignoreIfExists` is a wrong name. When `ignoreIfExists` is true, we will override the function if it already exists. So `overrideIfExists` should be the corrected name. ## How was this patch tested? N/A Author: Wenchen Fan Closes #18510 from cloud-fan/minor. --- .../sql/catalyst/catalog/SessionCatalog.scala | 6 +++--- .../catalog/SessionCatalogSuite.scala | 20 ++++++++++--------- .../sql/execution/command/functions.scala | 3 +-- .../spark/sql/internal/CatalogSuite.scala | 2 +- .../spark/sql/hive/HiveSessionCatalog.scala | 2 +- .../ObjectHashAggregateExecBenchmark.scala | 2 +- 6 files changed, 18 insertions(+), 17 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala index 7ece77df7fc14..a86604e4353ab 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala @@ -1104,10 +1104,10 @@ class SessionCatalog( */ def registerFunction( funcDefinition: CatalogFunction, - ignoreIfExists: Boolean, + overrideIfExists: Boolean, functionBuilder: Option[FunctionBuilder] = None): Unit = { val func = funcDefinition.identifier - if (functionRegistry.functionExists(func) && !ignoreIfExists) { + if (functionRegistry.functionExists(func) && !overrideIfExists) { throw new AnalysisException(s"Function $func already exists") } val info = new ExpressionInfo(funcDefinition.className, func.database.orNull, func.funcName) @@ -1219,7 +1219,7 @@ class SessionCatalog( // catalog. So, it is possible that qualifiedName is not exactly the same as // catalogFunction.identifier.unquotedString (difference is on case-sensitivity). // At here, we preserve the input from the user. - registerFunction(catalogFunction.copy(identifier = qualifiedName), ignoreIfExists = false) + registerFunction(catalogFunction.copy(identifier = qualifiedName), overrideIfExists = false) // Now, we need to create the Expression. functionRegistry.lookupFunction(qualifiedName, children) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalogSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalogSuite.scala index fc3893e197792..8f856a0daad15 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalogSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalogSuite.scala @@ -1175,9 +1175,9 @@ abstract class SessionCatalogSuite extends AnalysisTest { val tempFunc1 = (e: Seq[Expression]) => e.head val tempFunc2 = (e: Seq[Expression]) => e.last catalog.registerFunction( - newFunc("temp1", None), ignoreIfExists = false, functionBuilder = Some(tempFunc1)) + newFunc("temp1", None), overrideIfExists = false, functionBuilder = Some(tempFunc1)) catalog.registerFunction( - newFunc("temp2", None), ignoreIfExists = false, functionBuilder = Some(tempFunc2)) + newFunc("temp2", None), overrideIfExists = false, functionBuilder = Some(tempFunc2)) val arguments = Seq(Literal(1), Literal(2), Literal(3)) assert(catalog.lookupFunction(FunctionIdentifier("temp1"), arguments) === Literal(1)) assert(catalog.lookupFunction(FunctionIdentifier("temp2"), arguments) === Literal(3)) @@ -1189,12 +1189,12 @@ abstract class SessionCatalogSuite extends AnalysisTest { // Temporary function already exists val e = intercept[AnalysisException] { catalog.registerFunction( - newFunc("temp1", None), ignoreIfExists = false, functionBuilder = Some(tempFunc3)) + newFunc("temp1", None), overrideIfExists = false, functionBuilder = Some(tempFunc3)) }.getMessage assert(e.contains("Function temp1 already exists")) // Temporary function is overridden catalog.registerFunction( - newFunc("temp1", None), ignoreIfExists = true, functionBuilder = Some(tempFunc3)) + newFunc("temp1", None), overrideIfExists = true, functionBuilder = Some(tempFunc3)) assert( catalog.lookupFunction( FunctionIdentifier("temp1"), arguments) === Literal(arguments.length)) @@ -1208,7 +1208,7 @@ abstract class SessionCatalogSuite extends AnalysisTest { val tempFunc1 = (e: Seq[Expression]) => e.head catalog.registerFunction( - newFunc("temp1", None), ignoreIfExists = false, functionBuilder = Some(tempFunc1)) + newFunc("temp1", None), overrideIfExists = false, functionBuilder = Some(tempFunc1)) // Returns true when the function is temporary assert(catalog.isTemporaryFunction(FunctionIdentifier("temp1"))) @@ -1259,7 +1259,7 @@ abstract class SessionCatalogSuite extends AnalysisTest { withBasicCatalog { catalog => val tempFunc = (e: Seq[Expression]) => e.head catalog.registerFunction( - newFunc("func1", None), ignoreIfExists = false, functionBuilder = Some(tempFunc)) + newFunc("func1", None), overrideIfExists = false, functionBuilder = Some(tempFunc)) val arguments = Seq(Literal(1), Literal(2), Literal(3)) assert(catalog.lookupFunction(FunctionIdentifier("func1"), arguments) === Literal(1)) catalog.dropTempFunction("func1", ignoreIfNotExists = false) @@ -1300,7 +1300,7 @@ abstract class SessionCatalogSuite extends AnalysisTest { withBasicCatalog { catalog => val tempFunc1 = (e: Seq[Expression]) => e.head catalog.registerFunction( - newFunc("func1", None), ignoreIfExists = false, functionBuilder = Some(tempFunc1)) + newFunc("func1", None), overrideIfExists = false, functionBuilder = Some(tempFunc1)) assert(catalog.lookupFunction( FunctionIdentifier("func1"), Seq(Literal(1), Literal(2), Literal(3))) == Literal(1)) catalog.dropTempFunction("func1", ignoreIfNotExists = false) @@ -1318,8 +1318,10 @@ abstract class SessionCatalogSuite extends AnalysisTest { val tempFunc2 = (e: Seq[Expression]) => e.last catalog.createFunction(newFunc("func2", Some("db2")), ignoreIfExists = false) catalog.createFunction(newFunc("not_me", Some("db2")), ignoreIfExists = false) - catalog.registerFunction(funcMeta1, ignoreIfExists = false, functionBuilder = Some(tempFunc1)) - catalog.registerFunction(funcMeta2, ignoreIfExists = false, functionBuilder = Some(tempFunc2)) + catalog.registerFunction( + funcMeta1, overrideIfExists = false, functionBuilder = Some(tempFunc1)) + catalog.registerFunction( + funcMeta2, overrideIfExists = false, functionBuilder = Some(tempFunc2)) assert(catalog.listFunctions("db1", "*").map(_._1).toSet == Set(FunctionIdentifier("func1"), FunctionIdentifier("yes_me"))) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/functions.scala index f39a3269efaf1..a91ad413f4d1b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/functions.scala @@ -58,9 +58,8 @@ case class CreateFunctionCommand( s"is not allowed: '${databaseName.get}'") } // We first load resources and then put the builder in the function registry. - // Please note that it is allowed to overwrite an existing temp function. catalog.loadFunctionResources(resources) - catalog.registerFunction(func, ignoreIfExists = false) + catalog.registerFunction(func, overrideIfExists = false) } else { // For a permanent, we will store the metadata into underlying external catalog. // This function will be loaded into the FunctionRegistry when a query uses it. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/internal/CatalogSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/internal/CatalogSuite.scala index b2d568ce320e6..6acac1a9aa317 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/internal/CatalogSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/internal/CatalogSuite.scala @@ -79,7 +79,7 @@ class CatalogSuite val tempFunc = (e: Seq[Expression]) => e.head val funcMeta = CatalogFunction(FunctionIdentifier(name, None), "className", Nil) sessionCatalog.registerFunction( - funcMeta, ignoreIfExists = false, functionBuilder = Some(tempFunc)) + funcMeta, overrideIfExists = false, functionBuilder = Some(tempFunc)) } private def dropFunction(name: String, db: Option[String] = None): Unit = { diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala index da87f0218e3ad..0d0269f694300 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala @@ -161,7 +161,7 @@ private[sql] class HiveSessionCatalog( FunctionIdentifier(functionName.toLowerCase(Locale.ROOT), database) val func = CatalogFunction(functionIdentifier, className, Nil) // Put this Hive built-in function to our function registry. - registerFunction(func, ignoreIfExists = false) + registerFunction(func, overrideIfExists = false) // Now, we need to create the Expression. functionRegistry.lookupFunction(functionIdentifier, children) } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/execution/benchmark/ObjectHashAggregateExecBenchmark.scala b/sql/hive/src/test/scala/org/apache/spark/sql/execution/benchmark/ObjectHashAggregateExecBenchmark.scala index 73383ae4d4118..e599d1ab1d486 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/execution/benchmark/ObjectHashAggregateExecBenchmark.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/execution/benchmark/ObjectHashAggregateExecBenchmark.scala @@ -221,7 +221,7 @@ class ObjectHashAggregateExecBenchmark extends BenchmarkBase with TestHiveSingle val sessionCatalog = sparkSession.sessionState.catalog.asInstanceOf[HiveSessionCatalog] val functionIdentifier = FunctionIdentifier(functionName, database = None) val func = CatalogFunction(functionIdentifier, clazz.getName, resources = Nil) - sessionCatalog.registerFunction(func, ignoreIfExists = false) + sessionCatalog.registerFunction(func, overrideIfExists = false) } private def percentile_approx( From c79c10ebaf3d63b697b8d6d1a7e55aa2d406af69 Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Mon, 3 Jul 2017 16:18:54 -0700 Subject: [PATCH 0852/1765] [TEST] Different behaviors of SparkContext Conf when building SparkSession ## What changes were proposed in this pull request? If the created ACTIVE sparkContext is not EXPLICITLY passed through the Builder's API `sparkContext()`, the conf of this sparkContext will also contain the conf set through the API `config()`; otherwise, the conf of this sparkContext will NOT contain the conf set through the API `config()` ## How was this patch tested? N/A Author: gatorsmile Closes #18517 from gatorsmile/fixTestCase2. --- .../spark/sql/SparkSessionBuilderSuite.scala | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionBuilderSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionBuilderSuite.scala index 386d13d07a95f..4f6d5f79d466e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionBuilderSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionBuilderSuite.scala @@ -98,12 +98,31 @@ class SparkSessionBuilderSuite extends SparkFunSuite { val session = SparkSession.builder().config("key2", "value2").getOrCreate() assert(session.conf.get("key1") == "value1") assert(session.conf.get("key2") == "value2") + assert(session.sparkContext == sparkContext2) assert(session.sparkContext.conf.get("key1") == "value1") + // If the created sparkContext is not passed through the Builder's API sparkContext, + // the conf of this sparkContext will also contain the conf set through the API config. assert(session.sparkContext.conf.get("key2") == "value2") assert(session.sparkContext.conf.get("spark.app.name") == "test") session.stop() } + test("create SparkContext first then pass context to SparkSession") { + sparkContext.stop() + val conf = new SparkConf().setAppName("test").setMaster("local").set("key1", "value1") + val newSC = new SparkContext(conf) + val session = SparkSession.builder().sparkContext(newSC).config("key2", "value2").getOrCreate() + assert(session.conf.get("key1") == "value1") + assert(session.conf.get("key2") == "value2") + assert(session.sparkContext == newSC) + assert(session.sparkContext.conf.get("key1") == "value1") + // If the created sparkContext is passed through the Builder's API sparkContext, + // the conf of this sparkContext will not contain the conf set through the API config. + assert(!session.sparkContext.conf.contains("key2")) + assert(session.sparkContext.conf.get("spark.app.name") == "test") + session.stop() + } + test("SPARK-15887: hive-site.xml should be loaded") { val session = SparkSession.builder().master("local").getOrCreate() assert(session.sessionState.newHadoopConf().get("hive.in.test") == "true") From 6657e00de36b59011d3fe78e8613fb64e54c957a Mon Sep 17 00:00:00 2001 From: liuxian Date: Tue, 4 Jul 2017 09:16:40 +0800 Subject: [PATCH 0853/1765] [SPARK-21283][CORE] FileOutputStream should be created as append mode ## What changes were proposed in this pull request? `FileAppender` is used to write `stderr` and `stdout` files in `ExecutorRunner`, But before writing `ErrorStream` into the the `stderr` file, the header information has been written into ,if FileOutputStream is not created as append mode, the header information will be lost ## How was this patch tested? unit test case Author: liuxian Closes #18507 from 10110346/wip-lx-0703. --- .../scala/org/apache/spark/util/logging/FileAppender.scala | 2 +- .../test/scala/org/apache/spark/util/FileAppenderSuite.scala | 5 ++++- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/util/logging/FileAppender.scala b/core/src/main/scala/org/apache/spark/util/logging/FileAppender.scala index fdb1495899bc3..8a0cc709bccc5 100644 --- a/core/src/main/scala/org/apache/spark/util/logging/FileAppender.scala +++ b/core/src/main/scala/org/apache/spark/util/logging/FileAppender.scala @@ -94,7 +94,7 @@ private[spark] class FileAppender(inputStream: InputStream, file: File, bufferSi /** Open the file output stream */ protected def openFile() { - outputStream = new FileOutputStream(file, false) + outputStream = new FileOutputStream(file, true) logDebug(s"Opened file $file") } diff --git a/core/src/test/scala/org/apache/spark/util/FileAppenderSuite.scala b/core/src/test/scala/org/apache/spark/util/FileAppenderSuite.scala index 7e2da8e141532..cd0ed5b036bf9 100644 --- a/core/src/test/scala/org/apache/spark/util/FileAppenderSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/FileAppenderSuite.scala @@ -52,10 +52,13 @@ class FileAppenderSuite extends SparkFunSuite with BeforeAndAfter with Logging { test("basic file appender") { val testString = (1 to 1000).mkString(", ") val inputStream = new ByteArrayInputStream(testString.getBytes(StandardCharsets.UTF_8)) + // The `header` should not be covered + val header = "Add header" + Files.write(header, testFile, StandardCharsets.UTF_8) val appender = new FileAppender(inputStream, testFile) inputStream.close() appender.awaitTermination() - assert(Files.toString(testFile, StandardCharsets.UTF_8) === testString) + assert(Files.toString(testFile, StandardCharsets.UTF_8) === header + testString) } test("rolling file appender - time-based rolling") { From a848d552ef6b5d0d3bb3b2da903478437a8b10aa Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Tue, 4 Jul 2017 11:35:08 +0900 Subject: [PATCH 0854/1765] [SPARK-21264][PYTHON] Call cross join path in join without 'on' and with 'how' ## What changes were proposed in this pull request? Currently, it throws a NPE when missing columns but join type is speicified in join at PySpark as below: ```python spark.conf.set("spark.sql.crossJoin.enabled", "false") spark.range(1).join(spark.range(1), how="inner").show() ``` ``` Traceback (most recent call last): ... py4j.protocol.Py4JJavaError: An error occurred while calling o66.join. : java.lang.NullPointerException at org.apache.spark.sql.Dataset.join(Dataset.scala:931) at sun.reflect.NativeMethodAccessorImpl.invoke0(Native Method) ... ``` ```python spark.conf.set("spark.sql.crossJoin.enabled", "true") spark.range(1).join(spark.range(1), how="inner").show() ``` ``` ... py4j.protocol.Py4JJavaError: An error occurred while calling o84.join. : java.lang.NullPointerException at org.apache.spark.sql.Dataset.join(Dataset.scala:931) at sun.reflect.NativeMethodAccessorImpl.invoke0(Native Method) ... ``` This PR suggests to follow Scala's one as below: ```scala scala> spark.conf.set("spark.sql.crossJoin.enabled", "false") scala> spark.range(1).join(spark.range(1), Seq.empty[String], "inner").show() ``` ``` org.apache.spark.sql.AnalysisException: Detected cartesian product for INNER join between logical plans Range (0, 1, step=1, splits=Some(8)) and Range (0, 1, step=1, splits=Some(8)) Join condition is missing or trivial. Use the CROSS JOIN syntax to allow cartesian products between these relations.; ... ``` ```scala scala> spark.conf.set("spark.sql.crossJoin.enabled", "true") scala> spark.range(1).join(spark.range(1), Seq.empty[String], "inner").show() ``` ``` +---+---+ | id| id| +---+---+ | 0| 0| +---+---+ ``` **After** ```python spark.conf.set("spark.sql.crossJoin.enabled", "false") spark.range(1).join(spark.range(1), how="inner").show() ``` ``` Traceback (most recent call last): ... pyspark.sql.utils.AnalysisException: u'Detected cartesian product for INNER join between logical plans\nRange (0, 1, step=1, splits=Some(8))\nand\nRange (0, 1, step=1, splits=Some(8))\nJoin condition is missing or trivial.\nUse the CROSS JOIN syntax to allow cartesian products between these relations.;' ``` ```python spark.conf.set("spark.sql.crossJoin.enabled", "true") spark.range(1).join(spark.range(1), how="inner").show() ``` ``` +---+---+ | id| id| +---+---+ | 0| 0| +---+---+ ``` ## How was this patch tested? Added tests in `python/pyspark/sql/tests.py`. Author: hyukjinkwon Closes #18484 from HyukjinKwon/SPARK-21264. --- python/pyspark/sql/dataframe.py | 2 ++ python/pyspark/sql/tests.py | 16 ++++++++++++++++ 2 files changed, 18 insertions(+) diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index 0649271ed2246..27a6dad8917d3 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -833,6 +833,8 @@ def join(self, other, on=None, how=None): else: if how is None: how = "inner" + if on is None: + on = self._jseq([]) assert isinstance(how, basestring), "how should be basestring" jdf = self._jdf.join(other._jdf, on, how) return DataFrame(jdf, self.sql_ctx) diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 0a1cd6856b8e8..c105969b26b97 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -2021,6 +2021,22 @@ def test_toDF_with_schema_string(self): self.assertEqual(df.schema.simpleString(), "struct") self.assertEqual(df.collect(), [Row(key=i) for i in range(100)]) + def test_join_without_on(self): + df1 = self.spark.range(1).toDF("a") + df2 = self.spark.range(1).toDF("b") + + try: + self.spark.conf.set("spark.sql.crossJoin.enabled", "false") + self.assertRaises(AnalysisException, lambda: df1.join(df2, how="inner").collect()) + + self.spark.conf.set("spark.sql.crossJoin.enabled", "true") + actual = df1.join(df2, how="inner").collect() + expected = [Row(a=0, b=0)] + self.assertEqual(actual, expected) + finally: + # We should unset this. Otherwise, other tests are affected. + self.spark.conf.unset("spark.sql.crossJoin.enabled") + # Regression test for invalid join methods when on is None, Spark-14761 def test_invalid_join_method(self): df1 = self.spark.createDataFrame([("Alice", 5), ("Bob", 8)], ["name", "age"]) From 8ca4ebefa6301d9cb633ea15cf71f49c2d7f8607 Mon Sep 17 00:00:00 2001 From: Thomas Decaux Date: Tue, 4 Jul 2017 12:17:48 +0100 Subject: [PATCH 0855/1765] [MINOR] Add french stop word "les" ## What changes were proposed in this pull request? Added "les" as french stop word (plurial of le) Author: Thomas Decaux Closes #18514 from ebuildy/patch-1. --- .../resources/org/apache/spark/ml/feature/stopwords/french.txt | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/mllib/src/main/resources/org/apache/spark/ml/feature/stopwords/french.txt b/mllib/src/main/resources/org/apache/spark/ml/feature/stopwords/french.txt index 94b8f8f39a3e1..a59a0424616cc 100644 --- a/mllib/src/main/resources/org/apache/spark/ml/feature/stopwords/french.txt +++ b/mllib/src/main/resources/org/apache/spark/ml/feature/stopwords/french.txt @@ -15,6 +15,7 @@ il je la le +les leur lui ma @@ -152,4 +153,4 @@ eusses eût eussions eussiez -eussent \ No newline at end of file +eussent From 2b1e94b9add82b30bc94f639fa97492624bf0dce Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Tue, 4 Jul 2017 12:18:42 +0100 Subject: [PATCH 0856/1765] [MINOR][SPARK SUBMIT] Print out R file usage in spark-submit ## What changes were proposed in this pull request? Currently, running the shell below: ```bash $ ./bin/spark-submit tmp.R a b c ``` with R file, `tmp.R` as below: ```r #!/usr/bin/env Rscript library(SparkR) sparkRSQL.init(sparkR.init(master = "local")) collect(createDataFrame(list(list(1)))) print(commandArgs(trailingOnly = TRUE)) ``` working fine as below: ```bash _1 1 1 [1] "a" "b" "c" ``` However, it looks not printed in usage documentation as below: ```bash $ ./bin/spark-submit ``` ``` Usage: spark-submit [options] [app arguments] ... ``` For `./bin/sparkR`, it looks fine as below: ```bash $ ./bin/sparkR tmp.R ``` ``` Running R applications through 'sparkR' is not supported as of Spark 2.0. Use ./bin/spark-submit ``` Running the script below: ```bash $ ./bin/spark-submit ``` **Before** ``` Usage: spark-submit [options] [app arguments] ... ``` **After** ``` Usage: spark-submit [options] [app arguments] ... ``` ## How was this patch tested? Manually tested. Author: hyukjinkwon Closes #18505 from HyukjinKwon/minor-doc-summit. --- .../scala/org/apache/spark/deploy/SparkSubmitArguments.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala index 3d9a14c51618b..7800d3d624e3e 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala @@ -504,7 +504,7 @@ private[deploy] class SparkSubmitArguments(args: Seq[String], env: Map[String, S outStream.println("Unknown/unsupported param " + unknownParam) } val command = sys.env.get("_SPARK_CMD_USAGE").getOrElse( - """Usage: spark-submit [options] [app arguments] + """Usage: spark-submit [options] [app arguments] |Usage: spark-submit --kill [submission ID] --master [spark://...] |Usage: spark-submit --status [submission ID] --master [spark://...] |Usage: spark-submit run-example [options] example-class [example args]""".stripMargin) From d492cc5a21cd67b3999b85d97f5c41c3734b1ba3 Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Tue, 4 Jul 2017 20:45:58 +0800 Subject: [PATCH 0857/1765] [SPARK-19507][SPARK-21296][PYTHON] Avoid per-record type dispatch in schema verification and improve exception message ## What changes were proposed in this pull request? **Context** While reviewing https://github.com/apache/spark/pull/17227, I realised here we type-dispatch per record. The PR itself is fine in terms of performance as is but this prints a prefix, `"obj"` in exception message as below: ``` from pyspark.sql.types import * schema = StructType([StructField('s', IntegerType(), nullable=False)]) spark.createDataFrame([["1"]], schema) ... TypeError: obj.s: IntegerType can not accept object '1' in type ``` I suggested to get rid of this but during investigating this, I realised my approach might bring a performance regression as it is a hot path. Only for SPARK-19507 and https://github.com/apache/spark/pull/17227, It needs more changes to cleanly get rid of the prefix and I rather decided to fix both issues together. **Propersal** This PR tried to - get rid of per-record type dispatch as we do in many code paths in Scala so that it improves the performance (roughly ~25% improvement) - SPARK-21296 This was tested with a simple code `spark.createDataFrame(range(1000000), "int")`. However, I am quite sure the actual improvement in practice is larger than this, in particular, when the schema is complicated. - improve error message in exception describing field information as prose - SPARK-19507 ## How was this patch tested? Manually tested and unit tests were added in `python/pyspark/sql/tests.py`. Benchmark - codes: https://gist.github.com/HyukjinKwon/c3397469c56cb26c2d7dd521ed0bc5a3 Error message - codes: https://gist.github.com/HyukjinKwon/b1b2c7f65865444c4a8836435100e398 **Before** Benchmark: - Results: https://gist.github.com/HyukjinKwon/4a291dab45542106301a0c1abcdca924 Error message - Results: https://gist.github.com/HyukjinKwon/57b1916395794ce924faa32b14a3fe19 **After** Benchmark - Results: https://gist.github.com/HyukjinKwon/21496feecc4a920e50c4e455f836266e Error message - Results: https://gist.github.com/HyukjinKwon/7a494e4557fe32a652ce1236e504a395 Closes #17227 Author: hyukjinkwon Author: David Gingrich Closes #18521 from HyukjinKwon/python-type-dispatch. --- python/pyspark/rdd.py | 1 - python/pyspark/sql/session.py | 12 +- python/pyspark/sql/tests.py | 203 ++++++++++++++++++++++++++++++- python/pyspark/sql/types.py | 219 +++++++++++++++++++++++----------- 4 files changed, 352 insertions(+), 83 deletions(-) diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index 60141792d499b..7dfa17f68a943 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -627,7 +627,6 @@ def sortPartition(iterator): def sortByKey(self, ascending=True, numPartitions=None, keyfunc=lambda x: x): """ Sorts this RDD, which is assumed to consist of (key, value) pairs. - # noqa >>> tmp = [('a', 1), ('b', 2), ('1', 3), ('d', 4), ('2', 5)] >>> sc.parallelize(tmp).sortByKey().first() diff --git a/python/pyspark/sql/session.py b/python/pyspark/sql/session.py index e3bf0f35ea15e..2cc0e2d1d7b8d 100644 --- a/python/pyspark/sql/session.py +++ b/python/pyspark/sql/session.py @@ -33,7 +33,7 @@ from pyspark.sql.dataframe import DataFrame from pyspark.sql.readwriter import DataFrameReader from pyspark.sql.streaming import DataStreamReader -from pyspark.sql.types import Row, DataType, StringType, StructType, _verify_type, \ +from pyspark.sql.types import Row, DataType, StringType, StructType, _make_type_verifier, \ _infer_schema, _has_nulltype, _merge_type, _create_converter, _parse_datatype_string from pyspark.sql.utils import install_exception_handler @@ -514,17 +514,21 @@ def createDataFrame(self, data, schema=None, samplingRatio=None, verifySchema=Tr schema = [str(x) for x in data.columns] data = [r.tolist() for r in data.to_records(index=False)] - verify_func = _verify_type if verifySchema else lambda _, t: True if isinstance(schema, StructType): + verify_func = _make_type_verifier(schema) if verifySchema else lambda _: True + def prepare(obj): - verify_func(obj, schema) + verify_func(obj) return obj elif isinstance(schema, DataType): dataType = schema schema = StructType().add("value", schema) + verify_func = _make_type_verifier( + dataType, name="field value") if verifySchema else lambda _: True + def prepare(obj): - verify_func(obj, dataType) + verify_func(obj) return obj, else: if isinstance(schema, list): diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index c105969b26b97..16ba8bd73f400 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -57,7 +57,7 @@ from pyspark import SparkContext from pyspark.sql import SparkSession, SQLContext, HiveContext, Column, Row from pyspark.sql.types import * -from pyspark.sql.types import UserDefinedType, _infer_type +from pyspark.sql.types import UserDefinedType, _infer_type, _make_type_verifier from pyspark.tests import ReusedPySparkTestCase, SparkSubmitTests from pyspark.sql.functions import UserDefinedFunction, sha2, lit from pyspark.sql.window import Window @@ -852,7 +852,7 @@ def test_convert_row_to_dict(self): self.assertEqual(1.0, row.asDict()['d']['key'].c) def test_udt(self): - from pyspark.sql.types import _parse_datatype_json_string, _infer_type, _verify_type + from pyspark.sql.types import _parse_datatype_json_string, _infer_type, _make_type_verifier from pyspark.sql.tests import ExamplePointUDT, ExamplePoint def check_datatype(datatype): @@ -868,8 +868,8 @@ def check_datatype(datatype): check_datatype(structtype_with_udt) p = ExamplePoint(1.0, 2.0) self.assertEqual(_infer_type(p), ExamplePointUDT()) - _verify_type(ExamplePoint(1.0, 2.0), ExamplePointUDT()) - self.assertRaises(ValueError, lambda: _verify_type([1.0, 2.0], ExamplePointUDT())) + _make_type_verifier(ExamplePointUDT())(ExamplePoint(1.0, 2.0)) + self.assertRaises(ValueError, lambda: _make_type_verifier(ExamplePointUDT())([1.0, 2.0])) check_datatype(PythonOnlyUDT()) structtype_with_udt = StructType([StructField("label", DoubleType(), False), @@ -877,8 +877,10 @@ def check_datatype(datatype): check_datatype(structtype_with_udt) p = PythonOnlyPoint(1.0, 2.0) self.assertEqual(_infer_type(p), PythonOnlyUDT()) - _verify_type(PythonOnlyPoint(1.0, 2.0), PythonOnlyUDT()) - self.assertRaises(ValueError, lambda: _verify_type([1.0, 2.0], PythonOnlyUDT())) + _make_type_verifier(PythonOnlyUDT())(PythonOnlyPoint(1.0, 2.0)) + self.assertRaises( + ValueError, + lambda: _make_type_verifier(PythonOnlyUDT())([1.0, 2.0])) def test_simple_udt_in_df(self): schema = StructType().add("key", LongType()).add("val", PythonOnlyUDT()) @@ -2636,6 +2638,195 @@ def range_frame_match(): importlib.reload(window) + +class DataTypeVerificationTests(unittest.TestCase): + + def test_verify_type_exception_msg(self): + self.assertRaisesRegexp( + ValueError, + "test_name", + lambda: _make_type_verifier(StringType(), nullable=False, name="test_name")(None)) + + schema = StructType([StructField('a', StructType([StructField('b', IntegerType())]))]) + self.assertRaisesRegexp( + TypeError, + "field b in field a", + lambda: _make_type_verifier(schema)([["data"]])) + + def test_verify_type_ok_nullable(self): + obj = None + types = [IntegerType(), FloatType(), StringType(), StructType([])] + for data_type in types: + try: + _make_type_verifier(data_type, nullable=True)(obj) + except Exception: + self.fail("verify_type(%s, %s, nullable=True)" % (obj, data_type)) + + def test_verify_type_not_nullable(self): + import array + import datetime + import decimal + + schema = StructType([ + StructField('s', StringType(), nullable=False), + StructField('i', IntegerType(), nullable=True)]) + + class MyObj: + def __init__(self, **kwargs): + for k, v in kwargs.items(): + setattr(self, k, v) + + # obj, data_type + success_spec = [ + # String + ("", StringType()), + (u"", StringType()), + (1, StringType()), + (1.0, StringType()), + ([], StringType()), + ({}, StringType()), + + # UDT + (ExamplePoint(1.0, 2.0), ExamplePointUDT()), + + # Boolean + (True, BooleanType()), + + # Byte + (-(2**7), ByteType()), + (2**7 - 1, ByteType()), + + # Short + (-(2**15), ShortType()), + (2**15 - 1, ShortType()), + + # Integer + (-(2**31), IntegerType()), + (2**31 - 1, IntegerType()), + + # Long + (2**64, LongType()), + + # Float & Double + (1.0, FloatType()), + (1.0, DoubleType()), + + # Decimal + (decimal.Decimal("1.0"), DecimalType()), + + # Binary + (bytearray([1, 2]), BinaryType()), + + # Date/Timestamp + (datetime.date(2000, 1, 2), DateType()), + (datetime.datetime(2000, 1, 2, 3, 4), DateType()), + (datetime.datetime(2000, 1, 2, 3, 4), TimestampType()), + + # Array + ([], ArrayType(IntegerType())), + (["1", None], ArrayType(StringType(), containsNull=True)), + ([1, 2], ArrayType(IntegerType())), + ((1, 2), ArrayType(IntegerType())), + (array.array('h', [1, 2]), ArrayType(IntegerType())), + + # Map + ({}, MapType(StringType(), IntegerType())), + ({"a": 1}, MapType(StringType(), IntegerType())), + ({"a": None}, MapType(StringType(), IntegerType(), valueContainsNull=True)), + + # Struct + ({"s": "a", "i": 1}, schema), + ({"s": "a", "i": None}, schema), + ({"s": "a"}, schema), + ({"s": "a", "f": 1.0}, schema), + (Row(s="a", i=1), schema), + (Row(s="a", i=None), schema), + (Row(s="a", i=1, f=1.0), schema), + (["a", 1], schema), + (["a", None], schema), + (("a", 1), schema), + (MyObj(s="a", i=1), schema), + (MyObj(s="a", i=None), schema), + (MyObj(s="a"), schema), + ] + + # obj, data_type, exception class + failure_spec = [ + # String (match anything but None) + (None, StringType(), ValueError), + + # UDT + (ExamplePoint(1.0, 2.0), PythonOnlyUDT(), ValueError), + + # Boolean + (1, BooleanType(), TypeError), + ("True", BooleanType(), TypeError), + ([1], BooleanType(), TypeError), + + # Byte + (-(2**7) - 1, ByteType(), ValueError), + (2**7, ByteType(), ValueError), + ("1", ByteType(), TypeError), + (1.0, ByteType(), TypeError), + + # Short + (-(2**15) - 1, ShortType(), ValueError), + (2**15, ShortType(), ValueError), + + # Integer + (-(2**31) - 1, IntegerType(), ValueError), + (2**31, IntegerType(), ValueError), + + # Float & Double + (1, FloatType(), TypeError), + (1, DoubleType(), TypeError), + + # Decimal + (1.0, DecimalType(), TypeError), + (1, DecimalType(), TypeError), + ("1.0", DecimalType(), TypeError), + + # Binary + (1, BinaryType(), TypeError), + + # Date/Timestamp + ("2000-01-02", DateType(), TypeError), + (946811040, TimestampType(), TypeError), + + # Array + (["1", None], ArrayType(StringType(), containsNull=False), ValueError), + ([1, "2"], ArrayType(IntegerType()), TypeError), + + # Map + ({"a": 1}, MapType(IntegerType(), IntegerType()), TypeError), + ({"a": "1"}, MapType(StringType(), IntegerType()), TypeError), + ({"a": None}, MapType(StringType(), IntegerType(), valueContainsNull=False), + ValueError), + + # Struct + ({"s": "a", "i": "1"}, schema, TypeError), + (Row(s="a"), schema, ValueError), # Row can't have missing field + (Row(s="a", i="1"), schema, TypeError), + (["a"], schema, ValueError), + (["a", "1"], schema, TypeError), + (MyObj(s="a", i="1"), schema, TypeError), + (MyObj(s=None, i="1"), schema, ValueError), + ] + + # Check success cases + for obj, data_type in success_spec: + try: + _make_type_verifier(data_type, nullable=False)(obj) + except Exception: + self.fail("verify_type(%s, %s, nullable=False)" % (obj, data_type)) + + # Check failure cases + for obj, data_type, exp in failure_spec: + msg = "verify_type(%s, %s, nullable=False) == %s" % (obj, data_type, exp) + with self.assertRaises(exp, msg=msg): + _make_type_verifier(data_type, nullable=False)(obj) + + if __name__ == "__main__": from pyspark.sql.tests import * if xmlrunner: diff --git a/python/pyspark/sql/types.py b/python/pyspark/sql/types.py index 26b54a7fb3709..f5505ed4722ad 100644 --- a/python/pyspark/sql/types.py +++ b/python/pyspark/sql/types.py @@ -1249,121 +1249,196 @@ def _infer_schema_type(obj, dataType): } -def _verify_type(obj, dataType, nullable=True): +def _make_type_verifier(dataType, nullable=True, name=None): """ - Verify the type of obj against dataType, raise a TypeError if they do not match. - - Also verify the value of obj against datatype, raise a ValueError if it's not within the allowed - range, e.g. using 128 as ByteType will overflow. Note that, Python float is not checked, so it - will become infinity when cast to Java float if it overflows. - - >>> _verify_type(None, StructType([])) - >>> _verify_type("", StringType()) - >>> _verify_type(0, LongType()) - >>> _verify_type(list(range(3)), ArrayType(ShortType())) - >>> _verify_type(set(), ArrayType(StringType())) # doctest: +IGNORE_EXCEPTION_DETAIL + Make a verifier that checks the type of obj against dataType and raises a TypeError if they do + not match. + + This verifier also checks the value of obj against datatype and raises a ValueError if it's not + within the allowed range, e.g. using 128 as ByteType will overflow. Note that, Python float is + not checked, so it will become infinity when cast to Java float if it overflows. + + >>> _make_type_verifier(StructType([]))(None) + >>> _make_type_verifier(StringType())("") + >>> _make_type_verifier(LongType())(0) + >>> _make_type_verifier(ArrayType(ShortType()))(list(range(3))) + >>> _make_type_verifier(ArrayType(StringType()))(set()) # doctest: +IGNORE_EXCEPTION_DETAIL Traceback (most recent call last): ... TypeError:... - >>> _verify_type({}, MapType(StringType(), IntegerType())) - >>> _verify_type((), StructType([])) - >>> _verify_type([], StructType([])) - >>> _verify_type([1], StructType([])) # doctest: +IGNORE_EXCEPTION_DETAIL + >>> _make_type_verifier(MapType(StringType(), IntegerType()))({}) + >>> _make_type_verifier(StructType([]))(()) + >>> _make_type_verifier(StructType([]))([]) + >>> _make_type_verifier(StructType([]))([1]) # doctest: +IGNORE_EXCEPTION_DETAIL Traceback (most recent call last): ... ValueError:... >>> # Check if numeric values are within the allowed range. - >>> _verify_type(12, ByteType()) - >>> _verify_type(1234, ByteType()) # doctest: +IGNORE_EXCEPTION_DETAIL + >>> _make_type_verifier(ByteType())(12) + >>> _make_type_verifier(ByteType())(1234) # doctest: +IGNORE_EXCEPTION_DETAIL Traceback (most recent call last): ... ValueError:... - >>> _verify_type(None, ByteType(), False) # doctest: +IGNORE_EXCEPTION_DETAIL + >>> _make_type_verifier(ByteType(), False)(None) # doctest: +IGNORE_EXCEPTION_DETAIL Traceback (most recent call last): ... ValueError:... - >>> _verify_type([1, None], ArrayType(ShortType(), False)) # doctest: +IGNORE_EXCEPTION_DETAIL + >>> _make_type_verifier( + ... ArrayType(ShortType(), False))([1, None]) # doctest: +IGNORE_EXCEPTION_DETAIL Traceback (most recent call last): ... ValueError:... - >>> _verify_type({None: 1}, MapType(StringType(), IntegerType())) + >>> _make_type_verifier(MapType(StringType(), IntegerType()))({None: 1}) Traceback (most recent call last): ... ValueError:... >>> schema = StructType().add("a", IntegerType()).add("b", StringType(), False) - >>> _verify_type((1, None), schema) # doctest: +IGNORE_EXCEPTION_DETAIL + >>> _make_type_verifier(schema)((1, None)) # doctest: +IGNORE_EXCEPTION_DETAIL Traceback (most recent call last): ... ValueError:... """ - if obj is None: - if nullable: - return - else: - raise ValueError("This field is not nullable, but got None") - # StringType can work with any types - if isinstance(dataType, StringType): - return + if name is None: + new_msg = lambda msg: msg + new_name = lambda n: "field %s" % n + else: + new_msg = lambda msg: "%s: %s" % (name, msg) + new_name = lambda n: "field %s in %s" % (n, name) - if isinstance(dataType, UserDefinedType): - if not (hasattr(obj, '__UDT__') and obj.__UDT__ == dataType): - raise ValueError("%r is not an instance of type %r" % (obj, dataType)) - _verify_type(dataType.toInternal(obj), dataType.sqlType()) - return + def verify_nullability(obj): + if obj is None: + if nullable: + return True + else: + raise ValueError(new_msg("This field is not nullable, but got None")) + else: + return False _type = type(dataType) - assert _type in _acceptable_types, "unknown datatype: %s for object %r" % (dataType, obj) - if _type is StructType: - # check the type and fields later - pass - else: + def assert_acceptable_types(obj): + assert _type in _acceptable_types, \ + new_msg("unknown datatype: %s for object %r" % (dataType, obj)) + + def verify_acceptable_types(obj): # subclass of them can not be fromInternal in JVM if type(obj) not in _acceptable_types[_type]: - raise TypeError("%s can not accept object %r in type %s" % (dataType, obj, type(obj))) + raise TypeError(new_msg("%s can not accept object %r in type %s" + % (dataType, obj, type(obj)))) + + if isinstance(dataType, StringType): + # StringType can work with any types + verify_value = lambda _: _ + + elif isinstance(dataType, UserDefinedType): + verifier = _make_type_verifier(dataType.sqlType(), name=name) - if isinstance(dataType, ByteType): - if obj < -128 or obj > 127: - raise ValueError("object of ByteType out of range, got: %s" % obj) + def verify_udf(obj): + if not (hasattr(obj, '__UDT__') and obj.__UDT__ == dataType): + raise ValueError(new_msg("%r is not an instance of type %r" % (obj, dataType))) + verifier(dataType.toInternal(obj)) + + verify_value = verify_udf + + elif isinstance(dataType, ByteType): + def verify_byte(obj): + assert_acceptable_types(obj) + verify_acceptable_types(obj) + if obj < -128 or obj > 127: + raise ValueError(new_msg("object of ByteType out of range, got: %s" % obj)) + + verify_value = verify_byte elif isinstance(dataType, ShortType): - if obj < -32768 or obj > 32767: - raise ValueError("object of ShortType out of range, got: %s" % obj) + def verify_short(obj): + assert_acceptable_types(obj) + verify_acceptable_types(obj) + if obj < -32768 or obj > 32767: + raise ValueError(new_msg("object of ShortType out of range, got: %s" % obj)) + + verify_value = verify_short elif isinstance(dataType, IntegerType): - if obj < -2147483648 or obj > 2147483647: - raise ValueError("object of IntegerType out of range, got: %s" % obj) + def verify_integer(obj): + assert_acceptable_types(obj) + verify_acceptable_types(obj) + if obj < -2147483648 or obj > 2147483647: + raise ValueError( + new_msg("object of IntegerType out of range, got: %s" % obj)) + + verify_value = verify_integer elif isinstance(dataType, ArrayType): - for i in obj: - _verify_type(i, dataType.elementType, dataType.containsNull) + element_verifier = _make_type_verifier( + dataType.elementType, dataType.containsNull, name="element in array %s" % name) + + def verify_array(obj): + assert_acceptable_types(obj) + verify_acceptable_types(obj) + for i in obj: + element_verifier(i) + + verify_value = verify_array elif isinstance(dataType, MapType): - for k, v in obj.items(): - _verify_type(k, dataType.keyType, False) - _verify_type(v, dataType.valueType, dataType.valueContainsNull) + key_verifier = _make_type_verifier(dataType.keyType, False, name="key of map %s" % name) + value_verifier = _make_type_verifier( + dataType.valueType, dataType.valueContainsNull, name="value of map %s" % name) + + def verify_map(obj): + assert_acceptable_types(obj) + verify_acceptable_types(obj) + for k, v in obj.items(): + key_verifier(k) + value_verifier(v) + + verify_value = verify_map elif isinstance(dataType, StructType): - if isinstance(obj, dict): - for f in dataType.fields: - _verify_type(obj.get(f.name), f.dataType, f.nullable) - elif isinstance(obj, Row) and getattr(obj, "__from_dict__", False): - # the order in obj could be different than dataType.fields - for f in dataType.fields: - _verify_type(obj[f.name], f.dataType, f.nullable) - elif isinstance(obj, (tuple, list)): - if len(obj) != len(dataType.fields): - raise ValueError("Length of object (%d) does not match with " - "length of fields (%d)" % (len(obj), len(dataType.fields))) - for v, f in zip(obj, dataType.fields): - _verify_type(v, f.dataType, f.nullable) - elif hasattr(obj, "__dict__"): - d = obj.__dict__ - for f in dataType.fields: - _verify_type(d.get(f.name), f.dataType, f.nullable) - else: - raise TypeError("StructType can not accept object %r in type %s" % (obj, type(obj))) + verifiers = [] + for f in dataType.fields: + verifier = _make_type_verifier(f.dataType, f.nullable, name=new_name(f.name)) + verifiers.append((f.name, verifier)) + + def verify_struct(obj): + assert_acceptable_types(obj) + + if isinstance(obj, dict): + for f, verifier in verifiers: + verifier(obj.get(f)) + elif isinstance(obj, Row) and getattr(obj, "__from_dict__", False): + # the order in obj could be different than dataType.fields + for f, verifier in verifiers: + verifier(obj[f]) + elif isinstance(obj, (tuple, list)): + if len(obj) != len(verifiers): + raise ValueError( + new_msg("Length of object (%d) does not match with " + "length of fields (%d)" % (len(obj), len(verifiers)))) + for v, (_, verifier) in zip(obj, verifiers): + verifier(v) + elif hasattr(obj, "__dict__"): + d = obj.__dict__ + for f, verifier in verifiers: + verifier(d.get(f)) + else: + raise TypeError(new_msg("StructType can not accept object %r in type %s" + % (obj, type(obj)))) + verify_value = verify_struct + + else: + def verify_default(obj): + assert_acceptable_types(obj) + verify_acceptable_types(obj) + + verify_value = verify_default + + def verify(obj): + if not verify_nullability(obj): + verify_value(obj) + + return verify # This is used to unpickle a Row from JVM From 29b1f6b09f98e216af71e893a9da0c4717c80679 Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Tue, 4 Jul 2017 08:54:07 -0700 Subject: [PATCH 0858/1765] [SPARK-21256][SQL] Add withSQLConf to Catalyst Test ### What changes were proposed in this pull request? SQLConf is moved to Catalyst. We are adding more and more test cases for verifying the conf-specific behaviors. It is nice to add a helper function to simplify the test cases. ### How was this patch tested? N/A Author: gatorsmile Closes #18469 from gatorsmile/withSQLConf. --- .../InferFiltersFromConstraintsSuite.scala | 5 +-- .../optimizer/OuterJoinEliminationSuite.scala | 6 +--- .../optimizer/PruneFiltersSuite.scala | 6 +--- .../plans/ConstraintPropagationSuite.scala | 24 +++++++------- .../spark/sql/catalyst/plans/PlanTest.scala | 32 ++++++++++++++++++- .../AggregateEstimationSuite.scala | 9 ++---- .../BasicStatsEstimationSuite.scala | 12 +++---- .../spark/sql/SparkSessionBuilderSuite.scala | 3 ++ .../apache/spark/sql/test/SQLTestUtils.scala | 30 ++++------------- 9 files changed, 64 insertions(+), 63 deletions(-) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/InferFiltersFromConstraintsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/InferFiltersFromConstraintsSuite.scala index cdc9f25cf8777..d2dd469e2d74f 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/InferFiltersFromConstraintsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/InferFiltersFromConstraintsSuite.scala @@ -206,13 +206,10 @@ class InferFiltersFromConstraintsSuite extends PlanTest { } test("No inferred filter when constraint propagation is disabled") { - try { - SQLConf.get.setConf(SQLConf.CONSTRAINT_PROPAGATION_ENABLED, false) + withSQLConf(SQLConf.CONSTRAINT_PROPAGATION_ENABLED.key -> "false") { val originalQuery = testRelation.where('a === 1 && 'a === 'b).analyze val optimized = Optimize.execute(originalQuery) comparePlans(optimized, originalQuery) - } finally { - SQLConf.get.unsetConf(SQLConf.CONSTRAINT_PROPAGATION_ENABLED) } } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OuterJoinEliminationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OuterJoinEliminationSuite.scala index 623ff3d446a5f..893c111c2906b 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OuterJoinEliminationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OuterJoinEliminationSuite.scala @@ -234,9 +234,7 @@ class OuterJoinEliminationSuite extends PlanTest { } test("no outer join elimination if constraint propagation is disabled") { - try { - SQLConf.get.setConf(SQLConf.CONSTRAINT_PROPAGATION_ENABLED, false) - + withSQLConf(SQLConf.CONSTRAINT_PROPAGATION_ENABLED.key -> "false") { val x = testRelation.subquery('x) val y = testRelation1.subquery('y) @@ -251,8 +249,6 @@ class OuterJoinEliminationSuite extends PlanTest { val optimized = Optimize.execute(originalQuery.analyze) comparePlans(optimized, originalQuery.analyze) - } finally { - SQLConf.get.unsetConf(SQLConf.CONSTRAINT_PROPAGATION_ENABLED) } } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PruneFiltersSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PruneFiltersSuite.scala index 706634cdd29b8..6d1a05f3c998e 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PruneFiltersSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PruneFiltersSuite.scala @@ -25,7 +25,6 @@ import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules._ import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.internal.SQLConf.CONSTRAINT_PROPAGATION_ENABLED class PruneFiltersSuite extends PlanTest { @@ -149,8 +148,7 @@ class PruneFiltersSuite extends PlanTest { ("tr1.a".attr > 10 || "tr1.c".attr < 10) && 'd.attr < 100) - SQLConf.get.setConf(SQLConf.CONSTRAINT_PROPAGATION_ENABLED, false) - try { + withSQLConf(SQLConf.CONSTRAINT_PROPAGATION_ENABLED.key -> "false") { val optimized = Optimize.execute(queryWithUselessFilter.analyze) // When constraint propagation is disabled, the useless filter won't be pruned. // It gets pushed down. Because the rule `CombineFilters` runs only once, there are redundant @@ -160,8 +158,6 @@ class PruneFiltersSuite extends PlanTest { .join(tr2.where('d.attr < 100).where('d.attr < 100), Inner, Some("tr1.a".attr === "tr2.a".attr)).analyze comparePlans(optimized, correctAnswer) - } finally { - SQLConf.get.unsetConf(SQLConf.CONSTRAINT_PROPAGATION_ENABLED) } } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/ConstraintPropagationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/ConstraintPropagationSuite.scala index a3948d90b0e4d..a37e06d922642 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/ConstraintPropagationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/ConstraintPropagationSuite.scala @@ -28,7 +28,7 @@ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.{DataType, DoubleType, IntegerType, LongType, StringType} -class ConstraintPropagationSuite extends SparkFunSuite { +class ConstraintPropagationSuite extends SparkFunSuite with PlanTest { private def resolveColumn(tr: LocalRelation, columnName: String): Expression = resolveColumn(tr.analyze, columnName) @@ -400,26 +400,26 @@ class ConstraintPropagationSuite extends SparkFunSuite { } test("enable/disable constraint propagation") { - try { - val tr = LocalRelation('a.int, 'b.string, 'c.int) - val filterRelation = tr.where('a.attr > 10) + val tr = LocalRelation('a.int, 'b.string, 'c.int) + val filterRelation = tr.where('a.attr > 10) - SQLConf.get.setConf(SQLConf.CONSTRAINT_PROPAGATION_ENABLED, true) + withSQLConf(SQLConf.CONSTRAINT_PROPAGATION_ENABLED.key -> "true") { assert(filterRelation.analyze.constraints.nonEmpty) + } - SQLConf.get.setConf(SQLConf.CONSTRAINT_PROPAGATION_ENABLED, false) + withSQLConf(SQLConf.CONSTRAINT_PROPAGATION_ENABLED.key -> "false") { assert(filterRelation.analyze.constraints.isEmpty) + } - val aliasedRelation = tr.where('c.attr > 10 && 'a.attr < 5) - .groupBy('a, 'c, 'b)('a, 'c.as("c1"), count('a).as("a3")).select('c1, 'a, 'a3) + val aliasedRelation = tr.where('c.attr > 10 && 'a.attr < 5) + .groupBy('a, 'c, 'b)('a, 'c.as("c1"), count('a).as("a3")).select('c1, 'a, 'a3) - SQLConf.get.setConf(SQLConf.CONSTRAINT_PROPAGATION_ENABLED, true) + withSQLConf(SQLConf.CONSTRAINT_PROPAGATION_ENABLED.key -> "true") { assert(aliasedRelation.analyze.constraints.nonEmpty) + } - SQLConf.get.setConf(SQLConf.CONSTRAINT_PROPAGATION_ENABLED, false) + withSQLConf(SQLConf.CONSTRAINT_PROPAGATION_ENABLED.key -> "false") { assert(aliasedRelation.analyze.constraints.isEmpty) - } finally { - SQLConf.get.unsetConf(SQLConf.CONSTRAINT_PROPAGATION_ENABLED) } } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala index 6883d23d477e4..e9679d3361509 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.catalyst.plans import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.analysis.SimpleAnalyzer import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression @@ -28,8 +29,9 @@ import org.apache.spark.sql.internal.SQLConf /** * Provides helper methods for comparing plans. */ -abstract class PlanTest extends SparkFunSuite with PredicateHelper { +trait PlanTest extends SparkFunSuite with PredicateHelper { + // TODO(gatorsmile): remove this from PlanTest and all the analyzer/optimizer rules protected val conf = new SQLConf().copy(SQLConf.CASE_SENSITIVE -> true) /** @@ -142,4 +144,32 @@ abstract class PlanTest extends SparkFunSuite with PredicateHelper { plan1 == plan2 } } + + /** + * Sets all SQL configurations specified in `pairs`, calls `f`, and then restore all SQL + * configurations. + */ + protected def withSQLConf(pairs: (String, String)*)(f: => Unit): Unit = { + val conf = SQLConf.get + val (keys, values) = pairs.unzip + val currentValues = keys.map { key => + if (conf.contains(key)) { + Some(conf.getConfString(key)) + } else { + None + } + } + (keys, values).zipped.foreach { (k, v) => + if (SQLConf.staticConfKeys.contains(k)) { + throw new AnalysisException(s"Cannot modify the value of a static config: $k") + } + conf.setConfString(k, v) + } + try f finally { + keys.zip(currentValues).foreach { + case (key, Some(value)) => conf.setConfString(key, value) + case (key, None) => conf.unsetConf(key) + } + } + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/AggregateEstimationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/AggregateEstimationSuite.scala index 30ddf03bd3c4f..23f95a6cc2ac2 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/AggregateEstimationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/AggregateEstimationSuite.scala @@ -19,12 +19,13 @@ package org.apache.spark.sql.catalyst.statsEstimation import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, AttributeMap, Literal} import org.apache.spark.sql.catalyst.expressions.aggregate.Count +import org.apache.spark.sql.catalyst.plans.PlanTest import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.plans.logical.statsEstimation.EstimationUtils._ import org.apache.spark.sql.internal.SQLConf -class AggregateEstimationSuite extends StatsEstimationTestBase { +class AggregateEstimationSuite extends StatsEstimationTestBase with PlanTest { /** Columns for testing */ private val columnInfo: AttributeMap[ColumnStat] = AttributeMap(Seq( @@ -100,9 +101,7 @@ class AggregateEstimationSuite extends StatsEstimationTestBase { size = Some(4 * (8 + 4)), attributeStats = AttributeMap(Seq("key12").map(nameToColInfo))) - val originalValue = SQLConf.get.getConf(SQLConf.CBO_ENABLED) - try { - SQLConf.get.setConf(SQLConf.CBO_ENABLED, false) + withSQLConf(SQLConf.CBO_ENABLED.key -> "false") { val noGroupAgg = Aggregate(groupingExpressions = Nil, aggregateExpressions = Seq(Alias(Count(Literal(1)), "cnt")()), child) assert(noGroupAgg.stats == @@ -114,8 +113,6 @@ class AggregateEstimationSuite extends StatsEstimationTestBase { assert(hasGroupAgg.stats == // From UnaryNode.computeStats, childSize * outputRowSize / childRowSize Statistics(sizeInBytes = 48 * (8 + 4 + 8) / (8 + 4))) - } finally { - SQLConf.get.setConf(SQLConf.CBO_ENABLED, originalValue) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/BasicStatsEstimationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/BasicStatsEstimationSuite.scala index 31a8cbdee9777..5fd21a06a109d 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/BasicStatsEstimationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/BasicStatsEstimationSuite.scala @@ -18,12 +18,13 @@ package org.apache.spark.sql.catalyst.statsEstimation import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap, AttributeReference, Literal} +import org.apache.spark.sql.catalyst.plans.PlanTest import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.IntegerType -class BasicStatsEstimationSuite extends StatsEstimationTestBase { +class BasicStatsEstimationSuite extends PlanTest with StatsEstimationTestBase { val attribute = attr("key") val colStat = ColumnStat(distinctCount = 10, min = Some(1), max = Some(10), nullCount = 0, avgLen = 4, maxLen = 4) @@ -82,18 +83,15 @@ class BasicStatsEstimationSuite extends StatsEstimationTestBase { plan: LogicalPlan, expectedStatsCboOn: Statistics, expectedStatsCboOff: Statistics): Unit = { - val originalValue = SQLConf.get.getConf(SQLConf.CBO_ENABLED) - try { + withSQLConf(SQLConf.CBO_ENABLED.key -> "true") { // Invalidate statistics plan.invalidateStatsCache() - SQLConf.get.setConf(SQLConf.CBO_ENABLED, true) assert(plan.stats == expectedStatsCboOn) + } + withSQLConf(SQLConf.CBO_ENABLED.key -> "false") { plan.invalidateStatsCache() - SQLConf.get.setConf(SQLConf.CBO_ENABLED, false) assert(plan.stats == expectedStatsCboOff) - } finally { - SQLConf.get.setConf(SQLConf.CBO_ENABLED, originalValue) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionBuilderSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionBuilderSuite.scala index 4f6d5f79d466e..cdac6827082c4 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionBuilderSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionBuilderSuite.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql import org.apache.spark.{SparkConf, SparkContext, SparkFunSuite} +import org.apache.spark.sql.internal.SQLConf /** * Test cases for the builder pattern of [[SparkSession]]. @@ -67,6 +68,8 @@ class SparkSessionBuilderSuite extends SparkFunSuite { assert(activeSession != defaultSession) assert(session == activeSession) assert(session.conf.get("spark-config2") == "a") + assert(session.sessionState.conf == SQLConf.get) + assert(SQLConf.get.getConfString("spark-config2") == "a") SparkSession.clearActiveSession() assert(SparkSession.builder().getOrCreate() == defaultSession) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala index d74a7cce25ed6..92ee7d596acd1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala @@ -35,9 +35,11 @@ import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.analysis.NoSuchTableException import org.apache.spark.sql.catalyst.catalog.SessionCatalog.DEFAULT_DATABASE import org.apache.spark.sql.catalyst.FunctionIdentifier +import org.apache.spark.sql.catalyst.plans.PlanTest import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.util._ import org.apache.spark.sql.execution.FilterExec +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.util.{UninterruptibleThread, Utils} /** @@ -53,7 +55,8 @@ import org.apache.spark.util.{UninterruptibleThread, Utils} private[sql] trait SQLTestUtils extends SparkFunSuite with Eventually with BeforeAndAfterAll - with SQLTestData { self => + with SQLTestData + with PlanTest { self => protected def sparkContext = spark.sparkContext @@ -89,28 +92,9 @@ private[sql] trait SQLTestUtils } } - /** - * Sets all SQL configurations specified in `pairs`, calls `f`, and then restore all SQL - * configurations. - * - * @todo Probably this method should be moved to a more general place - */ - protected def withSQLConf(pairs: (String, String)*)(f: => Unit): Unit = { - val (keys, values) = pairs.unzip - val currentValues = keys.map { key => - if (spark.conf.contains(key)) { - Some(spark.conf.get(key)) - } else { - None - } - } - (keys, values).zipped.foreach(spark.conf.set) - try f finally { - keys.zip(currentValues).foreach { - case (key, Some(value)) => spark.conf.set(key, value) - case (key, None) => spark.conf.unset(key) - } - } + protected override def withSQLConf(pairs: (String, String)*)(f: => Unit): Unit = { + SparkSession.setActiveSession(spark) + super.withSQLConf(pairs: _*)(f) } /** From a3c29fcbbda02c1528b4185bcb880c91077d480c Mon Sep 17 00:00:00 2001 From: "YIHAODIAN\\wangshuangshuang" Date: Tue, 4 Jul 2017 09:44:27 -0700 Subject: [PATCH 0859/1765] [SPARK-19726][SQL] Faild to insert null timestamp value to mysql using spark jdbc ## What changes were proposed in this pull request? when creating table like following: > create table timestamp_test(id int(11), time_stamp timestamp not null default current_timestamp); The result of Excuting "insert into timestamp_test values (111, null)" is different between Spark and JDBC. ``` mysql> select * from timestamp_test; +------+---------------------+ | id | time_stamp | +------+---------------------+ | 111 | 1970-01-01 00:00:00 | -> spark | 111 | 2017-06-27 19:32:38 | -> mysql +------+---------------------+ 2 rows in set (0.00 sec) ``` Because in such case ```StructField.nullable``` is false, so the generated codes of ```InvokeLike``` and ```BoundReference``` don't check whether the field is null or not. Instead, they directly use ```CodegenContext.INPUT_ROW.getLong(1)```, however, ```UnsafeRow.setNullAt(1)``` will put 0 in the underlying memory. The PR will ```always``` set ```StructField.nullable``` true after obtaining metadata from jdbc connection, Since we can insert null to not null timestamp column in MySQL. In this way, spark will propagate null to underlying DB engine, and let DB to choose how to process NULL. ## How was this patch tested? Added tests. Please review http://spark.apache.org/contributing.html before opening a pull request. Author: YIHAODIAN\wangshuangshuang Author: Shuangshuang Wang Closes #18445 from shuangshuangwang/SPARK-19726. --- .../sql/execution/datasources/jdbc/JDBCRDD.scala | 2 +- .../sql/execution/datasources/jdbc/JdbcUtils.scala | 12 ++++++++++-- .../org/apache/spark/sql/jdbc/JDBCWriteSuite.scala | 8 ++++++++ 3 files changed, 19 insertions(+), 3 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala index 0f53b5c7c6f0f..57e9bc9b70454 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala @@ -59,7 +59,7 @@ object JDBCRDD extends Logging { try { val rs = statement.executeQuery() try { - JdbcUtils.getSchema(rs, dialect) + JdbcUtils.getSchema(rs, dialect, alwaysNullable = true) } finally { rs.close() } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala index ca61c2efe2ddf..55b2539c13381 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala @@ -266,10 +266,14 @@ object JdbcUtils extends Logging { /** * Takes a [[ResultSet]] and returns its Catalyst schema. * + * @param alwaysNullable If true, all the columns are nullable. * @return A [[StructType]] giving the Catalyst schema. * @throws SQLException if the schema contains an unsupported type. */ - def getSchema(resultSet: ResultSet, dialect: JdbcDialect): StructType = { + def getSchema( + resultSet: ResultSet, + dialect: JdbcDialect, + alwaysNullable: Boolean = false): StructType = { val rsmd = resultSet.getMetaData val ncols = rsmd.getColumnCount val fields = new Array[StructField](ncols) @@ -290,7 +294,11 @@ object JdbcUtils extends Logging { rsmd.getClass.getName == "org.apache.hive.jdbc.HiveResultSetMetaData" => true } } - val nullable = rsmd.isNullable(i + 1) != ResultSetMetaData.columnNoNulls + val nullable = if (alwaysNullable) { + true + } else { + rsmd.isNullable(i + 1) != ResultSetMetaData.columnNoNulls + } val metadata = new MetadataBuilder() .putString("name", columnName) .putLong("scale", fieldScale) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala index bf1fd160704fa..92f50a095f19b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala @@ -24,6 +24,7 @@ import scala.collection.JavaConverters.propertiesAsScalaMapConverter import org.scalatest.BeforeAndAfter +import org.apache.spark.SparkException import org.apache.spark.sql.{AnalysisException, DataFrame, Row, SaveMode} import org.apache.spark.sql.catalyst.parser.ParseException import org.apache.spark.sql.execution.datasources.jdbc.{JDBCOptions, JdbcUtils} @@ -506,4 +507,11 @@ class JDBCWriteSuite extends SharedSQLContext with BeforeAndAfter { "schema struct")) } } + + test("SPARK-19726: INSERT null to a NOT NULL column") { + val e = intercept[SparkException] { + sql("INSERT INTO PEOPLE1 values (null, null)") + }.getMessage + assert(e.contains("NULL not allowed for column \"NAME\"")) + } } From 1b50e0e0d6fd9d1b815a3bb37647ea659222e3f1 Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Tue, 4 Jul 2017 09:48:40 -0700 Subject: [PATCH 0860/1765] [SPARK-20256][SQL] SessionState should be created more lazily ## What changes were proposed in this pull request? `SessionState` is designed to be created lazily. However, in reality, it created immediately in `SparkSession.Builder.getOrCreate` ([here](https://github.com/apache/spark/blob/master/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala#L943)). This PR aims to recover the lazy behavior by keeping the options into `initialSessionOptions`. The benefit is like the following. Users can start `spark-shell` and use RDD operations without any problems. **BEFORE** ```scala $ bin/spark-shell java.lang.IllegalArgumentException: Error while instantiating 'org.apache.spark.sql.hive.HiveSessionStateBuilder' ... Caused by: org.apache.spark.sql.AnalysisException: org.apache.hadoop.hive.ql.metadata.HiveException: MetaException(message:java.security.AccessControlException: Permission denied: user=spark, access=READ, inode="/apps/hive/warehouse":hive:hdfs:drwx------ ``` As reported in SPARK-20256, this happens when the warehouse directory is not allowed for this user. **AFTER** ```scala $ bin/spark-shell ... Welcome to ____ __ / __/__ ___ _____/ /__ _\ \/ _ \/ _ `/ __/ '_/ /___/ .__/\_,_/_/ /_/\_\ version 2.3.0-SNAPSHOT /_/ Using Scala version 2.11.8 (Java HotSpot(TM) 64-Bit Server VM, Java 1.8.0_112) Type in expressions to have them evaluated. Type :help for more information. scala> sc.range(0, 10, 1).count() res0: Long = 10 ``` ## How was this patch tested? Manual. This closes #18512 . Author: Dongjoon Hyun Closes #18501 from dongjoon-hyun/SPARK-20256. --- .../scala/org/apache/spark/sql/SparkSession.scala | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala index 2c38f7d7c88da..0ddcd2111aa58 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala @@ -117,6 +117,12 @@ class SparkSession private( existingSharedState.getOrElse(new SharedState(sparkContext)) } + /** + * Initial options for session. This options are applied once when sessionState is created. + */ + @transient + private[sql] val initialSessionOptions = new scala.collection.mutable.HashMap[String, String] + /** * State isolated across sessions, including SQL configurations, temporary tables, registered * functions, and everything else that accepts a [[org.apache.spark.sql.internal.SQLConf]]. @@ -132,9 +138,11 @@ class SparkSession private( parentSessionState .map(_.clone(this)) .getOrElse { - SparkSession.instantiateSessionState( + val state = SparkSession.instantiateSessionState( SparkSession.sessionStateClassName(sparkContext.conf), self) + initialSessionOptions.foreach { case (k, v) => state.conf.setConfString(k, v) } + state } } @@ -940,7 +948,7 @@ object SparkSession { } session = new SparkSession(sparkContext, None, None, extensions) - options.foreach { case (k, v) => session.sessionState.conf.setConfString(k, v) } + options.foreach { case (k, v) => session.initialSessionOptions.put(k, v) } defaultSession.set(session) // Register a successfully instantiated context to the singleton. This should be at the From 4d6d8192c807006ff89488a1d38bc6f7d41de5cf Mon Sep 17 00:00:00 2001 From: dardelet Date: Tue, 4 Jul 2017 17:58:44 +0100 Subject: [PATCH 0861/1765] [SPARK-21268][MLLIB] Move center calculations to a distributed map in KMeans ## What changes were proposed in this pull request? The scal() and creation of newCenter vector is done in the driver, after a collectAsMap operation while it could be done in the distributed RDD. This PR moves this code before the collectAsMap for more efficiency ## How was this patch tested? This was tested manually by running the KMeansExample and verifying that the new code ran without error and gave same output as before. Author: dardelet Author: Guillaume Dardelet Closes #18491 from dardelet/move-center-calculation-to-distributed-map-kmean. --- .../org/apache/spark/mllib/clustering/KMeans.scala | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala index fa72b72e2d921..98e50c5b45cfd 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala @@ -272,8 +272,8 @@ class KMeans private ( val costAccum = sc.doubleAccumulator val bcCenters = sc.broadcast(centers) - // Find the sum and count of points mapping to each center - val totalContribs = data.mapPartitions { points => + // Find the new centers + val newCenters = data.mapPartitions { points => val thisCenters = bcCenters.value val dims = thisCenters.head.vector.size @@ -292,15 +292,16 @@ class KMeans private ( }.reduceByKey { case ((sum1, count1), (sum2, count2)) => axpy(1.0, sum2, sum1) (sum1, count1 + count2) + }.mapValues { case (sum, count) => + scal(1.0 / count, sum) + new VectorWithNorm(sum) }.collectAsMap() bcCenters.destroy(blocking = false) // Update the cluster centers and costs converged = true - totalContribs.foreach { case (j, (sum, count)) => - scal(1.0 / count, sum) - val newCenter = new VectorWithNorm(sum) + newCenters.foreach { case (j, newCenter) => if (converged && KMeans.fastSquaredDistance(newCenter, centers(j)) > epsilon * epsilon) { converged = false } From cec392150451a64c9c2902b7f8f4b3b38f25cbea Mon Sep 17 00:00:00 2001 From: actuaryzhang Date: Tue, 4 Jul 2017 12:18:51 -0700 Subject: [PATCH 0862/1765] [SPARK-20889][SPARKR] Grouped documentation for WINDOW column methods ## What changes were proposed in this pull request? Grouped documentation for column window methods. Author: actuaryzhang Closes #18481 from actuaryzhang/sparkRDocWindow. --- R/pkg/R/functions.R | 225 ++++++++++++++------------------------------ R/pkg/R/generics.R | 28 +++--- 2 files changed, 88 insertions(+), 165 deletions(-) diff --git a/R/pkg/R/functions.R b/R/pkg/R/functions.R index a1f5c4f8cc18d..8c12308c1d7c1 100644 --- a/R/pkg/R/functions.R +++ b/R/pkg/R/functions.R @@ -200,6 +200,34 @@ NULL #' head(select(tmp, sort_array(tmp$v1, asc = FALSE)))} NULL +#' Window functions for Column operations +#' +#' Window functions defined for \code{Column}. +#' +#' @param x In \code{lag} and \code{lead}, it is the column as a character string or a Column +#' to compute on. In \code{ntile}, it is the number of ntile groups. +#' @param offset In \code{lag}, the number of rows back from the current row from which to obtain +#' a value. In \code{lead}, the number of rows after the current row from which to +#' obtain a value. If not specified, the default is 1. +#' @param defaultValue (optional) default to use when the offset row does not exist. +#' @param ... additional argument(s). +#' @name column_window_functions +#' @rdname column_window_functions +#' @family window functions +#' @examples +#' \dontrun{ +#' # Dataframe used throughout this doc +#' df <- createDataFrame(cbind(model = rownames(mtcars), mtcars)) +#' ws <- orderBy(windowPartitionBy("am"), "hp") +#' tmp <- mutate(df, dist = over(cume_dist(), ws), dense_rank = over(dense_rank(), ws), +#' lag = over(lag(df$mpg), ws), lead = over(lead(df$mpg, 1), ws), +#' percent_rank = over(percent_rank(), ws), +#' rank = over(rank(), ws), row_number = over(row_number(), ws)) +#' # Get ntile group id (1-4) for hp +#' tmp <- mutate(tmp, ntile = over(ntile(4), ws)) +#' head(tmp)} +NULL + #' @details #' \code{lit}: A new Column is created to represent the literal value. #' If the parameter is a Column, it is returned unchanged. @@ -2844,27 +2872,16 @@ setMethod("ifelse", ###################### Window functions###################### -#' cume_dist -#' -#' Window function: returns the cumulative distribution of values within a window partition, -#' i.e. the fraction of rows that are below the current row. -#' -#' N = total number of rows in the partition -#' cume_dist(x) = number of values before (and including) x / N -#' +#' @details +#' \code{cume_dist}: Returns the cumulative distribution of values within a window partition, +#' i.e. the fraction of rows that are below the current row: +#' (number of values before and including x) / (total number of rows in the partition). #' This is equivalent to the \code{CUME_DIST} function in SQL. +#' The method should be used with no argument. #' -#' @rdname cume_dist -#' @name cume_dist -#' @family window functions -#' @aliases cume_dist,missing-method +#' @rdname column_window_functions +#' @aliases cume_dist cume_dist,missing-method #' @export -#' @examples -#' \dontrun{ -#' df <- createDataFrame(mtcars) -#' ws <- orderBy(windowPartitionBy("am"), "hp") -#' out <- select(df, over(cume_dist(), ws), df$hp, df$am) -#' } #' @note cume_dist since 1.6.0 setMethod("cume_dist", signature("missing"), @@ -2873,28 +2890,19 @@ setMethod("cume_dist", column(jc) }) -#' dense_rank -#' -#' Window function: returns the rank of rows within a window partition, without any gaps. +#' @details +#' \code{dense_rank}: Returns the rank of rows within a window partition, without any gaps. #' The difference between rank and dense_rank is that dense_rank leaves no gaps in ranking #' sequence when there are ties. That is, if you were ranking a competition using dense_rank #' and had three people tie for second place, you would say that all three were in second #' place and that the next person came in third. Rank would give me sequential numbers, making #' the person that came in third place (after the ties) would register as coming in fifth. -#' #' This is equivalent to the \code{DENSE_RANK} function in SQL. +#' The method should be used with no argument. #' -#' @rdname dense_rank -#' @name dense_rank -#' @family window functions -#' @aliases dense_rank,missing-method +#' @rdname column_window_functions +#' @aliases dense_rank dense_rank,missing-method #' @export -#' @examples -#' \dontrun{ -#' df <- createDataFrame(mtcars) -#' ws <- orderBy(windowPartitionBy("am"), "hp") -#' out <- select(df, over(dense_rank(), ws), df$hp, df$am) -#' } #' @note dense_rank since 1.6.0 setMethod("dense_rank", signature("missing"), @@ -2903,34 +2911,15 @@ setMethod("dense_rank", column(jc) }) -#' lag -#' -#' Window function: returns the value that is \code{offset} rows before the current row, and +#' @details +#' \code{lag}: Returns the value that is \code{offset} rows before the current row, and #' \code{defaultValue} if there is less than \code{offset} rows before the current row. For example, #' an \code{offset} of one will return the previous row at any given point in the window partition. -#' #' This is equivalent to the \code{LAG} function in SQL. #' -#' @param x the column as a character string or a Column to compute on. -#' @param offset the number of rows back from the current row from which to obtain a value. -#' If not specified, the default is 1. -#' @param defaultValue (optional) default to use when the offset row does not exist. -#' @param ... further arguments to be passed to or from other methods. -#' @rdname lag -#' @name lag -#' @aliases lag,characterOrColumn-method -#' @family window functions +#' @rdname column_window_functions +#' @aliases lag lag,characterOrColumn-method #' @export -#' @examples -#' \dontrun{ -#' df <- createDataFrame(mtcars) -#' -#' # Partition by am (transmission) and order by hp (horsepower) -#' ws <- orderBy(windowPartitionBy("am"), "hp") -#' -#' # Lag mpg values by 1 row on the partition-and-ordered table -#' out <- select(df, over(lag(df$mpg), ws), df$mpg, df$hp, df$am) -#' } #' @note lag since 1.6.0 setMethod("lag", signature(x = "characterOrColumn"), @@ -2946,35 +2935,16 @@ setMethod("lag", column(jc) }) -#' lead -#' -#' Window function: returns the value that is \code{offset} rows after the current row, and +#' @details +#' \code{lead}: Returns the value that is \code{offset} rows after the current row, and #' \code{defaultValue} if there is less than \code{offset} rows after the current row. #' For example, an \code{offset} of one will return the next row at any given point #' in the window partition. -#' #' This is equivalent to the \code{LEAD} function in SQL. #' -#' @param x the column as a character string or a Column to compute on. -#' @param offset the number of rows after the current row from which to obtain a value. -#' If not specified, the default is 1. -#' @param defaultValue (optional) default to use when the offset row does not exist. -#' -#' @rdname lead -#' @name lead -#' @family window functions -#' @aliases lead,characterOrColumn,numeric-method +#' @rdname column_window_functions +#' @aliases lead lead,characterOrColumn,numeric-method #' @export -#' @examples -#' \dontrun{ -#' df <- createDataFrame(mtcars) -#' -#' # Partition by am (transmission) and order by hp (horsepower) -#' ws <- orderBy(windowPartitionBy("am"), "hp") -#' -#' # Lead mpg values by 1 row on the partition-and-ordered table -#' out <- select(df, over(lead(df$mpg), ws), df$mpg, df$hp, df$am) -#' } #' @note lead since 1.6.0 setMethod("lead", signature(x = "characterOrColumn", offset = "numeric", defaultValue = "ANY"), @@ -2990,31 +2960,15 @@ setMethod("lead", column(jc) }) -#' ntile -#' -#' Window function: returns the ntile group id (from 1 to n inclusive) in an ordered window +#' @details +#' \code{ntile}: Returns the ntile group id (from 1 to n inclusive) in an ordered window #' partition. For example, if n is 4, the first quarter of the rows will get value 1, the second #' quarter will get 2, the third quarter will get 3, and the last quarter will get 4. -#' #' This is equivalent to the \code{NTILE} function in SQL. #' -#' @param x Number of ntile groups -#' -#' @rdname ntile -#' @name ntile -#' @aliases ntile,numeric-method -#' @family window functions +#' @rdname column_window_functions +#' @aliases ntile ntile,numeric-method #' @export -#' @examples -#' \dontrun{ -#' df <- createDataFrame(mtcars) -#' -#' # Partition by am (transmission) and order by hp (horsepower) -#' ws <- orderBy(windowPartitionBy("am"), "hp") -#' -#' # Get ntile group id (1-4) for hp -#' out <- select(df, over(ntile(4), ws), df$hp, df$am) -#' } #' @note ntile since 1.6.0 setMethod("ntile", signature(x = "numeric"), @@ -3023,27 +2977,15 @@ setMethod("ntile", column(jc) }) -#' percent_rank -#' -#' Window function: returns the relative rank (i.e. percentile) of rows within a window partition. -#' -#' This is computed by: -#' -#' (rank of row in its partition - 1) / (number of rows in the partition - 1) -#' -#' This is equivalent to the PERCENT_RANK function in SQL. +#' @details +#' \code{percent_rank}: Returns the relative rank (i.e. percentile) of rows within a window partition. +#' This is computed by: (rank of row in its partition - 1) / (number of rows in the partition - 1). +#' This is equivalent to the \code{PERCENT_RANK} function in SQL. +#' The method should be used with no argument. #' -#' @rdname percent_rank -#' @name percent_rank -#' @family window functions -#' @aliases percent_rank,missing-method +#' @rdname column_window_functions +#' @aliases percent_rank percent_rank,missing-method #' @export -#' @examples -#' \dontrun{ -#' df <- createDataFrame(mtcars) -#' ws <- orderBy(windowPartitionBy("am"), "hp") -#' out <- select(df, over(percent_rank(), ws), df$hp, df$am) -#' } #' @note percent_rank since 1.6.0 setMethod("percent_rank", signature("missing"), @@ -3052,29 +2994,19 @@ setMethod("percent_rank", column(jc) }) -#' rank -#' -#' Window function: returns the rank of rows within a window partition. -#' +#' @details +#' \code{rank}: Returns the rank of rows within a window partition. #' The difference between rank and dense_rank is that dense_rank leaves no gaps in ranking #' sequence when there are ties. That is, if you were ranking a competition using dense_rank #' and had three people tie for second place, you would say that all three were in second #' place and that the next person came in third. Rank would give me sequential numbers, making #' the person that came in third place (after the ties) would register as coming in fifth. +#' This is equivalent to the \code{RANK} function in SQL. +#' The method should be used with no argument. #' -#' This is equivalent to the RANK function in SQL. -#' -#' @rdname rank -#' @name rank -#' @family window functions -#' @aliases rank,missing-method +#' @rdname column_window_functions +#' @aliases rank rank,missing-method #' @export -#' @examples -#' \dontrun{ -#' df <- createDataFrame(mtcars) -#' ws <- orderBy(windowPartitionBy("am"), "hp") -#' out <- select(df, over(rank(), ws), df$hp, df$am) -#' } #' @note rank since 1.6.0 setMethod("rank", signature(x = "missing"), @@ -3083,11 +3015,7 @@ setMethod("rank", column(jc) }) -# Expose rank() in the R base package -#' @param x a numeric, complex, character or logical vector. -#' @param ... additional argument(s) passed to the method. -#' @name rank -#' @rdname rank +#' @rdname column_window_functions #' @aliases rank,ANY-method #' @export setMethod("rank", @@ -3096,23 +3024,14 @@ setMethod("rank", base::rank(x, ...) }) -#' row_number -#' -#' Window function: returns a sequential number starting at 1 within a window partition. -#' -#' This is equivalent to the ROW_NUMBER function in SQL. +#' @details +#' \code{row_number}: Returns a sequential number starting at 1 within a window partition. +#' This is equivalent to the \code{ROW_NUMBER} function in SQL. +#' The method should be used with no argument. #' -#' @rdname row_number -#' @name row_number -#' @aliases row_number,missing-method -#' @family window functions +#' @rdname column_window_functions +#' @aliases row_number row_number,missing-method #' @export -#' @examples -#' \dontrun{ -#' df <- createDataFrame(mtcars) -#' ws <- orderBy(windowPartitionBy("am"), "hp") -#' out <- select(df, over(row_number(), ws), df$hp, df$am) -#' } #' @note row_number since 1.6.0 setMethod("row_number", signature("missing"), diff --git a/R/pkg/R/generics.R b/R/pkg/R/generics.R index b901b74e4728d..beac18e412736 100644 --- a/R/pkg/R/generics.R +++ b/R/pkg/R/generics.R @@ -1013,9 +1013,9 @@ setGeneric("create_map", function(x, ...) { standardGeneric("create_map") }) #' @name NULL setGeneric("hash", function(x, ...) { standardGeneric("hash") }) -#' @param x empty. Should be used with no argument. -#' @rdname cume_dist +#' @rdname column_window_functions #' @export +#' @name NULL setGeneric("cume_dist", function(x = "missing") { standardGeneric("cume_dist") }) #' @rdname column_datetime_diff_functions @@ -1053,9 +1053,9 @@ setGeneric("dayofyear", function(x) { standardGeneric("dayofyear") }) #' @name NULL setGeneric("decode", function(x, charset) { standardGeneric("decode") }) -#' @param x empty. Should be used with no argument. -#' @rdname dense_rank +#' @rdname column_window_functions #' @export +#' @name NULL setGeneric("dense_rank", function(x = "missing") { standardGeneric("dense_rank") }) #' @rdname column_string_functions @@ -1159,8 +1159,9 @@ setGeneric("isnan", function(x) { standardGeneric("isnan") }) #' @name NULL setGeneric("kurtosis", function(x) { standardGeneric("kurtosis") }) -#' @rdname lag +#' @rdname column_window_functions #' @export +#' @name NULL setGeneric("lag", function(x, ...) { standardGeneric("lag") }) #' @rdname last @@ -1172,8 +1173,9 @@ setGeneric("last", function(x, ...) { standardGeneric("last") }) #' @name NULL setGeneric("last_day", function(x) { standardGeneric("last_day") }) -#' @rdname lead +#' @rdname column_window_functions #' @export +#' @name NULL setGeneric("lead", function(x, offset, defaultValue = NULL) { standardGeneric("lead") }) #' @rdname column_nonaggregate_functions @@ -1260,8 +1262,9 @@ setGeneric("not", function(x) { standardGeneric("not") }) #' @name NULL setGeneric("next_day", function(y, x) { standardGeneric("next_day") }) -#' @rdname ntile +#' @rdname column_window_functions #' @export +#' @name NULL setGeneric("ntile", function(x) { standardGeneric("ntile") }) #' @rdname column_aggregate_functions @@ -1269,9 +1272,9 @@ setGeneric("ntile", function(x) { standardGeneric("ntile") }) #' @name NULL setGeneric("n_distinct", function(x, ...) { standardGeneric("n_distinct") }) -#' @param x empty. Should be used with no argument. -#' @rdname percent_rank +#' @rdname column_window_functions #' @export +#' @name NULL setGeneric("percent_rank", function(x = "missing") { standardGeneric("percent_rank") }) #' @rdname column_math_functions @@ -1304,8 +1307,9 @@ setGeneric("rand", function(seed) { standardGeneric("rand") }) #' @name NULL setGeneric("randn", function(seed) { standardGeneric("randn") }) -#' @rdname rank +#' @rdname column_window_functions #' @export +#' @name NULL setGeneric("rank", function(x, ...) { standardGeneric("rank") }) #' @rdname column_string_functions @@ -1334,9 +1338,9 @@ setGeneric("reverse", function(x) { standardGeneric("reverse") }) #' @name NULL setGeneric("rint", function(x) { standardGeneric("rint") }) -#' @param x empty. Should be used with no argument. -#' @rdname row_number +#' @rdname column_window_functions #' @export +#' @name NULL setGeneric("row_number", function(x = "missing") { standardGeneric("row_number") }) #' @rdname column_string_functions From daabf425ec0272951b11f286e4bec7a48f42cc0d Mon Sep 17 00:00:00 2001 From: wangmiao1981 Date: Tue, 4 Jul 2017 12:37:29 -0700 Subject: [PATCH 0863/1765] [MINOR][SPARKR] ignore Rplots.pdf test output after running R tests ## What changes were proposed in this pull request? After running R tests in local build, it outputs Rplots.pdf. This one should be ignored in the git repository. Author: wangmiao1981 Closes #18518 from wangmiao1981/ignore. --- .gitignore | 1 + 1 file changed, 1 insertion(+) diff --git a/.gitignore b/.gitignore index 1d91b43c23fa7..cf9780db37ad7 100644 --- a/.gitignore +++ b/.gitignore @@ -25,6 +25,7 @@ R-unit-tests.log R/unit-tests.out R/cran-check.out R/pkg/vignettes/sparkr-vignettes.html +R/pkg/tests/fulltests/Rplots.pdf build/*.jar build/apache-maven* build/scala* From de14086e1f6a2474bb9ba1452ada94e0ce58cf9c Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Wed, 5 Jul 2017 10:40:02 +0800 Subject: [PATCH 0864/1765] [SPARK-21295][SQL] Use qualified names in error message for missing references ### What changes were proposed in this pull request? It is strange to see the following error message. Actually, the column is from another table. ``` cannot resolve '`right.a`' given input columns: [a, c, d]; ``` After the PR, the error message looks like ``` cannot resolve '`right.a`' given input columns: [left.a, right.c, right.d]; ``` ### How was this patch tested? Added a test case Author: gatorsmile Closes #18520 from gatorsmile/removeSQLConf. --- .../sql/catalyst/analysis/CheckAnalysis.scala | 2 +- .../results/columnresolution-negative.sql.out | 10 +++++----- .../results/columnresolution-views.sql.out | 2 +- .../sql-tests/results/columnresolution.sql.out | 18 +++++++++--------- .../sql-tests/results/group-by.sql.out | 2 +- .../sql-tests/results/table-aliases.sql.out | 2 +- .../org/apache/spark/sql/SubquerySuite.scala | 4 ++-- 7 files changed, 20 insertions(+), 20 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala index fb81a7006bc5e..85c52792ef659 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala @@ -86,7 +86,7 @@ trait CheckAnalysis extends PredicateHelper { case operator: LogicalPlan => operator transformExpressionsUp { case a: Attribute if !a.resolved => - val from = operator.inputSet.map(_.name).mkString(", ") + val from = operator.inputSet.map(_.qualifiedName).mkString(", ") a.failAnalysis(s"cannot resolve '${a.sql}' given input columns: [$from]") case e: Expression if e.checkInputDataTypes().isFailure => diff --git a/sql/core/src/test/resources/sql-tests/results/columnresolution-negative.sql.out b/sql/core/src/test/resources/sql-tests/results/columnresolution-negative.sql.out index 60bd8e9cc99db..9e60e592c2bd1 100644 --- a/sql/core/src/test/resources/sql-tests/results/columnresolution-negative.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/columnresolution-negative.sql.out @@ -90,7 +90,7 @@ SELECT mydb1.t1.i1 FROM t1, mydb1.t1 struct<> -- !query 10 output org.apache.spark.sql.AnalysisException -cannot resolve '`mydb1.t1.i1`' given input columns: [i1, i1]; line 1 pos 7 +cannot resolve '`mydb1.t1.i1`' given input columns: [t1.i1, t1.i1]; line 1 pos 7 -- !query 11 @@ -161,7 +161,7 @@ SELECT db1.t1.i1 FROM t1, mydb2.t1 struct<> -- !query 18 output org.apache.spark.sql.AnalysisException -cannot resolve '`db1.t1.i1`' given input columns: [i1, i1]; line 1 pos 7 +cannot resolve '`db1.t1.i1`' given input columns: [t1.i1, t1.i1]; line 1 pos 7 -- !query 19 @@ -186,7 +186,7 @@ SELECT mydb1.t1 FROM t1 struct<> -- !query 21 output org.apache.spark.sql.AnalysisException -cannot resolve '`mydb1.t1`' given input columns: [i1]; line 1 pos 7 +cannot resolve '`mydb1.t1`' given input columns: [t1.i1]; line 1 pos 7 -- !query 22 @@ -204,7 +204,7 @@ SELECT t1 FROM mydb1.t1 struct<> -- !query 23 output org.apache.spark.sql.AnalysisException -cannot resolve '`t1`' given input columns: [i1]; line 1 pos 7 +cannot resolve '`t1`' given input columns: [t1.i1]; line 1 pos 7 -- !query 24 @@ -221,7 +221,7 @@ SELECT mydb1.t1.i1 FROM t1 struct<> -- !query 25 output org.apache.spark.sql.AnalysisException -cannot resolve '`mydb1.t1.i1`' given input columns: [i1]; line 1 pos 7 +cannot resolve '`mydb1.t1.i1`' given input columns: [t1.i1]; line 1 pos 7 -- !query 26 diff --git a/sql/core/src/test/resources/sql-tests/results/columnresolution-views.sql.out b/sql/core/src/test/resources/sql-tests/results/columnresolution-views.sql.out index 616421d6f2b28..7c451c2aa5b5c 100644 --- a/sql/core/src/test/resources/sql-tests/results/columnresolution-views.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/columnresolution-views.sql.out @@ -105,7 +105,7 @@ SELECT global_temp.view1.i1 FROM global_temp.view1 struct<> -- !query 12 output org.apache.spark.sql.AnalysisException -cannot resolve '`global_temp.view1.i1`' given input columns: [i1]; line 1 pos 7 +cannot resolve '`global_temp.view1.i1`' given input columns: [view1.i1]; line 1 pos 7 -- !query 13 diff --git a/sql/core/src/test/resources/sql-tests/results/columnresolution.sql.out b/sql/core/src/test/resources/sql-tests/results/columnresolution.sql.out index 764cad0e3943c..d3ca4443cce55 100644 --- a/sql/core/src/test/resources/sql-tests/results/columnresolution.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/columnresolution.sql.out @@ -96,7 +96,7 @@ SELECT mydb1.t1.i1 FROM t1 struct<> -- !query 11 output org.apache.spark.sql.AnalysisException -cannot resolve '`mydb1.t1.i1`' given input columns: [i1]; line 1 pos 7 +cannot resolve '`mydb1.t1.i1`' given input columns: [t1.i1]; line 1 pos 7 -- !query 12 @@ -105,7 +105,7 @@ SELECT mydb1.t1.i1 FROM mydb1.t1 struct<> -- !query 12 output org.apache.spark.sql.AnalysisException -cannot resolve '`mydb1.t1.i1`' given input columns: [i1]; line 1 pos 7 +cannot resolve '`mydb1.t1.i1`' given input columns: [t1.i1]; line 1 pos 7 -- !query 13 @@ -154,7 +154,7 @@ SELECT mydb1.t1.i1 FROM mydb1.t1 struct<> -- !query 18 output org.apache.spark.sql.AnalysisException -cannot resolve '`mydb1.t1.i1`' given input columns: [i1]; line 1 pos 7 +cannot resolve '`mydb1.t1.i1`' given input columns: [t1.i1]; line 1 pos 7 -- !query 19 @@ -270,7 +270,7 @@ SELECT * FROM mydb1.t3 WHERE c1 IN struct<> -- !query 32 output org.apache.spark.sql.AnalysisException -cannot resolve '`mydb1.t4.c3`' given input columns: [c2, c3]; line 2 pos 42 +cannot resolve '`mydb1.t4.c3`' given input columns: [t4.c2, t4.c3]; line 2 pos 42 -- !query 33 @@ -287,7 +287,7 @@ SELECT mydb1.t1.i1 FROM t1, mydb2.t1 struct<> -- !query 34 output org.apache.spark.sql.AnalysisException -cannot resolve '`mydb1.t1.i1`' given input columns: [i1, i1]; line 1 pos 7 +cannot resolve '`mydb1.t1.i1`' given input columns: [t1.i1, t1.i1]; line 1 pos 7 -- !query 35 @@ -296,7 +296,7 @@ SELECT mydb1.t1.i1 FROM mydb1.t1, mydb2.t1 struct<> -- !query 35 output org.apache.spark.sql.AnalysisException -cannot resolve '`mydb1.t1.i1`' given input columns: [i1, i1]; line 1 pos 7 +cannot resolve '`mydb1.t1.i1`' given input columns: [t1.i1, t1.i1]; line 1 pos 7 -- !query 36 @@ -313,7 +313,7 @@ SELECT mydb1.t1.i1 FROM t1, mydb1.t1 struct<> -- !query 37 output org.apache.spark.sql.AnalysisException -cannot resolve '`mydb1.t1.i1`' given input columns: [i1, i1]; line 1 pos 7 +cannot resolve '`mydb1.t1.i1`' given input columns: [t1.i1, t1.i1]; line 1 pos 7 -- !query 38 @@ -402,7 +402,7 @@ SELECT mydb1.t5.t5.i1 FROM mydb1.t5 struct<> -- !query 48 output org.apache.spark.sql.AnalysisException -cannot resolve '`mydb1.t5.t5.i1`' given input columns: [i1, t5]; line 1 pos 7 +cannot resolve '`mydb1.t5.t5.i1`' given input columns: [t5.i1, t5.t5]; line 1 pos 7 -- !query 49 @@ -411,7 +411,7 @@ SELECT mydb1.t5.t5.i2 FROM mydb1.t5 struct<> -- !query 49 output org.apache.spark.sql.AnalysisException -cannot resolve '`mydb1.t5.t5.i2`' given input columns: [i1, t5]; line 1 pos 7 +cannot resolve '`mydb1.t5.t5.i2`' given input columns: [t5.i1, t5.t5]; line 1 pos 7 -- !query 50 diff --git a/sql/core/src/test/resources/sql-tests/results/group-by.sql.out b/sql/core/src/test/resources/sql-tests/results/group-by.sql.out index 14679850c692e..e23ebd4e822fa 100644 --- a/sql/core/src/test/resources/sql-tests/results/group-by.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/group-by.sql.out @@ -202,7 +202,7 @@ SELECT a AS k, COUNT(b) FROM testData GROUP BY k struct<> -- !query 21 output org.apache.spark.sql.AnalysisException -cannot resolve '`k`' given input columns: [a, b]; line 1 pos 47 +cannot resolve '`k`' given input columns: [testdata.a, testdata.b]; line 1 pos 47 -- !query 22 diff --git a/sql/core/src/test/resources/sql-tests/results/table-aliases.sql.out b/sql/core/src/test/resources/sql-tests/results/table-aliases.sql.out index c318018dced29..7abbcd834a523 100644 --- a/sql/core/src/test/resources/sql-tests/results/table-aliases.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/table-aliases.sql.out @@ -60,4 +60,4 @@ SELECT a AS col1, b AS col2 FROM testData AS t(c, d) struct<> -- !query 6 output org.apache.spark.sql.AnalysisException -cannot resolve '`a`' given input columns: [c, d]; line 1 pos 7 +cannot resolve '`a`' given input columns: [t.c, t.d]; line 1 pos 7 diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala index 820cff655c4ff..c0a3b5add313a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala @@ -870,9 +870,9 @@ class SubquerySuite extends QueryTest with SharedSQLContext { test("SPARK-20688: correctly check analysis for scalar sub-queries") { withTempView("t") { - Seq(1 -> "a").toDF("i", "j").createTempView("t") + Seq(1 -> "a").toDF("i", "j").createOrReplaceTempView("t") val e = intercept[AnalysisException](sql("SELECT (SELECT count(*) FROM t WHERE a = 1)")) - assert(e.message.contains("cannot resolve '`a`' given input columns: [i, j]")) + assert(e.message.contains("cannot resolve '`a`' given input columns: [t.i, t.j]")) } } } From ce10545d3401c555e56a214b7c2f334274803660 Mon Sep 17 00:00:00 2001 From: Takuya UESHIN Date: Wed, 5 Jul 2017 11:24:38 +0800 Subject: [PATCH 0865/1765] [SPARK-21300][SQL] ExternalMapToCatalyst should null-check map key prior to converting to internal value. ## What changes were proposed in this pull request? `ExternalMapToCatalyst` should null-check map key prior to converting to internal value to throw an appropriate Exception instead of something like NPE. ## How was this patch tested? Added a test and existing tests. Author: Takuya UESHIN Closes #18524 from ueshin/issues/SPARK-21300. --- .../spark/sql/catalyst/JavaTypeInference.scala | 1 + .../spark/sql/catalyst/ScalaReflection.scala | 1 + .../catalyst/expressions/objects/objects.scala | 16 +++++++++++++++- .../encoders/ExpressionEncoderSuite.scala | 8 +++++++- 4 files changed, 24 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala index 7683ee7074e7d..90ec699877dec 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala @@ -418,6 +418,7 @@ object JavaTypeInference { inputObject, ObjectType(keyType.getRawType), serializerFor(_, keyType), + keyNullable = true, ObjectType(valueType.getRawType), serializerFor(_, valueType), valueNullable = true diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala index d580cf4d3391c..f3c1e4150017d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala @@ -494,6 +494,7 @@ object ScalaReflection extends ScalaReflection { inputObject, dataTypeFor(keyType), serializerFor(_, keyType, keyPath, seenTypeSet), + keyNullable = !keyType.typeSymbol.asClass.isPrimitive, dataTypeFor(valueType), serializerFor(_, valueType, valuePath, seenTypeSet), valueNullable = !valueType.typeSymbol.asClass.isPrimitive) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala index 4b651836ff4d2..d6d06aecc077b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala @@ -841,18 +841,21 @@ object ExternalMapToCatalyst { inputMap: Expression, keyType: DataType, keyConverter: Expression => Expression, + keyNullable: Boolean, valueType: DataType, valueConverter: Expression => Expression, valueNullable: Boolean): ExternalMapToCatalyst = { val id = curId.getAndIncrement() val keyName = "ExternalMapToCatalyst_key" + id + val keyIsNull = "ExternalMapToCatalyst_key_isNull" + id val valueName = "ExternalMapToCatalyst_value" + id val valueIsNull = "ExternalMapToCatalyst_value_isNull" + id ExternalMapToCatalyst( keyName, + keyIsNull, keyType, - keyConverter(LambdaVariable(keyName, "false", keyType, false)), + keyConverter(LambdaVariable(keyName, keyIsNull, keyType, keyNullable)), valueName, valueIsNull, valueType, @@ -868,6 +871,8 @@ object ExternalMapToCatalyst { * * @param key the name of the map key variable that used when iterate the map, and used as input for * the `keyConverter` + * @param keyIsNull the nullability of the map key variable that used when iterate the map, and + * used as input for the `keyConverter` * @param keyType the data type of the map key variable that used when iterate the map, and used as * input for the `keyConverter` * @param keyConverter A function that take the `key` as input, and converts it to catalyst format. @@ -883,6 +888,7 @@ object ExternalMapToCatalyst { */ case class ExternalMapToCatalyst private( key: String, + keyIsNull: String, keyType: DataType, keyConverter: Expression, value: String, @@ -913,6 +919,7 @@ case class ExternalMapToCatalyst private( val keyElementJavaType = ctx.javaType(keyType) val valueElementJavaType = ctx.javaType(valueType) + ctx.addMutableState("boolean", keyIsNull, "") ctx.addMutableState(keyElementJavaType, key, "") ctx.addMutableState("boolean", valueIsNull, "") ctx.addMutableState(valueElementJavaType, value, "") @@ -950,6 +957,12 @@ case class ExternalMapToCatalyst private( defineEntries -> defineKeyValue } + val keyNullCheck = if (ctx.isPrimitiveType(keyType)) { + s"$keyIsNull = false;" + } else { + s"$keyIsNull = $key == null;" + } + val valueNullCheck = if (ctx.isPrimitiveType(valueType)) { s"$valueIsNull = false;" } else { @@ -972,6 +985,7 @@ case class ExternalMapToCatalyst private( $defineEntries while($entries.hasNext()) { $defineKeyValue + $keyNullCheck $valueNullCheck ${genKeyConverter.code} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala index 080f11b769388..bb1955a1ae242 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala @@ -355,12 +355,18 @@ class ExpressionEncoderSuite extends PlanTest with AnalysisTest { checkNullable[String](true) } - test("null check for map key") { + test("null check for map key: String") { val encoder = ExpressionEncoder[Map[String, Int]]() val e = intercept[RuntimeException](encoder.toRow(Map(("a", 1), (null, 2)))) assert(e.getMessage.contains("Cannot use null as map key")) } + test("null check for map key: Integer") { + val encoder = ExpressionEncoder[Map[Integer, String]]() + val e = intercept[RuntimeException](encoder.toRow(Map((1, "a"), (null, "b")))) + assert(e.getMessage.contains("Cannot use null as map key")) + } + private def encodeDecodeTest[T : ExpressionEncoder]( input: T, testName: String): Unit = { From e9a93f8140c913b91781b35e0e1b051c30244882 Mon Sep 17 00:00:00 2001 From: actuaryzhang Date: Tue, 4 Jul 2017 21:05:05 -0700 Subject: [PATCH 0866/1765] [SPARK-20889][SPARKR][FOLLOWUP] Clean up grouped doc for column methods ## What changes were proposed in this pull request? Add doc for methods that were left out, and fix various style and consistency issues. Author: actuaryzhang Closes #18493 from actuaryzhang/sparkRDocCleanup. --- R/pkg/R/functions.R | 100 ++++++++++++++++++++------------------------ R/pkg/R/generics.R | 7 ++-- 2 files changed, 49 insertions(+), 58 deletions(-) diff --git a/R/pkg/R/functions.R b/R/pkg/R/functions.R index 8c12308c1d7c1..c529d83060f50 100644 --- a/R/pkg/R/functions.R +++ b/R/pkg/R/functions.R @@ -38,10 +38,10 @@ NULL #' #' Date time functions defined for \code{Column}. #' -#' @param x Column to compute on. +#' @param x Column to compute on. In \code{window}, it must be a time Column of \code{TimestampType}. #' @param format For \code{to_date} and \code{to_timestamp}, it is the string to use to parse -#' x Column to DateType or TimestampType. For \code{trunc}, it is the string used -#' for specifying the truncation method. For example, "year", "yyyy", "yy" for +#' Column \code{x} to DateType or TimestampType. For \code{trunc}, it is the string +#' to use to specify the truncation method. For example, "year", "yyyy", "yy" for #' truncate by year, or "month", "mon", "mm" for truncate by month. #' @param ... additional argument(s). #' @name column_datetime_functions @@ -122,7 +122,7 @@ NULL #' format to. See 'Details'. #' } #' @param y Column to compute on. -#' @param ... additional columns. +#' @param ... additional Columns. #' @name column_string_functions #' @rdname column_string_functions #' @family string functions @@ -167,8 +167,7 @@ NULL #' tmp <- mutate(df, v1 = crc32(df$model), v2 = hash(df$model), #' v3 = hash(df$model, df$mpg), v4 = md5(df$model), #' v5 = sha1(df$model), v6 = sha2(df$model, 256)) -#' head(tmp) -#' } +#' head(tmp)} NULL #' Collection functions for Column operations @@ -190,7 +189,6 @@ NULL #' \dontrun{ #' # Dataframe used throughout this doc #' df <- createDataFrame(cbind(model = rownames(mtcars), mtcars)) -#' df <- createDataFrame(cbind(model = rownames(mtcars), mtcars)) #' tmp <- mutate(df, v1 = create_array(df$mpg, df$cyl, df$hp)) #' head(select(tmp, array_contains(tmp$v1, 21), size(tmp$v1))) #' tmp2 <- mutate(tmp, v2 = explode(tmp$v1)) @@ -394,7 +392,7 @@ setMethod("base64", }) #' @details -#' \code{bin}: An expression that returns the string representation of the binary value +#' \code{bin}: Returns the string representation of the binary value #' of the given long column. For example, bin("12") returns "1100". #' #' @rdname column_math_functions @@ -722,7 +720,7 @@ setMethod("dayofyear", #' \code{decode}: Computes the first argument into a string from a binary using the provided #' character set. #' -#' @param charset Character set to use (one of "US-ASCII", "ISO-8859-1", "UTF-8", "UTF-16BE", +#' @param charset character set to use (one of "US-ASCII", "ISO-8859-1", "UTF-8", "UTF-16BE", #' "UTF-16LE", "UTF-16"). #' #' @rdname column_string_functions @@ -855,7 +853,7 @@ setMethod("hex", }) #' @details -#' \code{hour}: Extracts the hours as an integer from a given date/timestamp/string. +#' \code{hour}: Extracts the hour as an integer from a given date/timestamp/string. #' #' @rdname column_datetime_functions #' @aliases hour hour,Column-method @@ -1177,7 +1175,7 @@ setMethod("min", }) #' @details -#' \code{minute}: Extracts the minutes as an integer from a given date/timestamp/string. +#' \code{minute}: Extracts the minute as an integer from a given date/timestamp/string. #' #' @rdname column_datetime_functions #' @aliases minute minute,Column-method @@ -1354,7 +1352,7 @@ setMethod("sd", }) #' @details -#' \code{second}: Extracts the seconds as an integer from a given date/timestamp/string. +#' \code{second}: Extracts the second as an integer from a given date/timestamp/string. #' #' @rdname column_datetime_functions #' @aliases second second,Column-method @@ -1464,20 +1462,18 @@ setMethod("soundex", column(jc) }) -#' Return the partition ID as a column -#' -#' Return the partition ID as a SparkDataFrame column. +#' @details +#' \code{spark_partition_id}: Returns the partition ID as a SparkDataFrame column. #' Note that this is nondeterministic because it depends on data partitioning and #' task scheduling. +#' This is equivalent to the \code{SPARK_PARTITION_ID} function in SQL. #' -#' This is equivalent to the SPARK_PARTITION_ID function in SQL. -#' -#' @rdname spark_partition_id -#' @name spark_partition_id -#' @aliases spark_partition_id,missing-method +#' @rdname column_nonaggregate_functions +#' @aliases spark_partition_id spark_partition_id,missing-method #' @export #' @examples -#' \dontrun{select(df, spark_partition_id())} +#' +#' \dontrun{head(select(df, spark_partition_id()))} #' @note spark_partition_id since 2.0.0 setMethod("spark_partition_id", signature("missing"), @@ -2028,7 +2024,7 @@ setMethod("pmod", signature(y = "Column"), column(jc) }) -#' @param rsd maximum estimation error allowed (default = 0.05) +#' @param rsd maximum estimation error allowed (default = 0.05). #' #' @rdname column_aggregate_functions #' @aliases approxCountDistinct,Column-method @@ -2220,8 +2216,8 @@ setMethod("from_json", signature(x = "Column", schema = "structType"), #' @examples #' #' \dontrun{ -#' tmp <- mutate(df, from_utc = from_utc_timestamp(df$time, 'PST'), -#' to_utc = to_utc_timestamp(df$time, 'PST')) +#' tmp <- mutate(df, from_utc = from_utc_timestamp(df$time, "PST"), +#' to_utc = to_utc_timestamp(df$time, "PST")) #' head(tmp)} #' @note from_utc_timestamp since 1.5.0 setMethod("from_utc_timestamp", signature(y = "Column", x = "character"), @@ -2255,7 +2251,7 @@ setMethod("instr", signature(y = "Column", x = "character"), #' @details #' \code{next_day}: Given a date column, returns the first date which is later than the value of #' the date column that is on the specified day of the week. For example, -#' \code{next_day('2015-07-27', "Sunday")} returns 2015-08-02 because that is the first Sunday +#' \code{next_day("2015-07-27", "Sunday")} returns 2015-08-02 because that is the first Sunday #' after 2015-07-27. Day of the week parameter is case insensitive, and accepts first three or #' two characters: "Mon", "Tue", "Wed", "Thu", "Fri", "Sat", "Sun". #' @@ -2295,7 +2291,7 @@ setMethod("to_utc_timestamp", signature(y = "Column", x = "character"), #' tmp <- mutate(df, t1 = add_months(df$time, 1), #' t2 = date_add(df$time, 2), #' t3 = date_sub(df$time, 3), -#' t4 = next_day(df$time, 'Sun')) +#' t4 = next_day(df$time, "Sun")) #' head(tmp)} #' @note add_months since 1.5.0 setMethod("add_months", signature(y = "Column", x = "numeric"), @@ -2404,8 +2400,8 @@ setMethod("shiftRight", signature(y = "Column", x = "numeric"), }) #' @details -#' \code{shiftRight}: (Unigned) shifts the given value numBits right. If the given value is a long value, -#' it will return a long value else it will return an integer value. +#' \code{shiftRightUnsigned}: (Unigned) shifts the given value numBits right. If the given value is +#' a long value, it will return a long value else it will return an integer value. #' #' @rdname column_math_functions #' @aliases shiftRightUnsigned shiftRightUnsigned,Column,numeric-method @@ -2513,14 +2509,13 @@ setMethod("from_unixtime", signature(x = "Column"), column(jc) }) -#' window -#' -#' Bucketize rows into one or more time windows given a timestamp specifying column. Window -#' starts are inclusive but the window ends are exclusive, e.g. 12:05 will be in the window +#' @details +#' \code{window}: Bucketizes rows into one or more time windows given a timestamp specifying column. +#' Window starts are inclusive but the window ends are exclusive, e.g. 12:05 will be in the window #' [12:05,12:10) but not in [12:00,12:05). Windows can support microsecond precision. Windows in -#' the order of months are not supported. +#' the order of months are not supported. It returns an output column of struct called 'window' +#' by default with the nested columns 'start' and 'end' #' -#' @param x a time Column. Must be of TimestampType. #' @param windowDuration a string specifying the width of the window, e.g. '1 second', #' '1 day 12 hours', '2 minutes'. Valid interval strings are 'week', #' 'day', 'hour', 'minute', 'second', 'millisecond', 'microsecond'. Note that @@ -2536,27 +2531,22 @@ setMethod("from_unixtime", signature(x = "Column"), #' window intervals. For example, in order to have hourly tumbling windows #' that start 15 minutes past the hour, e.g. 12:15-13:15, 13:15-14:15... provide #' \code{startTime} as \code{"15 minutes"}. -#' @param ... further arguments to be passed to or from other methods. -#' @return An output column of struct called 'window' by default with the nested columns 'start' -#' and 'end'. -#' @family date time functions -#' @rdname window -#' @name window -#' @aliases window,Column-method +#' @rdname column_datetime_functions +#' @aliases window window,Column-method #' @export #' @examples -#'\dontrun{ -#' # One minute windows every 15 seconds 10 seconds after the minute, e.g. 09:00:10-09:01:10, -#' # 09:00:25-09:01:25, 09:00:40-09:01:40, ... -#' window(df$time, "1 minute", "15 seconds", "10 seconds") #' -#' # One minute tumbling windows 15 seconds after the minute, e.g. 09:00:15-09:01:15, -#' # 09:01:15-09:02:15... -#' window(df$time, "1 minute", startTime = "15 seconds") +#' \dontrun{ +#' # One minute windows every 15 seconds 10 seconds after the minute, e.g. 09:00:10-09:01:10, +#' # 09:00:25-09:01:25, 09:00:40-09:01:40, ... +#' window(df$time, "1 minute", "15 seconds", "10 seconds") #' -#' # Thirty-second windows every 10 seconds, e.g. 09:00:00-09:00:30, 09:00:10-09:00:40, ... -#' window(df$time, "30 seconds", "10 seconds") -#'} +#' # One minute tumbling windows 15 seconds after the minute, e.g. 09:00:15-09:01:15, +#' # 09:01:15-09:02:15... +#' window(df$time, "1 minute", startTime = "15 seconds") +#' +#' # Thirty-second windows every 10 seconds, e.g. 09:00:00-09:00:30, 09:00:10-09:00:40, ... +#' window(df$time, "30 seconds", "10 seconds")} #' @note window since 2.0.0 setMethod("window", signature(x = "Column"), function(x, windowDuration, slideDuration = NULL, startTime = NULL) { @@ -3046,7 +3036,7 @@ setMethod("row_number", #' \code{array_contains}: Returns null if the array is null, true if the array contains #' the value, and false otherwise. #' -#' @param value A value to be checked if contained in the column +#' @param value a value to be checked if contained in the column #' @rdname column_collection_functions #' @aliases array_contains array_contains,Column-method #' @export @@ -3091,7 +3081,7 @@ setMethod("size", #' to the natural ordering of the array elements. #' #' @rdname column_collection_functions -#' @param asc A logical flag indicating the sorting order. +#' @param asc a logical flag indicating the sorting order. #' TRUE, sorting is in ascending order. #' FALSE, sorting is in descending order. #' @aliases sort_array sort_array,Column-method @@ -3218,7 +3208,7 @@ setMethod("split_string", #' \code{repeat_string}: Repeats string n times. #' Equivalent to \code{repeat} SQL function. #' -#' @param n Number of repetitions +#' @param n number of repetitions. #' @rdname column_string_functions #' @aliases repeat_string repeat_string,Column-method #' @export @@ -3347,7 +3337,7 @@ setMethod("grouping_bit", #' \code{grouping_id}: Returns the level of grouping. #' Equals to \code{ #' grouping_bit(c1) * 2^(n - 1) + grouping_bit(c2) * 2^(n - 2) + ... + grouping_bit(cn) -#' } +#' }. #' #' @rdname column_aggregate_functions #' @aliases grouping_id grouping_id,Column-method diff --git a/R/pkg/R/generics.R b/R/pkg/R/generics.R index beac18e412736..92098741f72f9 100644 --- a/R/pkg/R/generics.R +++ b/R/pkg/R/generics.R @@ -1418,9 +1418,9 @@ setGeneric("split_string", function(x, pattern) { standardGeneric("split_string" #' @name NULL setGeneric("soundex", function(x) { standardGeneric("soundex") }) -#' @param x empty. Should be used with no argument. -#' @rdname spark_partition_id +#' @rdname column_nonaggregate_functions #' @export +#' @name NULL setGeneric("spark_partition_id", function(x = "missing") { standardGeneric("spark_partition_id") }) #' @rdname column_aggregate_functions @@ -1538,8 +1538,9 @@ setGeneric("var_samp", function(x) { standardGeneric("var_samp") }) #' @name NULL setGeneric("weekofyear", function(x) { standardGeneric("weekofyear") }) -#' @rdname window +#' @rdname column_datetime_functions #' @export +#' @name NULL setGeneric("window", function(x, ...) { standardGeneric("window") }) #' @rdname column_datetime_functions From f2c3b1dd69423cf52880e0ffa5f673ad6041b40e Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Wed, 5 Jul 2017 14:17:26 +0800 Subject: [PATCH 0867/1765] [SPARK-21304][SQL] remove unnecessary isNull variable for collection related encoder expressions ## What changes were proposed in this pull request? For these collection-related encoder expressions, we don't need to create `isNull` variable if the loop element is not nullable. ## How was this patch tested? existing tests. Author: Wenchen Fan Closes #18529 from cloud-fan/minor. --- .../spark/sql/catalyst/ScalaReflection.scala | 2 +- .../expressions/objects/objects.scala | 77 ++++++++++++------- 2 files changed, 50 insertions(+), 29 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala index f3c1e4150017d..bea0de4d90c2f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala @@ -335,7 +335,7 @@ object ScalaReflection extends ScalaReflection { // TODO: add walked type path for map val TypeRef(_, _, Seq(keyType, valueType)) = t - CollectObjectsToMap( + CatalystToExternalMap( p => deserializerFor(keyType, Some(p), walkedTypePath), p => deserializerFor(valueType, Some(p), walkedTypePath), getPath, diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala index d6d06aecc077b..ce07f4a25c189 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala @@ -465,7 +465,11 @@ object MapObjects { customCollectionCls: Option[Class[_]] = None): MapObjects = { val id = curId.getAndIncrement() val loopValue = s"MapObjects_loopValue$id" - val loopIsNull = s"MapObjects_loopIsNull$id" + val loopIsNull = if (elementNullable) { + s"MapObjects_loopIsNull$id" + } else { + "false" + } val loopVar = LambdaVariable(loopValue, loopIsNull, elementType, elementNullable) MapObjects( loopValue, loopIsNull, elementType, function(loopVar), inputData, customCollectionCls) @@ -517,7 +521,6 @@ case class MapObjects private( override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val elementJavaType = ctx.javaType(loopVarDataType) - ctx.addMutableState("boolean", loopIsNull, "") ctx.addMutableState(elementJavaType, loopValue, "") val genInputData = inputData.genCode(ctx) val genFunction = lambdaFunction.genCode(ctx) @@ -588,12 +591,14 @@ case class MapObjects private( case _ => genFunction.value } - val loopNullCheck = inputDataType match { - case _: ArrayType => s"$loopIsNull = ${genInputData.value}.isNullAt($loopIndex);" - // The element of primitive array will never be null. - case ObjectType(cls) if cls.isArray && cls.getComponentType.isPrimitive => - s"$loopIsNull = false" - case _ => s"$loopIsNull = $loopValue == null;" + val loopNullCheck = if (loopIsNull != "false") { + ctx.addMutableState("boolean", loopIsNull, "") + inputDataType match { + case _: ArrayType => s"$loopIsNull = ${genInputData.value}.isNullAt($loopIndex);" + case _ => s"$loopIsNull = $loopValue == null;" + } + } else { + "" } val (initCollection, addElement, getResult): (String, String => String, String) = @@ -667,11 +672,11 @@ case class MapObjects private( } } -object CollectObjectsToMap { +object CatalystToExternalMap { private val curId = new java.util.concurrent.atomic.AtomicInteger() /** - * Construct an instance of CollectObjectsToMap case class. + * Construct an instance of CatalystToExternalMap case class. * * @param keyFunction The function applied on the key collection elements. * @param valueFunction The function applied on the value collection elements. @@ -682,15 +687,19 @@ object CollectObjectsToMap { keyFunction: Expression => Expression, valueFunction: Expression => Expression, inputData: Expression, - collClass: Class[_]): CollectObjectsToMap = { + collClass: Class[_]): CatalystToExternalMap = { val id = curId.getAndIncrement() - val keyLoopValue = s"CollectObjectsToMap_keyLoopValue$id" + val keyLoopValue = s"CatalystToExternalMap_keyLoopValue$id" val mapType = inputData.dataType.asInstanceOf[MapType] val keyLoopVar = LambdaVariable(keyLoopValue, "", mapType.keyType, nullable = false) - val valueLoopValue = s"CollectObjectsToMap_valueLoopValue$id" - val valueLoopIsNull = s"CollectObjectsToMap_valueLoopIsNull$id" + val valueLoopValue = s"CatalystToExternalMap_valueLoopValue$id" + val valueLoopIsNull = if (mapType.valueContainsNull) { + s"CatalystToExternalMap_valueLoopIsNull$id" + } else { + "false" + } val valueLoopVar = LambdaVariable(valueLoopValue, valueLoopIsNull, mapType.valueType) - CollectObjectsToMap( + CatalystToExternalMap( keyLoopValue, keyFunction(keyLoopVar), valueLoopValue, valueLoopIsNull, valueFunction(valueLoopVar), inputData, collClass) @@ -716,7 +725,7 @@ object CollectObjectsToMap { * @param inputData An expression that when evaluated returns a map object. * @param collClass The type of the resulting collection. */ -case class CollectObjectsToMap private( +case class CatalystToExternalMap private( keyLoopValue: String, keyLambdaFunction: Expression, valueLoopValue: String, @@ -748,7 +757,6 @@ case class CollectObjectsToMap private( ctx.addMutableState(keyElementJavaType, keyLoopValue, "") val genKeyFunction = keyLambdaFunction.genCode(ctx) val valueElementJavaType = ctx.javaType(mapType.valueType) - ctx.addMutableState("boolean", valueLoopIsNull, "") ctx.addMutableState(valueElementJavaType, valueLoopValue, "") val genValueFunction = valueLambdaFunction.genCode(ctx) val genInputData = inputData.genCode(ctx) @@ -781,7 +789,12 @@ case class CollectObjectsToMap private( val genKeyFunctionValue = genFunctionValue(keyLambdaFunction, genKeyFunction) val genValueFunctionValue = genFunctionValue(valueLambdaFunction, genValueFunction) - val valueLoopNullCheck = s"$valueLoopIsNull = $valueArray.isNullAt($loopIndex);" + val valueLoopNullCheck = if (valueLoopIsNull != "false") { + ctx.addMutableState("boolean", valueLoopIsNull, "") + s"$valueLoopIsNull = $valueArray.isNullAt($loopIndex);" + } else { + "" + } val builderClass = classOf[Builder[_, _]].getName val constructBuilder = s""" @@ -847,9 +860,17 @@ object ExternalMapToCatalyst { valueNullable: Boolean): ExternalMapToCatalyst = { val id = curId.getAndIncrement() val keyName = "ExternalMapToCatalyst_key" + id - val keyIsNull = "ExternalMapToCatalyst_key_isNull" + id + val keyIsNull = if (keyNullable) { + "ExternalMapToCatalyst_key_isNull" + id + } else { + "false" + } val valueName = "ExternalMapToCatalyst_value" + id - val valueIsNull = "ExternalMapToCatalyst_value_isNull" + id + val valueIsNull = if (valueNullable) { + "ExternalMapToCatalyst_value_isNull" + id + } else { + "false" + } ExternalMapToCatalyst( keyName, @@ -919,9 +940,7 @@ case class ExternalMapToCatalyst private( val keyElementJavaType = ctx.javaType(keyType) val valueElementJavaType = ctx.javaType(valueType) - ctx.addMutableState("boolean", keyIsNull, "") ctx.addMutableState(keyElementJavaType, key, "") - ctx.addMutableState("boolean", valueIsNull, "") ctx.addMutableState(valueElementJavaType, value, "") val (defineEntries, defineKeyValue) = child.dataType match { @@ -957,16 +976,18 @@ case class ExternalMapToCatalyst private( defineEntries -> defineKeyValue } - val keyNullCheck = if (ctx.isPrimitiveType(keyType)) { - s"$keyIsNull = false;" - } else { + val keyNullCheck = if (keyIsNull != "false") { + ctx.addMutableState("boolean", keyIsNull, "") s"$keyIsNull = $key == null;" + } else { + "" } - val valueNullCheck = if (ctx.isPrimitiveType(valueType)) { - s"$valueIsNull = false;" - } else { + val valueNullCheck = if (valueIsNull != "false") { + ctx.addMutableState("boolean", valueIsNull, "") s"$valueIsNull = $value == null;" + } else { + "" } val arrayCls = classOf[GenericArrayData].getName From a38643256691947ff7f7c474b85c052a7d5d8553 Mon Sep 17 00:00:00 2001 From: Takuya UESHIN Date: Wed, 5 Jul 2017 14:25:26 +0800 Subject: [PATCH 0868/1765] [SPARK-18623][SQL] Add `returnNullable` to `StaticInvoke` and modify it to handle properly. ## What changes were proposed in this pull request? Add `returnNullable` to `StaticInvoke` the same as #15780 is trying to add to `Invoke` and modify to handle properly. ## How was this patch tested? Existing tests. Author: Takuya UESHIN Author: Takuya UESHIN Closes #16056 from ueshin/issues/SPARK-18623. --- .../sql/catalyst/JavaTypeInference.scala | 21 +++++---- .../spark/sql/catalyst/ScalaReflection.scala | 44 ++++++++++++------- .../sql/catalyst/encoders/RowEncoder.scala | 27 ++++++++---- .../expressions/objects/objects.scala | 42 ++++++++++++++---- 4 files changed, 91 insertions(+), 43 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala index 90ec699877dec..21363d3ba82c1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala @@ -216,7 +216,7 @@ object JavaTypeInference { ObjectType(c), "valueOf", getPath :: Nil, - propagateNull = true) + returnNullable = false) case c if c == classOf[java.sql.Date] => StaticInvoke( @@ -224,7 +224,7 @@ object JavaTypeInference { ObjectType(c), "toJavaDate", getPath :: Nil, - propagateNull = true) + returnNullable = false) case c if c == classOf[java.sql.Timestamp] => StaticInvoke( @@ -232,7 +232,7 @@ object JavaTypeInference { ObjectType(c), "toJavaTimestamp", getPath :: Nil, - propagateNull = true) + returnNullable = false) case c if c == classOf[java.lang.String] => Invoke(getPath, "toString", ObjectType(classOf[String])) @@ -300,7 +300,8 @@ object JavaTypeInference { ArrayBasedMapData.getClass, ObjectType(classOf[JMap[_, _]]), "toJavaMap", - keyData :: valueData :: Nil) + keyData :: valueData :: Nil, + returnNullable = false) case other => val properties = getJavaBeanReadableAndWritableProperties(other) @@ -367,28 +368,32 @@ object JavaTypeInference { classOf[UTF8String], StringType, "fromString", - inputObject :: Nil) + inputObject :: Nil, + returnNullable = false) case c if c == classOf[java.sql.Timestamp] => StaticInvoke( DateTimeUtils.getClass, TimestampType, "fromJavaTimestamp", - inputObject :: Nil) + inputObject :: Nil, + returnNullable = false) case c if c == classOf[java.sql.Date] => StaticInvoke( DateTimeUtils.getClass, DateType, "fromJavaDate", - inputObject :: Nil) + inputObject :: Nil, + returnNullable = false) case c if c == classOf[java.math.BigDecimal] => StaticInvoke( Decimal.getClass, DecimalType.SYSTEM_DEFAULT, "apply", - inputObject :: Nil) + inputObject :: Nil, + returnNullable = false) case c if c == classOf[java.lang.Boolean] => Invoke(inputObject, "booleanValue", BooleanType) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala index bea0de4d90c2f..814f2c10b9097 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala @@ -206,51 +206,53 @@ object ScalaReflection extends ScalaReflection { case t if t <:< localTypeOf[java.lang.Integer] => val boxedType = classOf[java.lang.Integer] val objectType = ObjectType(boxedType) - StaticInvoke(boxedType, objectType, "valueOf", getPath :: Nil, propagateNull = true) + StaticInvoke(boxedType, objectType, "valueOf", getPath :: Nil, returnNullable = false) case t if t <:< localTypeOf[java.lang.Long] => val boxedType = classOf[java.lang.Long] val objectType = ObjectType(boxedType) - StaticInvoke(boxedType, objectType, "valueOf", getPath :: Nil, propagateNull = true) + StaticInvoke(boxedType, objectType, "valueOf", getPath :: Nil, returnNullable = false) case t if t <:< localTypeOf[java.lang.Double] => val boxedType = classOf[java.lang.Double] val objectType = ObjectType(boxedType) - StaticInvoke(boxedType, objectType, "valueOf", getPath :: Nil, propagateNull = true) + StaticInvoke(boxedType, objectType, "valueOf", getPath :: Nil, returnNullable = false) case t if t <:< localTypeOf[java.lang.Float] => val boxedType = classOf[java.lang.Float] val objectType = ObjectType(boxedType) - StaticInvoke(boxedType, objectType, "valueOf", getPath :: Nil, propagateNull = true) + StaticInvoke(boxedType, objectType, "valueOf", getPath :: Nil, returnNullable = false) case t if t <:< localTypeOf[java.lang.Short] => val boxedType = classOf[java.lang.Short] val objectType = ObjectType(boxedType) - StaticInvoke(boxedType, objectType, "valueOf", getPath :: Nil, propagateNull = true) + StaticInvoke(boxedType, objectType, "valueOf", getPath :: Nil, returnNullable = false) case t if t <:< localTypeOf[java.lang.Byte] => val boxedType = classOf[java.lang.Byte] val objectType = ObjectType(boxedType) - StaticInvoke(boxedType, objectType, "valueOf", getPath :: Nil, propagateNull = true) + StaticInvoke(boxedType, objectType, "valueOf", getPath :: Nil, returnNullable = false) case t if t <:< localTypeOf[java.lang.Boolean] => val boxedType = classOf[java.lang.Boolean] val objectType = ObjectType(boxedType) - StaticInvoke(boxedType, objectType, "valueOf", getPath :: Nil, propagateNull = true) + StaticInvoke(boxedType, objectType, "valueOf", getPath :: Nil, returnNullable = false) case t if t <:< localTypeOf[java.sql.Date] => StaticInvoke( DateTimeUtils.getClass, ObjectType(classOf[java.sql.Date]), "toJavaDate", - getPath :: Nil) + getPath :: Nil, + returnNullable = false) case t if t <:< localTypeOf[java.sql.Timestamp] => StaticInvoke( DateTimeUtils.getClass, ObjectType(classOf[java.sql.Timestamp]), "toJavaTimestamp", - getPath :: Nil) + getPath :: Nil, + returnNullable = false) case t if t <:< localTypeOf[java.lang.String] => Invoke(getPath, "toString", ObjectType(classOf[String]), returnNullable = false) @@ -446,7 +448,8 @@ object ScalaReflection extends ScalaReflection { classOf[UnsafeArrayData], ArrayType(dt, false), "fromPrimitiveArray", - input :: Nil) + input :: Nil, + returnNullable = false) } else { NewInstance( classOf[GenericArrayData], @@ -504,49 +507,56 @@ object ScalaReflection extends ScalaReflection { classOf[UTF8String], StringType, "fromString", - inputObject :: Nil) + inputObject :: Nil, + returnNullable = false) case t if t <:< localTypeOf[java.sql.Timestamp] => StaticInvoke( DateTimeUtils.getClass, TimestampType, "fromJavaTimestamp", - inputObject :: Nil) + inputObject :: Nil, + returnNullable = false) case t if t <:< localTypeOf[java.sql.Date] => StaticInvoke( DateTimeUtils.getClass, DateType, "fromJavaDate", - inputObject :: Nil) + inputObject :: Nil, + returnNullable = false) case t if t <:< localTypeOf[BigDecimal] => StaticInvoke( Decimal.getClass, DecimalType.SYSTEM_DEFAULT, "apply", - inputObject :: Nil) + inputObject :: Nil, + returnNullable = false) case t if t <:< localTypeOf[java.math.BigDecimal] => StaticInvoke( Decimal.getClass, DecimalType.SYSTEM_DEFAULT, "apply", - inputObject :: Nil) + inputObject :: Nil, + returnNullable = false) case t if t <:< localTypeOf[java.math.BigInteger] => StaticInvoke( Decimal.getClass, DecimalType.BigIntDecimal, "apply", - inputObject :: Nil) + inputObject :: Nil, + returnNullable = false) case t if t <:< localTypeOf[scala.math.BigInt] => StaticInvoke( Decimal.getClass, DecimalType.BigIntDecimal, "apply", - inputObject :: Nil) + inputObject :: Nil, + returnNullable = false) case t if t <:< localTypeOf[java.lang.Integer] => Invoke(inputObject, "intValue", IntegerType) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala index 0f8282d3b2f1f..cc32fac67e924 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala @@ -96,28 +96,32 @@ object RowEncoder { DateTimeUtils.getClass, TimestampType, "fromJavaTimestamp", - inputObject :: Nil) + inputObject :: Nil, + returnNullable = false) case DateType => StaticInvoke( DateTimeUtils.getClass, DateType, "fromJavaDate", - inputObject :: Nil) + inputObject :: Nil, + returnNullable = false) case d: DecimalType => StaticInvoke( Decimal.getClass, d, "fromDecimal", - inputObject :: Nil) + inputObject :: Nil, + returnNullable = false) case StringType => StaticInvoke( classOf[UTF8String], StringType, "fromString", - inputObject :: Nil) + inputObject :: Nil, + returnNullable = false) case t @ ArrayType(et, cn) => et match { @@ -126,7 +130,8 @@ object RowEncoder { classOf[ArrayData], t, "toArrayData", - inputObject :: Nil) + inputObject :: Nil, + returnNullable = false) case _ => MapObjects( element => serializerFor(ValidateExternalType(element, et), et), inputObject, @@ -254,14 +259,16 @@ object RowEncoder { DateTimeUtils.getClass, ObjectType(classOf[java.sql.Timestamp]), "toJavaTimestamp", - input :: Nil) + input :: Nil, + returnNullable = false) case DateType => StaticInvoke( DateTimeUtils.getClass, ObjectType(classOf[java.sql.Date]), "toJavaDate", - input :: Nil) + input :: Nil, + returnNullable = false) case _: DecimalType => Invoke(input, "toJavaBigDecimal", ObjectType(classOf[java.math.BigDecimal]), @@ -280,7 +287,8 @@ object RowEncoder { scala.collection.mutable.WrappedArray.getClass, ObjectType(classOf[Seq[_]]), "make", - arrayData :: Nil) + arrayData :: Nil, + returnNullable = false) case MapType(kt, vt, valueNullable) => val keyArrayType = ArrayType(kt, false) @@ -293,7 +301,8 @@ object RowEncoder { ArrayBasedMapData.getClass, ObjectType(classOf[Map[_, _]]), "toScalaMap", - keyData :: valueData :: Nil) + keyData :: valueData :: Nil, + returnNullable = false) case schema @ StructType(fields) => val convertedFields = fields.zipWithIndex.map { case (f, i) => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala index ce07f4a25c189..24c06d8b14b54 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala @@ -118,17 +118,20 @@ trait InvokeLike extends Expression with NonSQLExpression { * @param arguments An optional list of expressions to pass as arguments to the function. * @param propagateNull When true, and any of the arguments is null, null will be returned instead * of calling the function. + * @param returnNullable When false, indicating the invoked method will always return + * non-null value. */ case class StaticInvoke( staticObject: Class[_], dataType: DataType, functionName: String, arguments: Seq[Expression] = Nil, - propagateNull: Boolean = true) extends InvokeLike { + propagateNull: Boolean = true, + returnNullable: Boolean = true) extends InvokeLike { val objectName = staticObject.getName.stripSuffix("$") - override def nullable: Boolean = true + override def nullable: Boolean = needNullCheck || returnNullable override def children: Seq[Expression] = arguments override def eval(input: InternalRow): Any = @@ -141,19 +144,40 @@ case class StaticInvoke( val callFunc = s"$objectName.$functionName($argString)" - // If the function can return null, we do an extra check to make sure our null bit is still set - // correctly. - val postNullCheck = if (ctx.defaultValue(dataType) == "null") { - s"${ev.isNull} = ${ev.value} == null;" + val prepareIsNull = if (nullable) { + s"boolean ${ev.isNull} = $resultIsNull;" } else { + ev.isNull = "false" "" } + val evaluate = if (returnNullable) { + if (ctx.defaultValue(dataType) == "null") { + s""" + ${ev.value} = $callFunc; + ${ev.isNull} = ${ev.value} == null; + """ + } else { + val boxedResult = ctx.freshName("boxedResult") + s""" + ${ctx.boxedType(dataType)} $boxedResult = $callFunc; + ${ev.isNull} = $boxedResult == null; + if (!${ev.isNull}) { + ${ev.value} = $boxedResult; + } + """ + } + } else { + s"${ev.value} = $callFunc;" + } + val code = s""" $argCode - boolean ${ev.isNull} = $resultIsNull; - final $javaType ${ev.value} = $resultIsNull ? ${ctx.defaultValue(dataType)} : $callFunc; - $postNullCheck + $prepareIsNull + $javaType ${ev.value} = ${ctx.defaultValue(dataType)}; + if (!$resultIsNull) { + $evaluate + } """ ev.copy(code = code) } From 4852b7d447e872079c2c81428354adc825a87b27 Mon Sep 17 00:00:00 2001 From: actuaryzhang Date: Wed, 5 Jul 2017 18:41:00 +0800 Subject: [PATCH 0869/1765] [SPARK-21310][ML][PYSPARK] Expose offset in PySpark ## What changes were proposed in this pull request? Add offset to PySpark in GLM as in #16699. ## How was this patch tested? Python test Author: actuaryzhang Closes #18534 from actuaryzhang/pythonOffset. --- python/pyspark/ml/regression.py | 25 +++++++++++++++++++++---- python/pyspark/ml/tests.py | 14 ++++++++++++++ 2 files changed, 35 insertions(+), 4 deletions(-) diff --git a/python/pyspark/ml/regression.py b/python/pyspark/ml/regression.py index 84d843369e105..f0ff7a5f59abf 100644 --- a/python/pyspark/ml/regression.py +++ b/python/pyspark/ml/regression.py @@ -1376,17 +1376,20 @@ class GeneralizedLinearRegression(JavaEstimator, HasLabelCol, HasFeaturesCol, Ha typeConverter=TypeConverters.toFloat) solver = Param(Params._dummy(), "solver", "The solver algorithm for optimization. Supported " + "options: irls.", typeConverter=TypeConverters.toString) + offsetCol = Param(Params._dummy(), "offsetCol", "The offset column name. If this is not set " + + "or empty, we treat all instance offsets as 0.0", + typeConverter=TypeConverters.toString) @keyword_only def __init__(self, labelCol="label", featuresCol="features", predictionCol="prediction", family="gaussian", link=None, fitIntercept=True, maxIter=25, tol=1e-6, regParam=0.0, weightCol=None, solver="irls", linkPredictionCol=None, - variancePower=0.0, linkPower=None): + variancePower=0.0, linkPower=None, offsetCol=None): """ __init__(self, labelCol="label", featuresCol="features", predictionCol="prediction", \ family="gaussian", link=None, fitIntercept=True, maxIter=25, tol=1e-6, \ regParam=0.0, weightCol=None, solver="irls", linkPredictionCol=None, \ - variancePower=0.0, linkPower=None) + variancePower=0.0, linkPower=None, offsetCol=None) """ super(GeneralizedLinearRegression, self).__init__() self._java_obj = self._new_java_obj( @@ -1402,12 +1405,12 @@ def __init__(self, labelCol="label", featuresCol="features", predictionCol="pred def setParams(self, labelCol="label", featuresCol="features", predictionCol="prediction", family="gaussian", link=None, fitIntercept=True, maxIter=25, tol=1e-6, regParam=0.0, weightCol=None, solver="irls", linkPredictionCol=None, - variancePower=0.0, linkPower=None): + variancePower=0.0, linkPower=None, offsetCol=None): """ setParams(self, labelCol="label", featuresCol="features", predictionCol="prediction", \ family="gaussian", link=None, fitIntercept=True, maxIter=25, tol=1e-6, \ regParam=0.0, weightCol=None, solver="irls", linkPredictionCol=None, \ - variancePower=0.0, linkPower=None) + variancePower=0.0, linkPower=None, offsetCol=None) Sets params for generalized linear regression. """ kwargs = self._input_kwargs @@ -1486,6 +1489,20 @@ def getLinkPower(self): """ return self.getOrDefault(self.linkPower) + @since("2.3.0") + def setOffsetCol(self, value): + """ + Sets the value of :py:attr:`offsetCol`. + """ + return self._set(offsetCol=value) + + @since("2.3.0") + def getOffsetCol(self): + """ + Gets the value of offsetCol or its default value. + """ + return self.getOrDefault(self.offsetCol) + class GeneralizedLinearRegressionModel(JavaModel, JavaPredictionModel, JavaMLWritable, JavaMLReadable): diff --git a/python/pyspark/ml/tests.py b/python/pyspark/ml/tests.py index ffb8b0a890ff8..7870047651601 100755 --- a/python/pyspark/ml/tests.py +++ b/python/pyspark/ml/tests.py @@ -1291,6 +1291,20 @@ def test_tweedie_distribution(self): self.assertTrue(np.allclose(model2.coefficients.toArray(), [-0.6667, 0.5], atol=1E-4)) self.assertTrue(np.isclose(model2.intercept, 0.6667, atol=1E-4)) + def test_offset(self): + + df = self.spark.createDataFrame( + [(0.2, 1.0, 2.0, Vectors.dense(0.0, 5.0)), + (0.5, 2.1, 0.5, Vectors.dense(1.0, 2.0)), + (0.9, 0.4, 1.0, Vectors.dense(2.0, 1.0)), + (0.7, 0.7, 0.0, Vectors.dense(3.0, 3.0))], ["label", "weight", "offset", "features"]) + + glr = GeneralizedLinearRegression(family="poisson", weightCol="weight", offsetCol="offset") + model = glr.fit(df) + self.assertTrue(np.allclose(model.coefficients.toArray(), [0.664647, -0.3192581], + atol=1E-4)) + self.assertTrue(np.isclose(model.intercept, -1.561613, atol=1E-4)) + class FPGrowthTests(SparkSessionTestCase): def setUp(self): From 873f3ad2b89c955f42fced49dc129e8efa77d044 Mon Sep 17 00:00:00 2001 From: Takuya UESHIN Date: Wed, 5 Jul 2017 20:32:47 +0800 Subject: [PATCH 0870/1765] [SPARK-16167][SQL] RowEncoder should preserve array/map type nullability. ## What changes were proposed in this pull request? Currently `RowEncoder` doesn't preserve nullability of `ArrayType` or `MapType`. It returns always `containsNull = true` for `ArrayType`, `valueContainsNull = true` for `MapType` and also the nullability of itself is always `true`. This pr fixes the nullability of them. ## How was this patch tested? Add tests to check if `RowEncoder` preserves array/map nullability. Author: Takuya UESHIN Author: Takuya UESHIN Closes #13873 from ueshin/issues/SPARK-16167. --- .../sql/catalyst/encoders/RowEncoder.scala | 25 +++++++++++--- .../catalyst/encoders/RowEncoderSuite.scala | 33 +++++++++++++++++++ 2 files changed, 54 insertions(+), 4 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala index cc32fac67e924..43c35bbdf383a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala @@ -123,7 +123,7 @@ object RowEncoder { inputObject :: Nil, returnNullable = false) - case t @ ArrayType(et, cn) => + case t @ ArrayType(et, containsNull) => et match { case BooleanType | ByteType | ShortType | IntegerType | LongType | FloatType | DoubleType => StaticInvoke( @@ -132,8 +132,16 @@ object RowEncoder { "toArrayData", inputObject :: Nil, returnNullable = false) + case _ => MapObjects( - element => serializerFor(ValidateExternalType(element, et), et), + element => { + val value = serializerFor(ValidateExternalType(element, et), et) + if (!containsNull) { + AssertNotNull(value, Seq.empty) + } else { + value + } + }, inputObject, ObjectType(classOf[Object])) } @@ -155,10 +163,19 @@ object RowEncoder { ObjectType(classOf[scala.collection.Seq[_]]), returnNullable = false) val convertedValues = serializerFor(values, ArrayType(vt, valueNullable)) - NewInstance( + val nonNullOutput = NewInstance( classOf[ArrayBasedMapData], convertedKeys :: convertedValues :: Nil, - dataType = t) + dataType = t, + propagateNull = false) + + if (inputObject.nullable) { + If(IsNull(inputObject), + Literal.create(null, inputType), + nonNullOutput) + } else { + nonNullOutput + } case StructType(fields) => val nonNullOutput = CreateNamedStruct(fields.zipWithIndex.flatMap { case (field, index) => diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala index 1a5569a77dc7a..6ed175f86ca77 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala @@ -273,6 +273,39 @@ class RowEncoderSuite extends SparkFunSuite { assert(e4.getMessage.contains("java.lang.String is not a valid external type")) } + for { + elementType <- Seq(IntegerType, StringType) + containsNull <- Seq(true, false) + nullable <- Seq(true, false) + } { + test("RowEncoder should preserve array nullability: " + + s"ArrayType($elementType, containsNull = $containsNull), nullable = $nullable") { + val schema = new StructType().add("array", ArrayType(elementType, containsNull), nullable) + val encoder = RowEncoder(schema).resolveAndBind() + assert(encoder.serializer.length == 1) + assert(encoder.serializer.head.dataType == ArrayType(elementType, containsNull)) + assert(encoder.serializer.head.nullable == nullable) + } + } + + for { + keyType <- Seq(IntegerType, StringType) + valueType <- Seq(IntegerType, StringType) + valueContainsNull <- Seq(true, false) + nullable <- Seq(true, false) + } { + test("RowEncoder should preserve map nullability: " + + s"MapType($keyType, $valueType, valueContainsNull = $valueContainsNull), " + + s"nullable = $nullable") { + val schema = new StructType().add( + "map", MapType(keyType, valueType, valueContainsNull), nullable) + val encoder = RowEncoder(schema).resolveAndBind() + assert(encoder.serializer.length == 1) + assert(encoder.serializer.head.dataType == MapType(keyType, valueType, valueContainsNull)) + assert(encoder.serializer.head.nullable == nullable) + } + } + private def encodeDecodeTest(schema: StructType): Unit = { test(s"encode/decode: ${schema.simpleString}") { val encoder = RowEncoder(schema).resolveAndBind() From 5787ace463b2abde50d2ca24e8dd111e3a7c158e Mon Sep 17 00:00:00 2001 From: ouyangxiaochen Date: Wed, 5 Jul 2017 20:46:42 +0800 Subject: [PATCH 0871/1765] [SPARK-20383][SQL] Supporting Create [temporary] Function with the keyword 'OR REPLACE' and 'IF NOT EXISTS' ## What changes were proposed in this pull request? support to create [temporary] function with the keyword 'OR REPLACE' and 'IF NOT EXISTS' ## How was this patch tested? manual test and added test cases Please review http://spark.apache.org/contributing.html before opening a pull request. Author: ouyangxiaochen Closes #17681 from ouyangxiaochen/spark-419. --- .../spark/sql/catalyst/parser/SqlBase.g4 | 3 +- .../catalyst/catalog/ExternalCatalog.scala | 9 ++++ .../catalyst/catalog/InMemoryCatalog.scala | 6 +++ .../sql/catalyst/catalog/SessionCatalog.scala | 23 ++++++++ .../spark/sql/catalyst/catalog/events.scala | 10 ++++ .../catalog/ExternalCatalogEventSuite.scala | 9 ++++ .../catalog/ExternalCatalogSuite.scala | 9 ++++ .../spark/sql/execution/SparkSqlParser.scala | 8 +-- .../sql/execution/command/functions.scala | 46 +++++++++++----- .../execution/command/DDLCommandSuite.scala | 52 ++++++++++++++++++- .../sql/execution/command/DDLSuite.scala | 51 ++++++++++++++++++ .../spark/sql/hive/HiveExternalCatalog.scala | 9 ++++ 12 files changed, 216 insertions(+), 19 deletions(-) diff --git a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 index 29f554451ed4a..ef9f88a9026c9 100644 --- a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 +++ b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 @@ -126,7 +126,8 @@ statement tableIdentifier ('(' colTypeList ')')? tableProvider (OPTIONS tablePropertyList)? #createTempViewUsing | ALTER VIEW tableIdentifier AS? query #alterViewQuery - | CREATE TEMPORARY? FUNCTION qualifiedName AS className=STRING + | CREATE (OR REPLACE)? TEMPORARY? FUNCTION (IF NOT EXISTS)? + qualifiedName AS className=STRING (USING resource (',' resource)*)? #createFunction | DROP TEMPORARY? FUNCTION (IF EXISTS)? qualifiedName #dropFunction | EXPLAIN (LOGICAL | FORMATTED | EXTENDED | CODEGEN | COST)? diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalog.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalog.scala index 0254b6bb6d136..6000d483db209 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalog.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalog.scala @@ -332,6 +332,15 @@ abstract class ExternalCatalog protected def doDropFunction(db: String, funcName: String): Unit + final def alterFunction(db: String, funcDefinition: CatalogFunction): Unit = { + val name = funcDefinition.identifier.funcName + postToAll(AlterFunctionPreEvent(db, name)) + doAlterFunction(db, funcDefinition) + postToAll(AlterFunctionEvent(db, name)) + } + + protected def doAlterFunction(db: String, funcDefinition: CatalogFunction): Unit + final def renameFunction(db: String, oldName: String, newName: String): Unit = { postToAll(RenameFunctionPreEvent(db, oldName, newName)) doRenameFunction(db, oldName, newName) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/InMemoryCatalog.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/InMemoryCatalog.scala index 747190faa3c8c..d253c72a62739 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/InMemoryCatalog.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/InMemoryCatalog.scala @@ -590,6 +590,12 @@ class InMemoryCatalog( catalog(db).functions.remove(funcName) } + override protected def doAlterFunction(db: String, func: CatalogFunction): Unit = synchronized { + requireDbExists(db) + requireFunctionExists(db, func.identifier.funcName) + catalog(db).functions.put(func.identifier.funcName, func) + } + override protected def doRenameFunction( db: String, oldName: String, diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala index a86604e4353ab..c40d5f6031a21 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala @@ -1055,6 +1055,29 @@ class SessionCatalog( } } + /** + * overwirte a metastore function in the database specified in `funcDefinition`.. + * If no database is specified, assume the function is in the current database. + */ + def alterFunction(funcDefinition: CatalogFunction): Unit = { + val db = formatDatabaseName(funcDefinition.identifier.database.getOrElse(getCurrentDatabase)) + requireDbExists(db) + val identifier = FunctionIdentifier(funcDefinition.identifier.funcName, Some(db)) + val newFuncDefinition = funcDefinition.copy(identifier = identifier) + if (functionExists(identifier)) { + if (functionRegistry.functionExists(identifier)) { + // If we have loaded this function into the FunctionRegistry, + // also drop it from there. + // For a permanent function, because we loaded it to the FunctionRegistry + // when it's first used, we also need to drop it from the FunctionRegistry. + functionRegistry.dropFunction(identifier) + } + externalCatalog.alterFunction(db, newFuncDefinition) + } else { + throw new NoSuchFunctionException(db = db, func = identifier.toString) + } + } + /** * Retrieve the metadata of a metastore function. * diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/events.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/events.scala index 459973a13bb10..742a51e640383 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/events.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/events.scala @@ -139,6 +139,16 @@ case class DropFunctionPreEvent(database: String, name: String) extends Function */ case class DropFunctionEvent(database: String, name: String) extends FunctionEvent +/** + * Event fired before a function is altered. + */ +case class AlterFunctionPreEvent(database: String, name: String) extends FunctionEvent + +/** + * Event fired after a function has been altered. + */ +case class AlterFunctionEvent(database: String, name: String) extends FunctionEvent + /** * Event fired before a function is renamed. */ diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalogEventSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalogEventSuite.scala index 2539ea615ff92..087c26f23f383 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalogEventSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalogEventSuite.scala @@ -176,6 +176,15 @@ class ExternalCatalogEventSuite extends SparkFunSuite { } checkEvents(RenameFunctionPreEvent("db5", "fn7", "fn4") :: Nil) + // ALTER + val alteredFunctionDefinition = CatalogFunction( + identifier = FunctionIdentifier("fn4", Some("db5")), + className = "org.apache.spark.AlterFunction", + resources = Seq.empty) + catalog.alterFunction("db5", alteredFunctionDefinition) + checkEvents( + AlterFunctionPreEvent("db5", "fn4") :: AlterFunctionEvent("db5", "fn4") :: Nil) + // DROP intercept[AnalysisException] { catalog.dropFunction("db5", "fn7") diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalogSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalogSuite.scala index c22d55fc96a65..66e895a4690c1 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalogSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalogSuite.scala @@ -752,6 +752,14 @@ abstract class ExternalCatalogSuite extends SparkFunSuite with BeforeAndAfterEac } } + test("alter function") { + val catalog = newBasicCatalog() + assert(catalog.getFunction("db2", "func1").className == funcClass) + val myNewFunc = catalog.getFunction("db2", "func1").copy(className = newFuncClass) + catalog.alterFunction("db2", myNewFunc) + assert(catalog.getFunction("db2", "func1").className == newFuncClass) + } + test("list functions") { val catalog = newBasicCatalog() catalog.createFunction("db2", newFunc("func2")) @@ -916,6 +924,7 @@ abstract class CatalogTestUtils { lazy val partWithEmptyValue = CatalogTablePartition(Map("a" -> "3", "b" -> ""), storageFormat) lazy val funcClass = "org.apache.spark.myFunc" + lazy val newFuncClass = "org.apache.spark.myNewFunc" /** * Creates a basic catalog, with the following structure: diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala index 2b79eb5eac0f1..2f8e416e7df1b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala @@ -687,8 +687,8 @@ class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder(conf) { * * For example: * {{{ - * CREATE [TEMPORARY] FUNCTION [db_name.]function_name AS class_name - * [USING JAR|FILE|ARCHIVE 'file_uri' [, JAR|FILE|ARCHIVE 'file_uri']]; + * CREATE [OR REPLACE] [TEMPORARY] FUNCTION [IF NOT EXISTS] [db_name.]function_name + * AS class_name [USING JAR|FILE|ARCHIVE 'file_uri' [, JAR|FILE|ARCHIVE 'file_uri']]; * }}} */ override def visitCreateFunction(ctx: CreateFunctionContext): LogicalPlan = withOrigin(ctx) { @@ -709,7 +709,9 @@ class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder(conf) { functionIdentifier.funcName, string(ctx.className), resources, - ctx.TEMPORARY != null) + ctx.TEMPORARY != null, + ctx.EXISTS != null, + ctx.REPLACE != null) } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/functions.scala index a91ad413f4d1b..4f92ffee687aa 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/functions.scala @@ -31,13 +31,13 @@ import org.apache.spark.sql.types.{StringType, StructField, StructType} * The DDL command that creates a function. * To create a temporary function, the syntax of using this command in SQL is: * {{{ - * CREATE TEMPORARY FUNCTION functionName + * CREATE [OR REPLACE] TEMPORARY FUNCTION functionName * AS className [USING JAR\FILE 'uri' [, JAR|FILE 'uri']] * }}} * * To create a permanent function, the syntax in SQL is: * {{{ - * CREATE FUNCTION [databaseName.]functionName + * CREATE [OR REPLACE] FUNCTION [IF NOT EXISTS] [databaseName.]functionName * AS className [USING JAR\FILE 'uri' [, JAR|FILE 'uri']] * }}} */ @@ -46,26 +46,46 @@ case class CreateFunctionCommand( functionName: String, className: String, resources: Seq[FunctionResource], - isTemp: Boolean) + isTemp: Boolean, + ifNotExists: Boolean, + replace: Boolean) extends RunnableCommand { + if (ifNotExists && replace) { + throw new AnalysisException("CREATE FUNCTION with both IF NOT EXISTS and REPLACE" + + " is not allowed.") + } + + // Disallow to define a temporary function with `IF NOT EXISTS` + if (ifNotExists && isTemp) { + throw new AnalysisException( + "It is not allowed to define a TEMPORARY function with IF NOT EXISTS.") + } + + // Temporary function names should not contain database prefix like "database.function" + if (databaseName.isDefined && isTemp) { + throw new AnalysisException(s"Specifying a database in CREATE TEMPORARY FUNCTION " + + s"is not allowed: '${databaseName.get}'") + } + override def run(sparkSession: SparkSession): Seq[Row] = { val catalog = sparkSession.sessionState.catalog val func = CatalogFunction(FunctionIdentifier(functionName, databaseName), className, resources) if (isTemp) { - if (databaseName.isDefined) { - throw new AnalysisException(s"Specifying a database in CREATE TEMPORARY FUNCTION " + - s"is not allowed: '${databaseName.get}'") - } // We first load resources and then put the builder in the function registry. catalog.loadFunctionResources(resources) - catalog.registerFunction(func, overrideIfExists = false) + catalog.registerFunction(func, overrideIfExists = replace) } else { - // For a permanent, we will store the metadata into underlying external catalog. - // This function will be loaded into the FunctionRegistry when a query uses it. - // We do not load it into FunctionRegistry right now. - // TODO: should we also parse "IF NOT EXISTS"? - catalog.createFunction(func, ignoreIfExists = false) + // Handles `CREATE OR REPLACE FUNCTION AS ... USING ...` + if (replace && catalog.functionExists(func.identifier)) { + // alter the function in the metastore + catalog.alterFunction(CatalogFunction(func.identifier, className, resources)) + } else { + // For a permanent, we will store the metadata into underlying external catalog. + // This function will be loaded into the FunctionRegistry when a query uses it. + // We do not load it into FunctionRegistry right now. + catalog.createFunction(CatalogFunction(func.identifier, className, resources), ifNotExists) + } } Seq.empty[Row] } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLCommandSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLCommandSuite.scala index 8a6bc62fec96c..5643c58d9f847 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLCommandSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLCommandSuite.scala @@ -181,8 +181,29 @@ class DDLCommandSuite extends PlanTest { |'com.matthewrathbone.example.SimpleUDFExample' USING ARCHIVE '/path/to/archive', |FILE '/path/to/file' """.stripMargin + val sql3 = + """ + |CREATE OR REPLACE TEMPORARY FUNCTION helloworld3 as + |'com.matthewrathbone.example.SimpleUDFExample' USING JAR '/path/to/jar1', + |JAR '/path/to/jar2' + """.stripMargin + val sql4 = + """ + |CREATE OR REPLACE FUNCTION hello.world1 as + |'com.matthewrathbone.example.SimpleUDFExample' USING ARCHIVE '/path/to/archive', + |FILE '/path/to/file' + """.stripMargin + val sql5 = + """ + |CREATE FUNCTION IF NOT EXISTS hello.world2 as + |'com.matthewrathbone.example.SimpleUDFExample' USING ARCHIVE '/path/to/archive', + |FILE '/path/to/file' + """.stripMargin val parsed1 = parser.parsePlan(sql1) val parsed2 = parser.parsePlan(sql2) + val parsed3 = parser.parsePlan(sql3) + val parsed4 = parser.parsePlan(sql4) + val parsed5 = parser.parsePlan(sql5) val expected1 = CreateFunctionCommand( None, "helloworld", @@ -190,7 +211,7 @@ class DDLCommandSuite extends PlanTest { Seq( FunctionResource(FunctionResourceType.fromString("jar"), "/path/to/jar1"), FunctionResource(FunctionResourceType.fromString("jar"), "/path/to/jar2")), - isTemp = true) + isTemp = true, ifNotExists = false, replace = false) val expected2 = CreateFunctionCommand( Some("hello"), "world", @@ -198,9 +219,36 @@ class DDLCommandSuite extends PlanTest { Seq( FunctionResource(FunctionResourceType.fromString("archive"), "/path/to/archive"), FunctionResource(FunctionResourceType.fromString("file"), "/path/to/file")), - isTemp = false) + isTemp = false, ifNotExists = false, replace = false) + val expected3 = CreateFunctionCommand( + None, + "helloworld3", + "com.matthewrathbone.example.SimpleUDFExample", + Seq( + FunctionResource(FunctionResourceType.fromString("jar"), "/path/to/jar1"), + FunctionResource(FunctionResourceType.fromString("jar"), "/path/to/jar2")), + isTemp = true, ifNotExists = false, replace = true) + val expected4 = CreateFunctionCommand( + Some("hello"), + "world1", + "com.matthewrathbone.example.SimpleUDFExample", + Seq( + FunctionResource(FunctionResourceType.fromString("archive"), "/path/to/archive"), + FunctionResource(FunctionResourceType.fromString("file"), "/path/to/file")), + isTemp = false, ifNotExists = false, replace = true) + val expected5 = CreateFunctionCommand( + Some("hello"), + "world2", + "com.matthewrathbone.example.SimpleUDFExample", + Seq( + FunctionResource(FunctionResourceType.fromString("archive"), "/path/to/archive"), + FunctionResource(FunctionResourceType.fromString("file"), "/path/to/file")), + isTemp = false, ifNotExists = true, replace = false) comparePlans(parsed1, expected1) comparePlans(parsed2, expected2) + comparePlans(parsed3, expected3) + comparePlans(parsed4, expected4) + comparePlans(parsed5, expected5) } test("drop function") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala index e4dd077715d0f..5c40d8bb4b1ef 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala @@ -2270,6 +2270,57 @@ abstract class DDLSuite extends QueryTest with SQLTestUtils { } } + test("create temporary function with if not exists") { + withUserDefinedFunction("func1" -> true) { + val sql1 = + """ + |CREATE TEMPORARY FUNCTION IF NOT EXISTS func1 as + |'com.matthewrathbone.example.SimpleUDFExample' USING JAR '/path/to/jar1', + |JAR '/path/to/jar2' + """.stripMargin + val e = intercept[AnalysisException] { + sql(sql1) + }.getMessage + assert(e.contains("It is not allowed to define a TEMPORARY function with IF NOT EXISTS")) + } + } + + test("create function with both if not exists and replace") { + withUserDefinedFunction("func1" -> false) { + val sql1 = + """ + |CREATE OR REPLACE FUNCTION IF NOT EXISTS func1 as + |'com.matthewrathbone.example.SimpleUDFExample' USING JAR '/path/to/jar1', + |JAR '/path/to/jar2' + """.stripMargin + val e = intercept[AnalysisException] { + sql(sql1) + }.getMessage + assert(e.contains("CREATE FUNCTION with both IF NOT EXISTS and REPLACE is not allowed")) + } + } + + test("create temporary function by specifying a database") { + val dbName = "mydb" + withDatabase(dbName) { + sql(s"CREATE DATABASE $dbName") + sql(s"USE $dbName") + withUserDefinedFunction("func1" -> true) { + val sql1 = + s""" + |CREATE TEMPORARY FUNCTION $dbName.func1 as + |'com.matthewrathbone.example.SimpleUDFExample' USING JAR '/path/to/jar1', + |JAR '/path/to/jar2' + """.stripMargin + val e = intercept[AnalysisException] { + sql(sql1) + }.getMessage + assert(e.contains(s"Specifying a database in CREATE TEMPORARY FUNCTION " + + s"is not allowed: '$dbName'")) + } + } + } + Seq(true, false).foreach { caseSensitive => test(s"alter table add columns with existing column name - caseSensitive $caseSensitive") { withSQLConf(SQLConf.CASE_SENSITIVE.key -> s"$caseSensitive") { diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala index 2a17849fa8a34..306b38048e3a5 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala @@ -1132,6 +1132,15 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat client.dropFunction(db, name) } + override protected def doAlterFunction( + db: String, funcDefinition: CatalogFunction): Unit = withClient { + requireDbExists(db) + val functionName = funcDefinition.identifier.funcName.toLowerCase(Locale.ROOT) + requireFunctionExists(db, functionName) + val functionIdentifier = funcDefinition.identifier.copy(funcName = functionName) + client.alterFunction(db, funcDefinition.copy(identifier = functionIdentifier)) + } + override protected def doRenameFunction( db: String, oldName: String, From e3e2b5da3671a6c6d152b4de481a8aa3e57a6e42 Mon Sep 17 00:00:00 2001 From: "he.qiao" Date: Wed, 5 Jul 2017 21:13:25 +0800 Subject: [PATCH 0872/1765] [SPARK-21286][TEST] Modified StorageTabSuite unit test ## What changes were proposed in this pull request? The old unit test not effect ## How was this patch tested? unit test Author: he.qiao Closes #18511 from Geek-He/dev_0703. --- .../scala/org/apache/spark/ui/storage/StorageTabSuite.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/core/src/test/scala/org/apache/spark/ui/storage/StorageTabSuite.scala b/core/src/test/scala/org/apache/spark/ui/storage/StorageTabSuite.scala index 66dda382eb653..1cb52593e7060 100644 --- a/core/src/test/scala/org/apache/spark/ui/storage/StorageTabSuite.scala +++ b/core/src/test/scala/org/apache/spark/ui/storage/StorageTabSuite.scala @@ -74,7 +74,7 @@ class StorageTabSuite extends SparkFunSuite with BeforeAndAfter { // Submitting RDDInfos with duplicate IDs does nothing val rddInfo0Cached = new RDDInfo(0, "freedom", 100, StorageLevel.MEMORY_ONLY, Seq(10)) rddInfo0Cached.numCachedPartitions = 1 - val stageInfo0Cached = new StageInfo(0, 0, "0", 100, Seq(rddInfo0), Seq.empty, "details") + val stageInfo0Cached = new StageInfo(0, 0, "0", 100, Seq(rddInfo0Cached), Seq.empty, "details") bus.postToAll(SparkListenerStageSubmitted(stageInfo0Cached)) assert(storageListener._rddInfoMap.size === 4) assert(storageListener.rddInfoList.size === 2) From 960298ee66b9b8a80f84df679ce5b4b3846267f4 Mon Sep 17 00:00:00 2001 From: sadikovi Date: Wed, 5 Jul 2017 14:40:44 +0100 Subject: [PATCH 0873/1765] [SPARK-20858][DOC][MINOR] Document ListenerBus event queue size ## What changes were proposed in this pull request? This change adds a new configuration option `spark.scheduler.listenerbus.eventqueue.size` to the configuration docs to specify the capacity of the spark listener bus event queue. Default value is 10000. This is doc PR for [SPARK-15703](https://issues.apache.org/jira/browse/SPARK-15703). I added option to the `Scheduling` section, however it might be more related to `Spark UI` section. ## How was this patch tested? Manually verified correct rendering of configuration option. Author: sadikovi Author: Ivan Sadikov Closes #18476 from sadikovi/SPARK-20858. --- docs/configuration.md | 19 ++++++++++++++----- 1 file changed, 14 insertions(+), 5 deletions(-) diff --git a/docs/configuration.md b/docs/configuration.md index bd6a1f9e240e2..c785a664c67b1 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -725,7 +725,7 @@ Apart from these, the following properties are also available, and may be useful spark.ui.retainedJobs 1000 - How many jobs the Spark UI and status APIs remember before garbage collecting. + How many jobs the Spark UI and status APIs remember before garbage collecting. This is a target maximum, and fewer elements may be retained in some circumstances. @@ -733,7 +733,7 @@ Apart from these, the following properties are also available, and may be useful spark.ui.retainedStages 1000 - How many stages the Spark UI and status APIs remember before garbage collecting. + How many stages the Spark UI and status APIs remember before garbage collecting. This is a target maximum, and fewer elements may be retained in some circumstances. @@ -741,7 +741,7 @@ Apart from these, the following properties are also available, and may be useful spark.ui.retainedTasks 100000 - How many tasks the Spark UI and status APIs remember before garbage collecting. + How many tasks the Spark UI and status APIs remember before garbage collecting. This is a target maximum, and fewer elements may be retained in some circumstances. @@ -1389,6 +1389,15 @@ Apart from these, the following properties are also available, and may be useful The interval length for the scheduler to revive the worker resource offers to run tasks. + + spark.scheduler.listenerbus.eventqueue.capacity + 10000 + + Capacity for event queue in Spark listener bus, must be greater than 0. Consider increasing + value (e.g. 20000) if listener events are dropped. Increasing this value may result in the + driver using more memory. + + spark.blacklist.enabled @@ -1475,8 +1484,8 @@ Apart from these, the following properties are also available, and may be useful spark.blacklist.application.fetchFailure.enabled false - (Experimental) If set to "true", Spark will blacklist the executor immediately when a fetch - failure happenes. If external shuffle service is enabled, then the whole node will be + (Experimental) If set to "true", Spark will blacklist the executor immediately when a fetch + failure happenes. If external shuffle service is enabled, then the whole node will be blacklisted. From 742da0868534dab3d4d7b7edbe5ba9dc8bf26cc8 Mon Sep 17 00:00:00 2001 From: Jeff Zhang Date: Wed, 5 Jul 2017 10:59:10 -0700 Subject: [PATCH 0874/1765] [SPARK-19439][PYSPARK][SQL] PySpark's registerJavaFunction Should Support UDAFs ## What changes were proposed in this pull request? Support register Java UDAFs in PySpark so that user can use Java UDAF in PySpark. Besides that I also add api in `UDFRegistration` ## How was this patch tested? Unit test is added Author: Jeff Zhang Closes #17222 from zjffdu/SPARK-19439. --- python/pyspark/sql/context.py | 23 ++++++++ python/pyspark/sql/tests.py | 10 ++++ .../apache/spark/sql/UDFRegistration.scala | 33 +++++++++-- .../org/apache/spark/sql/JavaUDAFSuite.java | 55 +++++++++++++++++++ .../org/apache/spark/sql}/MyDoubleAvg.java | 2 +- .../org/apache/spark/sql}/MyDoubleSum.java | 8 +-- sql/hive/pom.xml | 7 +++ .../spark/sql/hive/JavaDataFrameSuite.java | 2 +- .../execution/AggregationQuerySuite.scala | 5 +- 9 files changed, 132 insertions(+), 13 deletions(-) create mode 100644 sql/core/src/test/java/test/org/apache/spark/sql/JavaUDAFSuite.java rename sql/{hive/src/test/java/org/apache/spark/sql/hive/aggregate => core/src/test/java/test/org/apache/spark/sql}/MyDoubleAvg.java (99%) rename sql/{hive/src/test/java/org/apache/spark/sql/hive/aggregate => core/src/test/java/test/org/apache/spark/sql}/MyDoubleSum.java (98%) diff --git a/python/pyspark/sql/context.py b/python/pyspark/sql/context.py index 426f07cd9410d..c44ab247fd3d3 100644 --- a/python/pyspark/sql/context.py +++ b/python/pyspark/sql/context.py @@ -232,6 +232,23 @@ def registerJavaFunction(self, name, javaClassName, returnType=None): jdt = self.sparkSession._jsparkSession.parseDataType(returnType.json()) self.sparkSession._jsparkSession.udf().registerJava(name, javaClassName, jdt) + @ignore_unicode_prefix + @since(2.3) + def registerJavaUDAF(self, name, javaClassName): + """Register a java UDAF so it can be used in SQL statements. + + :param name: name of the UDAF + :param javaClassName: fully qualified name of java class + + >>> sqlContext.registerJavaUDAF("javaUDAF", + ... "test.org.apache.spark.sql.MyDoubleAvg") + >>> df = sqlContext.createDataFrame([(1, "a"),(2, "b"), (3, "a")],["id", "name"]) + >>> df.registerTempTable("df") + >>> sqlContext.sql("SELECT name, javaUDAF(id) as avg from df group by name").collect() + [Row(name=u'b', avg=102.0), Row(name=u'a', avg=102.0)] + """ + self.sparkSession._jsparkSession.udf().registerJavaUDAF(name, javaClassName) + # TODO(andrew): delete this once we refactor things to take in SparkSession def _inferSchema(self, rdd, samplingRatio=None): """ @@ -551,6 +568,12 @@ def __init__(self, sqlContext): def register(self, name, f, returnType=StringType()): return self.sqlContext.registerFunction(name, f, returnType) + def registerJavaFunction(self, name, javaClassName, returnType=None): + self.sqlContext.registerJavaFunction(name, javaClassName, returnType) + + def registerJavaUDAF(self, name, javaClassName): + self.sqlContext.registerJavaUDAF(name, javaClassName) + register.__doc__ = SQLContext.registerFunction.__doc__ diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 16ba8bd73f400..c0e3b8d132396 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -481,6 +481,16 @@ def test_udf_registration_returns_udf(self): df.select(add_three("id").alias("plus_three")).collect() ) + def test_non_existed_udf(self): + spark = self.spark + self.assertRaisesRegexp(AnalysisException, "Can not load class non_existed_udf", + lambda: spark.udf.registerJavaFunction("udf1", "non_existed_udf")) + + def test_non_existed_udaf(self): + spark = self.spark + self.assertRaisesRegexp(AnalysisException, "Can not load class non_existed_udaf", + lambda: spark.udf.registerJavaUDAF("udaf1", "non_existed_udaf")) + def test_multiLine_json(self): people1 = self.spark.read.json("python/test_support/sql/people.json") people_array = self.spark.read.json("python/test_support/sql/people_array.json", diff --git a/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala b/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala index ad01b889429c7..8bdc0221888d0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala @@ -17,7 +17,6 @@ package org.apache.spark.sql -import java.io.IOException import java.lang.reflect.{ParameterizedType, Type} import scala.reflect.runtime.universe.TypeTag @@ -456,9 +455,9 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends .map(_.asInstanceOf[ParameterizedType]) .filter(e => e.getRawType.isInstanceOf[Class[_]] && e.getRawType.asInstanceOf[Class[_]].getCanonicalName.startsWith("org.apache.spark.sql.api.java.UDF")) if (udfInterfaces.length == 0) { - throw new IOException(s"UDF class ${className} doesn't implement any UDF interface") + throw new AnalysisException(s"UDF class ${className} doesn't implement any UDF interface") } else if (udfInterfaces.length > 1) { - throw new IOException(s"It is invalid to implement multiple UDF interfaces, UDF class ${className}") + throw new AnalysisException(s"It is invalid to implement multiple UDF interfaces, UDF class ${className}") } else { try { val udf = clazz.newInstance() @@ -491,19 +490,41 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends case 21 => register(name, udf.asInstanceOf[UDF20[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _]], returnType) case 22 => register(name, udf.asInstanceOf[UDF21[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _]], returnType) case 23 => register(name, udf.asInstanceOf[UDF22[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _]], returnType) - case n => logError(s"UDF class with ${n} type arguments is not supported ") + case n => + throw new AnalysisException(s"UDF class with ${n} type arguments is not supported.") } } catch { case e @ (_: InstantiationException | _: IllegalArgumentException) => - logError(s"Can not instantiate class ${className}, please make sure it has public non argument constructor") + throw new AnalysisException(s"Can not instantiate class ${className}, please make sure it has public non argument constructor") } } } catch { - case e: ClassNotFoundException => logError(s"Can not load class ${className}, please make sure it is on the classpath") + case e: ClassNotFoundException => throw new AnalysisException(s"Can not load class ${className}, please make sure it is on the classpath") } } + /** + * Register a Java UDAF class using reflection, for use from pyspark + * + * @param name UDAF name + * @param className fully qualified class name of UDAF + */ + private[sql] def registerJavaUDAF(name: String, className: String): Unit = { + try { + val clazz = Utils.classForName(className) + if (!classOf[UserDefinedAggregateFunction].isAssignableFrom(clazz)) { + throw new AnalysisException(s"class $className doesn't implement interface UserDefinedAggregateFunction") + } + val udaf = clazz.newInstance().asInstanceOf[UserDefinedAggregateFunction] + register(name, udaf) + } catch { + case e: ClassNotFoundException => throw new AnalysisException(s"Can not load class ${className}, please make sure it is on the classpath") + case e @ (_: InstantiationException | _: IllegalArgumentException) => + throw new AnalysisException(s"Can not instantiate class ${className}, please make sure it has public non argument constructor") + } + } + /** * Register a user-defined function with 1 arguments. * @since 1.3.0 diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaUDAFSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaUDAFSuite.java new file mode 100644 index 0000000000000..ddbaa45a483cb --- /dev/null +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaUDAFSuite.java @@ -0,0 +1,55 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package test.org.apache.spark.sql; + +import org.apache.spark.sql.Row; +import org.apache.spark.sql.SparkSession; +import org.junit.After; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; + + +public class JavaUDAFSuite { + + private transient SparkSession spark; + + @Before + public void setUp() { + spark = SparkSession.builder() + .master("local[*]") + .appName("testing") + .getOrCreate(); + } + + @After + public void tearDown() { + spark.stop(); + spark = null; + } + + @SuppressWarnings("unchecked") + @Test + public void udf1Test() { + spark.range(1, 10).toDF("value").registerTempTable("df"); + spark.udf().registerJavaUDAF("myDoubleAvg", MyDoubleAvg.class.getName()); + Row result = spark.sql("SELECT myDoubleAvg(value) as my_avg from df").head(); + Assert.assertEquals(105.0, result.getDouble(0), 1.0e-6); + } + +} diff --git a/sql/hive/src/test/java/org/apache/spark/sql/hive/aggregate/MyDoubleAvg.java b/sql/core/src/test/java/test/org/apache/spark/sql/MyDoubleAvg.java similarity index 99% rename from sql/hive/src/test/java/org/apache/spark/sql/hive/aggregate/MyDoubleAvg.java rename to sql/core/src/test/java/test/org/apache/spark/sql/MyDoubleAvg.java index ae0c097c362ab..447a71d284fbb 100644 --- a/sql/hive/src/test/java/org/apache/spark/sql/hive/aggregate/MyDoubleAvg.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/MyDoubleAvg.java @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.sql.hive.aggregate; +package test.org.apache.spark.sql; import java.util.ArrayList; import java.util.List; diff --git a/sql/hive/src/test/java/org/apache/spark/sql/hive/aggregate/MyDoubleSum.java b/sql/core/src/test/java/test/org/apache/spark/sql/MyDoubleSum.java similarity index 98% rename from sql/hive/src/test/java/org/apache/spark/sql/hive/aggregate/MyDoubleSum.java rename to sql/core/src/test/java/test/org/apache/spark/sql/MyDoubleSum.java index d17fb3e5194f3..93d20330c717f 100644 --- a/sql/hive/src/test/java/org/apache/spark/sql/hive/aggregate/MyDoubleSum.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/MyDoubleSum.java @@ -15,18 +15,18 @@ * limitations under the License. */ -package org.apache.spark.sql.hive.aggregate; +package test.org.apache.spark.sql; import java.util.ArrayList; import java.util.List; +import org.apache.spark.sql.Row; import org.apache.spark.sql.expressions.MutableAggregationBuffer; import org.apache.spark.sql.expressions.UserDefinedAggregateFunction; -import org.apache.spark.sql.types.StructField; -import org.apache.spark.sql.types.StructType; import org.apache.spark.sql.types.DataType; import org.apache.spark.sql.types.DataTypes; -import org.apache.spark.sql.Row; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; /** * An example {@link UserDefinedAggregateFunction} to calculate the sum of a diff --git a/sql/hive/pom.xml b/sql/hive/pom.xml index 09dcc4055e000..f9462e79a69f3 100644 --- a/sql/hive/pom.xml +++ b/sql/hive/pom.xml @@ -57,6 +57,13 @@ spark-sql_${scala.binary.version} ${project.version} + + org.apache.spark + spark-sql_${scala.binary.version} + ${project.version} + test-jar + test + org.apache.spark spark-tags_${scala.binary.version} diff --git a/sql/hive/src/test/java/org/apache/spark/sql/hive/JavaDataFrameSuite.java b/sql/hive/src/test/java/org/apache/spark/sql/hive/JavaDataFrameSuite.java index aefc9cc77da88..636ce10da3734 100644 --- a/sql/hive/src/test/java/org/apache/spark/sql/hive/JavaDataFrameSuite.java +++ b/sql/hive/src/test/java/org/apache/spark/sql/hive/JavaDataFrameSuite.java @@ -31,7 +31,7 @@ import org.apache.spark.sql.expressions.UserDefinedAggregateFunction; import static org.apache.spark.sql.functions.*; import org.apache.spark.sql.hive.test.TestHive$; -import org.apache.spark.sql.hive.aggregate.MyDoubleSum; +import test.org.apache.spark.sql.MyDoubleSum; public class JavaDataFrameSuite { private transient SQLContext hc; diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala index 84f915977bd88..f245a79f805a2 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala @@ -20,16 +20,19 @@ package org.apache.spark.sql.hive.execution import scala.collection.JavaConverters._ import scala.util.Random +import test.org.apache.spark.sql.MyDoubleAvg +import test.org.apache.spark.sql.MyDoubleSum + import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.expressions.UnsafeRow import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction} import org.apache.spark.sql.functions._ -import org.apache.spark.sql.hive.aggregate.{MyDoubleAvg, MyDoubleSum} import org.apache.spark.sql.hive.test.TestHiveSingleton import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SQLTestUtils import org.apache.spark.sql.types._ + class ScalaAggregateFunction(schema: StructType) extends UserDefinedAggregateFunction { def inputSchema: StructType = schema From c8e7f445b98fce0b419b26f43dd3a75bf7c7375b Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Wed, 5 Jul 2017 11:06:15 -0700 Subject: [PATCH 0875/1765] [SPARK-21307][SQL] Remove SQLConf parameters from the parser-related classes. ### What changes were proposed in this pull request? This PR is to remove SQLConf parameters from the parser-related classes. ### How was this patch tested? The existing test cases. Author: gatorsmile Closes #18531 from gatorsmile/rmSQLConfParser. --- .../sql/catalyst/catalog/SessionCatalog.scala | 2 +- .../sql/catalyst/parser/AstBuilder.scala | 6 +- .../sql/catalyst/parser/ParseDriver.scala | 8 +- .../parser/ExpressionParserSuite.scala | 167 +++++++++--------- .../spark/sql/execution/SparkSqlParser.scala | 11 +- .../org/apache/spark/sql/functions.scala | 3 +- .../internal/BaseSessionStateBuilder.scala | 2 +- .../sql/internal/VariableSubstitution.scala | 4 +- .../sql/execution/SparkSqlParserSuite.scala | 10 +- .../execution/command/DDLCommandSuite.scala | 4 +- .../internal/VariableSubstitutionSuite.scala | 31 ++-- 11 files changed, 121 insertions(+), 127 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala index c40d5f6031a21..336d3d65d0dd0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala @@ -74,7 +74,7 @@ class SessionCatalog( functionRegistry, conf, new Configuration(), - new CatalystSqlParser(conf), + CatalystSqlParser, DummyFunctionResourceLoader) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index 8eac3ef2d3568..b6a4686bb9ec9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -45,11 +45,9 @@ import org.apache.spark.util.random.RandomSampler * The AstBuilder converts an ANTLR4 ParseTree into a catalyst Expression, LogicalPlan or * TableIdentifier. */ -class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging { +class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with Logging { import ParserUtils._ - def this() = this(new SQLConf()) - protected def typedVisit[T](ctx: ParseTree): T = { ctx.accept(this).asInstanceOf[T] } @@ -1457,7 +1455,7 @@ class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging * Special characters can be escaped by using Hive/C-style escaping. */ private def createString(ctx: StringLiteralContext): String = { - if (conf.escapedStringLiterals) { + if (SQLConf.get.escapedStringLiterals) { ctx.STRING().asScala.map(stringWithoutUnescape).mkString } else { ctx.STRING().asScala.map(string).mkString diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParseDriver.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParseDriver.scala index 09598ffe770c6..7e1fcfefc64a5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParseDriver.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParseDriver.scala @@ -26,7 +26,6 @@ import org.apache.spark.sql.catalyst.{FunctionIdentifier, TableIdentifier} import org.apache.spark.sql.catalyst.expressions.Expression import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.trees.Origin -import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.{DataType, StructType} /** @@ -122,13 +121,8 @@ abstract class AbstractSqlParser extends ParserInterface with Logging { /** * Concrete SQL parser for Catalyst-only SQL statements. */ -class CatalystSqlParser(conf: SQLConf) extends AbstractSqlParser { - val astBuilder = new AstBuilder(conf) -} - -/** For test-only. */ object CatalystSqlParser extends AbstractSqlParser { - val astBuilder = new AstBuilder(new SQLConf()) + val astBuilder = new AstBuilder } /** diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala index 45f9f72dccc45..ac7325257a15a 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala @@ -167,12 +167,12 @@ class ExpressionParserSuite extends PlanTest { } test("like expressions with ESCAPED_STRING_LITERALS = true") { - val conf = new SQLConf() - conf.setConfString(SQLConf.ESCAPED_STRING_LITERALS.key, "true") - val parser = new CatalystSqlParser(conf) - assertEqual("a rlike '^\\x20[\\x20-\\x23]+$'", 'a rlike "^\\x20[\\x20-\\x23]+$", parser) - assertEqual("a rlike 'pattern\\\\'", 'a rlike "pattern\\\\", parser) - assertEqual("a rlike 'pattern\\t\\n'", 'a rlike "pattern\\t\\n", parser) + val parser = CatalystSqlParser + withSQLConf(SQLConf.ESCAPED_STRING_LITERALS.key -> "true") { + assertEqual("a rlike '^\\x20[\\x20-\\x23]+$'", 'a rlike "^\\x20[\\x20-\\x23]+$", parser) + assertEqual("a rlike 'pattern\\\\'", 'a rlike "pattern\\\\", parser) + assertEqual("a rlike 'pattern\\t\\n'", 'a rlike "pattern\\t\\n", parser) + } } test("is null expressions") { @@ -435,86 +435,85 @@ class ExpressionParserSuite extends PlanTest { } test("strings") { + val parser = CatalystSqlParser Seq(true, false).foreach { escape => - val conf = new SQLConf() - conf.setConfString(SQLConf.ESCAPED_STRING_LITERALS.key, escape.toString) - val parser = new CatalystSqlParser(conf) - - // tests that have same result whatever the conf is - // Single Strings. - assertEqual("\"hello\"", "hello", parser) - assertEqual("'hello'", "hello", parser) - - // Multi-Strings. - assertEqual("\"hello\" 'world'", "helloworld", parser) - assertEqual("'hello' \" \" 'world'", "hello world", parser) - - // 'LIKE' string literals. Notice that an escaped '%' is the same as an escaped '\' and a - // regular '%'; to get the correct result you need to add another escaped '\'. - // TODO figure out if we shouldn't change the ParseUtils.unescapeSQLString method? - assertEqual("'pattern%'", "pattern%", parser) - assertEqual("'no-pattern\\%'", "no-pattern\\%", parser) - - // tests that have different result regarding the conf - if (escape) { - // When SQLConf.ESCAPED_STRING_LITERALS is enabled, string literal parsing fallbacks to - // Spark 1.6 behavior. - - // 'LIKE' string literals. - assertEqual("'pattern\\\\%'", "pattern\\\\%", parser) - assertEqual("'pattern\\\\\\%'", "pattern\\\\\\%", parser) - - // Escaped characters. - // Unescape string literal "'\\0'" for ASCII NUL (X'00') doesn't work - // when ESCAPED_STRING_LITERALS is enabled. - // It is parsed literally. - assertEqual("'\\0'", "\\0", parser) - - // Note: Single quote follows 1.6 parsing behavior when ESCAPED_STRING_LITERALS is enabled. - val e = intercept[ParseException](parser.parseExpression("'\''")) - assert(e.message.contains("extraneous input '''")) - - // The unescape special characters (e.g., "\\t") for 2.0+ don't work - // when ESCAPED_STRING_LITERALS is enabled. They are parsed literally. - assertEqual("'\\\"'", "\\\"", parser) // Double quote - assertEqual("'\\b'", "\\b", parser) // Backspace - assertEqual("'\\n'", "\\n", parser) // Newline - assertEqual("'\\r'", "\\r", parser) // Carriage return - assertEqual("'\\t'", "\\t", parser) // Tab character - - // The unescape Octals for 2.0+ don't work when ESCAPED_STRING_LITERALS is enabled. - // They are parsed literally. - assertEqual("'\\110\\145\\154\\154\\157\\041'", "\\110\\145\\154\\154\\157\\041", parser) - // The unescape Unicode for 2.0+ doesn't work when ESCAPED_STRING_LITERALS is enabled. - // They are parsed literally. - assertEqual("'\\u0057\\u006F\\u0072\\u006C\\u0064\\u0020\\u003A\\u0029'", - "\\u0057\\u006F\\u0072\\u006C\\u0064\\u0020\\u003A\\u0029", parser) - } else { - // Default behavior - - // 'LIKE' string literals. - assertEqual("'pattern\\\\%'", "pattern\\%", parser) - assertEqual("'pattern\\\\\\%'", "pattern\\\\%", parser) - - // Escaped characters. - // See: http://dev.mysql.com/doc/refman/5.7/en/string-literals.html - assertEqual("'\\0'", "\u0000", parser) // ASCII NUL (X'00') - assertEqual("'\\''", "\'", parser) // Single quote - assertEqual("'\\\"'", "\"", parser) // Double quote - assertEqual("'\\b'", "\b", parser) // Backspace - assertEqual("'\\n'", "\n", parser) // Newline - assertEqual("'\\r'", "\r", parser) // Carriage return - assertEqual("'\\t'", "\t", parser) // Tab character - assertEqual("'\\Z'", "\u001A", parser) // ASCII 26 - CTRL + Z (EOF on windows) - - // Octals - assertEqual("'\\110\\145\\154\\154\\157\\041'", "Hello!", parser) - - // Unicode - assertEqual("'\\u0057\\u006F\\u0072\\u006C\\u0064\\u0020\\u003A\\u0029'", "World :)", - parser) + withSQLConf(SQLConf.ESCAPED_STRING_LITERALS.key -> escape.toString) { + // tests that have same result whatever the conf is + // Single Strings. + assertEqual("\"hello\"", "hello", parser) + assertEqual("'hello'", "hello", parser) + + // Multi-Strings. + assertEqual("\"hello\" 'world'", "helloworld", parser) + assertEqual("'hello' \" \" 'world'", "hello world", parser) + + // 'LIKE' string literals. Notice that an escaped '%' is the same as an escaped '\' and a + // regular '%'; to get the correct result you need to add another escaped '\'. + // TODO figure out if we shouldn't change the ParseUtils.unescapeSQLString method? + assertEqual("'pattern%'", "pattern%", parser) + assertEqual("'no-pattern\\%'", "no-pattern\\%", parser) + + // tests that have different result regarding the conf + if (escape) { + // When SQLConf.ESCAPED_STRING_LITERALS is enabled, string literal parsing fallbacks to + // Spark 1.6 behavior. + + // 'LIKE' string literals. + assertEqual("'pattern\\\\%'", "pattern\\\\%", parser) + assertEqual("'pattern\\\\\\%'", "pattern\\\\\\%", parser) + + // Escaped characters. + // Unescape string literal "'\\0'" for ASCII NUL (X'00') doesn't work + // when ESCAPED_STRING_LITERALS is enabled. + // It is parsed literally. + assertEqual("'\\0'", "\\0", parser) + + // Note: Single quote follows 1.6 parsing behavior when ESCAPED_STRING_LITERALS is + // enabled. + val e = intercept[ParseException](parser.parseExpression("'\''")) + assert(e.message.contains("extraneous input '''")) + + // The unescape special characters (e.g., "\\t") for 2.0+ don't work + // when ESCAPED_STRING_LITERALS is enabled. They are parsed literally. + assertEqual("'\\\"'", "\\\"", parser) // Double quote + assertEqual("'\\b'", "\\b", parser) // Backspace + assertEqual("'\\n'", "\\n", parser) // Newline + assertEqual("'\\r'", "\\r", parser) // Carriage return + assertEqual("'\\t'", "\\t", parser) // Tab character + + // The unescape Octals for 2.0+ don't work when ESCAPED_STRING_LITERALS is enabled. + // They are parsed literally. + assertEqual("'\\110\\145\\154\\154\\157\\041'", "\\110\\145\\154\\154\\157\\041", parser) + // The unescape Unicode for 2.0+ doesn't work when ESCAPED_STRING_LITERALS is enabled. + // They are parsed literally. + assertEqual("'\\u0057\\u006F\\u0072\\u006C\\u0064\\u0020\\u003A\\u0029'", + "\\u0057\\u006F\\u0072\\u006C\\u0064\\u0020\\u003A\\u0029", parser) + } else { + // Default behavior + + // 'LIKE' string literals. + assertEqual("'pattern\\\\%'", "pattern\\%", parser) + assertEqual("'pattern\\\\\\%'", "pattern\\\\%", parser) + + // Escaped characters. + // See: http://dev.mysql.com/doc/refman/5.7/en/string-literals.html + assertEqual("'\\0'", "\u0000", parser) // ASCII NUL (X'00') + assertEqual("'\\''", "\'", parser) // Single quote + assertEqual("'\\\"'", "\"", parser) // Double quote + assertEqual("'\\b'", "\b", parser) // Backspace + assertEqual("'\\n'", "\n", parser) // Newline + assertEqual("'\\r'", "\r", parser) // Carriage return + assertEqual("'\\t'", "\t", parser) // Tab character + assertEqual("'\\Z'", "\u001A", parser) // ASCII 26 - CTRL + Z (EOF on windows) + + // Octals + assertEqual("'\\110\\145\\154\\154\\157\\041'", "Hello!", parser) + + // Unicode + assertEqual("'\\u0057\\u006F\\u0072\\u006C\\u0064\\u0020\\u003A\\u0029'", "World :)", + parser) + } } - } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala index 2f8e416e7df1b..618d027d8dc07 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala @@ -39,10 +39,11 @@ import org.apache.spark.sql.types.StructType /** * Concrete parser for Spark SQL statements. */ -class SparkSqlParser(conf: SQLConf) extends AbstractSqlParser { - val astBuilder = new SparkSqlAstBuilder(conf) +class SparkSqlParser extends AbstractSqlParser { - private val substitutor = new VariableSubstitution(conf) + val astBuilder = new SparkSqlAstBuilder + + private val substitutor = new VariableSubstitution protected override def parse[T](command: String)(toResult: SqlBaseParser => T): T = { super.parse(substitutor.substitute(command))(toResult) @@ -52,9 +53,11 @@ class SparkSqlParser(conf: SQLConf) extends AbstractSqlParser { /** * Builder that converts an ANTLR ParseTree into a LogicalPlan/Expression/TableIdentifier. */ -class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder(conf) { +class SparkSqlAstBuilder extends AstBuilder { import org.apache.spark.sql.catalyst.parser.ParserUtils._ + private def conf: SQLConf = SQLConf.get + /** * Create a [[SetCommand]] logical plan. * diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 839cbf42024e3..3c67960d13e09 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -32,7 +32,6 @@ import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.plans.logical.{HintInfo, ResolvedHint} import org.apache.spark.sql.execution.SparkSqlParser import org.apache.spark.sql.expressions.UserDefinedFunction -import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ import org.apache.spark.util.Utils @@ -1276,7 +1275,7 @@ object functions { */ def expr(expr: String): Column = { val parser = SparkSession.getActiveSession.map(_.sessionState.sqlParser).getOrElse { - new SparkSqlParser(new SQLConf) + new SparkSqlParser } Column(parser.parseExpression(expr)) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala index 2532b2ddb72df..9d0148117fadf 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala @@ -114,7 +114,7 @@ abstract class BaseSessionStateBuilder( * Note: this depends on the `conf` field. */ protected lazy val sqlParser: ParserInterface = { - extensions.buildParser(session, new SparkSqlParser(conf)) + extensions.buildParser(session, new SparkSqlParser) } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/VariableSubstitution.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/VariableSubstitution.scala index 4e7c813be9922..2b9c574aaaf0c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/VariableSubstitution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/VariableSubstitution.scala @@ -25,7 +25,9 @@ import org.apache.spark.internal.config._ * * Variable substitution is controlled by `SQLConf.variableSubstituteEnabled`. */ -class VariableSubstitution(conf: SQLConf) { +class VariableSubstitution { + + private def conf = SQLConf.get private val provider = new ConfigProvider { override def get(key: String): Option[String] = Option(conf.getConfString(key, "")) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlParserSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlParserSuite.scala index d238c76fbeeff..2e29fa43f73d9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlParserSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlParserSuite.scala @@ -37,8 +37,7 @@ import org.apache.spark.sql.types.{IntegerType, LongType, StringType, StructType */ class SparkSqlParserSuite extends AnalysisTest { - val newConf = new SQLConf - private lazy val parser = new SparkSqlParser(newConf) + private lazy val parser = new SparkSqlParser /** * Normalizes plans: @@ -285,6 +284,7 @@ class SparkSqlParserSuite extends AnalysisTest { } test("query organization") { + val conf = SQLConf.get // Test all valid combinations of order by/sort by/distribute by/cluster by/limit/windows val baseSql = "select * from t" val basePlan = @@ -293,20 +293,20 @@ class SparkSqlParserSuite extends AnalysisTest { assertEqual(s"$baseSql distribute by a, b", RepartitionByExpression(UnresolvedAttribute("a") :: UnresolvedAttribute("b") :: Nil, basePlan, - numPartitions = newConf.numShufflePartitions)) + numPartitions = conf.numShufflePartitions)) assertEqual(s"$baseSql distribute by a sort by b", Sort(SortOrder(UnresolvedAttribute("b"), Ascending) :: Nil, global = false, RepartitionByExpression(UnresolvedAttribute("a") :: Nil, basePlan, - numPartitions = newConf.numShufflePartitions))) + numPartitions = conf.numShufflePartitions))) assertEqual(s"$baseSql cluster by a, b", Sort(SortOrder(UnresolvedAttribute("a"), Ascending) :: SortOrder(UnresolvedAttribute("b"), Ascending) :: Nil, global = false, RepartitionByExpression(UnresolvedAttribute("a") :: UnresolvedAttribute("b") :: Nil, basePlan, - numPartitions = newConf.numShufflePartitions))) + numPartitions = conf.numShufflePartitions))) } test("pipeline concatenation") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLCommandSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLCommandSuite.scala index 5643c58d9f847..750574830381f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLCommandSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLCommandSuite.scala @@ -29,13 +29,13 @@ import org.apache.spark.sql.catalyst.plans.PlanTest import org.apache.spark.sql.catalyst.plans.logical.Project import org.apache.spark.sql.execution.SparkSqlParser import org.apache.spark.sql.execution.datasources.CreateTable -import org.apache.spark.sql.internal.{HiveSerDe, SQLConf} +import org.apache.spark.sql.internal.HiveSerDe import org.apache.spark.sql.types.{IntegerType, StringType, StructField, StructType} // TODO: merge this with DDLSuite (SPARK-14441) class DDLCommandSuite extends PlanTest { - private lazy val parser = new SparkSqlParser(new SQLConf) + private lazy val parser = new SparkSqlParser private def assertUnsupported(sql: String, containsThesePhrases: Seq[String] = Seq()): Unit = { val e = intercept[ParseException] { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/internal/VariableSubstitutionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/internal/VariableSubstitutionSuite.scala index d5a946aeaac31..c5e5b70e21335 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/internal/VariableSubstitutionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/internal/VariableSubstitutionSuite.scala @@ -18,12 +18,11 @@ package org.apache.spark.sql.internal import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.catalyst.plans.PlanTest -class VariableSubstitutionSuite extends SparkFunSuite { +class VariableSubstitutionSuite extends SparkFunSuite with PlanTest { - private lazy val conf = new SQLConf - private lazy val sub = new VariableSubstitution(conf) + private lazy val sub = new VariableSubstitution test("system property") { System.setProperty("varSubSuite.var", "abcd") @@ -35,26 +34,26 @@ class VariableSubstitutionSuite extends SparkFunSuite { } test("Spark configuration variable") { - conf.setConfString("some-random-string-abcd", "1234abcd") - assert(sub.substitute("${hiveconf:some-random-string-abcd}") == "1234abcd") - assert(sub.substitute("${sparkconf:some-random-string-abcd}") == "1234abcd") - assert(sub.substitute("${spark:some-random-string-abcd}") == "1234abcd") - assert(sub.substitute("${some-random-string-abcd}") == "1234abcd") + withSQLConf("some-random-string-abcd" -> "1234abcd") { + assert(sub.substitute("${hiveconf:some-random-string-abcd}") == "1234abcd") + assert(sub.substitute("${sparkconf:some-random-string-abcd}") == "1234abcd") + assert(sub.substitute("${spark:some-random-string-abcd}") == "1234abcd") + assert(sub.substitute("${some-random-string-abcd}") == "1234abcd") + } } test("multiple substitutes") { val q = "select ${bar} ${foo} ${doo} this is great" - conf.setConfString("bar", "1") - conf.setConfString("foo", "2") - conf.setConfString("doo", "3") - assert(sub.substitute(q) == "select 1 2 3 this is great") + withSQLConf("bar" -> "1", "foo" -> "2", "doo" -> "3") { + assert(sub.substitute(q) == "select 1 2 3 this is great") + } } test("test nested substitutes") { val q = "select ${bar} ${foo} this is great" - conf.setConfString("bar", "1") - conf.setConfString("foo", "${bar}") - assert(sub.substitute(q) == "select 1 1 this is great") + withSQLConf("bar" -> "1", "foo" -> "${bar}") { + assert(sub.substitute(q) == "select 1 1 this is great") + } } } From c8d0aba198c0f593c2b6b656c23b3d0fb7ea98a2 Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Wed, 5 Jul 2017 16:33:23 -0700 Subject: [PATCH 0876/1765] [SPARK-21278][PYSPARK] Upgrade to Py4J 0.10.6 ## What changes were proposed in this pull request? This PR aims to bump Py4J in order to fix the following float/double bug. Py4J 0.10.5 fixes this (https://github.com/bartdag/py4j/issues/272) and the latest Py4J is 0.10.6. **BEFORE** ``` >>> df = spark.range(1) >>> df.select(df['id'] + 17.133574204226083).show() +--------------------+ |(id + 17.1335742042)| +--------------------+ | 17.1335742042| +--------------------+ ``` **AFTER** ``` >>> df = spark.range(1) >>> df.select(df['id'] + 17.133574204226083).show() +-------------------------+ |(id + 17.133574204226083)| +-------------------------+ | 17.133574204226083| +-------------------------+ ``` ## How was this patch tested? Manual. Author: Dongjoon Hyun Closes #18546 from dongjoon-hyun/SPARK-21278. --- LICENSE | 2 +- bin/pyspark | 2 +- bin/pyspark2.cmd | 2 +- core/pom.xml | 2 +- .../apache/spark/api/python/PythonUtils.scala | 2 +- dev/deps/spark-deps-hadoop-2.6 | 2 +- dev/deps/spark-deps-hadoop-2.7 | 2 +- python/README.md | 2 +- python/docs/Makefile | 2 +- python/lib/py4j-0.10.4-src.zip | Bin 74096 -> 0 bytes python/lib/py4j-0.10.6-src.zip | Bin 0 -> 80352 bytes python/setup.py | 2 +- .../org/apache/spark/deploy/yarn/Client.scala | 2 +- .../spark/deploy/yarn/YarnClusterSuite.scala | 2 +- sbin/spark-config.sh | 2 +- 15 files changed, 13 insertions(+), 13 deletions(-) delete mode 100644 python/lib/py4j-0.10.4-src.zip create mode 100644 python/lib/py4j-0.10.6-src.zip diff --git a/LICENSE b/LICENSE index 66a2e8f132953..39fe0dc462385 100644 --- a/LICENSE +++ b/LICENSE @@ -263,7 +263,7 @@ The text of each license is also included at licenses/LICENSE-[project].txt. (New BSD license) Protocol Buffer Java API (org.spark-project.protobuf:protobuf-java:2.4.1-shaded - http://code.google.com/p/protobuf) (The BSD License) Fortran to Java ARPACK (net.sourceforge.f2j:arpack_combined_all:0.1 - http://f2j.sourceforge.net) (The BSD License) xmlenc Library (xmlenc:xmlenc:0.52 - http://xmlenc.sourceforge.net) - (The New BSD License) Py4J (net.sf.py4j:py4j:0.10.4 - http://py4j.sourceforge.net/) + (The New BSD License) Py4J (net.sf.py4j:py4j:0.10.6 - http://py4j.sourceforge.net/) (Two-clause BSD-style license) JUnit-Interface (com.novocode:junit-interface:0.10 - http://github.com/szeiger/junit-interface/) (BSD licence) sbt and sbt-launch-lib.bash (BSD 3 Clause) d3.min.js (https://github.com/mbostock/d3/blob/master/LICENSE) diff --git a/bin/pyspark b/bin/pyspark index 98387c2ec5b8a..d3b512eeb1209 100755 --- a/bin/pyspark +++ b/bin/pyspark @@ -57,7 +57,7 @@ export PYSPARK_PYTHON # Add the PySpark classes to the Python path: export PYTHONPATH="${SPARK_HOME}/python/:$PYTHONPATH" -export PYTHONPATH="${SPARK_HOME}/python/lib/py4j-0.10.4-src.zip:$PYTHONPATH" +export PYTHONPATH="${SPARK_HOME}/python/lib/py4j-0.10.6-src.zip:$PYTHONPATH" # Load the PySpark shell.py script when ./pyspark is used interactively: export OLD_PYTHONSTARTUP="$PYTHONSTARTUP" diff --git a/bin/pyspark2.cmd b/bin/pyspark2.cmd index f211c0873ad2f..46d4d5c883cfb 100644 --- a/bin/pyspark2.cmd +++ b/bin/pyspark2.cmd @@ -30,7 +30,7 @@ if "x%PYSPARK_DRIVER_PYTHON%"=="x" ( ) set PYTHONPATH=%SPARK_HOME%\python;%PYTHONPATH% -set PYTHONPATH=%SPARK_HOME%\python\lib\py4j-0.10.4-src.zip;%PYTHONPATH% +set PYTHONPATH=%SPARK_HOME%\python\lib\py4j-0.10.6-src.zip;%PYTHONPATH% set OLD_PYTHONSTARTUP=%PYTHONSTARTUP% set PYTHONSTARTUP=%SPARK_HOME%\python\pyspark\shell.py diff --git a/core/pom.xml b/core/pom.xml index 326dde4f274bb..91ee941471495 100644 --- a/core/pom.xml +++ b/core/pom.xml @@ -335,7 +335,7 @@ net.sf.py4j py4j - 0.10.4 + 0.10.6 org.apache.spark diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonUtils.scala b/core/src/main/scala/org/apache/spark/api/python/PythonUtils.scala index c4e55b5e89027..92e228a9dd10c 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonUtils.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonUtils.scala @@ -32,7 +32,7 @@ private[spark] object PythonUtils { val pythonPath = new ArrayBuffer[String] for (sparkHome <- sys.env.get("SPARK_HOME")) { pythonPath += Seq(sparkHome, "python", "lib", "pyspark.zip").mkString(File.separator) - pythonPath += Seq(sparkHome, "python", "lib", "py4j-0.10.4-src.zip").mkString(File.separator) + pythonPath += Seq(sparkHome, "python", "lib", "py4j-0.10.6-src.zip").mkString(File.separator) } pythonPath ++= SparkContext.jarOfObject(this) pythonPath.mkString(File.pathSeparator) diff --git a/dev/deps/spark-deps-hadoop-2.6 b/dev/deps/spark-deps-hadoop-2.6 index 9287bd47cf113..c1325318d52fa 100644 --- a/dev/deps/spark-deps-hadoop-2.6 +++ b/dev/deps/spark-deps-hadoop-2.6 @@ -156,7 +156,7 @@ parquet-jackson-1.8.2.jar pmml-model-1.2.15.jar pmml-schema-1.2.15.jar protobuf-java-2.5.0.jar -py4j-0.10.4.jar +py4j-0.10.6.jar pyrolite-4.13.jar scala-compiler-2.11.8.jar scala-library-2.11.8.jar diff --git a/dev/deps/spark-deps-hadoop-2.7 b/dev/deps/spark-deps-hadoop-2.7 index 9127413ab6c23..ac5abd21807b6 100644 --- a/dev/deps/spark-deps-hadoop-2.7 +++ b/dev/deps/spark-deps-hadoop-2.7 @@ -157,7 +157,7 @@ parquet-jackson-1.8.2.jar pmml-model-1.2.15.jar pmml-schema-1.2.15.jar protobuf-java-2.5.0.jar -py4j-0.10.4.jar +py4j-0.10.6.jar pyrolite-4.13.jar scala-compiler-2.11.8.jar scala-library-2.11.8.jar diff --git a/python/README.md b/python/README.md index 0a5c8010b8486..84ec88141cb00 100644 --- a/python/README.md +++ b/python/README.md @@ -29,4 +29,4 @@ The Python packaging for Spark is not intended to replace all of the other use c ## Python Requirements -At its core PySpark depends on Py4J (currently version 0.10.4), but additional sub-packages have their own requirements (including numpy and pandas). \ No newline at end of file +At its core PySpark depends on Py4J (currently version 0.10.6), but additional sub-packages have their own requirements (including numpy and pandas). diff --git a/python/docs/Makefile b/python/docs/Makefile index 5e4cfb8ab6fe3..09898f29950ed 100644 --- a/python/docs/Makefile +++ b/python/docs/Makefile @@ -7,7 +7,7 @@ SPHINXBUILD ?= sphinx-build PAPER ?= BUILDDIR ?= _build -export PYTHONPATH=$(realpath ..):$(realpath ../lib/py4j-0.10.4-src.zip) +export PYTHONPATH=$(realpath ..):$(realpath ../lib/py4j-0.10.6-src.zip) # User-friendly check for sphinx-build ifeq ($(shell which $(SPHINXBUILD) >/dev/null 2>&1; echo $$?), 1) diff --git a/python/lib/py4j-0.10.4-src.zip b/python/lib/py4j-0.10.4-src.zip deleted file mode 100644 index 8c3829e328726df4dfad7b848d6daaed03495760..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 74096 zcmY(qV~{R9&@DQ)?b&1Qv2EM7ZQI&o;~CqwZQHi(`+n!1b5Gs;NK#2xex$0qv)1Y; zNP~i*0sT))&s3EBUz7jcf&Vu;c(Pd0EBtR0C?LvzO%f{;IB_}u?IGX*0U`Y#6C*=o zYX^HL7di*eVGZ5NB?ctl+ghj(2W^e$w>;vSJ@$f#yuDzFf7e&r7MV+=%bWhUq_X^r zH#f`P&42VV$1|9dGMLYINZV=CIs65z8arM2I~)(ioa!ao@ltTF#5lo^S4-PblK zy-)^8!w`v0o|QoxZ99dFDMc75zp?k3fo2|d1Z5ZRo8&x)3LBju_h)PcAkPz;!weMI z{)}eE=T9UMfkS~x(jv?~6%`-Wl>*`w0Zcx<@i_P*VX4?uGh!yOxFp444%{297-k?m z73R}T0KhRE(Gd<_T1K7%!Q$a?xK0?PJ!+t3Os~N@G}+=JO!)L>@SVh~N)E#5XjaVA zRRDFHHyv^zAeZ2SfNR`pgV!L@{#q3~20qP;IFg(Lug1yL$V^9<3?>ECLu##>*(a?(5%JZX?j+I~RO3kc(>$>z<^^d3qMZcn#>Yyk>3K3x@ z41k6#Wfu>{KQP9Ch!`waWhjP4OnHeDZz?j5zfmZ*-&VpVG=IBJk}bN_0U->h(1&=Y zp~`^s;&XijeW46KegJUco0?Fqo%t~6*a47k(LUcZ?bpCjm|U>zb{nnFC_C=*zfTp~ z{PT<6x^kEuR+R|~$-&ES&H*Z5d-0<0iUrH`F1V~G9oquImm??3f(8A$X4u&kMt+q_ zM3ZcJW_!RqWM%)WpjL8jM{N$}ufkEtuO{oyk}uXA%2x4V&5?AxV8#2I!e8miKmH@m z?=@RFwUHG-LGB1rUTKaoJ^I&ya=SIlLYGdM^Y4Ah{#Dxc8KFh3uJhJMq&_YnLD@VdMa{Ec{;W{F%n!HV&j7gCSgph=&y;u>)}hhC|R z7tx;^DWaz6-v*>G65LUz(D0J_hWhps8Wk@_6) z9jV8Ps*QfoVZ)cr>*8}*4hem=l%?0QDO~1rkQ+kLzWVKlhV1*h0 z-SccbgI_x%J&EHeXW-iq0Xa{m%l3p)( zg(;{&Kw2i{#Ud&wy`v|b8$9xax-=MFPvZNXggoBAovLZJkK-%pv>f=jQgf3?&T7=+S9>@qnQ zRmjEhp3~G66(z%1Q*ip6#&qJ5z$ur6fMC1>*D0cEnph38HZ}}sekyF0RZ0eSp+3lr zG=`V;jB&z;OBL;T8F0{4cO@D?D$c@8Q8$N9gkA(YLMb-g6MbZo4^$dsg#gubMz7Sj z?9mGRIwrZpb>e@WkswsdjMJbsttq8VbF0)T7+oz&Y-4JmJ2oyP0{Bod)5?A!W zN6+KN@>p+TiPB1m6j&4lm=f0{@_bgsMv3m_$n*1lpBP~P0J*y|_V%v*tE=qHRu$5I z>Arrx=zhM=-`|K|TCIHCS$@2*=K4BvwY`Cvn5E{6v$gayrR+{+Ly}nR-QlJw_H0Is z3k50*n`oXXqn;O{WWQ`h)jR8nKVGl>X^RD8(E)yX{X-MckeW(0{ceMy2Dm!;QV-8K zRcu>*utsrGu)q+hp>c{@(enlj0>;Rp?gdq5a4B?UQ-mn=jOS6}Y(p51I4NZ1ax{bP zOU#)HX$G6S=<{Zvx(eKG_bsMq^YktLX99?OzMG)vatO}` z=LmeO)`zKY3PMpisuG*WtjWn)^FkJRWF+>akWin*E`W=!ZEus_05F=5NBubT*qXjp z!i3)XwB5RTy@5E?(hQ_3OV`)K*oe0XV=Av|hn;VFLx}sndTN?YW<)!NqgNlefECuK z&e)1IyPH9zZXZr!HFV@&JSs@sV2ngllwiY!Mi)JJcwYT(_@`htK0DaRm>^P-UIe5# zDwN7f=lYn7N8DxK5h}jo8!)7j zqA?>t6;gG}u^s5J(K_Az8DZc3pPwO)K>_hFA`hmhwY)0i!&|OTd6yj7DSA=l#Ho1} zOJOvsywZXd@ViS1YO3eM48DK32G8KPwYHJreW*YyT1BPF6veW09BO0;p0&(5Ln(1KwEj|rg_3tHSjY=5Pw#6?6U0-UxuaSKE+$Dzny%=wb1>#>hla@`@$lPdb z;l6+++M)?hkm_5S7igBrBVjS$>|T^^A)2KJ+yBIRh9^-)pqjN|B4I#c#gx(V(iN23GQ&d{yID$&I*SXv+uZHy%@ zZwj#hjuzu!opRqD2?frtnQuc)D2JL|n+(wobld!^V@h11N4tzz2qNuO7#ImZx^M;r zJE;$IC8f=}SJVCUhN)m5?5U6P#05(jy+jORX*PZ)rap;)NaIx;DqS48a`5;su(33qP z(C-Ndau(t96nS4zW16r5$8(~+B^KwUGgM1`HL6E^(ikk)X%C9B4b1c@)2)Pv%j~3{ zQxe=XMuJwFhdTVXpd&-&Vf2uvrf}4bb0aJJRgEO?je(Zzs$rGiq>g^9>TZ|*%2coV zMS@6NLHXCMMtvSgM~)l%;?ov13z1@oJXE(bk3d~ zjVD5q%4SS4-yY#!$P1)m#lizf!L2AwMEsg)t+?rIZmof9}g+jCpRqSTQI27ks>hF21Y$50fb&D0p*5q$>)lMrlVy5nXva^{R= zA2R7>%cNS;#AtA=y^D+YD^SCH_>c_2n<*0Th9{j33SHzkda@@GaAS|iD`0O0BV-%3 z>nu|T+EgH;nn8cv9I4NAwUNBZOIl95l^5!PAB!u$+yt$?nj<_ZY#-cQt0}hwCcU!i za&YB9GUxJ>_^oUJu@Tu(e4AZ#b^qR)!2C1%?)u2R$TQUfw?LH|XB4|{OoH}{!wGDC z%j|FuNwFD`;uf;(gdlxz=2>8{#e>Dw)PmZD$L&KSxDoLaE` z)rzSEffj;y3Y&y5y_yg5D;(J<=h>S^&fvDdj>M_EiT!cCg%Ely+^VdJ$|eTvbX#Bh zx;On6DE%SC>4C=*pv7v_z}1JUvz+W+286fuVR6$%pbiG%zx%^#FG#(xVYbF#FbDSF znxIQhZXLx_?_F5VQBb3x22$8+!8I>Q-?Z!N64B=i34EpYAT+>lzbK%d14{6~Up#N~ zMkOKbCfzjlO1uPaXyJf;75(vezt#SnzwCXS8a5n&g8&pUxzB&j-3|V-vt?KWXRdZz z_vz|j!wZOg*n(v}i#|}QJ70??w-7$96SeZ4{~N{dPUZ4yvq9XrfjW)dK<*b-plz_kN75`RWl39lW6qTS~CUa7W0 z55eDBgJ3Ryb@m80E0mb^sgH`-TasMw);Itf*FN4$-a>Dv`+er}dAfx>9dl?$skzHL z!G#Sl%1C7?CMVaPV9um{_!!2xrbp<1%_~W|dHC#8t#d5PW*s}lZwA3;sC%@ve{Wa2 zq<-IF3NwpE$<5&N@>L@=NaLo3q8tMnleoKHk+#Z?hxBF;Jc^_ zaZ*~q!NP?%`jWFK;%mDap(_iQlH^*JtUKXvxL#TE+=FBD`AP{ZTujVq(YA9mKTN`P zs?U`$dFRV_#BPn3tBhJrA*tETuhg&hjw=rj&85vfHxwf$t~2A&5umj^;7W_v{0c>m zm{OCOWDa9E_CecHo1TuuP)3q;r!(6W=Dk`f+~|8VcU-UH!_iTA&pNU(Z8}~}EPKS< z*Q?be=XWaR!i!Hv)Fvpb=yZA44k_6WFV8d1`0VT#s7)KK_r9gFb534`VG)|cj3c6WB*=cB)OJLr66UAbhYgf zLEEs>`vc~@*MrNn>n~uX7j((JMFvc?Wv#NoPHQRl!W2U zFE(O3P8S_B9ca#wI?F2K* zbe^lV6!uWP4AS6ago29*mlp>%+~JX{i73lIEN!+?*|*X;P7a|(QD@ThVM}I606O^X1xFLIGx>3 zrz2D8F*14rJF+mqvU)tHA{uY0?%GYizW}7WV3Tje^B29#m6R7op{Kohny#Lc^3W5MAWr=z<0Lq`!puV%Z)ZYxW;Noa ze68c`cnd(E>q|G2bNL$Q?V8Enc}UV$j4h zqq3p4NxTto^`}@!o|U7gqZHgmkeRamwHRdW(;JgngzRA0b+r80oH2I;t!4KjI8*~T zF|u)g9+s<#r;+k;cfW%r!x^=;r30p8QKVYwyebFbeR?%Fe#ag@-j~P6s;>%??Pq5C z#)LQb5YbDGw`<~1csP5<379~S`H(|&b{}#)gm)DR-o{2VRq2uvO;b6O!M(?(7zeAE z8VD(cwRiu7A9w<*|82aNhRqZj?Kk@2zQ50fisE{E5s|<;Uw?|WFNJ_unB(dbhRAn! z;0kwvSjO!eH<)TqWx(U4fvU!rRqv!Pg>fq6 zR5fSr;FLq1wlLJ_Ir4`@w%j&$Y3oXp4-t;Ze02B?GXPpvTODb8#>RtWzWBVqyQ{rB zQ=q?dZ-Eua*h`u3uMKQ=NXIk&MA7H^@ zthG>VS0%h}YqI@}-BC?)Jha(&cBjkMGaZE6ie16Y+M~=sK*)$Rpi*mxHG)7axd9aw zHk)W5<5@$xxs2*b3mwD!WBiWt_$0CpkHypeA zyhjDze{7R$bL?jgLbIr2tJ2B0A1wEyuy1kILaq8%Y3qP!q%P-iDKV=>UfG*N&o(W# z#eXSPp`UyL#k_qf zW^`@;QI0sP(HgU%kC*&XjOCezb8jowfgs)6rl3~Lz*%>*_<^4>SXd(u-{Sr0?P>S* z4tD;Y<5bz*ya6ppbcc-q^YL5Sil4c&LNgHgr4b=Rx|&9xypEpE5b zm?CtdzM4yj*m_w3G-ZzQX&ku+>7t%0_x9Giqo^6aG=m+`LI*@67QA z`~Ro$?muHazCnP1jG+HhW&S@J&)(L-@PG7}Zq@17MFymvQ!24-=>}@#3*r!JFAbu- zfDnN{jslfqk`&eM%G}o6n_3WQih2@yD3W8+5ubxAy$P(p#V_nW?Jgt zNOLg3!6Tux^o$q9)s0gY?a-;4a@qH7V8OcpsYNjqIwxovTqHyysq8_dw|1SaJ}dBj zBZS&QQ=lhD^ofsU+w6v`nG*ZTF9Ek#c=$eWD9>3CS{4MOmm{@w{sF*>pC_pO{se`_ zCPy$gevK{ev^bHhTp!m<4JDS-vS&}`5EM`P9YZmp(WPZZR3!Yu6~c@~=WCD9!kD@p zHn&oVBJOZ?r#A<2^Qm0S2-k!nu8CB=mtf3IKH@g@KsO_Fk=KfJF9DMi+>7(#Ka!;+ z!T2OF^Yk%>C1+ercu;AI-%93tLM&(_tcB}?+hr~*{Y67DC&K|39J=MqJ!2+Ux2WtF z9;3H&o*8TOujT=mzk7c*-dSJTl73N_v>j{*|bt~-C5@f!)lqjYzLnz0w zD(k{aCKkmj3U^7$Ju7#k6U<%qbzb7lWz4f@3;Y)fN%)i#N}0as154zWN`g-V2K1FH z?Ie`P0k}6T{SuR}sDF16o38X`v`aGb%`&yaW37Wnh$Blv%^S;9y+?IJn7VU;{G zP3Os@N6SCz+WoCX6o1_>5(g_-%loSgw1RuvzFjdtim7BUZ{E~i^&A}QbBA3O$GUIm z9*{#x;bbh23 zES1x}kki`J=PQtlb>Wlv12xB?5R&Im=Jp_>N^cBaeN+f@)IN@vWQ zZ<504AzqZS{OCRzDIgaxXdtCAN%SLCy-g>TXJv{eqlBi<2!n)eb3eVwAsyE+w=QP( z7-*_iG;oHVUBe+e94S??TowSR=UlU#aNd=uZU|9urUbO_p`}AN0<|znxnMjRBzVYE zhk|q}d5^Nwt2iRlZ%%&ZE()>!ZTx2KXC;?RR5<6p{Z3VFJ2og$%dM|_qy_$<92~tn z4_K$zi|uGupdlx=&mjV-+IRrRHNiv(lcQduggt14WMDX?Dp5&e$IY;5t3Gcs5USE} zWRyn@7i?D=LW=m*P^BSXD;-jk_{}$S%cD6{ta`>|7#*sZ7->&q zDKi36Fw=QKTy2R!qPa+GGDRk5WP85t(fsZ^oEz59T(Ggxv(Ed-d+X0X*>gmxB7$Vo z9`JN@I|7gPyW=AjITdS_t1L=j<`seD6jFY5o3Dmc%au9Jlg}B;fKhndGHT;oH!KZ!92oZ*{B?XYe@a^xnuJ!> z;k6>Pixz&*@~VB*S^2TREj&pLN|2eUytb${LFk*D-m-EdLwSV83#|qz*?-WLkHbsr zfJ*jeowK{$Bb}sUrzN2co8cuL`;N&&Z{Fn`f&VXFp9fC*i0$)TX8HoIodaXJlH~p| z`^%#|Phq~bmUlZzGvgT+C0~cKbw#7L|nursAQUdyP zK+G;$Z0lLQ9HuONI;c+3s3oFSJK$LU|2qBuh+MbzX_VjEaky{ti&!pd|U z^_Wzz{R9^zp)$<~Pd@1ELZFM_=jlP)%CdR$ zdc~4QGen70Y5GtCoWut(0v%C^n^9zd)ce8C*At$PS1O1X>3H^^0xV(dST9pMZ^W%# z(nu4w8i>OD7ccjK)AchS>#q{>$d8EJkOdXF(*^xN)23B(EXsMQ#Xk~@n4wB_{Cg{V z|Jr4)JIuL6ya>WC((cs!?7zctKPfR^0{2KIl02vYcHN?A5ibCg*&~PKFfj~+$$6+C zH7dUf{(hVa9-Y#;WmM`((@JH`vL*CaHk}gT56APV^IJX%#6xXp%t%Ah^AWs=9j-F5 z>~<=0N$P|}CVj8xF8*j^du}RgLg*WN@V|sE(BJ0TFBVxn*S*eprWC)#H)*xtFi!7p zYb$lj1!NNadC0XhjHfQ~wJw^rrxkr1TfB+KdGDB(!Cd)c8&aZusgfNTp$m~#LPfOd zsu*9ukR_ZuwEp=lxT~hVrvc`qz27y3q>YQg6*GUPhqH)MD=CU#bQdnaPG4aj$F`a@1p%(V%)Z>=P<^2*q6SbRc)!ZPEDC!F4>FdQltnIC z@@D5>GXg#7`x?sMoOEL61MMj=7t2V_q-XRI8kC}n5Iu}1yZ8}j3?a@}If;k3w(eA9 z=mwOLba2I(aTMM0A(lzH5dkHO6DUMK7^PX~WVllNgb(O;76ByLt!_uGv50Vczbg3- zT<6m%k+o+^;qI|q$$3=UkA)fN9<)J?7?2naxDux>JmZH^WjGoIQT9g;D*o!YNJqdJ z?ir0lCyuq7CR`XkFDuR`tnd=m@Q9XCL!%UwI*OEI?%X-O9jbo~=tpekuIV+fLb*j9 zQ`k|4V>b3!rp{y;k_6CH2r5zfL0Y`KF{e50ztKZcV!ZY43pXcE^Lqovo*U4!y?=gU zmu~Sn#Y^>0d7`IlBB$X$fh}5N1suZ~GL3AC{Uk4?ThAo}ziLk~efvi-@YoCcVRE1u z(;SO)$?ld7O>7%TyT%nsune&c7AlD)yy|X$Ju-$CA9&!h_>vtU6b{gUnj3Uj1ZU#< z)71w|H2ZMl(|x$$gsW0UVDfCLo-*NPbcVo)Lvkg8v~ozvZ{*?N@=&@5w2a<1Ei5}l z59b6bb!+y6byiMlQ_k}`hiWda@cU^ZQVQ1Vr8rtzNlBCEwa(5e_UzA-HOQH&E~=?z zj%a%lZ0^+rmcd7Xgk3wa`;t`_mC`;0i`X1A#XhohR(R8PjKjjgJ>^)y*|ybA58+`! z^uz*RM!WqHhDN{rH*~K^%>6EqS|5S$h6(L-S|w()rV*i#6vpj>o2Xljh)w)Mp8TuM z*cvK7`4EW1*S8+TfDm+91?RU|IOk2x$$U4B)TrCdl)UtQpN!tZ+0PIym%*wZPpUV~ zLi8%=S{Pvx8E;GOv$|UkxbK_e^Aj@{Y^||}uHlK!jvUrzhxos#b;l2RxG zU)(}5Q0g9aC8v2=z5Zifna*m65d?|c*$l$Pzg4%=!K|(*tlygoa7yj$4B$$(RI~u< z)r#qhU4HAu|AITMfbAKAzf)(AIvR6y1t0v%Hj(z%9=sXk5N}b1wp9ASk5@W|9)_w6 zm3`BSi0)@h?>Z+zZaQwPg8#ZD?_I}R4S@8mgEYh0V`nq+4&p+Y_e+C@bFK!)qs4L)#T;VuK1^H^O$=VXOa z`AxQfK34HAhP7m_bIL^<*NzKYTV$@ha3P+toYn-lXivT+(9IN(Q0XDeh|cPdcSqCJ2wv9=?`H{Lyr4XFS&}HW+Tyh*I>hpgN~O&8S}nnGcs)) zU3e$SenoagtkZ3q3qw$yQ*?$hIY_UaN=j+rH+mRja1CBDfJ-6GV2~rPNjBh zHRUV9)Q2Jd!xV($lzNeilCQ5@R0e)5L2~=l6;cTv$;ru3^u7hLG96bdEbtz-j zvXPVCq&)%$9?ev>JO*8N^JkEktgW}T>wA*G-Cj9Sn~+WRXK!%uzgTTH5%5F-RZ+*q z24EW#8#`Sr3YM`+1x0c&b0Zcl0M7V7aw@U$Sh>l!g3%B!8P~Qq{A>t%0mWa-Pbxgl zRdYBFp|+?iHBP`qcd~5fMLCS4VHEM2lh91FOz|NGQ{CVRPFGd14yp(nQ2etd-?=;J z0Um`7wQcBd6ZY4Vf{pa9_-1PM(<}jgJb=&vje03bpqej`S+e&=tHuYnby^T_?2^Lq zpp>Kk>VCec*4h&yAA&WFe(Krmcki;brVT#NbZT`z$gby9T6v8L+2#IJ-=BXu^so;K z=4{6zpW~aM4`?UI4kuhIW)0nLRHZK7I_iuuE(9`5o2;Lq2SO;GfU5*1BFyH_@_lN; zrA1kNe35Eva>)@z2fWRk!qDc1xeY5XtLHlQwG279;(YJea$)p?cd zf%jhIj|L48^(U#Z*>!>I<%faqKY{M!&b0^EOTi>sxY@A=A!prwMh$y{>F*D z?8bE+jHU^_8-w6nS=|oBA6u2&12Y0{Ks6)?NpqdIPvx*RY9;1q9B9n)adQ8K;Oe2g zVy$4Qtj@hJebNrgH!i4Jm1qU4oTlXLOv-mFycN!A;eMC?JdIkD|6EMOCd{nbF=R?7 zZfe(SS~qS=PgYuQXG&k3DnuIoS3BNL72eLz*w2v%!&V*(Ig__Rue0UXd)LpS=eIu_ zRb9>RXl(5g%I5Wr`c369EQw1a{7WzhwCyZ=L^GuP*up1(nOKb`fL-p_uT{B@7p5hh?NtgK!@!Xqg=|I!i8s7$?IS*xsl0$#Gaw`p31q$<)&l zGjf_{Hx`S$sJ*92<4|Ivaht=Mma3#cTuPk-n;kcnj4~p9)cNv!M-wmQTc+COSpfgv zP4no|LGc=WyTxX_3aV8SGql;xUcFoickT+I);~-|tt0=UcORFZkE{7pX56-GA3vcn z)%uKyP0nVOo*ey~i;>a6&H>cbCx$MczBL@wlZIdKU#B|d;*au?j+65TdQpBQ57yKh zR9_ujsSf;QcF49?5t3kn#$)*%rAk06IAg3|?AyVbEsz(+JlOLIUMtV?E%-Gt#czwB zR!ra~W%dP9GBY}Sv7xEVT>Z^k#q_}2$C5;sGvbi$i-+IN=1)i*;Ih>k2Y_DNi!*o| zXme)-r)R}Ev6JO$^ypS4s>MH{SIIh_oS=e+QAltamy79Wnv(8alL4u+v4I)yiq#R# z0@d03c7d~fJyL0gkCr06*II7XO0pLK`;S%2GbNWDJ*?&8B%9Km;^x`?=``)&Oe^+1 z;VG;&eA%@N>bVMU9t{1;)7ePMK2TfpgwvDanSTsw164Q)wknp=e~~B7h`UqLz7rb@ zO<#!JWlLR%m$<4s8=$By`P@`rTg%Vm{k9D;RKV!Z6y(iJxg^B&5l)IMViPZ#H?bI@ z2ekb+8mBNaC!Hyw${k|{VH44vZ(0^EZ~VFBlW%lZBZ#QR13$Qq;C3M%udtW;0G`Vc z*I43Rdsv*)vU2(96e1f%qQj1vNM<{xE(SGZ^q>E9pXIvl==pvq``2N?ZolY*4aZTUE{MvR3;>RYU_a8#TrP0BkESL5<`lC`_Hy%i;wo(WYIMLJbcT)9<=oldSvD(e1p`}BqPbRt}L&bF-SD}&O{jWg|o?7(lu#u{e@*=gS zN@Wop`SP`GeLBd)wm@doCKyXioe(hn>2ar#A3qo#;w7;@`XX2!A3gc_tN8dUfbZL~ zwl%FbU;SC{fhwSz>SXxZ_kpLE4#RVOuk5+(QIk!(cEvTpRMNAv>H@uw&M-OJM(;pl zCQC6c0cmvPM*LF#I(!dEXgpYijdXPt2DcnvcJmGSgwC`%2l! zGC}s&^_P!y^Ly2vadwbH%zT@xDkbOF+-56G>#@kpN6p1cf-oB7Z@`0>;bX886UzN0 z5yi&er7afGdTAcqT^TR2tNiay(Zc*PYGtYg=>vPY^mQ-`kI3yGmMeaCDBCBpKei~sXG8h#UiMU-VUZdv1B;H`x~XdOgWbsK&q^N& zF4zT(PJC0xgK00Hz1pZVk4$)Te{tfoF`oq@=u&nEc?UeTrV@Z>Gx@PE=l~p4Kw9SOVVKjE0Ky~JI-?h*-cjKJZ>L8 zf9h60pPtQg@pH^ENBwz%wD>(f773pLguZTfAIGL{cVF-C^TXggH~gPFaf_vQz2?53 zFIP`rD0;rP;k&;-uh090bpF0xi8(pBKR<|ek%W|UgM)>QlUEOK%k$^fWp@vIcW0}m z^tpsN#@tF$+MGNd?+=T?gS%#!;-~*DLN4b|AbOkozxCTUMDG&hnueFs>l5JX5#$1p z7OLq#bbcPbSYKMH)PmW+KMT&|5`TDq8Vmd{T<86{L)Pu0aL5TGt@?c^tkssCpsVcl z&&(J^Mn-mlcRLyH;B!mG9D#Od(vmOZga$w$lUpd1W@C1DcIgn!=JqFTVaKj8=0#<7_&*&NY4~Z> z^83iS5^sL=O*x;(RSBXl1Rt<%b44(R9d>WzSZpxf%D98(LeYO` zda1Ei#%;2OfS|p)x>Z1aa!ONBel&~`59qn+s$ZJE+#ulCLo^Hv7Z-{))`-TGVvzTt zl|(KLm|?FZCx?oHi&}lA8kFz`xq5<{2I?uq+oOI=P_EDu1+&2ajQ4YYK!UVgikT7`vTVl49^`)kWolJwaf}DVxkzqM3 zg$cv7jtM!MHn>eOjtiQN`2|5)*j4TS@Nw{D>gMS6wplotz_(Bnn`kaQc6i8 zG2}3s0i(Np)JHG;R6APkz|%f=X<4 zx|mOODVnip3QFx83y|;$sf07da`m-Pd{KJR@dD-PA3cMVSOh$qR7jQ^**e~xzP7G^ zHnibJexwH))xaL71Mb5b!L-1&ML-6KXi}yd(vure#Evfbu&tjNpYW`akOFf5e1Uat z0>ub%Z0w*=@iAWClGP2b>jMFsE{wx75EibUY+bp7$0g@mG=d2#;G&(oauvjuT&i-) z>jy3HiYzr#1nA$Ls?t&(PBpB=-PX zI((lm_KrN=obE0t-(!2GV)gUsGIhItSSSlair`a~%Qk-(3lIqdeU0{H%BhMNiuy?c z?6f$!In32SaL4g4js@kF(F9}qjr($i%4)hnFG9)VfratFfyBopn@jML4GgpGosnD{>Nb+9pUbvcAoB^t7FABwcSEQ&hC>aKhIh|%E1eAe zWhQYj({4&nz@gZkajIhe2sa9^{abkI{PGtYV3;}^3D#1WM9Pg+J49HzFNC0;(QFq3znCATg8MJ&S`!l!A zB?!DhB+-YAUC{91WcI1%1N~Kvz2`n^X1yRqoW&N_-6`B$ATVk3390@8$v-2{a4wa-N6OSw_xVJXCyrZ^C1N8z(-O>?fU#;wIOw4K!mLg+ zo@)+GC&$yhIvrhXAPz|i;_zM<{wErE1uu3~IA}^bR!piq=CxvmGeMZyiung?^ zo<*sIpOEzeH9IIwTs@^^p{2JZRA=K-;=8 zm@s{bLt&f>l(5`hHciQu=MKt`J=13`Xa3*I5)@q*oy(jD#b&7}rD;jdfdI1L7D4t?2}J9m9pOfH@qqWkd*KM}pJ<&7tmb@OgRO63M#-O!m1Cz+ zx(1&2Xw#H-)6_ZGqnI(6^iAa9#0@)vacT?N#edw@S(A|NI0imBIpkrZ{@XdRuX5z9 z8~mP{{V*D!oyckS#!b4~%By9nVF?xg0^1RDMa z=|-x|0Av0ajod|w;{iw!^a>f)VC&IaV-&p(RuJfCwgeK&Fu=X3R6lT)P-j7+zoUel zGxjv5n}lQZF`I&8z#}A#QZLv-WQdqlqpdMdX&gXA^+NY(QGk7?-OKs*_2={Xyt43C zFD3fc2t1p@w@QgnvnBwoM1UrBU?OF}Xna&OPZvf=)SApA<0aM z_76&l(>X2~JPFfM<)`Vs1Z?#4wW{F|X9}mbY^-$5_SllmcO1Z2qH!R=a*Zb-F>vBa zt&yc1{e2$(#f%VJv))60Wf{cxA-IQ3gi)S}wAv`2l2mX&+1JK%&L?2*^DZzBtPJa< zBT&-vW{lmiIK^4K0hrmfpCj>8=FO*E8kr1Q4Qbk0W69aIRCz)W(Pf*V$w--PiI85A3}8Ec&^b} zF1qshS-^!}V^i#iXns53E=(J+zMq<-(E^^hPz>3jyf&SwajRDsdH4nz1-ji~C~*5R zBBL{tS8!hg50YN6oNqdhyDh8g;Y>JU;oy*}3wr+vl=$~<`Uv5L^glXc>dr2i{>Gk> ziEg19o&yICTXfxyJ(d_VQjC#yjx!si&YQjfgs4|~b&?8=@akwXS?rDsfg*?z%$PLF z7tiEFk?yDy3-_&XxZ-Y{f=(#)4p1244~tuW1uwBZQ__&#}h<41(bV+@$Z-845xB6dYrE)6A@W!6vC!K}qYL|3QBEXlj8B&*- zOA3!RhSQ8cl{TEp)sx{$Cm<#=Q63WopM@skE-nX|^Dv?A@1GdO2lBI{&73_>auxvj z=pq8lzRv_Mp7cRs2!W|ZtP1y{CvZaCAFd?jQ0aJklgyarxAY-%40yZ<D8xrSqx!$D^d09x%re6VYHK4l^y28_z<`rCQZDl>TndyYBvZQ!^ z;uEP)yD{RG(MOGkg=w?KkdZxJjv_w^a<36DMAK_USDaF4i{c-mT@Y?giDLzIe`M6< zw+Yx#J&*sJ@&);$y*v*31ld^-h|G*y_oj(T;lV0U)`>$Z#fY8#)5qZ07K6y*!P zM;rHs(YfY}K(4G@7Xk}jU!52kFJVXzUtij(nyinQSW0K{Ry=?e38&Q=TGdoXEZ`K?Y1SA{R1v&~dR_j4T-f1Xv16XrDu{Ri z?>3d_wm%KS`(upeK1xu<32t!iv;|GcuNFwwN?E)&Q_h^v zV zF_KXz2ZKuI$8)Bj;sE!}*kmd8p?~&dM@g#oe<=TvxY864usFle8uhUZRsPu4(!<>z z{0^_y^ueOQ{*$py;SYRC*~xB^6LVy;Mx^Gn9uP<`cG;2|zC-$1(PVT0fO2{=67$#< z)C@_KH|!U$5F4zZ2ud+3P8d!ith4 zqNXIvLa3{6AG*;G+P-HE6G)k^cX@PSyW>8Z@0=7y?-$%8<|sYUBprA?$8(uXR)C0d z_s)axI8vyCt_#EhM#dF-koF{%G6S`&jkZDTaxJ#y2wN$ON5IFe*mpt2s!m%r+K5j^ zE!Z8%=K1eKibgi!SiIn>!Q9v}L!bY>PEqX2?Ok;nsYN5BXoz|IJgkbMlhST~L_GUi zcER1qph{>gV50-Ai-G*?uzsB{2;73^!ws+A;x2L$FKE?1jf4090bW3%zlnQ=|6(bPzya*SMt%SA!A6Fq@FiQOqLI5^Ct0*pH z=uU=>=pS%2vE}yhv|!1GV)miKRQf?P62RX}E))bzi@~FFW&!|T(*y-bq`~l|7r%y6 zaOcmzIRCJzL$K%=^oKj1NmcxiLlhv+w#L$9IQ+}wlNU$lB({-1>nVOkK^nlHUt9vC z!+J7CAkyw)T zqUjZVJNF|R)i|pE5G$}FQde-I;JPcQsc&>Qs zLK|9DO-T+A=LXx3WZo1Af&ivhw7xlsQ_PWEv zHQXms4-S)Yc&tLRWD(7}IZpv+YM(xR#!H{K&!#YBqI|@EtAJqm-w(d$F%Een@O@UN z=TD5}@a2-?EWY6#N~RxlI(#necVX`ig4*L0)BWVC-w#ALWo?5#4y2M}>!kG!@fvLd z1IPxUKKnjLQ#Nw(dciS{$Uf>7={Q}njnlGO&hvYEtu+6fpCWByQ!nSEbH8*6g>=JR2B1ks41X3)ZH7p==GbVfF*LWho zVC~q@3;502B(>DI0RqxOAH>#AKC7xea!kkm@ai5Oj_{lF`6*7s)moxP@|vE?I-t=o zC}8kPzM9+sB5+;}cSk}4h6uE6tC-^sTya6N5E2|$6Si9YhyykF#|gX~eA4%aLJ2}+ zGzciK*%Er}DcBlx#;u|hj10*zCw^ciz)%xR0792I-4$CJ1aZ{^-FCX*c-WiAL3oK$ zXxW$|<~X6vOHztu01Hw_%XwZKYs?ryjE&h%O{qaIOd~x<*ukl5ta=qnUp70%fEH&z z15#jp2u%uoS#WLk$6p`4Jd2NM=$~8c^wZxQ9d=kF^$##e{q_gm@eTcpt*t_^!>t9C zzHHeY{2z}y8&b^U0Y(-4V-4%!%Y%!HokkTJ{jVEIvCBd1#t_$g_!<#&2O0oK*J6Kc zr-4{%Xbdb>V%-LsVc|%(01fA{ZSB(p7&AJsyPh*K^49hw&{O=~)WdYs*6Qu5`gPQ6 z+g9cC;WixZK1NPkLqp>3kcqcNCU$^qd^u3j^#>i9f*BnM`kti<^(EWD28W(4BeDya zO6x25nQ{6e!-pa6R2RVPc4M2c**%+Ota1O@m`1*ev zS%)ng^Ca*Xou2<8V6fZEKR3W^=-BW(6T!n|Et*t)efU zz~1oZTOK~ly73q^{N_7WqH4n%>{*A_q+xa7SwilpDV$YXk?@6$7z>AL$`rU|L^)%y zg<=xFPG-X;$!KVv;lG?J`o$V zQm(R$GJt8E{Rk!I0(}IZVE7w+d?JBxwr3dB1~~XAK|O#oOfA^3x~Wa&Qq!Ch=I9q- zEZ{YN>DjmK0h$t-sYu36W$3&Gfb=lK(AA92U~Gj90h|wpOm(P6sBkFUKA98y4}?6M z>Kr<~kZGiod3)_+o!aw#3c+H#8w7g>Fal(d+XV{sI|NX#7gSek5fE_Lt|XFXKf$j9 zhdgYIB&hX$-CW2dfsa9B41jDK>wzO<2#R(Qn$W@rtU@_0VZt-xLyd4CB0jycb+J-R zqxG(UCAC>?U$27wBNwa$+2V&o);5ZAM<8+2WBJlB>1ccluorkRdYe24h~W6MTX zddSj0Tjgq6JuBz1q3_u;$ELrH^U9atvg zJmVxm3}pZGs}EQ=9ules$eR@P`i0jJeG%gGiRKJ9FpruR4#VRP ziPzL~NeopsoGvtOOR0+w?h<{Bkz^(&r83FPxCqRVa~i2O+XbFn>#tC|qeIPA^=`f7 z=&**4BHtThtAGT*r39c}c|&9bR}DrV!c^O1#RAd*qdve?m00VZnkBEL{;ij|vamgP ze^a8nOsv-!+~G&8iHq``mWeO}$26Txp20-pY{a9gt1>zcty5*Q&_VB=xwWv`qM?{0k4 zz>Q&5!vnAkAIuO4#HZd&qo`yENcy;H?8voYUh5CPkn$@iq(8z! z%1Nl0{`3p#!h&M@!!M+iiNN$%0IP$puujLT0Dv>8SjqiqgvTHO>AzHl^qc-_#KVwST4=ti3tm@5@r2jU^)wsVsFZsWbr1`bVXm&#O93hCnOU8ymiX`Tg ziIr-IsscuC0=W7b)yzZ>xafK=N z2h_#G9I~Y?F^-qF!s0{rc!Z*R_O}~XJuw@;g=Q2IVpJA&cN(ImWl2AH&*Zfd(vfYG zN2_)3RBW)3qLGvdE~s$S#nY?FxkpGJGEgIe&yg_LL?;8_g)-H(`a{#DJ(i6~GxEaX zTu9IyZ3e6ST>tdbeYGW`7(bAl`#V2yM^Tx3~9niH{Z2+HK0=)i0(LV1P%=#5utG1knEe zv(5^VA+sH-wv9}*&a$n=zEra}#-f=+zwqNKsAfAIIr3htA=S{90T_uy0-v^CXzkSm z!2Bh6s}-Sy{@9_{5}AQd-GR)FH|3@TyW*2)QH`i!*Do@_TEiU&6WqK7a#JsIh~9=L z6XF5E({?2UdSHFmWgQ~7Dd~@tjeM`sq*8;C&T&raB3V6%Et8^yJX#?l(+EWd$Mvla z>|9TN<1jvH@)_G&6c&p+6MN0rlW9T8ZxovDa1=@OM2~>0Wsi(nHg~@(m~on6JkW9N zCke%u#+g(QMq)!D^PuX;>$1z{*ABAF3>Yp0TU#G>O3>{sbsGQ3 zdk%HzbKvgx;1~+G2Z zzNFT+0g*&uZeNFipwupoX}x%i;!>M`ua98~NHdt|HT11fDkAV1UWt>?am*PI>1W~+oa{(^V3YG$Q0V^y=s09K_K z!1Yi4F{ZR{LFir;;qu>UU8d1^4IB#e`YYF7}#4nL?vt;A=lIbKPo3_ocL3bsb3wq`{&$iZ_8MPq56 zh6?GZM!zW3wZ}@Oupba$JE$X&q+gUNQshC9~F45r56 zsBN$zPfhoL$3#?<; zuYD!g1XBm~k^#J8i9R%L-Bs*051MboCmK1}WX@Z&3DhQ-9zeIH zgwTm16I0BPhv|CF$)Q`6Mveh*Dtpe)qlz(4K9-Wn{|fk}bJ|%eY&gcH_!-)?kNvIQ zLLPNVHW`jgyj91Bm52|v&aDe*8>gmAwXGRQyQCo6x&=CBv=y`mTm$Tsf>~@z0sqK< zD{L#Z2Y-6k@WM8-}@Ir;4eH&}yY8 zyA|=rE|YK2q*+z5b*rcrbzlR=_8isu`RRGD*Fn^}?{?C4EWn1qV&iUmMQzNRUMbi< z$z)22n#4E6i;+OddbIELH-or=3pn(tEoxM+0@I^9Bv@LrogBhxcOjAp!KozijYtv3 zJ#7t>SY|;N99)4L9b@k3cK;O~xxZE!LG#C`R2w&WFrP=3PMjRLpp0+MQQ?+p@SXKN0*omFM=BcwBQ;{k%##o+0PDhzixO1GOf z+7ekEG0hz{*X@pq$gVAM4@Tbgws~`eDTkRDM#6z!jHFR&WKL7{v4ovmB>EJHaBO3GFW!VSWs zy~=HRq+l^=DE!}VkE<&j{ul^+C;1OLI4|K2eOyy~vV-@%ed;lCEohp(vTw^!2~ z5@LDNp{8`D3{`I=AIn@(3RixDt7Ghf zb5(+~q%)yCGl~ zbV8)A3K;XkRm< z_9){o7Kl%lv>JaO?aJ|eI2Hp_PpyACs<)&7#(ZyQdI*^#!iPbQHdEHDDDmX;(=Sdk zPOjmnu7H4dI-c8KuVznv(ic9uGv)2J|8{vk;*53Q&_WGgOUEn}@8qE%j9oHX7^b_) z+zrVjcFkw8gNgg=Fl2*Kw%w5JWD|0L{X5yOzjDK8w6tS#IrMwC*Kt)=yh`;*c3nzX zGb`^izEaJkJuOXoQa(bDNnKzwyT^!6JgV&jV=Fv`Ci(l`!0vh5*S}yLF6jpeFfvF; z{^B(9ZCH`O{{ zHcn8yg;`$VGT|($J}ha($TBL3x##+DPSVc))g>5n&|qg$eeN*vGNiPcHF6tpXB_*( zmK)cl#g9U@Je6M&-7@el$~`|iJUagTNW(6bEw{v6n$y|np`$X%piGO@+`BRK8;_C*`h`ux!W&p?*e*YfZ@W>< z<&h5tU4y3bE?4&WwdL3&?D`!rv_KF@ZOToniFa?|uS{z?dsrrpr~W4L7xPk5fU?-Z z!p1QR*xx<7e_p)*Zn2ucj%t(CDiqFOs)?P88X1wf*275rRLYUX<4*Vy)EoNRk@;`~ zA+Q!bGIxT!$Z& z>9}^o1!$=Rz-}QU1rRiBi-3c1k;yM?zopHv=g>v1?o=JMAv4?xt})S8=0RM=8RQ4n zxB|K19lhP|7i?<*zVG5vY^@g3+`yC?Y2RJ1+||0APhHCUgh?-2Njkr7o;BqA>J-v4 zLp(mx^Kc1vr1ij5t_5EED~>{ z0r%Q+da6UtvSm2hvC=_Dz1>@!*47U8ZT35VdxM6AthX(0!)C!}x1Sl5yWCzM?rT8i zM_0eh1iFO8o$7UF8Du}e=fx^!WW}$aG!LuC{ne0KDOy;F#*`|cp!eN^-tpMV0{|H{`ZIa9c zOaQ)?pgR&W2n$)ybH`G=qtb$Z?~40YYfsjZ_d4JNJJdv7*;ebAlU^uVZNG@MGh>n0 z4lTzn|8eX??(ic|Ci`tjjbKIW2R(^I(yk=tC?$5i8%!n$+qjQ6WHQHjanhK>Y?`Zw zl@|!l3N+Eqw%dCebkhOvUl7rLemJcDf2CP#d~B^5=L7TkdWIuUd7&p{c{CD9KQzXl z$&lm=o0|=xuv`H!4n6G>`CeDv@SBH29COws$&VXON==#lAO8 z4idb%k}a4_=hzinVFn+Vncp=hXi`Dc1=B|XZ#oS^qoHy}t zdAmn!&lrSTREs%bP&WK-cRfL_#e=(gGuf-?@av1~ijrkaP?s$ES`0{odo|}P&?yhJmt2xZI!N8&&%zt* z*V|;@zdShj!!o?cBc>`3;tPVk-ZK@HdNFgv9Acb$E{84SG!I45450k&g^X+UDA$WtsN1xjdoL4pnwo6#iLJQZK7U zWAOM*${iZil%MMU>4{F*Td~br;ZQr!^}tzBVFri*E2CEr?3iN zeR^bSEba@0%uez)BMetgog(v>L!t2&dLF8Xn@e$YywujP1)YKn^X4 z^ujeDHY?0kX)Y<198ZFpMl7TrSmOkxum(EMO7}>!362y%zCb!hzfR4WocQlGht7>$ za3%o4sp)uaN^+v|Wy}@dfGWUjl5|!VP4r=KXlwlE_qSs&ZE9|GZipdCuTyK*8O!4o zYNUF_K2NS*qts{BC!$dgQj76aHa@<*qdt z80iKcd`fv9hV5|K&^v24{BgpJS6|jE?a)jz8klQKBR+Q}W z4`)YRnM$^5=r|s7cdZ-J=*3~%|GN0N*WmX5H7Lsm;6FR8FwJ>ut{MC!Cw!9ALB9N2 zyr7gVk3i9HHbs#YuJ|%XaD31`$9>}SvbZy+I^V;0hW?9R{35B6eR`5gsNOL7rdXYz zrd3zut*uqVY*e>WRn1CkxaclDr*}S}@Gar$UQ1$;EK!?r1K}i3!MgCPq<95^&1ff=RZns)<9`(0kIO**}%BG|G}TB$C*L0 zyrmQ{^o+lo8t$`!gEYw3K-Q_u73+h~`a@Gq^SST4U6pc%#so(kBO!%}r6jUp(g`l@ zKhkv1yL8Z?UND`6;&A4FoMU z4RjLOzs`CPY5$Rx0NObGEt~*vqZf}0O82^zRPCQD)OH?0R}4sd2l{@P(k|uz*8uR@ z?_icLVa)d4wxjtF)pXkDizo-n){|#sqNGw8B2E9R?fV(x>p?t*r0{& z;%fV-?%I#wvfYJxsc_s~nEjD_#q=x5cQ#6p;q+~670l>w_P$P`b1-}qsk#kplj0} z)mZfr-LNp~yosF;aPK|hLV5{e!jBKX58e!!BTTl*s`{GlSl}!HkkN*(?8|9%03-MX za_8!xO_>ESgkCOyR!dJYR680{4TKuUE?1^J*j+$1N z(9~c!Cc4!q7Xw90I^djICQ^oOkFLQ-@uZr~43V>!Z&BiQiqQDJfNI5ePU73yx{+J$ zA8^;C&o4z~Qg*P<(bE2oMr#d&7r(sE{3|K{0ABqQ@V}q_{tqqpQ__Zz83riCNi9g5 z#=B*n>1ELJWXpYof_~6KP@uIprl>P9kIXfHu<))g9Y+@=j*YJDqTwf@Pu`8@B!El=NgkH<;GANR z=Uv%2bJ^hSN_~W;9)M^Zx+inb*ePn1NShUL0H1w<%oJUyPT2(O^1ZQ9Yvux8X-XO& zcoXH0S8x*SN8iJ#t6IMX3(x%#1FwSHf6Z(faN>;-w~X1EcQW+wrU2AH*ff{BlDs4LA9|ytlh}PJ zta}r+mVdy-ur5J7pc$(1pc`oIpzPQiTmB)vY2AUk+KVUDL!cEF?aI-V$61vv#vwH+ z;twH99SuERrrq{8`5mXxtD_Uy5!%x>^qZ6~A9{Grs{eKna+nS-PjEWkMRwe1jNeXloW~A(?*EM8g}f|jzj`C2QhS_3X~gSdJ7u+*qxJ2; zH9kCmUSp{3D{HU?q+@_>CnIBz?IOY8{{L;FX2*Di#;~B7ZY0gjg=of97up~P|XL(1u+dL8c-={DE2?v5X1PE0WRPJ&s{kQ}uaTB+-NRwbY zg$nWR%(XjvIdRGd5&{8@?cLdHc6N5ozMFl#^8I|6*HF?As@U$F*&nv?CsrW#WQsSB zi!YzdzIyLdkj{Y)JzSEgSzVNP3nukUVO;yJO@zB*o2%Q}6L^>?DN;oxkz4FcQ$wk! zdw8aB3J(>ApahOAe5~tqgHIF}mJqvi!9J=Uy+2ZG z=ok0SC{=a&c)kyZ54xMv=Y;I%U~?w>#pQ=nnB${c1T;J|Nj?`^d-LvIg;aK75ZBYs#RhT${GZGc0t!gd*cJe ze@BLSTwAGf*f9^7wws7HJEghgButWUM_^!RS?ZpZCx~$;s_XD-7YP;Wj_Q;vR0%%V zy17J%gpshG{f@~n$Sh#4O)(czIGpi5vK-^Up2z@?y@>C-G#}yzPY&{{E~gz~?7)8X zB+TexYFR=}4M2jd{KvJ>$T2J&Qi|v*UfVkgVwNa6+caIZDn(3lL&%?(pCm1!>)A z`d~9-T?(3tb=gwA25Jb)CkDwaFkD)oTDKriRZ)(KS}lcvze8ap?R=`BN=X#Pf~j-I zZDYJMWa_0 z6t6eUI3n>hai`f=Ted={M-dlWzO~3cFr!*msU%Ms$XS2IKq#E=12!8v*LU@L(GdW$ zJfj|~5d<{k!^lZzxTbNF6*3(ST&q-g^)i)!{yegb%5rLk)tIF){c4MB%Z-P5FNREpLf+`*PL+^&{)0{ zevx)O==@e`8xtQ{9eLr+D$ zN?4GxCH$qTwjtbv;a-}0%$%*Vhh{(}mFQ{>6tL){O`_gZAY=k_PhSkH`KYp-9E-C$ z%p%IRqjPoEkEI;R;htu-5|jiUh@*+y8O^AV26G)bUL!uPc5Qvnfe((Gy&~x7;TmqQ zxirqR-)cW$i}1FeJrn{G%9v>?pORS)HBqDhstwCN0KI!nZg|J;OZ12sumUk`sngFu zI_Y30uDr$HH(!*$RIOS%$|#|-dTNr|g6n!ZWl{r&w!8_>ubT?QZ@#CZ*HD`HyNzi`7R3OJPs;0$ z*0~c0}+7jA70PKr-Hl6=5lfV>jSY)gSSt~R^Q#%0-{ZJW$pG%MKr zsprDUQYxvaF_2IiApt40DW3rl+HC{DF5}i@oTnM;9bb|sYqES@ui%-cS9}NhB~~L4 zEd2DsN1M$EaO1&|Hhl2_FDLlR`TJhUaV;VeepNPHq9BwoJwjwIyTvW^tGp=@JE&JE zr)w%$_M3GH?djU{zT0pmGR(m8R(@WZtwFlY5kGthAM>AH7{#-s>rZSh;eY~dYn%*M zHm_RHNPI|M7cx8YLYB!O6)Y1r!lyk!qpMxzPudU@jG1q}b&vmbw+ z_~&w6%{FUb@oVF?VclnDZMOsp68l_uCE#);^lDiG?OsDrnVI>{bc?AVcLje22e@c4 zfr&<)A~JTfn`X1yrkKdjX(j>>FXq6R6f*=uca8WYp_op>i>h)Bp0wI%g!D7Wsu9cN zymwtcYL}K?Rv>S@h8~{Ib;S>o*O6kgd-B3UZ8*qVSrI@-CaV z4{fo9lfC|((hW1a0O*SkY#87-9jG(`-J?MhHrZk8BTdnsRR{suO?nD5u9~On_pT|8 z*A9Z)mM#8eSb3-|(&mNFx~iP@+-bNbTX3!ihqp6{v!9zz&5Lo+Uzc`-_a^R~Ah+yF zJ|wUfF?4QIgxRu)wlu;e_5(is0ySfN3rW=}0&b9fvtKkJDwXh$>Wga*(U=Q`@NNiF zBdY_4!fewP%j$j&$77+MPUHi*MR9-&19=&YE7vZXdP`BYN;pL;I>;NMOz>1>-{HKV z9)&P*Rj(_?_^$T#;j)`rlkhdVDu$H#sZR*N{BKQ2bSM$3D*r)_r#m2R66V2hGF~$Z zWr-&X6%nho>K@37n)Ui99K}f?oQ54}Z%XBw?snm{v{}xRJ&|SFd;Ug8_Bgw%Z*Gar zg^`F;!Rmppv9gN-Q)4Q-Ygu|h=O~zp-!Zpbqqv!Zwy-D_*Pl`$1xpc)0@Zuw)LNA& z&16O_N%>|>LInjUM2poD@HqfKgeE9NJy?z-hgtmEF1ji$e!hlx>BR$JP=0w7d z^fDvLwA%estLdQOLEP?G9H9EL;9@#NhYjB!RyMc2Vd3FP&?pH`9HEh2_^oWjDO397 zdp6Ss3J=5gyJq#$`!8Q!d<^!<=#%+;$6R)^+ja`)LT(fuIwR?BC&@doN8)i_0Et$> zwNLo6NSCv#N)d~=op$5zW*bQbCup;~xh+m&Q-ktR?3>$+I?9JG@^g`nrCEP*Guq3J zV-d@TJ50E8;1l=Q^d>W6yo2XS#CHeJ9<&rY9=~##bNg!`f8O_D9Qhucc`+jr&N#96 zD_v=5#>ajBAAlUsW|(3BJ&1^sU4H%H`xhTwyf0BI0z<|qPz7A`{UUO%wOcq3Y6(vY z&x?w@huJC^3p6h_IW^%84NP=)W6?peH)vk?2ReHse5JL^#QhjR`g_LbshAZ~! za(EpXJAf@W!>j5s5dkC{xh;PF!3R@_5ObvUc|e4zU&Kq(FK1C$=Fk9KSP~(U=Hw@~ zn6`AkknZngy@w$&C~5IH)A)bUjXyHs?-%eN=LYjHyN2)88TfNX+3krp6O`;#7CbGN z-J?9o%|4O8vdM_L1yhXg4%^#PF~1^%9$>3)N(M^Bc7^(kuKImQ^G@l30f&@yVHdxv zN|-fN3aHBa^9^8Z-CC5d+YK_A88DL+jtyf1LUj%(sNnc~5-WU-k9db=mK7Npiy@=h z1+=VDXh_0{JQrT|SH$f4nY$k3Hh^$|A7KnzR^^IKxC?)bD;&XDHpP*In%%mc8D$d` zAuc#BAQm1ow^ZQmqnhi;c_%v2JbjFc18D!*02&68kUJt$aFkGl#vRzk6fVsuuV-Y9 zdM4%$1xU>iok*|mqTxib88dmpzy4?ir5K8W<3G7HGcm6(zk2^O2owgEm0!5i+Rf-F zok^Q{40+sdArTiY;9{}NdxsH53Vi(~aNrg`q2AuVImur0-@NWsz$wbkyk>O{@fV7N zlkOsD&ReY!FLaM1G@jJ{(%5TL^wpnxjXqH39awDKwKq_4bS|bdY%bH-xZj&Zu{{MI95MFBC|P9clT*o0&2o$Zr^^@!cSIdm}LVHf%JV7W==$LIK>e7LHFUn189 zl%>(GK&PMGHrs9A^>G!;a$E-bkyEBnF<6H!;z9pFQV7(o?(q;(!EE>dGHdwtG#GJi z^85`8y_`(fvr;jsMSyGZ9P1Z zMAU{)dM?VXMA%dn_w?ETupSv1-W#7tJ5|kM=s*GmHvN|Jh0C`K3JFR{; zq@QU^Y2V0b;gb*8(k_cAfOvnSCJxY=$Ej>IQk?5VHdqVOpt`DEkc^Z|@lR|@s)f{B&@_3XNVtE*Rlkjll9xS*uo^(inZ zr{u9&&`tN4oakfS_CTqAuNFIO3zvq19ng;W%?Rr{YK4=N_JC+AjEq#mofSvQI{V+M zbYTW;V>T=(S}F9MF3Yrf*i>uG#R;M=vG&<1()E z&ErS|RsFm4!QY^f3r0-*3zTFvc|e(JilT8kc)fbXSDYAjoca{b*cPN9gZHf}Zlgnv z1#e!Kfw@nx)~sLThM98b0CLv zVf#8BsGc#`4f^YU;^D+skw+>LZQ|v-%{_(50WBRUD#~BbK1UZKbcADbh1J8;QAj>k zBbn^HgoL~m!iv4Us$>vt_0+;sABfW(57;mN2&uC}m-^+B{b#zOWyaCcx{cYQz(}+h z&C=1ia)eAJJUu5hpR06A==(TwhbQ4u3@1SWoYN@Q`tkH7qZQGAP`5v;{}zP9uC7y_ zgr)5-X0EZ83|S2?URCK`Qi3NnC(_^XIKj=(H_cs|Z5oQ9!fy+7a@gi$6Gvruws!z5 zt!r)~w89}H5kF4K=^n;mCMcjX3cwqV-7H;{yGly_5Ys0*RscfoLocc|E#PR^UBC3! zC^1jfDm9>w&!_J}k|h|;=a?gE@Ys@wk}SEbR=K6gKk|rw9zM*zn}3ijYWm@mZ>C?) zzJT8_oFTKJckuMJnIr3bv;JjGw=TutQQ_le09{Gwch&PvWsDQa2+;VtTYK886>1_D zH_XdYc(q;znga+Uz2&jPF^?#J)qqpsx_DX>`r&l?)pYQ5icCeCh{d8x?ue2dA?ts_ zu!syM@6D*2a-|rf7)EKx?`YyX&JT?Y@T`;QW{r!!a)Kci#Sz0H(aS$Q7*8nkBB;OT9=+AW8|3AaWk!a~amj!F!W7 z8C{f=eD!{h0Z1@ZE<#7AmNJJ=nx*tVY2)%G;L?JIPEYoV34Q6@$q2gLzZ@bXO0@t9Ti$_K%1OD#|z#d1$Cy@fjy_oNY(j zO2v&FJwWzLs-2WXk7FopQEs=mh3Sb4{EhB&Zv?cFIfMxnQfUo?BJ+&%C$t&Tgu2o+ zb!v#*6CvQ7syy)VDXWd$5D@6sO=%fS5K8vXxGom}B7||vm}tVJ3LI42FF8NnoaLaU zxYF|{17)sl&Dmy#J0$@m*#$gs`);$G{iDae>9O{^T56A5H`{m>&!Rutc6S8_q`-wM z|Dsb!r?%5XX@JViC)pgiX!AVA7d8V0%@4m)nHW42 z;QewUxg-xa5ykPsw?F>$U#~&k%vaJ^bi3mSVgB>$;Pc1M=T3=t=(Zw_xXU2XW zpiguXu=~ypkmC2^`|QQ=;;$=IPRy*_H zLWfO0n;p|vfeH0=0>cf#3mxHqc>UAa@6aWAK2mqKQL`mc-L8{0P z006i$000mG003}#G-@w!a&L5RV{dFOaCx;FYjYYm@caD=m41Pm-o$C9?a;b27+)Mx z3>0uuPm(#rAD`T}iE0%UWy&+a+-pdvU?x9-h6_Ol%8y`YD6|ILpMx z?E4i5JPijHJVU71Y<$lbQN}iL!fr(*5*~cN+b6p?6;>#+ZMi-|@uft3X)znRc1TJR9vWNCeM-GLF~;M4dkCvZqf@pIBCp?}1g)W9K{p zO#%@f*aZjoa6h#$Vm3(}GS}Vgvwb35m-*o?PBONSd@o)LH}ErbB#7dxG!FKea3zK% zDbW)@K2uaM9WVn=*6!-ZXCgud-O|`v;FuzX4(_Y9=1>7B;y|=wM-0>_&Xj=^a z1u?`AiIukN=BmQrArm}F_~B_6QSlp*y5Nk$pBf84b_cLd^n(H$vMh$c=&C=75PoDa zes?U(df9)~cdz*FIfea?&Tskv70_QLr7C5CX3ufcYobkX=CM~N!oAqz>Kn8TqLVs>So#EVk2QBNWKh>YSWsR># zMrzcurv1V1{YxWZ2hJJw=VsC#uy7VHCKq}n1TGWz%DFms$oFx}dNmmu_#K#O;mrDr ziLuiyNV)+o z;5WboxXib(SiOCg<^(QIPU7S&PFQim8j+@4LNi8C_IUos*o82g(XvBD-NEGQst>=X zau9Vtft3&y%#-7kSZgRMJ&3~)wyD$_^hcu{T}Pyc+@24e(M1ik0T8Yu2%GBxn7W4G zBf|W}_2OdkdhDtdQ^QIDi(;k*2>^m2IFbvs22T(U6Pe8FTtkGta7NP_WD5`(!!-~` z5Uj^nL+5o3_#T1Ip95#wsLT$*7sI+0)(8L_IsK~|Yz{E=K_53|mhM-ca)wgy5?0py z`n*PWYvwRNf|3?cM#F*Ah!hy4kuz>Y5e0L>92U)tJT;M-*aX7M{5UGkhMY)~xgP&M zp1f|BvCXl)jxX047~z%k5UWI~I)Dm##k~(K-9haLm3cY2 znKERj5}8{|T9ttlMtCxT=#2LWGKHBFQW}wJl(g41)$APT05rmF{RrYMP4dD|ivFP8 zl|pcSEU*cMiiKvJFvpSQT4NaJC@K#$;oQ+gC)bPVH7#k_8x~ZFe4;8^S2#V^4tTg) zbM~>z4m~!GBhitSwyM!8Xm>o=3q+NF*#Y!oNxQ?&u+U2OqNS*8y;GDRQMaX=wr$(0 zv~AnA?MmCWZQHhO+h%3w=^m%g`0wo)@emL3xc1z!=3X(s`U5UeNgKgTUM?t|WzSKH zhku3AMNrd)!tw^GkY&m|5pM~9$a3k@V5$4-Wap}zVd3p+c!tOw_XbhwLkh2v9heN&3 z`HMGC;x1=AW`^kY@G(RHp@eXOVC?>?m|Q>0fSxAW+mubcTI|XY3~(!s+zqrA*En>> z(6G^#G@#T|6d@&SjNT|ZcHoYJDV0luWIE3WG?N?@OC*7KMc`GABT$UFpx?3}O|yus z&k`q+^qG*G;aTR`pN({Zxp$rH8DsMxV7owE@3VfyK!de0Og7uKlFeY7RO0JRefhjK z){oq>dlUhood5ymxVLKdK}}Sf8s{lq#GqCv@hFAEXd^=@yYGDtz&}FxxkFs>AmCpx z9w`<^2mX*0q_Gt%&S?k!!^i&O;}-8hV`iwSC_JL}*Zbh!s^44**AMZMKY4E=NgXr+0 z&y?Xn=te1E!t%vqv2su-OIz20e`wU%yMV#AHUrZ8UGTC-^KO(l{V&T$8Xu|CuWO} z&*^caji59?+UP^=fwJi|DDL9Anz*nGEPu4RLuFB%GFnl0z&o1+=&WJ9`3}sB|I_qA~6-(&=Aq zplpM5jKE(gk+~c$ZL^!i88M^o|IH<~o)lv)H=)(4p{W;_fAMl?_%QIJ7Xc{eT1Llc z`tclvN)NPy#_S9F%qIMY4S3Xf`ZM}(IW;REZkZ2W+5{tPS>;FS0V59&LB2c8U&V=Q zG2^rA&AwPC`YqI~B_yLnzYW^CAJSsmiUXb%2aRFu@I*~5mDJ~B?qz}^Us`*`UCHb1 zo-1Odj)@@9*DPf)ylD_kU!73owJ2W;fAwjJRBW;48tv&U6p+p|lB0j;AD65}cVOYl z2k6W)koj>oB+4lXHh9IaXI)xtj@_r=tI&6n7_8o19ikRVjvdY>UM!!&(hUX`HoDAR zuSMP1HYv=dP0G_MuvFEW`!iH?Qt#0KuF}M$M46^jGg6} z=*X!|TrLUC?damvEo5jL!j3RI^l^emg4Topk}ioSm4#=DNE_Bjpx|Z10-Go~h-I2# zOUl6rXRZFBzzq1$P5kTml^&qHyrdBCHA?cUK|mIg@qVH#96Fw#zL#E`0n$r4MIe)) zM~=OT&ZWEiADR{LUetT)McW##`gF?j_kDC{0s&T3Wdrw@Dg+j>$^^A+N3J9ZB(h5b z=UpBTHS(`Eb*{<;G$?YDULYx)bql0l)90N?2&=?`m8EfZaThvTh&q}v9|^O6uzm8q{AfE~SwEfzPyZHp)pRUvpM9%v z^6Bcy-%%$~H+Seo3I8qBHj z&n;o@Ul6P8YrZel1{ za1^6SZy5QkCmn##L{QV0Cnee)zAK?*4@yR>m!|PWUB|tf;>U|VkZO(aK{zJSHS7WZ zs}(Lo8&FcpCM0uEz*_0NcisDoz|b*-<|v_1$w4bKUcY2}k+cvK5wj5>sP2=D8U^v$ zo9_H4@4F0V_JiKrfx^|E!%_4ypQAuvc@(5+pkkDn9}mh52$>kUTra|m<>;_*9mLvK zGoULYpljp@5iq8!Lg1e#wY-CS{RFVPaeK=17SayONIUk7eXgdfKsO$4@-qb|9W?4u zt!NpPOsC(&CQ=$0plWQH=;W?(bGc(~^#HGXu7X?h>5q-}AFwrVzMb(e`plh!!p_MZP7rMUk>OsHt&tDi>m=Wm7~1bf}6!1 zRRPqUBH{d>Z2yy2USjl5T&Nx2CKD6@0OFVE_S^ge*n6;8(mPq0{Zifj3+q&^Ha_!z z@P{;QBxN9QIP%4K{4|1w@#?Y=pls<)CE>W%p}9Rg%(@3flwq`?tl>Pn*1J>et?S!d zZW`5Cgz#D~``?Y7@aHnFmn+q36_{N5FJrXWO&?@f3{lL)ANjyvT;e;JE8*QIAKIqt zpGKZs_ErK?WH6s#AJT#ahZ!2d?tX$J0U$nd*RbYWh@c2yf6S^?d5{-iV&Xo(`gWJO1hQ!98)d~MkI-V|3V#HQLR+PV2tOkkU;PRYxJ6I%*X! zOfLU}?!Zm!#4=z2c%JX?XNy}2 zy2BPXG&xzp$1DjEz%;jq&!C(|heXq7Wj4_UPMlyeN*2YmU4y+-vhM|jZ6E_IOv{_9 zXS@&<1IZY8V{k4wQKgWeE{wvPFX2^x#&NRlY%~fW;Gyf5>(OUY3Cp$N5j@Ky;_xmp!0~LxGD9C`ZR?L$@4_;}W z=fh)12w*ymhg{cUXuaFij0hye)B9sKS1Go)4ewt%Q`fGzvc)#jlLBmD#H2c1oEYQG zMwmN~t167L&u~1}oFPP*R!k6ue9D8I_C&TZ{t+RVbcko?ffe?PB0fVI$1pk2Kd7wh zd0haapo5G~rd;;ZpAybe3NeJwRC5tEk=WvTQ&XTJlV1D~ug1623QMEGT8MKnk_l3z zUh_n%9XL*gvv)2LG9-qM*L#$fF%-x*!93xfGesyOM7`M_?NJqVgx*exzKD`9ru;4M z?j)9Suh(lXs25xBniX*Rz^78cSGK9@cKT0W0{HYh!by93lh6;##mM# zDi6HFy8~NFJ0KQk`1eON7TPL!UOAkqU&zAjVLDmC85SA`vXYK=0{O$GxPv0OoNl*l zWy`lCG}IJIO8!ot9|X)uP&s&*I5HWGAEOVEH8J6|;>di}t9O7j6ITEz=6Bfb^^ltG zBDzy$^?}V{bnmzcx)we`#sWOY_wd(YN>OhxCm>fJK7O>N;Fy)+m!A(_H2~7h&i8J= zKXeVq%R)u0vwUxd%MC2{#p=aDbZg(i^Vzsn-s9<2HDIIDHBuWkek^yb{Zo(@$t37H z@>D)%CZ&v8{}O&wn$KfGl&|8E@!07%PFkocAE|9e6IKK*YUcQ$cycB239 z7a;)t3+Pv4H}1O+0svt7tM>mF==XnKG&Hd_GPf~sv|{{K{$e#PyF~^R-`m={0Vyb* zjhM_f6pKE60?v|Vq&!9u+k=*!EcS+M8`OhdC-3oM63Xl`A^Oo%@6$)KTq83|_ka`+ zEXtX5VV3bWd5S?ng(=UQTPiAF%8a;!)~zw!*&(w`a6EjpQaK( z&rnRQy)#eIAy`B+^}QepB7mY-3tw+vmY=6*XEfh8Pdr}^-fwokcSPCPQ*yoCLP$>- zUc^Dg0ypk)0>}Qbbqq*CF{%2Pm2|8m0NcmgVPN=2&xYqXvqJ_O@*u;;pg z_5u9nvzJB|p$i z$|`df(mWUMiRprO6jj(N?qpyRt3z2TsIkJFV{)gGiQZPL2?&^8QPcp~axOc8k7HxP zn|Y9EFV!o`9y0Nh0KX=6%zA@{3V{3t3n4=ijs`8z;Vkha$6!M$Vj4r-lRNjC25pV| zDX4aLaCPJ(C8)23Fm>PEGR@hQy5-D?xrzD((Ucwr1-le8;XHb79YUkWDqkoffCVF%3N!+mMv%oO zG9>{jK+wSnn02eaA?DlxCkkxc^JD`g<(ROF^Zew`JLjn#A~m5Q3oP>t)SYKVmE9k< zObT=$OW!yk>53y1zfu37YHSkn%>nD8R0$x_09?x=ZdpLVYIaT&hD2Fs?33bHmle7o zFRelxTi7L6GXXRPv`wl#i`xPm#y@qtrq3%uj5p7hHh33;NA@S-g)BJz?GCkRKzUk=>lp3Bg2)0{q#bU|ypdlSK?xlbO^L0mOuNC(y}#jVC?HBUad$ zt(d1owb|rRMZBiprq8tg)pov8Oh8vzyM#S)5nbC#BaXP+dZCmY-w%1Auv}wTh|z^M;)M zws3#Kr};f#B;a|1e4XEl{<-`hnQuOu-fMyA+j zSVB)^c+%~FzGmk?^|0KTf@9eSBme33W#$ zPkG+{P=S)7a>E0xo6rnPba}S$ER3ziSoFYP#w1w5PO;VP)3_F1)8@}%p)GjrY=F;2 zK4RlbjLykaUxB2TDu-k^dqME@*?=mjx%@EVm=I3WTT$gk-hxZ@5RXuS7J*J)fzDvN zSJNUz1XE%7p3dZ3c6U%cTC2M`HDu#!hJD9bloO3#-0fxB910*euR+FA-Pg~`{gZHy z|8osV96lGnR07MLY~xox`#Z3w##J>#C@h*RMtUO}a|%*m)sb4D9=fL>Sbfat;$%X| zS?b|Zagq{S)W+X9iW+RrQ)kUCOV;(y1zkxE`wlrE5OnzIYCS?Esg8wWRBWnv2TW38d6GvAQ$Nw&vYaA=5P2s!Fzr*_sfv`1X zvIhE41{8#}{w?DyCx80fXCBs7sn4TiV;Et6vbVdrvWpUJIAq*m(^>0p>RUToXFHQ) z*8<(}>*Qq%o&SaTFnd1i5nG>6F{Sf6Wkl*dvS-v$XiU(w(n89;Ya8P(a>75q8Ka)g zGbQHBL+!0gsC5j-s8)_QY$VV#^&nSN?+K_sYYx&R)KeiSxwg;!qo)ExQodH}vB3K=6&5e74K9%UOed=R} z>TNzhCEPx_{7f`zxE3c4Df!QJj6XLF2BB{@uv%0bj2zMj4&A<7v@Tn47$bD~9#&lBaM;nyy(6X=p_r}n z-x|YmG$h|s1)`&i+=$0b9kzV#BP<7eizTDD1*_&FW;}bjr4Tuk&}pCj_(iW?_tq&>#OLqIra0Yr!Z4P>+jEhI^6|h z3m-2Lqj-Bu?pJrSe!)n81^fyyrA%&aK@ZQBU*FQWsVFoWl&rG77;Nbt7P~K&OV^2F zOk0DtH7TzoC7b*rnr^{cV9cFkbDGWs1gyrVq64CO=+G5DT(A1*G7+Z!I|N z&tGprX}n^ypR_Di2sgpDc+DH_=wA({k`3uWrSbZ8pRDt+_7*Y8i;F$Pfn2?3N?AHe zg}DP~1S2|i?CUDkvCm;b8DKJ`?*<(u{H^uL*IbwIhuI@dI7<;}5a$?S*USmaQm2dM zF%&14y&3frQAbt>2K`2sOuMSpQk7db!o()|$z6S|%F`&Y1~NHC z&Z08vmnIwMtGS*x59Nnaqk_G&v%l}P-I62opqPFGyPrgU!{w6z%R*W4u~P9?8{&6$ zNCZ?obFJdf7}gwjyf_UWq6RP?0ymJ!6-X=SyQPn0?tqVu{^Lej0nmuSdH4%8kLs7a zZV%~!X!8QoXH<|z(k5fYzYAqFXA_cAlK|Nl&rG}xnHwQ?^?JzH(JyTZIqM}X)*Dfp z=i}DCP-sf!#|U(=W=-n(Hraotvi-gER)bdQO|_~IR59_bJi)W!Q)goBpq$i5xdvopAP5MVn*KgX1*JEdIoOOd()UA= zY56B7EUp&inyXsZkD43opxQ565pqFGkjhcD78g!}Mem3m9I<}#f|Qn}8w|)?5a=JK z3Tw!M3t%3lH58&S!k^(pC>kaM01<{baWG<7VlRFqEJ9ing>;=zbvDgS8=iC|GY%h5 zAXkmuNI;(2MhbLLxH<7R|AHFqJd~i2bbwW!ZX@81QEg7{B(t*!V90G8>0x^6d5*b- zj5n<2p6GcVL$fw$F`?Q~wt-dbanaXMy45}oS%jR2RnBCdhyA(r`LG<1#zLs<_diQ;4Z7Og1Ryc7^?vRQ_MN6S3m8e8c;z}UminX{VT^9 zoWy|9X32zn!Yg8Wn*=ws^R8)Tfrcn115HNUEJ40xc;6W&+9Bf(}lF*%Km^ z9z0oqV#8ptIP}Nt9ol>@>_BK8AUeb(`MDSw*>}&d3#|eL!P6x3x<>R9HYinW0BNo$ zLb(J4Yf-CzFu45-*KeP&3d)^PuXs*|B3*!AT!oLqE)mlndMT9z=<`ym$l)Ej0sCj{ zOgAR9#j~vQMQ(FDaR+CX5f%fF+$Q=k^xO4Zp>Wk-p7&P8g2Fb*Mu@|&)ZCkz3g36m zF8y>1m&p!qE7$rtpCbG7@RK>PV5H6y!3o@;n**=cZ}6)fE^QMxf%rl5zDd&JxzZ~N zGzBOMbp@zyQPF_E_AlQCj7>2qJRFxy*LsXoaTmUwHU$DXPPPYeAbrV7o6WR|oHL!4 zV74K{n0f2$d3PM8JHgj=nX8h3rtp3aTp#B0^(@ZCJR1LJ67DN}9PDJ%C62fBL#=-Lt>c~%7yNapReYcFMbHH3B%4$`d@Z;_=mf7{NtQ7AUu^ySz zVMb$joKe4i@h0tY_lpW+4WI|yR+$U8TP#>lr}aMOmY>(WZM>#NcJVwrC!hQ zP^uLsuj+Xvqz9|8(tzH1qGm{Fh~UA1z=+%s*0d|CjMw#SItQ44^HHrGz<~j9(&hD` z{D8BA-g2mSqkP4Kho_mRnXSH(d^eu{{=H`sl`1kUddJD#xFU!l@tOLRCRf48vcMGC zbC5bu*jJEkCb$N}NRH9rvb6kVkpPuYW7A8eIy=5uTazJKSBKMsY zhXaj^C8y)+yXd{VYrH2A8BS7AyPbc#11k7G($%11TRIk&n!DY}fAy(X5BpHvY z3>N-6x4IJJlFvI!BE(&6k76aS5=mnRuo*bk#=_@{DIbWE*}gtEAUo((vt7L-y}N^5 z={u9>gZoGHvD~Q>gs#*CY%IW%Mo9;!t?=8En-loE$|$)ip()X}%8GOC@JN6rbV@k{ zXMzgB*Gw{Cdr_ypsJWMg0~CKi4;7Kbd%5J((a0HtaXB4FO;ZI&Y*`PSx{u(E4r6lZ zq^1>6^Np5xV8?^QrW8h8al^o9U)+JKVR1}6eB0NF4n-c2Ev&20nsasPZV9L9)Cda6 z%(vaHnDMyfg|-NV*U+uDEJ3LfUeTT74Cbo^Hu-r;G(f5wqzvi_d*KyRG7zs>tY%yL zSWQGOV1BP2tSA22#_h#VpRma^!R|FNas4De~-wWS+GYQNOdWL>!nnp0H-7P4m`fan& z7fXp|EheYj-SN_Ebd!Vr#k>!GKZbsT@q+S4=CD=tT`tdArEytvw06*OSa9WJcNERx zfyBB%W6(43^4`YA>}?~k7XcNDwhCNkI`u+RG5 zw3&_pPl>O)1i3fdpeiDSTOoTdf?H2@3@@`9%&%YAYApQjgtsZ(4iW^~)@iV)2;e6y zK{uxAsIF*Y2R?b3bJ*Ls+)15#1NhA}HfL|OsrTh!kS@(2NBmN%`-Hl=5K_7(AgY2I z)FRUPU|M!m1Qh_%T0WlHi4@N*LrS#7*u~qZgR;p)$j(EcbVa@g`(>8FgO4X8Qc^_8 zn?tobMwomk!^#Rfh!h{lfY(Z~PX-W`MbA`#Zpm#4I~D_w$(GrLX!C}FW(hC?^!e)$ zNRnbCL#aa5Mcwk++dqL-M~dZywY}5WLc?v*Tte>KCi^5zztZCH z9dzW4XB#+5Ah!mbbB4rAzVy(hU*=Jz{qkYIT>$gOc9@Py;k}nh!EI4@6_Y(i>edNY z_;kMeS-Sl{cihr#Z}JAN7P1}ieI&iBmTS?xBnp5SZhx>D)(PMs4`Qhjd*~s4{K6D`BEwuRUOx7ED*?)TdtAC-r03bQ*uQA)GT;0A6 z^aC`<<|ER!$GCWZ7QAv8rb$+r!^>=-birgOl3puSiwag7%FfAs%(M~&5^%nCbxXBb zNe~ces(u`H-8_?zsz04h<^I#6NfUelKmi5-Q2VuMaQ@SxF}3(T-nQ`kAA4q5ecLvZ z0j1}ZlExEQGYF`Hr5tIY!n!+cNoJ#?(xL%}y{TwKHdb?wAIp#kwB0nKD>q*MW|sC6nEEMO8Qpk1sRtNtz0l+dnUq=3Lh(!xu$W|e zMbSbrj3t3lKF3l%>XdmPrTr)vOw&SLwb}z)BRhH}?sDQ3Y^^(>$RVs%l~kQCr?!g( zHLCP4|DT6UE@|$6Ge~Wg!psYu(d-N&T8&!`RT$l)>X0dybre($0nD@tY|-1$k$cUm z<#TYnDRMlx@VDcgac^g22Q;(eCgjdXj-c^1K4}z5CFo26+@*fR;rs(?BS`&-3-WWJfktsV==o-n#za zJ?Dy+N4N$w4FD4C$KkciX_~|lrX8-=uY1B0wU`y7yX`fm|E*nckVL;h7zZT_7B9Dp z^8-Q@xFZs;JOG2tji*Nt$YFjKPeG_B0w+R7S_TZaBs$4CZB}hXjG;Xp1D`=A_@o^> zbhJkpj(twTnnK&|?RGgbV?zMnzcSp2A=ZA~eR2ryjooY3ctdV9AKf0ochcU7?we~H z{0awVZZCPyIlsAhB|Ww4K6W?k>NxxqeEW!bUBJXCBe#zN`^+lre;Z17PwQqF8^^IK zmC8^$J-P5x_Goh8AQ9*h(ahRd0#4YU_vFm$Tk2}s!w!4zB^EYIGJR_iDAys}O~}^} z8TtYG&(%a@UDJ2*yP5`20RXW6b2VAm{+i=1HovB)iSe&Deyk?t^j{+*pK9iNhRh@6 zvf1N+l74rR;1hs!Si?I-w%>w^SP!@HG~*4%Y_LP1rgzdx9?c&w-77T~It5L1P~slC zYB0wpN=h+IZ#}{My;pu)^6KCx-RpdvvRImrnpwg@`L4L)G|j1K9LiaM-Jq5Tg8<>D zUEA+lH(9GH#X}=-Jfe2vvlo6;Bb|2=os76sw zg86Q>=uNhM4?Z-2P2vKd^}S`619%NkZh|z`wK(7-8bqn=%TZ~99mO=gpfG&m>nM56 ztOpO4$RM%v=wK6{VXNk(zO@u_SCqt75WG`-G<-h8b}5r?x%1NX!3y^OEdG5t{%a8P zXnuS9`aZ}Mf)0RL2^%0Y!=aD`ty1QyQd^zZuxeE|7%68b_SN!~TDgO?z$$BycPj7h z5e*S2l_L75%+KPw)s0f5Z31$AhPjZj1Bfrtycoa`f|i_TD~0%HtZG!S6GgPitSA}d zRAq({4in`Pqgq0DN;f{hze$vr}r4 zJjgcf?e)l%Tn$*oN4cd(-11Vg8S`2wb^}>|`Yt^R*LK$))}eW2o5;f0P2MZ#^x)q)u;z4E zwGiIDdMMrf^W)&@E%snHnR}mOrH)`r*T^MK1kDS!DaNpv0KbST=pNEU#*KpnRRf?# z_{xKKH5^+tz#(xz(OV{G?OQsrISao;;~x!+TrYj2X9=kje0mYLPIWq7-Gk%E*2hPw zF86bAQ-YNje*+J8J0e!{KfvHqO@WY|0vkSCU&;+!zUXJHF_MMuT@3|@0+5G65B7UL zlLlYBImV+T%M7fb&j2Dm68Pv zc-BSmafeC%xGI(9Uiaev_^Tba^RfHCbe{D^>_@#zwcy5Y{wxepR&$zcC7u36YCMwv zr$XR%`2X&l003nW006lEsSuV1t_FGrj*bQ%|9u{ktSV!d$&b+eNUfU1FE}>h02fd5 zM`ocSc|&Sb$+(*AWo<+gt)p}3x%bxh&VV5;1)KS{JDm*%y64tQ4oF}>>XDbf9IoA+0AHXCm}zR}2=ir(FgeHIMmPsi(yf>ZR8exFYAl$DRF3T(jZz8~ zN}q}7j?77wst9i1168pY)00+;GKtf|B=T{ak#~S=uKVT&AG*k30xl;* zUIvQ0fWEv5G2$-)HHIPulxCQBpkheQsSSp>1YWS;queyig|DI0D+tmcICaA?ZhV<9 zs=sgxbN>ZXtAl{*p-I56D{)$H945jLnQEX`UMbxMDP`SYy#& zXHNrPykYq-I%Gq0qbFOkSKTD8s$*j;nSGH+u07(ltmp827?3`~e?HU7wAjFpFG>7; zNA-d@k?sN74fc8a{4`Xs`-{#WG%}3g$e$Wz2Vh4B02vlX>?3k`Ahu#hFYy_6-UL)i z_w-V=q1cj-7vRmML?CnOyRC-s+X(t~kk+1C0edz0;%m^5rzWxOoux~wF-xr4SEogX z$ZZCuCSeEFa<}&%6Yg%KI{9XTtAuF%3qTm=ozKp(DNFwp_bBWl+OWQ(j2z~8-ZKKPY7t*Z9`$$Oc&`QY&_|*f?+mbWY(zCrz!MPGO zi38KoO}t5a759sH{QeV+uOSBvz9v^9W~-IT5I9&R>jk*P{B(pe2agWoIx3|lZck#df z&jyDmeET;E5&)o);s1A{_rJ^6$iUj#(7?#*zbp7!(`@#4oBOw;X8?tdq)-4C0-{`y z)F37@iIg`cv(@S$aj-|iFYx6yL+l`$30L9=wZnS<{*d_aU%pfEI7 z40qw?wXJ#yCYy+AqBbuxm4JFLQ777Z2MJk^-ugB~S_s&1ZIV$GwZ{4o5Y_-voPR2x zy4I755c1s?ct=#?n}MLE*mNWv8Up#A^tK~7J#_=Z9%PbM1f|G#Pb7+qG^*>K5HEY2 zozfAt7+etFn1Bj#jCOW<5Q6o~M}@xXngk<02@aRx<)Eb%Oo(F_!m6gv(S(B~$}(2M zfRht9ctZKx_X2r#Py9RbA`0Dq${V#64~)yA6DTPJ@MFPKymXt#_~HX+BDw}XP?EOq z=eKbXlZKr*nb0WIuxU3aj+0esQnQnjC)iCK?4c~8J_6Q!lG5Zf@dC8vwW#{FvK5r0K#0Oh9k^d zMJUWNRS6W*k~ZvtS-W5J)zl7f4QHumPxMx?w`?qEJj~rf_Ap8j;XT`r5FYGPyffti{KPy3mE_?+kh>(6k)3v z$s3hBIeSSgbt7|P`5S<@<@JUzeJeZnbe?!`%qsM18#(!I%nE)nE0R+8GVaGFINELiz#LT~9t=~={UM!;L zHRNXC@XzZ!crV;ie!>xI0h=qk)b7i?c0KPX+Tr_~(+jil?c@Px`0hY(cl`JBMr)nu_VIicvCQviT5N~Q8em#Mk zPALf~tR7U`^7?aVk8RCQj_>;0|JoCys-nqin&SJMI{h)gXm7v?t%+`nuX4_uKz81be zieH;;{9S=z9Kac-h!ZnW+T)fh$O&&II&r;(OthXKFBo%vmPV~bvyz5)z!Q?GhlRTT zK{PdjkaL?DV$^xDwrhT9#==~x8Q{sxJuF#!Jz~eP=$bo z1@FC`kyegSb@?{Pe^iZ`l)s2lKFt&`)*j`l1R)c$HsxZJaon#8svl5kwXzBy9Qj1} z6z&S}LPi65Iuu07IDO&>je!0@hnKTHyS&;A8mlXaP$%?v#EaZk!>uc}8g#a%WIfNQ zzn=i6#F2|MYUVwR1+DkMZ;Q-bicRcDm@>{4e$L`+7!CmXx;N2VHm6W~evcDbrG3iu zi_x+(aFgLLxAW$fD?vyvzAHrA{W1PC?y_{S5X2PkSYpbYn_+p;UbpJg|$D@1J`_#3Z!VrzuWoP+m zm0ktcJz*K!|J79j`qWWFJIKzv!es4kfZ3PrE)B-Yi zTDP3Xjrq!s?ehJBm03-SgfyPhF{XnVP$1cJBgLd1srR%@hUCgNF|v;t+60q)Oja(s zo&O%{-7(mA39|eOO+_Q`WUZ|Nk`r>pRVPrapBW}JJxR^fFiO-+Tooq`wHs)M@Z|%+J%Q?RUj0WtU?ziw>sQuum`2%qf_H zEj89M&40ZzZ<1OO8HE*zKM|nAH`?WY%b4+1FvGq-L#^s|N!Io-`azxn|McTM^Xo!e zU&395Bk72JOZsgUoDYy6XG_&tUO1zrUi?HfWY@M1;i-Gi_L!*%e94JCh=Ig-+@yuf z-H79w8{SM*kO#-GLim*0YvdIafd~E}H}{UxNb^;FnwBT%Lrs4`D}wH}3PWm{$y^nY z{zDd%Qr%bnm%ZzOYSVd2!>iyV?o@jliA+Anqy_JDhnIYTuQTtEC(Ym$J=Fy~6|g}L zD8wi)Jz`nvHj-lgediDD4CUPclA5}`V`l-+Ac~t0JdzFs{lS`dj%gz{Q01qjZB2(f zuK%R<_imm6>2?*#RAJ3&luy7Q%ZEn&%8y9@3)f&()%XRK^05R<75D45`E%(V7;d=K z$K0lxS><5G5FdF}+Iz8_vVT&GUsX<662e%#qO$PK4{sf*P9%6lr((v++lX1a=`rDG zKrAZ>Fru&t+B($qbB$?lp_^B`v{(LX%Y`GSt=RixGI6e^aV=pqlpLMy3jPK^7>)`` zH9t^LA%41Xrk)<-|K`Bb1RsUP{YMPKitH2 z&f)|&?|3Y&0s=o>AbGAys^zUXA$!2j8htYN8Z&HY*y z%t!zL*#D1)#Mt70#ej2S~8*C+?j z3SXZxN860N9R9P8kc9w>y50it59WL{OMM2%Cy5{jkh(^482_r@B&W({Y zPgkKFtf*WvcP?gLs!J{4WIP$cm}`%BHhB;V(x~GU1D{cqYq^29T|Uka!yXUvAFAVT zkHo6)7G;3D>)Re8@4{4)Q!OiL{Gl$VCN ztjs}6s|PXtVr4u_{Ye9av4R{{Oj+M{;4v^%w5Nb7JEQ%jE7?M94jv(%;vD{Z%F-?| zSat=wEJOXP50TPYH)&V6X!~y83qd+udMcB;Zlm6n6?HSib&efM`6j3{Pge&G7VyB zUfFO{>ynm^2oC9zWAHBfV$PyQk8_eMmOvS(qr#_*p4my`J*PRuxkM6;5H?t`0gz;p z0c1{V9@@14CLd$V)uEFlry%$r+rX}EkUVv64@c>wHG@obE!+BlH#KXXBjeM1={`0=W?(pV zGbHIUeuL}2fsL(D{=+5T$}wXVM^jpEX>TApeXh6WtZQ#Z^p5u5mbTnilm6K5VueGkkVZ&Y^&1H&jzDu5>l?) z);VFf^cJ`mj(8y#Rd7prW+|*T)&Q%J&Xh-L~@6S2K- zfd3Q%t$xgq+25gF(J}x4{Qrj#m>D>mxc!eeagBTJcqn}JH+P_IN?*tlsg);B2vf2) zIIy#&>DcNldQ^h7oMRNOWl>k(oggO=U1;IP^V^2pV*iaoEORYEbwN_|`n9Fy@g{~9 zYE+kjXRfE6cl#Y7*ZRDp#|4Lx&A^Bjrm57r2=;?#`sS5yW8-5B?)-XQ*RlxG@uB)?ZYZIG}90=ezZhgGtR zA&m4Q?Ym3=!6vS5HT)KKSnA#Q1neeB1Oya<*zV*5>4@^DgS1Ch>|t?A#+GBmdY^oU z8%m5OMgS7&F@P`r&R{D4e=+t>?V*LymS$|*wryv}wr$&Xc5K@=cWm3XZ6}qgr|Wbd zT=doY0dszuYm9ds*gr;)BAuQtdbqkeZY|VI$M!hZ)X>Yt=sto(4Mw34c;u+|n?C$K zfo@OGl7HY(x@YAT59mqzYNLzeLVo<+$;@$p|$O>iV#BQ>> zI=@Tq7({ttq06I9HW^Li_E+y%s%%joJqPe_S!pr-RnBo(Uz1py9{p-ydO8r9?(X|G zJGYt{cbZ=oJ4a6^ndaa?n#^Vhyg*?1L5!$%4u^XB9?@B`cbH%= z?e{|nwnMtO=pkN&{Q*a%WnCc&2>sZ_m0|nKkpxPJLM&fFReG%%O*?#{+WuHtJb8)A zgp{%1Uit8TJhz!_?{%LQaGVoj65bH=t#6nTOKVF@jVxUcZ##*Md00kQPe(_Gi;tsA zwvVBm&S%|Av6F*`3~7+8#uILz&Ih$4$>lPL7q?$;$fK+tCU6 z%KB6L^HSrZc{cDT6b4Zqs(=Zmu?XjEOHbr-?&L(g93A@Qs_6+`1 zUS+g8>bB*`=fi1=AO}`)du)8`l_&SRSMWr`ZCXLW&=)U~;TSH+#&V7OYtM_QGqA zxr9XGH;S_rZJW-tsT8amGjo30eUOco^pA*u3o(+rQP|+T*BH%S`W_#_q=-ok{;jF1 zN)wnaKC7m1A~g9&kG;2UbZC{jmEXvY`zIQ!9YzZKy;y~X26HK|X(Zm|YaRMI?en6$ zW7eP->w>e3xfna~t&OQJBXR^>{;H{bnKc5~KpMab)?v7=ex7uEkNhEcjS=>PU-9{7`mu2Ed+DzCl6|r|^YW58Dn=Noc zQNGD`ezx#s2)dA8edyS&oW1dW*;0FXm8fH-RritX?EjSd(h64PF95(;f}8{F`$`ufj)l0iNCluThmt-tD+BMK_X0ga*q1O7HouO zl-3p^NGkO~W||mME7KpI1)OU@6*G#VATSBhwu5;~PEmH?&JcIVAv*afA?KdL46Go4 z9`?=z^%P}A5J#%$OOWI+&jaxcvisG$y*1G@x&R*c8N^eEK0y{34Tw_BVv?l>=#af< z2ZuEo82o?+ffkh?g{l0dZZ6=hL=l2jUxc><{HBOu7nn>C?;(X{l*RhECd1GTS%l^U zBosm}iH@Yz54ojGbXDA$J{vA(`+Z%(ff zal9QAihv;E<#=P*8Y{*}wWDYgA2CR&QtWcLA1s=7Dl^yJ7MuIoY40Qa6Zq!qp1ApE zCnO%;v$K7U8m#P|u;#WQA>1OvPpO=fbrkC#kT4BJ(6-s;R~-fx6p-?<7(i0_q!inW zj`#`5X?Uoj!ENgJxUo}m5-pBJ94wp|>phPt6tLGcv2X$b7rSWUu^79cBQnKqCct;V z{R8L`5BUOwR;L3dMo`W$Hmd?O;1b|$fLNS*zul?n%0qmu#%~+$fCRqE?DqW#=fPJE zWpz{Rl7%BM0HlJy2fol$+d4+2+tp!7wHp&3ZyoW7HA%`8;x3KJFse+LII_y44NG4_ zUf9tJBh2grkB4@PlX2spo1*i`&u64vq>NTd6iQl>*29tk8NeRvv&fT8=bD4ne&{e? zL3u!uF7i$L`IEOJ5RRyV%kr3`XmOT_f@U_ntPJX?IrL??U!;iHHj`cdGvV9GlkC)Q zfU3CUC{L}(m4@dIenhb?mIWaD^zM$D{{o7yN!Y6=>6GaYL5l4tbqNu7*J zG0;fa-HOS7W~Tqsv(k||D!d{jELb0EJ^6uQwtyZ37C+1lW`(*x4ig+sD7u1Icr3Ed z$^cO<=`0I%SnOPp|5N3*m#sV2{cTM#KmjZ4kHp|T|ByI%Vn}6yYYhcq09~9PSXe|-l(g?#+UWQ$U zZ>WVDBnr!yHDmfNPz+24=yE7=)w>MHK^CnSoGYK~kFd45sMP`UQ`b&yH8Fik zB`QozijJ;b3d(IN_dF#F>4inUprXnIr-fSo){Q^4DDFqP4U|88qI7}tWh;sjrf0J3 zLQp5(o)rE2gxc zn(i>zLiC-{l(3L%3xmSU_<{^ASQ#l$bWgb?xo+ru{7ySwsAnh?3c5rzrz5AX3~a;( zYDiS_hl?+!289Cphngq)Qu&fn3R8F(aP}seqH-Y|+P8&NFuDeMg zX?QdVE9s(PHbnfN<`*LjDnB>z0GxHkX5&msobc8R>CEXnY}w*RKP1z2D`G?mq2K++ zmsh$gl%M3&bC3t{!(d>=UuhUBU{w2RH#*c?*N27f4O0BOO#57I#-RH5NS7_@4uq!R zA{JdMD)UR9xHIfCym0ZIzPA~G1;2T;$>DAeb3$RGOA`TY3`w?NiX}yu?O--$xzELs zelv`%PHmde$XPcn>o?ZTypg1`!yWVbS{R~lKEs>iIns}Ky{T3MYww%;TnbJh5iZF` zilzuHknslWz-(x3;0PoMUj5&y4>d&Ml?$O_lGGaW?&=UTt zXuCY!l^gyeYgy!GHtbI?)(+pRw44~DQ|%a%32btbv;_p}C@fbj{$;$J=DEyX|HKyt z<$TBDQ2&YfHjOV7c5#M*ava%moQ|;|H~!SW(AeqB&pTlVfw+)a(MyuYL!P3A=DvzLkbCC4jA=|H@>#@%G?Q&CX4Q9G?fPW8+^P`)JgsOtTf2^+@kuH` zrFvzyqmmWy>g7s7amg3D7~2MK4>=F&p1%ffiztGw&+e;fZv80|p!o2O-A*;;8Ctp@ zSA!0WaOw?|aFV)X(H3a564b(F{$88S1x4Nj?+9kgKT!11NO=P&?s?H40=NE$2L1MB z%T0aTNQ|L4l$!1sl#Dq)v&r1$D+&<%81&UAuofqhR?PZm;I{c?Rxnl@ulH@;N1XMU z$w1)l3NIELN(ZscY#Nuw(&sx*`|>L3uf6#KLPFOO$zu0uLdcp)wGLTn{; zt|j(~f-}0=>=LsU-^YR5Qxe<-u(_?TI`YmxK6fD=h6e^$&?pvgf5F+;cG&H22l|kanel;|Q=V|Y%%E~I9)HZ6${~n!N>JMB3#3pR zO$#Ox!*w|RMKJ$*Bk{m+JOL;_G1HL8Hc!|GMl#`RFh@4BTOsI|C9uiZYANxWpcku#a5@6ga<#0!gCW_sGv5Gy zJ`j=B&o6r`yO6vsSHr_4^oi|}2w4jqg#?3{llhTX%2qe6aBj2V(3@wVY=$>% zhC#kq^)DIq;Bvnoy$qel=;!*F8bmu@UgBdYu5Ch6>P1wPPNiejvzF`6R8O(B<BGdF+N=FjAmB&p#cDW%;J|C^vn9Fj0NuF04$wkI;h0CS*jgj;I9jPQp*Kw`T8b zuCcq&Teu<&YK5+#BJ5X?_GiU^O}K$wS;Nnv<F5b$|31_TcV#f!0siJX;3LR$L-^`dX0?#*K6^HF-|aXS3z6J711L z+ju?eH%$(Pe@}6RKj{KccQeefWxFtXo2YZ+*3z4eQ+jG>dDFQ~EjMuQMU@6pt{J!H zQE7CQB8md7>0zHUL;+bi_4tavu)|TMoXS2Nz~tldWMdhsZm;nRmukFpFGj}MlV~RP z@zo|pKGA~G!H_;NVh1Q%w!qqMs~c~uqs;R)tJf6q=bFK@MXMYFy~PkZTZ|69X34fO zwfqrD=VbycQtsGmRXj-#ic?!n_8TKJoOfp~O;9y(V~Pw?0YL`mt&M5R5v@spJ(wtg zQ#?de0y8y@6`=*sA)d1@4U-aphGXxWg@jlQ{&9>8Eg{NG(Is;OjwIAHTnp&dRo}nj z#RK}bvpHjiRk6Uqnk*uEn6BlwS)X;Ca6?H^%ZX6qKs=141_XoU74;v#HaKWv0s&u( zw&p8&y2>Q^6oj>c%7eCI%cZ%?^z3{?0wCm<+uvB#q3ngxT}~j5eQ?MK<%5yLW~$d> z*mv?M#yaM1$|$wwl(xYy@J&RwJ+f2;t|Fe}zSIyWM(I_94#Sq6Xqby8m?}UNA9$Uv z`RYrA_sU9m+Fwh5xSD&glbr-t#IO01dIC?CS9eCkg+CDy9FpOkG0#wp(9|NQz8utQ zOjxf3gM&R~kX$%&uI=|A5BcTcG`!d4q~)MNP3x$i*vG|4sdF+oMluHPt+z!Y*e#6e zVXGNGY@qLWUUrY-WD!QkMsD^QfiGTe5=QW8^_w9^E_!?+xn*$S)<>(v7PB_)6kM3)Q}9aT@O7@c91Rq)X}$N|)5U7=oX^ds|z2 znp<-@#-UgL;`wPkuA;zx3S(El2Hgg`7ocwq_op}C;O`-Ts;xn3Er~L)UUW-sGVST> zVg!^fH~QczGr^C2Y_rezEP!h)xsXU;W^nEKVubT6O)z|)K)~b9iubbFEcWAGDR}lpmAfhA5?vt?-nZnB&`VzW=(Xve8huYj z@_cVKG&#vqM48tG0S!s@ep$w&Fjt}uctK+VI(ixh^#||P<2V1wQt6WSJ32q@9R~4& zy-GHW4&X_;-lFOV=gE_6B7efbGFG;#eMOZ*7>Ads_!K#5&V^-R*T^X!5q3PqeAMSv z^vDuhSlxW@HqDnx4gyK7v%J5^ZkE1yA(3zI+DEakBJ0i$Nl_!yC}T#~&3q&3o*X+x z%f3h`aDGM9wz-#6kEvFTPGh^nZsz2*1I+5v9ehft<)E=xUSF|%4DbxcK7cz$9>mz@ z&eKGzx(JY9t8MLW$(SC6Mpj?FvvZk`q5;!)G2uNkY}*(*-d*~U_W*hf$7<;ZUFP=5 z9`E~Rc`$Yy3FUYLJaOKeln3UD22?bc==j1OVe01eFDMb*gG^RWocMJa$!4_S`bQM^@)cfm2{;CUi27U!Y$yS=tBK- z7<%FZDkrLZ=$A?W#?g(7{w>c2?}pm{tETGz&3ISK5`A4x)h?q|5|N5un20dNDjsund-KJLsn5|mw9r%{o=jf)g8yELU(x4X%7t-fXe;qwynp}X>gNVIEpf>9=e2}yntgJn z6<79Ol4~UXMR+*<09m_(O1t~VjP*{E>_#xcf7KD@+WRG^-~}alWm#JTQ)i}t z@=^KdX%#iEpYC3F;_4zW;M*g6QFh z34%T@jR!NEjDlu18k@xYjHs)a(dEyf!$)yO>=| zCGKf{j?iDXlugV7pw0ci)`}CP7u3Y0(ak*R^|0s^D@tGwHbi!5PXO4%65g>1-c@A1 zG-$oLfc}n}U>dOPV6a5lPg4>#P~K@H$D8Oh4JAngS4uE7tzXm4B9dnQU!w$gh z&hn*cPS1N-tKDmLqv>cQg6X)+=y&oy=OT>ld^yqvmoFbK6KPV_j+V|aK#64)2y7`2 zN^5adoyU3|;A};hD%JY1=2KKz!i3h2$1!UGDKdY>>6{uv2H;OS+l8O|>xSsno+Z48 zuCP8}YS2Pnc$KL!7zP*9_)5RILmtMGo}9F+>eFtZ?~zaDP&xb^S5=u|{&JjDQaiUg z8TuzMvCB)8at{FimtPx8MNkBaYG(EV<^5F?IDax3C{LqLC?hrL#Yvqdc&VxU0He8u zca%?udQ5_dyf8`FTfYq4Laa)=X-e)fXp+|^ID}5&fRnKcvR2Lza*{v}p3wr>K0soKW2DhA!`VeFmFf-Tz2;_-mFlUqCa18E1CU4 zUzj@7%vZ1-`vj^WOF5n+Z>N2*Wwkf7fqZlI5CIo88I(yx?bkATN`>cn<{5;9WmjN@ z9asWHJwdeUZE*yXGCXN)oeKMtCX1#RN`m!J5M$WpeHxz`5Z%!eLegf0@BCx%o~J>C z{m6^YTf4lDY&DjHwZo{k4f~pEwtjVI60}MrNeQ0n81uVPxnD%vzF|X#no5lz7WxfZ z5W$y|+pLHY#Lql9Gf6>5NyPPO&Y)^gQK66X9Z1ywNx3Fw57-Qh5+-{`oGl(spgc;W z^(SLmp8x|vWJieGse}VD#)FqbNWv&rY>LpNWOWbM@fa6_#EOVE;buUgB}Y3()6`(C zBmw1;-AT_TS{n9waxqjMBUHPXSP88UgEY(>sjEoB3>%MPNvAiT8YeS~)8Q2-sg~53 zsMravWs&f4_jSV$*-3V~9}!p4dWpS&i049kdVv}Mg3qe375tgVN4<{;YQ*?=FiF7d zgp;i#Bc{JcN%RT%pZ|rky{HA={&KEtVg8#d?*GV~%}uQB|F6V(E2YnNg8?S=#si9| zV^vr`Z^$1A*byTXt;4$Gl8%AxfFhYzFd?x`j$gCo!c-yxXL+#NHo;f%&`j}u{v0S+ zUA4P`1Z1?|{34&qq=Feq(zBs#6UKh8seiZI7qi@g3quC$Eyr3-h zu)b%8FrKlopMoo&3z~=*l62Xi5=g>=Ef}V=sD97y-m!(86nOhV4H>^WG-t^2ebm!W zUwCAVz`jRBCrdyV%r?p!7Bk!@wK%=A%VU@+n_3Tv*iy9+v5r5%1cT=04zutH%7t8q zDwgz$%24PLbTX40o^_a)IVuBn6u5n@adC@BUCry?PMVU#1G($+tJXjT+i!#kEkXiZ zyPx?p{I?FS%T=R-kRPyF{Q#`8CDQ(ixKLpu2 zBvSafg70+iTiRe|7ILc0ap#HC$k1W$@E^{o>1W{nM7+BbB8A6xpcR? zdtKk;(NN|LJ|E|}17@%PB@d`{_1~f+7y!T;#sAzX))r3A|F=`7G;I2II(CBnam9&i+}|Uyng_?T@QTBLL#({(Myq#yCwHF-UHj zurZHg$cusi4~_`lWD^nNE8YR@-M&r#$|F;-t$;WSUbLuHgCIX-7{$=0_8iIKK}gdO zL@@K89io86_YBT6Mnql2RA-@K5kvAb4D=8zxyz6R|mm~C)xI6bG zrz4QY6NF|e9o%_sb%&vK1^&e>mBp!qvo!R6)$9Q6e3LN+=#+EbcJF*5r8*p3eXC$c zRKNuKvP#xJ=S7Zn)G%EOSQ*STzh+uO2B%;VU-G(u$$pXpSMpuzV;T*IaYeuHc$2Za zts!00o@9Iex*Ud3_yoVc!)UMjocmI6*+XMgHOQynIp^pn|<`QUZ7 zX>YzziAEh@-7NPGnACGa)xDdM$IlMY?jG(0;hS>}T|rF*5~La#Bhw@Ym%qb0Sy9?(oE^N5&l02g9TS?% z#Rp%FSN~2-Txz#8(3EJfFOs!n<0BU$qm+yfSQgT!S#tY#cdvbdx!ZyG~BgoT;c=u%w`^)#JPFB4A|!;j7v+|T{2#e z`UN3XA1*l|WmCjHwiJBzv%(HrU8U+rahsop+u0 zgOF-<(3;CmOzuR1AacD~v%w70PlE?HhdKk)$+!sX*9?7WOdk?Or*yNLPuAAX!oN0_ z4Z)jJEBUr7hua2igT8uM`0Ol@s^-x25J*HvHGeci-&9HVPqE>Uh$hl3b6WPJUY&xM zuZx<#Zy?ZEGCZD}Rgrd3E5o6Dv#q@ed%UrpWrxR) z-ndu~e%TTChJvw<{#sQ3!U@Y77y^aHcBdy5*r}JeE^Jmr|7g z8lRM}ruFa@ZBk48&nYGew7%T8EE9&G!nbA5D5q^$ry4b7l%zt-pk3SEw!Q)JimDw_ z8#1Xo^0e>Lv>o=;Sgkt(8FY3Q-fbq0ZA&Y&o>8iBNo+kL8Rk}c*Vm}vq#0njxZrF8 zS(+pnY6+V1Ler*;1}S9}lVLC|lEHi_2ugV4ZB-k3=xjwhgm~!lFZNX*oxvPSnLK_? z)bt_4_7kq2=gF4=;z)E_=ZhHnW15!Fh|$st)hpRQAxq9=9%t8n1Cx}9PN7G|R%oux zh6?zTGqDkte3~mZ%Gt20pyhh1H>od*I?J9^D*03U>&V@_T?EY*6Ak(7b`!OYYoB5x zGPhEvn-x3X-HOz0#7V zUA45Imz>|RU1o`kH8lD@)h>JJ4F;H|{#6JGEo7;|p?A?o?O-Ad*Jq5$~)7lohnKC0otulEq^*TD5ZNqub${^L?Jq@riF z!GPfRQaam-jve(k`-*h3ZmtG9ND%*DDN`^`E;e^Vv4mrz^<&44G52?35v@rbq~y`Z zN5@AtUCjj-I#gOfRCtEApyW0fiS8fh5sw+HMnp>y#ga)d4rQ;A^%F_u;Q3C*Xdp=m z1&MqZi*rL-ro2vuJd71UgL1S^MgeZ<5lq@>!014>q|F}6frKq)&oB;GYrpVUP?KCP z1%dYoR;4PS$vpCDwzQm(JXPoL15S`wHg&L&+^Pf|io^s-(&aqFGE+|ZE(>WgS5{v~ zfHUrxL_yi>-q{F=8NFr&6#XUEw<{($kpllV%J0|#(CuY+C%<<1)3#%2tNVP$eWAsd zyh+&pJ2;oH=Zip#Q=@ejtV#AWi`{nA=Mmq~1d>>zm(TXQ{GhQWkCyK{94?YDSu*mj zCk&y|3pg*wLabXKvtCR8a;eKj#?AWZp!>3VB^d^6B7u2EF4Gr#BP3s znDQU`xu4LY7UvHN=L`Hes{S7LgrZT%u=Pr!ToMVdtHpz}KmzZLnRFub2x|!iqs8rm zMQE2+EfNHJMqSZHMn76DVbREXU$rH3S{&cnWI-@5B z>_&&;&mryBVPv8Prn9j)>$-5HY8(md)+FCQd7x}+}q1VUd6pHoLX zD;`kpczV4Y1CJWvi3f|+g{B#{v%9VU=E%96)^M@Jl|w~;Tks5rD(S>*`Hkmna|OUY ziRA*g0sDQ2>ObR(-L1OIo7UmBMPIBm=iC<&5nGc-I0u-Sd?7@pQ=9`d7%dfy<(F@eKtfi&r@p zb&b_1jGuu2`NqG1B}ZZQyEdMFvEcvH+Hf-Yk89kPs!n940E*9xn)YR+P#mk)dmTa3 z7A9e-uEeC|WaVedKK?-Ml6fjI~tx)P+gYAZi3Cv*tlb^=o>4EVubDM-q=UWe``I zehrlwoHRISohtzR6Jx>?kDY2B<84jL zy{anpylcNoP>S!NYpXd$=Ng;11zhMQJ1Pd%F`PACYcvW;tP({3!0vK*(M>?&9B;LJ zGdl^UItqajdL!Q z{+y$V?Fsr48-0PCc>{RE7u*EI0?TdbH$bG<=(4fwCtj?-ksl7?i3Jn(iE!`OW8?7* z3j@`q>6d_6&RtLD%pc_X&3puGcDhg>`_ISGmZcb{{jJ$q{;I_S^3sPs8B&TeDD41y zkSRwTKD_G)L0ucIn7oasT=BUg=Z*gA4LtH!Il24XoK*gon__9XqRv1`O#yZBbfD7Z z;s#E=my1nP`<{Bt+Sec>2;st+n zvi2^4P1>&)Is~;VIl4G9Hohq9x(Uw3prke_^VEZdk^9w5UhlsNKRXAn+U>f0Cq1`N zQZMj1B)iZGhPse@H`Dvy=)?Im@Bg0!lK%~)2i1^NO#G!TGyJ|t|EK(X#)%}VBkqX*(M@-eIZ#SM@R1)0|XwE=1hx+v`+NT1n(S3@PJhMC_2bukt<8! zBqRmZbq+=9J{LIUf-Hl1vS692lxE;{Nsww&4JmS}LQZw+#U?fU7H+{8iwlXRtE#(f z;0skG)n}E=#;xs%I?IZDq~{fTXP6ok*NzQoikK_M{dama&IVg@z+(8)Ta zS;$P50yLbdeErePd)Wl0>3qkZRB(Yk`pCGYFS2-f*8YKmT&6g`Cq)&BJUMN1bf&^*TYX22 z^Z8vtZQI7T&%p7neS-E@&nH&?6 z!5Qy))%N`h_YeVGljAM&O@yQ93G1&<*tpG^xoRcxaSkk#al3qlg7-@MjGF8T9N5%g zfSAU3Q=|mvg*M&X=n<$sQ>tI_CnQ)^Bf6#a!zCJ*|B@rvukmsF!5`G#5B zRUd8t{igmxA7DszyK&4+cxfyuK-Lo4=&uwkj*Q;d{;9lPi$f=9r(jbq47=@lxsf=R z?W{gF0ULLH=37_SB;Js{tNiGi`f-7211Fw#D|2h|xm6SVo34KCNvz|q=T#h4Qc6@E zHKzGZpZJ4RZzrJ(=xRAe{OVU-V-i+kXd>dzDK0?(tRe%YZ-C#YdNfnAam{Q=lUwy7 zbnpRlew|M(%L?-`hmXBC4!BnIG$8YggD(EL%SAvp5=UHZ%g+$_xcARzDrp^?S}9dHgquyp4&Couh}IwTZ$1 zixAK`bNt`i*-q(fFh3j!yu<`&BBrExK8{Q&Bzx3|OVaS}Y9IJ-Y<%TFiucgC2s8(< zjkQ!qWwL9}Ov2H&K&#!hXbOAFR;1YJ)amInu%oHnz85pcRp_hP>FUX7&J-_y)Xb#x zN@wpk9>14C{e9yo8P1_y+IuJ&0{_FpORxi=mkGRG^1LBY0>!AW#*y{Fn^_=PqG zVQ&jML&0^tz!~Va`zRfA!tbNS$0di*B&8Rsmgg?F~WZ}>|oHS?>|ufhTLTG8RN zUF$_8F?`V(aPr}^wcRgp&vMFiLryl=0sjc8%=tpKv;s6%X|_XWL|I<7tay`zW3X~j ztV3nZ&$F-~^x?>X*0ABTy;Bcp4JB(I<~(?9>yVa{7WrL$o|y?!J=n;PzT0}#w}3BT z?AKDxOm2?ldmQc*4Lvo!I@8%g@v9Wzd~BgFELTkjW0)%k%TFdtl(~qp)xb?H%3) zUiDk}If9Wix$P`ppT@HVLN2>znZw7|<;j+-nP`ue(`&LrYnjj&k=4=>8UFN1jLjSv zTQ^O85^P~*qmrw6`xIu_u<$U$iBnK!1uE)g%qPTy>XyLGPLG%z_^C02kXTTo*%u-f z?CivWTJ7wq2fkq1)B6A+7hgeG;kP|V@rSUc?`z4*#Y?WDpliKCYy#Pg0=q0d7D{KY zL91gFi-pZ4h`S1>13VO^Q#RVdDPpAPGcS7OXH)FJsM{Ql_BThiv<13#-`i}`lxG4a z);siGPDpe%W5&Y4->ZJ{BlL9mZcnx*_zYuCQv`8~;zH6LqX%(nSo=y`tWY$4GawHl z$7pD33*etPE3Q_U4CnoJ4xeb1AgRm9HjF94UJ_!>CT}IVMYe;6H;iwc z?`Cl-*y|a2NpDGEANy}q(d=b{_D+Q&Ea%$=LunzGDd(aV@Z&(u(SMLlA;a(sf9+Er zMIw{JSm zU|6+QMLlF_FQR}$wiqPM*AB^~-W;d3`VEbb$Ii2l2O-*dMN zg($$)Xt5KCITKcsu8Y-yCSDksjL*3l(04+!w>kdukH?rWJL?CY2Js2-Ug+vw@5Ug> zQ*GrL2^jwkNezzaTZG7{7`l`v7vp6H*ok1oGF3{!qU$ zh?)=X%xueiKbjZ+G-N`@`-6>rjTgy}*WIbk)h#9Dm3S*V!W4Oz6u_qZOz4F@s7Pz0 zYAp}$;)FtP8_p=#X3Z|v*_`zi%;IX&?Q>zPq@D>ReMk9ZSgMguKK^GI=}{jm#E40g zWCTHls!i*X?l$GH%&?2gIRk$Bq4y_>zHV+y0)56gR@2F8fP&Rr*GSxZ855ARP9RLZ z=@zli?gL$f8*~RzA4EYIr{LOh{T>+0Tm^oRl$gW)FFFB7p+i5+W{&Lke&@3C9`-4K zM^$g3Gn(WzlNV$rvRitTN-I;r{oFV-gu8U{EpWi9?qFcoC7 zi7izpsu`RFT!6|3LiE9(g?qq6Axj(m;R2LF6 zOI0PZ84U_zsgD$8^~J*}Q_dBJ69bCYoa(qb&xscW)AyS{_A1^U?k=zMwy6M7Ni-z# zbj}KJK5^m)A#`dE^>B#@`8EcgGbl8s%rFB&s4eTbpquBMy8b`>w}!R{-xcwf74ee= zg>m034YX*;8GzY!!1&32-2n+zAexoyX^V04{+n5y7r+A!9*C2x;7wWe!}Tt?e_AC^ zs-vcIp}8$M4Np2u)kdphwl~CXljXAp7P6D`c<2d%ujQ*@rc?WEKBk$Z(T}g~qDnRd zlzlJCA}UUC!cs9hB!+1bO#dB_{3&iKqHeS0X9vw>Zx(ljz9L5}b~y%(D%n*x24_Gw z8-Ssjj73)=m12bs+$gybxo#7f9t^TLk)jalfR`!qFMeoLrlOfg>g$w4(YEhN*kQa9 z_GdJn?rlNvKyIHR@r&a-;4)bDib{SOM_!b=R*GOM#_Fw3LMOA@>_E9)mEOg*#&Jg# zbVFB}S3mNfSj6r%2TeT)BnWrn4=yti1r$x?zzeLWXX>AXe55S!2C_(1+Qhr@D4=9ic7)6ibUMZi zKqHZOjV5K{(2G#5(t3@X!f1jYw=TGJ)i$hZF(#*uI;ryGd?FPBP{7&zM~9;1>u3hy zs$ad6OubHqt!g;9bs{j8VRtgTo1CFqsctl&$(@o9Z*Q;ICjM@PPLp74_QW*?vlV(5 zdb*}5DGqz(MKv3`Lv1YCl?~lh&!(C*yNOyNSfE^$H6+}mwp-Du2gysWQwML`Y|kd(7WNUYupT$<1kTHHv@jr=u9PEqLq~QK@eJZH-Y;z%El?sR5e?-0 z)5-l3@s8Xd0gO4Hcz0M-F4N>pd+$6L;M_1(xf;E*rKFJzDjC(Xpw6YB*KS^}lklbY zSwnaV3-1x^UuMzu zW2L^LnBw|TywYnGI;@DU-_ci)o7&K?p|YR59X!GQBPd4-IRE}RsspNn8HLSsw0dXy z8b)Iu)pqVRERJl;gmrq6!|BT5Sbxcy(qyoSf;0A;)oUJ=@ksY?s(a?qN>lY+lS_SN zrGl`F+zvl9Zh5sfYfPSseXz~Y1gqIrmk0NH9Qj7)myU-Y7`$79lpQuhK2Q?pk8NPu z5Sj6%vhdVXoYE2i>!WrAtVa=znAGa03IvekM|A`lNo0$bOxGx(mX$oYm3hY>kANg{ z1#vesP(15xplCp2kYz=Vj6Z-rBWg@O@C!d4nB~4me?umpgvY_D=}T@SFYTH4dM@J6 zr~o263GRj&inBdJPznUjgQn_q0dHPZzOt*KI$XFExTZ zjlaN+>3gu~^!#aq45?%hitetK-BL!OrW4-TS!QbD;oh8P_~#17z;*VIG#%f<28F3> zOv+I(WiFqs~DDT z*^MBE_#~3X65x*LigatdGKofAkTMw@{V^YTO&yy_k=zaV+} zdAUgs3(NWk+o-uv=%fV%+X@vogc4ZB33ziN z`TsO|)~}p9yXMKE%mR+b9Nm^_d9|kl|7Ct@2#xJh`PCTJ#w-_XgJ$$;@jh? zeON7Nd@Ri$gWX<(d5ilceqX7Z_WHA_rFK%svA(Lm)6bX3d#1L!qje8J+|nG@o!Psb zzOe1%%1>*+cmQ+QH*0C+U6c~IzIE#_tYI^L4RIdm8b8qM?HnU({1BKcEdPLl183Ol zLX~sM6<9Bw*f*hl{+GO}W}~Bg`Cr*b*>8^$|7Qfk|4xe3`wzOEiMxS~y|u|N$2D1A zMrz;}(Rfpf;V8+N)It>^)Vm(ED+t>4muA7(kX^=g;If4EXM3s?55+m=VY}1g22_y1 zRr-_yw9r6{v1%Ey3T+|;fnB{LXb<)>ek%Zq3VROE*jlnIqtZL)0P?7#htCMOq6&t) z?WLo`7qt&A(lB=HF8%e-XK)lLry+H)I8+l_i$z5i6Bo-z783gg#BHJ?Sh+`4cs(#J zss;#E;BQder#?&bpn)=&m^|9LOUD-$nOY+lD0-Y60P3^Z`jKbzKiw+&RH?bX>^{2b zKX%OTSmNB!jGbUQkdFaZ%hO&xp+W3~os^W#a@3@Qj3A0BWlc$fwXCjaeznwD+a$s^ z^R8WEmQ-}#)pPc27t>Sv{ZevU&*3!_UF4qnuY50@ZSlzp>r*x#a%{QoEj({Dv&V!Lby7+`|k zctQ8ZGv@M?2$c;lAVy`CGjFl zAw-(Uz7%<>{GUNJ!!*DD^Z7iVnd|d$zH`sH_dfUDbDwh@xANEMYv)+$Y^2p$p6Tzz zzw=zyX7da17dtT{bB+A$8#YZd!aCUS`&HsaRf|BNu`` zw!X)=-_`)s^cZNn{Mb~5JdtAO21ND#-H3g3dz{+OQ#9}L)k#ZHcrTPrD{>bV-KC=c zR78^}xv2DrpH5Y4p|ygkE)zdL*%%77x2G?V`SfBpweDNjSdNVBI;GNx^M~oz{W(Jp z#PMd`0u?v@a)@(yGqFa-NN-A`EYMi{m}yc0E@8#UD^~Qb&bX;VQ8(mPZ_n;9F5?)fOaVC#GPFI#cx`q3j>BlpiL)dw;xX{o@k5g;SlejLTZTXApf9uTX zR>_;DXqW@6@`(@x62N8^FFP9_uWOFJb|_})HPA%HTWSH~ElF|c*;R%QsmXa5;}gA# z1`CGdvt+wG@LKBfZNC(`7q?uvC^lbp`eGA;>bmS|%YWWH&uF*k!*v=7?v(C4Z6AcD zkd^D?|?F{EX`Fzh%4?sYAb$WjWe1y zCs0EXTKI(0xS)n*y>Q-QN%GyTdio!ng*gpg`~faZ6QP@r>f%k{Txz>)se!H~*-&`~ zILqgor5yLrzw_1nk zpZM9jWX6U>aB()$Ix#SSq_5tc%wOr)_ptI-&TZEpHJcf1 z(E|~V*%tpgl!m7dfnI(VD9+qN*Fhfw*0}=2Bs)iUEELhc4z4-c1E~pqIlPvQr;rY> z;C_fzx4KAGq^fMLDy;0p$u=>^kHjoJ>XUzEF+_}-jNRy$>LZs?HAy-#+FmB-y$|HTbE8YQ4_v2Ao#vi$>iKkfe!v8F%FIG z$O)1@5!K2%4=;8}+nC!TGnL$Vd`~#(lHuHUwU#y2c-3?^wpRoj*vRqSBMSpJj4zrK zwq1Vpo))R-EH_v-oJcHfCF ze9KJOsNZD|YO8T0(fm0#r|PZbyAcL8atmz4cM7FtErRo7*3~@_1v7{pDMN_!3!YQe z21?{L_~W`8U+thnEP3r`uh!z`bVNxveCigvNV@;fxH&q=-$mY#5Rxt2>V^pa4q2j6`> zedU5)?~4I={SC?SgOAj5en|0%`X`$R=I$k4J>4>ExNy64T0QbpYAm)PY)YwVgfL%*Z<@JCT%6jJQ`IY7m<8T0Uv~Z{p%CF5^Jg`l?)Rz?sLQ= zFDKD*{}huqR&QDQzBxHRfre44aKm*~s+l^Uy*G<*V=%G_+A?QT5D&T9zO$_*QVx?0 z4O9ppE!)up5qDT<4@0k*R*+lTLLDaCI^j`^DMDOtL`4hdPA2Y#>bVQcdbeDThPmh6 zk5<2b$^o}{deC%QrUN!%@+Hkr{cU_qS#dcRB*0wU;f=&&7`4VnMqZW=9dpV7ri0FK z$I@Ion(|=lXR41B*{;+=T4ES`t_^!UT!z;Z&eF>aR;RlaGwu=cJdb>OzCy{91s%2#4kz7sJ}s+{8S>GRRxV>E-1wIaX* zGYj~(9_i1=A7uSA#4y^fTAOy%mlRAT_OuF^^7sRO@gu|Kkx`-r2&N94({9-{qL}Y^ zIgzX7MN5KsRm=0(Icr6a=C-wHsq_mkj_ao9ms+)>zqyXY0()Gvsb5s&*AtWar-Qu7 zI{!Bvbp2wp2Vt3r_Zz`s3F*C$P!X4tyl*{iJaX_ zJ@xVvGA8w45bB%rUd>r8F$9xJ-O48#r^p3*HWXQvq=$Wdp|Z_2{5skp&J^6t)~1#X zhSdr30g_v)wAP?&{>}bQrt44l*n`-_a_h9%SX2=O@He9i1TaVPamr*7)`)v;0v|$@ zbPyMUn8Uk9L=5dX|Fh0EPJT$74m% z;TSni&vMx@sKMLJ0fb*f_#@wk z>+*AJGx9|XSI(SKR=Z#RcV^g&;OLU7b2+%ubR*6)d!tX{U!YtCIJ96_xkq$MHgC2_XAuzsQ%iTUW~;wOBf zcHK*7zCTWqk^dAZa7EH3b)n;ZQQ2w9Touy*9fzKoPQRKk+eP|}p#fX@J<>fJyz@FQ zXS_gmB5vd0q&cut0$j)wAg4pN321=t$omX%BcFsq$i0$lLJ)|fo1-@b0<3=6z&h8( z)W8P7-@!USAa*NzI}+r>^*|1-sE+EnaUh}`@c+o0{HGis3G+2acQ-%)1akNt0s+f` zxy}m;pBEJ4+LwVO6M?v6YgE`R@kR}Rk~?ry{FH(6D>Uy2lZ2e7z~2q*xC!Gm{o68y z`JApGl~|4$RofKbKNp@%j%DL4i}pE8K52?=gDbIbaGbScJ-JTR`ze+E@=J}{ku1l> z-ppI~KuR~h{ab1^r0Wm(ka2MT-mg8Nk4#-v0Q>7A*k2tYi^yMty!N$20s`+Ba6q)6 zwa;6=Fz!K|pP!yEv6q2B`M21Es=eu^kIyCmfyjW-gyNT2q?}(n0>}XO9`0Vq6z8Y} zP_$odLe&5ct_|rtQu5C-DgkN8fUR={Dbyg4pApiS$Ym|^0t5F7z}^4h2idU?l4BpG zcpDFNGY;agR~}Lh%SS2eyK(;Cfq~H)IP6J=MyUm6RRn@e+T4|+L!+;fz;6|94hM49(F~oR;C5)0OLCDB@kdhR7DK|+S*4P zBpq;faj~=U2HNknqol*iIB1?cFv%a3#rn(zKT-n(3OB?#ViG(-RMKG#%s)`Bdt;{{ zcooXJfG*b{aCcQLJI>pEfl8?6m?ks-KvgvI?`VNWTy*EI0PIFNvIv;RnbA$`}t{g=fbhB zhwK=S++PO22hO-T0Xq6^?+`$EA-ZsDJ2x8#S1T`P;iJ_!wC`H@VKx3>?*ayHjnBXW z%K)G;{i+6%)Rcl&4eRUPb`T(OlcSsc&*OEdH3J>Wry9o2n=q-@wF7KY0}PsAG~PlQ z8t;G~%yLij5}5N`|%npy$0prHz-P(T-%n!lX+> z59}=T3_7T+?~l<~`(p!IJ{?Rr`!wu0^o%E{xaR8PxL-TOhfWMkME|GQiRclzQHgOa z$B8KMxiP^qPm*J`9D3AWR9ITuaoCX)rg zojr8kF()M^9P~YQ9J-eY6_<&NgW5`v4F;3X2@}+hh!U#-(LE`sP(J+s4MqA@Fi|6g zuv5{U5~x%|0xVPvE(%QaG8lF``sO<-edz=iI;O37O!%fk?0EE@TvR+0@p1g`eO*kf z$qno*^c_=F7Awhd){&hNOu!X7Dy;1ZecuHYs7Zzf{C}`m;xV!QwDLt|Nl>D(P*$Lr z>NI76Y0mVgmGMsw7??BN_{sUbPDTTKdxW^mHHgzEaR@jp|DHW#S~jBs+J6H6RE5JO d)4I&jB@+Sg@;TBu$__dWOq-RdfLSf*{{W)JaJB#d diff --git a/python/lib/py4j-0.10.6-src.zip b/python/lib/py4j-0.10.6-src.zip new file mode 100644 index 0000000000000000000000000000000000000000..2f8edcc0c0b886669460642aa650a1b642367439 GIT binary patch literal 80352 zcmafZ1B@`;wq@J4ZM*wx+qP}nwr$(CZQC}!wmtuSnLBspB`;ISE+lnMcBQhk0v&QGOsvq7S?-^UssrKY{S?;d*)&wieEMdUW<4|230^y}AYwy4HW6Nl6u= zoaJ5-LR@-QR$A^vl6rDZB|J`!nwFAQGA2%Ke42Kgo=QPnnre2Al2#%n?2d5qNx#vg zLS!W4-9uaZ{$8F;zKm8{9bO_C`oE>MI*Aq63km=LhxLD@WoTezWpC%`{QpU7MO`Xx ziw&XoQw?Flz@#5B5@*v;0sWj>WS)p39@G;8D1fR(qEen@k{AUt3$zi z;{dihzth9IJM2krlN>HW3%ZU@f)zh#x}qz1I%iKBe-(PUFCT?dhT0aD|MeuiGxy4R>IH0`W_ZMf4)j+w+$g z&DpJ|DXf0EK#%-2Xw7m&$!gLmp+z{0+|j)9=N-Q-kysDt#|9;^#5;m^kYh*zh&4#| z4df7CH$H;4FfyEB7=l^*M|dV=3w8Kc42583#*=jtV%zL*ueu&U2Oo z&)ds{ur9liY+7?~5wuIlOP?FV65|eTeC@3ZEeknNdgMEe>Yr-~rnGgV#w^%xwuzI{{u3=pNbvv`QwzoNZN_kuvWq3t?@y$ z8MEXqc4>Y2dgo}^sFQb_jN``W;U}$r=9)Y3S(74`h~W{;EmYZKnY^q{u~4q{4VQf~ z79KOih$5^^(v5a|aY3h~Q(t8JFvuJMhg%~`X$_cd@_nb7q3Pt3mWyq?5X=1Cj-+d>H*(Nq>;;*OGB2~ z6}$*}Tnw0Z$!e_-f+JI8IeHLiZWRdIV|J<#VaX-l9)ZP_)z(gWnlBX7X-&o)PRc&< zKOTvBXrbPmD>-L= z`cZbtSy-iQOV||H$!Y_4TiFHsY&F^43qL=mKdxUEPMtU`d(E{95DdDFx%`kp<6>h6 z#S4|1euLb|Jb~|)-B;LQ%_k-gCC-eYUZluQgEXN2=B+z*FU4ld<*LphdGKbhOaGWD zJ*gL?wLTtS-$re89X)Dv*)w3Y=s!~ySv#&?=eaJ+n$M`ks-Q7b-^}BK#+ELLM*q4? zzks;u+q55ct(g4>p8kby<~YP%=coVxzPbPaDF22hBWnv2TW2Q|M^_Wae`CSF!0uo1 z-@xu#`y%P6HRZRq4`{)gA;nzTHCwqvc|~jGVl{f}!l+g2NTZH@q<}!kx?S`GNk;m;nclK2jw~ zLSc`rMZB#`+SWu@jR>*IxhIv>J77;TLtm|dR+R}ET4PTt10Lz#=z$QCuf$}BB%@tC zwatRA;3~_2F&aaf+Te*<7ur#?_x}v1l{r2iA|ziF@?!8en1 z0+yUUz;*&d6itxsh^nKNH%F=JwVFeo={PWkG99)Ej#hIr`S{v^z7Coe;@DNGtt95y z_)cN9NK%NtQ!>4}8q1?h5+aHZ`(SP9qpoAG zp|pO?h_NlcHeZ^L?@WKo@Xbs#0MhQ~q!?NYyOrcBY0sUy6RXp(QPdlD0?Ed>-U7(u zFB51^`qfQV+$imO-#`mYU4RnMoxcI>XkFPQ%(Bm#W;-?yH)njvkwS zb9KgUYiobmy1eW{%z_!8F=R9wi)_mqCoE8czP@cTC(LyM2ypy#DnqSW@i#M#x|%?EQ7L~}E!QUlDq)X4OqO7?x5 z9sCQHB{GA$z}DKWKFje3nu5Bv1+45-_3_2QJ*$^sIYjTKZZ~4;-}~~q*MQ$GjcOuX zC^TyH{mNsx$Nb4$j6J{;SX>b6lrm~!&7X~;@V+H~ZXsNRS;Ha*t1VXq#P^*s%|wt$ z6jU%>%TLBW&Kt)Yd~#`+FzrrSDP9@Xsh-AgS8)uRU={?oBCedNqePLlNcqXYzi$i7 z%2}`ZKc@EncP;Tu#(PgSNT<9E#W>*dMSoZQ^s-Rwp5Kd-h z-XWRydEzVVYG!ZK+8jO5_s~>wf+FX$_m29Zu+IP3U}LytElYKIfk-2acnT}62p3uw zv+t9RQ|gpS$?JU7w*C>bke}#Rn1cG=88$+ygmnt{ei-`8s9^}}LbWE5gb4F~l-KwB zCTSRa<462Kdc<>S(4K39~X>J1BqRJETL>by{2Hkqof4uyPm) z24-Q)&~AJc56Z?}XSFQG=bP#Vi7e;ErG}p2mu;LG@Dww*L zCwP;<%9KE8qz+f3AyB#6w(TMnJ|Vz?d9)(SU(op?8ead#DMy{;eYsSc;1F#x*lJ!C z2CR~rVu}{{HT-vnv;YQ6^+2Mb)9I^oz?f7b z17Iqfu6OUnRtXUWj{WcV6cbBNHwjTYWK4sh3atk=sFHQg>8KWmWO+&(HL#RK3nFJ6 z3(tj#_Yo-2Fa3HCuSEm&)CgMyfi%WHX0c?aqR$A)zb(S!ERUswrmB=4hb5?~Hkl0b;BN8)z0 zHg^Qs;|E>fZn)76JSB~DVuNgTmUfB#IqMeh z6vxngU2+c`kU4|#I`@W-iepJl3|@sN7P8zX-Pnj;akhHD2ZWwKWKU96v#dWF^cItU z%aS4PI>Gk_?AwvoAcbdU;83eb{q;5{LZ+Yez-lM#BO77aSw>%gSVP*LiSa=?`OhR6 z>Yx32=Jyj2=2ch8f%usJ%=l9_%I0$VyUwJ_TtFHV5AYsaL}fn%T zb7qzq5c42X49xd&DK~U-tqf0iQ@q`hQLK|iRzYzP!`ts*_GvTSFZ zH{BpjK;lBcWc4L{l?r4WXN~BNrG=KdD5fR;-mGYN8(J3Gr zt5B8|oMeIkWsEorBX$M%VcnsBgpR%0G_M0cSuf;tj*0v3B?XpXt zG}z7a+yG4#ehaCMCx%&G9LInwL!GSXQk3jY&)B|SKZEyM4ed6!dF|xK!SDgXrIBo9 zT$|HJa>vLydtmJHo^b?$>XSxd-4@FB}`L(=h^|tk}qSaat#V(=3aGd`WDp zy0f^*mHCrN0Ws>ML5{1X-EdZE7(sRG!86)6_!Cbtp+4iIeshH22(O&+#jA5)w4=V? zyBT#9VR+shhEWFwY1d8Ak#Cqfwy9yiYP(7%>IY9t;gjw){ukOIK1A&rr=9yx zBsm7GZiWRfdF^V7YaB!)d?i{Oem8?C2WirE z-_9scgUE^|FZ6M_n;rS}e#^bD;d4~B3jXFGye&AQSKut`<)c85S2i8nXB11FW%WU! zt|b#aI^_B|26Bohn`1G&kfxc{->1sIgOeH^uiU+YoJ(k>C+#*C->J)Afj^JXH1e#~ zTz~$kh4P+q44JH6tVhl<%&#yU@7t2mi=3FL2%5R@+;yr$s%pz{lWkhmD1D^J&+ww8 z-I6|{nD->^BDl%8SZM`#NI~_}IGeHAy!@LwVmS^6>nGe z3@cH_wsy`JZEuF&Tr7PzHLt*P{jASW1vXE8wMy5yfaF2L3{|XTd8<`hOQ7sfNA<{5k8_!YBj_7H`zuvV;uW7hTrtL_|DYLr2KQc#5WKD=BoA-g zFUPrd_}TR|9w&o8?OW(yB~dn~3p+=Hh~_l$graf_`UqbS;y>9%st6lM|RTj6uJGEp(Lhp}dmhT`Jz+OIkLg+f=*N8J$NAG`k?&t0&Rz^nd zPb|DWycx3q)@fAre`I=m-_s1n?e3(zt=;1Y`w+@PW+T39t&}UUZVfxE>mB{zw;Px4 zy;WQr6QfL#VYgzpi;5??@05ncs5S{rsao)2gM8KG&c0K~g^=9Z@yHYTqS)=T9Xs zvQi3j^>TD0jxz0zmngwhHUj{aoV@J-bSKE2Z)4UL8$>)8-%+zil)oOgez9fGUzBh! zB%(jAiHaXO9qZMibWBFJ4rd7?Gf6KY&T%)JR9RI9)RU>;2BEiy8mbh?YGxVJU&nl= zU}A!j$0kk9*4EV4?pI8_nY*}oj%m_?WQW6OC3JU?C0Xq1vgktW6;e<2clxCDd|zJ^ z#0FFM8IhF*3Wx>pY>t=Q34QyCe4QVh?R3Xit_kWO#N?1PSJk6+jBf}dA} zRCKBqRnN8Kg?>HHebxtOauS0t>aYxYoeUMsbpa19Tc0Z)Dr*FelymGA%mF$n#|wwN z=Y)g$@MQJ%^hVR0va^rXnHA+m*ds+*d_C+u-Rf4|UAumr?|g~7z_U@I^^IO!I%E!3 zGC?dBJW8|fCMcdH9X&d^U=MD)4>nz^V%q_(8n&o-)kD6-5^j$93@_nE?d~jm1#gG) z3?J~+8#$P})Wy~PLInnuDs4+!t)<2`43iKPp*L3-%q0nl&dpQD4 zVr)getD`+z5c($5Cs+&HJ$7viXN9o?zdj0fC+%UVAvxq78gZ=aIG1qh@c5g3@9%@6 zMk$G;^VXa-e&3)Hrl=>6~IE<&={1?0W&q7cV3WYMmHlZ?7wYpO$4TK0Us<66 zD!x|x=1%QiBAA$W_}eit8GG$AF-~5(S{+V-abENe)~KXw1K=B<%wv&uz+(qMwDjPz zfWOB5r8Xu}Gx|QiF0hk2SmCY^(#&GUOkV36XK9|xrqcMr1SSV$v*Y{y*0Wx6AAP+| zYQRKvJ{`IS1)7P6bL8uygFE_r=C;2Wlw-rkr-}^s<|WVDf>FdlmDCXkfHi^`{mIt- z{FdgAav&!}m4{g34Hx|2u`q!gKY=Vik;FDBCq(=$e(Y2bUq*<4Z1xoI1@_?%CX^S# zxr{KN-0kY2YxQ$c~tyfThY__XGrb&DzM|)YA(yI%w+a+YdZ#( z@9T%a?XiH8r~6tSsfJ5y=+MY_7{ptNMXx`c!s@fSz99Rd1+-zb+gYoQ-13DCSr0(swB40%H{K6nKqsta;wwM zB0BVw5CYXJE6ltzcbHf~h9aXS@;qy2AJgf!(k8&Uo{f?zQm3=AILmZECaOgR_A&T% zLUzUH&$s64jjze4Ekn_{L{5SIZY+|niknyt-37dQ&~;kVCA+YOoBg)h*yUCMcf#%R z=cv7gw^s(W_ciZ ziELhk>M-5fV}wPqe>P^Ddf8nqo-VGmwtJh(##=-2wlrO#6Yg&Q zXQFv(Dt~n=+|(;lEmvhdpk1FN^ZZ{dvov^Z*ez-0uNQx?jM^68nb(v)MP^@Mz<=TY zU)Go^*?s~87yuv<@*f59|FFh(HueVpM?3r<_~I7Tnb;);gx)hs(H*HyD#S~oU@A}b zf)HRNpnze(8Y1!x>z=I5ifij7^dnGr#!*+_EL-&W^MZ!H_tUYXIR`p|6ijqC>T>Cy zLm7swnLsy3Ev6x4U3(}0ZlsC>Mf35JK*_soBQQwYbz*V?fhIAOyLWvy#eSyBfC!k{ zLQbea)G7!S>UO&NQ25I=xrT(1Rg;O7Z9$x~lnRtl#ycCzlm^IhqG(EXD%Eq)df}-; zCJZl81QuyJ=ZFf6vX6Wp=_ZBoznPoydpHOnapf$uB&S4a^{sFV2!;KBU;^-~PH3yw z5a`hFJbR)1ed&Ze++FZVX?*Tn-e~p`c)Ue`?EwXE)cr$71Lkt4p&rZDleX#Y4rsOC z0WzMM+gA)$lMVc!Dx82C$Foic;;UvKSBAy(Z3aCh>xj)P(jz0dig#Ch=T+=%LW)!C zzKx8<(?xAi+)f9Kg)LulFs{vUj`ZUxN8qHei*lZ7u~*y|^$*40|N3H|5E9Gz*V=V5 z277Ko)JoWgMxY2GktT$~u9YPHu3-MYl9+POo~Yh?x5jdgTQCxLFc>hoq)`#SrCU>T z3CMclF?zer-eH0IRn!ZDZ(B=w%2o3qbz}_fbPw_P+0)EccJ6Nm=QRIHZuQjGI8yfc zM4w#c60a!Y%Z0ZniPg=*Wl|{OS0wzzk#gj3v^WcLmjIYGX7L$uo_MqUlIL(9>QOy+ zK>RATPvBdw4nXrC)YUyVOY|hqvA#wC#EiP-jB3%|f@$Je`mxXVdS%$9n}$YO+fr_h zadYV20oZ|_p}*tlpoeT<-5la8h;Q~UqDTQ_4$*Glzeq2{NCP2m3)>|6 zb{(!`9dazIA49j3u-VL=lssbJ=p);~<`wtDv0lzAiL&4mP?=!Pa*JneH`cqd`VJ4p zC_WW1K`AT>a`LN@R@$#Kc7l{Idn~K?YCFyUNy78KbO5@-cM16rxWoQWef`Tcp!87` zy!20fEg}E_ApX0)Of75;tSvnMBgy;^u7PW{ov1BQ1iv*sMs0jcbgg32c_NN;ohbnV z3&I}WltV2kVk@UAorbQ+qu{UGobPzjjfb6%wJ^a#axa#=zV{iJ53tldT0%9@v@DpP z?jDVwW0HxH5s?Q`)EfSJ)FyfT2~sS5`yo=~FJjAaP-ZZeOi_uUwPH;wnP2v8_2pxP zbQ*E!jAA~C4NS5LdaR5QpyXh7>7h92w}X*a`a}|=nwsR<{Y(q3f__i%zD8Kd(HOak z<;u{NN^6@chZZ1Rc1bOB(lR7_QVvklK^p&>7^IpJm3i^Nk{XJ7>K112PtB9pdy^)w zce)vs9r=!X){ZKfM3RC6$$Q+cOQIC;3n{d-eiX@NI!MM9UfB;Tl#D ze-*E8{iwgYT?^`kq8Z4!Il#coxrUN19xVZrsaui7e~=Al9zH!X-SOK{MPLHj<5XIK zT2Unjv!m0SO;NI+8~9U}IUrIz%y>m-U3Q^LKiA1z9PWZo)-AQu12XNXppo{HRz}K0*);O zph^{MDEO(Q4C(>VF7bq9CbzLrOkms^mjX9JQmva{58$@p@&nb-`rk%mnSu^^UPFuA}Jk2SHU+sCt1nZO;JHqxEK^rXX)% zVSfIzB1Ur{n3GP>ppnTyD?JO|kl6C*$ zLXb&c`d?T%#wEQV{ES5PLAKLv0*_l%uXu5T-6AlRv4TZ=S7d+IoIW<%0?Y}{(azb7 zJWfJ~)gT0#8Y}AwbR`J#Ag47e%uF*Lv1>xXBTio1?-UsV(FjN;d3gwNq08wDPq^8T zSc473#czC0?W;&Em1uQqxq}B;=wrRE4BZAN?>YuWDWph7?qCROlG=_S7c|)Zy@|I3 zE+ZCmxQK0mWr>u!stw+sc|IdN*G(4lD-R0!V*laS3FKasDe7kB)7E*|@N6%9XU z^mJ90I{Cy?h6E#%-fuD=8?<5GQlMfWi$6*KL*=_rc6;UH^|yBf_s@gdi%-_@&l9TG zaUiZ&=b*HHLT)IJ{fR~9?|<{|+%}|ns=xvO0Eq$sVElV^urzQr&@-~Lwl**@PwBM)bb!17sU4vUPClTMJd_zr(4rZClsb zCQB+r>U76UQYPG?w@ozQhfkWvpbYhEUztOh?J4Bhcqo|p{(Z&Z`n0aa4#H3Q>5z!W1)wA)Csb}b5HAFY2Fl2B(#IE zXk=@Oan0JlPYrw%Xi~aw#08;uc7U9vhiGqS+7E@}17Hm}bx)n4cdt-e#5RB z$2@GONPYgX1DM!=A#1pTZ}R&I59H@JL{mYItp|XR{`!DslA|*FWCU^1aZecY+9F+R za03z3MIHM1+VA-K1tb!#z|?@9&BD~x%_^KGe}bv1F=PXIVWO-ng88R?^ntF2Y)^?u z1ylz?R0T2z3?=G>YqLb~f$>WECj-*smfFJ$`6G>2jX>uXiWxZ+K6x@p8pLOw?v*x+ z`&ipdLl`Hsfxun5;2~ukpd|pOX}k)?#$T#fI-B*0cP=~zi_A| z7(B&mB4T5bQGKu|sxwL2lK@O70}0`{e&bH?Iegcw>|9tGtasyb<8qH>_GR4XT!h)9xO)w-enyzW0DaR9|N5DI!O z&sz~X5y`SPkLV=S`QYNBw^d;!10@$ovQ&x-O}&?b>JmjHcSDrM=y7KQQb20+Ket7( zjla{D)`Xr}f!p~Dk1Z)A_Ul%(03^>8ZzcsSMWzQrj5-C#R|N`xeq#dz%Ln94_6R?^ z6_j9iP6`0uxVi?*8SZrZN|^fZ&~;Tey<%Jw#G)96=@8j_>JUbVvkQvpjs7ABj4`23 zfVD(WqDTGJTxRJ-F8}BU{3C?$0^9#4G^DC_$C09yJ{|ufR(+L9eIrjC2^Fb+x&bPUHE(%vVSoybxsasJ|7Ai*`YCY7))oTGFazTEg3%fB7!Lr)D*RVzhKWW zsDo@C`#tSzOav%TG+AF|0^r61Bv2++6o3khwjr6= zu+t_A$(zGg2-56U)f~@v*PW2*Y58RN&dO9yhaRTSzq1jJTW&5t5@SI_&b>EYMwmZY zk(Grx;IVvcJ166XQf;x>w_O99z11AhRxaRLRck!9K2S^G!Md3Md!Vwtg_Mg}54w*Pr3p3}y#nCS12~LxFY= z4I7qvXcO7^@OpI^d}5! zw4!L#`vEczEn6u>JyACk(aGev`CJl|A)4oax*|C+Kwm_@3xq&+HCm2hO6&BVi0;se zqCLNS5>)M!shC~kCjw?pseaUxk|1Sm_2XmyoYsFNah7bcqTPiVoA5^%Og2iBza+o(ceM%ENcj z)Kn2);Ssx;nb=vM$eEZK@XME+oK@P5oVft!Zv>}2q8w(r`cvr#r;>k zYar73#j=a9Z8Bc}yl6;N(NsdoaYZ*`2Wv2Xnh&Ku1TD?{?^=I_r#CW1hRsNJ-me%_ z1i2{}$kh=11&1a35xksSEO?xeF9;N>YkD3P&$s%{F~Q*%SErd!4saA#YP|poj=!dg z+VRvRScT09M9{~FntlS6wJ>N4#nfFsE8ds>A@LJqX{G4j0|%X2aQNL$3!!{OYj`aG zWiVsY76AL99E4C5wRp!@^R_^ctyHoW8T=Ul&ixl$kvrHqib9HNQ$olJ^n-!QHqmhu z9+w}Gu2yD+2f7m-ed{m!dSPD-BZwpnze=h{WpEq;N*Zf*8g)L9x`GmbAmel5HX1C4(7tciiCf} z>&UdHzv82O;fEjjt?j7VU$NOx^p+@Z7Y{%%9m-}filb^48B{`|LLJI&8OZsd9r_W0 zO(V*fYfK$uwB`D7%QgPixX0p;mKMt}{0XD@ZVS`}@BoHU$Tu6>U74&jY!mxw!H1)Z zZtVP=xY>N11Ux+Z_EgJCcS?_t+EO`}x>N6M4zC>H#5b1`I9{M!`Rym`#W%zZtqou* zY4pR-w!iyhuY_F@2FZ{*V$9v|y0tO}QiaSea<3n+%qkpkE%_L?F{+7!@MIW4E&A5k zVcIbhCK}ju|0)Rc4)Axy}V+TtX)i;s(m4nwy$rqAxm5FA|4DXKQqtHMVXm;H$wgtbtM`w}w4es~`EAXwq_1mGEWPABXYRpB33mwlG#kd#t+< zdG-FnBF7>lx_s@`yOb-X82q^8`jIS^D8$?Miyq`e2SgQEAf{pvfiTk-c6t5zJP zu-Y$E1u46oUs^b)YeaN;RAQ3?a8oUfzr8N$sn6zWhdRJL#Q`r`C+yeQSt0BymRf0D zZySZjE5_K9y4TG!CG@_n)PDDmcSz~&_ysVda!mmiWP4KwqmFMsVlQRnLknmnj0qdS zpSNevdvjO#XWc(mG?6p#!9I{f?8``R8+2V`a5)D_s;-Nnk9=|+>tj|#)Q zM&b2!8rk?9$|5&I%Ltwc2frt$KFA#0PPUASx;2yA8T}_!=B{9NlGb%uQaK7R4)S3| zc_kD32`NNqnoKM%h8}5Z)-h=@q(X7AXR2Z=HlXd8>oU&^%9^w0NVb!nW#U|rcCs|S z3s$UC&}hH9Y-;W1@O^lwL4SGUshhfAOqY4wY7erebY|Y%-Ob2D?UGIux%F*p0GVS#1Zt;^8KweJAqgtvpq*K6X5<)C(?Jql zE7Lr*aqqXCIeKqh>jVB|vRh-}qt;E?D|zK_b!RlCh%|cKwT|zw8$()rSngJ}VxvtP z#+<+RG{xtm=bo>P>Nv{u5HG3m0o3Ry=%YF;{o+r?l36*?-qjf~=BFLwNqrFgh|NT(*6uKlx_c%~+G3 z3%&*~wd;2M(xyBIyo|l$<`0W4D;K;s6P7k!%&zvXm-FWt9b3hn-+xG)e_3mJYA4Ce z! zD$#2jTmYU@p|5=n$e{|JMjUczXuAtDjoDLq*mf5arEniAPA66Gb^NCNmA$Wn5_ zsnS=iEafNB_>NP|rq`6%b#i6F#)#{w8iWY4ue=5JEVX}+fa=zO&FDcPkR7=m(M5b8OY zfK^A&nI1HqbkrG>0q0@o%ppwfP4e1ibRfJ;dChGTA+!io%&W zf!Sq=$4REv393fQ01br3gIQdmkDAeoFGD$~JiFXyB?oZ!|5?!4M9&peF-;>D{%F}|=u)yiOsV&tPrVmH_*vI@v1 zg3x!56aSFhigdY&HmJU96J5=)VddU|9GtOc4a&JN^bG}^l`$8s5QefN6#jt*_Pdpv z;GdjKQ(n*;!DQDbSiZo%@YDMI+vCn1ixyq&KJSgbNHFr>l`!48*WZ`5Rfpe`h35+f zPEOw5OL(xm_V1?+zu?ciy`@$kFAoRSV6{2lx1Z13ouRz2I{Mk7sj8?v-xn=zH({@y zpA!bG-0Wxi+Jd1Ag}UF5!>4w?kfU=|+vMGai;JkD-MEE(zh|_!YPT>tzGgZ<-ybrp zFLy`GXF6HA&z^&ckw1(*&mvh{@O8BcyNW6UpD&sz*R|$#_j!>&l&Y2AfXz!(Du79G zp{b4Zk&{$YPWEA8o3{G7{TugY`Dufbx)d-Fn-}0b86-x3IxfNi5s_mt*30Lg`WJem zkn`&&AB|;V{a74n#fS-`-l4QTRzz-b%l2oMB9Jzu(*ZqOSrM>-0QXHJt$IluOEXRf z`+lum)TL1&%bYxlfQ=4!YdXfS_SpxiH4-FZ4t>UUC581j{0G8#pQ6dT=;Qs zk^6ov3Lol{KCaG|Wscb}-P6OS36nmd1bSUjJ4;^AAQFMemGgqDkQ zz>h-~x2;V708BJCezX=6qOxr9DU_p)ItkK%gh7s)17OTVS*M7xyD_+9o+H5Ld{{vC zs|k7G0LuF@p`vM}8Wt9?k|6-Bp)*u_NU;d07PD#(hxeldddY{1(w|s1@|&}9c(^$7 z9xDTT0w1RwJNF|we;{DRVxZ3eP={53r2nFaCt9qy<_DYjPkgk%9#zO$BF!wm0J{3i zXUs)?F}J8lRJ6VjMf8izApM~)C<>Y|xUKYK`73tz?G_I2=}%LrriQ+jmZu>;>YbUo z>L#IXuTlvC0AJn+6Rm;h?p6ubtp#OC=`59Sg5z$Q3^7pF#R2+B*&<8-H;Asn!X8oQ zxenY1gs1dYabnrnw5mAr5yU36ItbGrGnq00n;LK`?vaK>t$Sly`sMhi`61Q z*9TB<5~w*$0aKrj6itE+Ebo!QSb_%+M(_&9(#O;BfZG+61TIB_v=gWX%LnGzAj}4V zD&Qms2W#DT+t1!jA24~fkIf-2zbVD!j~K<*)sNW|do1Z4yHY*> zU6FoAh^A@V-%#nx$E&5eYro}Cc)(#AlNOQGS+uk})T7WsM#oRA)aSZ6UT9jKH~m>l zO;ff%0#FBP{8+a{Rk%fBRr=(RJV;}jKm;GMXq3_l{NL`Clof~8jBoB;(Y8h|YbzJT z{3(O9$;h6gfaKs>DNfyoBl!vUn$g;?h?WqZ2#ERg@xA;GQuzhVVqD`3 zhas51{QBX{d+|W}is2aQBZu9p5A|{^_}x_S+x5}`8qmrd@LeS4?Ujcl+#$lO_l#}?if zH1xi@yfAaJ4d}37YxP(*IfNX6N0rXEe#ACI^$mIS2(2lt3V#Ze!8mJHP_Ju}l*DDP6Hax(>OAHlcp5mMRjKRJnrQ(0%@)9j7c0 z4pI#ei7h9}19J5sRQ#8qM)VrQL2w-`$T4fKI;F;5!ohd_^pipuP2=GU#Qu>5PzBvr z*`U{`>lx;xAmzy|UAI0B`LNiN0WJlyO1b(xL6qce*a8Wn_J-m#vN!3M+WSXN4l6tW z_sIxNM1(OzulZS^hRDD87$aJLKy1ch?Jx)`diw`Q<3rOv)$4Ow8-OOT*5R0x*DTGV zCVLauCPQdY8tP>bTM%)_l2jD@0dG;M6dIMEbM^FFw1?URXAb51`@f6$kUX-P+rtQi zQ@h)V4lo*RE7NtL&?X$g3LN1Html6@gNLs{xr5-(qs${!$-2 zkKM>6y2Dr-q&_YsQMS+VT(e1;uGUS5yl-L)E*MtDEEZow)=YVipgFw05a6NpDO~#T zK=#E{;yzBUJGmGZLO6(ZfkIRF@enlQRS`_0Q?5tUs>5jlS-) zLCYLd<@M43+p6H(HNS351NF`T>lUAB0R#>}2Moz%iRVjh6`(X29A(tq%wKyMY!06;3J0DWe!gm5~4?Wt*V;Huf0Ne_Um26YPJRTkMY92ywM zKjp^F44rbke`KnU)u|dO$I3C<;@FP?hR6`VKQf$x#DZxf(eFbMzrk`;>TvO32faj1 zrBfn!0^Kq^BfW(72p1CpLvyy`v45AGYqCu-)K;KsRyl8M(jaP9@STtQA)7jyVMG^# z!8^0aQ>`px!f3)0>VYE7?YzdydI;Vo35y9S<;$joVK*`kI)}K*NwEZQkBZJ4-m_Vs z3Uy7e7TjzbHqxvof%|ISRY3Wo3SO?Efx>~w3@EB9u9Hq^ANgtamURx&a+I$t9;VztpKu; zWTZ(nlXzN#!p=z5S(*|m#LHmM?L;?&i`I|UygZK`#jmE!DNEl?chq>`Lck`oLVH6I zdI&U?;FkXR5E+_c^QoksE9j^^>=A<9r1)j1Q_?(}IN9v}2GBtOE{GOof>ahM?Im>W z>;?t9dsG!7YWGDBWrKs4)lXNzZbY7eH6o|0S`VbUZ~G(=w79BdE!QS$(*atenHC>0 zj>O|N1X2+k?HZ0sQOl_Wxop;15vmAMy3YH$F zfqIE-Wu+IPB_ib1&&%oQg~HRvVbzZ*Rw=NNL+++Q$VV9)R4F~MMjbkeS}ZC-ZYS*q zKwvnI8LZi-joj=$PA}>11zEjpaxcAp3lc9F5bh2;YYdE1n3{v!-|3d%oTEQb3-*w- z8thqu?orsqg77V%ObvBH_cmg;v3+XsmWo$NrTSrVfg|IBQ)q^JAjYv_H}O#bj~e+2 zKs0fRHl{3SU~?=-H+VLMfj>@kBrfFXR%{P3X>eK5phZS9;(cn;%UUjSW;6g z$JAmKj}vN>2kO;0OTDwvg^J-(@0D+>GofIaQ96fHT`_ZquO7qhF*2ZXgjrRjWKu&a zv^>r0A~L8-wYSX@5cg)o)wlO0;6Y*(@2)}j9PN+?|8Svn3>)_vUetn@)4?|(tPM1+ zhsV4XJl{dvn4M~8>PWlADm$~(L`B|lqe5l2%(6QZoby$X!h`@`z>T=N&Q@u)`EoU3 zJvc=D6XW5N)7-yxoP2B0o2N8?N+|EIO=N~``yCMalE9K%ZUx&YKVPLAy?%TWk($}Vc;wp+6B(R>0rYxcCvMl*dZJc<{z0^|M`q5=bmAqY*z&q78miz z2j#f&x{8m`P*#p=tM(23v`W~49KdN%lsSpeDsQNnAaI+8D}g)G%A~h>Ra5r4Hz)Bs%u)l=T$qjMf4 z2Irx???)idJ|`>m=q=@qBKT$yEdCHn5e=a6d8c;FBI1MA_EZa`#>tdJJ?A7{L60J= zle&uAm8yyul892L0t&+|w3v#KYd#YC7_{O%T2mbPr&&ySM%PMatAx)1jpVAzn47~H z1^`{sGeIU}Z%Dw8$u@RB-FS@L>=__k%`(Y94dM56{BuV-d!ex&3rcAx%U{EK&F#pS zC;ijo?VFMO>Q$fiB6n^@IvK)3=lvaTS88fR+mU1ngNwFQP2IwC4N*BaOM!1v%{!Z4+}pSg9$@_8x*lJK69Jmp8i21%c%z5bNmt*@>|v*R=C0}M^8J>PJGDu!*9ROY{*Aoa z^)6)i&|`UPox!JL`Mj4*7|ni3A4H4Axt|zV9Ss?5#Y_!XThm|LQ4z~kS^1Mu)B1*j z+x1k$9RK)uy9}wi1#V1V{0xWdTwUXF!?--rVqVOGxG{2BF5wjcE;+C}89@W=X-8N0hc}pX0n^dx& zWJRIMJWHpv11g*BZ*xw*0nuk{t__n_=BJGKFWyUb+WI@bH92cZXGgh-@fleDqYEgD zBD+J4;J&zbW69Hp>3Se8TwgrKF(T@tkqIBo1!%|IC{8^R$9&^iR+U=jw3ETMjG@F#K4ij>ouB>-80i5@aIfWgNd-A zzHspRdKg_Dzq|UnsiWISsLz_Zcg-Xwe&`_p70B)PRt6W*>B;%4{uRMC+S%LZ1P4)< zxEc)UfHer9VOv&A;;~zcW=G7sZYlM_T|8l5DPs`&IK&LMU{D%7OT>qB7IxzVE10ZV z1UpiL9=am#E;1Pd1E{ArYrXS8m_h<|bNvJ9FXej7dLxyYL^G~j*^;3ip5`&)411Tp zNGU>ji12mOOQI~zW-|zCPxW`!H9Xkol;ysAGWXt*;W=@DjQ>PHh*J@azl-DVx@PpQ z9kyd;eX>xK8}E{-o%=3nb@K*I4?E~BS-XPitb&}bc@icnzEyZN!2EduDzn7~((6rA zo-`m|3M%7PI&?`f&j73mpgbO(tBNYa7R@x7mF6?n5=KfbGof-PfFp5uNW7 z=w8OWP=>~t52v03eP+)+;eKV6>7(eDTHP%CSvPl=_QU=C{ii0i!|+)uYbG^6;=Q$m zK>fGDvzp`3Zw#*+u8!||8RYOOrxPcC@xY6Y)b!UG5^wx%+}Ete*5e$fhhp4rH=&#L zTtR!t2S?0m+BRuz3!Zegfk)|rWu|AJpyh!zFPjDYaMLEu{Q=Zef%oDk0$7pT- zn4h9;;!yP0I`lPGeK&=mW4g7{Tjz@{1)gE5p|6V6t3$CUK;)c2t5#Xaj70{4&ehbZWuXAI{eB zkp5tX_xlA-1RjE+G8`q4@KlG%A)hiN7VsF4tiU5dG?`0`j>HE3r5HTAY7%pZE*UqZ z3L(?+Ghvg}clpQ)e>ul<)Q{T@IAIOKJFP-xM_RKnN?c2@?eRN3Dxlz!B8S!iKd|SV z*lzHQ5Kz}AmMdLp(2UpfKr_xRIyLU~E+HV zQ6z_JcJbm(zjtpn5`Tb$us(jU?Wz?1Y*bsd!FImb(An2U+k5su$hH3FdsviZHx&#i z__tDP@APOec#u|M(EoRu)O9&kv^%uk8?TWuU(rCo^J<%$%=oiWAT~YRX<#}bO1aG~ z-(r(q2=Gpen4X6;0j7)&?#{Nr{`}TT(#TWl(X>7r_k`A3RCOQgvDJiC{rk7qaJ2X- z>O&VAgdZVAEkp{dAfkc-MgQI2w>GzNBnf`^ujqj`!hjcor3vrF26yOPnxbWnMN&&r z_Bb>Opb0d|wm<+z1EP68|M#nWRb{;zASKzJ-9gM)BvF<1%F4=jBC&dp@I-T>U75bt zRDrx?2WTJ>BE=~MQn^y_V2_DU?M%%fP$5WHkDQoZLiW#1!E03 z4S=~|TX#XXtDwD`vHf#oc}gxl==e~vm!aFdY@|@2G2&=^`XnXd=!leTUpl% zM}{K-JH@1$8qNXG7zVMTuM?rbsmH$YqRE5U0(9|4ebNFSZ_ry%c*@(^{FVeAj4hG; zeMbedE?^d==Q-RH7QI=yW3xomH#gO!E|J-0x`pY0#^$2aV{RSK`)vFur?GRC439T57B?| zyrG6*)x*?g=8r@0?i~JP1H{@Ohwh5oK&FI@3bK8_9o!D;nL77`d^u5P=vKkUh{0<+ z>AfD<5V#Nm&UPCJVVGP9=yI|R6cwxI04&!t?5lcV9N-YMVXiO{8T{B)=wUqq`yk{y z6a!fmBTS4a;RD&U*In=N!6~{+7&_8zzJ(G{AkK(U+#kjagaIBKFDrTQiQ0WQt((>6 z`x?yto)?x~O5q=k*4`Za&3#mSFHX5*SS4JqmDrmb3K42-2&oxa5kizi186^jEzQW} ztl_#D#7F7P#?03H-8PTC@{^rkziCBIyDI}oX4lQobK5A4@X5A9k3@~g&bBj4Ow?aU z(o|*fmtQ!S0H{0HmQaIn-!I{YOWGBANa(;YEr1ISrZ36X3>rJs=FlrUpIz2h;L%VP zBW39xwcvGh9DPX=wHsvJ?4zgfuFkgEFO**f=?gq~_$_VlJgpHr^V>(?8QDUXN2;e*Lg%r`s;bKvbxI= z#q^SAJXCMcfpiU<&rZ9Bxm(qj6xO^V?9eN(r?P+|^lu zScP%F^ZDELg4~80jMlxgS~vn(qbV6wgctr63FB(LlzVv85_!uYEikhNb5!*h4Hhu0 z3WkKwdI2{J$Ga=8YqX<~CbJPGp>kzH+tpjwCYo7OOfh)((i}9(VCh}OTNyBkP0GHn)~IOvj)d3|jW<-%!$rAFH|3-XG6&$4U^g9QjW;v%Wau#F z3a*Zgns4e!#d$!zk(A?0*K=;PiU+hbYcM@P5WWk+Q^EH_y0B0Wo#M^%t&nEkUlva= zZn)uCZVZtEPGyX4IOXL^Y5w5j=(JehS95bU{FK1ferF>l>_P^~w-Y%dd6;(Pi3+$i z3M;bVN3NeG@X*!%l4wM*ofR3j!3yB=Zj8@h{Snz>&BA=qIi2!Zam09FRC0CMK5 zP%=}7I0Vm!%o3VfN=qVk!1Xa6@g9L zgJ<^a3TB9}V;bBNoZn}4bDd(aG#hak@Or_LK)%o$>R%oXDYb6zo+Gto7Vq=M0~GgW zf7*EB2|xIssEt5@gtIKk)YUO0m`m?UZ~bNl;f#VuGNz5Ojo(f;8yLi}Ka^DfyZUdkq#fdoP%`rTr94QWPl7Z{rD2|pP-J%4 z0`Ster|{-?<5XZCn5mXc-NLX${MyH2UhHKw-`bxuO*CQA&4(#M@*(*Ypo-o-0LQG& zuUn{56eE?3BiAXEZP*T+$zU)DJD)r|DNfN^=?bb8U~Gc)0c6Ichav_UfWiS<9`0iI zKIya_3WPv{{MlaPZJ4%a>zkW9tghsRBg=5&U<=j(Dg*{}iqLJZ*UDumizmxVddLh!fF zi*Pq$i&|IBCVTj+-pEp{DcALEwfp#Kx`!TE+eNz2tGDH5F3WL+#U(tK646Q|oJsVC z42D9xr^m|R<4oM<5uN)R3@4j9QCrz9?-)&s~g3OYceHN6v!fnymw)3+F~ zOBWIh+xWOCFG<&;VF9}%nyLBh@ub51Fbr=QbOm0< z4eVfDav&$=I<&mh;NzGj+KmulICMf9jI9o~lNF-35cA?rJ1qGitFHudamsb#cgV?j zCml0g;i7gd$hTFsP+1^FY1P1P~m6A(%1rq=49F_WJZCE|X;9(iL$yjRTGBY05!d#NeqDELU z9IR?*ypdybWS?87DSzTt1lND7dpU^qyB8ykj!^f$2v`A^V^d*LNJ^3jB~;Q?I}PHZ zaOplO^<&MRSi+Wa?Uv!0t0Qmx(!$iX@#Hvg*D3pmjc;R6m zWt=I2)+!w4O@7 zcjF@NUYchMecHL=fE4gfDU<>GLN^GHvv>M-mpnO46uAn#s6sshBoJd+ew2JwY!wgO zVwJbrA`M4taS-`L`5Sz9RCHn4+=B2#O(w#a9&OXrnsR?96dn}6(9CV{qK6eDI(%$J zhyM}ir3&HJq%#|vrsl`E8NHOYkQzbm=Qx`o?o9Hjh7V`$(Y5Kj0<#y78tDp&!UZ5h zJtKU0%Ti283mXvgj!LsU|A5UodT6h_`0ahD*nL09-wbmKZvJK?1hHy39;GD z!q(9V)6S4BaERoti~sF=cY85Qore_Yg!e3w8iexQ+f3T+j~}l`5H%LE>`TY58~wA* z;mb6wKUi8|+q{6*8ME+udU|r|=udD@+L(F1DlI%Q1IyStqK;}X?|3EnCw0i7W>Ly- z%}5;%2#U(oU2na^lVc-qA3t*=IwyCZ|KP|q;8mR#n_|*37In%Hj43x+M~@j&$Xmne z=_QMBX`U+c*j1O8pb|0!-|W?uV>?`)orqCZL+(+%Cr|MYhf^AENO&RUr)gA_O+9xH zC$=`?(+d`AE zoq#JJFVG-I(A$M!__W*YA~TP;0a^D@#W_esUujOF7xNpol$LVZwAy8qXuhK&dc{!* z>HDieTL4Zy(#L%7dxt-Mj_{7G*y-6mD9ZI{pu+>@6>c!1@7V(^G2r zC4h%-l^(k9z(R6W^S?k%OGswn&_7*|!MS+i_ftXYOHFngOQLr>2r+9F!GBwI{4GQq-d5#Xb0r|{V!{yU z%l_rf#iT55m;Ep{cvtO)zu_&>-i0TowzrL9hP93Wm2wHS6a2jig|7=oP-j;*Hw&ND z$i|P4;C76XS3j+&V;pAK-Ry|Tlw>yxyQQVOYb0s(3$iJkY@Ju-)vEPrsa0V;6ScuD%`<*@nVk_c2YoDhp_ zra*PXZSEwGiZk>UupWaK)g{D@a*jngKu)vB2!!RYXki*w7{{NQK@`nNAln)Q5b_l8 zx6&u4;_x1*Pj*UV4eTNMWTiybpj!3G_q7fMW+9n@I|)%Jx?R1Yq;LW&qI(x9zQ4MiOmQ_kNO(&h72ji~S0rO7ykBw=K28r?OBz?oC;`@W^O)U%19T%5cN{QM z?4zfi_wIG=o45J?i+U-0rMF;Uc#~KWr|nRzSZG(iW!QhyScbZ~aI_xvEC z^t4l_A6`G*c@rSTH)x=4IEd4?$xtnalBz)%?u3XX_=nCY+iU=fM`K<_#U-dPS2&YF zc74lw(d~i4PNziOaPYFFc$=MD5I9xB90NrtD+y`D4}!IHlphIiE$}4fo*wKU9DaA; zypUBTi`WHOk`p?Y;`XoKy9oD%>jVR&iY<_DWG$ck6eUJXX%V>fPtnY78%` zB=c2JFw?6`fNe+?5+W{tB5otlO7r6JB3c3p{I%vYL}n17&raHs|1q!$dQlTO)c`_Y=0qw%p0y2%#r{lK0{B2#jc~-(tkDMG&$4 ziRK|C6 zE-Kx-IZY|itY`t#tUwaYjh$X;u*@(%$Q9BEM)AbF;Jp7SH0O5ueWYeM4@htADUO1G zEmQ~@re8RcKm+**(ktVZAV)^23IRGiD!`keYa911nt8z$!P1rz1I|`ZGjp@C`T(Ss zdU0soY>L(GyjZN43ov|#b9&m7gD$#oc~E!T{AUkS!9dm=tRGmU@U zwxR;?FQGNvGiO&{tXDMRC(*f)jmv*MB7H_Te4-NSgf2LE!*_j;dGX~>DhE$oQwRXY z5f~YDMf8vX!ODQcEl4m^{stOKVD2m@r0R!Gr!Gms?R?q5A(20qZnz@aCk+!=z4RJ)u6rNYH^baR?z~x$XiZ;*Bvd&yH!_EB;@6-dcC!G zcpMA^L-B3wknsA+n=nd`j0T)g)>Q?APppgo*w-cIon(UK?CG&`UdFMXtVDpEvo8{m zBSccLVfcX$8bkAn*gyugu@rOVlxLr~nUTlBzyoU?+t7CbLu8@eHmTrH0Hi|OP@HVJw295k z=M}lm^*wOx0HJy+t4)nYR4Q6Mr%JH`%+?GLagm%v{GhGTkMC|qUKZg}WbS@Z4Xy@O ztuqpLThn9kqnX}dTF$NpbhtrPv4jHk@~7D*FAFSvxeG1dORK0o8#XGzV%LUYJF(?J^pWdTeIeFRA2bU#?yuiq;G4S=K{uz$` zx?F})XGk_JHHca3ZM;y2|5M#ihgzZY^N)MgRLgISu)Mck&7p$xv`Xfh!GFOBZ|(^M@L$iKjMl7o1kg~q=3P{ z(a*Mr`de2`+VIx~tp|KoBo97iApyGC@&>zlsoYL*jOR#YArKYD7L!PaNg1f@`Y$Tc zpDz`@f7$Vr#=H|bwK0F~DUN}flG>Ca|9KRpLWaM~UFMAmOK0RXz;?htv<~<|%3pK6 z-rWrANtjsb!)SsHsuB1$)g!mCIy0_)V2H@@@YyGsw)DeB^&!*j*J=7ZFi8340kXA` z;_?8K%7}FH*eB}4NbOsA#zmPb&Rq%`Ow1uyNGBx(j)YLc6WRtMw0&T;!K>qMk0C$B zCcS}7E3F#v5a&139;l-Tt+%WtFk|e0yZ8Jci7%Q8^77h+R9c^AYFk3kWd|^%!qUPX$sF+mU=W=&KPVotfWC%0w^AQjn_8 zt5@ljGjgv4>UH@I5#N4Szap?%cDZ;u*6k)=jGoTPB{xsKAJ&96{Rz`_`V(zQ zn#kYJtS1974fd`*B&(Qc*)2uxk$kI&1;kW0$E~_sLUC|{D%?b&it0Vp&W+5AI;&|YO3cC3Y2&#tRwO|dZn9H&|# zo7Q1*i(RnnyG@-4=n8W79P)`=PODTQ!AnedeYKt%whR-X2p^cCi@fo>YX5>VJNURt z>G>gukYr5TExLqbuix}GzQy-7j=5H_hldB>Z|BC7rSg&R-L|yXPu}!?2WOuTC=0h6 z(=Z>68Y_G241*0K-Io4qe@wULJoYBE$FN*;`hEHF>CPL22yGl>#@d)ctfz0R{qc8U zI<-pU9;ljO)RJ7?!rr_+k03tQo=zO)|&-ReTqlB>V)+FIjAjm@jq} zbNXPv*mU2CxIsD0jnb_1Rqtcb1|uYALg^@gJK%e~sK;DXBRet0?n$|!Oyk^Fo{EFF zEvMoTUXx0O8tMjRvay0?Ng?fvx%tiTovYglb3ie-^4*bUDN7Zp`dJZy7Nrz0u{tkd zX>$vxMid&$kX7kv296mC%p#ZI=vzY_4cc1NVr(}q!(OS}1_Olnk^lx%$xTA5@`i+D z_nnN)WbRGojA3m%576e=as5#VMbHrpEhj9FFrU_E(SfJ2Nq|U+4cPUFBXTLIja63| z{6;s?pA$JrnjfK1E2VyN$KoArpidlGb<|{jZiediMu@kbmkU9$EfqM`%X-Q{QnO__ zsUFWSFD;Uz`lWVhXoKh!3L7>Y)al~m~*;A9UN&*r%)vmPl;GBSy=WX4nvuvPXQ-?D7i3K3u zhwNeQh@N5z9BH#63g8#zto_ZwFCd%Xl8V%^dS6N&$W0T}@ZZN6Gum_YZ&{Hx4?9A$ z1Ow%6wdNhXi>H5yoo3X1bU%hQY4f-Ut_s;N!>wpR=|D<4p6jX(x7AW1Tr)Q;9m0c^ z5$*}(By48mu1tKTe$0wiX!_wQTR z5c6vYZ;ydu16U%erfyfzXcVllCP=ot~sGcO69X`8H zJct*mth6q=qsk?SFx8_L5aV<)CU8SKyE8)N&48ny&g+}%sGd&iCN5;3DfdAypH#$2 zq+xdxL`z=cbtdgKMDpB&quR`hSKDr7eXt6`J)fj}j2GKDblDH@mUmxm4 zpkvRIE@uU}fO9inVdO6-Q{2Njm$%70US@!Y3gWFo^9JTEw719ns>f-E z7`x>>+6rbQ$0bfQzQXNA>x~RQt<1?8_I$ox8o-bcjbB!{$)ZfMDRL7UhzucU<>tK; zbreG#aFBN-vzwdlzxHW}X5!1UZLF#0`7(CKEzfng?Z|^I$kIaar-H``m@4&@>~_ZM zW;MTohpQ8mb;IogWxeT(NmG#_Lr?YEP(vuIM9#&HVYpm_YSyd`JLhmwvR1R>^-VcJ zPe{1=*ufelK^Z%($17K5RrzPb2P`}@N&$jy;SK|yV^}@!5fhtD4!OzutMT zXZ25z2!-l>KtjPu-K^@FSv~+D&ogmKA}0vX8w);Nv>DE5-ROmkPXqQU9lUyys%gI$ zc}5v#8-^h>X6Z5g$O+kck=sQti(k)XFCgH1_D>HCt2Pu*pw?cIuMe^~zojP+vZB_r zDstp$^WBMKPLlj(V7c1BZlX7X*I)Sss>$|zih!9>m>gg$Jo|p}*~D;oGqATcIii4< z1p6cIL)6R2ff%82$UzxJm&moHg!j70O_1WnMFQ~K;w==1=Ia+^dXcf*JOJ$>J0Ns$ zfAza-`=>Y3==Ib7pWlqwbI+jHWIe_NGb8K>zCY&cWmKy>GZf7modsmS{?lX~Mgr4%>fVCH8GMhHm@B%ly~fkLf_33uoSffav<3Qlb@3 zcxV}~RI%1G)1#r9h%0sLJmsJ*+^hFbTKXK2hsONPz;E*9D87)vM0Y^U83Fq?;zdL? z7Nk&_ccfVIu}Kzu?3$79(7SeeNBjg8K5H_MSB7H_4>PQ%x~=#N4w=E4BOM?ywfZmV z@Lyf_Ecz1H<%*`Iwk8m*GDeR~*))%Rk9g|@YG(NFd!y?AB3#T#wdpV`un=PvRa7SM zt|#Q*dkcrLjP^eq$1Uz+12wm|$s3tuz=#bKs)CcAaTdJP4yeF+EmQ3pRzUEcJf(Tvu z#oMY{WZe6FMl(e{hU%)c& z5rlk4M=n@Dq2)*|cLzWTK}q2O1)BJqXATz;2OQ>BRJGdR<8<9XZFs*x&NjO@2q(I+ zhc{~~UT+#FwQmW+)>f@2w8}9733bfLt@Os<+ARl%7^DE^eOP?Ip3GfKdE{(yFEJuw z^9jVT({HT0hpHS*U+Y`^dvjF2rKgq@61+;&*+pui*@6QHwGFw_KFp{A@QSo^4i2uu ztn=csDDFRM5%+OcL{hyeRR+EfHk7OJfIDbPN6N}(#h1LEy_>(4iI8&w zTC3}u8pm1EH>>%)FtX(xuhhNAIx#`UryV|8EkhtkFOyg}JiB^>yO+zV>J>qxVla?8 zvy>JsOvwTmfcZ)3{NI&E7+#bsP^Nid7Cgm=XZzvG2)S}hfHR5_3?sZ8|vS`X!Ii_mUK1UpH%`O-Rf3a~pq z=HIN#@?|}R1YA^I5@QGQ8mrD#wfg&KU(e?qgBy2B+i-Me{@90~)W2Vql-D96;TM?v z*|fm;SD7DR;Z3<3Uz>TAS0!Kv8GdG~8tPBiSF_TLSNZmz^Eq`Q$BdN}3c=tFxf)>G z?9eaA=F9HiKkLN18)b`?MK^9rpnm zroHzdL85RdJ2&x0&4EiODe)_{80hZs6@0a&H)eHHAw%!{W5@Pp2Wp5BmN*xHG_q=~ zT0)@1b@5mUb{dWut`MavPT%Oj1u4!joYrIO>#c9Lz0^z<{tGXwswV?vVp^_e<7=yu z3Al--Pman@;1OAxF|Kzzoj%=q`x_9V`C`@VzJB9BpOd;8&S!?juWi(ZgwG5c^ER$m ziI@xB!U$JPuO_9T-Lv}tRzvxo=@wE!nhN|FJm74_1a19dIS0mWc(t6b7bzz41@1%; z;UxwZ zN0x%_w=e?$Wd*kZ2YW^#i@*>QEyo8O1|X|*L{<$NP~k?jLz)@WATjpM7VtZry6a>Fy91+IbiJYAjUk!?6~4~KsRMrz_&Jm+QsK!8v4!!bfLA zx|@DLw=19V!VNb2VvuNRxb{8^iy)j`R1UF_+kq0DK7rv(DzHKG_0_duaohp1t3ml7 z_RW7Hb+!NEC!?^?4k??KJ!NzL!$ilT?nD;O21PV4uHofgI#N>zra}NySCAw~Eq`FGzbRz}*~0)%QEPCt`yNE5-`;`^#JyQYl-kBS#_gRxRw z3`*QJb6_&f`%_~&HpmGGln)aMDoTjYkmrRD`4Pd*cRs1m$ygj2tsT1|i^38$j41sY z>2m6&Jf>7|cUG@CF4w?O1BOkia!M}T1AmQESix>_3M&aMyLB@ZW0Ps(*KpJgKsp$D ztt=_0V1|U{+6C_fmqOcPlsbS7Ka@bjLShmCL<-W~L59ZltlJxCno(X3(HnJ_-5mxH zgk{hP?fM!G`wp9-lBfA!fNq3R1jT`)f4V$OU;cb{vj6Sy?ELg#??`VexD*;}8S1LM^W(xe`@oK$bKA~Q``#=5S&;DP&@H)U2$<6YIHG2p@fjT%D zE&}en6=aVXx;qZg^rog-GbzEQPhHc{=>v4$fyc&817%4%d+a(x?lMc8_j_e?E!)s8 z_vUm`^{qSNMe)m%Uw3|a`fJZG7L(X*4+wjY*!$Zx6HG+dMR7gQu3z~1IgZM^i%N1j zLg+$6u?q(OjWoVq6zj!*`nbw%cg<~rAG%a2lpNO074e`y;}Qbq+9mEdb7{neZ$V}) zzn%mm!A+JxfRn9%{0V9y(QsBxOddC^pG>$rScqny9Km~LRjQ~(9Ra8jkyX(UKWVj~ zq_8!&6n(+-tC8FBumv5FggOb<#Xv`>-?YDA4B!kzT=PTMOKy*v>FM>nS!G3rYe1d{Rc$DyxOY*c62<-%h4&%^J^*Nk2Ot4IDu^~36nh#}C_d>_bZm}? z$%&E#!pOmzftgh!4e-OndOfIy#^Z142~GZc@`A_CYa1VMI(FybW?9UF_XbbRY;Dp9 z#d{3sV7vd8B!Fr6xy~8Aq7Em_=^H}Z8AMHm=-z+_*`xvvDe-FhFH)};PMVt1v7-Bb zI^9E`n}3Cy?c$UqFss`UxK~EZlDTgZl&h zlXgV3GlQ(q`2S?HjZ8LP$ss7LK_UA|1VlVSTUP@gs=1Xq$^z@t$codwX+MW@Px{3$ z&4kc=1O6kNYs-3Z+jzD4qG!QuAY~5sNgsyocaPEAm!Kqs%=I1IF3Sb)-G-@$ z2O#@iPpTmFz?<)Ev388H+p(MSZ3SxfZfAMXNq1FmKp1yzzsb)znlH{IJ<=w^X%}<3 z9GQbz76CpHP^lxba$J1=F8^E=cjdTj>5(%=tTwM zO};m|Y_@dWLLS(PR(TH~vfoU&Z%ed!Q8F8SG6DXk_aiM8FS-(V8Yy$oqV%Lb^8cg_ZPtr4=v* z-H1CHmi#w;4`!fmX4_FD0^#=@|@ z)K}q-E#}Q?>w6m-Zq*}4iZ`#yK-_A?*4nC3NvDAz2B8b>hFjmCKUuyZdM(A0ikx!d z$$eN>oXCi(D?qFl)v_D!0YAG6+ zPj|fFp#A-|&YCW@>Lt@}`o3mAAX^%GA>tJ^j6?(3EVa!QGh{}DCsx+sy9%#FOv_y6 zj#$Hmn2m}8a3}2cdarjLKYbH{m|y%zHQF8wVRNYKDQ&{U&KD}zSWgD1LHEnkl=coj ztij`+eV%VGe$NbM^IAB>7#u7ihYJ0506`85vxr0<75SvAy<@$5-e+6Kd+)^yb`sG^ zPtssH4+x`x(@_|_q1??p#l9j;uU(<@;Y`@H`KM zp%-maN{i8nVU$9CJN;Z1l=jR3k5r+X865glB|{8K*tIW5t6o0ZVq7?G7k531h3CW< z@Ld#jbT5rVSfS7W480D(t~tAsfc;1r7u`R|6pzWCSwiAOl#(8S^gLu}Ls(wtq}&t? z83{@O!Hor}001XLr6v=rEGM>+AD;|M8Gj{=sP_$XY5qXr#>N_4rJYNQJpmKsX>cPI zo_n4MUF4CX?Fv)(s^=|FK=cCTv7R|jeG8ohD`Wz8WlK&4c|Pkgh9zg|PuojN_VCw+IJfo=$l(<9DHI`@j3eCviZLX zQkGa%4n9UPt%^P!c45U#{y{+f83>`BNYCo;o5D|ZdW4@LUfR=s?Jxnf5BEeKb^(Hp ze^Qy_{LLeG*hJpQ&K*?0gx*O_bT5X|3`%p2%@xMq{00x0K5rv|wmOHfphBXoVbEk$ zQ2m4kw&3TwQf2C_BXTA~z#ewFqvw-Po2G#vFs>WQGT0=P7@u)jj+Otiz_=z%lr*Uf z2MqTUs*g7!J!r|H^z@ZL8M(1W>@$Otk_3_%0v?2YV`vh@NB08PV+#udVtB^N`vj{B zM!V>(wpm{QQ^bMglX|>zkSv0YkkS|`GwK&3;G*UC9{j;ppuqice`1~3sPEQ}#1_7U zMj#m>i5P91al`uU$Djh=Q06nN+S%qJitF7ie)-k^^Q`)LMIz}Fcimxy_8G!B4#O3U zt07lS_pJnb6DYgL7$bp3uqlxPyUvSYo0pS`-^+Hg**_`QuaR6KHG)Bey#tVJTa>O_ zwr$(CZQHhX*|zPfUADc;wr$(SF5NnPUq|25FFHCSXGTUw##)&fbB#H2{A2!KimBnO z17FP?LId~Rcuw0#l_SGMWwXGm7)2e_3v(SQ%qI=EqFT)-q*jk_6lw6NFQ$? zOkBpCzx}pobBEjUI8+4Wm?_f8(rPo1G9;-N89^G^RHoE~-L<#M8*Mg7uGB(=x?nRyOr9345_k0gBiQh|0ZW2R3k*Y9MG; zGNBwOLs11J*?_Xb5+?Y7*Z**Ff4sH38DaP^D)pzQQ|{T@ znC%E{9t{%0#>Mw<`Tuk*wAj|tBtqBz?=bLYZ0jFW`TGuEQ2F*68{5=>WAx=$l%C)-mVixa)jnWPhW3|D1+Xf;&SYd-D!U-R_YYk5V zf{b-{{6lS|dFKI%&bU;XDGBQ_Z>T^`g0q0xWJ{(d!IlYv5&==k1*7TVv+$E6-z~Bi z7Ghr{^S;Pp0#^iR)>)bpAq~OoMwSvn&4$$-F_zEhiqV!aUhT2Qqil&$rq6gmo4#QR z;;~GVK6SE1Op!^#kZv`vc(%Qd<;59LR;hlHRtIH}b^^^@pq}X8(LQ&rEZYD(sVeF+ z#LI-f5&_Mx@3AF-x}lT(BUQp0*W6`czT*~f+)3%t0ZVHhVLZIN+|5;VTfpT_(sY(k ziNQ2n@zhYq`GRI^yNnks5#cA*FK>LNx#EEqX+@lCBLTmw($T>jh`@kkxqIKB^GC7XnyA%asy zcgbbiz5%rJJLCwdXz(1x7LxIDT{)=ufR(_b72lGK_Xe&3)+2tS8Z^_&_cC(b%XcGw z0j33>w3uqvde(Yv%idJ8V@!>KdU(2ioxJJKd~tj$3xN?sdK3uwjp+u>JsGg1_Wk!)6X@A~XfItV`jb_P=AiHt-neV^0fD-+&_q^%M;P$za zldXentX{)8yaA?t@w`Pt>~{UbQ>(Y5p?>XhRyTG;xRqT=TH1iEy9=k?$~5Hi zJt%8ifmU}9v*d^KX{POf<#rzX_h@>t#dCoYr&9yK(EM75*G@3IugV+F_M(JcF>bL% ztaTvK+D1D`@B8<5RN??!H1p-uxNs1u-R9fGU5ZwO|M+UhFMB%OpHg~@UK}{-2iziy zTWQx@+u{NkM;9>F=C-}`+ltN_GNx#sd0Rlb?Vp9Sbmea&?UgRhGO23s5ayld_iLL_ zdGN)B4_7Xr2s365D7r6b-A8t&Ab{lecOGgs_xm<~<#-Olb|@4g^}qnN3Q?=mzx>4T zIQHZ8M3L7|Bkb(Fue^v+?wQDK*!3lKKqCazc@)n45kwF=8A~?-RD%H!&FJk<)DPH+YukiXt6=Ke&{YWug?cH&?$6csDEs#a4lHx5 zg6q<*9h8D>5Q#e5p$Z@ZvHyDRc9#S46VCc}q5Ra{&H zj8+4q8m>*JTQhe6GF*j8oD^)Ykz7>5Ji+aG1ee8D(`3DwHRJV}CjJC`sNdjVj&PEu zV|EehBZ&zC^}=!7>F@wMo?clz^APT^*yvu-lOaKG6&n=J?Bu2VtNr+b3zXDA)j3;* zO7u!twyWsNJns4O$QxNBN;x1Lot9Dy3N=wnnD!MqW8}MIPT6g zdFzLiuub~I`XP30rRvlIX?48V1S+XM*WoN_A&Hr-qSuA0c7&_=0PjKB2R(W4HH~fa z7-*{AS?-!+YXx9di_yXCew=MYjk=#Kw~4rs-MsWniL8QC7WVjlGo-GwNUoI>6;djy z^}?}B`YaTCwYTkheVj+fOlSI*uswRCg_<6x~dNaRB9mro#xsh=8JeMO6~Zs=zm1ONWm0saUBEy3%jK)`@G#eFTdLWsBTj zXmvaSFVU^sL80!_%aR!PbUC#nvmj2k=niC=Ru;O*V;Tm-kWGD-lH%7F*Z%S%Pzmjw zw&=KIYl{4{$Wg&<4M&#yx?z1UD$6xw z4ex--$tBQsHT;b`XjQdaD^aQiTSI0`te~O(wl7Y(BqBYBh;<-jTclj35mG;qKhc;b zEsqLdb0v}lFO-vR%y$4#?59I;y|b*4!rKymyf#He#G0&w{4 zNfaMZZ3HS3B49!Ryvf3PyAfj7M9bjE0KRuQkv}wm6a=9qwnOb5 z{fhJDJK+wcq+!%XM?Vfi=`69!P#V;DLMysU?I|8I7wOTQ4)u`089XHFKgVhGE0pLP3B|ZZIqI%%x7E8|-mJHb@Bs&rhVG%OV#gjTHUZyJkfsv*pnz@dTF04T;~qj@pWyaZh6JlrH~59riQ3`Xy~ zQ~^6DXB4T^;VxQ4wBYGWCbI&BA%Ka@&Azw8``R`%j#KSV{iU_ELf&R(qM$;Z`3qP| zH=9s9a}BBz-}9aUXQqR2Dezz#!U1v}NgMHK{c093)U_!m`{ zO2|&S0E;n=;P+LbN%Y)R1ZM9<F^r#I2CPPdY>`-T=2z;RW;X!^G{*;{ z;&1e_{g^&Q&6t1}`iV$dg)DHW=$Oww{%n!MQKS?VqNYTqrwlgioZHY5JnBjqEj zVr>FTmy0tCtobDlgDFRV@woxk5~(EoM=_!Oe*9R(_|h;D;cfw1t-3`u&UZ8+jtPN? zIVLvG6dMq=*(7RsjTg&a*vVBMJ{Nz#(g@=9RSNshUGlabAF~Jyn0V7ykASL> zB13n=kepo3Ujk_M*_nxh$4Hh4@3@j{1gI;_TdhL^DL4_ zLzX@KpJhuv)_UAh2q=-}X3@P$;CAi)J^b$$0B=4Z7A|OScy2z)B0Kz7zEyf^jpE># zC_XkLlXXVp1ID1Rd1pgF2wUSv=9PrMNFVdc#Tg{M{eJJKQRdymqvDUo+n4V2rzp=>Ok<2{MMVB@$7E@r)r3%VzJofZ+je+1%JPA|3@RXnWRppCjks^*ip z{Z06g?n0}D1<_i?bse1idhIP4@@fW>9cw7zgv}yJ+%PHY!&>Y7DD|#JK@eHy6OV^c zY6-}jP@-QuQAltX<&x5N1Z-Yb)TCN6FEk>b=>hs9Le00|A0C_Fh)gDJ(OYRT)YP;h zLI^rHS39Wgjzm{TCvW>2fW-v#C0cU0h;-MNDf1BiqhEZ+Uj`t??xG+|$p5bETt5*! zy>&Z%^;K^=)0*->`868N*r0JcQ|*9EL7QA#(Kr;S?vSXemXVxBb(NTV>iMU2pxnYx zK4ZZom!dm>GO1|$tT_gx*g0%)AsijJ zLTFOY8WRf7C@>b7KBJF8Ut8Wg59OpoK!C#>;9a{n+RD-Poe;I+H(iF*iZe^{2{Zv# zs79`p7R^$ecz%7J%3u4tKepM2sMG!Qag3<5C(wq$D4W_Y>BZDp?3-Re(v5Y~|Zk3j> zIfGX12`On=R_&yJD9)(e;V8gK`IeByh4(A%3b{!(=ASiE)7Mg6qOjajU({$Lh`=w< zZbc%2he#jC3vv^j^hpj^hl=2ULr?N&JoTuJ_po%Bw@eyN4)ql;WXZow-f=;YN#UJ6 z>Z-3OsM>m}1`tjixG*YS)z|(eh|%=O&(G|eL>?cIM-B0H=~d?5n8&W{Imi80L4P=b zxAA`kq~qWhMmD|sKClWug2;AT5(2;1c3B(Rr;zlD(K13$Q`eFCP~qJv8S~MB4*B>z zLvqZD^P$!LrA1VOsT8%Fk3mDS0*W6X=x_&;jc=WQPvQ46E{kEf(rI|^#{h8mF@rH< z3gy2g7>-}RLTY0w2*Qw??~0Ihk9mEBD_U7ix^9ezXqpAbIJA-j5EDuUR@j`WAS=TG zO8`jsVJ$Ri>u*S>GKw;hi(h@~w0LY!Dt|6a?c`u{@4&P4dsII9R@dR^KwVeffVZ4X zme}lL=*9HufOEUPT?*v_9hC~s7fGoz)N)Bx!RPY?b?Ptt%)%fmJ$HN}T77=~1(L;; zjd;tiOiFIBk(MX@|H38w!@Ly!gKygm3IG6s2mk>8pM?h}OYXRjbR!ZLlG9 zU8p0BiE43d<9SlRyv`EI)(d0;$T~s+(bFojH6oQKs<3^%_y{MUm~xMyRnw(PPOoIK z56n3|?;nF!x)3iJ4)AsQJ?*1wQReb`J(t=I%Q3he#=?sECB)Q4GgH3p%1+ubaFYI1S`sDJ&UFzLypB_AG16{LTHjFGqifzhxA2I0I=5SG>N^)P;rN^ zGp;Ux@wcE~mt|Ql;KuA8I&b<`UxXR0W6SaT>8-SfO}AOMafKI9V9hn;@`vZxjJvA1 zG*e>Wt&u&wG4YDEAqFH18gy-F&NX-FHXA)4eHxTgF-2J9Q*3woJiK3DlWveL*f$}q zQk=(XPi%Y4N_BllI|Nzg-D0;kz=b!Q70rlFq1xa;2nD5SNdQ=AVhT2+9#j~{Gy&xC zdAjUJDNKXCR;LdEa9n*`W#EC?iVB;1*2tMdruNq`8E-bc_7l;8%;dnB6lTIjNPwFK1xIEi>GIHSDJToJ!@9eO6tQ>38Psdx#Wc+D?>oFA|xN^taE|gNyo|P3QPP6g;PW5mg*H;2NHM!sbM1*13 z#e>sje~5|lVthi<%t`J=urHaNdV&`dY!ZM1AdV(ORsi-k`CN=Rr{K9P&v;D~c&8sDvpH8ur#x1P+x^QO#l?Y# zhF(5yQ$hUlWsPk6X+}D-S7@#BjtE)KA{|Y_4m)5l)>H0#QW^UnmO_a84$)-J3gtLJ zZ|JmcSW+);gbhcYvi$=uZdY6#fwNbwb==^>5?*!bdnPnFR`QV({^roBe&ULxvLWf4 zGC&$=S4zP+8z;h+EirKVI8WEi)rbRPE^4|Nx~ZDEctj~puq`L2vWh=GdCFfWaQB4G zf61TYOJ{R7c(w?{hd)IRcgljuA7aTu#`Ma*%(`a%#;T@R;05Y2yFMCf^2(QpBdIpm z0QFBzYZauVZ&*#H3Y&=3Z{&hj>$;Y=zK4$~vfhC~lZ53<+JhTv2Oao*QgRfI5}j3D zB&Es!5<0Qxe4aM@{C!<0wFXTr@>RfzIjym|?o_B&o9L#qoYf`^F z9!{F2yEw#kwO#IrSTZ2Y;^o5MKmB9m&>40qW<}1#@ zaF>vLA0uPPm+t(ygGR-^?!If^J7!yOq_I8Sj+?sTjH|W5BXZj|< z8ThqdMr>RmQW{I#M*&?vRoP)9TWJH@Om~Ti{gYZy{lr;LY0bp@`u?&%al&noY~KN+ zl}GkNqslvQeQu5y4!Iqe`Ba_);+oJQ@H)ZB}e9Iwk2U!u_W+ z@K$U1{!k`Xu8Yh!J%_!wk&}z=yX+z52u^_vSk*O)?`e3ZbEDxERI*dkrig!KV{7Os zce8EMj>98WcTH#WG3{mOSl=-6^!UORQn$Tg9GV!!n9G}U&ZJ<@F7l}&62UE=`U~)% z=nn<(&+Q*r&VM}l{~n<=w%AD}W`d@AsI4P7aobaM3kGt=GRNegRhOZ*nm+W~T; z5lW{bg9tboPM%%M)CsRfzVgCbyqC@hJczn^1O|``g1gk?=)8{-?qY_xsBsDVIuN~Y zZnjv*r_N+h+5m1eYyA)i%BM;sO{rNY?9O1?!de89E@a3>>Wc-$_5S<@Cd5OqjY5B7 z%e82UL8x`YOmW<}wW}3(wh~My4|~^yvmLP(M5ueapJe74$j6XY20h(+U8liFY$_tm z_JV;{=88v?j+sZXBt#25B{(lYRbk@8aIt@nx7B>{Rb!OaeV|t$!sSe%jMEkA+o$I9 z4_hqOo5+s>i&7<_l#z0gW3yFCBN%Kr3KJKBG>C(W22v$M#1NGT5fK~zxUtEUX0#mc z#ST=1Izw>E6Lgalv31wV)Rs_loM%neg&Ot6R5%+JZnH?2i=Y!}7U&#z?b@>&%7XjC zZf%t?ts69MuaKVt#=OcZujVQ`A?zfo1TE)IvK5@LK&mq)!21Hp79c{~NMH#pMAW@e zDByxiMKy^Mr_juGzZkzFj~|YlmWH7V4U4@0@!cVMiN9NIMw>v8igPN{@~=x63MPe< zHC13_z*sc8z5c+=o10~U7k_L;RIDi@Zla3<<8t*X)xSeS<7H9sK!%;Kh>C4(R_|Gr zNtc9l;RhBXi*_;#r0LvyAHi6uHfUGN9Cjb*wR7p-k^>%?h{HTUOaVP3k+nq1sK_;b z?+{q?VwG9SA=ORS0Z{MWm^twRZRlCCP4c83|8v`Jmlf6?kGbh_&k;e!g67Ce2&s9T zQxEwy&d|kt0OEt1D4vTgh8j{L6%>FS2FP6xemPVEVgHb2Ieb^2Y~s_G>^dHC8;V_P zYr065JvRAAoWuQmxW%iQr`%>I;Nw7=79>;!HU1X{1ZP5 zUb(FJ?%0F@i_KX}RVOBhEe2Tqz$a_?olQ?xM*Hw;KghKncw0`w#~Ve>0r2pxa`nie z`7o!K1w3uvs!vCFT)0P=e|^KPXIl@*wzT_xL1VFBD|JY>^<8RWgUQ>yOR5+{>612- z_i)<`d#dU*azp2){kV0oZlVY?n)C!65yL;L;=uOK((a1Pb&JGp@5aWz_PH6`EbRMS zqjTwf8a+hsq4VZ;$EK5;x4U!Qi@yDtzQMwC1N%<8(Gh}DO1_#2H=>jTUo_ogiZMwF zA&)WH0U^DQ1fp}mC=b`rIf?WwOLGTgQy7K*SFJ_}dGz@ZaWJ^FB$3?j_&@xM=tOZ= zs)d1nclJ*{_rmy%7vYsz*jf2^ZEC`Lt_RKof_8{g?K42gIa+`eJ|8;j|Q)6<@!D@Dc2=)jL;^`Mm zDM6H5acyx4YnhwnQR^P96?PvR5~0KIFf(y7yBRJ6XBIAZ<8E)X^x~v&{Q~}X{4B(# zkohNktNg@Jynn}NBM)a&eJ5u}OFMJ@ek2=7C&qJ#uAmxhr>bC@d#QGECzrk}*pL{2SEl{F1#qZY3|3He3~QM@Ey zNj3y8h*$U?iy<3YLS?KL{T)-9NSF>u_d$-yozxrp)4&M7268(Cwa~&AOck+Pj+RcUsGtp>P?#+>QLXewa*x|T>3GOcjNKSXYd0jAKem2}i%>#a$05l&I1wf6Rp^i#DhcCmsDy8?qISqs_TEtM^A@Nohp%>=D z^SVOS!==qVC5E1^EG&9Z7Geuo?)*aW!^k{dgx^k&fiGEHQHrJG+=iOoTR6Eg8^ZIt z5UWHZxlHes;{_haxec~l$gta!mJXSVA1nmR_3lU3zuVIo+WNymer?1zp4H7Oxk;PB zktCI|R-EcyKSl}7-CO&+5n5i0T6<zLTk=tEuC^rqlm` z+q>eicG}=b-1(%gyI0L|uB;;=Z*)uHwa*#(g>;g*ifV1f#G0g}h(be{4?zd0HK)7f zw`0d50YpNoD3_hlg{6vz1i^y!c@5TU{Cv8DS1&Ow+TcF!o6h6+h`;LSoHd5ZLZ0cj zT%XQZFJ)R1m8zN^44sR+wk#_0sqRrFnmHCu4SE|XxU#79sP;zlr2WWrKxm}Eau_k2 zN$n%zurQSrCx`-_YcbTwor@a(E)K(Yay%B}!DLd~-*L+|*^j!yqb%&&H7A{+LJJIj zEPndq`;}QcAPe{a`pJ_f*-uPsa!&-`_!bYziW)I*h;&!piQAw*wb|0jz1J9ZV$R>0 zD^pv$Et{>+FR!Q9m7~|w;rZR@ahL~qILnmyOR{b%NgZ92tTF#BlDdI^iE8rEc{;#w z6VSg&sVS2p`al!IC(}q9@b;cG&e-FR)WNf0D#$iQ!PV1@yq>In&Zj$D7hhH)gfB7LzZ^nC1y4uI1!(12P1+Qn%-;b?V6Vq;S$`5sI zd)7gkVY#9}YyrAsWa3tHOROjg5vxbOSB(nNDyxd2$cc}a$cLWC3P8Iw>_H!)vKtK| z$gLZ5Qxh0CA$P=2ItSlh-M|-WP+^u_%@EHM9$1w;K6t~C@EaoYd=)XEo_Wjj!qSn( z&tIQj&zF&1U72+!>Nk2iGOLG?pD#xjD>nD{fdRVQJ#Bvs(`Ke-#)hUqJoPuzzu5V+ zf9vb;F#C&C+VRjTN~~G#G5npi?|V86XlG%lbEo;KVHoE60Br%-kGas#!p>3WQEvlo z>(MOqI$ti(8<$_bwBxo;VXk*%!!Zm++Xpp8?*CD2 zhEIoU_QaP=9N+9z@B{P)s#u~|Y60-e2_ti$cngNNlhP~a^78812BEZfMo9#Q zD-SR_n6ntP&_ooG#A22*f?!kzCG{U8LtS-ZX5rC52nFph8D0how>{Kh6ui`rc*CKf zw3ntxx2~5mWTXZgY1E6fkOY9iB9MVRNRQ1gcPpXZ6CpJvurzVX=%ryP1-fckv)ird z{^@Bq+2-vxE@ zLHWKj_x%*8olz~c#eRYD>h3u3Qc<=`}s9*v#!eliGVr_&p1fw`V88S^LOrzDx z&Zvu#RCYp>2)C}1i=IM6(_H8QDlDJS10u5B&NN1s0+c}8D6AG62S)2w0ajRDEw-F9 zT{v$7K8D;T^oHbRsC4NRuhHqS|IR&*QGnRHQ8BYNkYa|(J+S2wHeRVI!=RX%T$|Za zaoKJ8YvMErnQ|$C^KE=|u^^l8kd;wh#~bWD)uF)@*6ZZH59Y!^yG)?`nc}qu@{^r> z;2_q7!_2(Q*Ch($SS!%H@UfuWpEkWaCFk_DO>89AZl1$y$MTJ0n(g#iNa8SPTivBV zH24F@n?x&qI)Q0T+A>B)v;(^y=~sHbA^1eNV8K0EL3$g=T{6I*0GBOR=|(_Pp-J88 z3P-%*+y#S+dfCnMRa05Gbws@X=|lCL`R9 zQt*+@Nqd|zM%=rHZOBSZoD)tCYw;|rz53lE1Rc%^JHldebwq+a631O4Uabc%th)q&w^RefcCaUAaAWG0zy}YW%j@NM?4}Y>QQIC(3!?~7nFZP2&W1(}BU-HNF*G7R?^@Tynje#h!%M;- zHtgK=`EMbJtUh?}DYgZbND*-8nw|Tn)S{1c45N96=Ww6FI?NY|35sMnxYrNNgIMhT zW|nnZ7%b>V`QTRSz*d~yJ@?i1DV$>{B*GvI%)s;^DtB;3s-OG^dupbD*yL@GaGo<~w^nKVv~(8~yi--b8e5 zx`b-CkrSryV3mU`0jcny+@3>@a3VW;xdNgUZ(N8Sr56iIsmm}B@wY&irn>(hL3~`N zIZFnNP4S!UIYPIL50EDQiDvI#UmMhfsgb2o;c!YV6h$wa`Yuo^@^=vZUm&1y*ZRUM zzQ{m$*VJ(@H=OGMcT$X%ylY|n7g7%RgQJ@_FaEX*H@8|Z5r5;>4$s$6y@Ey-C;OPW zw1dT`$GwmcR03Iw2(MeFv;`Dh53 z2Yu}dqT@L?a+D;lGBAy(WUvmAW%N=lGrfzPQvLzDc|FM5Wp9$u?GL0)Nzm;J;KqaJ z?B6{{TIA~xE=3ICt^QlOUB{CaM44!LNj%$!#}U)~?*IMQnhCFKt#j?;Y~`@F%gkUl z{jq4Fimmq1QZ<)zy6(qvQIN`X6Na1$mcg)=k~aJsFxOwZ1cEvh0eOaEj5apLLt71` z3`f#s1&m(~FPM=0-QwXD+9GH~@v*&ZThIKS+0K$;T$w<8&#hmC^J(227N&WD0ye#| ziN`MaEaR}F{|gJk!#uUKo*@1hClqh z_&d-}s||+x>u_8sL~TjNe<@Vh6Oh%2 zTrY&Q|NOj~D%Y-r_zyoi!S?lYdJH$$*Gi)CK^yFfVARoVg04tFh&KC8OXCU#mZmQP z)H2##%O-NIjg8YlMZ?8WF2=7vH3DrpF>N=?=2+ZCHCqZyS*I8Gv-o9HM&fBeRF78i zsYdlIY=ARdN^i!)=tPqK7=3abdSlBwmr^&zUD$5p4DVf<`&(#(xz`;n$9-X&T2@1r zj{uQZvJ8CAfzRAJQv$UCA^de;#{gc)lP@l(3t}$n?2bOVjo)qCewR}1IlA3}<@)HE=*AW?trdlepzK0sG8HOxMFSCHRFTaBBk>iK;;{k(0dlTJ~3 zwJbIVgH~$uSD0Kke_gi4wknAF3AD>VMbSo3Hv8WJ9V`@m-Z6HB%BTb?GQTHp9Jn%I zq_>$_fCS+1!95~E?L9MT?kqaWX6Po5(Zp=5yf{&O9-Ppg0rS^nbDg@ zmL=^GRE2v!Z8u|eIf9`eon~9*m3USF3E9wvwzpqNS65&%3=Bj9xx9y7tjk7`^**;d zyp>a|uxa|7?9*f*%l1Ycpt9sQ`4a@l($b6uAK`mnxy6Cp^s1C$4G2TE+ zNyrlVmO2^D?kY^<2(JS|%%Qn!pu=o%84h)M^usKx4>r3KxAl?hIvS)g)*GUOq)n+n z&3pzZ0Dz?^xM!D}-h*trkrtwL?7x(0r>4>!0FAa0#~TL`yuu-f&kH3dtDY%OiV7?c zKh3bH1R1)7E`^k@Rc_Y^D`#t%Dm&M%PzMY0 zw>?$;N^`|7VY`5zbrYVbFCV!D(ONtvuMp3aNxa<@D72}QqF>>uomQ|RBmWPI*+*fy z(8@wex&EM8q!t0X0dy34k^=8R{n$d1-_V9vI^ViFt*z1~#7@c7s3;T+6_QW+s#84a^N7j;cM!D_JGQC>;b zMciX1#(N?M*QZ>qAmp0$3s~ot(MXiSJJDo-IJ-O;R+Dme7WRlm0L34%iAFs;DJ?Kt zinfmmT&4U!M!|5QnSTz{jlh!M^1&u zYzFIy4GRMY^?q_fOup?Lfq+5O@{OcCqt$CSe<-ARj@pU`~xWfgW9G+nu={2OoN8NcV{3ROX~CKJ7x0#Sn(e&Tb#Fa~Tbe z^M?u)iO;hZqfS8NP4$WcZ}*2ggGNyVS;q{S#HM~Afkjs@jK*fFJP7P+!K*f#u23cy zb8YHw)5hxWyyd$VGViQ0(mETm=Uld*CavLuY zJ?sv(KG}Qkh&;)XsKS)-P-`!)#LQ3k17|THJfhUj4KT?D(`M z(IJXW0){DQZ>@urTFttxL(NbzzyF))LSs|Y`~GvG^`im+VEy|-v$XqBTwH8_YUQRT z|BK+Vr6KL~(`MEEpl+}Uonq*5KE}*P(lSpaOC!IJGY<_$l*VM6XaF!8`Q;-3Kmv%6 zm?9<1y%cTAgVwRb7ogf_5yPaV8Xu9!hLhqVogu*Ky(M5mrcEFZ6Iu4VwPOcwAyrZE zNK%AD>7eY?y}YZ6DjyMNmANk07>Drfj~e&y#c5VsELyf-%gB$hJjh*%c^x_Kc`@4= znoM}4uzJXV+6@mT@fBJLmt)7z9b_``n2(!NkFj%T>8qrX0ceyf-B;HX)x*MU2E-;h zjeU^2fu&mr=Aj78qioS>GyZ$nnR3J}c9gdEi+%oL~My~tw z<=0Ic$g=J`xS3{CNXSXW;I+ax1t(yU0QfG|%amTQsnr|CkXuJOk5cQJMZ!-o_sTwFK(blIxXzY2W>p zBcu#4lP4$wz?_x^lYChQQo^6CJ~L*H*EwcV+7vRnziZu9kF|l#_$t?)Lb0^8g zwv!M32-C0D4DKB%3;&Y0c2*0yw4-v>s$A&ZsPffA!%gEc#FsPHcS%mF9P31^dP>6& z-LrW>8f{-o&XE~*l_hb-BdagHtf%B~ox9`UVLiHB{Qib{RSbu$HtIB*~4icv0a^quEp+LR< z-SSPY)_nL84Yyc6txWrY_pn+2sejJFTkYGn^*Hlw?=^Mv?{youJwPi02mqiA0ssK_ z-{Hi{(A7}i(9zM*<6oNF|FCe!sL9%|3m|m8t5Mv}54N}D}oocILJ<2bs^ z@C7k}qon7oLOoLuF6GKk=8Q8Bztyl53r%ZgqW)p3Na)F<@Qj~_`5yQ8ofEe+M zM+M?U0i{>z7OWPYd0+!-GK-VyJ7Z@IT9&uzp%pZC*N?FIHNuj~FPf)#nK>Yp(>)C8 zm2QU|jJrw}*DOF2sOgRUEV@1jZ>K@z_6ZxPbtKMd7Z94?pn^n;fgbym&!Ro?o9*y% zz9e(JP7`?+%(Yn&%Vl3QR|@%N?ql#k8P+*7U?~f^ajvdcU#!?s-wlf(iSIs(MOfp= zKNp|s;jtA2k%ov&0nDZF7U*^ukjya23+0R+4VvGolb`$ox}X)7#v`{DcPv)Q!yDih zoDotsr2`s`=FAxFdYD#%=l(n!tNH2uheJSKezI^!H|2F{fIfyZ3Zs zUnf&*uPHO+G+`xKQvVDT4m#?iQ>$DnoYgUM4g;P8+5X*cI6|$nqCRNx7RrB?kl>n1 z7t^$6yfQ@kdgb-{V1o;nJ4i*ZKlqxos#bgkxbj5!G48B0x9yy>&0ZnPY|^CHvi_;QGi~{w(+V|H(^pDh{NG zWFt}5e|pnyIsSjL@_)>4V?!Goqo3}ue@*dxLxq%R6C;0(6j_yrqgib;VG)UK zl<^o3KOo?braYjC@lNv`g!GoffJ>z7j|v8(B##Z`P~g_Z`gI=x{zbHt!3+5d#3Vfs zYQg`Vv5-y@iNAeZkbW^K=+ZoqjtQeVs#btPKIr!Z_!`jwou^xvNr-oeRZueylDw88>KC$LZaht+*?@Kg*M?)`+ZmmxpA9%sGSo!FeHyAM%xQg_L0lN86eZLwQs0+MSY zH&K651|1Qd>xD8W2ZuV~71thi3_iqOfN&%@Jg~nbhByGQTGl~y+sxwC0`9LKZv~^GjkJeWkQ7{ zh_?G3XmZ8@30ho@sJ!3u)78FsQL}X*#I)Mjp141hI&(+jgE>qwgPd+3T;*b)_YQSm z&&CWeuBVnLNw`{qS(59UBS>*Z2hc8kHMp72Wa{5u3H*T_`9Gu!5cBC z)?jgLqCMa;EaxWlN&)uG46f^0w5Sy^E&=(E2m}daSGY`ea zsIki91#-OgWjZ4_1s8Pfb54pS#uS0bW5wn9!X&X262HQMo=gSo@mVczPfPc1CBI7x zb(^aDwq*(56l*B=Ig7WuJZ+p4r}+qP}n z&aCv4wry0}wr$(isonjh-_v_vSO15I6>CO}Id0N>8}KhsY?ln07Ro>Gk}5YAR6hY+ zvdN-9ChP@oF_;X=W#zB3*pbQQ?Od{zfS57TG?L6_x~Tkvd6i1Gml#^IAEuYlKfNij zBOyo}zyhC)7^m<&IgKjP{Q<=pWFbR7hSiB#y%^k_az7 zmwC7ChL^!4=#&%Ve9s(VRp*Wdy<+N>P7N$QFSBE(c`>)1Xb9g|7pB_8Fqe>&=%>Qz zVnee~sUXTC0O@Z6^rb<(mSey!K6i1A;2(!>DD`G$na2`kN)Pe_<@1u9K{Zk+3u~)p zJ*arZIs0Vr9Po7VJ-95^4=3GuPC!7ZK$RSM-vtI2nP-fv;=DuF+ti5W^pp?y#I1aR zcaw0j(e>+_@T2)pEkC5cChx2rlqj-9w#&6wuXL7{TBX}eKcJP{O_*Ipni((`_o@|Z zhUuuB@=747x1SN5_es*0?Re@o{u+<4OS(@4r6FWNl#Dne_D(uEsN+x|WlcdhY)~ls zz}T_Cs$jx3ieyoNX-Q+%=Xmj;YJH-1BgS14Y?A5yCLx2^CCu_`iGQDZvUqj07;bSg zxEVFYBfWZxSy*)c@dXv~xh@(6c@6u;2`0gXEAwTQiq1sEMbY-+2+8aPttmTmxOCG((3r7EyZaB$UJ~5zSeYXeojdCtskqKh*AzIcNB$o5;MO=Q?<_icVi&vv-zvOC%(^ZC) z-mlMJi+jQ}5mJQtYo}d(s&QPj=cM;cp@gpg@)6mr6aX2-6dXEV@KM3-LTFM-!%!Ou zqcpOgEXi=+JHStLp?88xhEk$dN8s6;%1ONAxusGLRbSiLRnI+wP5miC8%0Ctwbqf< zzg(!@wN8DkUH@EUanuA$Lb@w08>+8z@*7EgiEz(&M3LXJ|h2wwGFH&FU7i_Yt> z*ASUu)T-%tDsHc?evDJ(DGH2f`-6oxN4`loJt?t?&+3k zSX1hN!n7oGV}-c}7MLH14`@jlZcO+53{^|0E9L~^jN?(sR!fM)5l=S!kFaE5rK{Sg z;q<$mq7WyvY4&)a<%~<~R40~GOWY8av3;@pvvE$>%vo9Hr&=+yF`dfMk~;L3brWA{ z-}n9{4BU^d_zDh2GW44_Zfgn&s)I8@*aG z-m}e@)<(CtFo*G;uwY${@rJ&>ipP{CojOO&B~IJ<=Luj;%}RyI}cs?4|V$@(aT{y5le=sRlo ziL;9b1*eG+&6y-T#7-zLyol{yx%-( zJfZME;QyHkSFts;SN~lxF(U&3;rut5(8TgTDa`*N5ngF%%U`k~`OVbU^-CAjsBHmt z;YLD633O#UQ)MrN!w^l7CJpWmmS0^q{B(DdN@tk_7KC=QAVy%#_&5{bIQ3D7IBr*{ z0S!l}!DGNQv&woxq$z?2F|-LBsyQuF^~w=LITfzi$X}3XYRp zFo4epTKwg*?}((+egG?0)LGEW6geHgcpfjhf=)OI@BL)^7c-GMPUdsL#GMLFk!2_a z2ZEg>m_QKVV|53mn~oWth)~^d&+Q;i&gXfVM8W5Ip0soIX#zBbhNk3OjOJ6b*22vx z(zc7b*!DsWnN`&1Esn^{$UiH(b9b2z3s8Zm5CCKZxTR)${Q!(C_elg?Dv;`%`NQ{gY7j9?P+&ZL0ae-pfu2HaxV(Fr%vM5@#6|qYcdm>_X zJcY5~GrM5~7CJpW8*}b+OJ{t)#7XD{&lS^HTM(*VV&YGTH<84)vq|?3A*;ThlVPlR z|7Q@b+;yRWmJJjfi@dFyGG^8LWw596P(eBJDqPZS*D&2NUSkF;46-Ykp_*}0i!GN> z1%#`W?{s2PPfhW(iYU80#Z*E#Kd}mUj7=8!Q4Q`uG5I%3@P^&F>cS4hmwO0bqIWPQ zP&u-KKEmO6;O3##0*=W===!nj`Z3pLX0^l4^6G+MYrOLp*c+nO>5rJljhMur`@jhs zM%(I}<)!&xYm>m#VbqV)kxBTRHq6>8V<>O^kwuHC2HlEPtaP$YSZ@kXW ztNYqagHNj#FW4`eJyn9Kd9TKEE%P{it>=!?dhU%cdr6Onuu*yN zGASx1jO{y|u$3P-o6S4#ztr5zGV=`s-Jxdf6D?0B@m>^Gw2l&xaAnzK@HR097s8)e zNo;!`J}NkV4g~5moR%76)Fhf+5>%1Fb7E{0EqT@*@zlw4JdnW6Lcdr1l@`Z$1P;FS zT;X#b)XL^vU&l^!FfT~E5T;c3ROAK{v>z7gvE#&ly4+;YQHttd_m07@NpFs(&s=hF z61S5;5Y`J8D7P^A{eN`=4Sk~=1}IFZDL#XA`v0$@)_>Kqpn(2MGUk^Ctcphn1VmQ> z1Vr%Pd@RfjT}<8oe^*HVgLCYX*V1`ItYPPgdU&=H=2QU+SgNJa%d0s`X6^o5OVO>t zv`vAKv&bMAC|H>3A@14zl3?RH9hf0F8p(BV2FHGxJTNEcI;Z>ibdP+^t6OxqZ}7>y zVR+SWqH6O-^DrX~jimQZH4bYkLL<$KN^el?%#!Y!v#y`h4y!ku@!> zFKhLgB{C^~oz3winlMIkzhf4-nU}pJ|3FlCKc=8p+Q`Z|0)E>G*8?AK#BNvIasm2_ zH~&nVzeL}72CtZYD5>p$n;F}w3>vS)c|=-J@1YiF|H3BAFN6Dz)t*tlb5%E_FHP(h z6p^jxbZ*hD;9!2C!Ds03@7eEsUl*>wmx+68eLGisM?(fbMbrlKDT+*5Wq*yh^ksDA zZ_)IF(Ou#O1 zz`jFb`3rUT#R^SjvKd@bUQO#m1fM=$X`%+`#`Lsw_rbrbCf2`6@}aLJ@MLs%?l$=? z%I5i9N*)<}$pZb|>~Osyyj(A6R4caoGTA>xQ++w}AAW%gnX$B#(mj!X9%O-JqsaXU z;J?*5A0d3hO?(=-bG9)roXOJ}ExUSGynHpbf>Z{dYD3`)$PUXZfQ&uKuM21C%+LQk zei;n48`eV#gwsbvc-bM)Uq3O&NX5mNXUJEETpY)TC2e8GUl6r9Fo1wYN|?uczHIdSXT_27gU0wwmxjaEKBB@ed>zKZeyKL|)qj$|9= z`R)DApIg;b*eGo^Xy&_3e=QcD!>!sm#^lDq!5Qcq3hx6?0$U?lmh0Ks`HTB++$Y8I z@Ed>Y&e{n^@Ic|t<-^xR>n(8voWh*7@qQSegr}YvC&30g8jGM&j6n_ zPslO>7OZ9FOyI@E5FbbCIKdQy-}txio?ZT^KQ>tn`&%S*ju`(G-pUG>t`OT8jSIO! zecUqI`X-4x+B~I1|J}kZ!2P)_Ye)|2I8jpZk=fm@><4O#kXk!#YlRZ$A1lYshsMIk zo*F?{RO6H+plq`j{Fq~?#SH})KzdK9dwOXPQMbZem91x{K zBL77N4XsoQ?E~=7LKt9t=5S6Rkc@cpJr+4~Lqo`ek-<8v>m5#zq0G6A%{f416>WnRtVVhq&?Z#4KE|+<& zJ;f%-Nc<6;jilymOj7sp{1_n$$bdA`)o7ZRZEkoUs-mEJj!f~(v83>dup{xTFbl1#bu(7EREGm@yrHksV|lxcpW9QfRR|qm7 zX1>E5(86hDc`uxfyfc47C=C!G^V^bZCyKscvLPbb_Y#kFaSqdlwM6Xj*dZDm6Nd*X z7q2a(c8?Yvd7(y?cX1`CUlmOnD7y!T%;SOZ>v)ki^Xc4tg8$ zPBj!N=Hu49hS-{3KW!pzAd?#(vk5IntG>i`z93N88eq`eRYJwqLLC@%z&HKNfQJD~ zcIdB-tgdvF&J(aSFo$F~*T@#K$K^>=Vfo9GcfOAt=MXuWOP^{Mv?Vct;^d9mY-2ap zGQ@6w{mC4{lVHwrKfw}->7ds&x~jG>SlA5t6D7r->1c?c#ffn@E83N`zoK1b`+$y zcv$+^wJuc1tioWDHe{)9|AXO=7=E;h{A8V-M3?AYKs5u}$HDuMjnA$3zDN6WikF@P znto^xL_VZbXZt)AdYKA_x{qdP(ZboMWEYi@gt&!7D2Ur)1aTPTzXy0MZqfD0u_vb-6_%SxbL_h2xop=NiK>^Jw(#E=elbc=E7$}@1!zYLhX*h?{)C+R zTz+KRP=Q6VkuTz3`v;J5{_FCK!Aj{sMOiZebD3W8)bD;$T+62|B|Urk#1T1uo#|$j zpisQIa>>)3KrYeg1`hC7lLOQ2nSYqlY|+O{^r*dFD0^wUtzkJH1xPYce5i8l`<+K{ zQzP|4QBIxjy>xjZ6<7D-Yq1FvKf=w=b6a%TMCR*8UqNhUVP++e5a2}?s8egpcJ)TO z!Y@i_Z`B!nw+(|eg`i`=xxj-=jutv2$E=B16&%GEOOF3KOmB>@maUnjC9?L=ml}AZ z$v2%`Crw>p(7Gkq2sdCBV|z1X7FmYTXHsEM##BuEnWf3>b4>)=d}8Lz;S+4HY{-yg zC?$#_V@yfx>Dhl=7s>P|o?G_UXMwV8g8GDF{)QRxFYQezph`4GI}2X##(491B%LAyHz27LzY z57X@jvlNu@m)}&#$tu4fK`SsTCvZoCIjTI}KKYvHBF_JdFvhVr^5Z0=Ri^7!Kr$8U zv{cF6OkNs01fOYU*4TO`r@-}>LZ%C!e5q2G1;hA<#dK7G3;>uCIifArs zb5&?HmM@KxdY=^@y?(D|OgdC7ICfXO|Oh!1*(sF($uojeg7}3X-G948j7Zrp`)z58#jXz{>rsSlI;q z>t|wCR|Flhlcei!hIV7w7rYYfWd{5YNyd}0<|^NyhL|^*D(&VuFNvqOG95hPDw1#{ z#pRS9jzAK@*s`qhb|a=^ra7Tnt(0{d1dA0+IHjeyXs`=oh`af5G{w@ueFmKn*Kkk} zH=KgvW*pH~Q116n{y!rcn;!N}w%!$Ho1&MNU~V+Y)7#y9K@S&v&SpFUQS#=-KRx?+ zVU>e-cnde}dX(x&j!Fs8Rfh;lmo=Vg$Dd-ROVt)x#Lsp`olxzf7fmo#DWZEhH?NHO zky?)w$Rlbl1Au3w^aVZ;lsZp|{pt0g==UFeG&ojS3DrW9_$&9PYuM}aq%4HacgB>Dnn-Qa^# zXn5DjU&XGw<&@5FtIe2kw{-g2*{hp^{_g_;JYyBO_U93>tWI{+Rro> ztZs;fq`}d4U})39rO|vY%@i?Tqk^+&!6aKM!tn*q~_%mD@0h_DNy2Ck`rPiaJy@V)wzi^ z-P=OcEiTF5$hHQh>_yhd?ra4-j>JWV_svoW)|BPY^%jZsE_BF|`4EC04kIL;2aBNN z_UvKmQvl%#6#d3>u@}Z528%cftl_#1)wgvs7fv2?n$u$_YeyMjp4g~Kk>a|!RK?bv zMcX;d#Fw68qLyzJ^H?^Cd#u#RgX85*U{Etvduo<+mG49rM6Rms{+_WtJR_hUvD;rf zBiTzW%fE}z+$X?!4FK_T!P(|t&nZE%p#E+u&kirzC;$X7gKGVS_mLJ5+w=|kR{bY^ zn;xqjGivKc=MOQP6P4@ge6rvzwK&6<4Yy{HCKz___(f9pWl8|={LZdk^P#&!S*PAo^i0O+?F{mOW%$9oBU18t+d z%}pmIapJc9!;2_BV*NPbi=(L6+n42c6K8)n0y3sQs8d$N{Y<;tp?tlc$6AwcSqevy zuPc7u3jsluckc8uvEZDQ?FGW4%@|nZb+icQ_m-*Cd6wGCG#q@b5X&YM!-XZreLHjy z`fLt^jz;g-Jj2IVcIt+j{YHb{qu!G$3>Lku#Uk1y;BhRJ7sf^Xyi$5HdtS3%@pj4g zbayx#BA^i)UgCFaF{^|l&y}|{8k=+a@IpKMdV`E(=Th{aZGC*T%PRXs<{fS1&O1`0 zbW;7lB)-rEEC`>AL`8)fVV?Il{)`1t*RnC;t9^6{K3;B|0h1b+G|@OxjY87~8G7}j zH%wDNgJ>o9Y!BJ#Gev(sOf!Nh$nMS6-Y<84#&a8_v1I+|Ko=3k_Ja=skHpNxL+hyw zs!CKq?#%?|O&_qfrHzjRcT?c(SOX!6PNUxHU!b)d4qrXNJD(i6gUxf9VR0I~`McRc zCzKI(pi>7-CL?}xRWFOHvUySs#wXDFdHw==z=;a@#kerBYoDe?#7DOshUZRJvt_$+ zHZ0@5<-^WU=eJ&}XQI;^0{<|gpTfwyiwLXj7(MtSdcIT{W7A3P+JDImex8iXS7R6+KM*Y5(HN$xK&F~2 zkveULMj5e>k&;we@r0z&ut#g_GxSlA3u!=`J7u;aURBgB(q{2nnMMbT>htA%v<|~; zwpls4ndEphFl>!>=v|zrH`;tmZ(qN*eQl)r!h{q%r zH%6ISo%4J>W6CyRJ#2Y7YPl9$`vpgv`HiJ%+s^J!kgY}6!YWsidj)|zz_nc0nln|4 zZNftu*{R&JF`)Il+_XUzj;yPzcw2#O3~Im`H-dZ3bqebb07kq8giaIe+&A?x^!H%hK*e>-5b2k*>pTP2b6iFFni0im5%vCaX)mc63tO(8%*KP^MEjuOxc z&99C@#U5P{pr6OB(LuX;`&T4NFUM76G`|Q3^-JPuB4kwyqY&OZgiM(ZKdJM~MY9aO zAb<0Hzea|5>~j?OT~cn4XL7Kdre#DCI_W}$5_4395`^`h#1)9G)qv^AMkAY>qfY!g zHl}ZD1l_iX(7{QdGtF*J=sKc6lNg-5WqsYM3A97Sd}LUB)UZ?ue$A|uIm8BiJ^T?q zU~@tDqC;;v56EvKA#+(f%fI2QG={z&kA{$9LaQ_#hF_mqQ#p~w2W)_7(sQ^MQYV-G zdrg^k%^}yQca9)W*-6L)(C?LgzDV4|>SV~@w$Mh^BH@d8GUS-xLhgMXzUq_rNGklM z!+=d?xjI{lchxc4IDWC*usbV}^H8#tut%WsG;%4G!9f3cTmMt%>4&fGP7%7$8M+@A zhSzu51U0VHJxG^`kPt)pcRn2@_BKkf9z!0$l?mc01CFFDZ0H_3seJ=QHQ$-1?8vce zn3c37MSH%~QlR=2x*bQJVCWfwxoaP!sy7t zZ!)^8|4&{&=MHMGO&TMWX^C#Xqa>7Uo538jDY135A&)b-C8q$2CHfA{9kzL=6ymJF z7R8+d(@Q;&v7Z$Zq()Zw2FqozCROKup~()tE^xR_+S%W8?-h8xom-KJeE) zY~recJ3^F7o{B}UU&2|0t#XP6E$1FFCnBFjWjF+VxCU>My>qnonO>$+7}=)DQ?bnc zblyIpB)<%$omxez5Rm{OgHkkK$UfHS1CT$f=$B_NaXNt<-`~d;i7a{hJI4gBP9fYu zn;iob+wPE=M0(_&3BvUnWw{9kN)r8E&BYzodTQd`@ee}ec)#oVqD!g3y$5bKC(r{V zrG&5tW-{gYen~Gj*)BgzGKfn>5pa@kZ0pLqMRn0)&{4+*lhFKRp#mC6;F?jmO1|)> z<1nIO!S%=h8m5|<_g;A^Zx>T`XAifMJE#_;e8f3Th=5cW6W`o0fa&zq@nnY@Z6_X9 zzcLnu^e4H6)!Nxc-`!oh>Se2ME@$iT82?ppj^f~KXx8xpZ2(}{r*zpm3C zV>giKP@z7)#@&4ivkiJZ^BJDMshWhI3~9N ztI};GS-j8kw8zNd$hM?pay}Y?ML$RqGDqT6WH8n_FV;DkL)z#{EP%L=p^>59cw+eO zs-SwU-bXB%bL5a_Jnl&T?h?aj>XeO4l_e@+R|wrRU_w0-#qwr$s4~oJ>XN&Y*NYqL ze-EszugYuQ?wNjQ7$TZuq#+!48|C)>gyn8c)n!VZy!0a`{7B@V^}NE5KbAM=ej6=; zuOHtqg1%{gJj0UMu*?WSm*Kab=dTM_VQfP4m!+zP8*n9Kf5Cn1xc6}TMJ$FN#Wi`S z=@&>G-g$I#&6{2ri(Lf=IXqF6GIjf;giTGm_|Cqu=1_1>CDk^jW}04! zEfH0j$iU8F&yef;vt|jMd`bLZa6RBFy5OBmgcGVm4LXOv+GZN9dLG;kPAn0Iu)Hb0 zY$0?B(Z$P;;QU}|eKth0Kf!mSA+Rb zY?gEi6_T9Lvbu+|X`gwrSvnb=NtW|Hl{vO zgd8g)fv`PK>xcv?f&%4T0pBixup^gV+~*g34tl-- zc7pWw@)EE}Yk%A3#Wj#HT^B-~URNp6!%RirIrDp4rK>BUD5$fXCrFQtRL zA2E&4g*z#1i4$>Ix=N1H8Vc*A)EQ2~|BPY&(~e;uw8T|>%?4c^LiP|AZ@v1IJ$B4` zKrS@bgwQJ_T6a`s!l?#AeaA6%=Xz#DG`6iPH!j@GKqkRzGa-QNK>)n-*8ZypyD5I# zaY#LA6fKOGs)-V#W=8ygwyU1U-dp&F%NsO|!_qzn%znEQx?8uF1UHVBKul1{PY=w$ zI$qhG?xLKqyYy~ME#xEWweohW9Bdj9YS*`M>;gxG_B^?9UyJH&z8UP`HNg#qt&z%h zRj2p7S}SlWN~>c@XYpB+n!C~Gi$rQ%-6U4bG-bo)SWM8{yO73~!|YMk>H(OfJa>!O zf$$yScU}2t(YXrc)CD&Ni&)`qFyfBd^Mlt%-PuEvpzvP}H8-SWYlBg#~j3gOGrpfCQ-LIT^S$f+HJREKU_U-SwcYxZg+0X3{RkDC9ex& z`Ej8Aw5_?&54ORoN3MCf&UcegDZZPAB}i;Cv}1)ApLgX!-9ENzYwQ~9pZ3g__d(K< zyYrEenL^IgXHQPJNWTrX%cWQzD7##toP(I^ksBo}UU>Al z)$9RBa4exyoLIT~GVT5N^}`(DfX{7_v5NjL9Sjn+(XPCrrgHUBj@YCRjV02H zkFTqHL;tsA5kE;j4B*9@N3PQ#GwA&BvvPJ6fFV?xhL34*>Juho@za^cRAl3tHR)>G%x~=LfJamcH3E6 zYu)jq~dv~iaMBUAN*gWI_hA=~`!ymSPi zoed?l>eSz^7LqU8iR2TM;a({skBWONq-^EDjA#O?1IxJv%w5RWIZl`VO4piHm7b|r z!%=RurWvfepH|Jk+bxMLEbHoN%@dE}Tw&FoC!6flQGRrYVX?%NlG@}8iPJ~oc;Ion zKChU9HdgX5N;QMoK-avl-cH?I>o)3^PuIS(&8Ryr*mlKRo7^u_NnD~Wi%WE@*xI*- z*60#nV4pV3=LeD1S~3>;aJMhwZ@}) zsc;aWR>_T@vA{+NPk! zlin&g>VN9nMSVeNn#Qz9z8F1JV-}SUJ{}csw>HS_%YFZ0t`FZH7BQ?YvTb|q8z7~u zs{GOp>mZ14YjCBQdp0SFrfiv^AWN6M`?RXB82^bVvAw7V-TzY$+QI!_b<_Wa^;?+Q zIQ$3F|35HFODTPJ8w_w^H=Z!WU6i5*d7^!NjD$^7@i) zO5Cz7xGlp=vh(s$pE6QX?a^lEv2i? zJp%|oTR&}F`qlGG#xoona8rDbeVCwjnqbc1q2K&*z2h)BuVl%588i{Yx*QRTnsk-9 z{6E;pdQR(J^G|LUx(ampBlYb27d}M+*q<}l<3%7V=DLHrgBh;#TEbS9wJEHOjjj^m z8F_k`xEAO*!P~ysdqIBaL_F1S^>SQ!MX9Xt5WdtjX4jpPdTxbKLj{c>dq$%CadG2h zP7YeKt6lf%-*#aEr-cuAi6ue;8nIcqql0S%U95Us0)joa;$TxOAV=`+g@rk56y+KE}Kw z2p^?BT_YFo0VvW}*-bnQ3TP<>x&4VEz4T){MBwI|cUpGYpWP<89vyrJO&>uZhM8p@PZ zjD9{6l1-FR%l}%EuSggqyx`y8-s(ja9C66qz=Ka31WE*L_v6CrBPi^C^gMd?9As5P zW#?JTtic$qUV0%w()?b_W6(LKDMbpSTrUb?N~W00q9KeSfj}Ku1KPWY8RQJEyFGnE z$cYggrZ!PNyY{)78Y3T)v6il{y`x32=a}Q}tnB2wZFd)40s>&9rtQN!K%|4n{Wdg_ z+lTRJDdEIyMV@2CNsLr-G@9;G|waiE1(on{DK)%pH1_#y2iN<~iC zr7f)mEM=1g?|^C~VXU%|(sPteUDR<(+EVTu?8q!pgf;oo(_JFr(xw45mPf^hc)Ytv z)Q1r=hZiWs?jxQ+o#Lc!U!+fJW0qWIp{F|^f{FFsl~x?AFTLr4Rd3PX+mhc@r6c%b z6#aTPn4{A%_3MPGB7-RLD{oHFiyeg=uC?yO0U{`(>tF|t`Yu{#e=vtE7m7O+FuJ1d z^inQ%5C=a-QZV0F(T|YN91{pN|4z@&IfzLIj=c(9oL_4dBYR>HFK%KYtM`9^lGcz_ z?11dHH9|&M<0&}g_kywnLu8wM{mhgBQ&Ns^1TQ)&Byum}!XjYr?t0v6HeN8Vz2N_# zenp<>&#_KA<^UU2QVM?}lC*&&6l2;YZ<;(KuzPf@i0SN9PLXq!(brH5Sm%JfP{5GJ z+-Ks-3q9u*CA0{z<`85mUcgkO&7c8H7f9^r-~IKTRYzby+M8-G+`Pf!7LoAr&<03Y zBgB2Vr_3!(jLS0G$=Ub2CWMl*7^Y zHB@J^HOe6?oT;E*FO&yKWK%`_$>G1Neza;)p5unqdMzJDfI4@4M&Dl;S_3gsj)*)j zEWDUgqts2GiOUV7*8(V3RW<2x8t=G6jC%ICAKT8^oLfbnETFGOV993x z!jT#Law0!{Dtc*HF6thxqc-x|Uv!tVg)Clw;iH$pqnC%) z4Nhuhkk3Oq;D+#9_-CI3(ldtBlC2PTdh~#eyZTAe@@vz%oFeW<{ws+bX ze%4d9ekGoEybCI{cqU1nY3koL3f1S0A1rOySJ~pszPxt^6N7(U)|M+2SI$un+`^bz zCN$}!!K=VYTRN$#s8R)dor%$9!s)YPyIC!>2`3UF&db{W>EjeB{RUEp?gE4LLa)1%vR(fYnD-TFu^h@*#&4WfP0 zlIPd*7p8mWhOTT?cl$%Rp_&RL%U(O)K6>HIusbeB`Ks3|h{2V5+Sh1*P6)HJgK(KD z+C95j4indwH2^ZRVVc3ptMS?l?$#~YM*cc-6CP@nwst{3yP*_QS2rIfz#BK-n3c~x z{utOUe?M49hGtvALajWaw&@i)tvm6z_F{I%?8jF2c_~^z-lBQN3&7!`c8G3@?2Lka z#7|c!tdu2(f#@MDPNQjQjC@Mr%6z;-*+(Zh_NJUId1sizRg;?gX6FPCE-3q4Uepi1 z!sv2LXnJ(l_v>tI)*C#PKXYVoxSvYVSM+gmmdO{qR$9Eeq{z5mC|*Z=Zx$n2vTUHt3wVnhG;?6$4p|I>fbklK#@!aohijJ{#NG?l7iX1}y3 z+3c`@goJlMsa3<~aXY!J(SfrX#QM+lnxuBe6`QQYZYtNw$K96uiSq$REFB9JC}$FU zM-aF-x5=6XPgKE60jqWSKfPp=rZD)EFKW`a)I^43x0P&VptIVdl-0m`iN-jthbrc5 zK|c`+Y_h6#fIGIL%Pb2-JlHg#-Cwz+_%F(5ii6G2*Yp8KqVKtg@UnePC_1YG*kepJ zGn|eC_4=D6u^23Tao=8PnJTyrWh^u%xFxyQn-Dl)@SV>x7Px~)FH^KmxzUBR-I zTcq?OkuCK-AFKg~?~-+x-?1@o+qNwN*T)usIuUQ>$qsoT~m|7{Qf_1 z=KtD13)u3Mw*U6e6YPJ#f1FMKpZ*88)OMmX1^#JlYB4XPh2j^eeL0|-rT{5IXMzi* zg)~VVlBlk&9oF-oFL}60mIydu&v;GF1fvJzl(1(`E z*y7Is`~CBL1o;`%Fd>&ED`?S;ITS*ulTMa2)o-HtOhzQ%a7;4^8}G_m9?76<&$~@) zcC~!vaZSy}`zvEaSb~RLsvlvEt3ev;rpE;aoz-gJes()c4>nB%x<8iZpqFtWGvO&5 z)+Z~`WNw!_>n6_Bt?SX|cboPGyLH&G5X|I(mCHog&7c|T(KpxQtwW&AN5vq_5^fzg z`w3bm(YywjJ%rk{3Sk@IdSCbBt@g(jlc$5;3_#c2v^EZFPDzOz_Z=9#>)qR*V;$u- z9aqFv-t{b)*F;h!LJHV;7}RGji=~oTH$EBli?O(Tqz_k#yKbvicwniH+MLSo-eSLD zY;t0eis~kY^5Yl45@jUYNjKdna^rt}V-$YT_BVAgmeZ7X$fdZV}XXJv^7B2 ze&`ypJG7pV>G!%i;Psxs|Dcuy%N>HA%_mRL{-Nai-s>T~3Y<;FQubhDA{jOZYSSgW zO8Mr7H8DB)nFHZ`4I|0f)0n2+`ofS*2H?0|IbJRt|_II_0P!9@GlzwckLY)8|VL*L{qG${a=a3 zZ>5%SF?=+wgiNpl1e}G&wUJeD*`vPE&;ll?B-Vxt_1`Ei+?$mf1q|_)tm1)BqbQ!w zY4>Zm0R~JQ0|cf5v^VhMLl*MV5ej~YzL_c~Y1{3Df7O*F{?vX!qyis1mwzfIYIUr< zOW<<#T8T0$IGBH`$Gx>PQW9g5)AZ{X>Lh3yG;U*V;@g!>6o9E0UF4$Z%ue_pxK0PGIaKAD(;>EsX6nlQ%}S?oDrt{% ztruLUyLJ=n&Ac$*?ua<&Pk94PvQoz>V^XkaS5R;I?itX9b^Qg`INqs`(GklW$+2R0 zyH2O9+Vu+YWMrs&l&in?lniDmxw7t;H_mcCNPalm+H!Q`YHKMoz8sjmfCs|W4{@{% zlx^;YPuQ8goWDP=hNj-aQ)?c`3{2^pUAGNH($Zn{+DTa_EGjk5OP*~8(_duQIlX5- zOH#gb--KhrU?6F9x=k?W<-BF2WR{!RG`$w9W1mUpSzFLsH(OMTo^@LB$z|i@xJ>?9 z_fhvZOd2fo0f$z%8N|{@md2q0WvQZ%V#mV8bJ6)ZxaQY$bLa-|i}@p|r&DxieoIf_>7ACCt`j5(w+k~b>@Cj*oPcKX`8N@i!Dc-Gg_0>> zNxd`L>*=z=7z6NDij+Sw)GS1XkHoj%e0`w~&@ zoVEY1SYfr>9Zz*8?ci|g&T;0XdydoK7**~kGTw9beVvsT)@DBZ^tNvBYXd``Z)nZKi2XoVA%Qrce9Gr)N14-?^ zqaq(?Cnw~;t-L;VBMqq*hP?CB)?7Fgr%nlB{F&?#&mVgpc=AdxVEB27|;9-LmG^xqUv{O@^F|GR{d7$4zvKGI^Zyz{ zVryz^@8qd(V`})Hf{*{w8hjI*}}rwU!pfft#av_FOD=?_wkTlKK?2W&)zA5FIFxG z_m-P`WiYqTHg4}`(zPa@IrAe#)ZHw_T$W)C`bN%1(+~acWID>uaB@&**?e0k3ZSHK ziN}}X7I&Dzlv`=ir`~;`yWWo0*@)4{oR(N{7$N9_mdBHpph@9%ZL=R;Pb)b#mvDxLl2RfJ5*Rw5Ou}1YOE3=T{({; zZz4>apH?yY79{IS-&Qz&xK&&(m+`5l_4#(v)q1>fz#ljDa4?sN-cY*wZPt(Jv1%53 z(~RWGm~(P;A6J66-bcEM0Oe{|AK9bb*cM%=wr#NT6b`pOBk901F2NAIz2~S^uKZ2P z3Q;PxRyV`&PAhUdm_O^dK)#CnE>)=e7$^RvPV3_2<@WjgNdJ8&{4=n--}q+6;K2?D zUeRKEeR&yk$67E+JgMRuT}=@_Vu2sUhD>1X%#5i&sO!i2cM6Ag(Zh-JqkHFUdeGb( z;fLv1t^%r~th!~1j$r1b#yA>*tvkh{qJ7Tzv>dxh!9%Iik2iJ3s1CIDvPU z-wy!k@D#um)4?MyuIR;>w+q%-Sj;z%+FUPvSK+kyD_QB-R&g?ZM}#FTx3M03n>-DC zP}rH_&7XNJ8LnTqWJB^dIVHn6@Pkzy?fxJwb8f(y;w@? z414TEfp=RPW%8V1frZkomPj5fhqHc(56ZZ|(G3hp{B-ig;XS+Gt-s%pM9ozmUi-5{ZWbhmVOBi%@ebSmB5DBWGsEhXLEA>6}PKhW>~ z^xng>pXdB>c;DG;*38~Bv(|d)6D8yX$Tr1`w&%dx=eR7)7Ul+5$6kr_uXQP-S_%?H zPRs(onLB49nq21noW^QeYl#WG8yXrgjO1|~#BD_iep`r>tcWysA1)DeNC|V+IsxaL zJo2r!A>ADNt86HtJ8}U|T;XZ58f>~Qr({_77xjL)LnMM=bOr=dt=d?;hLF7|2$$Cs zB+ehV$25ST=Gc#6(BUkKOuC+@H<>E6!;7n|amIZn1Zuy zqstW0@=NrT^HnoPYM^-DPosi*`FerrcO+p2sPnbq!$G;Hsy%Dy5~>p1ZK;HM*2zXb z{fVjKU5Lmm#Y@JI_1?xT_vk!KQ4C}L2uaXgAV~d*IxF$Zq$u^4HzobNY-0FCg%ONw zZ%eR|xGhesd4ug{yZHLbgtV@lwAh)I1c1Td#i{O4v1y!Mu&dv`G7V(zwG`%aO0aNi z_>wHsB<6Hscd%CBkM?XZPsg-W9(~zHjT+3+1_|D z4m1Qf-x$=LQsR)%6&u-!`@1=sdRP<_W_KDDX03Nnp9)8% zN7STZ*DEC0V~2$fz}c$Crz=z~TvXrDxfN#ie_E~O^H{gL?^ilnTU%L@$W;Jh6t9rw zS2fR|y(VuWk=!m#z)13-Lw?D7aynB zV#$BGG=9&vB2BFGt>%Jo-x(ZsuXseiyF{#5?JMw}iQZXcOzBvR))45lHQL>z>5JaV zfzpgth1o_5yCZWfJq$Rw=ExS?{a^#2G5taIe&lH0THeo(-hct4onE9zLaKg)x(N)rX+>f_{W-lxd zc7%22&ysUVtT$l=Zzm2U?!%oDY6^5vx>6o^yHX`|+k#`hs&P}Yvv@@8eo1_f0ZcPx zT&u_}tv^wX=_&fMtd}V-Qotilg#nv9c%%sAkSw zRBdoTcMjbYl^TUAPUyV0v_1`_M$Q6I1Da$Q8$aTjCzRYQ45hFWr(J6gI+pGvTL=}h zI|!DOLU#%X%PqGJR8EgP84Cd}M8;R@D1)0%SbiQXd*2L;xa-xjDlB^&-RP?trg-so zq`5Ow>}9S>BE+sdQClMG&hjbfnz6zMs`20zUbQ1H*k^sLP&Lpu6ydBe;kr=xq6(1n z0p^XtAK3s8F?&{NzZ9qqbb5;i6rUxuFJ0Bw<{eoaCsSC@tY5(Xt#YektA@zQw6n_s zcCEn6<=P=0$b_p=s2FAJ>RSI@6(4#vD6Hq218wCTy=#h&n#!kMpAZuozGE9lNKw$R zv0f--2bR;?X#FlcIQKQfSZlJ?+v}?WbKM~VYr0*b%EmbRQ7v&|v z$Q$$q`*Z!%qIQhHq-gC7rF6Cl7l8yY$NhY~gQODXWOf~=zis;gZR^P>RFcE6$>jyL z4ZT8b6T6Oo5ke~~Gi9Db82>h412BR|@&$OcRU{@Di+T!lFOB_m`Lq{FA*NU^u)|jh zyH;GcwPfba2ffV(%jT1acxrP~818!!_Osz#ccc&U!K*f)QVW#6#Ek3qeu_A?R>qvT zlsONi3t?**RC^AA;9ST?w$2~ON?J-)#fiydxTD9*?qVW%5MQX<>3d|0Cq4&DIyG)H zp@4TIwk>^+8geb=U%AQVHD9mAgPM6=m_jh8jdpO(;qR1>L1d%P5Nqs{-5DENMnKX-1t*of?aYM0_A~?aI?_FiWK)WfCdf@^mOF!^-g@Ujmq^w=exql_E3prwfgQs$fb?tyaI@%)D>|xQPe5Hv>_Gmx>P>959M6FQLoRpl@Hterv4>7 z`Kr~Rj|TaOTc_vb$_~xjUs(1wST|OB-IU5tc)PW|Obdnz5G~!RwGiaHyVWMpy-b(> zVt-!t0unFhCQ{>`uj61%IA~ET_J~SV&AGOT#DomDd^^~3w!enGW-596^u1N}J+td>oq71_kA$Ppy4_xH?Z zS>PYg2)c7O9T1LE6{qsr%*Dy#UX27RO=-uaZY}O!O8`>{2HY3NcP$@I!gwbEslDcq z8j@w^M8(qF8&?V(fHZ%ep$z&4?h6~^BEiS%i@nc!QDxgSW^68by@FEso;1fc;rixm z(bH<=az~x#3F5a_iWyqUP^GEfShze8EY9a7@%$R#70`}WOrc&_J8`Wb8Ql8spN>=%i zatt&tMxs$j*O9{ZzupK&l-S@EIUw00l|zlh0iENQRuVekAG_PZ$J00~g>vaP=TKQs z7_6~<6Cqn7>i4!SFN#@tyo9|46D?c0Yorc2zRyQ&sl~?BEoP`FsuC2PSX8!-3;m1G z*)UA*0MJa1ExZLzu}Y>;LcaW)sV*EX-j~LXtpr?dFA*e~%=GtdcG;gf5PN9VqYT~3 z(!7W6?0){a|69-pHINlG7E474bU0P$G!y|*^4d!K_h0KvqdjqPNvA^j0ztN*KDWNM z;78}JqpyNLl#$Kiw(T{VY|bvuw56d)Jk2S2tCv_N8vsLH@R9=GG3PcjG*4r0Adw!w zA3}&)E^?u?$d`G*Oh3&|L9XqUD3wJLDR&Sypd1T?Z}c+(pA$qdIAkboda!Q@#qcs} zLKPcVuYeOehG4lGHYdaGt|n7zod$~>d+lT8k8U$RVp%EMuEr>gC_~9$nOx8mj15Sa1l$WNZ4zqlTBcY9yOR;}w8ncp=g%DBl6b2V|CtT%4B;LoDK z1Y?59xk(yEa!;?nMJj;H;5uIznb)-J z3upu|DlqPfNU@E^uJF1Im}#1C^N1Ccr63Hc20s(Z>x9&)$wbOJRjbKDw4JS19gK}- z6{UJ?SyfP`1{V+LORDGNe!kjL3*7pHs2hFRYq-mpJcBeuXd4S@Y!2w!lWRlPqvh;Y z*9*?P7#R1HdCjj+H(XhlvNo(Bh1ryr(eao(5j5e=T#BvoUv!t0Y`s-Gl_GZ~h!JUfSi?DHyDWhK|ite+L z>)Rg6HZYwjpMt`3SGq{N3}?okbt^=}I;4 zQjVx zvEhcM>Nyvp{dWF0LCWmx7-a0tY}9nlomVCJ!co!N2Q;cx(GVHYuZ@h1jj_#9?U$Oe z>0a1XZQ#rBc*`c2<7=(#a)rko_GHf;UmN?LzjGX*Q`sD;%b+ytP4;y-gv?y@JINb~ zGn&=g!$1%MKH4h-?(#!VDIZO@h-0J=hQicN80!^e10DqV>OA5XiGRkBmyGuH!zwQk z*uD-9#D#&azi5#QD}CCY+=$e>CN#FjTu}M#sD`u{h=e4$=lk^A0~Wze)!02Se8&gT z{rF-KtD;f8A?>vWrrPr;Uk`ubky(o)>Y zfo!cKi0{s4KK+9G&8Y1y-Y(=I4Xx;!kp{@I(rS}A(;FUfG|wEdz{0~OP0lyz=iw+uu4Q0Ep{2R_8a(I#g>Za4z2E+uFTM5L8N^>kwfA70sMC@6A9pX|Gz~ zH$uDlzAn=Nn=Mgxq(}>N5|JV!+y0F-Hns&h8_q{W>BO=72>-Tti9m2OQldMfGpDsd zN;GvAWqmugt*O<$6mCLZ@0pa_%1iKhB zbkXmX($Yhsf^=0PT`bbmL0u=p)DQ%8Qzq0Z<=4w5^9v8k8K`eFJ&1eE(@NyHiBi|@ z3T&0%!9WTY?@aHUCG|o7Ooib4PM1||gJZ`~xCb!lu);|y zuqM~&C-~>{sG3EdVbo_jF_Mvm9|%^zqDY;GUjVbI5WuZ|P&xSH+4p;6-ZEUu959Cl z9NeJ^P+$azb<@*<33FiXvYH$SYp!(|-nBR=w{3criawuC6ToOJ#rGi0aHW!rMTj77 zIHsUgK{%XdrO8e)&GVo(pH|4{^|Hl8gZU;76VaWU$u-I`BfZHI`b|(eyRF{D`cqbn zla6Ip%cQ_8h#F3hh`DX?%ViA9dwEwTmz!2DdcFAs$!-Vr;P81PySqxPzW5dXgT9uF z(`?iI-kP27(4^3O!dGl`-K90}|O;iK34OE*jl|Hr4;HpR037|wD zF4j*~xBjrt8&hL|=*n*$^zFTtvIc~}Hg2exY<2Ey=z}|A!mOg&s|OZSmhG}Wd%H~u z(49M`tyR7Mo`*1Nw@d1{UBb<}NmMD;BogGBW%qVadPufo!?vQ~yDyxU@LUL;#YpEPy4xmjh%TWLc!!^f@Q6v8(319&CPNmNIw!QzOud);nW@RA%T9lhrw>i$tA0T?D!W~)25(`vQzX-O!AYvjk6BdGEckLG$a2BoTgHE$x!|`3%bFjuP5v?rq>}43r~hte z&?ya%x(>SZut7=H;pwTTU`6(H5Vim{YxF1D+!UtoTE0@n+a?IiRCEwcUH}pB+_Qt~atY4#`>Q(insX zgcxd7=WS5kA>L_?yc0NE9H)F0AGi)i3=U9~#qzO2WI);U__!LC_u=y z*qL!bd7v07@L@CYRgs0TMQ-Eat#s?5VzyucA0wDKp>VJ3VQi#0K6G^zy0!VWwG=0K zH|!;Y%Q4G4UU5@=!5LCc(>l!MT?p|7;Zejr6O;izQ`S4wV{n`AZx z$!Nz?l-hbd0fsHOVwaJ^72z3u1;|~cI%5k640DnxKcGDowB4!z3M9S5q)!I-Y_7Rz z-=#f(%xc6aNU4C=bme6;u?kglKoxP(&b|=JRuQ6{`24Yx4o3L}(>KMP;mCro#adx_ z{Z1;=TE$AK^se>hCHYgw_l)z;xbH6H!61)n8GwBcT>}bVT5Ex7Kn1KYcd+R)vX@$B zbl!tW_3ORb7MaFeZ{`|(0YrnxA`aR8A=8h9WZ4YTpeC`cp6Iz$AmN(Y-J^1!@qiNf zPN6hkSpeZ;{BS@(SWh3{ZwxJr9;OmQfS!z@rII2v5NHOVF8;&yLv8%VxEc6D(Wr+& z`o%FaN5*U}IV}N`YYw9`^DDy&1x8BpPyZwIr!suS&Ds|Y}OCyPrYQ6KUrPDDM@p4%R zWhE8HAl&0T(d*e_Yvetm)IzAGHYy?ixTiv0#{GBVcMH}TgviZDxcA{Wttbtmu(Tq< zZLw)YqLBws@)vutd1i9U^o)7UI z<^q90bkh4Je0D{WTCOO}AYIbwa1tS?oLX;(JZjKFQ}KwfFEaDp@Cyk;pM2NlA>i&R zx>0>ItC-Xkp{2qSIOE8mElhM->L8I6PQ?p_i92=SiYp+BR)g$PtqgJgwij)i9oIhg zD|C}B_t*7{BF%ob*Y?N0HL+Gv-rdI2?Ly>tNqJ-t(ffr(I-8gwBikT~bCxSd!tl_pneEWtc}vYq_=roYA}jOt zmB?QlODh?^W`;m8~5ctB-lB7uOW+u|46zJ+TX#9H|=}zbfxMs#U*g zD>IWQcb&Bra`PGbr7}&#Way#$>2u}9A$_1Mx=_neU^ZM)p~0h z5EN*-AaNI!d9qkGi1!d;8H2k-@fKj98lK&yqQP2HJdU!U>g)xY5?cW`uhf$7E^A)+ zWkpnb?5!O*pMW4@?z6nbtHq2qAVQ9tw|v+2R+zKb8{K>y_4Cr1`{I;FZ?Vg~(`fy3 z)5NOTL3EFPY{v4V@~BnnY#5cBxZ0h@-dnwd!<%Yg+by&g(m^d2tGRb@FQ{o7F9^#P zP{@+IeJO}t_)}dR4+lSeiC0KJSaiYSV1nO6=*V~D`_|Y^*mDDqA6umx=Qv*mx^Pg# z#(kND$BOKskLjJ0h`)(m%j{8xf4sm4t|yS3yg7;&<@FR;t9xg_S1Bg)?u~}B%n_p*WSTaxTPTueJMu3U z@$qws7$bugnM{WAL0i5Oi<^2IgCbYeDJdLZ*i+Ft0=I-1BQEgBjgqx1PDSUK)$%@ZO~N`m#unzLoNT zN!low6rK}u@a>W#K>~VmngR{U>gTEa#qU!zC~sE85y`!V?0f~ohs09s+b_9s+1!gu zz+vM=m7ydM#`dhr2Z5NZnP?Q8^Wa5pp%}Bbt(<}WyxVtyg8^EUsde#_{ znxI|7>1#1+8t#?|pjMR88EzKnI25|1j%6Jo}CLM*8dK#Zb6#=D2a z)_V93S$)&db zE3Ho=C}pyCUb`XVsGLi5Yvnjw`6PMS(C%>E*|Wo~;v`h>mc7N*5-h*ZO9_(|<#V-? zb!oA0T6G3C({C`kWi7ukY=qJ)4T@xMn5xUek@V3P1zmgUdIo-K)e$fX47nTc(-wHxb3An8NNhl-T`yxezG2(nym7jRj|}lEoJ3C zKS&#yi=nQ`PT{iD`0n(5I;Ft+36IGl*)%GA%zG1S@j2+0)xLEFHin>iS$NJdI4BM4 z8Xc6r*HCyO$kjZ4qZbWXubOU&DZQZX{rkOE;OmamXI$e=M{h zzU&3FHP1!mC{!$%46$ii8IbHbTp7q~)xB)horFMF%C;O|wwl!8@H67@0OuHFWXiI` z0N1h2#PEt$9rSL-PYxZkC*|pX<^tq!r@OT%^)faQE{cZI=F^J9P&LCh& z@sIze96hyol8~FrtE_OEJTqLIbwIL9J;Zw@~z%csf z1|5nw*r0lReW-*HDm=bzZ(p!sv}&H|R7Ne;TE0ZZK_mEHiL`)wuQ`Zd4Tp&3@>MUT zcSv-@lrdw^t|49XsG}e;HF#)9!ixh*9nY?Wg*7F|s9XI8{QE1>Xac~|my|DslRod~ z2&xrJI{qu4kDe#zyxJ_ItPK~Jyp`iGtw)>}iU+<92h232-+HTxlTcQSN=|xo>(OoLj(#_sRF>ULnZhVy-Ja^%*QgLL$rf%zQ=*XDE=rHo zfe2rEVnnp#=;*{-D37sC*R9PQvvk#$eoT-_KEi#C#IzD-bXOE4CQ-|(8sz!5>A1kx z+muXCD$md+q1`pO&(&EL1U4Glb|qY^mZ5rPp0rGH5%sxrdRjDtNCq)PC%8;P;6c2a zw%k(aWnzD?LrDOoMiZAb9O6Zfgf@~Qs{R)n{}uk5%bk^X{adFm8WBEJxw>7-I`j)d zr+UD|F?p=ZTawRoIpvU8mE!lZJ-5ZlRKmF+LPfJ#Hw6^?@HEhi=LqZoi)cii`k?bl zfd$vMENyDELw3E9Xh(TB`%?qk*@Bj@T4VGx6a~+kCbVRg847P7ao$o{|1sWVaJU@q zg`N?xD?+P-#t~*x)O)Ak`|!|R?WZd`Ph{4nD!`tXGazj9hZXPtVKLXJ+zZHOL~Pul zffy8WL?fy2pnP^2K@LA`zI=)d20|B#l2efvp{Sm@+5BN#oKCi=CLcaPXz-Zx08Yf% zhK~y}6`a*ot?m2z`t2vJa>t~-))U7dNlsV%VUI60exw5YnIRLW-2A>|k-m&4}1+*9P*R*mO5+bqH%);~Ob8rvS@OnC<93fLX z2#yd?Wd2CdBM;P!IT|@EQS3ytGWM_3L3}(S1X~EJCCL-Ra`w_vdjKaK~d5`p<<`*IP;jYkI#oXdEQ z252Ch4xK`gu*mBQQ=*U9?R4!J+iFT>r*v_$nuJ6|90gh~$SHQGlImj=m7(0-0sBZo z0nFUHG0ykQmPLxm>9hOqg(A_7RlKsibDMPK(s8}~wk*DmcibO2;Ca9`l5d;AC}-m7 zXYx12iz}ZsIjPyP;*-qA9V)XEDb}5tXDC+K5O(yi%c2H#I&0(}C%`lZN=5I*v5tk{ zq@=}DEd>Q5a^!VU6 z0eT2jyCZ&ojU=@mGw1so=xNLbUn6loe{qang#Po;!@(3vKuz`p0^*P8nb&DseXr*U zuX6)=oxFMG@Lg{ISdRb>=>Q*U)irTD1X+Cy!1sp>;qUc`qak49(Aw(nO>c~V9yFD%jd=4P}5q0u!)CbWr8*4lJ-$Z>dtO+hj{U$&*)Kd!x2=z%@7XT@+f84gd zg(*Nj^o^k%z>a?-|H29n$kZ`5y_kUQs3^cP68;G?D&F77)|NK9|FrfOm=9r{3fGx_ zV*t!iL?9rfCz$qxe_{TkEc|ycBskn2Y%_`d?M?|3?0WIUeflZys)rGXPA%#PXv{ z9K8Jt*~ZS=-df+<;&0T4GW)-9G{IT|%SwQwK>=i&5uRAXB>FGZzn(sSgZ{53{`cVp z`xod#qxblY;{O7DICX#Xi$V4$;Dalmd0&D*47rC3{Yfhep920rhFXbrujf7>UBV3E zM}I&_c+yxsLo0m~OI@UcfHr9a{Lb8u__Il7;){KEep`rxU+_hT0I$|o%U z49X`A|Et^l$5v03jUO{~)I4GMN$2?2c6zvjr&6bn@lhI|;QulHb)O2K{+i-vwaLd6 zC(S=o{BwZ+PYeoxhnjVh{6o-B$_(|adK=KHp`)6c7KBxRWqNi~rkBKx;e@*mfRLSo#J&m+@ z%#?-pUzq+i?&9|tp9Ts%W>mrWFO2^lHt>6tPY;qGQx;(TO!7a#r)>^W^E)!(;d)%BS!@o+|&h=Ktq?@acl?F@6rs hU-*xgcL30TeIEe_EK-1g$N+yGfcJL@+K0FH{{dOjytV)U literal 0 HcmV?d00001 diff --git a/python/setup.py b/python/setup.py index 2644d3e79dea1..cfc83c68e3df5 100644 --- a/python/setup.py +++ b/python/setup.py @@ -194,7 +194,7 @@ def _supports_symlinks(): 'pyspark.examples.src.main.python': ['*.py', '*/*.py']}, scripts=scripts, license='http://www.apache.org/licenses/LICENSE-2.0', - install_requires=['py4j==0.10.4'], + install_requires=['py4j==0.10.6'], setup_requires=['pypandoc'], extras_require={ 'ml': ['numpy>=1.7'], diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala index e5131e636dc04..1dd0715918042 100644 --- a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala +++ b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala @@ -1124,7 +1124,7 @@ private[spark] class Client( val pyArchivesFile = new File(pyLibPath, "pyspark.zip") require(pyArchivesFile.exists(), s"$pyArchivesFile not found; cannot run pyspark application in YARN mode.") - val py4jFile = new File(pyLibPath, "py4j-0.10.4-src.zip") + val py4jFile = new File(pyLibPath, "py4j-0.10.6-src.zip") require(py4jFile.exists(), s"$py4jFile not found; cannot run pyspark application in YARN mode.") Seq(pyArchivesFile.getAbsolutePath(), py4jFile.getAbsolutePath()) diff --git a/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala b/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala index 59adb7e22d185..fc78bc488b116 100644 --- a/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala +++ b/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala @@ -249,7 +249,7 @@ class YarnClusterSuite extends BaseYarnClusterSuite { // needed locations. val sparkHome = sys.props("spark.test.home") val pythonPath = Seq( - s"$sparkHome/python/lib/py4j-0.10.4-src.zip", + s"$sparkHome/python/lib/py4j-0.10.6-src.zip", s"$sparkHome/python") val extraEnvVars = Map( "PYSPARK_ARCHIVES_PATH" -> pythonPath.map("local:" + _).mkString(File.pathSeparator), diff --git a/sbin/spark-config.sh b/sbin/spark-config.sh index f2d9e6b568a9b..bac154e10ae62 100755 --- a/sbin/spark-config.sh +++ b/sbin/spark-config.sh @@ -28,6 +28,6 @@ export SPARK_CONF_DIR="${SPARK_CONF_DIR:-"${SPARK_HOME}/conf"}" # Add the PySpark classes to the PYTHONPATH: if [ -z "${PYSPARK_PYTHONPATH_SET}" ]; then export PYTHONPATH="${SPARK_HOME}/python:${PYTHONPATH}" - export PYTHONPATH="${SPARK_HOME}/python/lib/py4j-0.10.4-src.zip:${PYTHONPATH}" + export PYTHONPATH="${SPARK_HOME}/python/lib/py4j-0.10.6-src.zip:${PYTHONPATH}" export PYSPARK_PYTHONPATH_SET=1 fi From ab866f117378e64dba483ead51b769ae7be31d4d Mon Sep 17 00:00:00 2001 From: Shixiong Zhu Date: Wed, 5 Jul 2017 18:26:28 -0700 Subject: [PATCH 0877/1765] [SPARK-21248][SS] The clean up codes in StreamExecution should not be interrupted ## What changes were proposed in this pull request? This PR uses `runUninterruptibly` to avoid that the clean up codes in StreamExecution is interrupted. It also removes an optimization in `runUninterruptibly` to make sure this method never throw `InterruptedException`. ## How was this patch tested? Jenkins Author: Shixiong Zhu Closes #18461 from zsxwing/SPARK-21248. --- .../org/apache/spark/util/UninterruptibleThread.scala | 10 +--------- .../apache/spark/util/UninterruptibleThreadSuite.scala | 5 ++--- .../sql/execution/streaming/StreamExecution.scala | 6 +++++- 3 files changed, 8 insertions(+), 13 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/util/UninterruptibleThread.scala b/core/src/main/scala/org/apache/spark/util/UninterruptibleThread.scala index 27922b31949b6..6a58ec142dd7f 100644 --- a/core/src/main/scala/org/apache/spark/util/UninterruptibleThread.scala +++ b/core/src/main/scala/org/apache/spark/util/UninterruptibleThread.scala @@ -55,9 +55,6 @@ private[spark] class UninterruptibleThread( * Run `f` uninterruptibly in `this` thread. The thread won't be interrupted before returning * from `f`. * - * If this method finds that `interrupt` is called before calling `f` and it's not inside another - * `runUninterruptibly`, it will throw `InterruptedException`. - * * Note: this method should be called only in `this` thread. */ def runUninterruptibly[T](f: => T): T = { @@ -73,12 +70,7 @@ private[spark] class UninterruptibleThread( uninterruptibleLock.synchronized { // Clear the interrupted status if it's set. - if (Thread.interrupted() || shouldInterruptThread) { - shouldInterruptThread = false - // Since it's interrupted, we don't need to run `f` which may be a long computation. - // Throw InterruptedException as we don't have a T to return. - throw new InterruptedException() - } + shouldInterruptThread = Thread.interrupted() || shouldInterruptThread uninterruptible = true } try { diff --git a/core/src/test/scala/org/apache/spark/util/UninterruptibleThreadSuite.scala b/core/src/test/scala/org/apache/spark/util/UninterruptibleThreadSuite.scala index 39b31f8ddeaba..6a190f63ac9d0 100644 --- a/core/src/test/scala/org/apache/spark/util/UninterruptibleThreadSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/UninterruptibleThreadSuite.scala @@ -68,7 +68,6 @@ class UninterruptibleThreadSuite extends SparkFunSuite { Uninterruptibles.awaitUninterruptibly(interruptLatch, 10, TimeUnit.SECONDS) try { runUninterruptibly { - assert(false, "Should not reach here") } } catch { case _: InterruptedException => hasInterruptedException = true @@ -80,8 +79,8 @@ class UninterruptibleThreadSuite extends SparkFunSuite { t.interrupt() interruptLatch.countDown() t.join() - assert(hasInterruptedException === true) - assert(interruptStatusBeforeExit === false) + assert(hasInterruptedException === false) + assert(interruptStatusBeforeExit === true) } test("nested runUninterruptibly") { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala index d5f8d2acba92b..10c42a7338e85 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala @@ -357,7 +357,11 @@ class StreamExecution( if (!NonFatal(e)) { throw e } - } finally { + } finally microBatchThread.runUninterruptibly { + // The whole `finally` block must run inside `runUninterruptibly` to avoid being interrupted + // when a query is stopped by the user. We need to make sure the following codes finish + // otherwise it may throw `InterruptedException` to `UncaughtExceptionHandler` (SPARK-21248). + // Release latches to unblock the user codes since exception can happen in any place and we // may not get a chance to release them startLatch.countDown() From 75b168fd30bb9a52ae223b6f1df73da4b1316f2e Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Thu, 6 Jul 2017 14:18:50 +0800 Subject: [PATCH 0878/1765] [SPARK-21308][SQL] Remove SQLConf parameters from the optimizer ### What changes were proposed in this pull request? This PR removes SQLConf parameters from the optimizer rules ### How was this patch tested? The existing test cases Author: gatorsmile Closes #18533 from gatorsmile/rmSQLConfOptimizer. --- .../optimizer/CostBasedJoinReorder.scala | 7 ++-- .../sql/catalyst/optimizer/Optimizer.scala | 36 +++++++++---------- .../optimizer/StarSchemaDetection.scala | 4 ++- .../sql/catalyst/optimizer/expressions.scala | 14 ++++---- .../spark/sql/catalyst/optimizer/joins.scala | 6 ++-- .../BinaryComparisonSimplificationSuite.scala | 2 +- .../BooleanSimplificationSuite.scala | 2 +- .../optimizer/CombiningLimitsSuite.scala | 2 +- .../optimizer/ConstantFoldingSuite.scala | 2 +- .../optimizer/DecimalAggregatesSuite.scala | 2 +- .../optimizer/EliminateMapObjectsSuite.scala | 2 +- .../optimizer/JoinOptimizationSuite.scala | 2 +- .../catalyst/optimizer/JoinReorderSuite.scala | 27 +++++++++++--- .../optimizer/LimitPushdownSuite.scala | 2 +- .../optimizer/OptimizeCodegenSuite.scala | 2 +- .../catalyst/optimizer/OptimizeInSuite.scala | 24 +++++++------ .../StarJoinCostBasedReorderSuite.scala | 36 ++++++++++++++----- .../optimizer/StarJoinReorderSuite.scala | 25 ++++++++++--- .../optimizer/complexTypesSuite.scala | 2 +- .../spark/sql/catalyst/plans/PlanTest.scala | 4 +-- .../execution/OptimizeMetadataOnlyQuery.scala | 8 ++--- .../spark/sql/execution/SparkOptimizer.scala | 6 ++-- .../internal/BaseSessionStateBuilder.scala | 2 +- 23 files changed, 137 insertions(+), 82 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/CostBasedJoinReorder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/CostBasedJoinReorder.scala index 3a7543e2141e9..db7baf6e9bc7d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/CostBasedJoinReorder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/CostBasedJoinReorder.scala @@ -32,7 +32,10 @@ import org.apache.spark.sql.internal.SQLConf * We may have several join reorder algorithms in the future. This class is the entry of these * algorithms, and chooses which one to use. */ -case class CostBasedJoinReorder(conf: SQLConf) extends Rule[LogicalPlan] with PredicateHelper { +object CostBasedJoinReorder extends Rule[LogicalPlan] with PredicateHelper { + + private def conf = SQLConf.get + def apply(plan: LogicalPlan): LogicalPlan = { if (!conf.cboEnabled || !conf.joinReorderEnabled) { plan @@ -379,7 +382,7 @@ object JoinReorderDPFilters extends PredicateHelper { if (conf.joinReorderDPStarFilter) { // Compute the tables in a star-schema relationship. - val starJoin = StarSchemaDetection(conf).findStarJoins(items, conditions.toSeq) + val starJoin = StarSchemaDetection.findStarJoins(items, conditions.toSeq) val nonStarJoin = items.filterNot(starJoin.contains(_)) if (starJoin.nonEmpty && nonStarJoin.nonEmpty) { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 946fa7bae0199..d82af94dbffb7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -34,10 +34,10 @@ import org.apache.spark.sql.types._ * Abstract class all optimizers should inherit of, contains the standard batches (extending * Optimizers can override this. */ -abstract class Optimizer(sessionCatalog: SessionCatalog, conf: SQLConf) +abstract class Optimizer(sessionCatalog: SessionCatalog) extends RuleExecutor[LogicalPlan] { - protected val fixedPoint = FixedPoint(conf.optimizerMaxIterations) + protected def fixedPoint = FixedPoint(SQLConf.get.optimizerMaxIterations) def batches: Seq[Batch] = { Batch("Eliminate Distinct", Once, EliminateDistinct) :: @@ -77,11 +77,11 @@ abstract class Optimizer(sessionCatalog: SessionCatalog, conf: SQLConf) Batch("Operator Optimizations", fixedPoint, Seq( // Operator push down PushProjectionThroughUnion, - ReorderJoin(conf), + ReorderJoin, EliminateOuterJoin, PushPredicateThroughJoin, PushDownPredicate, - LimitPushDown(conf), + LimitPushDown, ColumnPruning, InferFiltersFromConstraints, // Operator combine @@ -92,10 +92,10 @@ abstract class Optimizer(sessionCatalog: SessionCatalog, conf: SQLConf) CombineLimits, CombineUnions, // Constant folding and strength reduction - NullPropagation(conf), + NullPropagation, ConstantPropagation, FoldablePropagation, - OptimizeIn(conf), + OptimizeIn, ConstantFolding, ReorderAssociativeOperator, LikeSimplification, @@ -117,11 +117,11 @@ abstract class Optimizer(sessionCatalog: SessionCatalog, conf: SQLConf) CombineConcats) ++ extendedOperatorOptimizationRules: _*) :: Batch("Check Cartesian Products", Once, - CheckCartesianProducts(conf)) :: + CheckCartesianProducts) :: Batch("Join Reorder", Once, - CostBasedJoinReorder(conf)) :: + CostBasedJoinReorder) :: Batch("Decimal Optimizations", fixedPoint, - DecimalAggregates(conf)) :: + DecimalAggregates) :: Batch("Object Expressions Optimization", fixedPoint, EliminateMapObjects, CombineTypedFilters) :: @@ -129,7 +129,7 @@ abstract class Optimizer(sessionCatalog: SessionCatalog, conf: SQLConf) ConvertToLocalRelation, PropagateEmptyRelation) :: Batch("OptimizeCodegen", Once, - OptimizeCodegen(conf)) :: + OptimizeCodegen) :: Batch("RewriteSubquery", Once, RewritePredicateSubquery, CollapseProject) :: Nil @@ -178,8 +178,7 @@ class SimpleTestOptimizer extends Optimizer( new SessionCatalog( new InMemoryCatalog, EmptyFunctionRegistry, - new SQLConf().copy(SQLConf.CASE_SENSITIVE -> true)), - new SQLConf().copy(SQLConf.CASE_SENSITIVE -> true)) + new SQLConf().copy(SQLConf.CASE_SENSITIVE -> true))) /** * Remove redundant aliases from a query plan. A redundant alias is an alias that does not change @@ -288,7 +287,7 @@ object RemoveRedundantProject extends Rule[LogicalPlan] { /** * Pushes down [[LocalLimit]] beneath UNION ALL and beneath the streamed inputs of outer joins. */ -case class LimitPushDown(conf: SQLConf) extends Rule[LogicalPlan] { +object LimitPushDown extends Rule[LogicalPlan] { private def stripGlobalLimitIfPresent(plan: LogicalPlan): LogicalPlan = { plan match { @@ -1077,8 +1076,7 @@ object CombineLimits extends Rule[LogicalPlan] { * the join between R and S is not a cartesian product and therefore should be allowed. * The predicate R.r = S.s is not recognized as a join condition until the ReorderJoin rule. */ -case class CheckCartesianProducts(conf: SQLConf) - extends Rule[LogicalPlan] with PredicateHelper { +object CheckCartesianProducts extends Rule[LogicalPlan] with PredicateHelper { /** * Check if a join is a cartesian product. Returns true if * there are no join conditions involving references from both left and right. @@ -1090,7 +1088,7 @@ case class CheckCartesianProducts(conf: SQLConf) } def apply(plan: LogicalPlan): LogicalPlan = - if (conf.crossJoinEnabled) { + if (SQLConf.get.crossJoinEnabled) { plan } else plan transform { case j @ Join(left, right, Inner | LeftOuter | RightOuter | FullOuter, condition) @@ -1112,7 +1110,7 @@ case class CheckCartesianProducts(conf: SQLConf) * This uses the same rules for increasing the precision and scale of the output as * [[org.apache.spark.sql.catalyst.analysis.DecimalPrecision]]. */ -case class DecimalAggregates(conf: SQLConf) extends Rule[LogicalPlan] { +object DecimalAggregates extends Rule[LogicalPlan] { import Decimal.MAX_LONG_DIGITS /** Maximum number of decimal digits representable precisely in a Double */ @@ -1130,7 +1128,7 @@ case class DecimalAggregates(conf: SQLConf) extends Rule[LogicalPlan] { we.copy(windowFunction = ae.copy(aggregateFunction = Average(UnscaledValue(e)))) Cast( Divide(newAggExpr, Literal.create(math.pow(10.0, scale), DoubleType)), - DecimalType(prec + 4, scale + 4), Option(conf.sessionLocalTimeZone)) + DecimalType(prec + 4, scale + 4), Option(SQLConf.get.sessionLocalTimeZone)) case _ => we } @@ -1142,7 +1140,7 @@ case class DecimalAggregates(conf: SQLConf) extends Rule[LogicalPlan] { val newAggExpr = ae.copy(aggregateFunction = Average(UnscaledValue(e))) Cast( Divide(newAggExpr, Literal.create(math.pow(10.0, scale), DoubleType)), - DecimalType(prec + 4, scale + 4), Option(conf.sessionLocalTimeZone)) + DecimalType(prec + 4, scale + 4), Option(SQLConf.get.sessionLocalTimeZone)) case _ => ae } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/StarSchemaDetection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/StarSchemaDetection.scala index ca729127e7d1d..1f20b7661489e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/StarSchemaDetection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/StarSchemaDetection.scala @@ -28,7 +28,9 @@ import org.apache.spark.sql.internal.SQLConf /** * Encapsulates star-schema detection logic. */ -case class StarSchemaDetection(conf: SQLConf) extends PredicateHelper { +object StarSchemaDetection extends PredicateHelper { + + private def conf = SQLConf.get /** * Star schema consists of one or more fact tables referencing a number of dimension diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala index 66b8ca62e5e4c..6c83f4790004f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala @@ -173,12 +173,12 @@ object ReorderAssociativeOperator extends Rule[LogicalPlan] { * 2. Replaces [[In (value, seq[Literal])]] with optimized version * [[InSet (value, HashSet[Literal])]] which is much faster. */ -case class OptimizeIn(conf: SQLConf) extends Rule[LogicalPlan] { +object OptimizeIn extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan transform { case q: LogicalPlan => q transformExpressionsDown { case expr @ In(v, list) if expr.inSetConvertible => val newList = ExpressionSet(list).toSeq - if (newList.size > conf.optimizerInSetConversionThreshold) { + if (newList.size > SQLConf.get.optimizerInSetConversionThreshold) { val hSet = newList.map(e => e.eval(EmptyRow)) InSet(v, HashSet() ++ hSet) } else if (newList.size < list.size) { @@ -414,7 +414,7 @@ object LikeSimplification extends Rule[LogicalPlan] { * equivalent [[Literal]] values. This rule is more specific with * Null value propagation from bottom to top of the expression tree. */ -case class NullPropagation(conf: SQLConf) extends Rule[LogicalPlan] { +object NullPropagation extends Rule[LogicalPlan] { private def isNullLiteral(e: Expression): Boolean = e match { case Literal(null, _) => true case _ => false @@ -423,9 +423,9 @@ case class NullPropagation(conf: SQLConf) extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan transform { case q: LogicalPlan => q transformExpressionsUp { case e @ WindowExpression(Cast(Literal(0L, _), _, _), _) => - Cast(Literal(0L), e.dataType, Option(conf.sessionLocalTimeZone)) + Cast(Literal(0L), e.dataType, Option(SQLConf.get.sessionLocalTimeZone)) case e @ AggregateExpression(Count(exprs), _, _, _) if exprs.forall(isNullLiteral) => - Cast(Literal(0L), e.dataType, Option(conf.sessionLocalTimeZone)) + Cast(Literal(0L), e.dataType, Option(SQLConf.get.sessionLocalTimeZone)) case ae @ AggregateExpression(Count(exprs), _, false, _) if !exprs.exists(_.nullable) => // This rule should be only triggered when isDistinct field is false. ae.copy(aggregateFunction = Count(Literal(1))) @@ -552,14 +552,14 @@ object FoldablePropagation extends Rule[LogicalPlan] { /** * Optimizes expressions by replacing according to CodeGen configuration. */ -case class OptimizeCodegen(conf: SQLConf) extends Rule[LogicalPlan] { +object OptimizeCodegen extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { case e: CaseWhen if canCodegen(e) => e.toCodegen() } private def canCodegen(e: CaseWhen): Boolean = { val numBranches = e.branches.size + e.elseValue.size - numBranches <= conf.maxCaseBranchesForCodegen + numBranches <= SQLConf.get.maxCaseBranchesForCodegen } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/joins.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/joins.scala index bb97e2c808b9f..edbeaf273fd6f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/joins.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/joins.scala @@ -34,7 +34,7 @@ import org.apache.spark.sql.internal.SQLConf * * If star schema detection is enabled, reorder the star join plans based on heuristics. */ -case class ReorderJoin(conf: SQLConf) extends Rule[LogicalPlan] with PredicateHelper { +object ReorderJoin extends Rule[LogicalPlan] with PredicateHelper { /** * Join a list of plans together and push down the conditions into them. * @@ -87,8 +87,8 @@ case class ReorderJoin(conf: SQLConf) extends Rule[LogicalPlan] with PredicateHe def apply(plan: LogicalPlan): LogicalPlan = plan transform { case ExtractFiltersAndInnerJoins(input, conditions) if input.size > 2 && conditions.nonEmpty => - if (conf.starSchemaDetection && !conf.cboEnabled) { - val starJoinPlan = StarSchemaDetection(conf).reorderStarJoins(input, conditions) + if (SQLConf.get.starSchemaDetection && !SQLConf.get.cboEnabled) { + val starJoinPlan = StarSchemaDetection.reorderStarJoins(input, conditions) if (starJoinPlan.nonEmpty) { val rest = input.filterNot(starJoinPlan.contains(_)) createOrderedJoin(starJoinPlan ++ rest, conditions) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/BinaryComparisonSimplificationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/BinaryComparisonSimplificationSuite.scala index 2a04bd588dc1d..a313681eeb8f0 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/BinaryComparisonSimplificationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/BinaryComparisonSimplificationSuite.scala @@ -33,7 +33,7 @@ class BinaryComparisonSimplificationSuite extends PlanTest with PredicateHelper Batch("AnalysisNodes", Once, EliminateSubqueryAliases) :: Batch("Constant Folding", FixedPoint(50), - NullPropagation(conf), + NullPropagation, ConstantFolding, BooleanSimplification, SimplifyBinaryComparison, diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/BooleanSimplificationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/BooleanSimplificationSuite.scala index c6345b60b744b..56399f4831a6f 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/BooleanSimplificationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/BooleanSimplificationSuite.scala @@ -35,7 +35,7 @@ class BooleanSimplificationSuite extends PlanTest with PredicateHelper { Batch("AnalysisNodes", Once, EliminateSubqueryAliases) :: Batch("Constant Folding", FixedPoint(50), - NullPropagation(conf), + NullPropagation, ConstantFolding, BooleanSimplification, PruneFilters) :: Nil diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CombiningLimitsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CombiningLimitsSuite.scala index ac71887c16f96..87ad81db11b64 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CombiningLimitsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CombiningLimitsSuite.scala @@ -32,7 +32,7 @@ class CombiningLimitsSuite extends PlanTest { Batch("Combine Limit", FixedPoint(10), CombineLimits) :: Batch("Constant Folding", FixedPoint(10), - NullPropagation(conf), + NullPropagation, ConstantFolding, BooleanSimplification, SimplifyConditionals) :: Nil diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConstantFoldingSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConstantFoldingSuite.scala index 25c592b9c1dde..641c89873dcc4 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConstantFoldingSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConstantFoldingSuite.scala @@ -33,7 +33,7 @@ class ConstantFoldingSuite extends PlanTest { Batch("AnalysisNodes", Once, EliminateSubqueryAliases) :: Batch("ConstantFolding", Once, - OptimizeIn(conf), + OptimizeIn, ConstantFolding, BooleanSimplification) :: Nil } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/DecimalAggregatesSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/DecimalAggregatesSuite.scala index cc4fb3a244a98..711294ed61928 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/DecimalAggregatesSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/DecimalAggregatesSuite.scala @@ -29,7 +29,7 @@ class DecimalAggregatesSuite extends PlanTest { object Optimize extends RuleExecutor[LogicalPlan] { val batches = Batch("Decimal Optimizations", FixedPoint(100), - DecimalAggregates(conf)) :: Nil + DecimalAggregates) :: Nil } val testRelation = LocalRelation('a.decimal(2, 1), 'b.decimal(12, 1)) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/EliminateMapObjectsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/EliminateMapObjectsSuite.scala index d4f37e2a5e877..157472c2293f9 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/EliminateMapObjectsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/EliminateMapObjectsSuite.scala @@ -31,7 +31,7 @@ class EliminateMapObjectsSuite extends PlanTest { object Optimize extends RuleExecutor[LogicalPlan] { val batches = { Batch("EliminateMapObjects", FixedPoint(50), - NullPropagation(conf), + NullPropagation, SimplifyCasts, EliminateMapObjects) :: Nil } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinOptimizationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinOptimizationSuite.scala index a6584aa5fbba7..2f30a78f03211 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinOptimizationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinOptimizationSuite.scala @@ -37,7 +37,7 @@ class JoinOptimizationSuite extends PlanTest { CombineFilters, PushDownPredicate, BooleanSimplification, - ReorderJoin(conf), + ReorderJoin, PushPredicateThroughJoin, ColumnPruning, CollapseProject) :: Nil diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinReorderSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinReorderSuite.scala index 71db4e2e0ec4d..2fb587d50a4cb 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinReorderSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinReorderSuite.scala @@ -24,25 +24,42 @@ import org.apache.spark.sql.catalyst.plans.{Inner, PlanTest} import org.apache.spark.sql.catalyst.plans.logical.{ColumnStat, LogicalPlan} import org.apache.spark.sql.catalyst.rules.RuleExecutor import org.apache.spark.sql.catalyst.statsEstimation.{StatsEstimationTestBase, StatsTestPlan} -import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.internal.SQLConf.{CBO_ENABLED, JOIN_REORDER_ENABLED} class JoinReorderSuite extends PlanTest with StatsEstimationTestBase { - override val conf = new SQLConf().copy(CBO_ENABLED -> true, JOIN_REORDER_ENABLED -> true) - object Optimize extends RuleExecutor[LogicalPlan] { val batches = Batch("Operator Optimizations", FixedPoint(100), CombineFilters, PushDownPredicate, - ReorderJoin(conf), + ReorderJoin, PushPredicateThroughJoin, ColumnPruning, CollapseProject) :: Batch("Join Reorder", Once, - CostBasedJoinReorder(conf)) :: Nil + CostBasedJoinReorder) :: Nil + } + + var originalConfCBOEnabled = false + var originalConfJoinReorderEnabled = false + + override def beforeAll(): Unit = { + super.beforeAll() + originalConfCBOEnabled = conf.cboEnabled + originalConfJoinReorderEnabled = conf.joinReorderEnabled + conf.setConf(CBO_ENABLED, true) + conf.setConf(JOIN_REORDER_ENABLED, true) + } + + override def afterAll(): Unit = { + try { + conf.setConf(CBO_ENABLED, originalConfCBOEnabled) + conf.setConf(JOIN_REORDER_ENABLED, originalConfJoinReorderEnabled) + } finally { + super.afterAll() + } } /** Set up tables and columns for testing */ diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/LimitPushdownSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/LimitPushdownSuite.scala index d8302dfc9462d..f50e2e86516f0 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/LimitPushdownSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/LimitPushdownSuite.scala @@ -32,7 +32,7 @@ class LimitPushdownSuite extends PlanTest { Batch("Subqueries", Once, EliminateSubqueryAliases) :: Batch("Limit pushdown", FixedPoint(100), - LimitPushDown(conf), + LimitPushDown, CombineLimits, ConstantFolding, BooleanSimplification) :: Nil diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeCodegenSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeCodegenSuite.scala index 9dc6738ba04b3..b71067c0af3a1 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeCodegenSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeCodegenSuite.scala @@ -28,7 +28,7 @@ import org.apache.spark.sql.catalyst.rules._ class OptimizeCodegenSuite extends PlanTest { object Optimize extends RuleExecutor[LogicalPlan] { - val batches = Batch("OptimizeCodegen", Once, OptimizeCodegen(conf)) :: Nil + val batches = Batch("OptimizeCodegen", Once, OptimizeCodegen) :: Nil } protected def assertEquivalent(e1: Expression, e2: Expression): Unit = { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeInSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeInSuite.scala index d8937321ecb98..6a77580b29a21 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeInSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeInSuite.scala @@ -34,10 +34,10 @@ class OptimizeInSuite extends PlanTest { Batch("AnalysisNodes", Once, EliminateSubqueryAliases) :: Batch("ConstantFolding", FixedPoint(10), - NullPropagation(conf), + NullPropagation, ConstantFolding, BooleanSimplification, - OptimizeIn(conf)) :: Nil + OptimizeIn) :: Nil } val testRelation = LocalRelation('a.int, 'b.int, 'c.int) @@ -159,16 +159,20 @@ class OptimizeInSuite extends PlanTest { .where(In(UnresolvedAttribute("a"), Seq(Literal(1), Literal(2), Literal(3)))) .analyze - val notOptimizedPlan = OptimizeIn(conf)(plan) - comparePlans(notOptimizedPlan, plan) + withSQLConf(OPTIMIZER_INSET_CONVERSION_THRESHOLD.key -> "10") { + val notOptimizedPlan = OptimizeIn(plan) + comparePlans(notOptimizedPlan, plan) + } // Reduce the threshold to turning into InSet. - val optimizedPlan = OptimizeIn(conf.copy(OPTIMIZER_INSET_CONVERSION_THRESHOLD -> 2))(plan) - optimizedPlan match { - case Filter(cond, _) - if cond.isInstanceOf[InSet] && cond.asInstanceOf[InSet].getHSet().size == 3 => - // pass - case _ => fail("Unexpected result for OptimizedIn") + withSQLConf(OPTIMIZER_INSET_CONVERSION_THRESHOLD.key -> "2") { + val optimizedPlan = OptimizeIn(plan) + optimizedPlan match { + case Filter(cond, _) + if cond.isInstanceOf[InSet] && cond.asInstanceOf[InSet].getHSet().size == 3 => + // pass + case _ => fail("Unexpected result for OptimizedIn") + } } } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/StarJoinCostBasedReorderSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/StarJoinCostBasedReorderSuite.scala index a23d6266b2840..ada6e2a43ea0f 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/StarJoinCostBasedReorderSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/StarJoinCostBasedReorderSuite.scala @@ -24,28 +24,46 @@ import org.apache.spark.sql.catalyst.plans.{Inner, PlanTest} import org.apache.spark.sql.catalyst.plans.logical.{ColumnStat, LogicalPlan} import org.apache.spark.sql.catalyst.rules.RuleExecutor import org.apache.spark.sql.catalyst.statsEstimation.{StatsEstimationTestBase, StatsTestPlan} -import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.internal.SQLConf._ class StarJoinCostBasedReorderSuite extends PlanTest with StatsEstimationTestBase { - override val conf = new SQLConf().copy( - CBO_ENABLED -> true, - JOIN_REORDER_ENABLED -> true, - JOIN_REORDER_DP_STAR_FILTER -> true) - object Optimize extends RuleExecutor[LogicalPlan] { val batches = Batch("Operator Optimizations", FixedPoint(100), CombineFilters, PushDownPredicate, - ReorderJoin(conf), + ReorderJoin, PushPredicateThroughJoin, ColumnPruning, CollapseProject) :: - Batch("Join Reorder", Once, - CostBasedJoinReorder(conf)) :: Nil + Batch("Join Reorder", Once, + CostBasedJoinReorder) :: Nil + } + + var originalConfCBOEnabled = false + var originalConfJoinReorderEnabled = false + var originalConfJoinReorderDPStarFilter = false + + override def beforeAll(): Unit = { + super.beforeAll() + originalConfCBOEnabled = conf.cboEnabled + originalConfJoinReorderEnabled = conf.joinReorderEnabled + originalConfJoinReorderDPStarFilter = conf.joinReorderDPStarFilter + conf.setConf(CBO_ENABLED, true) + conf.setConf(JOIN_REORDER_ENABLED, true) + conf.setConf(JOIN_REORDER_DP_STAR_FILTER, true) + } + + override def afterAll(): Unit = { + try { + conf.setConf(CBO_ENABLED, originalConfCBOEnabled) + conf.setConf(JOIN_REORDER_ENABLED, originalConfJoinReorderEnabled) + conf.setConf(JOIN_REORDER_DP_STAR_FILTER, originalConfJoinReorderDPStarFilter) + } finally { + super.afterAll() + } } private val columnInfo: AttributeMap[ColumnStat] = AttributeMap(Seq( diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/StarJoinReorderSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/StarJoinReorderSuite.scala index 605c01b7220d1..777c5637201ed 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/StarJoinReorderSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/StarJoinReorderSuite.scala @@ -24,19 +24,36 @@ import org.apache.spark.sql.catalyst.plans.{Inner, PlanTest} import org.apache.spark.sql.catalyst.plans.logical.{ColumnStat, LocalRelation, LogicalPlan} import org.apache.spark.sql.catalyst.rules.RuleExecutor import org.apache.spark.sql.catalyst.statsEstimation.{StatsEstimationTestBase, StatsTestPlan} -import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.internal.SQLConf.{CASE_SENSITIVE, STARSCHEMA_DETECTION} +import org.apache.spark.sql.internal.SQLConf._ class StarJoinReorderSuite extends PlanTest with StatsEstimationTestBase { - override val conf = new SQLConf().copy(CASE_SENSITIVE -> true, STARSCHEMA_DETECTION -> true) + var originalConfStarSchemaDetection = false + var originalConfCBOEnabled = true + + override def beforeAll(): Unit = { + super.beforeAll() + originalConfStarSchemaDetection = conf.starSchemaDetection + originalConfCBOEnabled = conf.cboEnabled + conf.setConf(STARSCHEMA_DETECTION, true) + conf.setConf(CBO_ENABLED, false) + } + + override def afterAll(): Unit = { + try { + conf.setConf(STARSCHEMA_DETECTION, originalConfStarSchemaDetection) + conf.setConf(CBO_ENABLED, originalConfCBOEnabled) + } finally { + super.afterAll() + } + } object Optimize extends RuleExecutor[LogicalPlan] { val batches = Batch("Operator Optimizations", FixedPoint(100), CombineFilters, PushDownPredicate, - ReorderJoin(conf), + ReorderJoin, PushPredicateThroughJoin, ColumnPruning, CollapseProject) :: Nil diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/complexTypesSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/complexTypesSuite.scala index 0a18858350e1f..3634accf1ec21 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/complexTypesSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/complexTypesSuite.scala @@ -37,7 +37,7 @@ class ComplexTypesSuite extends PlanTest{ Batch("collapse projections", FixedPoint(10), CollapseProject) :: Batch("Constant Folding", FixedPoint(10), - NullPropagation(conf), + NullPropagation, ConstantFolding, BooleanSimplification, SimplifyConditionals, diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala index e9679d3361509..5389bf3389da4 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala @@ -31,8 +31,8 @@ import org.apache.spark.sql.internal.SQLConf */ trait PlanTest extends SparkFunSuite with PredicateHelper { - // TODO(gatorsmile): remove this from PlanTest and all the analyzer/optimizer rules - protected val conf = new SQLConf().copy(SQLConf.CASE_SENSITIVE -> true) + // TODO(gatorsmile): remove this from PlanTest and all the analyzer rules + protected def conf = SQLConf.get /** * Since attribute references are given globally unique ids during analysis, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/OptimizeMetadataOnlyQuery.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/OptimizeMetadataOnlyQuery.scala index 3c046ce494285..5cfad9126986b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/OptimizeMetadataOnlyQuery.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/OptimizeMetadataOnlyQuery.scala @@ -38,12 +38,10 @@ import org.apache.spark.sql.internal.SQLConf * 3. aggregate function on partition columns which have same result w or w/o DISTINCT keyword. * e.g. SELECT col1, Max(col2) FROM tbl GROUP BY col1. */ -case class OptimizeMetadataOnlyQuery( - catalog: SessionCatalog, - conf: SQLConf) extends Rule[LogicalPlan] { +case class OptimizeMetadataOnlyQuery(catalog: SessionCatalog) extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = { - if (!conf.optimizerMetadataOnly) { + if (!SQLConf.get.optimizerMetadataOnly) { return plan } @@ -106,7 +104,7 @@ case class OptimizeMetadataOnlyQuery( val caseInsensitiveProperties = CaseInsensitiveMap(relation.tableMeta.storage.properties) val timeZoneId = caseInsensitiveProperties.get(DateTimeUtils.TIMEZONE_OPTION) - .getOrElse(conf.sessionLocalTimeZone) + .getOrElse(SQLConf.get.sessionLocalTimeZone) val partitionData = catalog.listPartitions(relation.tableMeta.identifier).map { p => InternalRow.fromSeq(partAttrs.map { attr => Cast(Literal(p.spec(attr.name)), attr.dataType, Option(timeZoneId)).eval() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala index 1de4f508b89a0..00ff4c8ac310b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala @@ -22,16 +22,14 @@ import org.apache.spark.sql.catalyst.catalog.SessionCatalog import org.apache.spark.sql.catalyst.optimizer.Optimizer import org.apache.spark.sql.execution.datasources.PruneFileSourcePartitions import org.apache.spark.sql.execution.python.ExtractPythonUDFFromAggregate -import org.apache.spark.sql.internal.SQLConf class SparkOptimizer( catalog: SessionCatalog, - conf: SQLConf, experimentalMethods: ExperimentalMethods) - extends Optimizer(catalog, conf) { + extends Optimizer(catalog) { override def batches: Seq[Batch] = (preOptimizationBatches ++ super.batches :+ - Batch("Optimize Metadata Only Query", Once, OptimizeMetadataOnlyQuery(catalog, conf)) :+ + Batch("Optimize Metadata Only Query", Once, OptimizeMetadataOnlyQuery(catalog)) :+ Batch("Extract Python UDF from Aggregate", Once, ExtractPythonUDFFromAggregate) :+ Batch("Prune File Source Table Partitions", Once, PruneFileSourcePartitions)) ++ postHocOptimizationBatches :+ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala index 9d0148117fadf..72d0ddc62303a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala @@ -208,7 +208,7 @@ abstract class BaseSessionStateBuilder( * Note: this depends on the `conf`, `catalog` and `experimentalMethods` fields. */ protected def optimizer: Optimizer = { - new SparkOptimizer(catalog, conf, experimentalMethods) { + new SparkOptimizer(catalog, experimentalMethods) { override def extendedOperatorOptimizationRules: Seq[Rule[LogicalPlan]] = super.extendedOperatorOptimizationRules ++ customOperatorOptimizationRules } From 14a3bb3a008c302aac908d7deaf0942a98c63be7 Mon Sep 17 00:00:00 2001 From: Sumedh Wale Date: Thu, 6 Jul 2017 14:47:22 +0800 Subject: [PATCH 0879/1765] [SPARK-21312][SQL] correct offsetInBytes in UnsafeRow.writeToStream ## What changes were proposed in this pull request? Corrects offsetInBytes calculation in UnsafeRow.writeToStream. Known failures include writes to some DataSources that have own SparkPlan implementations and cause EXCHANGE in writes. ## How was this patch tested? Extended UnsafeRowSuite.writeToStream to include an UnsafeRow over byte array having non-zero offset. Author: Sumedh Wale Closes #18535 from sumwale/SPARK-21312. --- .../spark/sql/catalyst/expressions/UnsafeRow.java | 2 +- .../scala/org/apache/spark/sql/UnsafeRowSuite.scala | 13 +++++++++++++ 2 files changed, 14 insertions(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java index 86de90984ca00..56994fafe064b 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java @@ -550,7 +550,7 @@ public void copyFrom(UnsafeRow row) { */ public void writeToStream(OutputStream out, byte[] writeBuffer) throws IOException { if (baseObject instanceof byte[]) { - int offsetInByteArray = (int) (Platform.BYTE_ARRAY_OFFSET - baseOffset); + int offsetInByteArray = (int) (baseOffset - Platform.BYTE_ARRAY_OFFSET); out.write((byte[]) baseObject, offsetInByteArray, sizeInBytes); } else { int dataRemaining = sizeInBytes; diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UnsafeRowSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UnsafeRowSuite.scala index a32763db054f3..a5f904c621e6e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/UnsafeRowSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/UnsafeRowSuite.scala @@ -101,9 +101,22 @@ class UnsafeRowSuite extends SparkFunSuite { MemoryAllocator.UNSAFE.free(offheapRowPage) } } + val (bytesFromArrayBackedRowWithOffset, field0StringFromArrayBackedRowWithOffset) = { + val baos = new ByteArrayOutputStream() + val numBytes = arrayBackedUnsafeRow.getSizeInBytes + val bytesWithOffset = new Array[Byte](numBytes + 100) + System.arraycopy(arrayBackedUnsafeRow.getBaseObject.asInstanceOf[Array[Byte]], 0, + bytesWithOffset, 100, numBytes) + val arrayBackedRow = new UnsafeRow(arrayBackedUnsafeRow.numFields()) + arrayBackedRow.pointTo(bytesWithOffset, Platform.BYTE_ARRAY_OFFSET + 100, numBytes) + arrayBackedRow.writeToStream(baos, null) + (baos.toByteArray, arrayBackedRow.getString(0)) + } assert(bytesFromArrayBackedRow === bytesFromOffheapRow) assert(field0StringFromArrayBackedRow === field0StringFromOffheapRow) + assert(bytesFromArrayBackedRow === bytesFromArrayBackedRowWithOffset) + assert(field0StringFromArrayBackedRow === field0StringFromArrayBackedRowWithOffset) } test("calling getDouble() and getFloat() on null columns") { From 60043f22458668ac7ecba94fa78953f23a6bdcec Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Thu, 6 Jul 2017 00:20:26 -0700 Subject: [PATCH 0880/1765] [SS][MINOR] Fix flaky test in DatastreamReaderWriterSuite. temp checkpoint dir should be deleted ## What changes were proposed in this pull request? Stopping query while it is being initialized can throw interrupt exception, in which case temporary checkpoint directories will not be deleted, and the test will fail. Author: Tathagata Das Closes #18442 from tdas/DatastreamReaderWriterSuite-fix. --- .../spark/sql/streaming/test/DataStreamReaderWriterSuite.scala | 1 + 1 file changed, 1 insertion(+) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/test/DataStreamReaderWriterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/test/DataStreamReaderWriterSuite.scala index 3de0ae67a3892..e8a6202b8adce 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/test/DataStreamReaderWriterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/test/DataStreamReaderWriterSuite.scala @@ -641,6 +641,7 @@ class DataStreamReaderWriterSuite extends StreamTest with BeforeAndAfter { test("temp checkpoint dir should be deleted if a query is stopped without errors") { import testImplicits._ val query = MemoryStream[Int].toDS.writeStream.format("console").start() + query.processAllAvailable() val checkpointDir = new Path( query.asInstanceOf[StreamingQueryWrapper].streamingQuery.resolvedCheckpointRoot) val fs = checkpointDir.getFileSystem(spark.sessionState.newHadoopConf()) From 5800144a54f5c0180ccf67392f32c3e8a51119b1 Mon Sep 17 00:00:00 2001 From: jerryshao Date: Thu, 6 Jul 2017 15:32:49 +0800 Subject: [PATCH 0881/1765] [SPARK-21012][SUBMIT] Add glob support for resources adding to Spark Current "--jars (spark.jars)", "--files (spark.files)", "--py-files (spark.submit.pyFiles)" and "--archives (spark.yarn.dist.archives)" only support non-glob path. This is OK for most of the cases, but when user requires to add more jars, files into Spark, it is too verbose to list one by one. So here propose to add glob path support for resources. Also improving the code of downloading resources. ## How was this patch tested? UT added, also verified manually in local cluster. Author: jerryshao Closes #18235 from jerryshao/SPARK-21012. --- .../org/apache/spark/deploy/SparkSubmit.scala | 166 ++++++++++++++---- .../spark/deploy/SparkSubmitArguments.scala | 2 +- .../spark/deploy/SparkSubmitSuite.scala | 68 ++++++- docs/configuration.md | 6 +- 4 files changed, 196 insertions(+), 46 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala index d13fb4193970b..abde04062c4b1 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala @@ -17,17 +17,21 @@ package org.apache.spark.deploy -import java.io.{File, IOException} +import java.io._ import java.lang.reflect.{InvocationTargetException, Modifier, UndeclaredThrowableException} import java.net.URL import java.nio.file.Files -import java.security.PrivilegedExceptionAction +import java.security.{KeyStore, PrivilegedExceptionAction} +import java.security.cert.X509Certificate import java.text.ParseException +import javax.net.ssl._ import scala.annotation.tailrec import scala.collection.mutable.{ArrayBuffer, HashMap, Map} import scala.util.Properties +import com.google.common.io.ByteStreams +import org.apache.commons.io.FileUtils import org.apache.commons.lang3.StringUtils import org.apache.hadoop.conf.{Configuration => HadoopConfiguration} import org.apache.hadoop.fs.{FileSystem, Path} @@ -310,33 +314,33 @@ object SparkSubmit extends CommandLineUtils { RPackageUtils.checkAndBuildRPackage(args.jars, printStream, args.verbose) } - // In client mode, download remote files. - if (deployMode == CLIENT) { - val hadoopConf = new HadoopConfiguration() - args.primaryResource = Option(args.primaryResource).map(downloadFile(_, hadoopConf)).orNull - args.jars = Option(args.jars).map(downloadFileList(_, hadoopConf)).orNull - args.pyFiles = Option(args.pyFiles).map(downloadFileList(_, hadoopConf)).orNull - args.files = Option(args.files).map(downloadFileList(_, hadoopConf)).orNull - } - - // Require all python files to be local, so we can add them to the PYTHONPATH - // In YARN cluster mode, python files are distributed as regular files, which can be non-local. - // In Mesos cluster mode, non-local python files are automatically downloaded by Mesos. - if (args.isPython && !isYarnCluster && !isMesosCluster) { - if (Utils.nonLocalPaths(args.primaryResource).nonEmpty) { - printErrorAndExit(s"Only local python files are supported: ${args.primaryResource}") + val hadoopConf = new HadoopConfiguration() + val targetDir = Files.createTempDirectory("tmp").toFile + // scalastyle:off runtimeaddshutdownhook + Runtime.getRuntime.addShutdownHook(new Thread() { + override def run(): Unit = { + FileUtils.deleteQuietly(targetDir) } - val nonLocalPyFiles = Utils.nonLocalPaths(args.pyFiles).mkString(",") - if (nonLocalPyFiles.nonEmpty) { - printErrorAndExit(s"Only local additional python files are supported: $nonLocalPyFiles") - } - } + }) + // scalastyle:on runtimeaddshutdownhook - // Require all R files to be local - if (args.isR && !isYarnCluster && !isMesosCluster) { - if (Utils.nonLocalPaths(args.primaryResource).nonEmpty) { - printErrorAndExit(s"Only local R files are supported: ${args.primaryResource}") - } + // Resolve glob path for different resources. + args.jars = Option(args.jars).map(resolveGlobPaths(_, hadoopConf)).orNull + args.files = Option(args.files).map(resolveGlobPaths(_, hadoopConf)).orNull + args.pyFiles = Option(args.pyFiles).map(resolveGlobPaths(_, hadoopConf)).orNull + args.archives = Option(args.archives).map(resolveGlobPaths(_, hadoopConf)).orNull + + // In client mode, download remote files. + if (deployMode == CLIENT) { + args.primaryResource = Option(args.primaryResource).map { + downloadFile(_, targetDir, args.sparkProperties, hadoopConf) + }.orNull + args.jars = Option(args.jars).map { + downloadFileList(_, targetDir, args.sparkProperties, hadoopConf) + }.orNull + args.pyFiles = Option(args.pyFiles).map { + downloadFileList(_, targetDir, args.sparkProperties, hadoopConf) + }.orNull } // The following modes are not supported or applicable @@ -841,36 +845,132 @@ object SparkSubmit extends CommandLineUtils { * Download a list of remote files to temp local files. If the file is local, the original file * will be returned. * @param fileList A comma separated file list. + * @param targetDir A temporary directory for which downloaded files + * @param sparkProperties Spark properties * @return A comma separated local files list. */ private[deploy] def downloadFileList( fileList: String, + targetDir: File, + sparkProperties: Map[String, String], hadoopConf: HadoopConfiguration): String = { require(fileList != null, "fileList cannot be null.") - fileList.split(",").map(downloadFile(_, hadoopConf)).mkString(",") + fileList.split(",") + .map(downloadFile(_, targetDir, sparkProperties, hadoopConf)) + .mkString(",") } /** * Download a file from the remote to a local temporary directory. If the input path points to * a local path, returns it with no operation. + * @param path A file path from where the files will be downloaded. + * @param targetDir A temporary directory for which downloaded files + * @param sparkProperties Spark properties + * @return A comma separated local files list. */ - private[deploy] def downloadFile(path: String, hadoopConf: HadoopConfiguration): String = { + private[deploy] def downloadFile( + path: String, + targetDir: File, + sparkProperties: Map[String, String], + hadoopConf: HadoopConfiguration): String = { require(path != null, "path cannot be null.") val uri = Utils.resolveURI(path) uri.getScheme match { - case "file" | "local" => - path + case "file" | "local" => path + case "http" | "https" | "ftp" => + val uc = uri.toURL.openConnection() + uc match { + case https: HttpsURLConnection => + val trustStore = sparkProperties.get("spark.ssl.fs.trustStore") + .orElse(sparkProperties.get("spark.ssl.trustStore")) + val trustStorePwd = sparkProperties.get("spark.ssl.fs.trustStorePassword") + .orElse(sparkProperties.get("spark.ssl.trustStorePassword")) + .map(_.toCharArray) + .orNull + val protocol = sparkProperties.get("spark.ssl.fs.protocol") + .orElse(sparkProperties.get("spark.ssl.protocol")) + if (protocol.isEmpty) { + printErrorAndExit("spark ssl protocol is required when enabling SSL connection.") + } + + val trustStoreManagers = trustStore.map { t => + var input: InputStream = null + try { + input = new FileInputStream(new File(t)) + val ks = KeyStore.getInstance(KeyStore.getDefaultType) + ks.load(input, trustStorePwd) + val tmf = TrustManagerFactory.getInstance(TrustManagerFactory.getDefaultAlgorithm) + tmf.init(ks) + tmf.getTrustManagers + } finally { + if (input != null) { + input.close() + input = null + } + } + }.getOrElse { + Array({ + new X509TrustManager { + override def getAcceptedIssuers: Array[X509Certificate] = null + override def checkClientTrusted( + x509Certificates: Array[X509Certificate], s: String) {} + override def checkServerTrusted( + x509Certificates: Array[X509Certificate], s: String) {} + }: TrustManager + }) + } + val sslContext = SSLContext.getInstance(protocol.get) + sslContext.init(null, trustStoreManagers, null) + https.setSSLSocketFactory(sslContext.getSocketFactory) + https.setHostnameVerifier(new HostnameVerifier { + override def verify(s: String, sslSession: SSLSession): Boolean = false + }) + + case _ => + } + uc.setConnectTimeout(60 * 1000) + uc.setReadTimeout(60 * 1000) + uc.connect() + val in = uc.getInputStream + val fileName = new Path(uri).getName + val tempFile = new File(targetDir, fileName) + val out = new FileOutputStream(tempFile) + // scalastyle:off println + printStream.println(s"Downloading ${uri.toString} to ${tempFile.getAbsolutePath}.") + // scalastyle:on println + try { + ByteStreams.copy(in, out) + } finally { + in.close() + out.close() + } + tempFile.toURI.toString case _ => val fs = FileSystem.get(uri, hadoopConf) - val tmpFile = new File(Files.createTempDirectory("tmp").toFile, uri.getPath) + val tmpFile = new File(targetDir, new Path(uri).getName) // scalastyle:off println printStream.println(s"Downloading ${uri.toString} to ${tmpFile.getAbsolutePath}.") // scalastyle:on println fs.copyToLocalFile(new Path(uri), new Path(tmpFile.getAbsolutePath)) - Utils.resolveURI(tmpFile.getAbsolutePath).toString + tmpFile.toURI.toString } } + + private def resolveGlobPaths(paths: String, hadoopConf: HadoopConfiguration): String = { + require(paths != null, "paths cannot be null.") + paths.split(",").map(_.trim).filter(_.nonEmpty).flatMap { path => + val uri = Utils.resolveURI(path) + uri.getScheme match { + case "local" | "http" | "https" | "ftp" => Array(path) + case _ => + val fs = FileSystem.get(uri, hadoopConf) + Option(fs.globStatus(new Path(uri))).map { status => + status.filter(_.isFile).map(_.getPath.toUri.toString) + }.getOrElse(Array(path)) + } + }.mkString(",") + } } /** Provides utility functions to be used inside SparkSubmit. */ diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala index 7800d3d624e3e..fd1521193fdee 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala @@ -520,7 +520,7 @@ private[deploy] class SparkSubmitArguments(args: Seq[String], env: Map[String, S | (Default: client). | --class CLASS_NAME Your application's main class (for Java / Scala apps). | --name NAME A name of your application. - | --jars JARS Comma-separated list of local jars to include on the driver + | --jars JARS Comma-separated list of jars to include on the driver | and executor classpaths. | --packages Comma-separated list of maven coordinates of jars to include | on the driver and executor classpaths. Will search the local diff --git a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala index b089357e7b868..97357cdbb6083 100644 --- a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala @@ -20,12 +20,14 @@ package org.apache.spark.deploy import java.io._ import java.net.URI import java.nio.charset.StandardCharsets +import java.nio.file.Files +import scala.collection.mutable import scala.collection.mutable.ArrayBuffer import scala.io.Source import com.google.common.io.ByteStreams -import org.apache.commons.io.{FilenameUtils, FileUtils} +import org.apache.commons.io.FileUtils import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.Path import org.scalatest.{BeforeAndAfterEach, Matchers} @@ -42,7 +44,6 @@ import org.apache.spark.TestUtils.JavaSourceFromString import org.apache.spark.scheduler.EventLoggingListener import org.apache.spark.util.{CommandLineUtils, ResetSystemProperties, Utils} - trait TestPrematureExit { suite: SparkFunSuite => @@ -726,6 +727,47 @@ class SparkSubmitSuite Utils.unionFileLists(None, Option("/tmp/a.jar")) should be (Set("/tmp/a.jar")) Utils.unionFileLists(Option("/tmp/a.jar"), None) should be (Set("/tmp/a.jar")) } + + test("support glob path") { + val tmpJarDir = Utils.createTempDir() + val jar1 = TestUtils.createJarWithFiles(Map("test.resource" -> "1"), tmpJarDir) + val jar2 = TestUtils.createJarWithFiles(Map("test.resource" -> "USER"), tmpJarDir) + + val tmpFileDir = Utils.createTempDir() + val file1 = File.createTempFile("tmpFile1", "", tmpFileDir) + val file2 = File.createTempFile("tmpFile2", "", tmpFileDir) + + val tmpPyFileDir = Utils.createTempDir() + val pyFile1 = File.createTempFile("tmpPy1", ".py", tmpPyFileDir) + val pyFile2 = File.createTempFile("tmpPy2", ".egg", tmpPyFileDir) + + val tmpArchiveDir = Utils.createTempDir() + val archive1 = File.createTempFile("archive1", ".zip", tmpArchiveDir) + val archive2 = File.createTempFile("archive2", ".zip", tmpArchiveDir) + + val args = Seq( + "--class", UserClasspathFirstTest.getClass.getName.stripPrefix("$"), + "--name", "testApp", + "--master", "yarn", + "--deploy-mode", "client", + "--jars", s"${tmpJarDir.getAbsolutePath}/*.jar", + "--files", s"${tmpFileDir.getAbsolutePath}/tmpFile*", + "--py-files", s"${tmpPyFileDir.getAbsolutePath}/tmpPy*", + "--archives", s"${tmpArchiveDir.getAbsolutePath}/*.zip", + jar2.toString) + + val appArgs = new SparkSubmitArguments(args) + val sysProps = SparkSubmit.prepareSubmitEnvironment(appArgs)._3 + sysProps("spark.yarn.dist.jars").split(",").toSet should be + (Set(jar1.toURI.toString, jar2.toURI.toString)) + sysProps("spark.yarn.dist.files").split(",").toSet should be + (Set(file1.toURI.toString, file2.toURI.toString)) + sysProps("spark.submit.pyFiles").split(",").toSet should be + (Set(pyFile1.getAbsolutePath, pyFile2.getAbsolutePath)) + sysProps("spark.yarn.dist.archives").split(",").toSet should be + (Set(archive1.toURI.toString, archive2.toURI.toString)) + } + // scalastyle:on println private def checkDownloadedFile(sourcePath: String, outputPath: String): Unit = { @@ -738,7 +780,7 @@ class SparkSubmitSuite assert(outputUri.getScheme === "file") // The path and filename are preserved. - assert(outputUri.getPath.endsWith(sourceUri.getPath)) + assert(outputUri.getPath.endsWith(new Path(sourceUri).getName)) assert(FileUtils.readFileToString(new File(outputUri.getPath)) === FileUtils.readFileToString(new File(sourceUri.getPath))) } @@ -752,25 +794,29 @@ class SparkSubmitSuite test("downloadFile - invalid url") { intercept[IOException] { - SparkSubmit.downloadFile("abc:/my/file", new Configuration()) + SparkSubmit.downloadFile( + "abc:/my/file", Utils.createTempDir(), mutable.Map.empty, new Configuration()) } } test("downloadFile - file doesn't exist") { val hadoopConf = new Configuration() + val tmpDir = Utils.createTempDir() // Set s3a implementation to local file system for testing. hadoopConf.set("fs.s3a.impl", "org.apache.spark.deploy.TestFileSystem") // Disable file system impl cache to make sure the test file system is picked up. hadoopConf.set("fs.s3a.impl.disable.cache", "true") intercept[FileNotFoundException] { - SparkSubmit.downloadFile("s3a:/no/such/file", hadoopConf) + SparkSubmit.downloadFile("s3a:/no/such/file", tmpDir, mutable.Map.empty, hadoopConf) } } test("downloadFile does not download local file") { // empty path is considered as local file. - assert(SparkSubmit.downloadFile("", new Configuration()) === "") - assert(SparkSubmit.downloadFile("/local/file", new Configuration()) === "/local/file") + val tmpDir = Files.createTempDirectory("tmp").toFile + assert(SparkSubmit.downloadFile("", tmpDir, mutable.Map.empty, new Configuration()) === "") + assert(SparkSubmit.downloadFile("/local/file", tmpDir, mutable.Map.empty, + new Configuration()) === "/local/file") } test("download one file to local") { @@ -779,12 +825,14 @@ class SparkSubmitSuite val content = "hello, world" FileUtils.write(jarFile, content) val hadoopConf = new Configuration() + val tmpDir = Files.createTempDirectory("tmp").toFile // Set s3a implementation to local file system for testing. hadoopConf.set("fs.s3a.impl", "org.apache.spark.deploy.TestFileSystem") // Disable file system impl cache to make sure the test file system is picked up. hadoopConf.set("fs.s3a.impl.disable.cache", "true") val sourcePath = s"s3a://${jarFile.getAbsolutePath}" - val outputPath = SparkSubmit.downloadFile(sourcePath, hadoopConf) + val outputPath = + SparkSubmit.downloadFile(sourcePath, tmpDir, mutable.Map.empty, hadoopConf) checkDownloadedFile(sourcePath, outputPath) deleteTempOutputFile(outputPath) } @@ -795,12 +843,14 @@ class SparkSubmitSuite val content = "hello, world" FileUtils.write(jarFile, content) val hadoopConf = new Configuration() + val tmpDir = Files.createTempDirectory("tmp").toFile // Set s3a implementation to local file system for testing. hadoopConf.set("fs.s3a.impl", "org.apache.spark.deploy.TestFileSystem") // Disable file system impl cache to make sure the test file system is picked up. hadoopConf.set("fs.s3a.impl.disable.cache", "true") val sourcePaths = Seq("/local/file", s"s3a://${jarFile.getAbsolutePath}") - val outputPaths = SparkSubmit.downloadFileList(sourcePaths.mkString(","), hadoopConf).split(",") + val outputPaths = SparkSubmit.downloadFileList( + sourcePaths.mkString(","), tmpDir, mutable.Map.empty, hadoopConf).split(",") assert(outputPaths.length === sourcePaths.length) sourcePaths.zip(outputPaths).foreach { case (sourcePath, outputPath) => diff --git a/docs/configuration.md b/docs/configuration.md index c785a664c67b1..7dc23e441a7ba 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -422,21 +422,21 @@ Apart from these, the following properties are also available, and may be useful spark.files - Comma-separated list of files to be placed in the working directory of each executor. + Comma-separated list of files to be placed in the working directory of each executor. Globs are allowed. spark.submit.pyFiles - Comma-separated list of .zip, .egg, or .py files to place on the PYTHONPATH for Python apps. + Comma-separated list of .zip, .egg, or .py files to place on the PYTHONPATH for Python apps. Globs are allowed. spark.jars - Comma-separated list of local jars to include on the driver and executor classpaths. + Comma-separated list of jars to include on the driver and executor classpaths. Globs are allowed. From 6ff05a66fe83e721063efe5c28d2ffeb850fecc7 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Thu, 6 Jul 2017 15:47:09 +0800 Subject: [PATCH 0882/1765] [SPARK-20703][SQL] Associate metrics with data writes onto DataFrameWriter operations ## What changes were proposed in this pull request? Right now in the UI, after SPARK-20213, we can show the operations to write data out. However, there is no way to associate metrics with data writes. We should show relative metrics on the operations. #### Supported commands This change supports updating metrics for file-based data writing operations, including `InsertIntoHadoopFsRelationCommand`, `InsertIntoHiveTable`. Supported metrics: * number of written files * number of dynamic partitions * total bytes of written data * total number of output rows * average writing data out time (ms) * (TODO) min/med/max number of output rows per file/partition * (TODO) min/med/max bytes of written data per file/partition #### Commands not supported `InsertIntoDataSourceCommand`, `SaveIntoDataSourceCommand`: The two commands uses DataSource APIs to write data out, i.e., the logic of writing data out is delegated to the DataSource implementations, such as `InsertableRelation.insert` and `CreatableRelationProvider.createRelation`. So we can't obtain metrics from delegated methods for now. `CreateHiveTableAsSelectCommand`, `CreateDataSourceTableAsSelectCommand` : The two commands invokes other commands to write data out. The invoked commands can even write to non file-based data source. We leave them as future TODO. #### How to update metrics of writing files out A `RunnableCommand` which wants to update metrics, needs to override its `metrics` and provide the metrics data structure to `ExecutedCommandExec`. The metrics are prepared during the execution of `FileFormatWriter`. The callback function passed to `FileFormatWriter` will accept the metrics and update accordingly. There is a metrics updating function in `RunnableCommand`. In runtime, the function will be bound to the spark context and `metrics` of `ExecutedCommandExec` and pass to `FileFormatWriter`. ## How was this patch tested? Updated unit tests. Author: Liang-Chi Hsieh Closes #18159 from viirya/SPARK-20703-2. --- .../scala/org/apache/spark/util/Utils.scala | 9 ++ .../command/DataWritingCommand.scala | 75 ++++++++++ .../sql/execution/command/commands.scala | 12 ++ .../datasources/FileFormatWriter.scala | 121 ++++++++++++--- .../InsertIntoHadoopFsRelationCommand.scala | 18 ++- .../sql/sources/PartitionedWriteSuite.scala | 21 +-- .../hive/execution/InsertIntoHiveTable.scala | 8 +- .../sql/hive/execution/SQLMetricsSuite.scala | 139 ++++++++++++++++++ 8 files changed, 362 insertions(+), 41 deletions(-) create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/command/DataWritingCommand.scala create mode 100644 sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLMetricsSuite.scala diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index 26f61e25da4d3..b4caf68f0afaa 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -1002,6 +1002,15 @@ private[spark] object Utils extends Logging { } } + /** + * Lists files recursively. + */ + def recursiveList(f: File): Array[File] = { + require(f.isDirectory) + val current = f.listFiles + current ++ current.filter(_.isDirectory).flatMap(recursiveList) + } + /** * Delete a file or directory and its contents recursively. * Don't follow directories if they are symlinks. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/DataWritingCommand.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/DataWritingCommand.scala new file mode 100644 index 0000000000000..0c381a2c02986 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/DataWritingCommand.scala @@ -0,0 +1,75 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.command + +import org.apache.spark.SparkContext +import org.apache.spark.sql.execution.SQLExecution +import org.apache.spark.sql.execution.datasources.ExecutedWriteSummary +import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics} + +/** + * A special `RunnableCommand` which writes data out and updates metrics. + */ +trait DataWritingCommand extends RunnableCommand { + + override lazy val metrics: Map[String, SQLMetric] = { + val sparkContext = SparkContext.getActive.get + Map( + "avgTime" -> SQLMetrics.createMetric(sparkContext, "average writing time (ms)"), + "numFiles" -> SQLMetrics.createMetric(sparkContext, "number of written files"), + "numOutputBytes" -> SQLMetrics.createMetric(sparkContext, "bytes of written output"), + "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows"), + "numParts" -> SQLMetrics.createMetric(sparkContext, "number of dynamic part") + ) + } + + /** + * Callback function that update metrics collected from the writing operation. + */ + protected def updateWritingMetrics(writeSummaries: Seq[ExecutedWriteSummary]): Unit = { + val sparkContext = SparkContext.getActive.get + var numPartitions = 0 + var numFiles = 0 + var totalNumBytes: Long = 0L + var totalNumOutput: Long = 0L + var totalWritingTime: Long = 0L + + writeSummaries.foreach { summary => + numPartitions += summary.updatedPartitions.size + numFiles += summary.numOutputFile + totalNumBytes += summary.numOutputBytes + totalNumOutput += summary.numOutputRows + totalWritingTime += summary.totalWritingTime + } + + val avgWritingTime = if (numFiles > 0) { + (totalWritingTime / numFiles).toLong + } else { + 0L + } + + metrics("avgTime").add(avgWritingTime) + metrics("numFiles").add(numFiles) + metrics("numOutputBytes").add(totalNumBytes) + metrics("numOutputRows").add(totalNumOutput) + metrics("numParts").add(numPartitions) + + val executionId = sparkContext.getLocalProperty(SQLExecution.EXECUTION_ID_KEY) + SQLMetrics.postDriverMetricUpdates(sparkContext, executionId, metrics.values.toList) + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/commands.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/commands.scala index 81bc93e7ebcf4..7cd4baef89e75 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/commands.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/commands.scala @@ -28,6 +28,7 @@ import org.apache.spark.sql.catalyst.plans.{logical, QueryPlan} import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.execution.SparkPlan import org.apache.spark.sql.execution.debug._ +import org.apache.spark.sql.execution.metric.SQLMetric import org.apache.spark.sql.execution.streaming.{IncrementalExecution, OffsetSeqMetadata} import org.apache.spark.sql.streaming.OutputMode import org.apache.spark.sql.types._ @@ -37,6 +38,11 @@ import org.apache.spark.sql.types._ * wrapped in `ExecutedCommand` during execution. */ trait RunnableCommand extends logical.Command { + + // The map used to record the metrics of running the command. This will be passed to + // `ExecutedCommand` during query planning. + lazy val metrics: Map[String, SQLMetric] = Map.empty + def run(sparkSession: SparkSession, children: Seq[SparkPlan]): Seq[Row] = { throw new NotImplementedError } @@ -49,8 +55,14 @@ trait RunnableCommand extends logical.Command { /** * A physical operator that executes the run method of a `RunnableCommand` and * saves the result to prevent multiple executions. + * + * @param cmd the `RunnableCommand` this operator will run. + * @param children the children physical plans ran by the `RunnableCommand`. */ case class ExecutedCommandExec(cmd: RunnableCommand, children: Seq[SparkPlan]) extends SparkPlan { + + override lazy val metrics: Map[String, SQLMetric] = cmd.metrics + /** * A concrete command should override this lazy field to wrap up any side effects caused by the * command or any other computation that should be evaluated exactly once. The value of this field diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala index 0daffa93b4747..64866630623ab 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala @@ -22,7 +22,7 @@ import java.util.{Date, UUID} import scala.collection.mutable import org.apache.hadoop.conf.Configuration -import org.apache.hadoop.fs.Path +import org.apache.hadoop.fs.{FileSystem, Path} import org.apache.hadoop.mapreduce._ import org.apache.hadoop.mapreduce.lib.output.FileOutputFormat import org.apache.hadoop.mapreduce.task.TaskAttemptContextImpl @@ -82,7 +82,7 @@ object FileFormatWriter extends Logging { } /** The result of a successful write task. */ - private case class WriteTaskResult(commitMsg: TaskCommitMessage, updatedPartitions: Set[String]) + private case class WriteTaskResult(commitMsg: TaskCommitMessage, summary: ExecutedWriteSummary) /** * Basic work flow of this command is: @@ -104,7 +104,7 @@ object FileFormatWriter extends Logging { hadoopConf: Configuration, partitionColumns: Seq[Attribute], bucketSpec: Option[BucketSpec], - refreshFunction: (Seq[TablePartitionSpec]) => Unit, + refreshFunction: (Seq[ExecutedWriteSummary]) => Unit, options: Map[String, String]): Unit = { val job = Job.getInstance(hadoopConf) @@ -196,12 +196,10 @@ object FileFormatWriter extends Logging { }) val commitMsgs = ret.map(_.commitMsg) - val updatedPartitions = ret.flatMap(_.updatedPartitions) - .distinct.map(PartitioningUtils.parsePathFragment) committer.commitJob(job, commitMsgs) logInfo(s"Job ${job.getJobID} committed.") - refreshFunction(updatedPartitions) + refreshFunction(ret.map(_.summary)) } catch { case cause: Throwable => logError(s"Aborting job ${job.getJobID}.", cause) committer.abortJob(job) @@ -247,9 +245,9 @@ object FileFormatWriter extends Logging { try { Utils.tryWithSafeFinallyAndFailureCallbacks(block = { // Execute the task to write rows out and commit the task. - val outputPartitions = writeTask.execute(iterator) + val summary = writeTask.execute(iterator) writeTask.releaseResources() - WriteTaskResult(committer.commitTask(taskAttemptContext), outputPartitions) + WriteTaskResult(committer.commitTask(taskAttemptContext), summary) })(catchBlock = { // If there is an error, release resource and then abort the task try { @@ -273,12 +271,36 @@ object FileFormatWriter extends Logging { * automatically trigger task aborts. */ private trait ExecuteWriteTask { + /** - * Writes data out to files, and then returns the list of partition strings written out. - * The list of partitions is sent back to the driver and used to update the catalog. + * The data structures used to measure metrics during writing. */ - def execute(iterator: Iterator[InternalRow]): Set[String] + protected var totalWritingTime: Long = 0L + protected var timeOnCurrentFile: Long = 0L + protected var numOutputRows: Long = 0L + protected var numOutputBytes: Long = 0L + + /** + * Writes data out to files, and then returns the summary of relative information which + * includes the list of partition strings written out. The list of partitions is sent back + * to the driver and used to update the catalog. Other information will be sent back to the + * driver too and used to update the metrics in UI. + */ + def execute(iterator: Iterator[InternalRow]): ExecutedWriteSummary def releaseResources(): Unit + + /** + * A helper function used to determine the size in bytes of a written file. + */ + protected def getFileSize(conf: Configuration, filePath: String): Long = { + if (filePath != null) { + val path = new Path(filePath) + val fs = path.getFileSystem(conf) + fs.getFileStatus(path).getLen() + } else { + 0L + } + } } /** Writes data to a single directory (used for non-dynamic-partition writes). */ @@ -288,24 +310,26 @@ object FileFormatWriter extends Logging { committer: FileCommitProtocol) extends ExecuteWriteTask { private[this] var currentWriter: OutputWriter = _ + private[this] var currentPath: String = _ private def newOutputWriter(fileCounter: Int): Unit = { val ext = description.outputWriterFactory.getFileExtension(taskAttemptContext) - val tmpFilePath = committer.newTaskTempFile( + currentPath = committer.newTaskTempFile( taskAttemptContext, None, f"-c$fileCounter%03d" + ext) currentWriter = description.outputWriterFactory.newInstance( - path = tmpFilePath, + path = currentPath, dataSchema = description.dataColumns.toStructType, context = taskAttemptContext) } - override def execute(iter: Iterator[InternalRow]): Set[String] = { + override def execute(iter: Iterator[InternalRow]): ExecutedWriteSummary = { var fileCounter = 0 var recordsInFile: Long = 0L newOutputWriter(fileCounter) + while (iter.hasNext) { if (description.maxRecordsPerFile > 0 && recordsInFile >= description.maxRecordsPerFile) { fileCounter += 1 @@ -314,21 +338,35 @@ object FileFormatWriter extends Logging { recordsInFile = 0 releaseResources() + numOutputRows += recordsInFile newOutputWriter(fileCounter) } val internalRow = iter.next() + val startTime = System.nanoTime() currentWriter.write(internalRow) + timeOnCurrentFile += (System.nanoTime() - startTime) recordsInFile += 1 } releaseResources() - Set.empty + numOutputRows += recordsInFile + + ExecutedWriteSummary( + updatedPartitions = Set.empty, + numOutputFile = fileCounter + 1, + numOutputBytes = numOutputBytes, + numOutputRows = numOutputRows, + totalWritingTime = totalWritingTime) } override def releaseResources(): Unit = { if (currentWriter != null) { try { + val startTime = System.nanoTime() currentWriter.close() + totalWritingTime += (timeOnCurrentFile + System.nanoTime() - startTime) / 1000 / 1000 + timeOnCurrentFile = 0 + numOutputBytes += getFileSize(taskAttemptContext.getConfiguration, currentPath) } finally { currentWriter = null } @@ -348,6 +386,8 @@ object FileFormatWriter extends Logging { // currentWriter is initialized whenever we see a new key private var currentWriter: OutputWriter = _ + private var currentPath: String = _ + /** Expressions that given partition columns build a path string like: col1=val/col2=val/... */ private def partitionPathExpression: Seq[Expression] = { desc.partitionColumns.zipWithIndex.flatMap { case (c, i) => @@ -403,19 +443,19 @@ object FileFormatWriter extends Logging { case _ => None } - val path = if (customPath.isDefined) { + currentPath = if (customPath.isDefined) { committer.newTaskTempFileAbsPath(taskAttemptContext, customPath.get, ext) } else { committer.newTaskTempFile(taskAttemptContext, partDir, ext) } currentWriter = desc.outputWriterFactory.newInstance( - path = path, + path = currentPath, dataSchema = desc.dataColumns.toStructType, context = taskAttemptContext) } - override def execute(iter: Iterator[InternalRow]): Set[String] = { + override def execute(iter: Iterator[InternalRow]): ExecutedWriteSummary = { val getPartitionColsAndBucketId = UnsafeProjection.create( desc.partitionColumns ++ desc.bucketIdExpression, desc.allColumns) @@ -429,15 +469,22 @@ object FileFormatWriter extends Logging { // If anything below fails, we should abort the task. var recordsInFile: Long = 0L var fileCounter = 0 + var totalFileCounter = 0 var currentPartColsAndBucketId: UnsafeRow = null val updatedPartitions = mutable.Set[String]() + for (row <- iter) { val nextPartColsAndBucketId = getPartitionColsAndBucketId(row) if (currentPartColsAndBucketId != nextPartColsAndBucketId) { + if (currentPartColsAndBucketId != null) { + totalFileCounter += (fileCounter + 1) + } + // See a new partition or bucket - write to a new partition dir (or a new bucket file). currentPartColsAndBucketId = nextPartColsAndBucketId.copy() logDebug(s"Writing partition: $currentPartColsAndBucketId") + numOutputRows += recordsInFile recordsInFile = 0 fileCounter = 0 @@ -447,6 +494,8 @@ object FileFormatWriter extends Logging { recordsInFile >= desc.maxRecordsPerFile) { // Exceeded the threshold in terms of the number of records per file. // Create a new file by increasing the file counter. + + numOutputRows += recordsInFile recordsInFile = 0 fileCounter += 1 assert(fileCounter < MAX_FILE_COUNTER, @@ -455,18 +504,33 @@ object FileFormatWriter extends Logging { releaseResources() newOutputWriter(currentPartColsAndBucketId, getPartPath, fileCounter, updatedPartitions) } - + val startTime = System.nanoTime() currentWriter.write(getOutputRow(row)) + timeOnCurrentFile += (System.nanoTime() - startTime) recordsInFile += 1 } + if (currentPartColsAndBucketId != null) { + totalFileCounter += (fileCounter + 1) + } releaseResources() - updatedPartitions.toSet + numOutputRows += recordsInFile + + ExecutedWriteSummary( + updatedPartitions = updatedPartitions.toSet, + numOutputFile = totalFileCounter, + numOutputBytes = numOutputBytes, + numOutputRows = numOutputRows, + totalWritingTime = totalWritingTime) } override def releaseResources(): Unit = { if (currentWriter != null) { try { + val startTime = System.nanoTime() currentWriter.close() + totalWritingTime += (timeOnCurrentFile + System.nanoTime() - startTime) / 1000 / 1000 + timeOnCurrentFile = 0 + numOutputBytes += getFileSize(taskAttemptContext.getConfiguration, currentPath) } finally { currentWriter = null } @@ -474,3 +538,20 @@ object FileFormatWriter extends Logging { } } } + +/** + * Wrapper class for the metrics of writing data out. + * + * @param updatedPartitions the partitions updated during writing data out. Only valid + * for dynamic partition. + * @param numOutputFile the total number of files. + * @param numOutputRows the number of output rows. + * @param numOutputBytes the bytes of output data. + * @param totalWritingTime the total writing time in ms. + */ +case class ExecutedWriteSummary( + updatedPartitions: Set[String], + numOutputFile: Int, + numOutputRows: Long, + numOutputBytes: Long, + totalWritingTime: Long) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelationCommand.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelationCommand.scala index ab26f2affbce5..0031567d3d288 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelationCommand.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelationCommand.scala @@ -21,6 +21,7 @@ import java.io.IOException import org.apache.hadoop.fs.{FileSystem, Path} +import org.apache.spark.SparkContext import org.apache.spark.internal.io.FileCommitProtocol import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.catalog.{BucketSpec, CatalogTable, CatalogTablePartition} @@ -29,6 +30,7 @@ import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.execution.SparkPlan import org.apache.spark.sql.execution.command._ +import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics} /** * A command for writing data to a [[HadoopFsRelation]]. Supports both overwriting and appending. @@ -53,7 +55,7 @@ case class InsertIntoHadoopFsRelationCommand( mode: SaveMode, catalogTable: Option[CatalogTable], fileIndex: Option[FileIndex]) - extends RunnableCommand { + extends DataWritingCommand { import org.apache.spark.sql.catalyst.catalog.ExternalCatalogUtils.escapePathName override def children: Seq[LogicalPlan] = query :: Nil @@ -123,8 +125,16 @@ case class InsertIntoHadoopFsRelationCommand( if (doInsertion) { - // Callback for updating metastore partition metadata after the insertion job completes. - def refreshPartitionsCallback(updatedPartitions: Seq[TablePartitionSpec]): Unit = { + // Callback for updating metric and metastore partition metadata + // after the insertion job completes. + def refreshCallback(summary: Seq[ExecutedWriteSummary]): Unit = { + val updatedPartitions = summary.flatMap(_.updatedPartitions) + .distinct.map(PartitioningUtils.parsePathFragment) + + // Updating metrics. + updateWritingMetrics(summary) + + // Updating metastore partition metadata. if (partitionsTrackedByCatalog) { val newPartitions = updatedPartitions.toSet -- initialMatchingPartitions if (newPartitions.nonEmpty) { @@ -154,7 +164,7 @@ case class InsertIntoHadoopFsRelationCommand( hadoopConf = hadoopConf, partitionColumns = partitionColumns, bucketSpec = bucketSpec, - refreshFunction = refreshPartitionsCallback, + refreshFunction = refreshCallback, options = options) // refresh cached files in FileIndex diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/PartitionedWriteSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/PartitionedWriteSuite.scala index a2f3afe3ce236..6f998aa60faf5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/PartitionedWriteSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/PartitionedWriteSuite.scala @@ -91,15 +91,15 @@ class PartitionedWriteSuite extends QueryTest with SharedSQLContext { withTempDir { f => spark.range(start = 0, end = 4, step = 1, numPartitions = 1) .write.option("maxRecordsPerFile", 1).mode("overwrite").parquet(f.getAbsolutePath) - assert(recursiveList(f).count(_.getAbsolutePath.endsWith("parquet")) == 4) + assert(Utils.recursiveList(f).count(_.getAbsolutePath.endsWith("parquet")) == 4) spark.range(start = 0, end = 4, step = 1, numPartitions = 1) .write.option("maxRecordsPerFile", 2).mode("overwrite").parquet(f.getAbsolutePath) - assert(recursiveList(f).count(_.getAbsolutePath.endsWith("parquet")) == 2) + assert(Utils.recursiveList(f).count(_.getAbsolutePath.endsWith("parquet")) == 2) spark.range(start = 0, end = 4, step = 1, numPartitions = 1) .write.option("maxRecordsPerFile", -1).mode("overwrite").parquet(f.getAbsolutePath) - assert(recursiveList(f).count(_.getAbsolutePath.endsWith("parquet")) == 1) + assert(Utils.recursiveList(f).count(_.getAbsolutePath.endsWith("parquet")) == 1) } } @@ -111,7 +111,7 @@ class PartitionedWriteSuite extends QueryTest with SharedSQLContext { .option("maxRecordsPerFile", 1) .mode("overwrite") .parquet(f.getAbsolutePath) - assert(recursiveList(f).count(_.getAbsolutePath.endsWith("parquet")) == 4) + assert(Utils.recursiveList(f).count(_.getAbsolutePath.endsWith("parquet")) == 4) } } @@ -138,14 +138,14 @@ class PartitionedWriteSuite extends QueryTest with SharedSQLContext { val df = Seq((1, ts)).toDF("i", "ts") withTempPath { f => df.write.partitionBy("ts").parquet(f.getAbsolutePath) - val files = recursiveList(f).filter(_.getAbsolutePath.endsWith("parquet")) + val files = Utils.recursiveList(f).filter(_.getAbsolutePath.endsWith("parquet")) assert(files.length == 1) checkPartitionValues(files.head, "2016-12-01 00:00:00") } withTempPath { f => df.write.option(DateTimeUtils.TIMEZONE_OPTION, "GMT") .partitionBy("ts").parquet(f.getAbsolutePath) - val files = recursiveList(f).filter(_.getAbsolutePath.endsWith("parquet")) + val files = Utils.recursiveList(f).filter(_.getAbsolutePath.endsWith("parquet")) assert(files.length == 1) // use timeZone option "GMT" to format partition value. checkPartitionValues(files.head, "2016-12-01 08:00:00") @@ -153,18 +153,11 @@ class PartitionedWriteSuite extends QueryTest with SharedSQLContext { withTempPath { f => withSQLConf(SQLConf.SESSION_LOCAL_TIMEZONE.key -> "GMT") { df.write.partitionBy("ts").parquet(f.getAbsolutePath) - val files = recursiveList(f).filter(_.getAbsolutePath.endsWith("parquet")) + val files = Utils.recursiveList(f).filter(_.getAbsolutePath.endsWith("parquet")) assert(files.length == 1) // if there isn't timeZone option, then use session local timezone. checkPartitionValues(files.head, "2016-12-01 08:00:00") } } } - - /** Lists files recursively. */ - private def recursiveList(f: File): Array[File] = { - require(f.isDirectory) - val current = f.listFiles - current ++ current.filter(_.isDirectory).flatMap(recursiveList) - } } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala index 223d375232393..cd263e8b6df8e 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala @@ -31,14 +31,16 @@ import org.apache.hadoop.hive.ql.exec.TaskRunner import org.apache.hadoop.hive.ql.ErrorMsg import org.apache.hadoop.hive.ql.plan.TableDesc +import org.apache.spark.SparkContext import org.apache.spark.internal.io.FileCommitProtocol import org.apache.spark.sql.{AnalysisException, Row, SparkSession} import org.apache.spark.sql.catalyst.catalog.CatalogTable import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.execution.SparkPlan -import org.apache.spark.sql.execution.command.{CommandUtils, RunnableCommand} +import org.apache.spark.sql.execution.command.{CommandUtils, DataWritingCommand} import org.apache.spark.sql.execution.datasources.FileFormatWriter +import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics} import org.apache.spark.sql.hive._ import org.apache.spark.sql.hive.HiveShim.{ShimFileSinkDesc => FileSinkDesc} import org.apache.spark.sql.hive.client.{HiveClientImpl, HiveVersion} @@ -80,7 +82,7 @@ case class InsertIntoHiveTable( partition: Map[String, Option[String]], query: LogicalPlan, overwrite: Boolean, - ifPartitionNotExists: Boolean) extends RunnableCommand { + ifPartitionNotExists: Boolean) extends DataWritingCommand { override def children: Seq[LogicalPlan] = query :: Nil @@ -354,7 +356,7 @@ case class InsertIntoHiveTable( hadoopConf = hadoopConf, partitionColumns = partitionAttributes, bucketSpec = None, - refreshFunction = _ => (), + refreshFunction = updateWritingMetrics, options = Map.empty) if (partition.nonEmpty) { diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLMetricsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLMetricsSuite.scala new file mode 100644 index 0000000000000..1ef1988d4c605 --- /dev/null +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLMetricsSuite.scala @@ -0,0 +1,139 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive.execution + +import java.io.File + +import org.apache.spark.sql.catalyst.TableIdentifier +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.hive.test.TestHiveSingleton +import org.apache.spark.sql.test.SQLTestUtils +import org.apache.spark.util.Utils + +class SQLMetricsSuite extends SQLTestUtils with TestHiveSingleton { + import spark.implicits._ + + /** + * Get execution metrics for the SQL execution and verify metrics values. + * + * @param metricsValues the expected metric values (numFiles, numPartitions, numOutputRows). + * @param func the function can produce execution id after running. + */ + private def verifyWriteDataMetrics(metricsValues: Seq[Int])(func: => Unit): Unit = { + val previousExecutionIds = spark.sharedState.listener.executionIdToData.keySet + // Run the given function to trigger query execution. + func + spark.sparkContext.listenerBus.waitUntilEmpty(10000) + val executionIds = + spark.sharedState.listener.executionIdToData.keySet.diff(previousExecutionIds) + assert(executionIds.size == 1) + val executionId = executionIds.head + + val executionData = spark.sharedState.listener.getExecution(executionId).get + val executedNode = executionData.physicalPlanGraph.nodes.head + + val metricsNames = Seq( + "number of written files", + "number of dynamic part", + "number of output rows") + + val metrics = spark.sharedState.listener.getExecutionMetrics(executionId) + + metricsNames.zip(metricsValues).foreach { case (metricsName, expected) => + val sqlMetric = executedNode.metrics.find(_.name == metricsName) + assert(sqlMetric.isDefined) + val accumulatorId = sqlMetric.get.accumulatorId + val metricValue = metrics(accumulatorId).replaceAll(",", "").toInt + assert(metricValue == expected) + } + + val totalNumBytesMetric = executedNode.metrics.find(_.name == "bytes of written output").get + val totalNumBytes = metrics(totalNumBytesMetric.accumulatorId).replaceAll(",", "").toInt + assert(totalNumBytes > 0) + val writingTimeMetric = executedNode.metrics.find(_.name == "average writing time (ms)").get + val writingTime = metrics(writingTimeMetric.accumulatorId).replaceAll(",", "").toInt + assert(writingTime >= 0) + } + + private def testMetricsNonDynamicPartition( + dataFormat: String, + tableName: String): Unit = { + withTable(tableName) { + Seq((1, 2)).toDF("i", "j") + .write.format(dataFormat).mode("overwrite").saveAsTable(tableName) + + val tableLocation = + new File(spark.sessionState.catalog.getTableMetadata(TableIdentifier(tableName)).location) + + // 2 files, 100 rows, 0 dynamic partition. + verifyWriteDataMetrics(Seq(2, 0, 100)) { + (0 until 100).map(i => (i, i + 1)).toDF("i", "j").repartition(2) + .write.format(dataFormat).mode("overwrite").insertInto(tableName) + } + assert(Utils.recursiveList(tableLocation).count(_.getName.startsWith("part-")) == 2) + } + } + + private def testMetricsDynamicPartition( + provider: String, + dataFormat: String, + tableName: String): Unit = { + withTempPath { dir => + spark.sql( + s""" + |CREATE TABLE $tableName(a int, b int) + |USING $provider + |PARTITIONED BY(a) + |LOCATION '${dir.toURI}' + """.stripMargin) + val table = spark.sessionState.catalog.getTableMetadata(TableIdentifier(tableName)) + assert(table.location == makeQualifiedPath(dir.getAbsolutePath)) + + val df = spark.range(start = 0, end = 40, step = 1, numPartitions = 1) + .selectExpr("id a", "id b") + + // 40 files, 80 rows, 40 dynamic partitions. + verifyWriteDataMetrics(Seq(40, 40, 80)) { + df.union(df).repartition(2, $"a") + .write + .format(dataFormat) + .mode("overwrite") + .insertInto(tableName) + } + assert(Utils.recursiveList(dir).count(_.getName.startsWith("part-")) == 40) + } + } + + test("writing data out metrics: parquet") { + testMetricsNonDynamicPartition("parquet", "t1") + } + + test("writing data out metrics with dynamic partition: parquet") { + testMetricsDynamicPartition("parquet", "parquet", "t1") + } + + test("writing data out metrics: hive") { + testMetricsNonDynamicPartition("hive", "t1") + } + + test("writing data out metrics dynamic partition: hive") { + withSQLConf(("hive.exec.dynamic.partition.mode", "nonstrict")) { + testMetricsDynamicPartition("hive", "hive", "t1") + } + } +} From b8e4d567a7d6c2ff277700d4e7707e57e87c7808 Mon Sep 17 00:00:00 2001 From: wangzhenhua Date: Thu, 6 Jul 2017 16:00:31 +0800 Subject: [PATCH 0883/1765] [SPARK-21324][TEST] Improve statistics test suites ## What changes were proposed in this pull request? 1. move `StatisticsCollectionTestBase` to a separate file. 2. move some test cases to `StatisticsCollectionSuite` so that `hive/StatisticsSuite` only keeps tests that need hive support. 3. clear up some test cases. ## How was this patch tested? Existing tests. Author: wangzhenhua Author: Zhenhua Wang Closes #18545 from wzhfy/cleanStatSuites. --- .../spark/sql/StatisticsCollectionSuite.scala | 193 +++--------------- .../sql/StatisticsCollectionTestBase.scala | 192 +++++++++++++++++ .../spark/sql/hive/StatisticsSuite.scala | 124 +++-------- 3 files changed, 258 insertions(+), 251 deletions(-) create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/StatisticsCollectionTestBase.scala diff --git a/sql/core/src/test/scala/org/apache/spark/sql/StatisticsCollectionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/StatisticsCollectionSuite.scala index d9392de37a815..843ced7f0e697 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/StatisticsCollectionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/StatisticsCollectionSuite.scala @@ -17,19 +17,12 @@ package org.apache.spark.sql -import java.{lang => jl} -import java.sql.{Date, Timestamp} - import scala.collection.mutable -import scala.util.Random import org.apache.spark.sql.catalyst.TableIdentifier -import org.apache.spark.sql.catalyst.catalog.{CatalogRelation, CatalogStatistics} import org.apache.spark.sql.catalyst.plans.logical._ -import org.apache.spark.sql.catalyst.util.DateTimeUtils -import org.apache.spark.sql.execution.datasources.LogicalRelation -import org.apache.spark.sql.internal.{SQLConf, StaticSQLConf} -import org.apache.spark.sql.test.{SharedSQLContext, SQLTestUtils} +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.test.SQLTestData.ArrayData import org.apache.spark.sql.types._ @@ -58,6 +51,37 @@ class StatisticsCollectionSuite extends StatisticsCollectionTestBase with Shared } } + test("analyzing views is not supported") { + def assertAnalyzeUnsupported(analyzeCommand: String): Unit = { + val err = intercept[AnalysisException] { + sql(analyzeCommand) + } + assert(err.message.contains("ANALYZE TABLE is not supported")) + } + + val tableName = "tbl" + withTable(tableName) { + spark.range(10).write.saveAsTable(tableName) + val viewName = "view" + withView(viewName) { + sql(s"CREATE VIEW $viewName AS SELECT * FROM $tableName") + assertAnalyzeUnsupported(s"ANALYZE TABLE $viewName COMPUTE STATISTICS") + assertAnalyzeUnsupported(s"ANALYZE TABLE $viewName COMPUTE STATISTICS FOR COLUMNS id") + } + } + } + + test("statistics collection of a table with zero column") { + val table_no_cols = "table_no_cols" + withTable(table_no_cols) { + val rddNoCols = sparkContext.parallelize(1 to 10).map(_ => Row.empty) + val dfNoCols = spark.createDataFrame(rddNoCols, StructType(Seq.empty)) + dfNoCols.write.format("json").saveAsTable(table_no_cols) + sql(s"ANALYZE TABLE $table_no_cols COMPUTE STATISTICS") + checkTableStats(table_no_cols, hasSizeInBytes = true, expectedRowCounts = Some(10)) + } + } + test("analyze column command - unsupported types and invalid columns") { val tableName = "column_stats_test1" withTable(tableName) { @@ -239,154 +263,3 @@ class StatisticsCollectionSuite extends StatisticsCollectionTestBase with Shared } } - - -/** - * The base for test cases that we want to include in both the hive module (for verifying behavior - * when using the Hive external catalog) as well as in the sql/core module. - */ -abstract class StatisticsCollectionTestBase extends QueryTest with SQLTestUtils { - import testImplicits._ - - private val dec1 = new java.math.BigDecimal("1.000000000000000000") - private val dec2 = new java.math.BigDecimal("8.000000000000000000") - private val d1 = Date.valueOf("2016-05-08") - private val d2 = Date.valueOf("2016-05-09") - private val t1 = Timestamp.valueOf("2016-05-08 00:00:01") - private val t2 = Timestamp.valueOf("2016-05-09 00:00:02") - - /** - * Define a very simple 3 row table used for testing column serialization. - * Note: last column is seq[int] which doesn't support stats collection. - */ - protected val data = Seq[ - (jl.Boolean, jl.Byte, jl.Short, jl.Integer, jl.Long, - jl.Double, jl.Float, java.math.BigDecimal, - String, Array[Byte], Date, Timestamp, - Seq[Int])]( - (false, 1.toByte, 1.toShort, 1, 1L, 1.0, 1.0f, dec1, "s1", "b1".getBytes, d1, t1, null), - (true, 2.toByte, 3.toShort, 4, 5L, 6.0, 7.0f, dec2, "ss9", "bb0".getBytes, d2, t2, null), - (null, null, null, null, null, null, null, null, null, null, null, null, null) - ) - - /** A mapping from column to the stats collected. */ - protected val stats = mutable.LinkedHashMap( - "cbool" -> ColumnStat(2, Some(false), Some(true), 1, 1, 1), - "cbyte" -> ColumnStat(2, Some(1.toByte), Some(2.toByte), 1, 1, 1), - "cshort" -> ColumnStat(2, Some(1.toShort), Some(3.toShort), 1, 2, 2), - "cint" -> ColumnStat(2, Some(1), Some(4), 1, 4, 4), - "clong" -> ColumnStat(2, Some(1L), Some(5L), 1, 8, 8), - "cdouble" -> ColumnStat(2, Some(1.0), Some(6.0), 1, 8, 8), - "cfloat" -> ColumnStat(2, Some(1.0f), Some(7.0f), 1, 4, 4), - "cdecimal" -> ColumnStat(2, Some(Decimal(dec1)), Some(Decimal(dec2)), 1, 16, 16), - "cstring" -> ColumnStat(2, None, None, 1, 3, 3), - "cbinary" -> ColumnStat(2, None, None, 1, 3, 3), - "cdate" -> ColumnStat(2, Some(DateTimeUtils.fromJavaDate(d1)), - Some(DateTimeUtils.fromJavaDate(d2)), 1, 4, 4), - "ctimestamp" -> ColumnStat(2, Some(DateTimeUtils.fromJavaTimestamp(t1)), - Some(DateTimeUtils.fromJavaTimestamp(t2)), 1, 8, 8) - ) - - private val randomName = new Random(31) - - def checkTableStats( - tableName: String, - hasSizeInBytes: Boolean, - expectedRowCounts: Option[Int]): Option[CatalogStatistics] = { - val stats = spark.sessionState.catalog.getTableMetadata(TableIdentifier(tableName)).stats - if (hasSizeInBytes || expectedRowCounts.nonEmpty) { - assert(stats.isDefined) - assert(stats.get.sizeInBytes >= 0) - assert(stats.get.rowCount === expectedRowCounts) - } else { - assert(stats.isEmpty) - } - - stats - } - - /** - * Compute column stats for the given DataFrame and compare it with colStats. - */ - def checkColStats( - df: DataFrame, - colStats: mutable.LinkedHashMap[String, ColumnStat]): Unit = { - val tableName = "column_stats_test_" + randomName.nextInt(1000) - withTable(tableName) { - df.write.saveAsTable(tableName) - - // Collect statistics - sql(s"analyze table $tableName compute STATISTICS FOR COLUMNS " + - colStats.keys.mkString(", ")) - - // Validate statistics - val table = spark.sessionState.catalog.getTableMetadata(TableIdentifier(tableName)) - assert(table.stats.isDefined) - assert(table.stats.get.colStats.size == colStats.size) - - colStats.foreach { case (k, v) => - withClue(s"column $k") { - assert(table.stats.get.colStats(k) == v) - } - } - } - } - - // This test will be run twice: with and without Hive support - test("SPARK-18856: non-empty partitioned table should not report zero size") { - withTable("ds_tbl", "hive_tbl") { - spark.range(100).select($"id", $"id" % 5 as "p").write.partitionBy("p").saveAsTable("ds_tbl") - val stats = spark.table("ds_tbl").queryExecution.optimizedPlan.stats - assert(stats.sizeInBytes > 0, "non-empty partitioned table should not report zero size.") - - if (spark.conf.get(StaticSQLConf.CATALOG_IMPLEMENTATION) == "hive") { - sql("CREATE TABLE hive_tbl(i int) PARTITIONED BY (j int)") - sql("INSERT INTO hive_tbl PARTITION(j=1) SELECT 1") - val stats2 = spark.table("hive_tbl").queryExecution.optimizedPlan.stats - assert(stats2.sizeInBytes > 0, "non-empty partitioned table should not report zero size.") - } - } - } - - // This test will be run twice: with and without Hive support - test("conversion from CatalogStatistics to Statistics") { - withTable("ds_tbl", "hive_tbl") { - // Test data source table - checkStatsConversion(tableName = "ds_tbl", isDatasourceTable = true) - // Test hive serde table - if (spark.conf.get(StaticSQLConf.CATALOG_IMPLEMENTATION) == "hive") { - checkStatsConversion(tableName = "hive_tbl", isDatasourceTable = false) - } - } - } - - private def checkStatsConversion(tableName: String, isDatasourceTable: Boolean): Unit = { - // Create an empty table and run analyze command on it. - val createTableSql = if (isDatasourceTable) { - s"CREATE TABLE $tableName (c1 INT, c2 STRING) USING PARQUET" - } else { - s"CREATE TABLE $tableName (c1 INT, c2 STRING)" - } - sql(createTableSql) - // Analyze only one column. - sql(s"ANALYZE TABLE $tableName COMPUTE STATISTICS FOR COLUMNS c1") - val (relation, catalogTable) = spark.table(tableName).queryExecution.analyzed.collect { - case catalogRel: CatalogRelation => (catalogRel, catalogRel.tableMeta) - case logicalRel: LogicalRelation => (logicalRel, logicalRel.catalogTable.get) - }.head - val emptyColStat = ColumnStat(0, None, None, 0, 4, 4) - // Check catalog statistics - assert(catalogTable.stats.isDefined) - assert(catalogTable.stats.get.sizeInBytes == 0) - assert(catalogTable.stats.get.rowCount == Some(0)) - assert(catalogTable.stats.get.colStats == Map("c1" -> emptyColStat)) - - // Check relation statistics - assert(relation.stats.sizeInBytes == 0) - assert(relation.stats.rowCount == Some(0)) - assert(relation.stats.attributeStats.size == 1) - val (attribute, colStat) = relation.stats.attributeStats.head - assert(attribute.name == "c1") - assert(colStat == emptyColStat) - } -} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/StatisticsCollectionTestBase.scala b/sql/core/src/test/scala/org/apache/spark/sql/StatisticsCollectionTestBase.scala new file mode 100644 index 0000000000000..41569762d3c59 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/StatisticsCollectionTestBase.scala @@ -0,0 +1,192 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import java.{lang => jl} +import java.sql.{Date, Timestamp} + +import scala.collection.mutable +import scala.util.Random + +import org.apache.spark.sql.catalyst.TableIdentifier +import org.apache.spark.sql.catalyst.catalog.{CatalogRelation, CatalogStatistics, CatalogTable} +import org.apache.spark.sql.catalyst.plans.logical.ColumnStat +import org.apache.spark.sql.catalyst.util.DateTimeUtils +import org.apache.spark.sql.execution.datasources.LogicalRelation +import org.apache.spark.sql.internal.StaticSQLConf +import org.apache.spark.sql.test.SQLTestUtils +import org.apache.spark.sql.types.Decimal + + +/** + * The base for statistics test cases that we want to include in both the hive module (for + * verifying behavior when using the Hive external catalog) as well as in the sql/core module. + */ +abstract class StatisticsCollectionTestBase extends QueryTest with SQLTestUtils { + import testImplicits._ + + private val dec1 = new java.math.BigDecimal("1.000000000000000000") + private val dec2 = new java.math.BigDecimal("8.000000000000000000") + private val d1 = Date.valueOf("2016-05-08") + private val d2 = Date.valueOf("2016-05-09") + private val t1 = Timestamp.valueOf("2016-05-08 00:00:01") + private val t2 = Timestamp.valueOf("2016-05-09 00:00:02") + + /** + * Define a very simple 3 row table used for testing column serialization. + * Note: last column is seq[int] which doesn't support stats collection. + */ + protected val data = Seq[ + (jl.Boolean, jl.Byte, jl.Short, jl.Integer, jl.Long, + jl.Double, jl.Float, java.math.BigDecimal, + String, Array[Byte], Date, Timestamp, + Seq[Int])]( + (false, 1.toByte, 1.toShort, 1, 1L, 1.0, 1.0f, dec1, "s1", "b1".getBytes, d1, t1, null), + (true, 2.toByte, 3.toShort, 4, 5L, 6.0, 7.0f, dec2, "ss9", "bb0".getBytes, d2, t2, null), + (null, null, null, null, null, null, null, null, null, null, null, null, null) + ) + + /** A mapping from column to the stats collected. */ + protected val stats = mutable.LinkedHashMap( + "cbool" -> ColumnStat(2, Some(false), Some(true), 1, 1, 1), + "cbyte" -> ColumnStat(2, Some(1.toByte), Some(2.toByte), 1, 1, 1), + "cshort" -> ColumnStat(2, Some(1.toShort), Some(3.toShort), 1, 2, 2), + "cint" -> ColumnStat(2, Some(1), Some(4), 1, 4, 4), + "clong" -> ColumnStat(2, Some(1L), Some(5L), 1, 8, 8), + "cdouble" -> ColumnStat(2, Some(1.0), Some(6.0), 1, 8, 8), + "cfloat" -> ColumnStat(2, Some(1.0f), Some(7.0f), 1, 4, 4), + "cdecimal" -> ColumnStat(2, Some(Decimal(dec1)), Some(Decimal(dec2)), 1, 16, 16), + "cstring" -> ColumnStat(2, None, None, 1, 3, 3), + "cbinary" -> ColumnStat(2, None, None, 1, 3, 3), + "cdate" -> ColumnStat(2, Some(DateTimeUtils.fromJavaDate(d1)), + Some(DateTimeUtils.fromJavaDate(d2)), 1, 4, 4), + "ctimestamp" -> ColumnStat(2, Some(DateTimeUtils.fromJavaTimestamp(t1)), + Some(DateTimeUtils.fromJavaTimestamp(t2)), 1, 8, 8) + ) + + private val randomName = new Random(31) + + def getCatalogTable(tableName: String): CatalogTable = { + spark.sessionState.catalog.getTableMetadata(TableIdentifier(tableName)) + } + + def getCatalogStatistics(tableName: String): CatalogStatistics = { + getCatalogTable(tableName).stats.get + } + + def checkTableStats( + tableName: String, + hasSizeInBytes: Boolean, + expectedRowCounts: Option[Int]): Option[CatalogStatistics] = { + val stats = getCatalogTable(tableName).stats + if (hasSizeInBytes || expectedRowCounts.nonEmpty) { + assert(stats.isDefined) + assert(stats.get.sizeInBytes >= 0) + assert(stats.get.rowCount === expectedRowCounts) + } else { + assert(stats.isEmpty) + } + + stats + } + + /** + * Compute column stats for the given DataFrame and compare it with colStats. + */ + def checkColStats( + df: DataFrame, + colStats: mutable.LinkedHashMap[String, ColumnStat]): Unit = { + val tableName = "column_stats_test_" + randomName.nextInt(1000) + withTable(tableName) { + df.write.saveAsTable(tableName) + + // Collect statistics + sql(s"analyze table $tableName compute STATISTICS FOR COLUMNS " + + colStats.keys.mkString(", ")) + + // Validate statistics + val table = getCatalogTable(tableName) + assert(table.stats.isDefined) + assert(table.stats.get.colStats.size == colStats.size) + + colStats.foreach { case (k, v) => + withClue(s"column $k") { + assert(table.stats.get.colStats(k) == v) + } + } + } + } + + // This test will be run twice: with and without Hive support + test("SPARK-18856: non-empty partitioned table should not report zero size") { + withTable("ds_tbl", "hive_tbl") { + spark.range(100).select($"id", $"id" % 5 as "p").write.partitionBy("p").saveAsTable("ds_tbl") + val stats = spark.table("ds_tbl").queryExecution.optimizedPlan.stats + assert(stats.sizeInBytes > 0, "non-empty partitioned table should not report zero size.") + + if (spark.conf.get(StaticSQLConf.CATALOG_IMPLEMENTATION) == "hive") { + sql("CREATE TABLE hive_tbl(i int) PARTITIONED BY (j int)") + sql("INSERT INTO hive_tbl PARTITION(j=1) SELECT 1") + val stats2 = spark.table("hive_tbl").queryExecution.optimizedPlan.stats + assert(stats2.sizeInBytes > 0, "non-empty partitioned table should not report zero size.") + } + } + } + + // This test will be run twice: with and without Hive support + test("conversion from CatalogStatistics to Statistics") { + withTable("ds_tbl", "hive_tbl") { + // Test data source table + checkStatsConversion(tableName = "ds_tbl", isDatasourceTable = true) + // Test hive serde table + if (spark.conf.get(StaticSQLConf.CATALOG_IMPLEMENTATION) == "hive") { + checkStatsConversion(tableName = "hive_tbl", isDatasourceTable = false) + } + } + } + + private def checkStatsConversion(tableName: String, isDatasourceTable: Boolean): Unit = { + // Create an empty table and run analyze command on it. + val createTableSql = if (isDatasourceTable) { + s"CREATE TABLE $tableName (c1 INT, c2 STRING) USING PARQUET" + } else { + s"CREATE TABLE $tableName (c1 INT, c2 STRING)" + } + sql(createTableSql) + // Analyze only one column. + sql(s"ANALYZE TABLE $tableName COMPUTE STATISTICS FOR COLUMNS c1") + val (relation, catalogTable) = spark.table(tableName).queryExecution.analyzed.collect { + case catalogRel: CatalogRelation => (catalogRel, catalogRel.tableMeta) + case logicalRel: LogicalRelation => (logicalRel, logicalRel.catalogTable.get) + }.head + val emptyColStat = ColumnStat(0, None, None, 0, 4, 4) + // Check catalog statistics + assert(catalogTable.stats.isDefined) + assert(catalogTable.stats.get.sizeInBytes == 0) + assert(catalogTable.stats.get.rowCount == Some(0)) + assert(catalogTable.stats.get.colStats == Map("c1" -> emptyColStat)) + + // Check relation statistics + assert(relation.stats.sizeInBytes == 0) + assert(relation.stats.rowCount == Some(0)) + assert(relation.stats.attributeStats.size == 1) + val (attribute, colStat) = relation.stats.attributeStats.head + assert(attribute.name == "c1") + assert(colStat == emptyColStat) + } +} diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala index c601038a2b0af..e00fa64e9f2ce 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala @@ -25,7 +25,7 @@ import scala.util.matching.Regex import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.TableIdentifier -import org.apache.spark.sql.catalyst.catalog.{CatalogRelation, CatalogStatistics, CatalogTable} +import org.apache.spark.sql.catalyst.catalog.{CatalogRelation, CatalogStatistics} import org.apache.spark.sql.catalyst.util.StringUtils import org.apache.spark.sql.execution.command.DDLUtils import org.apache.spark.sql.execution.datasources.LogicalRelation @@ -33,7 +33,6 @@ import org.apache.spark.sql.execution.joins._ import org.apache.spark.sql.hive.HiveExternalCatalog._ import org.apache.spark.sql.hive.test.TestHiveSingleton import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.types._ class StatisticsSuite extends StatisticsCollectionTestBase with TestHiveSingleton { @@ -82,58 +81,42 @@ class StatisticsSuite extends StatisticsCollectionTestBase with TestHiveSingleto spark.table(tableName).queryExecution.analyzed.stats.sizeInBytes // Non-partitioned table - sql("CREATE TABLE analyzeTable (key STRING, value STRING)").collect() - sql("INSERT INTO TABLE analyzeTable SELECT * FROM src").collect() - sql("INSERT INTO TABLE analyzeTable SELECT * FROM src").collect() + val nonPartTable = "non_part_table" + withTable(nonPartTable) { + sql(s"CREATE TABLE $nonPartTable (key STRING, value STRING)") + sql(s"INSERT INTO TABLE $nonPartTable SELECT * FROM src") + sql(s"INSERT INTO TABLE $nonPartTable SELECT * FROM src") - sql("ANALYZE TABLE analyzeTable COMPUTE STATISTICS noscan") + sql(s"ANALYZE TABLE $nonPartTable COMPUTE STATISTICS noscan") - assert(queryTotalSize("analyzeTable") === BigInt(11624)) - - sql("DROP TABLE analyzeTable").collect() + assert(queryTotalSize(nonPartTable) === BigInt(11624)) + } // Partitioned table - sql( - """ - |CREATE TABLE analyzeTable_part (key STRING, value STRING) PARTITIONED BY (ds STRING) - """.stripMargin).collect() - sql( - """ - |INSERT INTO TABLE analyzeTable_part PARTITION (ds='2010-01-01') - |SELECT * FROM src - """.stripMargin).collect() - sql( - """ - |INSERT INTO TABLE analyzeTable_part PARTITION (ds='2010-01-02') - |SELECT * FROM src - """.stripMargin).collect() - sql( - """ - |INSERT INTO TABLE analyzeTable_part PARTITION (ds='2010-01-03') - |SELECT * FROM src - """.stripMargin).collect() + val partTable = "part_table" + withTable(partTable) { + sql(s"CREATE TABLE $partTable (key STRING, value STRING) PARTITIONED BY (ds STRING)") + sql(s"INSERT INTO TABLE $partTable PARTITION (ds='2010-01-01') SELECT * FROM src") + sql(s"INSERT INTO TABLE $partTable PARTITION (ds='2010-01-02') SELECT * FROM src") + sql(s"INSERT INTO TABLE $partTable PARTITION (ds='2010-01-03') SELECT * FROM src") - assert(queryTotalSize("analyzeTable_part") === spark.sessionState.conf.defaultSizeInBytes) + assert(queryTotalSize(partTable) === spark.sessionState.conf.defaultSizeInBytes) - sql("ANALYZE TABLE analyzeTable_part COMPUTE STATISTICS noscan") + sql(s"ANALYZE TABLE $partTable COMPUTE STATISTICS noscan") - assert(queryTotalSize("analyzeTable_part") === BigInt(17436)) - - sql("DROP TABLE analyzeTable_part").collect() + assert(queryTotalSize(partTable) === BigInt(17436)) + } // Try to analyze a temp table - sql("""SELECT * FROM src""").createOrReplaceTempView("tempTable") - intercept[AnalysisException] { - sql("ANALYZE TABLE tempTable COMPUTE STATISTICS") + withView("tempTable") { + sql("""SELECT * FROM src""").createOrReplaceTempView("tempTable") + intercept[AnalysisException] { + sql("ANALYZE TABLE tempTable COMPUTE STATISTICS") + } } - spark.sessionState.catalog.dropTable( - TableIdentifier("tempTable"), ignoreIfNotExists = true, purge = false) } test("SPARK-21079 - analyze table with location different than that of individual partitions") { - def queryTotalSize(tableName: String): BigInt = - spark.table(tableName).queryExecution.analyzed.stats.sizeInBytes - val tableName = "analyzeTable_part" withTable(tableName) { withTempPath { path => @@ -148,15 +131,12 @@ class StatisticsSuite extends StatisticsCollectionTestBase with TestHiveSingleto sql(s"ANALYZE TABLE $tableName COMPUTE STATISTICS noscan") - assert(queryTotalSize(tableName) === BigInt(17436)) + assert(getCatalogStatistics(tableName).sizeInBytes === BigInt(17436)) } } } test("SPARK-21079 - analyze partitioned table with only a subset of partitions visible") { - def queryTotalSize(tableName: String): BigInt = - spark.table(tableName).queryExecution.analyzed.stats.sizeInBytes - val sourceTableName = "analyzeTable_part" val tableName = "analyzeTable_part_vis" withTable(sourceTableName, tableName) { @@ -188,39 +168,19 @@ class StatisticsSuite extends StatisticsCollectionTestBase with TestHiveSingleto // Register only one of the partitions found on disk val ds = partitionDates.head - sql(s"ALTER TABLE $tableName ADD PARTITION (ds='$ds')").collect() + sql(s"ALTER TABLE $tableName ADD PARTITION (ds='$ds')") // Analyze original table - expect 3 partitions sql(s"ANALYZE TABLE $sourceTableName COMPUTE STATISTICS noscan") - assert(queryTotalSize(sourceTableName) === BigInt(3 * 5812)) + assert(getCatalogStatistics(sourceTableName).sizeInBytes === BigInt(3 * 5812)) // Analyze partial-copy table - expect only 1 partition sql(s"ANALYZE TABLE $tableName COMPUTE STATISTICS noscan") - assert(queryTotalSize(tableName) === BigInt(5812)) + assert(getCatalogStatistics(tableName).sizeInBytes === BigInt(5812)) } } } - test("analyzing views is not supported") { - def assertAnalyzeUnsupported(analyzeCommand: String): Unit = { - val err = intercept[AnalysisException] { - sql(analyzeCommand) - } - assert(err.message.contains("ANALYZE TABLE is not supported")) - } - - val tableName = "tbl" - withTable(tableName) { - spark.range(10).write.saveAsTable(tableName) - val viewName = "view" - withView(viewName) { - sql(s"CREATE VIEW $viewName AS SELECT * FROM $tableName") - assertAnalyzeUnsupported(s"ANALYZE TABLE $viewName COMPUTE STATISTICS") - assertAnalyzeUnsupported(s"ANALYZE TABLE $viewName COMPUTE STATISTICS FOR COLUMNS id") - } - } - } - test("test table-level statistics for hive tables created in HiveExternalCatalog") { val textTable = "textTable" withTable(textTable) { @@ -290,8 +250,7 @@ class StatisticsSuite extends StatisticsCollectionTestBase with TestHiveSingleto if (analyzedByHive) hiveClient.runSqlHive(s"ANALYZE TABLE $tabName COMPUTE STATISTICS") val describeResult1 = hiveClient.runSqlHive(s"DESCRIBE FORMATTED $tabName") - val tableMetadata = - spark.sessionState.catalog.getTableMetadata(TableIdentifier(tabName)).properties + val tableMetadata = getCatalogTable(tabName).properties // statistics info is not contained in the metadata of the original table assert(Seq(StatsSetupConst.COLUMN_STATS_ACCURATE, StatsSetupConst.NUM_FILES, @@ -327,8 +286,7 @@ class StatisticsSuite extends StatisticsCollectionTestBase with TestHiveSingleto val tabName = "tab1" withTable(tabName) { createNonPartitionedTable(tabName, analyzedByHive = false, analyzedBySpark = false) - checkTableStats( - tabName, hasSizeInBytes = true, expectedRowCounts = None) + checkTableStats(tabName, hasSizeInBytes = true, expectedRowCounts = None) // ALTER TABLE SET TBLPROPERTIES invalidates some contents of Hive specific statistics // This is triggered by the Hive alterTable API @@ -370,10 +328,6 @@ class StatisticsSuite extends StatisticsCollectionTestBase with TestHiveSingleto } test("alter table should not have the side effect to store statistics in Spark side") { - def getCatalogTable(tableName: String): CatalogTable = { - spark.sessionState.catalog.getTableMetadata(TableIdentifier(tableName)) - } - val table = "alter_table_side_effect" withTable(table) { sql(s"CREATE TABLE $table (i string, j string)") @@ -637,12 +591,12 @@ class StatisticsSuite extends StatisticsCollectionTestBase with TestHiveSingleto // the default value for `spark.sql.hive.convertMetastoreParquet` is true, here we just set it // for robustness - withSQLConf("spark.sql.hive.convertMetastoreParquet" -> "true") { + withSQLConf(HiveUtils.CONVERT_METASTORE_PARQUET.key -> "true") { checkTableStats(parquetTable, hasSizeInBytes = false, expectedRowCounts = None) sql(s"ANALYZE TABLE $parquetTable COMPUTE STATISTICS") checkTableStats(parquetTable, hasSizeInBytes = true, expectedRowCounts = Some(500)) } - withSQLConf("spark.sql.hive.convertMetastoreOrc" -> "true") { + withSQLConf(HiveUtils.CONVERT_METASTORE_ORC.key -> "true") { // We still can get tableSize from Hive before Analyze checkTableStats(orcTable, hasSizeInBytes = true, expectedRowCounts = None) sql(s"ANALYZE TABLE $orcTable COMPUTE STATISTICS") @@ -759,8 +713,7 @@ class StatisticsSuite extends StatisticsCollectionTestBase with TestHiveSingleto val parquetTable = "parquetTable" withTable(parquetTable) { sql(createTableCmd) - val catalogTable = spark.sessionState.catalog.getTableMetadata( - TableIdentifier(parquetTable)) + val catalogTable = getCatalogTable(parquetTable) assert(DDLUtils.isDatasourceTable(catalogTable)) // Add a filter to avoid creating too many partitions @@ -795,17 +748,6 @@ class StatisticsSuite extends StatisticsCollectionTestBase with TestHiveSingleto "partitioned data source table", "CREATE TABLE parquetTable (key STRING, value STRING) USING PARQUET PARTITIONED BY (key)") - test("statistics collection of a table with zero column") { - val table_no_cols = "table_no_cols" - withTable(table_no_cols) { - val rddNoCols = sparkContext.parallelize(1 to 10).map(_ => Row.empty) - val dfNoCols = spark.createDataFrame(rddNoCols, StructType(Seq.empty)) - dfNoCols.write.format("json").saveAsTable(table_no_cols) - sql(s"ANALYZE TABLE $table_no_cols COMPUTE STATISTICS") - checkTableStats(table_no_cols, hasSizeInBytes = true, expectedRowCounts = Some(10)) - } - } - /** Used to test refreshing cached metadata once table stats are updated. */ private def getStatsBeforeAfterUpdate(isAnalyzeColumns: Boolean) : (CatalogStatistics, CatalogStatistics) = { From d540dfbff33aa2f8571e0de149dfa3f4e7321113 Mon Sep 17 00:00:00 2001 From: Wang Gengliang Date: Thu, 6 Jul 2017 19:12:15 +0800 Subject: [PATCH 0884/1765] [SPARK-21273][SQL][FOLLOW-UP] Add missing test cases back and revise code style ## What changes were proposed in this pull request? Add missing test cases back and revise code style Follow up the previous PR: https://github.com/apache/spark/pull/18479 ## How was this patch tested? Unit test Please review http://spark.apache.org/contributing.html before opening a pull request. Author: Wang Gengliang Closes #18548 from gengliangwang/stat_propagation_revise. --- .../plans/logical/LogicalPlanVisitor.scala | 2 +- .../BasicStatsEstimationSuite.scala | 45 +++++++++++++++++++ 2 files changed, 46 insertions(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlanVisitor.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlanVisitor.scala index b23045810a4f6..2652f6d72730c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlanVisitor.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlanVisitor.scala @@ -38,10 +38,10 @@ trait LogicalPlanVisitor[T] { case p: Range => visitRange(p) case p: Repartition => visitRepartition(p) case p: RepartitionByExpression => visitRepartitionByExpr(p) + case p: ResolvedHint => visitHint(p) case p: Sample => visitSample(p) case p: ScriptTransformation => visitScriptTransform(p) case p: Union => visitUnion(p) - case p: ResolvedHint => visitHint(p) case p: LogicalPlan => default(p) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/BasicStatsEstimationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/BasicStatsEstimationSuite.scala index 5fd21a06a109d..913be6d1ff07f 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/BasicStatsEstimationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/BasicStatsEstimationSuite.scala @@ -78,6 +78,37 @@ class BasicStatsEstimationSuite extends PlanTest with StatsEstimationTestBase { checkStats(globalLimit, stats) } + test("sample estimation") { + val sample = Sample(0.0, 0.5, withReplacement = false, (math.random * 1000).toLong, plan) + checkStats(sample, Statistics(sizeInBytes = 60, rowCount = Some(5))) + + // Child doesn't have rowCount in stats + val childStats = Statistics(sizeInBytes = 120) + val childPlan = DummyLogicalPlan(childStats, childStats) + val sample2 = + Sample(0.0, 0.11, withReplacement = false, (math.random * 1000).toLong, childPlan) + checkStats(sample2, Statistics(sizeInBytes = 14)) + } + + test("estimate statistics when the conf changes") { + val expectedDefaultStats = + Statistics( + sizeInBytes = 40, + rowCount = Some(10), + attributeStats = AttributeMap(Seq( + AttributeReference("c1", IntegerType)() -> ColumnStat(10, Some(1), Some(10), 0, 4, 4)))) + val expectedCboStats = + Statistics( + sizeInBytes = 4, + rowCount = Some(1), + attributeStats = AttributeMap(Seq( + AttributeReference("c1", IntegerType)() -> ColumnStat(1, Some(5), Some(5), 0, 4, 4)))) + + val plan = DummyLogicalPlan(defaultStats = expectedDefaultStats, cboStats = expectedCboStats) + checkStats( + plan, expectedStatsCboOn = expectedCboStats, expectedStatsCboOff = expectedDefaultStats) + } + /** Check estimated stats when cbo is turned on/off. */ private def checkStats( plan: LogicalPlan, @@ -99,3 +130,17 @@ class BasicStatsEstimationSuite extends PlanTest with StatsEstimationTestBase { private def checkStats(plan: LogicalPlan, expectedStats: Statistics): Unit = checkStats(plan, expectedStats, expectedStats) } + +/** + * This class is used for unit-testing the cbo switch, it mimics a logical plan which computes + * a simple statistics or a cbo estimated statistics based on the conf. + */ +private case class DummyLogicalPlan( + defaultStats: Statistics, + cboStats: Statistics) + extends LeafNode { + + override def output: Seq[Attribute] = Nil + + override def computeStats(): Statistics = if (conf.cboEnabled) cboStats else defaultStats +} From 565e7a8d4ae7879ee704fb94ae9b3da31e202d7e Mon Sep 17 00:00:00 2001 From: caoxuewen Date: Thu, 6 Jul 2017 19:49:34 +0800 Subject: [PATCH 0885/1765] [SPARK-20950][CORE] add a new config to diskWriteBufferSize which is hard coded before ## What changes were proposed in this pull request? This PR Improvement in two: 1.With spark.shuffle.spill.diskWriteBufferSize configure diskWriteBufferSize of ShuffleExternalSorter. when change the size of the diskWriteBufferSize to test `forceSorterToSpill` The average performance of running 10 times is as follows:(their unit is MS). ``` diskWriteBufferSize: 1M 512K 256K 128K 64K 32K 16K 8K 4K --------------------------------------------------------------------------------------- RecordSize = 2.5M 742 722 694 686 667 668 671 669 683 RecordSize = 1M 294 293 292 287 283 285 281 279 285 ``` 2.Remove outputBufferSizeInBytes and inputBufferSizeInBytes to initialize in mergeSpillsWithFileStream function. ## How was this patch tested? The unit test. Author: caoxuewen Closes #18174 from heary-cao/buffersize. --- .../shuffle/sort/ShuffleExternalSorter.java | 11 +++++--- .../shuffle/sort/UnsafeShuffleWriter.java | 14 +++++++--- .../unsafe/sort/UnsafeSorterSpillWriter.java | 24 ++++++++++------- .../spark/internal/config/package.scala | 27 +++++++++++++++++++ 4 files changed, 60 insertions(+), 16 deletions(-) diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleExternalSorter.java b/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleExternalSorter.java index c33d1e33f030f..338faaadb33d4 100644 --- a/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleExternalSorter.java +++ b/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleExternalSorter.java @@ -43,6 +43,7 @@ import org.apache.spark.unsafe.array.LongArray; import org.apache.spark.unsafe.memory.MemoryBlock; import org.apache.spark.util.Utils; +import org.apache.spark.internal.config.package$; /** * An external sorter that is specialized for sort-based shuffle. @@ -82,6 +83,9 @@ final class ShuffleExternalSorter extends MemoryConsumer { /** The buffer size to use when writing spills using DiskBlockObjectWriter */ private final int fileBufferSizeBytes; + /** The buffer size to use when writing the sorted records to an on-disk file */ + private final int diskWriteBufferSize; + /** * Memory pages that hold the records being sorted. The pages in this list are freed when * spilling, although in principle we could recycle these pages across spills (on the other hand, @@ -116,13 +120,14 @@ final class ShuffleExternalSorter extends MemoryConsumer { this.taskContext = taskContext; this.numPartitions = numPartitions; // Use getSizeAsKb (not bytes) to maintain backwards compatibility if no units are provided - this.fileBufferSizeBytes = (int) conf.getSizeAsKb("spark.shuffle.file.buffer", "32k") * 1024; + this.fileBufferSizeBytes = (int) (long) conf.get(package$.MODULE$.SHUFFLE_FILE_BUFFER_SIZE()) * 1024; this.numElementsForSpillThreshold = conf.getLong("spark.shuffle.spill.numElementsForceSpillThreshold", 1024 * 1024 * 1024); this.writeMetrics = writeMetrics; this.inMemSorter = new ShuffleInMemorySorter( this, initialSize, conf.getBoolean("spark.shuffle.sort.useRadixSort", true)); this.peakMemoryUsedBytes = getMemoryUsage(); + this.diskWriteBufferSize = (int) (long) conf.get(package$.MODULE$.SHUFFLE_DISK_WRITE_BUFFER_SIZE()); } /** @@ -155,7 +160,7 @@ private void writeSortedFile(boolean isLastFile) throws IOException { // be an API to directly transfer bytes from managed memory to the disk writer, we buffer // data through a byte array. This array does not need to be large enough to hold a single // record; - final byte[] writeBuffer = new byte[DISK_WRITE_BUFFER_SIZE]; + final byte[] writeBuffer = new byte[diskWriteBufferSize]; // Because this output will be read during shuffle, its compression codec must be controlled by // spark.shuffle.compress instead of spark.shuffle.spill.compress, so we need to use @@ -195,7 +200,7 @@ private void writeSortedFile(boolean isLastFile) throws IOException { int dataRemaining = Platform.getInt(recordPage, recordOffsetInPage); long recordReadPosition = recordOffsetInPage + 4; // skip over record length while (dataRemaining > 0) { - final int toTransfer = Math.min(DISK_WRITE_BUFFER_SIZE, dataRemaining); + final int toTransfer = Math.min(diskWriteBufferSize, dataRemaining); Platform.copyMemory( recordPage, recordReadPosition, writeBuffer, Platform.BYTE_ARRAY_OFFSET, toTransfer); writer.write(writeBuffer, 0, toTransfer); diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java b/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java index 34c179990214f..1b578491b81d7 100644 --- a/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java +++ b/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java @@ -55,6 +55,7 @@ import org.apache.spark.storage.TimeTrackingOutputStream; import org.apache.spark.unsafe.Platform; import org.apache.spark.util.Utils; +import org.apache.spark.internal.config.package$; @Private public class UnsafeShuffleWriter extends ShuffleWriter { @@ -65,6 +66,7 @@ public class UnsafeShuffleWriter extends ShuffleWriter { @VisibleForTesting static final int DEFAULT_INITIAL_SORT_BUFFER_SIZE = 4096; + static final int DEFAULT_INITIAL_SER_BUFFER_SIZE = 1024 * 1024; private final BlockManager blockManager; private final IndexShuffleBlockResolver shuffleBlockResolver; @@ -78,6 +80,8 @@ public class UnsafeShuffleWriter extends ShuffleWriter { private final SparkConf sparkConf; private final boolean transferToEnabled; private final int initialSortBufferSize; + private final int inputBufferSizeInBytes; + private final int outputBufferSizeInBytes; @Nullable private MapStatus mapStatus; @Nullable private ShuffleExternalSorter sorter; @@ -140,6 +144,10 @@ public UnsafeShuffleWriter( this.transferToEnabled = sparkConf.getBoolean("spark.file.transferTo", true); this.initialSortBufferSize = sparkConf.getInt("spark.shuffle.sort.initialBufferSize", DEFAULT_INITIAL_SORT_BUFFER_SIZE); + this.inputBufferSizeInBytes = + (int) (long) sparkConf.get(package$.MODULE$.SHUFFLE_FILE_BUFFER_SIZE()) * 1024; + this.outputBufferSizeInBytes = + (int) (long) sparkConf.get(package$.MODULE$.SHUFFLE_UNSAFE_FILE_OUTPUT_BUFFER_SIZE()) * 1024; open(); } @@ -209,7 +217,7 @@ private void open() throws IOException { partitioner.numPartitions(), sparkConf, writeMetrics); - serBuffer = new MyByteArrayOutputStream(1024 * 1024); + serBuffer = new MyByteArrayOutputStream(DEFAULT_INITIAL_SER_BUFFER_SIZE); serOutputStream = serializer.serializeStream(serBuffer); } @@ -360,12 +368,10 @@ private long[] mergeSpillsWithFileStream( final OutputStream bos = new BufferedOutputStream( new FileOutputStream(outputFile), - (int) sparkConf.getSizeAsKb("spark.shuffle.unsafe.file.output.buffer", "32k") * 1024); + outputBufferSizeInBytes); // Use a counting output stream to avoid having to close the underlying file and ask // the file system for its size after each partition is written. final CountingOutputStream mergedFileOutputStream = new CountingOutputStream(bos); - final int inputBufferSizeInBytes = - (int) sparkConf.getSizeAsKb("spark.shuffle.file.buffer", "32k") * 1024; boolean threwException = true; try { diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillWriter.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillWriter.java index 164b9d70b79d7..f9b5493755443 100644 --- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillWriter.java +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillWriter.java @@ -20,9 +20,10 @@ import java.io.File; import java.io.IOException; -import org.apache.spark.serializer.SerializerManager; import scala.Tuple2; +import org.apache.spark.SparkConf; +import org.apache.spark.serializer.SerializerManager; import org.apache.spark.executor.ShuffleWriteMetrics; import org.apache.spark.serializer.DummySerializerInstance; import org.apache.spark.storage.BlockId; @@ -30,6 +31,7 @@ import org.apache.spark.storage.DiskBlockObjectWriter; import org.apache.spark.storage.TempLocalBlockId; import org.apache.spark.unsafe.Platform; +import org.apache.spark.internal.config.package$; /** * Spills a list of sorted records to disk. Spill files have the following format: @@ -38,12 +40,16 @@ */ public final class UnsafeSorterSpillWriter { - static final int DISK_WRITE_BUFFER_SIZE = 1024 * 1024; + private final SparkConf conf = new SparkConf(); + + /** The buffer size to use when writing the sorted records to an on-disk file */ + private final int diskWriteBufferSize = + (int) (long) conf.get(package$.MODULE$.SHUFFLE_DISK_WRITE_BUFFER_SIZE()); // Small writes to DiskBlockObjectWriter will be fairly inefficient. Since there doesn't seem to // be an API to directly transfer bytes from managed memory to the disk writer, we buffer // data through a byte array. - private byte[] writeBuffer = new byte[DISK_WRITE_BUFFER_SIZE]; + private byte[] writeBuffer = new byte[diskWriteBufferSize]; private final File file; private final BlockId blockId; @@ -114,7 +120,7 @@ public void write( writeIntToBuffer(recordLength, 0); writeLongToBuffer(keyPrefix, 4); int dataRemaining = recordLength; - int freeSpaceInWriteBuffer = DISK_WRITE_BUFFER_SIZE - 4 - 8; // space used by prefix + len + int freeSpaceInWriteBuffer = diskWriteBufferSize - 4 - 8; // space used by prefix + len long recordReadPosition = baseOffset; while (dataRemaining > 0) { final int toTransfer = Math.min(freeSpaceInWriteBuffer, dataRemaining); @@ -122,15 +128,15 @@ public void write( baseObject, recordReadPosition, writeBuffer, - Platform.BYTE_ARRAY_OFFSET + (DISK_WRITE_BUFFER_SIZE - freeSpaceInWriteBuffer), + Platform.BYTE_ARRAY_OFFSET + (diskWriteBufferSize - freeSpaceInWriteBuffer), toTransfer); - writer.write(writeBuffer, 0, (DISK_WRITE_BUFFER_SIZE - freeSpaceInWriteBuffer) + toTransfer); + writer.write(writeBuffer, 0, (diskWriteBufferSize - freeSpaceInWriteBuffer) + toTransfer); recordReadPosition += toTransfer; dataRemaining -= toTransfer; - freeSpaceInWriteBuffer = DISK_WRITE_BUFFER_SIZE; + freeSpaceInWriteBuffer = diskWriteBufferSize; } - if (freeSpaceInWriteBuffer < DISK_WRITE_BUFFER_SIZE) { - writer.write(writeBuffer, 0, (DISK_WRITE_BUFFER_SIZE - freeSpaceInWriteBuffer)); + if (freeSpaceInWriteBuffer < diskWriteBufferSize) { + writer.write(writeBuffer, 0, (diskWriteBufferSize - freeSpaceInWriteBuffer)); } writer.recordWritten(); } diff --git a/core/src/main/scala/org/apache/spark/internal/config/package.scala b/core/src/main/scala/org/apache/spark/internal/config/package.scala index 8dee0d970c4c6..a629810bf093a 100644 --- a/core/src/main/scala/org/apache/spark/internal/config/package.scala +++ b/core/src/main/scala/org/apache/spark/internal/config/package.scala @@ -336,4 +336,31 @@ package object config { "spark.") .booleanConf .createWithDefault(false) + + private[spark] val SHUFFLE_FILE_BUFFER_SIZE = + ConfigBuilder("spark.shuffle.file.buffer") + .doc("Size of the in-memory buffer for each shuffle file output stream. " + + "These buffers reduce the number of disk seeks and system calls made " + + "in creating intermediate shuffle files.") + .bytesConf(ByteUnit.KiB) + .checkValue(v => v > 0 && v <= Int.MaxValue / 1024, + s"The file buffer size must be greater than 0 and less than ${Int.MaxValue / 1024}.") + .createWithDefaultString("32k") + + private[spark] val SHUFFLE_UNSAFE_FILE_OUTPUT_BUFFER_SIZE = + ConfigBuilder("spark.shuffle.unsafe.file.output.buffer") + .doc("The file system for this buffer size after each partition " + + "is written in unsafe shuffle writer.") + .bytesConf(ByteUnit.KiB) + .checkValue(v => v > 0 && v <= Int.MaxValue / 1024, + s"The buffer size must be greater than 0 and less than ${Int.MaxValue / 1024}.") + .createWithDefaultString("32k") + + private[spark] val SHUFFLE_DISK_WRITE_BUFFER_SIZE = + ConfigBuilder("spark.shuffle.spill.diskWriteBufferSize") + .doc("The buffer size to use when writing the sorted records to an on-disk file.") + .bytesConf(ByteUnit.BYTE) + .checkValue(v => v > 0 && v <= Int.MaxValue, + s"The buffer size must be greater than 0 and less than ${Int.MaxValue}.") + .createWithDefault(1024 * 1024) } From 26ac085debb54d0104762d1cd4187cdf73f301ba Mon Sep 17 00:00:00 2001 From: Bogdan Raducanu Date: Fri, 7 Jul 2017 01:04:57 +0800 Subject: [PATCH 0886/1765] [SPARK-21228][SQL] InSet incorrect handling of structs ## What changes were proposed in this pull request? When data type is struct, InSet now uses TypeUtils.getInterpretedOrdering (similar to EqualTo) to build a TreeSet. In other cases it will use a HashSet as before (which should be faster). Similarly, In.eval uses Ordering.equiv instead of equals. ## How was this patch tested? New test in SQLQuerySuite. Author: Bogdan Raducanu Closes #18455 from bogdanrdc/SPARK-21228. --- .../sql/catalyst/expressions/predicates.scala | 57 ++++++++++++------- .../catalyst/expressions/PredicateSuite.scala | 31 +++++----- .../catalyst/optimizer/OptimizeInSuite.scala | 2 +- .../org/apache/spark/sql/SQLQuerySuite.scala | 22 +++++++ 4 files changed, 78 insertions(+), 34 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala index f3fe58caa6fe2..7bf10f199f1c7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala @@ -17,10 +17,11 @@ package org.apache.spark.sql.catalyst.expressions +import scala.collection.immutable.TreeSet + import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.TypeCheckResult -import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} -import org.apache.spark.sql.catalyst.expressions.codegen.{Predicate => BasePredicate} +import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode, GenerateSafeProjection, GenerateUnsafeProjection, Predicate => BasePredicate} import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.util.TypeUtils import org.apache.spark.sql.types._ @@ -175,20 +176,23 @@ case class In(value: Expression, list: Seq[Expression]) extends Predicate { |[${sub.output.map(_.dataType.catalogString).mkString(", ")}]. """.stripMargin) } else { - TypeCheckResult.TypeCheckSuccess + TypeUtils.checkForOrderingExpr(value.dataType, s"function $prettyName") } } case _ => - if (list.exists(l => l.dataType != value.dataType)) { - TypeCheckResult.TypeCheckFailure("Arguments must be same type") + val mismatchOpt = list.find(l => l.dataType != value.dataType) + if (mismatchOpt.isDefined) { + TypeCheckResult.TypeCheckFailure(s"Arguments must be same type but were: " + + s"${value.dataType} != ${mismatchOpt.get.dataType}") } else { - TypeCheckResult.TypeCheckSuccess + TypeUtils.checkForOrderingExpr(value.dataType, s"function $prettyName") } } } override def children: Seq[Expression] = value +: list lazy val inSetConvertible = list.forall(_.isInstanceOf[Literal]) + private lazy val ordering = TypeUtils.getInterpretedOrdering(value.dataType) override def nullable: Boolean = children.exists(_.nullable) override def foldable: Boolean = children.forall(_.foldable) @@ -203,10 +207,10 @@ case class In(value: Expression, list: Seq[Expression]) extends Predicate { var hasNull = false list.foreach { e => val v = e.eval(input) - if (v == evaluatedValue) { - return true - } else if (v == null) { + if (v == null) { hasNull = true + } else if (ordering.equiv(v, evaluatedValue)) { + return true } } if (hasNull) { @@ -265,7 +269,7 @@ case class InSet(child: Expression, hset: Set[Any]) extends UnaryExpression with override def nullable: Boolean = child.nullable || hasNull protected override def nullSafeEval(value: Any): Any = { - if (hset.contains(value)) { + if (set.contains(value)) { true } else if (hasNull) { null @@ -274,27 +278,40 @@ case class InSet(child: Expression, hset: Set[Any]) extends UnaryExpression with } } - def getHSet(): Set[Any] = hset + @transient private[this] lazy val set = child.dataType match { + case _: AtomicType => hset + case _: NullType => hset + case _ => + // for structs use interpreted ordering to be able to compare UnsafeRows with non-UnsafeRows + TreeSet.empty(TypeUtils.getInterpretedOrdering(child.dataType)) ++ hset + } + + def getSet(): Set[Any] = set override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val setName = classOf[Set[Any]].getName val InSetName = classOf[InSet].getName val childGen = child.genCode(ctx) ctx.references += this - val hsetTerm = ctx.freshName("hset") - val hasNullTerm = ctx.freshName("hasNull") - ctx.addMutableState(setName, hsetTerm, - s"$hsetTerm = (($InSetName)references[${ctx.references.size - 1}]).getHSet();") - ctx.addMutableState("boolean", hasNullTerm, s"$hasNullTerm = $hsetTerm.contains(null);") + val setTerm = ctx.freshName("set") + val setNull = if (hasNull) { + s""" + |if (!${ev.value}) { + | ${ev.isNull} = true; + |} + """.stripMargin + } else { + "" + } + ctx.addMutableState(setName, setTerm, + s"$setTerm = (($InSetName)references[${ctx.references.size - 1}]).getSet();") ev.copy(code = s""" ${childGen.code} boolean ${ev.isNull} = ${childGen.isNull}; boolean ${ev.value} = false; if (!${ev.isNull}) { - ${ev.value} = $hsetTerm.contains(${childGen.value}); - if (!${ev.value} && $hasNullTerm) { - ${ev.isNull} = true; - } + ${ev.value} = $setTerm.contains(${childGen.value}); + $setNull } """) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala index 6fe295c3dd936..ef510a95ef446 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala @@ -35,7 +35,8 @@ class PredicateSuite extends SparkFunSuite with ExpressionEvalHelper { test(s"3VL $name") { truthTable.foreach { case (l, r, answer) => - val expr = op(NonFoldableLiteral(l, BooleanType), NonFoldableLiteral(r, BooleanType)) + val expr = op(NonFoldableLiteral.create(l, BooleanType), + NonFoldableLiteral.create(r, BooleanType)) checkEvaluation(expr, answer) } } @@ -72,7 +73,7 @@ class PredicateSuite extends SparkFunSuite with ExpressionEvalHelper { (false, true) :: (null, null) :: Nil notTrueTable.foreach { case (v, answer) => - checkEvaluation(Not(NonFoldableLiteral(v, BooleanType)), answer) + checkEvaluation(Not(NonFoldableLiteral.create(v, BooleanType)), answer) } checkConsistencyBetweenInterpretedAndCodegen(Not, BooleanType) } @@ -120,22 +121,26 @@ class PredicateSuite extends SparkFunSuite with ExpressionEvalHelper { (null, null, null) :: Nil) test("IN") { - checkEvaluation(In(NonFoldableLiteral(null, IntegerType), Seq(Literal(1), Literal(2))), null) - checkEvaluation(In(NonFoldableLiteral(null, IntegerType), - Seq(NonFoldableLiteral(null, IntegerType))), null) - checkEvaluation(In(NonFoldableLiteral(null, IntegerType), Seq.empty), null) + checkEvaluation(In(NonFoldableLiteral.create(null, IntegerType), Seq(Literal(1), + Literal(2))), null) + checkEvaluation(In(NonFoldableLiteral.create(null, IntegerType), + Seq(NonFoldableLiteral.create(null, IntegerType))), null) + checkEvaluation(In(NonFoldableLiteral.create(null, IntegerType), Seq.empty), null) checkEvaluation(In(Literal(1), Seq.empty), false) - checkEvaluation(In(Literal(1), Seq(NonFoldableLiteral(null, IntegerType))), null) - checkEvaluation(In(Literal(1), Seq(Literal(1), NonFoldableLiteral(null, IntegerType))), true) - checkEvaluation(In(Literal(2), Seq(Literal(1), NonFoldableLiteral(null, IntegerType))), null) + checkEvaluation(In(Literal(1), Seq(NonFoldableLiteral.create(null, IntegerType))), null) + checkEvaluation(In(Literal(1), Seq(Literal(1), NonFoldableLiteral.create(null, IntegerType))), + true) + checkEvaluation(In(Literal(2), Seq(Literal(1), NonFoldableLiteral.create(null, IntegerType))), + null) checkEvaluation(In(Literal(1), Seq(Literal(1), Literal(2))), true) checkEvaluation(In(Literal(2), Seq(Literal(1), Literal(2))), true) checkEvaluation(In(Literal(3), Seq(Literal(1), Literal(2))), false) checkEvaluation( - And(In(Literal(1), Seq(Literal(1), Literal(2))), In(Literal(2), Seq(Literal(1), Literal(2)))), + And(In(Literal(1), Seq(Literal(1), Literal(2))), In(Literal(2), Seq(Literal(1), + Literal(2)))), true) - val ns = NonFoldableLiteral(null, StringType) + val ns = NonFoldableLiteral.create(null, StringType) checkEvaluation(In(ns, Seq(Literal("1"), Literal("2"))), null) checkEvaluation(In(ns, Seq(ns)), null) checkEvaluation(In(Literal("a"), Seq(ns)), null) @@ -155,7 +160,7 @@ class PredicateSuite extends SparkFunSuite with ExpressionEvalHelper { case _ => value } } - val input = inputData.map(NonFoldableLiteral(_, t)) + val input = inputData.map(NonFoldableLiteral.create(_, t)) val expected = if (inputData(0) == null) { null } else if (inputData.slice(1, 10).contains(inputData(0))) { @@ -279,7 +284,7 @@ class PredicateSuite extends SparkFunSuite with ExpressionEvalHelper { test("BinaryComparison: null test") { // Use -1 (default value for codegen) which can trigger some weird bugs, e.g. SPARK-14757 val normalInt = Literal(-1) - val nullInt = NonFoldableLiteral(null, IntegerType) + val nullInt = NonFoldableLiteral.create(null, IntegerType) def nullTest(op: (Expression, Expression) => Expression): Unit = { checkEvaluation(op(normalInt, nullInt), null) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeInSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeInSuite.scala index 6a77580b29a21..28bf7b6f84341 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeInSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeInSuite.scala @@ -169,7 +169,7 @@ class OptimizeInSuite extends PlanTest { val optimizedPlan = OptimizeIn(plan) optimizedPlan match { case Filter(cond, _) - if cond.isInstanceOf[InSet] && cond.asInstanceOf[InSet].getHSet().size == 3 => + if cond.isInstanceOf[InSet] && cond.asInstanceOf[InSet].getSet().size == 3 => // pass case _ => fail("Unexpected result for OptimizedIn") } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 68f61cfab6d2f..5171aaebc9907 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -2616,4 +2616,26 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { val e = intercept[AnalysisException](sql("SELECT nvl(1, 2, 3)")) assert(e.message.contains("Invalid number of arguments")) } + + test("SPARK-21228: InSet incorrect handling of structs") { + withTempView("A") { + // reduce this from the default of 10 so the repro query text is not too long + withSQLConf((SQLConf.OPTIMIZER_INSET_CONVERSION_THRESHOLD.key -> "3")) { + // a relation that has 1 column of struct type with values (1,1), ..., (9, 9) + spark.range(1, 10).selectExpr("named_struct('a', id, 'b', id) as a") + .createOrReplaceTempView("A") + val df = sql( + """ + |SELECT * from + | (SELECT MIN(a) as minA FROM A) AA -- this Aggregate will return UnsafeRows + | -- the IN will become InSet with a Set of GenericInternalRows + | -- a GenericInternalRow is never equal to an UnsafeRow so the query would + | -- returns 0 results, which is incorrect + | WHERE minA IN (NAMED_STRUCT('a', 1L, 'b', 1L), NAMED_STRUCT('a', 2L, 'b', 2L), + | NAMED_STRUCT('a', 3L, 'b', 3L)) + """.stripMargin) + checkAnswer(df, Row(Row(1, 1))) + } + } + } } From 48e44b24a7663142176102ac4c6bf4242f103804 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Fri, 7 Jul 2017 01:07:45 +0800 Subject: [PATCH 0887/1765] [SPARK-21204][SQL] Add support for Scala Set collection types in serialization ## What changes were proposed in this pull request? Currently we can't produce a `Dataset` containing `Set` in SparkSQL. This PR tries to support serialization/deserialization of `Set`. Because there's no corresponding internal data type in SparkSQL for a `Set`, the most proper choice for serializing a set should be an array. ## How was this patch tested? Added unit tests. Author: Liang-Chi Hsieh Closes #18416 from viirya/SPARK-21204. --- .../spark/sql/catalyst/ScalaReflection.scala | 28 +++++++++++++++-- .../expressions/objects/objects.scala | 5 +-- .../org/apache/spark/sql/SQLImplicits.scala | 10 ++++++ .../spark/sql/DataFrameAggregateSuite.scala | 10 ++++++ .../spark/sql/DatasetPrimitiveSuite.scala | 31 +++++++++++++++++++ 5 files changed, 79 insertions(+), 5 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala index 814f2c10b9097..4d5401f30d392 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala @@ -309,7 +309,10 @@ object ScalaReflection extends ScalaReflection { Invoke(arrayData, primitiveMethod, arrayCls, returnNullable = false) } - case t if t <:< localTypeOf[Seq[_]] => + // We serialize a `Set` to Catalyst array. When we deserialize a Catalyst array + // to a `Set`, if there are duplicated elements, the elements will be de-duplicated. + case t if t <:< localTypeOf[Seq[_]] || + t <:< localTypeOf[scala.collection.Set[_]] => val TypeRef(_, _, Seq(elementType)) = t val Schema(dataType, elementNullable) = schemaFor(elementType) val className = getClassNameFromType(elementType) @@ -327,8 +330,10 @@ object ScalaReflection extends ScalaReflection { } val companion = t.normalize.typeSymbol.companionSymbol.typeSignature - val cls = companion.declaration(newTermName("newBuilder")) match { - case NoSymbol => classOf[Seq[_]] + val cls = companion.member(newTermName("newBuilder")) match { + case NoSymbol if t <:< localTypeOf[Seq[_]] => classOf[Seq[_]] + case NoSymbol if t <:< localTypeOf[scala.collection.Set[_]] => + classOf[scala.collection.Set[_]] case _ => mirror.runtimeClass(t.typeSymbol.asClass) } UnresolvedMapObjects(mapFunction, getPath, Some(cls)) @@ -502,6 +507,19 @@ object ScalaReflection extends ScalaReflection { serializerFor(_, valueType, valuePath, seenTypeSet), valueNullable = !valueType.typeSymbol.asClass.isPrimitive) + case t if t <:< localTypeOf[scala.collection.Set[_]] => + val TypeRef(_, _, Seq(elementType)) = t + + // There's no corresponding Catalyst type for `Set`, we serialize a `Set` to Catalyst array. + // Note that the property of `Set` is only kept when manipulating the data as domain object. + val newInput = + Invoke( + inputObject, + "toSeq", + ObjectType(classOf[Seq[_]])) + + toCatalystArray(newInput, elementType) + case t if t <:< localTypeOf[String] => StaticInvoke( classOf[UTF8String], @@ -713,6 +731,10 @@ object ScalaReflection extends ScalaReflection { val Schema(valueDataType, valueNullable) = schemaFor(valueType) Schema(MapType(schemaFor(keyType).dataType, valueDataType, valueContainsNull = valueNullable), nullable = true) + case t if t <:< localTypeOf[Set[_]] => + val TypeRef(_, _, Seq(elementType)) = t + val Schema(dataType, nullable) = schemaFor(elementType) + Schema(ArrayType(dataType, containsNull = nullable), nullable = true) case t if t <:< localTypeOf[String] => Schema(StringType, nullable = true) case t if t <:< localTypeOf[java.sql.Timestamp] => Schema(TimestampType, nullable = true) case t if t <:< localTypeOf[java.sql.Date] => Schema(DateType, nullable = true) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala index 24c06d8b14b54..9b28a18035b1c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala @@ -627,8 +627,9 @@ case class MapObjects private( val (initCollection, addElement, getResult): (String, String => String, String) = customCollectionCls match { - case Some(cls) if classOf[Seq[_]].isAssignableFrom(cls) => - // Scala sequence + case Some(cls) if classOf[Seq[_]].isAssignableFrom(cls) || + classOf[scala.collection.Set[_]].isAssignableFrom(cls) => + // Scala sequence or set val getBuilder = s"${cls.getName}$$.MODULE$$.newBuilder()" val builder = ctx.freshName("collectionBuilder") ( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala index 86574e2f71d92..05db292bd41b1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala @@ -171,6 +171,16 @@ abstract class SQLImplicits extends LowPrioritySQLImplicits { /** @since 2.3.0 */ implicit def newMapEncoder[T <: Map[_, _] : TypeTag]: Encoder[T] = ExpressionEncoder() + /** + * Notice that we serialize `Set` to Catalyst array. The set property is only kept when + * manipulating the domain objects. The serialization format doesn't keep the set property. + * When we have a Catalyst array which contains duplicated elements and convert it to + * `Dataset[Set[T]]` by using the encoder, the elements will be de-duplicated. + * + * @since 2.3.0 + */ + implicit def newSetEncoder[T <: Set[_] : TypeTag]: Encoder[T] = ExpressionEncoder() + // Arrays /** @since 1.6.1 */ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala index 5db354d79bb6e..b52d50b195bcc 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala @@ -460,6 +460,16 @@ class DataFrameAggregateSuite extends QueryTest with SharedSQLContext { df.select(collect_set($"a"), collect_set($"b")), Seq(Row(Seq(1, 2, 3), Seq(2, 4))) ) + + checkDataset( + df.select(collect_set($"a").as("aSet")).as[Set[Int]], + Set(1, 2, 3)) + checkDataset( + df.select(collect_set($"b").as("bSet")).as[Set[Int]], + Set(2, 4)) + checkDataset( + df.select(collect_set($"a"), collect_set($"b")).as[(Set[Int], Set[Int])], + Seq(Set(1, 2, 3) -> Set(2, 4)): _*) } test("collect functions structs") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala index a6847dcfbffc4..f62f9e23db66d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql +import scala.collection.immutable.{HashSet => HSet} import scala.collection.immutable.Queue import scala.collection.mutable.{LinkedHashMap => LHMap} import scala.collection.mutable.ArrayBuffer @@ -342,6 +343,31 @@ class DatasetPrimitiveSuite extends QueryTest with SharedSQLContext { LHMapClass(LHMap(1 -> 2)) -> LHMap("test" -> MapClass(Map(3 -> 4)))) } + test("arbitrary sets") { + checkDataset(Seq(Set(1, 2, 3, 4)).toDS(), Set(1, 2, 3, 4)) + checkDataset(Seq(Set(1.toLong, 2.toLong)).toDS(), Set(1.toLong, 2.toLong)) + checkDataset(Seq(Set(1.toDouble, 2.toDouble)).toDS(), Set(1.toDouble, 2.toDouble)) + checkDataset(Seq(Set(1.toFloat, 2.toFloat)).toDS(), Set(1.toFloat, 2.toFloat)) + checkDataset(Seq(Set(1.toByte, 2.toByte)).toDS(), Set(1.toByte, 2.toByte)) + checkDataset(Seq(Set(1.toShort, 2.toShort)).toDS(), Set(1.toShort, 2.toShort)) + checkDataset(Seq(Set(true, false)).toDS(), Set(true, false)) + checkDataset(Seq(Set("test1", "test2")).toDS(), Set("test1", "test2")) + checkDataset(Seq(Set(Tuple1(1), Tuple1(2))).toDS(), Set(Tuple1(1), Tuple1(2))) + + checkDataset(Seq(HSet(1, 2)).toDS(), HSet(1, 2)) + checkDataset(Seq(HSet(1.toLong, 2.toLong)).toDS(), HSet(1.toLong, 2.toLong)) + checkDataset(Seq(HSet(1.toDouble, 2.toDouble)).toDS(), HSet(1.toDouble, 2.toDouble)) + checkDataset(Seq(HSet(1.toFloat, 2.toFloat)).toDS(), HSet(1.toFloat, 2.toFloat)) + checkDataset(Seq(HSet(1.toByte, 2.toByte)).toDS(), HSet(1.toByte, 2.toByte)) + checkDataset(Seq(HSet(1.toShort, 2.toShort)).toDS(), HSet(1.toShort, 2.toShort)) + checkDataset(Seq(HSet(true, false)).toDS(), HSet(true, false)) + checkDataset(Seq(HSet("test1", "test2")).toDS(), HSet("test1", "test2")) + checkDataset(Seq(HSet(Tuple1(1), Tuple1(2))).toDS(), HSet(Tuple1(1), Tuple1(2))) + + checkDataset(Seq(Seq(Some(1), None), Seq(Some(2))).toDF("c").as[Set[Integer]], + Seq(Set[Integer](1, null), Set[Integer](2)): _*) + } + test("nested sequences") { checkDataset(Seq(Seq(Seq(1))).toDS(), Seq(Seq(1))) checkDataset(Seq(List(Queue(1))).toDS(), List(Queue(1))) @@ -352,6 +378,11 @@ class DatasetPrimitiveSuite extends QueryTest with SharedSQLContext { checkDataset(Seq(LHMap(Map(1 -> 2) -> 3)).toDS(), LHMap(Map(1 -> 2) -> 3)) } + test("nested set") { + checkDataset(Seq(Set(HSet(1, 2), HSet(3, 4))).toDS(), Set(HSet(1, 2), HSet(3, 4))) + checkDataset(Seq(HSet(Set(1, 2), Set(3, 4))).toDS(), HSet(Set(1, 2), Set(3, 4))) + } + test("package objects") { import packageobject._ checkDataset(Seq(PackageClass(1)).toDS(), PackageClass(1)) From bf66335acab3c0c188f6c378eb8aa6948a259cb2 Mon Sep 17 00:00:00 2001 From: Wang Gengliang Date: Thu, 6 Jul 2017 13:58:27 -0700 Subject: [PATCH 0888/1765] [SPARK-21323][SQL] Rename plans.logical.statsEstimation.Range to ValueInterval ## What changes were proposed in this pull request? Rename org.apache.spark.sql.catalyst.plans.logical.statsEstimation.Range to ValueInterval. The current naming is identical to logical operator "range". Refactoring it to ValueInterval is more accurate. ## How was this patch tested? unit test Please review http://spark.apache.org/contributing.html before opening a pull request. Author: Wang Gengliang Closes #18549 from gengliangwang/ValueInterval. --- .../statsEstimation/FilterEstimation.scala | 36 ++++++++-------- .../statsEstimation/JoinEstimation.scala | 14 +++---- .../{Range.scala => ValueInterval.scala} | 41 ++++++++++--------- 3 files changed, 48 insertions(+), 43 deletions(-) rename sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/{Range.scala => ValueInterval.scala} (65%) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/FilterEstimation.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/FilterEstimation.scala index 5a3bee7b9e449..e13db85c7a76e 100755 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/FilterEstimation.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/FilterEstimation.scala @@ -316,8 +316,8 @@ case class FilterEstimation(plan: Filter) extends Logging { // decide if the value is in [min, max] of the column. // We currently don't store min/max for binary/string type. // Hence, we assume it is in boundary for binary/string type. - val statsRange = Range(colStat.min, colStat.max, attr.dataType) - if (statsRange.contains(literal)) { + val statsInterval = ValueInterval(colStat.min, colStat.max, attr.dataType) + if (statsInterval.contains(literal)) { if (update) { // We update ColumnStat structure after apply this equality predicate: // Set distinctCount to 1, nullCount to 0, and min/max values (if exist) to the literal @@ -388,9 +388,10 @@ case class FilterEstimation(plan: Filter) extends Logging { // use [min, max] to filter the original hSet dataType match { case _: NumericType | BooleanType | DateType | TimestampType => - val statsRange = Range(colStat.min, colStat.max, dataType).asInstanceOf[NumericRange] + val statsInterval = + ValueInterval(colStat.min, colStat.max, dataType).asInstanceOf[NumericValueInterval] val validQuerySet = hSet.filter { v => - v != null && statsRange.contains(Literal(v, dataType)) + v != null && statsInterval.contains(Literal(v, dataType)) } if (validQuerySet.isEmpty) { @@ -440,12 +441,13 @@ case class FilterEstimation(plan: Filter) extends Logging { update: Boolean): Option[BigDecimal] = { val colStat = colStatsMap(attr) - val statsRange = Range(colStat.min, colStat.max, attr.dataType).asInstanceOf[NumericRange] - val max = statsRange.max.toBigDecimal - val min = statsRange.min.toBigDecimal + val statsInterval = + ValueInterval(colStat.min, colStat.max, attr.dataType).asInstanceOf[NumericValueInterval] + val max = statsInterval.max.toBigDecimal + val min = statsInterval.min.toBigDecimal val ndv = BigDecimal(colStat.distinctCount) - // determine the overlapping degree between predicate range and column's range + // determine the overlapping degree between predicate interval and column's interval val numericLiteral = if (literal.dataType == BooleanType) { if (literal.value.asInstanceOf[Boolean]) BigDecimal(1) else BigDecimal(0) } else { @@ -566,18 +568,18 @@ case class FilterEstimation(plan: Filter) extends Logging { } val colStatLeft = colStatsMap(attrLeft) - val statsRangeLeft = Range(colStatLeft.min, colStatLeft.max, attrLeft.dataType) - .asInstanceOf[NumericRange] - val maxLeft = statsRangeLeft.max - val minLeft = statsRangeLeft.min + val statsIntervalLeft = ValueInterval(colStatLeft.min, colStatLeft.max, attrLeft.dataType) + .asInstanceOf[NumericValueInterval] + val maxLeft = statsIntervalLeft.max + val minLeft = statsIntervalLeft.min val colStatRight = colStatsMap(attrRight) - val statsRangeRight = Range(colStatRight.min, colStatRight.max, attrRight.dataType) - .asInstanceOf[NumericRange] - val maxRight = statsRangeRight.max - val minRight = statsRangeRight.min + val statsIntervalRight = ValueInterval(colStatRight.min, colStatRight.max, attrRight.dataType) + .asInstanceOf[NumericValueInterval] + val maxRight = statsIntervalRight.max + val minRight = statsIntervalRight.min - // determine the overlapping degree between predicate range and column's range + // determine the overlapping degree between predicate interval and column's interval val allNotNull = (colStatLeft.nullCount == 0) && (colStatRight.nullCount == 0) val (noOverlap: Boolean, completeOverlap: Boolean) = op match { // Left < Right or Left <= Right diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/JoinEstimation.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/JoinEstimation.scala index f48196997a24d..dcbe36da91dfc 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/JoinEstimation.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/JoinEstimation.scala @@ -175,9 +175,9 @@ case class InnerOuterEstimation(join: Join) extends Logging { // Check if the two sides are disjoint val leftKeyStats = leftStats.attributeStats(leftKey) val rightKeyStats = rightStats.attributeStats(rightKey) - val lRange = Range(leftKeyStats.min, leftKeyStats.max, leftKey.dataType) - val rRange = Range(rightKeyStats.min, rightKeyStats.max, rightKey.dataType) - if (Range.isIntersected(lRange, rRange)) { + val lInterval = ValueInterval(leftKeyStats.min, leftKeyStats.max, leftKey.dataType) + val rInterval = ValueInterval(rightKeyStats.min, rightKeyStats.max, rightKey.dataType) + if (ValueInterval.isIntersected(lInterval, rInterval)) { // Get the largest ndv among pairs of join keys val maxNdv = leftKeyStats.distinctCount.max(rightKeyStats.distinctCount) if (maxNdv > ndvDenom) ndvDenom = maxNdv @@ -239,16 +239,16 @@ case class InnerOuterEstimation(join: Join) extends Logging { joinKeyPairs.foreach { case (leftKey, rightKey) => val leftKeyStats = leftStats.attributeStats(leftKey) val rightKeyStats = rightStats.attributeStats(rightKey) - val lRange = Range(leftKeyStats.min, leftKeyStats.max, leftKey.dataType) - val rRange = Range(rightKeyStats.min, rightKeyStats.max, rightKey.dataType) + val lInterval = ValueInterval(leftKeyStats.min, leftKeyStats.max, leftKey.dataType) + val rInterval = ValueInterval(rightKeyStats.min, rightKeyStats.max, rightKey.dataType) // When we reach here, join selectivity is not zero, so each pair of join keys should be // intersected. - assert(Range.isIntersected(lRange, rRange)) + assert(ValueInterval.isIntersected(lInterval, rInterval)) // Update intersected column stats assert(leftKey.dataType.sameType(rightKey.dataType)) val newNdv = leftKeyStats.distinctCount.min(rightKeyStats.distinctCount) - val (newMin, newMax) = Range.intersect(lRange, rRange, leftKey.dataType) + val (newMin, newMax) = ValueInterval.intersect(lInterval, rInterval, leftKey.dataType) val newMaxLen = math.min(leftKeyStats.maxLen, rightKeyStats.maxLen) val newAvgLen = (leftKeyStats.avgLen + rightKeyStats.avgLen) / 2 val newStats = ColumnStat(newNdv, newMin, newMax, 0, newAvgLen, newMaxLen) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/Range.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/ValueInterval.scala similarity index 65% rename from sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/Range.scala rename to sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/ValueInterval.scala index 4ac5ba5689f82..0caaf796a3b68 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/Range.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/ValueInterval.scala @@ -22,12 +22,12 @@ import org.apache.spark.sql.types._ /** Value range of a column. */ -trait Range { +trait ValueInterval { def contains(l: Literal): Boolean } -/** For simplicity we use decimal to unify operations of numeric ranges. */ -case class NumericRange(min: Decimal, max: Decimal) extends Range { +/** For simplicity we use decimal to unify operations of numeric intervals. */ +case class NumericValueInterval(min: Decimal, max: Decimal) extends ValueInterval { override def contains(l: Literal): Boolean = { val lit = EstimationUtils.toDecimal(l.value, l.dataType) min <= lit && max >= lit @@ -38,46 +38,49 @@ case class NumericRange(min: Decimal, max: Decimal) extends Range { * This version of Spark does not have min/max for binary/string types, we define their default * behaviors by this class. */ -class DefaultRange extends Range { +class DefaultValueInterval extends ValueInterval { override def contains(l: Literal): Boolean = true } /** This is for columns with only null values. */ -class NullRange extends Range { +class NullValueInterval extends ValueInterval { override def contains(l: Literal): Boolean = false } -object Range { - def apply(min: Option[Any], max: Option[Any], dataType: DataType): Range = dataType match { - case StringType | BinaryType => new DefaultRange() - case _ if min.isEmpty || max.isEmpty => new NullRange() +object ValueInterval { + def apply( + min: Option[Any], + max: Option[Any], + dataType: DataType): ValueInterval = dataType match { + case StringType | BinaryType => new DefaultValueInterval() + case _ if min.isEmpty || max.isEmpty => new NullValueInterval() case _ => - NumericRange( + NumericValueInterval( min = EstimationUtils.toDecimal(min.get, dataType), max = EstimationUtils.toDecimal(max.get, dataType)) } - def isIntersected(r1: Range, r2: Range): Boolean = (r1, r2) match { - case (_, _: DefaultRange) | (_: DefaultRange, _) => - // The DefaultRange represents string/binary types which do not have max/min stats, + def isIntersected(r1: ValueInterval, r2: ValueInterval): Boolean = (r1, r2) match { + case (_, _: DefaultValueInterval) | (_: DefaultValueInterval, _) => + // The DefaultValueInterval represents string/binary types which do not have max/min stats, // we assume they are intersected to be conservative on estimation true - case (_, _: NullRange) | (_: NullRange, _) => + case (_, _: NullValueInterval) | (_: NullValueInterval, _) => false - case (n1: NumericRange, n2: NumericRange) => + case (n1: NumericValueInterval, n2: NumericValueInterval) => n1.min.compareTo(n2.max) <= 0 && n1.max.compareTo(n2.min) >= 0 } /** - * Intersected results of two ranges. This is only for two overlapped ranges. + * Intersected results of two intervals. This is only for two overlapped intervals. * The outputs are the intersected min/max values. */ - def intersect(r1: Range, r2: Range, dt: DataType): (Option[Any], Option[Any]) = { + def intersect(r1: ValueInterval, r2: ValueInterval, dt: DataType): (Option[Any], Option[Any]) = { (r1, r2) match { - case (_, _: DefaultRange) | (_: DefaultRange, _) => + case (_, _: DefaultValueInterval) | (_: DefaultValueInterval, _) => // binary/string types don't support intersecting. (None, None) - case (n1: NumericRange, n2: NumericRange) => + case (n1: NumericValueInterval, n2: NumericValueInterval) => // Choose the maximum of two min values, and the minimum of two max values. val newMin = if (n1.min <= n2.min) n2.min else n1.min val newMax = if (n1.max <= n2.max) n1.max else n2.max From 0217dfd26f89133f146197359b556c9bf5aca172 Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Thu, 6 Jul 2017 17:28:20 -0700 Subject: [PATCH 0889/1765] [SPARK-21267][SS][DOCS] Update Structured Streaming Documentation ## What changes were proposed in this pull request? Few changes to the Structured Streaming documentation - Clarify that the entire stream input table is not materialized - Add information for Ganglia - Add Kafka Sink to the main docs - Removed a couple of leftover experimental tags - Added more associated reading material and talk videos. In addition, https://github.com/apache/spark/pull/16856 broke the link to the RDD programming guide in several places while renaming the page. This PR fixes those sameeragarwal cloud-fan. - Added a redirection to avoid breaking internal and possible external links. - Removed unnecessary redirection pages that were there since the separate scala, java, and python programming guides were merged together in 2013 or 2014. ## How was this patch tested? (Please explain how this patch was tested. E.g. unit tests, integration tests, manual tests) (If this patch involves UI changes, please attach a screenshot; otherwise, remove this) Please review http://spark.apache.org/contributing.html before opening a pull request. Author: Tathagata Das Closes #18485 from tdas/SPARK-21267. --- docs/_layouts/global.html | 7 +- docs/index.md | 13 +- docs/java-programming-guide.md | 7 - docs/programming-guide.md | 7 + docs/python-programming-guide.md | 7 - docs/rdd-programming-guide.md | 2 +- docs/scala-programming-guide.md | 7 - docs/sql-programming-guide.md | 16 +- .../structured-streaming-programming-guide.md | 172 +++++++++++++++--- .../scala/org/apache/spark/sql/Dataset.scala | 3 - 10 files changed, 169 insertions(+), 72 deletions(-) delete mode 100644 docs/java-programming-guide.md create mode 100644 docs/programming-guide.md delete mode 100644 docs/python-programming-guide.md delete mode 100644 docs/scala-programming-guide.md diff --git a/docs/_layouts/global.html b/docs/_layouts/global.html index c00d0db63cd10..570483c0b04ea 100755 --- a/docs/_layouts/global.html +++ b/docs/_layouts/global.html @@ -69,11 +69,10 @@ Programming Guides
    +### Reporting Metrics using Dropwizard +Spark supports reporting metrics using the [Dropwizard Library](monitoring.html#metrics). To enable metrics of Structured Streaming queries to be reported as well, you have to explicitly enable the configuration `spark.sql.streaming.metricsEnabled` in the SparkSession. + +
    +
    +{% highlight scala %} +spark.conf.set("spark.sql.streaming.metricsEnabled", "true") +// or +spark.sql("SET spark.sql.streaming.metricsEnabled=true") +{% endhighlight %} +
    +
    +{% highlight java %} +spark.conf().set("spark.sql.streaming.metricsEnabled", "true"); +// or +spark.sql("SET spark.sql.streaming.metricsEnabled=true"); +{% endhighlight %} +
    +
    +{% highlight python %} +spark.conf.set("spark.sql.streaming.metricsEnabled", "true") +# or +spark.sql("SET spark.sql.streaming.metricsEnabled=true") +{% endhighlight %} +
    +
    +{% highlight r %} +sql("SET spark.sql.streaming.metricsEnabled=true") +{% endhighlight %} +
    +
    + + +All queries started in the SparkSession after this configuration has been enabled will report metrics through Dropwizard to whatever [sinks](monitoring.html#metrics) have been configured (e.g. Ganglia, Graphite, JMX, etc.). + ## Recovering from Failures with Checkpointing In case of a failure or intentional shutdown, you can recover the previous progress and state of a previous query, and continue where it left off. This is done using checkpointing and write ahead logs. You can configure a query with a checkpoint location, and the query will save all the progress information (i.e. range of offsets processed in each trigger) and the running aggregates (e.g. word counts in the [quick example](#quick-example)) to the checkpoint location. This checkpoint location has to be a path in an HDFS compatible file system, and can be set as an option in the DataStreamWriter when [starting a query](#starting-streaming-queries). @@ -1971,8 +2082,23 @@ write.stream(aggDF, "memory", outputMode = "complete", checkpointLocation = "pat -# Where to go from here -- Examples: See and run the -[Scala]({{site.SPARK_GITHUB_URL}}/tree/master/examples/src/main/scala/org/apache/spark/examples/sql/streaming)/[Java]({{site.SPARK_GITHUB_URL}}/tree/master/examples/src/main/java/org/apache/spark/examples/sql/streaming)/[Python]({{site.SPARK_GITHUB_URL}}/tree/master/examples/src/main/python/sql/streaming)/[R]({{site.SPARK_GITHUB_URL}}/tree/master/examples/src/main/r/streaming) -examples. +# Additional Information + +**Further Reading** + +- See and run the + [Scala]({{site.SPARK_GITHUB_URL}}/tree/v{{site.SPARK_VERSION_SHORT}}/examples/src/main/scala/org/apache/spark/examples/sql/streaming)/[Java]({{site.SPARK_GITHUB_URL}}/tree/v{{site.SPARK_VERSION_SHORT}}/examples/src/main/java/org/apache/spark/examples/sql/streaming)/[Python]({{site.SPARK_GITHUB_URL}}/tree/v{{site.SPARK_VERSION_SHORT}}/examples/src/main/python/sql/streaming)/[R]({{site.SPARK_GITHUB_URL}}/tree/v{{site.SPARK_VERSION_SHORT}}/examples/src/main/r/streaming) + examples. + - [Instructions](index.html#running-the-examples-and-shell) on how to run Spark examples +- Read about integrating with Kafka in the [Structured Streaming Kafka Integration Guide](structured-streaming-kafka-integration.html) +- Read more details about using DataFrames/Datasets in the [Spark SQL Programming Guide](sql-programming-guide.html) +- Third-party Blog Posts + - [Real-time Streaming ETL with Structured Streaming in Apache Spark 2.1 (Databricks Blog)](https://databricks.com/blog/2017/01/19/real-time-streaming-etl-structured-streaming-apache-spark-2-1.html) + - [Real-Time End-to-End Integration with Apache Kafka in Apache Spark’s Structured Streaming (Databricks Blog)](https://databricks.com/blog/2017/04/04/real-time-end-to-end-integration-with-apache-kafka-in-apache-sparks-structured-streaming.html) + - [Event-time Aggregation and Watermarking in Apache Spark’s Structured Streaming (Databricks Blog)](https://databricks.com/blog/2017/05/08/event-time-aggregation-watermarking-apache-sparks-structured-streaming.html) + +**Talks** + +- Spark Summit 2017 Talk - [Easy, Scalable, Fault-tolerant Stream Processing with Structured Streaming in Apache Spark](https://spark-summit.org/2017/events/easy-scalable-fault-tolerant-stream-processing-with-structured-streaming-in-apache-spark/) - Spark Summit 2016 Talk - [A Deep Dive into Structured Streaming](https://spark-summit.org/2016/events/a-deep-dive-into-structured-streaming/) + diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index 7be4aa1ca9562..b1638a2180b07 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -520,7 +520,6 @@ class Dataset[T] private[sql]( * @group streaming * @since 2.0.0 */ - @Experimental @InterfaceStability.Evolving def isStreaming: Boolean = logicalPlan.isStreaming @@ -581,7 +580,6 @@ class Dataset[T] private[sql]( } /** - * :: Experimental :: * Defines an event time watermark for this [[Dataset]]. A watermark tracks a point in time * before which we assume no more late data is going to arrive. * @@ -605,7 +603,6 @@ class Dataset[T] private[sql]( * @group streaming * @since 2.1.0 */ - @Experimental @InterfaceStability.Evolving // We only accept an existing column name, not a derived column here as a watermark that is // defined on a derived column cannot referenced elsewhere in the plan. From 40c7add3a4c811202d1fa2be9606aca08df81266 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Fri, 7 Jul 2017 08:44:31 +0800 Subject: [PATCH 0890/1765] [SPARK-20946][SQL] Do not update conf for existing SparkContext in SparkSession.getOrCreate ## What changes were proposed in this pull request? SparkContext is shared by all sessions, we should not update its conf for only one session. ## How was this patch tested? existing tests Author: Wenchen Fan Closes #18536 from cloud-fan/config. --- .../spark/ml/recommendation/ALSSuite.scala | 4 +--- .../apache/spark/ml/tree/impl/TreeTests.scala | 2 -- .../org/apache/spark/sql/SparkSession.scala | 19 +++++++------------ .../spark/sql/SparkSessionBuilderSuite.scala | 8 +++----- 4 files changed, 11 insertions(+), 22 deletions(-) diff --git a/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala index 3094f52ba1bc5..b57fc8d21ab34 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala @@ -818,15 +818,13 @@ class ALSCleanerSuite extends SparkFunSuite { FileUtils.listFiles(localDir, TrueFileFilter.INSTANCE, TrueFileFilter.INSTANCE).asScala.toSet try { conf.set("spark.local.dir", localDir.getAbsolutePath) - val sc = new SparkContext("local[2]", "test", conf) + val sc = new SparkContext("local[2]", "ALSCleanerSuite", conf) try { sc.setCheckpointDir(checkpointDir.getAbsolutePath) // Generate test data val (training, _) = ALSSuite.genImplicitTestData(sc, 20, 5, 1, 0.2, 0) // Implicitly test the cleaning of parents during ALS training val spark = SparkSession.builder - .master("local[2]") - .appName("ALSCleanerSuite") .sparkContext(sc) .getOrCreate() import spark.implicits._ diff --git a/mllib/src/test/scala/org/apache/spark/ml/tree/impl/TreeTests.scala b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/TreeTests.scala index 92a236928e90b..b6894b30b0c2b 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/tree/impl/TreeTests.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/TreeTests.scala @@ -43,8 +43,6 @@ private[ml] object TreeTests extends SparkFunSuite { categoricalFeatures: Map[Int, Int], numClasses: Int): DataFrame = { val spark = SparkSession.builder() - .master("local[2]") - .appName("TreeTests") .sparkContext(data.sparkContext) .getOrCreate() import spark.implicits._ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala index 0ddcd2111aa58..6dfe8a66baa9b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala @@ -867,7 +867,7 @@ object SparkSession { * * @since 2.2.0 */ - def withExtensions(f: SparkSessionExtensions => Unit): Builder = { + def withExtensions(f: SparkSessionExtensions => Unit): Builder = synchronized { f(extensions) this } @@ -912,21 +912,16 @@ object SparkSession { // No active nor global default session. Create a new one. val sparkContext = userSuppliedContext.getOrElse { - // set app name if not given - val randomAppName = java.util.UUID.randomUUID().toString val sparkConf = new SparkConf() options.foreach { case (k, v) => sparkConf.set(k, v) } + + // set a random app name if not given. if (!sparkConf.contains("spark.app.name")) { - sparkConf.setAppName(randomAppName) - } - val sc = SparkContext.getOrCreate(sparkConf) - // maybe this is an existing SparkContext, update its SparkConf which maybe used - // by SparkSession - options.foreach { case (k, v) => sc.conf.set(k, v) } - if (!sc.conf.contains("spark.app.name")) { - sc.conf.setAppName(randomAppName) + sparkConf.setAppName(java.util.UUID.randomUUID().toString) } - sc + + SparkContext.getOrCreate(sparkConf) + // Do not update `SparkConf` for existing `SparkContext`, as it's shared by all sessions. } // Initialize extensions if the user has defined a configurator class. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionBuilderSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionBuilderSuite.scala index cdac6827082c4..770e15629c839 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionBuilderSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionBuilderSuite.scala @@ -102,11 +102,9 @@ class SparkSessionBuilderSuite extends SparkFunSuite { assert(session.conf.get("key1") == "value1") assert(session.conf.get("key2") == "value2") assert(session.sparkContext == sparkContext2) - assert(session.sparkContext.conf.get("key1") == "value1") - // If the created sparkContext is not passed through the Builder's API sparkContext, - // the conf of this sparkContext will also contain the conf set through the API config. - assert(session.sparkContext.conf.get("key2") == "value2") - assert(session.sparkContext.conf.get("spark.app.name") == "test") + // We won't update conf for existing `SparkContext` + assert(!sparkContext2.conf.contains("key2")) + assert(sparkContext2.conf.get("key1") == "value1") session.stop() } From e5bb26174d3336e07dd670eec4fd2137df346163 Mon Sep 17 00:00:00 2001 From: Jacek Laskowski Date: Thu, 6 Jul 2017 18:11:41 -0700 Subject: [PATCH 0891/1765] [SPARK-21329][SS] Make EventTimeWatermarkExec explicitly UnaryExecNode ## What changes were proposed in this pull request? Making EventTimeWatermarkExec explicitly UnaryExecNode /cc tdas zsxwing ## How was this patch tested? Local build. Author: Jacek Laskowski Closes #18509 from jaceklaskowski/EventTimeWatermarkExec-UnaryExecNode. --- .../sql/execution/streaming/EventTimeWatermarkExec.scala | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/EventTimeWatermarkExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/EventTimeWatermarkExec.scala index 25cf609fc336e..87e5b78550423 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/EventTimeWatermarkExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/EventTimeWatermarkExec.scala @@ -21,7 +21,7 @@ import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{Attribute, UnsafeProjection} import org.apache.spark.sql.catalyst.plans.logical.EventTimeWatermark -import org.apache.spark.sql.execution.SparkPlan +import org.apache.spark.sql.execution.{SparkPlan, UnaryExecNode} import org.apache.spark.sql.types.MetadataBuilder import org.apache.spark.unsafe.types.CalendarInterval import org.apache.spark.util.AccumulatorV2 @@ -81,7 +81,7 @@ class EventTimeStatsAccum(protected var currentStats: EventTimeStats = EventTime case class EventTimeWatermarkExec( eventTime: Attribute, delay: CalendarInterval, - child: SparkPlan) extends SparkPlan { + child: SparkPlan) extends UnaryExecNode { val eventTimeStats = new EventTimeStatsAccum() val delayMs = EventTimeWatermark.getDelayMs(delay) @@ -117,6 +117,4 @@ case class EventTimeWatermarkExec( a } } - - override def children: Seq[SparkPlan] = child :: Nil } From d451b7f43d559aa1efd7ac3d1cbec5249f3a7a24 Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Fri, 7 Jul 2017 12:24:03 +0800 Subject: [PATCH 0892/1765] [SPARK-21326][SPARK-21066][ML] Use TextFileFormat in LibSVMFileFormat and allow multiple input paths for determining numFeatures ## What changes were proposed in this pull request? This is related with [SPARK-19918](https://issues.apache.org/jira/browse/SPARK-19918) and [SPARK-18362](https://issues.apache.org/jira/browse/SPARK-18362). This PR proposes to use `TextFileFormat` and allow multiple input paths (but with a warning) when determining the number of features in LibSVM data source via an extra scan. There are three points here: - The main advantage of this change should be to remove file-listing bottlenecks in driver side. - Another advantage is ones from using `FileScanRDD`. For example, I guess we can use `spark.sql.files.ignoreCorruptFiles` option when determining the number of features. - We can unify the schema inference code path in text based data sources. This is also a preparation for [SPARK-21289](https://issues.apache.org/jira/browse/SPARK-21289). ## How was this patch tested? Unit tests in `LibSVMRelationSuite`. Closes #18288 Author: hyukjinkwon Closes #18556 from HyukjinKwon/libsvm-schema. --- .../ml/source/libsvm/LibSVMRelation.scala | 26 +++++++++---------- .../org/apache/spark/mllib/util/MLUtils.scala | 25 ++++++++++++++++-- .../source/libsvm/LibSVMRelationSuite.scala | 17 +++++++++--- 3 files changed, 49 insertions(+), 19 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala b/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala index f68847a664b69..dec118330aec6 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala @@ -23,6 +23,7 @@ import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.{FileStatus, Path} import org.apache.hadoop.mapreduce.{Job, TaskAttemptContext} +import org.apache.spark.internal.Logging import org.apache.spark.TaskContext import org.apache.spark.ml.feature.LabeledPoint import org.apache.spark.ml.linalg.{Vectors, VectorUDT} @@ -66,7 +67,10 @@ private[libsvm] class LibSVMOutputWriter( /** @see [[LibSVMDataSource]] for public documentation. */ // If this is moved or renamed, please update DataSource's backwardCompatibilityMap. -private[libsvm] class LibSVMFileFormat extends TextBasedFileFormat with DataSourceRegister { +private[libsvm] class LibSVMFileFormat + extends TextBasedFileFormat + with DataSourceRegister + with Logging { override def shortName(): String = "libsvm" @@ -89,18 +93,14 @@ private[libsvm] class LibSVMFileFormat extends TextBasedFileFormat with DataSour files: Seq[FileStatus]): Option[StructType] = { val libSVMOptions = new LibSVMOptions(options) val numFeatures: Int = libSVMOptions.numFeatures.getOrElse { - // Infers number of features if the user doesn't specify (a valid) one. - val dataFiles = files.filterNot(_.getPath.getName startsWith "_") - val path = if (dataFiles.length == 1) { - dataFiles.head.getPath.toUri.toString - } else if (dataFiles.isEmpty) { - throw new IOException("No input path specified for libsvm data") - } else { - throw new IOException("Multiple input paths are not supported for libsvm data.") - } - - val sc = sparkSession.sparkContext - val parsed = MLUtils.parseLibSVMFile(sc, path, sc.defaultParallelism) + require(files.nonEmpty, "No input path specified for libsvm data") + logWarning( + "'numFeatures' option not specified, determining the number of features by going " + + "though the input. If you know the number in advance, please specify it via " + + "'numFeatures' option to avoid the extra scan.") + + val paths = files.map(_.getPath.toUri.toString) + val parsed = MLUtils.parseLibSVMFile(sparkSession, paths) MLUtils.computeNumFeatures(parsed) } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala b/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala index 4fdad05973969..14af8b5c73870 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala @@ -28,8 +28,10 @@ import org.apache.spark.mllib.linalg._ import org.apache.spark.mllib.linalg.BLAS.dot import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.rdd.{PartitionwiseSampledRDD, RDD} -import org.apache.spark.sql.{DataFrame, Dataset} -import org.apache.spark.sql.functions.{col, udf} +import org.apache.spark.sql.{DataFrame, Dataset, SparkSession} +import org.apache.spark.sql.execution.datasources.DataSource +import org.apache.spark.sql.execution.datasources.text.TextFileFormat +import org.apache.spark.sql.functions._ import org.apache.spark.storage.StorageLevel import org.apache.spark.util.random.BernoulliCellSampler @@ -102,6 +104,25 @@ object MLUtils extends Logging { .map(parseLibSVMRecord) } + private[spark] def parseLibSVMFile( + sparkSession: SparkSession, paths: Seq[String]): RDD[(Double, Array[Int], Array[Double])] = { + val lines = sparkSession.baseRelationToDataFrame( + DataSource.apply( + sparkSession, + paths = paths, + className = classOf[TextFileFormat].getName + ).resolveRelation(checkFilesExist = false)) + .select("value") + + import lines.sqlContext.implicits._ + + lines.select(trim($"value").as("line")) + .filter(not((length($"line") === 0).or($"line".startsWith("#")))) + .as[String] + .rdd + .map(MLUtils.parseLibSVMRecord) + } + private[spark] def parseLibSVMRecord(line: String): (Double, Array[Int], Array[Double]) = { val items = line.split(' ') val label = items.head.toDouble diff --git a/mllib/src/test/scala/org/apache/spark/ml/source/libsvm/LibSVMRelationSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/source/libsvm/LibSVMRelationSuite.scala index e164d279f3f02..a67e49d54e148 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/source/libsvm/LibSVMRelationSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/source/libsvm/LibSVMRelationSuite.scala @@ -35,15 +35,22 @@ class LibSVMRelationSuite extends SparkFunSuite with MLlibTestSparkContext { override def beforeAll(): Unit = { super.beforeAll() - val lines = + val lines0 = """ |1 1:1.0 3:2.0 5:3.0 |0 + """.stripMargin + val lines1 = + """ |0 2:4.0 4:5.0 6:6.0 """.stripMargin val dir = Utils.createDirectory(tempDir.getCanonicalPath, "data") - val file = new File(dir, "part-00000") - Files.write(lines, file, StandardCharsets.UTF_8) + val succ = new File(dir, "_SUCCESS") + val file0 = new File(dir, "part-00000") + val file1 = new File(dir, "part-00001") + Files.write("", succ, StandardCharsets.UTF_8) + Files.write(lines0, file0, StandardCharsets.UTF_8) + Files.write(lines1, file1, StandardCharsets.UTF_8) path = dir.toURI.toString } @@ -145,7 +152,9 @@ class LibSVMRelationSuite extends SparkFunSuite with MLlibTestSparkContext { test("create libsvmTable table without schema and path") { try { - val e = intercept[IOException](spark.sql("CREATE TABLE libsvmTable USING libsvm")) + val e = intercept[IllegalArgumentException] { + spark.sql("CREATE TABLE libsvmTable USING libsvm") + } assert(e.getMessage.contains("No input path specified for libsvm data")) } finally { spark.sql("DROP TABLE IF EXISTS libsvmTable") From 53c2eb59b2cc557081f6a252748dc38511601b0d Mon Sep 17 00:00:00 2001 From: Takuya UESHIN Date: Fri, 7 Jul 2017 14:05:22 +0900 Subject: [PATCH 0893/1765] [SPARK-21327][SQL][PYSPARK] ArrayConstructor should handle an array of typecode 'l' as long rather than int in Python 2. ## What changes were proposed in this pull request? Currently `ArrayConstructor` handles an array of typecode `'l'` as `int` when converting Python object in Python 2 into Java object, so if the value is larger than `Integer.MAX_VALUE` or smaller than `Integer.MIN_VALUE` then the overflow occurs. ```python import array data = [Row(longarray=array.array('l', [-9223372036854775808, 0, 9223372036854775807]))] df = spark.createDataFrame(data) df.show(truncate=False) ``` ``` +----------+ |longarray | +----------+ |[0, 0, -1]| +----------+ ``` This should be: ``` +----------------------------------------------+ |longarray | +----------------------------------------------+ |[-9223372036854775808, 0, 9223372036854775807]| +----------------------------------------------+ ``` ## How was this patch tested? Added a test and existing tests. Author: Takuya UESHIN Closes #18553 from ueshin/issues/SPARK-21327. --- .../scala/org/apache/spark/api/python/SerDeUtil.scala | 10 ++++++++++ python/pyspark/sql/tests.py | 6 ++++++ 2 files changed, 16 insertions(+) diff --git a/core/src/main/scala/org/apache/spark/api/python/SerDeUtil.scala b/core/src/main/scala/org/apache/spark/api/python/SerDeUtil.scala index 6e4eab4b805c1..42f67e8dbe865 100644 --- a/core/src/main/scala/org/apache/spark/api/python/SerDeUtil.scala +++ b/core/src/main/scala/org/apache/spark/api/python/SerDeUtil.scala @@ -73,6 +73,16 @@ private[spark] object SerDeUtil extends Logging { // This must be ISO 8859-1 / Latin 1, not UTF-8, to interoperate correctly val data = args(1).asInstanceOf[String].getBytes(StandardCharsets.ISO_8859_1) construct(typecode, machineCodes(typecode), data) + } else if (args.length == 2 && args(0) == "l") { + // On Python 2, an array of typecode 'l' should be handled as long rather than int. + val values = args(1).asInstanceOf[JArrayList[_]] + val result = new Array[Long](values.size) + var i = 0 + while (i < values.size) { + result(i) = values.get(i).asInstanceOf[Number].longValue() + i += 1 + } + result } else { super.construct(args) } diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index c0e3b8d132396..9db2f40474f70 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -2342,6 +2342,12 @@ def test_to_pandas(self): self.assertEquals(types[2], np.bool) self.assertEquals(types[3], np.float32) + def test_create_dataframe_from_array_of_long(self): + import array + data = [Row(longarray=array.array('l', [-9223372036854775808, 0, 9223372036854775807]))] + df = self.spark.createDataFrame(data) + self.assertEqual(df.first(), Row(longarray=[-9223372036854775808, 0, 9223372036854775807])) + class HiveSparkSubmitTests(SparkSubmitTests): From c09b31eb8fa83d5463a045c9278f5874ae505a8e Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Fri, 7 Jul 2017 13:09:32 +0800 Subject: [PATCH 0894/1765] [SPARK-21217][SQL] Support ColumnVector.Array.toArray() ## What changes were proposed in this pull request? This PR implements bulk-copy for `ColumnVector.Array.toArray()` methods (e.g. `toIntArray()`) in `ColumnVector.Array` by using `System.arrayCopy()` or `Platform.copyMemory()`. Before this PR, when one of these method is called, the generic method in `ArrayData` is called. It is not fast since element-wise copy is performed. This PR can improve performance of a benchmark program by 1.9x and 3.2x. Without this PR ``` OpenJDK 64-Bit Server VM 1.8.0_131-8u131-b11-0ubuntu1.16.04.2-b11 on Linux 4.4.0-66-generic Intel(R) Xeon(R) CPU E5-2667 v3 3.20GHz Int Array Best/Avg Time(ms) Rate(M/s) Per Row(ns) ------------------------------------------------------------------------------------------------ ON_HEAP 586 / 628 14.3 69.9 OFF_HEAP 893 / 902 9.4 106.5 ``` With this PR ``` OpenJDK 64-Bit Server VM 1.8.0_131-8u131-b11-0ubuntu1.16.04.2-b11 on Linux 4.4.0-66-generic Intel(R) Xeon(R) CPU E5-2667 v3 3.20GHz Int Array Best/Avg Time(ms) Rate(M/s) Per Row(ns) ------------------------------------------------------------------------------------------------ ON_HEAP 306 / 331 27.4 36.4 OFF_HEAP 282 / 287 29.8 33.6 ``` Source program ``` (MemoryMode.ON_HEAP :: MemoryMode.OFF_HEAP :: Nil).foreach { memMode => { val len = 8 * 1024 * 1024 val column = ColumnVector.allocate(len * 2, new ArrayType(IntegerType, false), memMode) val data = column.arrayData var i = 0 while (i < len) { data.putInt(i, i) i += 1 } column.putArray(0, 0, len) val benchmark = new Benchmark("Int Array", len, minNumIters = 20) benchmark.addCase(s"$memMode") { iter => var i = 0 while (i < 50) { column.getArray(0).toIntArray i += 1 } } benchmark.run }} ``` ## How was this patch tested? Added test suite Author: Kazuaki Ishizaki Closes #18425 from kiszk/SPARK-21217. --- .../execution/vectorized/ColumnVector.java | 56 ++++++++++++++++++ .../vectorized/OffHeapColumnVector.java | 58 +++++++++++++++++++ .../vectorized/OnHeapColumnVector.java | 58 +++++++++++++++++++ .../vectorized/ColumnarBatchSuite.scala | 49 ++++++++++++++++ 4 files changed, 221 insertions(+) diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVector.java index 24260a60197f2..0c027f80d48cc 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVector.java @@ -100,6 +100,27 @@ public ArrayData copy() { throw new UnsupportedOperationException(); } + @Override + public boolean[] toBooleanArray() { return data.getBooleans(offset, length); } + + @Override + public byte[] toByteArray() { return data.getBytes(offset, length); } + + @Override + public short[] toShortArray() { return data.getShorts(offset, length); } + + @Override + public int[] toIntArray() { return data.getInts(offset, length); } + + @Override + public long[] toLongArray() { return data.getLongs(offset, length); } + + @Override + public float[] toFloatArray() { return data.getFloats(offset, length); } + + @Override + public double[] toDoubleArray() { return data.getDoubles(offset, length); } + // TODO: this is extremely expensive. @Override public Object[] array() { @@ -366,6 +387,11 @@ private void throwUnsupportedException(int requiredCapacity, Throwable cause) { */ public abstract boolean getBoolean(int rowId); + /** + * Gets values from [rowId, rowId + count) + */ + public abstract boolean[] getBooleans(int rowId, int count); + /** * Sets the value at rowId to `value`. */ @@ -386,6 +412,11 @@ private void throwUnsupportedException(int requiredCapacity, Throwable cause) { */ public abstract byte getByte(int rowId); + /** + * Gets values from [rowId, rowId + count) + */ + public abstract byte[] getBytes(int rowId, int count); + /** * Sets the value at rowId to `value`. */ @@ -406,6 +437,11 @@ private void throwUnsupportedException(int requiredCapacity, Throwable cause) { */ public abstract short getShort(int rowId); + /** + * Gets values from [rowId, rowId + count) + */ + public abstract short[] getShorts(int rowId, int count); + /** * Sets the value at rowId to `value`. */ @@ -432,6 +468,11 @@ private void throwUnsupportedException(int requiredCapacity, Throwable cause) { */ public abstract int getInt(int rowId); + /** + * Gets values from [rowId, rowId + count) + */ + public abstract int[] getInts(int rowId, int count); + /** * Returns the dictionary Id for rowId. * This should only be called when the ColumnVector is dictionaryIds. @@ -465,6 +506,11 @@ private void throwUnsupportedException(int requiredCapacity, Throwable cause) { */ public abstract long getLong(int rowId); + /** + * Gets values from [rowId, rowId + count) + */ + public abstract long[] getLongs(int rowId, int count); + /** * Sets the value at rowId to `value`. */ @@ -491,6 +537,11 @@ private void throwUnsupportedException(int requiredCapacity, Throwable cause) { */ public abstract float getFloat(int rowId); + /** + * Gets values from [rowId, rowId + count) + */ + public abstract float[] getFloats(int rowId, int count); + /** * Sets the value at rowId to `value`. */ @@ -517,6 +568,11 @@ private void throwUnsupportedException(int requiredCapacity, Throwable cause) { */ public abstract double getDouble(int rowId); + /** + * Gets values from [rowId, rowId + count) + */ + public abstract double[] getDoubles(int rowId, int count); + /** * Puts a byte array that already exists in this column. */ diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OffHeapColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OffHeapColumnVector.java index a7d3744d00e91..2d1f3da8e7463 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OffHeapColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OffHeapColumnVector.java @@ -134,6 +134,16 @@ public void putBooleans(int rowId, int count, boolean value) { @Override public boolean getBoolean(int rowId) { return Platform.getByte(null, data + rowId) == 1; } + @Override + public boolean[] getBooleans(int rowId, int count) { + assert(dictionary == null); + boolean[] array = new boolean[count]; + for (int i = 0; i < count; ++i) { + array[i] = (Platform.getByte(null, data + rowId + i) == 1); + } + return array; + } + // // APIs dealing with Bytes // @@ -165,6 +175,14 @@ public byte getByte(int rowId) { } } + @Override + public byte[] getBytes(int rowId, int count) { + assert(dictionary == null); + byte[] array = new byte[count]; + Platform.copyMemory(null, data + rowId, array, Platform.BYTE_ARRAY_OFFSET, count); + return array; + } + // // APIs dealing with shorts // @@ -197,6 +215,14 @@ public short getShort(int rowId) { } } + @Override + public short[] getShorts(int rowId, int count) { + assert(dictionary == null); + short[] array = new short[count]; + Platform.copyMemory(null, data + rowId * 2, array, Platform.SHORT_ARRAY_OFFSET, count * 2); + return array; + } + // // APIs dealing with ints // @@ -244,6 +270,14 @@ public int getInt(int rowId) { } } + @Override + public int[] getInts(int rowId, int count) { + assert(dictionary == null); + int[] array = new int[count]; + Platform.copyMemory(null, data + rowId * 4, array, Platform.INT_ARRAY_OFFSET, count * 4); + return array; + } + /** * Returns the dictionary Id for rowId. * This should only be called when the ColumnVector is dictionaryIds. @@ -302,6 +336,14 @@ public long getLong(int rowId) { } } + @Override + public long[] getLongs(int rowId, int count) { + assert(dictionary == null); + long[] array = new long[count]; + Platform.copyMemory(null, data + rowId * 8, array, Platform.LONG_ARRAY_OFFSET, count * 8); + return array; + } + // // APIs dealing with floats // @@ -348,6 +390,14 @@ public float getFloat(int rowId) { } } + @Override + public float[] getFloats(int rowId, int count) { + assert(dictionary == null); + float[] array = new float[count]; + Platform.copyMemory(null, data + rowId * 4, array, Platform.FLOAT_ARRAY_OFFSET, count * 4); + return array; + } + // // APIs dealing with doubles @@ -395,6 +445,14 @@ public double getDouble(int rowId) { } } + @Override + public double[] getDoubles(int rowId, int count) { + assert(dictionary == null); + double[] array = new double[count]; + Platform.copyMemory(null, data + rowId * 8, array, Platform.DOUBLE_ARRAY_OFFSET, count * 8); + return array; + } + // // APIs dealing with Arrays. // diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OnHeapColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OnHeapColumnVector.java index 94ed32294cfae..506434364be48 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OnHeapColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OnHeapColumnVector.java @@ -130,6 +130,16 @@ public boolean getBoolean(int rowId) { return byteData[rowId] == 1; } + @Override + public boolean[] getBooleans(int rowId, int count) { + assert(dictionary == null); + boolean[] array = new boolean[count]; + for (int i = 0; i < count; ++i) { + array[i] = (byteData[rowId + i] == 1); + } + return array; + } + // // @@ -162,6 +172,14 @@ public byte getByte(int rowId) { } } + @Override + public byte[] getBytes(int rowId, int count) { + assert(dictionary == null); + byte[] array = new byte[count]; + System.arraycopy(byteData, rowId, array, 0, count); + return array; + } + // // APIs dealing with Shorts // @@ -192,6 +210,14 @@ public short getShort(int rowId) { } } + @Override + public short[] getShorts(int rowId, int count) { + assert(dictionary == null); + short[] array = new short[count]; + System.arraycopy(shortData, rowId, array, 0, count); + return array; + } + // // APIs dealing with Ints @@ -234,6 +260,14 @@ public int getInt(int rowId) { } } + @Override + public int[] getInts(int rowId, int count) { + assert(dictionary == null); + int[] array = new int[count]; + System.arraycopy(intData, rowId, array, 0, count); + return array; + } + /** * Returns the dictionary Id for rowId. * This should only be called when the ColumnVector is dictionaryIds. @@ -286,6 +320,14 @@ public long getLong(int rowId) { } } + @Override + public long[] getLongs(int rowId, int count) { + assert(dictionary == null); + long[] array = new long[count]; + System.arraycopy(longData, rowId, array, 0, count); + return array; + } + // // APIs dealing with floats // @@ -325,6 +367,14 @@ public float getFloat(int rowId) { } } + @Override + public float[] getFloats(int rowId, int count) { + assert(dictionary == null); + float[] array = new float[count]; + System.arraycopy(floatData, rowId, array, 0, count); + return array; + } + // // APIs dealing with doubles // @@ -366,6 +416,14 @@ public double getDouble(int rowId) { } } + @Override + public double[] getDoubles(int rowId, int count) { + assert(dictionary == null); + double[] array = new double[count]; + System.arraycopy(doubleData, rowId, array, 0, count); + return array; + } + // // APIs dealing with Arrays // diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala index 80d41577dcf2d..ccf7aa7022a2a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala @@ -709,6 +709,55 @@ class ColumnarBatchSuite extends SparkFunSuite { }} } + test("toArray for primitive types") { + // (MemoryMode.ON_HEAP :: MemoryMode.OFF_HEAP :: Nil).foreach { memMode => { + (MemoryMode.ON_HEAP :: Nil).foreach { memMode => { + val len = 4 + + val columnBool = ColumnVector.allocate(len, new ArrayType(BooleanType, false), memMode) + val boolArray = Array(false, true, false, true) + boolArray.zipWithIndex.map { case (v, i) => columnBool.arrayData.putBoolean(i, v) } + columnBool.putArray(0, 0, len) + assert(columnBool.getArray(0).toBooleanArray === boolArray) + + val columnByte = ColumnVector.allocate(len, new ArrayType(ByteType, false), memMode) + val byteArray = Array[Byte](0, 1, 2, 3) + byteArray.zipWithIndex.map { case (v, i) => columnByte.arrayData.putByte(i, v) } + columnByte.putArray(0, 0, len) + assert(columnByte.getArray(0).toByteArray === byteArray) + + val columnShort = ColumnVector.allocate(len, new ArrayType(ShortType, false), memMode) + val shortArray = Array[Short](0, 1, 2, 3) + shortArray.zipWithIndex.map { case (v, i) => columnShort.arrayData.putShort(i, v) } + columnShort.putArray(0, 0, len) + assert(columnShort.getArray(0).toShortArray === shortArray) + + val columnInt = ColumnVector.allocate(len, new ArrayType(IntegerType, false), memMode) + val intArray = Array(0, 1, 2, 3) + intArray.zipWithIndex.map { case (v, i) => columnInt.arrayData.putInt(i, v) } + columnInt.putArray(0, 0, len) + assert(columnInt.getArray(0).toIntArray === intArray) + + val columnLong = ColumnVector.allocate(len, new ArrayType(LongType, false), memMode) + val longArray = Array[Long](0, 1, 2, 3) + longArray.zipWithIndex.map { case (v, i) => columnLong.arrayData.putLong(i, v) } + columnLong.putArray(0, 0, len) + assert(columnLong.getArray(0).toLongArray === longArray) + + val columnFloat = ColumnVector.allocate(len, new ArrayType(FloatType, false), memMode) + val floatArray = Array(0.0F, 1.1F, 2.2F, 3.3F) + floatArray.zipWithIndex.map { case (v, i) => columnFloat.arrayData.putFloat(i, v) } + columnFloat.putArray(0, 0, len) + assert(columnFloat.getArray(0).toFloatArray === floatArray) + + val columnDouble = ColumnVector.allocate(len, new ArrayType(DoubleType, false), memMode) + val doubleArray = Array(0.0, 1.1, 2.2, 3.3) + doubleArray.zipWithIndex.map { case (v, i) => columnDouble.arrayData.putDouble(i, v) } + columnDouble.putArray(0, 0, len) + assert(columnDouble.getArray(0).toDoubleArray === doubleArray) + }} + } + test("Struct Column") { (MemoryMode.ON_HEAP :: MemoryMode.OFF_HEAP :: Nil).foreach { memMode => { val schema = new StructType().add("int", IntegerType).add("double", DoubleType) From 5df99bd364561c6f4c02308149ba5eb71f89247e Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Fri, 7 Jul 2017 13:12:20 +0800 Subject: [PATCH 0895/1765] [SPARK-20703][SQL][FOLLOW-UP] Associate metrics with data writes onto DataFrameWriter operations ## What changes were proposed in this pull request? Remove time metrics since it seems no way to measure it in non per-row tracking. ## How was this patch tested? Existing tests. Please review http://spark.apache.org/contributing.html before opening a pull request. Author: Liang-Chi Hsieh Closes #18558 from viirya/SPARK-20703-followup. --- .../command/DataWritingCommand.scala | 10 --------- .../datasources/FileFormatWriter.scala | 22 +++---------------- .../sql/hive/execution/SQLMetricsSuite.scala | 3 --- 3 files changed, 3 insertions(+), 32 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/DataWritingCommand.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/DataWritingCommand.scala index 0c381a2c02986..700f7f81dc8a9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/DataWritingCommand.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/DataWritingCommand.scala @@ -30,7 +30,6 @@ trait DataWritingCommand extends RunnableCommand { override lazy val metrics: Map[String, SQLMetric] = { val sparkContext = SparkContext.getActive.get Map( - "avgTime" -> SQLMetrics.createMetric(sparkContext, "average writing time (ms)"), "numFiles" -> SQLMetrics.createMetric(sparkContext, "number of written files"), "numOutputBytes" -> SQLMetrics.createMetric(sparkContext, "bytes of written output"), "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows"), @@ -47,23 +46,14 @@ trait DataWritingCommand extends RunnableCommand { var numFiles = 0 var totalNumBytes: Long = 0L var totalNumOutput: Long = 0L - var totalWritingTime: Long = 0L writeSummaries.foreach { summary => numPartitions += summary.updatedPartitions.size numFiles += summary.numOutputFile totalNumBytes += summary.numOutputBytes totalNumOutput += summary.numOutputRows - totalWritingTime += summary.totalWritingTime } - val avgWritingTime = if (numFiles > 0) { - (totalWritingTime / numFiles).toLong - } else { - 0L - } - - metrics("avgTime").add(avgWritingTime) metrics("numFiles").add(numFiles) metrics("numOutputBytes").add(totalNumBytes) metrics("numOutputRows").add(totalNumOutput) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala index 64866630623ab..9eb9eae699e94 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala @@ -275,8 +275,6 @@ object FileFormatWriter extends Logging { /** * The data structures used to measure metrics during writing. */ - protected var totalWritingTime: Long = 0L - protected var timeOnCurrentFile: Long = 0L protected var numOutputRows: Long = 0L protected var numOutputBytes: Long = 0L @@ -343,9 +341,7 @@ object FileFormatWriter extends Logging { } val internalRow = iter.next() - val startTime = System.nanoTime() currentWriter.write(internalRow) - timeOnCurrentFile += (System.nanoTime() - startTime) recordsInFile += 1 } releaseResources() @@ -355,17 +351,13 @@ object FileFormatWriter extends Logging { updatedPartitions = Set.empty, numOutputFile = fileCounter + 1, numOutputBytes = numOutputBytes, - numOutputRows = numOutputRows, - totalWritingTime = totalWritingTime) + numOutputRows = numOutputRows) } override def releaseResources(): Unit = { if (currentWriter != null) { try { - val startTime = System.nanoTime() currentWriter.close() - totalWritingTime += (timeOnCurrentFile + System.nanoTime() - startTime) / 1000 / 1000 - timeOnCurrentFile = 0 numOutputBytes += getFileSize(taskAttemptContext.getConfiguration, currentPath) } finally { currentWriter = null @@ -504,9 +496,7 @@ object FileFormatWriter extends Logging { releaseResources() newOutputWriter(currentPartColsAndBucketId, getPartPath, fileCounter, updatedPartitions) } - val startTime = System.nanoTime() currentWriter.write(getOutputRow(row)) - timeOnCurrentFile += (System.nanoTime() - startTime) recordsInFile += 1 } if (currentPartColsAndBucketId != null) { @@ -519,17 +509,13 @@ object FileFormatWriter extends Logging { updatedPartitions = updatedPartitions.toSet, numOutputFile = totalFileCounter, numOutputBytes = numOutputBytes, - numOutputRows = numOutputRows, - totalWritingTime = totalWritingTime) + numOutputRows = numOutputRows) } override def releaseResources(): Unit = { if (currentWriter != null) { try { - val startTime = System.nanoTime() currentWriter.close() - totalWritingTime += (timeOnCurrentFile + System.nanoTime() - startTime) / 1000 / 1000 - timeOnCurrentFile = 0 numOutputBytes += getFileSize(taskAttemptContext.getConfiguration, currentPath) } finally { currentWriter = null @@ -547,11 +533,9 @@ object FileFormatWriter extends Logging { * @param numOutputFile the total number of files. * @param numOutputRows the number of output rows. * @param numOutputBytes the bytes of output data. - * @param totalWritingTime the total writing time in ms. */ case class ExecutedWriteSummary( updatedPartitions: Set[String], numOutputFile: Int, numOutputRows: Long, - numOutputBytes: Long, - totalWritingTime: Long) + numOutputBytes: Long) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLMetricsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLMetricsSuite.scala index 1ef1988d4c605..24c038587d1d6 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLMetricsSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLMetricsSuite.scala @@ -65,9 +65,6 @@ class SQLMetricsSuite extends SQLTestUtils with TestHiveSingleton { val totalNumBytesMetric = executedNode.metrics.find(_.name == "bytes of written output").get val totalNumBytes = metrics(totalNumBytesMetric.accumulatorId).replaceAll(",", "").toInt assert(totalNumBytes > 0) - val writingTimeMetric = executedNode.metrics.find(_.name == "average writing time (ms)").get - val writingTime = metrics(writingTimeMetric.accumulatorId).replaceAll(",", "").toInt - assert(writingTime >= 0) } private def testMetricsNonDynamicPartition( From 7fcbb9b57f5eba8b14bf7d86ebaa08a8ee937cd2 Mon Sep 17 00:00:00 2001 From: Jacek Laskowski Date: Fri, 7 Jul 2017 08:31:30 +0100 Subject: [PATCH 0896/1765] [SPARK-21313][SS] ConsoleSink's string representation ## What changes were proposed in this pull request? Add `toString` with options for `ConsoleSink` so it shows nicely in query progress. **BEFORE** ``` "sink" : { "description" : "org.apache.spark.sql.execution.streaming.ConsoleSink4b340441" } ``` **AFTER** ``` "sink" : { "description" : "ConsoleSink[numRows=10, truncate=false]" } ``` /cc zsxwing tdas ## How was this patch tested? Local build Author: Jacek Laskowski Closes #18539 from jaceklaskowski/SPARK-21313-ConsoleSink-toString. --- .../org/apache/spark/sql/execution/streaming/ForeachSink.scala | 2 ++ .../org/apache/spark/sql/execution/streaming/console.scala | 2 ++ 2 files changed, 4 insertions(+) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ForeachSink.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ForeachSink.scala index de09fb568d2a6..2cc54107f8b83 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ForeachSink.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ForeachSink.scala @@ -63,4 +63,6 @@ class ForeachSink[T : Encoder](writer: ForeachWriter[T]) extends Sink with Seria } } } + + override def toString(): String = "ForeachSink" } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/console.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/console.scala index 3baea6376069f..1c9284e252bd6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/console.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/console.scala @@ -52,6 +52,8 @@ class ConsoleSink(options: Map[String, String]) extends Sink with Logging { data.sparkSession.sparkContext.parallelize(data.collect()), data.schema) .show(numRowsToShow, isTruncated) } + + override def toString(): String = s"ConsoleSink[numRows=$numRowsToShow, truncate=$isTruncated]" } case class ConsoleRelation(override val sqlContext: SQLContext, data: DataFrame) From 56536e9992ac4ea771758463962e49bba410e896 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Yan=20Facai=20=28=E9=A2=9C=E5=8F=91=E6=89=8D=29?= Date: Fri, 7 Jul 2017 18:32:01 +0800 Subject: [PATCH 0897/1765] [SPARK-21285][ML] VectorAssembler reports the column name of unsupported data type MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## What changes were proposed in this pull request? add the column name in the exception which is raised by unsupported data type. ## How was this patch tested? + [x] pass all tests. Author: Yan Facai (颜发才) Closes #18523 from facaiy/ENH/vectorassembler_add_col. --- .../apache/spark/ml/feature/VectorAssembler.scala | 15 +++++++++------ .../spark/ml/feature/VectorAssemblerSuite.scala | 5 ++++- 2 files changed, 13 insertions(+), 7 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala index ca900536bc7b8..73f27d1a423d9 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala @@ -113,12 +113,15 @@ class VectorAssembler @Since("1.4.0") (@Since("1.4.0") override val uid: String) override def transformSchema(schema: StructType): StructType = { val inputColNames = $(inputCols) val outputColName = $(outputCol) - val inputDataTypes = inputColNames.map(name => schema(name).dataType) - inputDataTypes.foreach { - case _: NumericType | BooleanType => - case t if t.isInstanceOf[VectorUDT] => - case other => - throw new IllegalArgumentException(s"Data type $other is not supported.") + val incorrectColumns = inputColNames.flatMap { name => + schema(name).dataType match { + case _: NumericType | BooleanType => None + case t if t.isInstanceOf[VectorUDT] => None + case other => Some(s"Data type $other of column $name is not supported.") + } + } + if (incorrectColumns.nonEmpty) { + throw new IllegalArgumentException(incorrectColumns.mkString("\n")) } if (schema.fieldNames.contains(outputColName)) { throw new IllegalArgumentException(s"Output column $outputColName already exists.") diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/VectorAssemblerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/VectorAssemblerSuite.scala index 46cced3a9a6e5..6aef1c6837025 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/VectorAssemblerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/VectorAssemblerSuite.scala @@ -79,7 +79,10 @@ class VectorAssemblerSuite val thrown = intercept[IllegalArgumentException] { assembler.transform(df) } - assert(thrown.getMessage contains "Data type StringType is not supported") + assert(thrown.getMessage contains + "Data type StringType of column a is not supported.\n" + + "Data type StringType of column b is not supported.\n" + + "Data type StringType of column c is not supported.") } test("ML attributes") { From fef081309fc28efe8e136f363d85d7ccd9466e61 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Fri, 7 Jul 2017 20:04:30 +0800 Subject: [PATCH 0898/1765] [SPARK-21335][SQL] support un-aliased subquery ## What changes were proposed in this pull request? un-aliased subquery is supported by Spark SQL for a long time. Its semantic was not well defined and had confusing behaviors, and it's not a standard SQL syntax, so we disallowed it in https://issues.apache.org/jira/browse/SPARK-20690 . However, this is a breaking change, and we do have existing queries using un-aliased subquery. We should add the support back and fix its semantic. This PR fixes the un-aliased subquery by assigning a default alias name. After this PR, there is no syntax change from branch 2.2 to master, but we invalid a weird use case: `SELECT v.i from (SELECT i FROM v)`. Now this query will throw analysis exception because users should not be able to use the qualifier inside a subquery. ## How was this patch tested? new regression test Author: Wenchen Fan Closes #18559 from cloud-fan/sub-query. --- .../sql/catalyst/parser/AstBuilder.scala | 16 ++-- .../catalyst/plans/logical/LogicalPlan.scala | 2 +- .../sql/catalyst/parser/PlanParserSuite.scala | 13 --- .../resources/sql-tests/inputs/group-by.sql | 2 +- .../test/resources/sql-tests/inputs/limit.sql | 2 +- .../sql-tests/inputs/string-functions.sql | 2 +- .../in-subquery/in-set-operations.sql | 2 +- .../negative-cases/invalid-correlation.sql | 2 +- .../scalar-subquery-predicate.sql | 2 +- .../test/resources/sql-tests/inputs/union.sql | 4 +- .../results/columnresolution-negative.sql.out | 16 ++-- .../sql-tests/results/group-by.sql.out | 2 +- .../resources/sql-tests/results/limit.sql.out | 2 +- .../results/string-functions.sql.out | 6 +- .../in-subquery/in-set-operations.sql.out | 2 +- .../invalid-correlation.sql.out | 2 +- .../scalar-subquery-predicate.sql.out | 2 +- .../results/subquery/subquery-in-from.sql.out | 20 +---- .../resources/sql-tests/results/union.sql.out | 4 +- .../apache/spark/sql/CachedTableSuite.scala | 82 ++++++------------- .../org/apache/spark/sql/SQLQuerySuite.scala | 13 +++ .../org/apache/spark/sql/SubquerySuite.scala | 8 +- 22 files changed, 83 insertions(+), 123 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index b6a4686bb9ec9..4d725904bc9b9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -751,15 +751,17 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with Logging { * hooks. */ override def visitAliasedQuery(ctx: AliasedQueryContext): LogicalPlan = withOrigin(ctx) { - // The unaliased subqueries in the FROM clause are disallowed. Instead of rejecting it in - // parser rules, we handle it here in order to provide better error message. - if (ctx.strictIdentifier == null) { - throw new ParseException("The unaliased subqueries in the FROM clause are not supported.", - ctx) + val alias = if (ctx.strictIdentifier == null) { + // For un-aliased subqueries, use a default alias name that is not likely to conflict with + // normal subquery names, so that parent operators can only access the columns in subquery by + // unqualified names. Users can still use this special qualifier to access columns if they + // know it, but that's not recommended. + "__auto_generated_subquery_name" + } else { + ctx.strictIdentifier.getText } - aliasPlan(ctx.strictIdentifier, - plan(ctx.queryNoWith).optionalMap(ctx.sample)(withSample)) + SubqueryAlias(alias, plan(ctx.queryNoWith).optionalMap(ctx.sample)(withSample)) } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala index 8649603b1a9f5..9b440cd99f994 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala @@ -253,7 +253,7 @@ abstract class LogicalPlan // More than one match. case ambiguousReferences => - val referenceNames = ambiguousReferences.map(_._1).mkString(", ") + val referenceNames = ambiguousReferences.map(_._1.qualifiedName).mkString(", ") throw new AnalysisException( s"Reference '$name' is ambiguous, could be: $referenceNames.") } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala index 5b2573fa4d601..6dad097041a15 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala @@ -450,19 +450,6 @@ class PlanParserSuite extends AnalysisTest { | (select id from t0)) as u_1 """.stripMargin, plan.union(plan).union(plan).as("u_1").select('id)) - - } - - test("aliased subquery") { - val errMsg = "The unaliased subqueries in the FROM clause are not supported" - - assertEqual("select a from (select id as a from t0) tt", - table("t0").select('id.as("a")).as("tt").select('a)) - intercept("select a from (select id as a from t0)", errMsg) - - assertEqual("from (select id as a from t0) tt select a", - table("t0").select('id.as("a")).as("tt").select('a)) - intercept("from (select id as a from t0) select a", errMsg) } test("scalar sub-query") { diff --git a/sql/core/src/test/resources/sql-tests/inputs/group-by.sql b/sql/core/src/test/resources/sql-tests/inputs/group-by.sql index bc2120727dac2..1e1384549a410 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/group-by.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/group-by.sql @@ -34,7 +34,7 @@ SELECT SKEWNESS(a), KURTOSIS(a), MIN(a), MAX(a), AVG(a), VARIANCE(a), STDDEV(a), FROM testData; -- Aggregate with foldable input and multiple distinct groups. -SELECT COUNT(DISTINCT b), COUNT(DISTINCT b, c) FROM (SELECT 1 AS a, 2 AS b, 3 AS c) t GROUP BY a; +SELECT COUNT(DISTINCT b), COUNT(DISTINCT b, c) FROM (SELECT 1 AS a, 2 AS b, 3 AS c) GROUP BY a; -- Aliases in SELECT could be used in GROUP BY SELECT a AS k, COUNT(b) FROM testData GROUP BY k; diff --git a/sql/core/src/test/resources/sql-tests/inputs/limit.sql b/sql/core/src/test/resources/sql-tests/inputs/limit.sql index df555bdc1976d..f21912a042716 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/limit.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/limit.sql @@ -21,7 +21,7 @@ SELECT * FROM testdata LIMIT true; SELECT * FROM testdata LIMIT 'a'; -- limit within a subquery -SELECT * FROM (SELECT * FROM range(10) LIMIT 5) t WHERE id > 3; +SELECT * FROM (SELECT * FROM range(10) LIMIT 5) WHERE id > 3; -- limit ALL SELECT * FROM testdata WHERE key < 3 LIMIT ALL; diff --git a/sql/core/src/test/resources/sql-tests/inputs/string-functions.sql b/sql/core/src/test/resources/sql-tests/inputs/string-functions.sql index 20c0390664037..c95f4817b7ce0 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/string-functions.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/string-functions.sql @@ -7,7 +7,7 @@ select 'a' || 'b' || 'c'; -- Check if catalyst combine nested `Concat`s EXPLAIN EXTENDED SELECT (col1 || col2 || col3 || col4) col -FROM (SELECT id col1, id col2, id col3, id col4 FROM range(10)) t; +FROM (SELECT id col1, id col2, id col3, id col4 FROM range(10)); -- replace function select replace('abc', 'b', '123'); diff --git a/sql/core/src/test/resources/sql-tests/inputs/subquery/in-subquery/in-set-operations.sql b/sql/core/src/test/resources/sql-tests/inputs/subquery/in-subquery/in-set-operations.sql index 42f84e9748713..5c371d2305ac8 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/subquery/in-subquery/in-set-operations.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/subquery/in-subquery/in-set-operations.sql @@ -394,7 +394,7 @@ FROM (SELECT * FROM t1)) t4 WHERE t4.t2b IN (SELECT Min(t3b) FROM t3 - WHERE t4.t2a = t3a)) T; + WHERE t4.t2a = t3a)); -- UNION, UNION ALL, UNION DISTINCT, INTERSECT and EXCEPT for NOT IN -- TC 01.12 diff --git a/sql/core/src/test/resources/sql-tests/inputs/subquery/negative-cases/invalid-correlation.sql b/sql/core/src/test/resources/sql-tests/inputs/subquery/negative-cases/invalid-correlation.sql index f3f0c7622ccdb..e22cade936792 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/subquery/negative-cases/invalid-correlation.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/subquery/negative-cases/invalid-correlation.sql @@ -23,7 +23,7 @@ AND t2b = (SELECT max(avg) FROM (SELECT t2b, avg(t2b) avg FROM t2 WHERE t2a = t1.t1b - ) T + ) ) ; diff --git a/sql/core/src/test/resources/sql-tests/inputs/subquery/scalar-subquery/scalar-subquery-predicate.sql b/sql/core/src/test/resources/sql-tests/inputs/subquery/scalar-subquery/scalar-subquery-predicate.sql index dbe8d76d2f117..fb0d07fbdace7 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/subquery/scalar-subquery/scalar-subquery-predicate.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/subquery/scalar-subquery/scalar-subquery-predicate.sql @@ -19,7 +19,7 @@ AND c.cv = (SELECT max(avg) FROM (SELECT c1.cv, avg(c1.cv) avg FROM c c1 WHERE c1.ck = p.pk - GROUP BY c1.cv) T); + GROUP BY c1.cv)); create temporary view t1 as select * from values ('val1a', 6S, 8, 10L, float(15.0), 20D, 20E2, timestamp '2014-04-04 00:00:00.000', date '2014-04-04'), diff --git a/sql/core/src/test/resources/sql-tests/inputs/union.sql b/sql/core/src/test/resources/sql-tests/inputs/union.sql index 63bc044535e4d..e57d69eaad033 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/union.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/union.sql @@ -5,7 +5,7 @@ CREATE OR REPLACE TEMPORARY VIEW t2 AS VALUES (1.0, 1), (2.0, 4) tbl(c1, c2); SELECT * FROM (SELECT * FROM t1 UNION ALL - SELECT * FROM t1) T; + SELECT * FROM t1); -- Type Coerced Union SELECT * @@ -13,7 +13,7 @@ FROM (SELECT * FROM t1 UNION ALL SELECT * FROM t2 UNION ALL - SELECT * FROM t2) T; + SELECT * FROM t2); -- Regression test for SPARK-18622 SELECT a diff --git a/sql/core/src/test/resources/sql-tests/results/columnresolution-negative.sql.out b/sql/core/src/test/resources/sql-tests/results/columnresolution-negative.sql.out index 9e60e592c2bd1..b5a4f5c2bf654 100644 --- a/sql/core/src/test/resources/sql-tests/results/columnresolution-negative.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/columnresolution-negative.sql.out @@ -72,7 +72,7 @@ SELECT i1 FROM t1, mydb1.t1 struct<> -- !query 8 output org.apache.spark.sql.AnalysisException -Reference 'i1' is ambiguous, could be: i1#x, i1#x.; line 1 pos 7 +Reference 'i1' is ambiguous, could be: t1.i1, t1.i1.; line 1 pos 7 -- !query 9 @@ -81,7 +81,7 @@ SELECT t1.i1 FROM t1, mydb1.t1 struct<> -- !query 9 output org.apache.spark.sql.AnalysisException -Reference 't1.i1' is ambiguous, could be: i1#x, i1#x.; line 1 pos 7 +Reference 't1.i1' is ambiguous, could be: t1.i1, t1.i1.; line 1 pos 7 -- !query 10 @@ -99,7 +99,7 @@ SELECT i1 FROM t1, mydb2.t1 struct<> -- !query 11 output org.apache.spark.sql.AnalysisException -Reference 'i1' is ambiguous, could be: i1#x, i1#x.; line 1 pos 7 +Reference 'i1' is ambiguous, could be: t1.i1, t1.i1.; line 1 pos 7 -- !query 12 @@ -108,7 +108,7 @@ SELECT t1.i1 FROM t1, mydb2.t1 struct<> -- !query 12 output org.apache.spark.sql.AnalysisException -Reference 't1.i1' is ambiguous, could be: i1#x, i1#x.; line 1 pos 7 +Reference 't1.i1' is ambiguous, could be: t1.i1, t1.i1.; line 1 pos 7 -- !query 13 @@ -125,7 +125,7 @@ SELECT i1 FROM t1, mydb1.t1 struct<> -- !query 14 output org.apache.spark.sql.AnalysisException -Reference 'i1' is ambiguous, could be: i1#x, i1#x.; line 1 pos 7 +Reference 'i1' is ambiguous, could be: t1.i1, t1.i1.; line 1 pos 7 -- !query 15 @@ -134,7 +134,7 @@ SELECT t1.i1 FROM t1, mydb1.t1 struct<> -- !query 15 output org.apache.spark.sql.AnalysisException -Reference 't1.i1' is ambiguous, could be: i1#x, i1#x.; line 1 pos 7 +Reference 't1.i1' is ambiguous, could be: t1.i1, t1.i1.; line 1 pos 7 -- !query 16 @@ -143,7 +143,7 @@ SELECT i1 FROM t1, mydb2.t1 struct<> -- !query 16 output org.apache.spark.sql.AnalysisException -Reference 'i1' is ambiguous, could be: i1#x, i1#x.; line 1 pos 7 +Reference 'i1' is ambiguous, could be: t1.i1, t1.i1.; line 1 pos 7 -- !query 17 @@ -152,7 +152,7 @@ SELECT t1.i1 FROM t1, mydb2.t1 struct<> -- !query 17 output org.apache.spark.sql.AnalysisException -Reference 't1.i1' is ambiguous, could be: i1#x, i1#x.; line 1 pos 7 +Reference 't1.i1' is ambiguous, could be: t1.i1, t1.i1.; line 1 pos 7 -- !query 18 diff --git a/sql/core/src/test/resources/sql-tests/results/group-by.sql.out b/sql/core/src/test/resources/sql-tests/results/group-by.sql.out index e23ebd4e822fa..986bb01c13fe4 100644 --- a/sql/core/src/test/resources/sql-tests/results/group-by.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/group-by.sql.out @@ -134,7 +134,7 @@ struct -- !query 14 output diff --git a/sql/core/src/test/resources/sql-tests/results/limit.sql.out b/sql/core/src/test/resources/sql-tests/results/limit.sql.out index afdd6df2a5714..146abe6cbd058 100644 --- a/sql/core/src/test/resources/sql-tests/results/limit.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/limit.sql.out @@ -93,7 +93,7 @@ The limit expression must be integer type, but got string; -- !query 10 -SELECT * FROM (SELECT * FROM range(10) LIMIT 5) t WHERE id > 3 +SELECT * FROM (SELECT * FROM range(10) LIMIT 5) WHERE id > 3 -- !query 10 schema struct -- !query 10 output diff --git a/sql/core/src/test/resources/sql-tests/results/string-functions.sql.out b/sql/core/src/test/resources/sql-tests/results/string-functions.sql.out index 52eb554edf89e..b0ae9d775d968 100644 --- a/sql/core/src/test/resources/sql-tests/results/string-functions.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/string-functions.sql.out @@ -30,20 +30,20 @@ abc -- !query 3 EXPLAIN EXTENDED SELECT (col1 || col2 || col3 || col4) col -FROM (SELECT id col1, id col2, id col3, id col4 FROM range(10)) t +FROM (SELECT id col1, id col2, id col3, id col4 FROM range(10)) -- !query 3 schema struct -- !query 3 output == Parsed Logical Plan == 'Project [concat(concat(concat('col1, 'col2), 'col3), 'col4) AS col#x] -+- 'SubqueryAlias t ++- 'SubqueryAlias __auto_generated_subquery_name +- 'Project ['id AS col1#x, 'id AS col2#x, 'id AS col3#x, 'id AS col4#x] +- 'UnresolvedTableValuedFunction range, [10] == Analyzed Logical Plan == col: string Project [concat(concat(concat(cast(col1#xL as string), cast(col2#xL as string)), cast(col3#xL as string)), cast(col4#xL as string)) AS col#x] -+- SubqueryAlias t ++- SubqueryAlias __auto_generated_subquery_name +- Project [id#xL AS col1#xL, id#xL AS col2#xL, id#xL AS col3#xL, id#xL AS col4#xL] +- Range (0, 10, step=1, splits=None) diff --git a/sql/core/src/test/resources/sql-tests/results/subquery/in-subquery/in-set-operations.sql.out b/sql/core/src/test/resources/sql-tests/results/subquery/in-subquery/in-set-operations.sql.out index 5780f49648ec7..e06f9206d3401 100644 --- a/sql/core/src/test/resources/sql-tests/results/subquery/in-subquery/in-set-operations.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/subquery/in-subquery/in-set-operations.sql.out @@ -496,7 +496,7 @@ FROM (SELECT * FROM t1)) t4 WHERE t4.t2b IN (SELECT Min(t3b) FROM t3 - WHERE t4.t2a = t3a)) T + WHERE t4.t2a = t3a)) -- !query 13 schema struct -- !query 13 output diff --git a/sql/core/src/test/resources/sql-tests/results/subquery/negative-cases/invalid-correlation.sql.out b/sql/core/src/test/resources/sql-tests/results/subquery/negative-cases/invalid-correlation.sql.out index ca3930b33e06d..e4b1a2dbc675c 100644 --- a/sql/core/src/test/resources/sql-tests/results/subquery/negative-cases/invalid-correlation.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/subquery/negative-cases/invalid-correlation.sql.out @@ -40,7 +40,7 @@ AND t2b = (SELECT max(avg) FROM (SELECT t2b, avg(t2b) avg FROM t2 WHERE t2a = t1.t1b - ) T + ) ) -- !query 3 schema struct<> diff --git a/sql/core/src/test/resources/sql-tests/results/subquery/scalar-subquery/scalar-subquery-predicate.sql.out b/sql/core/src/test/resources/sql-tests/results/subquery/scalar-subquery/scalar-subquery-predicate.sql.out index 1d5dddca76a17..8b29300e71f90 100644 --- a/sql/core/src/test/resources/sql-tests/results/subquery/scalar-subquery/scalar-subquery-predicate.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/subquery/scalar-subquery/scalar-subquery-predicate.sql.out @@ -39,7 +39,7 @@ AND c.cv = (SELECT max(avg) FROM (SELECT c1.cv, avg(c1.cv) avg FROM c c1 WHERE c1.ck = p.pk - GROUP BY c1.cv) T) + GROUP BY c1.cv)) -- !query 3 schema struct -- !query 3 output diff --git a/sql/core/src/test/resources/sql-tests/results/subquery/subquery-in-from.sql.out b/sql/core/src/test/resources/sql-tests/results/subquery/subquery-in-from.sql.out index 14553557d1ffc..50370df349168 100644 --- a/sql/core/src/test/resources/sql-tests/results/subquery/subquery-in-from.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/subquery/subquery-in-from.sql.out @@ -37,26 +37,14 @@ struct -- !query 4 SELECT * FROM (SELECT * FROM testData) WHERE key = 1 -- !query 4 schema -struct<> +struct -- !query 4 output -org.apache.spark.sql.catalyst.parser.ParseException - -The unaliased subqueries in the FROM clause are not supported.(line 1, pos 14) - -== SQL == -SELECT * FROM (SELECT * FROM testData) WHERE key = 1 ---------------^^^ +1 1 -- !query 5 FROM (SELECT * FROM testData WHERE key = 1) SELECT * -- !query 5 schema -struct<> +struct -- !query 5 output -org.apache.spark.sql.catalyst.parser.ParseException - -The unaliased subqueries in the FROM clause are not supported.(line 1, pos 5) - -== SQL == -FROM (SELECT * FROM testData WHERE key = 1) SELECT * ------^^^ +1 1 diff --git a/sql/core/src/test/resources/sql-tests/results/union.sql.out b/sql/core/src/test/resources/sql-tests/results/union.sql.out index 865b3aed65d70..d123b7fdbe0cf 100644 --- a/sql/core/src/test/resources/sql-tests/results/union.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/union.sql.out @@ -22,7 +22,7 @@ struct<> SELECT * FROM (SELECT * FROM t1 UNION ALL - SELECT * FROM t1) T + SELECT * FROM t1) -- !query 2 schema struct -- !query 2 output @@ -38,7 +38,7 @@ FROM (SELECT * FROM t1 UNION ALL SELECT * FROM t2 UNION ALL - SELECT * FROM t2) T + SELECT * FROM t2) -- !query 3 schema struct -- !query 3 output diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala index 506cc2548e260..3e4f619431599 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala @@ -631,13 +631,13 @@ class CachedTableSuite extends QueryTest with SQLTestUtils with SharedSQLContext val ds2 = sql( """ - |SELECT * FROM (SELECT max(c1) as c1 FROM t1 GROUP BY c1) tt + |SELECT * FROM (SELECT c1, max(c1) FROM t1 GROUP BY c1) |WHERE - |tt.c1 = (SELECT max(c1) FROM t2 GROUP BY c1) + |c1 = (SELECT max(c1) FROM t2 GROUP BY c1) |OR |EXISTS (SELECT c1 FROM t3) |OR - |tt.c1 IN (SELECT c1 FROM t4) + |c1 IN (SELECT c1 FROM t4) """.stripMargin) assert(getNumInMemoryRelations(ds2) == 4) } @@ -683,20 +683,15 @@ class CachedTableSuite extends QueryTest with SQLTestUtils with SharedSQLContext Seq(1).toDF("c1").createOrReplaceTempView("t1") Seq(2).toDF("c1").createOrReplaceTempView("t2") - sql( + val sql1 = """ |SELECT * FROM t1 |WHERE |NOT EXISTS (SELECT * FROM t2) - """.stripMargin).cache() + """.stripMargin + sql(sql1).cache() - val cachedDs = - sql( - """ - |SELECT * FROM t1 - |WHERE - |NOT EXISTS (SELECT * FROM t2) - """.stripMargin) + val cachedDs = sql(sql1) assert(getNumInMemoryRelations(cachedDs) == 1) // Additional predicate in the subquery plan should cause a cache miss @@ -717,20 +712,15 @@ class CachedTableSuite extends QueryTest with SQLTestUtils with SharedSQLContext Seq(1).toDF("c1").createOrReplaceTempView("t2") // Simple correlated predicate in subquery - sql( + val sqlText = """ |SELECT * FROM t1 |WHERE |t1.c1 in (SELECT t2.c1 FROM t2 where t1.c1 = t2.c1) - """.stripMargin).cache() + """.stripMargin + sql(sqlText).cache() - val cachedDs = - sql( - """ - |SELECT * FROM t1 - |WHERE - |t1.c1 in (SELECT t2.c1 FROM t2 where t1.c1 = t2.c1) - """.stripMargin) + val cachedDs = sql(sqlText) assert(getNumInMemoryRelations(cachedDs) == 1) } } @@ -741,22 +731,16 @@ class CachedTableSuite extends QueryTest with SQLTestUtils with SharedSQLContext spark.catalog.cacheTable("t1") // underlying table t1 is cached as well as the query that refers to it. - val ds = - sql( + val sqlText = """ |SELECT * FROM t1 |WHERE |NOT EXISTS (SELECT * FROM t1) - """.stripMargin) + """.stripMargin + val ds = sql(sqlText) assert(getNumInMemoryRelations(ds) == 2) - val cachedDs = - sql( - """ - |SELECT * FROM t1 - |WHERE - |NOT EXISTS (SELECT * FROM t1) - """.stripMargin).cache() + val cachedDs = sql(sqlText).cache() assert(getNumInMemoryTablesRecursively(cachedDs.queryExecution.sparkPlan) == 3) } } @@ -769,45 +753,31 @@ class CachedTableSuite extends QueryTest with SQLTestUtils with SharedSQLContext Seq(1).toDF("c1").createOrReplaceTempView("t4") // Nested predicate subquery - sql( + val sql1 = """ |SELECT * FROM t1 |WHERE |c1 IN (SELECT c1 FROM t2 WHERE c1 IN (SELECT c1 FROM t3 WHERE c1 = 1)) - """.stripMargin).cache() + """.stripMargin + sql(sql1).cache() - val cachedDs = - sql( - """ - |SELECT * FROM t1 - |WHERE - |c1 IN (SELECT c1 FROM t2 WHERE c1 IN (SELECT c1 FROM t3 WHERE c1 = 1)) - """.stripMargin) + val cachedDs = sql(sql1) assert(getNumInMemoryRelations(cachedDs) == 1) // Scalar subquery and predicate subquery - sql( + val sql2 = """ - |SELECT * FROM (SELECT max(c1) as c1 FROM t1 GROUP BY c1) tt + |SELECT * FROM (SELECT c1, max(c1) FROM t1 GROUP BY c1) |WHERE - |tt.c1 = (SELECT max(c1) FROM t2 GROUP BY c1) + |c1 = (SELECT max(c1) FROM t2 GROUP BY c1) |OR |EXISTS (SELECT c1 FROM t3) |OR - |tt.c1 IN (SELECT c1 FROM t4) - """.stripMargin).cache() + |c1 IN (SELECT c1 FROM t4) + """.stripMargin + sql(sql2).cache() - val cachedDs2 = - sql( - """ - |SELECT * FROM (SELECT max(c1) as c1 FROM t1 GROUP BY c1) tt - |WHERE - |tt.c1 = (SELECT max(c1) FROM t2 GROUP BY c1) - |OR - |EXISTS (SELECT c1 FROM t3) - |OR - |tt.c1 IN (SELECT c1 FROM t4) - """.stripMargin) + val cachedDs2 = sql(sql2) assert(getNumInMemoryRelations(cachedDs2) == 1) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 5171aaebc9907..472ff7385b194 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -2638,4 +2638,17 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { } } } + + test("SPARK-21335: support un-aliased subquery") { + withTempView("v") { + Seq(1 -> "a").toDF("i", "j").createOrReplaceTempView("v") + checkAnswer(sql("SELECT i from (SELECT i FROM v)"), Row(1)) + + val e = intercept[AnalysisException](sql("SELECT v.i from (SELECT i FROM v)")) + assert(e.message == + "cannot resolve '`v.i`' given input columns: [__auto_generated_subquery_name.i]") + + checkAnswer(sql("SELECT __auto_generated_subquery_name.i from (SELECT i FROM v)"), Row(1)) + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala index c0a3b5add313a..7bcb419e8df6c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala @@ -112,7 +112,7 @@ class SubquerySuite extends QueryTest with SharedSQLContext { | with t4 as (select 1 as d, 3 as e) | select * from t4 cross join t2 where t2.b = t4.d | ) - | select a from (select 1 as a union all select 2 as a) t + | select a from (select 1 as a union all select 2 as a) | where a = (select max(d) from t3) """.stripMargin), Array(Row(1)) @@ -606,8 +606,8 @@ class SubquerySuite extends QueryTest with SharedSQLContext { | select cntPlusOne + 1 as cntPlusTwo from ( | select cnt + 1 as cntPlusOne from ( | select sum(r.c) s, count(*) cnt from r where l.a = r.c having cnt = 0 - | ) t1 - | ) t2 + | ) + | ) |) = 2""".stripMargin), Row(1) :: Row(1) :: Row(null) :: Row(null) :: Nil) } @@ -655,7 +655,7 @@ class SubquerySuite extends QueryTest with SharedSQLContext { """ | select c1 from onerow t1 | where exists (select 1 - | from (select 1 as c1 from onerow t2 LIMIT 1) t2 + | from (select c1 from onerow t2 LIMIT 1) t2 | where t1.c1=t2.c1)""".stripMargin), Row(1) :: Nil) } From fbbe37ed416f2ca9d8fc713a135b335b8a0247bf Mon Sep 17 00:00:00 2001 From: CodingCat Date: Fri, 7 Jul 2017 20:10:24 +0800 Subject: [PATCH 0899/1765] [SPARK-19358][CORE] LiveListenerBus shall log the event name when dropping them due to a fully filled queue ## What changes were proposed in this pull request? Some dropped event will make the whole application behaves unexpectedly, e.g. some UI problem...we shall log the dropped event name to facilitate the debugging ## How was this patch tested? Existing tests Author: CodingCat Closes #16697 from CodingCat/SPARK-19358. --- .../main/scala/org/apache/spark/scheduler/LiveListenerBus.scala | 1 + 1 file changed, 1 insertion(+) diff --git a/core/src/main/scala/org/apache/spark/scheduler/LiveListenerBus.scala b/core/src/main/scala/org/apache/spark/scheduler/LiveListenerBus.scala index f0887e090b956..0dd63d4392800 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/LiveListenerBus.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/LiveListenerBus.scala @@ -232,6 +232,7 @@ private[spark] class LiveListenerBus(conf: SparkConf) extends SparkListenerBus { "This likely means one of the SparkListeners is too slow and cannot keep up with " + "the rate at which tasks are being started by the scheduler.") } + logTrace(s"Dropping event $event") } } From a0fe32a219253f0abe9d67cf178c73daf5f6fcc1 Mon Sep 17 00:00:00 2001 From: Wang Gengliang Date: Fri, 7 Jul 2017 15:39:29 -0700 Subject: [PATCH 0900/1765] [SPARK-21336] Revise rand comparison in BatchEvalPythonExecSuite ## What changes were proposed in this pull request? Revise rand comparison in BatchEvalPythonExecSuite In BatchEvalPythonExecSuite, there are two cases using the case "rand() > 3" Rand() generates a random value in [0, 1), it is wired to be compared with 3, use 0.3 instead ## How was this patch tested? unit test Please review http://spark.apache.org/contributing.html before opening a pull request. Author: Wang Gengliang Closes #18560 from gengliangwang/revise_BatchEvalPythonExecSuite. --- .../spark/sql/execution/python/BatchEvalPythonExecSuite.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExecSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExecSuite.scala index 80ef4eb75ca53..bbd9484271a3e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExecSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExecSuite.scala @@ -65,7 +65,7 @@ class BatchEvalPythonExecSuite extends SparkPlanTest with SharedSQLContext { test("Python UDF: no push down on non-deterministic") { val df = Seq(("Hello", 4)).toDF("a", "b") - .where("b > 4 and dummyPythonUDF(a) and rand() > 3") + .where("b > 4 and dummyPythonUDF(a) and rand() > 0.3") val qualifiedPlanNodes = df.queryExecution.executedPlan.collect { case f @ FilterExec( And(_: AttributeReference, _: GreaterThan), @@ -77,7 +77,7 @@ class BatchEvalPythonExecSuite extends SparkPlanTest with SharedSQLContext { test("Python UDF: no push down on predicates starting from the first non-deterministic") { val df = Seq(("Hello", 4)).toDF("a", "b") - .where("dummyPythonUDF(a) and rand() > 3 and b > 4") + .where("dummyPythonUDF(a) and rand() > 0.3 and b > 4") val qualifiedPlanNodes = df.queryExecution.executedPlan.collect { case f @ FilterExec(And(_: And, _: GreaterThan), InputAdapter(_: BatchEvalPythonExec)) => f } From e1a172c201d68406faa53b113518b10c879f1ff6 Mon Sep 17 00:00:00 2001 From: Andrew Ray Date: Sat, 8 Jul 2017 13:47:41 +0800 Subject: [PATCH 0901/1765] [SPARK-21100][SQL] Add summary method as alternative to describe that gives quartiles similar to Pandas ## What changes were proposed in this pull request? Adds method `summary` that allows user to specify which statistics and percentiles to calculate. By default it include the existing statistics from `describe` and quartiles (25th, 50th, and 75th percentiles) similar to Pandas. Also changes the implementation of `describe` to delegate to `summary`. ## How was this patch tested? additional unit test Author: Andrew Ray Closes #18307 from aray/SPARK-21100. --- .../scala/org/apache/spark/sql/Dataset.scala | 113 +++++++++++------- .../sql/execution/stat/StatFunctions.scala | 98 ++++++++++++++- .../org/apache/spark/sql/DataFrameSuite.scala | 112 +++++++++++++---- 3 files changed, 258 insertions(+), 65 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index b1638a2180b07..5326b45b50a8b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -38,18 +38,18 @@ import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.catalog.CatalogRelation import org.apache.spark.sql.catalyst.encoders._ import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.json.{JacksonGenerator, JSONOptions} import org.apache.spark.sql.catalyst.optimizer.CombineUnions import org.apache.spark.sql.catalyst.parser.ParseException import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.plans.physical.{Partitioning, PartitioningCollection} -import org.apache.spark.sql.catalyst.util.{usePrettyExpression, DateTimeUtils} +import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.command._ import org.apache.spark.sql.execution.datasources.LogicalRelation import org.apache.spark.sql.execution.python.EvaluatePython +import org.apache.spark.sql.execution.stat.StatFunctions import org.apache.spark.sql.streaming.DataStreamWriter import org.apache.spark.sql.types._ import org.apache.spark.storage.StorageLevel @@ -224,7 +224,7 @@ class Dataset[T] private[sql]( } } - private def aggregatableColumns: Seq[Expression] = { + private[sql] def aggregatableColumns: Seq[Expression] = { schema.fields .filter(f => f.dataType.isInstanceOf[NumericType] || f.dataType.isInstanceOf[StringType]) .map { n => @@ -2161,9 +2161,9 @@ class Dataset[T] private[sql]( } /** - * Computes statistics for numeric and string columns, including count, mean, stddev, min, and - * max. If no columns are given, this function computes statistics for all numerical or string - * columns. + * Computes basic statistics for numeric and string columns, including count, mean, stddev, min, + * and max. If no columns are given, this function computes statistics for all numerical or + * string columns. * * This function is meant for exploratory data analysis, as we make no guarantee about the * backward compatibility of the schema of the resulting Dataset. If you want to @@ -2181,46 +2181,79 @@ class Dataset[T] private[sql]( * // max 92.0 192.0 * }}} * + * Use [[summary]] for expanded statistics and control over which statistics to compute. + * + * @param cols Columns to compute statistics on. + * * @group action * @since 1.6.0 */ @scala.annotation.varargs - def describe(cols: String*): DataFrame = withPlan { - - // The list of summary statistics to compute, in the form of expressions. - val statistics = List[(String, Expression => Expression)]( - "count" -> ((child: Expression) => Count(child).toAggregateExpression()), - "mean" -> ((child: Expression) => Average(child).toAggregateExpression()), - "stddev" -> ((child: Expression) => StddevSamp(child).toAggregateExpression()), - "min" -> ((child: Expression) => Min(child).toAggregateExpression()), - "max" -> ((child: Expression) => Max(child).toAggregateExpression())) - - val outputCols = - (if (cols.isEmpty) aggregatableColumns.map(usePrettyExpression(_).sql) else cols).toList - - val ret: Seq[Row] = if (outputCols.nonEmpty) { - val aggExprs = statistics.flatMap { case (_, colToAgg) => - outputCols.map(c => Column(Cast(colToAgg(Column(c).expr), StringType)).as(c)) - } - - val row = groupBy().agg(aggExprs.head, aggExprs.tail: _*).head().toSeq - - // Pivot the data so each summary is one row - row.grouped(outputCols.size).toSeq.zip(statistics).map { case (aggregation, (statistic, _)) => - Row(statistic :: aggregation.toList: _*) - } - } else { - // If there are no output columns, just output a single column that contains the stats. - statistics.map { case (name, _) => Row(name) } - } - - // All columns are string type - val schema = StructType( - StructField("summary", StringType) :: outputCols.map(StructField(_, StringType))).toAttributes - // `toArray` forces materialization to make the seq serializable - LocalRelation.fromExternalRows(schema, ret.toArray.toSeq) + def describe(cols: String*): DataFrame = { + val selected = if (cols.isEmpty) this else select(cols.head, cols.tail: _*) + selected.summary("count", "mean", "stddev", "min", "max") } + /** + * Computes specified statistics for numeric and string columns. Available statistics are: + * + * - count + * - mean + * - stddev + * - min + * - max + * - arbitrary approximate percentiles specified as a percentage (eg, 75%) + * + * If no statistics are given, this function computes count, mean, stddev, min, + * approximate quartiles (percentiles at 25%, 50%, and 75%), and max. + * + * This function is meant for exploratory data analysis, as we make no guarantee about the + * backward compatibility of the schema of the resulting Dataset. If you want to + * programmatically compute summary statistics, use the `agg` function instead. + * + * {{{ + * ds.summary().show() + * + * // output: + * // summary age height + * // count 10.0 10.0 + * // mean 53.3 178.05 + * // stddev 11.6 15.7 + * // min 18.0 163.0 + * // 25% 24.0 176.0 + * // 50% 24.0 176.0 + * // 75% 32.0 180.0 + * // max 92.0 192.0 + * }}} + * + * {{{ + * ds.summary("count", "min", "25%", "75%", "max").show() + * + * // output: + * // summary age height + * // count 10.0 10.0 + * // min 18.0 163.0 + * // 25% 24.0 176.0 + * // 75% 32.0 180.0 + * // max 92.0 192.0 + * }}} + * + * To do a summary for specific columns first select them: + * + * {{{ + * ds.select("age", "height").summary().show() + * }}} + * + * See also [[describe]] for basic statistics. + * + * @param statistics Statistics from above list to be computed. + * + * @group action + * @since 2.3.0 + */ + @scala.annotation.varargs + def summary(statistics: String*): DataFrame = StatFunctions.summary(this, statistics.toSeq) + /** * Returns the first `n` rows. * diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala index 1debad03c93fa..436e18fdb5ff5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala @@ -19,9 +19,10 @@ package org.apache.spark.sql.execution.stat import org.apache.spark.internal.Logging import org.apache.spark.sql.{Column, DataFrame, Dataset, Row} -import org.apache.spark.sql.catalyst.expressions.{Cast, GenericInternalRow} +import org.apache.spark.sql.catalyst.expressions.{Cast, CreateArray, Expression, GenericInternalRow, Literal} +import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.plans.logical.LocalRelation -import org.apache.spark.sql.catalyst.util.QuantileSummaries +import org.apache.spark.sql.catalyst.util.{usePrettyExpression, QuantileSummaries} import org.apache.spark.sql.functions._ import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String @@ -220,4 +221,97 @@ object StatFunctions extends Logging { Dataset.ofRows(df.sparkSession, LocalRelation(schema.toAttributes, table)).na.fill(0.0) } + + /** Calculate selected summary statistics for a dataset */ + def summary(ds: Dataset[_], statistics: Seq[String]): DataFrame = { + + val defaultStatistics = Seq("count", "mean", "stddev", "min", "25%", "50%", "75%", "max") + val selectedStatistics = if (statistics.nonEmpty) statistics else defaultStatistics + + val hasPercentiles = selectedStatistics.exists(_.endsWith("%")) + val (percentiles, percentileNames, remainingAggregates) = if (hasPercentiles) { + val (pStrings, rest) = selectedStatistics.partition(a => a.endsWith("%")) + val percentiles = pStrings.map { p => + try { + p.stripSuffix("%").toDouble / 100.0 + } catch { + case e: NumberFormatException => + throw new IllegalArgumentException(s"Unable to parse $p as a percentile", e) + } + } + require(percentiles.forall(p => p >= 0 && p <= 1), "Percentiles must be in the range [0, 1]") + (percentiles, pStrings, rest) + } else { + (Seq(), Seq(), selectedStatistics) + } + + + // The list of summary statistics to compute, in the form of expressions. + val availableStatistics = Map[String, Expression => Expression]( + "count" -> ((child: Expression) => Count(child).toAggregateExpression()), + "mean" -> ((child: Expression) => Average(child).toAggregateExpression()), + "stddev" -> ((child: Expression) => StddevSamp(child).toAggregateExpression()), + "min" -> ((child: Expression) => Min(child).toAggregateExpression()), + "max" -> ((child: Expression) => Max(child).toAggregateExpression())) + + val statisticFns = remainingAggregates.map { agg => + require(availableStatistics.contains(agg), s"$agg is not a recognised statistic") + agg -> availableStatistics(agg) + } + + def percentileAgg(child: Expression): Expression = + new ApproximatePercentile(child, CreateArray(percentiles.map(Literal(_)))) + .toAggregateExpression() + + val outputCols = ds.aggregatableColumns.map(usePrettyExpression(_).sql).toList + + val ret: Seq[Row] = if (outputCols.nonEmpty) { + var aggExprs = statisticFns.toList.flatMap { case (_, colToAgg) => + outputCols.map(c => Column(Cast(colToAgg(Column(c).expr), StringType)).as(c)) + } + if (hasPercentiles) { + aggExprs = outputCols.map(c => Column(percentileAgg(Column(c).expr)).as(c)) ++ aggExprs + } + + val row = ds.groupBy().agg(aggExprs.head, aggExprs.tail: _*).head().toSeq + + // Pivot the data so each summary is one row + val grouped: Seq[Seq[Any]] = row.grouped(outputCols.size).toSeq + + val basicStats = if (hasPercentiles) grouped.tail else grouped + + val rows = basicStats.zip(statisticFns).map { case (aggregation, (statistic, _)) => + Row(statistic :: aggregation.toList: _*) + } + + if (hasPercentiles) { + def nullSafeString(x: Any) = if (x == null) null else x.toString + val percentileRows = grouped.head + .map { + case a: Seq[Any] => a + case _ => Seq.fill(percentiles.length)(null: Any) + } + .transpose + .zip(percentileNames) + .map { case (values: Seq[Any], name) => + Row(name :: values.map(nullSafeString).toList: _*) + } + (rows ++ percentileRows) + .sortWith((left, right) => + selectedStatistics.indexOf(left(0)) < selectedStatistics.indexOf(right(0))) + } else { + rows + } + } else { + // If there are no output columns, just output a single column that contains the stats. + selectedStatistics.map(Row(_)) + } + + // All columns are string type + val schema = StructType( + StructField("summary", StringType) :: outputCols.map(StructField(_, StringType))).toAttributes + // `toArray` forces materialization to make the seq serializable + Dataset.ofRows(ds.sparkSession, LocalRelation.fromExternalRows(schema, ret.toArray.toSeq)) + } + } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index 9ea9951c24ef1..2c7051bf431c3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -28,8 +28,7 @@ import org.scalatest.Matchers._ import org.apache.spark.SparkException import org.apache.spark.sql.catalyst.TableIdentifier -import org.apache.spark.sql.catalyst.plans.logical.{Filter, OneRowRelation, Project, Union} -import org.apache.spark.sql.catalyst.util.DateTimeUtils +import org.apache.spark.sql.catalyst.plans.logical.{Filter, OneRowRelation, Union} import org.apache.spark.sql.execution.{FilterExec, QueryExecution} import org.apache.spark.sql.execution.aggregate.HashAggregateExec import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, ReusedExchangeExec, ShuffleExchange} @@ -663,13 +662,13 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { assert(df.schema.map(_.name) === Seq("key", "valueRenamed", "newCol")) } - test("describe") { - val describeTestData = Seq( - ("Bob", 16, 176), - ("Alice", 32, 164), - ("David", 60, 192), - ("Amy", 24, 180)).toDF("name", "age", "height") + private lazy val person2: DataFrame = Seq( + ("Bob", 16, 176), + ("Alice", 32, 164), + ("David", 60, 192), + ("Amy", 24, 180)).toDF("name", "age", "height") + test("describe") { val describeResult = Seq( Row("count", "4", "4", "4"), Row("mean", null, "33.0", "178.0"), @@ -686,32 +685,99 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { def getSchemaAsSeq(df: DataFrame): Seq[String] = df.schema.map(_.name) - val describeTwoCols = describeTestData.describe("name", "age", "height") - assert(getSchemaAsSeq(describeTwoCols) === Seq("summary", "name", "age", "height")) - checkAnswer(describeTwoCols, describeResult) - // All aggregate value should have been cast to string - describeTwoCols.collect().foreach { row => - assert(row.get(2).isInstanceOf[String], "expected string but found " + row.get(2).getClass) - assert(row.get(3).isInstanceOf[String], "expected string but found " + row.get(3).getClass) - } - - val describeAllCols = describeTestData.describe() + val describeAllCols = person2.describe() assert(getSchemaAsSeq(describeAllCols) === Seq("summary", "name", "age", "height")) checkAnswer(describeAllCols, describeResult) + // All aggregate value should have been cast to string + describeAllCols.collect().foreach { row => + row.toSeq.foreach { value => + if (value != null) { + assert(value.isInstanceOf[String], "expected string but found " + value.getClass) + } + } + } - val describeOneCol = describeTestData.describe("age") + val describeOneCol = person2.describe("age") assert(getSchemaAsSeq(describeOneCol) === Seq("summary", "age")) checkAnswer(describeOneCol, describeResult.map { case Row(s, _, d, _) => Row(s, d)} ) - val describeNoCol = describeTestData.select("name").describe() - assert(getSchemaAsSeq(describeNoCol) === Seq("summary", "name")) - checkAnswer(describeNoCol, describeResult.map { case Row(s, n, _, _) => Row(s, n)} ) + val describeNoCol = person2.select().describe() + assert(getSchemaAsSeq(describeNoCol) === Seq("summary")) + checkAnswer(describeNoCol, describeResult.map { case Row(s, _, _, _) => Row(s)} ) - val emptyDescription = describeTestData.limit(0).describe() + val emptyDescription = person2.limit(0).describe() assert(getSchemaAsSeq(emptyDescription) === Seq("summary", "name", "age", "height")) checkAnswer(emptyDescription, emptyDescribeResult) } + test("summary") { + val summaryResult = Seq( + Row("count", "4", "4", "4"), + Row("mean", null, "33.0", "178.0"), + Row("stddev", null, "19.148542155126762", "11.547005383792516"), + Row("min", "Alice", "16", "164"), + Row("25%", null, "24.0", "176.0"), + Row("50%", null, "24.0", "176.0"), + Row("75%", null, "32.0", "180.0"), + Row("max", "David", "60", "192")) + + val emptySummaryResult = Seq( + Row("count", "0", "0", "0"), + Row("mean", null, null, null), + Row("stddev", null, null, null), + Row("min", null, null, null), + Row("25%", null, null, null), + Row("50%", null, null, null), + Row("75%", null, null, null), + Row("max", null, null, null)) + + def getSchemaAsSeq(df: DataFrame): Seq[String] = df.schema.map(_.name) + + val summaryAllCols = person2.summary() + + assert(getSchemaAsSeq(summaryAllCols) === Seq("summary", "name", "age", "height")) + checkAnswer(summaryAllCols, summaryResult) + // All aggregate value should have been cast to string + summaryAllCols.collect().foreach { row => + row.toSeq.foreach { value => + if (value != null) { + assert(value.isInstanceOf[String], "expected string but found " + value.getClass) + } + } + } + + val summaryOneCol = person2.select("age").summary() + assert(getSchemaAsSeq(summaryOneCol) === Seq("summary", "age")) + checkAnswer(summaryOneCol, summaryResult.map { case Row(s, _, d, _) => Row(s, d)} ) + + val summaryNoCol = person2.select().summary() + assert(getSchemaAsSeq(summaryNoCol) === Seq("summary")) + checkAnswer(summaryNoCol, summaryResult.map { case Row(s, _, _, _) => Row(s)} ) + + val emptyDescription = person2.limit(0).summary() + assert(getSchemaAsSeq(emptyDescription) === Seq("summary", "name", "age", "height")) + checkAnswer(emptyDescription, emptySummaryResult) + } + + test("summary advanced") { + val stats = Array("count", "50.01%", "max", "mean", "min", "25%") + val orderMatters = person2.summary(stats: _*) + assert(orderMatters.collect().map(_.getString(0)) === stats) + + val onlyPercentiles = person2.summary("0.1%", "99.9%") + assert(onlyPercentiles.count() === 2) + + val fooE = intercept[IllegalArgumentException] { + person2.summary("foo") + } + assert(fooE.getMessage === "requirement failed: foo is not a recognised statistic") + + val parseE = intercept[IllegalArgumentException] { + person2.summary("foo%") + } + assert(parseE.getMessage === "Unable to parse foo% as a percentile") + } + test("apply on query results (SPARK-5462)") { val df = testData.sparkSession.sql("select key from testData") checkAnswer(df.select(df("key")), testData.select('key).collect().toSeq) From 7896e7b99d95d28800f5644bd36b3990cf0ef8c4 Mon Sep 17 00:00:00 2001 From: Takeshi Yamamuro Date: Fri, 7 Jul 2017 23:05:38 -0700 Subject: [PATCH 0902/1765] [SPARK-21281][SQL] Use string types by default if array and map have no argument ## What changes were proposed in this pull request? This pr modified code to use string types by default if `array` and `map` in functions have no argument. This behaviour is the same with Hive one; ``` hive> CREATE TEMPORARY TABLE t1 AS SELECT map(); hive> DESCRIBE t1; _c0 map hive> CREATE TEMPORARY TABLE t2 AS SELECT array(); hive> DESCRIBE t2; _c0 array ``` ## How was this patch tested? Added tests in `DataFrameFunctionsSuite`. Author: Takeshi Yamamuro Closes #18516 from maropu/SPARK-21281. --- .../sql/catalyst/expressions/arithmetic.scala | 10 +++--- .../expressions/complexTypeCreator.scala | 35 ++++++++++-------- .../spark/sql/catalyst/expressions/hash.scala | 5 +-- .../expressions/nullExpressions.scala | 7 ++-- .../ExpressionTypeCheckingSuite.scala | 4 +-- .../org/apache/spark/sql/functions.scala | 10 ++---- .../spark/sql/DataFrameFunctionsSuite.scala | 36 +++++++++++++++++++ 7 files changed, 74 insertions(+), 33 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala index ec6e6ba0f091b..423bf66a24d1f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala @@ -527,13 +527,14 @@ case class Least(children: Seq[Expression]) extends Expression { override def checkInputDataTypes(): TypeCheckResult = { if (children.length <= 1) { - TypeCheckResult.TypeCheckFailure(s"LEAST requires at least 2 arguments") + TypeCheckResult.TypeCheckFailure( + s"input to function $prettyName requires at least two arguments") } else if (children.map(_.dataType).distinct.count(_ != NullType) > 1) { TypeCheckResult.TypeCheckFailure( s"The expressions should all have the same type," + s" got LEAST(${children.map(_.dataType.simpleString).mkString(", ")}).") } else { - TypeUtils.checkForOrderingExpr(dataType, "function " + prettyName) + TypeUtils.checkForOrderingExpr(dataType, s"function $prettyName") } } @@ -592,13 +593,14 @@ case class Greatest(children: Seq[Expression]) extends Expression { override def checkInputDataTypes(): TypeCheckResult = { if (children.length <= 1) { - TypeCheckResult.TypeCheckFailure(s"GREATEST requires at least 2 arguments") + TypeCheckResult.TypeCheckFailure( + s"input to function $prettyName requires at least two arguments") } else if (children.map(_.dataType).distinct.count(_ != NullType) > 1) { TypeCheckResult.TypeCheckFailure( s"The expressions should all have the same type," + s" got GREATEST(${children.map(_.dataType.simpleString).mkString(", ")}).") } else { - TypeUtils.checkForOrderingExpr(dataType, "function " + prettyName) + TypeUtils.checkForOrderingExpr(dataType, s"function $prettyName") } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala index 98c4cbee38dee..d9eeb5358ef79 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala @@ -41,12 +41,13 @@ case class CreateArray(children: Seq[Expression]) extends Expression { override def foldable: Boolean = children.forall(_.foldable) - override def checkInputDataTypes(): TypeCheckResult = - TypeUtils.checkForSameTypeInputExpr(children.map(_.dataType), "function array") + override def checkInputDataTypes(): TypeCheckResult = { + TypeUtils.checkForSameTypeInputExpr(children.map(_.dataType), s"function $prettyName") + } override def dataType: ArrayType = { ArrayType( - children.headOption.map(_.dataType).getOrElse(NullType), + children.headOption.map(_.dataType).getOrElse(StringType), containsNull = children.exists(_.nullable)) } @@ -93,7 +94,7 @@ private [sql] object GenArrayData { if (!ctx.isPrimitiveType(elementType)) { val genericArrayClass = classOf[GenericArrayData].getName ctx.addMutableState("Object[]", arrayName, - s"$arrayName = new Object[${numElements}];") + s"$arrayName = new Object[$numElements];") val assignments = elementsCode.zipWithIndex.map { case (eval, i) => val isNullAssignment = if (!isMapKey) { @@ -119,7 +120,7 @@ private [sql] object GenArrayData { UnsafeArrayData.calculateHeaderPortionInBytes(numElements) + ByteArrayMethods.roundNumberOfBytesToNearestWord(elementType.defaultSize * numElements) val baseOffset = Platform.BYTE_ARRAY_OFFSET - ctx.addMutableState("UnsafeArrayData", arrayDataName, ""); + ctx.addMutableState("UnsafeArrayData", arrayDataName, "") val primitiveValueTypeName = ctx.primitiveTypeName(elementType) val assignments = elementsCode.zipWithIndex.map { case (eval, i) => @@ -169,13 +170,16 @@ case class CreateMap(children: Seq[Expression]) extends Expression { override def checkInputDataTypes(): TypeCheckResult = { if (children.size % 2 != 0) { - TypeCheckResult.TypeCheckFailure(s"$prettyName expects a positive even number of arguments.") + TypeCheckResult.TypeCheckFailure( + s"$prettyName expects a positive even number of arguments.") } else if (keys.map(_.dataType).distinct.length > 1) { - TypeCheckResult.TypeCheckFailure("The given keys of function map should all be the same " + - "type, but they are " + keys.map(_.dataType.simpleString).mkString("[", ", ", "]")) + TypeCheckResult.TypeCheckFailure( + "The given keys of function map should all be the same type, but they are " + + keys.map(_.dataType.simpleString).mkString("[", ", ", "]")) } else if (values.map(_.dataType).distinct.length > 1) { - TypeCheckResult.TypeCheckFailure("The given values of function map should all be the same " + - "type, but they are " + values.map(_.dataType.simpleString).mkString("[", ", ", "]")) + TypeCheckResult.TypeCheckFailure( + "The given values of function map should all be the same type, but they are " + + values.map(_.dataType.simpleString).mkString("[", ", ", "]")) } else { TypeCheckResult.TypeCheckSuccess } @@ -183,8 +187,8 @@ case class CreateMap(children: Seq[Expression]) extends Expression { override def dataType: DataType = { MapType( - keyType = keys.headOption.map(_.dataType).getOrElse(NullType), - valueType = values.headOption.map(_.dataType).getOrElse(NullType), + keyType = keys.headOption.map(_.dataType).getOrElse(StringType), + valueType = values.headOption.map(_.dataType).getOrElse(StringType), valueContainsNull = values.exists(_.nullable)) } @@ -292,14 +296,17 @@ trait CreateNamedStructLike extends Expression { } override def checkInputDataTypes(): TypeCheckResult = { - if (children.size % 2 != 0) { + if (children.length < 1) { + TypeCheckResult.TypeCheckFailure( + s"input to function $prettyName requires at least one argument") + } else if (children.size % 2 != 0) { TypeCheckResult.TypeCheckFailure(s"$prettyName expects an even number of arguments.") } else { val invalidNames = nameExprs.filterNot(e => e.foldable && e.dataType == StringType) if (invalidNames.nonEmpty) { TypeCheckResult.TypeCheckFailure( "Only foldable StringType expressions are allowed to appear at odd position, got:" + - s" ${invalidNames.mkString(",")}") + s" ${invalidNames.mkString(",")}") } else if (!names.contains(null)) { TypeCheckResult.TypeCheckSuccess } else { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala index ffd0e64d86cff..2476fc962a6fa 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala @@ -247,8 +247,9 @@ abstract class HashExpression[E] extends Expression { override def nullable: Boolean = false override def checkInputDataTypes(): TypeCheckResult = { - if (children.isEmpty) { - TypeCheckResult.TypeCheckFailure("function hash requires at least one argument") + if (children.length < 1) { + TypeCheckResult.TypeCheckFailure( + s"input to function $prettyName requires at least one argument") } else { TypeCheckResult.TypeCheckSuccess } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala index 0866b8d791e01..1b625141d56ac 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala @@ -52,10 +52,11 @@ case class Coalesce(children: Seq[Expression]) extends Expression { override def foldable: Boolean = children.forall(_.foldable) override def checkInputDataTypes(): TypeCheckResult = { - if (children == Nil) { - TypeCheckResult.TypeCheckFailure("input to function coalesce cannot be empty") + if (children.length < 1) { + TypeCheckResult.TypeCheckFailure( + s"input to function $prettyName requires at least one argument") } else { - TypeUtils.checkForSameTypeInputExpr(children.map(_.dataType), "function coalesce") + TypeUtils.checkForSameTypeInputExpr(children.map(_.dataType), s"function $prettyName") } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala index 30459f173ab52..30725773a37b1 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala @@ -155,7 +155,7 @@ class ExpressionTypeCheckingSuite extends SparkFunSuite { "input to function array should all be the same type") assertError(Coalesce(Seq('intField, 'booleanField)), "input to function coalesce should all be the same type") - assertError(Coalesce(Nil), "input to function coalesce cannot be empty") + assertError(Coalesce(Nil), "function coalesce requires at least one argument") assertError(new Murmur3Hash(Nil), "function hash requires at least one argument") assertError(Explode('intField), "input to function explode should be array or map type") @@ -207,7 +207,7 @@ class ExpressionTypeCheckingSuite extends SparkFunSuite { test("check types for Greatest/Least") { for (operator <- Seq[(Seq[Expression] => Expression)](Greatest, Least)) { - assertError(operator(Seq('booleanField)), "requires at least 2 arguments") + assertError(operator(Seq('booleanField)), "requires at least two arguments") assertError(operator(Seq('intField, 'stringField)), "should all have the same type") assertError(operator(Seq('mapField, 'mapField)), "does not support ordering") } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 3c67960d13e09..1263071a3ffd5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -1565,10 +1565,7 @@ object functions { * @since 1.5.0 */ @scala.annotation.varargs - def greatest(exprs: Column*): Column = withExpr { - require(exprs.length > 1, "greatest requires at least 2 arguments.") - Greatest(exprs.map(_.expr)) - } + def greatest(exprs: Column*): Column = withExpr { Greatest(exprs.map(_.expr)) } /** * Returns the greatest value of the list of column names, skipping null values. @@ -1672,10 +1669,7 @@ object functions { * @since 1.5.0 */ @scala.annotation.varargs - def least(exprs: Column*): Column = withExpr { - require(exprs.length > 1, "least requires at least 2 arguments.") - Least(exprs.map(_.expr)) - } + def least(exprs: Column*): Column = withExpr { Least(exprs.map(_.expr)) } /** * Returns the least value of the list of column names, skipping null values. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index 0e9a2c6cf7dec..0681b9cbeb1d8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -448,6 +448,42 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { rand(Random.nextLong()), randn(Random.nextLong()) ).foreach(assertValuesDoNotChangeAfterCoalesceOrUnion(_)) } + + test("SPARK-21281 use string types by default if array and map have no argument") { + val ds = spark.range(1) + var expectedSchema = new StructType() + .add("x", ArrayType(StringType, containsNull = false), nullable = false) + assert(ds.select(array().as("x")).schema == expectedSchema) + expectedSchema = new StructType() + .add("x", MapType(StringType, StringType, valueContainsNull = false), nullable = false) + assert(ds.select(map().as("x")).schema == expectedSchema) + } + + test("SPARK-21281 fails if functions have no argument") { + val df = Seq(1).toDF("a") + + val funcsMustHaveAtLeastOneArg = + ("coalesce", (df: DataFrame) => df.select(coalesce())) :: + ("coalesce", (df: DataFrame) => df.selectExpr("coalesce()")) :: + ("named_struct", (df: DataFrame) => df.select(struct())) :: + ("named_struct", (df: DataFrame) => df.selectExpr("named_struct()")) :: + ("hash", (df: DataFrame) => df.select(hash())) :: + ("hash", (df: DataFrame) => df.selectExpr("hash()")) :: Nil + funcsMustHaveAtLeastOneArg.foreach { case (name, func) => + val errMsg = intercept[AnalysisException] { func(df) }.getMessage + assert(errMsg.contains(s"input to function $name requires at least one argument")) + } + + val funcsMustHaveAtLeastTwoArgs = + ("greatest", (df: DataFrame) => df.select(greatest())) :: + ("greatest", (df: DataFrame) => df.selectExpr("greatest()")) :: + ("least", (df: DataFrame) => df.select(least())) :: + ("least", (df: DataFrame) => df.selectExpr("least()")) :: Nil + funcsMustHaveAtLeastTwoArgs.foreach { case (name, func) => + val errMsg = intercept[AnalysisException] { func(df) }.getMessage + assert(errMsg.contains(s"input to function $name requires at least two arguments")) + } + } } object DataFrameFunctionsSuite { From 9760c15acbcf755dd5b13597ceb333576f806ecf Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Sat, 8 Jul 2017 14:20:09 +0800 Subject: [PATCH 0903/1765] [SPARK-20379][CORE] Allow SSL config to reference env variables. This change exposes the internal code path in SparkConf that allows configs to be read with variable substitution applied, and uses that new method in SSLOptions so that SSL configs can reference other variables, and more importantly, environment variables, providing a secure way to provide passwords to Spark when using SSL. The approach is a little bit hacky, but is the smallest change possible. Otherwise, the concept of "namespaced configs" would have to be added to the config system, which would create a lot of noise for not much gain at this point. Tested with added unit tests, and on a real cluster with SSL enabled. Author: Marcelo Vanzin Closes #18394 from vanzin/SPARK-20379.try2. --- .../scala/org/apache/spark/SSLOptions.scala | 20 +++++++++---------- .../scala/org/apache/spark/SparkConf.scala | 5 +++++ .../org/apache/spark/SSLOptionsSuite.scala | 16 +++++++++++++++ 3 files changed, 31 insertions(+), 10 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/SSLOptions.scala b/core/src/main/scala/org/apache/spark/SSLOptions.scala index 29163e7f30546..f86fd20e59190 100644 --- a/core/src/main/scala/org/apache/spark/SSLOptions.scala +++ b/core/src/main/scala/org/apache/spark/SSLOptions.scala @@ -167,39 +167,39 @@ private[spark] object SSLOptions extends Logging { def parse(conf: SparkConf, ns: String, defaults: Option[SSLOptions] = None): SSLOptions = { val enabled = conf.getBoolean(s"$ns.enabled", defaultValue = defaults.exists(_.enabled)) - val port = conf.getOption(s"$ns.port").map(_.toInt) + val port = conf.getWithSubstitution(s"$ns.port").map(_.toInt) port.foreach { p => require(p >= 0, "Port number must be a non-negative value.") } - val keyStore = conf.getOption(s"$ns.keyStore").map(new File(_)) + val keyStore = conf.getWithSubstitution(s"$ns.keyStore").map(new File(_)) .orElse(defaults.flatMap(_.keyStore)) - val keyStorePassword = conf.getOption(s"$ns.keyStorePassword") + val keyStorePassword = conf.getWithSubstitution(s"$ns.keyStorePassword") .orElse(defaults.flatMap(_.keyStorePassword)) - val keyPassword = conf.getOption(s"$ns.keyPassword") + val keyPassword = conf.getWithSubstitution(s"$ns.keyPassword") .orElse(defaults.flatMap(_.keyPassword)) - val keyStoreType = conf.getOption(s"$ns.keyStoreType") + val keyStoreType = conf.getWithSubstitution(s"$ns.keyStoreType") .orElse(defaults.flatMap(_.keyStoreType)) val needClientAuth = conf.getBoolean(s"$ns.needClientAuth", defaultValue = defaults.exists(_.needClientAuth)) - val trustStore = conf.getOption(s"$ns.trustStore").map(new File(_)) + val trustStore = conf.getWithSubstitution(s"$ns.trustStore").map(new File(_)) .orElse(defaults.flatMap(_.trustStore)) - val trustStorePassword = conf.getOption(s"$ns.trustStorePassword") + val trustStorePassword = conf.getWithSubstitution(s"$ns.trustStorePassword") .orElse(defaults.flatMap(_.trustStorePassword)) - val trustStoreType = conf.getOption(s"$ns.trustStoreType") + val trustStoreType = conf.getWithSubstitution(s"$ns.trustStoreType") .orElse(defaults.flatMap(_.trustStoreType)) - val protocol = conf.getOption(s"$ns.protocol") + val protocol = conf.getWithSubstitution(s"$ns.protocol") .orElse(defaults.flatMap(_.protocol)) - val enabledAlgorithms = conf.getOption(s"$ns.enabledAlgorithms") + val enabledAlgorithms = conf.getWithSubstitution(s"$ns.enabledAlgorithms") .map(_.split(",").map(_.trim).filter(_.nonEmpty).toSet) .orElse(defaults.map(_.enabledAlgorithms)) .getOrElse(Set.empty) diff --git a/core/src/main/scala/org/apache/spark/SparkConf.scala b/core/src/main/scala/org/apache/spark/SparkConf.scala index de2f475c6895f..715cfdcc8f4ef 100644 --- a/core/src/main/scala/org/apache/spark/SparkConf.scala +++ b/core/src/main/scala/org/apache/spark/SparkConf.scala @@ -373,6 +373,11 @@ class SparkConf(loadDefaults: Boolean) extends Cloneable with Logging with Seria Option(settings.get(key)).orElse(getDeprecatedConfig(key, this)) } + /** Get an optional value, applying variable substitution. */ + private[spark] def getWithSubstitution(key: String): Option[String] = { + getOption(key).map(reader.substitute(_)) + } + /** Get all parameters as a list of pairs */ def getAll: Array[(String, String)] = { settings.entrySet().asScala.map(x => (x.getKey, x.getValue)).toArray diff --git a/core/src/test/scala/org/apache/spark/SSLOptionsSuite.scala b/core/src/test/scala/org/apache/spark/SSLOptionsSuite.scala index 6fc7cea6ee94a..8eabc2b3cb958 100644 --- a/core/src/test/scala/org/apache/spark/SSLOptionsSuite.scala +++ b/core/src/test/scala/org/apache/spark/SSLOptionsSuite.scala @@ -22,6 +22,8 @@ import javax.net.ssl.SSLContext import org.scalatest.BeforeAndAfterAll +import org.apache.spark.util.SparkConfWithEnv + class SSLOptionsSuite extends SparkFunSuite with BeforeAndAfterAll { test("test resolving property file as spark conf ") { @@ -133,4 +135,18 @@ class SSLOptionsSuite extends SparkFunSuite with BeforeAndAfterAll { assert(opts.enabledAlgorithms === Set("ABC", "DEF")) } + test("variable substitution") { + val conf = new SparkConfWithEnv(Map( + "ENV1" -> "val1", + "ENV2" -> "val2")) + + conf.set("spark.ssl.enabled", "true") + conf.set("spark.ssl.keyStore", "${env:ENV1}") + conf.set("spark.ssl.trustStore", "${env:ENV2}") + + val opts = SSLOptions.parse(conf, "spark.ssl", defaults = None) + assert(opts.keyStore === Some(new File("val1"))) + assert(opts.trustStore === Some(new File("val2"))) + } + } From d0bfc6733521709e453d643582df2bdd68f28de7 Mon Sep 17 00:00:00 2001 From: Prashant Sharma Date: Fri, 7 Jul 2017 23:33:12 -0700 Subject: [PATCH 0904/1765] [SPARK-21069][SS][DOCS] Add rate source to programming guide. ## What changes were proposed in this pull request? SPARK-20979 added a new structured streaming source: Rate source. This patch adds the corresponding documentation to programming guide. ## How was this patch tested? Tested by running jekyll locally. Author: Prashant Sharma Author: Prashant Sharma Closes #18562 from ScrapCodes/spark-21069/rate-source-docs. --- docs/structured-streaming-programming-guide.md | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/docs/structured-streaming-programming-guide.md b/docs/structured-streaming-programming-guide.md index 3bc377c9a38b5..8f64faadc32dc 100644 --- a/docs/structured-streaming-programming-guide.md +++ b/docs/structured-streaming-programming-guide.md @@ -499,6 +499,8 @@ There are a few built-in sources. - **Socket source (for testing)** - Reads UTF8 text data from a socket connection. The listening server socket is at the driver. Note that this should be used only for testing as this does not provide end-to-end fault-tolerance guarantees. + - **Rate source (for testing)** - Generates data at the specified number of rows per second, each output row contains a `timestamp` and `value`. Where `timestamp` is a `Timestamp` type containing the time of message dispatch, and `value` is of `Long` type containing the message count, starting from 0 as the first row. This source is intended for testing and benchmarking. + Some sources are not fault-tolerant because they do not guarantee that data can be replayed using checkpointed offsets after a failure. See the earlier section on [fault-tolerance semantics](#fault-tolerance-semantics). @@ -546,6 +548,19 @@ Here are the details of all the sources in Spark. No + + Rate Source + + rowsPerSecond (e.g. 100, default: 1): How many rows should be generated per second.

    + rampUpTime (e.g. 5s, default: 0s): How long to ramp up before the generating speed becomes rowsPerSecond. Using finer granularities than seconds will be truncated to integer seconds.

    + numPartitions (e.g. 10, default: Spark's default parallelism): The partition number for the generated rows.

    + + The source will try its best to reach rowsPerSecond, but the query may be resource constrained, and numPartitions can be tweaked to help reach the desired speed. + + Yes + + + Kafka Source From a7b46c627b5d2461257f337139a29f23350e0c77 Mon Sep 17 00:00:00 2001 From: wangmiao1981 Date: Fri, 7 Jul 2017 23:51:32 -0700 Subject: [PATCH 0905/1765] [SPARK-20307][SPARKR] SparkR: pass on setHandleInvalid to spark.mllib functions that use StringIndexer ## What changes were proposed in this pull request? For randomForest classifier, if test data contains unseen labels, it will throw an error. The StringIndexer already has the handleInvalid logic. The patch add a new method to set the underlying StringIndexer handleInvalid logic. This patch should also apply to other classifiers. This PR focuses on the main logic and randomForest classifier. I will do follow-up PR for other classifiers. ## How was this patch tested? Add a new unit test based on the error case in the JIRA. Author: wangmiao1981 Closes #18496 from wangmiao1981/handle. --- R/pkg/R/mllib_tree.R | 11 ++++++-- R/pkg/tests/fulltests/test_mllib_tree.R | 17 +++++++++++++ .../apache/spark/ml/feature/RFormula.scala | 25 +++++++++++++++++++ .../r/RandomForestClassificationWrapper.scala | 4 ++- .../spark/ml/feature/StringIndexerSuite.scala | 2 +- 5 files changed, 55 insertions(+), 4 deletions(-) diff --git a/R/pkg/R/mllib_tree.R b/R/pkg/R/mllib_tree.R index 2f1220a752783..75b1a74ee8c7c 100644 --- a/R/pkg/R/mllib_tree.R +++ b/R/pkg/R/mllib_tree.R @@ -374,6 +374,10 @@ setMethod("write.ml", signature(object = "GBTClassificationModel", path = "chara #' nodes. If TRUE, the algorithm will cache node IDs for each instance. Caching #' can speed up training of deeper trees. Users can set how often should the #' cache be checkpointed or disable it by setting checkpointInterval. +#' @param handleInvalid How to handle invalid data (unseen labels or NULL values) in classification model. +#' Supported options: "skip" (filter out rows with invalid data), +#' "error" (throw an error), "keep" (put invalid data in a special additional +#' bucket, at index numLabels). Default is "error". #' @param ... additional arguments passed to the method. #' @aliases spark.randomForest,SparkDataFrame,formula-method #' @return \code{spark.randomForest} returns a fitted Random Forest model. @@ -409,7 +413,8 @@ setMethod("spark.randomForest", signature(data = "SparkDataFrame", formula = "fo maxDepth = 5, maxBins = 32, numTrees = 20, impurity = NULL, featureSubsetStrategy = "auto", seed = NULL, subsamplingRate = 1.0, minInstancesPerNode = 1, minInfoGain = 0.0, checkpointInterval = 10, - maxMemoryInMB = 256, cacheNodeIds = FALSE) { + maxMemoryInMB = 256, cacheNodeIds = FALSE, + handleInvalid = c("error", "keep", "skip")) { type <- match.arg(type) formula <- paste(deparse(formula), collapse = "") if (!is.null(seed)) { @@ -430,6 +435,7 @@ setMethod("spark.randomForest", signature(data = "SparkDataFrame", formula = "fo new("RandomForestRegressionModel", jobj = jobj) }, classification = { + handleInvalid <- match.arg(handleInvalid) if (is.null(impurity)) impurity <- "gini" impurity <- match.arg(impurity, c("gini", "entropy")) jobj <- callJStatic("org.apache.spark.ml.r.RandomForestClassifierWrapper", @@ -439,7 +445,8 @@ setMethod("spark.randomForest", signature(data = "SparkDataFrame", formula = "fo as.numeric(minInfoGain), as.integer(checkpointInterval), as.character(featureSubsetStrategy), seed, as.numeric(subsamplingRate), - as.integer(maxMemoryInMB), as.logical(cacheNodeIds)) + as.integer(maxMemoryInMB), as.logical(cacheNodeIds), + handleInvalid) new("RandomForestClassificationModel", jobj = jobj) } ) diff --git a/R/pkg/tests/fulltests/test_mllib_tree.R b/R/pkg/tests/fulltests/test_mllib_tree.R index 9b3fc8d270b25..66a0693a59a52 100644 --- a/R/pkg/tests/fulltests/test_mllib_tree.R +++ b/R/pkg/tests/fulltests/test_mllib_tree.R @@ -212,6 +212,23 @@ test_that("spark.randomForest", { expect_equal(length(grep("1.0", predictions)), 50) expect_equal(length(grep("2.0", predictions)), 50) + # Test unseen labels + data <- data.frame(clicked = base::sample(c(0, 1), 10, replace = TRUE), + someString = base::sample(c("this", "that"), 10, replace = TRUE), + stringsAsFactors = FALSE) + trainidxs <- base::sample(nrow(data), nrow(data) * 0.7) + traindf <- as.DataFrame(data[trainidxs, ]) + testdf <- as.DataFrame(rbind(data[-trainidxs, ], c(0, "the other"))) + model <- spark.randomForest(traindf, clicked ~ ., type = "classification", + maxDepth = 10, maxBins = 10, numTrees = 10) + predictions <- predict(model, testdf) + expect_error(collect(predictions)) + model <- spark.randomForest(traindf, clicked ~ ., type = "classification", + maxDepth = 10, maxBins = 10, numTrees = 10, + handleInvalid = "skip") + predictions <- predict(model, testdf) + expect_equal(class(collect(predictions)$clicked[1]), "character") + # spark.randomForest classification can work on libsvm data if (windows_with_hadoop()) { data <- read.df(absoluteSparkPath("data/mllib/sample_multiclass_classification_data.txt"), diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala index 4b44878784c90..61aa6463bb6da 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala @@ -132,6 +132,30 @@ class RFormula @Since("1.5.0") (@Since("1.5.0") override val uid: String) @Since("1.5.0") def getFormula: String = $(formula) + /** + * Param for how to handle invalid data (unseen labels or NULL values). + * Options are 'skip' (filter out rows with invalid data), + * 'error' (throw an error), or 'keep' (put invalid data in a special additional + * bucket, at index numLabels). + * Default: "error" + * @group param + */ + @Since("2.3.0") + val handleInvalid: Param[String] = new Param[String](this, "handleInvalid", "How to handle " + + "invalid data (unseen labels or NULL values). " + + "Options are 'skip' (filter out rows with invalid data), error (throw an error), " + + "or 'keep' (put invalid data in a special additional bucket, at index numLabels).", + ParamValidators.inArray(StringIndexer.supportedHandleInvalids)) + setDefault(handleInvalid, StringIndexer.ERROR_INVALID) + + /** @group setParam */ + @Since("2.3.0") + def setHandleInvalid(value: String): this.type = set(handleInvalid, value) + + /** @group getParam */ + @Since("2.3.0") + def getHandleInvalid: String = $(handleInvalid) + /** @group setParam */ @Since("1.5.0") def setFeaturesCol(value: String): this.type = set(featuresCol, value) @@ -197,6 +221,7 @@ class RFormula @Since("1.5.0") (@Since("1.5.0") override val uid: String) .setInputCol(term) .setOutputCol(indexCol) .setStringOrderType($(stringIndexerOrderType)) + .setHandleInvalid($(handleInvalid)) prefixesToRewrite(indexCol + "_") = term + "_" (term, indexCol) case _ => diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/RandomForestClassificationWrapper.scala b/mllib/src/main/scala/org/apache/spark/ml/r/RandomForestClassificationWrapper.scala index 8a83d4e980f7b..132345fb9a6d9 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/r/RandomForestClassificationWrapper.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/r/RandomForestClassificationWrapper.scala @@ -78,11 +78,13 @@ private[r] object RandomForestClassifierWrapper extends MLReadable[RandomForestC seed: String, subsamplingRate: Double, maxMemoryInMB: Int, - cacheNodeIds: Boolean): RandomForestClassifierWrapper = { + cacheNodeIds: Boolean, + handleInvalid: String): RandomForestClassifierWrapper = { val rFormula = new RFormula() .setFormula(formula) .setForceIndexLabel(true) + .setHandleInvalid(handleInvalid) checkDataColumns(rFormula, data) val rFormulaModel = rFormula.fit(data) diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala index 806a92760c8b6..027b1fbc6657c 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala @@ -20,7 +20,7 @@ package org.apache.spark.ml.feature import org.apache.spark.{SparkException, SparkFunSuite} import org.apache.spark.ml.attribute.{Attribute, NominalAttribute} import org.apache.spark.ml.param.ParamsSuite -import org.apache.spark.ml.util.{DefaultReadWriteTest, Identifiable, MLTestingUtils} +import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.sql.Row import org.apache.spark.sql.functions.col From f5f02d213d3151f58070e113d64fcded4f5d401e Mon Sep 17 00:00:00 2001 From: Michael Patterson Date: Fri, 7 Jul 2017 23:59:34 -0700 Subject: [PATCH 0906/1765] [SPARK-20456][DOCS] Add examples for functions collection for pyspark ## What changes were proposed in this pull request? This adds documentation to many functions in pyspark.sql.functions.py: `upper`, `lower`, `reverse`, `unix_timestamp`, `from_unixtime`, `rand`, `randn`, `collect_list`, `collect_set`, `lit` Add units to the trigonometry functions. Renames columns in datetime examples to be more informative. Adds links between some functions. ## How was this patch tested? `./dev/lint-python` `python python/pyspark/sql/functions.py` `./python/run-tests.py --module pyspark-sql` Author: Michael Patterson Closes #17865 from map222/spark-20456. --- R/pkg/R/functions.R | 11 +- python/pyspark/sql/functions.py | 166 +++++++++++------- .../org/apache/spark/sql/functions.scala | 14 +- 3 files changed, 119 insertions(+), 72 deletions(-) diff --git a/R/pkg/R/functions.R b/R/pkg/R/functions.R index c529d83060f50..f28d26a51baa0 100644 --- a/R/pkg/R/functions.R +++ b/R/pkg/R/functions.R @@ -336,7 +336,8 @@ setMethod("asin", }) #' @details -#' \code{atan}: Computes the tangent inverse of the given value. +#' \code{atan}: Computes the tangent inverse of the given value; the returned angle is in the range +#' -pi/2 through pi/2. #' #' @rdname column_math_functions #' @export @@ -599,7 +600,7 @@ setMethod("covar_pop", signature(col1 = "characterOrColumn", col2 = "characterOr }) #' @details -#' \code{cos}: Computes the cosine of the given value. +#' \code{cos}: Computes the cosine of the given value. Units in radians. #' #' @rdname column_math_functions #' @aliases cos cos,Column-method @@ -1407,7 +1408,7 @@ setMethod("sign", signature(x = "Column"), }) #' @details -#' \code{sin}: Computes the sine of the given value. +#' \code{sin}: Computes the sine of the given value. Units in radians. #' #' @rdname column_math_functions #' @aliases sin sin,Column-method @@ -1597,7 +1598,7 @@ setMethod("sumDistinct", }) #' @details -#' \code{tan}: Computes the tangent of the given value. +#' \code{tan}: Computes the tangent of the given value. Units in radians. #' #' @rdname column_math_functions #' @aliases tan tan,Column-method @@ -1896,7 +1897,7 @@ setMethod("year", #' @details #' \code{atan2}: Returns the angle theta from the conversion of rectangular coordinates -#' (x, y) to polar coordinates (r, theta). +#' (x, y) to polar coordinates (r, theta). Units in radians. #' #' @rdname column_math_functions #' @aliases atan2 atan2,Column-method diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index 3416c4b118a07..5d8ded83f667d 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -67,9 +67,14 @@ def _(): _.__doc__ = 'Window function: ' + doc return _ +_lit_doc = """ + Creates a :class:`Column` of literal value. + >>> df.select(lit(5).alias('height')).withColumn('spark_user', lit(True)).take(1) + [Row(height=5, spark_user=True)] + """ _functions = { - 'lit': 'Creates a :class:`Column` of literal value.', + 'lit': _lit_doc, 'col': 'Returns a :class:`Column` based on the given column name.', 'column': 'Returns a :class:`Column` based on the given column name.', 'asc': 'Returns a sort expression based on the ascending order of the given column name.', @@ -95,10 +100,13 @@ def _(): '0.0 through pi.', 'asin': 'Computes the sine inverse of the given value; the returned angle is in the range' + '-pi/2 through pi/2.', - 'atan': 'Computes the tangent inverse of the given value.', + 'atan': 'Computes the tangent inverse of the given value; the returned angle is in the range' + + '-pi/2 through pi/2', 'cbrt': 'Computes the cube-root of the given value.', 'ceil': 'Computes the ceiling of the given value.', - 'cos': 'Computes the cosine of the given value.', + 'cos': """Computes the cosine of the given value. + + :param col: :class:`DoubleType` column, units in radians.""", 'cosh': 'Computes the hyperbolic cosine of the given value.', 'exp': 'Computes the exponential of the given value.', 'expm1': 'Computes the exponential of the given value minus one.', @@ -109,15 +117,33 @@ def _(): 'rint': 'Returns the double value that is closest in value to the argument and' + ' is equal to a mathematical integer.', 'signum': 'Computes the signum of the given value.', - 'sin': 'Computes the sine of the given value.', + 'sin': """Computes the sine of the given value. + + :param col: :class:`DoubleType` column, units in radians.""", 'sinh': 'Computes the hyperbolic sine of the given value.', - 'tan': 'Computes the tangent of the given value.', + 'tan': """Computes the tangent of the given value. + + :param col: :class:`DoubleType` column, units in radians.""", 'tanh': 'Computes the hyperbolic tangent of the given value.', - 'toDegrees': '.. note:: Deprecated in 2.1, use degrees instead.', - 'toRadians': '.. note:: Deprecated in 2.1, use radians instead.', + 'toDegrees': '.. note:: Deprecated in 2.1, use :func:`degrees` instead.', + 'toRadians': '.. note:: Deprecated in 2.1, use :func:`radians` instead.', 'bitwiseNOT': 'Computes bitwise not.', } +_collect_list_doc = """ + Aggregate function: returns a list of objects with duplicates. + + >>> df2 = spark.createDataFrame([(2,), (5,), (5,)], ('age',)) + >>> df2.agg(collect_list('age')).collect() + [Row(collect_list(age)=[2, 5, 5])] + """ +_collect_set_doc = """ + Aggregate function: returns a set of objects with duplicate elements eliminated. + + >>> df2 = spark.createDataFrame([(2,), (5,), (5,)], ('age',)) + >>> df2.agg(collect_set('age')).collect() + [Row(collect_set(age)=[5, 2])] + """ _functions_1_6 = { # unary math functions 'stddev': 'Aggregate function: returns the unbiased sample standard deviation of' + @@ -131,9 +157,8 @@ def _(): 'var_pop': 'Aggregate function: returns the population variance of the values in a group.', 'skewness': 'Aggregate function: returns the skewness of the values in a group.', 'kurtosis': 'Aggregate function: returns the kurtosis of the values in a group.', - 'collect_list': 'Aggregate function: returns a list of objects with duplicates.', - 'collect_set': 'Aggregate function: returns a set of objects with duplicate elements' + - ' eliminated.', + 'collect_list': _collect_list_doc, + 'collect_set': _collect_set_doc } _functions_2_1 = { @@ -147,7 +172,7 @@ def _(): # math functions that take two arguments as input _binary_mathfunctions = { 'atan2': 'Returns the angle theta from the conversion of rectangular coordinates (x, y) to' + - 'polar coordinates (r, theta).', + 'polar coordinates (r, theta). Units in radians.', 'hypot': 'Computes ``sqrt(a^2 + b^2)`` without intermediate overflow or underflow.', 'pow': 'Returns the value of the first argument raised to the power of the second argument.', } @@ -200,17 +225,20 @@ def _(): @since(1.3) def approxCountDistinct(col, rsd=None): """ - .. note:: Deprecated in 2.1, use approx_count_distinct instead. + .. note:: Deprecated in 2.1, use :func:`approx_count_distinct` instead. """ return approx_count_distinct(col, rsd) @since(2.1) def approx_count_distinct(col, rsd=None): - """Returns a new :class:`Column` for approximate distinct count of ``col``. + """Aggregate function: returns a new :class:`Column` for approximate distinct count of column `col`. - >>> df.agg(approx_count_distinct(df.age).alias('c')).collect() - [Row(c=2)] + :param rsd: maximum estimation error allowed (default = 0.05). For rsd < 0.01, it is more + efficient to use :func:`countDistinct` + + >>> df.agg(approx_count_distinct(df.age).alias('distinct_ages')).collect() + [Row(distinct_ages=2)] """ sc = SparkContext._active_spark_context if rsd is None: @@ -267,8 +295,7 @@ def coalesce(*cols): @since(1.6) def corr(col1, col2): - """Returns a new :class:`Column` for the Pearson Correlation Coefficient for ``col1`` - and ``col2``. + """Returns a new :class:`Column` for the Pearson Correlation Coefficient for ``col1`` and ``col2``. >>> a = range(20) >>> b = [2 * x for x in range(20)] @@ -282,8 +309,7 @@ def corr(col1, col2): @since(2.0) def covar_pop(col1, col2): - """Returns a new :class:`Column` for the population covariance of ``col1`` - and ``col2``. + """Returns a new :class:`Column` for the population covariance of ``col1`` and ``col2``. >>> a = [1] * 10 >>> b = [1] * 10 @@ -297,8 +323,7 @@ def covar_pop(col1, col2): @since(2.0) def covar_samp(col1, col2): - """Returns a new :class:`Column` for the sample covariance of ``col1`` - and ``col2``. + """Returns a new :class:`Column` for the sample covariance of ``col1`` and ``col2``. >>> a = [1] * 10 >>> b = [1] * 10 @@ -450,7 +475,7 @@ def monotonically_increasing_id(): def nanvl(col1, col2): """Returns col1 if it is not NaN, or col2 if col1 is NaN. - Both inputs should be floating point columns (DoubleType or FloatType). + Both inputs should be floating point columns (:class:`DoubleType` or :class:`FloatType`). >>> df = spark.createDataFrame([(1.0, float('nan')), (float('nan'), 2.0)], ("a", "b")) >>> df.select(nanvl("a", "b").alias("r1"), nanvl(df.a, df.b).alias("r2")).collect() @@ -460,10 +485,15 @@ def nanvl(col1, col2): return Column(sc._jvm.functions.nanvl(_to_java_column(col1), _to_java_column(col2))) +@ignore_unicode_prefix @since(1.4) def rand(seed=None): """Generates a random column with independent and identically distributed (i.i.d.) samples from U[0.0, 1.0]. + + >>> df.withColumn('rand', rand(seed=42) * 3).collect() + [Row(age=2, name=u'Alice', rand=1.1568609015300986), + Row(age=5, name=u'Bob', rand=1.403379671529166)] """ sc = SparkContext._active_spark_context if seed is not None: @@ -473,10 +503,15 @@ def rand(seed=None): return Column(jc) +@ignore_unicode_prefix @since(1.4) def randn(seed=None): """Generates a column with independent and identically distributed (i.i.d.) samples from the standard normal distribution. + + >>> df.withColumn('randn', randn(seed=42)).collect() + [Row(age=2, name=u'Alice', randn=-0.7556247885860078), + Row(age=5, name=u'Bob', randn=-0.0861619008451133)] """ sc = SparkContext._active_spark_context if seed is not None: @@ -760,7 +795,7 @@ def ntile(n): @since(1.5) def current_date(): """ - Returns the current date as a date column. + Returns the current date as a :class:`DateType` column. """ sc = SparkContext._active_spark_context return Column(sc._jvm.functions.current_date()) @@ -768,7 +803,7 @@ def current_date(): def current_timestamp(): """ - Returns the current timestamp as a timestamp column. + Returns the current timestamp as a :class:`TimestampType` column. """ sc = SparkContext._active_spark_context return Column(sc._jvm.functions.current_timestamp()) @@ -787,8 +822,8 @@ def date_format(date, format): .. note:: Use when ever possible specialized functions like `year`. These benefit from a specialized implementation. - >>> df = spark.createDataFrame([('2015-04-08',)], ['a']) - >>> df.select(date_format('a', 'MM/dd/yyy').alias('date')).collect() + >>> df = spark.createDataFrame([('2015-04-08',)], ['dt']) + >>> df.select(date_format('dt', 'MM/dd/yyy').alias('date')).collect() [Row(date=u'04/08/2015')] """ sc = SparkContext._active_spark_context @@ -800,8 +835,8 @@ def year(col): """ Extract the year of a given date as integer. - >>> df = spark.createDataFrame([('2015-04-08',)], ['a']) - >>> df.select(year('a').alias('year')).collect() + >>> df = spark.createDataFrame([('2015-04-08',)], ['dt']) + >>> df.select(year('dt').alias('year')).collect() [Row(year=2015)] """ sc = SparkContext._active_spark_context @@ -813,8 +848,8 @@ def quarter(col): """ Extract the quarter of a given date as integer. - >>> df = spark.createDataFrame([('2015-04-08',)], ['a']) - >>> df.select(quarter('a').alias('quarter')).collect() + >>> df = spark.createDataFrame([('2015-04-08',)], ['dt']) + >>> df.select(quarter('dt').alias('quarter')).collect() [Row(quarter=2)] """ sc = SparkContext._active_spark_context @@ -826,8 +861,8 @@ def month(col): """ Extract the month of a given date as integer. - >>> df = spark.createDataFrame([('2015-04-08',)], ['a']) - >>> df.select(month('a').alias('month')).collect() + >>> df = spark.createDataFrame([('2015-04-08',)], ['dt']) + >>> df.select(month('dt').alias('month')).collect() [Row(month=4)] """ sc = SparkContext._active_spark_context @@ -839,8 +874,8 @@ def dayofmonth(col): """ Extract the day of the month of a given date as integer. - >>> df = spark.createDataFrame([('2015-04-08',)], ['a']) - >>> df.select(dayofmonth('a').alias('day')).collect() + >>> df = spark.createDataFrame([('2015-04-08',)], ['dt']) + >>> df.select(dayofmonth('dt').alias('day')).collect() [Row(day=8)] """ sc = SparkContext._active_spark_context @@ -852,8 +887,8 @@ def dayofyear(col): """ Extract the day of the year of a given date as integer. - >>> df = spark.createDataFrame([('2015-04-08',)], ['a']) - >>> df.select(dayofyear('a').alias('day')).collect() + >>> df = spark.createDataFrame([('2015-04-08',)], ['dt']) + >>> df.select(dayofyear('dt').alias('day')).collect() [Row(day=98)] """ sc = SparkContext._active_spark_context @@ -865,8 +900,8 @@ def hour(col): """ Extract the hours of a given date as integer. - >>> df = spark.createDataFrame([('2015-04-08 13:08:15',)], ['a']) - >>> df.select(hour('a').alias('hour')).collect() + >>> df = spark.createDataFrame([('2015-04-08 13:08:15',)], ['ts']) + >>> df.select(hour('ts').alias('hour')).collect() [Row(hour=13)] """ sc = SparkContext._active_spark_context @@ -878,8 +913,8 @@ def minute(col): """ Extract the minutes of a given date as integer. - >>> df = spark.createDataFrame([('2015-04-08 13:08:15',)], ['a']) - >>> df.select(minute('a').alias('minute')).collect() + >>> df = spark.createDataFrame([('2015-04-08 13:08:15',)], ['ts']) + >>> df.select(minute('ts').alias('minute')).collect() [Row(minute=8)] """ sc = SparkContext._active_spark_context @@ -891,8 +926,8 @@ def second(col): """ Extract the seconds of a given date as integer. - >>> df = spark.createDataFrame([('2015-04-08 13:08:15',)], ['a']) - >>> df.select(second('a').alias('second')).collect() + >>> df = spark.createDataFrame([('2015-04-08 13:08:15',)], ['ts']) + >>> df.select(second('ts').alias('second')).collect() [Row(second=15)] """ sc = SparkContext._active_spark_context @@ -904,8 +939,8 @@ def weekofyear(col): """ Extract the week number of a given date as integer. - >>> df = spark.createDataFrame([('2015-04-08',)], ['a']) - >>> df.select(weekofyear(df.a).alias('week')).collect() + >>> df = spark.createDataFrame([('2015-04-08',)], ['dt']) + >>> df.select(weekofyear(df.dt).alias('week')).collect() [Row(week=15)] """ sc = SparkContext._active_spark_context @@ -917,9 +952,9 @@ def date_add(start, days): """ Returns the date that is `days` days after `start` - >>> df = spark.createDataFrame([('2015-04-08',)], ['d']) - >>> df.select(date_add(df.d, 1).alias('d')).collect() - [Row(d=datetime.date(2015, 4, 9))] + >>> df = spark.createDataFrame([('2015-04-08',)], ['dt']) + >>> df.select(date_add(df.dt, 1).alias('next_date')).collect() + [Row(next_date=datetime.date(2015, 4, 9))] """ sc = SparkContext._active_spark_context return Column(sc._jvm.functions.date_add(_to_java_column(start), days)) @@ -930,9 +965,9 @@ def date_sub(start, days): """ Returns the date that is `days` days before `start` - >>> df = spark.createDataFrame([('2015-04-08',)], ['d']) - >>> df.select(date_sub(df.d, 1).alias('d')).collect() - [Row(d=datetime.date(2015, 4, 7))] + >>> df = spark.createDataFrame([('2015-04-08',)], ['dt']) + >>> df.select(date_sub(df.dt, 1).alias('prev_date')).collect() + [Row(prev_date=datetime.date(2015, 4, 7))] """ sc = SparkContext._active_spark_context return Column(sc._jvm.functions.date_sub(_to_java_column(start), days)) @@ -956,9 +991,9 @@ def add_months(start, months): """ Returns the date that is `months` months after `start` - >>> df = spark.createDataFrame([('2015-04-08',)], ['d']) - >>> df.select(add_months(df.d, 1).alias('d')).collect() - [Row(d=datetime.date(2015, 5, 8))] + >>> df = spark.createDataFrame([('2015-04-08',)], ['dt']) + >>> df.select(add_months(df.dt, 1).alias('next_month')).collect() + [Row(next_month=datetime.date(2015, 5, 8))] """ sc = SparkContext._active_spark_context return Column(sc._jvm.functions.add_months(_to_java_column(start), months)) @@ -969,8 +1004,8 @@ def months_between(date1, date2): """ Returns the number of months between date1 and date2. - >>> df = spark.createDataFrame([('1997-02-28 10:30:00', '1996-10-30')], ['t', 'd']) - >>> df.select(months_between(df.t, df.d).alias('months')).collect() + >>> df = spark.createDataFrame([('1997-02-28 10:30:00', '1996-10-30')], ['date1', 'date2']) + >>> df.select(months_between(df.date1, df.date2).alias('months')).collect() [Row(months=3.9495967...)] """ sc = SparkContext._active_spark_context @@ -1073,12 +1108,17 @@ def last_day(date): return Column(sc._jvm.functions.last_day(_to_java_column(date))) +@ignore_unicode_prefix @since(1.5) def from_unixtime(timestamp, format="yyyy-MM-dd HH:mm:ss"): """ Converts the number of seconds from unix epoch (1970-01-01 00:00:00 UTC) to a string representing the timestamp of that moment in the current system time zone in the given format. + + >>> time_df = spark.createDataFrame([(1428476400,)], ['unix_time']) + >>> time_df.select(from_unixtime('unix_time').alias('ts')).collect() + [Row(ts=u'2015-04-08 00:00:00')] """ sc = SparkContext._active_spark_context return Column(sc._jvm.functions.from_unixtime(_to_java_column(timestamp), format)) @@ -1092,6 +1132,10 @@ def unix_timestamp(timestamp=None, format='yyyy-MM-dd HH:mm:ss'): locale, return null if fail. if `timestamp` is None, then it returns current timestamp. + + >>> time_df = spark.createDataFrame([('2015-04-08',)], ['dt']) + >>> time_df.select(unix_timestamp('dt', 'yyyy-MM-dd').alias('unix_time')).collect() + [Row(unix_time=1428476400)] """ sc = SparkContext._active_spark_context if timestamp is None: @@ -1106,8 +1150,8 @@ def from_utc_timestamp(timestamp, tz): that corresponds to the same time of day in the given timezone. >>> df = spark.createDataFrame([('1997-02-28 10:30:00',)], ['t']) - >>> df.select(from_utc_timestamp(df.t, "PST").alias('t')).collect() - [Row(t=datetime.datetime(1997, 2, 28, 2, 30))] + >>> df.select(from_utc_timestamp(df.t, "PST").alias('local_time')).collect() + [Row(local_time=datetime.datetime(1997, 2, 28, 2, 30))] """ sc = SparkContext._active_spark_context return Column(sc._jvm.functions.from_utc_timestamp(_to_java_column(timestamp), tz)) @@ -1119,9 +1163,9 @@ def to_utc_timestamp(timestamp, tz): Given a timestamp, which corresponds to a certain time of day in the given timezone, returns another timestamp that corresponds to the same time of day in UTC. - >>> df = spark.createDataFrame([('1997-02-28 10:30:00',)], ['t']) - >>> df.select(to_utc_timestamp(df.t, "PST").alias('t')).collect() - [Row(t=datetime.datetime(1997, 2, 28, 18, 30))] + >>> df = spark.createDataFrame([('1997-02-28 10:30:00',)], ['ts']) + >>> df.select(to_utc_timestamp(df.ts, "PST").alias('utc_time')).collect() + [Row(utc_time=datetime.datetime(1997, 2, 28, 18, 30))] """ sc = SparkContext._active_spark_context return Column(sc._jvm.functions.to_utc_timestamp(_to_java_column(timestamp), tz)) @@ -2095,7 +2139,7 @@ def _test(): sc = spark.sparkContext globs['sc'] = sc globs['spark'] = spark - globs['df'] = sc.parallelize([Row(name='Alice', age=2), Row(name='Bob', age=5)]).toDF() + globs['df'] = spark.createDataFrame([Row(name='Alice', age=2), Row(name='Bob', age=5)]) (failure_count, test_count) = doctest.testmod( pyspark.sql.functions, globs=globs, optionflags=doctest.ELLIPSIS | doctest.NORMALIZE_WHITESPACE) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 1263071a3ffd5..a5e4a444f33be 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -1321,7 +1321,8 @@ object functions { def asin(columnName: String): Column = asin(Column(columnName)) /** - * Computes the tangent inverse of the given value. + * Computes the tangent inverse of the given column; the returned angle is in the range + * -pi/2 through pi/2 * * @group math_funcs * @since 1.4.0 @@ -1329,7 +1330,8 @@ object functions { def atan(e: Column): Column = withExpr { Atan(e.expr) } /** - * Computes the tangent inverse of the given column. + * Computes the tangent inverse of the given column; the returned angle is in the range + * -pi/2 through pi/2 * * @group math_funcs * @since 1.4.0 @@ -1338,7 +1340,7 @@ object functions { /** * Returns the angle theta from the conversion of rectangular coordinates (x, y) to - * polar coordinates (r, theta). + * polar coordinates (r, theta). Units in radians. * * @group math_funcs * @since 1.4.0 @@ -1470,7 +1472,7 @@ object functions { } /** - * Computes the cosine of the given value. + * Computes the cosine of the given value. Units in radians. * * @group math_funcs * @since 1.4.0 @@ -1937,7 +1939,7 @@ object functions { def signum(columnName: String): Column = signum(Column(columnName)) /** - * Computes the sine of the given value. + * Computes the sine of the given value. Units in radians. * * @group math_funcs * @since 1.4.0 @@ -1969,7 +1971,7 @@ object functions { def sinh(columnName: String): Column = sinh(Column(columnName)) /** - * Computes the tangent of the given value. + * Computes the tangent of the given value. Units in radians. * * @group math_funcs * @since 1.4.0 From 01f183e8497d4931f1fe5c69ff16fe84b1e41492 Mon Sep 17 00:00:00 2001 From: Joachim Hereth Date: Sat, 8 Jul 2017 08:32:45 +0100 Subject: [PATCH 0907/1765] Mesos doc fixes ## What changes were proposed in this pull request? Some link fixes for the documentation [Running Spark on Mesos](https://spark.apache.org/docs/latest/running-on-mesos.html): * Updated Link to Mesos Frameworks (Projects built on top of Mesos) * Update Link to Mesos binaries from Mesosphere (former link was redirected to dcos install page) ## How was this patch tested? Documentation was built and changed page manually/visually inspected. No code was changed, hence no dev tests. Since these changes are rather trivial I did not open a new JIRA ticket. Please review http://spark.apache.org/contributing.html before opening a pull request. Author: Joachim Hereth Closes #18564 from daten-kieker/mesos_doc_fixes. --- docs/running-on-mesos.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/running-on-mesos.md b/docs/running-on-mesos.md index ec130c1db8f5f..7401b63e022c1 100644 --- a/docs/running-on-mesos.md +++ b/docs/running-on-mesos.md @@ -10,7 +10,7 @@ Spark can run on hardware clusters managed by [Apache Mesos](http://mesos.apache The advantages of deploying Spark with Mesos include: - dynamic partitioning between Spark and other - [frameworks](https://mesos.apache.org/documentation/latest/mesos-frameworks/) + [frameworks](https://mesos.apache.org/documentation/latest/frameworks/) - scalable partitioning between multiple instances of Spark # How it Works @@ -61,7 +61,7 @@ third party projects publish binary releases that may be helpful in setting Meso One of those is Mesosphere. To install Mesos using the binary releases provided by Mesosphere: -1. Download Mesos installation package from [downloads page](http://mesosphere.io/downloads/) +1. Download Mesos installation package from [downloads page](https://open.mesosphere.com/downloads/mesos/) 2. Follow their instructions for installation and configuration The Mesosphere installation documents suggest setting up ZooKeeper to handle Mesos master failover, From 330bf5c99825afb6129577a34e6bed8b221a98cc Mon Sep 17 00:00:00 2001 From: caoxuewen Date: Sat, 8 Jul 2017 08:34:51 +0100 Subject: [PATCH 0908/1765] [SPARK-20609][MLLIB][TEST] manually cleared 'spark.local.dir' before/after a test in ALSCleanerSuite ## What changes were proposed in this pull request? This PR is similar to #17869. Once` 'spark.local.dir'` is set. Unless this is manually cleared before/after a test. it could return the same directory even if this property is configured. and add before/after for each likewise in ALSCleanerSuite. ## How was this patch tested? existing test. Author: caoxuewen Closes #18537 from heary-cao/ALSCleanerSuite. --- .../spark/ml/recommendation/ALSSuite.scala | 16 +++++++++++++++- 1 file changed, 15 insertions(+), 1 deletion(-) diff --git a/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala index b57fc8d21ab34..0a0fea255c7f3 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala @@ -29,6 +29,7 @@ import scala.language.existentials import com.github.fommil.netlib.BLAS.{getInstance => blas} import org.apache.commons.io.FileUtils import org.apache.commons.io.filefilter.TrueFileFilter +import org.scalatest.BeforeAndAfterEach import org.apache.spark._ import org.apache.spark.internal.Logging @@ -777,7 +778,20 @@ class ALSSuite } } -class ALSCleanerSuite extends SparkFunSuite { +class ALSCleanerSuite extends SparkFunSuite with BeforeAndAfterEach { + override def beforeEach(): Unit = { + super.beforeEach() + // Once `Utils.getOrCreateLocalRootDirs` is called, it is cached in `Utils.localRootDirs`. + // Unless this is manually cleared before and after a test, it returns the same directory + // set before even if 'spark.local.dir' is configured afterwards. + Utils.clearLocalRootDirs() + } + + override def afterEach(): Unit = { + Utils.clearLocalRootDirs() + super.afterEach() + } + test("ALS shuffle cleanup standalone") { val conf = new SparkConf() val localDir = Utils.createTempDir() From 0b8dd2d08460f3e6eb578727d2c336b6f11959e7 Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Sat, 8 Jul 2017 20:16:47 +0800 Subject: [PATCH 0909/1765] [SPARK-21345][SQL][TEST][TEST-MAVEN] SparkSessionBuilderSuite should clean up stopped sessions. ## What changes were proposed in this pull request? `SparkSessionBuilderSuite` should clean up stopped sessions. Otherwise, it leaves behind some stopped `SparkContext`s interfereing with other test suites using `ShardSQLContext`. Recently, master branch fails consequtively. - https://amplab.cs.berkeley.edu/jenkins/view/Spark%20QA%20Test%20(Dashboard)/ ## How was this patch tested? Pass the Jenkins with a updated suite. Author: Dongjoon Hyun Closes #18567 from dongjoon-hyun/SPARK-SESSION. --- .../spark/sql/SparkSessionBuilderSuite.scala | 46 ++++++++----------- 1 file changed, 18 insertions(+), 28 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionBuilderSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionBuilderSuite.scala index 770e15629c839..c0301f2ce2d66 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionBuilderSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionBuilderSuite.scala @@ -17,50 +17,49 @@ package org.apache.spark.sql +import org.scalatest.BeforeAndAfterEach + import org.apache.spark.{SparkConf, SparkContext, SparkFunSuite} import org.apache.spark.sql.internal.SQLConf /** * Test cases for the builder pattern of [[SparkSession]]. */ -class SparkSessionBuilderSuite extends SparkFunSuite { +class SparkSessionBuilderSuite extends SparkFunSuite with BeforeAndAfterEach { - private var initialSession: SparkSession = _ + override def afterEach(): Unit = { + // This suite should not interfere with the other test suites. + SparkSession.getActiveSession.foreach(_.stop()) + SparkSession.clearActiveSession() + SparkSession.getDefaultSession.foreach(_.stop()) + SparkSession.clearDefaultSession() + } - private lazy val sparkContext: SparkContext = { - initialSession = SparkSession.builder() + test("create with config options and propagate them to SparkContext and SparkSession") { + val session = SparkSession.builder() .master("local") .config("spark.ui.enabled", value = false) .config("some-config", "v2") .getOrCreate() - initialSession.sparkContext - } - - test("create with config options and propagate them to SparkContext and SparkSession") { - // Creating a new session with config - this works by just calling the lazy val - sparkContext - assert(initialSession.sparkContext.conf.get("some-config") == "v2") - assert(initialSession.conf.get("some-config") == "v2") - SparkSession.clearDefaultSession() + assert(session.sparkContext.conf.get("some-config") == "v2") + assert(session.conf.get("some-config") == "v2") } test("use global default session") { - val session = SparkSession.builder().getOrCreate() + val session = SparkSession.builder().master("local").getOrCreate() assert(SparkSession.builder().getOrCreate() == session) - SparkSession.clearDefaultSession() } test("config options are propagated to existing SparkSession") { - val session1 = SparkSession.builder().config("spark-config1", "a").getOrCreate() + val session1 = SparkSession.builder().master("local").config("spark-config1", "a").getOrCreate() assert(session1.conf.get("spark-config1") == "a") val session2 = SparkSession.builder().config("spark-config1", "b").getOrCreate() assert(session1 == session2) assert(session1.conf.get("spark-config1") == "b") - SparkSession.clearDefaultSession() } test("use session from active thread session and propagate config options") { - val defaultSession = SparkSession.builder().getOrCreate() + val defaultSession = SparkSession.builder().master("local").getOrCreate() val activeSession = defaultSession.newSession() SparkSession.setActiveSession(activeSession) val session = SparkSession.builder().config("spark-config2", "a").getOrCreate() @@ -73,16 +72,14 @@ class SparkSessionBuilderSuite extends SparkFunSuite { SparkSession.clearActiveSession() assert(SparkSession.builder().getOrCreate() == defaultSession) - SparkSession.clearDefaultSession() } test("create a new session if the default session has been stopped") { - val defaultSession = SparkSession.builder().getOrCreate() + val defaultSession = SparkSession.builder().master("local").getOrCreate() SparkSession.setDefaultSession(defaultSession) defaultSession.stop() val newSession = SparkSession.builder().master("local").getOrCreate() assert(newSession != defaultSession) - newSession.stop() } test("create a new session if the active thread session has been stopped") { @@ -91,11 +88,9 @@ class SparkSessionBuilderSuite extends SparkFunSuite { activeSession.stop() val newSession = SparkSession.builder().master("local").getOrCreate() assert(newSession != activeSession) - newSession.stop() } test("create SparkContext first then SparkSession") { - sparkContext.stop() val conf = new SparkConf().setAppName("test").setMaster("local").set("key1", "value1") val sparkContext2 = new SparkContext(conf) val session = SparkSession.builder().config("key2", "value2").getOrCreate() @@ -105,11 +100,9 @@ class SparkSessionBuilderSuite extends SparkFunSuite { // We won't update conf for existing `SparkContext` assert(!sparkContext2.conf.contains("key2")) assert(sparkContext2.conf.get("key1") == "value1") - session.stop() } test("create SparkContext first then pass context to SparkSession") { - sparkContext.stop() val conf = new SparkConf().setAppName("test").setMaster("local").set("key1", "value1") val newSC = new SparkContext(conf) val session = SparkSession.builder().sparkContext(newSC).config("key2", "value2").getOrCreate() @@ -121,14 +114,12 @@ class SparkSessionBuilderSuite extends SparkFunSuite { // the conf of this sparkContext will not contain the conf set through the API config. assert(!session.sparkContext.conf.contains("key2")) assert(session.sparkContext.conf.get("spark.app.name") == "test") - session.stop() } test("SPARK-15887: hive-site.xml should be loaded") { val session = SparkSession.builder().master("local").getOrCreate() assert(session.sessionState.newHadoopConf().get("hive.in.test") == "true") assert(session.sparkContext.hadoopConfiguration.get("hive.in.test") == "true") - session.stop() } test("SPARK-15991: Set global Hadoop conf") { @@ -140,7 +131,6 @@ class SparkSessionBuilderSuite extends SparkFunSuite { assert(session.sessionState.newHadoopConf().get(mySpecialKey) == mySpecialValue) } finally { session.sparkContext.hadoopConfiguration.unset(mySpecialKey) - session.stop() } } } From 9fccc3627fa41d32fbae6dbbb9bd1521e43eb4f0 Mon Sep 17 00:00:00 2001 From: Zhenhua Wang Date: Sat, 8 Jul 2017 20:44:12 +0800 Subject: [PATCH 0910/1765] [SPARK-21083][SQL] Store zero size and row count when analyzing empty table ## What changes were proposed in this pull request? We should be able to store zero size and row count after analyzing empty table. This pr also enhances the test cases for re-analyzing tables. ## How was this patch tested? Added a new test case and enhanced some test cases. Author: Zhenhua Wang Closes #18292 from wzhfy/analyzeNewColumn. --- .../command/AnalyzeTableCommand.scala | 5 +- .../spark/sql/StatisticsCollectionSuite.scala | 13 +++++ .../spark/sql/hive/StatisticsSuite.scala | 52 ++++++++++++++++--- 3 files changed, 59 insertions(+), 11 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeTableCommand.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeTableCommand.scala index 42e2a9ca5c4e2..cba147c35dd99 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeTableCommand.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeTableCommand.scala @@ -20,7 +20,6 @@ package org.apache.spark.sql.execution.command import org.apache.spark.sql.{AnalysisException, Row, SparkSession} import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.catalog.{CatalogStatistics, CatalogTableType} -import org.apache.spark.sql.execution.SQLExecution /** @@ -40,10 +39,10 @@ case class AnalyzeTableCommand( } val newTotalSize = CommandUtils.calculateTotalSize(sessionState, tableMeta) - val oldTotalSize = tableMeta.stats.map(_.sizeInBytes.toLong).getOrElse(0L) + val oldTotalSize = tableMeta.stats.map(_.sizeInBytes.toLong).getOrElse(-1L) val oldRowCount = tableMeta.stats.flatMap(_.rowCount.map(_.toLong)).getOrElse(-1L) var newStats: Option[CatalogStatistics] = None - if (newTotalSize > 0 && newTotalSize != oldTotalSize) { + if (newTotalSize >= 0 && newTotalSize != oldTotalSize) { newStats = Some(CatalogStatistics(sizeInBytes = newTotalSize)) } // We only set rowCount when noscan is false, because otherwise: diff --git a/sql/core/src/test/scala/org/apache/spark/sql/StatisticsCollectionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/StatisticsCollectionSuite.scala index 843ced7f0e697..b80bd80e93e8b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/StatisticsCollectionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/StatisticsCollectionSuite.scala @@ -82,6 +82,19 @@ class StatisticsCollectionSuite extends StatisticsCollectionTestBase with Shared } } + test("analyze empty table") { + val table = "emptyTable" + withTable(table) { + sql(s"CREATE TABLE $table (key STRING, value STRING) USING PARQUET") + sql(s"ANALYZE TABLE $table COMPUTE STATISTICS noscan") + val fetchedStats1 = checkTableStats(table, hasSizeInBytes = true, expectedRowCounts = None) + assert(fetchedStats1.get.sizeInBytes == 0) + sql(s"ANALYZE TABLE $table COMPUTE STATISTICS") + val fetchedStats2 = checkTableStats(table, hasSizeInBytes = true, expectedRowCounts = Some(0)) + assert(fetchedStats2.get.sizeInBytes == 0) + } + } + test("analyze column command - unsupported types and invalid columns") { val tableName = "column_stats_test1" withTable(tableName) { diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala index e00fa64e9f2ce..84bcea30d61a6 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala @@ -26,6 +26,7 @@ import scala.util.matching.Regex import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.catalog.{CatalogRelation, CatalogStatistics} +import org.apache.spark.sql.catalyst.plans.logical.ColumnStat import org.apache.spark.sql.catalyst.util.StringUtils import org.apache.spark.sql.execution.command.DDLUtils import org.apache.spark.sql.execution.datasources.LogicalRelation @@ -210,27 +211,62 @@ class StatisticsSuite extends StatisticsCollectionTestBase with TestHiveSingleto } } - test("test elimination of the influences of the old stats") { + test("keep existing row count in stats with noscan if table is not changed") { val textTable = "textTable" withTable(textTable) { - sql(s"CREATE TABLE $textTable (key STRING, value STRING) STORED AS TEXTFILE") + sql(s"CREATE TABLE $textTable (key STRING, value STRING)") sql(s"INSERT INTO TABLE $textTable SELECT * FROM src") sql(s"ANALYZE TABLE $textTable COMPUTE STATISTICS") val fetchedStats1 = checkTableStats(textTable, hasSizeInBytes = true, expectedRowCounts = Some(500)) sql(s"ANALYZE TABLE $textTable COMPUTE STATISTICS noscan") - // when the total size is not changed, the old row count is kept + // when the table is not changed, total size is the same, and the old row count is kept val fetchedStats2 = checkTableStats(textTable, hasSizeInBytes = true, expectedRowCounts = Some(500)) assert(fetchedStats1 == fetchedStats2) + } + } - sql(s"INSERT INTO TABLE $textTable SELECT * FROM src") - sql(s"ANALYZE TABLE $textTable COMPUTE STATISTICS noscan") - // update total size and remove the old and invalid row count + test("keep existing column stats if table is not changed") { + val table = "update_col_stats_table" + withTable(table) { + sql(s"CREATE TABLE $table (c1 INT, c2 STRING, c3 DOUBLE)") + sql(s"ANALYZE TABLE $table COMPUTE STATISTICS FOR COLUMNS c1") + val fetchedStats0 = + checkTableStats(table, hasSizeInBytes = true, expectedRowCounts = Some(0)) + assert(fetchedStats0.get.colStats == Map("c1" -> ColumnStat(0, None, None, 0, 4, 4))) + + // Insert new data and analyze: have the latest column stats. + sql(s"INSERT INTO TABLE $table SELECT 1, 'a', 10.0") + sql(s"ANALYZE TABLE $table COMPUTE STATISTICS FOR COLUMNS c1") + val fetchedStats1 = + checkTableStats(table, hasSizeInBytes = true, expectedRowCounts = Some(1)).get + assert(fetchedStats1.colStats == Map( + "c1" -> ColumnStat(distinctCount = 1, min = Some(1), max = Some(1), nullCount = 0, + avgLen = 4, maxLen = 4))) + + // Analyze another column: since the table is not changed, the precious column stats are kept. + sql(s"ANALYZE TABLE $table COMPUTE STATISTICS FOR COLUMNS c2") + val fetchedStats2 = + checkTableStats(table, hasSizeInBytes = true, expectedRowCounts = Some(1)).get + assert(fetchedStats2.colStats == Map( + "c1" -> ColumnStat(distinctCount = 1, min = Some(1), max = Some(1), nullCount = 0, + avgLen = 4, maxLen = 4), + "c2" -> ColumnStat(distinctCount = 1, min = None, max = None, nullCount = 0, + avgLen = 1, maxLen = 1))) + + // Insert new data and analyze: stale column stats are removed and newly collected column + // stats are added. + sql(s"INSERT INTO TABLE $table SELECT 2, 'b', 20.0") + sql(s"ANALYZE TABLE $table COMPUTE STATISTICS FOR COLUMNS c1, c3") val fetchedStats3 = - checkTableStats(textTable, hasSizeInBytes = true, expectedRowCounts = None) - assert(fetchedStats3.get.sizeInBytes > fetchedStats2.get.sizeInBytes) + checkTableStats(table, hasSizeInBytes = true, expectedRowCounts = Some(2)).get + assert(fetchedStats3.colStats == Map( + "c1" -> ColumnStat(distinctCount = 2, min = Some(1), max = Some(2), nullCount = 0, + avgLen = 4, maxLen = 4), + "c3" -> ColumnStat(distinctCount = 2, min = Some(10.0), max = Some(20.0), nullCount = 0, + avgLen = 8, maxLen = 8))) } } From 9131bdb7e12bcfb2cb699b3438f554604e28aaa8 Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Sun, 9 Jul 2017 00:24:54 +0800 Subject: [PATCH 0911/1765] [SPARK-20342][CORE] Update task accumulators before sending task end event. This makes sures that listeners get updated task information; otherwise it's possible to write incomplete task information into event logs, for example, making the information in a replayed UI inconsistent with the original application. Added a new unit test to try to detect the problem, but it's not guaranteed to fail since it's a race; but it fails pretty reliably for me without the scheduler changes. Author: Marcelo Vanzin Closes #18393 from vanzin/SPARK-20342.try2. --- .../apache/spark/scheduler/DAGScheduler.scala | 70 ++++++++++++------- .../spark/scheduler/DAGSchedulerSuite.scala | 32 ++++++++- 2 files changed, 75 insertions(+), 27 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index 3422a5f204b12..89b4cab88109d 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -1122,6 +1122,25 @@ class DAGScheduler( } } + private def postTaskEnd(event: CompletionEvent): Unit = { + val taskMetrics: TaskMetrics = + if (event.accumUpdates.nonEmpty) { + try { + TaskMetrics.fromAccumulators(event.accumUpdates) + } catch { + case NonFatal(e) => + val taskId = event.taskInfo.taskId + logError(s"Error when attempting to reconstruct metrics for task $taskId", e) + null + } + } else { + null + } + + listenerBus.post(SparkListenerTaskEnd(event.task.stageId, event.task.stageAttemptId, + Utils.getFormattedClassName(event.task), event.reason, event.taskInfo, taskMetrics)) + } + /** * Responds to a task finishing. This is called inside the event loop so it assumes that it can * modify the scheduler's internal state. Use taskEnded() to post a task end event from outside. @@ -1138,34 +1157,36 @@ class DAGScheduler( event.taskInfo.attemptNumber, // this is a task attempt number event.reason) - // Reconstruct task metrics. Note: this may be null if the task has failed. - val taskMetrics: TaskMetrics = - if (event.accumUpdates.nonEmpty) { - try { - TaskMetrics.fromAccumulators(event.accumUpdates) - } catch { - case NonFatal(e) => - logError(s"Error when attempting to reconstruct metrics for task $taskId", e) - null - } - } else { - null - } - - // The stage may have already finished when we get this event -- eg. maybe it was a - // speculative task. It is important that we send the TaskEnd event in any case, so listeners - // are properly notified and can chose to handle it. For instance, some listeners are - // doing their own accounting and if they don't get the task end event they think - // tasks are still running when they really aren't. - listenerBus.post(SparkListenerTaskEnd( - stageId, task.stageAttemptId, taskType, event.reason, event.taskInfo, taskMetrics)) - if (!stageIdToStage.contains(task.stageId)) { + // The stage may have already finished when we get this event -- eg. maybe it was a + // speculative task. It is important that we send the TaskEnd event in any case, so listeners + // are properly notified and can chose to handle it. For instance, some listeners are + // doing their own accounting and if they don't get the task end event they think + // tasks are still running when they really aren't. + postTaskEnd(event) + // Skip all the actions if the stage has been cancelled. return } val stage = stageIdToStage(task.stageId) + + // Make sure the task's accumulators are updated before any other processing happens, so that + // we can post a task end event before any jobs or stages are updated. The accumulators are + // only updated in certain cases. + event.reason match { + case Success => + stage match { + case rs: ResultStage if rs.activeJob.isEmpty => + // Ignore update if task's job has finished. + case _ => + updateAccumulators(event) + } + case _: ExceptionFailure => updateAccumulators(event) + case _ => + } + postTaskEnd(event) + event.reason match { case Success => task match { @@ -1176,7 +1197,6 @@ class DAGScheduler( resultStage.activeJob match { case Some(job) => if (!job.finished(rt.outputId)) { - updateAccumulators(event) job.finished(rt.outputId) = true job.numFinished += 1 // If the whole job has finished, remove it @@ -1203,7 +1223,6 @@ class DAGScheduler( case smt: ShuffleMapTask => val shuffleStage = stage.asInstanceOf[ShuffleMapStage] - updateAccumulators(event) val status = event.result.asInstanceOf[MapStatus] val execId = status.location.executorId logDebug("ShuffleMapTask finished on " + execId) @@ -1374,8 +1393,7 @@ class DAGScheduler( // Do nothing here, left up to the TaskScheduler to decide how to handle denied commits case exceptionFailure: ExceptionFailure => - // Tasks failed with exceptions might still have accumulator updates. - updateAccumulators(event) + // Nothing left to do, already handled above for accumulator updates. case TaskResultLost => // Do nothing here; the TaskScheduler handles these failures and resubmits the task. diff --git a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala index 453be26ed8d0c..3b5df657d45cf 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala @@ -18,7 +18,7 @@ package org.apache.spark.scheduler import java.util.Properties -import java.util.concurrent.atomic.AtomicBoolean +import java.util.concurrent.atomic.{AtomicBoolean, AtomicLong} import scala.annotation.meta.param import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet, Map} @@ -2346,6 +2346,36 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with Timeou (Success, 1))) } + test("task end event should have updated accumulators (SPARK-20342)") { + val tasks = 10 + + val accumId = new AtomicLong() + val foundCount = new AtomicLong() + val listener = new SparkListener() { + override def onTaskEnd(event: SparkListenerTaskEnd): Unit = { + event.taskInfo.accumulables.find(_.id == accumId.get).foreach { _ => + foundCount.incrementAndGet() + } + } + } + sc.addSparkListener(listener) + + // Try a few times in a loop to make sure. This is not guaranteed to fail when the bug exists, + // but it should at least make the test flaky. If the bug is fixed, this should always pass. + (1 to 10).foreach { i => + foundCount.set(0L) + + val accum = sc.longAccumulator(s"accum$i") + accumId.set(accum.id) + + sc.parallelize(1 to tasks, tasks).foreach { _ => + accum.add(1L) + } + sc.listenerBus.waitUntilEmpty(1000) + assert(foundCount.get() === tasks) + } + } + /** * Assert that the supplied TaskSet has exactly the given hosts as its preferred locations. * Note that this checks only the host and not the executor ID. From 062c336d06a0bd4e740a18d2349e03e311509243 Mon Sep 17 00:00:00 2001 From: jinxing Date: Sun, 9 Jul 2017 00:27:58 +0800 Subject: [PATCH 0912/1765] [SPARK-21343] Refine the document for spark.reducer.maxReqSizeShuffleToMem. ## What changes were proposed in this pull request? In current code, reducer can break the old shuffle service when `spark.reducer.maxReqSizeShuffleToMem` is enabled. Let's refine document. Author: jinxing Closes #18566 from jinxing64/SPARK-21343. --- .../org/apache/spark/internal/config/package.scala | 6 ++++-- docs/configuration.md | 10 ++++++++++ 2 files changed, 14 insertions(+), 2 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/internal/config/package.scala b/core/src/main/scala/org/apache/spark/internal/config/package.scala index a629810bf093a..512d539ee9c38 100644 --- a/core/src/main/scala/org/apache/spark/internal/config/package.scala +++ b/core/src/main/scala/org/apache/spark/internal/config/package.scala @@ -323,9 +323,11 @@ package object config { private[spark] val REDUCER_MAX_REQ_SIZE_SHUFFLE_TO_MEM = ConfigBuilder("spark.reducer.maxReqSizeShuffleToMem") - .internal() .doc("The blocks of a shuffle request will be fetched to disk when size of the request is " + - "above this threshold. This is to avoid a giant request takes too much memory.") + "above this threshold. This is to avoid a giant request takes too much memory. We can " + + "enable this config by setting a specific value(e.g. 200m). Note that this config can " + + "be enabled only when the shuffle shuffle service is newer than Spark-2.2 or the shuffle" + + " service is disabled.") .bytesConf(ByteUnit.BYTE) .createWithDefault(Long.MaxValue) diff --git a/docs/configuration.md b/docs/configuration.md index 7dc23e441a7ba..6ca84240c1247 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -528,6 +528,16 @@ Apart from these, the following properties are also available, and may be useful By allowing it to limit the number of fetch requests, this scenario can be mitigated. + + spark.reducer.maxReqSizeShuffleToMem + Long.MaxValue + + The blocks of a shuffle request will be fetched to disk when size of the request is above + this threshold. This is to avoid a giant request takes too much memory. We can enable this + config by setting a specific value(e.g. 200m). Note that this config can be enabled only when + the shuffle shuffle service is newer than Spark-2.2 or the shuffle service is disabled. + + spark.shuffle.compress true From c3712b77a915a5cb12b6c0204bc5bd6037aad1f5 Mon Sep 17 00:00:00 2001 From: Xiao Li Date: Sat, 8 Jul 2017 11:56:19 -0700 Subject: [PATCH 0913/1765] [SPARK-21307][REVERT][SQL] Remove SQLConf parameters from the parser-related classes ## What changes were proposed in this pull request? Since we do not set active sessions when parsing the plan, we are unable to correctly use SQLConf.get to find the correct active session. Since https://github.com/apache/spark/pull/18531 breaks the build, I plan to revert it at first. ## How was this patch tested? The existing test cases Author: Xiao Li Closes #18568 from gatorsmile/revert18531. --- .../sql/catalyst/catalog/SessionCatalog.scala | 2 +- .../sql/catalyst/parser/AstBuilder.scala | 6 +- .../sql/catalyst/parser/ParseDriver.scala | 8 +- .../parser/ExpressionParserSuite.scala | 167 +++++++++--------- .../spark/sql/execution/SparkSqlParser.scala | 11 +- .../org/apache/spark/sql/functions.scala | 3 +- .../internal/BaseSessionStateBuilder.scala | 2 +- .../sql/internal/VariableSubstitution.scala | 4 +- .../sql/execution/SparkSqlParserSuite.scala | 10 +- .../execution/command/DDLCommandSuite.scala | 4 +- .../internal/VariableSubstitutionSuite.scala | 31 ++-- 11 files changed, 127 insertions(+), 121 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala index 336d3d65d0dd0..c40d5f6031a21 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala @@ -74,7 +74,7 @@ class SessionCatalog( functionRegistry, conf, new Configuration(), - CatalystSqlParser, + new CatalystSqlParser(conf), DummyFunctionResourceLoader) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index 4d725904bc9b9..a616b0f773f38 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -45,9 +45,11 @@ import org.apache.spark.util.random.RandomSampler * The AstBuilder converts an ANTLR4 ParseTree into a catalyst Expression, LogicalPlan or * TableIdentifier. */ -class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with Logging { +class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging { import ParserUtils._ + def this() = this(new SQLConf()) + protected def typedVisit[T](ctx: ParseTree): T = { ctx.accept(this).asInstanceOf[T] } @@ -1457,7 +1459,7 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with Logging { * Special characters can be escaped by using Hive/C-style escaping. */ private def createString(ctx: StringLiteralContext): String = { - if (SQLConf.get.escapedStringLiterals) { + if (conf.escapedStringLiterals) { ctx.STRING().asScala.map(stringWithoutUnescape).mkString } else { ctx.STRING().asScala.map(string).mkString diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParseDriver.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParseDriver.scala index 7e1fcfefc64a5..09598ffe770c6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParseDriver.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParseDriver.scala @@ -26,6 +26,7 @@ import org.apache.spark.sql.catalyst.{FunctionIdentifier, TableIdentifier} import org.apache.spark.sql.catalyst.expressions.Expression import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.trees.Origin +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.{DataType, StructType} /** @@ -121,8 +122,13 @@ abstract class AbstractSqlParser extends ParserInterface with Logging { /** * Concrete SQL parser for Catalyst-only SQL statements. */ +class CatalystSqlParser(conf: SQLConf) extends AbstractSqlParser { + val astBuilder = new AstBuilder(conf) +} + +/** For test-only. */ object CatalystSqlParser extends AbstractSqlParser { - val astBuilder = new AstBuilder + val astBuilder = new AstBuilder(new SQLConf()) } /** diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala index ac7325257a15a..45f9f72dccc45 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala @@ -167,12 +167,12 @@ class ExpressionParserSuite extends PlanTest { } test("like expressions with ESCAPED_STRING_LITERALS = true") { - val parser = CatalystSqlParser - withSQLConf(SQLConf.ESCAPED_STRING_LITERALS.key -> "true") { - assertEqual("a rlike '^\\x20[\\x20-\\x23]+$'", 'a rlike "^\\x20[\\x20-\\x23]+$", parser) - assertEqual("a rlike 'pattern\\\\'", 'a rlike "pattern\\\\", parser) - assertEqual("a rlike 'pattern\\t\\n'", 'a rlike "pattern\\t\\n", parser) - } + val conf = new SQLConf() + conf.setConfString(SQLConf.ESCAPED_STRING_LITERALS.key, "true") + val parser = new CatalystSqlParser(conf) + assertEqual("a rlike '^\\x20[\\x20-\\x23]+$'", 'a rlike "^\\x20[\\x20-\\x23]+$", parser) + assertEqual("a rlike 'pattern\\\\'", 'a rlike "pattern\\\\", parser) + assertEqual("a rlike 'pattern\\t\\n'", 'a rlike "pattern\\t\\n", parser) } test("is null expressions") { @@ -435,85 +435,86 @@ class ExpressionParserSuite extends PlanTest { } test("strings") { - val parser = CatalystSqlParser Seq(true, false).foreach { escape => - withSQLConf(SQLConf.ESCAPED_STRING_LITERALS.key -> escape.toString) { - // tests that have same result whatever the conf is - // Single Strings. - assertEqual("\"hello\"", "hello", parser) - assertEqual("'hello'", "hello", parser) - - // Multi-Strings. - assertEqual("\"hello\" 'world'", "helloworld", parser) - assertEqual("'hello' \" \" 'world'", "hello world", parser) - - // 'LIKE' string literals. Notice that an escaped '%' is the same as an escaped '\' and a - // regular '%'; to get the correct result you need to add another escaped '\'. - // TODO figure out if we shouldn't change the ParseUtils.unescapeSQLString method? - assertEqual("'pattern%'", "pattern%", parser) - assertEqual("'no-pattern\\%'", "no-pattern\\%", parser) - - // tests that have different result regarding the conf - if (escape) { - // When SQLConf.ESCAPED_STRING_LITERALS is enabled, string literal parsing fallbacks to - // Spark 1.6 behavior. - - // 'LIKE' string literals. - assertEqual("'pattern\\\\%'", "pattern\\\\%", parser) - assertEqual("'pattern\\\\\\%'", "pattern\\\\\\%", parser) - - // Escaped characters. - // Unescape string literal "'\\0'" for ASCII NUL (X'00') doesn't work - // when ESCAPED_STRING_LITERALS is enabled. - // It is parsed literally. - assertEqual("'\\0'", "\\0", parser) - - // Note: Single quote follows 1.6 parsing behavior when ESCAPED_STRING_LITERALS is - // enabled. - val e = intercept[ParseException](parser.parseExpression("'\''")) - assert(e.message.contains("extraneous input '''")) - - // The unescape special characters (e.g., "\\t") for 2.0+ don't work - // when ESCAPED_STRING_LITERALS is enabled. They are parsed literally. - assertEqual("'\\\"'", "\\\"", parser) // Double quote - assertEqual("'\\b'", "\\b", parser) // Backspace - assertEqual("'\\n'", "\\n", parser) // Newline - assertEqual("'\\r'", "\\r", parser) // Carriage return - assertEqual("'\\t'", "\\t", parser) // Tab character - - // The unescape Octals for 2.0+ don't work when ESCAPED_STRING_LITERALS is enabled. - // They are parsed literally. - assertEqual("'\\110\\145\\154\\154\\157\\041'", "\\110\\145\\154\\154\\157\\041", parser) - // The unescape Unicode for 2.0+ doesn't work when ESCAPED_STRING_LITERALS is enabled. - // They are parsed literally. - assertEqual("'\\u0057\\u006F\\u0072\\u006C\\u0064\\u0020\\u003A\\u0029'", - "\\u0057\\u006F\\u0072\\u006C\\u0064\\u0020\\u003A\\u0029", parser) - } else { - // Default behavior - - // 'LIKE' string literals. - assertEqual("'pattern\\\\%'", "pattern\\%", parser) - assertEqual("'pattern\\\\\\%'", "pattern\\\\%", parser) - - // Escaped characters. - // See: http://dev.mysql.com/doc/refman/5.7/en/string-literals.html - assertEqual("'\\0'", "\u0000", parser) // ASCII NUL (X'00') - assertEqual("'\\''", "\'", parser) // Single quote - assertEqual("'\\\"'", "\"", parser) // Double quote - assertEqual("'\\b'", "\b", parser) // Backspace - assertEqual("'\\n'", "\n", parser) // Newline - assertEqual("'\\r'", "\r", parser) // Carriage return - assertEqual("'\\t'", "\t", parser) // Tab character - assertEqual("'\\Z'", "\u001A", parser) // ASCII 26 - CTRL + Z (EOF on windows) - - // Octals - assertEqual("'\\110\\145\\154\\154\\157\\041'", "Hello!", parser) - - // Unicode - assertEqual("'\\u0057\\u006F\\u0072\\u006C\\u0064\\u0020\\u003A\\u0029'", "World :)", - parser) - } + val conf = new SQLConf() + conf.setConfString(SQLConf.ESCAPED_STRING_LITERALS.key, escape.toString) + val parser = new CatalystSqlParser(conf) + + // tests that have same result whatever the conf is + // Single Strings. + assertEqual("\"hello\"", "hello", parser) + assertEqual("'hello'", "hello", parser) + + // Multi-Strings. + assertEqual("\"hello\" 'world'", "helloworld", parser) + assertEqual("'hello' \" \" 'world'", "hello world", parser) + + // 'LIKE' string literals. Notice that an escaped '%' is the same as an escaped '\' and a + // regular '%'; to get the correct result you need to add another escaped '\'. + // TODO figure out if we shouldn't change the ParseUtils.unescapeSQLString method? + assertEqual("'pattern%'", "pattern%", parser) + assertEqual("'no-pattern\\%'", "no-pattern\\%", parser) + + // tests that have different result regarding the conf + if (escape) { + // When SQLConf.ESCAPED_STRING_LITERALS is enabled, string literal parsing fallbacks to + // Spark 1.6 behavior. + + // 'LIKE' string literals. + assertEqual("'pattern\\\\%'", "pattern\\\\%", parser) + assertEqual("'pattern\\\\\\%'", "pattern\\\\\\%", parser) + + // Escaped characters. + // Unescape string literal "'\\0'" for ASCII NUL (X'00') doesn't work + // when ESCAPED_STRING_LITERALS is enabled. + // It is parsed literally. + assertEqual("'\\0'", "\\0", parser) + + // Note: Single quote follows 1.6 parsing behavior when ESCAPED_STRING_LITERALS is enabled. + val e = intercept[ParseException](parser.parseExpression("'\''")) + assert(e.message.contains("extraneous input '''")) + + // The unescape special characters (e.g., "\\t") for 2.0+ don't work + // when ESCAPED_STRING_LITERALS is enabled. They are parsed literally. + assertEqual("'\\\"'", "\\\"", parser) // Double quote + assertEqual("'\\b'", "\\b", parser) // Backspace + assertEqual("'\\n'", "\\n", parser) // Newline + assertEqual("'\\r'", "\\r", parser) // Carriage return + assertEqual("'\\t'", "\\t", parser) // Tab character + + // The unescape Octals for 2.0+ don't work when ESCAPED_STRING_LITERALS is enabled. + // They are parsed literally. + assertEqual("'\\110\\145\\154\\154\\157\\041'", "\\110\\145\\154\\154\\157\\041", parser) + // The unescape Unicode for 2.0+ doesn't work when ESCAPED_STRING_LITERALS is enabled. + // They are parsed literally. + assertEqual("'\\u0057\\u006F\\u0072\\u006C\\u0064\\u0020\\u003A\\u0029'", + "\\u0057\\u006F\\u0072\\u006C\\u0064\\u0020\\u003A\\u0029", parser) + } else { + // Default behavior + + // 'LIKE' string literals. + assertEqual("'pattern\\\\%'", "pattern\\%", parser) + assertEqual("'pattern\\\\\\%'", "pattern\\\\%", parser) + + // Escaped characters. + // See: http://dev.mysql.com/doc/refman/5.7/en/string-literals.html + assertEqual("'\\0'", "\u0000", parser) // ASCII NUL (X'00') + assertEqual("'\\''", "\'", parser) // Single quote + assertEqual("'\\\"'", "\"", parser) // Double quote + assertEqual("'\\b'", "\b", parser) // Backspace + assertEqual("'\\n'", "\n", parser) // Newline + assertEqual("'\\r'", "\r", parser) // Carriage return + assertEqual("'\\t'", "\t", parser) // Tab character + assertEqual("'\\Z'", "\u001A", parser) // ASCII 26 - CTRL + Z (EOF on windows) + + // Octals + assertEqual("'\\110\\145\\154\\154\\157\\041'", "Hello!", parser) + + // Unicode + assertEqual("'\\u0057\\u006F\\u0072\\u006C\\u0064\\u0020\\u003A\\u0029'", "World :)", + parser) } + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala index 618d027d8dc07..2f8e416e7df1b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala @@ -39,11 +39,10 @@ import org.apache.spark.sql.types.StructType /** * Concrete parser for Spark SQL statements. */ -class SparkSqlParser extends AbstractSqlParser { +class SparkSqlParser(conf: SQLConf) extends AbstractSqlParser { + val astBuilder = new SparkSqlAstBuilder(conf) - val astBuilder = new SparkSqlAstBuilder - - private val substitutor = new VariableSubstitution + private val substitutor = new VariableSubstitution(conf) protected override def parse[T](command: String)(toResult: SqlBaseParser => T): T = { super.parse(substitutor.substitute(command))(toResult) @@ -53,11 +52,9 @@ class SparkSqlParser extends AbstractSqlParser { /** * Builder that converts an ANTLR ParseTree into a LogicalPlan/Expression/TableIdentifier. */ -class SparkSqlAstBuilder extends AstBuilder { +class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder(conf) { import org.apache.spark.sql.catalyst.parser.ParserUtils._ - private def conf: SQLConf = SQLConf.get - /** * Create a [[SetCommand]] logical plan. * diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index a5e4a444f33be..0c7b483f5c836 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -32,6 +32,7 @@ import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.plans.logical.{HintInfo, ResolvedHint} import org.apache.spark.sql.execution.SparkSqlParser import org.apache.spark.sql.expressions.UserDefinedFunction +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ import org.apache.spark.util.Utils @@ -1275,7 +1276,7 @@ object functions { */ def expr(expr: String): Column = { val parser = SparkSession.getActiveSession.map(_.sessionState.sqlParser).getOrElse { - new SparkSqlParser + new SparkSqlParser(new SQLConf) } Column(parser.parseExpression(expr)) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala index 72d0ddc62303a..267f76217df84 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala @@ -114,7 +114,7 @@ abstract class BaseSessionStateBuilder( * Note: this depends on the `conf` field. */ protected lazy val sqlParser: ParserInterface = { - extensions.buildParser(session, new SparkSqlParser) + extensions.buildParser(session, new SparkSqlParser(conf)) } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/VariableSubstitution.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/VariableSubstitution.scala index 2b9c574aaaf0c..4e7c813be9922 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/VariableSubstitution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/VariableSubstitution.scala @@ -25,9 +25,7 @@ import org.apache.spark.internal.config._ * * Variable substitution is controlled by `SQLConf.variableSubstituteEnabled`. */ -class VariableSubstitution { - - private def conf = SQLConf.get +class VariableSubstitution(conf: SQLConf) { private val provider = new ConfigProvider { override def get(key: String): Option[String] = Option(conf.getConfString(key, "")) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlParserSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlParserSuite.scala index 2e29fa43f73d9..d238c76fbeeff 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlParserSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlParserSuite.scala @@ -37,7 +37,8 @@ import org.apache.spark.sql.types.{IntegerType, LongType, StringType, StructType */ class SparkSqlParserSuite extends AnalysisTest { - private lazy val parser = new SparkSqlParser + val newConf = new SQLConf + private lazy val parser = new SparkSqlParser(newConf) /** * Normalizes plans: @@ -284,7 +285,6 @@ class SparkSqlParserSuite extends AnalysisTest { } test("query organization") { - val conf = SQLConf.get // Test all valid combinations of order by/sort by/distribute by/cluster by/limit/windows val baseSql = "select * from t" val basePlan = @@ -293,20 +293,20 @@ class SparkSqlParserSuite extends AnalysisTest { assertEqual(s"$baseSql distribute by a, b", RepartitionByExpression(UnresolvedAttribute("a") :: UnresolvedAttribute("b") :: Nil, basePlan, - numPartitions = conf.numShufflePartitions)) + numPartitions = newConf.numShufflePartitions)) assertEqual(s"$baseSql distribute by a sort by b", Sort(SortOrder(UnresolvedAttribute("b"), Ascending) :: Nil, global = false, RepartitionByExpression(UnresolvedAttribute("a") :: Nil, basePlan, - numPartitions = conf.numShufflePartitions))) + numPartitions = newConf.numShufflePartitions))) assertEqual(s"$baseSql cluster by a, b", Sort(SortOrder(UnresolvedAttribute("a"), Ascending) :: SortOrder(UnresolvedAttribute("b"), Ascending) :: Nil, global = false, RepartitionByExpression(UnresolvedAttribute("a") :: UnresolvedAttribute("b") :: Nil, basePlan, - numPartitions = conf.numShufflePartitions))) + numPartitions = newConf.numShufflePartitions))) } test("pipeline concatenation") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLCommandSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLCommandSuite.scala index 750574830381f..5643c58d9f847 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLCommandSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLCommandSuite.scala @@ -29,13 +29,13 @@ import org.apache.spark.sql.catalyst.plans.PlanTest import org.apache.spark.sql.catalyst.plans.logical.Project import org.apache.spark.sql.execution.SparkSqlParser import org.apache.spark.sql.execution.datasources.CreateTable -import org.apache.spark.sql.internal.HiveSerDe +import org.apache.spark.sql.internal.{HiveSerDe, SQLConf} import org.apache.spark.sql.types.{IntegerType, StringType, StructField, StructType} // TODO: merge this with DDLSuite (SPARK-14441) class DDLCommandSuite extends PlanTest { - private lazy val parser = new SparkSqlParser + private lazy val parser = new SparkSqlParser(new SQLConf) private def assertUnsupported(sql: String, containsThesePhrases: Seq[String] = Seq()): Unit = { val e = intercept[ParseException] { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/internal/VariableSubstitutionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/internal/VariableSubstitutionSuite.scala index c5e5b70e21335..d5a946aeaac31 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/internal/VariableSubstitutionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/internal/VariableSubstitutionSuite.scala @@ -18,11 +18,12 @@ package org.apache.spark.sql.internal import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.catalyst.plans.PlanTest +import org.apache.spark.sql.AnalysisException -class VariableSubstitutionSuite extends SparkFunSuite with PlanTest { +class VariableSubstitutionSuite extends SparkFunSuite { - private lazy val sub = new VariableSubstitution + private lazy val conf = new SQLConf + private lazy val sub = new VariableSubstitution(conf) test("system property") { System.setProperty("varSubSuite.var", "abcd") @@ -34,26 +35,26 @@ class VariableSubstitutionSuite extends SparkFunSuite with PlanTest { } test("Spark configuration variable") { - withSQLConf("some-random-string-abcd" -> "1234abcd") { - assert(sub.substitute("${hiveconf:some-random-string-abcd}") == "1234abcd") - assert(sub.substitute("${sparkconf:some-random-string-abcd}") == "1234abcd") - assert(sub.substitute("${spark:some-random-string-abcd}") == "1234abcd") - assert(sub.substitute("${some-random-string-abcd}") == "1234abcd") - } + conf.setConfString("some-random-string-abcd", "1234abcd") + assert(sub.substitute("${hiveconf:some-random-string-abcd}") == "1234abcd") + assert(sub.substitute("${sparkconf:some-random-string-abcd}") == "1234abcd") + assert(sub.substitute("${spark:some-random-string-abcd}") == "1234abcd") + assert(sub.substitute("${some-random-string-abcd}") == "1234abcd") } test("multiple substitutes") { val q = "select ${bar} ${foo} ${doo} this is great" - withSQLConf("bar" -> "1", "foo" -> "2", "doo" -> "3") { - assert(sub.substitute(q) == "select 1 2 3 this is great") - } + conf.setConfString("bar", "1") + conf.setConfString("foo", "2") + conf.setConfString("doo", "3") + assert(sub.substitute(q) == "select 1 2 3 this is great") } test("test nested substitutes") { val q = "select ${bar} ${foo} this is great" - withSQLConf("bar" -> "1", "foo" -> "${bar}") { - assert(sub.substitute(q) == "select 1 1 this is great") - } + conf.setConfString("bar", "1") + conf.setConfString("foo", "${bar}") + assert(sub.substitute(q) == "select 1 1 this is great") } } From 08e0d033b40946b4ef5741a7aa1e7ba0bd48c6fb Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Sat, 8 Jul 2017 14:24:37 -0700 Subject: [PATCH 0914/1765] [SPARK-21093][R] Terminate R's worker processes in the parent of R's daemon to prevent a leak ## What changes were proposed in this pull request? This is a retry for #18320. This PR was reverted due to unexpected test failures with -10 error code. I was unable to reproduce in MacOS, CentOS and Ubuntu but only in Jenkins. So, the tests proceeded to verify this and revert the past try here - https://github.com/apache/spark/pull/18456 This new approach was tested in https://github.com/apache/spark/pull/18463. **Test results**: - With the part of suspicious change in the past try (https://github.com/apache/spark/pull/18463/commits/466325d3fd353668583f3bde38ae490d9db0b189) Tests ran 4 times and 2 times passed and 2 time failed. - Without the part of suspicious change in the past try (https://github.com/apache/spark/pull/18463/commits/466325d3fd353668583f3bde38ae490d9db0b189) Tests ran 5 times and they all passed. - With this new approach (https://github.com/apache/spark/pull/18463/commits/0a7589c09f53dfc2094497d8d3e59d6407569417) Tests ran 5 times and they all passed. It looks the cause is as below (see https://github.com/apache/spark/pull/18463/commits/466325d3fd353668583f3bde38ae490d9db0b189): ```diff + exitCode <- 1 ... + data <- parallel:::readChild(child) + if (is.raw(data)) { + if (unserialize(data) == exitCode) { ... + } + } ... - parallel:::mcexit(0L) + parallel:::mcexit(0L, send = exitCode) ``` Two possibilities I think - `parallel:::mcexit(.. , send = exitCode)` https://stat.ethz.ch/R-manual/R-devel/library/parallel/html/mcfork.html > It sends send to the master (unless NULL) and then shuts down the child process. However, it looks possible that the parent attemps to terminate the child right after getting our custom exit code. So, the child gets terminated between "send" and "shuts down", failing to exit properly. - A bug between `parallel:::mcexit(..., send = ...)` and `parallel:::readChild`. **Proposal**: To resolve this, I simply decided to avoid both possibilities with this new approach here (https://github.com/apache/spark/pull/18465/commits/9ff89a7859cb9f427fc774f33c3521c7d962b723). To support this idea, I explained with some quotation of the documentation as below: https://stat.ethz.ch/R-manual/R-devel/library/parallel/html/mcfork.html > `readChild` and `readChildren` return a raw vector with a "pid" attribute if data were available, an integer vector of length one with the process ID if a child terminated or `NULL` if the child no longer exists (no children at all for `readChildren`). `readChild` returns "an integer vector of length one with the process ID if a child terminated" so we can check if it is `integer` and the same selected "process ID". I believe this makes sure that the children are exited. In case that children happen to send any data manually to parent (which is why we introduced the suspicious part of the change (https://github.com/apache/spark/pull/18463/commits/466325d3fd353668583f3bde38ae490d9db0b189)), this should be raw bytes and will be discarded (and then will try to read the next and check if it is `integer` in the next loop). ## How was this patch tested? Manual tests and Jenkins tests. Author: hyukjinkwon Closes #18465 from HyukjinKwon/SPARK-21093-retry-1. --- R/pkg/inst/worker/daemon.R | 51 +++++++++++++++++++++++++++++++++++--- 1 file changed, 48 insertions(+), 3 deletions(-) diff --git a/R/pkg/inst/worker/daemon.R b/R/pkg/inst/worker/daemon.R index 3a318b71ea06d..2e31dc5f728cd 100644 --- a/R/pkg/inst/worker/daemon.R +++ b/R/pkg/inst/worker/daemon.R @@ -30,8 +30,50 @@ port <- as.integer(Sys.getenv("SPARKR_WORKER_PORT")) inputCon <- socketConnection( port = port, open = "rb", blocking = TRUE, timeout = connectionTimeout) +# Waits indefinitely for a socket connecion by default. +selectTimeout <- NULL + while (TRUE) { - ready <- socketSelect(list(inputCon)) + ready <- socketSelect(list(inputCon), timeout = selectTimeout) + + # Note that the children should be terminated in the parent. If each child terminates + # itself, it appears that the resource is not released properly, that causes an unexpected + # termination of this daemon due to, for example, running out of file descriptors + # (see SPARK-21093). Therefore, the current implementation tries to retrieve children + # that are exited (but not terminated) and then sends a kill signal to terminate them properly + # in the parent. + # + # There are two paths that it attempts to send a signal to terminate the children in the parent. + # + # 1. Every second if any socket connection is not available and if there are child workers + # running. + # 2. Right after a socket connection is available. + # + # In other words, the parent attempts to send the signal to the children every second if + # any worker is running or right before launching other worker children from the following + # new socket connection. + + # The process IDs of exited children are returned below. + children <- parallel:::selectChildren(timeout = 0) + + if (is.integer(children)) { + lapply(children, function(child) { + # This should be the PIDs of exited children. Otherwise, this returns raw bytes if any data + # was sent from this child. In this case, we discard it. + pid <- parallel:::readChild(child) + if (is.integer(pid)) { + # This checks if the data from this child is the same pid of this selected child. + if (child == pid) { + # If so, we terminate this child. + tools::pskill(child, tools::SIGUSR1) + } + } + }) + } else if (is.null(children)) { + # If it is NULL, there are no children. Waits indefinitely for a socket connecion. + selectTimeout <- NULL + } + if (ready) { port <- SparkR:::readInt(inputCon) # There is a small chance that it could be interrupted by signal, retry one time @@ -44,12 +86,15 @@ while (TRUE) { } p <- parallel:::mcfork() if (inherits(p, "masterProcess")) { + # Reach here because this is a child process. close(inputCon) Sys.setenv(SPARKR_WORKER_PORT = port) try(source(script)) - # Set SIGUSR1 so that child can exit - tools::pskill(Sys.getpid(), tools::SIGUSR1) + # Note that this mcexit does not fully terminate this child. parallel:::mcexit(0L) + } else { + # Forking succeeded and we need to check if they finished their jobs every second. + selectTimeout <- 1 } } } From 680b33f16694b7c460235b11b8c265bc304f795a Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Sun, 9 Jul 2017 16:30:35 -0700 Subject: [PATCH 0915/1765] [SPARK-18016][SQL][FOLLOWUP] merge declareAddedFunctions, initNestedClasses and declareNestedClasses ## What changes were proposed in this pull request? These 3 methods have to be used together, so it makes more sense to merge them into one method and then the caller side only need to call one method. ## How was this patch tested? existing tests. Author: Wenchen Fan Closes #18579 from cloud-fan/minor. --- .../expressions/codegen/CodeGenerator.scala | 29 +++++++------------ .../codegen/GenerateMutableProjection.scala | 5 +--- .../codegen/GenerateOrdering.scala | 5 +--- .../codegen/GeneratePredicate.scala | 5 +--- .../codegen/GenerateSafeProjection.scala | 5 +--- .../codegen/GenerateUnsafeProjection.scala | 5 +--- .../sql/execution/WholeStageCodegenExec.scala | 5 +--- .../columnar/GenerateColumnAccessor.scala | 5 +--- 8 files changed, 18 insertions(+), 46 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index b15bf2ca7c116..7cf9daf628608 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -302,29 +302,20 @@ class CodegenContext { } /** - * Instantiates all nested, private sub-classes as objects to the `OuterClass` + * Declares all function code. If the added functions are too many, split them into nested + * sub-classes to avoid hitting Java compiler constant pool limitation. */ - private[sql] def initNestedClasses(): String = { + def declareAddedFunctions(): String = { + val inlinedFunctions = classFunctions(outerClassName).values + // Nested, private sub-classes have no mutable state (though they do reference the outer class' // mutable state), so we declare and initialize them inline to the OuterClass. - classes.filter(_._1 != outerClassName).map { + val initNestedClasses = classes.filter(_._1 != outerClassName).map { case (className, classInstance) => s"private $className $classInstance = new $className();" - }.mkString("\n") - } - - /** - * Declares all function code that should be inlined to the `OuterClass`. - */ - private[sql] def declareAddedFunctions(): String = { - classFunctions(outerClassName).values.mkString("\n") - } + } - /** - * Declares all nested, private sub-classes and the function code that should be inlined to them. - */ - private[sql] def declareNestedClasses(): String = { - classFunctions.filterKeys(_ != outerClassName).map { + val declareNestedClasses = classFunctions.filterKeys(_ != outerClassName).map { case (className, functions) => s""" |private class $className { @@ -332,7 +323,9 @@ class CodegenContext { |} """.stripMargin } - }.mkString("\n") + + (inlinedFunctions ++ initNestedClasses ++ declareNestedClasses).mkString("\n") + } final val JAVA_BOOLEAN = "boolean" final val JAVA_BYTE = "byte" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala index 635766835029b..3768dcde00a4e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala @@ -115,8 +115,6 @@ object GenerateMutableProjection extends CodeGenerator[Seq[Expression], MutableP ${ctx.initPartition()} } - ${ctx.declareAddedFunctions()} - public ${classOf[BaseMutableProjection].getName} target(InternalRow row) { mutableRow = row; return this; @@ -136,8 +134,7 @@ object GenerateMutableProjection extends CodeGenerator[Seq[Expression], MutableP return mutableRow; } - ${ctx.initNestedClasses()} - ${ctx.declareNestedClasses()} + ${ctx.declareAddedFunctions()} } """ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala index a31943255b995..4e47895985209 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala @@ -173,15 +173,12 @@ object GenerateOrdering extends CodeGenerator[Seq[SortOrder], Ordering[InternalR ${ctx.initMutableStates()} } - ${ctx.declareAddedFunctions()} - public int compare(InternalRow a, InternalRow b) { $comparisons return 0; } - ${ctx.initNestedClasses()} - ${ctx.declareNestedClasses()} + ${ctx.declareAddedFunctions()} }""" val code = CodeFormatter.stripOverlappingComments( diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratePredicate.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratePredicate.scala index b400783bb5e55..e35b9dda6c017 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratePredicate.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratePredicate.scala @@ -66,15 +66,12 @@ object GeneratePredicate extends CodeGenerator[Expression, Predicate] { ${ctx.initPartition()} } - ${ctx.declareAddedFunctions()} - public boolean eval(InternalRow ${ctx.INPUT_ROW}) { ${eval.code} return !${eval.isNull} && ${eval.value}; } - ${ctx.initNestedClasses()} - ${ctx.declareNestedClasses()} + ${ctx.declareAddedFunctions()} }""" val code = CodeFormatter.stripOverlappingComments( diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala index dd0419d2286d1..192701a829686 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala @@ -175,16 +175,13 @@ object GenerateSafeProjection extends CodeGenerator[Seq[Expression], Projection] ${ctx.initPartition()} } - ${ctx.declareAddedFunctions()} - public java.lang.Object apply(java.lang.Object _i) { InternalRow ${ctx.INPUT_ROW} = (InternalRow) _i; $allExpressions return mutableRow; } - ${ctx.initNestedClasses()} - ${ctx.declareNestedClasses()} + ${ctx.declareAddedFunctions()} } """ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala index 6be69d119bf8a..f2a66efc98e71 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala @@ -391,8 +391,6 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro ${ctx.initPartition()} } - ${ctx.declareAddedFunctions()} - // Scala.Function1 need this public java.lang.Object apply(java.lang.Object row) { return apply((InternalRow) row); @@ -403,8 +401,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro return ${eval.value}; } - ${ctx.initNestedClasses()} - ${ctx.declareNestedClasses()} + ${ctx.declareAddedFunctions()} } """ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala index 0bd28e36135c8..1007a7d55691b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala @@ -352,14 +352,11 @@ case class WholeStageCodegenExec(child: SparkPlan) extends UnaryExecNode with Co ${ctx.initPartition()} } - ${ctx.declareAddedFunctions()} - protected void processNext() throws java.io.IOException { ${code.trim} } - ${ctx.initNestedClasses()} - ${ctx.declareNestedClasses()} + ${ctx.declareAddedFunctions()} } """.trim diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/GenerateColumnAccessor.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/GenerateColumnAccessor.scala index fc977f2fd5530..da34643281911 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/GenerateColumnAccessor.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/GenerateColumnAccessor.scala @@ -192,8 +192,6 @@ object GenerateColumnAccessor extends CodeGenerator[Seq[DataType], ColumnarItera this.columnIndexes = columnIndexes; } - ${ctx.declareAddedFunctions()} - public boolean hasNext() { if (currentRow < numRowsInBatch) { return true; @@ -222,8 +220,7 @@ object GenerateColumnAccessor extends CodeGenerator[Seq[DataType], ColumnarItera return unsafeRow; } - ${ctx.initNestedClasses()} - ${ctx.declareNestedClasses()} + ${ctx.declareAddedFunctions()} }""" val code = CodeFormatter.stripOverlappingComments( From 457dc9ccbf8404fef6c1ebf8f82e59e4ba480a0e Mon Sep 17 00:00:00 2001 From: jerryshao Date: Mon, 10 Jul 2017 11:22:28 +0800 Subject: [PATCH 0916/1765] [MINOR][DOC] Improve the docs about how to correctly set configurations ## What changes were proposed in this pull request? Spark provides several ways to set configurations, either from configuration file, or from `spark-submit` command line options, or programmatically through `SparkConf` class. It may confuses beginners why some configurations set through `SparkConf` cannot take affect. So here add some docs to address this problems and let beginners know how to correctly set configurations. ## How was this patch tested? N/A Author: jerryshao Closes #18552 from jerryshao/improve-doc. --- docs/configuration.md | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/docs/configuration.md b/docs/configuration.md index 6ca84240c1247..91b5befd1b1eb 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -95,6 +95,13 @@ in the `spark-defaults.conf` file. A few configuration keys have been renamed si versions of Spark; in such cases, the older key names are still accepted, but take lower precedence than any instance of the newer key. +Spark properties mainly can be divided into two kinds: one is related to deploy, like +"spark.driver.memory", "spark.executor.instances", this kind of properties may not be affected when +setting programmatically through `SparkConf` in runtime, or the behavior is depending on which +cluster manager and deploy mode you choose, so it would be suggested to set through configuration +file or `spark-submit` command line options; another is mainly related to Spark runtime control, +like "spark.task.maxFailures", this kind of properties can be set in either way. + ## Viewing Spark Properties The application web UI at `http://:4040` lists Spark properties in the "Environment" tab. From 0e80ecae300f3e2033419b2d98da8bf092c105bb Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Sun, 9 Jul 2017 22:53:27 -0700 Subject: [PATCH 0917/1765] [SPARK-21100][SQL][FOLLOWUP] cleanup code and add more comments for Dataset.summary ## What changes were proposed in this pull request? Some code cleanup and adding comments to make the code more readable. Changed the way to generate result rows, to be more clear. ## How was this patch tested? existing tests Author: Wenchen Fan Closes #18570 from cloud-fan/summary. --- .../scala/org/apache/spark/sql/Dataset.scala | 9 -- .../sql/execution/stat/StatFunctions.scala | 129 ++++++++---------- .../org/apache/spark/sql/DataFrameSuite.scala | 2 +- 3 files changed, 56 insertions(+), 84 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index 5326b45b50a8b..dfb51192c69bc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -224,15 +224,6 @@ class Dataset[T] private[sql]( } } - private[sql] def aggregatableColumns: Seq[Expression] = { - schema.fields - .filter(f => f.dataType.isInstanceOf[NumericType] || f.dataType.isInstanceOf[StringType]) - .map { n => - queryExecution.analyzed.resolveQuoted(n.name, sparkSession.sessionState.analyzer.resolver) - .get - } - } - /** * Compose the string representing rows for output * diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala index 436e18fdb5ff5..a75cfb3600225 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala @@ -17,12 +17,15 @@ package org.apache.spark.sql.execution.stat +import java.util.Locale + import org.apache.spark.internal.Logging import org.apache.spark.sql.{Column, DataFrame, Dataset, Row} -import org.apache.spark.sql.catalyst.expressions.{Cast, CreateArray, Expression, GenericInternalRow, Literal} +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Cast, Expression, GenericInternalRow, GetArrayItem, Literal} import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.plans.logical.LocalRelation -import org.apache.spark.sql.catalyst.util.{usePrettyExpression, QuantileSummaries} +import org.apache.spark.sql.catalyst.util.QuantileSummaries import org.apache.spark.sql.functions._ import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String @@ -228,90 +231,68 @@ object StatFunctions extends Logging { val defaultStatistics = Seq("count", "mean", "stddev", "min", "25%", "50%", "75%", "max") val selectedStatistics = if (statistics.nonEmpty) statistics else defaultStatistics - val hasPercentiles = selectedStatistics.exists(_.endsWith("%")) - val (percentiles, percentileNames, remainingAggregates) = if (hasPercentiles) { - val (pStrings, rest) = selectedStatistics.partition(a => a.endsWith("%")) - val percentiles = pStrings.map { p => - try { - p.stripSuffix("%").toDouble / 100.0 - } catch { - case e: NumberFormatException => - throw new IllegalArgumentException(s"Unable to parse $p as a percentile", e) - } + val percentiles = selectedStatistics.filter(a => a.endsWith("%")).map { p => + try { + p.stripSuffix("%").toDouble / 100.0 + } catch { + case e: NumberFormatException => + throw new IllegalArgumentException(s"Unable to parse $p as a percentile", e) } - require(percentiles.forall(p => p >= 0 && p <= 1), "Percentiles must be in the range [0, 1]") - (percentiles, pStrings, rest) - } else { - (Seq(), Seq(), selectedStatistics) } + require(percentiles.forall(p => p >= 0 && p <= 1), "Percentiles must be in the range [0, 1]") - - // The list of summary statistics to compute, in the form of expressions. - val availableStatistics = Map[String, Expression => Expression]( - "count" -> ((child: Expression) => Count(child).toAggregateExpression()), - "mean" -> ((child: Expression) => Average(child).toAggregateExpression()), - "stddev" -> ((child: Expression) => StddevSamp(child).toAggregateExpression()), - "min" -> ((child: Expression) => Min(child).toAggregateExpression()), - "max" -> ((child: Expression) => Max(child).toAggregateExpression())) - - val statisticFns = remainingAggregates.map { agg => - require(availableStatistics.contains(agg), s"$agg is not a recognised statistic") - agg -> availableStatistics(agg) - } - - def percentileAgg(child: Expression): Expression = - new ApproximatePercentile(child, CreateArray(percentiles.map(Literal(_)))) - .toAggregateExpression() - - val outputCols = ds.aggregatableColumns.map(usePrettyExpression(_).sql).toList - - val ret: Seq[Row] = if (outputCols.nonEmpty) { - var aggExprs = statisticFns.toList.flatMap { case (_, colToAgg) => - outputCols.map(c => Column(Cast(colToAgg(Column(c).expr), StringType)).as(c)) - } - if (hasPercentiles) { - aggExprs = outputCols.map(c => Column(percentileAgg(Column(c).expr)).as(c)) ++ aggExprs + var percentileIndex = 0 + val statisticFns = selectedStatistics.map { stats => + if (stats.endsWith("%")) { + val index = percentileIndex + percentileIndex += 1 + (child: Expression) => + GetArrayItem( + new ApproximatePercentile(child, Literal.create(percentiles)).toAggregateExpression(), + Literal(index)) + } else { + stats.toLowerCase(Locale.ROOT) match { + case "count" => (child: Expression) => Count(child).toAggregateExpression() + case "mean" => (child: Expression) => Average(child).toAggregateExpression() + case "stddev" => (child: Expression) => StddevSamp(child).toAggregateExpression() + case "min" => (child: Expression) => Min(child).toAggregateExpression() + case "max" => (child: Expression) => Max(child).toAggregateExpression() + case _ => throw new IllegalArgumentException(s"$stats is not a recognised statistic") + } } + } - val row = ds.groupBy().agg(aggExprs.head, aggExprs.tail: _*).head().toSeq + val selectedCols = ds.logicalPlan.output + .filter(a => a.dataType.isInstanceOf[NumericType] || a.dataType.isInstanceOf[StringType]) - // Pivot the data so each summary is one row - val grouped: Seq[Seq[Any]] = row.grouped(outputCols.size).toSeq + val aggExprs = statisticFns.flatMap { func => + selectedCols.map(c => Column(Cast(func(c), StringType)).as(c.name)) + } - val basicStats = if (hasPercentiles) grouped.tail else grouped + // If there is no selected columns, we don't need to run this aggregate, so make it a lazy val. + lazy val aggResult = ds.select(aggExprs: _*).queryExecution.toRdd.collect().head - val rows = basicStats.zip(statisticFns).map { case (aggregation, (statistic, _)) => - Row(statistic :: aggregation.toList: _*) - } + // We will have one row for each selected statistic in the result. + val result = Array.fill[InternalRow](selectedStatistics.length) { + // each row has the statistic name, and statistic values of each selected column. + new GenericInternalRow(selectedCols.length + 1) + } - if (hasPercentiles) { - def nullSafeString(x: Any) = if (x == null) null else x.toString - val percentileRows = grouped.head - .map { - case a: Seq[Any] => a - case _ => Seq.fill(percentiles.length)(null: Any) - } - .transpose - .zip(percentileNames) - .map { case (values: Seq[Any], name) => - Row(name :: values.map(nullSafeString).toList: _*) - } - (rows ++ percentileRows) - .sortWith((left, right) => - selectedStatistics.indexOf(left(0)) < selectedStatistics.indexOf(right(0))) - } else { - rows + var rowIndex = 0 + while (rowIndex < result.length) { + val statsName = selectedStatistics(rowIndex) + result(rowIndex).update(0, UTF8String.fromString(statsName)) + for (colIndex <- selectedCols.indices) { + val statsValue = aggResult.getUTF8String(rowIndex * selectedCols.length + colIndex) + result(rowIndex).update(colIndex + 1, statsValue) } - } else { - // If there are no output columns, just output a single column that contains the stats. - selectedStatistics.map(Row(_)) + rowIndex += 1 } // All columns are string type - val schema = StructType( - StructField("summary", StringType) :: outputCols.map(StructField(_, StringType))).toAttributes - // `toArray` forces materialization to make the seq serializable - Dataset.ofRows(ds.sparkSession, LocalRelation.fromExternalRows(schema, ret.toArray.toSeq)) - } + val output = AttributeReference("summary", StringType)() +: + selectedCols.map(c => AttributeReference(c.name, StringType)()) + Dataset.ofRows(ds.sparkSession, LocalRelation(output, result)) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index 2c7051bf431c3..b2219b4eb8c17 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -770,7 +770,7 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { val fooE = intercept[IllegalArgumentException] { person2.summary("foo") } - assert(fooE.getMessage === "requirement failed: foo is not a recognised statistic") + assert(fooE.getMessage === "foo is not a recognised statistic") val parseE = intercept[IllegalArgumentException] { person2.summary("foo%") From 96d58f285bc98d4c2484150eefe7447db4784a86 Mon Sep 17 00:00:00 2001 From: Eric Vandenberg Date: Mon, 10 Jul 2017 14:40:20 +0800 Subject: [PATCH 0918/1765] [SPARK-21219][CORE] Task retry occurs on same executor due to race condition with blacklisting ## What changes were proposed in this pull request? There's a race condition in the current TaskSetManager where a failed task is added for retry (addPendingTask), and can asynchronously be assigned to an executor *prior* to the blacklist state (updateBlacklistForFailedTask), the result is the task might re-execute on the same executor. This is particularly problematic if the executor is shutting down since the retry task immediately becomes a lost task (ExecutorLostFailure). Another side effect is that the actual failure reason gets obscured by the retry task which never actually executed. There are sample logs showing the issue in the https://issues.apache.org/jira/browse/SPARK-21219 The fix is to change the ordering of the addPendingTask and updatingBlackListForFailedTask calls in TaskSetManager.handleFailedTask ## How was this patch tested? Implemented a unit test that verifies the task is black listed before it is added to the pending task. Ran the unit test without the fix and it fails. Ran the unit test with the fix and it passes. Please review http://spark.apache.org/contributing.html before opening a pull request. Author: Eric Vandenberg Closes #18427 from ericvandenbergfb/blacklistFix. --- .../spark/scheduler/TaskSetManager.scala | 21 ++++----- .../spark/scheduler/TaskSetManagerSuite.scala | 44 ++++++++++++++++++- 2 files changed, 54 insertions(+), 11 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala index 02d374dc37cd5..3968fb7e6356d 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala @@ -198,7 +198,7 @@ private[spark] class TaskSetManager( private[scheduler] var emittedTaskSizeWarning = false /** Add a task to all the pending-task lists that it should be on. */ - private def addPendingTask(index: Int) { + private[spark] def addPendingTask(index: Int) { for (loc <- tasks(index).preferredLocations) { loc match { case e: ExecutorCacheTaskLocation => @@ -832,15 +832,6 @@ private[spark] class TaskSetManager( sched.dagScheduler.taskEnded(tasks(index), reason, null, accumUpdates, info) - if (successful(index)) { - logInfo(s"Task ${info.id} in stage ${taskSet.id} (TID $tid) failed, but the task will not" + - s" be re-executed (either because the task failed with a shuffle data fetch failure," + - s" so the previous stage needs to be re-run, or because a different copy of the task" + - s" has already succeeded).") - } else { - addPendingTask(index) - } - if (!isZombie && reason.countTowardsTaskFailures) { taskSetBlacklistHelperOpt.foreach(_.updateBlacklistForFailedTask( info.host, info.executorId, index)) @@ -854,6 +845,16 @@ private[spark] class TaskSetManager( return } } + + if (successful(index)) { + logInfo(s"Task ${info.id} in stage ${taskSet.id} (TID $tid) failed, but the task will not" + + s" be re-executed (either because the task failed with a shuffle data fetch failure," + + s" so the previous stage needs to be re-run, or because a different copy of the task" + + s" has already succeeded).") + } else { + addPendingTask(index) + } + maybeFinishTaskSet() } diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala index 80fb674725814..e46900e4e5049 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala @@ -23,7 +23,7 @@ import scala.collection.mutable import scala.collection.mutable.ArrayBuffer import org.mockito.Matchers.{any, anyInt, anyString} -import org.mockito.Mockito.{mock, never, spy, verify, when} +import org.mockito.Mockito.{mock, never, spy, times, verify, when} import org.mockito.invocation.InvocationOnMock import org.mockito.stubbing.Answer @@ -1172,6 +1172,48 @@ class TaskSetManagerSuite extends SparkFunSuite with LocalSparkContext with Logg assert(blacklistTracker.isNodeBlacklisted("host1")) } + test("update blacklist before adding pending task to avoid race condition") { + // When a task fails, it should apply the blacklist policy prior to + // retrying the task otherwise there's a race condition where run on + // the same executor that it was intended to be black listed from. + val conf = new SparkConf(). + set(config.BLACKLIST_ENABLED, true) + + // Create a task with two executors. + sc = new SparkContext("local", "test", conf) + val exec = "executor1" + val host = "host1" + val exec2 = "executor2" + val host2 = "host2" + sched = new FakeTaskScheduler(sc, (exec, host), (exec2, host2)) + val taskSet = FakeTask.createTaskSet(1) + + val clock = new ManualClock + val mockListenerBus = mock(classOf[LiveListenerBus]) + val blacklistTracker = new BlacklistTracker(mockListenerBus, conf, None, clock) + val taskSetManager = new TaskSetManager(sched, taskSet, 1, Some(blacklistTracker)) + val taskSetManagerSpy = spy(taskSetManager) + + val taskDesc = taskSetManagerSpy.resourceOffer(exec, host, TaskLocality.ANY) + + // Assert the task has been black listed on the executor it was last executed on. + when(taskSetManagerSpy.addPendingTask(anyInt())).thenAnswer( + new Answer[Unit] { + override def answer(invocationOnMock: InvocationOnMock): Unit = { + val task = invocationOnMock.getArgumentAt(0, classOf[Int]) + assert(taskSetManager.taskSetBlacklistHelperOpt.get. + isExecutorBlacklistedForTask(exec, task)) + } + } + ) + + // Simulate a fake exception + val e = new ExceptionFailure("a", "b", Array(), "c", None) + taskSetManagerSpy.handleFailedTask(taskDesc.get.taskId, TaskState.FAILED, e) + + verify(taskSetManagerSpy, times(1)).addPendingTask(anyInt()) + } + private def createTaskResult( id: Int, accumUpdates: Seq[AccumulatorV2[_, _]] = Seq.empty): DirectTaskResult[Int] = { From c444d10868c808f4ae43becd5506bf944d9c2e9b Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Mon, 10 Jul 2017 07:46:47 +0100 Subject: [PATCH 0919/1765] [MINOR][DOC] Remove obsolete `ec2-scripts.md` ## What changes were proposed in this pull request? Since this document became obsolete, we had better remove this for Apache Spark 2.3.0. The original document is removed via SPARK-12735 on January 2016, and currently it's just redirection page. The only reference in Apache Spark website will go directly to the destination in https://github.com/apache/spark-website/pull/54. ## How was this patch tested? N/A. This is a removal of documentation. Author: Dongjoon Hyun Closes #18578 from dongjoon-hyun/SPARK-REMOVE-EC2. --- docs/ec2-scripts.md | 7 ------- 1 file changed, 7 deletions(-) delete mode 100644 docs/ec2-scripts.md diff --git a/docs/ec2-scripts.md b/docs/ec2-scripts.md deleted file mode 100644 index 6cd39dbed055d..0000000000000 --- a/docs/ec2-scripts.md +++ /dev/null @@ -1,7 +0,0 @@ ---- -layout: global -title: Running Spark on EC2 -redirect: https://github.com/amplab/spark-ec2#readme ---- - -This document has been superseded and replaced by documentation at https://github.com/amplab/spark-ec2#readme From 647963a26a2d4468ebd9b68111ebe68bee501fde Mon Sep 17 00:00:00 2001 From: Takeshi Yamamuro Date: Mon, 10 Jul 2017 15:58:34 +0800 Subject: [PATCH 0920/1765] [SPARK-20460][SQL] Make it more consistent to handle column name duplication ## What changes were proposed in this pull request? This pr made it more consistent to handle column name duplication. In the current master, error handling is different when hitting column name duplication: ``` // json scala> val schema = StructType(StructField("a", IntegerType) :: StructField("a", IntegerType) :: Nil) scala> Seq("""{"a":1, "a":1}"""""").toDF().coalesce(1).write.mode("overwrite").text("/tmp/data") scala> spark.read.format("json").schema(schema).load("/tmp/data").show org.apache.spark.sql.AnalysisException: Reference 'a' is ambiguous, could be: a#12, a#13.; at org.apache.spark.sql.catalyst.plans.logical.LogicalPlan.resolve(LogicalPlan.scala:287) at org.apache.spark.sql.catalyst.plans.logical.LogicalPlan.resolve(LogicalPlan.scala:181) at org.apache.spark.sql.catalyst.plans.logical.LogicalPlan$$anonfun$resolve$1.apply(LogicalPlan.scala:153) scala> spark.read.format("json").load("/tmp/data").show org.apache.spark.sql.AnalysisException: Duplicate column(s) : "a" found, cannot save to JSON format; at org.apache.spark.sql.execution.datasources.json.JsonDataSource.checkConstraints(JsonDataSource.scala:81) at org.apache.spark.sql.execution.datasources.json.JsonDataSource.inferSchema(JsonDataSource.scala:63) at org.apache.spark.sql.execution.datasources.json.JsonFileFormat.inferSchema(JsonFileFormat.scala:57) at org.apache.spark.sql.execution.datasources.DataSource$$anonfun$7.apply(DataSource.scala:176) at org.apache.spark.sql.execution.datasources.DataSource$$anonfun$7.apply(DataSource.scala:176) // csv scala> val schema = StructType(StructField("a", IntegerType) :: StructField("a", IntegerType) :: Nil) scala> Seq("a,a", "1,1").toDF().coalesce(1).write.mode("overwrite").text("/tmp/data") scala> spark.read.format("csv").schema(schema).option("header", false).load("/tmp/data").show org.apache.spark.sql.AnalysisException: Reference 'a' is ambiguous, could be: a#41, a#42.; at org.apache.spark.sql.catalyst.plans.logical.LogicalPlan.resolve(LogicalPlan.scala:287) at org.apache.spark.sql.catalyst.plans.logical.LogicalPlan.resolve(LogicalPlan.scala:181) at org.apache.spark.sql.catalyst.plans.logical.LogicalPlan$$anonfun$resolve$1.apply(LogicalPlan.scala:153) at org.apache.spark.sql.catalyst.plans.logical.LogicalPlan$$anonfun$resolve$1.apply(LogicalPlan.scala:152) // If `inferSchema` is true, a CSV format is duplicate-safe (See SPARK-16896) scala> spark.read.format("csv").option("header", true).load("/tmp/data").show +---+---+ | a0| a1| +---+---+ | 1| 1| +---+---+ // parquet scala> val schema = StructType(StructField("a", IntegerType) :: StructField("a", IntegerType) :: Nil) scala> Seq((1, 1)).toDF("a", "b").coalesce(1).write.mode("overwrite").parquet("/tmp/data") scala> spark.read.format("parquet").schema(schema).option("header", false).load("/tmp/data").show org.apache.spark.sql.AnalysisException: Reference 'a' is ambiguous, could be: a#110, a#111.; at org.apache.spark.sql.catalyst.plans.logical.LogicalPlan.resolve(LogicalPlan.scala:287) at org.apache.spark.sql.catalyst.plans.logical.LogicalPlan.resolve(LogicalPlan.scala:181) at org.apache.spark.sql.catalyst.plans.logical.LogicalPlan$$anonfun$resolve$1.apply(LogicalPlan.scala:153) at org.apache.spark.sql.catalyst.plans.logical.LogicalPlan$$anonfun$resolve$1.apply(LogicalPlan.scala:152) at scala.collection.TraversableLike$$anonfun$map$1.apply(TraversableLike.scala:234) at scala.collection.TraversableLike$$anonfun$map$1.apply(TraversableLike.scala:234) ``` When this patch applied, the results change to; ``` // json scala> val schema = StructType(StructField("a", IntegerType) :: StructField("a", IntegerType) :: Nil) scala> Seq("""{"a":1, "a":1}"""""").toDF().coalesce(1).write.mode("overwrite").text("/tmp/data") scala> spark.read.format("json").schema(schema).load("/tmp/data").show org.apache.spark.sql.AnalysisException: Found duplicate column(s) in datasource: "a"; at org.apache.spark.sql.util.SchemaUtils$.checkColumnNameDuplication(SchemaUtil.scala:47) at org.apache.spark.sql.util.SchemaUtils$.checkSchemaColumnNameDuplication(SchemaUtil.scala:33) at org.apache.spark.sql.execution.datasources.DataSource.getOrInferFileFormatSchema(DataSource.scala:186) at org.apache.spark.sql.execution.datasources.DataSource.resolveRelation(DataSource.scala:368) scala> spark.read.format("json").load("/tmp/data").show org.apache.spark.sql.AnalysisException: Found duplicate column(s) in datasource: "a"; at org.apache.spark.sql.util.SchemaUtils$.checkColumnNameDuplication(SchemaUtil.scala:47) at org.apache.spark.sql.util.SchemaUtils$.checkSchemaColumnNameDuplication(SchemaUtil.scala:33) at org.apache.spark.sql.execution.datasources.DataSource.getOrInferFileFormatSchema(DataSource.scala:186) at org.apache.spark.sql.execution.datasources.DataSource.resolveRelation(DataSource.scala:368) at org.apache.spark.sql.DataFrameReader.load(DataFrameReader.scala:178) at org.apache.spark.sql.DataFrameReader.load(DataFrameReader.scala:156) // csv scala> val schema = StructType(StructField("a", IntegerType) :: StructField("a", IntegerType) :: Nil) scala> Seq("a,a", "1,1").toDF().coalesce(1).write.mode("overwrite").text("/tmp/data") scala> spark.read.format("csv").schema(schema).option("header", false).load("/tmp/data").show org.apache.spark.sql.AnalysisException: Found duplicate column(s) in datasource: "a"; at org.apache.spark.sql.util.SchemaUtils$.checkColumnNameDuplication(SchemaUtil.scala:47) at org.apache.spark.sql.util.SchemaUtils$.checkSchemaColumnNameDuplication(SchemaUtil.scala:33) at org.apache.spark.sql.execution.datasources.DataSource.getOrInferFileFormatSchema(DataSource.scala:186) at org.apache.spark.sql.execution.datasources.DataSource.resolveRelation(DataSource.scala:368) at org.apache.spark.sql.DataFrameReader.load(DataFrameReader.scala:178) scala> spark.read.format("csv").option("header", true).load("/tmp/data").show +---+---+ | a0| a1| +---+---+ | 1| 1| +---+---+ // parquet scala> val schema = StructType(StructField("a", IntegerType) :: StructField("a", IntegerType) :: Nil) scala> Seq((1, 1)).toDF("a", "b").coalesce(1).write.mode("overwrite").parquet("/tmp/data") scala> spark.read.format("parquet").schema(schema).option("header", false).load("/tmp/data").show org.apache.spark.sql.AnalysisException: Found duplicate column(s) in datasource: "a"; at org.apache.spark.sql.util.SchemaUtils$.checkColumnNameDuplication(SchemaUtil.scala:47) at org.apache.spark.sql.util.SchemaUtils$.checkSchemaColumnNameDuplication(SchemaUtil.scala:33) at org.apache.spark.sql.execution.datasources.DataSource.getOrInferFileFormatSchema(DataSource.scala:186) at org.apache.spark.sql.execution.datasources.DataSource.resolveRelation(DataSource.scala:368) ``` ## How was this patch tested? Added tests in `DataFrameReaderWriterSuite` and `SQLQueryTestSuite`. Author: Takeshi Yamamuro Closes #17758 from maropu/SPARK-20460. --- .../sql/catalyst/catalog/SessionCatalog.scala | 16 +--- .../apache/spark/sql/util/SchemaUtils.scala | 58 +++++++++--- .../spark/sql/util/SchemaUtilsSuite.scala | 83 +++++++++++++++++ .../command/createDataSourceTables.scala | 2 - .../spark/sql/execution/command/tables.scala | 8 +- .../spark/sql/execution/command/views.scala | 9 +- .../execution/datasources/DataSource.scala | 43 +++++++-- .../InsertIntoHadoopFsRelationCommand.scala | 14 ++- .../datasources/PartitioningUtils.scala | 10 +-- .../datasources/jdbc/JdbcUtils.scala | 11 +-- .../datasources/json/JsonDataSource.scala | 15 +--- .../sql/execution/datasources/rules.scala | 36 ++++---- .../org/apache/spark/sql/DataFrameSuite.scala | 4 +- .../sql/execution/command/DDLSuite.scala | 56 ++++++++---- .../spark/sql/jdbc/JDBCWriteSuite.scala | 4 +- .../sql/sources/ResolvedDataSourceSuite.scala | 5 +- .../sql/streaming/FileStreamSinkSuite.scala | 37 ++++++++ .../sql/test/DataFrameReaderWriterSuite.scala | 88 +++++++++++++++++++ .../sql/hive/execution/HiveDDLSuite.scala | 2 +- 19 files changed, 382 insertions(+), 119 deletions(-) create mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/util/SchemaUtilsSuite.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala index c40d5f6031a21..b44d2ee69e1d1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala @@ -39,7 +39,7 @@ import org.apache.spark.sql.catalyst.parser.{CatalystSqlParser, ParserInterface} import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, SubqueryAlias, View} import org.apache.spark.sql.catalyst.util.StringUtils import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.types.{StructField, StructType} +import org.apache.spark.sql.types.StructType object SessionCatalog { val DEFAULT_DATABASE = "default" @@ -188,19 +188,6 @@ class SessionCatalog( } } - private def checkDuplication(fields: Seq[StructField]): Unit = { - val columnNames = if (conf.caseSensitiveAnalysis) { - fields.map(_.name) - } else { - fields.map(_.name.toLowerCase) - } - if (columnNames.distinct.length != columnNames.length) { - val duplicateColumns = columnNames.groupBy(identity).collect { - case (x, ys) if ys.length > 1 => x - } - throw new AnalysisException(s"Found duplicate column(s): ${duplicateColumns.mkString(", ")}") - } - } // ---------------------------------------------------------------------------- // Databases // ---------------------------------------------------------------------------- @@ -353,7 +340,6 @@ class SessionCatalog( val tableIdentifier = TableIdentifier(table, Some(db)) requireDbExists(db) requireTableExists(tableIdentifier) - checkDuplication(newSchema) val catalogTable = externalCatalog.getTable(db, table) val oldSchema = catalogTable.schema diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/util/SchemaUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/util/SchemaUtils.scala index e881685ce6262..41ca270095ffb 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/util/SchemaUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/util/SchemaUtils.scala @@ -17,7 +17,9 @@ package org.apache.spark.sql.util -import org.apache.spark.internal.Logging +import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.catalyst.analysis._ +import org.apache.spark.sql.types.StructType /** @@ -25,29 +27,63 @@ import org.apache.spark.internal.Logging * * TODO: Merge this file with [[org.apache.spark.ml.util.SchemaUtils]]. */ -private[spark] object SchemaUtils extends Logging { +private[spark] object SchemaUtils { /** - * Checks if input column names have duplicate identifiers. Prints a warning message if + * Checks if an input schema has duplicate column names. This throws an exception if the + * duplication exists. + * + * @param schema schema to check + * @param colType column type name, used in an exception message + * @param caseSensitiveAnalysis whether duplication checks should be case sensitive or not + */ + def checkSchemaColumnNameDuplication( + schema: StructType, colType: String, caseSensitiveAnalysis: Boolean = false): Unit = { + checkColumnNameDuplication(schema.map(_.name), colType, caseSensitiveAnalysis) + } + + // Returns true if a given resolver is case-sensitive + private def isCaseSensitiveAnalysis(resolver: Resolver): Boolean = { + if (resolver == caseSensitiveResolution) { + true + } else if (resolver == caseInsensitiveResolution) { + false + } else { + sys.error("A resolver to check if two identifiers are equal must be " + + "`caseSensitiveResolution` or `caseInsensitiveResolution` in o.a.s.sql.catalyst.") + } + } + + /** + * Checks if input column names have duplicate identifiers. This throws an exception if * the duplication exists. * * @param columnNames column names to check - * @param colType column type name, used in a warning message + * @param colType column type name, used in an exception message + * @param resolver resolver used to determine if two identifiers are equal + */ + def checkColumnNameDuplication( + columnNames: Seq[String], colType: String, resolver: Resolver): Unit = { + checkColumnNameDuplication(columnNames, colType, isCaseSensitiveAnalysis(resolver)) + } + + /** + * Checks if input column names have duplicate identifiers. This throws an exception if + * the duplication exists. + * + * @param columnNames column names to check + * @param colType column type name, used in an exception message * @param caseSensitiveAnalysis whether duplication checks should be case sensitive or not */ def checkColumnNameDuplication( columnNames: Seq[String], colType: String, caseSensitiveAnalysis: Boolean): Unit = { - val names = if (caseSensitiveAnalysis) { - columnNames - } else { - columnNames.map(_.toLowerCase) - } + val names = if (caseSensitiveAnalysis) columnNames else columnNames.map(_.toLowerCase) if (names.distinct.length != names.length) { val duplicateColumns = names.groupBy(identity).collect { case (x, ys) if ys.length > 1 => s"`$x`" } - logWarning(s"Found duplicate column(s) $colType: ${duplicateColumns.mkString(", ")}. " + - "You might need to assign different column names.") + throw new AnalysisException( + s"Found duplicate column(s) $colType: ${duplicateColumns.mkString(", ")}") } } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/util/SchemaUtilsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/util/SchemaUtilsSuite.scala new file mode 100644 index 0000000000000..a25be2fe61dbd --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/util/SchemaUtilsSuite.scala @@ -0,0 +1,83 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.util + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.catalyst.analysis._ +import org.apache.spark.sql.types.StructType + +class SchemaUtilsSuite extends SparkFunSuite { + + private def resolver(caseSensitiveAnalysis: Boolean): Resolver = { + if (caseSensitiveAnalysis) { + caseSensitiveResolution + } else { + caseInsensitiveResolution + } + } + + Seq((true, ("a", "a"), ("b", "b")), (false, ("a", "A"), ("b", "B"))).foreach { + case (caseSensitive, (a0, a1), (b0, b1)) => + + val testType = if (caseSensitive) "case-sensitive" else "case-insensitive" + test(s"Check column name duplication in $testType cases") { + def checkExceptionCases(schemaStr: String, duplicatedColumns: Seq[String]): Unit = { + val expectedErrorMsg = "Found duplicate column(s) in SchemaUtilsSuite: " + + duplicatedColumns.map(c => s"`${c.toLowerCase}`").mkString(", ") + val schema = StructType.fromDDL(schemaStr) + var msg = intercept[AnalysisException] { + SchemaUtils.checkSchemaColumnNameDuplication( + schema, "in SchemaUtilsSuite", caseSensitiveAnalysis = caseSensitive) + }.getMessage + assert(msg.contains(expectedErrorMsg)) + msg = intercept[AnalysisException] { + SchemaUtils.checkColumnNameDuplication( + schema.map(_.name), "in SchemaUtilsSuite", resolver(caseSensitive)) + }.getMessage + assert(msg.contains(expectedErrorMsg)) + msg = intercept[AnalysisException] { + SchemaUtils.checkColumnNameDuplication( + schema.map(_.name), "in SchemaUtilsSuite", caseSensitiveAnalysis = caseSensitive) + }.getMessage + assert(msg.contains(expectedErrorMsg)) + } + + checkExceptionCases(s"$a0 INT, b INT, $a1 INT", a0 :: Nil) + checkExceptionCases(s"$a0 INT, b INT, $a1 INT, $a0 INT", a0 :: Nil) + checkExceptionCases(s"$a0 INT, $b0 INT, $a1 INT, $a0 INT, $b1 INT", b0 :: a0 :: Nil) + } + } + + test("Check no exception thrown for valid schemas") { + def checkNoExceptionCases(schemaStr: String, caseSensitive: Boolean): Unit = { + val schema = StructType.fromDDL(schemaStr) + SchemaUtils.checkSchemaColumnNameDuplication( + schema, "in SchemaUtilsSuite", caseSensitiveAnalysis = caseSensitive) + SchemaUtils.checkColumnNameDuplication( + schema.map(_.name), "in SchemaUtilsSuite", resolver(caseSensitive)) + SchemaUtils.checkColumnNameDuplication( + schema.map(_.name), "in SchemaUtilsSuite", caseSensitiveAnalysis = caseSensitive) + } + + checkNoExceptionCases("a INT, b INT, c INT", caseSensitive = true) + checkNoExceptionCases("Aa INT, b INT, aA INT", caseSensitive = true) + + checkNoExceptionCases("a INT, b INT, c INT", caseSensitive = false) + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/createDataSourceTables.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/createDataSourceTables.scala index 729bd39d821c9..04b2534ca5eb1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/createDataSourceTables.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/createDataSourceTables.scala @@ -19,8 +19,6 @@ package org.apache.spark.sql.execution.command import java.net.URI -import org.apache.hadoop.fs.Path - import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.catalog._ import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala index 8ded1060f7bf0..fa50d12722411 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala @@ -20,13 +20,11 @@ package org.apache.spark.sql.execution.command import java.io.File import java.net.URI import java.nio.file.FileSystems -import java.util.Date import scala.collection.mutable.ArrayBuffer import scala.util.control.NonFatal import scala.util.Try -import org.apache.commons.lang3.StringEscapeUtils import org.apache.hadoop.fs.Path import org.apache.spark.sql.{AnalysisException, Row, SparkSession} @@ -42,6 +40,7 @@ import org.apache.spark.sql.execution.datasources.csv.CSVFileFormat import org.apache.spark.sql.execution.datasources.json.JsonFileFormat import org.apache.spark.sql.execution.datasources.parquet.ParquetFileFormat import org.apache.spark.sql.types._ +import org.apache.spark.sql.util.SchemaUtils import org.apache.spark.util.Utils /** @@ -202,6 +201,11 @@ case class AlterTableAddColumnsCommand( // make sure any partition columns are at the end of the fields val reorderedSchema = catalogTable.dataSchema ++ columns ++ catalogTable.partitionSchema + + SchemaUtils.checkColumnNameDuplication( + reorderedSchema.map(_.name), "in the table definition of " + table.identifier, + conf.caseSensitiveAnalysis) + catalog.alterTableSchema( table, catalogTable.schema.copy(fields = reorderedSchema.toArray)) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/views.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/views.scala index a6d56ca91a3ee..ffdfd527fa701 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/views.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/views.scala @@ -27,6 +27,7 @@ import org.apache.spark.sql.catalyst.expressions.{Alias, SubqueryExpression} import org.apache.spark.sql.catalyst.plans.QueryPlan import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Project, View} import org.apache.spark.sql.types.MetadataBuilder +import org.apache.spark.sql.util.SchemaUtils /** @@ -355,15 +356,15 @@ object ViewHelper { properties: Map[String, String], session: SparkSession, analyzedPlan: LogicalPlan): Map[String, String] = { + val queryOutput = analyzedPlan.schema.fieldNames + // Generate the query column names, throw an AnalysisException if there exists duplicate column // names. - val queryOutput = analyzedPlan.schema.fieldNames - assert(queryOutput.distinct.size == queryOutput.size, - s"The view output ${queryOutput.mkString("(", ",", ")")} contains duplicate column name.") + SchemaUtils.checkColumnNameDuplication( + queryOutput, "in the view definition", session.sessionState.conf.resolver) // Generate the view default database name. val viewDefaultDatabase = session.sessionState.catalog.getCurrentDatabase - removeQueryColumnNames(properties) ++ generateViewDefaultDatabase(viewDefaultDatabase) ++ generateQueryColumnNames(queryOutput) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala index 75e530607570f..d36a04f1fff8e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala @@ -87,6 +87,14 @@ case class DataSource( lazy val providingClass: Class[_] = DataSource.lookupDataSource(className) lazy val sourceInfo: SourceInfo = sourceSchema() private val caseInsensitiveOptions = CaseInsensitiveMap(options) + private val equality = sparkSession.sessionState.conf.resolver + + bucketSpec.map { bucket => + SchemaUtils.checkColumnNameDuplication( + bucket.bucketColumnNames, "in the bucket definition", equality) + SchemaUtils.checkColumnNameDuplication( + bucket.sortColumnNames, "in the sort definition", equality) + } /** * Get the schema of the given FileFormat, if provided by `userSpecifiedSchema`, or try to infer @@ -132,7 +140,6 @@ case class DataSource( // Try to infer partitioning, because no DataSource in the read path provides the partitioning // columns properly unless it is a Hive DataSource val resolved = tempFileIndex.partitionSchema.map { partitionField => - val equality = sparkSession.sessionState.conf.resolver // SPARK-18510: try to get schema from userSpecifiedSchema, otherwise fallback to inferred userSpecifiedSchema.flatMap(_.find(f => equality(f.name, partitionField.name))).getOrElse( partitionField) @@ -146,7 +153,6 @@ case class DataSource( inferredPartitions } else { val partitionFields = partitionColumns.map { partitionColumn => - val equality = sparkSession.sessionState.conf.resolver userSpecifiedSchema.flatMap(_.find(c => equality(c.name, partitionColumn))).orElse { val inferredPartitions = tempFileIndex.partitionSchema val inferredOpt = inferredPartitions.find(p => equality(p.name, partitionColumn)) @@ -172,7 +178,6 @@ case class DataSource( } val dataSchema = userSpecifiedSchema.map { schema => - val equality = sparkSession.sessionState.conf.resolver StructType(schema.filterNot(f => partitionSchema.exists(p => equality(p.name, f.name)))) }.orElse { format.inferSchema( @@ -184,9 +189,18 @@ case class DataSource( s"Unable to infer schema for $format. It must be specified manually.") } - SchemaUtils.checkColumnNameDuplication( - (dataSchema ++ partitionSchema).map(_.name), "in the data schema and the partition schema", - sparkSession.sessionState.conf.caseSensitiveAnalysis) + // We just print a waring message if the data schema and partition schema have the duplicate + // columns. This is because we allow users to do so in the previous Spark releases and + // we have the existing tests for the cases (e.g., `ParquetHadoopFsRelationSuite`). + // See SPARK-18108 and SPARK-21144 for related discussions. + try { + SchemaUtils.checkColumnNameDuplication( + (dataSchema ++ partitionSchema).map(_.name), + "in the data schema and the partition schema", + equality) + } catch { + case e: AnalysisException => logWarning(e.getMessage) + } (dataSchema, partitionSchema) } @@ -391,6 +405,23 @@ case class DataSource( s"$className is not a valid Spark SQL Data Source.") } + relation match { + case hs: HadoopFsRelation => + SchemaUtils.checkColumnNameDuplication( + hs.dataSchema.map(_.name), + "in the data schema", + equality) + SchemaUtils.checkColumnNameDuplication( + hs.partitionSchema.map(_.name), + "in the partition schema", + equality) + case _ => + SchemaUtils.checkColumnNameDuplication( + relation.schema.map(_.name), + "in the data schema", + equality) + } + relation } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelationCommand.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelationCommand.scala index 0031567d3d288..c1bcfb8610783 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelationCommand.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelationCommand.scala @@ -21,7 +21,6 @@ import java.io.IOException import org.apache.hadoop.fs.{FileSystem, Path} -import org.apache.spark.SparkContext import org.apache.spark.internal.io.FileCommitProtocol import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.catalog.{BucketSpec, CatalogTable, CatalogTablePartition} @@ -30,7 +29,7 @@ import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.execution.SparkPlan import org.apache.spark.sql.execution.command._ -import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics} +import org.apache.spark.sql.util.SchemaUtils /** * A command for writing data to a [[HadoopFsRelation]]. Supports both overwriting and appending. @@ -64,13 +63,10 @@ case class InsertIntoHadoopFsRelationCommand( assert(children.length == 1) // Most formats don't do well with duplicate columns, so lets not allow that - if (query.schema.fieldNames.length != query.schema.fieldNames.distinct.length) { - val duplicateColumns = query.schema.fieldNames.groupBy(identity).collect { - case (x, ys) if ys.length > 1 => "\"" + x + "\"" - }.mkString(", ") - throw new AnalysisException(s"Duplicate column(s): $duplicateColumns found, " + - "cannot save to file.") - } + SchemaUtils.checkSchemaColumnNameDuplication( + query.schema, + s"when inserting into $outputPath", + sparkSession.sessionState.conf.caseSensitiveAnalysis) val hadoopConf = sparkSession.sessionState.newHadoopConfWithOptions(options) val fs = outputPath.getFileSystem(hadoopConf) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala index f61c673baaa58..92358da6d6c67 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala @@ -33,6 +33,7 @@ import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec import org.apache.spark.sql.catalyst.expressions.{Cast, Literal} import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.types._ +import org.apache.spark.sql.util.SchemaUtils // TODO: We should tighten up visibility of the classes here once we clean up Hive coupling. @@ -301,13 +302,8 @@ object PartitioningUtils { normalizedKey -> value } - if (normalizedPartSpec.map(_._1).distinct.length != normalizedPartSpec.length) { - val duplicateColumns = normalizedPartSpec.map(_._1).groupBy(identity).collect { - case (x, ys) if ys.length > 1 => x - } - throw new AnalysisException(s"Found duplicated columns in partition specification: " + - duplicateColumns.mkString(", ")) - } + SchemaUtils.checkColumnNameDuplication( + normalizedPartSpec.map(_._1), "in the partition schema", resolver) normalizedPartSpec.toMap } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala index 55b2539c13381..bbe9024f13a44 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala @@ -35,6 +35,7 @@ import org.apache.spark.sql.catalyst.parser.CatalystSqlParser import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, DateTimeUtils, GenericArrayData} import org.apache.spark.sql.jdbc.{JdbcDialect, JdbcDialects, JdbcType} import org.apache.spark.sql.types._ +import org.apache.spark.sql.util.SchemaUtils import org.apache.spark.unsafe.types.UTF8String import org.apache.spark.util.NextIterator @@ -749,14 +750,8 @@ object JdbcUtils extends Logging { val nameEquality = df.sparkSession.sessionState.conf.resolver // checks duplicate columns in the user specified column types. - userSchema.fieldNames.foreach { col => - val duplicatesCols = userSchema.fieldNames.filter(nameEquality(_, col)) - if (duplicatesCols.size >= 2) { - throw new AnalysisException( - "Found duplicate column(s) in createTableColumnTypes option value: " + - duplicatesCols.mkString(", ")) - } - } + SchemaUtils.checkColumnNameDuplication( + userSchema.map(_.name), "in the createTableColumnTypes option value", nameEquality) // checks if user specified column names exist in the DataFrame schema userSchema.fieldNames.foreach { col => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonDataSource.scala index 5a92a71d19e78..8b7c2709afde1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonDataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonDataSource.scala @@ -59,9 +59,7 @@ abstract class JsonDataSource extends Serializable { inputPaths: Seq[FileStatus], parsedOptions: JSONOptions): Option[StructType] = { if (inputPaths.nonEmpty) { - val jsonSchema = infer(sparkSession, inputPaths, parsedOptions) - checkConstraints(jsonSchema) - Some(jsonSchema) + Some(infer(sparkSession, inputPaths, parsedOptions)) } else { None } @@ -71,17 +69,6 @@ abstract class JsonDataSource extends Serializable { sparkSession: SparkSession, inputPaths: Seq[FileStatus], parsedOptions: JSONOptions): StructType - - /** Constraints to be imposed on schema to be stored. */ - private def checkConstraints(schema: StructType): Unit = { - if (schema.fieldNames.length != schema.fieldNames.distinct.length) { - val duplicateColumns = schema.fieldNames.groupBy(identity).collect { - case (x, ys) if ys.length > 1 => "\"" + x + "\"" - }.mkString(", ") - throw new AnalysisException(s"Duplicate column(s) : $duplicateColumns found, " + - s"cannot save to JSON format") - } - } } object JsonDataSource { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala index 3f4a78580f1eb..41d40aa926fbb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala @@ -29,6 +29,7 @@ import org.apache.spark.sql.execution.command.DDLUtils import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.sources.InsertableRelation import org.apache.spark.sql.types.{AtomicType, StructType} +import org.apache.spark.sql.util.SchemaUtils /** * Try to replaces [[UnresolvedRelation]]s if the plan is for direct query on files. @@ -222,12 +223,10 @@ case class PreprocessTableCreation(sparkSession: SparkSession) extends Rule[Logi } private def normalizeCatalogTable(schema: StructType, table: CatalogTable): CatalogTable = { - val columnNames = if (sparkSession.sessionState.conf.caseSensitiveAnalysis) { - schema.map(_.name) - } else { - schema.map(_.name.toLowerCase) - } - checkDuplication(columnNames, "table definition of " + table.identifier) + SchemaUtils.checkSchemaColumnNameDuplication( + schema, + "in the table definition of " + table.identifier, + sparkSession.sessionState.conf.caseSensitiveAnalysis) val normalizedPartCols = normalizePartitionColumns(schema, table) val normalizedBucketSpec = normalizeBucketSpec(schema, table) @@ -253,7 +252,10 @@ case class PreprocessTableCreation(sparkSession: SparkSession) extends Rule[Logi partCols = table.partitionColumnNames, resolver = sparkSession.sessionState.conf.resolver) - checkDuplication(normalizedPartitionCols, "partition") + SchemaUtils.checkColumnNameDuplication( + normalizedPartitionCols, + "in the partition schema", + sparkSession.sessionState.conf.resolver) if (schema.nonEmpty && normalizedPartitionCols.length == schema.length) { if (DDLUtils.isHiveTable(table)) { @@ -283,8 +285,15 @@ case class PreprocessTableCreation(sparkSession: SparkSession) extends Rule[Logi tableCols = schema.map(_.name), bucketSpec = bucketSpec, resolver = sparkSession.sessionState.conf.resolver) - checkDuplication(normalizedBucketSpec.bucketColumnNames, "bucket") - checkDuplication(normalizedBucketSpec.sortColumnNames, "sort") + + SchemaUtils.checkColumnNameDuplication( + normalizedBucketSpec.bucketColumnNames, + "in the bucket definition", + sparkSession.sessionState.conf.resolver) + SchemaUtils.checkColumnNameDuplication( + normalizedBucketSpec.sortColumnNames, + "in the sort definition", + sparkSession.sessionState.conf.resolver) normalizedBucketSpec.sortColumnNames.map(schema(_)).map(_.dataType).foreach { case dt if RowOrdering.isOrderable(dt) => // OK @@ -297,15 +306,6 @@ case class PreprocessTableCreation(sparkSession: SparkSession) extends Rule[Logi } } - private def checkDuplication(colNames: Seq[String], colType: String): Unit = { - if (colNames.distinct.length != colNames.length) { - val duplicateColumns = colNames.groupBy(identity).collect { - case (x, ys) if ys.length > 1 => x - } - failAnalysis(s"Found duplicate column(s) in $colType: ${duplicateColumns.mkString(", ")}") - } - } - private def failAnalysis(msg: String) = throw new AnalysisException(msg) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index b2219b4eb8c17..a5a2e1c38d300 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -1189,7 +1189,7 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { Seq((1, 2, 3), (2, 3, 4), (3, 4, 5)).toDF("column1", "column2", "column1") .write.format("parquet").save("temp") } - assert(e.getMessage.contains("Duplicate column(s)")) + assert(e.getMessage.contains("Found duplicate column(s) when inserting into")) assert(e.getMessage.contains("column1")) assert(!e.getMessage.contains("column2")) @@ -1199,7 +1199,7 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { .toDF("column1", "column2", "column3", "column1", "column3") .write.format("json").save("temp") } - assert(f.getMessage.contains("Duplicate column(s)")) + assert(f.getMessage.contains("Found duplicate column(s) when inserting into")) assert(f.getMessage.contains("column1")) assert(f.getMessage.contains("column3")) assert(!f.getMessage.contains("column2")) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala index 5c40d8bb4b1ef..5c0a6aa724bf0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala @@ -436,16 +436,13 @@ abstract class DDLSuite extends QueryTest with SQLTestUtils { } test("create table - duplicate column names in the table definition") { - val e = intercept[AnalysisException] { - sql("CREATE TABLE tbl(a int, a string) USING json") - } - assert(e.message == "Found duplicate column(s) in table definition of `tbl`: a") - - withSQLConf(SQLConf.CASE_SENSITIVE.key -> "false") { - val e2 = intercept[AnalysisException] { - sql("CREATE TABLE tbl(a int, A string) USING json") + Seq((true, ("a", "a")), (false, ("aA", "Aa"))).foreach { case (caseSensitive, (c0, c1)) => + withSQLConf(SQLConf.CASE_SENSITIVE.key -> caseSensitive.toString) { + val errMsg = intercept[AnalysisException] { + sql(s"CREATE TABLE t($c0 INT, $c1 INT) USING parquet") + }.getMessage + assert(errMsg.contains("Found duplicate column(s) in the table definition of `t`")) } - assert(e2.message == "Found duplicate column(s) in table definition of `tbl`: a") } } @@ -466,17 +463,33 @@ abstract class DDLSuite extends QueryTest with SQLTestUtils { } test("create table - column repeated in partition columns") { - val e = intercept[AnalysisException] { - sql("CREATE TABLE tbl(a int) USING json PARTITIONED BY (a, a)") + Seq((true, ("a", "a")), (false, ("aA", "Aa"))).foreach { case (caseSensitive, (c0, c1)) => + withSQLConf(SQLConf.CASE_SENSITIVE.key -> caseSensitive.toString) { + val errMsg = intercept[AnalysisException] { + sql(s"CREATE TABLE t($c0 INT) USING parquet PARTITIONED BY ($c0, $c1)") + }.getMessage + assert(errMsg.contains("Found duplicate column(s) in the partition schema")) + } } - assert(e.message == "Found duplicate column(s) in partition: a") } - test("create table - column repeated in bucket columns") { - val e = intercept[AnalysisException] { - sql("CREATE TABLE tbl(a int) USING json CLUSTERED BY (a, a) INTO 4 BUCKETS") + test("create table - column repeated in bucket/sort columns") { + Seq((true, ("a", "a")), (false, ("aA", "Aa"))).foreach { case (caseSensitive, (c0, c1)) => + withSQLConf(SQLConf.CASE_SENSITIVE.key -> caseSensitive.toString) { + var errMsg = intercept[AnalysisException] { + sql(s"CREATE TABLE t($c0 INT) USING parquet CLUSTERED BY ($c0, $c1) INTO 2 BUCKETS") + }.getMessage + assert(errMsg.contains("Found duplicate column(s) in the bucket definition")) + + errMsg = intercept[AnalysisException] { + sql(s""" + |CREATE TABLE t($c0 INT, col INT) USING parquet CLUSTERED BY (col) + | SORTED BY ($c0, $c1) INTO 2 BUCKETS + """.stripMargin) + }.getMessage + assert(errMsg.contains("Found duplicate column(s) in the sort definition")) + } } - assert(e.message == "Found duplicate column(s) in bucket: a") } test("Refresh table after changing the data source table partitioning") { @@ -528,6 +541,17 @@ abstract class DDLSuite extends QueryTest with SQLTestUtils { } } + test("create view - duplicate column names in the view definition") { + Seq((true, ("a", "a")), (false, ("aA", "Aa"))).foreach { case (caseSensitive, (c0, c1)) => + withSQLConf(SQLConf.CASE_SENSITIVE.key -> caseSensitive.toString) { + val errMsg = intercept[AnalysisException] { + sql(s"CREATE VIEW t AS SELECT * FROM VALUES (1, 1) AS t($c0, $c1)") + }.getMessage + assert(errMsg.contains("Found duplicate column(s) in the view definition")) + } + } + } + test("Alter/Describe Database") { val catalog = spark.sessionState.catalog val databaseNames = Seq("db1", "`database`") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala index 92f50a095f19b..2334d5ae32dc3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.jdbc -import java.sql.{Date, DriverManager, Timestamp} +import java.sql.DriverManager import java.util.Properties import scala.collection.JavaConverters.propertiesAsScalaMapConverter @@ -479,7 +479,7 @@ class JDBCWriteSuite extends SharedSQLContext with BeforeAndAfter { .jdbc(url1, "TEST.USERDBTYPETEST", properties) }.getMessage() assert(msg.contains( - "Found duplicate column(s) in createTableColumnTypes option value: name, NaMe")) + "Found duplicate column(s) in the createTableColumnTypes option value: `name`")) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/ResolvedDataSourceSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/ResolvedDataSourceSuite.scala index 0f97fd78d2ffb..308c5079c44bf 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/ResolvedDataSourceSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/ResolvedDataSourceSuite.scala @@ -21,11 +21,12 @@ import org.apache.spark.SparkFunSuite import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.execution.datasources.DataSource +import org.apache.spark.sql.test.SharedSQLContext -class ResolvedDataSourceSuite extends SparkFunSuite { +class ResolvedDataSourceSuite extends SparkFunSuite with SharedSQLContext { private def getProvidingClass(name: String): Class[_] = DataSource( - sparkSession = null, + sparkSession = spark, className = name, options = Map(DateTimeUtils.TIMEZONE_OPTION -> DateTimeUtils.defaultTimeZone().getID) ).providingClass diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSinkSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSinkSuite.scala index bb6a27803bb20..6676099d426ba 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSinkSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSinkSuite.scala @@ -26,6 +26,7 @@ import org.apache.spark.sql.execution.DataSourceScanExec import org.apache.spark.sql.execution.datasources._ import org.apache.spark.sql.execution.streaming._ import org.apache.spark.sql.functions._ +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.{IntegerType, StructField, StructType} import org.apache.spark.util.Utils @@ -352,4 +353,40 @@ class FileStreamSinkSuite extends StreamTest { assertAncestorIsNotMetadataDirectory(s"/a/b/c") assertAncestorIsNotMetadataDirectory(s"/a/b/c/${FileStreamSink.metadataDir}extra") } + + test("SPARK-20460 Check name duplication in schema") { + Seq((true, ("a", "a")), (false, ("aA", "Aa"))).foreach { case (caseSensitive, (c0, c1)) => + withSQLConf(SQLConf.CASE_SENSITIVE.key -> caseSensitive.toString) { + val inputData = MemoryStream[(Int, Int)] + val df = inputData.toDF() + + val outputDir = Utils.createTempDir(namePrefix = "stream.output").getCanonicalPath + val checkpointDir = Utils.createTempDir(namePrefix = "stream.checkpoint").getCanonicalPath + + var query: StreamingQuery = null + try { + query = + df.writeStream + .option("checkpointLocation", checkpointDir) + .format("json") + .start(outputDir) + + inputData.addData((1, 1)) + + failAfter(streamingTimeout) { + query.processAllAvailable() + } + } finally { + if (query != null) { + query.stop() + } + } + + val errorMsg = intercept[AnalysisException] { + spark.read.schema(s"$c0 INT, $c1 INT").json(outputDir).as[(Int, Int)] + }.getMessage + assert(errorMsg.contains("Found duplicate column(s) in the data schema: ")) + } + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/DataFrameReaderWriterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/DataFrameReaderWriterSuite.scala index 306aecb5bbc86..569bac156b531 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/test/DataFrameReaderWriterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/DataFrameReaderWriterSuite.scala @@ -27,6 +27,7 @@ import org.apache.spark.internal.io.FileCommitProtocol.TaskCommitMessage import org.apache.spark.internal.io.HadoopMapReduceCommitProtocol import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.TableIdentifier +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.sources._ import org.apache.spark.sql.types._ import org.apache.spark.util.Utils @@ -687,4 +688,91 @@ class DataFrameReaderWriterSuite extends QueryTest with SharedSQLContext with Be testRead(spark.read.schema(userSchemaString).text(dir, dir), data ++ data, userSchema) testRead(spark.read.schema(userSchemaString).text(Seq(dir, dir): _*), data ++ data, userSchema) } + + test("SPARK-20460 Check name duplication in buckets") { + Seq((true, ("a", "a")), (false, ("aA", "Aa"))).foreach { case (caseSensitive, (c0, c1)) => + withSQLConf(SQLConf.CASE_SENSITIVE.key -> caseSensitive.toString) { + var errorMsg = intercept[AnalysisException] { + Seq((1, 1)).toDF("col", c0).write.bucketBy(2, c0, c1).saveAsTable("t") + }.getMessage + assert(errorMsg.contains("Found duplicate column(s) in the bucket definition")) + + errorMsg = intercept[AnalysisException] { + Seq((1, 1)).toDF("col", c0).write.bucketBy(2, "col").sortBy(c0, c1).saveAsTable("t") + }.getMessage + assert(errorMsg.contains("Found duplicate column(s) in the sort definition")) + } + } + } + + test("SPARK-20460 Check name duplication in schema") { + def checkWriteDataColumnDuplication( + format: String, colName0: String, colName1: String, tempDir: File): Unit = { + val errorMsg = intercept[AnalysisException] { + Seq((1, 1)).toDF(colName0, colName1).write.format(format).mode("overwrite") + .save(tempDir.getAbsolutePath) + }.getMessage + assert(errorMsg.contains("Found duplicate column(s) when inserting into")) + } + + def checkReadUserSpecifiedDataColumnDuplication( + df: DataFrame, format: String, colName0: String, colName1: String, tempDir: File): Unit = { + val testDir = Utils.createTempDir(tempDir.getAbsolutePath) + df.write.format(format).mode("overwrite").save(testDir.getAbsolutePath) + val errorMsg = intercept[AnalysisException] { + spark.read.format(format).schema(s"$colName0 INT, $colName1 INT") + .load(testDir.getAbsolutePath) + }.getMessage + assert(errorMsg.contains("Found duplicate column(s) in the data schema:")) + } + + def checkReadPartitionColumnDuplication( + format: String, colName0: String, colName1: String, tempDir: File): Unit = { + val testDir = Utils.createTempDir(tempDir.getAbsolutePath) + Seq(1).toDF("col").write.format(format).mode("overwrite") + .save(s"${testDir.getAbsolutePath}/$colName0=1/$colName1=1") + val errorMsg = intercept[AnalysisException] { + spark.read.format(format).load(testDir.getAbsolutePath) + }.getMessage + assert(errorMsg.contains("Found duplicate column(s) in the partition schema:")) + } + + Seq((true, ("a", "a")), (false, ("aA", "Aa"))).foreach { case (caseSensitive, (c0, c1)) => + withSQLConf(SQLConf.CASE_SENSITIVE.key -> caseSensitive.toString) { + withTempDir { src => + // Check CSV format + checkWriteDataColumnDuplication("csv", c0, c1, src) + checkReadUserSpecifiedDataColumnDuplication( + Seq((1, 1)).toDF("c0", "c1"), "csv", c0, c1, src) + // If `inferSchema` is true, a CSV format is duplicate-safe (See SPARK-16896) + var testDir = Utils.createTempDir(src.getAbsolutePath) + Seq("a,a", "1,1").toDF().coalesce(1).write.mode("overwrite").text(testDir.getAbsolutePath) + val df = spark.read.format("csv").option("inferSchema", true).option("header", true) + .load(testDir.getAbsolutePath) + checkAnswer(df, Row(1, 1)) + checkReadPartitionColumnDuplication("csv", c0, c1, src) + + // Check JSON format + checkWriteDataColumnDuplication("json", c0, c1, src) + checkReadUserSpecifiedDataColumnDuplication( + Seq((1, 1)).toDF("c0", "c1"), "json", c0, c1, src) + // Inferred schema cases + testDir = Utils.createTempDir(src.getAbsolutePath) + Seq(s"""{"$c0":3, "$c1":5}""").toDF().write.mode("overwrite") + .text(testDir.getAbsolutePath) + val errorMsg = intercept[AnalysisException] { + spark.read.format("json").option("inferSchema", true).load(testDir.getAbsolutePath) + }.getMessage + assert(errorMsg.contains("Found duplicate column(s) in the data schema:")) + checkReadPartitionColumnDuplication("json", c0, c1, src) + + // Check Parquet format + checkWriteDataColumnDuplication("parquet", c0, c1, src) + checkReadUserSpecifiedDataColumnDuplication( + Seq((1, 1)).toDF("c0", "c1"), "parquet", c0, c1, src) + checkReadPartitionColumnDuplication("parquet", c0, c1, src) + } + } + } + } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala index 31fa3d2447467..12daf3af11abe 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala @@ -345,7 +345,7 @@ class HiveDDLSuite val e = intercept[AnalysisException] { sql("CREATE TABLE tbl(a int) PARTITIONED BY (a string)") } - assert(e.message == "Found duplicate column(s) in table definition of `default`.`tbl`: a") + assert(e.message == "Found duplicate column(s) in the table definition of `default`.`tbl`: `a`") } test("add/drop partition with location - managed table") { From 6a06c4b03c4dd86241fb9d11b4360371488f0e53 Mon Sep 17 00:00:00 2001 From: jinxing Date: Mon, 10 Jul 2017 21:06:58 +0800 Subject: [PATCH 0921/1765] [SPARK-21342] Fix DownloadCallback to work well with RetryingBlockFetcher. ## What changes were proposed in this pull request? When `RetryingBlockFetcher` retries fetching blocks. There could be two `DownloadCallback`s download the same content to the same target file. It could cause `ShuffleBlockFetcherIterator` reading a partial result. This pr proposes to create and delete the tmp files in `OneForOneBlockFetcher` Author: jinxing Author: Shixiong Zhu Closes #18565 from jinxing64/SPARK-21342. --- .../shuffle/ExternalShuffleClient.java | 7 ++-- .../shuffle/OneForOneBlockFetcher.java | 34 +++++++++++------- .../spark/network/shuffle/ShuffleClient.java | 13 +++++-- .../shuffle/TempShuffleFileManager.java | 36 +++++++++++++++++++ .../network/sasl/SaslIntegrationSuite.java | 2 +- .../shuffle/OneForOneBlockFetcherSuite.java | 2 +- .../spark/network/BlockTransferService.scala | 8 ++--- .../netty/NettyBlockTransferService.scala | 9 +++-- .../storage/ShuffleBlockFetcherIterator.scala | 28 ++++++++++----- .../spark/storage/BlockManagerSuite.scala | 4 +-- .../ShuffleBlockFetcherIteratorSuite.scala | 10 +++--- 11 files changed, 108 insertions(+), 45 deletions(-) create mode 100644 common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/TempShuffleFileManager.java diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleClient.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleClient.java index 6ac9302517ee0..31bd24e5038b2 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleClient.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleClient.java @@ -17,7 +17,6 @@ package org.apache.spark.network.shuffle; -import java.io.File; import java.io.IOException; import java.nio.ByteBuffer; import java.util.List; @@ -91,15 +90,15 @@ public void fetchBlocks( String execId, String[] blockIds, BlockFetchingListener listener, - File[] shuffleFiles) { + TempShuffleFileManager tempShuffleFileManager) { checkInit(); logger.debug("External shuffle fetch from {}:{} (executor id {})", host, port, execId); try { RetryingBlockFetcher.BlockFetchStarter blockFetchStarter = (blockIds1, listener1) -> { TransportClient client = clientFactory.createClient(host, port); - new OneForOneBlockFetcher(client, appId, execId, blockIds1, listener1, conf, - shuffleFiles).start(); + new OneForOneBlockFetcher(client, appId, execId, + blockIds1, listener1, conf, tempShuffleFileManager).start(); }; int maxRetries = conf.maxIORetries(); diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockFetcher.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockFetcher.java index d46ce2e0e6b78..2f160d12af22b 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockFetcher.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockFetcher.java @@ -57,11 +57,21 @@ public class OneForOneBlockFetcher { private final String[] blockIds; private final BlockFetchingListener listener; private final ChunkReceivedCallback chunkCallback; - private TransportConf transportConf = null; - private File[] shuffleFiles = null; + private final TransportConf transportConf; + private final TempShuffleFileManager tempShuffleFileManager; private StreamHandle streamHandle = null; + public OneForOneBlockFetcher( + TransportClient client, + String appId, + String execId, + String[] blockIds, + BlockFetchingListener listener, + TransportConf transportConf) { + this(client, appId, execId, blockIds, listener, transportConf, null); + } + public OneForOneBlockFetcher( TransportClient client, String appId, @@ -69,18 +79,14 @@ public OneForOneBlockFetcher( String[] blockIds, BlockFetchingListener listener, TransportConf transportConf, - File[] shuffleFiles) { + TempShuffleFileManager tempShuffleFileManager) { this.client = client; this.openMessage = new OpenBlocks(appId, execId, blockIds); this.blockIds = blockIds; this.listener = listener; this.chunkCallback = new ChunkCallback(); this.transportConf = transportConf; - if (shuffleFiles != null) { - this.shuffleFiles = shuffleFiles; - assert this.shuffleFiles.length == blockIds.length: - "Number of shuffle files should equal to blocks"; - } + this.tempShuffleFileManager = tempShuffleFileManager; } /** Callback invoked on receipt of each chunk. We equate a single chunk to a single block. */ @@ -119,9 +125,9 @@ public void onSuccess(ByteBuffer response) { // Immediately request all chunks -- we expect that the total size of the request is // reasonable due to higher level chunking in [[ShuffleBlockFetcherIterator]]. for (int i = 0; i < streamHandle.numChunks; i++) { - if (shuffleFiles != null) { + if (tempShuffleFileManager != null) { client.stream(OneForOneStreamManager.genStreamChunkId(streamHandle.streamId, i), - new DownloadCallback(shuffleFiles[i], i)); + new DownloadCallback(i)); } else { client.fetchChunk(streamHandle.streamId, i, chunkCallback); } @@ -157,8 +163,8 @@ private class DownloadCallback implements StreamCallback { private File targetFile = null; private int chunkIndex; - DownloadCallback(File targetFile, int chunkIndex) throws IOException { - this.targetFile = targetFile; + DownloadCallback(int chunkIndex) throws IOException { + this.targetFile = tempShuffleFileManager.createTempShuffleFile(); this.channel = Channels.newChannel(new FileOutputStream(targetFile)); this.chunkIndex = chunkIndex; } @@ -174,6 +180,9 @@ public void onComplete(String streamId) throws IOException { ManagedBuffer buffer = new FileSegmentManagedBuffer(transportConf, targetFile, 0, targetFile.length()); listener.onBlockFetchSuccess(blockIds[chunkIndex], buffer); + if (!tempShuffleFileManager.registerTempShuffleFileToClean(targetFile)) { + targetFile.delete(); + } } @Override @@ -182,6 +191,7 @@ public void onFailure(String streamId, Throwable cause) throws IOException { // On receipt of a failure, fail every block from chunkIndex onwards. String[] remainingBlockIds = Arrays.copyOfRange(blockIds, chunkIndex, blockIds.length); failRemainingBlocks(remainingBlockIds, cause); + targetFile.delete(); } } } diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ShuffleClient.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ShuffleClient.java index 978ff5a2a8699..9e77bee7f9ee6 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ShuffleClient.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ShuffleClient.java @@ -18,7 +18,6 @@ package org.apache.spark.network.shuffle; import java.io.Closeable; -import java.io.File; /** Provides an interface for reading shuffle files, either from an Executor or external service. */ public abstract class ShuffleClient implements Closeable { @@ -35,6 +34,16 @@ public void init(String appId) { } * Note that this API takes a sequence so the implementation can batch requests, and does not * return a future so the underlying implementation can invoke onBlockFetchSuccess as soon as * the data of a block is fetched, rather than waiting for all blocks to be fetched. + * + * @param host the host of the remote node. + * @param port the port of the remote node. + * @param execId the executor id. + * @param blockIds block ids to fetch. + * @param listener the listener to receive block fetching status. + * @param tempShuffleFileManager TempShuffleFileManager to create and clean temp shuffle files. + * If it's not null, the remote blocks will be streamed + * into temp shuffle files to reduce the memory usage, otherwise, + * they will be kept in memory. */ public abstract void fetchBlocks( String host, @@ -42,5 +51,5 @@ public abstract void fetchBlocks( String execId, String[] blockIds, BlockFetchingListener listener, - File[] shuffleFiles); + TempShuffleFileManager tempShuffleFileManager); } diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/TempShuffleFileManager.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/TempShuffleFileManager.java new file mode 100644 index 0000000000000..84a5ed6a276bd --- /dev/null +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/TempShuffleFileManager.java @@ -0,0 +1,36 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.shuffle; + +import java.io.File; + +/** + * A manager to create temp shuffle block files to reduce the memory usage and also clean temp + * files when they won't be used any more. + */ +public interface TempShuffleFileManager { + + /** Create a temp shuffle block file. */ + File createTempShuffleFile(); + + /** + * Register a temp shuffle file to clean up when it won't be used any more. Return whether the + * file is registered successfully. If `false`, the caller should clean up the file by itself. + */ + boolean registerTempShuffleFileToClean(File file); +} diff --git a/common/network-shuffle/src/test/java/org/apache/spark/network/sasl/SaslIntegrationSuite.java b/common/network-shuffle/src/test/java/org/apache/spark/network/sasl/SaslIntegrationSuite.java index 8110f1e004c73..02e6eb3a4467e 100644 --- a/common/network-shuffle/src/test/java/org/apache/spark/network/sasl/SaslIntegrationSuite.java +++ b/common/network-shuffle/src/test/java/org/apache/spark/network/sasl/SaslIntegrationSuite.java @@ -204,7 +204,7 @@ public void onBlockFetchFailure(String blockId, Throwable t) { String[] blockIds = { "shuffle_0_1_2", "shuffle_0_3_4" }; OneForOneBlockFetcher fetcher = - new OneForOneBlockFetcher(client1, "app-2", "0", blockIds, listener, conf, null); + new OneForOneBlockFetcher(client1, "app-2", "0", blockIds, listener, conf); fetcher.start(); blockFetchLatch.await(); checkSecurityException(exception.get()); diff --git a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/OneForOneBlockFetcherSuite.java b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/OneForOneBlockFetcherSuite.java index 61d82214e7d30..dc947a619bf02 100644 --- a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/OneForOneBlockFetcherSuite.java +++ b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/OneForOneBlockFetcherSuite.java @@ -131,7 +131,7 @@ private static BlockFetchingListener fetchBlocks(LinkedHashMap { diff --git a/core/src/main/scala/org/apache/spark/network/BlockTransferService.scala b/core/src/main/scala/org/apache/spark/network/BlockTransferService.scala index 6860214c7fe39..fe5fd2da039bb 100644 --- a/core/src/main/scala/org/apache/spark/network/BlockTransferService.scala +++ b/core/src/main/scala/org/apache/spark/network/BlockTransferService.scala @@ -17,7 +17,7 @@ package org.apache.spark.network -import java.io.{Closeable, File} +import java.io.Closeable import java.nio.ByteBuffer import scala.concurrent.{Future, Promise} @@ -26,7 +26,7 @@ import scala.reflect.ClassTag import org.apache.spark.internal.Logging import org.apache.spark.network.buffer.{ManagedBuffer, NioManagedBuffer} -import org.apache.spark.network.shuffle.{BlockFetchingListener, ShuffleClient} +import org.apache.spark.network.shuffle.{BlockFetchingListener, ShuffleClient, TempShuffleFileManager} import org.apache.spark.storage.{BlockId, StorageLevel} import org.apache.spark.util.ThreadUtils @@ -68,7 +68,7 @@ abstract class BlockTransferService extends ShuffleClient with Closeable with Lo execId: String, blockIds: Array[String], listener: BlockFetchingListener, - shuffleFiles: Array[File]): Unit + tempShuffleFileManager: TempShuffleFileManager): Unit /** * Upload a single block to a remote node, available only after [[init]] is invoked. @@ -101,7 +101,7 @@ abstract class BlockTransferService extends ShuffleClient with Closeable with Lo ret.flip() result.success(new NioManagedBuffer(ret)) } - }, shuffleFiles = null) + }, tempShuffleFileManager = null) ThreadUtils.awaitResult(result.future, Duration.Inf) } diff --git a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala index b13a9c681e543..30ff93897f98a 100644 --- a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala +++ b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala @@ -17,7 +17,6 @@ package org.apache.spark.network.netty -import java.io.File import java.nio.ByteBuffer import scala.collection.JavaConverters._ @@ -30,7 +29,7 @@ import org.apache.spark.network.buffer.ManagedBuffer import org.apache.spark.network.client.{RpcResponseCallback, TransportClientBootstrap, TransportClientFactory} import org.apache.spark.network.crypto.{AuthClientBootstrap, AuthServerBootstrap} import org.apache.spark.network.server._ -import org.apache.spark.network.shuffle.{BlockFetchingListener, OneForOneBlockFetcher, RetryingBlockFetcher} +import org.apache.spark.network.shuffle.{BlockFetchingListener, OneForOneBlockFetcher, RetryingBlockFetcher, TempShuffleFileManager} import org.apache.spark.network.shuffle.protocol.UploadBlock import org.apache.spark.network.util.JavaUtils import org.apache.spark.serializer.JavaSerializer @@ -90,14 +89,14 @@ private[spark] class NettyBlockTransferService( execId: String, blockIds: Array[String], listener: BlockFetchingListener, - shuffleFiles: Array[File]): Unit = { + tempShuffleFileManager: TempShuffleFileManager): Unit = { logTrace(s"Fetch blocks from $host:$port (executor id $execId)") try { val blockFetchStarter = new RetryingBlockFetcher.BlockFetchStarter { override def createAndStart(blockIds: Array[String], listener: BlockFetchingListener) { val client = clientFactory.createClient(host, port) - new OneForOneBlockFetcher(client, appId, execId, blockIds.toArray, listener, - transportConf, shuffleFiles).start() + new OneForOneBlockFetcher(client, appId, execId, blockIds, listener, + transportConf, tempShuffleFileManager).start() } } diff --git a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala index a10f1feadd0af..81d822dc8a98f 100644 --- a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala +++ b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala @@ -28,7 +28,7 @@ import scala.collection.mutable.{ArrayBuffer, HashSet, Queue} import org.apache.spark.{SparkException, TaskContext} import org.apache.spark.internal.Logging import org.apache.spark.network.buffer.{FileSegmentManagedBuffer, ManagedBuffer} -import org.apache.spark.network.shuffle.{BlockFetchingListener, ShuffleClient} +import org.apache.spark.network.shuffle.{BlockFetchingListener, ShuffleClient, TempShuffleFileManager} import org.apache.spark.shuffle.FetchFailedException import org.apache.spark.util.Utils import org.apache.spark.util.io.ChunkedByteBufferOutputStream @@ -66,7 +66,7 @@ final class ShuffleBlockFetcherIterator( maxReqsInFlight: Int, maxReqSizeShuffleToMem: Long, detectCorrupt: Boolean) - extends Iterator[(BlockId, InputStream)] with Logging { + extends Iterator[(BlockId, InputStream)] with TempShuffleFileManager with Logging { import ShuffleBlockFetcherIterator._ @@ -135,7 +135,8 @@ final class ShuffleBlockFetcherIterator( * A set to store the files used for shuffling remote huge blocks. Files in this set will be * deleted when cleanup. This is a layer of defensiveness against disk file leaks. */ - val shuffleFilesSet = mutable.HashSet[File]() + @GuardedBy("this") + private[this] val shuffleFilesSet = mutable.HashSet[File]() initialize() @@ -149,6 +150,19 @@ final class ShuffleBlockFetcherIterator( currentResult = null } + override def createTempShuffleFile(): File = { + blockManager.diskBlockManager.createTempLocalBlock()._2 + } + + override def registerTempShuffleFileToClean(file: File): Boolean = synchronized { + if (isZombie) { + false + } else { + shuffleFilesSet += file + true + } + } + /** * Mark the iterator as zombie, and release all buffers that haven't been deserialized yet. */ @@ -176,7 +190,7 @@ final class ShuffleBlockFetcherIterator( } shuffleFilesSet.foreach { file => if (!file.delete()) { - logInfo("Failed to cleanup shuffle fetch temp file " + file.getAbsolutePath()); + logWarning("Failed to cleanup shuffle fetch temp file " + file.getAbsolutePath()) } } } @@ -221,12 +235,8 @@ final class ShuffleBlockFetcherIterator( // already encrypted and compressed over the wire(w.r.t. the related configs), we can just fetch // the data and write it to file directly. if (req.size > maxReqSizeShuffleToMem) { - val shuffleFiles = blockIds.map { _ => - blockManager.diskBlockManager.createTempLocalBlock()._2 - }.toArray - shuffleFilesSet ++= shuffleFiles shuffleClient.fetchBlocks(address.host, address.port, address.executorId, blockIds.toArray, - blockFetchingListener, shuffleFiles) + blockFetchingListener, this) } else { shuffleClient.fetchBlocks(address.host, address.port, address.executorId, blockIds.toArray, blockFetchingListener, null) diff --git a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala index 086adccea954c..755a61a438a6a 100644 --- a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala @@ -45,7 +45,7 @@ import org.apache.spark.network.buffer.{ManagedBuffer, NioManagedBuffer} import org.apache.spark.network.client.{RpcResponseCallback, TransportClient} import org.apache.spark.network.netty.{NettyBlockTransferService, SparkTransportConf} import org.apache.spark.network.server.{NoOpRpcHandler, TransportServer, TransportServerBootstrap} -import org.apache.spark.network.shuffle.BlockFetchingListener +import org.apache.spark.network.shuffle.{BlockFetchingListener, ShuffleClient, TempShuffleFileManager} import org.apache.spark.network.shuffle.protocol.{BlockTransferMessage, RegisterExecutor} import org.apache.spark.rpc.RpcEnv import org.apache.spark.scheduler.LiveListenerBus @@ -1382,7 +1382,7 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE execId: String, blockIds: Array[String], listener: BlockFetchingListener, - shuffleFiles: Array[File]): Unit = { + tempShuffleFileManager: TempShuffleFileManager): Unit = { listener.onBlockFetchSuccess("mockBlockId", new NioManagedBuffer(ByteBuffer.allocate(1))) } diff --git a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala index 559b3faab8fd2..6a70cedf769b8 100644 --- a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala @@ -33,7 +33,7 @@ import org.scalatest.PrivateMethodTester import org.apache.spark.{SparkFunSuite, TaskContext} import org.apache.spark.network._ import org.apache.spark.network.buffer.{FileSegmentManagedBuffer, ManagedBuffer} -import org.apache.spark.network.shuffle.BlockFetchingListener +import org.apache.spark.network.shuffle.{BlockFetchingListener, TempShuffleFileManager} import org.apache.spark.network.util.LimitedInputStream import org.apache.spark.shuffle.FetchFailedException import org.apache.spark.util.Utils @@ -432,12 +432,12 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT val remoteBlocks = Map[BlockId, ManagedBuffer]( ShuffleBlockId(0, 0, 0) -> createMockManagedBuffer()) val transfer = mock(classOf[BlockTransferService]) - var shuffleFiles: Array[File] = null + var tempShuffleFileManager: TempShuffleFileManager = null when(transfer.fetchBlocks(any(), any(), any(), any(), any(), any())) .thenAnswer(new Answer[Unit] { override def answer(invocation: InvocationOnMock): Unit = { val listener = invocation.getArguments()(4).asInstanceOf[BlockFetchingListener] - shuffleFiles = invocation.getArguments()(5).asInstanceOf[Array[File]] + tempShuffleFileManager = invocation.getArguments()(5).asInstanceOf[TempShuffleFileManager] Future { listener.onBlockFetchSuccess( ShuffleBlockId(0, 0, 0).toString, remoteBlocks(ShuffleBlockId(0, 0, 0))) @@ -466,13 +466,13 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT fetchShuffleBlock(blocksByAddress1) // `maxReqSizeShuffleToMem` is 200, which is greater than the block size 100, so don't fetch // shuffle block to disk. - assert(shuffleFiles === null) + assert(tempShuffleFileManager == null) val blocksByAddress2 = Seq[(BlockManagerId, Seq[(BlockId, Long)])]( (remoteBmId, remoteBlocks.keys.map(blockId => (blockId, 300L)).toSeq)) fetchShuffleBlock(blocksByAddress2) // `maxReqSizeShuffleToMem` is 200, which is smaller than the block size 300, so fetch // shuffle block to disk. - assert(shuffleFiles != null) + assert(tempShuffleFileManager != null) } } From 18b3b00ecfde6c694fb6fee4f4d07d04e3d08ccf Mon Sep 17 00:00:00 2001 From: Juliusz Sompolski Date: Mon, 10 Jul 2017 09:26:42 -0700 Subject: [PATCH 0922/1765] [SPARK-21272] SortMergeJoin LeftAnti does not update numOutputRows ## What changes were proposed in this pull request? Updating numOutputRows metric was missing from one return path of LeftAnti SortMergeJoin. ## How was this patch tested? Non-zero output rows manually seen in metrics. Author: Juliusz Sompolski Closes #18494 from juliuszsompolski/SPARK-21272. --- .../sql/execution/joins/SortMergeJoinExec.scala | 1 + .../spark/sql/execution/metric/SQLMetricsSuite.scala | 12 ++++++++++++ 2 files changed, 13 insertions(+) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala index 8445c26eeee58..639b8e00c121b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala @@ -290,6 +290,7 @@ case class SortMergeJoinExec( currentLeftRow = smjScanner.getStreamedRow val currentRightMatches = smjScanner.getBufferedMatches if (currentRightMatches == null || currentRightMatches.length == 0) { + numOutputRows += 1 return true } var found = false diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala index cb3405b2fe19b..2911cbbeee479 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala @@ -483,6 +483,18 @@ class SQLMetricsSuite extends SparkFunSuite with SharedSQLContext { } } + test("SortMergeJoin(left-anti) metrics") { + val anti = testData2.filter("a > 2") + withTempView("antiData") { + anti.createOrReplaceTempView("antiData") + val df = spark.sql( + "SELECT * FROM testData2 ANTI JOIN antiData ON testData2.a = antiData.a") + testSparkPlanMetrics(df, 1, Map( + 0L -> ("SortMergeJoin", Map("number of output rows" -> 4L))) + ) + } + } + test("save metrics") { withTempPath { file => // person creates a temporary view. get the DF before listing previous execution IDs From 2bfd5accdce2ae31feeeddf213a019cf8ec97663 Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Mon, 10 Jul 2017 10:40:03 -0700 Subject: [PATCH 0923/1765] [SPARK-21266][R][PYTHON] Support schema a DDL-formatted string in dapply/gapply/from_json ## What changes were proposed in this pull request? This PR supports schema in a DDL formatted string for `from_json` in R/Python and `dapply` and `gapply` in R, which are commonly used and/or consistent with Scala APIs. Additionally, this PR exposes `structType` in R to allow working around in other possible corner cases. **Python** `from_json` ```python from pyspark.sql.functions import from_json data = [(1, '''{"a": 1}''')] df = spark.createDataFrame(data, ("key", "value")) df.select(from_json(df.value, "a INT").alias("json")).show() ``` **R** `from_json` ```R df <- sql("SELECT named_struct('name', 'Bob') as people") df <- mutate(df, people_json = to_json(df$people)) head(select(df, from_json(df$people_json, "name STRING"))) ``` `structType.character` ```R structType("a STRING, b INT") ``` `dapply` ```R dapply(createDataFrame(list(list(1.0)), "a"), function(x) {x}, "a DOUBLE") ``` `gapply` ```R gapply(createDataFrame(list(list(1.0)), "a"), "a", function(key, x) { x }, "a DOUBLE") ``` ## How was this patch tested? Doc tests for `from_json` in Python and unit tests `test_sparkSQL.R` in R. Author: hyukjinkwon Closes #18498 from HyukjinKwon/SPARK-21266. --- R/pkg/NAMESPACE | 2 + R/pkg/R/DataFrame.R | 36 ++++- R/pkg/R/functions.R | 12 +- R/pkg/R/group.R | 3 + R/pkg/R/schema.R | 29 +++- R/pkg/tests/fulltests/test_sparkSQL.R | 136 ++++++++++-------- python/pyspark/sql/functions.py | 11 +- .../org/apache/spark/sql/functions.scala | 7 +- 8 files changed, 160 insertions(+), 76 deletions(-) diff --git a/R/pkg/NAMESPACE b/R/pkg/NAMESPACE index b7fdae58de459..232f5cf31f319 100644 --- a/R/pkg/NAMESPACE +++ b/R/pkg/NAMESPACE @@ -429,6 +429,7 @@ export("structField", "structField.character", "print.structField", "structType", + "structType.character", "structType.jobj", "structType.structField", "print.structType") @@ -465,5 +466,6 @@ S3method(print, summary.GBTRegressionModel) S3method(print, summary.GBTClassificationModel) S3method(structField, character) S3method(structField, jobj) +S3method(structType, character) S3method(structType, jobj) S3method(structType, structField) diff --git a/R/pkg/R/DataFrame.R b/R/pkg/R/DataFrame.R index 3b9d42d6e7158..e7a166c3014c1 100644 --- a/R/pkg/R/DataFrame.R +++ b/R/pkg/R/DataFrame.R @@ -1391,6 +1391,10 @@ setMethod("summarize", }) dapplyInternal <- function(x, func, schema) { + if (is.character(schema)) { + schema <- structType(schema) + } + packageNamesArr <- serialize(.sparkREnv[[".packages"]], connection = NULL) @@ -1408,6 +1412,8 @@ dapplyInternal <- function(x, func, schema) { dataFrame(sdf) } +setClassUnion("characterOrstructType", c("character", "structType")) + #' dapply #' #' Apply a function to each partition of a SparkDataFrame. @@ -1418,10 +1424,11 @@ dapplyInternal <- function(x, func, schema) { #' to each partition will be passed. #' The output of func should be a R data.frame. #' @param schema The schema of the resulting SparkDataFrame after the function is applied. -#' It must match the output of func. +#' It must match the output of func. Since Spark 2.3, the DDL-formatted string +#' is also supported for the schema. #' @family SparkDataFrame functions #' @rdname dapply -#' @aliases dapply,SparkDataFrame,function,structType-method +#' @aliases dapply,SparkDataFrame,function,characterOrstructType-method #' @name dapply #' @seealso \link{dapplyCollect} #' @export @@ -1444,6 +1451,17 @@ dapplyInternal <- function(x, func, schema) { #' y <- cbind(y, y[1] + 1L) #' }, #' schema) +#' +#' # The schema also can be specified in a DDL-formatted string. +#' schema <- "a INT, d DOUBLE, c STRING, d INT" +#' df1 <- dapply( +#' df, +#' function(x) { +#' y <- x[x[1] > 1, ] +#' y <- cbind(y, y[1] + 1L) +#' }, +#' schema) +#' #' collect(df1) #' # the result #' # a b c d @@ -1452,7 +1470,7 @@ dapplyInternal <- function(x, func, schema) { #' } #' @note dapply since 2.0.0 setMethod("dapply", - signature(x = "SparkDataFrame", func = "function", schema = "structType"), + signature(x = "SparkDataFrame", func = "function", schema = "characterOrstructType"), function(x, func, schema) { dapplyInternal(x, func, schema) }) @@ -1522,6 +1540,7 @@ setMethod("dapplyCollect", #' @param schema the schema of the resulting SparkDataFrame after the function is applied. #' The schema must match to output of \code{func}. It has to be defined for each #' output column with preferred output column name and corresponding data type. +#' Since Spark 2.3, the DDL-formatted string is also supported for the schema. #' @return A SparkDataFrame. #' @family SparkDataFrame functions #' @aliases gapply,SparkDataFrame-method @@ -1541,7 +1560,7 @@ setMethod("dapplyCollect", #' #' Here our output contains three columns, the key which is a combination of two #' columns with data types integer and string and the mean which is a double. -#' schema <- structType(structField("a", "integer"), structField("c", "string"), +#' schema <- structType(structField("a", "integer"), structField("c", "string"), #' structField("avg", "double")) #' result <- gapply( #' df, @@ -1550,6 +1569,15 @@ setMethod("dapplyCollect", #' y <- data.frame(key, mean(x$b), stringsAsFactors = FALSE) #' }, schema) #' +#' The schema also can be specified in a DDL-formatted string. +#' schema <- "a INT, c STRING, avg DOUBLE" +#' result <- gapply( +#' df, +#' c("a", "c"), +#' function(key, x) { +#' y <- data.frame(key, mean(x$b), stringsAsFactors = FALSE) +#' }, schema) +#' #' We can also group the data and afterwards call gapply on GroupedData. #' For Example: #' gdf <- group_by(df, "a", "c") diff --git a/R/pkg/R/functions.R b/R/pkg/R/functions.R index f28d26a51baa0..86507f13f038d 100644 --- a/R/pkg/R/functions.R +++ b/R/pkg/R/functions.R @@ -2174,8 +2174,9 @@ setMethod("date_format", signature(y = "Column", x = "character"), #' #' @rdname column_collection_functions #' @param schema a structType object to use as the schema to use when parsing the JSON string. +#' Since Spark 2.3, the DDL-formatted string is also supported for the schema. #' @param as.json.array indicating if input string is JSON array of objects or a single object. -#' @aliases from_json from_json,Column,structType-method +#' @aliases from_json from_json,Column,characterOrstructType-method #' @export #' @examples #' @@ -2188,10 +2189,15 @@ setMethod("date_format", signature(y = "Column", x = "character"), #' df2 <- sql("SELECT named_struct('name', 'Bob') as people") #' df2 <- mutate(df2, people_json = to_json(df2$people)) #' schema <- structType(structField("name", "string")) -#' head(select(df2, from_json(df2$people_json, schema)))} +#' head(select(df2, from_json(df2$people_json, schema))) +#' head(select(df2, from_json(df2$people_json, "name STRING")))} #' @note from_json since 2.2.0 -setMethod("from_json", signature(x = "Column", schema = "structType"), +setMethod("from_json", signature(x = "Column", schema = "characterOrstructType"), function(x, schema, as.json.array = FALSE, ...) { + if (is.character(schema)) { + schema <- structType(schema) + } + if (as.json.array) { jschema <- callJStatic("org.apache.spark.sql.types.DataTypes", "createArrayType", diff --git a/R/pkg/R/group.R b/R/pkg/R/group.R index 17f5283abead1..0a7be0e993975 100644 --- a/R/pkg/R/group.R +++ b/R/pkg/R/group.R @@ -233,6 +233,9 @@ setMethod("gapplyCollect", }) gapplyInternal <- function(x, func, schema) { + if (is.character(schema)) { + schema <- structType(schema) + } packageNamesArr <- serialize(.sparkREnv[[".packages"]], connection = NULL) broadcastArr <- lapply(ls(.broadcastNames), diff --git a/R/pkg/R/schema.R b/R/pkg/R/schema.R index cb5bdb90175bf..d1ed6833d5d02 100644 --- a/R/pkg/R/schema.R +++ b/R/pkg/R/schema.R @@ -23,18 +23,24 @@ #' Create a structType object that contains the metadata for a SparkDataFrame. Intended for #' use with createDataFrame and toDF. #' -#' @param x a structField object (created with the field() function) +#' @param x a structField object (created with the \code{structField} method). Since Spark 2.3, +#' this can be a DDL-formatted string, which is a comma separated list of field +#' definitions, e.g., "a INT, b STRING". #' @param ... additional structField objects #' @return a structType object #' @rdname structType #' @export #' @examples #'\dontrun{ -#' schema <- structType(structField("a", "integer"), structField("c", "string"), +#' schema <- structType(structField("a", "integer"), structField("c", "string"), #' structField("avg", "double")) #' df1 <- gapply(df, list("a", "c"), #' function(key, x) { y <- data.frame(key, mean(x$b), stringsAsFactors = FALSE) }, #' schema) +#' schema <- structType("a INT, c STRING, avg DOUBLE") +#' df1 <- gapply(df, list("a", "c"), +#' function(key, x) { y <- data.frame(key, mean(x$b), stringsAsFactors = FALSE) }, +#' schema) #' } #' @note structType since 1.4.0 structType <- function(x, ...) { @@ -68,6 +74,23 @@ structType.structField <- function(x, ...) { structType(stObj) } +#' @rdname structType +#' @method structType character +#' @export +structType.character <- function(x, ...) { + if (!is.character(x)) { + stop("schema must be a DDL-formatted string.") + } + if (length(list(...)) > 0) { + stop("multiple DDL-formatted strings are not supported") + } + + stObj <- handledCallJStatic("org.apache.spark.sql.types.StructType", + "fromDDL", + x) + structType(stObj) +} + #' Print a Spark StructType. #' #' This function prints the contents of a StructType returned from the @@ -102,7 +125,7 @@ print.structType <- function(x, ...) { #' field1 <- structField("a", "integer") #' field2 <- structField("c", "string") #' field3 <- structField("avg", "double") -#' schema <- structType(field1, field2, field3) +#' schema <- structType(field1, field2, field3) #' df1 <- gapply(df, list("a", "c"), #' function(key, x) { y <- data.frame(key, mean(x$b), stringsAsFactors = FALSE) }, #' schema) diff --git a/R/pkg/tests/fulltests/test_sparkSQL.R b/R/pkg/tests/fulltests/test_sparkSQL.R index a2bcb5aefe16d..77052d4a28345 100644 --- a/R/pkg/tests/fulltests/test_sparkSQL.R +++ b/R/pkg/tests/fulltests/test_sparkSQL.R @@ -146,6 +146,13 @@ test_that("structType and structField", { expect_is(testSchema, "structType") expect_is(testSchema$fields()[[2]], "structField") expect_equal(testSchema$fields()[[1]]$dataType.toString(), "StringType") + + testSchema <- structType("a STRING, b INT") + expect_is(testSchema, "structType") + expect_is(testSchema$fields()[[2]], "structField") + expect_equal(testSchema$fields()[[1]]$dataType.toString(), "StringType") + + expect_error(structType("A stri"), "DataType stri is not supported.") }) test_that("structField type strings", { @@ -1480,13 +1487,15 @@ test_that("column functions", { j <- collect(select(df, alias(to_json(df$info), "json"))) expect_equal(j[order(j$json), ][1], "{\"age\":16,\"height\":176.5}") df <- as.DataFrame(j) - schema <- structType(structField("age", "integer"), - structField("height", "double")) - s <- collect(select(df, alias(from_json(df$json, schema), "structcol"))) - expect_equal(ncol(s), 1) - expect_equal(nrow(s), 3) - expect_is(s[[1]][[1]], "struct") - expect_true(any(apply(s, 1, function(x) { x[[1]]$age == 16 } ))) + schemas <- list(structType(structField("age", "integer"), structField("height", "double")), + "age INT, height DOUBLE") + for (schema in schemas) { + s <- collect(select(df, alias(from_json(df$json, schema), "structcol"))) + expect_equal(ncol(s), 1) + expect_equal(nrow(s), 3) + expect_is(s[[1]][[1]], "struct") + expect_true(any(apply(s, 1, function(x) { x[[1]]$age == 16 } ))) + } # passing option df <- as.DataFrame(list(list("col" = "{\"date\":\"21/10/2014\"}"))) @@ -1504,14 +1513,15 @@ test_that("column functions", { # check if array type in string is correctly supported. jsonArr <- "[{\"name\":\"Bob\"}, {\"name\":\"Alice\"}]" df <- as.DataFrame(list(list("people" = jsonArr))) - schema <- structType(structField("name", "string")) - arr <- collect(select(df, alias(from_json(df$people, schema, as.json.array = TRUE), "arrcol"))) - expect_equal(ncol(arr), 1) - expect_equal(nrow(arr), 1) - expect_is(arr[[1]][[1]], "list") - expect_equal(length(arr$arrcol[[1]]), 2) - expect_equal(arr$arrcol[[1]][[1]]$name, "Bob") - expect_equal(arr$arrcol[[1]][[2]]$name, "Alice") + for (schema in list(structType(structField("name", "string")), "name STRING")) { + arr <- collect(select(df, alias(from_json(df$people, schema, as.json.array = TRUE), "arrcol"))) + expect_equal(ncol(arr), 1) + expect_equal(nrow(arr), 1) + expect_is(arr[[1]][[1]], "list") + expect_equal(length(arr$arrcol[[1]]), 2) + expect_equal(arr$arrcol[[1]][[1]]$name, "Bob") + expect_equal(arr$arrcol[[1]][[2]]$name, "Alice") + } # Test create_array() and create_map() df <- as.DataFrame(data.frame( @@ -2885,30 +2895,33 @@ test_that("dapply() and dapplyCollect() on a DataFrame", { expect_identical(ldf, result) # Filter and add a column - schema <- structType(structField("a", "integer"), structField("b", "double"), - structField("c", "string"), structField("d", "integer")) - df1 <- dapply( - df, - function(x) { - y <- x[x$a > 1, ] - y <- cbind(y, y$a + 1L) - }, - schema) - result <- collect(df1) - expected <- ldf[ldf$a > 1, ] - expected$d <- expected$a + 1L - rownames(expected) <- NULL - expect_identical(expected, result) - - result <- dapplyCollect( - df, - function(x) { - y <- x[x$a > 1, ] - y <- cbind(y, y$a + 1L) - }) - expected1 <- expected - names(expected1) <- names(result) - expect_identical(expected1, result) + schemas <- list(structType(structField("a", "integer"), structField("b", "double"), + structField("c", "string"), structField("d", "integer")), + "a INT, b DOUBLE, c STRING, d INT") + for (schema in schemas) { + df1 <- dapply( + df, + function(x) { + y <- x[x$a > 1, ] + y <- cbind(y, y$a + 1L) + }, + schema) + result <- collect(df1) + expected <- ldf[ldf$a > 1, ] + expected$d <- expected$a + 1L + rownames(expected) <- NULL + expect_identical(expected, result) + + result <- dapplyCollect( + df, + function(x) { + y <- x[x$a > 1, ] + y <- cbind(y, y$a + 1L) + }) + expected1 <- expected + names(expected1) <- names(result) + expect_identical(expected1, result) + } # Remove the added column df2 <- dapply( @@ -3020,29 +3033,32 @@ test_that("gapply() and gapplyCollect() on a DataFrame", { # Computes the sum of second column by grouping on the first and third columns # and checks if the sum is larger than 2 - schema <- structType(structField("a", "integer"), structField("e", "boolean")) - df2 <- gapply( - df, - c(df$"a", df$"c"), - function(key, x) { - y <- data.frame(key[1], sum(x$b) > 2) - }, - schema) - actual <- collect(df2)$e - expected <- c(TRUE, TRUE) - expect_identical(actual, expected) - - df2Collect <- gapplyCollect( - df, - c(df$"a", df$"c"), - function(key, x) { - y <- data.frame(key[1], sum(x$b) > 2) - colnames(y) <- c("a", "e") - y - }) - actual <- df2Collect$e + schemas <- list(structType(structField("a", "integer"), structField("e", "boolean")), + "a INT, e BOOLEAN") + for (schema in schemas) { + df2 <- gapply( + df, + c(df$"a", df$"c"), + function(key, x) { + y <- data.frame(key[1], sum(x$b) > 2) + }, + schema) + actual <- collect(df2)$e + expected <- c(TRUE, TRUE) expect_identical(actual, expected) + df2Collect <- gapplyCollect( + df, + c(df$"a", df$"c"), + function(key, x) { + y <- data.frame(key[1], sum(x$b) > 2) + colnames(y) <- c("a", "e") + y + }) + actual <- df2Collect$e + expect_identical(actual, expected) + } + # Computes the arithmetic mean of the second column by grouping # on the first and third columns. Output the groupping value and the average. schema <- structType(structField("a", "integer"), structField("c", "string"), diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index 5d8ded83f667d..f3e7d033e97cf 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -1883,15 +1883,20 @@ def from_json(col, schema, options={}): string. :param col: string column in json format - :param schema: a StructType or ArrayType of StructType to use when parsing the json column + :param schema: a StructType or ArrayType of StructType to use when parsing the json column. :param options: options to control parsing. accepts the same options as the json datasource + .. note:: Since Spark 2.3, the DDL-formatted string or a JSON format string is also + supported for ``schema``. + >>> from pyspark.sql.types import * >>> data = [(1, '''{"a": 1}''')] >>> schema = StructType([StructField("a", IntegerType())]) >>> df = spark.createDataFrame(data, ("key", "value")) >>> df.select(from_json(df.value, schema).alias("json")).collect() [Row(json=Row(a=1))] + >>> df.select(from_json(df.value, "a INT").alias("json")).collect() + [Row(json=Row(a=1))] >>> data = [(1, '''[{"a": 1}]''')] >>> schema = ArrayType(StructType([StructField("a", IntegerType())])) >>> df = spark.createDataFrame(data, ("key", "value")) @@ -1900,7 +1905,9 @@ def from_json(col, schema, options={}): """ sc = SparkContext._active_spark_context - jc = sc._jvm.functions.from_json(_to_java_column(col), schema.json(), options) + if isinstance(schema, DataType): + schema = schema.json() + jc = sc._jvm.functions.from_json(_to_java_column(col), schema, options) return Column(jc) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 0c7b483f5c836..ebdeb42b0bfb1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -2114,7 +2114,7 @@ object functions { * Calculates the hash code of given columns, and returns the result as an int column. * * @group misc_funcs - * @since 2.0 + * @since 2.0.0 */ @scala.annotation.varargs def hash(cols: Column*): Column = withExpr { @@ -3074,9 +3074,8 @@ object functions { * string. * * @param e a string column containing JSON data. - * @param schema the schema to use when parsing the json string as a json string. In Spark 2.1, - * the user-provided schema has to be in JSON format. Since Spark 2.2, the DDL - * format is also supported for the schema. + * @param schema the schema to use when parsing the json string as a json string, it could be a + * JSON format string or a DDL-formatted string. * * @group collection_funcs * @since 2.3.0 From d03aebbe6508ba441dc87f9546f27aeb27553d77 Mon Sep 17 00:00:00 2001 From: Bryan Cutler Date: Mon, 10 Jul 2017 15:21:03 -0700 Subject: [PATCH 0924/1765] [SPARK-13534][PYSPARK] Using Apache Arrow to increase performance of DataFrame.toPandas ## What changes were proposed in this pull request? Integrate Apache Arrow with Spark to increase performance of `DataFrame.toPandas`. This has been done by using Arrow to convert data partitions on the executor JVM to Arrow payload byte arrays where they are then served to the Python process. The Python DataFrame can then collect the Arrow payloads where they are combined and converted to a Pandas DataFrame. Data types except complex, date, timestamp, and decimal are currently supported, otherwise an `UnsupportedOperation` exception is thrown. Additions to Spark include a Scala package private method `Dataset.toArrowPayload` that will convert data partitions in the executor JVM to `ArrowPayload`s as byte arrays so they can be easily served. A package private class/object `ArrowConverters` that provide data type mappings and conversion routines. In Python, a private method `DataFrame._collectAsArrow` is added to collect Arrow payloads and a SQLConf "spark.sql.execution.arrow.enable" can be used in `toPandas()` to enable using Arrow (uses the old conversion by default). ## How was this patch tested? Added a new test suite `ArrowConvertersSuite` that will run tests on conversion of Datasets to Arrow payloads for supported types. The suite will generate a Dataset and matching Arrow JSON data, then the dataset is converted to an Arrow payload and finally validated against the JSON data. This will ensure that the schema and data has been converted correctly. Added PySpark tests to verify the `toPandas` method is producing equal DataFrames with and without pyarrow. A roundtrip test to ensure the pandas DataFrame produced by pyspark is equal to a one made directly with pandas. Author: Bryan Cutler Author: Li Jin Author: Li Jin Author: Wes McKinney Closes #18459 from BryanCutler/toPandas_with_arrow-SPARK-13534. --- bin/pyspark | 2 +- dev/deps/spark-deps-hadoop-2.6 | 5 + dev/deps/spark-deps-hadoop-2.7 | 5 + pom.xml | 20 + python/pyspark/serializers.py | 17 + python/pyspark/sql/dataframe.py | 48 +- python/pyspark/sql/tests.py | 78 +- .../apache/spark/sql/internal/SQLConf.scala | 22 + sql/core/pom.xml | 4 + .../scala/org/apache/spark/sql/Dataset.scala | 20 + .../sql/execution/arrow/ArrowConverters.scala | 429 ++++++ .../arrow/ArrowConvertersSuite.scala | 1222 +++++++++++++++++ 12 files changed, 1859 insertions(+), 13 deletions(-) create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowConvertersSuite.scala diff --git a/bin/pyspark b/bin/pyspark index d3b512eeb1209..dd286277c1fc1 100755 --- a/bin/pyspark +++ b/bin/pyspark @@ -68,7 +68,7 @@ if [[ -n "$SPARK_TESTING" ]]; then unset YARN_CONF_DIR unset HADOOP_CONF_DIR export PYTHONHASHSEED=0 - exec "$PYSPARK_DRIVER_PYTHON" -m "$1" + exec "$PYSPARK_DRIVER_PYTHON" -m "$@" exit fi diff --git a/dev/deps/spark-deps-hadoop-2.6 b/dev/deps/spark-deps-hadoop-2.6 index c1325318d52fa..1a6515be51cff 100644 --- a/dev/deps/spark-deps-hadoop-2.6 +++ b/dev/deps/spark-deps-hadoop-2.6 @@ -13,6 +13,9 @@ apacheds-kerberos-codec-2.0.0-M15.jar api-asn1-api-1.0.0-M20.jar api-util-1.0.0-M20.jar arpack_combined_all-0.1.jar +arrow-format-0.4.0.jar +arrow-memory-0.4.0.jar +arrow-vector-0.4.0.jar avro-1.7.7.jar avro-ipc-1.7.7.jar avro-mapred-1.7.7-hadoop2.jar @@ -55,6 +58,7 @@ datanucleus-core-3.2.10.jar datanucleus-rdbms-3.2.9.jar derby-10.12.1.1.jar eigenbase-properties-1.1.5.jar +flatbuffers-1.2.0-3f79e055.jar gson-2.2.4.jar guava-14.0.1.jar guice-3.0.jar @@ -77,6 +81,7 @@ hadoop-yarn-server-web-proxy-2.6.5.jar hk2-api-2.4.0-b34.jar hk2-locator-2.4.0-b34.jar hk2-utils-2.4.0-b34.jar +hppc-0.7.1.jar htrace-core-3.0.4.jar httpclient-4.5.2.jar httpcore-4.4.4.jar diff --git a/dev/deps/spark-deps-hadoop-2.7 b/dev/deps/spark-deps-hadoop-2.7 index ac5abd21807b6..09e5a4288ca50 100644 --- a/dev/deps/spark-deps-hadoop-2.7 +++ b/dev/deps/spark-deps-hadoop-2.7 @@ -13,6 +13,9 @@ apacheds-kerberos-codec-2.0.0-M15.jar api-asn1-api-1.0.0-M20.jar api-util-1.0.0-M20.jar arpack_combined_all-0.1.jar +arrow-format-0.4.0.jar +arrow-memory-0.4.0.jar +arrow-vector-0.4.0.jar avro-1.7.7.jar avro-ipc-1.7.7.jar avro-mapred-1.7.7-hadoop2.jar @@ -55,6 +58,7 @@ datanucleus-core-3.2.10.jar datanucleus-rdbms-3.2.9.jar derby-10.12.1.1.jar eigenbase-properties-1.1.5.jar +flatbuffers-1.2.0-3f79e055.jar gson-2.2.4.jar guava-14.0.1.jar guice-3.0.jar @@ -77,6 +81,7 @@ hadoop-yarn-server-web-proxy-2.7.3.jar hk2-api-2.4.0-b34.jar hk2-locator-2.4.0-b34.jar hk2-utils-2.4.0-b34.jar +hppc-0.7.1.jar htrace-core-3.1.0-incubating.jar httpclient-4.5.2.jar httpcore-4.4.4.jar diff --git a/pom.xml b/pom.xml index 5f524079495c0..f124ba45007b7 100644 --- a/pom.xml +++ b/pom.xml @@ -181,6 +181,7 @@ 2.6 1.8 1.0.0 + 0.4.0 ${java.home} @@ -1878,6 +1879,25 @@ paranamer ${paranamer.version}
    + + org.apache.arrow + arrow-vector + ${arrow.version} + + + com.fasterxml.jackson.core + jackson-annotations + + + com.fasterxml.jackson.core + jackson-databind + + + io.netty + netty-handler + + + diff --git a/python/pyspark/serializers.py b/python/pyspark/serializers.py index ea5e00e9eeef5..d5c2a7518b18f 100644 --- a/python/pyspark/serializers.py +++ b/python/pyspark/serializers.py @@ -182,6 +182,23 @@ def loads(self, obj): raise NotImplementedError +class ArrowSerializer(FramedSerializer): + """ + Serializes an Arrow stream. + """ + + def dumps(self, obj): + raise NotImplementedError + + def loads(self, obj): + import pyarrow as pa + reader = pa.RecordBatchFileReader(pa.BufferReader(obj)) + return reader.read_all() + + def __repr__(self): + return "ArrowSerializer" + + class BatchedSerializer(Serializer): """ diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index 27a6dad8917d3..944739bcd2078 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -29,7 +29,8 @@ from pyspark import copy_func, since from pyspark.rdd import RDD, _load_from_socket, ignore_unicode_prefix -from pyspark.serializers import BatchedSerializer, PickleSerializer, UTF8Deserializer +from pyspark.serializers import ArrowSerializer, BatchedSerializer, PickleSerializer, \ + UTF8Deserializer from pyspark.storagelevel import StorageLevel from pyspark.traceback_utils import SCCallSiteSync from pyspark.sql.types import _parse_datatype_json_string @@ -1710,7 +1711,8 @@ def toDF(self, *cols): @since(1.3) def toPandas(self): - """Returns the contents of this :class:`DataFrame` as Pandas ``pandas.DataFrame``. + """ + Returns the contents of this :class:`DataFrame` as Pandas ``pandas.DataFrame``. This is only available if Pandas is installed and available. @@ -1723,18 +1725,42 @@ def toPandas(self): 1 5 Bob """ import pandas as pd + if self.sql_ctx.getConf("spark.sql.execution.arrow.enable", "false").lower() == "true": + try: + import pyarrow + tables = self._collectAsArrow() + if tables: + table = pyarrow.concat_tables(tables) + return table.to_pandas() + else: + return pd.DataFrame.from_records([], columns=self.columns) + except ImportError as e: + msg = "note: pyarrow must be installed and available on calling Python process " \ + "if using spark.sql.execution.arrow.enable=true" + raise ImportError("%s\n%s" % (e.message, msg)) + else: + dtype = {} + for field in self.schema: + pandas_type = _to_corrected_pandas_type(field.dataType) + if pandas_type is not None: + dtype[field.name] = pandas_type - dtype = {} - for field in self.schema: - pandas_type = _to_corrected_pandas_type(field.dataType) - if pandas_type is not None: - dtype[field.name] = pandas_type + pdf = pd.DataFrame.from_records(self.collect(), columns=self.columns) - pdf = pd.DataFrame.from_records(self.collect(), columns=self.columns) + for f, t in dtype.items(): + pdf[f] = pdf[f].astype(t, copy=False) + return pdf - for f, t in dtype.items(): - pdf[f] = pdf[f].astype(t, copy=False) - return pdf + def _collectAsArrow(self): + """ + Returns all records as list of deserialized ArrowPayloads, pyarrow must be installed + and available. + + .. note:: Experimental. + """ + with SCCallSiteSync(self._sc) as css: + port = self._jdf.collectAsArrowToPython() + return list(_load_from_socket(port, ArrowSerializer())) ########################################################################################## # Pandas compatibility diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 9db2f40474f70..bd8477e35f37a 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -58,12 +58,21 @@ from pyspark.sql import SparkSession, SQLContext, HiveContext, Column, Row from pyspark.sql.types import * from pyspark.sql.types import UserDefinedType, _infer_type, _make_type_verifier -from pyspark.tests import ReusedPySparkTestCase, SparkSubmitTests +from pyspark.tests import QuietTest, ReusedPySparkTestCase, SparkSubmitTests from pyspark.sql.functions import UserDefinedFunction, sha2, lit from pyspark.sql.window import Window from pyspark.sql.utils import AnalysisException, ParseException, IllegalArgumentException +_have_arrow = False +try: + import pyarrow + _have_arrow = True +except: + # No Arrow, but that's okay, we'll skip those tests + pass + + class UTCOffsetTimezone(datetime.tzinfo): """ Specifies timezone in UTC offset @@ -2843,6 +2852,73 @@ def __init__(self, **kwargs): _make_type_verifier(data_type, nullable=False)(obj) +@unittest.skipIf(not _have_arrow, "Arrow not installed") +class ArrowTests(ReusedPySparkTestCase): + + @classmethod + def setUpClass(cls): + ReusedPySparkTestCase.setUpClass() + cls.spark = SparkSession(cls.sc) + cls.spark.conf.set("spark.sql.execution.arrow.enable", "true") + cls.schema = StructType([ + StructField("1_str_t", StringType(), True), + StructField("2_int_t", IntegerType(), True), + StructField("3_long_t", LongType(), True), + StructField("4_float_t", FloatType(), True), + StructField("5_double_t", DoubleType(), True)]) + cls.data = [("a", 1, 10, 0.2, 2.0), + ("b", 2, 20, 0.4, 4.0), + ("c", 3, 30, 0.8, 6.0)] + + def assertFramesEqual(self, df_with_arrow, df_without): + msg = ("DataFrame from Arrow is not equal" + + ("\n\nWith Arrow:\n%s\n%s" % (df_with_arrow, df_with_arrow.dtypes)) + + ("\n\nWithout:\n%s\n%s" % (df_without, df_without.dtypes))) + self.assertTrue(df_without.equals(df_with_arrow), msg=msg) + + def test_unsupported_datatype(self): + schema = StructType([StructField("array", ArrayType(IntegerType(), False), True)]) + df = self.spark.createDataFrame([([1, 2, 3],)], schema=schema) + with QuietTest(self.sc): + self.assertRaises(Exception, lambda: df.toPandas()) + + def test_null_conversion(self): + df_null = self.spark.createDataFrame([tuple([None for _ in range(len(self.data[0]))])] + + self.data) + pdf = df_null.toPandas() + null_counts = pdf.isnull().sum().tolist() + self.assertTrue(all([c == 1 for c in null_counts])) + + def test_toPandas_arrow_toggle(self): + df = self.spark.createDataFrame(self.data, schema=self.schema) + self.spark.conf.set("spark.sql.execution.arrow.enable", "false") + pdf = df.toPandas() + self.spark.conf.set("spark.sql.execution.arrow.enable", "true") + pdf_arrow = df.toPandas() + self.assertFramesEqual(pdf_arrow, pdf) + + def test_pandas_round_trip(self): + import pandas as pd + import numpy as np + data_dict = {} + for j, name in enumerate(self.schema.names): + data_dict[name] = [self.data[i][j] for i in range(len(self.data))] + # need to convert these to numpy types first + data_dict["2_int_t"] = np.int32(data_dict["2_int_t"]) + data_dict["4_float_t"] = np.float32(data_dict["4_float_t"]) + pdf = pd.DataFrame(data=data_dict) + df = self.spark.createDataFrame(self.data, schema=self.schema) + pdf_arrow = df.toPandas() + self.assertFramesEqual(pdf_arrow, pdf) + + def test_filtered_frame(self): + df = self.spark.range(3).toDF("i") + pdf = df.filter("i < 0").toPandas() + self.assertEqual(len(pdf.columns), 1) + self.assertEqual(pdf.columns[0], "i") + self.assertTrue(pdf.empty) + + if __name__ == "__main__": from pyspark.sql.tests import * if xmlrunner: diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 25152f3e32d6b..643587a6eb09d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -855,6 +855,24 @@ object SQLConf { .intConf .createWithDefault(UnsafeExternalSorter.DEFAULT_NUM_ELEMENTS_FOR_SPILL_THRESHOLD.toInt) + val ARROW_EXECUTION_ENABLE = + buildConf("spark.sql.execution.arrow.enable") + .internal() + .doc("Make use of Apache Arrow for columnar data transfers. Currently available " + + "for use with pyspark.sql.DataFrame.toPandas with the following data types: " + + "StringType, BinaryType, BooleanType, DoubleType, FloatType, ByteType, IntegerType, " + + "LongType, ShortType") + .booleanConf + .createWithDefault(false) + + val ARROW_EXECUTION_MAX_RECORDS_PER_BATCH = + buildConf("spark.sql.execution.arrow.maxRecordsPerBatch") + .internal() + .doc("When using Apache Arrow, limit the maximum number of records that can be written " + + "to a single ArrowRecordBatch in memory. If set to zero or negative there is no limit.") + .intConf + .createWithDefault(10000) + object Deprecated { val MAPRED_REDUCE_TASKS = "mapred.reduce.tasks" } @@ -1115,6 +1133,10 @@ class SQLConf extends Serializable with Logging { def starSchemaFTRatio: Double = getConf(STARSCHEMA_FACT_TABLE_RATIO) + def arrowEnable: Boolean = getConf(ARROW_EXECUTION_ENABLE) + + def arrowMaxRecordsPerBatch: Int = getConf(ARROW_EXECUTION_MAX_RECORDS_PER_BATCH) + /** ********************** SQLConf functionality methods ************ */ /** Set Spark SQL configuration properties. */ diff --git a/sql/core/pom.xml b/sql/core/pom.xml index 1bc34a6b069d9..661c31ded7148 100644 --- a/sql/core/pom.xml +++ b/sql/core/pom.xml @@ -103,6 +103,10 @@ jackson-databind ${fasterxml.jackson.version} + + org.apache.arrow + arrow-vector + org.apache.xbean xbean-asm5-shaded diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index dfb51192c69bc..a7773831df075 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -46,6 +46,7 @@ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.plans.physical.{Partitioning, PartitioningCollection} import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.execution._ +import org.apache.spark.sql.execution.arrow.{ArrowConverters, ArrowPayload} import org.apache.spark.sql.execution.command._ import org.apache.spark.sql.execution.datasources.LogicalRelation import org.apache.spark.sql.execution.python.EvaluatePython @@ -2907,6 +2908,16 @@ class Dataset[T] private[sql]( } } + /** + * Collect a Dataset as ArrowPayload byte arrays and serve to PySpark. + */ + private[sql] def collectAsArrowToPython(): Int = { + withNewExecutionId { + val iter = toArrowPayload.collect().iterator.map(_.asPythonSerializable) + PythonRDD.serveIterator(iter, "serve-Arrow") + } + } + private[sql] def toPythonIterator(): Int = { withNewExecutionId { PythonRDD.toLocalIteratorAndServe(javaToPython.rdd) @@ -2988,4 +2999,13 @@ class Dataset[T] private[sql]( Dataset(sparkSession, logicalPlan) } } + + /** Convert to an RDD of ArrowPayload byte arrays */ + private[sql] def toArrowPayload: RDD[ArrowPayload] = { + val schemaCaptured = this.schema + val maxRecordsPerBatch = sparkSession.sessionState.conf.arrowMaxRecordsPerBatch + queryExecution.toRdd.mapPartitionsInternal { iter => + ArrowConverters.toPayloadIterator(iter, schemaCaptured, maxRecordsPerBatch) + } + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala new file mode 100644 index 0000000000000..6af5c73422377 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala @@ -0,0 +1,429 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +package org.apache.spark.sql.execution.arrow + +import java.io.ByteArrayOutputStream +import java.nio.channels.Channels + +import scala.collection.JavaConverters._ + +import io.netty.buffer.ArrowBuf +import org.apache.arrow.memory.{BufferAllocator, RootAllocator} +import org.apache.arrow.vector._ +import org.apache.arrow.vector.BaseValueVector.BaseMutator +import org.apache.arrow.vector.file._ +import org.apache.arrow.vector.schema.{ArrowFieldNode, ArrowRecordBatch} +import org.apache.arrow.vector.types.FloatingPointPrecision +import org.apache.arrow.vector.types.pojo.{ArrowType, Field, FieldType, Schema} +import org.apache.arrow.vector.util.ByteArrayReadableSeekableByteChannel + +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.types._ +import org.apache.spark.util.Utils + + +/** + * Store Arrow data in a form that can be serialized by Spark and served to a Python process. + */ +private[sql] class ArrowPayload private[arrow] (payload: Array[Byte]) extends Serializable { + + /** + * Convert the ArrowPayload to an ArrowRecordBatch. + */ + def loadBatch(allocator: BufferAllocator): ArrowRecordBatch = { + ArrowConverters.byteArrayToBatch(payload, allocator) + } + + /** + * Get the ArrowPayload as a type that can be served to Python. + */ + def asPythonSerializable: Array[Byte] = payload +} + +private[sql] object ArrowPayload { + + /** + * Create an ArrowPayload from an ArrowRecordBatch and Spark schema. + */ + def apply( + batch: ArrowRecordBatch, + schema: StructType, + allocator: BufferAllocator): ArrowPayload = { + new ArrowPayload(ArrowConverters.batchToByteArray(batch, schema, allocator)) + } +} + +private[sql] object ArrowConverters { + + /** + * Map a Spark DataType to ArrowType. + */ + private[arrow] def sparkTypeToArrowType(dataType: DataType): ArrowType = { + dataType match { + case BooleanType => ArrowType.Bool.INSTANCE + case ShortType => new ArrowType.Int(8 * ShortType.defaultSize, true) + case IntegerType => new ArrowType.Int(8 * IntegerType.defaultSize, true) + case LongType => new ArrowType.Int(8 * LongType.defaultSize, true) + case FloatType => new ArrowType.FloatingPoint(FloatingPointPrecision.SINGLE) + case DoubleType => new ArrowType.FloatingPoint(FloatingPointPrecision.DOUBLE) + case ByteType => new ArrowType.Int(8, true) + case StringType => ArrowType.Utf8.INSTANCE + case BinaryType => ArrowType.Binary.INSTANCE + case _ => throw new UnsupportedOperationException(s"Unsupported data type: $dataType") + } + } + + /** + * Convert a Spark Dataset schema to Arrow schema. + */ + private[arrow] def schemaToArrowSchema(schema: StructType): Schema = { + val arrowFields = schema.fields.map { f => + new Field(f.name, f.nullable, sparkTypeToArrowType(f.dataType), List.empty[Field].asJava) + } + new Schema(arrowFields.toList.asJava) + } + + /** + * Maps Iterator from InternalRow to ArrowPayload. Limit ArrowRecordBatch size in ArrowPayload + * by setting maxRecordsPerBatch or use 0 to fully consume rowIter. + */ + private[sql] def toPayloadIterator( + rowIter: Iterator[InternalRow], + schema: StructType, + maxRecordsPerBatch: Int): Iterator[ArrowPayload] = { + new Iterator[ArrowPayload] { + private val _allocator = new RootAllocator(Long.MaxValue) + private var _nextPayload = if (rowIter.nonEmpty) convert() else null + + override def hasNext: Boolean = _nextPayload != null + + override def next(): ArrowPayload = { + val obj = _nextPayload + if (hasNext) { + if (rowIter.hasNext) { + _nextPayload = convert() + } else { + _allocator.close() + _nextPayload = null + } + } + obj + } + + private def convert(): ArrowPayload = { + val batch = internalRowIterToArrowBatch(rowIter, schema, _allocator, maxRecordsPerBatch) + ArrowPayload(batch, schema, _allocator) + } + } + } + + /** + * Iterate over InternalRows and write to an ArrowRecordBatch, stopping when rowIter is consumed + * or the number of records in the batch equals maxRecordsInBatch. If maxRecordsPerBatch is 0, + * then rowIter will be fully consumed. + */ + private def internalRowIterToArrowBatch( + rowIter: Iterator[InternalRow], + schema: StructType, + allocator: BufferAllocator, + maxRecordsPerBatch: Int = 0): ArrowRecordBatch = { + + val columnWriters = schema.fields.zipWithIndex.map { case (field, ordinal) => + ColumnWriter(field.dataType, ordinal, allocator).init() + } + + val writerLength = columnWriters.length + var recordsInBatch = 0 + while (rowIter.hasNext && (maxRecordsPerBatch <= 0 || recordsInBatch < maxRecordsPerBatch)) { + val row = rowIter.next() + var i = 0 + while (i < writerLength) { + columnWriters(i).write(row) + i += 1 + } + recordsInBatch += 1 + } + + val (fieldNodes, bufferArrays) = columnWriters.map(_.finish()).unzip + val buffers = bufferArrays.flatten + + val rowLength = if (fieldNodes.nonEmpty) fieldNodes.head.getLength else 0 + val recordBatch = new ArrowRecordBatch(rowLength, + fieldNodes.toList.asJava, buffers.toList.asJava) + + buffers.foreach(_.release()) + recordBatch + } + + /** + * Convert an ArrowRecordBatch to a byte array and close batch to release resources. Once closed, + * the batch can no longer be used. + */ + private[arrow] def batchToByteArray( + batch: ArrowRecordBatch, + schema: StructType, + allocator: BufferAllocator): Array[Byte] = { + val arrowSchema = ArrowConverters.schemaToArrowSchema(schema) + val root = VectorSchemaRoot.create(arrowSchema, allocator) + val out = new ByteArrayOutputStream() + val writer = new ArrowFileWriter(root, null, Channels.newChannel(out)) + + // Write a batch to byte stream, ensure the batch, allocator and writer are closed + Utils.tryWithSafeFinally { + val loader = new VectorLoader(root) + loader.load(batch) + writer.writeBatch() // writeBatch can throw IOException + } { + batch.close() + root.close() + writer.close() + } + out.toByteArray + } + + /** + * Convert a byte array to an ArrowRecordBatch. + */ + private[arrow] def byteArrayToBatch( + batchBytes: Array[Byte], + allocator: BufferAllocator): ArrowRecordBatch = { + val in = new ByteArrayReadableSeekableByteChannel(batchBytes) + val reader = new ArrowFileReader(in, allocator) + + // Read a batch from a byte stream, ensure the reader is closed + Utils.tryWithSafeFinally { + val root = reader.getVectorSchemaRoot // throws IOException + val unloader = new VectorUnloader(root) + reader.loadNextBatch() // throws IOException + unloader.getRecordBatch + } { + reader.close() + } + } +} + +/** + * Interface for writing InternalRows to Arrow Buffers. + */ +private[arrow] trait ColumnWriter { + def init(): this.type + def write(row: InternalRow): Unit + + /** + * Clear the column writer and return the ArrowFieldNode and ArrowBuf. + * This should be called only once after all the data is written. + */ + def finish(): (ArrowFieldNode, Array[ArrowBuf]) +} + +/** + * Base class for flat arrow column writer, i.e., column without children. + */ +private[arrow] abstract class PrimitiveColumnWriter(val ordinal: Int) + extends ColumnWriter { + + def getFieldType(dtype: ArrowType): FieldType = FieldType.nullable(dtype) + + def valueVector: BaseDataValueVector + def valueMutator: BaseMutator + + def setNull(): Unit + def setValue(row: InternalRow): Unit + + protected var count = 0 + protected var nullCount = 0 + + override def init(): this.type = { + valueVector.allocateNew() + this + } + + override def write(row: InternalRow): Unit = { + if (row.isNullAt(ordinal)) { + setNull() + nullCount += 1 + } else { + setValue(row) + } + count += 1 + } + + override def finish(): (ArrowFieldNode, Array[ArrowBuf]) = { + valueMutator.setValueCount(count) + val fieldNode = new ArrowFieldNode(count, nullCount) + val valueBuffers = valueVector.getBuffers(true) + (fieldNode, valueBuffers) + } +} + +private[arrow] class BooleanColumnWriter(dtype: ArrowType, ordinal: Int, allocator: BufferAllocator) + extends PrimitiveColumnWriter(ordinal) { + override val valueVector: NullableBitVector + = new NullableBitVector("BooleanValue", getFieldType(dtype), allocator) + override val valueMutator: NullableBitVector#Mutator = valueVector.getMutator + + override def setNull(): Unit = valueMutator.setNull(count) + override def setValue(row: InternalRow): Unit + = valueMutator.setSafe(count, if (row.getBoolean(ordinal)) 1 else 0 ) +} + +private[arrow] class ShortColumnWriter(dtype: ArrowType, ordinal: Int, allocator: BufferAllocator) + extends PrimitiveColumnWriter(ordinal) { + override val valueVector: NullableSmallIntVector + = new NullableSmallIntVector("ShortValue", getFieldType(dtype: ArrowType), allocator) + override val valueMutator: NullableSmallIntVector#Mutator = valueVector.getMutator + + override def setNull(): Unit = valueMutator.setNull(count) + override def setValue(row: InternalRow): Unit + = valueMutator.setSafe(count, row.getShort(ordinal)) +} + +private[arrow] class IntegerColumnWriter(dtype: ArrowType, ordinal: Int, allocator: BufferAllocator) + extends PrimitiveColumnWriter(ordinal) { + override val valueVector: NullableIntVector + = new NullableIntVector("IntValue", getFieldType(dtype), allocator) + override val valueMutator: NullableIntVector#Mutator = valueVector.getMutator + + override def setNull(): Unit = valueMutator.setNull(count) + override def setValue(row: InternalRow): Unit + = valueMutator.setSafe(count, row.getInt(ordinal)) +} + +private[arrow] class LongColumnWriter(dtype: ArrowType, ordinal: Int, allocator: BufferAllocator) + extends PrimitiveColumnWriter(ordinal) { + override val valueVector: NullableBigIntVector + = new NullableBigIntVector("LongValue", getFieldType(dtype), allocator) + override val valueMutator: NullableBigIntVector#Mutator = valueVector.getMutator + + override def setNull(): Unit = valueMutator.setNull(count) + override def setValue(row: InternalRow): Unit + = valueMutator.setSafe(count, row.getLong(ordinal)) +} + +private[arrow] class FloatColumnWriter(dtype: ArrowType, ordinal: Int, allocator: BufferAllocator) + extends PrimitiveColumnWriter(ordinal) { + override val valueVector: NullableFloat4Vector + = new NullableFloat4Vector("FloatValue", getFieldType(dtype), allocator) + override val valueMutator: NullableFloat4Vector#Mutator = valueVector.getMutator + + override def setNull(): Unit = valueMutator.setNull(count) + override def setValue(row: InternalRow): Unit + = valueMutator.setSafe(count, row.getFloat(ordinal)) +} + +private[arrow] class DoubleColumnWriter(dtype: ArrowType, ordinal: Int, allocator: BufferAllocator) + extends PrimitiveColumnWriter(ordinal) { + override val valueVector: NullableFloat8Vector + = new NullableFloat8Vector("DoubleValue", getFieldType(dtype), allocator) + override val valueMutator: NullableFloat8Vector#Mutator = valueVector.getMutator + + override def setNull(): Unit = valueMutator.setNull(count) + override def setValue(row: InternalRow): Unit + = valueMutator.setSafe(count, row.getDouble(ordinal)) +} + +private[arrow] class ByteColumnWriter(dtype: ArrowType, ordinal: Int, allocator: BufferAllocator) + extends PrimitiveColumnWriter(ordinal) { + override val valueVector: NullableUInt1Vector + = new NullableUInt1Vector("ByteValue", getFieldType(dtype), allocator) + override val valueMutator: NullableUInt1Vector#Mutator = valueVector.getMutator + + override def setNull(): Unit = valueMutator.setNull(count) + override def setValue(row: InternalRow): Unit + = valueMutator.setSafe(count, row.getByte(ordinal)) +} + +private[arrow] class UTF8StringColumnWriter( + dtype: ArrowType, + ordinal: Int, + allocator: BufferAllocator) + extends PrimitiveColumnWriter(ordinal) { + override val valueVector: NullableVarCharVector + = new NullableVarCharVector("UTF8StringValue", getFieldType(dtype), allocator) + override val valueMutator: NullableVarCharVector#Mutator = valueVector.getMutator + + override def setNull(): Unit = valueMutator.setNull(count) + override def setValue(row: InternalRow): Unit = { + val str = row.getUTF8String(ordinal) + valueMutator.setSafe(count, str.getByteBuffer, 0, str.numBytes) + } +} + +private[arrow] class BinaryColumnWriter(dtype: ArrowType, ordinal: Int, allocator: BufferAllocator) + extends PrimitiveColumnWriter(ordinal) { + override val valueVector: NullableVarBinaryVector + = new NullableVarBinaryVector("BinaryValue", getFieldType(dtype), allocator) + override val valueMutator: NullableVarBinaryVector#Mutator = valueVector.getMutator + + override def setNull(): Unit = valueMutator.setNull(count) + override def setValue(row: InternalRow): Unit = { + val bytes = row.getBinary(ordinal) + valueMutator.setSafe(count, bytes, 0, bytes.length) + } +} + +private[arrow] class DateColumnWriter(dtype: ArrowType, ordinal: Int, allocator: BufferAllocator) + extends PrimitiveColumnWriter(ordinal) { + override val valueVector: NullableDateDayVector + = new NullableDateDayVector("DateValue", getFieldType(dtype), allocator) + override val valueMutator: NullableDateDayVector#Mutator = valueVector.getMutator + + override def setNull(): Unit = valueMutator.setNull(count) + override def setValue(row: InternalRow): Unit = { + valueMutator.setSafe(count, row.getInt(ordinal)) + } +} + +private[arrow] class TimeStampColumnWriter( + dtype: ArrowType, + ordinal: Int, + allocator: BufferAllocator) + extends PrimitiveColumnWriter(ordinal) { + override val valueVector: NullableTimeStampMicroVector + = new NullableTimeStampMicroVector("TimeStampValue", getFieldType(dtype), allocator) + override val valueMutator: NullableTimeStampMicroVector#Mutator = valueVector.getMutator + + override def setNull(): Unit = valueMutator.setNull(count) + override def setValue(row: InternalRow): Unit = { + valueMutator.setSafe(count, row.getLong(ordinal)) + } +} + +private[arrow] object ColumnWriter { + + /** + * Create an Arrow ColumnWriter given the type and ordinal of row. + */ + def apply(dataType: DataType, ordinal: Int, allocator: BufferAllocator): ColumnWriter = { + val dtype = ArrowConverters.sparkTypeToArrowType(dataType) + dataType match { + case BooleanType => new BooleanColumnWriter(dtype, ordinal, allocator) + case ShortType => new ShortColumnWriter(dtype, ordinal, allocator) + case IntegerType => new IntegerColumnWriter(dtype, ordinal, allocator) + case LongType => new LongColumnWriter(dtype, ordinal, allocator) + case FloatType => new FloatColumnWriter(dtype, ordinal, allocator) + case DoubleType => new DoubleColumnWriter(dtype, ordinal, allocator) + case ByteType => new ByteColumnWriter(dtype, ordinal, allocator) + case StringType => new UTF8StringColumnWriter(dtype, ordinal, allocator) + case BinaryType => new BinaryColumnWriter(dtype, ordinal, allocator) + case DateType => new DateColumnWriter(dtype, ordinal, allocator) + case TimestampType => new TimeStampColumnWriter(dtype, ordinal, allocator) + case _ => throw new UnsupportedOperationException(s"Unsupported data type: $dataType") + } + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowConvertersSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowConvertersSuite.scala new file mode 100644 index 0000000000000..159328cc0d958 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowConvertersSuite.scala @@ -0,0 +1,1222 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.execution.arrow + +import java.io.File +import java.nio.charset.StandardCharsets +import java.sql.{Date, Timestamp} +import java.text.SimpleDateFormat +import java.util.Locale + +import com.google.common.io.Files +import org.apache.arrow.memory.RootAllocator +import org.apache.arrow.vector.{VectorLoader, VectorSchemaRoot} +import org.apache.arrow.vector.file.json.JsonFileReader +import org.apache.arrow.vector.util.Validator +import org.scalatest.BeforeAndAfterAll + +import org.apache.spark.SparkException +import org.apache.spark.sql.{DataFrame, Row} +import org.apache.spark.sql.test.SharedSQLContext +import org.apache.spark.sql.types.{BinaryType, StructField, StructType} +import org.apache.spark.util.Utils + + +class ArrowConvertersSuite extends SharedSQLContext with BeforeAndAfterAll { + import testImplicits._ + + private var tempDataPath: String = _ + + override def beforeAll(): Unit = { + super.beforeAll() + tempDataPath = Utils.createTempDir(namePrefix = "arrow").getAbsolutePath + } + + test("collect to arrow record batch") { + val indexData = (1 to 6).toDF("i") + val arrowPayloads = indexData.toArrowPayload.collect() + assert(arrowPayloads.nonEmpty) + assert(arrowPayloads.length == indexData.rdd.getNumPartitions) + val allocator = new RootAllocator(Long.MaxValue) + val arrowRecordBatches = arrowPayloads.map(_.loadBatch(allocator)) + val rowCount = arrowRecordBatches.map(_.getLength).sum + assert(rowCount === indexData.count()) + arrowRecordBatches.foreach(batch => assert(batch.getNodes.size() > 0)) + arrowRecordBatches.foreach(_.close()) + allocator.close() + } + + test("short conversion") { + val json = + s""" + |{ + | "schema" : { + | "fields" : [ { + | "name" : "a_s", + | "type" : { + | "name" : "int", + | "isSigned" : true, + | "bitWidth" : 16 + | }, + | "nullable" : false, + | "children" : [ ], + | "typeLayout" : { + | "vectors" : [ { + | "type" : "VALIDITY", + | "typeBitWidth" : 1 + | }, { + | "type" : "DATA", + | "typeBitWidth" : 16 + | } ] + | } + | }, { + | "name" : "b_s", + | "type" : { + | "name" : "int", + | "isSigned" : true, + | "bitWidth" : 16 + | }, + | "nullable" : true, + | "children" : [ ], + | "typeLayout" : { + | "vectors" : [ { + | "type" : "VALIDITY", + | "typeBitWidth" : 1 + | }, { + | "type" : "DATA", + | "typeBitWidth" : 16 + | } ] + | } + | } ] + | }, + | "batches" : [ { + | "count" : 6, + | "columns" : [ { + | "name" : "a_s", + | "count" : 6, + | "VALIDITY" : [ 1, 1, 1, 1, 1, 1 ], + | "DATA" : [ 1, -1, 2, -2, 32767, -32768 ] + | }, { + | "name" : "b_s", + | "count" : 6, + | "VALIDITY" : [ 1, 0, 0, 1, 0, 1 ], + | "DATA" : [ 1, 0, 0, -2, 0, -32768 ] + | } ] + | } ] + |} + """.stripMargin + + val a_s = List[Short](1, -1, 2, -2, 32767, -32768) + val b_s = List[Option[Short]](Some(1), None, None, Some(-2), None, Some(-32768)) + val df = a_s.zip(b_s).toDF("a_s", "b_s") + + collectAndValidate(df, json, "integer-16bit.json") + } + + test("int conversion") { + val json = + s""" + |{ + | "schema" : { + | "fields" : [ { + | "name" : "a_i", + | "type" : { + | "name" : "int", + | "isSigned" : true, + | "bitWidth" : 32 + | }, + | "nullable" : false, + | "children" : [ ], + | "typeLayout" : { + | "vectors" : [ { + | "type" : "VALIDITY", + | "typeBitWidth" : 1 + | }, { + | "type" : "DATA", + | "typeBitWidth" : 32 + | } ] + | } + | }, { + | "name" : "b_i", + | "type" : { + | "name" : "int", + | "isSigned" : true, + | "bitWidth" : 32 + | }, + | "nullable" : true, + | "children" : [ ], + | "typeLayout" : { + | "vectors" : [ { + | "type" : "VALIDITY", + | "typeBitWidth" : 1 + | }, { + | "type" : "DATA", + | "typeBitWidth" : 32 + | } ] + | } + | } ] + | }, + | "batches" : [ { + | "count" : 6, + | "columns" : [ { + | "name" : "a_i", + | "count" : 6, + | "VALIDITY" : [ 1, 1, 1, 1, 1, 1 ], + | "DATA" : [ 1, -1, 2, -2, 2147483647, -2147483648 ] + | }, { + | "name" : "b_i", + | "count" : 6, + | "VALIDITY" : [ 1, 0, 0, 1, 0, 1 ], + | "DATA" : [ 1, 0, 0, -2, 0, -2147483648 ] + | } ] + | } ] + |} + """.stripMargin + + val a_i = List[Int](1, -1, 2, -2, 2147483647, -2147483648) + val b_i = List[Option[Int]](Some(1), None, None, Some(-2), None, Some(-2147483648)) + val df = a_i.zip(b_i).toDF("a_i", "b_i") + + collectAndValidate(df, json, "integer-32bit.json") + } + + test("long conversion") { + val json = + s""" + |{ + | "schema" : { + | "fields" : [ { + | "name" : "a_l", + | "type" : { + | "name" : "int", + | "isSigned" : true, + | "bitWidth" : 64 + | }, + | "nullable" : false, + | "children" : [ ], + | "typeLayout" : { + | "vectors" : [ { + | "type" : "VALIDITY", + | "typeBitWidth" : 1 + | }, { + | "type" : "DATA", + | "typeBitWidth" : 64 + | } ] + | } + | }, { + | "name" : "b_l", + | "type" : { + | "name" : "int", + | "isSigned" : true, + | "bitWidth" : 64 + | }, + | "nullable" : true, + | "children" : [ ], + | "typeLayout" : { + | "vectors" : [ { + | "type" : "VALIDITY", + | "typeBitWidth" : 1 + | }, { + | "type" : "DATA", + | "typeBitWidth" : 64 + | } ] + | } + | } ] + | }, + | "batches" : [ { + | "count" : 6, + | "columns" : [ { + | "name" : "a_l", + | "count" : 6, + | "VALIDITY" : [ 1, 1, 1, 1, 1, 1 ], + | "DATA" : [ 1, -1, 2, -2, 9223372036854775807, -9223372036854775808 ] + | }, { + | "name" : "b_l", + | "count" : 6, + | "VALIDITY" : [ 1, 0, 0, 1, 0, 1 ], + | "DATA" : [ 1, 0, 0, -2, 0, -9223372036854775808 ] + | } ] + | } ] + |} + """.stripMargin + + val a_l = List[Long](1, -1, 2, -2, 9223372036854775807L, -9223372036854775808L) + val b_l = List[Option[Long]](Some(1), None, None, Some(-2), None, Some(-9223372036854775808L)) + val df = a_l.zip(b_l).toDF("a_l", "b_l") + + collectAndValidate(df, json, "integer-64bit.json") + } + + test("float conversion") { + val json = + s""" + |{ + | "schema" : { + | "fields" : [ { + | "name" : "a_f", + | "type" : { + | "name" : "floatingpoint", + | "precision" : "SINGLE" + | }, + | "nullable" : false, + | "children" : [ ], + | "typeLayout" : { + | "vectors" : [ { + | "type" : "VALIDITY", + | "typeBitWidth" : 1 + | }, { + | "type" : "DATA", + | "typeBitWidth" : 32 + | } ] + | } + | }, { + | "name" : "b_f", + | "type" : { + | "name" : "floatingpoint", + | "precision" : "SINGLE" + | }, + | "nullable" : true, + | "children" : [ ], + | "typeLayout" : { + | "vectors" : [ { + | "type" : "VALIDITY", + | "typeBitWidth" : 1 + | }, { + | "type" : "DATA", + | "typeBitWidth" : 32 + | } ] + | } + | } ] + | }, + | "batches" : [ { + | "count" : 6, + | "columns" : [ { + | "name" : "a_f", + | "count" : 6, + | "VALIDITY" : [ 1, 1, 1, 1, 1, 1 ], + | "DATA" : [ 1.0, 2.0, 0.01, 200.0, 0.0001, 20000.0 ] + | }, { + | "name" : "b_f", + | "count" : 6, + | "VALIDITY" : [ 1, 0, 0, 1, 0, 1 ], + | "DATA" : [ 1.1, 0.0, 0.0, 2.2, 0.0, 3.3 ] + | } ] + | } ] + |} + """.stripMargin + + val a_f = List(1.0f, 2.0f, 0.01f, 200.0f, 0.0001f, 20000.0f) + val b_f = List[Option[Float]](Some(1.1f), None, None, Some(2.2f), None, Some(3.3f)) + val df = a_f.zip(b_f).toDF("a_f", "b_f") + + collectAndValidate(df, json, "floating_point-single_precision.json") + } + + test("double conversion") { + val json = + s""" + |{ + | "schema" : { + | "fields" : [ { + | "name" : "a_d", + | "type" : { + | "name" : "floatingpoint", + | "precision" : "DOUBLE" + | }, + | "nullable" : false, + | "children" : [ ], + | "typeLayout" : { + | "vectors" : [ { + | "type" : "VALIDITY", + | "typeBitWidth" : 1 + | }, { + | "type" : "DATA", + | "typeBitWidth" : 64 + | } ] + | } + | }, { + | "name" : "b_d", + | "type" : { + | "name" : "floatingpoint", + | "precision" : "DOUBLE" + | }, + | "nullable" : true, + | "children" : [ ], + | "typeLayout" : { + | "vectors" : [ { + | "type" : "VALIDITY", + | "typeBitWidth" : 1 + | }, { + | "type" : "DATA", + | "typeBitWidth" : 64 + | } ] + | } + | } ] + | }, + | "batches" : [ { + | "count" : 6, + | "columns" : [ { + | "name" : "a_d", + | "count" : 6, + | "VALIDITY" : [ 1, 1, 1, 1, 1, 1 ], + | "DATA" : [ 1.0, 2.0, 0.01, 200.0, 1.0E-4, 20000.0 ] + | }, { + | "name" : "b_d", + | "count" : 6, + | "VALIDITY" : [ 1, 0, 0, 1, 0, 1 ], + | "DATA" : [ 1.1, 0.0, 0.0, 2.2, 0.0, 3.3 ] + | } ] + | } ] + |} + """.stripMargin + + val a_d = List(1.0, 2.0, 0.01, 200.0, 0.0001, 20000.0) + val b_d = List[Option[Double]](Some(1.1), None, None, Some(2.2), None, Some(3.3)) + val df = a_d.zip(b_d).toDF("a_d", "b_d") + + collectAndValidate(df, json, "floating_point-double_precision.json") + } + + test("index conversion") { + val data = List[Int](1, 2, 3, 4, 5, 6) + val json = + s""" + |{ + | "schema" : { + | "fields" : [ { + | "name" : "i", + | "type" : { + | "name" : "int", + | "isSigned" : true, + | "bitWidth" : 32 + | }, + | "nullable" : false, + | "children" : [ ], + | "typeLayout" : { + | "vectors" : [ { + | "type" : "VALIDITY", + | "typeBitWidth" : 1 + | }, { + | "type" : "DATA", + | "typeBitWidth" : 32 + | } ] + | } + | } ] + | }, + | "batches" : [ { + | "count" : 6, + | "columns" : [ { + | "name" : "i", + | "count" : 6, + | "VALIDITY" : [ 1, 1, 1, 1, 1, 1 ], + | "DATA" : [ 1, 2, 3, 4, 5, 6 ] + | } ] + | } ] + |} + """.stripMargin + val df = data.toDF("i") + + collectAndValidate(df, json, "indexData-ints.json") + } + + test("mixed numeric type conversion") { + val json = + s""" + |{ + | "schema" : { + | "fields" : [ { + | "name" : "a", + | "type" : { + | "name" : "int", + | "isSigned" : true, + | "bitWidth" : 16 + | }, + | "nullable" : false, + | "children" : [ ], + | "typeLayout" : { + | "vectors" : [ { + | "type" : "VALIDITY", + | "typeBitWidth" : 1 + | }, { + | "type" : "DATA", + | "typeBitWidth" : 16 + | } ] + | } + | }, { + | "name" : "b", + | "type" : { + | "name" : "floatingpoint", + | "precision" : "SINGLE" + | }, + | "nullable" : false, + | "children" : [ ], + | "typeLayout" : { + | "vectors" : [ { + | "type" : "VALIDITY", + | "typeBitWidth" : 1 + | }, { + | "type" : "DATA", + | "typeBitWidth" : 32 + | } ] + | } + | }, { + | "name" : "c", + | "type" : { + | "name" : "int", + | "isSigned" : true, + | "bitWidth" : 32 + | }, + | "nullable" : false, + | "children" : [ ], + | "typeLayout" : { + | "vectors" : [ { + | "type" : "VALIDITY", + | "typeBitWidth" : 1 + | }, { + | "type" : "DATA", + | "typeBitWidth" : 32 + | } ] + | } + | }, { + | "name" : "d", + | "type" : { + | "name" : "floatingpoint", + | "precision" : "DOUBLE" + | }, + | "nullable" : false, + | "children" : [ ], + | "typeLayout" : { + | "vectors" : [ { + | "type" : "VALIDITY", + | "typeBitWidth" : 1 + | }, { + | "type" : "DATA", + | "typeBitWidth" : 64 + | } ] + | } + | }, { + | "name" : "e", + | "type" : { + | "name" : "int", + | "isSigned" : true, + | "bitWidth" : 64 + | }, + | "nullable" : false, + | "children" : [ ], + | "typeLayout" : { + | "vectors" : [ { + | "type" : "VALIDITY", + | "typeBitWidth" : 1 + | }, { + | "type" : "DATA", + | "typeBitWidth" : 64 + | } ] + | } + | } ] + | }, + | "batches" : [ { + | "count" : 6, + | "columns" : [ { + | "name" : "a", + | "count" : 6, + | "VALIDITY" : [ 1, 1, 1, 1, 1, 1 ], + | "DATA" : [ 1, 2, 3, 4, 5, 6 ] + | }, { + | "name" : "b", + | "count" : 6, + | "VALIDITY" : [ 1, 1, 1, 1, 1, 1 ], + | "DATA" : [ 1.0, 2.0, 3.0, 4.0, 5.0, 6.0 ] + | }, { + | "name" : "c", + | "count" : 6, + | "VALIDITY" : [ 1, 1, 1, 1, 1, 1 ], + | "DATA" : [ 1, 2, 3, 4, 5, 6 ] + | }, { + | "name" : "d", + | "count" : 6, + | "VALIDITY" : [ 1, 1, 1, 1, 1, 1 ], + | "DATA" : [ 1.0, 2.0, 3.0, 4.0, 5.0, 6.0 ] + | }, { + | "name" : "e", + | "count" : 6, + | "VALIDITY" : [ 1, 1, 1, 1, 1, 1 ], + | "DATA" : [ 1, 2, 3, 4, 5, 6 ] + | } ] + | } ] + |} + """.stripMargin + + val data = List(1, 2, 3, 4, 5, 6) + val data_tuples = for (d <- data) yield { + (d.toShort, d.toFloat, d.toInt, d.toDouble, d.toLong) + } + val df = data_tuples.toDF("a", "b", "c", "d", "e") + + collectAndValidate(df, json, "mixed_numeric_types.json") + } + + test("string type conversion") { + val json = + s""" + |{ + | "schema" : { + | "fields" : [ { + | "name" : "upper_case", + | "type" : { + | "name" : "utf8" + | }, + | "nullable" : true, + | "children" : [ ], + | "typeLayout" : { + | "vectors" : [ { + | "type" : "VALIDITY", + | "typeBitWidth" : 1 + | }, { + | "type" : "OFFSET", + | "typeBitWidth" : 32 + | }, { + | "type" : "DATA", + | "typeBitWidth" : 8 + | } ] + | } + | }, { + | "name" : "lower_case", + | "type" : { + | "name" : "utf8" + | }, + | "nullable" : true, + | "children" : [ ], + | "typeLayout" : { + | "vectors" : [ { + | "type" : "VALIDITY", + | "typeBitWidth" : 1 + | }, { + | "type" : "OFFSET", + | "typeBitWidth" : 32 + | }, { + | "type" : "DATA", + | "typeBitWidth" : 8 + | } ] + | } + | }, { + | "name" : "null_str", + | "type" : { + | "name" : "utf8" + | }, + | "nullable" : true, + | "children" : [ ], + | "typeLayout" : { + | "vectors" : [ { + | "type" : "VALIDITY", + | "typeBitWidth" : 1 + | }, { + | "type" : "OFFSET", + | "typeBitWidth" : 32 + | }, { + | "type" : "DATA", + | "typeBitWidth" : 8 + | } ] + | } + | } ] + | }, + | "batches" : [ { + | "count" : 3, + | "columns" : [ { + | "name" : "upper_case", + | "count" : 3, + | "VALIDITY" : [ 1, 1, 1 ], + | "OFFSET" : [ 0, 1, 2, 3 ], + | "DATA" : [ "A", "B", "C" ] + | }, { + | "name" : "lower_case", + | "count" : 3, + | "VALIDITY" : [ 1, 1, 1 ], + | "OFFSET" : [ 0, 1, 2, 3 ], + | "DATA" : [ "a", "b", "c" ] + | }, { + | "name" : "null_str", + | "count" : 3, + | "VALIDITY" : [ 1, 1, 0 ], + | "OFFSET" : [ 0, 2, 5, 5 ], + | "DATA" : [ "ab", "CDE", "" ] + | } ] + | } ] + |} + """.stripMargin + + val upperCase = Seq("A", "B", "C") + val lowerCase = Seq("a", "b", "c") + val nullStr = Seq("ab", "CDE", null) + val df = (upperCase, lowerCase, nullStr).zipped.toList + .toDF("upper_case", "lower_case", "null_str") + + collectAndValidate(df, json, "stringData.json") + } + + test("boolean type conversion") { + val json = + s""" + |{ + | "schema" : { + | "fields" : [ { + | "name" : "a_bool", + | "type" : { + | "name" : "bool" + | }, + | "nullable" : false, + | "children" : [ ], + | "typeLayout" : { + | "vectors" : [ { + | "type" : "VALIDITY", + | "typeBitWidth" : 1 + | }, { + | "type" : "DATA", + | "typeBitWidth" : 1 + | } ] + | } + | } ] + | }, + | "batches" : [ { + | "count" : 4, + | "columns" : [ { + | "name" : "a_bool", + | "count" : 4, + | "VALIDITY" : [ 1, 1, 1, 1 ], + | "DATA" : [ true, true, false, true ] + | } ] + | } ] + |} + """.stripMargin + val df = Seq(true, true, false, true).toDF("a_bool") + collectAndValidate(df, json, "boolData.json") + } + + test("byte type conversion") { + val json = + s""" + |{ + | "schema" : { + | "fields" : [ { + | "name" : "a_byte", + | "type" : { + | "name" : "int", + | "isSigned" : true, + | "bitWidth" : 8 + | }, + | "nullable" : false, + | "children" : [ ], + | "typeLayout" : { + | "vectors" : [ { + | "type" : "VALIDITY", + | "typeBitWidth" : 1 + | }, { + | "type" : "DATA", + | "typeBitWidth" : 8 + | } ] + | } + | } ] + | }, + | "batches" : [ { + | "count" : 4, + | "columns" : [ { + | "name" : "a_byte", + | "count" : 4, + | "VALIDITY" : [ 1, 1, 1, 1 ], + | "DATA" : [ 1, -1, 64, 127 ] + | } ] + | } ] + |} + | + """.stripMargin + val df = List[Byte](1.toByte, (-1).toByte, 64.toByte, Byte.MaxValue).toDF("a_byte") + collectAndValidate(df, json, "byteData.json") + } + + test("binary type conversion") { + val json = + s""" + |{ + | "schema" : { + | "fields" : [ { + | "name" : "a_binary", + | "type" : { + | "name" : "binary" + | }, + | "nullable" : true, + | "children" : [ ], + | "typeLayout" : { + | "vectors" : [ { + | "type" : "VALIDITY", + | "typeBitWidth" : 1 + | }, { + | "type" : "OFFSET", + | "typeBitWidth" : 32 + | }, { + | "type" : "DATA", + | "typeBitWidth" : 8 + | } ] + | } + | } ] + | }, + | "batches" : [ { + | "count" : 3, + | "columns" : [ { + | "name" : "a_binary", + | "count" : 3, + | "VALIDITY" : [ 1, 1, 1 ], + | "OFFSET" : [ 0, 3, 4, 6 ], + | "DATA" : [ "616263", "64", "6566" ] + | } ] + | } ] + |} + """.stripMargin + + val data = Seq("abc", "d", "ef") + val rdd = sparkContext.parallelize(data.map(s => Row(s.getBytes("utf-8")))) + val df = spark.createDataFrame(rdd, StructType(Seq(StructField("a_binary", BinaryType)))) + + collectAndValidate(df, json, "binaryData.json") + } + + test("floating-point NaN") { + val json = + s""" + |{ + | "schema" : { + | "fields" : [ { + | "name" : "NaN_f", + | "type" : { + | "name" : "floatingpoint", + | "precision" : "SINGLE" + | }, + | "nullable" : false, + | "children" : [ ], + | "typeLayout" : { + | "vectors" : [ { + | "type" : "VALIDITY", + | "typeBitWidth" : 1 + | }, { + | "type" : "DATA", + | "typeBitWidth" : 32 + | } ] + | } + | }, { + | "name" : "NaN_d", + | "type" : { + | "name" : "floatingpoint", + | "precision" : "DOUBLE" + | }, + | "nullable" : false, + | "children" : [ ], + | "typeLayout" : { + | "vectors" : [ { + | "type" : "VALIDITY", + | "typeBitWidth" : 1 + | }, { + | "type" : "DATA", + | "typeBitWidth" : 64 + | } ] + | } + | } ] + | }, + | "batches" : [ { + | "count" : 2, + | "columns" : [ { + | "name" : "NaN_f", + | "count" : 2, + | "VALIDITY" : [ 1, 1 ], + | "DATA" : [ 1.2000000476837158, "NaN" ] + | }, { + | "name" : "NaN_d", + | "count" : 2, + | "VALIDITY" : [ 1, 1 ], + | "DATA" : [ "NaN", 1.2 ] + | } ] + | } ] + |} + """.stripMargin + + val fnan = Seq(1.2F, Float.NaN) + val dnan = Seq(Double.NaN, 1.2) + val df = fnan.zip(dnan).toDF("NaN_f", "NaN_d") + + collectAndValidate(df, json, "nanData-floating_point.json") + } + + test("partitioned DataFrame") { + val json1 = + s""" + |{ + | "schema" : { + | "fields" : [ { + | "name" : "a", + | "type" : { + | "name" : "int", + | "isSigned" : true, + | "bitWidth" : 32 + | }, + | "nullable" : false, + | "children" : [ ], + | "typeLayout" : { + | "vectors" : [ { + | "type" : "VALIDITY", + | "typeBitWidth" : 1 + | }, { + | "type" : "DATA", + | "typeBitWidth" : 32 + | } ] + | } + | }, { + | "name" : "b", + | "type" : { + | "name" : "int", + | "isSigned" : true, + | "bitWidth" : 32 + | }, + | "nullable" : false, + | "children" : [ ], + | "typeLayout" : { + | "vectors" : [ { + | "type" : "VALIDITY", + | "typeBitWidth" : 1 + | }, { + | "type" : "DATA", + | "typeBitWidth" : 32 + | } ] + | } + | } ] + | }, + | "batches" : [ { + | "count" : 3, + | "columns" : [ { + | "name" : "a", + | "count" : 3, + | "VALIDITY" : [ 1, 1, 1 ], + | "DATA" : [ 1, 1, 2 ] + | }, { + | "name" : "b", + | "count" : 3, + | "VALIDITY" : [ 1, 1, 1 ], + | "DATA" : [ 1, 2, 1 ] + | } ] + | } ] + |} + """.stripMargin + val json2 = + s""" + |{ + | "schema" : { + | "fields" : [ { + | "name" : "a", + | "type" : { + | "name" : "int", + | "isSigned" : true, + | "bitWidth" : 32 + | }, + | "nullable" : false, + | "children" : [ ], + | "typeLayout" : { + | "vectors" : [ { + | "type" : "VALIDITY", + | "typeBitWidth" : 1 + | }, { + | "type" : "DATA", + | "typeBitWidth" : 32 + | } ] + | } + | }, { + | "name" : "b", + | "type" : { + | "name" : "int", + | "isSigned" : true, + | "bitWidth" : 32 + | }, + | "nullable" : false, + | "children" : [ ], + | "typeLayout" : { + | "vectors" : [ { + | "type" : "VALIDITY", + | "typeBitWidth" : 1 + | }, { + | "type" : "DATA", + | "typeBitWidth" : 32 + | } ] + | } + | } ] + | }, + | "batches" : [ { + | "count" : 3, + | "columns" : [ { + | "name" : "a", + | "count" : 3, + | "VALIDITY" : [ 1, 1, 1 ], + | "DATA" : [ 2, 3, 3 ] + | }, { + | "name" : "b", + | "count" : 3, + | "VALIDITY" : [ 1, 1, 1 ], + | "DATA" : [ 2, 1, 2 ] + | } ] + | } ] + |} + """.stripMargin + + val arrowPayloads = testData2.toArrowPayload.collect() + // NOTE: testData2 should have 2 partitions -> 2 arrow batches in payload + assert(arrowPayloads.length === 2) + val schema = testData2.schema + + val tempFile1 = new File(tempDataPath, "testData2-ints-part1.json") + val tempFile2 = new File(tempDataPath, "testData2-ints-part2.json") + Files.write(json1, tempFile1, StandardCharsets.UTF_8) + Files.write(json2, tempFile2, StandardCharsets.UTF_8) + + validateConversion(schema, arrowPayloads(0), tempFile1) + validateConversion(schema, arrowPayloads(1), tempFile2) + } + + test("empty frame collect") { + val arrowPayload = spark.emptyDataFrame.toArrowPayload.collect() + assert(arrowPayload.isEmpty) + + val filteredDF = List[Int](1, 2, 3, 4, 5, 6).toDF("i") + val filteredArrowPayload = filteredDF.filter("i < 0").toArrowPayload.collect() + assert(filteredArrowPayload.isEmpty) + } + + test("empty partition collect") { + val emptyPart = spark.sparkContext.parallelize(Seq(1), 2).toDF("i") + val arrowPayloads = emptyPart.toArrowPayload.collect() + assert(arrowPayloads.length === 1) + val allocator = new RootAllocator(Long.MaxValue) + val arrowRecordBatches = arrowPayloads.map(_.loadBatch(allocator)) + assert(arrowRecordBatches.head.getLength == 1) + arrowRecordBatches.foreach(_.close()) + allocator.close() + } + + test("max records in batch conf") { + val totalRecords = 10 + val maxRecordsPerBatch = 3 + spark.conf.set("spark.sql.execution.arrow.maxRecordsPerBatch", maxRecordsPerBatch) + val df = spark.sparkContext.parallelize(1 to totalRecords, 2).toDF("i") + val arrowPayloads = df.toArrowPayload.collect() + val allocator = new RootAllocator(Long.MaxValue) + val arrowRecordBatches = arrowPayloads.map(_.loadBatch(allocator)) + var recordCount = 0 + arrowRecordBatches.foreach { batch => + assert(batch.getLength > 0) + assert(batch.getLength <= maxRecordsPerBatch) + recordCount += batch.getLength + batch.close() + } + assert(recordCount == totalRecords) + allocator.close() + spark.conf.unset("spark.sql.execution.arrow.maxRecordsPerBatch") + } + + testQuietly("unsupported types") { + def runUnsupported(block: => Unit): Unit = { + val msg = intercept[SparkException] { + block + } + assert(msg.getMessage.contains("Unsupported data type")) + assert(msg.getCause.getClass === classOf[UnsupportedOperationException]) + } + + runUnsupported { decimalData.toArrowPayload.collect() } + runUnsupported { arrayData.toDF().toArrowPayload.collect() } + runUnsupported { mapData.toDF().toArrowPayload.collect() } + runUnsupported { complexData.toArrowPayload.collect() } + + val sdf = new SimpleDateFormat("yyyy-MM-dd HH:mm:ss.SSS z", Locale.US) + val d1 = new Date(sdf.parse("2015-04-08 13:10:15.000 UTC").getTime) + val d2 = new Date(sdf.parse("2016-05-09 13:10:15.000 UTC").getTime) + runUnsupported { Seq(d1, d2).toDF("date").toArrowPayload.collect() } + + val ts1 = new Timestamp(sdf.parse("2013-04-08 01:10:15.567 UTC").getTime) + val ts2 = new Timestamp(sdf.parse("2013-04-08 13:10:10.789 UTC").getTime) + runUnsupported { Seq(ts1, ts2).toDF("timestamp").toArrowPayload.collect() } + } + + test("test Arrow Validator") { + val json = + s""" + |{ + | "schema" : { + | "fields" : [ { + | "name" : "a_i", + | "type" : { + | "name" : "int", + | "isSigned" : true, + | "bitWidth" : 32 + | }, + | "nullable" : false, + | "children" : [ ], + | "typeLayout" : { + | "vectors" : [ { + | "type" : "VALIDITY", + | "typeBitWidth" : 1 + | }, { + | "type" : "DATA", + | "typeBitWidth" : 32 + | } ] + | } + | }, { + | "name" : "b_i", + | "type" : { + | "name" : "int", + | "isSigned" : true, + | "bitWidth" : 32 + | }, + | "nullable" : true, + | "children" : [ ], + | "typeLayout" : { + | "vectors" : [ { + | "type" : "VALIDITY", + | "typeBitWidth" : 1 + | }, { + | "type" : "DATA", + | "typeBitWidth" : 32 + | } ] + | } + | } ] + | }, + | "batches" : [ { + | "count" : 6, + | "columns" : [ { + | "name" : "a_i", + | "count" : 6, + | "VALIDITY" : [ 1, 1, 1, 1, 1, 1 ], + | "DATA" : [ 1, -1, 2, -2, 2147483647, -2147483648 ] + | }, { + | "name" : "b_i", + | "count" : 6, + | "VALIDITY" : [ 1, 0, 0, 1, 0, 1 ], + | "DATA" : [ 1, 0, 0, -2, 0, -2147483648 ] + | } ] + | } ] + |} + """.stripMargin + val json_diff_col_order = + s""" + |{ + | "schema" : { + | "fields" : [ { + | "name" : "b_i", + | "type" : { + | "name" : "int", + | "isSigned" : true, + | "bitWidth" : 32 + | }, + | "nullable" : true, + | "children" : [ ], + | "typeLayout" : { + | "vectors" : [ { + | "type" : "VALIDITY", + | "typeBitWidth" : 1 + | }, { + | "type" : "DATA", + | "typeBitWidth" : 32 + | } ] + | } + | }, { + | "name" : "a_i", + | "type" : { + | "name" : "int", + | "isSigned" : true, + | "bitWidth" : 32 + | }, + | "nullable" : false, + | "children" : [ ], + | "typeLayout" : { + | "vectors" : [ { + | "type" : "VALIDITY", + | "typeBitWidth" : 1 + | }, { + | "type" : "DATA", + | "typeBitWidth" : 32 + | } ] + | } + | } ] + | }, + | "batches" : [ { + | "count" : 6, + | "columns" : [ { + | "name" : "a_i", + | "count" : 6, + | "VALIDITY" : [ 1, 1, 1, 1, 1, 1 ], + | "DATA" : [ 1, -1, 2, -2, 2147483647, -2147483648 ] + | }, { + | "name" : "b_i", + | "count" : 6, + | "VALIDITY" : [ 1, 0, 0, 1, 0, 1 ], + | "DATA" : [ 1, 0, 0, -2, 0, -2147483648 ] + | } ] + | } ] + |} + """.stripMargin + + val a_i = List[Int](1, -1, 2, -2, 2147483647, -2147483648) + val b_i = List[Option[Int]](Some(1), None, None, Some(-2), None, Some(-2147483648)) + val df = a_i.zip(b_i).toDF("a_i", "b_i") + + // Different schema + intercept[IllegalArgumentException] { + collectAndValidate(df, json_diff_col_order, "validator_diff_schema.json") + } + + // Different values + intercept[IllegalArgumentException] { + collectAndValidate(df.sort($"a_i".desc), json, "validator_diff_values.json") + } + } + + /** Test that a converted DataFrame to Arrow record batch equals batch read from JSON file */ + private def collectAndValidate(df: DataFrame, json: String, file: String): Unit = { + // NOTE: coalesce to single partition because can only load 1 batch in validator + val arrowPayload = df.coalesce(1).toArrowPayload.collect().head + val tempFile = new File(tempDataPath, file) + Files.write(json, tempFile, StandardCharsets.UTF_8) + validateConversion(df.schema, arrowPayload, tempFile) + } + + private def validateConversion( + sparkSchema: StructType, + arrowPayload: ArrowPayload, + jsonFile: File): Unit = { + val allocator = new RootAllocator(Long.MaxValue) + val jsonReader = new JsonFileReader(jsonFile, allocator) + + val arrowSchema = ArrowConverters.schemaToArrowSchema(sparkSchema) + val jsonSchema = jsonReader.start() + Validator.compareSchemas(arrowSchema, jsonSchema) + + val arrowRoot = VectorSchemaRoot.create(arrowSchema, allocator) + val vectorLoader = new VectorLoader(arrowRoot) + val arrowRecordBatch = arrowPayload.loadBatch(allocator) + vectorLoader.load(arrowRecordBatch) + val jsonRoot = jsonReader.read() + Validator.compareVectorSchemaRoot(arrowRoot, jsonRoot) + + jsonRoot.close() + jsonReader.close() + arrowRecordBatch.close() + arrowRoot.close() + allocator.close() + } +} From c3713fde86204bf3f027483914ff9e60e7aad261 Mon Sep 17 00:00:00 2001 From: chie8842 Date: Mon, 10 Jul 2017 18:56:54 -0700 Subject: [PATCH 0925/1765] [SPARK-21358][EXAMPLES] Argument of repartitionandsortwithinpartitions at pyspark ## What changes were proposed in this pull request? At example of repartitionAndSortWithinPartitions at rdd.py, third argument should be True or False. I proposed fix of example code. ## How was this patch tested? * I rename test_repartitionAndSortWithinPartitions to test_repartitionAndSortWIthinPartitions_asc to specify boolean argument. * I added test_repartitionAndSortWithinPartitions_desc to test False pattern at third argument. (Please explain how this patch was tested. E.g. unit tests, integration tests, manual tests) (If this patch involves UI changes, please attach a screenshot; otherwise, remove this) Please review http://spark.apache.org/contributing.html before opening a pull request. Author: chie8842 Closes #18586 from chie8842/SPARK-21358. --- python/pyspark/rdd.py | 2 +- python/pyspark/tests.py | 12 ++++++++++-- 2 files changed, 11 insertions(+), 3 deletions(-) diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index 7dfa17f68a943..3325b65f8b600 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -608,7 +608,7 @@ def repartitionAndSortWithinPartitions(self, numPartitions=None, partitionFunc=p sort records by their keys. >>> rdd = sc.parallelize([(0, 5), (3, 8), (2, 6), (0, 8), (3, 8), (1, 3)]) - >>> rdd2 = rdd.repartitionAndSortWithinPartitions(2, lambda x: x % 2, 2) + >>> rdd2 = rdd.repartitionAndSortWithinPartitions(2, lambda x: x % 2, True) >>> rdd2.glom().collect() [[(0, 5), (0, 8), (2, 6)], [(1, 3), (3, 8), (3, 8)]] """ diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py index bb13de563cdd4..73ab442dfd791 100644 --- a/python/pyspark/tests.py +++ b/python/pyspark/tests.py @@ -1019,14 +1019,22 @@ def test_histogram(self): self.assertEqual((["ab", "ef"], [5]), rdd.histogram(1)) self.assertRaises(TypeError, lambda: rdd.histogram(2)) - def test_repartitionAndSortWithinPartitions(self): + def test_repartitionAndSortWithinPartitions_asc(self): rdd = self.sc.parallelize([(0, 5), (3, 8), (2, 6), (0, 8), (3, 8), (1, 3)], 2) - repartitioned = rdd.repartitionAndSortWithinPartitions(2, lambda key: key % 2) + repartitioned = rdd.repartitionAndSortWithinPartitions(2, lambda key: key % 2, True) partitions = repartitioned.glom().collect() self.assertEqual(partitions[0], [(0, 5), (0, 8), (2, 6)]) self.assertEqual(partitions[1], [(1, 3), (3, 8), (3, 8)]) + def test_repartitionAndSortWithinPartitions_desc(self): + rdd = self.sc.parallelize([(0, 5), (3, 8), (2, 6), (0, 8), (3, 8), (1, 3)], 2) + + repartitioned = rdd.repartitionAndSortWithinPartitions(2, lambda key: key % 2, False) + partitions = repartitioned.glom().collect() + self.assertEqual(partitions[0], [(2, 6), (0, 5), (0, 8)]) + self.assertEqual(partitions[1], [(3, 8), (3, 8), (1, 3)]) + def test_repartition_no_skewed(self): num_partitions = 20 a = self.sc.parallelize(range(int(1000)), 2) From a2bec6c92a063f4a8e9ed75a9f3f06808485b6d7 Mon Sep 17 00:00:00 2001 From: Takeshi Yamamuro Date: Mon, 10 Jul 2017 20:16:29 -0700 Subject: [PATCH 0926/1765] [SPARK-21043][SQL] Add unionByName in Dataset ## What changes were proposed in this pull request? This pr added `unionByName` in `DataSet`. Here is how to use: ``` val df1 = Seq((1, 2, 3)).toDF("col0", "col1", "col2") val df2 = Seq((4, 5, 6)).toDF("col1", "col2", "col0") df1.unionByName(df2).show // output: // +----+----+----+ // |col0|col1|col2| // +----+----+----+ // | 1| 2| 3| // | 6| 4| 5| // +----+----+----+ ``` ## How was this patch tested? Added tests in `DataFrameSuite`. Author: Takeshi Yamamuro Closes #18300 from maropu/SPARK-21043-2. --- .../scala/org/apache/spark/sql/Dataset.scala | 60 +++++++++++++ .../org/apache/spark/sql/DataFrameSuite.scala | 87 +++++++++++++++++++ 2 files changed, 147 insertions(+) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index a7773831df075..7f3ae05411516 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -53,6 +53,7 @@ import org.apache.spark.sql.execution.python.EvaluatePython import org.apache.spark.sql.execution.stat.StatFunctions import org.apache.spark.sql.streaming.DataStreamWriter import org.apache.spark.sql.types._ +import org.apache.spark.sql.util.SchemaUtils import org.apache.spark.storage.StorageLevel import org.apache.spark.unsafe.types.CalendarInterval import org.apache.spark.util.Utils @@ -1734,6 +1735,65 @@ class Dataset[T] private[sql]( CombineUnions(Union(logicalPlan, other.logicalPlan)) } + /** + * Returns a new Dataset containing union of rows in this Dataset and another Dataset. + * + * This is different from both `UNION ALL` and `UNION DISTINCT` in SQL. To do a SQL-style set + * union (that does deduplication of elements), use this function followed by a [[distinct]]. + * + * The difference between this function and [[union]] is that this function + * resolves columns by name (not by position): + * + * {{{ + * val df1 = Seq((1, 2, 3)).toDF("col0", "col1", "col2") + * val df2 = Seq((4, 5, 6)).toDF("col1", "col2", "col0") + * df1.unionByName(df2).show + * + * // output: + * // +----+----+----+ + * // |col0|col1|col2| + * // +----+----+----+ + * // | 1| 2| 3| + * // | 6| 4| 5| + * // +----+----+----+ + * }}} + * + * @group typedrel + * @since 2.3.0 + */ + def unionByName(other: Dataset[T]): Dataset[T] = withSetOperator { + // Check column name duplication + val resolver = sparkSession.sessionState.analyzer.resolver + val leftOutputAttrs = logicalPlan.output + val rightOutputAttrs = other.logicalPlan.output + + SchemaUtils.checkColumnNameDuplication( + leftOutputAttrs.map(_.name), + "in the left attributes", + sparkSession.sessionState.conf.caseSensitiveAnalysis) + SchemaUtils.checkColumnNameDuplication( + rightOutputAttrs.map(_.name), + "in the right attributes", + sparkSession.sessionState.conf.caseSensitiveAnalysis) + + // Builds a project list for `other` based on `logicalPlan` output names + val rightProjectList = leftOutputAttrs.map { lattr => + rightOutputAttrs.find { rattr => resolver(lattr.name, rattr.name) }.getOrElse { + throw new AnalysisException( + s"""Cannot resolve column name "${lattr.name}" among """ + + s"""(${rightOutputAttrs.map(_.name).mkString(", ")})""") + } + } + + // Delegates failure checks to `CheckAnalysis` + val notFoundAttrs = rightOutputAttrs.diff(rightProjectList) + val rightChild = Project(rightProjectList ++ notFoundAttrs, other.logicalPlan) + + // This breaks caching, but it's usually ok because it addresses a very specific use case: + // using union to union many files or partitions. + CombineUnions(Union(logicalPlan, rightChild)) + } + /** * Returns a new Dataset containing rows only in both this Dataset and another Dataset. * This is equivalent to `INTERSECT` in SQL. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index a5a2e1c38d300..5ae27032e0e94 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -111,6 +111,93 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { ) } + test("union by name") { + var df1 = Seq((1, 2, 3)).toDF("a", "b", "c") + var df2 = Seq((3, 1, 2)).toDF("c", "a", "b") + val df3 = Seq((2, 3, 1)).toDF("b", "c", "a") + val unionDf = df1.unionByName(df2.unionByName(df3)) + checkAnswer(unionDf, + Row(1, 2, 3) :: Row(1, 2, 3) :: Row(1, 2, 3) :: Nil + ) + + // Check if adjacent unions are combined into a single one + assert(unionDf.queryExecution.optimizedPlan.collect { case u: Union => true }.size == 1) + + // Check failure cases + df1 = Seq((1, 2)).toDF("a", "c") + df2 = Seq((3, 4, 5)).toDF("a", "b", "c") + var errMsg = intercept[AnalysisException] { + df1.unionByName(df2) + }.getMessage + assert(errMsg.contains( + "Union can only be performed on tables with the same number of columns, " + + "but the first table has 2 columns and the second table has 3 columns")) + + df1 = Seq((1, 2, 3)).toDF("a", "b", "c") + df2 = Seq((4, 5, 6)).toDF("a", "c", "d") + errMsg = intercept[AnalysisException] { + df1.unionByName(df2) + }.getMessage + assert(errMsg.contains("""Cannot resolve column name "b" among (a, c, d)""")) + } + + test("union by name - type coercion") { + var df1 = Seq((1, "a")).toDF("c0", "c1") + var df2 = Seq((3, 1L)).toDF("c1", "c0") + checkAnswer(df1.unionByName(df2), Row(1L, "a") :: Row(1L, "3") :: Nil) + + df1 = Seq((1, 1.0)).toDF("c0", "c1") + df2 = Seq((8L, 3.0)).toDF("c1", "c0") + checkAnswer(df1.unionByName(df2), Row(1.0, 1.0) :: Row(3.0, 8.0) :: Nil) + + df1 = Seq((2.0f, 7.4)).toDF("c0", "c1") + df2 = Seq(("a", 4.0)).toDF("c1", "c0") + checkAnswer(df1.unionByName(df2), Row(2.0, "7.4") :: Row(4.0, "a") :: Nil) + + df1 = Seq((1, "a", 3.0)).toDF("c0", "c1", "c2") + df2 = Seq((1.2, 2, "bc")).toDF("c2", "c0", "c1") + val df3 = Seq(("def", 1.2, 3)).toDF("c1", "c2", "c0") + checkAnswer(df1.unionByName(df2.unionByName(df3)), + Row(1, "a", 3.0) :: Row(2, "bc", 1.2) :: Row(3, "def", 1.2) :: Nil + ) + } + + test("union by name - check case sensitivity") { + def checkCaseSensitiveTest(): Unit = { + val df1 = Seq((1, 2, 3)).toDF("ab", "cd", "ef") + val df2 = Seq((4, 5, 6)).toDF("cd", "ef", "AB") + checkAnswer(df1.unionByName(df2), Row(1, 2, 3) :: Row(6, 4, 5) :: Nil) + } + withSQLConf(SQLConf.CASE_SENSITIVE.key -> "true") { + val errMsg2 = intercept[AnalysisException] { + checkCaseSensitiveTest() + }.getMessage + assert(errMsg2.contains("""Cannot resolve column name "ab" among (cd, ef, AB)""")) + } + withSQLConf(SQLConf.CASE_SENSITIVE.key -> "false") { + checkCaseSensitiveTest() + } + } + + test("union by name - check name duplication") { + Seq((true, ("a", "a")), (false, ("aA", "Aa"))).foreach { case (caseSensitive, (c0, c1)) => + withSQLConf(SQLConf.CASE_SENSITIVE.key -> caseSensitive.toString) { + var df1 = Seq((1, 1)).toDF(c0, c1) + var df2 = Seq((1, 1)).toDF("c0", "c1") + var errMsg = intercept[AnalysisException] { + df1.unionByName(df2) + }.getMessage + assert(errMsg.contains("Found duplicate column(s) in the left attributes:")) + df1 = Seq((1, 1)).toDF("c0", "c1") + df2 = Seq((1, 1)).toDF(c0, c1) + errMsg = intercept[AnalysisException] { + df1.unionByName(df2) + }.getMessage + assert(errMsg.contains("Found duplicate column(s) in the right attributes:")) + } + } + } + test("empty data frame") { assert(spark.emptyDataFrame.columns.toSeq === Seq.empty[String]) assert(spark.emptyDataFrame.count() === 0) From 1471ee7af5a9952b60cf8c56d60cb6a7ec46cc69 Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Tue, 11 Jul 2017 11:19:59 +0800 Subject: [PATCH 0927/1765] [SPARK-21350][SQL] Fix the error message when the number of arguments is wrong when invoking a UDF ### What changes were proposed in this pull request? Users get a very confusing error when users specify a wrong number of parameters. ```Scala val df = spark.emptyDataFrame spark.udf.register("foo", (_: String).length) df.selectExpr("foo(2, 3, 4)") ``` ``` org.apache.spark.sql.UDFSuite$$anonfun$9$$anonfun$apply$mcV$sp$12 cannot be cast to scala.Function3 java.lang.ClassCastException: org.apache.spark.sql.UDFSuite$$anonfun$9$$anonfun$apply$mcV$sp$12 cannot be cast to scala.Function3 at org.apache.spark.sql.catalyst.expressions.ScalaUDF.(ScalaUDF.scala:109) ``` This PR is to capture the exception and issue an error message that is consistent with what we did for built-in functions. After the fix, the error message is improved to ``` Invalid number of arguments for function foo; line 1 pos 0 org.apache.spark.sql.AnalysisException: Invalid number of arguments for function foo; line 1 pos 0 at org.apache.spark.sql.catalyst.analysis.SimpleFunctionRegistry.lookupFunction(FunctionRegistry.scala:119) ``` ### How was this patch tested? Added a test case Author: gatorsmile Closes #18574 from gatorsmile/statsCheck. --- .../apache/spark/sql/UDFRegistration.scala | 412 +++++++++++++----- .../org/apache/spark/sql/JavaUDFSuite.java | 8 + .../scala/org/apache/spark/sql/UDFSuite.scala | 13 +- 3 files changed, 331 insertions(+), 102 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala b/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala index 8bdc0221888d0..c4d0adb5236f2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala @@ -111,7 +111,12 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends def register[$typeTags](name: String, func: Function$x[$types]): UserDefinedFunction = { val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] val inputTypes = Try($inputTypes).toOption - def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil), Some(name), nullable) + def builder(e: Seq[Expression]) = if (e.length == $x) { + ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil), Some(name), nullable) + } else { + throw new AnalysisException("Invalid number of arguments for function " + name + + ". Expected: $x; Found: " + e.length) + } functionRegistry.createOrReplaceTempFunction(name, builder) UserDefinedFunction(func, dataType, inputTypes).withName(name).withNullability(nullable) }""") @@ -123,16 +128,20 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends val anyCast = s".asInstanceOf[UDF$i[$anyTypeArgs, Any]]" val anyParams = (1 to i).map(_ => "_: Any").mkString(", ") println(s""" - |/** - | * Register a user-defined function with ${i} arguments. - | * @since 1.3.0 - | */ - |def register(name: String, f: UDF$i[$extTypeArgs, _], returnType: DataType): Unit = { - | val func = f$anyCast.call($anyParams) - | functionRegistry.createOrReplaceTempFunction( - | name, - | (e: Seq[Expression]) => ScalaUDF(func, returnType, e)) - |}""".stripMargin) + |/** + | * Register a user-defined function with ${i} arguments. + | * @since 1.3.0 + | */ + |def register(name: String, f: UDF$i[$extTypeArgs, _], returnType: DataType): Unit = { + | val func = f$anyCast.call($anyParams) + |def builder(e: Seq[Expression]) = if (e.length == $i) { + | ScalaUDF(func, returnType, e) + |} else { + | throw new AnalysisException("Invalid number of arguments for function " + name + + | ". Expected: $i; Found: " + e.length) + |} + |functionRegistry.createOrReplaceTempFunction(name, builder) + |}""".stripMargin) } */ @@ -144,7 +153,12 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends def register[RT: TypeTag](name: String, func: Function0[RT]): UserDefinedFunction = { val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] val inputTypes = Try(Nil).toOption - def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil), Some(name), nullable) + def builder(e: Seq[Expression]) = if (e.length == 0) { + ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil), Some(name), nullable) + } else { + throw new AnalysisException("Invalid number of arguments for function " + name + + ". Expected: 0; Found: " + e.length) + } functionRegistry.createOrReplaceTempFunction(name, builder) UserDefinedFunction(func, dataType, inputTypes).withName(name).withNullability(nullable) } @@ -157,7 +171,12 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends def register[RT: TypeTag, A1: TypeTag](name: String, func: Function1[A1, RT]): UserDefinedFunction = { val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: Nil).toOption - def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil), Some(name), nullable) + def builder(e: Seq[Expression]) = if (e.length == 1) { + ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil), Some(name), nullable) + } else { + throw new AnalysisException("Invalid number of arguments for function " + name + + ". Expected: 1; Found: " + e.length) + } functionRegistry.createOrReplaceTempFunction(name, builder) UserDefinedFunction(func, dataType, inputTypes).withName(name).withNullability(nullable) } @@ -170,7 +189,12 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends def register[RT: TypeTag, A1: TypeTag, A2: TypeTag](name: String, func: Function2[A1, A2, RT]): UserDefinedFunction = { val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: Nil).toOption - def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil), Some(name), nullable) + def builder(e: Seq[Expression]) = if (e.length == 2) { + ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil), Some(name), nullable) + } else { + throw new AnalysisException("Invalid number of arguments for function " + name + + ". Expected: 2; Found: " + e.length) + } functionRegistry.createOrReplaceTempFunction(name, builder) UserDefinedFunction(func, dataType, inputTypes).withName(name).withNullability(nullable) } @@ -183,7 +207,12 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag](name: String, func: Function3[A1, A2, A3, RT]): UserDefinedFunction = { val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: Nil).toOption - def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil), Some(name), nullable) + def builder(e: Seq[Expression]) = if (e.length == 3) { + ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil), Some(name), nullable) + } else { + throw new AnalysisException("Invalid number of arguments for function " + name + + ". Expected: 3; Found: " + e.length) + } functionRegistry.createOrReplaceTempFunction(name, builder) UserDefinedFunction(func, dataType, inputTypes).withName(name).withNullability(nullable) } @@ -196,7 +225,12 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag](name: String, func: Function4[A1, A2, A3, A4, RT]): UserDefinedFunction = { val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: Nil).toOption - def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil), Some(name), nullable) + def builder(e: Seq[Expression]) = if (e.length == 4) { + ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil), Some(name), nullable) + } else { + throw new AnalysisException("Invalid number of arguments for function " + name + + ". Expected: 4; Found: " + e.length) + } functionRegistry.createOrReplaceTempFunction(name, builder) UserDefinedFunction(func, dataType, inputTypes).withName(name).withNullability(nullable) } @@ -209,7 +243,12 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag](name: String, func: Function5[A1, A2, A3, A4, A5, RT]): UserDefinedFunction = { val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: Nil).toOption - def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil), Some(name), nullable) + def builder(e: Seq[Expression]) = if (e.length == 5) { + ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil), Some(name), nullable) + } else { + throw new AnalysisException("Invalid number of arguments for function " + name + + ". Expected: 5; Found: " + e.length) + } functionRegistry.createOrReplaceTempFunction(name, builder) UserDefinedFunction(func, dataType, inputTypes).withName(name).withNullability(nullable) } @@ -222,7 +261,12 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag](name: String, func: Function6[A1, A2, A3, A4, A5, A6, RT]): UserDefinedFunction = { val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: Nil).toOption - def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil), Some(name), nullable) + def builder(e: Seq[Expression]) = if (e.length == 6) { + ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil), Some(name), nullable) + } else { + throw new AnalysisException("Invalid number of arguments for function " + name + + ". Expected: 6; Found: " + e.length) + } functionRegistry.createOrReplaceTempFunction(name, builder) UserDefinedFunction(func, dataType, inputTypes).withName(name).withNullability(nullable) } @@ -235,7 +279,12 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag](name: String, func: Function7[A1, A2, A3, A4, A5, A6, A7, RT]): UserDefinedFunction = { val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: Nil).toOption - def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil), Some(name), nullable) + def builder(e: Seq[Expression]) = if (e.length == 7) { + ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil), Some(name), nullable) + } else { + throw new AnalysisException("Invalid number of arguments for function " + name + + ". Expected: 7; Found: " + e.length) + } functionRegistry.createOrReplaceTempFunction(name, builder) UserDefinedFunction(func, dataType, inputTypes).withName(name).withNullability(nullable) } @@ -248,7 +297,12 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag](name: String, func: Function8[A1, A2, A3, A4, A5, A6, A7, A8, RT]): UserDefinedFunction = { val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: ScalaReflection.schemaFor[A8].dataType :: Nil).toOption - def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil), Some(name), nullable) + def builder(e: Seq[Expression]) = if (e.length == 8) { + ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil), Some(name), nullable) + } else { + throw new AnalysisException("Invalid number of arguments for function " + name + + ". Expected: 8; Found: " + e.length) + } functionRegistry.createOrReplaceTempFunction(name, builder) UserDefinedFunction(func, dataType, inputTypes).withName(name).withNullability(nullable) } @@ -261,7 +315,12 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag](name: String, func: Function9[A1, A2, A3, A4, A5, A6, A7, A8, A9, RT]): UserDefinedFunction = { val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: ScalaReflection.schemaFor[A8].dataType :: ScalaReflection.schemaFor[A9].dataType :: Nil).toOption - def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil), Some(name), nullable) + def builder(e: Seq[Expression]) = if (e.length == 9) { + ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil), Some(name), nullable) + } else { + throw new AnalysisException("Invalid number of arguments for function " + name + + ". Expected: 9; Found: " + e.length) + } functionRegistry.createOrReplaceTempFunction(name, builder) UserDefinedFunction(func, dataType, inputTypes).withName(name).withNullability(nullable) } @@ -274,7 +333,12 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag](name: String, func: Function10[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, RT]): UserDefinedFunction = { val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: ScalaReflection.schemaFor[A8].dataType :: ScalaReflection.schemaFor[A9].dataType :: ScalaReflection.schemaFor[A10].dataType :: Nil).toOption - def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil), Some(name), nullable) + def builder(e: Seq[Expression]) = if (e.length == 10) { + ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil), Some(name), nullable) + } else { + throw new AnalysisException("Invalid number of arguments for function " + name + + ". Expected: 10; Found: " + e.length) + } functionRegistry.createOrReplaceTempFunction(name, builder) UserDefinedFunction(func, dataType, inputTypes).withName(name).withNullability(nullable) } @@ -287,7 +351,12 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag](name: String, func: Function11[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, RT]): UserDefinedFunction = { val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: ScalaReflection.schemaFor[A8].dataType :: ScalaReflection.schemaFor[A9].dataType :: ScalaReflection.schemaFor[A10].dataType :: ScalaReflection.schemaFor[A11].dataType :: Nil).toOption - def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil), Some(name), nullable) + def builder(e: Seq[Expression]) = if (e.length == 11) { + ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil), Some(name), nullable) + } else { + throw new AnalysisException("Invalid number of arguments for function " + name + + ". Expected: 11; Found: " + e.length) + } functionRegistry.createOrReplaceTempFunction(name, builder) UserDefinedFunction(func, dataType, inputTypes).withName(name).withNullability(nullable) } @@ -300,7 +369,12 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag](name: String, func: Function12[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, RT]): UserDefinedFunction = { val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: ScalaReflection.schemaFor[A8].dataType :: ScalaReflection.schemaFor[A9].dataType :: ScalaReflection.schemaFor[A10].dataType :: ScalaReflection.schemaFor[A11].dataType :: ScalaReflection.schemaFor[A12].dataType :: Nil).toOption - def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil), Some(name), nullable) + def builder(e: Seq[Expression]) = if (e.length == 12) { + ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil), Some(name), nullable) + } else { + throw new AnalysisException("Invalid number of arguments for function " + name + + ". Expected: 12; Found: " + e.length) + } functionRegistry.createOrReplaceTempFunction(name, builder) UserDefinedFunction(func, dataType, inputTypes).withName(name).withNullability(nullable) } @@ -313,7 +387,12 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag](name: String, func: Function13[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, RT]): UserDefinedFunction = { val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: ScalaReflection.schemaFor[A8].dataType :: ScalaReflection.schemaFor[A9].dataType :: ScalaReflection.schemaFor[A10].dataType :: ScalaReflection.schemaFor[A11].dataType :: ScalaReflection.schemaFor[A12].dataType :: ScalaReflection.schemaFor[A13].dataType :: Nil).toOption - def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil), Some(name), nullable) + def builder(e: Seq[Expression]) = if (e.length == 13) { + ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil), Some(name), nullable) + } else { + throw new AnalysisException("Invalid number of arguments for function " + name + + ". Expected: 13; Found: " + e.length) + } functionRegistry.createOrReplaceTempFunction(name, builder) UserDefinedFunction(func, dataType, inputTypes).withName(name).withNullability(nullable) } @@ -326,7 +405,12 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag](name: String, func: Function14[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, RT]): UserDefinedFunction = { val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: ScalaReflection.schemaFor[A8].dataType :: ScalaReflection.schemaFor[A9].dataType :: ScalaReflection.schemaFor[A10].dataType :: ScalaReflection.schemaFor[A11].dataType :: ScalaReflection.schemaFor[A12].dataType :: ScalaReflection.schemaFor[A13].dataType :: ScalaReflection.schemaFor[A14].dataType :: Nil).toOption - def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil), Some(name), nullable) + def builder(e: Seq[Expression]) = if (e.length == 14) { + ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil), Some(name), nullable) + } else { + throw new AnalysisException("Invalid number of arguments for function " + name + + ". Expected: 14; Found: " + e.length) + } functionRegistry.createOrReplaceTempFunction(name, builder) UserDefinedFunction(func, dataType, inputTypes).withName(name).withNullability(nullable) } @@ -339,7 +423,12 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag, A15: TypeTag](name: String, func: Function15[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, A15, RT]): UserDefinedFunction = { val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: ScalaReflection.schemaFor[A8].dataType :: ScalaReflection.schemaFor[A9].dataType :: ScalaReflection.schemaFor[A10].dataType :: ScalaReflection.schemaFor[A11].dataType :: ScalaReflection.schemaFor[A12].dataType :: ScalaReflection.schemaFor[A13].dataType :: ScalaReflection.schemaFor[A14].dataType :: ScalaReflection.schemaFor[A15].dataType :: Nil).toOption - def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil), Some(name), nullable) + def builder(e: Seq[Expression]) = if (e.length == 15) { + ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil), Some(name), nullable) + } else { + throw new AnalysisException("Invalid number of arguments for function " + name + + ". Expected: 15; Found: " + e.length) + } functionRegistry.createOrReplaceTempFunction(name, builder) UserDefinedFunction(func, dataType, inputTypes).withName(name).withNullability(nullable) } @@ -352,7 +441,12 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag, A15: TypeTag, A16: TypeTag](name: String, func: Function16[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, A15, A16, RT]): UserDefinedFunction = { val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: ScalaReflection.schemaFor[A8].dataType :: ScalaReflection.schemaFor[A9].dataType :: ScalaReflection.schemaFor[A10].dataType :: ScalaReflection.schemaFor[A11].dataType :: ScalaReflection.schemaFor[A12].dataType :: ScalaReflection.schemaFor[A13].dataType :: ScalaReflection.schemaFor[A14].dataType :: ScalaReflection.schemaFor[A15].dataType :: ScalaReflection.schemaFor[A16].dataType :: Nil).toOption - def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil), Some(name), nullable) + def builder(e: Seq[Expression]) = if (e.length == 16) { + ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil), Some(name), nullable) + } else { + throw new AnalysisException("Invalid number of arguments for function " + name + + ". Expected: 16; Found: " + e.length) + } functionRegistry.createOrReplaceTempFunction(name, builder) UserDefinedFunction(func, dataType, inputTypes).withName(name).withNullability(nullable) } @@ -365,7 +459,12 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag, A15: TypeTag, A16: TypeTag, A17: TypeTag](name: String, func: Function17[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, A15, A16, A17, RT]): UserDefinedFunction = { val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: ScalaReflection.schemaFor[A8].dataType :: ScalaReflection.schemaFor[A9].dataType :: ScalaReflection.schemaFor[A10].dataType :: ScalaReflection.schemaFor[A11].dataType :: ScalaReflection.schemaFor[A12].dataType :: ScalaReflection.schemaFor[A13].dataType :: ScalaReflection.schemaFor[A14].dataType :: ScalaReflection.schemaFor[A15].dataType :: ScalaReflection.schemaFor[A16].dataType :: ScalaReflection.schemaFor[A17].dataType :: Nil).toOption - def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil), Some(name), nullable) + def builder(e: Seq[Expression]) = if (e.length == 17) { + ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil), Some(name), nullable) + } else { + throw new AnalysisException("Invalid number of arguments for function " + name + + ". Expected: 17; Found: " + e.length) + } functionRegistry.createOrReplaceTempFunction(name, builder) UserDefinedFunction(func, dataType, inputTypes).withName(name).withNullability(nullable) } @@ -378,7 +477,12 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag, A15: TypeTag, A16: TypeTag, A17: TypeTag, A18: TypeTag](name: String, func: Function18[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, A15, A16, A17, A18, RT]): UserDefinedFunction = { val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: ScalaReflection.schemaFor[A8].dataType :: ScalaReflection.schemaFor[A9].dataType :: ScalaReflection.schemaFor[A10].dataType :: ScalaReflection.schemaFor[A11].dataType :: ScalaReflection.schemaFor[A12].dataType :: ScalaReflection.schemaFor[A13].dataType :: ScalaReflection.schemaFor[A14].dataType :: ScalaReflection.schemaFor[A15].dataType :: ScalaReflection.schemaFor[A16].dataType :: ScalaReflection.schemaFor[A17].dataType :: ScalaReflection.schemaFor[A18].dataType :: Nil).toOption - def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil), Some(name), nullable) + def builder(e: Seq[Expression]) = if (e.length == 18) { + ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil), Some(name), nullable) + } else { + throw new AnalysisException("Invalid number of arguments for function " + name + + ". Expected: 18; Found: " + e.length) + } functionRegistry.createOrReplaceTempFunction(name, builder) UserDefinedFunction(func, dataType, inputTypes).withName(name).withNullability(nullable) } @@ -391,7 +495,12 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag, A15: TypeTag, A16: TypeTag, A17: TypeTag, A18: TypeTag, A19: TypeTag](name: String, func: Function19[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, A15, A16, A17, A18, A19, RT]): UserDefinedFunction = { val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: ScalaReflection.schemaFor[A8].dataType :: ScalaReflection.schemaFor[A9].dataType :: ScalaReflection.schemaFor[A10].dataType :: ScalaReflection.schemaFor[A11].dataType :: ScalaReflection.schemaFor[A12].dataType :: ScalaReflection.schemaFor[A13].dataType :: ScalaReflection.schemaFor[A14].dataType :: ScalaReflection.schemaFor[A15].dataType :: ScalaReflection.schemaFor[A16].dataType :: ScalaReflection.schemaFor[A17].dataType :: ScalaReflection.schemaFor[A18].dataType :: ScalaReflection.schemaFor[A19].dataType :: Nil).toOption - def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil), Some(name), nullable) + def builder(e: Seq[Expression]) = if (e.length == 19) { + ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil), Some(name), nullable) + } else { + throw new AnalysisException("Invalid number of arguments for function " + name + + ". Expected: 19; Found: " + e.length) + } functionRegistry.createOrReplaceTempFunction(name, builder) UserDefinedFunction(func, dataType, inputTypes).withName(name).withNullability(nullable) } @@ -404,7 +513,12 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag, A15: TypeTag, A16: TypeTag, A17: TypeTag, A18: TypeTag, A19: TypeTag, A20: TypeTag](name: String, func: Function20[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, A15, A16, A17, A18, A19, A20, RT]): UserDefinedFunction = { val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: ScalaReflection.schemaFor[A8].dataType :: ScalaReflection.schemaFor[A9].dataType :: ScalaReflection.schemaFor[A10].dataType :: ScalaReflection.schemaFor[A11].dataType :: ScalaReflection.schemaFor[A12].dataType :: ScalaReflection.schemaFor[A13].dataType :: ScalaReflection.schemaFor[A14].dataType :: ScalaReflection.schemaFor[A15].dataType :: ScalaReflection.schemaFor[A16].dataType :: ScalaReflection.schemaFor[A17].dataType :: ScalaReflection.schemaFor[A18].dataType :: ScalaReflection.schemaFor[A19].dataType :: ScalaReflection.schemaFor[A20].dataType :: Nil).toOption - def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil), Some(name), nullable) + def builder(e: Seq[Expression]) = if (e.length == 20) { + ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil), Some(name), nullable) + } else { + throw new AnalysisException("Invalid number of arguments for function " + name + + ". Expected: 20; Found: " + e.length) + } functionRegistry.createOrReplaceTempFunction(name, builder) UserDefinedFunction(func, dataType, inputTypes).withName(name).withNullability(nullable) } @@ -417,7 +531,12 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag, A15: TypeTag, A16: TypeTag, A17: TypeTag, A18: TypeTag, A19: TypeTag, A20: TypeTag, A21: TypeTag](name: String, func: Function21[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, A15, A16, A17, A18, A19, A20, A21, RT]): UserDefinedFunction = { val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: ScalaReflection.schemaFor[A8].dataType :: ScalaReflection.schemaFor[A9].dataType :: ScalaReflection.schemaFor[A10].dataType :: ScalaReflection.schemaFor[A11].dataType :: ScalaReflection.schemaFor[A12].dataType :: ScalaReflection.schemaFor[A13].dataType :: ScalaReflection.schemaFor[A14].dataType :: ScalaReflection.schemaFor[A15].dataType :: ScalaReflection.schemaFor[A16].dataType :: ScalaReflection.schemaFor[A17].dataType :: ScalaReflection.schemaFor[A18].dataType :: ScalaReflection.schemaFor[A19].dataType :: ScalaReflection.schemaFor[A20].dataType :: ScalaReflection.schemaFor[A21].dataType :: Nil).toOption - def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil), Some(name), nullable) + def builder(e: Seq[Expression]) = if (e.length == 21) { + ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil), Some(name), nullable) + } else { + throw new AnalysisException("Invalid number of arguments for function " + name + + ". Expected: 21; Found: " + e.length) + } functionRegistry.createOrReplaceTempFunction(name, builder) UserDefinedFunction(func, dataType, inputTypes).withName(name).withNullability(nullable) } @@ -430,7 +549,12 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag, A15: TypeTag, A16: TypeTag, A17: TypeTag, A18: TypeTag, A19: TypeTag, A20: TypeTag, A21: TypeTag, A22: TypeTag](name: String, func: Function22[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, A15, A16, A17, A18, A19, A20, A21, A22, RT]): UserDefinedFunction = { val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: ScalaReflection.schemaFor[A8].dataType :: ScalaReflection.schemaFor[A9].dataType :: ScalaReflection.schemaFor[A10].dataType :: ScalaReflection.schemaFor[A11].dataType :: ScalaReflection.schemaFor[A12].dataType :: ScalaReflection.schemaFor[A13].dataType :: ScalaReflection.schemaFor[A14].dataType :: ScalaReflection.schemaFor[A15].dataType :: ScalaReflection.schemaFor[A16].dataType :: ScalaReflection.schemaFor[A17].dataType :: ScalaReflection.schemaFor[A18].dataType :: ScalaReflection.schemaFor[A19].dataType :: ScalaReflection.schemaFor[A20].dataType :: ScalaReflection.schemaFor[A21].dataType :: ScalaReflection.schemaFor[A22].dataType :: Nil).toOption - def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil), Some(name), nullable) + def builder(e: Seq[Expression]) = if (e.length == 22) { + ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil), Some(name), nullable) + } else { + throw new AnalysisException("Invalid number of arguments for function " + name + + ". Expected: 22; Found: " + e.length) + } functionRegistry.createOrReplaceTempFunction(name, builder) UserDefinedFunction(func, dataType, inputTypes).withName(name).withNullability(nullable) } @@ -531,9 +655,13 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends */ def register(name: String, f: UDF1[_, _], returnType: DataType): Unit = { val func = f.asInstanceOf[UDF1[Any, Any]].call(_: Any) - functionRegistry.createOrReplaceTempFunction( - name, - (e: Seq[Expression]) => ScalaUDF(func, returnType, e)) + def builder(e: Seq[Expression]) = if (e.length == 1) { + ScalaUDF(func, returnType, e) + } else { + throw new AnalysisException("Invalid number of arguments for function " + name + + ". Expected: 1; Found: " + e.length) + } + functionRegistry.createOrReplaceTempFunction(name, builder) } /** @@ -542,9 +670,13 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends */ def register(name: String, f: UDF2[_, _, _], returnType: DataType): Unit = { val func = f.asInstanceOf[UDF2[Any, Any, Any]].call(_: Any, _: Any) - functionRegistry.createOrReplaceTempFunction( - name, - (e: Seq[Expression]) => ScalaUDF(func, returnType, e)) + def builder(e: Seq[Expression]) = if (e.length == 2) { + ScalaUDF(func, returnType, e) + } else { + throw new AnalysisException("Invalid number of arguments for function " + name + + ". Expected: 2; Found: " + e.length) + } + functionRegistry.createOrReplaceTempFunction(name, builder) } /** @@ -553,9 +685,13 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends */ def register(name: String, f: UDF3[_, _, _, _], returnType: DataType): Unit = { val func = f.asInstanceOf[UDF3[Any, Any, Any, Any]].call(_: Any, _: Any, _: Any) - functionRegistry.createOrReplaceTempFunction( - name, - (e: Seq[Expression]) => ScalaUDF(func, returnType, e)) + def builder(e: Seq[Expression]) = if (e.length == 3) { + ScalaUDF(func, returnType, e) + } else { + throw new AnalysisException("Invalid number of arguments for function " + name + + ". Expected: 3; Found: " + e.length) + } + functionRegistry.createOrReplaceTempFunction(name, builder) } /** @@ -564,9 +700,13 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends */ def register(name: String, f: UDF4[_, _, _, _, _], returnType: DataType): Unit = { val func = f.asInstanceOf[UDF4[Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any) - functionRegistry.createOrReplaceTempFunction( - name, - (e: Seq[Expression]) => ScalaUDF(func, returnType, e)) + def builder(e: Seq[Expression]) = if (e.length == 4) { + ScalaUDF(func, returnType, e) + } else { + throw new AnalysisException("Invalid number of arguments for function " + name + + ". Expected: 4; Found: " + e.length) + } + functionRegistry.createOrReplaceTempFunction(name, builder) } /** @@ -575,9 +715,13 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends */ def register(name: String, f: UDF5[_, _, _, _, _, _], returnType: DataType): Unit = { val func = f.asInstanceOf[UDF5[Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any) - functionRegistry.createOrReplaceTempFunction( - name, - (e: Seq[Expression]) => ScalaUDF(func, returnType, e)) + def builder(e: Seq[Expression]) = if (e.length == 5) { + ScalaUDF(func, returnType, e) + } else { + throw new AnalysisException("Invalid number of arguments for function " + name + + ". Expected: 5; Found: " + e.length) + } + functionRegistry.createOrReplaceTempFunction(name, builder) } /** @@ -586,9 +730,13 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends */ def register(name: String, f: UDF6[_, _, _, _, _, _, _], returnType: DataType): Unit = { val func = f.asInstanceOf[UDF6[Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any) - functionRegistry.createOrReplaceTempFunction( - name, - (e: Seq[Expression]) => ScalaUDF(func, returnType, e)) + def builder(e: Seq[Expression]) = if (e.length == 6) { + ScalaUDF(func, returnType, e) + } else { + throw new AnalysisException("Invalid number of arguments for function " + name + + ". Expected: 6; Found: " + e.length) + } + functionRegistry.createOrReplaceTempFunction(name, builder) } /** @@ -597,9 +745,13 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends */ def register(name: String, f: UDF7[_, _, _, _, _, _, _, _], returnType: DataType): Unit = { val func = f.asInstanceOf[UDF7[Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any) - functionRegistry.createOrReplaceTempFunction( - name, - (e: Seq[Expression]) => ScalaUDF(func, returnType, e)) + def builder(e: Seq[Expression]) = if (e.length == 7) { + ScalaUDF(func, returnType, e) + } else { + throw new AnalysisException("Invalid number of arguments for function " + name + + ". Expected: 7; Found: " + e.length) + } + functionRegistry.createOrReplaceTempFunction(name, builder) } /** @@ -608,9 +760,13 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends */ def register(name: String, f: UDF8[_, _, _, _, _, _, _, _, _], returnType: DataType): Unit = { val func = f.asInstanceOf[UDF8[Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any) - functionRegistry.createOrReplaceTempFunction( - name, - (e: Seq[Expression]) => ScalaUDF(func, returnType, e)) + def builder(e: Seq[Expression]) = if (e.length == 8) { + ScalaUDF(func, returnType, e) + } else { + throw new AnalysisException("Invalid number of arguments for function " + name + + ". Expected: 8; Found: " + e.length) + } + functionRegistry.createOrReplaceTempFunction(name, builder) } /** @@ -619,9 +775,13 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends */ def register(name: String, f: UDF9[_, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = { val func = f.asInstanceOf[UDF9[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any) - functionRegistry.createOrReplaceTempFunction( - name, - (e: Seq[Expression]) => ScalaUDF(func, returnType, e)) + def builder(e: Seq[Expression]) = if (e.length == 9) { + ScalaUDF(func, returnType, e) + } else { + throw new AnalysisException("Invalid number of arguments for function " + name + + ". Expected: 9; Found: " + e.length) + } + functionRegistry.createOrReplaceTempFunction(name, builder) } /** @@ -630,9 +790,13 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends */ def register(name: String, f: UDF10[_, _, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = { val func = f.asInstanceOf[UDF10[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any) - functionRegistry.createOrReplaceTempFunction( - name, - (e: Seq[Expression]) => ScalaUDF(func, returnType, e)) + def builder(e: Seq[Expression]) = if (e.length == 10) { + ScalaUDF(func, returnType, e) + } else { + throw new AnalysisException("Invalid number of arguments for function " + name + + ". Expected: 10; Found: " + e.length) + } + functionRegistry.createOrReplaceTempFunction(name, builder) } /** @@ -641,9 +805,13 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends */ def register(name: String, f: UDF11[_, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = { val func = f.asInstanceOf[UDF11[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any) - functionRegistry.createOrReplaceTempFunction( - name, - (e: Seq[Expression]) => ScalaUDF(func, returnType, e)) + def builder(e: Seq[Expression]) = if (e.length == 11) { + ScalaUDF(func, returnType, e) + } else { + throw new AnalysisException("Invalid number of arguments for function " + name + + ". Expected: 11; Found: " + e.length) + } + functionRegistry.createOrReplaceTempFunction(name, builder) } /** @@ -652,9 +820,13 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends */ def register(name: String, f: UDF12[_, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = { val func = f.asInstanceOf[UDF12[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any) - functionRegistry.createOrReplaceTempFunction( - name, - (e: Seq[Expression]) => ScalaUDF(func, returnType, e)) + def builder(e: Seq[Expression]) = if (e.length == 12) { + ScalaUDF(func, returnType, e) + } else { + throw new AnalysisException("Invalid number of arguments for function " + name + + ". Expected: 12; Found: " + e.length) + } + functionRegistry.createOrReplaceTempFunction(name, builder) } /** @@ -663,9 +835,13 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends */ def register(name: String, f: UDF13[_, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = { val func = f.asInstanceOf[UDF13[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any) - functionRegistry.createOrReplaceTempFunction( - name, - (e: Seq[Expression]) => ScalaUDF(func, returnType, e)) + def builder(e: Seq[Expression]) = if (e.length == 13) { + ScalaUDF(func, returnType, e) + } else { + throw new AnalysisException("Invalid number of arguments for function " + name + + ". Expected: 13; Found: " + e.length) + } + functionRegistry.createOrReplaceTempFunction(name, builder) } /** @@ -674,9 +850,13 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends */ def register(name: String, f: UDF14[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = { val func = f.asInstanceOf[UDF14[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any) - functionRegistry.createOrReplaceTempFunction( - name, - (e: Seq[Expression]) => ScalaUDF(func, returnType, e)) + def builder(e: Seq[Expression]) = if (e.length == 14) { + ScalaUDF(func, returnType, e) + } else { + throw new AnalysisException("Invalid number of arguments for function " + name + + ". Expected: 14; Found: " + e.length) + } + functionRegistry.createOrReplaceTempFunction(name, builder) } /** @@ -685,9 +865,13 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends */ def register(name: String, f: UDF15[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = { val func = f.asInstanceOf[UDF15[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any) - functionRegistry.createOrReplaceTempFunction( - name, - (e: Seq[Expression]) => ScalaUDF(func, returnType, e)) + def builder(e: Seq[Expression]) = if (e.length == 15) { + ScalaUDF(func, returnType, e) + } else { + throw new AnalysisException("Invalid number of arguments for function " + name + + ". Expected: 15; Found: " + e.length) + } + functionRegistry.createOrReplaceTempFunction(name, builder) } /** @@ -696,9 +880,13 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends */ def register(name: String, f: UDF16[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = { val func = f.asInstanceOf[UDF16[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any) - functionRegistry.createOrReplaceTempFunction( - name, - (e: Seq[Expression]) => ScalaUDF(func, returnType, e)) + def builder(e: Seq[Expression]) = if (e.length == 16) { + ScalaUDF(func, returnType, e) + } else { + throw new AnalysisException("Invalid number of arguments for function " + name + + ". Expected: 16; Found: " + e.length) + } + functionRegistry.createOrReplaceTempFunction(name, builder) } /** @@ -707,9 +895,13 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends */ def register(name: String, f: UDF17[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = { val func = f.asInstanceOf[UDF17[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any) - functionRegistry.createOrReplaceTempFunction( - name, - (e: Seq[Expression]) => ScalaUDF(func, returnType, e)) + def builder(e: Seq[Expression]) = if (e.length == 17) { + ScalaUDF(func, returnType, e) + } else { + throw new AnalysisException("Invalid number of arguments for function " + name + + ". Expected: 17; Found: " + e.length) + } + functionRegistry.createOrReplaceTempFunction(name, builder) } /** @@ -718,9 +910,13 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends */ def register(name: String, f: UDF18[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = { val func = f.asInstanceOf[UDF18[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any) - functionRegistry.createOrReplaceTempFunction( - name, - (e: Seq[Expression]) => ScalaUDF(func, returnType, e)) + def builder(e: Seq[Expression]) = if (e.length == 18) { + ScalaUDF(func, returnType, e) + } else { + throw new AnalysisException("Invalid number of arguments for function " + name + + ". Expected: 18; Found: " + e.length) + } + functionRegistry.createOrReplaceTempFunction(name, builder) } /** @@ -729,9 +925,13 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends */ def register(name: String, f: UDF19[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = { val func = f.asInstanceOf[UDF19[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any) - functionRegistry.createOrReplaceTempFunction( - name, - (e: Seq[Expression]) => ScalaUDF(func, returnType, e)) + def builder(e: Seq[Expression]) = if (e.length == 19) { + ScalaUDF(func, returnType, e) + } else { + throw new AnalysisException("Invalid number of arguments for function " + name + + ". Expected: 19; Found: " + e.length) + } + functionRegistry.createOrReplaceTempFunction(name, builder) } /** @@ -740,9 +940,13 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends */ def register(name: String, f: UDF20[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = { val func = f.asInstanceOf[UDF20[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any) - functionRegistry.createOrReplaceTempFunction( - name, - (e: Seq[Expression]) => ScalaUDF(func, returnType, e)) + def builder(e: Seq[Expression]) = if (e.length == 20) { + ScalaUDF(func, returnType, e) + } else { + throw new AnalysisException("Invalid number of arguments for function " + name + + ". Expected: 20; Found: " + e.length) + } + functionRegistry.createOrReplaceTempFunction(name, builder) } /** @@ -751,9 +955,13 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends */ def register(name: String, f: UDF21[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = { val func = f.asInstanceOf[UDF21[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any) - functionRegistry.createOrReplaceTempFunction( - name, - (e: Seq[Expression]) => ScalaUDF(func, returnType, e)) + def builder(e: Seq[Expression]) = if (e.length == 21) { + ScalaUDF(func, returnType, e) + } else { + throw new AnalysisException("Invalid number of arguments for function " + name + + ". Expected: 21; Found: " + e.length) + } + functionRegistry.createOrReplaceTempFunction(name, builder) } /** @@ -762,9 +970,13 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends */ def register(name: String, f: UDF22[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = { val func = f.asInstanceOf[UDF22[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any) - functionRegistry.createOrReplaceTempFunction( - name, - (e: Seq[Expression]) => ScalaUDF(func, returnType, e)) + def builder(e: Seq[Expression]) = if (e.length == 22) { + ScalaUDF(func, returnType, e) + } else { + throw new AnalysisException("Invalid number of arguments for function " + name + + ". Expected: 22; Found: " + e.length) + } + functionRegistry.createOrReplaceTempFunction(name, builder) } // scalastyle:on line.size.limit diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaUDFSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaUDFSuite.java index 250fa674d8ecc..4fb2988f24d26 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaUDFSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaUDFSuite.java @@ -25,6 +25,7 @@ import org.junit.Before; import org.junit.Test; +import org.apache.spark.sql.AnalysisException; import org.apache.spark.sql.Row; import org.apache.spark.sql.SparkSession; import org.apache.spark.sql.api.java.UDF2; @@ -105,4 +106,11 @@ public void udf4Test() { } Assert.assertEquals(55, sum); } + + @SuppressWarnings("unchecked") + @Test(expected = AnalysisException.class) + public void udf5Test() { + spark.udf().register("inc", (Long i) -> i + 1, DataTypes.LongType); + List results = spark.sql("SELECT inc(1, 5)").collectAsList(); + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala index b4f744b193ada..335b882ace92a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala @@ -71,12 +71,21 @@ class UDFSuite extends QueryTest with SharedSQLContext { } } - test("error reporting for incorrect number of arguments") { + test("error reporting for incorrect number of arguments - builtin function") { val df = spark.emptyDataFrame val e = intercept[AnalysisException] { df.selectExpr("substr('abcd', 2, 3, 4)") } - assert(e.getMessage.contains("arguments")) + assert(e.getMessage.contains("Invalid number of arguments for function substr")) + } + + test("error reporting for incorrect number of arguments - udf") { + val df = spark.emptyDataFrame + val e = intercept[AnalysisException] { + spark.udf.register("foo", (_: String).length) + df.selectExpr("foo(2, 3, 4)") + } + assert(e.getMessage.contains("Invalid number of arguments for function foo")) } test("error reporting for undefined functions") { From 833eab2c9bd273ee9577fbf9e480d3e3a4b7d203 Mon Sep 17 00:00:00 2001 From: Shixiong Zhu Date: Tue, 11 Jul 2017 11:26:17 +0800 Subject: [PATCH 0928/1765] [SPARK-21369][CORE] Don't use Scala Tuple2 in common/network-* ## What changes were proposed in this pull request? Remove all usages of Scala Tuple2 from common/network-* projects. Otherwise, Yarn users cannot use `spark.reducer.maxReqSizeShuffleToMem`. ## How was this patch tested? Jenkins. Author: Shixiong Zhu Closes #18593 from zsxwing/SPARK-21369. --- common/network-common/pom.xml | 3 ++- .../client/TransportResponseHandler.java | 20 +++++++++---------- .../server/OneForOneStreamManager.java | 17 +++++----------- common/network-shuffle/pom.xml | 1 + common/network-yarn/pom.xml | 1 + 5 files changed, 19 insertions(+), 23 deletions(-) diff --git a/common/network-common/pom.xml b/common/network-common/pom.xml index 066970f24205f..0254d0cefc368 100644 --- a/common/network-common/pom.xml +++ b/common/network-common/pom.xml @@ -90,7 +90,8 @@ org.apache.spark spark-tags_${scala.binary.version} - + test + + runtime log4j @@ -1859,9 +1859,9 @@ ${antlr4.version} - ${jline.groupid} + jline jline - ${jline.version} + 2.12.1 org.apache.commons @@ -1933,6 +1933,7 @@ --> org.jboss.netty org.codehaus.groovy + *:*_2.10 true @@ -1987,6 +1988,8 @@ -unchecked -deprecation -feature + -explaintypes + -Yno-adapted-args
    -Xms1024m @@ -2585,44 +2588,6 @@ - - scala-2.10 - - scala-2.10 - - - 2.10.6 - 2.10 - ${scala.version} - org.scala-lang - - - - - org.apache.maven.plugins - maven-enforcer-plugin - - - enforce-versions - - enforce - - - - - - *:*_2.11 - - - - - - - - - - - test-java-home @@ -2633,16 +2598,18 @@ + scala-2.11 - - !scala-2.10 - + + + + diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala index 89b0c7a3ab7b0..41f3a0451aa8a 100644 --- a/project/SparkBuild.scala +++ b/project/SparkBuild.scala @@ -87,19 +87,11 @@ object SparkBuild extends PomBuild { val projectsMap: Map[String, Seq[Setting[_]]] = Map.empty override val profiles = { - val profiles = Properties.envOrNone("SBT_MAVEN_PROFILES") match { + Properties.envOrNone("SBT_MAVEN_PROFILES") match { case None => Seq("sbt") case Some(v) => v.split("(\\s+|,)").filterNot(_.isEmpty).map(_.trim.replaceAll("-P", "")).toSeq } - - if (System.getProperty("scala-2.10") == "") { - // To activate scala-2.10 profile, replace empty property value to non-empty value - // in the same way as Maven which handles -Dname as -Dname=true before executes build process. - // see: https://github.com/apache/maven/blob/maven-3.0.4/maven-embedder/src/main/java/org/apache/maven/cli/MavenCli.java#L1082 - System.setProperty("scala-2.10", "true") - } - profiles } Properties.envOrNone("SBT_MAVEN_PROPERTIES") match { @@ -234,9 +226,7 @@ object SparkBuild extends PomBuild { }, javacJVMVersion := "1.8", - // SBT Scala 2.10 build still doesn't support Java 8, because scalac 2.10 doesn't, but, - // it also doesn't touch Java 8 code and it's OK to emit Java 7 bytecode in this case - scalacJVMVersion := (if (System.getProperty("scala-2.10") == "true") "1.7" else "1.8"), + scalacJVMVersion := "1.8", javacOptions in Compile ++= Seq( "-encoding", "UTF-8", @@ -477,7 +467,6 @@ object OldDeps { def oldDepsSettings() = Defaults.coreDefaultSettings ++ Seq( name := "old-deps", - scalaVersion := "2.10.5", libraryDependencies := allPreviousArtifactKeys.value.flatten ) } @@ -756,13 +745,7 @@ object CopyDependencies { object TestSettings { import BuildCommons._ - private val scalaBinaryVersion = - if (System.getProperty("scala-2.10") == "true") { - "2.10" - } else { - "2.11" - } - + private val scalaBinaryVersion = "2.11" lazy val settings = Seq ( // Fork new JVMs for tests and set Java options for those fork := true, diff --git a/python/run-tests.py b/python/run-tests.py index b2e50435bb192..afd3d29a0ff90 100755 --- a/python/run-tests.py +++ b/python/run-tests.py @@ -54,7 +54,8 @@ def print_red(text): LOGGER = logging.getLogger() # Find out where the assembly jars are located. -for scala in ["2.11", "2.10"]: +# Later, add back 2.12 to this list: +for scala in ["2.11"]: build_dir = os.path.join(SPARK_HOME, "assembly", "target", "scala-" + scala) if os.path.isdir(build_dir): SPARK_DIST_CLASSPATH = os.path.join(build_dir, "jars", "*") diff --git a/repl/pom.xml b/repl/pom.xml index 6d133a3cfff7d..51eb9b60dd54a 100644 --- a/repl/pom.xml +++ b/repl/pom.xml @@ -32,8 +32,8 @@ repl - scala-2.10/src/main/scala - scala-2.10/src/test/scala + scala-2.11/src/main/scala + scala-2.11/src/test/scala @@ -71,7 +71,7 @@ ${scala.version} - ${jline.groupid} + jline jline @@ -170,23 +170,17 @@ + + + diff --git a/repl/scala-2.10/src/main/scala/org/apache/spark/repl/Main.scala b/repl/scala-2.10/src/main/scala/org/apache/spark/repl/Main.scala deleted file mode 100644 index fba321be91886..0000000000000 --- a/repl/scala-2.10/src/main/scala/org/apache/spark/repl/Main.scala +++ /dev/null @@ -1,37 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.repl - -import org.apache.spark.internal.Logging - -object Main extends Logging { - - initializeLogIfNecessary(true) - Signaling.cancelOnInterrupt() - - private var _interp: SparkILoop = _ - - def interp = _interp - - def interp_=(i: SparkILoop) { _interp = i } - - def main(args: Array[String]) { - _interp = new SparkILoop - _interp.process(args) - } -} diff --git a/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkCommandLine.scala b/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkCommandLine.scala deleted file mode 100644 index be9b79021d2a8..0000000000000 --- a/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkCommandLine.scala +++ /dev/null @@ -1,46 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.repl - -import scala.tools.nsc.{CompilerCommand, Settings} - -import org.apache.spark.annotation.DeveloperApi - -/** - * Command class enabling Spark-specific command line options (provided by - * org.apache.spark.repl.SparkRunnerSettings). - * - * @example new SparkCommandLine(Nil).settings - * - * @param args The list of command line arguments - * @param settings The underlying settings to associate with this set of - * command-line options - */ -@DeveloperApi -class SparkCommandLine(args: List[String], override val settings: Settings) - extends CompilerCommand(args, settings) { - def this(args: List[String], error: String => Unit) { - this(args, new SparkRunnerSettings(error)) - } - - def this(args: List[String]) { - // scalastyle:off println - this(args, str => Console.println("Error: " + str)) - // scalastyle:on println - } -} diff --git a/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkExprTyper.scala b/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkExprTyper.scala deleted file mode 100644 index 2b5d56a895902..0000000000000 --- a/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkExprTyper.scala +++ /dev/null @@ -1,114 +0,0 @@ -// scalastyle:off - -/* NSC -- new Scala compiler - * Copyright 2005-2013 LAMP/EPFL - * @author Paul Phillips - */ - -package org.apache.spark.repl - -import scala.tools.nsc._ -import scala.tools.nsc.interpreter._ - -import scala.reflect.internal.util.BatchSourceFile -import scala.tools.nsc.ast.parser.Tokens.EOF - -import org.apache.spark.internal.Logging - -private[repl] trait SparkExprTyper extends Logging { - val repl: SparkIMain - - import repl._ - import global.{ reporter => _, Import => _, _ } - import definitions._ - import syntaxAnalyzer.{ UnitParser, UnitScanner, token2name } - import naming.freshInternalVarName - - object codeParser extends { val global: repl.global.type = repl.global } with CodeHandlers[Tree] { - def applyRule[T](code: String, rule: UnitParser => T): T = { - reporter.reset() - val scanner = newUnitParser(code) - val result = rule(scanner) - - if (!reporter.hasErrors) - scanner.accept(EOF) - - result - } - - def defns(code: String) = stmts(code) collect { case x: DefTree => x } - def expr(code: String) = applyRule(code, _.expr()) - def stmts(code: String) = applyRule(code, _.templateStats()) - def stmt(code: String) = stmts(code).last // guaranteed nonempty - } - - /** Parse a line into a sequence of trees. Returns None if the input is incomplete. */ - def parse(line: String): Option[List[Tree]] = debugging(s"""parse("$line")""") { - var isIncomplete = false - reporter.withIncompleteHandler((_, _) => isIncomplete = true) { - val trees = codeParser.stmts(line) - if (reporter.hasErrors) { - Some(Nil) - } else if (isIncomplete) { - None - } else { - Some(trees) - } - } - } - // def parsesAsExpr(line: String) = { - // import codeParser._ - // (opt expr line).isDefined - // } - - def symbolOfLine(code: String): Symbol = { - def asExpr(): Symbol = { - val name = freshInternalVarName() - // Typing it with a lazy val would give us the right type, but runs - // into compiler bugs with things like existentials, so we compile it - // behind a def and strip the NullaryMethodType which wraps the expr. - val line = "def " + name + " = {\n" + code + "\n}" - - interpretSynthetic(line) match { - case IR.Success => - val sym0 = symbolOfTerm(name) - // drop NullaryMethodType - val sym = sym0.cloneSymbol setInfo afterTyper(sym0.info.finalResultType) - if (sym.info.typeSymbol eq UnitClass) NoSymbol else sym - case _ => NoSymbol - } - } - def asDefn(): Symbol = { - val old = repl.definedSymbolList.toSet - - interpretSynthetic(code) match { - case IR.Success => - repl.definedSymbolList filterNot old match { - case Nil => NoSymbol - case sym :: Nil => sym - case syms => NoSymbol.newOverloaded(NoPrefix, syms) - } - case _ => NoSymbol - } - } - beQuietDuring(asExpr()) orElse beQuietDuring(asDefn()) - } - - private var typeOfExpressionDepth = 0 - def typeOfExpression(expr: String, silent: Boolean = true): Type = { - if (typeOfExpressionDepth > 2) { - logDebug("Terminating typeOfExpression recursion for expression: " + expr) - return NoType - } - typeOfExpressionDepth += 1 - // Don't presently have a good way to suppress undesirable success output - // while letting errors through, so it is first trying it silently: if there - // is an error, and errors are desired, then it re-evaluates non-silently - // to induce the error message. - try beSilentDuring(symbolOfLine(expr).tpe) match { - case NoType if !silent => symbolOfLine(expr).tpe // generate error - case tpe => tpe - } - finally typeOfExpressionDepth -= 1 - } -} diff --git a/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkHelper.scala b/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkHelper.scala deleted file mode 100644 index 955be17a73b85..0000000000000 --- a/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkHelper.scala +++ /dev/null @@ -1,39 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package scala.tools.nsc - -import org.apache.spark.annotation.DeveloperApi - -// NOTE: Forced to be public (and in scala.tools.nsc package) to access the -// settings "explicitParentLoader" method - -/** - * Provides exposure for the explicitParentLoader method on settings instances. - */ -@DeveloperApi -object SparkHelper { - /** - * Retrieves the explicit parent loader for the provided settings. - * - * @param settings The settings whose explicit parent loader to retrieve - * - * @return The Optional classloader representing the explicit parent loader - */ - @DeveloperApi - def explicitParentLoader(settings: Settings) = settings.explicitParentLoader -} diff --git a/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkILoop.scala b/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkILoop.scala deleted file mode 100644 index b7237a6ce822f..0000000000000 --- a/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkILoop.scala +++ /dev/null @@ -1,1145 +0,0 @@ -// scalastyle:off - -/* NSC -- new Scala compiler - * Copyright 2005-2013 LAMP/EPFL - * @author Alexander Spoon - */ - -package org.apache.spark.repl - - -import java.net.URL - -import scala.reflect.io.AbstractFile -import scala.tools.nsc._ -import scala.tools.nsc.backend.JavaPlatform -import scala.tools.nsc.interpreter._ -import scala.tools.nsc.interpreter.{Results => IR} -import Predef.{println => _, _} -import java.io.{BufferedReader, FileReader} -import java.net.URI -import java.util.concurrent.locks.ReentrantLock -import scala.sys.process.Process -import scala.tools.nsc.interpreter.session._ -import scala.util.Properties.{jdkHome, javaVersion} -import scala.tools.util.{Javap} -import scala.annotation.tailrec -import scala.collection.mutable.ListBuffer -import scala.concurrent.ops -import scala.tools.nsc.util._ -import scala.tools.nsc.interpreter._ -import scala.tools.nsc.io.{File, Directory} -import scala.reflect.NameTransformer._ -import scala.tools.nsc.util.ScalaClassLoader._ -import scala.tools.util._ -import scala.language.{implicitConversions, existentials, postfixOps} -import scala.reflect.{ClassTag, classTag} -import scala.tools.reflect.StdRuntimeTags._ - -import java.lang.{Class => jClass} -import scala.reflect.api.{Mirror, TypeCreator, Universe => ApiUniverse} - -import org.apache.spark.SparkConf -import org.apache.spark.SparkContext -import org.apache.spark.annotation.DeveloperApi -import org.apache.spark.internal.Logging -import org.apache.spark.sql.SparkSession -import org.apache.spark.util.Utils - -/** The Scala interactive shell. It provides a read-eval-print loop - * around the Interpreter class. - * After instantiation, clients should call the main() method. - * - * If no in0 is specified, then input will come from the console, and - * the class will attempt to provide input editing feature such as - * input history. - * - * @author Moez A. Abdel-Gawad - * @author Lex Spoon - * @version 1.2 - */ -@DeveloperApi -class SparkILoop( - private val in0: Option[BufferedReader], - protected val out: JPrintWriter, - val master: Option[String] -) extends AnyRef with LoopCommands with SparkILoopInit with Logging { - def this(in0: BufferedReader, out: JPrintWriter, master: String) = this(Some(in0), out, Some(master)) - def this(in0: BufferedReader, out: JPrintWriter) = this(Some(in0), out, None) - def this() = this(None, new JPrintWriter(Console.out, true), None) - - private var in: InteractiveReader = _ // the input stream from which commands come - - // NOTE: Exposed in package for testing - private[repl] var settings: Settings = _ - - private[repl] var intp: SparkIMain = _ - - @deprecated("Use `intp` instead.", "2.9.0") def interpreter = intp - @deprecated("Use `intp` instead.", "2.9.0") def interpreter_= (i: SparkIMain): Unit = intp = i - - /** Having inherited the difficult "var-ness" of the repl instance, - * I'm trying to work around it by moving operations into a class from - * which it will appear a stable prefix. - */ - private def onIntp[T](f: SparkIMain => T): T = f(intp) - - class IMainOps[T <: SparkIMain](val intp: T) { - import intp._ - import global._ - - def printAfterTyper(msg: => String) = - intp.reporter printMessage afterTyper(msg) - - /** Strip NullaryMethodType artifacts. */ - private def replInfo(sym: Symbol) = { - sym.info match { - case NullaryMethodType(restpe) if sym.isAccessor => restpe - case info => info - } - } - def echoTypeStructure(sym: Symbol) = - printAfterTyper("" + deconstruct.show(replInfo(sym))) - - def echoTypeSignature(sym: Symbol, verbose: Boolean) = { - if (verbose) SparkILoop.this.echo("// Type signature") - printAfterTyper("" + replInfo(sym)) - - if (verbose) { - SparkILoop.this.echo("\n// Internal Type structure") - echoTypeStructure(sym) - } - } - } - implicit def stabilizeIMain(intp: SparkIMain) = new IMainOps[intp.type](intp) - - /** TODO - - * -n normalize - * -l label with case class parameter names - * -c complete - leave nothing out - */ - private def typeCommandInternal(expr: String, verbose: Boolean): Result = { - onIntp { intp => - val sym = intp.symbolOfLine(expr) - if (sym.exists) intp.echoTypeSignature(sym, verbose) - else "" - } - } - - // NOTE: Must be public for visibility - @DeveloperApi - var sparkContext: SparkContext = _ - - override def echoCommandMessage(msg: String) { - intp.reporter printMessage msg - } - - // def isAsync = !settings.Yreplsync.value - private[repl] def isAsync = false - // lazy val power = new Power(intp, new StdReplVals(this))(tagOfStdReplVals, classTag[StdReplVals]) - private def history = in.history - - /** The context class loader at the time this object was created */ - protected val originalClassLoader = Utils.getContextOrSparkClassLoader - - // classpath entries added via :cp - private var addedClasspath: String = "" - - /** A reverse list of commands to replay if the user requests a :replay */ - private var replayCommandStack: List[String] = Nil - - /** A list of commands to replay if the user requests a :replay */ - private def replayCommands = replayCommandStack.reverse - - /** Record a command for replay should the user request a :replay */ - private def addReplay(cmd: String) = replayCommandStack ::= cmd - - private def savingReplayStack[T](body: => T): T = { - val saved = replayCommandStack - try body - finally replayCommandStack = saved - } - private def savingReader[T](body: => T): T = { - val saved = in - try body - finally in = saved - } - - - private def sparkCleanUp() { - echo("Stopping spark context.") - intp.beQuietDuring { - command("sc.stop()") - } - } - /** Close the interpreter and set the var to null. */ - private def closeInterpreter() { - if (intp ne null) { - sparkCleanUp() - intp.close() - intp = null - } - } - - class SparkILoopInterpreter extends SparkIMain(settings, out) { - outer => - - override private[repl] lazy val formatting = new Formatting { - def prompt = SparkILoop.this.prompt - } - override protected def parentClassLoader = SparkHelper.explicitParentLoader(settings).getOrElse(classOf[SparkILoop].getClassLoader) - } - - /** - * Constructs a new interpreter. - */ - protected def createInterpreter() { - require(settings != null) - - if (addedClasspath != "") settings.classpath.append(addedClasspath) - val addedJars = - if (Utils.isWindows) { - // Strip any URI scheme prefix so we can add the correct path to the classpath - // e.g. file:/C:/my/path.jar -> C:/my/path.jar - getAddedJars().map { jar => new URI(jar).getPath.stripPrefix("/") } - } else { - // We need new URI(jar).getPath here for the case that `jar` includes encoded white space (%20). - getAddedJars().map { jar => new URI(jar).getPath } - } - // work around for Scala bug - val totalClassPath = addedJars.foldLeft( - settings.classpath.value)((l, r) => ClassPath.join(l, r)) - this.settings.classpath.value = totalClassPath - - intp = new SparkILoopInterpreter - } - - /** print a friendly help message */ - private def helpCommand(line: String): Result = { - if (line == "") helpSummary() - else uniqueCommand(line) match { - case Some(lc) => echo("\n" + lc.longHelp) - case _ => ambiguousError(line) - } - } - private def helpSummary() = { - val usageWidth = commands map (_.usageMsg.length) max - val formatStr = "%-" + usageWidth + "s %s %s" - - echo("All commands can be abbreviated, e.g. :he instead of :help.") - echo("Those marked with a * have more detailed help, e.g. :help imports.\n") - - commands foreach { cmd => - val star = if (cmd.hasLongHelp) "*" else " " - echo(formatStr.format(cmd.usageMsg, star, cmd.help)) - } - } - private def ambiguousError(cmd: String): Result = { - matchingCommands(cmd) match { - case Nil => echo(cmd + ": no such command. Type :help for help.") - case xs => echo(cmd + " is ambiguous: did you mean " + xs.map(":" + _.name).mkString(" or ") + "?") - } - Result(true, None) - } - private def matchingCommands(cmd: String) = commands filter (_.name startsWith cmd) - private def uniqueCommand(cmd: String): Option[LoopCommand] = { - // this lets us add commands willy-nilly and only requires enough command to disambiguate - matchingCommands(cmd) match { - case List(x) => Some(x) - // exact match OK even if otherwise appears ambiguous - case xs => xs find (_.name == cmd) - } - } - private var fallbackMode = false - - private def toggleFallbackMode() { - val old = fallbackMode - fallbackMode = !old - System.setProperty("spark.repl.fallback", fallbackMode.toString) - echo(s""" - |Switched ${if (old) "off" else "on"} fallback mode without restarting. - | If you have defined classes in the repl, it would - |be good to redefine them incase you plan to use them. If you still run - |into issues it would be good to restart the repl and turn on `:fallback` - |mode as first command. - """.stripMargin) - } - - /** Show the history */ - private lazy val historyCommand = new LoopCommand("history", "show the history (optional num is commands to show)") { - override def usage = "[num]" - def defaultLines = 20 - - def apply(line: String): Result = { - if (history eq NoHistory) - return "No history available." - - val xs = words(line) - val current = history.index - val count = try xs.head.toInt catch { case _: Exception => defaultLines } - val lines = history.asStrings takeRight count - val offset = current - lines.size + 1 - - for ((line, index) <- lines.zipWithIndex) - echo("%3d %s".format(index + offset, line)) - } - } - - // When you know you are most likely breaking into the middle - // of a line being typed. This softens the blow. - private[repl] def echoAndRefresh(msg: String) = { - echo("\n" + msg) - in.redrawLine() - } - private[repl] def echo(msg: String) = { - out println msg - out.flush() - } - private def echoNoNL(msg: String) = { - out print msg - out.flush() - } - - /** Search the history */ - private def searchHistory(_cmdline: String) { - val cmdline = _cmdline.toLowerCase - val offset = history.index - history.size + 1 - - for ((line, index) <- history.asStrings.zipWithIndex ; if line.toLowerCase contains cmdline) - echo("%d %s".format(index + offset, line)) - } - - private var currentPrompt = Properties.shellPromptString - - /** - * Sets the prompt string used by the REPL. - * - * @param prompt The new prompt string - */ - @DeveloperApi - def setPrompt(prompt: String) = currentPrompt = prompt - - /** - * Represents the current prompt string used by the REPL. - * - * @return The current prompt string - */ - @DeveloperApi - def prompt = currentPrompt - - import LoopCommand.{ cmd, nullary } - - /** Standard commands */ - private lazy val standardCommands = List( - cmd("cp", "", "add a jar or directory to the classpath", addClasspath), - cmd("help", "[command]", "print this summary or command-specific help", helpCommand), - historyCommand, - cmd("h?", "", "search the history", searchHistory), - cmd("imports", "[name name ...]", "show import history, identifying sources of names", importsCommand), - cmd("implicits", "[-v]", "show the implicits in scope", implicitsCommand), - cmd("javap", "", "disassemble a file or class name", javapCommand), - cmd("load", "", "load and interpret a Scala file", loadCommand), - nullary("paste", "enter paste mode: all input up to ctrl-D compiled together", pasteCommand), -// nullary("power", "enable power user mode", powerCmd), - nullary("quit", "exit the repl", () => Result(false, None)), - nullary("replay", "reset execution and replay all previous commands", replay), - nullary("reset", "reset the repl to its initial state, forgetting all session entries", resetCommand), - shCommand, - nullary("silent", "disable/enable automatic printing of results", verbosity), - nullary("fallback", """ - |disable/enable advanced repl changes, these fix some issues but may introduce others. - |This mode will be removed once these fixes stablize""".stripMargin, toggleFallbackMode), - cmd("type", "[-v] ", "display the type of an expression without evaluating it", typeCommand), - nullary("warnings", "show the suppressed warnings from the most recent line which had any", warningsCommand) - ) - - /** Power user commands */ - private lazy val powerCommands: List[LoopCommand] = List( - // cmd("phase", "", "set the implicit phase for power commands", phaseCommand) - ) - - // private def dumpCommand(): Result = { - // echo("" + power) - // history.asStrings takeRight 30 foreach echo - // in.redrawLine() - // } - // private def valsCommand(): Result = power.valsDescription - - private val typeTransforms = List( - "scala.collection.immutable." -> "immutable.", - "scala.collection.mutable." -> "mutable.", - "scala.collection.generic." -> "generic.", - "java.lang." -> "jl.", - "scala.runtime." -> "runtime." - ) - - private def importsCommand(line: String): Result = { - val tokens = words(line) - val handlers = intp.languageWildcardHandlers ++ intp.importHandlers - val isVerbose = tokens contains "-v" - - handlers.filterNot(_.importedSymbols.isEmpty).zipWithIndex foreach { - case (handler, idx) => - val (types, terms) = handler.importedSymbols partition (_.name.isTypeName) - val imps = handler.implicitSymbols - val found = tokens filter (handler importsSymbolNamed _) - val typeMsg = if (types.isEmpty) "" else types.size + " types" - val termMsg = if (terms.isEmpty) "" else terms.size + " terms" - val implicitMsg = if (imps.isEmpty) "" else imps.size + " are implicit" - val foundMsg = if (found.isEmpty) "" else found.mkString(" // imports: ", ", ", "") - val statsMsg = List(typeMsg, termMsg, implicitMsg) filterNot (_ == "") mkString ("(", ", ", ")") - - intp.reporter.printMessage("%2d) %-30s %s%s".format( - idx + 1, - handler.importString, - statsMsg, - foundMsg - )) - } - } - - private def implicitsCommand(line: String): Result = onIntp { intp => - import intp._ - import global._ - - def p(x: Any) = intp.reporter.printMessage("" + x) - - // If an argument is given, only show a source with that - // in its name somewhere. - val args = line split "\\s+" - val filtered = intp.implicitSymbolsBySource filter { - case (source, syms) => - (args contains "-v") || { - if (line == "") (source.fullName.toString != "scala.Predef") - else (args exists (source.name.toString contains _)) - } - } - - if (filtered.isEmpty) - return "No implicits have been imported other than those in Predef." - - filtered foreach { - case (source, syms) => - p("/* " + syms.size + " implicit members imported from " + source.fullName + " */") - - // This groups the members by where the symbol is defined - val byOwner = syms groupBy (_.owner) - val sortedOwners = byOwner.toList sortBy { case (owner, _) => afterTyper(source.info.baseClasses indexOf owner) } - - sortedOwners foreach { - case (owner, members) => - // Within each owner, we cluster results based on the final result type - // if there are more than a couple, and sort each cluster based on name. - // This is really just trying to make the 100 or so implicits imported - // by default into something readable. - val memberGroups: List[List[Symbol]] = { - val groups = members groupBy (_.tpe.finalResultType) toList - val (big, small) = groups partition (_._2.size > 3) - val xss = ( - (big sortBy (_._1.toString) map (_._2)) :+ - (small flatMap (_._2)) - ) - - xss map (xs => xs sortBy (_.name.toString)) - } - - val ownerMessage = if (owner == source) " defined in " else " inherited from " - p(" /* " + members.size + ownerMessage + owner.fullName + " */") - - memberGroups foreach { group => - group foreach (s => p(" " + intp.symbolDefString(s))) - p("") - } - } - p("") - } - } - - private def findToolsJar() = { - val jdkPath = Directory(jdkHome) - val jar = jdkPath / "lib" / "tools.jar" toFile; - - if (jar isFile) - Some(jar) - else if (jdkPath.isDirectory) - jdkPath.deepFiles find (_.name == "tools.jar") - else None - } - private def addToolsJarToLoader() = { - val cl = findToolsJar match { - case Some(tools) => ScalaClassLoader.fromURLs(Seq(tools.toURL), intp.classLoader) - case _ => intp.classLoader - } - if (Javap.isAvailable(cl)) { - logDebug(":javap available.") - cl - } - else { - logDebug(":javap unavailable: no tools.jar at " + jdkHome) - intp.classLoader - } - } - - private def newJavap() = new JavapClass(addToolsJarToLoader(), new SparkIMain.ReplStrippingWriter(intp)) { - override def tryClass(path: String): Array[Byte] = { - val hd :: rest = path split '.' toList; - // If there are dots in the name, the first segment is the - // key to finding it. - if (rest.nonEmpty) { - intp optFlatName hd match { - case Some(flat) => - val clazz = flat :: rest mkString NAME_JOIN_STRING - val bytes = super.tryClass(clazz) - if (bytes.nonEmpty) bytes - else super.tryClass(clazz + MODULE_SUFFIX_STRING) - case _ => super.tryClass(path) - } - } - else { - // Look for Foo first, then Foo$, but if Foo$ is given explicitly, - // we have to drop the $ to find object Foo, then tack it back onto - // the end of the flattened name. - def className = intp flatName path - def moduleName = (intp flatName path.stripSuffix(MODULE_SUFFIX_STRING)) + MODULE_SUFFIX_STRING - - val bytes = super.tryClass(className) - if (bytes.nonEmpty) bytes - else super.tryClass(moduleName) - } - } - } - // private lazy val javap = substituteAndLog[Javap]("javap", NoJavap)(newJavap()) - private lazy val javap = - try newJavap() - catch { case _: Exception => null } - - // Still todo: modules. - private def typeCommand(line0: String): Result = { - line0.trim match { - case "" => ":type [-v] " - case s if s startsWith "-v " => typeCommandInternal(s stripPrefix "-v " trim, true) - case s => typeCommandInternal(s, false) - } - } - - private def warningsCommand(): Result = { - if (intp.lastWarnings.isEmpty) - "Can't find any cached warnings." - else - intp.lastWarnings foreach { case (pos, msg) => intp.reporter.warning(pos, msg) } - } - - private def javapCommand(line: String): Result = { - if (javap == null) - ":javap unavailable, no tools.jar at %s. Set JDK_HOME.".format(jdkHome) - else if (javaVersion startsWith "1.7") - ":javap not yet working with java 1.7" - else if (line == "") - ":javap [-lcsvp] [path1 path2 ...]" - else - javap(words(line)) foreach { res => - if (res.isError) return "Failed: " + res.value - else res.show() - } - } - - private def wrapCommand(line: String): Result = { - def failMsg = "Argument to :wrap must be the name of a method with signature [T](=> T): T" - onIntp { intp => - import intp._ - import global._ - - words(line) match { - case Nil => - intp.executionWrapper match { - case "" => "No execution wrapper is set." - case s => "Current execution wrapper: " + s - } - case "clear" :: Nil => - intp.executionWrapper match { - case "" => "No execution wrapper is set." - case s => intp.clearExecutionWrapper() ; "Cleared execution wrapper." - } - case wrapper :: Nil => - intp.typeOfExpression(wrapper) match { - case PolyType(List(targ), MethodType(List(arg), restpe)) => - intp setExecutionWrapper intp.pathToTerm(wrapper) - "Set wrapper to '" + wrapper + "'" - case tp => - failMsg + "\nFound: " - } - case _ => failMsg - } - } - } - - private def pathToPhaseWrapper = intp.pathToTerm("$r") + ".phased.atCurrent" - // private def phaseCommand(name: String): Result = { - // val phased: Phased = power.phased - // import phased.NoPhaseName - - // if (name == "clear") { - // phased.set(NoPhaseName) - // intp.clearExecutionWrapper() - // "Cleared active phase." - // } - // else if (name == "") phased.get match { - // case NoPhaseName => "Usage: :phase (e.g. typer, erasure.next, erasure+3)" - // case ph => "Active phase is '%s'. (To clear, :phase clear)".format(phased.get) - // } - // else { - // val what = phased.parse(name) - // if (what.isEmpty || !phased.set(what)) - // "'" + name + "' does not appear to represent a valid phase." - // else { - // intp.setExecutionWrapper(pathToPhaseWrapper) - // val activeMessage = - // if (what.toString.length == name.length) "" + what - // else "%s (%s)".format(what, name) - - // "Active phase is now: " + activeMessage - // } - // } - // } - - /** - * Provides a list of available commands. - * - * @return The list of commands - */ - @DeveloperApi - def commands: List[LoopCommand] = standardCommands /*++ ( - if (isReplPower) powerCommands else Nil - )*/ - - private val replayQuestionMessage = - """|That entry seems to have slain the compiler. Shall I replay - |your session? I can re-run each line except the last one. - |[y/n] - """.trim.stripMargin - - private def crashRecovery(ex: Throwable): Boolean = { - echo(ex.toString) - ex match { - case _: NoSuchMethodError | _: NoClassDefFoundError => - echo("\nUnrecoverable error.") - throw ex - case _ => - def fn(): Boolean = - try in.readYesOrNo(replayQuestionMessage, { echo("\nYou must enter y or n.") ; fn() }) - catch { case _: RuntimeException => false } - - if (fn()) replay() - else echo("\nAbandoning crashed session.") - } - true - } - - /** The main read-eval-print loop for the repl. It calls - * command() for each line of input, and stops when - * command() returns false. - */ - private def loop() { - def readOneLine() = { - out.flush() - in readLine prompt - } - // return false if repl should exit - def processLine(line: String): Boolean = { - if (isAsync) { - if (!awaitInitialized()) return false - runThunks() - } - if (line eq null) false // assume null means EOF - else command(line) match { - case Result(false, _) => false - case Result(_, Some(finalLine)) => addReplay(finalLine) ; true - case _ => true - } - } - def innerLoop() { - val shouldContinue = try { - processLine(readOneLine()) - } catch {case t: Throwable => crashRecovery(t)} - if (shouldContinue) - innerLoop() - } - innerLoop() - } - - /** interpret all lines from a specified file */ - private def interpretAllFrom(file: File) { - savingReader { - savingReplayStack { - file applyReader { reader => - in = SimpleReader(reader, out, false) - echo("Loading " + file + "...") - loop() - } - } - } - } - - /** create a new interpreter and replay the given commands */ - private def replay() { - reset() - if (replayCommandStack.isEmpty) - echo("Nothing to replay.") - else for (cmd <- replayCommands) { - echo("Replaying: " + cmd) // flush because maybe cmd will have its own output - command(cmd) - echo("") - } - } - private def resetCommand() { - echo("Resetting repl state.") - if (replayCommandStack.nonEmpty) { - echo("Forgetting this session history:\n") - replayCommands foreach echo - echo("") - replayCommandStack = Nil - } - if (intp.namedDefinedTerms.nonEmpty) - echo("Forgetting all expression results and named terms: " + intp.namedDefinedTerms.mkString(", ")) - if (intp.definedTypes.nonEmpty) - echo("Forgetting defined types: " + intp.definedTypes.mkString(", ")) - - reset() - } - - private def reset() { - intp.reset() - // unleashAndSetPhase() - } - - /** fork a shell and run a command */ - private lazy val shCommand = new LoopCommand("sh", "run a shell command (result is implicitly => List[String])") { - override def usage = "" - def apply(line: String): Result = line match { - case "" => showUsage() - case _ => - val toRun = classOf[ProcessResult].getName + "(" + string2codeQuoted(line) + ")" - intp interpret toRun - () - } - } - - private def withFile(filename: String)(action: File => Unit) { - val f = File(filename) - - if (f.exists) action(f) - else echo("That file does not exist") - } - - private def loadCommand(arg: String) = { - var shouldReplay: Option[String] = None - withFile(arg)(f => { - interpretAllFrom(f) - shouldReplay = Some(":load " + arg) - }) - Result(true, shouldReplay) - } - - private def addAllClasspath(args: Seq[String]): Unit = { - var added = false - var totalClasspath = "" - for (arg <- args) { - val f = File(arg).normalize - if (f.exists) { - added = true - addedClasspath = ClassPath.join(addedClasspath, f.path) - totalClasspath = ClassPath.join(settings.classpath.value, addedClasspath) - intp.addUrlsToClassPath(f.toURI.toURL) - sparkContext.addJar(f.toURI.toURL.getPath) - } - } - } - - private def addClasspath(arg: String): Unit = { - val f = File(arg).normalize - if (f.exists) { - addedClasspath = ClassPath.join(addedClasspath, f.path) - intp.addUrlsToClassPath(f.toURI.toURL) - sparkContext.addJar(f.toURI.toURL.getPath) - echo("Added '%s'. Your new classpath is:\n\"%s\"".format(f.path, intp.global.classPath.asClasspathString)) - } - else echo("The path '" + f + "' doesn't seem to exist.") - } - - - private def powerCmd(): Result = { - if (isReplPower) "Already in power mode." - else enablePowerMode(false) - } - - private[repl] def enablePowerMode(isDuringInit: Boolean) = { - // replProps.power setValue true - // unleashAndSetPhase() - // asyncEcho(isDuringInit, power.banner) - } - // private def unleashAndSetPhase() { -// if (isReplPower) { -// // power.unleash() -// // Set the phase to "typer" -// intp beSilentDuring phaseCommand("typer") -// } -// } - - private def asyncEcho(async: Boolean, msg: => String) { - if (async) asyncMessage(msg) - else echo(msg) - } - - private def verbosity() = { - // val old = intp.printResults - // intp.printResults = !old - // echo("Switched " + (if (old) "off" else "on") + " result printing.") - } - - /** - * Run one command submitted by the user. Two values are returned: - * (1) whether to keep running, (2) the line to record for replay, - * if any. - */ - private[repl] def command(line: String): Result = { - if (line startsWith ":") { - val cmd = line.tail takeWhile (x => !x.isWhitespace) - uniqueCommand(cmd) match { - case Some(lc) => lc(line.tail stripPrefix cmd dropWhile (_.isWhitespace)) - case _ => ambiguousError(cmd) - } - } - else if (intp.global == null) Result(false, None) // Notice failure to create compiler - else Result(true, interpretStartingWith(line)) - } - - private def readWhile(cond: String => Boolean) = { - Iterator continually in.readLine("") takeWhile (x => x != null && cond(x)) - } - - private def pasteCommand(): Result = { - echo("// Entering paste mode (ctrl-D to finish)\n") - val code = readWhile(_ => true) mkString "\n" - echo("\n// Exiting paste mode, now interpreting.\n") - intp interpret code - () - } - - private object paste extends Pasted { - val ContinueString = " | " - val PromptString = "scala> " - - def interpret(line: String): Unit = { - echo(line.trim) - intp interpret line - echo("") - } - - def transcript(start: String) = { - echo("\n// Detected repl transcript paste: ctrl-D to finish.\n") - apply(Iterator(start) ++ readWhile(_.trim != PromptString.trim)) - } - } - import paste.{ ContinueString, PromptString } - - /** - * Interpret expressions starting with the first line. - * Read lines until a complete compilation unit is available - * or until a syntax error has been seen. If a full unit is - * read, go ahead and interpret it. Return the full string - * to be recorded for replay, if any. - */ - private def interpretStartingWith(code: String): Option[String] = { - // signal completion non-completion input has been received - in.completion.resetVerbosity() - - def reallyInterpret = { - val reallyResult = intp.interpret(code) - (reallyResult, reallyResult match { - case IR.Error => None - case IR.Success => Some(code) - case IR.Incomplete => - if (in.interactive && code.endsWith("\n\n")) { - echo("You typed two blank lines. Starting a new command.") - None - } - else in.readLine(ContinueString) match { - case null => - // we know compilation is going to fail since we're at EOF and the - // parser thinks the input is still incomplete, but since this is - // a file being read non-interactively we want to fail. So we send - // it straight to the compiler for the nice error message. - intp.compileString(code) - None - - case line => interpretStartingWith(code + "\n" + line) - } - }) - } - - /** Here we place ourselves between the user and the interpreter and examine - * the input they are ostensibly submitting. We intervene in several cases: - * - * 1) If the line starts with "scala> " it is assumed to be an interpreter paste. - * 2) If the line starts with "." (but not ".." or "./") it is treated as an invocation - * on the previous result. - * 3) If the Completion object's execute returns Some(_), we inject that value - * and avoid the interpreter, as it's likely not valid scala code. - */ - if (code == "") None - else if (!paste.running && code.trim.startsWith(PromptString)) { - paste.transcript(code) - None - } - else if (Completion.looksLikeInvocation(code) && intp.mostRecentVar != "") { - interpretStartingWith(intp.mostRecentVar + code) - } - else if (code.trim startsWith "//") { - // line comment, do nothing - None - } - else - reallyInterpret._2 - } - - // runs :load `file` on any files passed via -i - private def loadFiles(settings: Settings) = settings match { - case settings: SparkRunnerSettings => - for (filename <- settings.loadfiles.value) { - val cmd = ":load " + filename - command(cmd) - addReplay(cmd) - echo("") - } - case _ => - } - - /** Tries to create a JLineReader, falling back to SimpleReader: - * unless settings or properties are such that it should start - * with SimpleReader. - */ - private def chooseReader(settings: Settings): InteractiveReader = { - if (settings.Xnojline.value || Properties.isEmacsShell) - SimpleReader() - else try new SparkJLineReader( - if (settings.noCompletion.value) NoCompletion - else new SparkJLineCompletion(intp) - ) - catch { - case ex @ (_: Exception | _: NoClassDefFoundError) => - echo("Failed to created SparkJLineReader: " + ex + "\nFalling back to SimpleReader.") - SimpleReader() - } - } - - private val u: scala.reflect.runtime.universe.type = scala.reflect.runtime.universe - private val m = u.runtimeMirror(Utils.getSparkClassLoader) - private def tagOfStaticClass[T: ClassTag]: u.TypeTag[T] = - u.TypeTag[T]( - m, - new TypeCreator { - def apply[U <: ApiUniverse with Singleton](m: Mirror[U]): U # Type = - m.staticClass(classTag[T].runtimeClass.getName).toTypeConstructor.asInstanceOf[U # Type] - }) - - private def process(settings: Settings): Boolean = savingContextLoader { - this.settings = settings - createInterpreter() - - // sets in to some kind of reader depending on environmental cues - in = in0 match { - case Some(reader) => SimpleReader(reader, out, true) - case None => - // some post-initialization - chooseReader(settings) match { - case x: SparkJLineReader => addThunk(x.consoleReader.postInit) ; x - case x => x - } - } - lazy val tagOfSparkIMain = tagOfStaticClass[org.apache.spark.repl.SparkIMain] - // Bind intp somewhere out of the regular namespace where - // we can get at it in generated code. - addThunk(intp.quietBind(NamedParam[SparkIMain]("$intp", intp)(tagOfSparkIMain, classTag[SparkIMain]))) - addThunk({ - import scala.tools.nsc.io._ - import Properties.userHome - import scala.compat.Platform.EOL - val autorun = replProps.replAutorunCode.option flatMap (f => io.File(f).safeSlurp()) - if (autorun.isDefined) intp.quietRun(autorun.get) - }) - - addThunk(printWelcome()) - addThunk(initializeSpark()) - - // it is broken on startup; go ahead and exit - if (intp.reporter.hasErrors) - return false - - // This is about the illusion of snappiness. We call initialize() - // which spins off a separate thread, then print the prompt and try - // our best to look ready. The interlocking lazy vals tend to - // inter-deadlock, so we break the cycle with a single asynchronous - // message to an rpcEndpoint. - if (isAsync) { - intp initialize initializedCallback() - createAsyncListener() // listens for signal to run postInitialization - } - else { - intp.initializeSynchronous() - postInitialization() - } - // printWelcome() - - loadFiles(settings) - - try loop() - catch AbstractOrMissingHandler() - finally closeInterpreter() - - true - } - - // NOTE: Must be public for visibility - @DeveloperApi - def createSparkSession(): SparkSession = { - val execUri = System.getenv("SPARK_EXECUTOR_URI") - val jars = getAddedJars() - val conf = new SparkConf() - .setMaster(getMaster()) - .setJars(jars) - .setIfMissing("spark.app.name", "Spark shell") - // SparkContext will detect this configuration and register it with the RpcEnv's - // file server, setting spark.repl.class.uri to the actual URI for executors to - // use. This is sort of ugly but since executors are started as part of SparkContext - // initialization in certain cases, there's an initialization order issue that prevents - // this from being set after SparkContext is instantiated. - .set("spark.repl.class.outputDir", intp.outputDir.getAbsolutePath()) - if (execUri != null) { - conf.set("spark.executor.uri", execUri) - } - - val builder = SparkSession.builder.config(conf) - val sparkSession = if (SparkSession.hiveClassesArePresent) { - logInfo("Creating Spark session with Hive support") - builder.enableHiveSupport().getOrCreate() - } else { - logInfo("Creating Spark session") - builder.getOrCreate() - } - sparkContext = sparkSession.sparkContext - sparkSession - } - - private def getMaster(): String = { - val master = this.master match { - case Some(m) => m - case None => - val envMaster = sys.env.get("MASTER") - val propMaster = sys.props.get("spark.master") - propMaster.orElse(envMaster).getOrElse("local[*]") - } - master - } - - /** process command-line arguments and do as they request */ - def process(args: Array[String]): Boolean = { - val command = new SparkCommandLine(args.toList, msg => echo(msg)) - def neededHelp(): String = - (if (command.settings.help.value) command.usageMsg + "\n" else "") + - (if (command.settings.Xhelp.value) command.xusageMsg + "\n" else "") - - // if they asked for no help and command is valid, we call the real main - neededHelp() match { - case "" => command.ok && process(command.settings) - case help => echoNoNL(help) ; true - } - } - - @deprecated("Use `process` instead", "2.9.0") - private def main(settings: Settings): Unit = process(settings) - - @DeveloperApi - def getAddedJars(): Array[String] = { - val conf = new SparkConf().setMaster(getMaster()) - val envJars = sys.env.get("ADD_JARS") - if (envJars.isDefined) { - logWarning("ADD_JARS environment variable is deprecated, use --jar spark submit argument instead") - } - val jars = { - val userJars = Utils.getUserJars(conf, isShell = true) - if (userJars.isEmpty) { - envJars.getOrElse("") - } else { - userJars.mkString(",") - } - } - Utils.resolveURIs(jars).split(",").filter(_.nonEmpty) - } - -} - -object SparkILoop extends Logging { - implicit def loopToInterpreter(repl: SparkILoop): SparkIMain = repl.intp - private def echo(msg: String) = Console println msg - - // Designed primarily for use by test code: take a String with a - // bunch of code, and prints out a transcript of what it would look - // like if you'd just typed it into the repl. - private[repl] def runForTranscript(code: String, settings: Settings): String = { - import java.io.{ BufferedReader, StringReader, OutputStreamWriter } - - stringFromStream { ostream => - Console.withOut(ostream) { - val output = new JPrintWriter(new OutputStreamWriter(ostream), true) { - override def write(str: String) = { - // completely skip continuation lines - if (str forall (ch => ch.isWhitespace || ch == '|')) () - // print a newline on empty scala prompts - else if ((str contains '\n') && (str.trim == "scala> ")) super.write("\n") - else super.write(str) - } - } - val input = new BufferedReader(new StringReader(code)) { - override def readLine(): String = { - val s = super.readLine() - // helping out by printing the line being interpreted. - if (s != null) - // scalastyle:off println - output.println(s) - // scalastyle:on println - s - } - } - val repl = new SparkILoop(input, output) - - if (settings.classpath.isDefault) - settings.classpath.value = sys.props("java.class.path") - - repl.getAddedJars().map(jar => new URI(jar).getPath).foreach(settings.classpath.append(_)) - - repl process settings - } - } - } - - /** Creates an interpreter loop with default settings and feeds - * the given code to it as input. - */ - private[repl] def run(code: String, sets: Settings = new Settings): String = { - import java.io.{ BufferedReader, StringReader, OutputStreamWriter } - - stringFromStream { ostream => - Console.withOut(ostream) { - val input = new BufferedReader(new StringReader(code)) - val output = new JPrintWriter(new OutputStreamWriter(ostream), true) - val repl = new ILoop(input, output) - - if (sets.classpath.isDefault) - sets.classpath.value = sys.props("java.class.path") - - repl process sets - } - } - } - private[repl] def run(lines: List[String]): String = run(lines map (_ + "\n") mkString) -} diff --git a/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkILoopInit.scala b/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkILoopInit.scala deleted file mode 100644 index 5f0d92bccd809..0000000000000 --- a/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkILoopInit.scala +++ /dev/null @@ -1,168 +0,0 @@ -// scalastyle:off - -/* NSC -- new Scala compiler - * Copyright 2005-2013 LAMP/EPFL - * @author Paul Phillips - */ - -package org.apache.spark.repl - -import scala.tools.nsc._ -import scala.tools.nsc.interpreter._ - -import scala.tools.nsc.util.stackTraceString - -import org.apache.spark.SPARK_VERSION - -/** - * Machinery for the asynchronous initialization of the repl. - */ -private[repl] trait SparkILoopInit { - self: SparkILoop => - - /** Print a welcome message */ - def printWelcome() { - echo("""Welcome to - ____ __ - / __/__ ___ _____/ /__ - _\ \/ _ \/ _ `/ __/ '_/ - /___/ .__/\_,_/_/ /_/\_\ version %s - /_/ -""".format(SPARK_VERSION)) - import Properties._ - val welcomeMsg = "Using Scala %s (%s, Java %s)".format( - versionString, javaVmName, javaVersion) - echo(welcomeMsg) - echo("Type in expressions to have them evaluated.") - echo("Type :help for more information.") - } - - protected def asyncMessage(msg: String) { - if (isReplInfo || isReplPower) - echoAndRefresh(msg) - } - - private val initLock = new java.util.concurrent.locks.ReentrantLock() - private val initCompilerCondition = initLock.newCondition() // signal the compiler is initialized - private val initLoopCondition = initLock.newCondition() // signal the whole repl is initialized - private val initStart = System.nanoTime - - private def withLock[T](body: => T): T = { - initLock.lock() - try body - finally initLock.unlock() - } - // a condition used to ensure serial access to the compiler. - @volatile private var initIsComplete = false - @volatile private var initError: String = null - private def elapsed() = "%.3f".format((System.nanoTime - initStart).toDouble / 1000000000L) - - // the method to be called when the interpreter is initialized. - // Very important this method does nothing synchronous (i.e. do - // not try to use the interpreter) because until it returns, the - // repl's lazy val `global` is still locked. - protected def initializedCallback() = withLock(initCompilerCondition.signal()) - - // Spins off a thread which awaits a single message once the interpreter - // has been initialized. - protected def createAsyncListener() = { - io.spawn { - withLock(initCompilerCondition.await()) - asyncMessage("[info] compiler init time: " + elapsed() + " s.") - postInitialization() - } - } - - // called from main repl loop - protected def awaitInitialized(): Boolean = { - if (!initIsComplete) - withLock { while (!initIsComplete) initLoopCondition.await() } - if (initError != null) { - // scalastyle:off println - println(""" - |Failed to initialize the REPL due to an unexpected error. - |This is a bug, please, report it along with the error diagnostics printed below. - |%s.""".stripMargin.format(initError) - ) - // scalastyle:on println - false - } else true - } - // private def warningsThunks = List( - // () => intp.bind("lastWarnings", "" + typeTag[List[(Position, String)]], intp.lastWarnings _), - // ) - - protected def postInitThunks = List[Option[() => Unit]]( - Some(intp.setContextClassLoader _), - if (isReplPower) Some(() => enablePowerMode(true)) else None - ).flatten - // ++ ( - // warningsThunks - // ) - // called once after init condition is signalled - protected def postInitialization() { - try { - postInitThunks foreach (f => addThunk(f())) - runThunks() - } catch { - case ex: Throwable => - initError = stackTraceString(ex) - throw ex - } finally { - initIsComplete = true - - if (isAsync) { - asyncMessage("[info] total init time: " + elapsed() + " s.") - withLock(initLoopCondition.signal()) - } - } - } - - def initializeSpark() { - intp.beQuietDuring { - command(""" - @transient val spark = org.apache.spark.repl.Main.interp.createSparkSession() - @transient val sc = { - val _sc = spark.sparkContext - if (_sc.getConf.getBoolean("spark.ui.reverseProxy", false)) { - val proxyUrl = _sc.getConf.get("spark.ui.reverseProxyUrl", null) - if (proxyUrl != null) { - println(s"Spark Context Web UI is available at ${proxyUrl}/proxy/${_sc.applicationId}") - } else { - println(s"Spark Context Web UI is available at Spark Master Public URL") - } - } else { - _sc.uiWebUrl.foreach { - webUrl => println(s"Spark context Web UI available at ${webUrl}") - } - } - println("Spark context available as 'sc' " + - s"(master = ${_sc.master}, app id = ${_sc.applicationId}).") - println("Spark session available as 'spark'.") - _sc - } - """) - command("import org.apache.spark.SparkContext._") - command("import spark.implicits._") - command("import spark.sql") - command("import org.apache.spark.sql.functions._") - } - } - - // code to be executed only after the interpreter is initialized - // and the lazy val `global` can be accessed without risk of deadlock. - private var pendingThunks: List[() => Unit] = Nil - protected def addThunk(body: => Unit) = synchronized { - pendingThunks :+= (() => body) - } - protected def runThunks(): Unit = synchronized { - if (pendingThunks.nonEmpty) - logDebug("Clearing " + pendingThunks.size + " thunks.") - - while (pendingThunks.nonEmpty) { - val thunk = pendingThunks.head - pendingThunks = pendingThunks.tail - thunk() - } - } -} diff --git a/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkIMain.scala b/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkIMain.scala deleted file mode 100644 index 74a04d5a42bb2..0000000000000 --- a/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkIMain.scala +++ /dev/null @@ -1,1808 +0,0 @@ -// scalastyle:off - -/* NSC -- new Scala compiler - * Copyright 2005-2013 LAMP/EPFL - * @author Martin Odersky - */ - -package org.apache.spark.repl - -import java.io.File - -import scala.tools.nsc._ -import scala.tools.nsc.backend.JavaPlatform -import scala.tools.nsc.interpreter._ - -import Predef.{ println => _, _ } -import scala.tools.nsc.util.{MergedClassPath, stringFromWriter, ScalaClassLoader, stackTraceString} -import scala.reflect.internal.util._ -import java.net.URL -import scala.sys.BooleanProp -import io.{AbstractFile, PlainFile, VirtualDirectory} - -import reporters._ -import symtab.Flags -import scala.reflect.internal.Names -import scala.tools.util.PathResolver -import ScalaClassLoader.URLClassLoader -import scala.tools.nsc.util.Exceptional.unwrap -import scala.collection.{ mutable, immutable } -import scala.util.control.Exception.{ ultimately } -import SparkIMain._ -import java.util.concurrent.Future -import typechecker.Analyzer -import scala.language.implicitConversions -import scala.reflect.runtime.{ universe => ru } -import scala.reflect.{ ClassTag, classTag } -import scala.tools.reflect.StdRuntimeTags._ -import scala.util.control.ControlThrowable - -import org.apache.spark.{SparkConf, SparkContext} -import org.apache.spark.internal.Logging -import org.apache.spark.util.Utils -import org.apache.spark.annotation.DeveloperApi - -// /** directory to save .class files to */ -// private class ReplVirtualDirectory(out: JPrintWriter) extends VirtualDirectory("((memory))", None) { -// private def pp(root: AbstractFile, indentLevel: Int) { -// val spaces = " " * indentLevel -// out.println(spaces + root.name) -// if (root.isDirectory) -// root.toList sortBy (_.name) foreach (x => pp(x, indentLevel + 1)) -// } -// // print the contents hierarchically -// def show() = pp(this, 0) -// } - - /** An interpreter for Scala code. - * - * The main public entry points are compile(), interpret(), and bind(). - * The compile() method loads a complete Scala file. The interpret() method - * executes one line of Scala code at the request of the user. The bind() - * method binds an object to a variable that can then be used by later - * interpreted code. - * - * The overall approach is based on compiling the requested code and then - * using a Java classloader and Java reflection to run the code - * and access its results. - * - * In more detail, a single compiler instance is used - * to accumulate all successfully compiled or interpreted Scala code. To - * "interpret" a line of code, the compiler generates a fresh object that - * includes the line of code and which has public member(s) to export - * all variables defined by that code. To extract the result of an - * interpreted line to show the user, a second "result object" is created - * which imports the variables exported by the above object and then - * exports members called "$eval" and "$print". To accommodate user expressions - * that read from variables or methods defined in previous statements, "import" - * statements are used. - * - * This interpreter shares the strengths and weaknesses of using the - * full compiler-to-Java. The main strength is that interpreted code - * behaves exactly as does compiled code, including running at full speed. - * The main weakness is that redefining classes and methods is not handled - * properly, because rebinding at the Java level is technically difficult. - * - * @author Moez A. Abdel-Gawad - * @author Lex Spoon - */ - @DeveloperApi - class SparkIMain( - initialSettings: Settings, - val out: JPrintWriter, - propagateExceptions: Boolean = false) - extends SparkImports with Logging { imain => - - private val conf = new SparkConf() - - private val SPARK_DEBUG_REPL: Boolean = (System.getenv("SPARK_DEBUG_REPL") == "1") - /** Local directory to save .class files too */ - private[repl] val outputDir = { - val rootDir = conf.getOption("spark.repl.classdir").getOrElse(Utils.getLocalDir(conf)) - Utils.createTempDir(root = rootDir, namePrefix = "repl") - } - if (SPARK_DEBUG_REPL) { - echo("Output directory: " + outputDir) - } - - /** - * Returns the path to the output directory containing all generated - * class files that will be served by the REPL class server. - */ - @DeveloperApi - lazy val getClassOutputDirectory = outputDir - - private val virtualDirectory = new PlainFile(outputDir) // "directory" for classfiles - /** Jetty server that will serve our classes to worker nodes */ - private var currentSettings: Settings = initialSettings - private var printResults = true // whether to print result lines - private var totalSilence = false // whether to print anything - private var _initializeComplete = false // compiler is initialized - private var _isInitialized: Future[Boolean] = null // set up initialization future - private var bindExceptions = true // whether to bind the lastException variable - private var _executionWrapper = "" // code to be wrapped around all lines - - /** We're going to go to some trouble to initialize the compiler asynchronously. - * It's critical that nothing call into it until it's been initialized or we will - * run into unrecoverable issues, but the perceived repl startup time goes - * through the roof if we wait for it. So we initialize it with a future and - * use a lazy val to ensure that any attempt to use the compiler object waits - * on the future. - */ - private var _classLoader: AbstractFileClassLoader = null // active classloader - private val _compiler: Global = newCompiler(settings, reporter) // our private compiler - - private trait ExposeAddUrl extends URLClassLoader { def addNewUrl(url: URL) = this.addURL(url) } - private var _runtimeClassLoader: URLClassLoader with ExposeAddUrl = null // wrapper exposing addURL - - private val nextReqId = { - var counter = 0 - () => { counter += 1 ; counter } - } - - private def compilerClasspath: Seq[URL] = ( - if (isInitializeComplete) global.classPath.asURLs - else new PathResolver(settings).result.asURLs // the compiler's classpath - ) - // NOTE: Exposed to repl package since accessed indirectly from SparkIMain - private[repl] def settings = currentSettings - private def mostRecentLine = prevRequestList match { - case Nil => "" - case req :: _ => req.originalLine - } - // Run the code body with the given boolean settings flipped to true. - private def withoutWarnings[T](body: => T): T = beQuietDuring { - val saved = settings.nowarn.value - if (!saved) - settings.nowarn.value = true - - try body - finally if (!saved) settings.nowarn.value = false - } - - /** construct an interpreter that reports to Console */ - def this(settings: Settings) = this(settings, new NewLinePrintWriter(new ConsoleWriter, true)) - def this() = this(new Settings()) - - private lazy val repllog: Logger = new Logger { - val out: JPrintWriter = imain.out - val isInfo: Boolean = BooleanProp keyExists "scala.repl.info" - val isDebug: Boolean = BooleanProp keyExists "scala.repl.debug" - val isTrace: Boolean = BooleanProp keyExists "scala.repl.trace" - } - private[repl] lazy val formatting: Formatting = new Formatting { - val prompt = Properties.shellPromptString - } - - // NOTE: Exposed to repl package since used by SparkExprTyper and SparkILoop - private[repl] lazy val reporter: ConsoleReporter = new SparkIMain.ReplReporter(this) - - /** - * Determines if errors were reported (typically during compilation). - * - * @note This is not for runtime errors - * - * @return True if had errors, otherwise false - */ - @DeveloperApi - def isReportingErrors = reporter.hasErrors - - import formatting._ - import reporter.{ printMessage, withoutTruncating } - - // This exists mostly because using the reporter too early leads to deadlock. - private def echo(msg: String) { Console println msg } - private def _initSources = List(new BatchSourceFile("", "class $repl_$init { }")) - private def _initialize() = { - try { - // todo. if this crashes, REPL will hang - new _compiler.Run() compileSources _initSources - _initializeComplete = true - true - } - catch AbstractOrMissingHandler() - } - private def tquoted(s: String) = "\"\"\"" + s + "\"\"\"" - - // argument is a thunk to execute after init is done - // NOTE: Exposed to repl package since used by SparkILoop - private[repl] def initialize(postInitSignal: => Unit) { - synchronized { - if (_isInitialized == null) { - _isInitialized = io.spawn { - try _initialize() - finally postInitSignal - } - } - } - } - - /** - * Initializes the underlying compiler/interpreter in a blocking fashion. - * - * @note Must be executed before using SparkIMain! - */ - @DeveloperApi - def initializeSynchronous(): Unit = { - if (!isInitializeComplete) { - _initialize() - assert(global != null, global) - } - } - private def isInitializeComplete = _initializeComplete - - /** the public, go through the future compiler */ - - /** - * The underlying compiler used to generate ASTs and execute code. - */ - @DeveloperApi - lazy val global: Global = { - if (isInitializeComplete) _compiler - else { - // If init hasn't been called yet you're on your own. - if (_isInitialized == null) { - logWarning("Warning: compiler accessed before init set up. Assuming no postInit code.") - initialize(()) - } - // // blocks until it is ; false means catastrophic failure - if (_isInitialized.get()) _compiler - else null - } - } - @deprecated("Use `global` for access to the compiler instance.", "2.9.0") - private lazy val compiler: global.type = global - - import global._ - import definitions.{ScalaPackage, JavaLangPackage, termMember, typeMember} - import rootMirror.{RootClass, getClassIfDefined, getModuleIfDefined, getRequiredModule, getRequiredClass} - - private implicit class ReplTypeOps(tp: Type) { - def orElse(other: => Type): Type = if (tp ne NoType) tp else other - def andAlso(fn: Type => Type): Type = if (tp eq NoType) tp else fn(tp) - } - - // TODO: If we try to make naming a lazy val, we run into big time - // scalac unhappiness with what look like cycles. It has not been easy to - // reduce, but name resolution clearly takes different paths. - // NOTE: Exposed to repl package since used by SparkExprTyper - private[repl] object naming extends { - val global: imain.global.type = imain.global - } with Naming { - // make sure we don't overwrite their unwisely named res3 etc. - def freshUserTermName(): TermName = { - val name = newTermName(freshUserVarName()) - if (definedNameMap contains name) freshUserTermName() - else name - } - def isUserTermName(name: Name) = isUserVarName("" + name) - def isInternalTermName(name: Name) = isInternalVarName("" + name) - } - import naming._ - - // NOTE: Exposed to repl package since used by SparkILoop - private[repl] object deconstruct extends { - val global: imain.global.type = imain.global - } with StructuredTypeStrings - - // NOTE: Exposed to repl package since used by SparkImports - private[repl] lazy val memberHandlers = new { - val intp: imain.type = imain - } with SparkMemberHandlers - import memberHandlers._ - - /** - * Suppresses overwriting print results during the operation. - * - * @param body The block to execute - * @tparam T The return type of the block - * - * @return The result from executing the block - */ - @DeveloperApi - def beQuietDuring[T](body: => T): T = { - val saved = printResults - printResults = false - try body - finally printResults = saved - } - - /** - * Completely masks all output during the operation (minus JVM standard - * out and error). - * - * @param operation The block to execute - * @tparam T The return type of the block - * - * @return The result from executing the block - */ - @DeveloperApi - def beSilentDuring[T](operation: => T): T = { - val saved = totalSilence - totalSilence = true - try operation - finally totalSilence = saved - } - - // NOTE: Exposed to repl package since used by SparkILoop - private[repl] def quietRun[T](code: String) = beQuietDuring(interpret(code)) - - private def logAndDiscard[T](label: String, alt: => T): PartialFunction[Throwable, T] = { - case t: ControlThrowable => throw t - case t: Throwable => - logDebug(label + ": " + unwrap(t)) - logDebug(stackTraceString(unwrap(t))) - alt - } - /** takes AnyRef because it may be binding a Throwable or an Exceptional */ - - private def withLastExceptionLock[T](body: => T, alt: => T): T = { - assert(bindExceptions, "withLastExceptionLock called incorrectly.") - bindExceptions = false - - try beQuietDuring(body) - catch logAndDiscard("withLastExceptionLock", alt) - finally bindExceptions = true - } - - /** - * Contains the code (in string form) representing a wrapper around all - * code executed by this instance. - * - * @return The wrapper code as a string - */ - @DeveloperApi - def executionWrapper = _executionWrapper - - /** - * Sets the code to use as a wrapper around all code executed by this - * instance. - * - * @param code The wrapper code as a string - */ - @DeveloperApi - def setExecutionWrapper(code: String) = _executionWrapper = code - - /** - * Clears the code used as a wrapper around all code executed by - * this instance. - */ - @DeveloperApi - def clearExecutionWrapper() = _executionWrapper = "" - - /** interpreter settings */ - private lazy val isettings = new SparkISettings(this) - - /** - * Instantiates a new compiler used by SparkIMain. Overridable to provide - * own instance of a compiler. - * - * @param settings The settings to provide the compiler - * @param reporter The reporter to use for compiler output - * - * @return The compiler as a Global - */ - @DeveloperApi - protected def newCompiler(settings: Settings, reporter: Reporter): ReplGlobal = { - settings.outputDirs setSingleOutput virtualDirectory - settings.exposeEmptyPackage.value = true - new Global(settings, reporter) with ReplGlobal { - override def toString: String = "" - } - } - - /** - * Adds any specified jars to the compile and runtime classpaths. - * - * @note Currently only supports jars, not directories - * @param urls The list of items to add to the compile and runtime classpaths - */ - @DeveloperApi - def addUrlsToClassPath(urls: URL*): Unit = { - new Run // Needed to force initialization of "something" to correctly load Scala classes from jars - urls.foreach(_runtimeClassLoader.addNewUrl) // Add jars/classes to runtime for execution - updateCompilerClassPath(urls: _*) // Add jars/classes to compile time for compiling - } - - private def updateCompilerClassPath(urls: URL*): Unit = { - require(!global.forMSIL) // Only support JavaPlatform - - val platform = global.platform.asInstanceOf[JavaPlatform] - - val newClassPath = mergeUrlsIntoClassPath(platform, urls: _*) - - // NOTE: Must use reflection until this is exposed/fixed upstream in Scala - val fieldSetter = platform.getClass.getMethods - .find(_.getName.endsWith("currentClassPath_$eq")).get - fieldSetter.invoke(platform, Some(newClassPath)) - - // Reload all jars specified into our compiler - global.invalidateClassPathEntries(urls.map(_.getPath): _*) - } - - private def mergeUrlsIntoClassPath(platform: JavaPlatform, urls: URL*): MergedClassPath[AbstractFile] = { - // Collect our new jars/directories and add them to the existing set of classpaths - val allClassPaths = ( - platform.classPath.asInstanceOf[MergedClassPath[AbstractFile]].entries ++ - urls.map(url => { - platform.classPath.context.newClassPath( - if (url.getProtocol == "file") { - val f = new File(url.getPath) - if (f.isDirectory) - io.AbstractFile.getDirectory(f) - else - io.AbstractFile.getFile(f) - } else { - io.AbstractFile.getURL(url) - } - ) - }) - ).distinct - - // Combine all of our classpaths (old and new) into one merged classpath - new MergedClassPath(allClassPaths, platform.classPath.context) - } - - /** - * Represents the parent classloader used by this instance. Can be - * overridden to provide alternative classloader. - * - * @return The classloader used as the parent loader of this instance - */ - @DeveloperApi - protected def parentClassLoader: ClassLoader = - SparkHelper.explicitParentLoader(settings).getOrElse( this.getClass.getClassLoader() ) - - /* A single class loader is used for all commands interpreted by this Interpreter. - It would also be possible to create a new class loader for each command - to interpret. The advantages of the current approach are: - - - Expressions are only evaluated one time. This is especially - significant for I/O, e.g. "val x = Console.readLine" - - The main disadvantage is: - - - Objects, classes, and methods cannot be rebound. Instead, definitions - shadow the old ones, and old code objects refer to the old - definitions. - */ - private def resetClassLoader() = { - logDebug("Setting new classloader: was " + _classLoader) - _classLoader = null - ensureClassLoader() - } - private final def ensureClassLoader() { - if (_classLoader == null) - _classLoader = makeClassLoader() - } - - // NOTE: Exposed to repl package since used by SparkILoop - private[repl] def classLoader: AbstractFileClassLoader = { - ensureClassLoader() - _classLoader - } - private class TranslatingClassLoader(parent: ClassLoader) extends AbstractFileClassLoader(virtualDirectory, parent) { - /** Overridden here to try translating a simple name to the generated - * class name if the original attempt fails. This method is used by - * getResourceAsStream as well as findClass. - */ - override protected def findAbstractFile(name: String): AbstractFile = { - super.findAbstractFile(name) match { - // deadlocks on startup if we try to translate names too early - case null if isInitializeComplete => - generatedName(name) map (x => super.findAbstractFile(x)) orNull - case file => - file - } - } - } - private def makeClassLoader(): AbstractFileClassLoader = - new TranslatingClassLoader(parentClassLoader match { - case null => ScalaClassLoader fromURLs compilerClasspath - case p => - _runtimeClassLoader = new URLClassLoader(compilerClasspath, p) with ExposeAddUrl - _runtimeClassLoader - }) - - private def getInterpreterClassLoader() = classLoader - - // Set the current Java "context" class loader to this interpreter's class loader - // NOTE: Exposed to repl package since used by SparkILoopInit - private[repl] def setContextClassLoader() = classLoader.setAsContext() - - /** - * Returns the real name of a class based on its repl-defined name. - * - * ==Example== - * Given a simple repl-defined name, returns the real name of - * the class representing it, e.g. for "Bippy" it may return - * {{{ - * $line19.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$Bippy - * }}} - * - * @param simpleName The repl-defined name whose real name to retrieve - * - * @return Some real name if the simple name exists, else None - */ - @DeveloperApi - def generatedName(simpleName: String): Option[String] = { - if (simpleName endsWith nme.MODULE_SUFFIX_STRING) optFlatName(simpleName.init) map (_ + nme.MODULE_SUFFIX_STRING) - else optFlatName(simpleName) - } - - // NOTE: Exposed to repl package since used by SparkILoop - private[repl] def flatName(id: String) = optFlatName(id) getOrElse id - // NOTE: Exposed to repl package since used by SparkILoop - private[repl] def optFlatName(id: String) = requestForIdent(id) map (_ fullFlatName id) - - /** - * Retrieves all simple names contained in the current instance. - * - * @return A list of sorted names - */ - @DeveloperApi - def allDefinedNames = definedNameMap.keys.toList.sorted - - private def pathToType(id: String): String = pathToName(newTypeName(id)) - // NOTE: Exposed to repl package since used by SparkILoop - private[repl] def pathToTerm(id: String): String = pathToName(newTermName(id)) - - /** - * Retrieves the full code path to access the specified simple name - * content. - * - * @param name The simple name of the target whose path to determine - * - * @return The full path used to access the specified target (name) - */ - @DeveloperApi - def pathToName(name: Name): String = { - if (definedNameMap contains name) - definedNameMap(name) fullPath name - else name.toString - } - - /** Most recent tree handled which wasn't wholly synthetic. */ - private def mostRecentlyHandledTree: Option[Tree] = { - prevRequests.reverse foreach { req => - req.handlers.reverse foreach { - case x: MemberDefHandler if x.definesValue && !isInternalTermName(x.name) => return Some(x.member) - case _ => () - } - } - None - } - - /** Stubs for work in progress. */ - private def handleTypeRedefinition(name: TypeName, old: Request, req: Request) = { - for (t1 <- old.simpleNameOfType(name) ; t2 <- req.simpleNameOfType(name)) { - logDebug("Redefining type '%s'\n %s -> %s".format(name, t1, t2)) - } - } - - private def handleTermRedefinition(name: TermName, old: Request, req: Request) = { - for (t1 <- old.compilerTypeOf get name ; t2 <- req.compilerTypeOf get name) { - // Printing the types here has a tendency to cause assertion errors, like - // assertion failed: fatal: has owner value x, but a class owner is required - // so DBG is by-name now to keep it in the family. (It also traps the assertion error, - // but we don't want to unnecessarily risk hosing the compiler's internal state.) - logDebug("Redefining term '%s'\n %s -> %s".format(name, t1, t2)) - } - } - - private def recordRequest(req: Request) { - if (req == null || referencedNameMap == null) - return - - prevRequests += req - req.referencedNames foreach (x => referencedNameMap(x) = req) - - // warning about serially defining companions. It'd be easy - // enough to just redefine them together but that may not always - // be what people want so I'm waiting until I can do it better. - for { - name <- req.definedNames filterNot (x => req.definedNames contains x.companionName) - oldReq <- definedNameMap get name.companionName - newSym <- req.definedSymbols get name - oldSym <- oldReq.definedSymbols get name.companionName - if Seq(oldSym, newSym).permutations exists { case Seq(s1, s2) => s1.isClass && s2.isModule } - } { - afterTyper(replwarn(s"warning: previously defined $oldSym is not a companion to $newSym.")) - replwarn("Companions must be defined together; you may wish to use :paste mode for this.") - } - - // Updating the defined name map - req.definedNames foreach { name => - if (definedNameMap contains name) { - if (name.isTypeName) handleTypeRedefinition(name.toTypeName, definedNameMap(name), req) - else handleTermRedefinition(name.toTermName, definedNameMap(name), req) - } - definedNameMap(name) = req - } - } - - private def replwarn(msg: => String) { - if (!settings.nowarnings.value) - printMessage(msg) - } - - private def isParseable(line: String): Boolean = { - beSilentDuring { - try parse(line) match { - case Some(xs) => xs.nonEmpty // parses as-is - case None => true // incomplete - } - catch { case x: Exception => // crashed the compiler - replwarn("Exception in isParseable(\"" + line + "\"): " + x) - false - } - } - } - - private def compileSourcesKeepingRun(sources: SourceFile*) = { - val run = new Run() - reporter.reset() - run compileSources sources.toList - (!reporter.hasErrors, run) - } - - /** - * Compiles specified source files. - * - * @param sources The sequence of source files to compile - * - * @return True if successful, otherwise false - */ - @DeveloperApi - def compileSources(sources: SourceFile*): Boolean = - compileSourcesKeepingRun(sources: _*)._1 - - /** - * Compiles a string of code. - * - * @param code The string of code to compile - * - * @return True if successful, otherwise false - */ - @DeveloperApi - def compileString(code: String): Boolean = - compileSources(new BatchSourceFile(" - UIUtils.headerSparkPage("SQL", content, parent, Some(5000)) + val summary: NodeSeq = +
    +
      + { + if (listener.getRunningExecutions.nonEmpty) { +
    • + Running Queries: + {listener.getRunningExecutions.size} +
    • + } + } + { + if (listener.getCompletedExecutions.nonEmpty) { +
    • + Completed Queries: + {listener.getCompletedExecutions.size} +
    • + } + } + { + if (listener.getFailedExecutions.nonEmpty) { +
    • + Failed Queries: + {listener.getFailedExecutions.size} +
    • + } + } +
    +
    + UIUtils.headerSparkPage("SQL", summary ++ content, parent, Some(5000)) } } From 12e740bba110c6ab017c73c5ef940cce39dd45b7 Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Wed, 27 Sep 2017 23:19:10 +0900 Subject: [PATCH 1407/1765] [SPARK-22130][CORE] UTF8String.trim() scans " " twice ## What changes were proposed in this pull request? This PR allows us to scan a string including only white space (e.g. `" "`) once while the current implementation scans twice (right to left, and then left to right). ## How was this patch tested? Existing test suites Author: Kazuaki Ishizaki Closes #19355 from kiszk/SPARK-22130. --- .../org/apache/spark/unsafe/types/UTF8String.java | 11 +++++------ .../apache/spark/unsafe/types/UTF8StringSuite.java | 3 +++ 2 files changed, 8 insertions(+), 6 deletions(-) diff --git a/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java index ce4a06bde80c4..b0d0c44823e68 100644 --- a/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java +++ b/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java @@ -498,17 +498,16 @@ private UTF8String copyUTF8String(int start, int end) { public UTF8String trim() { int s = 0; - int e = this.numBytes - 1; // skip all of the space (0x20) in the left side while (s < this.numBytes && getByte(s) == 0x20) s++; - // skip all of the space (0x20) in the right side - while (e >= 0 && getByte(e) == 0x20) e--; - if (s > e) { + if (s == this.numBytes) { // empty string return EMPTY_UTF8; - } else { - return copyUTF8String(s, e); } + // skip all of the space (0x20) in the right side + int e = this.numBytes - 1; + while (e > s && getByte(e) == 0x20) e--; + return copyUTF8String(s, e); } /** diff --git a/common/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java b/common/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java index 7b03d2c650fc9..9b303fa5bc6c5 100644 --- a/common/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java +++ b/common/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java @@ -222,10 +222,13 @@ public void substring() { @Test public void trims() { + assertEquals(fromString("1"), fromString("1").trim()); + assertEquals(fromString("hello"), fromString(" hello ").trim()); assertEquals(fromString("hello "), fromString(" hello ").trimLeft()); assertEquals(fromString(" hello"), fromString(" hello ").trimRight()); + assertEquals(EMPTY_UTF8, EMPTY_UTF8.trim()); assertEquals(EMPTY_UTF8, fromString(" ").trim()); assertEquals(EMPTY_UTF8, fromString(" ").trimLeft()); assertEquals(EMPTY_UTF8, fromString(" ").trimRight()); From 09cbf3df20efea09c0941499249b7a3b2bf7e9fd Mon Sep 17 00:00:00 2001 From: Takuya UESHIN Date: Wed, 27 Sep 2017 23:21:44 +0900 Subject: [PATCH 1408/1765] [SPARK-22125][PYSPARK][SQL] Enable Arrow Stream format for vectorized UDF. ## What changes were proposed in this pull request? Currently we use Arrow File format to communicate with Python worker when invoking vectorized UDF but we can use Arrow Stream format. This pr replaces the Arrow File format with the Arrow Stream format. ## How was this patch tested? Existing tests. Author: Takuya UESHIN Closes #19349 from ueshin/issues/SPARK-22125. --- .../apache/spark/api/python/PythonRDD.scala | 325 +------------ .../spark/api/python/PythonRunner.scala | 441 ++++++++++++++++++ python/pyspark/serializers.py | 70 +-- python/pyspark/worker.py | 4 +- .../execution/vectorized/ColumnarBatch.java | 5 + .../python/ArrowEvalPythonExec.scala | 54 ++- .../execution/python/ArrowPythonRunner.scala | 181 +++++++ .../python/BatchEvalPythonExec.scala | 4 +- .../execution/python/PythonUDFRunner.scala | 113 +++++ 9 files changed, 825 insertions(+), 372 deletions(-) create mode 100644 core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonUDFRunner.scala diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala index 86d0405c678a7..f6293c0dc5091 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala @@ -48,7 +48,7 @@ private[spark] class PythonRDD( extends RDD[Array[Byte]](parent) { val bufferSize = conf.getInt("spark.buffer.size", 65536) - val reuse_worker = conf.getBoolean("spark.python.worker.reuse", true) + val reuseWorker = conf.getBoolean("spark.python.worker.reuse", true) override def getPartitions: Array[Partition] = firstParent.partitions @@ -59,7 +59,7 @@ private[spark] class PythonRDD( val asJavaRDD: JavaRDD[Array[Byte]] = JavaRDD.fromRDD(this) override def compute(split: Partition, context: TaskContext): Iterator[Array[Byte]] = { - val runner = PythonRunner(func, bufferSize, reuse_worker) + val runner = PythonRunner(func, bufferSize, reuseWorker) runner.compute(firstParent.iterator(split, context), split.index, context) } } @@ -83,318 +83,9 @@ private[spark] case class PythonFunction( */ private[spark] case class ChainedPythonFunctions(funcs: Seq[PythonFunction]) -/** - * Enumerate the type of command that will be sent to the Python worker - */ -private[spark] object PythonEvalType { - val NON_UDF = 0 - val SQL_BATCHED_UDF = 1 - val SQL_PANDAS_UDF = 2 -} - -private[spark] object PythonRunner { - def apply(func: PythonFunction, bufferSize: Int, reuse_worker: Boolean): PythonRunner = { - new PythonRunner( - Seq(ChainedPythonFunctions(Seq(func))), - bufferSize, - reuse_worker, - PythonEvalType.NON_UDF, - Array(Array(0))) - } -} - -/** - * A helper class to run Python mapPartition/UDFs in Spark. - * - * funcs is a list of independent Python functions, each one of them is a list of chained Python - * functions (from bottom to top). - */ -private[spark] class PythonRunner( - funcs: Seq[ChainedPythonFunctions], - bufferSize: Int, - reuse_worker: Boolean, - evalType: Int, - argOffsets: Array[Array[Int]]) - extends Logging { - - require(funcs.length == argOffsets.length, "argOffsets should have the same length as funcs") - - // All the Python functions should have the same exec, version and envvars. - private val envVars = funcs.head.funcs.head.envVars - private val pythonExec = funcs.head.funcs.head.pythonExec - private val pythonVer = funcs.head.funcs.head.pythonVer - - // TODO: support accumulator in multiple UDF - private val accumulator = funcs.head.funcs.head.accumulator - - def compute( - inputIterator: Iterator[_], - partitionIndex: Int, - context: TaskContext): Iterator[Array[Byte]] = { - val startTime = System.currentTimeMillis - val env = SparkEnv.get - val localdir = env.blockManager.diskBlockManager.localDirs.map(f => f.getPath()).mkString(",") - envVars.put("SPARK_LOCAL_DIRS", localdir) // it's also used in monitor thread - if (reuse_worker) { - envVars.put("SPARK_REUSE_WORKER", "1") - } - val worker: Socket = env.createPythonWorker(pythonExec, envVars.asScala.toMap) - // Whether is the worker released into idle pool - @volatile var released = false - - // Start a thread to feed the process input from our parent's iterator - val writerThread = new WriterThread(env, worker, inputIterator, partitionIndex, context) - - context.addTaskCompletionListener { context => - writerThread.shutdownOnTaskCompletion() - if (!reuse_worker || !released) { - try { - worker.close() - } catch { - case e: Exception => - logWarning("Failed to close worker socket", e) - } - } - } - - writerThread.start() - new MonitorThread(env, worker, context).start() - - // Return an iterator that read lines from the process's stdout - val stream = new DataInputStream(new BufferedInputStream(worker.getInputStream, bufferSize)) - val stdoutIterator = new Iterator[Array[Byte]] { - override def next(): Array[Byte] = { - val obj = _nextObj - if (hasNext) { - _nextObj = read() - } - obj - } - - private def read(): Array[Byte] = { - if (writerThread.exception.isDefined) { - throw writerThread.exception.get - } - try { - stream.readInt() match { - case length if length > 0 => - val obj = new Array[Byte](length) - stream.readFully(obj) - obj - case 0 => Array.empty[Byte] - case SpecialLengths.TIMING_DATA => - // Timing data from worker - val bootTime = stream.readLong() - val initTime = stream.readLong() - val finishTime = stream.readLong() - val boot = bootTime - startTime - val init = initTime - bootTime - val finish = finishTime - initTime - val total = finishTime - startTime - logInfo("Times: total = %s, boot = %s, init = %s, finish = %s".format(total, boot, - init, finish)) - val memoryBytesSpilled = stream.readLong() - val diskBytesSpilled = stream.readLong() - context.taskMetrics.incMemoryBytesSpilled(memoryBytesSpilled) - context.taskMetrics.incDiskBytesSpilled(diskBytesSpilled) - read() - case SpecialLengths.PYTHON_EXCEPTION_THROWN => - // Signals that an exception has been thrown in python - val exLength = stream.readInt() - val obj = new Array[Byte](exLength) - stream.readFully(obj) - throw new PythonException(new String(obj, StandardCharsets.UTF_8), - writerThread.exception.getOrElse(null)) - case SpecialLengths.END_OF_DATA_SECTION => - // We've finished the data section of the output, but we can still - // read some accumulator updates: - val numAccumulatorUpdates = stream.readInt() - (1 to numAccumulatorUpdates).foreach { _ => - val updateLen = stream.readInt() - val update = new Array[Byte](updateLen) - stream.readFully(update) - accumulator.add(update) - } - // Check whether the worker is ready to be re-used. - if (stream.readInt() == SpecialLengths.END_OF_STREAM) { - if (reuse_worker) { - env.releasePythonWorker(pythonExec, envVars.asScala.toMap, worker) - released = true - } - } - null - } - } catch { - - case e: Exception if context.isInterrupted => - logDebug("Exception thrown after task interruption", e) - throw new TaskKilledException(context.getKillReason().getOrElse("unknown reason")) - - case e: Exception if env.isStopped => - logDebug("Exception thrown after context is stopped", e) - null // exit silently - - case e: Exception if writerThread.exception.isDefined => - logError("Python worker exited unexpectedly (crashed)", e) - logError("This may have been caused by a prior exception:", writerThread.exception.get) - throw writerThread.exception.get - - case eof: EOFException => - throw new SparkException("Python worker exited unexpectedly (crashed)", eof) - } - } - - var _nextObj = read() - - override def hasNext: Boolean = _nextObj != null - } - new InterruptibleIterator(context, stdoutIterator) - } - - /** - * The thread responsible for writing the data from the PythonRDD's parent iterator to the - * Python process. - */ - class WriterThread( - env: SparkEnv, - worker: Socket, - inputIterator: Iterator[_], - partitionIndex: Int, - context: TaskContext) - extends Thread(s"stdout writer for $pythonExec") { - - @volatile private var _exception: Exception = null - - private val pythonIncludes = funcs.flatMap(_.funcs.flatMap(_.pythonIncludes.asScala)).toSet - private val broadcastVars = funcs.flatMap(_.funcs.flatMap(_.broadcastVars.asScala)) - - setDaemon(true) - - /** Contains the exception thrown while writing the parent iterator to the Python process. */ - def exception: Option[Exception] = Option(_exception) - - /** Terminates the writer thread, ignoring any exceptions that may occur due to cleanup. */ - def shutdownOnTaskCompletion() { - assert(context.isCompleted) - this.interrupt() - } - - override def run(): Unit = Utils.logUncaughtExceptions { - try { - TaskContext.setTaskContext(context) - val stream = new BufferedOutputStream(worker.getOutputStream, bufferSize) - val dataOut = new DataOutputStream(stream) - // Partition index - dataOut.writeInt(partitionIndex) - // Python version of driver - PythonRDD.writeUTF(pythonVer, dataOut) - // Write out the TaskContextInfo - dataOut.writeInt(context.stageId()) - dataOut.writeInt(context.partitionId()) - dataOut.writeInt(context.attemptNumber()) - dataOut.writeLong(context.taskAttemptId()) - // sparkFilesDir - PythonRDD.writeUTF(SparkFiles.getRootDirectory(), dataOut) - // Python includes (*.zip and *.egg files) - dataOut.writeInt(pythonIncludes.size) - for (include <- pythonIncludes) { - PythonRDD.writeUTF(include, dataOut) - } - // Broadcast variables - val oldBids = PythonRDD.getWorkerBroadcasts(worker) - val newBids = broadcastVars.map(_.id).toSet - // number of different broadcasts - val toRemove = oldBids.diff(newBids) - val cnt = toRemove.size + newBids.diff(oldBids).size - dataOut.writeInt(cnt) - for (bid <- toRemove) { - // remove the broadcast from worker - dataOut.writeLong(- bid - 1) // bid >= 0 - oldBids.remove(bid) - } - for (broadcast <- broadcastVars) { - if (!oldBids.contains(broadcast.id)) { - // send new broadcast - dataOut.writeLong(broadcast.id) - PythonRDD.writeUTF(broadcast.value.path, dataOut) - oldBids.add(broadcast.id) - } - } - dataOut.flush() - // Serialized command: - dataOut.writeInt(evalType) - if (evalType != PythonEvalType.NON_UDF) { - dataOut.writeInt(funcs.length) - funcs.zip(argOffsets).foreach { case (chained, offsets) => - dataOut.writeInt(offsets.length) - offsets.foreach { offset => - dataOut.writeInt(offset) - } - dataOut.writeInt(chained.funcs.length) - chained.funcs.foreach { f => - dataOut.writeInt(f.command.length) - dataOut.write(f.command) - } - } - } else { - val command = funcs.head.funcs.head.command - dataOut.writeInt(command.length) - dataOut.write(command) - } - // Data values - PythonRDD.writeIteratorToStream(inputIterator, dataOut) - dataOut.writeInt(SpecialLengths.END_OF_DATA_SECTION) - dataOut.writeInt(SpecialLengths.END_OF_STREAM) - dataOut.flush() - } catch { - case e: Exception if context.isCompleted || context.isInterrupted => - logDebug("Exception thrown after task completion (likely due to cleanup)", e) - if (!worker.isClosed) { - Utils.tryLog(worker.shutdownOutput()) - } - - case e: Exception => - // We must avoid throwing exceptions here, because the thread uncaught exception handler - // will kill the whole executor (see org.apache.spark.executor.Executor). - _exception = e - if (!worker.isClosed) { - Utils.tryLog(worker.shutdownOutput()) - } - } - } - } - - /** - * It is necessary to have a monitor thread for python workers if the user cancels with - * interrupts disabled. In that case we will need to explicitly kill the worker, otherwise the - * threads can block indefinitely. - */ - class MonitorThread(env: SparkEnv, worker: Socket, context: TaskContext) - extends Thread(s"Worker Monitor for $pythonExec") { - - setDaemon(true) - - override def run() { - // Kill the worker if it is interrupted, checking until task completion. - // TODO: This has a race condition if interruption occurs, as completed may still become true. - while (!context.isInterrupted && !context.isCompleted) { - Thread.sleep(2000) - } - if (!context.isCompleted) { - try { - logWarning("Incomplete task interrupted: Attempting to kill Python Worker") - env.destroyPythonWorker(pythonExec, envVars.asScala.toMap, worker) - } catch { - case e: Exception => - logError("Exception when trying to kill worker", e) - } - } - } - } -} - /** Thrown for exceptions in user Python code. */ -private class PythonException(msg: String, cause: Exception) extends RuntimeException(msg, cause) +private[spark] class PythonException(msg: String, cause: Exception) + extends RuntimeException(msg, cause) /** * Form an RDD[(Array[Byte], Array[Byte])] from key-value pairs returned from Python. @@ -411,14 +102,6 @@ private class PairwiseRDD(prev: RDD[Array[Byte]]) extends RDD[(Long, Array[Byte] val asJavaPairRDD : JavaPairRDD[Long, Array[Byte]] = JavaPairRDD.fromRDD(this) } -private object SpecialLengths { - val END_OF_DATA_SECTION = -1 - val PYTHON_EXCEPTION_THROWN = -2 - val TIMING_DATA = -3 - val END_OF_STREAM = -4 - val NULL = -5 -} - private[spark] object PythonRDD extends Logging { // remember the broadcasts sent to each worker diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala new file mode 100644 index 0000000000000..3688a149443c1 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala @@ -0,0 +1,441 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.api.python + +import java.io._ +import java.net._ +import java.nio.charset.StandardCharsets +import java.util.concurrent.atomic.AtomicBoolean + +import scala.collection.JavaConverters._ + +import org.apache.spark._ +import org.apache.spark.internal.Logging +import org.apache.spark.util._ + + +/** + * Enumerate the type of command that will be sent to the Python worker + */ +private[spark] object PythonEvalType { + val NON_UDF = 0 + val SQL_BATCHED_UDF = 1 + val SQL_PANDAS_UDF = 2 +} + +/** + * A helper class to run Python mapPartition/UDFs in Spark. + * + * funcs is a list of independent Python functions, each one of them is a list of chained Python + * functions (from bottom to top). + */ +private[spark] abstract class BasePythonRunner[IN, OUT]( + funcs: Seq[ChainedPythonFunctions], + bufferSize: Int, + reuseWorker: Boolean, + evalType: Int, + argOffsets: Array[Array[Int]]) + extends Logging { + + require(funcs.length == argOffsets.length, "argOffsets should have the same length as funcs") + + // All the Python functions should have the same exec, version and envvars. + protected val envVars = funcs.head.funcs.head.envVars + protected val pythonExec = funcs.head.funcs.head.pythonExec + protected val pythonVer = funcs.head.funcs.head.pythonVer + + // TODO: support accumulator in multiple UDF + protected val accumulator = funcs.head.funcs.head.accumulator + + def compute( + inputIterator: Iterator[IN], + partitionIndex: Int, + context: TaskContext): Iterator[OUT] = { + val startTime = System.currentTimeMillis + val env = SparkEnv.get + val localdir = env.blockManager.diskBlockManager.localDirs.map(f => f.getPath()).mkString(",") + envVars.put("SPARK_LOCAL_DIRS", localdir) // it's also used in monitor thread + if (reuseWorker) { + envVars.put("SPARK_REUSE_WORKER", "1") + } + val worker: Socket = env.createPythonWorker(pythonExec, envVars.asScala.toMap) + // Whether is the worker released into idle pool + val released = new AtomicBoolean(false) + + // Start a thread to feed the process input from our parent's iterator + val writerThread = newWriterThread(env, worker, inputIterator, partitionIndex, context) + + context.addTaskCompletionListener { _ => + writerThread.shutdownOnTaskCompletion() + if (!reuseWorker || !released.get) { + try { + worker.close() + } catch { + case e: Exception => + logWarning("Failed to close worker socket", e) + } + } + } + + writerThread.start() + new MonitorThread(env, worker, context).start() + + // Return an iterator that read lines from the process's stdout + val stream = new DataInputStream(new BufferedInputStream(worker.getInputStream, bufferSize)) + + val stdoutIterator = newReaderIterator( + stream, writerThread, startTime, env, worker, released, context) + new InterruptibleIterator(context, stdoutIterator) + } + + protected def newWriterThread( + env: SparkEnv, + worker: Socket, + inputIterator: Iterator[IN], + partitionIndex: Int, + context: TaskContext): WriterThread + + protected def newReaderIterator( + stream: DataInputStream, + writerThread: WriterThread, + startTime: Long, + env: SparkEnv, + worker: Socket, + released: AtomicBoolean, + context: TaskContext): Iterator[OUT] + + /** + * The thread responsible for writing the data from the PythonRDD's parent iterator to the + * Python process. + */ + abstract class WriterThread( + env: SparkEnv, + worker: Socket, + inputIterator: Iterator[IN], + partitionIndex: Int, + context: TaskContext) + extends Thread(s"stdout writer for $pythonExec") { + + @volatile private var _exception: Exception = null + + private val pythonIncludes = funcs.flatMap(_.funcs.flatMap(_.pythonIncludes.asScala)).toSet + private val broadcastVars = funcs.flatMap(_.funcs.flatMap(_.broadcastVars.asScala)) + + setDaemon(true) + + /** Contains the exception thrown while writing the parent iterator to the Python process. */ + def exception: Option[Exception] = Option(_exception) + + /** Terminates the writer thread, ignoring any exceptions that may occur due to cleanup. */ + def shutdownOnTaskCompletion() { + assert(context.isCompleted) + this.interrupt() + } + + /** + * Writes a command section to the stream connected to the Python worker. + */ + protected def writeCommand(dataOut: DataOutputStream): Unit + + /** + * Writes input data to the stream connected to the Python worker. + */ + protected def writeIteratorToStream(dataOut: DataOutputStream): Unit + + override def run(): Unit = Utils.logUncaughtExceptions { + try { + TaskContext.setTaskContext(context) + val stream = new BufferedOutputStream(worker.getOutputStream, bufferSize) + val dataOut = new DataOutputStream(stream) + // Partition index + dataOut.writeInt(partitionIndex) + // Python version of driver + PythonRDD.writeUTF(pythonVer, dataOut) + // Write out the TaskContextInfo + dataOut.writeInt(context.stageId()) + dataOut.writeInt(context.partitionId()) + dataOut.writeInt(context.attemptNumber()) + dataOut.writeLong(context.taskAttemptId()) + // sparkFilesDir + PythonRDD.writeUTF(SparkFiles.getRootDirectory(), dataOut) + // Python includes (*.zip and *.egg files) + dataOut.writeInt(pythonIncludes.size) + for (include <- pythonIncludes) { + PythonRDD.writeUTF(include, dataOut) + } + // Broadcast variables + val oldBids = PythonRDD.getWorkerBroadcasts(worker) + val newBids = broadcastVars.map(_.id).toSet + // number of different broadcasts + val toRemove = oldBids.diff(newBids) + val cnt = toRemove.size + newBids.diff(oldBids).size + dataOut.writeInt(cnt) + for (bid <- toRemove) { + // remove the broadcast from worker + dataOut.writeLong(- bid - 1) // bid >= 0 + oldBids.remove(bid) + } + for (broadcast <- broadcastVars) { + if (!oldBids.contains(broadcast.id)) { + // send new broadcast + dataOut.writeLong(broadcast.id) + PythonRDD.writeUTF(broadcast.value.path, dataOut) + oldBids.add(broadcast.id) + } + } + dataOut.flush() + + dataOut.writeInt(evalType) + writeCommand(dataOut) + writeIteratorToStream(dataOut) + + dataOut.writeInt(SpecialLengths.END_OF_STREAM) + dataOut.flush() + } catch { + case e: Exception if context.isCompleted || context.isInterrupted => + logDebug("Exception thrown after task completion (likely due to cleanup)", e) + if (!worker.isClosed) { + Utils.tryLog(worker.shutdownOutput()) + } + + case e: Exception => + // We must avoid throwing exceptions here, because the thread uncaught exception handler + // will kill the whole executor (see org.apache.spark.executor.Executor). + _exception = e + if (!worker.isClosed) { + Utils.tryLog(worker.shutdownOutput()) + } + } + } + } + + abstract class ReaderIterator( + stream: DataInputStream, + writerThread: WriterThread, + startTime: Long, + env: SparkEnv, + worker: Socket, + released: AtomicBoolean, + context: TaskContext) + extends Iterator[OUT] { + + private var nextObj: OUT = _ + private var eos = false + + override def hasNext: Boolean = nextObj != null || { + if (!eos) { + nextObj = read() + hasNext + } else { + false + } + } + + override def next(): OUT = { + if (hasNext) { + val obj = nextObj + nextObj = null.asInstanceOf[OUT] + obj + } else { + Iterator.empty.next() + } + } + + /** + * Reads next object from the stream. + * When the stream reaches end of data, needs to process the following sections, + * and then returns null. + */ + protected def read(): OUT + + protected def handleTimingData(): Unit = { + // Timing data from worker + val bootTime = stream.readLong() + val initTime = stream.readLong() + val finishTime = stream.readLong() + val boot = bootTime - startTime + val init = initTime - bootTime + val finish = finishTime - initTime + val total = finishTime - startTime + logInfo("Times: total = %s, boot = %s, init = %s, finish = %s".format(total, boot, + init, finish)) + val memoryBytesSpilled = stream.readLong() + val diskBytesSpilled = stream.readLong() + context.taskMetrics.incMemoryBytesSpilled(memoryBytesSpilled) + context.taskMetrics.incDiskBytesSpilled(diskBytesSpilled) + } + + protected def handlePythonException(): PythonException = { + // Signals that an exception has been thrown in python + val exLength = stream.readInt() + val obj = new Array[Byte](exLength) + stream.readFully(obj) + new PythonException(new String(obj, StandardCharsets.UTF_8), + writerThread.exception.getOrElse(null)) + } + + protected def handleEndOfDataSection(): Unit = { + // We've finished the data section of the output, but we can still + // read some accumulator updates: + val numAccumulatorUpdates = stream.readInt() + (1 to numAccumulatorUpdates).foreach { _ => + val updateLen = stream.readInt() + val update = new Array[Byte](updateLen) + stream.readFully(update) + accumulator.add(update) + } + // Check whether the worker is ready to be re-used. + if (stream.readInt() == SpecialLengths.END_OF_STREAM) { + if (reuseWorker) { + env.releasePythonWorker(pythonExec, envVars.asScala.toMap, worker) + released.set(true) + } + } + eos = true + } + + protected val handleException: PartialFunction[Throwable, OUT] = { + case e: Exception if context.isInterrupted => + logDebug("Exception thrown after task interruption", e) + throw new TaskKilledException(context.getKillReason().getOrElse("unknown reason")) + + case e: Exception if env.isStopped => + logDebug("Exception thrown after context is stopped", e) + null.asInstanceOf[OUT] // exit silently + + case e: Exception if writerThread.exception.isDefined => + logError("Python worker exited unexpectedly (crashed)", e) + logError("This may have been caused by a prior exception:", writerThread.exception.get) + throw writerThread.exception.get + + case eof: EOFException => + throw new SparkException("Python worker exited unexpectedly (crashed)", eof) + } + } + + /** + * It is necessary to have a monitor thread for python workers if the user cancels with + * interrupts disabled. In that case we will need to explicitly kill the worker, otherwise the + * threads can block indefinitely. + */ + class MonitorThread(env: SparkEnv, worker: Socket, context: TaskContext) + extends Thread(s"Worker Monitor for $pythonExec") { + + setDaemon(true) + + override def run() { + // Kill the worker if it is interrupted, checking until task completion. + // TODO: This has a race condition if interruption occurs, as completed may still become true. + while (!context.isInterrupted && !context.isCompleted) { + Thread.sleep(2000) + } + if (!context.isCompleted) { + try { + logWarning("Incomplete task interrupted: Attempting to kill Python Worker") + env.destroyPythonWorker(pythonExec, envVars.asScala.toMap, worker) + } catch { + case e: Exception => + logError("Exception when trying to kill worker", e) + } + } + } + } +} + +private[spark] object PythonRunner { + + def apply(func: PythonFunction, bufferSize: Int, reuseWorker: Boolean): PythonRunner = { + new PythonRunner(Seq(ChainedPythonFunctions(Seq(func))), bufferSize, reuseWorker) + } +} + +/** + * A helper class to run Python mapPartition in Spark. + */ +private[spark] class PythonRunner( + funcs: Seq[ChainedPythonFunctions], + bufferSize: Int, + reuseWorker: Boolean) + extends BasePythonRunner[Array[Byte], Array[Byte]]( + funcs, bufferSize, reuseWorker, PythonEvalType.NON_UDF, Array(Array(0))) { + + protected override def newWriterThread( + env: SparkEnv, + worker: Socket, + inputIterator: Iterator[Array[Byte]], + partitionIndex: Int, + context: TaskContext): WriterThread = { + new WriterThread(env, worker, inputIterator, partitionIndex, context) { + + protected override def writeCommand(dataOut: DataOutputStream): Unit = { + val command = funcs.head.funcs.head.command + dataOut.writeInt(command.length) + dataOut.write(command) + } + + protected override def writeIteratorToStream(dataOut: DataOutputStream): Unit = { + PythonRDD.writeIteratorToStream(inputIterator, dataOut) + dataOut.writeInt(SpecialLengths.END_OF_DATA_SECTION) + } + } + } + + protected override def newReaderIterator( + stream: DataInputStream, + writerThread: WriterThread, + startTime: Long, + env: SparkEnv, + worker: Socket, + released: AtomicBoolean, + context: TaskContext): Iterator[Array[Byte]] = { + new ReaderIterator(stream, writerThread, startTime, env, worker, released, context) { + + protected override def read(): Array[Byte] = { + if (writerThread.exception.isDefined) { + throw writerThread.exception.get + } + try { + stream.readInt() match { + case length if length > 0 => + val obj = new Array[Byte](length) + stream.readFully(obj) + obj + case 0 => Array.empty[Byte] + case SpecialLengths.TIMING_DATA => + handleTimingData() + read() + case SpecialLengths.PYTHON_EXCEPTION_THROWN => + throw handlePythonException() + case SpecialLengths.END_OF_DATA_SECTION => + handleEndOfDataSection() + null + } + } catch handleException + } + } + } +} + +private[spark] object SpecialLengths { + val END_OF_DATA_SECTION = -1 + val PYTHON_EXCEPTION_THROWN = -2 + val TIMING_DATA = -3 + val END_OF_STREAM = -4 + val NULL = -5 + val START_ARROW_STREAM = -6 +} diff --git a/python/pyspark/serializers.py b/python/pyspark/serializers.py index 7c1fbadcb82be..db77b7e150b24 100644 --- a/python/pyspark/serializers.py +++ b/python/pyspark/serializers.py @@ -79,6 +79,7 @@ class SpecialLengths(object): TIMING_DATA = -3 END_OF_STREAM = -4 NULL = -5 + START_ARROW_STREAM = -6 class PythonEvalType(object): @@ -211,44 +212,61 @@ def __repr__(self): return "ArrowSerializer" -class ArrowPandasSerializer(ArrowSerializer): +def _create_batch(series): + import pyarrow as pa + # Make input conform to [(series1, type1), (series2, type2), ...] + if not isinstance(series, (list, tuple)) or \ + (len(series) == 2 and isinstance(series[1], pa.DataType)): + series = [series] + series = ((s, None) if not isinstance(s, (list, tuple)) else s for s in series) + + # If a nullable integer series has been promoted to floating point with NaNs, need to cast + # NOTE: this is not necessary with Arrow >= 0.7 + def cast_series(s, t): + if t is None or s.dtype == t.to_pandas_dtype(): + return s + else: + return s.fillna(0).astype(t.to_pandas_dtype(), copy=False) + + arrs = [pa.Array.from_pandas(cast_series(s, t), mask=s.isnull(), type=t) for s, t in series] + return pa.RecordBatch.from_arrays(arrs, ["_%d" % i for i in xrange(len(arrs))]) + + +class ArrowStreamPandasSerializer(Serializer): """ - Serializes Pandas.Series as Arrow data. + Serializes Pandas.Series as Arrow data with Arrow streaming format. """ - def dumps(self, series): + def dump_stream(self, iterator, stream): """ - Make an ArrowRecordBatch from a Pandas Series and serialize. Input is a single series or + Make ArrowRecordBatches from Pandas Serieses and serialize. Input is a single series or a list of series accompanied by an optional pyarrow type to coerce the data to. """ import pyarrow as pa - # Make input conform to [(series1, type1), (series2, type2), ...] - if not isinstance(series, (list, tuple)) or \ - (len(series) == 2 and isinstance(series[1], pa.DataType)): - series = [series] - series = ((s, None) if not isinstance(s, (list, tuple)) else s for s in series) - - # If a nullable integer series has been promoted to floating point with NaNs, need to cast - # NOTE: this is not necessary with Arrow >= 0.7 - def cast_series(s, t): - if t is None or s.dtype == t.to_pandas_dtype(): - return s - else: - return s.fillna(0).astype(t.to_pandas_dtype(), copy=False) - - arrs = [pa.Array.from_pandas(cast_series(s, t), mask=s.isnull(), type=t) for s, t in series] - batch = pa.RecordBatch.from_arrays(arrs, ["_%d" % i for i in xrange(len(arrs))]) - return super(ArrowPandasSerializer, self).dumps(batch) + writer = None + try: + for series in iterator: + batch = _create_batch(series) + if writer is None: + write_int(SpecialLengths.START_ARROW_STREAM, stream) + writer = pa.RecordBatchStreamWriter(stream, batch.schema) + writer.write_batch(batch) + finally: + if writer is not None: + writer.close() - def loads(self, obj): + def load_stream(self, stream): """ - Deserialize an ArrowRecordBatch to an Arrow table and return as a list of pandas.Series. + Deserialize ArrowRecordBatchs to an Arrow table and return as a list of pandas.Series. """ - table = super(ArrowPandasSerializer, self).loads(obj) - return [c.to_pandas() for c in table.itercolumns()] + import pyarrow as pa + reader = pa.open_stream(stream) + for batch in reader: + table = pa.Table.from_batches([batch]) + yield [c.to_pandas() for c in table.itercolumns()] def __repr__(self): - return "ArrowPandasSerializer" + return "ArrowStreamPandasSerializer" class BatchedSerializer(Serializer): diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py index fd917c400c872..4e24789cf010d 100644 --- a/python/pyspark/worker.py +++ b/python/pyspark/worker.py @@ -31,7 +31,7 @@ from pyspark.files import SparkFiles from pyspark.serializers import write_with_length, write_int, read_long, \ write_long, read_int, SpecialLengths, PythonEvalType, UTF8Deserializer, PickleSerializer, \ - BatchedSerializer, ArrowPandasSerializer + BatchedSerializer, ArrowStreamPandasSerializer from pyspark.sql.types import toArrowType from pyspark import shuffle @@ -123,7 +123,7 @@ def read_udfs(pickleSer, infile, eval_type): func = lambda _, it: map(mapper, it) if eval_type == PythonEvalType.SQL_PANDAS_UDF: - ser = ArrowPandasSerializer() + ser = ArrowStreamPandasSerializer() else: ser = BatchedSerializer(PickleSerializer(), 100) diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarBatch.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarBatch.java index e782756a3e781..bc546c7c425b1 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarBatch.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarBatch.java @@ -462,6 +462,11 @@ public int numValidRows() { return numRows - numRowsFiltered; } + /** + * Returns the schema that makes up this batch. + */ + public StructType schema() { return schema; } + /** * Returns the max capacity (in number of rows) for this batch. */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonExec.scala index 5e72cd255873a..f7e8cbe416121 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonExec.scala @@ -17,12 +17,13 @@ package org.apache.spark.sql.execution.python +import scala.collection.JavaConverters._ + import org.apache.spark.TaskContext -import org.apache.spark.api.python.{ChainedPythonFunctions, PythonEvalType, PythonRunner} +import org.apache.spark.api.python.{ChainedPythonFunctions, PythonEvalType} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.execution.SparkPlan -import org.apache.spark.sql.execution.arrow.{ArrowConverters, ArrowPayload} import org.apache.spark.sql.types.StructType /** @@ -39,25 +40,36 @@ case class ArrowEvalPythonExec(udfs: Seq[PythonUDF], output: Seq[Attribute], chi iter: Iterator[InternalRow], schema: StructType, context: TaskContext): Iterator[InternalRow] = { - val inputIterator = ArrowConverters.toPayloadIterator( - iter, schema, conf.arrowMaxRecordsPerBatch, context).map(_.asPythonSerializable) - - // Output iterator for results from Python. - val outputIterator = new PythonRunner( - funcs, bufferSize, reuseWorker, PythonEvalType.SQL_PANDAS_UDF, argOffsets) - .compute(inputIterator, context.partitionId(), context) - - val outputRowIterator = ArrowConverters.fromPayloadIterator( - outputIterator.map(new ArrowPayload(_)), context) - - // Verify that the output schema is correct - if (outputRowIterator.hasNext) { - val schemaOut = StructType.fromAttributes(output.drop(child.output.length).zipWithIndex - .map { case (attr, i) => attr.withName(s"_$i") }) - assert(schemaOut.equals(outputRowIterator.schema), - s"Invalid schema from pandas_udf: expected $schemaOut, got ${outputRowIterator.schema}") - } - outputRowIterator + val schemaOut = StructType.fromAttributes(output.drop(child.output.length).zipWithIndex + .map { case (attr, i) => attr.withName(s"_$i") }) + + val columnarBatchIter = new ArrowPythonRunner( + funcs, conf.arrowMaxRecordsPerBatch, bufferSize, reuseWorker, + PythonEvalType.SQL_PANDAS_UDF, argOffsets, schema) + .compute(iter, context.partitionId(), context) + + new Iterator[InternalRow] { + + var currentIter = if (columnarBatchIter.hasNext) { + val batch = columnarBatchIter.next() + assert(schemaOut.equals(batch.schema), + s"Invalid schema from pandas_udf: expected $schemaOut, got ${batch.schema}") + batch.rowIterator.asScala + } else { + Iterator.empty + } + + override def hasNext: Boolean = currentIter.hasNext || { + if (columnarBatchIter.hasNext) { + currentIter = columnarBatchIter.next().rowIterator.asScala + hasNext + } else { + false + } + } + + override def next(): InternalRow = currentIter.next() + } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala new file mode 100644 index 0000000000000..bbad9d6b631fd --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala @@ -0,0 +1,181 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.python + +import java.io._ +import java.net._ +import java.util.concurrent.atomic.AtomicBoolean + +import scala.collection.JavaConverters._ + +import org.apache.arrow.vector.VectorSchemaRoot +import org.apache.arrow.vector.stream.{ArrowStreamReader, ArrowStreamWriter} + +import org.apache.spark._ +import org.apache.spark.api.python._ +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.execution.arrow.{ArrowUtils, ArrowWriter} +import org.apache.spark.sql.execution.vectorized.{ArrowColumnVector, ColumnarBatch, ColumnVector} +import org.apache.spark.sql.types._ +import org.apache.spark.util.Utils + +/** + * Similar to `PythonUDFRunner`, but exchange data with Python worker via Arrow stream. + */ +class ArrowPythonRunner( + funcs: Seq[ChainedPythonFunctions], + batchSize: Int, + bufferSize: Int, + reuseWorker: Boolean, + evalType: Int, + argOffsets: Array[Array[Int]], + schema: StructType) + extends BasePythonRunner[InternalRow, ColumnarBatch]( + funcs, bufferSize, reuseWorker, evalType, argOffsets) { + + protected override def newWriterThread( + env: SparkEnv, + worker: Socket, + inputIterator: Iterator[InternalRow], + partitionIndex: Int, + context: TaskContext): WriterThread = { + new WriterThread(env, worker, inputIterator, partitionIndex, context) { + + protected override def writeCommand(dataOut: DataOutputStream): Unit = { + PythonUDFRunner.writeUDFs(dataOut, funcs, argOffsets) + } + + protected override def writeIteratorToStream(dataOut: DataOutputStream): Unit = { + val arrowSchema = ArrowUtils.toArrowSchema(schema) + val allocator = ArrowUtils.rootAllocator.newChildAllocator( + s"stdout writer for $pythonExec", 0, Long.MaxValue) + + val root = VectorSchemaRoot.create(arrowSchema, allocator) + val arrowWriter = ArrowWriter.create(root) + + var closed = false + + context.addTaskCompletionListener { _ => + if (!closed) { + root.close() + allocator.close() + } + } + + val writer = new ArrowStreamWriter(root, null, dataOut) + writer.start() + + Utils.tryWithSafeFinally { + while (inputIterator.hasNext) { + var rowCount = 0 + while (inputIterator.hasNext && (batchSize <= 0 || rowCount < batchSize)) { + val row = inputIterator.next() + arrowWriter.write(row) + rowCount += 1 + } + arrowWriter.finish() + writer.writeBatch() + arrowWriter.reset() + } + } { + writer.end() + root.close() + allocator.close() + closed = true + } + } + } + } + + protected override def newReaderIterator( + stream: DataInputStream, + writerThread: WriterThread, + startTime: Long, + env: SparkEnv, + worker: Socket, + released: AtomicBoolean, + context: TaskContext): Iterator[ColumnarBatch] = { + new ReaderIterator(stream, writerThread, startTime, env, worker, released, context) { + + private val allocator = ArrowUtils.rootAllocator.newChildAllocator( + s"stdin reader for $pythonExec", 0, Long.MaxValue) + + private var reader: ArrowStreamReader = _ + private var root: VectorSchemaRoot = _ + private var schema: StructType = _ + private var vectors: Array[ColumnVector] = _ + + private var closed = false + + context.addTaskCompletionListener { _ => + // todo: we need something like `reader.end()`, which release all the resources, but leave + // the input stream open. `reader.close()` will close the socket and we can't reuse worker. + // So here we simply not close the reader, which is problematic. + if (!closed) { + if (root != null) { + root.close() + } + allocator.close() + } + } + + private var batchLoaded = true + + protected override def read(): ColumnarBatch = { + if (writerThread.exception.isDefined) { + throw writerThread.exception.get + } + try { + if (reader != null && batchLoaded) { + batchLoaded = reader.loadNextBatch() + if (batchLoaded) { + val batch = new ColumnarBatch(schema, vectors, root.getRowCount) + batch.setNumRows(root.getRowCount) + batch + } else { + root.close() + allocator.close() + closed = true + // Reach end of stream. Call `read()` again to read control data. + read() + } + } else { + stream.readInt() match { + case SpecialLengths.START_ARROW_STREAM => + reader = new ArrowStreamReader(stream, allocator) + root = reader.getVectorSchemaRoot() + schema = ArrowUtils.fromArrowSchema(root.getSchema()) + vectors = root.getFieldVectors().asScala.map { vector => + new ArrowColumnVector(vector) + }.toArray[ColumnVector] + read() + case SpecialLengths.TIMING_DATA => + handleTimingData() + read() + case SpecialLengths.PYTHON_EXCEPTION_THROWN => + throw handlePythonException() + case SpecialLengths.END_OF_DATA_SECTION => + handleEndOfDataSection() + null + } + } + } catch handleException + } + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExec.scala index 2978eac50554d..26ee25f633ea4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExec.scala @@ -22,7 +22,7 @@ import scala.collection.JavaConverters._ import net.razorvine.pickle.{Pickler, Unpickler} import org.apache.spark.TaskContext -import org.apache.spark.api.python.{ChainedPythonFunctions, PythonEvalType, PythonRunner} +import org.apache.spark.api.python.{ChainedPythonFunctions, PythonEvalType} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.execution.SparkPlan @@ -68,7 +68,7 @@ case class BatchEvalPythonExec(udfs: Seq[PythonUDF], output: Seq[Attribute], chi }.grouped(100).map(x => pickle.dumps(x.toArray)) // Output iterator for results from Python. - val outputIterator = new PythonRunner( + val outputIterator = new PythonUDFRunner( funcs, bufferSize, reuseWorker, PythonEvalType.SQL_BATCHED_UDF, argOffsets) .compute(inputIterator, context.partitionId(), context) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonUDFRunner.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonUDFRunner.scala new file mode 100644 index 0000000000000..e28def1c4b423 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonUDFRunner.scala @@ -0,0 +1,113 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.python + +import java.io._ +import java.net._ +import java.util.concurrent.atomic.AtomicBoolean + +import org.apache.spark._ +import org.apache.spark.api.python._ + +/** + * A helper class to run Python UDFs in Spark. + */ +class PythonUDFRunner( + funcs: Seq[ChainedPythonFunctions], + bufferSize: Int, + reuseWorker: Boolean, + evalType: Int, + argOffsets: Array[Array[Int]]) + extends BasePythonRunner[Array[Byte], Array[Byte]]( + funcs, bufferSize, reuseWorker, evalType, argOffsets) { + + protected override def newWriterThread( + env: SparkEnv, + worker: Socket, + inputIterator: Iterator[Array[Byte]], + partitionIndex: Int, + context: TaskContext): WriterThread = { + new WriterThread(env, worker, inputIterator, partitionIndex, context) { + + protected override def writeCommand(dataOut: DataOutputStream): Unit = { + PythonUDFRunner.writeUDFs(dataOut, funcs, argOffsets) + } + + protected override def writeIteratorToStream(dataOut: DataOutputStream): Unit = { + PythonRDD.writeIteratorToStream(inputIterator, dataOut) + dataOut.writeInt(SpecialLengths.END_OF_DATA_SECTION) + } + } + } + + protected override def newReaderIterator( + stream: DataInputStream, + writerThread: WriterThread, + startTime: Long, + env: SparkEnv, + worker: Socket, + released: AtomicBoolean, + context: TaskContext): Iterator[Array[Byte]] = { + new ReaderIterator(stream, writerThread, startTime, env, worker, released, context) { + + protected override def read(): Array[Byte] = { + if (writerThread.exception.isDefined) { + throw writerThread.exception.get + } + try { + stream.readInt() match { + case length if length > 0 => + val obj = new Array[Byte](length) + stream.readFully(obj) + obj + case 0 => Array.empty[Byte] + case SpecialLengths.TIMING_DATA => + handleTimingData() + read() + case SpecialLengths.PYTHON_EXCEPTION_THROWN => + throw handlePythonException() + case SpecialLengths.END_OF_DATA_SECTION => + handleEndOfDataSection() + null + } + } catch handleException + } + } + } +} + +object PythonUDFRunner { + + def writeUDFs( + dataOut: DataOutputStream, + funcs: Seq[ChainedPythonFunctions], + argOffsets: Array[Array[Int]]): Unit = { + dataOut.writeInt(funcs.length) + funcs.zip(argOffsets).foreach { case (chained, offsets) => + dataOut.writeInt(offsets.length) + offsets.foreach { offset => + dataOut.writeInt(offset) + } + dataOut.writeInt(chained.funcs.length) + chained.funcs.foreach { f => + dataOut.writeInt(f.command.length) + dataOut.write(f.command) + } + } + } +} From 9b98aef6a39a5a9ea9fc5481b5a0d92620ba6347 Mon Sep 17 00:00:00 2001 From: Sean Owen Date: Wed, 27 Sep 2017 13:40:21 -0700 Subject: [PATCH 1409/1765] [HOTFIX][BUILD] Fix finalizer checkstyle error and re-disable checkstyle ## What changes were proposed in this pull request? Fix finalizer checkstyle violation by just turning it off; re-disable checkstyle as it won't be run by SBT PR builder. See https://github.com/apache/spark/pull/18887#issuecomment-332580700 ## How was this patch tested? `./dev/lint-java` runs successfully Author: Sean Owen Closes #19371 from srowen/HotfixFinalizerCheckstlye. --- .../java/org/apache/spark/io/NioBufferedFileInputStream.java | 2 -- dev/checkstyle-suppressions.xml | 2 -- dev/checkstyle.xml | 1 - pom.xml | 1 + 4 files changed, 1 insertion(+), 5 deletions(-) diff --git a/core/src/main/java/org/apache/spark/io/NioBufferedFileInputStream.java b/core/src/main/java/org/apache/spark/io/NioBufferedFileInputStream.java index ea5f1a9abf69b..f6d1288cb263d 100644 --- a/core/src/main/java/org/apache/spark/io/NioBufferedFileInputStream.java +++ b/core/src/main/java/org/apache/spark/io/NioBufferedFileInputStream.java @@ -130,10 +130,8 @@ public synchronized void close() throws IOException { StorageUtils.dispose(byteBuffer); } - //checkstyle.off: NoFinalizer @Override protected void finalize() throws IOException { close(); } - //checkstyle.on: NoFinalizer } diff --git a/dev/checkstyle-suppressions.xml b/dev/checkstyle-suppressions.xml index 6e15f6955984e..bbda824dd13b4 100644 --- a/dev/checkstyle-suppressions.xml +++ b/dev/checkstyle-suppressions.xml @@ -40,8 +40,6 @@ files="src/main/java/org/apache/hive/service/*"/> - - diff --git a/pom.xml b/pom.xml index b0408ecca0f66..83a35006707da 100644 --- a/pom.xml +++ b/pom.xml @@ -2488,6 +2488,7 @@ maven-checkstyle-plugin 2.17 + false true ${basedir}/src/main/java,${basedir}/src/main/scala ${basedir}/src/test/java From 02bb0682e68a2ce81f3b98d33649d368da7f2b3d Mon Sep 17 00:00:00 2001 From: Herman van Hovell Date: Wed, 27 Sep 2017 23:08:30 +0200 Subject: [PATCH 1410/1765] [SPARK-22143][SQL] Fix memory leak in OffHeapColumnVector ## What changes were proposed in this pull request? `WriteableColumnVector` does not close its child column vectors. This can create memory leaks for `OffHeapColumnVector` where we do not clean up the memory allocated by a vectors children. This can be especially bad for string columns (which uses a child byte column vector). ## How was this patch tested? I have updated the existing tests to always use both on-heap and off-heap vectors. Testing and diagnoses was done locally. Author: Herman van Hovell Closes #19367 from hvanhovell/SPARK-22143. --- .../vectorized/OffHeapColumnVector.java | 1 + .../vectorized/OnHeapColumnVector.java | 10 + .../vectorized/WritableColumnVector.java | 18 ++ .../vectorized/ColumnVectorSuite.scala | 102 +++++---- .../vectorized/ColumnarBatchSuite.scala | 194 ++++++++---------- 5 files changed, 165 insertions(+), 160 deletions(-) diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OffHeapColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OffHeapColumnVector.java index e1d36858d4eee..8cbc895506d91 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OffHeapColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OffHeapColumnVector.java @@ -85,6 +85,7 @@ public long nullsNativeAddress() { @Override public void close() { + super.close(); Platform.freeMemory(nulls); Platform.freeMemory(data); Platform.freeMemory(lengthData); diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OnHeapColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OnHeapColumnVector.java index 96a452978cb35..2725a29eeabe8 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OnHeapColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OnHeapColumnVector.java @@ -90,6 +90,16 @@ public long nullsNativeAddress() { @Override public void close() { + super.close(); + nulls = null; + byteData = null; + shortData = null; + intData = null; + longData = null; + floatData = null; + doubleData = null; + arrayLengths = null; + arrayOffsets = null; } // diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/WritableColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/WritableColumnVector.java index 0bddc351e1bed..163f2511e5f73 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/WritableColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/WritableColumnVector.java @@ -59,6 +59,24 @@ public void reset() { } } + @Override + public void close() { + if (childColumns != null) { + for (int i = 0; i < childColumns.length; i++) { + childColumns[i].close(); + childColumns[i] = null; + } + childColumns = null; + } + if (dictionaryIds != null) { + dictionaryIds.close(); + dictionaryIds = null; + } + dictionary = null; + resultStruct = null; + resultArray = null; + } + public void reserve(int requiredCapacity) { if (requiredCapacity > capacity) { int newCapacity = (int) Math.min(MAX_CAPACITY, requiredCapacity * 2L); diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnVectorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnVectorSuite.scala index f7b06c97f9db6..85da8270d4cba 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnVectorSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnVectorSuite.scala @@ -25,19 +25,24 @@ import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String class ColumnVectorSuite extends SparkFunSuite with BeforeAndAfterEach { - - var testVector: WritableColumnVector = _ - - private def allocate(capacity: Int, dt: DataType): WritableColumnVector = { - new OnHeapColumnVector(capacity, dt) + private def withVector( + vector: WritableColumnVector)( + block: WritableColumnVector => Unit): Unit = { + try block(vector) finally vector.close() } - override def afterEach(): Unit = { - testVector.close() + private def testVectors( + name: String, + size: Int, + dt: DataType)( + block: WritableColumnVector => Unit): Unit = { + test(name) { + withVector(new OnHeapColumnVector(size, dt))(block) + withVector(new OffHeapColumnVector(size, dt))(block) + } } - test("boolean") { - testVector = allocate(10, BooleanType) + testVectors("boolean", 10, BooleanType) { testVector => (0 until 10).foreach { i => testVector.appendBoolean(i % 2 == 0) } @@ -49,8 +54,7 @@ class ColumnVectorSuite extends SparkFunSuite with BeforeAndAfterEach { } } - test("byte") { - testVector = allocate(10, ByteType) + testVectors("byte", 10, ByteType) { testVector => (0 until 10).foreach { i => testVector.appendByte(i.toByte) } @@ -58,12 +62,11 @@ class ColumnVectorSuite extends SparkFunSuite with BeforeAndAfterEach { val array = new ColumnVector.Array(testVector) (0 until 10).foreach { i => - assert(array.get(i, ByteType) === (i.toByte)) + assert(array.get(i, ByteType) === i.toByte) } } - test("short") { - testVector = allocate(10, ShortType) + testVectors("short", 10, ShortType) { testVector => (0 until 10).foreach { i => testVector.appendShort(i.toShort) } @@ -71,12 +74,11 @@ class ColumnVectorSuite extends SparkFunSuite with BeforeAndAfterEach { val array = new ColumnVector.Array(testVector) (0 until 10).foreach { i => - assert(array.get(i, ShortType) === (i.toShort)) + assert(array.get(i, ShortType) === i.toShort) } } - test("int") { - testVector = allocate(10, IntegerType) + testVectors("int", 10, IntegerType) { testVector => (0 until 10).foreach { i => testVector.appendInt(i) } @@ -88,8 +90,7 @@ class ColumnVectorSuite extends SparkFunSuite with BeforeAndAfterEach { } } - test("long") { - testVector = allocate(10, LongType) + testVectors("long", 10, LongType) { testVector => (0 until 10).foreach { i => testVector.appendLong(i) } @@ -101,8 +102,7 @@ class ColumnVectorSuite extends SparkFunSuite with BeforeAndAfterEach { } } - test("float") { - testVector = allocate(10, FloatType) + testVectors("float", 10, FloatType) { testVector => (0 until 10).foreach { i => testVector.appendFloat(i.toFloat) } @@ -114,8 +114,7 @@ class ColumnVectorSuite extends SparkFunSuite with BeforeAndAfterEach { } } - test("double") { - testVector = allocate(10, DoubleType) + testVectors("double", 10, DoubleType) { testVector => (0 until 10).foreach { i => testVector.appendDouble(i.toDouble) } @@ -127,8 +126,7 @@ class ColumnVectorSuite extends SparkFunSuite with BeforeAndAfterEach { } } - test("string") { - testVector = allocate(10, StringType) + testVectors("string", 10, StringType) { testVector => (0 until 10).map { i => val utf8 = s"str$i".getBytes("utf8") testVector.appendByteArray(utf8, 0, utf8.length) @@ -141,8 +139,7 @@ class ColumnVectorSuite extends SparkFunSuite with BeforeAndAfterEach { } } - test("binary") { - testVector = allocate(10, BinaryType) + testVectors("binary", 10, BinaryType) { testVector => (0 until 10).map { i => val utf8 = s"str$i".getBytes("utf8") testVector.appendByteArray(utf8, 0, utf8.length) @@ -156,9 +153,8 @@ class ColumnVectorSuite extends SparkFunSuite with BeforeAndAfterEach { } } - test("array") { - val arrayType = ArrayType(IntegerType, true) - testVector = allocate(10, arrayType) + val arrayType: ArrayType = ArrayType(IntegerType, containsNull = true) + testVectors("array", 10, arrayType) { testVector => val data = testVector.arrayData() var i = 0 @@ -181,9 +177,8 @@ class ColumnVectorSuite extends SparkFunSuite with BeforeAndAfterEach { assert(array.get(3, arrayType).asInstanceOf[ArrayData].toIntArray() === Array(3, 4, 5)) } - test("struct") { - val schema = new StructType().add("int", IntegerType).add("double", DoubleType) - testVector = allocate(10, schema) + val structType: StructType = new StructType().add("int", IntegerType).add("double", DoubleType) + testVectors("struct", 10, structType) { testVector => val c1 = testVector.getChildColumn(0) val c2 = testVector.getChildColumn(1) c1.putInt(0, 123) @@ -193,35 +188,34 @@ class ColumnVectorSuite extends SparkFunSuite with BeforeAndAfterEach { val array = new ColumnVector.Array(testVector) - assert(array.get(0, schema).asInstanceOf[ColumnarBatch.Row].get(0, IntegerType) === 123) - assert(array.get(0, schema).asInstanceOf[ColumnarBatch.Row].get(1, DoubleType) === 3.45) - assert(array.get(1, schema).asInstanceOf[ColumnarBatch.Row].get(0, IntegerType) === 456) - assert(array.get(1, schema).asInstanceOf[ColumnarBatch.Row].get(1, DoubleType) === 5.67) + assert(array.get(0, structType).asInstanceOf[ColumnarBatch.Row].get(0, IntegerType) === 123) + assert(array.get(0, structType).asInstanceOf[ColumnarBatch.Row].get(1, DoubleType) === 3.45) + assert(array.get(1, structType).asInstanceOf[ColumnarBatch.Row].get(0, IntegerType) === 456) + assert(array.get(1, structType).asInstanceOf[ColumnarBatch.Row].get(1, DoubleType) === 5.67) } test("[SPARK-22092] off-heap column vector reallocation corrupts array data") { - val arrayType = ArrayType(IntegerType, true) - testVector = new OffHeapColumnVector(8, arrayType) + withVector(new OffHeapColumnVector(8, arrayType)) { testVector => + val data = testVector.arrayData() + (0 until 8).foreach(i => data.putInt(i, i)) + (0 until 8).foreach(i => testVector.putArray(i, i, 1)) - val data = testVector.arrayData() - (0 until 8).foreach(i => data.putInt(i, i)) - (0 until 8).foreach(i => testVector.putArray(i, i, 1)) + // Increase vector's capacity and reallocate the data to new bigger buffers. + testVector.reserve(16) - // Increase vector's capacity and reallocate the data to new bigger buffers. - testVector.reserve(16) - - // Check that none of the values got lost/overwritten. - val array = new ColumnVector.Array(testVector) - (0 until 8).foreach { i => - assert(array.get(i, arrayType).asInstanceOf[ArrayData].toIntArray() === Array(i)) + // Check that none of the values got lost/overwritten. + val array = new ColumnVector.Array(testVector) + (0 until 8).foreach { i => + assert(array.get(i, arrayType).asInstanceOf[ArrayData].toIntArray() === Array(i)) + } } } test("[SPARK-22092] off-heap column vector reallocation corrupts struct nullability") { - val structType = new StructType().add("int", IntegerType).add("double", DoubleType) - testVector = new OffHeapColumnVector(8, structType) - (0 until 8).foreach(i => if (i % 2 == 0) testVector.putNull(i) else testVector.putNotNull(i)) - testVector.reserve(16) - (0 until 8).foreach(i => assert(testVector.isNullAt(i) == (i % 2 == 0))) + withVector(new OffHeapColumnVector(8, structType)) { testVector => + (0 until 8).foreach(i => if (i % 2 == 0) testVector.putNull(i) else testVector.putNotNull(i)) + testVector.reserve(16) + (0 until 8).foreach(i => assert(testVector.isNullAt(i) == (i % 2 == 0))) + } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala index ebf76613343ba..983eb103682c1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala @@ -38,7 +38,7 @@ import org.apache.spark.unsafe.types.CalendarInterval class ColumnarBatchSuite extends SparkFunSuite { - def allocate(capacity: Int, dt: DataType, memMode: MemoryMode): WritableColumnVector = { + private def allocate(capacity: Int, dt: DataType, memMode: MemoryMode): WritableColumnVector = { if (memMode == MemoryMode.OFF_HEAP) { new OffHeapColumnVector(capacity, dt) } else { @@ -46,23 +46,36 @@ class ColumnarBatchSuite extends SparkFunSuite { } } - test("Null Apis") { - (MemoryMode.ON_HEAP :: MemoryMode.OFF_HEAP :: Nil).foreach { memMode => { - val reference = mutable.ArrayBuffer.empty[Boolean] + private def testVector( + name: String, + size: Int, + dt: DataType)( + block: (WritableColumnVector, MemoryMode) => Unit): Unit = { + test(name) { + Seq(MemoryMode.ON_HEAP, MemoryMode.OFF_HEAP).foreach { mode => + val vector = allocate(size, dt, mode) + try block(vector, mode) finally { + vector.close() + } + } + } + } - val column = allocate(1024, IntegerType, memMode) + testVector("Null APIs", 1024, IntegerType) { + (column, memMode) => + val reference = mutable.ArrayBuffer.empty[Boolean] var idx = 0 - assert(column.anyNullsSet() == false) + assert(!column.anyNullsSet()) assert(column.numNulls() == 0) column.appendNotNull() reference += false - assert(column.anyNullsSet() == false) + assert(!column.anyNullsSet()) assert(column.numNulls() == 0) column.appendNotNulls(3) (1 to 3).foreach(_ => reference += false) - assert(column.anyNullsSet() == false) + assert(!column.anyNullsSet()) assert(column.numNulls() == 0) column.appendNull() @@ -113,16 +126,12 @@ class ColumnarBatchSuite extends SparkFunSuite { assert(v._1 == (Platform.getByte(null, addr + v._2) == 1), "index=" + v._2) } } - column.close - }} } - test("Byte Apis") { - (MemoryMode.ON_HEAP :: MemoryMode.OFF_HEAP :: Nil).foreach { memMode => { + testVector("Byte APIs", 1024, ByteType) { + (column, memMode) => val reference = mutable.ArrayBuffer.empty[Byte] - val column = allocate(1024, ByteType, memMode) - var values = (10 :: 20 :: 30 :: 40 :: 50 :: Nil).map(_.toByte).toArray column.appendBytes(2, values, 0) reference += 10.toByte @@ -170,17 +179,14 @@ class ColumnarBatchSuite extends SparkFunSuite { assert(v._1 == Platform.getByte(null, addr + v._2)) } } - }} } - test("Short Apis") { - (MemoryMode.ON_HEAP :: MemoryMode.OFF_HEAP :: Nil).foreach { memMode => { + testVector("Short APIs", 1024, ShortType) { + (column, memMode) => val seed = System.currentTimeMillis() val random = new Random(seed) val reference = mutable.ArrayBuffer.empty[Short] - val column = allocate(1024, ShortType, memMode) - var values = (10 :: 20 :: 30 :: 40 :: 50 :: Nil).map(_.toShort).toArray column.appendShorts(2, values, 0) reference += 10.toShort @@ -248,19 +254,14 @@ class ColumnarBatchSuite extends SparkFunSuite { assert(v._1 == Platform.getShort(null, addr + 2 * v._2)) } } - - column.close - }} } - test("Int Apis") { - (MemoryMode.ON_HEAP :: MemoryMode.OFF_HEAP :: Nil).foreach { memMode => { + testVector("Int APIs", 1024, IntegerType) { + (column, memMode) => val seed = System.currentTimeMillis() val random = new Random(seed) val reference = mutable.ArrayBuffer.empty[Int] - val column = allocate(1024, IntegerType, memMode) - var values = (10 :: 20 :: 30 :: 40 :: 50 :: Nil).toArray column.appendInts(2, values, 0) reference += 10 @@ -334,18 +335,14 @@ class ColumnarBatchSuite extends SparkFunSuite { assert(v._1 == Platform.getInt(null, addr + 4 * v._2)) } } - column.close - }} } - test("Long Apis") { - (MemoryMode.ON_HEAP :: MemoryMode.OFF_HEAP :: Nil).foreach { memMode => { + testVector("Long APIs", 1024, LongType) { + (column, memMode) => val seed = System.currentTimeMillis() val random = new Random(seed) val reference = mutable.ArrayBuffer.empty[Long] - val column = allocate(1024, LongType, memMode) - var values = (10L :: 20L :: 30L :: 40L :: 50L :: Nil).toArray column.appendLongs(2, values, 0) reference += 10L @@ -422,17 +419,14 @@ class ColumnarBatchSuite extends SparkFunSuite { assert(v._1 == Platform.getLong(null, addr + 8 * v._2)) } } - }} } - test("Float APIs") { - (MemoryMode.ON_HEAP :: MemoryMode.OFF_HEAP :: Nil).foreach { memMode => { + testVector("Float APIs", 1024, FloatType) { + (column, memMode) => val seed = System.currentTimeMillis() val random = new Random(seed) val reference = mutable.ArrayBuffer.empty[Float] - val column = allocate(1024, FloatType, memMode) - var values = (.1f :: .2f :: .3f :: .4f :: .5f :: Nil).toArray column.appendFloats(2, values, 0) reference += .1f @@ -512,18 +506,14 @@ class ColumnarBatchSuite extends SparkFunSuite { assert(v._1 == Platform.getFloat(null, addr + 4 * v._2)) } } - column.close - }} } - test("Double APIs") { - (MemoryMode.ON_HEAP :: MemoryMode.OFF_HEAP :: Nil).foreach { memMode => { + testVector("Double APIs", 1024, DoubleType) { + (column, memMode) => val seed = System.currentTimeMillis() val random = new Random(seed) val reference = mutable.ArrayBuffer.empty[Double] - val column = allocate(1024, DoubleType, memMode) - var values = (.1 :: .2 :: .3 :: .4 :: .5 :: Nil).toArray column.appendDoubles(2, values, 0) reference += .1 @@ -603,15 +593,12 @@ class ColumnarBatchSuite extends SparkFunSuite { assert(v._1 == Platform.getDouble(null, addr + 8 * v._2)) } } - column.close - }} } - test("String APIs") { - (MemoryMode.ON_HEAP :: MemoryMode.OFF_HEAP :: Nil).foreach { memMode => { + testVector("String APIs", 6, StringType) { + (column, memMode) => val reference = mutable.ArrayBuffer.empty[String] - val column = allocate(6, BinaryType, memMode) assert(column.arrayData().elementsAppended == 0) val str = "string" @@ -663,15 +650,13 @@ class ColumnarBatchSuite extends SparkFunSuite { column.reset() assert(column.arrayData().elementsAppended == 0) - }} } - test("Int Array") { - (MemoryMode.ON_HEAP :: MemoryMode.OFF_HEAP :: Nil).foreach { memMode => { - val column = allocate(10, new ArrayType(IntegerType, true), memMode) + testVector("Int Array", 10, new ArrayType(IntegerType, true)) { + (column, _) => // Fill the underlying data with all the arrays back to back. - val data = column.arrayData(); + val data = column.arrayData() var i = 0 while (i < 6) { data.putInt(i, i) @@ -709,7 +694,7 @@ class ColumnarBatchSuite extends SparkFunSuite { assert(column.getArray(3).getInt(2) == 5) // Add a longer array which requires resizing - column.reset + column.reset() val array = Array(1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12) assert(data.capacity == 10) data.reserve(array.length) @@ -718,63 +703,67 @@ class ColumnarBatchSuite extends SparkFunSuite { column.putArray(0, 0, array.length) assert(ColumnVectorUtils.toPrimitiveJavaArray(column.getArray(0)).asInstanceOf[Array[Int]] === array) - }} } test("toArray for primitive types") { - // (MemoryMode.ON_HEAP :: MemoryMode.OFF_HEAP :: Nil).foreach { memMode => { - (MemoryMode.ON_HEAP :: Nil).foreach { memMode => { + (MemoryMode.ON_HEAP :: MemoryMode.OFF_HEAP :: Nil).foreach { memMode => val len = 4 val columnBool = allocate(len, new ArrayType(BooleanType, false), memMode) val boolArray = Array(false, true, false, true) - boolArray.zipWithIndex.map { case (v, i) => columnBool.arrayData.putBoolean(i, v) } + boolArray.zipWithIndex.foreach { case (v, i) => columnBool.arrayData.putBoolean(i, v) } columnBool.putArray(0, 0, len) assert(columnBool.getArray(0).toBooleanArray === boolArray) + columnBool.close() val columnByte = allocate(len, new ArrayType(ByteType, false), memMode) val byteArray = Array[Byte](0, 1, 2, 3) - byteArray.zipWithIndex.map { case (v, i) => columnByte.arrayData.putByte(i, v) } + byteArray.zipWithIndex.foreach { case (v, i) => columnByte.arrayData.putByte(i, v) } columnByte.putArray(0, 0, len) assert(columnByte.getArray(0).toByteArray === byteArray) + columnByte.close() val columnShort = allocate(len, new ArrayType(ShortType, false), memMode) val shortArray = Array[Short](0, 1, 2, 3) - shortArray.zipWithIndex.map { case (v, i) => columnShort.arrayData.putShort(i, v) } + shortArray.zipWithIndex.foreach { case (v, i) => columnShort.arrayData.putShort(i, v) } columnShort.putArray(0, 0, len) assert(columnShort.getArray(0).toShortArray === shortArray) + columnShort.close() val columnInt = allocate(len, new ArrayType(IntegerType, false), memMode) val intArray = Array(0, 1, 2, 3) - intArray.zipWithIndex.map { case (v, i) => columnInt.arrayData.putInt(i, v) } + intArray.zipWithIndex.foreach { case (v, i) => columnInt.arrayData.putInt(i, v) } columnInt.putArray(0, 0, len) assert(columnInt.getArray(0).toIntArray === intArray) + columnInt.close() val columnLong = allocate(len, new ArrayType(LongType, false), memMode) val longArray = Array[Long](0, 1, 2, 3) - longArray.zipWithIndex.map { case (v, i) => columnLong.arrayData.putLong(i, v) } + longArray.zipWithIndex.foreach { case (v, i) => columnLong.arrayData.putLong(i, v) } columnLong.putArray(0, 0, len) assert(columnLong.getArray(0).toLongArray === longArray) + columnLong.close() val columnFloat = allocate(len, new ArrayType(FloatType, false), memMode) val floatArray = Array(0.0F, 1.1F, 2.2F, 3.3F) - floatArray.zipWithIndex.map { case (v, i) => columnFloat.arrayData.putFloat(i, v) } + floatArray.zipWithIndex.foreach { case (v, i) => columnFloat.arrayData.putFloat(i, v) } columnFloat.putArray(0, 0, len) assert(columnFloat.getArray(0).toFloatArray === floatArray) + columnFloat.close() val columnDouble = allocate(len, new ArrayType(DoubleType, false), memMode) val doubleArray = Array(0.0, 1.1, 2.2, 3.3) - doubleArray.zipWithIndex.map { case (v, i) => columnDouble.arrayData.putDouble(i, v) } + doubleArray.zipWithIndex.foreach { case (v, i) => columnDouble.arrayData.putDouble(i, v) } columnDouble.putArray(0, 0, len) assert(columnDouble.getArray(0).toDoubleArray === doubleArray) - }} + columnDouble.close() + } } - test("Struct Column") { - (MemoryMode.ON_HEAP :: MemoryMode.OFF_HEAP :: Nil).foreach { memMode => { - val schema = new StructType().add("int", IntegerType).add("double", DoubleType) - val column = allocate(1024, schema, memMode) - + testVector( + "Struct Column", + 10, + new StructType().add("int", IntegerType).add("double", DoubleType)) { (column, _) => val c1 = column.getChildColumn(0) val c2 = column.getChildColumn(1) assert(c1.dataType() == IntegerType) @@ -797,13 +786,10 @@ class ColumnarBatchSuite extends SparkFunSuite { val s2 = column.getStruct(1) assert(s2.getInt(0) == 456) assert(s2.getDouble(1) == 5.67) - }} } - test("Nest Array in Array.") { - (MemoryMode.ON_HEAP :: MemoryMode.OFF_HEAP :: Nil).foreach { memMode => - val column = allocate(10, new ArrayType(new ArrayType(IntegerType, true), true), - memMode) + testVector("Nest Array in Array", 10, new ArrayType(new ArrayType(IntegerType, true), true)) { + (column, _) => val childColumn = column.arrayData() val data = column.arrayData().arrayData() (0 until 6).foreach { @@ -829,13 +815,14 @@ class ColumnarBatchSuite extends SparkFunSuite { assert(column.getArray(2).getArray(1).getInt(1) === 4) assert(column.getArray(2).getArray(1).getInt(2) === 5) assert(column.isNullAt(3)) - } } - test("Nest Struct in Array.") { - (MemoryMode.ON_HEAP :: MemoryMode.OFF_HEAP :: Nil).foreach { memMode => - val schema = new StructType().add("int", IntegerType).add("long", LongType) - val column = allocate(10, new ArrayType(schema, true), memMode) + private val structType: StructType = new StructType().add("i", IntegerType).add("l", LongType) + + testVector( + "Nest Struct in Array", + 10, + new ArrayType(structType, true)) { (column, _) => val data = column.arrayData() val c0 = data.getChildColumn(0) val c1 = data.getChildColumn(1) @@ -850,22 +837,21 @@ class ColumnarBatchSuite extends SparkFunSuite { column.putArray(1, 1, 3) column.putArray(2, 4, 2) - assert(column.getArray(0).getStruct(0, 2).toSeq(schema) === Seq(0, 0)) - assert(column.getArray(0).getStruct(1, 2).toSeq(schema) === Seq(1, 10)) - assert(column.getArray(1).getStruct(0, 2).toSeq(schema) === Seq(1, 10)) - assert(column.getArray(1).getStruct(1, 2).toSeq(schema) === Seq(2, 20)) - assert(column.getArray(1).getStruct(2, 2).toSeq(schema) === Seq(3, 30)) - assert(column.getArray(2).getStruct(0, 2).toSeq(schema) === Seq(4, 40)) - assert(column.getArray(2).getStruct(1, 2).toSeq(schema) === Seq(5, 50)) - } + assert(column.getArray(0).getStruct(0, 2).toSeq(structType) === Seq(0, 0)) + assert(column.getArray(0).getStruct(1, 2).toSeq(structType) === Seq(1, 10)) + assert(column.getArray(1).getStruct(0, 2).toSeq(structType) === Seq(1, 10)) + assert(column.getArray(1).getStruct(1, 2).toSeq(structType) === Seq(2, 20)) + assert(column.getArray(1).getStruct(2, 2).toSeq(structType) === Seq(3, 30)) + assert(column.getArray(2).getStruct(0, 2).toSeq(structType) === Seq(4, 40)) + assert(column.getArray(2).getStruct(1, 2).toSeq(structType) === Seq(5, 50)) } - test("Nest Array in Struct.") { - (MemoryMode.ON_HEAP :: MemoryMode.OFF_HEAP :: Nil).foreach { memMode => - val schema = new StructType() - .add("int", IntegerType) - .add("array", new ArrayType(IntegerType, true)) - val column = allocate(10, schema, memMode) + testVector( + "Nest Array in Struct", + 10, + new StructType() + .add("int", IntegerType) + .add("array", new ArrayType(IntegerType, true))) { (column, _) => val c0 = column.getChildColumn(0) val c1 = column.getChildColumn(1) c0.putInt(0, 0) @@ -886,18 +872,15 @@ class ColumnarBatchSuite extends SparkFunSuite { assert(column.getStruct(1).getArray(1).toIntArray() === Array(2)) assert(column.getStruct(2).getInt(0) === 2) assert(column.getStruct(2).getArray(1).toIntArray() === Array(3, 4, 5)) - } } - test("Nest Struct in Struct.") { - (MemoryMode.ON_HEAP :: Nil).foreach { memMode => - val subSchema = new StructType() - .add("int", IntegerType) - .add("int", IntegerType) - val schema = new StructType() - .add("int", IntegerType) - .add("struct", subSchema) - val column = allocate(10, schema, memMode) + private val subSchema: StructType = new StructType() + .add("int", IntegerType) + .add("int", IntegerType) + testVector( + "Nest Struct in Struct", + 10, + new StructType().add("int", IntegerType).add("struct", subSchema)) { (column, _) => val c0 = column.getChildColumn(0) val c1 = column.getChildColumn(1) c0.putInt(0, 0) @@ -919,7 +902,6 @@ class ColumnarBatchSuite extends SparkFunSuite { assert(column.getStruct(1).getStruct(1, 2).toSeq(subSchema) === Seq(8, 80)) assert(column.getStruct(2).getInt(0) === 2) assert(column.getStruct(2).getStruct(1, 2).toSeq(subSchema) === Seq(9, 90)) - } } test("ColumnarBatch basic") { @@ -1040,7 +1022,7 @@ class ColumnarBatchSuite extends SparkFunSuite { val it4 = batch.rowIterator() rowEquals(it4.next(), Row(null, 2.2, 2, "abc")) - batch.close + batch.close() }} } From 9244957b500cb2b458c32db2c63293a1444690d7 Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Wed, 27 Sep 2017 17:03:42 -0700 Subject: [PATCH 1411/1765] [SPARK-22140] Add TPCDSQuerySuite ## What changes were proposed in this pull request? Now, we are not running TPC-DS queries as regular test cases. Thus, we need to add a test suite using empty tables for ensuring the new code changes will not break them. For example, optimizer/analyzer batches should not exceed the max iteration. ## How was this patch tested? N/A Author: gatorsmile Closes #19361 from gatorsmile/tpcdsQuerySuite. --- .../apache/spark/sql/TPCDSQuerySuite.scala | 348 ++++++++++++++++++ 1 file changed, 348 insertions(+) create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/TPCDSQuerySuite.scala diff --git a/sql/core/src/test/scala/org/apache/spark/sql/TPCDSQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/TPCDSQuerySuite.scala new file mode 100644 index 0000000000000..c0797fa55f5da --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/TPCDSQuerySuite.scala @@ -0,0 +1,348 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import org.scalatest.BeforeAndAfterAll + +import org.apache.spark.sql.catalyst.util.resourceToString +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.test.SharedSQLContext + +class TPCDSQuerySuite extends QueryTest with SharedSQLContext with BeforeAndAfterAll { + + /** + * Drop all the tables + */ + protected override def afterAll(): Unit = { + try { + spark.sessionState.catalog.reset() + } finally { + super.afterAll() + } + } + + override def beforeAll() { + super.beforeAll() + sql( + """ + |CREATE TABLE `catalog_page` ( + |`cp_catalog_page_sk` INT, `cp_catalog_page_id` STRING, `cp_start_date_sk` INT, + |`cp_end_date_sk` INT, `cp_department` STRING, `cp_catalog_number` INT, + |`cp_catalog_page_number` INT, `cp_description` STRING, `cp_type` STRING) + |USING parquet + """.stripMargin) + + sql( + """ + |CREATE TABLE `catalog_returns` ( + |`cr_returned_date_sk` INT, `cr_returned_time_sk` INT, `cr_item_sk` INT, + |`cr_refunded_customer_sk` INT, `cr_refunded_cdemo_sk` INT, `cr_refunded_hdemo_sk` INT, + |`cr_refunded_addr_sk` INT, `cr_returning_customer_sk` INT, `cr_returning_cdemo_sk` INT, + |`cr_returning_hdemo_sk` INT, `cr_returning_addr_sk` INT, `cr_call_center_sk` INT, + |`cr_catalog_page_sk` INT, `cr_ship_mode_sk` INT, `cr_warehouse_sk` INT, `cr_reason_sk` INT, + |`cr_order_number` INT, `cr_return_quantity` INT, `cr_return_amount` DECIMAL(7,2), + |`cr_return_tax` DECIMAL(7,2), `cr_return_amt_inc_tax` DECIMAL(7,2), `cr_fee` DECIMAL(7,2), + |`cr_return_ship_cost` DECIMAL(7,2), `cr_refunded_cash` DECIMAL(7,2), + |`cr_reversed_charge` DECIMAL(7,2), `cr_store_credit` DECIMAL(7,2), + |`cr_net_loss` DECIMAL(7,2)) + |USING parquet + """.stripMargin) + + sql( + """ + |CREATE TABLE `customer` ( + |`c_customer_sk` INT, `c_customer_id` STRING, `c_current_cdemo_sk` INT, + |`c_current_hdemo_sk` INT, `c_current_addr_sk` INT, `c_first_shipto_date_sk` INT, + |`c_first_sales_date_sk` INT, `c_salutation` STRING, `c_first_name` STRING, + |`c_last_name` STRING, `c_preferred_cust_flag` STRING, `c_birth_day` INT, + |`c_birth_month` INT, `c_birth_year` INT, `c_birth_country` STRING, `c_login` STRING, + |`c_email_address` STRING, `c_last_review_date` STRING) + |USING parquet + """.stripMargin) + + sql( + """ + |CREATE TABLE `customer_address` ( + |`ca_address_sk` INT, `ca_address_id` STRING, `ca_street_number` STRING, + |`ca_street_name` STRING, `ca_street_type` STRING, `ca_suite_number` STRING, + |`ca_city` STRING, `ca_county` STRING, `ca_state` STRING, `ca_zip` STRING, + |`ca_country` STRING, `ca_gmt_offset` DECIMAL(5,2), `ca_location_type` STRING) + |USING parquet + """.stripMargin) + + sql( + """ + |CREATE TABLE `customer_demographics` ( + |`cd_demo_sk` INT, `cd_gender` STRING, `cd_marital_status` STRING, + |`cd_education_status` STRING, `cd_purchase_estimate` INT, `cd_credit_rating` STRING, + |`cd_dep_count` INT, `cd_dep_employed_count` INT, `cd_dep_college_count` INT) + |USING parquet + """.stripMargin) + + sql( + """ + |CREATE TABLE `date_dim` ( + |`d_date_sk` INT, `d_date_id` STRING, `d_date` STRING, + |`d_month_seq` INT, `d_week_seq` INT, `d_quarter_seq` INT, `d_year` INT, `d_dow` INT, + |`d_moy` INT, `d_dom` INT, `d_qoy` INT, `d_fy_year` INT, `d_fy_quarter_seq` INT, + |`d_fy_week_seq` INT, `d_day_name` STRING, `d_quarter_name` STRING, `d_holiday` STRING, + |`d_weekend` STRING, `d_following_holiday` STRING, `d_first_dom` INT, `d_last_dom` INT, + |`d_same_day_ly` INT, `d_same_day_lq` INT, `d_current_day` STRING, `d_current_week` STRING, + |`d_current_month` STRING, `d_current_quarter` STRING, `d_current_year` STRING) + |USING parquet + """.stripMargin) + + sql( + """ + |CREATE TABLE `household_demographics` ( + |`hd_demo_sk` INT, `hd_income_band_sk` INT, `hd_buy_potential` STRING, `hd_dep_count` INT, + |`hd_vehicle_count` INT) + |USING parquet + """.stripMargin) + + sql( + """ + |CREATE TABLE `inventory` (`inv_date_sk` INT, `inv_item_sk` INT, `inv_warehouse_sk` INT, + |`inv_quantity_on_hand` INT) + |USING parquet + """.stripMargin) + + sql( + """ + |CREATE TABLE `item` (`i_item_sk` INT, `i_item_id` STRING, `i_rec_start_date` STRING, + |`i_rec_end_date` STRING, `i_item_desc` STRING, `i_current_price` DECIMAL(7,2), + |`i_wholesale_cost` DECIMAL(7,2), `i_brand_id` INT, `i_brand` STRING, `i_class_id` INT, + |`i_class` STRING, `i_category_id` INT, `i_category` STRING, `i_manufact_id` INT, + |`i_manufact` STRING, `i_size` STRING, `i_formulation` STRING, `i_color` STRING, + |`i_units` STRING, `i_container` STRING, `i_manager_id` INT, `i_product_name` STRING) + |USING parquet + """.stripMargin) + + sql( + """ + |CREATE TABLE `promotion` ( + |`p_promo_sk` INT, `p_promo_id` STRING, `p_start_date_sk` INT, `p_end_date_sk` INT, + |`p_item_sk` INT, `p_cost` DECIMAL(15,2), `p_response_target` INT, `p_promo_name` STRING, + |`p_channel_dmail` STRING, `p_channel_email` STRING, `p_channel_catalog` STRING, + |`p_channel_tv` STRING, `p_channel_radio` STRING, `p_channel_press` STRING, + |`p_channel_event` STRING, `p_channel_demo` STRING, `p_channel_details` STRING, + |`p_purpose` STRING, `p_discount_active` STRING) + |USING parquet + """.stripMargin) + + sql( + """ + |CREATE TABLE `store` ( + |`s_store_sk` INT, `s_store_id` STRING, `s_rec_start_date` STRING, + |`s_rec_end_date` STRING, `s_closed_date_sk` INT, `s_store_name` STRING, + |`s_number_employees` INT, `s_floor_space` INT, `s_hours` STRING, `s_manager` STRING, + |`s_market_id` INT, `s_geography_class` STRING, `s_market_desc` STRING, + |`s_market_manager` STRING, `s_division_id` INT, `s_division_name` STRING, + |`s_company_id` INT, `s_company_name` STRING, `s_street_number` STRING, + |`s_street_name` STRING, `s_street_type` STRING, `s_suite_number` STRING, `s_city` STRING, + |`s_county` STRING, `s_state` STRING, `s_zip` STRING, `s_country` STRING, + |`s_gmt_offset` DECIMAL(5,2), `s_tax_precentage` DECIMAL(5,2)) + |USING parquet + """.stripMargin) + + sql( + """ + |CREATE TABLE `store_returns` ( + |`sr_returned_date_sk` BIGINT, `sr_return_time_sk` BIGINT, `sr_item_sk` BIGINT, + |`sr_customer_sk` BIGINT, `sr_cdemo_sk` BIGINT, `sr_hdemo_sk` BIGINT, `sr_addr_sk` BIGINT, + |`sr_store_sk` BIGINT, `sr_reason_sk` BIGINT, `sr_ticket_number` BIGINT, + |`sr_return_quantity` BIGINT, `sr_return_amt` DECIMAL(7,2), `sr_return_tax` DECIMAL(7,2), + |`sr_return_amt_inc_tax` DECIMAL(7,2), `sr_fee` DECIMAL(7,2), + |`sr_return_ship_cost` DECIMAL(7,2), `sr_refunded_cash` DECIMAL(7,2), + |`sr_reversed_charge` DECIMAL(7,2), `sr_store_credit` DECIMAL(7,2), + |`sr_net_loss` DECIMAL(7,2)) + |USING parquet + """.stripMargin) + + sql( + """ + |CREATE TABLE `catalog_sales` ( + |`cs_sold_date_sk` INT, `cs_sold_time_sk` INT, `cs_ship_date_sk` INT, + |`cs_bill_customer_sk` INT, `cs_bill_cdemo_sk` INT, `cs_bill_hdemo_sk` INT, + |`cs_bill_addr_sk` INT, `cs_ship_customer_sk` INT, `cs_ship_cdemo_sk` INT, + |`cs_ship_hdemo_sk` INT, `cs_ship_addr_sk` INT, `cs_call_center_sk` INT, + |`cs_catalog_page_sk` INT, `cs_ship_mode_sk` INT, `cs_warehouse_sk` INT, + |`cs_item_sk` INT, `cs_promo_sk` INT, `cs_order_number` INT, `cs_quantity` INT, + |`cs_wholesale_cost` DECIMAL(7,2), `cs_list_price` DECIMAL(7,2), + |`cs_sales_price` DECIMAL(7,2), `cs_ext_discount_amt` DECIMAL(7,2), + |`cs_ext_sales_price` DECIMAL(7,2), `cs_ext_wholesale_cost` DECIMAL(7,2), + |`cs_ext_list_price` DECIMAL(7,2), `cs_ext_tax` DECIMAL(7,2), `cs_coupon_amt` DECIMAL(7,2), + |`cs_ext_ship_cost` DECIMAL(7,2), `cs_net_paid` DECIMAL(7,2), + |`cs_net_paid_inc_tax` DECIMAL(7,2), `cs_net_paid_inc_ship` DECIMAL(7,2), + |`cs_net_paid_inc_ship_tax` DECIMAL(7,2), `cs_net_profit` DECIMAL(7,2)) + |USING parquet + """.stripMargin) + + sql( + """ + |CREATE TABLE `web_sales` ( + |`ws_sold_date_sk` INT, `ws_sold_time_sk` INT, `ws_ship_date_sk` INT, `ws_item_sk` INT, + |`ws_bill_customer_sk` INT, `ws_bill_cdemo_sk` INT, `ws_bill_hdemo_sk` INT, + |`ws_bill_addr_sk` INT, `ws_ship_customer_sk` INT, `ws_ship_cdemo_sk` INT, + |`ws_ship_hdemo_sk` INT, `ws_ship_addr_sk` INT, `ws_web_page_sk` INT, `ws_web_site_sk` INT, + |`ws_ship_mode_sk` INT, `ws_warehouse_sk` INT, `ws_promo_sk` INT, `ws_order_number` INT, + |`ws_quantity` INT, `ws_wholesale_cost` DECIMAL(7,2), `ws_list_price` DECIMAL(7,2), + |`ws_sales_price` DECIMAL(7,2), `ws_ext_discount_amt` DECIMAL(7,2), + |`ws_ext_sales_price` DECIMAL(7,2), `ws_ext_wholesale_cost` DECIMAL(7,2), + |`ws_ext_list_price` DECIMAL(7,2), `ws_ext_tax` DECIMAL(7,2), + |`ws_coupon_amt` DECIMAL(7,2), `ws_ext_ship_cost` DECIMAL(7,2), `ws_net_paid` DECIMAL(7,2), + |`ws_net_paid_inc_tax` DECIMAL(7,2), `ws_net_paid_inc_ship` DECIMAL(7,2), + |`ws_net_paid_inc_ship_tax` DECIMAL(7,2), `ws_net_profit` DECIMAL(7,2)) + |USING parquet + """.stripMargin) + + sql( + """ + |CREATE TABLE `store_sales` ( + |`ss_sold_date_sk` INT, `ss_sold_time_sk` INT, `ss_item_sk` INT, `ss_customer_sk` INT, + |`ss_cdemo_sk` INT, `ss_hdemo_sk` INT, `ss_addr_sk` INT, `ss_store_sk` INT, + |`ss_promo_sk` INT, `ss_ticket_number` INT, `ss_quantity` INT, + |`ss_wholesale_cost` DECIMAL(7,2), `ss_list_price` DECIMAL(7,2), + |`ss_sales_price` DECIMAL(7,2), `ss_ext_discount_amt` DECIMAL(7,2), + |`ss_ext_sales_price` DECIMAL(7,2), `ss_ext_wholesale_cost` DECIMAL(7,2), + |`ss_ext_list_price` DECIMAL(7,2), `ss_ext_tax` DECIMAL(7,2), + |`ss_coupon_amt` DECIMAL(7,2), `ss_net_paid` DECIMAL(7,2), + |`ss_net_paid_inc_tax` DECIMAL(7,2), `ss_net_profit` DECIMAL(7,2)) + |USING parquet + """.stripMargin) + + sql( + """ + |CREATE TABLE `web_returns` ( + |`wr_returned_date_sk` BIGINT, `wr_returned_time_sk` BIGINT, `wr_item_sk` BIGINT, + |`wr_refunded_customer_sk` BIGINT, `wr_refunded_cdemo_sk` BIGINT, + |`wr_refunded_hdemo_sk` BIGINT, `wr_refunded_addr_sk` BIGINT, + |`wr_returning_customer_sk` BIGINT, `wr_returning_cdemo_sk` BIGINT, + |`wr_returning_hdemo_sk` BIGINT, `wr_returning_addr_sk` BIGINT, `wr_web_page_sk` BIGINT, + |`wr_reason_sk` BIGINT, `wr_order_number` BIGINT, `wr_return_quantity` BIGINT, + |`wr_return_amt` DECIMAL(7,2), `wr_return_tax` DECIMAL(7,2), + |`wr_return_amt_inc_tax` DECIMAL(7,2), `wr_fee` DECIMAL(7,2), + |`wr_return_ship_cost` DECIMAL(7,2), `wr_refunded_cash` DECIMAL(7,2), + |`wr_reversed_charge` DECIMAL(7,2), `wr_account_credit` DECIMAL(7,2), + |`wr_net_loss` DECIMAL(7,2)) + |USING parquet + """.stripMargin) + + sql( + """ + |CREATE TABLE `web_site` ( + |`web_site_sk` INT, `web_site_id` STRING, `web_rec_start_date` DATE, + |`web_rec_end_date` DATE, `web_name` STRING, `web_open_date_sk` INT, + |`web_close_date_sk` INT, `web_class` STRING, `web_manager` STRING, `web_mkt_id` INT, + |`web_mkt_class` STRING, `web_mkt_desc` STRING, `web_market_manager` STRING, + |`web_company_id` INT, `web_company_name` STRING, `web_street_number` STRING, + |`web_street_name` STRING, `web_street_type` STRING, `web_suite_number` STRING, + |`web_city` STRING, `web_county` STRING, `web_state` STRING, `web_zip` STRING, + |`web_country` STRING, `web_gmt_offset` STRING, `web_tax_percentage` DECIMAL(5,2)) + |USING parquet + """.stripMargin) + + sql( + """ + |CREATE TABLE `reason` ( + |`r_reason_sk` INT, `r_reason_id` STRING, `r_reason_desc` STRING) + |USING parquet + """.stripMargin) + + sql( + """ + |CREATE TABLE `call_center` ( + |`cc_call_center_sk` INT, `cc_call_center_id` STRING, `cc_rec_start_date` DATE, + |`cc_rec_end_date` DATE, `cc_closed_date_sk` INT, `cc_open_date_sk` INT, `cc_name` STRING, + |`cc_class` STRING, `cc_employees` INT, `cc_sq_ft` INT, `cc_hours` STRING, + |`cc_manager` STRING, `cc_mkt_id` INT, `cc_mkt_class` STRING, `cc_mkt_desc` STRING, + |`cc_market_manager` STRING, `cc_division` INT, `cc_division_name` STRING, `cc_company` INT, + |`cc_company_name` STRING, `cc_street_number` STRING, `cc_street_name` STRING, + |`cc_street_type` STRING, `cc_suite_number` STRING, `cc_city` STRING, `cc_county` STRING, + |`cc_state` STRING, `cc_zip` STRING, `cc_country` STRING, `cc_gmt_offset` DECIMAL(5,2), + |`cc_tax_percentage` DECIMAL(5,2)) + |USING parquet + """.stripMargin) + + sql( + """ + |CREATE TABLE `warehouse` ( + |`w_warehouse_sk` INT, `w_warehouse_id` STRING, `w_warehouse_name` STRING, + |`w_warehouse_sq_ft` INT, `w_street_number` STRING, `w_street_name` STRING, + |`w_street_type` STRING, `w_suite_number` STRING, `w_city` STRING, `w_county` STRING, + |`w_state` STRING, `w_zip` STRING, `w_country` STRING, `w_gmt_offset` DECIMAL(5,2)) + |USING parquet + """.stripMargin) + + sql( + """ + |CREATE TABLE `ship_mode` ( + |`sm_ship_mode_sk` INT, `sm_ship_mode_id` STRING, `sm_type` STRING, `sm_code` STRING, + |`sm_carrier` STRING, `sm_contract` STRING) + |USING parquet + """.stripMargin) + + sql( + """ + |CREATE TABLE `income_band` ( + |`ib_income_band_sk` INT, `ib_lower_bound` INT, `ib_upper_bound` INT) + |USING parquet + """.stripMargin) + + sql( + """ + |CREATE TABLE `time_dim` ( + |`t_time_sk` INT, `t_time_id` STRING, `t_time` INT, `t_hour` INT, `t_minute` INT, + |`t_second` INT, `t_am_pm` STRING, `t_shift` STRING, `t_sub_shift` STRING, + |`t_meal_time` STRING) + |USING parquet + """.stripMargin) + + sql( + """ + |CREATE TABLE `web_page` (`wp_web_page_sk` INT, `wp_web_page_id` STRING, + |`wp_rec_start_date` DATE, `wp_rec_end_date` DATE, `wp_creation_date_sk` INT, + |`wp_access_date_sk` INT, `wp_autogen_flag` STRING, `wp_customer_sk` INT, + |`wp_url` STRING, `wp_type` STRING, `wp_char_count` INT, `wp_link_count` INT, + |`wp_image_count` INT, `wp_max_ad_count` INT) + |USING parquet + """.stripMargin) + } + + val tpcdsQueries = Seq( + "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q8", "q9", "q10", "q11", + "q12", "q13", "q14a", "q14b", "q15", "q16", "q17", "q18", "q19", "q20", + "q21", "q22", "q23a", "q23b", "q24a", "q24b", "q25", "q26", "q27", "q28", "q29", "q30", + "q31", "q32", "q33", "q34", "q35", "q36", "q37", "q38", "q39a", "q39b", "q40", + "q41", "q42", "q43", "q44", "q45", "q46", "q47", "q48", "q49", "q50", + "q51", "q52", "q53", "q54", "q55", "q56", "q57", "q58", "q59", "q60", + "q61", "q62", "q63", "q64", "q65", "q66", "q67", "q68", "q69", "q70", + "q71", "q72", "q73", "q74", "q75", "q76", "q77", "q78", "q79", "q80", + "q81", "q82", "q83", "q84", "q85", "q86", "q87", "q88", "q89", "q90", + "q91", "q92", "q93", "q94", "q95", "q96", "q97", "q98", "q99") + + tpcdsQueries.foreach { name => + val queryString = resourceToString(s"tpcds/$name.sql", + classLoader = Thread.currentThread().getContextClassLoader) + test(name) { + withSQLConf(SQLConf.CROSS_JOINS_ENABLED.key -> "true") { + sql(queryString).collect() + } + } + } +} From 7bf4da8a33c33b03bbfddc698335fe9b86ce1e0e Mon Sep 17 00:00:00 2001 From: Bryan Cutler Date: Thu, 28 Sep 2017 10:24:51 +0900 Subject: [PATCH 1412/1765] [MINOR] Fixed up pandas_udf related docs and formatting ## What changes were proposed in this pull request? Fixed some minor issues with pandas_udf related docs and formatting. ## How was this patch tested? NA Author: Bryan Cutler Closes #19375 from BryanCutler/arrow-pandas_udf-cleanup-minor. --- python/pyspark/serializers.py | 6 +++--- python/pyspark/sql/functions.py | 6 ++---- 2 files changed, 5 insertions(+), 7 deletions(-) diff --git a/python/pyspark/serializers.py b/python/pyspark/serializers.py index db77b7e150b24..ad18bd0c81eaa 100644 --- a/python/pyspark/serializers.py +++ b/python/pyspark/serializers.py @@ -191,7 +191,7 @@ def loads(self, obj): class ArrowSerializer(FramedSerializer): """ - Serializes an Arrow stream. + Serializes bytes as Arrow data with the Arrow file format. """ def dumps(self, batch): @@ -239,7 +239,7 @@ class ArrowStreamPandasSerializer(Serializer): def dump_stream(self, iterator, stream): """ - Make ArrowRecordBatches from Pandas Serieses and serialize. Input is a single series or + Make ArrowRecordBatches from Pandas Series and serialize. Input is a single series or a list of series accompanied by an optional pyarrow type to coerce the data to. """ import pyarrow as pa @@ -257,7 +257,7 @@ def dump_stream(self, iterator, stream): def load_stream(self, stream): """ - Deserialize ArrowRecordBatchs to an Arrow table and return as a list of pandas.Series. + Deserialize ArrowRecordBatches to an Arrow table and return as a list of pandas.Series. """ import pyarrow as pa reader = pa.open_stream(stream) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index 63e9a830bbc9e..b45a59db93679 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -2199,16 +2199,14 @@ def pandas_udf(f=None, returnType=StringType()): ... >>> df = spark.createDataFrame([(1, "John Doe", 21)], ("id", "name", "age")) >>> df.select(slen("name").alias("slen(name)"), to_upper("name"), add_one("age")) \\ - ... .show() # doctest: +SKIP + ... .show() # doctest: +SKIP +----------+--------------+------------+ |slen(name)|to_upper(name)|add_one(age)| +----------+--------------+------------+ | 8| JOHN DOE| 22| +----------+--------------+------------+ """ - wrapped_udf = _create_udf(f, returnType=returnType, vectorized=True) - - return wrapped_udf + return _create_udf(f, returnType=returnType, vectorized=True) blacklist = ['map', 'since', 'ignore_unicode_prefix'] From 3b117d631e1ff387b70ed8efba229594f4594db5 Mon Sep 17 00:00:00 2001 From: zhoukang Date: Thu, 28 Sep 2017 09:25:21 +0800 Subject: [PATCH 1413/1765] [SPARK-22123][CORE] Add latest failure reason for task set blacklist ## What changes were proposed in this pull request? This patch add latest failure reason for task set blacklist.Which can be showed on spark ui and let user know failure reason directly. Till now , every job which aborted by completed blacklist just show log like below which has no more information: `Aborting $taskSet because task $indexInTaskSet (partition $partition) cannot run anywhere due to node and executor blacklist. Blacklisting behavior cannot run anywhere due to node and executor blacklist.Blacklisting behavior can be configured via spark.blacklist.*."` **After modify:** ``` Aborting TaskSet 0.0 because task 0 (partition 0) cannot run anywhere due to node and executor blacklist. Most recent failure: Some(Lost task 0.1 in stage 0.0 (TID 3,xxx, executor 1): java.lang.Exception: Fake error! at org.apache.spark.scheduler.ResultTask.runTask(ResultTask.scala:73) at org.apache.spark.scheduler.Task.run(Task.scala:99) at org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:305) at java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1145) at java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:615) at java.lang.Thread.run(Thread.java:745) ). Blacklisting behavior can be configured via spark.blacklist.*. ``` ## How was this patch tested? Unit test and manually test. Author: zhoukang Closes #19338 from caneGuy/zhoukang/improve-blacklist. --- .../spark/scheduler/TaskSetBlacklist.scala | 14 ++++- .../spark/scheduler/TaskSetManager.scala | 15 +++-- .../scheduler/BlacklistIntegrationSuite.scala | 5 +- .../scheduler/BlacklistTrackerSuite.scala | 60 ++++++++++++------- .../scheduler/TaskSchedulerImplSuite.scala | 11 +++- .../scheduler/TaskSetBlacklistSuite.scala | 45 +++++++++----- .../spark/scheduler/TaskSetManagerSuite.scala | 2 +- 7 files changed, 104 insertions(+), 48 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSetBlacklist.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSetBlacklist.scala index e815b7e0cf6c9..233781f3d9719 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSetBlacklist.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSetBlacklist.scala @@ -61,6 +61,16 @@ private[scheduler] class TaskSetBlacklist(val conf: SparkConf, val stageId: Int, private val blacklistedExecs = new HashSet[String]() private val blacklistedNodes = new HashSet[String]() + private var latestFailureReason: String = null + + /** + * Get the most recent failure reason of this TaskSet. + * @return + */ + def getLatestFailureReason: String = { + latestFailureReason + } + /** * Return true if this executor is blacklisted for the given task. This does *not* * need to return true if the executor is blacklisted for the entire stage, or blacklisted @@ -94,7 +104,9 @@ private[scheduler] class TaskSetBlacklist(val conf: SparkConf, val stageId: Int, private[scheduler] def updateBlacklistForFailedTask( host: String, exec: String, - index: Int): Unit = { + index: Int, + failureReason: String): Unit = { + latestFailureReason = failureReason val execFailures = execToFailures.getOrElseUpdate(exec, new ExecutorFailuresInTaskSet(host)) execFailures.updateWithFailure(index, clock.getTimeMillis()) diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala index 3804ea863b4f9..bb867416a4fac 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala @@ -670,9 +670,14 @@ private[spark] class TaskSetManager( } if (blacklistedEverywhere) { val partition = tasks(indexInTaskSet).partitionId - abort(s"Aborting $taskSet because task $indexInTaskSet (partition $partition) " + - s"cannot run anywhere due to node and executor blacklist. Blacklisting behavior " + - s"can be configured via spark.blacklist.*.") + abort(s""" + |Aborting $taskSet because task $indexInTaskSet (partition $partition) + |cannot run anywhere due to node and executor blacklist. + |Most recent failure: + |${taskSetBlacklist.getLatestFailureReason} + | + |Blacklisting behavior can be configured via spark.blacklist.*. + |""".stripMargin) } } } @@ -837,9 +842,9 @@ private[spark] class TaskSetManager( sched.dagScheduler.taskEnded(tasks(index), reason, null, accumUpdates, info) if (!isZombie && reason.countTowardsTaskFailures) { - taskSetBlacklistHelperOpt.foreach(_.updateBlacklistForFailedTask( - info.host, info.executorId, index)) assert (null != failureReason) + taskSetBlacklistHelperOpt.foreach(_.updateBlacklistForFailedTask( + info.host, info.executorId, index, failureReason)) numFailures(index) += 1 if (numFailures(index) >= maxTaskFailures) { logError("Task %d in stage %s failed %d times; aborting job".format( diff --git a/core/src/test/scala/org/apache/spark/scheduler/BlacklistIntegrationSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/BlacklistIntegrationSuite.scala index f6015cd51c2bd..d3bbfd11d406d 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/BlacklistIntegrationSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/BlacklistIntegrationSuite.scala @@ -115,8 +115,9 @@ class BlacklistIntegrationSuite extends SchedulerIntegrationSuite[MultiExecutorM withBackend(runBackend _) { val jobFuture = submit(new MockRDD(sc, 10, Nil), (0 until 10).toArray) awaitJobTermination(jobFuture, duration) - val pattern = ("Aborting TaskSet 0.0 because task .* " + - "cannot run anywhere due to node and executor blacklist").r + val pattern = ( + s"""|Aborting TaskSet 0.0 because task .* + |cannot run anywhere due to node and executor blacklist""".stripMargin).r assert(pattern.findFirstIn(failure.getMessage).isDefined, s"Couldn't find $pattern in ${failure.getMessage()}") } diff --git a/core/src/test/scala/org/apache/spark/scheduler/BlacklistTrackerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/BlacklistTrackerSuite.scala index a136d69b36d6c..cd1b7a9e5ab18 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/BlacklistTrackerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/BlacklistTrackerSuite.scala @@ -110,7 +110,8 @@ class BlacklistTrackerSuite extends SparkFunSuite with BeforeAndAfterEach with M val taskSetBlacklist = createTaskSetBlacklist(stageId) if (stageId % 2 == 0) { // fail one task in every other taskset - taskSetBlacklist.updateBlacklistForFailedTask("hostA", exec = "1", index = 0) + taskSetBlacklist.updateBlacklistForFailedTask( + "hostA", exec = "1", index = 0, failureReason = "testing") failuresSoFar += 1 } blacklist.updateBlacklistForSuccessfulTaskSet(stageId, 0, taskSetBlacklist.execToFailures) @@ -132,7 +133,8 @@ class BlacklistTrackerSuite extends SparkFunSuite with BeforeAndAfterEach with M // for many different stages, executor 1 fails a task, and then the taskSet fails. (0 until failuresUntilBlacklisted * 10).foreach { stage => val taskSetBlacklist = createTaskSetBlacklist(stage) - taskSetBlacklist.updateBlacklistForFailedTask("hostA", exec = "1", index = 0) + taskSetBlacklist.updateBlacklistForFailedTask( + "hostA", exec = "1", index = 0, failureReason = "testing") } assertEquivalentToSet(blacklist.isExecutorBlacklisted(_), Set()) } @@ -147,7 +149,8 @@ class BlacklistTrackerSuite extends SparkFunSuite with BeforeAndAfterEach with M val numFailures = math.max(conf.get(config.MAX_FAILURES_PER_EXEC), conf.get(config.MAX_FAILURES_PER_EXEC_STAGE)) (0 until numFailures).foreach { index => - taskSetBlacklist.updateBlacklistForFailedTask("hostA", exec = "1", index = index) + taskSetBlacklist.updateBlacklistForFailedTask( + "hostA", exec = "1", index = index, failureReason = "testing") } assert(taskSetBlacklist.isExecutorBlacklistedForTaskSet("1")) assertEquivalentToSet(blacklist.isExecutorBlacklisted(_), Set()) @@ -170,7 +173,8 @@ class BlacklistTrackerSuite extends SparkFunSuite with BeforeAndAfterEach with M // Fail 4 tasks in one task set on executor 1, so that executor gets blacklisted for the whole // application. (0 until 4).foreach { partition => - taskSetBlacklist0.updateBlacklistForFailedTask("hostA", exec = "1", index = partition) + taskSetBlacklist0.updateBlacklistForFailedTask( + "hostA", exec = "1", index = partition, failureReason = "testing") } blacklist.updateBlacklistForSuccessfulTaskSet(0, 0, taskSetBlacklist0.execToFailures) assert(blacklist.nodeBlacklist() === Set()) @@ -183,7 +187,8 @@ class BlacklistTrackerSuite extends SparkFunSuite with BeforeAndAfterEach with M // application. Since that's the second executor that is blacklisted on the same node, we also // blacklist that node. (0 until 4).foreach { partition => - taskSetBlacklist1.updateBlacklistForFailedTask("hostA", exec = "2", index = partition) + taskSetBlacklist1.updateBlacklistForFailedTask( + "hostA", exec = "2", index = partition, failureReason = "testing") } blacklist.updateBlacklistForSuccessfulTaskSet(0, 0, taskSetBlacklist1.execToFailures) assert(blacklist.nodeBlacklist() === Set("hostA")) @@ -207,7 +212,8 @@ class BlacklistTrackerSuite extends SparkFunSuite with BeforeAndAfterEach with M // Fail one more task, but executor isn't put back into blacklist since the count of failures // on that executor should have been reset to 0. val taskSetBlacklist2 = createTaskSetBlacklist(stageId = 2) - taskSetBlacklist2.updateBlacklistForFailedTask("hostA", exec = "1", index = 0) + taskSetBlacklist2.updateBlacklistForFailedTask( + "hostA", exec = "1", index = 0, failureReason = "testing") blacklist.updateBlacklistForSuccessfulTaskSet(2, 0, taskSetBlacklist2.execToFailures) assert(blacklist.nodeBlacklist() === Set()) assertEquivalentToSet(blacklist.isNodeBlacklisted(_), Set()) @@ -221,7 +227,8 @@ class BlacklistTrackerSuite extends SparkFunSuite with BeforeAndAfterEach with M // Lets say that executor 1 dies completely. We get some task failures, but // the taskset then finishes successfully (elsewhere). (0 until 4).foreach { partition => - taskSetBlacklist0.updateBlacklistForFailedTask("hostA", exec = "1", index = partition) + taskSetBlacklist0.updateBlacklistForFailedTask( + "hostA", exec = "1", index = partition, failureReason = "testing") } blacklist.handleRemovedExecutor("1") blacklist.updateBlacklistForSuccessfulTaskSet( @@ -236,7 +243,8 @@ class BlacklistTrackerSuite extends SparkFunSuite with BeforeAndAfterEach with M // Now another executor gets spun up on that host, but it also dies. val taskSetBlacklist1 = createTaskSetBlacklist(stageId = 1) (0 until 4).foreach { partition => - taskSetBlacklist1.updateBlacklistForFailedTask("hostA", exec = "2", index = partition) + taskSetBlacklist1.updateBlacklistForFailedTask( + "hostA", exec = "2", index = partition, failureReason = "testing") } blacklist.handleRemovedExecutor("2") blacklist.updateBlacklistForSuccessfulTaskSet( @@ -279,7 +287,7 @@ class BlacklistTrackerSuite extends SparkFunSuite with BeforeAndAfterEach with M def failOneTaskInTaskSet(exec: String): Unit = { val taskSetBlacklist = createTaskSetBlacklist(stageId = stageId) - taskSetBlacklist.updateBlacklistForFailedTask("host-" + exec, exec, 0) + taskSetBlacklist.updateBlacklistForFailedTask("host-" + exec, exec, 0, "testing") blacklist.updateBlacklistForSuccessfulTaskSet(stageId, 0, taskSetBlacklist.execToFailures) stageId += 1 } @@ -354,12 +362,12 @@ class BlacklistTrackerSuite extends SparkFunSuite with BeforeAndAfterEach with M val taskSetBlacklist1 = createTaskSetBlacklist(stageId = 1) val taskSetBlacklist2 = createTaskSetBlacklist(stageId = 2) // Taskset1 has one failure immediately - taskSetBlacklist1.updateBlacklistForFailedTask("host-1", "1", 0) + taskSetBlacklist1.updateBlacklistForFailedTask("host-1", "1", 0, "testing") // Then we have a *long* delay, much longer than the timeout, before any other failures or // taskset completion clock.advance(blacklist.BLACKLIST_TIMEOUT_MILLIS * 5) // After the long delay, we have one failure on taskset 2, on the same executor - taskSetBlacklist2.updateBlacklistForFailedTask("host-1", "1", 0) + taskSetBlacklist2.updateBlacklistForFailedTask("host-1", "1", 0, "testing") // Finally, we complete both tasksets. Its important here to complete taskset2 *first*. We // want to make sure that when taskset 1 finishes, even though we've now got two task failures, // we realize that the task failure we just added was well before the timeout. @@ -377,16 +385,20 @@ class BlacklistTrackerSuite extends SparkFunSuite with BeforeAndAfterEach with M // we blacklist executors on two different hosts -- make sure that doesn't lead to any // node blacklisting val taskSetBlacklist0 = createTaskSetBlacklist(stageId = 0) - taskSetBlacklist0.updateBlacklistForFailedTask("hostA", exec = "1", index = 0) - taskSetBlacklist0.updateBlacklistForFailedTask("hostA", exec = "1", index = 1) + taskSetBlacklist0.updateBlacklistForFailedTask( + "hostA", exec = "1", index = 0, failureReason = "testing") + taskSetBlacklist0.updateBlacklistForFailedTask( + "hostA", exec = "1", index = 1, failureReason = "testing") blacklist.updateBlacklistForSuccessfulTaskSet(0, 0, taskSetBlacklist0.execToFailures) assertEquivalentToSet(blacklist.isExecutorBlacklisted(_), Set("1")) verify(listenerBusMock).post(SparkListenerExecutorBlacklisted(0, "1", 2)) assertEquivalentToSet(blacklist.isNodeBlacklisted(_), Set()) val taskSetBlacklist1 = createTaskSetBlacklist(stageId = 1) - taskSetBlacklist1.updateBlacklistForFailedTask("hostB", exec = "2", index = 0) - taskSetBlacklist1.updateBlacklistForFailedTask("hostB", exec = "2", index = 1) + taskSetBlacklist1.updateBlacklistForFailedTask( + "hostB", exec = "2", index = 0, failureReason = "testing") + taskSetBlacklist1.updateBlacklistForFailedTask( + "hostB", exec = "2", index = 1, failureReason = "testing") blacklist.updateBlacklistForSuccessfulTaskSet(1, 0, taskSetBlacklist1.execToFailures) assertEquivalentToSet(blacklist.isExecutorBlacklisted(_), Set("1", "2")) verify(listenerBusMock).post(SparkListenerExecutorBlacklisted(0, "2", 2)) @@ -395,8 +407,10 @@ class BlacklistTrackerSuite extends SparkFunSuite with BeforeAndAfterEach with M // Finally, blacklist another executor on the same node as the original blacklisted executor, // and make sure this time we *do* blacklist the node. val taskSetBlacklist2 = createTaskSetBlacklist(stageId = 0) - taskSetBlacklist2.updateBlacklistForFailedTask("hostA", exec = "3", index = 0) - taskSetBlacklist2.updateBlacklistForFailedTask("hostA", exec = "3", index = 1) + taskSetBlacklist2.updateBlacklistForFailedTask( + "hostA", exec = "3", index = 0, failureReason = "testing") + taskSetBlacklist2.updateBlacklistForFailedTask( + "hostA", exec = "3", index = 1, failureReason = "testing") blacklist.updateBlacklistForSuccessfulTaskSet(0, 0, taskSetBlacklist2.execToFailures) assertEquivalentToSet(blacklist.isExecutorBlacklisted(_), Set("1", "2", "3")) verify(listenerBusMock).post(SparkListenerExecutorBlacklisted(0, "3", 2)) @@ -486,7 +500,8 @@ class BlacklistTrackerSuite extends SparkFunSuite with BeforeAndAfterEach with M // Fail 4 tasks in one task set on executor 1, so that executor gets blacklisted for the whole // application. (0 until 4).foreach { partition => - taskSetBlacklist0.updateBlacklistForFailedTask("hostA", exec = "1", index = partition) + taskSetBlacklist0.updateBlacklistForFailedTask( + "hostA", exec = "1", index = partition, failureReason = "testing") } blacklist.updateBlacklistForSuccessfulTaskSet(0, 0, taskSetBlacklist0.execToFailures) @@ -497,7 +512,8 @@ class BlacklistTrackerSuite extends SparkFunSuite with BeforeAndAfterEach with M // application. Since that's the second executor that is blacklisted on the same node, we also // blacklist that node. (0 until 4).foreach { partition => - taskSetBlacklist1.updateBlacklistForFailedTask("hostA", exec = "2", index = partition) + taskSetBlacklist1.updateBlacklistForFailedTask( + "hostA", exec = "2", index = partition, failureReason = "testing") } blacklist.updateBlacklistForSuccessfulTaskSet(0, 0, taskSetBlacklist1.execToFailures) @@ -512,7 +528,8 @@ class BlacklistTrackerSuite extends SparkFunSuite with BeforeAndAfterEach with M // Fail 4 tasks in one task set on executor 1, so that executor gets blacklisted for the whole // application. (0 until 4).foreach { partition => - taskSetBlacklist2.updateBlacklistForFailedTask("hostA", exec = "1", index = partition) + taskSetBlacklist2.updateBlacklistForFailedTask( + "hostA", exec = "1", index = partition, failureReason = "testing") } blacklist.updateBlacklistForSuccessfulTaskSet(0, 0, taskSetBlacklist2.execToFailures) @@ -523,7 +540,8 @@ class BlacklistTrackerSuite extends SparkFunSuite with BeforeAndAfterEach with M // application. Since that's the second executor that is blacklisted on the same node, we also // blacklist that node. (0 until 4).foreach { partition => - taskSetBlacklist3.updateBlacklistForFailedTask("hostA", exec = "2", index = partition) + taskSetBlacklist3.updateBlacklistForFailedTask( + "hostA", exec = "2", index = partition, failureReason = "testing") } blacklist.updateBlacklistForSuccessfulTaskSet(0, 0, taskSetBlacklist3.execToFailures) diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala index b8626bf777598..6003899bb7bef 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala @@ -660,9 +660,14 @@ class TaskSchedulerImplSuite extends SparkFunSuite with LocalSparkContext with B assert(tsm.isZombie) assert(failedTaskSet) val idx = failedTask.index - assert(failedTaskSetReason === s"Aborting TaskSet 0.0 because task $idx (partition $idx) " + - s"cannot run anywhere due to node and executor blacklist. Blacklisting behavior can be " + - s"configured via spark.blacklist.*.") + assert(failedTaskSetReason === s""" + |Aborting $taskSet because task $idx (partition $idx) + |cannot run anywhere due to node and executor blacklist. + |Most recent failure: + |${tsm.taskSetBlacklistHelperOpt.get.getLatestFailureReason} + | + |Blacklisting behavior can be configured via spark.blacklist.*. + |""".stripMargin) } test("don't abort if there is an executor available, though it hasn't had scheduled tasks yet") { diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskSetBlacklistSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskSetBlacklistSuite.scala index f1392e9db6bfd..18981d5be2f94 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/TaskSetBlacklistSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/TaskSetBlacklistSuite.scala @@ -37,7 +37,8 @@ class TaskSetBlacklistSuite extends SparkFunSuite { // First, mark task 0 as failed on exec1. // task 0 should be blacklisted on exec1, and nowhere else - taskSetBlacklist.updateBlacklistForFailedTask("hostA", exec = "exec1", index = 0) + taskSetBlacklist.updateBlacklistForFailedTask( + "hostA", exec = "exec1", index = 0, failureReason = "testing") for { executor <- (1 to 4).map(_.toString) index <- 0 until 10 @@ -49,17 +50,20 @@ class TaskSetBlacklistSuite extends SparkFunSuite { assert(!taskSetBlacklist.isNodeBlacklistedForTaskSet("hostA")) // Mark task 1 failed on exec1 -- this pushes the executor into the blacklist - taskSetBlacklist.updateBlacklistForFailedTask("hostA", exec = "exec1", index = 1) + taskSetBlacklist.updateBlacklistForFailedTask( + "hostA", exec = "exec1", index = 1, failureReason = "testing") assert(taskSetBlacklist.isExecutorBlacklistedForTaskSet("exec1")) assert(!taskSetBlacklist.isNodeBlacklistedForTaskSet("hostA")) // Mark one task as failed on exec2 -- not enough for any further blacklisting yet. - taskSetBlacklist.updateBlacklistForFailedTask("hostA", exec = "exec2", index = 0) + taskSetBlacklist.updateBlacklistForFailedTask( + "hostA", exec = "exec2", index = 0, failureReason = "testing") assert(taskSetBlacklist.isExecutorBlacklistedForTaskSet("exec1")) assert(!taskSetBlacklist.isExecutorBlacklistedForTaskSet("exec2")) assert(!taskSetBlacklist.isNodeBlacklistedForTaskSet("hostA")) // Mark another task as failed on exec2 -- now we blacklist exec2, which also leads to // blacklisting the entire node. - taskSetBlacklist.updateBlacklistForFailedTask("hostA", exec = "exec2", index = 1) + taskSetBlacklist.updateBlacklistForFailedTask( + "hostA", exec = "exec2", index = 1, failureReason = "testing") assert(taskSetBlacklist.isExecutorBlacklistedForTaskSet("exec1")) assert(taskSetBlacklist.isExecutorBlacklistedForTaskSet("exec2")) assert(taskSetBlacklist.isNodeBlacklistedForTaskSet("hostA")) @@ -108,34 +112,41 @@ class TaskSetBlacklistSuite extends SparkFunSuite { .set(config.MAX_FAILED_EXEC_PER_NODE_STAGE, 3) val taskSetBlacklist = new TaskSetBlacklist(conf, stageId = 0, new SystemClock()) // Fail a task twice on hostA, exec:1 - taskSetBlacklist.updateBlacklistForFailedTask("hostA", exec = "1", index = 0) - taskSetBlacklist.updateBlacklistForFailedTask("hostA", exec = "1", index = 0) + taskSetBlacklist.updateBlacklistForFailedTask( + "hostA", exec = "1", index = 0, failureReason = "testing") + taskSetBlacklist.updateBlacklistForFailedTask( + "hostA", exec = "1", index = 0, failureReason = "testing") assert(taskSetBlacklist.isExecutorBlacklistedForTask("1", 0)) assert(!taskSetBlacklist.isNodeBlacklistedForTask("hostA", 0)) assert(!taskSetBlacklist.isExecutorBlacklistedForTaskSet("1")) assert(!taskSetBlacklist.isNodeBlacklistedForTaskSet("hostA")) // Fail the same task once more on hostA, exec:2 - taskSetBlacklist.updateBlacklistForFailedTask("hostA", exec = "2", index = 0) + taskSetBlacklist.updateBlacklistForFailedTask( + "hostA", exec = "2", index = 0, failureReason = "testing") assert(taskSetBlacklist.isNodeBlacklistedForTask("hostA", 0)) assert(!taskSetBlacklist.isExecutorBlacklistedForTaskSet("2")) assert(!taskSetBlacklist.isNodeBlacklistedForTaskSet("hostA")) // Fail another task on hostA, exec:1. Now that executor has failures on two different tasks, // so its blacklisted - taskSetBlacklist.updateBlacklistForFailedTask("hostA", exec = "1", index = 1) + taskSetBlacklist.updateBlacklistForFailedTask( + "hostA", exec = "1", index = 1, failureReason = "testing") assert(taskSetBlacklist.isExecutorBlacklistedForTaskSet("1")) assert(!taskSetBlacklist.isNodeBlacklistedForTaskSet("hostA")) // Fail a third task on hostA, exec:2, so that exec is blacklisted for the whole task set - taskSetBlacklist.updateBlacklistForFailedTask("hostA", exec = "2", index = 2) + taskSetBlacklist.updateBlacklistForFailedTask( + "hostA", exec = "2", index = 2, failureReason = "testing") assert(taskSetBlacklist.isExecutorBlacklistedForTaskSet("2")) assert(!taskSetBlacklist.isNodeBlacklistedForTaskSet("hostA")) // Fail a fourth & fifth task on hostA, exec:3. Now we've got three executors that are // blacklisted for the taskset, so blacklist the whole node. - taskSetBlacklist.updateBlacklistForFailedTask("hostA", exec = "3", index = 3) - taskSetBlacklist.updateBlacklistForFailedTask("hostA", exec = "3", index = 4) + taskSetBlacklist.updateBlacklistForFailedTask( + "hostA", exec = "3", index = 3, failureReason = "testing") + taskSetBlacklist.updateBlacklistForFailedTask( + "hostA", exec = "3", index = 4, failureReason = "testing") assert(taskSetBlacklist.isExecutorBlacklistedForTaskSet("3")) assert(taskSetBlacklist.isNodeBlacklistedForTaskSet("hostA")) } @@ -147,13 +158,17 @@ class TaskSetBlacklistSuite extends SparkFunSuite { val conf = new SparkConf().setAppName("test").setMaster("local") .set(config.BLACKLIST_ENABLED.key, "true") val taskSetBlacklist = new TaskSetBlacklist(conf, stageId = 0, new SystemClock()) - taskSetBlacklist.updateBlacklistForFailedTask("hostA", exec = "1", index = 0) - taskSetBlacklist.updateBlacklistForFailedTask("hostA", exec = "1", index = 1) + taskSetBlacklist.updateBlacklistForFailedTask( + "hostA", exec = "1", index = 0, failureReason = "testing") + taskSetBlacklist.updateBlacklistForFailedTask( + "hostA", exec = "1", index = 1, failureReason = "testing") assert(taskSetBlacklist.isExecutorBlacklistedForTaskSet("1")) assert(!taskSetBlacklist.isNodeBlacklistedForTaskSet("hostA")) - taskSetBlacklist.updateBlacklistForFailedTask("hostB", exec = "2", index = 0) - taskSetBlacklist.updateBlacklistForFailedTask("hostB", exec = "2", index = 1) + taskSetBlacklist.updateBlacklistForFailedTask( + "hostB", exec = "2", index = 0, failureReason = "testing") + taskSetBlacklist.updateBlacklistForFailedTask( + "hostB", exec = "2", index = 1, failureReason = "testing") assert(taskSetBlacklist.isExecutorBlacklistedForTaskSet("1")) assert(taskSetBlacklist.isExecutorBlacklistedForTaskSet("2")) assert(!taskSetBlacklist.isNodeBlacklistedForTaskSet("hostA")) diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala index ae43f4cadc037..5c712bd6a545b 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala @@ -1146,7 +1146,7 @@ class TaskSetManagerSuite extends SparkFunSuite with LocalSparkContext with Logg // Make sure that the blacklist ignored all of the task failures above, since they aren't // the fault of the executor where the task was running. verify(blacklist, never()) - .updateBlacklistForFailedTask(anyString(), anyString(), anyInt()) + .updateBlacklistForFailedTask(anyString(), anyString(), anyInt(), anyString()) } test("update application blacklist for shuffle-fetch") { From f20be4d70bf321f377020d1bde761a43e5c72f0a Mon Sep 17 00:00:00 2001 From: Paul Mackles Date: Thu, 28 Sep 2017 14:43:31 +0800 Subject: [PATCH 1414/1765] [SPARK-22135][MESOS] metrics in spark-dispatcher not being registered properly ## What changes were proposed in this pull request? Fix a trivial bug with how metrics are registered in the mesos dispatcher. Bug resulted in creating a new registry each time the metricRegistry() method was called. ## How was this patch tested? Verified manually on local mesos setup Author: Paul Mackles Closes #19358 from pmackles/SPARK-22135. --- .../cluster/mesos/MesosClusterSchedulerSource.scala | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterSchedulerSource.scala b/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterSchedulerSource.scala index 1fe94974c8e36..76aded4edb431 100644 --- a/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterSchedulerSource.scala +++ b/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterSchedulerSource.scala @@ -23,8 +23,9 @@ import org.apache.spark.metrics.source.Source private[mesos] class MesosClusterSchedulerSource(scheduler: MesosClusterScheduler) extends Source { - override def sourceName: String = "mesos_cluster" - override def metricRegistry: MetricRegistry = new MetricRegistry() + + override val sourceName: String = "mesos_cluster" + override val metricRegistry: MetricRegistry = new MetricRegistry() metricRegistry.register(MetricRegistry.name("waitingDrivers"), new Gauge[Int] { override def getValue: Int = scheduler.getQueuedDriversSize From 01bd00d13532af1c7328997cbec446b0d3e21459 Mon Sep 17 00:00:00 2001 From: Sean Owen Date: Thu, 28 Sep 2017 08:22:48 +0100 Subject: [PATCH 1415/1765] [SPARK-22128][CORE] Update paranamer to 2.8 to avoid BytecodeReadingParanamer ArrayIndexOutOfBoundsException with Scala 2.12 + Java 8 lambda ## What changes were proposed in this pull request? Un-manage jackson-module-paranamer version to let it use the version desired by jackson-module-scala; manage paranamer up from 2.8 for jackson-module-scala 2.7.9, to override avro 1.7.7's desired paranamer 2.3 ## How was this patch tested? Existing tests Author: Sean Owen Closes #19352 from srowen/SPARK-22128. --- dev/deps/spark-deps-hadoop-2.6 | 4 ++-- dev/deps/spark-deps-hadoop-2.7 | 4 ++-- pom.xml | 10 ++++------ 3 files changed, 8 insertions(+), 10 deletions(-) diff --git a/dev/deps/spark-deps-hadoop-2.6 b/dev/deps/spark-deps-hadoop-2.6 index e534e38213fb1..76fcbd15869f1 100644 --- a/dev/deps/spark-deps-hadoop-2.6 +++ b/dev/deps/spark-deps-hadoop-2.6 @@ -93,7 +93,7 @@ jackson-core-asl-1.9.13.jar jackson-databind-2.6.7.1.jar jackson-jaxrs-1.9.13.jar jackson-mapper-asl-1.9.13.jar -jackson-module-paranamer-2.6.7.jar +jackson-module-paranamer-2.7.9.jar jackson-module-scala_2.11-2.6.7.1.jar jackson-xc-1.9.13.jar janino-3.0.0.jar @@ -153,7 +153,7 @@ orc-core-1.4.0-nohive.jar orc-mapreduce-1.4.0-nohive.jar oro-2.0.8.jar osgi-resource-locator-1.0.1.jar -paranamer-2.6.jar +paranamer-2.8.jar parquet-column-1.8.2.jar parquet-common-1.8.2.jar parquet-encoding-1.8.2.jar diff --git a/dev/deps/spark-deps-hadoop-2.7 b/dev/deps/spark-deps-hadoop-2.7 index 02c5a19d173be..cb20072bf8b30 100644 --- a/dev/deps/spark-deps-hadoop-2.7 +++ b/dev/deps/spark-deps-hadoop-2.7 @@ -93,7 +93,7 @@ jackson-core-asl-1.9.13.jar jackson-databind-2.6.7.1.jar jackson-jaxrs-1.9.13.jar jackson-mapper-asl-1.9.13.jar -jackson-module-paranamer-2.6.7.jar +jackson-module-paranamer-2.7.9.jar jackson-module-scala_2.11-2.6.7.1.jar jackson-xc-1.9.13.jar janino-3.0.0.jar @@ -154,7 +154,7 @@ orc-core-1.4.0-nohive.jar orc-mapreduce-1.4.0-nohive.jar oro-2.0.8.jar osgi-resource-locator-1.0.1.jar -paranamer-2.6.jar +paranamer-2.8.jar parquet-column-1.8.2.jar parquet-common-1.8.2.jar parquet-encoding-1.8.2.jar diff --git a/pom.xml b/pom.xml index 83a35006707da..87a468c3a6f55 100644 --- a/pom.xml +++ b/pom.xml @@ -179,7 +179,10 @@ 4.7 1.1 2.52.0 - 2.6 + + 2.8 1.8 1.0.0 0.4.0 @@ -637,11 +640,6 @@
    - - com.fasterxml.jackson.module - jackson-module-paranamer - ${fasterxml.jackson.version} - com.fasterxml.jackson.module jackson-module-jaxb-annotations From d74dee1336e7152cc0fb7d2b3bf1a44f4f452025 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Thu, 28 Sep 2017 09:20:37 -0700 Subject: [PATCH 1416/1765] [SPARK-22153][SQL] Rename ShuffleExchange -> ShuffleExchangeExec ## What changes were proposed in this pull request? For some reason when we added the Exec suffix to all physical operators, we missed this one. I was looking for this physical operator today and couldn't find it, because I was looking for ExchangeExec. ## How was this patch tested? This is a simple rename and should be covered by existing tests. Author: Reynold Xin Closes #19376 from rxin/SPARK-22153. --- .../spark/sql/execution/SparkStrategies.scala | 6 +-- .../exchange/EnsureRequirements.scala | 26 ++++++------- .../exchange/ExchangeCoordinator.scala | 38 +++++++++---------- ...change.scala => ShuffleExchangeExec.scala} | 10 ++--- .../apache/spark/sql/execution/limit.scala | 6 +-- .../streaming/IncrementalExecution.scala | 4 +- .../apache/spark/sql/CachedTableSuite.scala | 5 ++- .../org/apache/spark/sql/DataFrameSuite.scala | 10 ++--- .../org/apache/spark/sql/DatasetSuite.scala | 4 +- .../execution/ExchangeCoordinatorSuite.scala | 22 +++++------ .../spark/sql/execution/ExchangeSuite.scala | 12 +++--- .../spark/sql/execution/PlannerSuite.scala | 32 ++++++++-------- .../spark/sql/sources/BucketedReadSuite.scala | 10 ++--- .../EnsureStatefulOpPartitioningSuite.scala | 4 +- 14 files changed, 95 insertions(+), 94 deletions(-) rename sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/{ShuffleExchange.scala => ShuffleExchangeExec.scala} (98%) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 4da7a73469537..92eaab5cd8f81 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -28,7 +28,7 @@ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.execution.columnar.{InMemoryRelation, InMemoryTableScanExec} import org.apache.spark.sql.execution.command._ -import org.apache.spark.sql.execution.exchange.ShuffleExchange +import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec import org.apache.spark.sql.execution.joins.{BuildLeft, BuildRight} import org.apache.spark.sql.execution.streaming._ import org.apache.spark.sql.internal.SQLConf @@ -411,7 +411,7 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { case logical.Repartition(numPartitions, shuffle, child) => if (shuffle) { - ShuffleExchange(RoundRobinPartitioning(numPartitions), planLater(child)) :: Nil + ShuffleExchangeExec(RoundRobinPartitioning(numPartitions), planLater(child)) :: Nil } else { execution.CoalesceExec(numPartitions, planLater(child)) :: Nil } @@ -446,7 +446,7 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { case r: logical.Range => execution.RangeExec(r) :: Nil case logical.RepartitionByExpression(expressions, child, numPartitions) => - exchange.ShuffleExchange(HashPartitioning( + exchange.ShuffleExchangeExec(HashPartitioning( expressions, numPartitions), planLater(child)) :: Nil case ExternalRDD(outputObjAttr, rdd) => ExternalRDDScanExec(outputObjAttr, rdd) :: Nil case r: LogicalRDD => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala index 1da72f2e92329..d28ce60e276d5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala @@ -27,8 +27,8 @@ import org.apache.spark.sql.internal.SQLConf * Ensures that the [[org.apache.spark.sql.catalyst.plans.physical.Partitioning Partitioning]] * of input data meets the * [[org.apache.spark.sql.catalyst.plans.physical.Distribution Distribution]] requirements for - * each operator by inserting [[ShuffleExchange]] Operators where required. Also ensure that the - * input partition ordering requirements are met. + * each operator by inserting [[ShuffleExchangeExec]] Operators where required. Also ensure that + * the input partition ordering requirements are met. */ case class EnsureRequirements(conf: SQLConf) extends Rule[SparkPlan] { private def defaultNumPreShufflePartitions: Int = conf.numShufflePartitions @@ -57,17 +57,17 @@ case class EnsureRequirements(conf: SQLConf) extends Rule[SparkPlan] { } /** - * Adds [[ExchangeCoordinator]] to [[ShuffleExchange]]s if adaptive query execution is enabled - * and partitioning schemes of these [[ShuffleExchange]]s support [[ExchangeCoordinator]]. + * Adds [[ExchangeCoordinator]] to [[ShuffleExchangeExec]]s if adaptive query execution is enabled + * and partitioning schemes of these [[ShuffleExchangeExec]]s support [[ExchangeCoordinator]]. */ private def withExchangeCoordinator( children: Seq[SparkPlan], requiredChildDistributions: Seq[Distribution]): Seq[SparkPlan] = { val supportsCoordinator = - if (children.exists(_.isInstanceOf[ShuffleExchange])) { + if (children.exists(_.isInstanceOf[ShuffleExchangeExec])) { // Right now, ExchangeCoordinator only support HashPartitionings. children.forall { - case e @ ShuffleExchange(hash: HashPartitioning, _, _) => true + case e @ ShuffleExchangeExec(hash: HashPartitioning, _, _) => true case child => child.outputPartitioning match { case hash: HashPartitioning => true @@ -94,7 +94,7 @@ case class EnsureRequirements(conf: SQLConf) extends Rule[SparkPlan] { targetPostShuffleInputSize, minNumPostShufflePartitions) children.zip(requiredChildDistributions).map { - case (e: ShuffleExchange, _) => + case (e: ShuffleExchangeExec, _) => // This child is an Exchange, we need to add the coordinator. e.copy(coordinator = Some(coordinator)) case (child, distribution) => @@ -138,7 +138,7 @@ case class EnsureRequirements(conf: SQLConf) extends Rule[SparkPlan] { val targetPartitioning = createPartitioning(distribution, defaultNumPreShufflePartitions) assert(targetPartitioning.isInstanceOf[HashPartitioning]) - ShuffleExchange(targetPartitioning, child, Some(coordinator)) + ShuffleExchangeExec(targetPartitioning, child, Some(coordinator)) } } else { // If we do not need ExchangeCoordinator, the original children are returned. @@ -162,7 +162,7 @@ case class EnsureRequirements(conf: SQLConf) extends Rule[SparkPlan] { case (child, BroadcastDistribution(mode)) => BroadcastExchangeExec(mode, child) case (child, distribution) => - ShuffleExchange(createPartitioning(distribution, defaultNumPreShufflePartitions), child) + ShuffleExchangeExec(createPartitioning(distribution, defaultNumPreShufflePartitions), child) } // If the operator has multiple children and specifies child output distributions (e.g. join), @@ -215,8 +215,8 @@ case class EnsureRequirements(conf: SQLConf) extends Rule[SparkPlan] { child match { // If child is an exchange, we replace it with // a new one having targetPartitioning. - case ShuffleExchange(_, c, _) => ShuffleExchange(targetPartitioning, c) - case _ => ShuffleExchange(targetPartitioning, child) + case ShuffleExchangeExec(_, c, _) => ShuffleExchangeExec(targetPartitioning, c) + case _ => ShuffleExchangeExec(targetPartitioning, child) } } } @@ -246,9 +246,9 @@ case class EnsureRequirements(conf: SQLConf) extends Rule[SparkPlan] { } def apply(plan: SparkPlan): SparkPlan = plan.transformUp { - case operator @ ShuffleExchange(partitioning, child, _) => + case operator @ ShuffleExchangeExec(partitioning, child, _) => child.children match { - case ShuffleExchange(childPartitioning, baseChild, _)::Nil => + case ShuffleExchangeExec(childPartitioning, baseChild, _)::Nil => if (childPartitioning.guarantees(partitioning)) child else operator case _ => operator } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ExchangeCoordinator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ExchangeCoordinator.scala index 9fc4ffb651ec8..78f11ca8d8c78 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ExchangeCoordinator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ExchangeCoordinator.scala @@ -35,9 +35,9 @@ import org.apache.spark.sql.execution.{ShuffledRowRDD, SparkPlan} * * A coordinator is constructed with three parameters, `numExchanges`, * `targetPostShuffleInputSize`, and `minNumPostShufflePartitions`. - * - `numExchanges` is used to indicated that how many [[ShuffleExchange]]s that will be registered - * to this coordinator. So, when we start to do any actual work, we have a way to make sure that - * we have got expected number of [[ShuffleExchange]]s. + * - `numExchanges` is used to indicated that how many [[ShuffleExchangeExec]]s that will be + * registered to this coordinator. So, when we start to do any actual work, we have a way to + * make sure that we have got expected number of [[ShuffleExchangeExec]]s. * - `targetPostShuffleInputSize` is the targeted size of a post-shuffle partition's * input data size. With this parameter, we can estimate the number of post-shuffle partitions. * This parameter is configured through @@ -47,28 +47,28 @@ import org.apache.spark.sql.execution.{ShuffledRowRDD, SparkPlan} * partitions. * * The workflow of this coordinator is described as follows: - * - Before the execution of a [[SparkPlan]], for a [[ShuffleExchange]] operator, + * - Before the execution of a [[SparkPlan]], for a [[ShuffleExchangeExec]] operator, * if an [[ExchangeCoordinator]] is assigned to it, it registers itself to this coordinator. * This happens in the `doPrepare` method. - * - Once we start to execute a physical plan, a [[ShuffleExchange]] registered to this + * - Once we start to execute a physical plan, a [[ShuffleExchangeExec]] registered to this * coordinator will call `postShuffleRDD` to get its corresponding post-shuffle * [[ShuffledRowRDD]]. - * If this coordinator has made the decision on how to shuffle data, this [[ShuffleExchange]] + * If this coordinator has made the decision on how to shuffle data, this [[ShuffleExchangeExec]] * will immediately get its corresponding post-shuffle [[ShuffledRowRDD]]. * - If this coordinator has not made the decision on how to shuffle data, it will ask those - * registered [[ShuffleExchange]]s to submit their pre-shuffle stages. Then, based on the + * registered [[ShuffleExchangeExec]]s to submit their pre-shuffle stages. Then, based on the * size statistics of pre-shuffle partitions, this coordinator will determine the number of * post-shuffle partitions and pack multiple pre-shuffle partitions with continuous indices * to a single post-shuffle partition whenever necessary. * - Finally, this coordinator will create post-shuffle [[ShuffledRowRDD]]s for all registered - * [[ShuffleExchange]]s. So, when a [[ShuffleExchange]] calls `postShuffleRDD`, this coordinator - * can lookup the corresponding [[RDD]]. + * [[ShuffleExchangeExec]]s. So, when a [[ShuffleExchangeExec]] calls `postShuffleRDD`, this + * coordinator can lookup the corresponding [[RDD]]. * * The strategy used to determine the number of post-shuffle partitions is described as follows. * To determine the number of post-shuffle partitions, we have a target input size for a * post-shuffle partition. Once we have size statistics of pre-shuffle partitions from stages - * corresponding to the registered [[ShuffleExchange]]s, we will do a pass of those statistics and - * pack pre-shuffle partitions with continuous indices to a single post-shuffle partition until + * corresponding to the registered [[ShuffleExchangeExec]]s, we will do a pass of those statistics + * and pack pre-shuffle partitions with continuous indices to a single post-shuffle partition until * adding another pre-shuffle partition would cause the size of a post-shuffle partition to be * greater than the target size. * @@ -89,11 +89,11 @@ class ExchangeCoordinator( extends Logging { // The registered Exchange operators. - private[this] val exchanges = ArrayBuffer[ShuffleExchange]() + private[this] val exchanges = ArrayBuffer[ShuffleExchangeExec]() // This map is used to lookup the post-shuffle ShuffledRowRDD for an Exchange operator. - private[this] val postShuffleRDDs: JMap[ShuffleExchange, ShuffledRowRDD] = - new JHashMap[ShuffleExchange, ShuffledRowRDD](numExchanges) + private[this] val postShuffleRDDs: JMap[ShuffleExchangeExec, ShuffledRowRDD] = + new JHashMap[ShuffleExchangeExec, ShuffledRowRDD](numExchanges) // A boolean that indicates if this coordinator has made decision on how to shuffle data. // This variable will only be updated by doEstimationIfNecessary, which is protected by @@ -101,11 +101,11 @@ class ExchangeCoordinator( @volatile private[this] var estimated: Boolean = false /** - * Registers a [[ShuffleExchange]] operator to this coordinator. This method is only allowed to - * be called in the `doPrepare` method of a [[ShuffleExchange]] operator. + * Registers a [[ShuffleExchangeExec]] operator to this coordinator. This method is only allowed + * to be called in the `doPrepare` method of a [[ShuffleExchangeExec]] operator. */ @GuardedBy("this") - def registerExchange(exchange: ShuffleExchange): Unit = synchronized { + def registerExchange(exchange: ShuffleExchangeExec): Unit = synchronized { exchanges += exchange } @@ -200,7 +200,7 @@ class ExchangeCoordinator( // Make sure we have the expected number of registered Exchange operators. assert(exchanges.length == numExchanges) - val newPostShuffleRDDs = new JHashMap[ShuffleExchange, ShuffledRowRDD](numExchanges) + val newPostShuffleRDDs = new JHashMap[ShuffleExchangeExec, ShuffledRowRDD](numExchanges) // Submit all map stages val shuffleDependencies = ArrayBuffer[ShuffleDependency[Int, InternalRow, InternalRow]]() @@ -255,7 +255,7 @@ class ExchangeCoordinator( } } - def postShuffleRDD(exchange: ShuffleExchange): ShuffledRowRDD = { + def postShuffleRDD(exchange: ShuffleExchangeExec): ShuffledRowRDD = { doEstimationIfNecessary() if (!postShuffleRDDs.containsKey(exchange)) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchange.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala similarity index 98% rename from sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchange.scala rename to sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala index 0d06d83fb2f3c..11c4aa9b4acf0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchange.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala @@ -35,7 +35,7 @@ import org.apache.spark.util.MutablePair /** * Performs a shuffle that will result in the desired `newPartitioning`. */ -case class ShuffleExchange( +case class ShuffleExchangeExec( var newPartitioning: Partitioning, child: SparkPlan, @transient coordinator: Option[ExchangeCoordinator]) extends Exchange { @@ -84,7 +84,7 @@ case class ShuffleExchange( */ private[exchange] def prepareShuffleDependency() : ShuffleDependency[Int, InternalRow, InternalRow] = { - ShuffleExchange.prepareShuffleDependency( + ShuffleExchangeExec.prepareShuffleDependency( child.execute(), child.output, newPartitioning, serializer) } @@ -129,9 +129,9 @@ case class ShuffleExchange( } } -object ShuffleExchange { - def apply(newPartitioning: Partitioning, child: SparkPlan): ShuffleExchange = { - ShuffleExchange(newPartitioning, child, coordinator = Option.empty[ExchangeCoordinator]) +object ShuffleExchangeExec { + def apply(newPartitioning: Partitioning, child: SparkPlan): ShuffleExchangeExec = { + ShuffleExchangeExec(newPartitioning, child, coordinator = Option.empty[ExchangeCoordinator]) } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala index 1f515e29b4af5..13da4b26a5dcb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala @@ -23,7 +23,7 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode, LazilyGeneratedOrdering} import org.apache.spark.sql.catalyst.plans.physical._ -import org.apache.spark.sql.execution.exchange.ShuffleExchange +import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec import org.apache.spark.util.Utils /** @@ -40,7 +40,7 @@ case class CollectLimitExec(limit: Int, child: SparkPlan) extends UnaryExecNode protected override def doExecute(): RDD[InternalRow] = { val locallyLimited = child.execute().mapPartitionsInternal(_.take(limit)) val shuffled = new ShuffledRowRDD( - ShuffleExchange.prepareShuffleDependency( + ShuffleExchangeExec.prepareShuffleDependency( locallyLimited, child.output, SinglePartition, serializer)) shuffled.mapPartitionsInternal(_.take(limit)) } @@ -153,7 +153,7 @@ case class TakeOrderedAndProjectExec( } } val shuffled = new ShuffledRowRDD( - ShuffleExchange.prepareShuffleDependency( + ShuffleExchangeExec.prepareShuffleDependency( localTopK, child.output, SinglePartition, serializer)) shuffled.mapPartitions { iter => val topK = org.apache.spark.util.collection.Utils.takeOrdered(iter.map(_.copy()), limit)(ord) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala index 8e0aae39cabb6..82f879c763c2b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala @@ -27,7 +27,7 @@ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.plans.physical.{AllTuples, ClusteredDistribution, HashPartitioning, SinglePartition} import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.execution.{QueryExecution, SparkPlan, SparkPlanner, UnaryExecNode} -import org.apache.spark.sql.execution.exchange.ShuffleExchange +import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec import org.apache.spark.sql.streaming.OutputMode /** @@ -155,7 +155,7 @@ object EnsureStatefulOpPartitioning extends Rule[SparkPlan] { child.execute().getNumPartitions == expectedPartitioning.numPartitions) { child } else { - ShuffleExchange(expectedPartitioning, child) + ShuffleExchangeExec(expectedPartitioning, child) } } so.withNewChildren(children) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala index 3e4f619431599..1e52445f28fc1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala @@ -28,7 +28,7 @@ import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.expressions.SubqueryExpression import org.apache.spark.sql.execution.{RDDScanExec, SparkPlan} import org.apache.spark.sql.execution.columnar._ -import org.apache.spark.sql.execution.exchange.ShuffleExchange +import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec import org.apache.spark.sql.functions._ import org.apache.spark.sql.test.{SharedSQLContext, SQLTestUtils} import org.apache.spark.storage.{RDDBlockId, StorageLevel} @@ -420,7 +420,8 @@ class CachedTableSuite extends QueryTest with SQLTestUtils with SharedSQLContext * Verifies that the plan for `df` contains `expected` number of Exchange operators. */ private def verifyNumExchanges(df: DataFrame, expected: Int): Unit = { - assert(df.queryExecution.executedPlan.collect { case e: ShuffleExchange => e }.size == expected) + assert( + df.queryExecution.executedPlan.collect { case e: ShuffleExchangeExec => e }.size == expected) } test("A cached table preserves the partitioning and ordering of its cached SparkPlan") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index 6178661cf7b2b..0e2f2e5a193e1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -31,7 +31,7 @@ import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.plans.logical.{Filter, OneRowRelation, Union} import org.apache.spark.sql.execution.{FilterExec, QueryExecution} import org.apache.spark.sql.execution.aggregate.HashAggregateExec -import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, ReusedExchangeExec, ShuffleExchange} +import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, ReusedExchangeExec, ShuffleExchangeExec} import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.{ExamplePoint, ExamplePointUDT, SharedSQLContext} @@ -1529,7 +1529,7 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { fail("Should not have back to back Aggregates") } atFirstAgg = true - case e: ShuffleExchange => atFirstAgg = false + case e: ShuffleExchangeExec => atFirstAgg = false case _ => } } @@ -1710,19 +1710,19 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { val plan = join.queryExecution.executedPlan checkAnswer(join, df) assert( - join.queryExecution.executedPlan.collect { case e: ShuffleExchange => true }.size === 1) + join.queryExecution.executedPlan.collect { case e: ShuffleExchangeExec => true }.size === 1) assert( join.queryExecution.executedPlan.collect { case e: ReusedExchangeExec => true }.size === 1) val broadcasted = broadcast(join) val join2 = join.join(broadcasted, "id").join(broadcasted, "id") checkAnswer(join2, df) assert( - join2.queryExecution.executedPlan.collect { case e: ShuffleExchange => true }.size === 1) + join2.queryExecution.executedPlan.collect { case e: ShuffleExchangeExec => true }.size == 1) assert( join2.queryExecution.executedPlan .collect { case e: BroadcastExchangeExec => true }.size === 1) assert( - join2.queryExecution.executedPlan.collect { case e: ReusedExchangeExec => true }.size === 4) + join2.queryExecution.executedPlan.collect { case e: ReusedExchangeExec => true }.size == 4) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala index 5015f3709f131..dace6825ee40e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala @@ -24,7 +24,7 @@ import org.apache.spark.sql.catalyst.encoders.{OuterScopes, RowEncoder} import org.apache.spark.sql.catalyst.plans.{LeftAnti, LeftSemi} import org.apache.spark.sql.catalyst.util.sideBySide import org.apache.spark.sql.execution.{LogicalRDD, RDDScanExec} -import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, ShuffleExchange} +import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, ShuffleExchangeExec} import org.apache.spark.sql.execution.streaming.MemoryStream import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf @@ -1206,7 +1206,7 @@ class DatasetSuite extends QueryTest with SharedSQLContext { val agg = cp.groupBy('id % 2).agg(count('id)) agg.queryExecution.executedPlan.collectFirst { - case ShuffleExchange(_, _: RDDScanExec, _) => + case ShuffleExchangeExec(_, _: RDDScanExec, _) => case BroadcastExchangeExec(_, _: RDDScanExec) => }.foreach { _ => fail( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeCoordinatorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeCoordinatorSuite.scala index f1b5e3be5b63f..737eeb0af586e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeCoordinatorSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeCoordinatorSuite.scala @@ -21,7 +21,7 @@ import org.scalatest.BeforeAndAfterAll import org.apache.spark.{MapOutputStatistics, SparkConf, SparkFunSuite} import org.apache.spark.sql._ -import org.apache.spark.sql.execution.exchange.{ExchangeCoordinator, ShuffleExchange} +import org.apache.spark.sql.execution.exchange.{ExchangeCoordinator, ShuffleExchangeExec} import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf @@ -300,13 +300,13 @@ class ExchangeCoordinatorSuite extends SparkFunSuite with BeforeAndAfterAll { // Then, let's look at the number of post-shuffle partitions estimated // by the ExchangeCoordinator. val exchanges = agg.queryExecution.executedPlan.collect { - case e: ShuffleExchange => e + case e: ShuffleExchangeExec => e } assert(exchanges.length === 1) minNumPostShufflePartitions match { case Some(numPartitions) => exchanges.foreach { - case e: ShuffleExchange => + case e: ShuffleExchangeExec => assert(e.coordinator.isDefined) assert(e.outputPartitioning.numPartitions === 5) case o => @@ -314,7 +314,7 @@ class ExchangeCoordinatorSuite extends SparkFunSuite with BeforeAndAfterAll { case None => exchanges.foreach { - case e: ShuffleExchange => + case e: ShuffleExchangeExec => assert(e.coordinator.isDefined) assert(e.outputPartitioning.numPartitions === 3) case o => @@ -351,13 +351,13 @@ class ExchangeCoordinatorSuite extends SparkFunSuite with BeforeAndAfterAll { // Then, let's look at the number of post-shuffle partitions estimated // by the ExchangeCoordinator. val exchanges = join.queryExecution.executedPlan.collect { - case e: ShuffleExchange => e + case e: ShuffleExchangeExec => e } assert(exchanges.length === 2) minNumPostShufflePartitions match { case Some(numPartitions) => exchanges.foreach { - case e: ShuffleExchange => + case e: ShuffleExchangeExec => assert(e.coordinator.isDefined) assert(e.outputPartitioning.numPartitions === 5) case o => @@ -365,7 +365,7 @@ class ExchangeCoordinatorSuite extends SparkFunSuite with BeforeAndAfterAll { case None => exchanges.foreach { - case e: ShuffleExchange => + case e: ShuffleExchangeExec => assert(e.coordinator.isDefined) assert(e.outputPartitioning.numPartitions === 2) case o => @@ -407,13 +407,13 @@ class ExchangeCoordinatorSuite extends SparkFunSuite with BeforeAndAfterAll { // Then, let's look at the number of post-shuffle partitions estimated // by the ExchangeCoordinator. val exchanges = join.queryExecution.executedPlan.collect { - case e: ShuffleExchange => e + case e: ShuffleExchangeExec => e } assert(exchanges.length === 4) minNumPostShufflePartitions match { case Some(numPartitions) => exchanges.foreach { - case e: ShuffleExchange => + case e: ShuffleExchangeExec => assert(e.coordinator.isDefined) assert(e.outputPartitioning.numPartitions === 5) case o => @@ -459,13 +459,13 @@ class ExchangeCoordinatorSuite extends SparkFunSuite with BeforeAndAfterAll { // Then, let's look at the number of post-shuffle partitions estimated // by the ExchangeCoordinator. val exchanges = join.queryExecution.executedPlan.collect { - case e: ShuffleExchange => e + case e: ShuffleExchangeExec => e } assert(exchanges.length === 3) minNumPostShufflePartitions match { case Some(numPartitions) => exchanges.foreach { - case e: ShuffleExchange => + case e: ShuffleExchangeExec => assert(e.coordinator.isDefined) assert(e.outputPartitioning.numPartitions === 5) case o => diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeSuite.scala index 59eaf4d1c29b7..aac8d56ba6201 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeSuite.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.execution import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.expressions.{Alias, Literal} import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, IdentityBroadcastMode, SinglePartition} -import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, ReusedExchangeExec, ShuffleExchange} +import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, ReusedExchangeExec, ShuffleExchangeExec} import org.apache.spark.sql.execution.joins.HashedRelationBroadcastMode import org.apache.spark.sql.test.SharedSQLContext @@ -31,7 +31,7 @@ class ExchangeSuite extends SparkPlanTest with SharedSQLContext { val input = (1 to 1000).map(Tuple1.apply) checkAnswer( input.toDF(), - plan => ShuffleExchange(SinglePartition, plan), + plan => ShuffleExchangeExec(SinglePartition, plan), input.map(Row.fromTuple) ) } @@ -81,12 +81,12 @@ class ExchangeSuite extends SparkPlanTest with SharedSQLContext { assert(plan sameResult plan) val part1 = HashPartitioning(output, 1) - val exchange1 = ShuffleExchange(part1, plan) - val exchange2 = ShuffleExchange(part1, plan) + val exchange1 = ShuffleExchangeExec(part1, plan) + val exchange2 = ShuffleExchangeExec(part1, plan) val part2 = HashPartitioning(output, 2) - val exchange3 = ShuffleExchange(part2, plan) + val exchange3 = ShuffleExchangeExec(part2, plan) val part3 = HashPartitioning(output ++ output, 2) - val exchange4 = ShuffleExchange(part3, plan) + val exchange4 = ShuffleExchangeExec(part3, plan) val exchange5 = ReusedExchangeExec(output, exchange4) assert(exchange1 sameResult exchange1) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala index 63e17c7f372b0..86066362da9dd 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala @@ -25,7 +25,7 @@ import org.apache.spark.sql.catalyst.plans.{Cross, FullOuter, Inner, LeftOuter, import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Repartition} import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.execution.columnar.InMemoryRelation -import org.apache.spark.sql.execution.exchange.{EnsureRequirements, ReusedExchangeExec, ReuseExchange, ShuffleExchange} +import org.apache.spark.sql.execution.exchange.{EnsureRequirements, ReusedExchangeExec, ReuseExchange, ShuffleExchangeExec} import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, SortMergeJoinExec} import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf @@ -214,7 +214,7 @@ class PlannerSuite extends SharedSQLContext { | JOIN tiny ON (small.key = tiny.key) """.stripMargin ).queryExecution.executedPlan.collect { - case exchange: ShuffleExchange => exchange + case exchange: ShuffleExchangeExec => exchange }.length assert(numExchanges === 5) } @@ -229,7 +229,7 @@ class PlannerSuite extends SharedSQLContext { | JOIN tiny ON (normal.key = tiny.key) """.stripMargin ).queryExecution.executedPlan.collect { - case exchange: ShuffleExchange => exchange + case exchange: ShuffleExchangeExec => exchange }.length assert(numExchanges === 5) } @@ -300,7 +300,7 @@ class PlannerSuite extends SharedSQLContext { ) val outputPlan = EnsureRequirements(spark.sessionState.conf).apply(inputPlan) assertDistributionRequirementsAreSatisfied(outputPlan) - if (outputPlan.collect { case e: ShuffleExchange => true }.isEmpty) { + if (outputPlan.collect { case e: ShuffleExchangeExec => true }.isEmpty) { fail(s"Exchange should have been added:\n$outputPlan") } } @@ -338,7 +338,7 @@ class PlannerSuite extends SharedSQLContext { ) val outputPlan = EnsureRequirements(spark.sessionState.conf).apply(inputPlan) assertDistributionRequirementsAreSatisfied(outputPlan) - if (outputPlan.collect { case e: ShuffleExchange => true }.isEmpty) { + if (outputPlan.collect { case e: ShuffleExchangeExec => true }.isEmpty) { fail(s"Exchange should have been added:\n$outputPlan") } } @@ -358,7 +358,7 @@ class PlannerSuite extends SharedSQLContext { ) val outputPlan = EnsureRequirements(spark.sessionState.conf).apply(inputPlan) assertDistributionRequirementsAreSatisfied(outputPlan) - if (outputPlan.collect { case e: ShuffleExchange => true }.nonEmpty) { + if (outputPlan.collect { case e: ShuffleExchangeExec => true }.nonEmpty) { fail(s"Exchange should not have been added:\n$outputPlan") } } @@ -381,7 +381,7 @@ class PlannerSuite extends SharedSQLContext { ) val outputPlan = EnsureRequirements(spark.sessionState.conf).apply(inputPlan) assertDistributionRequirementsAreSatisfied(outputPlan) - if (outputPlan.collect { case e: ShuffleExchange => true }.nonEmpty) { + if (outputPlan.collect { case e: ShuffleExchangeExec => true }.nonEmpty) { fail(s"No Exchanges should have been added:\n$outputPlan") } } @@ -391,7 +391,7 @@ class PlannerSuite extends SharedSQLContext { val finalPartitioning = HashPartitioning(Literal(1) :: Nil, 5) val childPartitioning = HashPartitioning(Literal(2) :: Nil, 5) assert(!childPartitioning.satisfies(distribution)) - val inputPlan = ShuffleExchange(finalPartitioning, + val inputPlan = ShuffleExchangeExec(finalPartitioning, DummySparkPlan( children = DummySparkPlan(outputPartitioning = childPartitioning) :: Nil, requiredChildDistribution = Seq(distribution), @@ -400,7 +400,7 @@ class PlannerSuite extends SharedSQLContext { val outputPlan = EnsureRequirements(spark.sessionState.conf).apply(inputPlan) assertDistributionRequirementsAreSatisfied(outputPlan) - if (outputPlan.collect { case e: ShuffleExchange => true }.size == 2) { + if (outputPlan.collect { case e: ShuffleExchangeExec => true }.size == 2) { fail(s"Topmost Exchange should have been eliminated:\n$outputPlan") } } @@ -411,7 +411,7 @@ class PlannerSuite extends SharedSQLContext { val finalPartitioning = HashPartitioning(Literal(1) :: Nil, 8) val childPartitioning = HashPartitioning(Literal(2) :: Nil, 5) assert(!childPartitioning.satisfies(distribution)) - val inputPlan = ShuffleExchange(finalPartitioning, + val inputPlan = ShuffleExchangeExec(finalPartitioning, DummySparkPlan( children = DummySparkPlan(outputPartitioning = childPartitioning) :: Nil, requiredChildDistribution = Seq(distribution), @@ -420,7 +420,7 @@ class PlannerSuite extends SharedSQLContext { val outputPlan = EnsureRequirements(spark.sessionState.conf).apply(inputPlan) assertDistributionRequirementsAreSatisfied(outputPlan) - if (outputPlan.collect { case e: ShuffleExchange => true }.size == 1) { + if (outputPlan.collect { case e: ShuffleExchangeExec => true }.size == 1) { fail(s"Topmost Exchange should not have been eliminated:\n$outputPlan") } } @@ -430,7 +430,7 @@ class PlannerSuite extends SharedSQLContext { val finalPartitioning = HashPartitioning(Literal(1) :: Nil, 5) val childPartitioning = HashPartitioning(Literal(2) :: Nil, 5) assert(!childPartitioning.satisfies(distribution)) - val shuffle = ShuffleExchange(finalPartitioning, + val shuffle = ShuffleExchangeExec(finalPartitioning, DummySparkPlan( children = DummySparkPlan(outputPartitioning = childPartitioning) :: Nil, requiredChildDistribution = Seq(distribution), @@ -449,7 +449,7 @@ class PlannerSuite extends SharedSQLContext { if (outputPlan.collect { case e: ReusedExchangeExec => true }.size != 1) { fail(s"Should re-use the shuffle:\n$outputPlan") } - if (outputPlan.collect { case e: ShuffleExchange => true }.size != 1) { + if (outputPlan.collect { case e: ShuffleExchangeExec => true }.size != 1) { fail(s"Should have only one shuffle:\n$outputPlan") } @@ -459,14 +459,14 @@ class PlannerSuite extends SharedSQLContext { Literal(1) :: Nil, Inner, None, - ShuffleExchange(finalPartitioning, inputPlan), - ShuffleExchange(finalPartitioning, inputPlan)) + ShuffleExchangeExec(finalPartitioning, inputPlan), + ShuffleExchangeExec(finalPartitioning, inputPlan)) val outputPlan2 = ReuseExchange(spark.sessionState.conf).apply(inputPlan2) if (outputPlan2.collect { case e: ReusedExchangeExec => true }.size != 2) { fail(s"Should re-use the two shuffles:\n$outputPlan2") } - if (outputPlan2.collect { case e: ShuffleExchange => true }.size != 2) { + if (outputPlan2.collect { case e: ShuffleExchangeExec => true }.size != 2) { fail(s"Should have only two shuffles:\n$outputPlan") } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala index eb9e6458fc61c..ab18905e2ddb2 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala @@ -26,7 +26,7 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.physical.HashPartitioning import org.apache.spark.sql.execution.{DataSourceScanExec, SortExec} import org.apache.spark.sql.execution.datasources.DataSourceStrategy -import org.apache.spark.sql.execution.exchange.ShuffleExchange +import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec import org.apache.spark.sql.execution.joins.SortMergeJoinExec import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf @@ -302,10 +302,10 @@ abstract class BucketedReadSuite extends QueryTest with SQLTestUtils { // check existence of shuffle assert( - joinOperator.left.find(_.isInstanceOf[ShuffleExchange]).isDefined == shuffleLeft, + joinOperator.left.find(_.isInstanceOf[ShuffleExchangeExec]).isDefined == shuffleLeft, s"expected shuffle in plan to be $shuffleLeft but found\n${joinOperator.left}") assert( - joinOperator.right.find(_.isInstanceOf[ShuffleExchange]).isDefined == shuffleRight, + joinOperator.right.find(_.isInstanceOf[ShuffleExchangeExec]).isDefined == shuffleRight, s"expected shuffle in plan to be $shuffleRight but found\n${joinOperator.right}") // check existence of sort @@ -506,7 +506,7 @@ abstract class BucketedReadSuite extends QueryTest with SQLTestUtils { agged.sort("i", "j"), df1.groupBy("i", "j").agg(max("k")).sort("i", "j")) - assert(agged.queryExecution.executedPlan.find(_.isInstanceOf[ShuffleExchange]).isEmpty) + assert(agged.queryExecution.executedPlan.find(_.isInstanceOf[ShuffleExchangeExec]).isEmpty) } } @@ -520,7 +520,7 @@ abstract class BucketedReadSuite extends QueryTest with SQLTestUtils { agged.sort("i", "j"), df1.groupBy("i", "j").agg(max("k")).sort("i", "j")) - assert(agged.queryExecution.executedPlan.find(_.isInstanceOf[ShuffleExchange]).isEmpty) + assert(agged.queryExecution.executedPlan.find(_.isInstanceOf[ShuffleExchangeExec]).isEmpty) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/EnsureStatefulOpPartitioningSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/EnsureStatefulOpPartitioningSuite.scala index 044bb03480aa4..ed9823fbddfda 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/EnsureStatefulOpPartitioningSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/EnsureStatefulOpPartitioningSuite.scala @@ -26,7 +26,7 @@ import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.execution.{SparkPlan, SparkPlanTest, UnaryExecNode} -import org.apache.spark.sql.execution.exchange.{Exchange, ShuffleExchange} +import org.apache.spark.sql.execution.exchange.{Exchange, ShuffleExchangeExec} import org.apache.spark.sql.execution.streaming.{IncrementalExecution, OffsetSeqMetadata, StatefulOperator, StatefulOperatorStateInfo} import org.apache.spark.sql.test.SharedSQLContext @@ -93,7 +93,7 @@ class EnsureStatefulOpPartitioningSuite extends SparkPlanTest with SharedSQLCont fail(s"Was expecting an exchange but didn't get one in:\n$executed") } assert(exchange.get === - ShuffleExchange(expectedPartitioning(inputPlan.output.take(1)), inputPlan), + ShuffleExchangeExec(expectedPartitioning(inputPlan.output.take(1)), inputPlan), s"Exchange didn't have expected properties:\n${exchange.get}") } else { assert(!executed.children.exists(_.isInstanceOf[Exchange]), From d29d1e87995e02cb57ba3026c945c3cd66bb06e2 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Thu, 28 Sep 2017 15:59:05 -0700 Subject: [PATCH 1417/1765] [SPARK-22159][SQL] Make config names consistently end with "enabled". ## What changes were proposed in this pull request? spark.sql.execution.arrow.enable and spark.sql.codegen.aggregate.map.twolevel.enable -> enabled ## How was this patch tested? N/A Author: Reynold Xin Closes #19384 from rxin/SPARK-22159. --- .../main/scala/org/apache/spark/sql/internal/SQLConf.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index d00c672487532..358cf62149070 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -668,7 +668,7 @@ object SQLConf { .createWithDefault(40) val ENABLE_TWOLEVEL_AGG_MAP = - buildConf("spark.sql.codegen.aggregate.map.twolevel.enable") + buildConf("spark.sql.codegen.aggregate.map.twolevel.enabled") .internal() .doc("Enable two-level aggregate hash map. When enabled, records will first be " + "inserted/looked-up at a 1st-level, small, fast map, and then fallback to a " + @@ -908,7 +908,7 @@ object SQLConf { .createWithDefault(false) val ARROW_EXECUTION_ENABLE = - buildConf("spark.sql.execution.arrow.enable") + buildConf("spark.sql.execution.arrow.enabled") .internal() .doc("Make use of Apache Arrow for columnar data transfers. Currently available " + "for use with pyspark.sql.DataFrame.toPandas with the following data types: " + From 323806e68f91f3c7521327186a37ddd1436267d0 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Thu, 28 Sep 2017 21:07:12 -0700 Subject: [PATCH 1418/1765] [SPARK-22160][SQL] Make sample points per partition (in range partitioner) configurable and bump the default value up to 100 ## What changes were proposed in this pull request? Spark's RangePartitioner hard codes the number of sampling points per partition to be 20. This is sometimes too low. This ticket makes it configurable, via spark.sql.execution.rangeExchange.sampleSizePerPartition, and raises the default in Spark SQL to be 100. ## How was this patch tested? Added a pretty sophisticated test based on chi square test ... Author: Reynold Xin Closes #19387 from rxin/SPARK-22160. --- .../scala/org/apache/spark/Partitioner.scala | 15 ++++- .../apache/spark/sql/internal/SQLConf.scala | 10 +++ .../exchange/ShuffleExchangeExec.scala | 7 +- .../spark/sql/ConfigBehaviorSuite.scala | 66 +++++++++++++++++++ 4 files changed, 95 insertions(+), 3 deletions(-) create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/ConfigBehaviorSuite.scala diff --git a/core/src/main/scala/org/apache/spark/Partitioner.scala b/core/src/main/scala/org/apache/spark/Partitioner.scala index 1484f29525a4e..debbd8d7c26c9 100644 --- a/core/src/main/scala/org/apache/spark/Partitioner.scala +++ b/core/src/main/scala/org/apache/spark/Partitioner.scala @@ -108,11 +108,21 @@ class HashPartitioner(partitions: Int) extends Partitioner { class RangePartitioner[K : Ordering : ClassTag, V]( partitions: Int, rdd: RDD[_ <: Product2[K, V]], - private var ascending: Boolean = true) + private var ascending: Boolean = true, + val samplePointsPerPartitionHint: Int = 20) extends Partitioner { + // A constructor declared in order to maintain backward compatibility for Java, when we add the + // 4th constructor parameter samplePointsPerPartitionHint. See SPARK-22160. + // This is added to make sure from a bytecode point of view, there is still a 3-arg ctor. + def this(partitions: Int, rdd: RDD[_ <: Product2[K, V]], ascending: Boolean) = { + this(partitions, rdd, ascending, samplePointsPerPartitionHint = 20) + } + // We allow partitions = 0, which happens when sorting an empty RDD under the default settings. require(partitions >= 0, s"Number of partitions cannot be negative but found $partitions.") + require(samplePointsPerPartitionHint > 0, + s"Sample points per partition must be greater than 0 but found $samplePointsPerPartitionHint") private var ordering = implicitly[Ordering[K]] @@ -122,7 +132,8 @@ class RangePartitioner[K : Ordering : ClassTag, V]( Array.empty } else { // This is the sample size we need to have roughly balanced output partitions, capped at 1M. - val sampleSize = math.min(20.0 * partitions, 1e6) + // Cast to double to avoid overflowing ints or longs + val sampleSize = math.min(samplePointsPerPartitionHint.toDouble * partitions, 1e6) // Assume the input partitions are roughly balanced and over-sample a little bit. val sampleSizePerPartition = math.ceil(3.0 * sampleSize / rdd.partitions.length).toInt val (numItems, sketched) = RangePartitioner.sketch(rdd.map(_._1), sampleSizePerPartition) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 358cf62149070..1a73d168b9b6e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -907,6 +907,14 @@ object SQLConf { .booleanConf .createWithDefault(false) + val RANGE_EXCHANGE_SAMPLE_SIZE_PER_PARTITION = + buildConf("spark.sql.execution.rangeExchange.sampleSizePerPartition") + .internal() + .doc("Number of points to sample per partition in order to determine the range boundaries" + + " for range partitioning, typically used in global sorting (without limit).") + .intConf + .createWithDefault(100) + val ARROW_EXECUTION_ENABLE = buildConf("spark.sql.execution.arrow.enabled") .internal() @@ -1199,6 +1207,8 @@ class SQLConf extends Serializable with Logging { def supportQuotedRegexColumnName: Boolean = getConf(SUPPORT_QUOTED_REGEX_COLUMN_NAME) + def rangeExchangeSampleSizePerPartition: Int = getConf(RANGE_EXCHANGE_SAMPLE_SIZE_PER_PARTITION) + def arrowEnable: Boolean = getConf(ARROW_EXECUTION_ENABLE) def arrowMaxRecordsPerBatch: Int = getConf(ARROW_EXECUTION_MAX_RECORDS_PER_BATCH) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala index 11c4aa9b4acf0..5a1e217082bc2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala @@ -30,6 +30,7 @@ import org.apache.spark.sql.catalyst.expressions.codegen.LazilyGeneratedOrdering import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.metric.SQLMetrics +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.util.MutablePair /** @@ -218,7 +219,11 @@ object ShuffleExchangeExec { iter.map(row => mutablePair.update(row.copy(), null)) } implicit val ordering = new LazilyGeneratedOrdering(sortingExpressions, outputAttributes) - new RangePartitioner(numPartitions, rddForSampling, ascending = true) + new RangePartitioner( + numPartitions, + rddForSampling, + ascending = true, + samplePointsPerPartitionHint = SQLConf.get.rangeExchangeSampleSizePerPartition) case SinglePartition => new Partitioner { override def numPartitions: Int = 1 diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ConfigBehaviorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ConfigBehaviorSuite.scala new file mode 100644 index 0000000000000..2c1e5db5fd9bb --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/ConfigBehaviorSuite.scala @@ -0,0 +1,66 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import org.apache.commons.math3.stat.inference.ChiSquareTest + +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.test.SharedSQLContext + + +class ConfigBehaviorSuite extends QueryTest with SharedSQLContext { + + import testImplicits._ + + test("SPARK-22160 spark.sql.execution.rangeExchange.sampleSizePerPartition") { + // In this test, we run a sort and compute the histogram for partition size post shuffle. + // With a high sample count, the partition size should be more evenly distributed, and has a + // low chi-sq test value. + // Also the whole code path for range partitioning as implemented should be deterministic + // (it uses the partition id as the seed), so this test shouldn't be flaky. + + val numPartitions = 4 + + def computeChiSquareTest(): Double = { + val n = 10000 + // Trigger a sort + val data = spark.range(0, n, 1, 1).sort('id) + .selectExpr("SPARK_PARTITION_ID() pid", "id").as[(Int, Long)].collect() + + // Compute histogram for the number of records per partition post sort + val dist = data.groupBy(_._1).map(_._2.length.toLong).toArray + assert(dist.length == 4) + + new ChiSquareTest().chiSquare( + Array.fill(numPartitions) { n.toDouble / numPartitions }, + dist) + } + + withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> numPartitions.toString) { + // The default chi-sq value should be low + assert(computeChiSquareTest() < 100) + + withSQLConf(SQLConf.RANGE_EXCHANGE_SAMPLE_SIZE_PER_PARTITION.key -> "1") { + // If we only sample one point, the range boundaries will be pretty bad and the + // chi-sq value would be very high. + assert(computeChiSquareTest() > 1000) + } + } + } + +} From 161ba7eaa4539f0a7f20d9e2a493e0e323ca5249 Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Thu, 28 Sep 2017 23:14:53 -0700 Subject: [PATCH 1419/1765] [SPARK-22146] FileNotFoundException while reading ORC files containing special characters ## What changes were proposed in this pull request? Reading ORC files containing special characters like '%' fails with a FileNotFoundException. This PR aims to fix the problem. ## How was this patch tested? Added UT. Author: Marco Gaido Author: Marco Gaido Closes #19368 from mgaido91/SPARK-22146. --- .../apache/spark/sql/hive/orc/OrcFileFormat.scala | 2 +- .../spark/sql/hive/MetastoreDataSourcesSuite.scala | 12 +++++++++++- 2 files changed, 12 insertions(+), 2 deletions(-) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileFormat.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileFormat.scala index 4d92a67044373..c76f0ebb36a60 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileFormat.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileFormat.scala @@ -58,7 +58,7 @@ class OrcFileFormat extends FileFormat with DataSourceRegister with Serializable options: Map[String, String], files: Seq[FileStatus]): Option[StructType] = { OrcFileOperator.readSchema( - files.map(_.getPath.toUri.toString), + files.map(_.getPath.toString), Some(sparkSession.sessionState.newHadoopConf()) ) } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala index 29b0e6c8533ef..f5d41c91270a5 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala @@ -993,7 +993,6 @@ class MetastoreDataSourcesSuite extends QueryTest with SQLTestUtils with TestHiv spark.sql("""drop database if exists testdb8156 CASCADE""") } - test("skip hive metadata on table creation") { withTempDir { tempPath => val schema = StructType((1 to 5).map(i => StructField(s"c_$i", StringType))) @@ -1345,6 +1344,17 @@ class MetastoreDataSourcesSuite extends QueryTest with SQLTestUtils with TestHiv } } + Seq("orc", "parquet", "csv", "json", "text").foreach { format => + test(s"SPARK-22146: read files containing special characters using $format") { + val nameWithSpecialChars = s"sp&cial%chars" + withTempDir { dir => + val tmpFile = s"$dir/$nameWithSpecialChars" + spark.createDataset(Seq("a", "b")).write.format(format).save(tmpFile) + spark.read.format(format).load(tmpFile) + } + } + } + private def withDebugMode(f: => Unit): Unit = { val previousValue = sparkSession.sparkContext.conf.get(DEBUG_MODE) try { From 0fa4dbe4f4d7b988be2105b46590b5207f7c8121 Mon Sep 17 00:00:00 2001 From: Wang Gengliang Date: Thu, 28 Sep 2017 23:23:30 -0700 Subject: [PATCH 1420/1765] [SPARK-22141][FOLLOWUP][SQL] Add comments for the order of batches ## What changes were proposed in this pull request? Add comments for specifying the position of batch "Check Cartesian Products", as rxin suggested in https://github.com/apache/spark/pull/19362 . ## How was this patch tested? Unit test Author: Wang Gengliang Closes #19379 from gengliangwang/SPARK-22141-followup. --- .../org/apache/spark/sql/catalyst/optimizer/Optimizer.scala | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index a391c513ad384..b9fa39d6dad4c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -134,6 +134,7 @@ abstract class Optimizer(sessionCatalog: SessionCatalog) Batch("LocalRelation", fixedPoint, ConvertToLocalRelation, PropagateEmptyRelation) :: + // The following batch should be executed after batch "Join Reorder" and "LocalRelation". Batch("Check Cartesian Products", Once, CheckCartesianProducts) :: Batch("OptimizeCodegen", Once, @@ -1089,6 +1090,9 @@ object CombineLimits extends Rule[LogicalPlan] { * SELECT * from R, S where R.r = S.s, * the join between R and S is not a cartesian product and therefore should be allowed. * The predicate R.r = S.s is not recognized as a join condition until the ReorderJoin rule. + * + * This rule must be run AFTER the batch "LocalRelation", since a join with empty relation should + * not be a cartesian product. */ object CheckCartesianProducts extends Rule[LogicalPlan] with PredicateHelper { /** From a2516f41aef68e39df7f6380fd2618cc148a609e Mon Sep 17 00:00:00 2001 From: Sean Owen Date: Fri, 29 Sep 2017 08:26:53 +0100 Subject: [PATCH 1421/1765] [SPARK-22142][BUILD][STREAMING] Move Flume support behind a profile ## What changes were proposed in this pull request? Add 'flume' profile to enable Flume-related integration modules ## How was this patch tested? Existing tests; no functional change Author: Sean Owen Closes #19365 from srowen/SPARK-22142. --- dev/create-release/release-build.sh | 4 ++-- dev/mima | 2 +- dev/scalastyle | 1 + dev/sparktestsupport/modules.py | 20 +++++++++++++++++++- dev/test-dependencies.sh | 2 +- docs/building-spark.md | 6 ++++++ pom.xml | 13 ++++++++++--- project/SparkBuild.scala | 17 +++++++++-------- python/pyspark/streaming/tests.py | 16 +++++++++++++--- 9 files changed, 62 insertions(+), 19 deletions(-) diff --git a/dev/create-release/release-build.sh b/dev/create-release/release-build.sh index 8de1d6a37dc25..c548a0a4e4bee 100755 --- a/dev/create-release/release-build.sh +++ b/dev/create-release/release-build.sh @@ -84,9 +84,9 @@ MVN="build/mvn --force" # Hive-specific profiles for some builds HIVE_PROFILES="-Phive -Phive-thriftserver" # Profiles for publishing snapshots and release to Maven Central -PUBLISH_PROFILES="-Pmesos -Pyarn $HIVE_PROFILES -Pspark-ganglia-lgpl -Pkinesis-asl" +PUBLISH_PROFILES="-Pmesos -Pyarn -Pflume $HIVE_PROFILES -Pspark-ganglia-lgpl -Pkinesis-asl" # Profiles for building binary releases -BASE_RELEASE_PROFILES="-Pmesos -Pyarn -Psparkr" +BASE_RELEASE_PROFILES="-Pmesos -Pyarn -Pflume -Psparkr" # Scala 2.11 only profiles for some builds SCALA_2_11_PROFILES="-Pkafka-0-8" # Scala 2.12 only profiles for some builds diff --git a/dev/mima b/dev/mima index fdb21f5007cf2..1e3ca9700bc07 100755 --- a/dev/mima +++ b/dev/mima @@ -24,7 +24,7 @@ set -e FWDIR="$(cd "`dirname "$0"`"/..; pwd)" cd "$FWDIR" -SPARK_PROFILES="-Pmesos -Pkafka-0-8 -Pyarn -Pspark-ganglia-lgpl -Pkinesis-asl -Phive-thriftserver -Phive" +SPARK_PROFILES="-Pmesos -Pkafka-0-8 -Pyarn -Pflume -Pspark-ganglia-lgpl -Pkinesis-asl -Phive-thriftserver -Phive" TOOLS_CLASSPATH="$(build/sbt -DcopyDependencies=false "export tools/fullClasspath" | tail -n1)" OLD_DEPS_CLASSPATH="$(build/sbt -DcopyDependencies=false $SPARK_PROFILES "export oldDeps/fullClasspath" | tail -n1)" diff --git a/dev/scalastyle b/dev/scalastyle index e5aa589869535..89ecc8abd6f8c 100755 --- a/dev/scalastyle +++ b/dev/scalastyle @@ -25,6 +25,7 @@ ERRORS=$(echo -e "q\n" \ -Pmesos \ -Pkafka-0-8 \ -Pyarn \ + -Pflume \ -Phive \ -Phive-thriftserver \ scalastyle test:scalastyle \ diff --git a/dev/sparktestsupport/modules.py b/dev/sparktestsupport/modules.py index 50e14b60545af..91d5667ed1f07 100644 --- a/dev/sparktestsupport/modules.py +++ b/dev/sparktestsupport/modules.py @@ -279,6 +279,12 @@ def __hash__(self): source_file_regexes=[ "external/flume-sink", ], + build_profile_flags=[ + "-Pflume", + ], + environ={ + "ENABLE_FLUME_TESTS": "1" + }, sbt_test_goals=[ "streaming-flume-sink/test", ] @@ -291,6 +297,12 @@ def __hash__(self): source_file_regexes=[ "external/flume", ], + build_profile_flags=[ + "-Pflume", + ], + environ={ + "ENABLE_FLUME_TESTS": "1" + }, sbt_test_goals=[ "streaming-flume/test", ] @@ -302,7 +314,13 @@ def __hash__(self): dependencies=[streaming_flume, streaming_flume_sink], source_file_regexes=[ "external/flume-assembly", - ] + ], + build_profile_flags=[ + "-Pflume", + ], + environ={ + "ENABLE_FLUME_TESTS": "1" + } ) diff --git a/dev/test-dependencies.sh b/dev/test-dependencies.sh index c7714578bd005..58b295d4f6e00 100755 --- a/dev/test-dependencies.sh +++ b/dev/test-dependencies.sh @@ -29,7 +29,7 @@ export LC_ALL=C # TODO: This would be much nicer to do in SBT, once SBT supports Maven-style resolution. # NOTE: These should match those in the release publishing script -HADOOP2_MODULE_PROFILES="-Phive-thriftserver -Pmesos -Pkafka-0-8 -Pyarn -Phive" +HADOOP2_MODULE_PROFILES="-Phive-thriftserver -Pmesos -Pkafka-0-8 -Pyarn -Pflume -Phive" MVN="build/mvn" HADOOP_PROFILES=( hadoop-2.6 diff --git a/docs/building-spark.md b/docs/building-spark.md index 57baa503259c1..e1532de16108d 100644 --- a/docs/building-spark.md +++ b/docs/building-spark.md @@ -100,6 +100,12 @@ Note: Kafka 0.8 support is deprecated as of Spark 2.3.0. Kafka 0.10 support is still automatically built. +## Building with Flume support + +Apache Flume support must be explicitly enabled with the `flume` profile. + + ./build/mvn -Pflume -DskipTests clean package + ## Building submodules individually It's possible to build Spark sub-modules using the `mvn -pl` option. diff --git a/pom.xml b/pom.xml index 87a468c3a6f55..9fac8b1e53788 100644 --- a/pom.xml +++ b/pom.xml @@ -98,15 +98,13 @@ sql/core sql/hive assembly - external/flume - external/flume-sink - external/flume-assembly examples repl launcher external/kafka-0-10 external/kafka-0-10-assembly external/kafka-0-10-sql + @@ -2583,6 +2581,15 @@
    + + flume + + external/flume + external/flume-sink + external/flume-assembly + + + spark-ganglia-lgpl diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala index a568d264cb2db..9501eed1e906b 100644 --- a/project/SparkBuild.scala +++ b/project/SparkBuild.scala @@ -43,11 +43,8 @@ object BuildCommons { "catalyst", "sql", "hive", "hive-thriftserver", "sql-kafka-0-10" ).map(ProjectRef(buildLocation, _)) - val streamingProjects@Seq( - streaming, streamingFlumeSink, streamingFlume, streamingKafka010 - ) = Seq( - "streaming", "streaming-flume-sink", "streaming-flume", "streaming-kafka-0-10" - ).map(ProjectRef(buildLocation, _)) + val streamingProjects@Seq(streaming, streamingKafka010) = + Seq("streaming", "streaming-kafka-0-10").map(ProjectRef(buildLocation, _)) val allProjects@Seq( core, graphx, mllib, mllibLocal, repl, networkCommon, networkShuffle, launcher, unsafe, tags, sketch, kvstore, _* @@ -56,9 +53,13 @@ object BuildCommons { "tags", "sketch", "kvstore" ).map(ProjectRef(buildLocation, _)) ++ sqlProjects ++ streamingProjects - val optionallyEnabledProjects@Seq(mesos, yarn, streamingKafka, sparkGangliaLgpl, - streamingKinesisAsl, dockerIntegrationTests, hadoopCloud) = - Seq("mesos", "yarn", "streaming-kafka-0-8", "ganglia-lgpl", "streaming-kinesis-asl", + val optionallyEnabledProjects@Seq(mesos, yarn, + streamingFlumeSink, streamingFlume, + streamingKafka, sparkGangliaLgpl, streamingKinesisAsl, + dockerIntegrationTests, hadoopCloud) = + Seq("mesos", "yarn", + "streaming-flume-sink", "streaming-flume", + "streaming-kafka-0-8", "ganglia-lgpl", "streaming-kinesis-asl", "docker-integration-tests", "hadoop-cloud").map(ProjectRef(buildLocation, _)) val assemblyProjects@Seq(networkYarn, streamingFlumeAssembly, streamingKafkaAssembly, streamingKafka010Assembly, streamingKinesisAslAssembly) = diff --git a/python/pyspark/streaming/tests.py b/python/pyspark/streaming/tests.py index 229cf53e47359..5b86c1cb2c390 100644 --- a/python/pyspark/streaming/tests.py +++ b/python/pyspark/streaming/tests.py @@ -1478,7 +1478,7 @@ def search_kafka_assembly_jar(): ("Failed to find Spark Streaming kafka assembly jar in %s. " % kafka_assembly_dir) + "You need to build Spark with " "'build/sbt assembly/package streaming-kafka-0-8-assembly/assembly' or " - "'build/mvn package' before running this test.") + "'build/mvn -Pkafka-0-8 package' before running this test.") elif len(jars) > 1: raise Exception(("Found multiple Spark Streaming Kafka assembly JARs: %s; please " "remove all but one") % (", ".join(jars))) @@ -1495,7 +1495,7 @@ def search_flume_assembly_jar(): ("Failed to find Spark Streaming Flume assembly jar in %s. " % flume_assembly_dir) + "You need to build Spark with " "'build/sbt assembly/assembly streaming-flume-assembly/assembly' or " - "'build/mvn package' before running this test.") + "'build/mvn -Pflume package' before running this test.") elif len(jars) > 1: raise Exception(("Found multiple Spark Streaming Flume assembly JARs: %s; please " "remove all but one") % (", ".join(jars))) @@ -1516,6 +1516,9 @@ def search_kinesis_asl_assembly_jar(): return jars[0] +# Must be same as the variable and condition defined in modules.py +flume_test_environ_var = "ENABLE_FLUME_TESTS" +are_flume_tests_enabled = os.environ.get(flume_test_environ_var) == '1' # Must be same as the variable and condition defined in modules.py kafka_test_environ_var = "ENABLE_KAFKA_0_8_TESTS" are_kafka_tests_enabled = os.environ.get(kafka_test_environ_var) == '1' @@ -1538,9 +1541,16 @@ def search_kinesis_asl_assembly_jar(): os.environ["PYSPARK_SUBMIT_ARGS"] = "--jars %s pyspark-shell" % jars testcases = [BasicOperationTests, WindowFunctionTests, StreamingContextTests, CheckpointTests, - FlumeStreamTests, FlumePollingStreamTests, StreamingListenerTests] + if are_flume_tests_enabled: + testcases.append(FlumeStreamTests) + testcases.append(FlumePollingStreamTests) + else: + sys.stderr.write( + "Skipped test_flume_stream (enable by setting environment variable %s=1" + % flume_test_environ_var) + if are_kafka_tests_enabled: testcases.append(KafkaStreamTests) else: From ecbe416ab5001b32737966c5a2407597a1dafc32 Mon Sep 17 00:00:00 2001 From: Holden Karau Date: Fri, 29 Sep 2017 08:04:14 -0700 Subject: [PATCH 1422/1765] [SPARK-22129][SPARK-22138] Release script improvements ## What changes were proposed in this pull request? Use the GPG_KEY param, fix lsof to non-hardcoded path, remove version swap since it wasn't really needed. Use EXPORT on JAVA_HOME for downstream scripts as well. ## How was this patch tested? Rolled 2.1.2 RC2 Author: Holden Karau Closes #19359 from holdenk/SPARK-22129-fix-signing. --- dev/create-release/release-build.sh | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/dev/create-release/release-build.sh b/dev/create-release/release-build.sh index c548a0a4e4bee..7e8d5c7075195 100755 --- a/dev/create-release/release-build.sh +++ b/dev/create-release/release-build.sh @@ -74,7 +74,7 @@ GIT_REF=${GIT_REF:-master} # Destination directory parent on remote server REMOTE_PARENT_DIR=${REMOTE_PARENT_DIR:-/home/$ASF_USERNAME/public_html} -GPG="gpg --no-tty --batch" +GPG="gpg -u $GPG_KEY --no-tty --batch" NEXUS_ROOT=https://repository.apache.org/service/local/staging NEXUS_PROFILE=d63f592e7eac0 # Profile for Spark staging uploads BASE_DIR=$(pwd) @@ -125,7 +125,7 @@ else echo "Please set JAVA_HOME correctly." exit 1 else - JAVA_HOME="$JAVA_7_HOME" + export JAVA_HOME="$JAVA_7_HOME" fi fi fi @@ -140,7 +140,7 @@ DEST_DIR_NAME="spark-$SPARK_PACKAGE_VERSION" function LFTP { SSH="ssh -o ConnectTimeout=300 -o StrictHostKeyChecking=no -i $ASF_RSA_KEY" COMMANDS=$(cat < Date: Fri, 29 Sep 2017 08:59:42 -0700 Subject: [PATCH 1423/1765] [SPARK-22161][SQL] Add Impala-modified TPC-DS queries ## What changes were proposed in this pull request? Added IMPALA-modified TPCDS queries to TPC-DS query suites. - Ref: https://github.com/cloudera/impala-tpcds-kit/tree/master/queries ## How was this patch tested? N/A Author: gatorsmile Closes #19386 from gatorsmile/addImpalaQueries. --- .../resources/tpcds-modifiedQueries/q10.sql | 70 ++++++ .../resources/tpcds-modifiedQueries/q19.sql | 38 +++ .../resources/tpcds-modifiedQueries/q27.sql | 43 ++++ .../resources/tpcds-modifiedQueries/q3.sql | 228 ++++++++++++++++++ .../resources/tpcds-modifiedQueries/q34.sql | 45 ++++ .../resources/tpcds-modifiedQueries/q42.sql | 28 +++ .../resources/tpcds-modifiedQueries/q43.sql | 36 +++ .../resources/tpcds-modifiedQueries/q46.sql | 80 ++++++ .../resources/tpcds-modifiedQueries/q52.sql | 27 +++ .../resources/tpcds-modifiedQueries/q53.sql | 37 +++ .../resources/tpcds-modifiedQueries/q55.sql | 24 ++ .../resources/tpcds-modifiedQueries/q59.sql | 83 +++++++ .../resources/tpcds-modifiedQueries/q63.sql | 29 +++ .../resources/tpcds-modifiedQueries/q65.sql | 58 +++++ .../resources/tpcds-modifiedQueries/q68.sql | 62 +++++ .../resources/tpcds-modifiedQueries/q7.sql | 31 +++ .../resources/tpcds-modifiedQueries/q73.sql | 49 ++++ .../resources/tpcds-modifiedQueries/q79.sql | 59 +++++ .../resources/tpcds-modifiedQueries/q89.sql | 43 ++++ .../resources/tpcds-modifiedQueries/q98.sql | 32 +++ .../tpcds-modifiedQueries/ss_max.sql | 14 ++ .../apache/spark/sql/TPCDSQuerySuite.scala | 26 +- 22 files changed, 1141 insertions(+), 1 deletion(-) create mode 100755 sql/core/src/test/resources/tpcds-modifiedQueries/q10.sql create mode 100755 sql/core/src/test/resources/tpcds-modifiedQueries/q19.sql create mode 100755 sql/core/src/test/resources/tpcds-modifiedQueries/q27.sql create mode 100755 sql/core/src/test/resources/tpcds-modifiedQueries/q3.sql create mode 100755 sql/core/src/test/resources/tpcds-modifiedQueries/q34.sql create mode 100755 sql/core/src/test/resources/tpcds-modifiedQueries/q42.sql create mode 100755 sql/core/src/test/resources/tpcds-modifiedQueries/q43.sql create mode 100755 sql/core/src/test/resources/tpcds-modifiedQueries/q46.sql create mode 100755 sql/core/src/test/resources/tpcds-modifiedQueries/q52.sql create mode 100755 sql/core/src/test/resources/tpcds-modifiedQueries/q53.sql create mode 100755 sql/core/src/test/resources/tpcds-modifiedQueries/q55.sql create mode 100755 sql/core/src/test/resources/tpcds-modifiedQueries/q59.sql create mode 100755 sql/core/src/test/resources/tpcds-modifiedQueries/q63.sql create mode 100755 sql/core/src/test/resources/tpcds-modifiedQueries/q65.sql create mode 100755 sql/core/src/test/resources/tpcds-modifiedQueries/q68.sql create mode 100755 sql/core/src/test/resources/tpcds-modifiedQueries/q7.sql create mode 100755 sql/core/src/test/resources/tpcds-modifiedQueries/q73.sql create mode 100755 sql/core/src/test/resources/tpcds-modifiedQueries/q79.sql create mode 100755 sql/core/src/test/resources/tpcds-modifiedQueries/q89.sql create mode 100755 sql/core/src/test/resources/tpcds-modifiedQueries/q98.sql create mode 100755 sql/core/src/test/resources/tpcds-modifiedQueries/ss_max.sql diff --git a/sql/core/src/test/resources/tpcds-modifiedQueries/q10.sql b/sql/core/src/test/resources/tpcds-modifiedQueries/q10.sql new file mode 100755 index 0000000000000..79dd3d516e8c7 --- /dev/null +++ b/sql/core/src/test/resources/tpcds-modifiedQueries/q10.sql @@ -0,0 +1,70 @@ +-- start query 10 in stream 0 using template query10.tpl +with +v1 as ( + select + ws_bill_customer_sk as customer_sk + from web_sales, + date_dim + where ws_sold_date_sk = d_date_sk + and d_year = 2002 + and d_moy between 4 and 4+3 + union all + select + cs_ship_customer_sk as customer_sk + from catalog_sales, + date_dim + where cs_sold_date_sk = d_date_sk + and d_year = 2002 + and d_moy between 4 and 4+3 +), +v2 as ( + select + ss_customer_sk as customer_sk + from store_sales, + date_dim + where ss_sold_date_sk = d_date_sk + and d_year = 2002 + and d_moy between 4 and 4+3 +) +select + cd_gender, + cd_marital_status, + cd_education_status, + count(*) cnt1, + cd_purchase_estimate, + count(*) cnt2, + cd_credit_rating, + count(*) cnt3, + cd_dep_count, + count(*) cnt4, + cd_dep_employed_count, + count(*) cnt5, + cd_dep_college_count, + count(*) cnt6 +from customer c +join customer_address ca on (c.c_current_addr_sk = ca.ca_address_sk) +join customer_demographics on (cd_demo_sk = c.c_current_cdemo_sk) +left semi join v1 on (v1.customer_sk = c.c_customer_sk) +left semi join v2 on (v2.customer_sk = c.c_customer_sk) +where + ca_county in ('Walker County','Richland County','Gaines County','Douglas County','Dona Ana County') +group by + cd_gender, + cd_marital_status, + cd_education_status, + cd_purchase_estimate, + cd_credit_rating, + cd_dep_count, + cd_dep_employed_count, + cd_dep_college_count +order by + cd_gender, + cd_marital_status, + cd_education_status, + cd_purchase_estimate, + cd_credit_rating, + cd_dep_count, + cd_dep_employed_count, + cd_dep_college_count +limit 100 +-- end query 10 in stream 0 using template query10.tpl diff --git a/sql/core/src/test/resources/tpcds-modifiedQueries/q19.sql b/sql/core/src/test/resources/tpcds-modifiedQueries/q19.sql new file mode 100755 index 0000000000000..1799827762916 --- /dev/null +++ b/sql/core/src/test/resources/tpcds-modifiedQueries/q19.sql @@ -0,0 +1,38 @@ +-- start query 19 in stream 0 using template query19.tpl +select + i_brand_id brand_id, + i_brand brand, + i_manufact_id, + i_manufact, + sum(ss_ext_sales_price) ext_price +from + date_dim, + store_sales, + item, + customer, + customer_address, + store +where + d_date_sk = ss_sold_date_sk + and ss_item_sk = i_item_sk + and i_manager_id = 7 + and d_moy = 11 + and d_year = 1999 + and ss_customer_sk = c_customer_sk + and c_current_addr_sk = ca_address_sk + and substr(ca_zip, 1, 5) <> substr(s_zip, 1, 5) + and ss_store_sk = s_store_sk + and ss_sold_date_sk between 2451484 and 2451513 -- partition key filter +group by + i_brand, + i_brand_id, + i_manufact_id, + i_manufact +order by + ext_price desc, + i_brand, + i_brand_id, + i_manufact_id, + i_manufact +limit 100 +-- end query 19 in stream 0 using template query19.tpl diff --git a/sql/core/src/test/resources/tpcds-modifiedQueries/q27.sql b/sql/core/src/test/resources/tpcds-modifiedQueries/q27.sql new file mode 100755 index 0000000000000..dedbc62a2ab2e --- /dev/null +++ b/sql/core/src/test/resources/tpcds-modifiedQueries/q27.sql @@ -0,0 +1,43 @@ +-- start query 27 in stream 0 using template query27.tpl + with results as + (select i_item_id, + s_state, + ss_quantity agg1, + ss_list_price agg2, + ss_coupon_amt agg3, + ss_sales_price agg4 + --0 as g_state, + --avg(ss_quantity) agg1, + --avg(ss_list_price) agg2, + --avg(ss_coupon_amt) agg3, + --avg(ss_sales_price) agg4 + from store_sales, customer_demographics, date_dim, store, item + where ss_sold_date_sk = d_date_sk and + ss_sold_date_sk between 2451545 and 2451910 and + ss_item_sk = i_item_sk and + ss_store_sk = s_store_sk and + ss_cdemo_sk = cd_demo_sk and + cd_gender = 'F' and + cd_marital_status = 'D' and + cd_education_status = 'Primary' and + d_year = 2000 and + s_state in ('TN','AL', 'SD', 'SD', 'SD', 'SD') + --group by i_item_id, s_state + ) + + select i_item_id, + s_state, g_state, agg1, agg2, agg3, agg4 + from ( + select i_item_id, s_state, 0 as g_state, avg(agg1) agg1, avg(agg2) agg2, avg(agg3) agg3, avg(agg4) agg4 from results + group by i_item_id, s_state + union all + select i_item_id, NULL AS s_state, 1 AS g_state, avg(agg1) agg1, avg(agg2) agg2, avg(agg3) agg3, + avg(agg4) agg4 from results + group by i_item_id + union all + select NULL AS i_item_id, NULL as s_state, 1 as g_state, avg(agg1) agg1, avg(agg2) agg2, avg(agg3) agg3, + avg(agg4) agg4 from results + ) foo + order by i_item_id, s_state + limit 100 +-- end query 27 in stream 0 using template query27.tpl diff --git a/sql/core/src/test/resources/tpcds-modifiedQueries/q3.sql b/sql/core/src/test/resources/tpcds-modifiedQueries/q3.sql new file mode 100755 index 0000000000000..35b0a20f80a4e --- /dev/null +++ b/sql/core/src/test/resources/tpcds-modifiedQueries/q3.sql @@ -0,0 +1,228 @@ +-- start query 3 in stream 0 using template query3.tpl +select + dt.d_year, + item.i_brand_id brand_id, + item.i_brand brand, + sum(ss_net_profit) sum_agg +from + date_dim dt, + store_sales, + item +where + dt.d_date_sk = store_sales.ss_sold_date_sk + and store_sales.ss_item_sk = item.i_item_sk + and item.i_manufact_id = 436 + and dt.d_moy = 12 + -- partition key filters + and ( +ss_sold_date_sk between 2415355 and 2415385 +or ss_sold_date_sk between 2415720 and 2415750 +or ss_sold_date_sk between 2416085 and 2416115 +or ss_sold_date_sk between 2416450 and 2416480 +or ss_sold_date_sk between 2416816 and 2416846 +or ss_sold_date_sk between 2417181 and 2417211 +or ss_sold_date_sk between 2417546 and 2417576 +or ss_sold_date_sk between 2417911 and 2417941 +or ss_sold_date_sk between 2418277 and 2418307 +or ss_sold_date_sk between 2418642 and 2418672 +or ss_sold_date_sk between 2419007 and 2419037 +or ss_sold_date_sk between 2419372 and 2419402 +or ss_sold_date_sk between 2419738 and 2419768 +or ss_sold_date_sk between 2420103 and 2420133 +or ss_sold_date_sk between 2420468 and 2420498 +or ss_sold_date_sk between 2420833 and 2420863 +or ss_sold_date_sk between 2421199 and 2421229 +or ss_sold_date_sk between 2421564 and 2421594 +or ss_sold_date_sk between 2421929 and 2421959 +or ss_sold_date_sk between 2422294 and 2422324 +or ss_sold_date_sk between 2422660 and 2422690 +or ss_sold_date_sk between 2423025 and 2423055 +or ss_sold_date_sk between 2423390 and 2423420 +or ss_sold_date_sk between 2423755 and 2423785 +or ss_sold_date_sk between 2424121 and 2424151 +or ss_sold_date_sk between 2424486 and 2424516 +or ss_sold_date_sk between 2424851 and 2424881 +or ss_sold_date_sk between 2425216 and 2425246 +or ss_sold_date_sk between 2425582 and 2425612 +or ss_sold_date_sk between 2425947 and 2425977 +or ss_sold_date_sk between 2426312 and 2426342 +or ss_sold_date_sk between 2426677 and 2426707 +or ss_sold_date_sk between 2427043 and 2427073 +or ss_sold_date_sk between 2427408 and 2427438 +or ss_sold_date_sk between 2427773 and 2427803 +or ss_sold_date_sk between 2428138 and 2428168 +or ss_sold_date_sk between 2428504 and 2428534 +or ss_sold_date_sk between 2428869 and 2428899 +or ss_sold_date_sk between 2429234 and 2429264 +or ss_sold_date_sk between 2429599 and 2429629 +or ss_sold_date_sk between 2429965 and 2429995 +or ss_sold_date_sk between 2430330 and 2430360 +or ss_sold_date_sk between 2430695 and 2430725 +or ss_sold_date_sk between 2431060 and 2431090 +or ss_sold_date_sk between 2431426 and 2431456 +or ss_sold_date_sk between 2431791 and 2431821 +or ss_sold_date_sk between 2432156 and 2432186 +or ss_sold_date_sk between 2432521 and 2432551 +or ss_sold_date_sk between 2432887 and 2432917 +or ss_sold_date_sk between 2433252 and 2433282 +or ss_sold_date_sk between 2433617 and 2433647 +or ss_sold_date_sk between 2433982 and 2434012 +or ss_sold_date_sk between 2434348 and 2434378 +or ss_sold_date_sk between 2434713 and 2434743 +or ss_sold_date_sk between 2435078 and 2435108 +or ss_sold_date_sk between 2435443 and 2435473 +or ss_sold_date_sk between 2435809 and 2435839 +or ss_sold_date_sk between 2436174 and 2436204 +or ss_sold_date_sk between 2436539 and 2436569 +or ss_sold_date_sk between 2436904 and 2436934 +or ss_sold_date_sk between 2437270 and 2437300 +or ss_sold_date_sk between 2437635 and 2437665 +or ss_sold_date_sk between 2438000 and 2438030 +or ss_sold_date_sk between 2438365 and 2438395 +or ss_sold_date_sk between 2438731 and 2438761 +or ss_sold_date_sk between 2439096 and 2439126 +or ss_sold_date_sk between 2439461 and 2439491 +or ss_sold_date_sk between 2439826 and 2439856 +or ss_sold_date_sk between 2440192 and 2440222 +or ss_sold_date_sk between 2440557 and 2440587 +or ss_sold_date_sk between 2440922 and 2440952 +or ss_sold_date_sk between 2441287 and 2441317 +or ss_sold_date_sk between 2441653 and 2441683 +or ss_sold_date_sk between 2442018 and 2442048 +or ss_sold_date_sk between 2442383 and 2442413 +or ss_sold_date_sk between 2442748 and 2442778 +or ss_sold_date_sk between 2443114 and 2443144 +or ss_sold_date_sk between 2443479 and 2443509 +or ss_sold_date_sk between 2443844 and 2443874 +or ss_sold_date_sk between 2444209 and 2444239 +or ss_sold_date_sk between 2444575 and 2444605 +or ss_sold_date_sk between 2444940 and 2444970 +or ss_sold_date_sk between 2445305 and 2445335 +or ss_sold_date_sk between 2445670 and 2445700 +or ss_sold_date_sk between 2446036 and 2446066 +or ss_sold_date_sk between 2446401 and 2446431 +or ss_sold_date_sk between 2446766 and 2446796 +or ss_sold_date_sk between 2447131 and 2447161 +or ss_sold_date_sk between 2447497 and 2447527 +or ss_sold_date_sk between 2447862 and 2447892 +or ss_sold_date_sk between 2448227 and 2448257 +or ss_sold_date_sk between 2448592 and 2448622 +or ss_sold_date_sk between 2448958 and 2448988 +or ss_sold_date_sk between 2449323 and 2449353 +or ss_sold_date_sk between 2449688 and 2449718 +or ss_sold_date_sk between 2450053 and 2450083 +or ss_sold_date_sk between 2450419 and 2450449 +or ss_sold_date_sk between 2450784 and 2450814 +or ss_sold_date_sk between 2451149 and 2451179 +or ss_sold_date_sk between 2451514 and 2451544 +or ss_sold_date_sk between 2451880 and 2451910 +or ss_sold_date_sk between 2452245 and 2452275 +or ss_sold_date_sk between 2452610 and 2452640 +or ss_sold_date_sk between 2452975 and 2453005 +or ss_sold_date_sk between 2453341 and 2453371 +or ss_sold_date_sk between 2453706 and 2453736 +or ss_sold_date_sk between 2454071 and 2454101 +or ss_sold_date_sk between 2454436 and 2454466 +or ss_sold_date_sk between 2454802 and 2454832 +or ss_sold_date_sk between 2455167 and 2455197 +or ss_sold_date_sk between 2455532 and 2455562 +or ss_sold_date_sk between 2455897 and 2455927 +or ss_sold_date_sk between 2456263 and 2456293 +or ss_sold_date_sk between 2456628 and 2456658 +or ss_sold_date_sk between 2456993 and 2457023 +or ss_sold_date_sk between 2457358 and 2457388 +or ss_sold_date_sk between 2457724 and 2457754 +or ss_sold_date_sk between 2458089 and 2458119 +or ss_sold_date_sk between 2458454 and 2458484 +or ss_sold_date_sk between 2458819 and 2458849 +or ss_sold_date_sk between 2459185 and 2459215 +or ss_sold_date_sk between 2459550 and 2459580 +or ss_sold_date_sk between 2459915 and 2459945 +or ss_sold_date_sk between 2460280 and 2460310 +or ss_sold_date_sk between 2460646 and 2460676 +or ss_sold_date_sk between 2461011 and 2461041 +or ss_sold_date_sk between 2461376 and 2461406 +or ss_sold_date_sk between 2461741 and 2461771 +or ss_sold_date_sk between 2462107 and 2462137 +or ss_sold_date_sk between 2462472 and 2462502 +or ss_sold_date_sk between 2462837 and 2462867 +or ss_sold_date_sk between 2463202 and 2463232 +or ss_sold_date_sk between 2463568 and 2463598 +or ss_sold_date_sk between 2463933 and 2463963 +or ss_sold_date_sk between 2464298 and 2464328 +or ss_sold_date_sk between 2464663 and 2464693 +or ss_sold_date_sk between 2465029 and 2465059 +or ss_sold_date_sk between 2465394 and 2465424 +or ss_sold_date_sk between 2465759 and 2465789 +or ss_sold_date_sk between 2466124 and 2466154 +or ss_sold_date_sk between 2466490 and 2466520 +or ss_sold_date_sk between 2466855 and 2466885 +or ss_sold_date_sk between 2467220 and 2467250 +or ss_sold_date_sk between 2467585 and 2467615 +or ss_sold_date_sk between 2467951 and 2467981 +or ss_sold_date_sk between 2468316 and 2468346 +or ss_sold_date_sk between 2468681 and 2468711 +or ss_sold_date_sk between 2469046 and 2469076 +or ss_sold_date_sk between 2469412 and 2469442 +or ss_sold_date_sk between 2469777 and 2469807 +or ss_sold_date_sk between 2470142 and 2470172 +or ss_sold_date_sk between 2470507 and 2470537 +or ss_sold_date_sk between 2470873 and 2470903 +or ss_sold_date_sk between 2471238 and 2471268 +or ss_sold_date_sk between 2471603 and 2471633 +or ss_sold_date_sk between 2471968 and 2471998 +or ss_sold_date_sk between 2472334 and 2472364 +or ss_sold_date_sk between 2472699 and 2472729 +or ss_sold_date_sk between 2473064 and 2473094 +or ss_sold_date_sk between 2473429 and 2473459 +or ss_sold_date_sk between 2473795 and 2473825 +or ss_sold_date_sk between 2474160 and 2474190 +or ss_sold_date_sk between 2474525 and 2474555 +or ss_sold_date_sk between 2474890 and 2474920 +or ss_sold_date_sk between 2475256 and 2475286 +or ss_sold_date_sk between 2475621 and 2475651 +or ss_sold_date_sk between 2475986 and 2476016 +or ss_sold_date_sk between 2476351 and 2476381 +or ss_sold_date_sk between 2476717 and 2476747 +or ss_sold_date_sk between 2477082 and 2477112 +or ss_sold_date_sk between 2477447 and 2477477 +or ss_sold_date_sk between 2477812 and 2477842 +or ss_sold_date_sk between 2478178 and 2478208 +or ss_sold_date_sk between 2478543 and 2478573 +or ss_sold_date_sk between 2478908 and 2478938 +or ss_sold_date_sk between 2479273 and 2479303 +or ss_sold_date_sk between 2479639 and 2479669 +or ss_sold_date_sk between 2480004 and 2480034 +or ss_sold_date_sk between 2480369 and 2480399 +or ss_sold_date_sk between 2480734 and 2480764 +or ss_sold_date_sk between 2481100 and 2481130 +or ss_sold_date_sk between 2481465 and 2481495 +or ss_sold_date_sk between 2481830 and 2481860 +or ss_sold_date_sk between 2482195 and 2482225 +or ss_sold_date_sk between 2482561 and 2482591 +or ss_sold_date_sk between 2482926 and 2482956 +or ss_sold_date_sk between 2483291 and 2483321 +or ss_sold_date_sk between 2483656 and 2483686 +or ss_sold_date_sk between 2484022 and 2484052 +or ss_sold_date_sk between 2484387 and 2484417 +or ss_sold_date_sk between 2484752 and 2484782 +or ss_sold_date_sk between 2485117 and 2485147 +or ss_sold_date_sk between 2485483 and 2485513 +or ss_sold_date_sk between 2485848 and 2485878 +or ss_sold_date_sk between 2486213 and 2486243 +or ss_sold_date_sk between 2486578 and 2486608 +or ss_sold_date_sk between 2486944 and 2486974 +or ss_sold_date_sk between 2487309 and 2487339 +or ss_sold_date_sk between 2487674 and 2487704 +or ss_sold_date_sk between 2488039 and 2488069 +) +group by + dt.d_year, + item.i_brand, + item.i_brand_id +order by + dt.d_year, + sum_agg desc, + brand_id +limit 100 +-- end query 3 in stream 0 using template query3.tpl diff --git a/sql/core/src/test/resources/tpcds-modifiedQueries/q34.sql b/sql/core/src/test/resources/tpcds-modifiedQueries/q34.sql new file mode 100755 index 0000000000000..d11696e5e0c34 --- /dev/null +++ b/sql/core/src/test/resources/tpcds-modifiedQueries/q34.sql @@ -0,0 +1,45 @@ +-- start query 34 in stream 0 using template query34.tpl +select + c_last_name, + c_first_name, + c_salutation, + c_preferred_cust_flag, + ss_ticket_number, + cnt +from + (select + ss_ticket_number, + ss_customer_sk, + count(*) cnt + from + store_sales, + date_dim, + store, + household_demographics + where + store_sales.ss_sold_date_sk = date_dim.d_date_sk + and store_sales.ss_store_sk = store.s_store_sk + and store_sales.ss_hdemo_sk = household_demographics.hd_demo_sk + and (date_dim.d_dom between 1 and 3 + or date_dim.d_dom between 25 and 28) + and (household_demographics.hd_buy_potential = '>10000' + or household_demographics.hd_buy_potential = 'Unknown') + and household_demographics.hd_vehicle_count > 0 + and (case when household_demographics.hd_vehicle_count > 0 then household_demographics.hd_dep_count / household_demographics.hd_vehicle_count else null end) > 1.2 + and date_dim.d_year in (1998, 1998 + 1, 1998 + 2) + and store.s_county in ('Saginaw County', 'Sumner County', 'Appanoose County', 'Daviess County', 'Fairfield County', 'Raleigh County', 'Ziebach County', 'Williamson County') + and ss_sold_date_sk between 2450816 and 2451910 -- partition key filter + group by + ss_ticket_number, + ss_customer_sk + ) dn, + customer +where + ss_customer_sk = c_customer_sk + and cnt between 15 and 20 +order by + c_last_name, + c_first_name, + c_salutation, + c_preferred_cust_flag desc +-- end query 34 in stream 0 using template query34.tpl diff --git a/sql/core/src/test/resources/tpcds-modifiedQueries/q42.sql b/sql/core/src/test/resources/tpcds-modifiedQueries/q42.sql new file mode 100755 index 0000000000000..b6332a8afbebe --- /dev/null +++ b/sql/core/src/test/resources/tpcds-modifiedQueries/q42.sql @@ -0,0 +1,28 @@ +-- start query 42 in stream 0 using template query42.tpl +select + dt.d_year, + item.i_category_id, + item.i_category, + sum(ss_ext_sales_price) +from + date_dim dt, + store_sales, + item +where + dt.d_date_sk = store_sales.ss_sold_date_sk + and store_sales.ss_item_sk = item.i_item_sk + and item.i_manager_id = 1 + and dt.d_moy = 12 + and dt.d_year = 1998 + and ss_sold_date_sk between 2451149 and 2451179 -- partition key filter +group by + dt.d_year, + item.i_category_id, + item.i_category +order by + sum(ss_ext_sales_price) desc, + dt.d_year, + item.i_category_id, + item.i_category +limit 100 +-- end query 42 in stream 0 using template query42.tpl diff --git a/sql/core/src/test/resources/tpcds-modifiedQueries/q43.sql b/sql/core/src/test/resources/tpcds-modifiedQueries/q43.sql new file mode 100755 index 0000000000000..cc2040b2fdb7c --- /dev/null +++ b/sql/core/src/test/resources/tpcds-modifiedQueries/q43.sql @@ -0,0 +1,36 @@ +-- start query 43 in stream 0 using template query43.tpl +select + s_store_name, + s_store_id, + sum(case when (d_day_name = 'Sunday') then ss_sales_price else null end) sun_sales, + sum(case when (d_day_name = 'Monday') then ss_sales_price else null end) mon_sales, + sum(case when (d_day_name = 'Tuesday') then ss_sales_price else null end) tue_sales, + sum(case when (d_day_name = 'Wednesday') then ss_sales_price else null end) wed_sales, + sum(case when (d_day_name = 'Thursday') then ss_sales_price else null end) thu_sales, + sum(case when (d_day_name = 'Friday') then ss_sales_price else null end) fri_sales, + sum(case when (d_day_name = 'Saturday') then ss_sales_price else null end) sat_sales +from + date_dim, + store_sales, + store +where + d_date_sk = ss_sold_date_sk + and s_store_sk = ss_store_sk + and s_gmt_offset = -5 + and d_year = 1998 + and ss_sold_date_sk between 2450816 and 2451179 -- partition key filter +group by + s_store_name, + s_store_id +order by + s_store_name, + s_store_id, + sun_sales, + mon_sales, + tue_sales, + wed_sales, + thu_sales, + fri_sales, + sat_sales +limit 100 +-- end query 43 in stream 0 using template query43.tpl diff --git a/sql/core/src/test/resources/tpcds-modifiedQueries/q46.sql b/sql/core/src/test/resources/tpcds-modifiedQueries/q46.sql new file mode 100755 index 0000000000000..52b7ba4f4b86b --- /dev/null +++ b/sql/core/src/test/resources/tpcds-modifiedQueries/q46.sql @@ -0,0 +1,80 @@ +-- start query 46 in stream 0 using template query46.tpl +select + c_last_name, + c_first_name, + ca_city, + bought_city, + ss_ticket_number, + amt, + profit +from + (select + ss_ticket_number, + ss_customer_sk, + ca_city bought_city, + sum(ss_coupon_amt) amt, + sum(ss_net_profit) profit + from + store_sales, + date_dim, + store, + household_demographics, + customer_address + where + store_sales.ss_sold_date_sk = date_dim.d_date_sk + and store_sales.ss_store_sk = store.s_store_sk + and store_sales.ss_hdemo_sk = household_demographics.hd_demo_sk + and store_sales.ss_addr_sk = customer_address.ca_address_sk + and (household_demographics.hd_dep_count = 5 + or household_demographics.hd_vehicle_count = 3) + and date_dim.d_dow in (6, 0) + and date_dim.d_year in (1999, 1999 + 1, 1999 + 2) + and store.s_city in ('Midway', 'Concord', 'Spring Hill', 'Brownsville', 'Greenville') + -- partition key filter + and ss_sold_date_sk in (2451181, 2451182, 2451188, 2451189, 2451195, 2451196, 2451202, 2451203, 2451209, 2451210, 2451216, 2451217, + 2451223, 2451224, 2451230, 2451231, 2451237, 2451238, 2451244, 2451245, 2451251, 2451252, 2451258, 2451259, + 2451265, 2451266, 2451272, 2451273, 2451279, 2451280, 2451286, 2451287, 2451293, 2451294, 2451300, 2451301, + 2451307, 2451308, 2451314, 2451315, 2451321, 2451322, 2451328, 2451329, 2451335, 2451336, 2451342, 2451343, + 2451349, 2451350, 2451356, 2451357, 2451363, 2451364, 2451370, 2451371, 2451377, 2451378, 2451384, 2451385, + 2451391, 2451392, 2451398, 2451399, 2451405, 2451406, 2451412, 2451413, 2451419, 2451420, 2451426, 2451427, + 2451433, 2451434, 2451440, 2451441, 2451447, 2451448, 2451454, 2451455, 2451461, 2451462, 2451468, 2451469, + 2451475, 2451476, 2451482, 2451483, 2451489, 2451490, 2451496, 2451497, 2451503, 2451504, 2451510, 2451511, + 2451517, 2451518, 2451524, 2451525, 2451531, 2451532, 2451538, 2451539, 2451545, 2451546, 2451552, 2451553, + 2451559, 2451560, 2451566, 2451567, 2451573, 2451574, 2451580, 2451581, 2451587, 2451588, 2451594, 2451595, + 2451601, 2451602, 2451608, 2451609, 2451615, 2451616, 2451622, 2451623, 2451629, 2451630, 2451636, 2451637, + 2451643, 2451644, 2451650, 2451651, 2451657, 2451658, 2451664, 2451665, 2451671, 2451672, 2451678, 2451679, + 2451685, 2451686, 2451692, 2451693, 2451699, 2451700, 2451706, 2451707, 2451713, 2451714, 2451720, 2451721, + 2451727, 2451728, 2451734, 2451735, 2451741, 2451742, 2451748, 2451749, 2451755, 2451756, 2451762, 2451763, + 2451769, 2451770, 2451776, 2451777, 2451783, 2451784, 2451790, 2451791, 2451797, 2451798, 2451804, 2451805, + 2451811, 2451812, 2451818, 2451819, 2451825, 2451826, 2451832, 2451833, 2451839, 2451840, 2451846, 2451847, + 2451853, 2451854, 2451860, 2451861, 2451867, 2451868, 2451874, 2451875, 2451881, 2451882, 2451888, 2451889, + 2451895, 2451896, 2451902, 2451903, 2451909, 2451910, 2451916, 2451917, 2451923, 2451924, 2451930, 2451931, + 2451937, 2451938, 2451944, 2451945, 2451951, 2451952, 2451958, 2451959, 2451965, 2451966, 2451972, 2451973, + 2451979, 2451980, 2451986, 2451987, 2451993, 2451994, 2452000, 2452001, 2452007, 2452008, 2452014, 2452015, + 2452021, 2452022, 2452028, 2452029, 2452035, 2452036, 2452042, 2452043, 2452049, 2452050, 2452056, 2452057, + 2452063, 2452064, 2452070, 2452071, 2452077, 2452078, 2452084, 2452085, 2452091, 2452092, 2452098, 2452099, + 2452105, 2452106, 2452112, 2452113, 2452119, 2452120, 2452126, 2452127, 2452133, 2452134, 2452140, 2452141, + 2452147, 2452148, 2452154, 2452155, 2452161, 2452162, 2452168, 2452169, 2452175, 2452176, 2452182, 2452183, + 2452189, 2452190, 2452196, 2452197, 2452203, 2452204, 2452210, 2452211, 2452217, 2452218, 2452224, 2452225, + 2452231, 2452232, 2452238, 2452239, 2452245, 2452246, 2452252, 2452253, 2452259, 2452260, 2452266, 2452267, + 2452273, 2452274) + group by + ss_ticket_number, + ss_customer_sk, + ss_addr_sk, + ca_city + ) dn, + customer, + customer_address current_addr +where + ss_customer_sk = c_customer_sk + and customer.c_current_addr_sk = current_addr.ca_address_sk + and current_addr.ca_city <> bought_city +order by + c_last_name, + c_first_name, + ca_city, + bought_city, + ss_ticket_number +limit 100 +-- end query 46 in stream 0 using template query46.tpl diff --git a/sql/core/src/test/resources/tpcds-modifiedQueries/q52.sql b/sql/core/src/test/resources/tpcds-modifiedQueries/q52.sql new file mode 100755 index 0000000000000..a510eefb13e17 --- /dev/null +++ b/sql/core/src/test/resources/tpcds-modifiedQueries/q52.sql @@ -0,0 +1,27 @@ +-- start query 52 in stream 0 using template query52.tpl +select + dt.d_year, + item.i_brand_id brand_id, + item.i_brand brand, + sum(ss_ext_sales_price) ext_price +from + date_dim dt, + store_sales, + item +where + dt.d_date_sk = store_sales.ss_sold_date_sk + and store_sales.ss_item_sk = item.i_item_sk + and item.i_manager_id = 1 + and dt.d_moy = 12 + and dt.d_year = 1998 + and ss_sold_date_sk between 2451149 and 2451179 -- added for partition pruning +group by + dt.d_year, + item.i_brand, + item.i_brand_id +order by + dt.d_year, + ext_price desc, + brand_id +limit 100 +-- end query 52 in stream 0 using template query52.tpl diff --git a/sql/core/src/test/resources/tpcds-modifiedQueries/q53.sql b/sql/core/src/test/resources/tpcds-modifiedQueries/q53.sql new file mode 100755 index 0000000000000..fb7bb75183858 --- /dev/null +++ b/sql/core/src/test/resources/tpcds-modifiedQueries/q53.sql @@ -0,0 +1,37 @@ +-- start query 53 in stream 0 using template query53.tpl +select + * +from + (select + i_manufact_id, + sum(ss_sales_price) sum_sales, + avg(sum(ss_sales_price)) over (partition by i_manufact_id) avg_quarterly_sales + from + item, + store_sales, + date_dim, + store + where + ss_item_sk = i_item_sk + and ss_sold_date_sk = d_date_sk + and ss_store_sk = s_store_sk + and d_month_seq in (1212, 1212 + 1, 1212 + 2, 1212 + 3, 1212 + 4, 1212 + 5, 1212 + 6, 1212 + 7, 1212 + 8, 1212 + 9, 1212 + 10, 1212 + 11) + and ((i_category in ('Books', 'Children', 'Electronics') + and i_class in ('personal', 'portable', 'reference', 'self-help') + and i_brand in ('scholaramalgamalg #14', 'scholaramalgamalg #7', 'exportiunivamalg #9', 'scholaramalgamalg #9')) + or (i_category in ('Women', 'Music', 'Men') + and i_class in ('accessories', 'classical', 'fragrances', 'pants') + and i_brand in ('amalgimporto #1', 'edu packscholar #1', 'exportiimporto #1', 'importoamalg #1'))) + and ss_sold_date_sk between 2451911 and 2452275 -- partition key filter + group by + i_manufact_id, + d_qoy + ) tmp1 +where + case when avg_quarterly_sales > 0 then abs (sum_sales - avg_quarterly_sales) / avg_quarterly_sales else null end > 0.1 +order by + avg_quarterly_sales, + sum_sales, + i_manufact_id +limit 100 +-- end query 53 in stream 0 using template query53.tpl diff --git a/sql/core/src/test/resources/tpcds-modifiedQueries/q55.sql b/sql/core/src/test/resources/tpcds-modifiedQueries/q55.sql new file mode 100755 index 0000000000000..47b1f0292d901 --- /dev/null +++ b/sql/core/src/test/resources/tpcds-modifiedQueries/q55.sql @@ -0,0 +1,24 @@ +-- start query 55 in stream 0 using template query55.tpl +select + i_brand_id brand_id, + i_brand brand, + sum(ss_ext_sales_price) ext_price +from + date_dim, + store_sales, + item +where + d_date_sk = ss_sold_date_sk + and ss_item_sk = i_item_sk + and i_manager_id = 48 + and d_moy = 11 + and d_year = 2001 + and ss_sold_date_sk between 2452215 and 2452244 +group by + i_brand, + i_brand_id +order by + ext_price desc, + i_brand_id +limit 100 +-- end query 55 in stream 0 using template query55.tpl diff --git a/sql/core/src/test/resources/tpcds-modifiedQueries/q59.sql b/sql/core/src/test/resources/tpcds-modifiedQueries/q59.sql new file mode 100755 index 0000000000000..3d5c4e9d64419 --- /dev/null +++ b/sql/core/src/test/resources/tpcds-modifiedQueries/q59.sql @@ -0,0 +1,83 @@ +-- start query 59 in stream 0 using template query59.tpl +with + wss as + (select + d_week_seq, + ss_store_sk, + sum(case when (d_day_name = 'Sunday') then ss_sales_price else null end) sun_sales, + sum(case when (d_day_name = 'Monday') then ss_sales_price else null end) mon_sales, + sum(case when (d_day_name = 'Tuesday') then ss_sales_price else null end) tue_sales, + sum(case when (d_day_name = 'Wednesday') then ss_sales_price else null end) wed_sales, + sum(case when (d_day_name = 'Thursday') then ss_sales_price else null end) thu_sales, + sum(case when (d_day_name = 'Friday') then ss_sales_price else null end) fri_sales, + sum(case when (d_day_name = 'Saturday') then ss_sales_price else null end) sat_sales + from + store_sales, + date_dim + where + d_date_sk = ss_sold_date_sk + group by + d_week_seq, + ss_store_sk + ) +select + s_store_name1, + s_store_id1, + d_week_seq1, + sun_sales1 / sun_sales2, + mon_sales1 / mon_sales2, + tue_sales1 / tue_sales1, + wed_sales1 / wed_sales2, + thu_sales1 / thu_sales2, + fri_sales1 / fri_sales2, + sat_sales1 / sat_sales2 +from + (select + s_store_name s_store_name1, + wss.d_week_seq d_week_seq1, + s_store_id s_store_id1, + sun_sales sun_sales1, + mon_sales mon_sales1, + tue_sales tue_sales1, + wed_sales wed_sales1, + thu_sales thu_sales1, + fri_sales fri_sales1, + sat_sales sat_sales1 + from + wss, + store, + date_dim d + where + d.d_week_seq = wss.d_week_seq + and ss_store_sk = s_store_sk + and d_month_seq between 1185 and 1185 + 11 + ) y, + (select + s_store_name s_store_name2, + wss.d_week_seq d_week_seq2, + s_store_id s_store_id2, + sun_sales sun_sales2, + mon_sales mon_sales2, + tue_sales tue_sales2, + wed_sales wed_sales2, + thu_sales thu_sales2, + fri_sales fri_sales2, + sat_sales sat_sales2 + from + wss, + store, + date_dim d + where + d.d_week_seq = wss.d_week_seq + and ss_store_sk = s_store_sk + and d_month_seq between 1185 + 12 and 1185 + 23 + ) x +where + s_store_id1 = s_store_id2 + and d_week_seq1 = d_week_seq2 - 52 +order by + s_store_name1, + s_store_id1, + d_week_seq1 +limit 100 +-- end query 59 in stream 0 using template query59.tpl diff --git a/sql/core/src/test/resources/tpcds-modifiedQueries/q63.sql b/sql/core/src/test/resources/tpcds-modifiedQueries/q63.sql new file mode 100755 index 0000000000000..b71199ab17d0b --- /dev/null +++ b/sql/core/src/test/resources/tpcds-modifiedQueries/q63.sql @@ -0,0 +1,29 @@ +-- start query 63 in stream 0 using template query63.tpl +select * +from (select i_manager_id + ,sum(ss_sales_price) sum_sales + ,avg(sum(ss_sales_price)) over (partition by i_manager_id) avg_monthly_sales + from item + ,store_sales + ,date_dim + ,store + where ss_item_sk = i_item_sk + and ss_sold_date_sk = d_date_sk + and ss_sold_date_sk between 2452123 and 2452487 + and ss_store_sk = s_store_sk + and d_month_seq in (1219,1219+1,1219+2,1219+3,1219+4,1219+5,1219+6,1219+7,1219+8,1219+9,1219+10,1219+11) + and (( i_category in ('Books','Children','Electronics') + and i_class in ('personal','portable','reference','self-help') + and i_brand in ('scholaramalgamalg #14','scholaramalgamalg #7', + 'exportiunivamalg #9','scholaramalgamalg #9')) + or( i_category in ('Women','Music','Men') + and i_class in ('accessories','classical','fragrances','pants') + and i_brand in ('amalgimporto #1','edu packscholar #1','exportiimporto #1', + 'importoamalg #1'))) +group by i_manager_id, d_moy) tmp1 +where case when avg_monthly_sales > 0 then abs (sum_sales - avg_monthly_sales) / avg_monthly_sales else null end > 0.1 +order by i_manager_id + ,avg_monthly_sales + ,sum_sales +limit 100 +-- end query 63 in stream 0 using template query63.tpl diff --git a/sql/core/src/test/resources/tpcds-modifiedQueries/q65.sql b/sql/core/src/test/resources/tpcds-modifiedQueries/q65.sql new file mode 100755 index 0000000000000..7344feeff6a9f --- /dev/null +++ b/sql/core/src/test/resources/tpcds-modifiedQueries/q65.sql @@ -0,0 +1,58 @@ +-- start query 65 in stream 0 using template query65.tpl +select + s_store_name, + i_item_desc, + sc.revenue, + i_current_price, + i_wholesale_cost, + i_brand +from + store, + item, + (select + ss_store_sk, + avg(revenue) as ave + from + (select + ss_store_sk, + ss_item_sk, + sum(ss_sales_price) as revenue + from + store_sales, + date_dim + where + ss_sold_date_sk = d_date_sk + and d_month_seq between 1212 and 1212 + 11 + and ss_sold_date_sk between 2451911 and 2452275 -- partition key filter + group by + ss_store_sk, + ss_item_sk + ) sa + group by + ss_store_sk + ) sb, + (select + ss_store_sk, + ss_item_sk, + sum(ss_sales_price) as revenue + from + store_sales, + date_dim + where + ss_sold_date_sk = d_date_sk + and d_month_seq between 1212 and 1212 + 11 + and ss_sold_date_sk between 2451911 and 2452275 -- partition key filter + group by + ss_store_sk, + ss_item_sk + ) sc +where + sb.ss_store_sk = sc.ss_store_sk + and sc.revenue <= 0.1 * sb.ave + and s_store_sk = sc.ss_store_sk + and i_item_sk = sc.ss_item_sk +order by + s_store_name, + i_item_desc +limit 100 +-- end query 65 in stream 0 using template query65.tpl diff --git a/sql/core/src/test/resources/tpcds-modifiedQueries/q68.sql b/sql/core/src/test/resources/tpcds-modifiedQueries/q68.sql new file mode 100755 index 0000000000000..94df4b3f57a90 --- /dev/null +++ b/sql/core/src/test/resources/tpcds-modifiedQueries/q68.sql @@ -0,0 +1,62 @@ +-- start query 68 in stream 0 using template query68.tpl +-- changed to match exact same partitions in original query +select + c_last_name, + c_first_name, + ca_city, + bought_city, + ss_ticket_number, + extended_price, + extended_tax, + list_price +from + (select + ss_ticket_number, + ss_customer_sk, + ca_city bought_city, + sum(ss_ext_sales_price) extended_price, + sum(ss_ext_list_price) list_price, + sum(ss_ext_tax) extended_tax + from + store_sales, + date_dim, + store, + household_demographics, + customer_address + where + store_sales.ss_sold_date_sk = date_dim.d_date_sk + and store_sales.ss_store_sk = store.s_store_sk + and store_sales.ss_hdemo_sk = household_demographics.hd_demo_sk + and store_sales.ss_addr_sk = customer_address.ca_address_sk + and date_dim.d_dom between 1 and 2 + and (household_demographics.hd_dep_count = 5 + or household_demographics.hd_vehicle_count = 3) + and date_dim.d_year in (1999, 1999 + 1, 1999 + 2) + and store.s_city in ('Midway', 'Fairview') + -- partition key filter + and ss_sold_date_sk in (2451180, 2451181, 2451211, 2451212, 2451239, 2451240, 2451270, 2451271, 2451300, 2451301, 2451331, + 2451332, 2451361, 2451362, 2451392, 2451393, 2451423, 2451424, 2451453, 2451454, 2451484, 2451485, + 2451514, 2451515, 2451545, 2451546, 2451576, 2451577, 2451605, 2451606, 2451636, 2451637, 2451666, + 2451667, 2451697, 2451698, 2451727, 2451728, 2451758, 2451759, 2451789, 2451790, 2451819, 2451820, + 2451850, 2451851, 2451880, 2451881, 2451911, 2451912, 2451942, 2451943, 2451970, 2451971, 2452001, + 2452002, 2452031, 2452032, 2452062, 2452063, 2452092, 2452093, 2452123, 2452124, 2452154, 2452155, + 2452184, 2452185, 2452215, 2452216, 2452245, 2452246) + --and ss_sold_date_sk between 2451180 and 2451269 -- partition key filter (3 months) + --and d_date between '1999-01-01' and '1999-03-31' + group by + ss_ticket_number, + ss_customer_sk, + ss_addr_sk, + ca_city + ) dn, + customer, + customer_address current_addr +where + ss_customer_sk = c_customer_sk + and customer.c_current_addr_sk = current_addr.ca_address_sk + and current_addr.ca_city <> bought_city +order by + c_last_name, + ss_ticket_number +limit 100 +-- end query 68 in stream 0 using template query68.tpl diff --git a/sql/core/src/test/resources/tpcds-modifiedQueries/q7.sql b/sql/core/src/test/resources/tpcds-modifiedQueries/q7.sql new file mode 100755 index 0000000000000..c61a2d0d2a8fa --- /dev/null +++ b/sql/core/src/test/resources/tpcds-modifiedQueries/q7.sql @@ -0,0 +1,31 @@ +-- start query 7 in stream 0 using template query7.tpl +select + i_item_id, + avg(ss_quantity) agg1, + avg(ss_list_price) agg2, + avg(ss_coupon_amt) agg3, + avg(ss_sales_price) agg4 +from + store_sales, + customer_demographics, + date_dim, + item, + promotion +where + ss_sold_date_sk = d_date_sk + and ss_item_sk = i_item_sk + and ss_cdemo_sk = cd_demo_sk + and ss_promo_sk = p_promo_sk + and cd_gender = 'F' + and cd_marital_status = 'W' + and cd_education_status = 'Primary' + and (p_channel_email = 'N' + or p_channel_event = 'N') + and d_year = 1998 + and ss_sold_date_sk between 2450815 and 2451179 -- partition key filter +group by + i_item_id +order by + i_item_id +limit 100 +-- end query 7 in stream 0 using template query7.tpl diff --git a/sql/core/src/test/resources/tpcds-modifiedQueries/q73.sql b/sql/core/src/test/resources/tpcds-modifiedQueries/q73.sql new file mode 100755 index 0000000000000..8703910b305a8 --- /dev/null +++ b/sql/core/src/test/resources/tpcds-modifiedQueries/q73.sql @@ -0,0 +1,49 @@ +-- start query 73 in stream 0 using template query73.tpl +select + c_last_name, + c_first_name, + c_salutation, + c_preferred_cust_flag, + ss_ticket_number, + cnt +from + (select + ss_ticket_number, + ss_customer_sk, + count(*) cnt + from + store_sales, + date_dim, + store, + household_demographics + where + store_sales.ss_sold_date_sk = date_dim.d_date_sk + and store_sales.ss_store_sk = store.s_store_sk + and store_sales.ss_hdemo_sk = household_demographics.hd_demo_sk + and date_dim.d_dom between 1 and 2 + and (household_demographics.hd_buy_potential = '>10000' + or household_demographics.hd_buy_potential = 'Unknown') + and household_demographics.hd_vehicle_count > 0 + and case when household_demographics.hd_vehicle_count > 0 then household_demographics.hd_dep_count / household_demographics.hd_vehicle_count else null end > 1 + and date_dim.d_year in (1998, 1998 + 1, 1998 + 2) + and store.s_county in ('Fairfield County','Ziebach County','Bronx County','Barrow County') + -- partition key filter + and ss_sold_date_sk in (2450815, 2450816, 2450846, 2450847, 2450874, 2450875, 2450905, 2450906, 2450935, 2450936, 2450966, 2450967, + 2450996, 2450997, 2451027, 2451028, 2451058, 2451059, 2451088, 2451089, 2451119, 2451120, 2451149, + 2451150, 2451180, 2451181, 2451211, 2451212, 2451239, 2451240, 2451270, 2451271, 2451300, 2451301, + 2451331, 2451332, 2451361, 2451362, 2451392, 2451393, 2451423, 2451424, 2451453, 2451454, 2451484, + 2451485, 2451514, 2451515, 2451545, 2451546, 2451576, 2451577, 2451605, 2451606, 2451636, 2451637, + 2451666, 2451667, 2451697, 2451698, 2451727, 2451728, 2451758, 2451759, 2451789, 2451790, 2451819, + 2451820, 2451850, 2451851, 2451880, 2451881) + --and ss_sold_date_sk between 2451180 and 2451269 -- partition key filter (3 months) + group by + ss_ticket_number, + ss_customer_sk + ) dj, + customer +where + ss_customer_sk = c_customer_sk + and cnt between 1 and 5 +order by + cnt desc +-- end query 73 in stream 0 using template query73.tpl diff --git a/sql/core/src/test/resources/tpcds-modifiedQueries/q79.sql b/sql/core/src/test/resources/tpcds-modifiedQueries/q79.sql new file mode 100755 index 0000000000000..4254310ecd10b --- /dev/null +++ b/sql/core/src/test/resources/tpcds-modifiedQueries/q79.sql @@ -0,0 +1,59 @@ +-- start query 79 in stream 0 using template query79.tpl +select + c_last_name, + c_first_name, + substr(s_city, 1, 30), + ss_ticket_number, + amt, + profit +from + (select + ss_ticket_number, + ss_customer_sk, + store.s_city, + sum(ss_coupon_amt) amt, + sum(ss_net_profit) profit + from + store_sales, + date_dim, + store, + household_demographics + where + store_sales.ss_sold_date_sk = date_dim.d_date_sk + and store_sales.ss_store_sk = store.s_store_sk + and store_sales.ss_hdemo_sk = household_demographics.hd_demo_sk + and (household_demographics.hd_dep_count = 8 + or household_demographics.hd_vehicle_count > 0) + and date_dim.d_dow = 1 + and date_dim.d_year in (1998, 1998 + 1, 1998 + 2) + and store.s_number_employees between 200 and 295 + and ss_sold_date_sk between 2450819 and 2451904 + -- partition key filter + --and ss_sold_date_sk in (2450819, 2450826, 2450833, 2450840, 2450847, 2450854, 2450861, 2450868, 2450875, 2450882, 2450889, + -- 2450896, 2450903, 2450910, 2450917, 2450924, 2450931, 2450938, 2450945, 2450952, 2450959, 2450966, 2450973, 2450980, 2450987, + -- 2450994, 2451001, 2451008, 2451015, 2451022, 2451029, 2451036, 2451043, 2451050, 2451057, 2451064, 2451071, 2451078, 2451085, + -- 2451092, 2451099, 2451106, 2451113, 2451120, 2451127, 2451134, 2451141, 2451148, 2451155, 2451162, 2451169, 2451176, 2451183, + -- 2451190, 2451197, 2451204, 2451211, 2451218, 2451225, 2451232, 2451239, 2451246, 2451253, 2451260, 2451267, 2451274, 2451281, + -- 2451288, 2451295, 2451302, 2451309, 2451316, 2451323, 2451330, 2451337, 2451344, 2451351, 2451358, 2451365, 2451372, 2451379, + -- 2451386, 2451393, 2451400, 2451407, 2451414, 2451421, 2451428, 2451435, 2451442, 2451449, 2451456, 2451463, 2451470, 2451477, + -- 2451484, 2451491, 2451498, 2451505, 2451512, 2451519, 2451526, 2451533, 2451540, 2451547, 2451554, 2451561, 2451568, 2451575, + -- 2451582, 2451589, 2451596, 2451603, 2451610, 2451617, 2451624, 2451631, 2451638, 2451645, 2451652, 2451659, 2451666, 2451673, + -- 2451680, 2451687, 2451694, 2451701, 2451708, 2451715, 2451722, 2451729, 2451736, 2451743, 2451750, 2451757, 2451764, 2451771, + -- 2451778, 2451785, 2451792, 2451799, 2451806, 2451813, 2451820, 2451827, 2451834, 2451841, 2451848, 2451855, 2451862, 2451869, + -- 2451876, 2451883, 2451890, 2451897, 2451904) + group by + ss_ticket_number, + ss_customer_sk, + ss_addr_sk, + store.s_city + ) ms, + customer +where + ss_customer_sk = c_customer_sk +order by + c_last_name, + c_first_name, + substr(s_city, 1, 30), + profit + limit 100 +-- end query 79 in stream 0 using template query79.tpl diff --git a/sql/core/src/test/resources/tpcds-modifiedQueries/q89.sql b/sql/core/src/test/resources/tpcds-modifiedQueries/q89.sql new file mode 100755 index 0000000000000..b1d814af5e57a --- /dev/null +++ b/sql/core/src/test/resources/tpcds-modifiedQueries/q89.sql @@ -0,0 +1,43 @@ +-- start query 89 in stream 0 using template query89.tpl +select + * +from + (select + i_category, + i_class, + i_brand, + s_store_name, + s_company_name, + d_moy, + sum(ss_sales_price) sum_sales, + avg(sum(ss_sales_price)) over (partition by i_category, i_brand, s_store_name, s_company_name) avg_monthly_sales + from + item, + store_sales, + date_dim, + store + where + ss_item_sk = i_item_sk + and ss_sold_date_sk = d_date_sk + and ss_store_sk = s_store_sk + and d_year in (2000) + and ((i_category in ('Home', 'Books', 'Electronics') + and i_class in ('wallpaper', 'parenting', 'musical')) + or (i_category in ('Shoes', 'Jewelry', 'Men') + and i_class in ('womens', 'birdal', 'pants'))) + and ss_sold_date_sk between 2451545 and 2451910 -- partition key filter + group by + i_category, + i_class, + i_brand, + s_store_name, + s_company_name, + d_moy + ) tmp1 +where + case when (avg_monthly_sales <> 0) then (abs(sum_sales - avg_monthly_sales) / avg_monthly_sales) else null end > 0.1 +order by + sum_sales - avg_monthly_sales, + s_store_name +limit 100 +-- end query 89 in stream 0 using template query89.tpl diff --git a/sql/core/src/test/resources/tpcds-modifiedQueries/q98.sql b/sql/core/src/test/resources/tpcds-modifiedQueries/q98.sql new file mode 100755 index 0000000000000..f53f2f5f9c5b6 --- /dev/null +++ b/sql/core/src/test/resources/tpcds-modifiedQueries/q98.sql @@ -0,0 +1,32 @@ +-- start query 98 in stream 0 using template query98.tpl +select + i_item_desc, + i_category, + i_class, + i_current_price, + sum(ss_ext_sales_price) as itemrevenue, + sum(ss_ext_sales_price) * 100 / sum(sum(ss_ext_sales_price)) over (partition by i_class) as revenueratio +from + store_sales, + item, + date_dim +where + ss_item_sk = i_item_sk + and i_category in ('Jewelry', 'Sports', 'Books') + and ss_sold_date_sk = d_date_sk + and ss_sold_date_sk between 2451911 and 2451941 -- partition key filter (1 calendar month) + and d_date between '2001-01-01' and '2001-01-31' +group by + i_item_id, + i_item_desc, + i_category, + i_class, + i_current_price +order by + i_category, + i_class, + i_item_id, + i_item_desc, + revenueratio +--limit 1000; -- added limit +-- end query 98 in stream 0 using template query98.tpl diff --git a/sql/core/src/test/resources/tpcds-modifiedQueries/ss_max.sql b/sql/core/src/test/resources/tpcds-modifiedQueries/ss_max.sql new file mode 100755 index 0000000000000..bf58b4bb3c5a5 --- /dev/null +++ b/sql/core/src/test/resources/tpcds-modifiedQueries/ss_max.sql @@ -0,0 +1,14 @@ +select + count(*) as total, + count(ss_sold_date_sk) as not_null_total, + count(distinct ss_sold_date_sk) as unique_days, + max(ss_sold_date_sk) as max_ss_sold_date_sk, + max(ss_sold_time_sk) as max_ss_sold_time_sk, + max(ss_item_sk) as max_ss_item_sk, + max(ss_customer_sk) as max_ss_customer_sk, + max(ss_cdemo_sk) as max_ss_cdemo_sk, + max(ss_hdemo_sk) as max_ss_hdemo_sk, + max(ss_addr_sk) as max_ss_addr_sk, + max(ss_store_sk) as max_ss_store_sk, + max(ss_promo_sk) as max_ss_promo_sk +from store_sales diff --git a/sql/core/src/test/scala/org/apache/spark/sql/TPCDSQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/TPCDSQuerySuite.scala index c0797fa55f5da..e47d4b0ee25d4 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/TPCDSQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/TPCDSQuerySuite.scala @@ -22,9 +22,18 @@ import org.scalatest.BeforeAndAfterAll import org.apache.spark.sql.catalyst.util.resourceToString import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSQLContext +import org.apache.spark.util.Utils +/** + * This test suite ensures all the TPC-DS queries can be successfully analyzed and optimized + * without hitting the max iteration threshold. + */ class TPCDSQuerySuite extends QueryTest with SharedSQLContext with BeforeAndAfterAll { + // When Utils.isTesting is true, the RuleExecutor will issue an exception when hitting + // the max iteration of analyzer/optimizer batches. + assert(Utils.isTesting, "spark.testing is not set to true") + /** * Drop all the tables */ @@ -341,8 +350,23 @@ class TPCDSQuerySuite extends QueryTest with SharedSQLContext with BeforeAndAfte classLoader = Thread.currentThread().getContextClassLoader) test(name) { withSQLConf(SQLConf.CROSS_JOINS_ENABLED.key -> "true") { - sql(queryString).collect() + // Just check the plans can be properly generated + sql(queryString).queryExecution.executedPlan } } } + + // These queries are from https://github.com/cloudera/impala-tpcds-kit/tree/master/queries + val modifiedTPCDSQueries = Seq( + "q3", "q7", "q10", "q19", "q27", "q34", "q42", "q43", "q46", "q52", "q53", "q55", "q59", + "q63", "q65", "q68", "q73", "q79", "q89", "q98", "ss_max") + + modifiedTPCDSQueries.foreach { name => + val queryString = resourceToString(s"tpcds-modifiedQueries/$name.sql", + classLoader = Thread.currentThread().getContextClassLoader) + test(s"modified-$name") { + // Just check the plans can be properly generated + sql(queryString).queryExecution.executedPlan + } + } } From 472864014c42da08b9d3f3fffbe657c6fcf1e2ef Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Fri, 29 Sep 2017 11:45:58 -0700 Subject: [PATCH 1424/1765] Revert "[SPARK-22142][BUILD][STREAMING] Move Flume support behind a profile" This reverts commit a2516f41aef68e39df7f6380fd2618cc148a609e. --- dev/create-release/release-build.sh | 4 ++-- dev/mima | 2 +- dev/scalastyle | 1 - dev/sparktestsupport/modules.py | 20 +------------------- dev/test-dependencies.sh | 2 +- docs/building-spark.md | 6 ------ pom.xml | 13 +++---------- project/SparkBuild.scala | 17 ++++++++--------- python/pyspark/streaming/tests.py | 16 +++------------- 9 files changed, 19 insertions(+), 62 deletions(-) diff --git a/dev/create-release/release-build.sh b/dev/create-release/release-build.sh index 7e8d5c7075195..5390f5916fc0d 100755 --- a/dev/create-release/release-build.sh +++ b/dev/create-release/release-build.sh @@ -84,9 +84,9 @@ MVN="build/mvn --force" # Hive-specific profiles for some builds HIVE_PROFILES="-Phive -Phive-thriftserver" # Profiles for publishing snapshots and release to Maven Central -PUBLISH_PROFILES="-Pmesos -Pyarn -Pflume $HIVE_PROFILES -Pspark-ganglia-lgpl -Pkinesis-asl" +PUBLISH_PROFILES="-Pmesos -Pyarn $HIVE_PROFILES -Pspark-ganglia-lgpl -Pkinesis-asl" # Profiles for building binary releases -BASE_RELEASE_PROFILES="-Pmesos -Pyarn -Pflume -Psparkr" +BASE_RELEASE_PROFILES="-Pmesos -Pyarn -Psparkr" # Scala 2.11 only profiles for some builds SCALA_2_11_PROFILES="-Pkafka-0-8" # Scala 2.12 only profiles for some builds diff --git a/dev/mima b/dev/mima index 1e3ca9700bc07..fdb21f5007cf2 100755 --- a/dev/mima +++ b/dev/mima @@ -24,7 +24,7 @@ set -e FWDIR="$(cd "`dirname "$0"`"/..; pwd)" cd "$FWDIR" -SPARK_PROFILES="-Pmesos -Pkafka-0-8 -Pyarn -Pflume -Pspark-ganglia-lgpl -Pkinesis-asl -Phive-thriftserver -Phive" +SPARK_PROFILES="-Pmesos -Pkafka-0-8 -Pyarn -Pspark-ganglia-lgpl -Pkinesis-asl -Phive-thriftserver -Phive" TOOLS_CLASSPATH="$(build/sbt -DcopyDependencies=false "export tools/fullClasspath" | tail -n1)" OLD_DEPS_CLASSPATH="$(build/sbt -DcopyDependencies=false $SPARK_PROFILES "export oldDeps/fullClasspath" | tail -n1)" diff --git a/dev/scalastyle b/dev/scalastyle index 89ecc8abd6f8c..e5aa589869535 100755 --- a/dev/scalastyle +++ b/dev/scalastyle @@ -25,7 +25,6 @@ ERRORS=$(echo -e "q\n" \ -Pmesos \ -Pkafka-0-8 \ -Pyarn \ - -Pflume \ -Phive \ -Phive-thriftserver \ scalastyle test:scalastyle \ diff --git a/dev/sparktestsupport/modules.py b/dev/sparktestsupport/modules.py index 91d5667ed1f07..50e14b60545af 100644 --- a/dev/sparktestsupport/modules.py +++ b/dev/sparktestsupport/modules.py @@ -279,12 +279,6 @@ def __hash__(self): source_file_regexes=[ "external/flume-sink", ], - build_profile_flags=[ - "-Pflume", - ], - environ={ - "ENABLE_FLUME_TESTS": "1" - }, sbt_test_goals=[ "streaming-flume-sink/test", ] @@ -297,12 +291,6 @@ def __hash__(self): source_file_regexes=[ "external/flume", ], - build_profile_flags=[ - "-Pflume", - ], - environ={ - "ENABLE_FLUME_TESTS": "1" - }, sbt_test_goals=[ "streaming-flume/test", ] @@ -314,13 +302,7 @@ def __hash__(self): dependencies=[streaming_flume, streaming_flume_sink], source_file_regexes=[ "external/flume-assembly", - ], - build_profile_flags=[ - "-Pflume", - ], - environ={ - "ENABLE_FLUME_TESTS": "1" - } + ] ) diff --git a/dev/test-dependencies.sh b/dev/test-dependencies.sh index 58b295d4f6e00..c7714578bd005 100755 --- a/dev/test-dependencies.sh +++ b/dev/test-dependencies.sh @@ -29,7 +29,7 @@ export LC_ALL=C # TODO: This would be much nicer to do in SBT, once SBT supports Maven-style resolution. # NOTE: These should match those in the release publishing script -HADOOP2_MODULE_PROFILES="-Phive-thriftserver -Pmesos -Pkafka-0-8 -Pyarn -Pflume -Phive" +HADOOP2_MODULE_PROFILES="-Phive-thriftserver -Pmesos -Pkafka-0-8 -Pyarn -Phive" MVN="build/mvn" HADOOP_PROFILES=( hadoop-2.6 diff --git a/docs/building-spark.md b/docs/building-spark.md index e1532de16108d..57baa503259c1 100644 --- a/docs/building-spark.md +++ b/docs/building-spark.md @@ -100,12 +100,6 @@ Note: Kafka 0.8 support is deprecated as of Spark 2.3.0. Kafka 0.10 support is still automatically built. -## Building with Flume support - -Apache Flume support must be explicitly enabled with the `flume` profile. - - ./build/mvn -Pflume -DskipTests clean package - ## Building submodules individually It's possible to build Spark sub-modules using the `mvn -pl` option. diff --git a/pom.xml b/pom.xml index 9fac8b1e53788..87a468c3a6f55 100644 --- a/pom.xml +++ b/pom.xml @@ -98,13 +98,15 @@ sql/core sql/hive assembly + external/flume + external/flume-sink + external/flume-assembly examples repl launcher external/kafka-0-10 external/kafka-0-10-assembly external/kafka-0-10-sql - @@ -2581,15 +2583,6 @@ - - flume - - external/flume - external/flume-sink - external/flume-assembly - - - spark-ganglia-lgpl diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala index 9501eed1e906b..a568d264cb2db 100644 --- a/project/SparkBuild.scala +++ b/project/SparkBuild.scala @@ -43,8 +43,11 @@ object BuildCommons { "catalyst", "sql", "hive", "hive-thriftserver", "sql-kafka-0-10" ).map(ProjectRef(buildLocation, _)) - val streamingProjects@Seq(streaming, streamingKafka010) = - Seq("streaming", "streaming-kafka-0-10").map(ProjectRef(buildLocation, _)) + val streamingProjects@Seq( + streaming, streamingFlumeSink, streamingFlume, streamingKafka010 + ) = Seq( + "streaming", "streaming-flume-sink", "streaming-flume", "streaming-kafka-0-10" + ).map(ProjectRef(buildLocation, _)) val allProjects@Seq( core, graphx, mllib, mllibLocal, repl, networkCommon, networkShuffle, launcher, unsafe, tags, sketch, kvstore, _* @@ -53,13 +56,9 @@ object BuildCommons { "tags", "sketch", "kvstore" ).map(ProjectRef(buildLocation, _)) ++ sqlProjects ++ streamingProjects - val optionallyEnabledProjects@Seq(mesos, yarn, - streamingFlumeSink, streamingFlume, - streamingKafka, sparkGangliaLgpl, streamingKinesisAsl, - dockerIntegrationTests, hadoopCloud) = - Seq("mesos", "yarn", - "streaming-flume-sink", "streaming-flume", - "streaming-kafka-0-8", "ganglia-lgpl", "streaming-kinesis-asl", + val optionallyEnabledProjects@Seq(mesos, yarn, streamingKafka, sparkGangliaLgpl, + streamingKinesisAsl, dockerIntegrationTests, hadoopCloud) = + Seq("mesos", "yarn", "streaming-kafka-0-8", "ganglia-lgpl", "streaming-kinesis-asl", "docker-integration-tests", "hadoop-cloud").map(ProjectRef(buildLocation, _)) val assemblyProjects@Seq(networkYarn, streamingFlumeAssembly, streamingKafkaAssembly, streamingKafka010Assembly, streamingKinesisAslAssembly) = diff --git a/python/pyspark/streaming/tests.py b/python/pyspark/streaming/tests.py index 5b86c1cb2c390..229cf53e47359 100644 --- a/python/pyspark/streaming/tests.py +++ b/python/pyspark/streaming/tests.py @@ -1478,7 +1478,7 @@ def search_kafka_assembly_jar(): ("Failed to find Spark Streaming kafka assembly jar in %s. " % kafka_assembly_dir) + "You need to build Spark with " "'build/sbt assembly/package streaming-kafka-0-8-assembly/assembly' or " - "'build/mvn -Pkafka-0-8 package' before running this test.") + "'build/mvn package' before running this test.") elif len(jars) > 1: raise Exception(("Found multiple Spark Streaming Kafka assembly JARs: %s; please " "remove all but one") % (", ".join(jars))) @@ -1495,7 +1495,7 @@ def search_flume_assembly_jar(): ("Failed to find Spark Streaming Flume assembly jar in %s. " % flume_assembly_dir) + "You need to build Spark with " "'build/sbt assembly/assembly streaming-flume-assembly/assembly' or " - "'build/mvn -Pflume package' before running this test.") + "'build/mvn package' before running this test.") elif len(jars) > 1: raise Exception(("Found multiple Spark Streaming Flume assembly JARs: %s; please " "remove all but one") % (", ".join(jars))) @@ -1516,9 +1516,6 @@ def search_kinesis_asl_assembly_jar(): return jars[0] -# Must be same as the variable and condition defined in modules.py -flume_test_environ_var = "ENABLE_FLUME_TESTS" -are_flume_tests_enabled = os.environ.get(flume_test_environ_var) == '1' # Must be same as the variable and condition defined in modules.py kafka_test_environ_var = "ENABLE_KAFKA_0_8_TESTS" are_kafka_tests_enabled = os.environ.get(kafka_test_environ_var) == '1' @@ -1541,16 +1538,9 @@ def search_kinesis_asl_assembly_jar(): os.environ["PYSPARK_SUBMIT_ARGS"] = "--jars %s pyspark-shell" % jars testcases = [BasicOperationTests, WindowFunctionTests, StreamingContextTests, CheckpointTests, + FlumeStreamTests, FlumePollingStreamTests, StreamingListenerTests] - if are_flume_tests_enabled: - testcases.append(FlumeStreamTests) - testcases.append(FlumePollingStreamTests) - else: - sys.stderr.write( - "Skipped test_flume_stream (enable by setting environment variable %s=1" - % flume_test_environ_var) - if are_kafka_tests_enabled: testcases.append(KafkaStreamTests) else: From 530fe683297cb11b920a4df6630eff5d7e7ddce2 Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Fri, 29 Sep 2017 19:35:32 -0700 Subject: [PATCH 1425/1765] [SPARK-21904][SQL] Rename tempTables to tempViews in SessionCatalog ### What changes were proposed in this pull request? `tempTables` is not right. To be consistent, we need to rename the internal variable names/comments to tempViews in SessionCatalog too. ### How was this patch tested? N/A Author: gatorsmile Closes #19117 from gatorsmile/renameTempTablesToTempViews. --- .../sql/catalyst/catalog/SessionCatalog.scala | 79 +++++++++---------- .../sql/execution/command/DDLSuite.scala | 10 +-- 2 files changed, 43 insertions(+), 46 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala index 9407b727bca4c..6ba9ee5446a01 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala @@ -17,7 +17,6 @@ package org.apache.spark.sql.catalyst.catalog -import java.lang.reflect.InvocationTargetException import java.net.URI import java.util.Locale import java.util.concurrent.Callable @@ -25,7 +24,6 @@ import javax.annotation.concurrent.GuardedBy import scala.collection.mutable import scala.util.{Failure, Success, Try} -import scala.util.control.NonFatal import com.google.common.cache.{Cache, CacheBuilder} import org.apache.hadoop.conf.Configuration @@ -41,7 +39,6 @@ import org.apache.spark.sql.catalyst.parser.{CatalystSqlParser, ParserInterface} import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, SubqueryAlias, View} import org.apache.spark.sql.catalyst.util.StringUtils import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.internal.StaticSQLConf.CATALOG_IMPLEMENTATION import org.apache.spark.sql.types.StructType import org.apache.spark.util.Utils @@ -52,7 +49,7 @@ object SessionCatalog { /** * An internal catalog that is used by a Spark Session. This internal catalog serves as a * proxy to the underlying metastore (e.g. Hive Metastore) and it also manages temporary - * tables and functions of the Spark Session that it belongs to. + * views and functions of the Spark Session that it belongs to. * * This class must be thread-safe. */ @@ -90,13 +87,13 @@ class SessionCatalog( new SQLConf().copy(SQLConf.CASE_SENSITIVE -> true)) } - /** List of temporary tables, mapping from table name to their logical plan. */ + /** List of temporary views, mapping from table name to their logical plan. */ @GuardedBy("this") - protected val tempTables = new mutable.HashMap[String, LogicalPlan] + protected val tempViews = new mutable.HashMap[String, LogicalPlan] // Note: we track current database here because certain operations do not explicitly // specify the database (e.g. DROP TABLE my_table). In these cases we must first - // check whether the temporary table or function exists, then, if not, operate on + // check whether the temporary view or function exists, then, if not, operate on // the corresponding item in the current database. @GuardedBy("this") protected var currentDb: String = formatDatabaseName(DEFAULT_DATABASE) @@ -272,8 +269,8 @@ class SessionCatalog( // ---------------------------------------------------------------------------- // Tables // ---------------------------------------------------------------------------- - // There are two kinds of tables, temporary tables and metastore tables. - // Temporary tables are isolated across sessions and do not belong to any + // There are two kinds of tables, temporary views and metastore tables. + // Temporary views are isolated across sessions and do not belong to any // particular database. Metastore tables can be used across multiple // sessions as their metadata is persisted in the underlying catalog. // ---------------------------------------------------------------------------- @@ -462,10 +459,10 @@ class SessionCatalog( tableDefinition: LogicalPlan, overrideIfExists: Boolean): Unit = synchronized { val table = formatTableName(name) - if (tempTables.contains(table) && !overrideIfExists) { + if (tempViews.contains(table) && !overrideIfExists) { throw new TempTableAlreadyExistsException(name) } - tempTables.put(table, tableDefinition) + tempViews.put(table, tableDefinition) } /** @@ -487,7 +484,7 @@ class SessionCatalog( viewDefinition: LogicalPlan): Boolean = synchronized { val viewName = formatTableName(name.table) if (name.database.isEmpty) { - if (tempTables.contains(viewName)) { + if (tempViews.contains(viewName)) { createTempView(viewName, viewDefinition, overrideIfExists = true) true } else { @@ -504,7 +501,7 @@ class SessionCatalog( * Return a local temporary view exactly as it was stored. */ def getTempView(name: String): Option[LogicalPlan] = synchronized { - tempTables.get(formatTableName(name)) + tempViews.get(formatTableName(name)) } /** @@ -520,7 +517,7 @@ class SessionCatalog( * Returns true if this view is dropped successfully, false otherwise. */ def dropTempView(name: String): Boolean = synchronized { - tempTables.remove(formatTableName(name)).isDefined + tempViews.remove(formatTableName(name)).isDefined } /** @@ -572,7 +569,7 @@ class SessionCatalog( * Rename a table. * * If a database is specified in `oldName`, this will rename the table in that database. - * If no database is specified, this will first attempt to rename a temporary table with + * If no database is specified, this will first attempt to rename a temporary view with * the same name, then, if that does not exist, rename the table in the current database. * * This assumes the database specified in `newName` matches the one in `oldName`. @@ -592,7 +589,7 @@ class SessionCatalog( globalTempViewManager.rename(oldTableName, newTableName) } else { requireDbExists(db) - if (oldName.database.isDefined || !tempTables.contains(oldTableName)) { + if (oldName.database.isDefined || !tempViews.contains(oldTableName)) { requireTableExists(TableIdentifier(oldTableName, Some(db))) requireTableNotExists(TableIdentifier(newTableName, Some(db))) validateName(newTableName) @@ -600,16 +597,16 @@ class SessionCatalog( } else { if (newName.database.isDefined) { throw new AnalysisException( - s"RENAME TEMPORARY TABLE from '$oldName' to '$newName': cannot specify database " + + s"RENAME TEMPORARY VIEW from '$oldName' to '$newName': cannot specify database " + s"name '${newName.database.get}' in the destination table") } - if (tempTables.contains(newTableName)) { - throw new AnalysisException(s"RENAME TEMPORARY TABLE from '$oldName' to '$newName': " + + if (tempViews.contains(newTableName)) { + throw new AnalysisException(s"RENAME TEMPORARY VIEW from '$oldName' to '$newName': " + "destination table already exists") } - val table = tempTables(oldTableName) - tempTables.remove(oldTableName) - tempTables.put(newTableName, table) + val table = tempViews(oldTableName) + tempViews.remove(oldTableName) + tempViews.put(newTableName, table) } } } @@ -618,7 +615,7 @@ class SessionCatalog( * Drop a table. * * If a database is specified in `name`, this will drop the table from that database. - * If no database is specified, this will first attempt to drop a temporary table with + * If no database is specified, this will first attempt to drop a temporary view with * the same name, then, if that does not exist, drop the table from the current database. */ def dropTable( @@ -633,7 +630,7 @@ class SessionCatalog( throw new NoSuchTableException(globalTempViewManager.database, table) } } else { - if (name.database.isDefined || !tempTables.contains(table)) { + if (name.database.isDefined || !tempViews.contains(table)) { requireDbExists(db) // When ignoreIfNotExists is false, no exception is issued when the table does not exist. // Instead, log it as an error message. @@ -643,7 +640,7 @@ class SessionCatalog( throw new NoSuchTableException(db = db, table = table) } } else { - tempTables.remove(table) + tempViews.remove(table) } } } @@ -652,7 +649,7 @@ class SessionCatalog( * Return a [[LogicalPlan]] that represents the given table or view. * * If a database is specified in `name`, this will return the table/view from that database. - * If no database is specified, this will first attempt to return a temporary table/view with + * If no database is specified, this will first attempt to return a temporary view with * the same name, then, if that does not exist, return the table/view from the current database. * * Note that, the global temp view database is also valid here, this will return the global temp @@ -671,7 +668,7 @@ class SessionCatalog( globalTempViewManager.get(table).map { viewDef => SubqueryAlias(table, viewDef) }.getOrElse(throw new NoSuchTableException(db, table)) - } else if (name.database.isDefined || !tempTables.contains(table)) { + } else if (name.database.isDefined || !tempViews.contains(table)) { val metadata = externalCatalog.getTable(db, table) if (metadata.tableType == CatalogTableType.VIEW) { val viewText = metadata.viewText.getOrElse(sys.error("Invalid view without text.")) @@ -687,21 +684,21 @@ class SessionCatalog( SubqueryAlias(table, UnresolvedCatalogRelation(metadata)) } } else { - SubqueryAlias(table, tempTables(table)) + SubqueryAlias(table, tempViews(table)) } } } /** - * Return whether a table with the specified name is a temporary table. + * Return whether a table with the specified name is a temporary view. * - * Note: The temporary table cache is checked only when database is not + * Note: The temporary view cache is checked only when database is not * explicitly specified. */ def isTemporaryTable(name: TableIdentifier): Boolean = synchronized { val table = formatTableName(name.table) if (name.database.isEmpty) { - tempTables.contains(table) + tempViews.contains(table) } else if (formatDatabaseName(name.database.get) == globalTempViewManager.database) { globalTempViewManager.get(table).isDefined } else { @@ -710,7 +707,7 @@ class SessionCatalog( } /** - * List all tables in the specified database, including local temporary tables. + * List all tables in the specified database, including local temporary views. * * Note that, if the specified database is global temporary view database, we will list global * temporary views. @@ -718,7 +715,7 @@ class SessionCatalog( def listTables(db: String): Seq[TableIdentifier] = listTables(db, "*") /** - * List all matching tables in the specified database, including local temporary tables. + * List all matching tables in the specified database, including local temporary views. * * Note that, if the specified database is global temporary view database, we will list global * temporary views. @@ -736,7 +733,7 @@ class SessionCatalog( } } val localTempViews = synchronized { - StringUtils.filterPattern(tempTables.keys.toSeq, pattern).map { name => + StringUtils.filterPattern(tempViews.keys.toSeq, pattern).map { name => TableIdentifier(name) } } @@ -750,11 +747,11 @@ class SessionCatalog( val dbName = formatDatabaseName(name.database.getOrElse(currentDb)) val tableName = formatTableName(name.table) - // Go through temporary tables and invalidate them. + // Go through temporary views and invalidate them. // If the database is defined, this may be a global temporary view. - // If the database is not defined, there is a good chance this is a temp table. + // If the database is not defined, there is a good chance this is a temp view. if (name.database.isEmpty) { - tempTables.get(tableName).foreach(_.refresh()) + tempViews.get(tableName).foreach(_.refresh()) } else if (dbName == globalTempViewManager.database) { globalTempViewManager.get(tableName).foreach(_.refresh()) } @@ -765,11 +762,11 @@ class SessionCatalog( } /** - * Drop all existing temporary tables. + * Drop all existing temporary views. * For testing only. */ def clearTempTables(): Unit = synchronized { - tempTables.clear() + tempViews.clear() } // ---------------------------------------------------------------------------- @@ -1337,7 +1334,7 @@ class SessionCatalog( */ private[sql] def copyStateTo(target: SessionCatalog): Unit = synchronized { target.currentDb = currentDb - // copy over temporary tables - tempTables.foreach(kv => target.tempTables.put(kv._1, kv._2)) + // copy over temporary views + tempViews.foreach(kv => target.tempViews.put(kv._1, kv._2)) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala index d19cfeef7d19f..4ed2cecc5faff 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala @@ -795,7 +795,7 @@ abstract class DDLSuite extends QueryTest with SQLTestUtils { checkAnswer(spark.table("teachers"), df) } - test("rename temporary table - destination table with database name") { + test("rename temporary view - destination table with database name") { withTempView("tab1") { sql( """ @@ -812,7 +812,7 @@ abstract class DDLSuite extends QueryTest with SQLTestUtils { sql("ALTER TABLE tab1 RENAME TO default.tab2") } assert(e.getMessage.contains( - "RENAME TEMPORARY TABLE from '`tab1`' to '`default`.`tab2`': " + + "RENAME TEMPORARY VIEW from '`tab1`' to '`default`.`tab2`': " + "cannot specify database name 'default' in the destination table")) val catalog = spark.sessionState.catalog @@ -820,7 +820,7 @@ abstract class DDLSuite extends QueryTest with SQLTestUtils { } } - test("rename temporary table") { + test("rename temporary view") { withTempView("tab1", "tab2") { spark.range(10).createOrReplaceTempView("tab1") sql("ALTER TABLE tab1 RENAME TO tab2") @@ -832,7 +832,7 @@ abstract class DDLSuite extends QueryTest with SQLTestUtils { } } - test("rename temporary table - destination table already exists") { + test("rename temporary view - destination table already exists") { withTempView("tab1", "tab2") { sql( """ @@ -860,7 +860,7 @@ abstract class DDLSuite extends QueryTest with SQLTestUtils { sql("ALTER TABLE tab1 RENAME TO tab2") } assert(e.getMessage.contains( - "RENAME TEMPORARY TABLE from '`tab1`' to '`tab2`': destination table already exists")) + "RENAME TEMPORARY VIEW from '`tab1`' to '`tab2`': destination table already exists")) val catalog = spark.sessionState.catalog assert(catalog.listTables("default") == Seq(TableIdentifier("tab1"), TableIdentifier("tab2"))) From c6610a997f69148a1f1bbf69360e8f39e24cb70a Mon Sep 17 00:00:00 2001 From: Takeshi Yamamuro Date: Fri, 29 Sep 2017 21:36:52 -0700 Subject: [PATCH 1426/1765] [SPARK-22122][SQL] Use analyzed logical plans to count input rows in TPCDSQueryBenchmark ## What changes were proposed in this pull request? Since the current code ignores WITH clauses to check input relations in TPCDS queries, this leads to inaccurate per-row processing time for benchmark results. For example, in `q2`, this fix could catch all the input relations: `web_sales`, `date_dim`, and `catalog_sales` (the current code catches `date_dim` only). The one-third of the TPCDS queries uses WITH clauses, so I think it is worth fixing this. ## How was this patch tested? Manually checked. Author: Takeshi Yamamuro Closes #19344 from maropu/RespectWithInTPCDSBench. --- .../benchmark/TPCDSQueryBenchmark.scala | 32 +++++++------------ 1 file changed, 11 insertions(+), 21 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/TPCDSQueryBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/TPCDSQueryBenchmark.scala index 99c6df7389205..69247d7f4e9aa 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/TPCDSQueryBenchmark.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/TPCDSQueryBenchmark.scala @@ -20,11 +20,10 @@ package org.apache.spark.sql.execution.benchmark import org.apache.spark.SparkConf import org.apache.spark.internal.Logging import org.apache.spark.sql.SparkSession -import org.apache.spark.sql.catalyst.TableIdentifier -import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation -import org.apache.spark.sql.catalyst.expressions.SubqueryExpression -import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.catalyst.catalog.HiveTableRelation +import org.apache.spark.sql.catalyst.plans.logical.SubqueryAlias import org.apache.spark.sql.catalyst.util._ +import org.apache.spark.sql.execution.datasources.LogicalRelation import org.apache.spark.util.Benchmark /** @@ -66,24 +65,15 @@ object TPCDSQueryBenchmark extends Logging { classLoader = Thread.currentThread().getContextClassLoader) // This is an indirect hack to estimate the size of each query's input by traversing the - // logical plan and adding up the sizes of all tables that appear in the plan. Note that this - // currently doesn't take WITH subqueries into account which might lead to fairly inaccurate - // per-row processing time for those cases. + // logical plan and adding up the sizes of all tables that appear in the plan. val queryRelations = scala.collection.mutable.HashSet[String]() - spark.sql(queryString).queryExecution.logical.map { - case UnresolvedRelation(t: TableIdentifier) => - queryRelations.add(t.table) - case lp: LogicalPlan => - lp.expressions.foreach { _ foreach { - case subquery: SubqueryExpression => - subquery.plan.foreach { - case UnresolvedRelation(t: TableIdentifier) => - queryRelations.add(t.table) - case _ => - } - case _ => - } - } + spark.sql(queryString).queryExecution.analyzed.foreach { + case SubqueryAlias(alias, _: LogicalRelation) => + queryRelations.add(alias) + case LogicalRelation(_, _, Some(catalogTable), _) => + queryRelations.add(catalogTable.identifier.table) + case HiveTableRelation(tableMeta, _, _) => + queryRelations.add(tableMeta.identifier.table) case _ => } val numRows = queryRelations.map(tableSizes.getOrElse(_, 0L)).sum From 02c91e03f975c2a6a05a9d5327057bb6b3c4a66f Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Sun, 1 Oct 2017 18:42:45 +0900 Subject: [PATCH 1427/1765] [SPARK-22063][R] Fixes lint check failures in R by latest commit sha1 ID of lint-r ## What changes were proposed in this pull request? Currently, we set lintr to jimhester/lintra769c0b (see [this](https://github.com/apache/spark/commit/7d1175011c976756efcd4e4e4f70a8fd6f287026) and [SPARK-14074](https://issues.apache.org/jira/browse/SPARK-14074)). I first tested and checked lintr-1.0.1 but it looks many important fixes are missing (for example, checking 100 length). So, I instead tried the latest commit, https://github.com/jimhester/lintr/commit/5431140ffea65071f1327625d4a8de9688fa7e72, in my local and fixed the check failures. It looks it has fixed many bugs and now finds many instances that I have observed and thought should be caught time to time, here I filed [the results](https://gist.github.com/HyukjinKwon/4f59ddcc7b6487a02da81800baca533c). The downside looks it now takes about 7ish mins, (it was 2ish mins before) in my local. ## How was this patch tested? Manually, `./dev/lint-r` after manually updating the lintr package. Author: hyukjinkwon Author: zuotingbing Closes #19290 from HyukjinKwon/upgrade-r-lint. --- R/pkg/.lintr | 2 +- R/pkg/R/DataFrame.R | 30 ++-- R/pkg/R/RDD.R | 6 +- R/pkg/R/WindowSpec.R | 2 +- R/pkg/R/column.R | 2 + R/pkg/R/context.R | 2 +- R/pkg/R/deserialize.R | 2 +- R/pkg/R/functions.R | 79 ++++++----- R/pkg/R/generics.R | 4 +- R/pkg/R/group.R | 4 +- R/pkg/R/mllib_classification.R | 137 +++++++++++-------- R/pkg/R/mllib_clustering.R | 15 +- R/pkg/R/mllib_regression.R | 62 +++++---- R/pkg/R/mllib_tree.R | 36 +++-- R/pkg/R/pairRDD.R | 4 +- R/pkg/R/schema.R | 2 +- R/pkg/R/stats.R | 14 +- R/pkg/R/utils.R | 4 +- R/pkg/inst/worker/worker.R | 2 +- R/pkg/tests/fulltests/test_binary_function.R | 2 +- R/pkg/tests/fulltests/test_rdd.R | 6 +- R/pkg/tests/fulltests/test_sparkSQL.R | 14 +- dev/lint-r.R | 4 +- 23 files changed, 242 insertions(+), 193 deletions(-) diff --git a/R/pkg/.lintr b/R/pkg/.lintr index ae50b28ec6166..c83ad2adfe0ef 100644 --- a/R/pkg/.lintr +++ b/R/pkg/.lintr @@ -1,2 +1,2 @@ -linters: with_defaults(line_length_linter(100), multiple_dots_linter = NULL, camel_case_linter = NULL, open_curly_linter(allow_single_line = TRUE), closed_curly_linter(allow_single_line = TRUE)) +linters: with_defaults(line_length_linter(100), multiple_dots_linter = NULL, object_name_linter = NULL, camel_case_linter = NULL, open_curly_linter(allow_single_line = TRUE), closed_curly_linter(allow_single_line = TRUE)) exclusions: list("inst/profile/general.R" = 1, "inst/profile/shell.R") diff --git a/R/pkg/R/DataFrame.R b/R/pkg/R/DataFrame.R index 0728141fa483e..176bb3b8a8d0c 100644 --- a/R/pkg/R/DataFrame.R +++ b/R/pkg/R/DataFrame.R @@ -1923,13 +1923,15 @@ setMethod("[", signature(x = "SparkDataFrame"), #' @param i,subset (Optional) a logical expression to filter on rows. #' For extract operator [[ and replacement operator [[<-, the indexing parameter for #' a single Column. -#' @param j,select expression for the single Column or a list of columns to select from the SparkDataFrame. +#' @param j,select expression for the single Column or a list of columns to select from the +#' SparkDataFrame. #' @param drop if TRUE, a Column will be returned if the resulting dataset has only one column. #' Otherwise, a SparkDataFrame will always be returned. #' @param value a Column or an atomic vector in the length of 1 as literal value, or \code{NULL}. #' If \code{NULL}, the specified Column is dropped. #' @param ... currently not used. -#' @return A new SparkDataFrame containing only the rows that meet the condition with selected columns. +#' @return A new SparkDataFrame containing only the rows that meet the condition with selected +#' columns. #' @export #' @family SparkDataFrame functions #' @aliases subset,SparkDataFrame-method @@ -2608,12 +2610,12 @@ setMethod("merge", } else { # if by or both by.x and by.y have length 0, use Cartesian Product joinRes <- crossJoin(x, y) - return (joinRes) + return(joinRes) } # sets alias for making colnames unique in dataframes 'x' and 'y' - colsX <- generateAliasesForIntersectedCols(x, by, suffixes[1]) - colsY <- generateAliasesForIntersectedCols(y, by, suffixes[2]) + colsX <- genAliasesForIntersectedCols(x, by, suffixes[1]) + colsY <- genAliasesForIntersectedCols(y, by, suffixes[2]) # selects columns with their aliases from dataframes # in case same column names are present in both data frames @@ -2661,9 +2663,8 @@ setMethod("merge", #' @param intersectedColNames a list of intersected column names of the SparkDataFrame #' @param suffix a suffix for the column name #' @return list of columns -#' -#' @note generateAliasesForIntersectedCols since 1.6.0 -generateAliasesForIntersectedCols <- function (x, intersectedColNames, suffix) { +#' @noRd +genAliasesForIntersectedCols <- function(x, intersectedColNames, suffix) { allColNames <- names(x) # sets alias for making colnames unique in dataframe 'x' cols <- lapply(allColNames, function(colName) { @@ -2671,7 +2672,7 @@ generateAliasesForIntersectedCols <- function (x, intersectedColNames, suffix) { if (colName %in% intersectedColNames) { newJoin <- paste(colName, suffix, sep = "") if (newJoin %in% allColNames){ - stop ("The following column name: ", newJoin, " occurs more than once in the 'DataFrame'.", + stop("The following column name: ", newJoin, " occurs more than once in the 'DataFrame'.", "Please use different suffixes for the intersected columns.") } col <- alias(col, newJoin) @@ -3058,7 +3059,8 @@ setMethod("describe", #' summary(select(df, "age", "height")) #' } #' @note summary(SparkDataFrame) since 1.5.0 -#' @note The statistics provided by \code{summary} were change in 2.3.0 use \link{describe} for previous defaults. +#' @note The statistics provided by \code{summary} were change in 2.3.0 use \link{describe} for +#' previous defaults. #' @seealso \link{describe} setMethod("summary", signature(object = "SparkDataFrame"), @@ -3765,8 +3767,8 @@ setMethod("checkpoint", #' #' Create a multi-dimensional cube for the SparkDataFrame using the specified columns. #' -#' If grouping expression is missing \code{cube} creates a single global aggregate and is equivalent to -#' direct application of \link{agg}. +#' If grouping expression is missing \code{cube} creates a single global aggregate and is +#' equivalent to direct application of \link{agg}. #' #' @param x a SparkDataFrame. #' @param ... character name(s) or Column(s) to group on. @@ -3800,8 +3802,8 @@ setMethod("cube", #' #' Create a multi-dimensional rollup for the SparkDataFrame using the specified columns. #' -#' If grouping expression is missing \code{rollup} creates a single global aggregate and is equivalent to -#' direct application of \link{agg}. +#' If grouping expression is missing \code{rollup} creates a single global aggregate and is +#' equivalent to direct application of \link{agg}. #' #' @param x a SparkDataFrame. #' @param ... character name(s) or Column(s) to group on. diff --git a/R/pkg/R/RDD.R b/R/pkg/R/RDD.R index 15ca212acf87f..6e89b4bb4d964 100644 --- a/R/pkg/R/RDD.R +++ b/R/pkg/R/RDD.R @@ -131,7 +131,7 @@ PipelinedRDD <- function(prev, func) { # Return the serialization mode for an RDD. setGeneric("getSerializedMode", function(rdd, ...) { standardGeneric("getSerializedMode") }) # For normal RDDs we can directly read the serializedMode -setMethod("getSerializedMode", signature(rdd = "RDD"), function(rdd) rdd@env$serializedMode ) +setMethod("getSerializedMode", signature(rdd = "RDD"), function(rdd) rdd@env$serializedMode) # For pipelined RDDs if jrdd_val is set then serializedMode should exist # if not we return the defaultSerialization mode of "byte" as we don't know the serialization # mode at this point in time. @@ -145,7 +145,7 @@ setMethod("getSerializedMode", signature(rdd = "PipelinedRDD"), }) # The jrdd accessor function. -setMethod("getJRDD", signature(rdd = "RDD"), function(rdd) rdd@jrdd ) +setMethod("getJRDD", signature(rdd = "RDD"), function(rdd) rdd@jrdd) setMethod("getJRDD", signature(rdd = "PipelinedRDD"), function(rdd, serializedMode = "byte") { if (!is.null(rdd@env$jrdd_val)) { @@ -893,7 +893,7 @@ setMethod("sampleRDD", if (withReplacement) { count <- stats::rpois(1, fraction) if (count > 0) { - res[ (len + 1) : (len + count) ] <- rep(list(elem), count) + res[(len + 1) : (len + count)] <- rep(list(elem), count) len <- len + count } } else { diff --git a/R/pkg/R/WindowSpec.R b/R/pkg/R/WindowSpec.R index 81beac9ea9925..debc7cbde55e7 100644 --- a/R/pkg/R/WindowSpec.R +++ b/R/pkg/R/WindowSpec.R @@ -73,7 +73,7 @@ setMethod("show", "WindowSpec", setMethod("partitionBy", signature(x = "WindowSpec"), function(x, col, ...) { - stopifnot (class(col) %in% c("character", "Column")) + stopifnot(class(col) %in% c("character", "Column")) if (class(col) == "character") { windowSpec(callJMethod(x@sws, "partitionBy", col, list(...))) diff --git a/R/pkg/R/column.R b/R/pkg/R/column.R index a5c2ea81f2490..3095adb918b67 100644 --- a/R/pkg/R/column.R +++ b/R/pkg/R/column.R @@ -238,8 +238,10 @@ setMethod("between", signature(x = "Column"), #' @param x a Column. #' @param dataType a character object describing the target data type. #' See +# nolint start #' \href{https://spark.apache.org/docs/latest/sparkr.html#data-type-mapping-between-r-and-spark}{ #' Spark Data Types} for available data types. +# nolint end #' @rdname cast #' @name cast #' @family colum_func diff --git a/R/pkg/R/context.R b/R/pkg/R/context.R index 8349b57a30a93..443c2ff8f9ace 100644 --- a/R/pkg/R/context.R +++ b/R/pkg/R/context.R @@ -329,7 +329,7 @@ spark.addFile <- function(path, recursive = FALSE) { #' spark.getSparkFilesRootDirectory() #'} #' @note spark.getSparkFilesRootDirectory since 2.1.0 -spark.getSparkFilesRootDirectory <- function() { +spark.getSparkFilesRootDirectory <- function() { # nolint if (Sys.getenv("SPARKR_IS_RUNNING_ON_WORKER") == "") { # Running on driver. callJStatic("org.apache.spark.SparkFiles", "getRootDirectory") diff --git a/R/pkg/R/deserialize.R b/R/pkg/R/deserialize.R index 0e99b171cabeb..a90f7d381026b 100644 --- a/R/pkg/R/deserialize.R +++ b/R/pkg/R/deserialize.R @@ -43,7 +43,7 @@ readObject <- function(con) { } readTypedObject <- function(con, type) { - switch (type, + switch(type, "i" = readInt(con), "c" = readString(con), "b" = readBoolean(con), diff --git a/R/pkg/R/functions.R b/R/pkg/R/functions.R index 9f286263c2162..0143a3e63ba61 100644 --- a/R/pkg/R/functions.R +++ b/R/pkg/R/functions.R @@ -38,7 +38,8 @@ NULL #' #' Date time functions defined for \code{Column}. #' -#' @param x Column to compute on. In \code{window}, it must be a time Column of \code{TimestampType}. +#' @param x Column to compute on. In \code{window}, it must be a time Column of +#' \code{TimestampType}. #' @param format For \code{to_date} and \code{to_timestamp}, it is the string to use to parse #' Column \code{x} to DateType or TimestampType. For \code{trunc}, it is the string #' to use to specify the truncation method. For example, "year", "yyyy", "yy" for @@ -90,8 +91,8 @@ NULL #' #' Math functions defined for \code{Column}. #' -#' @param x Column to compute on. In \code{shiftLeft}, \code{shiftRight} and \code{shiftRightUnsigned}, -#' this is the number of bits to shift. +#' @param x Column to compute on. In \code{shiftLeft}, \code{shiftRight} and +#' \code{shiftRightUnsigned}, this is the number of bits to shift. #' @param y Column to compute on. #' @param ... additional argument(s). #' @name column_math_functions @@ -480,7 +481,7 @@ setMethod("ceiling", setMethod("coalesce", signature(x = "Column"), function(x, ...) { - jcols <- lapply(list(x, ...), function (x) { + jcols <- lapply(list(x, ...), function(x) { stopifnot(class(x) == "Column") x@jc }) @@ -676,7 +677,7 @@ setMethod("crc32", setMethod("hash", signature(x = "Column"), function(x, ...) { - jcols <- lapply(list(x, ...), function (x) { + jcols <- lapply(list(x, ...), function(x) { stopifnot(class(x) == "Column") x@jc }) @@ -1310,9 +1311,9 @@ setMethod("round", #' Also known as Gaussian rounding or bankers' rounding that rounds to the nearest even number. #' bround(2.5, 0) = 2, bround(3.5, 0) = 4. #' -#' @param scale round to \code{scale} digits to the right of the decimal point when \code{scale} > 0, -#' the nearest even number when \code{scale} = 0, and \code{scale} digits to the left -#' of the decimal point when \code{scale} < 0. +#' @param scale round to \code{scale} digits to the right of the decimal point when +#' \code{scale} > 0, the nearest even number when \code{scale} = 0, and \code{scale} digits +#' to the left of the decimal point when \code{scale} < 0. #' @rdname column_math_functions #' @aliases bround bround,Column-method #' @export @@ -2005,8 +2006,9 @@ setMethod("months_between", signature(y = "Column"), }) #' @details -#' \code{nanvl}: Returns the first column (\code{y}) if it is not NaN, or the second column (\code{x}) if -#' the first column is NaN. Both inputs should be floating point columns (DoubleType or FloatType). +#' \code{nanvl}: Returns the first column (\code{y}) if it is not NaN, or the second column +#' (\code{x}) if the first column is NaN. Both inputs should be floating point columns +#' (DoubleType or FloatType). #' #' @rdname column_nonaggregate_functions #' @aliases nanvl nanvl,Column-method @@ -2061,7 +2063,7 @@ setMethod("approxCountDistinct", setMethod("countDistinct", signature(x = "Column"), function(x, ...) { - jcols <- lapply(list(...), function (x) { + jcols <- lapply(list(...), function(x) { stopifnot(class(x) == "Column") x@jc }) @@ -2090,7 +2092,7 @@ setMethod("countDistinct", setMethod("concat", signature(x = "Column"), function(x, ...) { - jcols <- lapply(list(x, ...), function (x) { + jcols <- lapply(list(x, ...), function(x) { stopifnot(class(x) == "Column") x@jc }) @@ -2110,7 +2112,7 @@ setMethod("greatest", signature(x = "Column"), function(x, ...) { stopifnot(length(list(...)) > 0) - jcols <- lapply(list(x, ...), function (x) { + jcols <- lapply(list(x, ...), function(x) { stopifnot(class(x) == "Column") x@jc }) @@ -2130,7 +2132,7 @@ setMethod("least", signature(x = "Column"), function(x, ...) { stopifnot(length(list(...)) > 0) - jcols <- lapply(list(x, ...), function (x) { + jcols <- lapply(list(x, ...), function(x) { stopifnot(class(x) == "Column") x@jc }) @@ -2406,8 +2408,8 @@ setMethod("shiftLeft", signature(y = "Column", x = "numeric"), }) #' @details -#' \code{shiftRight}: (Signed) shifts the given value numBits right. If the given value is a long value, -#' it will return a long value else it will return an integer value. +#' \code{shiftRight}: (Signed) shifts the given value numBits right. If the given value is a long +#' value, it will return a long value else it will return an integer value. #' #' @rdname column_math_functions #' @aliases shiftRight shiftRight,Column,numeric-method @@ -2505,9 +2507,10 @@ setMethod("format_string", signature(format = "character", x = "Column"), }) #' @details -#' \code{from_unixtime}: Converts the number of seconds from unix epoch (1970-01-01 00:00:00 UTC) to a -#' string representing the timestamp of that moment in the current system time zone in the JVM in the -#' given format. See \href{http://docs.oracle.com/javase/tutorial/i18n/format/simpleDateFormat.html}{ +#' \code{from_unixtime}: Converts the number of seconds from unix epoch (1970-01-01 00:00:00 UTC) +#' to a string representing the timestamp of that moment in the current system time zone in the JVM +#' in the given format. +#' See \href{http://docs.oracle.com/javase/tutorial/i18n/format/simpleDateFormat.html}{ #' Customizing Formats} for available options. #' #' @rdname column_datetime_functions @@ -2634,8 +2637,8 @@ setMethod("lpad", signature(x = "Column", len = "numeric", pad = "character"), }) #' @details -#' \code{rand}: Generates a random column with independent and identically distributed (i.i.d.) samples -#' from U[0.0, 1.0]. +#' \code{rand}: Generates a random column with independent and identically distributed (i.i.d.) +#' samples from U[0.0, 1.0]. #' #' @rdname column_nonaggregate_functions #' @param seed a random seed. Can be missing. @@ -2664,8 +2667,8 @@ setMethod("rand", signature(seed = "numeric"), }) #' @details -#' \code{randn}: Generates a column with independent and identically distributed (i.i.d.) samples from -#' the standard normal distribution. +#' \code{randn}: Generates a column with independent and identically distributed (i.i.d.) samples +#' from the standard normal distribution. #' #' @rdname column_nonaggregate_functions #' @aliases randn randn,missing-method @@ -2831,8 +2834,8 @@ setMethod("unix_timestamp", signature(x = "Column", format = "character"), }) #' @details -#' \code{when}: Evaluates a list of conditions and returns one of multiple possible result expressions. -#' For unmatched expressions null is returned. +#' \code{when}: Evaluates a list of conditions and returns one of multiple possible result +#' expressions. For unmatched expressions null is returned. #' #' @rdname column_nonaggregate_functions #' @param condition the condition to test on. Must be a Column expression. @@ -2859,8 +2862,8 @@ setMethod("when", signature(condition = "Column", value = "ANY"), }) #' @details -#' \code{ifelse}: Evaluates a list of conditions and returns \code{yes} if the conditions are satisfied. -#' Otherwise \code{no} is returned for unmatched conditions. +#' \code{ifelse}: Evaluates a list of conditions and returns \code{yes} if the conditions are +#' satisfied. Otherwise \code{no} is returned for unmatched conditions. #' #' @rdname column_nonaggregate_functions #' @param test a Column expression that describes the condition. @@ -2990,7 +2993,8 @@ setMethod("ntile", }) #' @details -#' \code{percent_rank}: Returns the relative rank (i.e. percentile) of rows within a window partition. +#' \code{percent_rank}: Returns the relative rank (i.e. percentile) of rows within a window +#' partition. #' This is computed by: (rank of row in its partition - 1) / (number of rows in the partition - 1). #' This is equivalent to the \code{PERCENT_RANK} function in SQL. #' The method should be used with no argument. @@ -3160,7 +3164,8 @@ setMethod("posexplode", }) #' @details -#' \code{create_array}: Creates a new array column. The input columns must all have the same data type. +#' \code{create_array}: Creates a new array column. The input columns must all have the same data +#' type. #' #' @rdname column_nonaggregate_functions #' @aliases create_array create_array,Column-method @@ -3169,7 +3174,7 @@ setMethod("posexplode", setMethod("create_array", signature(x = "Column"), function(x, ...) { - jcols <- lapply(list(x, ...), function (x) { + jcols <- lapply(list(x, ...), function(x) { stopifnot(class(x) == "Column") x@jc }) @@ -3178,8 +3183,8 @@ setMethod("create_array", }) #' @details -#' \code{create_map}: Creates a new map column. The input columns must be grouped as key-value pairs, -#' e.g. (key1, value1, key2, value2, ...). +#' \code{create_map}: Creates a new map column. The input columns must be grouped as key-value +#' pairs, e.g. (key1, value1, key2, value2, ...). #' The key columns must all have the same data type, and can't be null. #' The value columns must all have the same data type. #' @@ -3190,7 +3195,7 @@ setMethod("create_array", setMethod("create_map", signature(x = "Column"), function(x, ...) { - jcols <- lapply(list(x, ...), function (x) { + jcols <- lapply(list(x, ...), function(x) { stopifnot(class(x) == "Column") x@jc }) @@ -3352,9 +3357,9 @@ setMethod("not", }) #' @details -#' \code{grouping_bit}: Indicates whether a specified column in a GROUP BY list is aggregated or not, -#' returns 1 for aggregated or 0 for not aggregated in the result set. Same as \code{GROUPING} in SQL -#' and \code{grouping} function in Scala. +#' \code{grouping_bit}: Indicates whether a specified column in a GROUP BY list is aggregated or +#' not, returns 1 for aggregated or 0 for not aggregated in the result set. Same as \code{GROUPING} +#' in SQL and \code{grouping} function in Scala. #' #' @rdname column_aggregate_functions #' @aliases grouping_bit grouping_bit,Column-method @@ -3412,7 +3417,7 @@ setMethod("grouping_bit", setMethod("grouping_id", signature(x = "Column"), function(x, ...) { - jcols <- lapply(list(x, ...), function (x) { + jcols <- lapply(list(x, ...), function(x) { stopifnot(class(x) == "Column") x@jc }) diff --git a/R/pkg/R/generics.R b/R/pkg/R/generics.R index 0fe8f0453b064..4e427489f6860 100644 --- a/R/pkg/R/generics.R +++ b/R/pkg/R/generics.R @@ -385,7 +385,7 @@ setGeneric("value", function(bcast) { standardGeneric("value") }) #' @return A SparkDataFrame. #' @rdname summarize #' @export -setGeneric("agg", function (x, ...) { standardGeneric("agg") }) +setGeneric("agg", function(x, ...) { standardGeneric("agg") }) #' alias #' @@ -731,7 +731,7 @@ setGeneric("schema", function(x) { standardGeneric("schema") }) #' @rdname select #' @export -setGeneric("select", function(x, col, ...) { standardGeneric("select") } ) +setGeneric("select", function(x, col, ...) { standardGeneric("select") }) #' @rdname selectExpr #' @export diff --git a/R/pkg/R/group.R b/R/pkg/R/group.R index 0a7be0e993975..54ef9f07d6fae 100644 --- a/R/pkg/R/group.R +++ b/R/pkg/R/group.R @@ -133,8 +133,8 @@ setMethod("summarize", # Aggregate Functions by name methods <- c("avg", "max", "mean", "min", "sum") -# These are not exposed on GroupedData: "kurtosis", "skewness", "stddev", "stddev_samp", "stddev_pop", -# "variance", "var_samp", "var_pop" +# These are not exposed on GroupedData: "kurtosis", "skewness", "stddev", "stddev_samp", +# "stddev_pop", "variance", "var_samp", "var_pop" #' Pivot a column of the GroupedData and perform the specified aggregation. #' diff --git a/R/pkg/R/mllib_classification.R b/R/pkg/R/mllib_classification.R index 15af8298ba484..7cd072a1d6f89 100644 --- a/R/pkg/R/mllib_classification.R +++ b/R/pkg/R/mllib_classification.R @@ -58,22 +58,25 @@ setClass("NaiveBayesModel", representation(jobj = "jobj")) #' @param regParam The regularization parameter. Only supports L2 regularization currently. #' @param maxIter Maximum iteration number. #' @param tol Convergence tolerance of iterations. -#' @param standardization Whether to standardize the training features before fitting the model. The coefficients -#' of models will be always returned on the original scale, so it will be transparent for -#' users. Note that with/without standardization, the models should be always converged -#' to the same solution when no regularization is applied. +#' @param standardization Whether to standardize the training features before fitting the model. +#' The coefficients of models will be always returned on the original scale, +#' so it will be transparent for users. Note that with/without +#' standardization, the models should be always converged to the same +#' solution when no regularization is applied. #' @param threshold The threshold in binary classification applied to the linear model prediction. #' This threshold can be any real number, where Inf will make all predictions 0.0 #' and -Inf will make all predictions 1.0. #' @param weightCol The weight column name. -#' @param aggregationDepth The depth for treeAggregate (greater than or equal to 2). If the dimensions of features -#' or the number of partitions are large, this param could be adjusted to a larger size. +#' @param aggregationDepth The depth for treeAggregate (greater than or equal to 2). If the +#' dimensions of features or the number of partitions are large, this param +#' could be adjusted to a larger size. #' This is an expert parameter. Default value should be good for most cases. -#' @param handleInvalid How to handle invalid data (unseen labels or NULL values) in features and label -#' column of string type. +#' @param handleInvalid How to handle invalid data (unseen labels or NULL values) in features and +#' label column of string type. #' Supported options: "skip" (filter out rows with invalid data), -#' "error" (throw an error), "keep" (put invalid data in a special additional -#' bucket, at index numLabels). Default is "error". +#' "error" (throw an error), "keep" (put invalid data in +#' a special additional bucket, at index numLabels). Default +#' is "error". #' @param ... additional arguments passed to the method. #' @return \code{spark.svmLinear} returns a fitted linear SVM model. #' @rdname spark.svmLinear @@ -175,62 +178,80 @@ function(object, path, overwrite = FALSE) { #' Logistic Regression Model #' -#' Fits an logistic regression model against a SparkDataFrame. It supports "binomial": Binary logistic regression -#' with pivoting; "multinomial": Multinomial logistic (softmax) regression without pivoting, similar to glmnet. -#' Users can print, make predictions on the produced model and save the model to the input path. +#' Fits an logistic regression model against a SparkDataFrame. It supports "binomial": Binary +#' logistic regression with pivoting; "multinomial": Multinomial logistic (softmax) regression +#' without pivoting, similar to glmnet. Users can print, make predictions on the produced model +#' and save the model to the input path. #' #' @param data SparkDataFrame for training. #' @param formula A symbolic description of the model to be fitted. Currently only a few formula #' operators are supported, including '~', '.', ':', '+', and '-'. #' @param regParam the regularization parameter. -#' @param elasticNetParam the ElasticNet mixing parameter. For alpha = 0.0, the penalty is an L2 penalty. -#' For alpha = 1.0, it is an L1 penalty. For 0.0 < alpha < 1.0, the penalty is a combination -#' of L1 and L2. Default is 0.0 which is an L2 penalty. +#' @param elasticNetParam the ElasticNet mixing parameter. For alpha = 0.0, the penalty is an L2 +#' penalty. For alpha = 1.0, it is an L1 penalty. For 0.0 < alpha < 1.0, +#' the penalty is a combination of L1 and L2. Default is 0.0 which is an +#' L2 penalty. #' @param maxIter maximum iteration number. #' @param tol convergence tolerance of iterations. -#' @param family the name of family which is a description of the label distribution to be used in the model. +#' @param family the name of family which is a description of the label distribution to be used +#' in the model. #' Supported options: #' \itemize{ #' \item{"auto": Automatically select the family based on the number of classes: #' If number of classes == 1 || number of classes == 2, set to "binomial". #' Else, set to "multinomial".} #' \item{"binomial": Binary logistic regression with pivoting.} -#' \item{"multinomial": Multinomial logistic (softmax) regression without pivoting.} +#' \item{"multinomial": Multinomial logistic (softmax) regression without +#' pivoting.} #' } -#' @param standardization whether to standardize the training features before fitting the model. The coefficients -#' of models will be always returned on the original scale, so it will be transparent for -#' users. Note that with/without standardization, the models should be always converged -#' to the same solution when no regularization is applied. Default is TRUE, same as glmnet. -#' @param thresholds in binary classification, in range [0, 1]. If the estimated probability of class label 1 -#' is > threshold, then predict 1, else 0. A high threshold encourages the model to predict 0 -#' more often; a low threshold encourages the model to predict 1 more often. Note: Setting this with -#' threshold p is equivalent to setting thresholds c(1-p, p). In multiclass (or binary) classification to adjust the probability of -#' predicting each class. Array must have length equal to the number of classes, with values > 0, -#' excepting that at most one value may be 0. The class with largest value p/t is predicted, where p -#' is the original probability of that class and t is the class's threshold. +#' @param standardization whether to standardize the training features before fitting the model. +#' The coefficients of models will be always returned on the original scale, +#' so it will be transparent for users. Note that with/without +#' standardization, the models should be always converged to the same +#' solution when no regularization is applied. Default is TRUE, same as +#' glmnet. +#' @param thresholds in binary classification, in range [0, 1]. If the estimated probability of +#' class label 1 is > threshold, then predict 1, else 0. A high threshold +#' encourages the model to predict 0 more often; a low threshold encourages the +#' model to predict 1 more often. Note: Setting this with threshold p is +#' equivalent to setting thresholds c(1-p, p). In multiclass (or binary) +#' classification to adjust the probability of predicting each class. Array must +#' have length equal to the number of classes, with values > 0, excepting that +#' at most one value may be 0. The class with largest value p/t is predicted, +#' where p is the original probability of that class and t is the class's +#' threshold. #' @param weightCol The weight column name. -#' @param aggregationDepth The depth for treeAggregate (greater than or equal to 2). If the dimensions of features -#' or the number of partitions are large, this param could be adjusted to a larger size. -#' This is an expert parameter. Default value should be good for most cases. -#' @param lowerBoundsOnCoefficients The lower bounds on coefficients if fitting under bound constrained optimization. -#' The bound matrix must be compatible with the shape (1, number of features) for binomial -#' regression, or (number of classes, number of features) for multinomial regression. +#' @param aggregationDepth The depth for treeAggregate (greater than or equal to 2). If the +#' dimensions of features or the number of partitions are large, this param +#' could be adjusted to a larger size. This is an expert parameter. Default +#' value should be good for most cases. +#' @param lowerBoundsOnCoefficients The lower bounds on coefficients if fitting under bound +#' constrained optimization. +#' The bound matrix must be compatible with the shape (1, number +#' of features) for binomial regression, or (number of classes, +#' number of features) for multinomial regression. #' It is a R matrix. -#' @param upperBoundsOnCoefficients The upper bounds on coefficients if fitting under bound constrained optimization. -#' The bound matrix must be compatible with the shape (1, number of features) for binomial -#' regression, or (number of classes, number of features) for multinomial regression. +#' @param upperBoundsOnCoefficients The upper bounds on coefficients if fitting under bound +#' constrained optimization. +#' The bound matrix must be compatible with the shape (1, number +#' of features) for binomial regression, or (number of classes, +#' number of features) for multinomial regression. #' It is a R matrix. -#' @param lowerBoundsOnIntercepts The lower bounds on intercepts if fitting under bound constrained optimization. -#' The bounds vector size must be equal to 1 for binomial regression, or the number -#' of classes for multinomial regression. -#' @param upperBoundsOnIntercepts The upper bounds on intercepts if fitting under bound constrained optimization. -#' The bound vector size must be equal to 1 for binomial regression, or the number +#' @param lowerBoundsOnIntercepts The lower bounds on intercepts if fitting under bound constrained +#' optimization. +#' The bounds vector size must be equal to 1 for binomial regression, +#' or the number #' of classes for multinomial regression. -#' @param handleInvalid How to handle invalid data (unseen labels or NULL values) in features and label -#' column of string type. +#' @param upperBoundsOnIntercepts The upper bounds on intercepts if fitting under bound constrained +#' optimization. +#' The bound vector size must be equal to 1 for binomial regression, +#' or the number of classes for multinomial regression. +#' @param handleInvalid How to handle invalid data (unseen labels or NULL values) in features and +#' label column of string type. #' Supported options: "skip" (filter out rows with invalid data), -#' "error" (throw an error), "keep" (put invalid data in a special additional -#' bucket, at index numLabels). Default is "error". +#' "error" (throw an error), "keep" (put invalid data in +#' a special additional bucket, at index numLabels). Default +#' is "error". #' @param ... additional arguments passed to the method. #' @return \code{spark.logit} returns a fitted logistic regression model. #' @rdname spark.logit @@ -412,11 +433,12 @@ setMethod("write.ml", signature(object = "LogisticRegressionModel", path = "char #' @param seed seed parameter for weights initialization. #' @param initialWeights initialWeights parameter for weights initialization, it should be a #' numeric vector. -#' @param handleInvalid How to handle invalid data (unseen labels or NULL values) in features and label -#' column of string type. +#' @param handleInvalid How to handle invalid data (unseen labels or NULL values) in features and +#' label column of string type. #' Supported options: "skip" (filter out rows with invalid data), -#' "error" (throw an error), "keep" (put invalid data in a special additional -#' bucket, at index numLabels). Default is "error". +#' "error" (throw an error), "keep" (put invalid data in +#' a special additional bucket, at index numLabels). Default +#' is "error". #' @param ... additional arguments passed to the method. #' @return \code{spark.mlp} returns a fitted Multilayer Perceptron Classification Model. #' @rdname spark.mlp @@ -452,11 +474,11 @@ setMethod("spark.mlp", signature(data = "SparkDataFrame", formula = "formula"), handleInvalid = c("error", "keep", "skip")) { formula <- paste(deparse(formula), collapse = "") if (is.null(layers)) { - stop ("layers must be a integer vector with length > 1.") + stop("layers must be a integer vector with length > 1.") } layers <- as.integer(na.omit(layers)) if (length(layers) <= 1) { - stop ("layers must be a integer vector with length > 1.") + stop("layers must be a integer vector with length > 1.") } if (!is.null(seed)) { seed <- as.character(as.integer(seed)) @@ -538,11 +560,12 @@ setMethod("write.ml", signature(object = "MultilayerPerceptronClassificationMode #' @param formula a symbolic description of the model to be fitted. Currently only a few formula #' operators are supported, including '~', '.', ':', '+', and '-'. #' @param smoothing smoothing parameter. -#' @param handleInvalid How to handle invalid data (unseen labels or NULL values) in features and label -#' column of string type. +#' @param handleInvalid How to handle invalid data (unseen labels or NULL values) in features and +#' label column of string type. #' Supported options: "skip" (filter out rows with invalid data), -#' "error" (throw an error), "keep" (put invalid data in a special additional -#' bucket, at index numLabels). Default is "error". +#' "error" (throw an error), "keep" (put invalid data in +#' a special additional bucket, at index numLabels). Default +#' is "error". #' @param ... additional argument(s) passed to the method. Currently only \code{smoothing}. #' @return \code{spark.naiveBayes} returns a fitted naive Bayes model. #' @rdname spark.naiveBayes diff --git a/R/pkg/R/mllib_clustering.R b/R/pkg/R/mllib_clustering.R index 97c9fa1b45840..a25bf81c6d977 100644 --- a/R/pkg/R/mllib_clustering.R +++ b/R/pkg/R/mllib_clustering.R @@ -60,9 +60,9 @@ setClass("LDAModel", representation(jobj = "jobj")) #' @param maxIter maximum iteration number. #' @param seed the random seed. #' @param minDivisibleClusterSize The minimum number of points (if greater than or equal to 1.0) -#' or the minimum proportion of points (if less than 1.0) of a divisible cluster. -#' Note that it is an expert parameter. The default value should be good enough -#' for most cases. +#' or the minimum proportion of points (if less than 1.0) of a +#' divisible cluster. Note that it is an expert parameter. The +#' default value should be good enough for most cases. #' @param ... additional argument(s) passed to the method. #' @return \code{spark.bisectingKmeans} returns a fitted bisecting k-means model. #' @rdname spark.bisectingKmeans @@ -325,10 +325,11 @@ setMethod("write.ml", signature(object = "GaussianMixtureModel", path = "charact #' Note that the response variable of formula is empty in spark.kmeans. #' @param k number of centers. #' @param maxIter maximum iteration number. -#' @param initMode the initialization algorithm choosen to fit the model. +#' @param initMode the initialization algorithm chosen to fit the model. #' @param seed the random seed for cluster initialization. #' @param initSteps the number of steps for the k-means|| initialization mode. -#' This is an advanced setting, the default of 2 is almost always enough. Must be > 0. +#' This is an advanced setting, the default of 2 is almost always enough. +#' Must be > 0. #' @param tol convergence tolerance of iterations. #' @param ... additional argument(s) passed to the method. #' @return \code{spark.kmeans} returns a fitted k-means model. @@ -548,8 +549,8 @@ setMethod("spark.lda", signature(data = "SparkDataFrame"), #' \item{\code{topics}}{top 10 terms and their weights of all topics} #' \item{\code{vocabulary}}{whole terms of the training corpus, NULL if libsvm format file #' used as training set} -#' \item{\code{trainingLogLikelihood}}{Log likelihood of the observed tokens in the training set, -#' given the current parameter estimates: +#' \item{\code{trainingLogLikelihood}}{Log likelihood of the observed tokens in the +#' training set, given the current parameter estimates: #' log P(docs | topics, topic distributions for docs, Dirichlet hyperparameters) #' It is only for distributed LDA model (i.e., optimizer = "em")} #' \item{\code{logPrior}}{Log probability of the current parameter estimate: diff --git a/R/pkg/R/mllib_regression.R b/R/pkg/R/mllib_regression.R index ebaeae970218a..f734a0865ec3b 100644 --- a/R/pkg/R/mllib_regression.R +++ b/R/pkg/R/mllib_regression.R @@ -58,8 +58,8 @@ setClass("IsotonicRegressionModel", representation(jobj = "jobj")) #' Note that there are two ways to specify the tweedie family. #' \itemize{ #' \item Set \code{family = "tweedie"} and specify the var.power and link.power; -#' \item When package \code{statmod} is loaded, the tweedie family is specified using the -#' family definition therein, i.e., \code{tweedie(var.power, link.power)}. +#' \item When package \code{statmod} is loaded, the tweedie family is specified +#' using the family definition therein, i.e., \code{tweedie(var.power, link.power)}. #' } #' @param tol positive convergence tolerance of iterations. #' @param maxIter integer giving the maximal number of IRLS iterations. @@ -71,13 +71,15 @@ setClass("IsotonicRegressionModel", representation(jobj = "jobj")) #' applicable to the Tweedie family. #' @param link.power the index in the power link function. Only applicable to the Tweedie family. #' @param stringIndexerOrderType how to order categories of a string feature column. This is used to -#' decide the base level of a string feature as the last category after -#' ordering is dropped when encoding strings. Supported options are -#' "frequencyDesc", "frequencyAsc", "alphabetDesc", and "alphabetAsc". -#' The default value is "frequencyDesc". When the ordering is set to -#' "alphabetDesc", this drops the same category as R when encoding strings. -#' @param offsetCol the offset column name. If this is not set or empty, we treat all instance offsets -#' as 0.0. The feature specified as offset has a constant coefficient of 1.0. +#' decide the base level of a string feature as the last category +#' after ordering is dropped when encoding strings. Supported options +#' are "frequencyDesc", "frequencyAsc", "alphabetDesc", and +#' "alphabetAsc". The default value is "frequencyDesc". When the +#' ordering is set to "alphabetDesc", this drops the same category +#' as R when encoding strings. +#' @param offsetCol the offset column name. If this is not set or empty, we treat all instance +#' offsets as 0.0. The feature specified as offset has a constant coefficient of +#' 1.0. #' @param ... additional arguments passed to the method. #' @aliases spark.glm,SparkDataFrame,formula-method #' @return \code{spark.glm} returns a fitted generalized linear model. @@ -197,13 +199,15 @@ setMethod("spark.glm", signature(data = "SparkDataFrame", formula = "formula"), #' @param var.power the index of the power variance function in the Tweedie family. #' @param link.power the index of the power link function in the Tweedie family. #' @param stringIndexerOrderType how to order categories of a string feature column. This is used to -#' decide the base level of a string feature as the last category after -#' ordering is dropped when encoding strings. Supported options are -#' "frequencyDesc", "frequencyAsc", "alphabetDesc", and "alphabetAsc". -#' The default value is "frequencyDesc". When the ordering is set to -#' "alphabetDesc", this drops the same category as R when encoding strings. -#' @param offsetCol the offset column name. If this is not set or empty, we treat all instance offsets -#' as 0.0. The feature specified as offset has a constant coefficient of 1.0. +#' decide the base level of a string feature as the last category +#' after ordering is dropped when encoding strings. Supported options +#' are "frequencyDesc", "frequencyAsc", "alphabetDesc", and +#' "alphabetAsc". The default value is "frequencyDesc". When the +#' ordering is set to "alphabetDesc", this drops the same category +#' as R when encoding strings. +#' @param offsetCol the offset column name. If this is not set or empty, we treat all instance +#' offsets as 0.0. The feature specified as offset has a constant coefficient of +#' 1.0. #' @return \code{glm} returns a fitted generalized linear model. #' @rdname glm #' @export @@ -233,11 +237,11 @@ setMethod("glm", signature(formula = "formula", family = "ANY", data = "SparkDat #' @param object a fitted generalized linear model. #' @return \code{summary} returns summary information of the fitted model, which is a list. -#' The list of components includes at least the \code{coefficients} (coefficients matrix, which includes -#' coefficients, standard error of coefficients, t value and p value), +#' The list of components includes at least the \code{coefficients} (coefficients matrix, +#' which includes coefficients, standard error of coefficients, t value and p value), #' \code{null.deviance} (null/residual degrees of freedom), \code{aic} (AIC) -#' and \code{iter} (number of iterations IRLS takes). If there are collinear columns in the data, -#' the coefficients matrix only provides coefficients. +#' and \code{iter} (number of iterations IRLS takes). If there are collinear columns in +#' the data, the coefficients matrix only provides coefficients. #' @rdname spark.glm #' @export #' @note summary(GeneralizedLinearRegressionModel) since 2.0.0 @@ -457,15 +461,17 @@ setMethod("write.ml", signature(object = "IsotonicRegressionModel", path = "char #' @param formula a symbolic description of the model to be fitted. Currently only a few formula #' operators are supported, including '~', ':', '+', and '-'. #' Note that operator '.' is not supported currently. -#' @param aggregationDepth The depth for treeAggregate (greater than or equal to 2). If the dimensions of features -#' or the number of partitions are large, this param could be adjusted to a larger size. -#' This is an expert parameter. Default value should be good for most cases. +#' @param aggregationDepth The depth for treeAggregate (greater than or equal to 2). If the +#' dimensions of features or the number of partitions are large, this +#' param could be adjusted to a larger size. This is an expert parameter. +#' Default value should be good for most cases. #' @param stringIndexerOrderType how to order categories of a string feature column. This is used to -#' decide the base level of a string feature as the last category after -#' ordering is dropped when encoding strings. Supported options are -#' "frequencyDesc", "frequencyAsc", "alphabetDesc", and "alphabetAsc". -#' The default value is "frequencyDesc". When the ordering is set to -#' "alphabetDesc", this drops the same category as R when encoding strings. +#' decide the base level of a string feature as the last category +#' after ordering is dropped when encoding strings. Supported options +#' are "frequencyDesc", "frequencyAsc", "alphabetDesc", and +#' "alphabetAsc". The default value is "frequencyDesc". When the +#' ordering is set to "alphabetDesc", this drops the same category +#' as R when encoding strings. #' @param ... additional arguments passed to the method. #' @return \code{spark.survreg} returns a fitted AFT survival regression model. #' @rdname spark.survreg diff --git a/R/pkg/R/mllib_tree.R b/R/pkg/R/mllib_tree.R index 33c4653f4c184..89a58bf0aadae 100644 --- a/R/pkg/R/mllib_tree.R +++ b/R/pkg/R/mllib_tree.R @@ -132,10 +132,12 @@ print.summary.decisionTree <- function(x) { #' Gradient Boosted Tree model, \code{predict} to make predictions on new data, and #' \code{write.ml}/\code{read.ml} to save/load fitted models. #' For more details, see +# nolint start #' \href{http://spark.apache.org/docs/latest/ml-classification-regression.html#gradient-boosted-tree-regression}{ #' GBT Regression} and #' \href{http://spark.apache.org/docs/latest/ml-classification-regression.html#gradient-boosted-tree-classifier}{ #' GBT Classification} +# nolint end #' #' @param data a SparkDataFrame for training. #' @param formula a symbolic description of the model to be fitted. Currently only a few formula @@ -164,11 +166,12 @@ print.summary.decisionTree <- function(x) { #' nodes. If TRUE, the algorithm will cache node IDs for each instance. Caching #' can speed up training of deeper trees. Users can set how often should the #' cache be checkpointed or disable it by setting checkpointInterval. -#' @param handleInvalid How to handle invalid data (unseen labels or NULL values) in features and label -#' column of string type in classification model. +#' @param handleInvalid How to handle invalid data (unseen labels or NULL values) in features and +#' label column of string type in classification model. #' Supported options: "skip" (filter out rows with invalid data), -#' "error" (throw an error), "keep" (put invalid data in a special additional -#' bucket, at index numLabels). Default is "error". +#' "error" (throw an error), "keep" (put invalid data in +#' a special additional bucket, at index numLabels). Default +#' is "error". #' @param ... additional arguments passed to the method. #' @aliases spark.gbt,SparkDataFrame,formula-method #' @return \code{spark.gbt} returns a fitted Gradient Boosted Tree model. @@ -352,10 +355,12 @@ setMethod("write.ml", signature(object = "GBTClassificationModel", path = "chara #' model, \code{predict} to make predictions on new data, and \code{write.ml}/\code{read.ml} to #' save/load fitted models. #' For more details, see +# nolint start #' \href{http://spark.apache.org/docs/latest/ml-classification-regression.html#random-forest-regression}{ #' Random Forest Regression} and #' \href{http://spark.apache.org/docs/latest/ml-classification-regression.html#random-forest-classifier}{ #' Random Forest Classification} +# nolint end #' #' @param data a SparkDataFrame for training. #' @param formula a symbolic description of the model to be fitted. Currently only a few formula @@ -382,11 +387,12 @@ setMethod("write.ml", signature(object = "GBTClassificationModel", path = "chara #' nodes. If TRUE, the algorithm will cache node IDs for each instance. Caching #' can speed up training of deeper trees. Users can set how often should the #' cache be checkpointed or disable it by setting checkpointInterval. -#' @param handleInvalid How to handle invalid data (unseen labels or NULL values) in features and label -#' column of string type in classification model. +#' @param handleInvalid How to handle invalid data (unseen labels or NULL values) in features and +#' label column of string type in classification model. #' Supported options: "skip" (filter out rows with invalid data), -#' "error" (throw an error), "keep" (put invalid data in a special additional -#' bucket, at index numLabels). Default is "error". +#' "error" (throw an error), "keep" (put invalid data in +#' a special additional bucket, at index numLabels). Default +#' is "error". #' @param ... additional arguments passed to the method. #' @aliases spark.randomForest,SparkDataFrame,formula-method #' @return \code{spark.randomForest} returns a fitted Random Forest model. @@ -567,10 +573,12 @@ setMethod("write.ml", signature(object = "RandomForestClassificationModel", path #' model, \code{predict} to make predictions on new data, and \code{write.ml}/\code{read.ml} to #' save/load fitted models. #' For more details, see +# nolint start #' \href{http://spark.apache.org/docs/latest/ml-classification-regression.html#decision-tree-regression}{ #' Decision Tree Regression} and #' \href{http://spark.apache.org/docs/latest/ml-classification-regression.html#decision-tree-classifier}{ #' Decision Tree Classification} +# nolint end #' #' @param data a SparkDataFrame for training. #' @param formula a symbolic description of the model to be fitted. Currently only a few formula @@ -592,11 +600,12 @@ setMethod("write.ml", signature(object = "RandomForestClassificationModel", path #' nodes. If TRUE, the algorithm will cache node IDs for each instance. Caching #' can speed up training of deeper trees. Users can set how often should the #' cache be checkpointed or disable it by setting checkpointInterval. -#' @param handleInvalid How to handle invalid data (unseen labels or NULL values) in features and label -#' column of string type in classification model. +#' @param handleInvalid How to handle invalid data (unseen labels or NULL values) in features and +#' label column of string type in classification model. #' Supported options: "skip" (filter out rows with invalid data), -#' "error" (throw an error), "keep" (put invalid data in a special additional -#' bucket, at index numLabels). Default is "error". +#' "error" (throw an error), "keep" (put invalid data in +#' a special additional bucket, at index numLabels). Default +#' is "error". #' @param ... additional arguments passed to the method. #' @aliases spark.decisionTree,SparkDataFrame,formula-method #' @return \code{spark.decisionTree} returns a fitted Decision Tree model. @@ -671,7 +680,8 @@ setMethod("spark.decisionTree", signature(data = "SparkDataFrame", formula = "fo #' @return \code{summary} returns summary information of the fitted model, which is a list. #' The list of components includes \code{formula} (formula), #' \code{numFeatures} (number of features), \code{features} (list of features), -#' \code{featureImportances} (feature importances), and \code{maxDepth} (max depth of trees). +#' \code{featureImportances} (feature importances), and \code{maxDepth} (max depth of +#' trees). #' @rdname spark.decisionTree #' @aliases summary,DecisionTreeRegressionModel-method #' @export diff --git a/R/pkg/R/pairRDD.R b/R/pkg/R/pairRDD.R index 8fa21be3076b5..9c2e57d3067db 100644 --- a/R/pkg/R/pairRDD.R +++ b/R/pkg/R/pairRDD.R @@ -860,7 +860,7 @@ setMethod("subtractByKey", other, numPartitions = numPartitions), filterFunction), - function (v) { v[[1]] }) + function(v) { v[[1]] }) }) #' Return a subset of this RDD sampled by key. @@ -925,7 +925,7 @@ setMethod("sampleByKey", if (withReplacement) { count <- stats::rpois(1, frac) if (count > 0) { - res[ (len + 1) : (len + count) ] <- rep(list(elem), count) + res[(len + 1) : (len + count)] <- rep(list(elem), count) len <- len + count } } else { diff --git a/R/pkg/R/schema.R b/R/pkg/R/schema.R index d1ed6833d5d02..65f418740c643 100644 --- a/R/pkg/R/schema.R +++ b/R/pkg/R/schema.R @@ -155,7 +155,7 @@ checkType <- function(type) { } else { # Check complex types firstChar <- substr(type, 1, 1) - switch (firstChar, + switch(firstChar, a = { # Array type m <- regexec("^array<(.+)>$", type) diff --git a/R/pkg/R/stats.R b/R/pkg/R/stats.R index 9a9fa84044ce6..c8af798830b30 100644 --- a/R/pkg/R/stats.R +++ b/R/pkg/R/stats.R @@ -29,9 +29,9 @@ setOldClass("jobj") #' @param col1 name of the first column. Distinct items will make the first item of each row. #' @param col2 name of the second column. Distinct items will make the column names of the output. #' @return a local R data.frame representing the contingency table. The first column of each row -#' will be the distinct values of \code{col1} and the column names will be the distinct values -#' of \code{col2}. The name of the first column will be "\code{col1}_\code{col2}". Pairs -#' that have no occurrences will have zero as their counts. +#' will be the distinct values of \code{col1} and the column names will be the distinct +#' values of \code{col2}. The name of the first column will be "\code{col1}_\code{col2}". +#' Pairs that have no occurrences will have zero as their counts. #' #' @rdname crosstab #' @name crosstab @@ -53,8 +53,8 @@ setMethod("crosstab", }) #' @details -#' \code{cov}: When applied to SparkDataFrame, this calculates the sample covariance of two numerical -#' columns of \emph{one} SparkDataFrame. +#' \code{cov}: When applied to SparkDataFrame, this calculates the sample covariance of two +#' numerical columns of \emph{one} SparkDataFrame. #' #' @param colName1 the name of the first column #' @param colName2 the name of the second column @@ -159,8 +159,8 @@ setMethod("freqItems", signature(x = "SparkDataFrame", cols = "character"), #' @param relativeError The relative target precision to achieve (>= 0). If set to zero, #' the exact quantiles are computed, which could be very expensive. #' Note that values greater than 1 are accepted but give the same result as 1. -#' @return The approximate quantiles at the given probabilities. If the input is a single column name, -#' the output is a list of approximate quantiles in that column; If the input is +#' @return The approximate quantiles at the given probabilities. If the input is a single column +#' name, the output is a list of approximate quantiles in that column; If the input is #' multiple column names, the output should be a list, and each element in it is a list of #' numeric values which represents the approximate quantiles in corresponding column. #' diff --git a/R/pkg/R/utils.R b/R/pkg/R/utils.R index 91483a4d23d9b..4b716995f2c46 100644 --- a/R/pkg/R/utils.R +++ b/R/pkg/R/utils.R @@ -625,7 +625,7 @@ appendPartitionLengths <- function(x, other) { x <- lapplyPartition(x, appendLength) other <- lapplyPartition(other, appendLength) } - list (x, other) + list(x, other) } # Perform zip or cartesian between elements from two RDDs in each partition @@ -657,7 +657,7 @@ mergePartitions <- function(rdd, zip) { keys <- list() } if (lengthOfValues > 1) { - values <- part[ (lengthOfKeys + 1) : (len - 1) ] + values <- part[(lengthOfKeys + 1) : (len - 1)] } else { values <- list() } diff --git a/R/pkg/inst/worker/worker.R b/R/pkg/inst/worker/worker.R index 03e7450147865..00789d815bba8 100644 --- a/R/pkg/inst/worker/worker.R +++ b/R/pkg/inst/worker/worker.R @@ -68,7 +68,7 @@ compute <- function(mode, partition, serializer, deserializer, key, } else { output <- computeFunc(partition, inputData) } - return (output) + return(output) } outputResult <- function(serializer, output, outputCon) { diff --git a/R/pkg/tests/fulltests/test_binary_function.R b/R/pkg/tests/fulltests/test_binary_function.R index 442bed509bb1d..c5d240f3e7344 100644 --- a/R/pkg/tests/fulltests/test_binary_function.R +++ b/R/pkg/tests/fulltests/test_binary_function.R @@ -73,7 +73,7 @@ test_that("zipPartitions() on RDDs", { rdd2 <- parallelize(sc, 1:4, 2L) # 1:2, 3:4 rdd3 <- parallelize(sc, 1:6, 2L) # 1:3, 4:6 actual <- collectRDD(zipPartitions(rdd1, rdd2, rdd3, - func = function(x, y, z) { list(list(x, y, z))} )) + func = function(x, y, z) { list(list(x, y, z))})) expect_equal(actual, list(list(1, c(1, 2), c(1, 2, 3)), list(2, c(3, 4), c(4, 5, 6)))) diff --git a/R/pkg/tests/fulltests/test_rdd.R b/R/pkg/tests/fulltests/test_rdd.R index 6ee1fceffd822..0c702ea897f7c 100644 --- a/R/pkg/tests/fulltests/test_rdd.R +++ b/R/pkg/tests/fulltests/test_rdd.R @@ -698,14 +698,14 @@ test_that("fullOuterJoin() on pairwise RDDs", { }) test_that("sortByKey() on pairwise RDDs", { - numPairsRdd <- map(rdd, function(x) { list (x, x) }) + numPairsRdd <- map(rdd, function(x) { list(x, x) }) sortedRdd <- sortByKey(numPairsRdd, ascending = FALSE) actual <- collectRDD(sortedRdd) - numPairs <- lapply(nums, function(x) { list (x, x) }) + numPairs <- lapply(nums, function(x) { list(x, x) }) expect_equal(actual, sortKeyValueList(numPairs, decreasing = TRUE)) rdd2 <- parallelize(sc, sort(nums, decreasing = TRUE), 2L) - numPairsRdd2 <- map(rdd2, function(x) { list (x, x) }) + numPairsRdd2 <- map(rdd2, function(x) { list(x, x) }) sortedRdd2 <- sortByKey(numPairsRdd2) actual <- collectRDD(sortedRdd2) expect_equal(actual, numPairs) diff --git a/R/pkg/tests/fulltests/test_sparkSQL.R b/R/pkg/tests/fulltests/test_sparkSQL.R index 4e62be9b4d619..7f781f2f66a7f 100644 --- a/R/pkg/tests/fulltests/test_sparkSQL.R +++ b/R/pkg/tests/fulltests/test_sparkSQL.R @@ -560,9 +560,9 @@ test_that("Collect DataFrame with complex types", { expect_equal(nrow(ldf), 3) expect_equal(ncol(ldf), 3) expect_equal(names(ldf), c("c1", "c2", "c3")) - expect_equal(ldf$c1, list(list(1, 2, 3), list(4, 5, 6), list (7, 8, 9))) - expect_equal(ldf$c2, list(list("a", "b", "c"), list("d", "e", "f"), list ("g", "h", "i"))) - expect_equal(ldf$c3, list(list(1.0, 2.0, 3.0), list(4.0, 5.0, 6.0), list (7.0, 8.0, 9.0))) + expect_equal(ldf$c1, list(list(1, 2, 3), list(4, 5, 6), list(7, 8, 9))) + expect_equal(ldf$c2, list(list("a", "b", "c"), list("d", "e", "f"), list("g", "h", "i"))) + expect_equal(ldf$c3, list(list(1.0, 2.0, 3.0), list(4.0, 5.0, 6.0), list(7.0, 8.0, 9.0))) # MapType schema <- structType(structField("name", "string"), @@ -1524,7 +1524,7 @@ test_that("column functions", { expect_equal(ncol(s), 1) expect_equal(nrow(s), 3) expect_is(s[[1]][[1]], "struct") - expect_true(any(apply(s, 1, function(x) { x[[1]]$age == 16 } ))) + expect_true(any(apply(s, 1, function(x) { x[[1]]$age == 16 }))) } # passing option @@ -2710,7 +2710,7 @@ test_that("freqItems() on a DataFrame", { input <- 1:1000 rdf <- data.frame(numbers = input, letters = as.character(input), negDoubles = input * -1.0, stringsAsFactors = F) - rdf[ input %% 3 == 0, ] <- c(1, "1", -1) + rdf[input %% 3 == 0, ] <- c(1, "1", -1) df <- createDataFrame(rdf) multiColResults <- freqItems(df, c("numbers", "letters"), support = 0.1) expect_true(1 %in% multiColResults$numbers[[1]]) @@ -3064,7 +3064,7 @@ test_that("coalesce, repartition, numPartitions", { }) test_that("gapply() and gapplyCollect() on a DataFrame", { - df <- createDataFrame ( + df <- createDataFrame( list(list(1L, 1, "1", 0.1), list(1L, 2, "1", 0.2), list(3L, 3, "3", 0.3)), c("a", "b", "c", "d")) expected <- collect(df) @@ -3135,7 +3135,7 @@ test_that("gapply() and gapplyCollect() on a DataFrame", { actual <- df3Collect[order(df3Collect$a), ] expect_identical(actual$avg, expected$avg) - irisDF <- suppressWarnings(createDataFrame (iris)) + irisDF <- suppressWarnings(createDataFrame(iris)) schema <- structType(structField("Sepal_Length", "double"), structField("Avg", "double")) # Groups by `Sepal_Length` and computes the average for `Sepal_Width` df4 <- gapply( diff --git a/dev/lint-r.R b/dev/lint-r.R index 87ee36d5c9b68..a4261d266bbc0 100644 --- a/dev/lint-r.R +++ b/dev/lint-r.R @@ -26,8 +26,8 @@ if (! library(SparkR, lib.loc = LOCAL_LIB_LOC, logical.return = TRUE)) { # Installs lintr from Github in a local directory. # NOTE: The CRAN's version is too old to adapt to our rules. -if ("lintr" %in% row.names(installed.packages()) == FALSE) { - devtools::install_github("jimhester/lintr@a769c0b") +if ("lintr" %in% row.names(installed.packages()) == FALSE) { + devtools::install_github("jimhester/lintr@5431140") } library(lintr) From 3ca367083e196e6487207211e6c49d4bbfe31288 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Sun, 1 Oct 2017 10:49:22 -0700 Subject: [PATCH 1428/1765] [SPARK-22001][ML][SQL] ImputerModel can do withColumn for all input columns at one pass ## What changes were proposed in this pull request? SPARK-21690 makes one-pass `Imputer` by parallelizing the computation of all input columns. When we transform dataset with `ImputerModel`, we do `withColumn` on all input columns sequentially. We can also do this on all input columns at once by adding a `withColumns` API to `Dataset`. The new `withColumns` API is for internal use only now. ## How was this patch tested? Existing tests for `ImputerModel`'s change. Added tests for `withColumns` API. Author: Liang-Chi Hsieh Closes #19229 from viirya/SPARK-22001. --- .../org/apache/spark/ml/feature/Imputer.scala | 10 ++-- .../scala/org/apache/spark/sql/Dataset.scala | 42 ++++++++++----- .../org/apache/spark/sql/DataFrameSuite.scala | 52 +++++++++++++++++++ 3 files changed, 86 insertions(+), 18 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Imputer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Imputer.scala index 1f36eced3d08f..4663f16b5f5dc 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/Imputer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Imputer.scala @@ -223,20 +223,18 @@ class ImputerModel private[ml] ( override def transform(dataset: Dataset[_]): DataFrame = { transformSchema(dataset.schema, logging = true) - var outputDF = dataset val surrogates = surrogateDF.select($(inputCols).map(col): _*).head().toSeq - $(inputCols).zip($(outputCols)).zip(surrogates).foreach { + val newCols = $(inputCols).zip($(outputCols)).zip(surrogates).map { case ((inputCol, outputCol), surrogate) => val inputType = dataset.schema(inputCol).dataType val ic = col(inputCol) - outputDF = outputDF.withColumn(outputCol, - when(ic.isNull, surrogate) + when(ic.isNull, surrogate) .when(ic === $(missingValue), surrogate) .otherwise(ic) - .cast(inputType)) + .cast(inputType) } - outputDF.toDF() + dataset.withColumns($(outputCols), newCols).toDF() } override def transformSchema(schema: StructType): StructType = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index ab0c4126bcbdd..f2a76a506eb6f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -2083,22 +2083,40 @@ class Dataset[T] private[sql]( * @group untypedrel * @since 2.0.0 */ - def withColumn(colName: String, col: Column): DataFrame = { + def withColumn(colName: String, col: Column): DataFrame = withColumns(Seq(colName), Seq(col)) + + /** + * Returns a new Dataset by adding columns or replacing the existing columns that has + * the same names. + */ + private[spark] def withColumns(colNames: Seq[String], cols: Seq[Column]): DataFrame = { + require(colNames.size == cols.size, + s"The size of column names: ${colNames.size} isn't equal to " + + s"the size of columns: ${cols.size}") + SchemaUtils.checkColumnNameDuplication( + colNames, + "in given column names", + sparkSession.sessionState.conf.caseSensitiveAnalysis) + val resolver = sparkSession.sessionState.analyzer.resolver val output = queryExecution.analyzed.output - val shouldReplace = output.exists(f => resolver(f.name, colName)) - if (shouldReplace) { - val columns = output.map { field => - if (resolver(field.name, colName)) { - col.as(colName) - } else { - Column(field) - } + + val columnMap = colNames.zip(cols).toMap + + val replacedAndExistingColumns = output.map { field => + columnMap.find { case (colName, _) => + resolver(field.name, colName) + } match { + case Some((colName: String, col: Column)) => col.as(colName) + case _ => Column(field) } - select(columns : _*) - } else { - select(Column("*"), col.as(colName)) } + + val newColumns = columnMap.filter { case (colName, col) => + !output.exists(f => resolver(f.name, colName)) + }.map { case (colName, col) => col.as(colName) } + + select(replacedAndExistingColumns ++ newColumns : _*) } /** diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index 0e2f2e5a193e1..672deeac597f1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -641,6 +641,49 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { assert(df.schema.map(_.name) === Seq("key", "value", "newCol")) } + test("withColumns") { + val df = testData.toDF().withColumns(Seq("newCol1", "newCol2"), + Seq(col("key") + 1, col("key") + 2)) + checkAnswer( + df, + testData.collect().map { case Row(key: Int, value: String) => + Row(key, value, key + 1, key + 2) + }.toSeq) + assert(df.schema.map(_.name) === Seq("key", "value", "newCol1", "newCol2")) + + val err = intercept[IllegalArgumentException] { + testData.toDF().withColumns(Seq("newCol1"), + Seq(col("key") + 1, col("key") + 2)) + } + assert( + err.getMessage.contains("The size of column names: 1 isn't equal to the size of columns: 2")) + + val err2 = intercept[AnalysisException] { + testData.toDF().withColumns(Seq("newCol1", "newCOL1"), + Seq(col("key") + 1, col("key") + 2)) + } + assert(err2.getMessage.contains("Found duplicate column(s)")) + } + + test("withColumns: case sensitive") { + withSQLConf(SQLConf.CASE_SENSITIVE.key -> "true") { + val df = testData.toDF().withColumns(Seq("newCol1", "newCOL1"), + Seq(col("key") + 1, col("key") + 2)) + checkAnswer( + df, + testData.collect().map { case Row(key: Int, value: String) => + Row(key, value, key + 1, key + 2) + }.toSeq) + assert(df.schema.map(_.name) === Seq("key", "value", "newCol1", "newCOL1")) + + val err = intercept[AnalysisException] { + testData.toDF().withColumns(Seq("newCol1", "newCol1"), + Seq(col("key") + 1, col("key") + 2)) + } + assert(err.getMessage.contains("Found duplicate column(s)")) + } + } + test("replace column using withColumn") { val df2 = sparkContext.parallelize(Array(1, 2, 3)).toDF("x") val df3 = df2.withColumn("x", df2("x") + 1) @@ -649,6 +692,15 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { Row(2) :: Row(3) :: Row(4) :: Nil) } + test("replace column using withColumns") { + val df2 = sparkContext.parallelize(Array((1, 2), (2, 3), (3, 4))).toDF("x", "y") + val df3 = df2.withColumns(Seq("x", "newCol1", "newCol2"), + Seq(df2("x") + 1, df2("y"), df2("y") + 1)) + checkAnswer( + df3.select("x", "newCol1", "newCol2"), + Row(2, 2, 3) :: Row(3, 3, 4) :: Row(4, 4, 5) :: Nil) + } + test("drop column using drop") { val df = testData.drop("key") checkAnswer( From 405c0e99e7697bfa88aa4abc9a55ce5e043e48b1 Mon Sep 17 00:00:00 2001 From: guoxiaolong Date: Mon, 2 Oct 2017 08:07:56 +0100 Subject: [PATCH 1429/1765] [SPARK-22173][WEB-UI] Table CSS style needs to be adjusted in History Page and in Executors Page. ## What changes were proposed in this pull request? There is a problem with table CSS style. 1. At present, table CSS style is too crowded, and the table width cannot adapt itself. 2. Table CSS style is different from job page, stage page, task page, master page, worker page, etc. The Spark web UI needs to be consistent. fix before: ![01](https://user-images.githubusercontent.com/26266482/31041261-c6766c3a-a5c4-11e7-97a7-96bd51ef12bd.png) ![02](https://user-images.githubusercontent.com/26266482/31041266-d75b6a32-a5c4-11e7-8071-e3bbbba39b80.png) ---------------------------------------------------------------------------------------------------------- fix after: ![1](https://user-images.githubusercontent.com/26266482/31041162-808a5a3e-a5c3-11e7-8d92-d763b500ce53.png) ![2](https://user-images.githubusercontent.com/26266482/31041166-86e583e0-a5c3-11e7-949c-11c370db9e27.png) ## How was this patch tested? manual tests Please review http://spark.apache.org/contributing.html before opening a pull request. Author: guoxiaolong Closes #19397 from guoxiaolongzte/SPARK-22173. --- .../scala/org/apache/spark/deploy/history/HistoryPage.scala | 4 ++-- .../main/scala/org/apache/spark/ui/exec/ExecutorsPage.scala | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/history/HistoryPage.scala b/core/src/main/scala/org/apache/spark/deploy/history/HistoryPage.scala index af14717633409..6399dccc1676a 100644 --- a/core/src/main/scala/org/apache/spark/deploy/history/HistoryPage.scala +++ b/core/src/main/scala/org/apache/spark/deploy/history/HistoryPage.scala @@ -37,7 +37,7 @@ private[history] class HistoryPage(parent: HistoryServer) extends WebUIPage("") val content =
    -
    +
      {providerConfig.map { case (k, v) =>
    • {k}: {v}
    • }}
    @@ -58,7 +58,7 @@ private[history] class HistoryPage(parent: HistoryServer) extends WebUIPage("") { if (allAppsSize > 0) { ++ - ++ +
    ++ ++ ++ diff --git a/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsPage.scala b/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsPage.scala index d63381c78bc3b..7b2767f0be3cd 100644 --- a/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsPage.scala +++ b/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsPage.scala @@ -82,7 +82,7 @@ private[ui] class ExecutorsPage(
    ++ - ++ +
    ++ ++ ++ From 8fab7995d36c7bc4524393b20a4e524dbf6bbf62 Mon Sep 17 00:00:00 2001 From: Holden Karau Date: Mon, 2 Oct 2017 11:46:51 -0700 Subject: [PATCH 1430/1765] [SPARK-22167][R][BUILD] sparkr packaging issue allow zinc ## What changes were proposed in this pull request? When zinc is running the pwd might be in the root of the project. A quick solution to this is to not go a level up incase we are in the root rather than root/core/. If we are in the root everything works fine, if we are in core add a script which goes and runs the level up ## How was this patch tested? set -x in the SparkR install scripts. Author: Holden Karau Closes #19402 from holdenk/SPARK-22167-sparkr-packaging-issue-allow-zinc. --- R/install-dev.sh | 1 + core/pom.xml | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/R/install-dev.sh b/R/install-dev.sh index d613552718307..9fbc999f2e805 100755 --- a/R/install-dev.sh +++ b/R/install-dev.sh @@ -28,6 +28,7 @@ set -o pipefail set -e +set -x FWDIR="$(cd "`dirname "${BASH_SOURCE[0]}"`"; pwd)" LIB_DIR="$FWDIR/lib" diff --git a/core/pom.xml b/core/pom.xml index 09669149d8123..54f7a34a6c37e 100644 --- a/core/pom.xml +++ b/core/pom.xml @@ -499,7 +499,7 @@ - ..${file.separator}R${file.separator}install-dev${script.extension} + ${project.basedir}${file.separator}..${file.separator}R${file.separator}install-dev${script.extension} From e5431f2cfddc8e96194827a2123b92716c7a1467 Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Mon, 2 Oct 2017 15:00:26 -0700 Subject: [PATCH 1431/1765] [SPARK-22158][SQL] convertMetastore should not ignore table property ## What changes were proposed in this pull request? From the beginning, convertMetastoreOrc ignores table properties and use an empty map instead. This PR fixes that. For the diff, please see [this](https://github.com/apache/spark/pull/19382/files?w=1). convertMetastoreParquet also ignore. ```scala val options = Map[String, String]() ``` - [SPARK-14070: HiveMetastoreCatalog.scala](https://github.com/apache/spark/pull/11891/files#diff-ee66e11b56c21364760a5ed2b783f863R650) - [Master branch: HiveStrategies.scala](https://github.com/apache/spark/blob/master/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala#L197 ) ## How was this patch tested? Pass the Jenkins with an updated test suite. Author: Dongjoon Hyun Closes #19382 from dongjoon-hyun/SPARK-22158. --- .../spark/sql/hive/HiveStrategies.scala | 4 +- .../sql/hive/execution/HiveDDLSuite.scala | 54 ++++++++++++++++--- 2 files changed, 50 insertions(+), 8 deletions(-) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala index 805b3171cdaab..3592b8f4846d1 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala @@ -189,12 +189,12 @@ case class RelationConversions( private def convert(relation: HiveTableRelation): LogicalRelation = { val serde = relation.tableMeta.storage.serde.getOrElse("").toLowerCase(Locale.ROOT) if (serde.contains("parquet")) { - val options = Map(ParquetOptions.MERGE_SCHEMA -> + val options = relation.tableMeta.storage.properties + (ParquetOptions.MERGE_SCHEMA -> conf.getConf(HiveUtils.CONVERT_METASTORE_PARQUET_WITH_SCHEMA_MERGING).toString) sessionCatalog.metastoreCatalog .convertToLogicalRelation(relation, options, classOf[ParquetFileFormat], "parquet") } else { - val options = Map[String, String]() + val options = relation.tableMeta.storage.properties sessionCatalog.metastoreCatalog .convertToLogicalRelation(relation, options, classOf[OrcFileFormat], "orc") } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala index 668da5fb47323..02e26bbe876a0 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala @@ -23,6 +23,8 @@ import java.net.URI import scala.language.existentials import org.apache.hadoop.fs.Path +import org.apache.parquet.format.converter.ParquetMetadataConverter.NO_FILTER +import org.apache.parquet.hadoop.ParquetFileReader import org.scalatest.BeforeAndAfterEach import org.apache.spark.SparkException @@ -32,6 +34,7 @@ import org.apache.spark.sql.catalyst.analysis.{NoSuchPartitionException, TableAl import org.apache.spark.sql.catalyst.catalog._ import org.apache.spark.sql.execution.command.{DDLSuite, DDLUtils} import org.apache.spark.sql.hive.HiveExternalCatalog +import org.apache.spark.sql.hive.HiveUtils.{CONVERT_METASTORE_ORC, CONVERT_METASTORE_PARQUET} import org.apache.spark.sql.hive.orc.OrcFileOperator import org.apache.spark.sql.hive.test.TestHiveSingleton import org.apache.spark.sql.internal.{HiveSerDe, SQLConf} @@ -1455,12 +1458,8 @@ class HiveDDLSuite sql("INSERT INTO t SELECT 1") checkAnswer(spark.table("t"), Row(1)) // Check if this is compressed as ZLIB. - val maybeOrcFile = path.listFiles().find(!_.getName.endsWith(".crc")) - assert(maybeOrcFile.isDefined) - val orcFilePath = maybeOrcFile.get.toPath.toString - val expectedCompressionKind = - OrcFileOperator.getFileReader(orcFilePath).get.getCompression - assert("ZLIB" === expectedCompressionKind.name()) + val maybeOrcFile = path.listFiles().find(_.getName.startsWith("part")) + assertCompression(maybeOrcFile, "orc", "ZLIB") sql("CREATE TABLE t2 USING HIVE AS SELECT 1 AS c1, 'a' AS c2") val table2 = spark.sessionState.catalog.getTableMetadata(TableIdentifier("t2")) @@ -2009,4 +2008,47 @@ class HiveDDLSuite } } } + + private def assertCompression(maybeFile: Option[File], format: String, compression: String) = { + assert(maybeFile.isDefined) + + val actualCompression = format match { + case "orc" => + OrcFileOperator.getFileReader(maybeFile.get.toPath.toString).get.getCompression.name + + case "parquet" => + val footer = ParquetFileReader.readFooter( + sparkContext.hadoopConfiguration, new Path(maybeFile.get.getPath), NO_FILTER) + footer.getBlocks.get(0).getColumns.get(0).getCodec.toString + } + + assert(compression === actualCompression) + } + + Seq(("orc", "ZLIB"), ("parquet", "GZIP")).foreach { case (fileFormat, compression) => + test(s"SPARK-22158 convertMetastore should not ignore table property - $fileFormat") { + withSQLConf(CONVERT_METASTORE_ORC.key -> "true", CONVERT_METASTORE_PARQUET.key -> "true") { + withTable("t") { + withTempPath { path => + sql( + s""" + |CREATE TABLE t(id int) USING hive + |OPTIONS(fileFormat '$fileFormat', compression '$compression') + |LOCATION '${path.toURI}' + """.stripMargin) + val table = spark.sessionState.catalog.getTableMetadata(TableIdentifier("t")) + assert(DDLUtils.isHiveTable(table)) + assert(table.storage.serde.get.contains(fileFormat)) + assert(table.storage.properties.get("compression") == Some(compression)) + assert(spark.table("t").collect().isEmpty) + + sql("INSERT INTO t SELECT 1") + checkAnswer(spark.table("t"), Row(1)) + val maybeFile = path.listFiles().find(_.getName.startsWith("part")) + assertCompression(maybeFile, fileFormat, compression) + } + } + } + } + } } From 4329eb2e73181819bb712f57ca9c7feac0d640ea Mon Sep 17 00:00:00 2001 From: Gene Pang Date: Mon, 2 Oct 2017 15:09:11 -0700 Subject: [PATCH 1432/1765] [SPARK-16944][Mesos] Improve data locality when launching new executors when dynamic allocation is enabled ## What changes were proposed in this pull request? Improve the Spark-Mesos coarse-grained scheduler to consider the preferred locations when dynamic allocation is enabled. ## How was this patch tested? Added a unittest, and performed manual testing on AWS. Author: Gene Pang Closes #18098 from gpang/mesos_data_locality. --- .../spark/internal/config/package.scala | 4 ++ .../spark/scheduler/TaskSetManager.scala | 6 +- .../MesosCoarseGrainedSchedulerBackend.scala | 52 ++++++++++++++-- ...osCoarseGrainedSchedulerBackendSuite.scala | 62 +++++++++++++++++++ .../spark/scheduler/cluster/mesos/Utils.scala | 6 ++ 5 files changed, 123 insertions(+), 7 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/internal/config/package.scala b/core/src/main/scala/org/apache/spark/internal/config/package.scala index 44a2815b81a73..d85b6a0200b8d 100644 --- a/core/src/main/scala/org/apache/spark/internal/config/package.scala +++ b/core/src/main/scala/org/apache/spark/internal/config/package.scala @@ -72,6 +72,10 @@ package object config { private[spark] val DYN_ALLOCATION_MAX_EXECUTORS = ConfigBuilder("spark.dynamicAllocation.maxExecutors").intConf.createWithDefault(Int.MaxValue) + private[spark] val LOCALITY_WAIT = ConfigBuilder("spark.locality.wait") + .timeConf(TimeUnit.MILLISECONDS) + .createWithDefaultString("3s") + private[spark] val SHUFFLE_SERVICE_ENABLED = ConfigBuilder("spark.shuffle.service.enabled").booleanConf.createWithDefault(false) diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala index bb867416a4fac..3bdede6743d1b 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala @@ -27,7 +27,7 @@ import scala.util.control.NonFatal import org.apache.spark._ import org.apache.spark.TaskState.TaskState -import org.apache.spark.internal.Logging +import org.apache.spark.internal.{config, Logging} import org.apache.spark.scheduler.SchedulingMode._ import org.apache.spark.util.{AccumulatorV2, Clock, SystemClock, Utils} import org.apache.spark.util.collection.MedianHeap @@ -980,7 +980,7 @@ private[spark] class TaskSetManager( } private def getLocalityWait(level: TaskLocality.TaskLocality): Long = { - val defaultWait = conf.get("spark.locality.wait", "3s") + val defaultWait = conf.get(config.LOCALITY_WAIT) val localityWaitKey = level match { case TaskLocality.PROCESS_LOCAL => "spark.locality.wait.process" case TaskLocality.NODE_LOCAL => "spark.locality.wait.node" @@ -989,7 +989,7 @@ private[spark] class TaskSetManager( } if (localityWaitKey != null) { - conf.getTimeAsMs(localityWaitKey, defaultWait) + conf.getTimeAsMs(localityWaitKey, defaultWait.toString) } else { 0L } diff --git a/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackend.scala b/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackend.scala index 26699873145b4..80c0a041b7322 100644 --- a/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackend.scala +++ b/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackend.scala @@ -99,6 +99,14 @@ private[spark] class MesosCoarseGrainedSchedulerBackend( private var totalCoresAcquired = 0 private var totalGpusAcquired = 0 + // The amount of time to wait for locality scheduling + private val localityWait = conf.get(config.LOCALITY_WAIT) + // The start of the waiting, for data local scheduling + private var localityWaitStartTime = System.currentTimeMillis() + // If true, the scheduler is in the process of launching executors to reach the requested + // executor limit + private var launchingExecutors = false + // SlaveID -> Slave // This map accumulates entries for the duration of the job. Slaves are never deleted, because // we need to maintain e.g. failure state and connection state. @@ -311,6 +319,19 @@ private[spark] class MesosCoarseGrainedSchedulerBackend( return } + if (numExecutors >= executorLimit) { + logDebug("Executor limit reached. numExecutors: " + numExecutors + + " executorLimit: " + executorLimit) + offers.asScala.map(_.getId).foreach(d.declineOffer) + launchingExecutors = false + return + } else { + if (!launchingExecutors) { + launchingExecutors = true + localityWaitStartTime = System.currentTimeMillis() + } + } + logDebug(s"Received ${offers.size} resource offers.") val (matchedOffers, unmatchedOffers) = offers.asScala.partition { offer => @@ -413,7 +434,7 @@ private[spark] class MesosCoarseGrainedSchedulerBackend( val offerId = offer.getId.getValue val resources = remainingResources(offerId) - if (canLaunchTask(slaveId, resources)) { + if (canLaunchTask(slaveId, offer.getHostname, resources)) { // Create a task launchTasks = true val taskId = newMesosTaskId() @@ -477,7 +498,8 @@ private[spark] class MesosCoarseGrainedSchedulerBackend( cpuResourcesToUse ++ memResourcesToUse ++ portResourcesToUse ++ gpuResourcesToUse) } - private def canLaunchTask(slaveId: String, resources: JList[Resource]): Boolean = { + private def canLaunchTask(slaveId: String, offerHostname: String, + resources: JList[Resource]): Boolean = { val offerMem = getResource(resources, "mem") val offerCPUs = getResource(resources, "cpus").toInt val cpus = executorCores(offerCPUs) @@ -489,9 +511,10 @@ private[spark] class MesosCoarseGrainedSchedulerBackend( cpus <= offerCPUs && cpus + totalCoresAcquired <= maxCores && mem <= offerMem && - numExecutors() < executorLimit && + numExecutors < executorLimit && slaves.get(slaveId).map(_.taskFailures).getOrElse(0) < MAX_SLAVE_FAILURES && - meetsPortRequirements + meetsPortRequirements && + satisfiesLocality(offerHostname) } private def executorCores(offerCPUs: Int): Int = { @@ -500,6 +523,25 @@ private[spark] class MesosCoarseGrainedSchedulerBackend( ) } + private def satisfiesLocality(offerHostname: String): Boolean = { + if (!Utils.isDynamicAllocationEnabled(conf) || hostToLocalTaskCount.isEmpty) { + return true + } + + // Check the locality information + val currentHosts = slaves.values.filter(_.taskIDs.nonEmpty).map(_.hostname).toSet + val allDesiredHosts = hostToLocalTaskCount.keys.toSet + // Try to match locality for hosts which do not have executors yet, to potentially + // increase coverage. + val remainingHosts = allDesiredHosts -- currentHosts + if (!remainingHosts.contains(offerHostname) && + (System.currentTimeMillis() - localityWaitStartTime <= localityWait)) { + logDebug("Skipping host and waiting for locality. host: " + offerHostname) + return false + } + return true + } + override def statusUpdate(d: org.apache.mesos.SchedulerDriver, status: TaskStatus) { val taskId = status.getTaskId.getValue val slaveId = status.getSlaveId.getValue @@ -646,6 +688,8 @@ private[spark] class MesosCoarseGrainedSchedulerBackend( // since at coarse grain it depends on the amount of slaves available. logInfo("Capping the total amount of executors to " + requestedTotal) executorLimitOption = Some(requestedTotal) + // Update the locality wait start time to continue trying for locality. + localityWaitStartTime = System.currentTimeMillis() true } diff --git a/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackendSuite.scala b/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackendSuite.scala index f6bae01c3af59..6c40792112f49 100644 --- a/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackendSuite.scala +++ b/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackendSuite.scala @@ -604,6 +604,55 @@ class MesosCoarseGrainedSchedulerBackendSuite extends SparkFunSuite assert(backend.isReady) } + test("supports data locality with dynamic allocation") { + setBackend(Map( + "spark.dynamicAllocation.enabled" -> "true", + "spark.dynamicAllocation.testing" -> "true", + "spark.locality.wait" -> "1s")) + + assert(backend.getExecutorIds().isEmpty) + + backend.requestTotalExecutors(2, 2, Map("hosts10" -> 1, "hosts11" -> 1)) + + // Offer non-local resources, which should be rejected + offerResourcesAndVerify(1, false) + offerResourcesAndVerify(2, false) + + // Offer local resource + offerResourcesAndVerify(10, true) + + // Wait longer than spark.locality.wait + Thread.sleep(2000) + + // Offer non-local resource, which should be accepted + offerResourcesAndVerify(1, true) + + // Update total executors + backend.requestTotalExecutors(3, 3, Map("hosts10" -> 1, "hosts11" -> 1, "hosts12" -> 1)) + + // Offer non-local resources, which should be rejected + offerResourcesAndVerify(3, false) + + // Wait longer than spark.locality.wait + Thread.sleep(2000) + + // Update total executors + backend.requestTotalExecutors(4, 4, Map("hosts10" -> 1, "hosts11" -> 1, "hosts12" -> 1, + "hosts13" -> 1)) + + // Offer non-local resources, which should be rejected + offerResourcesAndVerify(3, false) + + // Offer local resource + offerResourcesAndVerify(13, true) + + // Wait longer than spark.locality.wait + Thread.sleep(2000) + + // Offer non-local resource, which should be accepted + offerResourcesAndVerify(2, true) + } + private case class Resources(mem: Int, cpus: Int, gpus: Int = 0) private def registerMockExecutor(executorId: String, slaveId: String, cores: Integer) = { @@ -631,6 +680,19 @@ class MesosCoarseGrainedSchedulerBackendSuite extends SparkFunSuite backend.resourceOffers(driver, mesosOffers.asJava) } + private def offerResourcesAndVerify(id: Int, expectAccept: Boolean): Unit = { + offerResources(List(Resources(backend.executorMemory(sc), 1)), id) + if (expectAccept) { + val numExecutors = backend.getExecutorIds().size + val launchedTasks = verifyTaskLaunched(driver, s"o$id") + assert(s"s$id" == launchedTasks.head.getSlaveId.getValue) + registerMockExecutor(launchedTasks.head.getTaskId.getValue, s"s$id", 1) + assert(backend.getExecutorIds().size == numExecutors + 1) + } else { + verifyTaskNotLaunched(driver, s"o$id") + } + } + private def createTaskStatus(taskId: String, slaveId: String, state: TaskState): TaskStatus = { TaskStatus.newBuilder() .setTaskId(TaskID.newBuilder().setValue(taskId).build()) diff --git a/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/Utils.scala b/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/Utils.scala index 2a67cbc913ffe..833db0c1ff334 100644 --- a/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/Utils.scala +++ b/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/Utils.scala @@ -84,6 +84,12 @@ object Utils { captor.getValue.asScala.toList } + def verifyTaskNotLaunched(driver: SchedulerDriver, offerId: String): Unit = { + verify(driver, times(0)).launchTasks( + Matchers.eq(Collections.singleton(createOfferId(offerId))), + Matchers.any(classOf[java.util.Collection[TaskInfo]])) + } + def createOfferId(offerId: String): OfferID = { OfferID.newBuilder().setValue(offerId).build() } From fa225da7463e384529da14706e44f4a09772e5c1 Mon Sep 17 00:00:00 2001 From: Takeshi Yamamuro Date: Mon, 2 Oct 2017 15:25:33 -0700 Subject: [PATCH 1433/1765] [SPARK-22176][SQL] Fix overflow issue in Dataset.show ## What changes were proposed in this pull request? This pr fixed an overflow issue below in `Dataset.show`: ``` scala> Seq((1, 2), (3, 4)).toDF("a", "b").show(Int.MaxValue) org.apache.spark.sql.AnalysisException: The limit expression must be equal to or greater than 0, but got -2147483648;; GlobalLimit -2147483648 +- LocalLimit -2147483648 +- Project [_1#27218 AS a#27221, _2#27219 AS b#27222] +- LocalRelation [_1#27218, _2#27219] at org.apache.spark.sql.catalyst.analysis.CheckAnalysis$class.failAnalysis(CheckAnalysis.scala:41) at org.apache.spark.sql.catalyst.analysis.Analyzer.failAnalysis(Analyzer.scala:89) at org.apache.spark.sql.catalyst.analysis.CheckAnalysis$class.org$apache$spark$sql$catalyst$analysis$CheckAnalysis$$checkLimitClause(CheckAnalysis.scala:70) at org.apache.spark.sql.catalyst.analysis.CheckAnalysis$$anonfun$checkAnalysis$1.apply(CheckAnalysis.scala:234) at org.apache.spark.sql.catalyst.analysis.CheckAnalysis$$anonfun$checkAnalysis$1.apply(CheckAnalysis.scala:80) at org.apache.spark.sql.catalyst.trees.TreeNode.foreachUp(TreeNode.scala:127) ``` ## How was this patch tested? Added tests in `DataFrameSuite`. Author: Takeshi Yamamuro Closes #19401 from maropu/MaxValueInShowString. --- .../main/scala/org/apache/spark/sql/Dataset.scala | 2 +- .../scala/org/apache/spark/sql/DataFrameSuite.scala | 12 ++++++++++++ 2 files changed, 13 insertions(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index f2a76a506eb6f..b70dfc05330f8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -237,7 +237,7 @@ class Dataset[T] private[sql]( */ private[sql] def showString( _numRows: Int, truncate: Int = 20, vertical: Boolean = false): String = { - val numRows = _numRows.max(0) + val numRows = _numRows.max(0).min(Int.MaxValue - 1) val takeResult = toDF().take(numRows + 1) val hasMoreData = takeResult.length > numRows val data = takeResult.take(numRows) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index 672deeac597f1..dd8f54b690f64 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -1045,6 +1045,18 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { assert(testData.select($"*").showString(0) === expectedAnswer) } + test("showString(Int.MaxValue)") { + val df = Seq((1, 2), (3, 4)).toDF("a", "b") + val expectedAnswer = """+---+---+ + || a| b| + |+---+---+ + || 1| 2| + || 3| 4| + |+---+---+ + |""".stripMargin + assert(df.showString(Int.MaxValue) === expectedAnswer) + } + test("showString(0), vertical = true") { val expectedAnswer = "(0 rows)\n" assert(testData.select($"*").showString(0, vertical = true) === expectedAnswer) From 4c5158eec9101ef105274df6b488e292a56156a2 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Tue, 3 Oct 2017 12:38:13 -0700 Subject: [PATCH 1434/1765] [SPARK-21644][SQL] LocalLimit.maxRows is defined incorrectly ## What changes were proposed in this pull request? The definition of `maxRows` in `LocalLimit` operator was simply wrong. This patch introduces a new `maxRowsPerPartition` method and uses that in pruning. The patch also adds more documentation on why we need local limit vs global limit. Note that this previously has never been a bug because the way the code is structured, but future use of the maxRows could lead to bugs. ## How was this patch tested? Should be covered by existing test cases. Closes #18851 Author: gatorsmile Author: Reynold Xin Closes #19393 from gatorsmile/pr-18851. --- .../sql/catalyst/optimizer/Optimizer.scala | 29 ++++++----- .../catalyst/plans/logical/LogicalPlan.scala | 5 ++ .../plans/logical/basicLogicalOperators.scala | 49 ++++++++++++++++++- .../execution/basicPhysicalOperators.scala | 3 ++ 4 files changed, 74 insertions(+), 12 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index b9fa39d6dad4c..bc2d4a824cb49 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -305,13 +305,20 @@ object LimitPushDown extends Rule[LogicalPlan] { } } - private def maybePushLimit(limitExp: Expression, plan: LogicalPlan): LogicalPlan = { - (limitExp, plan.maxRows) match { - case (IntegerLiteral(maxRow), Some(childMaxRows)) if maxRow < childMaxRows => + private def maybePushLocalLimit(limitExp: Expression, plan: LogicalPlan): LogicalPlan = { + (limitExp, plan.maxRowsPerPartition) match { + case (IntegerLiteral(newLimit), Some(childMaxRows)) if newLimit < childMaxRows => + // If the child has a cap on max rows per partition and the cap is larger than + // the new limit, put a new LocalLimit there. LocalLimit(limitExp, stripGlobalLimitIfPresent(plan)) + case (_, None) => + // If the child has no cap, put the new LocalLimit. LocalLimit(limitExp, stripGlobalLimitIfPresent(plan)) - case _ => plan + + case _ => + // Otherwise, don't put a new LocalLimit. + plan } } @@ -323,7 +330,7 @@ object LimitPushDown extends Rule[LogicalPlan] { // pushdown Limit through it. Once we add UNION DISTINCT, however, we will not be able to // pushdown Limit. case LocalLimit(exp, Union(children)) => - LocalLimit(exp, Union(children.map(maybePushLimit(exp, _)))) + LocalLimit(exp, Union(children.map(maybePushLocalLimit(exp, _)))) // Add extra limits below OUTER JOIN. For LEFT OUTER and FULL OUTER JOIN we push limits to the // left and right sides, respectively. For FULL OUTER JOIN, we can only push limits to one side // because we need to ensure that rows from the limited side still have an opportunity to match @@ -335,19 +342,19 @@ object LimitPushDown extends Rule[LogicalPlan] { // - If neither side is limited, limit the side that is estimated to be bigger. case LocalLimit(exp, join @ Join(left, right, joinType, _)) => val newJoin = joinType match { - case RightOuter => join.copy(right = maybePushLimit(exp, right)) - case LeftOuter => join.copy(left = maybePushLimit(exp, left)) + case RightOuter => join.copy(right = maybePushLocalLimit(exp, right)) + case LeftOuter => join.copy(left = maybePushLocalLimit(exp, left)) case FullOuter => (left.maxRows, right.maxRows) match { case (None, None) => if (left.stats.sizeInBytes >= right.stats.sizeInBytes) { - join.copy(left = maybePushLimit(exp, left)) + join.copy(left = maybePushLocalLimit(exp, left)) } else { - join.copy(right = maybePushLimit(exp, right)) + join.copy(right = maybePushLocalLimit(exp, right)) } case (Some(_), Some(_)) => join - case (Some(_), None) => join.copy(left = maybePushLimit(exp, left)) - case (None, Some(_)) => join.copy(right = maybePushLimit(exp, right)) + case (Some(_), None) => join.copy(left = maybePushLocalLimit(exp, left)) + case (None, Some(_)) => join.copy(right = maybePushLocalLimit(exp, right)) } case _ => join diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala index 68aae720e026a..14188829db2af 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala @@ -97,6 +97,11 @@ abstract class LogicalPlan */ def maxRows: Option[Long] = None + /** + * Returns the maximum number of rows this plan may compute on each partition. + */ + def maxRowsPerPartition: Option[Long] = maxRows + /** * Returns true if this expression and all its children have been resolved to a specific schema * and false if it still contains any unresolved placeholders. Implementations of LogicalPlan diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala index f443cd5a69de3..80243d3d356ca 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala @@ -191,6 +191,9 @@ object Union { } } +/** + * Logical plan for unioning two plans, without a distinct. This is UNION ALL in SQL. + */ case class Union(children: Seq[LogicalPlan]) extends LogicalPlan { override def maxRows: Option[Long] = { if (children.exists(_.maxRows.isEmpty)) { @@ -200,6 +203,17 @@ case class Union(children: Seq[LogicalPlan]) extends LogicalPlan { } } + /** + * Note the definition has assumption about how union is implemented physically. + */ + override def maxRowsPerPartition: Option[Long] = { + if (children.exists(_.maxRowsPerPartition.isEmpty)) { + None + } else { + Some(children.flatMap(_.maxRowsPerPartition).sum) + } + } + // updating nullability to make all the children consistent override def output: Seq[Attribute] = children.map(_.output).transpose.map(attrs => @@ -669,6 +683,27 @@ case class Pivot( } } +/** + * A constructor for creating a logical limit, which is split into two separate logical nodes: + * a [[LocalLimit]], which is a partition local limit, followed by a [[GlobalLimit]]. + * + * This muds the water for clean logical/physical separation, and is done for better limit pushdown. + * In distributed query processing, a non-terminal global limit is actually an expensive operation + * because it requires coordination (in Spark this is done using a shuffle). + * + * In most cases when we want to push down limit, it is often better to only push some partition + * local limit. Consider the following: + * + * GlobalLimit(Union(A, B)) + * + * It is better to do + * GlobalLimit(Union(LocalLimit(A), LocalLimit(B))) + * + * than + * Union(GlobalLimit(A), GlobalLimit(B)). + * + * So we introduced LocalLimit and GlobalLimit in the logical plan node for limit pushdown. + */ object Limit { def apply(limitExpr: Expression, child: LogicalPlan): UnaryNode = { GlobalLimit(limitExpr, LocalLimit(limitExpr, child)) @@ -682,6 +717,11 @@ object Limit { } } +/** + * A global (coordinated) limit. This operator can emit at most `limitExpr` number in total. + * + * See [[Limit]] for more information. + */ case class GlobalLimit(limitExpr: Expression, child: LogicalPlan) extends UnaryNode { override def output: Seq[Attribute] = child.output override def maxRows: Option[Long] = { @@ -692,9 +732,16 @@ case class GlobalLimit(limitExpr: Expression, child: LogicalPlan) extends UnaryN } } +/** + * A partition-local (non-coordinated) limit. This operator can emit at most `limitExpr` number + * of tuples on each physical partition. + * + * See [[Limit]] for more information. + */ case class LocalLimit(limitExpr: Expression, child: LogicalPlan) extends UnaryNode { override def output: Seq[Attribute] = child.output - override def maxRows: Option[Long] = { + + override def maxRowsPerPartition: Option[Long] = { limitExpr match { case IntegerLiteral(limit) => Some(limit) case _ => None diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala index 8389e2f3d5be9..63cd1691f4cd7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala @@ -554,6 +554,9 @@ case class RangeExec(range: org.apache.spark.sql.catalyst.plans.logical.Range) /** * Physical plan for unioning two plans, without a distinct. This is UNION ALL in SQL. + * + * If we change how this is implemented physically, we'd need to update + * [[org.apache.spark.sql.catalyst.plans.logical.Union.maxRowsPerPartition]]. */ case class UnionExec(children: Seq[SparkPlan]) extends SparkPlan { override def output: Seq[Attribute] = From e65b6b7ca1a7cff1b91ad2262bb7941e6bf057cd Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Tue, 3 Oct 2017 12:40:22 -0700 Subject: [PATCH 1435/1765] [SPARK-22178][SQL] Refresh Persistent Views by REFRESH TABLE Command ## What changes were proposed in this pull request? The underlying tables of persistent views are not refreshed when users issue the REFRESH TABLE command against the persistent views. ## How was this patch tested? Added a test case Author: gatorsmile Closes #19405 from gatorsmile/refreshView. --- .../apache/spark/sql/internal/CatalogImpl.scala | 15 +++++++++++---- .../spark/sql/hive/HiveMetadataCacheSuite.scala | 14 +++++++++++--- 2 files changed, 22 insertions(+), 7 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/CatalogImpl.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/CatalogImpl.scala index 142b005850a49..fdd25330c5e67 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/CatalogImpl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/CatalogImpl.scala @@ -474,13 +474,20 @@ class CatalogImpl(sparkSession: SparkSession) extends Catalog { */ override def refreshTable(tableName: String): Unit = { val tableIdent = sparkSession.sessionState.sqlParser.parseTableIdentifier(tableName) - // Temp tables: refresh (or invalidate) any metadata/data cached in the plan recursively. - // Non-temp tables: refresh the metadata cache. - sessionCatalog.refreshTable(tableIdent) + val tableMetadata = sessionCatalog.getTempViewOrPermanentTableMetadata(tableIdent) + val table = sparkSession.table(tableIdent) + + if (tableMetadata.tableType == CatalogTableType.VIEW) { + // Temp or persistent views: refresh (or invalidate) any metadata/data cached + // in the plan recursively. + table.queryExecution.analyzed.foreach(_.refresh()) + } else { + // Non-temp tables: refresh the metadata cache. + sessionCatalog.refreshTable(tableIdent) + } // If this table is cached as an InMemoryRelation, drop the original // cached version and make the new version cached lazily. - val table = sparkSession.table(tableIdent) if (isCached(table)) { // Uncache the logicalPlan. sparkSession.sharedState.cacheManager.uncacheQuery(table, blocking = true) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetadataCacheSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetadataCacheSuite.scala index 0c28a1b609bb8..e71aba72c31fe 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetadataCacheSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetadataCacheSuite.scala @@ -31,14 +31,22 @@ import org.apache.spark.sql.test.SQLTestUtils class HiveMetadataCacheSuite extends QueryTest with SQLTestUtils with TestHiveSingleton { test("SPARK-16337 temporary view refresh") { - withTempView("view_refresh") { + checkRefreshView(isTemp = true) + } + + test("view refresh") { + checkRefreshView(isTemp = false) + } + + private def checkRefreshView(isTemp: Boolean) { + withView("view_refresh") { withTable("view_table") { // Create a Parquet directory spark.range(start = 0, end = 100, step = 1, numPartitions = 3) .write.saveAsTable("view_table") - // Read the table in - spark.table("view_table").filter("id > -1").createOrReplaceTempView("view_refresh") + val temp = if (isTemp) "TEMPORARY" else "" + spark.sql(s"CREATE $temp VIEW view_refresh AS SELECT * FROM view_table WHERE id > -1") assert(sql("select count(*) from view_refresh").first().getLong(0) == 100) // Delete a file using the Hadoop file system interface since the path returned by From e36ec38d89472df0dfe12222b6af54cd6eea8e98 Mon Sep 17 00:00:00 2001 From: Sahil Takiar Date: Tue, 3 Oct 2017 16:53:32 -0700 Subject: [PATCH 1436/1765] [SPARK-20466][CORE] HadoopRDD#addLocalConfiguration throws NPE ## What changes were proposed in this pull request? Fix for SPARK-20466, full description of the issue in the JIRA. To summarize, `HadoopRDD` uses a metadata cache to cache `JobConf` objects. The cache uses soft-references, which means the JVM can delete entries from the cache whenever there is GC pressure. `HadoopRDD#getJobConf` had a bug where it would check if the cache contained the `JobConf`, if it did it would get the `JobConf` from the cache and return it. This doesn't work when soft-references are used as the JVM can delete the entry between the existence check and the get call. ## How was this patch tested? Haven't thought of a good way to test this yet given the issue only occurs sometimes, and happens during high GC pressure. Was thinking of using mocks to verify `#getJobConf` is doing the right thing. I deleted the method `HadoopRDD#containsCachedMetadata` so that we don't hit this issue again. Author: Sahil Takiar Closes #19413 from sahilTakiar/master. --- .../org/apache/spark/rdd/HadoopRDD.scala | 33 ++++++++++--------- 1 file changed, 18 insertions(+), 15 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala b/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala index 76ea8b86c53d2..23b344230e490 100644 --- a/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala @@ -157,20 +157,25 @@ class HadoopRDD[K, V]( if (conf.isInstanceOf[JobConf]) { logDebug("Re-using user-broadcasted JobConf") conf.asInstanceOf[JobConf] - } else if (HadoopRDD.containsCachedMetadata(jobConfCacheKey)) { - logDebug("Re-using cached JobConf") - HadoopRDD.getCachedMetadata(jobConfCacheKey).asInstanceOf[JobConf] } else { - // Create a JobConf that will be cached and used across this RDD's getJobConf() calls in the - // local process. The local cache is accessed through HadoopRDD.putCachedMetadata(). - // The caching helps minimize GC, since a JobConf can contain ~10KB of temporary objects. - // Synchronize to prevent ConcurrentModificationException (SPARK-1097, HADOOP-10456). - HadoopRDD.CONFIGURATION_INSTANTIATION_LOCK.synchronized { - logDebug("Creating new JobConf and caching it for later re-use") - val newJobConf = new JobConf(conf) - initLocalJobConfFuncOpt.foreach(f => f(newJobConf)) - HadoopRDD.putCachedMetadata(jobConfCacheKey, newJobConf) - newJobConf + Option(HadoopRDD.getCachedMetadata(jobConfCacheKey)) + .map { conf => + logDebug("Re-using cached JobConf") + conf.asInstanceOf[JobConf] + } + .getOrElse { + // Create a JobConf that will be cached and used across this RDD's getJobConf() calls in + // the local process. The local cache is accessed through HadoopRDD.putCachedMetadata(). + // The caching helps minimize GC, since a JobConf can contain ~10KB of temporary + // objects. Synchronize to prevent ConcurrentModificationException (SPARK-1097, + // HADOOP-10456). + HadoopRDD.CONFIGURATION_INSTANTIATION_LOCK.synchronized { + logDebug("Creating new JobConf and caching it for later re-use") + val newJobConf = new JobConf(conf) + initLocalJobConfFuncOpt.foreach(f => f(newJobConf)) + HadoopRDD.putCachedMetadata(jobConfCacheKey, newJobConf) + newJobConf + } } } } @@ -360,8 +365,6 @@ private[spark] object HadoopRDD extends Logging { */ def getCachedMetadata(key: String): Any = SparkEnv.get.hadoopJobMetadata.get(key) - def containsCachedMetadata(key: String): Boolean = SparkEnv.get.hadoopJobMetadata.containsKey(key) - private def putCachedMetadata(key: String, value: Any): Unit = SparkEnv.get.hadoopJobMetadata.put(key, value) From 5f694334534e4425fb9e8abf5b7e3e5efdfcef50 Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Tue, 3 Oct 2017 21:27:58 -0700 Subject: [PATCH 1437/1765] [SPARK-22171][SQL] Describe Table Extended Failed when Table Owner is Empty ## What changes were proposed in this pull request? Users could hit `java.lang.NullPointerException` when the tables were created by Hive and the table's owner is `null` that are got from Hive metastore. `DESC EXTENDED` failed with the error: > SQLExecutionException: java.lang.NullPointerException at scala.collection.immutable.StringOps$.length$extension(StringOps.scala:47) at scala.collection.immutable.StringOps.length(StringOps.scala:47) at scala.collection.IndexedSeqOptimized$class.isEmpty(IndexedSeqOptimized.scala:27) at scala.collection.immutable.StringOps.isEmpty(StringOps.scala:29) at scala.collection.TraversableOnce$class.nonEmpty(TraversableOnce.scala:111) at scala.collection.immutable.StringOps.nonEmpty(StringOps.scala:29) at org.apache.spark.sql.catalyst.catalog.CatalogTable.toLinkedHashMap(interface.scala:300) at org.apache.spark.sql.execution.command.DescribeTableCommand.describeFormattedTableInfo(tables.scala:565) at org.apache.spark.sql.execution.command.DescribeTableCommand.run(tables.scala:543) at org.apache.spark.sql.execution.command.ExecutedCommandExec.sideEffectResult$lzycompute(commands.scala:66) at ## How was this patch tested? Added a unit test case Author: gatorsmile Closes #19395 from gatorsmile/desc. --- .../sql/catalyst/catalog/interface.scala | 2 +- .../sql/catalyst/analysis/CatalogSuite.scala | 37 +++++++++++++++++++ .../sql/hive/client/HiveClientImpl.scala | 2 +- 3 files changed, 39 insertions(+), 2 deletions(-) create mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/CatalogSuite.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala index 1965144e81197..fe2af910a0ae5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala @@ -307,7 +307,7 @@ case class CatalogTable( identifier.database.foreach(map.put("Database", _)) map.put("Table", identifier.table) - if (owner.nonEmpty) map.put("Owner", owner) + if (owner != null && owner.nonEmpty) map.put("Owner", owner) map.put("Created Time", new Date(createTime).toString) map.put("Last Access", new Date(lastAccessTime).toString) map.put("Created By", "Spark " + createVersion) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/CatalogSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/CatalogSuite.scala new file mode 100644 index 0000000000000..d670053ba1b5d --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/CatalogSuite.scala @@ -0,0 +1,37 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.analysis + +import org.apache.spark.sql.catalyst.TableIdentifier +import org.apache.spark.sql.catalyst.catalog.{CatalogStorageFormat, CatalogTable, CatalogTableType} +import org.apache.spark.sql.types.StructType + + +class CatalogSuite extends AnalysisTest { + + test("desc table when owner is set to null") { + val table = CatalogTable( + identifier = TableIdentifier("tbl", Some("db1")), + tableType = CatalogTableType.MANAGED, + storage = CatalogStorageFormat.empty, + owner = null, + schema = new StructType().add("col1", "int").add("col2", "string"), + provider = Some("parquet")) + table.toLinkedHashMap + } +} diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala index c4e48c9360db7..66165c7228bca 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala @@ -461,7 +461,7 @@ private[hive] class HiveClientImpl( // in table properties. This means, if we have bucket spec in both hive metastore and // table properties, we will trust the one in table properties. bucketSpec = bucketSpec, - owner = h.getOwner, + owner = Option(h.getOwner).getOrElse(""), createTime = h.getTTable.getCreateTime.toLong * 1000, lastAccessTime = h.getLastAccessTime.toLong * 1000, storage = CatalogStorageFormat( From 3099c574c56cab86c3fcf759864f89151643f837 Mon Sep 17 00:00:00 2001 From: Jose Torres Date: Tue, 3 Oct 2017 21:42:51 -0700 Subject: [PATCH 1438/1765] [SPARK-22136][SS] Implement stream-stream outer joins. ## What changes were proposed in this pull request? Allow one-sided outer joins between two streams when a watermark is defined. ## How was this patch tested? new unit tests Author: Jose Torres Closes #19327 from joseph-torres/outerjoin. --- .../analysis/StreamingJoinHelper.scala | 286 +++++++++++++++++ .../UnsupportedOperationChecker.scala | 53 +++- .../analysis/StreamingJoinHelperSuite.scala | 140 ++++++++ .../analysis/UnsupportedOperationsSuite.scala | 108 ++++++- .../StreamingSymmetricHashJoinExec.scala | 152 +++++++-- .../StreamingSymmetricHashJoinHelper.scala | 241 +------------- .../state/SymmetricHashJoinStateManager.scala | 200 +++++++++--- .../SymmetricHashJoinStateManagerSuite.scala | 6 +- .../sql/streaming/StreamingJoinSuite.scala | 298 +++++++++++------- 9 files changed, 1051 insertions(+), 433 deletions(-) create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/StreamingJoinHelper.scala create mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/StreamingJoinHelperSuite.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/StreamingJoinHelper.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/StreamingJoinHelper.scala new file mode 100644 index 0000000000000..072dc954879ca --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/StreamingJoinHelper.scala @@ -0,0 +1,286 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.analysis + +import scala.util.control.NonFatal + +import org.apache.spark.internal.Logging +import org.apache.spark.sql.catalyst.expressions.{Add, AttributeReference, AttributeSet, Cast, CheckOverflow, Expression, ExpressionSet, GreaterThan, GreaterThanOrEqual, LessThan, LessThanOrEqual, Literal, Multiply, PreciseTimestampConversion, PredicateHelper, Subtract, TimeAdd, TimeSub, UnaryMinus} +import org.apache.spark.sql.catalyst.planning.ExtractEquiJoinKeys +import org.apache.spark.sql.catalyst.plans.logical.{EventTimeWatermark, LogicalPlan} +import org.apache.spark.sql.catalyst.plans.logical.EventTimeWatermark._ +import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.CalendarInterval + + +/** + * Helper object for stream joins. See [[StreamingSymmetricHashJoinExec]] in SQL for more details. + */ +object StreamingJoinHelper extends PredicateHelper with Logging { + + /** + * Check the provided logical plan to see if its join keys contain a watermark attribute. + * + * Will return false if the plan is not an equijoin. + * @param plan the logical plan to check + */ + def isWatermarkInJoinKeys(plan: LogicalPlan): Boolean = { + plan match { + case ExtractEquiJoinKeys(_, leftKeys, rightKeys, _, _, _) => + (leftKeys ++ rightKeys).exists { + case a: AttributeReference => a.metadata.contains(EventTimeWatermark.delayKey) + case _ => false + } + case _ => false + } + } + + /** + * Get state value watermark (see [[StreamingSymmetricHashJoinExec]] for context about it) + * given the join condition and the event time watermark. This is how it works. + * - The condition is split into conjunctive predicates, and we find the predicates of the + * form `leftTime + c1 < rightTime + c2` (or <=, >, >=). + * - We canoncalize the predicate and solve it with the event time watermark value to find the + * value of the state watermark. + * This function is supposed to make best-effort attempt to get the state watermark. If there is + * any error, it will return None. + * + * @param attributesToFindStateWatermarkFor attributes of the side whose state watermark + * is to be calculated + * @param attributesWithEventWatermark attributes of the other side which has a watermark column + * @param joinCondition join condition + * @param eventWatermark watermark defined on the input event data + * @return state value watermark in milliseconds, is possible. + */ + def getStateValueWatermark( + attributesToFindStateWatermarkFor: AttributeSet, + attributesWithEventWatermark: AttributeSet, + joinCondition: Option[Expression], + eventWatermark: Option[Long]): Option[Long] = { + + // If condition or event time watermark is not provided, then cannot calculate state watermark + if (joinCondition.isEmpty || eventWatermark.isEmpty) return None + + // If there is not watermark attribute, then cannot define state watermark + if (!attributesWithEventWatermark.exists(_.metadata.contains(delayKey))) return None + + def getStateWatermarkSafely(l: Expression, r: Expression): Option[Long] = { + try { + getStateWatermarkFromLessThenPredicate( + l, r, attributesToFindStateWatermarkFor, attributesWithEventWatermark, eventWatermark) + } catch { + case NonFatal(e) => + logWarning(s"Error trying to extract state constraint from condition $joinCondition", e) + None + } + } + + val allStateWatermarks = splitConjunctivePredicates(joinCondition.get).flatMap { predicate => + + // The generated the state watermark cleanup expression is inclusive of the state watermark. + // If state watermark is W, all state where timestamp <= W will be cleaned up. + // Now when the canonicalized join condition solves to leftTime >= W, we dont want to clean + // up leftTime <= W. Rather we should clean up leftTime <= W - 1. Hence the -1 below. + val stateWatermark = predicate match { + case LessThan(l, r) => getStateWatermarkSafely(l, r) + case LessThanOrEqual(l, r) => getStateWatermarkSafely(l, r).map(_ - 1) + case GreaterThan(l, r) => getStateWatermarkSafely(r, l) + case GreaterThanOrEqual(l, r) => getStateWatermarkSafely(r, l).map(_ - 1) + case _ => None + } + if (stateWatermark.nonEmpty) { + logInfo(s"Condition $joinCondition generated watermark constraint = ${stateWatermark.get}") + } + stateWatermark + } + allStateWatermarks.reduceOption((x, y) => Math.min(x, y)) + } + + /** + * Extract the state value watermark (milliseconds) from the condition + * `LessThan(leftExpr, rightExpr)` where . For example: if we want to find the constraint for + * leftTime using the watermark on the rightTime. Example: + * + * Input: rightTime-with-watermark + c1 < leftTime + c2 + * Canonical form: rightTime-with-watermark + c1 + (-c2) + (-leftTime) < 0 + * Solving for rightTime: rightTime-with-watermark + c1 + (-c2) < leftTime + * With watermark value: watermark-value + c1 + (-c2) < leftTime + */ + private def getStateWatermarkFromLessThenPredicate( + leftExpr: Expression, + rightExpr: Expression, + attributesToFindStateWatermarkFor: AttributeSet, + attributesWithEventWatermark: AttributeSet, + eventWatermark: Option[Long]): Option[Long] = { + + val attributesInCondition = AttributeSet( + leftExpr.collect { case a: AttributeReference => a } ++ + rightExpr.collect { case a: AttributeReference => a } + ) + if (attributesInCondition.filter { attributesToFindStateWatermarkFor.contains(_) }.size > 1 || + attributesInCondition.filter { attributesWithEventWatermark.contains(_) }.size > 1) { + // If more than attributes present in condition from one side, then it cannot be solved + return None + } + + def containsAttributeToFindStateConstraintFor(e: Expression): Boolean = { + e.collectLeaves().collectFirst { + case a @ AttributeReference(_, _, _, _) + if attributesToFindStateWatermarkFor.contains(a) => a + }.nonEmpty + } + + // Canonicalization step 1: convert to (rightTime-with-watermark + c1) - (leftTime + c2) < 0 + val allOnLeftExpr = Subtract(leftExpr, rightExpr) + logDebug(s"All on Left:\n${allOnLeftExpr.treeString(true)}\n${allOnLeftExpr.asCode}") + + // Canonicalization step 2: extract commutative terms + // rightTime-with-watermark, c1, -leftTime, -c2 + val terms = ExpressionSet(collectTerms(allOnLeftExpr)) + logDebug("Terms extracted from join condition:\n\t" + terms.mkString("\n\t")) + + // Find the term that has leftTime (i.e. the one present in attributesToFindConstraintFor + val constraintTerms = terms.filter(containsAttributeToFindStateConstraintFor) + + // Verify there is only one correct constraint term and of the correct type + if (constraintTerms.size > 1) { + logWarning("Failed to extract state constraint terms: multiple time terms in condition\n\t" + + terms.mkString("\n\t")) + return None + } + if (constraintTerms.isEmpty) { + logDebug("Failed to extract state constraint terms: no time terms in condition\n\t" + + terms.mkString("\n\t")) + return None + } + val constraintTerm = constraintTerms.head + if (constraintTerm.collectFirst { case u: UnaryMinus => u }.isEmpty) { + // Incorrect condition. We want the constraint term in canonical form to be `-leftTime` + // so that resolve for it as `-leftTime + watermark + c < 0` ==> `watermark + c < leftTime`. + // Now, if the original conditions is `rightTime-with-watermark > leftTime` and watermark + // condition is `rightTime-with-watermark > watermarkValue`, then no constraint about + // `leftTime` can be inferred. In this case, after canonicalization and collection of terms, + // the constraintTerm would be `leftTime` and not `-leftTime`. Hence, we return None. + return None + } + + // Replace watermark attribute with watermark value, and generate the resolved expression + // from the other terms. That is, + // rightTime-with-watermark, c1, -c2 => watermark, c1, -c2 => watermark + c1 + (-c2) + logDebug(s"Constraint term from join condition:\t$constraintTerm") + val exprWithWatermarkSubstituted = (terms - constraintTerm).map { term => + term.transform { + case a @ AttributeReference(_, _, _, metadata) + if attributesWithEventWatermark.contains(a) && metadata.contains(delayKey) => + Multiply(Literal(eventWatermark.get.toDouble), Literal(1000.0)) + } + }.reduceLeft(Add) + + // Calculate the constraint value + logInfo(s"Final expression to evaluate constraint:\t$exprWithWatermarkSubstituted") + val constraintValue = exprWithWatermarkSubstituted.eval().asInstanceOf[java.lang.Double] + Some((Double2double(constraintValue) / 1000.0).toLong) + } + + /** + * Collect all the terms present in an expression after converting it into the form + * a + b + c + d where each term be either an attribute or a literal casted to long, + * optionally wrapped in a unary minus. + */ + private def collectTerms(exprToCollectFrom: Expression): Seq[Expression] = { + var invalid = false + + /** Wrap a term with UnaryMinus if its needs to be negated. */ + def negateIfNeeded(expr: Expression, minus: Boolean): Expression = { + if (minus) UnaryMinus(expr) else expr + } + + /** + * Recursively split the expression into its leaf terms contains attributes or literals. + * Returns terms only of the forms: + * Cast(AttributeReference), UnaryMinus(Cast(AttributeReference)), + * Cast(AttributeReference, Double), UnaryMinus(Cast(AttributeReference, Double)) + * Multiply(Literal), UnaryMinus(Multiply(Literal)) + * Multiply(Cast(Literal)), UnaryMinus(Multiple(Cast(Literal))) + * + * Note: + * - If term needs to be negated for making it a commutative term, + * then it will be wrapped in UnaryMinus(...) + * - Each terms will be representing timestamp value or time interval in microseconds, + * typed as doubles. + */ + def collect(expr: Expression, negate: Boolean): Seq[Expression] = { + expr match { + case Add(left, right) => + collect(left, negate) ++ collect(right, negate) + case Subtract(left, right) => + collect(left, negate) ++ collect(right, !negate) + case TimeAdd(left, right, _) => + collect(left, negate) ++ collect(right, negate) + case TimeSub(left, right, _) => + collect(left, negate) ++ collect(right, !negate) + case UnaryMinus(child) => + collect(child, !negate) + case CheckOverflow(child, _) => + collect(child, negate) + case Cast(child, dataType, _) => + dataType match { + case _: NumericType | _: TimestampType => collect(child, negate) + case _ => + invalid = true + Seq.empty + } + case a: AttributeReference => + val castedRef = if (a.dataType != DoubleType) Cast(a, DoubleType) else a + Seq(negateIfNeeded(castedRef, negate)) + case lit: Literal => + // If literal of type calendar interval, then explicitly convert to millis + // Convert other number like literal to doubles representing millis (by x1000) + val castedLit = lit.dataType match { + case CalendarIntervalType => + val calendarInterval = lit.value.asInstanceOf[CalendarInterval] + if (calendarInterval.months > 0) { + invalid = true + logWarning( + s"Failed to extract state value watermark from condition $exprToCollectFrom " + + s"as imprecise intervals like months and years cannot be used for" + + s"watermark calculation. Use interval in terms of day instead.") + Literal(0.0) + } else { + Literal(calendarInterval.microseconds.toDouble) + } + case DoubleType => + Multiply(lit, Literal(1000000.0)) + case _: NumericType => + Multiply(Cast(lit, DoubleType), Literal(1000000.0)) + case _: TimestampType => + Multiply(PreciseTimestampConversion(lit, TimestampType, LongType), Literal(1000000.0)) + } + Seq(negateIfNeeded(castedLit, negate)) + case a @ _ => + logWarning( + s"Failed to extract state value watermark from condition $exprToCollectFrom due to $a") + invalid = true + Seq.empty + } + } + + val terms = collect(exprToCollectFrom, negate = false) + if (!invalid) terms else Seq.empty + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala index d1d705691b076..dee6fbe9d1514 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala @@ -18,8 +18,9 @@ package org.apache.spark.sql.catalyst.analysis import org.apache.spark.sql.AnalysisException -import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, AttributeSet} import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression +import org.apache.spark.sql.catalyst.planning.ExtractEquiJoinKeys import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.streaming.InternalOutputModes @@ -217,7 +218,7 @@ object UnsupportedOperationChecker { throwError("dropDuplicates is not supported after aggregation on a " + "streaming DataFrame/Dataset") - case Join(left, right, joinType, _) => + case Join(left, right, joinType, condition) => joinType match { @@ -233,16 +234,52 @@ object UnsupportedOperationChecker { throwError("Full outer joins with streaming DataFrames/Datasets are not supported") } - case LeftOuter | LeftSemi | LeftAnti => + case LeftSemi | LeftAnti => if (right.isStreaming) { - throwError("Left outer/semi/anti joins with a streaming DataFrame/Dataset " + - "on the right is not supported") + throwError("Left semi/anti joins with a streaming DataFrame/Dataset " + + "on the right are not supported") } + // We support streaming left outer joins with static on the right always, and with + // stream on both sides under the appropriate conditions. + case LeftOuter => + if (!left.isStreaming && right.isStreaming) { + throwError("Left outer join with a streaming DataFrame/Dataset " + + "on the right and a static DataFrame/Dataset on the left is not supported") + } else if (left.isStreaming && right.isStreaming) { + val watermarkInJoinKeys = StreamingJoinHelper.isWatermarkInJoinKeys(subPlan) + + val hasValidWatermarkRange = + StreamingJoinHelper.getStateValueWatermark( + left.outputSet, right.outputSet, condition, Some(1000000)).isDefined + + if (!watermarkInJoinKeys && !hasValidWatermarkRange) { + throwError("Stream-stream outer join between two streaming DataFrame/Datasets " + + "is not supported without a watermark in the join keys, or a watermark on " + + "the nullable side and an appropriate range condition") + } + } + + // We support streaming right outer joins with static on the left always, and with + // stream on both sides under the appropriate conditions. case RightOuter => - if (left.isStreaming) { - throwError("Right outer join with a streaming DataFrame/Dataset on the left is " + - "not supported") + if (left.isStreaming && !right.isStreaming) { + throwError("Right outer join with a streaming DataFrame/Dataset on the left and " + + "a static DataFrame/DataSet on the right not supported") + } else if (left.isStreaming && right.isStreaming) { + val isWatermarkInJoinKeys = StreamingJoinHelper.isWatermarkInJoinKeys(subPlan) + + // Check if the nullable side has a watermark, and there's a range condition which + // implies a state value watermark on the first side. + val hasValidWatermarkRange = + StreamingJoinHelper.getStateValueWatermark( + right.outputSet, left.outputSet, condition, Some(1000000)).isDefined + + if (!isWatermarkInJoinKeys && !hasValidWatermarkRange) { + throwError("Stream-stream outer join between two streaming DataFrame/Datasets " + + "is not supported without a watermark in the join keys, or a watermark on " + + "the nullable side and an appropriate range condition") + } } case NaturalJoin(_) | UsingJoin(_, _) => diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/StreamingJoinHelperSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/StreamingJoinHelperSuite.scala new file mode 100644 index 0000000000000..8cf41a02320d2 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/StreamingJoinHelperSuite.scala @@ -0,0 +1,140 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.analysis + +import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, AttributeSet} +import org.apache.spark.sql.catalyst.optimizer.SimpleTestOptimizer +import org.apache.spark.sql.catalyst.parser.CatalystSqlParser +import org.apache.spark.sql.catalyst.plans.logical.{EventTimeWatermark, Filter, LeafNode, LocalRelation} +import org.apache.spark.sql.types.{IntegerType, MetadataBuilder, TimestampType} + +class StreamingJoinHelperSuite extends AnalysisTest { + + test("extract watermark from time condition") { + val attributesToFindConstraintFor = Seq( + AttributeReference("leftTime", TimestampType)(), + AttributeReference("leftOther", IntegerType)()) + val metadataWithWatermark = new MetadataBuilder() + .putLong(EventTimeWatermark.delayKey, 1000) + .build() + val attributesWithWatermark = Seq( + AttributeReference("rightTime", TimestampType, metadata = metadataWithWatermark)(), + AttributeReference("rightOther", IntegerType)()) + + case class DummyLeafNode() extends LeafNode { + override def output: Seq[Attribute] = + attributesToFindConstraintFor ++ attributesWithWatermark + } + + def watermarkFrom( + conditionStr: String, + rightWatermark: Option[Long] = Some(10000)): Option[Long] = { + val conditionExpr = Some(conditionStr).map { str => + val plan = + Filter( + CatalystSqlParser.parseExpression(str), + DummyLeafNode()) + val optimized = SimpleTestOptimizer.execute(SimpleAnalyzer.execute(plan)) + optimized.asInstanceOf[Filter].condition + } + StreamingJoinHelper.getStateValueWatermark( + AttributeSet(attributesToFindConstraintFor), AttributeSet(attributesWithWatermark), + conditionExpr, rightWatermark) + } + + // Test comparison directionality. E.g. if leftTime < rightTime and rightTime > watermark, + // then cannot define constraint on leftTime. + assert(watermarkFrom("leftTime > rightTime") === Some(10000)) + assert(watermarkFrom("leftTime >= rightTime") === Some(9999)) + assert(watermarkFrom("leftTime < rightTime") === None) + assert(watermarkFrom("leftTime <= rightTime") === None) + assert(watermarkFrom("rightTime > leftTime") === None) + assert(watermarkFrom("rightTime >= leftTime") === None) + assert(watermarkFrom("rightTime < leftTime") === Some(10000)) + assert(watermarkFrom("rightTime <= leftTime") === Some(9999)) + + // Test type conversions + assert(watermarkFrom("CAST(leftTime AS LONG) > CAST(rightTime AS LONG)") === Some(10000)) + assert(watermarkFrom("CAST(leftTime AS LONG) < CAST(rightTime AS LONG)") === None) + assert(watermarkFrom("CAST(leftTime AS DOUBLE) > CAST(rightTime AS DOUBLE)") === Some(10000)) + assert(watermarkFrom("CAST(leftTime AS LONG) > CAST(rightTime AS DOUBLE)") === Some(10000)) + assert(watermarkFrom("CAST(leftTime AS LONG) > CAST(rightTime AS FLOAT)") === Some(10000)) + assert(watermarkFrom("CAST(leftTime AS DOUBLE) > CAST(rightTime AS FLOAT)") === Some(10000)) + assert(watermarkFrom("CAST(leftTime AS STRING) > CAST(rightTime AS STRING)") === None) + + // Test with timestamp type + calendar interval on either side of equation + // Note: timestamptype and calendar interval don't commute, so less valid combinations to test. + assert(watermarkFrom("leftTime > rightTime + interval 1 second") === Some(11000)) + assert(watermarkFrom("leftTime + interval 2 seconds > rightTime ") === Some(8000)) + assert(watermarkFrom("leftTime > rightTime - interval 3 second") === Some(7000)) + assert(watermarkFrom("rightTime < leftTime - interval 3 second") === Some(13000)) + assert(watermarkFrom("rightTime - interval 1 second < leftTime - interval 3 second") + === Some(12000)) + + // Test with casted long type + constants on either side of equation + // Note: long type and constants commute, so more combinations to test. + // -- Constants on the right + assert(watermarkFrom("CAST(leftTime AS LONG) > CAST(rightTime AS LONG) + 1") === Some(11000)) + assert(watermarkFrom("CAST(leftTime AS LONG) > CAST(rightTime AS LONG) - 1") === Some(9000)) + assert(watermarkFrom("CAST(leftTime AS LONG) > CAST((rightTime + interval 1 second) AS LONG)") + === Some(11000)) + assert(watermarkFrom("CAST(leftTime AS LONG) > 2 + CAST(rightTime AS LONG)") === Some(12000)) + assert(watermarkFrom("CAST(leftTime AS LONG) > -0.5 + CAST(rightTime AS LONG)") === Some(9500)) + assert(watermarkFrom("CAST(leftTime AS LONG) - CAST(rightTime AS LONG) > 2") === Some(12000)) + assert(watermarkFrom("-CAST(rightTime AS DOUBLE) + CAST(leftTime AS LONG) > 0.1") + === Some(10100)) + assert(watermarkFrom("0 > CAST(rightTime AS LONG) - CAST(leftTime AS LONG) + 0.2") + === Some(10200)) + // -- Constants on the left + assert(watermarkFrom("CAST(leftTime AS LONG) + 2 > CAST(rightTime AS LONG)") === Some(8000)) + assert(watermarkFrom("1 + CAST(leftTime AS LONG) > CAST(rightTime AS LONG)") === Some(9000)) + assert(watermarkFrom("CAST((leftTime + interval 3 second) AS LONG) > CAST(rightTime AS LONG)") + === Some(7000)) + assert(watermarkFrom("CAST(leftTime AS LONG) - 2 > CAST(rightTime AS LONG)") === Some(12000)) + assert(watermarkFrom("CAST(leftTime AS LONG) + 0.5 > CAST(rightTime AS LONG)") === Some(9500)) + assert(watermarkFrom("CAST(leftTime AS LONG) - CAST(rightTime AS LONG) - 2 > 0") + === Some(12000)) + assert(watermarkFrom("-CAST(rightTime AS LONG) + CAST(leftTime AS LONG) - 0.1 > 0") + === Some(10100)) + // -- Constants on both sides, mixed types + assert(watermarkFrom("CAST(leftTime AS LONG) - 2.0 > CAST(rightTime AS LONG) + 1") + === Some(13000)) + + // Test multiple conditions, should return minimum watermark + assert(watermarkFrom( + "leftTime > rightTime - interval 3 second AND rightTime < leftTime + interval 2 seconds") === + Some(7000)) // first condition wins + assert(watermarkFrom( + "leftTime > rightTime - interval 3 second AND rightTime < leftTime + interval 4 seconds") === + Some(6000)) // second condition wins + + // Test invalid comparisons + assert(watermarkFrom("cast(leftTime AS LONG) > leftOther") === None) // non-time attributes + assert(watermarkFrom("leftOther > rightOther") === None) // non-time attributes + assert(watermarkFrom("leftOther > rightOther AND leftTime > rightTime") === Some(10000)) + assert(watermarkFrom("cast(rightTime AS DOUBLE) < rightOther") === None) // non-time attributes + assert(watermarkFrom("leftTime > rightTime + interval 1 month") === None) // month not allowed + + // Test static comparisons + assert(watermarkFrom("cast(leftTime AS LONG) > 10") === Some(10000)) + + // Test non-positive results + assert(watermarkFrom("CAST(leftTime AS LONG) > CAST(rightTime AS LONG) - 10") === Some(0)) + assert(watermarkFrom("CAST(leftTime AS LONG) > CAST(rightTime AS LONG) - 100") === Some(-90000)) + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationsSuite.scala index 11f48a39c1e25..e5057c451d5b8 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationsSuite.scala @@ -31,6 +31,7 @@ import org.apache.spark.sql.catalyst.plans.logical.{FlatMapGroupsWithState, _} import org.apache.spark.sql.catalyst.streaming.InternalOutputModes._ import org.apache.spark.sql.streaming.OutputMode import org.apache.spark.sql.types.{IntegerType, LongType, MetadataBuilder} +import org.apache.spark.unsafe.types.CalendarInterval /** A dummy command for testing unsupported operations. */ case class DummyCommand() extends Command @@ -417,9 +418,57 @@ class UnsupportedOperationsSuite extends SparkFunSuite { testBinaryOperationInStreamingPlan( "left outer join", _.join(_, joinType = LeftOuter), - streamStreamSupported = false, batchStreamSupported = false, - expectedMsg = "left outer/semi/anti joins") + streamStreamSupported = false, + expectedMsg = "outer join") + + // Left outer joins: stream-stream allowed with join on watermark attribute + // Note that the attribute need not be watermarked on both sides. + assertSupportedInStreamingPlan( + s"left outer join with stream-stream relations and join on attribute with left watermark", + streamRelation.join(streamRelation, joinType = LeftOuter, + condition = Some(attributeWithWatermark === attribute)), + OutputMode.Append()) + assertSupportedInStreamingPlan( + s"left outer join with stream-stream relations and join on attribute with right watermark", + streamRelation.join(streamRelation, joinType = LeftOuter, + condition = Some(attribute === attributeWithWatermark)), + OutputMode.Append()) + assertNotSupportedInStreamingPlan( + s"left outer join with stream-stream relations and join on non-watermarked attribute", + streamRelation.join(streamRelation, joinType = LeftOuter, + condition = Some(attribute === attribute)), + OutputMode.Append(), + Seq("watermark in the join keys")) + + // Left outer joins: stream-stream allowed with range condition yielding state value watermark + assertSupportedInStreamingPlan( + s"left outer join with stream-stream relations and state value watermark", { + val leftRelation = streamRelation + val rightTimeWithWatermark = + AttributeReference("b", IntegerType)().withMetadata(watermarkMetadata) + val rightRelation = new TestStreamingRelation(rightTimeWithWatermark) + leftRelation.join( + rightRelation, + joinType = LeftOuter, + condition = Some(attribute > rightTimeWithWatermark + 10)) + }, + OutputMode.Append()) + + // Left outer joins: stream-stream not allowed with insufficient range condition + assertNotSupportedInStreamingPlan( + s"left outer join with stream-stream relations and state value watermark", { + val leftRelation = streamRelation + val rightTimeWithWatermark = + AttributeReference("b", IntegerType)().withMetadata(watermarkMetadata) + val rightRelation = new TestStreamingRelation(rightTimeWithWatermark) + leftRelation.join( + rightRelation, + joinType = LeftOuter, + condition = Some(attribute < rightTimeWithWatermark + 10)) + }, + OutputMode.Append(), + Seq("appropriate range condition")) // Left semi joins: stream-* not allowed testBinaryOperationInStreamingPlan( @@ -427,7 +476,7 @@ class UnsupportedOperationsSuite extends SparkFunSuite { _.join(_, joinType = LeftSemi), streamStreamSupported = false, batchStreamSupported = false, - expectedMsg = "left outer/semi/anti joins") + expectedMsg = "left semi/anti joins") // Left anti joins: stream-* not allowed testBinaryOperationInStreamingPlan( @@ -435,14 +484,63 @@ class UnsupportedOperationsSuite extends SparkFunSuite { _.join(_, joinType = LeftAnti), streamStreamSupported = false, batchStreamSupported = false, - expectedMsg = "left outer/semi/anti joins") + expectedMsg = "left semi/anti joins") // Right outer joins: stream-* not allowed testBinaryOperationInStreamingPlan( "right outer join", _.join(_, joinType = RightOuter), + streamBatchSupported = false, streamStreamSupported = false, - streamBatchSupported = false) + expectedMsg = "outer join") + + // Right outer joins: stream-stream allowed with join on watermark attribute + // Note that the attribute need not be watermarked on both sides. + assertSupportedInStreamingPlan( + s"right outer join with stream-stream relations and join on attribute with left watermark", + streamRelation.join(streamRelation, joinType = RightOuter, + condition = Some(attributeWithWatermark === attribute)), + OutputMode.Append()) + assertSupportedInStreamingPlan( + s"right outer join with stream-stream relations and join on attribute with right watermark", + streamRelation.join(streamRelation, joinType = RightOuter, + condition = Some(attribute === attributeWithWatermark)), + OutputMode.Append()) + assertNotSupportedInStreamingPlan( + s"right outer join with stream-stream relations and join on non-watermarked attribute", + streamRelation.join(streamRelation, joinType = RightOuter, + condition = Some(attribute === attribute)), + OutputMode.Append(), + Seq("watermark in the join keys")) + + // Right outer joins: stream-stream allowed with range condition yielding state value watermark + assertSupportedInStreamingPlan( + s"right outer join with stream-stream relations and state value watermark", { + val leftTimeWithWatermark = + AttributeReference("b", IntegerType)().withMetadata(watermarkMetadata) + val leftRelation = new TestStreamingRelation(leftTimeWithWatermark) + val rightRelation = streamRelation + leftRelation.join( + rightRelation, + joinType = RightOuter, + condition = Some(leftTimeWithWatermark + 10 < attribute)) + }, + OutputMode.Append()) + + // Right outer joins: stream-stream not allowed with insufficient range condition + assertNotSupportedInStreamingPlan( + s"right outer join with stream-stream relations and state value watermark", { + val leftTimeWithWatermark = + AttributeReference("b", IntegerType)().withMetadata(watermarkMetadata) + val leftRelation = new TestStreamingRelation(leftTimeWithWatermark) + val rightRelation = streamRelation + leftRelation.join( + rightRelation, + joinType = RightOuter, + condition = Some(leftTimeWithWatermark + 10 > attribute)) + }, + OutputMode.Append(), + Seq("appropriate range condition")) // Cogroup: only batch-batch is allowed testBinaryOperationInStreamingPlan( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinExec.scala index 44f1fa58599d2..9bd2127a28ff6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinExec.scala @@ -21,7 +21,7 @@ import java.util.concurrent.TimeUnit.NANOSECONDS import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.{Attribute, BindReferences, Expression, JoinedRow, Literal, NamedExpression, PreciseTimestampConversion, UnsafeProjection, UnsafeRow} +import org.apache.spark.sql.catalyst.expressions.{Attribute, BindReferences, Expression, GenericInternalRow, JoinedRow, Literal, NamedExpression, PreciseTimestampConversion, UnsafeProjection, UnsafeRow} import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical.EventTimeWatermark._ import org.apache.spark.sql.catalyst.plans.physical._ @@ -146,7 +146,14 @@ case class StreamingSymmetricHashJoinExec( stateWatermarkPredicates = JoinStateWatermarkPredicates(), left, right) } - require(joinType == Inner, s"${getClass.getSimpleName} should not take $joinType as the JoinType") + private def throwBadJoinTypeException(): Nothing = { + throw new IllegalArgumentException( + s"${getClass.getSimpleName} should not take $joinType as the JoinType") + } + + require( + joinType == Inner || joinType == LeftOuter || joinType == RightOuter, + s"${getClass.getSimpleName} should not take $joinType as the JoinType") require(leftKeys.map(_.dataType) == rightKeys.map(_.dataType)) private val storeConf = new StateStoreConf(sqlContext.conf) @@ -157,11 +164,18 @@ case class StreamingSymmetricHashJoinExec( override def requiredChildDistribution: Seq[Distribution] = ClusteredDistribution(leftKeys) :: ClusteredDistribution(rightKeys) :: Nil - override def output: Seq[Attribute] = left.output ++ right.output + override def output: Seq[Attribute] = joinType match { + case _: InnerLike => left.output ++ right.output + case LeftOuter => left.output ++ right.output.map(_.withNullability(true)) + case RightOuter => left.output.map(_.withNullability(true)) ++ right.output + case _ => throwBadJoinTypeException() + } override def outputPartitioning: Partitioning = joinType match { case _: InnerLike => PartitioningCollection(Seq(left.outputPartitioning, right.outputPartitioning)) + case LeftOuter => PartitioningCollection(Seq(left.outputPartitioning)) + case RightOuter => PartitioningCollection(Seq(right.outputPartitioning)) case x => throw new IllegalArgumentException( s"${getClass.getSimpleName} should not take $x as the JoinType") @@ -207,31 +221,108 @@ case class StreamingSymmetricHashJoinExec( // matching new left input with new right input, since the new left input has become stored // by that point. This tiny asymmetry is necessary to avoid duplication. val leftOutputIter = leftSideJoiner.storeAndJoinWithOtherSide(rightSideJoiner) { - (inputRow: UnsafeRow, matchedRow: UnsafeRow) => - joinedRow.withLeft(inputRow).withRight(matchedRow) + (input: UnsafeRow, matched: UnsafeRow) => joinedRow.withLeft(input).withRight(matched) } val rightOutputIter = rightSideJoiner.storeAndJoinWithOtherSide(leftSideJoiner) { - (inputRow: UnsafeRow, matchedRow: UnsafeRow) => - joinedRow.withLeft(matchedRow).withRight(inputRow) + (input: UnsafeRow, matched: UnsafeRow) => joinedRow.withLeft(matched).withRight(input) } // Filter the joined rows based on the given condition. - val outputFilterFunction = - newPredicate(condition.getOrElse(Literal(true)), left.output ++ right.output).eval _ - val filteredOutputIter = - (leftOutputIter ++ rightOutputIter).filter(outputFilterFunction).map { row => - numOutputRows += 1 - row - } + val outputFilterFunction = newPredicate(condition.getOrElse(Literal(true)), output).eval _ + + // We need to save the time that the inner join output iterator completes, since outer join + // output counts as both update and removal time. + var innerOutputCompletionTimeNs: Long = 0 + def onInnerOutputCompletion = { + innerOutputCompletionTimeNs = System.nanoTime + } + val filteredInnerOutputIter = CompletionIterator[InternalRow, Iterator[InternalRow]]( + (leftOutputIter ++ rightOutputIter).filter(outputFilterFunction), onInnerOutputCompletion) + + def matchesWithRightSideState(leftKeyValue: UnsafeRowPair) = { + rightSideJoiner.get(leftKeyValue.key).exists( + rightValue => { + outputFilterFunction( + joinedRow.withLeft(leftKeyValue.value).withRight(rightValue)) + }) + } + + def matchesWithLeftSideState(rightKeyValue: UnsafeRowPair) = { + leftSideJoiner.get(rightKeyValue.key).exists( + leftValue => { + outputFilterFunction( + joinedRow.withLeft(leftValue).withRight(rightKeyValue.value)) + }) + } + + val outputIter: Iterator[InternalRow] = joinType match { + case Inner => + filteredInnerOutputIter + case LeftOuter => + // We generate the outer join input by: + // * Getting an iterator over the rows that have aged out on the left side. These rows are + // candidates for being null joined. Note that to avoid doing two passes, this iterator + // removes the rows from the state manager as they're processed. + // * Checking whether the current row matches a key in the right side state, and that key + // has any value which satisfies the filter function when joined. If it doesn't, + // we know we can join with null, since there was never (including this batch) a match + // within the watermark period. If it does, there must have been a match at some point, so + // we know we can't join with null. + val nullRight = new GenericInternalRow(right.output.map(_.withNullability(true)).length) + val removedRowIter = leftSideJoiner.removeOldState() + val outerOutputIter = removedRowIter + .filterNot(pair => matchesWithRightSideState(pair)) + .map(pair => joinedRow.withLeft(pair.value).withRight(nullRight)) + + filteredInnerOutputIter ++ outerOutputIter + case RightOuter => + // See comments for left outer case. + val nullLeft = new GenericInternalRow(left.output.map(_.withNullability(true)).length) + val removedRowIter = rightSideJoiner.removeOldState() + val outerOutputIter = removedRowIter + .filterNot(pair => matchesWithLeftSideState(pair)) + .map(pair => joinedRow.withLeft(nullLeft).withRight(pair.value)) + + filteredInnerOutputIter ++ outerOutputIter + case _ => throwBadJoinTypeException() + } + + val outputIterWithMetrics = outputIter.map { row => + numOutputRows += 1 + row + } // Function to remove old state after all the input has been consumed and output generated def onOutputCompletion = { + // All processing time counts as update time. allUpdatesTimeMs += math.max(NANOSECONDS.toMillis(System.nanoTime - updateStartTimeNs), 0) - // Remove old state if needed + // Processing time between inner output completion and here comes from the outer portion of a + // join, and thus counts as removal time as we remove old state from one side while iterating. + if (innerOutputCompletionTimeNs != 0) { + allRemovalsTimeMs += + math.max(NANOSECONDS.toMillis(System.nanoTime - innerOutputCompletionTimeNs), 0) + } + allRemovalsTimeMs += timeTakenMs { - leftSideJoiner.removeOldState() - rightSideJoiner.removeOldState() + // Remove any remaining state rows which aren't needed because they're below the watermark. + // + // For inner joins, we have to remove unnecessary state rows from both sides if possible. + // For outer joins, we have already removed unnecessary state rows from the outer side + // (e.g., left side for left outer join) while generating the outer "null" outputs. Now, we + // have to remove unnecessary state rows from the other side (e.g., right side for the left + // outer join) if possible. In all cases, nothing needs to be outputted, hence the removal + // needs to be done greedily by immediately consuming the returned iterator. + val cleanupIter = joinType match { + case Inner => + leftSideJoiner.removeOldState() ++ rightSideJoiner.removeOldState() + case LeftOuter => rightSideJoiner.removeOldState() + case RightOuter => leftSideJoiner.removeOldState() + case _ => throwBadJoinTypeException() + } + while (cleanupIter.hasNext) { + cleanupIter.next() + } } // Commit all state changes and update state store metrics @@ -251,7 +342,8 @@ case class StreamingSymmetricHashJoinExec( } } - CompletionIterator[InternalRow, Iterator[InternalRow]](filteredOutputIter, onOutputCompletion) + CompletionIterator[InternalRow, Iterator[InternalRow]]( + outputIterWithMetrics, onOutputCompletion) } /** @@ -324,14 +416,32 @@ case class StreamingSymmetricHashJoinExec( } } - /** Remove old buffered state rows using watermarks for state keys and values */ - def removeOldState(): Unit = { + /** + * Get an iterator over the values stored in this joiner's state manager for the given key. + * + * Should not be interleaved with mutations. + */ + def get(key: UnsafeRow): Iterator[UnsafeRow] = { + joinStateManager.get(key) + } + + /** + * Builds an iterator over old state key-value pairs, removing them lazily as they're produced. + * + * @note This iterator must be consumed fully before any other operations are made + * against this joiner's join state manager. For efficiency reasons, the intermediate states of + * the iterator leave the state manager in an undefined state. + * + * We do this to avoid requiring either two passes or full materialization when + * processing the rows for outer join. + */ + def removeOldState(): Iterator[UnsafeRowPair] = { stateWatermarkPredicate match { case Some(JoinStateKeyWatermarkPredicate(expr)) => joinStateManager.removeByKeyCondition(stateKeyWatermarkPredicateFunc) case Some(JoinStateValueWatermarkPredicate(expr)) => joinStateManager.removeByValueCondition(stateValueWatermarkPredicateFunc) - case _ => + case _ => Iterator.empty } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinHelper.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinHelper.scala index e50274a1baba1..64c7189f72ac3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinHelper.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinHelper.scala @@ -23,6 +23,7 @@ import scala.util.control.NonFatal import org.apache.spark.{Partition, SparkContext} import org.apache.spark.internal.Logging import org.apache.spark.rdd.{RDD, ZippedPartitionsRDD2} +import org.apache.spark.sql.catalyst.analysis.StreamingJoinHelper import org.apache.spark.sql.catalyst.expressions.{Add, Attribute, AttributeReference, AttributeSet, BoundReference, Cast, CheckOverflow, Expression, ExpressionSet, GreaterThan, GreaterThanOrEqual, LessThan, LessThanOrEqual, Literal, Multiply, NamedExpression, PreciseTimestampConversion, PredicateHelper, Subtract, TimeAdd, TimeSub, UnaryMinus} import org.apache.spark.sql.catalyst.plans.logical.EventTimeWatermark._ import org.apache.spark.sql.execution.streaming.WatermarkSupport.watermarkExpression @@ -34,7 +35,7 @@ import org.apache.spark.unsafe.types.CalendarInterval /** * Helper object for [[StreamingSymmetricHashJoinExec]]. See that object for more details. */ -object StreamingSymmetricHashJoinHelper extends PredicateHelper with Logging { +object StreamingSymmetricHashJoinHelper extends Logging { sealed trait JoinSide case object LeftSide extends JoinSide { override def toString(): String = "left" } @@ -111,7 +112,7 @@ object StreamingSymmetricHashJoinHelper extends PredicateHelper with Logging { expr.map(JoinStateKeyWatermarkPredicate.apply _) } else if (isWatermarkDefinedOnInput) { // case 2 in the StreamingSymmetricHashJoinExec docs - val stateValueWatermark = getStateValueWatermark( + val stateValueWatermark = StreamingJoinHelper.getStateValueWatermark( attributesToFindStateWatermarkFor = AttributeSet(oneSideInputAttributes), attributesWithEventWatermark = AttributeSet(otherSideInputAttributes), condition, @@ -132,242 +133,6 @@ object StreamingSymmetricHashJoinHelper extends PredicateHelper with Logging { JoinStateWatermarkPredicates(leftStateWatermarkPredicate, rightStateWatermarkPredicate) } - /** - * Get state value watermark (see [[StreamingSymmetricHashJoinExec]] for context about it) - * given the join condition and the event time watermark. This is how it works. - * - The condition is split into conjunctive predicates, and we find the predicates of the - * form `leftTime + c1 < rightTime + c2` (or <=, >, >=). - * - We canoncalize the predicate and solve it with the event time watermark value to find the - * value of the state watermark. - * This function is supposed to make best-effort attempt to get the state watermark. If there is - * any error, it will return None. - * - * @param attributesToFindStateWatermarkFor attributes of the side whose state watermark - * is to be calculated - * @param attributesWithEventWatermark attributes of the other side which has a watermark column - * @param joinCondition join condition - * @param eventWatermark watermark defined on the input event data - * @return state value watermark in milliseconds, is possible. - */ - def getStateValueWatermark( - attributesToFindStateWatermarkFor: AttributeSet, - attributesWithEventWatermark: AttributeSet, - joinCondition: Option[Expression], - eventWatermark: Option[Long]): Option[Long] = { - - // If condition or event time watermark is not provided, then cannot calculate state watermark - if (joinCondition.isEmpty || eventWatermark.isEmpty) return None - - // If there is not watermark attribute, then cannot define state watermark - if (!attributesWithEventWatermark.exists(_.metadata.contains(delayKey))) return None - - def getStateWatermarkSafely(l: Expression, r: Expression): Option[Long] = { - try { - getStateWatermarkFromLessThenPredicate( - l, r, attributesToFindStateWatermarkFor, attributesWithEventWatermark, eventWatermark) - } catch { - case NonFatal(e) => - logWarning(s"Error trying to extract state constraint from condition $joinCondition", e) - None - } - } - - val allStateWatermarks = splitConjunctivePredicates(joinCondition.get).flatMap { predicate => - - // The generated the state watermark cleanup expression is inclusive of the state watermark. - // If state watermark is W, all state where timestamp <= W will be cleaned up. - // Now when the canonicalized join condition solves to leftTime >= W, we dont want to clean - // up leftTime <= W. Rather we should clean up leftTime <= W - 1. Hence the -1 below. - val stateWatermark = predicate match { - case LessThan(l, r) => getStateWatermarkSafely(l, r) - case LessThanOrEqual(l, r) => getStateWatermarkSafely(l, r).map(_ - 1) - case GreaterThan(l, r) => getStateWatermarkSafely(r, l) - case GreaterThanOrEqual(l, r) => getStateWatermarkSafely(r, l).map(_ - 1) - case _ => None - } - if (stateWatermark.nonEmpty) { - logInfo(s"Condition $joinCondition generated watermark constraint = ${stateWatermark.get}") - } - stateWatermark - } - allStateWatermarks.reduceOption((x, y) => Math.min(x, y)) - } - - /** - * Extract the state value watermark (milliseconds) from the condition - * `LessThan(leftExpr, rightExpr)` where . For example: if we want to find the constraint for - * leftTime using the watermark on the rightTime. Example: - * - * Input: rightTime-with-watermark + c1 < leftTime + c2 - * Canonical form: rightTime-with-watermark + c1 + (-c2) + (-leftTime) < 0 - * Solving for rightTime: rightTime-with-watermark + c1 + (-c2) < leftTime - * With watermark value: watermark-value + c1 + (-c2) < leftTime - */ - private def getStateWatermarkFromLessThenPredicate( - leftExpr: Expression, - rightExpr: Expression, - attributesToFindStateWatermarkFor: AttributeSet, - attributesWithEventWatermark: AttributeSet, - eventWatermark: Option[Long]): Option[Long] = { - - val attributesInCondition = AttributeSet( - leftExpr.collect { case a: AttributeReference => a } ++ - rightExpr.collect { case a: AttributeReference => a } - ) - if (attributesInCondition.filter { attributesToFindStateWatermarkFor.contains(_) }.size > 1 || - attributesInCondition.filter { attributesWithEventWatermark.contains(_) }.size > 1) { - // If more than attributes present in condition from one side, then it cannot be solved - return None - } - - def containsAttributeToFindStateConstraintFor(e: Expression): Boolean = { - e.collectLeaves().collectFirst { - case a @ AttributeReference(_, TimestampType, _, _) - if attributesToFindStateWatermarkFor.contains(a) => a - }.nonEmpty - } - - // Canonicalization step 1: convert to (rightTime-with-watermark + c1) - (leftTime + c2) < 0 - val allOnLeftExpr = Subtract(leftExpr, rightExpr) - logDebug(s"All on Left:\n${allOnLeftExpr.treeString(true)}\n${allOnLeftExpr.asCode}") - - // Canonicalization step 2: extract commutative terms - // rightTime-with-watermark, c1, -leftTime, -c2 - val terms = ExpressionSet(collectTerms(allOnLeftExpr)) - logDebug("Terms extracted from join condition:\n\t" + terms.mkString("\n\t")) - - - - // Find the term that has leftTime (i.e. the one present in attributesToFindConstraintFor - val constraintTerms = terms.filter(containsAttributeToFindStateConstraintFor) - - // Verify there is only one correct constraint term and of the correct type - if (constraintTerms.size > 1) { - logWarning("Failed to extract state constraint terms: multiple time terms in condition\n\t" + - terms.mkString("\n\t")) - return None - } - if (constraintTerms.isEmpty) { - logDebug("Failed to extract state constraint terms: no time terms in condition\n\t" + - terms.mkString("\n\t")) - return None - } - val constraintTerm = constraintTerms.head - if (constraintTerm.collectFirst { case u: UnaryMinus => u }.isEmpty) { - // Incorrect condition. We want the constraint term in canonical form to be `-leftTime` - // so that resolve for it as `-leftTime + watermark + c < 0` ==> `watermark + c < leftTime`. - // Now, if the original conditions is `rightTime-with-watermark > leftTime` and watermark - // condition is `rightTime-with-watermark > watermarkValue`, then no constraint about - // `leftTime` can be inferred. In this case, after canonicalization and collection of terms, - // the constraintTerm would be `leftTime` and not `-leftTime`. Hence, we return None. - return None - } - - // Replace watermark attribute with watermark value, and generate the resolved expression - // from the other terms. That is, - // rightTime-with-watermark, c1, -c2 => watermark, c1, -c2 => watermark + c1 + (-c2) - logDebug(s"Constraint term from join condition:\t$constraintTerm") - val exprWithWatermarkSubstituted = (terms - constraintTerm).map { term => - term.transform { - case a @ AttributeReference(_, TimestampType, _, metadata) - if attributesWithEventWatermark.contains(a) && metadata.contains(delayKey) => - Multiply(Literal(eventWatermark.get.toDouble), Literal(1000.0)) - } - }.reduceLeft(Add) - - // Calculate the constraint value - logInfo(s"Final expression to evaluate constraint:\t$exprWithWatermarkSubstituted") - val constraintValue = exprWithWatermarkSubstituted.eval().asInstanceOf[java.lang.Double] - Some((Double2double(constraintValue) / 1000.0).toLong) - } - - /** - * Collect all the terms present in an expression after converting it into the form - * a + b + c + d where each term be either an attribute or a literal casted to long, - * optionally wrapped in a unary minus. - */ - private def collectTerms(exprToCollectFrom: Expression): Seq[Expression] = { - var invalid = false - - /** Wrap a term with UnaryMinus if its needs to be negated. */ - def negateIfNeeded(expr: Expression, minus: Boolean): Expression = { - if (minus) UnaryMinus(expr) else expr - } - - /** - * Recursively split the expression into its leaf terms contains attributes or literals. - * Returns terms only of the forms: - * Cast(AttributeReference), UnaryMinus(Cast(AttributeReference)), - * Cast(AttributeReference, Double), UnaryMinus(Cast(AttributeReference, Double)) - * Multiply(Literal), UnaryMinus(Multiply(Literal)) - * Multiply(Cast(Literal)), UnaryMinus(Multiple(Cast(Literal))) - * - * Note: - * - If term needs to be negated for making it a commutative term, - * then it will be wrapped in UnaryMinus(...) - * - Each terms will be representing timestamp value or time interval in microseconds, - * typed as doubles. - */ - def collect(expr: Expression, negate: Boolean): Seq[Expression] = { - expr match { - case Add(left, right) => - collect(left, negate) ++ collect(right, negate) - case Subtract(left, right) => - collect(left, negate) ++ collect(right, !negate) - case TimeAdd(left, right, _) => - collect(left, negate) ++ collect(right, negate) - case TimeSub(left, right, _) => - collect(left, negate) ++ collect(right, !negate) - case UnaryMinus(child) => - collect(child, !negate) - case CheckOverflow(child, _) => - collect(child, negate) - case Cast(child, dataType, _) => - dataType match { - case _: NumericType | _: TimestampType => collect(child, negate) - case _ => - invalid = true - Seq.empty - } - case a: AttributeReference => - val castedRef = if (a.dataType != DoubleType) Cast(a, DoubleType) else a - Seq(negateIfNeeded(castedRef, negate)) - case lit: Literal => - // If literal of type calendar interval, then explicitly convert to millis - // Convert other number like literal to doubles representing millis (by x1000) - val castedLit = lit.dataType match { - case CalendarIntervalType => - val calendarInterval = lit.value.asInstanceOf[CalendarInterval] - if (calendarInterval.months > 0) { - invalid = true - logWarning( - s"Failed to extract state value watermark from condition $exprToCollectFrom " + - s"as imprecise intervals like months and years cannot be used for" + - s"watermark calculation. Use interval in terms of day instead.") - Literal(0.0) - } else { - Literal(calendarInterval.microseconds.toDouble) - } - case DoubleType => - Multiply(lit, Literal(1000000.0)) - case _: NumericType => - Multiply(Cast(lit, DoubleType), Literal(1000000.0)) - case _: TimestampType => - Multiply(PreciseTimestampConversion(lit, TimestampType, LongType), Literal(1000000.0)) - } - Seq(negateIfNeeded(castedLit, negate)) - case a @ _ => - logWarning( - s"Failed to extract state value watermark from condition $exprToCollectFrom due to $a") - invalid = true - Seq.empty - } - } - - val terms = collect(exprToCollectFrom, negate = false) - if (!invalid) terms else Seq.empty - } - /** * A custom RDD that allows partitions to be "zipped" together, while ensuring the tasks' * preferred location is based on which executors have the required join state stores already diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/SymmetricHashJoinStateManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/SymmetricHashJoinStateManager.scala index 37648710dfc2a..d256fb578d921 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/SymmetricHashJoinStateManager.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/SymmetricHashJoinStateManager.scala @@ -76,7 +76,7 @@ class SymmetricHashJoinStateManager( /** Get all the values of a key */ def get(key: UnsafeRow): Iterator[UnsafeRow] = { val numValues = keyToNumValues.get(key) - keyWithIndexToValue.getAll(key, numValues) + keyWithIndexToValue.getAll(key, numValues).map(_.value) } /** Append a new value to the key */ @@ -87,70 +87,163 @@ class SymmetricHashJoinStateManager( } /** - * Remove using a predicate on keys. See class docs for more context and implement details. + * Remove using a predicate on keys. + * + * This produces an iterator over the (key, value) pairs satisfying condition(key), where the + * underlying store is updated as a side-effect of producing next. + * + * This implies the iterator must be consumed fully without any other operations on this manager + * or the underlying store being interleaved. */ - def removeByKeyCondition(condition: UnsafeRow => Boolean): Unit = { - val allKeyToNumValues = keyToNumValues.iterator - - while (allKeyToNumValues.hasNext) { - val keyToNumValue = allKeyToNumValues.next - if (condition(keyToNumValue.key)) { - keyToNumValues.remove(keyToNumValue.key) - keyWithIndexToValue.removeAllValues(keyToNumValue.key, keyToNumValue.numValue) + def removeByKeyCondition(removalCondition: UnsafeRow => Boolean): Iterator[UnsafeRowPair] = { + new NextIterator[UnsafeRowPair] { + + private val allKeyToNumValues = keyToNumValues.iterator + + private var currentKeyToNumValue: KeyAndNumValues = null + private var currentValues: Iterator[KeyWithIndexAndValue] = null + + private def currentKey = currentKeyToNumValue.key + + private val reusedPair = new UnsafeRowPair() + + private def getAndRemoveValue() = { + val keyWithIndexAndValue = currentValues.next() + keyWithIndexToValue.remove(currentKey, keyWithIndexAndValue.valueIndex) + reusedPair.withRows(currentKey, keyWithIndexAndValue.value) + } + + override def getNext(): UnsafeRowPair = { + // If there are more values for the current key, remove and return the next one. + if (currentValues != null && currentValues.hasNext) { + return getAndRemoveValue() + } + + // If there weren't any values left, try and find the next key that satisfies the removal + // condition and has values. + while (allKeyToNumValues.hasNext) { + currentKeyToNumValue = allKeyToNumValues.next() + if (removalCondition(currentKey)) { + currentValues = keyWithIndexToValue.getAll( + currentKey, currentKeyToNumValue.numValue) + keyToNumValues.remove(currentKey) + + if (currentValues.hasNext) { + return getAndRemoveValue() + } + } + } + + // We only reach here if there were no satisfying keys left, which means we're done. + finished = true + return null } + + override def close: Unit = {} } } /** - * Remove using a predicate on values. See class docs for more context and implementation details. + * Remove using a predicate on values. + * + * At a high level, this produces an iterator over the (key, value) pairs such that value + * satisfies the predicate, where producing an element removes the value from the state store + * and producing all elements with a given key updates it accordingly. + * + * This implies the iterator must be consumed fully without any other operations on this manager + * or the underlying store being interleaved. */ - def removeByValueCondition(condition: UnsafeRow => Boolean): Unit = { - val allKeyToNumValues = keyToNumValues.iterator + def removeByValueCondition(removalCondition: UnsafeRow => Boolean): Iterator[UnsafeRowPair] = { + new NextIterator[UnsafeRowPair] { - while (allKeyToNumValues.hasNext) { - val keyToNumValue = allKeyToNumValues.next - val key = keyToNumValue.key + // Reuse this object to avoid creation+GC overhead. + private val reusedPair = new UnsafeRowPair() - var numValues: Long = keyToNumValue.numValue - var index: Long = 0L - var valueRemoved: Boolean = false - var valueForIndex: UnsafeRow = null + private val allKeyToNumValues = keyToNumValues.iterator - while (index < numValues) { - if (valueForIndex == null) { - valueForIndex = keyWithIndexToValue.get(key, index) + private var currentKey: UnsafeRow = null + private var numValues: Long = 0L + private var index: Long = 0L + private var valueRemoved: Boolean = false + + // Push the data for the current key to the numValues store, and reset the tracking variables + // to their empty state. + private def updateNumValueForCurrentKey(): Unit = { + if (valueRemoved) { + if (numValues >= 1) { + keyToNumValues.put(currentKey, numValues) + } else { + keyToNumValues.remove(currentKey) + } } - if (condition(valueForIndex)) { - if (numValues > 1) { - val valueAtMaxIndex = keyWithIndexToValue.get(key, numValues - 1) - keyWithIndexToValue.put(key, index, valueAtMaxIndex) - keyWithIndexToValue.remove(key, numValues - 1) - valueForIndex = valueAtMaxIndex + + currentKey = null + numValues = 0 + index = 0 + valueRemoved = false + } + + // Find the next value satisfying the condition, updating `currentKey` and `numValues` if + // needed. Returns null when no value can be found. + private def findNextValueForIndex(): UnsafeRow = { + // Loop across all values for the current key, and then all other keys, until we find a + // value satisfying the removal condition. + def hasMoreValuesForCurrentKey = currentKey != null && index < numValues + def hasMoreKeys = allKeyToNumValues.hasNext + while (hasMoreValuesForCurrentKey || hasMoreKeys) { + if (hasMoreValuesForCurrentKey) { + // First search the values for the current key. + val currentValue = keyWithIndexToValue.get(currentKey, index) + if (removalCondition(currentValue)) { + return currentValue + } else { + index += 1 + } + } else if (hasMoreKeys) { + // If we can't find a value for the current key, cleanup and start looking at the next. + // This will also happen the first time the iterator is called. + updateNumValueForCurrentKey() + + val currentKeyToNumValue = allKeyToNumValues.next() + currentKey = currentKeyToNumValue.key + numValues = currentKeyToNumValue.numValue } else { - keyWithIndexToValue.remove(key, 0) - valueForIndex = null + // Should be unreachable, but in any case means a value couldn't be found. + return null } - numValues -= 1 - valueRemoved = true - } else { - valueForIndex = null - index += 1 } + + // We tried and failed to find the next value. + return null } - if (valueRemoved) { - if (numValues >= 1) { - keyToNumValues.put(key, numValues) + + override def getNext(): UnsafeRowPair = { + val currentValue = findNextValueForIndex() + + // If there's no value, clean up and finish. There aren't any more available. + if (currentValue == null) { + updateNumValueForCurrentKey() + finished = true + return null + } + + // The backing store is arraylike - we as the caller are responsible for filling back in + // any hole. So we swap the last element into the hole and decrement numValues to shorten. + // clean + if (numValues > 1) { + val valueAtMaxIndex = keyWithIndexToValue.get(currentKey, numValues - 1) + keyWithIndexToValue.put(currentKey, index, valueAtMaxIndex) + keyWithIndexToValue.remove(currentKey, numValues - 1) } else { - keyToNumValues.remove(key) + keyWithIndexToValue.remove(currentKey, 0) } + numValues -= 1 + valueRemoved = true + + return reusedPair.withRows(currentKey, currentValue) } - } - } - def iterator(): Iterator[UnsafeRowPair] = { - val pair = new UnsafeRowPair() - keyWithIndexToValue.iterator.map { x => - pair.withRows(x.key, x.value) + override def close: Unit = {} } } @@ -309,19 +402,24 @@ class SymmetricHashJoinStateManager( stateStore.get(keyWithIndexRow(key, valueIndex)) } - /** Get all the values for key and all indices. */ - def getAll(key: UnsafeRow, numValues: Long): Iterator[UnsafeRow] = { + /** + * Get all values and indices for the provided key. + * Should not return null. + */ + def getAll(key: UnsafeRow, numValues: Long): Iterator[KeyWithIndexAndValue] = { + val keyWithIndexAndValue = new KeyWithIndexAndValue() var index = 0 - new NextIterator[UnsafeRow] { - override protected def getNext(): UnsafeRow = { + new NextIterator[KeyWithIndexAndValue] { + override protected def getNext(): KeyWithIndexAndValue = { if (index >= numValues) { finished = true null } else { val keyWithIndex = keyWithIndexRow(key, index) val value = stateStore.get(keyWithIndex) + keyWithIndexAndValue.withNew(key, index, value) index += 1 - value + keyWithIndexAndValue } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/SymmetricHashJoinStateManagerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/SymmetricHashJoinStateManagerSuite.scala index ffa4c3c22a194..d44af1d14c27a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/SymmetricHashJoinStateManagerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/SymmetricHashJoinStateManagerSuite.scala @@ -137,14 +137,16 @@ class SymmetricHashJoinStateManagerSuite extends StreamTest with BeforeAndAfter BoundReference( 1, inputValueAttribWithWatermark.dataType, inputValueAttribWithWatermark.nullable), Literal(threshold)) - manager.removeByKeyCondition(GeneratePredicate.generate(expr).eval _) + val iter = manager.removeByKeyCondition(GeneratePredicate.generate(expr).eval _) + while (iter.hasNext) iter.next() } /** Remove values where `time <= threshold` */ def removeByValue(watermark: Long)(implicit manager: SymmetricHashJoinStateManager): Unit = { val expr = LessThanOrEqual(inputValueAttribWithWatermark, Literal(watermark)) - manager.removeByValueCondition( + val iter = manager.removeByValueCondition( GeneratePredicate.generate(expr, inputValueAttribs).eval _) + while (iter.hasNext) iter.next() } def numRows(implicit manager: SymmetricHashJoinStateManager): Long = { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingJoinSuite.scala index 533e1165fd59c..a6593b71e51de 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingJoinSuite.scala @@ -24,8 +24,9 @@ import scala.util.Random import org.scalatest.BeforeAndAfter import org.apache.spark.scheduler.ExecutorCacheTaskLocation -import org.apache.spark.sql.{AnalysisException, SparkSession} -import org.apache.spark.sql.catalyst.expressions.{AttributeReference, AttributeSet} +import org.apache.spark.sql.{AnalysisException, DataFrame, Row, SparkSession} +import org.apache.spark.sql.catalyst.analysis.StreamingJoinHelper +import org.apache.spark.sql.catalyst.expressions.{AttributeReference, AttributeSet, Literal} import org.apache.spark.sql.catalyst.plans.logical.{EventTimeWatermark, Filter} import org.apache.spark.sql.execution.LogicalRDD import org.apache.spark.sql.execution.streaming.{MemoryStream, StatefulOperatorStateInfo, StreamingSymmetricHashJoinHelper} @@ -35,7 +36,7 @@ import org.apache.spark.sql.types._ import org.apache.spark.util.Utils -class StreamingJoinSuite extends StreamTest with StateStoreMetricsTest with BeforeAndAfter { +class StreamingInnerJoinSuite extends StreamTest with StateStoreMetricsTest with BeforeAndAfter { before { SparkSession.setActiveSession(spark) // set this before force initializing 'joinExec' @@ -322,111 +323,6 @@ class StreamingJoinSuite extends StreamTest with StateStoreMetricsTest with Befo assert(e.toString.contains("Stream stream joins without equality predicate is not supported")) } - testQuietly("extract watermark from time condition") { - val attributesToFindConstraintFor = Seq( - AttributeReference("leftTime", TimestampType)(), - AttributeReference("leftOther", IntegerType)()) - val metadataWithWatermark = new MetadataBuilder() - .putLong(EventTimeWatermark.delayKey, 1000) - .build() - val attributesWithWatermark = Seq( - AttributeReference("rightTime", TimestampType, metadata = metadataWithWatermark)(), - AttributeReference("rightOther", IntegerType)()) - - def watermarkFrom( - conditionStr: String, - rightWatermark: Option[Long] = Some(10000)): Option[Long] = { - val conditionExpr = Some(conditionStr).map { str => - val plan = - Filter( - spark.sessionState.sqlParser.parseExpression(str), - LogicalRDD( - attributesToFindConstraintFor ++ attributesWithWatermark, - spark.sparkContext.emptyRDD)(spark)) - plan.queryExecution.optimizedPlan.asInstanceOf[Filter].condition - } - StreamingSymmetricHashJoinHelper.getStateValueWatermark( - AttributeSet(attributesToFindConstraintFor), AttributeSet(attributesWithWatermark), - conditionExpr, rightWatermark) - } - - // Test comparison directionality. E.g. if leftTime < rightTime and rightTime > watermark, - // then cannot define constraint on leftTime. - assert(watermarkFrom("leftTime > rightTime") === Some(10000)) - assert(watermarkFrom("leftTime >= rightTime") === Some(9999)) - assert(watermarkFrom("leftTime < rightTime") === None) - assert(watermarkFrom("leftTime <= rightTime") === None) - assert(watermarkFrom("rightTime > leftTime") === None) - assert(watermarkFrom("rightTime >= leftTime") === None) - assert(watermarkFrom("rightTime < leftTime") === Some(10000)) - assert(watermarkFrom("rightTime <= leftTime") === Some(9999)) - - // Test type conversions - assert(watermarkFrom("CAST(leftTime AS LONG) > CAST(rightTime AS LONG)") === Some(10000)) - assert(watermarkFrom("CAST(leftTime AS LONG) < CAST(rightTime AS LONG)") === None) - assert(watermarkFrom("CAST(leftTime AS DOUBLE) > CAST(rightTime AS DOUBLE)") === Some(10000)) - assert(watermarkFrom("CAST(leftTime AS LONG) > CAST(rightTime AS DOUBLE)") === Some(10000)) - assert(watermarkFrom("CAST(leftTime AS LONG) > CAST(rightTime AS FLOAT)") === Some(10000)) - assert(watermarkFrom("CAST(leftTime AS DOUBLE) > CAST(rightTime AS FLOAT)") === Some(10000)) - assert(watermarkFrom("CAST(leftTime AS STRING) > CAST(rightTime AS STRING)") === None) - - // Test with timestamp type + calendar interval on either side of equation - // Note: timestamptype and calendar interval don't commute, so less valid combinations to test. - assert(watermarkFrom("leftTime > rightTime + interval 1 second") === Some(11000)) - assert(watermarkFrom("leftTime + interval 2 seconds > rightTime ") === Some(8000)) - assert(watermarkFrom("leftTime > rightTime - interval 3 second") === Some(7000)) - assert(watermarkFrom("rightTime < leftTime - interval 3 second") === Some(13000)) - assert(watermarkFrom("rightTime - interval 1 second < leftTime - interval 3 second") - === Some(12000)) - - // Test with casted long type + constants on either side of equation - // Note: long type and constants commute, so more combinations to test. - // -- Constants on the right - assert(watermarkFrom("CAST(leftTime AS LONG) > CAST(rightTime AS LONG) + 1") === Some(11000)) - assert(watermarkFrom("CAST(leftTime AS LONG) > CAST(rightTime AS LONG) - 1") === Some(9000)) - assert(watermarkFrom("CAST(leftTime AS LONG) > CAST((rightTime + interval 1 second) AS LONG)") - === Some(11000)) - assert(watermarkFrom("CAST(leftTime AS LONG) > 2 + CAST(rightTime AS LONG)") === Some(12000)) - assert(watermarkFrom("CAST(leftTime AS LONG) > -0.5 + CAST(rightTime AS LONG)") === Some(9500)) - assert(watermarkFrom("CAST(leftTime AS LONG) - CAST(rightTime AS LONG) > 2") === Some(12000)) - assert(watermarkFrom("-CAST(rightTime AS DOUBLE) + CAST(leftTime AS LONG) > 0.1") - === Some(10100)) - assert(watermarkFrom("0 > CAST(rightTime AS LONG) - CAST(leftTime AS LONG) + 0.2") - === Some(10200)) - // -- Constants on the left - assert(watermarkFrom("CAST(leftTime AS LONG) + 2 > CAST(rightTime AS LONG)") === Some(8000)) - assert(watermarkFrom("1 + CAST(leftTime AS LONG) > CAST(rightTime AS LONG)") === Some(9000)) - assert(watermarkFrom("CAST((leftTime + interval 3 second) AS LONG) > CAST(rightTime AS LONG)") - === Some(7000)) - assert(watermarkFrom("CAST(leftTime AS LONG) - 2 > CAST(rightTime AS LONG)") === Some(12000)) - assert(watermarkFrom("CAST(leftTime AS LONG) + 0.5 > CAST(rightTime AS LONG)") === Some(9500)) - assert(watermarkFrom("CAST(leftTime AS LONG) - CAST(rightTime AS LONG) - 2 > 0") - === Some(12000)) - assert(watermarkFrom("-CAST(rightTime AS LONG) + CAST(leftTime AS LONG) - 0.1 > 0") - === Some(10100)) - // -- Constants on both sides, mixed types - assert(watermarkFrom("CAST(leftTime AS LONG) - 2.0 > CAST(rightTime AS LONG) + 1") - === Some(13000)) - - // Test multiple conditions, should return minimum watermark - assert(watermarkFrom( - "leftTime > rightTime - interval 3 second AND rightTime < leftTime + interval 2 seconds") === - Some(7000)) // first condition wins - assert(watermarkFrom( - "leftTime > rightTime - interval 3 second AND rightTime < leftTime + interval 4 seconds") === - Some(6000)) // second condition wins - - // Test invalid comparisons - assert(watermarkFrom("cast(leftTime AS LONG) > leftOther") === None) // non-time attributes - assert(watermarkFrom("leftOther > rightOther") === None) // non-time attributes - assert(watermarkFrom("leftOther > rightOther AND leftTime > rightTime") === Some(10000)) - assert(watermarkFrom("cast(rightTime AS DOUBLE) < rightOther") === None) // non-time attributes - assert(watermarkFrom("leftTime > rightTime + interval 1 month") === None) // month not allowed - - // Test static comparisons - assert(watermarkFrom("cast(leftTime AS LONG) > 10") === Some(10000)) - } - test("locality preferences of StateStoreAwareZippedRDD") { import StreamingSymmetricHashJoinHelper._ @@ -470,3 +366,189 @@ class StreamingJoinSuite extends StreamTest with StateStoreMetricsTest with Befo } } } + +class StreamingOuterJoinSuite extends StreamTest with StateStoreMetricsTest with BeforeAndAfter { + + import testImplicits._ + import org.apache.spark.sql.functions._ + + before { + SparkSession.setActiveSession(spark) // set this before force initializing 'joinExec' + spark.streams.stateStoreCoordinator // initialize the lazy coordinator + } + + after { + StateStore.stop() + } + + private def setupStream(prefix: String, multiplier: Int): (MemoryStream[Int], DataFrame) = { + val input = MemoryStream[Int] + val df = input.toDF + .select( + 'value as "key", + 'value.cast("timestamp") as s"${prefix}Time", + ('value * multiplier) as s"${prefix}Value") + .withWatermark(s"${prefix}Time", "10 seconds") + + return (input, df) + } + + private def setupWindowedJoin(joinType: String): + (MemoryStream[Int], MemoryStream[Int], DataFrame) = { + val (input1, df1) = setupStream("left", 2) + val (input2, df2) = setupStream("right", 3) + val windowed1 = df1.select('key, window('leftTime, "10 second"), 'leftValue) + val windowed2 = df2.select('key, window('rightTime, "10 second"), 'rightValue) + val joined = windowed1.join(windowed2, Seq("key", "window"), joinType) + .select('key, $"window.end".cast("long"), 'leftValue, 'rightValue) + + (input1, input2, joined) + } + + test("windowed left outer join") { + val (leftInput, rightInput, joined) = setupWindowedJoin("left_outer") + + testStream(joined)( + // Test inner part of the join. + AddData(leftInput, 1, 2, 3, 4, 5), + AddData(rightInput, 3, 4, 5, 6, 7), + CheckLastBatch((3, 10, 6, 9), (4, 10, 8, 12), (5, 10, 10, 15)), + // Old state doesn't get dropped until the batch *after* it gets introduced, so the + // nulls won't show up until the next batch after the watermark advances. + AddData(leftInput, 21), + AddData(rightInput, 22), + CheckLastBatch(), + assertNumStateRows(total = 12, updated = 2), + AddData(leftInput, 22), + CheckLastBatch(Row(22, 30, 44, 66), Row(1, 10, 2, null), Row(2, 10, 4, null)), + assertNumStateRows(total = 3, updated = 1) + ) + } + + test("windowed right outer join") { + val (leftInput, rightInput, joined) = setupWindowedJoin("right_outer") + + testStream(joined)( + // Test inner part of the join. + AddData(leftInput, 1, 2, 3, 4, 5), + AddData(rightInput, 3, 4, 5, 6, 7), + CheckLastBatch((3, 10, 6, 9), (4, 10, 8, 12), (5, 10, 10, 15)), + // Old state doesn't get dropped until the batch *after* it gets introduced, so the + // nulls won't show up until the next batch after the watermark advances. + AddData(leftInput, 21), + AddData(rightInput, 22), + CheckLastBatch(), + assertNumStateRows(total = 12, updated = 2), + AddData(leftInput, 22), + CheckLastBatch(Row(22, 30, 44, 66), Row(6, 10, null, 18), Row(7, 10, null, 21)), + assertNumStateRows(total = 3, updated = 1) + ) + } + + Seq( + ("left_outer", Row(3, null, 5, null)), + ("right_outer", Row(null, 2, null, 5)) + ).foreach { case (joinType: String, outerResult) => + test(s"${joinType.replaceAllLiterally("_", " ")} with watermark range condition") { + import org.apache.spark.sql.functions._ + + val leftInput = MemoryStream[(Int, Int)] + val rightInput = MemoryStream[(Int, Int)] + + val df1 = leftInput.toDF.toDF("leftKey", "time") + .select('leftKey, 'time.cast("timestamp") as "leftTime", ('leftKey * 2) as "leftValue") + .withWatermark("leftTime", "10 seconds") + + val df2 = rightInput.toDF.toDF("rightKey", "time") + .select('rightKey, 'time.cast("timestamp") as "rightTime", ('rightKey * 3) as "rightValue") + .withWatermark("rightTime", "10 seconds") + + val joined = + df1.join( + df2, + expr("leftKey = rightKey AND " + + "leftTime BETWEEN rightTime - interval 5 seconds AND rightTime + interval 5 seconds"), + joinType) + .select('leftKey, 'rightKey, 'leftTime.cast("int"), 'rightTime.cast("int")) + testStream(joined)( + AddData(leftInput, (1, 5), (3, 5)), + CheckAnswer(), + AddData(rightInput, (1, 10), (2, 5)), + CheckLastBatch((1, 1, 5, 10)), + AddData(rightInput, (1, 11)), + CheckLastBatch(), // no match as left time is too low + assertNumStateRows(total = 5, updated = 1), + + // Increase event time watermark to 20s by adding data with time = 30s on both inputs + AddData(leftInput, (1, 7), (1, 30)), + CheckLastBatch((1, 1, 7, 10), (1, 1, 7, 11)), + assertNumStateRows(total = 7, updated = 2), + AddData(rightInput, (0, 30)), + CheckLastBatch(), + assertNumStateRows(total = 8, updated = 1), + AddData(rightInput, (0, 30)), + CheckLastBatch(outerResult), + assertNumStateRows(total = 3, updated = 1) + ) + } + } + + // When the join condition isn't true, the outer null rows must be generated, even if the join + // keys themselves have a match. + test("left outer join with non-key condition violated on left") { + val (leftInput, simpleLeftDf) = setupStream("left", 2) + val (rightInput, simpleRightDf) = setupStream("right", 3) + + val left = simpleLeftDf.select('key, window('leftTime, "10 second"), 'leftValue) + val right = simpleRightDf.select('key, window('rightTime, "10 second"), 'rightValue) + + val joined = left.join( + right, + left("key") === right("key") && left("window") === right("window") && + 'leftValue > 10 && ('rightValue < 300 || 'rightValue > 1000), + "left_outer") + .select(left("key"), left("window.end").cast("long"), 'leftValue, 'rightValue) + + testStream(joined)( + // leftValue <= 10 should generate outer join rows even though it matches right keys + AddData(leftInput, 1, 2, 3), + AddData(rightInput, 1, 2, 3), + CheckLastBatch(), + AddData(leftInput, 20), + AddData(rightInput, 21), + CheckLastBatch(), + assertNumStateRows(total = 8, updated = 2), + AddData(rightInput, 20), + CheckLastBatch( + Row(20, 30, 40, 60), Row(1, 10, 2, null), Row(2, 10, 4, null), Row(3, 10, 6, null)), + assertNumStateRows(total = 3, updated = 1), + // leftValue and rightValue both satisfying condition should not generate outer join rows + AddData(leftInput, 40, 41), + AddData(rightInput, 40, 41), + CheckLastBatch((40, 50, 80, 120), (41, 50, 82, 123)), + AddData(leftInput, 70), + AddData(rightInput, 71), + CheckLastBatch(), + assertNumStateRows(total = 6, updated = 2), + AddData(rightInput, 70), + CheckLastBatch((70, 80, 140, 210)), + assertNumStateRows(total = 3, updated = 1), + // rightValue between 300 and 1000 should generate outer join rows even though it matches left + AddData(leftInput, 101, 102, 103), + AddData(rightInput, 101, 102, 103), + CheckLastBatch(), + AddData(leftInput, 1000), + AddData(rightInput, 1001), + CheckLastBatch(), + assertNumStateRows(total = 8, updated = 2), + AddData(rightInput, 1000), + CheckLastBatch( + Row(1000, 1010, 2000, 3000), + Row(101, 110, 202, null), + Row(102, 110, 204, null), + Row(103, 110, 206, null)), + assertNumStateRows(total = 3, updated = 1) + ) + } +} + From d54670192a6acd892d13b511dfb62390be6ad39c Mon Sep 17 00:00:00 2001 From: Rekha Joshi Date: Wed, 4 Oct 2017 07:11:00 +0100 Subject: [PATCH 1439/1765] [SPARK-22193][SQL] Minor typo fix ## What changes were proposed in this pull request? [SPARK-22193][SQL] Minor typo fix in SortMergeJoinExec. Nothing major, but it bothered me going into.Hence fixing ## How was this patch tested? existing tests Author: Rekha Joshi Author: rjoshi2 Closes #19422 from rekhajoshm/SPARK-22193. --- .../spark/sql/execution/joins/SortMergeJoinExec.scala | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala index 14de2dc23e3c0..4e02803552e82 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala @@ -402,7 +402,7 @@ case class SortMergeJoinExec( } } - private def genComparision(ctx: CodegenContext, a: Seq[ExprCode], b: Seq[ExprCode]): String = { + private def genComparison(ctx: CodegenContext, a: Seq[ExprCode], b: Seq[ExprCode]): String = { val comparisons = a.zip(b).zipWithIndex.map { case ((l, r), i) => s""" |if (comp == 0) { @@ -463,7 +463,7 @@ case class SortMergeJoinExec( | continue; | } | if (!$matches.isEmpty()) { - | ${genComparision(ctx, leftKeyVars, matchedKeyVars)} + | ${genComparison(ctx, leftKeyVars, matchedKeyVars)} | if (comp == 0) { | return true; | } @@ -484,7 +484,7 @@ case class SortMergeJoinExec( | } | ${rightKeyVars.map(_.code).mkString("\n")} | } - | ${genComparision(ctx, leftKeyVars, rightKeyVars)} + | ${genComparison(ctx, leftKeyVars, rightKeyVars)} | if (comp > 0) { | $rightRow = null; | } else if (comp < 0) { From 64df08b64779bab629a8a90a3797d8bd70f61703 Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Wed, 4 Oct 2017 15:06:44 +0800 Subject: [PATCH 1440/1765] [SPARK-20783][SQL] Create ColumnVector to abstract existing compressed column (batch method) ## What changes were proposed in this pull request? This PR abstracts data compressed by `CompressibleColumnAccessor` using `ColumnVector` in batch method. When `ColumnAccessor.decompress` is called, `ColumnVector` will have uncompressed data. This batch decompress does not use `InternalRow` to reduce the number of memory accesses. As first step of this implementation, this JIRA supports primitive data types. Another PR will support array and other data types. This implementation decompress data in batch into uncompressed column batch, as rxin suggested at [here](https://github.com/apache/spark/pull/18468#issuecomment-316914076). Another implementation uses adapter approach [as cloud-fan suggested](https://github.com/apache/spark/pull/18468). ## How was this patch tested? Added test suites Author: Kazuaki Ishizaki Closes #18704 from kiszk/SPARK-20783a. --- .../execution/columnar/ColumnDictionary.java | 58 +++ .../vectorized/OffHeapColumnVector.java | 18 + .../vectorized/OnHeapColumnVector.java | 18 + .../vectorized/WritableColumnVector.java | 76 ++-- .../execution/columnar/ColumnAccessor.scala | 16 +- .../sql/execution/columnar/ColumnType.scala | 33 ++ .../CompressibleColumnAccessor.scala | 4 + .../compression/CompressionScheme.scala | 3 + .../compression/compressionSchemes.scala | 340 +++++++++++++++++- .../compression/BooleanBitSetSuite.scala | 52 +++ .../compression/DictionaryEncodingSuite.scala | 72 +++- .../compression/IntegralDeltaSuite.scala | 72 ++++ .../PassThroughEncodingSuite.scala | 189 ++++++++++ .../compression/RunLengthEncodingSuite.scala | 89 ++++- .../TestCompressibleColumnBuilder.scala | 9 +- .../vectorized/ColumnVectorSuite.scala | 183 +++++++++- .../vectorized/ColumnarBatchSuite.scala | 4 +- 17 files changed, 1192 insertions(+), 44 deletions(-) create mode 100644 sql/core/src/main/java/org/apache/spark/sql/execution/columnar/ColumnDictionary.java create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/compression/PassThroughEncodingSuite.scala diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/columnar/ColumnDictionary.java b/sql/core/src/main/java/org/apache/spark/sql/execution/columnar/ColumnDictionary.java new file mode 100644 index 0000000000000..f1785853a94ae --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/columnar/ColumnDictionary.java @@ -0,0 +1,58 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.columnar; + +import org.apache.spark.sql.execution.vectorized.Dictionary; + +public final class ColumnDictionary implements Dictionary { + private int[] intDictionary; + private long[] longDictionary; + + public ColumnDictionary(int[] dictionary) { + this.intDictionary = dictionary; + } + + public ColumnDictionary(long[] dictionary) { + this.longDictionary = dictionary; + } + + @Override + public int decodeToInt(int id) { + return intDictionary[id]; + } + + @Override + public long decodeToLong(int id) { + return longDictionary[id]; + } + + @Override + public float decodeToFloat(int id) { + throw new UnsupportedOperationException("Dictionary encoding does not support float"); + } + + @Override + public double decodeToDouble(int id) { + throw new UnsupportedOperationException("Dictionary encoding does not support double"); + } + + @Override + public byte[] decodeToBinary(int id) { + throw new UnsupportedOperationException("Dictionary encoding does not support String"); + } +} diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OffHeapColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OffHeapColumnVector.java index 8cbc895506d91..a7522ebf5821a 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OffHeapColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OffHeapColumnVector.java @@ -228,6 +228,12 @@ public void putShorts(int rowId, int count, short[] src, int srcIndex) { null, data + 2 * rowId, count * 2); } + @Override + public void putShorts(int rowId, int count, byte[] src, int srcIndex) { + Platform.copyMemory(src, Platform.BYTE_ARRAY_OFFSET + srcIndex, + null, data + rowId * 2, count * 2); + } + @Override public short getShort(int rowId) { if (dictionary == null) { @@ -268,6 +274,12 @@ public void putInts(int rowId, int count, int[] src, int srcIndex) { null, data + 4 * rowId, count * 4); } + @Override + public void putInts(int rowId, int count, byte[] src, int srcIndex) { + Platform.copyMemory(src, Platform.BYTE_ARRAY_OFFSET + srcIndex, + null, data + rowId * 4, count * 4); + } + @Override public void putIntsLittleEndian(int rowId, int count, byte[] src, int srcIndex) { if (!bigEndianPlatform) { @@ -334,6 +346,12 @@ public void putLongs(int rowId, int count, long[] src, int srcIndex) { null, data + 8 * rowId, count * 8); } + @Override + public void putLongs(int rowId, int count, byte[] src, int srcIndex) { + Platform.copyMemory(src, Platform.BYTE_ARRAY_OFFSET + srcIndex, + null, data + rowId * 8, count * 8); + } + @Override public void putLongsLittleEndian(int rowId, int count, byte[] src, int srcIndex) { if (!bigEndianPlatform) { diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OnHeapColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OnHeapColumnVector.java index 2725a29eeabe8..166a39e0fabd9 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OnHeapColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OnHeapColumnVector.java @@ -233,6 +233,12 @@ public void putShorts(int rowId, int count, short[] src, int srcIndex) { System.arraycopy(src, srcIndex, shortData, rowId, count); } + @Override + public void putShorts(int rowId, int count, byte[] src, int srcIndex) { + Platform.copyMemory(src, Platform.BYTE_ARRAY_OFFSET + srcIndex, shortData, + Platform.SHORT_ARRAY_OFFSET + rowId * 2, count * 2); + } + @Override public short getShort(int rowId) { if (dictionary == null) { @@ -272,6 +278,12 @@ public void putInts(int rowId, int count, int[] src, int srcIndex) { System.arraycopy(src, srcIndex, intData, rowId, count); } + @Override + public void putInts(int rowId, int count, byte[] src, int srcIndex) { + Platform.copyMemory(src, Platform.BYTE_ARRAY_OFFSET + srcIndex, intData, + Platform.INT_ARRAY_OFFSET + rowId * 4, count * 4); + } + @Override public void putIntsLittleEndian(int rowId, int count, byte[] src, int srcIndex) { int srcOffset = srcIndex + Platform.BYTE_ARRAY_OFFSET; @@ -332,6 +344,12 @@ public void putLongs(int rowId, int count, long[] src, int srcIndex) { System.arraycopy(src, srcIndex, longData, rowId, count); } + @Override + public void putLongs(int rowId, int count, byte[] src, int srcIndex) { + Platform.copyMemory(src, Platform.BYTE_ARRAY_OFFSET + srcIndex, longData, + Platform.LONG_ARRAY_OFFSET + rowId * 8, count * 8); + } + @Override public void putLongsLittleEndian(int rowId, int count, byte[] src, int srcIndex) { int srcOffset = srcIndex + Platform.BYTE_ARRAY_OFFSET; diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/WritableColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/WritableColumnVector.java index 163f2511e5f73..da72954ddc448 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/WritableColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/WritableColumnVector.java @@ -113,138 +113,156 @@ private void throwUnsupportedException(int requiredCapacity, Throwable cause) { protected abstract void reserveInternal(int capacity); /** - * Sets the value at rowId to null/not null. + * Sets null/not null to the value at rowId. */ public abstract void putNotNull(int rowId); public abstract void putNull(int rowId); /** - * Sets the values from [rowId, rowId + count) to null/not null. + * Sets null/not null to the values at [rowId, rowId + count). */ public abstract void putNulls(int rowId, int count); public abstract void putNotNulls(int rowId, int count); /** - * Sets the value at rowId to `value`. + * Sets `value` to the value at rowId. */ public abstract void putBoolean(int rowId, boolean value); /** - * Sets values from [rowId, rowId + count) to value. + * Sets value to [rowId, rowId + count). */ public abstract void putBooleans(int rowId, int count, boolean value); /** - * Sets the value at rowId to `value`. + * Sets `value` to the value at rowId. */ public abstract void putByte(int rowId, byte value); /** - * Sets values from [rowId, rowId + count) to value. + * Sets value to [rowId, rowId + count). */ public abstract void putBytes(int rowId, int count, byte value); /** - * Sets values from [rowId, rowId + count) to [src + srcIndex, src + srcIndex + count) + * Sets values from [src[srcIndex], src[srcIndex + count]) to [rowId, rowId + count) */ public abstract void putBytes(int rowId, int count, byte[] src, int srcIndex); /** - * Sets the value at rowId to `value`. + * Sets `value` to the value at rowId. */ public abstract void putShort(int rowId, short value); /** - * Sets values from [rowId, rowId + count) to value. + * Sets value to [rowId, rowId + count). */ public abstract void putShorts(int rowId, int count, short value); /** - * Sets values from [rowId, rowId + count) to [src + srcIndex, src + srcIndex + count) + * Sets values from [src[srcIndex], src[srcIndex + count]) to [rowId, rowId + count) */ public abstract void putShorts(int rowId, int count, short[] src, int srcIndex); /** - * Sets the value at rowId to `value`. + * Sets values from [src[srcIndex], src[srcIndex + count * 2]) to [rowId, rowId + count) + * The data in src must be 2-byte platform native endian shorts. + */ + public abstract void putShorts(int rowId, int count, byte[] src, int srcIndex); + + /** + * Sets `value` to the value at rowId. */ public abstract void putInt(int rowId, int value); /** - * Sets values from [rowId, rowId + count) to value. + * Sets value to [rowId, rowId + count). */ public abstract void putInts(int rowId, int count, int value); /** - * Sets values from [rowId, rowId + count) to [src + srcIndex, src + srcIndex + count) + * Sets values from [src[srcIndex], src[srcIndex + count]) to [rowId, rowId + count) */ public abstract void putInts(int rowId, int count, int[] src, int srcIndex); /** - * Sets values from [rowId, rowId + count) to [src[srcIndex], src[srcIndex + count]) + * Sets values from [src[srcIndex], src[srcIndex + count * 4]) to [rowId, rowId + count) + * The data in src must be 4-byte platform native endian ints. + */ + public abstract void putInts(int rowId, int count, byte[] src, int srcIndex); + + /** + * Sets values from [src[srcIndex], src[srcIndex + count * 4]) to [rowId, rowId + count) * The data in src must be 4-byte little endian ints. */ public abstract void putIntsLittleEndian(int rowId, int count, byte[] src, int srcIndex); /** - * Sets the value at rowId to `value`. + * Sets `value` to the value at rowId. */ public abstract void putLong(int rowId, long value); /** - * Sets values from [rowId, rowId + count) to value. + * Sets value to [rowId, rowId + count). */ public abstract void putLongs(int rowId, int count, long value); /** - * Sets values from [rowId, rowId + count) to [src + srcIndex, src + srcIndex + count) + * Sets values from [src[srcIndex], src[srcIndex + count]) to [rowId, rowId + count) */ public abstract void putLongs(int rowId, int count, long[] src, int srcIndex); /** - * Sets values from [rowId, rowId + count) to [src[srcIndex], src[srcIndex + count]) + * Sets values from [src[srcIndex], src[srcIndex + count * 8]) to [rowId, rowId + count) + * The data in src must be 8-byte platform native endian longs. + */ + public abstract void putLongs(int rowId, int count, byte[] src, int srcIndex); + + /** + * Sets values from [src + srcIndex, src + srcIndex + count * 8) to [rowId, rowId + count) * The data in src must be 8-byte little endian longs. */ public abstract void putLongsLittleEndian(int rowId, int count, byte[] src, int srcIndex); /** - * Sets the value at rowId to `value`. + * Sets `value` to the value at rowId. */ public abstract void putFloat(int rowId, float value); /** - * Sets values from [rowId, rowId + count) to value. + * Sets value to [rowId, rowId + count). */ public abstract void putFloats(int rowId, int count, float value); /** - * Sets values from [rowId, rowId + count) to [src + srcIndex, src + srcIndex + count) + * Sets values from [src[srcIndex], src[srcIndex + count]) to [rowId, rowId + count) */ public abstract void putFloats(int rowId, int count, float[] src, int srcIndex); /** - * Sets values from [rowId, rowId + count) to [src[srcIndex], src[srcIndex + count]) - * The data in src must be ieee formatted floats. + * Sets values from [src[srcIndex], src[srcIndex + count * 4]) to [rowId, rowId + count) + * The data in src must be ieee formatted floats in platform native endian. */ public abstract void putFloats(int rowId, int count, byte[] src, int srcIndex); /** - * Sets the value at rowId to `value`. + * Sets `value` to the value at rowId. */ public abstract void putDouble(int rowId, double value); /** - * Sets values from [rowId, rowId + count) to value. + * Sets value to [rowId, rowId + count). */ public abstract void putDoubles(int rowId, int count, double value); /** - * Sets values from [rowId, rowId + count) to [src + srcIndex, src + srcIndex + count) + * Sets values from [src[srcIndex], src[srcIndex + count]) to [rowId, rowId + count) */ public abstract void putDoubles(int rowId, int count, double[] src, int srcIndex); /** - * Sets values from [rowId, rowId + count) to [src[srcIndex], src[srcIndex + count]) - * The data in src must be ieee formatted doubles. + * Sets values from [src[srcIndex], src[srcIndex + count * 8]) to [rowId, rowId + count) + * The data in src must be ieee formatted doubles in platform native endian. */ public abstract void putDoubles(int rowId, int count, byte[] src, int srcIndex); @@ -254,7 +272,7 @@ private void throwUnsupportedException(int requiredCapacity, Throwable cause) { public abstract void putArray(int rowId, int offset, int length); /** - * Sets the value at rowId to `value`. + * Sets values from [value + offset, value + offset + count) to the values at rowId. */ public abstract int putByteArray(int rowId, byte[] value, int offset, int count); public final int putByteArray(int rowId, byte[] value) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnAccessor.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnAccessor.scala index 6241b79d9affc..24c8ac81420cb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnAccessor.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnAccessor.scala @@ -24,6 +24,7 @@ import scala.annotation.tailrec import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{UnsafeArrayData, UnsafeMapData, UnsafeRow} import org.apache.spark.sql.execution.columnar.compression.CompressibleColumnAccessor +import org.apache.spark.sql.execution.vectorized.WritableColumnVector import org.apache.spark.sql.types._ /** @@ -62,6 +63,9 @@ private[columnar] abstract class BasicColumnAccessor[JvmType]( } protected def underlyingBuffer = buffer + + def getByteBuffer: ByteBuffer = + buffer.duplicate.order(ByteOrder.nativeOrder()) } private[columnar] class NullColumnAccessor(buffer: ByteBuffer) @@ -122,7 +126,7 @@ private[columnar] class MapColumnAccessor(buffer: ByteBuffer, dataType: MapType) extends BasicColumnAccessor[UnsafeMapData](buffer, MAP(dataType)) with NullableColumnAccessor -private[columnar] object ColumnAccessor { +private[sql] object ColumnAccessor { @tailrec def apply(dataType: DataType, buffer: ByteBuffer): ColumnAccessor = { val buf = buffer.order(ByteOrder.nativeOrder) @@ -149,4 +153,14 @@ private[columnar] object ColumnAccessor { throw new Exception(s"not support type: $other") } } + + def decompress(columnAccessor: ColumnAccessor, columnVector: WritableColumnVector, numRows: Int): + Unit = { + if (columnAccessor.isInstanceOf[NativeColumnAccessor[_]]) { + val nativeAccessor = columnAccessor.asInstanceOf[NativeColumnAccessor[_]] + nativeAccessor.decompress(columnVector, numRows) + } else { + throw new RuntimeException("Not support non-primitive type now") + } + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnType.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnType.scala index 5cfb003e4f150..e9b150fd86095 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnType.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnType.scala @@ -43,6 +43,12 @@ import org.apache.spark.unsafe.types.UTF8String * WARNING: This only works with HeapByteBuffer */ private[columnar] object ByteBufferHelper { + def getShort(buffer: ByteBuffer): Short = { + val pos = buffer.position() + buffer.position(pos + 2) + Platform.getShort(buffer.array(), Platform.BYTE_ARRAY_OFFSET + pos) + } + def getInt(buffer: ByteBuffer): Int = { val pos = buffer.position() buffer.position(pos + 4) @@ -66,6 +72,33 @@ private[columnar] object ByteBufferHelper { buffer.position(pos + 8) Platform.getDouble(buffer.array(), Platform.BYTE_ARRAY_OFFSET + pos) } + + def putShort(buffer: ByteBuffer, value: Short): Unit = { + val pos = buffer.position() + buffer.position(pos + 2) + Platform.putShort(buffer.array(), Platform.BYTE_ARRAY_OFFSET + pos, value) + } + + def putInt(buffer: ByteBuffer, value: Int): Unit = { + val pos = buffer.position() + buffer.position(pos + 4) + Platform.putInt(buffer.array(), Platform.BYTE_ARRAY_OFFSET + pos, value) + } + + def putLong(buffer: ByteBuffer, value: Long): Unit = { + val pos = buffer.position() + buffer.position(pos + 8) + Platform.putLong(buffer.array(), Platform.BYTE_ARRAY_OFFSET + pos, value) + } + + def copyMemory(src: ByteBuffer, dst: ByteBuffer, len: Int): Unit = { + val srcPos = src.position() + val dstPos = dst.position() + src.position(srcPos + len) + dst.position(dstPos + len) + Platform.copyMemory(src.array(), Platform.BYTE_ARRAY_OFFSET + srcPos, + dst.array(), Platform.BYTE_ARRAY_OFFSET + dstPos, len) + } } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/compression/CompressibleColumnAccessor.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/compression/CompressibleColumnAccessor.scala index e1d13ad0e94e5..774011f1e3de8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/compression/CompressibleColumnAccessor.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/compression/CompressibleColumnAccessor.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.execution.columnar.compression import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.execution.columnar.{ColumnAccessor, NativeColumnAccessor} +import org.apache.spark.sql.execution.vectorized.WritableColumnVector import org.apache.spark.sql.types.AtomicType private[columnar] trait CompressibleColumnAccessor[T <: AtomicType] extends ColumnAccessor { @@ -36,4 +37,7 @@ private[columnar] trait CompressibleColumnAccessor[T <: AtomicType] extends Colu override def extractSingle(row: InternalRow, ordinal: Int): Unit = { decoder.next(row, ordinal) } + + def decompress(columnVector: WritableColumnVector, capacity: Int): Unit = + decoder.decompress(columnVector, capacity) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/compression/CompressionScheme.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/compression/CompressionScheme.scala index 6e4f1c5b80684..f8aeba44257d8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/compression/CompressionScheme.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/compression/CompressionScheme.scala @@ -21,6 +21,7 @@ import java.nio.{ByteBuffer, ByteOrder} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.execution.columnar.{ColumnType, NativeColumnType} +import org.apache.spark.sql.execution.vectorized.WritableColumnVector import org.apache.spark.sql.types.AtomicType private[columnar] trait Encoder[T <: AtomicType] { @@ -41,6 +42,8 @@ private[columnar] trait Decoder[T <: AtomicType] { def next(row: InternalRow, ordinal: Int): Unit def hasNext: Boolean + + def decompress(columnVector: WritableColumnVector, capacity: Int): Unit } private[columnar] trait CompressionScheme { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/compression/compressionSchemes.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/compression/compressionSchemes.scala index ee99c90a751d9..bf00ad997c76e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/compression/compressionSchemes.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/compression/compressionSchemes.scala @@ -18,12 +18,14 @@ package org.apache.spark.sql.execution.columnar.compression import java.nio.ByteBuffer +import java.nio.ByteOrder import scala.collection.mutable import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.SpecificInternalRow import org.apache.spark.sql.execution.columnar._ +import org.apache.spark.sql.execution.vectorized.WritableColumnVector import org.apache.spark.sql.types._ @@ -61,6 +63,101 @@ private[columnar] case object PassThrough extends CompressionScheme { } override def hasNext: Boolean = buffer.hasRemaining + + private def putBooleans( + columnVector: WritableColumnVector, pos: Int, bufferPos: Int, len: Int): Unit = { + for (i <- 0 until len) { + columnVector.putBoolean(pos + i, (buffer.get(bufferPos + i) != 0)) + } + } + + private def putBytes( + columnVector: WritableColumnVector, pos: Int, bufferPos: Int, len: Int): Unit = { + columnVector.putBytes(pos, len, buffer.array, bufferPos) + } + + private def putShorts( + columnVector: WritableColumnVector, pos: Int, bufferPos: Int, len: Int): Unit = { + columnVector.putShorts(pos, len, buffer.array, bufferPos) + } + + private def putInts( + columnVector: WritableColumnVector, pos: Int, bufferPos: Int, len: Int): Unit = { + columnVector.putInts(pos, len, buffer.array, bufferPos) + } + + private def putLongs( + columnVector: WritableColumnVector, pos: Int, bufferPos: Int, len: Int): Unit = { + columnVector.putLongs(pos, len, buffer.array, bufferPos) + } + + private def putFloats( + columnVector: WritableColumnVector, pos: Int, bufferPos: Int, len: Int): Unit = { + columnVector.putFloats(pos, len, buffer.array, bufferPos) + } + + private def putDoubles( + columnVector: WritableColumnVector, pos: Int, bufferPos: Int, len: Int): Unit = { + columnVector.putDoubles(pos, len, buffer.array, bufferPos) + } + + private def decompress0( + columnVector: WritableColumnVector, + capacity: Int, + unitSize: Int, + putFunction: (WritableColumnVector, Int, Int, Int) => Unit): Unit = { + val nullsBuffer = buffer.duplicate().order(ByteOrder.nativeOrder()) + nullsBuffer.rewind() + val nullCount = ByteBufferHelper.getInt(nullsBuffer) + var nextNullIndex = if (nullCount > 0) ByteBufferHelper.getInt(nullsBuffer) else capacity + var pos = 0 + var seenNulls = 0 + var bufferPos = buffer.position + while (pos < capacity) { + if (pos != nextNullIndex) { + val len = nextNullIndex - pos + assert(len * unitSize < Int.MaxValue) + putFunction(columnVector, pos, bufferPos, len) + bufferPos += len * unitSize + pos += len + } else { + seenNulls += 1 + nextNullIndex = if (seenNulls < nullCount) { + ByteBufferHelper.getInt(nullsBuffer) + } else { + capacity + } + columnVector.putNull(pos) + pos += 1 + } + } + } + + override def decompress(columnVector: WritableColumnVector, capacity: Int): Unit = { + columnType.dataType match { + case _: BooleanType => + val unitSize = 1 + decompress0(columnVector, capacity, unitSize, putBooleans) + case _: ByteType => + val unitSize = 1 + decompress0(columnVector, capacity, unitSize, putBytes) + case _: ShortType => + val unitSize = 2 + decompress0(columnVector, capacity, unitSize, putShorts) + case _: IntegerType => + val unitSize = 4 + decompress0(columnVector, capacity, unitSize, putInts) + case _: LongType => + val unitSize = 8 + decompress0(columnVector, capacity, unitSize, putLongs) + case _: FloatType => + val unitSize = 4 + decompress0(columnVector, capacity, unitSize, putFloats) + case _: DoubleType => + val unitSize = 8 + decompress0(columnVector, capacity, unitSize, putDoubles) + } + } } } @@ -169,6 +266,94 @@ private[columnar] case object RunLengthEncoding extends CompressionScheme { } override def hasNext: Boolean = valueCount < run || buffer.hasRemaining + + private def putBoolean(columnVector: WritableColumnVector, pos: Int, value: Long): Unit = { + columnVector.putBoolean(pos, value == 1) + } + + private def getByte(buffer: ByteBuffer): Long = { + buffer.get().toLong + } + + private def putByte(columnVector: WritableColumnVector, pos: Int, value: Long): Unit = { + columnVector.putByte(pos, value.toByte) + } + + private def getShort(buffer: ByteBuffer): Long = { + buffer.getShort().toLong + } + + private def putShort(columnVector: WritableColumnVector, pos: Int, value: Long): Unit = { + columnVector.putShort(pos, value.toShort) + } + + private def getInt(buffer: ByteBuffer): Long = { + buffer.getInt().toLong + } + + private def putInt(columnVector: WritableColumnVector, pos: Int, value: Long): Unit = { + columnVector.putInt(pos, value.toInt) + } + + private def getLong(buffer: ByteBuffer): Long = { + buffer.getLong() + } + + private def putLong(columnVector: WritableColumnVector, pos: Int, value: Long): Unit = { + columnVector.putLong(pos, value) + } + + private def decompress0( + columnVector: WritableColumnVector, + capacity: Int, + getFunction: (ByteBuffer) => Long, + putFunction: (WritableColumnVector, Int, Long) => Unit): Unit = { + val nullsBuffer = buffer.duplicate().order(ByteOrder.nativeOrder()) + nullsBuffer.rewind() + val nullCount = ByteBufferHelper.getInt(nullsBuffer) + var nextNullIndex = if (nullCount > 0) ByteBufferHelper.getInt(nullsBuffer) else -1 + var pos = 0 + var seenNulls = 0 + var runLocal = 0 + var valueCountLocal = 0 + var currentValueLocal: Long = 0 + + while (valueCountLocal < runLocal || (pos < capacity)) { + if (pos != nextNullIndex) { + if (valueCountLocal == runLocal) { + currentValueLocal = getFunction(buffer) + runLocal = ByteBufferHelper.getInt(buffer) + valueCountLocal = 1 + } else { + valueCountLocal += 1 + } + putFunction(columnVector, pos, currentValueLocal) + } else { + seenNulls += 1 + if (seenNulls < nullCount) { + nextNullIndex = ByteBufferHelper.getInt(nullsBuffer) + } + columnVector.putNull(pos) + } + pos += 1 + } + } + + override def decompress(columnVector: WritableColumnVector, capacity: Int): Unit = { + columnType.dataType match { + case _: BooleanType => + decompress0(columnVector, capacity, getByte, putBoolean) + case _: ByteType => + decompress0(columnVector, capacity, getByte, putByte) + case _: ShortType => + decompress0(columnVector, capacity, getShort, putShort) + case _: IntegerType => + decompress0(columnVector, capacity, getInt, putInt) + case _: LongType => + decompress0(columnVector, capacity, getLong, putLong) + case _ => throw new IllegalStateException("Not supported type in RunLengthEncoding.") + } + } } } @@ -266,11 +451,32 @@ private[columnar] case object DictionaryEncoding extends CompressionScheme { } class Decoder[T <: AtomicType](buffer: ByteBuffer, columnType: NativeColumnType[T]) - extends compression.Decoder[T] { - - private val dictionary: Array[Any] = { - val elementNum = ByteBufferHelper.getInt(buffer) - Array.fill[Any](elementNum)(columnType.extract(buffer).asInstanceOf[Any]) + extends compression.Decoder[T] { + val elementNum = ByteBufferHelper.getInt(buffer) + private val dictionary: Array[Any] = new Array[Any](elementNum) + private var intDictionary: Array[Int] = null + private var longDictionary: Array[Long] = null + + columnType.dataType match { + case _: IntegerType => + intDictionary = new Array[Int](elementNum) + for (i <- 0 until elementNum) { + val v = columnType.extract(buffer).asInstanceOf[Int] + intDictionary(i) = v + dictionary(i) = v + } + case _: LongType => + longDictionary = new Array[Long](elementNum) + for (i <- 0 until elementNum) { + val v = columnType.extract(buffer).asInstanceOf[Long] + longDictionary(i) = v + dictionary(i) = v + } + case _: StringType => + for (i <- 0 until elementNum) { + val v = columnType.extract(buffer).asInstanceOf[Any] + dictionary(i) = v + } } override def next(row: InternalRow, ordinal: Int): Unit = { @@ -278,6 +484,46 @@ private[columnar] case object DictionaryEncoding extends CompressionScheme { } override def hasNext: Boolean = buffer.hasRemaining + + override def decompress(columnVector: WritableColumnVector, capacity: Int): Unit = { + val nullsBuffer = buffer.duplicate().order(ByteOrder.nativeOrder()) + nullsBuffer.rewind() + val nullCount = ByteBufferHelper.getInt(nullsBuffer) + var nextNullIndex = if (nullCount > 0) ByteBufferHelper.getInt(nullsBuffer) else -1 + var pos = 0 + var seenNulls = 0 + columnType.dataType match { + case _: IntegerType => + val dictionaryIds = columnVector.reserveDictionaryIds(capacity) + columnVector.setDictionary(new ColumnDictionary(intDictionary)) + while (pos < capacity) { + if (pos != nextNullIndex) { + dictionaryIds.putInt(pos, buffer.getShort()) + } else { + seenNulls += 1 + if (seenNulls < nullCount) nextNullIndex = ByteBufferHelper.getInt(nullsBuffer) + columnVector.putNull(pos) + } + pos += 1 + } + case _: LongType => + val dictionaryIds = columnVector.reserveDictionaryIds(capacity) + columnVector.setDictionary(new ColumnDictionary(longDictionary)) + while (pos < capacity) { + if (pos != nextNullIndex) { + dictionaryIds.putInt(pos, buffer.getShort()) + } else { + seenNulls += 1 + if (seenNulls < nullCount) { + nextNullIndex = ByteBufferHelper.getInt(nullsBuffer) + } + columnVector.putNull(pos) + } + pos += 1 + } + case _ => throw new IllegalStateException("Not supported type in DictionaryEncoding.") + } + } } } @@ -368,6 +614,38 @@ private[columnar] case object BooleanBitSet extends CompressionScheme { } override def hasNext: Boolean = visited < count + + override def decompress(columnVector: WritableColumnVector, capacity: Int): Unit = { + val countLocal = count + var currentWordLocal: Long = 0 + var visitedLocal: Int = 0 + val nullsBuffer = buffer.duplicate().order(ByteOrder.nativeOrder()) + nullsBuffer.rewind() + val nullCount = ByteBufferHelper.getInt(nullsBuffer) + var nextNullIndex = if (nullCount > 0) ByteBufferHelper.getInt(nullsBuffer) else -1 + var pos = 0 + var seenNulls = 0 + + while (visitedLocal < countLocal) { + if (pos != nextNullIndex) { + val bit = visitedLocal % BITS_PER_LONG + + visitedLocal += 1 + if (bit == 0) { + currentWordLocal = ByteBufferHelper.getLong(buffer) + } + + columnVector.putBoolean(pos, ((currentWordLocal >> bit) & 1) != 0) + } else { + seenNulls += 1 + if (seenNulls < nullCount) { + nextNullIndex = ByteBufferHelper.getInt(nullsBuffer) + } + columnVector.putNull(pos) + } + pos += 1 + } + } } } @@ -448,6 +726,32 @@ private[columnar] case object IntDelta extends CompressionScheme { prev = if (delta > Byte.MinValue) prev + delta else ByteBufferHelper.getInt(buffer) row.setInt(ordinal, prev) } + + override def decompress(columnVector: WritableColumnVector, capacity: Int): Unit = { + var prevLocal: Int = 0 + val nullsBuffer = buffer.duplicate().order(ByteOrder.nativeOrder()) + nullsBuffer.rewind() + val nullCount = ByteBufferHelper.getInt(nullsBuffer) + var nextNullIndex = if (nullCount > 0) ByteBufferHelper.getInt(nullsBuffer) else -1 + var pos = 0 + var seenNulls = 0 + + while (pos < capacity) { + if (pos != nextNullIndex) { + val delta = buffer.get + prevLocal = if (delta > Byte.MinValue) { prevLocal + delta } else + { ByteBufferHelper.getInt(buffer) } + columnVector.putInt(pos, prevLocal) + } else { + seenNulls += 1 + if (seenNulls < nullCount) { + nextNullIndex = ByteBufferHelper.getInt(nullsBuffer) + } + columnVector.putNull(pos) + } + pos += 1 + } + } } } @@ -528,5 +832,31 @@ private[columnar] case object LongDelta extends CompressionScheme { prev = if (delta > Byte.MinValue) prev + delta else ByteBufferHelper.getLong(buffer) row.setLong(ordinal, prev) } + + override def decompress(columnVector: WritableColumnVector, capacity: Int): Unit = { + var prevLocal: Long = 0 + val nullsBuffer = buffer.duplicate().order(ByteOrder.nativeOrder()) + nullsBuffer.rewind + val nullCount = ByteBufferHelper.getInt(nullsBuffer) + var nextNullIndex = if (nullCount > 0) ByteBufferHelper.getInt(nullsBuffer) else -1 + var pos = 0 + var seenNulls = 0 + + while (pos < capacity) { + if (pos != nextNullIndex) { + val delta = buffer.get() + prevLocal = if (delta > Byte.MinValue) { prevLocal + delta } else + { ByteBufferHelper.getLong(buffer) } + columnVector.putLong(pos, prevLocal) + } else { + seenNulls += 1 + if (seenNulls < nullCount) { + nextNullIndex = ByteBufferHelper.getInt(nullsBuffer) + } + columnVector.putNull(pos) + } + pos += 1 + } + } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/compression/BooleanBitSetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/compression/BooleanBitSetSuite.scala index d01bf911e3a77..2d71a42628dfb 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/compression/BooleanBitSetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/compression/BooleanBitSetSuite.scala @@ -22,6 +22,8 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.GenericInternalRow import org.apache.spark.sql.execution.columnar.{BOOLEAN, NoopColumnStats} import org.apache.spark.sql.execution.columnar.ColumnarTestUtils._ +import org.apache.spark.sql.execution.vectorized.OnHeapColumnVector +import org.apache.spark.sql.types.BooleanType class BooleanBitSetSuite extends SparkFunSuite { import BooleanBitSet._ @@ -85,6 +87,36 @@ class BooleanBitSetSuite extends SparkFunSuite { assert(!decoder.hasNext) } + def skeletonForDecompress(count: Int) { + val builder = TestCompressibleColumnBuilder(new NoopColumnStats, BOOLEAN, BooleanBitSet) + val rows = Seq.fill[InternalRow](count)(makeRandomRow(BOOLEAN)) + val values = rows.map(_.getBoolean(0)) + + rows.foreach(builder.appendFrom(_, 0)) + val buffer = builder.build() + + // ---------------- + // Tests decompress + // ---------------- + + // Rewinds, skips column header and 4 more bytes for compression scheme ID + val headerSize = CompressionScheme.columnHeaderSize(buffer) + buffer.position(headerSize) + assertResult(BooleanBitSet.typeId, "Wrong compression scheme ID")(buffer.getInt()) + + val decoder = BooleanBitSet.decoder(buffer, BOOLEAN) + val columnVector = new OnHeapColumnVector(values.length, BooleanType) + decoder.decompress(columnVector, values.length) + + if (values.nonEmpty) { + values.zipWithIndex.foreach { case (b: Boolean, index: Int) => + assertResult(b, s"Wrong ${index}-th decoded boolean value") { + columnVector.getBoolean(index) + } + } + } + } + test(s"$BooleanBitSet: empty") { skeleton(0) } @@ -104,4 +136,24 @@ class BooleanBitSetSuite extends SparkFunSuite { test(s"$BooleanBitSet: multiple words and 1 more bit") { skeleton(BITS_PER_LONG * 2 + 1) } + + test(s"$BooleanBitSet: empty for decompression()") { + skeletonForDecompress(0) + } + + test(s"$BooleanBitSet: less than 1 word for decompression()") { + skeletonForDecompress(BITS_PER_LONG - 1) + } + + test(s"$BooleanBitSet: exactly 1 word for decompression()") { + skeletonForDecompress(BITS_PER_LONG) + } + + test(s"$BooleanBitSet: multiple whole words for decompression()") { + skeletonForDecompress(BITS_PER_LONG * 2) + } + + test(s"$BooleanBitSet: multiple words and 1 more bit for decompression()") { + skeletonForDecompress(BITS_PER_LONG * 2 + 1) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/compression/DictionaryEncodingSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/compression/DictionaryEncodingSuite.scala index 67139b13d7882..28950b74cf1c8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/compression/DictionaryEncodingSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/compression/DictionaryEncodingSuite.scala @@ -23,16 +23,19 @@ import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.expressions.GenericInternalRow import org.apache.spark.sql.execution.columnar._ import org.apache.spark.sql.execution.columnar.ColumnarTestUtils._ +import org.apache.spark.sql.execution.vectorized.OnHeapColumnVector import org.apache.spark.sql.types.AtomicType class DictionaryEncodingSuite extends SparkFunSuite { + val nullValue = -1 testDictionaryEncoding(new IntColumnStats, INT) testDictionaryEncoding(new LongColumnStats, LONG) - testDictionaryEncoding(new StringColumnStats, STRING) + testDictionaryEncoding(new StringColumnStats, STRING, false) def testDictionaryEncoding[T <: AtomicType]( columnStats: ColumnStats, - columnType: NativeColumnType[T]) { + columnType: NativeColumnType[T], + testDecompress: Boolean = true) { val typeName = columnType.getClass.getSimpleName.stripSuffix("$") @@ -113,6 +116,58 @@ class DictionaryEncodingSuite extends SparkFunSuite { } } + def skeletonForDecompress(uniqueValueCount: Int, inputSeq: Seq[Int]) { + if (!testDecompress) return + val builder = TestCompressibleColumnBuilder(columnStats, columnType, DictionaryEncoding) + val (values, rows) = makeUniqueValuesAndSingleValueRows(columnType, uniqueValueCount) + val dictValues = stableDistinct(inputSeq) + + val nullRow = new GenericInternalRow(1) + nullRow.setNullAt(0) + inputSeq.foreach { i => + if (i == nullValue) { + builder.appendFrom(nullRow, 0) + } else { + builder.appendFrom(rows(i), 0) + } + } + val buffer = builder.build() + + // ---------------- + // Tests decompress + // ---------------- + // Rewinds, skips column header and 4 more bytes for compression scheme ID + val headerSize = CompressionScheme.columnHeaderSize(buffer) + buffer.position(headerSize) + assertResult(DictionaryEncoding.typeId, "Wrong compression scheme ID")(buffer.getInt()) + + val decoder = DictionaryEncoding.decoder(buffer, columnType) + val columnVector = new OnHeapColumnVector(inputSeq.length, columnType.dataType) + decoder.decompress(columnVector, inputSeq.length) + + if (inputSeq.nonEmpty) { + inputSeq.zipWithIndex.foreach { case (i: Any, index: Int) => + if (i == nullValue) { + assertResult(true, s"Wrong null ${index}-th position") { + columnVector.isNullAt(index) + } + } else { + columnType match { + case INT => + assertResult(values(i), s"Wrong ${index}-th decoded int value") { + columnVector.getInt(index) + } + case LONG => + assertResult(values(i), s"Wrong ${index}-th decoded long value") { + columnVector.getLong(index) + } + case _ => fail("Unsupported type") + } + } + } + } + } + test(s"$DictionaryEncoding with $typeName: empty") { skeleton(0, Seq.empty) } @@ -124,5 +179,18 @@ class DictionaryEncodingSuite extends SparkFunSuite { test(s"$DictionaryEncoding with $typeName: dictionary overflow") { skeleton(DictionaryEncoding.MAX_DICT_SIZE + 1, 0 to DictionaryEncoding.MAX_DICT_SIZE) } + + test(s"$DictionaryEncoding with $typeName: empty for decompress()") { + skeletonForDecompress(0, Seq.empty) + } + + test(s"$DictionaryEncoding with $typeName: simple case for decompress()") { + skeletonForDecompress(2, Seq(0, nullValue, 0, nullValue)) + } + + test(s"$DictionaryEncoding with $typeName: dictionary overflow for decompress()") { + skeletonForDecompress(DictionaryEncoding.MAX_DICT_SIZE + 2, + Seq(nullValue) ++ (0 to DictionaryEncoding.MAX_DICT_SIZE - 1) ++ Seq(nullValue)) + } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/compression/IntegralDeltaSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/compression/IntegralDeltaSuite.scala index 411d31fa0e29b..0d9f1fb0c02c9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/compression/IntegralDeltaSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/compression/IntegralDeltaSuite.scala @@ -21,9 +21,11 @@ import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.expressions.GenericInternalRow import org.apache.spark.sql.execution.columnar._ import org.apache.spark.sql.execution.columnar.ColumnarTestUtils._ +import org.apache.spark.sql.execution.vectorized.OnHeapColumnVector import org.apache.spark.sql.types.IntegralType class IntegralDeltaSuite extends SparkFunSuite { + val nullValue = -1 testIntegralDelta(new IntColumnStats, INT, IntDelta) testIntegralDelta(new LongColumnStats, LONG, LongDelta) @@ -109,6 +111,53 @@ class IntegralDeltaSuite extends SparkFunSuite { assert(!decoder.hasNext) } + def skeletonForDecompress(input: Seq[I#InternalType]) { + val builder = TestCompressibleColumnBuilder(columnStats, columnType, scheme) + val row = new GenericInternalRow(1) + val nullRow = new GenericInternalRow(1) + nullRow.setNullAt(0) + input.map { value => + if (value == nullValue) { + builder.appendFrom(nullRow, 0) + } else { + columnType.setField(row, 0, value) + builder.appendFrom(row, 0) + } + } + val buffer = builder.build() + + // ---------------- + // Tests decompress + // ---------------- + // Rewinds, skips column header and 4 more bytes for compression scheme ID + val headerSize = CompressionScheme.columnHeaderSize(buffer) + buffer.position(headerSize) + assertResult(scheme.typeId, "Wrong compression scheme ID")(buffer.getInt()) + + val decoder = scheme.decoder(buffer, columnType) + val columnVector = new OnHeapColumnVector(input.length, columnType.dataType) + decoder.decompress(columnVector, input.length) + + if (input.nonEmpty) { + input.zipWithIndex.foreach { + case (expected: Any, index: Int) if expected == nullValue => + assertResult(true, s"Wrong null ${index}th-position") { + columnVector.isNullAt(index) + } + case (expected: Int, index: Int) => + assertResult(expected, s"Wrong ${index}-th decoded int value") { + columnVector.getInt(index) + } + case (expected: Long, index: Int) => + assertResult(expected, s"Wrong ${index}-th decoded long value") { + columnVector.getLong(index) + } + case _ => + fail("Unsupported type") + } + } + } + test(s"$scheme: empty column") { skeleton(Seq.empty) } @@ -127,5 +176,28 @@ class IntegralDeltaSuite extends SparkFunSuite { val input = Array.fill[Any](10000)(makeRandomValue(columnType)) skeleton(input.map(_.asInstanceOf[I#InternalType])) } + + + test(s"$scheme: empty column for decompress()") { + skeletonForDecompress(Seq.empty) + } + + test(s"$scheme: simple case for decompress()") { + val input = columnType match { + case INT => Seq(2: Int, 1: Int, 2: Int, 130: Int) + case LONG => Seq(2: Long, 1: Long, 2: Long, 130: Long) + } + + skeletonForDecompress(input.map(_.asInstanceOf[I#InternalType])) + } + + test(s"$scheme: simple case with null for decompress()") { + val input = columnType match { + case INT => Seq(2: Int, 1: Int, 2: Int, nullValue: Int, 5: Int) + case LONG => Seq(2: Long, 1: Long, 2: Long, nullValue: Long, 5: Long) + } + + skeletonForDecompress(input.map(_.asInstanceOf[I#InternalType])) + } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/compression/PassThroughEncodingSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/compression/PassThroughEncodingSuite.scala new file mode 100644 index 0000000000000..b6f0b5e6277b4 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/compression/PassThroughEncodingSuite.scala @@ -0,0 +1,189 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.columnar.compression + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.expressions.GenericInternalRow +import org.apache.spark.sql.execution.columnar._ +import org.apache.spark.sql.execution.columnar.ColumnarTestUtils._ +import org.apache.spark.sql.execution.vectorized.OnHeapColumnVector +import org.apache.spark.sql.types.AtomicType + +class PassThroughSuite extends SparkFunSuite { + val nullValue = -1 + testPassThrough(new ByteColumnStats, BYTE) + testPassThrough(new ShortColumnStats, SHORT) + testPassThrough(new IntColumnStats, INT) + testPassThrough(new LongColumnStats, LONG) + testPassThrough(new FloatColumnStats, FLOAT) + testPassThrough(new DoubleColumnStats, DOUBLE) + + def testPassThrough[T <: AtomicType]( + columnStats: ColumnStats, + columnType: NativeColumnType[T]) { + + val typeName = columnType.getClass.getSimpleName.stripSuffix("$") + + def skeleton(input: Seq[T#InternalType]) { + // ------------- + // Tests encoder + // ------------- + + val builder = TestCompressibleColumnBuilder(columnStats, columnType, PassThrough) + + input.map { value => + val row = new GenericInternalRow(1) + columnType.setField(row, 0, value) + builder.appendFrom(row, 0) + } + + val buffer = builder.build() + // Column type ID + null count + null positions + val headerSize = CompressionScheme.columnHeaderSize(buffer) + + // Compression scheme ID + compressed contents + val compressedSize = 4 + input.size * columnType.defaultSize + + // 4 extra bytes for compression scheme type ID + assertResult(headerSize + compressedSize, "Wrong buffer capacity")(buffer.capacity) + + buffer.position(headerSize) + assertResult(PassThrough.typeId, "Wrong compression scheme ID")(buffer.getInt()) + + if (input.nonEmpty) { + input.foreach { value => + assertResult(value, "Wrong value")(columnType.extract(buffer)) + } + } + + // ------------- + // Tests decoder + // ------------- + + // Rewinds, skips column header and 4 more bytes for compression scheme ID + buffer.rewind().position(headerSize + 4) + + val decoder = PassThrough.decoder(buffer, columnType) + val mutableRow = new GenericInternalRow(1) + + if (input.nonEmpty) { + input.foreach{ + assert(decoder.hasNext) + assertResult(_, "Wrong decoded value") { + decoder.next(mutableRow, 0) + columnType.getField(mutableRow, 0) + } + } + } + assert(!decoder.hasNext) + } + + def skeletonForDecompress(input: Seq[T#InternalType]) { + val builder = TestCompressibleColumnBuilder(columnStats, columnType, PassThrough) + val row = new GenericInternalRow(1) + val nullRow = new GenericInternalRow(1) + nullRow.setNullAt(0) + input.map { value => + if (value == nullValue) { + builder.appendFrom(nullRow, 0) + } else { + columnType.setField(row, 0, value) + builder.appendFrom(row, 0) + } + } + val buffer = builder.build() + + // ---------------- + // Tests decompress + // ---------------- + // Rewinds, skips column header and 4 more bytes for compression scheme ID + val headerSize = CompressionScheme.columnHeaderSize(buffer) + buffer.position(headerSize) + assertResult(PassThrough.typeId, "Wrong compression scheme ID")(buffer.getInt()) + + val decoder = PassThrough.decoder(buffer, columnType) + val columnVector = new OnHeapColumnVector(input.length, columnType.dataType) + decoder.decompress(columnVector, input.length) + + if (input.nonEmpty) { + input.zipWithIndex.foreach { + case (expected: Any, index: Int) if expected == nullValue => + assertResult(true, s"Wrong null ${index}th-position") { + columnVector.isNullAt(index) + } + case (expected: Byte, index: Int) => + assertResult(expected, s"Wrong ${index}-th decoded byte value") { + columnVector.getByte(index) + } + case (expected: Short, index: Int) => + assertResult(expected, s"Wrong ${index}-th decoded short value") { + columnVector.getShort(index) + } + case (expected: Int, index: Int) => + assertResult(expected, s"Wrong ${index}-th decoded int value") { + columnVector.getInt(index) + } + case (expected: Long, index: Int) => + assertResult(expected, s"Wrong ${index}-th decoded long value") { + columnVector.getLong(index) + } + case (expected: Float, index: Int) => + assertResult(expected, s"Wrong ${index}-th decoded float value") { + columnVector.getFloat(index) + } + case (expected: Double, index: Int) => + assertResult(expected, s"Wrong ${index}-th decoded double value") { + columnVector.getDouble(index) + } + case _ => fail("Unsupported type") + } + } + } + + test(s"$PassThrough with $typeName: empty column") { + skeleton(Seq.empty) + } + + test(s"$PassThrough with $typeName: long random series") { + val input = Array.fill[Any](10000)(makeRandomValue(columnType)) + skeleton(input.map(_.asInstanceOf[T#InternalType])) + } + + test(s"$PassThrough with $typeName: empty column for decompress()") { + skeletonForDecompress(Seq.empty) + } + + test(s"$PassThrough with $typeName: long random series for decompress()") { + val input = Array.fill[Any](10000)(makeRandomValue(columnType)) + skeletonForDecompress(input.map(_.asInstanceOf[T#InternalType])) + } + + test(s"$PassThrough with $typeName: simple case with null for decompress()") { + val input = columnType match { + case BYTE => Seq(2: Byte, 1: Byte, 2: Byte, nullValue.toByte: Byte, 5: Byte) + case SHORT => Seq(2: Short, 1: Short, 2: Short, nullValue.toShort: Short, 5: Short) + case INT => Seq(2: Int, 1: Int, 2: Int, nullValue: Int, 5: Int) + case LONG => Seq(2: Long, 1: Long, 2: Long, nullValue: Long, 5: Long) + case FLOAT => Seq(2: Float, 1: Float, 2: Float, nullValue: Float, 5: Float) + case DOUBLE => Seq(2: Double, 1: Double, 2: Double, nullValue: Double, 5: Double) + } + + skeletonForDecompress(input.map(_.asInstanceOf[T#InternalType])) + } + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/compression/RunLengthEncodingSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/compression/RunLengthEncodingSuite.scala index dffa9b364ebfe..eb1cdd9bbceff 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/compression/RunLengthEncodingSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/compression/RunLengthEncodingSuite.scala @@ -21,19 +21,22 @@ import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.expressions.GenericInternalRow import org.apache.spark.sql.execution.columnar._ import org.apache.spark.sql.execution.columnar.ColumnarTestUtils._ +import org.apache.spark.sql.execution.vectorized.OnHeapColumnVector import org.apache.spark.sql.types.AtomicType class RunLengthEncodingSuite extends SparkFunSuite { + val nullValue = -1 testRunLengthEncoding(new NoopColumnStats, BOOLEAN) testRunLengthEncoding(new ByteColumnStats, BYTE) testRunLengthEncoding(new ShortColumnStats, SHORT) testRunLengthEncoding(new IntColumnStats, INT) testRunLengthEncoding(new LongColumnStats, LONG) - testRunLengthEncoding(new StringColumnStats, STRING) + testRunLengthEncoding(new StringColumnStats, STRING, false) def testRunLengthEncoding[T <: AtomicType]( columnStats: ColumnStats, - columnType: NativeColumnType[T]) { + columnType: NativeColumnType[T], + testDecompress: Boolean = true) { val typeName = columnType.getClass.getSimpleName.stripSuffix("$") @@ -95,6 +98,72 @@ class RunLengthEncodingSuite extends SparkFunSuite { assert(!decoder.hasNext) } + def skeletonForDecompress(uniqueValueCount: Int, inputRuns: Seq[(Int, Int)]) { + if (!testDecompress) return + val builder = TestCompressibleColumnBuilder(columnStats, columnType, RunLengthEncoding) + val (values, rows) = makeUniqueValuesAndSingleValueRows(columnType, uniqueValueCount) + val inputSeq = inputRuns.flatMap { case (index, run) => + Seq.fill(run)(index) + } + + val nullRow = new GenericInternalRow(1) + nullRow.setNullAt(0) + inputSeq.foreach { i => + if (i == nullValue) { + builder.appendFrom(nullRow, 0) + } else { + builder.appendFrom(rows(i), 0) + } + } + val buffer = builder.build() + + // ---------------- + // Tests decompress + // ---------------- + // Rewinds, skips column header and 4 more bytes for compression scheme ID + val headerSize = CompressionScheme.columnHeaderSize(buffer) + buffer.position(headerSize) + assertResult(RunLengthEncoding.typeId, "Wrong compression scheme ID")(buffer.getInt()) + + val decoder = RunLengthEncoding.decoder(buffer, columnType) + val columnVector = new OnHeapColumnVector(inputSeq.length, columnType.dataType) + decoder.decompress(columnVector, inputSeq.length) + + if (inputSeq.nonEmpty) { + inputSeq.zipWithIndex.foreach { + case (expected: Any, index: Int) if expected == nullValue => + assertResult(true, s"Wrong null ${index}th-position") { + columnVector.isNullAt(index) + } + case (i: Int, index: Int) => + columnType match { + case BOOLEAN => + assertResult(values(i), s"Wrong ${index}-th decoded boolean value") { + columnVector.getBoolean(index) + } + case BYTE => + assertResult(values(i), s"Wrong ${index}-th decoded byte value") { + columnVector.getByte(index) + } + case SHORT => + assertResult(values(i), s"Wrong ${index}-th decoded short value") { + columnVector.getShort(index) + } + case INT => + assertResult(values(i), s"Wrong ${index}-th decoded int value") { + columnVector.getInt(index) + } + case LONG => + assertResult(values(i), s"Wrong ${index}-th decoded long value") { + columnVector.getLong(index) + } + case _ => fail("Unsupported type") + } + case _ => fail("Unsupported type") + } + } + } + test(s"$RunLengthEncoding with $typeName: empty column") { skeleton(0, Seq.empty) } @@ -110,5 +179,21 @@ class RunLengthEncodingSuite extends SparkFunSuite { test(s"$RunLengthEncoding with $typeName: single long run") { skeleton(1, Seq(0 -> 1000)) } + + test(s"$RunLengthEncoding with $typeName: empty column for decompress()") { + skeletonForDecompress(0, Seq.empty) + } + + test(s"$RunLengthEncoding with $typeName: simple case for decompress()") { + skeletonForDecompress(2, Seq(0 -> 2, 1 -> 2)) + } + + test(s"$RunLengthEncoding with $typeName: single long run for decompress()") { + skeletonForDecompress(1, Seq(0 -> 1000)) + } + + test(s"$RunLengthEncoding with $typeName: single case with null for decompress()") { + skeletonForDecompress(2, Seq(0 -> 2, nullValue -> 2)) + } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/compression/TestCompressibleColumnBuilder.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/compression/TestCompressibleColumnBuilder.scala index 5e078f251375a..310cb0be5f5a2 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/compression/TestCompressibleColumnBuilder.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/compression/TestCompressibleColumnBuilder.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.execution.columnar.compression import org.apache.spark.sql.execution.columnar._ -import org.apache.spark.sql.types.AtomicType +import org.apache.spark.sql.types.{AtomicType, DataType} class TestCompressibleColumnBuilder[T <: AtomicType]( override val columnStats: ColumnStats, @@ -42,3 +42,10 @@ object TestCompressibleColumnBuilder { builder } } + +object ColumnBuilderHelper { + def apply( + dataType: DataType, batchSize: Int, name: String, useCompression: Boolean): ColumnBuilder = { + ColumnBuilder(dataType, batchSize, name, useCompression) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnVectorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnVectorSuite.scala index 85da8270d4cba..c5c8ae3a17c6c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnVectorSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnVectorSuite.scala @@ -20,7 +20,10 @@ package org.apache.spark.sql.execution.vectorized import org.scalatest.BeforeAndAfterEach import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.expressions.SpecificInternalRow import org.apache.spark.sql.catalyst.util.ArrayData +import org.apache.spark.sql.execution.columnar.ColumnAccessor +import org.apache.spark.sql.execution.columnar.compression.ColumnBuilderHelper import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String @@ -31,14 +34,21 @@ class ColumnVectorSuite extends SparkFunSuite with BeforeAndAfterEach { try block(vector) finally vector.close() } + private def withVectors( + size: Int, + dt: DataType)( + block: WritableColumnVector => Unit): Unit = { + withVector(new OnHeapColumnVector(size, dt))(block) + withVector(new OffHeapColumnVector(size, dt))(block) + } + private def testVectors( name: String, size: Int, dt: DataType)( block: WritableColumnVector => Unit): Unit = { test(name) { - withVector(new OnHeapColumnVector(size, dt))(block) - withVector(new OffHeapColumnVector(size, dt))(block) + withVectors(size, dt)(block) } } @@ -218,4 +228,173 @@ class ColumnVectorSuite extends SparkFunSuite with BeforeAndAfterEach { (0 until 8).foreach(i => assert(testVector.isNullAt(i) == (i % 2 == 0))) } } + + test("CachedBatch boolean Apis") { + val dataType = BooleanType + val columnBuilder = ColumnBuilderHelper(dataType, 1024, "col", true) + val row = new SpecificInternalRow(Array(dataType)) + + row.setNullAt(0) + columnBuilder.appendFrom(row, 0) + for (i <- 1 until 16) { + row.setBoolean(0, i % 2 == 0) + columnBuilder.appendFrom(row, 0) + } + + withVectors(16, dataType) { testVector => + val columnAccessor = ColumnAccessor(dataType, columnBuilder.build) + ColumnAccessor.decompress(columnAccessor, testVector, 16) + + assert(testVector.isNullAt(0) == true) + for (i <- 1 until 16) { + assert(testVector.isNullAt(i) == false) + assert(testVector.getBoolean(i) == (i % 2 == 0)) + } + } + } + + test("CachedBatch byte Apis") { + val dataType = ByteType + val columnBuilder = ColumnBuilderHelper(dataType, 1024, "col", true) + val row = new SpecificInternalRow(Array(dataType)) + + row.setNullAt(0) + columnBuilder.appendFrom(row, 0) + for (i <- 1 until 16) { + row.setByte(0, i.toByte) + columnBuilder.appendFrom(row, 0) + } + + withVectors(16, dataType) { testVector => + val columnAccessor = ColumnAccessor(dataType, columnBuilder.build) + ColumnAccessor.decompress(columnAccessor, testVector, 16) + + assert(testVector.isNullAt(0) == true) + for (i <- 1 until 16) { + assert(testVector.isNullAt(i) == false) + assert(testVector.getByte(i) == i) + } + } + } + + test("CachedBatch short Apis") { + val dataType = ShortType + val columnBuilder = ColumnBuilderHelper(dataType, 1024, "col", true) + val row = new SpecificInternalRow(Array(dataType)) + + row.setNullAt(0) + columnBuilder.appendFrom(row, 0) + for (i <- 1 until 16) { + row.setShort(0, i.toShort) + columnBuilder.appendFrom(row, 0) + } + + withVectors(16, dataType) { testVector => + val columnAccessor = ColumnAccessor(dataType, columnBuilder.build) + ColumnAccessor.decompress(columnAccessor, testVector, 16) + + assert(testVector.isNullAt(0) == true) + for (i <- 1 until 16) { + assert(testVector.isNullAt(i) == false) + assert(testVector.getShort(i) == i) + } + } + } + + test("CachedBatch int Apis") { + val dataType = IntegerType + val columnBuilder = ColumnBuilderHelper(dataType, 1024, "col", true) + val row = new SpecificInternalRow(Array(dataType)) + + row.setNullAt(0) + columnBuilder.appendFrom(row, 0) + for (i <- 1 until 16) { + row.setInt(0, i) + columnBuilder.appendFrom(row, 0) + } + + withVectors(16, dataType) { testVector => + val columnAccessor = ColumnAccessor(dataType, columnBuilder.build) + ColumnAccessor.decompress(columnAccessor, testVector, 16) + + assert(testVector.isNullAt(0) == true) + for (i <- 1 until 16) { + assert(testVector.isNullAt(i) == false) + assert(testVector.getInt(i) == i) + } + } + } + + test("CachedBatch long Apis") { + val dataType = LongType + val columnBuilder = ColumnBuilderHelper(dataType, 1024, "col", true) + val row = new SpecificInternalRow(Array(dataType)) + + row.setNullAt(0) + columnBuilder.appendFrom(row, 0) + for (i <- 1 until 16) { + row.setLong(0, i.toLong) + columnBuilder.appendFrom(row, 0) + } + + withVectors(16, dataType) { testVector => + val columnAccessor = ColumnAccessor(dataType, columnBuilder.build) + ColumnAccessor.decompress(columnAccessor, testVector, 16) + + assert(testVector.isNullAt(0) == true) + for (i <- 1 until 16) { + assert(testVector.isNullAt(i) == false) + assert(testVector.getLong(i) == i.toLong) + } + } + } + + test("CachedBatch float Apis") { + val dataType = FloatType + val columnBuilder = ColumnBuilderHelper(dataType, 1024, "col", true) + val row = new SpecificInternalRow(Array(dataType)) + + row.setNullAt(0) + columnBuilder.appendFrom(row, 0) + for (i <- 1 until 16) { + row.setFloat(0, i.toFloat) + columnBuilder.appendFrom(row, 0) + } + + withVectors(16, dataType) { testVector => + val columnAccessor = ColumnAccessor(dataType, columnBuilder.build) + ColumnAccessor.decompress(columnAccessor, testVector, 16) + + assert(testVector.isNullAt(0) == true) + for (i <- 1 until 16) { + assert(testVector.isNullAt(i) == false) + assert(testVector.getFloat(i) == i.toFloat) + } + } + } + + test("CachedBatch double Apis") { + val dataType = DoubleType + val columnBuilder = ColumnBuilderHelper(dataType, 1024, "col", true) + val row = new SpecificInternalRow(Array(dataType)) + + row.setNullAt(0) + columnBuilder.appendFrom(row, 0) + for (i <- 1 until 16) { + row.setDouble(0, i.toDouble) + columnBuilder.appendFrom(row, 0) + } + + withVectors(16, dataType) { testVector => + val columnAccessor = ColumnAccessor(dataType, columnBuilder.build) + ColumnAccessor.decompress(columnAccessor, testVector, 16) + + assert(testVector.isNullAt(0) == true) + for (i <- 1 until 16) { + assert(testVector.isNullAt(i) == false) + assert(testVector.getDouble(i) == i.toDouble) + } + } + } } + diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala index 983eb103682c1..0b179aa97c479 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala @@ -413,7 +413,7 @@ class ColumnarBatchSuite extends SparkFunSuite { reference.zipWithIndex.foreach { v => assert(v._1 == column.getLong(v._2), "idx=" + v._2 + - " Seed = " + seed + " MemMode=" + memMode) + " Seed = " + seed + " MemMode=" + memMode) if (memMode == MemoryMode.OFF_HEAP) { val addr = column.valuesNativeAddress() assert(v._1 == Platform.getLong(null, addr + 8 * v._2)) @@ -1120,7 +1120,7 @@ class ColumnarBatchSuite extends SparkFunSuite { } batch.close() } - }} + }} /** * This test generates a random schema data, serializes it to column batches and verifies the From 4a779bdac3e75c17b7d36c5a009ba6c948fa9fb6 Mon Sep 17 00:00:00 2001 From: Takeshi Yamamuro Date: Wed, 4 Oct 2017 10:08:24 -0700 Subject: [PATCH 1441/1765] [SPARK-21871][SQL] Check actual bytecode size when compiling generated code ## What changes were proposed in this pull request? This pr added code to check actual bytecode size when compiling generated code. In #18810, we added code to give up code compilation and use interpreter execution in `SparkPlan` if the line number of generated functions goes over `maxLinesPerFunction`. But, we already have code to collect metrics for compiled bytecode size in `CodeGenerator` object. So,we could easily reuse the code for this purpose. ## How was this patch tested? Added tests in `WholeStageCodegenSuite`. Author: Takeshi Yamamuro Closes #19083 from maropu/SPARK-21871. --- .../expressions/codegen/CodeFormatter.scala | 8 --- .../expressions/codegen/CodeGenerator.scala | 59 +++++++++---------- .../codegen/GenerateMutableProjection.scala | 4 +- .../codegen/GenerateOrdering.scala | 3 +- .../codegen/GeneratePredicate.scala | 3 +- .../codegen/GenerateSafeProjection.scala | 4 +- .../codegen/GenerateUnsafeProjection.scala | 4 +- .../codegen/GenerateUnsafeRowJoiner.scala | 4 +- .../apache/spark/sql/internal/SQLConf.scala | 15 ++--- .../codegen/CodeFormatterSuite.scala | 32 ---------- .../sql/execution/WholeStageCodegenExec.scala | 25 ++++---- .../columnar/GenerateColumnAccessor.scala | 3 +- .../execution/WholeStageCodegenSuite.scala | 43 ++++---------- .../benchmark/AggregateBenchmark.scala | 36 +++++------ 14 files changed, 94 insertions(+), 149 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeFormatter.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeFormatter.scala index 7b398f424cead..60e600d8dbd8f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeFormatter.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeFormatter.scala @@ -89,14 +89,6 @@ object CodeFormatter { } new CodeAndComment(code.result().trim(), map) } - - def stripExtraNewLinesAndComments(input: String): String = { - val commentReg = - ("""([ |\t]*?\/\*[\s|\S]*?\*\/[ |\t]*?)|""" + // strip /*comment*/ - """([ |\t]*?\/\/[\s\S]*?\n)""").r // strip //comment - val codeWithoutComment = commentReg.replaceAllIn(input, "") - codeWithoutComment.replaceAll("""\n\s*\n""", "\n") // strip ExtraNewLines - } } private class CodeFormatter { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index f3b45799c5688..f9c5ef8439085 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -373,20 +373,6 @@ class CodegenContext { */ private val placeHolderToComments = new mutable.HashMap[String, String] - /** - * It will count the lines of every Java function generated by whole-stage codegen, - * if there is a function of length greater than spark.sql.codegen.maxLinesPerFunction, - * it will return true. - */ - def isTooLongGeneratedFunction: Boolean = { - classFunctions.values.exists { _.values.exists { - code => - val codeWithoutComments = CodeFormatter.stripExtraNewLinesAndComments(code) - codeWithoutComments.count(_ == '\n') > SQLConf.get.maxLinesPerFunction - } - } - } - /** * Returns a term name that is unique within this instance of a `CodegenContext`. */ @@ -1020,10 +1006,16 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin } object CodeGenerator extends Logging { + + // This is the value of HugeMethodLimit in the OpenJDK JVM settings + val DEFAULT_JVM_HUGE_METHOD_LIMIT = 8000 + /** * Compile the Java source code into a Java class, using Janino. + * + * @return a pair of a generated class and the max bytecode size of generated functions. */ - def compile(code: CodeAndComment): GeneratedClass = try { + def compile(code: CodeAndComment): (GeneratedClass, Int) = try { cache.get(code) } catch { // Cache.get() may wrap the original exception. See the following URL @@ -1036,7 +1028,7 @@ object CodeGenerator extends Logging { /** * Compile the Java source code into a Java class, using Janino. */ - private[this] def doCompile(code: CodeAndComment): GeneratedClass = { + private[this] def doCompile(code: CodeAndComment): (GeneratedClass, Int) = { val evaluator = new ClassBodyEvaluator() // A special classloader used to wrap the actual parent classloader of @@ -1075,9 +1067,9 @@ object CodeGenerator extends Logging { s"\n${CodeFormatter.format(code)}" }) - try { + val maxCodeSize = try { evaluator.cook("generated.java", code.body) - recordCompilationStats(evaluator) + updateAndGetCompilationStats(evaluator) } catch { case e: JaninoRuntimeException => val msg = s"failed to compile: $e" @@ -1092,13 +1084,15 @@ object CodeGenerator extends Logging { logInfo(s"\n${CodeFormatter.format(code, maxLines)}") throw new CompileException(msg, e.getLocation) } - evaluator.getClazz().newInstance().asInstanceOf[GeneratedClass] + + (evaluator.getClazz().newInstance().asInstanceOf[GeneratedClass], maxCodeSize) } /** - * Records the generated class and method bytecode sizes by inspecting janino private fields. + * Returns the max bytecode size of the generated functions by inspecting janino private fields. + * Also, this method updates the metrics information. */ - private def recordCompilationStats(evaluator: ClassBodyEvaluator): Unit = { + private def updateAndGetCompilationStats(evaluator: ClassBodyEvaluator): Int = { // First retrieve the generated classes. val classes = { val resultField = classOf[SimpleCompiler].getDeclaredField("result") @@ -1113,23 +1107,26 @@ object CodeGenerator extends Logging { val codeAttr = Utils.classForName("org.codehaus.janino.util.ClassFile$CodeAttribute") val codeAttrField = codeAttr.getDeclaredField("code") codeAttrField.setAccessible(true) - classes.foreach { case (_, classBytes) => + val codeSizes = classes.flatMap { case (_, classBytes) => CodegenMetrics.METRIC_GENERATED_CLASS_BYTECODE_SIZE.update(classBytes.length) try { val cf = new ClassFile(new ByteArrayInputStream(classBytes)) - cf.methodInfos.asScala.foreach { method => - method.getAttributes().foreach { a => - if (a.getClass.getName == codeAttr.getName) { - CodegenMetrics.METRIC_GENERATED_METHOD_BYTECODE_SIZE.update( - codeAttrField.get(a).asInstanceOf[Array[Byte]].length) - } + val stats = cf.methodInfos.asScala.flatMap { method => + method.getAttributes().filter(_.getClass.getName == codeAttr.getName).map { a => + val byteCodeSize = codeAttrField.get(a).asInstanceOf[Array[Byte]].length + CodegenMetrics.METRIC_GENERATED_METHOD_BYTECODE_SIZE.update(byteCodeSize) + byteCodeSize } } + Some(stats) } catch { case NonFatal(e) => logWarning("Error calculating stats of compiled class.", e) + None } - } + }.flatten + + codeSizes.max } /** @@ -1144,8 +1141,8 @@ object CodeGenerator extends Logging { private val cache = CacheBuilder.newBuilder() .maximumSize(100) .build( - new CacheLoader[CodeAndComment, GeneratedClass]() { - override def load(code: CodeAndComment): GeneratedClass = { + new CacheLoader[CodeAndComment, (GeneratedClass, Int)]() { + override def load(code: CodeAndComment): (GeneratedClass, Int) = { val startTime = System.nanoTime() val result = doCompile(code) val endTime = System.nanoTime() diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala index 3768dcde00a4e..b5429fade53cf 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala @@ -142,7 +142,7 @@ object GenerateMutableProjection extends CodeGenerator[Seq[Expression], MutableP new CodeAndComment(codeBody, ctx.getPlaceHolderToComments())) logDebug(s"code for ${expressions.mkString(",")}:\n${CodeFormatter.format(code)}") - val c = CodeGenerator.compile(code) - c.generate(ctx.references.toArray).asInstanceOf[MutableProjection] + val (clazz, _) = CodeGenerator.compile(code) + clazz.generate(ctx.references.toArray).asInstanceOf[MutableProjection] } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala index 4e47895985209..1639d1b9dda1f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala @@ -185,7 +185,8 @@ object GenerateOrdering extends CodeGenerator[Seq[SortOrder], Ordering[InternalR new CodeAndComment(codeBody, ctx.getPlaceHolderToComments())) logDebug(s"Generated Ordering by ${ordering.mkString(",")}:\n${CodeFormatter.format(code)}") - CodeGenerator.compile(code).generate(ctx.references.toArray).asInstanceOf[BaseOrdering] + val (clazz, _) = CodeGenerator.compile(code) + clazz.generate(ctx.references.toArray).asInstanceOf[BaseOrdering] } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratePredicate.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratePredicate.scala index e35b9dda6c017..e0fabad6d089a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratePredicate.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratePredicate.scala @@ -78,6 +78,7 @@ object GeneratePredicate extends CodeGenerator[Expression, Predicate] { new CodeAndComment(codeBody, ctx.getPlaceHolderToComments())) logDebug(s"Generated predicate '$predicate':\n${CodeFormatter.format(code)}") - CodeGenerator.compile(code).generate(ctx.references.toArray).asInstanceOf[Predicate] + val (clazz, _) = CodeGenerator.compile(code) + clazz.generate(ctx.references.toArray).asInstanceOf[Predicate] } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala index 192701a829686..1e4ac3f2afd52 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala @@ -189,8 +189,8 @@ object GenerateSafeProjection extends CodeGenerator[Seq[Expression], Projection] new CodeAndComment(codeBody, ctx.getPlaceHolderToComments())) logDebug(s"code for ${expressions.mkString(",")}:\n${CodeFormatter.format(code)}") - val c = CodeGenerator.compile(code) + val (clazz, _) = CodeGenerator.compile(code) val resultRow = new SpecificInternalRow(expressions.map(_.dataType)) - c.generate(ctx.references.toArray :+ resultRow).asInstanceOf[Projection] + clazz.generate(ctx.references.toArray :+ resultRow).asInstanceOf[Projection] } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala index f2a66efc98e71..4bd50aee05514 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala @@ -409,7 +409,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro new CodeAndComment(codeBody, ctx.getPlaceHolderToComments())) logDebug(s"code for ${expressions.mkString(",")}:\n${CodeFormatter.format(code)}") - val c = CodeGenerator.compile(code) - c.generate(ctx.references.toArray).asInstanceOf[UnsafeProjection] + val (clazz, _) = CodeGenerator.compile(code) + clazz.generate(ctx.references.toArray).asInstanceOf[UnsafeProjection] } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeRowJoiner.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeRowJoiner.scala index 4aa5ec82471ec..6bc72a0d75c6d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeRowJoiner.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeRowJoiner.scala @@ -196,7 +196,7 @@ object GenerateUnsafeRowJoiner extends CodeGenerator[(StructType, StructType), U val code = CodeFormatter.stripOverlappingComments(new CodeAndComment(codeBody, Map.empty)) logDebug(s"SpecificUnsafeRowJoiner($schema1, $schema2):\n${CodeFormatter.format(code)}") - val c = CodeGenerator.compile(code) - c.generate(Array.empty).asInstanceOf[UnsafeRowJoiner] + val (clazz, _) = CodeGenerator.compile(code) + clazz.generate(Array.empty).asInstanceOf[UnsafeRowJoiner] } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 1a73d168b9b6e..58323740b80cc 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -30,6 +30,7 @@ import org.apache.spark.internal.Logging import org.apache.spark.internal.config._ import org.apache.spark.network.util.ByteUnit import org.apache.spark.sql.catalyst.analysis.Resolver +import org.apache.spark.sql.catalyst.expressions.codegen.CodeGenerator import org.apache.spark.util.collection.unsafe.sort.UnsafeExternalSorter //////////////////////////////////////////////////////////////////////////////////////////////////// @@ -575,15 +576,15 @@ object SQLConf { "disable logging or -1 to apply no limit.") .createWithDefault(1000) - val WHOLESTAGE_MAX_LINES_PER_FUNCTION = buildConf("spark.sql.codegen.maxLinesPerFunction") + val WHOLESTAGE_HUGE_METHOD_LIMIT = buildConf("spark.sql.codegen.hugeMethodLimit") .internal() - .doc("The maximum lines of a single Java function generated by whole-stage codegen. " + - "When the generated function exceeds this threshold, " + + .doc("The maximum bytecode size of a single compiled Java function generated by whole-stage " + + "codegen. When the compiled function exceeds this threshold, " + "the whole-stage codegen is deactivated for this subtree of the current query plan. " + - "The default value 4000 is the max length of byte code JIT supported " + - "for a single function(8000) divided by 2.") + s"The default value is ${CodeGenerator.DEFAULT_JVM_HUGE_METHOD_LIMIT} and " + + "this is a limit in the OpenJDK JVM implementation.") .intConf - .createWithDefault(4000) + .createWithDefault(CodeGenerator.DEFAULT_JVM_HUGE_METHOD_LIMIT) val FILES_MAX_PARTITION_BYTES = buildConf("spark.sql.files.maxPartitionBytes") .doc("The maximum number of bytes to pack into a single partition when reading files.") @@ -1058,7 +1059,7 @@ class SQLConf extends Serializable with Logging { def loggingMaxLinesForCodegen: Int = getConf(CODEGEN_LOGGING_MAX_LINES) - def maxLinesPerFunction: Int = getConf(WHOLESTAGE_MAX_LINES_PER_FUNCTION) + def hugeMethodLimit: Int = getConf(WHOLESTAGE_HUGE_METHOD_LIMIT) def tableRelationCacheSize: Int = getConf(StaticSQLConf.FILESOURCE_TABLE_RELATION_CACHE_SIZE) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeFormatterSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeFormatterSuite.scala index a0f1a64b0ab08..9d0a41661beaa 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeFormatterSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeFormatterSuite.scala @@ -53,38 +53,6 @@ class CodeFormatterSuite extends SparkFunSuite { assert(reducedCode.body === "/*project_c4*/") } - test("removing extra new lines and comments") { - val code = - """ - |/* - | * multi - | * line - | * comments - | */ - | - |public function() { - |/*comment*/ - | /*comment_with_space*/ - |code_body - |//comment - |code_body - | //comment_with_space - | - |code_body - |} - """.stripMargin - - val reducedCode = CodeFormatter.stripExtraNewLinesAndComments(code) - assert(reducedCode === - """ - |public function() { - |code_body - |code_body - |code_body - |} - """.stripMargin) - } - testCase("basic example") { """ |class A { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala index 268ccfa4edfa0..9073d599ac43d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala @@ -380,16 +380,8 @@ case class WholeStageCodegenExec(child: SparkPlan) extends UnaryExecNode with Co override def doExecute(): RDD[InternalRow] = { val (ctx, cleanedSource) = doCodeGen() - if (ctx.isTooLongGeneratedFunction) { - logWarning("Found too long generated codes and JIT optimization might not work, " + - "Whole-stage codegen disabled for this plan, " + - "You can change the config spark.sql.codegen.MaxFunctionLength " + - "to adjust the function length limit:\n " - + s"$treeString") - return child.execute() - } // try to compile and fallback if it failed - try { + val (_, maxCodeSize) = try { CodeGenerator.compile(cleanedSource) } catch { case _: Exception if !Utils.isTesting && sqlContext.conf.codegenFallback => @@ -397,6 +389,17 @@ case class WholeStageCodegenExec(child: SparkPlan) extends UnaryExecNode with Co logWarning(s"Whole-stage codegen disabled for this plan:\n $treeString") return child.execute() } + + // Check if compiled code has a too large function + if (maxCodeSize > sqlContext.conf.hugeMethodLimit) { + logWarning(s"Found too long generated codes and JIT optimization might not work: " + + s"the bytecode size was $maxCodeSize, this value went over the limit " + + s"${sqlContext.conf.hugeMethodLimit}, and the whole-stage codegen was disabled " + + s"for this plan. To avoid this, you can raise the limit " + + s"${SQLConf.WHOLESTAGE_HUGE_METHOD_LIMIT.key}:\n$treeString") + return child.execute() + } + val references = ctx.references.toArray val durationMs = longMetric("pipelineTime") @@ -405,7 +408,7 @@ case class WholeStageCodegenExec(child: SparkPlan) extends UnaryExecNode with Co assert(rdds.size <= 2, "Up to two input RDDs can be supported") if (rdds.length == 1) { rdds.head.mapPartitionsWithIndex { (index, iter) => - val clazz = CodeGenerator.compile(cleanedSource) + val (clazz, _) = CodeGenerator.compile(cleanedSource) val buffer = clazz.generate(references).asInstanceOf[BufferedRowIterator] buffer.init(index, Array(iter)) new Iterator[InternalRow] { @@ -424,7 +427,7 @@ case class WholeStageCodegenExec(child: SparkPlan) extends UnaryExecNode with Co // a small hack to obtain the correct partition index }.mapPartitionsWithIndex { (index, zippedIter) => val (leftIter, rightIter) = zippedIter.next() - val clazz = CodeGenerator.compile(cleanedSource) + val (clazz, _) = CodeGenerator.compile(cleanedSource) val buffer = clazz.generate(references).asInstanceOf[BufferedRowIterator] buffer.init(index, Array(leftIter, rightIter)) new Iterator[InternalRow] { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/GenerateColumnAccessor.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/GenerateColumnAccessor.scala index da34643281911..ae600c1ffae8e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/GenerateColumnAccessor.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/GenerateColumnAccessor.scala @@ -227,6 +227,7 @@ object GenerateColumnAccessor extends CodeGenerator[Seq[DataType], ColumnarItera new CodeAndComment(codeBody, ctx.getPlaceHolderToComments())) logDebug(s"Generated ColumnarIterator:\n${CodeFormatter.format(code)}") - CodeGenerator.compile(code).generate(Array.empty).asInstanceOf[ColumnarIterator] + val (clazz, _) = CodeGenerator.compile(code) + clazz.generate(Array.empty).asInstanceOf[ColumnarIterator] } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala index beeee6a97c8dd..aaa77b3ee6201 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala @@ -17,10 +17,8 @@ package org.apache.spark.sql.execution -import org.apache.spark.sql.{Column, Dataset, Row} -import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute -import org.apache.spark.sql.catalyst.expressions.{Add, Literal, Stack} -import org.apache.spark.sql.catalyst.expressions.codegen.CodegenContext +import org.apache.spark.sql.Row +import org.apache.spark.sql.catalyst.expressions.codegen.{CodeAndComment, CodeGenerator} import org.apache.spark.sql.execution.aggregate.HashAggregateExec import org.apache.spark.sql.execution.joins.BroadcastHashJoinExec import org.apache.spark.sql.execution.joins.SortMergeJoinExec @@ -151,7 +149,7 @@ class WholeStageCodegenSuite extends SparkPlanTest with SharedSQLContext { } } - def genGroupByCodeGenContext(caseNum: Int): CodegenContext = { + def genGroupByCode(caseNum: Int): CodeAndComment = { val caseExp = (1 to caseNum).map { i => s"case when id > $i and id <= ${i + 1} then 1 else 0 end as v$i" }.toList @@ -176,34 +174,15 @@ class WholeStageCodegenSuite extends SparkPlanTest with SharedSQLContext { }) assert(wholeStageCodeGenExec.isDefined) - wholeStageCodeGenExec.get.asInstanceOf[WholeStageCodegenExec].doCodeGen()._1 + wholeStageCodeGenExec.get.asInstanceOf[WholeStageCodegenExec].doCodeGen()._2 } - test("SPARK-21603 check there is a too long generated function") { - withSQLConf(SQLConf.WHOLESTAGE_MAX_LINES_PER_FUNCTION.key -> "1500") { - val ctx = genGroupByCodeGenContext(30) - assert(ctx.isTooLongGeneratedFunction === true) - } - } - - test("SPARK-21603 check there is not a too long generated function") { - withSQLConf(SQLConf.WHOLESTAGE_MAX_LINES_PER_FUNCTION.key -> "1500") { - val ctx = genGroupByCodeGenContext(1) - assert(ctx.isTooLongGeneratedFunction === false) - } - } - - test("SPARK-21603 check there is not a too long generated function when threshold is Int.Max") { - withSQLConf(SQLConf.WHOLESTAGE_MAX_LINES_PER_FUNCTION.key -> Int.MaxValue.toString) { - val ctx = genGroupByCodeGenContext(30) - assert(ctx.isTooLongGeneratedFunction === false) - } - } - - test("SPARK-21603 check there is a too long generated function when threshold is 0") { - withSQLConf(SQLConf.WHOLESTAGE_MAX_LINES_PER_FUNCTION.key -> "0") { - val ctx = genGroupByCodeGenContext(1) - assert(ctx.isTooLongGeneratedFunction === true) - } + test("SPARK-21871 check if we can get large code size when compiling too long functions") { + val codeWithShortFunctions = genGroupByCode(3) + val (_, maxCodeSize1) = CodeGenerator.compile(codeWithShortFunctions) + assert(maxCodeSize1 < SQLConf.WHOLESTAGE_HUGE_METHOD_LIMIT.defaultValue.get) + val codeWithLongFunctions = genGroupByCode(20) + val (_, maxCodeSize2) = CodeGenerator.compile(codeWithLongFunctions) + assert(maxCodeSize2 > SQLConf.WHOLESTAGE_HUGE_METHOD_LIMIT.defaultValue.get) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/AggregateBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/AggregateBenchmark.scala index 691fa9ac5e1e7..aca1be01fa3da 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/AggregateBenchmark.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/AggregateBenchmark.scala @@ -24,6 +24,7 @@ import org.apache.spark.memory.{StaticMemoryManager, TaskMemoryManager} import org.apache.spark.sql.catalyst.expressions.UnsafeRow import org.apache.spark.sql.execution.joins.LongToUnsafeRowMap import org.apache.spark.sql.execution.vectorized.AggregateHashMap +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.{LongType, StructType} import org.apache.spark.unsafe.Platform import org.apache.spark.unsafe.hash.Murmur3_x86_32 @@ -301,10 +302,10 @@ class AggregateBenchmark extends BenchmarkBase { */ } - ignore("max function length of wholestagecodegen") { + ignore("max function bytecode size of wholestagecodegen") { val N = 20 << 15 - val benchmark = new Benchmark("max function length of wholestagecodegen", N) + val benchmark = new Benchmark("max function bytecode size", N) def f(): Unit = sparkSession.range(N) .selectExpr( "id", @@ -333,33 +334,34 @@ class AggregateBenchmark extends BenchmarkBase { .sum() .collect() - benchmark.addCase(s"codegen = F") { iter => - sparkSession.conf.set("spark.sql.codegen.wholeStage", "false") + benchmark.addCase("codegen = F") { iter => + sparkSession.conf.set(SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key, "false") f() } - benchmark.addCase(s"codegen = T maxLinesPerFunction = 10000") { iter => - sparkSession.conf.set("spark.sql.codegen.wholeStage", "true") - sparkSession.conf.set("spark.sql.codegen.maxLinesPerFunction", "10000") + benchmark.addCase("codegen = T hugeMethodLimit = 10000") { iter => + sparkSession.conf.set(SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key, "true") + sparkSession.conf.set(SQLConf.WHOLESTAGE_HUGE_METHOD_LIMIT.key, "10000") f() } - benchmark.addCase(s"codegen = T maxLinesPerFunction = 1500") { iter => - sparkSession.conf.set("spark.sql.codegen.wholeStage", "true") - sparkSession.conf.set("spark.sql.codegen.maxLinesPerFunction", "1500") + benchmark.addCase("codegen = T hugeMethodLimit = 1500") { iter => + sparkSession.conf.set(SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key, "true") + sparkSession.conf.set(SQLConf.WHOLESTAGE_HUGE_METHOD_LIMIT.key, "1500") f() } benchmark.run() /* - Java HotSpot(TM) 64-Bit Server VM 1.8.0_111-b14 on Windows 7 6.1 - Intel64 Family 6 Model 58 Stepping 9, GenuineIntel - max function length of wholestagecodegen: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative - ---------------------------------------------------------------------------------------------- - codegen = F 462 / 533 1.4 704.4 1.0X - codegen = T maxLinesPerFunction = 10000 3444 / 3447 0.2 5255.3 0.1X - codegen = T maxLinesPerFunction = 1500 447 / 478 1.5 682.1 1.0X + Java HotSpot(TM) 64-Bit Server VM 1.8.0_31-b13 on Mac OS X 10.10.2 + Intel(R) Core(TM) i7-4578U CPU @ 3.00GHz + + max function bytecode size: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ------------------------------------------------------------------------------------------------ + codegen = F 709 / 803 0.9 1082.1 1.0X + codegen = T hugeMethodLimit = 10000 3485 / 3548 0.2 5317.7 0.2X + codegen = T hugeMethodLimit = 1500 636 / 701 1.0 969.9 1.1X */ } From bb035f1ee5cdf88e476b7ed83d59140d669fbe12 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Wed, 4 Oct 2017 13:13:51 -0700 Subject: [PATCH 1442/1765] [SPARK-22169][SQL] support byte length literal as identifier ## What changes were proposed in this pull request? By definition the table name in Spark can be something like `123x`, `25a`, etc., with exceptions for literals like `12L`, `23BD`, etc. However, Spark SQL has a special byte length literal, which stops users to use digits followed by `b`, `k`, `m`, `g` as identifiers. byte length literal is not a standard sql literal and is only used in the `tableSample` parser rule. This PR move the parsing of byte length literal from lexer to parser, so that users can use it as identifiers. ## How was this patch tested? regression test Author: Wenchen Fan Closes #19392 from cloud-fan/parser-bug. --- .../spark/sql/catalyst/parser/SqlBase.g4 | 25 +++++++----------- .../sql/catalyst/catalog/SessionCatalog.scala | 2 +- .../sql/catalyst/parser/AstBuilder.scala | 26 +++++++++++++------ .../sql/catalyst/parser/PlanParserSuite.scala | 1 + .../execution/command/DDLParserSuite.scala | 19 ++++++++++++++ 5 files changed, 49 insertions(+), 24 deletions(-) diff --git a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 index d0a54288780ea..17c8404f8a79c 100644 --- a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 +++ b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 @@ -25,7 +25,7 @@ grammar SqlBase; * For char stream "2.3", "2." is not a valid decimal token, because it is followed by digit '3'. * For char stream "2.3_", "2.3" is not a valid decimal token, because it is followed by '_'. * For char stream "2.3W", "2.3" is not a valid decimal token, because it is followed by 'W'. - * For char stream "12.0D 34.E2+0.12 " 12.0D is a valid decimal token because it is folllowed + * For char stream "12.0D 34.E2+0.12 " 12.0D is a valid decimal token because it is followed * by a space. 34.E2 is a valid decimal token because it is followed by symbol '+' * which is not a digit or letter or underscore. */ @@ -40,10 +40,6 @@ grammar SqlBase; } } -tokens { - DELIMITER -} - singleStatement : statement EOF ; @@ -447,12 +443,15 @@ joinCriteria ; sample - : TABLESAMPLE '(' - ( (negativeSign=MINUS? percentage=(INTEGER_VALUE | DECIMAL_VALUE) sampleType=PERCENTLIT) - | (expression sampleType=ROWS) - | sampleType=BYTELENGTH_LITERAL - | (sampleType=BUCKET numerator=INTEGER_VALUE OUT OF denominator=INTEGER_VALUE (ON (identifier | qualifiedName '(' ')'))?)) - ')' + : TABLESAMPLE '(' sampleMethod? ')' + ; + +sampleMethod + : negativeSign=MINUS? percentage=(INTEGER_VALUE | DECIMAL_VALUE) PERCENTLIT #sampleByPercentile + | expression ROWS #sampleByRows + | sampleType=BUCKET numerator=INTEGER_VALUE OUT OF denominator=INTEGER_VALUE + (ON (identifier | qualifiedName '(' ')'))? #sampleByBucket + | bytes=expression #sampleByBytes ; identifierList @@ -1004,10 +1003,6 @@ TINYINT_LITERAL : DIGIT+ 'Y' ; -BYTELENGTH_LITERAL - : DIGIT+ ('B' | 'K' | 'M' | 'G') - ; - INTEGER_VALUE : DIGIT+ ; diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala index 6ba9ee5446a01..95bc3d674b4f8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala @@ -99,7 +99,7 @@ class SessionCatalog( protected var currentDb: String = formatDatabaseName(DEFAULT_DATABASE) /** - * Checks if the given name conforms the Hive standard ("[a-zA-z_0-9]+"), + * Checks if the given name conforms the Hive standard ("[a-zA-Z_0-9]+"), * i.e. if this name only contains characters, numbers, and _. * * This method is intended to have the same behavior of diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index 85b492e83446e..ce367145bc637 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -699,20 +699,30 @@ class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging Sample(0.0, fraction, withReplacement = false, (math.random * 1000).toInt, query) } - ctx.sampleType.getType match { - case SqlBaseParser.ROWS => + if (ctx.sampleMethod() == null) { + throw new ParseException("TABLESAMPLE does not accept empty inputs.", ctx) + } + + ctx.sampleMethod() match { + case ctx: SampleByRowsContext => Limit(expression(ctx.expression), query) - case SqlBaseParser.PERCENTLIT => + case ctx: SampleByPercentileContext => val fraction = ctx.percentage.getText.toDouble val sign = if (ctx.negativeSign == null) 1 else -1 sample(sign * fraction / 100.0d) - case SqlBaseParser.BYTELENGTH_LITERAL => - throw new ParseException( - "TABLESAMPLE(byteLengthLiteral) is not supported", ctx) + case ctx: SampleByBytesContext => + val bytesStr = ctx.bytes.getText + if (bytesStr.matches("[0-9]+[bBkKmMgG]")) { + throw new ParseException("TABLESAMPLE(byteLengthLiteral) is not supported", ctx) + } else { + throw new ParseException( + bytesStr + " is not a valid byte length literal, " + + "expected syntax: DIGIT+ ('B' | 'K' | 'M' | 'G')", ctx) + } - case SqlBaseParser.BUCKET if ctx.ON != null => + case ctx: SampleByBucketContext if ctx.ON() != null => if (ctx.identifier != null) { throw new ParseException( "TABLESAMPLE(BUCKET x OUT OF y ON colname) is not supported", ctx) @@ -721,7 +731,7 @@ class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging "TABLESAMPLE(BUCKET x OUT OF y ON function) is not supported", ctx) } - case SqlBaseParser.BUCKET => + case ctx: SampleByBucketContext => sample(ctx.numerator.getText.toDouble / ctx.denominator.getText.toDouble) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala index 306e6f2cfbd37..d34a83c42c67e 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala @@ -110,6 +110,7 @@ class PlanParserSuite extends AnalysisTest { assertEqual("select distinct a, b from db.c", Distinct(table("db", "c").select('a, 'b))) assertEqual("select all a, b from db.c", table("db", "c").select('a, 'b)) assertEqual("select from tbl", OneRowRelation().select('from.as("tbl"))) + assertEqual("select a from 1k.2m", table("1k", "2m").select('a)) } test("reverse select query") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLParserSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLParserSuite.scala index fa5172ca8a3e7..eb7c33590b602 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLParserSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLParserSuite.scala @@ -525,6 +525,25 @@ class DDLParserSuite extends PlanTest with SharedSQLContext { assert(e.message.contains("you can only specify one of them.")) } + test("create table - byte length literal table name") { + val sql = "CREATE TABLE 1m.2g(a INT) USING parquet" + + val expectedTableDesc = CatalogTable( + identifier = TableIdentifier("2g", Some("1m")), + tableType = CatalogTableType.MANAGED, + storage = CatalogStorageFormat.empty, + schema = new StructType().add("a", IntegerType), + provider = Some("parquet")) + + parser.parsePlan(sql) match { + case CreateTable(tableDesc, _, None) => + assert(tableDesc == expectedTableDesc.copy(createTime = tableDesc.createTime)) + case other => + fail(s"Expected to parse ${classOf[CreateTableCommand].getClass.getName} from query," + + s"got ${other.getClass.getName}: $sql") + } + } + test("insert overwrite directory") { val v1 = "INSERT OVERWRITE DIRECTORY '/tmp/file' USING parquet SELECT 1 as a" parser.parsePlan(v1) match { From 969ffd631746125eb2b83722baf6f6e7ddd2092c Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Wed, 4 Oct 2017 19:25:22 -0700 Subject: [PATCH 1443/1765] [SPARK-22187][SS] Update unsaferow format for saved state such that we can set timeouts when state is null ## What changes were proposed in this pull request? Currently, the group state of user-defined-type is encoded as top-level columns in the UnsafeRows stores in the state store. The timeout timestamp is also saved as (when needed) as the last top-level column. Since the group state is serialized to top-level columns, you cannot save "null" as a value of state (setting null in all the top-level columns is not equivalent). So we don't let the user set the timeout without initializing the state for a key. Based on user experience, this leads to confusion. This PR is to change the row format such that the state is saved as nested columns. This would allow the state to be set to null, and avoid these confusing corner cases. ## How was this patch tested? Refactored tests. Author: Tathagata Das Closes #19416 from tdas/SPARK-22187. --- .../FlatMapGroupsWithStateExec.scala | 133 +++------------ .../FlatMapGroupsWithState_StateManager.scala | 153 ++++++++++++++++++ .../FlatMapGroupsWithStateSuite.scala | 130 ++++++++------- 3 files changed, 246 insertions(+), 170 deletions(-) create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/FlatMapGroupsWithState_StateManager.scala diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala index ab690fd5fbbca..aab06d611a5ea 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala @@ -23,10 +23,8 @@ import org.apache.spark.sql.catalyst.expressions.{Ascending, Attribute, Attribut import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.plans.physical.{ClusteredDistribution, Distribution} import org.apache.spark.sql.execution._ -import org.apache.spark.sql.execution.streaming.GroupStateImpl.NO_TIMESTAMP import org.apache.spark.sql.execution.streaming.state._ import org.apache.spark.sql.streaming.{GroupStateTimeout, OutputMode} -import org.apache.spark.sql.types.IntegerType import org.apache.spark.util.CompletionIterator /** @@ -62,26 +60,7 @@ case class FlatMapGroupsWithStateExec( import GroupStateImpl._ private val isTimeoutEnabled = timeoutConf != NoTimeout - private val timestampTimeoutAttribute = - AttributeReference("timeoutTimestamp", dataType = IntegerType, nullable = false)() - private val stateAttributes: Seq[Attribute] = { - val encSchemaAttribs = stateEncoder.schema.toAttributes - if (isTimeoutEnabled) encSchemaAttribs :+ timestampTimeoutAttribute else encSchemaAttribs - } - // Get the serializer for the state, taking into account whether we need to save timestamps - private val stateSerializer = { - val encoderSerializer = stateEncoder.namedExpressions - if (isTimeoutEnabled) { - encoderSerializer :+ Literal(GroupStateImpl.NO_TIMESTAMP) - } else { - encoderSerializer - } - } - // Get the deserializer for the state. Note that this must be done in the driver, as - // resolving and binding of deserializer expressions to the encoded type can be safely done - // only in the driver. - private val stateDeserializer = stateEncoder.resolveAndBind().deserializer - + val stateManager = new FlatMapGroupsWithState_StateManager(stateEncoder, isTimeoutEnabled) /** Distribute by grouping attributes */ override def requiredChildDistribution: Seq[Distribution] = @@ -109,11 +88,11 @@ case class FlatMapGroupsWithStateExec( child.execute().mapPartitionsWithStateStore[InternalRow]( getStateInfo, groupingAttributes.toStructType, - stateAttributes.toStructType, + stateManager.stateSchema, indexOrdinal = None, sqlContext.sessionState, Some(sqlContext.streams.stateStoreCoordinator)) { case (store, iter) => - val updater = new StateStoreUpdater(store) + val processor = new InputProcessor(store) // If timeout is based on event time, then filter late data based on watermark val filteredIter = watermarkPredicateForData match { @@ -128,7 +107,7 @@ case class FlatMapGroupsWithStateExec( // all the data has been processed. This is to ensure that the timeout information of all // the keys with data is updated before they are processed for timeouts. val outputIterator = - updater.updateStateForKeysWithData(filteredIter) ++ updater.updateStateForTimedOutKeys() + processor.processNewData(filteredIter) ++ processor.processTimedOutState() // Return an iterator of all the rows generated by all the keys, such that when fully // consumed, all the state updates will be committed by the state store @@ -143,7 +122,7 @@ case class FlatMapGroupsWithStateExec( } /** Helper class to update the state store */ - class StateStoreUpdater(store: StateStore) { + class InputProcessor(store: StateStore) { // Converters for translating input keys, values, output data between rows and Java objects private val getKeyObj = @@ -152,14 +131,6 @@ case class FlatMapGroupsWithStateExec( ObjectOperator.deserializeRowToObject(valueDeserializer, dataAttributes) private val getOutputRow = ObjectOperator.wrapObjectToRow(outputObjAttr.dataType) - // Converters for translating state between rows and Java objects - private val getStateObjFromRow = ObjectOperator.deserializeRowToObject( - stateDeserializer, stateAttributes) - private val getStateRowFromObj = ObjectOperator.serializeObjectToRow(stateSerializer) - - // Index of the additional metadata fields in the state row - private val timeoutTimestampIndex = stateAttributes.indexOf(timestampTimeoutAttribute) - // Metrics private val numUpdatedStateRows = longMetric("numUpdatedStateRows") private val numOutputRows = longMetric("numOutputRows") @@ -168,20 +139,19 @@ case class FlatMapGroupsWithStateExec( * For every group, get the key, values and corresponding state and call the function, * and return an iterator of rows */ - def updateStateForKeysWithData(dataIter: Iterator[InternalRow]): Iterator[InternalRow] = { + def processNewData(dataIter: Iterator[InternalRow]): Iterator[InternalRow] = { val groupedIter = GroupedIterator(dataIter, groupingAttributes, child.output) groupedIter.flatMap { case (keyRow, valueRowIter) => val keyUnsafeRow = keyRow.asInstanceOf[UnsafeRow] callFunctionAndUpdateState( - keyUnsafeRow, + stateManager.getState(store, keyUnsafeRow), valueRowIter, - store.get(keyUnsafeRow), hasTimedOut = false) } } /** Find the groups that have timeout set and are timing out right now, and call the function */ - def updateStateForTimedOutKeys(): Iterator[InternalRow] = { + def processTimedOutState(): Iterator[InternalRow] = { if (isTimeoutEnabled) { val timeoutThreshold = timeoutConf match { case ProcessingTimeTimeout => batchTimestampMs.get @@ -190,12 +160,11 @@ case class FlatMapGroupsWithStateExec( throw new IllegalStateException( s"Cannot filter timed out keys for $timeoutConf") } - val timingOutKeys = store.getRange(None, None).filter { rowPair => - val timeoutTimestamp = getTimeoutTimestamp(rowPair.value) - timeoutTimestamp != NO_TIMESTAMP && timeoutTimestamp < timeoutThreshold + val timingOutKeys = stateManager.getAllState(store).filter { state => + state.timeoutTimestamp != NO_TIMESTAMP && state.timeoutTimestamp < timeoutThreshold } - timingOutKeys.flatMap { rowPair => - callFunctionAndUpdateState(rowPair.key, Iterator.empty, rowPair.value, hasTimedOut = true) + timingOutKeys.flatMap { stateData => + callFunctionAndUpdateState(stateData, Iterator.empty, hasTimedOut = true) } } else Iterator.empty } @@ -205,72 +174,43 @@ case class FlatMapGroupsWithStateExec( * iterator. Note that the store updating is lazy, that is, the store will be updated only * after the returned iterator is fully consumed. * - * @param keyRow Row representing the key, cannot be null + * @param stateData All the data related to the state to be updated * @param valueRowIter Iterator of values as rows, cannot be null, but can be empty - * @param prevStateRow Row representing the previous state, can be null * @param hasTimedOut Whether this function is being called for a key timeout */ private def callFunctionAndUpdateState( - keyRow: UnsafeRow, + stateData: FlatMapGroupsWithState_StateData, valueRowIter: Iterator[InternalRow], - prevStateRow: UnsafeRow, hasTimedOut: Boolean): Iterator[InternalRow] = { - val keyObj = getKeyObj(keyRow) // convert key to objects + val keyObj = getKeyObj(stateData.keyRow) // convert key to objects val valueObjIter = valueRowIter.map(getValueObj.apply) // convert value rows to objects - val stateObj = getStateObj(prevStateRow) - val keyedState = GroupStateImpl.createForStreaming( - Option(stateObj), + val groupState = GroupStateImpl.createForStreaming( + Option(stateData.stateObj), batchTimestampMs.getOrElse(NO_TIMESTAMP), eventTimeWatermark.getOrElse(NO_TIMESTAMP), timeoutConf, hasTimedOut) // Call function, get the returned objects and convert them to rows - val mappedIterator = func(keyObj, valueObjIter, keyedState).map { obj => + val mappedIterator = func(keyObj, valueObjIter, groupState).map { obj => numOutputRows += 1 getOutputRow(obj) } // When the iterator is consumed, then write changes to state def onIteratorCompletion: Unit = { - - val currentTimeoutTimestamp = keyedState.getTimeoutTimestamp - // If the state has not yet been set but timeout has been set, then - // we have to generate a row to save the timeout. However, attempting serialize - // null using case class encoder throws - - // java.lang.NullPointerException: Null value appeared in non-nullable field: - // If the schema is inferred from a Scala tuple / case class, or a Java bean, please - // try to use scala.Option[_] or other nullable types. - if (!keyedState.exists && currentTimeoutTimestamp != NO_TIMESTAMP) { - throw new IllegalStateException( - "Cannot set timeout when state is not defined, that is, state has not been" + - "initialized or has been removed") - } - - if (keyedState.hasRemoved) { - store.remove(keyRow) + if (groupState.hasRemoved && groupState.getTimeoutTimestamp == NO_TIMESTAMP) { + stateManager.removeState(store, stateData.keyRow) numUpdatedStateRows += 1 - } else { - val previousTimeoutTimestamp = getTimeoutTimestamp(prevStateRow) - val stateRowToWrite = if (keyedState.hasUpdated) { - getStateRow(keyedState.get) - } else { - prevStateRow - } - - val hasTimeoutChanged = currentTimeoutTimestamp != previousTimeoutTimestamp - val shouldWriteState = keyedState.hasUpdated || hasTimeoutChanged + val currentTimeoutTimestamp = groupState.getTimeoutTimestamp + val hasTimeoutChanged = currentTimeoutTimestamp != stateData.timeoutTimestamp + val shouldWriteState = groupState.hasUpdated || groupState.hasRemoved || hasTimeoutChanged if (shouldWriteState) { - if (stateRowToWrite == null) { - // This should never happen because checks in GroupStateImpl should avoid cases - // where empty state would need to be written - throw new IllegalStateException("Attempting to write empty state") - } - setTimeoutTimestamp(stateRowToWrite, currentTimeoutTimestamp) - store.put(keyRow, stateRowToWrite) + val updatedStateObj = if (groupState.exists) groupState.get else null + stateManager.putState(store, stateData.keyRow, updatedStateObj, currentTimeoutTimestamp) numUpdatedStateRows += 1 } } @@ -279,28 +219,5 @@ case class FlatMapGroupsWithStateExec( // Return an iterator of rows such that fully consumed, the updated state value will be saved CompletionIterator[InternalRow, Iterator[InternalRow]](mappedIterator, onIteratorCompletion) } - - /** Returns the state as Java object if defined */ - def getStateObj(stateRow: UnsafeRow): Any = { - if (stateRow != null) getStateObjFromRow(stateRow) else null - } - - /** Returns the row for an updated state */ - def getStateRow(obj: Any): UnsafeRow = { - assert(obj != null) - getStateRowFromObj(obj) - } - - /** Returns the timeout timestamp of a state row is set */ - def getTimeoutTimestamp(stateRow: UnsafeRow): Long = { - if (isTimeoutEnabled && stateRow != null) { - stateRow.getLong(timeoutTimestampIndex) - } else NO_TIMESTAMP - } - - /** Set the timestamp in a state row */ - def setTimeoutTimestamp(stateRow: UnsafeRow, timeoutTimestamps: Long): Unit = { - if (isTimeoutEnabled) stateRow.setLong(timeoutTimestampIndex, timeoutTimestamps) - } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/FlatMapGroupsWithState_StateManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/FlatMapGroupsWithState_StateManager.scala new file mode 100644 index 0000000000000..d077836da847c --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/FlatMapGroupsWithState_StateManager.scala @@ -0,0 +1,153 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming.state + +import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder +import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, BoundReference, CaseWhen, CreateNamedStruct, GetStructField, IsNull, Literal, UnsafeRow} +import org.apache.spark.sql.execution.ObjectOperator +import org.apache.spark.sql.execution.streaming.GroupStateImpl +import org.apache.spark.sql.execution.streaming.GroupStateImpl.NO_TIMESTAMP +import org.apache.spark.sql.types.{IntegerType, LongType, StructType} + + +/** + * Class to serialize/write/read/deserialize state for + * [[org.apache.spark.sql.execution.streaming.FlatMapGroupsWithStateExec]]. + */ +class FlatMapGroupsWithState_StateManager( + stateEncoder: ExpressionEncoder[Any], + shouldStoreTimestamp: Boolean) extends Serializable { + + /** Schema of the state rows saved in the state store */ + val stateSchema = { + val schema = new StructType().add("groupState", stateEncoder.schema, nullable = true) + if (shouldStoreTimestamp) schema.add("timeoutTimestamp", LongType) else schema + } + + /** Get deserialized state and corresponding timeout timestamp for a key */ + def getState(store: StateStore, keyRow: UnsafeRow): FlatMapGroupsWithState_StateData = { + val stateRow = store.get(keyRow) + stateDataForGets.withNew( + keyRow, stateRow, getStateObj(stateRow), getTimestamp(stateRow)) + } + + /** Put state and timeout timestamp for a key */ + def putState(store: StateStore, keyRow: UnsafeRow, state: Any, timestamp: Long): Unit = { + val stateRow = getStateRow(state) + setTimestamp(stateRow, timestamp) + store.put(keyRow, stateRow) + } + + /** Removed all information related to a key */ + def removeState(store: StateStore, keyRow: UnsafeRow): Unit = { + store.remove(keyRow) + } + + /** Get all the keys and corresponding state rows in the state store */ + def getAllState(store: StateStore): Iterator[FlatMapGroupsWithState_StateData] = { + val stateDataForGetAllState = FlatMapGroupsWithState_StateData() + store.getRange(None, None).map { pair => + stateDataForGetAllState.withNew( + pair.key, pair.value, getStateObjFromRow(pair.value), getTimestamp(pair.value)) + } + } + + // Ordinals of the information stored in the state row + private lazy val nestedStateOrdinal = 0 + private lazy val timeoutTimestampOrdinal = 1 + + // Get the serializer for the state, taking into account whether we need to save timestamps + private val stateSerializer = { + val nestedStateExpr = CreateNamedStruct( + stateEncoder.namedExpressions.flatMap(e => Seq(Literal(e.name), e))) + if (shouldStoreTimestamp) { + Seq(nestedStateExpr, Literal(GroupStateImpl.NO_TIMESTAMP)) + } else { + Seq(nestedStateExpr) + } + } + + // Get the deserializer for the state. Note that this must be done in the driver, as + // resolving and binding of deserializer expressions to the encoded type can be safely done + // only in the driver. + private val stateDeserializer = { + val boundRefToNestedState = BoundReference(nestedStateOrdinal, stateEncoder.schema, true) + val deser = stateEncoder.resolveAndBind().deserializer.transformUp { + case BoundReference(ordinal, _, _) => GetStructField(boundRefToNestedState, ordinal) + } + CaseWhen(Seq(IsNull(boundRefToNestedState) -> Literal(null)), elseValue = deser).toCodegen() + } + + // Converters for translating state between rows and Java objects + private lazy val getStateObjFromRow = ObjectOperator.deserializeRowToObject( + stateDeserializer, stateSchema.toAttributes) + private lazy val getStateRowFromObj = ObjectOperator.serializeObjectToRow(stateSerializer) + + // Reusable instance for returning state information + private lazy val stateDataForGets = FlatMapGroupsWithState_StateData() + + /** Returns the state as Java object if defined */ + private def getStateObj(stateRow: UnsafeRow): Any = { + if (stateRow == null) null + else getStateObjFromRow(stateRow) + } + + /** Returns the row for an updated state */ + private def getStateRow(obj: Any): UnsafeRow = { + val row = getStateRowFromObj(obj) + if (obj == null) { + row.setNullAt(nestedStateOrdinal) + } + row + } + + /** Returns the timeout timestamp of a state row is set */ + private def getTimestamp(stateRow: UnsafeRow): Long = { + if (shouldStoreTimestamp && stateRow != null) { + stateRow.getLong(timeoutTimestampOrdinal) + } else NO_TIMESTAMP + } + + /** Set the timestamp in a state row */ + private def setTimestamp(stateRow: UnsafeRow, timeoutTimestamps: Long): Unit = { + if (shouldStoreTimestamp) stateRow.setLong(timeoutTimestampOrdinal, timeoutTimestamps) + } +} + +/** + * Class to capture deserialized state and timestamp return by the state manager. + * This is intended for reuse. + */ +case class FlatMapGroupsWithState_StateData( + var keyRow: UnsafeRow = null, + var stateRow: UnsafeRow = null, + var stateObj: Any = null, + var timeoutTimestamp: Long = -1) { + def withNew( + newKeyRow: UnsafeRow, + newStateRow: UnsafeRow, + newStateObj: Any, + newTimeout: Long): this.type = { + keyRow = newKeyRow + stateRow = newStateRow + stateObj = newStateObj + timeoutTimestamp = newTimeout + this + } +} + diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala index 9d74a5c701ef1..d2e8beb2f5290 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala @@ -289,13 +289,13 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest with BeforeAndAf } } - // Values used for testing StateStoreUpdater + // Values used for testing InputProcessor val currentBatchTimestamp = 1000 val currentBatchWatermark = 1000 val beforeTimeoutThreshold = 999 val afterTimeoutThreshold = 1001 - // Tests for StateStoreUpdater.updateStateForKeysWithData() when timeout = NoTimeout + // Tests for InputProcessor.processNewData() when timeout = NoTimeout for (priorState <- Seq(None, Some(0))) { val priorStateStr = if (priorState.nonEmpty) "prior state set" else "no prior state" val testName = s"NoTimeout - $priorStateStr - " @@ -322,7 +322,7 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest with BeforeAndAf expectedState = None) // should be removed } - // Tests for StateStoreUpdater.updateStateForKeysWithData() when timeout != NoTimeout + // Tests for InputProcessor.processTimedOutState() when timeout != NoTimeout for (priorState <- Seq(None, Some(0))) { for (priorTimeoutTimestamp <- Seq(NO_TIMESTAMP, 1000)) { var testName = "" @@ -365,6 +365,18 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest with BeforeAndAf expectedState = None) // state should be removed } + // Tests with ProcessingTimeTimeout + if (priorState == None) { + testStateUpdateWithData( + s"ProcessingTimeTimeout - $testName - timeout updated without initializing state", + stateUpdates = state => { state.setTimeoutDuration(5000) }, + timeoutConf = ProcessingTimeTimeout, + priorState = None, + priorTimeoutTimestamp = priorTimeoutTimestamp, + expectedState = None, + expectedTimeoutTimestamp = currentBatchTimestamp + 5000) + } + testStateUpdateWithData( s"ProcessingTimeTimeout - $testName - state and timeout duration updated", stateUpdates = @@ -375,10 +387,36 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest with BeforeAndAf expectedState = Some(5), // state should change expectedTimeoutTimestamp = currentBatchTimestamp + 5000) // timestamp should change + testStateUpdateWithData( + s"ProcessingTimeTimeout - $testName - timeout updated after state removed", + stateUpdates = state => { state.remove(); state.setTimeoutDuration(5000) }, + timeoutConf = ProcessingTimeTimeout, + priorState = priorState, + priorTimeoutTimestamp = priorTimeoutTimestamp, + expectedState = None, + expectedTimeoutTimestamp = currentBatchTimestamp + 5000) + + // Tests with EventTimeTimeout + + if (priorState == None) { + testStateUpdateWithData( + s"EventTimeTimeout - $testName - setting timeout without init state not allowed", + stateUpdates = state => { + state.setTimeoutTimestamp(10000) + }, + timeoutConf = EventTimeTimeout, + priorState = None, + priorTimeoutTimestamp = priorTimeoutTimestamp, + expectedState = None, + expectedTimeoutTimestamp = 10000) + } + testStateUpdateWithData( s"EventTimeTimeout - $testName - state and timeout timestamp updated", stateUpdates = - (state: GroupState[Int]) => { state.update(5); state.setTimeoutTimestamp(5000) }, + (state: GroupState[Int]) => { + state.update(5); state.setTimeoutTimestamp(5000) + }, timeoutConf = EventTimeTimeout, priorState = priorState, priorTimeoutTimestamp = priorTimeoutTimestamp, @@ -397,50 +435,23 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest with BeforeAndAf timeoutConf = EventTimeTimeout, priorState = priorState, priorTimeoutTimestamp = priorTimeoutTimestamp, - expectedState = Some(5), // state should change - expectedTimeoutTimestamp = NO_TIMESTAMP) // timestamp should not update - } - } - - // Currently disallowed cases for StateStoreUpdater.updateStateForKeysWithData(), - // Try to remove these cases in the future - for (priorTimeoutTimestamp <- Seq(NO_TIMESTAMP, 1000)) { - val testName = - if (priorTimeoutTimestamp != NO_TIMESTAMP) "prior timeout set" else "no prior timeout" - testStateUpdateWithData( - s"ProcessingTimeTimeout - $testName - setting timeout without init state not allowed", - stateUpdates = state => { state.setTimeoutDuration(5000) }, - timeoutConf = ProcessingTimeTimeout, - priorState = None, - priorTimeoutTimestamp = priorTimeoutTimestamp, - expectedException = classOf[IllegalStateException]) - - testStateUpdateWithData( - s"ProcessingTimeTimeout - $testName - setting timeout with state removal not allowed", - stateUpdates = state => { state.remove(); state.setTimeoutDuration(5000) }, - timeoutConf = ProcessingTimeTimeout, - priorState = Some(5), - priorTimeoutTimestamp = priorTimeoutTimestamp, - expectedException = classOf[IllegalStateException]) - - testStateUpdateWithData( - s"EventTimeTimeout - $testName - setting timeout without init state not allowed", - stateUpdates = state => { state.setTimeoutTimestamp(10000) }, - timeoutConf = EventTimeTimeout, - priorState = None, - priorTimeoutTimestamp = priorTimeoutTimestamp, - expectedException = classOf[IllegalStateException]) + expectedState = Some(5), // state should change + expectedTimeoutTimestamp = NO_TIMESTAMP) // timestamp should not update - testStateUpdateWithData( - s"EventTimeTimeout - $testName - setting timeout with state removal not allowed", - stateUpdates = state => { state.remove(); state.setTimeoutTimestamp(10000) }, - timeoutConf = EventTimeTimeout, - priorState = Some(5), - priorTimeoutTimestamp = priorTimeoutTimestamp, - expectedException = classOf[IllegalStateException]) + testStateUpdateWithData( + s"EventTimeTimeout - $testName - setting timeout with state removal not allowed", + stateUpdates = state => { + state.remove(); state.setTimeoutTimestamp(10000) + }, + timeoutConf = EventTimeTimeout, + priorState = priorState, + priorTimeoutTimestamp = priorTimeoutTimestamp, + expectedState = None, + expectedTimeoutTimestamp = 10000) + } } - // Tests for StateStoreUpdater.updateStateForTimedOutKeys() + // Tests for InputProcessor.processTimedOutState() val preTimeoutState = Some(5) for (timeoutConf <- Seq(ProcessingTimeTimeout, EventTimeTimeout)) { testStateUpdateWithTimeout( @@ -924,7 +935,7 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest with BeforeAndAf if (priorState.isEmpty && priorTimeoutTimestamp != NO_TIMESTAMP) { return // there can be no prior timestamp, when there is no prior state } - test(s"StateStoreUpdater - updates with data - $testName") { + test(s"InputProcessor - process new data - $testName") { val mapGroupsFunc = (key: Int, values: Iterator[Int], state: GroupState[Int]) => { assert(state.hasTimedOut === false, "hasTimedOut not false") assert(values.nonEmpty, "Some value is expected") @@ -946,7 +957,7 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest with BeforeAndAf expectedState: Option[Int], expectedTimeoutTimestamp: Long = NO_TIMESTAMP): Unit = { - test(s"StateStoreUpdater - updates for timeout - $testName") { + test(s"InputProcessor - process timed out state - $testName") { val mapGroupsFunc = (key: Int, values: Iterator[Int], state: GroupState[Int]) => { assert(state.hasTimedOut === true, "hasTimedOut not true") assert(values.isEmpty, "values not empty") @@ -973,21 +984,20 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest with BeforeAndAf val store = newStateStore() val mapGroupsSparkPlan = newFlatMapGroupsWithStateExec( mapGroupsFunc, timeoutConf, currentBatchTimestamp) - val updater = new mapGroupsSparkPlan.StateStoreUpdater(store) + val inputProcessor = new mapGroupsSparkPlan.InputProcessor(store) + val stateManager = mapGroupsSparkPlan.stateManager val key = intToRow(0) // Prepare store with prior state configs - if (priorState.nonEmpty) { - val row = updater.getStateRow(priorState.get) - updater.setTimeoutTimestamp(row, priorTimeoutTimestamp) - store.put(key.copy(), row.copy()) + if (priorState.nonEmpty || priorTimeoutTimestamp != NO_TIMESTAMP) { + stateManager.putState(store, key, priorState.orNull, priorTimeoutTimestamp) } // Call updating function to update state store def callFunction() = { val returnedIter = if (testTimeoutUpdates) { - updater.updateStateForTimedOutKeys() + inputProcessor.processTimedOutState() } else { - updater.updateStateForKeysWithData(Iterator(key)) + inputProcessor.processNewData(Iterator(key)) } returnedIter.size // consume the iterator to force state updates } @@ -998,15 +1008,11 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest with BeforeAndAf } else { // Call function to update and verify updated state in store callFunction() - val updatedStateRow = store.get(key) - assert( - Option(updater.getStateObj(updatedStateRow)).map(_.toString.toInt) === expectedState, + val updatedState = stateManager.getState(store, key) + assert(Option(updatedState.stateObj).map(_.toString.toInt) === expectedState, "final state not as expected") - if (updatedStateRow != null) { - assert( - updater.getTimeoutTimestamp(updatedStateRow) === expectedTimeoutTimestamp, - "final timeout timestamp not as expected") - } + assert(updatedState.timeoutTimestamp === expectedTimeoutTimestamp, + "final timeout timestamp not as expected") } } From c8affec21c91d638009524955515fc143ad86f20 Mon Sep 17 00:00:00 2001 From: Shixiong Zhu Date: Wed, 4 Oct 2017 20:58:48 -0700 Subject: [PATCH 1444/1765] [SPARK-22203][SQL] Add job description for file listing Spark jobs ## What changes were proposed in this pull request? The user may be confused about some 10000-tasks jobs. We can add a job description for these jobs so that the user can figure it out. ## How was this patch tested? The new unit test. Before: screen shot 2017-10-04 at 3 22 09 pm After: screen shot 2017-10-04 at 3 13 51 pm Author: Shixiong Zhu Closes #19432 from zsxwing/SPARK-22203. --- .../datasources/InMemoryFileIndex.scala | 85 +++++++++++-------- .../sql/test/DataFrameReaderWriterSuite.scala | 31 +++++++ 2 files changed, 81 insertions(+), 35 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InMemoryFileIndex.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InMemoryFileIndex.scala index 203d449717512..318ada0ceefc5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InMemoryFileIndex.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InMemoryFileIndex.scala @@ -25,6 +25,7 @@ import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs._ import org.apache.hadoop.mapred.{FileInputFormat, JobConf} +import org.apache.spark.SparkContext import org.apache.spark.internal.Logging import org.apache.spark.metrics.source.HiveCatalogMetrics import org.apache.spark.sql.SparkSession @@ -187,42 +188,56 @@ object InMemoryFileIndex extends Logging { // in case of large #defaultParallelism. val numParallelism = Math.min(paths.size, parallelPartitionDiscoveryParallelism) - val statusMap = sparkContext - .parallelize(serializedPaths, numParallelism) - .mapPartitions { pathStrings => - val hadoopConf = serializableConfiguration.value - pathStrings.map(new Path(_)).toSeq.map { path => - (path, listLeafFiles(path, hadoopConf, filter, None)) - }.iterator - }.map { case (path, statuses) => - val serializableStatuses = statuses.map { status => - // Turn FileStatus into SerializableFileStatus so we can send it back to the driver - val blockLocations = status match { - case f: LocatedFileStatus => - f.getBlockLocations.map { loc => - SerializableBlockLocation( - loc.getNames, - loc.getHosts, - loc.getOffset, - loc.getLength) - } - - case _ => - Array.empty[SerializableBlockLocation] - } - - SerializableFileStatus( - status.getPath.toString, - status.getLen, - status.isDirectory, - status.getReplication, - status.getBlockSize, - status.getModificationTime, - status.getAccessTime, - blockLocations) + val previousJobDescription = sparkContext.getLocalProperty(SparkContext.SPARK_JOB_DESCRIPTION) + val statusMap = try { + val description = paths.size match { + case 0 => + s"Listing leaf files and directories 0 paths" + case 1 => + s"Listing leaf files and directories for 1 path:
    ${paths(0)}" + case s => + s"Listing leaf files and directories for $s paths:
    ${paths(0)}, ..." } - (path.toString, serializableStatuses) - }.collect() + sparkContext.setJobDescription(description) + sparkContext + .parallelize(serializedPaths, numParallelism) + .mapPartitions { pathStrings => + val hadoopConf = serializableConfiguration.value + pathStrings.map(new Path(_)).toSeq.map { path => + (path, listLeafFiles(path, hadoopConf, filter, None)) + }.iterator + }.map { case (path, statuses) => + val serializableStatuses = statuses.map { status => + // Turn FileStatus into SerializableFileStatus so we can send it back to the driver + val blockLocations = status match { + case f: LocatedFileStatus => + f.getBlockLocations.map { loc => + SerializableBlockLocation( + loc.getNames, + loc.getHosts, + loc.getOffset, + loc.getLength) + } + + case _ => + Array.empty[SerializableBlockLocation] + } + + SerializableFileStatus( + status.getPath.toString, + status.getLen, + status.isDirectory, + status.getReplication, + status.getBlockSize, + status.getModificationTime, + status.getAccessTime, + blockLocations) + } + (path.toString, serializableStatuses) + }.collect() + } finally { + sparkContext.setJobDescription(previousJobDescription) + } // turn SerializableFileStatus back to Status statusMap.map { case (path, serializableStatuses) => diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/DataFrameReaderWriterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/DataFrameReaderWriterSuite.scala index 569bac156b531..a5d7e6257a6df 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/test/DataFrameReaderWriterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/DataFrameReaderWriterSuite.scala @@ -21,10 +21,14 @@ import java.io.File import java.util.Locale import java.util.concurrent.ConcurrentLinkedQueue +import scala.collection.JavaConverters._ + import org.scalatest.BeforeAndAfter +import org.apache.spark.SparkContext import org.apache.spark.internal.io.FileCommitProtocol.TaskCommitMessage import org.apache.spark.internal.io.HadoopMapReduceCommitProtocol +import org.apache.spark.scheduler.{SparkListener, SparkListenerJobStart} import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.internal.SQLConf @@ -775,4 +779,31 @@ class DataFrameReaderWriterSuite extends QueryTest with SharedSQLContext with Be } } } + + test("use Spark jobs to list files") { + withSQLConf(SQLConf.PARALLEL_PARTITION_DISCOVERY_THRESHOLD.key -> "1") { + withTempDir { dir => + val jobDescriptions = new ConcurrentLinkedQueue[String]() + val jobListener = new SparkListener { + override def onJobStart(jobStart: SparkListenerJobStart): Unit = { + jobDescriptions.add(jobStart.properties.getProperty(SparkContext.SPARK_JOB_DESCRIPTION)) + } + } + sparkContext.addSparkListener(jobListener) + try { + spark.range(0, 3).map(i => (i, i)) + .write.partitionBy("_1").mode("overwrite").parquet(dir.getCanonicalPath) + // normal file paths + checkDatasetUnorderly( + spark.read.parquet(dir.getCanonicalPath).as[(Long, Long)], + 0L -> 0L, 1L -> 1L, 2L -> 2L) + sparkContext.listenerBus.waitUntilEmpty(10000) + assert(jobDescriptions.asScala.toList.exists( + _.contains("Listing leaf files and directories for 3 paths"))) + } finally { + sparkContext.removeSparkListener(jobListener) + } + } + } + } } From ae61f187aa0471242c046fdeac6ed55b9b98a3f6 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Thu, 5 Oct 2017 23:36:18 +0900 Subject: [PATCH 1445/1765] [SPARK-22206][SQL][SPARKR] gapply in R can't work on empty grouping columns ## What changes were proposed in this pull request? Looks like `FlatMapGroupsInRExec.requiredChildDistribution` didn't consider empty grouping attributes. It should be a problem when running `EnsureRequirements` and `gapply` in R can't work on empty grouping columns. ## How was this patch tested? Added test. Author: Liang-Chi Hsieh Closes #19436 from viirya/fix-flatmapinr-distribution. --- R/pkg/tests/fulltests/test_sparkSQL.R | 5 +++++ .../main/scala/org/apache/spark/sql/execution/objects.scala | 6 +++++- 2 files changed, 10 insertions(+), 1 deletion(-) diff --git a/R/pkg/tests/fulltests/test_sparkSQL.R b/R/pkg/tests/fulltests/test_sparkSQL.R index 7f781f2f66a7f..bbea25bc4da5c 100644 --- a/R/pkg/tests/fulltests/test_sparkSQL.R +++ b/R/pkg/tests/fulltests/test_sparkSQL.R @@ -3075,6 +3075,11 @@ test_that("gapply() and gapplyCollect() on a DataFrame", { df1Collect <- gapplyCollect(df, list("a"), function(key, x) { x }) expect_identical(df1Collect, expected) + # gapply on empty grouping columns. + df1 <- gapply(df, c(), function(key, x) { x }, schema(df)) + actual <- collect(df1) + expect_identical(actual, expected) + # Computes the sum of second column by grouping on the first and third columns # and checks if the sum is larger than 2 schemas <- list(structType(structField("a", "integer"), structField("e", "boolean")), diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala index 5a3fcad38888e..c68975bea490f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala @@ -394,7 +394,11 @@ case class FlatMapGroupsInRExec( override def producedAttributes: AttributeSet = AttributeSet(outputObjAttr) override def requiredChildDistribution: Seq[Distribution] = - ClusteredDistribution(groupingAttributes) :: Nil + if (groupingAttributes.isEmpty) { + AllTuples :: Nil + } else { + ClusteredDistribution(groupingAttributes) :: Nil + } override def requiredChildOrdering: Seq[Seq[SortOrder]] = Seq(groupingAttributes.map(SortOrder(_, Ascending))) From 83488cc3180ca18f829516f550766efb3095881e Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Thu, 5 Oct 2017 23:33:49 -0700 Subject: [PATCH 1446/1765] [SPARK-21871][SQL] Fix infinite loop when bytecode size is larger than spark.sql.codegen.hugeMethodLimit ## What changes were proposed in this pull request? When exceeding `spark.sql.codegen.hugeMethodLimit`, the runtime fallbacks to the Volcano iterator solution. This could cause an infinite loop when `FileSourceScanExec` can use the columnar batch to read the data. This PR is to fix the issue. ## How was this patch tested? Added a test Author: gatorsmile Closes #19440 from gatorsmile/testt. --- .../sql/execution/WholeStageCodegenExec.scala | 12 ++++++---- .../execution/WholeStageCodegenSuite.scala | 23 +++++++++++++++++-- 2 files changed, 29 insertions(+), 6 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala index 9073d599ac43d..1aaaf896692d1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala @@ -392,12 +392,16 @@ case class WholeStageCodegenExec(child: SparkPlan) extends UnaryExecNode with Co // Check if compiled code has a too large function if (maxCodeSize > sqlContext.conf.hugeMethodLimit) { - logWarning(s"Found too long generated codes and JIT optimization might not work: " + - s"the bytecode size was $maxCodeSize, this value went over the limit " + + logInfo(s"Found too long generated codes and JIT optimization might not work: " + + s"the bytecode size ($maxCodeSize) is above the limit " + s"${sqlContext.conf.hugeMethodLimit}, and the whole-stage codegen was disabled " + s"for this plan. To avoid this, you can raise the limit " + - s"${SQLConf.WHOLESTAGE_HUGE_METHOD_LIMIT.key}:\n$treeString") - return child.execute() + s"`${SQLConf.WHOLESTAGE_HUGE_METHOD_LIMIT.key}`:\n$treeString") + child match { + // The fallback solution of batch file source scan still uses WholeStageCodegenExec + case f: FileSourceScanExec if f.supportsBatch => // do nothing + case _ => return child.execute() + } } val references = ctx.references.toArray diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala index aaa77b3ee6201..098e4cfeb15b2 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.execution -import org.apache.spark.sql.Row +import org.apache.spark.sql.{QueryTest, Row, SaveMode} import org.apache.spark.sql.catalyst.expressions.codegen.{CodeAndComment, CodeGenerator} import org.apache.spark.sql.execution.aggregate.HashAggregateExec import org.apache.spark.sql.execution.joins.BroadcastHashJoinExec @@ -28,7 +28,7 @@ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types.{IntegerType, StringType, StructType} -class WholeStageCodegenSuite extends SparkPlanTest with SharedSQLContext { +class WholeStageCodegenSuite extends QueryTest with SharedSQLContext { test("range/filter should be combined") { val df = spark.range(10).filter("id = 1").selectExpr("id + 1") @@ -185,4 +185,23 @@ class WholeStageCodegenSuite extends SparkPlanTest with SharedSQLContext { val (_, maxCodeSize2) = CodeGenerator.compile(codeWithLongFunctions) assert(maxCodeSize2 > SQLConf.WHOLESTAGE_HUGE_METHOD_LIMIT.defaultValue.get) } + + test("bytecode of batch file scan exceeds the limit of WHOLESTAGE_HUGE_METHOD_LIMIT") { + import testImplicits._ + withTempPath { dir => + val path = dir.getCanonicalPath + val df = spark.range(10).select(Seq.tabulate(201) {i => ('id + i).as(s"c$i")} : _*) + df.write.mode(SaveMode.Overwrite).parquet(path) + + withSQLConf(SQLConf.WHOLESTAGE_MAX_NUM_FIELDS.key -> "202", + SQLConf.WHOLESTAGE_HUGE_METHOD_LIMIT.key -> "2000") { + // wide table batch scan causes the byte code of codegen exceeds the limit of + // WHOLESTAGE_HUGE_METHOD_LIMIT + val df2 = spark.read.parquet(path) + val fileScan2 = df2.queryExecution.sparkPlan.find(_.isInstanceOf[FileSourceScanExec]).get + assert(fileScan2.asInstanceOf[FileSourceScanExec].supportsBatch) + checkAnswer(df2, df) + } + } + } } From 0c03297bf0e87944f9fe0535fdae5518228e3e29 Mon Sep 17 00:00:00 2001 From: Sean Owen Date: Fri, 6 Oct 2017 15:08:28 +0100 Subject: [PATCH 1447/1765] [SPARK-22142][BUILD][STREAMING] Move Flume support behind a profile, take 2 ## What changes were proposed in this pull request? Move flume behind a profile, take 2. See https://github.com/apache/spark/pull/19365 for most of the back-story. This change should fix the problem by removing the examples module dependency and moving Flume examples to the module itself. It also adds deprecation messages, per a discussion on dev about deprecating for 2.3.0. ## How was this patch tested? Existing tests, which still enable flume integration. Author: Sean Owen Closes #19412 from srowen/SPARK-22142.2. --- dev/create-release/release-build.sh | 4 ++-- dev/mima | 2 +- dev/scalastyle | 1 + dev/sparktestsupport/modules.py | 20 ++++++++++++++++++- dev/test-dependencies.sh | 2 +- docs/building-spark.md | 7 +++++++ docs/streaming-flume-integration.md | 13 +++++------- examples/pom.xml | 7 ------- .../spark/examples}/JavaFlumeEventCount.java | 2 -- .../spark/examples}/FlumeEventCount.scala | 2 -- .../examples}/FlumePollingEventCount.scala | 2 -- .../spark/streaming/flume/FlumeUtils.scala | 1 + pom.xml | 13 +++++++++--- project/SparkBuild.scala | 17 ++++++++-------- python/pyspark/streaming/flume.py | 4 ++++ python/pyspark/streaming/tests.py | 16 ++++++++++++--- 16 files changed, 73 insertions(+), 40 deletions(-) rename {examples/src/main/java/org/apache/spark/examples/streaming => external/flume/src/main/java/org/apache/spark/examples}/JavaFlumeEventCount.java (98%) rename {examples/src/main/scala/org/apache/spark/examples/streaming => external/flume/src/main/scala/org/apache/spark/examples}/FlumeEventCount.scala (98%) rename {examples/src/main/scala/org/apache/spark/examples/streaming => external/flume/src/main/scala/org/apache/spark/examples}/FlumePollingEventCount.scala (98%) diff --git a/dev/create-release/release-build.sh b/dev/create-release/release-build.sh index 5390f5916fc0d..7e8d5c7075195 100755 --- a/dev/create-release/release-build.sh +++ b/dev/create-release/release-build.sh @@ -84,9 +84,9 @@ MVN="build/mvn --force" # Hive-specific profiles for some builds HIVE_PROFILES="-Phive -Phive-thriftserver" # Profiles for publishing snapshots and release to Maven Central -PUBLISH_PROFILES="-Pmesos -Pyarn $HIVE_PROFILES -Pspark-ganglia-lgpl -Pkinesis-asl" +PUBLISH_PROFILES="-Pmesos -Pyarn -Pflume $HIVE_PROFILES -Pspark-ganglia-lgpl -Pkinesis-asl" # Profiles for building binary releases -BASE_RELEASE_PROFILES="-Pmesos -Pyarn -Psparkr" +BASE_RELEASE_PROFILES="-Pmesos -Pyarn -Pflume -Psparkr" # Scala 2.11 only profiles for some builds SCALA_2_11_PROFILES="-Pkafka-0-8" # Scala 2.12 only profiles for some builds diff --git a/dev/mima b/dev/mima index fdb21f5007cf2..1e3ca9700bc07 100755 --- a/dev/mima +++ b/dev/mima @@ -24,7 +24,7 @@ set -e FWDIR="$(cd "`dirname "$0"`"/..; pwd)" cd "$FWDIR" -SPARK_PROFILES="-Pmesos -Pkafka-0-8 -Pyarn -Pspark-ganglia-lgpl -Pkinesis-asl -Phive-thriftserver -Phive" +SPARK_PROFILES="-Pmesos -Pkafka-0-8 -Pyarn -Pflume -Pspark-ganglia-lgpl -Pkinesis-asl -Phive-thriftserver -Phive" TOOLS_CLASSPATH="$(build/sbt -DcopyDependencies=false "export tools/fullClasspath" | tail -n1)" OLD_DEPS_CLASSPATH="$(build/sbt -DcopyDependencies=false $SPARK_PROFILES "export oldDeps/fullClasspath" | tail -n1)" diff --git a/dev/scalastyle b/dev/scalastyle index e5aa589869535..89ecc8abd6f8c 100755 --- a/dev/scalastyle +++ b/dev/scalastyle @@ -25,6 +25,7 @@ ERRORS=$(echo -e "q\n" \ -Pmesos \ -Pkafka-0-8 \ -Pyarn \ + -Pflume \ -Phive \ -Phive-thriftserver \ scalastyle test:scalastyle \ diff --git a/dev/sparktestsupport/modules.py b/dev/sparktestsupport/modules.py index 50e14b60545af..91d5667ed1f07 100644 --- a/dev/sparktestsupport/modules.py +++ b/dev/sparktestsupport/modules.py @@ -279,6 +279,12 @@ def __hash__(self): source_file_regexes=[ "external/flume-sink", ], + build_profile_flags=[ + "-Pflume", + ], + environ={ + "ENABLE_FLUME_TESTS": "1" + }, sbt_test_goals=[ "streaming-flume-sink/test", ] @@ -291,6 +297,12 @@ def __hash__(self): source_file_regexes=[ "external/flume", ], + build_profile_flags=[ + "-Pflume", + ], + environ={ + "ENABLE_FLUME_TESTS": "1" + }, sbt_test_goals=[ "streaming-flume/test", ] @@ -302,7 +314,13 @@ def __hash__(self): dependencies=[streaming_flume, streaming_flume_sink], source_file_regexes=[ "external/flume-assembly", - ] + ], + build_profile_flags=[ + "-Pflume", + ], + environ={ + "ENABLE_FLUME_TESTS": "1" + } ) diff --git a/dev/test-dependencies.sh b/dev/test-dependencies.sh index c7714578bd005..58b295d4f6e00 100755 --- a/dev/test-dependencies.sh +++ b/dev/test-dependencies.sh @@ -29,7 +29,7 @@ export LC_ALL=C # TODO: This would be much nicer to do in SBT, once SBT supports Maven-style resolution. # NOTE: These should match those in the release publishing script -HADOOP2_MODULE_PROFILES="-Phive-thriftserver -Pmesos -Pkafka-0-8 -Pyarn -Phive" +HADOOP2_MODULE_PROFILES="-Phive-thriftserver -Pmesos -Pkafka-0-8 -Pyarn -Pflume -Phive" MVN="build/mvn" HADOOP_PROFILES=( hadoop-2.6 diff --git a/docs/building-spark.md b/docs/building-spark.md index 57baa503259c1..98f7df155456f 100644 --- a/docs/building-spark.md +++ b/docs/building-spark.md @@ -100,6 +100,13 @@ Note: Kafka 0.8 support is deprecated as of Spark 2.3.0. Kafka 0.10 support is still automatically built. +## Building with Flume support + +Apache Flume support must be explicitly enabled with the `flume` profile. +Note: Flume support is deprecated as of Spark 2.3.0. + + ./build/mvn -Pflume -DskipTests clean package + ## Building submodules individually It's possible to build Spark sub-modules using the `mvn -pl` option. diff --git a/docs/streaming-flume-integration.md b/docs/streaming-flume-integration.md index a5d36da5b6de9..257a4f7d4f3ca 100644 --- a/docs/streaming-flume-integration.md +++ b/docs/streaming-flume-integration.md @@ -5,6 +5,8 @@ title: Spark Streaming + Flume Integration Guide [Apache Flume](https://flume.apache.org/) is a distributed, reliable, and available service for efficiently collecting, aggregating, and moving large amounts of log data. Here we explain how to configure Flume and Spark Streaming to receive data from Flume. There are two approaches to this. +**Note: Flume support is deprecated as of Spark 2.3.0.** + ## Approach 1: Flume-style Push-based Approach Flume is designed to push data between Flume agents. In this approach, Spark Streaming essentially sets up a receiver that acts an Avro agent for Flume, to which Flume can push the data. Here are the configuration steps. @@ -44,8 +46,7 @@ configuring Flume agents. val flumeStream = FlumeUtils.createStream(streamingContext, [chosen machine's hostname], [chosen port]) - See the [API docs](api/scala/index.html#org.apache.spark.streaming.flume.FlumeUtils$) - and the [example]({{site.SPARK_GITHUB_URL}}/tree/master/examples/src/main/scala/org/apache/spark/examples/streaming/FlumeEventCount.scala). + See the [API docs](api/scala/index.html#org.apache.spark.streaming.flume.FlumeUtils$).
    import org.apache.spark.streaming.flume.*; @@ -53,8 +54,7 @@ configuring Flume agents. JavaReceiverInputDStream flumeStream = FlumeUtils.createStream(streamingContext, [chosen machine's hostname], [chosen port]); - See the [API docs](api/java/index.html?org/apache/spark/streaming/flume/FlumeUtils.html) - and the [example]({{site.SPARK_GITHUB_URL}}/tree/master/examples/src/main/java/org/apache/spark/examples/streaming/JavaFlumeEventCount.java). + See the [API docs](api/java/index.html?org/apache/spark/streaming/flume/FlumeUtils.html).
    from pyspark.streaming.flume import FlumeUtils @@ -62,8 +62,7 @@ configuring Flume agents. flumeStream = FlumeUtils.createStream(streamingContext, [chosen machine's hostname], [chosen port]) By default, the Python API will decode Flume event body as UTF8 encoded strings. You can specify your custom decoding function to decode the body byte arrays in Flume events to any arbitrary data type. - See the [API docs](api/python/pyspark.streaming.html#pyspark.streaming.flume.FlumeUtils) - and the [example]({{site.SPARK_GITHUB_URL}}/blob/v{{site.SPARK_VERSION_SHORT}}/examples/src/main/python/streaming/flume_wordcount.py). + See the [API docs](api/python/pyspark.streaming.html#pyspark.streaming.flume.FlumeUtils).
    @@ -162,8 +161,6 @@ configuring Flume agents. - See the Scala example [FlumePollingEventCount]({{site.SPARK_GITHUB_URL}}/tree/master/examples/src/main/scala/org/apache/spark/examples/streaming/FlumePollingEventCount.scala). - Note that each input DStream can be configured to receive data from multiple sinks. 3. **Deploying:** This is same as the first approach. diff --git a/examples/pom.xml b/examples/pom.xml index 52a6764ae26a5..1791dbaad775e 100644 --- a/examples/pom.xml +++ b/examples/pom.xml @@ -34,7 +34,6 @@ examples none package - provided provided provided provided @@ -78,12 +77,6 @@ ${project.version} provided - - org.apache.spark - spark-streaming-flume_${scala.binary.version} - ${project.version} - provided - org.apache.spark spark-streaming-kafka-0-10_${scala.binary.version} diff --git a/examples/src/main/java/org/apache/spark/examples/streaming/JavaFlumeEventCount.java b/external/flume/src/main/java/org/apache/spark/examples/JavaFlumeEventCount.java similarity index 98% rename from examples/src/main/java/org/apache/spark/examples/streaming/JavaFlumeEventCount.java rename to external/flume/src/main/java/org/apache/spark/examples/JavaFlumeEventCount.java index 0c651049d0ffa..4e3420d9c3b06 100644 --- a/examples/src/main/java/org/apache/spark/examples/streaming/JavaFlumeEventCount.java +++ b/external/flume/src/main/java/org/apache/spark/examples/JavaFlumeEventCount.java @@ -48,8 +48,6 @@ public static void main(String[] args) throws Exception { System.exit(1); } - StreamingExamples.setStreamingLogLevels(); - String host = args[0]; int port = Integer.parseInt(args[1]); diff --git a/examples/src/main/scala/org/apache/spark/examples/streaming/FlumeEventCount.scala b/external/flume/src/main/scala/org/apache/spark/examples/FlumeEventCount.scala similarity index 98% rename from examples/src/main/scala/org/apache/spark/examples/streaming/FlumeEventCount.scala rename to external/flume/src/main/scala/org/apache/spark/examples/FlumeEventCount.scala index 91e52e4eff5a7..f877f79391b37 100644 --- a/examples/src/main/scala/org/apache/spark/examples/streaming/FlumeEventCount.scala +++ b/external/flume/src/main/scala/org/apache/spark/examples/FlumeEventCount.scala @@ -47,8 +47,6 @@ object FlumeEventCount { System.exit(1) } - StreamingExamples.setStreamingLogLevels() - val Array(host, IntParam(port)) = args val batchInterval = Milliseconds(2000) diff --git a/examples/src/main/scala/org/apache/spark/examples/streaming/FlumePollingEventCount.scala b/external/flume/src/main/scala/org/apache/spark/examples/FlumePollingEventCount.scala similarity index 98% rename from examples/src/main/scala/org/apache/spark/examples/streaming/FlumePollingEventCount.scala rename to external/flume/src/main/scala/org/apache/spark/examples/FlumePollingEventCount.scala index dd725d72c23ef..79a4027ca5bde 100644 --- a/examples/src/main/scala/org/apache/spark/examples/streaming/FlumePollingEventCount.scala +++ b/external/flume/src/main/scala/org/apache/spark/examples/FlumePollingEventCount.scala @@ -44,8 +44,6 @@ object FlumePollingEventCount { System.exit(1) } - StreamingExamples.setStreamingLogLevels() - val Array(host, IntParam(port)) = args val batchInterval = Milliseconds(2000) diff --git a/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumeUtils.scala b/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumeUtils.scala index 3e3ed712f0dbf..707193a957700 100644 --- a/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumeUtils.scala +++ b/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumeUtils.scala @@ -30,6 +30,7 @@ import org.apache.spark.streaming.StreamingContext import org.apache.spark.streaming.api.java.{JavaPairDStream, JavaReceiverInputDStream, JavaStreamingContext} import org.apache.spark.streaming.dstream.ReceiverInputDStream +@deprecated("Deprecated without replacement", "2.3.0") object FlumeUtils { private val DEFAULT_POLLING_PARALLELISM = 5 private val DEFAULT_POLLING_BATCH_SIZE = 1000 diff --git a/pom.xml b/pom.xml index 87a468c3a6f55..9fac8b1e53788 100644 --- a/pom.xml +++ b/pom.xml @@ -98,15 +98,13 @@ sql/core sql/hive assembly - external/flume - external/flume-sink - external/flume-assembly examples repl launcher external/kafka-0-10 external/kafka-0-10-assembly external/kafka-0-10-sql + @@ -2583,6 +2581,15 @@
    + + flume + + external/flume + external/flume-sink + external/flume-assembly + + + spark-ganglia-lgpl diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala index a568d264cb2db..9501eed1e906b 100644 --- a/project/SparkBuild.scala +++ b/project/SparkBuild.scala @@ -43,11 +43,8 @@ object BuildCommons { "catalyst", "sql", "hive", "hive-thriftserver", "sql-kafka-0-10" ).map(ProjectRef(buildLocation, _)) - val streamingProjects@Seq( - streaming, streamingFlumeSink, streamingFlume, streamingKafka010 - ) = Seq( - "streaming", "streaming-flume-sink", "streaming-flume", "streaming-kafka-0-10" - ).map(ProjectRef(buildLocation, _)) + val streamingProjects@Seq(streaming, streamingKafka010) = + Seq("streaming", "streaming-kafka-0-10").map(ProjectRef(buildLocation, _)) val allProjects@Seq( core, graphx, mllib, mllibLocal, repl, networkCommon, networkShuffle, launcher, unsafe, tags, sketch, kvstore, _* @@ -56,9 +53,13 @@ object BuildCommons { "tags", "sketch", "kvstore" ).map(ProjectRef(buildLocation, _)) ++ sqlProjects ++ streamingProjects - val optionallyEnabledProjects@Seq(mesos, yarn, streamingKafka, sparkGangliaLgpl, - streamingKinesisAsl, dockerIntegrationTests, hadoopCloud) = - Seq("mesos", "yarn", "streaming-kafka-0-8", "ganglia-lgpl", "streaming-kinesis-asl", + val optionallyEnabledProjects@Seq(mesos, yarn, + streamingFlumeSink, streamingFlume, + streamingKafka, sparkGangliaLgpl, streamingKinesisAsl, + dockerIntegrationTests, hadoopCloud) = + Seq("mesos", "yarn", + "streaming-flume-sink", "streaming-flume", + "streaming-kafka-0-8", "ganglia-lgpl", "streaming-kinesis-asl", "docker-integration-tests", "hadoop-cloud").map(ProjectRef(buildLocation, _)) val assemblyProjects@Seq(networkYarn, streamingFlumeAssembly, streamingKafkaAssembly, streamingKafka010Assembly, streamingKinesisAslAssembly) = diff --git a/python/pyspark/streaming/flume.py b/python/pyspark/streaming/flume.py index cd30483fc636a..2fed5940b31ea 100644 --- a/python/pyspark/streaming/flume.py +++ b/python/pyspark/streaming/flume.py @@ -53,6 +53,8 @@ def createStream(ssc, hostname, port, :param enableDecompression: Should netty server decompress input stream :param bodyDecoder: A function used to decode body (default is utf8_decoder) :return: A DStream object + + .. note:: Deprecated in 2.3.0 """ jlevel = ssc._sc._getJavaStorageLevel(storageLevel) helper = FlumeUtils._get_helper(ssc._sc) @@ -79,6 +81,8 @@ def createPollingStream(ssc, addresses, will result in this stream using more threads :param bodyDecoder: A function used to decode body (default is utf8_decoder) :return: A DStream object + + .. note:: Deprecated in 2.3.0 """ jlevel = ssc._sc._getJavaStorageLevel(storageLevel) hosts = [] diff --git a/python/pyspark/streaming/tests.py b/python/pyspark/streaming/tests.py index 229cf53e47359..5b86c1cb2c390 100644 --- a/python/pyspark/streaming/tests.py +++ b/python/pyspark/streaming/tests.py @@ -1478,7 +1478,7 @@ def search_kafka_assembly_jar(): ("Failed to find Spark Streaming kafka assembly jar in %s. " % kafka_assembly_dir) + "You need to build Spark with " "'build/sbt assembly/package streaming-kafka-0-8-assembly/assembly' or " - "'build/mvn package' before running this test.") + "'build/mvn -Pkafka-0-8 package' before running this test.") elif len(jars) > 1: raise Exception(("Found multiple Spark Streaming Kafka assembly JARs: %s; please " "remove all but one") % (", ".join(jars))) @@ -1495,7 +1495,7 @@ def search_flume_assembly_jar(): ("Failed to find Spark Streaming Flume assembly jar in %s. " % flume_assembly_dir) + "You need to build Spark with " "'build/sbt assembly/assembly streaming-flume-assembly/assembly' or " - "'build/mvn package' before running this test.") + "'build/mvn -Pflume package' before running this test.") elif len(jars) > 1: raise Exception(("Found multiple Spark Streaming Flume assembly JARs: %s; please " "remove all but one") % (", ".join(jars))) @@ -1516,6 +1516,9 @@ def search_kinesis_asl_assembly_jar(): return jars[0] +# Must be same as the variable and condition defined in modules.py +flume_test_environ_var = "ENABLE_FLUME_TESTS" +are_flume_tests_enabled = os.environ.get(flume_test_environ_var) == '1' # Must be same as the variable and condition defined in modules.py kafka_test_environ_var = "ENABLE_KAFKA_0_8_TESTS" are_kafka_tests_enabled = os.environ.get(kafka_test_environ_var) == '1' @@ -1538,9 +1541,16 @@ def search_kinesis_asl_assembly_jar(): os.environ["PYSPARK_SUBMIT_ARGS"] = "--jars %s pyspark-shell" % jars testcases = [BasicOperationTests, WindowFunctionTests, StreamingContextTests, CheckpointTests, - FlumeStreamTests, FlumePollingStreamTests, StreamingListenerTests] + if are_flume_tests_enabled: + testcases.append(FlumeStreamTests) + testcases.append(FlumePollingStreamTests) + else: + sys.stderr.write( + "Skipped test_flume_stream (enable by setting environment variable %s=1" + % flume_test_environ_var) + if are_kafka_tests_enabled: testcases.append(KafkaStreamTests) else: From c7b46d4d8aa8da24131d79d2bfa36e8db19662e4 Mon Sep 17 00:00:00 2001 From: minixalpha Date: Fri, 6 Oct 2017 23:38:47 +0900 Subject: [PATCH 1448/1765] [SPARK-21877][DEPLOY, WINDOWS] Handle quotes in Windows command scripts ## What changes were proposed in this pull request? All the windows command scripts can not handle quotes in parameter. Run a windows command shell with parameter which has quotes can reproduce the bug: ``` C:\Users\meng\software\spark-2.2.0-bin-hadoop2.7> bin\spark-shell --driver-java-options " -Dfile.encoding=utf-8 " 'C:\Users\meng\software\spark-2.2.0-bin-hadoop2.7\bin\spark-shell2.cmd" --driver-java-options "' is not recognized as an internal or external command, operable program or batch file. ``` Windows recognize "--driver-java-options" as part of the command. All the Windows command script has the following code have the bug. ``` cmd /V /E /C "" %* ``` We should quote command and parameters like ``` cmd /V /E /C """ %*" ``` ## How was this patch tested? Test manually on Windows 10 and Windows 7 We can verify it by the following demo: ``` C:\Users\meng\program\demo>cat a.cmd echo off cmd /V /E /C "b.cmd" %* C:\Users\meng\program\demo>cat b.cmd echo off echo %* C:\Users\meng\program\demo>cat c.cmd echo off cmd /V /E /C ""b.cmd" %*" C:\Users\meng\program\demo>a.cmd "123" 'b.cmd" "123' is not recognized as an internal or external command, operable program or batch file. C:\Users\meng\program\demo>c.cmd "123" "123" ``` With the spark-shell.cmd example, change it to the following code will make the command execute succeed. ``` cmd /V /E /C ""%~dp0spark-shell2.cmd" %*" ``` ``` C:\Users\meng\software\spark-2.2.0-bin-hadoop2.7> bin\spark-shell --driver-java-options " -Dfile.encoding=utf-8 " Using Spark's default log4j profile: org/apache/spark/log4j-defaults.properties Setting default log level to "WARN". To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel). ... ``` Author: minixalpha Closes #19090 from minixalpha/master. --- bin/beeline.cmd | 4 +++- bin/pyspark.cmd | 4 +++- bin/run-example.cmd | 5 ++++- bin/spark-class.cmd | 4 +++- bin/spark-shell.cmd | 4 +++- bin/spark-submit.cmd | 4 +++- bin/sparkR.cmd | 4 +++- 7 files changed, 22 insertions(+), 7 deletions(-) diff --git a/bin/beeline.cmd b/bin/beeline.cmd index 02464bd088792..288059a28cd74 100644 --- a/bin/beeline.cmd +++ b/bin/beeline.cmd @@ -17,4 +17,6 @@ rem See the License for the specific language governing permissions and rem limitations under the License. rem -cmd /V /E /C "%~dp0spark-class.cmd" org.apache.hive.beeline.BeeLine %* +rem The outermost quotes are used to prevent Windows command line parse error +rem when there are some quotes in parameters, see SPARK-21877. +cmd /V /E /C ""%~dp0spark-class.cmd" org.apache.hive.beeline.BeeLine %*" diff --git a/bin/pyspark.cmd b/bin/pyspark.cmd index 72d046a4ba2cf..3dcf1d45a8189 100644 --- a/bin/pyspark.cmd +++ b/bin/pyspark.cmd @@ -20,4 +20,6 @@ rem rem This is the entry point for running PySpark. To avoid polluting the rem environment, it just launches a new cmd to do the real work. -cmd /V /E /C "%~dp0pyspark2.cmd" %* +rem The outermost quotes are used to prevent Windows command line parse error +rem when there are some quotes in parameters, see SPARK-21877. +cmd /V /E /C ""%~dp0pyspark2.cmd" %*" diff --git a/bin/run-example.cmd b/bin/run-example.cmd index f9b786e92b823..efa5f81d08f7f 100644 --- a/bin/run-example.cmd +++ b/bin/run-example.cmd @@ -19,4 +19,7 @@ rem set SPARK_HOME=%~dp0.. set _SPARK_CMD_USAGE=Usage: ./bin/run-example [options] example-class [example args] -cmd /V /E /C "%~dp0spark-submit.cmd" run-example %* + +rem The outermost quotes are used to prevent Windows command line parse error +rem when there are some quotes in parameters, see SPARK-21877. +cmd /V /E /C ""%~dp0spark-submit.cmd" run-example %*" diff --git a/bin/spark-class.cmd b/bin/spark-class.cmd index 3bf3d20cb57b5..b22536ab6f458 100644 --- a/bin/spark-class.cmd +++ b/bin/spark-class.cmd @@ -20,4 +20,6 @@ rem rem This is the entry point for running a Spark class. To avoid polluting rem the environment, it just launches a new cmd to do the real work. -cmd /V /E /C "%~dp0spark-class2.cmd" %* +rem The outermost quotes are used to prevent Windows command line parse error +rem when there are some quotes in parameters, see SPARK-21877. +cmd /V /E /C ""%~dp0spark-class2.cmd" %*" diff --git a/bin/spark-shell.cmd b/bin/spark-shell.cmd index 991423da6ab99..e734f13097d61 100644 --- a/bin/spark-shell.cmd +++ b/bin/spark-shell.cmd @@ -20,4 +20,6 @@ rem rem This is the entry point for running Spark shell. To avoid polluting the rem environment, it just launches a new cmd to do the real work. -cmd /V /E /C "%~dp0spark-shell2.cmd" %* +rem The outermost quotes are used to prevent Windows command line parse error +rem when there are some quotes in parameters, see SPARK-21877. +cmd /V /E /C ""%~dp0spark-shell2.cmd" %*" diff --git a/bin/spark-submit.cmd b/bin/spark-submit.cmd index f301606933a95..da62a8777524d 100644 --- a/bin/spark-submit.cmd +++ b/bin/spark-submit.cmd @@ -20,4 +20,6 @@ rem rem This is the entry point for running Spark submit. To avoid polluting the rem environment, it just launches a new cmd to do the real work. -cmd /V /E /C "%~dp0spark-submit2.cmd" %* +rem The outermost quotes are used to prevent Windows command line parse error +rem when there are some quotes in parameters, see SPARK-21877. +cmd /V /E /C ""%~dp0spark-submit2.cmd" %*" diff --git a/bin/sparkR.cmd b/bin/sparkR.cmd index 1e5ea6a623219..fcd172b083e1e 100644 --- a/bin/sparkR.cmd +++ b/bin/sparkR.cmd @@ -20,4 +20,6 @@ rem rem This is the entry point for running SparkR. To avoid polluting the rem environment, it just launches a new cmd to do the real work. -cmd /V /E /C "%~dp0sparkR2.cmd" %* +rem The outermost quotes are used to prevent Windows command line parse error +rem when there are some quotes in parameters, see SPARK-21877. +cmd /V /E /C ""%~dp0sparkR2.cmd" %*" From 08b204fd2c731e87d3bc2cc0bccb6339ef7e3a6e Mon Sep 17 00:00:00 2001 From: Xingbo Jiang Date: Fri, 6 Oct 2017 12:53:35 -0700 Subject: [PATCH 1449/1765] [SPARK-22214][SQL] Refactor the list hive partitions code ## What changes were proposed in this pull request? In this PR we make a few changes to the list hive partitions code, to make the code more extensible. The following changes are made: 1. In `HiveClientImpl.getPartitions()`, call `client.getPartitions` instead of `shim.getAllPartitions` when `spec` is empty; 2. In `HiveTableScanExec`, previously we always call `listPartitionsByFilter` if the config `metastorePartitionPruning` is enabled, but actually, we'd better call `listPartitions` if `partitionPruningPred` is empty; 3. We should use sessionCatalog instead of SharedState.externalCatalog in `HiveTableScanExec`. ## How was this patch tested? Tested by existing test cases since this is code refactor, no regression or behavior change is expected. Author: Xingbo Jiang Closes #19444 from jiangxb1987/hivePartitions. --- .../sql/catalyst/catalog/interface.scala | 5 ++++ .../sql/hive/client/HiveClientImpl.scala | 7 +++-- .../hive/execution/HiveTableScanExec.scala | 28 +++++++++---------- 3 files changed, 22 insertions(+), 18 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala index fe2af910a0ae5..975b084aa6188 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala @@ -405,6 +405,11 @@ object CatalogTypes { * Specifications of a table partition. Mapping column name to column value. */ type TablePartitionSpec = Map[String, String] + + /** + * Initialize an empty spec. + */ + lazy val emptyTablePartitionSpec: TablePartitionSpec = Map.empty[String, String] } /** diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala index 66165c7228bca..a01c312d5e497 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala @@ -638,12 +638,13 @@ private[hive] class HiveClientImpl( table: CatalogTable, spec: Option[TablePartitionSpec]): Seq[CatalogTablePartition] = withHiveState { val hiveTable = toHiveTable(table, Some(userName)) - val parts = spec match { - case None => shim.getAllPartitions(client, hiveTable).map(fromHivePartition) + val partSpec = spec match { + case None => CatalogTypes.emptyTablePartitionSpec case Some(s) => assert(s.values.forall(_.nonEmpty), s"partition spec '$s' is invalid") - client.getPartitions(hiveTable, s.asJava).asScala.map(fromHivePartition) + s } + val parts = client.getPartitions(hiveTable, partSpec.asJava).asScala.map(fromHivePartition) HiveCatalogMetrics.incrementFetchedPartitions(parts.length) parts } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScanExec.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScanExec.scala index 48d0b4a63e54a..4f8dab9cd6172 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScanExec.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScanExec.scala @@ -162,21 +162,19 @@ case class HiveTableScanExec( // exposed for tests @transient lazy val rawPartitions = { - val prunedPartitions = if (sparkSession.sessionState.conf.metastorePartitionPruning) { - // Retrieve the original attributes based on expression ID so that capitalization matches. - val normalizedFilters = partitionPruningPred.map(_.transform { - case a: AttributeReference => originalAttributes(a) - }) - sparkSession.sharedState.externalCatalog.listPartitionsByFilter( - relation.tableMeta.database, - relation.tableMeta.identifier.table, - normalizedFilters, - sparkSession.sessionState.conf.sessionLocalTimeZone) - } else { - sparkSession.sharedState.externalCatalog.listPartitions( - relation.tableMeta.database, - relation.tableMeta.identifier.table) - } + val prunedPartitions = + if (sparkSession.sessionState.conf.metastorePartitionPruning && + partitionPruningPred.size > 0) { + // Retrieve the original attributes based on expression ID so that capitalization matches. + val normalizedFilters = partitionPruningPred.map(_.transform { + case a: AttributeReference => originalAttributes(a) + }) + sparkSession.sessionState.catalog.listPartitionsByFilter( + relation.tableMeta.identifier, + normalizedFilters) + } else { + sparkSession.sessionState.catalog.listPartitions(relation.tableMeta.identifier) + } prunedPartitions.map(HiveClientImpl.toHivePartition(_, hiveQlTable)) } From debcbec7491d3a23b19ef149e50d2887590b6de0 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Fri, 6 Oct 2017 13:10:04 -0700 Subject: [PATCH 1450/1765] [SPARK-21947][SS] Check and report error when monotonically_increasing_id is used in streaming query ## What changes were proposed in this pull request? `monotonically_increasing_id` doesn't work in Structured Streaming. We should throw an exception if a streaming query uses it. ## How was this patch tested? Added test. Author: Liang-Chi Hsieh Closes #19336 from viirya/SPARK-21947. --- .../analysis/UnsupportedOperationChecker.scala | 15 ++++++++++++++- .../analysis/UnsupportedOperationsSuite.scala | 10 +++++++++- 2 files changed, 23 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala index dee6fbe9d1514..04502d04d9509 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.catalyst.analysis import org.apache.spark.sql.AnalysisException -import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, AttributeSet} +import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, AttributeSet, MonotonicallyIncreasingID} import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression import org.apache.spark.sql.catalyst.planning.ExtractEquiJoinKeys import org.apache.spark.sql.catalyst.plans._ @@ -129,6 +129,16 @@ object UnsupportedOperationChecker { !subplan.isStreaming || (aggs.nonEmpty && outputMode == InternalOutputModes.Complete) } + def checkUnsupportedExpressions(implicit operator: LogicalPlan): Unit = { + val unsupportedExprs = operator.expressions.flatMap(_.collect { + case m: MonotonicallyIncreasingID => m + }).distinct + if (unsupportedExprs.nonEmpty) { + throwError("Expression(s): " + unsupportedExprs.map(_.sql).mkString(", ") + + " is not supported with streaming DataFrames/Datasets") + } + } + plan.foreachUp { implicit subPlan => // Operations that cannot exists anywhere in a streaming plan @@ -323,6 +333,9 @@ object UnsupportedOperationChecker { case _ => } + + // Check if there are unsupported expressions in streaming query plan. + checkUnsupportedExpressions(subPlan) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationsSuite.scala index e5057c451d5b8..60d1351fda264 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationsSuite.scala @@ -24,7 +24,7 @@ import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder -import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, NamedExpression} +import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, MonotonicallyIncreasingID, NamedExpression} import org.apache.spark.sql.catalyst.expressions.aggregate.Count import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical.{FlatMapGroupsWithState, _} @@ -614,6 +614,14 @@ class UnsupportedOperationsSuite extends SparkFunSuite { testOutputMode(Update, shouldSupportAggregation = true, shouldSupportNonAggregation = true) testOutputMode(Complete, shouldSupportAggregation = true, shouldSupportNonAggregation = false) + // Unsupported expressions in streaming plan + assertNotSupportedInStreamingPlan( + "MonotonicallyIncreasingID", + streamRelation.select(MonotonicallyIncreasingID()), + outputMode = Append, + expectedMsgs = Seq("monotonically_increasing_id")) + + /* ======================================================================================= TESTING FUNCTIONS From 2030f19511f656e9534f3fd692e622e45f9a074e Mon Sep 17 00:00:00 2001 From: Sergey Zhemzhitsky Date: Fri, 6 Oct 2017 20:43:53 -0700 Subject: [PATCH 1451/1765] [SPARK-21549][CORE] Respect OutputFormats with no output directory provided ## What changes were proposed in this pull request? Fix for https://issues.apache.org/jira/browse/SPARK-21549 JIRA issue. Since version 2.2 Spark does not respect OutputFormat with no output paths provided. The examples of such formats are [Cassandra OutputFormat](https://github.com/finn-no/cassandra-hadoop/blob/08dfa3a7ac727bb87269f27a1c82ece54e3f67e6/src/main/java/org/apache/cassandra/hadoop2/AbstractColumnFamilyOutputFormat.java), [Aerospike OutputFormat](https://github.com/aerospike/aerospike-hadoop/blob/master/mapreduce/src/main/java/com/aerospike/hadoop/mapreduce/AerospikeOutputFormat.java), etc. which do not have an ability to rollback the results written to an external systems on job failure. Provided output directory is required by Spark to allows files to be committed to an absolute output location, that is not the case for output formats which write data to external systems. This pull request prevents accessing `absPathStagingDir` method that causes the error described in SPARK-21549 unless there are files to rename in `addedAbsPathFiles`. ## How was this patch tested? Unit tests Author: Sergey Zhemzhitsky Closes #19294 from szhem/SPARK-21549-abs-output-commits. --- .../io/HadoopMapReduceCommitProtocol.scala | 28 ++++++++++++---- .../spark/rdd/PairRDDFunctionsSuite.scala | 33 ++++++++++++++++++- 2 files changed, 54 insertions(+), 7 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/internal/io/HadoopMapReduceCommitProtocol.scala b/core/src/main/scala/org/apache/spark/internal/io/HadoopMapReduceCommitProtocol.scala index b1d07ab2c9199..a7e6859ef6b64 100644 --- a/core/src/main/scala/org/apache/spark/internal/io/HadoopMapReduceCommitProtocol.scala +++ b/core/src/main/scala/org/apache/spark/internal/io/HadoopMapReduceCommitProtocol.scala @@ -35,6 +35,9 @@ import org.apache.spark.mapred.SparkHadoopMapRedUtil * (from the newer mapreduce API, not the old mapred API). * * Unlike Hadoop's OutputCommitter, this implementation is serializable. + * + * @param jobId the job's or stage's id + * @param path the job's output path, or null if committer acts as a noop */ class HadoopMapReduceCommitProtocol(jobId: String, path: String) extends FileCommitProtocol with Serializable with Logging { @@ -57,6 +60,15 @@ class HadoopMapReduceCommitProtocol(jobId: String, path: String) */ private def absPathStagingDir: Path = new Path(path, "_temporary-" + jobId) + /** + * Checks whether there are files to be committed to an absolute output location. + * + * As committing and aborting a job occurs on driver, where `addedAbsPathFiles` is always null, + * it is necessary to check whether the output path is specified. Output path may not be required + * for committers not writing to distributed file systems. + */ + private def hasAbsPathFiles: Boolean = path != null + protected def setupCommitter(context: TaskAttemptContext): OutputCommitter = { val format = context.getOutputFormatClass.newInstance() // If OutputFormat is Configurable, we should set conf to it. @@ -130,17 +142,21 @@ class HadoopMapReduceCommitProtocol(jobId: String, path: String) val filesToMove = taskCommits.map(_.obj.asInstanceOf[Map[String, String]]) .foldLeft(Map[String, String]())(_ ++ _) logDebug(s"Committing files staged for absolute locations $filesToMove") - val fs = absPathStagingDir.getFileSystem(jobContext.getConfiguration) - for ((src, dst) <- filesToMove) { - fs.rename(new Path(src), new Path(dst)) + if (hasAbsPathFiles) { + val fs = absPathStagingDir.getFileSystem(jobContext.getConfiguration) + for ((src, dst) <- filesToMove) { + fs.rename(new Path(src), new Path(dst)) + } + fs.delete(absPathStagingDir, true) } - fs.delete(absPathStagingDir, true) } override def abortJob(jobContext: JobContext): Unit = { committer.abortJob(jobContext, JobStatus.State.FAILED) - val fs = absPathStagingDir.getFileSystem(jobContext.getConfiguration) - fs.delete(absPathStagingDir, true) + if (hasAbsPathFiles) { + val fs = absPathStagingDir.getFileSystem(jobContext.getConfiguration) + fs.delete(absPathStagingDir, true) + } } override def setupTask(taskContext: TaskAttemptContext): Unit = { diff --git a/core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala b/core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala index 44dd955ce8690..07579c5098014 100644 --- a/core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala +++ b/core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala @@ -26,7 +26,7 @@ import org.apache.commons.math3.distribution.{BinomialDistribution, PoissonDistr import org.apache.hadoop.conf.{Configurable, Configuration} import org.apache.hadoop.fs.FileSystem import org.apache.hadoop.mapred._ -import org.apache.hadoop.mapreduce.{JobContext => NewJobContext, +import org.apache.hadoop.mapreduce.{Job => NewJob, JobContext => NewJobContext, OutputCommitter => NewOutputCommitter, OutputFormat => NewOutputFormat, RecordWriter => NewRecordWriter, TaskAttemptContext => NewTaskAttempContext} import org.apache.hadoop.util.Progressable @@ -568,6 +568,37 @@ class PairRDDFunctionsSuite extends SparkFunSuite with SharedSparkContext { assert(FakeWriterWithCallback.exception.getMessage contains "failed to write") } + test("saveAsNewAPIHadoopDataset should respect empty output directory when " + + "there are no files to be committed to an absolute output location") { + val pairs = sc.parallelize(Array((new Integer(1), new Integer(2))), 1) + + val job = NewJob.getInstance(new Configuration(sc.hadoopConfiguration)) + job.setOutputKeyClass(classOf[Integer]) + job.setOutputValueClass(classOf[Integer]) + job.setOutputFormatClass(classOf[NewFakeFormat]) + val jobConfiguration = job.getConfiguration + + // just test that the job does not fail with + // java.lang.IllegalArgumentException: Can not create a Path from a null string + pairs.saveAsNewAPIHadoopDataset(jobConfiguration) + } + + test("saveAsHadoopDataset should respect empty output directory when " + + "there are no files to be committed to an absolute output location") { + val pairs = sc.parallelize(Array((new Integer(1), new Integer(2))), 1) + + val conf = new JobConf() + conf.setOutputKeyClass(classOf[Integer]) + conf.setOutputValueClass(classOf[Integer]) + conf.setOutputFormat(classOf[FakeOutputFormat]) + conf.setOutputCommitter(classOf[FakeOutputCommitter]) + + FakeOutputCommitter.ran = false + pairs.saveAsHadoopDataset(conf) + + assert(FakeOutputCommitter.ran, "OutputCommitter was never called") + } + test("lookup") { val pairs = sc.parallelize(Array((1, 2), (3, 4), (5, 6), (5, 7))) From 5eacc3bfa9b9c1435ce04222ac7f943b5f930cf4 Mon Sep 17 00:00:00 2001 From: Kento NOZAWA Date: Sat, 7 Oct 2017 08:30:48 +0100 Subject: [PATCH 1452/1765] [SPARK-22156][MLLIB] Fix update equation of learning rate in Word2Vec.scala ## What changes were proposed in this pull request? Current equation of learning rate is incorrect when `numIterations` > `1`. This PR is based on [original C code](https://github.com/tmikolov/word2vec/blob/master/word2vec.c#L393). cc: mengxr ## How was this patch tested? manual tests I modified [this example code](https://spark.apache.org/docs/2.1.1/mllib-feature-extraction.html#example). ### `numIteration=1` #### Code ```scala import org.apache.spark.mllib.feature.{Word2Vec, Word2VecModel} val input = sc.textFile("data/mllib/sample_lda_data.txt").map(line => line.split(" ").toSeq) val word2vec = new Word2Vec() val model = word2vec.fit(input) val synonyms = model.findSynonyms("1", 5) for((synonym, cosineSimilarity) <- synonyms) { println(s"$synonym $cosineSimilarity") } ``` #### Result ``` 2 0.175856813788414 0 0.10971353203058243 4 0.09818313270807266 3 0.012947646901011467 9 -0.09881238639354706 ``` ### `numIteration=5` #### Code ```scala import org.apache.spark.mllib.feature.{Word2Vec, Word2VecModel} val input = sc.textFile("data/mllib/sample_lda_data.txt").map(line => line.split(" ").toSeq) val word2vec = new Word2Vec() word2vec.setNumIterations(5) val model = word2vec.fit(input) val synonyms = model.findSynonyms("1", 5) for((synonym, cosineSimilarity) <- synonyms) { println(s"$synonym $cosineSimilarity") } ``` #### Result ``` 0 0.9898583889007568 2 0.9808019399642944 4 0.9794934391975403 3 0.9506527781486511 9 -0.9065656661987305 ``` Author: Kento NOZAWA Closes #19372 from nzw0301/master. --- .../org/apache/spark/mllib/feature/Word2Vec.scala | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala index 6f96813497b62..b8c306d86bace 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala @@ -353,11 +353,14 @@ class Word2Vec extends Serializable with Logging { val syn0Global = Array.fill[Float](vocabSize * vectorSize)((initRandom.nextFloat() - 0.5f) / vectorSize) val syn1Global = new Array[Float](vocabSize * vectorSize) + val totalWordsCounts = numIterations * trainWordsCount + 1 var alpha = learningRate for (k <- 1 to numIterations) { val bcSyn0Global = sc.broadcast(syn0Global) val bcSyn1Global = sc.broadcast(syn1Global) + val numWordsProcessedInPreviousIterations = (k - 1) * trainWordsCount + val partial = newSentences.mapPartitionsWithIndex { case (idx, iter) => val random = new XORShiftRandom(seed ^ ((idx + 1) << 16) ^ ((-k - 1) << 8)) val syn0Modify = new Array[Int](vocabSize) @@ -368,11 +371,12 @@ class Word2Vec extends Serializable with Logging { var wc = wordCount if (wordCount - lastWordCount > 10000) { lwc = wordCount - // TODO: discount by iteration? - alpha = - learningRate * (1 - numPartitions * wordCount.toDouble / (trainWordsCount + 1)) + alpha = learningRate * + (1 - (numPartitions * wordCount.toDouble + numWordsProcessedInPreviousIterations) / + totalWordsCounts) if (alpha < learningRate * 0.0001) alpha = learningRate * 0.0001 - logInfo("wordCount = " + wordCount + ", alpha = " + alpha) + logInfo(s"wordCount = ${wordCount + numWordsProcessedInPreviousIterations}, " + + s"alpha = $alpha") } wc += sentence.length var pos = 0 From c998a2ae0ea019dfb9b39cef6ddfac07c496e083 Mon Sep 17 00:00:00 2001 From: Sergei Lebedev Date: Sun, 8 Oct 2017 12:58:39 +0100 Subject: [PATCH 1453/1765] [SPARK-22147][CORE] Removed redundant allocations from BlockId ## What changes were proposed in this pull request? Prior to this commit BlockId.hashCode and BlockId.equals were defined in terms of BlockId.name. This allowed the subclasses to be concise and enforced BlockId.name as a single unique identifier for a block. All subclasses override BlockId.name with an expression involving an allocation of StringBuilder and ultimatelly String. This is suboptimal since it induced unnecessary GC pressure on the dirver, see BlockManagerMasterEndpoint. The commit removes the definition of hashCode and equals from the base class. No other change is necessary since all subclasses are in fact case classes and therefore have auto-generated hashCode and equals. No change of behaviour is expected. Sidenote: you might be wondering, why did the subclasses use the base implementation and the auto-generated one? Apparently, this behaviour is documented in the spec. See this SO answer for details https://stackoverflow.com/a/44990210/262432. ## How was this patch tested? BlockIdSuite Author: Sergei Lebedev Closes #19369 from superbobry/blockid-equals-hashcode. --- .../netty/NettyBlockTransferService.scala | 2 +- .../org/apache/spark/storage/BlockId.scala | 5 -- .../org/apache/spark/storage/DiskStore.scala | 8 +-- .../BlockManagerReplicationSuite.scala | 49 ------------------- 4 files changed, 5 insertions(+), 59 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala index ac4d85004bad1..6a29e18bf3cbb 100644 --- a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala +++ b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala @@ -151,7 +151,7 @@ private[spark] class NettyBlockTransferService( // Convert or copy nio buffer into array in order to serialize it. val array = JavaUtils.bufferToArray(blockData.nioByteBuffer()) - client.sendRpc(new UploadBlock(appId, execId, blockId.toString, metadata, array).toByteBuffer, + client.sendRpc(new UploadBlock(appId, execId, blockId.name, metadata, array).toByteBuffer, new RpcResponseCallback { override def onSuccess(response: ByteBuffer): Unit = { logTrace(s"Successfully uploaded block $blockId") diff --git a/core/src/main/scala/org/apache/spark/storage/BlockId.scala b/core/src/main/scala/org/apache/spark/storage/BlockId.scala index 524f6970992a5..a441baed2800e 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockId.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockId.scala @@ -41,11 +41,6 @@ sealed abstract class BlockId { def isBroadcast: Boolean = isInstanceOf[BroadcastBlockId] override def toString: String = name - override def hashCode: Int = name.hashCode - override def equals(other: Any): Boolean = other match { - case o: BlockId => getClass == o.getClass && name.equals(o.name) - case _ => false - } } @DeveloperApi diff --git a/core/src/main/scala/org/apache/spark/storage/DiskStore.scala b/core/src/main/scala/org/apache/spark/storage/DiskStore.scala index 3579acf8d83d9..97abd92d4b70f 100644 --- a/core/src/main/scala/org/apache/spark/storage/DiskStore.scala +++ b/core/src/main/scala/org/apache/spark/storage/DiskStore.scala @@ -47,9 +47,9 @@ private[spark] class DiskStore( private val minMemoryMapBytes = conf.getSizeAsBytes("spark.storage.memoryMapThreshold", "2m") private val maxMemoryMapBytes = conf.getSizeAsBytes("spark.storage.memoryMapLimitForTests", Int.MaxValue.toString) - private val blockSizes = new ConcurrentHashMap[String, Long]() + private val blockSizes = new ConcurrentHashMap[BlockId, Long]() - def getSize(blockId: BlockId): Long = blockSizes.get(blockId.name) + def getSize(blockId: BlockId): Long = blockSizes.get(blockId) /** * Invokes the provided callback function to write the specific block. @@ -67,7 +67,7 @@ private[spark] class DiskStore( var threwException: Boolean = true try { writeFunc(out) - blockSizes.put(blockId.name, out.getCount) + blockSizes.put(blockId, out.getCount) threwException = false } finally { try { @@ -113,7 +113,7 @@ private[spark] class DiskStore( } def remove(blockId: BlockId): Boolean = { - blockSizes.remove(blockId.name) + blockSizes.remove(blockId) val file = diskManager.getFile(blockId.name) if (file.exists()) { val ret = file.delete() diff --git a/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala index dd61dcd11bcda..c2101ba828553 100644 --- a/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala @@ -198,55 +198,6 @@ trait BlockManagerReplicationBehavior extends SparkFunSuite } } - test("block replication - deterministic node selection") { - val blockSize = 1000 - val storeSize = 10000 - val stores = (1 to 5).map { - i => makeBlockManager(storeSize, s"store$i") - } - val storageLevel2x = StorageLevel.MEMORY_AND_DISK_2 - val storageLevel3x = StorageLevel(true, true, false, true, 3) - val storageLevel4x = StorageLevel(true, true, false, true, 4) - - def putBlockAndGetLocations(blockId: String, level: StorageLevel): Set[BlockManagerId] = { - stores.head.putSingle(blockId, new Array[Byte](blockSize), level) - val locations = master.getLocations(blockId).sortBy { _.executorId }.toSet - stores.foreach { _.removeBlock(blockId) } - master.removeBlock(blockId) - locations - } - - // Test if two attempts to 2x replication returns same set of locations - val a1Locs = putBlockAndGetLocations("a1", storageLevel2x) - assert(putBlockAndGetLocations("a1", storageLevel2x) === a1Locs, - "Inserting a 2x replicated block second time gave different locations from the first") - - // Test if two attempts to 3x replication returns same set of locations - val a2Locs3x = putBlockAndGetLocations("a2", storageLevel3x) - assert(putBlockAndGetLocations("a2", storageLevel3x) === a2Locs3x, - "Inserting a 3x replicated block second time gave different locations from the first") - - // Test if 2x replication of a2 returns a strict subset of the locations of 3x replication - val a2Locs2x = putBlockAndGetLocations("a2", storageLevel2x) - assert( - a2Locs2x.subsetOf(a2Locs3x), - "Inserting a with 2x replication gave locations that are not a subset of locations" + - s" with 3x replication [3x: ${a2Locs3x.mkString(",")}; 2x: ${a2Locs2x.mkString(",")}" - ) - - // Test if 4x replication of a2 returns a strict superset of the locations of 3x replication - val a2Locs4x = putBlockAndGetLocations("a2", storageLevel4x) - assert( - a2Locs3x.subsetOf(a2Locs4x), - "Inserting a with 4x replication gave locations that are not a superset of locations " + - s"with 3x replication [3x: ${a2Locs3x.mkString(",")}; 4x: ${a2Locs4x.mkString(",")}" - ) - - // Test if 3x replication of two different blocks gives two different sets of locations - val a3Locs3x = putBlockAndGetLocations("a3", storageLevel3x) - assert(a3Locs3x !== a2Locs3x, "Two blocks gave same locations with 3x replication") - } - test("block replication - replication failures") { /* Create a system of three block managers / stores. One of them (say, failableStore) From fe7b219ae3e8a045655a836cbb77219036ec5740 Mon Sep 17 00:00:00 2001 From: Yuanjian Li Date: Mon, 9 Oct 2017 14:16:25 +0800 Subject: [PATCH 1454/1765] [SPARK-22074][CORE] Task killed by other attempt task should not be resubmitted ## What changes were proposed in this pull request? As the detail scenario described in [SPARK-22074](https://issues.apache.org/jira/browse/SPARK-22074), unnecessary resubmitted may cause stage hanging in currently release versions. This patch add a new var in TaskInfo to mark this task killed by other attempt or not. ## How was this patch tested? Add a new UT `[SPARK-22074] Task killed by other attempt task should not be resubmitted` in TaskSetManagerSuite, this UT recreate the scenario in JIRA description, it failed without the changes in this PR and passed conversely. Author: Yuanjian Li Closes #19287 from xuanyuanking/SPARK-22074. --- .../spark/scheduler/TaskSetManager.scala | 8 +- .../org/apache/spark/scheduler/FakeTask.scala | 20 +++- .../spark/scheduler/TaskSetManagerSuite.scala | 107 ++++++++++++++++++ 3 files changed, 132 insertions(+), 3 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala index 3bdede6743d1b..de4711f461df2 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala @@ -83,6 +83,11 @@ private[spark] class TaskSetManager( val successful = new Array[Boolean](numTasks) private val numFailures = new Array[Int](numTasks) + // Set the coresponding index of Boolean var when the task killed by other attempt tasks, + // this happened while we set the `spark.speculation` to true. The task killed by others + // should not resubmit while executor lost. + private val killedByOtherAttempt: Array[Boolean] = new Array[Boolean](numTasks) + val taskAttempts = Array.fill[List[TaskInfo]](numTasks)(Nil) private[scheduler] var tasksSuccessful = 0 @@ -729,6 +734,7 @@ private[spark] class TaskSetManager( logInfo(s"Killing attempt ${attemptInfo.attemptNumber} for task ${attemptInfo.id} " + s"in stage ${taskSet.id} (TID ${attemptInfo.taskId}) on ${attemptInfo.host} " + s"as the attempt ${info.attemptNumber} succeeded on ${info.host}") + killedByOtherAttempt(index) = true sched.backend.killTask( attemptInfo.taskId, attemptInfo.executorId, @@ -915,7 +921,7 @@ private[spark] class TaskSetManager( && !isZombie) { for ((tid, info) <- taskInfos if info.executorId == execId) { val index = taskInfos(tid).index - if (successful(index)) { + if (successful(index) && !killedByOtherAttempt(index)) { successful(index) = false copiesRunning(index) -= 1 tasksSuccessful -= 1 diff --git a/core/src/test/scala/org/apache/spark/scheduler/FakeTask.scala b/core/src/test/scala/org/apache/spark/scheduler/FakeTask.scala index fe6de2bd98850..109d4a0a870b8 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/FakeTask.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/FakeTask.scala @@ -19,8 +19,7 @@ package org.apache.spark.scheduler import java.util.Properties -import org.apache.spark.SparkEnv -import org.apache.spark.TaskContext +import org.apache.spark.{Partition, SparkEnv, TaskContext} import org.apache.spark.executor.TaskMetrics class FakeTask( @@ -58,4 +57,21 @@ object FakeTask { } new TaskSet(tasks, stageId, stageAttemptId, priority = 0, null) } + + def createShuffleMapTaskSet( + numTasks: Int, + stageId: Int, + stageAttemptId: Int, + prefLocs: Seq[TaskLocation]*): TaskSet = { + if (prefLocs.size != 0 && prefLocs.size != numTasks) { + throw new IllegalArgumentException("Wrong number of task locations") + } + val tasks = Array.tabulate[Task[_]](numTasks) { i => + new ShuffleMapTask(stageId, stageAttemptId, null, new Partition { + override def index: Int = i + }, prefLocs(i), new Properties, + SparkEnv.get.closureSerializer.newInstance().serialize(TaskMetrics.registered).array()) + } + new TaskSet(tasks, stageId, stageAttemptId, priority = 0, null) + } } diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala index 5c712bd6a545b..2ce81ae27daf6 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala @@ -744,6 +744,113 @@ class TaskSetManagerSuite extends SparkFunSuite with LocalSparkContext with Logg assert(resubmittedTasks === 0) } + + test("[SPARK-22074] Task killed by other attempt task should not be resubmitted") { + val conf = new SparkConf().set("spark.speculation", "true") + sc = new SparkContext("local", "test", conf) + // Set the speculation multiplier to be 0 so speculative tasks are launched immediately + sc.conf.set("spark.speculation.multiplier", "0.0") + sc.conf.set("spark.speculation.quantile", "0.5") + sc.conf.set("spark.speculation", "true") + + var killTaskCalled = false + val sched = new FakeTaskScheduler(sc, ("exec1", "host1"), + ("exec2", "host2"), ("exec3", "host3")) + sched.initialize(new FakeSchedulerBackend() { + override def killTask( + taskId: Long, + executorId: String, + interruptThread: Boolean, + reason: String): Unit = { + // Check the only one killTask event in this case, which triggered by + // task 2.1 completed. + assert(taskId === 2) + assert(executorId === "exec3") + assert(interruptThread) + assert(reason === "another attempt succeeded") + killTaskCalled = true + } + }) + + // Keep track of the number of tasks that are resubmitted, + // so that the test can check that no tasks were resubmitted. + var resubmittedTasks = 0 + val dagScheduler = new FakeDAGScheduler(sc, sched) { + override def taskEnded( + task: Task[_], + reason: TaskEndReason, + result: Any, + accumUpdates: Seq[AccumulatorV2[_, _]], + taskInfo: TaskInfo): Unit = { + super.taskEnded(task, reason, result, accumUpdates, taskInfo) + reason match { + case Resubmitted => resubmittedTasks += 1 + case _ => + } + } + } + sched.setDAGScheduler(dagScheduler) + + val taskSet = FakeTask.createShuffleMapTaskSet(4, 0, 0, + Seq(TaskLocation("host1", "exec1")), + Seq(TaskLocation("host1", "exec1")), + Seq(TaskLocation("host3", "exec3")), + Seq(TaskLocation("host2", "exec2"))) + + val clock = new ManualClock() + val manager = new TaskSetManager(sched, taskSet, MAX_TASK_FAILURES, clock = clock) + val accumUpdatesByTask: Array[Seq[AccumulatorV2[_, _]]] = taskSet.tasks.map { task => + task.metrics.internalAccums + } + // Offer resources for 4 tasks to start + for ((exec, host) <- Seq( + "exec1" -> "host1", + "exec1" -> "host1", + "exec3" -> "host3", + "exec2" -> "host2")) { + val taskOption = manager.resourceOffer(exec, host, NO_PREF) + assert(taskOption.isDefined) + val task = taskOption.get + assert(task.executorId === exec) + // Add an extra assert to make sure task 2.0 is running on exec3 + if (task.index == 2) { + assert(task.attemptNumber === 0) + assert(task.executorId === "exec3") + } + } + assert(sched.startedTasks.toSet === Set(0, 1, 2, 3)) + clock.advance(1) + // Complete the 2 tasks and leave 2 task in running + for (id <- Set(0, 1)) { + manager.handleSuccessfulTask(id, createTaskResult(id, accumUpdatesByTask(id))) + assert(sched.endedTasks(id) === Success) + } + + // checkSpeculatableTasks checks that the task runtime is greater than the threshold for + // speculating. Since we use a threshold of 0 for speculation, tasks need to be running for + // > 0ms, so advance the clock by 1ms here. + clock.advance(1) + assert(manager.checkSpeculatableTasks(0)) + assert(sched.speculativeTasks.toSet === Set(2, 3)) + + // Offer resource to start the speculative attempt for the running task 2.0 + val taskOption = manager.resourceOffer("exec2", "host2", ANY) + assert(taskOption.isDefined) + val task4 = taskOption.get + assert(task4.index === 2) + assert(task4.taskId === 4) + assert(task4.executorId === "exec2") + assert(task4.attemptNumber === 1) + // Complete the speculative attempt for the running task + manager.handleSuccessfulTask(4, createTaskResult(2, accumUpdatesByTask(2))) + // Make sure schedBackend.killTask(2, "exec3", true, "another attempt succeeded") gets called + assert(killTaskCalled) + // Host 3 Losts, there's only task 2.0 on it, which killed by task 2.1 + manager.executorLost("exec3", "host3", SlaveLost()) + // Check the resubmittedTasks + assert(resubmittedTasks === 0) + } + test("speculative and noPref task should be scheduled after node-local") { sc = new SparkContext("local", "test") sched = new FakeTaskScheduler( From 98057583dd2787c0e396c2658c7dd76412f86936 Mon Sep 17 00:00:00 2001 From: Nick Pentreath Date: Mon, 9 Oct 2017 10:42:33 +0200 Subject: [PATCH 1455/1765] [SPARK-20679][ML] Support recommending for a subset of users/items in ALSModel This PR adds methods `recommendForUserSubset` and `recommendForItemSubset` to `ALSModel`. These allow recommending for a specified set of user / item ids rather than for every user / item (as in the `recommendForAllX` methods). The subset methods take a `DataFrame` as input, containing ids in the column specified by the param `userCol` or `itemCol`. The model will generate recommendations for each _unique_ id in this input dataframe. ## How was this patch tested? New unit tests in `ALSSuite` and Python doctests in `ALS`. Ran updated examples locally. Author: Nick Pentreath Closes #18748 from MLnick/als-recommend-df. --- .../spark/examples/ml/JavaALSExample.java | 9 ++ examples/src/main/python/ml/als_example.py | 9 ++ .../apache/spark/examples/ml/ALSExample.scala | 9 ++ .../apache/spark/ml/recommendation/ALS.scala | 48 +++++++++ .../spark/ml/recommendation/ALSSuite.scala | 100 ++++++++++++++++-- python/pyspark/ml/recommendation.py | 38 +++++++ 6 files changed, 205 insertions(+), 8 deletions(-) diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaALSExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaALSExample.java index fe4d6bc83f04a..27052be87b82e 100644 --- a/examples/src/main/java/org/apache/spark/examples/ml/JavaALSExample.java +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaALSExample.java @@ -118,9 +118,18 @@ public static void main(String[] args) { Dataset userRecs = model.recommendForAllUsers(10); // Generate top 10 user recommendations for each movie Dataset movieRecs = model.recommendForAllItems(10); + + // Generate top 10 movie recommendations for a specified set of users + Dataset users = ratings.select(als.getUserCol()).distinct().limit(3); + Dataset userSubsetRecs = model.recommendForUserSubset(users, 10); + // Generate top 10 user recommendations for a specified set of movies + Dataset movies = ratings.select(als.getItemCol()).distinct().limit(3); + Dataset movieSubSetRecs = model.recommendForItemSubset(movies, 10); // $example off$ userRecs.show(); movieRecs.show(); + userSubsetRecs.show(); + movieSubSetRecs.show(); spark.stop(); } diff --git a/examples/src/main/python/ml/als_example.py b/examples/src/main/python/ml/als_example.py index 1672d552eb1d5..8b7ec9c439f9f 100644 --- a/examples/src/main/python/ml/als_example.py +++ b/examples/src/main/python/ml/als_example.py @@ -60,8 +60,17 @@ userRecs = model.recommendForAllUsers(10) # Generate top 10 user recommendations for each movie movieRecs = model.recommendForAllItems(10) + + # Generate top 10 movie recommendations for a specified set of users + users = ratings.select(als.getUserCol()).distinct().limit(3) + userSubsetRecs = model.recommendForUserSubset(users, 10) + # Generate top 10 user recommendations for a specified set of movies + movies = ratings.select(als.getItemCol()).distinct().limit(3) + movieSubSetRecs = model.recommendForItemSubset(movies, 10) # $example off$ userRecs.show() movieRecs.show() + userSubsetRecs.show() + movieSubSetRecs.show() spark.stop() diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/ALSExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/ALSExample.scala index 07b15dfa178f7..8091838a2301e 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/ALSExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/ALSExample.scala @@ -80,9 +80,18 @@ object ALSExample { val userRecs = model.recommendForAllUsers(10) // Generate top 10 user recommendations for each movie val movieRecs = model.recommendForAllItems(10) + + // Generate top 10 movie recommendations for a specified set of users + val users = ratings.select(als.getUserCol).distinct().limit(3) + val userSubsetRecs = model.recommendForUserSubset(users, 10) + // Generate top 10 user recommendations for a specified set of movies + val movies = ratings.select(als.getItemCol).distinct().limit(3) + val movieSubSetRecs = model.recommendForItemSubset(movies, 10) // $example off$ userRecs.show() movieRecs.show() + userSubsetRecs.show() + movieSubSetRecs.show() spark.stop() } diff --git a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala index 3d5fd1794de23..a8843661c873b 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala @@ -344,6 +344,21 @@ class ALSModel private[ml] ( recommendForAll(userFactors, itemFactors, $(userCol), $(itemCol), numItems) } + /** + * Returns top `numItems` items recommended for each user id in the input data set. Note that if + * there are duplicate ids in the input dataset, only one set of recommendations per unique id + * will be returned. + * @param dataset a Dataset containing a column of user ids. The column name must match `userCol`. + * @param numItems max number of recommendations for each user. + * @return a DataFrame of (userCol: Int, recommendations), where recommendations are + * stored as an array of (itemCol: Int, rating: Float) Rows. + */ + @Since("2.3.0") + def recommendForUserSubset(dataset: Dataset[_], numItems: Int): DataFrame = { + val srcFactorSubset = getSourceFactorSubset(dataset, userFactors, $(userCol)) + recommendForAll(srcFactorSubset, itemFactors, $(userCol), $(itemCol), numItems) + } + /** * Returns top `numUsers` users recommended for each item, for all items. * @param numUsers max number of recommendations for each item @@ -355,6 +370,39 @@ class ALSModel private[ml] ( recommendForAll(itemFactors, userFactors, $(itemCol), $(userCol), numUsers) } + /** + * Returns top `numUsers` users recommended for each item id in the input data set. Note that if + * there are duplicate ids in the input dataset, only one set of recommendations per unique id + * will be returned. + * @param dataset a Dataset containing a column of item ids. The column name must match `itemCol`. + * @param numUsers max number of recommendations for each item. + * @return a DataFrame of (itemCol: Int, recommendations), where recommendations are + * stored as an array of (userCol: Int, rating: Float) Rows. + */ + @Since("2.3.0") + def recommendForItemSubset(dataset: Dataset[_], numUsers: Int): DataFrame = { + val srcFactorSubset = getSourceFactorSubset(dataset, itemFactors, $(itemCol)) + recommendForAll(srcFactorSubset, userFactors, $(itemCol), $(userCol), numUsers) + } + + /** + * Returns a subset of a factor DataFrame limited to only those unique ids contained + * in the input dataset. + * @param dataset input Dataset containing id column to user to filter factors. + * @param factors factor DataFrame to filter. + * @param column column name containing the ids in the input dataset. + * @return DataFrame containing factors only for those ids present in both the input dataset and + * the factor DataFrame. + */ + private def getSourceFactorSubset( + dataset: Dataset[_], + factors: DataFrame, + column: String): DataFrame = { + factors + .join(dataset.select(column), factors("id") === dataset(column), joinType = "left_semi") + .select(factors("id"), factors("features")) + } + /** * Makes recommendations for all users (or items). * diff --git a/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala index ac7319110159b..addcd21d50aac 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala @@ -723,9 +723,9 @@ class ALSSuite val numUsers = model.userFactors.count val numItems = model.itemFactors.count val expected = Map( - 0 -> Array((3, 54f), (4, 44f), (5, 42f), (6, 28f)), - 1 -> Array((3, 39f), (5, 33f), (4, 26f), (6, 16f)), - 2 -> Array((3, 51f), (5, 45f), (4, 30f), (6, 18f)) + 0 -> Seq((3, 54f), (4, 44f), (5, 42f), (6, 28f)), + 1 -> Seq((3, 39f), (5, 33f), (4, 26f), (6, 16f)), + 2 -> Seq((3, 51f), (5, 45f), (4, 30f), (6, 18f)) ) Seq(2, 4, 6).foreach { k => @@ -743,10 +743,10 @@ class ALSSuite val numUsers = model.userFactors.count val numItems = model.itemFactors.count val expected = Map( - 3 -> Array((0, 54f), (2, 51f), (1, 39f)), - 4 -> Array((0, 44f), (2, 30f), (1, 26f)), - 5 -> Array((2, 45f), (0, 42f), (1, 33f)), - 6 -> Array((0, 28f), (2, 18f), (1, 16f)) + 3 -> Seq((0, 54f), (2, 51f), (1, 39f)), + 4 -> Seq((0, 44f), (2, 30f), (1, 26f)), + 5 -> Seq((2, 45f), (0, 42f), (1, 33f)), + 6 -> Seq((0, 28f), (2, 18f), (1, 16f)) ) Seq(2, 3, 4).foreach { k => @@ -759,9 +759,93 @@ class ALSSuite } } + test("recommendForUserSubset with k <, = and > num_items") { + val spark = this.spark + import spark.implicits._ + val model = getALSModel + val numItems = model.itemFactors.count + val expected = Map( + 0 -> Seq((3, 54f), (4, 44f), (5, 42f), (6, 28f)), + 2 -> Seq((3, 51f), (5, 45f), (4, 30f), (6, 18f)) + ) + val userSubset = expected.keys.toSeq.toDF("user") + val numUsersSubset = userSubset.count + + Seq(2, 4, 6).foreach { k => + val n = math.min(k, numItems).toInt + val expectedUpToN = expected.mapValues(_.slice(0, n)) + val topItems = model.recommendForUserSubset(userSubset, k) + assert(topItems.count() == numUsersSubset) + assert(topItems.columns.contains("user")) + checkRecommendations(topItems, expectedUpToN, "item") + } + } + + test("recommendForItemSubset with k <, = and > num_users") { + val spark = this.spark + import spark.implicits._ + val model = getALSModel + val numUsers = model.userFactors.count + val expected = Map( + 3 -> Seq((0, 54f), (2, 51f), (1, 39f)), + 6 -> Seq((0, 28f), (2, 18f), (1, 16f)) + ) + val itemSubset = expected.keys.toSeq.toDF("item") + val numItemsSubset = itemSubset.count + + Seq(2, 3, 4).foreach { k => + val n = math.min(k, numUsers).toInt + val expectedUpToN = expected.mapValues(_.slice(0, n)) + val topUsers = model.recommendForItemSubset(itemSubset, k) + assert(topUsers.count() == numItemsSubset) + assert(topUsers.columns.contains("item")) + checkRecommendations(topUsers, expectedUpToN, "user") + } + } + + test("subset recommendations eliminate duplicate ids, returns same results as unique ids") { + val spark = this.spark + import spark.implicits._ + val model = getALSModel + val k = 2 + + val users = Seq(0, 1).toDF("user") + val dupUsers = Seq(0, 1, 0, 1).toDF("user") + val singleUserRecs = model.recommendForUserSubset(users, k) + val dupUserRecs = model.recommendForUserSubset(dupUsers, k) + .as[(Int, Seq[(Int, Float)])].collect().toMap + assert(singleUserRecs.count == dupUserRecs.size) + checkRecommendations(singleUserRecs, dupUserRecs, "item") + + val items = Seq(3, 4, 5).toDF("item") + val dupItems = Seq(3, 4, 5, 4, 5).toDF("item") + val singleItemRecs = model.recommendForItemSubset(items, k) + val dupItemRecs = model.recommendForItemSubset(dupItems, k) + .as[(Int, Seq[(Int, Float)])].collect().toMap + assert(singleItemRecs.count == dupItemRecs.size) + checkRecommendations(singleItemRecs, dupItemRecs, "user") + } + + test("subset recommendations on full input dataset equivalent to recommendForAll") { + val spark = this.spark + import spark.implicits._ + val model = getALSModel + val k = 2 + + val userSubset = model.userFactors.withColumnRenamed("id", "user").drop("features") + val userSubsetRecs = model.recommendForUserSubset(userSubset, k) + val allUserRecs = model.recommendForAllUsers(k).as[(Int, Seq[(Int, Float)])].collect().toMap + checkRecommendations(userSubsetRecs, allUserRecs, "item") + + val itemSubset = model.itemFactors.withColumnRenamed("id", "item").drop("features") + val itemSubsetRecs = model.recommendForItemSubset(itemSubset, k) + val allItemRecs = model.recommendForAllItems(k).as[(Int, Seq[(Int, Float)])].collect().toMap + checkRecommendations(itemSubsetRecs, allItemRecs, "user") + } + private def checkRecommendations( topK: DataFrame, - expected: Map[Int, Array[(Int, Float)]], + expected: Map[Int, Seq[(Int, Float)]], dstColName: String): Unit = { val spark = this.spark import spark.implicits._ diff --git a/python/pyspark/ml/recommendation.py b/python/pyspark/ml/recommendation.py index bcfb36880eb02..e8bcbe4cd34cb 100644 --- a/python/pyspark/ml/recommendation.py +++ b/python/pyspark/ml/recommendation.py @@ -90,6 +90,14 @@ class ALS(JavaEstimator, HasCheckpointInterval, HasMaxIter, HasPredictionCol, Ha >>> item_recs.where(item_recs.item == 2)\ .select("recommendations.user", "recommendations.rating").collect() [Row(user=[2, 1, 0], rating=[4.901..., 3.981..., -0.138...])] + >>> user_subset = df.where(df.user == 2) + >>> user_subset_recs = model.recommendForUserSubset(user_subset, 3) + >>> user_subset_recs.select("recommendations.item", "recommendations.rating").first() + Row(item=[2, 1, 0], rating=[4.901..., 1.056..., -1.501...]) + >>> item_subset = df.where(df.item == 0) + >>> item_subset_recs = model.recommendForItemSubset(item_subset, 3) + >>> item_subset_recs.select("recommendations.user", "recommendations.rating").first() + Row(user=[0, 1, 2], rating=[3.910..., 2.625..., -1.501...]) >>> als_path = temp_path + "/als" >>> als.save(als_path) >>> als2 = ALS.load(als_path) @@ -414,6 +422,36 @@ def recommendForAllItems(self, numUsers): """ return self._call_java("recommendForAllItems", numUsers) + @since("2.3.0") + def recommendForUserSubset(self, dataset, numItems): + """ + Returns top `numItems` items recommended for each user id in the input data set. Note that + if there are duplicate ids in the input dataset, only one set of recommendations per unique + id will be returned. + + :param dataset: a Dataset containing a column of user ids. The column name must match + `userCol`. + :param numItems: max number of recommendations for each user + :return: a DataFrame of (userCol, recommendations), where recommendations are + stored as an array of (itemCol, rating) Rows. + """ + return self._call_java("recommendForUserSubset", dataset, numItems) + + @since("2.3.0") + def recommendForItemSubset(self, dataset, numUsers): + """ + Returns top `numUsers` users recommended for each item id in the input data set. Note that + if there are duplicate ids in the input dataset, only one set of recommendations per unique + id will be returned. + + :param dataset: a Dataset containing a column of item ids. The column name must match + `itemCol`. + :param numUsers: max number of recommendations for each item + :return: a DataFrame of (itemCol, recommendations), where recommendations are + stored as an array of (userCol, rating) Rows. + """ + return self._call_java("recommendForItemSubset", dataset, numUsers) + if __name__ == "__main__": import doctest From f31e11404d6d5ee28b574c242ecbee94f35e9370 Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Mon, 9 Oct 2017 12:53:10 -0700 Subject: [PATCH 1456/1765] [SPARK-21568][CORE] ConsoleProgressBar should only be enabled in shells ## What changes were proposed in this pull request? This PR disables console progress bar feature in non-shell environment by overriding the configuration. ## How was this patch tested? Manual. Run the following examples with and without `spark.ui.showConsoleProgress` in order to see progress bar on master branch and this PR. **Scala Shell** ```scala spark.range(1000000000).map(_ + 1).count ``` **PySpark** ```python spark.range(10000000).rdd.map(lambda x: len(x)).count() ``` **Spark Submit** ```python from pyspark.sql import SparkSession if __name__ == "__main__": spark = SparkSession.builder.getOrCreate() spark.range(2000000).rdd.map(lambda row: len(row)).count() spark.stop() ``` Author: Dongjoon Hyun Closes #19061 from dongjoon-hyun/SPARK-21568. --- .../main/scala/org/apache/spark/SparkContext.scala | 2 +- .../scala/org/apache/spark/deploy/SparkSubmit.scala | 5 +++++ .../org/apache/spark/internal/config/package.scala | 5 +++++ .../org/apache/spark/deploy/SparkSubmitSuite.scala | 12 ++++++++++++ 4 files changed, 23 insertions(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index cec61d85ccf38..b3cd03c0cfbe1 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -434,7 +434,7 @@ class SparkContext(config: SparkConf) extends Logging { _statusTracker = new SparkStatusTracker(this) _progressBar = - if (_conf.getBoolean("spark.ui.showConsoleProgress", true) && !log.isInfoEnabled) { + if (_conf.get(UI_SHOW_CONSOLE_PROGRESS) && !log.isInfoEnabled) { Some(new ConsoleProgressBar(this)) } else { None diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala index 286a4379d2040..135bbe93bf28e 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala @@ -598,6 +598,11 @@ object SparkSubmit extends CommandLineUtils with Logging { } } + // In case of shells, spark.ui.showConsoleProgress can be true by default or by user. + if (isShell(args.primaryResource) && !sparkConf.contains(UI_SHOW_CONSOLE_PROGRESS)) { + sysProps(UI_SHOW_CONSOLE_PROGRESS.key) = "true" + } + // Add the application jar automatically so the user doesn't have to call sc.addJar // For YARN cluster mode, the jar is already distributed on each node as "app.jar" // For python and R files, the primary resource is already distributed as a regular file diff --git a/core/src/main/scala/org/apache/spark/internal/config/package.scala b/core/src/main/scala/org/apache/spark/internal/config/package.scala index d85b6a0200b8d..5278e5e0fb270 100644 --- a/core/src/main/scala/org/apache/spark/internal/config/package.scala +++ b/core/src/main/scala/org/apache/spark/internal/config/package.scala @@ -203,6 +203,11 @@ package object config { private[spark] val HISTORY_UI_MAX_APPS = ConfigBuilder("spark.history.ui.maxApplications").intConf.createWithDefault(Integer.MAX_VALUE) + private[spark] val UI_SHOW_CONSOLE_PROGRESS = ConfigBuilder("spark.ui.showConsoleProgress") + .doc("When true, show the progress bar in the console.") + .booleanConf + .createWithDefault(false) + private[spark] val IO_ENCRYPTION_ENABLED = ConfigBuilder("spark.io.encryption.enabled") .booleanConf .createWithDefault(false) diff --git a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala index ad801bf8519a6..b06f2e26a4a7a 100644 --- a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala @@ -399,6 +399,18 @@ class SparkSubmitSuite mainClass should be ("org.apache.spark.deploy.yarn.Client") } + test("SPARK-21568 ConsoleProgressBar should be enabled only in shells") { + val clArgs1 = Seq("--class", "org.apache.spark.repl.Main", "spark-shell") + val appArgs1 = new SparkSubmitArguments(clArgs1) + val (_, _, sysProps1, _) = prepareSubmitEnvironment(appArgs1) + sysProps1(UI_SHOW_CONSOLE_PROGRESS.key) should be ("true") + + val clArgs2 = Seq("--class", "org.SomeClass", "thejar.jar") + val appArgs2 = new SparkSubmitArguments(clArgs2) + val (_, _, sysProps2, _) = prepareSubmitEnvironment(appArgs2) + sysProps2.keys should not contain UI_SHOW_CONSOLE_PROGRESS.key + } + test("launch simple application with spark-submit") { val unusedJar = TestUtils.createJarWithClasses(Seq.empty) val args = Seq( From a74ec6d7bbfe185ba995dcb02d69e90a089c293e Mon Sep 17 00:00:00 2001 From: Thomas Graves Date: Mon, 9 Oct 2017 12:56:37 -0700 Subject: [PATCH 1457/1765] [SPARK-22218] spark shuffle services fails to update secret on app re-attempts This patch fixes application re-attempts when running spark on yarn using the external shuffle service with security on. Currently executors will fail to launch on any application re-attempt when launched on a nodemanager that had an executor from the first attempt. The reason for this is because we aren't updating the secret key after the first application attempt. The fix here is to just remove the containskey check to see if it already exists. In this way, we always add it and make sure its the most recent secret. Similarly remove the check for containsKey on the remove since its just adding extra check that isn't really needed. Note this worked before spark 2.2 because the check used to be contains (which was looking for the value) rather then containsKey, so that never matched and it was just always adding the new secret. Patch was tested on a 10 node cluster as well as added the unit test. The test ran was a wordcount where the output directory already existed. With the bug present the application attempt failed with max number of executor Failures which were all saslExceptions. With the fix present the application re-attempts fail with directory already exists or when you remove the directory between attempts the re-attemps succeed. Author: Thomas Graves Closes #19450 from tgravescs/SPARK-22218. --- .../network/sasl/ShuffleSecretManager.java | 19 +++---- .../sasl/ShuffleSecretManagerSuite.java | 55 +++++++++++++++++++ 2 files changed, 62 insertions(+), 12 deletions(-) create mode 100644 common/network-shuffle/src/test/java/org/apache/spark/network/sasl/ShuffleSecretManagerSuite.java diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/sasl/ShuffleSecretManager.java b/common/network-shuffle/src/main/java/org/apache/spark/network/sasl/ShuffleSecretManager.java index d2d008f8a3d35..7253101f41df6 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/sasl/ShuffleSecretManager.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/sasl/ShuffleSecretManager.java @@ -47,12 +47,11 @@ public ShuffleSecretManager() { * fetching shuffle files written by other executors in this application. */ public void registerApp(String appId, String shuffleSecret) { - if (!shuffleSecretMap.containsKey(appId)) { - shuffleSecretMap.put(appId, shuffleSecret); - logger.info("Registered shuffle secret for application {}", appId); - } else { - logger.debug("Application {} already registered", appId); - } + // Always put the new secret information to make sure it's the most up to date. + // Otherwise we have to specifically look at the application attempt in addition + // to the applicationId since the secrets change between application attempts on yarn. + shuffleSecretMap.put(appId, shuffleSecret); + logger.info("Registered shuffle secret for application {}", appId); } /** @@ -67,12 +66,8 @@ public void registerApp(String appId, ByteBuffer shuffleSecret) { * This is called when the application terminates. */ public void unregisterApp(String appId) { - if (shuffleSecretMap.containsKey(appId)) { - shuffleSecretMap.remove(appId); - logger.info("Unregistered shuffle secret for application {}", appId); - } else { - logger.warn("Attempted to unregister application {} when it is not registered", appId); - } + shuffleSecretMap.remove(appId); + logger.info("Unregistered shuffle secret for application {}", appId); } /** diff --git a/common/network-shuffle/src/test/java/org/apache/spark/network/sasl/ShuffleSecretManagerSuite.java b/common/network-shuffle/src/test/java/org/apache/spark/network/sasl/ShuffleSecretManagerSuite.java new file mode 100644 index 0000000000000..46c4c33865eea --- /dev/null +++ b/common/network-shuffle/src/test/java/org/apache/spark/network/sasl/ShuffleSecretManagerSuite.java @@ -0,0 +1,55 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.sasl; + +import java.nio.ByteBuffer; + +import org.junit.Test; +import static org.junit.Assert.*; + +public class ShuffleSecretManagerSuite { + static String app1 = "app1"; + static String app2 = "app2"; + static String pw1 = "password1"; + static String pw2 = "password2"; + static String pw1update = "password1update"; + static String pw2update = "password2update"; + + @Test + public void testMultipleRegisters() { + ShuffleSecretManager secretManager = new ShuffleSecretManager(); + secretManager.registerApp(app1, pw1); + assertEquals(pw1, secretManager.getSecretKey(app1)); + secretManager.registerApp(app2, ByteBuffer.wrap(pw2.getBytes())); + assertEquals(pw2, secretManager.getSecretKey(app2)); + + // now update the password for the apps and make sure it takes affect + secretManager.registerApp(app1, pw1update); + assertEquals(pw1update, secretManager.getSecretKey(app1)); + secretManager.registerApp(app2, ByteBuffer.wrap(pw2update.getBytes())); + assertEquals(pw2update, secretManager.getSecretKey(app2)); + + secretManager.unregisterApp(app1); + assertNull(secretManager.getSecretKey(app1)); + assertEquals(pw2update, secretManager.getSecretKey(app2)); + + secretManager.unregisterApp(app2); + assertNull(secretManager.getSecretKey(app2)); + assertNull(secretManager.getSecretKey(app1)); + } +} From b650ee0265477ada68220cbf286fa79906608ef5 Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Mon, 9 Oct 2017 13:55:55 -0700 Subject: [PATCH 1458/1765] [INFRA] Close stale PRs. Closes #19423 Closes #19455 From dadd13f365aad0d9228cd8b8e6d57ad32175b155 Mon Sep 17 00:00:00 2001 From: Pavel Sakun Date: Mon, 9 Oct 2017 23:00:04 +0100 Subject: [PATCH 1459/1765] [SPARK] Misleading error message for missing --proxy-user value Fix misleading error message when argument is expected. ## What changes were proposed in this pull request? Change message to be accurate. ## How was this patch tested? Messaging change, was tested manually. Author: Pavel Sakun Closes #19457 from pavel-sakun/patch-1. --- .../src/main/java/org/apache/spark/launcher/SparkLauncher.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/launcher/src/main/java/org/apache/spark/launcher/SparkLauncher.java b/launcher/src/main/java/org/apache/spark/launcher/SparkLauncher.java index 718a368a8e731..75b8ef5ca5ef4 100644 --- a/launcher/src/main/java/org/apache/spark/launcher/SparkLauncher.java +++ b/launcher/src/main/java/org/apache/spark/launcher/SparkLauncher.java @@ -625,7 +625,7 @@ private static class ArgumentValidator extends SparkSubmitOptionParser { @Override protected boolean handle(String opt, String value) { if (value == null && hasValue) { - throw new IllegalArgumentException(String.format("'%s' does not expect a value.", opt)); + throw new IllegalArgumentException(String.format("'%s' expects a value.", opt)); } return true; } From 155ab6347ec7be06c937372a51e8013fdd371d93 Mon Sep 17 00:00:00 2001 From: Ryan Blue Date: Mon, 9 Oct 2017 15:22:41 -0700 Subject: [PATCH 1460/1765] [SPARK-22170][SQL] Reduce memory consumption in broadcast joins. ## What changes were proposed in this pull request? This updates the broadcast join code path to lazily decompress pages and iterate through UnsafeRows to prevent all rows from being held in memory while the broadcast table is being built. ## How was this patch tested? Existing tests. Author: Ryan Blue Closes #19394 from rdblue/broadcast-driver-memory. --- .../plans/physical/broadcastMode.scala | 6 ++++ .../spark/sql/execution/SparkPlan.scala | 19 ++++++++---- .../exchange/BroadcastExchangeExec.scala | 29 ++++++++++++++----- .../sql/execution/joins/HashedRelation.scala | 13 ++++++++- .../spark/sql/ConfigBehaviorSuite.scala | 2 +- .../execution/metric/SQLMetricsSuite.scala | 3 +- 6 files changed, 54 insertions(+), 18 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/broadcastMode.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/broadcastMode.scala index 2ab46dc8330aa..9fac95aed8f12 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/broadcastMode.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/broadcastMode.scala @@ -26,6 +26,8 @@ import org.apache.spark.sql.catalyst.InternalRow trait BroadcastMode { def transform(rows: Array[InternalRow]): Any + def transform(rows: Iterator[InternalRow], sizeHint: Option[Long]): Any + def canonicalized: BroadcastMode } @@ -36,5 +38,9 @@ case object IdentityBroadcastMode extends BroadcastMode { // TODO: pack the UnsafeRows into single bytes array. override def transform(rows: Array[InternalRow]): Array[InternalRow] = rows + override def transform( + rows: Iterator[InternalRow], + sizeHint: Option[Long]): Array[InternalRow] = rows.toArray + override def canonicalized: BroadcastMode = this } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala index b263f100e6068..2ffd948f984bf 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala @@ -223,7 +223,7 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ * UnsafeRow is highly compressible (at least 8 bytes for any column), the byte array is also * compressed. */ - private def getByteArrayRdd(n: Int = -1): RDD[Array[Byte]] = { + private def getByteArrayRdd(n: Int = -1): RDD[(Long, Array[Byte])] = { execute().mapPartitionsInternal { iter => var count = 0 val buffer = new Array[Byte](4 << 10) // 4K @@ -239,7 +239,7 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ out.writeInt(-1) out.flush() out.close() - Iterator(bos.toByteArray) + Iterator((count, bos.toByteArray)) } } @@ -274,19 +274,26 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ val byteArrayRdd = getByteArrayRdd() val results = ArrayBuffer[InternalRow]() - byteArrayRdd.collect().foreach { bytes => - decodeUnsafeRows(bytes).foreach(results.+=) + byteArrayRdd.collect().foreach { countAndBytes => + decodeUnsafeRows(countAndBytes._2).foreach(results.+=) } results.toArray } + private[spark] def executeCollectIterator(): (Long, Iterator[InternalRow]) = { + val countsAndBytes = getByteArrayRdd().collect() + val total = countsAndBytes.map(_._1).sum + val rows = countsAndBytes.iterator.flatMap(countAndBytes => decodeUnsafeRows(countAndBytes._2)) + (total, rows) + } + /** * Runs this query returning the result as an iterator of InternalRow. * * @note Triggers multiple jobs (one for each partition). */ def executeToIterator(): Iterator[InternalRow] = { - getByteArrayRdd().toLocalIterator.flatMap(decodeUnsafeRows) + getByteArrayRdd().map(_._2).toLocalIterator.flatMap(decodeUnsafeRows) } /** @@ -307,7 +314,7 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ return new Array[InternalRow](0) } - val childRDD = getByteArrayRdd(n) + val childRDD = getByteArrayRdd(n).map(_._2) val buf = new ArrayBuffer[InternalRow] val totalParts = childRDD.partitions.length diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/BroadcastExchangeExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/BroadcastExchangeExec.scala index 9c859e41f8762..880e18c6808b0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/BroadcastExchangeExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/BroadcastExchangeExec.scala @@ -27,8 +27,8 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.UnsafeRow import org.apache.spark.sql.catalyst.plans.physical.{BroadcastMode, BroadcastPartitioning, Partitioning} import org.apache.spark.sql.execution.{SparkPlan, SQLExecution} +import org.apache.spark.sql.execution.joins.HashedRelation import org.apache.spark.sql.execution.metric.SQLMetrics -import org.apache.spark.sql.execution.ui.SparkListenerDriverAccumUpdates import org.apache.spark.sql.internal.SQLConf import org.apache.spark.util.ThreadUtils @@ -72,26 +72,39 @@ case class BroadcastExchangeExec( SQLExecution.withExecutionId(sparkContext, executionId) { try { val beforeCollect = System.nanoTime() - // Note that we use .executeCollect() because we don't want to convert data to Scala types - val input: Array[InternalRow] = child.executeCollect() - if (input.length >= 512000000) { + // Use executeCollect/executeCollectIterator to avoid conversion to Scala types + val (numRows, input) = child.executeCollectIterator() + if (numRows >= 512000000) { throw new SparkException( - s"Cannot broadcast the table with more than 512 millions rows: ${input.length} rows") + s"Cannot broadcast the table with more than 512 millions rows: $numRows rows") } + val beforeBuild = System.nanoTime() longMetric("collectTime") += (beforeBuild - beforeCollect) / 1000000 - val dataSize = input.map(_.asInstanceOf[UnsafeRow].getSizeInBytes.toLong).sum + + // Construct the relation. + val relation = mode.transform(input, Some(numRows)) + + val dataSize = relation match { + case map: HashedRelation => + map.estimatedSize + case arr: Array[InternalRow] => + arr.map(_.asInstanceOf[UnsafeRow].getSizeInBytes.toLong).sum + case _ => + throw new SparkException("[BUG] BroadcastMode.transform returned unexpected type: " + + relation.getClass.getName) + } + longMetric("dataSize") += dataSize if (dataSize >= (8L << 30)) { throw new SparkException( s"Cannot broadcast the table that is larger than 8GB: ${dataSize >> 30} GB") } - // Construct and broadcast the relation. - val relation = mode.transform(input) val beforeBroadcast = System.nanoTime() longMetric("buildTime") += (beforeBroadcast - beforeBuild) / 1000000 + // Broadcast the relation val broadcasted = sparkContext.broadcast(relation) longMetric("broadcastTime") += (System.nanoTime() - beforeBroadcast) / 1000000 diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala index f8058b2f7813b..b2dcbe5aa9877 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala @@ -866,7 +866,18 @@ private[execution] case class HashedRelationBroadcastMode(key: Seq[Expression]) extends BroadcastMode { override def transform(rows: Array[InternalRow]): HashedRelation = { - HashedRelation(rows.iterator, canonicalized.key, rows.length) + transform(rows.iterator, Some(rows.length)) + } + + override def transform( + rows: Iterator[InternalRow], + sizeHint: Option[Long]): HashedRelation = { + sizeHint match { + case Some(numRows) => + HashedRelation(rows, canonicalized.key, numRows.toInt) + case None => + HashedRelation(rows, canonicalized.key) + } } override lazy val canonicalized: HashedRelationBroadcastMode = { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ConfigBehaviorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ConfigBehaviorSuite.scala index 2c1e5db5fd9bb..cee85ec8af04d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ConfigBehaviorSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ConfigBehaviorSuite.scala @@ -58,7 +58,7 @@ class ConfigBehaviorSuite extends QueryTest with SharedSQLContext { withSQLConf(SQLConf.RANGE_EXCHANGE_SAMPLE_SIZE_PER_PARTITION.key -> "1") { // If we only sample one point, the range boundaries will be pretty bad and the // chi-sq value would be very high. - assert(computeChiSquareTest() > 1000) + assert(computeChiSquareTest() > 300) } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala index 0dc612ef735fa..58a194b8af62b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala @@ -227,8 +227,7 @@ class SQLMetricsSuite extends SparkFunSuite with SQLMetricsTestUtils with Shared val df = df1.join(broadcast(df2), "key") testSparkPlanMetrics(df, 2, Map( 1L -> (("BroadcastHashJoin", Map( - "number of output rows" -> 2L, - "avg hash probe (min, med, max)" -> "\n(1, 1, 1)")))) + "number of output rows" -> 2L)))) ) } From 71c2b81aa0e0db70013821f5512df1fbd8e59445 Mon Sep 17 00:00:00 2001 From: Jose Torres Date: Mon, 9 Oct 2017 16:34:39 -0700 Subject: [PATCH 1461/1765] [SPARK-22230] Swap per-row order in state store restore. ## What changes were proposed in this pull request? In state store restore, for each row, put the saved state before the row in the iterator instead of after. This fixes an issue where agg(last('attr)) will forever return the last value of 'attr from the first microbatch. ## How was this patch tested? new unit test Author: Jose Torres Closes #19461 from joseph-torres/SPARK-22230. --- .../execution/streaming/statefulOperators.scala | 2 +- .../streaming/StreamingAggregationSuite.scala | 16 ++++++++++++++++ 2 files changed, 17 insertions(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala index fb960fbdde8b3..0d85542928ee6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala @@ -225,7 +225,7 @@ case class StateStoreRestoreExec( val key = getKey(row) val savedState = store.get(key) numOutputRows += 1 - row +: Option(savedState).toSeq + Option(savedState).toSeq :+ row } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala index 995cea3b37d4f..fe7efa69f7e31 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala @@ -520,6 +520,22 @@ class StreamingAggregationSuite extends StateStoreMetricsTest } } + test("SPARK-22230: last should change with new batches") { + val input = MemoryStream[Int] + + val aggregated = input.toDF().agg(last('value)) + testStream(aggregated, OutputMode.Complete())( + AddData(input, 1, 2, 3), + CheckLastBatch(3), + AddData(input, 4, 5, 6), + CheckLastBatch(6), + AddData(input), + CheckLastBatch(6), + AddData(input, 0), + CheckLastBatch(0) + ) + } + /** Add blocks of data to the `BlockRDDBackedSource`. */ case class AddBlockData(source: BlockRDDBackedSource, data: Seq[Int]*) extends AddData { override def addData(query: Option[StreamExecution]): (Source, Offset) = { From bebd2e1ce10a460555f75cda75df33f39a783469 Mon Sep 17 00:00:00 2001 From: Feng Liu Date: Mon, 9 Oct 2017 21:34:37 -0700 Subject: [PATCH 1462/1765] [SPARK-22222][CORE] Fix the ARRAY_MAX in BufferHolder and add a test ## What changes were proposed in this pull request? We should not break the assumption that the length of the allocated byte array is word rounded: https://github.com/apache/spark/blob/master/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java#L170 So we want to use `Integer.MAX_VALUE - 15` instead of `Integer.MAX_VALUE - 8` as the upper bound of an allocated byte array. cc: srowen gatorsmile ## How was this patch tested? Since the Spark unit test JVM has less than 1GB heap, here we run the test code as a submit job, so it can run on a JVM has 4GB memory. Please review http://spark.apache.org/contributing.html before opening a pull request. Author: Feng Liu Closes #19460 from liufengdb/fix_array_max. --- .../spark/unsafe/array/ByteArrayMethods.java | 7 ++ .../unsafe/map/HashMapGrowthStrategy.java | 6 +- .../collection/PartitionedPairBuffer.scala | 6 +- .../spark/deploy/SparkSubmitSuite.scala | 52 +++++++------ .../expressions/codegen/BufferHolder.java | 7 +- .../BufferHolderSparkSubmitSutie.scala | 78 +++++++++++++++++++ .../vectorized/WritableColumnVector.java | 3 +- 7 files changed, 124 insertions(+), 35 deletions(-) create mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/BufferHolderSparkSubmitSutie.scala diff --git a/common/unsafe/src/main/java/org/apache/spark/unsafe/array/ByteArrayMethods.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/array/ByteArrayMethods.java index 9c551ab19e9aa..f121b1cd745b8 100644 --- a/common/unsafe/src/main/java/org/apache/spark/unsafe/array/ByteArrayMethods.java +++ b/common/unsafe/src/main/java/org/apache/spark/unsafe/array/ByteArrayMethods.java @@ -40,6 +40,13 @@ public static int roundNumberOfBytesToNearestWord(int numBytes) { } } + // Some JVMs can't allocate arrays of length Integer.MAX_VALUE; actual max is somewhat smaller. + // Be conservative and lower the cap a little. + // Refer to "http://hg.openjdk.java.net/jdk8/jdk8/jdk/file/tip/src/share/classes/java/util/ArrayList.java#l229" + // This value is word rounded. Use this value if the allocated byte arrays are used to store other + // types rather than bytes. + public static int MAX_ROUNDED_ARRAY_LENGTH = Integer.MAX_VALUE - 15; + private static final boolean unaligned = Platform.unaligned(); /** * Optimized byte array equality check for byte arrays. diff --git a/core/src/main/java/org/apache/spark/unsafe/map/HashMapGrowthStrategy.java b/core/src/main/java/org/apache/spark/unsafe/map/HashMapGrowthStrategy.java index b8c2294c7b7ab..ee6d9f75ac5aa 100644 --- a/core/src/main/java/org/apache/spark/unsafe/map/HashMapGrowthStrategy.java +++ b/core/src/main/java/org/apache/spark/unsafe/map/HashMapGrowthStrategy.java @@ -17,6 +17,8 @@ package org.apache.spark.unsafe.map; +import org.apache.spark.unsafe.array.ByteArrayMethods; + /** * Interface that defines how we can grow the size of a hash map when it is over a threshold. */ @@ -31,9 +33,7 @@ public interface HashMapGrowthStrategy { class Doubling implements HashMapGrowthStrategy { - // Some JVMs can't allocate arrays of length Integer.MAX_VALUE; actual max is somewhat - // smaller. Be conservative and lower the cap a little. - private static final int ARRAY_MAX = Integer.MAX_VALUE - 8; + private static final int ARRAY_MAX = ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH; @Override public int nextCapacity(int currentCapacity) { diff --git a/core/src/main/scala/org/apache/spark/util/collection/PartitionedPairBuffer.scala b/core/src/main/scala/org/apache/spark/util/collection/PartitionedPairBuffer.scala index b755e5da51684..e17a9de97e335 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/PartitionedPairBuffer.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/PartitionedPairBuffer.scala @@ -19,6 +19,8 @@ package org.apache.spark.util.collection import java.util.Comparator +import org.apache.spark.unsafe.Platform +import org.apache.spark.unsafe.array.ByteArrayMethods import org.apache.spark.util.collection.WritablePartitionedPairCollection._ /** @@ -96,7 +98,5 @@ private[spark] class PartitionedPairBuffer[K, V](initialCapacity: Int = 64) } private object PartitionedPairBuffer { - // Some JVMs can't allocate arrays of length Integer.MAX_VALUE; actual max is somewhat - // smaller. Be conservative and lower the cap a little. - val MAXIMUM_CAPACITY: Int = (Int.MaxValue - 8) / 2 + val MAXIMUM_CAPACITY: Int = ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH / 2 } diff --git a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala index b06f2e26a4a7a..b52da4c0c8bc3 100644 --- a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala @@ -100,6 +100,8 @@ class SparkSubmitSuite with TimeLimits with TestPrematureExit { + import SparkSubmitSuite._ + override def beforeEach() { super.beforeEach() System.setProperty("spark.testing", "true") @@ -974,30 +976,6 @@ class SparkSubmitSuite } } - // NOTE: This is an expensive operation in terms of time (10 seconds+). Use sparingly. - private def runSparkSubmit(args: Seq[String]): Unit = { - val sparkHome = sys.props.getOrElse("spark.test.home", fail("spark.test.home is not set!")) - val sparkSubmitFile = if (Utils.isWindows) { - new File("..\\bin\\spark-submit.cmd") - } else { - new File("../bin/spark-submit") - } - val process = Utils.executeCommand( - Seq(sparkSubmitFile.getCanonicalPath) ++ args, - new File(sparkHome), - Map("SPARK_TESTING" -> "1", "SPARK_HOME" -> sparkHome)) - - try { - val exitCode = failAfter(60 seconds) { process.waitFor() } - if (exitCode != 0) { - fail(s"Process returned with exit code $exitCode. See the log4j logs for more detail.") - } - } finally { - // Ensure we still kill the process in case it timed out - process.destroy() - } - } - private def forConfDir(defaults: Map[String, String]) (f: String => Unit) = { val tmpDir = Utils.createTempDir() @@ -1020,6 +998,32 @@ class SparkSubmitSuite } } +object SparkSubmitSuite extends SparkFunSuite with TimeLimits { + // NOTE: This is an expensive operation in terms of time (10 seconds+). Use sparingly. + def runSparkSubmit(args: Seq[String], root: String = ".."): Unit = { + val sparkHome = sys.props.getOrElse("spark.test.home", fail("spark.test.home is not set!")) + val sparkSubmitFile = if (Utils.isWindows) { + new File(s"$root\\bin\\spark-submit.cmd") + } else { + new File(s"$root/bin/spark-submit") + } + val process = Utils.executeCommand( + Seq(sparkSubmitFile.getCanonicalPath) ++ args, + new File(sparkHome), + Map("SPARK_TESTING" -> "1", "SPARK_HOME" -> sparkHome)) + + try { + val exitCode = failAfter(60 seconds) { process.waitFor() } + if (exitCode != 0) { + fail(s"Process returned with exit code $exitCode. See the log4j logs for more detail.") + } + } finally { + // Ensure we still kill the process in case it timed out + process.destroy() + } + } +} + object JarCreationTest extends Logging { def main(args: Array[String]) { Utils.configTestLog4j("INFO") diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/BufferHolder.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/BufferHolder.java index 971d19973f067..259976118c12f 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/BufferHolder.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/BufferHolder.java @@ -19,6 +19,7 @@ import org.apache.spark.sql.catalyst.expressions.UnsafeRow; import org.apache.spark.unsafe.Platform; +import org.apache.spark.unsafe.array.ByteArrayMethods; /** * A helper class to manage the data buffer for an unsafe row. The data buffer can grow and @@ -36,9 +37,7 @@ */ public class BufferHolder { - // Some JVMs can't allocate arrays of length Integer.MAX_VALUE; actual max is somewhat - // smaller. Be conservative and lower the cap a little. - private static final int ARRAY_MAX = Integer.MAX_VALUE - 8; + private static final int ARRAY_MAX = ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH; public byte[] buffer; public int cursor = Platform.BYTE_ARRAY_OFFSET; @@ -51,7 +50,7 @@ public BufferHolder(UnsafeRow row) { public BufferHolder(UnsafeRow row, int initialSize) { int bitsetWidthInBytes = UnsafeRow.calculateBitSetWidthInBytes(row.numFields()); - if (row.numFields() > (Integer.MAX_VALUE - initialSize - bitsetWidthInBytes) / 8) { + if (row.numFields() > (ARRAY_MAX - initialSize - bitsetWidthInBytes) / 8) { throw new UnsupportedOperationException( "Cannot create BufferHolder for input UnsafeRow because there are " + "too many fields (number of fields: " + row.numFields() + ")"); diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/BufferHolderSparkSubmitSutie.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/BufferHolderSparkSubmitSutie.scala new file mode 100644 index 0000000000000..1167d2f3f3891 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/BufferHolderSparkSubmitSutie.scala @@ -0,0 +1,78 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions.codegen + +import org.scalatest.{BeforeAndAfterEach, Matchers} +import org.scalatest.concurrent.Timeouts + +import org.apache.spark.{SparkFunSuite, TestUtils} +import org.apache.spark.deploy.SparkSubmitSuite +import org.apache.spark.sql.catalyst.expressions.UnsafeRow +import org.apache.spark.unsafe.array.ByteArrayMethods +import org.apache.spark.util.ResetSystemProperties + +// A test for growing the buffer holder to nearly 2GB. Due to the heap size limitation of the Spark +// unit tests JVM, the actually test code is running as a submit job. +class BufferHolderSparkSubmitSuite + extends SparkFunSuite + with Matchers + with BeforeAndAfterEach + with ResetSystemProperties + with Timeouts { + + test("SPARK-22222: Buffer holder should be able to allocate memory larger than 1GB") { + val unusedJar = TestUtils.createJarWithClasses(Seq.empty) + + val argsForSparkSubmit = Seq( + "--class", BufferHolderSparkSubmitSuite.getClass.getName.stripSuffix("$"), + "--name", "SPARK-22222", + "--master", "local-cluster[2,1,1024]", + "--driver-memory", "4g", + "--conf", "spark.ui.enabled=false", + "--conf", "spark.master.rest.enabled=false", + "--conf", "spark.driver.extraJavaOptions=-ea", + unusedJar.toString) + SparkSubmitSuite.runSparkSubmit(argsForSparkSubmit, "../..") + } +} + +object BufferHolderSparkSubmitSuite { + + def main(args: Array[String]): Unit = { + + val ARRAY_MAX = ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH + + val holder = new BufferHolder(new UnsafeRow(1000)) + + holder.reset() + holder.grow(roundToWord(ARRAY_MAX / 2)) + + holder.reset() + holder.grow(roundToWord(ARRAY_MAX / 2 + 8)) + + holder.reset() + holder.grow(roundToWord(Integer.MAX_VALUE / 2)) + + holder.reset() + holder.grow(roundToWord(Integer.MAX_VALUE)) + } + + private def roundToWord(len: Int): Int = { + ByteArrayMethods.roundNumberOfBytesToNearestWord(len) + } +} diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/WritableColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/WritableColumnVector.java index da72954ddc448..d3a14b9d8bd74 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/WritableColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/WritableColumnVector.java @@ -23,6 +23,7 @@ import org.apache.spark.sql.internal.SQLConf; import org.apache.spark.sql.types.*; +import org.apache.spark.unsafe.array.ByteArrayMethods; import org.apache.spark.unsafe.types.UTF8String; /** @@ -595,7 +596,7 @@ public final int appendStruct(boolean isNull) { * Upper limit for the maximum capacity for this column. */ @VisibleForTesting - protected int MAX_CAPACITY = Integer.MAX_VALUE - 8; + protected int MAX_CAPACITY = ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH; /** * Number of nulls in this column. This is an optimization for the reader, to skip NULL checks. From af8a34c787dc3d68f5148a7d9975b52650bb7729 Mon Sep 17 00:00:00 2001 From: Takuya UESHIN Date: Mon, 9 Oct 2017 22:35:34 -0700 Subject: [PATCH 1463/1765] [SPARK-22159][SQL][FOLLOW-UP] Make config names consistently end with "enabled". ## What changes were proposed in this pull request? This is a follow-up of #19384. In the previous pr, only definitions of the config names were modified, but we also need to modify the names in runtime or tests specified as string literal. ## How was this patch tested? Existing tests but modified the config names. Author: Takuya UESHIN Closes #19462 from ueshin/issues/SPARK-22159/fup1. --- python/pyspark/sql/dataframe.py | 4 ++-- python/pyspark/sql/tests.py | 6 +++--- .../aggregate/HashAggregateExec.scala | 2 +- .../spark/sql/AggregateHashMapSuite.scala | 12 +++++------ .../benchmark/AggregateBenchmark.scala | 20 +++++++++---------- .../execution/AggregationQuerySuite.scala | 2 +- 6 files changed, 23 insertions(+), 23 deletions(-) diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index b7ce9a83a616d..fe69e588fe098 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -1878,7 +1878,7 @@ def toPandas(self): 1 5 Bob """ import pandas as pd - if self.sql_ctx.getConf("spark.sql.execution.arrow.enable", "false").lower() == "true": + if self.sql_ctx.getConf("spark.sql.execution.arrow.enabled", "false").lower() == "true": try: import pyarrow tables = self._collectAsArrow() @@ -1889,7 +1889,7 @@ def toPandas(self): return pd.DataFrame.from_records([], columns=self.columns) except ImportError as e: msg = "note: pyarrow must be installed and available on calling Python process " \ - "if using spark.sql.execution.arrow.enable=true" + "if using spark.sql.execution.arrow.enabled=true" raise ImportError("%s\n%s" % (e.message, msg)) else: pdf = pd.DataFrame.from_records(self.collect(), columns=self.columns) diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 1b3af42c47ad2..a59378b5e848a 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -3088,7 +3088,7 @@ class ArrowTests(ReusedPySparkTestCase): def setUpClass(cls): ReusedPySparkTestCase.setUpClass() cls.spark = SparkSession(cls.sc) - cls.spark.conf.set("spark.sql.execution.arrow.enable", "true") + cls.spark.conf.set("spark.sql.execution.arrow.enabled", "true") cls.schema = StructType([ StructField("1_str_t", StringType(), True), StructField("2_int_t", IntegerType(), True), @@ -3120,9 +3120,9 @@ def test_null_conversion(self): def test_toPandas_arrow_toggle(self): df = self.spark.createDataFrame(self.data, schema=self.schema) - self.spark.conf.set("spark.sql.execution.arrow.enable", "false") + self.spark.conf.set("spark.sql.execution.arrow.enabled", "false") pdf = df.toPandas() - self.spark.conf.set("spark.sql.execution.arrow.enable", "true") + self.spark.conf.set("spark.sql.execution.arrow.enabled", "true") pdf_arrow = df.toPandas() self.assertFramesEqual(pdf_arrow, pdf) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala index f424096b330e3..8b573fdcf25e1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala @@ -539,7 +539,7 @@ case class HashAggregateExec( private def enableTwoLevelHashMap(ctx: CodegenContext) = { if (!checkIfFastHashMapSupported(ctx)) { if (modes.forall(mode => mode == Partial || mode == PartialMerge) && !Utils.isTesting) { - logInfo("spark.sql.codegen.aggregate.map.twolevel.enable is set to true, but" + logInfo("spark.sql.codegen.aggregate.map.twolevel.enabled is set to true, but" + " current version of codegened fast hashmap does not support this aggregate.") } } else { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/AggregateHashMapSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/AggregateHashMapSuite.scala index 7e61a68025158..938d76c9f0837 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/AggregateHashMapSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/AggregateHashMapSuite.scala @@ -24,14 +24,14 @@ import org.apache.spark.SparkConf class SingleLevelAggregateHashMapSuite extends DataFrameAggregateSuite with BeforeAndAfter { override protected def sparkConf: SparkConf = super.sparkConf .set("spark.sql.codegen.fallback", "false") - .set("spark.sql.codegen.aggregate.map.twolevel.enable", "false") + .set("spark.sql.codegen.aggregate.map.twolevel.enabled", "false") // adding some checking after each test is run, assuring that the configs are not changed // in test code after { assert(sparkConf.get("spark.sql.codegen.fallback") == "false", "configuration parameter changed in test body") - assert(sparkConf.get("spark.sql.codegen.aggregate.map.twolevel.enable") == "false", + assert(sparkConf.get("spark.sql.codegen.aggregate.map.twolevel.enabled") == "false", "configuration parameter changed in test body") } } @@ -39,14 +39,14 @@ class SingleLevelAggregateHashMapSuite extends DataFrameAggregateSuite with Befo class TwoLevelAggregateHashMapSuite extends DataFrameAggregateSuite with BeforeAndAfter { override protected def sparkConf: SparkConf = super.sparkConf .set("spark.sql.codegen.fallback", "false") - .set("spark.sql.codegen.aggregate.map.twolevel.enable", "true") + .set("spark.sql.codegen.aggregate.map.twolevel.enabled", "true") // adding some checking after each test is run, assuring that the configs are not changed // in test code after { assert(sparkConf.get("spark.sql.codegen.fallback") == "false", "configuration parameter changed in test body") - assert(sparkConf.get("spark.sql.codegen.aggregate.map.twolevel.enable") == "true", + assert(sparkConf.get("spark.sql.codegen.aggregate.map.twolevel.enabled") == "true", "configuration parameter changed in test body") } } @@ -57,7 +57,7 @@ class TwoLevelAggregateHashMapWithVectorizedMapSuite override protected def sparkConf: SparkConf = super.sparkConf .set("spark.sql.codegen.fallback", "false") - .set("spark.sql.codegen.aggregate.map.twolevel.enable", "true") + .set("spark.sql.codegen.aggregate.map.twolevel.enabled", "true") .set("spark.sql.codegen.aggregate.map.vectorized.enable", "true") // adding some checking after each test is run, assuring that the configs are not changed @@ -65,7 +65,7 @@ class TwoLevelAggregateHashMapWithVectorizedMapSuite after { assert(sparkConf.get("spark.sql.codegen.fallback") == "false", "configuration parameter changed in test body") - assert(sparkConf.get("spark.sql.codegen.aggregate.map.twolevel.enable") == "true", + assert(sparkConf.get("spark.sql.codegen.aggregate.map.twolevel.enabled") == "true", "configuration parameter changed in test body") assert(sparkConf.get("spark.sql.codegen.aggregate.map.vectorized.enable") == "true", "configuration parameter changed in test body") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/AggregateBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/AggregateBenchmark.scala index aca1be01fa3da..a834b7cd2c69f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/AggregateBenchmark.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/AggregateBenchmark.scala @@ -107,14 +107,14 @@ class AggregateBenchmark extends BenchmarkBase { benchmark.addCase(s"codegen = T hashmap = F", numIters = 3) { iter => sparkSession.conf.set("spark.sql.codegen.wholeStage", "true") - sparkSession.conf.set("spark.sql.codegen.aggregate.map.twolevel.enable", "false") + sparkSession.conf.set("spark.sql.codegen.aggregate.map.twolevel.enabled", "false") sparkSession.conf.set("spark.sql.codegen.aggregate.map.vectorized.enable", "false") f() } benchmark.addCase(s"codegen = T hashmap = T", numIters = 5) { iter => sparkSession.conf.set("spark.sql.codegen.wholeStage", "true") - sparkSession.conf.set("spark.sql.codegen.aggregate.map.twolevel.enable", "true") + sparkSession.conf.set("spark.sql.codegen.aggregate.map.twolevel.enabled", "true") sparkSession.conf.set("spark.sql.codegen.aggregate.map.vectorized.enable", "true") f() } @@ -149,14 +149,14 @@ class AggregateBenchmark extends BenchmarkBase { benchmark.addCase(s"codegen = T hashmap = F", numIters = 3) { iter => sparkSession.conf.set("spark.sql.codegen.wholeStage", value = true) - sparkSession.conf.set("spark.sql.codegen.aggregate.map.twolevel.enable", "false") + sparkSession.conf.set("spark.sql.codegen.aggregate.map.twolevel.enabled", "false") sparkSession.conf.set("spark.sql.codegen.aggregate.map.vectorized.enable", "false") f() } benchmark.addCase(s"codegen = T hashmap = T", numIters = 5) { iter => sparkSession.conf.set("spark.sql.codegen.wholeStage", value = true) - sparkSession.conf.set("spark.sql.codegen.aggregate.map.twolevel.enable", "true") + sparkSession.conf.set("spark.sql.codegen.aggregate.map.twolevel.enabled", "true") sparkSession.conf.set("spark.sql.codegen.aggregate.map.vectorized.enable", "true") f() } @@ -189,14 +189,14 @@ class AggregateBenchmark extends BenchmarkBase { benchmark.addCase(s"codegen = T hashmap = F", numIters = 3) { iter => sparkSession.conf.set("spark.sql.codegen.wholeStage", "true") - sparkSession.conf.set("spark.sql.codegen.aggregate.map.twolevel.enable", "false") + sparkSession.conf.set("spark.sql.codegen.aggregate.map.twolevel.enabled", "false") sparkSession.conf.set("spark.sql.codegen.aggregate.map.vectorized.enable", "false") f() } benchmark.addCase(s"codegen = T hashmap = T", numIters = 5) { iter => sparkSession.conf.set("spark.sql.codegen.wholeStage", "true") - sparkSession.conf.set("spark.sql.codegen.aggregate.map.twolevel.enable", "true") + sparkSession.conf.set("spark.sql.codegen.aggregate.map.twolevel.enabled", "true") sparkSession.conf.set("spark.sql.codegen.aggregate.map.vectorized.enable", "true") f() } @@ -228,14 +228,14 @@ class AggregateBenchmark extends BenchmarkBase { benchmark.addCase(s"codegen = T hashmap = F") { iter => sparkSession.conf.set("spark.sql.codegen.wholeStage", "true") - sparkSession.conf.set("spark.sql.codegen.aggregate.map.twolevel.enable", "false") + sparkSession.conf.set("spark.sql.codegen.aggregate.map.twolevel.enabled", "false") sparkSession.conf.set("spark.sql.codegen.aggregate.map.vectorized.enable", "false") f() } benchmark.addCase(s"codegen = T hashmap = T") { iter => sparkSession.conf.set("spark.sql.codegen.wholeStage", "true") - sparkSession.conf.set("spark.sql.codegen.aggregate.map.twolevel.enable", "true") + sparkSession.conf.set("spark.sql.codegen.aggregate.map.twolevel.enabled", "true") sparkSession.conf.set("spark.sql.codegen.aggregate.map.vectorized.enable", "true") f() } @@ -277,14 +277,14 @@ class AggregateBenchmark extends BenchmarkBase { benchmark.addCase(s"codegen = T hashmap = F") { iter => sparkSession.conf.set("spark.sql.codegen.wholeStage", "true") - sparkSession.conf.set("spark.sql.codegen.aggregate.map.twolevel.enable", "false") + sparkSession.conf.set("spark.sql.codegen.aggregate.map.twolevel.enabled", "false") sparkSession.conf.set("spark.sql.codegen.aggregate.map.vectorized.enable", "false") f() } benchmark.addCase(s"codegen = T hashmap = T") { iter => sparkSession.conf.set("spark.sql.codegen.wholeStage", "true") - sparkSession.conf.set("spark.sql.codegen.aggregate.map.twolevel.enable", "true") + sparkSession.conf.set("spark.sql.codegen.aggregate.map.twolevel.enabled", "true") sparkSession.conf.set("spark.sql.codegen.aggregate.map.vectorized.enable", "true") f() } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala index f245a79f805a2..ae675149df5e2 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala @@ -1015,7 +1015,7 @@ class HashAggregationQueryWithControlledFallbackSuite extends AggregationQuerySu override protected def checkAnswer(actual: => DataFrame, expectedAnswer: Seq[Row]): Unit = { Seq("true", "false").foreach { enableTwoLevelMaps => - withSQLConf("spark.sql.codegen.aggregate.map.twolevel.enable" -> + withSQLConf("spark.sql.codegen.aggregate.map.twolevel.enabled" -> enableTwoLevelMaps) { (1 to 3).foreach { fallbackStartsAt => withSQLConf("spark.sql.TungstenAggregate.testFallbackStartsAt" -> From 3b5c2a84bfa311a94c1c0a57f2cb3e421fb05650 Mon Sep 17 00:00:00 2001 From: WeichenXu Date: Tue, 10 Oct 2017 08:27:45 +0100 Subject: [PATCH 1464/1765] [SPARK-21770][ML] ProbabilisticClassificationModel fix corner case: normalization of all-zero raw predictions ## What changes were proposed in this pull request? Fix probabilisticClassificationModel corner case: normalization of all-zero raw predictions, throw IllegalArgumentException with description. ## How was this patch tested? Test case added. Author: WeichenXu Closes #19106 from WeichenXu123/SPARK-21770. --- .../ProbabilisticClassifier.scala | 20 ++++++++++--------- .../ProbabilisticClassifierSuite.scala | 18 +++++++++++++++++ 2 files changed, 29 insertions(+), 9 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/ProbabilisticClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/ProbabilisticClassifier.scala index ef08134809915..730fcab333e11 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/ProbabilisticClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/ProbabilisticClassifier.scala @@ -230,21 +230,23 @@ private[ml] object ProbabilisticClassificationModel { * Normalize a vector of raw predictions to be a multinomial probability vector, in place. * * The input raw predictions should be nonnegative. - * The output vector sums to 1, unless the input vector is all-0 (in which case the output is - * all-0 too). + * The output vector sums to 1. * * NOTE: This is NOT applicable to all models, only ones which effectively use class * instance counts for raw predictions. + * + * @throws IllegalArgumentException if the input vector is all-0 or including negative values */ def normalizeToProbabilitiesInPlace(v: DenseVector): Unit = { + v.values.foreach(value => require(value >= 0, + "The input raw predictions should be nonnegative.")) val sum = v.values.sum - if (sum != 0) { - var i = 0 - val size = v.size - while (i < size) { - v.values(i) /= sum - i += 1 - } + require(sum > 0, "Can't normalize the 0-vector.") + var i = 0 + val size = v.size + while (i < size) { + v.values(i) /= sum + i += 1 } } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/ProbabilisticClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/ProbabilisticClassifierSuite.scala index 4ecd5a05365eb..d649ceac949c4 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/ProbabilisticClassifierSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/ProbabilisticClassifierSuite.scala @@ -80,6 +80,24 @@ class ProbabilisticClassifierSuite extends SparkFunSuite { new TestProbabilisticClassificationModel("myuid", 2, 2).setThresholds(Array(-0.1, 0.1)) } } + + test("normalizeToProbabilitiesInPlace") { + val vec1 = Vectors.dense(1.0, 2.0, 3.0).toDense + ProbabilisticClassificationModel.normalizeToProbabilitiesInPlace(vec1) + assert(vec1 ~== Vectors.dense(1.0 / 6, 2.0 / 6, 3.0 / 6) relTol 1e-3) + + // all-0 input test + val vec2 = Vectors.dense(0.0, 0.0, 0.0).toDense + intercept[IllegalArgumentException] { + ProbabilisticClassificationModel.normalizeToProbabilitiesInPlace(vec2) + } + + // negative input test + val vec3 = Vectors.dense(1.0, -1.0, 2.0).toDense + intercept[IllegalArgumentException] { + ProbabilisticClassificationModel.normalizeToProbabilitiesInPlace(vec3) + } + } } object ProbabilisticClassifierSuite { From b8a08f25cc64ed3034f3c90790931c30e5b0f236 Mon Sep 17 00:00:00 2001 From: liuxian Date: Tue, 10 Oct 2017 20:44:33 +0800 Subject: [PATCH 1465/1765] [SPARK-21506][DOC] The description of "spark.executor.cores" may be not correct ## What changes were proposed in this pull request? The number of cores assigned to each executor is configurable. When this is not explicitly set, multiple executors from the same application may be launched on the same worker too. ## How was this patch tested? N/A Author: liuxian Closes #18711 from 10110346/executorcores. --- .../spark/deploy/client/StandaloneAppClient.scala | 2 +- .../scala/org/apache/spark/deploy/master/Master.scala | 8 +++++++- .../cluster/StandaloneSchedulerBackend.scala | 2 +- docs/configuration.md | 11 ++++------- docs/spark-standalone.md | 8 ++++++++ 5 files changed, 21 insertions(+), 10 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/client/StandaloneAppClient.scala b/core/src/main/scala/org/apache/spark/deploy/client/StandaloneAppClient.scala index 757c930b84eb2..34ade4ce6f39b 100644 --- a/core/src/main/scala/org/apache/spark/deploy/client/StandaloneAppClient.scala +++ b/core/src/main/scala/org/apache/spark/deploy/client/StandaloneAppClient.scala @@ -170,7 +170,7 @@ private[spark] class StandaloneAppClient( case ExecutorAdded(id: Int, workerId: String, hostPort: String, cores: Int, memory: Int) => val fullId = appId + "/" + id - logInfo("Executor added: %s on %s (%s) with %d cores".format(fullId, workerId, hostPort, + logInfo("Executor added: %s on %s (%s) with %d core(s)".format(fullId, workerId, hostPort, cores)) listener.executorAdded(fullId, workerId, hostPort, cores, memory) diff --git a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala index e030cac60a8e4..2c78c15773af2 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala @@ -581,7 +581,13 @@ private[deploy] class Master( * The number of cores assigned to each executor is configurable. When this is explicitly set, * multiple executors from the same application may be launched on the same worker if the worker * has enough cores and memory. Otherwise, each executor grabs all the cores available on the - * worker by default, in which case only one executor may be launched on each worker. + * worker by default, in which case only one executor per application may be launched on each + * worker during one single schedule iteration. + * Note that when `spark.executor.cores` is not set, we may still launch multiple executors from + * the same application on the same worker. Consider appA and appB both have one executor running + * on worker1, and appA.coresLeft > 0, then appB is finished and release all its cores on worker1, + * thus for the next schedule iteration, appA launches a new executor that grabs all the free + * cores on worker1, therefore we get multiple executors from appA running on worker1. * * It is important to allocate coresPerExecutor on each worker at a time (instead of 1 core * at a time). Consider the following example: cluster has 4 workers with 16 cores each. diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/StandaloneSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/StandaloneSchedulerBackend.scala index a4e2a74341283..505c342a889ee 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/StandaloneSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/StandaloneSchedulerBackend.scala @@ -153,7 +153,7 @@ private[spark] class StandaloneSchedulerBackend( override def executorAdded(fullId: String, workerId: String, hostPort: String, cores: Int, memory: Int) { - logInfo("Granted executor ID %s on hostPort %s with %d cores, %s RAM".format( + logInfo("Granted executor ID %s on hostPort %s with %d core(s), %s RAM".format( fullId, hostPort, cores, Utils.megabytesToString(memory))) } diff --git a/docs/configuration.md b/docs/configuration.md index 6e9fe591b70a3..7a777d3c6fa3d 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -1015,7 +1015,7 @@ Apart from these, the following properties are also available, and may be useful 0.5 Amount of storage memory immune to eviction, expressed as a fraction of the size of the - region set aside by s​park.memory.fraction. The higher this is, the less + region set aside by spark.memory.fraction. The higher this is, the less working memory may be available to execution and tasks may spill to disk more often. Leaving this at the default value is recommended. For more detail, see this description. @@ -1041,7 +1041,7 @@ Apart from these, the following properties are also available, and may be useful spark.memory.useLegacyMode false - ​Whether to enable the legacy memory management mode used in Spark 1.5 and before. + Whether to enable the legacy memory management mode used in Spark 1.5 and before. The legacy mode rigidly partitions the heap space into fixed-size regions, potentially leading to excessive spilling if the application was not tuned. The following deprecated memory fraction configurations are not read unless this is enabled: @@ -1115,11 +1115,8 @@ Apart from these, the following properties are also available, and may be useful The number of cores to use on each executor. - In standalone and Mesos coarse-grained modes, setting this - parameter allows an application to run multiple executors on the - same worker, provided that there are enough cores on that - worker. Otherwise, only one executor per application will run on - each worker. + In standalone and Mesos coarse-grained modes, for more detail, see + this description. diff --git a/docs/spark-standalone.md b/docs/spark-standalone.md index 1095386c31ab8..f51c5cc38f4de 100644 --- a/docs/spark-standalone.md +++ b/docs/spark-standalone.md @@ -328,6 +328,14 @@ export SPARK_MASTER_OPTS="-Dspark.deploy.defaultCores=" This is useful on shared clusters where users might not have configured a maximum number of cores individually. +# Executors Scheduling + +The number of cores assigned to each executor is configurable. When `spark.executor.cores` is +explicitly set, multiple executors from the same application may be launched on the same worker +if the worker has enough cores and memory. Otherwise, each executor grabs all the cores available +on the worker by default, in which case only one executor per application may be launched on each +worker during one single schedule iteration. + # Monitoring and Logging Spark's standalone mode offers a web-based user interface to monitor the cluster. The master and each worker has its own web UI that shows cluster and job statistics. By default you can access the web UI for the master at port 8080. The port can be changed either in the configuration file or via command-line options. From 23af2d79ad9a3c83936485ee57513b39193a446b Mon Sep 17 00:00:00 2001 From: Prashant Sharma Date: Tue, 10 Oct 2017 20:48:42 +0800 Subject: [PATCH 1466/1765] [SPARK-20025][CORE] Ignore SPARK_LOCAL* env, while deploying via cluster mode. ## What changes were proposed in this pull request? In a bare metal system with No DNS setup, spark may be configured with SPARK_LOCAL* for IP and host properties. During a driver failover, in cluster deployment mode. SPARK_LOCAL* should be ignored while restarting on another node and should be picked up from target system's local environment. ## How was this patch tested? Distributed deployment against a spark standalone cluster of 6 Workers. Tested by killing JVM's running driver and verified the restarted JVMs have right configurations on them. Author: Prashant Sharma Author: Prashant Sharma Closes #17357 from ScrapCodes/driver-failover-fix. --- core/src/main/scala/org/apache/spark/deploy/Client.scala | 6 +++--- .../apache/spark/deploy/rest/StandaloneRestServer.scala | 4 +++- .../org/apache/spark/deploy/worker/DriverWrapper.scala | 9 ++++++--- 3 files changed, 12 insertions(+), 7 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/Client.scala b/core/src/main/scala/org/apache/spark/deploy/Client.scala index bf6093236d92b..7acb5c55bb252 100644 --- a/core/src/main/scala/org/apache/spark/deploy/Client.scala +++ b/core/src/main/scala/org/apache/spark/deploy/Client.scala @@ -93,19 +93,19 @@ private class ClientEndpoint( driverArgs.cores, driverArgs.supervise, command) - ayncSendToMasterAndForwardReply[SubmitDriverResponse]( + asyncSendToMasterAndForwardReply[SubmitDriverResponse]( RequestSubmitDriver(driverDescription)) case "kill" => val driverId = driverArgs.driverId - ayncSendToMasterAndForwardReply[KillDriverResponse](RequestKillDriver(driverId)) + asyncSendToMasterAndForwardReply[KillDriverResponse](RequestKillDriver(driverId)) } } /** * Send the message to master and forward the reply to self asynchronously. */ - private def ayncSendToMasterAndForwardReply[T: ClassTag](message: Any): Unit = { + private def asyncSendToMasterAndForwardReply[T: ClassTag](message: Any): Unit = { for (masterEndpoint <- masterEndpoints) { masterEndpoint.ask[T](message).onComplete { case Success(v) => self.send(v) diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/StandaloneRestServer.scala b/core/src/main/scala/org/apache/spark/deploy/rest/StandaloneRestServer.scala index 0164084ab129e..22b65abce611a 100644 --- a/core/src/main/scala/org/apache/spark/deploy/rest/StandaloneRestServer.scala +++ b/core/src/main/scala/org/apache/spark/deploy/rest/StandaloneRestServer.scala @@ -139,7 +139,9 @@ private[rest] class StandaloneSubmitRequestServlet( val driverExtraLibraryPath = sparkProperties.get("spark.driver.extraLibraryPath") val superviseDriver = sparkProperties.get("spark.driver.supervise") val appArgs = request.appArgs - val environmentVariables = request.environmentVariables + // Filter SPARK_LOCAL_(IP|HOSTNAME) environment variables from being set on the remote system. + val environmentVariables = + request.environmentVariables.filterNot(x => x._1.matches("SPARK_LOCAL_(IP|HOSTNAME)")) // Construct driver description val conf = new SparkConf(false) diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/DriverWrapper.scala b/core/src/main/scala/org/apache/spark/deploy/worker/DriverWrapper.scala index c1671192e0c64..b19c9904d5982 100644 --- a/core/src/main/scala/org/apache/spark/deploy/worker/DriverWrapper.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/DriverWrapper.scala @@ -23,6 +23,7 @@ import org.apache.commons.lang3.StringUtils import org.apache.spark.{SecurityManager, SparkConf} import org.apache.spark.deploy.{DependencyUtils, SparkHadoopUtil, SparkSubmit} +import org.apache.spark.internal.Logging import org.apache.spark.rpc.RpcEnv import org.apache.spark.util.{ChildFirstURLClassLoader, MutableURLClassLoader, Utils} @@ -30,7 +31,7 @@ import org.apache.spark.util.{ChildFirstURLClassLoader, MutableURLClassLoader, U * Utility object for launching driver programs such that they share fate with the Worker process. * This is used in standalone cluster mode only. */ -object DriverWrapper { +object DriverWrapper extends Logging { def main(args: Array[String]) { args.toList match { /* @@ -41,8 +42,10 @@ object DriverWrapper { */ case workerUrl :: userJar :: mainClass :: extraArgs => val conf = new SparkConf() - val rpcEnv = RpcEnv.create("Driver", - Utils.localHostName(), 0, conf, new SecurityManager(conf)) + val host: String = Utils.localHostName() + val port: Int = sys.props.getOrElse("spark.driver.port", "0").toInt + val rpcEnv = RpcEnv.create("Driver", host, port, conf, new SecurityManager(conf)) + logInfo(s"Driver address: ${rpcEnv.address}") rpcEnv.setupEndpoint("workerWatcher", new WorkerWatcher(rpcEnv, workerUrl)) val currentLoader = Thread.currentThread.getContextClassLoader From 633ffd816d285480bab1f346471135b10ec092bb Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Tue, 10 Oct 2017 11:01:02 -0700 Subject: [PATCH 1467/1765] rename the file. --- ...rSparkSubmitSutie.scala => BufferHolderSparkSubmitSuite.scala} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/{BufferHolderSparkSubmitSutie.scala => BufferHolderSparkSubmitSuite.scala} (100%) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/BufferHolderSparkSubmitSutie.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/BufferHolderSparkSubmitSuite.scala similarity index 100% rename from sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/BufferHolderSparkSubmitSutie.scala rename to sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/BufferHolderSparkSubmitSuite.scala From 2028e5a82bc3e9a79f9b84f376bdf606b8c9bb0f Mon Sep 17 00:00:00 2001 From: Eyal Farago Date: Tue, 10 Oct 2017 22:49:47 +0200 Subject: [PATCH 1468/1765] [SPARK-21907][CORE] oom during spill ## What changes were proposed in this pull request? 1. a test reproducing [SPARK-21907](https://issues.apache.org/jira/browse/SPARK-21907) 2. a fix for the root cause of the issue. `org.apache.spark.util.collection.unsafe.sort.UnsafeExternalSorter.spill` calls `org.apache.spark.util.collection.unsafe.sort.UnsafeInMemorySorter.reset` which may trigger another spill, when this happens the `array` member is already de-allocated but still referenced by the code, this causes the nested spill to fail with an NPE in `org.apache.spark.memory.TaskMemoryManager.getPage`. This patch introduces a reproduction in a test case and a fix, the fix simply sets the in-mem sorter's array member to an empty array before actually performing the allocation. This prevents the spilling code from 'touching' the de-allocated array. ## How was this patch tested? introduced a new test case: `org.apache.spark.util.collection.unsafe.sort.UnsafeExternalSorterSuite#testOOMDuringSpill`. Author: Eyal Farago Closes #19181 from eyalfa/SPARK-21907__oom_during_spill. --- .../unsafe/sort/UnsafeExternalSorter.java | 4 ++ .../unsafe/sort/UnsafeInMemorySorter.java | 12 ++++- .../sort/UnsafeExternalSorterSuite.java | 33 +++++++++++++ .../sort/UnsafeInMemorySorterSuite.java | 46 +++++++++++++++++++ .../spark/memory/TestMemoryManager.scala | 12 +++-- 5 files changed, 102 insertions(+), 5 deletions(-) diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java index 39eda00dd7efb..e749f7ba87c6e 100644 --- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java @@ -480,6 +480,10 @@ public UnsafeSorterIterator getSortedIterator() throws IOException { } } + @VisibleForTesting boolean hasSpaceForAnotherRecord() { + return inMemSorter.hasSpaceForAnotherRecord(); + } + private static void spillIterator(UnsafeSorterIterator inMemIterator, UnsafeSorterSpillWriter spillWriter) throws IOException { while (inMemIterator.hasNext()) { diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java index c14c12664f5ab..869ec908be1fb 100644 --- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java @@ -162,7 +162,9 @@ private int getUsableCapacity() { */ public void free() { if (consumer != null) { - consumer.freeArray(array); + if (array != null) { + consumer.freeArray(array); + } array = null; } } @@ -170,6 +172,14 @@ public void free() { public void reset() { if (consumer != null) { consumer.freeArray(array); + // the call to consumer.allocateArray may trigger a spill + // which in turn access this instance and eventually re-enter this method and try to free the array again. + // by setting the array to null and its length to 0 we effectively make the spill code-path a no-op. + // setting the array to null also indicates that it has already been de-allocated which prevents a double de-allocation in free(). + array = null; + usableCapacity = 0; + pos = 0; + nullBoundaryPos = 0; array = consumer.allocateArray(initialSize); usableCapacity = getUsableCapacity(); } diff --git a/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java b/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java index 5330a688e63e3..6c5451d0fd2a5 100644 --- a/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java +++ b/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java @@ -23,6 +23,7 @@ import java.util.LinkedList; import java.util.UUID; +import org.hamcrest.Matchers; import scala.Tuple2$; import org.junit.After; @@ -503,6 +504,38 @@ public void testGetIterator() throws Exception { verifyIntIterator(sorter.getIterator(279), 279, 300); } + @Test + public void testOOMDuringSpill() throws Exception { + final UnsafeExternalSorter sorter = newSorter(); + // we assume that given default configuration, + // the size of the data we insert to the sorter (ints) + // and assuming we shouldn't spill before pointers array is exhausted + // (memory manager is not configured to throw at this point) + // - so this loop runs a reasonable number of iterations (<2000). + // test indeed completed within <30ms (on a quad i7 laptop). + for (int i = 0; sorter.hasSpaceForAnotherRecord(); ++i) { + insertNumber(sorter, i); + } + // we expect the next insert to attempt growing the pointerssArray + // first allocation is expected to fail, then a spill is triggered which attempts another allocation + // which also fails and we expect to see this OOM here. + // the original code messed with a released array within the spill code + // and ended up with a failed assertion. + // we also expect the location of the OOM to be org.apache.spark.util.collection.unsafe.sort.UnsafeInMemorySorter.reset + memoryManager.markconsequentOOM(2); + try { + insertNumber(sorter, 1024); + fail("expected OutOfMmoryError but it seems operation surprisingly succeeded"); + } + // we expect an OutOfMemoryError here, anything else (i.e the original NPE is a failure) + catch (OutOfMemoryError oom){ + String oomStackTrace = Utils.exceptionString(oom); + assertThat("expected OutOfMemoryError in org.apache.spark.util.collection.unsafe.sort.UnsafeInMemorySorter.reset", + oomStackTrace, + Matchers.containsString("org.apache.spark.util.collection.unsafe.sort.UnsafeInMemorySorter.reset")); + } + } + private void verifyIntIterator(UnsafeSorterIterator iter, int start, int end) throws IOException { for (int i = start; i < end; i++) { diff --git a/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorterSuite.java b/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorterSuite.java index bd89085aa9a14..1a3e11efe9787 100644 --- a/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorterSuite.java +++ b/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorterSuite.java @@ -35,6 +35,7 @@ import static org.hamcrest.Matchers.greaterThanOrEqualTo; import static org.hamcrest.Matchers.isIn; import static org.junit.Assert.assertEquals; +import static org.junit.Assert.fail; import static org.mockito.Mockito.mock; public class UnsafeInMemorySorterSuite { @@ -139,4 +140,49 @@ public int compare( } assertEquals(dataToSort.length, iterLength); } + + @Test + public void freeAfterOOM() { + final SparkConf sparkConf = new SparkConf(); + sparkConf.set("spark.memory.offHeap.enabled", "false"); + + final TestMemoryManager testMemoryManager = + new TestMemoryManager(sparkConf); + final TaskMemoryManager memoryManager = new TaskMemoryManager( + testMemoryManager, 0); + final TestMemoryConsumer consumer = new TestMemoryConsumer(memoryManager); + final MemoryBlock dataPage = memoryManager.allocatePage(2048, consumer); + final Object baseObject = dataPage.getBaseObject(); + // Write the records into the data page: + long position = dataPage.getBaseOffset(); + + final HashPartitioner hashPartitioner = new HashPartitioner(4); + // Use integer comparison for comparing prefixes (which are partition ids, in this case) + final PrefixComparator prefixComparator = PrefixComparators.LONG; + final RecordComparator recordComparator = new RecordComparator() { + @Override + public int compare( + Object leftBaseObject, + long leftBaseOffset, + Object rightBaseObject, + long rightBaseOffset) { + return 0; + } + }; + UnsafeInMemorySorter sorter = new UnsafeInMemorySorter(consumer, memoryManager, + recordComparator, prefixComparator, 100, shouldUseRadixSort()); + + testMemoryManager.markExecutionAsOutOfMemoryOnce(); + try { + sorter.reset(); + fail("expected OutOfMmoryError but it seems operation surprisingly succeeded"); + } catch (OutOfMemoryError oom) { + // as expected + } + // [SPARK-21907] this failed on NPE at org.apache.spark.memory.MemoryConsumer.freeArray(MemoryConsumer.java:108) + sorter.free(); + // simulate a 'back to back' free. + sorter.free(); + } + } diff --git a/core/src/test/scala/org/apache/spark/memory/TestMemoryManager.scala b/core/src/test/scala/org/apache/spark/memory/TestMemoryManager.scala index 5f699df8211de..c26945fa5fa31 100644 --- a/core/src/test/scala/org/apache/spark/memory/TestMemoryManager.scala +++ b/core/src/test/scala/org/apache/spark/memory/TestMemoryManager.scala @@ -27,8 +27,8 @@ class TestMemoryManager(conf: SparkConf) numBytes: Long, taskAttemptId: Long, memoryMode: MemoryMode): Long = { - if (oomOnce) { - oomOnce = false + if (consequentOOM > 0) { + consequentOOM -= 1 0 } else if (available >= numBytes) { available -= numBytes @@ -58,11 +58,15 @@ class TestMemoryManager(conf: SparkConf) override def maxOffHeapStorageMemory: Long = 0L - private var oomOnce = false + private var consequentOOM = 0 private var available = Long.MaxValue def markExecutionAsOutOfMemoryOnce(): Unit = { - oomOnce = true + markconsequentOOM(1) + } + + def markconsequentOOM(n : Int) : Unit = { + consequentOOM += n } def limit(avail: Long): Unit = { From bfc7e1fe1ad5f9777126f2941e29bbe51ea5da7c Mon Sep 17 00:00:00 2001 From: Li Jin Date: Wed, 11 Oct 2017 07:32:01 +0900 Subject: [PATCH 1469/1765] [SPARK-20396][SQL][PYSPARK] groupby().apply() with pandas udf ## What changes were proposed in this pull request? This PR adds an apply() function on df.groupby(). apply() takes a pandas udf that is a transformation on `pandas.DataFrame` -> `pandas.DataFrame`. Static schema ------------------- ``` schema = df.schema pandas_udf(schema) def normalize(df): df = df.assign(v1 = (df.v1 - df.v1.mean()) / df.v1.std() return df df.groupBy('id').apply(normalize) ``` Dynamic schema ----------------------- **This use case is removed from the PR and we will discuss this as a follow up. See discussion https://github.com/apache/spark/pull/18732#pullrequestreview-66583248** Another example to use pd.DataFrame dtypes as output schema of the udf: ``` sample_df = df.filter(df.id == 1).toPandas() def foo(df): ret = # Some transformation on the input pd.DataFrame return ret foo_udf = pandas_udf(foo, foo(sample_df).dtypes) df.groupBy('id').apply(foo_udf) ``` In interactive use case, user usually have a sample pd.DataFrame to test function `foo` in their notebook. Having been able to use `foo(sample_df).dtypes` frees user from specifying the output schema of `foo`. Design doc: https://github.com/icexelloss/spark/blob/pandas-udf-doc/docs/pyspark-pandas-udf.md ## How was this patch tested? * Added GroupbyApplyTest Author: Li Jin Author: Takuya UESHIN Author: Bryan Cutler Closes #18732 from icexelloss/groupby-apply-SPARK-20396. --- python/pyspark/sql/dataframe.py | 6 +- python/pyspark/sql/functions.py | 98 ++++++++--- python/pyspark/sql/group.py | 88 +++++++++- python/pyspark/sql/tests.py | 157 +++++++++++++++++- python/pyspark/sql/types.py | 2 +- python/pyspark/worker.py | 35 ++-- .../sql/catalyst/optimizer/Optimizer.scala | 2 + .../logical/pythonLogicalOperators.scala | 39 +++++ .../spark/sql/RelationalGroupedDataset.scala | 36 +++- .../spark/sql/execution/SparkStrategies.scala | 2 + .../python/ArrowEvalPythonExec.scala | 39 ++++- .../execution/python/ArrowPythonRunner.scala | 15 +- .../execution/python/ExtractPythonUDFs.scala | 8 +- .../python/FlatMapGroupsInPandasExec.scala | 103 ++++++++++++ 14 files changed, 561 insertions(+), 69 deletions(-) create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/pythonLogicalOperators.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasExec.scala diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index fe69e588fe098..2d596229ced7e 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -1227,7 +1227,7 @@ def groupBy(self, *cols): """ jgd = self._jdf.groupBy(self._jcols(*cols)) from pyspark.sql.group import GroupedData - return GroupedData(jgd, self.sql_ctx) + return GroupedData(jgd, self) @since(1.4) def rollup(self, *cols): @@ -1248,7 +1248,7 @@ def rollup(self, *cols): """ jgd = self._jdf.rollup(self._jcols(*cols)) from pyspark.sql.group import GroupedData - return GroupedData(jgd, self.sql_ctx) + return GroupedData(jgd, self) @since(1.4) def cube(self, *cols): @@ -1271,7 +1271,7 @@ def cube(self, *cols): """ jgd = self._jdf.cube(self._jcols(*cols)) from pyspark.sql.group import GroupedData - return GroupedData(jgd, self.sql_ctx) + return GroupedData(jgd, self) @since(1.3) def agg(self, *exprs): diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index b45a59db93679..9bc12c3b7a162 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -2058,7 +2058,7 @@ def __init__(self, func, returnType, name=None, vectorized=False): self._name = name or ( func.__name__ if hasattr(func, '__name__') else func.__class__.__name__) - self._vectorized = vectorized + self.vectorized = vectorized @property def returnType(self): @@ -2090,7 +2090,7 @@ def _create_judf(self): wrapped_func = _wrap_function(sc, self.func, self.returnType) jdt = spark._jsparkSession.parseDataType(self.returnType.json()) judf = sc._jvm.org.apache.spark.sql.execution.python.UserDefinedPythonFunction( - self._name, wrapped_func, jdt, self._vectorized) + self._name, wrapped_func, jdt, self.vectorized) return judf def __call__(self, *cols): @@ -2118,8 +2118,10 @@ def wrapper(*args): wrapper.__name__ = self._name wrapper.__module__ = (self.func.__module__ if hasattr(self.func, '__module__') else self.func.__class__.__module__) + wrapper.func = self.func wrapper.returnType = self.returnType + wrapper.vectorized = self.vectorized return wrapper @@ -2129,8 +2131,12 @@ def _create_udf(f, returnType, vectorized): def _udf(f, returnType=StringType(), vectorized=vectorized): if vectorized: import inspect - if len(inspect.getargspec(f).args) == 0: - raise NotImplementedError("0-parameter pandas_udfs are not currently supported") + argspec = inspect.getargspec(f) + if len(argspec.args) == 0 and argspec.varargs is None: + raise ValueError( + "0-arg pandas_udfs are not supported. " + "Instead, create a 1-arg pandas_udf and ignore the arg in your function." + ) udf_obj = UserDefinedFunction(f, returnType, vectorized=vectorized) return udf_obj._wrapped() @@ -2146,7 +2152,7 @@ def _udf(f, returnType=StringType(), vectorized=vectorized): @since(1.3) def udf(f=None, returnType=StringType()): - """Creates a :class:`Column` expression representing a user defined function (UDF). + """Creates a user defined function (UDF). .. note:: The user-defined functions must be deterministic. Due to optimization, duplicate invocations may be eliminated or the function may even be invoked more times than @@ -2181,30 +2187,70 @@ def udf(f=None, returnType=StringType()): @since(2.3) def pandas_udf(f=None, returnType=StringType()): """ - Creates a :class:`Column` expression representing a user defined function (UDF) that accepts - `Pandas.Series` as input arguments and outputs a `Pandas.Series` of the same length. + Creates a vectorized user defined function (UDF). - :param f: python function if used as a standalone function + :param f: user-defined function. A python function if used as a standalone function :param returnType: a :class:`pyspark.sql.types.DataType` object - >>> from pyspark.sql.types import IntegerType, StringType - >>> slen = pandas_udf(lambda s: s.str.len(), IntegerType()) - >>> @pandas_udf(returnType=StringType()) - ... def to_upper(s): - ... return s.str.upper() - ... - >>> @pandas_udf(returnType="integer") - ... def add_one(x): - ... return x + 1 - ... - >>> df = spark.createDataFrame([(1, "John Doe", 21)], ("id", "name", "age")) - >>> df.select(slen("name").alias("slen(name)"), to_upper("name"), add_one("age")) \\ - ... .show() # doctest: +SKIP - +----------+--------------+------------+ - |slen(name)|to_upper(name)|add_one(age)| - +----------+--------------+------------+ - | 8| JOHN DOE| 22| - +----------+--------------+------------+ + The user-defined function can define one of the following transformations: + + 1. One or more `pandas.Series` -> A `pandas.Series` + + This udf is used with :meth:`pyspark.sql.DataFrame.withColumn` and + :meth:`pyspark.sql.DataFrame.select`. + The returnType should be a primitive data type, e.g., `DoubleType()`. + The length of the returned `pandas.Series` must be of the same as the input `pandas.Series`. + + >>> from pyspark.sql.types import IntegerType, StringType + >>> slen = pandas_udf(lambda s: s.str.len(), IntegerType()) + >>> @pandas_udf(returnType=StringType()) + ... def to_upper(s): + ... return s.str.upper() + ... + >>> @pandas_udf(returnType="integer") + ... def add_one(x): + ... return x + 1 + ... + >>> df = spark.createDataFrame([(1, "John Doe", 21)], ("id", "name", "age")) + >>> df.select(slen("name").alias("slen(name)"), to_upper("name"), add_one("age")) \\ + ... .show() # doctest: +SKIP + +----------+--------------+------------+ + |slen(name)|to_upper(name)|add_one(age)| + +----------+--------------+------------+ + | 8| JOHN DOE| 22| + +----------+--------------+------------+ + + 2. A `pandas.DataFrame` -> A `pandas.DataFrame` + + This udf is only used with :meth:`pyspark.sql.GroupedData.apply`. + The returnType should be a :class:`StructType` describing the schema of the returned + `pandas.DataFrame`. + + >>> df = spark.createDataFrame( + ... [(1, 1.0), (1, 2.0), (2, 3.0), (2, 5.0), (2, 10.0)], + ... ("id", "v")) + >>> @pandas_udf(returnType=df.schema) + ... def normalize(pdf): + ... v = pdf.v + ... return pdf.assign(v=(v - v.mean()) / v.std()) + >>> df.groupby('id').apply(normalize).show() # doctest: +SKIP + +---+-------------------+ + | id| v| + +---+-------------------+ + | 1|-0.7071067811865475| + | 1| 0.7071067811865475| + | 2|-0.8320502943378437| + | 2|-0.2773500981126146| + | 2| 1.1094003924504583| + +---+-------------------+ + + .. note:: This type of udf cannot be used with functions such as `withColumn` or `select` + because it defines a `DataFrame` transformation rather than a `Column` + transformation. + + .. seealso:: :meth:`pyspark.sql.GroupedData.apply` + + .. note:: The user-defined function must be deterministic. """ return _create_udf(f, returnType=returnType, vectorized=True) diff --git a/python/pyspark/sql/group.py b/python/pyspark/sql/group.py index f2092f9c63054..817d0bc83bb77 100644 --- a/python/pyspark/sql/group.py +++ b/python/pyspark/sql/group.py @@ -54,9 +54,10 @@ class GroupedData(object): .. versionadded:: 1.3 """ - def __init__(self, jgd, sql_ctx): + def __init__(self, jgd, df): self._jgd = jgd - self.sql_ctx = sql_ctx + self._df = df + self.sql_ctx = df.sql_ctx @ignore_unicode_prefix @since(1.3) @@ -170,7 +171,7 @@ def sum(self, *cols): @since(1.6) def pivot(self, pivot_col, values=None): """ - Pivots a column of the current [[DataFrame]] and perform the specified aggregation. + Pivots a column of the current :class:`DataFrame` and perform the specified aggregation. There are two versions of pivot function: one that requires the caller to specify the list of distinct values to pivot on, and one that does not. The latter is more concise but less efficient, because Spark needs to first compute the list of distinct values internally. @@ -192,7 +193,85 @@ def pivot(self, pivot_col, values=None): jgd = self._jgd.pivot(pivot_col) else: jgd = self._jgd.pivot(pivot_col, values) - return GroupedData(jgd, self.sql_ctx) + return GroupedData(jgd, self._df) + + @since(2.3) + def apply(self, udf): + """ + Maps each group of the current :class:`DataFrame` using a pandas udf and returns the result + as a `DataFrame`. + + The user-defined function should take a `pandas.DataFrame` and return another + `pandas.DataFrame`. For each group, all columns are passed together as a `pandas.DataFrame` + to the user-function and the returned `pandas.DataFrame`s are combined as a + :class:`DataFrame`. + The returned `pandas.DataFrame` can be of arbitrary length and its schema must match the + returnType of the pandas udf. + + This function does not support partial aggregation, and requires shuffling all the data in + the :class:`DataFrame`. + + :param udf: A function object returned by :meth:`pyspark.sql.functions.pandas_udf` + + >>> from pyspark.sql.functions import pandas_udf + >>> df = spark.createDataFrame( + ... [(1, 1.0), (1, 2.0), (2, 3.0), (2, 5.0), (2, 10.0)], + ... ("id", "v")) + >>> @pandas_udf(returnType=df.schema) + ... def normalize(pdf): + ... v = pdf.v + ... return pdf.assign(v=(v - v.mean()) / v.std()) + >>> df.groupby('id').apply(normalize).show() # doctest: +SKIP + +---+-------------------+ + | id| v| + +---+-------------------+ + | 1|-0.7071067811865475| + | 1| 0.7071067811865475| + | 2|-0.8320502943378437| + | 2|-0.2773500981126146| + | 2| 1.1094003924504583| + +---+-------------------+ + + .. seealso:: :meth:`pyspark.sql.functions.pandas_udf` + + """ + from pyspark.sql.functions import pandas_udf + + # Columns are special because hasattr always return True + if isinstance(udf, Column) or not hasattr(udf, 'func') or not udf.vectorized: + raise ValueError("The argument to apply must be a pandas_udf") + if not isinstance(udf.returnType, StructType): + raise ValueError("The returnType of the pandas_udf must be a StructType") + + df = self._df + func = udf.func + returnType = udf.returnType + + # The python executors expects the function to use pd.Series as input and output + # So we to create a wrapper function that turns that to a pd.DataFrame before passing + # down to the user function, then turn the result pd.DataFrame back into pd.Series + columns = df.columns + + def wrapped(*cols): + from pyspark.sql.types import to_arrow_type + import pandas as pd + result = func(pd.concat(cols, axis=1, keys=columns)) + if not isinstance(result, pd.DataFrame): + raise TypeError("Return type of the user-defined function should be " + "Pandas.DataFrame, but is {}".format(type(result))) + if not len(result.columns) == len(returnType): + raise RuntimeError( + "Number of columns of the returned Pandas.DataFrame " + "doesn't match specified schema. " + "Expected: {} Actual: {}".format(len(returnType), len(result.columns))) + arrow_return_types = (to_arrow_type(field.dataType) for field in returnType) + return [(result[result.columns[i]], arrow_type) + for i, arrow_type in enumerate(arrow_return_types)] + + wrapped_udf_obj = pandas_udf(wrapped, returnType) + udf_column = wrapped_udf_obj(*[df[col] for col in df.columns]) + jdf = self._jgd.flatMapGroupsInPandas(udf_column._jc.expr()) + return DataFrame(jdf, self.sql_ctx) def _test(): @@ -206,6 +285,7 @@ def _test(): .getOrCreate() sc = spark.sparkContext globs['sc'] = sc + globs['spark'] = spark globs['df'] = sc.parallelize([(2, 'Alice'), (5, 'Bob')]) \ .toDF(StructType([StructField('age', IntegerType()), StructField('name', StringType())])) diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index a59378b5e848a..bac2ef84ae7a7 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -3256,17 +3256,17 @@ def test_vectorized_udf_null_string(self): def test_vectorized_udf_zero_parameter(self): from pyspark.sql.functions import pandas_udf - error_str = '0-parameter pandas_udfs.*not.*supported' + error_str = '0-arg pandas_udfs.*not.*supported' with QuietTest(self.sc): - with self.assertRaisesRegexp(NotImplementedError, error_str): + with self.assertRaisesRegexp(ValueError, error_str): pandas_udf(lambda: 1, LongType()) - with self.assertRaisesRegexp(NotImplementedError, error_str): + with self.assertRaisesRegexp(ValueError, error_str): @pandas_udf def zero_no_type(): return 1 - with self.assertRaisesRegexp(NotImplementedError, error_str): + with self.assertRaisesRegexp(ValueError, error_str): @pandas_udf(LongType()) def zero_with_type(): return 1 @@ -3348,7 +3348,7 @@ def test_vectorized_udf_wrong_return_type(self): df = self.spark.range(10) f = pandas_udf(lambda x: x * 1.0, StringType()) with QuietTest(self.sc): - with self.assertRaisesRegexp(Exception, 'Invalid.*type.*string'): + with self.assertRaisesRegexp(Exception, 'Invalid.*type'): df.select(f(col('id'))).collect() def test_vectorized_udf_return_scalar(self): @@ -3356,7 +3356,7 @@ def test_vectorized_udf_return_scalar(self): df = self.spark.range(10) f = pandas_udf(lambda x: 1.0, DoubleType()) with QuietTest(self.sc): - with self.assertRaisesRegexp(Exception, 'Return.*type.*pandas_udf.*Series'): + with self.assertRaisesRegexp(Exception, 'Return.*type.*Series'): df.select(f(col('id'))).collect() def test_vectorized_udf_decorator(self): @@ -3376,6 +3376,151 @@ def test_vectorized_udf_empty_partition(self): res = df.select(f(col('id'))) self.assertEquals(df.collect(), res.collect()) + def test_vectorized_udf_varargs(self): + from pyspark.sql.functions import pandas_udf, col + df = self.spark.createDataFrame(self.sc.parallelize([Row(id=1)], 2)) + f = pandas_udf(lambda *v: v[0], LongType()) + res = df.select(f(col('id'))) + self.assertEquals(df.collect(), res.collect()) + + +@unittest.skipIf(not _have_pandas or not _have_arrow, "Pandas or Arrow not installed") +class GroupbyApplyTests(ReusedPySparkTestCase): + @classmethod + def setUpClass(cls): + ReusedPySparkTestCase.setUpClass() + cls.spark = SparkSession(cls.sc) + + @classmethod + def tearDownClass(cls): + ReusedPySparkTestCase.tearDownClass() + cls.spark.stop() + + def assertFramesEqual(self, expected, result): + msg = ("DataFrames are not equal: " + + ("\n\nExpected:\n%s\n%s" % (expected, expected.dtypes)) + + ("\n\nResult:\n%s\n%s" % (result, result.dtypes))) + self.assertTrue(expected.equals(result), msg=msg) + + @property + def data(self): + from pyspark.sql.functions import array, explode, col, lit + return self.spark.range(10).toDF('id') \ + .withColumn("vs", array([lit(i) for i in range(20, 30)])) \ + .withColumn("v", explode(col('vs'))).drop('vs') + + def test_simple(self): + from pyspark.sql.functions import pandas_udf + df = self.data + + foo_udf = pandas_udf( + lambda pdf: pdf.assign(v1=pdf.v * pdf.id * 1.0, v2=pdf.v + pdf.id), + StructType( + [StructField('id', LongType()), + StructField('v', IntegerType()), + StructField('v1', DoubleType()), + StructField('v2', LongType())])) + + result = df.groupby('id').apply(foo_udf).sort('id').toPandas() + expected = df.toPandas().groupby('id').apply(foo_udf.func).reset_index(drop=True) + self.assertFramesEqual(expected, result) + + def test_decorator(self): + from pyspark.sql.functions import pandas_udf + df = self.data + + @pandas_udf(StructType( + [StructField('id', LongType()), + StructField('v', IntegerType()), + StructField('v1', DoubleType()), + StructField('v2', LongType())])) + def foo(pdf): + return pdf.assign(v1=pdf.v * pdf.id * 1.0, v2=pdf.v + pdf.id) + + result = df.groupby('id').apply(foo).sort('id').toPandas() + expected = df.toPandas().groupby('id').apply(foo.func).reset_index(drop=True) + self.assertFramesEqual(expected, result) + + def test_coerce(self): + from pyspark.sql.functions import pandas_udf + df = self.data + + foo = pandas_udf( + lambda pdf: pdf, + StructType([StructField('id', LongType()), StructField('v', DoubleType())])) + + result = df.groupby('id').apply(foo).sort('id').toPandas() + expected = df.toPandas().groupby('id').apply(foo.func).reset_index(drop=True) + expected = expected.assign(v=expected.v.astype('float64')) + self.assertFramesEqual(expected, result) + + def test_complex_groupby(self): + from pyspark.sql.functions import pandas_udf, col + df = self.data + + @pandas_udf(StructType( + [StructField('id', LongType()), + StructField('v', IntegerType()), + StructField('norm', DoubleType())])) + def normalize(pdf): + v = pdf.v + return pdf.assign(norm=(v - v.mean()) / v.std()) + + result = df.groupby(col('id') % 2 == 0).apply(normalize).sort('id', 'v').toPandas() + pdf = df.toPandas() + expected = pdf.groupby(pdf['id'] % 2 == 0).apply(normalize.func) + expected = expected.sort_values(['id', 'v']).reset_index(drop=True) + expected = expected.assign(norm=expected.norm.astype('float64')) + self.assertFramesEqual(expected, result) + + def test_empty_groupby(self): + from pyspark.sql.functions import pandas_udf, col + df = self.data + + @pandas_udf(StructType( + [StructField('id', LongType()), + StructField('v', IntegerType()), + StructField('norm', DoubleType())])) + def normalize(pdf): + v = pdf.v + return pdf.assign(norm=(v - v.mean()) / v.std()) + + result = df.groupby().apply(normalize).sort('id', 'v').toPandas() + pdf = df.toPandas() + expected = normalize.func(pdf) + expected = expected.sort_values(['id', 'v']).reset_index(drop=True) + expected = expected.assign(norm=expected.norm.astype('float64')) + self.assertFramesEqual(expected, result) + + def test_wrong_return_type(self): + from pyspark.sql.functions import pandas_udf + df = self.data + + foo = pandas_udf( + lambda pdf: pdf, + StructType([StructField('id', LongType()), StructField('v', StringType())])) + + with QuietTest(self.sc): + with self.assertRaisesRegexp(Exception, 'Invalid.*type'): + df.groupby('id').apply(foo).sort('id').toPandas() + + def test_wrong_args(self): + from pyspark.sql.functions import udf, pandas_udf, sum + df = self.data + + with QuietTest(self.sc): + with self.assertRaisesRegexp(ValueError, 'pandas_udf'): + df.groupby('id').apply(lambda x: x) + with self.assertRaisesRegexp(ValueError, 'pandas_udf'): + df.groupby('id').apply(udf(lambda x: x, DoubleType())) + with self.assertRaisesRegexp(ValueError, 'pandas_udf'): + df.groupby('id').apply(sum(df.v)) + with self.assertRaisesRegexp(ValueError, 'pandas_udf'): + df.groupby('id').apply(df.v + 1) + with self.assertRaisesRegexp(ValueError, 'returnType'): + df.groupby('id').apply(pandas_udf(lambda x: x, DoubleType())) + + if __name__ == "__main__": from pyspark.sql.tests import * if xmlrunner: diff --git a/python/pyspark/sql/types.py b/python/pyspark/sql/types.py index ebdc11c3b744a..f65273d5f0b6c 100644 --- a/python/pyspark/sql/types.py +++ b/python/pyspark/sql/types.py @@ -1597,7 +1597,7 @@ def convert(self, obj, gateway_client): register_input_converter(DateConverter()) -def toArrowType(dt): +def to_arrow_type(dt): """ Convert Spark data type to pyarrow type """ import pyarrow as pa diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py index 4e24789cf010d..eb6d48688dc0a 100644 --- a/python/pyspark/worker.py +++ b/python/pyspark/worker.py @@ -32,7 +32,7 @@ from pyspark.serializers import write_with_length, write_int, read_long, \ write_long, read_int, SpecialLengths, PythonEvalType, UTF8Deserializer, PickleSerializer, \ BatchedSerializer, ArrowStreamPandasSerializer -from pyspark.sql.types import toArrowType +from pyspark.sql.types import to_arrow_type, StructType from pyspark import shuffle pickleSer = PickleSerializer() @@ -74,17 +74,28 @@ def wrap_udf(f, return_type): def wrap_pandas_udf(f, return_type): - arrow_return_type = toArrowType(return_type) - - def verify_result_length(*a): - result = f(*a) - if not hasattr(result, "__len__"): - raise TypeError("Return type of pandas_udf should be a Pandas.Series") - if len(result) != len(a[0]): - raise RuntimeError("Result vector from pandas_udf was not the required length: " - "expected %d, got %d" % (len(a[0]), len(result))) - return result - return lambda *a: (verify_result_length(*a), arrow_return_type) + # If the return_type is a StructType, it indicates this is a groupby apply udf, + # and has already been wrapped under apply(), otherwise, it's a vectorized column udf. + # We can distinguish these two by return type because in groupby apply, we always specify + # returnType as a StructType, and in vectorized column udf, StructType is not supported. + # + # TODO: Look into refactoring use of StructType to be more flexible for future pandas_udfs + if isinstance(return_type, StructType): + return lambda *a: f(*a) + else: + arrow_return_type = to_arrow_type(return_type) + + def verify_result_length(*a): + result = f(*a) + if not hasattr(result, "__len__"): + raise TypeError("Return type of the user-defined functon should be " + "Pandas.Series, but is {}".format(type(result))) + if len(result) != len(a[0]): + raise RuntimeError("Result vector from pandas_udf was not the required length: " + "expected %d, got %d" % (len(a[0]), len(result))) + return result + + return lambda *a: (verify_result_length(*a), arrow_return_type) def read_single_udf(pickleSer, infile, eval_type): diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index bc2d4a824cb49..d829e01441dcc 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -452,6 +452,8 @@ object ColumnPruning extends Rule[LogicalPlan] { // Prunes the unused columns from child of Aggregate/Expand/Generate case a @ Aggregate(_, _, child) if (child.outputSet -- a.references).nonEmpty => a.copy(child = prunedChild(child, a.references)) + case f @ FlatMapGroupsInPandas(_, _, _, child) if (child.outputSet -- f.references).nonEmpty => + f.copy(child = prunedChild(child, f.references)) case e @ Expand(_, _, child) if (child.outputSet -- e.references).nonEmpty => e.copy(child = prunedChild(child, e.references)) case g: Generate if !g.join && (g.child.outputSet -- g.references).nonEmpty => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/pythonLogicalOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/pythonLogicalOperators.scala new file mode 100644 index 0000000000000..8abab24bc9b44 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/pythonLogicalOperators.scala @@ -0,0 +1,39 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.plans.logical + +import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeSet, Expression} + +/** + * FlatMap groups using an udf: pandas.Dataframe -> pandas.DataFrame. + * This is used by DataFrame.groupby().apply(). + */ +case class FlatMapGroupsInPandas( + groupingAttributes: Seq[Attribute], + functionExpr: Expression, + output: Seq[Attribute], + child: LogicalPlan) extends UnaryNode { + /** + * This is needed because output attributes are considered `references` when + * passed through the constructor. + * + * Without this, catalyst will complain that output attributes are missing + * from the input. + */ + override val producedAttributes = AttributeSet(output) +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala index 147b549964913..cd0ac1feffa51 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala @@ -27,12 +27,12 @@ import org.apache.spark.broadcast.Broadcast import org.apache.spark.sql.catalyst.analysis.{Star, UnresolvedAlias, UnresolvedAttribute, UnresolvedFunction} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ -import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, FlatMapGroupsInR, Pivot} +import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.util.usePrettyExpression import org.apache.spark.sql.execution.aggregate.TypedAggregateExpression +import org.apache.spark.sql.execution.python.PythonUDF import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.types.NumericType -import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.types.{NumericType, StructField, StructType} /** * A set of methods for aggregations on a `DataFrame`, created by [[Dataset#groupBy groupBy]], @@ -435,6 +435,36 @@ class RelationalGroupedDataset protected[sql]( df.logicalPlan.output, df.logicalPlan)) } + + /** + * Applies a vectorized python user-defined function to each group of data. + * The user-defined function defines a transformation: `pandas.DataFrame` -> `pandas.DataFrame`. + * For each group, all elements in the group are passed as a `pandas.DataFrame` and the results + * for all groups are combined into a new [[DataFrame]]. + * + * This function does not support partial aggregation, and requires shuffling all the data in + * the [[DataFrame]]. + * + * This function uses Apache Arrow as serialization format between Java executors and Python + * workers. + */ + private[sql] def flatMapGroupsInPandas(expr: PythonUDF): DataFrame = { + require(expr.vectorized, "Must pass a vectorized python udf") + require(expr.dataType.isInstanceOf[StructType], + "The returnType of the vectorized python udf must be a StructType") + + val groupingNamedExpressions = groupingExprs.map { + case ne: NamedExpression => ne + case other => Alias(other, other.toString)() + } + val groupingAttributes = groupingNamedExpressions.map(_.toAttribute) + val child = df.logicalPlan + val project = Project(groupingNamedExpressions ++ child.output, child) + val output = expr.dataType.asInstanceOf[StructType].toAttributes + val plan = FlatMapGroupsInPandas(groupingAttributes, expr, output, project) + + Dataset.ofRows(df.sparkSession, plan) + } } private[sql] object RelationalGroupedDataset { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 92eaab5cd8f81..4cdcc73faacd7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -392,6 +392,8 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { case logical.FlatMapGroupsInR(f, p, b, is, os, key, value, grouping, data, objAttr, child) => execution.FlatMapGroupsInRExec(f, p, b, is, os, key, value, grouping, data, objAttr, planLater(child)) :: Nil + case logical.FlatMapGroupsInPandas(grouping, func, output, child) => + execution.python.FlatMapGroupsInPandasExec(grouping, func, output, planLater(child)) :: Nil case logical.MapElements(f, _, _, objAttr, child) => execution.MapElementsExec(f, objAttr, planLater(child)) :: Nil case logical.AppendColumns(f, _, _, in, out, child) => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonExec.scala index f7e8cbe416121..81896187ecc46 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonExec.scala @@ -26,6 +26,35 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.execution.SparkPlan import org.apache.spark.sql.types.StructType +/** + * Grouped a iterator into batches. + * This is similar to iter.grouped but returns Iterator[T] instead of Seq[T]. + * This is necessary because sometimes we cannot hold reference of input rows + * because the some input rows are mutable and can be reused. + */ +private class BatchIterator[T](iter: Iterator[T], batchSize: Int) + extends Iterator[Iterator[T]] { + + override def hasNext: Boolean = iter.hasNext + + override def next(): Iterator[T] = { + new Iterator[T] { + var count = 0 + + override def hasNext: Boolean = iter.hasNext && count < batchSize + + override def next(): T = { + if (!hasNext) { + Iterator.empty.next() + } else { + count += 1 + iter.next() + } + } + } + } +} + /** * A physical plan that evaluates a [[PythonUDF]], */ @@ -44,14 +73,18 @@ case class ArrowEvalPythonExec(udfs: Seq[PythonUDF], output: Seq[Attribute], chi val schemaOut = StructType.fromAttributes(output.drop(child.output.length).zipWithIndex .map { case (attr, i) => attr.withName(s"_$i") }) + val batchSize = conf.arrowMaxRecordsPerBatch + // DO NOT use iter.grouped(). See BatchIterator. + val batchIter = if (batchSize > 0) new BatchIterator(iter, batchSize) else Iterator(iter) + val columnarBatchIter = new ArrowPythonRunner( - funcs, conf.arrowMaxRecordsPerBatch, bufferSize, reuseWorker, + funcs, bufferSize, reuseWorker, PythonEvalType.SQL_PANDAS_UDF, argOffsets, schema) - .compute(iter, context.partitionId(), context) + .compute(batchIter, context.partitionId(), context) new Iterator[InternalRow] { - var currentIter = if (columnarBatchIter.hasNext) { + private var currentIter = if (columnarBatchIter.hasNext) { val batch = columnarBatchIter.next() assert(schemaOut.equals(batch.schema), s"Invalid schema from pandas_udf: expected $schemaOut, got ${batch.schema}") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala index bbad9d6b631fd..f6c03c415dc66 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala @@ -39,19 +39,18 @@ import org.apache.spark.util.Utils */ class ArrowPythonRunner( funcs: Seq[ChainedPythonFunctions], - batchSize: Int, bufferSize: Int, reuseWorker: Boolean, evalType: Int, argOffsets: Array[Array[Int]], schema: StructType) - extends BasePythonRunner[InternalRow, ColumnarBatch]( + extends BasePythonRunner[Iterator[InternalRow], ColumnarBatch]( funcs, bufferSize, reuseWorker, evalType, argOffsets) { protected override def newWriterThread( env: SparkEnv, worker: Socket, - inputIterator: Iterator[InternalRow], + inputIterator: Iterator[Iterator[InternalRow]], partitionIndex: Int, context: TaskContext): WriterThread = { new WriterThread(env, worker, inputIterator, partitionIndex, context) { @@ -82,12 +81,12 @@ class ArrowPythonRunner( Utils.tryWithSafeFinally { while (inputIterator.hasNext) { - var rowCount = 0 - while (inputIterator.hasNext && (batchSize <= 0 || rowCount < batchSize)) { - val row = inputIterator.next() - arrowWriter.write(row) - rowCount += 1 + val nextBatch = inputIterator.next() + + while (nextBatch.hasNext) { + arrowWriter.write(nextBatch.next()) } + arrowWriter.finish() writer.writeBatch() arrowWriter.reset() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala index fec456d86dbe2..e3f952e221d53 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala @@ -24,8 +24,7 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LogicalPlan, Project} import org.apache.spark.sql.catalyst.rules.Rule -import org.apache.spark.sql.execution -import org.apache.spark.sql.execution.{FilterExec, SparkPlan} +import org.apache.spark.sql.execution.{FilterExec, ProjectExec, SparkPlan} /** @@ -111,6 +110,9 @@ object ExtractPythonUDFs extends Rule[SparkPlan] with PredicateHelper { } def apply(plan: SparkPlan): SparkPlan = plan transformUp { + // FlatMapGroupsInPandas can be evaluated directly in python worker + // Therefore we don't need to extract the UDFs + case plan: FlatMapGroupsInPandasExec => plan case plan: SparkPlan => extract(plan) } @@ -169,7 +171,7 @@ object ExtractPythonUDFs extends Rule[SparkPlan] with PredicateHelper { val newPlan = extract(rewritten) if (newPlan.output != plan.output) { // Trim away the new UDF value if it was only used for filtering or something. - execution.ProjectExec(plan.output, newPlan) + ProjectExec(plan.output, newPlan) } else { newPlan } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasExec.scala new file mode 100644 index 0000000000000..b996b5bb38ba5 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasExec.scala @@ -0,0 +1,103 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.python + +import scala.collection.JavaConverters._ + +import org.apache.spark.TaskContext +import org.apache.spark.api.python.{ChainedPythonFunctions, PythonEvalType} +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans.physical.{AllTuples, ClusteredDistribution, Distribution, Partitioning} +import org.apache.spark.sql.execution.{GroupedIterator, SparkPlan, UnaryExecNode} +import org.apache.spark.sql.types.StructType + +/** + * Physical node for [[org.apache.spark.sql.catalyst.plans.logical.FlatMapGroupsInPandas]] + * + * Rows in each group are passed to the Python worker as an Arrow record batch. + * The Python worker turns the record batch to a `pandas.DataFrame`, invoke the + * user-defined function, and passes the resulting `pandas.DataFrame` + * as an Arrow record batch. Finally, each record batch is turned to + * Iterator[InternalRow] using ColumnarBatch. + * + * Note on memory usage: + * Both the Python worker and the Java executor need to have enough memory to + * hold the largest group. The memory on the Java side is used to construct the + * record batch (off heap memory). The memory on the Python side is used for + * holding the `pandas.DataFrame`. It's possible to further split one group into + * multiple record batches to reduce the memory footprint on the Java side, this + * is left as future work. + */ +case class FlatMapGroupsInPandasExec( + groupingAttributes: Seq[Attribute], + func: Expression, + output: Seq[Attribute], + child: SparkPlan) + extends UnaryExecNode { + + private val pandasFunction = func.asInstanceOf[PythonUDF].func + + override def outputPartitioning: Partitioning = child.outputPartitioning + + override def producedAttributes: AttributeSet = AttributeSet(output) + + override def requiredChildDistribution: Seq[Distribution] = { + if (groupingAttributes.isEmpty) { + AllTuples :: Nil + } else { + ClusteredDistribution(groupingAttributes) :: Nil + } + } + + override def requiredChildOrdering: Seq[Seq[SortOrder]] = + Seq(groupingAttributes.map(SortOrder(_, Ascending))) + + override protected def doExecute(): RDD[InternalRow] = { + val inputRDD = child.execute() + + val bufferSize = inputRDD.conf.getInt("spark.buffer.size", 65536) + val reuseWorker = inputRDD.conf.getBoolean("spark.python.worker.reuse", defaultValue = true) + val chainedFunc = Seq(ChainedPythonFunctions(Seq(pandasFunction))) + val argOffsets = Array((0 until (child.output.length - groupingAttributes.length)).toArray) + val schema = StructType(child.schema.drop(groupingAttributes.length)) + + inputRDD.mapPartitionsInternal { iter => + val grouped = if (groupingAttributes.isEmpty) { + Iterator(iter) + } else { + val groupedIter = GroupedIterator(iter, groupingAttributes, child.output) + val dropGrouping = + UnsafeProjection.create(child.output.drop(groupingAttributes.length), child.output) + groupedIter.map { + case (_, groupedRowIter) => groupedRowIter.map(dropGrouping) + } + } + + val context = TaskContext.get() + + val columnarBatchIter = new ArrowPythonRunner( + chainedFunc, bufferSize, reuseWorker, + PythonEvalType.SQL_PANDAS_UDF, argOffsets, schema) + .compute(grouped, context.partitionId(), context) + + columnarBatchIter.flatMap(_.rowIterator.asScala).map(UnsafeProjection.create(output, output)) + } + } +} From bd4eb9ce57da7bacff69d9ed958c94f349b7e6fb Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Tue, 10 Oct 2017 15:50:37 -0700 Subject: [PATCH 1470/1765] [SPARK-19558][SQL] Add config key to register QueryExecutionListeners automatically. This change adds a new SQL config key that is equivalent to SparkContext's "spark.extraListeners", allowing users to register QueryExecutionListener instances through the Spark configuration system instead of having to explicitly do it in code. The code used by SparkContext to implement the feature was refactored into a helper method in the Utils class, and SQL's ExecutionListenerManager was modified to use it to initialize listener declared in the configuration. Unit tests were added to verify all the new functionality. Author: Marcelo Vanzin Closes #19309 from vanzin/SPARK-19558. --- .../scala/org/apache/spark/SparkContext.scala | 38 ++-------- .../spark/internal/config/package.scala | 7 ++ .../scala/org/apache/spark/util/Utils.scala | 57 ++++++++++++++- .../spark/scheduler/SparkListenerSuite.scala | 6 +- .../org/apache/spark/util/UtilsSuite.scala | 56 ++++++++++++++- .../spark/sql/internal/StaticSQLConf.scala | 8 +++ .../internal/BaseSessionStateBuilder.scala | 3 +- .../sql/util/QueryExecutionListener.scala | 12 +++- .../util/ExecutionListenerManagerSuite.scala | 69 +++++++++++++++++++ 9 files changed, 216 insertions(+), 40 deletions(-) create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/util/ExecutionListenerManagerSuite.scala diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index b3cd03c0cfbe1..6f25d346e6e54 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -2344,41 +2344,13 @@ class SparkContext(config: SparkConf) extends Logging { * (e.g. after the web UI and event logging listeners have been registered). */ private def setupAndStartListenerBus(): Unit = { - // Use reflection to instantiate listeners specified via `spark.extraListeners` try { - val listenerClassNames: Seq[String] = - conf.get("spark.extraListeners", "").split(',').map(_.trim).filter(_ != "") - for (className <- listenerClassNames) { - // Use reflection to find the right constructor - val constructors = { - val listenerClass = Utils.classForName(className) - listenerClass - .getConstructors - .asInstanceOf[Array[Constructor[_ <: SparkListenerInterface]]] + conf.get(EXTRA_LISTENERS).foreach { classNames => + val listeners = Utils.loadExtensions(classOf[SparkListenerInterface], classNames, conf) + listeners.foreach { listener => + listenerBus.addToSharedQueue(listener) + logInfo(s"Registered listener ${listener.getClass().getName()}") } - val constructorTakingSparkConf = constructors.find { c => - c.getParameterTypes.sameElements(Array(classOf[SparkConf])) - } - lazy val zeroArgumentConstructor = constructors.find { c => - c.getParameterTypes.isEmpty - } - val listener: SparkListenerInterface = { - if (constructorTakingSparkConf.isDefined) { - constructorTakingSparkConf.get.newInstance(conf) - } else if (zeroArgumentConstructor.isDefined) { - zeroArgumentConstructor.get.newInstance() - } else { - throw new SparkException( - s"$className did not have a zero-argument constructor or a" + - " single-argument constructor that accepts SparkConf. Note: if the class is" + - " defined inside of another Scala class, then its constructors may accept an" + - " implicit parameter that references the enclosing class; in this case, you must" + - " define the listener as a top-level class in order to prevent this extra" + - " parameter from breaking Spark's ability to find a valid constructor.") - } - } - listenerBus.addToSharedQueue(listener) - logInfo(s"Registered listener $className") } } catch { case e: Exception => diff --git a/core/src/main/scala/org/apache/spark/internal/config/package.scala b/core/src/main/scala/org/apache/spark/internal/config/package.scala index 5278e5e0fb270..19336f854145f 100644 --- a/core/src/main/scala/org/apache/spark/internal/config/package.scala +++ b/core/src/main/scala/org/apache/spark/internal/config/package.scala @@ -419,4 +419,11 @@ package object config { .stringConf .toSequence .createWithDefault(Nil) + + private[spark] val EXTRA_LISTENERS = ConfigBuilder("spark.extraListeners") + .doc("Class names of listeners to add to SparkContext during initialization.") + .stringConf + .toSequence + .createOptional + } diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index 836e33c36d9a1..930e09d90c2f5 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -19,6 +19,7 @@ package org.apache.spark.util import java.io._ import java.lang.management.{LockInfo, ManagementFactory, MonitorInfo, ThreadInfo} +import java.lang.reflect.InvocationTargetException import java.math.{MathContext, RoundingMode} import java.net._ import java.nio.ByteBuffer @@ -37,7 +38,7 @@ import scala.collection.Map import scala.collection.mutable.ArrayBuffer import scala.io.Source import scala.reflect.ClassTag -import scala.util.Try +import scala.util.{Failure, Success, Try} import scala.util.control.{ControlThrowable, NonFatal} import scala.util.matching.Regex @@ -2687,6 +2688,60 @@ private[spark] object Utils extends Logging { def stringToSeq(str: String): Seq[String] = { str.split(",").map(_.trim()).filter(_.nonEmpty) } + + /** + * Create instances of extension classes. + * + * The classes in the given list must: + * - Be sub-classes of the given base class. + * - Provide either a no-arg constructor, or a 1-arg constructor that takes a SparkConf. + * + * The constructors are allowed to throw "UnsupportedOperationException" if the extension does not + * want to be registered; this allows the implementations to check the Spark configuration (or + * other state) and decide they do not need to be added. A log message is printed in that case. + * Other exceptions are bubbled up. + */ + def loadExtensions[T](extClass: Class[T], classes: Seq[String], conf: SparkConf): Seq[T] = { + classes.flatMap { name => + try { + val klass = classForName(name) + require(extClass.isAssignableFrom(klass), + s"$name is not a subclass of ${extClass.getName()}.") + + val ext = Try(klass.getConstructor(classOf[SparkConf])) match { + case Success(ctor) => + ctor.newInstance(conf) + + case Failure(_) => + klass.getConstructor().newInstance() + } + + Some(ext.asInstanceOf[T]) + } catch { + case _: NoSuchMethodException => + throw new SparkException( + s"$name did not have a zero-argument constructor or a" + + " single-argument constructor that accepts SparkConf. Note: if the class is" + + " defined inside of another Scala class, then its constructors may accept an" + + " implicit parameter that references the enclosing class; in this case, you must" + + " define the class as a top-level class in order to prevent this extra" + + " parameter from breaking Spark's ability to find a valid constructor.") + + case e: InvocationTargetException => + e.getCause() match { + case uoe: UnsupportedOperationException => + logDebug(s"Extension $name not being initialized.", uoe) + logInfo(s"Extension $name not being initialized.") + None + + case null => throw e + + case cause => throw cause + } + } + } + } + } private[util] object CallerContext extends Logging { diff --git a/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala index d061c7845f4a6..1beb36afa95f0 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala @@ -27,7 +27,7 @@ import org.scalatest.Matchers import org.apache.spark._ import org.apache.spark.executor.TaskMetrics -import org.apache.spark.internal.config.LISTENER_BUS_EVENT_QUEUE_CAPACITY +import org.apache.spark.internal.config._ import org.apache.spark.metrics.MetricsSystem import org.apache.spark.util.{ResetSystemProperties, RpcUtils} @@ -446,13 +446,13 @@ class SparkListenerSuite extends SparkFunSuite with LocalSparkContext with Match classOf[FirehoseListenerThatAcceptsSparkConf], classOf[BasicJobCounter]) val conf = new SparkConf().setMaster("local").setAppName("test") - .set("spark.extraListeners", listeners.map(_.getName).mkString(",")) + .set(EXTRA_LISTENERS, listeners.map(_.getName)) sc = new SparkContext(conf) sc.listenerBus.listeners.asScala.count(_.isInstanceOf[BasicJobCounter]) should be (1) sc.listenerBus.listeners.asScala .count(_.isInstanceOf[ListenerThatAcceptsSparkConf]) should be (1) sc.listenerBus.listeners.asScala - .count(_.isInstanceOf[FirehoseListenerThatAcceptsSparkConf]) should be (1) + .count(_.isInstanceOf[FirehoseListenerThatAcceptsSparkConf]) should be (1) } test("add and remove listeners to/from LiveListenerBus queues") { diff --git a/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala b/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala index 2b16cc4852ba8..4d3adeb968e84 100644 --- a/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala @@ -38,9 +38,10 @@ import org.apache.commons.math3.stat.inference.ChiSquareTest import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.Path -import org.apache.spark.{SparkConf, SparkFunSuite, TaskContext} +import org.apache.spark.{SparkConf, SparkException, SparkFunSuite, TaskContext} import org.apache.spark.internal.Logging import org.apache.spark.network.util.ByteUnit +import org.apache.spark.scheduler.SparkListener class UtilsSuite extends SparkFunSuite with ResetSystemProperties with Logging { @@ -1110,4 +1111,57 @@ class UtilsSuite extends SparkFunSuite with ResetSystemProperties with Logging { Utils.tryWithSafeFinallyAndFailureCallbacks {}(catchBlock = {}, finallyBlock = {}) TaskContext.unset } + + test("load extensions") { + val extensions = Seq( + classOf[SimpleExtension], + classOf[ExtensionWithConf], + classOf[UnregisterableExtension]).map(_.getName()) + + val conf = new SparkConf(false) + val instances = Utils.loadExtensions(classOf[Object], extensions, conf) + assert(instances.size === 2) + assert(instances.count(_.isInstanceOf[SimpleExtension]) === 1) + + val extWithConf = instances.find(_.isInstanceOf[ExtensionWithConf]) + .map(_.asInstanceOf[ExtensionWithConf]) + .get + assert(extWithConf.conf eq conf) + + class NestedExtension { } + + val invalid = Seq(classOf[NestedExtension].getName()) + intercept[SparkException] { + Utils.loadExtensions(classOf[Object], invalid, conf) + } + + val error = Seq(classOf[ExtensionWithError].getName()) + intercept[IllegalArgumentException] { + Utils.loadExtensions(classOf[Object], error, conf) + } + + val wrongType = Seq(classOf[ListenerImpl].getName()) + intercept[IllegalArgumentException] { + Utils.loadExtensions(classOf[Seq[_]], wrongType, conf) + } + } + +} + +private class SimpleExtension + +private class ExtensionWithConf(val conf: SparkConf) + +private class UnregisterableExtension { + + throw new UnsupportedOperationException() + +} + +private class ExtensionWithError { + + throw new IllegalArgumentException() + } + +private class ListenerImpl extends SparkListener diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/StaticSQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/StaticSQLConf.scala index c6c0a605d89ff..c018fc8a332fa 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/StaticSQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/StaticSQLConf.scala @@ -87,4 +87,12 @@ object StaticSQLConf { "implement Function1[SparkSessionExtension, Unit], and must have a no-args constructor.") .stringConf .createOptional + + val QUERY_EXECUTION_LISTENERS = buildStaticConf("spark.sql.queryExecutionListeners") + .doc("List of class names implementing QueryExecutionListener that will be automatically " + + "added to newly created sessions. The classes should have either a no-arg constructor, " + + "or a constructor that expects a SparkConf argument.") + .stringConf + .toSequence + .createOptional } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala index 4e756084bbdbb..2867b4cd7da5e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala @@ -266,7 +266,8 @@ abstract class BaseSessionStateBuilder( * This gets cloned from parent if available, otherwise is a new instance is created. */ protected def listenerManager: ExecutionListenerManager = { - parentState.map(_.listenerManager.clone()).getOrElse(new ExecutionListenerManager) + parentState.map(_.listenerManager.clone()).getOrElse( + new ExecutionListenerManager(session.sparkContext.conf)) } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/util/QueryExecutionListener.scala b/sql/core/src/main/scala/org/apache/spark/sql/util/QueryExecutionListener.scala index f6240d85fba6f..2b46233e1a5df 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/util/QueryExecutionListener.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/util/QueryExecutionListener.scala @@ -22,9 +22,12 @@ import java.util.concurrent.locks.ReentrantReadWriteLock import scala.collection.mutable.ListBuffer import scala.util.control.NonFatal +import org.apache.spark.SparkConf import org.apache.spark.annotation.{DeveloperApi, Experimental, InterfaceStability} import org.apache.spark.internal.Logging import org.apache.spark.sql.execution.QueryExecution +import org.apache.spark.sql.internal.StaticSQLConf._ +import org.apache.spark.util.Utils /** * :: Experimental :: @@ -72,7 +75,14 @@ trait QueryExecutionListener { */ @Experimental @InterfaceStability.Evolving -class ExecutionListenerManager private[sql] () extends Logging { +class ExecutionListenerManager private extends Logging { + + private[sql] def this(conf: SparkConf) = { + this() + conf.get(QUERY_EXECUTION_LISTENERS).foreach { classNames => + Utils.loadExtensions(classOf[QueryExecutionListener], classNames, conf).foreach(register) + } + } /** * Registers the specified [[QueryExecutionListener]]. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/util/ExecutionListenerManagerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/util/ExecutionListenerManagerSuite.scala new file mode 100644 index 0000000000000..4205e23ae240a --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/util/ExecutionListenerManagerSuite.scala @@ -0,0 +1,69 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.util + +import java.util.concurrent.atomic.AtomicInteger + +import org.apache.spark._ +import org.apache.spark.sql.execution.QueryExecution +import org.apache.spark.sql.internal.StaticSQLConf._ + +class ExecutionListenerManagerSuite extends SparkFunSuite { + + import CountingQueryExecutionListener._ + + test("register query execution listeners using configuration") { + val conf = new SparkConf(false) + .set(QUERY_EXECUTION_LISTENERS, Seq(classOf[CountingQueryExecutionListener].getName())) + + val mgr = new ExecutionListenerManager(conf) + assert(INSTANCE_COUNT.get() === 1) + mgr.onSuccess(null, null, 42L) + assert(CALLBACK_COUNT.get() === 1) + + val clone = mgr.clone() + assert(INSTANCE_COUNT.get() === 1) + + clone.onSuccess(null, null, 42L) + assert(CALLBACK_COUNT.get() === 2) + } + +} + +private class CountingQueryExecutionListener extends QueryExecutionListener { + + import CountingQueryExecutionListener._ + + INSTANCE_COUNT.incrementAndGet() + + override def onSuccess(funcName: String, qe: QueryExecution, durationNs: Long): Unit = { + CALLBACK_COUNT.incrementAndGet() + } + + override def onFailure(funcName: String, qe: QueryExecution, exception: Exception): Unit = { + CALLBACK_COUNT.incrementAndGet() + } + +} + +private object CountingQueryExecutionListener { + + val CALLBACK_COUNT = new AtomicInteger() + val INSTANCE_COUNT = new AtomicInteger() + +} From 76fb173dd639baa9534486488155fc05a71f850e Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Tue, 10 Oct 2017 20:29:02 -0700 Subject: [PATCH 1471/1765] [SPARK-21751][SQL] CodeGeneraor.splitExpressions counts code size more precisely ## What changes were proposed in this pull request? Current `CodeGeneraor.splitExpressions` splits statements into methods if the total length of statements is more than 1024 characters. The length may include comments or empty line. This PR excludes comment or empty line from the length to reduce the number of generated methods in a class, by using `CodeFormatter.stripExtraNewLinesAndComments()` method. ## How was this patch tested? Existing tests Author: Kazuaki Ishizaki Closes #18966 from kiszk/SPARK-21751. --- .../expressions/codegen/CodeFormatter.scala | 8 +++++ .../expressions/codegen/CodeGenerator.scala | 5 ++- .../codegen/CodeFormatterSuite.scala | 32 +++++++++++++++++++ 3 files changed, 44 insertions(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeFormatter.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeFormatter.scala index 60e600d8dbd8f..7b398f424cead 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeFormatter.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeFormatter.scala @@ -89,6 +89,14 @@ object CodeFormatter { } new CodeAndComment(code.result().trim(), map) } + + def stripExtraNewLinesAndComments(input: String): String = { + val commentReg = + ("""([ |\t]*?\/\*[\s|\S]*?\*\/[ |\t]*?)|""" + // strip /*comment*/ + """([ |\t]*?\/\/[\s\S]*?\n)""").r // strip //comment + val codeWithoutComment = commentReg.replaceAllIn(input, "") + codeWithoutComment.replaceAll("""\n\s*\n""", "\n") // strip ExtraNewLines + } } private class CodeFormatter { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index f9c5ef8439085..2cb66599076a9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -772,16 +772,19 @@ class CodegenContext { foldFunctions: Seq[String] => String = _.mkString("", ";\n", ";")): String = { val blocks = new ArrayBuffer[String]() val blockBuilder = new StringBuilder() + var length = 0 for (code <- expressions) { // We can't know how many bytecode will be generated, so use the length of source code // as metric. A method should not go beyond 8K, otherwise it will not be JITted, should // also not be too small, or it will have many function calls (for wide table), see the // results in BenchmarkWideTable. - if (blockBuilder.length > 1024) { + if (length > 1024) { blocks += blockBuilder.toString() blockBuilder.clear() + length = 0 } blockBuilder.append(code) + length += CodeFormatter.stripExtraNewLinesAndComments(code).length } blocks += blockBuilder.toString() diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeFormatterSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeFormatterSuite.scala index 9d0a41661beaa..a0f1a64b0ab08 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeFormatterSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeFormatterSuite.scala @@ -53,6 +53,38 @@ class CodeFormatterSuite extends SparkFunSuite { assert(reducedCode.body === "/*project_c4*/") } + test("removing extra new lines and comments") { + val code = + """ + |/* + | * multi + | * line + | * comments + | */ + | + |public function() { + |/*comment*/ + | /*comment_with_space*/ + |code_body + |//comment + |code_body + | //comment_with_space + | + |code_body + |} + """.stripMargin + + val reducedCode = CodeFormatter.stripExtraNewLinesAndComments(code) + assert(reducedCode === + """ + |public function() { + |code_body + |code_body + |code_body + |} + """.stripMargin) + } + testCase("basic example") { """ |class A { From 655f6f86f84ff5241d1d20766e1ef83bb32ca5e0 Mon Sep 17 00:00:00 2001 From: Zhenhua Wang Date: Wed, 11 Oct 2017 00:16:12 -0700 Subject: [PATCH 1472/1765] [SPARK-22208][SQL] Improve percentile_approx by not rounding up targetError and starting from index 0 ## What changes were proposed in this pull request? Currently percentile_approx never returns the first element when percentile is in (relativeError, 1/N], where relativeError default 1/10000, and N is the total number of elements. But ideally, percentiles in [0, 1/N] should all return the first element as the answer. For example, given input data 1 to 10, if a user queries 10% (or even less) percentile, it should return 1, because the first value 1 already reaches 10%. Currently it returns 2. Based on the paper, targetError is not rounded up, and searching index should start from 0 instead of 1. By following the paper, we should be able to fix the cases mentioned above. ## How was this patch tested? Added a new test case and fix existing test cases. Author: Zhenhua Wang Closes #19438 from wzhfy/improve_percentile_approx. --- R/pkg/tests/fulltests/test_sparkSQL.R | 8 ++++---- .../apache/spark/ml/feature/ImputerSuite.scala | 2 +- python/pyspark/sql/dataframe.py | 6 +++--- .../sql/catalyst/util/QuantileSummaries.scala | 4 ++-- .../catalyst/util/QuantileSummariesSuite.scala | 10 ++++++++-- .../sql/ApproximatePercentileQuerySuite.scala | 17 ++++++++++++++++- .../apache/spark/sql/DataFrameStatSuite.scala | 2 +- .../org/apache/spark/sql/DataFrameSuite.scala | 2 +- 8 files changed, 36 insertions(+), 15 deletions(-) diff --git a/R/pkg/tests/fulltests/test_sparkSQL.R b/R/pkg/tests/fulltests/test_sparkSQL.R index bbea25bc4da5c..4382ef2ed4525 100644 --- a/R/pkg/tests/fulltests/test_sparkSQL.R +++ b/R/pkg/tests/fulltests/test_sparkSQL.R @@ -2538,7 +2538,7 @@ test_that("describe() and summary() on a DataFrame", { stats2 <- summary(df) expect_equal(collect(stats2)[5, "summary"], "25%") - expect_equal(collect(stats2)[5, "age"], "30") + expect_equal(collect(stats2)[5, "age"], "19") stats3 <- summary(df, "min", "max", "55.1%") @@ -2738,7 +2738,7 @@ test_that("sampleBy() on a DataFrame", { }) test_that("approxQuantile() on a DataFrame", { - l <- lapply(c(0:99), function(i) { list(i, 99 - i) }) + l <- lapply(c(0:100), function(i) { list(i, 100 - i) }) df <- createDataFrame(l, list("a", "b")) quantiles <- approxQuantile(df, "a", c(0.5, 0.8), 0.0) expect_equal(quantiles, list(50, 80)) @@ -2749,8 +2749,8 @@ test_that("approxQuantile() on a DataFrame", { dfWithNA <- createDataFrame(data.frame(a = c(NA, 30, 19, 11, 28, 15), b = c(-30, -19, NA, -11, -28, -15))) quantiles3 <- approxQuantile(dfWithNA, c("a", "b"), c(0.5), 0.0) - expect_equal(quantiles3[[1]], list(28)) - expect_equal(quantiles3[[2]], list(-15)) + expect_equal(quantiles3[[1]], list(19)) + expect_equal(quantiles3[[2]], list(-19)) }) test_that("SQL error message is returned from JVM", { diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/ImputerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/ImputerSuite.scala index ee2ba73fa96d5..c08b35b419266 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/ImputerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/ImputerSuite.scala @@ -43,7 +43,7 @@ class ImputerSuite extends SparkFunSuite with MLlibTestSparkContext with Default (0, 1.0, 1.0, 1.0), (1, 3.0, 3.0, 3.0), (2, Double.NaN, Double.NaN, Double.NaN), - (3, -1.0, 2.0, 3.0) + (3, -1.0, 2.0, 1.0) )).toDF("id", "value", "expected_mean_value", "expected_median_value") val imputer = new Imputer().setInputCols(Array("value")).setOutputCols(Array("out")) .setMissingValue(-1.0) diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index 2d596229ced7e..38b01f0011671 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -1038,8 +1038,8 @@ def summary(self, *statistics): | mean| 3.5| null| | stddev|2.1213203435596424| null| | min| 2|Alice| - | 25%| 5| null| - | 50%| 5| null| + | 25%| 2| null| + | 50%| 2| null| | 75%| 5| null| | max| 5| Bob| +-------+------------------+-----+ @@ -1050,7 +1050,7 @@ def summary(self, *statistics): +-------+---+-----+ | count| 2| 2| | min| 2|Alice| - | 25%| 5| null| + | 25%| 2| null| | 75%| 5| null| | max| 5| Bob| +-------+---+-----+ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/QuantileSummaries.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/QuantileSummaries.scala index af543b04ba780..eb7941cf9e6af 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/QuantileSummaries.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/QuantileSummaries.scala @@ -193,10 +193,10 @@ class QuantileSummaries( // Target rank val rank = math.ceil(quantile * count).toInt - val targetError = math.ceil(relativeError * count) + val targetError = relativeError * count // Minimum rank at current sample var minRank = 0 - var i = 1 + var i = 0 while (i < sampled.length - 1) { val curSample = sampled(i) minRank += curSample.g diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/QuantileSummariesSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/QuantileSummariesSuite.scala index df579d5ec1ddf..650813975d75c 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/QuantileSummariesSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/QuantileSummariesSuite.scala @@ -57,8 +57,14 @@ class QuantileSummariesSuite extends SparkFunSuite { private def checkQuantile(quant: Double, data: Seq[Double], summary: QuantileSummaries): Unit = { if (data.nonEmpty) { val approx = summary.query(quant).get - // The rank of the approximation. - val rank = data.count(_ < approx) // has to be <, not <= to be exact + // Get the rank of the approximation. + val rankOfValue = data.count(_ <= approx) + val rankOfPreValue = data.count(_ < approx) + // `rankOfValue` is the last position of the quantile value. If the input repeats the value + // chosen as the quantile, e.g. in (1,2,2,2,2,2,3), the 50% quantile is 2, then it's + // improper to choose the last position as its rank. Instead, we get the rank by averaging + // `rankOfValue` and `rankOfPreValue`. + val rank = math.ceil((rankOfValue + rankOfPreValue) / 2.0) val lower = math.floor((quant - summary.relativeError) * data.size) val upper = math.ceil((quant + summary.relativeError) * data.size) val msg = diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ApproximatePercentileQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ApproximatePercentileQuerySuite.scala index 1aea33766407f..137c5bea2abb9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ApproximatePercentileQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ApproximatePercentileQuerySuite.scala @@ -53,6 +53,21 @@ class ApproximatePercentileQuerySuite extends QueryTest with SharedSQLContext { } } + test("percentile_approx, the first element satisfies small percentages") { + withTempView(table) { + (1 to 10).toDF("col").createOrReplaceTempView(table) + checkAnswer( + spark.sql( + s""" + |SELECT + | percentile_approx(col, array(0.01, 0.1, 0.11)) + |FROM $table + """.stripMargin), + Row(Seq(1, 1, 2)) + ) + } + } + test("percentile_approx, array of percentile value") { withTempView(table) { (1 to 1000).toDF("col").createOrReplaceTempView(table) @@ -130,7 +145,7 @@ class ApproximatePercentileQuerySuite extends QueryTest with SharedSQLContext { (1 to 1000).toDF("col").createOrReplaceTempView(table) checkAnswer( spark.sql(s"SELECT percentile_approx(col, array(0.25 + 0.25D), 200 + 800D) FROM $table"), - Row(Seq(500D)) + Row(Seq(499)) ) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala index 247c30e2ee65b..46b21c3b64a2e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala @@ -141,7 +141,7 @@ class DataFrameStatSuite extends QueryTest with SharedSQLContext { test("approximate quantile") { val n = 1000 - val df = Seq.tabulate(n)(i => (i, 2.0 * i)).toDF("singles", "doubles") + val df = Seq.tabulate(n + 1)(i => (i, 2.0 * i)).toDF("singles", "doubles") val q1 = 0.5 val q2 = 0.8 diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index dd8f54b690f64..ad461fa6144b3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -855,7 +855,7 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { Row("mean", null, "33.0", "178.0"), Row("stddev", null, "19.148542155126762", "11.547005383792516"), Row("min", "Alice", "16", "164"), - Row("25%", null, "24", "176"), + Row("25%", null, "16", "164"), Row("50%", null, "24", "176"), Row("75%", null, "32", "180"), Row("max", "David", "60", "192")) From 645e108eeb6364e57f5d7213dbbd42dbcf1124d3 Mon Sep 17 00:00:00 2001 From: Shixiong Zhu Date: Wed, 11 Oct 2017 13:51:33 -0700 Subject: [PATCH 1473/1765] [SPARK-21988][SS] Implement StreamingRelation.computeStats to fix explain ## What changes were proposed in this pull request? Implement StreamingRelation.computeStats to fix explain ## How was this patch tested? - unit tests: `StreamingRelation.computeStats` and `StreamingExecutionRelation.computeStats`. - regression tests: `explain join with a normal source` and `explain join with MemoryStream`. Author: Shixiong Zhu Closes #19465 from zsxwing/SPARK-21988. --- .../streaming/StreamingRelation.scala | 8 +++ .../spark/sql/streaming/StreamSuite.scala | 65 ++++++++++++++++--- 2 files changed, 63 insertions(+), 10 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingRelation.scala index ab716052c28ba..6b82c78ea653d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingRelation.scala @@ -44,6 +44,14 @@ case class StreamingRelation(dataSource: DataSource, sourceName: String, output: extends LeafNode { override def isStreaming: Boolean = true override def toString: String = sourceName + + // There's no sensible value here. On the execution path, this relation will be + // swapped out with microbatches. But some dataframe operations (in particular explain) do lead + // to this node surviving analysis. So we satisfy the LeafNode contract with the session default + // value. + override def computeStats(): Statistics = Statistics( + sizeInBytes = BigInt(dataSource.sparkSession.sessionState.conf.defaultSizeInBytes) + ) } /** diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala index 9c901062d570a..3d687d2214e90 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala @@ -76,20 +76,65 @@ class StreamSuite extends StreamTest { CheckAnswer(Row(1, 1, "one"), Row(2, 2, "two"), Row(4, 4, "four"))) } + test("StreamingRelation.computeStats") { + val streamingRelation = spark.readStream.format("rate").load().logicalPlan collect { + case s: StreamingRelation => s + } + assert(streamingRelation.nonEmpty, "cannot find StreamingRelation") + assert( + streamingRelation.head.computeStats.sizeInBytes == spark.sessionState.conf.defaultSizeInBytes) + } - test("explain join") { - // Make a table and ensure it will be broadcast. - val smallTable = Seq((1, "one"), (2, "two"), (4, "four")).toDF("number", "word") + test("StreamingExecutionRelation.computeStats") { + val streamingExecutionRelation = MemoryStream[Int].toDF.logicalPlan collect { + case s: StreamingExecutionRelation => s + } + assert(streamingExecutionRelation.nonEmpty, "cannot find StreamingExecutionRelation") + assert(streamingExecutionRelation.head.computeStats.sizeInBytes + == spark.sessionState.conf.defaultSizeInBytes) + } - // Join the input stream with a table. - val inputData = MemoryStream[Int] - val joined = inputData.toDF().join(smallTable, smallTable("number") === $"value") + test("explain join with a normal source") { + // This test triggers CostBasedJoinReorder to call `computeStats` + withSQLConf(SQLConf.CBO_ENABLED.key -> "true", SQLConf.JOIN_REORDER_ENABLED.key -> "true") { + val smallTable = Seq((1, "one"), (2, "two"), (4, "four")).toDF("number", "word") + val smallTable2 = Seq((1, "one"), (2, "two"), (4, "four")).toDF("number", "word") + val smallTable3 = Seq((1, "one"), (2, "two"), (4, "four")).toDF("number", "word") + + // Join the input stream with a table. + val df = spark.readStream.format("rate").load() + val joined = df.join(smallTable, smallTable("number") === $"value") + .join(smallTable2, smallTable2("number") === $"value") + .join(smallTable3, smallTable3("number") === $"value") + + val outputStream = new java.io.ByteArrayOutputStream() + Console.withOut(outputStream) { + joined.explain(true) + } + assert(outputStream.toString.contains("StreamingRelation")) + } + } - val outputStream = new java.io.ByteArrayOutputStream() - Console.withOut(outputStream) { - joined.explain() + test("explain join with MemoryStream") { + // This test triggers CostBasedJoinReorder to call `computeStats` + // Because MemoryStream doesn't use DataSource code path, we need a separate test. + withSQLConf(SQLConf.CBO_ENABLED.key -> "true", SQLConf.JOIN_REORDER_ENABLED.key -> "true") { + val smallTable = Seq((1, "one"), (2, "two"), (4, "four")).toDF("number", "word") + val smallTable2 = Seq((1, "one"), (2, "two"), (4, "four")).toDF("number", "word") + val smallTable3 = Seq((1, "one"), (2, "two"), (4, "four")).toDF("number", "word") + + // Join the input stream with a table. + val df = MemoryStream[Int].toDF + val joined = df.join(smallTable, smallTable("number") === $"value") + .join(smallTable2, smallTable2("number") === $"value") + .join(smallTable3, smallTable3("number") === $"value") + + val outputStream = new java.io.ByteArrayOutputStream() + Console.withOut(outputStream) { + joined.explain(true) + } + assert(outputStream.toString.contains("StreamingRelation")) } - assert(outputStream.toString.contains("StreamingRelation")) } test("SPARK-20432: union one stream with itself") { From ccdf21f56e4ff5497d7770dcbee2f7a60bb9e3a7 Mon Sep 17 00:00:00 2001 From: Jorge Machado Date: Wed, 11 Oct 2017 22:13:07 -0700 Subject: [PATCH 1474/1765] [SPARK-20055][DOCS] Added documentation for loading csv files into DataFrames ## What changes were proposed in this pull request? Added documentation for loading csv files into Dataframes ## How was this patch tested? /dev/run-tests Author: Jorge Machado Closes #19429 from jomach/master. --- docs/sql-programming-guide.md | 32 ++++++++++++++++--- .../sql/JavaSQLDataSourceExample.java | 7 ++++ examples/src/main/python/sql/datasource.py | 5 +++ examples/src/main/r/RSparkSQLExample.R | 6 ++++ examples/src/main/resources/people.csv | 3 ++ .../examples/sql/SQLDataSourceExample.scala | 8 +++++ 6 files changed, 56 insertions(+), 5 deletions(-) create mode 100644 examples/src/main/resources/people.csv diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md index a095263bfa619..639a8ea7bb8ad 100644 --- a/docs/sql-programming-guide.md +++ b/docs/sql-programming-guide.md @@ -461,6 +461,8 @@ name (i.e., `org.apache.spark.sql.parquet`), but for built-in sources you can al names (`json`, `parquet`, `jdbc`, `orc`, `libsvm`, `csv`, `text`). DataFrames loaded from any data source type can be converted into other types using this syntax. +To load a JSON file you can use: +
    {% include_example manual_load_options scala/org/apache/spark/examples/sql/SQLDataSourceExample.scala %} @@ -479,6 +481,26 @@ source type can be converted into other types using this syntax.
    +To load a CSV file you can use: + +
    +
    +{% include_example manual_load_options_csv scala/org/apache/spark/examples/sql/SQLDataSourceExample.scala %} +
    + +
    +{% include_example manual_load_options_csv java/org/apache/spark/examples/sql/JavaSQLDataSourceExample.java %} +
    + +
    +{% include_example manual_load_options_csv python/sql/datasource.py %} +
    + +
    +{% include_example manual_load_options_csv r/RSparkSQLExample.R %} + +
    +
    ### Run SQL on files directly Instead of using read API to load a file into DataFrame and query it, you can also query that @@ -573,7 +595,7 @@ Note that partition information is not gathered by default when creating externa ### Bucketing, Sorting and Partitioning -For file-based data source, it is also possible to bucket and sort or partition the output. +For file-based data source, it is also possible to bucket and sort or partition the output. Bucketing and sorting are applicable only to persistent tables:
    @@ -598,7 +620,7 @@ CREATE TABLE users_bucketed_by_name( name STRING, favorite_color STRING, favorite_numbers array -) USING parquet +) USING parquet CLUSTERED BY(name) INTO 42 BUCKETS; {% endhighlight %} @@ -629,7 +651,7 @@ while partitioning can be used with both `save` and `saveAsTable` when using the {% highlight sql %} CREATE TABLE users_by_favorite_color( - name STRING, + name STRING, favorite_color STRING, favorite_numbers array ) USING csv PARTITIONED BY(favorite_color); @@ -664,7 +686,7 @@ CREATE TABLE users_bucketed_and_partitioned( name STRING, favorite_color STRING, favorite_numbers array -) USING parquet +) USING parquet PARTITIONED BY (favorite_color) CLUSTERED BY(name) SORTED BY (favorite_numbers) INTO 42 BUCKETS; @@ -675,7 +697,7 @@ CLUSTERED BY(name) SORTED BY (favorite_numbers) INTO 42 BUCKETS;
    `partitionBy` creates a directory structure as described in the [Partition Discovery](#partition-discovery) section. -Thus, it has limited applicability to columns with high cardinality. In contrast +Thus, it has limited applicability to columns with high cardinality. In contrast `bucketBy` distributes data across a fixed number of buckets and can be used when a number of unique values is unbounded. diff --git a/examples/src/main/java/org/apache/spark/examples/sql/JavaSQLDataSourceExample.java b/examples/src/main/java/org/apache/spark/examples/sql/JavaSQLDataSourceExample.java index 95859c52c2aeb..ef3c904775697 100644 --- a/examples/src/main/java/org/apache/spark/examples/sql/JavaSQLDataSourceExample.java +++ b/examples/src/main/java/org/apache/spark/examples/sql/JavaSQLDataSourceExample.java @@ -116,6 +116,13 @@ private static void runBasicDataSourceExample(SparkSession spark) { spark.read().format("json").load("examples/src/main/resources/people.json"); peopleDF.select("name", "age").write().format("parquet").save("namesAndAges.parquet"); // $example off:manual_load_options$ + // $example on:manual_load_options_csv$ + Dataset peopleDFCsv = spark.read().format("csv") + .option("sep", ";") + .option("inferSchema", "true") + .option("header", "true") + .load("examples/src/main/resources/people.csv"); + // $example off:manual_load_options_csv$ // $example on:direct_sql$ Dataset sqlDF = spark.sql("SELECT * FROM parquet.`examples/src/main/resources/users.parquet`"); diff --git a/examples/src/main/python/sql/datasource.py b/examples/src/main/python/sql/datasource.py index f86012ea382e8..b375fa775de39 100644 --- a/examples/src/main/python/sql/datasource.py +++ b/examples/src/main/python/sql/datasource.py @@ -53,6 +53,11 @@ def basic_datasource_example(spark): df.select("name", "age").write.save("namesAndAges.parquet", format="parquet") # $example off:manual_load_options$ + # $example on:manual_load_options_csv$ + df = spark.read.load("examples/src/main/resources/people.csv", + format="csv", sep=":", inferSchema="true", header="true") + # $example off:manual_load_options_csv$ + # $example on:write_sorting_and_bucketing$ df.write.bucketBy(42, "name").sortBy("age").saveAsTable("people_bucketed") # $example off:write_sorting_and_bucketing$ diff --git a/examples/src/main/r/RSparkSQLExample.R b/examples/src/main/r/RSparkSQLExample.R index 3734568d872d0..a5ed723da47ca 100644 --- a/examples/src/main/r/RSparkSQLExample.R +++ b/examples/src/main/r/RSparkSQLExample.R @@ -113,6 +113,12 @@ write.df(namesAndAges, "namesAndAges.parquet", "parquet") # $example off:manual_load_options$ +# $example on:manual_load_options_csv$ +df <- read.df("examples/src/main/resources/people.csv", "csv") +namesAndAges <- select(df, "name", "age") +# $example off:manual_load_options_csv$ + + # $example on:direct_sql$ df <- sql("SELECT * FROM parquet.`examples/src/main/resources/users.parquet`") # $example off:direct_sql$ diff --git a/examples/src/main/resources/people.csv b/examples/src/main/resources/people.csv new file mode 100644 index 0000000000000..7fe5adba93d77 --- /dev/null +++ b/examples/src/main/resources/people.csv @@ -0,0 +1,3 @@ +name;age;job +Jorge;30;Developer +Bob;32;Developer diff --git a/examples/src/main/scala/org/apache/spark/examples/sql/SQLDataSourceExample.scala b/examples/src/main/scala/org/apache/spark/examples/sql/SQLDataSourceExample.scala index 86b3dc4a84f58..f9477969a4bb5 100644 --- a/examples/src/main/scala/org/apache/spark/examples/sql/SQLDataSourceExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/sql/SQLDataSourceExample.scala @@ -49,6 +49,14 @@ object SQLDataSourceExample { val peopleDF = spark.read.format("json").load("examples/src/main/resources/people.json") peopleDF.select("name", "age").write.format("parquet").save("namesAndAges.parquet") // $example off:manual_load_options$ + // $example on:manual_load_options_csv$ + val peopleDFCsv = spark.read.format("csv") + .option("sep", ";") + .option("inferSchema", "true") + .option("header", "true") + .load("examples/src/main/resources/people.csv") + // $example off:manual_load_options_csv$ + // $example on:direct_sql$ val sqlDF = spark.sql("SELECT * FROM parquet.`examples/src/main/resources/users.parquet`") // $example off:direct_sql$ From 274f0efefa0c063649bccddb787e8863910f4366 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Thu, 12 Oct 2017 20:20:44 +0800 Subject: [PATCH 1475/1765] [SPARK-22252][SQL] FileFormatWriter should respect the input query schema ## What changes were proposed in this pull request? In https://github.com/apache/spark/pull/18064, we allowed `RunnableCommand` to have children in order to fix some UI issues. Then we made `InsertIntoXXX` commands take the input `query` as a child, when we do the actual writing, we just pass the physical plan to the writer(`FileFormatWriter.write`). However this is problematic. In Spark SQL, optimizer and planner are allowed to change the schema names a little bit. e.g. `ColumnPruning` rule will remove no-op `Project`s, like `Project("A", Scan("a"))`, and thus change the output schema from "" to ``. When it comes to writing, especially for self-description data format like parquet, we may write the wrong schema to the file and cause null values at the read path. Fortunately, in https://github.com/apache/spark/pull/18450 , we decided to allow nested execution and one query can map to multiple executions in the UI. This releases the major restriction in #18604 , and now we don't have to take the input `query` as child of `InsertIntoXXX` commands. So the fix is simple, this PR partially revert #18064 and make `InsertIntoXXX` commands leaf nodes again. ## How was this patch tested? new regression test Author: Wenchen Fan Closes #19474 from cloud-fan/bug. --- .../spark/sql/catalyst/plans/QueryPlan.scala | 2 +- .../sql/catalyst/plans/logical/Command.scala | 3 +- .../spark/sql/execution/QueryExecution.scala | 4 +-- .../spark/sql/execution/SparkStrategies.scala | 2 +- .../execution/columnar/InMemoryRelation.scala | 3 +- .../columnar/InMemoryTableScanExec.scala | 2 +- .../command/DataWritingCommand.scala | 13 ++++++++ .../InsertIntoDataSourceDirCommand.scala | 6 ++-- .../spark/sql/execution/command/cache.scala | 2 +- .../sql/execution/command/commands.scala | 30 ++++++------------- .../command/createDataSourceTables.scala | 2 +- .../spark/sql/execution/command/views.scala | 4 +-- .../execution/datasources/DataSource.scala | 13 +++++++- .../datasources/FileFormatWriter.scala | 14 +++++---- .../InsertIntoDataSourceCommand.scala | 2 +- .../InsertIntoHadoopFsRelationCommand.scala | 9 ++---- .../SaveIntoDataSourceCommand.scala | 2 +- .../execution/streaming/FileStreamSink.scala | 2 +- .../datasources/FileFormatWriterSuite.scala | 16 +++++++++- .../execution/InsertIntoHiveDirCommand.scala | 10 ++----- .../hive/execution/InsertIntoHiveTable.scala | 14 ++------- .../sql/hive/execution/SaveAsHiveFile.scala | 6 ++-- 22 files changed, 85 insertions(+), 76 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala index 7addbaaa9afa5..c7952e3ff8280 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala @@ -178,7 +178,7 @@ abstract class QueryPlan[PlanType <: QueryPlan[PlanType]] extends TreeNode[PlanT }) } - override def innerChildren: Seq[QueryPlan[_]] = subqueries + override protected def innerChildren: Seq[QueryPlan[_]] = subqueries /** * Returns a plan where a best effort attempt has been made to transform `this` in a way diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/Command.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/Command.scala index ec5766e1f67f2..38f47081b6f55 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/Command.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/Command.scala @@ -24,7 +24,6 @@ import org.apache.spark.sql.catalyst.expressions.Attribute * commands can be used by parsers to represent DDL operations. Commands, unlike queries, are * eagerly executed. */ -trait Command extends LogicalPlan { +trait Command extends LeafNode { override def output: Seq[Attribute] = Seq.empty - override def children: Seq[LogicalPlan] = Seq.empty } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala index 4accf54a18232..f404621399cea 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala @@ -119,7 +119,7 @@ class QueryExecution(val sparkSession: SparkSession, val logical: LogicalPlan) { * `SparkSQLDriver` for CLI applications. */ def hiveResultString(): Seq[String] = executedPlan match { - case ExecutedCommandExec(desc: DescribeTableCommand, _) => + case ExecutedCommandExec(desc: DescribeTableCommand) => // If it is a describe command for a Hive table, we want to have the output format // be similar with Hive. desc.run(sparkSession).map { @@ -130,7 +130,7 @@ class QueryExecution(val sparkSession: SparkSession, val logical: LogicalPlan) { .mkString("\t") } // SHOW TABLES in Hive only output table names, while ours output database, table name, isTemp. - case command @ ExecutedCommandExec(s: ShowTablesCommand, _) if !s.isExtended => + case command @ ExecutedCommandExec(s: ShowTablesCommand) if !s.isExtended => command.executeCollect().map(_.getString(1)) case other => val result: Seq[Seq[Any]] = other.executeCollectPublic().map(_.toSeq).toSeq diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 4cdcc73faacd7..19b858faba6ea 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -364,7 +364,7 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { // Can we automate these 'pass through' operations? object BasicOperators extends Strategy { def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { - case r: RunnableCommand => ExecutedCommandExec(r, r.children.map(planLater)) :: Nil + case r: RunnableCommand => ExecutedCommandExec(r) :: Nil case MemoryPlan(sink, output) => val encoder = RowEncoder(sink.schema) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala index bc98d8d9d6d61..a1c62a729900e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala @@ -62,7 +62,8 @@ case class InMemoryRelation( @transient var _cachedColumnBuffers: RDD[CachedBatch] = null, val batchStats: LongAccumulator = child.sqlContext.sparkContext.longAccumulator) extends logical.LeafNode with MultiInstanceRelation { - override def innerChildren: Seq[SparkPlan] = Seq(child) + + override protected def innerChildren: Seq[SparkPlan] = Seq(child) override def producedAttributes: AttributeSet = outputSet diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala index c7ddec55682e1..af3636a5a2ca7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala @@ -34,7 +34,7 @@ case class InMemoryTableScanExec( @transient relation: InMemoryRelation) extends LeafExecNode { - override def innerChildren: Seq[QueryPlan[_]] = Seq(relation) ++ super.innerChildren + override protected def innerChildren: Seq[QueryPlan[_]] = Seq(relation) ++ super.innerChildren override lazy val metrics = Map( "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows")) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/DataWritingCommand.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/DataWritingCommand.scala index 4e1c5e4846f36..2cf06982e25f6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/DataWritingCommand.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/DataWritingCommand.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.execution.command import org.apache.hadoop.conf.Configuration import org.apache.spark.SparkContext +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.execution.datasources.BasicWriteJobStatsTracker import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics} import org.apache.spark.util.SerializableConfiguration @@ -30,6 +31,18 @@ import org.apache.spark.util.SerializableConfiguration */ trait DataWritingCommand extends RunnableCommand { + /** + * The input query plan that produces the data to be written. + */ + def query: LogicalPlan + + // We make the input `query` an inner child instead of a child in order to hide it from the + // optimizer. This is because optimizer may not preserve the output schema names' case, and we + // have to keep the original analyzed plan here so that we can pass the corrected schema to the + // writer. The schema of analyzed plan is what user expects(or specifies), so we should respect + // it when writing. + override protected def innerChildren: Seq[LogicalPlan] = query :: Nil + override lazy val metrics: Map[String, SQLMetric] = { val sparkContext = SparkContext.getActive.get Map( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/InsertIntoDataSourceDirCommand.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/InsertIntoDataSourceDirCommand.scala index 633de4c37af94..9e3519073303c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/InsertIntoDataSourceDirCommand.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/InsertIntoDataSourceDirCommand.scala @@ -21,7 +21,6 @@ import org.apache.spark.SparkException import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.catalog._ import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan -import org.apache.spark.sql.execution.SparkPlan import org.apache.spark.sql.execution.datasources._ /** @@ -45,10 +44,9 @@ case class InsertIntoDataSourceDirCommand( query: LogicalPlan, overwrite: Boolean) extends RunnableCommand { - override def children: Seq[LogicalPlan] = Seq(query) + override protected def innerChildren: Seq[LogicalPlan] = query :: Nil - override def run(sparkSession: SparkSession, children: Seq[SparkPlan]): Seq[Row] = { - assert(children.length == 1) + override def run(sparkSession: SparkSession): Seq[Row] = { assert(storage.locationUri.nonEmpty, "Directory path is required") assert(provider.nonEmpty, "Data source is required") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/cache.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/cache.scala index 792290bef0163..140f920eaafae 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/cache.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/cache.scala @@ -30,7 +30,7 @@ case class CacheTableCommand( require(plan.isEmpty || tableIdent.database.isEmpty, "Database name is not allowed in CACHE TABLE AS SELECT") - override def innerChildren: Seq[QueryPlan[_]] = plan.toSeq + override protected def innerChildren: Seq[QueryPlan[_]] = plan.toSeq override def run(sparkSession: SparkSession): Seq[Row] = { plan.foreach { logicalPlan => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/commands.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/commands.scala index 7cd4baef89e75..e28b5eb2e2a2b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/commands.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/commands.scala @@ -24,9 +24,9 @@ import org.apache.spark.sql.{Row, SparkSession} import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} import org.apache.spark.sql.catalyst.errors.TreeNodeException import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference} -import org.apache.spark.sql.catalyst.plans.{logical, QueryPlan} -import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan -import org.apache.spark.sql.execution.SparkPlan +import org.apache.spark.sql.catalyst.plans.QueryPlan +import org.apache.spark.sql.catalyst.plans.logical.{Command, LogicalPlan} +import org.apache.spark.sql.execution.LeafExecNode import org.apache.spark.sql.execution.debug._ import org.apache.spark.sql.execution.metric.SQLMetric import org.apache.spark.sql.execution.streaming.{IncrementalExecution, OffsetSeqMetadata} @@ -37,19 +37,13 @@ import org.apache.spark.sql.types._ * A logical command that is executed for its side-effects. `RunnableCommand`s are * wrapped in `ExecutedCommand` during execution. */ -trait RunnableCommand extends logical.Command { +trait RunnableCommand extends Command { // The map used to record the metrics of running the command. This will be passed to // `ExecutedCommand` during query planning. lazy val metrics: Map[String, SQLMetric] = Map.empty - def run(sparkSession: SparkSession, children: Seq[SparkPlan]): Seq[Row] = { - throw new NotImplementedError - } - - def run(sparkSession: SparkSession): Seq[Row] = { - throw new NotImplementedError - } + def run(sparkSession: SparkSession): Seq[Row] } /** @@ -57,9 +51,8 @@ trait RunnableCommand extends logical.Command { * saves the result to prevent multiple executions. * * @param cmd the `RunnableCommand` this operator will run. - * @param children the children physical plans ran by the `RunnableCommand`. */ -case class ExecutedCommandExec(cmd: RunnableCommand, children: Seq[SparkPlan]) extends SparkPlan { +case class ExecutedCommandExec(cmd: RunnableCommand) extends LeafExecNode { override lazy val metrics: Map[String, SQLMetric] = cmd.metrics @@ -74,19 +67,14 @@ case class ExecutedCommandExec(cmd: RunnableCommand, children: Seq[SparkPlan]) e */ protected[sql] lazy val sideEffectResult: Seq[InternalRow] = { val converter = CatalystTypeConverters.createToCatalystConverter(schema) - val rows = if (children.isEmpty) { - cmd.run(sqlContext.sparkSession) - } else { - cmd.run(sqlContext.sparkSession, children) - } - rows.map(converter(_).asInstanceOf[InternalRow]) + cmd.run(sqlContext.sparkSession).map(converter(_).asInstanceOf[InternalRow]) } - override def innerChildren: Seq[QueryPlan[_]] = cmd.innerChildren + override protected def innerChildren: Seq[QueryPlan[_]] = cmd :: Nil override def output: Seq[Attribute] = cmd.output - override def nodeName: String = cmd.nodeName + override def nodeName: String = "Execute " + cmd.nodeName override def executeCollect(): Array[InternalRow] = sideEffectResult.toArray diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/createDataSourceTables.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/createDataSourceTables.scala index 04b2534ca5eb1..9e3907996995c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/createDataSourceTables.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/createDataSourceTables.scala @@ -120,7 +120,7 @@ case class CreateDataSourceTableAsSelectCommand( query: LogicalPlan) extends RunnableCommand { - override def innerChildren: Seq[LogicalPlan] = Seq(query) + override protected def innerChildren: Seq[LogicalPlan] = Seq(query) override def run(sparkSession: SparkSession): Seq[Row] = { assert(table.tableType != CatalogTableType.VIEW) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/views.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/views.scala index ffdfd527fa701..5172f32ec7b9c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/views.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/views.scala @@ -98,7 +98,7 @@ case class CreateViewCommand( import ViewHelper._ - override def innerChildren: Seq[QueryPlan[_]] = Seq(child) + override protected def innerChildren: Seq[QueryPlan[_]] = Seq(child) if (viewType == PersistedView) { require(originalText.isDefined, "'originalText' must be provided to create permanent view") @@ -267,7 +267,7 @@ case class AlterViewAsCommand( import ViewHelper._ - override def innerChildren: Seq[QueryPlan[_]] = Seq(query) + override protected def innerChildren: Seq[QueryPlan[_]] = Seq(query) override def run(session: SparkSession): Seq[Row] = { // If the plan cannot be analyzed, throw an exception and don't proceed. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala index b9502a95a7c08..b43d282bd434c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala @@ -453,6 +453,17 @@ case class DataSource( val caseSensitive = sparkSession.sessionState.conf.caseSensitiveAnalysis PartitioningUtils.validatePartitionColumn(data.schema, partitionColumns, caseSensitive) + + // SPARK-17230: Resolve the partition columns so InsertIntoHadoopFsRelationCommand does + // not need to have the query as child, to avoid to analyze an optimized query, + // because InsertIntoHadoopFsRelationCommand will be optimized first. + val partitionAttributes = partitionColumns.map { name => + data.output.find(a => equality(a.name, name)).getOrElse { + throw new AnalysisException( + s"Unable to resolve $name given [${data.output.map(_.name).mkString(", ")}]") + } + } + val fileIndex = catalogTable.map(_.identifier).map { tableIdent => sparkSession.table(tableIdent).queryExecution.analyzed.collect { case LogicalRelation(t: HadoopFsRelation, _, _, _) => t.location @@ -465,7 +476,7 @@ case class DataSource( outputPath = outputPath, staticPartitions = Map.empty, ifPartitionNotExists = false, - partitionColumns = partitionColumns.map(UnresolvedAttribute.quoted), + partitionColumns = partitionAttributes, bucketSpec = bucketSpec, fileFormat = format, options = options, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala index 514969715091a..75b1695fbc275 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala @@ -39,7 +39,7 @@ import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec import org.apache.spark.sql.catalyst.expressions.{UnsafeProjection, _} import org.apache.spark.sql.catalyst.plans.physical.HashPartitioning import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, DateTimeUtils} -import org.apache.spark.sql.execution.{SortExec, SparkPlan, SQLExecution} +import org.apache.spark.sql.execution.{QueryExecution, SortExec, SQLExecution} import org.apache.spark.sql.types.StringType import org.apache.spark.util.{SerializableConfiguration, Utils} @@ -101,7 +101,7 @@ object FileFormatWriter extends Logging { */ def write( sparkSession: SparkSession, - plan: SparkPlan, + queryExecution: QueryExecution, fileFormat: FileFormat, committer: FileCommitProtocol, outputSpec: OutputSpec, @@ -117,7 +117,9 @@ object FileFormatWriter extends Logging { job.setOutputValueClass(classOf[InternalRow]) FileOutputFormat.setOutputPath(job, new Path(outputSpec.outputPath)) - val allColumns = plan.output + // Pick the attributes from analyzed plan, as optimizer may not preserve the output schema + // names' case. + val allColumns = queryExecution.analyzed.output val partitionSet = AttributeSet(partitionColumns) val dataColumns = allColumns.filterNot(partitionSet.contains) @@ -158,7 +160,7 @@ object FileFormatWriter extends Logging { // We should first sort by partition columns, then bucket id, and finally sorting columns. val requiredOrdering = partitionColumns ++ bucketIdExpression ++ sortColumns // the sort order doesn't matter - val actualOrdering = plan.outputOrdering.map(_.child) + val actualOrdering = queryExecution.executedPlan.outputOrdering.map(_.child) val orderingMatched = if (requiredOrdering.length > actualOrdering.length) { false } else { @@ -176,12 +178,12 @@ object FileFormatWriter extends Logging { try { val rdd = if (orderingMatched) { - plan.execute() + queryExecution.toRdd } else { SortExec( requiredOrdering.map(SortOrder(_, Ascending)), global = false, - child = plan).execute() + child = queryExecution.executedPlan).execute() } val ret = new Array[WriteTaskResult](rdd.partitions.length) sparkSession.sparkContext.runJob( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoDataSourceCommand.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoDataSourceCommand.scala index 08b2f4f31170f..a813829d50cb1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoDataSourceCommand.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoDataSourceCommand.scala @@ -33,7 +33,7 @@ case class InsertIntoDataSourceCommand( overwrite: Boolean) extends RunnableCommand { - override def innerChildren: Seq[QueryPlan[_]] = Seq(query) + override protected def innerChildren: Seq[QueryPlan[_]] = Seq(query) override def run(sparkSession: SparkSession): Seq[Row] = { val relation = logicalRelation.relation.asInstanceOf[InsertableRelation] diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelationCommand.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelationCommand.scala index 64e5a57adc37c..675bee85bf61e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelationCommand.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelationCommand.scala @@ -27,7 +27,6 @@ import org.apache.spark.sql.catalyst.catalog.{BucketSpec, CatalogTable, CatalogT import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan -import org.apache.spark.sql.execution.SparkPlan import org.apache.spark.sql.execution.command._ import org.apache.spark.sql.util.SchemaUtils @@ -57,11 +56,7 @@ case class InsertIntoHadoopFsRelationCommand( extends DataWritingCommand { import org.apache.spark.sql.catalyst.catalog.ExternalCatalogUtils.escapePathName - override def children: Seq[LogicalPlan] = query :: Nil - - override def run(sparkSession: SparkSession, children: Seq[SparkPlan]): Seq[Row] = { - assert(children.length == 1) - + override def run(sparkSession: SparkSession): Seq[Row] = { // Most formats don't do well with duplicate columns, so lets not allow that SchemaUtils.checkSchemaColumnNameDuplication( query.schema, @@ -144,7 +139,7 @@ case class InsertIntoHadoopFsRelationCommand( val updatedPartitionPaths = FileFormatWriter.write( sparkSession = sparkSession, - plan = children.head, + queryExecution = Dataset.ofRows(sparkSession, query).queryExecution, fileFormat = fileFormat, committer = committer, outputSpec = FileFormatWriter.OutputSpec( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SaveIntoDataSourceCommand.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SaveIntoDataSourceCommand.scala index 5eb6a8471be0d..96c84eab1c894 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SaveIntoDataSourceCommand.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SaveIntoDataSourceCommand.scala @@ -38,7 +38,7 @@ case class SaveIntoDataSourceCommand( options: Map[String, String], mode: SaveMode) extends RunnableCommand { - override def innerChildren: Seq[QueryPlan[_]] = Seq(query) + override protected def innerChildren: Seq[QueryPlan[_]] = Seq(query) override def run(sparkSession: SparkSession): Seq[Row] = { dataSource.createRelation( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSink.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSink.scala index 72e5ac40bbfed..6bd0696622005 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSink.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSink.scala @@ -121,7 +121,7 @@ class FileStreamSink( FileFormatWriter.write( sparkSession = sparkSession, - plan = data.queryExecution.executedPlan, + queryExecution = data.queryExecution, fileFormat = fileFormat, committer = committer, outputSpec = FileFormatWriter.OutputSpec(path, Map.empty), diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileFormatWriterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileFormatWriterSuite.scala index a0c1ea63d3827..6f8767db176aa 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileFormatWriterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileFormatWriterSuite.scala @@ -17,10 +17,11 @@ package org.apache.spark.sql.execution.datasources -import org.apache.spark.sql.QueryTest +import org.apache.spark.sql.{QueryTest, Row} import org.apache.spark.sql.test.SharedSQLContext class FileFormatWriterSuite extends QueryTest with SharedSQLContext { + import testImplicits._ test("empty file should be skipped while write to file") { withTempPath { path => @@ -30,4 +31,17 @@ class FileFormatWriterSuite extends QueryTest with SharedSQLContext { assert(partFiles.length === 2) } } + + test("FileFormatWriter should respect the input query schema") { + withTable("t1", "t2", "t3", "t4") { + spark.range(1).select('id as 'col1, 'id as 'col2).write.saveAsTable("t1") + spark.sql("select COL1, COL2 from t1").write.saveAsTable("t2") + checkAnswer(spark.table("t2"), Row(0, 0)) + + // Test picking part of the columns when writing. + spark.range(1).select('id, 'id as 'col1, 'id as 'col2).write.saveAsTable("t3") + spark.sql("select COL1, COL2 from t3").write.saveAsTable("t4") + checkAnswer(spark.table("t4"), Row(0, 0)) + } + } } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveDirCommand.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveDirCommand.scala index 918c8be00d69d..1c6f8dd77fc2c 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveDirCommand.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveDirCommand.scala @@ -27,11 +27,10 @@ import org.apache.hadoop.hive.serde2.`lazy`.LazySimpleSerDe import org.apache.hadoop.mapred._ import org.apache.spark.SparkException -import org.apache.spark.sql.{Row, SparkSession} +import org.apache.spark.sql.{Dataset, Row, SparkSession} import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.catalog.{CatalogStorageFormat, CatalogTable} import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan -import org.apache.spark.sql.execution.SparkPlan import org.apache.spark.sql.hive.client.HiveClientImpl /** @@ -57,10 +56,7 @@ case class InsertIntoHiveDirCommand( query: LogicalPlan, overwrite: Boolean) extends SaveAsHiveFile { - override def children: Seq[LogicalPlan] = query :: Nil - - override def run(sparkSession: SparkSession, children: Seq[SparkPlan]): Seq[Row] = { - assert(children.length == 1) + override def run(sparkSession: SparkSession): Seq[Row] = { assert(storage.locationUri.nonEmpty) val hiveTable = HiveClientImpl.toHiveTable(CatalogTable( @@ -102,7 +98,7 @@ case class InsertIntoHiveDirCommand( try { saveAsHiveFile( sparkSession = sparkSession, - plan = children.head, + queryExecution = Dataset.ofRows(sparkSession, query).queryExecution, hadoopConf = hadoopConf, fileSinkConf = fileSinkConf, outputLocation = tmpPath.toString) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala index e5b59ed7a1a6b..56e10bc457a00 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala @@ -17,20 +17,16 @@ package org.apache.spark.sql.hive.execution -import scala.util.control.NonFatal - import org.apache.hadoop.fs.Path import org.apache.hadoop.hive.ql.ErrorMsg import org.apache.hadoop.hive.ql.plan.TableDesc import org.apache.spark.SparkException -import org.apache.spark.sql.{AnalysisException, Row, SparkSession} +import org.apache.spark.sql.{AnalysisException, Dataset, Row, SparkSession} import org.apache.spark.sql.catalyst.catalog.CatalogTable import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan -import org.apache.spark.sql.execution.SparkPlan import org.apache.spark.sql.execution.command.CommandUtils -import org.apache.spark.sql.hive._ import org.apache.spark.sql.hive.HiveShim.{ShimFileSinkDesc => FileSinkDesc} import org.apache.spark.sql.hive.client.HiveClientImpl @@ -72,16 +68,12 @@ case class InsertIntoHiveTable( overwrite: Boolean, ifPartitionNotExists: Boolean) extends SaveAsHiveFile { - override def children: Seq[LogicalPlan] = query :: Nil - /** * Inserts all the rows in the table into Hive. Row objects are properly serialized with the * `org.apache.hadoop.hive.serde2.SerDe` and the * `org.apache.hadoop.mapred.OutputFormat` provided by the table definition. */ - override def run(sparkSession: SparkSession, children: Seq[SparkPlan]): Seq[Row] = { - assert(children.length == 1) - + override def run(sparkSession: SparkSession): Seq[Row] = { val externalCatalog = sparkSession.sharedState.externalCatalog val hadoopConf = sparkSession.sessionState.newHadoopConf() @@ -170,7 +162,7 @@ case class InsertIntoHiveTable( saveAsHiveFile( sparkSession = sparkSession, - plan = children.head, + queryExecution = Dataset.ofRows(sparkSession, query).queryExecution, hadoopConf = hadoopConf, fileSinkConf = fileSinkConf, outputLocation = tmpLocation.toString, diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/SaveAsHiveFile.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/SaveAsHiveFile.scala index 2d74ef040ef5a..63657590e5e79 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/SaveAsHiveFile.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/SaveAsHiveFile.scala @@ -33,7 +33,7 @@ import org.apache.spark.internal.io.FileCommitProtocol import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec import org.apache.spark.sql.catalyst.expressions.Attribute -import org.apache.spark.sql.execution.SparkPlan +import org.apache.spark.sql.execution.QueryExecution import org.apache.spark.sql.execution.command.DataWritingCommand import org.apache.spark.sql.execution.datasources.FileFormatWriter import org.apache.spark.sql.hive.HiveExternalCatalog @@ -47,7 +47,7 @@ private[hive] trait SaveAsHiveFile extends DataWritingCommand { protected def saveAsHiveFile( sparkSession: SparkSession, - plan: SparkPlan, + queryExecution: QueryExecution, hadoopConf: Configuration, fileSinkConf: FileSinkDesc, outputLocation: String, @@ -75,7 +75,7 @@ private[hive] trait SaveAsHiveFile extends DataWritingCommand { FileFormatWriter.write( sparkSession = sparkSession, - plan = plan, + queryExecution = queryExecution, fileFormat = new HiveFileFormat(fileSinkConf), committer = committer, outputSpec = FileFormatWriter.OutputSpec(outputLocation, customPartitionLocations), From b5c1ef7a8e4db4067bc361d10d554ee9a538423f Mon Sep 17 00:00:00 2001 From: Xianyang Liu Date: Thu, 12 Oct 2017 20:26:51 +0800 Subject: [PATCH 1476/1765] [SPARK-22097][CORE] Request an accurate memory after we unrolled the block ## What changes were proposed in this pull request? We only need request `bbos.size - unrollMemoryUsedByThisBlock` after unrolled the block. ## How was this patch tested? Existing UT. Author: Xianyang Liu Closes #19316 from ConeyLiu/putIteratorAsBytes. --- .../org/apache/spark/storage/memory/MemoryStore.scala | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/storage/memory/MemoryStore.scala b/core/src/main/scala/org/apache/spark/storage/memory/MemoryStore.scala index 651e9c7b2ab61..17f7a69ad6ba1 100644 --- a/core/src/main/scala/org/apache/spark/storage/memory/MemoryStore.scala +++ b/core/src/main/scala/org/apache/spark/storage/memory/MemoryStore.scala @@ -388,7 +388,13 @@ private[spark] class MemoryStore( // perform one final call to attempt to allocate additional memory if necessary. if (keepUnrolling) { serializationStream.close() - reserveAdditionalMemoryIfNecessary() + if (bbos.size > unrollMemoryUsedByThisBlock) { + val amountToRequest = bbos.size - unrollMemoryUsedByThisBlock + keepUnrolling = reserveUnrollMemoryForThisTask(blockId, amountToRequest, memoryMode) + if (keepUnrolling) { + unrollMemoryUsedByThisBlock += amountToRequest + } + } } if (keepUnrolling) { From 73d80ec49713605d6a589e688020f0fc2d6feab2 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Thu, 12 Oct 2017 20:34:03 +0800 Subject: [PATCH 1477/1765] [SPARK-22197][SQL] push down operators to data source before planning ## What changes were proposed in this pull request? As we discussed in https://github.com/apache/spark/pull/19136#discussion_r137023744 , we should push down operators to data source before planning, so that data source can report statistics more accurate. This PR also includes some cleanup for the read path. ## How was this patch tested? existing tests. Author: Wenchen Fan Closes #19424 from cloud-fan/follow. --- .../spark/sql/sources/v2/ReadSupport.java | 5 +- .../sql/sources/v2/ReadSupportWithSchema.java | 5 +- .../sql/sources/v2/reader/DataReader.java | 4 + .../sources/v2/reader/DataSourceV2Reader.java | 2 +- .../spark/sql/sources/v2/reader/ReadTask.java | 3 +- .../SupportsPushDownCatalystFilters.java | 8 + .../v2/reader/SupportsPushDownFilters.java | 8 + .../apache/spark/sql/DataFrameReader.scala | 5 +- .../spark/sql/execution/SparkOptimizer.scala | 4 +- .../v2/DataSourceReaderHolder.scala | 68 +++++++++ .../datasources/v2/DataSourceV2Relation.scala | 8 +- .../datasources/v2/DataSourceV2ScanExec.scala | 22 +-- .../datasources/v2/DataSourceV2Strategy.scala | 60 +------- .../v2/PushDownOperatorsToDataSource.scala | 140 ++++++++++++++++++ .../sources/v2/JavaAdvancedDataSourceV2.java | 5 + .../sql/sources/v2/DataSourceV2Suite.scala | 2 + 16 files changed, 262 insertions(+), 87 deletions(-) create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceReaderHolder.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownOperatorsToDataSource.scala diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/ReadSupport.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/ReadSupport.java index ab5254a688d5a..ee489ad0f608f 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/ReadSupport.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/ReadSupport.java @@ -30,9 +30,8 @@ public interface ReadSupport { /** * Creates a {@link DataSourceV2Reader} to scan the data from this data source. * - * @param options the options for this data source reader, which is an immutable case-insensitive - * string-to-string map. - * @return a reader that implements the actual read logic. + * @param options the options for the returned data source reader, which is an immutable + * case-insensitive string-to-string map. */ DataSourceV2Reader createReader(DataSourceV2Options options); } diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/ReadSupportWithSchema.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/ReadSupportWithSchema.java index c13aeca2ef36f..74e81a2c84d68 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/ReadSupportWithSchema.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/ReadSupportWithSchema.java @@ -39,9 +39,8 @@ public interface ReadSupportWithSchema { * physical schema of the underlying storage of this data source reader, e.g. * CSV files, JSON files, etc, while this reader may not read data with full * schema, as column pruning or other optimizations may happen. - * @param options the options for this data source reader, which is an immutable case-insensitive - * string-to-string map. - * @return a reader that implements the actual read logic. + * @param options the options for the returned data source reader, which is an immutable + * case-insensitive string-to-string map. */ DataSourceV2Reader createReader(StructType schema, DataSourceV2Options options); } diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/DataReader.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/DataReader.java index cfafc1a576793..95e091569b614 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/DataReader.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/DataReader.java @@ -24,6 +24,10 @@ /** * A data reader returned by {@link ReadTask#createReader()} and is responsible for outputting data * for a RDD partition. + * + * Note that, Currently the type `T` can only be {@link org.apache.spark.sql.Row} for normal data + * source readers, or {@link org.apache.spark.sql.catalyst.expressions.UnsafeRow} for data source + * readers that mix in {@link SupportsScanUnsafeRow}. */ @InterfaceStability.Evolving public interface DataReader extends Closeable { diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/DataSourceV2Reader.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/DataSourceV2Reader.java index fb4d5c0d7ae41..5989a4ac8440b 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/DataSourceV2Reader.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/DataSourceV2Reader.java @@ -30,7 +30,7 @@ * {@link org.apache.spark.sql.sources.v2.ReadSupportWithSchema#createReader( * StructType, org.apache.spark.sql.sources.v2.DataSourceV2Options)}. * It can mix in various query optimization interfaces to speed up the data scan. The actual scan - * logic should be delegated to {@link ReadTask}s that are returned by {@link #createReadTasks()}. + * logic is delegated to {@link ReadTask}s that are returned by {@link #createReadTasks()}. * * There are mainly 3 kinds of query optimizations: * 1. Operators push-down. E.g., filter push-down, required columns push-down(aka column diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/ReadTask.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/ReadTask.java index 7885bfcdd49e4..01362df0978cb 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/ReadTask.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/ReadTask.java @@ -27,7 +27,8 @@ * is similar to the relationship between {@link Iterable} and {@link java.util.Iterator}. * * Note that, the read task will be serialized and sent to executors, then the data reader will be - * created on executors and do the actual reading. + * created on executors and do the actual reading. So {@link ReadTask} must be serializable and + * {@link DataReader} doesn't need to be. */ @InterfaceStability.Evolving public interface ReadTask extends Serializable { diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsPushDownCatalystFilters.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsPushDownCatalystFilters.java index 19d706238ec8e..d6091774d75aa 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsPushDownCatalystFilters.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsPushDownCatalystFilters.java @@ -40,4 +40,12 @@ public interface SupportsPushDownCatalystFilters { * Pushes down filters, and returns unsupported filters. */ Expression[] pushCatalystFilters(Expression[] filters); + + /** + * Returns the catalyst filters that are pushed in {@link #pushCatalystFilters(Expression[])}. + * It's possible that there is no filters in the query and + * {@link #pushCatalystFilters(Expression[])} is never called, empty array should be returned for + * this case. + */ + Expression[] pushedCatalystFilters(); } diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsPushDownFilters.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsPushDownFilters.java index d4b509e7080f2..d6f297c013375 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsPushDownFilters.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsPushDownFilters.java @@ -18,6 +18,7 @@ package org.apache.spark.sql.sources.v2.reader; import org.apache.spark.annotation.InterfaceStability; +import org.apache.spark.sql.catalyst.expressions.Expression; import org.apache.spark.sql.sources.Filter; /** @@ -35,4 +36,11 @@ public interface SupportsPushDownFilters { * Pushes down filters, and returns unsupported filters. */ Filter[] pushFilters(Filter[] filters); + + /** + * Returns the filters that are pushed in {@link #pushFilters(Filter[])}. + * It's possible that there is no filters in the query and {@link #pushFilters(Filter[])} + * is never called, empty array should be returned for this case. + */ + Filter[] pushedFilters(); } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala index 78b668c04fd5c..17966eecfc051 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala @@ -184,7 +184,6 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { val cls = DataSource.lookupDataSource(source) if (classOf[DataSourceV2].isAssignableFrom(cls)) { - val dataSource = cls.newInstance() val options = new DataSourceV2Options(extraOptions.asJava) val reader = (cls.newInstance(), userSpecifiedSchema) match { @@ -194,8 +193,8 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { case (ds: ReadSupport, None) => ds.createReader(options) - case (_: ReadSupportWithSchema, None) => - throw new AnalysisException(s"A schema needs to be specified when using $dataSource.") + case (ds: ReadSupportWithSchema, None) => + throw new AnalysisException(s"A schema needs to be specified when using $ds.") case (ds: ReadSupport, Some(schema)) => val reader = ds.createReader(options) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala index 00ff4c8ac310b..1c8e4050978dc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala @@ -21,6 +21,7 @@ import org.apache.spark.sql.ExperimentalMethods import org.apache.spark.sql.catalyst.catalog.SessionCatalog import org.apache.spark.sql.catalyst.optimizer.Optimizer import org.apache.spark.sql.execution.datasources.PruneFileSourcePartitions +import org.apache.spark.sql.execution.datasources.v2.PushDownOperatorsToDataSource import org.apache.spark.sql.execution.python.ExtractPythonUDFFromAggregate class SparkOptimizer( @@ -31,7 +32,8 @@ class SparkOptimizer( override def batches: Seq[Batch] = (preOptimizationBatches ++ super.batches :+ Batch("Optimize Metadata Only Query", Once, OptimizeMetadataOnlyQuery(catalog)) :+ Batch("Extract Python UDF from Aggregate", Once, ExtractPythonUDFFromAggregate) :+ - Batch("Prune File Source Table Partitions", Once, PruneFileSourcePartitions)) ++ + Batch("Prune File Source Table Partitions", Once, PruneFileSourcePartitions) :+ + Batch("Push down operators to data source scan", Once, PushDownOperatorsToDataSource)) ++ postHocOptimizationBatches :+ Batch("User Provided Optimizers", fixedPoint, experimentalMethods.extraOptimizations: _*) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceReaderHolder.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceReaderHolder.scala new file mode 100644 index 0000000000000..6093df26630cd --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceReaderHolder.scala @@ -0,0 +1,68 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources.v2 + +import java.util.Objects + +import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference} +import org.apache.spark.sql.sources.v2.reader._ + +/** + * A base class for data source reader holder with customized equals/hashCode methods. + */ +trait DataSourceReaderHolder { + + /** + * The full output of the data source reader, without column pruning. + */ + def fullOutput: Seq[AttributeReference] + + /** + * The held data source reader. + */ + def reader: DataSourceV2Reader + + /** + * The metadata of this data source reader that can be used for equality test. + */ + private def metadata: Seq[Any] = { + val filters: Any = reader match { + case s: SupportsPushDownCatalystFilters => s.pushedCatalystFilters().toSet + case s: SupportsPushDownFilters => s.pushedFilters().toSet + case _ => Nil + } + Seq(fullOutput, reader.getClass, reader.readSchema(), filters) + } + + def canEqual(other: Any): Boolean + + override def equals(other: Any): Boolean = other match { + case other: DataSourceReaderHolder => + canEqual(other) && metadata.length == other.metadata.length && + metadata.zip(other.metadata).forall { case (l, r) => l == r } + case _ => false + } + + override def hashCode(): Int = { + metadata.map(Objects.hashCode).foldLeft(0)((a, b) => 31 * a + b) + } + + lazy val output: Seq[Attribute] = reader.readSchema().map(_.name).map { name => + fullOutput.find(_.name == name).get + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala index 3c9b598fd07c9..7eb99a645001a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala @@ -19,11 +19,13 @@ package org.apache.spark.sql.execution.datasources.v2 import org.apache.spark.sql.catalyst.expressions.AttributeReference import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, Statistics} -import org.apache.spark.sql.sources.v2.reader.{DataSourceV2Reader, SupportsReportStatistics} +import org.apache.spark.sql.sources.v2.reader._ case class DataSourceV2Relation( - output: Seq[AttributeReference], - reader: DataSourceV2Reader) extends LeafNode { + fullOutput: Seq[AttributeReference], + reader: DataSourceV2Reader) extends LeafNode with DataSourceReaderHolder { + + override def canEqual(other: Any): Boolean = other.isInstanceOf[DataSourceV2Relation] override def computeStats(): Statistics = reader match { case r: SupportsReportStatistics => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExec.scala index 7999c0ceb5749..addc12a3f0901 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExec.scala @@ -29,20 +29,14 @@ import org.apache.spark.sql.execution.metric.SQLMetrics import org.apache.spark.sql.sources.v2.reader._ import org.apache.spark.sql.types.StructType +/** + * Physical plan node for scanning data from a data source. + */ case class DataSourceV2ScanExec( - fullOutput: Array[AttributeReference], - @transient reader: DataSourceV2Reader, - // TODO: these 3 parameters are only used to determine the equality of the scan node, however, - // the reader also have this information, and ideally we can just rely on the equality of the - // reader. The only concern is, the reader implementation is outside of Spark and we have no - // control. - readSchema: StructType, - @transient filters: ExpressionSet, - hashPartitionKeys: Seq[String]) extends LeafExecNode { - - def output: Seq[Attribute] = readSchema.map(_.name).map { name => - fullOutput.find(_.name == name).get - } + fullOutput: Seq[AttributeReference], + @transient reader: DataSourceV2Reader) extends LeafExecNode with DataSourceReaderHolder { + + override def canEqual(other: Any): Boolean = other.isInstanceOf[DataSourceV2ScanExec] override def references: AttributeSet = AttributeSet.empty @@ -74,7 +68,7 @@ class RowToUnsafeRowReadTask(rowReadTask: ReadTask[Row], schema: StructType) override def preferredLocations: Array[String] = rowReadTask.preferredLocations override def createReader: DataReader[UnsafeRow] = { - new RowToUnsafeDataReader(rowReadTask.createReader, RowEncoder.apply(schema)) + new RowToUnsafeDataReader(rowReadTask.createReader, RowEncoder.apply(schema).resolveAndBind()) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala index b80f695b2a87f..f2cda002245e8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala @@ -29,64 +29,8 @@ import org.apache.spark.sql.sources.v2.reader._ object DataSourceV2Strategy extends Strategy { // TODO: write path override def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { - case PhysicalOperation(projects, filters, DataSourceV2Relation(output, reader)) => - val stayUpFilters: Seq[Expression] = reader match { - case r: SupportsPushDownCatalystFilters => - r.pushCatalystFilters(filters.toArray) - - case r: SupportsPushDownFilters => - // A map from original Catalyst expressions to corresponding translated data source - // filters. If a predicate is not in this map, it means it cannot be pushed down. - val translatedMap: Map[Expression, Filter] = filters.flatMap { p => - DataSourceStrategy.translateFilter(p).map(f => p -> f) - }.toMap - - // Catalyst predicate expressions that cannot be converted to data source filters. - val nonConvertiblePredicates = filters.filterNot(translatedMap.contains) - - // Data source filters that cannot be pushed down. An unhandled filter means - // the data source cannot guarantee the rows returned can pass the filter. - // As a result we must return it so Spark can plan an extra filter operator. - val unhandledFilters = r.pushFilters(translatedMap.values.toArray).toSet - val unhandledPredicates = translatedMap.filter { case (_, f) => - unhandledFilters.contains(f) - }.keys - - nonConvertiblePredicates ++ unhandledPredicates - - case _ => filters - } - - val attrMap = AttributeMap(output.zip(output)) - val projectSet = AttributeSet(projects.flatMap(_.references)) - val filterSet = AttributeSet(stayUpFilters.flatMap(_.references)) - - // Match original case of attributes. - // TODO: nested fields pruning - val requiredColumns = (projectSet ++ filterSet).toSeq.map(attrMap) - reader match { - case r: SupportsPushDownRequiredColumns => - r.pruneColumns(requiredColumns.toStructType) - case _ => - } - - val scan = DataSourceV2ScanExec( - output.toArray, - reader, - reader.readSchema(), - ExpressionSet(filters), - Nil) - - val filterCondition = stayUpFilters.reduceLeftOption(And) - val withFilter = filterCondition.map(FilterExec(_, scan)).getOrElse(scan) - - val withProject = if (projects == withFilter.output) { - withFilter - } else { - ProjectExec(projects, withFilter) - } - - withProject :: Nil + case DataSourceV2Relation(output, reader) => + DataSourceV2ScanExec(output, reader) :: Nil case _ => Nil } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownOperatorsToDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownOperatorsToDataSource.scala new file mode 100644 index 0000000000000..0c1708131ae46 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownOperatorsToDataSource.scala @@ -0,0 +1,140 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources.v2 + +import org.apache.spark.sql.catalyst.expressions.{And, Attribute, AttributeMap, Expression, NamedExpression, PredicateHelper} +import org.apache.spark.sql.catalyst.optimizer.RemoveRedundantProject +import org.apache.spark.sql.catalyst.plans.logical.{Filter, LogicalPlan, Project} +import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.execution.datasources.DataSourceStrategy +import org.apache.spark.sql.sources +import org.apache.spark.sql.sources.v2.reader._ + +/** + * Pushes down various operators to the underlying data source for better performance. Operators are + * being pushed down with a specific order. As an example, given a LIMIT has a FILTER child, you + * can't push down LIMIT if FILTER is not completely pushed down. When both are pushed down, the + * data source should execute FILTER before LIMIT. And required columns are calculated at the end, + * because when more operators are pushed down, we may need less columns at Spark side. + */ +object PushDownOperatorsToDataSource extends Rule[LogicalPlan] with PredicateHelper { + override def apply(plan: LogicalPlan): LogicalPlan = { + // Note that, we need to collect the target operator along with PROJECT node, as PROJECT may + // appear in many places for column pruning. + // TODO: Ideally column pruning should be implemented via a plan property that is propagated + // top-down, then we can simplify the logic here and only collect target operators. + val filterPushed = plan transformUp { + case FilterAndProject(fields, condition, r @ DataSourceV2Relation(_, reader)) => + // Non-deterministic expressions are stateful and we must keep the input sequence unchanged + // to avoid changing the result. This means, we can't evaluate the filter conditions that + // are after the first non-deterministic condition ahead. Here we only try to push down + // deterministic conditions that are before the first non-deterministic condition. + val (candidates, containingNonDeterministic) = + splitConjunctivePredicates(condition).span(_.deterministic) + + val stayUpFilters: Seq[Expression] = reader match { + case r: SupportsPushDownCatalystFilters => + r.pushCatalystFilters(candidates.toArray) + + case r: SupportsPushDownFilters => + // A map from original Catalyst expressions to corresponding translated data source + // filters. If a predicate is not in this map, it means it cannot be pushed down. + val translatedMap: Map[Expression, sources.Filter] = candidates.flatMap { p => + DataSourceStrategy.translateFilter(p).map(f => p -> f) + }.toMap + + // Catalyst predicate expressions that cannot be converted to data source filters. + val nonConvertiblePredicates = candidates.filterNot(translatedMap.contains) + + // Data source filters that cannot be pushed down. An unhandled filter means + // the data source cannot guarantee the rows returned can pass the filter. + // As a result we must return it so Spark can plan an extra filter operator. + val unhandledFilters = r.pushFilters(translatedMap.values.toArray).toSet + val unhandledPredicates = translatedMap.filter { case (_, f) => + unhandledFilters.contains(f) + }.keys + + nonConvertiblePredicates ++ unhandledPredicates + + case _ => candidates + } + + val filterCondition = (stayUpFilters ++ containingNonDeterministic).reduceLeftOption(And) + val withFilter = filterCondition.map(Filter(_, r)).getOrElse(r) + if (withFilter.output == fields) { + withFilter + } else { + Project(fields, withFilter) + } + } + + // TODO: add more push down rules. + + // TODO: nested fields pruning + def pushDownRequiredColumns(plan: LogicalPlan, requiredByParent: Seq[Attribute]): Unit = { + plan match { + case Project(projectList, child) => + val required = projectList.filter(requiredByParent.contains).flatMap(_.references) + pushDownRequiredColumns(child, required) + + case Filter(condition, child) => + val required = requiredByParent ++ condition.references + pushDownRequiredColumns(child, required) + + case DataSourceV2Relation(fullOutput, reader) => reader match { + case r: SupportsPushDownRequiredColumns => + // Match original case of attributes. + val attrMap = AttributeMap(fullOutput.zip(fullOutput)) + val requiredColumns = requiredByParent.map(attrMap) + r.pruneColumns(requiredColumns.toStructType) + case _ => + } + + // TODO: there may be more operators can be used to calculate required columns, we can add + // more and more in the future. + case _ => plan.children.foreach(child => pushDownRequiredColumns(child, child.output)) + } + } + + pushDownRequiredColumns(filterPushed, filterPushed.output) + // After column pruning, we may have redundant PROJECT nodes in the query plan, remove them. + RemoveRedundantProject(filterPushed) + } + + /** + * Finds a Filter node(with an optional Project child) above data source relation. + */ + object FilterAndProject { + // returns the project list, the filter condition and the data source relation. + def unapply(plan: LogicalPlan) + : Option[(Seq[NamedExpression], Expression, DataSourceV2Relation)] = plan match { + + case Filter(condition, r: DataSourceV2Relation) => Some((r.output, condition, r)) + + case Filter(condition, Project(fields, r: DataSourceV2Relation)) + if fields.forall(_.deterministic) => + val attributeMap = AttributeMap(fields.map(e => e.toAttribute -> e)) + val substituted = condition.transform { + case a: Attribute => attributeMap.getOrElse(a, a) + } + Some((fields, substituted, r)) + + case _ => None + } + } +} diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaAdvancedDataSourceV2.java b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaAdvancedDataSourceV2.java index 7aacf0346d2fb..da2c13f70c52a 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaAdvancedDataSourceV2.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaAdvancedDataSourceV2.java @@ -54,6 +54,11 @@ public Filter[] pushFilters(Filter[] filters) { return new Filter[0]; } + @Override + public Filter[] pushedFilters() { + return filters; + } + @Override public List> createReadTasks() { List> res = new ArrayList<>(); diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala index 9ce93d7ae926c..f238e565dc2fc 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala @@ -129,6 +129,8 @@ class AdvancedDataSourceV2 extends DataSourceV2 with ReadSupport { Array.empty } + override def pushedFilters(): Array[Filter] = filters + override def readSchema(): StructType = { requiredSchema } From 02218c4c73c32741390d9906b6190ef2124ce518 Mon Sep 17 00:00:00 2001 From: Ala Luszczak Date: Thu, 12 Oct 2017 17:00:22 +0200 Subject: [PATCH 1478/1765] [SPARK-22251][SQL] Metric 'aggregate time' is incorrect when codegen is off ## What changes were proposed in this pull request? Adding the code for setting 'aggregate time' metric to non-codegen path in HashAggregateExec and to ObjectHashAggregateExces. ## How was this patch tested? Tested manually. Author: Ala Luszczak Closes #19473 from ala/fix-agg-time. --- .../sql/execution/aggregate/HashAggregateExec.scala | 6 +++++- .../execution/aggregate/ObjectHashAggregateExec.scala | 9 +++++++-- 2 files changed, 12 insertions(+), 3 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala index 8b573fdcf25e1..43e5ff89afee6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala @@ -95,11 +95,13 @@ case class HashAggregateExec( val peakMemory = longMetric("peakMemory") val spillSize = longMetric("spillSize") val avgHashProbe = longMetric("avgHashProbe") + val aggTime = longMetric("aggTime") child.execute().mapPartitionsWithIndex { (partIndex, iter) => + val beforeAgg = System.nanoTime() val hasInput = iter.hasNext - if (!hasInput && groupingExpressions.nonEmpty) { + val res = if (!hasInput && groupingExpressions.nonEmpty) { // This is a grouped aggregate and the input iterator is empty, // so return an empty iterator. Iterator.empty @@ -128,6 +130,8 @@ case class HashAggregateExec( aggregationIterator } } + aggTime += (System.nanoTime() - beforeAgg) / 1000000 + res } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectHashAggregateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectHashAggregateExec.scala index 6316e06a8f34e..ec3f9a05b5ccc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectHashAggregateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectHashAggregateExec.scala @@ -76,7 +76,8 @@ case class ObjectHashAggregateExec( aggregateExpressions.flatMap(_.aggregateFunction.inputAggBufferAttributes) override lazy val metrics = Map( - "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows") + "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows"), + "aggTime" -> SQLMetrics.createTimingMetric(sparkContext, "aggregate time") ) override def output: Seq[Attribute] = resultExpressions.map(_.toAttribute) @@ -96,11 +97,13 @@ case class ObjectHashAggregateExec( protected override def doExecute(): RDD[InternalRow] = attachTree(this, "execute") { val numOutputRows = longMetric("numOutputRows") + val aggTime = longMetric("aggTime") val fallbackCountThreshold = sqlContext.conf.objectAggSortBasedFallbackThreshold child.execute().mapPartitionsWithIndexInternal { (partIndex, iter) => + val beforeAgg = System.nanoTime() val hasInput = iter.hasNext - if (!hasInput && groupingExpressions.nonEmpty) { + val res = if (!hasInput && groupingExpressions.nonEmpty) { // This is a grouped aggregate and the input kvIterator is empty, // so return an empty kvIterator. Iterator.empty @@ -127,6 +130,8 @@ case class ObjectHashAggregateExec( aggregationIterator } } + aggTime += (System.nanoTime() - beforeAgg) / 1000000 + res } } From 9104add4c7c6b578df15b64a8533a1266f90734e Mon Sep 17 00:00:00 2001 From: Steve Loughran Date: Fri, 13 Oct 2017 08:40:26 +0900 Subject: [PATCH 1479/1765] [SPARK-22217][SQL] ParquetFileFormat to support arbitrary OutputCommitters ## What changes were proposed in this pull request? `ParquetFileFormat` to relax its requirement of output committer class from `org.apache.parquet.hadoop.ParquetOutputCommitter` or subclass thereof (and so implicitly Hadoop `FileOutputCommitter`) to any committer implementing `org.apache.hadoop.mapreduce.OutputCommitter` This enables output committers which don't write to the filesystem the way `FileOutputCommitter` does to save parquet data from a dataframe: at present you cannot do this. Before a committer which isn't a subclass of `ParquetOutputCommitter`, it checks to see if the context has requested summary metadata by setting `parquet.enable.summary-metadata`. If true, and the committer class isn't a parquet committer, it raises a RuntimeException with an error message. (It could downgrade, of course, but raising an exception makes it clear there won't be an summary. It also makes the behaviour testable.) Note that `SQLConf` already states that any `OutputCommitter` can be used, but that typically it's a subclass of ParquetOutputCommitter. That's not currently true. This patch will make the code consistent with the docs, adding tests to verify, ## How was this patch tested? The patch includes a test suite, `ParquetCommitterSuite`, with a new committer, `MarkingFileOutputCommitter` which extends `FileOutputCommitter` and writes a marker file in the destination directory. The presence of the marker file can be used to verify the new committer was used. The tests then try the combinations of Parquet committer summary/no-summary and marking committer summary/no-summary. | committer | summary | outcome | |-----------|---------|---------| | parquet | true | success | | parquet | false | success | | marking | false | success with marker | | marking | true | exception | All tests are happy. Author: Steve Loughran Closes #19448 from steveloughran/cloud/SPARK-22217-committer. --- .../apache/spark/sql/internal/SQLConf.scala | 5 +- .../parquet/ParquetFileFormat.scala | 12 +- .../parquet/ParquetCommitterSuite.scala | 152 ++++++++++++++++++ 3 files changed, 165 insertions(+), 4 deletions(-) create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetCommitterSuite.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 58323740b80cc..618d4a0d6148a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -306,8 +306,9 @@ object SQLConf { val PARQUET_OUTPUT_COMMITTER_CLASS = buildConf("spark.sql.parquet.output.committer.class") .doc("The output committer class used by Parquet. The specified class needs to be a " + - "subclass of org.apache.hadoop.mapreduce.OutputCommitter. Typically, it's also a subclass " + - "of org.apache.parquet.hadoop.ParquetOutputCommitter.") + "subclass of org.apache.hadoop.mapreduce.OutputCommitter. Typically, it's also a subclass " + + "of org.apache.parquet.hadoop.ParquetOutputCommitter. If it is not, then metadata summaries" + + "will never be created, irrespective of the value of parquet.enable.summary-metadata") .internal() .stringConf .createWithDefault("org.apache.parquet.hadoop.ParquetOutputCommitter") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala index e1e740500205a..c1535babbae1f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala @@ -86,7 +86,7 @@ class ParquetFileFormat conf.getClass( SQLConf.PARQUET_OUTPUT_COMMITTER_CLASS.key, classOf[ParquetOutputCommitter], - classOf[ParquetOutputCommitter]) + classOf[OutputCommitter]) if (conf.get(SQLConf.PARQUET_OUTPUT_COMMITTER_CLASS.key) == null) { logInfo("Using default output committer for Parquet: " + @@ -98,7 +98,7 @@ class ParquetFileFormat conf.setClass( SQLConf.OUTPUT_COMMITTER_CLASS.key, committerClass, - classOf[ParquetOutputCommitter]) + classOf[OutputCommitter]) // We're not really using `ParquetOutputFormat[Row]` for writing data here, because we override // it in `ParquetOutputWriter` to support appending and dynamic partitioning. The reason why @@ -138,6 +138,14 @@ class ParquetFileFormat conf.setBoolean(ParquetOutputFormat.ENABLE_JOB_SUMMARY, false) } + if (conf.getBoolean(ParquetOutputFormat.ENABLE_JOB_SUMMARY, false) + && !classOf[ParquetOutputCommitter].isAssignableFrom(committerClass)) { + // output summary is requested, but the class is not a Parquet Committer + logWarning(s"Committer $committerClass is not a ParquetOutputCommitter and cannot" + + s" create job summaries. " + + s"Set Parquet option ${ParquetOutputFormat.ENABLE_JOB_SUMMARY} to false.") + } + new OutputWriterFactory { // This OutputWriterFactory instance is deserialized when writing Parquet files on the // executor side without constructing or deserializing ParquetFileFormat. Therefore, we hold diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetCommitterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetCommitterSuite.scala new file mode 100644 index 0000000000000..caa4f6d70c6a9 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetCommitterSuite.scala @@ -0,0 +1,152 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources.parquet + +import java.io.FileNotFoundException + +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.fs.{FileStatus, Path} +import org.apache.hadoop.mapreduce.{JobContext, TaskAttemptContext} +import org.apache.hadoop.mapreduce.lib.output.FileOutputCommitter +import org.apache.parquet.hadoop.{ParquetOutputCommitter, ParquetOutputFormat} + +import org.apache.spark.{LocalSparkContext, SparkFunSuite} +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.test.SQLTestUtils + +/** + * Test logic related to choice of output committers. + */ +class ParquetCommitterSuite extends SparkFunSuite with SQLTestUtils + with LocalSparkContext { + + private val PARQUET_COMMITTER = classOf[ParquetOutputCommitter].getCanonicalName + + protected var spark: SparkSession = _ + + /** + * Create a new [[SparkSession]] running in local-cluster mode with unsafe and codegen enabled. + */ + override def beforeAll(): Unit = { + super.beforeAll() + spark = SparkSession.builder() + .master("local-cluster[2,1,1024]") + .appName("testing") + .getOrCreate() + } + + override def afterAll(): Unit = { + try { + if (spark != null) { + spark.stop() + spark = null + } + } finally { + super.afterAll() + } + } + + test("alternative output committer, merge schema") { + writeDataFrame(MarkingFileOutput.COMMITTER, summary = true, check = true) + } + + test("alternative output committer, no merge schema") { + writeDataFrame(MarkingFileOutput.COMMITTER, summary = false, check = true) + } + + test("Parquet output committer, merge schema") { + writeDataFrame(PARQUET_COMMITTER, summary = true, check = false) + } + + test("Parquet output committer, no merge schema") { + writeDataFrame(PARQUET_COMMITTER, summary = false, check = false) + } + + /** + * Write a trivial dataframe as Parquet, using the given committer + * and job summary option. + * @param committer committer to use + * @param summary create a job summary + * @param check look for a marker file + * @return if a marker file was sought, it's file status. + */ + private def writeDataFrame( + committer: String, + summary: Boolean, + check: Boolean): Option[FileStatus] = { + var result: Option[FileStatus] = None + withSQLConf( + SQLConf.PARQUET_OUTPUT_COMMITTER_CLASS.key -> committer, + ParquetOutputFormat.ENABLE_JOB_SUMMARY -> summary.toString) { + withTempPath { dest => + val df = spark.createDataFrame(Seq((1, "4"), (2, "2"))) + val destPath = new Path(dest.toURI) + df.write.format("parquet").save(destPath.toString) + if (check) { + result = Some(MarkingFileOutput.checkMarker( + destPath, + spark.sparkContext.hadoopConfiguration)) + } + } + } + result + } +} + +/** + * A file output committer which explicitly touches a file "marker"; this + * is how tests can verify that this committer was used. + * @param outputPath output path + * @param context task context + */ +private class MarkingFileOutputCommitter( + outputPath: Path, + context: TaskAttemptContext) extends FileOutputCommitter(outputPath, context) { + + override def commitJob(context: JobContext): Unit = { + super.commitJob(context) + MarkingFileOutput.touch(outputPath, context.getConfiguration) + } +} + +private object MarkingFileOutput { + + val COMMITTER = classOf[MarkingFileOutputCommitter].getCanonicalName + + /** + * Touch the marker. + * @param outputPath destination directory + * @param conf configuration to create the FS with + */ + def touch(outputPath: Path, conf: Configuration): Unit = { + outputPath.getFileSystem(conf).create(new Path(outputPath, "marker")).close() + } + + /** + * Get the file status of the marker + * + * @param outputPath destination directory + * @param conf configuration to create the FS with + * @return the status of the marker + * @throws FileNotFoundException if the marker is absent + */ + def checkMarker(outputPath: Path, conf: Configuration): FileStatus = { + outputPath.getFileSystem(conf).getFileStatus(new Path(outputPath, "marker")) + } +} From 3ff766f61afbd09dcc7a73eae02e68a39114ce3f Mon Sep 17 00:00:00 2001 From: Wang Gengliang Date: Thu, 12 Oct 2017 18:47:16 -0700 Subject: [PATCH 1480/1765] [SPARK-22263][SQL] Refactor deterministic as lazy value ## What changes were proposed in this pull request? The method `deterministic` is frequently called in optimizer. Refactor `deterministic` as lazy value, in order to avoid redundant computations. ## How was this patch tested? Simple benchmark test over TPC-DS queries, run time from query string to optimized plan(continuous 20 runs, and get the average of last 5 results): Before changes: 12601 ms After changes: 11993ms This is 4.8% performance improvement. Also run test with Unit test. Author: Wang Gengliang Closes #19478 from gengliangwang/deterministicAsLazyVal. --- .../sql/catalyst/expressions/CallMethodViaReflection.scala | 2 +- .../apache/spark/sql/catalyst/expressions/Expression.scala | 4 ++-- .../org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala | 2 +- .../spark/sql/catalyst/expressions/aggregate/First.scala | 2 +- .../spark/sql/catalyst/expressions/aggregate/Last.scala | 2 +- .../spark/sql/catalyst/expressions/aggregate/collect.scala | 2 +- .../org/apache/spark/sql/catalyst/expressions/misc.scala | 2 +- .../sql/execution/aggregate/TypedAggregateExpression.scala | 4 ++-- .../scala/org/apache/spark/sql/execution/aggregate/udaf.scala | 2 +- .../org/apache/spark/sql/TypedImperativeAggregateSuite.scala | 2 +- .../src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala | 4 ++-- 11 files changed, 14 insertions(+), 14 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/CallMethodViaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/CallMethodViaReflection.scala index cd97304302e48..65bb9a8c642b6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/CallMethodViaReflection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/CallMethodViaReflection.scala @@ -76,7 +76,7 @@ case class CallMethodViaReflection(children: Seq[Expression]) } } - override def deterministic: Boolean = false + override lazy val deterministic: Boolean = false override def nullable: Boolean = true override val dataType: DataType = StringType diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala index c058425b4bc36..0e75ac88dc2b8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala @@ -79,7 +79,7 @@ abstract class Expression extends TreeNode[Expression] { * An example would be `SparkPartitionID` that relies on the partition id returned by TaskContext. * By default leaf expressions are deterministic as Nil.forall(_.deterministic) returns true. */ - def deterministic: Boolean = children.forall(_.deterministic) + lazy val deterministic: Boolean = children.forall(_.deterministic) def nullable: Boolean @@ -265,7 +265,7 @@ trait NonSQLExpression extends Expression { * An expression that is nondeterministic. */ trait Nondeterministic extends Expression { - final override def deterministic: Boolean = false + final override lazy val deterministic: Boolean = false final override def foldable: Boolean = false @transient diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala index 527f1670c25e1..179853032035e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala @@ -49,7 +49,7 @@ case class ScalaUDF( udfDeterministic: Boolean = true) extends Expression with ImplicitCastInputTypes with NonSQLExpression with UserDefinedExpression { - override def deterministic: Boolean = udfDeterministic && children.forall(_.deterministic) + override lazy val deterministic: Boolean = udfDeterministic && children.forall(_.deterministic) override def toString: String = s"${udfName.map(name => s"UDF:$name").getOrElse("UDF")}(${children.mkString(", ")})" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/First.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/First.scala index bfc58c22886cc..4e671e1f3e6eb 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/First.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/First.scala @@ -44,7 +44,7 @@ case class First(child: Expression, ignoreNullsExpr: Expression) override def nullable: Boolean = true // First is not a deterministic function. - override def deterministic: Boolean = false + override lazy val deterministic: Boolean = false // Return data type. override def dataType: DataType = child.dataType diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Last.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Last.scala index 96a6ec08a160a..0ccabb9d98914 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Last.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Last.scala @@ -44,7 +44,7 @@ case class Last(child: Expression, ignoreNullsExpr: Expression) override def nullable: Boolean = true // Last is not a deterministic function. - override def deterministic: Boolean = false + override lazy val deterministic: Boolean = false // Return data type. override def dataType: DataType = child.dataType diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/collect.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/collect.scala index 405c2065680f5..be972f006352e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/collect.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/collect.scala @@ -44,7 +44,7 @@ abstract class Collect[T <: Growable[Any] with Iterable[Any]] extends TypedImper // Both `CollectList` and `CollectSet` are non-deterministic since their results depend on the // actual order of input rows. - override def deterministic: Boolean = false + override lazy val deterministic: Boolean = false override def update(buffer: T, input: InternalRow): T = { val value = child.eval(input) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala index ef293ff3f18ea..b86e271fe2958 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala @@ -119,7 +119,7 @@ case class CurrentDatabase() extends LeafExpression with Unevaluable { // scalastyle:on line.size.limit case class Uuid() extends LeafExpression { - override def deterministic: Boolean = false + override lazy val deterministic: Boolean = false override def nullable: Boolean = false diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala index 717758fdf716f..aab8cc50b9526 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala @@ -127,7 +127,7 @@ case class SimpleTypedAggregateExpression( nullable: Boolean) extends DeclarativeAggregate with TypedAggregateExpression with NonSQLExpression { - override def deterministic: Boolean = true + override lazy val deterministic: Boolean = true override def children: Seq[Expression] = inputDeserializer.toSeq :+ bufferDeserializer @@ -221,7 +221,7 @@ case class ComplexTypedAggregateExpression( inputAggBufferOffset: Int = 0) extends TypedImperativeAggregate[Any] with TypedAggregateExpression with NonSQLExpression { - override def deterministic: Boolean = true + override lazy val deterministic: Boolean = true override def children: Seq[Expression] = inputDeserializer.toSeq diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala index fec1add18cbf2..72aa4adff4e64 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala @@ -340,7 +340,7 @@ case class ScalaUDAF( override def dataType: DataType = udaf.dataType - override def deterministic: Boolean = udaf.deterministic + override lazy val deterministic: Boolean = udaf.deterministic override val inputTypes: Seq[DataType] = udaf.inputSchema.map(_.dataType) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/TypedImperativeAggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/TypedImperativeAggregateSuite.scala index b76f168220d84..c5fb17345222a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/TypedImperativeAggregateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/TypedImperativeAggregateSuite.scala @@ -268,7 +268,7 @@ object TypedImperativeAggregateSuite { } } - override def deterministic: Boolean = true + override lazy val deterministic: Boolean = true override def children: Seq[Expression] = Seq(child) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala index e9bdcf00b9346..68af99ea272a8 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala @@ -48,7 +48,7 @@ private[hive] case class HiveSimpleUDF( with Logging with UserDefinedExpression { - override def deterministic: Boolean = isUDFDeterministic && children.forall(_.deterministic) + override lazy val deterministic: Boolean = isUDFDeterministic && children.forall(_.deterministic) override def nullable: Boolean = true @@ -131,7 +131,7 @@ private[hive] case class HiveGenericUDF( override def nullable: Boolean = true - override def deterministic: Boolean = isUDFDeterministic && children.forall(_.deterministic) + override lazy val deterministic: Boolean = isUDFDeterministic && children.forall(_.deterministic) override def foldable: Boolean = isUDFDeterministic && returnInspector.isInstanceOf[ConstantObjectInspector] From ec122209fb35a65637df42eded64b0203e105aae Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Fri, 13 Oct 2017 13:09:35 +0800 Subject: [PATCH 1481/1765] [SPARK-21165][SQL] FileFormatWriter should handle mismatched attribute ids between logical and physical plan ## What changes were proposed in this pull request? Due to optimizer removing some unnecessary aliases, the logical and physical plan may have different output attribute ids. FileFormatWriter should handle this when creating the physical sort node. ## How was this patch tested? new regression test. Author: Wenchen Fan Closes #19483 from cloud-fan/bug2. --- .../datasources/FileFormatWriter.scala | 7 +++++- .../datasources/FileFormatWriterSuite.scala | 2 +- .../apache/spark/sql/hive/InsertSuite.scala | 22 +++++++++++++++++++ 3 files changed, 29 insertions(+), 2 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala index 75b1695fbc275..1fac01a2c26c6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala @@ -180,8 +180,13 @@ object FileFormatWriter extends Logging { val rdd = if (orderingMatched) { queryExecution.toRdd } else { + // SPARK-21165: the `requiredOrdering` is based on the attributes from analyzed plan, and + // the physical plan may have different attribute ids due to optimizer removing some + // aliases. Here we bind the expression ahead to avoid potential attribute ids mismatch. + val orderingExpr = requiredOrdering + .map(SortOrder(_, Ascending)).map(BindReferences.bindReference(_, allColumns)) SortExec( - requiredOrdering.map(SortOrder(_, Ascending)), + orderingExpr, global = false, child = queryExecution.executedPlan).execute() } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileFormatWriterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileFormatWriterSuite.scala index 6f8767db176aa..13f0e0bca86c7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileFormatWriterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileFormatWriterSuite.scala @@ -32,7 +32,7 @@ class FileFormatWriterSuite extends QueryTest with SharedSQLContext { } } - test("FileFormatWriter should respect the input query schema") { + test("SPARK-22252: FileFormatWriter should respect the input query schema") { withTable("t1", "t2", "t3", "t4") { spark.range(1).select('id as 'col1, 'id as 'col2).write.saveAsTable("t1") spark.sql("select COL1, COL2 from t1").write.saveAsTable("t2") diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertSuite.scala index aa5cae33f5cd9..ab91727049ff5 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertSuite.scala @@ -728,4 +728,26 @@ class InsertSuite extends QueryTest with TestHiveSingleton with BeforeAndAfter assert(e.contains("mismatched input 'ROW'")) } } + + test("SPARK-21165: FileFormatWriter should only rely on attributes from analyzed plan") { + withSQLConf(("hive.exec.dynamic.partition.mode", "nonstrict")) { + withTable("tab1", "tab2") { + Seq(("a", "b", 3)).toDF("word", "first", "length").write.saveAsTable("tab1") + + spark.sql( + """ + |CREATE TABLE tab2 (word string, length int) + |PARTITIONED BY (first string) + """.stripMargin) + + spark.sql( + """ + |INSERT INTO TABLE tab2 PARTITION(first) + |SELECT word, length, cast(first as string) as first FROM tab1 + """.stripMargin) + + checkAnswer(spark.table("tab2"), Row("a", 3, "b")) + } + } + } } From 2f00a71a876321af02865d7cd53ada167e1ce2e3 Mon Sep 17 00:00:00 2001 From: Wang Gengliang Date: Thu, 12 Oct 2017 22:45:19 -0700 Subject: [PATCH 1482/1765] [SPARK-22257][SQL] Reserve all non-deterministic expressions in ExpressionSet ## What changes were proposed in this pull request? For non-deterministic expressions, they should be considered as not contained in the [[ExpressionSet]]. This is consistent with how we define `semanticEquals` between two expressions. Otherwise, combining expressions will remove non-deterministic expressions which should be reserved. E.g. Combine filters of ```scala testRelation.where(Rand(0) > 0.1).where(Rand(0) > 0.1) ``` should result in ```scala testRelation.where(Rand(0) > 0.1 && Rand(0) > 0.1) ``` ## How was this patch tested? Unit test Author: Wang Gengliang Closes #19475 from gengliangwang/non-deterministic-expressionSet. --- .../catalyst/expressions/ExpressionSet.scala | 23 ++++++--- .../expressions/ExpressionSetSuite.scala | 51 +++++++++++++++---- .../optimizer/FilterPushdownSuite.scala | 15 ++++++ 3 files changed, 72 insertions(+), 17 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExpressionSet.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExpressionSet.scala index 305ac90e245b8..7e8e7b8cd5f18 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExpressionSet.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExpressionSet.scala @@ -30,8 +30,9 @@ object ExpressionSet { } /** - * A [[Set]] where membership is determined based on a canonical representation of an [[Expression]] - * (i.e. one that attempts to ignore cosmetic differences). See [[Canonicalize]] for more details. + * A [[Set]] where membership is determined based on determinacy and a canonical representation of + * an [[Expression]] (i.e. one that attempts to ignore cosmetic differences). + * See [[Canonicalize]] for more details. * * Internally this set uses the canonical representation, but keeps also track of the original * expressions to ease debugging. Since different expressions can share the same canonical @@ -46,6 +47,10 @@ object ExpressionSet { * set.contains(1 + a) => true * set.contains(a + 2) => false * }}} + * + * For non-deterministic expressions, they are always considered as not contained in the [[Set]]. + * On adding a non-deterministic expression, simply append it to the original expressions. + * This is consistent with how we define `semanticEquals` between two expressions. */ class ExpressionSet protected( protected val baseSet: mutable.Set[Expression] = new mutable.HashSet, @@ -53,7 +58,9 @@ class ExpressionSet protected( extends Set[Expression] { protected def add(e: Expression): Unit = { - if (!baseSet.contains(e.canonicalized)) { + if (!e.deterministic) { + originals += e + } else if (!baseSet.contains(e.canonicalized) ) { baseSet.add(e.canonicalized) originals += e } @@ -74,9 +81,13 @@ class ExpressionSet protected( } override def -(elem: Expression): ExpressionSet = { - val newBaseSet = baseSet.clone().filterNot(_ == elem.canonicalized) - val newOriginals = originals.clone().filterNot(_.canonicalized == elem.canonicalized) - new ExpressionSet(newBaseSet, newOriginals) + if (elem.deterministic) { + val newBaseSet = baseSet.clone().filterNot(_ == elem.canonicalized) + val newOriginals = originals.clone().filterNot(_.canonicalized == elem.canonicalized) + new ExpressionSet(newBaseSet, newOriginals) + } else { + new ExpressionSet(baseSet.clone(), originals.clone()) + } } override def iterator: Iterator[Expression] = originals.iterator diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionSetSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionSetSuite.scala index a1000a0e80799..12eddf557109f 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionSetSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionSetSuite.scala @@ -175,20 +175,14 @@ class ExpressionSetSuite extends SparkFunSuite { aUpper > bUpper || aUpper <= Rand(1L) || aUpper <= 10, aUpper <= Rand(1L) || aUpper <= 10 || aUpper > bUpper) - // Partial reorder case: we don't reorder non-deterministic expressions, - // but we can reorder sub-expressions in deterministic AND/OR expressions. - // There are two predicates: - // (aUpper > bUpper || bUpper > 100) => we can reorder sub-expressions in it. - // (aUpper === Rand(1L)) - setTest(1, + // Keep all the non-deterministic expressions even they are semantically equal. + setTest(2, Rand(1L), Rand(1L)) + + setTest(2, (aUpper > bUpper || bUpper > 100) && aUpper === Rand(1L), (bUpper > 100 || aUpper > bUpper) && aUpper === Rand(1L)) - // There are three predicates: - // (Rand(1L) > aUpper) - // (aUpper <= Rand(1L) && aUpper > bUpper) - // (aUpper > 10 && bUpper > 10) => we can reorder sub-expressions in it. - setTest(1, + setTest(2, Rand(1L) > aUpper || (aUpper <= Rand(1L) && aUpper > bUpper) || (aUpper > 10 && bUpper > 10), Rand(1L) > aUpper || (aUpper <= Rand(1L) && aUpper > bUpper) || (bUpper > 10 && aUpper > 10)) @@ -219,4 +213,39 @@ class ExpressionSetSuite extends SparkFunSuite { assert((initialSet ++ setToAddWithSameExpression).size == 2) assert((initialSet ++ setToAddWithOutSameExpression).size == 3) } + + test("add single element to set with non-deterministic expressions") { + val initialSet = ExpressionSet(aUpper + 1 :: Rand(0) :: Nil) + + assert((initialSet + (aUpper + 1)).size == 2) + assert((initialSet + Rand(0)).size == 3) + assert((initialSet + (aUpper + 2)).size == 3) + } + + test("remove single element to set with non-deterministic expressions") { + val initialSet = ExpressionSet(aUpper + 1 :: Rand(0) :: Nil) + + assert((initialSet - (aUpper + 1)).size == 1) + assert((initialSet - Rand(0)).size == 2) + assert((initialSet - (aUpper + 2)).size == 2) + } + + test("add multiple elements to set with non-deterministic expressions") { + val initialSet = ExpressionSet(aUpper + 1 :: Rand(0) :: Nil) + val setToAddWithSameDeterministicExpression = ExpressionSet(aUpper + 1 :: Rand(0) :: Nil) + val setToAddWithOutSameExpression = ExpressionSet(aUpper + 3 :: aUpper + 4 :: Nil) + + assert((initialSet ++ setToAddWithSameDeterministicExpression).size == 3) + assert((initialSet ++ setToAddWithOutSameExpression).size == 4) + } + + test("remove multiple elements to set with non-deterministic expressions") { + val initialSet = ExpressionSet(aUpper + 1 :: Rand(0) :: Nil) + val setToRemoveWithSameDeterministicExpression = ExpressionSet(aUpper + 1 :: Rand(0) :: Nil) + val setToRemoveWithOutSameExpression = ExpressionSet(aUpper + 3 :: aUpper + 4 :: Nil) + + assert((initialSet -- setToRemoveWithSameDeterministicExpression).size == 1) + assert((initialSet -- setToRemoveWithOutSameExpression).size == 2) + } + } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala index 582b3ead5e54a..de0e7c7ee49ac 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala @@ -94,6 +94,21 @@ class FilterPushdownSuite extends PlanTest { comparePlans(optimized, correctAnswer) } + test("combine redundant deterministic filters") { + val originalQuery = + testRelation + .where(Rand(0) > 0.1 && 'a === 1) + .where(Rand(0) > 0.1 && 'a === 1) + + val optimized = Optimize.execute(originalQuery.analyze) + val correctAnswer = + testRelation + .where(Rand(0) > 0.1 && 'a === 1 && Rand(0) > 0.1) + .analyze + + comparePlans(optimized, correctAnswer) + } + test("SPARK-16164: Filter pushdown should keep the ordering in the logical plan") { val originalQuery = testRelation From e6e36004afc3f9fc8abea98542248e9de11b4435 Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Fri, 13 Oct 2017 23:09:12 +0800 Subject: [PATCH 1483/1765] [SPARK-14387][SPARK-16628][SPARK-18355][SQL] Use Spark schema to read ORC table instead of ORC file schema ## What changes were proposed in this pull request? Before Hive 2.0, ORC File schema has invalid column names like `_col1` and `_col2`. This is a well-known limitation and there are several Apache Spark issues with `spark.sql.hive.convertMetastoreOrc=true`. This PR ignores ORC File schema and use Spark schema. ## How was this patch tested? Pass the newly added test case. Author: Dongjoon Hyun Closes #19470 from dongjoon-hyun/SPARK-18355. --- .../spark/sql/hive/orc/OrcFileFormat.scala | 31 ++++++---- .../sql/hive/execution/SQLQuerySuite.scala | 62 ++++++++++++++++++- 2 files changed, 80 insertions(+), 13 deletions(-) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileFormat.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileFormat.scala index c76f0ebb36a60..194e69c93e1a8 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileFormat.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileFormat.scala @@ -134,12 +134,11 @@ class OrcFileFormat extends FileFormat with DataSourceRegister with Serializable // SPARK-8501: Empty ORC files always have an empty schema stored in their footer. In this // case, `OrcFileOperator.readSchema` returns `None`, and we can't read the underlying file // using the given physical schema. Instead, we simply return an empty iterator. - val maybePhysicalSchema = OrcFileOperator.readSchema(Seq(file.filePath), Some(conf)) - if (maybePhysicalSchema.isEmpty) { + val isEmptyFile = OrcFileOperator.readSchema(Seq(file.filePath), Some(conf)).isEmpty + if (isEmptyFile) { Iterator.empty } else { - val physicalSchema = maybePhysicalSchema.get - OrcRelation.setRequiredColumns(conf, physicalSchema, requiredSchema) + OrcRelation.setRequiredColumns(conf, dataSchema, requiredSchema) val orcRecordReader = { val job = Job.getInstance(conf) @@ -163,6 +162,7 @@ class OrcFileFormat extends FileFormat with DataSourceRegister with Serializable // Unwraps `OrcStruct`s to `UnsafeRow`s OrcRelation.unwrapOrcStructs( conf, + dataSchema, requiredSchema, Some(orcRecordReader.getObjectInspector.asInstanceOf[StructObjectInspector]), recordsIterator) @@ -272,25 +272,32 @@ private[orc] object OrcRelation extends HiveInspectors { def unwrapOrcStructs( conf: Configuration, dataSchema: StructType, + requiredSchema: StructType, maybeStructOI: Option[StructObjectInspector], iterator: Iterator[Writable]): Iterator[InternalRow] = { val deserializer = new OrcSerde - val mutableRow = new SpecificInternalRow(dataSchema.map(_.dataType)) - val unsafeProjection = UnsafeProjection.create(dataSchema) + val mutableRow = new SpecificInternalRow(requiredSchema.map(_.dataType)) + val unsafeProjection = UnsafeProjection.create(requiredSchema) def unwrap(oi: StructObjectInspector): Iterator[InternalRow] = { - val (fieldRefs, fieldOrdinals) = dataSchema.zipWithIndex.map { - case (field, ordinal) => oi.getStructFieldRef(field.name) -> ordinal + val (fieldRefs, fieldOrdinals) = requiredSchema.zipWithIndex.map { + case (field, ordinal) => + var ref = oi.getStructFieldRef(field.name) + if (ref == null) { + ref = oi.getStructFieldRef("_col" + dataSchema.fieldIndex(field.name)) + } + ref -> ordinal }.unzip - val unwrappers = fieldRefs.map(unwrapperFor) + val unwrappers = fieldRefs.map(r => if (r == null) null else unwrapperFor(r)) iterator.map { value => val raw = deserializer.deserialize(value) var i = 0 val length = fieldRefs.length while (i < length) { - val fieldValue = oi.getStructFieldData(raw, fieldRefs(i)) + val fieldRef = fieldRefs(i) + val fieldValue = if (fieldRef == null) null else oi.getStructFieldData(raw, fieldRef) if (fieldValue == null) { mutableRow.setNullAt(fieldOrdinals(i)) } else { @@ -306,8 +313,8 @@ private[orc] object OrcRelation extends HiveInspectors { } def setRequiredColumns( - conf: Configuration, physicalSchema: StructType, requestedSchema: StructType): Unit = { - val ids = requestedSchema.map(a => physicalSchema.fieldIndex(a.name): Integer) + conf: Configuration, dataSchema: StructType, requestedSchema: StructType): Unit = { + val ids = requestedSchema.map(a => dataSchema.fieldIndex(a.name): Integer) val (sortedIDs, sortedNames) = ids.zip(requestedSchema.fieldNames).sorted.unzip HiveShim.appendReadColumns(conf, sortedIDs, sortedNames) } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala index 09c59000b3e3f..94fa43dec7313 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala @@ -34,7 +34,7 @@ import org.apache.spark.sql.catalyst.parser.ParseException import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, SubqueryAlias} import org.apache.spark.sql.execution.datasources.{HadoopFsRelation, LogicalRelation} import org.apache.spark.sql.functions._ -import org.apache.spark.sql.hive.HiveUtils +import org.apache.spark.sql.hive.{HiveExternalCatalog, HiveUtils} import org.apache.spark.sql.hive.test.TestHiveSingleton import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SQLTestUtils @@ -2050,4 +2050,64 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton { } } } + + Seq("orc", "parquet").foreach { format => + test(s"SPARK-18355 Read data from a hive table with a new column - $format") { + val client = spark.sharedState.externalCatalog.asInstanceOf[HiveExternalCatalog].client + + Seq("true", "false").foreach { value => + withSQLConf( + HiveUtils.CONVERT_METASTORE_ORC.key -> value, + HiveUtils.CONVERT_METASTORE_PARQUET.key -> value) { + withTempDatabase { db => + client.runSqlHive( + s""" + |CREATE TABLE $db.t( + | click_id string, + | search_id string, + | uid bigint) + |PARTITIONED BY ( + | ts string, + | hour string) + |STORED AS $format + """.stripMargin) + + client.runSqlHive( + s""" + |INSERT INTO TABLE $db.t + |PARTITION (ts = '98765', hour = '01') + |VALUES (12, 2, 12345) + """.stripMargin + ) + + checkAnswer( + sql(s"SELECT click_id, search_id, uid, ts, hour FROM $db.t"), + Row("12", "2", 12345, "98765", "01")) + + client.runSqlHive(s"ALTER TABLE $db.t ADD COLUMNS (dummy string)") + + checkAnswer( + sql(s"SELECT click_id, search_id FROM $db.t"), + Row("12", "2")) + + checkAnswer( + sql(s"SELECT search_id, click_id FROM $db.t"), + Row("2", "12")) + + checkAnswer( + sql(s"SELECT search_id FROM $db.t"), + Row("2")) + + checkAnswer( + sql(s"SELECT dummy, click_id FROM $db.t"), + Row(null, "12")) + + checkAnswer( + sql(s"SELECT click_id, search_id, uid, dummy, ts, hour FROM $db.t"), + Row("12", "2", 12345, null, "98765", "01")) + } + } + } + } + } } From 6412ea1759d39a2380c572ec24cfd8ae4f2d81f7 Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Sat, 14 Oct 2017 00:35:12 +0800 Subject: [PATCH 1484/1765] [SPARK-21247][SQL] Type comparison should respect case-sensitive SQL conf ## What changes were proposed in this pull request? This is an effort to reduce the difference between Hive and Spark. Spark supports case-sensitivity in columns. Especially, for Struct types, with `spark.sql.caseSensitive=true`, the following is supported. ```scala scala> sql("select named_struct('a', 1, 'A', 2).a").show +--------------------------+ |named_struct(a, 1, A, 2).a| +--------------------------+ | 1| +--------------------------+ scala> sql("select named_struct('a', 1, 'A', 2).A").show +--------------------------+ |named_struct(a, 1, A, 2).A| +--------------------------+ | 2| +--------------------------+ ``` And vice versa, with `spark.sql.caseSensitive=false`, the following is supported. ```scala scala> sql("select named_struct('a', 1).A, named_struct('A', 1).a").show +--------------------+--------------------+ |named_struct(a, 1).A|named_struct(A, 1).a| +--------------------+--------------------+ | 1| 1| +--------------------+--------------------+ ``` However, types are considered different. For example, SET operations fail. ```scala scala> sql("SELECT named_struct('a',1) union all (select named_struct('A',2))").show org.apache.spark.sql.AnalysisException: Union can only be performed on tables with the compatible column types. struct <> struct at the first column of the second table;; 'Union :- Project [named_struct(a, 1) AS named_struct(a, 1)#57] : +- OneRowRelation$ +- Project [named_struct(A, 2) AS named_struct(A, 2)#58] +- OneRowRelation$ ``` This PR aims to support case-insensitive type equality. For example, in Set operation, the above operation succeed when `spark.sql.caseSensitive=false`. ```scala scala> sql("SELECT named_struct('a',1) union all (select named_struct('A',2))").show +------------------+ |named_struct(a, 1)| +------------------+ | [1]| | [2]| +------------------+ ``` ## How was this patch tested? Pass the Jenkins with a newly add test case. Author: Dongjoon Hyun Closes #18460 from dongjoon-hyun/SPARK-21247. --- .../sql/catalyst/analysis/TypeCoercion.scala | 10 ++++ .../org/apache/spark/sql/types/DataType.scala | 7 ++- .../catalyst/analysis/TypeCoercionSuite.scala | 52 +++++++++++++++++-- .../org/apache/spark/sql/SQLQuerySuite.scala | 38 ++++++++++++++ 4 files changed, 102 insertions(+), 5 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala index 9ffe646b5e4ec..532d22dbf2321 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala @@ -100,6 +100,16 @@ object TypeCoercion { case (_: TimestampType, _: DateType) | (_: DateType, _: TimestampType) => Some(TimestampType) + case (t1 @ StructType(fields1), t2 @ StructType(fields2)) if t1.sameType(t2) => + Some(StructType(fields1.zip(fields2).map { case (f1, f2) => + // Since `t1.sameType(t2)` is true, two StructTypes have the same DataType + // except `name` (in case of `spark.sql.caseSensitive=false`) and `nullable`. + // - Different names: use f1.name + // - Different nullabilities: `nullable` is true iff one of them is nullable. + val dataType = findTightestCommonType(f1.dataType, f2.dataType).get + StructField(f1.name, dataType, nullable = f1.nullable || f2.nullable) + })) + case _ => None } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala index 30745c6a9d42a..d6e0df12218ad 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala @@ -26,6 +26,7 @@ import org.json4s.jackson.JsonMethods._ import org.apache.spark.annotation.InterfaceStability import org.apache.spark.sql.catalyst.expressions.Expression +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.util.Utils /** @@ -80,7 +81,11 @@ abstract class DataType extends AbstractDataType { * (`StructField.nullable`, `ArrayType.containsNull`, and `MapType.valueContainsNull`). */ private[spark] def sameType(other: DataType): Boolean = - DataType.equalsIgnoreNullability(this, other) + if (SQLConf.get.caseSensitiveAnalysis) { + DataType.equalsIgnoreNullability(this, other) + } else { + DataType.equalsIgnoreCaseAndNullability(this, other) + } /** * Returns the same data type but set all nullability fields are true diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala index d62e3b6dfe34f..793e04f66f0f9 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala @@ -131,14 +131,17 @@ class TypeCoercionSuite extends AnalysisTest { widenFunc: (DataType, DataType) => Option[DataType], t1: DataType, t2: DataType, - expected: Option[DataType]): Unit = { + expected: Option[DataType], + isSymmetric: Boolean = true): Unit = { var found = widenFunc(t1, t2) assert(found == expected, s"Expected $expected as wider common type for $t1 and $t2, found $found") // Test both directions to make sure the widening is symmetric. - found = widenFunc(t2, t1) - assert(found == expected, - s"Expected $expected as wider common type for $t2 and $t1, found $found") + if (isSymmetric) { + found = widenFunc(t2, t1) + assert(found == expected, + s"Expected $expected as wider common type for $t2 and $t1, found $found") + } } test("implicit type cast - ByteType") { @@ -385,6 +388,47 @@ class TypeCoercionSuite extends AnalysisTest { widenTest(NullType, StructType(Seq()), Some(StructType(Seq()))) widenTest(StringType, MapType(IntegerType, StringType, true), None) widenTest(ArrayType(IntegerType), StructType(Seq()), None) + + widenTest( + StructType(Seq(StructField("a", IntegerType))), + StructType(Seq(StructField("b", IntegerType))), + None) + widenTest( + StructType(Seq(StructField("a", IntegerType, nullable = false))), + StructType(Seq(StructField("a", DoubleType, nullable = false))), + None) + + widenTest( + StructType(Seq(StructField("a", IntegerType, nullable = false))), + StructType(Seq(StructField("a", IntegerType, nullable = false))), + Some(StructType(Seq(StructField("a", IntegerType, nullable = false))))) + widenTest( + StructType(Seq(StructField("a", IntegerType, nullable = false))), + StructType(Seq(StructField("a", IntegerType, nullable = true))), + Some(StructType(Seq(StructField("a", IntegerType, nullable = true))))) + widenTest( + StructType(Seq(StructField("a", IntegerType, nullable = true))), + StructType(Seq(StructField("a", IntegerType, nullable = false))), + Some(StructType(Seq(StructField("a", IntegerType, nullable = true))))) + widenTest( + StructType(Seq(StructField("a", IntegerType, nullable = true))), + StructType(Seq(StructField("a", IntegerType, nullable = true))), + Some(StructType(Seq(StructField("a", IntegerType, nullable = true))))) + + withSQLConf(SQLConf.CASE_SENSITIVE.key -> "true") { + widenTest( + StructType(Seq(StructField("a", IntegerType))), + StructType(Seq(StructField("A", IntegerType))), + None) + } + withSQLConf(SQLConf.CASE_SENSITIVE.key -> "false") { + checkWidenType( + TypeCoercion.findTightestCommonType, + StructType(Seq(StructField("a", IntegerType), StructField("B", IntegerType))), + StructType(Seq(StructField("A", IntegerType), StructField("b", IntegerType))), + Some(StructType(Seq(StructField("a", IntegerType), StructField("B", IntegerType)))), + isSymmetric = false) + } } test("wider common type for decimal and array") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 93a7777b70b46..f0c58e2e5bf45 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -2646,6 +2646,44 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { } } + test("SPARK-21247: Allow case-insensitive type equality in Set operation") { + withSQLConf(SQLConf.CASE_SENSITIVE.key -> "false") { + sql("SELECT struct(1 a) UNION ALL (SELECT struct(2 A))") + sql("SELECT struct(1 a) EXCEPT (SELECT struct(2 A))") + + withTable("t", "S") { + sql("CREATE TABLE t(c struct) USING parquet") + sql("CREATE TABLE S(C struct) USING parquet") + Seq(("c", "C"), ("C", "c"), ("c.f", "C.F"), ("C.F", "c.f")).foreach { + case (left, right) => + checkAnswer(sql(s"SELECT * FROM t, S WHERE t.$left = S.$right"), Seq.empty) + } + } + } + + withSQLConf(SQLConf.CASE_SENSITIVE.key -> "true") { + val m1 = intercept[AnalysisException] { + sql("SELECT struct(1 a) UNION ALL (SELECT struct(2 A))") + }.message + assert(m1.contains("Union can only be performed on tables with the compatible column types")) + + val m2 = intercept[AnalysisException] { + sql("SELECT struct(1 a) EXCEPT (SELECT struct(2 A))") + }.message + assert(m2.contains("Except can only be performed on tables with the compatible column types")) + + withTable("t", "S") { + sql("CREATE TABLE t(c struct) USING parquet") + sql("CREATE TABLE S(C struct) USING parquet") + checkAnswer(sql("SELECT * FROM t, S WHERE t.c.f = S.C.F"), Seq.empty) + val m = intercept[AnalysisException] { + sql("SELECT * FROM t, S WHERE c = C") + }.message + assert(m.contains("cannot resolve '(t.`c` = S.`C`)' due to data type mismatch")) + } + } + } + test("SPARK-21335: support un-aliased subquery") { withTempView("v") { Seq(1 -> "a").toDF("i", "j").createOrReplaceTempView("v") From 3823dc88d3816c7d1099f9601426108acc90574c Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Fri, 13 Oct 2017 10:49:48 -0700 Subject: [PATCH 1485/1765] [SPARK-22252][SQL][FOLLOWUP] Command should not be a LeafNode ## What changes were proposed in this pull request? This is a minor folllowup of #19474 . #19474 partially reverted #18064 but accidentally introduced a behavior change. `Command` extended `LogicalPlan` before #18064 , but #19474 made it extend `LeafNode`. This is an internal behavior change as now all `Command` subclasses can't define children, and they have to implement `computeStatistic` method. This PR fixes this by making `Command` extend `LogicalPlan` ## How was this patch tested? N/A Author: Wenchen Fan Closes #19493 from cloud-fan/minor. --- .../org/apache/spark/sql/catalyst/plans/logical/Command.scala | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/Command.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/Command.scala index 38f47081b6f55..ec5766e1f67f2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/Command.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/Command.scala @@ -24,6 +24,7 @@ import org.apache.spark.sql.catalyst.expressions.Attribute * commands can be used by parsers to represent DDL operations. Commands, unlike queries, are * eagerly executed. */ -trait Command extends LeafNode { +trait Command extends LogicalPlan { override def output: Seq[Attribute] = Seq.empty + override def children: Seq[LogicalPlan] = Seq.empty } From 1bb8b76045420b61a37806d6f4765af15c4052a7 Mon Sep 17 00:00:00 2001 From: Liwei Lin Date: Fri, 13 Oct 2017 15:13:06 -0700 Subject: [PATCH 1486/1765] [MINOR][SS] keyWithIndexToNumValues" -> "keyWithIndexToValue" ## What changes were proposed in this pull request? This PR changes `keyWithIndexToNumValues` to `keyWithIndexToValue`. There will be directories on HDFS named with this `keyWithIndexToNumValues`. So if we ever want to fix this, let's fix it now. ## How was this patch tested? existing unit test cases. Author: Liwei Lin Closes #19435 from lw-lin/keyWithIndex. --- .../streaming/state/SymmetricHashJoinStateManager.scala | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/SymmetricHashJoinStateManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/SymmetricHashJoinStateManager.scala index d256fb578d921..6b386308c79fb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/SymmetricHashJoinStateManager.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/SymmetricHashJoinStateManager.scala @@ -384,7 +384,7 @@ class SymmetricHashJoinStateManager( } /** A wrapper around a [[StateStore]] that stores [(key, index) -> value]. */ - private class KeyWithIndexToValueStore extends StateStoreHandler(KeyWithIndexToValuesType) { + private class KeyWithIndexToValueStore extends StateStoreHandler(KeyWithIndexToValueType) { private val keyWithIndexExprs = keyAttributes :+ Literal(1L) private val keyWithIndexSchema = keySchema.add("index", LongType) private val indexOrdinalInKeyWithIndexRow = keyAttributes.size @@ -471,7 +471,7 @@ class SymmetricHashJoinStateManager( object SymmetricHashJoinStateManager { def allStateStoreNames(joinSides: JoinSide*): Seq[String] = { - val allStateStoreTypes: Seq[StateStoreType] = Seq(KeyToNumValuesType, KeyWithIndexToValuesType) + val allStateStoreTypes: Seq[StateStoreType] = Seq(KeyToNumValuesType, KeyWithIndexToValueType) for (joinSide <- joinSides; stateStoreType <- allStateStoreTypes) yield { getStateStoreName(joinSide, stateStoreType) } @@ -483,8 +483,8 @@ object SymmetricHashJoinStateManager { override def toString(): String = "keyToNumValues" } - private case object KeyWithIndexToValuesType extends StateStoreType { - override def toString(): String = "keyWithIndexToNumValues" + private case object KeyWithIndexToValueType extends StateStoreType { + override def toString(): String = "keyWithIndexToValue" } private def getStateStoreName(joinSide: JoinSide, storeType: StateStoreType): String = { From 06df34d35ec088277445ef09cfb24bfe996f072e Mon Sep 17 00:00:00 2001 From: Devaraj K Date: Fri, 13 Oct 2017 17:12:50 -0700 Subject: [PATCH 1487/1765] [SPARK-11034][LAUNCHER][MESOS] Launcher: add support for monitoring Mesos apps ## What changes were proposed in this pull request? Added Launcher support for monitoring Mesos apps in Client mode. SPARK-11033 can handle the support for Mesos/Cluster mode since the Standalone/Cluster and Mesos/Cluster modes use the same code at client side. ## How was this patch tested? I verified it manually by running launcher application, able to launch, stop and kill the mesos applications and also can invoke other launcher API's. Author: Devaraj K Closes #19385 from devaraj-kavali/SPARK-11034. --- .../MesosCoarseGrainedSchedulerBackend.scala | 28 +++++++++++++++++-- 1 file changed, 26 insertions(+), 2 deletions(-) diff --git a/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackend.scala b/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackend.scala index 80c0a041b7322..603c980cb268d 100644 --- a/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackend.scala +++ b/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackend.scala @@ -32,6 +32,7 @@ import org.apache.spark.{SecurityManager, SparkContext, SparkException, TaskStat import org.apache.spark.deploy.mesos.config._ import org.apache.spark.deploy.security.HadoopDelegationTokenManager import org.apache.spark.internal.config +import org.apache.spark.launcher.{LauncherBackend, SparkAppHandle} import org.apache.spark.network.netty.SparkTransportConf import org.apache.spark.network.shuffle.mesos.MesosExternalShuffleClient import org.apache.spark.rpc.RpcEndpointAddress @@ -89,6 +90,13 @@ private[spark] class MesosCoarseGrainedSchedulerBackend( // Synchronization protected by stateLock private[this] var stopCalled: Boolean = false + private val launcherBackend = new LauncherBackend() { + override protected def onStopRequest(): Unit = { + stopSchedulerBackend() + setState(SparkAppHandle.State.KILLED) + } + } + // If shuffle service is enabled, the Spark driver will register with the shuffle service. // This is for cleaning up shuffle files reliably. private val shuffleServiceEnabled = conf.getBoolean("spark.shuffle.service.enabled", false) @@ -182,6 +190,9 @@ private[spark] class MesosCoarseGrainedSchedulerBackend( override def start() { super.start() + if (sc.deployMode == "client") { + launcherBackend.connect() + } val startedBefore = IdHelper.startedBefore.getAndSet(true) val suffix = if (startedBefore) { @@ -202,6 +213,7 @@ private[spark] class MesosCoarseGrainedSchedulerBackend( sc.conf.getOption("spark.mesos.driver.frameworkId").map(_ + suffix) ) + launcherBackend.setState(SparkAppHandle.State.SUBMITTED) startScheduler(driver) } @@ -295,15 +307,21 @@ private[spark] class MesosCoarseGrainedSchedulerBackend( this.mesosExternalShuffleClient.foreach(_.init(appId)) this.schedulerDriver = driver markRegistered() + launcherBackend.setAppId(appId) + launcherBackend.setState(SparkAppHandle.State.RUNNING) } override def sufficientResourcesRegistered(): Boolean = { totalCoreCount.get >= maxCoresOption.getOrElse(0) * minRegisteredRatio } - override def disconnected(d: org.apache.mesos.SchedulerDriver) {} + override def disconnected(d: org.apache.mesos.SchedulerDriver) { + launcherBackend.setState(SparkAppHandle.State.SUBMITTED) + } - override def reregistered(d: org.apache.mesos.SchedulerDriver, masterInfo: MasterInfo) {} + override def reregistered(d: org.apache.mesos.SchedulerDriver, masterInfo: MasterInfo) { + launcherBackend.setState(SparkAppHandle.State.RUNNING) + } /** * Method called by Mesos to offer resources on slaves. We respond by launching an executor, @@ -611,6 +629,12 @@ private[spark] class MesosCoarseGrainedSchedulerBackend( } override def stop() { + stopSchedulerBackend() + launcherBackend.setState(SparkAppHandle.State.FINISHED) + launcherBackend.close() + } + + private def stopSchedulerBackend() { // Make sure we're not launching tasks during shutdown stateLock.synchronized { if (stopCalled) { From e3536406ec6ff65a8b41ba2f2fd40517a760cfd6 Mon Sep 17 00:00:00 2001 From: Steve Loughran Date: Fri, 13 Oct 2017 23:08:17 -0700 Subject: [PATCH 1488/1765] [SPARK-21762][SQL] FileFormatWriter/BasicWriteTaskStatsTracker metrics collection fails if a new file isn't yet visible ## What changes were proposed in this pull request? `BasicWriteTaskStatsTracker.getFileSize()` to catch `FileNotFoundException`, log info and then return 0 as a file size. This ensures that if a newly created file isn't visible due to the store not always having create consistency, the metric collection doesn't cause the failure. ## How was this patch tested? New test suite included, `BasicWriteTaskStatsTrackerSuite`. This not only checks the resilience to missing files, but verifies the existing logic as to how file statistics are gathered. Note that in the current implementation 1. if you call `Tracker..getFinalStats()` more than once, the file size count will increase by size of the last file. This could be fixed by clearing the filename field inside `getFinalStats()` itself. 2. If you pass in an empty or null string to `Tracker.newFile(path)` then IllegalArgumentException is raised, but only in `getFinalStats()`, rather than in `newFile`. There's a test for this behaviour in the new suite, as it verifies that only FNFEs get swallowed. Author: Steve Loughran Closes #18979 from steveloughran/cloud/SPARK-21762-missing-files-in-metrics. --- .../datasources/BasicWriteStatsTracker.scala | 49 +++- .../BasicWriteTaskStatsTrackerSuite.scala | 220 ++++++++++++++++++ .../sql/hive/execution/SQLQuerySuite.scala | 8 + 3 files changed, 265 insertions(+), 12 deletions(-) create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/BasicWriteTaskStatsTrackerSuite.scala diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/BasicWriteStatsTracker.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/BasicWriteStatsTracker.scala index b8f7d130d569f..11af0aaa7b206 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/BasicWriteStatsTracker.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/BasicWriteStatsTracker.scala @@ -17,10 +17,13 @@ package org.apache.spark.sql.execution.datasources +import java.io.FileNotFoundException + import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.Path import org.apache.spark.SparkContext +import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.execution.SQLExecution import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics} @@ -44,20 +47,32 @@ case class BasicWriteTaskStats( * @param hadoopConf */ class BasicWriteTaskStatsTracker(hadoopConf: Configuration) - extends WriteTaskStatsTracker { + extends WriteTaskStatsTracker with Logging { private[this] var numPartitions: Int = 0 private[this] var numFiles: Int = 0 + private[this] var submittedFiles: Int = 0 private[this] var numBytes: Long = 0L private[this] var numRows: Long = 0L - private[this] var curFile: String = null - + private[this] var curFile: Option[String] = None - private def getFileSize(filePath: String): Long = { + /** + * Get the size of the file expected to have been written by a worker. + * @param filePath path to the file + * @return the file size or None if the file was not found. + */ + private def getFileSize(filePath: String): Option[Long] = { val path = new Path(filePath) val fs = path.getFileSystem(hadoopConf) - fs.getFileStatus(path).getLen() + try { + Some(fs.getFileStatus(path).getLen()) + } catch { + case e: FileNotFoundException => + // may arise against eventually consistent object stores + logDebug(s"File $path is not yet visible", e) + None + } } @@ -70,12 +85,19 @@ class BasicWriteTaskStatsTracker(hadoopConf: Configuration) } override def newFile(filePath: String): Unit = { - if (numFiles > 0) { - // we assume here that we've finished writing to disk the previous file by now - numBytes += getFileSize(curFile) + statCurrentFile() + curFile = Some(filePath) + submittedFiles += 1 + } + + private def statCurrentFile(): Unit = { + curFile.foreach { path => + getFileSize(path).foreach { len => + numBytes += len + numFiles += 1 + } + curFile = None } - curFile = filePath - numFiles += 1 } override def newRow(row: InternalRow): Unit = { @@ -83,8 +105,11 @@ class BasicWriteTaskStatsTracker(hadoopConf: Configuration) } override def getFinalStats(): WriteTaskStats = { - if (numFiles > 0) { - numBytes += getFileSize(curFile) + statCurrentFile() + if (submittedFiles != numFiles) { + logInfo(s"Expected $submittedFiles files, but only saw $numFiles. " + + "This could be due to the output format not writing empty files, " + + "or files being not immediately visible in the filesystem.") } BasicWriteTaskStats(numPartitions, numFiles, numBytes, numRows) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/BasicWriteTaskStatsTrackerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/BasicWriteTaskStatsTrackerSuite.scala new file mode 100644 index 0000000000000..bf3c8ede9a980 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/BasicWriteTaskStatsTrackerSuite.scala @@ -0,0 +1,220 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources + +import java.nio.charset.Charset + +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.fs.Path + +import org.apache.spark.SparkFunSuite +import org.apache.spark.util.Utils + +/** + * Test how BasicWriteTaskStatsTracker handles files. + * + * Two different datasets are written (alongside 0), one of + * length 10, one of 3. They were chosen to be distinct enough + * that it is straightforward to determine which file lengths were added + * from the sum of all files added. Lengths like "10" and "5" would + * be less informative. + */ +class BasicWriteTaskStatsTrackerSuite extends SparkFunSuite { + + private val tempDir = Utils.createTempDir() + private val tempDirPath = new Path(tempDir.toURI) + private val conf = new Configuration() + private val localfs = tempDirPath.getFileSystem(conf) + private val data1 = "0123456789".getBytes(Charset.forName("US-ASCII")) + private val data2 = "012".getBytes(Charset.forName("US-ASCII")) + private val len1 = data1.length + private val len2 = data2.length + + /** + * In teardown delete the temp dir. + */ + protected override def afterAll(): Unit = { + Utils.deleteRecursively(tempDir) + } + + /** + * Assert that the stats match that expected. + * @param tracker tracker to check + * @param files number of files expected + * @param bytes total number of bytes expected + */ + private def assertStats( + tracker: BasicWriteTaskStatsTracker, + files: Int, + bytes: Int): Unit = { + val stats = finalStatus(tracker) + assert(files === stats.numFiles, "Wrong number of files") + assert(bytes === stats.numBytes, "Wrong byte count of file size") + } + + private def finalStatus(tracker: BasicWriteTaskStatsTracker): BasicWriteTaskStats = { + tracker.getFinalStats().asInstanceOf[BasicWriteTaskStats] + } + + test("No files in run") { + val tracker = new BasicWriteTaskStatsTracker(conf) + assertStats(tracker, 0, 0) + } + + test("Missing File") { + val missing = new Path(tempDirPath, "missing") + val tracker = new BasicWriteTaskStatsTracker(conf) + tracker.newFile(missing.toString) + assertStats(tracker, 0, 0) + } + + test("Empty filename is forwarded") { + val tracker = new BasicWriteTaskStatsTracker(conf) + tracker.newFile("") + intercept[IllegalArgumentException] { + finalStatus(tracker) + } + } + + test("Null filename is only picked up in final status") { + val tracker = new BasicWriteTaskStatsTracker(conf) + tracker.newFile(null) + intercept[IllegalArgumentException] { + finalStatus(tracker) + } + } + + test("0 byte file") { + val file = new Path(tempDirPath, "file0") + val tracker = new BasicWriteTaskStatsTracker(conf) + tracker.newFile(file.toString) + touch(file) + assertStats(tracker, 1, 0) + } + + test("File with data") { + val file = new Path(tempDirPath, "file-with-data") + val tracker = new BasicWriteTaskStatsTracker(conf) + tracker.newFile(file.toString) + write1(file) + assertStats(tracker, 1, len1) + } + + test("Open file") { + val file = new Path(tempDirPath, "file-open") + val tracker = new BasicWriteTaskStatsTracker(conf) + tracker.newFile(file.toString) + val stream = localfs.create(file, true) + try { + assertStats(tracker, 1, 0) + stream.write(data1) + stream.flush() + assert(1 === finalStatus(tracker).numFiles, "Wrong number of files") + } finally { + stream.close() + } + } + + test("Two files") { + val file1 = new Path(tempDirPath, "f-2-1") + val file2 = new Path(tempDirPath, "f-2-2") + val tracker = new BasicWriteTaskStatsTracker(conf) + tracker.newFile(file1.toString) + write1(file1) + tracker.newFile(file2.toString) + write2(file2) + assertStats(tracker, 2, len1 + len2) + } + + test("Three files, last one empty") { + val file1 = new Path(tempDirPath, "f-3-1") + val file2 = new Path(tempDirPath, "f-3-2") + val file3 = new Path(tempDirPath, "f-3-2") + val tracker = new BasicWriteTaskStatsTracker(conf) + tracker.newFile(file1.toString) + write1(file1) + tracker.newFile(file2.toString) + write2(file2) + tracker.newFile(file3.toString) + touch(file3) + assertStats(tracker, 3, len1 + len2) + } + + test("Three files, one not found") { + val file1 = new Path(tempDirPath, "f-4-1") + val file2 = new Path(tempDirPath, "f-4-2") + val file3 = new Path(tempDirPath, "f-3-2") + val tracker = new BasicWriteTaskStatsTracker(conf) + // file 1 + tracker.newFile(file1.toString) + write1(file1) + + // file 2 is noted, but not created + tracker.newFile(file2.toString) + + // file 3 is noted & then created + tracker.newFile(file3.toString) + write2(file3) + + // the expected size is file1 + file3; only two files are reported + // as found + assertStats(tracker, 2, len1 + len2) + } + + /** + * Write a 0-byte file. + * @param file file path + */ + private def touch(file: Path): Unit = { + localfs.create(file, true).close() + } + + /** + * Write a byte array. + * @param file path to file + * @param data data + * @return bytes written + */ + private def write(file: Path, data: Array[Byte]): Integer = { + val stream = localfs.create(file, true) + try { + stream.write(data) + } finally { + stream.close() + } + data.length + } + + /** + * Write a data1 array. + * @param file file + */ + private def write1(file: Path): Unit = { + write(file, data1) + } + + /** + * Write a data2 array. + * + * @param file file + */ + private def write2(file: Path): Unit = { + write(file, data2) + } + +} diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala index 94fa43dec7313..60935c3e85c43 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala @@ -2110,4 +2110,12 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton { } } } + + Seq("orc", "parquet", "csv", "json", "text").foreach { format => + test(s"Writing empty datasets should not fail - $format") { + withTempDir { dir => + Seq("str").toDS.limit(0).write.format(format).save(dir.getCanonicalPath + "/tmp") + } + } + } } From e0503a7223410289d01bc4b20da3a451730577da Mon Sep 17 00:00:00 2001 From: Takuya UESHIN Date: Fri, 13 Oct 2017 23:24:36 -0700 Subject: [PATCH 1489/1765] [SPARK-22273][SQL] Fix key/value schema field names in HashMapGenerators. ## What changes were proposed in this pull request? When fixing schema field names using escape characters with `addReferenceMinorObj()` at [SPARK-18952](https://issues.apache.org/jira/browse/SPARK-18952) (#16361), double-quotes around the names were remained and the names become something like `"((java.lang.String) references[1])"`. ```java /* 055 */ private int maxSteps = 2; /* 056 */ private int numRows = 0; /* 057 */ private org.apache.spark.sql.types.StructType keySchema = new org.apache.spark.sql.types.StructType().add("((java.lang.String) references[1])", org.apache.spark.sql.types.DataTypes.StringType); /* 058 */ private org.apache.spark.sql.types.StructType valueSchema = new org.apache.spark.sql.types.StructType().add("((java.lang.String) references[2])", org.apache.spark.sql.types.DataTypes.LongType); /* 059 */ private Object emptyVBase; ``` We should remove the double-quotes to refer the values in `references` properly: ```java /* 055 */ private int maxSteps = 2; /* 056 */ private int numRows = 0; /* 057 */ private org.apache.spark.sql.types.StructType keySchema = new org.apache.spark.sql.types.StructType().add(((java.lang.String) references[1]), org.apache.spark.sql.types.DataTypes.StringType); /* 058 */ private org.apache.spark.sql.types.StructType valueSchema = new org.apache.spark.sql.types.StructType().add(((java.lang.String) references[2]), org.apache.spark.sql.types.DataTypes.LongType); /* 059 */ private Object emptyVBase; ``` ## How was this patch tested? Existing tests. Author: Takuya UESHIN Closes #19491 from ueshin/issues/SPARK-22273. --- .../execution/aggregate/RowBasedHashMapGenerator.scala | 8 ++++---- .../execution/aggregate/VectorizedHashMapGenerator.scala | 8 ++++---- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/RowBasedHashMapGenerator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/RowBasedHashMapGenerator.scala index 9316ebcdf105c..3718424931b40 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/RowBasedHashMapGenerator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/RowBasedHashMapGenerator.scala @@ -50,10 +50,10 @@ class RowBasedHashMapGenerator( val keyName = ctx.addReferenceMinorObj(key.name) key.dataType match { case d: DecimalType => - s""".add("$keyName", org.apache.spark.sql.types.DataTypes.createDecimalType( + s""".add($keyName, org.apache.spark.sql.types.DataTypes.createDecimalType( |${d.precision}, ${d.scale}))""".stripMargin case _ => - s""".add("$keyName", org.apache.spark.sql.types.DataTypes.${key.dataType})""" + s""".add($keyName, org.apache.spark.sql.types.DataTypes.${key.dataType})""" } }.mkString("\n").concat(";") @@ -63,10 +63,10 @@ class RowBasedHashMapGenerator( val keyName = ctx.addReferenceMinorObj(key.name) key.dataType match { case d: DecimalType => - s""".add("$keyName", org.apache.spark.sql.types.DataTypes.createDecimalType( + s""".add($keyName, org.apache.spark.sql.types.DataTypes.createDecimalType( |${d.precision}, ${d.scale}))""".stripMargin case _ => - s""".add("$keyName", org.apache.spark.sql.types.DataTypes.${key.dataType})""" + s""".add($keyName, org.apache.spark.sql.types.DataTypes.${key.dataType})""" } }.mkString("\n").concat(";") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/VectorizedHashMapGenerator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/VectorizedHashMapGenerator.scala index 13f79275cac41..812d405d5ebfe 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/VectorizedHashMapGenerator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/VectorizedHashMapGenerator.scala @@ -55,10 +55,10 @@ class VectorizedHashMapGenerator( val keyName = ctx.addReferenceMinorObj(key.name) key.dataType match { case d: DecimalType => - s""".add("$keyName", org.apache.spark.sql.types.DataTypes.createDecimalType( + s""".add($keyName, org.apache.spark.sql.types.DataTypes.createDecimalType( |${d.precision}, ${d.scale}))""".stripMargin case _ => - s""".add("$keyName", org.apache.spark.sql.types.DataTypes.${key.dataType})""" + s""".add($keyName, org.apache.spark.sql.types.DataTypes.${key.dataType})""" } }.mkString("\n").concat(";") @@ -68,10 +68,10 @@ class VectorizedHashMapGenerator( val keyName = ctx.addReferenceMinorObj(key.name) key.dataType match { case d: DecimalType => - s""".add("$keyName", org.apache.spark.sql.types.DataTypes.createDecimalType( + s""".add($keyName, org.apache.spark.sql.types.DataTypes.createDecimalType( |${d.precision}, ${d.scale}))""".stripMargin case _ => - s""".add("$keyName", org.apache.spark.sql.types.DataTypes.${key.dataType})""" + s""".add($keyName, org.apache.spark.sql.types.DataTypes.${key.dataType})""" } }.mkString("\n").concat(";") From 014dc8471200518d63005eed531777d30d8a6639 Mon Sep 17 00:00:00 2001 From: liulijia Date: Sat, 14 Oct 2017 17:37:33 +0900 Subject: [PATCH 1490/1765] [SPARK-22233][CORE] Allow user to filter out empty split in HadoopRDD ## What changes were proposed in this pull request? Add a flag spark.files.ignoreEmptySplits. When true, methods like that use HadoopRDD and NewHadoopRDD such as SparkContext.textFiles will not create a partition for input splits that are empty. Author: liulijia Closes #19464 from liutang123/SPARK-22233. --- .../spark/internal/config/package.scala | 6 ++ .../org/apache/spark/rdd/HadoopRDD.scala | 12 ++- .../org/apache/spark/rdd/NewHadoopRDD.scala | 13 ++- .../scala/org/apache/spark/FileSuite.scala | 95 +++++++++++++++++-- 4 files changed, 112 insertions(+), 14 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/internal/config/package.scala b/core/src/main/scala/org/apache/spark/internal/config/package.scala index 19336f854145f..ce013d69579c1 100644 --- a/core/src/main/scala/org/apache/spark/internal/config/package.scala +++ b/core/src/main/scala/org/apache/spark/internal/config/package.scala @@ -270,6 +270,12 @@ package object config { .longConf .createWithDefault(4 * 1024 * 1024) + private[spark] val IGNORE_EMPTY_SPLITS = ConfigBuilder("spark.files.ignoreEmptySplits") + .doc("If true, methods that use HadoopRDD and NewHadoopRDD such as " + + "SparkContext.textFiles will not create a partition for input splits that are empty.") + .booleanConf + .createWithDefault(false) + private[spark] val SECRET_REDACTION_PATTERN = ConfigBuilder("spark.redaction.regex") .doc("Regex to decide which Spark configuration properties and environment variables in " + diff --git a/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala b/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala index 23b344230e490..1f33c0a2b709f 100644 --- a/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala @@ -35,7 +35,7 @@ import org.apache.spark.annotation.DeveloperApi import org.apache.spark.broadcast.Broadcast import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.internal.Logging -import org.apache.spark.internal.config.IGNORE_CORRUPT_FILES +import org.apache.spark.internal.config.{IGNORE_CORRUPT_FILES, IGNORE_EMPTY_SPLITS} import org.apache.spark.rdd.HadoopRDD.HadoopMapPartitionsWithSplitRDD import org.apache.spark.scheduler.{HDFSCacheTaskLocation, HostTaskLocation} import org.apache.spark.storage.StorageLevel @@ -134,6 +134,8 @@ class HadoopRDD[K, V]( private val ignoreCorruptFiles = sparkContext.conf.get(IGNORE_CORRUPT_FILES) + private val ignoreEmptySplits = sparkContext.getConf.get(IGNORE_EMPTY_SPLITS) + // Returns a JobConf that will be used on slaves to obtain input splits for Hadoop reads. protected def getJobConf(): JobConf = { val conf: Configuration = broadcastedConf.value.value @@ -195,8 +197,12 @@ class HadoopRDD[K, V]( val jobConf = getJobConf() // add the credentials here as this can be called before SparkContext initialized SparkHadoopUtil.get.addCredentials(jobConf) - val inputFormat = getInputFormat(jobConf) - val inputSplits = inputFormat.getSplits(jobConf, minPartitions) + val allInputSplits = getInputFormat(jobConf).getSplits(jobConf, minPartitions) + val inputSplits = if (ignoreEmptySplits) { + allInputSplits.filter(_.getLength > 0) + } else { + allInputSplits + } val array = new Array[Partition](inputSplits.size) for (i <- 0 until inputSplits.size) { array(i) = new HadoopPartition(id, i, inputSplits(i)) diff --git a/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala b/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala index 482875e6c1ac5..db4eac1d0a775 100644 --- a/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala @@ -21,6 +21,7 @@ import java.io.IOException import java.text.SimpleDateFormat import java.util.{Date, Locale} +import scala.collection.JavaConverters.asScalaBufferConverter import scala.reflect.ClassTag import org.apache.hadoop.conf.{Configurable, Configuration} @@ -34,7 +35,7 @@ import org.apache.spark._ import org.apache.spark.annotation.DeveloperApi import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.internal.Logging -import org.apache.spark.internal.config.IGNORE_CORRUPT_FILES +import org.apache.spark.internal.config.{IGNORE_CORRUPT_FILES, IGNORE_EMPTY_SPLITS} import org.apache.spark.rdd.NewHadoopRDD.NewHadoopMapPartitionsWithSplitRDD import org.apache.spark.storage.StorageLevel import org.apache.spark.util.{SerializableConfiguration, ShutdownHookManager} @@ -89,6 +90,8 @@ class NewHadoopRDD[K, V]( private val ignoreCorruptFiles = sparkContext.conf.get(IGNORE_CORRUPT_FILES) + private val ignoreEmptySplits = sparkContext.getConf.get(IGNORE_EMPTY_SPLITS) + def getConf: Configuration = { val conf: Configuration = confBroadcast.value.value if (shouldCloneJobConf) { @@ -121,8 +124,12 @@ class NewHadoopRDD[K, V]( configurable.setConf(_conf) case _ => } - val jobContext = new JobContextImpl(_conf, jobId) - val rawSplits = inputFormat.getSplits(jobContext).toArray + val allRowSplits = inputFormat.getSplits(new JobContextImpl(_conf, jobId)).asScala + val rawSplits = if (ignoreEmptySplits) { + allRowSplits.filter(_.getLength > 0) + } else { + allRowSplits + } val result = new Array[Partition](rawSplits.size) for (i <- 0 until rawSplits.size) { result(i) = new NewHadoopPartition(id, i, rawSplits(i).asInstanceOf[InputSplit with Writable]) diff --git a/core/src/test/scala/org/apache/spark/FileSuite.scala b/core/src/test/scala/org/apache/spark/FileSuite.scala index 02728180ac82d..4da4323ceb5c8 100644 --- a/core/src/test/scala/org/apache/spark/FileSuite.scala +++ b/core/src/test/scala/org/apache/spark/FileSuite.scala @@ -31,7 +31,7 @@ import org.apache.hadoop.mapreduce.Job import org.apache.hadoop.mapreduce.lib.input.{FileSplit => NewFileSplit, TextInputFormat => NewTextInputFormat} import org.apache.hadoop.mapreduce.lib.output.{TextOutputFormat => NewTextOutputFormat} -import org.apache.spark.internal.config.IGNORE_CORRUPT_FILES +import org.apache.spark.internal.config.{IGNORE_CORRUPT_FILES, IGNORE_EMPTY_SPLITS} import org.apache.spark.rdd.{HadoopRDD, NewHadoopRDD} import org.apache.spark.storage.StorageLevel import org.apache.spark.util.Utils @@ -347,10 +347,10 @@ class FileSuite extends SparkFunSuite with LocalSparkContext { } } - test ("allow user to disable the output directory existence checking (old Hadoop API") { - val sf = new SparkConf() - sf.setAppName("test").setMaster("local").set("spark.hadoop.validateOutputSpecs", "false") - sc = new SparkContext(sf) + test ("allow user to disable the output directory existence checking (old Hadoop API)") { + val conf = new SparkConf() + conf.setAppName("test").setMaster("local").set("spark.hadoop.validateOutputSpecs", "false") + sc = new SparkContext(conf) val randomRDD = sc.parallelize(Array((1, "a"), (1, "a"), (2, "b"), (3, "c")), 1) randomRDD.saveAsTextFile(tempDir.getPath + "/output") assert(new File(tempDir.getPath + "/output/part-00000").exists() === true) @@ -380,9 +380,9 @@ class FileSuite extends SparkFunSuite with LocalSparkContext { } test ("allow user to disable the output directory existence checking (new Hadoop API") { - val sf = new SparkConf() - sf.setAppName("test").setMaster("local").set("spark.hadoop.validateOutputSpecs", "false") - sc = new SparkContext(sf) + val conf = new SparkConf() + conf.setAppName("test").setMaster("local").set("spark.hadoop.validateOutputSpecs", "false") + sc = new SparkContext(conf) val randomRDD = sc.parallelize( Array(("key1", "a"), ("key2", "a"), ("key3", "b"), ("key4", "c")), 1) randomRDD.saveAsNewAPIHadoopFile[NewTextOutputFormat[String, String]]( @@ -510,4 +510,83 @@ class FileSuite extends SparkFunSuite with LocalSparkContext { } } + test("spark.files.ignoreEmptySplits work correctly (old Hadoop API)") { + val conf = new SparkConf() + conf.setAppName("test").setMaster("local").set(IGNORE_EMPTY_SPLITS, true) + sc = new SparkContext(conf) + + def testIgnoreEmptySplits( + data: Array[Tuple2[String, String]], + actualPartitionNum: Int, + expectedPartitionNum: Int): Unit = { + val output = new File(tempDir, "output") + sc.parallelize(data, actualPartitionNum) + .saveAsHadoopFile[TextOutputFormat[String, String]](output.getPath) + for (i <- 0 until actualPartitionNum) { + assert(new File(output, s"part-0000$i").exists() === true) + } + val hadoopRDD = sc.textFile(new File(output, "part-*").getPath) + assert(hadoopRDD.partitions.length === expectedPartitionNum) + Utils.deleteRecursively(output) + } + + // Ensure that if all of the splits are empty, we remove the splits correctly + testIgnoreEmptySplits( + data = Array.empty[Tuple2[String, String]], + actualPartitionNum = 1, + expectedPartitionNum = 0) + + // Ensure that if no split is empty, we don't lose any splits + testIgnoreEmptySplits( + data = Array(("key1", "a"), ("key2", "a"), ("key3", "b")), + actualPartitionNum = 2, + expectedPartitionNum = 2) + + // Ensure that if part of the splits are empty, we remove the splits correctly + testIgnoreEmptySplits( + data = Array(("key1", "a"), ("key2", "a")), + actualPartitionNum = 5, + expectedPartitionNum = 2) + } + + test("spark.files.ignoreEmptySplits work correctly (new Hadoop API)") { + val conf = new SparkConf() + conf.setAppName("test").setMaster("local").set(IGNORE_EMPTY_SPLITS, true) + sc = new SparkContext(conf) + + def testIgnoreEmptySplits( + data: Array[Tuple2[String, String]], + actualPartitionNum: Int, + expectedPartitionNum: Int): Unit = { + val output = new File(tempDir, "output") + sc.parallelize(data, actualPartitionNum) + .saveAsNewAPIHadoopFile[NewTextOutputFormat[String, String]](output.getPath) + for (i <- 0 until actualPartitionNum) { + assert(new File(output, s"part-r-0000$i").exists() === true) + } + val hadoopRDD = sc.newAPIHadoopFile(new File(output, "part-r-*").getPath, + classOf[NewTextInputFormat], classOf[LongWritable], classOf[Text]) + .asInstanceOf[NewHadoopRDD[_, _]] + assert(hadoopRDD.partitions.length === expectedPartitionNum) + Utils.deleteRecursively(output) + } + + // Ensure that if all of the splits are empty, we remove the splits correctly + testIgnoreEmptySplits( + data = Array.empty[Tuple2[String, String]], + actualPartitionNum = 1, + expectedPartitionNum = 0) + + // Ensure that if no split is empty, we don't lose any splits + testIgnoreEmptySplits( + data = Array(("1", "a"), ("2", "a"), ("3", "b")), + actualPartitionNum = 2, + expectedPartitionNum = 2) + + // Ensure that if part of the splits are empty, we remove the splits correctly + testIgnoreEmptySplits( + data = Array(("1", "a"), ("2", "b")), + actualPartitionNum = 5, + expectedPartitionNum = 2) + } } From e8547ffb49071525c06876c856cecc0d4731b918 Mon Sep 17 00:00:00 2001 From: Burak Yavuz Date: Sat, 14 Oct 2017 17:39:15 -0700 Subject: [PATCH 1491/1765] [SPARK-22238] Fix plan resolution bug caused by EnsureStatefulOpPartitioning ## What changes were proposed in this pull request? In EnsureStatefulOpPartitioning, we check that the inputRDD to a SparkPlan has the expected partitioning for Streaming Stateful Operators. The problem is that we are not allowed to access this information during planning. The reason we added that check was because CoalesceExec could actually create RDDs with 0 partitions. We should fix it such that when CoalesceExec says that there is a SinglePartition, there is in fact an inputRDD of 1 partition instead of 0 partitions. ## How was this patch tested? Regression test in StreamingQuerySuite Author: Burak Yavuz Closes #19467 from brkyvz/stateful-op. --- .../plans/physical/partitioning.scala | 15 +- .../execution/basicPhysicalOperators.scala | 27 +++- .../exchange/EnsureRequirements.scala | 5 +- .../FlatMapGroupsWithStateExec.scala | 2 +- .../streaming/IncrementalExecution.scala | 39 ++--- .../streaming/statefulOperators.scala | 11 +- .../org/apache/spark/sql/DataFrameSuite.scala | 2 + .../spark/sql/execution/PlannerSuite.scala | 17 +++ .../streaming/state/StateStoreRDDSuite.scala | 2 +- .../SymmetricHashJoinStateManagerSuite.scala | 2 +- .../sql/streaming/DeduplicateSuite.scala | 11 +- .../EnsureStatefulOpPartitioningSuite.scala | 138 ------------------ .../FlatMapGroupsWithStateSuite.scala | 6 +- .../sql/streaming/StatefulOperatorTest.scala | 49 +++++++ .../streaming/StreamingAggregationSuite.scala | 8 +- .../sql/streaming/StreamingJoinSuite.scala | 2 +- .../sql/streaming/StreamingQuerySuite.scala | 13 ++ 17 files changed, 160 insertions(+), 189 deletions(-) delete mode 100644 sql/core/src/test/scala/org/apache/spark/sql/streaming/EnsureStatefulOpPartitioningSuite.scala create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/streaming/StatefulOperatorTest.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala index 51d78dd1233fe..e57c842ce2a36 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala @@ -49,7 +49,9 @@ case object AllTuples extends Distribution * can mean such tuples are either co-located in the same partition or they will be contiguous * within a single partition. */ -case class ClusteredDistribution(clustering: Seq[Expression]) extends Distribution { +case class ClusteredDistribution( + clustering: Seq[Expression], + numPartitions: Option[Int] = None) extends Distribution { require( clustering != Nil, "The clustering expressions of a ClusteredDistribution should not be Nil. " + @@ -221,6 +223,7 @@ case object SinglePartition extends Partitioning { override def satisfies(required: Distribution): Boolean = required match { case _: BroadcastDistribution => false + case ClusteredDistribution(_, desiredPartitions) => desiredPartitions.forall(_ == 1) case _ => true } @@ -243,8 +246,9 @@ case class HashPartitioning(expressions: Seq[Expression], numPartitions: Int) override def satisfies(required: Distribution): Boolean = required match { case UnspecifiedDistribution => true - case ClusteredDistribution(requiredClustering) => - expressions.forall(x => requiredClustering.exists(_.semanticEquals(x))) + case ClusteredDistribution(requiredClustering, desiredPartitions) => + expressions.forall(x => requiredClustering.exists(_.semanticEquals(x))) && + desiredPartitions.forall(_ == numPartitions) // if desiredPartitions = None, returns true case _ => false } @@ -289,8 +293,9 @@ case class RangePartitioning(ordering: Seq[SortOrder], numPartitions: Int) case OrderedDistribution(requiredOrdering) => val minSize = Seq(requiredOrdering.size, ordering.size).min requiredOrdering.take(minSize) == ordering.take(minSize) - case ClusteredDistribution(requiredClustering) => - ordering.map(_.child).forall(x => requiredClustering.exists(_.semanticEquals(x))) + case ClusteredDistribution(requiredClustering, desiredPartitions) => + ordering.map(_.child).forall(x => requiredClustering.exists(_.semanticEquals(x))) && + desiredPartitions.forall(_ == numPartitions) // if desiredPartitions = None, returns true case _ => false } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala index 63cd1691f4cd7..d15ece304cac4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.execution import scala.concurrent.{ExecutionContext, Future} import scala.concurrent.duration.Duration -import org.apache.spark.{InterruptibleIterator, TaskContext} +import org.apache.spark.{InterruptibleIterator, Partition, SparkContext, TaskContext} import org.apache.spark.rdd.{EmptyRDD, PartitionwiseSampledRDD, RDD} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ @@ -590,10 +590,33 @@ case class CoalesceExec(numPartitions: Int, child: SparkPlan) extends UnaryExecN } protected override def doExecute(): RDD[InternalRow] = { - child.execute().coalesce(numPartitions, shuffle = false) + if (numPartitions == 1 && child.execute().getNumPartitions < 1) { + // Make sure we don't output an RDD with 0 partitions, when claiming that we have a + // `SinglePartition`. + new CoalesceExec.EmptyRDDWithPartitions(sparkContext, numPartitions) + } else { + child.execute().coalesce(numPartitions, shuffle = false) + } } } +object CoalesceExec { + /** A simple RDD with no data, but with the given number of partitions. */ + class EmptyRDDWithPartitions( + @transient private val sc: SparkContext, + numPartitions: Int) extends RDD[InternalRow](sc, Nil) { + + override def getPartitions: Array[Partition] = + Array.tabulate(numPartitions)(i => EmptyPartition(i)) + + override def compute(split: Partition, context: TaskContext): Iterator[InternalRow] = { + Iterator.empty + } + } + + case class EmptyPartition(index: Int) extends Partition +} + /** * Physical plan for a subquery. */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala index d28ce60e276d5..4e2ca37bc1a59 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala @@ -44,13 +44,16 @@ case class EnsureRequirements(conf: SQLConf) extends Rule[SparkPlan] { /** * Given a required distribution, returns a partitioning that satisfies that distribution. + * @param requiredDistribution The distribution that is required by the operator + * @param numPartitions Used when the distribution doesn't require a specific number of partitions */ private def createPartitioning( requiredDistribution: Distribution, numPartitions: Int): Partitioning = { requiredDistribution match { case AllTuples => SinglePartition - case ClusteredDistribution(clustering) => HashPartitioning(clustering, numPartitions) + case ClusteredDistribution(clustering, desiredPartitions) => + HashPartitioning(clustering, desiredPartitions.getOrElse(numPartitions)) case OrderedDistribution(ordering) => RangePartitioning(ordering, numPartitions) case dist => sys.error(s"Do not know how to satisfy distribution $dist") } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala index aab06d611a5ea..c81f1a8142784 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala @@ -64,7 +64,7 @@ case class FlatMapGroupsWithStateExec( /** Distribute by grouping attributes */ override def requiredChildDistribution: Seq[Distribution] = - ClusteredDistribution(groupingAttributes) :: Nil + ClusteredDistribution(groupingAttributes, stateInfo.map(_.numPartitions)) :: Nil /** Ordering needed for using GroupingIterator */ override def requiredChildOrdering: Seq[Seq[SortOrder]] = diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala index 82f879c763c2b..2e378637727fc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala @@ -28,6 +28,7 @@ import org.apache.spark.sql.catalyst.plans.physical.{AllTuples, ClusteredDistrib import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.execution.{QueryExecution, SparkPlan, SparkPlanner, UnaryExecNode} import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.streaming.OutputMode /** @@ -61,6 +62,10 @@ class IncrementalExecution( StreamingDeduplicationStrategy :: Nil } + private val numStateStores = offsetSeqMetadata.conf.get(SQLConf.SHUFFLE_PARTITIONS.key) + .map(SQLConf.SHUFFLE_PARTITIONS.valueConverter) + .getOrElse(sparkSession.sessionState.conf.numShufflePartitions) + /** * See [SPARK-18339] * Walk the optimized logical plan and replace CurrentBatchTimestamp @@ -83,7 +88,11 @@ class IncrementalExecution( /** Get the state info of the next stateful operator */ private def nextStatefulOperationStateInfo(): StatefulOperatorStateInfo = { StatefulOperatorStateInfo( - checkpointLocation, runId, statefulOperatorId.getAndIncrement(), currentBatchId) + checkpointLocation, + runId, + statefulOperatorId.getAndIncrement(), + currentBatchId, + numStateStores) } /** Locates save/restore pairs surrounding aggregation. */ @@ -130,34 +139,8 @@ class IncrementalExecution( } } - override def preparations: Seq[Rule[SparkPlan]] = - Seq(state, EnsureStatefulOpPartitioning) ++ super.preparations + override def preparations: Seq[Rule[SparkPlan]] = state +: super.preparations /** No need assert supported, as this check has already been done */ override def assertSupported(): Unit = { } } - -object EnsureStatefulOpPartitioning extends Rule[SparkPlan] { - // Needs to be transformUp to avoid extra shuffles - override def apply(plan: SparkPlan): SparkPlan = plan transformUp { - case so: StatefulOperator => - val numPartitions = plan.sqlContext.sessionState.conf.numShufflePartitions - val distributions = so.requiredChildDistribution - val children = so.children.zip(distributions).map { case (child, reqDistribution) => - val expectedPartitioning = reqDistribution match { - case AllTuples => SinglePartition - case ClusteredDistribution(keys) => HashPartitioning(keys, numPartitions) - case _ => throw new AnalysisException("Unexpected distribution expected for " + - s"Stateful Operator: $so. Expect AllTuples or ClusteredDistribution but got " + - s"$reqDistribution.") - } - if (child.outputPartitioning.guarantees(expectedPartitioning) && - child.execute().getNumPartitions == expectedPartitioning.numPartitions) { - child - } else { - ShuffleExchangeExec(expectedPartitioning, child) - } - } - so.withNewChildren(children) - } -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala index 0d85542928ee6..b9b07a2e688f9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala @@ -43,10 +43,11 @@ case class StatefulOperatorStateInfo( checkpointLocation: String, queryRunId: UUID, operatorId: Long, - storeVersion: Long) { + storeVersion: Long, + numPartitions: Int) { override def toString(): String = { s"state info [ checkpoint = $checkpointLocation, runId = $queryRunId, " + - s"opId = $operatorId, ver = $storeVersion]" + s"opId = $operatorId, ver = $storeVersion, numPartitions = $numPartitions]" } } @@ -239,7 +240,7 @@ case class StateStoreRestoreExec( if (keyExpressions.isEmpty) { AllTuples :: Nil } else { - ClusteredDistribution(keyExpressions) :: Nil + ClusteredDistribution(keyExpressions, stateInfo.map(_.numPartitions)) :: Nil } } } @@ -386,7 +387,7 @@ case class StateStoreSaveExec( if (keyExpressions.isEmpty) { AllTuples :: Nil } else { - ClusteredDistribution(keyExpressions) :: Nil + ClusteredDistribution(keyExpressions, stateInfo.map(_.numPartitions)) :: Nil } } } @@ -401,7 +402,7 @@ case class StreamingDeduplicateExec( /** Distribute by grouping attributes */ override def requiredChildDistribution: Seq[Distribution] = - ClusteredDistribution(keyExpressions) :: Nil + ClusteredDistribution(keyExpressions, stateInfo.map(_.numPartitions)) :: Nil override protected def doExecute(): RDD[InternalRow] = { metrics // force lazy init at driver diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index ad461fa6144b3..50de2fd3bca8d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -368,6 +368,8 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { checkAnswer( testData.select('key).coalesce(1).select('key), testData.select('key).collect().toSeq) + + assert(spark.emptyDataFrame.coalesce(1).rdd.partitions.size === 1) } test("convert $\"attribute name\" into unresolved attribute") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala index 86066362da9dd..c25c90d0c70e2 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala @@ -425,6 +425,23 @@ class PlannerSuite extends SharedSQLContext { } } + test("EnsureRequirements should respect ClusteredDistribution's num partitioning") { + val distribution = ClusteredDistribution(Literal(1) :: Nil, Some(13)) + // Number of partitions differ + val finalPartitioning = HashPartitioning(Literal(1) :: Nil, 13) + val childPartitioning = HashPartitioning(Literal(1) :: Nil, 5) + assert(!childPartitioning.satisfies(distribution)) + val inputPlan = DummySparkPlan( + children = DummySparkPlan(outputPartitioning = childPartitioning) :: Nil, + requiredChildDistribution = Seq(distribution), + requiredChildOrdering = Seq(Seq.empty)) + + val outputPlan = EnsureRequirements(spark.sessionState.conf).apply(inputPlan) + val shuffle = outputPlan.collect { case e: ShuffleExchangeExec => e } + assert(shuffle.size === 1) + assert(shuffle.head.newPartitioning === finalPartitioning) + } + test("Reuse exchanges") { val distribution = ClusteredDistribution(Literal(1) :: Nil) val finalPartitioning = HashPartitioning(Literal(1) :: Nil, 5) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDDSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDDSuite.scala index defb9ed63a881..65b39f0fbd73d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDDSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDDSuite.scala @@ -214,7 +214,7 @@ class StateStoreRDDSuite extends SparkFunSuite with BeforeAndAfter with BeforeAn path: String, queryRunId: UUID = UUID.randomUUID, version: Int = 0): StatefulOperatorStateInfo = { - StatefulOperatorStateInfo(path, queryRunId, operatorId = 0, version) + StatefulOperatorStateInfo(path, queryRunId, operatorId = 0, version, numPartitions = 5) } private val increment = (store: StateStore, iter: Iterator[String]) => { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/SymmetricHashJoinStateManagerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/SymmetricHashJoinStateManagerSuite.scala index d44af1d14c27a..c0216a2ef3e61 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/SymmetricHashJoinStateManagerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/SymmetricHashJoinStateManagerSuite.scala @@ -160,7 +160,7 @@ class SymmetricHashJoinStateManagerSuite extends StreamTest with BeforeAndAfter withTempDir { file => val storeConf = new StateStoreConf() - val stateInfo = StatefulOperatorStateInfo(file.getAbsolutePath, UUID.randomUUID, 0, 0) + val stateInfo = StatefulOperatorStateInfo(file.getAbsolutePath, UUID.randomUUID, 0, 0, 5) val manager = new SymmetricHashJoinStateManager( LeftSide, inputValueAttribs, joinKeyExprs, Some(stateInfo), storeConf, new Configuration) try { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/DeduplicateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/DeduplicateSuite.scala index e858b7d9998a8..caf2bab8a5859 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/DeduplicateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/DeduplicateSuite.scala @@ -19,12 +19,15 @@ package org.apache.spark.sql.streaming import org.scalatest.BeforeAndAfterAll +import org.apache.spark.sql.catalyst.plans.physical.{ClusteredDistribution, HashPartitioning, SinglePartition} import org.apache.spark.sql.catalyst.streaming.InternalOutputModes._ -import org.apache.spark.sql.execution.streaming.MemoryStream +import org.apache.spark.sql.execution.streaming.{MemoryStream, StreamingDeduplicateExec} import org.apache.spark.sql.execution.streaming.state.StateStore import org.apache.spark.sql.functions._ -class DeduplicateSuite extends StateStoreMetricsTest with BeforeAndAfterAll { +class DeduplicateSuite extends StateStoreMetricsTest + with BeforeAndAfterAll + with StatefulOperatorTest { import testImplicits._ @@ -41,6 +44,8 @@ class DeduplicateSuite extends StateStoreMetricsTest with BeforeAndAfterAll { AddData(inputData, "a"), CheckLastBatch("a"), assertNumStateRows(total = 1, updated = 1), + AssertOnQuery(sq => + checkChildOutputHashPartitioning[StreamingDeduplicateExec](sq, Seq("value"))), AddData(inputData, "a"), CheckLastBatch(), assertNumStateRows(total = 1, updated = 0), @@ -58,6 +63,8 @@ class DeduplicateSuite extends StateStoreMetricsTest with BeforeAndAfterAll { AddData(inputData, "a" -> 1), CheckLastBatch("a" -> 1), assertNumStateRows(total = 1, updated = 1), + AssertOnQuery(sq => + checkChildOutputHashPartitioning[StreamingDeduplicateExec](sq, Seq("_1"))), AddData(inputData, "a" -> 2), // Dropped CheckLastBatch(), assertNumStateRows(total = 1, updated = 0), diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/EnsureStatefulOpPartitioningSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/EnsureStatefulOpPartitioningSuite.scala deleted file mode 100644 index ed9823fbddfda..0000000000000 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/EnsureStatefulOpPartitioningSuite.scala +++ /dev/null @@ -1,138 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.streaming - -import java.util.UUID - -import org.apache.spark.rdd.RDD -import org.apache.spark.sql.DataFrame -import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute -import org.apache.spark.sql.catalyst.expressions.Attribute -import org.apache.spark.sql.catalyst.plans.physical._ -import org.apache.spark.sql.execution.{SparkPlan, SparkPlanTest, UnaryExecNode} -import org.apache.spark.sql.execution.exchange.{Exchange, ShuffleExchangeExec} -import org.apache.spark.sql.execution.streaming.{IncrementalExecution, OffsetSeqMetadata, StatefulOperator, StatefulOperatorStateInfo} -import org.apache.spark.sql.test.SharedSQLContext - -class EnsureStatefulOpPartitioningSuite extends SparkPlanTest with SharedSQLContext { - - import testImplicits._ - - private var baseDf: DataFrame = null - - override def beforeAll(): Unit = { - super.beforeAll() - baseDf = Seq((1, "A"), (2, "b")).toDF("num", "char") - } - - test("ClusteredDistribution generates Exchange with HashPartitioning") { - testEnsureStatefulOpPartitioning( - baseDf.queryExecution.sparkPlan, - requiredDistribution = keys => ClusteredDistribution(keys), - expectedPartitioning = - keys => HashPartitioning(keys, spark.sessionState.conf.numShufflePartitions), - expectShuffle = true) - } - - test("ClusteredDistribution with coalesce(1) generates Exchange with HashPartitioning") { - testEnsureStatefulOpPartitioning( - baseDf.coalesce(1).queryExecution.sparkPlan, - requiredDistribution = keys => ClusteredDistribution(keys), - expectedPartitioning = - keys => HashPartitioning(keys, spark.sessionState.conf.numShufflePartitions), - expectShuffle = true) - } - - test("AllTuples generates Exchange with SinglePartition") { - testEnsureStatefulOpPartitioning( - baseDf.queryExecution.sparkPlan, - requiredDistribution = _ => AllTuples, - expectedPartitioning = _ => SinglePartition, - expectShuffle = true) - } - - test("AllTuples with coalesce(1) doesn't need Exchange") { - testEnsureStatefulOpPartitioning( - baseDf.coalesce(1).queryExecution.sparkPlan, - requiredDistribution = _ => AllTuples, - expectedPartitioning = _ => SinglePartition, - expectShuffle = false) - } - - /** - * For `StatefulOperator` with the given `requiredChildDistribution`, and child SparkPlan - * `inputPlan`, ensures that the incremental planner adds exchanges, if required, in order to - * ensure the expected partitioning. - */ - private def testEnsureStatefulOpPartitioning( - inputPlan: SparkPlan, - requiredDistribution: Seq[Attribute] => Distribution, - expectedPartitioning: Seq[Attribute] => Partitioning, - expectShuffle: Boolean): Unit = { - val operator = TestStatefulOperator(inputPlan, requiredDistribution(inputPlan.output.take(1))) - val executed = executePlan(operator, OutputMode.Complete()) - if (expectShuffle) { - val exchange = executed.children.find(_.isInstanceOf[Exchange]) - if (exchange.isEmpty) { - fail(s"Was expecting an exchange but didn't get one in:\n$executed") - } - assert(exchange.get === - ShuffleExchangeExec(expectedPartitioning(inputPlan.output.take(1)), inputPlan), - s"Exchange didn't have expected properties:\n${exchange.get}") - } else { - assert(!executed.children.exists(_.isInstanceOf[Exchange]), - s"Unexpected exchange found in:\n$executed") - } - } - - /** Executes a SparkPlan using the IncrementalPlanner used for Structured Streaming. */ - private def executePlan( - p: SparkPlan, - outputMode: OutputMode = OutputMode.Append()): SparkPlan = { - val execution = new IncrementalExecution( - spark, - null, - OutputMode.Complete(), - "chk", - UUID.randomUUID(), - 0L, - OffsetSeqMetadata()) { - override lazy val sparkPlan: SparkPlan = p transform { - case plan: SparkPlan => - val inputMap = plan.children.flatMap(_.output).map(a => (a.name, a)).toMap - plan transformExpressions { - case UnresolvedAttribute(Seq(u)) => - inputMap.getOrElse(u, - sys.error(s"Invalid Test: Cannot resolve $u given input $inputMap")) - } - } - } - execution.executedPlan - } -} - -/** Used to emulate a `StatefulOperator` with the given requiredDistribution. */ -case class TestStatefulOperator( - child: SparkPlan, - requiredDist: Distribution) extends UnaryExecNode with StatefulOperator { - override def output: Seq[Attribute] = child.output - override def doExecute(): RDD[InternalRow] = child.execute() - override def requiredChildDistribution: Seq[Distribution] = requiredDist :: Nil - override def stateInfo: Option[StatefulOperatorStateInfo] = None -} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala index d2e8beb2f5290..aeb83835f981a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala @@ -41,7 +41,9 @@ case class RunningCount(count: Long) case class Result(key: Long, count: Int) -class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest with BeforeAndAfterAll { +class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest + with BeforeAndAfterAll + with StatefulOperatorTest { import testImplicits._ import GroupStateImpl._ @@ -544,6 +546,8 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest with BeforeAndAf AddData(inputData, "a"), CheckLastBatch(("a", "1")), assertNumStateRows(total = 1, updated = 1), + AssertOnQuery(sq => checkChildOutputHashPartitioning[FlatMapGroupsWithStateExec]( + sq, Seq("value"))), AddData(inputData, "a", "b"), CheckLastBatch(("a", "2"), ("b", "1")), assertNumStateRows(total = 2, updated = 2), diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StatefulOperatorTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StatefulOperatorTest.scala new file mode 100644 index 0000000000000..45142278993bb --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StatefulOperatorTest.scala @@ -0,0 +1,49 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.streaming + +import org.apache.spark.sql.catalyst.plans.physical._ +import org.apache.spark.sql.execution.streaming._ + +trait StatefulOperatorTest { + /** + * Check that the output partitioning of a child operator of a Stateful operator satisfies the + * distribution that we expect for our Stateful operator. + */ + protected def checkChildOutputHashPartitioning[T <: StatefulOperator]( + sq: StreamingQuery, + colNames: Seq[String]): Boolean = { + val attr = sq.asInstanceOf[StreamExecution].lastExecution.analyzed.output + val partitions = sq.sparkSession.sessionState.conf.numShufflePartitions + val groupingAttr = attr.filter(a => colNames.contains(a.name)) + checkChildOutputPartitioning(sq, HashPartitioning(groupingAttr, partitions)) + } + + /** + * Check that the output partitioning of a child operator of a Stateful operator satisfies the + * distribution that we expect for our Stateful operator. + */ + protected def checkChildOutputPartitioning[T <: StatefulOperator]( + sq: StreamingQuery, + expectedPartitioning: Partitioning): Boolean = { + val operator = sq.asInstanceOf[StreamExecution].lastExecution + .executedPlan.collect { case p: T => p } + operator.head.children.forall( + _.outputPartitioning.numPartitions == expectedPartitioning.numPartitions) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala index fe7efa69f7e31..1b4d8556f6ae5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala @@ -44,7 +44,7 @@ object FailureSingleton { } class StreamingAggregationSuite extends StateStoreMetricsTest - with BeforeAndAfterAll with Assertions { + with BeforeAndAfterAll with Assertions with StatefulOperatorTest { override def afterAll(): Unit = { super.afterAll() @@ -281,6 +281,8 @@ class StreamingAggregationSuite extends StateStoreMetricsTest AddData(inputData, 0L, 5L, 5L, 10L), AdvanceManualClock(10 * 1000), CheckLastBatch((0L, 1), (5L, 2), (10L, 1)), + AssertOnQuery(sq => + checkChildOutputHashPartitioning[StateStoreRestoreExec](sq, Seq("value"))), // advance clock to 20 seconds, should retain keys >= 10 AddData(inputData, 15L, 15L, 20L), @@ -455,8 +457,8 @@ class StreamingAggregationSuite extends StateStoreMetricsTest }, AddBlockData(inputSource), // create an empty trigger CheckLastBatch(1), - AssertOnQuery("Verify addition of exchange operator") { se => - checkAggregationChain(se, expectShuffling = true, 1) + AssertOnQuery("Verify that no exchange is required") { se => + checkAggregationChain(se, expectShuffling = false, 1) }, AddBlockData(inputSource, Seq(2, 3)), CheckLastBatch(3), diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingJoinSuite.scala index a6593b71e51de..d32617275aadc 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingJoinSuite.scala @@ -330,7 +330,7 @@ class StreamingInnerJoinSuite extends StreamTest with StateStoreMetricsTest with val queryId = UUID.randomUUID val opId = 0 val path = Utils.createDirectory(tempDir.getAbsolutePath, Random.nextString(10)).toString - val stateInfo = StatefulOperatorStateInfo(path, queryId, opId, 0L) + val stateInfo = StatefulOperatorStateInfo(path, queryId, opId, 0L, 5) implicit val sqlContext = spark.sqlContext val coordinatorRef = sqlContext.streams.stateStoreCoordinator diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala index ab35079dca23f..c53889bb8566c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala @@ -652,6 +652,19 @@ class StreamingQuerySuite extends StreamTest with BeforeAndAfter with Logging wi } } + test("SPARK-22238: don't check for RDD partitions during streaming aggregation preparation") { + val stream = MemoryStream[(Int, Int)] + val baseDf = Seq((1, "A"), (2, "b")).toDF("num", "char").where("char = 'A'") + val otherDf = stream.toDF().toDF("num", "numSq") + .join(broadcast(baseDf), "num") + .groupBy('char) + .agg(sum('numSq)) + + testStream(otherDf, OutputMode.Complete())( + AddData(stream, (1, 1), (2, 4)), + CheckLastBatch(("A", 1))) + } + /** Create a streaming DF that only execute one batch in which it returns the given static DF */ private def createSingleTriggerStreamingDF(triggerDF: DataFrame): DataFrame = { require(!triggerDF.isStreaming) From 13c1559587d0eb533c94f5a492390f81b048b347 Mon Sep 17 00:00:00 2001 From: Mridul Muralidharan Date: Sun, 15 Oct 2017 18:40:53 -0700 Subject: [PATCH 1492/1765] [SPARK-21549][CORE] Respect OutputFormats with no/invalid output directory provided ## What changes were proposed in this pull request? PR #19294 added support for null's - but spark 2.1 handled other error cases where path argument can be invalid. Namely: * empty string * URI parse exception while creating Path This is resubmission of PR #19487, which I messed up while updating my repo. ## How was this patch tested? Enhanced test to cover new support added. Author: Mridul Muralidharan Closes #19497 from mridulm/master. --- .../io/HadoopMapReduceCommitProtocol.scala | 24 +++++++------- .../spark/rdd/PairRDDFunctionsSuite.scala | 31 +++++++++++++------ 2 files changed, 35 insertions(+), 20 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/internal/io/HadoopMapReduceCommitProtocol.scala b/core/src/main/scala/org/apache/spark/internal/io/HadoopMapReduceCommitProtocol.scala index a7e6859ef6b64..95c99d29c3a9c 100644 --- a/core/src/main/scala/org/apache/spark/internal/io/HadoopMapReduceCommitProtocol.scala +++ b/core/src/main/scala/org/apache/spark/internal/io/HadoopMapReduceCommitProtocol.scala @@ -20,6 +20,7 @@ package org.apache.spark.internal.io import java.util.{Date, UUID} import scala.collection.mutable +import scala.util.Try import org.apache.hadoop.conf.Configurable import org.apache.hadoop.fs.Path @@ -47,6 +48,16 @@ class HadoopMapReduceCommitProtocol(jobId: String, path: String) /** OutputCommitter from Hadoop is not serializable so marking it transient. */ @transient private var committer: OutputCommitter = _ + /** + * Checks whether there are files to be committed to a valid output location. + * + * As committing and aborting a job occurs on driver, where `addedAbsPathFiles` is always null, + * it is necessary to check whether a valid output path is specified. + * [[HadoopMapReduceCommitProtocol#path]] need not be a valid [[org.apache.hadoop.fs.Path]] for + * committers not writing to distributed file systems. + */ + private val hasValidPath = Try { new Path(path) }.isSuccess + /** * Tracks files staged by this task for absolute output paths. These outputs are not managed by * the Hadoop OutputCommitter, so we must move these to their final locations on job commit. @@ -60,15 +71,6 @@ class HadoopMapReduceCommitProtocol(jobId: String, path: String) */ private def absPathStagingDir: Path = new Path(path, "_temporary-" + jobId) - /** - * Checks whether there are files to be committed to an absolute output location. - * - * As committing and aborting a job occurs on driver, where `addedAbsPathFiles` is always null, - * it is necessary to check whether the output path is specified. Output path may not be required - * for committers not writing to distributed file systems. - */ - private def hasAbsPathFiles: Boolean = path != null - protected def setupCommitter(context: TaskAttemptContext): OutputCommitter = { val format = context.getOutputFormatClass.newInstance() // If OutputFormat is Configurable, we should set conf to it. @@ -142,7 +144,7 @@ class HadoopMapReduceCommitProtocol(jobId: String, path: String) val filesToMove = taskCommits.map(_.obj.asInstanceOf[Map[String, String]]) .foldLeft(Map[String, String]())(_ ++ _) logDebug(s"Committing files staged for absolute locations $filesToMove") - if (hasAbsPathFiles) { + if (hasValidPath) { val fs = absPathStagingDir.getFileSystem(jobContext.getConfiguration) for ((src, dst) <- filesToMove) { fs.rename(new Path(src), new Path(dst)) @@ -153,7 +155,7 @@ class HadoopMapReduceCommitProtocol(jobId: String, path: String) override def abortJob(jobContext: JobContext): Unit = { committer.abortJob(jobContext, JobStatus.State.FAILED) - if (hasAbsPathFiles) { + if (hasValidPath) { val fs = absPathStagingDir.getFileSystem(jobContext.getConfiguration) fs.delete(absPathStagingDir, true) } diff --git a/core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala b/core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala index 07579c5098014..0a248b6064ee8 100644 --- a/core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala +++ b/core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala @@ -568,21 +568,34 @@ class PairRDDFunctionsSuite extends SparkFunSuite with SharedSparkContext { assert(FakeWriterWithCallback.exception.getMessage contains "failed to write") } - test("saveAsNewAPIHadoopDataset should respect empty output directory when " + + test("saveAsNewAPIHadoopDataset should support invalid output paths when " + "there are no files to be committed to an absolute output location") { val pairs = sc.parallelize(Array((new Integer(1), new Integer(2))), 1) - val job = NewJob.getInstance(new Configuration(sc.hadoopConfiguration)) - job.setOutputKeyClass(classOf[Integer]) - job.setOutputValueClass(classOf[Integer]) - job.setOutputFormatClass(classOf[NewFakeFormat]) - val jobConfiguration = job.getConfiguration + def saveRddWithPath(path: String): Unit = { + val job = NewJob.getInstance(new Configuration(sc.hadoopConfiguration)) + job.setOutputKeyClass(classOf[Integer]) + job.setOutputValueClass(classOf[Integer]) + job.setOutputFormatClass(classOf[NewFakeFormat]) + if (null != path) { + job.getConfiguration.set("mapred.output.dir", path) + } else { + job.getConfiguration.unset("mapred.output.dir") + } + val jobConfiguration = job.getConfiguration + + // just test that the job does not fail with java.lang.IllegalArgumentException. + pairs.saveAsNewAPIHadoopDataset(jobConfiguration) + } - // just test that the job does not fail with - // java.lang.IllegalArgumentException: Can not create a Path from a null string - pairs.saveAsNewAPIHadoopDataset(jobConfiguration) + saveRddWithPath(null) + saveRddWithPath("") + saveRddWithPath("::invalid::") } + // In spark 2.1, only null was supported - not other invalid paths. + // org.apache.hadoop.mapred.FileOutputFormat.getOutputPath fails with IllegalArgumentException + // for non-null invalid paths. test("saveAsHadoopDataset should respect empty output directory when " + "there are no files to be committed to an absolute output location") { val pairs = sc.parallelize(Array((new Integer(1), new Integer(2))), 1) From 0ae96495dedb54b3b6bae0bd55560820c5ca29a2 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Mon, 16 Oct 2017 13:37:58 +0800 Subject: [PATCH 1493/1765] [SPARK-22223][SQL] ObjectHashAggregate should not introduce unnecessary shuffle ## What changes were proposed in this pull request? `ObjectHashAggregateExec` should override `outputPartitioning` in order to avoid unnecessary shuffle. ## How was this patch tested? Added Jenkins test. Author: Liang-Chi Hsieh Closes #19501 from viirya/SPARK-22223. --- .../aggregate/ObjectHashAggregateExec.scala | 2 ++ .../spark/sql/DataFrameAggregateSuite.scala | 30 +++++++++++++++++++ 2 files changed, 32 insertions(+) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectHashAggregateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectHashAggregateExec.scala index ec3f9a05b5ccc..66955b8ef723c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectHashAggregateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectHashAggregateExec.scala @@ -95,6 +95,8 @@ case class ObjectHashAggregateExec( } } + override def outputPartitioning: Partitioning = child.outputPartitioning + protected override def doExecute(): RDD[InternalRow] = attachTree(this, "execute") { val numOutputRows = longMetric("numOutputRows") val aggTime = longMetric("aggTime") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala index 8549eac58ee95..06848e4d2b297 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala @@ -21,6 +21,7 @@ import scala.util.Random import org.apache.spark.sql.execution.WholeStageCodegenExec import org.apache.spark.sql.execution.aggregate.{HashAggregateExec, ObjectHashAggregateExec, SortAggregateExec} +import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec import org.apache.spark.sql.expressions.Window import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf @@ -636,4 +637,33 @@ class DataFrameAggregateSuite extends QueryTest with SharedSQLContext { spark.sql("SELECT 3 AS c, 4 AS d, SUM(b) FROM testData2 GROUP BY c, d"), Seq(Row(3, 4, 9))) } + + test("SPARK-22223: ObjectHashAggregate should not introduce unnecessary shuffle") { + withSQLConf(SQLConf.USE_OBJECT_HASH_AGG.key -> "true") { + val df = Seq(("1", "2", 1), ("1", "2", 2), ("2", "3", 3), ("2", "3", 4)).toDF("a", "b", "c") + .repartition(col("a")) + + val objHashAggDF = df + .withColumn("d", expr("(a, b, c)")) + .groupBy("a", "b").agg(collect_list("d").as("e")) + .withColumn("f", expr("(b, e)")) + .groupBy("a").agg(collect_list("f").as("g")) + val aggPlan = objHashAggDF.queryExecution.executedPlan + + val sortAggPlans = aggPlan.collect { + case sortAgg: SortAggregateExec => sortAgg + } + assert(sortAggPlans.isEmpty) + + val objHashAggPlans = aggPlan.collect { + case objHashAgg: ObjectHashAggregateExec => objHashAgg + } + assert(objHashAggPlans.nonEmpty) + + val exchangePlans = aggPlan.collect { + case shuffle: ShuffleExchangeExec => shuffle + } + assert(exchangePlans.length == 1) + } + } } From 0fa10666cf75e3c4929940af49c8a6f6ea874759 Mon Sep 17 00:00:00 2001 From: Xingbo Jiang Date: Mon, 16 Oct 2017 22:15:50 +0800 Subject: [PATCH 1494/1765] [SPARK-22233][CORE][FOLLOW-UP] Allow user to filter out empty split in HadoopRDD ## What changes were proposed in this pull request? Update the config `spark.files.ignoreEmptySplits`, rename it and make it internal. This is followup of #19464 ## How was this patch tested? Exsiting tests. Author: Xingbo Jiang Closes #19504 from jiangxb1987/partitionsplit. --- .../org/apache/spark/internal/config/package.scala | 11 ++++++----- .../scala/org/apache/spark/rdd/HadoopRDD.scala | 4 ++-- .../scala/org/apache/spark/rdd/NewHadoopRDD.scala | 4 ++-- .../test/scala/org/apache/spark/FileSuite.scala | 14 +++++++++----- 4 files changed, 19 insertions(+), 14 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/internal/config/package.scala b/core/src/main/scala/org/apache/spark/internal/config/package.scala index ce013d69579c1..efffdca1ea59b 100644 --- a/core/src/main/scala/org/apache/spark/internal/config/package.scala +++ b/core/src/main/scala/org/apache/spark/internal/config/package.scala @@ -270,11 +270,12 @@ package object config { .longConf .createWithDefault(4 * 1024 * 1024) - private[spark] val IGNORE_EMPTY_SPLITS = ConfigBuilder("spark.files.ignoreEmptySplits") - .doc("If true, methods that use HadoopRDD and NewHadoopRDD such as " + - "SparkContext.textFiles will not create a partition for input splits that are empty.") - .booleanConf - .createWithDefault(false) + private[spark] val HADOOP_RDD_IGNORE_EMPTY_SPLITS = + ConfigBuilder("spark.hadoopRDD.ignoreEmptySplits") + .internal() + .doc("When true, HadoopRDD/NewHadoopRDD will not create partitions for empty input splits.") + .booleanConf + .createWithDefault(false) private[spark] val SECRET_REDACTION_PATTERN = ConfigBuilder("spark.redaction.regex") diff --git a/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala b/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala index 1f33c0a2b709f..2480559a41b7a 100644 --- a/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala @@ -35,7 +35,7 @@ import org.apache.spark.annotation.DeveloperApi import org.apache.spark.broadcast.Broadcast import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.internal.Logging -import org.apache.spark.internal.config.{IGNORE_CORRUPT_FILES, IGNORE_EMPTY_SPLITS} +import org.apache.spark.internal.config._ import org.apache.spark.rdd.HadoopRDD.HadoopMapPartitionsWithSplitRDD import org.apache.spark.scheduler.{HDFSCacheTaskLocation, HostTaskLocation} import org.apache.spark.storage.StorageLevel @@ -134,7 +134,7 @@ class HadoopRDD[K, V]( private val ignoreCorruptFiles = sparkContext.conf.get(IGNORE_CORRUPT_FILES) - private val ignoreEmptySplits = sparkContext.getConf.get(IGNORE_EMPTY_SPLITS) + private val ignoreEmptySplits = sparkContext.conf.get(HADOOP_RDD_IGNORE_EMPTY_SPLITS) // Returns a JobConf that will be used on slaves to obtain input splits for Hadoop reads. protected def getJobConf(): JobConf = { diff --git a/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala b/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala index db4eac1d0a775..e4dd1b6a82498 100644 --- a/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala @@ -35,7 +35,7 @@ import org.apache.spark._ import org.apache.spark.annotation.DeveloperApi import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.internal.Logging -import org.apache.spark.internal.config.{IGNORE_CORRUPT_FILES, IGNORE_EMPTY_SPLITS} +import org.apache.spark.internal.config._ import org.apache.spark.rdd.NewHadoopRDD.NewHadoopMapPartitionsWithSplitRDD import org.apache.spark.storage.StorageLevel import org.apache.spark.util.{SerializableConfiguration, ShutdownHookManager} @@ -90,7 +90,7 @@ class NewHadoopRDD[K, V]( private val ignoreCorruptFiles = sparkContext.conf.get(IGNORE_CORRUPT_FILES) - private val ignoreEmptySplits = sparkContext.getConf.get(IGNORE_EMPTY_SPLITS) + private val ignoreEmptySplits = sparkContext.conf.get(HADOOP_RDD_IGNORE_EMPTY_SPLITS) def getConf: Configuration = { val conf: Configuration = confBroadcast.value.value diff --git a/core/src/test/scala/org/apache/spark/FileSuite.scala b/core/src/test/scala/org/apache/spark/FileSuite.scala index 4da4323ceb5c8..e9539dc73f6fa 100644 --- a/core/src/test/scala/org/apache/spark/FileSuite.scala +++ b/core/src/test/scala/org/apache/spark/FileSuite.scala @@ -31,7 +31,7 @@ import org.apache.hadoop.mapreduce.Job import org.apache.hadoop.mapreduce.lib.input.{FileSplit => NewFileSplit, TextInputFormat => NewTextInputFormat} import org.apache.hadoop.mapreduce.lib.output.{TextOutputFormat => NewTextOutputFormat} -import org.apache.spark.internal.config.{IGNORE_CORRUPT_FILES, IGNORE_EMPTY_SPLITS} +import org.apache.spark.internal.config._ import org.apache.spark.rdd.{HadoopRDD, NewHadoopRDD} import org.apache.spark.storage.StorageLevel import org.apache.spark.util.Utils @@ -510,9 +510,11 @@ class FileSuite extends SparkFunSuite with LocalSparkContext { } } - test("spark.files.ignoreEmptySplits work correctly (old Hadoop API)") { + test("spark.hadoopRDD.ignoreEmptySplits work correctly (old Hadoop API)") { val conf = new SparkConf() - conf.setAppName("test").setMaster("local").set(IGNORE_EMPTY_SPLITS, true) + .setAppName("test") + .setMaster("local") + .set(HADOOP_RDD_IGNORE_EMPTY_SPLITS, true) sc = new SparkContext(conf) def testIgnoreEmptySplits( @@ -549,9 +551,11 @@ class FileSuite extends SparkFunSuite with LocalSparkContext { expectedPartitionNum = 2) } - test("spark.files.ignoreEmptySplits work correctly (new Hadoop API)") { + test("spark.hadoopRDD.ignoreEmptySplits work correctly (new Hadoop API)") { val conf = new SparkConf() - conf.setAppName("test").setMaster("local").set(IGNORE_EMPTY_SPLITS, true) + .setAppName("test") + .setMaster("local") + .set(HADOOP_RDD_IGNORE_EMPTY_SPLITS, true) sc = new SparkContext(conf) def testIgnoreEmptySplits( From 561505e2fc290fc2cee3b8464ec49df773dca5eb Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Mon, 16 Oct 2017 11:27:08 -0700 Subject: [PATCH 1495/1765] [SPARK-22282][SQL] Rename OrcRelation to OrcFileFormat and remove ORC_COMPRESSION ## What changes were proposed in this pull request? This PR aims to - Rename `OrcRelation` to `OrcFileFormat` object. - Replace `OrcRelation.ORC_COMPRESSION` with `org.apache.orc.OrcConf.COMPRESS`. Since [SPARK-21422](https://issues.apache.org/jira/browse/SPARK-21422), we can use `OrcConf.COMPRESS` instead of Hive's. ```scala // The references of Hive's classes will be minimized. val ORC_COMPRESSION = "orc.compress" ``` ## How was this patch tested? Pass the Jenkins with the existing and updated test cases. Author: Dongjoon Hyun Closes #19502 from dongjoon-hyun/SPARK-22282. --- .../org/apache/spark/sql/DataFrameWriter.scala | 4 ++-- .../spark/sql/hive/orc/OrcFileFormat.scala | 18 ++++++++---------- .../apache/spark/sql/hive/orc/OrcOptions.scala | 8 +++++--- .../spark/sql/hive/orc/OrcQuerySuite.scala | 11 ++++++----- .../spark/sql/hive/orc/OrcSourceSuite.scala | 9 ++++++--- 5 files changed, 27 insertions(+), 23 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala index 07347d2748544..c9e45436ed42f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala @@ -520,8 +520,8 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { *
  • `compression` (default is the value specified in `spark.sql.orc.compression.codec`): * compression codec to use when saving to file. This can be one of the known case-insensitive * shorten names(`none`, `snappy`, `zlib`, and `lzo`). This will override - * `orc.compress` and `spark.sql.parquet.compression.codec`. If `orc.compress` is given, - * it overrides `spark.sql.parquet.compression.codec`.
  • + * `orc.compress` and `spark.sql.orc.compression.codec`. If `orc.compress` is given, + * it overrides `spark.sql.orc.compression.codec`. * * * @since 1.5.0 diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileFormat.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileFormat.scala index 194e69c93e1a8..d26ec15410d95 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileFormat.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileFormat.scala @@ -32,6 +32,7 @@ import org.apache.hadoop.io.{NullWritable, Writable} import org.apache.hadoop.mapred.{JobConf, OutputFormat => MapRedOutputFormat, RecordWriter, Reporter} import org.apache.hadoop.mapreduce._ import org.apache.hadoop.mapreduce.lib.input.{FileInputFormat, FileSplit} +import org.apache.orc.OrcConf.COMPRESS import org.apache.spark.TaskContext import org.apache.spark.sql.SparkSession @@ -72,7 +73,7 @@ class OrcFileFormat extends FileFormat with DataSourceRegister with Serializable val configuration = job.getConfiguration - configuration.set(OrcRelation.ORC_COMPRESSION, orcOptions.compressionCodec) + configuration.set(COMPRESS.getAttribute, orcOptions.compressionCodec) configuration match { case conf: JobConf => conf.setOutputFormat(classOf[OrcOutputFormat]) @@ -93,8 +94,8 @@ class OrcFileFormat extends FileFormat with DataSourceRegister with Serializable override def getFileExtension(context: TaskAttemptContext): String = { val compressionExtension: String = { - val name = context.getConfiguration.get(OrcRelation.ORC_COMPRESSION) - OrcRelation.extensionsForCompressionCodecNames.getOrElse(name, "") + val name = context.getConfiguration.get(COMPRESS.getAttribute) + OrcFileFormat.extensionsForCompressionCodecNames.getOrElse(name, "") } compressionExtension + ".orc" @@ -120,7 +121,7 @@ class OrcFileFormat extends FileFormat with DataSourceRegister with Serializable if (sparkSession.sessionState.conf.orcFilterPushDown) { // Sets pushed predicates OrcFilters.createFilter(requiredSchema, filters.toArray).foreach { f => - hadoopConf.set(OrcRelation.SARG_PUSHDOWN, f.toKryo) + hadoopConf.set(OrcFileFormat.SARG_PUSHDOWN, f.toKryo) hadoopConf.setBoolean(ConfVars.HIVEOPTINDEXFILTER.varname, true) } } @@ -138,7 +139,7 @@ class OrcFileFormat extends FileFormat with DataSourceRegister with Serializable if (isEmptyFile) { Iterator.empty } else { - OrcRelation.setRequiredColumns(conf, dataSchema, requiredSchema) + OrcFileFormat.setRequiredColumns(conf, dataSchema, requiredSchema) val orcRecordReader = { val job = Job.getInstance(conf) @@ -160,7 +161,7 @@ class OrcFileFormat extends FileFormat with DataSourceRegister with Serializable Option(TaskContext.get()).foreach(_.addTaskCompletionListener(_ => recordsIterator.close())) // Unwraps `OrcStruct`s to `UnsafeRow`s - OrcRelation.unwrapOrcStructs( + OrcFileFormat.unwrapOrcStructs( conf, dataSchema, requiredSchema, @@ -255,10 +256,7 @@ private[orc] class OrcOutputWriter( } } -private[orc] object OrcRelation extends HiveInspectors { - // The references of Hive's classes will be minimized. - val ORC_COMPRESSION = "orc.compress" - +private[orc] object OrcFileFormat extends HiveInspectors { // This constant duplicates `OrcInputFormat.SARG_PUSHDOWN`, which is unfortunately not public. private[orc] val SARG_PUSHDOWN = "sarg.pushdown" diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcOptions.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcOptions.scala index 7f94c8c579026..6ce90c07b4921 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcOptions.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcOptions.scala @@ -19,6 +19,8 @@ package org.apache.spark.sql.hive.orc import java.util.Locale +import org.apache.orc.OrcConf.COMPRESS + import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap import org.apache.spark.sql.internal.SQLConf @@ -40,9 +42,9 @@ private[orc] class OrcOptions( * Acceptable values are defined in [[shortOrcCompressionCodecNames]]. */ val compressionCodec: String = { - // `compression`, `orc.compress`, and `spark.sql.orc.compression.codec` are - // in order of precedence from highest to lowest. - val orcCompressionConf = parameters.get(OrcRelation.ORC_COMPRESSION) + // `compression`, `orc.compress`(i.e., OrcConf.COMPRESS), and `spark.sql.orc.compression.codec` + // are in order of precedence from highest to lowest. + val orcCompressionConf = parameters.get(COMPRESS.getAttribute) val codecName = parameters .get("compression") .orElse(orcCompressionConf) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcQuerySuite.scala index 60ccd996d6d58..1fa9091f967a3 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcQuerySuite.scala @@ -22,6 +22,7 @@ import java.sql.Timestamp import org.apache.hadoop.conf.Configuration import org.apache.hadoop.hive.ql.io.orc.{OrcStruct, SparkOrcNewRecordReader} +import org.apache.orc.OrcConf.COMPRESS import org.scalatest.BeforeAndAfterAll import org.apache.spark.sql._ @@ -176,11 +177,11 @@ class OrcQuerySuite extends QueryTest with BeforeAndAfterAll with OrcTest { } } - test("SPARK-16610: Respect orc.compress option when compression is unset") { - // Respect `orc.compress`. + test("SPARK-16610: Respect orc.compress (i.e., OrcConf.COMPRESS) when compression is unset") { + // Respect `orc.compress` (i.e., OrcConf.COMPRESS). withTempPath { file => spark.range(0, 10).write - .option("orc.compress", "ZLIB") + .option(COMPRESS.getAttribute, "ZLIB") .orc(file.getCanonicalPath) val expectedCompressionKind = OrcFileOperator.getFileReader(file.getCanonicalPath).get.getCompression @@ -191,7 +192,7 @@ class OrcQuerySuite extends QueryTest with BeforeAndAfterAll with OrcTest { withTempPath { file => spark.range(0, 10).write .option("compression", "ZLIB") - .option("orc.compress", "SNAPPY") + .option(COMPRESS.getAttribute, "SNAPPY") .orc(file.getCanonicalPath) val expectedCompressionKind = OrcFileOperator.getFileReader(file.getCanonicalPath).get.getCompression @@ -598,7 +599,7 @@ class OrcQuerySuite extends QueryTest with BeforeAndAfterAll with OrcTest { val requestedSchema = StructType(Nil) val conf = new Configuration() val physicalSchema = OrcFileOperator.readSchema(Seq(path), Some(conf)).get - OrcRelation.setRequiredColumns(conf, physicalSchema, requestedSchema) + OrcFileFormat.setRequiredColumns(conf, physicalSchema, requestedSchema) val maybeOrcReader = OrcFileOperator.getFileReader(path, Some(conf)) assert(maybeOrcReader.isDefined) val orcRecordReader = new SparkOrcNewRecordReader( diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcSourceSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcSourceSuite.scala index 781de6631f324..ef9e67c743837 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcSourceSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcSourceSuite.scala @@ -18,7 +18,9 @@ package org.apache.spark.sql.hive.orc import java.io.File +import java.util.Locale +import org.apache.orc.OrcConf.COMPRESS import org.scalatest.BeforeAndAfterAll import org.apache.spark.sql.{QueryTest, Row} @@ -150,7 +152,8 @@ abstract class OrcSuite extends QueryTest with TestHiveSingleton with BeforeAndA test("SPARK-18433: Improve DataSource option keys to be more case-insensitive") { val conf = sqlContext.sessionState.conf - assert(new OrcOptions(Map("Orc.Compress" -> "NONE"), conf).compressionCodec == "NONE") + val option = new OrcOptions(Map(COMPRESS.getAttribute.toUpperCase(Locale.ROOT) -> "NONE"), conf) + assert(option.compressionCodec == "NONE") } test("SPARK-19459/SPARK-18220: read char/varchar column written by Hive") { @@ -205,8 +208,8 @@ abstract class OrcSuite extends QueryTest with TestHiveSingleton with BeforeAndA // `compression` -> `orc.compression` -> `spark.sql.orc.compression.codec` withSQLConf(SQLConf.ORC_COMPRESSION.key -> "uncompressed") { assert(new OrcOptions(Map.empty[String, String], conf).compressionCodec == "NONE") - val map1 = Map("orc.compress" -> "zlib") - val map2 = Map("orc.compress" -> "zlib", "compression" -> "lzo") + val map1 = Map(COMPRESS.getAttribute -> "zlib") + val map2 = Map(COMPRESS.getAttribute -> "zlib", "compression" -> "lzo") assert(new OrcOptions(map1, conf).compressionCodec == "ZLIB") assert(new OrcOptions(map2, conf).compressionCodec == "LZO") } From c09a2a76b52905a784d2767cb899dc886c330628 Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Mon, 16 Oct 2017 16:16:34 -0700 Subject: [PATCH 1496/1765] [SPARK-22280][SQL][TEST] Improve StatisticsSuite to test `convertMetastore` properly ## What changes were proposed in this pull request? This PR aims to improve **StatisticsSuite** to test `convertMetastore` configuration properly. Currently, some test logic in `test statistics of LogicalRelation converted from Hive serde tables` depends on the default configuration. New test case is shorter and covers both(true/false) cases explicitly. This test case was previously modified by SPARK-17410 and SPARK-17284 in Spark 2.3.0. - https://github.com/apache/spark/commit/a2460be9c30b67b9159fe339d115b84d53cc288a#diff-1c464c86b68c2d0b07e73b7354e74ce7R443 ## How was this patch tested? Pass the Jenkins with the improved test case. Author: Dongjoon Hyun Closes #19500 from dongjoon-hyun/SPARK-22280. --- .../spark/sql/hive/StatisticsSuite.scala | 34 ++++++++----------- 1 file changed, 14 insertions(+), 20 deletions(-) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala index 9ff9ecf7f3677..b9a5ad7657134 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala @@ -937,26 +937,20 @@ class StatisticsSuite extends StatisticsCollectionTestBase with TestHiveSingleto } test("test statistics of LogicalRelation converted from Hive serde tables") { - val parquetTable = "parquetTable" - val orcTable = "orcTable" - withTable(parquetTable, orcTable) { - sql(s"CREATE TABLE $parquetTable (key STRING, value STRING) STORED AS PARQUET") - sql(s"CREATE TABLE $orcTable (key STRING, value STRING) STORED AS ORC") - sql(s"INSERT INTO TABLE $parquetTable SELECT * FROM src") - sql(s"INSERT INTO TABLE $orcTable SELECT * FROM src") - - // the default value for `spark.sql.hive.convertMetastoreParquet` is true, here we just set it - // for robustness - withSQLConf(HiveUtils.CONVERT_METASTORE_PARQUET.key -> "true") { - checkTableStats(parquetTable, hasSizeInBytes = false, expectedRowCounts = None) - sql(s"ANALYZE TABLE $parquetTable COMPUTE STATISTICS") - checkTableStats(parquetTable, hasSizeInBytes = true, expectedRowCounts = Some(500)) - } - withSQLConf(HiveUtils.CONVERT_METASTORE_ORC.key -> "true") { - // We still can get tableSize from Hive before Analyze - checkTableStats(orcTable, hasSizeInBytes = true, expectedRowCounts = None) - sql(s"ANALYZE TABLE $orcTable COMPUTE STATISTICS") - checkTableStats(orcTable, hasSizeInBytes = true, expectedRowCounts = Some(500)) + Seq("orc", "parquet").foreach { format => + Seq(true, false).foreach { isConverted => + withSQLConf( + HiveUtils.CONVERT_METASTORE_ORC.key -> s"$isConverted", + HiveUtils.CONVERT_METASTORE_PARQUET.key -> s"$isConverted") { + withTable(format) { + sql(s"CREATE TABLE $format (key STRING, value STRING) STORED AS $format") + sql(s"INSERT INTO TABLE $format SELECT * FROM src") + + checkTableStats(format, hasSizeInBytes = !isConverted, expectedRowCounts = None) + sql(s"ANALYZE TABLE $format COMPUTE STATISTICS") + checkTableStats(format, hasSizeInBytes = true, expectedRowCounts = Some(500)) + } + } } } } From e66cabb0215204605ca7928406d4787d41853dd1 Mon Sep 17 00:00:00 2001 From: Ben Barnard Date: Tue, 17 Oct 2017 09:36:09 +0200 Subject: [PATCH 1497/1765] [SPARK-20992][SCHEDULER] Add links in documentation to Nomad integration. ## What changes were proposed in this pull request? Adds links to the fork that provides integration with Nomad, in the same places the k8s integration is linked to. ## How was this patch tested? I clicked on the links to make sure they're correct ;) Author: Ben Barnard Closes #19354 from barnardb/link-to-nomad-integration. --- docs/cluster-overview.md | 3 +++ 1 file changed, 3 insertions(+) diff --git a/docs/cluster-overview.md b/docs/cluster-overview.md index a2ad958959a50..c42bb4bb8377e 100644 --- a/docs/cluster-overview.md +++ b/docs/cluster-overview.md @@ -58,6 +58,9 @@ for providing container-centric infrastructure. Kubernetes support is being acti developed in an [apache-spark-on-k8s](https://github.com/apache-spark-on-k8s/) Github organization. For documentation, refer to that project's README. +A third-party project (not supported by the Spark project) exists to add support for +[Nomad](https://github.com/hashicorp/nomad-spark) as a cluster manager. + # Submitting Applications Applications can be submitted to a cluster of any type using the `spark-submit` script. From 8148f19ca1f0e0375603cb4f180c1bad8b0b8042 Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Tue, 17 Oct 2017 09:41:23 +0200 Subject: [PATCH 1498/1765] [SPARK-22249][SQL] isin with empty list throws exception on cached DataFrame ## What changes were proposed in this pull request? As pointed out in the JIRA, there is a bug which causes an exception to be thrown if `isin` is called with an empty list on a cached DataFrame. The PR fixes it. ## How was this patch tested? Added UT. Author: Marco Gaido Closes #19494 from mgaido91/SPARK-22249. --- .../columnar/InMemoryTableScanExec.scala | 1 + .../columnar/InMemoryColumnarQuerySuite.scala | 15 +++++++++++++++ 2 files changed, 16 insertions(+) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala index af3636a5a2ca7..846ec03e46a12 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala @@ -102,6 +102,7 @@ case class InMemoryTableScanExec( case IsNull(a: Attribute) => statsFor(a).nullCount > 0 case IsNotNull(a: Attribute) => statsFor(a).count - statsFor(a).nullCount > 0 + case In(_: AttributeReference, list: Seq[Expression]) if list.isEmpty => Literal.FalseLiteral case In(a: AttributeReference, list: Seq[Expression]) if list.forall(_.isInstanceOf[Literal]) => list.map(l => statsFor(a).lowerBound <= l.asInstanceOf[Literal] && l.asInstanceOf[Literal] <= statsFor(a).upperBound).reduce(_ || _) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala index 8d411eb191cd9..75d17bc79477d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala @@ -429,4 +429,19 @@ class InMemoryColumnarQuerySuite extends QueryTest with SharedSQLContext { checkAnswer(agg_without_cache, agg_with_cache) } } + + test("SPARK-22249: IN should work also with cached DataFrame") { + val df = spark.range(10).cache() + // with an empty list + assert(df.filter($"id".isin()).count() == 0) + // with a non-empty list + assert(df.filter($"id".isin(2)).count() == 1) + assert(df.filter($"id".isin(2, 3)).count() == 2) + df.unpersist() + val dfNulls = spark.range(10).selectExpr("null as id").cache() + // with null as value for the attribute + assert(dfNulls.filter($"id".isin()).count() == 0) + assert(dfNulls.filter($"id".isin(2, 3)).count() == 0) + dfNulls.unpersist() + } } From 99e32f8ba5d908d5408e9857fd96ac1d7d7e5876 Mon Sep 17 00:00:00 2001 From: Kent Yao Date: Tue, 17 Oct 2017 17:58:45 +0800 Subject: [PATCH 1499/1765] [SPARK-22224][SQL] Override toString of KeyValue/Relational-GroupedDataset ## What changes were proposed in this pull request? #### before ```scala scala> val words = spark.read.textFile("README.md").flatMap(_.split(" ")) words: org.apache.spark.sql.Dataset[String] = [value: string] scala> val grouped = words.groupByKey(identity) grouped: org.apache.spark.sql.KeyValueGroupedDataset[String,String] = org.apache.spark.sql.KeyValueGroupedDataset65214862 ``` #### after ```scala scala> val words = spark.read.textFile("README.md").flatMap(_.split(" ")) words: org.apache.spark.sql.Dataset[String] = [value: string] scala> val grouped = words.groupByKey(identity) grouped: org.apache.spark.sql.KeyValueGroupedDataset[String,String] = [key: [value: string], value: [value: string]] ``` ## How was this patch tested? existing ut cc gatorsmile cloud-fan Author: Kent Yao Closes #19363 from yaooqinn/minor-dataset-tostring. --- .../spark/sql/KeyValueGroupedDataset.scala | 22 ++++++- .../spark/sql/RelationalGroupedDataset.scala | 19 +++++- .../org/apache/spark/sql/DatasetSuite.scala | 61 +++++++++++++++++++ .../org/apache/spark/sql/QueryTest.scala | 12 +--- 4 files changed, 100 insertions(+), 14 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala index cb42e9e4560cf..6bab21dca0cbd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala @@ -24,7 +24,6 @@ import org.apache.spark.api.java.function._ import org.apache.spark.sql.catalyst.encoders.{encoderFor, ExpressionEncoder} import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, CreateStruct} import org.apache.spark.sql.catalyst.plans.logical._ -import org.apache.spark.sql.catalyst.streaming.InternalOutputModes import org.apache.spark.sql.execution.QueryExecution import org.apache.spark.sql.expressions.ReduceAggregator import org.apache.spark.sql.streaming.{GroupState, GroupStateTimeout, OutputMode} @@ -564,4 +563,25 @@ class KeyValueGroupedDataset[K, V] private[sql]( encoder: Encoder[R]): Dataset[R] = { cogroup(other)((key, left, right) => f.call(key, left.asJava, right.asJava).asScala)(encoder) } + + override def toString: String = { + val builder = new StringBuilder + val kFields = kExprEnc.schema.map { + case f => s"${f.name}: ${f.dataType.simpleString(2)}" + } + val vFields = vExprEnc.schema.map { + case f => s"${f.name}: ${f.dataType.simpleString(2)}" + } + builder.append("KeyValueGroupedDataset: [key: [") + builder.append(kFields.take(2).mkString(", ")) + if (kFields.length > 2) { + builder.append(" ... " + (kFields.length - 2) + " more field(s)") + } + builder.append("], value: [") + builder.append(vFields.take(2).mkString(", ")) + if (vFields.length > 2) { + builder.append(" ... " + (vFields.length - 2) + " more field(s)") + } + builder.append("]]").toString() + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala index cd0ac1feffa51..33ec3a27110a8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala @@ -32,7 +32,7 @@ import org.apache.spark.sql.catalyst.util.usePrettyExpression import org.apache.spark.sql.execution.aggregate.TypedAggregateExpression import org.apache.spark.sql.execution.python.PythonUDF import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.types.{NumericType, StructField, StructType} +import org.apache.spark.sql.types.{NumericType, StructType} /** * A set of methods for aggregations on a `DataFrame`, created by [[Dataset#groupBy groupBy]], @@ -465,6 +465,19 @@ class RelationalGroupedDataset protected[sql]( Dataset.ofRows(df.sparkSession, plan) } + + override def toString: String = { + val builder = new StringBuilder + builder.append("RelationalGroupedDataset: [grouping expressions: [") + val kFields = groupingExprs.map(_.asInstanceOf[NamedExpression]).map { + case f => s"${f.name}: ${f.dataType.simpleString(2)}" + } + builder.append(kFields.take(2).mkString(", ")) + if (kFields.length > 2) { + builder.append(" ... " + (kFields.length - 2) + " more field(s)") + } + builder.append(s"], value: ${df.toString}, type: $groupType]").toString() + } } private[sql] object RelationalGroupedDataset { @@ -479,7 +492,9 @@ private[sql] object RelationalGroupedDataset { /** * The Grouping Type */ - private[sql] trait GroupType + private[sql] trait GroupType { + override def toString: String = getClass.getSimpleName.stripSuffix("$").stripSuffix("Type") + } /** * To indicate it's the GroupBy diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala index dace6825ee40e..1537ce3313c09 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala @@ -1341,8 +1341,69 @@ class DatasetSuite extends QueryTest with SharedSQLContext { Seq(1).toDS().map(_ => ("", TestForTypeAlias.seqOfTupleTypeAlias)), ("", Seq((1, 1), (2, 2)))) } + + test("Check RelationalGroupedDataset toString: Single data") { + val kvDataset = (1 to 3).toDF("id").groupBy("id") + val expected = "RelationalGroupedDataset: [" + + "grouping expressions: [id: int], value: [id: int], type: GroupBy]" + val actual = kvDataset.toString + assert(expected === actual) + } + + test("Check RelationalGroupedDataset toString: over length schema ") { + val kvDataset = (1 to 3).map( x => (x, x.toString, x.toLong)) + .toDF("id", "val1", "val2").groupBy("id") + val expected = "RelationalGroupedDataset:" + + " [grouping expressions: [id: int]," + + " value: [id: int, val1: string ... 1 more field]," + + " type: GroupBy]" + val actual = kvDataset.toString + assert(expected === actual) + } + + + test("Check KeyValueGroupedDataset toString: Single data") { + val kvDataset = (1 to 3).toDF("id").as[SingleData].groupByKey(identity) + val expected = "KeyValueGroupedDataset: [key: [id: int], value: [id: int]]" + val actual = kvDataset.toString + assert(expected === actual) + } + + test("Check KeyValueGroupedDataset toString: Unnamed KV-pair") { + val kvDataset = (1 to 3).map(x => (x, x.toString)) + .toDF("id", "val1").as[DoubleData].groupByKey(x => (x.id, x.val1)) + val expected = "KeyValueGroupedDataset:" + + " [key: [_1: int, _2: string]," + + " value: [id: int, val1: string]]" + val actual = kvDataset.toString + assert(expected === actual) + } + + test("Check KeyValueGroupedDataset toString: Named KV-pair") { + val kvDataset = (1 to 3).map( x => (x, x.toString)) + .toDF("id", "val1").as[DoubleData].groupByKey(x => DoubleData(x.id, x.val1)) + val expected = "KeyValueGroupedDataset:" + + " [key: [id: int, val1: string]," + + " value: [id: int, val1: string]]" + val actual = kvDataset.toString + assert(expected === actual) + } + + test("Check KeyValueGroupedDataset toString: over length schema ") { + val kvDataset = (1 to 3).map( x => (x, x.toString, x.toLong)) + .toDF("id", "val1", "val2").as[TripleData].groupByKey(identity) + val expected = "KeyValueGroupedDataset:" + + " [key: [id: int, val1: string ... 1 more field(s)]," + + " value: [id: int, val1: string ... 1 more field(s)]]" + val actual = kvDataset.toString + assert(expected === actual) + } } +case class SingleData(id: Int) +case class DoubleData(id: Int, val1: String) +case class TripleData(id: Int, val1: String, val2: Long) + case class WithImmutableMap(id: String, map_test: scala.collection.immutable.Map[Long, String]) case class WithMap(id: String, map_test: scala.collection.Map[Long, String]) case class WithMapInOption(m: Option[scala.collection.Map[Int, Int]]) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala index f9808834df4a5..fcaca3d75b74f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala @@ -17,23 +17,13 @@ package org.apache.spark.sql -import java.util.{ArrayDeque, Locale, TimeZone} +import java.util.{Locale, TimeZone} import scala.collection.JavaConverters._ -import scala.util.control.NonFatal -import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.expressions.aggregate.ImperativeAggregate import org.apache.spark.sql.catalyst.plans._ -import org.apache.spark.sql.catalyst.plans.logical._ -import org.apache.spark.sql.catalyst.trees.TreeNode import org.apache.spark.sql.catalyst.util._ -import org.apache.spark.sql.execution._ -import org.apache.spark.sql.execution.aggregate.TypedAggregateExpression import org.apache.spark.sql.execution.columnar.InMemoryRelation -import org.apache.spark.sql.execution.datasources.LogicalRelation -import org.apache.spark.sql.execution.streaming.MemoryPlan -import org.apache.spark.sql.types.{Metadata, ObjectType} abstract class QueryTest extends PlanTest { From e1960c3d6f380b0dfbba6ee5d8ac6da4bc29a698 Mon Sep 17 00:00:00 2001 From: jerryshao Date: Tue, 17 Oct 2017 22:54:38 +0800 Subject: [PATCH 1500/1765] [SPARK-22062][CORE] Spill large block to disk in BlockManager's remote fetch to avoid OOM ## What changes were proposed in this pull request? In the current BlockManager's `getRemoteBytes`, it will call `BlockTransferService#fetchBlockSync` to get remote block. In the `fetchBlockSync`, Spark will allocate a temporary `ByteBuffer` to store the whole fetched block. This will potentially lead to OOM if block size is too big or several blocks are fetched simultaneously in this executor. So here leveraging the idea of shuffle fetch, to spill the large block to local disk before consumed by upstream code. The behavior is controlled by newly added configuration, if block size is smaller than the threshold, then this block will be persisted in memory; otherwise it will first spill to disk, and then read from disk file. To achieve this feature, what I did is: 1. Rename `TempShuffleFileManager` to `TempFileManager`, since now it is not only used by shuffle. 2. Add a new `TempFileManager` to manage the files of fetched remote blocks, the files are tracked by weak reference, will be deleted when no use at all. ## How was this patch tested? This was tested by adding UT, also manual verification in local test to perform GC to clean the files. Author: jerryshao Closes #19476 from jerryshao/SPARK-22062. --- .../shuffle/ExternalShuffleClient.java | 4 +- .../shuffle/OneForOneBlockFetcher.java | 12 +-- .../spark/network/shuffle/ShuffleClient.java | 10 +- ...eFileManager.java => TempFileManager.java} | 12 +-- .../scala/org/apache/spark/SparkConf.scala | 4 +- .../spark/internal/config/package.scala | 15 +-- .../spark/network/BlockTransferService.scala | 28 +++-- .../netty/NettyBlockTransferService.scala | 6 +- .../shuffle/BlockStoreShuffleReader.scala | 2 +- .../apache/spark/storage/BlockManager.scala | 102 ++++++++++++++++-- .../spark/storage/BlockManagerMaster.scala | 6 ++ .../storage/BlockManagerMasterEndpoint.scala | 14 +++ .../spark/storage/BlockManagerMessages.scala | 7 ++ .../storage/ShuffleBlockFetcherIterator.scala | 8 +- .../org/apache/spark/DistributedSuite.scala | 2 +- .../spark/storage/BlockManagerSuite.scala | 57 +++++++--- .../ShuffleBlockFetcherIteratorSuite.scala | 10 +- docs/configuration.md | 11 +- 18 files changed, 236 insertions(+), 74 deletions(-) rename common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/{TempShuffleFileManager.java => TempFileManager.java} (74%) diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleClient.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleClient.java index 77702447edb88..510017fee2db5 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleClient.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleClient.java @@ -91,7 +91,7 @@ public void fetchBlocks( String execId, String[] blockIds, BlockFetchingListener listener, - TempShuffleFileManager tempShuffleFileManager) { + TempFileManager tempFileManager) { checkInit(); logger.debug("External shuffle fetch from {}:{} (executor id {})", host, port, execId); try { @@ -99,7 +99,7 @@ public void fetchBlocks( (blockIds1, listener1) -> { TransportClient client = clientFactory.createClient(host, port); new OneForOneBlockFetcher(client, appId, execId, - blockIds1, listener1, conf, tempShuffleFileManager).start(); + blockIds1, listener1, conf, tempFileManager).start(); }; int maxRetries = conf.maxIORetries(); diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockFetcher.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockFetcher.java index 66b67e282c80d..3f2f20b4149f1 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockFetcher.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockFetcher.java @@ -58,7 +58,7 @@ public class OneForOneBlockFetcher { private final BlockFetchingListener listener; private final ChunkReceivedCallback chunkCallback; private final TransportConf transportConf; - private final TempShuffleFileManager tempShuffleFileManager; + private final TempFileManager tempFileManager; private StreamHandle streamHandle = null; @@ -79,14 +79,14 @@ public OneForOneBlockFetcher( String[] blockIds, BlockFetchingListener listener, TransportConf transportConf, - TempShuffleFileManager tempShuffleFileManager) { + TempFileManager tempFileManager) { this.client = client; this.openMessage = new OpenBlocks(appId, execId, blockIds); this.blockIds = blockIds; this.listener = listener; this.chunkCallback = new ChunkCallback(); this.transportConf = transportConf; - this.tempShuffleFileManager = tempShuffleFileManager; + this.tempFileManager = tempFileManager; } /** Callback invoked on receipt of each chunk. We equate a single chunk to a single block. */ @@ -125,7 +125,7 @@ public void onSuccess(ByteBuffer response) { // Immediately request all chunks -- we expect that the total size of the request is // reasonable due to higher level chunking in [[ShuffleBlockFetcherIterator]]. for (int i = 0; i < streamHandle.numChunks; i++) { - if (tempShuffleFileManager != null) { + if (tempFileManager != null) { client.stream(OneForOneStreamManager.genStreamChunkId(streamHandle.streamId, i), new DownloadCallback(i)); } else { @@ -164,7 +164,7 @@ private class DownloadCallback implements StreamCallback { private int chunkIndex; DownloadCallback(int chunkIndex) throws IOException { - this.targetFile = tempShuffleFileManager.createTempShuffleFile(); + this.targetFile = tempFileManager.createTempFile(); this.channel = Channels.newChannel(Files.newOutputStream(targetFile.toPath())); this.chunkIndex = chunkIndex; } @@ -180,7 +180,7 @@ public void onComplete(String streamId) throws IOException { ManagedBuffer buffer = new FileSegmentManagedBuffer(transportConf, targetFile, 0, targetFile.length()); listener.onBlockFetchSuccess(blockIds[chunkIndex], buffer); - if (!tempShuffleFileManager.registerTempShuffleFileToClean(targetFile)) { + if (!tempFileManager.registerTempFileToClean(targetFile)) { targetFile.delete(); } } diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ShuffleClient.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ShuffleClient.java index 5bd4412b75275..18b04fedcac5b 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ShuffleClient.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ShuffleClient.java @@ -43,10 +43,10 @@ public void init(String appId) { } * @param execId the executor id. * @param blockIds block ids to fetch. * @param listener the listener to receive block fetching status. - * @param tempShuffleFileManager TempShuffleFileManager to create and clean temp shuffle files. - * If it's not null, the remote blocks will be streamed - * into temp shuffle files to reduce the memory usage, otherwise, - * they will be kept in memory. + * @param tempFileManager TempFileManager to create and clean temp files. + * If it's not null, the remote blocks will be streamed + * into temp shuffle files to reduce the memory usage, otherwise, + * they will be kept in memory. */ public abstract void fetchBlocks( String host, @@ -54,7 +54,7 @@ public abstract void fetchBlocks( String execId, String[] blockIds, BlockFetchingListener listener, - TempShuffleFileManager tempShuffleFileManager); + TempFileManager tempFileManager); /** * Get the shuffle MetricsSet from ShuffleClient, this will be used in MetricsSystem to diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/TempShuffleFileManager.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/TempFileManager.java similarity index 74% rename from common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/TempShuffleFileManager.java rename to common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/TempFileManager.java index 84a5ed6a276bd..552364d274f19 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/TempShuffleFileManager.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/TempFileManager.java @@ -20,17 +20,17 @@ import java.io.File; /** - * A manager to create temp shuffle block files to reduce the memory usage and also clean temp + * A manager to create temp block files to reduce the memory usage and also clean temp * files when they won't be used any more. */ -public interface TempShuffleFileManager { +public interface TempFileManager { - /** Create a temp shuffle block file. */ - File createTempShuffleFile(); + /** Create a temp block file. */ + File createTempFile(); /** - * Register a temp shuffle file to clean up when it won't be used any more. Return whether the + * Register a temp file to clean up when it won't be used any more. Return whether the * file is registered successfully. If `false`, the caller should clean up the file by itself. */ - boolean registerTempShuffleFileToClean(File file); + boolean registerTempFileToClean(File file); } diff --git a/core/src/main/scala/org/apache/spark/SparkConf.scala b/core/src/main/scala/org/apache/spark/SparkConf.scala index e61f943af49f2..57b3744e9c30a 100644 --- a/core/src/main/scala/org/apache/spark/SparkConf.scala +++ b/core/src/main/scala/org/apache/spark/SparkConf.scala @@ -662,7 +662,9 @@ private[spark] object SparkConf extends Logging { "spark.yarn.jars" -> Seq( AlternateConfig("spark.yarn.jar", "2.0")), "spark.yarn.access.hadoopFileSystems" -> Seq( - AlternateConfig("spark.yarn.access.namenodes", "2.2")) + AlternateConfig("spark.yarn.access.namenodes", "2.2")), + "spark.maxRemoteBlockSizeFetchToMem" -> Seq( + AlternateConfig("spark.reducer.maxReqSizeShuffleToMem", "2.3")) ) /** diff --git a/core/src/main/scala/org/apache/spark/internal/config/package.scala b/core/src/main/scala/org/apache/spark/internal/config/package.scala index efffdca1ea59b..e7b406af8d9b1 100644 --- a/core/src/main/scala/org/apache/spark/internal/config/package.scala +++ b/core/src/main/scala/org/apache/spark/internal/config/package.scala @@ -357,13 +357,15 @@ package object config { .checkValue(_ > 0, "The max no. of blocks in flight cannot be non-positive.") .createWithDefault(Int.MaxValue) - private[spark] val REDUCER_MAX_REQ_SIZE_SHUFFLE_TO_MEM = - ConfigBuilder("spark.reducer.maxReqSizeShuffleToMem") - .doc("The blocks of a shuffle request will be fetched to disk when size of the request is " + + private[spark] val MAX_REMOTE_BLOCK_SIZE_FETCH_TO_MEM = + ConfigBuilder("spark.maxRemoteBlockSizeFetchToMem") + .doc("Remote block will be fetched to disk when size of the block is " + "above this threshold. This is to avoid a giant request takes too much memory. We can " + - "enable this config by setting a specific value(e.g. 200m). Note that this config can " + - "be enabled only when the shuffle shuffle service is newer than Spark-2.2 or the shuffle" + - " service is disabled.") + "enable this config by setting a specific value(e.g. 200m). Note this configuration will " + + "affect both shuffle fetch and block manager remote block fetch. For users who " + + "enabled external shuffle service, this feature can only be worked when external shuffle" + + " service is newer than Spark 2.2.") + .withAlternative("spark.reducer.maxReqSizeShuffleToMem") .bytesConf(ByteUnit.BYTE) .createWithDefault(Long.MaxValue) @@ -432,5 +434,4 @@ package object config { .stringConf .toSequence .createOptional - } diff --git a/core/src/main/scala/org/apache/spark/network/BlockTransferService.scala b/core/src/main/scala/org/apache/spark/network/BlockTransferService.scala index fe5fd2da039bb..1d8a266d0079c 100644 --- a/core/src/main/scala/org/apache/spark/network/BlockTransferService.scala +++ b/core/src/main/scala/org/apache/spark/network/BlockTransferService.scala @@ -25,8 +25,8 @@ import scala.concurrent.duration.Duration import scala.reflect.ClassTag import org.apache.spark.internal.Logging -import org.apache.spark.network.buffer.{ManagedBuffer, NioManagedBuffer} -import org.apache.spark.network.shuffle.{BlockFetchingListener, ShuffleClient, TempShuffleFileManager} +import org.apache.spark.network.buffer.{FileSegmentManagedBuffer, ManagedBuffer, NioManagedBuffer} +import org.apache.spark.network.shuffle.{BlockFetchingListener, ShuffleClient, TempFileManager} import org.apache.spark.storage.{BlockId, StorageLevel} import org.apache.spark.util.ThreadUtils @@ -68,7 +68,7 @@ abstract class BlockTransferService extends ShuffleClient with Closeable with Lo execId: String, blockIds: Array[String], listener: BlockFetchingListener, - tempShuffleFileManager: TempShuffleFileManager): Unit + tempFileManager: TempFileManager): Unit /** * Upload a single block to a remote node, available only after [[init]] is invoked. @@ -87,7 +87,12 @@ abstract class BlockTransferService extends ShuffleClient with Closeable with Lo * * It is also only available after [[init]] is invoked. */ - def fetchBlockSync(host: String, port: Int, execId: String, blockId: String): ManagedBuffer = { + def fetchBlockSync( + host: String, + port: Int, + execId: String, + blockId: String, + tempFileManager: TempFileManager): ManagedBuffer = { // A monitor for the thread to wait on. val result = Promise[ManagedBuffer]() fetchBlocks(host, port, execId, Array(blockId), @@ -96,12 +101,17 @@ abstract class BlockTransferService extends ShuffleClient with Closeable with Lo result.failure(exception) } override def onBlockFetchSuccess(blockId: String, data: ManagedBuffer): Unit = { - val ret = ByteBuffer.allocate(data.size.toInt) - ret.put(data.nioByteBuffer()) - ret.flip() - result.success(new NioManagedBuffer(ret)) + data match { + case f: FileSegmentManagedBuffer => + result.success(f) + case _ => + val ret = ByteBuffer.allocate(data.size.toInt) + ret.put(data.nioByteBuffer()) + ret.flip() + result.success(new NioManagedBuffer(ret)) + } } - }, tempShuffleFileManager = null) + }, tempFileManager) ThreadUtils.awaitResult(result.future, Duration.Inf) } diff --git a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala index 6a29e18bf3cbb..b7d8c35032763 100644 --- a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala +++ b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala @@ -32,7 +32,7 @@ import org.apache.spark.network.buffer.ManagedBuffer import org.apache.spark.network.client.{RpcResponseCallback, TransportClientBootstrap, TransportClientFactory} import org.apache.spark.network.crypto.{AuthClientBootstrap, AuthServerBootstrap} import org.apache.spark.network.server._ -import org.apache.spark.network.shuffle.{BlockFetchingListener, OneForOneBlockFetcher, RetryingBlockFetcher, TempShuffleFileManager} +import org.apache.spark.network.shuffle.{BlockFetchingListener, OneForOneBlockFetcher, RetryingBlockFetcher, TempFileManager} import org.apache.spark.network.shuffle.protocol.UploadBlock import org.apache.spark.network.util.JavaUtils import org.apache.spark.serializer.JavaSerializer @@ -105,14 +105,14 @@ private[spark] class NettyBlockTransferService( execId: String, blockIds: Array[String], listener: BlockFetchingListener, - tempShuffleFileManager: TempShuffleFileManager): Unit = { + tempFileManager: TempFileManager): Unit = { logTrace(s"Fetch blocks from $host:$port (executor id $execId)") try { val blockFetchStarter = new RetryingBlockFetcher.BlockFetchStarter { override def createAndStart(blockIds: Array[String], listener: BlockFetchingListener) { val client = clientFactory.createClient(host, port) new OneForOneBlockFetcher(client, appId, execId, blockIds, listener, - transportConf, tempShuffleFileManager).start() + transportConf, tempFileManager).start() } } diff --git a/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala b/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala index c8d1460300934..0562d45ff57c5 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala @@ -52,7 +52,7 @@ private[spark] class BlockStoreShuffleReader[K, C]( SparkEnv.get.conf.getSizeAsMb("spark.reducer.maxSizeInFlight", "48m") * 1024 * 1024, SparkEnv.get.conf.getInt("spark.reducer.maxReqsInFlight", Int.MaxValue), SparkEnv.get.conf.get(config.REDUCER_MAX_BLOCKS_IN_FLIGHT_PER_ADDRESS), - SparkEnv.get.conf.get(config.REDUCER_MAX_REQ_SIZE_SHUFFLE_TO_MEM), + SparkEnv.get.conf.get(config.MAX_REMOTE_BLOCK_SIZE_FETCH_TO_MEM), SparkEnv.get.conf.getBoolean("spark.shuffle.detectCorrupt", true)) val serializerInstance = dep.serializer.newInstance() diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala index a98083df5bd84..e0276a4dc4224 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala @@ -18,8 +18,11 @@ package org.apache.spark.storage import java.io._ +import java.lang.ref.{ReferenceQueue => JReferenceQueue, WeakReference} import java.nio.ByteBuffer import java.nio.channels.Channels +import java.util.Collections +import java.util.concurrent.ConcurrentHashMap import scala.collection.mutable import scala.collection.mutable.HashMap @@ -39,7 +42,7 @@ import org.apache.spark.metrics.source.Source import org.apache.spark.network._ import org.apache.spark.network.buffer.ManagedBuffer import org.apache.spark.network.netty.SparkTransportConf -import org.apache.spark.network.shuffle.ExternalShuffleClient +import org.apache.spark.network.shuffle.{ExternalShuffleClient, TempFileManager} import org.apache.spark.network.shuffle.protocol.ExecutorShuffleInfo import org.apache.spark.rpc.RpcEnv import org.apache.spark.serializer.{SerializerInstance, SerializerManager} @@ -203,6 +206,13 @@ private[spark] class BlockManager( private var blockReplicationPolicy: BlockReplicationPolicy = _ + // A TempFileManager used to track all the files of remote blocks which above the + // specified memory threshold. Files will be deleted automatically based on weak reference. + // Exposed for test + private[storage] val remoteBlockTempFileManager = + new BlockManager.RemoteBlockTempFileManager(this) + private val maxRemoteBlockToMem = conf.get(config.MAX_REMOTE_BLOCK_SIZE_FETCH_TO_MEM) + /** * Initializes the BlockManager with the given appId. This is not performed in the constructor as * the appId may not be known at BlockManager instantiation time (in particular for the driver, @@ -632,8 +642,8 @@ private[spark] class BlockManager( * Return a list of locations for the given block, prioritizing the local machine since * multiple block managers can share the same host, followed by hosts on the same rack. */ - private def getLocations(blockId: BlockId): Seq[BlockManagerId] = { - val locs = Random.shuffle(master.getLocations(blockId)) + private def sortLocations(locations: Seq[BlockManagerId]): Seq[BlockManagerId] = { + val locs = Random.shuffle(locations) val (preferredLocs, otherLocs) = locs.partition { loc => blockManagerId.host == loc.host } blockManagerId.topologyInfo match { case None => preferredLocs ++ otherLocs @@ -653,7 +663,25 @@ private[spark] class BlockManager( require(blockId != null, "BlockId is null") var runningFailureCount = 0 var totalFailureCount = 0 - val locations = getLocations(blockId) + + // Because all the remote blocks are registered in driver, it is not necessary to ask + // all the slave executors to get block status. + val locationsAndStatus = master.getLocationsAndStatus(blockId) + val blockSize = locationsAndStatus.map { b => + b.status.diskSize.max(b.status.memSize) + }.getOrElse(0L) + val blockLocations = locationsAndStatus.map(_.locations).getOrElse(Seq.empty) + + // If the block size is above the threshold, we should pass our FileManger to + // BlockTransferService, which will leverage it to spill the block; if not, then passed-in + // null value means the block will be persisted in memory. + val tempFileManager = if (blockSize > maxRemoteBlockToMem) { + remoteBlockTempFileManager + } else { + null + } + + val locations = sortLocations(blockLocations) val maxFetchFailures = locations.size var locationIterator = locations.iterator while (locationIterator.hasNext) { @@ -661,7 +689,7 @@ private[spark] class BlockManager( logDebug(s"Getting remote block $blockId from $loc") val data = try { blockTransferService.fetchBlockSync( - loc.host, loc.port, loc.executorId, blockId.toString).nioByteBuffer() + loc.host, loc.port, loc.executorId, blockId.toString, tempFileManager).nioByteBuffer() } catch { case NonFatal(e) => runningFailureCount += 1 @@ -684,7 +712,7 @@ private[spark] class BlockManager( // take a significant amount of time. To get rid of these stale entries // we refresh the block locations after a certain number of fetch failures if (runningFailureCount >= maxFailuresBeforeLocationRefresh) { - locationIterator = getLocations(blockId).iterator + locationIterator = sortLocations(master.getLocations(blockId)).iterator logDebug(s"Refreshed locations from the driver " + s"after ${runningFailureCount} fetch failures.") runningFailureCount = 0 @@ -1512,6 +1540,7 @@ private[spark] class BlockManager( // Closing should be idempotent, but maybe not for the NioBlockTransferService. shuffleClient.close() } + remoteBlockTempFileManager.stop() diskBlockManager.stop() rpcEnv.stop(slaveEndpoint) blockInfoManager.clear() @@ -1552,4 +1581,65 @@ private[spark] object BlockManager { override val metricRegistry = new MetricRegistry metricRegistry.registerAll(metricSet) } + + class RemoteBlockTempFileManager(blockManager: BlockManager) + extends TempFileManager with Logging { + + private class ReferenceWithCleanup(file: File, referenceQueue: JReferenceQueue[File]) + extends WeakReference[File](file, referenceQueue) { + private val filePath = file.getAbsolutePath + + def cleanUp(): Unit = { + logDebug(s"Clean up file $filePath") + + if (!new File(filePath).delete()) { + logDebug(s"Fail to delete file $filePath") + } + } + } + + private val referenceQueue = new JReferenceQueue[File] + private val referenceBuffer = Collections.newSetFromMap[ReferenceWithCleanup]( + new ConcurrentHashMap) + + private val POLL_TIMEOUT = 1000 + @volatile private var stopped = false + + private val cleaningThread = new Thread() { override def run() { keepCleaning() } } + cleaningThread.setDaemon(true) + cleaningThread.setName("RemoteBlock-temp-file-clean-thread") + cleaningThread.start() + + override def createTempFile(): File = { + blockManager.diskBlockManager.createTempLocalBlock()._2 + } + + override def registerTempFileToClean(file: File): Boolean = { + referenceBuffer.add(new ReferenceWithCleanup(file, referenceQueue)) + } + + def stop(): Unit = { + stopped = true + cleaningThread.interrupt() + cleaningThread.join() + } + + private def keepCleaning(): Unit = { + while (!stopped) { + try { + Option(referenceQueue.remove(POLL_TIMEOUT)) + .map(_.asInstanceOf[ReferenceWithCleanup]) + .foreach { ref => + referenceBuffer.remove(ref) + ref.cleanUp() + } + } catch { + case _: InterruptedException => + // no-op + case NonFatal(e) => + logError("Error in cleaning thread", e) + } + } + } + } } diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala index 8b1dc0ba6356a..d24421b962774 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala @@ -84,6 +84,12 @@ class BlockManagerMaster( driverEndpoint.askSync[Seq[BlockManagerId]](GetLocations(blockId)) } + /** Get locations as well as status of the blockId from the driver */ + def getLocationsAndStatus(blockId: BlockId): Option[BlockLocationsAndStatus] = { + driverEndpoint.askSync[Option[BlockLocationsAndStatus]]( + GetLocationsAndStatus(blockId)) + } + /** Get locations of multiple blockIds from the driver */ def getLocations(blockIds: Array[BlockId]): IndexedSeq[Seq[BlockManagerId]] = { driverEndpoint.askSync[IndexedSeq[Seq[BlockManagerId]]]( diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala index df0a5f5e229fb..56d0266b8edad 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala @@ -82,6 +82,9 @@ class BlockManagerMasterEndpoint( case GetLocations(blockId) => context.reply(getLocations(blockId)) + case GetLocationsAndStatus(blockId) => + context.reply(getLocationsAndStatus(blockId)) + case GetLocationsMultipleBlockIds(blockIds) => context.reply(getLocationsMultipleBlockIds(blockIds)) @@ -422,6 +425,17 @@ class BlockManagerMasterEndpoint( if (blockLocations.containsKey(blockId)) blockLocations.get(blockId).toSeq else Seq.empty } + private def getLocationsAndStatus(blockId: BlockId): Option[BlockLocationsAndStatus] = { + val locations = Option(blockLocations.get(blockId)).map(_.toSeq).getOrElse(Seq.empty) + val status = locations.headOption.flatMap { bmId => blockManagerInfo(bmId).getStatus(blockId) } + + if (locations.nonEmpty && status.isDefined) { + Some(BlockLocationsAndStatus(locations, status.get)) + } else { + None + } + } + private def getLocationsMultipleBlockIds( blockIds: Array[BlockId]): IndexedSeq[Seq[BlockManagerId]] = { blockIds.map(blockId => getLocations(blockId)) diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala index 0c0ff144596ac..1bbe7a5b39509 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala @@ -93,6 +93,13 @@ private[spark] object BlockManagerMessages { case class GetLocations(blockId: BlockId) extends ToBlockManagerMaster + case class GetLocationsAndStatus(blockId: BlockId) extends ToBlockManagerMaster + + // The response message of `GetLocationsAndStatus` request. + case class BlockLocationsAndStatus(locations: Seq[BlockManagerId], status: BlockStatus) { + assert(locations.nonEmpty) + } + case class GetLocationsMultipleBlockIds(blockIds: Array[BlockId]) extends ToBlockManagerMaster case class GetPeers(blockManagerId: BlockManagerId) extends ToBlockManagerMaster diff --git a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala index 2d176b62f8b36..98b5a735a4529 100644 --- a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala +++ b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala @@ -28,7 +28,7 @@ import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet, Queue} import org.apache.spark.{SparkException, TaskContext} import org.apache.spark.internal.Logging import org.apache.spark.network.buffer.{FileSegmentManagedBuffer, ManagedBuffer} -import org.apache.spark.network.shuffle.{BlockFetchingListener, ShuffleClient, TempShuffleFileManager} +import org.apache.spark.network.shuffle.{BlockFetchingListener, ShuffleClient, TempFileManager} import org.apache.spark.shuffle.FetchFailedException import org.apache.spark.util.Utils import org.apache.spark.util.io.ChunkedByteBufferOutputStream @@ -69,7 +69,7 @@ final class ShuffleBlockFetcherIterator( maxBlocksInFlightPerAddress: Int, maxReqSizeShuffleToMem: Long, detectCorrupt: Boolean) - extends Iterator[(BlockId, InputStream)] with TempShuffleFileManager with Logging { + extends Iterator[(BlockId, InputStream)] with TempFileManager with Logging { import ShuffleBlockFetcherIterator._ @@ -162,11 +162,11 @@ final class ShuffleBlockFetcherIterator( currentResult = null } - override def createTempShuffleFile(): File = { + override def createTempFile(): File = { blockManager.diskBlockManager.createTempLocalBlock()._2 } - override def registerTempShuffleFileToClean(file: File): Boolean = synchronized { + override def registerTempFileToClean(file: File): Boolean = synchronized { if (isZombie) { false } else { diff --git a/core/src/test/scala/org/apache/spark/DistributedSuite.scala b/core/src/test/scala/org/apache/spark/DistributedSuite.scala index bea67b71a5a12..f8005610f7e4f 100644 --- a/core/src/test/scala/org/apache/spark/DistributedSuite.scala +++ b/core/src/test/scala/org/apache/spark/DistributedSuite.scala @@ -171,7 +171,7 @@ class DistributedSuite extends SparkFunSuite with Matchers with LocalSparkContex val serializerManager = SparkEnv.get.serializerManager blockManager.master.getLocations(blockId).foreach { cmId => val bytes = blockTransfer.fetchBlockSync(cmId.host, cmId.port, cmId.executorId, - blockId.toString) + blockId.toString, null) val deserialized = serializerManager.dataDeserializeStream(blockId, new ChunkedByteBuffer(bytes.nioByteBuffer()).toInputStream())(data.elementClassTag).toList assert(deserialized === (1 to 100).toList) diff --git a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala index cfe89fde63f88..d45c194d31adc 100644 --- a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala @@ -17,7 +17,6 @@ package org.apache.spark.storage -import java.io.File import java.nio.ByteBuffer import scala.collection.JavaConverters._ @@ -45,14 +44,14 @@ import org.apache.spark.network.buffer.{ManagedBuffer, NioManagedBuffer} import org.apache.spark.network.client.{RpcResponseCallback, TransportClient} import org.apache.spark.network.netty.{NettyBlockTransferService, SparkTransportConf} import org.apache.spark.network.server.{NoOpRpcHandler, TransportServer, TransportServerBootstrap} -import org.apache.spark.network.shuffle.{BlockFetchingListener, ShuffleClient, TempShuffleFileManager} +import org.apache.spark.network.shuffle.{BlockFetchingListener, ShuffleClient, TempFileManager} import org.apache.spark.network.shuffle.protocol.{BlockTransferMessage, RegisterExecutor} import org.apache.spark.rpc.RpcEnv import org.apache.spark.scheduler.LiveListenerBus import org.apache.spark.security.{CryptoStreamUtils, EncryptionFunSuite} import org.apache.spark.serializer.{JavaSerializer, KryoSerializer, SerializerManager} import org.apache.spark.shuffle.sort.SortShuffleManager -import org.apache.spark.storage.BlockManagerMessages.BlockManagerHeartbeat +import org.apache.spark.storage.BlockManagerMessages._ import org.apache.spark.util._ import org.apache.spark.util.io.ChunkedByteBuffer @@ -512,8 +511,8 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE when(bmMaster.getLocations(mc.any[BlockId])).thenReturn(Seq(bmId1, bmId2, bmId3)) val blockManager = makeBlockManager(128, "exec", bmMaster) - val getLocations = PrivateMethod[Seq[BlockManagerId]]('getLocations) - val locations = blockManager invokePrivate getLocations(BroadcastBlockId(0)) + val sortLocations = PrivateMethod[Seq[BlockManagerId]]('sortLocations) + val locations = blockManager invokePrivate sortLocations(bmMaster.getLocations("test")) assert(locations.map(_.host) === Seq(localHost, localHost, otherHost)) } @@ -535,8 +534,8 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE val blockManager = makeBlockManager(128, "exec", bmMaster) blockManager.blockManagerId = BlockManagerId(SparkContext.DRIVER_IDENTIFIER, localHost, 1, Some(localRack)) - val getLocations = PrivateMethod[Seq[BlockManagerId]]('getLocations) - val locations = blockManager invokePrivate getLocations(BroadcastBlockId(0)) + val sortLocations = PrivateMethod[Seq[BlockManagerId]]('sortLocations) + val locations = blockManager invokePrivate sortLocations(bmMaster.getLocations("test")) assert(locations.map(_.host) === Seq(localHost, localHost, otherHost, otherHost, otherHost)) assert(locations.flatMap(_.topologyInfo) === Seq(localRack, localRack, localRack, otherRack, otherRack)) @@ -1274,13 +1273,18 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE // so that we have a chance to do location refresh val blockManagerIds = (0 to maxFailuresBeforeLocationRefresh) .map { i => BlockManagerId(s"id-$i", s"host-$i", i + 1) } - when(mockBlockManagerMaster.getLocations(mc.any[BlockId])).thenReturn(blockManagerIds) + when(mockBlockManagerMaster.getLocationsAndStatus(mc.any[BlockId])).thenReturn( + Option(BlockLocationsAndStatus(blockManagerIds, BlockStatus.empty))) + when(mockBlockManagerMaster.getLocations(mc.any[BlockId])).thenReturn( + blockManagerIds) + store = makeBlockManager(8000, "executor1", mockBlockManagerMaster, transferService = Option(mockBlockTransferService)) val block = store.getRemoteBytes("item") .asInstanceOf[Option[ByteBuffer]] assert(block.isDefined) - verify(mockBlockManagerMaster, times(2)).getLocations("item") + verify(mockBlockManagerMaster, times(1)).getLocationsAndStatus("item") + verify(mockBlockManagerMaster, times(1)).getLocations("item") } test("SPARK-17484: block status is properly updated following an exception in put()") { @@ -1371,8 +1375,32 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE server.close() } + test("fetch remote block to local disk if block size is larger than threshold") { + conf.set(MAX_REMOTE_BLOCK_SIZE_FETCH_TO_MEM, 1000L) + + val mockBlockManagerMaster = mock(classOf[BlockManagerMaster]) + val mockBlockTransferService = new MockBlockTransferService(0) + val blockLocations = Seq(BlockManagerId("id-0", "host-0", 1)) + val blockStatus = BlockStatus(StorageLevel.DISK_ONLY, 0L, 2000L) + + when(mockBlockManagerMaster.getLocationsAndStatus(mc.any[BlockId])).thenReturn( + Option(BlockLocationsAndStatus(blockLocations, blockStatus))) + when(mockBlockManagerMaster.getLocations(mc.any[BlockId])).thenReturn(blockLocations) + + store = makeBlockManager(8000, "executor1", mockBlockManagerMaster, + transferService = Option(mockBlockTransferService)) + val block = store.getRemoteBytes("item") + .asInstanceOf[Option[ByteBuffer]] + + assert(block.isDefined) + assert(mockBlockTransferService.numCalls === 1) + // assert FileManager is not null if the block size is larger than threshold. + assert(mockBlockTransferService.tempFileManager === store.remoteBlockTempFileManager) + } + class MockBlockTransferService(val maxFailures: Int) extends BlockTransferService { var numCalls = 0 + var tempFileManager: TempFileManager = null override def init(blockDataManager: BlockDataManager): Unit = {} @@ -1382,7 +1410,7 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE execId: String, blockIds: Array[String], listener: BlockFetchingListener, - tempShuffleFileManager: TempShuffleFileManager): Unit = { + tempFileManager: TempFileManager): Unit = { listener.onBlockFetchSuccess("mockBlockId", new NioManagedBuffer(ByteBuffer.allocate(1))) } @@ -1394,7 +1422,8 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE override def uploadBlock( hostname: String, - port: Int, execId: String, + port: Int, + execId: String, blockId: BlockId, blockData: ManagedBuffer, level: StorageLevel, @@ -1407,12 +1436,14 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE host: String, port: Int, execId: String, - blockId: String): ManagedBuffer = { + blockId: String, + tempFileManager: TempFileManager): ManagedBuffer = { numCalls += 1 + this.tempFileManager = tempFileManager if (numCalls <= maxFailures) { throw new RuntimeException("Failing block fetch in the mock block transfer service") } - super.fetchBlockSync(host, port, execId, blockId) + super.fetchBlockSync(host, port, execId, blockId, tempFileManager) } } } diff --git a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala index c371cbcf8dff5..5bfe9905ff17b 100644 --- a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala @@ -33,7 +33,7 @@ import org.scalatest.PrivateMethodTester import org.apache.spark.{SparkFunSuite, TaskContext} import org.apache.spark.network._ import org.apache.spark.network.buffer.{FileSegmentManagedBuffer, ManagedBuffer} -import org.apache.spark.network.shuffle.{BlockFetchingListener, TempShuffleFileManager} +import org.apache.spark.network.shuffle.{BlockFetchingListener, TempFileManager} import org.apache.spark.network.util.LimitedInputStream import org.apache.spark.shuffle.FetchFailedException import org.apache.spark.util.Utils @@ -437,12 +437,12 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT val remoteBlocks = Map[BlockId, ManagedBuffer]( ShuffleBlockId(0, 0, 0) -> createMockManagedBuffer()) val transfer = mock(classOf[BlockTransferService]) - var tempShuffleFileManager: TempShuffleFileManager = null + var tempFileManager: TempFileManager = null when(transfer.fetchBlocks(any(), any(), any(), any(), any(), any())) .thenAnswer(new Answer[Unit] { override def answer(invocation: InvocationOnMock): Unit = { val listener = invocation.getArguments()(4).asInstanceOf[BlockFetchingListener] - tempShuffleFileManager = invocation.getArguments()(5).asInstanceOf[TempShuffleFileManager] + tempFileManager = invocation.getArguments()(5).asInstanceOf[TempFileManager] Future { listener.onBlockFetchSuccess( ShuffleBlockId(0, 0, 0).toString, remoteBlocks(ShuffleBlockId(0, 0, 0))) @@ -472,13 +472,13 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT fetchShuffleBlock(blocksByAddress1) // `maxReqSizeShuffleToMem` is 200, which is greater than the block size 100, so don't fetch // shuffle block to disk. - assert(tempShuffleFileManager == null) + assert(tempFileManager == null) val blocksByAddress2 = Seq[(BlockManagerId, Seq[(BlockId, Long)])]( (remoteBmId, remoteBlocks.keys.map(blockId => (blockId, 300L)).toSeq)) fetchShuffleBlock(blocksByAddress2) // `maxReqSizeShuffleToMem` is 200, which is smaller than the block size 300, so fetch // shuffle block to disk. - assert(tempShuffleFileManager != null) + assert(tempFileManager != null) } } diff --git a/docs/configuration.md b/docs/configuration.md index 7a777d3c6fa3d..bb06c8faaaed7 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -547,13 +547,14 @@ Apart from these, the following properties are also available, and may be useful - spark.reducer.maxReqSizeShuffleToMem + spark.maxRemoteBlockSizeFetchToMem Long.MaxValue - The blocks of a shuffle request will be fetched to disk when size of the request is above - this threshold. This is to avoid a giant request takes too much memory. We can enable this - config by setting a specific value(e.g. 200m). Note that this config can be enabled only when - the shuffle shuffle service is newer than Spark-2.2 or the shuffle service is disabled. + The remote block will be fetched to disk when size of the block is above this threshold. + This is to avoid a giant request takes too much memory. We can enable this config by setting + a specific value(e.g. 200m). Note this configuration will affect both shuffle fetch + and block manager remote block fetch. For users who enabled external shuffle service, + this feature can only be worked when external shuffle service is newer than Spark 2.2. From 75d666b95a711787355ca3895057dabadd429023 Mon Sep 17 00:00:00 2001 From: Jose Torres Date: Tue, 17 Oct 2017 12:26:53 -0700 Subject: [PATCH 1501/1765] [SPARK-22136][SS] Evaluate one-sided conditions early in stream-stream joins. ## What changes were proposed in this pull request? Evaluate one-sided conditions early in stream-stream joins. This is in addition to normal filter pushdown, because integrating it with the join logic allows it to take place in outer join scenarios. This means that rows which can never satisfy the join condition won't clog up the state. ## How was this patch tested? new unit tests Author: Jose Torres Closes #19452 from joseph-torres/SPARK-22136. --- .../streaming/IncrementalExecution.scala | 2 +- .../StreamingSymmetricHashJoinExec.scala | 134 ++++++++++------ .../StreamingSymmetricHashJoinHelper.scala | 70 +++++++- .../sql/streaming/StreamingJoinSuite.scala | 150 +++++++++++++++++- ...treamingSymmetricHashJoinHelperSuite.scala | 130 +++++++++++++++ 5 files changed, 433 insertions(+), 53 deletions(-) create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingSymmetricHashJoinHelperSuite.scala diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala index 2e378637727fc..a10ed5f2df1b5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala @@ -133,7 +133,7 @@ class IncrementalExecution( eventTimeWatermark = Some(offsetSeqMetadata.batchWatermarkMs), stateWatermarkPredicates = StreamingSymmetricHashJoinHelper.getStateWatermarkPredicates( - j.left.output, j.right.output, j.leftKeys, j.rightKeys, j.condition, + j.left.output, j.right.output, j.leftKeys, j.rightKeys, j.condition.full, Some(offsetSeqMetadata.batchWatermarkMs)) ) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinExec.scala index 9bd2127a28ff6..c351f658cb955 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinExec.scala @@ -21,7 +21,7 @@ import java.util.concurrent.TimeUnit.NANOSECONDS import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.{Attribute, BindReferences, Expression, GenericInternalRow, JoinedRow, Literal, NamedExpression, PreciseTimestampConversion, UnsafeProjection, UnsafeRow} +import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression, GenericInternalRow, JoinedRow, Literal, UnsafeProjection, UnsafeRow} import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical.EventTimeWatermark._ import org.apache.spark.sql.catalyst.plans.physical._ @@ -29,7 +29,6 @@ import org.apache.spark.sql.execution.{BinaryExecNode, SparkPlan} import org.apache.spark.sql.execution.streaming.StreamingSymmetricHashJoinHelper._ import org.apache.spark.sql.execution.streaming.state._ import org.apache.spark.sql.internal.SessionState -import org.apache.spark.sql.types.{LongType, TimestampType} import org.apache.spark.util.{CompletionIterator, SerializableConfiguration} @@ -115,7 +114,8 @@ import org.apache.spark.util.{CompletionIterator, SerializableConfiguration} * @param leftKeys Expression to generate key rows for joining from left input * @param rightKeys Expression to generate key rows for joining from right input * @param joinType Type of join (inner, left outer, etc.) - * @param condition Optional, additional condition to filter output of the equi-join + * @param condition Conditions to filter rows, split by left, right, and joined. See + * [[JoinConditionSplitPredicates]] * @param stateInfo Version information required to read join state (buffered rows) * @param eventTimeWatermark Watermark of input event, same for both sides * @param stateWatermarkPredicates Predicates for removal of state, see @@ -127,7 +127,7 @@ case class StreamingSymmetricHashJoinExec( leftKeys: Seq[Expression], rightKeys: Seq[Expression], joinType: JoinType, - condition: Option[Expression], + condition: JoinConditionSplitPredicates, stateInfo: Option[StatefulOperatorStateInfo], eventTimeWatermark: Option[Long], stateWatermarkPredicates: JoinStateWatermarkPredicates, @@ -141,8 +141,10 @@ case class StreamingSymmetricHashJoinExec( condition: Option[Expression], left: SparkPlan, right: SparkPlan) = { + this( - leftKeys, rightKeys, joinType, condition, stateInfo = None, eventTimeWatermark = None, + leftKeys, rightKeys, joinType, JoinConditionSplitPredicates(condition, left, right), + stateInfo = None, eventTimeWatermark = None, stateWatermarkPredicates = JoinStateWatermarkPredicates(), left, right) } @@ -161,6 +163,9 @@ case class StreamingSymmetricHashJoinExec( new SerializableConfiguration(SessionState.newHadoopConf( sparkContext.hadoopConfiguration, sqlContext.conf))) + val nullLeft = new GenericInternalRow(left.output.map(_.withNullability(true)).length) + val nullRight = new GenericInternalRow(right.output.map(_.withNullability(true)).length) + override def requiredChildDistribution: Seq[Distribution] = ClusteredDistribution(leftKeys) :: ClusteredDistribution(rightKeys) :: Nil @@ -206,10 +211,15 @@ case class StreamingSymmetricHashJoinExec( val updateStartTimeNs = System.nanoTime val joinedRow = new JoinedRow + + val postJoinFilter = + newPredicate(condition.bothSides.getOrElse(Literal(true)), left.output ++ right.output).eval _ val leftSideJoiner = new OneSideHashJoiner( - LeftSide, left.output, leftKeys, leftInputIter, stateWatermarkPredicates.left) + LeftSide, left.output, leftKeys, leftInputIter, + condition.leftSideOnly, postJoinFilter, stateWatermarkPredicates.left) val rightSideJoiner = new OneSideHashJoiner( - RightSide, right.output, rightKeys, rightInputIter, stateWatermarkPredicates.right) + RightSide, right.output, rightKeys, rightInputIter, + condition.rightSideOnly, postJoinFilter, stateWatermarkPredicates.right) // Join one side input using the other side's buffered/state rows. Here is how it is done. // @@ -221,43 +231,28 @@ case class StreamingSymmetricHashJoinExec( // matching new left input with new right input, since the new left input has become stored // by that point. This tiny asymmetry is necessary to avoid duplication. val leftOutputIter = leftSideJoiner.storeAndJoinWithOtherSide(rightSideJoiner) { - (input: UnsafeRow, matched: UnsafeRow) => joinedRow.withLeft(input).withRight(matched) + (input: InternalRow, matched: InternalRow) => joinedRow.withLeft(input).withRight(matched) } val rightOutputIter = rightSideJoiner.storeAndJoinWithOtherSide(leftSideJoiner) { - (input: UnsafeRow, matched: UnsafeRow) => joinedRow.withLeft(matched).withRight(input) + (input: InternalRow, matched: InternalRow) => joinedRow.withLeft(matched).withRight(input) } - // Filter the joined rows based on the given condition. - val outputFilterFunction = newPredicate(condition.getOrElse(Literal(true)), output).eval _ - // We need to save the time that the inner join output iterator completes, since outer join // output counts as both update and removal time. var innerOutputCompletionTimeNs: Long = 0 def onInnerOutputCompletion = { innerOutputCompletionTimeNs = System.nanoTime } - val filteredInnerOutputIter = CompletionIterator[InternalRow, Iterator[InternalRow]]( - (leftOutputIter ++ rightOutputIter).filter(outputFilterFunction), onInnerOutputCompletion) - - def matchesWithRightSideState(leftKeyValue: UnsafeRowPair) = { - rightSideJoiner.get(leftKeyValue.key).exists( - rightValue => { - outputFilterFunction( - joinedRow.withLeft(leftKeyValue.value).withRight(rightValue)) - }) - } + // This is the iterator which produces the inner join rows. For outer joins, this will be + // prepended to a second iterator producing outer join rows; for inner joins, this is the full + // output. + val innerOutputIter = CompletionIterator[InternalRow, Iterator[InternalRow]]( + (leftOutputIter ++ rightOutputIter), onInnerOutputCompletion) - def matchesWithLeftSideState(rightKeyValue: UnsafeRowPair) = { - leftSideJoiner.get(rightKeyValue.key).exists( - leftValue => { - outputFilterFunction( - joinedRow.withLeft(leftValue).withRight(rightKeyValue.value)) - }) - } val outputIter: Iterator[InternalRow] = joinType match { case Inner => - filteredInnerOutputIter + innerOutputIter case LeftOuter => // We generate the outer join input by: // * Getting an iterator over the rows that have aged out on the left side. These rows are @@ -268,28 +263,37 @@ case class StreamingSymmetricHashJoinExec( // we know we can join with null, since there was never (including this batch) a match // within the watermark period. If it does, there must have been a match at some point, so // we know we can't join with null. - val nullRight = new GenericInternalRow(right.output.map(_.withNullability(true)).length) + def matchesWithRightSideState(leftKeyValue: UnsafeRowPair) = { + rightSideJoiner.get(leftKeyValue.key).exists { rightValue => + postJoinFilter(joinedRow.withLeft(leftKeyValue.value).withRight(rightValue)) + } + } val removedRowIter = leftSideJoiner.removeOldState() val outerOutputIter = removedRowIter .filterNot(pair => matchesWithRightSideState(pair)) .map(pair => joinedRow.withLeft(pair.value).withRight(nullRight)) - filteredInnerOutputIter ++ outerOutputIter + innerOutputIter ++ outerOutputIter case RightOuter => // See comments for left outer case. - val nullLeft = new GenericInternalRow(left.output.map(_.withNullability(true)).length) + def matchesWithLeftSideState(rightKeyValue: UnsafeRowPair) = { + leftSideJoiner.get(rightKeyValue.key).exists { leftValue => + postJoinFilter(joinedRow.withLeft(leftValue).withRight(rightKeyValue.value)) + } + } val removedRowIter = rightSideJoiner.removeOldState() val outerOutputIter = removedRowIter .filterNot(pair => matchesWithLeftSideState(pair)) .map(pair => joinedRow.withLeft(nullLeft).withRight(pair.value)) - filteredInnerOutputIter ++ outerOutputIter + innerOutputIter ++ outerOutputIter case _ => throwBadJoinTypeException() } + val outputProjection = UnsafeProjection.create(left.output ++ right.output, output) val outputIterWithMetrics = outputIter.map { row => numOutputRows += 1 - row + outputProjection(row) } // Function to remove old state after all the input has been consumed and output generated @@ -349,14 +353,36 @@ case class StreamingSymmetricHashJoinExec( /** * Internal helper class to consume input rows, generate join output rows using other sides * buffered state rows, and finally clean up this sides buffered state rows + * + * @param joinSide The JoinSide - either left or right. + * @param inputAttributes The input attributes for this side of the join. + * @param joinKeys The join keys. + * @param inputIter The iterator of input rows on this side to be joined. + * @param preJoinFilterExpr A filter over rows on this side. This filter rejects rows that could + * never pass the overall join condition no matter what other side row + * they're joined with. + * @param postJoinFilter A filter over joined rows. This filter completes the application of + * the overall join condition, assuming that preJoinFilter on both sides + * of the join has already been passed. + * Passed as a function rather than expression to avoid creating the + * predicate twice; we also need this filter later on in the parent exec. + * @param stateWatermarkPredicate The state watermark predicate. See + * [[StreamingSymmetricHashJoinExec]] for further description of + * state watermarks. */ private class OneSideHashJoiner( joinSide: JoinSide, inputAttributes: Seq[Attribute], joinKeys: Seq[Expression], inputIter: Iterator[InternalRow], + preJoinFilterExpr: Option[Expression], + postJoinFilter: (InternalRow) => Boolean, stateWatermarkPredicate: Option[JoinStateWatermarkPredicate]) { + // Filter the joined rows based on the given condition. + val preJoinFilter = + newPredicate(preJoinFilterExpr.getOrElse(Literal(true)), inputAttributes).eval _ + private val joinStateManager = new SymmetricHashJoinStateManager( joinSide, inputAttributes, joinKeys, stateInfo, storeConf, hadoopConfBcast.value.value) private[this] val keyGenerator = UnsafeProjection.create(joinKeys, inputAttributes) @@ -388,8 +414,8 @@ case class StreamingSymmetricHashJoinExec( */ def storeAndJoinWithOtherSide( otherSideJoiner: OneSideHashJoiner)( - generateJoinedRow: (UnsafeRow, UnsafeRow) => JoinedRow): Iterator[InternalRow] = { - + generateJoinedRow: (InternalRow, InternalRow) => JoinedRow): + Iterator[InternalRow] = { val watermarkAttribute = inputAttributes.find(_.metadata.contains(delayKey)) val nonLateRows = WatermarkSupport.watermarkExpression(watermarkAttribute, eventTimeWatermark) match { @@ -402,17 +428,31 @@ case class StreamingSymmetricHashJoinExec( nonLateRows.flatMap { row => val thisRow = row.asInstanceOf[UnsafeRow] - val key = keyGenerator(thisRow) - val outputIter = otherSideJoiner.joinStateManager.get(key).map { thatRow => - generateJoinedRow(thisRow, thatRow) - } - val shouldAddToState = // add only if both removal predicates do not match - !stateKeyWatermarkPredicateFunc(key) && !stateValueWatermarkPredicateFunc(thisRow) - if (shouldAddToState) { - joinStateManager.append(key, thisRow) - updatedStateRowsCount += 1 + // If this row fails the pre join filter, that means it can never satisfy the full join + // condition no matter what other side row it's matched with. This allows us to avoid + // adding it to the state, and generate an outer join row immediately (or do nothing in + // the case of inner join). + if (preJoinFilter(thisRow)) { + val key = keyGenerator(thisRow) + val outputIter = otherSideJoiner.joinStateManager.get(key).map { thatRow => + generateJoinedRow(thisRow, thatRow) + }.filter(postJoinFilter) + val shouldAddToState = // add only if both removal predicates do not match + !stateKeyWatermarkPredicateFunc(key) && !stateValueWatermarkPredicateFunc(thisRow) + if (shouldAddToState) { + joinStateManager.append(key, thisRow) + updatedStateRowsCount += 1 + } + outputIter + } else { + joinSide match { + case LeftSide if joinType == LeftOuter => + Iterator(generateJoinedRow(thisRow, nullRight)) + case RightSide if joinType == RightOuter => + Iterator(generateJoinedRow(thisRow, nullLeft)) + case _ => Iterator() + } } - outputIter } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinHelper.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinHelper.scala index 64c7189f72ac3..167e991ca62f8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinHelper.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinHelper.scala @@ -24,8 +24,9 @@ import org.apache.spark.{Partition, SparkContext} import org.apache.spark.internal.Logging import org.apache.spark.rdd.{RDD, ZippedPartitionsRDD2} import org.apache.spark.sql.catalyst.analysis.StreamingJoinHelper -import org.apache.spark.sql.catalyst.expressions.{Add, Attribute, AttributeReference, AttributeSet, BoundReference, Cast, CheckOverflow, Expression, ExpressionSet, GreaterThan, GreaterThanOrEqual, LessThan, LessThanOrEqual, Literal, Multiply, NamedExpression, PreciseTimestampConversion, PredicateHelper, Subtract, TimeAdd, TimeSub, UnaryMinus} +import org.apache.spark.sql.catalyst.expressions.{Add, And, Attribute, AttributeReference, AttributeSet, BoundReference, Cast, CheckOverflow, Expression, ExpressionSet, GreaterThan, GreaterThanOrEqual, LessThan, LessThanOrEqual, Literal, Multiply, NamedExpression, PreciseTimestampConversion, PredicateHelper, Subtract, TimeAdd, TimeSub, UnaryMinus} import org.apache.spark.sql.catalyst.plans.logical.EventTimeWatermark._ +import org.apache.spark.sql.execution.SparkPlan import org.apache.spark.sql.execution.streaming.WatermarkSupport.watermarkExpression import org.apache.spark.sql.execution.streaming.state.{StateStoreCoordinatorRef, StateStoreProvider, StateStoreProviderId} import org.apache.spark.sql.types._ @@ -66,6 +67,73 @@ object StreamingSymmetricHashJoinHelper extends Logging { } } + /** + * Wrapper around various useful splits of the join condition. + * left AND right AND joined is equivalent to full. + * + * Note that left and right do not necessarily contain *all* conjuncts which satisfy + * their condition. Any conjuncts after the first nondeterministic one are treated as + * nondeterministic for purposes of the split. + * + * @param leftSideOnly Deterministic conjuncts which reference only the left side of the join. + * @param rightSideOnly Deterministic conjuncts which reference only the right side of the join. + * @param bothSides Conjuncts which are nondeterministic, occur after a nondeterministic conjunct, + * or reference both left and right sides of the join. + * @param full The full join condition. + */ + case class JoinConditionSplitPredicates( + leftSideOnly: Option[Expression], + rightSideOnly: Option[Expression], + bothSides: Option[Expression], + full: Option[Expression]) { + override def toString(): String = { + s"condition = [ leftOnly = ${leftSideOnly.map(_.toString).getOrElse("null")}, " + + s"rightOnly = ${rightSideOnly.map(_.toString).getOrElse("null")}, " + + s"both = ${bothSides.map(_.toString).getOrElse("null")}, " + + s"full = ${full.map(_.toString).getOrElse("null")} ]" + } + } + + object JoinConditionSplitPredicates extends PredicateHelper { + def apply(condition: Option[Expression], left: SparkPlan, right: SparkPlan): + JoinConditionSplitPredicates = { + // Split the condition into 3 parts: + // * Conjuncts that can be evaluated on only the left input. + // * Conjuncts that can be evaluated on only the right input. + // * Conjuncts that require both left and right input. + // + // Note that we treat nondeterministic conjuncts as though they require both left and right + // input. To maintain their semantics, they need to be evaluated exactly once per joined row. + val (leftCondition, rightCondition, joinedCondition) = { + if (condition.isEmpty) { + (None, None, None) + } else { + // Span rather than partition, because nondeterministic expressions don't commute + // across AND. + val (deterministicConjuncts, nonDeterministicConjuncts) = + splitConjunctivePredicates(condition.get).span(_.deterministic) + + val (leftConjuncts, nonLeftConjuncts) = deterministicConjuncts.partition { cond => + cond.references.subsetOf(left.outputSet) + } + + val (rightConjuncts, nonRightConjuncts) = deterministicConjuncts.partition { cond => + cond.references.subsetOf(right.outputSet) + } + + ( + leftConjuncts.reduceOption(And), + rightConjuncts.reduceOption(And), + (nonLeftConjuncts.intersect(nonRightConjuncts) ++ nonDeterministicConjuncts) + .reduceOption(And) + ) + } + } + + JoinConditionSplitPredicates(leftCondition, rightCondition, joinedCondition, condition) + } + } + /** Get the predicates defining the state watermarks for both sides of the join */ def getStateWatermarkPredicates( leftAttributes: Seq[Attribute], diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingJoinSuite.scala index d32617275aadc..54eb863dacc83 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingJoinSuite.scala @@ -365,6 +365,24 @@ class StreamingInnerJoinSuite extends StreamTest with StateStoreMetricsTest with } } } + + test("join between three streams") { + val input1 = MemoryStream[Int] + val input2 = MemoryStream[Int] + val input3 = MemoryStream[Int] + + val df1 = input1.toDF.select('value as "leftKey", ('value * 2) as "leftValue") + val df2 = input2.toDF.select('value as "middleKey", ('value * 3) as "middleValue") + val df3 = input3.toDF.select('value as "rightKey", ('value * 5) as "rightValue") + + val joined = df1.join(df2, expr("leftKey = middleKey")).join(df3, expr("rightKey = middleKey")) + + testStream(joined)( + AddData(input1, 1, 5), + AddData(input2, 1, 5, 10), + AddData(input3, 5, 10), + CheckLastBatch((5, 10, 5, 15, 5, 25))) + } } class StreamingOuterJoinSuite extends StreamTest with StateStoreMetricsTest with BeforeAndAfter { @@ -405,6 +423,130 @@ class StreamingOuterJoinSuite extends StreamTest with StateStoreMetricsTest with (input1, input2, joined) } + test("left outer early state exclusion on left") { + val (leftInput, df1) = setupStream("left", 2) + val (rightInput, df2) = setupStream("right", 3) + // Use different schemas to ensure the null row is being generated from the correct side. + val left = df1.select('key, window('leftTime, "10 second"), 'leftValue) + val right = df2.select('key, window('rightTime, "10 second"), 'rightValue.cast("string")) + + val joined = left.join( + right, + left("key") === right("key") + && left("window") === right("window") + && 'leftValue > 4, + "left_outer") + .select(left("key"), left("window.end").cast("long"), 'leftValue, 'rightValue) + + testStream(joined)( + AddData(leftInput, 1, 2, 3), + AddData(rightInput, 3, 4, 5), + // The left rows with leftValue <= 4 should generate their outer join row now and + // not get added to the state. + CheckLastBatch(Row(3, 10, 6, "9"), Row(1, 10, 2, null), Row(2, 10, 4, null)), + assertNumStateRows(total = 4, updated = 4), + // We shouldn't get more outer join rows when the watermark advances. + AddData(leftInput, 20), + AddData(rightInput, 21), + CheckLastBatch(), + AddData(rightInput, 20), + CheckLastBatch((20, 30, 40, "60")) + ) + } + + test("left outer early state exclusion on right") { + val (leftInput, df1) = setupStream("left", 2) + val (rightInput, df2) = setupStream("right", 3) + // Use different schemas to ensure the null row is being generated from the correct side. + val left = df1.select('key, window('leftTime, "10 second"), 'leftValue) + val right = df2.select('key, window('rightTime, "10 second"), 'rightValue.cast("string")) + + val joined = left.join( + right, + left("key") === right("key") + && left("window") === right("window") + && 'rightValue.cast("int") > 7, + "left_outer") + .select(left("key"), left("window.end").cast("long"), 'leftValue, 'rightValue) + + testStream(joined)( + AddData(leftInput, 3, 4, 5), + AddData(rightInput, 1, 2, 3), + // The right rows with value <= 7 should never be added to the state. + CheckLastBatch(Row(3, 10, 6, "9")), + assertNumStateRows(total = 4, updated = 4), + // When the watermark advances, we get the outer join rows just as we would if they + // were added but didn't match the full join condition. + AddData(leftInput, 20), + AddData(rightInput, 21), + CheckLastBatch(), + AddData(rightInput, 20), + CheckLastBatch(Row(20, 30, 40, "60"), Row(4, 10, 8, null), Row(5, 10, 10, null)) + ) + } + + test("right outer early state exclusion on left") { + val (leftInput, df1) = setupStream("left", 2) + val (rightInput, df2) = setupStream("right", 3) + // Use different schemas to ensure the null row is being generated from the correct side. + val left = df1.select('key, window('leftTime, "10 second"), 'leftValue) + val right = df2.select('key, window('rightTime, "10 second"), 'rightValue.cast("string")) + + val joined = left.join( + right, + left("key") === right("key") + && left("window") === right("window") + && 'leftValue > 4, + "right_outer") + .select(right("key"), right("window.end").cast("long"), 'leftValue, 'rightValue) + + testStream(joined)( + AddData(leftInput, 1, 2, 3), + AddData(rightInput, 3, 4, 5), + // The left rows with value <= 4 should never be added to the state. + CheckLastBatch(Row(3, 10, 6, "9")), + assertNumStateRows(total = 4, updated = 4), + // When the watermark advances, we get the outer join rows just as we would if they + // were added but didn't match the full join condition. + AddData(leftInput, 20), + AddData(rightInput, 21), + CheckLastBatch(), + AddData(rightInput, 20), + CheckLastBatch(Row(20, 30, 40, "60"), Row(4, 10, null, "12"), Row(5, 10, null, "15")) + ) + } + + test("right outer early state exclusion on right") { + val (leftInput, df1) = setupStream("left", 2) + val (rightInput, df2) = setupStream("right", 3) + // Use different schemas to ensure the null row is being generated from the correct side. + val left = df1.select('key, window('leftTime, "10 second"), 'leftValue) + val right = df2.select('key, window('rightTime, "10 second"), 'rightValue.cast("string")) + + val joined = left.join( + right, + left("key") === right("key") + && left("window") === right("window") + && 'rightValue.cast("int") > 7, + "right_outer") + .select(right("key"), right("window.end").cast("long"), 'leftValue, 'rightValue) + + testStream(joined)( + AddData(leftInput, 3, 4, 5), + AddData(rightInput, 1, 2, 3), + // The right rows with rightValue <= 7 should generate their outer join row now and + // not get added to the state. + CheckLastBatch(Row(3, 10, 6, "9"), Row(1, 10, null, "3"), Row(2, 10, null, "6")), + assertNumStateRows(total = 4, updated = 4), + // We shouldn't get more outer join rows when the watermark advances. + AddData(leftInput, 20), + AddData(rightInput, 21), + CheckLastBatch(), + AddData(rightInput, 20), + CheckLastBatch((20, 30, 40, "60")) + ) + } + test("windowed left outer join") { val (leftInput, rightInput, joined) = setupWindowedJoin("left_outer") @@ -495,7 +637,7 @@ class StreamingOuterJoinSuite extends StreamTest with StateStoreMetricsTest with // When the join condition isn't true, the outer null rows must be generated, even if the join // keys themselves have a match. - test("left outer join with non-key condition violated on left") { + test("left outer join with non-key condition violated") { val (leftInput, simpleLeftDf) = setupStream("left", 2) val (rightInput, simpleRightDf) = setupStream("right", 3) @@ -513,14 +655,14 @@ class StreamingOuterJoinSuite extends StreamTest with StateStoreMetricsTest with // leftValue <= 10 should generate outer join rows even though it matches right keys AddData(leftInput, 1, 2, 3), AddData(rightInput, 1, 2, 3), - CheckLastBatch(), + CheckLastBatch(Row(1, 10, 2, null), Row(2, 10, 4, null), Row(3, 10, 6, null)), AddData(leftInput, 20), AddData(rightInput, 21), CheckLastBatch(), - assertNumStateRows(total = 8, updated = 2), + assertNumStateRows(total = 5, updated = 2), AddData(rightInput, 20), CheckLastBatch( - Row(20, 30, 40, 60), Row(1, 10, 2, null), Row(2, 10, 4, null), Row(3, 10, 6, null)), + Row(20, 30, 40, 60)), assertNumStateRows(total = 3, updated = 1), // leftValue and rightValue both satisfying condition should not generate outer join rows AddData(leftInput, 40, 41), diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingSymmetricHashJoinHelperSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingSymmetricHashJoinHelperSuite.scala new file mode 100644 index 0000000000000..2a854e37bf0df --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingSymmetricHashJoinHelperSuite.scala @@ -0,0 +1,130 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.streaming + +import org.apache.spark.sql.Column +import org.apache.spark.sql.catalyst.analysis.SimpleAnalyzer +import org.apache.spark.sql.catalyst.expressions.AttributeReference +import org.apache.spark.sql.execution.{LeafExecNode, LocalTableScanExec, SparkPlan} +import org.apache.spark.sql.execution.columnar.InMemoryTableScanExec +import org.apache.spark.sql.execution.streaming.StreamingSymmetricHashJoinHelper.JoinConditionSplitPredicates +import org.apache.spark.sql.types._ + +class StreamingSymmetricHashJoinHelperSuite extends StreamTest { + import org.apache.spark.sql.functions._ + + val leftAttributeA = AttributeReference("a", IntegerType)() + val leftAttributeB = AttributeReference("b", IntegerType)() + val rightAttributeC = AttributeReference("c", IntegerType)() + val rightAttributeD = AttributeReference("d", IntegerType)() + val leftColA = new Column(leftAttributeA) + val leftColB = new Column(leftAttributeB) + val rightColC = new Column(rightAttributeC) + val rightColD = new Column(rightAttributeD) + + val left = new LocalTableScanExec(Seq(leftAttributeA, leftAttributeB), Seq()) + val right = new LocalTableScanExec(Seq(rightAttributeC, rightAttributeD), Seq()) + + test("empty") { + val split = JoinConditionSplitPredicates(None, left, right) + assert(split.leftSideOnly.isEmpty) + assert(split.rightSideOnly.isEmpty) + assert(split.bothSides.isEmpty) + assert(split.full.isEmpty) + } + + test("only literals") { + // Literal-only conjuncts end up on the left side because that's the first bucket they fit in. + // There's no semantic reason they couldn't be in any bucket. + val predicate = (lit(1) < lit(5) && lit(6) < lit(7) && lit(0) === lit(-1)).expr + val split = JoinConditionSplitPredicates(Some(predicate), left, right) + + assert(split.leftSideOnly.contains(predicate)) + assert(split.rightSideOnly.contains(predicate)) + assert(split.bothSides.isEmpty) + assert(split.full.contains(predicate)) + } + + test("only left") { + val predicate = (leftColA > lit(1) && leftColB > lit(5) && leftColA < leftColB).expr + val split = JoinConditionSplitPredicates(Some(predicate), left, right) + + assert(split.leftSideOnly.contains(predicate)) + assert(split.rightSideOnly.isEmpty) + assert(split.bothSides.isEmpty) + assert(split.full.contains(predicate)) + } + + test("only right") { + val predicate = (rightColC > lit(1) && rightColD > lit(5) && rightColD < rightColC).expr + val split = JoinConditionSplitPredicates(Some(predicate), left, right) + + assert(split.leftSideOnly.isEmpty) + assert(split.rightSideOnly.contains(predicate)) + assert(split.bothSides.isEmpty) + assert(split.full.contains(predicate)) + } + + test("mixed conjuncts") { + val predicate = + (leftColA > leftColB + && rightColC > rightColD + && leftColA === rightColC + && lit(1) === lit(1)).expr + val split = JoinConditionSplitPredicates(Some(predicate), left, right) + + assert(split.leftSideOnly.contains((leftColA > leftColB && lit(1) === lit(1)).expr)) + assert(split.rightSideOnly.contains((rightColC > rightColD && lit(1) === lit(1)).expr)) + assert(split.bothSides.contains((leftColA === rightColC).expr)) + assert(split.full.contains(predicate)) + } + + test("conjuncts after nondeterministic") { + // All conjuncts after a nondeterministic conjunct shouldn't be split because they don't + // commute across it. + val predicate = + (rand() > lit(0) + && leftColA > leftColB + && rightColC > rightColD + && leftColA === rightColC + && lit(1) === lit(1)).expr + val split = JoinConditionSplitPredicates(Some(predicate), left, right) + + assert(split.leftSideOnly.isEmpty) + assert(split.rightSideOnly.isEmpty) + assert(split.bothSides.contains(predicate)) + assert(split.full.contains(predicate)) + } + + + test("conjuncts before nondeterministic") { + val randCol = rand() + val predicate = + (leftColA > leftColB + && rightColC > rightColD + && leftColA === rightColC + && lit(1) === lit(1) + && randCol > lit(0)).expr + val split = JoinConditionSplitPredicates(Some(predicate), left, right) + + assert(split.leftSideOnly.contains((leftColA > leftColB && lit(1) === lit(1)).expr)) + assert(split.rightSideOnly.contains((rightColC > rightColD && lit(1) === lit(1)).expr)) + assert(split.bothSides.contains((leftColA === rightColC && randCol > lit(0)).expr)) + assert(split.full.contains(predicate)) + } +} From 28f9f3f22511e9f2f900764d9bd5b90d2eeee773 Mon Sep 17 00:00:00 2001 From: Huaxin Gao Date: Tue, 17 Oct 2017 12:50:41 -0700 Subject: [PATCH 1502/1765] [SPARK-22271][SQL] mean overflows and returns null for some decimal variables ## What changes were proposed in this pull request? In Average.scala, it has ``` override lazy val evaluateExpression = child.dataType match { case DecimalType.Fixed(p, s) => // increase the precision and scale to prevent precision loss val dt = DecimalType.bounded(p + 14, s + 4) Cast(Cast(sum, dt) / Cast(count, dt), resultType) case _ => Cast(sum, resultType) / Cast(count, resultType) } def setChild (newchild: Expression) = { child = newchild } ``` It is possible that Cast(count, dt), resultType) will make the precision of the decimal number bigger than 38, and this causes over flow. Since count is an integer and doesn't need a scale, I will cast it using DecimalType.bounded(38,0) ## How was this patch tested? In DataFrameSuite, I will add a test case. Please review http://spark.apache.org/contributing.html before opening a pull request. Author: Huaxin Gao Closes #19496 from huaxingao/spark-22271. --- .../sql/catalyst/expressions/aggregate/Average.scala | 3 ++- .../test/scala/org/apache/spark/sql/DataFrameSuite.scala | 9 +++++++++ 2 files changed, 11 insertions(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala index c423e17169e85..708bdbfc36058 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala @@ -80,7 +80,8 @@ case class Average(child: Expression) extends DeclarativeAggregate with Implicit case DecimalType.Fixed(p, s) => // increase the precision and scale to prevent precision loss val dt = DecimalType.bounded(p + 14, s + 4) - Cast(Cast(sum, dt) / Cast(count, dt), resultType) + Cast(Cast(sum, dt) / Cast(count, DecimalType.bounded(DecimalType.MAX_PRECISION, 0)), + resultType) case _ => Cast(sum, resultType) / Cast(count, resultType) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index 50de2fd3bca8d..473c355cf3c7f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -2105,4 +2105,13 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { testData2.select(lit(7), 'a, 'b).orderBy(lit(1), lit(2), lit(3)), Seq(Row(7, 1, 1), Row(7, 1, 2), Row(7, 2, 1), Row(7, 2, 2), Row(7, 3, 1), Row(7, 3, 2))) } + + test("SPARK-22271: mean overflows and returns null for some decimal variables") { + val d = 0.034567890 + val df = Seq(d, d, d, d, d, d, d, d, d, d).toDF("DecimalCol") + val result = df.select('DecimalCol cast DecimalType(38, 33)) + .select(col("DecimalCol")).describe() + val mean = result.select("DecimalCol").where($"summary" === "mean") + assert(mean.collect().toSet === Set(Row("0.0345678900000000000000000000000000000"))) + } } From 1437e344ec0c29a44a19f4513986f5f184c44695 Mon Sep 17 00:00:00 2001 From: Michael Mior Date: Tue, 17 Oct 2017 14:30:52 -0700 Subject: [PATCH 1503/1765] [SPARK-22050][CORE] Allow BlockUpdated events to be optionally logged to the event log ## What changes were proposed in this pull request? I see that block updates are not logged to the event log. This makes sense as a default for performance reasons. However, I find it helpful when trying to get a better understanding of caching for a job to be able to log these updates. This PR adds a configuration setting `spark.eventLog.blockUpdates` (defaulting to false) which allows block updates to be recorded in the log. This contribution is original work which is licensed to the Apache Spark project. ## How was this patch tested? Current and additional unit tests. Author: Michael Mior Closes #19263 from michaelmior/log-block-updates. --- .../spark/internal/config/package.scala | 23 +++++++++++++ .../scheduler/EventLoggingListener.scala | 18 ++++++---- .../org/apache/spark/util/JsonProtocol.scala | 34 +++++++++++++++++-- .../scheduler/EventLoggingListenerSuite.scala | 2 ++ .../apache/spark/util/JsonProtocolSuite.scala | 27 +++++++++++++++ docs/configuration.md | 8 +++++ 6 files changed, 104 insertions(+), 8 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/internal/config/package.scala b/core/src/main/scala/org/apache/spark/internal/config/package.scala index e7b406af8d9b1..0c36bdcdd2904 100644 --- a/core/src/main/scala/org/apache/spark/internal/config/package.scala +++ b/core/src/main/scala/org/apache/spark/internal/config/package.scala @@ -41,6 +41,29 @@ package object config { .bytesConf(ByteUnit.MiB) .createWithDefaultString("1g") + private[spark] val EVENT_LOG_COMPRESS = + ConfigBuilder("spark.eventLog.compress") + .booleanConf + .createWithDefault(false) + + private[spark] val EVENT_LOG_BLOCK_UPDATES = + ConfigBuilder("spark.eventLog.logBlockUpdates.enabled") + .booleanConf + .createWithDefault(false) + + private[spark] val EVENT_LOG_TESTING = + ConfigBuilder("spark.eventLog.testing") + .internal() + .booleanConf + .createWithDefault(false) + + private[spark] val EVENT_LOG_OUTPUT_BUFFER_SIZE = ConfigBuilder("spark.eventLog.buffer.kb") + .bytesConf(ByteUnit.KiB) + .createWithDefaultString("100k") + + private[spark] val EVENT_LOG_OVERWRITE = + ConfigBuilder("spark.eventLog.overwrite").booleanConf.createWithDefault(false) + private[spark] val EXECUTOR_CLASS_PATH = ConfigBuilder(SparkLauncher.EXECUTOR_EXTRA_CLASSPATH).stringConf.createOptional diff --git a/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala b/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala index 9dafa0b7646bf..a77adc5ff3545 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala @@ -37,6 +37,7 @@ import org.json4s.jackson.JsonMethods._ import org.apache.spark.{SPARK_VERSION, SparkConf} import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.internal.Logging +import org.apache.spark.internal.config._ import org.apache.spark.io.CompressionCodec import org.apache.spark.util.{JsonProtocol, Utils} @@ -45,6 +46,7 @@ import org.apache.spark.util.{JsonProtocol, Utils} * * Event logging is specified by the following configurable parameters: * spark.eventLog.enabled - Whether event logging is enabled. + * spark.eventLog.logBlockUpdates.enabled - Whether to log block updates * spark.eventLog.compress - Whether to compress logged events * spark.eventLog.overwrite - Whether to overwrite any existing files. * spark.eventLog.dir - Path to the directory in which events are logged. @@ -64,10 +66,11 @@ private[spark] class EventLoggingListener( this(appId, appAttemptId, logBaseDir, sparkConf, SparkHadoopUtil.get.newConfiguration(sparkConf)) - private val shouldCompress = sparkConf.getBoolean("spark.eventLog.compress", false) - private val shouldOverwrite = sparkConf.getBoolean("spark.eventLog.overwrite", false) - private val testing = sparkConf.getBoolean("spark.eventLog.testing", false) - private val outputBufferSize = sparkConf.getInt("spark.eventLog.buffer.kb", 100) * 1024 + private val shouldCompress = sparkConf.get(EVENT_LOG_COMPRESS) + private val shouldOverwrite = sparkConf.get(EVENT_LOG_OVERWRITE) + private val shouldLogBlockUpdates = sparkConf.get(EVENT_LOG_BLOCK_UPDATES) + private val testing = sparkConf.get(EVENT_LOG_TESTING) + private val outputBufferSize = sparkConf.get(EVENT_LOG_OUTPUT_BUFFER_SIZE).toInt private val fileSystem = Utils.getHadoopFileSystem(logBaseDir, hadoopConf) private val compressionCodec = if (shouldCompress) { @@ -216,8 +219,11 @@ private[spark] class EventLoggingListener( logEvent(event, flushLogger = true) } - // No-op because logging every update would be overkill - override def onBlockUpdated(event: SparkListenerBlockUpdated): Unit = {} + override def onBlockUpdated(event: SparkListenerBlockUpdated): Unit = { + if (shouldLogBlockUpdates) { + logEvent(event, flushLogger = true) + } + } // No-op because logging every update would be overkill override def onExecutorMetricsUpdate(event: SparkListenerExecutorMetricsUpdate): Unit = { } diff --git a/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala b/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala index 8406826a228db..5e60218c5740b 100644 --- a/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala +++ b/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala @@ -98,8 +98,8 @@ private[spark] object JsonProtocol { logStartToJson(logStart) case metricsUpdate: SparkListenerExecutorMetricsUpdate => executorMetricsUpdateToJson(metricsUpdate) - case blockUpdated: SparkListenerBlockUpdated => - throw new MatchError(blockUpdated) // TODO(ekl) implement this + case blockUpdate: SparkListenerBlockUpdated => + blockUpdateToJson(blockUpdate) case _ => parse(mapper.writeValueAsString(event)) } } @@ -246,6 +246,12 @@ private[spark] object JsonProtocol { }) } + def blockUpdateToJson(blockUpdate: SparkListenerBlockUpdated): JValue = { + val blockUpdatedInfo = blockUpdatedInfoToJson(blockUpdate.blockUpdatedInfo) + ("Event" -> SPARK_LISTENER_EVENT_FORMATTED_CLASS_NAMES.blockUpdate) ~ + ("Block Updated Info" -> blockUpdatedInfo) + } + /** ------------------------------------------------------------------- * * JSON serialization methods for classes SparkListenerEvents depend on | * -------------------------------------------------------------------- */ @@ -458,6 +464,14 @@ private[spark] object JsonProtocol { ("Log Urls" -> mapToJson(executorInfo.logUrlMap)) } + def blockUpdatedInfoToJson(blockUpdatedInfo: BlockUpdatedInfo): JValue = { + ("Block Manager ID" -> blockManagerIdToJson(blockUpdatedInfo.blockManagerId)) ~ + ("Block ID" -> blockUpdatedInfo.blockId.toString) ~ + ("Storage Level" -> storageLevelToJson(blockUpdatedInfo.storageLevel)) ~ + ("Memory Size" -> blockUpdatedInfo.memSize) ~ + ("Disk Size" -> blockUpdatedInfo.diskSize) + } + /** ------------------------------ * * Util JSON serialization methods | * ------------------------------- */ @@ -515,6 +529,7 @@ private[spark] object JsonProtocol { val executorRemoved = Utils.getFormattedClassName(SparkListenerExecutorRemoved) val logStart = Utils.getFormattedClassName(SparkListenerLogStart) val metricsUpdate = Utils.getFormattedClassName(SparkListenerExecutorMetricsUpdate) + val blockUpdate = Utils.getFormattedClassName(SparkListenerBlockUpdated) } def sparkEventFromJson(json: JValue): SparkListenerEvent = { @@ -538,6 +553,7 @@ private[spark] object JsonProtocol { case `executorRemoved` => executorRemovedFromJson(json) case `logStart` => logStartFromJson(json) case `metricsUpdate` => executorMetricsUpdateFromJson(json) + case `blockUpdate` => blockUpdateFromJson(json) case other => mapper.readValue(compact(render(json)), Utils.classForName(other)) .asInstanceOf[SparkListenerEvent] } @@ -676,6 +692,11 @@ private[spark] object JsonProtocol { SparkListenerExecutorMetricsUpdate(execInfo, accumUpdates) } + def blockUpdateFromJson(json: JValue): SparkListenerBlockUpdated = { + val blockUpdatedInfo = blockUpdatedInfoFromJson(json \ "Block Updated Info") + SparkListenerBlockUpdated(blockUpdatedInfo) + } + /** --------------------------------------------------------------------- * * JSON deserialization methods for classes SparkListenerEvents depend on | * ---------------------------------------------------------------------- */ @@ -989,6 +1010,15 @@ private[spark] object JsonProtocol { new ExecutorInfo(executorHost, totalCores, logUrls) } + def blockUpdatedInfoFromJson(json: JValue): BlockUpdatedInfo = { + val blockManagerId = blockManagerIdFromJson(json \ "Block Manager ID") + val blockId = BlockId((json \ "Block ID").extract[String]) + val storageLevel = storageLevelFromJson(json \ "Storage Level") + val memorySize = (json \ "Memory Size").extract[Long] + val diskSize = (json \ "Disk Size").extract[Long] + BlockUpdatedInfo(blockManagerId, blockId, storageLevel, memorySize, diskSize) + } + /** -------------------------------- * * Util JSON deserialization methods | * --------------------------------- */ diff --git a/core/src/test/scala/org/apache/spark/scheduler/EventLoggingListenerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/EventLoggingListenerSuite.scala index 6b42775ccb0f6..a9e92fa07b9dd 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/EventLoggingListenerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/EventLoggingListenerSuite.scala @@ -228,6 +228,7 @@ class EventLoggingListenerSuite extends SparkFunSuite with LocalSparkContext wit SparkListenerStageCompleted, SparkListenerTaskStart, SparkListenerTaskEnd, + SparkListenerBlockUpdated, SparkListenerApplicationEnd).map(Utils.getFormattedClassName) Utils.tryWithSafeFinally { val logStart = SparkListenerLogStart(SPARK_VERSION) @@ -291,6 +292,7 @@ object EventLoggingListenerSuite { def getLoggingConf(logDir: Path, compressionCodec: Option[String] = None): SparkConf = { val conf = new SparkConf conf.set("spark.eventLog.enabled", "true") + conf.set("spark.eventLog.logBlockUpdates.enabled", "true") conf.set("spark.eventLog.testing", "true") conf.set("spark.eventLog.dir", logDir.toString) compressionCodec.foreach { codec => diff --git a/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala b/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala index a1a858765a7d4..4abbb8e7894f5 100644 --- a/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala @@ -96,6 +96,9 @@ class JsonProtocolSuite extends SparkFunSuite { .zipWithIndex.map { case (a, i) => a.copy(id = i) } SparkListenerExecutorMetricsUpdate("exec3", Seq((1L, 2, 3, accumUpdates))) } + val blockUpdated = + SparkListenerBlockUpdated(BlockUpdatedInfo(BlockManagerId("Stars", + "In your multitude...", 300), RDDBlockId(0, 0), StorageLevel.MEMORY_ONLY, 100L, 0L)) testEvent(stageSubmitted, stageSubmittedJsonString) testEvent(stageCompleted, stageCompletedJsonString) @@ -120,6 +123,7 @@ class JsonProtocolSuite extends SparkFunSuite { testEvent(nodeBlacklisted, nodeBlacklistedJsonString) testEvent(nodeUnblacklisted, nodeUnblacklistedJsonString) testEvent(executorMetricsUpdate, executorMetricsUpdateJsonString) + testEvent(blockUpdated, blockUpdatedJsonString) } test("Dependent Classes") { @@ -2007,6 +2011,29 @@ private[spark] object JsonProtocolSuite extends Assertions { |} """.stripMargin + private val blockUpdatedJsonString = + """ + |{ + | "Event": "SparkListenerBlockUpdated", + | "Block Updated Info": { + | "Block Manager ID": { + | "Executor ID": "Stars", + | "Host": "In your multitude...", + | "Port": 300 + | }, + | "Block ID": "rdd_0_0", + | "Storage Level": { + | "Use Disk": false, + | "Use Memory": true, + | "Deserialized": true, + | "Replication": 1 + | }, + | "Memory Size": 100, + | "Disk Size": 0 + | } + |} + """.stripMargin + private val executorBlacklistedJsonString = s""" |{ diff --git a/docs/configuration.md b/docs/configuration.md index bb06c8faaaed7..7b9e16a382449 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -714,6 +714,14 @@ Apart from these, the following properties are also available, and may be useful + + + + + From f3137feecd30c74c47dbddb0e22b4ddf8cf2f912 Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Tue, 17 Oct 2017 20:09:12 -0700 Subject: [PATCH 1504/1765] [SPARK-22278][SS] Expose current event time watermark and current processing time in GroupState ## What changes were proposed in this pull request? Complex state-updating and/or timeout-handling logic in mapGroupsWithState functions may require taking decisions based on the current event-time watermark and/or processing time. Currently, you can use the SQL function `current_timestamp` to get the current processing time, but it needs to be passed inserted in every row with a select, and then passed through the encoder, which isn't efficient. Furthermore, there is no way to get the current watermark. This PR exposes both of them through the GroupState API. Additionally, it also cleans up some of the GroupState docs. ## How was this patch tested? New unit tests Author: Tathagata Das Closes #19495 from tdas/SPARK-22278. --- .../apache/spark/sql/execution/objects.scala | 8 +- .../FlatMapGroupsWithStateExec.scala | 7 +- .../execution/streaming/GroupStateImpl.scala | 50 +++--- .../spark/sql/streaming/GroupState.scala | 92 ++++++---- .../FlatMapGroupsWithStateSuite.scala | 160 +++++++++++++++--- 5 files changed, 238 insertions(+), 79 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala index c68975bea490f..d861109436a08 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala @@ -29,7 +29,7 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.expressions.objects.Invoke -import org.apache.spark.sql.catalyst.plans.logical.{FunctionUtils, LogicalGroupState} +import org.apache.spark.sql.catalyst.plans.logical.{EventTimeWatermark, FunctionUtils, LogicalGroupState} import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.execution.streaming.GroupStateImpl import org.apache.spark.sql.streaming.GroupStateTimeout @@ -361,8 +361,12 @@ object MapGroupsExec { outputObjAttr: Attribute, timeoutConf: GroupStateTimeout, child: SparkPlan): MapGroupsExec = { + val watermarkPresent = child.output.exists { + case a: Attribute if a.metadata.contains(EventTimeWatermark.delayKey) => true + case _ => false + } val f = (key: Any, values: Iterator[Any]) => { - func(key, values, GroupStateImpl.createForBatch(timeoutConf)) + func(key, values, GroupStateImpl.createForBatch(timeoutConf, watermarkPresent)) } new MapGroupsExec(f, keyDeserializer, valueDeserializer, groupingAttributes, dataAttributes, outputObjAttr, child) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala index c81f1a8142784..29f38fab3f896 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala @@ -61,6 +61,10 @@ case class FlatMapGroupsWithStateExec( private val isTimeoutEnabled = timeoutConf != NoTimeout val stateManager = new FlatMapGroupsWithState_StateManager(stateEncoder, isTimeoutEnabled) + val watermarkPresent = child.output.exists { + case a: Attribute if a.metadata.contains(EventTimeWatermark.delayKey) => true + case _ => false + } /** Distribute by grouping attributes */ override def requiredChildDistribution: Seq[Distribution] = @@ -190,7 +194,8 @@ case class FlatMapGroupsWithStateExec( batchTimestampMs.getOrElse(NO_TIMESTAMP), eventTimeWatermark.getOrElse(NO_TIMESTAMP), timeoutConf, - hasTimedOut) + hasTimedOut, + watermarkPresent) // Call function, get the returned objects and convert them to rows val mappedIterator = func(keyObj, valueObjIter, groupState).map { obj => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/GroupStateImpl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/GroupStateImpl.scala index 4401e86936af9..7f65e3ea9dd5e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/GroupStateImpl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/GroupStateImpl.scala @@ -43,7 +43,8 @@ private[sql] class GroupStateImpl[S] private( batchProcessingTimeMs: Long, eventTimeWatermarkMs: Long, timeoutConf: GroupStateTimeout, - override val hasTimedOut: Boolean) extends GroupState[S] { + override val hasTimedOut: Boolean, + watermarkPresent: Boolean) extends GroupState[S] { private var value: S = optionalValue.getOrElse(null.asInstanceOf[S]) private var defined: Boolean = optionalValue.isDefined @@ -90,7 +91,7 @@ private[sql] class GroupStateImpl[S] private( if (timeoutConf != ProcessingTimeTimeout) { throw new UnsupportedOperationException( "Cannot set timeout duration without enabling processing time timeout in " + - "map/flatMapGroupsWithState") + "[map|flatMap]GroupsWithState") } if (durationMs <= 0) { throw new IllegalArgumentException("Timeout duration must be positive") @@ -102,10 +103,6 @@ private[sql] class GroupStateImpl[S] private( setTimeoutDuration(parseDuration(duration)) } - @throws[IllegalArgumentException]("if 'timestampMs' is not positive") - @throws[IllegalStateException]("when state is either not initialized, or already removed") - @throws[UnsupportedOperationException]( - "if 'timeout' has not been enabled in [map|flatMap]GroupsWithState in a streaming query") override def setTimeoutTimestamp(timestampMs: Long): Unit = { checkTimeoutTimestampAllowed() if (timestampMs <= 0) { @@ -119,32 +116,34 @@ private[sql] class GroupStateImpl[S] private( timeoutTimestamp = timestampMs } - @throws[IllegalArgumentException]("if 'additionalDuration' is invalid") - @throws[IllegalStateException]("when state is either not initialized, or already removed") - @throws[UnsupportedOperationException]( - "if 'timeout' has not been enabled in [map|flatMap]GroupsWithState in a streaming query") override def setTimeoutTimestamp(timestampMs: Long, additionalDuration: String): Unit = { checkTimeoutTimestampAllowed() setTimeoutTimestamp(parseDuration(additionalDuration) + timestampMs) } - @throws[IllegalStateException]("when state is either not initialized, or already removed") - @throws[UnsupportedOperationException]( - "if 'timeout' has not been enabled in [map|flatMap]GroupsWithState in a streaming query") override def setTimeoutTimestamp(timestamp: Date): Unit = { checkTimeoutTimestampAllowed() setTimeoutTimestamp(timestamp.getTime) } - @throws[IllegalArgumentException]("if 'additionalDuration' is invalid") - @throws[IllegalStateException]("when state is either not initialized, or already removed") - @throws[UnsupportedOperationException]( - "if 'timeout' has not been enabled in [map|flatMap]GroupsWithState in a streaming query") override def setTimeoutTimestamp(timestamp: Date, additionalDuration: String): Unit = { checkTimeoutTimestampAllowed() setTimeoutTimestamp(timestamp.getTime + parseDuration(additionalDuration)) } + override def getCurrentWatermarkMs(): Long = { + if (!watermarkPresent) { + throw new UnsupportedOperationException( + "Cannot get event time watermark timestamp without setting watermark before " + + "[map|flatMap]GroupsWithState") + } + eventTimeWatermarkMs + } + + override def getCurrentProcessingTimeMs(): Long = { + batchProcessingTimeMs + } + override def toString: String = { s"GroupState(${getOption.map(_.toString).getOrElse("")})" } @@ -187,7 +186,7 @@ private[sql] class GroupStateImpl[S] private( if (timeoutConf != EventTimeTimeout) { throw new UnsupportedOperationException( "Cannot set timeout timestamp without enabling event time timeout in " + - "map/flatMapGroupsWithState") + "[map|flatMapGroupsWithState") } } } @@ -202,17 +201,22 @@ private[sql] object GroupStateImpl { batchProcessingTimeMs: Long, eventTimeWatermarkMs: Long, timeoutConf: GroupStateTimeout, - hasTimedOut: Boolean): GroupStateImpl[S] = { + hasTimedOut: Boolean, + watermarkPresent: Boolean): GroupStateImpl[S] = { new GroupStateImpl[S]( - optionalValue, batchProcessingTimeMs, eventTimeWatermarkMs, timeoutConf, hasTimedOut) + optionalValue, batchProcessingTimeMs, eventTimeWatermarkMs, + timeoutConf, hasTimedOut, watermarkPresent) } - def createForBatch(timeoutConf: GroupStateTimeout): GroupStateImpl[Any] = { + def createForBatch( + timeoutConf: GroupStateTimeout, + watermarkPresent: Boolean): GroupStateImpl[Any] = { new GroupStateImpl[Any]( optionalValue = None, - batchProcessingTimeMs = NO_TIMESTAMP, + batchProcessingTimeMs = System.currentTimeMillis, eventTimeWatermarkMs = NO_TIMESTAMP, timeoutConf, - hasTimedOut = false) + hasTimedOut = false, + watermarkPresent) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/GroupState.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/GroupState.scala index 04a956b70b022..e9510c903acae 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/GroupState.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/GroupState.scala @@ -205,11 +205,7 @@ trait GroupState[S] extends LogicalGroupState[S] { /** Get the state value as a scala Option. */ def getOption: Option[S] - /** - * Update the value of the state. Note that `null` is not a valid value, and it throws - * IllegalArgumentException. - */ - @throws[IllegalArgumentException]("when updating with null") + /** Update the value of the state. */ def update(newState: S): Unit /** Remove this state. */ @@ -217,80 +213,114 @@ trait GroupState[S] extends LogicalGroupState[S] { /** * Whether the function has been called because the key has timed out. - * @note This can return true only when timeouts are enabled in `[map/flatmap]GroupsWithStates`. + * @note This can return true only when timeouts are enabled in `[map/flatMap]GroupsWithState`. */ def hasTimedOut: Boolean + /** * Set the timeout duration in ms for this key. * - * @note ProcessingTimeTimeout must be enabled in `[map/flatmap]GroupsWithStates`. + * @note [[GroupStateTimeout Processing time timeout]] must be enabled in + * `[map/flatMap]GroupsWithState` for calling this method. + * @note This method has no effect when used in a batch query. */ @throws[IllegalArgumentException]("if 'durationMs' is not positive") - @throws[IllegalStateException]("when state is either not initialized, or already removed") @throws[UnsupportedOperationException]( - "if 'timeout' has not been enabled in [map|flatMap]GroupsWithState in a streaming query") + "if processing time timeout has not been enabled in [map|flatMap]GroupsWithState") def setTimeoutDuration(durationMs: Long): Unit + /** * Set the timeout duration for this key as a string. For example, "1 hour", "2 days", etc. * - * @note ProcessingTimeTimeout must be enabled in `[map/flatmap]GroupsWithStates`. + * @note [[GroupStateTimeout Processing time timeout]] must be enabled in + * `[map/flatMap]GroupsWithState` for calling this method. + * @note This method has no effect when used in a batch query. */ @throws[IllegalArgumentException]("if 'duration' is not a valid duration") - @throws[IllegalStateException]("when state is either not initialized, or already removed") @throws[UnsupportedOperationException]( - "if 'timeout' has not been enabled in [map|flatMap]GroupsWithState in a streaming query") + "if processing time timeout has not been enabled in [map|flatMap]GroupsWithState") def setTimeoutDuration(duration: String): Unit - @throws[IllegalArgumentException]("if 'timestampMs' is not positive") - @throws[IllegalStateException]("when state is either not initialized, or already removed") - @throws[UnsupportedOperationException]( - "if 'timeout' has not been enabled in [map|flatMap]GroupsWithState in a streaming query") + /** * Set the timeout timestamp for this key as milliseconds in epoch time. * This timestamp cannot be older than the current watermark. * - * @note EventTimeTimeout must be enabled in `[map/flatmap]GroupsWithStates`. + * @note [[GroupStateTimeout Event time timeout]] must be enabled in + * `[map/flatMap]GroupsWithState` for calling this method. + * @note This method has no effect when used in a batch query. */ + @throws[IllegalArgumentException]( + "if 'timestampMs' is not positive or less than the current watermark in a streaming query") + @throws[UnsupportedOperationException]( + "if processing time timeout has not been enabled in [map|flatMap]GroupsWithState") def setTimeoutTimestamp(timestampMs: Long): Unit - @throws[IllegalArgumentException]("if 'additionalDuration' is invalid") - @throws[IllegalStateException]("when state is either not initialized, or already removed") - @throws[UnsupportedOperationException]( - "if 'timeout' has not been enabled in [map|flatMap]GroupsWithState in a streaming query") + /** * Set the timeout timestamp for this key as milliseconds in epoch time and an additional * duration as a string (e.g. "1 hour", "2 days", etc.). * The final timestamp (including the additional duration) cannot be older than the * current watermark. * - * @note EventTimeTimeout must be enabled in `[map/flatmap]GroupsWithStates`. + * @note [[GroupStateTimeout Event time timeout]] must be enabled in + * `[map/flatMap]GroupsWithState` for calling this method. + * @note This method has no side effect when used in a batch query. */ + @throws[IllegalArgumentException]( + "if 'additionalDuration' is invalid or the final timeout timestamp is less than " + + "the current watermark in a streaming query") + @throws[UnsupportedOperationException]( + "if event time timeout has not been enabled in [map|flatMap]GroupsWithState") def setTimeoutTimestamp(timestampMs: Long, additionalDuration: String): Unit - @throws[IllegalStateException]("when state is either not initialized, or already removed") - @throws[UnsupportedOperationException]( - "if 'timeout' has not been enabled in [map|flatMap]GroupsWithState in a streaming query") + /** * Set the timeout timestamp for this key as a java.sql.Date. * This timestamp cannot be older than the current watermark. * - * @note EventTimeTimeout must be enabled in `[map/flatmap]GroupsWithStates`. + * @note [[GroupStateTimeout Event time timeout]] must be enabled in + * `[map/flatMap]GroupsWithState` for calling this method. + * @note This method has no side effect when used in a batch query. */ + @throws[UnsupportedOperationException]( + "if event time timeout has not been enabled in [map|flatMap]GroupsWithState") def setTimeoutTimestamp(timestamp: java.sql.Date): Unit - @throws[IllegalArgumentException]("if 'additionalDuration' is invalid") - @throws[IllegalStateException]("when state is either not initialized, or already removed") - @throws[UnsupportedOperationException]( - "if 'timeout' has not been enabled in [map|flatMap]GroupsWithState in a streaming query") + /** * Set the timeout timestamp for this key as a java.sql.Date and an additional * duration as a string (e.g. "1 hour", "2 days", etc.). * The final timestamp (including the additional duration) cannot be older than the * current watermark. * - * @note EventTimeTimeout must be enabled in `[map/flatmap]GroupsWithStates`. + * @note [[GroupStateTimeout Event time timeout]] must be enabled in + * `[map/flatMap]GroupsWithState` for calling this method. + * @note This method has no side effect when used in a batch query. */ + @throws[IllegalArgumentException]("if 'additionalDuration' is invalid") + @throws[UnsupportedOperationException]( + "if event time timeout has not been enabled in [map|flatMap]GroupsWithState") def setTimeoutTimestamp(timestamp: java.sql.Date, additionalDuration: String): Unit + + + /** + * Get the current event time watermark as milliseconds in epoch time. + * + * @note In a streaming query, this can be called only when watermark is set before calling + * `[map/flatMap]GroupsWithState`. In a batch query, this method always returns -1. + */ + @throws[UnsupportedOperationException]( + "if watermark has not been set before in [map|flatMap]GroupsWithState") + def getCurrentWatermarkMs(): Long + + + /** + * Get the current processing time as milliseconds in epoch time. + * @note In a streaming query, this will return a constant value throughout the duration of a + * trigger, even if the trigger is re-executed. + */ + def getCurrentProcessingTimeMs(): Long } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala index aeb83835f981a..af08186aadbb0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala @@ -21,6 +21,7 @@ import java.sql.Date import java.util.concurrent.ConcurrentHashMap import org.scalatest.BeforeAndAfterAll +import org.scalatest.exceptions.TestFailedException import org.apache.spark.SparkException import org.apache.spark.api.java.function.FlatMapGroupsWithStateFunction @@ -48,6 +49,7 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest import testImplicits._ import GroupStateImpl._ import GroupStateTimeout._ + import FlatMapGroupsWithStateSuite._ override def afterAll(): Unit = { super.afterAll() @@ -77,13 +79,15 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest // === Tests for state in streaming queries === // Updating empty state - state = GroupStateImpl.createForStreaming(None, 1, 1, NoTimeout, hasTimedOut = false) + state = GroupStateImpl.createForStreaming( + None, 1, 1, NoTimeout, hasTimedOut = false, watermarkPresent = false) testState(None) state.update("") testState(Some(""), shouldBeUpdated = true) // Updating exiting state - state = GroupStateImpl.createForStreaming(Some("2"), 1, 1, NoTimeout, hasTimedOut = false) + state = GroupStateImpl.createForStreaming( + Some("2"), 1, 1, NoTimeout, hasTimedOut = false, watermarkPresent = false) testState(Some("2")) state.update("3") testState(Some("3"), shouldBeUpdated = true) @@ -104,8 +108,9 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest test("GroupState - setTimeout - with NoTimeout") { for (initValue <- Seq(None, Some(5))) { val states = Seq( - GroupStateImpl.createForStreaming(initValue, 1000, 1000, NoTimeout, hasTimedOut = false), - GroupStateImpl.createForBatch(NoTimeout) + GroupStateImpl.createForStreaming( + initValue, 1000, 1000, NoTimeout, hasTimedOut = false, watermarkPresent = false), + GroupStateImpl.createForBatch(NoTimeout, watermarkPresent = false) ) for (state <- states) { // for streaming queries @@ -122,7 +127,7 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest test("GroupState - setTimeout - with ProcessingTimeTimeout") { // for streaming queries var state: GroupStateImpl[Int] = GroupStateImpl.createForStreaming( - None, 1000, 1000, ProcessingTimeTimeout, hasTimedOut = false) + None, 1000, 1000, ProcessingTimeTimeout, hasTimedOut = false, watermarkPresent = false) assert(state.getTimeoutTimestamp === NO_TIMESTAMP) state.setTimeoutDuration(500) assert(state.getTimeoutTimestamp === 1500) // can be set without initializing state @@ -143,7 +148,8 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest testTimeoutTimestampNotAllowed[UnsupportedOperationException](state) // for batch queries - state = GroupStateImpl.createForBatch(ProcessingTimeTimeout).asInstanceOf[GroupStateImpl[Int]] + state = GroupStateImpl.createForBatch( + ProcessingTimeTimeout, watermarkPresent = false).asInstanceOf[GroupStateImpl[Int]] assert(state.getTimeoutTimestamp === NO_TIMESTAMP) state.setTimeoutDuration(500) testTimeoutTimestampNotAllowed[UnsupportedOperationException](state) @@ -160,7 +166,7 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest test("GroupState - setTimeout - with EventTimeTimeout") { var state: GroupStateImpl[Int] = GroupStateImpl.createForStreaming( - None, 1000, 1000, EventTimeTimeout, false) + None, 1000, 1000, EventTimeTimeout, false, watermarkPresent = true) assert(state.getTimeoutTimestamp === NO_TIMESTAMP) testTimeoutDurationNotAllowed[UnsupportedOperationException](state) @@ -182,7 +188,8 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest testTimeoutDurationNotAllowed[UnsupportedOperationException](state) // for batch queries - state = GroupStateImpl.createForBatch(EventTimeTimeout).asInstanceOf[GroupStateImpl[Int]] + state = GroupStateImpl.createForBatch(EventTimeTimeout, watermarkPresent = false) + .asInstanceOf[GroupStateImpl[Int]] assert(state.getTimeoutTimestamp === NO_TIMESTAMP) testTimeoutDurationNotAllowed[UnsupportedOperationException](state) state.setTimeoutTimestamp(5000) @@ -209,7 +216,7 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest } state = GroupStateImpl.createForStreaming( - Some(5), 1000, 1000, ProcessingTimeTimeout, hasTimedOut = false) + Some(5), 1000, 1000, ProcessingTimeTimeout, hasTimedOut = false, watermarkPresent = false) testIllegalTimeout { state.setTimeoutDuration(-1000) } @@ -227,7 +234,7 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest } state = GroupStateImpl.createForStreaming( - Some(5), 1000, 1000, EventTimeTimeout, hasTimedOut = false) + Some(5), 1000, 1000, EventTimeTimeout, hasTimedOut = false, watermarkPresent = false) testIllegalTimeout { state.setTimeoutTimestamp(-10000) } @@ -259,29 +266,92 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest // for streaming queries for (initState <- Seq(None, Some(5))) { val state1 = GroupStateImpl.createForStreaming( - initState, 1000, 1000, timeoutConf, hasTimedOut = false) + initState, 1000, 1000, timeoutConf, hasTimedOut = false, watermarkPresent = false) assert(state1.hasTimedOut === false) val state2 = GroupStateImpl.createForStreaming( - initState, 1000, 1000, timeoutConf, hasTimedOut = true) + initState, 1000, 1000, timeoutConf, hasTimedOut = true, watermarkPresent = false) assert(state2.hasTimedOut === true) } // for batch queries - assert(GroupStateImpl.createForBatch(timeoutConf).hasTimedOut === false) + assert( + GroupStateImpl.createForBatch(timeoutConf, watermarkPresent = false).hasTimedOut === false) + } + } + + test("GroupState - getCurrentWatermarkMs") { + def streamingState(timeoutConf: GroupStateTimeout, watermark: Option[Long]): GroupState[Int] = { + GroupStateImpl.createForStreaming( + None, 1000, watermark.getOrElse(-1), timeoutConf, + hasTimedOut = false, watermark.nonEmpty) + } + + def batchState(timeoutConf: GroupStateTimeout, watermarkPresent: Boolean): GroupState[Any] = { + GroupStateImpl.createForBatch(timeoutConf, watermarkPresent) + } + + def assertWrongTimeoutError(test: => Unit): Unit = { + val e = intercept[UnsupportedOperationException] { test } + assert(e.getMessage.contains( + "Cannot get event time watermark timestamp without setting watermark")) + } + + for (timeoutConf <- Seq(NoTimeout, EventTimeTimeout, ProcessingTimeTimeout)) { + // Tests for getCurrentWatermarkMs in streaming queries + assertWrongTimeoutError { streamingState(timeoutConf, None).getCurrentWatermarkMs() } + assert(streamingState(timeoutConf, Some(1000)).getCurrentWatermarkMs() === 1000) + assert(streamingState(timeoutConf, Some(2000)).getCurrentWatermarkMs() === 2000) + + // Tests for getCurrentWatermarkMs in batch queries + assertWrongTimeoutError { + batchState(timeoutConf, watermarkPresent = false).getCurrentWatermarkMs() + } + assert(batchState(timeoutConf, watermarkPresent = true).getCurrentWatermarkMs() === -1) + } + } + + test("GroupState - getCurrentProcessingTimeMs") { + def streamingState( + timeoutConf: GroupStateTimeout, + procTime: Long, + watermarkPresent: Boolean): GroupState[Int] = { + GroupStateImpl.createForStreaming( + None, procTime, -1, timeoutConf, hasTimedOut = false, watermarkPresent = false) + } + + def batchState(timeoutConf: GroupStateTimeout, watermarkPresent: Boolean): GroupState[Any] = { + GroupStateImpl.createForBatch(timeoutConf, watermarkPresent) + } + + for (timeoutConf <- Seq(NoTimeout, EventTimeTimeout, ProcessingTimeTimeout)) { + for (watermarkPresent <- Seq(false, true)) { + // Tests for getCurrentProcessingTimeMs in streaming queries + assert(streamingState(timeoutConf, NO_TIMESTAMP, watermarkPresent) + .getCurrentProcessingTimeMs() === -1) + assert(streamingState(timeoutConf, 1000, watermarkPresent) + .getCurrentProcessingTimeMs() === 1000) + assert(streamingState(timeoutConf, 2000, watermarkPresent) + .getCurrentProcessingTimeMs() === 2000) + + // Tests for getCurrentProcessingTimeMs in batch queries + val currentTime = System.currentTimeMillis() + assert(batchState(timeoutConf, watermarkPresent).getCurrentProcessingTimeMs >= currentTime) + } } } + test("GroupState - primitive type") { var intState = GroupStateImpl.createForStreaming[Int]( - None, 1000, 1000, NoTimeout, hasTimedOut = false) + None, 1000, 1000, NoTimeout, hasTimedOut = false, watermarkPresent = false) intercept[NoSuchElementException] { intState.get } assert(intState.getOption === None) intState = GroupStateImpl.createForStreaming[Int]( - Some(10), 1000, 1000, NoTimeout, hasTimedOut = false) + Some(10), 1000, 1000, NoTimeout, hasTimedOut = false, watermarkPresent = false) assert(intState.get == 10) intState.update(0) assert(intState.get == 0) @@ -304,7 +374,11 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest testStateUpdateWithData( testName + "no update", - stateUpdates = state => { /* do nothing */ }, + stateUpdates = state => { + assert(state.getCurrentProcessingTimeMs() === currentBatchTimestamp) + intercept[Exception] { state.getCurrentWatermarkMs() } // watermark not specified + /* no updates */ + }, timeoutConf = GroupStateTimeout.NoTimeout, priorState = priorState, expectedState = priorState) // should not change @@ -342,7 +416,11 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest testStateUpdateWithData( s"$timeoutConf - $testName - no update", - stateUpdates = state => { /* do nothing */ }, + stateUpdates = state => { + assert(state.getCurrentProcessingTimeMs() === currentBatchTimestamp) + intercept[Exception] { state.getCurrentWatermarkMs() } // watermark not specified + /* no updates */ + }, timeoutConf = timeoutConf, priorState = priorState, priorTimeoutTimestamp = priorTimeoutTimestamp, @@ -466,7 +544,11 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest testStateUpdateWithTimeout( s"$timeoutConf - should timeout - no update/remove", - stateUpdates = state => { /* do nothing */ }, + stateUpdates = state => { + assert(state.getCurrentProcessingTimeMs() === currentBatchTimestamp) + intercept[Exception] { state.getCurrentWatermarkMs() } // watermark not specified + /* no updates */ + }, timeoutConf = timeoutConf, priorTimeoutTimestamp = beforeTimeoutThreshold, expectedState = preTimeoutState, // state should not change @@ -525,6 +607,8 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest // Function to maintain running count up to 2, and then remove the count // Returns the data and the count if state is defined, otherwise does not return anything val stateFunc = (key: String, values: Iterator[String], state: GroupState[RunningCount]) => { + assertCanGetProcessingTime { state.getCurrentProcessingTimeMs() >= 0 } + assertCannotGetWatermark { state.getCurrentWatermarkMs() } val count = state.getOption.map(_.count).getOrElse(0L) + values.size if (count == 3) { @@ -647,6 +731,9 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest test("flatMapGroupsWithState - batch") { // Function that returns running count only if its even, otherwise does not return val stateFunc = (key: String, values: Iterator[String], state: GroupState[RunningCount]) => { + assertCanGetProcessingTime { state.getCurrentProcessingTimeMs() > 0 } + assertCannotGetWatermark { state.getCurrentWatermarkMs() } + if (state.exists) throw new IllegalArgumentException("state.exists should be false") Iterator((key, values.size)) } @@ -660,6 +747,9 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest // Function to maintain running count up to 2, and then remove the count // Returns the data and the count (-1 if count reached beyond 2 and state was just removed) val stateFunc = (key: String, values: Iterator[String], state: GroupState[RunningCount]) => { + assertCanGetProcessingTime { state.getCurrentProcessingTimeMs() >= 0 } + assertCannotGetWatermark { state.getCurrentWatermarkMs() } + if (state.hasTimedOut) { state.remove() Iterator((key, "-1")) @@ -713,10 +803,10 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest test("flatMapGroupsWithState - streaming with event time timeout + watermark") { // Function to maintain the max event time // Returns the max event time in the state, or -1 if the state was removed by timeout - val stateFunc = ( - key: String, - values: Iterator[(String, Long)], - state: GroupState[Long]) => { + val stateFunc = (key: String, values: Iterator[(String, Long)], state: GroupState[Long]) => { + assertCanGetProcessingTime { state.getCurrentProcessingTimeMs() >= 0 } + assertCanGetWatermark { state.getCurrentWatermarkMs() >= -1 } + val timeoutDelay = 5 if (key != "a") { Iterator.empty @@ -760,6 +850,8 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest // Function to maintain running count up to 2, and then remove the count // Returns the data and the count (-1 if count reached beyond 2 and state was just removed) val stateFunc = (key: String, values: Iterator[String], state: GroupState[RunningCount]) => { + assertCanGetProcessingTime { state.getCurrentProcessingTimeMs() >= 0 } + assertCannotGetWatermark { state.getCurrentWatermarkMs() } val count = state.getOption.map(_.count).getOrElse(0L) + values.size if (count == 3) { @@ -802,7 +894,11 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest // - no initial state // - timeouts operations work, does not throw any error [SPARK-20792] // - works with primitive state type + // - can get processing time val stateFunc = (key: String, values: Iterator[String], state: GroupState[Int]) => { + assertCanGetProcessingTime { state.getCurrentProcessingTimeMs() > 0 } + assertCannotGetWatermark { state.getCurrentWatermarkMs() } + if (state.exists) throw new IllegalArgumentException("state.exists should be false") state.setTimeoutTimestamp(0, "1 hour") state.update(10) @@ -1090,4 +1186,24 @@ object FlatMapGroupsWithStateSuite { override def metrics: StateStoreMetrics = new StateStoreMetrics(map.size, 0, Map.empty) override def hasCommitted: Boolean = true } + + def assertCanGetProcessingTime(predicate: => Boolean): Unit = { + if (!predicate) throw new TestFailedException("Could not get processing time", 20) + } + + def assertCanGetWatermark(predicate: => Boolean): Unit = { + if (!predicate) throw new TestFailedException("Could not get processing time", 20) + } + + def assertCannotGetWatermark(func: => Unit): Unit = { + try { + func + } catch { + case u: UnsupportedOperationException => + return + case _ => + throw new TestFailedException("Unexpected exception when trying to get watermark", 20) + } + throw new TestFailedException("Could get watermark when not expected", 20) + } } From 72561ecf4b611d68f8bf695ddd0c4c2cce3a29d9 Mon Sep 17 00:00:00 2001 From: maryannxue Date: Wed, 18 Oct 2017 20:59:40 +0800 Subject: [PATCH 1505/1765] [SPARK-22266][SQL] The same aggregate function was evaluated multiple times ## What changes were proposed in this pull request? To let the same aggregate function that appear multiple times in an Aggregate be evaluated only once, we need to deduplicate the aggregate expressions. The original code was trying to use a "distinct" call to get a set of aggregate expressions, but did not work, since the "distinct" did not compare semantic equality. And even if it did, further work should be done in result expression rewriting. In this PR, I changed the "set" to a map mapping the semantic identity of a aggregate expression to itself. Thus, later on, when rewriting result expressions (i.e., output expressions), the aggregate expression reference can be fixed. ## How was this patch tested? Added a new test in SQLQuerySuite Author: maryannxue Closes #19488 from maryannxue/spark-22266. --- .../sql/catalyst/planning/patterns.scala | 16 +++++++----- .../org/apache/spark/sql/SQLQuerySuite.scala | 26 +++++++++++++++++++ 2 files changed, 36 insertions(+), 6 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala index 8d034c21a4960..cc391aae55787 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala @@ -205,14 +205,17 @@ object PhysicalAggregation { case logical.Aggregate(groupingExpressions, resultExpressions, child) => // A single aggregate expression might appear multiple times in resultExpressions. // In order to avoid evaluating an individual aggregate function multiple times, we'll - // build a set of the distinct aggregate expressions and build a function which can - // be used to re-write expressions so that they reference the single copy of the - // aggregate function which actually gets computed. + // build a set of semantically distinct aggregate expressions and re-write expressions so + // that they reference the single copy of the aggregate function which actually gets computed. + // Non-deterministic aggregate expressions are not deduplicated. + val equivalentAggregateExpressions = new EquivalentExpressions val aggregateExpressions = resultExpressions.flatMap { expr => expr.collect { - case agg: AggregateExpression => agg + // addExpr() always returns false for non-deterministic expressions and do not add them. + case agg: AggregateExpression + if (!equivalentAggregateExpressions.addExpr(agg)) => agg } - }.distinct + } val namedGroupingExpressions = groupingExpressions.map { case ne: NamedExpression => ne -> ne @@ -236,7 +239,8 @@ object PhysicalAggregation { case ae: AggregateExpression => // The final aggregation buffer's attributes will be `finalAggregationAttributes`, // so replace each aggregate expression by its corresponding attribute in the set: - ae.resultAttribute + equivalentAggregateExpressions.getEquivalentExprs(ae).headOption + .getOrElse(ae).asInstanceOf[AggregateExpression].resultAttribute case expression => // Since we're using `namedGroupingAttributes` to extract the grouping key // columns, we need to replace grouping key expressions with their corresponding diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index f0c58e2e5bf45..caf332d050d7b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -27,6 +27,7 @@ import org.apache.spark.{AccumulatorSuite, SparkException} import org.apache.spark.scheduler.{SparkListener, SparkListenerJobStart} import org.apache.spark.sql.catalyst.util.StringUtils import org.apache.spark.sql.execution.aggregate +import org.apache.spark.sql.execution.aggregate.{HashAggregateExec, SortAggregateExec} import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, CartesianProductExec, SortMergeJoinExec} import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf @@ -2715,4 +2716,29 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { checkAnswer(df, Row(1, 1, 1)) } } + + test("SRARK-22266: the same aggregate function was calculated multiple times") { + val query = "SELECT a, max(b+1), max(b+1) + 1 FROM testData2 GROUP BY a" + val df = sql(query) + val physical = df.queryExecution.sparkPlan + val aggregateExpressions = physical.collectFirst { + case agg : HashAggregateExec => agg.aggregateExpressions + case agg : SortAggregateExec => agg.aggregateExpressions + } + assert (aggregateExpressions.isDefined) + assert (aggregateExpressions.get.size == 1) + checkAnswer(df, Row(1, 3, 4) :: Row(2, 3, 4) :: Row(3, 3, 4) :: Nil) + } + + test("Non-deterministic aggregate functions should not be deduplicated") { + val query = "SELECT a, first_value(b), first_value(b) + 1 FROM testData2 GROUP BY a" + val df = sql(query) + val physical = df.queryExecution.sparkPlan + val aggregateExpressions = physical.collectFirst { + case agg : HashAggregateExec => agg.aggregateExpressions + case agg : SortAggregateExec => agg.aggregateExpressions + } + assert (aggregateExpressions.isDefined) + assert (aggregateExpressions.get.size == 2) + } } From 1f25d8683a84a479fd7fc77b5a1ea980289b681b Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Wed, 18 Oct 2017 09:14:46 -0700 Subject: [PATCH 1506/1765] [SPARK-22249][FOLLOWUP][SQL] Check if list of value for IN is empty in the optimizer ## What changes were proposed in this pull request? This PR addresses the comments by gatorsmile on [the previous PR](https://github.com/apache/spark/pull/19494). ## How was this patch tested? Previous UT and added UT. Author: Marco Gaido Closes #19522 from mgaido91/SPARK-22249_FOLLOWUP. --- .../execution/columnar/InMemoryTableScanExec.scala | 4 ++-- .../columnar/InMemoryColumnarQuerySuite.scala | 12 +++++++++++- 2 files changed, 13 insertions(+), 3 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala index 846ec03e46a12..139da1c519da2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala @@ -102,8 +102,8 @@ case class InMemoryTableScanExec( case IsNull(a: Attribute) => statsFor(a).nullCount > 0 case IsNotNull(a: Attribute) => statsFor(a).count - statsFor(a).nullCount > 0 - case In(_: AttributeReference, list: Seq[Expression]) if list.isEmpty => Literal.FalseLiteral - case In(a: AttributeReference, list: Seq[Expression]) if list.forall(_.isInstanceOf[Literal]) => + case In(a: AttributeReference, list: Seq[Expression]) + if list.forall(_.isInstanceOf[Literal]) && list.nonEmpty => list.map(l => statsFor(a).lowerBound <= l.asInstanceOf[Literal] && l.asInstanceOf[Literal] <= statsFor(a).upperBound).reduce(_ || _) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala index 75d17bc79477d..2f249c850a088 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala @@ -21,8 +21,9 @@ import java.nio.charset.StandardCharsets import java.sql.{Date, Timestamp} import org.apache.spark.sql.{DataFrame, QueryTest, Row} -import org.apache.spark.sql.catalyst.expressions.AttributeSet +import org.apache.spark.sql.catalyst.expressions.{AttributeReference, AttributeSet, In} import org.apache.spark.sql.catalyst.plans.physical.HashPartitioning +import org.apache.spark.sql.execution.LocalTableScanExec import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSQLContext @@ -444,4 +445,13 @@ class InMemoryColumnarQuerySuite extends QueryTest with SharedSQLContext { assert(dfNulls.filter($"id".isin(2, 3)).count() == 0) dfNulls.unpersist() } + + test("SPARK-22249: buildFilter should not throw exception when In contains an empty list") { + val attribute = AttributeReference("a", IntegerType)() + val testRelation = InMemoryRelation(false, 1, MEMORY_ONLY, + LocalTableScanExec(Seq(attribute), Nil), None) + val tableScanExec = InMemoryTableScanExec(Seq(attribute), + Seq(In(attribute, Nil)), testRelation) + assert(tableScanExec.partitionFilters.isEmpty) + } } From 52facb0062a4253fa45ac0c633d0510a9b684a62 Mon Sep 17 00:00:00 2001 From: Valeriy Avanesov Date: Wed, 18 Oct 2017 10:46:46 -0700 Subject: [PATCH 1507/1765] [SPARK-14371][MLLIB] OnlineLDAOptimizer should not collect stats for each doc in mini-batch to driver Hi, # What changes were proposed in this pull request? as it was proposed by jkbradley , ```gammat``` are not collected to the driver anymore. # How was this patch tested? existing test suite. Author: Valeriy Avanesov Author: Valeriy Avanesov Closes #18924 from akopich/master. --- .../spark/mllib/clustering/LDAOptimizer.scala | 82 +++++++++++++------ 1 file changed, 57 insertions(+), 25 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala index d633893e55f55..693a2a31f026b 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala @@ -26,6 +26,7 @@ import breeze.stats.distributions.{Gamma, RandBasis} import org.apache.spark.annotation.{DeveloperApi, Since} import org.apache.spark.graphx._ import org.apache.spark.graphx.util.PeriodicGraphCheckpointer +import org.apache.spark.internal.Logging import org.apache.spark.mllib.linalg.{DenseVector, Matrices, SparseVector, Vector, Vectors} import org.apache.spark.rdd.RDD import org.apache.spark.storage.StorageLevel @@ -259,7 +260,7 @@ final class EMLDAOptimizer extends LDAOptimizer { */ @Since("1.4.0") @DeveloperApi -final class OnlineLDAOptimizer extends LDAOptimizer { +final class OnlineLDAOptimizer extends LDAOptimizer with Logging { // LDA common parameters private var k: Int = 0 @@ -462,31 +463,61 @@ final class OnlineLDAOptimizer extends LDAOptimizer { val expElogbetaBc = batch.sparkContext.broadcast(expElogbeta) val alpha = this.alpha.asBreeze val gammaShape = this.gammaShape - - val stats: RDD[(BDM[Double], List[BDV[Double]])] = batch.mapPartitions { docs => + val optimizeDocConcentration = this.optimizeDocConcentration + // If and only if optimizeDocConcentration is set true, + // we calculate logphat in the same pass as other statistics. + // No calculation of loghat happens otherwise. + val logphatPartOptionBase = () => if (optimizeDocConcentration) { + Some(BDV.zeros[Double](k)) + } else { + None + } + + val stats: RDD[(BDM[Double], Option[BDV[Double]], Long)] = batch.mapPartitions { docs => val nonEmptyDocs = docs.filter(_._2.numNonzeros > 0) val stat = BDM.zeros[Double](k, vocabSize) - var gammaPart = List[BDV[Double]]() + val logphatPartOption = logphatPartOptionBase() + var nonEmptyDocCount: Long = 0L nonEmptyDocs.foreach { case (_, termCounts: Vector) => + nonEmptyDocCount += 1 val (gammad, sstats, ids) = OnlineLDAOptimizer.variationalTopicInference( termCounts, expElogbetaBc.value, alpha, gammaShape, k) - stat(::, ids) := stat(::, ids).toDenseMatrix + sstats - gammaPart = gammad :: gammaPart + stat(::, ids) := stat(::, ids) + sstats + logphatPartOption.foreach(_ += LDAUtils.dirichletExpectation(gammad)) } - Iterator((stat, gammaPart)) - }.persist(StorageLevel.MEMORY_AND_DISK) - val statsSum: BDM[Double] = stats.map(_._1).treeAggregate(BDM.zeros[Double](k, vocabSize))( - _ += _, _ += _) - val gammat: BDM[Double] = breeze.linalg.DenseMatrix.vertcat( - stats.map(_._2).flatMap(list => list).collect().map(_.toDenseMatrix): _*) - stats.unpersist() + Iterator((stat, logphatPartOption, nonEmptyDocCount)) + } + + val elementWiseSum = ( + u: (BDM[Double], Option[BDV[Double]], Long), + v: (BDM[Double], Option[BDV[Double]], Long)) => { + u._1 += v._1 + u._2.foreach(_ += v._2.get) + (u._1, u._2, u._3 + v._3) + } + + val (statsSum: BDM[Double], logphatOption: Option[BDV[Double]], nonEmptyDocsN: Long) = stats + .treeAggregate((BDM.zeros[Double](k, vocabSize), logphatPartOptionBase(), 0L))( + elementWiseSum, elementWiseSum + ) + expElogbetaBc.destroy(false) - val batchResult = statsSum *:* expElogbeta.t + if (nonEmptyDocsN == 0) { + logWarning("No non-empty documents were submitted in the batch.") + // Therefore, there is no need to update any of the model parameters + return this + } + + val batchResult = statsSum *:* expElogbeta.t // Note that this is an optimization to avoid batch.count - updateLambda(batchResult, (miniBatchFraction * corpusSize).ceil.toInt) - if (optimizeDocConcentration) updateAlpha(gammat) + val batchSize = (miniBatchFraction * corpusSize).ceil.toInt + updateLambda(batchResult, batchSize) + + logphatOption.foreach(_ /= nonEmptyDocsN.toDouble) + logphatOption.foreach(updateAlpha(_, nonEmptyDocsN)) + this } @@ -503,21 +534,22 @@ final class OnlineLDAOptimizer extends LDAOptimizer { } /** - * Update alpha based on `gammat`, the inferred topic distributions for documents in the - * current mini-batch. Uses Newton-Rhapson method. + * Update alpha based on `logphat`. + * Uses Newton-Rhapson method. * @see Section 3.3, Huang: Maximum Likelihood Estimation of Dirichlet Distribution Parameters * (http://jonathan-huang.org/research/dirichlet/dirichlet.pdf) + * @param logphat Expectation of estimated log-posterior distribution of + * topics in a document averaged over the batch. + * @param nonEmptyDocsN number of non-empty documents */ - private def updateAlpha(gammat: BDM[Double]): Unit = { + private def updateAlpha(logphat: BDV[Double], nonEmptyDocsN: Double): Unit = { val weight = rho() - val N = gammat.rows.toDouble val alpha = this.alpha.asBreeze.toDenseVector - val logphat: BDV[Double] = - sum(LDAUtils.dirichletExpectation(gammat)(::, breeze.linalg.*)).t / N - val gradf = N * (-LDAUtils.dirichletExpectation(alpha) + logphat) - val c = N * trigamma(sum(alpha)) - val q = -N * trigamma(alpha) + val gradf = nonEmptyDocsN * (-LDAUtils.dirichletExpectation(alpha) + logphat) + + val c = nonEmptyDocsN * trigamma(sum(alpha)) + val q = -nonEmptyDocsN * trigamma(alpha) val b = sum(gradf / q) / (1D / c + sum(1D / q)) val dalpha = -(gradf - b) / q From 6f1d0dea1cdda558c998179789b386f6e52b9e36 Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Thu, 19 Oct 2017 13:30:55 +0800 Subject: [PATCH 1508/1765] [SPARK-22300][BUILD] Update ORC to 1.4.1 ## What changes were proposed in this pull request? Apache ORC 1.4.1 is released yesterday. - https://orc.apache.org/news/2017/10/16/ORC-1.4.1/ Like ORC-233 (Allow `orc.include.columns` to be empty), there are several important fixes. This PR updates Apache ORC dependency to use the latest one, 1.4.1. ## How was this patch tested? Pass the Jenkins. Author: Dongjoon Hyun Closes #19521 from dongjoon-hyun/SPARK-22300. --- dev/deps/spark-deps-hadoop-2.6 | 6 +++--- dev/deps/spark-deps-hadoop-2.7 | 6 +++--- pom.xml | 6 +++++- 3 files changed, 11 insertions(+), 7 deletions(-) diff --git a/dev/deps/spark-deps-hadoop-2.6 b/dev/deps/spark-deps-hadoop-2.6 index 76fcbd15869f1..6e2fc63d67108 100644 --- a/dev/deps/spark-deps-hadoop-2.6 +++ b/dev/deps/spark-deps-hadoop-2.6 @@ -2,7 +2,7 @@ JavaEWAH-0.3.2.jar RoaringBitmap-0.5.11.jar ST4-4.0.4.jar activation-1.1.1.jar -aircompressor-0.3.jar +aircompressor-0.8.jar antlr-2.7.7.jar antlr-runtime-3.4.jar antlr4-runtime-4.7.jar @@ -149,8 +149,8 @@ netty-3.9.9.Final.jar netty-all-4.0.47.Final.jar objenesis-2.1.jar opencsv-2.3.jar -orc-core-1.4.0-nohive.jar -orc-mapreduce-1.4.0-nohive.jar +orc-core-1.4.1-nohive.jar +orc-mapreduce-1.4.1-nohive.jar oro-2.0.8.jar osgi-resource-locator-1.0.1.jar paranamer-2.8.jar diff --git a/dev/deps/spark-deps-hadoop-2.7 b/dev/deps/spark-deps-hadoop-2.7 index cb20072bf8b30..c2bbc253d723a 100644 --- a/dev/deps/spark-deps-hadoop-2.7 +++ b/dev/deps/spark-deps-hadoop-2.7 @@ -2,7 +2,7 @@ JavaEWAH-0.3.2.jar RoaringBitmap-0.5.11.jar ST4-4.0.4.jar activation-1.1.1.jar -aircompressor-0.3.jar +aircompressor-0.8.jar antlr-2.7.7.jar antlr-runtime-3.4.jar antlr4-runtime-4.7.jar @@ -150,8 +150,8 @@ netty-3.9.9.Final.jar netty-all-4.0.47.Final.jar objenesis-2.1.jar opencsv-2.3.jar -orc-core-1.4.0-nohive.jar -orc-mapreduce-1.4.0-nohive.jar +orc-core-1.4.1-nohive.jar +orc-mapreduce-1.4.1-nohive.jar oro-2.0.8.jar osgi-resource-locator-1.0.1.jar paranamer-2.8.jar diff --git a/pom.xml b/pom.xml index 9fac8b1e53788..b9c972855204a 100644 --- a/pom.xml +++ b/pom.xml @@ -128,7 +128,7 @@ 1.2.1 10.12.1.1 1.8.2 - 1.4.0 + 1.4.1 nohive 1.6.0 9.3.20.v20170531 @@ -1712,6 +1712,10 @@ org.apache.hive hive-storage-api + + io.airlift + slice + From dc2714da50ecba1bf1fdf555a82a4314f763a76e Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Thu, 19 Oct 2017 14:56:48 +0800 Subject: [PATCH 1509/1765] [SPARK-22290][CORE] Avoid creating Hive delegation tokens when not necessary. Hive delegation tokens are only needed when the Spark driver has no access to the kerberos TGT. That happens only in two situations: - when using a proxy user - when using cluster mode without a keytab This change modifies the Hive provider so that it only generates delegation tokens in those situations, and tweaks the YARN AM so that it makes the proper user visible to the Hive code when running with keytabs, so that the TGT can be used instead of a delegation token. The effect of this change is that now it's possible to initialize multiple, non-concurrent SparkContext instances in the same JVM. Before, the second invocation would fail to fetch a new Hive delegation token, which then could make the second (or third or...) application fail once the token expired. With this change, the TGT will be used to authenticate to the HMS instead. This change also avoids polluting the current logged in user's credentials when launching applications. The credentials are copied only when running applications as a proxy user. This makes it possible to implement SPARK-11035 later, where multiple threads might be launching applications, and each app should have its own set of credentials. Tested by verifying HDFS and Hive access in following scenarios: - client and cluster mode - client and cluster mode with proxy user - client and cluster mode with principal / keytab - long-running cluster app with principal / keytab - pyspark app that creates (and stops) multiple SparkContext instances through its lifetime Author: Marcelo Vanzin Closes #19509 from vanzin/SPARK-22290. --- .../apache/spark/deploy/SparkHadoopUtil.scala | 17 +++-- .../HBaseDelegationTokenProvider.scala | 4 +- .../HadoopDelegationTokenManager.scala | 2 +- .../HadoopDelegationTokenProvider.scala | 2 +- .../HadoopFSDelegationTokenProvider.scala | 4 +- .../HiveDelegationTokenProvider.scala | 20 +++++- docs/running-on-yarn.md | 9 +++ .../spark/deploy/yarn/ApplicationMaster.scala | 69 +++++++++++++++---- .../org/apache/spark/deploy/yarn/Client.scala | 5 +- .../org/apache/spark/deploy/yarn/config.scala | 4 ++ .../sql/hive/client/HiveClientImpl.scala | 6 -- 11 files changed, 110 insertions(+), 32 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala b/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala index 53775db251bc6..1fa10ab943f34 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala @@ -61,13 +61,17 @@ class SparkHadoopUtil extends Logging { * do a FileSystem.closeAllForUGI in order to avoid leaking Filesystems */ def runAsSparkUser(func: () => Unit) { + createSparkUser().doAs(new PrivilegedExceptionAction[Unit] { + def run: Unit = func() + }) + } + + def createSparkUser(): UserGroupInformation = { val user = Utils.getCurrentUserName() - logDebug("running as user: " + user) + logDebug("creating UGI for user: " + user) val ugi = UserGroupInformation.createRemoteUser(user) transferCredentials(UserGroupInformation.getCurrentUser(), ugi) - ugi.doAs(new PrivilegedExceptionAction[Unit] { - def run: Unit = func() - }) + ugi } def transferCredentials(source: UserGroupInformation, dest: UserGroupInformation) { @@ -417,6 +421,11 @@ class SparkHadoopUtil extends Logging { creds.readTokenStorageStream(new DataInputStream(tokensBuf)) creds } + + def isProxyUser(ugi: UserGroupInformation): Boolean = { + ugi.getAuthenticationMethod() == UserGroupInformation.AuthenticationMethod.PROXY + } + } object SparkHadoopUtil { diff --git a/core/src/main/scala/org/apache/spark/deploy/security/HBaseDelegationTokenProvider.scala b/core/src/main/scala/org/apache/spark/deploy/security/HBaseDelegationTokenProvider.scala index 78b0e6b2cbf39..5dcde4ec3a8a4 100644 --- a/core/src/main/scala/org/apache/spark/deploy/security/HBaseDelegationTokenProvider.scala +++ b/core/src/main/scala/org/apache/spark/deploy/security/HBaseDelegationTokenProvider.scala @@ -56,7 +56,9 @@ private[security] class HBaseDelegationTokenProvider None } - override def delegationTokensRequired(hadoopConf: Configuration): Boolean = { + override def delegationTokensRequired( + sparkConf: SparkConf, + hadoopConf: Configuration): Boolean = { hbaseConf(hadoopConf).get("hbase.security.authentication") == "kerberos" } diff --git a/core/src/main/scala/org/apache/spark/deploy/security/HadoopDelegationTokenManager.scala b/core/src/main/scala/org/apache/spark/deploy/security/HadoopDelegationTokenManager.scala index c134b7ebe38fa..483d0deec8070 100644 --- a/core/src/main/scala/org/apache/spark/deploy/security/HadoopDelegationTokenManager.scala +++ b/core/src/main/scala/org/apache/spark/deploy/security/HadoopDelegationTokenManager.scala @@ -115,7 +115,7 @@ private[spark] class HadoopDelegationTokenManager( hadoopConf: Configuration, creds: Credentials): Long = { delegationTokenProviders.values.flatMap { provider => - if (provider.delegationTokensRequired(hadoopConf)) { + if (provider.delegationTokensRequired(sparkConf, hadoopConf)) { provider.obtainDelegationTokens(hadoopConf, sparkConf, creds) } else { logDebug(s"Service ${provider.serviceName} does not require a token." + diff --git a/core/src/main/scala/org/apache/spark/deploy/security/HadoopDelegationTokenProvider.scala b/core/src/main/scala/org/apache/spark/deploy/security/HadoopDelegationTokenProvider.scala index 1ba245e84af4b..ed0905088ab25 100644 --- a/core/src/main/scala/org/apache/spark/deploy/security/HadoopDelegationTokenProvider.scala +++ b/core/src/main/scala/org/apache/spark/deploy/security/HadoopDelegationTokenProvider.scala @@ -37,7 +37,7 @@ private[spark] trait HadoopDelegationTokenProvider { * Returns true if delegation tokens are required for this service. By default, it is based on * whether Hadoop security is enabled. */ - def delegationTokensRequired(hadoopConf: Configuration): Boolean + def delegationTokensRequired(sparkConf: SparkConf, hadoopConf: Configuration): Boolean /** * Obtain delegation tokens for this service and get the time of the next renewal. diff --git a/core/src/main/scala/org/apache/spark/deploy/security/HadoopFSDelegationTokenProvider.scala b/core/src/main/scala/org/apache/spark/deploy/security/HadoopFSDelegationTokenProvider.scala index 300773c58b183..21ca669ea98f0 100644 --- a/core/src/main/scala/org/apache/spark/deploy/security/HadoopFSDelegationTokenProvider.scala +++ b/core/src/main/scala/org/apache/spark/deploy/security/HadoopFSDelegationTokenProvider.scala @@ -69,7 +69,9 @@ private[deploy] class HadoopFSDelegationTokenProvider(fileSystems: Configuration nextRenewalDate } - def delegationTokensRequired(hadoopConf: Configuration): Boolean = { + override def delegationTokensRequired( + sparkConf: SparkConf, + hadoopConf: Configuration): Boolean = { UserGroupInformation.isSecurityEnabled } diff --git a/core/src/main/scala/org/apache/spark/deploy/security/HiveDelegationTokenProvider.scala b/core/src/main/scala/org/apache/spark/deploy/security/HiveDelegationTokenProvider.scala index b31cc595ed83b..ece5ce79c650d 100644 --- a/core/src/main/scala/org/apache/spark/deploy/security/HiveDelegationTokenProvider.scala +++ b/core/src/main/scala/org/apache/spark/deploy/security/HiveDelegationTokenProvider.scala @@ -31,7 +31,9 @@ import org.apache.hadoop.security.{Credentials, UserGroupInformation} import org.apache.hadoop.security.token.Token import org.apache.spark.SparkConf +import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.internal.Logging +import org.apache.spark.internal.config.KEYTAB import org.apache.spark.util.Utils private[security] class HiveDelegationTokenProvider @@ -55,9 +57,21 @@ private[security] class HiveDelegationTokenProvider } } - override def delegationTokensRequired(hadoopConf: Configuration): Boolean = { + override def delegationTokensRequired( + sparkConf: SparkConf, + hadoopConf: Configuration): Boolean = { + // Delegation tokens are needed only when: + // - trying to connect to a secure metastore + // - either deploying in cluster mode without a keytab, or impersonating another user + // + // Other modes (such as client with or without keytab, or cluster mode with keytab) do not need + // a delegation token, since there's a valid kerberos TGT for the right user available to the + // driver, which is the only process that connects to the HMS. + val deployMode = sparkConf.get("spark.submit.deployMode", "client") UserGroupInformation.isSecurityEnabled && - hiveConf(hadoopConf).getTrimmed("hive.metastore.uris", "").nonEmpty + hiveConf(hadoopConf).getTrimmed("hive.metastore.uris", "").nonEmpty && + (SparkHadoopUtil.get.isProxyUser(UserGroupInformation.getCurrentUser()) || + (deployMode == "cluster" && !sparkConf.contains(KEYTAB))) } override def obtainDelegationTokens( @@ -83,7 +97,7 @@ private[security] class HiveDelegationTokenProvider val hive2Token = new Token[DelegationTokenIdentifier]() hive2Token.decodeFromUrlString(tokenStr) - logInfo(s"Get Token from hive metastore: ${hive2Token.toString}") + logDebug(s"Get Token from hive metastore: ${hive2Token.toString}") creds.addToken(new Text("hive.server2.delegation.token"), hive2Token) } diff --git a/docs/running-on-yarn.md b/docs/running-on-yarn.md index 432639588cc2b..9599d40c545b2 100644 --- a/docs/running-on-yarn.md +++ b/docs/running-on-yarn.md @@ -401,6 +401,15 @@ To use a custom metrics.properties for the application master and executors, upd Principal to be used to login to KDC, while running on secure HDFS. (Works also with the "local" master) + + + + + diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala index e227bff88f71d..f6167235f89e4 100644 --- a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala +++ b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala @@ -20,6 +20,7 @@ package org.apache.spark.deploy.yarn import java.io.{File, IOException} import java.lang.reflect.InvocationTargetException import java.net.{Socket, URI, URL} +import java.security.PrivilegedExceptionAction import java.util.concurrent.{TimeoutException, TimeUnit} import scala.collection.mutable.HashMap @@ -28,6 +29,7 @@ import scala.concurrent.duration.Duration import scala.util.control.NonFatal import org.apache.hadoop.fs.{FileSystem, Path} +import org.apache.hadoop.security.UserGroupInformation import org.apache.hadoop.yarn.api._ import org.apache.hadoop.yarn.api.records._ import org.apache.hadoop.yarn.conf.YarnConfiguration @@ -49,10 +51,7 @@ import org.apache.spark.util._ /** * Common application master functionality for Spark on Yarn. */ -private[spark] class ApplicationMaster( - args: ApplicationMasterArguments, - client: YarnRMClient) - extends Logging { +private[spark] class ApplicationMaster(args: ApplicationMasterArguments) extends Logging { // TODO: Currently, task to container is computed once (TaskSetManager) - which need not be // optimal as more containers are available. Might need to handle this better. @@ -62,6 +61,46 @@ private[spark] class ApplicationMaster( .asInstanceOf[YarnConfiguration] private val isClusterMode = args.userClass != null + private val ugi = { + val original = UserGroupInformation.getCurrentUser() + + // If a principal and keytab were provided, log in to kerberos, and set up a thread to + // renew the kerberos ticket when needed. Because the UGI API does not expose the TTL + // of the TGT, use a configuration to define how often to check that a relogin is necessary. + // checkTGTAndReloginFromKeytab() is a no-op if the relogin is not yet needed. + val principal = sparkConf.get(PRINCIPAL).orNull + val keytab = sparkConf.get(KEYTAB).orNull + if (principal != null && keytab != null) { + UserGroupInformation.loginUserFromKeytab(principal, keytab) + + val renewer = new Thread() { + override def run(): Unit = Utils.tryLogNonFatalError { + while (true) { + TimeUnit.SECONDS.sleep(sparkConf.get(KERBEROS_RELOGIN_PERIOD)) + UserGroupInformation.getCurrentUser().checkTGTAndReloginFromKeytab() + } + } + } + renewer.setName("am-kerberos-renewer") + renewer.setDaemon(true) + renewer.start() + + // Transfer the original user's tokens to the new user, since that's needed to connect to + // YARN. It also copies over any delegation tokens that might have been created by the + // client, which will then be transferred over when starting executors (until new ones + // are created by the periodic task). + val newUser = UserGroupInformation.getCurrentUser() + SparkHadoopUtil.get.transferCredentials(original, newUser) + newUser + } else { + SparkHadoopUtil.get.createSparkUser() + } + } + + private val client = ugi.doAs(new PrivilegedExceptionAction[YarnRMClient]() { + def run: YarnRMClient = new YarnRMClient() + }) + // Default to twice the number of executors (twice the maximum number of executors if dynamic // allocation is enabled), with a minimum of 3. @@ -201,6 +240,13 @@ private[spark] class ApplicationMaster( } final def run(): Int = { + ugi.doAs(new PrivilegedExceptionAction[Unit]() { + def run: Unit = runImpl() + }) + exitCode + } + + private def runImpl(): Unit = { try { val appAttemptId = client.getAttemptId() @@ -254,11 +300,6 @@ private[spark] class ApplicationMaster( } } - // Call this to force generation of secret so it gets populated into the - // Hadoop UGI. This has to happen before the startUserApplication which does a - // doAs in order for the credentials to be passed on to the executor containers. - val securityMgr = new SecurityManager(sparkConf) - // If the credentials file config is present, we must periodically renew tokens. So create // a new AMDelegationTokenRenewer if (sparkConf.contains(CREDENTIALS_FILE_PATH)) { @@ -284,6 +325,9 @@ private[spark] class ApplicationMaster( credentialRenewerThread.join() } + // Call this to force generation of secret so it gets populated into the Hadoop UGI. + val securityMgr = new SecurityManager(sparkConf) + if (isClusterMode) { runDriver(securityMgr) } else { @@ -297,7 +341,6 @@ private[spark] class ApplicationMaster( ApplicationMaster.EXIT_UNCAUGHT_EXCEPTION, "Uncaught exception: " + e) } - exitCode } /** @@ -775,10 +818,8 @@ object ApplicationMaster extends Logging { sys.props(k) = v } } - SparkHadoopUtil.get.runAsSparkUser { () => - master = new ApplicationMaster(amArgs, new YarnRMClient) - System.exit(master.run()) - } + master = new ApplicationMaster(amArgs) + System.exit(master.run()) } private[spark] def sparkContextInitialized(sc: SparkContext): Unit = { diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala index 64b2b4d4db549..1fe25c4ddaabf 100644 --- a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala +++ b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala @@ -394,7 +394,10 @@ private[spark] class Client( if (credentials != null) { // Add credentials to current user's UGI, so that following operations don't need to use the // Kerberos tgt to get delegations again in the client side. - UserGroupInformation.getCurrentUser.addCredentials(credentials) + val currentUser = UserGroupInformation.getCurrentUser() + if (SparkHadoopUtil.get.isProxyUser(currentUser)) { + currentUser.addCredentials(credentials) + } logDebug(YarnSparkHadoopUtil.get.dumpTokens(credentials).mkString("\n")) } diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/config.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/config.scala index 187803cc6050b..e1af8ba087d6e 100644 --- a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/config.scala +++ b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/config.scala @@ -347,6 +347,10 @@ package object config { .timeConf(TimeUnit.MILLISECONDS) .createWithDefault(Long.MaxValue) + private[spark] val KERBEROS_RELOGIN_PERIOD = ConfigBuilder("spark.yarn.kerberos.relogin.period") + .timeConf(TimeUnit.SECONDS) + .createWithDefaultString("1m") + // The list of cache-related config entries. This is used by Client and the AM to clean // up the environment so that these settings do not appear on the web UI. private[yarn] val CACHE_CONFIGS = Seq( diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala index a01c312d5e497..16c95c53b4201 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala @@ -111,12 +111,6 @@ private[hive] class HiveClientImpl( if (clientLoader.isolationOn) { // Switch to the initClassLoader. Thread.currentThread().setContextClassLoader(initClassLoader) - // Set up kerberos credentials for UserGroupInformation.loginUser within current class loader - if (sparkConf.contains("spark.yarn.principal") && sparkConf.contains("spark.yarn.keytab")) { - val principal = sparkConf.get("spark.yarn.principal") - val keytab = sparkConf.get("spark.yarn.keytab") - SparkHadoopUtil.get.loginUserFromKeytab(principal, keytab) - } try { newState() } finally { From 5a07aca4d464e96d75ea17bf6768e24b829872ec Mon Sep 17 00:00:00 2001 From: krishna-pandey Date: Thu, 19 Oct 2017 08:33:14 +0100 Subject: [PATCH 1510/1765] [SPARK-22188][CORE] Adding security headers for preventing XSS, MitM and MIME sniffing ## What changes were proposed in this pull request? The HTTP Strict-Transport-Security response header (often abbreviated as HSTS) is a security feature that lets a web site tell browsers that it should only be communicated with using HTTPS, instead of using HTTP. Note: The Strict-Transport-Security header is ignored by the browser when your site is accessed using HTTP; this is because an attacker may intercept HTTP connections and inject the header or remove it. When your site is accessed over HTTPS with no certificate errors, the browser knows your site is HTTPS capable and will honor the Strict-Transport-Security header. The HTTP X-XSS-Protection response header is a feature of Internet Explorer, Chrome and Safari that stops pages from loading when they detect reflected cross-site scripting (XSS) attacks. The HTTP X-Content-Type-Options response header is used to protect against MIME sniffing vulnerabilities. ## How was this patch tested? Checked on my system locally. screen shot 2017-10-03 at 6 49 20 pm Author: krishna-pandey Author: Krishna Pandey Closes #19419 from krishna-pandey/SPARK-22188. --- .../spark/internal/config/package.scala | 18 +++++++ .../org/apache/spark/ui/JettyUtils.scala | 9 ++++ docs/security.md | 47 +++++++++++++++++++ 3 files changed, 74 insertions(+) diff --git a/core/src/main/scala/org/apache/spark/internal/config/package.scala b/core/src/main/scala/org/apache/spark/internal/config/package.scala index 0c36bdcdd2904..6f0247b73070d 100644 --- a/core/src/main/scala/org/apache/spark/internal/config/package.scala +++ b/core/src/main/scala/org/apache/spark/internal/config/package.scala @@ -452,6 +452,24 @@ package object config { .toSequence .createWithDefault(Nil) + private[spark] val UI_X_XSS_PROTECTION = + ConfigBuilder("spark.ui.xXssProtection") + .doc("Value for HTTP X-XSS-Protection response header") + .stringConf + .createWithDefaultString("1; mode=block") + + private[spark] val UI_X_CONTENT_TYPE_OPTIONS = + ConfigBuilder("spark.ui.xContentTypeOptions.enabled") + .doc("Set to 'true' for setting X-Content-Type-Options HTTP response header to 'nosniff'") + .booleanConf + .createWithDefault(true) + + private[spark] val UI_STRICT_TRANSPORT_SECURITY = + ConfigBuilder("spark.ui.strictTransportSecurity") + .doc("Value for HTTP Strict Transport Security Response Header") + .stringConf + .createOptional + private[spark] val EXTRA_LISTENERS = ConfigBuilder("spark.extraListeners") .doc("Class names of listeners to add to SparkContext during initialization.") .stringConf diff --git a/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala b/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala index 5ee04dad6ed4d..0adeb4058b6e4 100644 --- a/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala +++ b/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala @@ -39,6 +39,7 @@ import org.json4s.jackson.JsonMethods.{pretty, render} import org.apache.spark.{SecurityManager, SparkConf, SSLOptions} import org.apache.spark.internal.Logging +import org.apache.spark.internal.config._ import org.apache.spark.util.Utils /** @@ -89,6 +90,14 @@ private[spark] object JettyUtils extends Logging { val result = servletParams.responder(request) response.setHeader("Cache-Control", "no-cache, no-store, must-revalidate") response.setHeader("X-Frame-Options", xFrameOptionsValue) + response.setHeader("X-XSS-Protection", conf.get(UI_X_XSS_PROTECTION)) + if (conf.get(UI_X_CONTENT_TYPE_OPTIONS)) { + response.setHeader("X-Content-Type-Options", "nosniff") + } + if (request.getScheme == "https") { + conf.get(UI_STRICT_TRANSPORT_SECURITY).foreach( + response.setHeader("Strict-Transport-Security", _)) + } response.getWriter.print(servletParams.extractFn(result)) } else { response.setStatus(HttpServletResponse.SC_FORBIDDEN) diff --git a/docs/security.md b/docs/security.md index 1d004003f9a32..15aadf07cf873 100644 --- a/docs/security.md +++ b/docs/security.md @@ -186,7 +186,54 @@ configure those ports.
    Property NameDefaultMeaning
    spark.eventLog.logBlockUpdates.enabledfalse + Whether to log events for every block update, if spark.eventLog.enabled is true. + *Warning*: This will increase the size of the event log considerably. +
    spark.eventLog.compress false
    spark.yarn.kerberos.relogin.period1m + How often to check whether the kerberos TGT should be renewed. This should be set to a value + that is shorter than the TGT renewal period (or the TGT lifetime if TGT renewal is not enabled). + The default value should be enough for most deployments. +
    spark.yarn.config.gatewayPath (none)
    +### HTTP Security Headers + +Apache Spark can be configured to include HTTP Headers which aids in preventing Cross +Site Scripting (XSS), Cross-Frame Scripting (XFS), MIME-Sniffing and also enforces HTTP +Strict Transport Security. + + + + + + + + + + + + + + + + + + +
    Property NameDefaultMeaning
    spark.ui.xXssProtection1; mode=block + Value for HTTP X-XSS-Protection response header. You can choose appropriate value + from below: +
      +
    • 0 (Disables XSS filtering)
    • +
    • 1 (Enables XSS filtering. If a cross-site scripting attack is detected, + the browser will sanitize the page.)
    • +
    • 1; mode=block (Enables XSS filtering. The browser will prevent rendering + of the page if an attack is detected.)
    • +
    +
    spark.ui.xContentTypeOptions.enabledtrue + When value is set to "true", X-Content-Type-Options HTTP response header will be set + to "nosniff". Set "false" to disable. +
    spark.ui.strictTransportSecurityNone + Value for HTTP Strict Transport Security (HSTS) Response Header. You can choose appropriate + value from below and set expire-time accordingly, when Spark is SSL/TLS enabled. +
      +
    • max-age=<expire-time>
    • +
    • max-age=<expire-time>; includeSubDomains
    • +
    • max-age=<expire-time>; preload
    • +
    +
    + See the [configuration page](configuration.html) for more details on the security configuration parameters, and org.apache.spark.SecurityManager for implementation details about security. + From 7fae7995ba05e0333d1decb7ca74ddb7c1b448d7 Mon Sep 17 00:00:00 2001 From: Andrew Ash Date: Fri, 20 Oct 2017 09:40:00 +0900 Subject: [PATCH 1511/1765] [SPARK-22268][BUILD] Fix lint-java ## What changes were proposed in this pull request? Fix java style issues ## How was this patch tested? Run `./dev/lint-java` locally since it's not run on Jenkins Author: Andrew Ash Closes #19486 from ash211/aash/fix-lint-java. --- .../unsafe/sort/UnsafeInMemorySorter.java | 9 ++++---- .../sort/UnsafeExternalSorterSuite.java | 21 +++++++++++-------- .../sort/UnsafeInMemorySorterSuite.java | 3 ++- .../v2/reader/SupportsPushDownFilters.java | 1 - 4 files changed, 19 insertions(+), 15 deletions(-) diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java index 869ec908be1fb..3bb87a6ed653d 100644 --- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java @@ -172,10 +172,11 @@ public void free() { public void reset() { if (consumer != null) { consumer.freeArray(array); - // the call to consumer.allocateArray may trigger a spill - // which in turn access this instance and eventually re-enter this method and try to free the array again. - // by setting the array to null and its length to 0 we effectively make the spill code-path a no-op. - // setting the array to null also indicates that it has already been de-allocated which prevents a double de-allocation in free(). + // the call to consumer.allocateArray may trigger a spill which in turn access this instance + // and eventually re-enter this method and try to free the array again. by setting the array + // to null and its length to 0 we effectively make the spill code-path a no-op. setting the + // array to null also indicates that it has already been de-allocated which prevents a double + // de-allocation in free(). array = null; usableCapacity = 0; pos = 0; diff --git a/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java b/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java index 6c5451d0fd2a5..d0d0334add0bf 100644 --- a/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java +++ b/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java @@ -516,12 +516,13 @@ public void testOOMDuringSpill() throws Exception { for (int i = 0; sorter.hasSpaceForAnotherRecord(); ++i) { insertNumber(sorter, i); } - // we expect the next insert to attempt growing the pointerssArray - // first allocation is expected to fail, then a spill is triggered which attempts another allocation - // which also fails and we expect to see this OOM here. - // the original code messed with a released array within the spill code - // and ended up with a failed assertion. - // we also expect the location of the OOM to be org.apache.spark.util.collection.unsafe.sort.UnsafeInMemorySorter.reset + // we expect the next insert to attempt growing the pointerssArray first + // allocation is expected to fail, then a spill is triggered which + // attempts another allocation which also fails and we expect to see this + // OOM here. the original code messed with a released array within the + // spill code and ended up with a failed assertion. we also expect the + // location of the OOM to be + // org.apache.spark.util.collection.unsafe.sort.UnsafeInMemorySorter.reset memoryManager.markconsequentOOM(2); try { insertNumber(sorter, 1024); @@ -530,9 +531,11 @@ public void testOOMDuringSpill() throws Exception { // we expect an OutOfMemoryError here, anything else (i.e the original NPE is a failure) catch (OutOfMemoryError oom){ String oomStackTrace = Utils.exceptionString(oom); - assertThat("expected OutOfMemoryError in org.apache.spark.util.collection.unsafe.sort.UnsafeInMemorySorter.reset", - oomStackTrace, - Matchers.containsString("org.apache.spark.util.collection.unsafe.sort.UnsafeInMemorySorter.reset")); + assertThat("expected OutOfMemoryError in " + + "org.apache.spark.util.collection.unsafe.sort.UnsafeInMemorySorter.reset", + oomStackTrace, + Matchers.containsString( + "org.apache.spark.util.collection.unsafe.sort.UnsafeInMemorySorter.reset")); } } diff --git a/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorterSuite.java b/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorterSuite.java index 1a3e11efe9787..594f07dd780f9 100644 --- a/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorterSuite.java +++ b/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorterSuite.java @@ -179,7 +179,8 @@ public int compare( } catch (OutOfMemoryError oom) { // as expected } - // [SPARK-21907] this failed on NPE at org.apache.spark.memory.MemoryConsumer.freeArray(MemoryConsumer.java:108) + // [SPARK-21907] this failed on NPE at + // org.apache.spark.memory.MemoryConsumer.freeArray(MemoryConsumer.java:108) sorter.free(); // simulate a 'back to back' free. sorter.free(); diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsPushDownFilters.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsPushDownFilters.java index d6f297c013375..6b0c9d417eeae 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsPushDownFilters.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsPushDownFilters.java @@ -18,7 +18,6 @@ package org.apache.spark.sql.sources.v2.reader; import org.apache.spark.annotation.InterfaceStability; -import org.apache.spark.sql.catalyst.expressions.Expression; import org.apache.spark.sql.sources.Filter; /** From b034f2565f72aa73c9f0be1e49d148bb4cf05153 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Thu, 19 Oct 2017 20:24:51 -0700 Subject: [PATCH 1512/1765] [SPARK-22026][SQL] data source v2 write path ## What changes were proposed in this pull request? A working prototype for data source v2 write path. The writing framework is similar to the reading framework. i.e. `WriteSupport` -> `DataSourceV2Writer` -> `DataWriterFactory` -> `DataWriter`. Similar to the `FileCommitPotocol`, the writing API has job and task level commit/abort to support the transaction. ## How was this patch tested? new tests Author: Wenchen Fan Closes #19269 from cloud-fan/data-source-v2-write. --- .../spark/sql/sources/v2/WriteSupport.java | 49 ++++ .../sources/v2/writer/DataSourceV2Writer.java | 88 +++++++ .../sql/sources/v2/writer/DataWriter.java | 92 +++++++ .../sources/v2/writer/DataWriterFactory.java | 50 ++++ .../v2/writer/SupportsWriteInternalRow.java | 44 ++++ .../v2/writer/WriterCommitMessage.java | 33 +++ .../apache/spark/sql/DataFrameWriter.scala | 38 ++- .../datasources/v2/DataSourceV2Strategy.scala | 11 +- .../datasources/v2/WriteToDataSourceV2.scala | 133 ++++++++++ .../sql/sources/v2/DataSourceV2Suite.scala | 69 +++++ .../sources/v2/SimpleWritableDataSource.scala | 249 ++++++++++++++++++ 11 files changed, 842 insertions(+), 14 deletions(-) create mode 100644 sql/core/src/main/java/org/apache/spark/sql/sources/v2/WriteSupport.java create mode 100644 sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataSourceV2Writer.java create mode 100644 sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataWriter.java create mode 100644 sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataWriterFactory.java create mode 100644 sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/SupportsWriteInternalRow.java create mode 100644 sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/WriterCommitMessage.java create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2.scala create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/sources/v2/SimpleWritableDataSource.scala diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/WriteSupport.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/WriteSupport.java new file mode 100644 index 0000000000000..a8a961598bde3 --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/WriteSupport.java @@ -0,0 +1,49 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.sources.v2; + +import java.util.Optional; + +import org.apache.spark.annotation.InterfaceStability; +import org.apache.spark.sql.SaveMode; +import org.apache.spark.sql.sources.v2.writer.DataSourceV2Writer; +import org.apache.spark.sql.types.StructType; + +/** + * A mix-in interface for {@link DataSourceV2}. Data sources can implement this interface to + * provide data writing ability and save the data to the data source. + */ +@InterfaceStability.Evolving +public interface WriteSupport { + + /** + * Creates an optional {@link DataSourceV2Writer} to save the data to this data source. Data + * sources can return None if there is no writing needed to be done according to the save mode. + * + * @param jobId A unique string for the writing job. It's possible that there are many writing + * jobs running at the same time, and the returned {@link DataSourceV2Writer} should + * use this job id to distinguish itself with writers of other jobs. + * @param schema the schema of the data to be written. + * @param mode the save mode which determines what to do when the data are already in this data + * source, please refer to {@link SaveMode} for more details. + * @param options the options for the returned data source writer, which is an immutable + * case-insensitive string-to-string map. + */ + Optional createWriter( + String jobId, StructType schema, SaveMode mode, DataSourceV2Options options); +} diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataSourceV2Writer.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataSourceV2Writer.java new file mode 100644 index 0000000000000..8d8e33633fb0d --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataSourceV2Writer.java @@ -0,0 +1,88 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.sources.v2.writer; + +import org.apache.spark.annotation.InterfaceStability; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.SaveMode; +import org.apache.spark.sql.sources.v2.DataSourceV2Options; +import org.apache.spark.sql.sources.v2.WriteSupport; +import org.apache.spark.sql.types.StructType; + +/** + * A data source writer that is returned by + * {@link WriteSupport#createWriter(String, StructType, SaveMode, DataSourceV2Options)}. + * It can mix in various writing optimization interfaces to speed up the data saving. The actual + * writing logic is delegated to {@link DataWriter}. + * + * The writing procedure is: + * 1. Create a writer factory by {@link #createWriterFactory()}, serialize and send it to all the + * partitions of the input data(RDD). + * 2. For each partition, create the data writer, and write the data of the partition with this + * writer. If all the data are written successfully, call {@link DataWriter#commit()}. If + * exception happens during the writing, call {@link DataWriter#abort()}. + * 3. If all writers are successfully committed, call {@link #commit(WriterCommitMessage[])}. If + * some writers are aborted, or the job failed with an unknown reason, call + * {@link #abort(WriterCommitMessage[])}. + * + * Spark won't retry failed writing jobs, users should do it manually in their Spark applications if + * they want to retry. + * + * Please refer to the document of commit/abort methods for detailed specifications. + * + * Note that, this interface provides a protocol between Spark and data sources for transactional + * data writing, but the transaction here is Spark-level transaction, which may not be the + * underlying storage transaction. For example, Spark successfully writes data to a Cassandra data + * source, but Cassandra may need some more time to reach consistency at storage level. + */ +@InterfaceStability.Evolving +public interface DataSourceV2Writer { + + /** + * Creates a writer factory which will be serialized and sent to executors. + */ + DataWriterFactory createWriterFactory(); + + /** + * Commits this writing job with a list of commit messages. The commit messages are collected from + * successful data writers and are produced by {@link DataWriter#commit()}. If this method + * fails(throw exception), this writing job is considered to be failed, and + * {@link #abort(WriterCommitMessage[])} will be called. The written data should only be visible + * to data source readers if this method succeeds. + * + * Note that, one partition may have multiple committed data writers because of speculative tasks. + * Spark will pick the first successful one and get its commit message. Implementations should be + * aware of this and handle it correctly, e.g., have a mechanism to make sure only one data writer + * can commit successfully, or have a way to clean up the data of already-committed writers. + */ + void commit(WriterCommitMessage[] messages); + + /** + * Aborts this writing job because some data writers are failed to write the records and aborted, + * or the Spark job fails with some unknown reasons, or {@link #commit(WriterCommitMessage[])} + * fails. If this method fails(throw exception), the underlying data source may have garbage that + * need to be cleaned manually, but these garbage should not be visible to data source readers. + * + * Unless the abort is triggered by the failure of commit, the given messages should have some + * null slots as there maybe only a few data writers that are committed before the abort + * happens, or some data writers were committed but their commit messages haven't reached the + * driver when the abort is triggered. So this is just a "best effort" for data sources to + * clean up the data left by data writers. + */ + void abort(WriterCommitMessage[] messages); +} diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataWriter.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataWriter.java new file mode 100644 index 0000000000000..14261419af6f6 --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataWriter.java @@ -0,0 +1,92 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.sources.v2.writer; + +import org.apache.spark.annotation.InterfaceStability; + +/** + * A data writer returned by {@link DataWriterFactory#createWriter(int, int)} and is + * responsible for writing data for an input RDD partition. + * + * One Spark task has one exclusive data writer, so there is no thread-safe concern. + * + * {@link #write(Object)} is called for each record in the input RDD partition. If one record fails + * the {@link #write(Object)}, {@link #abort()} is called afterwards and the remaining records will + * not be processed. If all records are successfully written, {@link #commit()} is called. + * + * If this data writer succeeds(all records are successfully written and {@link #commit()} + * succeeds), a {@link WriterCommitMessage} will be sent to the driver side and pass to + * {@link DataSourceV2Writer#commit(WriterCommitMessage[])} with commit messages from other data + * writers. If this data writer fails(one record fails to write or {@link #commit()} fails), an + * exception will be sent to the driver side, and Spark will retry this writing task for some times, + * each time {@link DataWriterFactory#createWriter(int, int)} gets a different `attemptNumber`, + * and finally call {@link DataSourceV2Writer#abort(WriterCommitMessage[])} if all retry fail. + * + * Besides the retry mechanism, Spark may launch speculative tasks if the existing writing task + * takes too long to finish. Different from retried tasks, which are launched one by one after the + * previous one fails, speculative tasks are running simultaneously. It's possible that one input + * RDD partition has multiple data writers with different `attemptNumber` running at the same time, + * and data sources should guarantee that these data writers don't conflict and can work together. + * Implementations can coordinate with driver during {@link #commit()} to make sure only one of + * these data writers can commit successfully. Or implementations can allow all of them to commit + * successfully, and have a way to revert committed data writers without the commit message, because + * Spark only accepts the commit message that arrives first and ignore others. + * + * Note that, Currently the type `T` can only be {@link org.apache.spark.sql.Row} for normal data + * source writers, or {@link org.apache.spark.sql.catalyst.InternalRow} for data source writers + * that mix in {@link SupportsWriteInternalRow}. + */ +@InterfaceStability.Evolving +public interface DataWriter { + + /** + * Writes one record. + * + * If this method fails(throw exception), {@link #abort()} will be called and this data writer is + * considered to be failed. + */ + void write(T record); + + /** + * Commits this writer after all records are written successfully, returns a commit message which + * will be send back to driver side and pass to + * {@link DataSourceV2Writer#commit(WriterCommitMessage[])}. + * + * The written data should only be visible to data source readers after + * {@link DataSourceV2Writer#commit(WriterCommitMessage[])} succeeds, which means this method + * should still "hide" the written data and ask the {@link DataSourceV2Writer} at driver side to + * do the final commitment via {@link WriterCommitMessage}. + * + * If this method fails(throw exception), {@link #abort()} will be called and this data writer is + * considered to be failed. + */ + WriterCommitMessage commit(); + + /** + * Aborts this writer if it is failed. Implementations should clean up the data for already + * written records. + * + * This method will only be called if there is one record failed to write, or {@link #commit()} + * failed. + * + * If this method fails(throw exception), the underlying data source may have garbage that need + * to be cleaned by {@link DataSourceV2Writer#abort(WriterCommitMessage[])} or manually, but + * these garbage should not be visible to data source readers. + */ + void abort(); +} diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataWriterFactory.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataWriterFactory.java new file mode 100644 index 0000000000000..f812d102bda1a --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataWriterFactory.java @@ -0,0 +1,50 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.sources.v2.writer; + +import java.io.Serializable; + +import org.apache.spark.annotation.InterfaceStability; + +/** + * A factory of {@link DataWriter} returned by {@link DataSourceV2Writer#createWriterFactory()}, + * which is responsible for creating and initializing the actual data writer at executor side. + * + * Note that, the writer factory will be serialized and sent to executors, then the data writer + * will be created on executors and do the actual writing. So {@link DataWriterFactory} must be + * serializable and {@link DataWriter} doesn't need to be. + */ +@InterfaceStability.Evolving +public interface DataWriterFactory extends Serializable { + + /** + * Returns a data writer to do the actual writing work. + * + * @param partitionId A unique id of the RDD partition that the returned writer will process. + * Usually Spark processes many RDD partitions at the same time, + * implementations should use the partition id to distinguish writers for + * different partitions. + * @param attemptNumber Spark may launch multiple tasks with the same task id. For example, a task + * failed, Spark launches a new task wth the same task id but different + * attempt number. Or a task is too slow, Spark launches new tasks wth the + * same task id but different attempt number, which means there are multiple + * tasks with the same task id running at the same time. Implementations can + * use this attempt number to distinguish writers of different task attempts. + */ + DataWriter createWriter(int partitionId, int attemptNumber); +} diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/SupportsWriteInternalRow.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/SupportsWriteInternalRow.java new file mode 100644 index 0000000000000..a8e95901f3b07 --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/SupportsWriteInternalRow.java @@ -0,0 +1,44 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.sources.v2.writer; + +import org.apache.spark.annotation.Experimental; +import org.apache.spark.annotation.InterfaceStability; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.catalyst.InternalRow; + +/** + * A mix-in interface for {@link DataSourceV2Writer}. Data source writers can implement this + * interface to write {@link InternalRow} directly and avoid the row conversion at Spark side. + * This is an experimental and unstable interface, as {@link InternalRow} is not public and may get + * changed in the future Spark versions. + */ + +@InterfaceStability.Evolving +@Experimental +@InterfaceStability.Unstable +public interface SupportsWriteInternalRow extends DataSourceV2Writer { + + @Override + default DataWriterFactory createWriterFactory() { + throw new IllegalStateException( + "createWriterFactory should not be called with SupportsWriteInternalRow."); + } + + DataWriterFactory createInternalRowWriterFactory(); +} diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/WriterCommitMessage.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/WriterCommitMessage.java new file mode 100644 index 0000000000000..082d6b5dc409f --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/WriterCommitMessage.java @@ -0,0 +1,33 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.sources.v2.writer; + +import java.io.Serializable; + +import org.apache.spark.annotation.InterfaceStability; + +/** + * A commit message returned by {@link DataWriter#commit()} and will be sent back to the driver side + * as the input parameter of {@link DataSourceV2Writer#commit(WriterCommitMessage[])}. + * + * This is an empty interface, data sources should define their own message class and use it in + * their {@link DataWriter#commit()} and {@link DataSourceV2Writer#commit(WriterCommitMessage[])} + * implementations. + */ +@InterfaceStability.Evolving +public interface WriterCommitMessage extends Serializable {} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala index c9e45436ed42f..8d95b24c00619 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala @@ -17,7 +17,8 @@ package org.apache.spark.sql -import java.util.{Locale, Properties} +import java.text.SimpleDateFormat +import java.util.{Date, Locale, Properties, UUID} import scala.collection.JavaConverters._ @@ -29,7 +30,9 @@ import org.apache.spark.sql.catalyst.plans.logical.{InsertIntoTable, LogicalPlan import org.apache.spark.sql.execution.SQLExecution import org.apache.spark.sql.execution.command.DDLUtils import org.apache.spark.sql.execution.datasources.{CreateTable, DataSource, LogicalRelation} +import org.apache.spark.sql.execution.datasources.v2.WriteToDataSourceV2 import org.apache.spark.sql.sources.BaseRelation +import org.apache.spark.sql.sources.v2.{DataSourceV2, DataSourceV2Options, WriteSupport} import org.apache.spark.sql.types.StructType /** @@ -231,12 +234,33 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { assertNotBucketed("save") - runCommand(df.sparkSession, "save") { - DataSource( - sparkSession = df.sparkSession, - className = source, - partitionColumns = partitioningColumns.getOrElse(Nil), - options = extraOptions.toMap).planForWriting(mode, df.logicalPlan) + val cls = DataSource.lookupDataSource(source) + if (classOf[DataSourceV2].isAssignableFrom(cls)) { + cls.newInstance() match { + case ds: WriteSupport => + val options = new DataSourceV2Options(extraOptions.asJava) + // Using a timestamp and a random UUID to distinguish different writing jobs. This is good + // enough as there won't be tons of writing jobs created at the same second. + val jobId = new SimpleDateFormat("yyyyMMddHHmmss", Locale.US) + .format(new Date()) + "-" + UUID.randomUUID() + val writer = ds.createWriter(jobId, df.logicalPlan.schema, mode, options) + if (writer.isPresent) { + runCommand(df.sparkSession, "save") { + WriteToDataSourceV2(writer.get(), df.logicalPlan) + } + } + + case _ => throw new AnalysisException(s"$cls does not support data writing.") + } + } else { + // Code path for data source v1. + runCommand(df.sparkSession, "save") { + DataSource( + sparkSession = df.sparkSession, + className = source, + partitionColumns = partitioningColumns.getOrElse(Nil), + options = extraOptions.toMap).planForWriting(mode, df.logicalPlan) + } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala index f2cda002245e8..df5b524485f54 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala @@ -18,20 +18,17 @@ package org.apache.spark.sql.execution.datasources.v2 import org.apache.spark.sql.Strategy -import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.planning.PhysicalOperation import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan -import org.apache.spark.sql.execution.{FilterExec, ProjectExec, SparkPlan} -import org.apache.spark.sql.execution.datasources.DataSourceStrategy -import org.apache.spark.sql.sources.Filter -import org.apache.spark.sql.sources.v2.reader._ +import org.apache.spark.sql.execution.SparkPlan object DataSourceV2Strategy extends Strategy { - // TODO: write path override def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { case DataSourceV2Relation(output, reader) => DataSourceV2ScanExec(output, reader) :: Nil + case WriteToDataSourceV2(writer, query) => + WriteToDataSourceV2Exec(writer, planLater(query)) :: Nil + case _ => Nil } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2.scala new file mode 100644 index 0000000000000..92c1e1f4a3383 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2.scala @@ -0,0 +1,133 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources.v2 + +import org.apache.spark.{SparkException, TaskContext} +import org.apache.spark.internal.Logging +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.Row +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.encoders.{ExpressionEncoder, RowEncoder} +import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.execution.SparkPlan +import org.apache.spark.sql.sources.v2.writer._ +import org.apache.spark.sql.types.StructType +import org.apache.spark.util.Utils + +/** + * The logical plan for writing data into data source v2. + */ +case class WriteToDataSourceV2(writer: DataSourceV2Writer, query: LogicalPlan) extends LogicalPlan { + override def children: Seq[LogicalPlan] = Seq(query) + override def output: Seq[Attribute] = Nil +} + +/** + * The physical plan for writing data into data source v2. + */ +case class WriteToDataSourceV2Exec(writer: DataSourceV2Writer, query: SparkPlan) extends SparkPlan { + override def children: Seq[SparkPlan] = Seq(query) + override def output: Seq[Attribute] = Nil + + override protected def doExecute(): RDD[InternalRow] = { + val writeTask = writer match { + case w: SupportsWriteInternalRow => w.createInternalRowWriterFactory() + case _ => new RowToInternalRowDataWriterFactory(writer.createWriterFactory(), query.schema) + } + + val rdd = query.execute() + val messages = new Array[WriterCommitMessage](rdd.partitions.length) + + logInfo(s"Start processing data source writer: $writer. " + + s"The input RDD has ${messages.length} partitions.") + + try { + sparkContext.runJob( + rdd, + (context: TaskContext, iter: Iterator[InternalRow]) => + DataWritingSparkTask.run(writeTask, context, iter), + rdd.partitions.indices, + (index, message: WriterCommitMessage) => messages(index) = message + ) + + logInfo(s"Data source writer $writer is committing.") + writer.commit(messages) + logInfo(s"Data source writer $writer committed.") + } catch { + case cause: Throwable => + logError(s"Data source writer $writer is aborting.") + try { + writer.abort(messages) + } catch { + case t: Throwable => + logError(s"Data source writer $writer failed to abort.") + cause.addSuppressed(t) + throw new SparkException("Writing job failed.", cause) + } + logError(s"Data source writer $writer aborted.") + throw new SparkException("Writing job aborted.", cause) + } + + sparkContext.emptyRDD + } +} + +object DataWritingSparkTask extends Logging { + def run( + writeTask: DataWriterFactory[InternalRow], + context: TaskContext, + iter: Iterator[InternalRow]): WriterCommitMessage = { + val dataWriter = writeTask.createWriter(context.partitionId(), context.attemptNumber()) + + // write the data and commit this writer. + Utils.tryWithSafeFinallyAndFailureCallbacks(block = { + iter.foreach(dataWriter.write) + logInfo(s"Writer for partition ${context.partitionId()} is committing.") + val msg = dataWriter.commit() + logInfo(s"Writer for partition ${context.partitionId()} committed.") + msg + })(catchBlock = { + // If there is an error, abort this writer + logError(s"Writer for partition ${context.partitionId()} is aborting.") + dataWriter.abort() + logError(s"Writer for partition ${context.partitionId()} aborted.") + }) + } +} + +class RowToInternalRowDataWriterFactory( + rowWriterFactory: DataWriterFactory[Row], + schema: StructType) extends DataWriterFactory[InternalRow] { + + override def createWriter(partitionId: Int, attemptNumber: Int): DataWriter[InternalRow] = { + new RowToInternalRowDataWriter( + rowWriterFactory.createWriter(partitionId, attemptNumber), + RowEncoder.apply(schema).resolveAndBind()) + } +} + +class RowToInternalRowDataWriter(rowWriter: DataWriter[Row], encoder: ExpressionEncoder[Row]) + extends DataWriter[InternalRow] { + + override def write(record: InternalRow): Unit = rowWriter.write(encoder.fromRow(record)) + + override def commit(): WriterCommitMessage = rowWriter.commit() + + override def abort(): Unit = rowWriter.abort() +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala index f238e565dc2fc..092702a1d5173 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala @@ -21,6 +21,7 @@ import java.util.{ArrayList, List => JList} import test.org.apache.spark.sql.sources.v2._ +import org.apache.spark.SparkException import org.apache.spark.sql.{AnalysisException, QueryTest, Row} import org.apache.spark.sql.catalyst.expressions.UnsafeRow import org.apache.spark.sql.sources.{Filter, GreaterThan} @@ -80,6 +81,74 @@ class DataSourceV2Suite extends QueryTest with SharedSQLContext { } } } + + test("simple writable data source") { + // TODO: java implementation. + Seq(classOf[SimpleWritableDataSource]).foreach { cls => + withTempPath { file => + val path = file.getCanonicalPath + assert(spark.read.format(cls.getName).option("path", path).load().collect().isEmpty) + + spark.range(10).select('id, -'id).write.format(cls.getName) + .option("path", path).save() + checkAnswer( + spark.read.format(cls.getName).option("path", path).load(), + spark.range(10).select('id, -'id)) + + // test with different save modes + spark.range(10).select('id, -'id).write.format(cls.getName) + .option("path", path).mode("append").save() + checkAnswer( + spark.read.format(cls.getName).option("path", path).load(), + spark.range(10).union(spark.range(10)).select('id, -'id)) + + spark.range(5).select('id, -'id).write.format(cls.getName) + .option("path", path).mode("overwrite").save() + checkAnswer( + spark.read.format(cls.getName).option("path", path).load(), + spark.range(5).select('id, -'id)) + + spark.range(5).select('id, -'id).write.format(cls.getName) + .option("path", path).mode("ignore").save() + checkAnswer( + spark.read.format(cls.getName).option("path", path).load(), + spark.range(5).select('id, -'id)) + + val e = intercept[Exception] { + spark.range(5).select('id, -'id).write.format(cls.getName) + .option("path", path).mode("error").save() + } + assert(e.getMessage.contains("data already exists")) + + // test transaction + val failingUdf = org.apache.spark.sql.functions.udf { + var count = 0 + (id: Long) => { + if (count > 5) { + throw new RuntimeException("testing error") + } + count += 1 + id + } + } + // this input data will fail to read middle way. + val input = spark.range(10).select(failingUdf('id).as('i)).select('i, -'i) + val e2 = intercept[SparkException] { + input.write.format(cls.getName).option("path", path).mode("overwrite").save() + } + assert(e2.getMessage.contains("Writing job aborted")) + // make sure we don't have partial data. + assert(spark.read.format(cls.getName).option("path", path).load().collect().isEmpty) + + // test internal row writer + spark.range(5).select('id, -'id).write.format(cls.getName) + .option("path", path).option("internal", "true").mode("overwrite").save() + checkAnswer( + spark.read.format(cls.getName).option("path", path).load(), + spark.range(5).select('id, -'id)) + } + } + } } class SimpleDataSourceV2 extends DataSourceV2 with ReadSupport { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/SimpleWritableDataSource.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/SimpleWritableDataSource.scala new file mode 100644 index 0000000000000..6fb60f4d848d7 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/SimpleWritableDataSource.scala @@ -0,0 +1,249 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.sources.v2 + +import java.io.{BufferedReader, InputStreamReader, IOException} +import java.text.SimpleDateFormat +import java.util.{Collections, Date, List => JList, Locale, Optional, UUID} + +import scala.collection.JavaConverters._ + +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.fs.{FileSystem, FSDataInputStream, Path} + +import org.apache.spark.SparkContext +import org.apache.spark.sql.{Row, SaveMode} +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.sources.v2.reader.{DataReader, DataSourceV2Reader, ReadTask} +import org.apache.spark.sql.sources.v2.writer._ +import org.apache.spark.sql.types.{DataType, StructType} +import org.apache.spark.util.SerializableConfiguration + +/** + * A HDFS based transactional writable data source. + * Each task writes data to `target/_temporary/jobId/$jobId-$partitionId-$attemptNumber`. + * Each job moves files from `target/_temporary/jobId/` to `target`. + */ +class SimpleWritableDataSource extends DataSourceV2 with ReadSupport with WriteSupport { + + private val schema = new StructType().add("i", "long").add("j", "long") + + class Reader(path: String, conf: Configuration) extends DataSourceV2Reader { + override def readSchema(): StructType = schema + + override def createReadTasks(): JList[ReadTask[Row]] = { + val dataPath = new Path(path) + val fs = dataPath.getFileSystem(conf) + if (fs.exists(dataPath)) { + fs.listStatus(dataPath).filterNot { status => + val name = status.getPath.getName + name.startsWith("_") || name.startsWith(".") + }.map { f => + val serializableConf = new SerializableConfiguration(conf) + new SimpleCSVReadTask(f.getPath.toUri.toString, serializableConf): ReadTask[Row] + }.toList.asJava + } else { + Collections.emptyList() + } + } + } + + class Writer(jobId: String, path: String, conf: Configuration) extends DataSourceV2Writer { + override def createWriterFactory(): DataWriterFactory[Row] = { + new SimpleCSVDataWriterFactory(path, jobId, new SerializableConfiguration(conf)) + } + + override def commit(messages: Array[WriterCommitMessage]): Unit = { + val finalPath = new Path(path) + val jobPath = new Path(new Path(finalPath, "_temporary"), jobId) + val fs = jobPath.getFileSystem(conf) + try { + for (file <- fs.listStatus(jobPath).map(_.getPath)) { + val dest = new Path(finalPath, file.getName) + if(!fs.rename(file, dest)) { + throw new IOException(s"failed to rename($file, $dest)") + } + } + } finally { + fs.delete(jobPath, true) + } + } + + override def abort(messages: Array[WriterCommitMessage]): Unit = { + val jobPath = new Path(new Path(path, "_temporary"), jobId) + val fs = jobPath.getFileSystem(conf) + fs.delete(jobPath, true) + } + } + + class InternalRowWriter(jobId: String, path: String, conf: Configuration) + extends Writer(jobId, path, conf) with SupportsWriteInternalRow { + + override def createWriterFactory(): DataWriterFactory[Row] = { + throw new IllegalArgumentException("not expected!") + } + + override def createInternalRowWriterFactory(): DataWriterFactory[InternalRow] = { + new InternalRowCSVDataWriterFactory(path, jobId, new SerializableConfiguration(conf)) + } + } + + override def createReader(options: DataSourceV2Options): DataSourceV2Reader = { + val path = new Path(options.get("path").get()) + val conf = SparkContext.getActive.get.hadoopConfiguration + new Reader(path.toUri.toString, conf) + } + + override def createWriter( + jobId: String, + schema: StructType, + mode: SaveMode, + options: DataSourceV2Options): Optional[DataSourceV2Writer] = { + assert(DataType.equalsStructurally(schema.asNullable, this.schema.asNullable)) + assert(!SparkContext.getActive.get.conf.getBoolean("spark.speculation", false)) + + val path = new Path(options.get("path").get()) + val internal = options.get("internal").isPresent + val conf = SparkContext.getActive.get.hadoopConfiguration + val fs = path.getFileSystem(conf) + + if (mode == SaveMode.ErrorIfExists) { + if (fs.exists(path)) { + throw new RuntimeException("data already exists.") + } + } + if (mode == SaveMode.Ignore) { + if (fs.exists(path)) { + return Optional.empty() + } + } + if (mode == SaveMode.Overwrite) { + fs.delete(path, true) + } + + Optional.of(createWriter(jobId, path, conf, internal)) + } + + private def createWriter( + jobId: String, path: Path, conf: Configuration, internal: Boolean): DataSourceV2Writer = { + val pathStr = path.toUri.toString + if (internal) { + new InternalRowWriter(jobId, pathStr, conf) + } else { + new Writer(jobId, pathStr, conf) + } + } +} + +class SimpleCSVReadTask(path: String, conf: SerializableConfiguration) + extends ReadTask[Row] with DataReader[Row] { + + @transient private var lines: Iterator[String] = _ + @transient private var currentLine: String = _ + @transient private var inputStream: FSDataInputStream = _ + + override def createReader(): DataReader[Row] = { + val filePath = new Path(path) + val fs = filePath.getFileSystem(conf.value) + inputStream = fs.open(filePath) + lines = new BufferedReader(new InputStreamReader(inputStream)) + .lines().iterator().asScala + this + } + + override def next(): Boolean = { + if (lines.hasNext) { + currentLine = lines.next() + true + } else { + false + } + } + + override def get(): Row = Row(currentLine.split(",").map(_.trim.toLong): _*) + + override def close(): Unit = { + inputStream.close() + } +} + +class SimpleCSVDataWriterFactory(path: String, jobId: String, conf: SerializableConfiguration) + extends DataWriterFactory[Row] { + + override def createWriter(partitionId: Int, attemptNumber: Int): DataWriter[Row] = { + val jobPath = new Path(new Path(path, "_temporary"), jobId) + val filePath = new Path(jobPath, s"$jobId-$partitionId-$attemptNumber") + val fs = filePath.getFileSystem(conf.value) + new SimpleCSVDataWriter(fs, filePath) + } +} + +class SimpleCSVDataWriter(fs: FileSystem, file: Path) extends DataWriter[Row] { + + private val out = fs.create(file) + + override def write(record: Row): Unit = { + out.writeBytes(s"${record.getLong(0)},${record.getLong(1)}\n") + } + + override def commit(): WriterCommitMessage = { + out.close() + null + } + + override def abort(): Unit = { + try { + out.close() + } finally { + fs.delete(file, false) + } + } +} + +class InternalRowCSVDataWriterFactory(path: String, jobId: String, conf: SerializableConfiguration) + extends DataWriterFactory[InternalRow] { + + override def createWriter(partitionId: Int, attemptNumber: Int): DataWriter[InternalRow] = { + val jobPath = new Path(new Path(path, "_temporary"), jobId) + val filePath = new Path(jobPath, s"$jobId-$partitionId-$attemptNumber") + val fs = filePath.getFileSystem(conf.value) + new InternalRowCSVDataWriter(fs, filePath) + } +} + +class InternalRowCSVDataWriter(fs: FileSystem, file: Path) extends DataWriter[InternalRow] { + + private val out = fs.create(file) + + override def write(record: InternalRow): Unit = { + out.writeBytes(s"${record.getLong(0)},${record.getLong(1)}\n") + } + + override def commit(): WriterCommitMessage = { + out.close() + null + } + + override def abort(): Unit = { + try { + out.close() + } finally { + fs.delete(file, false) + } + } +} From b84f61cd79a365edd4cc893a1de416c628d9906b Mon Sep 17 00:00:00 2001 From: Eric Perry Date: Thu, 19 Oct 2017 23:57:41 -0700 Subject: [PATCH 1513/1765] [SQL] Mark strategies with override for clarity. ## What changes were proposed in this pull request? This is a very trivial PR, simply marking `strategies` in `SparkPlanner` with the `override` keyword for clarity since it is overriding `strategies` in `QueryPlanner` two levels up in the class hierarchy. I was reading through the code to learn a bit and got stuck on this fact for a little while, so I figured this may be helpful so that another developer new to the project doesn't get stuck where I was. I did not make a JIRA ticket for this because it is so trivial, but I'm happy to do so to adhere to the contribution guidelines if required. ## How was this patch tested? (Please explain how this patch was tested. E.g. unit tests, integration tests, manual tests) (If this patch involves UI changes, please attach a screenshot; otherwise, remove this) Please review http://spark.apache.org/contributing.html before opening a pull request. Author: Eric Perry Closes #19537 from ericjperry/override-strategies. --- .../scala/org/apache/spark/sql/execution/SparkPlanner.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanner.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanner.scala index b143d44eae17b..74048871f8d42 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanner.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanner.scala @@ -33,7 +33,7 @@ class SparkPlanner( def numPartitions: Int = conf.numShufflePartitions - def strategies: Seq[Strategy] = + override def strategies: Seq[Strategy] = experimentalMethods.extraStrategies ++ extraPlanningStrategies ++ ( DataSourceV2Strategy :: From 673876b7eadc6f382afc26fc654b0e7916c9ac5c Mon Sep 17 00:00:00 2001 From: Zheng RuiFeng Date: Fri, 20 Oct 2017 08:28:05 +0100 Subject: [PATCH 1514/1765] [SPARK-22309][ML] Remove unused param in `LDAModel.getTopicDistributionMethod` ## What changes were proposed in this pull request? Remove unused param in `LDAModel.getTopicDistributionMethod` ## How was this patch tested? existing tests Author: Zheng RuiFeng Closes #19530 from zhengruifeng/lda_bc. --- mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala | 2 +- .../main/scala/org/apache/spark/mllib/clustering/LDAModel.scala | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala index 3da29b1c816b1..4bab670cc159f 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala @@ -458,7 +458,7 @@ abstract class LDAModel private[ml] ( if ($(topicDistributionCol).nonEmpty) { // TODO: Make the transformer natively in ml framework to avoid extra conversion. - val transformer = oldLocalModel.getTopicDistributionMethod(sparkSession.sparkContext) + val transformer = oldLocalModel.getTopicDistributionMethod val t = udf { (v: Vector) => transformer(OldVectors.fromML(v)).asML } dataset.withColumn($(topicDistributionCol), t(col($(featuresCol)))).toDF() diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala index 4ab420058f33d..b8a6e94248421 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala @@ -371,7 +371,7 @@ class LocalLDAModel private[spark] ( /** * Get a method usable as a UDF for `topicDistributions()` */ - private[spark] def getTopicDistributionMethod(sc: SparkContext): Vector => Vector = { + private[spark] def getTopicDistributionMethod: Vector => Vector = { val expElogbeta = exp(LDAUtils.dirichletExpectation(topicsMatrix.asBreeze.toDenseMatrix.t).t) val docConcentrationBrz = this.docConcentration.asBreeze val gammaShape = this.gammaShape From e2fea8cd6058a807ff4841b496ea345ff0553044 Mon Sep 17 00:00:00 2001 From: guoxiaolong Date: Fri, 20 Oct 2017 09:43:46 +0100 Subject: [PATCH 1515/1765] [CORE][DOC] Add event log conf. ## What changes were proposed in this pull request? Event Log Server has a total of five configuration parameters, and now the description of the other two configuration parameters on the doc, user-friendly access and use. ## How was this patch tested? manual tests Please review http://spark.apache.org/contributing.html before opening a pull request. Author: guoxiaolong Closes #19242 from guoxiaolongzte/addEventLogConf. --- docs/configuration.md | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/docs/configuration.md b/docs/configuration.md index 7b9e16a382449..d3c358bb74173 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -748,6 +748,20 @@ Apart from these, the following properties are also available, and may be useful finished. + + spark.eventLog.overwrite + false + + Whether to overwrite any existing files. + + + + spark.eventLog.buffer.kb + 100k + + Buffer size in KB to use when writing to output streams. + + spark.ui.enabled true From 16c9cc68c5a70fd50e214f6deba591f0a9ae5cca Mon Sep 17 00:00:00 2001 From: CenYuhai Date: Fri, 20 Oct 2017 09:27:39 -0700 Subject: [PATCH 1516/1765] [SPARK-21055][SQL] replace grouping__id with grouping_id() ## What changes were proposed in this pull request? spark does not support grouping__id, it has grouping_id() instead. But it is not convenient for hive user to change to spark-sql so this pr is to replace grouping__id with grouping_id() hive user need not to alter their scripts ## How was this patch tested? test with SQLQuerySuite.scala Author: CenYuhai Closes #18270 from cenyuhai/SPARK-21055. --- .../sql/catalyst/analysis/Analyzer.scala | 15 +-- .../sql-tests/inputs/group-analytics.sql | 6 +- .../sql-tests/results/group-analytics.sql.out | 43 ++++--- .../sql/hive/execution/SQLQuerySuite.scala | 110 ++++++++++++++++++ 4 files changed, 148 insertions(+), 26 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 8edf575db7969..d6a962a14dc9c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -29,7 +29,7 @@ import org.apache.spark.sql.catalyst.expressions.SubExprUtils._ import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.expressions.objects.{LambdaVariable, MapObjects, NewInstance, UnresolvedMapObjects} import org.apache.spark.sql.catalyst.plans._ -import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, _} +import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules._ import org.apache.spark.sql.catalyst.trees.TreeNodeRef import org.apache.spark.sql.catalyst.util.toPrettySQL @@ -293,12 +293,6 @@ class Analyzer( Seq(Seq.empty) } - private def hasGroupingAttribute(expr: Expression): Boolean = { - expr.collectFirst { - case u: UnresolvedAttribute if resolver(u.name, VirtualColumn.hiveGroupingIdName) => u - }.isDefined - } - private[analysis] def hasGroupingFunction(e: Expression): Boolean = { e.collectFirst { case g: Grouping => g @@ -452,9 +446,6 @@ class Analyzer( // This require transformUp to replace grouping()/grouping_id() in resolved Filter/Sort def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { case a if !a.childrenResolved => a // be sure all of the children are resolved. - case p if p.expressions.exists(hasGroupingAttribute) => - failAnalysis( - s"${VirtualColumn.hiveGroupingIdName} is deprecated; use grouping_id() instead") // Ensure group by expressions and aggregate expressions have been resolved. case Aggregate(Seq(c @ Cube(groupByExprs)), aggregateExpressions, child) @@ -1174,6 +1165,10 @@ class Analyzer( case q: LogicalPlan => q transformExpressions { case u if !u.childrenResolved => u // Skip until children are resolved. + case u: UnresolvedAttribute if resolver(u.name, VirtualColumn.hiveGroupingIdName) => + withPosition(u) { + Alias(GroupingID(Nil), VirtualColumn.hiveGroupingIdName)() + } case u @ UnresolvedGenerator(name, children) => withPosition(u) { catalog.lookupFunction(name, children) match { diff --git a/sql/core/src/test/resources/sql-tests/inputs/group-analytics.sql b/sql/core/src/test/resources/sql-tests/inputs/group-analytics.sql index 8aff4cb524199..9721f8c60ebce 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/group-analytics.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/group-analytics.sql @@ -38,11 +38,11 @@ SELECT course, year, GROUPING(course), GROUPING(year), GROUPING_ID(course, year) GROUP BY CUBE(course, year); SELECT course, year, GROUPING(course) FROM courseSales GROUP BY course, year; SELECT course, year, GROUPING_ID(course, year) FROM courseSales GROUP BY course, year; -SELECT course, year, grouping__id FROM courseSales GROUP BY CUBE(course, year); +SELECT course, year, grouping__id FROM courseSales GROUP BY CUBE(course, year) ORDER BY grouping__id, course, year; -- GROUPING/GROUPING_ID in having clause SELECT course, year FROM courseSales GROUP BY CUBE(course, year) -HAVING GROUPING(year) = 1 AND GROUPING_ID(course, year) > 0; +HAVING GROUPING(year) = 1 AND GROUPING_ID(course, year) > 0 ORDER BY course, year; SELECT course, year FROM courseSales GROUP BY course, year HAVING GROUPING(course) > 0; SELECT course, year FROM courseSales GROUP BY course, year HAVING GROUPING_ID(course) > 0; SELECT course, year FROM courseSales GROUP BY CUBE(course, year) HAVING grouping__id > 0; @@ -54,7 +54,7 @@ SELECT course, year, GROUPING_ID(course, year) FROM courseSales GROUP BY CUBE(co ORDER BY GROUPING(course), GROUPING(year), course, year; SELECT course, year FROM courseSales GROUP BY course, year ORDER BY GROUPING(course); SELECT course, year FROM courseSales GROUP BY course, year ORDER BY GROUPING_ID(course); -SELECT course, year FROM courseSales GROUP BY CUBE(course, year) ORDER BY grouping__id; +SELECT course, year FROM courseSales GROUP BY CUBE(course, year) ORDER BY grouping__id, course, year; -- Aliases in SELECT could be used in ROLLUP/CUBE/GROUPING SETS SELECT a + b AS k1, b AS k2, SUM(a - b) FROM testData GROUP BY CUBE(k1, k2); diff --git a/sql/core/src/test/resources/sql-tests/results/group-analytics.sql.out b/sql/core/src/test/resources/sql-tests/results/group-analytics.sql.out index ce7a16a4d0c81..3439a05727f95 100644 --- a/sql/core/src/test/resources/sql-tests/results/group-analytics.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/group-analytics.sql.out @@ -223,22 +223,29 @@ grouping_id() can only be used with GroupingSets/Cube/Rollup; -- !query 16 -SELECT course, year, grouping__id FROM courseSales GROUP BY CUBE(course, year) +SELECT course, year, grouping__id FROM courseSales GROUP BY CUBE(course, year) ORDER BY grouping__id, course, year -- !query 16 schema -struct<> +struct -- !query 16 output -org.apache.spark.sql.AnalysisException -grouping__id is deprecated; use grouping_id() instead; +Java 2012 0 +Java 2013 0 +dotNET 2012 0 +dotNET 2013 0 +Java NULL 1 +dotNET NULL 1 +NULL 2012 2 +NULL 2013 2 +NULL NULL 3 -- !query 17 SELECT course, year FROM courseSales GROUP BY CUBE(course, year) -HAVING GROUPING(year) = 1 AND GROUPING_ID(course, year) > 0 +HAVING GROUPING(year) = 1 AND GROUPING_ID(course, year) > 0 ORDER BY course, year -- !query 17 schema struct -- !query 17 output -Java NULL NULL NULL +Java NULL dotNET NULL @@ -263,10 +270,13 @@ grouping()/grouping_id() can only be used with GroupingSets/Cube/Rollup; -- !query 20 SELECT course, year FROM courseSales GROUP BY CUBE(course, year) HAVING grouping__id > 0 -- !query 20 schema -struct<> +struct -- !query 20 output -org.apache.spark.sql.AnalysisException -grouping__id is deprecated; use grouping_id() instead; +Java NULL +NULL 2012 +NULL 2013 +NULL NULL +dotNET NULL -- !query 21 @@ -322,12 +332,19 @@ grouping()/grouping_id() can only be used with GroupingSets/Cube/Rollup; -- !query 25 -SELECT course, year FROM courseSales GROUP BY CUBE(course, year) ORDER BY grouping__id +SELECT course, year FROM courseSales GROUP BY CUBE(course, year) ORDER BY grouping__id, course, year -- !query 25 schema -struct<> +struct -- !query 25 output -org.apache.spark.sql.AnalysisException -grouping__id is deprecated; use grouping_id() instead; +Java 2012 +Java 2013 +dotNET 2012 +dotNET 2013 +Java NULL +dotNET NULL +NULL 2012 +NULL 2013 +NULL NULL -- !query 26 diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala index 60935c3e85c43..2476a440ad82c 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala @@ -1388,6 +1388,19 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton { ).map(i => Row(i._1, i._2, i._3))) } + test("SPARK-21055 replace grouping__id: Wrong Result for Rollup #1") { + checkAnswer(sql( + "SELECT count(*) AS cnt, key % 5, grouping__id FROM src GROUP BY key%5 WITH ROLLUP"), + Seq( + (113, 3, 0), + (91, 0, 0), + (500, null, 1), + (84, 1, 0), + (105, 2, 0), + (107, 4, 0) + ).map(i => Row(i._1, i._2, i._3))) + } + test("SPARK-8976 Wrong Result for Rollup #2") { checkAnswer(sql( """ @@ -1409,6 +1422,27 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton { ).map(i => Row(i._1, i._2, i._3, i._4))) } + test("SPARK-21055 replace grouping__id: Wrong Result for Rollup #2") { + checkAnswer(sql( + """ + |SELECT count(*) AS cnt, key % 5 AS k1, key-5 AS k2, grouping__id AS k3 + |FROM src GROUP BY key%5, key-5 + |WITH ROLLUP ORDER BY cnt, k1, k2, k3 LIMIT 10 + """.stripMargin), + Seq( + (1, 0, 5, 0), + (1, 0, 15, 0), + (1, 0, 25, 0), + (1, 0, 60, 0), + (1, 0, 75, 0), + (1, 0, 80, 0), + (1, 0, 100, 0), + (1, 0, 140, 0), + (1, 0, 145, 0), + (1, 0, 150, 0) + ).map(i => Row(i._1, i._2, i._3, i._4))) + } + test("SPARK-8976 Wrong Result for Rollup #3") { checkAnswer(sql( """ @@ -1430,6 +1464,27 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton { ).map(i => Row(i._1, i._2, i._3, i._4))) } + test("SPARK-21055 replace grouping__id: Wrong Result for Rollup #3") { + checkAnswer(sql( + """ + |SELECT count(*) AS cnt, key % 5 AS k1, key-5 AS k2, grouping__id AS k3 + |FROM (SELECT key, key%2, key - 5 FROM src) t GROUP BY key%5, key-5 + |WITH ROLLUP ORDER BY cnt, k1, k2, k3 LIMIT 10 + """.stripMargin), + Seq( + (1, 0, 5, 0), + (1, 0, 15, 0), + (1, 0, 25, 0), + (1, 0, 60, 0), + (1, 0, 75, 0), + (1, 0, 80, 0), + (1, 0, 100, 0), + (1, 0, 140, 0), + (1, 0, 145, 0), + (1, 0, 150, 0) + ).map(i => Row(i._1, i._2, i._3, i._4))) + } + test("SPARK-8976 Wrong Result for CUBE #1") { checkAnswer(sql( "SELECT count(*) AS cnt, key % 5, grouping_id() FROM src GROUP BY key%5 WITH CUBE"), @@ -1443,6 +1498,19 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton { ).map(i => Row(i._1, i._2, i._3))) } + test("SPARK-21055 replace grouping__id: Wrong Result for CUBE #1") { + checkAnswer(sql( + "SELECT count(*) AS cnt, key % 5, grouping__id FROM src GROUP BY key%5 WITH CUBE"), + Seq( + (113, 3, 0), + (91, 0, 0), + (500, null, 1), + (84, 1, 0), + (105, 2, 0), + (107, 4, 0) + ).map(i => Row(i._1, i._2, i._3))) + } + test("SPARK-8976 Wrong Result for CUBE #2") { checkAnswer(sql( """ @@ -1464,6 +1532,27 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton { ).map(i => Row(i._1, i._2, i._3, i._4))) } + test("SPARK-21055 replace grouping__id: Wrong Result for CUBE #2") { + checkAnswer(sql( + """ + |SELECT count(*) AS cnt, key % 5 AS k1, key-5 AS k2, grouping__id AS k3 + |FROM (SELECT key, key%2, key - 5 FROM src) t GROUP BY key%5, key-5 + |WITH CUBE ORDER BY cnt, k1, k2, k3 LIMIT 10 + """.stripMargin), + Seq( + (1, null, -3, 2), + (1, null, -1, 2), + (1, null, 3, 2), + (1, null, 4, 2), + (1, null, 5, 2), + (1, null, 6, 2), + (1, null, 12, 2), + (1, null, 14, 2), + (1, null, 15, 2), + (1, null, 22, 2) + ).map(i => Row(i._1, i._2, i._3, i._4))) + } + test("SPARK-8976 Wrong Result for GroupingSet") { checkAnswer(sql( """ @@ -1485,6 +1574,27 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton { ).map(i => Row(i._1, i._2, i._3, i._4))) } + test("SPARK-21055 replace grouping__id: Wrong Result for GroupingSet") { + checkAnswer(sql( + """ + |SELECT count(*) AS cnt, key % 5 AS k1, key-5 AS k2, grouping__id AS k3 + |FROM (SELECT key, key%2, key - 5 FROM src) t GROUP BY key%5, key-5 + |GROUPING SETS (key%5, key-5) ORDER BY cnt, k1, k2, k3 LIMIT 10 + """.stripMargin), + Seq( + (1, null, -3, 2), + (1, null, -1, 2), + (1, null, 3, 2), + (1, null, 4, 2), + (1, null, 5, 2), + (1, null, 6, 2), + (1, null, 12, 2), + (1, null, 14, 2), + (1, null, 15, 2), + (1, null, 22, 2) + ).map(i => Row(i._1, i._2, i._3, i._4))) + } + ignore("SPARK-10562: partition by column with mixed case name") { withTable("tbl10562") { val df = Seq(2012 -> "a").toDF("Year", "val") From 568763bafb7acfcf5921d6492034d1f6f87875e2 Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Fri, 20 Oct 2017 12:32:45 -0700 Subject: [PATCH 1517/1765] [INFRA] Close stale PRs. Closes #19541 Closes #19542 From b8624b06e5d531ebc14acb05da286f96f4bc9515 Mon Sep 17 00:00:00 2001 From: Takuya UESHIN Date: Fri, 20 Oct 2017 12:44:30 -0700 Subject: [PATCH 1518/1765] [SPARK-20396][SQL][PYSPARK][FOLLOW-UP] groupby().apply() with pandas udf ## What changes were proposed in this pull request? This is a follow-up of #18732. This pr modifies `GroupedData.apply()` method to convert pandas udf to grouped udf implicitly. ## How was this patch tested? Exisiting tests. Author: Takuya UESHIN Closes #19517 from ueshin/issues/SPARK-20396/fup2. --- .../spark/api/python/PythonRunner.scala | 1 + python/pyspark/serializers.py | 1 + python/pyspark/sql/functions.py | 33 ++++++++++------ python/pyspark/sql/group.py | 14 ++++--- python/pyspark/sql/tests.py | 37 ++++++++++++++++++ python/pyspark/worker.py | 39 ++++++++----------- .../logical/pythonLogicalOperators.scala | 9 +++-- .../spark/sql/RelationalGroupedDataset.scala | 7 ++-- .../execution/python/ExtractPythonUDFs.scala | 6 ++- .../python/FlatMapGroupsInPandasExec.scala | 2 +- .../sql/execution/python/PythonUDF.scala | 2 +- .../python/UserDefinedPythonFunction.scala | 13 ++++++- .../python/BatchEvalPythonExecSuite.scala | 2 +- 13 files changed, 114 insertions(+), 52 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala index 3688a149443c1..d417303bb147d 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala @@ -36,6 +36,7 @@ private[spark] object PythonEvalType { val NON_UDF = 0 val SQL_BATCHED_UDF = 1 val SQL_PANDAS_UDF = 2 + val SQL_PANDAS_GROUPED_UDF = 3 } /** diff --git a/python/pyspark/serializers.py b/python/pyspark/serializers.py index ad18bd0c81eaa..a0adeed994456 100644 --- a/python/pyspark/serializers.py +++ b/python/pyspark/serializers.py @@ -86,6 +86,7 @@ class PythonEvalType(object): NON_UDF = 0 SQL_BATCHED_UDF = 1 SQL_PANDAS_UDF = 2 + SQL_PANDAS_GROUPED_UDF = 3 class Serializer(object): diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index 9bc12c3b7a162..9bc374b93a433 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -2038,13 +2038,22 @@ def _wrap_function(sc, func, returnType): sc.pythonVer, broadcast_vars, sc._javaAccumulator) +class PythonUdfType(object): + # row-at-a-time UDFs + NORMAL_UDF = 0 + # scalar vectorized UDFs + PANDAS_UDF = 1 + # grouped vectorized UDFs + PANDAS_GROUPED_UDF = 2 + + class UserDefinedFunction(object): """ User defined function in Python .. versionadded:: 1.3 """ - def __init__(self, func, returnType, name=None, vectorized=False): + def __init__(self, func, returnType, name=None, pythonUdfType=PythonUdfType.NORMAL_UDF): if not callable(func): raise TypeError( "Not a function or callable (__call__ is not defined): " @@ -2058,7 +2067,7 @@ def __init__(self, func, returnType, name=None, vectorized=False): self._name = name or ( func.__name__ if hasattr(func, '__name__') else func.__class__.__name__) - self.vectorized = vectorized + self.pythonUdfType = pythonUdfType @property def returnType(self): @@ -2090,7 +2099,7 @@ def _create_judf(self): wrapped_func = _wrap_function(sc, self.func, self.returnType) jdt = spark._jsparkSession.parseDataType(self.returnType.json()) judf = sc._jvm.org.apache.spark.sql.execution.python.UserDefinedPythonFunction( - self._name, wrapped_func, jdt, self.vectorized) + self._name, wrapped_func, jdt, self.pythonUdfType) return judf def __call__(self, *cols): @@ -2121,15 +2130,15 @@ def wrapper(*args): wrapper.func = self.func wrapper.returnType = self.returnType - wrapper.vectorized = self.vectorized + wrapper.pythonUdfType = self.pythonUdfType return wrapper -def _create_udf(f, returnType, vectorized): +def _create_udf(f, returnType, pythonUdfType): - def _udf(f, returnType=StringType(), vectorized=vectorized): - if vectorized: + def _udf(f, returnType=StringType(), pythonUdfType=pythonUdfType): + if pythonUdfType == PythonUdfType.PANDAS_UDF: import inspect argspec = inspect.getargspec(f) if len(argspec.args) == 0 and argspec.varargs is None: @@ -2137,7 +2146,7 @@ def _udf(f, returnType=StringType(), vectorized=vectorized): "0-arg pandas_udfs are not supported. " "Instead, create a 1-arg pandas_udf and ignore the arg in your function." ) - udf_obj = UserDefinedFunction(f, returnType, vectorized=vectorized) + udf_obj = UserDefinedFunction(f, returnType, pythonUdfType=pythonUdfType) return udf_obj._wrapped() # decorator @udf, @udf(), @udf(dataType()), or similar with @pandas_udf @@ -2145,9 +2154,9 @@ def _udf(f, returnType=StringType(), vectorized=vectorized): # If DataType has been passed as a positional argument # for decorator use it as a returnType return_type = f or returnType - return functools.partial(_udf, returnType=return_type, vectorized=vectorized) + return functools.partial(_udf, returnType=return_type, pythonUdfType=pythonUdfType) else: - return _udf(f=f, returnType=returnType, vectorized=vectorized) + return _udf(f=f, returnType=returnType, pythonUdfType=pythonUdfType) @since(1.3) @@ -2181,7 +2190,7 @@ def udf(f=None, returnType=StringType()): | 8| JOHN DOE| 22| +----------+--------------+------------+ """ - return _create_udf(f, returnType=returnType, vectorized=False) + return _create_udf(f, returnType=returnType, pythonUdfType=PythonUdfType.NORMAL_UDF) @since(2.3) @@ -2252,7 +2261,7 @@ def pandas_udf(f=None, returnType=StringType()): .. note:: The user-defined function must be deterministic. """ - return _create_udf(f, returnType=returnType, vectorized=True) + return _create_udf(f, returnType=returnType, pythonUdfType=PythonUdfType.PANDAS_UDF) blacklist = ['map', 'since', 'ignore_unicode_prefix'] diff --git a/python/pyspark/sql/group.py b/python/pyspark/sql/group.py index 817d0bc83bb77..e11388d604312 100644 --- a/python/pyspark/sql/group.py +++ b/python/pyspark/sql/group.py @@ -19,6 +19,7 @@ from pyspark.rdd import ignore_unicode_prefix from pyspark.sql.column import Column, _to_seq, _to_java_column, _create_column_from_literal from pyspark.sql.dataframe import DataFrame +from pyspark.sql.functions import PythonUdfType, UserDefinedFunction from pyspark.sql.types import * __all__ = ["GroupedData"] @@ -235,11 +236,13 @@ def apply(self, udf): .. seealso:: :meth:`pyspark.sql.functions.pandas_udf` """ - from pyspark.sql.functions import pandas_udf + import inspect # Columns are special because hasattr always return True - if isinstance(udf, Column) or not hasattr(udf, 'func') or not udf.vectorized: - raise ValueError("The argument to apply must be a pandas_udf") + if isinstance(udf, Column) or not hasattr(udf, 'func') \ + or udf.pythonUdfType != PythonUdfType.PANDAS_UDF \ + or len(inspect.getargspec(udf.func).args) != 1: + raise ValueError("The argument to apply must be a 1-arg pandas_udf") if not isinstance(udf.returnType, StructType): raise ValueError("The returnType of the pandas_udf must be a StructType") @@ -268,8 +271,9 @@ def wrapped(*cols): return [(result[result.columns[i]], arrow_type) for i, arrow_type in enumerate(arrow_return_types)] - wrapped_udf_obj = pandas_udf(wrapped, returnType) - udf_column = wrapped_udf_obj(*[df[col] for col in df.columns]) + udf_obj = UserDefinedFunction( + wrapped, returnType, name=udf.__name__, pythonUdfType=PythonUdfType.PANDAS_GROUPED_UDF) + udf_column = udf_obj(*[df[col] for col in df.columns]) jdf = self._jgd.flatMapGroupsInPandas(udf_column._jc.expr()) return DataFrame(jdf, self.sql_ctx) diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index bac2ef84ae7a7..685eebcafefba 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -3383,6 +3383,15 @@ def test_vectorized_udf_varargs(self): res = df.select(f(col('id'))) self.assertEquals(df.collect(), res.collect()) + def test_vectorized_udf_unsupported_types(self): + from pyspark.sql.functions import pandas_udf, col + schema = StructType([StructField("dt", DateType(), True)]) + df = self.spark.createDataFrame([(datetime.date(1970, 1, 1),)], schema=schema) + f = pandas_udf(lambda x: x, DateType()) + with QuietTest(self.sc): + with self.assertRaisesRegexp(Exception, 'Unsupported data type'): + df.select(f(col('dt'))).collect() + @unittest.skipIf(not _have_pandas or not _have_arrow, "Pandas or Arrow not installed") class GroupbyApplyTests(ReusedPySparkTestCase): @@ -3492,6 +3501,18 @@ def normalize(pdf): expected = expected.assign(norm=expected.norm.astype('float64')) self.assertFramesEqual(expected, result) + def test_datatype_string(self): + from pyspark.sql.functions import pandas_udf + df = self.data + + foo_udf = pandas_udf( + lambda pdf: pdf.assign(v1=pdf.v * pdf.id * 1.0, v2=pdf.v + pdf.id), + "id long, v int, v1 double, v2 long") + + result = df.groupby('id').apply(foo_udf).sort('id').toPandas() + expected = df.toPandas().groupby('id').apply(foo_udf.func).reset_index(drop=True) + self.assertFramesEqual(expected, result) + def test_wrong_return_type(self): from pyspark.sql.functions import pandas_udf df = self.data @@ -3517,9 +3538,25 @@ def test_wrong_args(self): df.groupby('id').apply(sum(df.v)) with self.assertRaisesRegexp(ValueError, 'pandas_udf'): df.groupby('id').apply(df.v + 1) + with self.assertRaisesRegexp(ValueError, 'pandas_udf'): + df.groupby('id').apply( + pandas_udf(lambda: 1, StructType([StructField("d", DoubleType())]))) + with self.assertRaisesRegexp(ValueError, 'pandas_udf'): + df.groupby('id').apply( + pandas_udf(lambda x, y: x, StructType([StructField("d", DoubleType())]))) with self.assertRaisesRegexp(ValueError, 'returnType'): df.groupby('id').apply(pandas_udf(lambda x: x, DoubleType())) + def test_unsupported_types(self): + from pyspark.sql.functions import pandas_udf, col + schema = StructType( + [StructField("id", LongType(), True), StructField("dt", DateType(), True)]) + df = self.spark.createDataFrame([(1, datetime.date(1970, 1, 1),)], schema=schema) + f = pandas_udf(lambda x: x, df.schema) + with QuietTest(self.sc): + with self.assertRaisesRegexp(Exception, 'Unsupported data type'): + df.groupby('id').apply(f).collect() + if __name__ == "__main__": from pyspark.sql.tests import * diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py index eb6d48688dc0a..5e100e0a9a95d 100644 --- a/python/pyspark/worker.py +++ b/python/pyspark/worker.py @@ -32,7 +32,7 @@ from pyspark.serializers import write_with_length, write_int, read_long, \ write_long, read_int, SpecialLengths, PythonEvalType, UTF8Deserializer, PickleSerializer, \ BatchedSerializer, ArrowStreamPandasSerializer -from pyspark.sql.types import to_arrow_type, StructType +from pyspark.sql.types import to_arrow_type from pyspark import shuffle pickleSer = PickleSerializer() @@ -74,28 +74,19 @@ def wrap_udf(f, return_type): def wrap_pandas_udf(f, return_type): - # If the return_type is a StructType, it indicates this is a groupby apply udf, - # and has already been wrapped under apply(), otherwise, it's a vectorized column udf. - # We can distinguish these two by return type because in groupby apply, we always specify - # returnType as a StructType, and in vectorized column udf, StructType is not supported. - # - # TODO: Look into refactoring use of StructType to be more flexible for future pandas_udfs - if isinstance(return_type, StructType): - return lambda *a: f(*a) - else: - arrow_return_type = to_arrow_type(return_type) + arrow_return_type = to_arrow_type(return_type) - def verify_result_length(*a): - result = f(*a) - if not hasattr(result, "__len__"): - raise TypeError("Return type of the user-defined functon should be " - "Pandas.Series, but is {}".format(type(result))) - if len(result) != len(a[0]): - raise RuntimeError("Result vector from pandas_udf was not the required length: " - "expected %d, got %d" % (len(a[0]), len(result))) - return result + def verify_result_length(*a): + result = f(*a) + if not hasattr(result, "__len__"): + raise TypeError("Return type of the user-defined functon should be " + "Pandas.Series, but is {}".format(type(result))) + if len(result) != len(a[0]): + raise RuntimeError("Result vector from pandas_udf was not the required length: " + "expected %d, got %d" % (len(a[0]), len(result))) + return result - return lambda *a: (verify_result_length(*a), arrow_return_type) + return lambda *a: (verify_result_length(*a), arrow_return_type) def read_single_udf(pickleSer, infile, eval_type): @@ -111,6 +102,9 @@ def read_single_udf(pickleSer, infile, eval_type): # the last returnType will be the return type of UDF if eval_type == PythonEvalType.SQL_PANDAS_UDF: return arg_offsets, wrap_pandas_udf(row_func, return_type) + elif eval_type == PythonEvalType.SQL_PANDAS_GROUPED_UDF: + # a groupby apply udf has already been wrapped under apply() + return arg_offsets, row_func else: return arg_offsets, wrap_udf(row_func, return_type) @@ -133,7 +127,8 @@ def read_udfs(pickleSer, infile, eval_type): func = lambda _, it: map(mapper, it) - if eval_type == PythonEvalType.SQL_PANDAS_UDF: + if eval_type == PythonEvalType.SQL_PANDAS_UDF \ + or eval_type == PythonEvalType.SQL_PANDAS_GROUPED_UDF: ser = ArrowStreamPandasSerializer() else: ser = BatchedSerializer(PickleSerializer(), 100) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/pythonLogicalOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/pythonLogicalOperators.scala index 8abab24bc9b44..254687ec00880 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/pythonLogicalOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/pythonLogicalOperators.scala @@ -24,10 +24,11 @@ import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeSet, Expre * This is used by DataFrame.groupby().apply(). */ case class FlatMapGroupsInPandas( - groupingAttributes: Seq[Attribute], - functionExpr: Expression, - output: Seq[Attribute], - child: LogicalPlan) extends UnaryNode { + groupingAttributes: Seq[Attribute], + functionExpr: Expression, + output: Seq[Attribute], + child: LogicalPlan) extends UnaryNode { + /** * This is needed because output attributes are considered `references` when * passed through the constructor. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala index 33ec3a27110a8..6b45790d5ff6e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala @@ -30,7 +30,7 @@ import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.util.usePrettyExpression import org.apache.spark.sql.execution.aggregate.TypedAggregateExpression -import org.apache.spark.sql.execution.python.PythonUDF +import org.apache.spark.sql.execution.python.{PythonUDF, PythonUdfType} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.{NumericType, StructType} @@ -437,7 +437,7 @@ class RelationalGroupedDataset protected[sql]( } /** - * Applies a vectorized python user-defined function to each group of data. + * Applies a grouped vectorized python user-defined function to each group of data. * The user-defined function defines a transformation: `pandas.DataFrame` -> `pandas.DataFrame`. * For each group, all elements in the group are passed as a `pandas.DataFrame` and the results * for all groups are combined into a new [[DataFrame]]. @@ -449,7 +449,8 @@ class RelationalGroupedDataset protected[sql]( * workers. */ private[sql] def flatMapGroupsInPandas(expr: PythonUDF): DataFrame = { - require(expr.vectorized, "Must pass a vectorized python udf") + require(expr.pythonUdfType == PythonUdfType.PANDAS_GROUPED_UDF, + "Must pass a grouped vectorized python udf") require(expr.dataType.isInstanceOf[StructType], "The returnType of the vectorized python udf must be a StructType") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala index e3f952e221d53..d6825369f7378 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala @@ -137,11 +137,15 @@ object ExtractPythonUDFs extends Rule[SparkPlan] with PredicateHelper { udf.references.subsetOf(child.outputSet) } if (validUdfs.nonEmpty) { + if (validUdfs.exists(_.pythonUdfType == PythonUdfType.PANDAS_GROUPED_UDF)) { + throw new IllegalArgumentException("Can not use grouped vectorized UDFs") + } + val resultAttrs = udfs.zipWithIndex.map { case (u, i) => AttributeReference(s"pythonUDF$i", u.dataType)() } - val evaluation = validUdfs.partition(_.vectorized) match { + val evaluation = validUdfs.partition(_.pythonUdfType == PythonUdfType.PANDAS_UDF) match { case (vectorizedUdfs, plainUdfs) if plainUdfs.isEmpty => ArrowEvalPythonExec(vectorizedUdfs, child.output ++ resultAttrs, child) case (vectorizedUdfs, plainUdfs) if vectorizedUdfs.isEmpty => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasExec.scala index b996b5bb38ba5..5ed88ada428cb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasExec.scala @@ -94,7 +94,7 @@ case class FlatMapGroupsInPandasExec( val columnarBatchIter = new ArrowPythonRunner( chainedFunc, bufferSize, reuseWorker, - PythonEvalType.SQL_PANDAS_UDF, argOffsets, schema) + PythonEvalType.SQL_PANDAS_GROUPED_UDF, argOffsets, schema) .compute(grouped, context.partitionId(), context) columnarBatchIter.flatMap(_.rowIterator.asScala).map(UnsafeProjection.create(output, output)) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonUDF.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonUDF.scala index 84a6d9e5be59c..9c07c7638de57 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonUDF.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonUDF.scala @@ -29,7 +29,7 @@ case class PythonUDF( func: PythonFunction, dataType: DataType, children: Seq[Expression], - vectorized: Boolean) + pythonUdfType: Int) extends Expression with Unevaluable with NonSQLExpression with UserDefinedExpression { override def toString: String = s"$name(${children.mkString(", ")})" diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonFunction.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonFunction.scala index a30a80acf5c23..b2fe6c300846a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonFunction.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonFunction.scala @@ -22,6 +22,15 @@ import org.apache.spark.sql.Column import org.apache.spark.sql.catalyst.expressions.Expression import org.apache.spark.sql.types.DataType +private[spark] object PythonUdfType { + // row-at-a-time UDFs + val NORMAL_UDF = 0 + // scalar vectorized UDFs + val PANDAS_UDF = 1 + // grouped vectorized UDFs + val PANDAS_GROUPED_UDF = 2 +} + /** * A user-defined Python function. This is used by the Python API. */ @@ -29,10 +38,10 @@ case class UserDefinedPythonFunction( name: String, func: PythonFunction, dataType: DataType, - vectorized: Boolean) { + pythonUdfType: Int) { def builder(e: Seq[Expression]): PythonUDF = { - PythonUDF(name, func, dataType, e, vectorized) + PythonUDF(name, func, dataType, e, pythonUdfType) } /** Returns a [[Column]] that will evaluate to calling this UDF with the given input. */ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExecSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExecSuite.scala index 153e6e1f88c70..95b21fc9f16ae 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExecSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExecSuite.scala @@ -109,4 +109,4 @@ class MyDummyPythonUDF extends UserDefinedPythonFunction( name = "dummyUDF", func = new DummyUDF, dataType = BooleanType, - vectorized = false) + pythonUdfType = PythonUdfType.NORMAL_UDF) From d9f286d261c6ee9e8dcb46e78d4666318ea25af2 Mon Sep 17 00:00:00 2001 From: Zhenhua Wang Date: Fri, 20 Oct 2017 20:58:55 -0700 Subject: [PATCH 1519/1765] [SPARK-22326][SQL] Remove unnecessary hashCode and equals methods ## What changes were proposed in this pull request? Plan equality should be computed by `canonicalized`, so we can remove unnecessary `hashCode` and `equals` methods. ## How was this patch tested? Existing tests. Author: Zhenhua Wang Closes #19539 from wzhfy/remove_equals. --- .../apache/spark/sql/catalyst/catalog/interface.scala | 11 ----------- .../sql/execution/datasources/LogicalRelation.scala | 11 ----------- 2 files changed, 22 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala index 975b084aa6188..1dbae4d37d8f5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala @@ -22,8 +22,6 @@ import java.util.Date import scala.collection.mutable -import com.google.common.base.Objects - import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.{FunctionIdentifier, InternalRow, TableIdentifier} import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation @@ -440,15 +438,6 @@ case class HiveTableRelation( def isPartitioned: Boolean = partitionCols.nonEmpty - override def equals(relation: Any): Boolean = relation match { - case other: HiveTableRelation => tableMeta == other.tableMeta && output == other.output - case _ => false - } - - override def hashCode(): Int = { - Objects.hashCode(tableMeta.identifier, output) - } - override lazy val canonicalized: HiveTableRelation = copy( tableMeta = tableMeta.copy( storage = CatalogStorageFormat.empty, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/LogicalRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/LogicalRelation.scala index 17a61074d3b5c..3e98cb28453a2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/LogicalRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/LogicalRelation.scala @@ -34,17 +34,6 @@ case class LogicalRelation( override val isStreaming: Boolean) extends LeafNode with MultiInstanceRelation { - // Logical Relations are distinct if they have different output for the sake of transformations. - override def equals(other: Any): Boolean = other match { - case l @ LogicalRelation(otherRelation, _, _, isStreaming) => - relation == otherRelation && output == l.output && isStreaming == l.isStreaming - case _ => false - } - - override def hashCode: Int = { - com.google.common.base.Objects.hashCode(relation, output) - } - // Only care about relation when canonicalizing. override lazy val canonicalized: LogicalPlan = copy( output = output.map(QueryPlan.normalizeExprId(_, output)), From d8cada8d1d3fce979a4bc1f9879593206722a3b9 Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Sat, 21 Oct 2017 10:05:45 -0700 Subject: [PATCH 1520/1765] [SPARK-20331][SQL][FOLLOW-UP] Add a SQLConf for enhanced Hive partition pruning predicate pushdown ## What changes were proposed in this pull request? This is a follow-up PR of https://github.com/apache/spark/pull/17633. This PR is to add a conf `spark.sql.hive.advancedPartitionPredicatePushdown.enabled`, which can be used to turn the enhancement off. ## How was this patch tested? Add a test case Author: gatorsmile Closes #19547 from gatorsmile/Spark20331FollowUp. --- .../apache/spark/sql/internal/SQLConf.scala | 10 +++++++ .../spark/sql/hive/client/HiveShim.scala | 29 ++++++++++++++++++ .../spark/sql/hive/client/FiltersSuite.scala | 30 +++++++++++++++---- 3 files changed, 64 insertions(+), 5 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 618d4a0d6148a..4cfe53b2c115b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -173,6 +173,13 @@ object SQLConf { .intConf .createWithDefault(4) + val ADVANCED_PARTITION_PREDICATE_PUSHDOWN = + buildConf("spark.sql.hive.advancedPartitionPredicatePushdown.enabled") + .internal() + .doc("When true, advanced partition predicate pushdown into Hive metastore is enabled.") + .booleanConf + .createWithDefault(true) + val ENABLE_FALL_BACK_TO_HDFS_FOR_STATS = buildConf("spark.sql.statistics.fallBackToHdfs") .doc("If the table statistics are not available from table metadata enable fall back to hdfs." + @@ -1092,6 +1099,9 @@ class SQLConf extends Serializable with Logging { def limitScaleUpFactor: Int = getConf(LIMIT_SCALE_UP_FACTOR) + def advancedPartitionPredicatePushdownEnabled: Boolean = + getConf(ADVANCED_PARTITION_PREDICATE_PUSHDOWN) + def fallBackToHdfsForStatsEnabled: Boolean = getConf(ENABLE_FALL_BACK_TO_HDFS_FOR_STATS) def preferSortMergeJoin: Boolean = getConf(PREFER_SORTMERGEJOIN) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala index cde20da186acd..5c1ff2b76fdaa 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala @@ -585,6 +585,35 @@ private[client] class Shim_v0_13 extends Shim_v0_12 { * Unsupported predicates are skipped. */ def convertFilters(table: Table, filters: Seq[Expression]): String = { + if (SQLConf.get.advancedPartitionPredicatePushdownEnabled) { + convertComplexFilters(table, filters) + } else { + convertBasicFilters(table, filters) + } + } + + private def convertBasicFilters(table: Table, filters: Seq[Expression]): String = { + // hive varchar is treated as catalyst string, but hive varchar can't be pushed down. + lazy val varcharKeys = table.getPartitionKeys.asScala + .filter(col => col.getType.startsWith(serdeConstants.VARCHAR_TYPE_NAME) || + col.getType.startsWith(serdeConstants.CHAR_TYPE_NAME)) + .map(col => col.getName).toSet + + filters.collect { + case op @ BinaryComparison(a: Attribute, Literal(v, _: IntegralType)) => + s"${a.name} ${op.symbol} $v" + case op @ BinaryComparison(Literal(v, _: IntegralType), a: Attribute) => + s"$v ${op.symbol} ${a.name}" + case op @ BinaryComparison(a: Attribute, Literal(v, _: StringType)) + if !varcharKeys.contains(a.name) => + s"""${a.name} ${op.symbol} ${quoteStringLiteral(v.toString)}""" + case op @ BinaryComparison(Literal(v, _: StringType), a: Attribute) + if !varcharKeys.contains(a.name) => + s"""${quoteStringLiteral(v.toString)} ${op.symbol} ${a.name}""" + }.mkString(" and ") + } + + private def convertComplexFilters(table: Table, filters: Seq[Expression]): String = { // hive varchar is treated as catalyst string, but hive varchar can't be pushed down. lazy val varcharKeys = table.getPartitionKeys.asScala .filter(col => col.getType.startsWith(serdeConstants.VARCHAR_TYPE_NAME) || diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/FiltersSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/FiltersSuite.scala index 031c1a5ec0ec3..19765695fbcb4 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/FiltersSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/FiltersSuite.scala @@ -26,13 +26,15 @@ import org.apache.spark.SparkFunSuite import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans.PlanTest +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ /** * A set of tests for the filter conversion logic used when pushing partition pruning into the * metastore */ -class FiltersSuite extends SparkFunSuite with Logging { +class FiltersSuite extends SparkFunSuite with Logging with PlanTest { private val shim = new Shim_v0_13 private val testTable = new org.apache.hadoop.hive.ql.metadata.Table("default", "test") @@ -72,10 +74,28 @@ class FiltersSuite extends SparkFunSuite with Logging { private def filterTest(name: String, filters: Seq[Expression], result: String) = { test(name) { - val converted = shim.convertFilters(testTable, filters) - if (converted != result) { - fail( - s"Expected filters ${filters.mkString(",")} to convert to '$result' but got '$converted'") + withSQLConf(SQLConf.ADVANCED_PARTITION_PREDICATE_PUSHDOWN.key -> "true") { + val converted = shim.convertFilters(testTable, filters) + if (converted != result) { + fail(s"Expected ${filters.mkString(",")} to convert to '$result' but got '$converted'") + } + } + } + } + + test("turn on/off ADVANCED_PARTITION_PREDICATE_PUSHDOWN") { + import org.apache.spark.sql.catalyst.dsl.expressions._ + Seq(true, false).foreach { enabled => + withSQLConf(SQLConf.ADVANCED_PARTITION_PREDICATE_PUSHDOWN.key -> enabled.toString) { + val filters = + (Literal(1) === a("intcol", IntegerType) || + Literal(2) === a("intcol", IntegerType)) :: Nil + val converted = shim.convertFilters(testTable, filters) + if (enabled) { + assert(converted == "(1 = intcol or 2 = intcol)") + } else { + assert(converted.isEmpty) + } } } } From a763607e4fc24f4dc0f455b67a63acba5be1c80a Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Sat, 21 Oct 2017 10:07:31 -0700 Subject: [PATCH 1521/1765] [SPARK-21055][SQL][FOLLOW-UP] replace grouping__id with grouping_id() ## What changes were proposed in this pull request? Simplifies the test cases that were added in the PR https://github.com/apache/spark/pull/18270. ## How was this patch tested? N/A Author: gatorsmile Closes #19546 from gatorsmile/backportSPARK-21055. --- .../sql/hive/execution/SQLQuerySuite.scala | 306 ++++++------------ 1 file changed, 104 insertions(+), 202 deletions(-) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala index 2476a440ad82c..1cf1c5cd5a472 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala @@ -1376,223 +1376,125 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton { } test("SPARK-8976 Wrong Result for Rollup #1") { - checkAnswer(sql( - "SELECT count(*) AS cnt, key % 5, grouping_id() FROM src GROUP BY key%5 WITH ROLLUP"), - Seq( - (113, 3, 0), - (91, 0, 0), - (500, null, 1), - (84, 1, 0), - (105, 2, 0), - (107, 4, 0) - ).map(i => Row(i._1, i._2, i._3))) - } - - test("SPARK-21055 replace grouping__id: Wrong Result for Rollup #1") { - checkAnswer(sql( - "SELECT count(*) AS cnt, key % 5, grouping__id FROM src GROUP BY key%5 WITH ROLLUP"), - Seq( - (113, 3, 0), - (91, 0, 0), - (500, null, 1), - (84, 1, 0), - (105, 2, 0), - (107, 4, 0) - ).map(i => Row(i._1, i._2, i._3))) + Seq("grouping_id()", "grouping__id").foreach { gid => + checkAnswer(sql( + s"SELECT count(*) AS cnt, key % 5, $gid FROM src GROUP BY key%5 WITH ROLLUP"), + Seq( + (113, 3, 0), + (91, 0, 0), + (500, null, 1), + (84, 1, 0), + (105, 2, 0), + (107, 4, 0) + ).map(i => Row(i._1, i._2, i._3))) + } } test("SPARK-8976 Wrong Result for Rollup #2") { - checkAnswer(sql( - """ - |SELECT count(*) AS cnt, key % 5 AS k1, key-5 AS k2, grouping_id() AS k3 - |FROM src GROUP BY key%5, key-5 - |WITH ROLLUP ORDER BY cnt, k1, k2, k3 LIMIT 10 - """.stripMargin), - Seq( - (1, 0, 5, 0), - (1, 0, 15, 0), - (1, 0, 25, 0), - (1, 0, 60, 0), - (1, 0, 75, 0), - (1, 0, 80, 0), - (1, 0, 100, 0), - (1, 0, 140, 0), - (1, 0, 145, 0), - (1, 0, 150, 0) - ).map(i => Row(i._1, i._2, i._3, i._4))) - } - - test("SPARK-21055 replace grouping__id: Wrong Result for Rollup #2") { - checkAnswer(sql( - """ - |SELECT count(*) AS cnt, key % 5 AS k1, key-5 AS k2, grouping__id AS k3 - |FROM src GROUP BY key%5, key-5 - |WITH ROLLUP ORDER BY cnt, k1, k2, k3 LIMIT 10 - """.stripMargin), - Seq( - (1, 0, 5, 0), - (1, 0, 15, 0), - (1, 0, 25, 0), - (1, 0, 60, 0), - (1, 0, 75, 0), - (1, 0, 80, 0), - (1, 0, 100, 0), - (1, 0, 140, 0), - (1, 0, 145, 0), - (1, 0, 150, 0) - ).map(i => Row(i._1, i._2, i._3, i._4))) + Seq("grouping_id()", "grouping__id").foreach { gid => + checkAnswer(sql( + s""" + |SELECT count(*) AS cnt, key % 5 AS k1, key-5 AS k2, $gid AS k3 + |FROM src GROUP BY key%5, key-5 + |WITH ROLLUP ORDER BY cnt, k1, k2, k3 LIMIT 10 + """.stripMargin), + Seq( + (1, 0, 5, 0), + (1, 0, 15, 0), + (1, 0, 25, 0), + (1, 0, 60, 0), + (1, 0, 75, 0), + (1, 0, 80, 0), + (1, 0, 100, 0), + (1, 0, 140, 0), + (1, 0, 145, 0), + (1, 0, 150, 0) + ).map(i => Row(i._1, i._2, i._3, i._4))) + } } test("SPARK-8976 Wrong Result for Rollup #3") { - checkAnswer(sql( - """ - |SELECT count(*) AS cnt, key % 5 AS k1, key-5 AS k2, grouping_id() AS k3 - |FROM (SELECT key, key%2, key - 5 FROM src) t GROUP BY key%5, key-5 - |WITH ROLLUP ORDER BY cnt, k1, k2, k3 LIMIT 10 - """.stripMargin), - Seq( - (1, 0, 5, 0), - (1, 0, 15, 0), - (1, 0, 25, 0), - (1, 0, 60, 0), - (1, 0, 75, 0), - (1, 0, 80, 0), - (1, 0, 100, 0), - (1, 0, 140, 0), - (1, 0, 145, 0), - (1, 0, 150, 0) - ).map(i => Row(i._1, i._2, i._3, i._4))) - } - - test("SPARK-21055 replace grouping__id: Wrong Result for Rollup #3") { - checkAnswer(sql( - """ - |SELECT count(*) AS cnt, key % 5 AS k1, key-5 AS k2, grouping__id AS k3 - |FROM (SELECT key, key%2, key - 5 FROM src) t GROUP BY key%5, key-5 - |WITH ROLLUP ORDER BY cnt, k1, k2, k3 LIMIT 10 - """.stripMargin), - Seq( - (1, 0, 5, 0), - (1, 0, 15, 0), - (1, 0, 25, 0), - (1, 0, 60, 0), - (1, 0, 75, 0), - (1, 0, 80, 0), - (1, 0, 100, 0), - (1, 0, 140, 0), - (1, 0, 145, 0), - (1, 0, 150, 0) - ).map(i => Row(i._1, i._2, i._3, i._4))) + Seq("grouping_id()", "grouping__id").foreach { gid => + checkAnswer(sql( + s""" + |SELECT count(*) AS cnt, key % 5 AS k1, key-5 AS k2, $gid AS k3 + |FROM (SELECT key, key%2, key - 5 FROM src) t GROUP BY key%5, key-5 + |WITH ROLLUP ORDER BY cnt, k1, k2, k3 LIMIT 10 + """.stripMargin), + Seq( + (1, 0, 5, 0), + (1, 0, 15, 0), + (1, 0, 25, 0), + (1, 0, 60, 0), + (1, 0, 75, 0), + (1, 0, 80, 0), + (1, 0, 100, 0), + (1, 0, 140, 0), + (1, 0, 145, 0), + (1, 0, 150, 0) + ).map(i => Row(i._1, i._2, i._3, i._4))) + } } test("SPARK-8976 Wrong Result for CUBE #1") { - checkAnswer(sql( - "SELECT count(*) AS cnt, key % 5, grouping_id() FROM src GROUP BY key%5 WITH CUBE"), - Seq( - (113, 3, 0), - (91, 0, 0), - (500, null, 1), - (84, 1, 0), - (105, 2, 0), - (107, 4, 0) - ).map(i => Row(i._1, i._2, i._3))) - } - - test("SPARK-21055 replace grouping__id: Wrong Result for CUBE #1") { - checkAnswer(sql( - "SELECT count(*) AS cnt, key % 5, grouping__id FROM src GROUP BY key%5 WITH CUBE"), - Seq( - (113, 3, 0), - (91, 0, 0), - (500, null, 1), - (84, 1, 0), - (105, 2, 0), - (107, 4, 0) - ).map(i => Row(i._1, i._2, i._3))) + Seq("grouping_id()", "grouping__id").foreach { gid => + checkAnswer(sql( + s"SELECT count(*) AS cnt, key % 5, $gid FROM src GROUP BY key%5 WITH CUBE"), + Seq( + (113, 3, 0), + (91, 0, 0), + (500, null, 1), + (84, 1, 0), + (105, 2, 0), + (107, 4, 0) + ).map(i => Row(i._1, i._2, i._3))) + } } test("SPARK-8976 Wrong Result for CUBE #2") { - checkAnswer(sql( - """ - |SELECT count(*) AS cnt, key % 5 AS k1, key-5 AS k2, grouping_id() AS k3 - |FROM (SELECT key, key%2, key - 5 FROM src) t GROUP BY key%5, key-5 - |WITH CUBE ORDER BY cnt, k1, k2, k3 LIMIT 10 - """.stripMargin), - Seq( - (1, null, -3, 2), - (1, null, -1, 2), - (1, null, 3, 2), - (1, null, 4, 2), - (1, null, 5, 2), - (1, null, 6, 2), - (1, null, 12, 2), - (1, null, 14, 2), - (1, null, 15, 2), - (1, null, 22, 2) - ).map(i => Row(i._1, i._2, i._3, i._4))) - } - - test("SPARK-21055 replace grouping__id: Wrong Result for CUBE #2") { - checkAnswer(sql( - """ - |SELECT count(*) AS cnt, key % 5 AS k1, key-5 AS k2, grouping__id AS k3 - |FROM (SELECT key, key%2, key - 5 FROM src) t GROUP BY key%5, key-5 - |WITH CUBE ORDER BY cnt, k1, k2, k3 LIMIT 10 - """.stripMargin), - Seq( - (1, null, -3, 2), - (1, null, -1, 2), - (1, null, 3, 2), - (1, null, 4, 2), - (1, null, 5, 2), - (1, null, 6, 2), - (1, null, 12, 2), - (1, null, 14, 2), - (1, null, 15, 2), - (1, null, 22, 2) - ).map(i => Row(i._1, i._2, i._3, i._4))) + Seq("grouping_id()", "grouping__id").foreach { gid => + checkAnswer(sql( + s""" + |SELECT count(*) AS cnt, key % 5 AS k1, key-5 AS k2, $gid AS k3 + |FROM (SELECT key, key%2, key - 5 FROM src) t GROUP BY key%5, key-5 + |WITH CUBE ORDER BY cnt, k1, k2, k3 LIMIT 10 + """.stripMargin), + Seq( + (1, null, -3, 2), + (1, null, -1, 2), + (1, null, 3, 2), + (1, null, 4, 2), + (1, null, 5, 2), + (1, null, 6, 2), + (1, null, 12, 2), + (1, null, 14, 2), + (1, null, 15, 2), + (1, null, 22, 2) + ).map(i => Row(i._1, i._2, i._3, i._4))) + } } test("SPARK-8976 Wrong Result for GroupingSet") { - checkAnswer(sql( - """ - |SELECT count(*) AS cnt, key % 5 AS k1, key-5 AS k2, grouping_id() AS k3 - |FROM (SELECT key, key%2, key - 5 FROM src) t GROUP BY key%5, key-5 - |GROUPING SETS (key%5, key-5) ORDER BY cnt, k1, k2, k3 LIMIT 10 - """.stripMargin), - Seq( - (1, null, -3, 2), - (1, null, -1, 2), - (1, null, 3, 2), - (1, null, 4, 2), - (1, null, 5, 2), - (1, null, 6, 2), - (1, null, 12, 2), - (1, null, 14, 2), - (1, null, 15, 2), - (1, null, 22, 2) - ).map(i => Row(i._1, i._2, i._3, i._4))) - } - - test("SPARK-21055 replace grouping__id: Wrong Result for GroupingSet") { - checkAnswer(sql( - """ - |SELECT count(*) AS cnt, key % 5 AS k1, key-5 AS k2, grouping__id AS k3 - |FROM (SELECT key, key%2, key - 5 FROM src) t GROUP BY key%5, key-5 - |GROUPING SETS (key%5, key-5) ORDER BY cnt, k1, k2, k3 LIMIT 10 - """.stripMargin), - Seq( - (1, null, -3, 2), - (1, null, -1, 2), - (1, null, 3, 2), - (1, null, 4, 2), - (1, null, 5, 2), - (1, null, 6, 2), - (1, null, 12, 2), - (1, null, 14, 2), - (1, null, 15, 2), - (1, null, 22, 2) - ).map(i => Row(i._1, i._2, i._3, i._4))) + Seq("grouping_id()", "grouping__id").foreach { gid => + checkAnswer(sql( + s""" + |SELECT count(*) AS cnt, key % 5 AS k1, key-5 AS k2, $gid AS k3 + |FROM (SELECT key, key%2, key - 5 FROM src) t GROUP BY key%5, key-5 + |GROUPING SETS (key%5, key-5) ORDER BY cnt, k1, k2, k3 LIMIT 10 + """.stripMargin), + Seq( + (1, null, -3, 2), + (1, null, -1, 2), + (1, null, 3, 2), + (1, null, 4, 2), + (1, null, 5, 2), + (1, null, 6, 2), + (1, null, 12, 2), + (1, null, 14, 2), + (1, null, 15, 2), + (1, null, 22, 2) + ).map(i => Row(i._1, i._2, i._3, i._4))) + } } ignore("SPARK-10562: partition by column with mixed case name") { From ff8de99a1c7b4a291e661cd0ad12748f4321e43d Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Sun, 22 Oct 2017 02:22:35 +0900 Subject: [PATCH 1522/1765] [SPARK-22302][INFRA] Remove manual backports for subprocess and print explicit message for < Python 2.7 ## What changes were proposed in this pull request? Seems there was a mistake - missing import for `subprocess.call`, while refactoring this script a long ago, which should be used for backports of some missing functions in `subprocess`, specifically in < Python 2.7. Reproduction is: ``` cd dev && python2.6 ``` ``` >>> from sparktestsupport import shellutils >>> shellutils.subprocess_check_call("ls") Traceback (most recent call last): File "", line 1, in File "sparktestsupport/shellutils.py", line 46, in subprocess_check_call retcode = call(*popenargs, **kwargs) NameError: global name 'call' is not defined ``` For Jenkins logs, please see https://amplab.cs.berkeley.edu/jenkins/job/NewSparkPullRequestBuilder/3950/console Since we dropped the Python 2.6.x support, looks better we remove those workarounds and print out explicit error messages in order to reduce the efforts to find out the root causes for such cases, for example, `https://github.com/apache/spark/pull/19513#issuecomment-337406734`. ## How was this patch tested? Manually tested: ``` ./dev/run-tests ``` ``` Python versions prior to 2.7 are not supported. ``` ``` ./dev/run-tests-jenkins ``` ``` Python versions prior to 2.7 are not supported. ``` Author: hyukjinkwon Closes #19524 from HyukjinKwon/SPARK-22302. --- dev/run-tests | 6 ++++++ dev/run-tests-jenkins | 7 ++++++- dev/sparktestsupport/shellutils.py | 31 ++---------------------------- 3 files changed, 14 insertions(+), 30 deletions(-) diff --git a/dev/run-tests b/dev/run-tests index 257d1e8d50bb4..9cf93d000d0ea 100755 --- a/dev/run-tests +++ b/dev/run-tests @@ -20,4 +20,10 @@ FWDIR="$(cd "`dirname $0`"/..; pwd)" cd "$FWDIR" +PYTHON_VERSION_CHECK=$(python -c 'import sys; print(sys.version_info < (2, 7, 0))') +if [[ "$PYTHON_VERSION_CHECK" == "True" ]]; then + echo "Python versions prior to 2.7 are not supported." + exit -1 +fi + exec python -u ./dev/run-tests.py "$@" diff --git a/dev/run-tests-jenkins b/dev/run-tests-jenkins index f41f1ac79e381..03fd6ff0fba40 100755 --- a/dev/run-tests-jenkins +++ b/dev/run-tests-jenkins @@ -25,5 +25,10 @@ FWDIR="$( cd "$( dirname "$0" )/.." && pwd )" cd "$FWDIR" -export PATH=/home/anaconda/bin:$PATH +PYTHON_VERSION_CHECK=$(python -c 'import sys; print(sys.version_info < (2, 7, 0))') +if [[ "$PYTHON_VERSION_CHECK" == "True" ]]; then + echo "Python versions prior to 2.7 are not supported." + exit -1 +fi + exec python -u ./dev/run-tests-jenkins.py "$@" diff --git a/dev/sparktestsupport/shellutils.py b/dev/sparktestsupport/shellutils.py index 05af87189b18d..c7644da88f770 100644 --- a/dev/sparktestsupport/shellutils.py +++ b/dev/sparktestsupport/shellutils.py @@ -21,35 +21,8 @@ import subprocess import sys - -if sys.version_info >= (2, 7): - subprocess_check_output = subprocess.check_output - subprocess_check_call = subprocess.check_call -else: - # SPARK-8763 - # backported from subprocess module in Python 2.7 - def subprocess_check_output(*popenargs, **kwargs): - if 'stdout' in kwargs: - raise ValueError('stdout argument not allowed, it will be overridden.') - process = subprocess.Popen(stdout=subprocess.PIPE, *popenargs, **kwargs) - output, unused_err = process.communicate() - retcode = process.poll() - if retcode: - cmd = kwargs.get("args") - if cmd is None: - cmd = popenargs[0] - raise subprocess.CalledProcessError(retcode, cmd, output=output) - return output - - # backported from subprocess module in Python 2.7 - def subprocess_check_call(*popenargs, **kwargs): - retcode = call(*popenargs, **kwargs) - if retcode: - cmd = kwargs.get("args") - if cmd is None: - cmd = popenargs[0] - raise CalledProcessError(retcode, cmd) - return 0 +subprocess_check_output = subprocess.check_output +subprocess_check_call = subprocess.check_call def exit_from_command_with_retcode(cmd, retcode): From ca2a780e7c4c4df2488ef933241c6e65264f8d3c Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Sat, 21 Oct 2017 18:01:45 -0700 Subject: [PATCH 1523/1765] [SPARK-21929][SQL] Support `ALTER TABLE table_name ADD COLUMNS(..)` for ORC data source ## What changes were proposed in this pull request? When [SPARK-19261](https://issues.apache.org/jira/browse/SPARK-19261) implements `ALTER TABLE ADD COLUMNS`, ORC data source is omitted due to SPARK-14387, SPARK-16628, and SPARK-18355. Now, those issues are fixed and Spark 2.3 is [using Spark schema to read ORC table instead of ORC file schema](https://github.com/apache/spark/commit/e6e36004afc3f9fc8abea98542248e9de11b4435). This PR enables `ALTER TABLE ADD COLUMNS` for ORC data source. ## How was this patch tested? Pass the updated and added test cases. Author: Dongjoon Hyun Closes #19545 from dongjoon-hyun/SPARK-21929. --- .../spark/sql/execution/command/tables.scala | 3 +- .../sql/execution/command/DDLSuite.scala | 90 ++++++++++--------- .../sql/hive/execution/HiveDDLSuite.scala | 8 ++ 3 files changed, 58 insertions(+), 43 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala index 8d95ca6921cf8..38f91639c0422 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala @@ -235,11 +235,10 @@ case class AlterTableAddColumnsCommand( DataSource.lookupDataSource(catalogTable.provider.get).newInstance() match { // For datasource table, this command can only support the following File format. // TextFileFormat only default to one column "value" - // OrcFileFormat can not handle difference between user-specified schema and - // inferred schema yet. TODO, once this issue is resolved , we can add Orc back. // Hive type is already considered as hive serde table, so the logic will not // come in here. case _: JsonFileFormat | _: CSVFileFormat | _: ParquetFileFormat => + case s if s.getClass.getCanonicalName.endsWith("OrcFileFormat") => case s => throw new AnalysisException( s""" diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala index 4ed2cecc5faff..21a2c62929146 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala @@ -2202,56 +2202,64 @@ abstract class DDLSuite extends QueryTest with SQLTestUtils { } } + protected def testAddColumn(provider: String): Unit = { + withTable("t1") { + sql(s"CREATE TABLE t1 (c1 int) USING $provider") + sql("INSERT INTO t1 VALUES (1)") + sql("ALTER TABLE t1 ADD COLUMNS (c2 int)") + checkAnswer( + spark.table("t1"), + Seq(Row(1, null)) + ) + checkAnswer( + sql("SELECT * FROM t1 WHERE c2 is null"), + Seq(Row(1, null)) + ) + + sql("INSERT INTO t1 VALUES (3, 2)") + checkAnswer( + sql("SELECT * FROM t1 WHERE c2 = 2"), + Seq(Row(3, 2)) + ) + } + } + + protected def testAddColumnPartitioned(provider: String): Unit = { + withTable("t1") { + sql(s"CREATE TABLE t1 (c1 int, c2 int) USING $provider PARTITIONED BY (c2)") + sql("INSERT INTO t1 PARTITION(c2 = 2) VALUES (1)") + sql("ALTER TABLE t1 ADD COLUMNS (c3 int)") + checkAnswer( + spark.table("t1"), + Seq(Row(1, null, 2)) + ) + checkAnswer( + sql("SELECT * FROM t1 WHERE c3 is null"), + Seq(Row(1, null, 2)) + ) + sql("INSERT INTO t1 PARTITION(c2 =1) VALUES (2, 3)") + checkAnswer( + sql("SELECT * FROM t1 WHERE c3 = 3"), + Seq(Row(2, 3, 1)) + ) + checkAnswer( + sql("SELECT * FROM t1 WHERE c2 = 1"), + Seq(Row(2, 3, 1)) + ) + } + } + val supportedNativeFileFormatsForAlterTableAddColumns = Seq("parquet", "json", "csv") supportedNativeFileFormatsForAlterTableAddColumns.foreach { provider => test(s"alter datasource table add columns - $provider") { - withTable("t1") { - sql(s"CREATE TABLE t1 (c1 int) USING $provider") - sql("INSERT INTO t1 VALUES (1)") - sql("ALTER TABLE t1 ADD COLUMNS (c2 int)") - checkAnswer( - spark.table("t1"), - Seq(Row(1, null)) - ) - checkAnswer( - sql("SELECT * FROM t1 WHERE c2 is null"), - Seq(Row(1, null)) - ) - - sql("INSERT INTO t1 VALUES (3, 2)") - checkAnswer( - sql("SELECT * FROM t1 WHERE c2 = 2"), - Seq(Row(3, 2)) - ) - } + testAddColumn(provider) } } supportedNativeFileFormatsForAlterTableAddColumns.foreach { provider => test(s"alter datasource table add columns - partitioned - $provider") { - withTable("t1") { - sql(s"CREATE TABLE t1 (c1 int, c2 int) USING $provider PARTITIONED BY (c2)") - sql("INSERT INTO t1 PARTITION(c2 = 2) VALUES (1)") - sql("ALTER TABLE t1 ADD COLUMNS (c3 int)") - checkAnswer( - spark.table("t1"), - Seq(Row(1, null, 2)) - ) - checkAnswer( - sql("SELECT * FROM t1 WHERE c3 is null"), - Seq(Row(1, null, 2)) - ) - sql("INSERT INTO t1 PARTITION(c2 =1) VALUES (2, 3)") - checkAnswer( - sql("SELECT * FROM t1 WHERE c3 = 3"), - Seq(Row(2, 3, 1)) - ) - checkAnswer( - sql("SELECT * FROM t1 WHERE c2 = 1"), - Seq(Row(2, 3, 1)) - ) - } + testAddColumnPartitioned(provider) } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala index 02e26bbe876a0..d3465a641a1a4 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala @@ -166,6 +166,14 @@ class HiveCatalogedDDLSuite extends DDLSuite with TestHiveSingleton with BeforeA test("drop table") { testDropTable(isDatasourceTable = false) } + + test("alter datasource table add columns - orc") { + testAddColumn("orc") + } + + test("alter datasource table add columns - partitioned - orc") { + testAddColumnPartitioned("orc") + } } class HiveDDLSuite From 57accf6e3965ff69adc4408623916c5003918235 Mon Sep 17 00:00:00 2001 From: Steven Rand Date: Mon, 23 Oct 2017 09:43:45 +0800 Subject: [PATCH 1524/1765] [SPARK-22319][CORE] call loginUserFromKeytab before accessing hdfs In `SparkSubmit`, call `loginUserFromKeytab` before attempting to make RPC calls to the NameNode. I manually tested this patch by: 1. Confirming that my Spark application failed to launch with the error reported in https://issues.apache.org/jira/browse/SPARK-22319. 2. Applying this patch and confirming that the app no longer fails to launch, even when I have not manually run `kinit` on the host. Presumably we also want integration tests for secure clusters so that we catch this sort of thing. I'm happy to take a shot at this if it's feasible and someone can point me in the right direction. Author: Steven Rand Closes #19540 from sjrand/SPARK-22319. Change-Id: Ic306bfe7181107fbcf92f61d75856afcb5b6f761 --- .../org/apache/spark/deploy/SparkSubmit.scala | 32 +++++++++---------- 1 file changed, 16 insertions(+), 16 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala index 135bbe93bf28e..b7e6d0ea021a4 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala @@ -342,6 +342,22 @@ object SparkSubmit extends CommandLineUtils with Logging { val hadoopConf = conf.getOrElse(SparkHadoopUtil.newConfiguration(sparkConf)) val targetDir = Utils.createTempDir() + // assure a keytab is available from any place in a JVM + if (clusterManager == YARN || clusterManager == LOCAL || clusterManager == MESOS) { + if (args.principal != null) { + if (args.keytab != null) { + require(new File(args.keytab).exists(), s"Keytab file: ${args.keytab} does not exist") + // Add keytab and principal configurations in sysProps to make them available + // for later use; e.g. in spark sql, the isolated class loader used to talk + // to HiveMetastore will use these settings. They will be set as Java system + // properties and then loaded by SparkConf + sysProps.put("spark.yarn.keytab", args.keytab) + sysProps.put("spark.yarn.principal", args.principal) + UserGroupInformation.loginUserFromKeytab(args.principal, args.keytab) + } + } + } + // Resolve glob path for different resources. args.jars = Option(args.jars).map(resolveGlobPaths(_, hadoopConf)).orNull args.files = Option(args.files).map(resolveGlobPaths(_, hadoopConf)).orNull @@ -641,22 +657,6 @@ object SparkSubmit extends CommandLineUtils with Logging { } } - // assure a keytab is available from any place in a JVM - if (clusterManager == YARN || clusterManager == LOCAL || clusterManager == MESOS) { - if (args.principal != null) { - if (args.keytab != null) { - require(new File(args.keytab).exists(), s"Keytab file: ${args.keytab} does not exist") - // Add keytab and principal configurations in sysProps to make them available - // for later use; e.g. in spark sql, the isolated class loader used to talk - // to HiveMetastore will use these settings. They will be set as Java system - // properties and then loaded by SparkConf - sysProps.put("spark.yarn.keytab", args.keytab) - sysProps.put("spark.yarn.principal", args.principal) - UserGroupInformation.loginUserFromKeytab(args.principal, args.keytab) - } - } - } - if (clusterManager == MESOS && UserGroupInformation.isSecurityEnabled) { setRMPrincipal(sysProps) } From 5a5b6b78517b526771ee5b579d56aa1daa4b3ef1 Mon Sep 17 00:00:00 2001 From: Kohki Nishio Date: Mon, 23 Oct 2017 09:55:46 -0700 Subject: [PATCH 1525/1765] [SPARK-22303][SQL] Handle Oracle specific jdbc types in OracleDialect TIMESTAMP (-101), BINARY_DOUBLE (101) and BINARY_FLOAT (100) are handled in OracleDialect ## What changes were proposed in this pull request? When a oracle table contains columns whose type is BINARY_FLOAT or BINARY_DOUBLE, spark sql fails to load a table with SQLException ``` java.sql.SQLException: Unsupported type 101 at org.apache.spark.sql.execution.datasources.jdbc.JdbcUtils$.org$apache$spark$sql$execution$datasources$jdbc$JdbcUtils$$getCatalystType(JdbcUtils.scala:235) at org.apache.spark.sql.execution.datasources.jdbc.JdbcUtils$$anonfun$8.apply(JdbcUtils.scala:292) at org.apache.spark.sql.execution.datasources.jdbc.JdbcUtils$$anonfun$8.apply(JdbcUtils.scala:292) at scala.Option.getOrElse(Option.scala:121) at org.apache.spark.sql.execution.datasources.jdbc.JdbcUtils$.getSchema(JdbcUtils.scala:291) at org.apache.spark.sql.execution.datasources.jdbc.JDBCRDD$.resolveTable(JDBCRDD.scala:64) at org.apache.spark.sql.execution.datasources.jdbc.JDBCRelation.(JDBCRelation.scala:113) at org.apache.spark.sql.execution.datasources.jdbc.JdbcRelationProvider.createRelation(JdbcRelationProvider.scala:47) at org.apache.spark.sql.execution.datasources.DataSource.resolveRelation(DataSource.scala:306) at org.apache.spark.sql.DataFrameReader.load(DataFrameReader.scala:178) at org.apache.spark.sql.DataFrameReader.load(DataFrameReader.scala:146) ``` ## How was this patch tested? I updated a UT which covers type conversion test for types (-101, 100, 101), on top of that I tested this change against actual table with those columns and it was able to read and write to the table. Author: Kohki Nishio Closes #19548 from taroplus/oracle_sql_types_101. --- .../sql/jdbc/OracleIntegrationSuite.scala | 43 +++++++++++++++--- .../datasources/jdbc/JdbcUtils.scala | 1 - .../apache/spark/sql/jdbc/OracleDialect.scala | 44 +++++++++++-------- .../org/apache/spark/sql/jdbc/JDBCSuite.scala | 6 +++ 4 files changed, 68 insertions(+), 26 deletions(-) diff --git a/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/OracleIntegrationSuite.scala b/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/OracleIntegrationSuite.scala index 7680ae3835132..90343182712ed 100644 --- a/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/OracleIntegrationSuite.scala +++ b/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/OracleIntegrationSuite.scala @@ -21,7 +21,7 @@ import java.sql.{Connection, Date, Timestamp} import java.util.Properties import java.math.BigDecimal -import org.apache.spark.sql.{DataFrame, Row} +import org.apache.spark.sql.{DataFrame, Row, SaveMode} import org.apache.spark.sql.execution.{WholeStageCodegenExec, RowDataSourceScanExec} import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types._ @@ -52,7 +52,7 @@ class OracleIntegrationSuite extends DockerJDBCIntegrationSuite with SharedSQLCo import testImplicits._ override val db = new DatabaseOnDocker { - override val imageName = "wnameless/oracle-xe-11g:14.04.4" + override val imageName = "wnameless/oracle-xe-11g:16.04" override val env = Map( "ORACLE_ROOT_PASSWORD" -> "oracle" ) @@ -104,15 +104,18 @@ class OracleIntegrationSuite extends DockerJDBCIntegrationSuite with SharedSQLCo """.stripMargin.replaceAll("\n", " ")) - conn.prepareStatement("CREATE TABLE numerics (b DECIMAL(1), f DECIMAL(3, 2), i DECIMAL(10))").executeUpdate(); + conn.prepareStatement("CREATE TABLE numerics (b DECIMAL(1), f DECIMAL(3, 2), i DECIMAL(10))").executeUpdate() conn.prepareStatement( - "INSERT INTO numerics VALUES (4, 1.23, 9999999999)").executeUpdate(); - conn.commit(); + "INSERT INTO numerics VALUES (4, 1.23, 9999999999)").executeUpdate() + conn.commit() + + conn.prepareStatement("CREATE TABLE oracle_types (d BINARY_DOUBLE, f BINARY_FLOAT)").executeUpdate() + conn.commit() } test("SPARK-16625 : Importing Oracle numeric types") { - val df = sqlContext.read.jdbc(jdbcUrl, "numerics", new Properties); + val df = sqlContext.read.jdbc(jdbcUrl, "numerics", new Properties) val rows = df.collect() assert(rows.size == 1) val row = rows(0) @@ -307,4 +310,32 @@ class OracleIntegrationSuite extends DockerJDBCIntegrationSuite with SharedSQLCo assert(values.getInt(1).equals(1)) assert(values.getBoolean(2).equals(false)) } + + test("SPARK-22303: handle BINARY_DOUBLE and BINARY_FLOAT as DoubleType and FloatType") { + val tableName = "oracle_types" + val schema = StructType(Seq( + StructField("d", DoubleType, true), + StructField("f", FloatType, true))) + val props = new Properties() + + // write it back to the table (append mode) + val data = spark.sparkContext.parallelize(Seq(Row(1.1, 2.2f))) + val dfWrite = spark.createDataFrame(data, schema) + dfWrite.write.mode(SaveMode.Append).jdbc(jdbcUrl, tableName, props) + + // read records from oracle_types + val dfRead = sqlContext.read.jdbc(jdbcUrl, tableName, new Properties) + val rows = dfRead.collect() + assert(rows.size == 1) + + // check data types + val types = dfRead.schema.map(field => field.dataType) + assert(types(0).equals(DoubleType)) + assert(types(1).equals(FloatType)) + + // check values + val values = rows(0) + assert(values.getDouble(0) === 1.1) + assert(values.getFloat(1) === 2.2f) + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala index 71133666b3249..9debc4ff82748 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala @@ -230,7 +230,6 @@ object JdbcUtils extends Logging { case java.sql.Types.TIMESTAMP => TimestampType case java.sql.Types.TIMESTAMP_WITH_TIMEZONE => TimestampType - case -101 => TimestampType // Value for Timestamp with Time Zone in Oracle case java.sql.Types.TINYINT => IntegerType case java.sql.Types.VARBINARY => BinaryType case java.sql.Types.VARCHAR => StringType diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/OracleDialect.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/OracleDialect.scala index 3b44c1de93a61..e3f106c41c7ff 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/OracleDialect.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/OracleDialect.scala @@ -23,30 +23,36 @@ import org.apache.spark.sql.types._ private case object OracleDialect extends JdbcDialect { + private[jdbc] val BINARY_FLOAT = 100 + private[jdbc] val BINARY_DOUBLE = 101 + private[jdbc] val TIMESTAMPTZ = -101 override def canHandle(url: String): Boolean = url.startsWith("jdbc:oracle") override def getCatalystType( sqlType: Int, typeName: String, size: Int, md: MetadataBuilder): Option[DataType] = { - if (sqlType == Types.NUMERIC) { - val scale = if (null != md) md.build().getLong("scale") else 0L - size match { - // Handle NUMBER fields that have no precision/scale in special way - // because JDBC ResultSetMetaData converts this to 0 precision and -127 scale - // For more details, please see - // https://github.com/apache/spark/pull/8780#issuecomment-145598968 - // and - // https://github.com/apache/spark/pull/8780#issuecomment-144541760 - case 0 => Option(DecimalType(DecimalType.MAX_PRECISION, 10)) - // Handle FLOAT fields in a special way because JDBC ResultSetMetaData converts - // this to NUMERIC with -127 scale - // Not sure if there is a more robust way to identify the field as a float (or other - // numeric types that do not specify a scale. - case _ if scale == -127L => Option(DecimalType(DecimalType.MAX_PRECISION, 10)) - case _ => None - } - } else { - None + sqlType match { + case Types.NUMERIC => + val scale = if (null != md) md.build().getLong("scale") else 0L + size match { + // Handle NUMBER fields that have no precision/scale in special way + // because JDBC ResultSetMetaData converts this to 0 precision and -127 scale + // For more details, please see + // https://github.com/apache/spark/pull/8780#issuecomment-145598968 + // and + // https://github.com/apache/spark/pull/8780#issuecomment-144541760 + case 0 => Option(DecimalType(DecimalType.MAX_PRECISION, 10)) + // Handle FLOAT fields in a special way because JDBC ResultSetMetaData converts + // this to NUMERIC with -127 scale + // Not sure if there is a more robust way to identify the field as a float (or other + // numeric types that do not specify a scale. + case _ if scale == -127L => Option(DecimalType(DecimalType.MAX_PRECISION, 10)) + case _ => None + } + case TIMESTAMPTZ => Some(TimestampType) // Value for Timestamp with Time Zone in Oracle + case BINARY_FLOAT => Some(FloatType) // Value for OracleTypes.BINARY_FLOAT + case BINARY_DOUBLE => Some(DoubleType) // Value for OracleTypes.BINARY_DOUBLE + case _ => None } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala index 34205e0b2bf08..167b3e0190026 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala @@ -815,6 +815,12 @@ class JDBCSuite extends SparkFunSuite Some(DecimalType(DecimalType.MAX_PRECISION, 10))) assert(oracleDialect.getCatalystType(java.sql.Types.NUMERIC, "numeric", 0, null) == Some(DecimalType(DecimalType.MAX_PRECISION, 10))) + assert(oracleDialect.getCatalystType(OracleDialect.BINARY_FLOAT, "BINARY_FLOAT", 0, null) == + Some(FloatType)) + assert(oracleDialect.getCatalystType(OracleDialect.BINARY_DOUBLE, "BINARY_DOUBLE", 0, null) == + Some(DoubleType)) + assert(oracleDialect.getCatalystType(OracleDialect.TIMESTAMPTZ, "TIMESTAMP", 0, null) == + Some(TimestampType)) } test("table exists query by jdbc dialect") { From f6290aea24efeb238db88bdaef4e24d50740ca4c Mon Sep 17 00:00:00 2001 From: Zhenhua Wang Date: Mon, 23 Oct 2017 23:02:36 +0100 Subject: [PATCH 1526/1765] [SPARK-22285][SQL] Change implementation of ApproxCountDistinctForIntervals to TypedImperativeAggregate ## What changes were proposed in this pull request? The current implementation of `ApproxCountDistinctForIntervals` is `ImperativeAggregate`. The number of `aggBufferAttributes` is the number of total words in the hllppHelper array. Each hllppHelper has 52 words by default relativeSD. Since this aggregate function is used in equi-height histogram generation, and the number of buckets in histogram is usually hundreds, the number of `aggBufferAttributes` can easily reach tens of thousands or even more. This leads to a huge method in codegen and causes error: ``` org.codehaus.janino.JaninoRuntimeException: Code of method "apply(Lorg/apache/spark/sql/catalyst/InternalRow;)Lorg/apache/spark/sql/catalyst/expressions/UnsafeRow;" of class "org.apache.spark.sql.catalyst.expressions.GeneratedClass$SpecificUnsafeProjection" grows beyond 64 KB. ``` Besides, huge generated methods also result in performance regression. In this PR, we change its implementation to `TypedImperativeAggregate`. After the fix, `ApproxCountDistinctForIntervals` can deal with more than thousands endpoints without throwing codegen error, and improve performance from `20 sec` to `2 sec` in a test case of 500 endpoints. ## How was this patch tested? Test by an added test case and existing tests. Author: Zhenhua Wang Closes #19506 from wzhfy/change_forIntervals_typedAgg. --- .../ApproxCountDistinctForIntervals.scala | 97 ++++++++++--------- ...ApproxCountDistinctForIntervalsSuite.scala | 34 +++---- ...xCountDistinctForIntervalsQuerySuite.scala | 61 ++++++++++++ 3 files changed, 130 insertions(+), 62 deletions(-) create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/ApproxCountDistinctForIntervalsQuerySuite.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproxCountDistinctForIntervals.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproxCountDistinctForIntervals.scala index 096d1b35a8620..d4421ca20a9bd 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproxCountDistinctForIntervals.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproxCountDistinctForIntervals.scala @@ -22,9 +22,10 @@ import java.util import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{TypeCheckFailure, TypeCheckSuccess} -import org.apache.spark.sql.catalyst.expressions.{AttributeReference, ExpectsInputTypes, Expression} +import org.apache.spark.sql.catalyst.expressions.{ExpectsInputTypes, Expression, GenericInternalRow} import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData, HyperLogLogPlusPlusHelper} import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.Platform /** * This function counts the approximate number of distinct values (ndv) in @@ -46,16 +47,7 @@ case class ApproxCountDistinctForIntervals( relativeSD: Double = 0.05, mutableAggBufferOffset: Int = 0, inputAggBufferOffset: Int = 0) - extends ImperativeAggregate with ExpectsInputTypes { - - def this(child: Expression, endpointsExpression: Expression) = { - this( - child = child, - endpointsExpression = endpointsExpression, - relativeSD = 0.05, - mutableAggBufferOffset = 0, - inputAggBufferOffset = 0) - } + extends TypedImperativeAggregate[Array[Long]] with ExpectsInputTypes { def this(child: Expression, endpointsExpression: Expression, relativeSD: Expression) = { this( @@ -114,29 +106,11 @@ case class ApproxCountDistinctForIntervals( private lazy val totalNumWords = numWordsPerHllpp * hllppArray.length /** Allocate enough words to store all registers. */ - override lazy val aggBufferAttributes: Seq[AttributeReference] = { - Seq.tabulate(totalNumWords) { i => - AttributeReference(s"MS[$i]", LongType)() - } + override def createAggregationBuffer(): Array[Long] = { + Array.fill(totalNumWords)(0L) } - override def aggBufferSchema: StructType = StructType.fromAttributes(aggBufferAttributes) - - // Note: although this simply copies aggBufferAttributes, this common code can not be placed - // in the superclass because that will lead to initialization ordering issues. - override lazy val inputAggBufferAttributes: Seq[AttributeReference] = - aggBufferAttributes.map(_.newInstance()) - - /** Fill all words with zeros. */ - override def initialize(buffer: InternalRow): Unit = { - var word = 0 - while (word < totalNumWords) { - buffer.setLong(mutableAggBufferOffset + word, 0) - word += 1 - } - } - - override def update(buffer: InternalRow, input: InternalRow): Unit = { + override def update(buffer: Array[Long], input: InternalRow): Array[Long] = { val value = child.eval(input) // Ignore empty rows if (value != null) { @@ -153,13 +127,14 @@ case class ApproxCountDistinctForIntervals( // endpoints are sorted into ascending order already if (endpoints.head > doubleValue || endpoints.last < doubleValue) { // ignore if the value is out of the whole range - return + return buffer } val hllppIndex = findHllppIndex(doubleValue) - val offset = mutableAggBufferOffset + hllppIndex * numWordsPerHllpp - hllppArray(hllppIndex).update(buffer, offset, value, child.dataType) + val offset = hllppIndex * numWordsPerHllpp + hllppArray(hllppIndex).update(LongArrayInternalRow(buffer), offset, value, child.dataType) } + buffer } // Find which interval (HyperLogLogPlusPlusHelper) should receive the given value. @@ -196,17 +171,18 @@ case class ApproxCountDistinctForIntervals( } } - override def merge(buffer1: InternalRow, buffer2: InternalRow): Unit = { + override def merge(buffer1: Array[Long], buffer2: Array[Long]): Array[Long] = { for (i <- hllppArray.indices) { hllppArray(i).merge( - buffer1 = buffer1, - buffer2 = buffer2, - offset1 = mutableAggBufferOffset + i * numWordsPerHllpp, - offset2 = inputAggBufferOffset + i * numWordsPerHllpp) + buffer1 = LongArrayInternalRow(buffer1), + buffer2 = LongArrayInternalRow(buffer2), + offset1 = i * numWordsPerHllpp, + offset2 = i * numWordsPerHllpp) } + buffer1 } - override def eval(buffer: InternalRow): Any = { + override def eval(buffer: Array[Long]): Any = { val ndvArray = hllppResults(buffer) // If the endpoints contains multiple elements with the same value, // we set ndv=1 for intervals between these elements. @@ -218,19 +194,23 @@ case class ApproxCountDistinctForIntervals( new GenericArrayData(ndvArray) } - def hllppResults(buffer: InternalRow): Array[Long] = { + def hllppResults(buffer: Array[Long]): Array[Long] = { val ndvArray = new Array[Long](hllppArray.length) for (i <- ndvArray.indices) { - ndvArray(i) = hllppArray(i).query(buffer, mutableAggBufferOffset + i * numWordsPerHllpp) + ndvArray(i) = hllppArray(i).query(LongArrayInternalRow(buffer), i * numWordsPerHllpp) } ndvArray } - override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): ImperativeAggregate = + override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int) + : ApproxCountDistinctForIntervals = { copy(mutableAggBufferOffset = newMutableAggBufferOffset) + } - override def withNewInputAggBufferOffset(newInputAggBufferOffset: Int): ImperativeAggregate = + override def withNewInputAggBufferOffset(newInputAggBufferOffset: Int) + : ApproxCountDistinctForIntervals = { copy(inputAggBufferOffset = newInputAggBufferOffset) + } override def children: Seq[Expression] = Seq(child, endpointsExpression) @@ -239,4 +219,31 @@ case class ApproxCountDistinctForIntervals( override def dataType: DataType = ArrayType(LongType) override def prettyName: String = "approx_count_distinct_for_intervals" + + override def serialize(obj: Array[Long]): Array[Byte] = { + val byteArray = new Array[Byte](obj.length * 8) + var i = 0 + while (i < obj.length) { + Platform.putLong(byteArray, Platform.BYTE_ARRAY_OFFSET + i * 8, obj(i)) + i += 1 + } + byteArray + } + + override def deserialize(bytes: Array[Byte]): Array[Long] = { + assert(bytes.length % 8 == 0) + val length = bytes.length / 8 + val longArray = new Array[Long](length) + var i = 0 + while (i < length) { + longArray(i) = Platform.getLong(bytes, Platform.BYTE_ARRAY_OFFSET + i * 8) + i += 1 + } + longArray + } + + private case class LongArrayInternalRow(array: Array[Long]) extends GenericInternalRow { + override def getLong(offset: Int): Long = array(offset) + override def setLong(offset: Int, value: Long): Unit = { array(offset) = value } + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproxCountDistinctForIntervalsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproxCountDistinctForIntervalsSuite.scala index d6c38c3608bf8..73f18d4feef3f 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproxCountDistinctForIntervalsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproxCountDistinctForIntervalsSuite.scala @@ -32,7 +32,7 @@ class ApproxCountDistinctForIntervalsSuite extends SparkFunSuite { val wrongColumnTypes = Seq(BinaryType, BooleanType, StringType, ArrayType(IntegerType), MapType(IntegerType, IntegerType), StructType(Seq(StructField("s", IntegerType)))) wrongColumnTypes.foreach { dataType => - val wrongColumn = new ApproxCountDistinctForIntervals( + val wrongColumn = ApproxCountDistinctForIntervals( AttributeReference("a", dataType)(), endpointsExpression = CreateArray(Seq(1, 10).map(Literal(_)))) assert( @@ -43,7 +43,7 @@ class ApproxCountDistinctForIntervalsSuite extends SparkFunSuite { }) } - var wrongEndpoints = new ApproxCountDistinctForIntervals( + var wrongEndpoints = ApproxCountDistinctForIntervals( AttributeReference("a", DoubleType)(), endpointsExpression = Literal(0.5d)) assert( @@ -52,19 +52,19 @@ class ApproxCountDistinctForIntervalsSuite extends SparkFunSuite { case _ => false }) - wrongEndpoints = new ApproxCountDistinctForIntervals( + wrongEndpoints = ApproxCountDistinctForIntervals( AttributeReference("a", DoubleType)(), endpointsExpression = CreateArray(Seq(AttributeReference("b", DoubleType)()))) assert(wrongEndpoints.checkInputDataTypes() == TypeCheckFailure("The endpoints provided must be constant literals")) - wrongEndpoints = new ApproxCountDistinctForIntervals( + wrongEndpoints = ApproxCountDistinctForIntervals( AttributeReference("a", DoubleType)(), endpointsExpression = CreateArray(Array(10L).map(Literal(_)))) assert(wrongEndpoints.checkInputDataTypes() == TypeCheckFailure("The number of endpoints must be >= 2 to construct intervals")) - wrongEndpoints = new ApproxCountDistinctForIntervals( + wrongEndpoints = ApproxCountDistinctForIntervals( AttributeReference("a", DoubleType)(), endpointsExpression = CreateArray(Array("foobar").map(Literal(_)))) assert(wrongEndpoints.checkInputDataTypes() == @@ -75,25 +75,18 @@ class ApproxCountDistinctForIntervalsSuite extends SparkFunSuite { private def createEstimator[T]( endpoints: Array[T], dt: DataType, - rsd: Double = 0.05): (ApproxCountDistinctForIntervals, InternalRow, InternalRow) = { + rsd: Double = 0.05): (ApproxCountDistinctForIntervals, InternalRow, Array[Long]) = { val input = new SpecificInternalRow(Seq(dt)) val aggFunc = ApproxCountDistinctForIntervals( BoundReference(0, dt, nullable = true), CreateArray(endpoints.map(Literal(_))), rsd) - val buffer = createBuffer(aggFunc) - (aggFunc, input, buffer) - } - - private def createBuffer(aggFunc: ApproxCountDistinctForIntervals): InternalRow = { - val buffer = new SpecificInternalRow(aggFunc.aggBufferAttributes.map(_.dataType)) - aggFunc.initialize(buffer) - buffer + (aggFunc, input, aggFunc.createAggregationBuffer()) } test("merging ApproxCountDistinctForIntervals instances") { val (aggFunc, input, buffer1a) = createEstimator(Array[Int](0, 10, 2000, 345678, 1000000), IntegerType) - val buffer1b = createBuffer(aggFunc) - val buffer2 = createBuffer(aggFunc) + val buffer1b = aggFunc.createAggregationBuffer() + val buffer2 = aggFunc.createAggregationBuffer() // Add the lower half to `buffer1a`. var i = 0 @@ -123,7 +116,7 @@ class ApproxCountDistinctForIntervalsSuite extends SparkFunSuite { } // Check if the buffers are equal. - assert(buffer2 == buffer1a, "Buffers should be equal") + assert(buffer2.sameElements(buffer1a), "Buffers should be equal") } test("test findHllppIndex(value) for values in the range") { @@ -152,6 +145,13 @@ class ApproxCountDistinctForIntervalsSuite extends SparkFunSuite { checkHllppIndex(endpoints = Array(1, 3, 5, 7, 7, 9), value = 7, expectedIntervalIndex = 2) } + test("round trip serialization") { + val (aggFunc, _, _) = createEstimator(Array(1, 2), DoubleType) + val longArray = (1L to 100L).toArray + val roundtrip = aggFunc.deserialize(aggFunc.serialize(longArray)) + assert(roundtrip.sameElements(longArray)) + } + test("basic operations: update, merge, eval...") { val endpoints = Array[Double](0, 0.33, 0.6, 0.6, 0.6, 1.0) val data: Seq[Double] = Seq(0, 0.6, 0.3, 1, 0.6, 0.5, 0.6, 0.33) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ApproxCountDistinctForIntervalsQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ApproxCountDistinctForIntervalsQuerySuite.scala new file mode 100644 index 0000000000000..c7d86bc955d67 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/ApproxCountDistinctForIntervalsQuerySuite.scala @@ -0,0 +1,61 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import org.apache.spark.sql.catalyst.expressions.{Alias, CreateArray, Literal} +import org.apache.spark.sql.catalyst.expressions.aggregate.ApproxCountDistinctForIntervals +import org.apache.spark.sql.catalyst.plans.logical.Aggregate +import org.apache.spark.sql.execution.QueryExecution +import org.apache.spark.sql.test.SharedSQLContext + +class ApproxCountDistinctForIntervalsQuerySuite extends QueryTest with SharedSQLContext { + import testImplicits._ + + // ApproxCountDistinctForIntervals is used in equi-height histogram generation. An equi-height + // histogram usually contains hundreds of buckets. So we need to test + // ApproxCountDistinctForIntervals with large number of endpoints + // (the number of endpoints == the number of buckets + 1). + test("test ApproxCountDistinctForIntervals with large number of endpoints") { + val table = "approx_count_distinct_for_intervals_tbl" + withTable(table) { + (1 to 100000).toDF("col").createOrReplaceTempView(table) + // percentiles of 0, 0.001, 0.002 ... 0.999, 1 + val endpoints = (0 to 1000).map(_ * 100000 / 1000) + + // Since approx_count_distinct_for_intervals is not a public function, here we do + // the computation by constructing logical plan. + val relation = spark.table(table).logicalPlan + val attr = relation.output.find(_.name == "col").get + val aggFunc = ApproxCountDistinctForIntervals(attr, CreateArray(endpoints.map(Literal(_)))) + val aggExpr = aggFunc.toAggregateExpression() + val namedExpr = Alias(aggExpr, aggExpr.toString)() + val ndvsRow = new QueryExecution(spark, Aggregate(Nil, Seq(namedExpr), relation)) + .executedPlan.executeTake(1).head + val ndvArray = ndvsRow.getArray(0).toLongArray() + assert(endpoints.length == ndvArray.length + 1) + + // Each bucket has 100 distinct values. + val expectedNdv = 100 + for (i <- ndvArray.indices) { + val ndv = ndvArray(i) + val error = math.abs((ndv / expectedNdv.toDouble) - 1.0d) + assert(error <= aggFunc.relativeSD * 3.0d, "Error should be within 3 std. errors.") + } + } + } +} From 884d4f95f7ebfaa9d8c57cf770d10a2c6ab82d62 Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Mon, 23 Oct 2017 17:21:49 -0700 Subject: [PATCH 1527/1765] [SPARK-21912][SQL][FOLLOW-UP] ORC/Parquet table should not create invalid column names ## What changes were proposed in this pull request? During [SPARK-21912](https://issues.apache.org/jira/browse/SPARK-21912), we skipped testing 'ADD COLUMNS' on ORC tables due to ORC limitation. Since [SPARK-21929](https://issues.apache.org/jira/browse/SPARK-21929) is resolved now, we can test both `ORC` and `PARQUET` completely. ## How was this patch tested? Pass the updated test case. Author: Dongjoon Hyun Closes #19562 from dongjoon-hyun/SPARK-21912-2. --- .../spark/sql/hive/execution/SQLQuerySuite.scala | 11 ++++------- 1 file changed, 4 insertions(+), 7 deletions(-) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala index 1cf1c5cd5a472..39e918c3d5209 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala @@ -2031,8 +2031,8 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton { test("SPARK-21912 ORC/Parquet table should not create invalid column names") { Seq(" ", ",", ";", "{", "}", "(", ")", "\n", "\t", "=").foreach { name => - withTable("t21912") { - Seq("ORC", "PARQUET").foreach { source => + Seq("ORC", "PARQUET").foreach { source => + withTable("t21912") { val m = intercept[AnalysisException] { sql(s"CREATE TABLE t21912(`col$name` INT) USING $source") }.getMessage @@ -2049,15 +2049,12 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton { }.getMessage assert(m3.contains(s"contains invalid character(s)")) } - } - // TODO: After SPARK-21929, we need to check ORC, too. - Seq("PARQUET").foreach { source => sql(s"CREATE TABLE t21912(`col` INT) USING $source") - val m = intercept[AnalysisException] { + val m4 = intercept[AnalysisException] { sql(s"ALTER TABLE t21912 ADD COLUMNS(`col$name` INT)") }.getMessage - assert(m.contains(s"contains invalid character(s)")) + assert(m4.contains(s"contains invalid character(s)")) } } } From d9798c834f3fed060cfd18a8d38c398cb2efcc82 Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Tue, 24 Oct 2017 12:44:47 +0900 Subject: [PATCH 1528/1765] [SPARK-22313][PYTHON] Mark/print deprecation warnings as DeprecationWarning for deprecated APIs ## What changes were proposed in this pull request? This PR proposes to mark the existing warnings as `DeprecationWarning` and print out warnings for deprecated functions. This could be actually useful for Spark app developers. I use (old) PyCharm and this IDE can detect this specific `DeprecationWarning` in some cases: **Before** **After** For console usage, `DeprecationWarning` is usually disabled (see https://docs.python.org/2/library/warnings.html#warning-categories and https://docs.python.org/3/library/warnings.html#warning-categories): ``` >>> import warnings >>> filter(lambda f: f[2] == DeprecationWarning, warnings.filters) [('ignore', <_sre.SRE_Pattern object at 0x10ba58c00>, , <_sre.SRE_Pattern object at 0x10bb04138>, 0), ('ignore', None, , None, 0)] ``` so, it won't actually mess up the terminal much unless it is intended. If this is intendedly enabled, it'd should as below: ``` >>> import warnings >>> warnings.simplefilter('always', DeprecationWarning) >>> >>> from pyspark.sql import functions >>> functions.approxCountDistinct("a") .../spark/python/pyspark/sql/functions.py:232: DeprecationWarning: Deprecated in 2.1, use approx_count_distinct instead. "Deprecated in 2.1, use approx_count_distinct instead.", DeprecationWarning) ... ``` These instances were found by: ``` cd python/pyspark grep -r "Deprecated" . grep -r "deprecated" . grep -r "deprecate" . ``` ## How was this patch tested? Manually tested. Author: hyukjinkwon Closes #19535 from HyukjinKwon/deprecated-warning. --- python/pyspark/ml/util.py | 8 ++- python/pyspark/mllib/classification.py | 2 +- python/pyspark/mllib/evaluation.py | 6 +-- python/pyspark/mllib/regression.py | 8 +-- python/pyspark/sql/dataframe.py | 3 ++ python/pyspark/sql/functions.py | 18 +++++++ python/pyspark/streaming/flume.py | 14 ++++- python/pyspark/streaming/kafka.py | 72 ++++++++++++++++++++++---- 8 files changed, 110 insertions(+), 21 deletions(-) diff --git a/python/pyspark/ml/util.py b/python/pyspark/ml/util.py index 67772910c0d38..c3c47bd79459a 100644 --- a/python/pyspark/ml/util.py +++ b/python/pyspark/ml/util.py @@ -175,7 +175,9 @@ def context(self, sqlContext): .. note:: Deprecated in 2.1 and will be removed in 3.0, use session instead. """ - warnings.warn("Deprecated in 2.1 and will be removed in 3.0, use session instead.") + warnings.warn( + "Deprecated in 2.1 and will be removed in 3.0, use session instead.", + DeprecationWarning) self._jwrite.context(sqlContext._ssql_ctx) return self @@ -256,7 +258,9 @@ def context(self, sqlContext): .. note:: Deprecated in 2.1 and will be removed in 3.0, use session instead. """ - warnings.warn("Deprecated in 2.1 and will be removed in 3.0, use session instead.") + warnings.warn( + "Deprecated in 2.1 and will be removed in 3.0, use session instead.", + DeprecationWarning) self._jread.context(sqlContext._ssql_ctx) return self diff --git a/python/pyspark/mllib/classification.py b/python/pyspark/mllib/classification.py index e04eeb2b60d71..cce703d432b5a 100644 --- a/python/pyspark/mllib/classification.py +++ b/python/pyspark/mllib/classification.py @@ -311,7 +311,7 @@ def train(cls, data, iterations=100, step=1.0, miniBatchFraction=1.0, """ warnings.warn( "Deprecated in 2.0.0. Use ml.classification.LogisticRegression or " - "LogisticRegressionWithLBFGS.") + "LogisticRegressionWithLBFGS.", DeprecationWarning) def train(rdd, i): return callMLlibFunc("trainLogisticRegressionModelWithSGD", rdd, int(iterations), diff --git a/python/pyspark/mllib/evaluation.py b/python/pyspark/mllib/evaluation.py index fc2a0b3b5038a..2cd1da3fbf9aa 100644 --- a/python/pyspark/mllib/evaluation.py +++ b/python/pyspark/mllib/evaluation.py @@ -234,7 +234,7 @@ def precision(self, label=None): """ if label is None: # note:: Deprecated in 2.0.0. Use accuracy. - warnings.warn("Deprecated in 2.0.0. Use accuracy.") + warnings.warn("Deprecated in 2.0.0. Use accuracy.", DeprecationWarning) return self.call("precision") else: return self.call("precision", float(label)) @@ -246,7 +246,7 @@ def recall(self, label=None): """ if label is None: # note:: Deprecated in 2.0.0. Use accuracy. - warnings.warn("Deprecated in 2.0.0. Use accuracy.") + warnings.warn("Deprecated in 2.0.0. Use accuracy.", DeprecationWarning) return self.call("recall") else: return self.call("recall", float(label)) @@ -259,7 +259,7 @@ def fMeasure(self, label=None, beta=None): if beta is None: if label is None: # note:: Deprecated in 2.0.0. Use accuracy. - warnings.warn("Deprecated in 2.0.0. Use accuracy.") + warnings.warn("Deprecated in 2.0.0. Use accuracy.", DeprecationWarning) return self.call("fMeasure") else: return self.call("fMeasure", label) diff --git a/python/pyspark/mllib/regression.py b/python/pyspark/mllib/regression.py index 1b66f5b51044b..ea107d400621d 100644 --- a/python/pyspark/mllib/regression.py +++ b/python/pyspark/mllib/regression.py @@ -278,7 +278,8 @@ def train(cls, data, iterations=100, step=1.0, miniBatchFraction=1.0, A condition which decides iteration termination. (default: 0.001) """ - warnings.warn("Deprecated in 2.0.0. Use ml.regression.LinearRegression.") + warnings.warn( + "Deprecated in 2.0.0. Use ml.regression.LinearRegression.", DeprecationWarning) def train(rdd, i): return callMLlibFunc("trainLinearRegressionModelWithSGD", rdd, int(iterations), @@ -421,7 +422,8 @@ def train(cls, data, iterations=100, step=1.0, regParam=0.01, """ warnings.warn( "Deprecated in 2.0.0. Use ml.regression.LinearRegression with elasticNetParam = 1.0. " - "Note the default regParam is 0.01 for LassoWithSGD, but is 0.0 for LinearRegression.") + "Note the default regParam is 0.01 for LassoWithSGD, but is 0.0 for LinearRegression.", + DeprecationWarning) def train(rdd, i): return callMLlibFunc("trainLassoModelWithSGD", rdd, int(iterations), float(step), @@ -566,7 +568,7 @@ def train(cls, data, iterations=100, step=1.0, regParam=0.01, warnings.warn( "Deprecated in 2.0.0. Use ml.regression.LinearRegression with elasticNetParam = 0.0. " "Note the default regParam is 0.01 for RidgeRegressionWithSGD, but is 0.0 for " - "LinearRegression.") + "LinearRegression.", DeprecationWarning) def train(rdd, i): return callMLlibFunc("trainRidgeModelWithSGD", rdd, int(iterations), float(step), diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index 38b01f0011671..c0b574e2b93a1 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -130,6 +130,8 @@ def registerTempTable(self, name): .. note:: Deprecated in 2.0, use createOrReplaceTempView instead. """ + warnings.warn( + "Deprecated in 2.0, use createOrReplaceTempView instead.", DeprecationWarning) self._jdf.createOrReplaceTempView(name) @since(2.0) @@ -1308,6 +1310,7 @@ def unionAll(self, other): .. note:: Deprecated in 2.0, use :func:`union` instead. """ + warnings.warn("Deprecated in 2.0, use union instead.", DeprecationWarning) return self.union(other) @since(2.3) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index 9bc374b93a433..0d40368c9cd6e 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -21,6 +21,7 @@ import math import sys import functools +import warnings if sys.version < "3": from itertools import imap as map @@ -44,6 +45,14 @@ def _(col): return _ +def _wrap_deprecated_function(func, message): + """ Wrap the deprecated function to print out deprecation warnings""" + def _(col): + warnings.warn(message, DeprecationWarning) + return func(col) + return functools.wraps(func)(_) + + def _create_binary_mathfunction(name, doc=""): """ Create a binary mathfunction by name""" def _(col1, col2): @@ -207,6 +216,12 @@ def _(): """returns the relative rank (i.e. percentile) of rows within a window partition.""", } +# Wraps deprecated functions (keys) with the messages (values). +_functions_deprecated = { + 'toDegrees': 'Deprecated in 2.1, use degrees instead.', + 'toRadians': 'Deprecated in 2.1, use radians instead.', +} + for _name, _doc in _functions.items(): globals()[_name] = since(1.3)(_create_function(_name, _doc)) for _name, _doc in _functions_1_4.items(): @@ -219,6 +234,8 @@ def _(): globals()[_name] = since(1.6)(_create_function(_name, _doc)) for _name, _doc in _functions_2_1.items(): globals()[_name] = since(2.1)(_create_function(_name, _doc)) +for _name, _message in _functions_deprecated.items(): + globals()[_name] = _wrap_deprecated_function(globals()[_name], _message) del _name, _doc @@ -227,6 +244,7 @@ def approxCountDistinct(col, rsd=None): """ .. note:: Deprecated in 2.1, use :func:`approx_count_distinct` instead. """ + warnings.warn("Deprecated in 2.1, use approx_count_distinct instead.", DeprecationWarning) return approx_count_distinct(col, rsd) diff --git a/python/pyspark/streaming/flume.py b/python/pyspark/streaming/flume.py index 2fed5940b31ea..5a975d050b0d8 100644 --- a/python/pyspark/streaming/flume.py +++ b/python/pyspark/streaming/flume.py @@ -54,8 +54,13 @@ def createStream(ssc, hostname, port, :param bodyDecoder: A function used to decode body (default is utf8_decoder) :return: A DStream object - .. note:: Deprecated in 2.3.0 + .. note:: Deprecated in 2.3.0. Flume support is deprecated as of Spark 2.3.0. + See SPARK-22142. """ + warnings.warn( + "Deprecated in 2.3.0. Flume support is deprecated as of Spark 2.3.0. " + "See SPARK-22142.", + DeprecationWarning) jlevel = ssc._sc._getJavaStorageLevel(storageLevel) helper = FlumeUtils._get_helper(ssc._sc) jstream = helper.createStream(ssc._jssc, hostname, port, jlevel, enableDecompression) @@ -82,8 +87,13 @@ def createPollingStream(ssc, addresses, :param bodyDecoder: A function used to decode body (default is utf8_decoder) :return: A DStream object - .. note:: Deprecated in 2.3.0 + .. note:: Deprecated in 2.3.0. Flume support is deprecated as of Spark 2.3.0. + See SPARK-22142. """ + warnings.warn( + "Deprecated in 2.3.0. Flume support is deprecated as of Spark 2.3.0. " + "See SPARK-22142.", + DeprecationWarning) jlevel = ssc._sc._getJavaStorageLevel(storageLevel) hosts = [] ports = [] diff --git a/python/pyspark/streaming/kafka.py b/python/pyspark/streaming/kafka.py index 4af4135c81958..fdb9308604489 100644 --- a/python/pyspark/streaming/kafka.py +++ b/python/pyspark/streaming/kafka.py @@ -15,6 +15,8 @@ # limitations under the License. # +import warnings + from py4j.protocol import Py4JJavaError from pyspark.rdd import RDD @@ -56,8 +58,13 @@ def createStream(ssc, zkQuorum, groupId, topics, kafkaParams=None, :param valueDecoder: A function used to decode value (default is utf8_decoder) :return: A DStream object - .. note:: Deprecated in 2.3.0 + .. note:: Deprecated in 2.3.0. Kafka 0.8 support is deprecated as of Spark 2.3.0. + See SPARK-21893. """ + warnings.warn( + "Deprecated in 2.3.0. Kafka 0.8 support is deprecated as of Spark 2.3.0. " + "See SPARK-21893.", + DeprecationWarning) if kafkaParams is None: kafkaParams = dict() kafkaParams.update({ @@ -105,8 +112,13 @@ def createDirectStream(ssc, topics, kafkaParams, fromOffsets=None, :return: A DStream object .. note:: Experimental - .. note:: Deprecated in 2.3.0 + .. note:: Deprecated in 2.3.0. Kafka 0.8 support is deprecated as of Spark 2.3.0. + See SPARK-21893. """ + warnings.warn( + "Deprecated in 2.3.0. Kafka 0.8 support is deprecated as of Spark 2.3.0. " + "See SPARK-21893.", + DeprecationWarning) if fromOffsets is None: fromOffsets = dict() if not isinstance(topics, list): @@ -159,8 +171,13 @@ def createRDD(sc, kafkaParams, offsetRanges, leaders=None, :return: An RDD object .. note:: Experimental - .. note:: Deprecated in 2.3.0 + .. note:: Deprecated in 2.3.0. Kafka 0.8 support is deprecated as of Spark 2.3.0. + See SPARK-21893. """ + warnings.warn( + "Deprecated in 2.3.0. Kafka 0.8 support is deprecated as of Spark 2.3.0. " + "See SPARK-21893.", + DeprecationWarning) if leaders is None: leaders = dict() if not isinstance(kafkaParams, dict): @@ -229,7 +246,8 @@ class OffsetRange(object): """ Represents a range of offsets from a single Kafka TopicAndPartition. - .. note:: Deprecated in 2.3.0 + .. note:: Deprecated in 2.3.0. Kafka 0.8 support is deprecated as of Spark 2.3.0. + See SPARK-21893. """ def __init__(self, topic, partition, fromOffset, untilOffset): @@ -240,6 +258,10 @@ def __init__(self, topic, partition, fromOffset, untilOffset): :param fromOffset: Inclusive starting offset. :param untilOffset: Exclusive ending offset. """ + warnings.warn( + "Deprecated in 2.3.0. Kafka 0.8 support is deprecated as of Spark 2.3.0. " + "See SPARK-21893.", + DeprecationWarning) self.topic = topic self.partition = partition self.fromOffset = fromOffset @@ -270,7 +292,8 @@ class TopicAndPartition(object): """ Represents a specific topic and partition for Kafka. - .. note:: Deprecated in 2.3.0 + .. note:: Deprecated in 2.3.0. Kafka 0.8 support is deprecated as of Spark 2.3.0. + See SPARK-21893. """ def __init__(self, topic, partition): @@ -279,6 +302,10 @@ def __init__(self, topic, partition): :param topic: Kafka topic name. :param partition: Kafka partition id. """ + warnings.warn( + "Deprecated in 2.3.0. Kafka 0.8 support is deprecated as of Spark 2.3.0. " + "See SPARK-21893.", + DeprecationWarning) self._topic = topic self._partition = partition @@ -303,7 +330,8 @@ class Broker(object): """ Represent the host and port info for a Kafka broker. - .. note:: Deprecated in 2.3.0 + .. note:: Deprecated in 2.3.0. Kafka 0.8 support is deprecated as of Spark 2.3.0. + See SPARK-21893. """ def __init__(self, host, port): @@ -312,6 +340,10 @@ def __init__(self, host, port): :param host: Broker's hostname. :param port: Broker's port. """ + warnings.warn( + "Deprecated in 2.3.0. Kafka 0.8 support is deprecated as of Spark 2.3.0. " + "See SPARK-21893.", + DeprecationWarning) self._host = host self._port = port @@ -323,10 +355,15 @@ class KafkaRDD(RDD): """ A Python wrapper of KafkaRDD, to provide additional information on normal RDD. - .. note:: Deprecated in 2.3.0 + .. note:: Deprecated in 2.3.0. Kafka 0.8 support is deprecated as of Spark 2.3.0. + See SPARK-21893. """ def __init__(self, jrdd, ctx, jrdd_deserializer): + warnings.warn( + "Deprecated in 2.3.0. Kafka 0.8 support is deprecated as of Spark 2.3.0. " + "See SPARK-21893.", + DeprecationWarning) RDD.__init__(self, jrdd, ctx, jrdd_deserializer) def offsetRanges(self): @@ -345,10 +382,15 @@ class KafkaDStream(DStream): """ A Python wrapper of KafkaDStream - .. note:: Deprecated in 2.3.0 + .. note:: Deprecated in 2.3.0. Kafka 0.8 support is deprecated as of Spark 2.3.0. + See SPARK-21893. """ def __init__(self, jdstream, ssc, jrdd_deserializer): + warnings.warn( + "Deprecated in 2.3.0. Kafka 0.8 support is deprecated as of Spark 2.3.0. " + "See SPARK-21893.", + DeprecationWarning) DStream.__init__(self, jdstream, ssc, jrdd_deserializer) def foreachRDD(self, func): @@ -383,10 +425,15 @@ class KafkaTransformedDStream(TransformedDStream): """ Kafka specific wrapper of TransformedDStream to transform on Kafka RDD. - .. note:: Deprecated in 2.3.0 + .. note:: Deprecated in 2.3.0. Kafka 0.8 support is deprecated as of Spark 2.3.0. + See SPARK-21893. """ def __init__(self, prev, func): + warnings.warn( + "Deprecated in 2.3.0. Kafka 0.8 support is deprecated as of Spark 2.3.0. " + "See SPARK-21893.", + DeprecationWarning) TransformedDStream.__init__(self, prev, func) @property @@ -405,7 +452,8 @@ class KafkaMessageAndMetadata(object): """ Kafka message and metadata information. Including topic, partition, offset and message - .. note:: Deprecated in 2.3.0 + .. note:: Deprecated in 2.3.0. Kafka 0.8 support is deprecated as of Spark 2.3.0. + See SPARK-21893. """ def __init__(self, topic, partition, offset, key, message): @@ -419,6 +467,10 @@ def __init__(self, topic, partition, offset, key, message): :param message: actual message payload of this Kafka message, the return data is undecoded bytearray. """ + warnings.warn( + "Deprecated in 2.3.0. Kafka 0.8 support is deprecated as of Spark 2.3.0. " + "See SPARK-21893.", + DeprecationWarning) self.topic = topic self.partition = partition self.offset = offset From c30d5cfc7117bdadd63bf730e88398139e0f65f4 Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Tue, 24 Oct 2017 08:46:22 +0100 Subject: [PATCH 1529/1765] [SPARK-20822][SQL] Generate code to directly get value from ColumnVector for table cache ## What changes were proposed in this pull request? This PR generates the Java code to directly get a value for a column in `ColumnVector` without using an iterator (e.g. at lines 54-69 in the generated code example) for table cache (e.g. `dataframe.cache`). This PR improves runtime performance by eliminating data copy from column-oriented storage to `InternalRow` in a `SpecificColumnarIterator` iterator for primitive type. Another PR will support primitive type array. Benchmark result: **1.2x** ``` OpenJDK 64-Bit Server VM 1.8.0_121-8u121-b13-0ubuntu1.16.04.2-b13 on Linux 4.4.0-22-generic Intel(R) Xeon(R) CPU E5-2667 v3 3.20GHz Int Sum with IntDelta cache: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ InternalRow codegen 731 / 812 43.0 23.2 1.0X ColumnVector codegen 616 / 772 51.0 19.6 1.2X ``` Benchmark program ``` intSumBenchmark(sqlContext, 1024 * 1024 * 30) def intSumBenchmark(sqlContext: SQLContext, values: Int): Unit = { import sqlContext.implicits._ val benchmarkPT = new Benchmark("Int Sum with IntDelta cache", values, 20) Seq(("InternalRow", "false"), ("ColumnVector", "true")).foreach { case (str, value) => withSQLConf(sqlContext, SQLConf. COLUMN_VECTOR_CODEGEN.key -> value) { // tentatively added for benchmarking val dfPassThrough = sqlContext.sparkContext.parallelize(0 to values - 1, 1).toDF().cache() dfPassThrough.count() // force to create df.cache() benchmarkPT.addCase(s"$str codegen") { iter => dfPassThrough.agg(sum("value")).collect } dfPassThrough.unpersist(true) } } benchmarkPT.run() } ``` Motivating example ``` val dsInt = spark.range(3).cache dsInt.count // force to build cache dsInt.filter(_ > 0).collect ``` Generated code ``` /* 001 */ public Object generate(Object[] references) { /* 002 */ return new GeneratedIterator(references); /* 003 */ } /* 004 */ /* 005 */ final class GeneratedIterator extends org.apache.spark.sql.execution.BufferedRowIterator { /* 006 */ private Object[] references; /* 007 */ private scala.collection.Iterator[] inputs; /* 008 */ private scala.collection.Iterator inmemorytablescan_input; /* 009 */ private org.apache.spark.sql.execution.metric.SQLMetric inmemorytablescan_numOutputRows; /* 010 */ private org.apache.spark.sql.execution.metric.SQLMetric inmemorytablescan_scanTime; /* 011 */ private long inmemorytablescan_scanTime1; /* 012 */ private org.apache.spark.sql.execution.vectorized.ColumnarBatch inmemorytablescan_batch; /* 013 */ private int inmemorytablescan_batchIdx; /* 014 */ private org.apache.spark.sql.execution.vectorized.OnHeapColumnVector inmemorytablescan_colInstance0; /* 015 */ private UnsafeRow inmemorytablescan_result; /* 016 */ private org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder inmemorytablescan_holder; /* 017 */ private org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter inmemorytablescan_rowWriter; /* 018 */ private org.apache.spark.sql.execution.metric.SQLMetric filter_numOutputRows; /* 019 */ private UnsafeRow filter_result; /* 020 */ private org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder filter_holder; /* 021 */ private org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter filter_rowWriter; /* 022 */ /* 023 */ public GeneratedIterator(Object[] references) { /* 024 */ this.references = references; /* 025 */ } /* 026 */ /* 027 */ public void init(int index, scala.collection.Iterator[] inputs) { /* 028 */ partitionIndex = index; /* 029 */ this.inputs = inputs; /* 030 */ inmemorytablescan_input = inputs[0]; /* 031 */ inmemorytablescan_numOutputRows = (org.apache.spark.sql.execution.metric.SQLMetric) references[0]; /* 032 */ inmemorytablescan_scanTime = (org.apache.spark.sql.execution.metric.SQLMetric) references[1]; /* 033 */ inmemorytablescan_scanTime1 = 0; /* 034 */ inmemorytablescan_batch = null; /* 035 */ inmemorytablescan_batchIdx = 0; /* 036 */ inmemorytablescan_colInstance0 = null; /* 037 */ inmemorytablescan_result = new UnsafeRow(1); /* 038 */ inmemorytablescan_holder = new org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder(inmemorytablescan_result, 0); /* 039 */ inmemorytablescan_rowWriter = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter(inmemorytablescan_holder, 1); /* 040 */ filter_numOutputRows = (org.apache.spark.sql.execution.metric.SQLMetric) references[2]; /* 041 */ filter_result = new UnsafeRow(1); /* 042 */ filter_holder = new org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder(filter_result, 0); /* 043 */ filter_rowWriter = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter(filter_holder, 1); /* 044 */ /* 045 */ } /* 046 */ /* 047 */ protected void processNext() throws java.io.IOException { /* 048 */ if (inmemorytablescan_batch == null) { /* 049 */ inmemorytablescan_nextBatch(); /* 050 */ } /* 051 */ while (inmemorytablescan_batch != null) { /* 052 */ int inmemorytablescan_numRows = inmemorytablescan_batch.numRows(); /* 053 */ int inmemorytablescan_localEnd = inmemorytablescan_numRows - inmemorytablescan_batchIdx; /* 054 */ for (int inmemorytablescan_localIdx = 0; inmemorytablescan_localIdx < inmemorytablescan_localEnd; inmemorytablescan_localIdx++) { /* 055 */ int inmemorytablescan_rowIdx = inmemorytablescan_batchIdx + inmemorytablescan_localIdx; /* 056 */ int inmemorytablescan_value = inmemorytablescan_colInstance0.getInt(inmemorytablescan_rowIdx); /* 057 */ /* 058 */ boolean filter_isNull = false; /* 059 */ /* 060 */ boolean filter_value = false; /* 061 */ filter_value = inmemorytablescan_value > 1; /* 062 */ if (!filter_value) continue; /* 063 */ /* 064 */ filter_numOutputRows.add(1); /* 065 */ /* 066 */ filter_rowWriter.write(0, inmemorytablescan_value); /* 067 */ append(filter_result); /* 068 */ if (shouldStop()) { inmemorytablescan_batchIdx = inmemorytablescan_rowIdx + 1; return; } /* 069 */ } /* 070 */ inmemorytablescan_batchIdx = inmemorytablescan_numRows; /* 071 */ inmemorytablescan_batch = null; /* 072 */ inmemorytablescan_nextBatch(); /* 073 */ } /* 074 */ inmemorytablescan_scanTime.add(inmemorytablescan_scanTime1 / (1000 * 1000)); /* 075 */ inmemorytablescan_scanTime1 = 0; /* 076 */ } /* 077 */ /* 078 */ private void inmemorytablescan_nextBatch() throws java.io.IOException { /* 079 */ long getBatchStart = System.nanoTime(); /* 080 */ if (inmemorytablescan_input.hasNext()) { /* 081 */ org.apache.spark.sql.execution.columnar.CachedBatch inmemorytablescan_cachedBatch = (org.apache.spark.sql.execution.columnar.CachedBatch)inmemorytablescan_input.next(); /* 082 */ inmemorytablescan_batch = org.apache.spark.sql.execution.columnar.InMemoryRelation$.MODULE$.createColumn(inmemorytablescan_cachedBatch); /* 083 */ /* 084 */ inmemorytablescan_numOutputRows.add(inmemorytablescan_batch.numRows()); /* 085 */ inmemorytablescan_batchIdx = 0; /* 086 */ inmemorytablescan_colInstance0 = (org.apache.spark.sql.execution.vectorized.OnHeapColumnVector) inmemorytablescan_batch.column(0); org.apache.spark.sql.execution.columnar.ColumnAccessor$.MODULE$.decompress(inmemorytablescan_cachedBatch.buffers()[0], (org.apache.spark.sql.execution.vectorized.WritableColumnVector) inmemorytablescan_colInstance0, org.apache.spark.sql.types.DataTypes.IntegerType, inmemorytablescan_cachedBatch.numRows()); /* 087 */ /* 088 */ } /* 089 */ inmemorytablescan_scanTime1 += System.nanoTime() - getBatchStart; /* 090 */ } /* 091 */ } ``` ## How was this patch tested? Add test cases into `DataFrameTungstenSuite` and `WholeStageCodegenSuite` Author: Kazuaki Ishizaki Closes #18747 from kiszk/SPARK-20822a. --- .../sql/execution/ColumnarBatchScan.scala | 3 - .../sql/execution/WholeStageCodegenExec.scala | 24 ++++---- .../execution/columnar/ColumnAccessor.scala | 8 +++ .../columnar/InMemoryTableScanExec.scala | 57 +++++++++++++++++-- .../spark/sql/DataFrameTungstenSuite.scala | 36 ++++++++++++ .../execution/WholeStageCodegenSuite.scala | 32 +++++++++++ 6 files changed, 141 insertions(+), 19 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ColumnarBatchScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ColumnarBatchScan.scala index 1afe83ea3539e..eb01e126bcbef 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ColumnarBatchScan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ColumnarBatchScan.scala @@ -19,7 +19,6 @@ package org.apache.spark.sql.execution import org.apache.spark.sql.catalyst.expressions.UnsafeRow import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} -import org.apache.spark.sql.execution.columnar.InMemoryTableScanExec import org.apache.spark.sql.execution.metric.SQLMetrics import org.apache.spark.sql.execution.vectorized.{ColumnarBatch, ColumnVector} import org.apache.spark.sql.types.DataType @@ -31,8 +30,6 @@ import org.apache.spark.sql.types.DataType */ private[sql] trait ColumnarBatchScan extends CodegenSupport { - val inMemoryTableScan: InMemoryTableScanExec = null - def vectorTypes: Option[Seq[String]] = None override lazy val metrics = Map( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala index 1aaaf896692d1..e37d133ff336a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala @@ -282,6 +282,18 @@ case class InputAdapter(child: SparkPlan) extends UnaryExecNode with CodegenSupp object WholeStageCodegenExec { val PIPELINE_DURATION_METRIC = "duration" + + private def numOfNestedFields(dataType: DataType): Int = dataType match { + case dt: StructType => dt.fields.map(f => numOfNestedFields(f.dataType)).sum + case m: MapType => numOfNestedFields(m.keyType) + numOfNestedFields(m.valueType) + case a: ArrayType => numOfNestedFields(a.elementType) + case u: UserDefinedType[_] => numOfNestedFields(u.sqlType) + case _ => 1 + } + + def isTooManyFields(conf: SQLConf, dataType: DataType): Boolean = { + numOfNestedFields(dataType) > conf.wholeStageMaxNumFields + } } /** @@ -490,22 +502,14 @@ case class CollapseCodegenStages(conf: SQLConf) extends Rule[SparkPlan] { case _ => true } - private def numOfNestedFields(dataType: DataType): Int = dataType match { - case dt: StructType => dt.fields.map(f => numOfNestedFields(f.dataType)).sum - case m: MapType => numOfNestedFields(m.keyType) + numOfNestedFields(m.valueType) - case a: ArrayType => numOfNestedFields(a.elementType) - case u: UserDefinedType[_] => numOfNestedFields(u.sqlType) - case _ => 1 - } - private def supportCodegen(plan: SparkPlan): Boolean = plan match { case plan: CodegenSupport if plan.supportCodegen => val willFallback = plan.expressions.exists(_.find(e => !supportCodegen(e)).isDefined) // the generated code will be huge if there are too many columns val hasTooManyOutputFields = - numOfNestedFields(plan.schema) > conf.wholeStageMaxNumFields + WholeStageCodegenExec.isTooManyFields(conf, plan.schema) val hasTooManyInputFields = - plan.children.map(p => numOfNestedFields(p.schema)).exists(_ > conf.wholeStageMaxNumFields) + plan.children.exists(p => WholeStageCodegenExec.isTooManyFields(conf, p.schema)) !willFallback && !hasTooManyOutputFields && !hasTooManyInputFields case _ => false } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnAccessor.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnAccessor.scala index 24c8ac81420cb..445933d98e9d4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnAccessor.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnAccessor.scala @@ -163,4 +163,12 @@ private[sql] object ColumnAccessor { throw new RuntimeException("Not support non-primitive type now") } } + + def decompress( + array: Array[Byte], columnVector: WritableColumnVector, dataType: DataType, numRows: Int): + Unit = { + val byteBuffer = ByteBuffer.wrap(array) + val columnAccessor = ColumnAccessor(dataType, byteBuffer) + decompress(columnAccessor, columnVector, numRows) + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala index 139da1c519da2..43386e7a03c32 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala @@ -23,21 +23,66 @@ import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.QueryPlan import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, Partitioning} -import org.apache.spark.sql.execution.LeafExecNode -import org.apache.spark.sql.execution.metric.SQLMetrics -import org.apache.spark.sql.types.UserDefinedType +import org.apache.spark.sql.execution.{ColumnarBatchScan, LeafExecNode, WholeStageCodegenExec} +import org.apache.spark.sql.execution.vectorized._ +import org.apache.spark.sql.types._ case class InMemoryTableScanExec( attributes: Seq[Attribute], predicates: Seq[Expression], @transient relation: InMemoryRelation) - extends LeafExecNode { + extends LeafExecNode with ColumnarBatchScan { override protected def innerChildren: Seq[QueryPlan[_]] = Seq(relation) ++ super.innerChildren - override lazy val metrics = Map( - "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows")) + override def vectorTypes: Option[Seq[String]] = + Option(Seq.fill(attributes.length)(classOf[OnHeapColumnVector].getName)) + + /** + * If true, get data from ColumnVector in ColumnarBatch, which are generally faster. + * If false, get data from UnsafeRow build from ColumnVector + */ + override val supportCodegen: Boolean = { + // In the initial implementation, for ease of review + // support only primitive data types and # of fields is less than wholeStageMaxNumFields + relation.schema.fields.forall(f => f.dataType match { + case BooleanType | ByteType | ShortType | IntegerType | LongType | + FloatType | DoubleType => true + case _ => false + }) && !WholeStageCodegenExec.isTooManyFields(conf, relation.schema) + } + + private val columnIndices = + attributes.map(a => relation.output.map(o => o.exprId).indexOf(a.exprId)).toArray + + private val relationSchema = relation.schema.toArray + + private lazy val columnarBatchSchema = new StructType(columnIndices.map(i => relationSchema(i))) + + private def createAndDecompressColumn(cachedColumnarBatch: CachedBatch): ColumnarBatch = { + val rowCount = cachedColumnarBatch.numRows + val columnVectors = OnHeapColumnVector.allocateColumns(rowCount, columnarBatchSchema) + val columnarBatch = new ColumnarBatch( + columnarBatchSchema, columnVectors.asInstanceOf[Array[ColumnVector]], rowCount) + columnarBatch.setNumRows(rowCount) + + for (i <- 0 until attributes.length) { + ColumnAccessor.decompress( + cachedColumnarBatch.buffers(columnIndices(i)), + columnarBatch.column(i).asInstanceOf[WritableColumnVector], + columnarBatchSchema.fields(i).dataType, rowCount) + } + columnarBatch + } + + override def inputRDDs(): Seq[RDD[InternalRow]] = { + assert(supportCodegen) + val buffers = relation.cachedColumnBuffers + // HACK ALERT: This is actually an RDD[ColumnarBatch]. + // We're taking advantage of Scala's type erasure here to pass these batches along. + Seq(buffers.map(createAndDecompressColumn(_)).asInstanceOf[RDD[InternalRow]]) + } override def output: Seq[Attribute] = attributes diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameTungstenSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameTungstenSuite.scala index fe6ba83b4cbfb..0881212a64de8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameTungstenSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameTungstenSuite.scala @@ -73,4 +73,40 @@ class DataFrameTungstenSuite extends QueryTest with SharedSQLContext { val df = spark.createDataFrame(data, schema) assert(df.select("b").first() === Row(outerStruct)) } + + test("primitive data type accesses in persist data") { + val data = Seq(true, 1.toByte, 3.toShort, 7, 15.toLong, + 31.25.toFloat, 63.75, null) + val dataTypes = Seq(BooleanType, ByteType, ShortType, IntegerType, LongType, + FloatType, DoubleType, IntegerType) + val schemas = dataTypes.zipWithIndex.map { case (dataType, index) => + StructField(s"col$index", dataType, true) + } + val rdd = sparkContext.makeRDD(Seq(Row.fromSeq(data))) + val df = spark.createDataFrame(rdd, StructType(schemas)) + val row = df.persist.take(1).apply(0) + checkAnswer(df, row) + } + + test("access cache multiple times") { + val df0 = sparkContext.parallelize(Seq(1, 2, 3), 1).toDF("x").cache + df0.count + val df1 = df0.filter("x > 1") + checkAnswer(df1, Seq(Row(2), Row(3))) + val df2 = df0.filter("x > 2") + checkAnswer(df2, Row(3)) + + val df10 = sparkContext.parallelize(Seq(3, 4, 5, 6), 1).toDF("x").cache + for (_ <- 0 to 2) { + val df11 = df10.filter("x > 5") + checkAnswer(df11, Row(6)) + } + } + + test("access only some column of the all of columns") { + val df = spark.range(1, 10).map(i => (i, (i + 1).toDouble)).toDF("l", "d") + df.cache + df.count + assert(df.filter("d < 3").count == 1) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala index 098e4cfeb15b2..bc05dca578c47 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.execution import org.apache.spark.sql.{QueryTest, Row, SaveMode} import org.apache.spark.sql.catalyst.expressions.codegen.{CodeAndComment, CodeGenerator} import org.apache.spark.sql.execution.aggregate.HashAggregateExec +import org.apache.spark.sql.execution.columnar.InMemoryTableScanExec import org.apache.spark.sql.execution.joins.BroadcastHashJoinExec import org.apache.spark.sql.execution.joins.SortMergeJoinExec import org.apache.spark.sql.expressions.scalalang.typed @@ -117,6 +118,37 @@ class WholeStageCodegenSuite extends QueryTest with SharedSQLContext { assert(ds.collect() === Array(("a", 10.0), ("b", 3.0), ("c", 1.0))) } + test("cache for primitive type should be in WholeStageCodegen with InMemoryTableScanExec") { + import testImplicits._ + + val dsInt = spark.range(3).cache + dsInt.count + val dsIntFilter = dsInt.filter(_ > 0) + val planInt = dsIntFilter.queryExecution.executedPlan + assert(planInt.find(p => + p.isInstanceOf[WholeStageCodegenExec] && + p.asInstanceOf[WholeStageCodegenExec].child.isInstanceOf[FilterExec] && + p.asInstanceOf[WholeStageCodegenExec].child.asInstanceOf[FilterExec].child + .isInstanceOf[InMemoryTableScanExec] && + p.asInstanceOf[WholeStageCodegenExec].child.asInstanceOf[FilterExec].child + .asInstanceOf[InMemoryTableScanExec].supportCodegen).isDefined + ) + assert(dsIntFilter.collect() === Array(1, 2)) + + // cache for string type is not supported for InMemoryTableScanExec + val dsString = spark.range(3).map(_.toString).cache + dsString.count + val dsStringFilter = dsString.filter(_ == "1") + val planString = dsStringFilter.queryExecution.executedPlan + assert(planString.find(p => + p.isInstanceOf[WholeStageCodegenExec] && + p.asInstanceOf[WholeStageCodegenExec].child.isInstanceOf[FilterExec] && + !p.asInstanceOf[WholeStageCodegenExec].child.asInstanceOf[FilterExec].child + .isInstanceOf[InMemoryTableScanExec]).isDefined + ) + assert(dsStringFilter.collect() === Array("1")) + } + test("SPARK-19512 codegen for comparing structs is incorrect") { // this would raise CompileException before the fix spark.range(10) From 8beeaed66bde0ace44495b38dc967816e16b3464 Mon Sep 17 00:00:00 2001 From: Sean Owen Date: Tue, 24 Oct 2017 13:56:10 +0100 Subject: [PATCH 1530/1765] [SPARK-21936][SQL][FOLLOW-UP] backward compatibility test framework for HiveExternalCatalog ## What changes were proposed in this pull request? Adjust Spark download in test to use Apache mirrors and respect its load balancer, and use Spark 2.1.2. This follows on a recent PMC list thread about removing the cloudfront download rather than update it further. ## How was this patch tested? Existing tests. Author: Sean Owen Closes #19564 from srowen/SPARK-21936.2. --- .../spark/sql/hive/HiveExternalCatalogVersionsSuite.scala | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogVersionsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogVersionsSuite.scala index 305f5b533d592..5f8c9d5799662 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogVersionsSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogVersionsSuite.scala @@ -53,7 +53,9 @@ class HiveExternalCatalogVersionsSuite extends SparkSubmitTestUtils { private def downloadSpark(version: String): Unit = { import scala.sys.process._ - val url = s"https://d3kbcqa49mib13.cloudfront.net/spark-$version-bin-hadoop2.7.tgz" + val preferredMirror = + Seq("wget", "https://www.apache.org/dyn/closer.lua?preferred=true", "-q", "-O", "-").!!.trim + val url = s"$preferredMirror/spark/spark-$version/spark-$version-bin-hadoop2.7.tgz" Seq("wget", url, "-q", "-P", sparkTestingDir.getCanonicalPath).! @@ -142,7 +144,7 @@ class HiveExternalCatalogVersionsSuite extends SparkSubmitTestUtils { object PROCESS_TABLES extends QueryTest with SQLTestUtils { // Tests the latest version of every release line. - val testingVersions = Seq("2.0.2", "2.1.1", "2.2.0") + val testingVersions = Seq("2.0.2", "2.1.2", "2.2.0") protected var spark: SparkSession = _ From 3f5ba968c5af7911a2f6c452500b6a629a3de8db Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Tue, 24 Oct 2017 09:11:52 -0700 Subject: [PATCH 1531/1765] [SPARK-22301][SQL] Add rule to Optimizer for In with not-nullable value and empty list ## What changes were proposed in this pull request? For performance reason, we should resolve in operation on an empty list as false in the optimizations phase, ad discussed in #19522. ## How was this patch tested? Added UT cc gatorsmile Author: Marco Gaido Author: Marco Gaido Closes #19523 from mgaido91/SPARK-22301. --- .../sql/catalyst/optimizer/expressions.scala | 7 +++++-- .../sql/catalyst/optimizer/OptimizeInSuite.scala | 16 ++++++++++++++++ 2 files changed, 21 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala index 273bc6ce27c5d..523b53b39d6b5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala @@ -169,13 +169,16 @@ object ReorderAssociativeOperator extends Rule[LogicalPlan] { /** * Optimize IN predicates: - * 1. Removes literal repetitions. - * 2. Replaces [[In (value, seq[Literal])]] with optimized version + * 1. Converts the predicate to false when the list is empty and + * the value is not nullable. + * 2. Removes literal repetitions. + * 3. Replaces [[In (value, seq[Literal])]] with optimized version * [[InSet (value, HashSet[Literal])]] which is much faster. */ object OptimizeIn extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan transform { case q: LogicalPlan => q transformExpressionsDown { + case In(v, list) if list.isEmpty && !v.nullable => FalseLiteral case expr @ In(v, list) if expr.inSetConvertible => val newList = ExpressionSet(list).toSeq if (newList.size > SQLConf.get.optimizerInSetConversionThreshold) { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeInSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeInSuite.scala index eaad1e32a8aba..d7acd139225cd 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeInSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeInSuite.scala @@ -175,4 +175,20 @@ class OptimizeInSuite extends PlanTest { } } } + + test("OptimizedIn test: In empty list gets transformed to FalseLiteral " + + "when value is not nullable") { + val originalQuery = + testRelation + .where(In(Literal("a"), Nil)) + .analyze + + val optimized = Optimize.execute(originalQuery) + val correctAnswer = + testRelation + .where(Literal(false)) + .analyze + + comparePlans(optimized, correctAnswer) + } } From bc1e76632ddec8fc64726086905183d1f312bca4 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Wed, 25 Oct 2017 06:33:44 +0100 Subject: [PATCH 1532/1765] [SPARK-22348][SQL] The table cache providing ColumnarBatch should also do partition batch pruning ## What changes were proposed in this pull request? We enable table cache `InMemoryTableScanExec` to provide `ColumnarBatch` now. But the cached batches are retrieved without pruning. In this case, we still need to do partition batch pruning. ## How was this patch tested? Existing tests. Author: Liang-Chi Hsieh Closes #19569 from viirya/SPARK-22348. --- .../columnar/InMemoryTableScanExec.scala | 70 ++++++++++--------- .../columnar/InMemoryColumnarQuerySuite.scala | 27 ++++++- 2 files changed, 64 insertions(+), 33 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala index 43386e7a03c32..2ae3f35eb1da1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala @@ -78,7 +78,7 @@ case class InMemoryTableScanExec( override def inputRDDs(): Seq[RDD[InternalRow]] = { assert(supportCodegen) - val buffers = relation.cachedColumnBuffers + val buffers = filteredCachedBatches() // HACK ALERT: This is actually an RDD[ColumnarBatch]. // We're taking advantage of Scala's type erasure here to pass these batches along. Seq(buffers.map(createAndDecompressColumn(_)).asInstanceOf[RDD[InternalRow]]) @@ -180,19 +180,11 @@ case class InMemoryTableScanExec( private val inMemoryPartitionPruningEnabled = sqlContext.conf.inMemoryPartitionPruning - protected override def doExecute(): RDD[InternalRow] = { - val numOutputRows = longMetric("numOutputRows") - - if (enableAccumulators) { - readPartitions.setValue(0) - readBatches.setValue(0) - } - + private def filteredCachedBatches(): RDD[CachedBatch] = { // Using these variables here to avoid serialization of entire objects (if referenced directly) // within the map Partitions closure. val schema = relation.partitionStatistics.schema val schemaIndex = schema.zipWithIndex - val relOutput: AttributeSeq = relation.output val buffers = relation.cachedColumnBuffers buffers.mapPartitionsWithIndexInternal { (index, cachedBatchIterator) => @@ -201,35 +193,49 @@ case class InMemoryTableScanExec( schema) partitionFilter.initialize(index) + // Do partition batch pruning if enabled + if (inMemoryPartitionPruningEnabled) { + cachedBatchIterator.filter { cachedBatch => + if (!partitionFilter.eval(cachedBatch.stats)) { + logDebug { + val statsString = schemaIndex.map { case (a, i) => + val value = cachedBatch.stats.get(i, a.dataType) + s"${a.name}: $value" + }.mkString(", ") + s"Skipping partition based on stats $statsString" + } + false + } else { + true + } + } + } else { + cachedBatchIterator + } + } + } + + protected override def doExecute(): RDD[InternalRow] = { + val numOutputRows = longMetric("numOutputRows") + + if (enableAccumulators) { + readPartitions.setValue(0) + readBatches.setValue(0) + } + + // Using these variables here to avoid serialization of entire objects (if referenced directly) + // within the map Partitions closure. + val relOutput: AttributeSeq = relation.output + + filteredCachedBatches().mapPartitionsInternal { cachedBatchIterator => // Find the ordinals and data types of the requested columns. val (requestedColumnIndices, requestedColumnDataTypes) = attributes.map { a => relOutput.indexOf(a.exprId) -> a.dataType }.unzip - // Do partition batch pruning if enabled - val cachedBatchesToScan = - if (inMemoryPartitionPruningEnabled) { - cachedBatchIterator.filter { cachedBatch => - if (!partitionFilter.eval(cachedBatch.stats)) { - logDebug { - val statsString = schemaIndex.map { case (a, i) => - val value = cachedBatch.stats.get(i, a.dataType) - s"${a.name}: $value" - }.mkString(", ") - s"Skipping partition based on stats $statsString" - } - false - } else { - true - } - } - } else { - cachedBatchIterator - } - // update SQL metrics - val withMetrics = cachedBatchesToScan.map { batch => + val withMetrics = cachedBatchIterator.map { batch => if (enableAccumulators) { readBatches.add(1) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala index 2f249c850a088..e662e294228db 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala @@ -23,7 +23,7 @@ import java.sql.{Date, Timestamp} import org.apache.spark.sql.{DataFrame, QueryTest, Row} import org.apache.spark.sql.catalyst.expressions.{AttributeReference, AttributeSet, In} import org.apache.spark.sql.catalyst.plans.physical.HashPartitioning -import org.apache.spark.sql.execution.LocalTableScanExec +import org.apache.spark.sql.execution.{FilterExec, LocalTableScanExec, WholeStageCodegenExec} import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSQLContext @@ -454,4 +454,29 @@ class InMemoryColumnarQuerySuite extends QueryTest with SharedSQLContext { Seq(In(attribute, Nil)), testRelation) assert(tableScanExec.partitionFilters.isEmpty) } + + test("SPARK-22348: table cache should do partition batch pruning") { + Seq("true", "false").foreach { enabled => + withSQLConf(SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> enabled) { + val df1 = Seq((1, 1), (1, 1), (2, 2)).toDF("x", "y") + df1.unpersist() + df1.cache() + + // Push predicate to the cached table. + val df2 = df1.where("y = 3") + + val planBeforeFilter = df2.queryExecution.executedPlan.collect { + case f: FilterExec => f.child + } + assert(planBeforeFilter.head.isInstanceOf[InMemoryTableScanExec]) + + val execPlan = if (enabled == "true") { + WholeStageCodegenExec(planBeforeFilter.head) + } else { + planBeforeFilter.head + } + assert(execPlan.executeCollectPublic().length == 0) + } + } + } } From 524abb996abc9970d699623c13469ea3b6d2d3fc Mon Sep 17 00:00:00 2001 From: Yuming Wang Date: Tue, 24 Oct 2017 22:59:46 -0700 Subject: [PATCH 1533/1765] [SPARK-21101][SQL] Catch IllegalStateException when CREATE TEMPORARY FUNCTION ## What changes were proposed in this pull request? It must `override` [`public StructObjectInspector initialize(ObjectInspector[] argOIs)`](https://github.com/apache/hive/blob/release-2.0.0/ql/src/java/org/apache/hadoop/hive/ql/udf/generic/GenericUDTF.java#L70) when create a UDTF. If you `override` [`public StructObjectInspector initialize(StructObjectInspector argOIs)`](https://github.com/apache/hive/blob/release-2.0.0/ql/src/java/org/apache/hadoop/hive/ql/udf/generic/GenericUDTF.java#L49), `IllegalStateException` will throw. per: [HIVE-12377](https://issues.apache.org/jira/browse/HIVE-12377). This PR catch `IllegalStateException` and point user to `override` `public StructObjectInspector initialize(ObjectInspector[] argOIs)`. ## How was this patch tested? unit tests Source code and binary jar: [SPARK-21101.zip](https://github.com/apache/spark/files/1123763/SPARK-21101.zip) These two source code copy from : https://github.com/apache/hive/blob/release-2.0.0/ql/src/java/org/apache/hadoop/hive/ql/udf/generic/GenericUDTFStack.java Author: Yuming Wang Closes #18527 from wangyum/SPARK-21101. --- .../spark/sql/hive/HiveSessionCatalog.scala | 11 ++++++-- .../src/test/resources/SPARK-21101-1.0.jar | Bin 0 -> 7439 bytes .../sql/hive/execution/SQLQuerySuite.scala | 26 ++++++++++++++++++ 3 files changed, 35 insertions(+), 2 deletions(-) create mode 100644 sql/hive/src/test/resources/SPARK-21101-1.0.jar diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala index b256ffc27b199..1f11adbd4f62e 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala @@ -94,8 +94,15 @@ private[sql] class HiveSessionCatalog( } } catch { case NonFatal(e) => - val analysisException = - new AnalysisException(s"No handler for UDF/UDAF/UDTF '${clazz.getCanonicalName}': $e") + val noHandlerMsg = s"No handler for UDF/UDAF/UDTF '${clazz.getCanonicalName}': $e" + val errorMsg = + if (classOf[GenericUDTF].isAssignableFrom(clazz)) { + s"$noHandlerMsg\nPlease make sure your function overrides " + + "`public StructObjectInspector initialize(ObjectInspector[] args)`." + } else { + noHandlerMsg + } + val analysisException = new AnalysisException(errorMsg) analysisException.setStackTrace(e.getStackTrace) throw analysisException } diff --git a/sql/hive/src/test/resources/SPARK-21101-1.0.jar b/sql/hive/src/test/resources/SPARK-21101-1.0.jar new file mode 100644 index 0000000000000000000000000000000000000000..768b2334db5c3aa8b1e4186af5fde86120d3231b GIT binary patch literal 7439 zcmb7J1z1#D*XAOFgmkw^#~@uINJ=*%LyqJK2t#+*05TE|HH4DV-Q7rsgi4oO8m=&+ zzz3gKz2EQkKKEbq%sJB?WmMT^==sBelT+RTUu57@i7{b9iv1 zQk!>DU~$cfTY0#TTLmbCb$vDaK>|5f8?#3}GD@37MO()ujkB1P7MD0)K%2~mWI+4q z@{Y2AvvS+A{HISWgnp4F`pUw6j<|P&J)E6QcuaWEzJ>L3^ca_6IXGE=5Bz5j+&?|Q zj$m^e%YSer`d>$9N3fIaKe&_qox3yC?jIo3zk=96-2N#t=6}RldRUsfxk;BBJ-@-KWJ^N#rinTf~4%X6gKbOjmDORleAB(l_w?F9>Q9~XJ@2E!6-9*Gl0>o zg@v=KaIdBFt>}1BJs8`8E}_Q2xeK9bsT^54w)1BiX(aZN9c3kyC&D@yWku->4&O&^ zx9Y)^B^Wwt*H%N72iv2hn@APtT1YzFJ5lr|h@q2U6qqs!tfFKRI|LCs)54NMP?+K^ zu`wbEma1?1ak6-@gkSwBZR>tVemCtCEIn{*|?j8tz4td~>TBI^Z{YnJb7)X8!xb6LfD8)fHC)(G&T}nuonup|g_zYt-?nkFHOP|`O zfvKM~21pUp5n{Q3u)SV+6@q`in;hHO&@4+#oIm@xyLvLfipDA3^hTX5Cdchm_K5hY z?WH8yyLWv*l6WC=%LiwMpWbOWH3mw)SFW)FO?fJUa? zRDl4`LpMl4mmk}`%&^1*T;;c5j2YI} z)X#G#nWwHePY{!4PmxK(i6CAt-Lw`0flCv%;yudCGi^^SMjw&ieURn7*2~|i`hg9@ z2b1piI~_@-C-X3J&@AVT_nuN_zFXMTr@cFy9L`a|ocY4Ajo_Fv|=$ z#qki151<*P>H+{-J83ZTf)-ZP0&5x%VL`R~HEm4pw4ylY0Z}-|;MTUt+htq>vf*ii z(I0XR2NTg%KD)cPNRq<}PO+8dwS@H5Hd9s#cA+Y74k?i2kP;PDLjjWNcpi75Nv62#&m!$q!VL|RtgL`K$BiZI;7HSewdKMiDEJ$;k)r{wfZ#+6s6d{u+OK)Ur7xt^9f^}8S74LE8dI#P=XjW0@j4$ zl4wfUIsu74G+j4Hhc{;Lx@BvVpqFScSl}1JJ=DpFBg3 z7{lL_#?Zwkljpz6c_qv*WI0e}Qj#;#kW6TkzT1|N<0JQ41hf13^VVLKZd_3=Ic3I{ z>rMuqPSuqDA{w3o6*oL#P41eA7fK4${(522x0dwWNtIlJ5sC?v*_s>2<2RE9+vN^v ztWFh>hXpsQ8~VKxsG%#y&=2Dl-f{g=Q`U~j*nS1?(W@V1lOl7DlBI&4bUtf%!%HC` zQa=}bx53BIMqAsLx^Us+iaoUp;E3T#?fwy(l%W28DiwPtmB9?-KH0DjigTmrfHou6 zvZd}^sHO<@%%A(zZlJz*8vPy5aMr*>a!VGwG+NKMM{B=dJ_ ztMX@QxDmJ+7nLu>qJCGQ--H@Z#jT3}J|5;(rIRHLf;O{_EALJ&g6(=&nxtzymsV}a zPT4_JzM~p<(MzPFXyHu8%AJ-MjAQIipu$djgXfB_kO9bY3IQ6H=jh|3Liy zn)yT3V*;y#Q67J9-}8O;;i(Nc-Wxg`+BhCalVoIelU8DP9L>Y`4{2=g za=02VYLcti-qL1_>D7y@c!fuLqL^9thPPx#Olr=1842IHCu4k#1QJ!m?m>eS)Cw2fWfI2RaRPPTyA>(kG(nou=B-+B>8VYe`h<$#uM)lD;;N_~`b~|8ujmd!_gZuLVfkdGH5}IT$ zV>woe4WrA0N3{=wrON?I<#e*k*vHTtQQ+*5U?M1Ht_L2XuJ6EdhCb#bE9j3fCADP~ z-l2F~EW_<3d)n1LHu@a9>rG>PR`6?LO;P3O^J3z1cDPQn_)o`{nM#48gIf4|M|E-U zq9{&?-@n`G_IQ^tSt)154eS_(uj2Dx0V8zbbp6?b8r$Va;##99*F+$ceTeTdgC)!PrGp^_^8=^;_pQDV%mqvI-l zg^t>#M~)>=dn)U63zPKuv&+K@R2qlwe^R49VkJFqfTTH%Jkd*k`q~qkYR%$DzB=J0 z;GHacCRCiLti63iqFp34G6ZLu=eQN#%q-NAyxAu6wp~m78kVywpI@7g{(l)?2V+oaKej1+S+9we<#yOXLqZb~ORj@l zX6V5XZwm|IIPwq8ak7WgQdvW`Fr>8=@mcQH4p}}8-vnZPjaGIM7P;hmRuivoV*nJg zXU-LxtEq%RoaYtm!m?-@iH}#KV-HmahAJx3x}Gg)AWpp->-}9`yA9jK4D$uzvEfql zr#$BwUQhTKBYifd=``cGGuX652$>*ezZ;X42JEn>;~8u=)`7eTh^yl6+#Iqgv2{#U zt+wB_nc7snm8Zcnj#V5AFS< zsy|u_Id5w%>f1@&>KGHd9wt%in``MchaULQ$LQva?UqDN^;xdvjrDa_JZs%2%%d|7 zc~ygT>q!-+q?pZ+_&aO}bZBLYK&m4o&!M#3Ec*lXJLPCP>K&Scs!8tZsa*b=0lb&t z6xs?8@EgHR{Vf>gkExCf@LP!Q)GIr6>DT#nIXiKbI#F~80p}hP5-J5!PgwBD<^q&F zLH=gd>@0z5Oj!hvo#Ql!T333LDJ6C@#$9=v56L`9paN zXy}B{tD4>lJB=eM=22*3e!QG2E&uxj^V5*5mznbLBNyZPyIY>Xc=Of`-Gv#my;n(Z zUgT1W85og106p&HW852*`0`N<>-+OQBYgBTY%a}Q@wbKnIQyJ&>?v6oZ)A;Fbu|xDX3y25n zX|J>5>1DCudaLNaw#)7rugT*?;kOhndW$NtI`)0Mp#~>8*AacS_!$)#a0>C%XeT&8 z+|$poRrMQ)B~m@{bG6o|q#$j2%D{o*o5xvULua}OXD(UHLnULNoB2jG1xD6ro#kgp3_HpoB8QGXM<03M2k2VPu7OY{btVwLN+&$ zFF2BU;|~f%T)McqJW>A3Odo92Z3(RHN2QnqdqJyAJILOwgc*ieeVcCk&w9L6*M$5<} z#PnHn)|WVsqB8j*U`hU3w(ZWe?d_LK)W53>DwJlPkL&~*`v*GxX0PpnhbL&Jw(?eD z3cAQSg*B%t zGcwZ{$F-V-{pMTpO`Q)uP^Y{)PnQ72TyHRpiodO=C8Gq)6c=$qN7bW`t#Ys8CK@?QUPv@LUYV8SM#lY8kF zK75w5*G^C)A2%t}2WR9!Lc@90bK#^BTS8rm%%_#S*B5@>CP`<%POxy*x+W~iK4h1uxFTnlXz4`g;+B)A9&BC&p=~^ zYHCZ=EwUM#f|}Jhd*4$zD?xZgWzfg<0`P60@^f{K;Buexy?SU5cC-9Zh5Ff<{AXu# z2Ya5!7T+sZKVxqF5~HIjt*rv&=i}q$`;Ny#dvIB|^zto}eT$BU=5kr_Fm{;WCH37ajo>sf7umUIbI8w#LHvls9MDxr(fF!yGR7`%Cj_;JVp5M61kbIU z>SlQFejKW2z;2MVi?X?&^C-wvjf6Kzp$X7*^W7$h6xu%XBRIzGO&zhq{^ILhDHvl&D^Mc1 zPL-zqa$g`O@wI`2Tv+wM?yO-JUyn2qK~K!}LC=(wfzbzqpf88;3~`zuy1_Unj)Fs6 zC4q7PO^}(WAq^s>*fDnq>UpwsZY!Y?XAoqDTU!NM5-g|K29>MDz(qwK=h1KzNzfO^ zt^m~%SrG8C!_R8EvCkuSeOL%xdSFqE@=8O#x;g4?n!tU{@*p%8&g3x-S3MnlUFV4l z(K+5YL`TSm7~yjsuDj%q6E-^S_wj?X2z$zdJ;td?7++%|gWb)@tu4;TjRF8ccb#`S z-Kxk$H@&K!J^;=<$DJ)An~JuVmPSRnO96_93;3_t!$S}D?mY|sPJc1Ilivb{jk-W(9_T~hN< zTp`lzjbJm=P(A}3Eh0{+-B8}LPtnlE){4^2E`Zf`3X~snre3odX9v_)K7a+}V6%Ex zq27Y`#RVgcv50O$1YG+G`oS(6A!~cx^AVD4P!Q^Wr`o6^v+2`BT2*UWF&Vu#`Mjp>jiMM9Z74wRyi>*SEKsC&Yy$ zS6`M7zLIo(b>N5ZCL3v+L|j=WMGHc#a}| zor^K;%e8Ma`0JO3S*gU@qsxh5arx){*Av9i$-&Xm$pvERtmCFJqykjtR#MYaEYRIo z<5n43(Z;`i*re=CdP9}$23}D>-=t={n5bNPU)w0Mkh?lxR6uHEN^>NDtC_vEs!gC< zSWHU)HdiQDc|NlMkgpztotSo!a<$mj!p6p|4{%xW{|%k=+OHLh%b+if5N${L>im5L z`T_fU1>?&3%CZC5E|1?~f6+63F#Nn0U5O1}?W*CH%ge9Q!>>cV5+A)Vy!gSB z^^ZJ%l`_6D{=Jy-jWauN?O!tfUuomVK>z*M_m373m!tTXVf)(qSKasDXa5oN_ZGf> zi~nfg=QrspZGJUEw2KSLOL6V{yZ9}MetZ6(!B=VY2iW#!@b{$pYmTe5`hz3mmmL3_ zWWVwKeVYBoxD&YWulW8 true, "udtf_stack2" -> true) { + sql( + s""" + |CREATE TEMPORARY FUNCTION udtf_stack1 + |AS 'org.apache.spark.sql.hive.execution.UDTFStack' + |USING JAR '${hiveContext.getHiveFile("SPARK-21101-1.0.jar").toURI}' + """.stripMargin) + val cnt = + sql("SELECT udtf_stack1(2, 'A', 10, date '2015-01-01', 'B', 20, date '2016-01-01')").count() + assert(cnt === 2) + + sql( + s""" + |CREATE TEMPORARY FUNCTION udtf_stack2 + |AS 'org.apache.spark.sql.hive.execution.UDTFStack2' + |USING JAR '${hiveContext.getHiveFile("SPARK-21101-1.0.jar").toURI}' + """.stripMargin) + val e = intercept[org.apache.spark.sql.AnalysisException] { + sql("SELECT udtf_stack2(2, 'A', 10, date '2015-01-01', 'B', 20, date '2016-01-01')") + } + assert( + e.getMessage.contains("public StructObjectInspector initialize(ObjectInspector[] args)")) + } + } + test("SPARK-21721: Clear FileSystem deleterOnExit cache if path is successfully removed") { val table = "test21721" withTable(table) { From 427359f077ad469d78c97972d021535f30a1e418 Mon Sep 17 00:00:00 2001 From: Ruben Berenguel Montoro Date: Tue, 24 Oct 2017 23:02:11 -0700 Subject: [PATCH 1534/1765] [SPARK-13947][SQL] The error message from using an invalid column reference is not clear ## What changes were proposed in this pull request? Rewritten error message for clarity. Added extra information in case of attribute name collision, hinting the user to double-check referencing two different tables ## How was this patch tested? No functional changes, only final message has changed. It has been tested manually against the situation proposed in the JIRA ticket. Automated tests in repository pass. This PR is original work from me and I license this work to the Spark project Author: Ruben Berenguel Montoro Author: Ruben Berenguel Montoro Author: Ruben Berenguel Closes #17100 from rberenguel/SPARK-13947-error-message. --- .../sql/catalyst/analysis/CheckAnalysis.scala | 19 ++++++++++++--- .../analysis/AnalysisErrorSuite.scala | 23 +++++++++++++------ .../invalid-correlation.sql.out | 2 +- 3 files changed, 33 insertions(+), 11 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala index d9906bb6e6ede..b5e8bdd79869e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala @@ -272,10 +272,23 @@ trait CheckAnalysis extends PredicateHelper { case o if o.children.nonEmpty && o.missingInput.nonEmpty => val missingAttributes = o.missingInput.mkString(",") val input = o.inputSet.mkString(",") + val msgForMissingAttributes = s"Resolved attribute(s) $missingAttributes missing " + + s"from $input in operator ${operator.simpleString}." - failAnalysis( - s"resolved attribute(s) $missingAttributes missing from $input " + - s"in operator ${operator.simpleString}") + val resolver = plan.conf.resolver + val attrsWithSameName = o.missingInput.filter { missing => + o.inputSet.exists(input => resolver(missing.name, input.name)) + } + + val msg = if (attrsWithSameName.nonEmpty) { + val sameNames = attrsWithSameName.map(_.name).mkString(",") + s"$msgForMissingAttributes Attribute(s) with the same name appear in the " + + s"operation: $sameNames. Please check if the right attribute(s) are used." + } else { + msgForMissingAttributes + } + + failAnalysis(msg) case p @ Project(exprs, _) if containsMultipleGenerators(exprs) => failAnalysis( diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala index 884e113537c93..5d2f8e735e3d4 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala @@ -408,16 +408,25 @@ class AnalysisErrorSuite extends AnalysisTest { // CheckAnalysis should throw AnalysisException when Aggregate contains missing attribute(s) // Since we manually construct the logical plan at here and Sum only accept // LongType, DoubleType, and DecimalType. We use LongType as the type of a. - val plan = - Aggregate( - Nil, - Alias(sum(AttributeReference("a", LongType)(exprId = ExprId(1))), "b")() :: Nil, - LocalRelation( - AttributeReference("a", LongType)(exprId = ExprId(2)))) + val attrA = AttributeReference("a", LongType)(exprId = ExprId(1)) + val otherA = AttributeReference("a", LongType)(exprId = ExprId(2)) + val attrC = AttributeReference("c", LongType)(exprId = ExprId(3)) + val aliases = Alias(sum(attrA), "b")() :: Alias(sum(attrC), "d")() :: Nil + val plan = Aggregate( + Nil, + aliases, + LocalRelation(otherA)) assert(plan.resolved) - assertAnalysisError(plan, "resolved attribute(s) a#1L missing from a#2L" :: Nil) + val resolved = s"${attrA.toString},${attrC.toString}" + + val errorMsg = s"Resolved attribute(s) $resolved missing from ${otherA.toString} " + + s"in operator !Aggregate [${aliases.mkString(", ")}]. " + + s"Attribute(s) with the same name appear in the operation: a. " + + "Please check if the right attribute(s) are used." + + assertAnalysisError(plan, errorMsg :: Nil) } test("error test for self-join") { diff --git a/sql/core/src/test/resources/sql-tests/results/subquery/negative-cases/invalid-correlation.sql.out b/sql/core/src/test/resources/sql-tests/results/subquery/negative-cases/invalid-correlation.sql.out index e4b1a2dbc675c..2586f26f71c35 100644 --- a/sql/core/src/test/resources/sql-tests/results/subquery/negative-cases/invalid-correlation.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/subquery/negative-cases/invalid-correlation.sql.out @@ -63,7 +63,7 @@ WHERE t1a IN (SELECT min(t2a) struct<> -- !query 4 output org.apache.spark.sql.AnalysisException -resolved attribute(s) t2b#x missing from min(t2a)#x,t2c#x in operator !Filter t2c#x IN (list#x [t2b#x]); +Resolved attribute(s) t2b#x missing from min(t2a)#x,t2c#x in operator !Filter t2c#x IN (list#x [t2b#x]).; -- !query 5 From 6c6950839da991bd41accdb8fb03fbc3b588c1e4 Mon Sep 17 00:00:00 2001 From: Sean Owen Date: Wed, 25 Oct 2017 12:51:20 +0100 Subject: [PATCH 1535/1765] [SPARK-22322][CORE] Update FutureAction for compatibility with Scala 2.12 Future ## What changes were proposed in this pull request? Scala 2.12's `Future` defines two new methods to implement, `transform` and `transformWith`. These can be implemented naturally in Spark's `FutureAction` extension and subclasses, but, only in terms of the new methods that don't exist in Scala 2.11. To support both at the same time, reflection is used to implement these. ## How was this patch tested? Existing tests. Author: Sean Owen Closes #19561 from srowen/SPARK-22322. --- .../scala/org/apache/spark/FutureAction.scala | 59 ++++++++++++++++++- pom.xml | 2 +- .../FlatMapGroupsWithStateSuite.scala | 3 +- .../sql/streaming/StreamingQuerySuite.scala | 2 +- .../util/FileBasedWriteAheadLog.scala | 2 +- 5 files changed, 62 insertions(+), 6 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/FutureAction.scala b/core/src/main/scala/org/apache/spark/FutureAction.scala index 1034fdcae8e8c..036c9a60630ea 100644 --- a/core/src/main/scala/org/apache/spark/FutureAction.scala +++ b/core/src/main/scala/org/apache/spark/FutureAction.scala @@ -89,7 +89,11 @@ trait FutureAction[T] extends Future[T] { */ override def value: Option[Try[T]] - // These two methods must be implemented in Scala 2.12, but won't be used by Spark + // These two methods must be implemented in Scala 2.12. They're implemented as a no-op here + // and then filled in with a real implementation in the two subclasses below. The no-op exists + // here so that those implementations can declare "override", necessary in 2.12, while working + // in 2.11, where the method doesn't exist in the superclass. + // After 2.11 support goes away, remove these two: def transform[S](f: (Try[T]) => Try[S])(implicit executor: ExecutionContext): Future[S] = throw new UnsupportedOperationException() @@ -113,6 +117,42 @@ trait FutureAction[T] extends Future[T] { } +/** + * Scala 2.12 defines the two new transform/transformWith methods mentioned above. Impementing + * these for 2.12 in the Spark class here requires delegating to these same methods in an + * underlying Future object. But that only exists in 2.12. But these methods are only called + * in 2.12. So define helper shims to access these methods on a Future by reflection. + */ +private[spark] object FutureAction { + + private val transformTryMethod = + try { + classOf[Future[_]].getMethod("transform", classOf[(_) => _], classOf[ExecutionContext]) + } catch { + case _: NoSuchMethodException => null // Would fail later in 2.11, but not called in 2.11 + } + + private val transformWithTryMethod = + try { + classOf[Future[_]].getMethod("transformWith", classOf[(_) => _], classOf[ExecutionContext]) + } catch { + case _: NoSuchMethodException => null // Would fail later in 2.11, but not called in 2.11 + } + + private[spark] def transform[T, S]( + future: Future[T], + f: (Try[T]) => Try[S], + executor: ExecutionContext): Future[S] = + transformTryMethod.invoke(future, f, executor).asInstanceOf[Future[S]] + + private[spark] def transformWith[T, S]( + future: Future[T], + f: (Try[T]) => Future[S], + executor: ExecutionContext): Future[S] = + transformWithTryMethod.invoke(future, f, executor).asInstanceOf[Future[S]] + +} + /** * A [[FutureAction]] holding the result of an action that triggers a single job. Examples include @@ -153,6 +193,18 @@ class SimpleFutureAction[T] private[spark](jobWaiter: JobWaiter[_], resultFunc: jobWaiter.completionFuture.value.map {res => res.map(_ => resultFunc)} def jobIds: Seq[Int] = Seq(jobWaiter.jobId) + + override def transform[S](f: (Try[T]) => Try[S])(implicit e: ExecutionContext): Future[S] = + FutureAction.transform( + jobWaiter.completionFuture, + (u: Try[Unit]) => f(u.map(_ => resultFunc)), + e) + + override def transformWith[S](f: (Try[T]) => Future[S])(implicit e: ExecutionContext): Future[S] = + FutureAction.transformWith( + jobWaiter.completionFuture, + (u: Try[Unit]) => f(u.map(_ => resultFunc)), + e) } @@ -246,6 +298,11 @@ class ComplexFutureAction[T](run : JobSubmitter => Future[T]) def jobIds: Seq[Int] = subActions.flatMap(_.jobIds) + override def transform[S](f: (Try[T]) => Try[S])(implicit e: ExecutionContext): Future[S] = + FutureAction.transform(p.future, f, e) + + override def transformWith[S](f: (Try[T]) => Future[S])(implicit e: ExecutionContext): Future[S] = + FutureAction.transformWith(p.future, f, e) } diff --git a/pom.xml b/pom.xml index b9c972855204a..2d59f06811a82 100644 --- a/pom.xml +++ b/pom.xml @@ -2692,7 +2692,7 @@ scala-2.12 - 2.12.3 + 2.12.4 2.12 diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala index af08186aadbb0..b906393a379ae 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala @@ -33,7 +33,6 @@ import org.apache.spark.sql.catalyst.streaming.InternalOutputModes._ import org.apache.spark.sql.execution.RDDScanExec import org.apache.spark.sql.execution.streaming.{FlatMapGroupsWithStateExec, GroupStateImpl, MemoryStream} import org.apache.spark.sql.execution.streaming.state.{StateStore, StateStoreId, StateStoreMetrics, UnsafeRowPair} -import org.apache.spark.sql.streaming.FlatMapGroupsWithStateSuite.MemoryStateStore import org.apache.spark.sql.streaming.util.StreamManualClock import org.apache.spark.sql.types.{DataType, IntegerType} @@ -1201,7 +1200,7 @@ object FlatMapGroupsWithStateSuite { } catch { case u: UnsupportedOperationException => return - case _ => + case _: Throwable => throw new TestFailedException("Unexpected exception when trying to get watermark", 20) } throw new TestFailedException("Could get watermark when not expected", 20) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala index c53889bb8566c..cc693909270f8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala @@ -744,7 +744,7 @@ class StreamingQuerySuite extends StreamTest with BeforeAndAfter with Logging wi assert(returnedValue === expectedReturnValue, "Returned value does not match expected") } } - AwaitTerminationTester.test(expectedBehavior, awaitTermFunc) + AwaitTerminationTester.test(expectedBehavior, () => awaitTermFunc()) true // If the control reached here, then everything worked as expected } } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/util/FileBasedWriteAheadLog.scala b/streaming/src/main/scala/org/apache/spark/streaming/util/FileBasedWriteAheadLog.scala index d6e15cfdd2723..ab7c8558321c8 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/util/FileBasedWriteAheadLog.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/util/FileBasedWriteAheadLog.scala @@ -139,7 +139,7 @@ private[streaming] class FileBasedWriteAheadLog( def readFile(file: String): Iterator[ByteBuffer] = { logDebug(s"Creating log reader with $file") val reader = new FileBasedWriteAheadLogReader(file, hadoopConf) - CompletionIterator[ByteBuffer, Iterator[ByteBuffer]](reader, reader.close _) + CompletionIterator[ByteBuffer, Iterator[ByteBuffer]](reader, () => reader.close()) } if (!closeFileAfterWrite) { logFilesToRead.iterator.map(readFile).flatten.asJava From 1051ebec70bf05971ddc80819d112626b1f1614f Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Wed, 25 Oct 2017 16:31:58 +0100 Subject: [PATCH 1536/1765] [SPARK-20783][SQL][FOLLOW-UP] Create ColumnVector to abstract existing compressed column ## What changes were proposed in this pull request? Removed one unused method. ## How was this patch tested? Existing tests. Author: Liang-Chi Hsieh Closes #19508 from viirya/SPARK-20783-followup. --- .../apache/spark/sql/execution/columnar/ColumnAccessor.scala | 3 --- 1 file changed, 3 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnAccessor.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnAccessor.scala index 445933d98e9d4..85c36b7da9498 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnAccessor.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnAccessor.scala @@ -63,9 +63,6 @@ private[columnar] abstract class BasicColumnAccessor[JvmType]( } protected def underlyingBuffer = buffer - - def getByteBuffer: ByteBuffer = - buffer.duplicate.order(ByteOrder.nativeOrder()) } private[columnar] class NullColumnAccessor(buffer: ByteBuffer) From 3d43a9f939764ec265de945921d1ecf2323ca230 Mon Sep 17 00:00:00 2001 From: liuxian Date: Wed, 25 Oct 2017 21:34:00 +0530 Subject: [PATCH 1537/1765] [SPARK-22349] In on-heap mode, when allocating memory from pool,we should fill memory with `MEMORY_DEBUG_FILL_CLEAN_VALUE` ## What changes were proposed in this pull request? In on-heap mode, when allocating memory from pool,we should fill memory with `MEMORY_DEBUG_FILL_CLEAN_VALUE` ## How was this patch tested? added unit tests Author: liuxian Closes #19572 from 10110346/MEMORY_DEBUG. --- .../spark/unsafe/memory/HeapMemoryAllocator.java | 3 +++ .../org/apache/spark/unsafe/PlatformUtilSuite.java | 13 ++++++++++++- 2 files changed, 15 insertions(+), 1 deletion(-) diff --git a/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/HeapMemoryAllocator.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/HeapMemoryAllocator.java index 355748238540b..cc9cc429643ad 100644 --- a/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/HeapMemoryAllocator.java +++ b/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/HeapMemoryAllocator.java @@ -56,6 +56,9 @@ public MemoryBlock allocate(long size) throws OutOfMemoryError { final MemoryBlock memory = blockReference.get(); if (memory != null) { assert (memory.size() == size); + if (MemoryAllocator.MEMORY_DEBUG_FILL_ENABLED) { + memory.fill(MemoryAllocator.MEMORY_DEBUG_FILL_CLEAN_VALUE); + } return memory; } } diff --git a/common/unsafe/src/test/java/org/apache/spark/unsafe/PlatformUtilSuite.java b/common/unsafe/src/test/java/org/apache/spark/unsafe/PlatformUtilSuite.java index 4ae49d82efa29..4b141339ec816 100644 --- a/common/unsafe/src/test/java/org/apache/spark/unsafe/PlatformUtilSuite.java +++ b/common/unsafe/src/test/java/org/apache/spark/unsafe/PlatformUtilSuite.java @@ -66,10 +66,21 @@ public void overlappingCopyMemory() { public void memoryDebugFillEnabledInTest() { Assert.assertTrue(MemoryAllocator.MEMORY_DEBUG_FILL_ENABLED); MemoryBlock onheap = MemoryAllocator.HEAP.allocate(1); - MemoryBlock offheap = MemoryAllocator.UNSAFE.allocate(1); Assert.assertEquals( Platform.getByte(onheap.getBaseObject(), onheap.getBaseOffset()), MemoryAllocator.MEMORY_DEBUG_FILL_CLEAN_VALUE); + + MemoryBlock onheap1 = MemoryAllocator.HEAP.allocate(1024 * 1024); + MemoryAllocator.HEAP.free(onheap1); + Assert.assertEquals( + Platform.getByte(onheap1.getBaseObject(), onheap1.getBaseOffset()), + MemoryAllocator.MEMORY_DEBUG_FILL_FREED_VALUE); + MemoryBlock onheap2 = MemoryAllocator.HEAP.allocate(1024 * 1024); + Assert.assertEquals( + Platform.getByte(onheap2.getBaseObject(), onheap2.getBaseOffset()), + MemoryAllocator.MEMORY_DEBUG_FILL_CLEAN_VALUE); + + MemoryBlock offheap = MemoryAllocator.UNSAFE.allocate(1); Assert.assertEquals( Platform.getByte(offheap.getBaseObject(), offheap.getBaseOffset()), MemoryAllocator.MEMORY_DEBUG_FILL_CLEAN_VALUE); From 6ea8a56ca26a7e02e6574f5f763bb91059119a80 Mon Sep 17 00:00:00 2001 From: Andrea zito Date: Wed, 25 Oct 2017 10:10:24 -0700 Subject: [PATCH 1538/1765] [SPARK-21991][LAUNCHER] Fix race condition in LauncherServer#acceptConnections ## What changes were proposed in this pull request? This patch changes the order in which _acceptConnections_ starts the client thread and schedules the client timeout action ensuring that the latter has been scheduled before the former get a chance to cancel it. ## How was this patch tested? Due to the non-deterministic nature of the patch I wasn't able to add a new test for this issue. Author: Andrea zito Closes #19217 from nivox/SPARK-21991. --- .../apache/spark/launcher/LauncherServer.java | 26 +++++++++---------- 1 file changed, 13 insertions(+), 13 deletions(-) diff --git a/launcher/src/main/java/org/apache/spark/launcher/LauncherServer.java b/launcher/src/main/java/org/apache/spark/launcher/LauncherServer.java index 865d4926da6a9..454bc7a7f924d 100644 --- a/launcher/src/main/java/org/apache/spark/launcher/LauncherServer.java +++ b/launcher/src/main/java/org/apache/spark/launcher/LauncherServer.java @@ -232,20 +232,20 @@ public void run() { }; ServerConnection clientConnection = new ServerConnection(client, timeout); Thread clientThread = factory.newThread(clientConnection); - synchronized (timeout) { - clientThread.start(); - synchronized (clients) { - clients.add(clientConnection); - } - long timeoutMs = getConnectionTimeout(); - // 0 is used for testing to avoid issues with clock resolution / thread scheduling, - // and force an immediate timeout. - if (timeoutMs > 0) { - timeoutTimer.schedule(timeout, getConnectionTimeout()); - } else { - timeout.run(); - } + synchronized (clients) { + clients.add(clientConnection); + } + + long timeoutMs = getConnectionTimeout(); + // 0 is used for testing to avoid issues with clock resolution / thread scheduling, + // and force an immediate timeout. + if (timeoutMs > 0) { + timeoutTimer.schedule(timeout, timeoutMs); + } else { + timeout.run(); } + + clientThread.start(); } } catch (IOException ioe) { if (running) { From d212ef14be7c2864cc529e48a02a47584e46f7a5 Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Wed, 25 Oct 2017 13:53:01 -0700 Subject: [PATCH 1539/1765] [SPARK-22341][YARN] Impersonate correct user when preparing resources. The bug was introduced in SPARK-22290, which changed how the app's user is impersonated in the AM. The changed missed an initialization function that needs to be run as the app owner (who has the right credentials to read from HDFS). Author: Marcelo Vanzin Closes #19566 from vanzin/SPARK-22341. --- .../spark/deploy/yarn/ApplicationMaster.scala | 18 +++++++++++------- 1 file changed, 11 insertions(+), 7 deletions(-) diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala index f6167235f89e4..244d912b9f3aa 100644 --- a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala +++ b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala @@ -97,9 +97,7 @@ private[spark] class ApplicationMaster(args: ApplicationMasterArguments) extends } } - private val client = ugi.doAs(new PrivilegedExceptionAction[YarnRMClient]() { - def run: YarnRMClient = new YarnRMClient() - }) + private val client = doAsUser { new YarnRMClient() } // Default to twice the number of executors (twice the maximum number of executors if dynamic // allocation is enabled), with a minimum of 3. @@ -178,7 +176,7 @@ private[spark] class ApplicationMaster(args: ApplicationMasterArguments) extends // Load the list of localized files set by the client. This is used when launching executors, // and is loaded here so that these configs don't pollute the Web UI's environment page in // cluster mode. - private val localResources = { + private val localResources = doAsUser { logInfo("Preparing Local resources") val resources = HashMap[String, LocalResource]() @@ -240,9 +238,9 @@ private[spark] class ApplicationMaster(args: ApplicationMasterArguments) extends } final def run(): Int = { - ugi.doAs(new PrivilegedExceptionAction[Unit]() { - def run: Unit = runImpl() - }) + doAsUser { + runImpl() + } exitCode } @@ -790,6 +788,12 @@ private[spark] class ApplicationMaster(args: ApplicationMasterArguments) extends } } + private def doAsUser[T](fn: => T): T = { + ugi.doAs(new PrivilegedExceptionAction[T]() { + override def run: T = fn + }) + } + } object ApplicationMaster extends Logging { From b377ef133cdc38d49b460b2cc6ece0b5892804cc Mon Sep 17 00:00:00 2001 From: Sergei Lebedev Date: Wed, 25 Oct 2017 22:15:44 +0100 Subject: [PATCH 1540/1765] [SPARK-22227][CORE] DiskBlockManager.getAllBlocks now tolerates temp files ## What changes were proposed in this pull request? Prior to this commit getAllBlocks implicitly assumed that the directories managed by the DiskBlockManager contain only the files corresponding to valid block IDs. In reality, this assumption was violated during shuffle, which produces temporary files in the same directory as the resulting blocks. As a result, calls to getAllBlocks during shuffle were unreliable. The fix could be made more efficient, but this is probably good enough. ## How was this patch tested? `DiskBlockManagerSuite` Author: Sergei Lebedev Closes #19458 from superbobry/block-id-option. --- .../scala/org/apache/spark/storage/BlockId.scala | 16 +++++++++++++--- .../apache/spark/storage/DiskBlockManager.scala | 11 ++++++++++- .../org/apache/spark/storage/BlockIdSuite.scala | 9 +++------ .../spark/storage/DiskBlockManagerSuite.scala | 7 +++++++ 4 files changed, 33 insertions(+), 10 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/storage/BlockId.scala b/core/src/main/scala/org/apache/spark/storage/BlockId.scala index a441baed2800e..7ac2c71c18eb3 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockId.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockId.scala @@ -19,6 +19,7 @@ package org.apache.spark.storage import java.util.UUID +import org.apache.spark.SparkException import org.apache.spark.annotation.DeveloperApi /** @@ -95,6 +96,10 @@ private[spark] case class TestBlockId(id: String) extends BlockId { override def name: String = "test_" + id } +@DeveloperApi +class UnrecognizedBlockId(name: String) + extends SparkException(s"Failed to parse $name into a block ID") + @DeveloperApi object BlockId { val RDD = "rdd_([0-9]+)_([0-9]+)".r @@ -104,10 +109,11 @@ object BlockId { val BROADCAST = "broadcast_([0-9]+)([_A-Za-z0-9]*)".r val TASKRESULT = "taskresult_([0-9]+)".r val STREAM = "input-([0-9]+)-([0-9]+)".r + val TEMP_LOCAL = "temp_local_([-A-Fa-f0-9]+)".r + val TEMP_SHUFFLE = "temp_shuffle_([-A-Fa-f0-9]+)".r val TEST = "test_(.*)".r - /** Converts a BlockId "name" String back into a BlockId. */ - def apply(id: String): BlockId = id match { + def apply(name: String): BlockId = name match { case RDD(rddId, splitIndex) => RDDBlockId(rddId.toInt, splitIndex.toInt) case SHUFFLE(shuffleId, mapId, reduceId) => @@ -122,9 +128,13 @@ object BlockId { TaskResultBlockId(taskId.toLong) case STREAM(streamId, uniqueId) => StreamBlockId(streamId.toInt, uniqueId.toLong) + case TEMP_LOCAL(uuid) => + TempLocalBlockId(UUID.fromString(uuid)) + case TEMP_SHUFFLE(uuid) => + TempShuffleBlockId(UUID.fromString(uuid)) case TEST(value) => TestBlockId(value) case _ => - throw new IllegalStateException("Unrecognized BlockId: " + id) + throw new UnrecognizedBlockId(name) } } diff --git a/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala b/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala index 3d43e3c367aac..a69bcc9259995 100644 --- a/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala @@ -100,7 +100,16 @@ private[spark] class DiskBlockManager(conf: SparkConf, deleteFilesOnStop: Boolea /** List all the blocks currently stored on disk by the disk manager. */ def getAllBlocks(): Seq[BlockId] = { - getAllFiles().map(f => BlockId(f.getName)) + getAllFiles().flatMap { f => + try { + Some(BlockId(f.getName)) + } catch { + case _: UnrecognizedBlockId => + // Skip files which do not correspond to blocks, for example temporary + // files created by [[SortShuffleWriter]]. + None + } + } } /** Produces a unique block id and File suitable for storing local intermediate results. */ diff --git a/core/src/test/scala/org/apache/spark/storage/BlockIdSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockIdSuite.scala index f0c521b00b583..ff4755833a916 100644 --- a/core/src/test/scala/org/apache/spark/storage/BlockIdSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/BlockIdSuite.scala @@ -35,13 +35,8 @@ class BlockIdSuite extends SparkFunSuite { } test("test-bad-deserialization") { - try { - // Try to deserialize an invalid block id. + intercept[UnrecognizedBlockId] { BlockId("myblock") - fail() - } catch { - case e: IllegalStateException => // OK - case _: Throwable => fail() } } @@ -139,6 +134,7 @@ class BlockIdSuite extends SparkFunSuite { assert(id.id.getMostSignificantBits() === 5) assert(id.id.getLeastSignificantBits() === 2) assert(!id.isShuffle) + assertSame(id, BlockId(id.toString)) } test("temp shuffle") { @@ -151,6 +147,7 @@ class BlockIdSuite extends SparkFunSuite { assert(id.id.getMostSignificantBits() === 1) assert(id.id.getLeastSignificantBits() === 2) assert(!id.isShuffle) + assertSame(id, BlockId(id.toString)) } test("test") { diff --git a/core/src/test/scala/org/apache/spark/storage/DiskBlockManagerSuite.scala b/core/src/test/scala/org/apache/spark/storage/DiskBlockManagerSuite.scala index 7859b0bba2b48..0c4f3c48ef802 100644 --- a/core/src/test/scala/org/apache/spark/storage/DiskBlockManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/DiskBlockManagerSuite.scala @@ -18,6 +18,7 @@ package org.apache.spark.storage import java.io.{File, FileWriter} +import java.util.UUID import org.scalatest.{BeforeAndAfterAll, BeforeAndAfterEach} @@ -79,6 +80,12 @@ class DiskBlockManagerSuite extends SparkFunSuite with BeforeAndAfterEach with B assert(diskBlockManager.getAllBlocks.toSet === ids.toSet) } + test("SPARK-22227: non-block files are skipped") { + val file = diskBlockManager.getFile("unmanaged_file") + writeToFile(file, 10) + assert(diskBlockManager.getAllBlocks().isEmpty) + } + def writeToFile(file: File, numBytes: Int) { val writer = new FileWriter(file, true) for (i <- 0 until numBytes) writer.write(i) From 841f1d776f420424c20d99cf7110d06c73f9ca20 Mon Sep 17 00:00:00 2001 From: WeichenXu Date: Wed, 25 Oct 2017 14:31:36 -0700 Subject: [PATCH 1541/1765] [SPARK-22332][ML][TEST] Fix NaiveBayes unit test occasionly fail (cause by test dataset not deterministic) ## What changes were proposed in this pull request? Fix NaiveBayes unit test occasionly fail: Set seed for `BrzMultinomial.sample`, make `generateNaiveBayesInput` output deterministic dataset. (If we do not set seed, the generated dataset will be random, and the model will be possible to exceed the tolerance in the test, which trigger this failure) ## How was this patch tested? Manually run tests multiple times and check each time output models contains the same values. Author: WeichenXu Closes #19558 from WeichenXu123/fix_nb_test_seed. --- .../org/apache/spark/ml/classification/NaiveBayesSuite.scala | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala index 9730dd68a3b27..0d3adf993383f 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala @@ -20,7 +20,7 @@ package org.apache.spark.ml.classification import scala.util.Random import breeze.linalg.{DenseVector => BDV, Vector => BV} -import breeze.stats.distributions.{Multinomial => BrzMultinomial} +import breeze.stats.distributions.{Multinomial => BrzMultinomial, RandBasis => BrzRandBasis} import org.apache.spark.{SparkException, SparkFunSuite} import org.apache.spark.ml.classification.NaiveBayes.{Bernoulli, Multinomial} @@ -335,6 +335,7 @@ object NaiveBayesSuite { val _pi = pi.map(math.exp) val _theta = theta.map(row => row.map(math.exp)) + implicit val rngForBrzMultinomial = BrzRandBasis.withSeed(seed) for (i <- 0 until nPoints) yield { val y = calcLabel(rnd.nextDouble(), _pi) val xi = modelType match { From 5433be44caecaeef45ed1fdae10b223c698a9d14 Mon Sep 17 00:00:00 2001 From: Andrew Ash Date: Wed, 25 Oct 2017 14:41:02 -0700 Subject: [PATCH 1542/1765] [SPARK-21991][LAUNCHER][FOLLOWUP] Fix java lint ## What changes were proposed in this pull request? Fix java lint ## How was this patch tested? Run `./dev/lint-java` Author: Andrew Ash Closes #19574 from ash211/aash/fix-java-lint. --- .../main/java/org/apache/spark/launcher/LauncherServer.java | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/launcher/src/main/java/org/apache/spark/launcher/LauncherServer.java b/launcher/src/main/java/org/apache/spark/launcher/LauncherServer.java index 454bc7a7f924d..4353e3f263c51 100644 --- a/launcher/src/main/java/org/apache/spark/launcher/LauncherServer.java +++ b/launcher/src/main/java/org/apache/spark/launcher/LauncherServer.java @@ -235,7 +235,7 @@ public void run() { synchronized (clients) { clients.add(clientConnection); } - + long timeoutMs = getConnectionTimeout(); // 0 is used for testing to avoid issues with clock resolution / thread scheduling, // and force an immediate timeout. @@ -244,7 +244,7 @@ public void run() { } else { timeout.run(); } - + clientThread.start(); } } catch (IOException ioe) { From 592cfeab9caeff955d115a1ca5014ede7d402907 Mon Sep 17 00:00:00 2001 From: Nathan Kronenfeld Date: Thu, 26 Oct 2017 00:29:49 -0700 Subject: [PATCH 1543/1765] [SPARK-22308] Support alternative unit testing styles in external applications ## What changes were proposed in this pull request? Support unit tests of external code (i.e., applications that use spark) using scalatest that don't want to use FunSuite. SharedSparkContext already supports this, but SharedSQLContext does not. I've introduced SharedSparkSession as a parent to SharedSQLContext, written in a way that it does support all scalatest styles. ## How was this patch tested? There are three new unit test suites added that just test using FunSpec, FlatSpec, and WordSpec. Author: Nathan Kronenfeld Closes #19529 from nkronenfeld/alternative-style-tests-2. --- .../org/apache/spark/SharedSparkContext.scala | 17 +- .../spark/sql/catalyst/plans/PlanTest.scala | 10 +- .../spark/sql/test/GenericFlatSpecSuite.scala | 45 +++++ .../spark/sql/test/GenericFunSpecSuite.scala | 47 +++++ .../spark/sql/test/GenericWordSpecSuite.scala | 51 ++++++ .../apache/spark/sql/test/SQLTestUtils.scala | 173 ++++++++++-------- .../spark/sql/test/SharedSQLContext.scala | 84 +-------- .../spark/sql/test/SharedSparkSession.scala | 119 ++++++++++++ 8 files changed, 381 insertions(+), 165 deletions(-) create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/test/GenericFlatSpecSuite.scala create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/test/GenericFunSpecSuite.scala create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/test/GenericWordSpecSuite.scala create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/test/SharedSparkSession.scala diff --git a/core/src/test/scala/org/apache/spark/SharedSparkContext.scala b/core/src/test/scala/org/apache/spark/SharedSparkContext.scala index 6aedcb1271ff6..1aa1c421d792e 100644 --- a/core/src/test/scala/org/apache/spark/SharedSparkContext.scala +++ b/core/src/test/scala/org/apache/spark/SharedSparkContext.scala @@ -29,10 +29,23 @@ trait SharedSparkContext extends BeforeAndAfterAll with BeforeAndAfterEach { sel var conf = new SparkConf(false) + /** + * Initialize the [[SparkContext]]. Generally, this is just called from beforeAll; however, in + * test using styles other than FunSuite, there is often code that relies on the session between + * test group constructs and the actual tests, which may need this session. It is purely a + * semantic difference, but semantically, it makes more sense to call 'initializeContext' between + * a 'describe' and an 'it' call than it does to call 'beforeAll'. + */ + protected def initializeContext(): Unit = { + if (null == _sc) { + _sc = new SparkContext( + "local[4]", "test", conf.set("spark.hadoop.fs.file.impl", classOf[DebugFilesystem].getName)) + } + } + override def beforeAll() { super.beforeAll() - _sc = new SparkContext( - "local[4]", "test", conf.set("spark.hadoop.fs.file.impl", classOf[DebugFilesystem].getName)) + initializeContext() } override def afterAll() { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala index 10bdfafd6f933..82c5307d54360 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.catalyst.plans +import org.scalatest.Suite + import org.apache.spark.SparkFunSuite import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.analysis.SimpleAnalyzer @@ -29,7 +31,13 @@ import org.apache.spark.sql.internal.SQLConf /** * Provides helper methods for comparing plans. */ -trait PlanTest extends SparkFunSuite with PredicateHelper { +trait PlanTest extends SparkFunSuite with PlanTestBase + +/** + * Provides helper methods for comparing plans, but without the overhead of + * mandating a FunSuite. + */ +trait PlanTestBase extends PredicateHelper { self: Suite => // TODO(gatorsmile): remove this from PlanTest and all the analyzer rules protected def conf = SQLConf.get diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/GenericFlatSpecSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/GenericFlatSpecSuite.scala new file mode 100644 index 0000000000000..6179585a0d39a --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/GenericFlatSpecSuite.scala @@ -0,0 +1,45 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.test + +import org.scalatest.FlatSpec + +/** + * The purpose of this suite is to make sure that generic FlatSpec-based scala + * tests work with a shared spark session + */ +class GenericFlatSpecSuite extends FlatSpec with SharedSparkSession { + import testImplicits._ + initializeSession() + val ds = Seq((1, 1), (2, 1), (3, 2), (4, 2), (5, 3), (6, 3), (7, 4), (8, 4)).toDS + + "A Simple Dataset" should "have the specified number of elements" in { + assert(8 === ds.count) + } + it should "have the specified number of unique elements" in { + assert(8 === ds.distinct.count) + } + it should "have the specified number of elements in each column" in { + assert(8 === ds.select("_1").count) + assert(8 === ds.select("_2").count) + } + it should "have the correct number of distinct elements in each column" in { + assert(8 === ds.select("_1").distinct.count) + assert(4 === ds.select("_2").distinct.count) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/GenericFunSpecSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/GenericFunSpecSuite.scala new file mode 100644 index 0000000000000..15139ee8b3047 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/GenericFunSpecSuite.scala @@ -0,0 +1,47 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.test + +import org.scalatest.FunSpec + +/** + * The purpose of this suite is to make sure that generic FunSpec-based scala + * tests work with a shared spark session + */ +class GenericFunSpecSuite extends FunSpec with SharedSparkSession { + import testImplicits._ + initializeSession() + val ds = Seq((1, 1), (2, 1), (3, 2), (4, 2), (5, 3), (6, 3), (7, 4), (8, 4)).toDS + + describe("Simple Dataset") { + it("should have the specified number of elements") { + assert(8 === ds.count) + } + it("should have the specified number of unique elements") { + assert(8 === ds.distinct.count) + } + it("should have the specified number of elements in each column") { + assert(8 === ds.select("_1").count) + assert(8 === ds.select("_2").count) + } + it("should have the correct number of distinct elements in each column") { + assert(8 === ds.select("_1").distinct.count) + assert(4 === ds.select("_2").distinct.count) + } + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/GenericWordSpecSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/GenericWordSpecSuite.scala new file mode 100644 index 0000000000000..b6548bf95fec8 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/GenericWordSpecSuite.scala @@ -0,0 +1,51 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.test + +import org.scalatest.WordSpec + +/** + * The purpose of this suite is to make sure that generic WordSpec-based scala + * tests work with a shared spark session + */ +class GenericWordSpecSuite extends WordSpec with SharedSparkSession { + import testImplicits._ + initializeSession() + val ds = Seq((1, 1), (2, 1), (3, 2), (4, 2), (5, 3), (6, 3), (7, 4), (8, 4)).toDS + + "A Simple Dataset" when { + "looked at as complete rows" should { + "have the specified number of elements" in { + assert(8 === ds.count) + } + "have the specified number of unique elements" in { + assert(8 === ds.distinct.count) + } + } + "refined to specific columns" should { + "have the specified number of elements in each column" in { + assert(8 === ds.select("_1").count) + assert(8 === ds.select("_2").count) + } + "have the correct number of distinct elements in each column" in { + assert(8 === ds.select("_1").distinct.count) + assert(4 === ds.select("_2").distinct.count) + } + } + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala index a14a1441a4313..b4248b74f50ab 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala @@ -27,7 +27,7 @@ import scala.language.implicitConversions import scala.util.control.NonFatal import org.apache.hadoop.fs.Path -import org.scalatest.BeforeAndAfterAll +import org.scalatest.{BeforeAndAfterAll, Suite} import org.scalatest.concurrent.Eventually import org.apache.spark.SparkFunSuite @@ -36,14 +36,17 @@ import org.apache.spark.sql.catalyst.FunctionIdentifier import org.apache.spark.sql.catalyst.analysis.NoSuchTableException import org.apache.spark.sql.catalyst.catalog.SessionCatalog.DEFAULT_DATABASE import org.apache.spark.sql.catalyst.plans.PlanTest +import org.apache.spark.sql.catalyst.plans.PlanTestBase import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.util._ import org.apache.spark.sql.execution.FilterExec import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.util.{UninterruptibleThread, Utils} +import org.apache.spark.util.UninterruptibleThread +import org.apache.spark.util.Utils /** - * Helper trait that should be extended by all SQL test suites. + * Helper trait that should be extended by all SQL test suites within the Spark + * code base. * * This allows subclasses to plugin a custom `SQLContext`. It comes with test data * prepared in advance as well as all implicit conversions used extensively by dataframes. @@ -52,17 +55,99 @@ import org.apache.spark.util.{UninterruptibleThread, Utils} * Subclasses should *not* create `SQLContext`s in the test suite constructor, which is * prone to leaving multiple overlapping [[org.apache.spark.SparkContext]]s in the same JVM. */ -private[sql] trait SQLTestUtils - extends SparkFunSuite with Eventually +private[sql] trait SQLTestUtils extends SparkFunSuite with SQLTestUtilsBase with PlanTest { + // Whether to materialize all test data before the first test is run + private var loadTestDataBeforeTests = false + + protected override def beforeAll(): Unit = { + super.beforeAll() + if (loadTestDataBeforeTests) { + loadTestData() + } + } + + /** + * Materialize the test data immediately after the `SQLContext` is set up. + * This is necessary if the data is accessed by name but not through direct reference. + */ + protected def setupTestData(): Unit = { + loadTestDataBeforeTests = true + } + + /** + * Disable stdout and stderr when running the test. To not output the logs to the console, + * ConsoleAppender's `follow` should be set to `true` so that it will honors reassignments of + * System.out or System.err. Otherwise, ConsoleAppender will still output to the console even if + * we change System.out and System.err. + */ + protected def testQuietly(name: String)(f: => Unit): Unit = { + test(name) { + quietly { + f + } + } + } + + /** + * Run a test on a separate `UninterruptibleThread`. + */ + protected def testWithUninterruptibleThread(name: String, quietly: Boolean = false) + (body: => Unit): Unit = { + val timeoutMillis = 10000 + @transient var ex: Throwable = null + + def runOnThread(): Unit = { + val thread = new UninterruptibleThread(s"Testing thread for test $name") { + override def run(): Unit = { + try { + body + } catch { + case NonFatal(e) => + ex = e + } + } + } + thread.setDaemon(true) + thread.start() + thread.join(timeoutMillis) + if (thread.isAlive) { + thread.interrupt() + // If this interrupt does not work, then this thread is most likely running something that + // is not interruptible. There is not much point to wait for the thread to termniate, and + // we rather let the JVM terminate the thread on exit. + fail( + s"Test '$name' running on o.a.s.util.UninterruptibleThread timed out after" + + s" $timeoutMillis ms") + } else if (ex != null) { + throw ex + } + } + + if (quietly) { + testQuietly(name) { runOnThread() } + } else { + test(name) { runOnThread() } + } + } +} + +/** + * Helper trait that can be extended by all external SQL test suites. + * + * This allows subclasses to plugin a custom `SQLContext`. + * To use implicit methods, import `testImplicits._` instead of through the `SQLContext`. + * + * Subclasses should *not* create `SQLContext`s in the test suite constructor, which is + * prone to leaving multiple overlapping [[org.apache.spark.SparkContext]]s in the same JVM. + */ +private[sql] trait SQLTestUtilsBase + extends Eventually with BeforeAndAfterAll with SQLTestData - with PlanTest { self => + with PlanTestBase { self: Suite => protected def sparkContext = spark.sparkContext - // Whether to materialize all test data before the first test is run - private var loadTestDataBeforeTests = false - // Shorthand for running a query using our SQLContext protected lazy val sql = spark.sql _ @@ -77,21 +162,6 @@ private[sql] trait SQLTestUtils protected override def _sqlContext: SQLContext = self.spark.sqlContext } - /** - * Materialize the test data immediately after the `SQLContext` is set up. - * This is necessary if the data is accessed by name but not through direct reference. - */ - protected def setupTestData(): Unit = { - loadTestDataBeforeTests = true - } - - protected override def beforeAll(): Unit = { - super.beforeAll() - if (loadTestDataBeforeTests) { - loadTestData() - } - } - protected override def withSQLConf(pairs: (String, String)*)(f: => Unit): Unit = { SparkSession.setActiveSession(spark) super.withSQLConf(pairs: _*)(f) @@ -297,61 +367,6 @@ private[sql] trait SQLTestUtils Dataset.ofRows(spark, plan) } - /** - * Disable stdout and stderr when running the test. To not output the logs to the console, - * ConsoleAppender's `follow` should be set to `true` so that it will honors reassignments of - * System.out or System.err. Otherwise, ConsoleAppender will still output to the console even if - * we change System.out and System.err. - */ - protected def testQuietly(name: String)(f: => Unit): Unit = { - test(name) { - quietly { - f - } - } - } - - /** - * Run a test on a separate `UninterruptibleThread`. - */ - protected def testWithUninterruptibleThread(name: String, quietly: Boolean = false) - (body: => Unit): Unit = { - val timeoutMillis = 10000 - @transient var ex: Throwable = null - - def runOnThread(): Unit = { - val thread = new UninterruptibleThread(s"Testing thread for test $name") { - override def run(): Unit = { - try { - body - } catch { - case NonFatal(e) => - ex = e - } - } - } - thread.setDaemon(true) - thread.start() - thread.join(timeoutMillis) - if (thread.isAlive) { - thread.interrupt() - // If this interrupt does not work, then this thread is most likely running something that - // is not interruptible. There is not much point to wait for the thread to termniate, and - // we rather let the JVM terminate the thread on exit. - fail( - s"Test '$name' running on o.a.s.util.UninterruptibleThread timed out after" + - s" $timeoutMillis ms") - } else if (ex != null) { - throw ex - } - } - - if (quietly) { - testQuietly(name) { runOnThread() } - } else { - test(name) { runOnThread() } - } - } /** * This method is used to make the given path qualified, when a path diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSQLContext.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSQLContext.scala index cd8d0708d8a32..4d578e21f5494 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSQLContext.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSQLContext.scala @@ -17,86 +17,4 @@ package org.apache.spark.sql.test -import scala.concurrent.duration._ - -import org.scalatest.BeforeAndAfterEach -import org.scalatest.concurrent.Eventually - -import org.apache.spark.{DebugFilesystem, SparkConf} -import org.apache.spark.sql.{SparkSession, SQLContext} -import org.apache.spark.sql.internal.SQLConf - -/** - * Helper trait for SQL test suites where all tests share a single [[TestSparkSession]]. - */ -trait SharedSQLContext extends SQLTestUtils with BeforeAndAfterEach with Eventually { - - protected def sparkConf = { - new SparkConf() - .set("spark.hadoop.fs.file.impl", classOf[DebugFilesystem].getName) - .set("spark.unsafe.exceptionOnMemoryLeak", "true") - .set(SQLConf.CODEGEN_FALLBACK.key, "false") - } - - /** - * The [[TestSparkSession]] to use for all tests in this suite. - * - * By default, the underlying [[org.apache.spark.SparkContext]] will be run in local - * mode with the default test configurations. - */ - private var _spark: TestSparkSession = null - - /** - * The [[TestSparkSession]] to use for all tests in this suite. - */ - protected implicit def spark: SparkSession = _spark - - /** - * The [[TestSQLContext]] to use for all tests in this suite. - */ - protected implicit def sqlContext: SQLContext = _spark.sqlContext - - protected def createSparkSession: TestSparkSession = { - new TestSparkSession(sparkConf) - } - - /** - * Initialize the [[TestSparkSession]]. - */ - protected override def beforeAll(): Unit = { - SparkSession.sqlListener.set(null) - if (_spark == null) { - _spark = createSparkSession - } - // Ensure we have initialized the context before calling parent code - super.beforeAll() - } - - /** - * Stop the underlying [[org.apache.spark.SparkContext]], if any. - */ - protected override def afterAll(): Unit = { - super.afterAll() - if (_spark != null) { - _spark.sessionState.catalog.reset() - _spark.stop() - _spark = null - } - } - - protected override def beforeEach(): Unit = { - super.beforeEach() - DebugFilesystem.clearOpenStreams() - } - - protected override def afterEach(): Unit = { - super.afterEach() - // Clear all persistent datasets after each test - spark.sharedState.cacheManager.clearCache() - // files can be closed from other threads, so wait a bit - // normally this doesn't take more than 1s - eventually(timeout(10.seconds)) { - DebugFilesystem.assertNoOpenStreams() - } - } -} +trait SharedSQLContext extends SQLTestUtils with SharedSparkSession diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSparkSession.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSparkSession.scala new file mode 100644 index 0000000000000..e0568a3c5c99f --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSparkSession.scala @@ -0,0 +1,119 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.test + +import scala.concurrent.duration._ + +import org.scalatest.{BeforeAndAfterEach, Suite} +import org.scalatest.concurrent.Eventually + +import org.apache.spark.{DebugFilesystem, SparkConf} +import org.apache.spark.sql.{SparkSession, SQLContext} +import org.apache.spark.sql.internal.SQLConf + +/** + * Helper trait for SQL test suites where all tests share a single [[TestSparkSession]]. + */ +trait SharedSparkSession + extends SQLTestUtilsBase + with BeforeAndAfterEach + with Eventually { self: Suite => + + protected def sparkConf = { + new SparkConf() + .set("spark.hadoop.fs.file.impl", classOf[DebugFilesystem].getName) + .set("spark.unsafe.exceptionOnMemoryLeak", "true") + .set(SQLConf.CODEGEN_FALLBACK.key, "false") + } + + /** + * The [[TestSparkSession]] to use for all tests in this suite. + * + * By default, the underlying [[org.apache.spark.SparkContext]] will be run in local + * mode with the default test configurations. + */ + private var _spark: TestSparkSession = null + + /** + * The [[TestSparkSession]] to use for all tests in this suite. + */ + protected implicit def spark: SparkSession = _spark + + /** + * The [[TestSQLContext]] to use for all tests in this suite. + */ + protected implicit def sqlContext: SQLContext = _spark.sqlContext + + protected def createSparkSession: TestSparkSession = { + new TestSparkSession(sparkConf) + } + + /** + * Initialize the [[TestSparkSession]]. Generally, this is just called from + * beforeAll; however, in test using styles other than FunSuite, there is + * often code that relies on the session between test group constructs and + * the actual tests, which may need this session. It is purely a semantic + * difference, but semantically, it makes more sense to call + * 'initializeSession' between a 'describe' and an 'it' call than it does to + * call 'beforeAll'. + */ + protected def initializeSession(): Unit = { + SparkSession.sqlListener.set(null) + if (_spark == null) { + _spark = createSparkSession + } + } + + /** + * Make sure the [[TestSparkSession]] is initialized before any tests are run. + */ + protected override def beforeAll(): Unit = { + initializeSession() + + // Ensure we have initialized the context before calling parent code + super.beforeAll() + } + + /** + * Stop the underlying [[org.apache.spark.SparkContext]], if any. + */ + protected override def afterAll(): Unit = { + super.afterAll() + if (_spark != null) { + _spark.sessionState.catalog.reset() + _spark.stop() + _spark = null + } + } + + protected override def beforeEach(): Unit = { + super.beforeEach() + DebugFilesystem.clearOpenStreams() + } + + protected override def afterEach(): Unit = { + super.afterEach() + // Clear all persistent datasets after each test + spark.sharedState.cacheManager.clearCache() + // files can be closed from other threads, so wait a bit + // normally this doesn't take more than 1s + eventually(timeout(10.seconds)) { + DebugFilesystem.assertNoOpenStreams() + } + } +} From 3073344a2551fb198d63f2114a519ab97904cb55 Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Thu, 26 Oct 2017 15:50:27 +0800 Subject: [PATCH 1544/1765] [SPARK-21840][CORE] Add trait that allows conf to be directly set in application. Currently SparkSubmit uses system properties to propagate configuration to applications. This makes it hard to implement features such as SPARK-11035, which would allow multiple applications to be started in the same JVM. The current code would cause the config data from multiple apps to get mixed up. This change introduces a new trait, currently internal to Spark, that allows the app configuration to be passed directly to the application, without having to use system properties. The current "call main() method" behavior is maintained as an implementation of this new trait. This will be useful to allow multiple cluster mode apps to be submitted from the same JVM. As part of this, SparkSubmit was modified to collect all configuration directly into a SparkConf instance. Most of the changes are to tests so they use SparkConf instead of an opaque map. Tested with existing and added unit tests. Author: Marcelo Vanzin Closes #19519 from vanzin/SPARK-21840. --- .../spark/deploy/SparkApplication.scala | 55 +++++ .../org/apache/spark/deploy/SparkSubmit.scala | 160 +++++++------ .../spark/deploy/SparkSubmitSuite.scala | 213 ++++++++++-------- .../rest/StandaloneRestSubmitSuite.scala | 4 +- 4 files changed, 257 insertions(+), 175 deletions(-) create mode 100644 core/src/main/scala/org/apache/spark/deploy/SparkApplication.scala diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkApplication.scala b/core/src/main/scala/org/apache/spark/deploy/SparkApplication.scala new file mode 100644 index 0000000000000..118b4605675b0 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/deploy/SparkApplication.scala @@ -0,0 +1,55 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.deploy + +import java.lang.reflect.Modifier + +import org.apache.spark.SparkConf + +/** + * Entry point for a Spark application. Implementations must provide a no-argument constructor. + */ +private[spark] trait SparkApplication { + + def start(args: Array[String], conf: SparkConf): Unit + +} + +/** + * Implementation of SparkApplication that wraps a standard Java class with a "main" method. + * + * Configuration is propagated to the application via system properties, so running multiple + * of these in the same JVM may lead to undefined behavior due to configuration leaks. + */ +private[deploy] class JavaMainApplication(klass: Class[_]) extends SparkApplication { + + override def start(args: Array[String], conf: SparkConf): Unit = { + val mainMethod = klass.getMethod("main", new Array[String](0).getClass) + if (!Modifier.isStatic(mainMethod.getModifiers)) { + throw new IllegalStateException("The main method in the given main class must be static") + } + + val sysProps = conf.getAll.toMap + sysProps.foreach { case (k, v) => + sys.props(k) = v + } + + mainMethod.invoke(null, args) + } + +} diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala index b7e6d0ea021a4..73b956ef3e470 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala @@ -158,7 +158,7 @@ object SparkSubmit extends CommandLineUtils with Logging { */ @tailrec private def submit(args: SparkSubmitArguments, uninitLog: Boolean): Unit = { - val (childArgs, childClasspath, sysProps, childMainClass) = prepareSubmitEnvironment(args) + val (childArgs, childClasspath, sparkConf, childMainClass) = prepareSubmitEnvironment(args) def doRunMain(): Unit = { if (args.proxyUser != null) { @@ -167,7 +167,7 @@ object SparkSubmit extends CommandLineUtils with Logging { try { proxyUser.doAs(new PrivilegedExceptionAction[Unit]() { override def run(): Unit = { - runMain(childArgs, childClasspath, sysProps, childMainClass, args.verbose) + runMain(childArgs, childClasspath, sparkConf, childMainClass, args.verbose) } }) } catch { @@ -185,7 +185,7 @@ object SparkSubmit extends CommandLineUtils with Logging { } } } else { - runMain(childArgs, childClasspath, sysProps, childMainClass, args.verbose) + runMain(childArgs, childClasspath, sparkConf, childMainClass, args.verbose) } } @@ -235,11 +235,11 @@ object SparkSubmit extends CommandLineUtils with Logging { private[deploy] def prepareSubmitEnvironment( args: SparkSubmitArguments, conf: Option[HadoopConfiguration] = None) - : (Seq[String], Seq[String], Map[String, String], String) = { + : (Seq[String], Seq[String], SparkConf, String) = { // Return values val childArgs = new ArrayBuffer[String]() val childClasspath = new ArrayBuffer[String]() - val sysProps = new HashMap[String, String]() + val sparkConf = new SparkConf() var childMainClass = "" // Set the cluster manager @@ -337,7 +337,6 @@ object SparkSubmit extends CommandLineUtils with Logging { } } - val sparkConf = new SparkConf(false) args.sparkProperties.foreach { case (k, v) => sparkConf.set(k, v) } val hadoopConf = conf.getOrElse(SparkHadoopUtil.newConfiguration(sparkConf)) val targetDir = Utils.createTempDir() @@ -351,8 +350,8 @@ object SparkSubmit extends CommandLineUtils with Logging { // for later use; e.g. in spark sql, the isolated class loader used to talk // to HiveMetastore will use these settings. They will be set as Java system // properties and then loaded by SparkConf - sysProps.put("spark.yarn.keytab", args.keytab) - sysProps.put("spark.yarn.principal", args.principal) + sparkConf.set(KEYTAB, args.keytab) + sparkConf.set(PRINCIPAL, args.principal) UserGroupInformation.loginUserFromKeytab(args.principal, args.keytab) } } @@ -364,23 +363,24 @@ object SparkSubmit extends CommandLineUtils with Logging { args.pyFiles = Option(args.pyFiles).map(resolveGlobPaths(_, hadoopConf)).orNull args.archives = Option(args.archives).map(resolveGlobPaths(_, hadoopConf)).orNull + // This security manager will not need an auth secret, but set a dummy value in case + // spark.authenticate is enabled, otherwise an exception is thrown. + lazy val downloadConf = sparkConf.clone().set(SecurityManager.SPARK_AUTH_SECRET_CONF, "unused") + lazy val secMgr = new SecurityManager(downloadConf) + // In client mode, download remote files. var localPrimaryResource: String = null var localJars: String = null var localPyFiles: String = null if (deployMode == CLIENT) { - // This security manager will not need an auth secret, but set a dummy value in case - // spark.authenticate is enabled, otherwise an exception is thrown. - sparkConf.set(SecurityManager.SPARK_AUTH_SECRET_CONF, "unused") - val secMgr = new SecurityManager(sparkConf) localPrimaryResource = Option(args.primaryResource).map { - downloadFile(_, targetDir, sparkConf, hadoopConf, secMgr) + downloadFile(_, targetDir, downloadConf, hadoopConf, secMgr) }.orNull localJars = Option(args.jars).map { - downloadFileList(_, targetDir, sparkConf, hadoopConf, secMgr) + downloadFileList(_, targetDir, downloadConf, hadoopConf, secMgr) }.orNull localPyFiles = Option(args.pyFiles).map { - downloadFileList(_, targetDir, sparkConf, hadoopConf, secMgr) + downloadFileList(_, targetDir, downloadConf, hadoopConf, secMgr) }.orNull } @@ -409,7 +409,7 @@ object SparkSubmit extends CommandLineUtils with Logging { if (file.exists()) { file.toURI.toString } else { - downloadFile(resource, targetDir, sparkConf, hadoopConf, secMgr) + downloadFile(resource, targetDir, downloadConf, hadoopConf, secMgr) } case _ => uri.toString } @@ -449,7 +449,7 @@ object SparkSubmit extends CommandLineUtils with Logging { args.files = mergeFileLists(args.files, args.pyFiles) } if (localPyFiles != null) { - sysProps("spark.submit.pyFiles") = localPyFiles + sparkConf.set("spark.submit.pyFiles", localPyFiles) } } @@ -515,69 +515,69 @@ object SparkSubmit extends CommandLineUtils with Logging { } // Special flag to avoid deprecation warnings at the client - sysProps("SPARK_SUBMIT") = "true" + sys.props("SPARK_SUBMIT") = "true" // A list of rules to map each argument to system properties or command-line options in // each deploy mode; we iterate through these below val options = List[OptionAssigner]( // All cluster managers - OptionAssigner(args.master, ALL_CLUSTER_MGRS, ALL_DEPLOY_MODES, sysProp = "spark.master"), + OptionAssigner(args.master, ALL_CLUSTER_MGRS, ALL_DEPLOY_MODES, confKey = "spark.master"), OptionAssigner(args.deployMode, ALL_CLUSTER_MGRS, ALL_DEPLOY_MODES, - sysProp = "spark.submit.deployMode"), - OptionAssigner(args.name, ALL_CLUSTER_MGRS, ALL_DEPLOY_MODES, sysProp = "spark.app.name"), - OptionAssigner(args.ivyRepoPath, ALL_CLUSTER_MGRS, CLIENT, sysProp = "spark.jars.ivy"), + confKey = "spark.submit.deployMode"), + OptionAssigner(args.name, ALL_CLUSTER_MGRS, ALL_DEPLOY_MODES, confKey = "spark.app.name"), + OptionAssigner(args.ivyRepoPath, ALL_CLUSTER_MGRS, CLIENT, confKey = "spark.jars.ivy"), OptionAssigner(args.driverMemory, ALL_CLUSTER_MGRS, CLIENT, - sysProp = "spark.driver.memory"), + confKey = "spark.driver.memory"), OptionAssigner(args.driverExtraClassPath, ALL_CLUSTER_MGRS, ALL_DEPLOY_MODES, - sysProp = "spark.driver.extraClassPath"), + confKey = "spark.driver.extraClassPath"), OptionAssigner(args.driverExtraJavaOptions, ALL_CLUSTER_MGRS, ALL_DEPLOY_MODES, - sysProp = "spark.driver.extraJavaOptions"), + confKey = "spark.driver.extraJavaOptions"), OptionAssigner(args.driverExtraLibraryPath, ALL_CLUSTER_MGRS, ALL_DEPLOY_MODES, - sysProp = "spark.driver.extraLibraryPath"), + confKey = "spark.driver.extraLibraryPath"), // Propagate attributes for dependency resolution at the driver side - OptionAssigner(args.packages, STANDALONE | MESOS, CLUSTER, sysProp = "spark.jars.packages"), + OptionAssigner(args.packages, STANDALONE | MESOS, CLUSTER, confKey = "spark.jars.packages"), OptionAssigner(args.repositories, STANDALONE | MESOS, CLUSTER, - sysProp = "spark.jars.repositories"), - OptionAssigner(args.ivyRepoPath, STANDALONE | MESOS, CLUSTER, sysProp = "spark.jars.ivy"), + confKey = "spark.jars.repositories"), + OptionAssigner(args.ivyRepoPath, STANDALONE | MESOS, CLUSTER, confKey = "spark.jars.ivy"), OptionAssigner(args.packagesExclusions, STANDALONE | MESOS, - CLUSTER, sysProp = "spark.jars.excludes"), + CLUSTER, confKey = "spark.jars.excludes"), // Yarn only - OptionAssigner(args.queue, YARN, ALL_DEPLOY_MODES, sysProp = "spark.yarn.queue"), + OptionAssigner(args.queue, YARN, ALL_DEPLOY_MODES, confKey = "spark.yarn.queue"), OptionAssigner(args.numExecutors, YARN, ALL_DEPLOY_MODES, - sysProp = "spark.executor.instances"), - OptionAssigner(args.pyFiles, YARN, ALL_DEPLOY_MODES, sysProp = "spark.yarn.dist.pyFiles"), - OptionAssigner(args.jars, YARN, ALL_DEPLOY_MODES, sysProp = "spark.yarn.dist.jars"), - OptionAssigner(args.files, YARN, ALL_DEPLOY_MODES, sysProp = "spark.yarn.dist.files"), - OptionAssigner(args.archives, YARN, ALL_DEPLOY_MODES, sysProp = "spark.yarn.dist.archives"), - OptionAssigner(args.principal, YARN, ALL_DEPLOY_MODES, sysProp = "spark.yarn.principal"), - OptionAssigner(args.keytab, YARN, ALL_DEPLOY_MODES, sysProp = "spark.yarn.keytab"), + confKey = "spark.executor.instances"), + OptionAssigner(args.pyFiles, YARN, ALL_DEPLOY_MODES, confKey = "spark.yarn.dist.pyFiles"), + OptionAssigner(args.jars, YARN, ALL_DEPLOY_MODES, confKey = "spark.yarn.dist.jars"), + OptionAssigner(args.files, YARN, ALL_DEPLOY_MODES, confKey = "spark.yarn.dist.files"), + OptionAssigner(args.archives, YARN, ALL_DEPLOY_MODES, confKey = "spark.yarn.dist.archives"), + OptionAssigner(args.principal, YARN, ALL_DEPLOY_MODES, confKey = "spark.yarn.principal"), + OptionAssigner(args.keytab, YARN, ALL_DEPLOY_MODES, confKey = "spark.yarn.keytab"), // Other options OptionAssigner(args.executorCores, STANDALONE | YARN, ALL_DEPLOY_MODES, - sysProp = "spark.executor.cores"), + confKey = "spark.executor.cores"), OptionAssigner(args.executorMemory, STANDALONE | MESOS | YARN, ALL_DEPLOY_MODES, - sysProp = "spark.executor.memory"), + confKey = "spark.executor.memory"), OptionAssigner(args.totalExecutorCores, STANDALONE | MESOS, ALL_DEPLOY_MODES, - sysProp = "spark.cores.max"), + confKey = "spark.cores.max"), OptionAssigner(args.files, LOCAL | STANDALONE | MESOS, ALL_DEPLOY_MODES, - sysProp = "spark.files"), - OptionAssigner(args.jars, LOCAL, CLIENT, sysProp = "spark.jars"), - OptionAssigner(args.jars, STANDALONE | MESOS, ALL_DEPLOY_MODES, sysProp = "spark.jars"), + confKey = "spark.files"), + OptionAssigner(args.jars, LOCAL, CLIENT, confKey = "spark.jars"), + OptionAssigner(args.jars, STANDALONE | MESOS, ALL_DEPLOY_MODES, confKey = "spark.jars"), OptionAssigner(args.driverMemory, STANDALONE | MESOS | YARN, CLUSTER, - sysProp = "spark.driver.memory"), + confKey = "spark.driver.memory"), OptionAssigner(args.driverCores, STANDALONE | MESOS | YARN, CLUSTER, - sysProp = "spark.driver.cores"), + confKey = "spark.driver.cores"), OptionAssigner(args.supervise.toString, STANDALONE | MESOS, CLUSTER, - sysProp = "spark.driver.supervise"), - OptionAssigner(args.ivyRepoPath, STANDALONE, CLUSTER, sysProp = "spark.jars.ivy"), + confKey = "spark.driver.supervise"), + OptionAssigner(args.ivyRepoPath, STANDALONE, CLUSTER, confKey = "spark.jars.ivy"), // An internal option used only for spark-shell to add user jars to repl's classloader, // previously it uses "spark.jars" or "spark.yarn.dist.jars" which now may be pointed to // remote jars, so adding a new option to only specify local jars for spark-shell internally. - OptionAssigner(localJars, ALL_CLUSTER_MGRS, CLIENT, sysProp = "spark.repl.local.jars") + OptionAssigner(localJars, ALL_CLUSTER_MGRS, CLIENT, confKey = "spark.repl.local.jars") ) // In client mode, launch the application main class directly @@ -610,24 +610,24 @@ object SparkSubmit extends CommandLineUtils with Logging { (deployMode & opt.deployMode) != 0 && (clusterManager & opt.clusterManager) != 0) { if (opt.clOption != null) { childArgs += (opt.clOption, opt.value) } - if (opt.sysProp != null) { sysProps.put(opt.sysProp, opt.value) } + if (opt.confKey != null) { sparkConf.set(opt.confKey, opt.value) } } } // In case of shells, spark.ui.showConsoleProgress can be true by default or by user. if (isShell(args.primaryResource) && !sparkConf.contains(UI_SHOW_CONSOLE_PROGRESS)) { - sysProps(UI_SHOW_CONSOLE_PROGRESS.key) = "true" + sparkConf.set(UI_SHOW_CONSOLE_PROGRESS, true) } // Add the application jar automatically so the user doesn't have to call sc.addJar // For YARN cluster mode, the jar is already distributed on each node as "app.jar" // For python and R files, the primary resource is already distributed as a regular file if (!isYarnCluster && !args.isPython && !args.isR) { - var jars = sysProps.get("spark.jars").map(x => x.split(",").toSeq).getOrElse(Seq.empty) + var jars = sparkConf.getOption("spark.jars").map(x => x.split(",").toSeq).getOrElse(Seq.empty) if (isUserJar(args.primaryResource)) { jars = jars ++ Seq(args.primaryResource) } - sysProps.put("spark.jars", jars.mkString(",")) + sparkConf.set("spark.jars", jars.mkString(",")) } // In standalone cluster mode, use the REST client to submit the application (Spark 1.3+). @@ -653,12 +653,12 @@ object SparkSubmit extends CommandLineUtils with Logging { // Let YARN know it's a pyspark app, so it distributes needed libraries. if (clusterManager == YARN) { if (args.isPython) { - sysProps.put("spark.yarn.isPython", "true") + sparkConf.set("spark.yarn.isPython", "true") } } if (clusterManager == MESOS && UserGroupInformation.isSecurityEnabled) { - setRMPrincipal(sysProps) + setRMPrincipal(sparkConf) } // In yarn-cluster mode, use yarn.Client as a wrapper around the user class @@ -689,7 +689,7 @@ object SparkSubmit extends CommandLineUtils with Logging { // Second argument is main class childArgs += (args.primaryResource, "") if (args.pyFiles != null) { - sysProps("spark.submit.pyFiles") = args.pyFiles + sparkConf.set("spark.submit.pyFiles", args.pyFiles) } } else if (args.isR) { // Second argument is main class @@ -704,12 +704,12 @@ object SparkSubmit extends CommandLineUtils with Logging { // Load any properties specified through --conf and the default properties file for ((k, v) <- args.sparkProperties) { - sysProps.getOrElseUpdate(k, v) + sparkConf.setIfMissing(k, v) } // Ignore invalid spark.driver.host in cluster modes. if (deployMode == CLUSTER) { - sysProps -= "spark.driver.host" + sparkConf.remove("spark.driver.host") } // Resolve paths in certain spark properties @@ -721,15 +721,15 @@ object SparkSubmit extends CommandLineUtils with Logging { "spark.yarn.dist.jars") pathConfigs.foreach { config => // Replace old URIs with resolved URIs, if they exist - sysProps.get(config).foreach { oldValue => - sysProps(config) = Utils.resolveURIs(oldValue) + sparkConf.getOption(config).foreach { oldValue => + sparkConf.set(config, Utils.resolveURIs(oldValue)) } } // Resolve and format python file paths properly before adding them to the PYTHONPATH. // The resolving part is redundant in the case of --py-files, but necessary if the user // explicitly sets `spark.submit.pyFiles` in his/her default properties file. - sysProps.get("spark.submit.pyFiles").foreach { pyFiles => + sparkConf.getOption("spark.submit.pyFiles").foreach { pyFiles => val resolvedPyFiles = Utils.resolveURIs(pyFiles) val formattedPyFiles = if (!isYarnCluster && !isMesosCluster) { PythonRunner.formatPaths(resolvedPyFiles).mkString(",") @@ -739,22 +739,22 @@ object SparkSubmit extends CommandLineUtils with Logging { // locally. resolvedPyFiles } - sysProps("spark.submit.pyFiles") = formattedPyFiles + sparkConf.set("spark.submit.pyFiles", formattedPyFiles) } - (childArgs, childClasspath, sysProps, childMainClass) + (childArgs, childClasspath, sparkConf, childMainClass) } // [SPARK-20328]. HadoopRDD calls into a Hadoop library that fetches delegation tokens with // renewer set to the YARN ResourceManager. Since YARN isn't configured in Mesos mode, we // must trick it into thinking we're YARN. - private def setRMPrincipal(sysProps: HashMap[String, String]): Unit = { + private def setRMPrincipal(sparkConf: SparkConf): Unit = { val shortUserName = UserGroupInformation.getCurrentUser.getShortUserName val key = s"spark.hadoop.${YarnConfiguration.RM_PRINCIPAL}" // scalastyle:off println printStream.println(s"Setting ${key} to ${shortUserName}") // scalastyle:off println - sysProps.put(key, shortUserName) + sparkConf.set(key, shortUserName) } /** @@ -766,7 +766,7 @@ object SparkSubmit extends CommandLineUtils with Logging { private def runMain( childArgs: Seq[String], childClasspath: Seq[String], - sysProps: Map[String, String], + sparkConf: SparkConf, childMainClass: String, verbose: Boolean): Unit = { // scalastyle:off println @@ -774,14 +774,14 @@ object SparkSubmit extends CommandLineUtils with Logging { printStream.println(s"Main class:\n$childMainClass") printStream.println(s"Arguments:\n${childArgs.mkString("\n")}") // sysProps may contain sensitive information, so redact before printing - printStream.println(s"System properties:\n${Utils.redact(sysProps).mkString("\n")}") + printStream.println(s"Spark config:\n${Utils.redact(sparkConf.getAll.toMap).mkString("\n")}") printStream.println(s"Classpath elements:\n${childClasspath.mkString("\n")}") printStream.println("\n") } // scalastyle:on println val loader = - if (sysProps.getOrElse("spark.driver.userClassPathFirst", "false").toBoolean) { + if (sparkConf.get(DRIVER_USER_CLASS_PATH_FIRST)) { new ChildFirstURLClassLoader(new Array[URL](0), Thread.currentThread.getContextClassLoader) } else { @@ -794,10 +794,6 @@ object SparkSubmit extends CommandLineUtils with Logging { addJarToClasspath(jar, loader) } - for ((key, value) <- sysProps) { - System.setProperty(key, value) - } - var mainClass: Class[_] = null try { @@ -823,14 +819,14 @@ object SparkSubmit extends CommandLineUtils with Logging { System.exit(CLASS_NOT_FOUND_EXIT_STATUS) } - // SPARK-4170 - if (classOf[scala.App].isAssignableFrom(mainClass)) { - printWarning("Subclasses of scala.App may not work correctly. Use a main() method instead.") - } - - val mainMethod = mainClass.getMethod("main", new Array[String](0).getClass) - if (!Modifier.isStatic(mainMethod.getModifiers)) { - throw new IllegalStateException("The main method in the given main class must be static") + val app: SparkApplication = if (classOf[SparkApplication].isAssignableFrom(mainClass)) { + mainClass.newInstance().asInstanceOf[SparkApplication] + } else { + // SPARK-4170 + if (classOf[scala.App].isAssignableFrom(mainClass)) { + printWarning("Subclasses of scala.App may not work correctly. Use a main() method instead.") + } + new JavaMainApplication(mainClass) } @tailrec @@ -844,7 +840,7 @@ object SparkSubmit extends CommandLineUtils with Logging { } try { - mainMethod.invoke(null, childArgs.toArray) + app.start(childArgs.toArray, sparkConf) } catch { case t: Throwable => findCause(t) match { @@ -1271,4 +1267,4 @@ private case class OptionAssigner( clusterManager: Int, deployMode: Int, clOption: String = null, - sysProp: String = null) + confKey: String = null) diff --git a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala index b52da4c0c8bc3..cfbf56fb8c369 100644 --- a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala @@ -176,10 +176,10 @@ class SparkSubmitSuite "thejar.jar" ) val appArgs = new SparkSubmitArguments(clArgs) - val (_, _, sysProps, _) = prepareSubmitEnvironment(appArgs) + val (_, _, conf, _) = prepareSubmitEnvironment(appArgs) appArgs.deployMode should be ("client") - sysProps("spark.submit.deployMode") should be ("client") + conf.get("spark.submit.deployMode") should be ("client") // Both cmd line and configuration are specified, cmdline option takes the priority val clArgs1 = Seq( @@ -190,10 +190,10 @@ class SparkSubmitSuite "thejar.jar" ) val appArgs1 = new SparkSubmitArguments(clArgs1) - val (_, _, sysProps1, _) = prepareSubmitEnvironment(appArgs1) + val (_, _, conf1, _) = prepareSubmitEnvironment(appArgs1) appArgs1.deployMode should be ("cluster") - sysProps1("spark.submit.deployMode") should be ("cluster") + conf1.get("spark.submit.deployMode") should be ("cluster") // Neither cmdline nor configuration are specified, client mode is the default choice val clArgs2 = Seq( @@ -204,9 +204,9 @@ class SparkSubmitSuite val appArgs2 = new SparkSubmitArguments(clArgs2) appArgs2.deployMode should be (null) - val (_, _, sysProps2, _) = prepareSubmitEnvironment(appArgs2) + val (_, _, conf2, _) = prepareSubmitEnvironment(appArgs2) appArgs2.deployMode should be ("client") - sysProps2("spark.submit.deployMode") should be ("client") + conf2.get("spark.submit.deployMode") should be ("client") } test("handles YARN cluster mode") { @@ -227,7 +227,7 @@ class SparkSubmitSuite "thejar.jar", "arg1", "arg2") val appArgs = new SparkSubmitArguments(clArgs) - val (childArgs, classpath, sysProps, mainClass) = prepareSubmitEnvironment(appArgs) + val (childArgs, classpath, conf, mainClass) = prepareSubmitEnvironment(appArgs) val childArgsStr = childArgs.mkString(" ") childArgsStr should include ("--class org.SomeClass") childArgsStr should include ("--arg arg1 --arg arg2") @@ -240,16 +240,16 @@ class SparkSubmitSuite classpath(2) should endWith ("two.jar") classpath(3) should endWith ("three.jar") - sysProps("spark.executor.memory") should be ("5g") - sysProps("spark.driver.memory") should be ("4g") - sysProps("spark.executor.cores") should be ("5") - sysProps("spark.yarn.queue") should be ("thequeue") - sysProps("spark.yarn.dist.jars") should include regex (".*one.jar,.*two.jar,.*three.jar") - sysProps("spark.yarn.dist.files") should include regex (".*file1.txt,.*file2.txt") - sysProps("spark.yarn.dist.archives") should include regex (".*archive1.txt,.*archive2.txt") - sysProps("spark.app.name") should be ("beauty") - sysProps("spark.ui.enabled") should be ("false") - sysProps("SPARK_SUBMIT") should be ("true") + conf.get("spark.executor.memory") should be ("5g") + conf.get("spark.driver.memory") should be ("4g") + conf.get("spark.executor.cores") should be ("5") + conf.get("spark.yarn.queue") should be ("thequeue") + conf.get("spark.yarn.dist.jars") should include regex (".*one.jar,.*two.jar,.*three.jar") + conf.get("spark.yarn.dist.files") should include regex (".*file1.txt,.*file2.txt") + conf.get("spark.yarn.dist.archives") should include regex (".*archive1.txt,.*archive2.txt") + conf.get("spark.app.name") should be ("beauty") + conf.get("spark.ui.enabled") should be ("false") + sys.props("SPARK_SUBMIT") should be ("true") } test("handles YARN client mode") { @@ -270,7 +270,7 @@ class SparkSubmitSuite "thejar.jar", "arg1", "arg2") val appArgs = new SparkSubmitArguments(clArgs) - val (childArgs, classpath, sysProps, mainClass) = prepareSubmitEnvironment(appArgs) + val (childArgs, classpath, conf, mainClass) = prepareSubmitEnvironment(appArgs) childArgs.mkString(" ") should be ("arg1 arg2") mainClass should be ("org.SomeClass") classpath should have length (4) @@ -278,17 +278,17 @@ class SparkSubmitSuite classpath(1) should endWith ("one.jar") classpath(2) should endWith ("two.jar") classpath(3) should endWith ("three.jar") - sysProps("spark.app.name") should be ("trill") - sysProps("spark.executor.memory") should be ("5g") - sysProps("spark.executor.cores") should be ("5") - sysProps("spark.yarn.queue") should be ("thequeue") - sysProps("spark.executor.instances") should be ("6") - sysProps("spark.yarn.dist.files") should include regex (".*file1.txt,.*file2.txt") - sysProps("spark.yarn.dist.archives") should include regex (".*archive1.txt,.*archive2.txt") - sysProps("spark.yarn.dist.jars") should include + conf.get("spark.app.name") should be ("trill") + conf.get("spark.executor.memory") should be ("5g") + conf.get("spark.executor.cores") should be ("5") + conf.get("spark.yarn.queue") should be ("thequeue") + conf.get("spark.executor.instances") should be ("6") + conf.get("spark.yarn.dist.files") should include regex (".*file1.txt,.*file2.txt") + conf.get("spark.yarn.dist.archives") should include regex (".*archive1.txt,.*archive2.txt") + conf.get("spark.yarn.dist.jars") should include regex (".*one.jar,.*two.jar,.*three.jar,.*thejar.jar") - sysProps("SPARK_SUBMIT") should be ("true") - sysProps("spark.ui.enabled") should be ("false") + conf.get("spark.ui.enabled") should be ("false") + sys.props("SPARK_SUBMIT") should be ("true") } test("handles standalone cluster mode") { @@ -316,7 +316,7 @@ class SparkSubmitSuite "arg1", "arg2") val appArgs = new SparkSubmitArguments(clArgs) appArgs.useRest = useRest - val (childArgs, classpath, sysProps, mainClass) = prepareSubmitEnvironment(appArgs) + val (childArgs, classpath, conf, mainClass) = prepareSubmitEnvironment(appArgs) val childArgsStr = childArgs.mkString(" ") if (useRest) { childArgsStr should endWith ("thejar.jar org.SomeClass arg1 arg2") @@ -327,17 +327,18 @@ class SparkSubmitSuite mainClass should be ("org.apache.spark.deploy.Client") } classpath should have size 0 - sysProps should have size 9 - sysProps.keys should contain ("SPARK_SUBMIT") - sysProps.keys should contain ("spark.master") - sysProps.keys should contain ("spark.app.name") - sysProps.keys should contain ("spark.jars") - sysProps.keys should contain ("spark.driver.memory") - sysProps.keys should contain ("spark.driver.cores") - sysProps.keys should contain ("spark.driver.supervise") - sysProps.keys should contain ("spark.ui.enabled") - sysProps.keys should contain ("spark.submit.deployMode") - sysProps("spark.ui.enabled") should be ("false") + sys.props("SPARK_SUBMIT") should be ("true") + + val confMap = conf.getAll.toMap + confMap.keys should contain ("spark.master") + confMap.keys should contain ("spark.app.name") + confMap.keys should contain ("spark.jars") + confMap.keys should contain ("spark.driver.memory") + confMap.keys should contain ("spark.driver.cores") + confMap.keys should contain ("spark.driver.supervise") + confMap.keys should contain ("spark.ui.enabled") + confMap.keys should contain ("spark.submit.deployMode") + conf.get("spark.ui.enabled") should be ("false") } test("handles standalone client mode") { @@ -352,14 +353,14 @@ class SparkSubmitSuite "thejar.jar", "arg1", "arg2") val appArgs = new SparkSubmitArguments(clArgs) - val (childArgs, classpath, sysProps, mainClass) = prepareSubmitEnvironment(appArgs) + val (childArgs, classpath, conf, mainClass) = prepareSubmitEnvironment(appArgs) childArgs.mkString(" ") should be ("arg1 arg2") mainClass should be ("org.SomeClass") classpath should have length (1) classpath(0) should endWith ("thejar.jar") - sysProps("spark.executor.memory") should be ("5g") - sysProps("spark.cores.max") should be ("5") - sysProps("spark.ui.enabled") should be ("false") + conf.get("spark.executor.memory") should be ("5g") + conf.get("spark.cores.max") should be ("5") + conf.get("spark.ui.enabled") should be ("false") } test("handles mesos client mode") { @@ -374,14 +375,14 @@ class SparkSubmitSuite "thejar.jar", "arg1", "arg2") val appArgs = new SparkSubmitArguments(clArgs) - val (childArgs, classpath, sysProps, mainClass) = prepareSubmitEnvironment(appArgs) + val (childArgs, classpath, conf, mainClass) = prepareSubmitEnvironment(appArgs) childArgs.mkString(" ") should be ("arg1 arg2") mainClass should be ("org.SomeClass") classpath should have length (1) classpath(0) should endWith ("thejar.jar") - sysProps("spark.executor.memory") should be ("5g") - sysProps("spark.cores.max") should be ("5") - sysProps("spark.ui.enabled") should be ("false") + conf.get("spark.executor.memory") should be ("5g") + conf.get("spark.cores.max") should be ("5") + conf.get("spark.ui.enabled") should be ("false") } test("handles confs with flag equivalents") { @@ -394,23 +395,26 @@ class SparkSubmitSuite "thejar.jar", "arg1", "arg2") val appArgs = new SparkSubmitArguments(clArgs) - val (_, _, sysProps, mainClass) = prepareSubmitEnvironment(appArgs) - sysProps("spark.executor.memory") should be ("5g") - sysProps("spark.master") should be ("yarn") - sysProps("spark.submit.deployMode") should be ("cluster") + val (_, _, conf, mainClass) = prepareSubmitEnvironment(appArgs) + conf.get("spark.executor.memory") should be ("5g") + conf.get("spark.master") should be ("yarn") + conf.get("spark.submit.deployMode") should be ("cluster") mainClass should be ("org.apache.spark.deploy.yarn.Client") } test("SPARK-21568 ConsoleProgressBar should be enabled only in shells") { + // Unset from system properties since this config is defined in the root pom's test config. + sys.props -= UI_SHOW_CONSOLE_PROGRESS.key + val clArgs1 = Seq("--class", "org.apache.spark.repl.Main", "spark-shell") val appArgs1 = new SparkSubmitArguments(clArgs1) - val (_, _, sysProps1, _) = prepareSubmitEnvironment(appArgs1) - sysProps1(UI_SHOW_CONSOLE_PROGRESS.key) should be ("true") + val (_, _, conf1, _) = prepareSubmitEnvironment(appArgs1) + conf1.get(UI_SHOW_CONSOLE_PROGRESS) should be (true) val clArgs2 = Seq("--class", "org.SomeClass", "thejar.jar") val appArgs2 = new SparkSubmitArguments(clArgs2) - val (_, _, sysProps2, _) = prepareSubmitEnvironment(appArgs2) - sysProps2.keys should not contain UI_SHOW_CONSOLE_PROGRESS.key + val (_, _, conf2, _) = prepareSubmitEnvironment(appArgs2) + assert(!conf2.contains(UI_SHOW_CONSOLE_PROGRESS)) } test("launch simple application with spark-submit") { @@ -585,11 +589,11 @@ class SparkSubmitSuite "--files", files, "thejar.jar") val appArgs = new SparkSubmitArguments(clArgs) - val sysProps = SparkSubmit.prepareSubmitEnvironment(appArgs)._3 + val (_, _, conf, _) = SparkSubmit.prepareSubmitEnvironment(appArgs) appArgs.jars should be (Utils.resolveURIs(jars)) appArgs.files should be (Utils.resolveURIs(files)) - sysProps("spark.jars") should be (Utils.resolveURIs(jars + ",thejar.jar")) - sysProps("spark.files") should be (Utils.resolveURIs(files)) + conf.get("spark.jars") should be (Utils.resolveURIs(jars + ",thejar.jar")) + conf.get("spark.files") should be (Utils.resolveURIs(files)) // Test files and archives (Yarn) val clArgs2 = Seq( @@ -600,11 +604,11 @@ class SparkSubmitSuite "thejar.jar" ) val appArgs2 = new SparkSubmitArguments(clArgs2) - val sysProps2 = SparkSubmit.prepareSubmitEnvironment(appArgs2)._3 + val (_, _, conf2, _) = SparkSubmit.prepareSubmitEnvironment(appArgs2) appArgs2.files should be (Utils.resolveURIs(files)) appArgs2.archives should be (Utils.resolveURIs(archives)) - sysProps2("spark.yarn.dist.files") should be (Utils.resolveURIs(files)) - sysProps2("spark.yarn.dist.archives") should be (Utils.resolveURIs(archives)) + conf2.get("spark.yarn.dist.files") should be (Utils.resolveURIs(files)) + conf2.get("spark.yarn.dist.archives") should be (Utils.resolveURIs(archives)) // Test python files val clArgs3 = Seq( @@ -615,12 +619,12 @@ class SparkSubmitSuite "mister.py" ) val appArgs3 = new SparkSubmitArguments(clArgs3) - val sysProps3 = SparkSubmit.prepareSubmitEnvironment(appArgs3)._3 + val (_, _, conf3, _) = SparkSubmit.prepareSubmitEnvironment(appArgs3) appArgs3.pyFiles should be (Utils.resolveURIs(pyFiles)) - sysProps3("spark.submit.pyFiles") should be ( + conf3.get("spark.submit.pyFiles") should be ( PythonRunner.formatPaths(Utils.resolveURIs(pyFiles)).mkString(",")) - sysProps3(PYSPARK_DRIVER_PYTHON.key) should be ("python3.4") - sysProps3(PYSPARK_PYTHON.key) should be ("python3.5") + conf3.get(PYSPARK_DRIVER_PYTHON.key) should be ("python3.4") + conf3.get(PYSPARK_PYTHON.key) should be ("python3.5") } test("resolves config paths correctly") { @@ -644,9 +648,9 @@ class SparkSubmitSuite "thejar.jar" ) val appArgs = new SparkSubmitArguments(clArgs) - val sysProps = SparkSubmit.prepareSubmitEnvironment(appArgs)._3 - sysProps("spark.jars") should be(Utils.resolveURIs(jars + ",thejar.jar")) - sysProps("spark.files") should be(Utils.resolveURIs(files)) + val (_, _, conf, _) = SparkSubmit.prepareSubmitEnvironment(appArgs) + conf.get("spark.jars") should be(Utils.resolveURIs(jars + ",thejar.jar")) + conf.get("spark.files") should be(Utils.resolveURIs(files)) // Test files and archives (Yarn) val f2 = File.createTempFile("test-submit-files-archives", "", tmpDir) @@ -661,9 +665,9 @@ class SparkSubmitSuite "thejar.jar" ) val appArgs2 = new SparkSubmitArguments(clArgs2) - val sysProps2 = SparkSubmit.prepareSubmitEnvironment(appArgs2)._3 - sysProps2("spark.yarn.dist.files") should be(Utils.resolveURIs(files)) - sysProps2("spark.yarn.dist.archives") should be(Utils.resolveURIs(archives)) + val (_, _, conf2, _) = SparkSubmit.prepareSubmitEnvironment(appArgs2) + conf2.get("spark.yarn.dist.files") should be(Utils.resolveURIs(files)) + conf2.get("spark.yarn.dist.archives") should be(Utils.resolveURIs(archives)) // Test python files val f3 = File.createTempFile("test-submit-python-files", "", tmpDir) @@ -676,8 +680,8 @@ class SparkSubmitSuite "mister.py" ) val appArgs3 = new SparkSubmitArguments(clArgs3) - val sysProps3 = SparkSubmit.prepareSubmitEnvironment(appArgs3)._3 - sysProps3("spark.submit.pyFiles") should be( + val (_, _, conf3, _) = SparkSubmit.prepareSubmitEnvironment(appArgs3) + conf3.get("spark.submit.pyFiles") should be( PythonRunner.formatPaths(Utils.resolveURIs(pyFiles)).mkString(",")) // Test remote python files @@ -693,11 +697,9 @@ class SparkSubmitSuite "hdfs:///tmp/mister.py" ) val appArgs4 = new SparkSubmitArguments(clArgs4) - val sysProps4 = SparkSubmit.prepareSubmitEnvironment(appArgs4)._3 + val (_, _, conf4, _) = SparkSubmit.prepareSubmitEnvironment(appArgs4) // Should not format python path for yarn cluster mode - sysProps4("spark.submit.pyFiles") should be( - Utils.resolveURIs(remotePyFiles) - ) + conf4.get("spark.submit.pyFiles") should be(Utils.resolveURIs(remotePyFiles)) } test("user classpath first in driver") { @@ -771,14 +773,14 @@ class SparkSubmitSuite jar2.toString) val appArgs = new SparkSubmitArguments(args) - val sysProps = SparkSubmit.prepareSubmitEnvironment(appArgs)._3 - sysProps("spark.yarn.dist.jars").split(",").toSet should be + val (_, _, conf, _) = SparkSubmit.prepareSubmitEnvironment(appArgs) + conf.get("spark.yarn.dist.jars").split(",").toSet should be (Set(jar1.toURI.toString, jar2.toURI.toString)) - sysProps("spark.yarn.dist.files").split(",").toSet should be + conf.get("spark.yarn.dist.files").split(",").toSet should be (Set(file1.toURI.toString, file2.toURI.toString)) - sysProps("spark.yarn.dist.pyFiles").split(",").toSet should be + conf.get("spark.yarn.dist.pyFiles").split(",").toSet should be (Set(pyFile1.getAbsolutePath, pyFile2.getAbsolutePath)) - sysProps("spark.yarn.dist.archives").split(",").toSet should be + conf.get("spark.yarn.dist.archives").split(",").toSet should be (Set(archive1.toURI.toString, archive2.toURI.toString)) } @@ -897,18 +899,18 @@ class SparkSubmitSuite ) val appArgs = new SparkSubmitArguments(args) - val sysProps = SparkSubmit.prepareSubmitEnvironment(appArgs, Some(hadoopConf))._3 + val (_, _, conf, _) = SparkSubmit.prepareSubmitEnvironment(appArgs, Some(hadoopConf)) // All the resources should still be remote paths, so that YARN client will not upload again. - sysProps("spark.yarn.dist.jars") should be (tmpJarPath) - sysProps("spark.yarn.dist.files") should be (s"s3a://${file.getAbsolutePath}") - sysProps("spark.yarn.dist.pyFiles") should be (s"s3a://${pyFile.getAbsolutePath}") + conf.get("spark.yarn.dist.jars") should be (tmpJarPath) + conf.get("spark.yarn.dist.files") should be (s"s3a://${file.getAbsolutePath}") + conf.get("spark.yarn.dist.pyFiles") should be (s"s3a://${pyFile.getAbsolutePath}") // Local repl jars should be a local path. - sysProps("spark.repl.local.jars") should (startWith("file:")) + conf.get("spark.repl.local.jars") should (startWith("file:")) // local py files should not be a URI format. - sysProps("spark.submit.pyFiles") should (startWith("/")) + conf.get("spark.submit.pyFiles") should (startWith("/")) } test("download remote resource if it is not supported by yarn service") { @@ -955,9 +957,9 @@ class SparkSubmitSuite ) val appArgs = new SparkSubmitArguments(args) - val sysProps = SparkSubmit.prepareSubmitEnvironment(appArgs, Some(hadoopConf))._3 + val (_, _, conf, _) = SparkSubmit.prepareSubmitEnvironment(appArgs, Some(hadoopConf)) - val jars = sysProps("spark.yarn.dist.jars").split(",").toSet + val jars = conf.get("spark.yarn.dist.jars").split(",").toSet // The URI of remote S3 resource should still be remote. assert(jars.contains(tmpS3JarPath)) @@ -996,6 +998,21 @@ class SparkSubmitSuite conf.set("fs.s3a.impl", classOf[TestFileSystem].getCanonicalName) conf.set("fs.s3a.impl.disable.cache", "true") } + + test("start SparkApplication without modifying system properties") { + val args = Array( + "--class", classOf[TestSparkApplication].getName(), + "--master", "local", + "--conf", "spark.test.hello=world", + "spark-internal", + "hello") + + val exception = intercept[SparkException] { + SparkSubmit.main(args) + } + + assert(exception.getMessage() === "hello") + } } object SparkSubmitSuite extends SparkFunSuite with TimeLimits { @@ -1115,3 +1132,17 @@ class TestFileSystem extends org.apache.hadoop.fs.LocalFileSystem { override def open(path: Path): FSDataInputStream = super.open(local(path)) } + +class TestSparkApplication extends SparkApplication with Matchers { + + override def start(args: Array[String], conf: SparkConf): Unit = { + assert(args.size === 1) + assert(args(0) === "hello") + assert(conf.get("spark.test.hello") === "world") + assert(sys.props.get("spark.test.hello") === None) + + // This is how the test verifies the application was actually run. + throw new SparkException(args(0)) + } + +} diff --git a/core/src/test/scala/org/apache/spark/deploy/rest/StandaloneRestSubmitSuite.scala b/core/src/test/scala/org/apache/spark/deploy/rest/StandaloneRestSubmitSuite.scala index 70887dc5dd97a..490baf040491f 100644 --- a/core/src/test/scala/org/apache/spark/deploy/rest/StandaloneRestSubmitSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/rest/StandaloneRestSubmitSuite.scala @@ -445,9 +445,9 @@ class StandaloneRestSubmitSuite extends SparkFunSuite with BeforeAndAfterEach { "--class", mainClass, mainJar) ++ appArgs val args = new SparkSubmitArguments(commandLineArgs) - val (_, _, sparkProperties, _) = SparkSubmit.prepareSubmitEnvironment(args) + val (_, _, sparkConf, _) = SparkSubmit.prepareSubmitEnvironment(args) new RestSubmissionClient("spark://host:port").constructSubmitRequest( - mainJar, mainClass, appArgs, sparkProperties.toMap, Map.empty) + mainJar, mainClass, appArgs, sparkConf.getAll.toMap, Map.empty) } /** Return the response as a submit response, or fail with error otherwise. */ From a83d8d5adcb4e0061e43105767242ba9770dda96 Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Thu, 26 Oct 2017 20:54:36 +0900 Subject: [PATCH 1545/1765] [SPARK-17902][R] Revive stringsAsFactors option for collect() in SparkR ## What changes were proposed in this pull request? This PR proposes to revive `stringsAsFactors` option in collect API, which was mistakenly removed in https://github.com/apache/spark/commit/71a138cd0e0a14e8426f97877e3b52a562bbd02c. Simply, it casts `charactor` to `factor` if it meets the condition, `stringsAsFactors && is.character(vec)` in primitive type conversion. ## How was this patch tested? Unit test in `R/pkg/tests/fulltests/test_sparkSQL.R`. Author: hyukjinkwon Closes #19551 from HyukjinKwon/SPARK-17902. --- R/pkg/R/DataFrame.R | 3 +++ R/pkg/tests/fulltests/test_sparkSQL.R | 6 ++++++ 2 files changed, 9 insertions(+) diff --git a/R/pkg/R/DataFrame.R b/R/pkg/R/DataFrame.R index 176bb3b8a8d0c..aaa3349d57506 100644 --- a/R/pkg/R/DataFrame.R +++ b/R/pkg/R/DataFrame.R @@ -1191,6 +1191,9 @@ setMethod("collect", vec <- do.call(c, col) stopifnot(class(vec) != "list") class(vec) <- PRIMITIVE_TYPES[[colType]] + if (is.character(vec) && stringsAsFactors) { + vec <- as.factor(vec) + } df[[colIndex]] <- vec } else { df[[colIndex]] <- col diff --git a/R/pkg/tests/fulltests/test_sparkSQL.R b/R/pkg/tests/fulltests/test_sparkSQL.R index 4382ef2ed4525..0c8118a7c73f3 100644 --- a/R/pkg/tests/fulltests/test_sparkSQL.R +++ b/R/pkg/tests/fulltests/test_sparkSQL.R @@ -499,6 +499,12 @@ test_that("create DataFrame with different data types", { expect_equal(collect(df), data.frame(l, stringsAsFactors = FALSE)) }) +test_that("SPARK-17902: collect() with stringsAsFactors enabled", { + df <- suppressWarnings(collect(createDataFrame(iris), stringsAsFactors = TRUE)) + expect_equal(class(iris$Species), class(df$Species)) + expect_equal(iris$Species, df$Species) +}) + test_that("SPARK-17811: can create DataFrame containing NA as date and time", { df <- data.frame( id = 1:2, From 0e9a750a8d389b3a17834584d31c204c77c6970d Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Thu, 26 Oct 2017 11:05:16 -0500 Subject: [PATCH 1546/1765] [SPARK-20643][CORE] Add listener implementation to collect app state. The initial listener code is based on the existing JobProgressListener (and others), and tries to mimic their behavior as much as possible. The change also includes some minor code movement so that some types and methods from the initial history server code code can be reused. The code introduces a few mutable versions of public API types, used internally, to make it easier to update information without ugly copy methods, and also to make certain updates cheaper. Note the code here is not 100% correct. This is meant as a building ground for the UI integration in the next milestones. As different parts of the UI are ported, fixes will be made to the different parts of this code to account for the needed behavior. I also added annotations to API types so that Jackson is able to correctly deserialize options, sequences and maps that store primitive types. Author: Marcelo Vanzin Closes #19383 from vanzin/SPARK-20643. --- .../apache/spark/util/kvstore/KVTypeInfo.java | 2 + .../apache/spark/util/kvstore/LevelDB.java | 2 +- .../spark/status/api/v1/StageStatus.java | 3 +- .../deploy/history/FsHistoryProvider.scala | 37 +- .../apache/spark/deploy/history/config.scala | 6 - .../spark/status/AppStatusListener.scala | 531 ++++++++++++++ .../org/apache/spark/status/KVUtils.scala | 73 ++ .../org/apache/spark/status/LiveEntity.scala | 526 +++++++++++++ .../status/api/v1/AllStagesResource.scala | 4 +- .../org/apache/spark/status/api/v1/api.scala | 11 +- .../org/apache/spark/status/storeTypes.scala | 98 +++ .../history/FsHistoryProviderSuite.scala | 2 +- .../spark/status/AppStatusListenerSuite.scala | 690 ++++++++++++++++++ project/MimaExcludes.scala | 2 + 14 files changed, 1942 insertions(+), 45 deletions(-) create mode 100644 core/src/main/scala/org/apache/spark/status/AppStatusListener.scala create mode 100644 core/src/main/scala/org/apache/spark/status/KVUtils.scala create mode 100644 core/src/main/scala/org/apache/spark/status/LiveEntity.scala create mode 100644 core/src/main/scala/org/apache/spark/status/storeTypes.scala create mode 100644 core/src/test/scala/org/apache/spark/status/AppStatusListenerSuite.scala diff --git a/common/kvstore/src/main/java/org/apache/spark/util/kvstore/KVTypeInfo.java b/common/kvstore/src/main/java/org/apache/spark/util/kvstore/KVTypeInfo.java index a2b077e4531ee..870b484f99068 100644 --- a/common/kvstore/src/main/java/org/apache/spark/util/kvstore/KVTypeInfo.java +++ b/common/kvstore/src/main/java/org/apache/spark/util/kvstore/KVTypeInfo.java @@ -46,6 +46,7 @@ public KVTypeInfo(Class type) throws Exception { KVIndex idx = f.getAnnotation(KVIndex.class); if (idx != null) { checkIndex(idx, indices); + f.setAccessible(true); indices.put(idx.value(), idx); f.setAccessible(true); accessors.put(idx.value(), new FieldAccessor(f)); @@ -58,6 +59,7 @@ public KVTypeInfo(Class type) throws Exception { checkIndex(idx, indices); Preconditions.checkArgument(m.getParameterTypes().length == 0, "Annotated method %s::%s should not have any parameters.", type.getName(), m.getName()); + m.setAccessible(true); indices.put(idx.value(), idx); m.setAccessible(true); accessors.put(idx.value(), new MethodAccessor(m)); diff --git a/common/kvstore/src/main/java/org/apache/spark/util/kvstore/LevelDB.java b/common/kvstore/src/main/java/org/apache/spark/util/kvstore/LevelDB.java index ff48b155fab31..4f9e10ca20066 100644 --- a/common/kvstore/src/main/java/org/apache/spark/util/kvstore/LevelDB.java +++ b/common/kvstore/src/main/java/org/apache/spark/util/kvstore/LevelDB.java @@ -76,7 +76,7 @@ public LevelDB(File path, KVStoreSerializer serializer) throws Exception { this.types = new ConcurrentHashMap<>(); Options options = new Options(); - options.createIfMissing(!path.exists()); + options.createIfMissing(true); this._db = new AtomicReference<>(JniDBFactory.factory.open(path, options)); byte[] versionData = db().get(STORE_VERSION_KEY); diff --git a/core/src/main/java/org/apache/spark/status/api/v1/StageStatus.java b/core/src/main/java/org/apache/spark/status/api/v1/StageStatus.java index 9dbb565aab707..40b5f627369d5 100644 --- a/core/src/main/java/org/apache/spark/status/api/v1/StageStatus.java +++ b/core/src/main/java/org/apache/spark/status/api/v1/StageStatus.java @@ -23,7 +23,8 @@ public enum StageStatus { ACTIVE, COMPLETE, FAILED, - PENDING; + PENDING, + SKIPPED; public static StageStatus fromString(String str) { return EnumUtil.parseIgnoreCase(StageStatus.class, str); diff --git a/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala b/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala index 3889dd097ee59..cf97597b484d8 100644 --- a/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala +++ b/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala @@ -42,6 +42,7 @@ import org.apache.spark.deploy.history.config._ import org.apache.spark.internal.Logging import org.apache.spark.scheduler._ import org.apache.spark.scheduler.ReplayListenerBus._ +import org.apache.spark.status.KVUtils._ import org.apache.spark.status.api.v1 import org.apache.spark.ui.SparkUI import org.apache.spark.util.{Clock, SystemClock, ThreadUtils, Utils} @@ -129,29 +130,15 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) // Visible for testing. private[history] val listing: KVStore = storePath.map { path => val dbPath = new File(path, "listing.ldb") - - def openDB(): LevelDB = new LevelDB(dbPath, new KVStoreScalaSerializer()) + val metadata = new FsHistoryProviderMetadata(CURRENT_LISTING_VERSION, logDir.toString()) try { - val db = openDB() - val meta = db.getMetadata(classOf[KVStoreMetadata]) - - if (meta == null) { - db.setMetadata(new KVStoreMetadata(CURRENT_LISTING_VERSION, logDir)) - db - } else if (meta.version != CURRENT_LISTING_VERSION || !logDir.equals(meta.logDir)) { - logInfo("Detected mismatched config in existing DB, deleting...") - db.close() - Utils.deleteRecursively(dbPath) - openDB() - } else { - db - } + open(new File(path, "listing.ldb"), metadata) } catch { - case _: UnsupportedStoreVersionException => + case _: UnsupportedStoreVersionException | _: MetadataMismatchException => logInfo("Detected incompatible DB versions, deleting...") Utils.deleteRecursively(dbPath) - openDB() + open(new File(path, "listing.ldb"), metadata) } }.getOrElse(new InMemoryStore()) @@ -720,19 +707,7 @@ private[history] object FsHistoryProvider { private[history] val CURRENT_LISTING_VERSION = 1L } -/** - * A KVStoreSerializer that provides Scala types serialization too, and uses the same options as - * the API serializer. - */ -private class KVStoreScalaSerializer extends KVStoreSerializer { - - mapper.registerModule(DefaultScalaModule) - mapper.setSerializationInclusion(JsonInclude.Include.NON_NULL) - mapper.setDateFormat(v1.JacksonMessageWriter.makeISODateFormat) - -} - -private[history] case class KVStoreMetadata( +private[history] case class FsHistoryProviderMetadata( version: Long, logDir: String) diff --git a/core/src/main/scala/org/apache/spark/deploy/history/config.scala b/core/src/main/scala/org/apache/spark/deploy/history/config.scala index fb9e997def0dd..52dedc1a2ed41 100644 --- a/core/src/main/scala/org/apache/spark/deploy/history/config.scala +++ b/core/src/main/scala/org/apache/spark/deploy/history/config.scala @@ -19,16 +19,10 @@ package org.apache.spark.deploy.history import java.util.concurrent.TimeUnit -import scala.annotation.meta.getter - import org.apache.spark.internal.config.ConfigBuilder -import org.apache.spark.util.kvstore.KVIndex private[spark] object config { - /** Use this to annotate constructor params to be used as KVStore indices. */ - type KVIndexParam = KVIndex @getter - val DEFAULT_LOG_DIR = "file:/tmp/spark-events" val EVENT_LOG_DIR = ConfigBuilder("spark.history.fs.logDirectory") diff --git a/core/src/main/scala/org/apache/spark/status/AppStatusListener.scala b/core/src/main/scala/org/apache/spark/status/AppStatusListener.scala new file mode 100644 index 0000000000000..f120685c941df --- /dev/null +++ b/core/src/main/scala/org/apache/spark/status/AppStatusListener.scala @@ -0,0 +1,531 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.status + +import java.util.Date + +import scala.collection.mutable.HashMap + +import org.apache.spark._ +import org.apache.spark.executor.TaskMetrics +import org.apache.spark.internal.Logging +import org.apache.spark.scheduler._ +import org.apache.spark.status.api.v1 +import org.apache.spark.storage._ +import org.apache.spark.ui.SparkUI +import org.apache.spark.util.kvstore.KVStore + +/** + * A Spark listener that writes application information to a data store. The types written to the + * store are defined in the `storeTypes.scala` file and are based on the public REST API. + */ +private class AppStatusListener(kvstore: KVStore) extends SparkListener with Logging { + + private var sparkVersion = SPARK_VERSION + private var appInfo: v1.ApplicationInfo = null + private var coresPerTask: Int = 1 + + // Keep track of live entities, so that task metrics can be efficiently updated (without + // causing too many writes to the underlying store, and other expensive operations). + private val liveStages = new HashMap[(Int, Int), LiveStage]() + private val liveJobs = new HashMap[Int, LiveJob]() + private val liveExecutors = new HashMap[String, LiveExecutor]() + private val liveTasks = new HashMap[Long, LiveTask]() + private val liveRDDs = new HashMap[Int, LiveRDD]() + + override def onOtherEvent(event: SparkListenerEvent): Unit = event match { + case SparkListenerLogStart(version) => sparkVersion = version + case _ => + } + + override def onApplicationStart(event: SparkListenerApplicationStart): Unit = { + assert(event.appId.isDefined, "Application without IDs are not supported.") + + val attempt = new v1.ApplicationAttemptInfo( + event.appAttemptId, + new Date(event.time), + new Date(-1), + new Date(event.time), + -1L, + event.sparkUser, + false, + sparkVersion) + + appInfo = new v1.ApplicationInfo( + event.appId.get, + event.appName, + None, + None, + None, + None, + Seq(attempt)) + + kvstore.write(new ApplicationInfoWrapper(appInfo)) + } + + override def onApplicationEnd(event: SparkListenerApplicationEnd): Unit = { + val old = appInfo.attempts.head + val attempt = new v1.ApplicationAttemptInfo( + old.attemptId, + old.startTime, + new Date(event.time), + new Date(event.time), + event.time - old.startTime.getTime(), + old.sparkUser, + true, + old.appSparkVersion) + + appInfo = new v1.ApplicationInfo( + appInfo.id, + appInfo.name, + None, + None, + None, + None, + Seq(attempt)) + kvstore.write(new ApplicationInfoWrapper(appInfo)) + } + + override def onExecutorAdded(event: SparkListenerExecutorAdded): Unit = { + // This needs to be an update in case an executor re-registers after the driver has + // marked it as "dead". + val exec = getOrCreateExecutor(event.executorId) + exec.host = event.executorInfo.executorHost + exec.isActive = true + exec.totalCores = event.executorInfo.totalCores + exec.maxTasks = event.executorInfo.totalCores / coresPerTask + exec.executorLogs = event.executorInfo.logUrlMap + update(exec) + } + + override def onExecutorRemoved(event: SparkListenerExecutorRemoved): Unit = { + liveExecutors.remove(event.executorId).foreach { exec => + exec.isActive = false + update(exec) + } + } + + override def onExecutorBlacklisted(event: SparkListenerExecutorBlacklisted): Unit = { + updateBlackListStatus(event.executorId, true) + } + + override def onExecutorUnblacklisted(event: SparkListenerExecutorUnblacklisted): Unit = { + updateBlackListStatus(event.executorId, false) + } + + override def onNodeBlacklisted(event: SparkListenerNodeBlacklisted): Unit = { + updateNodeBlackList(event.hostId, true) + } + + override def onNodeUnblacklisted(event: SparkListenerNodeUnblacklisted): Unit = { + updateNodeBlackList(event.hostId, false) + } + + private def updateBlackListStatus(execId: String, blacklisted: Boolean): Unit = { + liveExecutors.get(execId).foreach { exec => + exec.isBlacklisted = blacklisted + update(exec) + } + } + + private def updateNodeBlackList(host: String, blacklisted: Boolean): Unit = { + // Implicitly (un)blacklist every executor associated with the node. + liveExecutors.values.foreach { exec => + if (exec.hostname == host) { + exec.isBlacklisted = blacklisted + update(exec) + } + } + } + + override def onJobStart(event: SparkListenerJobStart): Unit = { + // Compute (a potential over-estimate of) the number of tasks that will be run by this job. + // This may be an over-estimate because the job start event references all of the result + // stages' transitive stage dependencies, but some of these stages might be skipped if their + // output is available from earlier runs. + // See https://github.com/apache/spark/pull/3009 for a more extensive discussion. + val numTasks = { + val missingStages = event.stageInfos.filter(_.completionTime.isEmpty) + missingStages.map(_.numTasks).sum + } + + val lastStageInfo = event.stageInfos.lastOption + val lastStageName = lastStageInfo.map(_.name).getOrElse("(Unknown Stage Name)") + + val jobGroup = Option(event.properties) + .flatMap { p => Option(p.getProperty(SparkContext.SPARK_JOB_GROUP_ID)) } + + val job = new LiveJob( + event.jobId, + lastStageName, + Some(new Date(event.time)), + event.stageIds, + jobGroup, + numTasks) + liveJobs.put(event.jobId, job) + update(job) + + event.stageInfos.foreach { stageInfo => + // A new job submission may re-use an existing stage, so this code needs to do an update + // instead of just a write. + val stage = getOrCreateStage(stageInfo) + stage.jobs :+= job + stage.jobIds += event.jobId + update(stage) + } + } + + override def onJobEnd(event: SparkListenerJobEnd): Unit = { + liveJobs.remove(event.jobId).foreach { job => + job.status = event.jobResult match { + case JobSucceeded => JobExecutionStatus.SUCCEEDED + case JobFailed(_) => JobExecutionStatus.FAILED + } + + job.completionTime = Some(new Date(event.time)) + update(job) + } + } + + override def onStageSubmitted(event: SparkListenerStageSubmitted): Unit = { + val stage = getOrCreateStage(event.stageInfo) + stage.status = v1.StageStatus.ACTIVE + stage.schedulingPool = Option(event.properties).flatMap { p => + Option(p.getProperty("spark.scheduler.pool")) + }.getOrElse(SparkUI.DEFAULT_POOL_NAME) + + // Look at all active jobs to find the ones that mention this stage. + stage.jobs = liveJobs.values + .filter(_.stageIds.contains(event.stageInfo.stageId)) + .toSeq + stage.jobIds = stage.jobs.map(_.jobId).toSet + + stage.jobs.foreach { job => + job.completedStages = job.completedStages - event.stageInfo.stageId + job.activeStages += 1 + update(job) + } + + event.stageInfo.rddInfos.foreach { info => + if (info.storageLevel.isValid) { + update(liveRDDs.getOrElseUpdate(info.id, new LiveRDD(info))) + } + } + + update(stage) + } + + override def onTaskStart(event: SparkListenerTaskStart): Unit = { + val task = new LiveTask(event.taskInfo, event.stageId, event.stageAttemptId) + liveTasks.put(event.taskInfo.taskId, task) + update(task) + + liveStages.get((event.stageId, event.stageAttemptId)).foreach { stage => + stage.activeTasks += 1 + stage.firstLaunchTime = math.min(stage.firstLaunchTime, event.taskInfo.launchTime) + update(stage) + + stage.jobs.foreach { job => + job.activeTasks += 1 + update(job) + } + } + + liveExecutors.get(event.taskInfo.executorId).foreach { exec => + exec.activeTasks += 1 + exec.totalTasks += 1 + update(exec) + } + } + + override def onTaskGettingResult(event: SparkListenerTaskGettingResult): Unit = { + // Call update on the task so that the "getting result" time is written to the store; the + // value is part of the mutable TaskInfo state that the live entity already references. + liveTasks.get(event.taskInfo.taskId).foreach { task => + update(task) + } + } + + override def onTaskEnd(event: SparkListenerTaskEnd): Unit = { + // TODO: can this really happen? + if (event.taskInfo == null) { + return + } + + val metricsDelta = liveTasks.remove(event.taskInfo.taskId).map { task => + val errorMessage = event.reason match { + case Success => + None + case k: TaskKilled => + Some(k.reason) + case e: ExceptionFailure => // Handle ExceptionFailure because we might have accumUpdates + Some(e.toErrorString) + case e: TaskFailedReason => // All other failure cases + Some(e.toErrorString) + case other => + logInfo(s"Unhandled task end reason: $other") + None + } + task.errorMessage = errorMessage + val delta = task.updateMetrics(event.taskMetrics) + update(task) + delta + }.orNull + + val (completedDelta, failedDelta) = event.reason match { + case Success => + (1, 0) + case _ => + (0, 1) + } + + liveStages.get((event.stageId, event.stageAttemptId)).foreach { stage => + if (metricsDelta != null) { + stage.metrics.update(metricsDelta) + } + stage.activeTasks -= 1 + stage.completedTasks += completedDelta + stage.failedTasks += failedDelta + update(stage) + + stage.jobs.foreach { job => + job.activeTasks -= 1 + job.completedTasks += completedDelta + job.failedTasks += failedDelta + update(job) + } + + val esummary = stage.executorSummary(event.taskInfo.executorId) + esummary.taskTime += event.taskInfo.duration + esummary.succeededTasks += completedDelta + esummary.failedTasks += failedDelta + if (metricsDelta != null) { + esummary.metrics.update(metricsDelta) + } + update(esummary) + } + + liveExecutors.get(event.taskInfo.executorId).foreach { exec => + if (event.taskMetrics != null) { + val readMetrics = event.taskMetrics.shuffleReadMetrics + exec.totalGcTime += event.taskMetrics.jvmGCTime + exec.totalInputBytes += event.taskMetrics.inputMetrics.bytesRead + exec.totalShuffleRead += readMetrics.localBytesRead + readMetrics.remoteBytesRead + exec.totalShuffleWrite += event.taskMetrics.shuffleWriteMetrics.bytesWritten + } + + exec.activeTasks -= 1 + exec.completedTasks += completedDelta + exec.failedTasks += failedDelta + exec.totalDuration += event.taskInfo.duration + update(exec) + } + } + + override def onStageCompleted(event: SparkListenerStageCompleted): Unit = { + liveStages.remove((event.stageInfo.stageId, event.stageInfo.attemptId)).foreach { stage => + stage.info = event.stageInfo + + // Because of SPARK-20205, old event logs may contain valid stages without a submission time + // in their start event. In those cases, we can only detect whether a stage was skipped by + // waiting until the completion event, at which point the field would have been set. + stage.status = event.stageInfo.failureReason match { + case Some(_) => v1.StageStatus.FAILED + case _ if event.stageInfo.submissionTime.isDefined => v1.StageStatus.COMPLETE + case _ => v1.StageStatus.SKIPPED + } + update(stage) + + stage.jobs.foreach { job => + stage.status match { + case v1.StageStatus.COMPLETE => + job.completedStages += event.stageInfo.stageId + case v1.StageStatus.SKIPPED => + job.skippedStages += event.stageInfo.stageId + job.skippedTasks += event.stageInfo.numTasks + case _ => + job.failedStages += 1 + } + job.activeStages -= 1 + update(job) + } + + stage.executorSummaries.values.foreach(update) + update(stage) + } + } + + override def onBlockManagerAdded(event: SparkListenerBlockManagerAdded): Unit = { + // This needs to set fields that are already set by onExecutorAdded because the driver is + // considered an "executor" in the UI, but does not have a SparkListenerExecutorAdded event. + val exec = getOrCreateExecutor(event.blockManagerId.executorId) + exec.hostPort = event.blockManagerId.hostPort + event.maxOnHeapMem.foreach { _ => + exec.totalOnHeap = event.maxOnHeapMem.get + exec.totalOffHeap = event.maxOffHeapMem.get + } + exec.isActive = true + exec.maxMemory = event.maxMem + update(exec) + } + + override def onBlockManagerRemoved(event: SparkListenerBlockManagerRemoved): Unit = { + // Nothing to do here. Covered by onExecutorRemoved. + } + + override def onUnpersistRDD(event: SparkListenerUnpersistRDD): Unit = { + liveRDDs.remove(event.rddId) + kvstore.delete(classOf[RDDStorageInfoWrapper], event.rddId) + } + + override def onExecutorMetricsUpdate(event: SparkListenerExecutorMetricsUpdate): Unit = { + event.accumUpdates.foreach { case (taskId, sid, sAttempt, accumUpdates) => + liveTasks.get(taskId).foreach { task => + val metrics = TaskMetrics.fromAccumulatorInfos(accumUpdates) + val delta = task.updateMetrics(metrics) + update(task) + + liveStages.get((sid, sAttempt)).foreach { stage => + stage.metrics.update(delta) + update(stage) + + val esummary = stage.executorSummary(event.execId) + esummary.metrics.update(delta) + update(esummary) + } + } + } + } + + override def onBlockUpdated(event: SparkListenerBlockUpdated): Unit = { + event.blockUpdatedInfo.blockId match { + case block: RDDBlockId => updateRDDBlock(event, block) + case _ => // TODO: API only covers RDD storage. + } + } + + private def updateRDDBlock(event: SparkListenerBlockUpdated, block: RDDBlockId): Unit = { + val executorId = event.blockUpdatedInfo.blockManagerId.executorId + + // Whether values are being added to or removed from the existing accounting. + val storageLevel = event.blockUpdatedInfo.storageLevel + val diskDelta = event.blockUpdatedInfo.diskSize * (if (storageLevel.useDisk) 1 else -1) + val memoryDelta = event.blockUpdatedInfo.memSize * (if (storageLevel.useMemory) 1 else -1) + + // Function to apply a delta to a value, but ensure that it doesn't go negative. + def newValue(old: Long, delta: Long): Long = math.max(0, old + delta) + + val updatedStorageLevel = if (storageLevel.isValid) { + Some(storageLevel.description) + } else { + None + } + + // We need information about the executor to update some memory accounting values in the + // RDD info, so read that beforehand. + val maybeExec = liveExecutors.get(executorId) + var rddBlocksDelta = 0 + + // Update the block entry in the RDD info, keeping track of the deltas above so that we + // can update the executor information too. + liveRDDs.get(block.rddId).foreach { rdd => + val partition = rdd.partition(block.name) + + val executors = if (updatedStorageLevel.isDefined) { + if (!partition.executors.contains(executorId)) { + rddBlocksDelta = 1 + } + partition.executors + executorId + } else { + rddBlocksDelta = -1 + partition.executors - executorId + } + + // Only update the partition if it's still stored in some executor, otherwise get rid of it. + if (executors.nonEmpty) { + if (updatedStorageLevel.isDefined) { + partition.storageLevel = updatedStorageLevel.get + } + partition.memoryUsed = newValue(partition.memoryUsed, memoryDelta) + partition.diskUsed = newValue(partition.diskUsed, diskDelta) + partition.executors = executors + } else { + rdd.removePartition(block.name) + } + + maybeExec.foreach { exec => + if (exec.rddBlocks + rddBlocksDelta > 0) { + val dist = rdd.distribution(exec) + dist.memoryRemaining = newValue(dist.memoryRemaining, -memoryDelta) + dist.memoryUsed = newValue(dist.memoryUsed, memoryDelta) + dist.diskUsed = newValue(dist.diskUsed, diskDelta) + + if (exec.hasMemoryInfo) { + if (storageLevel.useOffHeap) { + dist.offHeapUsed = newValue(dist.offHeapUsed, memoryDelta) + dist.offHeapRemaining = newValue(dist.offHeapRemaining, -memoryDelta) + } else { + dist.onHeapUsed = newValue(dist.onHeapUsed, memoryDelta) + dist.onHeapRemaining = newValue(dist.onHeapRemaining, -memoryDelta) + } + } + } else { + rdd.removeDistribution(exec) + } + } + + if (updatedStorageLevel.isDefined) { + rdd.storageLevel = updatedStorageLevel.get + } + rdd.memoryUsed = newValue(rdd.memoryUsed, memoryDelta) + rdd.diskUsed = newValue(rdd.diskUsed, diskDelta) + update(rdd) + } + + maybeExec.foreach { exec => + if (exec.hasMemoryInfo) { + if (storageLevel.useOffHeap) { + exec.usedOffHeap = newValue(exec.usedOffHeap, memoryDelta) + } else { + exec.usedOnHeap = newValue(exec.usedOnHeap, memoryDelta) + } + } + exec.memoryUsed = newValue(exec.memoryUsed, memoryDelta) + exec.diskUsed = newValue(exec.diskUsed, diskDelta) + exec.rddBlocks += rddBlocksDelta + if (exec.hasMemoryInfo || rddBlocksDelta != 0) { + update(exec) + } + } + } + + private def getOrCreateExecutor(executorId: String): LiveExecutor = { + liveExecutors.getOrElseUpdate(executorId, new LiveExecutor(executorId)) + } + + private def getOrCreateStage(info: StageInfo): LiveStage = { + val stage = liveStages.getOrElseUpdate((info.stageId, info.attemptId), new LiveStage()) + stage.info = info + stage + } + + private def update(entity: LiveEntity): Unit = { + entity.write(kvstore) + } + +} diff --git a/core/src/main/scala/org/apache/spark/status/KVUtils.scala b/core/src/main/scala/org/apache/spark/status/KVUtils.scala new file mode 100644 index 0000000000000..4638511944c61 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/status/KVUtils.scala @@ -0,0 +1,73 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.status + +import java.io.File + +import scala.annotation.meta.getter +import scala.language.implicitConversions +import scala.reflect.{classTag, ClassTag} + +import com.fasterxml.jackson.annotation.JsonInclude +import com.fasterxml.jackson.module.scala.DefaultScalaModule + +import org.apache.spark.internal.Logging +import org.apache.spark.util.kvstore._ + +private[spark] object KVUtils extends Logging { + + /** Use this to annotate constructor params to be used as KVStore indices. */ + type KVIndexParam = KVIndex @getter + + /** + * A KVStoreSerializer that provides Scala types serialization too, and uses the same options as + * the API serializer. + */ + private[spark] class KVStoreScalaSerializer extends KVStoreSerializer { + + mapper.registerModule(DefaultScalaModule) + mapper.setSerializationInclusion(JsonInclude.Include.NON_NULL) + + } + + /** + * Open or create a LevelDB store. + * + * @param path Location of the store. + * @param metadata Metadata value to compare to the data in the store. If the store does not + * contain any metadata (e.g. it's a new store), this value is written as + * the store's metadata. + */ + def open[M: ClassTag](path: File, metadata: M): LevelDB = { + require(metadata != null, "Metadata is required.") + + val db = new LevelDB(path, new KVStoreScalaSerializer()) + val dbMeta = db.getMetadata(classTag[M].runtimeClass) + if (dbMeta == null) { + db.setMetadata(metadata) + } else if (dbMeta != metadata) { + db.close() + throw new MetadataMismatchException() + } + + db + } + + private[spark] class MetadataMismatchException extends Exception + +} diff --git a/core/src/main/scala/org/apache/spark/status/LiveEntity.scala b/core/src/main/scala/org/apache/spark/status/LiveEntity.scala new file mode 100644 index 0000000000000..63fa36580bc7d --- /dev/null +++ b/core/src/main/scala/org/apache/spark/status/LiveEntity.scala @@ -0,0 +1,526 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.status + +import java.util.Date + +import scala.collection.mutable.HashMap + +import org.apache.spark.JobExecutionStatus +import org.apache.spark.executor.TaskMetrics +import org.apache.spark.scheduler.{AccumulableInfo, StageInfo, TaskInfo} +import org.apache.spark.status.api.v1 +import org.apache.spark.storage.RDDInfo +import org.apache.spark.ui.SparkUI +import org.apache.spark.util.AccumulatorContext +import org.apache.spark.util.kvstore.KVStore + +/** + * A mutable representation of a live entity in Spark (jobs, stages, tasks, et al). Every live + * entity uses one of these instances to keep track of their evolving state, and periodically + * flush an immutable view of the entity to the app state store. + */ +private[spark] abstract class LiveEntity { + + def write(store: KVStore): Unit = { + store.write(doUpdate()) + } + + /** + * Returns an updated view of entity data, to be stored in the status store, reflecting the + * latest information collected by the listener. + */ + protected def doUpdate(): Any + +} + +private class LiveJob( + val jobId: Int, + name: String, + submissionTime: Option[Date], + val stageIds: Seq[Int], + jobGroup: Option[String], + numTasks: Int) extends LiveEntity { + + var activeTasks = 0 + var completedTasks = 0 + var failedTasks = 0 + + var skippedTasks = 0 + var skippedStages = Set[Int]() + + var status = JobExecutionStatus.RUNNING + var completionTime: Option[Date] = None + + var completedStages: Set[Int] = Set() + var activeStages = 0 + var failedStages = 0 + + override protected def doUpdate(): Any = { + val info = new v1.JobData( + jobId, + name, + None, // description is always None? + submissionTime, + completionTime, + stageIds, + jobGroup, + status, + numTasks, + activeTasks, + completedTasks, + skippedTasks, + failedTasks, + activeStages, + completedStages.size, + skippedStages.size, + failedStages) + new JobDataWrapper(info, skippedStages) + } + +} + +private class LiveTask( + info: TaskInfo, + stageId: Int, + stageAttemptId: Int) extends LiveEntity { + + import LiveEntityHelpers._ + + private var recordedMetrics: v1.TaskMetrics = null + + var errorMessage: Option[String] = None + + /** + * Update the metrics for the task and return the difference between the previous and new + * values. + */ + def updateMetrics(metrics: TaskMetrics): v1.TaskMetrics = { + if (metrics != null) { + val old = recordedMetrics + recordedMetrics = new v1.TaskMetrics( + metrics.executorDeserializeTime, + metrics.executorDeserializeCpuTime, + metrics.executorRunTime, + metrics.executorCpuTime, + metrics.resultSize, + metrics.jvmGCTime, + metrics.resultSerializationTime, + metrics.memoryBytesSpilled, + metrics.diskBytesSpilled, + new v1.InputMetrics( + metrics.inputMetrics.bytesRead, + metrics.inputMetrics.recordsRead), + new v1.OutputMetrics( + metrics.outputMetrics.bytesWritten, + metrics.outputMetrics.recordsWritten), + new v1.ShuffleReadMetrics( + metrics.shuffleReadMetrics.remoteBlocksFetched, + metrics.shuffleReadMetrics.localBlocksFetched, + metrics.shuffleReadMetrics.fetchWaitTime, + metrics.shuffleReadMetrics.remoteBytesRead, + metrics.shuffleReadMetrics.remoteBytesReadToDisk, + metrics.shuffleReadMetrics.localBytesRead, + metrics.shuffleReadMetrics.recordsRead), + new v1.ShuffleWriteMetrics( + metrics.shuffleWriteMetrics.bytesWritten, + metrics.shuffleWriteMetrics.writeTime, + metrics.shuffleWriteMetrics.recordsWritten)) + if (old != null) calculateMetricsDelta(recordedMetrics, old) else recordedMetrics + } else { + null + } + } + + /** + * Return a new TaskMetrics object containing the delta of the various fields of the given + * metrics objects. This is currently targeted at updating stage data, so it does not + * necessarily calculate deltas for all the fields. + */ + private def calculateMetricsDelta( + metrics: v1.TaskMetrics, + old: v1.TaskMetrics): v1.TaskMetrics = { + val shuffleWriteDelta = new v1.ShuffleWriteMetrics( + metrics.shuffleWriteMetrics.bytesWritten - old.shuffleWriteMetrics.bytesWritten, + 0L, + metrics.shuffleWriteMetrics.recordsWritten - old.shuffleWriteMetrics.recordsWritten) + + val shuffleReadDelta = new v1.ShuffleReadMetrics( + 0L, 0L, 0L, + metrics.shuffleReadMetrics.remoteBytesRead - old.shuffleReadMetrics.remoteBytesRead, + metrics.shuffleReadMetrics.remoteBytesReadToDisk - + old.shuffleReadMetrics.remoteBytesReadToDisk, + metrics.shuffleReadMetrics.localBytesRead - old.shuffleReadMetrics.localBytesRead, + metrics.shuffleReadMetrics.recordsRead - old.shuffleReadMetrics.recordsRead) + + val inputDelta = new v1.InputMetrics( + metrics.inputMetrics.bytesRead - old.inputMetrics.bytesRead, + metrics.inputMetrics.recordsRead - old.inputMetrics.recordsRead) + + val outputDelta = new v1.OutputMetrics( + metrics.outputMetrics.bytesWritten - old.outputMetrics.bytesWritten, + metrics.outputMetrics.recordsWritten - old.outputMetrics.recordsWritten) + + new v1.TaskMetrics( + 0L, 0L, + metrics.executorRunTime - old.executorRunTime, + metrics.executorCpuTime - old.executorCpuTime, + 0L, 0L, 0L, + metrics.memoryBytesSpilled - old.memoryBytesSpilled, + metrics.diskBytesSpilled - old.diskBytesSpilled, + inputDelta, + outputDelta, + shuffleReadDelta, + shuffleWriteDelta) + } + + override protected def doUpdate(): Any = { + val task = new v1.TaskData( + info.taskId, + info.index, + info.attemptNumber, + new Date(info.launchTime), + if (info.finished) Some(info.duration) else None, + info.executorId, + info.host, + info.status, + info.taskLocality.toString(), + info.speculative, + newAccumulatorInfos(info.accumulables), + errorMessage, + Option(recordedMetrics)) + new TaskDataWrapper(task) + } + +} + +private class LiveExecutor(val executorId: String) extends LiveEntity { + + var hostPort: String = null + var host: String = null + var isActive = true + var totalCores = 0 + + var rddBlocks = 0 + var memoryUsed = 0L + var diskUsed = 0L + var maxTasks = 0 + var maxMemory = 0L + + var totalTasks = 0 + var activeTasks = 0 + var completedTasks = 0 + var failedTasks = 0 + var totalDuration = 0L + var totalGcTime = 0L + var totalInputBytes = 0L + var totalShuffleRead = 0L + var totalShuffleWrite = 0L + var isBlacklisted = false + + var executorLogs = Map[String, String]() + + // Memory metrics. They may not be recorded (e.g. old event logs) so if totalOnHeap is not + // initialized, the store will not contain this information. + var totalOnHeap = -1L + var totalOffHeap = 0L + var usedOnHeap = 0L + var usedOffHeap = 0L + + def hasMemoryInfo: Boolean = totalOnHeap >= 0L + + def hostname: String = if (host != null) host else hostPort.split(":")(0) + + override protected def doUpdate(): Any = { + val memoryMetrics = if (totalOnHeap >= 0) { + Some(new v1.MemoryMetrics(usedOnHeap, usedOffHeap, totalOnHeap, totalOffHeap)) + } else { + None + } + + val info = new v1.ExecutorSummary( + executorId, + if (hostPort != null) hostPort else host, + isActive, + rddBlocks, + memoryUsed, + diskUsed, + totalCores, + maxTasks, + activeTasks, + failedTasks, + completedTasks, + totalTasks, + totalDuration, + totalGcTime, + totalInputBytes, + totalShuffleRead, + totalShuffleWrite, + isBlacklisted, + maxMemory, + executorLogs, + memoryMetrics) + new ExecutorSummaryWrapper(info) + } + +} + +/** Metrics tracked per stage (both total and per executor). */ +private class MetricsTracker { + var executorRunTime = 0L + var executorCpuTime = 0L + var inputBytes = 0L + var inputRecords = 0L + var outputBytes = 0L + var outputRecords = 0L + var shuffleReadBytes = 0L + var shuffleReadRecords = 0L + var shuffleWriteBytes = 0L + var shuffleWriteRecords = 0L + var memoryBytesSpilled = 0L + var diskBytesSpilled = 0L + + def update(delta: v1.TaskMetrics): Unit = { + executorRunTime += delta.executorRunTime + executorCpuTime += delta.executorCpuTime + inputBytes += delta.inputMetrics.bytesRead + inputRecords += delta.inputMetrics.recordsRead + outputBytes += delta.outputMetrics.bytesWritten + outputRecords += delta.outputMetrics.recordsWritten + shuffleReadBytes += delta.shuffleReadMetrics.localBytesRead + + delta.shuffleReadMetrics.remoteBytesRead + shuffleReadRecords += delta.shuffleReadMetrics.recordsRead + shuffleWriteBytes += delta.shuffleWriteMetrics.bytesWritten + shuffleWriteRecords += delta.shuffleWriteMetrics.recordsWritten + memoryBytesSpilled += delta.memoryBytesSpilled + diskBytesSpilled += delta.diskBytesSpilled + } + +} + +private class LiveExecutorStageSummary( + stageId: Int, + attemptId: Int, + executorId: String) extends LiveEntity { + + var taskTime = 0L + var succeededTasks = 0 + var failedTasks = 0 + var killedTasks = 0 + + val metrics = new MetricsTracker() + + override protected def doUpdate(): Any = { + val info = new v1.ExecutorStageSummary( + taskTime, + failedTasks, + succeededTasks, + metrics.inputBytes, + metrics.outputBytes, + metrics.shuffleReadBytes, + metrics.shuffleWriteBytes, + metrics.memoryBytesSpilled, + metrics.diskBytesSpilled) + new ExecutorStageSummaryWrapper(stageId, attemptId, executorId, info) + } + +} + +private class LiveStage extends LiveEntity { + + import LiveEntityHelpers._ + + var jobs = Seq[LiveJob]() + var jobIds = Set[Int]() + + var info: StageInfo = null + var status = v1.StageStatus.PENDING + + var schedulingPool: String = SparkUI.DEFAULT_POOL_NAME + + var activeTasks = 0 + var completedTasks = 0 + var failedTasks = 0 + + var firstLaunchTime = Long.MaxValue + + val metrics = new MetricsTracker() + + val executorSummaries = new HashMap[String, LiveExecutorStageSummary]() + + def executorSummary(executorId: String): LiveExecutorStageSummary = { + executorSummaries.getOrElseUpdate(executorId, + new LiveExecutorStageSummary(info.stageId, info.attemptId, executorId)) + } + + override protected def doUpdate(): Any = { + val update = new v1.StageData( + status, + info.stageId, + info.attemptId, + + activeTasks, + completedTasks, + failedTasks, + + metrics.executorRunTime, + metrics.executorCpuTime, + info.submissionTime.map(new Date(_)), + if (firstLaunchTime < Long.MaxValue) Some(new Date(firstLaunchTime)) else None, + info.completionTime.map(new Date(_)), + + metrics.inputBytes, + metrics.inputRecords, + metrics.outputBytes, + metrics.outputRecords, + metrics.shuffleReadBytes, + metrics.shuffleReadRecords, + metrics.shuffleWriteBytes, + metrics.shuffleWriteRecords, + metrics.memoryBytesSpilled, + metrics.diskBytesSpilled, + + info.name, + info.details, + schedulingPool, + + newAccumulatorInfos(info.accumulables.values), + None, + None) + + new StageDataWrapper(update, jobIds) + } + +} + +private class LiveRDDPartition(val blockName: String) { + + var executors = Set[String]() + var storageLevel: String = null + var memoryUsed = 0L + var diskUsed = 0L + + def toApi(): v1.RDDPartitionInfo = { + new v1.RDDPartitionInfo( + blockName, + storageLevel, + memoryUsed, + diskUsed, + executors.toSeq.sorted) + } + +} + +private class LiveRDDDistribution(val exec: LiveExecutor) { + + var memoryRemaining = exec.maxMemory + var memoryUsed = 0L + var diskUsed = 0L + + var onHeapUsed = 0L + var offHeapUsed = 0L + var onHeapRemaining = 0L + var offHeapRemaining = 0L + + def toApi(): v1.RDDDataDistribution = { + new v1.RDDDataDistribution( + exec.hostPort, + memoryUsed, + memoryRemaining, + diskUsed, + if (exec.hasMemoryInfo) Some(onHeapUsed) else None, + if (exec.hasMemoryInfo) Some(offHeapUsed) else None, + if (exec.hasMemoryInfo) Some(onHeapRemaining) else None, + if (exec.hasMemoryInfo) Some(offHeapRemaining) else None) + } + +} + +private class LiveRDD(info: RDDInfo) extends LiveEntity { + + var storageLevel: String = info.storageLevel.description + var memoryUsed = 0L + var diskUsed = 0L + + private val partitions = new HashMap[String, LiveRDDPartition]() + private val distributions = new HashMap[String, LiveRDDDistribution]() + + def partition(blockName: String): LiveRDDPartition = { + partitions.getOrElseUpdate(blockName, new LiveRDDPartition(blockName)) + } + + def removePartition(blockName: String): Unit = partitions.remove(blockName) + + def distribution(exec: LiveExecutor): LiveRDDDistribution = { + distributions.getOrElseUpdate(exec.hostPort, new LiveRDDDistribution(exec)) + } + + def removeDistribution(exec: LiveExecutor): Unit = { + distributions.remove(exec.hostPort) + } + + override protected def doUpdate(): Any = { + val parts = if (partitions.nonEmpty) { + Some(partitions.values.toList.sortBy(_.blockName).map(_.toApi())) + } else { + None + } + + val dists = if (distributions.nonEmpty) { + Some(distributions.values.toList.sortBy(_.exec.executorId).map(_.toApi())) + } else { + None + } + + val rdd = new v1.RDDStorageInfo( + info.id, + info.name, + info.numPartitions, + partitions.size, + storageLevel, + memoryUsed, + diskUsed, + dists, + parts) + + new RDDStorageInfoWrapper(rdd) + } + +} + +private object LiveEntityHelpers { + + def newAccumulatorInfos(accums: Iterable[AccumulableInfo]): Seq[v1.AccumulableInfo] = { + accums + .filter { acc => + // We don't need to store internal or SQL accumulables as their values will be shown in + // other places, so drop them to reduce the memory usage. + !acc.internal && (!acc.metadata.isDefined || + acc.metadata.get != Some(AccumulatorContext.SQL_ACCUM_IDENTIFIER)) + } + .map { acc => + new v1.AccumulableInfo( + acc.id, + acc.name.map(_.intern()).orNull, + acc.update.map(_.toString()), + acc.value.map(_.toString()).orNull) + } + .toSeq + } + +} diff --git a/core/src/main/scala/org/apache/spark/status/api/v1/AllStagesResource.scala b/core/src/main/scala/org/apache/spark/status/api/v1/AllStagesResource.scala index 4a4ed954d689e..5f69949c618fd 100644 --- a/core/src/main/scala/org/apache/spark/status/api/v1/AllStagesResource.scala +++ b/core/src/main/scala/org/apache/spark/status/api/v1/AllStagesResource.scala @@ -71,7 +71,7 @@ private[v1] object AllStagesResource { val taskData = if (includeDetails) { Some(stageUiData.taskData.map { case (k, v) => - k -> convertTaskData(v, stageUiData.lastUpdateTime) }) + k -> convertTaskData(v, stageUiData.lastUpdateTime) }.toMap) } else { None } @@ -88,7 +88,7 @@ private[v1] object AllStagesResource { memoryBytesSpilled = summary.memoryBytesSpilled, diskBytesSpilled = summary.diskBytesSpilled ) - }) + }.toMap) } else { None } diff --git a/core/src/main/scala/org/apache/spark/status/api/v1/api.scala b/core/src/main/scala/org/apache/spark/status/api/v1/api.scala index 31659b25db318..bff6f90823f40 100644 --- a/core/src/main/scala/org/apache/spark/status/api/v1/api.scala +++ b/core/src/main/scala/org/apache/spark/status/api/v1/api.scala @@ -16,11 +16,11 @@ */ package org.apache.spark.status.api.v1 +import java.lang.{Long => JLong} import java.util.Date -import scala.collection.Map - import com.fasterxml.jackson.annotation.JsonIgnoreProperties +import com.fasterxml.jackson.databind.annotation.JsonDeserialize import org.apache.spark.JobExecutionStatus @@ -129,9 +129,13 @@ class RDDDataDistribution private[spark]( val memoryUsed: Long, val memoryRemaining: Long, val diskUsed: Long, + @JsonDeserialize(contentAs = classOf[JLong]) val onHeapMemoryUsed: Option[Long], + @JsonDeserialize(contentAs = classOf[JLong]) val offHeapMemoryUsed: Option[Long], + @JsonDeserialize(contentAs = classOf[JLong]) val onHeapMemoryRemaining: Option[Long], + @JsonDeserialize(contentAs = classOf[JLong]) val offHeapMemoryRemaining: Option[Long]) class RDDPartitionInfo private[spark]( @@ -179,7 +183,8 @@ class TaskData private[spark]( val index: Int, val attempt: Int, val launchTime: Date, - val duration: Option[Long] = None, + @JsonDeserialize(contentAs = classOf[JLong]) + val duration: Option[Long], val executorId: String, val host: String, val status: String, diff --git a/core/src/main/scala/org/apache/spark/status/storeTypes.scala b/core/src/main/scala/org/apache/spark/status/storeTypes.scala new file mode 100644 index 0000000000000..9579accd2cba7 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/status/storeTypes.scala @@ -0,0 +1,98 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.status + +import com.fasterxml.jackson.annotation.JsonIgnore + +import org.apache.spark.status.KVUtils._ +import org.apache.spark.status.api.v1._ +import org.apache.spark.util.kvstore.KVIndex + +private[spark] class ApplicationInfoWrapper(val info: ApplicationInfo) { + + @JsonIgnore @KVIndex + def id: String = info.id + +} + +private[spark] class ExecutorSummaryWrapper(val info: ExecutorSummary) { + + @JsonIgnore @KVIndex + private[this] val id: String = info.id + + @JsonIgnore @KVIndex("active") + private[this] val active: Boolean = info.isActive + + @JsonIgnore @KVIndex("host") + val host: String = info.hostPort.split(":")(0) + +} + +/** + * Keep track of the existing stages when the job was submitted, and those that were + * completed during the job's execution. This allows a more accurate acounting of how + * many tasks were skipped for the job. + */ +private[spark] class JobDataWrapper( + val info: JobData, + val skippedStages: Set[Int]) { + + @JsonIgnore @KVIndex + private[this] val id: Int = info.jobId + +} + +private[spark] class StageDataWrapper( + val info: StageData, + val jobIds: Set[Int]) { + + @JsonIgnore @KVIndex + def id: Array[Int] = Array(info.stageId, info.attemptId) + +} + +private[spark] class TaskDataWrapper(val info: TaskData) { + + @JsonIgnore @KVIndex + def id: Long = info.taskId + +} + +private[spark] class RDDStorageInfoWrapper(val info: RDDStorageInfo) { + + @JsonIgnore @KVIndex + def id: Int = info.id + + @JsonIgnore @KVIndex("cached") + def cached: Boolean = info.numCachedPartitions > 0 + +} + +private[spark] class ExecutorStageSummaryWrapper( + val stageId: Int, + val stageAttemptId: Int, + val executorId: String, + val info: ExecutorStageSummary) { + + @JsonIgnore @KVIndex + val id: Array[Any] = Array(stageId, stageAttemptId, executorId) + + @JsonIgnore @KVIndex("stage") + private[this] val stage: Array[Int] = Array(stageId, stageAttemptId) + +} diff --git a/core/src/test/scala/org/apache/spark/deploy/history/FsHistoryProviderSuite.scala b/core/src/test/scala/org/apache/spark/deploy/history/FsHistoryProviderSuite.scala index 2141934c92640..03bd3eaf579f3 100644 --- a/core/src/test/scala/org/apache/spark/deploy/history/FsHistoryProviderSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/history/FsHistoryProviderSuite.scala @@ -611,7 +611,7 @@ class FsHistoryProviderSuite extends SparkFunSuite with BeforeAndAfter with Matc // Manually overwrite the version in the listing db; this should cause the new provider to // discard all data because the versions don't match. - val meta = new KVStoreMetadata(FsHistoryProvider.CURRENT_LISTING_VERSION + 1, + val meta = new FsHistoryProviderMetadata(FsHistoryProvider.CURRENT_LISTING_VERSION + 1, conf.get(LOCAL_STORE_DIR).get) oldProvider.listing.setMetadata(meta) oldProvider.stop() diff --git a/core/src/test/scala/org/apache/spark/status/AppStatusListenerSuite.scala b/core/src/test/scala/org/apache/spark/status/AppStatusListenerSuite.scala new file mode 100644 index 0000000000000..6f7a0c14dd684 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/status/AppStatusListenerSuite.scala @@ -0,0 +1,690 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.status + +import java.io.File +import java.util.{Date, Properties} + +import scala.collection.JavaConverters._ +import scala.reflect.{classTag, ClassTag} + +import org.scalatest.BeforeAndAfter + +import org.apache.spark._ +import org.apache.spark.executor.TaskMetrics +import org.apache.spark.scheduler._ +import org.apache.spark.scheduler.cluster._ +import org.apache.spark.status.api.v1 +import org.apache.spark.storage._ +import org.apache.spark.util.Utils +import org.apache.spark.util.kvstore._ + +class AppStatusListenerSuite extends SparkFunSuite with BeforeAndAfter { + + private var time: Long = _ + private var testDir: File = _ + private var store: KVStore = _ + + before { + time = 0L + testDir = Utils.createTempDir() + store = KVUtils.open(testDir, getClass().getName()) + } + + after { + store.close() + Utils.deleteRecursively(testDir) + } + + test("scheduler events") { + val listener = new AppStatusListener(store) + + // Start the application. + time += 1 + listener.onApplicationStart(SparkListenerApplicationStart( + "name", + Some("id"), + time, + "user", + Some("attempt"), + None)) + + check[ApplicationInfoWrapper]("id") { app => + assert(app.info.name === "name") + assert(app.info.id === "id") + assert(app.info.attempts.size === 1) + + val attempt = app.info.attempts.head + assert(attempt.attemptId === Some("attempt")) + assert(attempt.startTime === new Date(time)) + assert(attempt.lastUpdated === new Date(time)) + assert(attempt.endTime.getTime() === -1L) + assert(attempt.sparkUser === "user") + assert(!attempt.completed) + } + + // Start a couple of executors. + time += 1 + val execIds = Array("1", "2") + + execIds.foreach { id => + listener.onExecutorAdded(SparkListenerExecutorAdded(time, id, + new ExecutorInfo(s"$id.example.com", 1, Map()))) + } + + execIds.foreach { id => + check[ExecutorSummaryWrapper](id) { exec => + assert(exec.info.id === id) + assert(exec.info.hostPort === s"$id.example.com") + assert(exec.info.isActive) + } + } + + // Start a job with 2 stages / 4 tasks each + time += 1 + val stages = Seq( + new StageInfo(1, 0, "stage1", 4, Nil, Nil, "details1"), + new StageInfo(2, 0, "stage2", 4, Nil, Seq(1), "details2")) + + val jobProps = new Properties() + jobProps.setProperty(SparkContext.SPARK_JOB_GROUP_ID, "jobGroup") + jobProps.setProperty("spark.scheduler.pool", "schedPool") + + listener.onJobStart(SparkListenerJobStart(1, time, stages, jobProps)) + + check[JobDataWrapper](1) { job => + assert(job.info.jobId === 1) + assert(job.info.name === stages.last.name) + assert(job.info.description === None) + assert(job.info.status === JobExecutionStatus.RUNNING) + assert(job.info.submissionTime === Some(new Date(time))) + assert(job.info.jobGroup === Some("jobGroup")) + } + + stages.foreach { info => + check[StageDataWrapper](key(info)) { stage => + assert(stage.info.status === v1.StageStatus.PENDING) + assert(stage.jobIds === Set(1)) + } + } + + // Submit stage 1 + time += 1 + stages.head.submissionTime = Some(time) + listener.onStageSubmitted(SparkListenerStageSubmitted(stages.head, jobProps)) + + check[JobDataWrapper](1) { job => + assert(job.info.numActiveStages === 1) + } + + check[StageDataWrapper](key(stages.head)) { stage => + assert(stage.info.status === v1.StageStatus.ACTIVE) + assert(stage.info.submissionTime === Some(new Date(stages.head.submissionTime.get))) + assert(stage.info.schedulingPool === "schedPool") + } + + // Start tasks from stage 1 + time += 1 + var _taskIdTracker = -1L + def nextTaskId(): Long = { + _taskIdTracker += 1 + _taskIdTracker + } + + def createTasks(count: Int, time: Long): Seq[TaskInfo] = { + (1 to count).map { id => + val exec = execIds(id.toInt % execIds.length) + val taskId = nextTaskId() + new TaskInfo(taskId, taskId.toInt, 1, time, exec, s"$exec.example.com", + TaskLocality.PROCESS_LOCAL, id % 2 == 0) + } + } + + val s1Tasks = createTasks(4, time) + s1Tasks.foreach { task => + listener.onTaskStart(SparkListenerTaskStart(stages.head.stageId, stages.head.attemptId, task)) + } + + assert(store.count(classOf[TaskDataWrapper]) === s1Tasks.size) + + check[JobDataWrapper](1) { job => + assert(job.info.numActiveTasks === s1Tasks.size) + } + + check[StageDataWrapper](key(stages.head)) { stage => + assert(stage.info.numActiveTasks === s1Tasks.size) + assert(stage.info.firstTaskLaunchedTime === Some(new Date(s1Tasks.head.launchTime))) + } + + s1Tasks.foreach { task => + check[TaskDataWrapper](task.taskId) { wrapper => + assert(wrapper.info.taskId === task.taskId) + assert(wrapper.info.index === task.index) + assert(wrapper.info.attempt === task.attemptNumber) + assert(wrapper.info.launchTime === new Date(task.launchTime)) + assert(wrapper.info.executorId === task.executorId) + assert(wrapper.info.host === task.host) + assert(wrapper.info.status === task.status) + assert(wrapper.info.taskLocality === task.taskLocality.toString()) + assert(wrapper.info.speculative === task.speculative) + } + } + + // Send executor metrics update. Only update one metric to avoid a lot of boilerplate code. + s1Tasks.foreach { task => + val accum = new AccumulableInfo(1L, Some(InternalAccumulator.MEMORY_BYTES_SPILLED), + Some(1L), None, true, false, None) + listener.onExecutorMetricsUpdate(SparkListenerExecutorMetricsUpdate( + task.executorId, + Seq((task.taskId, stages.head.stageId, stages.head.attemptId, Seq(accum))))) + } + + check[StageDataWrapper](key(stages.head)) { stage => + assert(stage.info.memoryBytesSpilled === s1Tasks.size) + } + + val execs = store.view(classOf[ExecutorStageSummaryWrapper]).index("stage") + .first(key(stages.head)).last(key(stages.head)).asScala.toSeq + assert(execs.size > 0) + execs.foreach { exec => + assert(exec.info.memoryBytesSpilled === s1Tasks.size / 2) + } + + // Fail one of the tasks, re-start it. + time += 1 + s1Tasks.head.markFinished(TaskState.FAILED, time) + listener.onTaskEnd(SparkListenerTaskEnd(stages.head.stageId, stages.head.attemptId, + "taskType", TaskResultLost, s1Tasks.head, null)) + + time += 1 + val reattempt = { + val orig = s1Tasks.head + // Task reattempts have a different ID, but the same index as the original. + new TaskInfo(nextTaskId(), orig.index, orig.attemptNumber + 1, time, orig.executorId, + s"${orig.executorId}.example.com", TaskLocality.PROCESS_LOCAL, orig.speculative) + } + listener.onTaskStart(SparkListenerTaskStart(stages.head.stageId, stages.head.attemptId, + reattempt)) + + assert(store.count(classOf[TaskDataWrapper]) === s1Tasks.size + 1) + + check[JobDataWrapper](1) { job => + assert(job.info.numFailedTasks === 1) + assert(job.info.numActiveTasks === s1Tasks.size) + } + + check[StageDataWrapper](key(stages.head)) { stage => + assert(stage.info.numFailedTasks === 1) + assert(stage.info.numActiveTasks === s1Tasks.size) + } + + check[TaskDataWrapper](s1Tasks.head.taskId) { task => + assert(task.info.status === s1Tasks.head.status) + assert(task.info.duration === Some(s1Tasks.head.duration)) + assert(task.info.errorMessage == Some(TaskResultLost.toErrorString)) + } + + check[TaskDataWrapper](reattempt.taskId) { task => + assert(task.info.index === s1Tasks.head.index) + assert(task.info.attempt === reattempt.attemptNumber) + } + + // Succeed all tasks in stage 1. + val pending = s1Tasks.drop(1) ++ Seq(reattempt) + + val s1Metrics = TaskMetrics.empty + s1Metrics.setExecutorCpuTime(2L) + s1Metrics.setExecutorRunTime(4L) + + time += 1 + pending.foreach { task => + task.markFinished(TaskState.FINISHED, time) + listener.onTaskEnd(SparkListenerTaskEnd(stages.head.stageId, stages.head.attemptId, + "taskType", Success, task, s1Metrics)) + } + + check[JobDataWrapper](1) { job => + assert(job.info.numFailedTasks === 1) + assert(job.info.numActiveTasks === 0) + assert(job.info.numCompletedTasks === pending.size) + } + + check[StageDataWrapper](key(stages.head)) { stage => + assert(stage.info.numFailedTasks === 1) + assert(stage.info.numActiveTasks === 0) + assert(stage.info.numCompleteTasks === pending.size) + } + + pending.foreach { task => + check[TaskDataWrapper](task.taskId) { wrapper => + assert(wrapper.info.errorMessage === None) + assert(wrapper.info.taskMetrics.get.executorCpuTime === 2L) + assert(wrapper.info.taskMetrics.get.executorRunTime === 4L) + } + } + + assert(store.count(classOf[TaskDataWrapper]) === pending.size + 1) + + // End stage 1. + time += 1 + stages.head.completionTime = Some(time) + listener.onStageCompleted(SparkListenerStageCompleted(stages.head)) + + check[JobDataWrapper](1) { job => + assert(job.info.numActiveStages === 0) + assert(job.info.numCompletedStages === 1) + } + + check[StageDataWrapper](key(stages.head)) { stage => + assert(stage.info.status === v1.StageStatus.COMPLETE) + assert(stage.info.numFailedTasks === 1) + assert(stage.info.numActiveTasks === 0) + assert(stage.info.numCompleteTasks === pending.size) + } + + // Submit stage 2. + time += 1 + stages.last.submissionTime = Some(time) + listener.onStageSubmitted(SparkListenerStageSubmitted(stages.last, jobProps)) + + check[JobDataWrapper](1) { job => + assert(job.info.numActiveStages === 1) + } + + check[StageDataWrapper](key(stages.last)) { stage => + assert(stage.info.status === v1.StageStatus.ACTIVE) + assert(stage.info.submissionTime === Some(new Date(stages.last.submissionTime.get))) + } + + // Start and fail all tasks of stage 2. + time += 1 + val s2Tasks = createTasks(4, time) + s2Tasks.foreach { task => + listener.onTaskStart(SparkListenerTaskStart(stages.last.stageId, stages.last.attemptId, task)) + } + + time += 1 + s2Tasks.foreach { task => + task.markFinished(TaskState.FAILED, time) + listener.onTaskEnd(SparkListenerTaskEnd(stages.last.stageId, stages.last.attemptId, + "taskType", TaskResultLost, task, null)) + } + + check[JobDataWrapper](1) { job => + assert(job.info.numFailedTasks === 1 + s2Tasks.size) + assert(job.info.numActiveTasks === 0) + } + + check[StageDataWrapper](key(stages.last)) { stage => + assert(stage.info.numFailedTasks === s2Tasks.size) + assert(stage.info.numActiveTasks === 0) + } + + // Fail stage 2. + time += 1 + stages.last.completionTime = Some(time) + stages.last.failureReason = Some("uh oh") + listener.onStageCompleted(SparkListenerStageCompleted(stages.last)) + + check[JobDataWrapper](1) { job => + assert(job.info.numCompletedStages === 1) + assert(job.info.numFailedStages === 1) + } + + check[StageDataWrapper](key(stages.last)) { stage => + assert(stage.info.status === v1.StageStatus.FAILED) + assert(stage.info.numFailedTasks === s2Tasks.size) + assert(stage.info.numActiveTasks === 0) + assert(stage.info.numCompleteTasks === 0) + } + + // - Re-submit stage 2, all tasks, and succeed them and the stage. + val oldS2 = stages.last + val newS2 = new StageInfo(oldS2.stageId, oldS2.attemptId + 1, oldS2.name, oldS2.numTasks, + oldS2.rddInfos, oldS2.parentIds, oldS2.details, oldS2.taskMetrics) + + time += 1 + newS2.submissionTime = Some(time) + listener.onStageSubmitted(SparkListenerStageSubmitted(newS2, jobProps)) + assert(store.count(classOf[StageDataWrapper]) === 3) + + val newS2Tasks = createTasks(4, time) + + newS2Tasks.foreach { task => + listener.onTaskStart(SparkListenerTaskStart(newS2.stageId, newS2.attemptId, task)) + } + + time += 1 + newS2Tasks.foreach { task => + task.markFinished(TaskState.FINISHED, time) + listener.onTaskEnd(SparkListenerTaskEnd(newS2.stageId, newS2.attemptId, "taskType", Success, + task, null)) + } + + time += 1 + newS2.completionTime = Some(time) + listener.onStageCompleted(SparkListenerStageCompleted(newS2)) + + check[JobDataWrapper](1) { job => + assert(job.info.numActiveStages === 0) + assert(job.info.numFailedStages === 1) + assert(job.info.numCompletedStages === 2) + } + + check[StageDataWrapper](key(newS2)) { stage => + assert(stage.info.status === v1.StageStatus.COMPLETE) + assert(stage.info.numActiveTasks === 0) + assert(stage.info.numCompleteTasks === newS2Tasks.size) + } + + // End job. + time += 1 + listener.onJobEnd(SparkListenerJobEnd(1, time, JobSucceeded)) + + check[JobDataWrapper](1) { job => + assert(job.info.status === JobExecutionStatus.SUCCEEDED) + } + + // Submit a second job that re-uses stage 1 and stage 2. Stage 1 won't be re-run, but + // stage 2 will. In any case, the DAGScheduler creates new info structures that are copies + // of the old stages, so mimic that behavior here. The "new" stage 1 is submitted without + // a submission time, which means it is "skipped", and the stage 2 re-execution should not + // change the stats of the already finished job. + time += 1 + val j2Stages = Seq( + new StageInfo(3, 0, "stage1", 4, Nil, Nil, "details1"), + new StageInfo(4, 0, "stage2", 4, Nil, Seq(3), "details2")) + j2Stages.last.submissionTime = Some(time) + listener.onJobStart(SparkListenerJobStart(2, time, j2Stages, null)) + assert(store.count(classOf[JobDataWrapper]) === 2) + + listener.onStageSubmitted(SparkListenerStageSubmitted(j2Stages.head, jobProps)) + listener.onStageCompleted(SparkListenerStageCompleted(j2Stages.head)) + listener.onStageSubmitted(SparkListenerStageSubmitted(j2Stages.last, jobProps)) + assert(store.count(classOf[StageDataWrapper]) === 5) + + time += 1 + val j2s2Tasks = createTasks(4, time) + + j2s2Tasks.foreach { task => + listener.onTaskStart(SparkListenerTaskStart(j2Stages.last.stageId, j2Stages.last.attemptId, + task)) + } + + time += 1 + j2s2Tasks.foreach { task => + task.markFinished(TaskState.FINISHED, time) + listener.onTaskEnd(SparkListenerTaskEnd(j2Stages.last.stageId, j2Stages.last.attemptId, + "taskType", Success, task, null)) + } + + time += 1 + j2Stages.last.completionTime = Some(time) + listener.onStageCompleted(SparkListenerStageCompleted(j2Stages.last)) + + time += 1 + listener.onJobEnd(SparkListenerJobEnd(2, time, JobSucceeded)) + + check[JobDataWrapper](1) { job => + assert(job.info.numCompletedStages === 2) + assert(job.info.numCompletedTasks === s1Tasks.size + s2Tasks.size) + } + + check[JobDataWrapper](2) { job => + assert(job.info.status === JobExecutionStatus.SUCCEEDED) + assert(job.info.numCompletedStages === 1) + assert(job.info.numCompletedTasks === j2s2Tasks.size) + assert(job.info.numSkippedStages === 1) + assert(job.info.numSkippedTasks === s1Tasks.size) + } + + // Blacklist an executor. + time += 1 + listener.onExecutorBlacklisted(SparkListenerExecutorBlacklisted(time, "1", 42)) + check[ExecutorSummaryWrapper]("1") { exec => + assert(exec.info.isBlacklisted) + } + + time += 1 + listener.onExecutorUnblacklisted(SparkListenerExecutorUnblacklisted(time, "1")) + check[ExecutorSummaryWrapper]("1") { exec => + assert(!exec.info.isBlacklisted) + } + + // Blacklist a node. + time += 1 + listener.onNodeBlacklisted(SparkListenerNodeBlacklisted(time, "1.example.com", 2)) + check[ExecutorSummaryWrapper]("1") { exec => + assert(exec.info.isBlacklisted) + } + + time += 1 + listener.onNodeUnblacklisted(SparkListenerNodeUnblacklisted(time, "1.example.com")) + check[ExecutorSummaryWrapper]("1") { exec => + assert(!exec.info.isBlacklisted) + } + + // Stop executors. + listener.onExecutorRemoved(SparkListenerExecutorRemoved(41L, "1", "Test")) + listener.onExecutorRemoved(SparkListenerExecutorRemoved(41L, "2", "Test")) + + Seq("1", "2").foreach { id => + check[ExecutorSummaryWrapper](id) { exec => + assert(exec.info.id === id) + assert(!exec.info.isActive) + } + } + + // End the application. + listener.onApplicationEnd(SparkListenerApplicationEnd(42L)) + + check[ApplicationInfoWrapper]("id") { app => + assert(app.info.name === "name") + assert(app.info.id === "id") + assert(app.info.attempts.size === 1) + + val attempt = app.info.attempts.head + assert(attempt.attemptId === Some("attempt")) + assert(attempt.startTime === new Date(1L)) + assert(attempt.lastUpdated === new Date(42L)) + assert(attempt.endTime === new Date(42L)) + assert(attempt.duration === 41L) + assert(attempt.sparkUser === "user") + assert(attempt.completed) + } + } + + test("storage events") { + val listener = new AppStatusListener(store) + val maxMemory = 42L + + // Register a couple of block managers. + val bm1 = BlockManagerId("1", "1.example.com", 42) + val bm2 = BlockManagerId("2", "2.example.com", 84) + Seq(bm1, bm2).foreach { bm => + listener.onExecutorAdded(SparkListenerExecutorAdded(1L, bm.executorId, + new ExecutorInfo(bm.host, 1, Map()))) + listener.onBlockManagerAdded(SparkListenerBlockManagerAdded(1L, bm, maxMemory)) + check[ExecutorSummaryWrapper](bm.executorId) { exec => + assert(exec.info.maxMemory === maxMemory) + } + } + + val rdd1b1 = RDDBlockId(1, 1) + val level = StorageLevel.MEMORY_AND_DISK + + // Submit a stage and make sure the RDD is recorded. + val rddInfo = new RDDInfo(rdd1b1.rddId, "rdd1", 2, level, Nil) + val stage = new StageInfo(1, 0, "stage1", 4, Seq(rddInfo), Nil, "details1") + listener.onStageSubmitted(SparkListenerStageSubmitted(stage, new Properties())) + + check[RDDStorageInfoWrapper](rdd1b1.rddId) { wrapper => + assert(wrapper.info.name === rddInfo.name) + assert(wrapper.info.numPartitions === rddInfo.numPartitions) + assert(wrapper.info.storageLevel === rddInfo.storageLevel.description) + } + + // Add partition 1 replicated on two block managers. + listener.onBlockUpdated(SparkListenerBlockUpdated(BlockUpdatedInfo(bm1, rdd1b1, level, 1L, 1L))) + + check[RDDStorageInfoWrapper](rdd1b1.rddId) { wrapper => + assert(wrapper.info.memoryUsed === 1L) + assert(wrapper.info.diskUsed === 1L) + + assert(wrapper.info.dataDistribution.isDefined) + assert(wrapper.info.dataDistribution.get.size === 1) + + val dist = wrapper.info.dataDistribution.get.head + assert(dist.address === bm1.hostPort) + assert(dist.memoryUsed === 1L) + assert(dist.diskUsed === 1L) + assert(dist.memoryRemaining === maxMemory - dist.memoryUsed) + + assert(wrapper.info.partitions.isDefined) + assert(wrapper.info.partitions.get.size === 1) + + val part = wrapper.info.partitions.get.head + assert(part.blockName === rdd1b1.name) + assert(part.storageLevel === level.description) + assert(part.memoryUsed === 1L) + assert(part.diskUsed === 1L) + assert(part.executors === Seq(bm1.executorId)) + } + + check[ExecutorSummaryWrapper](bm1.executorId) { exec => + assert(exec.info.rddBlocks === 1L) + assert(exec.info.memoryUsed === 1L) + assert(exec.info.diskUsed === 1L) + } + + listener.onBlockUpdated(SparkListenerBlockUpdated(BlockUpdatedInfo(bm2, rdd1b1, level, 1L, 1L))) + + check[RDDStorageInfoWrapper](rdd1b1.rddId) { wrapper => + assert(wrapper.info.memoryUsed === 2L) + assert(wrapper.info.diskUsed === 2L) + assert(wrapper.info.dataDistribution.get.size === 2L) + assert(wrapper.info.partitions.get.size === 1L) + + val dist = wrapper.info.dataDistribution.get.find(_.address == bm2.hostPort).get + assert(dist.memoryUsed === 1L) + assert(dist.diskUsed === 1L) + assert(dist.memoryRemaining === maxMemory - dist.memoryUsed) + + val part = wrapper.info.partitions.get(0) + assert(part.memoryUsed === 2L) + assert(part.diskUsed === 2L) + assert(part.executors === Seq(bm1.executorId, bm2.executorId)) + } + + check[ExecutorSummaryWrapper](bm2.executorId) { exec => + assert(exec.info.rddBlocks === 1L) + assert(exec.info.memoryUsed === 1L) + assert(exec.info.diskUsed === 1L) + } + + // Add a second partition only to bm 1. + val rdd1b2 = RDDBlockId(1, 2) + listener.onBlockUpdated(SparkListenerBlockUpdated(BlockUpdatedInfo(bm1, rdd1b2, level, + 3L, 3L))) + + check[RDDStorageInfoWrapper](rdd1b1.rddId) { wrapper => + assert(wrapper.info.memoryUsed === 5L) + assert(wrapper.info.diskUsed === 5L) + assert(wrapper.info.dataDistribution.get.size === 2L) + assert(wrapper.info.partitions.get.size === 2L) + + val dist = wrapper.info.dataDistribution.get.find(_.address == bm1.hostPort).get + assert(dist.memoryUsed === 4L) + assert(dist.diskUsed === 4L) + assert(dist.memoryRemaining === maxMemory - dist.memoryUsed) + + val part = wrapper.info.partitions.get.find(_.blockName === rdd1b2.name).get + assert(part.storageLevel === level.description) + assert(part.memoryUsed === 3L) + assert(part.diskUsed === 3L) + assert(part.executors === Seq(bm1.executorId)) + } + + check[ExecutorSummaryWrapper](bm1.executorId) { exec => + assert(exec.info.rddBlocks === 2L) + assert(exec.info.memoryUsed === 4L) + assert(exec.info.diskUsed === 4L) + } + + // Remove block 1 from bm 1. + listener.onBlockUpdated(SparkListenerBlockUpdated(BlockUpdatedInfo(bm1, rdd1b1, + StorageLevel.NONE, 1L, 1L))) + + check[RDDStorageInfoWrapper](rdd1b1.rddId) { wrapper => + assert(wrapper.info.memoryUsed === 4L) + assert(wrapper.info.diskUsed === 4L) + assert(wrapper.info.dataDistribution.get.size === 2L) + assert(wrapper.info.partitions.get.size === 2L) + + val dist = wrapper.info.dataDistribution.get.find(_.address == bm1.hostPort).get + assert(dist.memoryUsed === 3L) + assert(dist.diskUsed === 3L) + assert(dist.memoryRemaining === maxMemory - dist.memoryUsed) + + val part = wrapper.info.partitions.get.find(_.blockName === rdd1b1.name).get + assert(part.storageLevel === level.description) + assert(part.memoryUsed === 1L) + assert(part.diskUsed === 1L) + assert(part.executors === Seq(bm2.executorId)) + } + + check[ExecutorSummaryWrapper](bm1.executorId) { exec => + assert(exec.info.rddBlocks === 1L) + assert(exec.info.memoryUsed === 3L) + assert(exec.info.diskUsed === 3L) + } + + // Remove block 2 from bm 2. This should leave only block 2 info in the store. + listener.onBlockUpdated(SparkListenerBlockUpdated(BlockUpdatedInfo(bm2, rdd1b1, + StorageLevel.NONE, 1L, 1L))) + + check[RDDStorageInfoWrapper](rdd1b1.rddId) { wrapper => + assert(wrapper.info.memoryUsed === 3L) + assert(wrapper.info.diskUsed === 3L) + assert(wrapper.info.dataDistribution.get.size === 1L) + assert(wrapper.info.partitions.get.size === 1L) + assert(wrapper.info.partitions.get(0).blockName === rdd1b2.name) + } + + check[ExecutorSummaryWrapper](bm2.executorId) { exec => + assert(exec.info.rddBlocks === 0L) + assert(exec.info.memoryUsed === 0L) + assert(exec.info.diskUsed === 0L) + } + + // Unpersist RDD1. + listener.onUnpersistRDD(SparkListenerUnpersistRDD(rdd1b1.rddId)) + intercept[NoSuchElementException] { + check[RDDStorageInfoWrapper](rdd1b1.rddId) { _ => () } + } + + } + + private def key(stage: StageInfo): Array[Int] = Array(stage.stageId, stage.attemptId) + + private def check[T: ClassTag](key: Any)(fn: T => Unit): Unit = { + val value = store.read(classTag[T].runtimeClass, key).asInstanceOf[T] + fn(value) + } + +} diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index dd299e074535e..45b8870f3b62f 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -36,6 +36,8 @@ object MimaExcludes { // Exclude rules for 2.3.x lazy val v23excludes = v22excludes ++ Seq( + // SPARK-18085: Better History Server scalability for many / large applications + ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.status.api.v1.ExecutorSummary.executorLogs"), // [SPARK-20495][SQL] Add StorageLevel to cacheTable API ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.sql.catalog.Catalog.cacheTable"), From 4f8dc6b01ea787243a38678ea8199fbb0814cffc Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Thu, 26 Oct 2017 21:41:45 +0100 Subject: [PATCH 1547/1765] [SPARK-22328][CORE] ClosureCleaner should not miss referenced superclass fields ## What changes were proposed in this pull request? When the given closure uses some fields defined in super class, `ClosureCleaner` can't figure them and don't set it properly. Those fields will be in null values. ## How was this patch tested? Added test. Author: Liang-Chi Hsieh Closes #19556 from viirya/SPARK-22328. --- .../apache/spark/util/ClosureCleaner.scala | 73 ++++++++++++++++--- .../spark/util/ClosureCleanerSuite.scala | 72 ++++++++++++++++++ 2 files changed, 133 insertions(+), 12 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/util/ClosureCleaner.scala b/core/src/main/scala/org/apache/spark/util/ClosureCleaner.scala index 48a1d7b84b61b..dfece5dd0670b 100644 --- a/core/src/main/scala/org/apache/spark/util/ClosureCleaner.scala +++ b/core/src/main/scala/org/apache/spark/util/ClosureCleaner.scala @@ -91,6 +91,54 @@ private[spark] object ClosureCleaner extends Logging { (seen - obj.getClass).toList } + /** Initializes the accessed fields for outer classes and their super classes. */ + private def initAccessedFields( + accessedFields: Map[Class[_], Set[String]], + outerClasses: Seq[Class[_]]): Unit = { + for (cls <- outerClasses) { + var currentClass = cls + assert(currentClass != null, "The outer class can't be null.") + + while (currentClass != null) { + accessedFields(currentClass) = Set.empty[String] + currentClass = currentClass.getSuperclass() + } + } + } + + /** Sets accessed fields for given class in clone object based on given object. */ + private def setAccessedFields( + outerClass: Class[_], + clone: AnyRef, + obj: AnyRef, + accessedFields: Map[Class[_], Set[String]]): Unit = { + for (fieldName <- accessedFields(outerClass)) { + val field = outerClass.getDeclaredField(fieldName) + field.setAccessible(true) + val value = field.get(obj) + field.set(clone, value) + } + } + + /** Clones a given object and sets accessed fields in cloned object. */ + private def cloneAndSetFields( + parent: AnyRef, + obj: AnyRef, + outerClass: Class[_], + accessedFields: Map[Class[_], Set[String]]): AnyRef = { + val clone = instantiateClass(outerClass, parent) + + var currentClass = outerClass + assert(currentClass != null, "The outer class can't be null.") + + while (currentClass != null) { + setAccessedFields(currentClass, clone, obj, accessedFields) + currentClass = currentClass.getSuperclass() + } + + clone + } + /** * Clean the given closure in place. * @@ -202,9 +250,8 @@ private[spark] object ClosureCleaner extends Logging { logDebug(s" + populating accessed fields because this is the starting closure") // Initialize accessed fields with the outer classes first // This step is needed to associate the fields to the correct classes later - for (cls <- outerClasses) { - accessedFields(cls) = Set.empty[String] - } + initAccessedFields(accessedFields, outerClasses) + // Populate accessed fields by visiting all fields and methods accessed by this and // all of its inner closures. If transitive cleaning is enabled, this may recursively // visits methods that belong to other classes in search of transitively referenced fields. @@ -250,13 +297,8 @@ private[spark] object ClosureCleaner extends Logging { // required fields from the original object. We need the parent here because the Java // language specification requires the first constructor parameter of any closure to be // its enclosing object. - val clone = instantiateClass(cls, parent) - for (fieldName <- accessedFields(cls)) { - val field = cls.getDeclaredField(fieldName) - field.setAccessible(true) - val value = field.get(obj) - field.set(clone, value) - } + val clone = cloneAndSetFields(parent, obj, cls, accessedFields) + // If transitive cleaning is enabled, we recursively clean any enclosing closure using // the already populated accessed fields map of the starting closure if (cleanTransitively && isClosure(clone.getClass)) { @@ -395,8 +437,15 @@ private[util] class FieldAccessFinder( if (!visitedMethods.contains(m)) { // Keep track of visited methods to avoid potential infinite cycles visitedMethods += m - ClosureCleaner.getClassReader(cl).accept( - new FieldAccessFinder(fields, findTransitively, Some(m), visitedMethods), 0) + + var currentClass = cl + assert(currentClass != null, "The outer class can't be null.") + + while (currentClass != null) { + ClosureCleaner.getClassReader(currentClass).accept( + new FieldAccessFinder(fields, findTransitively, Some(m), visitedMethods), 0) + currentClass = currentClass.getSuperclass() + } } } } diff --git a/core/src/test/scala/org/apache/spark/util/ClosureCleanerSuite.scala b/core/src/test/scala/org/apache/spark/util/ClosureCleanerSuite.scala index 4920b7ee8bfb4..9a19baee9569e 100644 --- a/core/src/test/scala/org/apache/spark/util/ClosureCleanerSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/ClosureCleanerSuite.scala @@ -119,6 +119,63 @@ class ClosureCleanerSuite extends SparkFunSuite { test("createNullValue") { new TestCreateNullValue().run() } + + test("SPARK-22328: ClosureCleaner misses referenced superclass fields: case 1") { + val concreteObject = new TestAbstractClass { + val n2 = 222 + val s2 = "bbb" + val d2 = 2.0d + + def run(): Seq[(Int, Int, String, String, Double, Double)] = { + withSpark(new SparkContext("local", "test")) { sc => + val rdd = sc.parallelize(1 to 1) + body(rdd) + } + } + + def body(rdd: RDD[Int]): Seq[(Int, Int, String, String, Double, Double)] = rdd.map { _ => + (n1, n2, s1, s2, d1, d2) + }.collect() + } + assert(concreteObject.run() === Seq((111, 222, "aaa", "bbb", 1.0d, 2.0d))) + } + + test("SPARK-22328: ClosureCleaner misses referenced superclass fields: case 2") { + val concreteObject = new TestAbstractClass2 { + val n2 = 222 + val s2 = "bbb" + val d2 = 2.0d + def getData: Int => (Int, Int, String, String, Double, Double) = _ => (n1, n2, s1, s2, d1, d2) + } + withSpark(new SparkContext("local", "test")) { sc => + val rdd = sc.parallelize(1 to 1).map(concreteObject.getData) + assert(rdd.collect() === Seq((111, 222, "aaa", "bbb", 1.0d, 2.0d))) + } + } + + test("SPARK-22328: multiple outer classes have the same parent class") { + val concreteObject = new TestAbstractClass2 { + + val innerObject = new TestAbstractClass2 { + override val n1 = 222 + override val s1 = "bbb" + } + + val innerObject2 = new TestAbstractClass2 { + override val n1 = 444 + val n3 = 333 + val s3 = "ccc" + val d3 = 3.0d + + def getData: Int => (Int, Int, String, String, Double, Double, Int, String) = + _ => (n1, n3, s1, s3, d1, d3, innerObject.n1, innerObject.s1) + } + } + withSpark(new SparkContext("local", "test")) { sc => + val rdd = sc.parallelize(1 to 1).map(concreteObject.innerObject2.getData) + assert(rdd.collect() === Seq((444, 333, "aaa", "ccc", 1.0d, 3.0d, 222, "bbb"))) + } + } } // A non-serializable class we create in closures to make sure that we aren't @@ -377,3 +434,18 @@ class TestCreateNullValue { nestedClosure() } } + +abstract class TestAbstractClass extends Serializable { + val n1 = 111 + val s1 = "aaa" + protected val d1 = 1.0d + + def run(): Seq[(Int, Int, String, String, Double, Double)] + def body(rdd: RDD[Int]): Seq[(Int, Int, String, String, Double, Double)] +} + +abstract class TestAbstractClass2 extends Serializable { + val n1 = 111 + val s1 = "aaa" + protected val d1 = 1.0d +} From 5415963d2caaf95604211419ffc4e29fff38e1d7 Mon Sep 17 00:00:00 2001 From: "Susan X. Huynh" Date: Thu, 26 Oct 2017 16:13:48 -0700 Subject: [PATCH 1548/1765] [SPARK-22131][MESOS] Mesos driver secrets ## Background In #18837 , ArtRand added Mesos secrets support to the dispatcher. **This PR is to add the same secrets support to the drivers.** This means if the secret configs are set, the driver will launch executors that have access to either env or file-based secrets. One use case for this is to support TLS in the driver <=> executor communication. ## What changes were proposed in this pull request? Most of the changes are a refactor of the dispatcher secrets support (#18837) - moving it to a common place that can be used by both the dispatcher and drivers. The same goes for the unit tests. ## How was this patch tested? There are four config combinations: [env or file-based] x [value or reference secret]. For each combination: - Added a unit test. - Tested in DC/OS. Author: Susan X. Huynh Closes #19437 from susanxhuynh/sh-mesos-driver-secret. --- docs/running-on-mesos.md | 111 ++++++++++--- .../apache/spark/deploy/mesos/config.scala | 64 ++++---- .../cluster/mesos/MesosClusterScheduler.scala | 138 +++------------- .../MesosCoarseGrainedSchedulerBackend.scala | 31 +++- .../MesosFineGrainedSchedulerBackend.scala | 4 +- .../mesos/MesosSchedulerBackendUtil.scala | 92 ++++++++++- .../mesos/MesosClusterSchedulerSuite.scala | 150 +++--------------- ...osCoarseGrainedSchedulerBackendSuite.scala | 34 +++- .../MesosSchedulerBackendUtilSuite.scala | 7 +- .../spark/scheduler/cluster/mesos/Utils.scala | 107 +++++++++++++ 10 files changed, 434 insertions(+), 304 deletions(-) diff --git a/docs/running-on-mesos.md b/docs/running-on-mesos.md index e0944bc9f5f86..b7e3e6473c338 100644 --- a/docs/running-on-mesos.md +++ b/docs/running-on-mesos.md @@ -485,39 +485,106 @@ See the [configuration page](configuration.html) for information on Spark config - spark.mesos.driver.secret.envkeys - (none) - A comma-separated list that, if set, the contents of the secret referenced - by spark.mesos.driver.secret.names or spark.mesos.driver.secret.values will be - set to the provided environment variable in the driver's process. + spark.mesos.driver.secret.values, + spark.mesos.driver.secret.names, + spark.mesos.executor.secret.values, + spark.mesos.executor.secret.names, - - -spark.mesos.driver.secret.filenames (none) - A comma-separated list that, if set, the contents of the secret referenced by - spark.mesos.driver.secret.names or spark.mesos.driver.secret.values will be - written to the provided file. Paths are relative to the container's work - directory. Absolute paths must already exist. Consult the Mesos Secret - protobuf for more information. +

    + A secret is specified by its contents and destination. These properties + specify a secret's contents. To specify a secret's destination, see the cell below. +

    +

    + You can specify a secret's contents either (1) by value or (2) by reference. +

    +

    + (1) To specify a secret by value, set the + spark.mesos.[driver|executor].secret.values + property, to make the secret available in the driver or executors. + For example, to make a secret password "guessme" available to the driver process, set: + +

    spark.mesos.driver.secret.values=guessme
    +

    +

    + (2) To specify a secret that has been placed in a secret store + by reference, specify its name within the secret store + by setting the spark.mesos.[driver|executor].secret.names + property. For example, to make a secret password named "password" in a secret store + available to the driver process, set: + +

    spark.mesos.driver.secret.names=password
    +

    +

    + Note: To use a secret store, make sure one has been integrated with Mesos via a custom + SecretResolver + module. +

    +

    + To specify multiple secrets, provide a comma-separated list: + +

    spark.mesos.driver.secret.values=guessme,passwd123
    + + or + +
    spark.mesos.driver.secret.names=password1,password2
    +

    + - spark.mesos.driver.secret.names - (none) - A comma-separated list of secret references. Consult the Mesos Secret - protobuf for more information. + spark.mesos.driver.secret.envkeys, + spark.mesos.driver.secret.filenames, + spark.mesos.executor.secret.envkeys, + spark.mesos.executor.secret.filenames, - - - spark.mesos.driver.secret.values (none) - A comma-separated list of secret values. Consult the Mesos Secret - protobuf for more information. +

    + A secret is specified by its contents and destination. These properties + specify a secret's destination. To specify a secret's contents, see the cell above. +

    +

    + You can specify a secret's destination in the driver or + executors as either (1) an environment variable or (2) as a file. +

    +

    + (1) To make an environment-based secret, set the + spark.mesos.[driver|executor].secret.envkeys property. + The secret will appear as an environment variable with the + given name in the driver or executors. For example, to make a secret password available + to the driver process as $PASSWORD, set: + +

    spark.mesos.driver.secret.envkeys=PASSWORD
    +

    +

    + (2) To make a file-based secret, set the + spark.mesos.[driver|executor].secret.filenames property. + The secret will appear in the contents of a file with the given file name in + the driver or executors. For example, to make a secret password available in a + file named "pwdfile" in the driver process, set: + +

    spark.mesos.driver.secret.filenames=pwdfile
    +

    +

    + Paths are relative to the container's work directory. Absolute paths must + already exist. Note: File-based secrets require a custom + SecretResolver + module. +

    +

    + To specify env vars or file names corresponding to multiple secrets, + provide a comma-separated list: + +

    spark.mesos.driver.secret.envkeys=PASSWORD1,PASSWORD2
    + + or + +
    spark.mesos.driver.secret.filenames=pwdfile1,pwdfile2
    +

    diff --git a/resource-managers/mesos/src/main/scala/org/apache/spark/deploy/mesos/config.scala b/resource-managers/mesos/src/main/scala/org/apache/spark/deploy/mesos/config.scala index 7e85de91c5d36..821534eb4fc38 100644 --- a/resource-managers/mesos/src/main/scala/org/apache/spark/deploy/mesos/config.scala +++ b/resource-managers/mesos/src/main/scala/org/apache/spark/deploy/mesos/config.scala @@ -23,6 +23,39 @@ import org.apache.spark.internal.config.ConfigBuilder package object config { + private[spark] class MesosSecretConfig private[config](taskType: String) { + private[spark] val SECRET_NAMES = + ConfigBuilder(s"spark.mesos.$taskType.secret.names") + .doc("A comma-separated list of secret reference names. Consult the Mesos Secret " + + "protobuf for more information.") + .stringConf + .toSequence + .createOptional + + private[spark] val SECRET_VALUES = + ConfigBuilder(s"spark.mesos.$taskType.secret.values") + .doc("A comma-separated list of secret values.") + .stringConf + .toSequence + .createOptional + + private[spark] val SECRET_ENVKEYS = + ConfigBuilder(s"spark.mesos.$taskType.secret.envkeys") + .doc("A comma-separated list of the environment variables to contain the secrets." + + "The environment variable will be set on the driver.") + .stringConf + .toSequence + .createOptional + + private[spark] val SECRET_FILENAMES = + ConfigBuilder(s"spark.mesos.$taskType.secret.filenames") + .doc("A comma-separated list of file paths secret will be written to. Consult the Mesos " + + "Secret protobuf for more information.") + .stringConf + .toSequence + .createOptional + } + /* Common app configuration. */ private[spark] val SHUFFLE_CLEANER_INTERVAL_S = @@ -64,36 +97,9 @@ package object config { .stringConf .createOptional - private[spark] val SECRET_NAME = - ConfigBuilder("spark.mesos.driver.secret.names") - .doc("A comma-separated list of secret reference names. Consult the Mesos Secret protobuf " + - "for more information.") - .stringConf - .toSequence - .createOptional - - private[spark] val SECRET_VALUE = - ConfigBuilder("spark.mesos.driver.secret.values") - .doc("A comma-separated list of secret values.") - .stringConf - .toSequence - .createOptional + private[spark] val driverSecretConfig = new MesosSecretConfig("driver") - private[spark] val SECRET_ENVKEY = - ConfigBuilder("spark.mesos.driver.secret.envkeys") - .doc("A comma-separated list of the environment variables to contain the secrets." + - "The environment variable will be set on the driver.") - .stringConf - .toSequence - .createOptional - - private[spark] val SECRET_FILENAME = - ConfigBuilder("spark.mesos.driver.secret.filenames") - .doc("A comma-seperated list of file paths secret will be written to. Consult the Mesos " + - "Secret protobuf for more information.") - .stringConf - .toSequence - .createOptional + private[spark] val executorSecretConfig = new MesosSecretConfig("executor") private[spark] val DRIVER_FAILOVER_TIMEOUT = ConfigBuilder("spark.mesos.driver.failoverTimeout") diff --git a/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterScheduler.scala b/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterScheduler.scala index ec533f91474f2..82470264f2a4a 100644 --- a/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterScheduler.scala +++ b/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterScheduler.scala @@ -28,7 +28,6 @@ import org.apache.mesos.{Scheduler, SchedulerDriver} import org.apache.mesos.Protos.{TaskState => MesosTaskState, _} import org.apache.mesos.Protos.Environment.Variable import org.apache.mesos.Protos.TaskStatus.Reason -import org.apache.mesos.protobuf.ByteString import org.apache.spark.{SecurityManager, SparkConf, SparkException, TaskState} import org.apache.spark.deploy.mesos.MesosDriverDescription @@ -394,39 +393,20 @@ private[spark] class MesosClusterScheduler( } // add secret environment variables - getSecretEnvVar(desc).foreach { variable => - if (variable.getSecret.getReference.isInitialized) { - logInfo(s"Setting reference secret ${variable.getSecret.getReference.getName}" + - s"on file ${variable.getName}") - } else { - logInfo(s"Setting secret on environment variable name=${variable.getName}") - } - envBuilder.addVariables(variable) + MesosSchedulerBackendUtil.getSecretEnvVar(desc.conf, config.driverSecretConfig) + .foreach { variable => + if (variable.getSecret.getReference.isInitialized) { + logInfo(s"Setting reference secret ${variable.getSecret.getReference.getName} " + + s"on file ${variable.getName}") + } else { + logInfo(s"Setting secret on environment variable name=${variable.getName}") + } + envBuilder.addVariables(variable) } envBuilder.build() } - private def getSecretEnvVar(desc: MesosDriverDescription): List[Variable] = { - val secrets = getSecrets(desc) - val secretEnvKeys = desc.conf.get(config.SECRET_ENVKEY).getOrElse(Nil) - if (illegalSecretInput(secretEnvKeys, secrets)) { - throw new SparkException( - s"Need to give equal numbers of secrets and environment keys " + - s"for environment-based reference secrets got secrets $secrets, " + - s"and keys $secretEnvKeys") - } - - secrets.zip(secretEnvKeys).map { - case (s, k) => - Variable.newBuilder() - .setName(k) - .setType(Variable.Type.SECRET) - .setSecret(s) - .build - }.toList - } - private def getDriverUris(desc: MesosDriverDescription): List[CommandInfo.URI] = { val confUris = List(conf.getOption("spark.mesos.uris"), desc.conf.getOption("spark.mesos.uris"), @@ -440,6 +420,23 @@ private[spark] class MesosClusterScheduler( CommandInfo.URI.newBuilder().setValue(uri.trim()).setCache(useFetchCache).build()) } + private def getContainerInfo(desc: MesosDriverDescription): ContainerInfo.Builder = { + val containerInfo = MesosSchedulerBackendUtil.buildContainerInfo(desc.conf) + + MesosSchedulerBackendUtil.getSecretVolume(desc.conf, config.driverSecretConfig) + .foreach { volume => + if (volume.getSource.getSecret.getReference.isInitialized) { + logInfo(s"Setting reference secret ${volume.getSource.getSecret.getReference.getName} " + + s"on file ${volume.getContainerPath}") + } else { + logInfo(s"Setting secret on file name=${volume.getContainerPath}") + } + containerInfo.addVolumes(volume) + } + + containerInfo + } + private def getDriverCommandValue(desc: MesosDriverDescription): String = { val dockerDefined = desc.conf.contains("spark.mesos.executor.docker.image") val executorUri = getDriverExecutorURI(desc) @@ -579,89 +576,6 @@ private[spark] class MesosClusterScheduler( .build } - private def getContainerInfo(desc: MesosDriverDescription): ContainerInfo.Builder = { - val containerInfo = MesosSchedulerBackendUtil.containerInfo(desc.conf) - - getSecretVolume(desc).foreach { volume => - if (volume.getSource.getSecret.getReference.isInitialized) { - logInfo(s"Setting reference secret ${volume.getSource.getSecret.getReference.getName}" + - s"on file ${volume.getContainerPath}") - } else { - logInfo(s"Setting secret on file name=${volume.getContainerPath}") - } - containerInfo.addVolumes(volume) - } - - containerInfo - } - - - private def getSecrets(desc: MesosDriverDescription): Seq[Secret] = { - def createValueSecret(data: String): Secret = { - Secret.newBuilder() - .setType(Secret.Type.VALUE) - .setValue(Secret.Value.newBuilder().setData(ByteString.copyFrom(data.getBytes))) - .build() - } - - def createReferenceSecret(name: String): Secret = { - Secret.newBuilder() - .setReference(Secret.Reference.newBuilder().setName(name)) - .setType(Secret.Type.REFERENCE) - .build() - } - - val referenceSecrets: Seq[Secret] = - desc.conf.get(config.SECRET_NAME).getOrElse(Nil).map(s => createReferenceSecret(s)) - - val valueSecrets: Seq[Secret] = { - desc.conf.get(config.SECRET_VALUE).getOrElse(Nil).map(s => createValueSecret(s)) - } - - if (valueSecrets.nonEmpty && referenceSecrets.nonEmpty) { - throw new SparkException("Cannot specify VALUE type secrets and REFERENCE types ones") - } - - if (referenceSecrets.nonEmpty) referenceSecrets else valueSecrets - } - - private def illegalSecretInput(dest: Seq[String], s: Seq[Secret]): Boolean = { - if (dest.isEmpty) { // no destination set (ie not using secrets of this type - return false - } - if (dest.nonEmpty && s.nonEmpty) { - // make sure there is a destination for each secret of this type - if (dest.length != s.length) { - return true - } - } - false - } - - private def getSecretVolume(desc: MesosDriverDescription): List[Volume] = { - val secrets = getSecrets(desc) - val secretPaths: Seq[String] = - desc.conf.get(config.SECRET_FILENAME).getOrElse(Nil) - - if (illegalSecretInput(secretPaths, secrets)) { - throw new SparkException( - s"Need to give equal numbers of secrets and file paths for file-based " + - s"reference secrets got secrets $secrets, and paths $secretPaths") - } - - secrets.zip(secretPaths).map { - case (s, p) => - val source = Volume.Source.newBuilder() - .setType(Volume.Source.Type.SECRET) - .setSecret(s) - Volume.newBuilder() - .setContainerPath(p) - .setSource(source) - .setMode(Volume.Mode.RO) - .build - }.toList - } - /** * This method takes all the possible candidates and attempt to schedule them with Mesos offers. * Every time a new task is scheduled, the afterLaunchCallback is called to perform post scheduled diff --git a/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackend.scala b/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackend.scala index 603c980cb268d..104ed01d293ce 100644 --- a/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackend.scala +++ b/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackend.scala @@ -28,7 +28,7 @@ import scala.collection.JavaConverters._ import scala.collection.mutable import scala.concurrent.Future -import org.apache.spark.{SecurityManager, SparkContext, SparkException, TaskState} +import org.apache.spark.{SecurityManager, SparkConf, SparkContext, SparkException, TaskState} import org.apache.spark.deploy.mesos.config._ import org.apache.spark.deploy.security.HadoopDelegationTokenManager import org.apache.spark.internal.config @@ -244,6 +244,17 @@ private[spark] class MesosCoarseGrainedSchedulerBackend( .setValue(value) .build()) } + + MesosSchedulerBackendUtil.getSecretEnvVar(conf, executorSecretConfig).foreach { variable => + if (variable.getSecret.getReference.isInitialized) { + logInfo(s"Setting reference secret ${variable.getSecret.getReference.getName} " + + s"on file ${variable.getName}") + } else { + logInfo(s"Setting secret on environment variable name=${variable.getName}") + } + environment.addVariables(variable) + } + val command = CommandInfo.newBuilder() .setEnvironment(environment) @@ -424,6 +435,22 @@ private[spark] class MesosCoarseGrainedSchedulerBackend( } } + private def getContainerInfo(conf: SparkConf): ContainerInfo.Builder = { + val containerInfo = MesosSchedulerBackendUtil.buildContainerInfo(conf) + + MesosSchedulerBackendUtil.getSecretVolume(conf, executorSecretConfig).foreach { volume => + if (volume.getSource.getSecret.getReference.isInitialized) { + logInfo(s"Setting reference secret ${volume.getSource.getSecret.getReference.getName} " + + s"on file ${volume.getContainerPath}") + } else { + logInfo(s"Setting secret on file name=${volume.getContainerPath}") + } + containerInfo.addVolumes(volume) + } + + containerInfo + } + /** * Returns a map from OfferIDs to the tasks to launch on those offers. In order to maximize * per-task memory and IO, tasks are round-robin assigned to offers. @@ -475,7 +502,7 @@ private[spark] class MesosCoarseGrainedSchedulerBackend( .setName(s"${sc.appName} $taskId") .setLabels(MesosProtoUtils.mesosLabels(taskLabels)) .addAllResources(resourcesToUse.asJava) - .setContainer(MesosSchedulerBackendUtil.containerInfo(sc.conf)) + .setContainer(getContainerInfo(sc.conf)) tasks(offer.getId) ::= taskBuilder.build() remainingResources(offerId) = resourcesLeft.asJava diff --git a/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosFineGrainedSchedulerBackend.scala b/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosFineGrainedSchedulerBackend.scala index 66b8e0a640121..d6d939d246109 100644 --- a/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosFineGrainedSchedulerBackend.scala +++ b/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosFineGrainedSchedulerBackend.scala @@ -28,6 +28,7 @@ import org.apache.mesos.SchedulerDriver import org.apache.mesos.protobuf.ByteString import org.apache.spark.{SparkContext, SparkException, TaskState} +import org.apache.spark.deploy.mesos.config import org.apache.spark.executor.MesosExecutorBackend import org.apache.spark.scheduler._ import org.apache.spark.scheduler.cluster.ExecutorInfo @@ -159,7 +160,8 @@ private[spark] class MesosFineGrainedSchedulerBackend( .setCommand(command) .setData(ByteString.copyFrom(createExecArg())) - executorInfo.setContainer(MesosSchedulerBackendUtil.containerInfo(sc.conf)) + executorInfo.setContainer( + MesosSchedulerBackendUtil.buildContainerInfo(sc.conf)) (executorInfo.build(), resourcesAfterMem.asJava) } diff --git a/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackendUtil.scala b/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackendUtil.scala index f29e541addf23..bfb73611f0530 100644 --- a/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackendUtil.scala +++ b/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackendUtil.scala @@ -17,11 +17,15 @@ package org.apache.spark.scheduler.cluster.mesos -import org.apache.mesos.Protos.{ContainerInfo, Image, NetworkInfo, Parameter, Volume} +import org.apache.mesos.Protos.{ContainerInfo, Environment, Image, NetworkInfo, Parameter, Secret, Volume} import org.apache.mesos.Protos.ContainerInfo.{DockerInfo, MesosInfo} +import org.apache.mesos.Protos.Environment.Variable +import org.apache.mesos.protobuf.ByteString -import org.apache.spark.{SparkConf, SparkException} +import org.apache.spark.SparkConf +import org.apache.spark.SparkException import org.apache.spark.deploy.mesos.config.{NETWORK_LABELS, NETWORK_NAME} +import org.apache.spark.deploy.mesos.config.MesosSecretConfig import org.apache.spark.internal.Logging /** @@ -122,7 +126,7 @@ private[mesos] object MesosSchedulerBackendUtil extends Logging { .toList } - def containerInfo(conf: SparkConf): ContainerInfo.Builder = { + def buildContainerInfo(conf: SparkConf): ContainerInfo.Builder = { val containerType = if (conf.contains("spark.mesos.executor.docker.image") && conf.get("spark.mesos.containerizer", "docker") == "docker") { ContainerInfo.Type.DOCKER @@ -173,6 +177,88 @@ private[mesos] object MesosSchedulerBackendUtil extends Logging { containerInfo } + private def getSecrets(conf: SparkConf, secretConfig: MesosSecretConfig): Seq[Secret] = { + def createValueSecret(data: String): Secret = { + Secret.newBuilder() + .setType(Secret.Type.VALUE) + .setValue(Secret.Value.newBuilder().setData(ByteString.copyFrom(data.getBytes))) + .build() + } + + def createReferenceSecret(name: String): Secret = { + Secret.newBuilder() + .setReference(Secret.Reference.newBuilder().setName(name)) + .setType(Secret.Type.REFERENCE) + .build() + } + + val referenceSecrets: Seq[Secret] = + conf.get(secretConfig.SECRET_NAMES).getOrElse(Nil).map { s => createReferenceSecret(s) } + + val valueSecrets: Seq[Secret] = { + conf.get(secretConfig.SECRET_VALUES).getOrElse(Nil).map { s => createValueSecret(s) } + } + + if (valueSecrets.nonEmpty && referenceSecrets.nonEmpty) { + throw new SparkException("Cannot specify both value-type and reference-type secrets.") + } + + if (referenceSecrets.nonEmpty) referenceSecrets else valueSecrets + } + + private def illegalSecretInput(dest: Seq[String], secrets: Seq[Secret]): Boolean = { + if (dest.nonEmpty) { + // make sure there is a one-to-one correspondence between destinations and secrets + if (dest.length != secrets.length) { + return true + } + } + false + } + + def getSecretVolume(conf: SparkConf, secretConfig: MesosSecretConfig): List[Volume] = { + val secrets = getSecrets(conf, secretConfig) + val secretPaths: Seq[String] = + conf.get(secretConfig.SECRET_FILENAMES).getOrElse(Nil) + + if (illegalSecretInput(secretPaths, secrets)) { + throw new SparkException( + s"Need to give equal numbers of secrets and file paths for file-based " + + s"reference secrets got secrets $secrets, and paths $secretPaths") + } + + secrets.zip(secretPaths).map { case (s, p) => + val source = Volume.Source.newBuilder() + .setType(Volume.Source.Type.SECRET) + .setSecret(s) + Volume.newBuilder() + .setContainerPath(p) + .setSource(source) + .setMode(Volume.Mode.RO) + .build + }.toList + } + + def getSecretEnvVar(conf: SparkConf, secretConfig: MesosSecretConfig): + List[Variable] = { + val secrets = getSecrets(conf, secretConfig) + val secretEnvKeys = conf.get(secretConfig.SECRET_ENVKEYS).getOrElse(Nil) + if (illegalSecretInput(secretEnvKeys, secrets)) { + throw new SparkException( + s"Need to give equal numbers of secrets and environment keys " + + s"for environment-based reference secrets got secrets $secrets, " + + s"and keys $secretEnvKeys") + } + + secrets.zip(secretEnvKeys).map { case (s, k) => + Variable.newBuilder() + .setName(k) + .setType(Variable.Type.SECRET) + .setSecret(s) + .build + }.toList + } + private def dockerInfo( image: String, forcePullImage: Boolean, diff --git a/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterSchedulerSuite.scala b/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterSchedulerSuite.scala index ff63e3f4ccfc3..77acee608f25f 100644 --- a/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterSchedulerSuite.scala +++ b/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterSchedulerSuite.scala @@ -24,7 +24,6 @@ import scala.collection.JavaConverters._ import org.apache.mesos.Protos.{Environment, Secret, TaskState => MesosTaskState, _} import org.apache.mesos.Protos.Value.{Scalar, Type} import org.apache.mesos.SchedulerDriver -import org.apache.mesos.protobuf.ByteString import org.mockito.{ArgumentCaptor, Matchers} import org.mockito.Mockito._ import org.scalatest.mockito.MockitoSugar @@ -32,6 +31,7 @@ import org.scalatest.mockito.MockitoSugar import org.apache.spark.{LocalSparkContext, SparkConf, SparkFunSuite} import org.apache.spark.deploy.Command import org.apache.spark.deploy.mesos.MesosDriverDescription +import org.apache.spark.deploy.mesos.config class MesosClusterSchedulerSuite extends SparkFunSuite with LocalSparkContext with MockitoSugar { @@ -341,132 +341,33 @@ class MesosClusterSchedulerSuite extends SparkFunSuite with LocalSparkContext wi } test("Creates an env-based reference secrets.") { - setScheduler() - - val mem = 1000 - val cpu = 1 - val secretName = "/path/to/secret,/anothersecret" - val envKey = "SECRET_ENV_KEY,PASSWORD" - val driverDesc = new MesosDriverDescription( - "d1", - "jar", - mem, - cpu, - true, - command, - Map("spark.mesos.executor.home" -> "test", - "spark.app.name" -> "test", - "spark.mesos.driver.secret.names" -> secretName, - "spark.mesos.driver.secret.envkeys" -> envKey), - "s1", - new Date()) - val response = scheduler.submitDriver(driverDesc) - assert(response.success) - val offer = Utils.createOffer("o1", "s1", mem, cpu) - scheduler.resourceOffers(driver, Collections.singletonList(offer)) - val launchedTasks = Utils.verifyTaskLaunched(driver, "o1") - assert(launchedTasks.head - .getCommand - .getEnvironment - .getVariablesCount == 3) // SPARK_SUBMIT_OPS and the secret - val variableOne = launchedTasks.head.getCommand.getEnvironment - .getVariablesList.asScala.filter(_.getName == "SECRET_ENV_KEY").head - assert(variableOne.getSecret.isInitialized) - assert(variableOne.getSecret.getType == Secret.Type.REFERENCE) - assert(variableOne.getSecret.getReference.getName == "/path/to/secret") - assert(variableOne.getType == Environment.Variable.Type.SECRET) - val variableTwo = launchedTasks.head.getCommand.getEnvironment - .getVariablesList.asScala.filter(_.getName == "PASSWORD").head - assert(variableTwo.getSecret.isInitialized) - assert(variableTwo.getSecret.getType == Secret.Type.REFERENCE) - assert(variableTwo.getSecret.getReference.getName == "/anothersecret") - assert(variableTwo.getType == Environment.Variable.Type.SECRET) + val launchedTasks = launchDriverTask( + Utils.configEnvBasedRefSecrets(config.driverSecretConfig)) + Utils.verifyEnvBasedRefSecrets(launchedTasks) } test("Creates an env-based value secrets.") { - setScheduler() - val mem = 1000 - val cpu = 1 - val secretValues = "user,password" - val envKeys = "USER,PASSWORD" - val driverDesc = new MesosDriverDescription( - "d1", - "jar", - mem, - cpu, - true, - command, - Map("spark.mesos.executor.home" -> "test", - "spark.app.name" -> "test", - "spark.mesos.driver.secret.values" -> secretValues, - "spark.mesos.driver.secret.envkeys" -> envKeys), - "s1", - new Date()) - val response = scheduler.submitDriver(driverDesc) - assert(response.success) - val offer = Utils.createOffer("o1", "s1", mem, cpu) - scheduler.resourceOffers(driver, Collections.singletonList(offer)) - val launchedTasks = Utils.verifyTaskLaunched(driver, "o1") - assert(launchedTasks.head - .getCommand - .getEnvironment - .getVariablesCount == 3) // SPARK_SUBMIT_OPS and the secret - val variableOne = launchedTasks.head.getCommand.getEnvironment - .getVariablesList.asScala.filter(_.getName == "USER").head - assert(variableOne.getSecret.isInitialized) - assert(variableOne.getSecret.getType == Secret.Type.VALUE) - assert(variableOne.getSecret.getValue.getData == ByteString.copyFrom("user".getBytes)) - assert(variableOne.getType == Environment.Variable.Type.SECRET) - val variableTwo = launchedTasks.head.getCommand.getEnvironment - .getVariablesList.asScala.filter(_.getName == "PASSWORD").head - assert(variableTwo.getSecret.isInitialized) - assert(variableTwo.getSecret.getType == Secret.Type.VALUE) - assert(variableTwo.getSecret.getValue.getData == ByteString.copyFrom("password".getBytes)) - assert(variableTwo.getType == Environment.Variable.Type.SECRET) + val launchedTasks = launchDriverTask( + Utils.configEnvBasedValueSecrets(config.driverSecretConfig)) + Utils.verifyEnvBasedValueSecrets(launchedTasks) } test("Creates file-based reference secrets.") { - setScheduler() - val mem = 1000 - val cpu = 1 - val secretName = "/path/to/secret,/anothersecret" - val secretPath = "/topsecret,/mypassword" - val driverDesc = new MesosDriverDescription( - "d1", - "jar", - mem, - cpu, - true, - command, - Map("spark.mesos.executor.home" -> "test", - "spark.app.name" -> "test", - "spark.mesos.driver.secret.names" -> secretName, - "spark.mesos.driver.secret.filenames" -> secretPath), - "s1", - new Date()) - val response = scheduler.submitDriver(driverDesc) - assert(response.success) - val offer = Utils.createOffer("o1", "s1", mem, cpu) - scheduler.resourceOffers(driver, Collections.singletonList(offer)) - val launchedTasks = Utils.verifyTaskLaunched(driver, "o1") - val volumes = launchedTasks.head.getContainer.getVolumesList - assert(volumes.size() == 2) - val secretVolOne = volumes.get(0) - assert(secretVolOne.getContainerPath == "/topsecret") - assert(secretVolOne.getSource.getSecret.getType == Secret.Type.REFERENCE) - assert(secretVolOne.getSource.getSecret.getReference.getName == "/path/to/secret") - val secretVolTwo = volumes.get(1) - assert(secretVolTwo.getContainerPath == "/mypassword") - assert(secretVolTwo.getSource.getSecret.getType == Secret.Type.REFERENCE) - assert(secretVolTwo.getSource.getSecret.getReference.getName == "/anothersecret") + val launchedTasks = launchDriverTask( + Utils.configFileBasedRefSecrets(config.driverSecretConfig)) + Utils.verifyFileBasedRefSecrets(launchedTasks) } test("Creates a file-based value secrets.") { + val launchedTasks = launchDriverTask( + Utils.configFileBasedValueSecrets(config.driverSecretConfig)) + Utils.verifyFileBasedValueSecrets(launchedTasks) + } + + private def launchDriverTask(addlSparkConfVars: Map[String, String]): List[TaskInfo] = { setScheduler() val mem = 1000 val cpu = 1 - val secretValues = "user,password" - val secretPath = "/whoami,/mypassword" val driverDesc = new MesosDriverDescription( "d1", "jar", @@ -475,27 +376,14 @@ class MesosClusterSchedulerSuite extends SparkFunSuite with LocalSparkContext wi true, command, Map("spark.mesos.executor.home" -> "test", - "spark.app.name" -> "test", - "spark.mesos.driver.secret.values" -> secretValues, - "spark.mesos.driver.secret.filenames" -> secretPath), + "spark.app.name" -> "test") ++ + addlSparkConfVars, "s1", new Date()) val response = scheduler.submitDriver(driverDesc) assert(response.success) val offer = Utils.createOffer("o1", "s1", mem, cpu) scheduler.resourceOffers(driver, Collections.singletonList(offer)) - val launchedTasks = Utils.verifyTaskLaunched(driver, "o1") - val volumes = launchedTasks.head.getContainer.getVolumesList - assert(volumes.size() == 2) - val secretVolOne = volumes.get(0) - assert(secretVolOne.getContainerPath == "/whoami") - assert(secretVolOne.getSource.getSecret.getType == Secret.Type.VALUE) - assert(secretVolOne.getSource.getSecret.getValue.getData == - ByteString.copyFrom("user".getBytes)) - val secretVolTwo = volumes.get(1) - assert(secretVolTwo.getContainerPath == "/mypassword") - assert(secretVolTwo.getSource.getSecret.getType == Secret.Type.VALUE) - assert(secretVolTwo.getSource.getSecret.getValue.getData == - ByteString.copyFrom("password".getBytes)) + Utils.verifyTaskLaunched(driver, "o1") } } diff --git a/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackendSuite.scala b/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackendSuite.scala index 6c40792112f49..f4bd1ee9da6f7 100644 --- a/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackendSuite.scala +++ b/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackendSuite.scala @@ -21,7 +21,6 @@ import java.util.concurrent.TimeUnit import scala.collection.JavaConverters._ import scala.concurrent.duration._ -import scala.reflect.ClassTag import org.apache.mesos.{Protos, Scheduler, SchedulerDriver} import org.apache.mesos.Protos._ @@ -38,7 +37,7 @@ import org.apache.spark.internal.config._ import org.apache.spark.network.shuffle.mesos.MesosExternalShuffleClient import org.apache.spark.rpc.{RpcAddress, RpcEndpointRef} import org.apache.spark.scheduler.TaskSchedulerImpl -import org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages.{RegisterExecutor, RemoveExecutor} +import org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages.{RegisterExecutor} import org.apache.spark.scheduler.cluster.mesos.Utils._ class MesosCoarseGrainedSchedulerBackendSuite extends SparkFunSuite @@ -653,6 +652,37 @@ class MesosCoarseGrainedSchedulerBackendSuite extends SparkFunSuite offerResourcesAndVerify(2, true) } + test("Creates an env-based reference secrets.") { + val launchedTasks = launchExecutorTasks(configEnvBasedRefSecrets(executorSecretConfig)) + verifyEnvBasedRefSecrets(launchedTasks) + } + + test("Creates an env-based value secrets.") { + val launchedTasks = launchExecutorTasks(configEnvBasedValueSecrets(executorSecretConfig)) + verifyEnvBasedValueSecrets(launchedTasks) + } + + test("Creates file-based reference secrets.") { + val launchedTasks = launchExecutorTasks(configFileBasedRefSecrets(executorSecretConfig)) + verifyFileBasedRefSecrets(launchedTasks) + } + + test("Creates a file-based value secrets.") { + val launchedTasks = launchExecutorTasks(configFileBasedValueSecrets(executorSecretConfig)) + verifyFileBasedValueSecrets(launchedTasks) + } + + private def launchExecutorTasks(sparkConfVars: Map[String, String]): List[TaskInfo] = { + setBackend(sparkConfVars) + + val (mem, cpu) = (backend.executorMemory(sc), 4) + + val offer1 = createOffer("o1", "s1", mem, cpu) + backend.resourceOffers(driver, List(offer1).asJava) + + verifyTaskLaunched(driver, "o1") + } + private case class Resources(mem: Int, cpus: Int, gpus: Int = 0) private def registerMockExecutor(executorId: String, slaveId: String, cores: Integer) = { diff --git a/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackendUtilSuite.scala b/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackendUtilSuite.scala index f49d7c29eda49..442c43960ec1f 100644 --- a/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackendUtilSuite.scala +++ b/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackendUtilSuite.scala @@ -18,6 +18,7 @@ package org.apache.spark.scheduler.cluster.mesos import org.apache.spark.{SparkConf, SparkFunSuite} +import org.apache.spark.deploy.mesos.config class MesosSchedulerBackendUtilSuite extends SparkFunSuite { @@ -26,7 +27,8 @@ class MesosSchedulerBackendUtilSuite extends SparkFunSuite { conf.set("spark.mesos.executor.docker.parameters", "a,b") conf.set("spark.mesos.executor.docker.image", "test") - val containerInfo = MesosSchedulerBackendUtil.containerInfo(conf) + val containerInfo = MesosSchedulerBackendUtil.buildContainerInfo( + conf) val params = containerInfo.getDocker.getParametersList assert(params.size() == 0) @@ -37,7 +39,8 @@ class MesosSchedulerBackendUtilSuite extends SparkFunSuite { conf.set("spark.mesos.executor.docker.parameters", "a=1,b=2,c=3") conf.set("spark.mesos.executor.docker.image", "test") - val containerInfo = MesosSchedulerBackendUtil.containerInfo(conf) + val containerInfo = MesosSchedulerBackendUtil.buildContainerInfo( + conf) val params = containerInfo.getDocker.getParametersList assert(params.size() == 3) assert(params.get(0).getKey == "a") diff --git a/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/Utils.scala b/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/Utils.scala index 833db0c1ff334..5636ac52bd4a7 100644 --- a/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/Utils.scala +++ b/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/Utils.scala @@ -24,9 +24,12 @@ import scala.collection.JavaConverters._ import org.apache.mesos.Protos._ import org.apache.mesos.Protos.Value.{Range => MesosRange, Ranges, Scalar} import org.apache.mesos.SchedulerDriver +import org.apache.mesos.protobuf.ByteString import org.mockito.{ArgumentCaptor, Matchers} import org.mockito.Mockito._ +import org.apache.spark.deploy.mesos.config.MesosSecretConfig + object Utils { val TEST_FRAMEWORK_ID = FrameworkID.newBuilder() @@ -105,4 +108,108 @@ object Utils { def createTaskId(taskId: String): TaskID = { TaskID.newBuilder().setValue(taskId).build() } + + def configEnvBasedRefSecrets(secretConfig: MesosSecretConfig): Map[String, String] = { + val secretName = "/path/to/secret,/anothersecret" + val envKey = "SECRET_ENV_KEY,PASSWORD" + Map( + secretConfig.SECRET_NAMES.key -> secretName, + secretConfig.SECRET_ENVKEYS.key -> envKey + ) + } + + def verifyEnvBasedRefSecrets(launchedTasks: List[TaskInfo]): Unit = { + val envVars = launchedTasks.head + .getCommand + .getEnvironment + .getVariablesList + .asScala + assert(envVars + .count(!_.getName.startsWith("SPARK_")) == 2) // user-defined secret env vars + val variableOne = envVars.filter(_.getName == "SECRET_ENV_KEY").head + assert(variableOne.getSecret.isInitialized) + assert(variableOne.getSecret.getType == Secret.Type.REFERENCE) + assert(variableOne.getSecret.getReference.getName == "/path/to/secret") + assert(variableOne.getType == Environment.Variable.Type.SECRET) + val variableTwo = envVars.filter(_.getName == "PASSWORD").head + assert(variableTwo.getSecret.isInitialized) + assert(variableTwo.getSecret.getType == Secret.Type.REFERENCE) + assert(variableTwo.getSecret.getReference.getName == "/anothersecret") + assert(variableTwo.getType == Environment.Variable.Type.SECRET) + } + + def configEnvBasedValueSecrets(secretConfig: MesosSecretConfig): Map[String, String] = { + val secretValues = "user,password" + val envKeys = "USER,PASSWORD" + Map( + secretConfig.SECRET_VALUES.key -> secretValues, + secretConfig.SECRET_ENVKEYS.key -> envKeys + ) + } + + def verifyEnvBasedValueSecrets(launchedTasks: List[TaskInfo]): Unit = { + val envVars = launchedTasks.head + .getCommand + .getEnvironment + .getVariablesList + .asScala + assert(envVars + .count(!_.getName.startsWith("SPARK_")) == 2) // user-defined secret env vars + val variableOne = envVars.filter(_.getName == "USER").head + assert(variableOne.getSecret.isInitialized) + assert(variableOne.getSecret.getType == Secret.Type.VALUE) + assert(variableOne.getSecret.getValue.getData == ByteString.copyFrom("user".getBytes)) + assert(variableOne.getType == Environment.Variable.Type.SECRET) + val variableTwo = envVars.filter(_.getName == "PASSWORD").head + assert(variableTwo.getSecret.isInitialized) + assert(variableTwo.getSecret.getType == Secret.Type.VALUE) + assert(variableTwo.getSecret.getValue.getData == ByteString.copyFrom("password".getBytes)) + assert(variableTwo.getType == Environment.Variable.Type.SECRET) + } + + def configFileBasedRefSecrets(secretConfig: MesosSecretConfig): Map[String, String] = { + val secretName = "/path/to/secret,/anothersecret" + val secretPath = "/topsecret,/mypassword" + Map( + secretConfig.SECRET_NAMES.key -> secretName, + secretConfig.SECRET_FILENAMES.key -> secretPath + ) + } + + def verifyFileBasedRefSecrets(launchedTasks: List[TaskInfo]): Unit = { + val volumes = launchedTasks.head.getContainer.getVolumesList + assert(volumes.size() == 2) + val secretVolOne = volumes.get(0) + assert(secretVolOne.getContainerPath == "/topsecret") + assert(secretVolOne.getSource.getSecret.getType == Secret.Type.REFERENCE) + assert(secretVolOne.getSource.getSecret.getReference.getName == "/path/to/secret") + val secretVolTwo = volumes.get(1) + assert(secretVolTwo.getContainerPath == "/mypassword") + assert(secretVolTwo.getSource.getSecret.getType == Secret.Type.REFERENCE) + assert(secretVolTwo.getSource.getSecret.getReference.getName == "/anothersecret") + } + + def configFileBasedValueSecrets(secretConfig: MesosSecretConfig): Map[String, String] = { + val secretValues = "user,password" + val secretPath = "/whoami,/mypassword" + Map( + secretConfig.SECRET_VALUES.key -> secretValues, + secretConfig.SECRET_FILENAMES.key -> secretPath + ) + } + + def verifyFileBasedValueSecrets(launchedTasks: List[TaskInfo]): Unit = { + val volumes = launchedTasks.head.getContainer.getVolumesList + assert(volumes.size() == 2) + val secretVolOne = volumes.get(0) + assert(secretVolOne.getContainerPath == "/whoami") + assert(secretVolOne.getSource.getSecret.getType == Secret.Type.VALUE) + assert(secretVolOne.getSource.getSecret.getValue.getData == + ByteString.copyFrom("user".getBytes)) + val secretVolTwo = volumes.get(1) + assert(secretVolTwo.getContainerPath == "/mypassword") + assert(secretVolTwo.getSource.getSecret.getType == Secret.Type.VALUE) + assert(secretVolTwo.getSource.getSecret.getValue.getData == + ByteString.copyFrom("password".getBytes)) + } } From 8e9863531bebbd4d83eafcbc2b359b8bd0ac5734 Mon Sep 17 00:00:00 2001 From: Jose Torres Date: Thu, 26 Oct 2017 16:55:30 -0700 Subject: [PATCH 1549/1765] [SPARK-22366] Support ignoring missing files ## What changes were proposed in this pull request? Add a flag "spark.sql.files.ignoreMissingFiles" to parallel the existing flag "spark.sql.files.ignoreCorruptFiles". ## How was this patch tested? new unit test Author: Jose Torres Closes #19581 from joseph-torres/SPARK-22366. --- .../apache/spark/sql/internal/SQLConf.scala | 8 +++++ .../execution/datasources/FileScanRDD.scala | 13 +++++--- .../parquet/ParquetQuerySuite.scala | 33 +++++++++++++++++++ 3 files changed, 50 insertions(+), 4 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 4cfe53b2c115b..21e4685fcc456 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -614,6 +614,12 @@ object SQLConf { .booleanConf .createWithDefault(false) + val IGNORE_MISSING_FILES = buildConf("spark.sql.files.ignoreMissingFiles") + .doc("Whether to ignore missing files. If true, the Spark jobs will continue to run when " + + "encountering missing files and the contents that have been read will still be returned.") + .booleanConf + .createWithDefault(false) + val MAX_RECORDS_PER_FILE = buildConf("spark.sql.files.maxRecordsPerFile") .doc("Maximum number of records to write out to a single file. " + "If this value is zero or negative, there is no limit.") @@ -1014,6 +1020,8 @@ class SQLConf extends Serializable with Logging { def ignoreCorruptFiles: Boolean = getConf(IGNORE_CORRUPT_FILES) + def ignoreMissingFiles: Boolean = getConf(IGNORE_MISSING_FILES) + def maxRecordsPerFile: Long = getConf(MAX_RECORDS_PER_FILE) def useCompression: Boolean = getConf(COMPRESS_CACHED) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileScanRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileScanRDD.scala index 9df20731c71d5..8731ee88f87f2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileScanRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileScanRDD.scala @@ -66,6 +66,7 @@ class FileScanRDD( extends RDD[InternalRow](sparkSession.sparkContext, Nil) { private val ignoreCorruptFiles = sparkSession.sessionState.conf.ignoreCorruptFiles + private val ignoreMissingFiles = sparkSession.sessionState.conf.ignoreMissingFiles override def compute(split: RDDPartition, context: TaskContext): Iterator[InternalRow] = { val iterator = new Iterator[Object] with AutoCloseable { @@ -142,7 +143,7 @@ class FileScanRDD( // Sets InputFileBlockHolder for the file block's information InputFileBlockHolder.set(currentFile.filePath, currentFile.start, currentFile.length) - if (ignoreCorruptFiles) { + if (ignoreMissingFiles || ignoreCorruptFiles) { currentIterator = new NextIterator[Object] { // The readFunction may read some bytes before consuming the iterator, e.g., // vectorized Parquet reader. Here we use lazy val to delay the creation of @@ -158,9 +159,13 @@ class FileScanRDD( null } } catch { - // Throw FileNotFoundException even `ignoreCorruptFiles` is true - case e: FileNotFoundException => throw e - case e @ (_: RuntimeException | _: IOException) => + case e: FileNotFoundException if ignoreMissingFiles => + logWarning(s"Skipped missing file: $currentFile", e) + finished = true + null + // Throw FileNotFoundException even if `ignoreCorruptFiles` is true + case e: FileNotFoundException if !ignoreMissingFiles => throw e + case e @ (_: RuntimeException | _: IOException) if ignoreCorruptFiles => logWarning( s"Skipped the rest of the content in the corrupted file: $currentFile", e) finished = true diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala index 2efff3f57d7d3..e822e40b146ee 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala @@ -316,6 +316,39 @@ class ParquetQuerySuite extends QueryTest with ParquetTest with SharedSQLContext } } + testQuietly("Enabling/disabling ignoreMissingFiles") { + def testIgnoreMissingFiles(): Unit = { + withTempDir { dir => + val basePath = dir.getCanonicalPath + spark.range(1).toDF("a").write.parquet(new Path(basePath, "first").toString) + spark.range(1, 2).toDF("a").write.parquet(new Path(basePath, "second").toString) + val thirdPath = new Path(basePath, "third") + spark.range(2, 3).toDF("a").write.parquet(thirdPath.toString) + val df = spark.read.parquet( + new Path(basePath, "first").toString, + new Path(basePath, "second").toString, + new Path(basePath, "third").toString) + + val fs = thirdPath.getFileSystem(spark.sparkContext.hadoopConfiguration) + fs.delete(thirdPath, true) + checkAnswer( + df, + Seq(Row(0), Row(1))) + } + } + + withSQLConf(SQLConf.IGNORE_MISSING_FILES.key -> "true") { + testIgnoreMissingFiles() + } + + withSQLConf(SQLConf.IGNORE_MISSING_FILES.key -> "false") { + val exception = intercept[SparkException] { + testIgnoreMissingFiles() + } + assert(exception.getMessage().contains("does not exist")) + } + } + /** * this is part of test 'Enabling/disabling ignoreCorruptFiles' but run in a loop * to increase the chance of failure From 9b262f6a08c0c1b474d920d49b9fdd574c401d39 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Thu, 26 Oct 2017 17:39:53 -0700 Subject: [PATCH 1550/1765] [SPARK-22356][SQL] data source table should support overlapped columns between data and partition schema ## What changes were proposed in this pull request? This is a regression introduced by #14207. After Spark 2.1, we store the inferred schema when creating the table, to avoid inferring schema again at read path. However, there is one special case: overlapped columns between data and partition. For this case, it breaks the assumption of table schema that there is on ovelap between data and partition schema, and partition columns should be at the end. The result is, for Spark 2.1, the table scan has incorrect schema that puts partition columns at the end. For Spark 2.2, we add a check in CatalogTable to validate table schema, which fails at this case. To fix this issue, a simple and safe approach is to fallback to old behavior when overlapeed columns detected, i.e. store empty schema in metastore. ## How was this patch tested? new regression test Author: Wenchen Fan Closes #19579 from cloud-fan/bug2. --- .../command/createDataSourceTables.scala | 35 +++++++++++++---- .../datasources/HadoopFsRelation.scala | 25 ++++++++---- .../org/apache/spark/sql/SQLQuerySuite.scala | 16 ++++++++ .../HiveExternalCatalogVersionsSuite.scala | 38 ++++++++++++++----- 4 files changed, 89 insertions(+), 25 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/createDataSourceTables.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/createDataSourceTables.scala index 9e3907996995c..306f43dc4214a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/createDataSourceTables.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/createDataSourceTables.scala @@ -24,6 +24,7 @@ import org.apache.spark.sql.catalyst.catalog._ import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.execution.datasources._ import org.apache.spark.sql.sources.BaseRelation +import org.apache.spark.sql.types.StructType /** * A command used to create a data source table. @@ -85,14 +86,32 @@ case class CreateDataSourceTableCommand(table: CatalogTable, ignoreIfExists: Boo } } - val newTable = table.copy( - schema = dataSource.schema, - partitionColumnNames = partitionColumnNames, - // If metastore partition management for file source tables is enabled, we start off with - // partition provider hive, but no partitions in the metastore. The user has to call - // `msck repair table` to populate the table partitions. - tracksPartitionsInCatalog = partitionColumnNames.nonEmpty && - sessionState.conf.manageFilesourcePartitions) + val newTable = dataSource match { + // Since Spark 2.1, we store the inferred schema of data source in metastore, to avoid + // inferring the schema again at read path. However if the data source has overlapped columns + // between data and partition schema, we can't store it in metastore as it breaks the + // assumption of table schema. Here we fallback to the behavior of Spark prior to 2.1, store + // empty schema in metastore and infer it at runtime. Note that this also means the new + // scalable partitioning handling feature(introduced at Spark 2.1) is disabled in this case. + case r: HadoopFsRelation if r.overlappedPartCols.nonEmpty => + logWarning("It is not recommended to create a table with overlapped data and partition " + + "columns, as Spark cannot store a valid table schema and has to infer it at runtime, " + + "which hurts performance. Please check your data files and remove the partition " + + "columns in it.") + table.copy(schema = new StructType(), partitionColumnNames = Nil) + + case _ => + table.copy( + schema = dataSource.schema, + partitionColumnNames = partitionColumnNames, + // If metastore partition management for file source tables is enabled, we start off with + // partition provider hive, but no partitions in the metastore. The user has to call + // `msck repair table` to populate the table partitions. + tracksPartitionsInCatalog = partitionColumnNames.nonEmpty && + sessionState.conf.manageFilesourcePartitions) + + } + // We will return Nil or throw exception at the beginning if the table already exists, so when // we reach here, the table should not exist and we should set `ignoreIfExists` to false. sessionState.catalog.createTable(newTable, ignoreIfExists = false) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/HadoopFsRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/HadoopFsRelation.scala index 9a08524476baa..89d8a85a9cbd2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/HadoopFsRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/HadoopFsRelation.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.execution.datasources +import java.util.Locale + import scala.collection.mutable import org.apache.spark.sql.{SparkSession, SQLContext} @@ -50,15 +52,22 @@ case class HadoopFsRelation( override def sqlContext: SQLContext = sparkSession.sqlContext - val schema: StructType = { - val getColName: (StructField => String) = - if (sparkSession.sessionState.conf.caseSensitiveAnalysis) _.name else _.name.toLowerCase - val overlappedPartCols = mutable.Map.empty[String, StructField] - partitionSchema.foreach { partitionField => - if (dataSchema.exists(getColName(_) == getColName(partitionField))) { - overlappedPartCols += getColName(partitionField) -> partitionField - } + private def getColName(f: StructField): String = { + if (sparkSession.sessionState.conf.caseSensitiveAnalysis) { + f.name + } else { + f.name.toLowerCase(Locale.ROOT) + } + } + + val overlappedPartCols = mutable.Map.empty[String, StructField] + partitionSchema.foreach { partitionField => + if (dataSchema.exists(getColName(_) == getColName(partitionField))) { + overlappedPartCols += getColName(partitionField) -> partitionField } + } + + val schema: StructType = { StructType(dataSchema.map(f => overlappedPartCols.getOrElse(getColName(f), f)) ++ partitionSchema.filterNot(f => overlappedPartCols.contains(getColName(f)))) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index caf332d050d7b..5d0bba69daca1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -2741,4 +2741,20 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { assert (aggregateExpressions.isDefined) assert (aggregateExpressions.get.size == 2) } + + test("SPARK-22356: overlapped columns between data and partition schema in data source tables") { + withTempPath { path => + Seq((1, 1, 1), (1, 2, 1)).toDF("i", "p", "j") + .write.mode("overwrite").parquet(new File(path, "p=1").getCanonicalPath) + withTable("t") { + sql(s"create table t using parquet options(path='${path.getCanonicalPath}')") + // We should respect the column order in data schema. + assert(spark.table("t").columns === Array("i", "p", "j")) + checkAnswer(spark.table("t"), Row(1, 1, 1) :: Row(1, 1, 1) :: Nil) + // The DESC TABLE should report same schema as table scan. + assert(sql("desc t").select("col_name") + .as[String].collect().mkString(",").contains("i,p,j")) + } + } + } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogVersionsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogVersionsSuite.scala index 5f8c9d5799662..6859432c406a9 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogVersionsSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogVersionsSuite.scala @@ -40,7 +40,7 @@ class HiveExternalCatalogVersionsSuite extends SparkSubmitTestUtils { private val tmpDataDir = Utils.createTempDir(namePrefix = "test-data") // For local test, you can set `sparkTestingDir` to a static value like `/tmp/test-spark`, to // avoid downloading Spark of different versions in each run. - private val sparkTestingDir = Utils.createTempDir(namePrefix = "test-spark") + private val sparkTestingDir = new File("/tmp/test-spark") private val unusedJar = TestUtils.createJarWithClasses(Seq.empty) override def afterAll(): Unit = { @@ -77,35 +77,38 @@ class HiveExternalCatalogVersionsSuite extends SparkSubmitTestUtils { super.beforeAll() val tempPyFile = File.createTempFile("test", ".py") + // scalastyle:off line.size.limit Files.write(tempPyFile.toPath, s""" |from pyspark.sql import SparkSession + |import os | |spark = SparkSession.builder.enableHiveSupport().getOrCreate() |version_index = spark.conf.get("spark.sql.test.version.index", None) | |spark.sql("create table data_source_tbl_{} using json as select 1 i".format(version_index)) | - |spark.sql("create table hive_compatible_data_source_tbl_" + version_index + \\ - | " using parquet as select 1 i") + |spark.sql("create table hive_compatible_data_source_tbl_{} using parquet as select 1 i".format(version_index)) | |json_file = "${genDataDir("json_")}" + str(version_index) |spark.range(1, 2).selectExpr("cast(id as int) as i").write.json(json_file) - |spark.sql("create table external_data_source_tbl_" + version_index + \\ - | "(i int) using json options (path '{}')".format(json_file)) + |spark.sql("create table external_data_source_tbl_{}(i int) using json options (path '{}')".format(version_index, json_file)) | |parquet_file = "${genDataDir("parquet_")}" + str(version_index) |spark.range(1, 2).selectExpr("cast(id as int) as i").write.parquet(parquet_file) - |spark.sql("create table hive_compatible_external_data_source_tbl_" + version_index + \\ - | "(i int) using parquet options (path '{}')".format(parquet_file)) + |spark.sql("create table hive_compatible_external_data_source_tbl_{}(i int) using parquet options (path '{}')".format(version_index, parquet_file)) | |json_file2 = "${genDataDir("json2_")}" + str(version_index) |spark.range(1, 2).selectExpr("cast(id as int) as i").write.json(json_file2) - |spark.sql("create table external_table_without_schema_" + version_index + \\ - | " using json options (path '{}')".format(json_file2)) + |spark.sql("create table external_table_without_schema_{} using json options (path '{}')".format(version_index, json_file2)) + | + |parquet_file2 = "${genDataDir("parquet2_")}" + str(version_index) + |spark.range(1, 3).selectExpr("1 as i", "cast(id as int) as p", "1 as j").write.parquet(os.path.join(parquet_file2, "p=1")) + |spark.sql("create table tbl_with_col_overlap_{} using parquet options(path '{}')".format(version_index, parquet_file2)) | |spark.sql("create view v_{} as select 1 i".format(version_index)) """.stripMargin.getBytes("utf8")) + // scalastyle:on line.size.limit PROCESS_TABLES.testingVersions.zipWithIndex.foreach { case (version, index) => val sparkHome = new File(sparkTestingDir, s"spark-$version") @@ -153,6 +156,7 @@ object PROCESS_TABLES extends QueryTest with SQLTestUtils { .enableHiveSupport() .getOrCreate() spark = session + import session.implicits._ testingVersions.indices.foreach { index => Seq( @@ -194,6 +198,22 @@ object PROCESS_TABLES extends QueryTest with SQLTestUtils { // test permanent view checkAnswer(sql(s"select i from v_$index"), Row(1)) + + // SPARK-22356: overlapped columns between data and partition schema in data source tables + val tbl_with_col_overlap = s"tbl_with_col_overlap_$index" + // For Spark 2.2.0 and 2.1.x, the behavior is different from Spark 2.0. + if (testingVersions(index).startsWith("2.1") || testingVersions(index) == "2.2.0") { + spark.sql("msck repair table " + tbl_with_col_overlap) + assert(spark.table(tbl_with_col_overlap).columns === Array("i", "j", "p")) + checkAnswer(spark.table(tbl_with_col_overlap), Row(1, 1, 1) :: Row(1, 1, 1) :: Nil) + assert(sql("desc " + tbl_with_col_overlap).select("col_name") + .as[String].collect().mkString(",").contains("i,j,p")) + } else { + assert(spark.table(tbl_with_col_overlap).columns === Array("i", "p", "j")) + checkAnswer(spark.table(tbl_with_col_overlap), Row(1, 1, 1) :: Row(1, 1, 1) :: Nil) + assert(sql("desc " + tbl_with_col_overlap).select("col_name") + .as[String].collect().mkString(",").contains("i,p,j")) + } } } } From 5c3a1f3fad695317c2fff1243cdb9b3ceb25c317 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Thu, 26 Oct 2017 17:51:16 -0700 Subject: [PATCH 1551/1765] [SPARK-22355][SQL] Dataset.collect is not threadsafe ## What changes were proposed in this pull request? It's possible that users create a `Dataset`, and call `collect` of this `Dataset` in many threads at the same time. Currently `Dataset#collect` just call `encoder.fromRow` to convert spark rows to objects of type T, and this encoder is per-dataset. This means `Dataset#collect` is not thread-safe, because the encoder uses a projection to output the object to a re-usable row. This PR fixes this problem, by creating a new projection when calling `Dataset#collect`, so that we have the re-usable row for each method call, instead of each Dataset. ## How was this patch tested? N/A Author: Wenchen Fan Closes #19577 from cloud-fan/encoder. --- .../scala/org/apache/spark/sql/Dataset.scala | 33 ++++++++++++------- 1 file changed, 22 insertions(+), 11 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index b70dfc05330f8..0e23983786b08 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -39,6 +39,7 @@ import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.catalog.HiveTableRelation import org.apache.spark.sql.catalyst.encoders._ import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.codegen.GenerateSafeProjection import org.apache.spark.sql.catalyst.json.{JacksonGenerator, JSONOptions} import org.apache.spark.sql.catalyst.optimizer.CombineUnions import org.apache.spark.sql.catalyst.parser.{ParseException, ParserUtils} @@ -198,15 +199,10 @@ class Dataset[T] private[sql]( */ private[sql] implicit val exprEnc: ExpressionEncoder[T] = encoderFor(encoder) - /** - * Encoder is used mostly as a container of serde expressions in Dataset. We build logical - * plans by these serde expressions and execute it within the query framework. However, for - * performance reasons we may want to use encoder as a function to deserialize internal rows to - * custom objects, e.g. collect. Here we resolve and bind the encoder so that we can call its - * `fromRow` method later. - */ - private val boundEnc = - exprEnc.resolveAndBind(logicalPlan.output, sparkSession.sessionState.analyzer) + // The deserializer expression which can be used to build a projection and turn rows to objects + // of type T, after collecting rows to the driver side. + private val deserializer = + exprEnc.resolveAndBind(logicalPlan.output, sparkSession.sessionState.analyzer).deserializer private implicit def classTag = exprEnc.clsTag @@ -2661,7 +2657,15 @@ class Dataset[T] private[sql]( */ def toLocalIterator(): java.util.Iterator[T] = { withAction("toLocalIterator", queryExecution) { plan => - plan.executeToIterator().map(boundEnc.fromRow).asJava + // This projection writes output to a `InternalRow`, which means applying this projection is + // not thread-safe. Here we create the projection inside this method to make `Dataset` + // thread-safe. + val objProj = GenerateSafeProjection.generate(deserializer :: Nil) + plan.executeToIterator().map { row => + // The row returned by SafeProjection is `SpecificInternalRow`, which ignore the data type + // parameter of its `get` method, so it's safe to use null here. + objProj(row).get(0, null).asInstanceOf[T] + }.asJava } } @@ -3102,7 +3106,14 @@ class Dataset[T] private[sql]( * Collect all elements from a spark plan. */ private def collectFromPlan(plan: SparkPlan): Array[T] = { - plan.executeCollect().map(boundEnc.fromRow) + // This projection writes output to a `InternalRow`, which means applying this projection is not + // thread-safe. Here we create the projection inside this method to make `Dataset` thread-safe. + val objProj = GenerateSafeProjection.generate(deserializer :: Nil) + plan.executeCollect().map { row => + // The row returned by SafeProjection is `SpecificInternalRow`, which ignore the data type + // parameter of its `get` method, so it's safe to use null here. + objProj(row).get(0, null).asInstanceOf[T] + } } private def sortInternal(global: Boolean, sortExprs: Seq[Column]): Dataset[T] = { From 17af727e38c3faaeab5b91a8cdab5f2181cf3fc4 Mon Sep 17 00:00:00 2001 From: Bryan Cutler Date: Thu, 26 Oct 2017 23:02:46 -0700 Subject: [PATCH 1552/1765] [SPARK-21375][PYSPARK][SQL] Add Date and Timestamp support to ArrowConverters for toPandas() Conversion ## What changes were proposed in this pull request? Adding date and timestamp support with Arrow for `toPandas()` and `pandas_udf`s. Timestamps are stored in Arrow as UTC and manifested to the user as timezone-naive localized to the Python system timezone. ## How was this patch tested? Added Scala tests for date and timestamp types under ArrowConverters, ArrowUtils, and ArrowWriter suites. Added Python tests for `toPandas()` and `pandas_udf`s with date and timestamp types. Author: Bryan Cutler Author: Takuya UESHIN Closes #18664 from BryanCutler/arrow-date-timestamp-SPARK-21375. --- python/pyspark/serializers.py | 24 +++- python/pyspark/sql/dataframe.py | 7 +- python/pyspark/sql/tests.py | 106 ++++++++++++++-- python/pyspark/sql/types.py | 36 ++++++ .../vectorized/ArrowColumnVector.java | 34 +++++ .../scala/org/apache/spark/sql/Dataset.scala | 4 +- .../sql/execution/arrow/ArrowConverters.scala | 3 +- .../sql/execution/arrow/ArrowUtils.scala | 30 +++-- .../sql/execution/arrow/ArrowWriter.scala | 39 +++++- .../python/ArrowEvalPythonExec.scala | 2 +- .../execution/python/ArrowPythonRunner.scala | 5 +- .../python/FlatMapGroupsInPandasExec.scala | 4 +- .../arrow/ArrowConvertersSuite.scala | 120 ++++++++++++++++-- .../sql/execution/arrow/ArrowUtilsSuite.scala | 26 +++- .../execution/arrow/ArrowWriterSuite.scala | 24 ++-- .../vectorized/ArrowColumnVectorSuite.scala | 22 ++-- .../vectorized/ColumnarBatchSuite.scala | 4 +- 17 files changed, 417 insertions(+), 73 deletions(-) diff --git a/python/pyspark/serializers.py b/python/pyspark/serializers.py index a0adeed994456..d7979f095da76 100644 --- a/python/pyspark/serializers.py +++ b/python/pyspark/serializers.py @@ -214,6 +214,7 @@ def __repr__(self): def _create_batch(series): + from pyspark.sql.types import _check_series_convert_timestamps_internal import pyarrow as pa # Make input conform to [(series1, type1), (series2, type2), ...] if not isinstance(series, (list, tuple)) or \ @@ -224,12 +225,25 @@ def _create_batch(series): # If a nullable integer series has been promoted to floating point with NaNs, need to cast # NOTE: this is not necessary with Arrow >= 0.7 def cast_series(s, t): - if t is None or s.dtype == t.to_pandas_dtype(): + if type(t) == pa.TimestampType: + # NOTE: convert to 'us' with astype here, unit ignored in `from_pandas` see ARROW-1680 + return _check_series_convert_timestamps_internal(s.fillna(0))\ + .values.astype('datetime64[us]', copy=False) + elif t == pa.date32(): + # TODO: this converts the series to Python objects, possibly avoid with Arrow >= 0.8 + return s.dt.date + elif t is None or s.dtype == t.to_pandas_dtype(): return s else: return s.fillna(0).astype(t.to_pandas_dtype(), copy=False) - arrs = [pa.Array.from_pandas(cast_series(s, t), mask=s.isnull(), type=t) for s, t in series] + # Some object types don't support masks in Arrow, see ARROW-1721 + def create_array(s, t): + casted = cast_series(s, t) + mask = None if casted.dtype == 'object' else s.isnull() + return pa.Array.from_pandas(casted, mask=mask, type=t) + + arrs = [create_array(s, t) for s, t in series] return pa.RecordBatch.from_arrays(arrs, ["_%d" % i for i in xrange(len(arrs))]) @@ -260,11 +274,13 @@ def load_stream(self, stream): """ Deserialize ArrowRecordBatches to an Arrow table and return as a list of pandas.Series. """ + from pyspark.sql.types import _check_dataframe_localize_timestamps import pyarrow as pa reader = pa.open_stream(stream) for batch in reader: - table = pa.Table.from_batches([batch]) - yield [c.to_pandas() for c in table.itercolumns()] + # NOTE: changed from pa.Columns.to_pandas, timezone issue in conversion fixed in 0.7.1 + pdf = _check_dataframe_localize_timestamps(batch.to_pandas()) + yield [c for _, c in pdf.iteritems()] def __repr__(self): return "ArrowStreamPandasSerializer" diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index c0b574e2b93a1..406686e6df724 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -1883,11 +1883,13 @@ def toPandas(self): import pandas as pd if self.sql_ctx.getConf("spark.sql.execution.arrow.enabled", "false").lower() == "true": try: + from pyspark.sql.types import _check_dataframe_localize_timestamps import pyarrow tables = self._collectAsArrow() if tables: table = pyarrow.concat_tables(tables) - return table.to_pandas() + pdf = table.to_pandas() + return _check_dataframe_localize_timestamps(pdf) else: return pd.DataFrame.from_records([], columns=self.columns) except ImportError as e: @@ -1955,6 +1957,7 @@ def _to_corrected_pandas_type(dt): """ When converting Spark SQL records to Pandas DataFrame, the inferred data type may be wrong. This method gets the corrected data type for Pandas if that type may be inferred uncorrectly. + NOTE: DateType is inferred incorrectly as 'object', TimestampType is correct with datetime64[ns] """ import numpy as np if type(dt) == ByteType: @@ -1965,6 +1968,8 @@ def _to_corrected_pandas_type(dt): return np.int32 elif type(dt) == FloatType: return np.float32 + elif type(dt) == DateType: + return 'datetime64[ns]' else: return None diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 685eebcafefba..98afae662b42d 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -3086,18 +3086,38 @@ class ArrowTests(ReusedPySparkTestCase): @classmethod def setUpClass(cls): + from datetime import datetime ReusedPySparkTestCase.setUpClass() + + # Synchronize default timezone between Python and Java + cls.tz_prev = os.environ.get("TZ", None) # save current tz if set + tz = "America/Los_Angeles" + os.environ["TZ"] = tz + time.tzset() + cls.spark = SparkSession(cls.sc) + cls.spark.conf.set("spark.sql.session.timeZone", tz) cls.spark.conf.set("spark.sql.execution.arrow.enabled", "true") cls.schema = StructType([ StructField("1_str_t", StringType(), True), StructField("2_int_t", IntegerType(), True), StructField("3_long_t", LongType(), True), StructField("4_float_t", FloatType(), True), - StructField("5_double_t", DoubleType(), True)]) - cls.data = [("a", 1, 10, 0.2, 2.0), - ("b", 2, 20, 0.4, 4.0), - ("c", 3, 30, 0.8, 6.0)] + StructField("5_double_t", DoubleType(), True), + StructField("6_date_t", DateType(), True), + StructField("7_timestamp_t", TimestampType(), True)]) + cls.data = [("a", 1, 10, 0.2, 2.0, datetime(1969, 1, 1), datetime(1969, 1, 1, 1, 1, 1)), + ("b", 2, 20, 0.4, 4.0, datetime(2012, 2, 2), datetime(2012, 2, 2, 2, 2, 2)), + ("c", 3, 30, 0.8, 6.0, datetime(2100, 3, 3), datetime(2100, 3, 3, 3, 3, 3))] + + @classmethod + def tearDownClass(cls): + del os.environ["TZ"] + if cls.tz_prev is not None: + os.environ["TZ"] = cls.tz_prev + time.tzset() + ReusedPySparkTestCase.tearDownClass() + cls.spark.stop() def assertFramesEqual(self, df_with_arrow, df_without): msg = ("DataFrame from Arrow is not equal" + @@ -3106,8 +3126,8 @@ def assertFramesEqual(self, df_with_arrow, df_without): self.assertTrue(df_without.equals(df_with_arrow), msg=msg) def test_unsupported_datatype(self): - schema = StructType([StructField("dt", DateType(), True)]) - df = self.spark.createDataFrame([(datetime.date(1970, 1, 1),)], schema=schema) + schema = StructType([StructField("decimal", DecimalType(), True)]) + df = self.spark.createDataFrame([(None,)], schema=schema) with QuietTest(self.sc): self.assertRaises(Exception, lambda: df.toPandas()) @@ -3385,13 +3405,77 @@ def test_vectorized_udf_varargs(self): def test_vectorized_udf_unsupported_types(self): from pyspark.sql.functions import pandas_udf, col - schema = StructType([StructField("dt", DateType(), True)]) - df = self.spark.createDataFrame([(datetime.date(1970, 1, 1),)], schema=schema) - f = pandas_udf(lambda x: x, DateType()) + schema = StructType([StructField("dt", DecimalType(), True)]) + df = self.spark.createDataFrame([(None,)], schema=schema) + f = pandas_udf(lambda x: x, DecimalType()) with QuietTest(self.sc): with self.assertRaisesRegexp(Exception, 'Unsupported data type'): df.select(f(col('dt'))).collect() + def test_vectorized_udf_null_date(self): + from pyspark.sql.functions import pandas_udf, col + from datetime import date + schema = StructType().add("date", DateType()) + data = [(date(1969, 1, 1),), + (date(2012, 2, 2),), + (None,), + (date(2100, 4, 4),)] + df = self.spark.createDataFrame(data, schema=schema) + date_f = pandas_udf(lambda t: t, returnType=DateType()) + res = df.select(date_f(col("date"))) + self.assertEquals(df.collect(), res.collect()) + + def test_vectorized_udf_timestamps(self): + from pyspark.sql.functions import pandas_udf, col + from datetime import datetime + schema = StructType([ + StructField("idx", LongType(), True), + StructField("timestamp", TimestampType(), True)]) + data = [(0, datetime(1969, 1, 1, 1, 1, 1)), + (1, datetime(2012, 2, 2, 2, 2, 2)), + (2, None), + (3, datetime(2100, 4, 4, 4, 4, 4))] + df = self.spark.createDataFrame(data, schema=schema) + + # Check that a timestamp passed through a pandas_udf will not be altered by timezone calc + f_timestamp_copy = pandas_udf(lambda t: t, returnType=TimestampType()) + df = df.withColumn("timestamp_copy", f_timestamp_copy(col("timestamp"))) + + @pandas_udf(returnType=BooleanType()) + def check_data(idx, timestamp, timestamp_copy): + is_equal = timestamp.isnull() # use this array to check values are equal + for i in range(len(idx)): + # Check that timestamps are as expected in the UDF + is_equal[i] = (is_equal[i] and data[idx[i]][1] is None) or \ + timestamp[i].to_pydatetime() == data[idx[i]][1] + return is_equal + + result = df.withColumn("is_equal", check_data(col("idx"), col("timestamp"), + col("timestamp_copy"))).collect() + # Check that collection values are correct + self.assertEquals(len(data), len(result)) + for i in range(len(result)): + self.assertEquals(data[i][1], result[i][1]) # "timestamp" col + self.assertTrue(result[i][3]) # "is_equal" data in udf was as expected + + def test_vectorized_udf_return_timestamp_tz(self): + from pyspark.sql.functions import pandas_udf, col + import pandas as pd + df = self.spark.range(10) + + @pandas_udf(returnType=TimestampType()) + def gen_timestamps(id): + ts = [pd.Timestamp(i, unit='D', tz='America/Los_Angeles') for i in id] + return pd.Series(ts) + + result = df.withColumn("ts", gen_timestamps(col("id"))).collect() + spark_ts_t = TimestampType() + for r in result: + i, ts = r + ts_tz = pd.Timestamp(i, unit='D', tz='America/Los_Angeles').to_pydatetime() + expected = spark_ts_t.fromInternal(spark_ts_t.toInternal(ts_tz)) + self.assertEquals(expected, ts) + @unittest.skipIf(not _have_pandas or not _have_arrow, "Pandas or Arrow not installed") class GroupbyApplyTests(ReusedPySparkTestCase): @@ -3550,8 +3634,8 @@ def test_wrong_args(self): def test_unsupported_types(self): from pyspark.sql.functions import pandas_udf, col schema = StructType( - [StructField("id", LongType(), True), StructField("dt", DateType(), True)]) - df = self.spark.createDataFrame([(1, datetime.date(1970, 1, 1),)], schema=schema) + [StructField("id", LongType(), True), StructField("dt", DecimalType(), True)]) + df = self.spark.createDataFrame([(1, None,)], schema=schema) f = pandas_udf(lambda x: x, df.schema) with QuietTest(self.sc): with self.assertRaisesRegexp(Exception, 'Unsupported data type'): diff --git a/python/pyspark/sql/types.py b/python/pyspark/sql/types.py index f65273d5f0b6c..7dd8fa04160e0 100644 --- a/python/pyspark/sql/types.py +++ b/python/pyspark/sql/types.py @@ -1619,11 +1619,47 @@ def to_arrow_type(dt): arrow_type = pa.decimal(dt.precision, dt.scale) elif type(dt) == StringType: arrow_type = pa.string() + elif type(dt) == DateType: + arrow_type = pa.date32() + elif type(dt) == TimestampType: + # Timestamps should be in UTC, JVM Arrow timestamps require a timezone to be read + arrow_type = pa.timestamp('us', tz='UTC') else: raise TypeError("Unsupported type in conversion to Arrow: " + str(dt)) return arrow_type +def _check_dataframe_localize_timestamps(pdf): + """ + Convert timezone aware timestamps to timezone-naive in local time + + :param pdf: pandas.DataFrame + :return pandas.DataFrame where any timezone aware columns have be converted to tz-naive + """ + from pandas.api.types import is_datetime64tz_dtype + for column, series in pdf.iteritems(): + # TODO: handle nested timestamps, such as ArrayType(TimestampType())? + if is_datetime64tz_dtype(series.dtype): + pdf[column] = series.dt.tz_convert('tzlocal()').dt.tz_localize(None) + return pdf + + +def _check_series_convert_timestamps_internal(s): + """ + Convert a tz-naive timestamp in local tz to UTC normalized for Spark internal storage + :param s: a pandas.Series + :return pandas.Series where if it is a timestamp, has been UTC normalized without a time zone + """ + from pandas.api.types import is_datetime64_dtype, is_datetime64tz_dtype + # TODO: handle nested timestamps, such as ArrayType(TimestampType())? + if is_datetime64_dtype(s.dtype): + return s.dt.tz_localize('tzlocal()').dt.tz_convert('UTC') + elif is_datetime64tz_dtype(s.dtype): + return s.dt.tz_convert('UTC') + else: + return s + + def _test(): import doctest from pyspark.context import SparkContext diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ArrowColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ArrowColumnVector.java index 1f171049820b2..51ea719f8c4a6 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ArrowColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ArrowColumnVector.java @@ -320,6 +320,10 @@ public ArrowColumnVector(ValueVector vector) { accessor = new StringAccessor((NullableVarCharVector) vector); } else if (vector instanceof NullableVarBinaryVector) { accessor = new BinaryAccessor((NullableVarBinaryVector) vector); + } else if (vector instanceof NullableDateDayVector) { + accessor = new DateAccessor((NullableDateDayVector) vector); + } else if (vector instanceof NullableTimeStampMicroTZVector) { + accessor = new TimestampAccessor((NullableTimeStampMicroTZVector) vector); } else if (vector instanceof ListVector) { ListVector listVector = (ListVector) vector; accessor = new ArrayAccessor(listVector); @@ -575,6 +579,36 @@ final byte[] getBinary(int rowId) { } } + private static class DateAccessor extends ArrowVectorAccessor { + + private final NullableDateDayVector.Accessor accessor; + + DateAccessor(NullableDateDayVector vector) { + super(vector); + this.accessor = vector.getAccessor(); + } + + @Override + final int getInt(int rowId) { + return accessor.get(rowId); + } + } + + private static class TimestampAccessor extends ArrowVectorAccessor { + + private final NullableTimeStampMicroTZVector.Accessor accessor; + + TimestampAccessor(NullableTimeStampMicroTZVector vector) { + super(vector); + this.accessor = vector.getAccessor(); + } + + @Override + final long getLong(int rowId) { + return accessor.get(rowId); + } + } + private static class ArrayAccessor extends ArrowVectorAccessor { private final UInt4Vector.Accessor accessor; diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index 0e23983786b08..fe4e192e43dfe 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -3154,9 +3154,11 @@ class Dataset[T] private[sql]( private[sql] def toArrowPayload: RDD[ArrowPayload] = { val schemaCaptured = this.schema val maxRecordsPerBatch = sparkSession.sessionState.conf.arrowMaxRecordsPerBatch + val timeZoneId = sparkSession.sessionState.conf.sessionLocalTimeZone queryExecution.toRdd.mapPartitionsInternal { iter => val context = TaskContext.get() - ArrowConverters.toPayloadIterator(iter, schemaCaptured, maxRecordsPerBatch, context) + ArrowConverters.toPayloadIterator( + iter, schemaCaptured, maxRecordsPerBatch, timeZoneId, context) } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala index 561a067a2f81f..05ea1517fcac9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala @@ -74,9 +74,10 @@ private[sql] object ArrowConverters { rowIter: Iterator[InternalRow], schema: StructType, maxRecordsPerBatch: Int, + timeZoneId: String, context: TaskContext): Iterator[ArrowPayload] = { - val arrowSchema = ArrowUtils.toArrowSchema(schema) + val arrowSchema = ArrowUtils.toArrowSchema(schema, timeZoneId) val allocator = ArrowUtils.rootAllocator.newChildAllocator("toPayloadIterator", 0, Long.MaxValue) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowUtils.scala index 2caf1ef02909a..6ad11bda84bf6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowUtils.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.execution.arrow import scala.collection.JavaConverters._ import org.apache.arrow.memory.RootAllocator -import org.apache.arrow.vector.types.FloatingPointPrecision +import org.apache.arrow.vector.types.{DateUnit, FloatingPointPrecision, TimeUnit} import org.apache.arrow.vector.types.pojo.{ArrowType, Field, FieldType, Schema} import org.apache.spark.sql.types._ @@ -31,7 +31,8 @@ object ArrowUtils { // todo: support more types. - def toArrowType(dt: DataType): ArrowType = dt match { + /** Maps data type from Spark to Arrow. NOTE: timeZoneId required for TimestampTypes */ + def toArrowType(dt: DataType, timeZoneId: String): ArrowType = dt match { case BooleanType => ArrowType.Bool.INSTANCE case ByteType => new ArrowType.Int(8, true) case ShortType => new ArrowType.Int(8 * 2, true) @@ -42,6 +43,13 @@ object ArrowUtils { case StringType => ArrowType.Utf8.INSTANCE case BinaryType => ArrowType.Binary.INSTANCE case DecimalType.Fixed(precision, scale) => new ArrowType.Decimal(precision, scale) + case DateType => new ArrowType.Date(DateUnit.DAY) + case TimestampType => + if (timeZoneId == null) { + throw new UnsupportedOperationException("TimestampType must supply timeZoneId parameter") + } else { + new ArrowType.Timestamp(TimeUnit.MICROSECOND, timeZoneId) + } case _ => throw new UnsupportedOperationException(s"Unsupported data type: ${dt.simpleString}") } @@ -58,22 +66,27 @@ object ArrowUtils { case ArrowType.Utf8.INSTANCE => StringType case ArrowType.Binary.INSTANCE => BinaryType case d: ArrowType.Decimal => DecimalType(d.getPrecision, d.getScale) + case date: ArrowType.Date if date.getUnit == DateUnit.DAY => DateType + case ts: ArrowType.Timestamp if ts.getUnit == TimeUnit.MICROSECOND => TimestampType case _ => throw new UnsupportedOperationException(s"Unsupported data type: $dt") } - def toArrowField(name: String, dt: DataType, nullable: Boolean): Field = { + /** Maps field from Spark to Arrow. NOTE: timeZoneId required for TimestampType */ + def toArrowField( + name: String, dt: DataType, nullable: Boolean, timeZoneId: String): Field = { dt match { case ArrayType(elementType, containsNull) => val fieldType = new FieldType(nullable, ArrowType.List.INSTANCE, null) - new Field(name, fieldType, Seq(toArrowField("element", elementType, containsNull)).asJava) + new Field(name, fieldType, + Seq(toArrowField("element", elementType, containsNull, timeZoneId)).asJava) case StructType(fields) => val fieldType = new FieldType(nullable, ArrowType.Struct.INSTANCE, null) new Field(name, fieldType, fields.map { field => - toArrowField(field.name, field.dataType, field.nullable) + toArrowField(field.name, field.dataType, field.nullable, timeZoneId) }.toSeq.asJava) case dataType => - val fieldType = new FieldType(nullable, toArrowType(dataType), null) + val fieldType = new FieldType(nullable, toArrowType(dataType, timeZoneId), null) new Field(name, fieldType, Seq.empty[Field].asJava) } } @@ -94,9 +107,10 @@ object ArrowUtils { } } - def toArrowSchema(schema: StructType): Schema = { + /** Maps schema from Spark to Arrow. NOTE: timeZoneId required for TimestampType in StructType */ + def toArrowSchema(schema: StructType, timeZoneId: String): Schema = { new Schema(schema.map { field => - toArrowField(field.name, field.dataType, field.nullable) + toArrowField(field.name, field.dataType, field.nullable, timeZoneId) }.asJava) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowWriter.scala index 0b740735ffe19..e4af4f65da127 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowWriter.scala @@ -21,7 +21,7 @@ import scala.collection.JavaConverters._ import org.apache.arrow.vector._ import org.apache.arrow.vector.complex._ -import org.apache.arrow.vector.util.DecimalUtility +import org.apache.arrow.vector.types.pojo.ArrowType import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.SpecializedGetters @@ -29,8 +29,8 @@ import org.apache.spark.sql.types._ object ArrowWriter { - def create(schema: StructType): ArrowWriter = { - val arrowSchema = ArrowUtils.toArrowSchema(schema) + def create(schema: StructType, timeZoneId: String): ArrowWriter = { + val arrowSchema = ArrowUtils.toArrowSchema(schema, timeZoneId) val root = VectorSchemaRoot.create(arrowSchema, ArrowUtils.rootAllocator) create(root) } @@ -55,6 +55,8 @@ object ArrowWriter { case (DoubleType, vector: NullableFloat8Vector) => new DoubleWriter(vector) case (StringType, vector: NullableVarCharVector) => new StringWriter(vector) case (BinaryType, vector: NullableVarBinaryVector) => new BinaryWriter(vector) + case (DateType, vector: NullableDateDayVector) => new DateWriter(vector) + case (TimestampType, vector: NullableTimeStampMicroTZVector) => new TimestampWriter(vector) case (ArrayType(_, _), vector: ListVector) => val elementVector = createFieldWriter(vector.getDataVector()) new ArrayWriter(vector, elementVector) @@ -69,9 +71,7 @@ object ArrowWriter { } } -class ArrowWriter( - val root: VectorSchemaRoot, - fields: Array[ArrowFieldWriter]) { +class ArrowWriter(val root: VectorSchemaRoot, fields: Array[ArrowFieldWriter]) { def schema: StructType = StructType(fields.map { f => StructField(f.name, f.dataType, f.nullable) @@ -255,6 +255,33 @@ private[arrow] class BinaryWriter( } } +private[arrow] class DateWriter(val valueVector: NullableDateDayVector) extends ArrowFieldWriter { + + override def valueMutator: NullableDateDayVector#Mutator = valueVector.getMutator() + + override def setNull(): Unit = { + valueMutator.setNull(count) + } + + override def setValue(input: SpecializedGetters, ordinal: Int): Unit = { + valueMutator.setSafe(count, input.getInt(ordinal)) + } +} + +private[arrow] class TimestampWriter( + val valueVector: NullableTimeStampMicroTZVector) extends ArrowFieldWriter { + + override def valueMutator: NullableTimeStampMicroTZVector#Mutator = valueVector.getMutator() + + override def setNull(): Unit = { + valueMutator.setNull(count) + } + + override def setValue(input: SpecializedGetters, ordinal: Int): Unit = { + valueMutator.setSafe(count, input.getLong(ordinal)) + } +} + private[arrow] class ArrayWriter( val valueVector: ListVector, val elementWriter: ArrowFieldWriter) extends ArrowFieldWriter { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonExec.scala index 81896187ecc46..0db463a5fbd89 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonExec.scala @@ -79,7 +79,7 @@ case class ArrowEvalPythonExec(udfs: Seq[PythonUDF], output: Seq[Attribute], chi val columnarBatchIter = new ArrowPythonRunner( funcs, bufferSize, reuseWorker, - PythonEvalType.SQL_PANDAS_UDF, argOffsets, schema) + PythonEvalType.SQL_PANDAS_UDF, argOffsets, schema, conf.sessionLocalTimeZone) .compute(batchIter, context.partitionId(), context) new Iterator[InternalRow] { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala index f6c03c415dc66..94c05b9b5e49f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala @@ -43,7 +43,8 @@ class ArrowPythonRunner( reuseWorker: Boolean, evalType: Int, argOffsets: Array[Array[Int]], - schema: StructType) + schema: StructType, + timeZoneId: String) extends BasePythonRunner[Iterator[InternalRow], ColumnarBatch]( funcs, bufferSize, reuseWorker, evalType, argOffsets) { @@ -60,7 +61,7 @@ class ArrowPythonRunner( } protected override def writeIteratorToStream(dataOut: DataOutputStream): Unit = { - val arrowSchema = ArrowUtils.toArrowSchema(schema) + val arrowSchema = ArrowUtils.toArrowSchema(schema, timeZoneId) val allocator = ArrowUtils.rootAllocator.newChildAllocator( s"stdout writer for $pythonExec", 0, Long.MaxValue) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasExec.scala index 5ed88ada428cb..cc93fda9f81da 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasExec.scala @@ -94,8 +94,8 @@ case class FlatMapGroupsInPandasExec( val columnarBatchIter = new ArrowPythonRunner( chainedFunc, bufferSize, reuseWorker, - PythonEvalType.SQL_PANDAS_GROUPED_UDF, argOffsets, schema) - .compute(grouped, context.partitionId(), context) + PythonEvalType.SQL_PANDAS_GROUPED_UDF, argOffsets, schema, conf.sessionLocalTimeZone) + .compute(grouped, context.partitionId(), context) columnarBatchIter.flatMap(_.rowIterator.asScala).map(UnsafeProjection.create(output, output)) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowConvertersSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowConvertersSuite.scala index 30422b657742c..ba2903babbba8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowConvertersSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowConvertersSuite.scala @@ -32,6 +32,8 @@ import org.scalatest.BeforeAndAfterAll import org.apache.spark.{SparkException, TaskContext} import org.apache.spark.sql.{DataFrame, Row} import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.util.DateTimeUtils +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types.{BinaryType, IntegerType, StructField, StructType} import org.apache.spark.util.Utils @@ -793,6 +795,103 @@ class ArrowConvertersSuite extends SharedSQLContext with BeforeAndAfterAll { collectAndValidate(df, json, "binaryData.json") } + test("date type conversion") { + val json = + s""" + |{ + | "schema" : { + | "fields" : [ { + | "name" : "date", + | "type" : { + | "name" : "date", + | "unit" : "DAY" + | }, + | "nullable" : true, + | "children" : [ ], + | "typeLayout" : { + | "vectors" : [ { + | "type" : "VALIDITY", + | "typeBitWidth" : 1 + | }, { + | "type" : "DATA", + | "typeBitWidth" : 32 + | } ] + | } + | } ] + | }, + | "batches" : [ { + | "count" : 4, + | "columns" : [ { + | "name" : "date", + | "count" : 4, + | "VALIDITY" : [ 1, 1, 1, 1 ], + | "DATA" : [ -1, 0, 16533, 382607 ] + | } ] + | } ] + |} + """.stripMargin + + val d1 = DateTimeUtils.toJavaDate(-1) // "1969-12-31" + val d2 = DateTimeUtils.toJavaDate(0) // "1970-01-01" + val d3 = Date.valueOf("2015-04-08") + val d4 = Date.valueOf("3017-07-18") + + val df = Seq(d1, d2, d3, d4).toDF("date") + + collectAndValidate(df, json, "dateData.json") + } + + test("timestamp type conversion") { + withSQLConf(SQLConf.SESSION_LOCAL_TIMEZONE.key -> "America/Los_Angeles") { + val json = + s""" + |{ + | "schema" : { + | "fields" : [ { + | "name" : "timestamp", + | "type" : { + | "name" : "timestamp", + | "unit" : "MICROSECOND", + | "timezone" : "America/Los_Angeles" + | }, + | "nullable" : true, + | "children" : [ ], + | "typeLayout" : { + | "vectors" : [ { + | "type" : "VALIDITY", + | "typeBitWidth" : 1 + | }, { + | "type" : "DATA", + | "typeBitWidth" : 64 + | } ] + | } + | } ] + | }, + | "batches" : [ { + | "count" : 4, + | "columns" : [ { + | "name" : "timestamp", + | "count" : 4, + | "VALIDITY" : [ 1, 1, 1, 1 ], + | "DATA" : [ -1234, 0, 1365383415567000, 33057298500000000 ] + | } ] + | } ] + |} + """.stripMargin + + val sdf = new SimpleDateFormat("yyyy-MM-dd HH:mm:ss.SSS z", Locale.US) + val ts1 = DateTimeUtils.toJavaTimestamp(-1234L) + val ts2 = DateTimeUtils.toJavaTimestamp(0L) + val ts3 = new Timestamp(sdf.parse("2013-04-08 01:10:15.567 UTC").getTime) + val ts4 = new Timestamp(sdf.parse("3017-07-18 14:55:00.000 UTC").getTime) + val data = Seq(ts1, ts2, ts3, ts4) + + val df = data.toDF("timestamp") + + collectAndValidate(df, json, "timestampData.json", "America/Los_Angeles") + } + } + test("floating-point NaN") { val json = s""" @@ -1486,15 +1585,6 @@ class ArrowConvertersSuite extends SharedSQLContext with BeforeAndAfterAll { runUnsupported { decimalData.toArrowPayload.collect() } runUnsupported { mapData.toDF().toArrowPayload.collect() } runUnsupported { complexData.toArrowPayload.collect() } - - val sdf = new SimpleDateFormat("yyyy-MM-dd HH:mm:ss.SSS z", Locale.US) - val d1 = new Date(sdf.parse("2015-04-08 13:10:15.000 UTC").getTime) - val d2 = new Date(sdf.parse("2016-05-09 13:10:15.000 UTC").getTime) - runUnsupported { Seq(d1, d2).toDF("date").toArrowPayload.collect() } - - val ts1 = new Timestamp(sdf.parse("2013-04-08 01:10:15.567 UTC").getTime) - val ts2 = new Timestamp(sdf.parse("2013-04-08 13:10:10.789 UTC").getTime) - runUnsupported { Seq(ts1, ts2).toDF("timestamp").toArrowPayload.collect() } } test("test Arrow Validator") { @@ -1638,7 +1728,7 @@ class ArrowConvertersSuite extends SharedSQLContext with BeforeAndAfterAll { val schema = StructType(Seq(StructField("int", IntegerType, nullable = true))) val ctx = TaskContext.empty() - val payloadIter = ArrowConverters.toPayloadIterator(inputRows.toIterator, schema, 0, ctx) + val payloadIter = ArrowConverters.toPayloadIterator(inputRows.toIterator, schema, 0, null, ctx) val outputRowIter = ArrowConverters.fromPayloadIterator(payloadIter, ctx) assert(schema.equals(outputRowIter.schema)) @@ -1657,22 +1747,24 @@ class ArrowConvertersSuite extends SharedSQLContext with BeforeAndAfterAll { } /** Test that a converted DataFrame to Arrow record batch equals batch read from JSON file */ - private def collectAndValidate(df: DataFrame, json: String, file: String): Unit = { + private def collectAndValidate( + df: DataFrame, json: String, file: String, timeZoneId: String = null): Unit = { // NOTE: coalesce to single partition because can only load 1 batch in validator val arrowPayload = df.coalesce(1).toArrowPayload.collect().head val tempFile = new File(tempDataPath, file) Files.write(json, tempFile, StandardCharsets.UTF_8) - validateConversion(df.schema, arrowPayload, tempFile) + validateConversion(df.schema, arrowPayload, tempFile, timeZoneId) } private def validateConversion( sparkSchema: StructType, arrowPayload: ArrowPayload, - jsonFile: File): Unit = { + jsonFile: File, + timeZoneId: String = null): Unit = { val allocator = new RootAllocator(Long.MaxValue) val jsonReader = new JsonFileReader(jsonFile, allocator) - val arrowSchema = ArrowUtils.toArrowSchema(sparkSchema) + val arrowSchema = ArrowUtils.toArrowSchema(sparkSchema, timeZoneId) val jsonSchema = jsonReader.start() Validator.compareSchemas(arrowSchema, jsonSchema) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowUtilsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowUtilsSuite.scala index 638619fd39d06..d801f62b62323 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowUtilsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowUtilsSuite.scala @@ -17,7 +17,10 @@ package org.apache.spark.sql.execution.arrow +import org.apache.arrow.vector.types.pojo.{ArrowType, Field, FieldType, Schema} + import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.types._ class ArrowUtilsSuite extends SparkFunSuite { @@ -25,7 +28,7 @@ class ArrowUtilsSuite extends SparkFunSuite { def roundtrip(dt: DataType): Unit = { dt match { case schema: StructType => - assert(ArrowUtils.fromArrowSchema(ArrowUtils.toArrowSchema(schema)) === schema) + assert(ArrowUtils.fromArrowSchema(ArrowUtils.toArrowSchema(schema, null)) === schema) case _ => roundtrip(new StructType().add("value", dt)) } @@ -42,6 +45,27 @@ class ArrowUtilsSuite extends SparkFunSuite { roundtrip(StringType) roundtrip(BinaryType) roundtrip(DecimalType.SYSTEM_DEFAULT) + roundtrip(DateType) + val tsExMsg = intercept[UnsupportedOperationException] { + roundtrip(TimestampType) + } + assert(tsExMsg.getMessage.contains("timeZoneId")) + } + + test("timestamp") { + + def roundtripWithTz(timeZoneId: String): Unit = { + val schema = new StructType().add("value", TimestampType) + val arrowSchema = ArrowUtils.toArrowSchema(schema, timeZoneId) + val fieldType = arrowSchema.findField("value").getType.asInstanceOf[ArrowType.Timestamp] + assert(fieldType.getTimezone() === timeZoneId) + assert(ArrowUtils.fromArrowSchema(arrowSchema) === schema) + } + + roundtripWithTz(DateTimeUtils.defaultTimeZone().getID) + roundtripWithTz("Asia/Tokyo") + roundtripWithTz("UTC") + roundtripWithTz("America/Los_Angeles") } test("array") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowWriterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowWriterSuite.scala index e9a629315f5f4..a71e30aa3ca96 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowWriterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowWriterSuite.scala @@ -27,9 +27,9 @@ import org.apache.spark.unsafe.types.UTF8String class ArrowWriterSuite extends SparkFunSuite { test("simple") { - def check(dt: DataType, data: Seq[Any]): Unit = { + def check(dt: DataType, data: Seq[Any], timeZoneId: String = null): Unit = { val schema = new StructType().add("value", dt, nullable = true) - val writer = ArrowWriter.create(schema) + val writer = ArrowWriter.create(schema, timeZoneId) assert(writer.schema === schema) data.foreach { datum => @@ -51,6 +51,8 @@ class ArrowWriterSuite extends SparkFunSuite { case DoubleType => reader.getDouble(rowId) case StringType => reader.getUTF8String(rowId) case BinaryType => reader.getBinary(rowId) + case DateType => reader.getInt(rowId) + case TimestampType => reader.getLong(rowId) } assert(value === datum) } @@ -66,12 +68,14 @@ class ArrowWriterSuite extends SparkFunSuite { check(DoubleType, Seq(1.0d, 2.0d, null, 4.0d)) check(StringType, Seq("a", "b", null, "d").map(UTF8String.fromString)) check(BinaryType, Seq("a".getBytes(), "b".getBytes(), null, "d".getBytes())) + check(DateType, Seq(0, 1, 2, null, 4)) + check(TimestampType, Seq(0L, 3.6e9.toLong, null, 8.64e10.toLong), "America/Los_Angeles") } test("get multiple") { - def check(dt: DataType, data: Seq[Any]): Unit = { + def check(dt: DataType, data: Seq[Any], timeZoneId: String = null): Unit = { val schema = new StructType().add("value", dt, nullable = false) - val writer = ArrowWriter.create(schema) + val writer = ArrowWriter.create(schema, timeZoneId) assert(writer.schema === schema) data.foreach { datum => @@ -88,6 +92,8 @@ class ArrowWriterSuite extends SparkFunSuite { case LongType => reader.getLongs(0, data.size) case FloatType => reader.getFloats(0, data.size) case DoubleType => reader.getDoubles(0, data.size) + case DateType => reader.getInts(0, data.size) + case TimestampType => reader.getLongs(0, data.size) } assert(values === data) @@ -100,12 +106,14 @@ class ArrowWriterSuite extends SparkFunSuite { check(LongType, (0 until 10).map(_.toLong)) check(FloatType, (0 until 10).map(_.toFloat)) check(DoubleType, (0 until 10).map(_.toDouble)) + check(DateType, (0 until 10)) + check(TimestampType, (0 until 10).map(_ * 4.32e10.toLong), "America/Los_Angeles") } test("array") { val schema = new StructType() .add("arr", ArrayType(IntegerType, containsNull = true), nullable = true) - val writer = ArrowWriter.create(schema) + val writer = ArrowWriter.create(schema, null) assert(writer.schema === schema) writer.write(InternalRow(ArrayData.toArrayData(Array(1, 2, 3)))) @@ -144,7 +152,7 @@ class ArrowWriterSuite extends SparkFunSuite { test("nested array") { val schema = new StructType().add("nested", ArrayType(ArrayType(IntegerType))) - val writer = ArrowWriter.create(schema) + val writer = ArrowWriter.create(schema, null) assert(writer.schema === schema) writer.write(InternalRow(ArrayData.toArrayData(Array( @@ -195,7 +203,7 @@ class ArrowWriterSuite extends SparkFunSuite { test("struct") { val schema = new StructType() .add("struct", new StructType().add("i", IntegerType).add("str", StringType)) - val writer = ArrowWriter.create(schema) + val writer = ArrowWriter.create(schema, null) assert(writer.schema === schema) writer.write(InternalRow(InternalRow(1, UTF8String.fromString("str1")))) @@ -231,7 +239,7 @@ class ArrowWriterSuite extends SparkFunSuite { test("nested struct") { val schema = new StructType().add("struct", new StructType().add("nested", new StructType().add("i", IntegerType).add("str", StringType))) - val writer = ArrowWriter.create(schema) + val writer = ArrowWriter.create(schema, null) assert(writer.schema === schema) writer.write(InternalRow(InternalRow(InternalRow(1, UTF8String.fromString("str1"))))) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ArrowColumnVectorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ArrowColumnVectorSuite.scala index d24a9e1f4bd16..068a17bf772e1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ArrowColumnVectorSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ArrowColumnVectorSuite.scala @@ -29,7 +29,7 @@ class ArrowColumnVectorSuite extends SparkFunSuite { test("boolean") { val allocator = ArrowUtils.rootAllocator.newChildAllocator("boolean", 0, Long.MaxValue) - val vector = ArrowUtils.toArrowField("boolean", BooleanType, nullable = true) + val vector = ArrowUtils.toArrowField("boolean", BooleanType, nullable = true, null) .createVector(allocator).asInstanceOf[NullableBitVector] vector.allocateNew() val mutator = vector.getMutator() @@ -58,7 +58,7 @@ class ArrowColumnVectorSuite extends SparkFunSuite { test("byte") { val allocator = ArrowUtils.rootAllocator.newChildAllocator("byte", 0, Long.MaxValue) - val vector = ArrowUtils.toArrowField("byte", ByteType, nullable = true) + val vector = ArrowUtils.toArrowField("byte", ByteType, nullable = true, null) .createVector(allocator).asInstanceOf[NullableTinyIntVector] vector.allocateNew() val mutator = vector.getMutator() @@ -87,7 +87,7 @@ class ArrowColumnVectorSuite extends SparkFunSuite { test("short") { val allocator = ArrowUtils.rootAllocator.newChildAllocator("short", 0, Long.MaxValue) - val vector = ArrowUtils.toArrowField("short", ShortType, nullable = true) + val vector = ArrowUtils.toArrowField("short", ShortType, nullable = true, null) .createVector(allocator).asInstanceOf[NullableSmallIntVector] vector.allocateNew() val mutator = vector.getMutator() @@ -116,7 +116,7 @@ class ArrowColumnVectorSuite extends SparkFunSuite { test("int") { val allocator = ArrowUtils.rootAllocator.newChildAllocator("int", 0, Long.MaxValue) - val vector = ArrowUtils.toArrowField("int", IntegerType, nullable = true) + val vector = ArrowUtils.toArrowField("int", IntegerType, nullable = true, null) .createVector(allocator).asInstanceOf[NullableIntVector] vector.allocateNew() val mutator = vector.getMutator() @@ -145,7 +145,7 @@ class ArrowColumnVectorSuite extends SparkFunSuite { test("long") { val allocator = ArrowUtils.rootAllocator.newChildAllocator("long", 0, Long.MaxValue) - val vector = ArrowUtils.toArrowField("long", LongType, nullable = true) + val vector = ArrowUtils.toArrowField("long", LongType, nullable = true, null) .createVector(allocator).asInstanceOf[NullableBigIntVector] vector.allocateNew() val mutator = vector.getMutator() @@ -174,7 +174,7 @@ class ArrowColumnVectorSuite extends SparkFunSuite { test("float") { val allocator = ArrowUtils.rootAllocator.newChildAllocator("float", 0, Long.MaxValue) - val vector = ArrowUtils.toArrowField("float", FloatType, nullable = true) + val vector = ArrowUtils.toArrowField("float", FloatType, nullable = true, null) .createVector(allocator).asInstanceOf[NullableFloat4Vector] vector.allocateNew() val mutator = vector.getMutator() @@ -203,7 +203,7 @@ class ArrowColumnVectorSuite extends SparkFunSuite { test("double") { val allocator = ArrowUtils.rootAllocator.newChildAllocator("double", 0, Long.MaxValue) - val vector = ArrowUtils.toArrowField("double", DoubleType, nullable = true) + val vector = ArrowUtils.toArrowField("double", DoubleType, nullable = true, null) .createVector(allocator).asInstanceOf[NullableFloat8Vector] vector.allocateNew() val mutator = vector.getMutator() @@ -232,7 +232,7 @@ class ArrowColumnVectorSuite extends SparkFunSuite { test("string") { val allocator = ArrowUtils.rootAllocator.newChildAllocator("string", 0, Long.MaxValue) - val vector = ArrowUtils.toArrowField("string", StringType, nullable = true) + val vector = ArrowUtils.toArrowField("string", StringType, nullable = true, null) .createVector(allocator).asInstanceOf[NullableVarCharVector] vector.allocateNew() val mutator = vector.getMutator() @@ -260,7 +260,7 @@ class ArrowColumnVectorSuite extends SparkFunSuite { test("binary") { val allocator = ArrowUtils.rootAllocator.newChildAllocator("binary", 0, Long.MaxValue) - val vector = ArrowUtils.toArrowField("binary", BinaryType, nullable = true) + val vector = ArrowUtils.toArrowField("binary", BinaryType, nullable = true, null) .createVector(allocator).asInstanceOf[NullableVarBinaryVector] vector.allocateNew() val mutator = vector.getMutator() @@ -288,7 +288,7 @@ class ArrowColumnVectorSuite extends SparkFunSuite { test("array") { val allocator = ArrowUtils.rootAllocator.newChildAllocator("array", 0, Long.MaxValue) - val vector = ArrowUtils.toArrowField("array", ArrayType(IntegerType), nullable = true) + val vector = ArrowUtils.toArrowField("array", ArrayType(IntegerType), nullable = true, null) .createVector(allocator).asInstanceOf[ListVector] vector.allocateNew() val mutator = vector.getMutator() @@ -345,7 +345,7 @@ class ArrowColumnVectorSuite extends SparkFunSuite { test("struct") { val allocator = ArrowUtils.rootAllocator.newChildAllocator("struct", 0, Long.MaxValue) val schema = new StructType().add("int", IntegerType).add("long", LongType) - val vector = ArrowUtils.toArrowField("struct", schema, nullable = true) + val vector = ArrowUtils.toArrowField("struct", schema, nullable = true, null) .createVector(allocator).asInstanceOf[NullableMapVector] vector.allocateNew() val mutator = vector.getMutator() diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala index 0b179aa97c479..4cfc776e51db1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala @@ -1249,11 +1249,11 @@ class ColumnarBatchSuite extends SparkFunSuite { test("create columnar batch from Arrow column vectors") { val allocator = ArrowUtils.rootAllocator.newChildAllocator("int", 0, Long.MaxValue) - val vector1 = ArrowUtils.toArrowField("int1", IntegerType, nullable = true) + val vector1 = ArrowUtils.toArrowField("int1", IntegerType, nullable = true, null) .createVector(allocator).asInstanceOf[NullableIntVector] vector1.allocateNew() val mutator1 = vector1.getMutator() - val vector2 = ArrowUtils.toArrowField("int2", IntegerType, nullable = true) + val vector2 = ArrowUtils.toArrowField("int2", IntegerType, nullable = true, null) .createVector(allocator).asInstanceOf[NullableIntVector] vector2.allocateNew() val mutator2 = vector2.getMutator() From 36b826f5d17ae7be89135cb2c43ff797f9e7fe48 Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Fri, 27 Oct 2017 07:52:10 -0700 Subject: [PATCH 1553/1765] [TRIVIAL][SQL] Code cleaning in ResolveReferences ## What changes were proposed in this pull request? This PR is to clean the related codes majorly based on the today's code review on https://github.com/apache/spark/pull/19559 ## How was this patch tested? N/A Author: gatorsmile Closes #19585 from gatorsmile/trivialFixes. --- .../sql/catalyst/analysis/Analyzer.scala | 21 +++++++++++-------- .../scala/org/apache/spark/sql/Column.scala | 10 ++++----- .../spark/sql/RelationalGroupedDataset.scala | 4 ++-- .../sql/execution/WholeStageCodegenExec.scala | 5 ++--- 4 files changed, 21 insertions(+), 19 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index d6a962a14dc9c..6384a141e83b3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -783,6 +783,17 @@ class Analyzer( } } + private def resolve(e: Expression, q: LogicalPlan): Expression = e match { + case u @ UnresolvedAttribute(nameParts) => + // Leave unchanged if resolution fails. Hopefully will be resolved next round. + val result = withPosition(u) { q.resolveChildren(nameParts, resolver).getOrElse(u) } + logDebug(s"Resolving $u to $result") + result + case UnresolvedExtractValue(child, fieldExpr) if child.resolved => + ExtractValue(child, fieldExpr, resolver) + case _ => e.mapChildren(resolve(_, q)) + } + def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators { case p: LogicalPlan if !p.childrenResolved => p @@ -841,15 +852,7 @@ class Analyzer( case q: LogicalPlan => logTrace(s"Attempting to resolve ${q.simpleString}") - q.transformExpressionsUp { - case u @ UnresolvedAttribute(nameParts) => - // Leave unchanged if resolution fails. Hopefully will be resolved next round. - val result = withPosition(u) { q.resolveChildren(nameParts, resolver).getOrElse(u) } - logDebug(s"Resolving $u to $result") - result - case UnresolvedExtractValue(child, fieldExpr) if child.resolved => - ExtractValue(child, fieldExpr, resolver) - } + q.mapExpressions(resolve(_, q)) } def newAliases(expressions: Seq[NamedExpression]): Seq[NamedExpression] = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala index 8468a8a96349a..92988680871a4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala @@ -26,7 +26,7 @@ import org.apache.spark.sql.catalyst.encoders.{encoderFor, ExpressionEncoder} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression import org.apache.spark.sql.catalyst.parser.CatalystSqlParser -import org.apache.spark.sql.catalyst.util.usePrettyExpression +import org.apache.spark.sql.catalyst.util.toPrettySQL import org.apache.spark.sql.execution.aggregate.TypedAggregateExpression import org.apache.spark.sql.expressions.Window import org.apache.spark.sql.functions.lit @@ -44,7 +44,7 @@ private[sql] object Column { e match { case a: AggregateExpression if a.aggregateFunction.isInstanceOf[TypedAggregateExpression] => a.aggregateFunction.toString - case expr => usePrettyExpression(expr).sql + case expr => toPrettySQL(expr) } } } @@ -137,7 +137,7 @@ class Column(val expr: Expression) extends Logging { case _ => UnresolvedAttribute.quotedString(name) }) - override def toString: String = usePrettyExpression(expr).sql + override def toString: String = toPrettySQL(expr) override def equals(that: Any): Boolean = that match { case that: Column => that.expr.equals(this.expr) @@ -175,7 +175,7 @@ class Column(val expr: Expression) extends Logging { case c @ Cast(_: NamedExpression, _, _) => UnresolvedAlias(c) } match { case ne: NamedExpression => ne - case other => Alias(expr, usePrettyExpression(expr).sql)() + case _ => Alias(expr, toPrettySQL(expr))() } case a: AggregateExpression if a.aggregateFunction.isInstanceOf[TypedAggregateExpression] => @@ -184,7 +184,7 @@ class Column(val expr: Expression) extends Logging { // Wait until the struct is resolved. This will generate a nicer looking alias. case struct: CreateNamedStructLike => UnresolvedAlias(struct) - case expr: Expression => Alias(expr, usePrettyExpression(expr).sql)() + case expr: Expression => Alias(expr, toPrettySQL(expr))() } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala index 6b45790d5ff6e..21e94fa8bb0b1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala @@ -28,7 +28,7 @@ import org.apache.spark.sql.catalyst.analysis.{Star, UnresolvedAlias, Unresolved import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.plans.logical._ -import org.apache.spark.sql.catalyst.util.usePrettyExpression +import org.apache.spark.sql.catalyst.util.toPrettySQL import org.apache.spark.sql.execution.aggregate.TypedAggregateExpression import org.apache.spark.sql.execution.python.{PythonUDF, PythonUdfType} import org.apache.spark.sql.internal.SQLConf @@ -85,7 +85,7 @@ class RelationalGroupedDataset protected[sql]( case expr: NamedExpression => expr case a: AggregateExpression if a.aggregateFunction.isInstanceOf[TypedAggregateExpression] => UnresolvedAlias(a, Some(Column.generateAlias)) - case expr: Expression => Alias(expr, usePrettyExpression(expr).sql)() + case expr: Expression => Alias(expr, toPrettySQL(expr))() } private[this] def aggregateNumericColumns(colNames: String*)(f: Expression => AggregateFunction) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala index e37d133ff336a..286cb3bb0767c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala @@ -521,10 +521,9 @@ case class CollapseCodegenStages(conf: SQLConf) extends Rule[SparkPlan] { case p if !supportCodegen(p) => // collapse them recursively InputAdapter(insertWholeStageCodegen(p)) - case j @ SortMergeJoinExec(_, _, _, _, left, right) => + case j: SortMergeJoinExec => // The children of SortMergeJoin should do codegen separately. - j.copy(left = InputAdapter(insertWholeStageCodegen(left)), - right = InputAdapter(insertWholeStageCodegen(right))) + j.withNewChildren(j.children.map(child => InputAdapter(insertWholeStageCodegen(child)))) case p => p.withNewChildren(p.children.map(insertInputAdapter)) } From b3d8fc3dc458d42cf11d961762ce99f551f68548 Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Fri, 27 Oct 2017 13:43:09 -0700 Subject: [PATCH 1554/1765] [SPARK-22226][SQL] splitExpression can create too many method calls in the outer class ## What changes were proposed in this pull request? SPARK-18016 introduced `NestedClass` to avoid that the many methods generated by `splitExpressions` contribute to the outer class' constant pool, making it growing too much. Unfortunately, despite their definition is stored in the `NestedClass`, they all are invoked in the outer class and for each method invocation, there are two entries added to the constant pool: a `Methodref` and a `Utf8` entry (you can easily check this compiling a simple sample class with `janinoc` and looking at its Constant Pool). This limits the scalability of the solution with very large methods which are split in a lot of small ones. This means that currently we are generating classes like this one: ``` class SpecificUnsafeProjection extends org.apache.spark.sql.catalyst.expressions.UnsafeProjection { ... public UnsafeRow apply(InternalRow i) { rowWriter.zeroOutNullBytes(); apply_0(i); apply_1(i); ... nestedClassInstance.apply_862(i); nestedClassInstance.apply_863(i); ... nestedClassInstance1.apply_1612(i); nestedClassInstance1.apply_1613(i); ... } ... private class NestedClass { private void apply_862(InternalRow i) { ... } private void apply_863(InternalRow i) { ... } ... } private class NestedClass1 { private void apply_1612(InternalRow i) { ... } private void apply_1613(InternalRow i) { ... } ... } } ``` This PR reduce the Constant Pool size of the outer class by adding a new method to each nested class: in this method we invoke all the small methods generated by `splitExpression` in that nested class. In this way, in the outer class there is only one method invocation per nested class, reducing by orders of magnitude the entries in its constant pool because of method invocations. This means that after the patch the generated code becomes: ``` class SpecificUnsafeProjection extends org.apache.spark.sql.catalyst.expressions.UnsafeProjection { ... public UnsafeRow apply(InternalRow i) { rowWriter.zeroOutNullBytes(); apply_0(i); apply_1(i); ... nestedClassInstance.apply(i); nestedClassInstance1.apply(i); ... } ... private class NestedClass { private void apply_862(InternalRow i) { ... } private void apply_863(InternalRow i) { ... } ... private void apply(InternalRow i) { apply_862(i); apply_863(i); ... } } private class NestedClass1 { private void apply_1612(InternalRow i) { ... } private void apply_1613(InternalRow i) { ... } ... private void apply(InternalRow i) { apply_1612(i); apply_1613(i); ... } } } ``` ## How was this patch tested? Added UT and existing UTs Author: Marco Gaido Author: Marco Gaido Closes #19480 from mgaido91/SPARK-22226. --- .../expressions/codegen/CodeGenerator.scala | 156 ++++++++++++++++-- .../expressions/CodeGenerationSuite.scala | 17 ++ .../org/apache/spark/sql/DataFrameSuite.scala | 12 ++ 3 files changed, 167 insertions(+), 18 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index 2cb66599076a9..58738b52b299f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -77,6 +77,22 @@ case class SubExprEliminationState(isNull: String, value: String) */ case class SubExprCodes(codes: Seq[String], states: Map[Expression, SubExprEliminationState]) +/** + * The main information about a new added function. + * + * @param functionName String representing the name of the function + * @param innerClassName Optional value which is empty if the function is added to + * the outer class, otherwise it contains the name of the + * inner class in which the function has been added. + * @param innerClassInstance Optional value which is empty if the function is added to + * the outer class, otherwise it contains the name of the + * instance of the inner class in the outer class. + */ +private[codegen] case class NewFunctionSpec( + functionName: String, + innerClassName: Option[String], + innerClassInstance: Option[String]) + /** * A context for codegen, tracking a list of objects that could be passed into generated Java * function. @@ -228,8 +244,8 @@ class CodegenContext { /** * Holds the class and instance names to be generated, where `OuterClass` is a placeholder * standing for whichever class is generated as the outermost class and which will contain any - * nested sub-classes. All other classes and instance names in this list will represent private, - * nested sub-classes. + * inner sub-classes. All other classes and instance names in this list will represent private, + * inner sub-classes. */ private val classes: mutable.ListBuffer[(String, String)] = mutable.ListBuffer[(String, String)](outerClassName -> null) @@ -260,8 +276,8 @@ class CodegenContext { /** * Adds a function to the generated class. If the code for the `OuterClass` grows too large, the - * function will be inlined into a new private, nested class, and a class-qualified name for the - * function will be returned. Otherwise, the function will be inined to the `OuterClass` the + * function will be inlined into a new private, inner class, and a class-qualified name for the + * function will be returned. Otherwise, the function will be inlined to the `OuterClass` the * simple `funcName` will be returned. * * @param funcName the class-unqualified name of the function @@ -271,19 +287,27 @@ class CodegenContext { * it is eventually referenced and a returned qualified function name * cannot otherwise be accessed. * @return the name of the function, qualified by class if it will be inlined to a private, - * nested sub-class + * inner class */ def addNewFunction( funcName: String, funcCode: String, inlineToOuterClass: Boolean = false): String = { - // The number of named constants that can exist in the class is limited by the Constant Pool - // limit, 65,536. We cannot know how many constants will be inserted for a class, so we use a - // threshold of 1600k bytes to determine when a function should be inlined to a private, nested - // sub-class. + val newFunction = addNewFunctionInternal(funcName, funcCode, inlineToOuterClass) + newFunction match { + case NewFunctionSpec(functionName, None, None) => functionName + case NewFunctionSpec(functionName, Some(_), Some(innerClassInstance)) => + innerClassInstance + "." + functionName + } + } + + private[this] def addNewFunctionInternal( + funcName: String, + funcCode: String, + inlineToOuterClass: Boolean): NewFunctionSpec = { val (className, classInstance) = if (inlineToOuterClass) { outerClassName -> "" - } else if (currClassSize > 1600000) { + } else if (currClassSize > CodeGenerator.GENERATED_CLASS_SIZE_THRESHOLD) { val className = freshName("NestedClass") val classInstance = freshName("nestedClassInstance") @@ -294,17 +318,23 @@ class CodegenContext { currClass() } - classSize(className) += funcCode.length - classFunctions(className) += funcName -> funcCode + addNewFunctionToClass(funcName, funcCode, className) if (className == outerClassName) { - funcName + NewFunctionSpec(funcName, None, None) } else { - - s"$classInstance.$funcName" + NewFunctionSpec(funcName, Some(className), Some(classInstance)) } } + private[this] def addNewFunctionToClass( + funcName: String, + funcCode: String, + className: String) = { + classSize(className) += funcCode.length + classFunctions(className) += funcName -> funcCode + } + /** * Declares all function code. If the added functions are too many, split them into nested * sub-classes to avoid hitting Java compiler constant pool limitation. @@ -738,7 +768,7 @@ class CodegenContext { /** * Splits the generated code of expressions into multiple functions, because function has * 64kb code size limit in JVM. If the class to which the function would be inlined would grow - * beyond 1600kb, we declare a private, nested sub-class, and the function is inlined to it + * beyond 1000kb, we declare a private, inner sub-class, and the function is inlined to it * instead, because classes have a constant pool limit of 65,536 named values. * * @param row the variable name of row that is used by expressions @@ -801,10 +831,90 @@ class CodegenContext { | ${makeSplitFunction(body)} |} """.stripMargin - addNewFunction(name, code) + addNewFunctionInternal(name, code, inlineToOuterClass = false) } - foldFunctions(functions.map(name => s"$name(${arguments.map(_._2).mkString(", ")})")) + val (outerClassFunctions, innerClassFunctions) = functions.partition(_.innerClassName.isEmpty) + + val argsString = arguments.map(_._2).mkString(", ") + val outerClassFunctionCalls = outerClassFunctions.map(f => s"${f.functionName}($argsString)") + + val innerClassFunctionCalls = generateInnerClassesFunctionCalls( + innerClassFunctions, + func, + arguments, + returnType, + makeSplitFunction, + foldFunctions) + + foldFunctions(outerClassFunctionCalls ++ innerClassFunctionCalls) + } + } + + /** + * Here we handle all the methods which have been added to the inner classes and + * not to the outer class. + * Since they can be many, their direct invocation in the outer class adds many entries + * to the outer class' constant pool. This can cause the constant pool to past JVM limit. + * Moreover, this can cause also the outer class method where all the invocations are + * performed to grow beyond the 64k limit. + * To avoid these problems, we group them and we call only the grouping methods in the + * outer class. + * + * @param functions a [[Seq]] of [[NewFunctionSpec]] defined in the inner classes + * @param funcName the split function name base. + * @param arguments the list of (type, name) of the arguments of the split function. + * @param returnType the return type of the split function. + * @param makeSplitFunction makes split function body, e.g. add preparation or cleanup. + * @param foldFunctions folds the split function calls. + * @return an [[Iterable]] containing the methods' invocations + */ + private def generateInnerClassesFunctionCalls( + functions: Seq[NewFunctionSpec], + funcName: String, + arguments: Seq[(String, String)], + returnType: String, + makeSplitFunction: String => String, + foldFunctions: Seq[String] => String): Iterable[String] = { + val innerClassToFunctions = mutable.LinkedHashMap.empty[(String, String), Seq[String]] + functions.foreach(f => { + val key = (f.innerClassName.get, f.innerClassInstance.get) + val value = f.functionName +: innerClassToFunctions.getOrElse(key, Seq.empty[String]) + innerClassToFunctions.put(key, value) + }) + + val argDefinitionString = arguments.map { case (t, name) => s"$t $name" }.mkString(", ") + val argInvocationString = arguments.map(_._2).mkString(", ") + + innerClassToFunctions.flatMap { + case ((innerClassName, innerClassInstance), innerClassFunctions) => + // for performance reasons, the functions are prepended, instead of appended, + // thus here they are in reversed order + val orderedFunctions = innerClassFunctions.reverse + if (orderedFunctions.size > CodeGenerator.MERGE_SPLIT_METHODS_THRESHOLD) { + // Adding a new function to each inner class which contains the invocation of all the + // ones which have been added to that inner class. For example, + // private class NestedClass { + // private void apply_862(InternalRow i) { ... } + // private void apply_863(InternalRow i) { ... } + // ... + // private void apply(InternalRow i) { + // apply_862(i); + // apply_863(i); + // ... + // } + // } + val body = foldFunctions(orderedFunctions.map(name => s"$name($argInvocationString)")) + val code = s""" + |private $returnType $funcName($argDefinitionString) { + | ${makeSplitFunction(body)} + |} + """.stripMargin + addNewFunctionToClass(funcName, code, innerClassName) + Seq(s"$innerClassInstance.$funcName($argInvocationString)") + } else { + orderedFunctions.map(f => s"$innerClassInstance.$f($argInvocationString)") + } } } @@ -1013,6 +1123,16 @@ object CodeGenerator extends Logging { // This is the value of HugeMethodLimit in the OpenJDK JVM settings val DEFAULT_JVM_HUGE_METHOD_LIMIT = 8000 + // This is the threshold over which the methods in an inner class are grouped in a single + // method which is going to be called by the outer class instead of the many small ones + val MERGE_SPLIT_METHODS_THRESHOLD = 3 + + // The number of named constants that can exist in the class is limited by the Constant Pool + // limit, 65,536. We cannot know how many constants will be inserted for a class, so we use a + // threshold of 1000k bytes to determine when a function should be inlined to a private, inner + // class. + val GENERATED_CLASS_SIZE_THRESHOLD = 1000000 + /** * Compile the Java source code into a Java class, using Janino. * diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala index 7ea0bec145481..1e6f7b65e7e72 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala @@ -201,6 +201,23 @@ class CodeGenerationSuite extends SparkFunSuite with ExpressionEvalHelper { } } + test("SPARK-22226: group splitted expressions into one method per nested class") { + val length = 10000 + val expressions = Seq.fill(length) { + ToUTCTimestamp( + Literal.create(Timestamp.valueOf("2017-10-10 00:00:00"), TimestampType), + Literal.create("PST", StringType)) + } + val plan = GenerateMutableProjection.generate(expressions) + val actual = plan(new GenericInternalRow(length)).toSeq(expressions.map(_.dataType)) + val expected = Seq.fill(length)( + DateTimeUtils.fromJavaTimestamp(Timestamp.valueOf("2017-10-10 07:00:00"))) + + if (actual != expected) { + fail(s"Incorrect Evaluation: expressions: $expressions, actual: $actual, expected: $expected") + } + } + test("test generated safe and unsafe projection") { val schema = new StructType(Array( StructField("a", StringType, true), diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index 473c355cf3c7f..17c88b0690800 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -2106,6 +2106,18 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { Seq(Row(7, 1, 1), Row(7, 1, 2), Row(7, 2, 1), Row(7, 2, 2), Row(7, 3, 1), Row(7, 3, 2))) } + test("SPARK-22226: splitExpressions should not generate codes beyond 64KB") { + val colNumber = 10000 + val input = spark.range(2).rdd.map(_ => Row(1 to colNumber: _*)) + val df = sqlContext.createDataFrame(input, StructType( + (1 to colNumber).map(colIndex => StructField(s"_$colIndex", IntegerType, false)))) + val newCols = (1 to colNumber).flatMap { colIndex => + Seq(expr(s"if(1000 < _$colIndex, 1000, _$colIndex)"), + expr(s"sqrt(_$colIndex)")) + } + df.select(newCols: _*).collect() + } + test("SPARK-22271: mean overflows and returns null for some decimal variables") { val d = 0.034567890 val df = Seq(d, d, d, d, d, d, d, d, d, d).toDF("DecimalCol") From 20eb95e5e9c562261b44e4e47cad67a31390fa59 Mon Sep 17 00:00:00 2001 From: WeichenXu Date: Fri, 27 Oct 2017 15:19:27 -0700 Subject: [PATCH 1555/1765] [SPARK-21911][ML][PYSPARK] Parallel Model Evaluation for ML Tuning in PySpark ## What changes were proposed in this pull request? Add parallelism support for ML tuning in pyspark. ## How was this patch tested? Test updated. Author: WeichenXu Closes #19122 from WeichenXu123/par-ml-tuning-py. --- .../spark/ml/tuning/CrossValidatorSuite.scala | 4 +- .../ml/tuning/TrainValidationSplitSuite.scala | 4 +- python/pyspark/ml/tests.py | 39 +++++++++ python/pyspark/ml/tuning.py | 86 ++++++++++++------- 4 files changed, 96 insertions(+), 37 deletions(-) diff --git a/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala index a01744f7b67fd..853eeb39bf8df 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala @@ -137,8 +137,8 @@ class CrossValidatorSuite cv.setParallelism(2) val cvParallelModel = cv.fit(dataset) - val serialMetrics = cvSerialModel.avgMetrics.sorted - val parallelMetrics = cvParallelModel.avgMetrics.sorted + val serialMetrics = cvSerialModel.avgMetrics + val parallelMetrics = cvParallelModel.avgMetrics assert(serialMetrics === parallelMetrics) val parentSerial = cvSerialModel.bestModel.parent.asInstanceOf[LogisticRegression] diff --git a/mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala index 2ed4fbb601b61..f8d9c66be2c40 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala @@ -138,8 +138,8 @@ class TrainValidationSplitSuite cv.setParallelism(2) val cvParallelModel = cv.fit(dataset) - val serialMetrics = cvSerialModel.validationMetrics.sorted - val parallelMetrics = cvParallelModel.validationMetrics.sorted + val serialMetrics = cvSerialModel.validationMetrics + val parallelMetrics = cvParallelModel.validationMetrics assert(serialMetrics === parallelMetrics) val parentSerial = cvSerialModel.bestModel.parent.asInstanceOf[LogisticRegression] diff --git a/python/pyspark/ml/tests.py b/python/pyspark/ml/tests.py index 8b8bcc7b13a38..2f1f3af957e4d 100755 --- a/python/pyspark/ml/tests.py +++ b/python/pyspark/ml/tests.py @@ -836,6 +836,27 @@ def test_save_load_simple_estimator(self): loadedModel = CrossValidatorModel.load(cvModelPath) self.assertEqual(loadedModel.bestModel.uid, cvModel.bestModel.uid) + def test_parallel_evaluation(self): + dataset = self.spark.createDataFrame( + [(Vectors.dense([0.0]), 0.0), + (Vectors.dense([0.4]), 1.0), + (Vectors.dense([0.5]), 0.0), + (Vectors.dense([0.6]), 1.0), + (Vectors.dense([1.0]), 1.0)] * 10, + ["features", "label"]) + + lr = LogisticRegression() + grid = ParamGridBuilder().addGrid(lr.maxIter, [5, 6]).build() + evaluator = BinaryClassificationEvaluator() + + # test save/load of CrossValidator + cv = CrossValidator(estimator=lr, estimatorParamMaps=grid, evaluator=evaluator) + cv.setParallelism(1) + cvSerialModel = cv.fit(dataset) + cv.setParallelism(2) + cvParallelModel = cv.fit(dataset) + self.assertEqual(cvSerialModel.avgMetrics, cvParallelModel.avgMetrics) + def test_save_load_nested_estimator(self): temp_path = tempfile.mkdtemp() dataset = self.spark.createDataFrame( @@ -986,6 +1007,24 @@ def test_save_load_simple_estimator(self): loadedModel = TrainValidationSplitModel.load(tvsModelPath) self.assertEqual(loadedModel.bestModel.uid, tvsModel.bestModel.uid) + def test_parallel_evaluation(self): + dataset = self.spark.createDataFrame( + [(Vectors.dense([0.0]), 0.0), + (Vectors.dense([0.4]), 1.0), + (Vectors.dense([0.5]), 0.0), + (Vectors.dense([0.6]), 1.0), + (Vectors.dense([1.0]), 1.0)] * 10, + ["features", "label"]) + lr = LogisticRegression() + grid = ParamGridBuilder().addGrid(lr.maxIter, [5, 6]).build() + evaluator = BinaryClassificationEvaluator() + tvs = TrainValidationSplit(estimator=lr, estimatorParamMaps=grid, evaluator=evaluator) + tvs.setParallelism(1) + tvsSerialModel = tvs.fit(dataset) + tvs.setParallelism(2) + tvsParallelModel = tvs.fit(dataset) + self.assertEqual(tvsSerialModel.validationMetrics, tvsParallelModel.validationMetrics) + def test_save_load_nested_estimator(self): # This tests saving and loading the trained model only. # Save/load for TrainValidationSplit will be added later: SPARK-13786 diff --git a/python/pyspark/ml/tuning.py b/python/pyspark/ml/tuning.py index 00c348aa9f7de..47351133524e7 100644 --- a/python/pyspark/ml/tuning.py +++ b/python/pyspark/ml/tuning.py @@ -14,15 +14,15 @@ # See the License for the specific language governing permissions and # limitations under the License. # - import itertools import numpy as np +from multiprocessing.pool import ThreadPool from pyspark import since, keyword_only from pyspark.ml import Estimator, Model from pyspark.ml.common import _py2java from pyspark.ml.param import Params, Param, TypeConverters -from pyspark.ml.param.shared import HasSeed +from pyspark.ml.param.shared import HasParallelism, HasSeed from pyspark.ml.util import * from pyspark.ml.wrapper import JavaParams from pyspark.sql.functions import rand @@ -170,7 +170,7 @@ def _to_java_impl(self): return java_estimator, java_epms, java_evaluator -class CrossValidator(Estimator, ValidatorParams, MLReadable, MLWritable): +class CrossValidator(Estimator, ValidatorParams, HasParallelism, MLReadable, MLWritable): """ K-fold cross validation performs model selection by splitting the dataset into a set of @@ -193,7 +193,8 @@ class CrossValidator(Estimator, ValidatorParams, MLReadable, MLWritable): >>> lr = LogisticRegression() >>> grid = ParamGridBuilder().addGrid(lr.maxIter, [0, 1]).build() >>> evaluator = BinaryClassificationEvaluator() - >>> cv = CrossValidator(estimator=lr, estimatorParamMaps=grid, evaluator=evaluator) + >>> cv = CrossValidator(estimator=lr, estimatorParamMaps=grid, evaluator=evaluator, + ... parallelism=2) >>> cvModel = cv.fit(dataset) >>> cvModel.avgMetrics[0] 0.5 @@ -208,23 +209,23 @@ class CrossValidator(Estimator, ValidatorParams, MLReadable, MLWritable): @keyword_only def __init__(self, estimator=None, estimatorParamMaps=None, evaluator=None, numFolds=3, - seed=None): + seed=None, parallelism=1): """ __init__(self, estimator=None, estimatorParamMaps=None, evaluator=None, numFolds=3,\ - seed=None) + seed=None, parallelism=1) """ super(CrossValidator, self).__init__() - self._setDefault(numFolds=3) + self._setDefault(numFolds=3, parallelism=1) kwargs = self._input_kwargs self._set(**kwargs) @keyword_only @since("1.4.0") def setParams(self, estimator=None, estimatorParamMaps=None, evaluator=None, numFolds=3, - seed=None): + seed=None, parallelism=1): """ setParams(self, estimator=None, estimatorParamMaps=None, evaluator=None, numFolds=3,\ - seed=None): + seed=None, parallelism=1): Sets params for cross validator. """ kwargs = self._input_kwargs @@ -255,18 +256,27 @@ def _fit(self, dataset): randCol = self.uid + "_rand" df = dataset.select("*", rand(seed).alias(randCol)) metrics = [0.0] * numModels + + pool = ThreadPool(processes=min(self.getParallelism(), numModels)) + for i in range(nFolds): validateLB = i * h validateUB = (i + 1) * h condition = (df[randCol] >= validateLB) & (df[randCol] < validateUB) - validation = df.filter(condition) - train = df.filter(~condition) - models = est.fit(train, epm) - for j in range(numModels): - model = models[j] + validation = df.filter(condition).cache() + train = df.filter(~condition).cache() + + def singleTrain(paramMap): + model = est.fit(train, paramMap) # TODO: duplicate evaluator to take extra params from input - metric = eva.evaluate(model.transform(validation, epm[j])) - metrics[j] += metric/nFolds + metric = eva.evaluate(model.transform(validation, paramMap)) + return metric + + currentFoldMetrics = pool.map(singleTrain, epm) + for j in range(numModels): + metrics[j] += (currentFoldMetrics[j] / nFolds) + validation.unpersist() + train.unpersist() if eva.isLargerBetter(): bestIndex = np.argmax(metrics) @@ -316,9 +326,10 @@ def _from_java(cls, java_stage): estimator, epms, evaluator = super(CrossValidator, cls)._from_java_impl(java_stage) numFolds = java_stage.getNumFolds() seed = java_stage.getSeed() + parallelism = java_stage.getParallelism() # Create a new instance of this stage. py_stage = cls(estimator=estimator, estimatorParamMaps=epms, evaluator=evaluator, - numFolds=numFolds, seed=seed) + numFolds=numFolds, seed=seed, parallelism=parallelism) py_stage._resetUid(java_stage.uid()) return py_stage @@ -337,6 +348,7 @@ def _to_java(self): _java_obj.setEstimator(estimator) _java_obj.setSeed(self.getSeed()) _java_obj.setNumFolds(self.getNumFolds()) + _java_obj.setParallelism(self.getParallelism()) return _java_obj @@ -427,7 +439,7 @@ def _to_java(self): return _java_obj -class TrainValidationSplit(Estimator, ValidatorParams, MLReadable, MLWritable): +class TrainValidationSplit(Estimator, ValidatorParams, HasParallelism, MLReadable, MLWritable): """ .. note:: Experimental @@ -448,7 +460,8 @@ class TrainValidationSplit(Estimator, ValidatorParams, MLReadable, MLWritable): >>> lr = LogisticRegression() >>> grid = ParamGridBuilder().addGrid(lr.maxIter, [0, 1]).build() >>> evaluator = BinaryClassificationEvaluator() - >>> tvs = TrainValidationSplit(estimator=lr, estimatorParamMaps=grid, evaluator=evaluator) + >>> tvs = TrainValidationSplit(estimator=lr, estimatorParamMaps=grid, evaluator=evaluator, + ... parallelism=2) >>> tvsModel = tvs.fit(dataset) >>> evaluator.evaluate(tvsModel.transform(dataset)) 0.8333... @@ -461,23 +474,23 @@ class TrainValidationSplit(Estimator, ValidatorParams, MLReadable, MLWritable): @keyword_only def __init__(self, estimator=None, estimatorParamMaps=None, evaluator=None, trainRatio=0.75, - seed=None): + parallelism=1, seed=None): """ __init__(self, estimator=None, estimatorParamMaps=None, evaluator=None, trainRatio=0.75,\ - seed=None) + parallelism=1, seed=None) """ super(TrainValidationSplit, self).__init__() - self._setDefault(trainRatio=0.75) + self._setDefault(trainRatio=0.75, parallelism=1) kwargs = self._input_kwargs self._set(**kwargs) @since("2.0.0") @keyword_only def setParams(self, estimator=None, estimatorParamMaps=None, evaluator=None, trainRatio=0.75, - seed=None): + parallelism=1, seed=None): """ setParams(self, estimator=None, estimatorParamMaps=None, evaluator=None, trainRatio=0.75,\ - seed=None): + parallelism=1, seed=None): Sets params for the train validation split. """ kwargs = self._input_kwargs @@ -506,15 +519,20 @@ def _fit(self, dataset): seed = self.getOrDefault(self.seed) randCol = self.uid + "_rand" df = dataset.select("*", rand(seed).alias(randCol)) - metrics = [0.0] * numModels condition = (df[randCol] >= tRatio) - validation = df.filter(condition) - train = df.filter(~condition) - models = est.fit(train, epm) - for j in range(numModels): - model = models[j] - metric = eva.evaluate(model.transform(validation, epm[j])) - metrics[j] += metric + validation = df.filter(condition).cache() + train = df.filter(~condition).cache() + + def singleTrain(paramMap): + model = est.fit(train, paramMap) + metric = eva.evaluate(model.transform(validation, paramMap)) + return metric + + pool = ThreadPool(processes=min(self.getParallelism(), numModels)) + metrics = pool.map(singleTrain, epm) + train.unpersist() + validation.unpersist() + if eva.isLargerBetter(): bestIndex = np.argmax(metrics) else: @@ -563,9 +581,10 @@ def _from_java(cls, java_stage): estimator, epms, evaluator = super(TrainValidationSplit, cls)._from_java_impl(java_stage) trainRatio = java_stage.getTrainRatio() seed = java_stage.getSeed() + parallelism = java_stage.getParallelism() # Create a new instance of this stage. py_stage = cls(estimator=estimator, estimatorParamMaps=epms, evaluator=evaluator, - trainRatio=trainRatio, seed=seed) + trainRatio=trainRatio, seed=seed, parallelism=parallelism) py_stage._resetUid(java_stage.uid()) return py_stage @@ -584,6 +603,7 @@ def _to_java(self): _java_obj.setEstimator(estimator) _java_obj.setTrainRatio(self.getTrainRatio()) _java_obj.setSeed(self.getSeed()) + _java_obj.setParallelism(self.getParallelism()) return _java_obj From 01f6ba0e7a12ef818d56e7d5b1bd889b79f2b57c Mon Sep 17 00:00:00 2001 From: Sathiya Date: Fri, 27 Oct 2017 18:57:08 -0700 Subject: [PATCH 1556/1765] [SPARK-22181][SQL] Adds ReplaceExceptWithFilter rule ## What changes were proposed in this pull request? Adds a new optimisation rule 'ReplaceExceptWithNotFilter' that replaces Except logical with Filter operator and schedule it before applying 'ReplaceExceptWithAntiJoin' rule. This way we can avoid expensive join operation if one or both of the datasets of the Except operation are fully derived out of Filters from a same parent. ## How was this patch tested? The patch is tested locally using spark-shell + unit test. Author: Sathiya Closes #19451 from sathiyapk/SPARK-22181-optimize-exceptWithFilter. --- .../sql/catalyst/expressions/subquery.scala | 10 ++ .../sql/catalyst/optimizer/Optimizer.scala | 1 + .../optimizer/ReplaceExceptWithFilter.scala | 101 +++++++++++++++++ .../apache/spark/sql/internal/SQLConf.scala | 15 +++ .../optimizer/ReplaceOperatorSuite.scala | 106 +++++++++++++++++- .../resources/sql-tests/inputs/except.sql | 57 ++++++++++ .../sql-tests/results/except.sql.out | 105 +++++++++++++++++ 7 files changed, 394 insertions(+), 1 deletion(-) create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceExceptWithFilter.scala create mode 100644 sql/core/src/test/resources/sql-tests/inputs/except.sql create mode 100644 sql/core/src/test/resources/sql-tests/results/except.sql.out diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala index c6146042ef1a6..6acc87a3e7367 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala @@ -89,6 +89,16 @@ object SubqueryExpression { case _ => false }.isDefined } + + /** + * Returns true when an expression contains a subquery + */ + def hasSubquery(e: Expression): Boolean = { + e.find { + case _: SubqueryExpression => true + case _ => false + }.isDefined + } } object SubExprUtils extends PredicateHelper { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index d829e01441dcc..3273a61dc7b35 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -76,6 +76,7 @@ abstract class Optimizer(sessionCatalog: SessionCatalog) OptimizeSubqueries) :: Batch("Replace Operators", fixedPoint, ReplaceIntersectWithSemiJoin, + ReplaceExceptWithFilter, ReplaceExceptWithAntiJoin, ReplaceDistinctWithAggregate) :: Batch("Aggregate", fixedPoint, diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceExceptWithFilter.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceExceptWithFilter.scala new file mode 100644 index 0000000000000..89bfcee078fba --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceExceptWithFilter.scala @@ -0,0 +1,101 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.optimizer + +import scala.annotation.tailrec + +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.catalyst.rules.Rule + + +/** + * If one or both of the datasets in the logical [[Except]] operator are purely transformed using + * [[Filter]], this rule will replace logical [[Except]] operator with a [[Filter]] operator by + * flipping the filter condition of the right child. + * {{{ + * SELECT a1, a2 FROM Tab1 WHERE a2 = 12 EXCEPT SELECT a1, a2 FROM Tab1 WHERE a1 = 5 + * ==> SELECT DISTINCT a1, a2 FROM Tab1 WHERE a2 = 12 AND (a1 is null OR a1 <> 5) + * }}} + * + * Note: + * Before flipping the filter condition of the right node, we should: + * 1. Combine all it's [[Filter]]. + * 2. Apply InferFiltersFromConstraints rule (to take into account of NULL values in the condition). + */ +object ReplaceExceptWithFilter extends Rule[LogicalPlan] { + + def apply(plan: LogicalPlan): LogicalPlan = { + if (!plan.conf.replaceExceptWithFilter) { + return plan + } + + plan.transform { + case Except(left, right) if isEligible(left, right) => + Distinct(Filter(Not(transformCondition(left, skipProject(right))), left)) + } + } + + private def transformCondition(left: LogicalPlan, right: LogicalPlan): Expression = { + val filterCondition = + InferFiltersFromConstraints(combineFilters(right)).asInstanceOf[Filter].condition + + val attributeNameMap: Map[String, Attribute] = left.output.map(x => (x.name, x)).toMap + + filterCondition.transform { case a : AttributeReference => attributeNameMap(a.name) } + } + + // TODO: This can be further extended in the future. + private def isEligible(left: LogicalPlan, right: LogicalPlan): Boolean = (left, right) match { + case (_, right @ (Project(_, _: Filter) | Filter(_, _))) => verifyConditions(left, right) + case _ => false + } + + private def verifyConditions(left: LogicalPlan, right: LogicalPlan): Boolean = { + val leftProjectList = projectList(left) + val rightProjectList = projectList(right) + + left.output.size == left.output.map(_.name).distinct.size && + left.find(_.expressions.exists(SubqueryExpression.hasSubquery)).isEmpty && + right.find(_.expressions.exists(SubqueryExpression.hasSubquery)).isEmpty && + Project(leftProjectList, nonFilterChild(skipProject(left))).sameResult( + Project(rightProjectList, nonFilterChild(skipProject(right)))) + } + + private def projectList(node: LogicalPlan): Seq[NamedExpression] = node match { + case p: Project => p.projectList + case x => x.output + } + + private def skipProject(node: LogicalPlan): LogicalPlan = node match { + case p: Project => p.child + case x => x + } + + private def nonFilterChild(plan: LogicalPlan) = plan.find(!_.isInstanceOf[Filter]).getOrElse { + throw new IllegalStateException("Leaf node is expected") + } + + private def combineFilters(plan: LogicalPlan): LogicalPlan = { + @tailrec + def iterate(plan: LogicalPlan, acc: LogicalPlan): LogicalPlan = { + if (acc.fastEquals(plan)) acc else iterate(acc, CombineFilters(acc)) + } + iterate(plan, CombineFilters(plan)) + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 21e4685fcc456..5203e8833fbbb 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -948,6 +948,19 @@ object SQLConf { .intConf .createWithDefault(10000) + val REPLACE_EXCEPT_WITH_FILTER = buildConf("spark.sql.optimizer.replaceExceptWithFilter") + .internal() + .doc("When true, the apply function of the rule verifies whether the right node of the" + + " except operation is of type Filter or Project followed by Filter. If yes, the rule" + + " further verifies 1) Excluding the filter operations from the right (as well as the" + + " left node, if any) on the top, whether both the nodes evaluates to a same result." + + " 2) The left and right nodes don't contain any SubqueryExpressions. 3) The output" + + " column names of the left node are distinct. If all the conditions are met, the" + + " rule will replace the except operation with a Filter by flipping the filter" + + " condition(s) of the right node.") + .booleanConf + .createWithDefault(true) + object Deprecated { val MAPRED_REDUCE_TASKS = "mapred.reduce.tasks" } @@ -1233,6 +1246,8 @@ class SQLConf extends Serializable with Logging { def arrowMaxRecordsPerBatch: Int = getConf(ARROW_EXECUTION_MAX_RECORDS_PER_BATCH) + def replaceExceptWithFilter: Boolean = getConf(REPLACE_EXCEPT_WITH_FILTER) + /** ********************** SQLConf functionality methods ************ */ /** Set Spark SQL configuration properties. */ diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceOperatorSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceOperatorSuite.scala index 85988d2fb948c..0fa1aaeb9e164 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceOperatorSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceOperatorSuite.scala @@ -17,9 +17,10 @@ package org.apache.spark.sql.catalyst.optimizer +import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ -import org.apache.spark.sql.catalyst.expressions.Alias +import org.apache.spark.sql.catalyst.expressions.{Alias, Not} import org.apache.spark.sql.catalyst.expressions.aggregate.First import org.apache.spark.sql.catalyst.plans.{LeftAnti, LeftSemi, PlanTest} import org.apache.spark.sql.catalyst.plans.logical._ @@ -31,6 +32,7 @@ class ReplaceOperatorSuite extends PlanTest { val batches = Batch("Replace Operators", FixedPoint(100), ReplaceDistinctWithAggregate, + ReplaceExceptWithFilter, ReplaceExceptWithAntiJoin, ReplaceIntersectWithSemiJoin, ReplaceDeduplicateWithAggregate) :: Nil @@ -50,6 +52,108 @@ class ReplaceOperatorSuite extends PlanTest { comparePlans(optimized, correctAnswer) } + test("replace Except with Filter while both the nodes are of type Filter") { + val attributeA = 'a.int + val attributeB = 'b.int + + val table1 = LocalRelation.fromExternalRows(Seq(attributeA, attributeB), data = Seq(Row(1, 2))) + val table2 = Filter(attributeB === 2, Filter(attributeA === 1, table1)) + val table3 = Filter(attributeB < 1, Filter(attributeA >= 2, table1)) + + val query = Except(table2, table3) + val optimized = Optimize.execute(query.analyze) + + val correctAnswer = + Aggregate(table1.output, table1.output, + Filter(Not((attributeA.isNotNull && attributeB.isNotNull) && + (attributeA >= 2 && attributeB < 1)), + Filter(attributeB === 2, Filter(attributeA === 1, table1)))).analyze + + comparePlans(optimized, correctAnswer) + } + + test("replace Except with Filter while only right node is of type Filter") { + val attributeA = 'a.int + val attributeB = 'b.int + + val table1 = LocalRelation.fromExternalRows(Seq(attributeA, attributeB), data = Seq(Row(1, 2))) + val table2 = Filter(attributeB < 1, Filter(attributeA >= 2, table1)) + + val query = Except(table1, table2) + val optimized = Optimize.execute(query.analyze) + + val correctAnswer = + Aggregate(table1.output, table1.output, + Filter(Not((attributeA.isNotNull && attributeB.isNotNull) && + (attributeA >= 2 && attributeB < 1)), table1)).analyze + + comparePlans(optimized, correctAnswer) + } + + test("replace Except with Filter while both the nodes are of type Project") { + val attributeA = 'a.int + val attributeB = 'b.int + + val table1 = LocalRelation.fromExternalRows(Seq(attributeA, attributeB), data = Seq(Row(1, 2))) + val table2 = Project(Seq(attributeA, attributeB), table1) + val table3 = Project(Seq(attributeA, attributeB), + Filter(attributeB < 1, Filter(attributeA >= 2, table1))) + + val query = Except(table2, table3) + val optimized = Optimize.execute(query.analyze) + + val correctAnswer = + Aggregate(table1.output, table1.output, + Filter(Not((attributeA.isNotNull && attributeB.isNotNull) && + (attributeA >= 2 && attributeB < 1)), + Project(Seq(attributeA, attributeB), table1))).analyze + + comparePlans(optimized, correctAnswer) + } + + test("replace Except with Filter while only right node is of type Project") { + val attributeA = 'a.int + val attributeB = 'b.int + + val table1 = LocalRelation.fromExternalRows(Seq(attributeA, attributeB), data = Seq(Row(1, 2))) + val table2 = Filter(attributeB === 2, Filter(attributeA === 1, table1)) + val table3 = Project(Seq(attributeA, attributeB), + Filter(attributeB < 1, Filter(attributeA >= 2, table1))) + + val query = Except(table2, table3) + val optimized = Optimize.execute(query.analyze) + + val correctAnswer = + Aggregate(table1.output, table1.output, + Filter(Not((attributeA.isNotNull && attributeB.isNotNull) && + (attributeA >= 2 && attributeB < 1)), + Filter(attributeB === 2, Filter(attributeA === 1, table1)))).analyze + + comparePlans(optimized, correctAnswer) + } + + test("replace Except with Filter while left node is Project and right node is Filter") { + val attributeA = 'a.int + val attributeB = 'b.int + + val table1 = LocalRelation.fromExternalRows(Seq(attributeA, attributeB), data = Seq(Row(1, 2))) + val table2 = Project(Seq(attributeA, attributeB), + Filter(attributeB < 1, Filter(attributeA >= 2, table1))) + val table3 = Filter(attributeB === 2, Filter(attributeA === 1, table1)) + + val query = Except(table2, table3) + val optimized = Optimize.execute(query.analyze) + + val correctAnswer = + Aggregate(table1.output, table1.output, + Filter(Not((attributeA.isNotNull && attributeB.isNotNull) && + (attributeA === 1 && attributeB === 2)), + Project(Seq(attributeA, attributeB), + Filter(attributeB < 1, Filter(attributeA >= 2, table1))))).analyze + + comparePlans(optimized, correctAnswer) + } + test("replace Except with Left-anti Join") { val table1 = LocalRelation('a.int, 'b.int) val table2 = LocalRelation('c.int, 'd.int) diff --git a/sql/core/src/test/resources/sql-tests/inputs/except.sql b/sql/core/src/test/resources/sql-tests/inputs/except.sql new file mode 100644 index 0000000000000..1d579e65f3473 --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/inputs/except.sql @@ -0,0 +1,57 @@ +-- Tests different scenarios of except operation +create temporary view t1 as select * from values + ("one", 1), + ("two", 2), + ("three", 3), + ("one", NULL) + as t1(k, v); + +create temporary view t2 as select * from values + ("one", 1), + ("two", 22), + ("one", 5), + ("one", NULL), + (NULL, 5) + as t2(k, v); + + +-- Except operation that will be replaced by left anti join +SELECT * FROM t1 EXCEPT SELECT * FROM t2; + + +-- Except operation that will be replaced by Filter: SPARK-22181 +SELECT * FROM t1 EXCEPT SELECT * FROM t1 where v <> 1 and v <> 2; + + +-- Except operation that will be replaced by Filter: SPARK-22181 +SELECT * FROM t1 where v <> 1 and v <> 22 EXCEPT SELECT * FROM t1 where v <> 2 and v >= 3; + + +-- Except operation that will be replaced by Filter: SPARK-22181 +SELECT t1.* FROM t1, t2 where t1.k = t2.k +EXCEPT +SELECT t1.* FROM t1, t2 where t1.k = t2.k and t1.k != 'one'; + + +-- Except operation that will be replaced by left anti join +SELECT * FROM t2 where v >= 1 and v <> 22 EXCEPT SELECT * FROM t1; + + +-- Except operation that will be replaced by left anti join +SELECT (SELECT min(k) FROM t2 WHERE t2.k = t1.k) min_t2 FROM t1 +MINUS +SELECT (SELECT min(k) FROM t2) abs_min_t2 FROM t1 WHERE t1.k = 'one'; + + +-- Except operation that will be replaced by left anti join +SELECT t1.k +FROM t1 +WHERE t1.v <= (SELECT max(t2.v) + FROM t2 + WHERE t2.k = t1.k) +MINUS +SELECT t1.k +FROM t1 +WHERE t1.v >= (SELECT min(t2.v) + FROM t2 + WHERE t2.k = t1.k); diff --git a/sql/core/src/test/resources/sql-tests/results/except.sql.out b/sql/core/src/test/resources/sql-tests/results/except.sql.out new file mode 100644 index 0000000000000..c9b712d4d2949 --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/results/except.sql.out @@ -0,0 +1,105 @@ +-- Automatically generated by SQLQueryTestSuite +-- Number of queries: 9 + + +-- !query 0 +create temporary view t1 as select * from values + ("one", 1), + ("two", 2), + ("three", 3), + ("one", NULL) + as t1(k, v) +-- !query 0 schema +struct<> +-- !query 0 output + + + +-- !query 1 +create temporary view t2 as select * from values + ("one", 1), + ("two", 22), + ("one", 5), + ("one", NULL), + (NULL, 5) + as t2(k, v) +-- !query 1 schema +struct<> +-- !query 1 output + + + +-- !query 2 +SELECT * FROM t1 EXCEPT SELECT * FROM t2 +-- !query 2 schema +struct +-- !query 2 output +three 3 +two 2 + + +-- !query 3 +SELECT * FROM t1 EXCEPT SELECT * FROM t1 where v <> 1 and v <> 2 +-- !query 3 schema +struct +-- !query 3 output +one 1 +one NULL +two 2 + + +-- !query 4 +SELECT * FROM t1 where v <> 1 and v <> 22 EXCEPT SELECT * FROM t1 where v <> 2 and v >= 3 +-- !query 4 schema +struct +-- !query 4 output +two 2 + + +-- !query 5 +SELECT t1.* FROM t1, t2 where t1.k = t2.k +EXCEPT +SELECT t1.* FROM t1, t2 where t1.k = t2.k and t1.k != 'one' +-- !query 5 schema +struct +-- !query 5 output +one 1 +one NULL + + +-- !query 6 +SELECT * FROM t2 where v >= 1 and v <> 22 EXCEPT SELECT * FROM t1 +-- !query 6 schema +struct +-- !query 6 output +NULL 5 +one 5 + + +-- !query 7 +SELECT (SELECT min(k) FROM t2 WHERE t2.k = t1.k) min_t2 FROM t1 +MINUS +SELECT (SELECT min(k) FROM t2) abs_min_t2 FROM t1 WHERE t1.k = 'one' +-- !query 7 schema +struct +-- !query 7 output +NULL +two + + +-- !query 8 +SELECT t1.k +FROM t1 +WHERE t1.v <= (SELECT max(t2.v) + FROM t2 + WHERE t2.k = t1.k) +MINUS +SELECT t1.k +FROM t1 +WHERE t1.v >= (SELECT min(t2.v) + FROM t2 + WHERE t2.k = t1.k) +-- !query 8 schema +struct +-- !query 8 output +two From c42d208e197ff061cc5f49c75568c047d8d1126c Mon Sep 17 00:00:00 2001 From: donnyzone Date: Fri, 27 Oct 2017 23:40:59 -0700 Subject: [PATCH 1557/1765] [SPARK-22333][SQL] timeFunctionCall(CURRENT_DATE, CURRENT_TIMESTAMP) has conflicts with columnReference ## What changes were proposed in this pull request? https://issues.apache.org/jira/browse/SPARK-22333 In current version, users can use CURRENT_DATE() and CURRENT_TIMESTAMP() without specifying braces. However, when a table has columns named as "current_date" or "current_timestamp", it will still be parsed as function call. There are many such cases in our production cluster. We get the wrong answer due to this inappropriate behevior. In general, ColumnReference should get higher priority than timeFunctionCall. ## How was this patch tested? unit test manul test Author: donnyzone Closes #19559 from DonnyZone/master. --- .../spark/sql/catalyst/parser/SqlBase.g4 | 7 +-- .../sql/catalyst/analysis/Analyzer.scala | 37 +++++++++++++- .../sql/catalyst/parser/AstBuilder.scala | 13 ----- .../parser/ExpressionParserSuite.scala | 5 -- .../resources/sql-tests/inputs/datetime.sql | 17 +++++++ .../sql-tests/results/datetime.sql.out | 49 ++++++++++++++++++- 6 files changed, 102 insertions(+), 26 deletions(-) diff --git a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 index 17c8404f8a79c..6fe995f650d55 100644 --- a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 +++ b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 @@ -564,8 +564,7 @@ valueExpression ; primaryExpression - : name=(CURRENT_DATE | CURRENT_TIMESTAMP) #timeFunctionCall - | CASE whenClause+ (ELSE elseExpression=expression)? END #searchedCase + : CASE whenClause+ (ELSE elseExpression=expression)? END #searchedCase | CASE value=expression whenClause+ (ELSE elseExpression=expression)? END #simpleCase | CAST '(' expression AS dataType ')' #cast | STRUCT '(' (argument+=namedExpression (',' argument+=namedExpression)*)? ')' #struct @@ -747,7 +746,7 @@ nonReserved | NULL | ORDER | OUTER | TABLE | TRUE | WITH | RLIKE | AND | CASE | CAST | DISTINCT | DIV | ELSE | END | FUNCTION | INTERVAL | MACRO | OR | STRATIFY | THEN | UNBOUNDED | WHEN - | DATABASE | SELECT | FROM | WHERE | HAVING | TO | TABLE | WITH | NOT | CURRENT_DATE | CURRENT_TIMESTAMP + | DATABASE | SELECT | FROM | WHERE | HAVING | TO | TABLE | WITH | NOT | DIRECTORY | BOTH | LEADING | TRAILING ; @@ -983,8 +982,6 @@ OPTION: 'OPTION'; ANTI: 'ANTI'; LOCAL: 'LOCAL'; INPATH: 'INPATH'; -CURRENT_DATE: 'CURRENT_DATE'; -CURRENT_TIMESTAMP: 'CURRENT_TIMESTAMP'; STRING : '\'' ( ~('\''|'\\') | ('\\' .) )* '\'' diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 6384a141e83b3..e5c93b5f0e059 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -786,7 +786,12 @@ class Analyzer( private def resolve(e: Expression, q: LogicalPlan): Expression = e match { case u @ UnresolvedAttribute(nameParts) => // Leave unchanged if resolution fails. Hopefully will be resolved next round. - val result = withPosition(u) { q.resolveChildren(nameParts, resolver).getOrElse(u) } + val result = + withPosition(u) { + q.resolveChildren(nameParts, resolver) + .orElse(resolveLiteralFunction(nameParts, u, q)) + .getOrElse(u) + } logDebug(s"Resolving $u to $result") result case UnresolvedExtractValue(child, fieldExpr) if child.resolved => @@ -925,6 +930,30 @@ class Analyzer( exprs.exists(_.find(_.isInstanceOf[UnresolvedDeserializer]).isDefined) } + /** + * Literal functions do not require the user to specify braces when calling them + * When an attributes is not resolvable, we try to resolve it as a literal function. + */ + private def resolveLiteralFunction( + nameParts: Seq[String], + attribute: UnresolvedAttribute, + plan: LogicalPlan): Option[Expression] = { + if (nameParts.length != 1) return None + val isNamedExpression = plan match { + case Aggregate(_, aggregateExpressions, _) => aggregateExpressions.contains(attribute) + case Project(projectList, _) => projectList.contains(attribute) + case Window(windowExpressions, _, _, _) => windowExpressions.contains(attribute) + case _ => false + } + val wrapper: Expression => Expression = + if (isNamedExpression) f => Alias(f, toPrettySQL(f))() else identity + // support CURRENT_DATE and CURRENT_TIMESTAMP + val literalFunctions = Seq(CurrentDate(), CurrentTimestamp()) + val name = nameParts.head + val func = literalFunctions.find(e => resolver(e.prettyName, name)) + func.map(wrapper) + } + protected[sql] def resolveExpression( expr: Expression, plan: LogicalPlan, @@ -937,7 +966,11 @@ class Analyzer( expr transformUp { case GetColumnByOrdinal(ordinal, _) => plan.output(ordinal) case u @ UnresolvedAttribute(nameParts) => - withPosition(u) { plan.resolve(nameParts, resolver).getOrElse(u) } + withPosition(u) { + plan.resolve(nameParts, resolver) + .orElse(resolveLiteralFunction(nameParts, u, plan)) + .getOrElse(u) + } case UnresolvedExtractValue(child, fieldName) if child.resolved => ExtractValue(child, fieldName, resolver) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index ce367145bc637..7651d11ee65a8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -1234,19 +1234,6 @@ class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging } } - /** - * Create a current timestamp/date expression. These are different from regular function because - * they do not require the user to specify braces when calling them. - */ - override def visitTimeFunctionCall(ctx: TimeFunctionCallContext): Expression = withOrigin(ctx) { - ctx.name.getType match { - case SqlBaseParser.CURRENT_DATE => - CurrentDate() - case SqlBaseParser.CURRENT_TIMESTAMP => - CurrentTimestamp() - } - } - /** * Create a function database (optional) and name pair. */ diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala index 76c79b3d0760c..2b9783a3295c6 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala @@ -592,11 +592,6 @@ class ExpressionParserSuite extends PlanTest { intercept("1 - f('o', o(bar)) hello * world", "mismatched input '*'") } - test("current date/timestamp braceless expressions") { - assertEqual("current_date", CurrentDate()) - assertEqual("current_timestamp", CurrentTimestamp()) - } - test("SPARK-17364, fully qualified column name which starts with number") { assertEqual("123_", UnresolvedAttribute("123_")) assertEqual("1a.123_", UnresolvedAttribute("1a.123_")) diff --git a/sql/core/src/test/resources/sql-tests/inputs/datetime.sql b/sql/core/src/test/resources/sql-tests/inputs/datetime.sql index 616b6caee3f20..adea2bfa82cd3 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/datetime.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/datetime.sql @@ -8,3 +8,20 @@ select to_date(null), to_date('2016-12-31'), to_date('2016-12-31', 'yyyy-MM-dd') select to_timestamp(null), to_timestamp('2016-12-31 00:12:00'), to_timestamp('2016-12-31', 'yyyy-MM-dd'); select dayofweek('2007-02-03'), dayofweek('2009-07-30'), dayofweek('2017-05-27'), dayofweek(null), dayofweek('1582-10-15 13:10:15'); + +-- [SPARK-22333]: timeFunctionCall has conflicts with columnReference +create temporary view ttf1 as select * from values + (1, 2), + (2, 3) + as ttf1(current_date, current_timestamp); + +select current_date, current_timestamp from ttf1; + +create temporary view ttf2 as select * from values + (1, 2), + (2, 3) + as ttf2(a, b); + +select current_date = current_date(), current_timestamp = current_timestamp(), a, b from ttf2; + +select a, b from ttf2 order by a, current_date; diff --git a/sql/core/src/test/resources/sql-tests/results/datetime.sql.out b/sql/core/src/test/resources/sql-tests/results/datetime.sql.out index a28b91c77324b..7b2f46f6c2a66 100644 --- a/sql/core/src/test/resources/sql-tests/results/datetime.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/datetime.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 4 +-- Number of queries: 9 -- !query 0 @@ -32,3 +32,50 @@ select dayofweek('2007-02-03'), dayofweek('2009-07-30'), dayofweek('2017-05-27') struct -- !query 3 output 7 5 7 NULL 6 + + +-- !query 4 +create temporary view ttf1 as select * from values + (1, 2), + (2, 3) + as ttf1(current_date, current_timestamp) +-- !query 4 schema +struct<> +-- !query 4 output + + +-- !query 5 +select current_date, current_timestamp from ttf1 +-- !query 5 schema +struct +-- !query 5 output +1 2 +2 3 + + +-- !query 6 +create temporary view ttf2 as select * from values + (1, 2), + (2, 3) + as ttf2(a, b) +-- !query 6 schema +struct<> +-- !query 6 output + + +-- !query 7 +select current_date = current_date(), current_timestamp = current_timestamp(), a, b from ttf2 +-- !query 7 schema +struct<(current_date() = current_date()):boolean,(current_timestamp() = current_timestamp()):boolean,a:int,b:int> +-- !query 7 output +true true 1 2 +true true 2 3 + + +-- !query 8 +select a, b from ttf2 order by a, current_date +-- !query 8 schema +struct +-- !query 8 output +1 2 +2 3 From d28d5732ae205771f1f443b15b10e64dcffb5ff0 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Fri, 27 Oct 2017 23:44:24 -0700 Subject: [PATCH 1558/1765] [SPARK-21619][SQL] Fail the execution of canonicalized plans explicitly ## What changes were proposed in this pull request? Canonicalized plans are not supposed to be executed. I ran into a case in which there's some code that accidentally calls execute on a canonicalized plan. This patch throws a more explicit exception when that happens. ## How was this patch tested? Added a test case in SparkPlanSuite. Author: Reynold Xin Closes #18828 from rxin/SPARK-21619. --- .../sql/catalyst/catalog/interface.scala | 5 +-- .../spark/sql/catalyst/plans/QueryPlan.scala | 30 +++++++++++++--- .../plans/logical/basicLogicalOperators.scala | 2 +- .../sql/catalyst/plans/logical/hints.scala | 2 +- .../sql/execution/DataSourceScanExec.scala | 4 +-- .../spark/sql/execution/SparkPlan.scala | 6 ++++ .../spark/sql/execution/SparkSqlParser.scala | 2 +- .../execution/basicPhysicalOperators.scala | 2 +- .../spark/sql/execution/command/cache.scala | 5 ++- .../datasources/LogicalRelation.scala | 2 +- .../exchange/BroadcastExchangeExec.scala | 2 +- .../sql/execution/exchange/Exchange.scala | 2 +- .../spark/sql/execution/SparkPlanSuite.scala | 36 +++++++++++++++++++ .../hive/execution/HiveTableScanExec.scala | 4 +-- 14 files changed, 86 insertions(+), 18 deletions(-) create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanSuite.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala index 1dbae4d37d8f5..b87bbb4874670 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala @@ -438,7 +438,7 @@ case class HiveTableRelation( def isPartitioned: Boolean = partitionCols.nonEmpty - override lazy val canonicalized: HiveTableRelation = copy( + override def doCanonicalize(): HiveTableRelation = copy( tableMeta = tableMeta.copy( storage = CatalogStorageFormat.empty, createTime = -1 @@ -448,7 +448,8 @@ case class HiveTableRelation( }, partitionCols = partitionCols.zipWithIndex.map { case (attr, index) => attr.withExprId(ExprId(index + dataCols.length)) - }) + } + ) override def computeStats(): Statistics = { tableMeta.stats.map(_.toPlanStats(output)).getOrElse { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala index c7952e3ff8280..d21b4afa2f06c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala @@ -180,6 +180,15 @@ abstract class QueryPlan[PlanType <: QueryPlan[PlanType]] extends TreeNode[PlanT override protected def innerChildren: Seq[QueryPlan[_]] = subqueries + /** + * A private mutable variable to indicate whether this plan is the result of canonicalization. + * This is used solely for making sure we wouldn't execute a canonicalized plan. + * See [[canonicalized]] on how this is set. + */ + @transient private var _isCanonicalizedPlan: Boolean = false + + protected def isCanonicalizedPlan: Boolean = _isCanonicalizedPlan + /** * Returns a plan where a best effort attempt has been made to transform `this` in a way * that preserves the result but removes cosmetic variations (case sensitivity, ordering for @@ -188,10 +197,24 @@ abstract class QueryPlan[PlanType <: QueryPlan[PlanType]] extends TreeNode[PlanT * Plans where `this.canonicalized == other.canonicalized` will always evaluate to the same * result. * - * Some nodes should overwrite this to provide proper canonicalize logic, but they should remove - * expressions cosmetic variations themselves. + * Plan nodes that require special canonicalization should override [[doCanonicalize()]]. + * They should remove expressions cosmetic variations themselves. + */ + @transient final lazy val canonicalized: PlanType = { + var plan = doCanonicalize() + // If the plan has not been changed due to canonicalization, make a copy of it so we don't + // mutate the original plan's _isCanonicalizedPlan flag. + if (plan eq this) { + plan = plan.makeCopy(plan.mapProductIterator(x => x.asInstanceOf[AnyRef])) + } + plan._isCanonicalizedPlan = true + plan + } + + /** + * Defines how the canonicalization should work for the current plan. */ - lazy val canonicalized: PlanType = { + protected def doCanonicalize(): PlanType = { val canonicalizedChildren = children.map(_.canonicalized) var id = -1 mapExpressions { @@ -213,7 +236,6 @@ abstract class QueryPlan[PlanType <: QueryPlan[PlanType]] extends TreeNode[PlanT }.withNewChildren(canonicalizedChildren) } - /** * Returns true when the given query plan will return the same results as this query plan. * diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala index 80243d3d356ca..c2750c3079814 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala @@ -760,7 +760,7 @@ case class SubqueryAlias( child: LogicalPlan) extends UnaryNode { - override lazy val canonicalized: LogicalPlan = child.canonicalized + override def doCanonicalize(): LogicalPlan = child.canonicalized override def output: Seq[Attribute] = child.output.map(_.withQualifier(Some(alias))) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/hints.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/hints.scala index 29a43528124d8..cbb626590d1d7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/hints.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/hints.scala @@ -41,7 +41,7 @@ case class ResolvedHint(child: LogicalPlan, hints: HintInfo = HintInfo()) override def output: Seq[Attribute] = child.output - override lazy val canonicalized: LogicalPlan = child.canonicalized + override def doCanonicalize(): LogicalPlan = child.canonicalized } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala index 8d0fc32feac99..e9f65031143b7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala @@ -139,7 +139,7 @@ case class RowDataSourceScanExec( } // Don't care about `rdd` and `tableIdentifier` when canonicalizing. - override lazy val canonicalized: SparkPlan = + override def doCanonicalize(): SparkPlan = copy( fullOutput.map(QueryPlan.normalizeExprId(_, fullOutput)), rdd = null, @@ -522,7 +522,7 @@ case class FileSourceScanExec( } } - override lazy val canonicalized: FileSourceScanExec = { + override def doCanonicalize(): FileSourceScanExec = { FileSourceScanExec( relation, output.map(QueryPlan.normalizeExprId(_, output)), diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala index 2ffd948f984bf..657b265260135 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala @@ -111,6 +111,9 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ * Concrete implementations of SparkPlan should override `doExecute`. */ final def execute(): RDD[InternalRow] = executeQuery { + if (isCanonicalizedPlan) { + throw new IllegalStateException("A canonicalized plan is not supposed to be executed.") + } doExecute() } @@ -121,6 +124,9 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ * Concrete implementations of SparkPlan should override `doExecuteBroadcast`. */ final def executeBroadcast[T](): broadcast.Broadcast[T] = executeQuery { + if (isCanonicalizedPlan) { + throw new IllegalStateException("A canonicalized plan is not supposed to be executed.") + } doExecuteBroadcast() } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala index 6de9ea0efd2c6..29b584b55972c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala @@ -286,7 +286,7 @@ class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder(conf) { * Create a [[ClearCacheCommand]] logical plan. */ override def visitClearCache(ctx: ClearCacheContext): LogicalPlan = withOrigin(ctx) { - ClearCacheCommand + ClearCacheCommand() } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala index d15ece304cac4..e58c3cec2df15 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala @@ -350,7 +350,7 @@ case class RangeExec(range: org.apache.spark.sql.catalyst.plans.logical.Range) override lazy val metrics = Map( "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows")) - override lazy val canonicalized: SparkPlan = { + override def doCanonicalize(): SparkPlan = { RangeExec(range.canonicalized.asInstanceOf[org.apache.spark.sql.catalyst.plans.logical.Range]) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/cache.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/cache.scala index 140f920eaafae..687994d82a003 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/cache.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/cache.scala @@ -66,10 +66,13 @@ case class UncacheTableCommand( /** * Clear all cached data from the in-memory cache. */ -case object ClearCacheCommand extends RunnableCommand { +case class ClearCacheCommand() extends RunnableCommand { override def run(sparkSession: SparkSession): Seq[Row] = { sparkSession.catalog.clearCache() Seq.empty[Row] } + + /** [[org.apache.spark.sql.catalyst.trees.TreeNode.makeCopy()]] does not support 0-arg ctor. */ + override def makeCopy(newArgs: Array[AnyRef]): ClearCacheCommand = ClearCacheCommand() } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/LogicalRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/LogicalRelation.scala index 3e98cb28453a2..236995708a12f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/LogicalRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/LogicalRelation.scala @@ -35,7 +35,7 @@ case class LogicalRelation( extends LeafNode with MultiInstanceRelation { // Only care about relation when canonicalizing. - override lazy val canonicalized: LogicalPlan = copy( + override def doCanonicalize(): LogicalPlan = copy( output = output.map(QueryPlan.normalizeExprId(_, output)), catalogTable = None) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/BroadcastExchangeExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/BroadcastExchangeExec.scala index 880e18c6808b0..daea6c39624d6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/BroadcastExchangeExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/BroadcastExchangeExec.scala @@ -48,7 +48,7 @@ case class BroadcastExchangeExec( override def outputPartitioning: Partitioning = BroadcastPartitioning(mode) - override lazy val canonicalized: SparkPlan = { + override def doCanonicalize(): SparkPlan = { BroadcastExchangeExec(mode.canonicalized, child.canonicalized) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/Exchange.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/Exchange.scala index 4b52f3e4c49b0..09f79a2de0ba0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/Exchange.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/Exchange.scala @@ -50,7 +50,7 @@ case class ReusedExchangeExec(override val output: Seq[Attribute], child: Exchan extends LeafExecNode { // Ignore this wrapper for canonicalizing. - override lazy val canonicalized: SparkPlan = child.canonicalized + override def doCanonicalize(): SparkPlan = child.canonicalized def doExecute(): RDD[InternalRow] = { child.execute() diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanSuite.scala new file mode 100644 index 0000000000000..750d9e4adf8b4 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanSuite.scala @@ -0,0 +1,36 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution + +import org.apache.spark.sql.QueryTest +import org.apache.spark.sql.test.SharedSQLContext + +class SparkPlanSuite extends QueryTest with SharedSQLContext { + + test("SPARK-21619 execution of a canonicalized plan should fail") { + val plan = spark.range(10).queryExecution.executedPlan.canonicalized + + intercept[IllegalStateException] { plan.execute() } + intercept[IllegalStateException] { plan.executeCollect() } + intercept[IllegalStateException] { plan.executeCollectPublic() } + intercept[IllegalStateException] { plan.executeToIterator() } + intercept[IllegalStateException] { plan.executeBroadcast() } + intercept[IllegalStateException] { plan.executeTake(1) } + } + +} diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScanExec.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScanExec.scala index 4f8dab9cd6172..7dcaf170f9693 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScanExec.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScanExec.scala @@ -203,11 +203,11 @@ case class HiveTableScanExec( } } - override lazy val canonicalized: HiveTableScanExec = { + override def doCanonicalize(): HiveTableScanExec = { val input: AttributeSeq = relation.output HiveTableScanExec( requestedAttributes.map(QueryPlan.normalizeExprId(_, input)), - relation.canonicalized, + relation.canonicalized.asInstanceOf[HiveTableRelation], QueryPlan.normalizePredicates(partitionPruningPred, input))(sparkSession) } From 683ffe0620e69fd6e9f92c1037eef7996029aba8 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Sat, 28 Oct 2017 21:47:15 +0900 Subject: [PATCH 1559/1765] [SPARK-22335][SQL] Clarify union behavior on Dataset of typed objects in the document ## What changes were proposed in this pull request? Seems that end users can be confused by the union's behavior on Dataset of typed objects. We can clarity it more in the document of `union` function. ## How was this patch tested? Only document change. Author: Liang-Chi Hsieh Closes #19570 from viirya/SPARK-22335. --- .../scala/org/apache/spark/sql/Dataset.scala | 21 ++++++++++++++++++- 1 file changed, 20 insertions(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index fe4e192e43dfe..bd99ec52ce93f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -1747,7 +1747,26 @@ class Dataset[T] private[sql]( * This is equivalent to `UNION ALL` in SQL. To do a SQL-style set union (that does * deduplication of elements), use this function followed by a [[distinct]]. * - * Also as standard in SQL, this function resolves columns by position (not by name). + * Also as standard in SQL, this function resolves columns by position (not by name): + * + * {{{ + * val df1 = Seq((1, 2, 3)).toDF("col0", "col1", "col2") + * val df2 = Seq((4, 5, 6)).toDF("col1", "col2", "col0") + * df1.union(df2).show + * + * // output: + * // +----+----+----+ + * // |col0|col1|col2| + * // +----+----+----+ + * // | 1| 2| 3| + * // | 4| 5| 6| + * // +----+----+----+ + * }}} + * + * Notice that the column positions in the schema aren't necessarily matched with the + * fields in the strongly typed objects in a Dataset. This function resolves columns + * by their positions in the schema, not the fields in the strongly typed objects. Use + * [[unionByName]] to resolve columns by field name in the typed objects. * * @group typedrel * @since 2.0.0 From 4c5269f1aa529e6a397b68d6dc409d89e32685bd Mon Sep 17 00:00:00 2001 From: Takuya UESHIN Date: Sat, 28 Oct 2017 18:33:09 +0100 Subject: [PATCH 1560/1765] [SPARK-22370][SQL][PYSPARK] Config values should be captured in Driver. ## What changes were proposed in this pull request? `ArrowEvalPythonExec` and `FlatMapGroupsInPandasExec` are refering config values of `SQLConf` in function for `mapPartitions`/`mapPartitionsInternal`, but we should capture them in Driver. ## How was this patch tested? Added a test and existing tests. Author: Takuya UESHIN Closes #19587 from ueshin/issues/SPARK-22370. --- python/pyspark/sql/tests.py | 20 +++++++++++++++++++ .../spark/sql/catalyst/plans/QueryPlan.scala | 6 ++++++ .../python/ArrowEvalPythonExec.scala | 6 ++++-- .../python/FlatMapGroupsInPandasExec.scala | 3 ++- 4 files changed, 32 insertions(+), 3 deletions(-) diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 98afae662b42d..8ed37c9da98b8 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -3476,6 +3476,26 @@ def gen_timestamps(id): expected = spark_ts_t.fromInternal(spark_ts_t.toInternal(ts_tz)) self.assertEquals(expected, ts) + def test_vectorized_udf_check_config(self): + from pyspark.sql.functions import pandas_udf, col + orig_value = self.spark.conf.get("spark.sql.execution.arrow.maxRecordsPerBatch", None) + self.spark.conf.set("spark.sql.execution.arrow.maxRecordsPerBatch", 3) + try: + df = self.spark.range(10, numPartitions=1) + + @pandas_udf(returnType=LongType()) + def check_records_per_batch(x): + self.assertTrue(x.size <= 3) + return x + + result = df.select(check_records_per_batch(col("id"))) + self.assertEquals(df.collect(), result.collect()) + finally: + if orig_value is None: + self.spark.conf.unset("spark.sql.execution.arrow.maxRecordsPerBatch") + else: + self.spark.conf.set("spark.sql.execution.arrow.maxRecordsPerBatch", orig_value) + @unittest.skipIf(not _have_pandas or not _have_arrow, "Pandas or Arrow not installed") class GroupbyApplyTests(ReusedPySparkTestCase): diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala index d21b4afa2f06c..ddf2cbf2ab911 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala @@ -25,6 +25,12 @@ import org.apache.spark.sql.types.{DataType, StructType} abstract class QueryPlan[PlanType <: QueryPlan[PlanType]] extends TreeNode[PlanType] { self: PlanType => + /** + * The active config object within the current scope. + * Note that if you want to refer config values during execution, you have to capture them + * in Driver and use the captured values in Executors. + * See [[SQLConf.get]] for more information. + */ def conf: SQLConf = SQLConf.get def output: Seq[Attribute] diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonExec.scala index 0db463a5fbd89..bcda2dae92e53 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonExec.scala @@ -61,6 +61,9 @@ private class BatchIterator[T](iter: Iterator[T], batchSize: Int) case class ArrowEvalPythonExec(udfs: Seq[PythonUDF], output: Seq[Attribute], child: SparkPlan) extends EvalPythonExec(udfs, output, child) { + private val batchSize = conf.arrowMaxRecordsPerBatch + private val sessionLocalTimeZone = conf.sessionLocalTimeZone + protected override def evaluate( funcs: Seq[ChainedPythonFunctions], bufferSize: Int, @@ -73,13 +76,12 @@ case class ArrowEvalPythonExec(udfs: Seq[PythonUDF], output: Seq[Attribute], chi val schemaOut = StructType.fromAttributes(output.drop(child.output.length).zipWithIndex .map { case (attr, i) => attr.withName(s"_$i") }) - val batchSize = conf.arrowMaxRecordsPerBatch // DO NOT use iter.grouped(). See BatchIterator. val batchIter = if (batchSize > 0) new BatchIterator(iter, batchSize) else Iterator(iter) val columnarBatchIter = new ArrowPythonRunner( funcs, bufferSize, reuseWorker, - PythonEvalType.SQL_PANDAS_UDF, argOffsets, schema, conf.sessionLocalTimeZone) + PythonEvalType.SQL_PANDAS_UDF, argOffsets, schema, sessionLocalTimeZone) .compute(batchIter, context.partitionId(), context) new Iterator[InternalRow] { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasExec.scala index cc93fda9f81da..e1e04e34e0c71 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasExec.scala @@ -77,6 +77,7 @@ case class FlatMapGroupsInPandasExec( val chainedFunc = Seq(ChainedPythonFunctions(Seq(pandasFunction))) val argOffsets = Array((0 until (child.output.length - groupingAttributes.length)).toArray) val schema = StructType(child.schema.drop(groupingAttributes.length)) + val sessionLocalTimeZone = conf.sessionLocalTimeZone inputRDD.mapPartitionsInternal { iter => val grouped = if (groupingAttributes.isEmpty) { @@ -94,7 +95,7 @@ case class FlatMapGroupsInPandasExec( val columnarBatchIter = new ArrowPythonRunner( chainedFunc, bufferSize, reuseWorker, - PythonEvalType.SQL_PANDAS_GROUPED_UDF, argOffsets, schema, conf.sessionLocalTimeZone) + PythonEvalType.SQL_PANDAS_GROUPED_UDF, argOffsets, schema, sessionLocalTimeZone) .compute(grouped, context.partitionId(), context) columnarBatchIter.flatMap(_.rowIterator.asScala).map(UnsafeProjection.create(output, output)) From e80da8129a6b8ebaeac0eeac603ddc461144aec3 Mon Sep 17 00:00:00 2001 From: Juliusz Sompolski Date: Sat, 28 Oct 2017 17:20:35 -0700 Subject: [PATCH 1561/1765] [MINOR] Remove false comment from planStreamingAggregation ## What changes were proposed in this pull request? AggUtils.planStreamingAggregation has some comments about DISTINCT aggregates, while streaming aggregation does not support DISTINCT. This seems to have been wrongly copy-pasted over. ## How was this patch tested? Only a comment change. Author: Juliusz Sompolski Closes #18937 from juliuszsompolski/streaming-agg-doc. --- .../org/apache/spark/sql/execution/aggregate/AggUtils.scala | 3 --- 1 file changed, 3 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala index 12f8cffb6774a..ebbdf1aaa024d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala @@ -263,9 +263,6 @@ object AggUtils { val partialAggregate: SparkPlan = { val aggregateExpressions = functionsWithoutDistinct.map(_.copy(mode = Partial)) val aggregateAttributes = aggregateExpressions.map(_.resultAttribute) - // We will group by the original grouping expression, plus an additional expression for the - // DISTINCT column. For example, for AVG(DISTINCT value) GROUP BY key, the grouping - // expressions will be [key, value]. createAggregate( groupingExpressions = groupingExpressions, aggregateExpressions = aggregateExpressions, From 7fdacbc77bbcf98c2c045a1873e749129769dcc0 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Sat, 28 Oct 2017 18:24:18 -0700 Subject: [PATCH 1562/1765] [SPARK-19727][SQL][FOLLOWUP] Fix for round function that modifies original column ## What changes were proposed in this pull request? This is a followup of https://github.com/apache/spark/pull/17075 , to fix the bug in codegen path. ## How was this patch tested? new regression test Author: Wenchen Fan Closes #19576 from cloud-fan/bug. --- .../sql/catalyst/CatalystTypeConverters.scala | 2 +- .../spark/sql/catalyst/expressions/Cast.scala | 3 +- .../expressions/decimalExpressions.scala | 2 +- .../expressions/mathExpressions.scala | 10 ++---- .../org/apache/spark/sql/types/Decimal.scala | 31 ++++++++++--------- .../apache/spark/sql/types/DecimalSuite.scala | 2 +- .../apache/spark/sql/MathFunctionsSuite.scala | 12 +++++++ 7 files changed, 36 insertions(+), 26 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala index d4ebdb139fe0f..474ec592201d9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala @@ -310,7 +310,7 @@ object CatalystTypeConverters { case d: JavaBigInteger => Decimal(d) case d: Decimal => d } - decimal.toPrecision(dataType.precision, dataType.scale).orNull + decimal.toPrecision(dataType.precision, dataType.scale) } override def toScala(catalystValue: Decimal): JavaBigDecimal = { if (catalystValue == null) null diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala index d949b8f1d6696..bc809f559d586 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala @@ -387,10 +387,9 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String /** * Create new `Decimal` with precision and scale given in `decimalType` (if any), * returning null if it overflows or creating a new `value` and returning it if successful. - * */ private[this] def toPrecision(value: Decimal, decimalType: DecimalType): Decimal = - value.toPrecision(decimalType.precision, decimalType.scale).orNull + value.toPrecision(decimalType.precision, decimalType.scale) private[this] def castToDecimal(from: DataType, target: DecimalType): Any => Any = from match { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalExpressions.scala index c2211ae5d594b..752dea23e1f7a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalExpressions.scala @@ -85,7 +85,7 @@ case class CheckOverflow(child: Expression, dataType: DecimalType) extends Unary override def nullable: Boolean = true override def nullSafeEval(input: Any): Any = - input.asInstanceOf[Decimal].toPrecision(dataType.precision, dataType.scale).orNull + input.asInstanceOf[Decimal].toPrecision(dataType.precision, dataType.scale) override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { nullSafeCodeGen(ctx, ev, eval => { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala index 547d5be0e908e..d8dc0862f1141 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala @@ -1044,7 +1044,7 @@ abstract class RoundBase(child: Expression, scale: Expression, dataType match { case DecimalType.Fixed(_, s) => val decimal = input1.asInstanceOf[Decimal] - decimal.toPrecision(decimal.precision, s, mode).orNull + decimal.toPrecision(decimal.precision, s, mode) case ByteType => BigDecimal(input1.asInstanceOf[Byte]).setScale(_scale, mode).toByte case ShortType => @@ -1076,12 +1076,8 @@ abstract class RoundBase(child: Expression, scale: Expression, val evaluationCode = dataType match { case DecimalType.Fixed(_, s) => s""" - if (${ce.value}.changePrecision(${ce.value}.precision(), ${s}, - java.math.BigDecimal.${modeStr})) { - ${ev.value} = ${ce.value}; - } else { - ${ev.isNull} = true; - }""" + ${ev.value} = ${ce.value}.toPrecision(${ce.value}.precision(), $s, Decimal.$modeStr()); + ${ev.isNull} = ${ev.value} == null;""" case ByteType => if (_scale < 0) { s""" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala index 1f1fb51addfd8..6da4f28b12962 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala @@ -234,22 +234,17 @@ final class Decimal extends Ordered[Decimal] with Serializable { changePrecision(precision, scale, ROUND_HALF_UP) } - def changePrecision(precision: Int, scale: Int, mode: Int): Boolean = mode match { - case java.math.BigDecimal.ROUND_HALF_UP => changePrecision(precision, scale, ROUND_HALF_UP) - case java.math.BigDecimal.ROUND_HALF_EVEN => changePrecision(precision, scale, ROUND_HALF_EVEN) - } - /** * Create new `Decimal` with given precision and scale. * - * @return `Some(decimal)` if successful or `None` if overflow would occur + * @return a non-null `Decimal` value if successful or `null` if overflow would occur. */ private[sql] def toPrecision( precision: Int, scale: Int, - roundMode: BigDecimal.RoundingMode.Value = ROUND_HALF_UP): Option[Decimal] = { + roundMode: BigDecimal.RoundingMode.Value = ROUND_HALF_UP): Decimal = { val copy = clone() - if (copy.changePrecision(precision, scale, roundMode)) Some(copy) else None + if (copy.changePrecision(precision, scale, roundMode)) copy else null } /** @@ -257,8 +252,10 @@ final class Decimal extends Ordered[Decimal] with Serializable { * * @return true if successful, false if overflow would occur */ - private[sql] def changePrecision(precision: Int, scale: Int, - roundMode: BigDecimal.RoundingMode.Value): Boolean = { + private[sql] def changePrecision( + precision: Int, + scale: Int, + roundMode: BigDecimal.RoundingMode.Value): Boolean = { // fast path for UnsafeProjection if (precision == this.precision && scale == this.scale) { return true @@ -393,14 +390,20 @@ final class Decimal extends Ordered[Decimal] with Serializable { def floor: Decimal = if (scale == 0) this else { val newPrecision = DecimalType.bounded(precision - scale + 1, 0).precision - toPrecision(newPrecision, 0, ROUND_FLOOR).getOrElse( - throw new AnalysisException(s"Overflow when setting precision to $newPrecision")) + val res = toPrecision(newPrecision, 0, ROUND_FLOOR) + if (res == null) { + throw new AnalysisException(s"Overflow when setting precision to $newPrecision") + } + res } def ceil: Decimal = if (scale == 0) this else { val newPrecision = DecimalType.bounded(precision - scale + 1, 0).precision - toPrecision(newPrecision, 0, ROUND_CEILING).getOrElse( - throw new AnalysisException(s"Overflow when setting precision to $newPrecision")) + val res = toPrecision(newPrecision, 0, ROUND_CEILING) + if (res == null) { + throw new AnalysisException(s"Overflow when setting precision to $newPrecision") + } + res } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DecimalSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DecimalSuite.scala index 3193d1320ad9d..10de90c6a44ca 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DecimalSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DecimalSuite.scala @@ -213,7 +213,7 @@ class DecimalSuite extends SparkFunSuite with PrivateMethodTester { assert(d.changePrecision(10, 0, mode)) assert(d.toString === bd.setScale(0, mode).toString(), s"num: $sign$n, mode: $mode") - val copy = d.toPrecision(10, 0, mode).orNull + val copy = d.toPrecision(10, 0, mode) assert(copy !== null) assert(d.ne(copy)) assert(d === copy) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/MathFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/MathFunctionsSuite.scala index c2d08a06569bf..5be8c581e9ddb 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/MathFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/MathFunctionsSuite.scala @@ -258,6 +258,18 @@ class MathFunctionsSuite extends QueryTest with SharedSQLContext { ) } + test("round/bround with table columns") { + withTable("t") { + Seq(BigDecimal("5.9")).toDF("i").write.saveAsTable("t") + checkAnswer( + sql("select i, round(i) from t"), + Seq(Row(BigDecimal("5.9"), BigDecimal("6")))) + checkAnswer( + sql("select i, bround(i) from t"), + Seq(Row(BigDecimal("5.9"), BigDecimal("6")))) + } + } + test("exp") { testOneToOneMathFunction(exp, math.exp) } From 544a1ba678810b331d78fe9e63c7bb2342ab3d99 Mon Sep 17 00:00:00 2001 From: Xin Lu Date: Sun, 29 Oct 2017 15:29:23 +0900 Subject: [PATCH 1563/1765] =?UTF-8?q?[SPARK-22375][TEST]=20Test=20script?= =?UTF-8?q?=20can=20fail=20if=20eggs=20are=20installed=20by=20set=E2=80=A6?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit …up.py during test process ## What changes were proposed in this pull request? Ignore the python/.eggs folder when running lint-python ## How was this patch tested? 1) put a bad python file in python/.eggs and ran the original script. results were: xins-MBP:spark xinlu$ dev/lint-python PEP8 checks failed. ./python/.eggs/worker.py:33:4: E121 continuation line under-indented for hanging indent ./python/.eggs/worker.py:34:5: E131 continuation line unaligned for hanging indent 2) test same situation with change: xins-MBP:spark xinlu$ dev/lint-python PEP8 checks passed. The sphinx-build command was not found. Skipping pydoc checks for now Author: Xin Lu Closes #19597 from xynny/SPARK-22375. --- dev/tox.ini | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dev/tox.ini b/dev/tox.ini index eeeb637460cfb..eb8b1eb2c2886 100644 --- a/dev/tox.ini +++ b/dev/tox.ini @@ -16,4 +16,4 @@ [pep8] ignore=E402,E731,E241,W503,E226 max-line-length=100 -exclude=cloudpickle.py,heapq3.py,shared.py,python/docs/conf.py,work/*/*.py +exclude=cloudpickle.py,heapq3.py,shared.py,python/docs/conf.py,work/*/*.py,python/.eggs/* From bc7ca9786e162e33f29d57c4aacb830761b97221 Mon Sep 17 00:00:00 2001 From: Jen-Ming Chung Date: Sun, 29 Oct 2017 18:11:48 +0100 Subject: [PATCH 1564/1765] [SPARK-22291][SQL] Conversion error when transforming array types of uuid, inet and cidr to StingType in PostgreSQL ## What changes were proposed in this pull request? This PR fixes the conversion error when reads data from a PostgreSQL table that contains columns of `uuid[]`, `inet[]` and `cidr[]` data types. For example, create a table with the uuid[] data type, and insert the test data. ```SQL CREATE TABLE users ( id smallint NOT NULL, name character varying(50), user_ids uuid[], PRIMARY KEY (id) ) INSERT INTO users ("id", "name","user_ids") VALUES (1, 'foo', ARRAY ['7be8aaf8-650e-4dbb-8186-0a749840ecf2' ,'205f9bfc-018c-4452-a605-609c0cfad228']::UUID[] ) ``` Then it will throw the following exceptions when trying to load the data. ``` java.lang.ClassCastException: [Ljava.util.UUID; cannot be cast to [Ljava.lang.String; at org.apache.spark.sql.execution.datasources.jdbc.JdbcUtils$$anonfun$14.apply(JdbcUtils.scala:459) at org.apache.spark.sql.execution.datasources.jdbc.JdbcUtils$$anonfun$14.apply(JdbcUtils.scala:458) ... ``` ## How was this patch tested? Added test in `PostgresIntegrationSuite`. Author: Jen-Ming Chung Closes #19567 from jmchung/SPARK-22291. --- .../sql/jdbc/PostgresIntegrationSuite.scala | 37 ++++++++++++++++++- .../datasources/jdbc/JdbcUtils.scala | 5 ++- 2 files changed, 38 insertions(+), 4 deletions(-) diff --git a/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/PostgresIntegrationSuite.scala b/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/PostgresIntegrationSuite.scala index eb3c458360e7b..48aba90afc787 100644 --- a/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/PostgresIntegrationSuite.scala +++ b/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/PostgresIntegrationSuite.scala @@ -60,7 +60,22 @@ class PostgresIntegrationSuite extends DockerJDBCIntegrationSuite { "(id integer, tstz TIMESTAMP WITH TIME ZONE, ttz TIME WITH TIME ZONE)") .executeUpdate() conn.prepareStatement("INSERT INTO ts_with_timezone VALUES " + - "(1, TIMESTAMP WITH TIME ZONE '2016-08-12 10:22:31.949271-07', TIME WITH TIME ZONE '17:22:31.949271+00')") + "(1, TIMESTAMP WITH TIME ZONE '2016-08-12 10:22:31.949271-07', " + + "TIME WITH TIME ZONE '17:22:31.949271+00')") + .executeUpdate() + + conn.prepareStatement("CREATE TABLE st_with_array (c0 uuid, c1 inet, c2 cidr," + + "c3 json, c4 jsonb, c5 uuid[], c6 inet[], c7 cidr[], c8 json[], c9 jsonb[])") + .executeUpdate() + conn.prepareStatement("INSERT INTO st_with_array VALUES ( " + + "'0a532531-cdf1-45e3-963d-5de90b6a30f1', '172.168.22.1', '192.168.100.128/25', " + + """'{"a": "foo", "b": "bar"}', '{"a": 1, "b": 2}', """ + + "ARRAY['7be8aaf8-650e-4dbb-8186-0a749840ecf2'," + + "'205f9bfc-018c-4452-a605-609c0cfad228']::uuid[], ARRAY['172.16.0.41', " + + "'172.16.0.42']::inet[], ARRAY['192.168.0.0/24', '10.1.0.0/16']::cidr[], " + + """ARRAY['{"a": "foo", "b": "bar"}', '{"a": 1, "b": 2}']::json[], """ + + """ARRAY['{"a": 1, "b": 2, "c": 3}']::jsonb[])""" + ) .executeUpdate() } @@ -134,11 +149,29 @@ class PostgresIntegrationSuite extends DockerJDBCIntegrationSuite { assert(schema(1).dataType == ShortType) } - test("SPARK-20557: column type TIMESTAMP with TIME ZONE and TIME with TIME ZONE should be recognized") { + test("SPARK-20557: column type TIMESTAMP with TIME ZONE and TIME with TIME ZONE " + + "should be recognized") { val dfRead = sqlContext.read.jdbc(jdbcUrl, "ts_with_timezone", new Properties) val rows = dfRead.collect() val types = rows(0).toSeq.map(x => x.getClass.toString) assert(types(1).equals("class java.sql.Timestamp")) assert(types(2).equals("class java.sql.Timestamp")) } + + test("SPARK-22291: Conversion error when transforming array types of " + + "uuid, inet and cidr to StingType in PostgreSQL") { + val df = sqlContext.read.jdbc(jdbcUrl, "st_with_array", new Properties) + val rows = df.collect() + assert(rows(0).getString(0) == "0a532531-cdf1-45e3-963d-5de90b6a30f1") + assert(rows(0).getString(1) == "172.168.22.1") + assert(rows(0).getString(2) == "192.168.100.128/25") + assert(rows(0).getString(3) == "{\"a\": \"foo\", \"b\": \"bar\"}") + assert(rows(0).getString(4) == "{\"a\": 1, \"b\": 2}") + assert(rows(0).getSeq(5) == Seq("7be8aaf8-650e-4dbb-8186-0a749840ecf2", + "205f9bfc-018c-4452-a605-609c0cfad228")) + assert(rows(0).getSeq(6) == Seq("172.16.0.41", "172.16.0.42")) + assert(rows(0).getSeq(7) == Seq("192.168.0.0/24", "10.1.0.0/16")) + assert(rows(0).getSeq(8) == Seq("""{"a": "foo", "b": "bar"}""", """{"a": 1, "b": 2}""")) + assert(rows(0).getSeq(9) == Seq("""{"a": 1, "b": 2, "c": 3}""")) + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala index 9debc4ff82748..75c94fc486493 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala @@ -456,8 +456,9 @@ object JdbcUtils extends Logging { case StringType => (array: Object) => - array.asInstanceOf[Array[java.lang.String]] - .map(UTF8String.fromString) + // some underling types are not String such as uuid, inet, cidr, etc. + array.asInstanceOf[Array[java.lang.Object]] + .map(obj => if (obj == null) null else UTF8String.fromString(obj.toString)) case DateType => (array: Object) => From 659acf18daf0d91fc0595227b7e29d732b99f4aa Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Sun, 29 Oct 2017 10:37:25 -0700 Subject: [PATCH 1565/1765] Revert "[SPARK-22308] Support alternative unit testing styles in external applications" This reverts commit 592cfeab9caeff955d115a1ca5014ede7d402907. --- .../org/apache/spark/SharedSparkContext.scala | 17 +- .../spark/sql/catalyst/plans/PlanTest.scala | 10 +- .../spark/sql/test/GenericFlatSpecSuite.scala | 45 ----- .../spark/sql/test/GenericFunSpecSuite.scala | 47 ----- .../spark/sql/test/GenericWordSpecSuite.scala | 51 ------ .../apache/spark/sql/test/SQLTestUtils.scala | 173 ++++++++---------- .../spark/sql/test/SharedSQLContext.scala | 84 ++++++++- .../spark/sql/test/SharedSparkSession.scala | 119 ------------ 8 files changed, 165 insertions(+), 381 deletions(-) delete mode 100644 sql/core/src/test/scala/org/apache/spark/sql/test/GenericFlatSpecSuite.scala delete mode 100644 sql/core/src/test/scala/org/apache/spark/sql/test/GenericFunSpecSuite.scala delete mode 100644 sql/core/src/test/scala/org/apache/spark/sql/test/GenericWordSpecSuite.scala delete mode 100644 sql/core/src/test/scala/org/apache/spark/sql/test/SharedSparkSession.scala diff --git a/core/src/test/scala/org/apache/spark/SharedSparkContext.scala b/core/src/test/scala/org/apache/spark/SharedSparkContext.scala index 1aa1c421d792e..6aedcb1271ff6 100644 --- a/core/src/test/scala/org/apache/spark/SharedSparkContext.scala +++ b/core/src/test/scala/org/apache/spark/SharedSparkContext.scala @@ -29,23 +29,10 @@ trait SharedSparkContext extends BeforeAndAfterAll with BeforeAndAfterEach { sel var conf = new SparkConf(false) - /** - * Initialize the [[SparkContext]]. Generally, this is just called from beforeAll; however, in - * test using styles other than FunSuite, there is often code that relies on the session between - * test group constructs and the actual tests, which may need this session. It is purely a - * semantic difference, but semantically, it makes more sense to call 'initializeContext' between - * a 'describe' and an 'it' call than it does to call 'beforeAll'. - */ - protected def initializeContext(): Unit = { - if (null == _sc) { - _sc = new SparkContext( - "local[4]", "test", conf.set("spark.hadoop.fs.file.impl", classOf[DebugFilesystem].getName)) - } - } - override def beforeAll() { super.beforeAll() - initializeContext() + _sc = new SparkContext( + "local[4]", "test", conf.set("spark.hadoop.fs.file.impl", classOf[DebugFilesystem].getName)) } override def afterAll() { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala index 82c5307d54360..10bdfafd6f933 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala @@ -17,8 +17,6 @@ package org.apache.spark.sql.catalyst.plans -import org.scalatest.Suite - import org.apache.spark.SparkFunSuite import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.analysis.SimpleAnalyzer @@ -31,13 +29,7 @@ import org.apache.spark.sql.internal.SQLConf /** * Provides helper methods for comparing plans. */ -trait PlanTest extends SparkFunSuite with PlanTestBase - -/** - * Provides helper methods for comparing plans, but without the overhead of - * mandating a FunSuite. - */ -trait PlanTestBase extends PredicateHelper { self: Suite => +trait PlanTest extends SparkFunSuite with PredicateHelper { // TODO(gatorsmile): remove this from PlanTest and all the analyzer rules protected def conf = SQLConf.get diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/GenericFlatSpecSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/GenericFlatSpecSuite.scala deleted file mode 100644 index 6179585a0d39a..0000000000000 --- a/sql/core/src/test/scala/org/apache/spark/sql/test/GenericFlatSpecSuite.scala +++ /dev/null @@ -1,45 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.test - -import org.scalatest.FlatSpec - -/** - * The purpose of this suite is to make sure that generic FlatSpec-based scala - * tests work with a shared spark session - */ -class GenericFlatSpecSuite extends FlatSpec with SharedSparkSession { - import testImplicits._ - initializeSession() - val ds = Seq((1, 1), (2, 1), (3, 2), (4, 2), (5, 3), (6, 3), (7, 4), (8, 4)).toDS - - "A Simple Dataset" should "have the specified number of elements" in { - assert(8 === ds.count) - } - it should "have the specified number of unique elements" in { - assert(8 === ds.distinct.count) - } - it should "have the specified number of elements in each column" in { - assert(8 === ds.select("_1").count) - assert(8 === ds.select("_2").count) - } - it should "have the correct number of distinct elements in each column" in { - assert(8 === ds.select("_1").distinct.count) - assert(4 === ds.select("_2").distinct.count) - } -} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/GenericFunSpecSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/GenericFunSpecSuite.scala deleted file mode 100644 index 15139ee8b3047..0000000000000 --- a/sql/core/src/test/scala/org/apache/spark/sql/test/GenericFunSpecSuite.scala +++ /dev/null @@ -1,47 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.test - -import org.scalatest.FunSpec - -/** - * The purpose of this suite is to make sure that generic FunSpec-based scala - * tests work with a shared spark session - */ -class GenericFunSpecSuite extends FunSpec with SharedSparkSession { - import testImplicits._ - initializeSession() - val ds = Seq((1, 1), (2, 1), (3, 2), (4, 2), (5, 3), (6, 3), (7, 4), (8, 4)).toDS - - describe("Simple Dataset") { - it("should have the specified number of elements") { - assert(8 === ds.count) - } - it("should have the specified number of unique elements") { - assert(8 === ds.distinct.count) - } - it("should have the specified number of elements in each column") { - assert(8 === ds.select("_1").count) - assert(8 === ds.select("_2").count) - } - it("should have the correct number of distinct elements in each column") { - assert(8 === ds.select("_1").distinct.count) - assert(4 === ds.select("_2").distinct.count) - } - } -} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/GenericWordSpecSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/GenericWordSpecSuite.scala deleted file mode 100644 index b6548bf95fec8..0000000000000 --- a/sql/core/src/test/scala/org/apache/spark/sql/test/GenericWordSpecSuite.scala +++ /dev/null @@ -1,51 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.test - -import org.scalatest.WordSpec - -/** - * The purpose of this suite is to make sure that generic WordSpec-based scala - * tests work with a shared spark session - */ -class GenericWordSpecSuite extends WordSpec with SharedSparkSession { - import testImplicits._ - initializeSession() - val ds = Seq((1, 1), (2, 1), (3, 2), (4, 2), (5, 3), (6, 3), (7, 4), (8, 4)).toDS - - "A Simple Dataset" when { - "looked at as complete rows" should { - "have the specified number of elements" in { - assert(8 === ds.count) - } - "have the specified number of unique elements" in { - assert(8 === ds.distinct.count) - } - } - "refined to specific columns" should { - "have the specified number of elements in each column" in { - assert(8 === ds.select("_1").count) - assert(8 === ds.select("_2").count) - } - "have the correct number of distinct elements in each column" in { - assert(8 === ds.select("_1").distinct.count) - assert(4 === ds.select("_2").distinct.count) - } - } - } -} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala index b4248b74f50ab..a14a1441a4313 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala @@ -27,7 +27,7 @@ import scala.language.implicitConversions import scala.util.control.NonFatal import org.apache.hadoop.fs.Path -import org.scalatest.{BeforeAndAfterAll, Suite} +import org.scalatest.BeforeAndAfterAll import org.scalatest.concurrent.Eventually import org.apache.spark.SparkFunSuite @@ -36,17 +36,14 @@ import org.apache.spark.sql.catalyst.FunctionIdentifier import org.apache.spark.sql.catalyst.analysis.NoSuchTableException import org.apache.spark.sql.catalyst.catalog.SessionCatalog.DEFAULT_DATABASE import org.apache.spark.sql.catalyst.plans.PlanTest -import org.apache.spark.sql.catalyst.plans.PlanTestBase import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.util._ import org.apache.spark.sql.execution.FilterExec import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.util.UninterruptibleThread -import org.apache.spark.util.Utils +import org.apache.spark.util.{UninterruptibleThread, Utils} /** - * Helper trait that should be extended by all SQL test suites within the Spark - * code base. + * Helper trait that should be extended by all SQL test suites. * * This allows subclasses to plugin a custom `SQLContext`. It comes with test data * prepared in advance as well as all implicit conversions used extensively by dataframes. @@ -55,99 +52,17 @@ import org.apache.spark.util.Utils * Subclasses should *not* create `SQLContext`s in the test suite constructor, which is * prone to leaving multiple overlapping [[org.apache.spark.SparkContext]]s in the same JVM. */ -private[sql] trait SQLTestUtils extends SparkFunSuite with SQLTestUtilsBase with PlanTest { - // Whether to materialize all test data before the first test is run - private var loadTestDataBeforeTests = false - - protected override def beforeAll(): Unit = { - super.beforeAll() - if (loadTestDataBeforeTests) { - loadTestData() - } - } - - /** - * Materialize the test data immediately after the `SQLContext` is set up. - * This is necessary if the data is accessed by name but not through direct reference. - */ - protected def setupTestData(): Unit = { - loadTestDataBeforeTests = true - } - - /** - * Disable stdout and stderr when running the test. To not output the logs to the console, - * ConsoleAppender's `follow` should be set to `true` so that it will honors reassignments of - * System.out or System.err. Otherwise, ConsoleAppender will still output to the console even if - * we change System.out and System.err. - */ - protected def testQuietly(name: String)(f: => Unit): Unit = { - test(name) { - quietly { - f - } - } - } - - /** - * Run a test on a separate `UninterruptibleThread`. - */ - protected def testWithUninterruptibleThread(name: String, quietly: Boolean = false) - (body: => Unit): Unit = { - val timeoutMillis = 10000 - @transient var ex: Throwable = null - - def runOnThread(): Unit = { - val thread = new UninterruptibleThread(s"Testing thread for test $name") { - override def run(): Unit = { - try { - body - } catch { - case NonFatal(e) => - ex = e - } - } - } - thread.setDaemon(true) - thread.start() - thread.join(timeoutMillis) - if (thread.isAlive) { - thread.interrupt() - // If this interrupt does not work, then this thread is most likely running something that - // is not interruptible. There is not much point to wait for the thread to termniate, and - // we rather let the JVM terminate the thread on exit. - fail( - s"Test '$name' running on o.a.s.util.UninterruptibleThread timed out after" + - s" $timeoutMillis ms") - } else if (ex != null) { - throw ex - } - } - - if (quietly) { - testQuietly(name) { runOnThread() } - } else { - test(name) { runOnThread() } - } - } -} - -/** - * Helper trait that can be extended by all external SQL test suites. - * - * This allows subclasses to plugin a custom `SQLContext`. - * To use implicit methods, import `testImplicits._` instead of through the `SQLContext`. - * - * Subclasses should *not* create `SQLContext`s in the test suite constructor, which is - * prone to leaving multiple overlapping [[org.apache.spark.SparkContext]]s in the same JVM. - */ -private[sql] trait SQLTestUtilsBase - extends Eventually +private[sql] trait SQLTestUtils + extends SparkFunSuite with Eventually with BeforeAndAfterAll with SQLTestData - with PlanTestBase { self: Suite => + with PlanTest { self => protected def sparkContext = spark.sparkContext + // Whether to materialize all test data before the first test is run + private var loadTestDataBeforeTests = false + // Shorthand for running a query using our SQLContext protected lazy val sql = spark.sql _ @@ -162,6 +77,21 @@ private[sql] trait SQLTestUtilsBase protected override def _sqlContext: SQLContext = self.spark.sqlContext } + /** + * Materialize the test data immediately after the `SQLContext` is set up. + * This is necessary if the data is accessed by name but not through direct reference. + */ + protected def setupTestData(): Unit = { + loadTestDataBeforeTests = true + } + + protected override def beforeAll(): Unit = { + super.beforeAll() + if (loadTestDataBeforeTests) { + loadTestData() + } + } + protected override def withSQLConf(pairs: (String, String)*)(f: => Unit): Unit = { SparkSession.setActiveSession(spark) super.withSQLConf(pairs: _*)(f) @@ -367,6 +297,61 @@ private[sql] trait SQLTestUtilsBase Dataset.ofRows(spark, plan) } + /** + * Disable stdout and stderr when running the test. To not output the logs to the console, + * ConsoleAppender's `follow` should be set to `true` so that it will honors reassignments of + * System.out or System.err. Otherwise, ConsoleAppender will still output to the console even if + * we change System.out and System.err. + */ + protected def testQuietly(name: String)(f: => Unit): Unit = { + test(name) { + quietly { + f + } + } + } + + /** + * Run a test on a separate `UninterruptibleThread`. + */ + protected def testWithUninterruptibleThread(name: String, quietly: Boolean = false) + (body: => Unit): Unit = { + val timeoutMillis = 10000 + @transient var ex: Throwable = null + + def runOnThread(): Unit = { + val thread = new UninterruptibleThread(s"Testing thread for test $name") { + override def run(): Unit = { + try { + body + } catch { + case NonFatal(e) => + ex = e + } + } + } + thread.setDaemon(true) + thread.start() + thread.join(timeoutMillis) + if (thread.isAlive) { + thread.interrupt() + // If this interrupt does not work, then this thread is most likely running something that + // is not interruptible. There is not much point to wait for the thread to termniate, and + // we rather let the JVM terminate the thread on exit. + fail( + s"Test '$name' running on o.a.s.util.UninterruptibleThread timed out after" + + s" $timeoutMillis ms") + } else if (ex != null) { + throw ex + } + } + + if (quietly) { + testQuietly(name) { runOnThread() } + } else { + test(name) { runOnThread() } + } + } /** * This method is used to make the given path qualified, when a path diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSQLContext.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSQLContext.scala index 4d578e21f5494..cd8d0708d8a32 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSQLContext.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSQLContext.scala @@ -17,4 +17,86 @@ package org.apache.spark.sql.test -trait SharedSQLContext extends SQLTestUtils with SharedSparkSession +import scala.concurrent.duration._ + +import org.scalatest.BeforeAndAfterEach +import org.scalatest.concurrent.Eventually + +import org.apache.spark.{DebugFilesystem, SparkConf} +import org.apache.spark.sql.{SparkSession, SQLContext} +import org.apache.spark.sql.internal.SQLConf + +/** + * Helper trait for SQL test suites where all tests share a single [[TestSparkSession]]. + */ +trait SharedSQLContext extends SQLTestUtils with BeforeAndAfterEach with Eventually { + + protected def sparkConf = { + new SparkConf() + .set("spark.hadoop.fs.file.impl", classOf[DebugFilesystem].getName) + .set("spark.unsafe.exceptionOnMemoryLeak", "true") + .set(SQLConf.CODEGEN_FALLBACK.key, "false") + } + + /** + * The [[TestSparkSession]] to use for all tests in this suite. + * + * By default, the underlying [[org.apache.spark.SparkContext]] will be run in local + * mode with the default test configurations. + */ + private var _spark: TestSparkSession = null + + /** + * The [[TestSparkSession]] to use for all tests in this suite. + */ + protected implicit def spark: SparkSession = _spark + + /** + * The [[TestSQLContext]] to use for all tests in this suite. + */ + protected implicit def sqlContext: SQLContext = _spark.sqlContext + + protected def createSparkSession: TestSparkSession = { + new TestSparkSession(sparkConf) + } + + /** + * Initialize the [[TestSparkSession]]. + */ + protected override def beforeAll(): Unit = { + SparkSession.sqlListener.set(null) + if (_spark == null) { + _spark = createSparkSession + } + // Ensure we have initialized the context before calling parent code + super.beforeAll() + } + + /** + * Stop the underlying [[org.apache.spark.SparkContext]], if any. + */ + protected override def afterAll(): Unit = { + super.afterAll() + if (_spark != null) { + _spark.sessionState.catalog.reset() + _spark.stop() + _spark = null + } + } + + protected override def beforeEach(): Unit = { + super.beforeEach() + DebugFilesystem.clearOpenStreams() + } + + protected override def afterEach(): Unit = { + super.afterEach() + // Clear all persistent datasets after each test + spark.sharedState.cacheManager.clearCache() + // files can be closed from other threads, so wait a bit + // normally this doesn't take more than 1s + eventually(timeout(10.seconds)) { + DebugFilesystem.assertNoOpenStreams() + } + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSparkSession.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSparkSession.scala deleted file mode 100644 index e0568a3c5c99f..0000000000000 --- a/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSparkSession.scala +++ /dev/null @@ -1,119 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.test - -import scala.concurrent.duration._ - -import org.scalatest.{BeforeAndAfterEach, Suite} -import org.scalatest.concurrent.Eventually - -import org.apache.spark.{DebugFilesystem, SparkConf} -import org.apache.spark.sql.{SparkSession, SQLContext} -import org.apache.spark.sql.internal.SQLConf - -/** - * Helper trait for SQL test suites where all tests share a single [[TestSparkSession]]. - */ -trait SharedSparkSession - extends SQLTestUtilsBase - with BeforeAndAfterEach - with Eventually { self: Suite => - - protected def sparkConf = { - new SparkConf() - .set("spark.hadoop.fs.file.impl", classOf[DebugFilesystem].getName) - .set("spark.unsafe.exceptionOnMemoryLeak", "true") - .set(SQLConf.CODEGEN_FALLBACK.key, "false") - } - - /** - * The [[TestSparkSession]] to use for all tests in this suite. - * - * By default, the underlying [[org.apache.spark.SparkContext]] will be run in local - * mode with the default test configurations. - */ - private var _spark: TestSparkSession = null - - /** - * The [[TestSparkSession]] to use for all tests in this suite. - */ - protected implicit def spark: SparkSession = _spark - - /** - * The [[TestSQLContext]] to use for all tests in this suite. - */ - protected implicit def sqlContext: SQLContext = _spark.sqlContext - - protected def createSparkSession: TestSparkSession = { - new TestSparkSession(sparkConf) - } - - /** - * Initialize the [[TestSparkSession]]. Generally, this is just called from - * beforeAll; however, in test using styles other than FunSuite, there is - * often code that relies on the session between test group constructs and - * the actual tests, which may need this session. It is purely a semantic - * difference, but semantically, it makes more sense to call - * 'initializeSession' between a 'describe' and an 'it' call than it does to - * call 'beforeAll'. - */ - protected def initializeSession(): Unit = { - SparkSession.sqlListener.set(null) - if (_spark == null) { - _spark = createSparkSession - } - } - - /** - * Make sure the [[TestSparkSession]] is initialized before any tests are run. - */ - protected override def beforeAll(): Unit = { - initializeSession() - - // Ensure we have initialized the context before calling parent code - super.beforeAll() - } - - /** - * Stop the underlying [[org.apache.spark.SparkContext]], if any. - */ - protected override def afterAll(): Unit = { - super.afterAll() - if (_spark != null) { - _spark.sessionState.catalog.reset() - _spark.stop() - _spark = null - } - } - - protected override def beforeEach(): Unit = { - super.beforeEach() - DebugFilesystem.clearOpenStreams() - } - - protected override def afterEach(): Unit = { - super.afterEach() - // Clear all persistent datasets after each test - spark.sharedState.cacheManager.clearCache() - // files can be closed from other threads, so wait a bit - // normally this doesn't take more than 1s - eventually(timeout(10.seconds)) { - DebugFilesystem.assertNoOpenStreams() - } - } -} From 1fe27612d7bcb8b6478a36bc16ddd4802e4ee2fc Mon Sep 17 00:00:00 2001 From: Shivaram Venkataraman Date: Sun, 29 Oct 2017 18:53:47 -0700 Subject: [PATCH 1566/1765] [SPARK-22344][SPARKR] Set java.io.tmpdir for SparkR tests MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This PR sets the java.io.tmpdir for CRAN checks and also disables the hsperfdata for the JVM when running CRAN checks. Together this prevents files from being left behind in `/tmp` ## How was this patch tested? Tested manually on a clean EC2 machine Author: Shivaram Venkataraman Closes #19589 from shivaram/sparkr-tmpdir-clean. --- R/pkg/inst/tests/testthat/test_basic.R | 6 ++++-- R/pkg/tests/run-all.R | 9 +++++++++ R/pkg/vignettes/sparkr-vignettes.Rmd | 8 +++++++- 3 files changed, 20 insertions(+), 3 deletions(-) diff --git a/R/pkg/inst/tests/testthat/test_basic.R b/R/pkg/inst/tests/testthat/test_basic.R index de47162d5325f..823d26f12feee 100644 --- a/R/pkg/inst/tests/testthat/test_basic.R +++ b/R/pkg/inst/tests/testthat/test_basic.R @@ -18,7 +18,8 @@ context("basic tests for CRAN") test_that("create DataFrame from list or data.frame", { - sparkR.session(master = sparkRTestMaster, enableHiveSupport = FALSE) + sparkR.session(master = sparkRTestMaster, enableHiveSupport = FALSE, + sparkConfig = sparkRTestConfig) i <- 4 df <- createDataFrame(data.frame(dummy = 1:i)) @@ -49,7 +50,8 @@ test_that("create DataFrame from list or data.frame", { }) test_that("spark.glm and predict", { - sparkR.session(master = sparkRTestMaster, enableHiveSupport = FALSE) + sparkR.session(master = sparkRTestMaster, enableHiveSupport = FALSE, + sparkConfig = sparkRTestConfig) training <- suppressWarnings(createDataFrame(iris)) # gaussian family diff --git a/R/pkg/tests/run-all.R b/R/pkg/tests/run-all.R index a1834a220261d..a7f913e5fad11 100644 --- a/R/pkg/tests/run-all.R +++ b/R/pkg/tests/run-all.R @@ -36,8 +36,17 @@ invisible(lapply(sparkRWhitelistSQLDirs, sparkRFilesBefore <- list.files(path = sparkRDir, all.files = TRUE) sparkRTestMaster <- "local[1]" +sparkRTestConfig <- list() if (identical(Sys.getenv("NOT_CRAN"), "true")) { sparkRTestMaster <- "" +} else { + # Disable hsperfdata on CRAN + old_java_opt <- Sys.getenv("_JAVA_OPTIONS") + Sys.setenv("_JAVA_OPTIONS" = paste("-XX:-UsePerfData", old_java_opt)) + tmpDir <- tempdir() + tmpArg <- paste0("-Djava.io.tmpdir=", tmpDir) + sparkRTestConfig <- list(spark.driver.extraJavaOptions = tmpArg, + spark.executor.extraJavaOptions = tmpArg) } test_package("SparkR") diff --git a/R/pkg/vignettes/sparkr-vignettes.Rmd b/R/pkg/vignettes/sparkr-vignettes.Rmd index caeae72e37bbf..907bbb3d66018 100644 --- a/R/pkg/vignettes/sparkr-vignettes.Rmd +++ b/R/pkg/vignettes/sparkr-vignettes.Rmd @@ -36,6 +36,12 @@ opts_hooks$set(eval = function(options) { } options }) +r_tmp_dir <- tempdir() +tmp_arg <- paste("-Djava.io.tmpdir=", r_tmp_dir, sep = "") +sparkSessionConfig <- list(spark.driver.extraJavaOptions = tmp_arg, + spark.executor.extraJavaOptions = tmp_arg) +old_java_opt <- Sys.getenv("_JAVA_OPTIONS") +Sys.setenv("_JAVA_OPTIONS" = paste("-XX:-UsePerfData", old_java_opt, sep = " ")) ``` ## Overview @@ -57,7 +63,7 @@ We use default settings in which it runs in local mode. It auto downloads Spark ```{r, include=FALSE} install.spark() -sparkR.session(master = "local[1]") +sparkR.session(master = "local[1]", sparkConfig = sparkSessionConfig, enableHiveSupport = FALSE) ``` ```{r, eval=FALSE} sparkR.session() From 188b47e68350731da775efccc2cda9c61610aa14 Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Mon, 30 Oct 2017 11:50:22 +0900 Subject: [PATCH 1567/1765] [SPARK-22379][PYTHON] Reduce duplication setUpClass and tearDownClass in PySpark SQL tests ## What changes were proposed in this pull request? This PR propose to add `ReusedSQLTestCase` which deduplicate `setUpClass` and `tearDownClass` in `sql/tests.py`. ## How was this patch tested? Jenkins tests and manual tests. Author: hyukjinkwon Closes #19595 from HyukjinKwon/reduce-dupe. --- python/pyspark/sql/tests.py | 63 +++++++++++++------------------------ 1 file changed, 21 insertions(+), 42 deletions(-) diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 8ed37c9da98b8..483f39aeef66a 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -179,6 +179,18 @@ def __init__(self, key, value): self.value = value +class ReusedSQLTestCase(ReusedPySparkTestCase): + @classmethod + def setUpClass(cls): + ReusedPySparkTestCase.setUpClass() + cls.spark = SparkSession(cls.sc) + + @classmethod + def tearDownClass(cls): + ReusedPySparkTestCase.tearDownClass() + cls.spark.stop() + + class DataTypeTests(unittest.TestCase): # regression test for SPARK-6055 def test_data_type_eq(self): @@ -214,21 +226,19 @@ def test_struct_field_type_name(self): self.assertRaises(TypeError, struct_field.typeName) -class SQLTests(ReusedPySparkTestCase): +class SQLTests(ReusedSQLTestCase): @classmethod def setUpClass(cls): - ReusedPySparkTestCase.setUpClass() + ReusedSQLTestCase.setUpClass() cls.tempdir = tempfile.NamedTemporaryFile(delete=False) os.unlink(cls.tempdir.name) - cls.spark = SparkSession(cls.sc) cls.testData = [Row(key=i, value=str(i)) for i in range(100)] cls.df = cls.spark.createDataFrame(cls.testData) @classmethod def tearDownClass(cls): - ReusedPySparkTestCase.tearDownClass() - cls.spark.stop() + ReusedSQLTestCase.tearDownClass() shutil.rmtree(cls.tempdir.name, ignore_errors=True) def test_sqlcontext_reuses_sparksession(self): @@ -2623,17 +2633,7 @@ def test_hivecontext(self): self.assertTrue(os.path.exists(metastore_path)) -class SQLTests2(ReusedPySparkTestCase): - - @classmethod - def setUpClass(cls): - ReusedPySparkTestCase.setUpClass() - cls.spark = SparkSession(cls.sc) - - @classmethod - def tearDownClass(cls): - ReusedPySparkTestCase.tearDownClass() - cls.spark.stop() +class SQLTests2(ReusedSQLTestCase): # We can't include this test into SQLTests because we will stop class's SparkContext and cause # other tests failed. @@ -3082,12 +3082,12 @@ def __init__(self, **kwargs): @unittest.skipIf(not _have_arrow, "Arrow not installed") -class ArrowTests(ReusedPySparkTestCase): +class ArrowTests(ReusedSQLTestCase): @classmethod def setUpClass(cls): from datetime import datetime - ReusedPySparkTestCase.setUpClass() + ReusedSQLTestCase.setUpClass() # Synchronize default timezone between Python and Java cls.tz_prev = os.environ.get("TZ", None) # save current tz if set @@ -3095,7 +3095,6 @@ def setUpClass(cls): os.environ["TZ"] = tz time.tzset() - cls.spark = SparkSession(cls.sc) cls.spark.conf.set("spark.sql.session.timeZone", tz) cls.spark.conf.set("spark.sql.execution.arrow.enabled", "true") cls.schema = StructType([ @@ -3116,8 +3115,7 @@ def tearDownClass(cls): if cls.tz_prev is not None: os.environ["TZ"] = cls.tz_prev time.tzset() - ReusedPySparkTestCase.tearDownClass() - cls.spark.stop() + ReusedSQLTestCase.tearDownClass() def assertFramesEqual(self, df_with_arrow, df_without): msg = ("DataFrame from Arrow is not equal" + @@ -3169,17 +3167,7 @@ def test_filtered_frame(self): @unittest.skipIf(not _have_pandas or not _have_arrow, "Pandas or Arrow not installed") -class VectorizedUDFTests(ReusedPySparkTestCase): - - @classmethod - def setUpClass(cls): - ReusedPySparkTestCase.setUpClass() - cls.spark = SparkSession(cls.sc) - - @classmethod - def tearDownClass(cls): - ReusedPySparkTestCase.tearDownClass() - cls.spark.stop() +class VectorizedUDFTests(ReusedSQLTestCase): def test_vectorized_udf_basic(self): from pyspark.sql.functions import pandas_udf, col @@ -3498,16 +3486,7 @@ def check_records_per_batch(x): @unittest.skipIf(not _have_pandas or not _have_arrow, "Pandas or Arrow not installed") -class GroupbyApplyTests(ReusedPySparkTestCase): - @classmethod - def setUpClass(cls): - ReusedPySparkTestCase.setUpClass() - cls.spark = SparkSession(cls.sc) - - @classmethod - def tearDownClass(cls): - ReusedPySparkTestCase.tearDownClass() - cls.spark.stop() +class GroupbyApplyTests(ReusedSQLTestCase): def assertFramesEqual(self, expected, result): msg = ("DataFrames are not equal: " + From 6eda55f728a6f2e265ae12a7e01dae88e4172715 Mon Sep 17 00:00:00 2001 From: tengpeng Date: Mon, 30 Oct 2017 07:24:55 +0000 Subject: [PATCH 1568/1765] Added more information to Imputer Often times we want to impute custom values other than 'NaN'. My addition helps people locate this function without reading the API. ## What changes were proposed in this pull request? (Please fill in changes proposed in this fix) ## How was this patch tested? (Please explain how this patch was tested. E.g. unit tests, integration tests, manual tests) (If this patch involves UI changes, please attach a screenshot; otherwise, remove this) Please review http://spark.apache.org/contributing.html before opening a pull request. Author: tengpeng Closes #19600 from tengpeng/patch-5. --- docs/ml-features.md | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/docs/ml-features.md b/docs/ml-features.md index 86a0e09997b8e..72643137d96b1 100644 --- a/docs/ml-features.md +++ b/docs/ml-features.md @@ -1373,7 +1373,9 @@ for more details on the API. The `Imputer` transformer completes missing values in a dataset, either using the mean or the median of the columns in which the missing values are located. The input columns should be of `DoubleType` or `FloatType`. Currently `Imputer` does not support categorical features and possibly -creates incorrect values for columns containing categorical features. +creates incorrect values for columns containing categorical features. Imputer can impute custom values +other than 'NaN' by `.setMissingValue(custom_value)`. For example, `.setMissingValue(0)` will impute +all occurrences of (0). **Note** all `null` values in the input columns are treated as missing, and so are also imputed. From 9f5c77ae32890b892b69b45a2833b9d6d6866aea Mon Sep 17 00:00:00 2001 From: Henry Robinson Date: Mon, 30 Oct 2017 07:45:54 +0000 Subject: [PATCH 1569/1765] [SPARK-21983][SQL] Fix Antlr 4.7 deprecation warnings ## What changes were proposed in this pull request? Fix three deprecation warnings introduced by move to ANTLR 4.7: * Use ParserRuleContext.addChild(TerminalNode) in preference to deprecated ParserRuleContext.addChild(Token) interface. * TokenStream.reset() is deprecated in favour of seek(0) * Replace use of deprecated ANTLRInputStream with stream returned by CharStreams.fromString() The last item changed the way we construct ANTLR's input stream (from direct instantiation to factory construction), so necessitated a change to how we override the LA() method to always return an upper-case char. The ANTLR object is now wrapped, rather than inherited-from. * Also fix incorrect usage of CharStream.getText() which expects the rhs of the supplied interval to be the last char to be returned, i.e. the interval is inclusive, and work around bug in ANTLR 4.7 where empty streams or intervals may cause getText() to throw an error. ## How was this patch tested? Ran all the sql tests. Confirmed that LA() override has coverage by breaking it, and noting that tests failed. Author: Henry Robinson Closes #19578 from henryr/spark-21983. --- .../sql/catalyst/parser/ParseDriver.scala | 39 +++++++++++++++---- .../sql/catalyst/parser/ParserUtils.scala | 4 +- .../catalyst/parser/ParserUtilsSuite.scala | 4 +- 3 files changed, 35 insertions(+), 12 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParseDriver.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParseDriver.scala index 0d9ad218e48db..4c20f2368bded 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParseDriver.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParseDriver.scala @@ -18,7 +18,8 @@ package org.apache.spark.sql.catalyst.parser import org.antlr.v4.runtime._ import org.antlr.v4.runtime.atn.PredictionMode -import org.antlr.v4.runtime.misc.ParseCancellationException +import org.antlr.v4.runtime.misc.{Interval, ParseCancellationException} +import org.antlr.v4.runtime.tree.TerminalNodeImpl import org.apache.spark.internal.Logging import org.apache.spark.sql.AnalysisException @@ -80,7 +81,7 @@ abstract class AbstractSqlParser extends ParserInterface with Logging { protected def parse[T](command: String)(toResult: SqlBaseParser => T): T = { logDebug(s"Parsing command: $command") - val lexer = new SqlBaseLexer(new ANTLRNoCaseStringStream(command)) + val lexer = new SqlBaseLexer(new UpperCaseCharStream(CharStreams.fromString(command))) lexer.removeErrorListeners() lexer.addErrorListener(ParseErrorListener) @@ -99,7 +100,7 @@ abstract class AbstractSqlParser extends ParserInterface with Logging { catch { case e: ParseCancellationException => // if we fail, parse with LL mode - tokenStream.reset() // rewind input stream + tokenStream.seek(0) // rewind input stream parser.reset() // Try Again. @@ -148,12 +149,33 @@ object CatalystSqlParser extends AbstractSqlParser { * the consume() function of the super class ANTLRStringStream. The LA() function is the lookahead * function and is purely used for matching lexical rules. This also means that the grammar will * only accept capitalized tokens in case it is run from other tools like antlrworks which do not - * have the ANTLRNoCaseStringStream implementation. + * have the UpperCaseCharStream implementation. */ -private[parser] class ANTLRNoCaseStringStream(input: String) extends ANTLRInputStream(input) { +private[parser] class UpperCaseCharStream(wrapped: CodePointCharStream) extends CharStream { + override def consume(): Unit = wrapped.consume + override def getSourceName(): String = wrapped.getSourceName + override def index(): Int = wrapped.index + override def mark(): Int = wrapped.mark + override def release(marker: Int): Unit = wrapped.release(marker) + override def seek(where: Int): Unit = wrapped.seek(where) + override def size(): Int = wrapped.size + + override def getText(interval: Interval): String = { + // ANTLR 4.7's CodePointCharStream implementations have bugs when + // getText() is called with an empty stream, or intervals where + // the start > end. See + // https://github.com/antlr/antlr4/commit/ac9f7530 for one fix + // that is not yet in a released ANTLR artifact. + if (size() > 0 && (interval.b - interval.a >= 0)) { + wrapped.getText(interval) + } else { + "" + } + } + override def LA(i: Int): Int = { - val la = super.LA(i) + val la = wrapped.LA(i) if (la == 0 || la == IntStream.EOF) la else Character.toUpperCase(la) } @@ -244,11 +266,12 @@ case object PostProcessor extends SqlBaseBaseListener { val parent = ctx.getParent parent.removeLastChild() val token = ctx.getChild(0).getPayload.asInstanceOf[Token] - parent.addChild(f(new CommonToken( + val newToken = new CommonToken( new org.antlr.v4.runtime.misc.Pair(token.getTokenSource, token.getInputStream), SqlBaseParser.IDENTIFIER, token.getChannel, token.getStartIndex + stripMargins, - token.getStopIndex - stripMargins))) + token.getStopIndex - stripMargins) + parent.addChild(new TerminalNodeImpl(f(newToken))) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParserUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParserUtils.scala index 9c1031e8033e7..9b127f91648e6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParserUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParserUtils.scala @@ -32,7 +32,7 @@ object ParserUtils { /** Get the command which created the token. */ def command(ctx: ParserRuleContext): String = { val stream = ctx.getStart.getInputStream - stream.getText(Interval.of(0, stream.size())) + stream.getText(Interval.of(0, stream.size() - 1)) } def operationNotAllowed(message: String, ctx: ParserRuleContext): Nothing = { @@ -58,7 +58,7 @@ object ParserUtils { /** Get all the text which comes after the given token. */ def remainder(token: Token): String = { val stream = token.getInputStream - val interval = Interval.of(token.getStopIndex + 1, stream.size()) + val interval = Interval.of(token.getStopIndex + 1, stream.size() - 1) stream.getText(interval) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ParserUtilsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ParserUtilsSuite.scala index d5748a4ff18f8..768030f0a9bc4 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ParserUtilsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ParserUtilsSuite.scala @@ -16,7 +16,7 @@ */ package org.apache.spark.sql.catalyst.parser -import org.antlr.v4.runtime.{CommonTokenStream, ParserRuleContext} +import org.antlr.v4.runtime.{CharStreams, CommonTokenStream, ParserRuleContext} import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.parser.SqlBaseParser._ @@ -57,7 +57,7 @@ class ParserUtilsSuite extends SparkFunSuite { } private def buildContext[T](command: String)(toResult: SqlBaseParser => T): T = { - val lexer = new SqlBaseLexer(new ANTLRNoCaseStringStream(command)) + val lexer = new SqlBaseLexer(new UpperCaseCharStream(CharStreams.fromString(command))) val tokenStream = new CommonTokenStream(lexer) val parser = new SqlBaseParser(tokenStream) toResult(parser) From 9f02d7dc537b73988468b11337dbb14a8602f246 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Mon, 30 Oct 2017 11:00:44 +0100 Subject: [PATCH 1570/1765] [SPARK-22385][SQL] MapObjects should not access list element by index ## What changes were proposed in this pull request? This issue was discovered and investigated by Ohad Raviv and Sean Owen in https://issues.apache.org/jira/browse/SPARK-21657. The input data of `MapObjects` may be a `List` which has O(n) complexity for accessing by index. When converting input data to catalyst array, `MapObjects` gets element by index in each loop, and results to bad performance. This PR fixes this issue by accessing elements via Iterator. ## How was this patch tested? using the test script in https://issues.apache.org/jira/browse/SPARK-21657 ``` val BASE = 100000000 val N = 100000 val df = sc.parallelize(List(("1234567890", (BASE to (BASE+N)).map(x => (x.toString, (x+1).toString, (x+2).toString, (x+3).toString)).toList ))).toDF("c1", "c_arr") spark.time(df.queryExecution.toRdd.foreach(_ => ())) ``` We can see 50x speed up. Author: Wenchen Fan Closes #19603 from cloud-fan/map-objects. --- .../expressions/objects/objects.scala | 40 +++++++++++++++---- 1 file changed, 33 insertions(+), 7 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala index 9b28a18035b1c..6ae3490a3f863 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala @@ -591,18 +591,43 @@ case class MapObjects private( case _ => inputData.dataType } - val (getLength, getLoopVar) = inputDataType match { + // `MapObjects` generates a while loop to traverse the elements of the input collection. We + // need to take care of Seq and List because they may have O(n) complexity for indexed accessing + // like `list.get(1)`. Here we use Iterator to traverse Seq and List. + val (getLength, prepareLoop, getLoopVar) = inputDataType match { case ObjectType(cls) if classOf[Seq[_]].isAssignableFrom(cls) => - s"${genInputData.value}.size()" -> s"${genInputData.value}.apply($loopIndex)" + val it = ctx.freshName("it") + ( + s"${genInputData.value}.size()", + s"scala.collection.Iterator $it = ${genInputData.value}.toIterator();", + s"$it.next()" + ) case ObjectType(cls) if cls.isArray => - s"${genInputData.value}.length" -> s"${genInputData.value}[$loopIndex]" + ( + s"${genInputData.value}.length", + "", + s"${genInputData.value}[$loopIndex]" + ) case ObjectType(cls) if classOf[java.util.List[_]].isAssignableFrom(cls) => - s"${genInputData.value}.size()" -> s"${genInputData.value}.get($loopIndex)" + val it = ctx.freshName("it") + ( + s"${genInputData.value}.size()", + s"java.util.Iterator $it = ${genInputData.value}.iterator();", + s"$it.next()" + ) case ArrayType(et, _) => - s"${genInputData.value}.numElements()" -> ctx.getValue(genInputData.value, et, loopIndex) + ( + s"${genInputData.value}.numElements()", + "", + ctx.getValue(genInputData.value, et, loopIndex) + ) case ObjectType(cls) if cls == classOf[Object] => - s"$seq == null ? $array.length : $seq.size()" -> - s"$seq == null ? $array[$loopIndex] : $seq.apply($loopIndex)" + val it = ctx.freshName("it") + ( + s"$seq == null ? $array.length : $seq.size()", + s"scala.collection.Iterator $it = $seq == null ? null : $seq.toIterator();", + s"$it == null ? $array[$loopIndex] : $it.next()" + ) } // Make a copy of the data if it's unsafe-backed @@ -676,6 +701,7 @@ case class MapObjects private( $initCollection int $loopIndex = 0; + $prepareLoop while ($loopIndex < $dataLength) { $loopValue = ($elementJavaType) ($getLoopVar); $loopNullCheck From 3663764254615cef442d4af55c11808445e5b03a Mon Sep 17 00:00:00 2001 From: guoxiaolong Date: Mon, 30 Oct 2017 12:14:38 +0000 Subject: [PATCH 1571/1765] [WEB-UI] Add count in fair scheduler pool page ## What changes were proposed in this pull request? Add count in fair scheduler pool page. The purpose is to know the statistics clearly. For specific reasons, please refer to PR of https://github.com/apache/spark/pull/18525 fix before: ![1](https://user-images.githubusercontent.com/26266482/31641589-4b17b970-b318-11e7-97eb-f5a36db428f6.png) ![2](https://user-images.githubusercontent.com/26266482/31641643-97b6345a-b318-11e7-8c20-4b164ade228d.png) fix after: ![3](https://user-images.githubusercontent.com/26266482/31641688-e6ceacb6-b318-11e7-8204-6a816c581a29.png) ![4](https://user-images.githubusercontent.com/26266482/31641766-7310b0c0-b319-11e7-871d-a57f874f1e8b.png) ## How was this patch tested? (Please explain how this patch was tested. E.g. unit tests, integration tests, manual tests) (If this patch involves UI changes, please attach a screenshot; otherwise, remove this) Please review http://spark.apache.org/contributing.html before opening a pull request. Author: guoxiaolong Closes #19507 from guoxiaolongzte/add_count_in_fair_scheduler_pool_page. --- .../src/main/scala/org/apache/spark/ui/jobs/AllStagesPage.scala | 2 +- core/src/main/scala/org/apache/spark/ui/jobs/PoolPage.scala | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/AllStagesPage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/AllStagesPage.scala index a30c13592947c..dc5b03c5269a9 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/AllStagesPage.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/AllStagesPage.scala @@ -113,7 +113,7 @@ private[ui] class AllStagesPage(parent: StagesTab) extends WebUIPage("") { var content = summary ++ { if (sc.isDefined && isFairScheduler) { -

    {pools.size} Fair Scheduler Pools

    ++ poolTable.toNodeSeq +

    Fair Scheduler Pools ({pools.size})

    ++ poolTable.toNodeSeq } else { Seq.empty[Node] } diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/PoolPage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/PoolPage.scala index 819fe57e14b2d..4b8c7b203771d 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/PoolPage.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/PoolPage.scala @@ -57,7 +57,7 @@ private[ui] class PoolPage(parent: StagesTab) extends WebUIPage("pool") { var content =

    Summary

    ++ poolTable.toNodeSeq if (shouldShowActiveStages) { - content ++=

    {activeStages.size} Active Stages

    ++ activeStagesTable.toNodeSeq + content ++=

    Active Stages ({activeStages.size})

    ++ activeStagesTable.toNodeSeq } UIUtils.headerSparkPage("Fair Scheduler Pool: " + poolName, content, parent) From 079a2609d7ad0a7dd2ec3eaa594e6ed8801a8008 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Mon, 30 Oct 2017 17:53:06 +0100 Subject: [PATCH 1572/1765] [SPARK-17788][SPARK-21033][SQL] fix the potential OOM in UnsafeExternalSorter and ShuffleExternalSorter ## What changes were proposed in this pull request? In `UnsafeInMemorySorter`, one record may take 32 bytes: 1 `long` for pointer, 1 `long` for key-prefix, and another 2 `long`s as the temporary buffer for radix sort. In `UnsafeExternalSorter`, we set the `DEFAULT_NUM_ELEMENTS_FOR_SPILL_THRESHOLD` to be `1024 * 1024 * 1024 / 2`, and hoping the max size of point array to be 8 GB. However this is wrong, `1024 * 1024 * 1024 / 2 * 32` is actually 16 GB, and if we grow the point array before reach this limitation, we may hit the max-page-size error. Users may see exception like this on large dataset: ``` Caused by: java.lang.IllegalArgumentException: Cannot allocate a page with more than 17179869176 bytes at org.apache.spark.memory.TaskMemoryManager.allocatePage(TaskMemoryManager.java:241) at org.apache.spark.memory.MemoryConsumer.allocatePage(MemoryConsumer.java:121) at org.apache.spark.util.collection.unsafe.sort.UnsafeExternalSorter.acquireNewPageIfNecessary(UnsafeExternalSorter.java:374) at org.apache.spark.util.collection.unsafe.sort.UnsafeExternalSorter.insertRecord(UnsafeExternalSorter.java:396) at org.apache.spark.sql.execution.UnsafeExternalRowSorter.insertRow(UnsafeExternalRowSorter.java:94) ... ``` Setting `DEFAULT_NUM_ELEMENTS_FOR_SPILL_THRESHOLD` to a smaller number is not enough, users can still set the config to a big number and trigger the too large page size issue. This PR fixes it by explicitly handling the too large page size exception in the sorter and spill. This PR also change the type of `spark.shuffle.spill.numElementsForceSpillThreshold` to int, because it's only compared with `numRecords`, which is an int. This is an internal conf so we don't have a serious compatibility issue. ## How was this patch tested? TODO Author: Wenchen Fan Closes #18251 from cloud-fan/sort. --- .../apache/spark/memory/MemoryConsumer.java | 8 ++++++- .../spark/memory/TaskMemoryManager.java | 5 ++-- .../spark/memory/TooLargePageException.java | 24 +++++++++++++++++++ .../shuffle/sort/ShuffleExternalSorter.java | 16 ++++++++----- .../unsafe/sort/UnsafeExternalSorter.java | 21 +++++++++------- .../spark/internal/config/package.scala | 10 ++++++++ .../sort/UnsafeExternalSorterSuite.java | 11 +++++---- .../execution/UnsafeExternalRowSorter.java | 5 ++-- .../apache/spark/sql/internal/SQLConf.scala | 6 ++--- .../UnsafeFixedWidthAggregationMap.java | 6 ++--- .../sql/execution/UnsafeKVExternalSorter.java | 4 ++-- .../aggregate/ObjectAggregationIterator.scala | 6 ++--- .../aggregate/ObjectAggregationMap.scala | 5 ++-- ...nalAppendOnlyUnsafeRowArrayBenchmark.scala | 3 ++- .../UnsafeKVExternalSorterSuite.scala | 4 ++-- 15 files changed, 92 insertions(+), 42 deletions(-) create mode 100644 core/src/main/java/org/apache/spark/memory/TooLargePageException.java diff --git a/core/src/main/java/org/apache/spark/memory/MemoryConsumer.java b/core/src/main/java/org/apache/spark/memory/MemoryConsumer.java index 0efae16e9838c..2dff241900e82 100644 --- a/core/src/main/java/org/apache/spark/memory/MemoryConsumer.java +++ b/core/src/main/java/org/apache/spark/memory/MemoryConsumer.java @@ -83,7 +83,13 @@ public void spill() throws IOException { public abstract long spill(long size, MemoryConsumer trigger) throws IOException; /** - * Allocates a LongArray of `size`. + * Allocates a LongArray of `size`. Note that this method may throw `OutOfMemoryError` if Spark + * doesn't have enough memory for this allocation, or throw `TooLargePageException` if this + * `LongArray` is too large to fit in a single page. The caller side should take care of these + * two exceptions, or make sure the `size` is small enough that won't trigger exceptions. + * + * @throws OutOfMemoryError + * @throws TooLargePageException */ public LongArray allocateArray(long size) { long required = size * 8L; diff --git a/core/src/main/java/org/apache/spark/memory/TaskMemoryManager.java b/core/src/main/java/org/apache/spark/memory/TaskMemoryManager.java index 44b60c1e4e8c8..f6b5ea3c0ad26 100644 --- a/core/src/main/java/org/apache/spark/memory/TaskMemoryManager.java +++ b/core/src/main/java/org/apache/spark/memory/TaskMemoryManager.java @@ -270,13 +270,14 @@ public long pageSizeBytes() { * * Returns `null` if there was not enough memory to allocate the page. May return a page that * contains fewer bytes than requested, so callers should verify the size of returned pages. + * + * @throws TooLargePageException */ public MemoryBlock allocatePage(long size, MemoryConsumer consumer) { assert(consumer != null); assert(consumer.getMode() == tungstenMemoryMode); if (size > MAXIMUM_PAGE_SIZE_BYTES) { - throw new IllegalArgumentException( - "Cannot allocate a page with more than " + MAXIMUM_PAGE_SIZE_BYTES + " bytes"); + throw new TooLargePageException(size); } long acquired = acquireExecutionMemory(size, consumer); diff --git a/core/src/main/java/org/apache/spark/memory/TooLargePageException.java b/core/src/main/java/org/apache/spark/memory/TooLargePageException.java new file mode 100644 index 0000000000000..4abee77ff67b2 --- /dev/null +++ b/core/src/main/java/org/apache/spark/memory/TooLargePageException.java @@ -0,0 +1,24 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.memory; + +public class TooLargePageException extends RuntimeException { + TooLargePageException(long size) { + super("Cannot allocate a page of " + size + " bytes."); + } +} diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleExternalSorter.java b/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleExternalSorter.java index b4f46306f2827..e80f9734ecf7b 100644 --- a/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleExternalSorter.java +++ b/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleExternalSorter.java @@ -31,8 +31,10 @@ import org.apache.spark.SparkConf; import org.apache.spark.TaskContext; import org.apache.spark.executor.ShuffleWriteMetrics; +import org.apache.spark.internal.config.package$; import org.apache.spark.memory.MemoryConsumer; import org.apache.spark.memory.TaskMemoryManager; +import org.apache.spark.memory.TooLargePageException; import org.apache.spark.serializer.DummySerializerInstance; import org.apache.spark.serializer.SerializerInstance; import org.apache.spark.storage.BlockManager; @@ -43,7 +45,6 @@ import org.apache.spark.unsafe.array.LongArray; import org.apache.spark.unsafe.memory.MemoryBlock; import org.apache.spark.util.Utils; -import org.apache.spark.internal.config.package$; /** * An external sorter that is specialized for sort-based shuffle. @@ -75,10 +76,9 @@ final class ShuffleExternalSorter extends MemoryConsumer { private final ShuffleWriteMetrics writeMetrics; /** - * Force this sorter to spill when there are this many elements in memory. The default value is - * 1024 * 1024 * 1024, which allows the maximum size of the pointer array to be 8G. + * Force this sorter to spill when there are this many elements in memory. */ - private final long numElementsForSpillThreshold; + private final int numElementsForSpillThreshold; /** The buffer size to use when writing spills using DiskBlockObjectWriter */ private final int fileBufferSizeBytes; @@ -123,7 +123,7 @@ final class ShuffleExternalSorter extends MemoryConsumer { this.fileBufferSizeBytes = (int) (long) conf.get(package$.MODULE$.SHUFFLE_FILE_BUFFER_SIZE()) * 1024; this.numElementsForSpillThreshold = - conf.getLong("spark.shuffle.spill.numElementsForceSpillThreshold", 1024 * 1024 * 1024); + (int) conf.get(package$.MODULE$.SHUFFLE_SPILL_NUM_ELEMENTS_FORCE_SPILL_THRESHOLD()); this.writeMetrics = writeMetrics; this.inMemSorter = new ShuffleInMemorySorter( this, initialSize, conf.getBoolean("spark.shuffle.sort.useRadixSort", true)); @@ -325,7 +325,7 @@ public void cleanupResources() { * array and grows the array if additional space is required. If the required space cannot be * obtained, then the in-memory data will be spilled to disk. */ - private void growPointerArrayIfNecessary() { + private void growPointerArrayIfNecessary() throws IOException { assert(inMemSorter != null); if (!inMemSorter.hasSpaceForAnotherRecord()) { long used = inMemSorter.getMemoryUsage(); @@ -333,6 +333,10 @@ private void growPointerArrayIfNecessary() { try { // could trigger spilling array = allocateArray(used / 8 * 2); + } catch (TooLargePageException e) { + // The pointer array is too big to fix in a single page, spill. + spill(); + return; } catch (OutOfMemoryError e) { // should have trigger spilling if (!inMemSorter.hasSpaceForAnotherRecord()) { diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java index e749f7ba87c6e..8b8e15e3f78ed 100644 --- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java @@ -32,6 +32,7 @@ import org.apache.spark.executor.ShuffleWriteMetrics; import org.apache.spark.memory.MemoryConsumer; import org.apache.spark.memory.TaskMemoryManager; +import org.apache.spark.memory.TooLargePageException; import org.apache.spark.serializer.SerializerManager; import org.apache.spark.storage.BlockManager; import org.apache.spark.unsafe.Platform; @@ -68,12 +69,10 @@ public final class UnsafeExternalSorter extends MemoryConsumer { private final int fileBufferSizeBytes; /** - * Force this sorter to spill when there are this many elements in memory. The default value is - * 1024 * 1024 * 1024 / 2 which allows the maximum size of the pointer array to be 8G. + * Force this sorter to spill when there are this many elements in memory. */ - public static final long DEFAULT_NUM_ELEMENTS_FOR_SPILL_THRESHOLD = 1024 * 1024 * 1024 / 2; + private final int numElementsForSpillThreshold; - private final long numElementsForSpillThreshold; /** * Memory pages that hold the records being sorted. The pages in this list are freed when * spilling, although in principle we could recycle these pages across spills (on the other hand, @@ -103,11 +102,11 @@ public static UnsafeExternalSorter createWithExistingInMemorySorter( PrefixComparator prefixComparator, int initialSize, long pageSizeBytes, - long numElementsForSpillThreshold, + int numElementsForSpillThreshold, UnsafeInMemorySorter inMemorySorter) throws IOException { UnsafeExternalSorter sorter = new UnsafeExternalSorter(taskMemoryManager, blockManager, serializerManager, taskContext, recordComparatorSupplier, prefixComparator, initialSize, - numElementsForSpillThreshold, pageSizeBytes, inMemorySorter, false /* ignored */); + pageSizeBytes, numElementsForSpillThreshold, inMemorySorter, false /* ignored */); sorter.spill(Long.MAX_VALUE, sorter); // The external sorter will be used to insert records, in-memory sorter is not needed. sorter.inMemSorter = null; @@ -123,7 +122,7 @@ public static UnsafeExternalSorter create( PrefixComparator prefixComparator, int initialSize, long pageSizeBytes, - long numElementsForSpillThreshold, + int numElementsForSpillThreshold, boolean canUseRadixSort) { return new UnsafeExternalSorter(taskMemoryManager, blockManager, serializerManager, taskContext, recordComparatorSupplier, prefixComparator, initialSize, pageSizeBytes, @@ -139,7 +138,7 @@ private UnsafeExternalSorter( PrefixComparator prefixComparator, int initialSize, long pageSizeBytes, - long numElementsForSpillThreshold, + int numElementsForSpillThreshold, @Nullable UnsafeInMemorySorter existingInMemorySorter, boolean canUseRadixSort) { super(taskMemoryManager, pageSizeBytes, taskMemoryManager.getTungstenMemoryMode()); @@ -338,7 +337,7 @@ public void cleanupResources() { * array and grows the array if additional space is required. If the required space cannot be * obtained, then the in-memory data will be spilled to disk. */ - private void growPointerArrayIfNecessary() { + private void growPointerArrayIfNecessary() throws IOException { assert(inMemSorter != null); if (!inMemSorter.hasSpaceForAnotherRecord()) { long used = inMemSorter.getMemoryUsage(); @@ -346,6 +345,10 @@ private void growPointerArrayIfNecessary() { try { // could trigger spilling array = allocateArray(used / 8 * 2); + } catch (TooLargePageException e) { + // The pointer array is too big to fix in a single page, spill. + spill(); + return; } catch (OutOfMemoryError e) { // should have trigger spilling if (!inMemSorter.hasSpaceForAnotherRecord()) { diff --git a/core/src/main/scala/org/apache/spark/internal/config/package.scala b/core/src/main/scala/org/apache/spark/internal/config/package.scala index 6f0247b73070d..57e2da8353d6d 100644 --- a/core/src/main/scala/org/apache/spark/internal/config/package.scala +++ b/core/src/main/scala/org/apache/spark/internal/config/package.scala @@ -475,4 +475,14 @@ package object config { .stringConf .toSequence .createOptional + + private[spark] val SHUFFLE_SPILL_NUM_ELEMENTS_FORCE_SPILL_THRESHOLD = + ConfigBuilder("spark.shuffle.spill.numElementsForceSpillThreshold") + .internal() + .doc("The maximum number of elements in memory before forcing the shuffle sorter to spill. " + + "By default it's Integer.MAX_VALUE, which means we never force the sorter to spill, " + + "until we reach some limitations, like the max page size limitation for the pointer " + + "array in the sorter.") + .intConf + .createWithDefault(Integer.MAX_VALUE) } diff --git a/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java b/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java index d0d0334add0bf..af4975c888d65 100644 --- a/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java +++ b/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java @@ -36,6 +36,7 @@ import org.apache.spark.TaskContext; import org.apache.spark.executor.ShuffleWriteMetrics; import org.apache.spark.executor.TaskMetrics; +import org.apache.spark.internal.config.package$; import org.apache.spark.memory.TestMemoryManager; import org.apache.spark.memory.TaskMemoryManager; import org.apache.spark.serializer.JavaSerializer; @@ -86,6 +87,9 @@ public int compare( private final long pageSizeBytes = conf.getSizeAsBytes("spark.buffer.pageSize", "4m"); + private final int spillThreshold = + (int) conf.get(package$.MODULE$.SHUFFLE_SPILL_NUM_ELEMENTS_FORCE_SPILL_THRESHOLD()); + @Before public void setUp() { MockitoAnnotations.initMocks(this); @@ -159,7 +163,7 @@ private UnsafeExternalSorter newSorter() throws IOException { prefixComparator, /* initialSize */ 1024, pageSizeBytes, - UnsafeExternalSorter.DEFAULT_NUM_ELEMENTS_FOR_SPILL_THRESHOLD, + spillThreshold, shouldUseRadixSort()); } @@ -383,7 +387,7 @@ public void forcedSpillingWithoutComparator() throws Exception { null, /* initialSize */ 1024, pageSizeBytes, - UnsafeExternalSorter.DEFAULT_NUM_ELEMENTS_FOR_SPILL_THRESHOLD, + spillThreshold, shouldUseRadixSort()); long[] record = new long[100]; int recordSize = record.length * 8; @@ -445,7 +449,7 @@ public void testPeakMemoryUsed() throws Exception { prefixComparator, 1024, pageSizeBytes, - UnsafeExternalSorter.DEFAULT_NUM_ELEMENTS_FOR_SPILL_THRESHOLD, + spillThreshold, shouldUseRadixSort()); // Peak memory should be monotonically increasing. More specifically, every time @@ -548,4 +552,3 @@ private void verifyIntIterator(UnsafeSorterIterator iter, int start, int end) } } } - diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java b/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java index 12a123ee0bcff..6b002f0d3f8e8 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java @@ -27,6 +27,7 @@ import org.apache.spark.SparkEnv; import org.apache.spark.TaskContext; +import org.apache.spark.internal.config.package$; import org.apache.spark.sql.catalyst.InternalRow; import org.apache.spark.sql.catalyst.expressions.UnsafeRow; import org.apache.spark.sql.types.StructType; @@ -89,8 +90,8 @@ public UnsafeExternalRowSorter( sparkEnv.conf().getInt("spark.shuffle.sort.initialBufferSize", DEFAULT_INITIAL_SORT_BUFFER_SIZE), pageSizeBytes, - SparkEnv.get().conf().getLong("spark.shuffle.spill.numElementsForceSpillThreshold", - UnsafeExternalSorter.DEFAULT_NUM_ELEMENTS_FOR_SPILL_THRESHOLD), + (int) SparkEnv.get().conf().get( + package$.MODULE$.SHUFFLE_SPILL_NUM_ELEMENTS_FORCE_SPILL_THRESHOLD()), canUseRadixSort ); } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 5203e8833fbbb..ede116e964a03 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -884,7 +884,7 @@ object SQLConf { .internal() .doc("Threshold for number of rows to be spilled by window operator") .intConf - .createWithDefault(UnsafeExternalSorter.DEFAULT_NUM_ELEMENTS_FOR_SPILL_THRESHOLD.toInt) + .createWithDefault(SHUFFLE_SPILL_NUM_ELEMENTS_FORCE_SPILL_THRESHOLD.defaultValue.get) val SORT_MERGE_JOIN_EXEC_BUFFER_IN_MEMORY_THRESHOLD = buildConf("spark.sql.sortMergeJoinExec.buffer.in.memory.threshold") @@ -899,7 +899,7 @@ object SQLConf { .internal() .doc("Threshold for number of rows to be spilled by sort merge join operator") .intConf - .createWithDefault(UnsafeExternalSorter.DEFAULT_NUM_ELEMENTS_FOR_SPILL_THRESHOLD.toInt) + .createWithDefault(SHUFFLE_SPILL_NUM_ELEMENTS_FORCE_SPILL_THRESHOLD.defaultValue.get) val CARTESIAN_PRODUCT_EXEC_BUFFER_IN_MEMORY_THRESHOLD = buildConf("spark.sql.cartesianProductExec.buffer.in.memory.threshold") @@ -914,7 +914,7 @@ object SQLConf { .internal() .doc("Threshold for number of rows to be spilled by cartesian product operator") .intConf - .createWithDefault(UnsafeExternalSorter.DEFAULT_NUM_ELEMENTS_FOR_SPILL_THRESHOLD.toInt) + .createWithDefault(SHUFFLE_SPILL_NUM_ELEMENTS_FORCE_SPILL_THRESHOLD.defaultValue.get) val SUPPORT_QUOTED_REGEX_COLUMN_NAME = buildConf("spark.sql.parser.quotedRegexColumnNames") .doc("When true, quoted Identifiers (using backticks) in SELECT statement are interpreted" + diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java index 8fea46a58e857..c7c4c7b3e7715 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java @@ -20,6 +20,7 @@ import java.io.IOException; import org.apache.spark.SparkEnv; +import org.apache.spark.internal.config.package$; import org.apache.spark.memory.TaskMemoryManager; import org.apache.spark.sql.catalyst.InternalRow; import org.apache.spark.sql.catalyst.expressions.UnsafeProjection; @@ -29,7 +30,6 @@ import org.apache.spark.unsafe.KVIterator; import org.apache.spark.unsafe.Platform; import org.apache.spark.unsafe.map.BytesToBytesMap; -import org.apache.spark.util.collection.unsafe.sort.UnsafeExternalSorter; /** * Unsafe-based HashMap for performing aggregations where the aggregated values are fixed-width. @@ -238,8 +238,8 @@ public UnsafeKVExternalSorter destructAndCreateExternalSorter() throws IOExcepti SparkEnv.get().blockManager(), SparkEnv.get().serializerManager(), map.getPageSizeBytes(), - SparkEnv.get().conf().getLong("spark.shuffle.spill.numElementsForceSpillThreshold", - UnsafeExternalSorter.DEFAULT_NUM_ELEMENTS_FOR_SPILL_THRESHOLD), + (int) SparkEnv.get().conf().get( + package$.MODULE$.SHUFFLE_SPILL_NUM_ELEMENTS_FORCE_SPILL_THRESHOLD()), map); } } diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeKVExternalSorter.java b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeKVExternalSorter.java index 6aa52f1aae048..eb2fe82007af3 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeKVExternalSorter.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeKVExternalSorter.java @@ -57,7 +57,7 @@ public UnsafeKVExternalSorter( BlockManager blockManager, SerializerManager serializerManager, long pageSizeBytes, - long numElementsForSpillThreshold) throws IOException { + int numElementsForSpillThreshold) throws IOException { this(keySchema, valueSchema, blockManager, serializerManager, pageSizeBytes, numElementsForSpillThreshold, null); } @@ -68,7 +68,7 @@ public UnsafeKVExternalSorter( BlockManager blockManager, SerializerManager serializerManager, long pageSizeBytes, - long numElementsForSpillThreshold, + int numElementsForSpillThreshold, @Nullable BytesToBytesMap map) throws IOException { this.keySchema = keySchema; this.valueSchema = valueSchema; diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectAggregationIterator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectAggregationIterator.scala index c68dbc73f0447..43514f5271ac8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectAggregationIterator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectAggregationIterator.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.execution.aggregate import org.apache.spark.{SparkEnv, TaskContext} -import org.apache.spark.internal.Logging +import org.apache.spark.internal.{config, Logging} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ @@ -315,9 +315,7 @@ class SortBasedAggregator( SparkEnv.get.blockManager, SparkEnv.get.serializerManager, TaskContext.get().taskMemoryManager().pageSizeBytes, - SparkEnv.get.conf.getLong( - "spark.shuffle.spill.numElementsForceSpillThreshold", - UnsafeExternalSorter.DEFAULT_NUM_ELEMENTS_FOR_SPILL_THRESHOLD), + SparkEnv.get.conf.get(config.SHUFFLE_SPILL_NUM_ELEMENTS_FORCE_SPILL_THRESHOLD), null ) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectAggregationMap.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectAggregationMap.scala index f2d4f6c6ebd5b..b5372bcca89dd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectAggregationMap.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectAggregationMap.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.execution.aggregate import java.{util => ju} import org.apache.spark.{SparkEnv, TaskContext} +import org.apache.spark.internal.config import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{Attribute, UnsafeProjection, UnsafeRow} import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateFunction, TypedImperativeAggregate} @@ -73,9 +74,7 @@ class ObjectAggregationMap() { SparkEnv.get.blockManager, SparkEnv.get.serializerManager, TaskContext.get().taskMemoryManager().pageSizeBytes, - SparkEnv.get.conf.getLong( - "spark.shuffle.spill.numElementsForceSpillThreshold", - UnsafeExternalSorter.DEFAULT_NUM_ELEMENTS_FOR_SPILL_THRESHOLD), + SparkEnv.get.conf.get(config.SHUFFLE_SPILL_NUM_ELEMENTS_FORCE_SPILL_THRESHOLD), null ) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/ExternalAppendOnlyUnsafeRowArrayBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/ExternalAppendOnlyUnsafeRowArrayBenchmark.scala index efe28afab08e5..59397dbcb1cab 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/ExternalAppendOnlyUnsafeRowArrayBenchmark.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/ExternalAppendOnlyUnsafeRowArrayBenchmark.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.execution import scala.collection.mutable.ArrayBuffer import org.apache.spark.{SparkConf, SparkContext, SparkEnv, TaskContext} +import org.apache.spark.internal.config import org.apache.spark.memory.MemoryTestingUtils import org.apache.spark.sql.catalyst.expressions.UnsafeRow import org.apache.spark.util.Benchmark @@ -231,6 +232,6 @@ object ExternalAppendOnlyUnsafeRowArrayBenchmark { ExternalAppendOnlyUnsafeRowArray 5 / 6 29.8 33.5 0.8X */ testAgainstRawUnsafeExternalSorter( - UnsafeExternalSorter.DEFAULT_NUM_ELEMENTS_FOR_SPILL_THRESHOLD.toInt, 10 * 1000, 1 << 4) + config.SHUFFLE_SPILL_NUM_ELEMENTS_FORCE_SPILL_THRESHOLD.defaultValue.get, 10 * 1000, 1 << 4) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeKVExternalSorterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeKVExternalSorterSuite.scala index 3d869c77e9608..359525fcd05a2 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeKVExternalSorterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeKVExternalSorterSuite.scala @@ -22,13 +22,13 @@ import java.util.Properties import scala.util.Random import org.apache.spark._ +import org.apache.spark.internal.config import org.apache.spark.memory.{TaskMemoryManager, TestMemoryManager} import org.apache.spark.sql.{RandomDataGenerator, Row} import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} import org.apache.spark.sql.catalyst.expressions.{InterpretedOrdering, UnsafeProjection, UnsafeRow} import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types._ -import org.apache.spark.util.collection.unsafe.sort.UnsafeExternalSorter /** * Test suite for [[UnsafeKVExternalSorter]], with randomly generated test data. @@ -125,7 +125,7 @@ class UnsafeKVExternalSorterSuite extends SparkFunSuite with SharedSQLContext { val sorter = new UnsafeKVExternalSorter( keySchema, valueSchema, SparkEnv.get.blockManager, SparkEnv.get.serializerManager, - pageSize, UnsafeExternalSorter.DEFAULT_NUM_ELEMENTS_FOR_SPILL_THRESHOLD) + pageSize, config.SHUFFLE_SPILL_NUM_ELEMENTS_FORCE_SPILL_THRESHOLD.defaultValue.get) // Insert the keys and values into the sorter inputData.foreach { case (k, v) => From 65338de5fbaf90774fb3f4c51321359d324ace56 Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Mon, 30 Oct 2017 10:19:34 -0700 Subject: [PATCH 1573/1765] [SPARK-22396][SQL] Better Error Message for InsertIntoDir using Hive format without enabling Hive Support ## What changes were proposed in this pull request? When Hive support is not on, users can hit unresolved plan node when trying to call `INSERT OVERWRITE DIRECTORY` using Hive format. ``` "unresolved operator 'InsertIntoDir true, Storage(Location: /private/var/folders/vx/j0ydl5rn0gd9mgrh1pljnw900000gn/T/spark-b4227606-9311-46a8-8c02-56355bf0e2bc, Serde Library: org.apache.hadoop.hive.ql.io.orc.OrcSerde, InputFormat: org.apache.hadoop.hive.ql.io.orc.OrcInputFormat, OutputFormat: org.apache.hadoop.hive.ql.io.orc.OrcOutputFormat), hive, true;; ``` This PR is to issue a better error message. ## How was this patch tested? Added a test case. Author: gatorsmile Closes #19608 from gatorsmile/hivesupportInsertOverwrite. --- .../spark/sql/execution/datasources/rules.scala | 3 +++ .../apache/spark/sql/sources/InsertSuite.scala | 16 ++++++++++++++++ 2 files changed, 19 insertions(+) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala index 7a2c85e8e01f6..60c430bcfece2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala @@ -404,6 +404,9 @@ object HiveOnlyCheck extends (LogicalPlan => Unit) { plan.foreach { case CreateTable(tableDesc, _, _) if DDLUtils.isHiveTable(tableDesc) => throw new AnalysisException("Hive support is required to CREATE Hive TABLE (AS SELECT)") + case i: InsertIntoDir if DDLUtils.isHiveTable(i.provider) => + throw new AnalysisException( + "Hive support is required to INSERT OVERWRITE DIRECTORY with the Hive format") case _ => // OK } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertSuite.scala index 875b74551addb..8b7e2e5f45946 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertSuite.scala @@ -408,6 +408,22 @@ class InsertSuite extends DataSourceTest with SharedSQLContext { } } + test("Insert overwrite directory using Hive serde without turning on Hive support") { + withTempDir { dir => + val path = dir.toURI.getPath + val e = intercept[AnalysisException] { + sql( + s""" + |INSERT OVERWRITE LOCAL DIRECTORY '$path' + |STORED AS orc + |SELECT 1, 2 + """.stripMargin) + }.getMessage + assert(e.contains( + "Hive support is required to INSERT OVERWRITE DIRECTORY with the Hive format")) + } + } + test("insert overwrite directory to data source not providing FileFormat") { withTempDir { dir => val path = dir.toURI.getPath From 44c4003155c1d243ffe0f73d5537b4c8b3f3b564 Mon Sep 17 00:00:00 2001 From: Zhenhua Wang Date: Mon, 30 Oct 2017 10:21:05 -0700 Subject: [PATCH 1574/1765] [SPARK-22400][SQL] rename some APIs and classes to make their meaning clearer ## What changes were proposed in this pull request? Both `ReadSupport` and `ReadTask` have a method called `createReader`, but they create different things. This could cause some confusion for data source developers. The same issue exists between `WriteSupport` and `DataWriterFactory`, both of which have a method called `createWriter`. This PR renames the method of `ReadTask`/`DataWriterFactory` to `createDataReader`/`createDataWriter`. Besides, the name of `RowToInternalRowDataWriterFactory` is not correct, because it actually converts `InternalRow`s to `Row`s. It should be renamed `InternalRowDataWriterFactory`. ## How was this patch tested? Only renaming, should be covered by existing tests. Author: Zhenhua Wang Closes #19610 from wzhfy/rename. --- .../spark/sql/sources/v2/reader/DataReader.java | 4 ++-- .../spark/sql/sources/v2/reader/ReadTask.java | 2 +- .../spark/sql/sources/v2/writer/DataWriter.java | 6 +++--- .../sql/sources/v2/writer/DataWriterFactory.java | 2 +- .../execution/datasources/v2/DataSourceRDD.scala | 2 +- .../datasources/v2/DataSourceV2ScanExec.scala | 5 +++-- .../datasources/v2/WriteToDataSourceV2.scala | 14 +++++++------- .../sql/sources/v2/JavaAdvancedDataSourceV2.java | 2 +- .../sql/sources/v2/JavaSimpleDataSourceV2.java | 2 +- .../sql/sources/v2/JavaUnsafeRowDataSourceV2.java | 2 +- .../spark/sql/sources/v2/DataSourceV2Suite.scala | 8 +++++--- .../sql/sources/v2/SimpleWritableDataSource.scala | 9 ++++----- 12 files changed, 30 insertions(+), 28 deletions(-) diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/DataReader.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/DataReader.java index 95e091569b614..52bb138673fc9 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/DataReader.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/DataReader.java @@ -22,8 +22,8 @@ import org.apache.spark.annotation.InterfaceStability; /** - * A data reader returned by {@link ReadTask#createReader()} and is responsible for outputting data - * for a RDD partition. + * A data reader returned by {@link ReadTask#createDataReader()} and is responsible for + * outputting data for a RDD partition. * * Note that, Currently the type `T` can only be {@link org.apache.spark.sql.Row} for normal data * source readers, or {@link org.apache.spark.sql.catalyst.expressions.UnsafeRow} for data source diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/ReadTask.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/ReadTask.java index 01362df0978cb..44786db419a32 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/ReadTask.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/ReadTask.java @@ -45,5 +45,5 @@ default String[] preferredLocations() { /** * Returns a data reader to do the actual reading work for this read task. */ - DataReader createReader(); + DataReader createDataReader(); } diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataWriter.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataWriter.java index 14261419af6f6..d84afbae32892 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataWriter.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataWriter.java @@ -20,7 +20,7 @@ import org.apache.spark.annotation.InterfaceStability; /** - * A data writer returned by {@link DataWriterFactory#createWriter(int, int)} and is + * A data writer returned by {@link DataWriterFactory#createDataWriter(int, int)} and is * responsible for writing data for an input RDD partition. * * One Spark task has one exclusive data writer, so there is no thread-safe concern. @@ -34,7 +34,7 @@ * {@link DataSourceV2Writer#commit(WriterCommitMessage[])} with commit messages from other data * writers. If this data writer fails(one record fails to write or {@link #commit()} fails), an * exception will be sent to the driver side, and Spark will retry this writing task for some times, - * each time {@link DataWriterFactory#createWriter(int, int)} gets a different `attemptNumber`, + * each time {@link DataWriterFactory#createDataWriter(int, int)} gets a different `attemptNumber`, * and finally call {@link DataSourceV2Writer#abort(WriterCommitMessage[])} if all retry fail. * * Besides the retry mechanism, Spark may launch speculative tasks if the existing writing task @@ -64,7 +64,7 @@ public interface DataWriter { /** * Commits this writer after all records are written successfully, returns a commit message which - * will be send back to driver side and pass to + * will be sent back to driver side and passed to * {@link DataSourceV2Writer#commit(WriterCommitMessage[])}. * * The written data should only be visible to data source readers after diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataWriterFactory.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataWriterFactory.java index f812d102bda1a..fe56cc00d1c7a 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataWriterFactory.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataWriterFactory.java @@ -46,5 +46,5 @@ public interface DataWriterFactory extends Serializable { * tasks with the same task id running at the same time. Implementations can * use this attempt number to distinguish writers of different task attempts. */ - DataWriter createWriter(int partitionId, int attemptNumber); + DataWriter createDataWriter(int partitionId, int attemptNumber); } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceRDD.scala index b8fe5ac8e3d94..5f30be5ed4af1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceRDD.scala @@ -39,7 +39,7 @@ class DataSourceRDD( } override def compute(split: Partition, context: TaskContext): Iterator[UnsafeRow] = { - val reader = split.asInstanceOf[DataSourceRDDPartition].readTask.createReader() + val reader = split.asInstanceOf[DataSourceRDDPartition].readTask.createDataReader() context.addTaskCompletionListener(_ => reader.close()) val iter = new Iterator[UnsafeRow] { private[this] var valuePrepared = false diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExec.scala index addc12a3f0901..3f243dc44e043 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExec.scala @@ -67,8 +67,9 @@ class RowToUnsafeRowReadTask(rowReadTask: ReadTask[Row], schema: StructType) override def preferredLocations: Array[String] = rowReadTask.preferredLocations - override def createReader: DataReader[UnsafeRow] = { - new RowToUnsafeDataReader(rowReadTask.createReader, RowEncoder.apply(schema).resolveAndBind()) + override def createDataReader: DataReader[UnsafeRow] = { + new RowToUnsafeDataReader( + rowReadTask.createDataReader, RowEncoder.apply(schema).resolveAndBind()) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2.scala index 92c1e1f4a3383..b72d15ed15aed 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2.scala @@ -48,7 +48,7 @@ case class WriteToDataSourceV2Exec(writer: DataSourceV2Writer, query: SparkPlan) override protected def doExecute(): RDD[InternalRow] = { val writeTask = writer match { case w: SupportsWriteInternalRow => w.createInternalRowWriterFactory() - case _ => new RowToInternalRowDataWriterFactory(writer.createWriterFactory(), query.schema) + case _ => new InternalRowDataWriterFactory(writer.createWriterFactory(), query.schema) } val rdd = query.execute() @@ -93,7 +93,7 @@ object DataWritingSparkTask extends Logging { writeTask: DataWriterFactory[InternalRow], context: TaskContext, iter: Iterator[InternalRow]): WriterCommitMessage = { - val dataWriter = writeTask.createWriter(context.partitionId(), context.attemptNumber()) + val dataWriter = writeTask.createDataWriter(context.partitionId(), context.attemptNumber()) // write the data and commit this writer. Utils.tryWithSafeFinallyAndFailureCallbacks(block = { @@ -111,18 +111,18 @@ object DataWritingSparkTask extends Logging { } } -class RowToInternalRowDataWriterFactory( +class InternalRowDataWriterFactory( rowWriterFactory: DataWriterFactory[Row], schema: StructType) extends DataWriterFactory[InternalRow] { - override def createWriter(partitionId: Int, attemptNumber: Int): DataWriter[InternalRow] = { - new RowToInternalRowDataWriter( - rowWriterFactory.createWriter(partitionId, attemptNumber), + override def createDataWriter(partitionId: Int, attemptNumber: Int): DataWriter[InternalRow] = { + new InternalRowDataWriter( + rowWriterFactory.createDataWriter(partitionId, attemptNumber), RowEncoder.apply(schema).resolveAndBind()) } } -class RowToInternalRowDataWriter(rowWriter: DataWriter[Row], encoder: ExpressionEncoder[Row]) +class InternalRowDataWriter(rowWriter: DataWriter[Row], encoder: ExpressionEncoder[Row]) extends DataWriter[InternalRow] { override def write(record: InternalRow): Unit = rowWriter.write(encoder.fromRow(record)) diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaAdvancedDataSourceV2.java b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaAdvancedDataSourceV2.java index da2c13f70c52a..1cfdc08217e6e 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaAdvancedDataSourceV2.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaAdvancedDataSourceV2.java @@ -100,7 +100,7 @@ static class JavaAdvancedReadTask implements ReadTask, DataReader { } @Override - public DataReader createReader() { + public DataReader createDataReader() { return new JavaAdvancedReadTask(start - 1, end, requiredSchema); } diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaSimpleDataSourceV2.java b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaSimpleDataSourceV2.java index 08469f14c257a..2d458b7f7e906 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaSimpleDataSourceV2.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaSimpleDataSourceV2.java @@ -58,7 +58,7 @@ static class JavaSimpleReadTask implements ReadTask, DataReader { } @Override - public DataReader createReader() { + public DataReader createDataReader() { return new JavaSimpleReadTask(start - 1, end); } diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaUnsafeRowDataSourceV2.java b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaUnsafeRowDataSourceV2.java index 9efe7c791a936..f6aa00869a681 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaUnsafeRowDataSourceV2.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaUnsafeRowDataSourceV2.java @@ -58,7 +58,7 @@ static class JavaUnsafeRowReadTask implements ReadTask, DataReader createReader() { + public DataReader createDataReader() { return new JavaUnsafeRowReadTask(start - 1, end); } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala index 092702a1d5173..ab37e4984bd1f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala @@ -167,7 +167,7 @@ class SimpleDataSourceV2 extends DataSourceV2 with ReadSupport { class SimpleReadTask(start: Int, end: Int) extends ReadTask[Row] with DataReader[Row] { private var current = start - 1 - override def createReader(): DataReader[Row] = new SimpleReadTask(start, end) + override def createDataReader(): DataReader[Row] = new SimpleReadTask(start, end) override def next(): Boolean = { current += 1 @@ -233,7 +233,9 @@ class AdvancedReadTask(start: Int, end: Int, requiredSchema: StructType) private var current = start - 1 - override def createReader(): DataReader[Row] = new AdvancedReadTask(start, end, requiredSchema) + override def createDataReader(): DataReader[Row] = { + new AdvancedReadTask(start, end, requiredSchema) + } override def close(): Unit = {} @@ -273,7 +275,7 @@ class UnsafeRowReadTask(start: Int, end: Int) private var current = start - 1 - override def createReader(): DataReader[UnsafeRow] = new UnsafeRowReadTask(start, end) + override def createDataReader(): DataReader[UnsafeRow] = new UnsafeRowReadTask(start, end) override def next(): Boolean = { current += 1 diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/SimpleWritableDataSource.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/SimpleWritableDataSource.scala index 6fb60f4d848d7..cd7252eb2e3d6 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/SimpleWritableDataSource.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/SimpleWritableDataSource.scala @@ -18,8 +18,7 @@ package org.apache.spark.sql.sources.v2 import java.io.{BufferedReader, InputStreamReader, IOException} -import java.text.SimpleDateFormat -import java.util.{Collections, Date, List => JList, Locale, Optional, UUID} +import java.util.{Collections, List => JList, Optional} import scala.collection.JavaConverters._ @@ -157,7 +156,7 @@ class SimpleCSVReadTask(path: String, conf: SerializableConfiguration) @transient private var currentLine: String = _ @transient private var inputStream: FSDataInputStream = _ - override def createReader(): DataReader[Row] = { + override def createDataReader(): DataReader[Row] = { val filePath = new Path(path) val fs = filePath.getFileSystem(conf.value) inputStream = fs.open(filePath) @@ -185,7 +184,7 @@ class SimpleCSVReadTask(path: String, conf: SerializableConfiguration) class SimpleCSVDataWriterFactory(path: String, jobId: String, conf: SerializableConfiguration) extends DataWriterFactory[Row] { - override def createWriter(partitionId: Int, attemptNumber: Int): DataWriter[Row] = { + override def createDataWriter(partitionId: Int, attemptNumber: Int): DataWriter[Row] = { val jobPath = new Path(new Path(path, "_temporary"), jobId) val filePath = new Path(jobPath, s"$jobId-$partitionId-$attemptNumber") val fs = filePath.getFileSystem(conf.value) @@ -218,7 +217,7 @@ class SimpleCSVDataWriter(fs: FileSystem, file: Path) extends DataWriter[Row] { class InternalRowCSVDataWriterFactory(path: String, jobId: String, conf: SerializableConfiguration) extends DataWriterFactory[InternalRow] { - override def createWriter(partitionId: Int, attemptNumber: Int): DataWriter[InternalRow] = { + override def createDataWriter(partitionId: Int, attemptNumber: Int): DataWriter[InternalRow] = { val jobPath = new Path(new Path(path, "_temporary"), jobId) val filePath = new Path(jobPath, s"$jobId-$partitionId-$attemptNumber") val fs = filePath.getFileSystem(conf.value) From ded3ed97337427477d33bd5fad76649e96f6e50a Mon Sep 17 00:00:00 2001 From: Felix Cheung Date: Mon, 30 Oct 2017 21:44:24 -0700 Subject: [PATCH 1575/1765] [SPARK-22327][SPARKR][TEST] check for version warning ## What changes were proposed in this pull request? Will need to port to this to branch-1.6, -2.0, -2.1, -2.2 ## How was this patch tested? manually Jenkins, AppVeyor Author: Felix Cheung Closes #19549 from felixcheung/rcranversioncheck. --- R/run-tests.sh | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/R/run-tests.sh b/R/run-tests.sh index 29764f48bd156..f38c86e3e6b1d 100755 --- a/R/run-tests.sh +++ b/R/run-tests.sh @@ -38,6 +38,7 @@ FAILED=$((PIPESTATUS[0]||$FAILED)) NUM_CRAN_WARNING="$(grep -c WARNING$ $CRAN_CHECK_LOG_FILE)" NUM_CRAN_ERROR="$(grep -c ERROR$ $CRAN_CHECK_LOG_FILE)" NUM_CRAN_NOTES="$(grep -c NOTE$ $CRAN_CHECK_LOG_FILE)" +HAS_PACKAGE_VERSION_WARN="$(grep -c "Insufficient package version" $CRAN_CHECK_LOG_FILE)" if [[ $FAILED != 0 || $NUM_TEST_WARNING != 0 ]]; then cat $LOGFILE @@ -46,9 +47,10 @@ if [[ $FAILED != 0 || $NUM_TEST_WARNING != 0 ]]; then echo -en "\033[0m" # No color exit -1 else - # We have 2 existing NOTEs for new maintainer, attach() - # We have one more NOTE in Jenkins due to "No repository set" - if [[ $NUM_CRAN_WARNING != 0 || $NUM_CRAN_ERROR != 0 || $NUM_CRAN_NOTES -gt 3 ]]; then + # We have 2 NOTEs for RoxygenNote, attach(); and one in Jenkins only "No repository set" + # For non-latest version branches, one WARNING for package version + if [[ ($NUM_CRAN_WARNING != 0 || $NUM_CRAN_ERROR != 0 || $NUM_CRAN_NOTES -gt 3) && + ($HAS_PACKAGE_VERSION_WARN != 1 || $NUM_CRAN_WARNING != 1 || $NUM_CRAN_ERROR != 0 || $NUM_CRAN_NOTES -gt 2) ]]; then cat $CRAN_CHECK_LOG_FILE echo -en "\033[31m" # Red echo "Had CRAN check errors; see logs." From 1ff41d8693ade5bf34ee41a1140254488abfbbc7 Mon Sep 17 00:00:00 2001 From: "pj.fanning" Date: Tue, 31 Oct 2017 08:16:54 +0000 Subject: [PATCH 1576/1765] [SPARK-21708][BUILD] update some sbt plugins ## What changes were proposed in this pull request? These are just some straightforward upgrades to use the latest versions of some sbt plugins that also support sbt 1.0. The remaining sbt plugins that need upgrading will require bigger changes. ## How was this patch tested? Tested sbt use manually. Author: pj.fanning Closes #19609 from pjfanning/SPARK-21708. --- project/plugins.sbt | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/project/plugins.sbt b/project/plugins.sbt index 3c5442b04b8e4..96bdb9067ae59 100644 --- a/project/plugins.sbt +++ b/project/plugins.sbt @@ -1,11 +1,9 @@ // need to make changes to uptake sbt 1.0 support in "com.eed3si9n" % "sbt-assembly" % "1.14.5" addSbtPlugin("com.eed3si9n" % "sbt-assembly" % "0.11.2") -// sbt 1.0.0 support: https://github.com/typesafehub/sbteclipse/issues/343 -addSbtPlugin("com.typesafe.sbteclipse" % "sbteclipse-plugin" % "5.1.0") +addSbtPlugin("com.typesafe.sbteclipse" % "sbteclipse-plugin" % "5.2.3") -// sbt 1.0.0 support: https://github.com/jrudolph/sbt-dependency-graph/issues/134 -addSbtPlugin("net.virtual-void" % "sbt-dependency-graph" % "0.8.2") +addSbtPlugin("net.virtual-void" % "sbt-dependency-graph" % "0.9.0") addSbtPlugin("org.scalastyle" %% "scalastyle-sbt-plugin" % "1.0.0") @@ -20,8 +18,7 @@ addSbtPlugin("com.eed3si9n" % "sbt-unidoc" % "0.3.3") // need to make changes to uptake sbt 1.0 support in "com.cavorite" % "sbt-avro-1-7" % "1.1.2" addSbtPlugin("com.cavorite" % "sbt-avro" % "0.3.2") -// sbt 1.0.0 support: https://github.com/spray/sbt-revolver/issues/62 -addSbtPlugin("io.spray" % "sbt-revolver" % "0.8.0") +addSbtPlugin("io.spray" % "sbt-revolver" % "0.9.1") libraryDependencies += "org.ow2.asm" % "asm" % "5.1" From aa6db57e39d4931658089d9237dbf2a29acfe5ed Mon Sep 17 00:00:00 2001 From: bomeng Date: Tue, 31 Oct 2017 08:20:23 +0000 Subject: [PATCH 1577/1765] [SPARK-22399][ML] update the location of reference paper ## What changes were proposed in this pull request? Update the url of reference paper. ## How was this patch tested? It is comments, so nothing tested. Author: bomeng Closes #19614 from bomeng/22399. --- docs/mllib-clustering.md | 2 +- .../examples/mllib/PowerIterationClusteringExample.scala | 3 ++- .../spark/mllib/clustering/PowerIterationClustering.scala | 6 +++--- python/pyspark/mllib/clustering.py | 2 +- 4 files changed, 7 insertions(+), 6 deletions(-) diff --git a/docs/mllib-clustering.md b/docs/mllib-clustering.md index 8990e95796b67..df2be92d860e4 100644 --- a/docs/mllib-clustering.md +++ b/docs/mllib-clustering.md @@ -134,7 +134,7 @@ Refer to the [`GaussianMixture` Python docs](api/python/pyspark.mllib.html#pyspa Power iteration clustering (PIC) is a scalable and efficient algorithm for clustering vertices of a graph given pairwise similarities as edge properties, -described in [Lin and Cohen, Power Iteration Clustering](http://www.icml2010.org/papers/387.pdf). +described in [Lin and Cohen, Power Iteration Clustering](http://www.cs.cmu.edu/~frank/papers/icml2010-pic-final.pdf). It computes a pseudo-eigenvector of the normalized affinity matrix of the graph via [power iteration](http://en.wikipedia.org/wiki/Power_iteration) and uses it to cluster vertices. `spark.mllib` includes an implementation of PIC using GraphX as its backend. diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/PowerIterationClusteringExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/PowerIterationClusteringExample.scala index 986496c0d9435..65603252c4384 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/PowerIterationClusteringExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/PowerIterationClusteringExample.scala @@ -28,7 +28,8 @@ import org.apache.spark.mllib.clustering.PowerIterationClustering import org.apache.spark.rdd.RDD /** - * An example Power Iteration Clustering http://www.icml2010.org/papers/387.pdf app. + * An example Power Iteration Clustering app. + * http://www.cs.cmu.edu/~frank/papers/icml2010-pic-final.pdf * Takes an input of K concentric circles and the number of points in the innermost circle. * The output should be K clusters - each cluster containing precisely the points associated * with each of the input circles. diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/PowerIterationClustering.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/PowerIterationClustering.scala index b2437b845f826..9444f29a91ed8 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/PowerIterationClustering.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/PowerIterationClustering.scala @@ -103,9 +103,9 @@ object PowerIterationClusteringModel extends Loader[PowerIterationClusteringMode /** * Power Iteration Clustering (PIC), a scalable graph clustering algorithm developed by - * Lin and Cohen. From the abstract: PIC finds - * a very low-dimensional embedding of a dataset using truncated power iteration on a normalized - * pair-wise similarity matrix of the data. + * Lin and Cohen. + * From the abstract: PIC finds a very low-dimensional embedding of a dataset using + * truncated power iteration on a normalized pair-wise similarity matrix of the data. * * @param k Number of clusters. * @param maxIterations Maximum number of iterations of the PIC algorithm. diff --git a/python/pyspark/mllib/clustering.py b/python/pyspark/mllib/clustering.py index 91123ace3387e..bb687a7da6ffd 100644 --- a/python/pyspark/mllib/clustering.py +++ b/python/pyspark/mllib/clustering.py @@ -636,7 +636,7 @@ def load(cls, sc, path): class PowerIterationClustering(object): """ Power Iteration Clustering (PIC), a scalable graph clustering algorithm - developed by [[http://www.icml2010.org/papers/387.pdf Lin and Cohen]]. + developed by [[http://www.cs.cmu.edu/~frank/papers/icml2010-pic-final.pdf Lin and Cohen]]. From the abstract: PIC finds a very low-dimensional embedding of a dataset using truncated power iteration on a normalized pair-wise similarity matrix of the data. From 59589bc6545b6665432febfa9ee4891a96d119c4 Mon Sep 17 00:00:00 2001 From: Zhenhua Wang Date: Tue, 31 Oct 2017 11:13:48 +0100 Subject: [PATCH 1578/1765] [SPARK-22310][SQL] Refactor join estimation to incorporate estimation logic for different kinds of statistics ## What changes were proposed in this pull request? The current join estimation logic is only based on basic column statistics (such as ndv, etc). If we want to add estimation for other kinds of statistics (such as histograms), it's not easy to incorporate into the current algorithm: 1. When we have multiple pairs of join keys, the current algorithm computes cardinality in a single formula. But if different join keys have different kinds of stats, the computation logic for each pair of join keys become different, so the previous formula does not apply. 2. Currently it computes cardinality and updates join keys' column stats separately. It's better to do these two steps together, since both computation and update logic are different for different kinds of stats. ## How was this patch tested? Only refactor, covered by existing tests. Author: Zhenhua Wang Closes #19531 from wzhfy/join_est_refactor. --- .../BasicStatsPlanVisitor.scala | 4 +- .../statsEstimation/JoinEstimation.scala | 172 +++++++++--------- 2 files changed, 85 insertions(+), 91 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/BasicStatsPlanVisitor.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/BasicStatsPlanVisitor.scala index 4cff72d45a400..ca0775a2e8408 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/BasicStatsPlanVisitor.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/BasicStatsPlanVisitor.scala @@ -17,9 +17,7 @@ package org.apache.spark.sql.catalyst.plans.logical.statsEstimation -import org.apache.spark.sql.catalyst.plans.logical import org.apache.spark.sql.catalyst.plans.logical._ -import org.apache.spark.sql.types.LongType /** * An [[LogicalPlanVisitor]] that computes a the statistics used in a cost-based optimizer. @@ -54,7 +52,7 @@ object BasicStatsPlanVisitor extends LogicalPlanVisitor[Statistics] { override def visitIntersect(p: Intersect): Statistics = fallback(p) override def visitJoin(p: Join): Statistics = { - JoinEstimation.estimate(p).getOrElse(fallback(p)) + JoinEstimation(p).estimate.getOrElse(fallback(p)) } override def visitLocalLimit(p: LocalLimit): Statistics = fallback(p) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/JoinEstimation.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/JoinEstimation.scala index dcbe36da91dfc..b073108c26ee5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/JoinEstimation.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/JoinEstimation.scala @@ -28,60 +28,58 @@ import org.apache.spark.sql.catalyst.plans.logical.{ColumnStat, Join, Statistics import org.apache.spark.sql.catalyst.plans.logical.statsEstimation.EstimationUtils._ -object JoinEstimation extends Logging { +case class JoinEstimation(join: Join) extends Logging { + + private val leftStats = join.left.stats + private val rightStats = join.right.stats + /** * Estimate statistics after join. Return `None` if the join type is not supported, or we don't * have enough statistics for estimation. */ - def estimate(join: Join): Option[Statistics] = { + def estimate: Option[Statistics] = { join.joinType match { case Inner | Cross | LeftOuter | RightOuter | FullOuter => - InnerOuterEstimation(join).doEstimate() + estimateInnerOuterJoin() case LeftSemi | LeftAnti => - LeftSemiAntiEstimation(join).doEstimate() + estimateLeftSemiAntiJoin() case _ => logDebug(s"[CBO] Unsupported join type: ${join.joinType}") None } } -} - -case class InnerOuterEstimation(join: Join) extends Logging { - - private val leftStats = join.left.stats - private val rightStats = join.right.stats /** * Estimate output size and number of rows after a join operator, and update output column stats. */ - def doEstimate(): Option[Statistics] = join match { + private def estimateInnerOuterJoin(): Option[Statistics] = join match { case _ if !rowCountsExist(join.left, join.right) => None case ExtractEquiJoinKeys(joinType, leftKeys, rightKeys, _, _, _) => // 1. Compute join selectivity val joinKeyPairs = extractJoinKeysWithColStats(leftKeys, rightKeys) - val selectivity = joinSelectivity(joinKeyPairs) + val (numInnerJoinedRows, keyStatsAfterJoin) = computeCardinalityAndStats(joinKeyPairs) // 2. Estimate the number of output rows val leftRows = leftStats.rowCount.get val rightRows = rightStats.rowCount.get - val innerJoinedRows = ceil(BigDecimal(leftRows * rightRows) * selectivity) // Make sure outputRows won't be too small based on join type. val outputRows = joinType match { case LeftOuter => // All rows from left side should be in the result. - leftRows.max(innerJoinedRows) + leftRows.max(numInnerJoinedRows) case RightOuter => // All rows from right side should be in the result. - rightRows.max(innerJoinedRows) + rightRows.max(numInnerJoinedRows) case FullOuter => // T(A FOJ B) = T(A LOJ B) + T(A ROJ B) - T(A IJ B) - leftRows.max(innerJoinedRows) + rightRows.max(innerJoinedRows) - innerJoinedRows + leftRows.max(numInnerJoinedRows) + rightRows.max(numInnerJoinedRows) - numInnerJoinedRows case _ => + assert(joinType == Inner || joinType == Cross) // Don't change for inner or cross join - innerJoinedRows + numInnerJoinedRows } // 3. Update statistics based on the output of join @@ -93,7 +91,7 @@ case class InnerOuterEstimation(join: Join) extends Logging { val outputStats: Seq[(Attribute, ColumnStat)] = if (outputRows == 0) { // The output is empty, we don't need to keep column stats. Nil - } else if (selectivity == 0) { + } else if (numInnerJoinedRows == 0) { joinType match { // For outer joins, if the join selectivity is 0, the number of output rows is the // same as that of the outer side. And column stats of join keys from the outer side @@ -113,26 +111,28 @@ case class InnerOuterEstimation(join: Join) extends Logging { val oriColStat = inputAttrStats(a) (a, oriColStat.copy(nullCount = oriColStat.nullCount + leftRows)) } - case _ => Nil + case _ => + assert(joinType == Inner || joinType == Cross) + Nil } - } else if (selectivity == 1) { + } else if (numInnerJoinedRows == leftRows * rightRows) { // Cartesian product, just propagate the original column stats inputAttrStats.toSeq } else { - val joinKeyStats = getIntersectedStats(joinKeyPairs) join.joinType match { // For outer joins, don't update column stats from the outer side. case LeftOuter => fromLeft.map(a => (a, inputAttrStats(a))) ++ - updateAttrStats(outputRows, fromRight, inputAttrStats, joinKeyStats) + updateOutputStats(outputRows, fromRight, inputAttrStats, keyStatsAfterJoin) case RightOuter => - updateAttrStats(outputRows, fromLeft, inputAttrStats, joinKeyStats) ++ + updateOutputStats(outputRows, fromLeft, inputAttrStats, keyStatsAfterJoin) ++ fromRight.map(a => (a, inputAttrStats(a))) case FullOuter => inputAttrStats.toSeq case _ => + assert(joinType == Inner || joinType == Cross) // Update column stats from both sides for inner or cross join. - updateAttrStats(outputRows, attributesWithStat, inputAttrStats, joinKeyStats) + updateOutputStats(outputRows, attributesWithStat, inputAttrStats, keyStatsAfterJoin) } } @@ -157,64 +157,90 @@ case class InnerOuterEstimation(join: Join) extends Logging { // scalastyle:off /** * The number of rows of A inner join B on A.k1 = B.k1 is estimated by this basic formula: - * T(A IJ B) = T(A) * T(B) / max(V(A.k1), V(B.k1)), where V is the number of distinct values of - * that column. The underlying assumption for this formula is: each value of the smaller domain - * is included in the larger domain. - * Generally, inner join with multiple join keys can also be estimated based on the above - * formula: + * T(A IJ B) = T(A) * T(B) / max(V(A.k1), V(B.k1)), + * where V is the number of distinct values (ndv) of that column. The underlying assumption for + * this formula is: each value of the smaller domain is included in the larger domain. + * + * Generally, inner join with multiple join keys can be estimated based on the above formula: * T(A IJ B) = T(A) * T(B) / (max(V(A.k1), V(B.k1)) * max(V(A.k2), V(B.k2)) * ... * max(V(A.kn), V(B.kn))) * However, the denominator can become very large and excessively reduce the result, so we use a * conservative strategy to take only the largest max(V(A.ki), V(B.ki)) as the denominator. + * + * That is, join estimation is based on the most selective join keys. We follow this strategy + * when different types of column statistics are available. E.g., if card1 is the cardinality + * estimated by ndv of join key A.k1 and B.k1, card2 is the cardinality estimated by histograms + * of join key A.k2 and B.k2, then the result cardinality would be min(card1, card2). + * + * @param keyPairs pairs of join keys + * + * @return join cardinality, and column stats for join keys after the join */ // scalastyle:on - def joinSelectivity(joinKeyPairs: Seq[(AttributeReference, AttributeReference)]): BigDecimal = { - var ndvDenom: BigInt = -1 + private def computeCardinalityAndStats(keyPairs: Seq[(AttributeReference, AttributeReference)]) + : (BigInt, AttributeMap[ColumnStat]) = { + // If there's no column stats available for join keys, estimate as cartesian product. + var joinCard: BigInt = leftStats.rowCount.get * rightStats.rowCount.get + val keyStatsAfterJoin = new mutable.HashMap[Attribute, ColumnStat]() var i = 0 - while(i < joinKeyPairs.length && ndvDenom != 0) { - val (leftKey, rightKey) = joinKeyPairs(i) + while(i < keyPairs.length && joinCard != 0) { + val (leftKey, rightKey) = keyPairs(i) // Check if the two sides are disjoint - val leftKeyStats = leftStats.attributeStats(leftKey) - val rightKeyStats = rightStats.attributeStats(rightKey) - val lInterval = ValueInterval(leftKeyStats.min, leftKeyStats.max, leftKey.dataType) - val rInterval = ValueInterval(rightKeyStats.min, rightKeyStats.max, rightKey.dataType) + val leftKeyStat = leftStats.attributeStats(leftKey) + val rightKeyStat = rightStats.attributeStats(rightKey) + val lInterval = ValueInterval(leftKeyStat.min, leftKeyStat.max, leftKey.dataType) + val rInterval = ValueInterval(rightKeyStat.min, rightKeyStat.max, rightKey.dataType) if (ValueInterval.isIntersected(lInterval, rInterval)) { - // Get the largest ndv among pairs of join keys - val maxNdv = leftKeyStats.distinctCount.max(rightKeyStats.distinctCount) - if (maxNdv > ndvDenom) ndvDenom = maxNdv + val (newMin, newMax) = ValueInterval.intersect(lInterval, rInterval, leftKey.dataType) + val (card, joinStat) = computeByNdv(leftKey, rightKey, newMin, newMax) + keyStatsAfterJoin += (leftKey -> joinStat, rightKey -> joinStat) + // Return cardinality estimated from the most selective join keys. + if (card < joinCard) joinCard = card } else { - // Set ndvDenom to zero to indicate that this join should have no output - ndvDenom = 0 + // One of the join key pairs is disjoint, thus the two sides of join is disjoint. + joinCard = 0 } i += 1 } + (joinCard, AttributeMap(keyStatsAfterJoin.toSeq)) + } - if (ndvDenom < 0) { - // We can't find any join key pairs with column stats, estimate it as cartesian join. - 1 - } else if (ndvDenom == 0) { - // One of the join key pairs is disjoint, thus the two sides of join is disjoint. - 0 - } else { - 1 / BigDecimal(ndvDenom) - } + /** Returns join cardinality and the column stat for this pair of join keys. */ + private def computeByNdv( + leftKey: AttributeReference, + rightKey: AttributeReference, + newMin: Option[Any], + newMax: Option[Any]): (BigInt, ColumnStat) = { + val leftKeyStat = leftStats.attributeStats(leftKey) + val rightKeyStat = rightStats.attributeStats(rightKey) + val maxNdv = leftKeyStat.distinctCount.max(rightKeyStat.distinctCount) + // Compute cardinality by the basic formula. + val card = BigDecimal(leftStats.rowCount.get * rightStats.rowCount.get) / BigDecimal(maxNdv) + + // Get the intersected column stat. + val newNdv = leftKeyStat.distinctCount.min(rightKeyStat.distinctCount) + val newMaxLen = math.min(leftKeyStat.maxLen, rightKeyStat.maxLen) + val newAvgLen = (leftKeyStat.avgLen + rightKeyStat.avgLen) / 2 + val newStats = ColumnStat(newNdv, newMin, newMax, 0, newAvgLen, newMaxLen) + + (ceil(card), newStats) } /** * Propagate or update column stats for output attributes. */ - private def updateAttrStats( + private def updateOutputStats( outputRows: BigInt, - attributes: Seq[Attribute], + output: Seq[Attribute], oldAttrStats: AttributeMap[ColumnStat], - joinKeyStats: AttributeMap[ColumnStat]): Seq[(Attribute, ColumnStat)] = { + keyStatsAfterJoin: AttributeMap[ColumnStat]): Seq[(Attribute, ColumnStat)] = { val outputAttrStats = new ArrayBuffer[(Attribute, ColumnStat)]() val leftRows = leftStats.rowCount.get val rightRows = rightStats.rowCount.get - attributes.foreach { a => + output.foreach { a => // check if this attribute is a join key - if (joinKeyStats.contains(a)) { - outputAttrStats += a -> joinKeyStats(a) + if (keyStatsAfterJoin.contains(a)) { + outputAttrStats += a -> keyStatsAfterJoin(a) } else { val oldColStat = oldAttrStats(a) val oldNdv = oldColStat.distinctCount @@ -231,34 +257,6 @@ case class InnerOuterEstimation(join: Join) extends Logging { outputAttrStats } - /** Get intersected column stats for join keys. */ - private def getIntersectedStats(joinKeyPairs: Seq[(AttributeReference, AttributeReference)]) - : AttributeMap[ColumnStat] = { - - val intersectedStats = new mutable.HashMap[Attribute, ColumnStat]() - joinKeyPairs.foreach { case (leftKey, rightKey) => - val leftKeyStats = leftStats.attributeStats(leftKey) - val rightKeyStats = rightStats.attributeStats(rightKey) - val lInterval = ValueInterval(leftKeyStats.min, leftKeyStats.max, leftKey.dataType) - val rInterval = ValueInterval(rightKeyStats.min, rightKeyStats.max, rightKey.dataType) - // When we reach here, join selectivity is not zero, so each pair of join keys should be - // intersected. - assert(ValueInterval.isIntersected(lInterval, rInterval)) - - // Update intersected column stats - assert(leftKey.dataType.sameType(rightKey.dataType)) - val newNdv = leftKeyStats.distinctCount.min(rightKeyStats.distinctCount) - val (newMin, newMax) = ValueInterval.intersect(lInterval, rInterval, leftKey.dataType) - val newMaxLen = math.min(leftKeyStats.maxLen, rightKeyStats.maxLen) - val newAvgLen = (leftKeyStats.avgLen + rightKeyStats.avgLen) / 2 - val newStats = ColumnStat(newNdv, newMin, newMax, 0, newAvgLen, newMaxLen) - - intersectedStats.put(leftKey, newStats) - intersectedStats.put(rightKey, newStats) - } - AttributeMap(intersectedStats.toSeq) - } - private def extractJoinKeysWithColStats( leftKeys: Seq[Expression], rightKeys: Seq[Expression]): Seq[(AttributeReference, AttributeReference)] = { @@ -270,10 +268,8 @@ case class InnerOuterEstimation(join: Join) extends Logging { if columnStatsExist((leftStats, lk), (rightStats, rk)) => (lk, rk) } } -} -case class LeftSemiAntiEstimation(join: Join) { - def doEstimate(): Option[Statistics] = { + private def estimateLeftSemiAntiJoin(): Option[Statistics] = { // TODO: It's error-prone to estimate cardinalities for LeftSemi and LeftAnti based on basic // column stats. Now we just propagate the statistics from left side. We should do more // accurate estimation when advanced stats (e.g. histograms) are available. From 4d9ebf3835dde1abbf9cff29a55675d9f4227620 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Tue, 31 Oct 2017 11:35:32 +0100 Subject: [PATCH 1579/1765] [SPARK-19611][SQL][FOLLOWUP] set dataSchema correctly in HiveMetastoreCatalog.convertToLogicalRelation ## What changes were proposed in this pull request? We made a mistake in https://github.com/apache/spark/pull/16944 . In `HiveMetastoreCatalog#inferIfNeeded` we infer the data schema, merge with full schema, and return the new full schema. At caller side we treat the full schema as data schema and set it to `HadoopFsRelation`. This doesn't cause any problem because both parquet and orc can work with a wrong data schema that has extra columns, but it's better to fix this mistake. ## How was this patch tested? N/A Author: Wenchen Fan Closes #19615 from cloud-fan/infer. --- .../spark/sql/hive/HiveMetastoreCatalog.scala | 22 +++++++++---------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala index f0f2c493498b3..5ac65973e70e1 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala @@ -164,13 +164,12 @@ private[hive] class HiveMetastoreCatalog(sparkSession: SparkSession) extends Log } } - val (dataSchema, updatedTable) = - inferIfNeeded(relation, options, fileFormat, Option(fileIndex)) + val updatedTable = inferIfNeeded(relation, options, fileFormat, Option(fileIndex)) val fsRelation = HadoopFsRelation( location = fileIndex, partitionSchema = partitionSchema, - dataSchema = dataSchema, + dataSchema = updatedTable.dataSchema, bucketSpec = None, fileFormat = fileFormat, options = options)(sparkSession = sparkSession) @@ -191,13 +190,13 @@ private[hive] class HiveMetastoreCatalog(sparkSession: SparkSession) extends Log fileFormatClass, None) val logicalRelation = cached.getOrElse { - val (dataSchema, updatedTable) = inferIfNeeded(relation, options, fileFormat) + val updatedTable = inferIfNeeded(relation, options, fileFormat) val created = LogicalRelation( DataSource( sparkSession = sparkSession, paths = rootPath.toString :: Nil, - userSpecifiedSchema = Option(dataSchema), + userSpecifiedSchema = Option(updatedTable.dataSchema), bucketSpec = None, options = options, className = fileType).resolveRelation(), @@ -224,7 +223,7 @@ private[hive] class HiveMetastoreCatalog(sparkSession: SparkSession) extends Log relation: HiveTableRelation, options: Map[String, String], fileFormat: FileFormat, - fileIndexOpt: Option[FileIndex] = None): (StructType, CatalogTable) = { + fileIndexOpt: Option[FileIndex] = None): CatalogTable = { val inferenceMode = sparkSession.sessionState.conf.caseSensitiveInferenceMode val shouldInfer = (inferenceMode != NEVER_INFER) && !relation.tableMeta.schemaPreservesCase val tableName = relation.tableMeta.identifier.unquotedString @@ -241,21 +240,22 @@ private[hive] class HiveMetastoreCatalog(sparkSession: SparkSession) extends Log sparkSession, options, fileIndex.listFiles(Nil, Nil).flatMap(_.files)) - .map(mergeWithMetastoreSchema(relation.tableMeta.schema, _)) + .map(mergeWithMetastoreSchema(relation.tableMeta.dataSchema, _)) inferredSchema match { - case Some(schema) => + case Some(dataSchema) => + val schema = StructType(dataSchema ++ relation.tableMeta.partitionSchema) if (inferenceMode == INFER_AND_SAVE) { updateCatalogSchema(relation.tableMeta.identifier, schema) } - (schema, relation.tableMeta.copy(schema = schema)) + relation.tableMeta.copy(schema = schema) case None => logWarning(s"Unable to infer schema for table $tableName from file format " + s"$fileFormat (inference mode: $inferenceMode). Using metastore schema.") - (relation.tableMeta.schema, relation.tableMeta) + relation.tableMeta } } else { - (relation.tableMeta.schema, relation.tableMeta) + relation.tableMeta } } From 7986cc09b1b2100fc061d0aea8aa2e1e1b162c75 Mon Sep 17 00:00:00 2001 From: Sital Kedia Date: Tue, 31 Oct 2017 09:49:58 -0700 Subject: [PATCH 1580/1765] [SPARK-11334][CORE] Fix bug in Executor allocation manager in running tasks calculation ## What changes were proposed in this pull request? We often see the issue of Spark jobs stuck because the Executor Allocation Manager does not ask for any executor even if there are pending tasks in case dynamic allocation is turned on. Looking at the logic in Executor Allocation Manager, which calculates the running tasks, it can happen that the calculation will be wrong and the number of running tasks can become negative. ## How was this patch tested? Added unit test Author: Sital Kedia Closes #19580 from sitalkedia/skedia/fix_stuck_job. --- .../spark/ExecutorAllocationManager.scala | 29 ++++++++++++------- .../ExecutorAllocationManagerSuite.scala | 22 ++++++++++++++ 2 files changed, 40 insertions(+), 11 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala b/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala index 119b426a9af34..5bc2d9ef1b949 100644 --- a/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala +++ b/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala @@ -267,6 +267,10 @@ private[spark] class ExecutorAllocationManager( (numRunningOrPendingTasks + tasksPerExecutor - 1) / tasksPerExecutor } + private def totalRunningTasks(): Int = synchronized { + listener.totalRunningTasks + } + /** * This is called at a fixed interval to regulate the number of pending executor requests * and number of executors running. @@ -602,12 +606,11 @@ private[spark] class ExecutorAllocationManager( private class ExecutorAllocationListener extends SparkListener { private val stageIdToNumTasks = new mutable.HashMap[Int, Int] + // Number of running tasks per stage including speculative tasks. + // Should be 0 when no stages are active. + private val stageIdToNumRunningTask = new mutable.HashMap[Int, Int] private val stageIdToTaskIndices = new mutable.HashMap[Int, mutable.HashSet[Int]] private val executorIdToTaskIds = new mutable.HashMap[String, mutable.HashSet[Long]] - // Number of tasks currently running on the cluster including speculative tasks. - // Should be 0 when no stages are active. - private var numRunningTasks: Int = _ - // Number of speculative tasks to be scheduled in each stage private val stageIdToNumSpeculativeTasks = new mutable.HashMap[Int, Int] // The speculative tasks started in each stage @@ -625,6 +628,7 @@ private[spark] class ExecutorAllocationManager( val numTasks = stageSubmitted.stageInfo.numTasks allocationManager.synchronized { stageIdToNumTasks(stageId) = numTasks + stageIdToNumRunningTask(stageId) = 0 allocationManager.onSchedulerBacklogged() // Compute the number of tasks requested by the stage on each host @@ -651,6 +655,7 @@ private[spark] class ExecutorAllocationManager( val stageId = stageCompleted.stageInfo.stageId allocationManager.synchronized { stageIdToNumTasks -= stageId + stageIdToNumRunningTask -= stageId stageIdToNumSpeculativeTasks -= stageId stageIdToTaskIndices -= stageId stageIdToSpeculativeTaskIndices -= stageId @@ -663,10 +668,6 @@ private[spark] class ExecutorAllocationManager( // This is needed in case the stage is aborted for any reason if (stageIdToNumTasks.isEmpty && stageIdToNumSpeculativeTasks.isEmpty) { allocationManager.onSchedulerQueueEmpty() - if (numRunningTasks != 0) { - logWarning("No stages are running, but numRunningTasks != 0") - numRunningTasks = 0 - } } } } @@ -678,7 +679,9 @@ private[spark] class ExecutorAllocationManager( val executorId = taskStart.taskInfo.executorId allocationManager.synchronized { - numRunningTasks += 1 + if (stageIdToNumRunningTask.contains(stageId)) { + stageIdToNumRunningTask(stageId) += 1 + } // This guards against the race condition in which the `SparkListenerTaskStart` // event is posted before the `SparkListenerBlockManagerAdded` event, which is // possible because these events are posted in different threads. (see SPARK-4951) @@ -709,7 +712,9 @@ private[spark] class ExecutorAllocationManager( val taskIndex = taskEnd.taskInfo.index val stageId = taskEnd.stageId allocationManager.synchronized { - numRunningTasks -= 1 + if (stageIdToNumRunningTask.contains(stageId)) { + stageIdToNumRunningTask(stageId) -= 1 + } // If the executor is no longer running any scheduled tasks, mark it as idle if (executorIdToTaskIds.contains(executorId)) { executorIdToTaskIds(executorId) -= taskId @@ -787,7 +792,9 @@ private[spark] class ExecutorAllocationManager( /** * The number of tasks currently running across all stages. */ - def totalRunningTasks(): Int = numRunningTasks + def totalRunningTasks(): Int = { + stageIdToNumRunningTask.values.sum + } /** * Return true if an executor is not currently running a task, and false otherwise. diff --git a/core/src/test/scala/org/apache/spark/ExecutorAllocationManagerSuite.scala b/core/src/test/scala/org/apache/spark/ExecutorAllocationManagerSuite.scala index a91e09b7cb69f..90b7ec4384abd 100644 --- a/core/src/test/scala/org/apache/spark/ExecutorAllocationManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/ExecutorAllocationManagerSuite.scala @@ -227,6 +227,23 @@ class ExecutorAllocationManagerSuite assert(numExecutorsToAdd(manager) === 1) } + test("ignore task end events from completed stages") { + sc = createSparkContext(0, 10, 0) + val manager = sc.executorAllocationManager.get + val stage = createStageInfo(0, 5) + post(sc.listenerBus, SparkListenerStageSubmitted(stage)) + val taskInfo1 = createTaskInfo(0, 0, "executor-1") + val taskInfo2 = createTaskInfo(1, 1, "executor-1") + post(sc.listenerBus, SparkListenerTaskStart(0, 0, taskInfo1)) + post(sc.listenerBus, SparkListenerTaskStart(0, 0, taskInfo2)) + + post(sc.listenerBus, SparkListenerStageCompleted(stage)) + + post(sc.listenerBus, SparkListenerTaskEnd(0, 0, null, Success, taskInfo1, null)) + post(sc.listenerBus, SparkListenerTaskEnd(2, 0, null, Success, taskInfo2, null)) + assert(totalRunningTasks(manager) === 0) + } + test("cancel pending executors when no longer needed") { sc = createSparkContext(0, 10, 0) val manager = sc.executorAllocationManager.get @@ -1107,6 +1124,7 @@ private object ExecutorAllocationManagerSuite extends PrivateMethodTester { private val _localityAwareTasks = PrivateMethod[Int]('localityAwareTasks) private val _hostToLocalTaskCount = PrivateMethod[Map[String, Int]]('hostToLocalTaskCount) private val _onSpeculativeTaskSubmitted = PrivateMethod[Unit]('onSpeculativeTaskSubmitted) + private val _totalRunningTasks = PrivateMethod[Int]('totalRunningTasks) private def numExecutorsToAdd(manager: ExecutorAllocationManager): Int = { manager invokePrivate _numExecutorsToAdd() @@ -1190,6 +1208,10 @@ private object ExecutorAllocationManagerSuite extends PrivateMethodTester { manager invokePrivate _localityAwareTasks() } + private def totalRunningTasks(manager: ExecutorAllocationManager): Int = { + manager invokePrivate _totalRunningTasks() + } + private def hostToLocalTaskCount(manager: ExecutorAllocationManager): Map[String, Int] = { manager invokePrivate _hostToLocalTaskCount() } From 73231860baaa40f6001db347e5dcb6b5bb65e032 Mon Sep 17 00:00:00 2001 From: Jose Torres Date: Tue, 31 Oct 2017 11:53:50 -0700 Subject: [PATCH 1581/1765] [SPARK-22305] Write HDFSBackedStateStoreProvider.loadMap non-recursively ## What changes were proposed in this pull request? Write HDFSBackedStateStoreProvider.loadMap non-recursively. This prevents stack overflow if too many deltas stack up in a low memory environment. ## How was this patch tested? existing unit tests for functional equivalence, new unit test to check for stack overflow Author: Jose Torres Closes #19611 from joseph-torres/SPARK-22305. --- .../state/HDFSBackedStateStoreProvider.scala | 45 +++++++++++++++---- 1 file changed, 36 insertions(+), 9 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala index 36d6569a4187a..3f5002a4e6937 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala @@ -297,17 +297,44 @@ private[state] class HDFSBackedStateStoreProvider extends StateStoreProvider wit /** Load the required version of the map data from the backing files */ private def loadMap(version: Long): MapType = { - if (version <= 0) return new MapType - synchronized { loadedMaps.get(version) }.getOrElse { - val mapFromFile = readSnapshotFile(version).getOrElse { - val prevMap = loadMap(version - 1) - val newMap = new MapType(prevMap) - updateFromDeltaFile(version, newMap) - newMap + + // Shortcut if the map for this version is already there to avoid a redundant put. + val loadedCurrentVersionMap = synchronized { loadedMaps.get(version) } + if (loadedCurrentVersionMap.isDefined) { + return loadedCurrentVersionMap.get + } + val snapshotCurrentVersionMap = readSnapshotFile(version) + if (snapshotCurrentVersionMap.isDefined) { + synchronized { loadedMaps.put(version, snapshotCurrentVersionMap.get) } + return snapshotCurrentVersionMap.get + } + + // Find the most recent map before this version that we can. + // [SPARK-22305] This must be done iteratively to avoid stack overflow. + var lastAvailableVersion = version + var lastAvailableMap: Option[MapType] = None + while (lastAvailableMap.isEmpty) { + lastAvailableVersion -= 1 + + if (lastAvailableVersion <= 0) { + // Use an empty map for versions 0 or less. + lastAvailableMap = Some(new MapType) + } else { + lastAvailableMap = + synchronized { loadedMaps.get(lastAvailableVersion) } + .orElse(readSnapshotFile(lastAvailableVersion)) } - loadedMaps.put(version, mapFromFile) - mapFromFile } + + // Load all the deltas from the version after the last available one up to the target version. + // The last available version is the one with a full snapshot, so it doesn't need deltas. + val resultMap = new MapType(lastAvailableMap.get) + for (deltaVersion <- lastAvailableVersion + 1 to version) { + updateFromDeltaFile(deltaVersion, resultMap) + } + + synchronized { loadedMaps.put(version, resultMap) } + resultMap } private def writeUpdateToDeltaFile( From 556b5d21512b17027a6e451c6a82fb428940e95a Mon Sep 17 00:00:00 2001 From: Zheng RuiFeng Date: Wed, 1 Nov 2017 08:45:11 +0000 Subject: [PATCH 1582/1765] [SPARK-5484][FOLLOWUP] PeriodicRDDCheckpointer doc cleanup ## What changes were proposed in this pull request? PeriodicRDDCheckpointer was already moved out of mllib in Spark-5484 ## How was this patch tested? existing tests Author: Zheng RuiFeng Closes #19618 from zhengruifeng/checkpointer_doc. --- .../org/apache/spark/rdd/util/PeriodicRDDCheckpointer.scala | 2 -- 1 file changed, 2 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/rdd/util/PeriodicRDDCheckpointer.scala b/core/src/main/scala/org/apache/spark/rdd/util/PeriodicRDDCheckpointer.scala index facbb830a60d8..5e181a9822534 100644 --- a/core/src/main/scala/org/apache/spark/rdd/util/PeriodicRDDCheckpointer.scala +++ b/core/src/main/scala/org/apache/spark/rdd/util/PeriodicRDDCheckpointer.scala @@ -73,8 +73,6 @@ import org.apache.spark.util.PeriodicCheckpointer * * @param checkpointInterval RDDs will be checkpointed at this interval * @tparam T RDD element type - * - * TODO: Move this out of MLlib? */ private[spark] class PeriodicRDDCheckpointer[T]( checkpointInterval: Int, From 96798d14f07208796fa0a90af0ab369879bacd6c Mon Sep 17 00:00:00 2001 From: Devaraj K Date: Wed, 1 Nov 2017 18:07:39 +0800 Subject: [PATCH 1583/1765] [SPARK-22172][CORE] Worker hangs when the external shuffle service port is already in use ## What changes were proposed in this pull request? Handling the NonFatal exceptions while starting the external shuffle service, if there are any NonFatal exceptions it logs and continues without the external shuffle service. ## How was this patch tested? I verified it manually, it logs the exception and continues to serve without external shuffle service when BindException occurs. Author: Devaraj K Closes #19396 from devaraj-kavali/SPARK-22172. --- .../org/apache/spark/deploy/worker/Worker.scala | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala b/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala index ed5fa4b839cd4..3962d422f81d3 100755 --- a/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala @@ -199,7 +199,7 @@ private[deploy] class Worker( logInfo(s"Running Spark version ${org.apache.spark.SPARK_VERSION}") logInfo("Spark home: " + sparkHome) createWorkDir() - shuffleService.startIfEnabled() + startExternalShuffleService() webUi = new WorkerWebUI(this, workDir, webUiPort) webUi.bind() @@ -367,6 +367,16 @@ private[deploy] class Worker( } } + private def startExternalShuffleService() { + try { + shuffleService.startIfEnabled() + } catch { + case e: Exception => + logError("Failed to start external shuffle service", e) + System.exit(1) + } + } + private def sendRegisterMessageToMaster(masterEndpoint: RpcEndpointRef): Unit = { masterEndpoint.send(RegisterWorker( workerId, From 07f390a27d7b793291c352a643d4bbd5f47294a6 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Wed, 1 Nov 2017 13:09:35 +0100 Subject: [PATCH 1584/1765] [SPARK-22347][PYSPARK][DOC] Add document to notice users for using udfs with conditional expressions ## What changes were proposed in this pull request? Under the current execution mode of Python UDFs, we don't well support Python UDFs as branch values or else value in CaseWhen expression. Since to fix it might need the change not small (e.g., #19592) and this issue has simpler workaround. We should just notice users in the document about this. ## How was this patch tested? Only document change. Author: Liang-Chi Hsieh Closes #19617 from viirya/SPARK-22347-3. --- python/pyspark/sql/functions.py | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index 0d40368c9cd6e..39815497f3956 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -2185,6 +2185,13 @@ def udf(f=None, returnType=StringType()): duplicate invocations may be eliminated or the function may even be invoked more times than it is present in the query. + .. note:: The user-defined functions do not support conditional execution by using them with + SQL conditional expressions such as `when` or `if`. The functions still apply on all rows no + matter the conditions are met or not. So the output is correct if the functions can be + correctly run on all rows without failure. If the functions can cause runtime failure on the + rows that do not satisfy the conditions, the suggested workaround is to incorporate the + condition logic into the functions. + :param f: python function if used as a standalone function :param returnType: a :class:`pyspark.sql.types.DataType` object @@ -2278,6 +2285,13 @@ def pandas_udf(f=None, returnType=StringType()): .. seealso:: :meth:`pyspark.sql.GroupedData.apply` .. note:: The user-defined function must be deterministic. + + .. note:: The user-defined functions do not support conditional execution by using them with + SQL conditional expressions such as `when` or `if`. The functions still apply on all rows no + matter the conditions are met or not. So the output is correct if the functions can be + correctly run on all rows without failure. If the functions can cause runtime failure on the + rows that do not satisfy the conditions, the suggested workaround is to incorporate the + condition logic into the functions. """ return _create_udf(f, returnType=returnType, pythonUdfType=PythonUdfType.PANDAS_UDF) From 444bce1c98c45147fe63e2132e9743a0c5e49598 Mon Sep 17 00:00:00 2001 From: Sital Kedia Date: Wed, 1 Nov 2017 14:54:08 +0100 Subject: [PATCH 1585/1765] [SPARK-19112][CORE] Support for ZStandard codec ## What changes were proposed in this pull request? Using zstd compression for Spark jobs spilling 100s of TBs of data, we could reduce the amount of data written to disk by as much as 50%. This translates to significant latency gain because of reduced disk io operations. There is a degradation CPU time by 2 - 5% because of zstd compression overhead, but for jobs which are bottlenecked by disk IO, this hit can be taken. ## Benchmark Please note that this benchmark is using real world compute heavy production workload spilling TBs of data to disk | | zstd performance as compred to LZ4 | | ------------- | -----:| | spill/shuffle bytes | -48% | | cpu time | + 3% | | cpu reservation time | -40%| | latency | -40% | ## How was this patch tested? Tested by running few jobs spilling large amount of data on the cluster and amount of intermediate data written to disk reduced by as much as 50%. Author: Sital Kedia Closes #18805 from sitalkedia/skedia/upstream_zstd. --- LICENSE | 2 ++ core/pom.xml | 4 +++ .../apache/spark/io/CompressionCodec.scala | 36 +++++++++++++++++-- .../spark/io/CompressionCodecSuite.scala | 18 ++++++++++ dev/deps/spark-deps-hadoop-2.6 | 1 + dev/deps/spark-deps-hadoop-2.7 | 1 + docs/configuration.md | 20 ++++++++++- licenses/LICENSE-zstd-jni.txt | 26 ++++++++++++++ licenses/LICENSE-zstd.txt | 30 ++++++++++++++++ pom.xml | 5 +++ 10 files changed, 140 insertions(+), 3 deletions(-) create mode 100644 licenses/LICENSE-zstd-jni.txt create mode 100644 licenses/LICENSE-zstd.txt diff --git a/LICENSE b/LICENSE index 39fe0dc462385..c2b0d72663b55 100644 --- a/LICENSE +++ b/LICENSE @@ -269,6 +269,8 @@ The text of each license is also included at licenses/LICENSE-[project].txt. (BSD 3 Clause) d3.min.js (https://github.com/mbostock/d3/blob/master/LICENSE) (BSD 3 Clause) DPark (https://github.com/douban/dpark/blob/master/LICENSE) (BSD 3 Clause) CloudPickle (https://github.com/cloudpipe/cloudpickle/blob/master/LICENSE) + (BSD 2 Clause) Zstd-jni (https://github.com/luben/zstd-jni/blob/master/LICENSE) + (BSD license) Zstd (https://github.com/facebook/zstd/blob/v1.3.1/LICENSE) ======================================================================== MIT licenses diff --git a/core/pom.xml b/core/pom.xml index 54f7a34a6c37e..fa138d3e7a4e0 100644 --- a/core/pom.xml +++ b/core/pom.xml @@ -198,6 +198,10 @@ org.lz4 lz4-java + + com.github.luben + zstd-jni + org.roaringbitmap RoaringBitmap diff --git a/core/src/main/scala/org/apache/spark/io/CompressionCodec.scala b/core/src/main/scala/org/apache/spark/io/CompressionCodec.scala index 27f2e429395db..7722db56ee297 100644 --- a/core/src/main/scala/org/apache/spark/io/CompressionCodec.scala +++ b/core/src/main/scala/org/apache/spark/io/CompressionCodec.scala @@ -20,6 +20,7 @@ package org.apache.spark.io import java.io._ import java.util.Locale +import com.github.luben.zstd.{ZstdInputStream, ZstdOutputStream} import com.ning.compress.lzf.{LZFInputStream, LZFOutputStream} import net.jpountz.lz4.{LZ4BlockInputStream, LZ4BlockOutputStream} import org.xerial.snappy.{Snappy, SnappyInputStream, SnappyOutputStream} @@ -50,13 +51,14 @@ private[spark] object CompressionCodec { private[spark] def supportsConcatenationOfSerializedStreams(codec: CompressionCodec): Boolean = { (codec.isInstanceOf[SnappyCompressionCodec] || codec.isInstanceOf[LZFCompressionCodec] - || codec.isInstanceOf[LZ4CompressionCodec]) + || codec.isInstanceOf[LZ4CompressionCodec] || codec.isInstanceOf[ZStdCompressionCodec]) } private val shortCompressionCodecNames = Map( "lz4" -> classOf[LZ4CompressionCodec].getName, "lzf" -> classOf[LZFCompressionCodec].getName, - "snappy" -> classOf[SnappyCompressionCodec].getName) + "snappy" -> classOf[SnappyCompressionCodec].getName, + "zstd" -> classOf[ZStdCompressionCodec].getName) def getCodecName(conf: SparkConf): String = { conf.get(configKey, DEFAULT_COMPRESSION_CODEC) @@ -219,3 +221,33 @@ private final class SnappyOutputStreamWrapper(os: SnappyOutputStream) extends Ou } } } + +/** + * :: DeveloperApi :: + * ZStandard implementation of [[org.apache.spark.io.CompressionCodec]]. For more + * details see - http://facebook.github.io/zstd/ + * + * @note The wire protocol for this codec is not guaranteed to be compatible across versions + * of Spark. This is intended for use as an internal compression utility within a single Spark + * application. + */ +@DeveloperApi +class ZStdCompressionCodec(conf: SparkConf) extends CompressionCodec { + + private val bufferSize = conf.getSizeAsBytes("spark.io.compression.zstd.bufferSize", "32k").toInt + // Default compression level for zstd compression to 1 because it is + // fastest of all with reasonably high compression ratio. + private val level = conf.getInt("spark.io.compression.zstd.level", 1) + + override def compressedOutputStream(s: OutputStream): OutputStream = { + // Wrap the zstd output stream in a buffered output stream, so that we can + // avoid overhead excessive of JNI call while trying to compress small amount of data. + new BufferedOutputStream(new ZstdOutputStream(s, level), bufferSize) + } + + override def compressedInputStream(s: InputStream): InputStream = { + // Wrap the zstd input stream in a buffered input stream so that we can + // avoid overhead excessive of JNI call while trying to uncompress small amount of data. + new BufferedInputStream(new ZstdInputStream(s), bufferSize) + } +} diff --git a/core/src/test/scala/org/apache/spark/io/CompressionCodecSuite.scala b/core/src/test/scala/org/apache/spark/io/CompressionCodecSuite.scala index 9e9c2b0165e13..7b40e3e58216d 100644 --- a/core/src/test/scala/org/apache/spark/io/CompressionCodecSuite.scala +++ b/core/src/test/scala/org/apache/spark/io/CompressionCodecSuite.scala @@ -104,6 +104,24 @@ class CompressionCodecSuite extends SparkFunSuite { testConcatenationOfSerializedStreams(codec) } + test("zstd compression codec") { + val codec = CompressionCodec.createCodec(conf, classOf[ZStdCompressionCodec].getName) + assert(codec.getClass === classOf[ZStdCompressionCodec]) + testCodec(codec) + } + + test("zstd compression codec short form") { + val codec = CompressionCodec.createCodec(conf, "zstd") + assert(codec.getClass === classOf[ZStdCompressionCodec]) + testCodec(codec) + } + + test("zstd supports concatenation of serialized zstd") { + val codec = CompressionCodec.createCodec(conf, classOf[ZStdCompressionCodec].getName) + assert(codec.getClass === classOf[ZStdCompressionCodec]) + testConcatenationOfSerializedStreams(codec) + } + test("bad compression codec") { intercept[IllegalArgumentException] { CompressionCodec.createCodec(conf, "foobar") diff --git a/dev/deps/spark-deps-hadoop-2.6 b/dev/deps/spark-deps-hadoop-2.6 index 6e2fc63d67108..21c8a75796387 100644 --- a/dev/deps/spark-deps-hadoop-2.6 +++ b/dev/deps/spark-deps-hadoop-2.6 @@ -189,3 +189,4 @@ xercesImpl-2.9.1.jar xmlenc-0.52.jar xz-1.0.jar zookeeper-3.4.6.jar +zstd-jni-1.3.2-2.jar diff --git a/dev/deps/spark-deps-hadoop-2.7 b/dev/deps/spark-deps-hadoop-2.7 index c2bbc253d723a..7173426c7bf74 100644 --- a/dev/deps/spark-deps-hadoop-2.7 +++ b/dev/deps/spark-deps-hadoop-2.7 @@ -190,3 +190,4 @@ xercesImpl-2.9.1.jar xmlenc-0.52.jar xz-1.0.jar zookeeper-3.4.6.jar +zstd-jni-1.3.2-2.jar diff --git a/docs/configuration.md b/docs/configuration.md index d3c358bb74173..9b9583d9165ef 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -889,7 +889,8 @@ Apart from these, the following properties are also available, and may be useful e.g. org.apache.spark.io.LZ4CompressionCodec, org.apache.spark.io.LZFCompressionCodec, - and org.apache.spark.io.SnappyCompressionCodec. + org.apache.spark.io.SnappyCompressionCodec, + and org.apache.spark.io.ZstdCompressionCodec. @@ -908,6 +909,23 @@ Apart from these, the following properties are also available, and may be useful is used. Lowering this block size will also lower shuffle memory usage when Snappy is used. + + spark.io.compression.zstd.level + 1 + + Compression level for Zstd compression codec. Increasing the compression level will result in better + compression at the expense of more CPU and memory. + + + + spark.io.compression.zstd.bufferSize + 32k + + Buffer size used in Zstd compression, in the case when Zstd compression codec + is used. Lowering this size will lower the shuffle memory usage when Zstd is used, but it + might increase the compression cost because of excessive JNI call overhead. + + spark.kryo.classesToRegister (none) diff --git a/licenses/LICENSE-zstd-jni.txt b/licenses/LICENSE-zstd-jni.txt new file mode 100644 index 0000000000000..32c6bbdd980d6 --- /dev/null +++ b/licenses/LICENSE-zstd-jni.txt @@ -0,0 +1,26 @@ +Zstd-jni: JNI bindings to Zstd Library + +Copyright (c) 2015-2016, Luben Karavelov/ All rights reserved. + +BSD License + +Redistribution and use in source and binary forms, with or without modification, +are permitted provided that the following conditions are met: + +* Redistributions of source code must retain the above copyright notice, this + list of conditions and the following disclaimer. + +* Redistributions in binary form must reproduce the above copyright notice, this + list of conditions and the following disclaimer in the documentation and/or + other materials provided with the distribution. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND +ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED +WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR +ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES +(INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; +LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON +ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS +SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. diff --git a/licenses/LICENSE-zstd.txt b/licenses/LICENSE-zstd.txt new file mode 100644 index 0000000000000..a793a80289256 --- /dev/null +++ b/licenses/LICENSE-zstd.txt @@ -0,0 +1,30 @@ +BSD License + +For Zstandard software + +Copyright (c) 2016-present, Facebook, Inc. All rights reserved. + +Redistribution and use in source and binary forms, with or without modification, +are permitted provided that the following conditions are met: + + * Redistributions of source code must retain the above copyright notice, this + list of conditions and the following disclaimer. + + * Redistributions in binary form must reproduce the above copyright notice, + this list of conditions and the following disclaimer in the documentation + and/or other materials provided with the distribution. + + * Neither the name Facebook nor the names of its contributors may be used to + endorse or promote products derived from this software without specific + prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND +ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED +WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR +ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES +(INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; +LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON +ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS +SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. diff --git a/pom.xml b/pom.xml index 2d59f06811a82..652aed4d12f7a 100644 --- a/pom.xml +++ b/pom.xml @@ -537,6 +537,11 @@ lz4-java 1.4.0 + + com.github.luben + zstd-jni + 1.3.2-2 + com.clearspring.analytics stream From 1ffe03d9e87fb784cc8a0bae232c81c7b14deac9 Mon Sep 17 00:00:00 2001 From: LucaCanali Date: Wed, 1 Nov 2017 15:40:25 +0100 Subject: [PATCH 1586/1765] [SPARK-22190][CORE] Add Spark executor task metrics to Dropwizard metrics ## What changes were proposed in this pull request? This proposed patch is about making Spark executor task metrics available as Dropwizard metrics. This is intended to be of aid in monitoring Spark jobs and when drilling down on performance troubleshooting issues. ## How was this patch tested? Manually tested on a Spark cluster (see JIRA for an example screenshot). Author: LucaCanali Closes #19426 from LucaCanali/SparkTaskMetricsDropWizard. --- .../org/apache/spark/executor/Executor.scala | 41 ++++++++++++++++ .../spark/executor/ExecutorSource.scala | 48 +++++++++++++++++++ 2 files changed, 89 insertions(+) diff --git a/core/src/main/scala/org/apache/spark/executor/Executor.scala b/core/src/main/scala/org/apache/spark/executor/Executor.scala index 2ecbb749d1fb7..e3e555eaa0277 100644 --- a/core/src/main/scala/org/apache/spark/executor/Executor.scala +++ b/core/src/main/scala/org/apache/spark/executor/Executor.scala @@ -406,6 +406,47 @@ private[spark] class Executor( task.metrics.setJvmGCTime(computeTotalGcTime() - startGCTime) task.metrics.setResultSerializationTime(afterSerialization - beforeSerialization) + // Expose task metrics using the Dropwizard metrics system. + // Update task metrics counters + executorSource.METRIC_CPU_TIME.inc(task.metrics.executorCpuTime) + executorSource.METRIC_RUN_TIME.inc(task.metrics.executorRunTime) + executorSource.METRIC_JVM_GC_TIME.inc(task.metrics.jvmGCTime) + executorSource.METRIC_DESERIALIZE_TIME.inc(task.metrics.executorDeserializeTime) + executorSource.METRIC_DESERIALIZE_CPU_TIME.inc(task.metrics.executorDeserializeCpuTime) + executorSource.METRIC_RESULT_SERIALIZE_TIME.inc(task.metrics.resultSerializationTime) + executorSource.METRIC_SHUFFLE_FETCH_WAIT_TIME + .inc(task.metrics.shuffleReadMetrics.fetchWaitTime) + executorSource.METRIC_SHUFFLE_WRITE_TIME.inc(task.metrics.shuffleWriteMetrics.writeTime) + executorSource.METRIC_SHUFFLE_TOTAL_BYTES_READ + .inc(task.metrics.shuffleReadMetrics.totalBytesRead) + executorSource.METRIC_SHUFFLE_REMOTE_BYTES_READ + .inc(task.metrics.shuffleReadMetrics.remoteBytesRead) + executorSource.METRIC_SHUFFLE_REMOTE_BYTES_READ_TO_DISK + .inc(task.metrics.shuffleReadMetrics.remoteBytesReadToDisk) + executorSource.METRIC_SHUFFLE_LOCAL_BYTES_READ + .inc(task.metrics.shuffleReadMetrics.localBytesRead) + executorSource.METRIC_SHUFFLE_RECORDS_READ + .inc(task.metrics.shuffleReadMetrics.recordsRead) + executorSource.METRIC_SHUFFLE_REMOTE_BLOCKS_FETCHED + .inc(task.metrics.shuffleReadMetrics.remoteBlocksFetched) + executorSource.METRIC_SHUFFLE_LOCAL_BLOCKS_FETCHED + .inc(task.metrics.shuffleReadMetrics.localBlocksFetched) + executorSource.METRIC_SHUFFLE_BYTES_WRITTEN + .inc(task.metrics.shuffleWriteMetrics.bytesWritten) + executorSource.METRIC_SHUFFLE_RECORDS_WRITTEN + .inc(task.metrics.shuffleWriteMetrics.recordsWritten) + executorSource.METRIC_INPUT_BYTES_READ + .inc(task.metrics.inputMetrics.bytesRead) + executorSource.METRIC_INPUT_RECORDS_READ + .inc(task.metrics.inputMetrics.recordsRead) + executorSource.METRIC_OUTPUT_BYTES_WRITTEN + .inc(task.metrics.outputMetrics.bytesWritten) + executorSource.METRIC_OUTPUT_RECORDS_WRITTEN + .inc(task.metrics.inputMetrics.recordsRead) + executorSource.METRIC_RESULT_SIZE.inc(task.metrics.resultSize) + executorSource.METRIC_DISK_BYTES_SPILLED.inc(task.metrics.diskBytesSpilled) + executorSource.METRIC_MEMORY_BYTES_SPILLED.inc(task.metrics.memoryBytesSpilled) + // Note: accumulator updates must be collected after TaskMetrics is updated val accumUpdates = task.collectAccumulatorUpdates() // TODO: do not serialize value twice diff --git a/core/src/main/scala/org/apache/spark/executor/ExecutorSource.scala b/core/src/main/scala/org/apache/spark/executor/ExecutorSource.scala index d16f4a1fc4e3b..669ce63325d0e 100644 --- a/core/src/main/scala/org/apache/spark/executor/ExecutorSource.scala +++ b/core/src/main/scala/org/apache/spark/executor/ExecutorSource.scala @@ -72,4 +72,52 @@ class ExecutorSource(threadPool: ThreadPoolExecutor, executorId: String) extends registerFileSystemStat(scheme, "largeRead_ops", _.getLargeReadOps(), 0) registerFileSystemStat(scheme, "write_ops", _.getWriteOps(), 0) } + + // Expose executor task metrics using the Dropwizard metrics system. + // The list is taken from TaskMetrics.scala + val METRIC_CPU_TIME = metricRegistry.counter(MetricRegistry.name("cpuTime")) + val METRIC_RUN_TIME = metricRegistry.counter(MetricRegistry.name("runTime")) + val METRIC_JVM_GC_TIME = metricRegistry.counter(MetricRegistry.name("jvmGCTime")) + val METRIC_DESERIALIZE_TIME = + metricRegistry.counter(MetricRegistry.name("deserializeTime")) + val METRIC_DESERIALIZE_CPU_TIME = + metricRegistry.counter(MetricRegistry.name("deserializeCpuTime")) + val METRIC_RESULT_SERIALIZE_TIME = + metricRegistry.counter(MetricRegistry.name("resultSerializationTime")) + val METRIC_SHUFFLE_FETCH_WAIT_TIME = + metricRegistry.counter(MetricRegistry.name("shuffleFetchWaitTime")) + val METRIC_SHUFFLE_WRITE_TIME = + metricRegistry.counter(MetricRegistry.name("shuffleWriteTime")) + val METRIC_SHUFFLE_TOTAL_BYTES_READ = + metricRegistry.counter(MetricRegistry.name("shuffleTotalBytesRead")) + val METRIC_SHUFFLE_REMOTE_BYTES_READ = + metricRegistry.counter(MetricRegistry.name("shuffleRemoteBytesRead")) + val METRIC_SHUFFLE_REMOTE_BYTES_READ_TO_DISK = + metricRegistry.counter(MetricRegistry.name("shuffleRemoteBytesReadToDisk")) + val METRIC_SHUFFLE_LOCAL_BYTES_READ = + metricRegistry.counter(MetricRegistry.name("shuffleLocalBytesRead")) + val METRIC_SHUFFLE_RECORDS_READ = + metricRegistry.counter(MetricRegistry.name("shuffleRecordsRead")) + val METRIC_SHUFFLE_REMOTE_BLOCKS_FETCHED = + metricRegistry.counter(MetricRegistry.name("shuffleRemoteBlocksFetched")) + val METRIC_SHUFFLE_LOCAL_BLOCKS_FETCHED = + metricRegistry.counter(MetricRegistry.name("shuffleLocalBlocksFetched")) + val METRIC_SHUFFLE_BYTES_WRITTEN = + metricRegistry.counter(MetricRegistry.name("shuffleBytesWritten")) + val METRIC_SHUFFLE_RECORDS_WRITTEN = + metricRegistry.counter(MetricRegistry.name("shuffleRecordsWritten")) + val METRIC_INPUT_BYTES_READ = + metricRegistry.counter(MetricRegistry.name("bytesRead")) + val METRIC_INPUT_RECORDS_READ = + metricRegistry.counter(MetricRegistry.name("recordsRead")) + val METRIC_OUTPUT_BYTES_WRITTEN = + metricRegistry.counter(MetricRegistry.name("bytesWritten")) + val METRIC_OUTPUT_RECORDS_WRITTEN = + metricRegistry.counter(MetricRegistry.name("recordsWritten")) + val METRIC_RESULT_SIZE = + metricRegistry.counter(MetricRegistry.name("resultSize")) + val METRIC_DISK_BYTES_SPILLED = + metricRegistry.counter(MetricRegistry.name("diskBytesSpilled")) + val METRIC_MEMORY_BYTES_SPILLED = + metricRegistry.counter(MetricRegistry.name("memoryBytesSpilled")) } From d43e1f06bd545d00bfcaf1efb388b469effd5d64 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Wed, 1 Nov 2017 18:39:15 +0100 Subject: [PATCH 1587/1765] [MINOR] Data source v2 docs update. ## What changes were proposed in this pull request? This patch includes some doc updates for data source API v2. I was reading the code and noticed some minor issues. ## How was this patch tested? This is a doc only change. Author: Reynold Xin Closes #19626 from rxin/dsv2-update. --- .../org/apache/spark/sql/sources/v2/DataSourceV2.java | 9 ++++----- .../org/apache/spark/sql/sources/v2/WriteSupport.java | 4 ++-- .../sql/sources/v2/reader/DataSourceV2Reader.java | 10 +++++----- .../v2/reader/SupportsPushDownCatalystFilters.java | 2 -- .../sql/sources/v2/reader/SupportsScanUnsafeRow.java | 2 -- .../sql/sources/v2/writer/DataSourceV2Writer.java | 11 +++-------- .../spark/sql/sources/v2/writer/DataWriter.java | 10 +++++----- 7 files changed, 19 insertions(+), 29 deletions(-) diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/DataSourceV2.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/DataSourceV2.java index dbcbe326a7510..6234071320dc9 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/DataSourceV2.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/DataSourceV2.java @@ -20,12 +20,11 @@ import org.apache.spark.annotation.InterfaceStability; /** - * The base interface for data source v2. Implementations must have a public, no arguments - * constructor. + * The base interface for data source v2. Implementations must have a public, 0-arg constructor. * - * Note that this is an empty interface, data source implementations should mix-in at least one of - * the plug-in interfaces like {@link ReadSupport}. Otherwise it's just a dummy data source which is - * un-readable/writable. + * Note that this is an empty interface. Data source implementations should mix-in at least one of + * the plug-in interfaces like {@link ReadSupport} and {@link WriteSupport}. Otherwise it's just + * a dummy data source which is un-readable/writable. */ @InterfaceStability.Evolving public interface DataSourceV2 {} diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/WriteSupport.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/WriteSupport.java index a8a961598bde3..8fdfdfd19ea1e 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/WriteSupport.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/WriteSupport.java @@ -36,8 +36,8 @@ public interface WriteSupport { * sources can return None if there is no writing needed to be done according to the save mode. * * @param jobId A unique string for the writing job. It's possible that there are many writing - * jobs running at the same time, and the returned {@link DataSourceV2Writer} should - * use this job id to distinguish itself with writers of other jobs. + * jobs running at the same time, and the returned {@link DataSourceV2Writer} can + * use this job id to distinguish itself from other jobs. * @param schema the schema of the data to be written. * @param mode the save mode which determines what to do when the data are already in this data * source, please refer to {@link SaveMode} for more details. diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/DataSourceV2Reader.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/DataSourceV2Reader.java index 5989a4ac8440b..88c3219a75c1d 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/DataSourceV2Reader.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/DataSourceV2Reader.java @@ -34,11 +34,11 @@ * * There are mainly 3 kinds of query optimizations: * 1. Operators push-down. E.g., filter push-down, required columns push-down(aka column - * pruning), etc. These push-down interfaces are named like `SupportsPushDownXXX`. - * 2. Information Reporting. E.g., statistics reporting, ordering reporting, etc. These - * reporting interfaces are named like `SupportsReportingXXX`. - * 3. Special scans. E.g, columnar scan, unsafe row scan, etc. These scan interfaces are named - * like `SupportsScanXXX`. + * pruning), etc. Names of these interfaces start with `SupportsPushDown`. + * 2. Information Reporting. E.g., statistics reporting, ordering reporting, etc. + * Names of these interfaces start with `SupportsReporting`. + * 3. Special scans. E.g, columnar scan, unsafe row scan, etc. + * Names of these interfaces start with `SupportsScan`. * * Spark first applies all operator push-down optimizations that this data source supports. Then * Spark collects information this data source reported for further optimizations. Finally Spark diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsPushDownCatalystFilters.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsPushDownCatalystFilters.java index d6091774d75aa..efc42242f4421 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsPushDownCatalystFilters.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsPushDownCatalystFilters.java @@ -31,8 +31,6 @@ * {@link SupportsPushDownFilters}, Spark will ignore {@link SupportsPushDownFilters} and only * process this interface. */ -@InterfaceStability.Evolving -@Experimental @InterfaceStability.Unstable public interface SupportsPushDownCatalystFilters { diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsScanUnsafeRow.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsScanUnsafeRow.java index d5eada808a16c..6008fb5f71cc1 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsScanUnsafeRow.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsScanUnsafeRow.java @@ -30,8 +30,6 @@ * This is an experimental and unstable interface, as {@link UnsafeRow} is not public and may get * changed in the future Spark versions. */ -@InterfaceStability.Evolving -@Experimental @InterfaceStability.Unstable public interface SupportsScanUnsafeRow extends DataSourceV2Reader { diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataSourceV2Writer.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataSourceV2Writer.java index 8d8e33633fb0d..37bb15f87c59a 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataSourceV2Writer.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataSourceV2Writer.java @@ -40,15 +40,10 @@ * some writers are aborted, or the job failed with an unknown reason, call * {@link #abort(WriterCommitMessage[])}. * - * Spark won't retry failed writing jobs, users should do it manually in their Spark applications if - * they want to retry. + * While Spark will retry failed writing tasks, Spark won't retry failed writing jobs. Users should + * do it manually in their Spark applications if they want to retry. * - * Please refer to the document of commit/abort methods for detailed specifications. - * - * Note that, this interface provides a protocol between Spark and data sources for transactional - * data writing, but the transaction here is Spark-level transaction, which may not be the - * underlying storage transaction. For example, Spark successfully writes data to a Cassandra data - * source, but Cassandra may need some more time to reach consistency at storage level. + * Please refer to the documentation of commit/abort methods for detailed specifications. */ @InterfaceStability.Evolving public interface DataSourceV2Writer { diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataWriter.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataWriter.java index d84afbae32892..dc1aab33bdcef 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataWriter.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataWriter.java @@ -57,8 +57,8 @@ public interface DataWriter { /** * Writes one record. * - * If this method fails(throw exception), {@link #abort()} will be called and this data writer is - * considered to be failed. + * If this method fails (by throwing an exception), {@link #abort()} will be called and this + * data writer is considered to have been failed. */ void write(T record); @@ -70,10 +70,10 @@ public interface DataWriter { * The written data should only be visible to data source readers after * {@link DataSourceV2Writer#commit(WriterCommitMessage[])} succeeds, which means this method * should still "hide" the written data and ask the {@link DataSourceV2Writer} at driver side to - * do the final commitment via {@link WriterCommitMessage}. + * do the final commit via {@link WriterCommitMessage}. * - * If this method fails(throw exception), {@link #abort()} will be called and this data writer is - * considered to be failed. + * If this method fails (by throwing an exception), {@link #abort()} will be called and this + * data writer is considered to have been failed. */ WriterCommitMessage commit(); From b04eefae49b96e2ef5a8d75334db29ef4e19ce58 Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Thu, 2 Nov 2017 09:30:03 +0900 Subject: [PATCH 1588/1765] [MINOR][DOC] automatic type inference supports also Date and Timestamp ## What changes were proposed in this pull request? Easy fix in the documentation, which is reporting that only numeric types and string are supported in type inference for partition columns, while Date and Timestamp are supported too since 2.1.0, thanks to SPARK-17388. ## How was this patch tested? n/a Author: Marco Gaido Closes #19628 from mgaido91/SPARK-22398. --- docs/sql-programming-guide.md | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md index 639a8ea7bb8ad..ce377875ff2b1 100644 --- a/docs/sql-programming-guide.md +++ b/docs/sql-programming-guide.md @@ -800,10 +800,11 @@ root {% endhighlight %} Notice that the data types of the partitioning columns are automatically inferred. Currently, -numeric data types and string type are supported. Sometimes users may not want to automatically -infer the data types of the partitioning columns. For these use cases, the automatic type inference -can be configured by `spark.sql.sources.partitionColumnTypeInference.enabled`, which is default to -`true`. When type inference is disabled, string type will be used for the partitioning columns. +numeric data types, date, timestamp and string type are supported. Sometimes users may not want +to automatically infer the data types of the partitioning columns. For these use cases, the +automatic type inference can be configured by +`spark.sql.sources.partitionColumnTypeInference.enabled`, which is default to `true`. When type +inference is disabled, string type will be used for the partitioning columns. Starting from Spark 1.6.0, partition discovery only finds partitions under the given paths by default. For the above example, if users pass `path/to/table/gender=male` to either From 849b465bbf472d6ca56308fb3ccade86e2244e01 Mon Sep 17 00:00:00 2001 From: Sean Owen Date: Thu, 2 Nov 2017 09:45:34 +0000 Subject: [PATCH 1589/1765] [SPARK-14650][REPL][BUILD] Compile Spark REPL for Scala 2.12 ## What changes were proposed in this pull request? Spark REPL changes for Scala 2.12.4: use command(), not processLine() in ILoop; remove direct dependence on older jline. Not sure whether this became needed in 2.12.4 or just missed this before. This makes spark-shell work in 2.12. ## How was this patch tested? Existing tests; manual run of spark-shell in 2.11, 2.12 builds Author: Sean Owen Closes #19612 from srowen/SPARK-14650.2. --- pom.xml | 15 ++++++++++----- repl/pom.xml | 4 ---- .../scala/org/apache/spark/repl/SparkILoop.scala | 13 +++++-------- 3 files changed, 15 insertions(+), 17 deletions(-) diff --git a/pom.xml b/pom.xml index 652aed4d12f7a..8570338df6878 100644 --- a/pom.xml +++ b/pom.xml @@ -735,6 +735,12 @@ scalap ${scala.version} + + + jline + jline + 2.12.1 + org.scalatest scalatest_${scala.binary.version} @@ -1188,6 +1194,10 @@ org.jboss.netty netty + + jline + jline + @@ -1925,11 +1935,6 @@ antlr4-runtime ${antlr4.version} - - jline - jline - 2.12.1 - org.apache.commons commons-crypto diff --git a/repl/pom.xml b/repl/pom.xml index bd2cfc465aaf0..1cb0098d0eca3 100644 --- a/repl/pom.xml +++ b/repl/pom.xml @@ -69,10 +69,6 @@ org.scala-lang scala-reflect ${scala.version} - - - jline - jline org.slf4j diff --git a/repl/scala-2.12/src/main/scala/org/apache/spark/repl/SparkILoop.scala b/repl/scala-2.12/src/main/scala/org/apache/spark/repl/SparkILoop.scala index 413594021987d..900edd63cb90e 100644 --- a/repl/scala-2.12/src/main/scala/org/apache/spark/repl/SparkILoop.scala +++ b/repl/scala-2.12/src/main/scala/org/apache/spark/repl/SparkILoop.scala @@ -19,9 +19,6 @@ package org.apache.spark.repl import java.io.BufferedReader -// scalastyle:off println -import scala.Predef.{println => _, _} -// scalastyle:on println import scala.tools.nsc.Settings import scala.tools.nsc.interpreter.{ILoop, JPrintWriter} import scala.tools.nsc.util.stringFromStream @@ -37,7 +34,7 @@ class SparkILoop(in0: Option[BufferedReader], out: JPrintWriter) def initializeSpark() { intp.beQuietDuring { - processLine(""" + command(""" @transient val spark = if (org.apache.spark.repl.Main.sparkSession != null) { org.apache.spark.repl.Main.sparkSession } else { @@ -64,10 +61,10 @@ class SparkILoop(in0: Option[BufferedReader], out: JPrintWriter) _sc } """) - processLine("import org.apache.spark.SparkContext._") - processLine("import spark.implicits._") - processLine("import spark.sql") - processLine("import org.apache.spark.sql.functions._") + command("import org.apache.spark.SparkContext._") + command("import spark.implicits._") + command("import spark.sql") + command("import org.apache.spark.sql.functions._") } } From 277b1924b46a70ab25414f5670eb784906dbbfdf Mon Sep 17 00:00:00 2001 From: Patrick Woody Date: Thu, 2 Nov 2017 14:19:21 +0100 Subject: [PATCH 1590/1765] [SPARK-22408][SQL] RelationalGroupedDataset's distinct pivot value calculation launches unnecessary stages ## What changes were proposed in this pull request? Adding a global limit on top of the distinct values before sorting and collecting will reduce the overall work in the case where we have more distinct values. We will also eagerly perform a collect rather than a take because we know we only have at most (maxValues + 1) rows. ## How was this patch tested? Existing tests cover sorted order Author: Patrick Woody Closes #19629 from pwoody/SPARK-22408. --- .../scala/org/apache/spark/sql/RelationalGroupedDataset.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala index 21e94fa8bb0b1..3e4edd4ea8cf3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala @@ -321,10 +321,10 @@ class RelationalGroupedDataset protected[sql]( // Get the distinct values of the column and sort them so its consistent val values = df.select(pivotColumn) .distinct() + .limit(maxValues + 1) .sort(pivotColumn) // ensure that the output columns are in a consistent logical order - .rdd + .collect() .map(_.get(0)) - .take(maxValues + 1) .toSeq if (values.length > maxValues) { From b2463fad718d25f564d62c50d587610de3d0c5bd Mon Sep 17 00:00:00 2001 From: Stavros Kontopoulos Date: Thu, 2 Nov 2017 13:25:48 +0000 Subject: [PATCH 1591/1765] [SPARK-22145][MESOS] fix supervise with checkpointing on mesos ## What changes were proposed in this pull request? - Fixes the issue with the frameworkId being recovered by checkpointed data overwriting the one sent by the dipatcher. - Keeps submission driver id as the only index for all data structures in the dispatcher. Allocates a different task id per driver retry to satisfy the mesos requirements. Check the relevant ticket for the details on that. ## How was this patch tested? Manually tested this with DC/OS 1.10. Launched a streaming job with checkpointing to hdfs, made the driver fail several times and observed behavior: ![image](https://user-images.githubusercontent.com/7945591/30940500-f7d2a744-a3e9-11e7-8c56-f2ccbb271e80.png) ![image](https://user-images.githubusercontent.com/7945591/30940550-19bc15de-a3ea-11e7-8a11-f48abfe36720.png) ![image](https://user-images.githubusercontent.com/7945591/30940524-083ea308-a3ea-11e7-83ae-00d3fa17b928.png) ![image](https://user-images.githubusercontent.com/7945591/30940579-2f0fb242-a3ea-11e7-82f9-86179da28b8c.png) ![image](https://user-images.githubusercontent.com/7945591/30940591-3b561b0e-a3ea-11e7-9dbd-e71912bb2ef3.png) ![image](https://user-images.githubusercontent.com/7945591/30940605-49c810ca-a3ea-11e7-8af5-67930851fd38.png) ![image](https://user-images.githubusercontent.com/7945591/30940631-59f4a288-a3ea-11e7-88cb-c3741b72bb13.png) ![image](https://user-images.githubusercontent.com/7945591/30940642-62346c9e-a3ea-11e7-8935-82e494925f67.png) ![image](https://user-images.githubusercontent.com/7945591/30940653-6c46d53c-a3ea-11e7-8dd1-5840d484d28c.png) Author: Stavros Kontopoulos Closes #19374 from skonto/fix_retry. --- .../scala/org/apache/spark/SparkContext.scala | 1 + .../cluster/mesos/MesosClusterScheduler.scala | 90 +++++++++++-------- .../apache/spark/streaming/Checkpoint.scala | 3 +- 3 files changed, 57 insertions(+), 37 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index 6f25d346e6e54..c7dd635ad4c96 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -310,6 +310,7 @@ class SparkContext(config: SparkConf) extends Logging { * (i.e. * in case of local spark app something like 'local-1433865536131' * in case of YARN something like 'application_1433865536131_34483' + * in case of MESOS something like 'driver-20170926223339-0001' * ) */ def applicationId: String = _applicationId diff --git a/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterScheduler.scala b/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterScheduler.scala index 82470264f2a4a..de846c85d53a6 100644 --- a/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterScheduler.scala +++ b/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterScheduler.scala @@ -134,22 +134,24 @@ private[spark] class MesosClusterScheduler( private val useFetchCache = conf.getBoolean("spark.mesos.fetchCache.enable", false) private val schedulerState = engineFactory.createEngine("scheduler") private val stateLock = new Object() + // Keyed by submission id private val finishedDrivers = new mutable.ArrayBuffer[MesosClusterSubmissionState](retainedDrivers) private var frameworkId: String = null - // Holds all the launched drivers and current launch state, keyed by driver id. + // Holds all the launched drivers and current launch state, keyed by submission id. private val launchedDrivers = new mutable.HashMap[String, MesosClusterSubmissionState]() // Holds a map of driver id to expected slave id that is passed to Mesos for reconciliation. // All drivers that are loaded after failover are added here, as we need get the latest - // state of the tasks from Mesos. + // state of the tasks from Mesos. Keyed by task Id. private val pendingRecover = new mutable.HashMap[String, SlaveID]() - // Stores all the submitted drivers that hasn't been launched. + // Stores all the submitted drivers that hasn't been launched, keyed by submission id private val queuedDrivers = new ArrayBuffer[MesosDriverDescription]() - // All supervised drivers that are waiting to retry after termination. + // All supervised drivers that are waiting to retry after termination, keyed by submission id private val pendingRetryDrivers = new ArrayBuffer[MesosDriverDescription]() private val queuedDriversState = engineFactory.createEngine("driverQueue") private val launchedDriversState = engineFactory.createEngine("launchedDrivers") private val pendingRetryDriversState = engineFactory.createEngine("retryList") + private final val RETRY_SEP = "-retry-" // Flag to mark if the scheduler is ready to be called, which is until the scheduler // is registered with Mesos master. @volatile protected var ready = false @@ -192,8 +194,8 @@ private[spark] class MesosClusterScheduler( // 3. Check if it's in the retry list. // 4. Check if it has already completed. if (launchedDrivers.contains(submissionId)) { - val task = launchedDrivers(submissionId) - schedulerDriver.killTask(task.taskId) + val state = launchedDrivers(submissionId) + schedulerDriver.killTask(state.taskId) k.success = true k.message = "Killing running driver" } else if (removeFromQueuedDrivers(submissionId)) { @@ -275,7 +277,7 @@ private[spark] class MesosClusterScheduler( private def recoverState(): Unit = { stateLock.synchronized { launchedDriversState.fetchAll[MesosClusterSubmissionState]().foreach { state => - launchedDrivers(state.taskId.getValue) = state + launchedDrivers(state.driverDescription.submissionId) = state pendingRecover(state.taskId.getValue) = state.slaveId } queuedDriversState.fetchAll[MesosDriverDescription]().foreach(d => queuedDrivers += d) @@ -353,7 +355,8 @@ private[spark] class MesosClusterScheduler( .setSlaveId(slaveId) .setState(MesosTaskState.TASK_STAGING) .build() - launchedDrivers.get(taskId).map(_.mesosTaskStatus.getOrElse(newStatus)) + launchedDrivers.get(getSubmissionIdFromTaskId(taskId)) + .map(_.mesosTaskStatus.getOrElse(newStatus)) .getOrElse(newStatus) } // TODO: Page the status updates to avoid trying to reconcile @@ -369,10 +372,19 @@ private[spark] class MesosClusterScheduler( } private def getDriverFrameworkID(desc: MesosDriverDescription): String = { - val retries = desc.retryState.map { d => s"-retry-${d.retries.toString}" }.getOrElse("") + val retries = desc.retryState.map { d => s"${RETRY_SEP}${d.retries.toString}" }.getOrElse("") s"${frameworkId}-${desc.submissionId}${retries}" } + private def getDriverTaskId(desc: MesosDriverDescription): String = { + val sId = desc.submissionId + desc.retryState.map(state => sId + s"${RETRY_SEP}${state.retries.toString}").getOrElse(sId) + } + + private def getSubmissionIdFromTaskId(taskId: String): String = { + taskId.split(s"${RETRY_SEP}").head + } + private def adjust[A, B](m: collection.Map[A, B], k: A, default: B)(f: B => B) = { m.updated(k, f(m.getOrElse(k, default))) } @@ -551,7 +563,7 @@ private[spark] class MesosClusterScheduler( } private def createTaskInfo(desc: MesosDriverDescription, offer: ResourceOffer): TaskInfo = { - val taskId = TaskID.newBuilder().setValue(desc.submissionId).build() + val taskId = TaskID.newBuilder().setValue(getDriverTaskId(desc)).build() val (remainingResources, cpuResourcesToUse) = partitionResources(offer.remainingResources, "cpus", desc.cores) @@ -604,7 +616,7 @@ private[spark] class MesosClusterScheduler( val task = createTaskInfo(submission, offer) queuedTasks += task logTrace(s"Using offer ${offer.offer.getId.getValue} to launch driver " + - submission.submissionId) + submission.submissionId + s" with taskId: ${task.getTaskId.toString}") val newState = new MesosClusterSubmissionState( submission, task.getTaskId, @@ -718,45 +730,51 @@ private[spark] class MesosClusterScheduler( logInfo(s"Received status update: taskId=${taskId}" + s" state=${status.getState}" + s" message=${status.getMessage}" + - s" reason=${status.getReason}"); + s" reason=${status.getReason}") stateLock.synchronized { - if (launchedDrivers.contains(taskId)) { + val subId = getSubmissionIdFromTaskId(taskId) + if (launchedDrivers.contains(subId)) { if (status.getReason == Reason.REASON_RECONCILIATION && !pendingRecover.contains(taskId)) { // Task has already received update and no longer requires reconciliation. return } - val state = launchedDrivers(taskId) + val state = launchedDrivers(subId) // Check if the driver is supervise enabled and can be relaunched. if (state.driverDescription.supervise && shouldRelaunch(status.getState)) { - removeFromLaunchedDrivers(taskId) + removeFromLaunchedDrivers(subId) state.finishDate = Some(new Date()) val retryState: Option[MesosClusterRetryState] = state.driverDescription.retryState val (retries, waitTimeSec) = retryState .map { rs => (rs.retries + 1, Math.min(maxRetryWaitTime, rs.waitTime * 2)) } .getOrElse{ (1, 1) } val nextRetry = new Date(new Date().getTime + waitTimeSec * 1000L) - val newDriverDescription = state.driverDescription.copy( retryState = Some(new MesosClusterRetryState(status, retries, nextRetry, waitTimeSec))) - addDriverToPending(newDriverDescription, taskId); + addDriverToPending(newDriverDescription, newDriverDescription.submissionId) } else if (TaskState.isFinished(mesosToTaskState(status.getState))) { - removeFromLaunchedDrivers(taskId) - state.finishDate = Some(new Date()) - if (finishedDrivers.size >= retainedDrivers) { - val toRemove = math.max(retainedDrivers / 10, 1) - finishedDrivers.trimStart(toRemove) - } - finishedDrivers += state + retireDriver(subId, state) } state.mesosTaskStatus = Option(status) } else { - logError(s"Unable to find driver $taskId in status update") + logError(s"Unable to find driver with $taskId in status update") } } } + private def retireDriver( + submissionId: String, + state: MesosClusterSubmissionState) = { + removeFromLaunchedDrivers(submissionId) + state.finishDate = Some(new Date()) + if (finishedDrivers.size >= retainedDrivers) { + val toRemove = math.max(retainedDrivers / 10, 1) + finishedDrivers.trimStart(toRemove) + } + finishedDrivers += state + } + override def frameworkMessage( driver: SchedulerDriver, executorId: ExecutorID, @@ -769,31 +787,31 @@ private[spark] class MesosClusterScheduler( slaveId: SlaveID, status: Int): Unit = {} - private def removeFromQueuedDrivers(id: String): Boolean = { - val index = queuedDrivers.indexWhere(_.submissionId.equals(id)) + private def removeFromQueuedDrivers(subId: String): Boolean = { + val index = queuedDrivers.indexWhere(_.submissionId.equals(subId)) if (index != -1) { queuedDrivers.remove(index) - queuedDriversState.expunge(id) + queuedDriversState.expunge(subId) true } else { false } } - private def removeFromLaunchedDrivers(id: String): Boolean = { - if (launchedDrivers.remove(id).isDefined) { - launchedDriversState.expunge(id) + private def removeFromLaunchedDrivers(subId: String): Boolean = { + if (launchedDrivers.remove(subId).isDefined) { + launchedDriversState.expunge(subId) true } else { false } } - private def removeFromPendingRetryDrivers(id: String): Boolean = { - val index = pendingRetryDrivers.indexWhere(_.submissionId.equals(id)) + private def removeFromPendingRetryDrivers(subId: String): Boolean = { + val index = pendingRetryDrivers.indexWhere(_.submissionId.equals(subId)) if (index != -1) { pendingRetryDrivers.remove(index) - pendingRetryDriversState.expunge(id) + pendingRetryDriversState.expunge(subId) true } else { false @@ -810,8 +828,8 @@ private[spark] class MesosClusterScheduler( revive() } - private def addDriverToPending(desc: MesosDriverDescription, taskId: String) = { - pendingRetryDriversState.persist(taskId, desc) + private def addDriverToPending(desc: MesosDriverDescription, subId: String) = { + pendingRetryDriversState.persist(subId, desc) pendingRetryDrivers += desc revive() } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala b/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala index b8c780db07c98..40a0b8e3a407d 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala @@ -58,7 +58,8 @@ class Checkpoint(ssc: StreamingContext, val checkpointTime: Time) "spark.yarn.credentials.file", "spark.yarn.credentials.renewalTime", "spark.yarn.credentials.updateTime", - "spark.ui.filters") + "spark.ui.filters", + "spark.mesos.driver.frameworkId") val newSparkConf = new SparkConf(loadDefaults = false).setAll(sparkConfPairs) .remove("spark.driver.host") From 41b60125b673bad0c133cd5c825d353ac2e6dfd6 Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Thu, 2 Nov 2017 15:22:52 +0100 Subject: [PATCH 1592/1765] [SPARK-22369][PYTHON][DOCS] Exposes catalog API documentation in PySpark ## What changes were proposed in this pull request? This PR proposes to add a link from `spark.catalog(..)` to `Catalog` and expose Catalog APIs in PySpark as below: 2017-10-29 12 25 46 2017-10-29 12 26 33 Note that this is not shown in the list on the top - https://spark.apache.org/docs/latest/api/python/pyspark.sql.html#module-pyspark.sql 2017-10-29 12 30 58 This is basically similar with `DataFrameReader` and `DataFrameWriter`. ## How was this patch tested? Manually built the doc. Author: hyukjinkwon Closes #19596 from HyukjinKwon/SPARK-22369. --- python/pyspark/sql/__init__.py | 3 ++- python/pyspark/sql/session.py | 2 ++ 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/python/pyspark/sql/__init__.py b/python/pyspark/sql/__init__.py index 22ec416f6c584..c3c06c8124362 100644 --- a/python/pyspark/sql/__init__.py +++ b/python/pyspark/sql/__init__.py @@ -46,6 +46,7 @@ from pyspark.sql.context import SQLContext, HiveContext, UDFRegistration from pyspark.sql.session import SparkSession from pyspark.sql.column import Column +from pyspark.sql.catalog import Catalog from pyspark.sql.dataframe import DataFrame, DataFrameNaFunctions, DataFrameStatFunctions from pyspark.sql.group import GroupedData from pyspark.sql.readwriter import DataFrameReader, DataFrameWriter @@ -54,7 +55,7 @@ __all__ = [ 'SparkSession', 'SQLContext', 'HiveContext', 'UDFRegistration', - 'DataFrame', 'GroupedData', 'Column', 'Row', + 'DataFrame', 'GroupedData', 'Column', 'Catalog', 'Row', 'DataFrameNaFunctions', 'DataFrameStatFunctions', 'Window', 'WindowSpec', 'DataFrameReader', 'DataFrameWriter' ] diff --git a/python/pyspark/sql/session.py b/python/pyspark/sql/session.py index 2cc0e2d1d7b8d..c3dc1a46fd3c1 100644 --- a/python/pyspark/sql/session.py +++ b/python/pyspark/sql/session.py @@ -271,6 +271,8 @@ def conf(self): def catalog(self): """Interface through which the user may create, drop, alter or query underlying databases, tables, functions etc. + + :return: :class:`Catalog` """ if not hasattr(self, "_catalog"): self._catalog = Catalog(self) From e3f67a97f126abfb7eeb864f657bfc9221bb195e Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Thu, 2 Nov 2017 18:28:56 +0100 Subject: [PATCH 1593/1765] [SPARK-22416][SQL] Move OrcOptions from `sql/hive` to `sql/core` ## What changes were proposed in this pull request? According to the [discussion](https://github.com/apache/spark/pull/19571#issuecomment-339472976) on SPARK-15474, we will add new OrcFileFormat in `sql/core` module and allow users to use both old and new OrcFileFormat. To do that, `OrcOptions` should be visible in `sql/core` module, too. Previously, it was `private[orc]` in `sql/hive`. This PR removes `private[orc]` because we don't use `private[sql]` in `sql/execution` package after [SPARK-16964](https://github.com/apache/spark/pull/14554). ## How was this patch tested? Pass the Jenkins with the existing tests. Author: Dongjoon Hyun Closes #19636 from dongjoon-hyun/SPARK-22416. --- .../spark/sql/execution/datasources}/orc/OrcOptions.scala | 6 +++--- .../scala/org/apache/spark/sql/hive/orc/OrcFileFormat.scala | 1 + .../org/apache/spark/sql/hive/orc/OrcSourceSuite.scala | 1 + 3 files changed, 5 insertions(+), 3 deletions(-) rename sql/{hive/src/main/scala/org/apache/spark/sql/hive => core/src/main/scala/org/apache/spark/sql/execution/datasources}/orc/OrcOptions.scala (95%) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcOptions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcOptions.scala similarity index 95% rename from sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcOptions.scala rename to sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcOptions.scala index 6ce90c07b4921..c866dd834a525 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcOptions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcOptions.scala @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.sql.hive.orc +package org.apache.spark.sql.execution.datasources.orc import java.util.Locale @@ -27,7 +27,7 @@ import org.apache.spark.sql.internal.SQLConf /** * Options for the ORC data source. */ -private[orc] class OrcOptions( +class OrcOptions( @transient private val parameters: CaseInsensitiveMap[String], @transient private val sqlConf: SQLConf) extends Serializable { @@ -59,7 +59,7 @@ private[orc] class OrcOptions( } } -private[orc] object OrcOptions { +object OrcOptions { // The ORC compression short names private val shortOrcCompressionCodecNames = Map( "none" -> "NONE", diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileFormat.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileFormat.scala index d26ec15410d95..3b33a9ff082f3 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileFormat.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileFormat.scala @@ -39,6 +39,7 @@ import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.execution.datasources._ +import org.apache.spark.sql.execution.datasources.orc.OrcOptions import org.apache.spark.sql.hive.{HiveInspectors, HiveShim} import org.apache.spark.sql.sources.{Filter, _} import org.apache.spark.sql.types.StructType diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcSourceSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcSourceSuite.scala index ef9e67c743837..2a086be57f517 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcSourceSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcSourceSuite.scala @@ -24,6 +24,7 @@ import org.apache.orc.OrcConf.COMPRESS import org.scalatest.BeforeAndAfterAll import org.apache.spark.sql.{QueryTest, Row} +import org.apache.spark.sql.execution.datasources.orc.OrcOptions import org.apache.spark.sql.hive.test.TestHiveSingleton import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.sources._ From 882079f5c6de40913a27873f2fa3306e5d827393 Mon Sep 17 00:00:00 2001 From: ZouChenjun Date: Thu, 2 Nov 2017 11:06:37 -0700 Subject: [PATCH 1594/1765] [SPARK-22243][DSTREAM] spark.yarn.jars should reload from config when checkpoint recovery ## What changes were proposed in this pull request? the previous [PR](https://github.com/apache/spark/pull/19469) is deleted by mistake. the solution is straight forward. adding "spark.yarn.jars" to propertiesToReload so this property will load from config. ## How was this patch tested? manual tests Author: ZouChenjun Closes #19637 from ChenjunZou/checkpoint-yarn-jars. --- .../src/main/scala/org/apache/spark/streaming/Checkpoint.scala | 1 + 1 file changed, 1 insertion(+) diff --git a/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala b/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala index 40a0b8e3a407d..9ebb91b8cab3c 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala @@ -53,6 +53,7 @@ class Checkpoint(ssc: StreamingContext, val checkpointTime: Time) "spark.driver.host", "spark.driver.port", "spark.master", + "spark.yarn.jars", "spark.yarn.keytab", "spark.yarn.principal", "spark.yarn.credentials.file", From 2fd12af4372a1e2c3faf0eb5d0a1cf530abc0016 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Thu, 2 Nov 2017 23:41:16 +0100 Subject: [PATCH 1595/1765] [SPARK-22306][SQL] alter table schema should not erase the bucketing metadata at hive side forward-port https://github.com/apache/spark/pull/19622 to master branch. This bug doesn't exist in master because we've added hive bucketing support and the hive bucketing metadata can be recognized by Spark, but we should still port it to master: 1) there may be other unsupported hive metadata removed by Spark. 2) reduce code difference between master and 2.2 to ease the backport in the feature. *** When we alter table schema, we set the new schema to spark `CatalogTable`, convert it to hive table, and finally call `hive.alterTable`. This causes a problem in Spark 2.2, because hive bucketing metedata is not recognized by Spark, which means a Spark `CatalogTable` representing a hive table is always non-bucketed, and when we convert it to hive table and call `hive.alterTable`, the original hive bucketing metadata will be removed. To fix this bug, we should read out the raw hive table metadata, update its schema, and call `hive.alterTable`. By doing this we can guarantee only the schema is changed, and nothing else. Author: Wenchen Fan Closes #19644 from cloud-fan/infer. --- .../catalyst/catalog/ExternalCatalog.scala | 12 ++--- .../catalyst/catalog/InMemoryCatalog.scala | 7 +-- .../sql/catalyst/catalog/SessionCatalog.scala | 23 ++++----- .../catalog/ExternalCatalogSuite.scala | 10 ++-- .../catalog/SessionCatalogSuite.scala | 8 ++-- .../spark/sql/execution/command/ddl.scala | 14 ++++-- .../spark/sql/execution/command/tables.scala | 14 ++---- .../datasources/DataSourceStrategy.scala | 4 +- .../datasources/orc/OrcFileFormat.scala | 5 +- .../parquet/ParquetFileFormat.scala | 5 +- .../parquet/ParquetSchemaConverter.scala | 5 +- .../spark/sql/hive/HiveExternalCatalog.scala | 47 ++++++++++++------- .../spark/sql/hive/HiveMetastoreCatalog.scala | 11 ++--- .../spark/sql/hive/HiveStrategies.scala | 4 +- .../spark/sql/hive/client/HiveClient.scala | 11 +++++ .../sql/hive/client/HiveClientImpl.scala | 45 +++++++++++------- .../sql/hive/HiveExternalCatalogSuite.scala | 18 +++++++ .../sql/hive/MetastoreDataSourcesSuite.scala | 4 +- .../hive/execution/Hive_2_1_DDLSuite.scala | 2 +- 19 files changed, 146 insertions(+), 103 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalog.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalog.scala index d4c58db3708e3..223094d485936 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalog.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalog.scala @@ -150,17 +150,15 @@ abstract class ExternalCatalog def alterTable(tableDefinition: CatalogTable): Unit /** - * Alter the schema of a table identified by the provided database and table name. The new schema - * should still contain the existing bucket columns and partition columns used by the table. This - * method will also update any Spark SQL-related parameters stored as Hive table properties (such - * as the schema itself). + * Alter the data schema of a table identified by the provided database and table name. The new + * data schema should not have conflict column names with the existing partition columns, and + * should still contain all the existing data columns. * * @param db Database that table to alter schema for exists in * @param table Name of table to alter schema for - * @param schema Updated schema to be used for the table (must contain existing partition and - * bucket columns) + * @param newDataSchema Updated data schema to be used for the table. */ - def alterTableSchema(db: String, table: String, schema: StructType): Unit + def alterTableDataSchema(db: String, table: String, newDataSchema: StructType): Unit /** Alter the statistics of a table. If `stats` is None, then remove all existing statistics. */ def alterTableStats(db: String, table: String, stats: Option[CatalogStatistics]): Unit diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/InMemoryCatalog.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/InMemoryCatalog.scala index 98370c12a977c..9504140d51e99 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/InMemoryCatalog.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/InMemoryCatalog.scala @@ -303,13 +303,14 @@ class InMemoryCatalog( catalog(db).tables(tableDefinition.identifier.table).table = newTableDefinition } - override def alterTableSchema( + override def alterTableDataSchema( db: String, table: String, - schema: StructType): Unit = synchronized { + newDataSchema: StructType): Unit = synchronized { requireTableExists(db, table) val origTable = catalog(db).tables(table).table - catalog(db).tables(table).table = origTable.copy(schema = schema) + val newSchema = StructType(newDataSchema ++ origTable.partitionSchema) + catalog(db).tables(table).table = origTable.copy(schema = newSchema) } override def alterTableStats( diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala index 95bc3d674b4f8..a129896230775 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala @@ -324,18 +324,16 @@ class SessionCatalog( } /** - * Alter the schema of a table identified by the provided table identifier. The new schema - * should still contain the existing bucket columns and partition columns used by the table. This - * method will also update any Spark SQL-related parameters stored as Hive table properties (such - * as the schema itself). + * Alter the data schema of a table identified by the provided table identifier. The new data + * schema should not have conflict column names with the existing partition columns, and should + * still contain all the existing data columns. * * @param identifier TableIdentifier - * @param newSchema Updated schema to be used for the table (must contain existing partition and - * bucket columns, and partition columns need to be at the end) + * @param newDataSchema Updated data schema to be used for the table */ - def alterTableSchema( + def alterTableDataSchema( identifier: TableIdentifier, - newSchema: StructType): Unit = { + newDataSchema: StructType): Unit = { val db = formatDatabaseName(identifier.database.getOrElse(getCurrentDatabase)) val table = formatTableName(identifier.table) val tableIdentifier = TableIdentifier(table, Some(db)) @@ -343,10 +341,10 @@ class SessionCatalog( requireTableExists(tableIdentifier) val catalogTable = externalCatalog.getTable(db, table) - val oldSchema = catalogTable.schema - + val oldDataSchema = catalogTable.dataSchema // not supporting dropping columns yet - val nonExistentColumnNames = oldSchema.map(_.name).filterNot(columnNameResolved(newSchema, _)) + val nonExistentColumnNames = + oldDataSchema.map(_.name).filterNot(columnNameResolved(newDataSchema, _)) if (nonExistentColumnNames.nonEmpty) { throw new AnalysisException( s""" @@ -355,8 +353,7 @@ class SessionCatalog( """.stripMargin) } - // assuming the newSchema has all partition columns at the end as required - externalCatalog.alterTableSchema(db, table, newSchema) + externalCatalog.alterTableDataSchema(db, table, newDataSchema) } private def columnNameResolved(schema: StructType, colName: String): Boolean = { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalogSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalogSuite.scala index 94593ef7efa50..b376108399c1c 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalogSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalogSuite.scala @@ -245,14 +245,12 @@ abstract class ExternalCatalogSuite extends SparkFunSuite with BeforeAndAfterEac test("alter table schema") { val catalog = newBasicCatalog() - val newSchema = StructType(Seq( + val newDataSchema = StructType(Seq( StructField("col1", IntegerType), - StructField("new_field_2", StringType), - StructField("a", IntegerType), - StructField("b", StringType))) - catalog.alterTableSchema("db2", "tbl1", newSchema) + StructField("new_field_2", StringType))) + catalog.alterTableDataSchema("db2", "tbl1", newDataSchema) val newTbl1 = catalog.getTable("db2", "tbl1") - assert(newTbl1.schema == newSchema) + assert(newTbl1.dataSchema == newDataSchema) } test("alter table stats") { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalogSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalogSuite.scala index 1cce19989c60e..95c87ffa20cb7 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalogSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalogSuite.scala @@ -463,9 +463,9 @@ abstract class SessionCatalogSuite extends AnalysisTest { withBasicCatalog { sessionCatalog => sessionCatalog.createTable(newTable("t1", "default"), ignoreIfExists = false) val oldTab = sessionCatalog.externalCatalog.getTable("default", "t1") - sessionCatalog.alterTableSchema( + sessionCatalog.alterTableDataSchema( TableIdentifier("t1", Some("default")), - StructType(oldTab.dataSchema.add("c3", IntegerType) ++ oldTab.partitionSchema)) + StructType(oldTab.dataSchema.add("c3", IntegerType))) val newTab = sessionCatalog.externalCatalog.getTable("default", "t1") // construct the expected table schema @@ -480,8 +480,8 @@ abstract class SessionCatalogSuite extends AnalysisTest { sessionCatalog.createTable(newTable("t1", "default"), ignoreIfExists = false) val oldTab = sessionCatalog.externalCatalog.getTable("default", "t1") val e = intercept[AnalysisException] { - sessionCatalog.alterTableSchema( - TableIdentifier("t1", Some("default")), StructType(oldTab.schema.drop(1))) + sessionCatalog.alterTableDataSchema( + TableIdentifier("t1", Some("default")), StructType(oldTab.dataSchema.drop(1))) }.getMessage assert(e.contains("We don't support dropping columns yet.")) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala index 162e1d5be2938..a9cd65e3242c9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala @@ -857,19 +857,23 @@ object DDLUtils { } } - private[sql] def checkDataSchemaFieldNames(table: CatalogTable): Unit = { + private[sql] def checkDataColNames(table: CatalogTable): Unit = { + checkDataColNames(table, table.dataSchema.fieldNames) + } + + private[sql] def checkDataColNames(table: CatalogTable, colNames: Seq[String]): Unit = { table.provider.foreach { _.toLowerCase(Locale.ROOT) match { case HIVE_PROVIDER => val serde = table.storage.serde if (serde == HiveSerDe.sourceToSerDe("orc").get.serde) { - OrcFileFormat.checkFieldNames(table.dataSchema) + OrcFileFormat.checkFieldNames(colNames) } else if (serde == HiveSerDe.sourceToSerDe("parquet").get.serde || serde == Some("parquet.hive.serde.ParquetHiveSerDe")) { - ParquetSchemaConverter.checkFieldNames(table.dataSchema) + ParquetSchemaConverter.checkFieldNames(colNames) } - case "parquet" => ParquetSchemaConverter.checkFieldNames(table.dataSchema) - case "orc" => OrcFileFormat.checkFieldNames(table.dataSchema) + case "parquet" => ParquetSchemaConverter.checkFieldNames(colNames) + case "orc" => OrcFileFormat.checkFieldNames(colNames) case _ => } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala index 38f91639c0422..95f16b0f4baea 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala @@ -186,7 +186,7 @@ case class AlterTableRenameCommand( */ case class AlterTableAddColumnsCommand( table: TableIdentifier, - columns: Seq[StructField]) extends RunnableCommand { + colsToAdd: Seq[StructField]) extends RunnableCommand { override def run(sparkSession: SparkSession): Seq[Row] = { val catalog = sparkSession.sessionState.catalog val catalogTable = verifyAlterTableAddColumn(catalog, table) @@ -199,17 +199,13 @@ case class AlterTableAddColumnsCommand( } catalog.refreshTable(table) - // make sure any partition columns are at the end of the fields - val reorderedSchema = catalogTable.dataSchema ++ columns ++ catalogTable.partitionSchema - val newSchema = catalogTable.schema.copy(fields = reorderedSchema.toArray) - SchemaUtils.checkColumnNameDuplication( - reorderedSchema.map(_.name), "in the table definition of " + table.identifier, + (colsToAdd ++ catalogTable.schema).map(_.name), + "in the table definition of " + table.identifier, conf.caseSensitiveAnalysis) - DDLUtils.checkDataSchemaFieldNames(catalogTable.copy(schema = newSchema)) - - catalog.alterTableSchema(table, newSchema) + DDLUtils.checkDataColNames(catalogTable, colsToAdd.map(_.name)) + catalog.alterTableDataSchema(table, StructType(catalogTable.dataSchema ++ colsToAdd)) Seq.empty[Row] } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala index 018f24e290b4b..04d6f3f56eb02 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala @@ -133,12 +133,12 @@ case class DataSourceAnalysis(conf: SQLConf) extends Rule[LogicalPlan] with Cast override def apply(plan: LogicalPlan): LogicalPlan = plan transform { case CreateTable(tableDesc, mode, None) if DDLUtils.isDatasourceTable(tableDesc) => - DDLUtils.checkDataSchemaFieldNames(tableDesc) + DDLUtils.checkDataColNames(tableDesc) CreateDataSourceTableCommand(tableDesc, ignoreIfExists = mode == SaveMode.Ignore) case CreateTable(tableDesc, mode, Some(query)) if query.resolved && DDLUtils.isDatasourceTable(tableDesc) => - DDLUtils.checkDataSchemaFieldNames(tableDesc.copy(schema = query.schema)) + DDLUtils.checkDataColNames(tableDesc.copy(schema = query.schema)) CreateDataSourceTableAsSelectCommand(tableDesc, mode, query) case InsertIntoTable(l @ LogicalRelation(_: InsertableRelation, _, _, _), diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFileFormat.scala index 2eeb0065455f3..215740e90fe84 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFileFormat.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFileFormat.scala @@ -35,8 +35,7 @@ private[sql] object OrcFileFormat { } } - def checkFieldNames(schema: StructType): StructType = { - schema.fieldNames.foreach(checkFieldName) - schema + def checkFieldNames(names: Seq[String]): Unit = { + names.foreach(checkFieldName) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala index c1535babbae1f..61bd65dd48144 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala @@ -23,7 +23,6 @@ import java.net.URI import scala.collection.JavaConverters._ import scala.collection.mutable import scala.collection.parallel.ForkJoinTaskSupport -import scala.concurrent.forkjoin.ForkJoinPool import scala.util.{Failure, Try} import org.apache.hadoop.conf.Configuration @@ -306,10 +305,10 @@ class ParquetFileFormat hadoopConf.set(ParquetInputFormat.READ_SUPPORT_CLASS, classOf[ParquetReadSupport].getName) hadoopConf.set( ParquetReadSupport.SPARK_ROW_REQUESTED_SCHEMA, - ParquetSchemaConverter.checkFieldNames(requiredSchema).json) + requiredSchema.json) hadoopConf.set( ParquetWriteSupport.SPARK_ROW_SCHEMA, - ParquetSchemaConverter.checkFieldNames(requiredSchema).json) + requiredSchema.json) ParquetWriteSupport.setSchema(requiredSchema, hadoopConf) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaConverter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaConverter.scala index b3781cfc4a607..cd384d17d0cda 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaConverter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaConverter.scala @@ -571,9 +571,8 @@ private[sql] object ParquetSchemaConverter { """.stripMargin.split("\n").mkString(" ").trim) } - def checkFieldNames(schema: StructType): StructType = { - schema.fieldNames.foreach(checkFieldName) - schema + def checkFieldNames(names: Seq[String]): Unit = { + names.foreach(checkFieldName) } def checkConversionRequirement(f: => Boolean, message: String): Unit = { diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala index 96dc983b0bfc6..f8a947bf527e7 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala @@ -138,16 +138,17 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat } /** - * Checks the validity of column names. Hive metastore disallows the table to use comma in + * Checks the validity of data column names. Hive metastore disallows the table to use comma in * data column names. Partition columns do not have such a restriction. Views do not have such * a restriction. */ - private def verifyColumnNames(table: CatalogTable): Unit = { - if (table.tableType != VIEW) { - table.dataSchema.map(_.name).foreach { colName => + private def verifyDataSchema( + tableName: TableIdentifier, tableType: CatalogTableType, dataSchema: StructType): Unit = { + if (tableType != VIEW) { + dataSchema.map(_.name).foreach { colName => if (colName.contains(",")) { throw new AnalysisException("Cannot create a table having a column whose name contains " + - s"commas in Hive metastore. Table: ${table.identifier}; Column: $colName") + s"commas in Hive metastore. Table: $tableName; Column: $colName") } } } @@ -218,7 +219,8 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat val table = tableDefinition.identifier.table requireDbExists(db) verifyTableProperties(tableDefinition) - verifyColumnNames(tableDefinition) + verifyDataSchema( + tableDefinition.identifier, tableDefinition.tableType, tableDefinition.dataSchema) if (tableExists(db, table) && !ignoreIfExists) { throw new TableAlreadyExistsException(db = db, table = table) @@ -296,7 +298,7 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat storage = table.storage.copy( locationUri = None, properties = storagePropsWithLocation), - schema = table.partitionSchema, + schema = StructType(EMPTY_DATA_SCHEMA ++ table.partitionSchema), bucketSpec = None, properties = table.properties ++ tableProperties) } @@ -617,32 +619,32 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat } } - override def alterTableSchema(db: String, table: String, schema: StructType): Unit = withClient { + override def alterTableDataSchema( + db: String, table: String, newDataSchema: StructType): Unit = withClient { requireTableExists(db, table) - val rawTable = getRawTable(db, table) - // Add table metadata such as table schema, partition columns, etc. to table properties. - val updatedProperties = rawTable.properties ++ tableMetaToTableProps(rawTable, schema) - val withNewSchema = rawTable.copy(properties = updatedProperties, schema = schema) - verifyColumnNames(withNewSchema) + val oldTable = getTable(db, table) + verifyDataSchema(oldTable.identifier, oldTable.tableType, newDataSchema) + val schemaProps = + tableMetaToTableProps(oldTable, StructType(newDataSchema ++ oldTable.partitionSchema)).toMap - if (isDatasourceTable(rawTable)) { + if (isDatasourceTable(oldTable)) { // For data source tables, first try to write it with the schema set; if that does not work, // try again with updated properties and the partition schema. This is a simplified version of // what createDataSourceTable() does, and may leave the table in a state unreadable by Hive // (for example, the schema does not match the data source schema, or does not match the // storage descriptor). try { - client.alterTable(withNewSchema) + client.alterTableDataSchema(db, table, newDataSchema, schemaProps) } catch { case NonFatal(e) => val warningMessage = - s"Could not alter schema of table ${rawTable.identifier.quotedString} in a Hive " + + s"Could not alter schema of table ${oldTable.identifier.quotedString} in a Hive " + "compatible way. Updating Hive metastore in Spark SQL specific format." logWarning(warningMessage, e) - client.alterTable(withNewSchema.copy(schema = rawTable.partitionSchema)) + client.alterTableDataSchema(db, table, EMPTY_DATA_SCHEMA, schemaProps) } } else { - client.alterTable(withNewSchema) + client.alterTableDataSchema(db, table, newDataSchema, schemaProps) } } @@ -1297,6 +1299,15 @@ object HiveExternalCatalog { val CREATED_SPARK_VERSION = SPARK_SQL_PREFIX + "create.version" + // When storing data source tables in hive metastore, we need to set data schema to empty if the + // schema is hive-incompatible. However we need a hack to preserve existing behavior. Before + // Spark 2.0, we do not set a default serde here (this was done in Hive), and so if the user + // provides an empty schema Hive would automatically populate the schema with a single field + // "col". However, after SPARK-14388, we set the default serde to LazySimpleSerde so this + // implicit behavior no longer happens. Therefore, we need to do it in Spark ourselves. + val EMPTY_DATA_SCHEMA = new StructType() + .add("col", "array", nullable = true, comment = "from deserializer") + /** * Returns the fully qualified name used in table properties for a particular column stat. * For example, for column "mycol", and "min" stat, this should return diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala index 5ac65973e70e1..8adfda07d29d5 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala @@ -244,11 +244,11 @@ private[hive] class HiveMetastoreCatalog(sparkSession: SparkSession) extends Log inferredSchema match { case Some(dataSchema) => - val schema = StructType(dataSchema ++ relation.tableMeta.partitionSchema) if (inferenceMode == INFER_AND_SAVE) { - updateCatalogSchema(relation.tableMeta.identifier, schema) + updateDataSchema(relation.tableMeta.identifier, dataSchema) } - relation.tableMeta.copy(schema = schema) + val newSchema = StructType(dataSchema ++ relation.tableMeta.partitionSchema) + relation.tableMeta.copy(schema = newSchema) case None => logWarning(s"Unable to infer schema for table $tableName from file format " + s"$fileFormat (inference mode: $inferenceMode). Using metastore schema.") @@ -259,10 +259,9 @@ private[hive] class HiveMetastoreCatalog(sparkSession: SparkSession) extends Log } } - private def updateCatalogSchema(identifier: TableIdentifier, schema: StructType): Unit = try { - val db = identifier.database.get + private def updateDataSchema(identifier: TableIdentifier, newDataSchema: StructType): Unit = try { logInfo(s"Saving case-sensitive schema for table ${identifier.unquotedString}") - sparkSession.sharedState.externalCatalog.alterTableSchema(db, identifier.table, schema) + sparkSession.sessionState.catalog.alterTableDataSchema(identifier, newDataSchema) } catch { case NonFatal(ex) => logWarning(s"Unable to save case-sensitive schema for table ${identifier.unquotedString}", ex) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala index 3592b8f4846d1..ee1f6ee173063 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala @@ -152,11 +152,11 @@ object HiveAnalysis extends Rule[LogicalPlan] { InsertIntoHiveTable(r.tableMeta, partSpec, query, overwrite, ifPartitionNotExists) case CreateTable(tableDesc, mode, None) if DDLUtils.isHiveTable(tableDesc) => - DDLUtils.checkDataSchemaFieldNames(tableDesc) + DDLUtils.checkDataColNames(tableDesc) CreateTableCommand(tableDesc, ignoreIfExists = mode == SaveMode.Ignore) case CreateTable(tableDesc, mode, Some(query)) if DDLUtils.isHiveTable(tableDesc) => - DDLUtils.checkDataSchemaFieldNames(tableDesc) + DDLUtils.checkDataColNames(tableDesc) CreateHiveTableAsSelectCommand(tableDesc, query, mode) case InsertIntoDir(isLocal, storage, provider, child, overwrite) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClient.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClient.scala index ee3eb2ee8abe5..f69717441d615 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClient.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClient.scala @@ -23,6 +23,7 @@ import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.catalog._ import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec import org.apache.spark.sql.catalyst.expressions.Expression +import org.apache.spark.sql.types.StructType /** @@ -100,6 +101,16 @@ private[hive] trait HiveClient { */ def alterTable(dbName: String, tableName: String, table: CatalogTable): Unit + /** + * Updates the given table with a new data schema and table properties, and keep everything else + * unchanged. + * + * TODO(cloud-fan): it's a little hacky to introduce the schema table properties here in + * `HiveClient`, but we don't have a cleaner solution now. + */ + def alterTableDataSchema( + dbName: String, tableName: String, newDataSchema: StructType, schemaProps: Map[String, String]) + /** Creates a new database with the given name. */ def createDatabase(database: CatalogDatabase, ignoreIfExists: Boolean): Unit diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala index 16c95c53b4201..b5a5890d47b03 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala @@ -39,7 +39,6 @@ import org.apache.hadoop.hive.ql.processors._ import org.apache.hadoop.hive.ql.session.SessionState import org.apache.spark.{SparkConf, SparkException} -import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.internal.Logging import org.apache.spark.metrics.source.HiveCatalogMetrics import org.apache.spark.sql.AnalysisException @@ -51,7 +50,7 @@ import org.apache.spark.sql.catalyst.expressions.Expression import org.apache.spark.sql.catalyst.parser.{CatalystSqlParser, ParseException} import org.apache.spark.sql.execution.QueryExecutionException import org.apache.spark.sql.execution.command.DDLUtils -import org.apache.spark.sql.hive.HiveExternalCatalog +import org.apache.spark.sql.hive.HiveExternalCatalog.{DATASOURCE_SCHEMA, DATASOURCE_SCHEMA_NUMPARTS, DATASOURCE_SCHEMA_PART_PREFIX} import org.apache.spark.sql.hive.client.HiveClientImpl._ import org.apache.spark.sql.types._ import org.apache.spark.util.{CircularBuffer, Utils} @@ -515,6 +514,33 @@ private[hive] class HiveClientImpl( shim.alterTable(client, qualifiedTableName, hiveTable) } + override def alterTableDataSchema( + dbName: String, + tableName: String, + newDataSchema: StructType, + schemaProps: Map[String, String]): Unit = withHiveState { + val oldTable = client.getTable(dbName, tableName) + val hiveCols = newDataSchema.map(toHiveColumn) + oldTable.setFields(hiveCols.asJava) + + // remove old schema table properties + val it = oldTable.getParameters.entrySet.iterator + while (it.hasNext) { + val entry = it.next() + val isSchemaProp = entry.getKey.startsWith(DATASOURCE_SCHEMA_PART_PREFIX) || + entry.getKey == DATASOURCE_SCHEMA || entry.getKey == DATASOURCE_SCHEMA_NUMPARTS + if (isSchemaProp) { + it.remove() + } + } + + // set new schema table properties + schemaProps.foreach { case (k, v) => oldTable.setProperty(k, v) } + + val qualifiedTableName = s"$dbName.$tableName" + shim.alterTable(client, qualifiedTableName, oldTable) + } + override def createPartitions( db: String, table: String, @@ -896,20 +922,7 @@ private[hive] object HiveClientImpl { val (partCols, schema) = table.schema.map(toHiveColumn).partition { c => table.partitionColumnNames.contains(c.getName) } - // after SPARK-19279, it is not allowed to create a hive table with an empty schema, - // so here we should not add a default col schema - if (schema.isEmpty && HiveExternalCatalog.isDatasourceTable(table)) { - // This is a hack to preserve existing behavior. Before Spark 2.0, we do not - // set a default serde here (this was done in Hive), and so if the user provides - // an empty schema Hive would automatically populate the schema with a single - // field "col". However, after SPARK-14388, we set the default serde to - // LazySimpleSerde so this implicit behavior no longer happens. Therefore, - // we need to do it in Spark ourselves. - hiveTable.setFields( - Seq(new FieldSchema("col", "array", "from deserializer")).asJava) - } else { - hiveTable.setFields(schema.asJava) - } + hiveTable.setFields(schema.asJava) hiveTable.setPartCols(partCols.asJava) userName.foreach(hiveTable.setOwner) hiveTable.setCreateTime((table.createTime / 1000).toInt) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogSuite.scala index d43534d5914d1..2e35fdeba464d 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogSuite.scala @@ -89,4 +89,22 @@ class HiveExternalCatalogSuite extends ExternalCatalogSuite { assert(restoredTable.schema == newSchema) } } + + test("SPARK-22306: alter table schema should not erase the bucketing metadata at hive side") { + val catalog = newBasicCatalog() + externalCatalog.client.runSqlHive( + """ + |CREATE TABLE db1.t(a string, b string) + |CLUSTERED BY (a, b) SORTED BY (a, b) INTO 10 BUCKETS + |STORED AS PARQUET + """.stripMargin) + + val newSchema = new StructType().add("a", "string").add("b", "string").add("c", "string") + catalog.alterTableDataSchema("db1", "t", newSchema) + + assert(catalog.getTable("db1", "t").schema == newSchema) + val bucketString = externalCatalog.client.runSqlHive("DESC FORMATTED db1.t") + .filter(_.contains("Num Buckets")).head + assert(bucketString.contains("10")) + } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala index f5d41c91270a5..a1060476f2211 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala @@ -741,7 +741,7 @@ class MetastoreDataSourcesSuite extends QueryTest with SQLTestUtils with TestHiv val hiveTable = CatalogTable( identifier = TableIdentifier(tableName, Some("default")), tableType = CatalogTableType.MANAGED, - schema = new StructType, + schema = HiveExternalCatalog.EMPTY_DATA_SCHEMA, provider = Some("json"), storage = CatalogStorageFormat( locationUri = None, @@ -1266,7 +1266,7 @@ class MetastoreDataSourcesSuite extends QueryTest with SQLTestUtils with TestHiv val hiveTable = CatalogTable( identifier = TableIdentifier("t", Some("default")), tableType = CatalogTableType.MANAGED, - schema = new StructType, + schema = HiveExternalCatalog.EMPTY_DATA_SCHEMA, provider = Some("json"), storage = CatalogStorageFormat.empty, properties = Map( diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/Hive_2_1_DDLSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/Hive_2_1_DDLSuite.scala index 5c248b9acd04f..bc828877e35ec 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/Hive_2_1_DDLSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/Hive_2_1_DDLSuite.scala @@ -117,7 +117,7 @@ class Hive_2_1_DDLSuite extends SparkFunSuite with TestHiveSingleton with Before spark.sql(createTableStmt) val oldTable = spark.sessionState.catalog.externalCatalog.getTable("default", tableName) catalog.createTable(oldTable, true) - catalog.alterTableSchema("default", tableName, updatedSchema) + catalog.alterTableDataSchema("default", tableName, updatedSchema) val updatedTable = catalog.getTable("default", tableName) assert(updatedTable.schema.fieldNames === updatedSchema.fieldNames) From 51145f13768ec04c14762397527b1bc2648e2374 Mon Sep 17 00:00:00 2001 From: zhoukang Date: Fri, 3 Nov 2017 12:20:17 +0000 Subject: [PATCH 1596/1765] [SPARK-22407][WEB-UI] Add rdd id column on storage page to speed up navigating ## What changes were proposed in this pull request? Add rdd id column on storage page to speed up navigating. Example has attached on [SPARK-22407](https://issues.apache.org/jira/browse/SPARK-22407) An example below: ![add-rddid](https://user-images.githubusercontent.com/26762018/32361127-da0758ac-c097-11e7-9f8c-0ea7ffb87e12.png) ![rdd-cache](https://user-images.githubusercontent.com/26762018/32361128-da3c1574-c097-11e7-8ab1-2def66466f33.png) ## How was this patch tested? Current unit test and manually deploy an history server for testing Author: zhoukang Closes #19625 from caneGuy/zhoukang/add-rddid. --- .../scala/org/apache/spark/ui/storage/StoragePage.scala | 2 ++ .../org/apache/spark/ui/storage/StoragePageSuite.scala | 7 ++++--- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/ui/storage/StoragePage.scala b/core/src/main/scala/org/apache/spark/ui/storage/StoragePage.scala index aa84788f1df88..b6c764d1728e4 100644 --- a/core/src/main/scala/org/apache/spark/ui/storage/StoragePage.scala +++ b/core/src/main/scala/org/apache/spark/ui/storage/StoragePage.scala @@ -49,6 +49,7 @@ private[ui] class StoragePage(parent: StorageTab) extends WebUIPage("") { /** Header fields for the RDD table */ private val rddHeader = Seq( + "ID", "RDD Name", "Storage Level", "Cached Partitions", @@ -60,6 +61,7 @@ private[ui] class StoragePage(parent: StorageTab) extends WebUIPage("") { private def rddRow(rdd: RDDInfo): Seq[Node] = { // scalastyle:off + {rdd.id} {rdd.name} diff --git a/core/src/test/scala/org/apache/spark/ui/storage/StoragePageSuite.scala b/core/src/test/scala/org/apache/spark/ui/storage/StoragePageSuite.scala index 350c174e24742..4a48b3c686725 100644 --- a/core/src/test/scala/org/apache/spark/ui/storage/StoragePageSuite.scala +++ b/core/src/test/scala/org/apache/spark/ui/storage/StoragePageSuite.scala @@ -57,6 +57,7 @@ class StoragePageSuite extends SparkFunSuite { val xmlNodes = storagePage.rddTable(Seq(rdd1, rdd2, rdd3)) val headers = Seq( + "ID", "RDD Name", "Storage Level", "Cached Partitions", @@ -67,19 +68,19 @@ class StoragePageSuite extends SparkFunSuite { assert((xmlNodes \\ "tr").size === 3) assert(((xmlNodes \\ "tr")(0) \\ "td").map(_.text.trim) === - Seq("rdd1", "Memory Deserialized 1x Replicated", "10", "100%", "100.0 B", "0.0 B")) + Seq("1", "rdd1", "Memory Deserialized 1x Replicated", "10", "100%", "100.0 B", "0.0 B")) // Check the url assert(((xmlNodes \\ "tr")(0) \\ "td" \ "a")(0).attribute("href").map(_.text) === Some("http://localhost:4040/storage/rdd?id=1")) assert(((xmlNodes \\ "tr")(1) \\ "td").map(_.text.trim) === - Seq("rdd2", "Disk Serialized 1x Replicated", "5", "50%", "0.0 B", "200.0 B")) + Seq("2", "rdd2", "Disk Serialized 1x Replicated", "5", "50%", "0.0 B", "200.0 B")) // Check the url assert(((xmlNodes \\ "tr")(1) \\ "td" \ "a")(0).attribute("href").map(_.text) === Some("http://localhost:4040/storage/rdd?id=2")) assert(((xmlNodes \\ "tr")(2) \\ "td").map(_.text.trim) === - Seq("rdd3", "Disk Memory Serialized 1x Replicated", "10", "100%", "400.0 B", "500.0 B")) + Seq("3", "rdd3", "Disk Memory Serialized 1x Replicated", "10", "100%", "400.0 B", "500.0 B")) // Check the url assert(((xmlNodes \\ "tr")(2) \\ "td" \ "a")(0).attribute("href").map(_.text) === Some("http://localhost:4040/storage/rdd?id=3")) From 89158866085ee3aa18759efaa7d3b3846b9c6504 Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Fri, 3 Nov 2017 22:03:58 -0700 Subject: [PATCH 1597/1765] [SPARK-22418][SQL][TEST] Add test cases for NULL Handling ## What changes were proposed in this pull request? Added a test class to check NULL handling behavior. The expected behavior is defined as the one of the most well-known databases as specified here: https://sqlite.org/nulls.html. SparkSQL behaves like other DBs: - Adding anything to null gives null -> YES - Multiplying null by zero gives null -> YES - nulls are distinct in SELECT DISTINCT -> NO - nulls are distinct in a UNION -> NO - "CASE WHEN null THEN 1 ELSE 0 END" is 0? -> YES - "null OR true" is true -> YES - "not (null AND false)" is true -> YES - null in aggregation are skipped -> YES ## How was this patch tested? Added test class Author: Marco Gaido Closes #19653 from mgaido91/SPARK-22418. --- .../sql-tests/inputs/null-handling.sql | 48 +++ .../sql-tests/results/null-handling.sql.out | 305 ++++++++++++++++++ 2 files changed, 353 insertions(+) create mode 100644 sql/core/src/test/resources/sql-tests/inputs/null-handling.sql create mode 100644 sql/core/src/test/resources/sql-tests/results/null-handling.sql.out diff --git a/sql/core/src/test/resources/sql-tests/inputs/null-handling.sql b/sql/core/src/test/resources/sql-tests/inputs/null-handling.sql new file mode 100644 index 0000000000000..b90b0a6ac7500 --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/inputs/null-handling.sql @@ -0,0 +1,48 @@ +-- Create a test table with data +create table t1(a int, b int, c int) using parquet; +insert into t1 values(1,0,0); +insert into t1 values(2,0,1); +insert into t1 values(3,1,0); +insert into t1 values(4,1,1); +insert into t1 values(5,null,0); +insert into t1 values(6,null,1); +insert into t1 values(7,null,null); + +-- Adding anything to null gives null +select a, b+c from t1; + +-- Multiplying null by zero gives null +select a+10, b*0 from t1; + +-- nulls are NOT distinct in SELECT DISTINCT +select distinct b from t1; + +-- nulls are NOT distinct in UNION +select b from t1 union select b from t1; + +-- CASE WHEN null THEN 1 ELSE 0 END is 0 +select a+20, case b when c then 1 else 0 end from t1; +select a+30, case c when b then 1 else 0 end from t1; +select a+40, case when b<>0 then 1 else 0 end from t1; +select a+50, case when not b<>0 then 1 else 0 end from t1; +select a+60, case when b<>0 and c<>0 then 1 else 0 end from t1; + +-- "not (null AND false)" is true +select a+70, case when not (b<>0 and c<>0) then 1 else 0 end from t1; + +-- "null OR true" is true +select a+80, case when b<>0 or c<>0 then 1 else 0 end from t1; +select a+90, case when not (b<>0 or c<>0) then 1 else 0 end from t1; + +-- null with aggregate operators +select count(*), count(b), sum(b), avg(b), min(b), max(b) from t1; + +-- Check the behavior of NULLs in WHERE clauses +select a+100 from t1 where b<10; +select a+110 from t1 where not b>10; +select a+120 from t1 where b<10 OR c=1; +select a+130 from t1 where b<10 AND c=1; +select a+140 from t1 where not (b<10 AND c=1); +select a+150 from t1 where not (c=1 AND b<10); + +drop table t1; diff --git a/sql/core/src/test/resources/sql-tests/results/null-handling.sql.out b/sql/core/src/test/resources/sql-tests/results/null-handling.sql.out new file mode 100644 index 0000000000000..5005dfeb6cd14 --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/results/null-handling.sql.out @@ -0,0 +1,305 @@ +-- Automatically generated by SQLQueryTestSuite +-- Number of queries: 28 + + +-- !query 0 +create table t1(a int, b int, c int) using parquet +-- !query 0 schema +struct<> +-- !query 0 output + + + +-- !query 1 +insert into t1 values(1,0,0) +-- !query 1 schema +struct<> +-- !query 1 output + + + +-- !query 2 +insert into t1 values(2,0,1) +-- !query 2 schema +struct<> +-- !query 2 output + + + +-- !query 3 +insert into t1 values(3,1,0) +-- !query 3 schema +struct<> +-- !query 3 output + + + +-- !query 4 +insert into t1 values(4,1,1) +-- !query 4 schema +struct<> +-- !query 4 output + + + +-- !query 5 +insert into t1 values(5,null,0) +-- !query 5 schema +struct<> +-- !query 5 output + + + +-- !query 6 +insert into t1 values(6,null,1) +-- !query 6 schema +struct<> +-- !query 6 output + + + +-- !query 7 +insert into t1 values(7,null,null) +-- !query 7 schema +struct<> +-- !query 7 output + + + +-- !query 8 +select a, b+c from t1 +-- !query 8 schema +struct +-- !query 8 output +1 0 +2 1 +3 1 +4 2 +5 NULL +6 NULL +7 NULL + + +-- !query 9 +select a+10, b*0 from t1 +-- !query 9 schema +struct<(a + 10):int,(b * 0):int> +-- !query 9 output +11 0 +12 0 +13 0 +14 0 +15 NULL +16 NULL +17 NULL + + +-- !query 10 +select distinct b from t1 +-- !query 10 schema +struct +-- !query 10 output +0 +1 +NULL + + +-- !query 11 +select b from t1 union select b from t1 +-- !query 11 schema +struct +-- !query 11 output +0 +1 +NULL + + +-- !query 12 +select a+20, case b when c then 1 else 0 end from t1 +-- !query 12 schema +struct<(a + 20):int,CASE WHEN (b = c) THEN 1 ELSE 0 END:int> +-- !query 12 output +21 1 +22 0 +23 0 +24 1 +25 0 +26 0 +27 0 + + +-- !query 13 +select a+30, case c when b then 1 else 0 end from t1 +-- !query 13 schema +struct<(a + 30):int,CASE WHEN (c = b) THEN 1 ELSE 0 END:int> +-- !query 13 output +31 1 +32 0 +33 0 +34 1 +35 0 +36 0 +37 0 + + +-- !query 14 +select a+40, case when b<>0 then 1 else 0 end from t1 +-- !query 14 schema +struct<(a + 40):int,CASE WHEN (NOT (b = 0)) THEN 1 ELSE 0 END:int> +-- !query 14 output +41 0 +42 0 +43 1 +44 1 +45 0 +46 0 +47 0 + + +-- !query 15 +select a+50, case when not b<>0 then 1 else 0 end from t1 +-- !query 15 schema +struct<(a + 50):int,CASE WHEN (NOT (NOT (b = 0))) THEN 1 ELSE 0 END:int> +-- !query 15 output +51 1 +52 1 +53 0 +54 0 +55 0 +56 0 +57 0 + + +-- !query 16 +select a+60, case when b<>0 and c<>0 then 1 else 0 end from t1 +-- !query 16 schema +struct<(a + 60):int,CASE WHEN ((NOT (b = 0)) AND (NOT (c = 0))) THEN 1 ELSE 0 END:int> +-- !query 16 output +61 0 +62 0 +63 0 +64 1 +65 0 +66 0 +67 0 + + +-- !query 17 +select a+70, case when not (b<>0 and c<>0) then 1 else 0 end from t1 +-- !query 17 schema +struct<(a + 70):int,CASE WHEN (NOT ((NOT (b = 0)) AND (NOT (c = 0)))) THEN 1 ELSE 0 END:int> +-- !query 17 output +71 1 +72 1 +73 1 +74 0 +75 1 +76 0 +77 0 + + +-- !query 18 +select a+80, case when b<>0 or c<>0 then 1 else 0 end from t1 +-- !query 18 schema +struct<(a + 80):int,CASE WHEN ((NOT (b = 0)) OR (NOT (c = 0))) THEN 1 ELSE 0 END:int> +-- !query 18 output +81 0 +82 1 +83 1 +84 1 +85 0 +86 1 +87 0 + + +-- !query 19 +select a+90, case when not (b<>0 or c<>0) then 1 else 0 end from t1 +-- !query 19 schema +struct<(a + 90):int,CASE WHEN (NOT ((NOT (b = 0)) OR (NOT (c = 0)))) THEN 1 ELSE 0 END:int> +-- !query 19 output +91 1 +92 0 +93 0 +94 0 +95 0 +96 0 +97 0 + + +-- !query 20 +select count(*), count(b), sum(b), avg(b), min(b), max(b) from t1 +-- !query 20 schema +struct +-- !query 20 output +7 4 2 0.5 0 1 + + +-- !query 21 +select a+100 from t1 where b<10 +-- !query 21 schema +struct<(a + 100):int> +-- !query 21 output +101 +102 +103 +104 + + +-- !query 22 +select a+110 from t1 where not b>10 +-- !query 22 schema +struct<(a + 110):int> +-- !query 22 output +111 +112 +113 +114 + + +-- !query 23 +select a+120 from t1 where b<10 OR c=1 +-- !query 23 schema +struct<(a + 120):int> +-- !query 23 output +121 +122 +123 +124 +126 + + +-- !query 24 +select a+130 from t1 where b<10 AND c=1 +-- !query 24 schema +struct<(a + 130):int> +-- !query 24 output +132 +134 + + +-- !query 25 +select a+140 from t1 where not (b<10 AND c=1) +-- !query 25 schema +struct<(a + 140):int> +-- !query 25 output +141 +143 +145 + + +-- !query 26 +select a+150 from t1 where not (c=1 AND b<10) +-- !query 26 schema +struct<(a + 150):int> +-- !query 26 output +151 +153 +155 + + +-- !query 27 +drop table t1 +-- !query 27 schema +struct<> +-- !query 27 output + From bc1e101039ae3700eab42e633571256440a42b9d Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Fri, 3 Nov 2017 23:35:57 -0700 Subject: [PATCH 1598/1765] [SPARK-22254][CORE] Fix the arrayMax in BufferHolder ## What changes were proposed in this pull request? This PR replaces the old the maximum array size (`Int.MaxValue`) with the new one (`ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH`). This PR also refactor the code to calculate the new array size to easily understand why we have to use `newSize - 2` for allocating a new array. ## How was this patch tested? Used the existing test Author: Kazuaki Ishizaki Closes #19650 from kiszk/SPARK-22254. --- .../spark/util/collection/CompactBuffer.scala | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/util/collection/CompactBuffer.scala b/core/src/main/scala/org/apache/spark/util/collection/CompactBuffer.scala index f5d2fa14e49cb..5d3693190cc1f 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/CompactBuffer.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/CompactBuffer.scala @@ -19,6 +19,8 @@ package org.apache.spark.util.collection import scala.reflect.ClassTag +import org.apache.spark.unsafe.array.ByteArrayMethods + /** * An append-only buffer similar to ArrayBuffer, but more memory-efficient for small buffers. * ArrayBuffer always allocates an Object array to store the data, with 16 entries by default, @@ -126,16 +128,16 @@ private[spark] class CompactBuffer[T: ClassTag] extends Seq[T] with Serializable /** Increase our size to newSize and grow the backing array if needed. */ private def growToSize(newSize: Int): Unit = { - // Some JVMs can't allocate arrays of length Integer.MAX_VALUE; actual max is somewhat - // smaller. Be conservative and lower the cap a little. - val arrayMax = Int.MaxValue - 8 - if (newSize < 0 || newSize - 2 > arrayMax) { + // since two fields are hold in element0 and element1, an array holds newSize - 2 elements + val newArraySize = newSize - 2 + val arrayMax = ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH + if (newSize < 0 || newArraySize > arrayMax) { throw new UnsupportedOperationException(s"Can't grow buffer past $arrayMax elements") } - val capacity = if (otherElements != null) otherElements.length + 2 else 2 - if (newSize > capacity) { + val capacity = if (otherElements != null) otherElements.length else 0 + if (newArraySize > capacity) { var newArrayLen = 8L - while (newSize - 2 > newArrayLen) { + while (newArraySize > newArrayLen) { newArrayLen *= 2 } if (newArrayLen > arrayMax) { From e7adb7d7a6d017bb95da566e76b39d9d96ab42c1 Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Sat, 4 Nov 2017 16:59:58 +0900 Subject: [PATCH 1599/1765] [SPARK-22437][PYSPARK] default mode for jdbc is wrongly set to None ## What changes were proposed in this pull request? When writing using jdbc with python currently we are wrongly assigning by default None as writing mode. This is due to wrongly calling mode on the `_jwrite` object instead of `self` and it causes an exception. ## How was this patch tested? manual tests Author: Marco Gaido Closes #19654 from mgaido91/SPARK-22437. --- python/pyspark/sql/readwriter.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/pyspark/sql/readwriter.py b/python/pyspark/sql/readwriter.py index f3092918abb54..3d87567ab673d 100644 --- a/python/pyspark/sql/readwriter.py +++ b/python/pyspark/sql/readwriter.py @@ -915,7 +915,7 @@ def jdbc(self, url, table, mode=None, properties=None): jprop = JavaClass("java.util.Properties", self._spark._sc._gateway._gateway_client)() for k in properties: jprop.setProperty(k, properties[k]) - self._jwrite.mode(mode).jdbc(url, table, jprop) + self.mode(mode)._jwrite.jdbc(url, table, jprop) def _test(): From 7a8412352e3aaf14527f97c82d0d62f9de39e753 Mon Sep 17 00:00:00 2001 From: xubo245 <601450868@qq.com> Date: Sat, 4 Nov 2017 11:51:10 +0000 Subject: [PATCH 1600/1765] [SPARK-22423][SQL] Scala test source files like TestHiveSingleton.scala should be in scala source root ## What changes were proposed in this pull request? Scala test source files like TestHiveSingleton.scala should be in scala source root ## How was this patch tested? Just move scala file from java directory to scala directory No new test case in this PR. ``` renamed: mllib/src/test/java/org/apache/spark/ml/util/IdentifiableSuite.scala -> mllib/src/test/scala/org/apache/spark/ml/util/IdentifiableSuite.scala renamed: streaming/src/test/java/org/apache/spark/streaming/JavaTestUtils.scala -> streaming/src/test/scala/org/apache/spark/streaming/JavaTestUtils.scala renamed: streaming/src/test/java/org/apache/spark/streaming/api/java/JavaStreamingListenerWrapperSuite.scala -> streaming/src/test/scala/org/apache/spark/streaming/api/java/JavaStreamingListenerWrapperSuite.scala renamed: sql/hive/src/test/java/org/apache/spark/sql/hive/test/TestHiveSingleton.scala sql/hive/src/test/scala/org/apache/spark/sql/hive/test/TestHiveSingleton.scala ``` Author: xubo245 <601450868@qq.com> Closes #19639 from xubo245/scalaDirectory. --- .../org/apache/spark/ml/util/IdentifiableSuite.scala | 0 .../org/apache/spark/sql/hive/test/TestHiveSingleton.scala | 0 .../org/apache/spark/streaming/JavaTestUtils.scala | 0 .../streaming/api/java/JavaStreamingListenerWrapperSuite.scala | 0 4 files changed, 0 insertions(+), 0 deletions(-) rename mllib/src/test/{java => scala}/org/apache/spark/ml/util/IdentifiableSuite.scala (100%) rename sql/hive/src/test/{java => scala}/org/apache/spark/sql/hive/test/TestHiveSingleton.scala (100%) rename streaming/src/test/{java => scala}/org/apache/spark/streaming/JavaTestUtils.scala (100%) rename streaming/src/test/{java => scala}/org/apache/spark/streaming/api/java/JavaStreamingListenerWrapperSuite.scala (100%) diff --git a/mllib/src/test/java/org/apache/spark/ml/util/IdentifiableSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/util/IdentifiableSuite.scala similarity index 100% rename from mllib/src/test/java/org/apache/spark/ml/util/IdentifiableSuite.scala rename to mllib/src/test/scala/org/apache/spark/ml/util/IdentifiableSuite.scala diff --git a/sql/hive/src/test/java/org/apache/spark/sql/hive/test/TestHiveSingleton.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/test/TestHiveSingleton.scala similarity index 100% rename from sql/hive/src/test/java/org/apache/spark/sql/hive/test/TestHiveSingleton.scala rename to sql/hive/src/test/scala/org/apache/spark/sql/hive/test/TestHiveSingleton.scala diff --git a/streaming/src/test/java/org/apache/spark/streaming/JavaTestUtils.scala b/streaming/src/test/scala/org/apache/spark/streaming/JavaTestUtils.scala similarity index 100% rename from streaming/src/test/java/org/apache/spark/streaming/JavaTestUtils.scala rename to streaming/src/test/scala/org/apache/spark/streaming/JavaTestUtils.scala diff --git a/streaming/src/test/java/org/apache/spark/streaming/api/java/JavaStreamingListenerWrapperSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/api/java/JavaStreamingListenerWrapperSuite.scala similarity index 100% rename from streaming/src/test/java/org/apache/spark/streaming/api/java/JavaStreamingListenerWrapperSuite.scala rename to streaming/src/test/scala/org/apache/spark/streaming/api/java/JavaStreamingListenerWrapperSuite.scala From 0c2aee69b0efeea5ce8d39c0564e9e4511faf387 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Sat, 4 Nov 2017 13:11:09 +0100 Subject: [PATCH 1601/1765] [SPARK-22410][SQL] Remove unnecessary output from BatchEvalPython's children plans ## What changes were proposed in this pull request? When we insert `BatchEvalPython` for Python UDFs into a query plan, if its child has some outputs that are not used by the original parent node, `BatchEvalPython` will still take those outputs and save into the queue. When the data for those outputs are big, it is easily to generate big spill on disk. For example, the following reproducible code is from the JIRA ticket. ```python from pyspark.sql.functions import * from pyspark.sql.types import * lines_of_file = [ "this is a line" for x in xrange(10000) ] file_obj = [ "this_is_a_foldername/this_is_a_filename", lines_of_file ] data = [ file_obj for x in xrange(5) ] small_df = spark.sparkContext.parallelize(data).map(lambda x : (x[0], x[1])).toDF(["file", "lines"]) exploded = small_df.select("file", explode("lines")) def split_key(s): return s.split("/")[1] split_key_udf = udf(split_key, StringType()) with_filename = exploded.withColumn("filename", split_key_udf("file")) with_filename.explain(True) ``` The physical plan before/after this change: Before: ``` *Project [file#0, col#5, pythonUDF0#14 AS filename#9] +- BatchEvalPython [split_key(file#0)], [file#0, lines#1, col#5, pythonUDF0#14] +- Generate explode(lines#1), true, false, [col#5] +- Scan ExistingRDD[file#0,lines#1] ``` After: ``` *Project [file#0, col#5, pythonUDF0#14 AS filename#9] +- BatchEvalPython [split_key(file#0)], [col#5, file#0, pythonUDF0#14] +- *Project [col#5, file#0] +- Generate explode(lines#1), true, false, [col#5] +- Scan ExistingRDD[file#0,lines#1] ``` Before this change, `lines#1` is a redundant input to `BatchEvalPython`. This patch removes it by adding a Project. ## How was this patch tested? Manually test. Author: Liang-Chi Hsieh Closes #19642 from viirya/SPARK-22410. --- .../sql/execution/python/ExtractPythonUDFs.scala | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala index d6825369f7378..e15e760136e81 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala @@ -127,8 +127,19 @@ object ExtractPythonUDFs extends Rule[SparkPlan] with PredicateHelper { // If there aren't any, we are done. plan } else { + val inputsForPlan = plan.references ++ plan.outputSet + val prunedChildren = plan.children.map { child => + val allNeededOutput = inputsForPlan.intersect(child.outputSet).toSeq + if (allNeededOutput.length != child.output.length) { + ProjectExec(allNeededOutput, child) + } else { + child + } + } + val planWithNewChildren = plan.withNewChildren(prunedChildren) + val attributeMap = mutable.HashMap[PythonUDF, Expression]() - val splitFilter = trySplitFilter(plan) + val splitFilter = trySplitFilter(planWithNewChildren) // Rewrite the child that has the input required for the UDF val newChildren = splitFilter.children.map { child => // Pick the UDF we are going to evaluate From f7f4e9c2db405b887832fcb592cd4522795d00ca Mon Sep 17 00:00:00 2001 From: Vinitha Gankidi Date: Sat, 4 Nov 2017 11:09:47 -0700 Subject: [PATCH 1602/1765] [SPARK-22412][SQL] Fix incorrect comment in DataSourceScanExec ## What changes were proposed in this pull request? Next fit decreasing bin packing algorithm is used to combine splits in DataSourceScanExec but the comment incorrectly states that first fit decreasing algorithm is used. The current implementation doesn't go back to a previously used bin other than the bin that the last element was put into. Author: Vinitha Gankidi Closes #19634 from vgankidi/SPARK-22412. --- .../org/apache/spark/sql/execution/DataSourceScanExec.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala index e9f65031143b7..a607ec0bf8c9b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala @@ -469,7 +469,7 @@ case class FileSourceScanExec( currentSize = 0 } - // Assign files to partitions using "First Fit Decreasing" (FFD) + // Assign files to partitions using "Next Fit Decreasing" splitFiles.foreach { file => if (currentSize + file.length > maxSplitBytes) { closePartition() From 6c6626614e59b2e8d66ca853a74638d3d6267d73 Mon Sep 17 00:00:00 2001 From: Henry Robinson Date: Sat, 4 Nov 2017 22:47:25 -0700 Subject: [PATCH 1603/1765] [SPARK-22211][SQL] Remove incorrect FOJ limit pushdown ## What changes were proposed in this pull request? It's not safe in all cases to push down a LIMIT below a FULL OUTER JOIN. If the limit is pushed to one side of the FOJ, the physical join operator can not tell if a row in the non-limited side would have a match in the other side. *If* the join operator guarantees that unmatched tuples from the limited side are emitted before any unmatched tuples from the other side, pushing down the limit is safe. But this is impractical for some join implementations, e.g. SortMergeJoin. For now, disable limit pushdown through a FULL OUTER JOIN, and we can evaluate whether a more complicated solution is necessary in the future. ## How was this patch tested? Ran org.apache.spark.sql.* tests. Altered full outer join tests in LimitPushdownSuite. Author: Henry Robinson Closes #19647 from henryr/spark-22211. --- .../sql/catalyst/optimizer/Optimizer.scala | 24 +++----------- .../optimizer/LimitPushdownSuite.scala | 33 +++++++++---------- 2 files changed, 21 insertions(+), 36 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 3273a61dc7b35..3a3ccd5ff5e60 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -332,12 +332,11 @@ object LimitPushDown extends Rule[LogicalPlan] { // pushdown Limit. case LocalLimit(exp, Union(children)) => LocalLimit(exp, Union(children.map(maybePushLocalLimit(exp, _)))) - // Add extra limits below OUTER JOIN. For LEFT OUTER and FULL OUTER JOIN we push limits to the - // left and right sides, respectively. For FULL OUTER JOIN, we can only push limits to one side - // because we need to ensure that rows from the limited side still have an opportunity to match - // against all candidates from the non-limited side. We also need to ensure that this limit - // pushdown rule will not eventually introduce limits on both sides if it is applied multiple - // times. Therefore: + // Add extra limits below OUTER JOIN. For LEFT OUTER and RIGHT OUTER JOIN we push limits to + // the left and right sides, respectively. It's not safe to push limits below FULL OUTER + // JOIN in the general case without a more invasive rewrite. + // We also need to ensure that this limit pushdown rule will not eventually introduce limits + // on both sides if it is applied multiple times. Therefore: // - If one side is already limited, stack another limit on top if the new limit is smaller. // The redundant limit will be collapsed by the CombineLimits rule. // - If neither side is limited, limit the side that is estimated to be bigger. @@ -345,19 +344,6 @@ object LimitPushDown extends Rule[LogicalPlan] { val newJoin = joinType match { case RightOuter => join.copy(right = maybePushLocalLimit(exp, right)) case LeftOuter => join.copy(left = maybePushLocalLimit(exp, left)) - case FullOuter => - (left.maxRows, right.maxRows) match { - case (None, None) => - if (left.stats.sizeInBytes >= right.stats.sizeInBytes) { - join.copy(left = maybePushLocalLimit(exp, left)) - } else { - join.copy(right = maybePushLocalLimit(exp, right)) - } - case (Some(_), Some(_)) => join - case (Some(_), None) => join.copy(left = maybePushLocalLimit(exp, left)) - case (None, Some(_)) => join.copy(right = maybePushLocalLimit(exp, right)) - - } case _ => join } LocalLimit(exp, newJoin) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/LimitPushdownSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/LimitPushdownSuite.scala index f50e2e86516f0..cc98d2350c777 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/LimitPushdownSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/LimitPushdownSuite.scala @@ -113,35 +113,34 @@ class LimitPushdownSuite extends PlanTest { test("full outer join where neither side is limited and both sides have same statistics") { assert(x.stats.sizeInBytes === y.stats.sizeInBytes) - val originalQuery = x.join(y, FullOuter).limit(1) - val optimized = Optimize.execute(originalQuery.analyze) - val correctAnswer = Limit(1, LocalLimit(1, x).join(y, FullOuter)).analyze - comparePlans(optimized, correctAnswer) + val originalQuery = x.join(y, FullOuter).limit(1).analyze + val optimized = Optimize.execute(originalQuery) + // No pushdown for FULL OUTER JOINS. + comparePlans(optimized, originalQuery) } test("full outer join where neither side is limited and left side has larger statistics") { val xBig = testRelation.copy(data = Seq.fill(2)(null)).subquery('x) assert(xBig.stats.sizeInBytes > y.stats.sizeInBytes) - val originalQuery = xBig.join(y, FullOuter).limit(1) - val optimized = Optimize.execute(originalQuery.analyze) - val correctAnswer = Limit(1, LocalLimit(1, xBig).join(y, FullOuter)).analyze - comparePlans(optimized, correctAnswer) + val originalQuery = xBig.join(y, FullOuter).limit(1).analyze + val optimized = Optimize.execute(originalQuery) + // No pushdown for FULL OUTER JOINS. + comparePlans(optimized, originalQuery) } test("full outer join where neither side is limited and right side has larger statistics") { val yBig = testRelation.copy(data = Seq.fill(2)(null)).subquery('y) assert(x.stats.sizeInBytes < yBig.stats.sizeInBytes) - val originalQuery = x.join(yBig, FullOuter).limit(1) - val optimized = Optimize.execute(originalQuery.analyze) - val correctAnswer = Limit(1, x.join(LocalLimit(1, yBig), FullOuter)).analyze - comparePlans(optimized, correctAnswer) + val originalQuery = x.join(yBig, FullOuter).limit(1).analyze + val optimized = Optimize.execute(originalQuery) + // No pushdown for FULL OUTER JOINS. + comparePlans(optimized, originalQuery) } test("full outer join where both sides are limited") { - val originalQuery = x.limit(2).join(y.limit(2), FullOuter).limit(1) - val optimized = Optimize.execute(originalQuery.analyze) - val correctAnswer = Limit(1, Limit(2, x).join(Limit(2, y), FullOuter)).analyze - comparePlans(optimized, correctAnswer) + val originalQuery = x.limit(2).join(y.limit(2), FullOuter).limit(1).analyze + val optimized = Optimize.execute(originalQuery) + // No pushdown for FULL OUTER JOINS. + comparePlans(optimized, originalQuery) } } - From 3bba8621cf0a97f5c3134c9a160b1c8c5e97ba97 Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Sat, 4 Nov 2017 22:57:12 -0700 Subject: [PATCH 1604/1765] [SPARK-22378][SQL] Eliminate redundant null check in generated code for extracting an element from complex types ## What changes were proposed in this pull request? This PR eliminates redundant null check in generated code for extracting an element from complex types `GetArrayItem`, `GetMapValue`, and `GetArrayStructFields`. Since these code generation does not take care of `nullable` in `DataType` such as `ArrayType`, the generated code always has `isNullAt(index)`. This PR avoids to generate `isNullAt(index)` if `nullable` is false in `DataType`. Example ``` val nonNullArray = Literal.create(Seq(1), ArrayType(IntegerType, false)) checkEvaluation(GetArrayItem(nonNullArray, Literal(0)), 1) ``` Before this PR ``` /* 034 */ public java.lang.Object apply(java.lang.Object _i) { /* 035 */ InternalRow i = (InternalRow) _i; /* 036 */ /* 037 */ /* 038 */ /* 039 */ boolean isNull = true; /* 040 */ int value = -1; /* 041 */ /* 042 */ /* 043 */ /* 044 */ isNull = false; // resultCode could change nullability. /* 045 */ /* 046 */ final int index = (int) 0; /* 047 */ if (index >= ((ArrayData) references[0]).numElements() || index < 0 || ((ArrayData) references[0]).isNullAt(index)) { /* 048 */ isNull = true; /* 049 */ } else { /* 050 */ value = ((ArrayData) references[0]).getInt(index); /* 051 */ } /* 052 */ isNull_0 = isNull; /* 053 */ value_0 = value; /* 054 */ /* 055 */ // copy all the results into MutableRow /* 056 */ /* 057 */ if (!isNull_0) { /* 058 */ mutableRow.setInt(0, value_0); /* 059 */ } else { /* 060 */ mutableRow.setNullAt(0); /* 061 */ } /* 062 */ /* 063 */ return mutableRow; /* 064 */ } ``` After this PR (Line 47 is changed) ``` /* 034 */ public java.lang.Object apply(java.lang.Object _i) { /* 035 */ InternalRow i = (InternalRow) _i; /* 036 */ /* 037 */ /* 038 */ /* 039 */ boolean isNull = true; /* 040 */ int value = -1; /* 041 */ /* 042 */ /* 043 */ /* 044 */ isNull = false; // resultCode could change nullability. /* 045 */ /* 046 */ final int index = (int) 0; /* 047 */ if (index >= ((ArrayData) references[0]).numElements() || index < 0) { /* 048 */ isNull = true; /* 049 */ } else { /* 050 */ value = ((ArrayData) references[0]).getInt(index); /* 051 */ } /* 052 */ isNull_0 = isNull; /* 053 */ value_0 = value; /* 054 */ /* 055 */ // copy all the results into MutableRow /* 056 */ /* 057 */ if (!isNull_0) { /* 058 */ mutableRow.setInt(0, value_0); /* 059 */ } else { /* 060 */ mutableRow.setNullAt(0); /* 061 */ } /* 062 */ /* 063 */ return mutableRow; /* 064 */ } ``` ## How was this patch tested? Added test cases into `ComplexTypeSuite` Author: Kazuaki Ishizaki Closes #19598 from kiszk/SPARK-22378. --- .../expressions/complexTypeExtractors.scala | 28 +++++++++++++++---- .../expressions/ComplexTypeSuite.scala | 11 ++++++-- 2 files changed, 32 insertions(+), 7 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala index ef88cfb543ebb..7e53ca3908905 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala @@ -186,6 +186,16 @@ case class GetArrayStructFields( val values = ctx.freshName("values") val j = ctx.freshName("j") val row = ctx.freshName("row") + val nullSafeEval = if (field.nullable) { + s""" + if ($row.isNullAt($ordinal)) { + $values[$j] = null; + } else + """ + } else { + "" + } + s""" final int $n = $eval.numElements(); final Object[] $values = new Object[$n]; @@ -194,9 +204,7 @@ case class GetArrayStructFields( $values[$j] = null; } else { final InternalRow $row = $eval.getStruct($j, $numFields); - if ($row.isNullAt($ordinal)) { - $values[$j] = null; - } else { + $nullSafeEval { $values[$j] = ${ctx.getValue(row, field.dataType, ordinal.toString)}; } } @@ -242,9 +250,14 @@ case class GetArrayItem(child: Expression, ordinal: Expression) override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { nullSafeCodeGen(ctx, ev, (eval1, eval2) => { val index = ctx.freshName("index") + val nullCheck = if (child.dataType.asInstanceOf[ArrayType].containsNull) { + s" || $eval1.isNullAt($index)" + } else { + "" + } s""" final int $index = (int) $eval2; - if ($index >= $eval1.numElements() || $index < 0 || $eval1.isNullAt($index)) { + if ($index >= $eval1.numElements() || $index < 0$nullCheck) { ${ev.isNull} = true; } else { ${ev.value} = ${ctx.getValue(eval1, dataType, index)}; @@ -309,6 +322,11 @@ case class GetMapValue(child: Expression, key: Expression) val found = ctx.freshName("found") val key = ctx.freshName("key") val values = ctx.freshName("values") + val nullCheck = if (child.dataType.asInstanceOf[MapType].valueContainsNull) { + s" || $values.isNullAt($index)" + } else { + "" + } nullSafeCodeGen(ctx, ev, (eval1, eval2) => { s""" final int $length = $eval1.numElements(); @@ -326,7 +344,7 @@ case class GetMapValue(child: Expression, key: Expression) } } - if (!$found || $values.isNullAt($index)) { + if (!$found$nullCheck) { ${ev.isNull} = true; } else { ${ev.value} = ${ctx.getValue(values, dataType, index)}; diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala index 5f8a8f44d48e6..b0eaad1c80f89 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala @@ -51,6 +51,9 @@ class ComplexTypeSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(GetArrayItem(array, nullInt), null) checkEvaluation(GetArrayItem(nullArray, nullInt), null) + val nonNullArray = Literal.create(Seq(1), ArrayType(IntegerType, false)) + checkEvaluation(GetArrayItem(nonNullArray, Literal(0)), 1) + val nestedArray = Literal.create(Seq(Seq(1)), ArrayType(ArrayType(IntegerType))) checkEvaluation(GetArrayItem(nestedArray, Literal(0)), Seq(1)) } @@ -66,6 +69,9 @@ class ComplexTypeSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(GetMapValue(nullMap, nullString), null) checkEvaluation(GetMapValue(map, nullString), null) + val nonNullMap = Literal.create(Map("a" -> 1), MapType(StringType, IntegerType, false)) + checkEvaluation(GetMapValue(nonNullMap, Literal("a")), 1) + val nestedMap = Literal.create(Map("a" -> Map("b" -> "c")), MapType(StringType, typeM)) checkEvaluation(GetMapValue(nestedMap, Literal("a")), Map("b" -> "c")) } @@ -101,9 +107,10 @@ class ComplexTypeSuite extends SparkFunSuite with ExpressionEvalHelper { } test("GetArrayStructFields") { - val typeAS = ArrayType(StructType(StructField("a", IntegerType) :: Nil)) + val typeAS = ArrayType(StructType(StructField("a", IntegerType, false) :: Nil)) + val typeNullAS = ArrayType(StructType(StructField("a", IntegerType) :: Nil)) val arrayStruct = Literal.create(Seq(create_row(1)), typeAS) - val nullArrayStruct = Literal.create(null, typeAS) + val nullArrayStruct = Literal.create(null, typeNullAS) def getArrayStructFields(expr: Expression, fieldName: String): GetArrayStructFields = { expr.dataType match { From 572284c5b08515901267a37adef4f8e55df3780e Mon Sep 17 00:00:00 2001 From: Huaxin Gao Date: Sat, 4 Nov 2017 23:07:24 -0700 Subject: [PATCH 1605/1765] [SPARK-22443][SQL] add implementation of quoteIdentifier, getTableExistsQuery and getSchemaQuery in AggregatedDialect MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit … ## What changes were proposed in this pull request? override JDBCDialects methods quoteIdentifier, getTableExistsQuery and getSchemaQuery in AggregatedDialect ## How was this patch tested? Test the new implementation in JDBCSuite test("Aggregated dialects") Author: Huaxin Gao Closes #19658 from huaxingao/spark-22443. --- .../apache/spark/sql/jdbc/AggregatedDialect.scala | 12 ++++++++++++ .../scala/org/apache/spark/sql/jdbc/JDBCSuite.scala | 12 ++++++++++++ 2 files changed, 24 insertions(+) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/AggregatedDialect.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/AggregatedDialect.scala index 1419d69f983ab..f3bfea5f6bfc8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/AggregatedDialect.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/AggregatedDialect.scala @@ -42,6 +42,18 @@ private class AggregatedDialect(dialects: List[JdbcDialect]) extends JdbcDialect dialects.flatMap(_.getJDBCType(dt)).headOption } + override def quoteIdentifier(colName: String): String = { + dialects.head.quoteIdentifier(colName) + } + + override def getTableExistsQuery(table: String): String = { + dialects.head.getTableExistsQuery(table) + } + + override def getSchemaQuery(table: String): String = { + dialects.head.getSchemaQuery(table) + } + override def isCascadingTruncateTable(): Option[Boolean] = { // If any dialect claims cascading truncate, this dialect is also cascading truncate. // Otherwise, if any dialect has unknown cascading truncate, this dialect is also unknown. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala index 167b3e0190026..88a5f618d604d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala @@ -740,6 +740,15 @@ class JDBCSuite extends SparkFunSuite } else { None } + override def quoteIdentifier(colName: String): String = { + s"My $colName quoteIdentifier" + } + override def getTableExistsQuery(table: String): String = { + s"My $table Table" + } + override def getSchemaQuery(table: String): String = { + s"My $table Schema" + } override def isCascadingTruncateTable(): Option[Boolean] = Some(true) }, testH2Dialect)) assert(agg.canHandle("jdbc:h2:xxx")) @@ -747,6 +756,9 @@ class JDBCSuite extends SparkFunSuite assert(agg.getCatalystType(0, "", 1, null) === Some(LongType)) assert(agg.getCatalystType(1, "", 1, null) === Some(StringType)) assert(agg.isCascadingTruncateTable() === Some(true)) + assert(agg.quoteIdentifier ("Dummy") === "My Dummy quoteIdentifier") + assert(agg.getTableExistsQuery ("Dummy") === "My Dummy Table") + assert(agg.getSchemaQuery ("Dummy") === "My Dummy Schema") } test("Aggregated dialects: isCascadingTruncateTable") { From fe258a7963361c1f31bc3dc3a2a2ee4a5834bb58 Mon Sep 17 00:00:00 2001 From: Tristan Stevens Date: Sun, 5 Nov 2017 09:10:40 +0000 Subject: [PATCH 1606/1765] [SPARK-22429][STREAMING] Streaming checkpointing code does not retry after failure ## What changes were proposed in this pull request? SPARK-14930/SPARK-13693 put in a change to set the fs object to null after a failure, however the retry loop does not include initialization. Moved fs initialization inside the retry while loop to aid recoverability. ## How was this patch tested? Passes all existing unit tests. Author: Tristan Stevens Closes #19645 from tmgstevens/SPARK-22429. --- .../main/scala/org/apache/spark/streaming/Checkpoint.scala | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala b/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala index 9ebb91b8cab3c..3cfbcedd519d6 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala @@ -211,9 +211,6 @@ class CheckpointWriter( if (latestCheckpointTime == null || latestCheckpointTime < checkpointTime) { latestCheckpointTime = checkpointTime } - if (fs == null) { - fs = new Path(checkpointDir).getFileSystem(hadoopConf) - } var attempts = 0 val startTime = System.currentTimeMillis() val tempFile = new Path(checkpointDir, "temp") @@ -233,7 +230,9 @@ class CheckpointWriter( attempts += 1 try { logInfo(s"Saving checkpoint for time $checkpointTime to file '$checkpointFile'") - + if (fs == null) { + fs = new Path(checkpointDir).getFileSystem(hadoopConf) + } // Write checkpoint to temp file fs.delete(tempFile, true) // just in case it exists val fos = fs.create(tempFile) From db389f71972754697ebaa89b731d117d23367dd1 Mon Sep 17 00:00:00 2001 From: Yuming Wang Date: Sun, 5 Nov 2017 20:10:15 -0800 Subject: [PATCH 1607/1765] [SPARK-21625][DOC] Add incompatible Hive UDF describe to DOC ## What changes were proposed in this pull request? Add incompatible Hive UDF describe to DOC. ## How was this patch tested? N/A Author: Yuming Wang Closes #18833 from wangyum/SPARK-21625. --- docs/sql-programming-guide.md | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md index ce377875ff2b1..686fcb159d09d 100644 --- a/docs/sql-programming-guide.md +++ b/docs/sql-programming-guide.md @@ -1958,6 +1958,14 @@ Not all the APIs of the Hive UDF/UDTF/UDAF are supported by Spark SQL. Below are Spark SQL currently does not support the reuse of aggregation. * `getWindowingEvaluator` (`GenericUDAFEvaluator`) is a function to optimize aggregation by evaluating an aggregate over a fixed window. + +### Incompatible Hive UDF + +Below are the scenarios in which Hive and Spark generate different results: + +* `SQRT(n)` If n < 0, Hive returns null, Spark SQL returns NaN. +* `ACOS(n)` If n < -1 or n > 1, Hive returns null, Spark SQL returns NaN. +* `ASIN(n)` If n < -1 or n > 1, Hive returns null, Spark SQL returns NaN. # Reference From 4bacddb602e19fcd4e1ec75a7b10bed524e6989a Mon Sep 17 00:00:00 2001 From: Holden Karau Date: Sun, 5 Nov 2017 21:21:12 -0800 Subject: [PATCH 1608/1765] [SPARK-7146][ML] Expose the common params as a DeveloperAPI for other ML developers ## What changes were proposed in this pull request? Expose the common params from Spark ML as a Developer API. ## How was this patch tested? Existing tests. Author: Holden Karau Author: Holden Karau Closes #18699 from holdenk/SPARK-7146-ml-shared-params-developer-api. --- .../ml/param/shared/SharedParamsCodeGen.scala | 7 +- .../spark/ml/param/shared/sharedParams.scala | 145 ++++++++++++------ 2 files changed, 102 insertions(+), 50 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala b/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala index 1860fe8361749..a932d28fadbd8 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala @@ -177,9 +177,11 @@ private[shared] object SharedParamsCodeGen { s""" |/** - | * Trait for shared param $name$defaultValueDoc. + | * Trait for shared param $name$defaultValueDoc. This trait may be changed or + | * removed between minor versions. | */ - |private[ml] trait Has$Name extends Params { + |@DeveloperApi + |trait Has$Name extends Params { | | /** | * Param for $htmlCompliantDoc. @@ -215,6 +217,7 @@ private[shared] object SharedParamsCodeGen { | |package org.apache.spark.ml.param.shared | + |import org.apache.spark.annotation.DeveloperApi |import org.apache.spark.ml.param._ | |// DO NOT MODIFY THIS FILE! It was generated by SharedParamsCodeGen. diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala b/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala index 6061d9ca0a084..e6bdf5236e72d 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala @@ -17,6 +17,7 @@ package org.apache.spark.ml.param.shared +import org.apache.spark.annotation.DeveloperApi import org.apache.spark.ml.param._ // DO NOT MODIFY THIS FILE! It was generated by SharedParamsCodeGen. @@ -24,9 +25,11 @@ import org.apache.spark.ml.param._ // scalastyle:off /** - * Trait for shared param regParam. + * Trait for shared param regParam. This trait may be changed or + * removed between minor versions. */ -private[ml] trait HasRegParam extends Params { +@DeveloperApi +trait HasRegParam extends Params { /** * Param for regularization parameter (>= 0). @@ -39,9 +42,11 @@ private[ml] trait HasRegParam extends Params { } /** - * Trait for shared param maxIter. + * Trait for shared param maxIter. This trait may be changed or + * removed between minor versions. */ -private[ml] trait HasMaxIter extends Params { +@DeveloperApi +trait HasMaxIter extends Params { /** * Param for maximum number of iterations (>= 0). @@ -54,9 +59,11 @@ private[ml] trait HasMaxIter extends Params { } /** - * Trait for shared param featuresCol (default: "features"). + * Trait for shared param featuresCol (default: "features"). This trait may be changed or + * removed between minor versions. */ -private[ml] trait HasFeaturesCol extends Params { +@DeveloperApi +trait HasFeaturesCol extends Params { /** * Param for features column name. @@ -71,9 +78,11 @@ private[ml] trait HasFeaturesCol extends Params { } /** - * Trait for shared param labelCol (default: "label"). + * Trait for shared param labelCol (default: "label"). This trait may be changed or + * removed between minor versions. */ -private[ml] trait HasLabelCol extends Params { +@DeveloperApi +trait HasLabelCol extends Params { /** * Param for label column name. @@ -88,9 +97,11 @@ private[ml] trait HasLabelCol extends Params { } /** - * Trait for shared param predictionCol (default: "prediction"). + * Trait for shared param predictionCol (default: "prediction"). This trait may be changed or + * removed between minor versions. */ -private[ml] trait HasPredictionCol extends Params { +@DeveloperApi +trait HasPredictionCol extends Params { /** * Param for prediction column name. @@ -105,9 +116,11 @@ private[ml] trait HasPredictionCol extends Params { } /** - * Trait for shared param rawPredictionCol (default: "rawPrediction"). + * Trait for shared param rawPredictionCol (default: "rawPrediction"). This trait may be changed or + * removed between minor versions. */ -private[ml] trait HasRawPredictionCol extends Params { +@DeveloperApi +trait HasRawPredictionCol extends Params { /** * Param for raw prediction (a.k.a. confidence) column name. @@ -122,9 +135,11 @@ private[ml] trait HasRawPredictionCol extends Params { } /** - * Trait for shared param probabilityCol (default: "probability"). + * Trait for shared param probabilityCol (default: "probability"). This trait may be changed or + * removed between minor versions. */ -private[ml] trait HasProbabilityCol extends Params { +@DeveloperApi +trait HasProbabilityCol extends Params { /** * Param for Column name for predicted class conditional probabilities. Note: Not all models output well-calibrated probability estimates! These probabilities should be treated as confidences, not precise probabilities. @@ -139,9 +154,11 @@ private[ml] trait HasProbabilityCol extends Params { } /** - * Trait for shared param varianceCol. + * Trait for shared param varianceCol. This trait may be changed or + * removed between minor versions. */ -private[ml] trait HasVarianceCol extends Params { +@DeveloperApi +trait HasVarianceCol extends Params { /** * Param for Column name for the biased sample variance of prediction. @@ -154,9 +171,11 @@ private[ml] trait HasVarianceCol extends Params { } /** - * Trait for shared param threshold. + * Trait for shared param threshold. This trait may be changed or + * removed between minor versions. */ -private[ml] trait HasThreshold extends Params { +@DeveloperApi +trait HasThreshold extends Params { /** * Param for threshold in binary classification prediction, in range [0, 1]. @@ -169,9 +188,11 @@ private[ml] trait HasThreshold extends Params { } /** - * Trait for shared param thresholds. + * Trait for shared param thresholds. This trait may be changed or + * removed between minor versions. */ -private[ml] trait HasThresholds extends Params { +@DeveloperApi +trait HasThresholds extends Params { /** * Param for Thresholds in multi-class classification to adjust the probability of predicting each class. Array must have length equal to the number of classes, with values > 0 excepting that at most one value may be 0. The class with largest value p/t is predicted, where p is the original probability of that class and t is the class's threshold. @@ -184,9 +205,11 @@ private[ml] trait HasThresholds extends Params { } /** - * Trait for shared param inputCol. + * Trait for shared param inputCol. This trait may be changed or + * removed between minor versions. */ -private[ml] trait HasInputCol extends Params { +@DeveloperApi +trait HasInputCol extends Params { /** * Param for input column name. @@ -199,9 +222,11 @@ private[ml] trait HasInputCol extends Params { } /** - * Trait for shared param inputCols. + * Trait for shared param inputCols. This trait may be changed or + * removed between minor versions. */ -private[ml] trait HasInputCols extends Params { +@DeveloperApi +trait HasInputCols extends Params { /** * Param for input column names. @@ -214,9 +239,11 @@ private[ml] trait HasInputCols extends Params { } /** - * Trait for shared param outputCol (default: uid + "__output"). + * Trait for shared param outputCol (default: uid + "__output"). This trait may be changed or + * removed between minor versions. */ -private[ml] trait HasOutputCol extends Params { +@DeveloperApi +trait HasOutputCol extends Params { /** * Param for output column name. @@ -231,9 +258,11 @@ private[ml] trait HasOutputCol extends Params { } /** - * Trait for shared param checkpointInterval. + * Trait for shared param checkpointInterval. This trait may be changed or + * removed between minor versions. */ -private[ml] trait HasCheckpointInterval extends Params { +@DeveloperApi +trait HasCheckpointInterval extends Params { /** * Param for set checkpoint interval (>= 1) or disable checkpoint (-1). E.g. 10 means that the cache will get checkpointed every 10 iterations. @@ -246,9 +275,11 @@ private[ml] trait HasCheckpointInterval extends Params { } /** - * Trait for shared param fitIntercept (default: true). + * Trait for shared param fitIntercept (default: true). This trait may be changed or + * removed between minor versions. */ -private[ml] trait HasFitIntercept extends Params { +@DeveloperApi +trait HasFitIntercept extends Params { /** * Param for whether to fit an intercept term. @@ -263,9 +294,11 @@ private[ml] trait HasFitIntercept extends Params { } /** - * Trait for shared param handleInvalid. + * Trait for shared param handleInvalid. This trait may be changed or + * removed between minor versions. */ -private[ml] trait HasHandleInvalid extends Params { +@DeveloperApi +trait HasHandleInvalid extends Params { /** * Param for how to handle invalid entries. Options are skip (which will filter out rows with bad values), or error (which will throw an error). More options may be added later. @@ -278,9 +311,11 @@ private[ml] trait HasHandleInvalid extends Params { } /** - * Trait for shared param standardization (default: true). + * Trait for shared param standardization (default: true). This trait may be changed or + * removed between minor versions. */ -private[ml] trait HasStandardization extends Params { +@DeveloperApi +trait HasStandardization extends Params { /** * Param for whether to standardize the training features before fitting the model. @@ -295,9 +330,11 @@ private[ml] trait HasStandardization extends Params { } /** - * Trait for shared param seed (default: this.getClass.getName.hashCode.toLong). + * Trait for shared param seed (default: this.getClass.getName.hashCode.toLong). This trait may be changed or + * removed between minor versions. */ -private[ml] trait HasSeed extends Params { +@DeveloperApi +trait HasSeed extends Params { /** * Param for random seed. @@ -312,9 +349,11 @@ private[ml] trait HasSeed extends Params { } /** - * Trait for shared param elasticNetParam. + * Trait for shared param elasticNetParam. This trait may be changed or + * removed between minor versions. */ -private[ml] trait HasElasticNetParam extends Params { +@DeveloperApi +trait HasElasticNetParam extends Params { /** * Param for the ElasticNet mixing parameter, in range [0, 1]. For alpha = 0, the penalty is an L2 penalty. For alpha = 1, it is an L1 penalty. @@ -327,9 +366,11 @@ private[ml] trait HasElasticNetParam extends Params { } /** - * Trait for shared param tol. + * Trait for shared param tol. This trait may be changed or + * removed between minor versions. */ -private[ml] trait HasTol extends Params { +@DeveloperApi +trait HasTol extends Params { /** * Param for the convergence tolerance for iterative algorithms (>= 0). @@ -342,9 +383,11 @@ private[ml] trait HasTol extends Params { } /** - * Trait for shared param stepSize. + * Trait for shared param stepSize. This trait may be changed or + * removed between minor versions. */ -private[ml] trait HasStepSize extends Params { +@DeveloperApi +trait HasStepSize extends Params { /** * Param for Step size to be used for each iteration of optimization (> 0). @@ -357,9 +400,11 @@ private[ml] trait HasStepSize extends Params { } /** - * Trait for shared param weightCol. + * Trait for shared param weightCol. This trait may be changed or + * removed between minor versions. */ -private[ml] trait HasWeightCol extends Params { +@DeveloperApi +trait HasWeightCol extends Params { /** * Param for weight column name. If this is not set or empty, we treat all instance weights as 1.0. @@ -372,9 +417,11 @@ private[ml] trait HasWeightCol extends Params { } /** - * Trait for shared param solver. + * Trait for shared param solver. This trait may be changed or + * removed between minor versions. */ -private[ml] trait HasSolver extends Params { +@DeveloperApi +trait HasSolver extends Params { /** * Param for the solver algorithm for optimization. @@ -387,9 +434,11 @@ private[ml] trait HasSolver extends Params { } /** - * Trait for shared param aggregationDepth (default: 2). + * Trait for shared param aggregationDepth (default: 2). This trait may be changed or + * removed between minor versions. */ -private[ml] trait HasAggregationDepth extends Params { +@DeveloperApi +trait HasAggregationDepth extends Params { /** * Param for suggested depth for treeAggregate (>= 2). From 472db58cb19bbd3025eabbd185d920aab0ebb4da Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Mon, 6 Nov 2017 15:10:44 +0100 Subject: [PATCH 1609/1765] [SPARK-22445][SQL] move CodegenContext.copyResult to CodegenSupport ## What changes were proposed in this pull request? `CodegenContext.copyResult` is kind of a global status for whole stage codegen. But the tricky part is, it is only used to transfer an information from child to parent when calling the `consume` chain. We have to be super careful in `produce`/`consume`, to set it to true when producing multiple result rows, and set it to false in operators that start new pipeline(like sort). This PR moves the `copyResult` to `CodegenSupport`, and call it at `WholeStageCodegenExec`. This is much easier to reason about. ## How was this patch tested? existing tests Author: Wenchen Fan Closes #19656 from cloud-fan/whole-sage. --- .../expressions/codegen/CodeGenerator.scala | 10 ------ .../sql/execution/ColumnarBatchScan.scala | 2 +- .../spark/sql/execution/ExpandExec.scala | 3 +- .../spark/sql/execution/GenerateExec.scala | 3 +- .../apache/spark/sql/execution/SortExec.scala | 14 ++++---- .../sql/execution/WholeStageCodegenExec.scala | 35 ++++++++++++++----- .../aggregate/HashAggregateExec.scala | 14 ++++---- .../execution/basicPhysicalOperators.scala | 5 +-- .../joins/BroadcastHashJoinExec.scala | 16 +++++++-- .../execution/joins/SortMergeJoinExec.scala | 3 +- 10 files changed, 66 insertions(+), 39 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index 58738b52b299f..98eda2a1ba92c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -139,16 +139,6 @@ class CodegenContext { */ var currentVars: Seq[ExprCode] = null - /** - * Whether should we copy the result rows or not. - * - * If any operator inside WholeStageCodegen generate multiple rows from a single row (for - * example, Join), this should be true. - * - * If an operator starts a new pipeline, this should be reset to false before calling `consume()`. - */ - var copyResult: Boolean = false - /** * Holding expressions' mutable states like `MonotonicallyIncreasingID.count` as a * 3-tuple: java type, variable name, code to init it. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ColumnarBatchScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ColumnarBatchScan.scala index eb01e126bcbef..1925bad8c3545 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ColumnarBatchScan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ColumnarBatchScan.scala @@ -115,7 +115,7 @@ private[sql] trait ColumnarBatchScan extends CodegenSupport { val localIdx = ctx.freshName("localIdx") val localEnd = ctx.freshName("localEnd") val numRows = ctx.freshName("numRows") - val shouldStop = if (isShouldStopRequired) { + val shouldStop = if (parent.needStopCheck) { s"if (shouldStop()) { $idx = $rowidx + 1; return; }" } else { "// shouldStop check is eliminated" diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExpandExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExpandExec.scala index d5603b3b00914..33849f4389b92 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExpandExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExpandExec.scala @@ -93,6 +93,8 @@ case class ExpandExec( child.asInstanceOf[CodegenSupport].produce(ctx, this) } + override def needCopyResult: Boolean = true + override def doConsume(ctx: CodegenContext, input: Seq[ExprCode], row: ExprCode): String = { /* * When the projections list looks like: @@ -187,7 +189,6 @@ case class ExpandExec( val i = ctx.freshName("i") // these column have to declared before the loop. val evaluate = evaluateVariables(outputColumns) - ctx.copyResult = true s""" |$evaluate |for (int $i = 0; $i < ${projections.length}; $i ++) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/GenerateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/GenerateExec.scala index 65ca37491b6a1..c142d3b5ed4f2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/GenerateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/GenerateExec.scala @@ -132,9 +132,10 @@ case class GenerateExec( child.asInstanceOf[CodegenSupport].produce(ctx, this) } + override def needCopyResult: Boolean = true + override def doConsume(ctx: CodegenContext, input: Seq[ExprCode], row: ExprCode): String = { ctx.currentVars = input - ctx.copyResult = true // Add input rows to the values when we are joining val values = if (join) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SortExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SortExec.scala index ff71fd4dc7bb7..21765cdbd94cf 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SortExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SortExec.scala @@ -124,6 +124,14 @@ case class SortExec( // Name of sorter variable used in codegen. private var sorterVariable: String = _ + // The result rows come from the sort buffer, so this operator doesn't need to copy its result + // even if its child does. + override def needCopyResult: Boolean = false + + // Sort operator always consumes all the input rows before outputting any result, so we don't need + // a stop check before sorting. + override def needStopCheck: Boolean = false + override protected def doProduce(ctx: CodegenContext): String = { val needToSort = ctx.freshName("needToSort") ctx.addMutableState("boolean", needToSort, s"$needToSort = true;") @@ -148,10 +156,6 @@ case class SortExec( | } """.stripMargin.trim) - // The child could change `copyResult` to true, but we had already consumed all the rows, - // so `copyResult` should be reset to `false`. - ctx.copyResult = false - val outputRow = ctx.freshName("outputRow") val peakMemory = metricTerm(ctx, "peakMemory") val spillSize = metricTerm(ctx, "spillSize") @@ -177,8 +181,6 @@ case class SortExec( """.stripMargin.trim } - protected override val shouldStopRequired = false - override def doConsume(ctx: CodegenContext, input: Seq[ExprCode], row: ExprCode): String = { s""" |${row.code} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala index 286cb3bb0767c..16b5706c03bf9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala @@ -213,19 +213,32 @@ trait CodegenSupport extends SparkPlan { } /** - * For optimization to suppress shouldStop() in a loop of WholeStageCodegen. - * Returning true means we need to insert shouldStop() into the loop producing rows, if any. + * Whether or not the result rows of this operator should be copied before putting into a buffer. + * + * If any operator inside WholeStageCodegen generate multiple rows from a single row (for + * example, Join), this should be true. + * + * If an operator starts a new pipeline, this should be false. */ - def isShouldStopRequired: Boolean = { - return shouldStopRequired && (this.parent == null || this.parent.isShouldStopRequired) + def needCopyResult: Boolean = { + if (children.isEmpty) { + false + } else if (children.length == 1) { + children.head.asInstanceOf[CodegenSupport].needCopyResult + } else { + throw new UnsupportedOperationException + } } /** - * Set to false if this plan consumes all rows produced by children but doesn't output row - * to buffer by calling append(), so the children don't require shouldStop() - * in the loop of producing rows. + * Whether or not the children of this operator should generate a stop check when consuming input + * rows. This is used to suppress shouldStop() in a loop of WholeStageCodegen. + * + * This should be false if an operator starts a new pipeline, which means it consumes all rows + * produced by children but doesn't output row to buffer by calling append(), so the children + * don't require shouldStop() in the loop of producing rows. */ - protected def shouldStopRequired: Boolean = true + def needStopCheck: Boolean = parent.needStopCheck } @@ -278,6 +291,8 @@ case class InputAdapter(child: SparkPlan) extends UnaryExecNode with CodegenSupp addSuffix: Boolean = false): StringBuilder = { child.generateTreeString(depth, lastChildren, builder, verbose, "") } + + override def needCopyResult: Boolean = false } object WholeStageCodegenExec { @@ -467,7 +482,7 @@ case class WholeStageCodegenExec(child: SparkPlan) extends UnaryExecNode with Co } override def doConsume(ctx: CodegenContext, input: Seq[ExprCode], row: ExprCode): String = { - val doCopy = if (ctx.copyResult) { + val doCopy = if (needCopyResult) { ".copy()" } else { "" @@ -487,6 +502,8 @@ case class WholeStageCodegenExec(child: SparkPlan) extends UnaryExecNode with Co addSuffix: Boolean = false): StringBuilder = { child.generateTreeString(depth, lastChildren, builder, verbose, "*") } + + override def needStopCheck: Boolean = true } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala index 43e5ff89afee6..2a208a2722550 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala @@ -149,6 +149,14 @@ case class HashAggregateExec( child.asInstanceOf[CodegenSupport].inputRDDs() } + // The result rows come from the aggregate buffer, or a single row(no grouping keys), so this + // operator doesn't need to copy its result even if its child does. + override def needCopyResult: Boolean = false + + // Aggregate operator always consumes all the input rows before outputting any result, so we + // don't need a stop check before aggregating. + override def needStopCheck: Boolean = false + protected override def doProduce(ctx: CodegenContext): String = { if (groupingExpressions.isEmpty) { doProduceWithoutKeys(ctx) @@ -246,8 +254,6 @@ case class HashAggregateExec( """.stripMargin } - protected override val shouldStopRequired = false - private def doConsumeWithoutKeys(ctx: CodegenContext, input: Seq[ExprCode]): String = { // only have DeclarativeAggregate val functions = aggregateExpressions.map(_.aggregateFunction.asInstanceOf[DeclarativeAggregate]) @@ -651,10 +657,6 @@ case class HashAggregateExec( val outputFunc = generateResultFunction(ctx) val numOutput = metricTerm(ctx, "numOutputRows") - // The child could change `copyResult` to true, but we had already consumed all the rows, - // so `copyResult` should be reset to `false`. - ctx.copyResult = false - def outputFromGeneratedMap: String = { if (isFastHashMapEnabled) { if (isVectorizedHashMapEnabled) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala index e58c3cec2df15..3c7daa0a45844 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala @@ -279,6 +279,8 @@ case class SampleExec( child.asInstanceOf[CodegenSupport].produce(ctx, this) } + override def needCopyResult: Boolean = withReplacement + override def doConsume(ctx: CodegenContext, input: Seq[ExprCode], row: ExprCode): String = { val numOutput = metricTerm(ctx, "numOutputRows") val sampler = ctx.freshName("sampler") @@ -286,7 +288,6 @@ case class SampleExec( if (withReplacement) { val samplerClass = classOf[PoissonSampler[UnsafeRow]].getName val initSampler = ctx.freshName("initSampler") - ctx.copyResult = true val initSamplerFuncName = ctx.addNewFunction(initSampler, s""" @@ -450,7 +451,7 @@ case class RangeExec(range: org.apache.spark.sql.catalyst.plans.logical.Range) val localIdx = ctx.freshName("localIdx") val localEnd = ctx.freshName("localEnd") val range = ctx.freshName("range") - val shouldStop = if (isShouldStopRequired) { + val shouldStop = if (parent.needStopCheck) { s"if (shouldStop()) { $number = $value + ${step}L; return; }" } else { "// shouldStop check is eliminated" diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala index b09da9bdacb99..837b8525fed55 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala @@ -76,6 +76,20 @@ case class BroadcastHashJoinExec( streamedPlan.asInstanceOf[CodegenSupport].inputRDDs() } + override def needCopyResult: Boolean = joinType match { + case _: InnerLike | LeftOuter | RightOuter => + // For inner and outer joins, one row from the streamed side may produce multiple result rows, + // if the build side has duplicated keys. Then we need to copy the result rows before putting + // them in a buffer, because these result rows share one UnsafeRow instance. Note that here + // we wait for the broadcast to be finished, which is a no-op because it's already finished + // when we wait it in `doProduce`. + !buildPlan.executeBroadcast[HashedRelation]().value.keyIsUnique + + // Other joins types(semi, anti, existence) can at most produce one result row for one input + // row from the streamed side, so no need to copy the result rows. + case _ => false + } + override def doProduce(ctx: CodegenContext): String = { streamedPlan.asInstanceOf[CodegenSupport].produce(ctx, this) } @@ -237,7 +251,6 @@ case class BroadcastHashJoinExec( """.stripMargin } else { - ctx.copyResult = true val matches = ctx.freshName("matches") val iteratorCls = classOf[Iterator[UnsafeRow]].getName s""" @@ -310,7 +323,6 @@ case class BroadcastHashJoinExec( """.stripMargin } else { - ctx.copyResult = true val matches = ctx.freshName("matches") val iteratorCls = classOf[Iterator[UnsafeRow]].getName val found = ctx.freshName("found") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala index 4e02803552e82..cf7885f80d9fe 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala @@ -569,8 +569,9 @@ case class SortMergeJoinExec( } } + override def needCopyResult: Boolean = true + override def doProduce(ctx: CodegenContext): String = { - ctx.copyResult = true val leftInput = ctx.freshName("leftInput") ctx.addMutableState("scala.collection.Iterator", leftInput, s"$leftInput = inputs[0];") val rightInput = ctx.freshName("rightInput") From c7f38e5adb88d43ef60662c5d6ff4e7a95bff580 Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Mon, 6 Nov 2017 08:45:40 -0600 Subject: [PATCH 1610/1765] [SPARK-20644][core] Initial ground work for kvstore UI backend. There are two somewhat unrelated things going on in this patch, but both are meant to make integration of individual UI pages later on much easier. The first part is some tweaking of the code in the listener so that it does less updates of the kvstore for data that changes fast; for example, it avoids writing changes down to the store for every task-related event, since those can arrive very quickly at times. Instead, for these kinds of events, it chooses to only flush things if a certain interval has passed. The interval is based on how often the current spark-shell code updates the progress bar for jobs, so that users can get reasonably accurate data. The code also delays as much as possible hitting the underlying kvstore when replaying apps in the history server. This is to avoid unnecessary writes to disk. The second set of changes prepare the history server and SparkUI for integrating with the kvstore. A new class, AppStatusStore, is used for translating between the stored data and the types used in the UI / API. The SHS now populates a kvstore with data loaded from event logs when an application UI is requested. Because this store can hold references to disk-based resources, the code was modified to retrieve data from the store under a read lock. This allows the SHS to detect when the store is still being used, and only update it (e.g. because an updated event log was detected) when there is no other thread using the store. This change ended up creating a lot of churn in the ApplicationCache code, which was cleaned up a lot in the process. I also removed some metrics which don't make too much sense with the new code. Tested with existing and added unit tests, and by making sure the SHS still works on a real cluster. Author: Marcelo Vanzin Closes #19582 from vanzin/SPARK-20644. --- .../scala/org/apache/spark/SparkContext.scala | 17 +- .../deploy/history/ApplicationCache.scala | 441 ++++-------------- .../history/ApplicationHistoryProvider.scala | 65 ++- .../deploy/history/FsHistoryProvider.scala | 275 +++++++---- .../spark/deploy/history/HistoryServer.scala | 18 +- .../scheduler/ApplicationEventListener.scala | 67 --- .../spark/status/AppStatusListener.scala | 113 +++-- .../apache/spark/status/AppStatusStore.scala | 239 ++++++++++ .../org/apache/spark/status/LiveEntity.scala | 7 +- .../spark/status/api/v1/ApiRootResource.scala | 19 +- .../org/apache/spark/status/config.scala | 30 ++ .../org/apache/spark/status/storeTypes.scala | 26 +- .../scala/org/apache/spark/ui/SparkUI.scala | 73 +-- .../history/ApplicationCacheSuite.scala | 194 ++------ .../history/FsHistoryProviderSuite.scala | 40 +- .../deploy/history/HistoryServerSuite.scala | 28 +- .../spark/status/AppStatusListenerSuite.scala | 19 +- project/MimaExcludes.scala | 2 + 18 files changed, 878 insertions(+), 795 deletions(-) delete mode 100644 core/src/main/scala/org/apache/spark/scheduler/ApplicationEventListener.scala create mode 100644 core/src/main/scala/org/apache/spark/status/AppStatusStore.scala create mode 100644 core/src/main/scala/org/apache/spark/status/config.scala diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index c7dd635ad4c96..e5aaaf6c155eb 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -54,6 +54,7 @@ import org.apache.spark.rpc.RpcEndpointRef import org.apache.spark.scheduler._ import org.apache.spark.scheduler.cluster.{CoarseGrainedSchedulerBackend, StandaloneSchedulerBackend} import org.apache.spark.scheduler.local.LocalSchedulerBackend +import org.apache.spark.status.AppStatusStore import org.apache.spark.storage._ import org.apache.spark.storage.BlockManagerMessages.TriggerThreadDump import org.apache.spark.ui.{ConsoleProgressBar, SparkUI} @@ -213,6 +214,7 @@ class SparkContext(config: SparkConf) extends Logging { private var _jars: Seq[String] = _ private var _files: Seq[String] = _ private var _shutdownHookRef: AnyRef = _ + private var _statusStore: AppStatusStore = _ /* ------------------------------------------------------------------------------------- * | Accessors and public fields. These provide access to the internal state of the | @@ -422,6 +424,10 @@ class SparkContext(config: SparkConf) extends Logging { _jobProgressListener = new JobProgressListener(_conf) listenerBus.addToStatusQueue(jobProgressListener) + // Initialize the app status store and listener before SparkEnv is created so that it gets + // all events. + _statusStore = AppStatusStore.createLiveStore(conf, listenerBus) + // Create the Spark execution environment (cache, map output tracker, etc) _env = createSparkEnv(_conf, isLocal, listenerBus) SparkEnv.set(_env) @@ -443,8 +449,12 @@ class SparkContext(config: SparkConf) extends Logging { _ui = if (conf.getBoolean("spark.ui.enabled", true)) { - Some(SparkUI.createLiveUI(this, _conf, _jobProgressListener, - _env.securityManager, appName, startTime = startTime)) + Some(SparkUI.create(Some(this), _statusStore, _conf, + l => listenerBus.addToStatusQueue(l), + _env.securityManager, + appName, + "", + startTime)) } else { // For tests, do not enable the UI None @@ -1940,6 +1950,9 @@ class SparkContext(config: SparkConf) extends Logging { } SparkEnv.set(null) } + if (_statusStore != null) { + _statusStore.close() + } // Clear this `InheritableThreadLocal`, or it will still be inherited in child threads even this // `SparkContext` is stopped. localProperties.remove() diff --git a/core/src/main/scala/org/apache/spark/deploy/history/ApplicationCache.scala b/core/src/main/scala/org/apache/spark/deploy/history/ApplicationCache.scala index a370526c46f3d..8c63fa65b40fd 100644 --- a/core/src/main/scala/org/apache/spark/deploy/history/ApplicationCache.scala +++ b/core/src/main/scala/org/apache/spark/deploy/history/ApplicationCache.scala @@ -18,14 +18,15 @@ package org.apache.spark.deploy.history import java.util.NoSuchElementException +import java.util.concurrent.ExecutionException import javax.servlet.{DispatcherType, Filter, FilterChain, FilterConfig, ServletException, ServletRequest, ServletResponse} import javax.servlet.http.{HttpServletRequest, HttpServletResponse} import scala.collection.JavaConverters._ -import scala.util.control.NonFatal import com.codahale.metrics.{Counter, MetricRegistry, Timer} import com.google.common.cache.{CacheBuilder, CacheLoader, LoadingCache, RemovalListener, RemovalNotification} +import com.google.common.util.concurrent.UncheckedExecutionException import org.eclipse.jetty.servlet.FilterHolder import org.apache.spark.internal.Logging @@ -34,18 +35,11 @@ import org.apache.spark.ui.SparkUI import org.apache.spark.util.Clock /** - * Cache for applications. + * Cache for application UIs. * - * Completed applications are cached for as long as there is capacity for them. - * Incompleted applications have their update time checked on every - * retrieval; if the cached entry is out of date, it is refreshed. + * Applications are cached for as long as there is capacity for them. See [[LoadedAppUI]] for a + * discussion of the UI lifecycle. * - * @note there must be only one instance of [[ApplicationCache]] in a - * JVM at a time. This is because a static field in [[ApplicationCacheCheckFilterRelay]] - * keeps a reference to the cache so that HTTP requests on the attempt-specific web UIs - * can probe the current cache to see if the attempts have changed. - * - * Creating multiple instances will break this routing. * @param operations implementation of record access operations * @param retainedApplications number of retained applications * @param clock time source @@ -55,9 +49,6 @@ private[history] class ApplicationCache( val retainedApplications: Int, val clock: Clock) extends Logging { - /** - * Services the load request from the cache. - */ private val appLoader = new CacheLoader[CacheKey, CacheEntry] { /** the cache key doesn't match a cached entry, or the entry is out-of-date, so load it. */ @@ -67,9 +58,6 @@ private[history] class ApplicationCache( } - /** - * Handler for callbacks from the cache of entry removal. - */ private val removalListener = new RemovalListener[CacheKey, CacheEntry] { /** @@ -80,16 +68,11 @@ private[history] class ApplicationCache( metrics.evictionCount.inc() val key = rm.getKey logDebug(s"Evicting entry ${key}") - operations.detachSparkUI(key.appId, key.attemptId, rm.getValue().ui) + operations.detachSparkUI(key.appId, key.attemptId, rm.getValue().loadedUI.ui) } } - /** - * The cache of applications. - * - * Tagged as `protected` so as to allow subclasses in tests to access it directly - */ - protected val appCache: LoadingCache[CacheKey, CacheEntry] = { + private val appCache: LoadingCache[CacheKey, CacheEntry] = { CacheBuilder.newBuilder() .maximumSize(retainedApplications) .removalListener(removalListener) @@ -101,151 +84,51 @@ private[history] class ApplicationCache( */ val metrics = new CacheMetrics("history.cache") - init() - - /** - * Perform any startup operations. - * - * This includes declaring this instance as the cache to use in the - * [[ApplicationCacheCheckFilterRelay]]. - */ - private def init(): Unit = { - ApplicationCacheCheckFilterRelay.setApplicationCache(this) - } - - /** - * Stop the cache. - * This will reset the relay in [[ApplicationCacheCheckFilterRelay]]. - */ - def stop(): Unit = { - ApplicationCacheCheckFilterRelay.resetApplicationCache() - } - - /** - * Get an entry. - * - * Cache fetch/refresh will have taken place by the time this method returns. - * @param appAndAttempt application to look up in the format needed by the history server web UI, - * `appId/attemptId` or `appId`. - * @return the entry - */ - def get(appAndAttempt: String): SparkUI = { - val parts = splitAppAndAttemptKey(appAndAttempt) - get(parts._1, parts._2) - } - - /** - * Get the Spark UI, converting a lookup failure from an exception to `None`. - * @param appAndAttempt application to look up in the format needed by the history server web UI, - * `appId/attemptId` or `appId`. - * @return the entry - */ - def getSparkUI(appAndAttempt: String): Option[SparkUI] = { + def get(appId: String, attemptId: Option[String] = None): CacheEntry = { try { - val ui = get(appAndAttempt) - Some(ui) + appCache.get(new CacheKey(appId, attemptId)) } catch { - case NonFatal(e) => e.getCause() match { - case nsee: NoSuchElementException => - None - case cause: Exception => throw cause - } + case e @ (_: ExecutionException | _: UncheckedExecutionException) => + throw Option(e.getCause()).getOrElse(e) } } /** - * Get the associated spark UI. - * - * Cache fetch/refresh will have taken place by the time this method returns. - * @param appId application ID - * @param attemptId optional attempt ID - * @return the entry + * Run a closure while holding an application's UI read lock. This prevents the history server + * from closing the UI data store while it's being used. */ - def get(appId: String, attemptId: Option[String]): SparkUI = { - lookupAndUpdate(appId, attemptId)._1.ui - } + def withSparkUI[T](appId: String, attemptId: Option[String])(fn: SparkUI => T): T = { + var entry = get(appId, attemptId) - /** - * Look up the entry; update it if needed. - * @param appId application ID - * @param attemptId optional attempt ID - * @return the underlying cache entry -which can have its timestamp changed, and a flag to - * indicate that the entry has changed - */ - private def lookupAndUpdate(appId: String, attemptId: Option[String]): (CacheEntry, Boolean) = { - metrics.lookupCount.inc() - val cacheKey = CacheKey(appId, attemptId) - var entry = appCache.getIfPresent(cacheKey) - var updated = false - if (entry == null) { - // no entry, so fetch without any post-fetch probes for out-of-dateness - // this will trigger a callback to loadApplicationEntry() - entry = appCache.get(cacheKey) - } else if (!entry.completed) { - val now = clock.getTimeMillis() - log.debug(s"Probing at time $now for updated application $cacheKey -> $entry") - metrics.updateProbeCount.inc() - updated = time(metrics.updateProbeTimer) { - entry.updateProbe() + // If the entry exists, we need to make sure we run the closure with a valid entry. So + // we need to re-try until we can lock a valid entry for read. + entry.loadedUI.lock.readLock().lock() + try { + while (!entry.loadedUI.valid) { + entry.loadedUI.lock.readLock().unlock() + entry = null + try { + invalidate(new CacheKey(appId, attemptId)) + entry = get(appId, attemptId) + metrics.loadCount.inc() + } finally { + if (entry != null) { + entry.loadedUI.lock.readLock().lock() + } + } } - if (updated) { - logDebug(s"refreshing $cacheKey") - metrics.updateTriggeredCount.inc() - appCache.refresh(cacheKey) - // and repeat the lookup - entry = appCache.get(cacheKey) - } else { - // update the probe timestamp to the current time - entry.probeTime = now + + fn(entry.loadedUI.ui) + } finally { + if (entry != null) { + entry.loadedUI.lock.readLock().unlock() } } - (entry, updated) } - /** - * This method is visible for testing. - * - * It looks up the cached entry *and returns a clone of it*. - * This ensures that the cached entries never leak - * @param appId application ID - * @param attemptId optional attempt ID - * @return a new entry with shared SparkUI, but copies of the other fields. - */ - def lookupCacheEntry(appId: String, attemptId: Option[String]): CacheEntry = { - val entry = lookupAndUpdate(appId, attemptId)._1 - new CacheEntry(entry.ui, entry.completed, entry.updateProbe, entry.probeTime) - } - - /** - * Probe for an application being updated. - * @param appId application ID - * @param attemptId attempt ID - * @return true if an update has been triggered - */ - def checkForUpdates(appId: String, attemptId: Option[String]): Boolean = { - val (entry, updated) = lookupAndUpdate(appId, attemptId) - updated - } - - /** - * Size probe, primarily for testing. - * @return size - */ + /** @return Number of cached UIs. */ def size(): Long = appCache.size() - /** - * Emptiness predicate, primarily for testing. - * @return true if the cache is empty - */ - def isEmpty: Boolean = appCache.size() == 0 - - /** - * Time a closure, returning its output. - * @param t timer - * @param f function - * @tparam T type of return value of time - * @return the result of the function. - */ private def time[T](t: Timer)(f: => T): T = { val timeCtx = t.time() try { @@ -272,27 +155,15 @@ private[history] class ApplicationCache( * @throws NoSuchElementException if there is no matching element */ @throws[NoSuchElementException] - def loadApplicationEntry(appId: String, attemptId: Option[String]): CacheEntry = { - + private def loadApplicationEntry(appId: String, attemptId: Option[String]): CacheEntry = { logDebug(s"Loading application Entry $appId/$attemptId") metrics.loadCount.inc() - time(metrics.loadTimer) { + val loadedUI = time(metrics.loadTimer) { + metrics.lookupCount.inc() operations.getAppUI(appId, attemptId) match { - case Some(LoadedAppUI(ui, updateState)) => - val completed = ui.getApplicationInfoList.exists(_.attempts.last.completed) - if (completed) { - // completed spark UIs are attached directly - operations.attachSparkUI(appId, attemptId, ui, completed) - } else { - // incomplete UIs have the cache-check filter put in front of them. - ApplicationCacheCheckFilterRelay.registerFilter(ui, appId, attemptId) - operations.attachSparkUI(appId, attemptId, ui, completed) - } - // build the cache entry - val now = clock.getTimeMillis() - val entry = new CacheEntry(ui, completed, updateState, now) - logDebug(s"Loaded application $appId/$attemptId -> $entry") - entry + case Some(loadedUI) => + logDebug(s"Loaded application $appId/$attemptId") + loadedUI case None => metrics.lookupFailureCount.inc() // guava's cache logs via java.util log, so is of limited use. Hence: our own message @@ -301,32 +172,20 @@ private[history] class ApplicationCache( attemptId.map { id => s" attemptId '$id'" }.getOrElse(" and no attempt Id")) } } - } - - /** - * Split up an `applicationId/attemptId` or `applicationId` key into the separate pieces. - * - * @param appAndAttempt combined key - * @return a tuple of the application ID and, if present, the attemptID - */ - def splitAppAndAttemptKey(appAndAttempt: String): (String, Option[String]) = { - val parts = appAndAttempt.split("/") - require(parts.length == 1 || parts.length == 2, s"Invalid app key $appAndAttempt") - val appId = parts(0) - val attemptId = if (parts.length > 1) Some(parts(1)) else None - (appId, attemptId) - } - - /** - * Merge an appId and optional attempt Id into a key of the form `applicationId/attemptId`. - * - * If there is an `attemptId`; `applicationId` if not. - * @param appId application ID - * @param attemptId optional attempt ID - * @return a unified string - */ - def mergeAppAndAttemptToKey(appId: String, attemptId: Option[String]): String = { - appId + attemptId.map { id => s"/$id" }.getOrElse("") + try { + val completed = loadedUI.ui.getApplicationInfoList.exists(_.attempts.last.completed) + if (!completed) { + // incomplete UIs have the cache-check filter put in front of them. + registerFilter(new CacheKey(appId, attemptId), loadedUI) + } + operations.attachSparkUI(appId, attemptId, loadedUI.ui, completed) + new CacheEntry(loadedUI, completed) + } catch { + case e: Exception => + logWarning(s"Failed to initialize application UI for $appId/$attemptId", e) + operations.detachSparkUI(appId, attemptId, loadedUI.ui) + throw e + } } /** @@ -347,6 +206,26 @@ private[history] class ApplicationCache( sb.append("----\n") sb.toString() } + + /** + * Register a filter for the web UI which checks for updates to the given app/attempt + * @param ui Spark UI to attach filters to + * @param appId application ID + * @param attemptId attempt ID + */ + private def registerFilter(key: CacheKey, loadedUI: LoadedAppUI): Unit = { + require(loadedUI != null) + val enumDispatcher = java.util.EnumSet.of(DispatcherType.ASYNC, DispatcherType.REQUEST) + val filter = new ApplicationCacheCheckFilter(key, loadedUI, this) + val holder = new FilterHolder(filter) + require(loadedUI.ui.getHandlers != null, "null handlers") + loadedUI.ui.getHandlers.foreach { handler => + handler.addFilter(holder, "/*", enumDispatcher) + } + } + + def invalidate(key: CacheKey): Unit = appCache.invalidate(key) + } /** @@ -355,19 +234,14 @@ private[history] class ApplicationCache( * @param ui Spark UI * @param completed Flag to indicated that the application has completed (and so * does not need refreshing). - * @param updateProbe function to call to see if the application has been updated and - * therefore that the cached value needs to be refreshed. - * @param probeTime Times in milliseconds when the probe was last executed. */ private[history] final class CacheEntry( - val ui: SparkUI, - val completed: Boolean, - val updateProbe: () => Boolean, - var probeTime: Long) { + val loadedUI: LoadedAppUI, + val completed: Boolean) { /** string value is for test assertions */ override def toString: String = { - s"UI $ui, completed=$completed, probeTime=$probeTime" + s"UI ${loadedUI.ui}, completed=$completed" } } @@ -396,23 +270,17 @@ private[history] class CacheMetrics(prefix: String) extends Source { val evictionCount = new Counter() val loadCount = new Counter() val loadTimer = new Timer() - val updateProbeCount = new Counter() - val updateProbeTimer = new Timer() - val updateTriggeredCount = new Counter() /** all the counters: for registration and string conversion. */ private val counters = Seq( ("lookup.count", lookupCount), ("lookup.failure.count", lookupFailureCount), ("eviction.count", evictionCount), - ("load.count", loadCount), - ("update.probe.count", updateProbeCount), - ("update.triggered.count", updateTriggeredCount)) + ("load.count", loadCount)) /** all metrics, including timers */ private val allMetrics = counters ++ Seq( - ("load.timer", loadTimer), - ("update.probe.timer", updateProbeTimer)) + ("load.timer", loadTimer)) /** * Name of metric source @@ -498,23 +366,11 @@ private[history] trait ApplicationCacheOperations { * Implementation note: there's some abuse of a shared global entry here because * the configuration data passed to the servlet is just a string:string map. */ -private[history] class ApplicationCacheCheckFilter() extends Filter with Logging { - - import ApplicationCacheCheckFilterRelay._ - var appId: String = _ - var attemptId: Option[String] = _ - - /** - * Bind the app and attempt ID, throwing an exception if no application ID was provided. - * @param filterConfig configuration - */ - override def init(filterConfig: FilterConfig): Unit = { - - appId = Option(filterConfig.getInitParameter(APP_ID)) - .getOrElse(throw new ServletException(s"Missing Parameter $APP_ID")) - attemptId = Option(filterConfig.getInitParameter(ATTEMPT_ID)) - logDebug(s"initializing filter $this") - } +private[history] class ApplicationCacheCheckFilter( + key: CacheKey, + loadedUI: LoadedAppUI, + cache: ApplicationCache) + extends Filter with Logging { /** * Filter the request. @@ -543,123 +399,24 @@ private[history] class ApplicationCacheCheckFilter() extends Filter with Logging // if the request is for an attempt, check to see if it is in need of delete/refresh // and have the cache update the UI if so - if (operation=="HEAD" || operation=="GET" - && checkForUpdates(requestURI, appId, attemptId)) { - // send a redirect back to the same location. This will be routed - // to the *new* UI - logInfo(s"Application Attempt $appId/$attemptId updated; refreshing") + loadedUI.lock.readLock().lock() + if (loadedUI.valid) { + try { + chain.doFilter(request, response) + } finally { + loadedUI.lock.readLock.unlock() + } + } else { + loadedUI.lock.readLock.unlock() + cache.invalidate(key) val queryStr = Option(httpRequest.getQueryString).map("?" + _).getOrElse("") val redirectUrl = httpResponse.encodeRedirectURL(requestURI + queryStr) httpResponse.sendRedirect(redirectUrl) - } else { - chain.doFilter(request, response) } } - override def destroy(): Unit = { - } + override def init(config: FilterConfig): Unit = { } - override def toString: String = s"ApplicationCacheCheckFilter for $appId/$attemptId" -} + override def destroy(): Unit = { } -/** - * Global state for the [[ApplicationCacheCheckFilter]] instances, so that they can relay cache - * probes to the cache. - * - * This is an ugly workaround for the limitation of servlets and filters in the Java servlet - * API; they are still configured on the model of a list of classnames and configuration - * strings in a `web.xml` field, rather than a chain of instances wired up by hand or - * via an injection framework. There is no way to directly configure a servlet filter instance - * with a reference to the application cache which is must use: some global state is needed. - * - * Here, [[ApplicationCacheCheckFilter]] is that global state; it relays all requests - * to the singleton [[ApplicationCache]] - * - * The field `applicationCache` must be set for the filters to work - - * this is done during the construction of [[ApplicationCache]], which requires that there - * is only one cache serving requests through the WebUI. - * - * *Important* In test runs, if there is more than one [[ApplicationCache]], the relay logic - * will break: filters may not find instances. Tests must not do that. - * - */ -private[history] object ApplicationCacheCheckFilterRelay extends Logging { - // name of the app ID entry in the filter configuration. Mandatory. - val APP_ID = "appId" - - // name of the attempt ID entry in the filter configuration. Optional. - val ATTEMPT_ID = "attemptId" - - // name of the filter to register - val FILTER_NAME = "org.apache.spark.deploy.history.ApplicationCacheCheckFilter" - - /** the application cache to relay requests to */ - @volatile - private var applicationCache: Option[ApplicationCache] = None - - /** - * Set the application cache. Logs a warning if it is overwriting an existing value - * @param cache new cache - */ - def setApplicationCache(cache: ApplicationCache): Unit = { - applicationCache.foreach( c => logWarning(s"Overwriting application cache $c")) - applicationCache = Some(cache) - } - - /** - * Reset the application cache - */ - def resetApplicationCache(): Unit = { - applicationCache = None - } - - /** - * Check to see if there has been an update - * @param requestURI URI the request came in on - * @param appId application ID - * @param attemptId attempt ID - * @return true if an update was loaded for the app/attempt - */ - def checkForUpdates(requestURI: String, appId: String, attemptId: Option[String]): Boolean = { - - logDebug(s"Checking $appId/$attemptId from $requestURI") - applicationCache match { - case Some(cache) => - try { - cache.checkForUpdates(appId, attemptId) - } catch { - case ex: Exception => - // something went wrong. Keep going with the existing UI - logWarning(s"When checking for $appId/$attemptId from $requestURI", ex) - false - } - - case None => - logWarning("No application cache instance defined") - false - } - } - - - /** - * Register a filter for the web UI which checks for updates to the given app/attempt - * @param ui Spark UI to attach filters to - * @param appId application ID - * @param attemptId attempt ID - */ - def registerFilter( - ui: SparkUI, - appId: String, - attemptId: Option[String] ): Unit = { - require(ui != null) - val enumDispatcher = java.util.EnumSet.of(DispatcherType.ASYNC, DispatcherType.REQUEST) - val holder = new FilterHolder() - holder.setClassName(FILTER_NAME) - holder.setInitParameter(APP_ID, appId) - attemptId.foreach( id => holder.setInitParameter(ATTEMPT_ID, id)) - require(ui.getHandlers != null, "null handlers") - ui.getHandlers.foreach { handler => - handler.addFilter(holder, "/*", enumDispatcher) - } - } } diff --git a/core/src/main/scala/org/apache/spark/deploy/history/ApplicationHistoryProvider.scala b/core/src/main/scala/org/apache/spark/deploy/history/ApplicationHistoryProvider.scala index 5cb48ca3e60b0..38f0d6f2afa5e 100644 --- a/core/src/main/scala/org/apache/spark/deploy/history/ApplicationHistoryProvider.scala +++ b/core/src/main/scala/org/apache/spark/deploy/history/ApplicationHistoryProvider.scala @@ -17,6 +17,7 @@ package org.apache.spark.deploy.history +import java.util.concurrent.locks.ReentrantReadWriteLock import java.util.zip.ZipOutputStream import scala.xml.Node @@ -48,30 +49,44 @@ private[spark] case class ApplicationHistoryInfo( } /** - * A probe which can be invoked to see if a loaded Web UI has been updated. - * The probe is expected to be relative purely to that of the UI returned - * in the same [[LoadedAppUI]] instance. That is, whenever a new UI is loaded, - * the probe returned with it is the one that must be used to check for it - * being out of date; previous probes must be discarded. - */ -private[history] abstract class HistoryUpdateProbe { - /** - * Return true if the history provider has a later version of the application - * attempt than the one against this probe was constructed. - * @return - */ - def isUpdated(): Boolean -} - -/** - * All the information returned from a call to `getAppUI()`: the new UI - * and any required update state. + * A loaded UI for a Spark application. + * + * Loaded UIs are valid once created, and can be invalidated once the history provider detects + * changes in the underlying app data (e.g. an updated event log). Invalidating a UI does not + * unload it; it just signals the [[ApplicationCache]] that the UI should not be used to serve + * new requests. + * + * Reloading of the UI with new data requires collaboration between the cache and the provider; + * the provider invalidates the UI when it detects updated information, and the cache invalidates + * the cache entry when it detects the UI has been invalidated. That will trigger a callback + * on the provider to finally clean up any UI state. The cache should hold read locks when + * using the UI, and the provider should grab the UI's write lock before making destructive + * operations. + * + * Note that all this means that an invalidated UI will still stay in-memory, and any resources it + * references will remain open, until the cache either sees that it's invalidated, or evicts it to + * make room for another UI. + * * @param ui Spark UI - * @param updateProbe probe to call to check on the update state of this application attempt */ -private[history] case class LoadedAppUI( - ui: SparkUI, - updateProbe: () => Boolean) +private[history] case class LoadedAppUI(ui: SparkUI) { + + val lock = new ReentrantReadWriteLock() + + @volatile private var _valid = true + + def valid: Boolean = _valid + + def invalidate(): Unit = { + lock.writeLock().lock() + try { + _valid = false + } finally { + lock.writeLock().unlock() + } + } + +} private[history] abstract class ApplicationHistoryProvider { @@ -145,4 +160,10 @@ private[history] abstract class ApplicationHistoryProvider { * @return html text to display when the application list is empty */ def getEmptyListingHtml(): Seq[Node] = Seq.empty + + /** + * Called when an application UI is unloaded from the history server. + */ + def onUIDetached(appId: String, attemptId: Option[String], ui: SparkUI): Unit = { } + } diff --git a/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala b/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala index cf97597b484d8..f16dddea9f784 100644 --- a/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala +++ b/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala @@ -18,7 +18,7 @@ package org.apache.spark.deploy.history import java.io.{File, FileNotFoundException, IOException} -import java.util.{Date, UUID} +import java.util.{Date, ServiceLoader, UUID} import java.util.concurrent.{Executors, ExecutorService, Future, TimeUnit} import java.util.zip.{ZipEntry, ZipOutputStream} @@ -26,8 +26,7 @@ import scala.collection.JavaConverters._ import scala.collection.mutable import scala.xml.Node -import com.fasterxml.jackson.annotation.{JsonIgnore, JsonInclude} -import com.fasterxml.jackson.module.scala.DefaultScalaModule +import com.fasterxml.jackson.annotation.JsonIgnore import com.google.common.io.ByteStreams import com.google.common.util.concurrent.{MoreExecutors, ThreadFactoryBuilder} import org.apache.hadoop.fs.{FileStatus, Path} @@ -42,6 +41,7 @@ import org.apache.spark.deploy.history.config._ import org.apache.spark.internal.Logging import org.apache.spark.scheduler._ import org.apache.spark.scheduler.ReplayListenerBus._ +import org.apache.spark.status.{AppStatusListener, AppStatusStore, AppStatusStoreMetadata, KVUtils} import org.apache.spark.status.KVUtils._ import org.apache.spark.status.api.v1 import org.apache.spark.ui.SparkUI @@ -61,9 +61,6 @@ import org.apache.spark.util.kvstore._ * and update or create a matching application info element in the list of applications. * - Updated attempts are also found in [[checkForLogs]] -- if the attempt's log file has grown, the * attempt is replaced by another one with a larger log size. - * - When [[updateProbe()]] is invoked to check if a loaded [[SparkUI]] - * instance is out of date, the log size of the cached instance is checked against the app last - * loaded by [[checkForLogs]]. * * The use of log size, rather than simply relying on modification times, is needed to * address the following issues @@ -125,23 +122,30 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) private val pendingReplayTasksCount = new java.util.concurrent.atomic.AtomicInteger(0) - private val storePath = conf.get(LOCAL_STORE_DIR) + private val storePath = conf.get(LOCAL_STORE_DIR).map(new File(_)) // Visible for testing. private[history] val listing: KVStore = storePath.map { path => + require(path.isDirectory(), s"Configured store directory ($path) does not exist.") val dbPath = new File(path, "listing.ldb") - val metadata = new FsHistoryProviderMetadata(CURRENT_LISTING_VERSION, logDir.toString()) + val metadata = new FsHistoryProviderMetadata(CURRENT_LISTING_VERSION, + AppStatusStore.CURRENT_VERSION, logDir.toString()) try { open(new File(path, "listing.ldb"), metadata) } catch { + // If there's an error, remove the listing database and any existing UI database + // from the store directory, since it's extremely likely that they'll all contain + // incompatible information. case _: UnsupportedStoreVersionException | _: MetadataMismatchException => logInfo("Detected incompatible DB versions, deleting...") - Utils.deleteRecursively(dbPath) + path.listFiles().foreach(Utils.deleteRecursively) open(new File(path, "listing.ldb"), metadata) } }.getOrElse(new InMemoryStore()) + private val activeUIs = new mutable.HashMap[(String, Option[String]), LoadedAppUI]() + /** * Return a runnable that performs the given operation on the event logs. * This operation is expected to be executed periodically. @@ -165,7 +169,6 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) } } - // Conf option used for testing the initialization code. val initThread = initialize() private[history] def initialize(): Thread = { @@ -268,42 +271,100 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) override def getLastUpdatedTime(): Long = lastScanTime.get() override def getAppUI(appId: String, attemptId: Option[String]): Option[LoadedAppUI] = { - try { - val appInfo = load(appId) - appInfo.attempts - .find(_.info.attemptId == attemptId) - .map { attempt => - val replayBus = new ReplayListenerBus() - val ui = { - val conf = this.conf.clone() - val appSecManager = new SecurityManager(conf) - SparkUI.createHistoryUI(conf, replayBus, appSecManager, appInfo.info.name, - HistoryServer.getAttemptURI(appId, attempt.info.attemptId), - Some(attempt.info.lastUpdated.getTime()), attempt.info.startTime.getTime()) - // Do not call ui.bind() to avoid creating a new server for each application - } + val app = try { + load(appId) + } catch { + case _: NoSuchElementException => + return None + } + + val attempt = app.attempts.find(_.info.attemptId == attemptId).orNull + if (attempt == null) { + return None + } - val fileStatus = fs.getFileStatus(new Path(logDir, attempt.logPath)) - - val appListener = replay(fileStatus, isApplicationCompleted(fileStatus), replayBus) - assert(appListener.appId.isDefined) - ui.appSparkVersion = appListener.appSparkVersion.getOrElse("") - ui.getSecurityManager.setAcls(HISTORY_UI_ACLS_ENABLE) - // make sure to set admin acls before view acls so they are properly picked up - val adminAcls = HISTORY_UI_ADMIN_ACLS + "," + appListener.adminAcls.getOrElse("") - ui.getSecurityManager.setAdminAcls(adminAcls) - ui.getSecurityManager.setViewAcls(attempt.info.sparkUser, - appListener.viewAcls.getOrElse("")) - val adminAclsGroups = HISTORY_UI_ADMIN_ACLS_GROUPS + "," + - appListener.adminAclsGroups.getOrElse("") - ui.getSecurityManager.setAdminAclsGroups(adminAclsGroups) - ui.getSecurityManager.setViewAclsGroups(appListener.viewAclsGroups.getOrElse("")) - LoadedAppUI(ui, () => updateProbe(appId, attemptId, attempt.fileSize)) + val conf = this.conf.clone() + val secManager = new SecurityManager(conf) + + secManager.setAcls(HISTORY_UI_ACLS_ENABLE) + // make sure to set admin acls before view acls so they are properly picked up + secManager.setAdminAcls(HISTORY_UI_ADMIN_ACLS + "," + attempt.adminAcls.getOrElse("")) + secManager.setViewAcls(attempt.info.sparkUser, attempt.viewAcls.getOrElse("")) + secManager.setAdminAclsGroups(HISTORY_UI_ADMIN_ACLS_GROUPS + "," + + attempt.adminAclsGroups.getOrElse("")) + secManager.setViewAclsGroups(attempt.viewAclsGroups.getOrElse("")) + + val replayBus = new ReplayListenerBus() + + val uiStorePath = storePath.map { path => getStorePath(path, appId, attemptId) } + + val (kvstore, needReplay) = uiStorePath match { + case Some(path) => + try { + val _replay = !path.isDirectory() + (createDiskStore(path, conf), _replay) + } catch { + case e: Exception => + // Get rid of the old data and re-create it. The store is either old or corrupted. + logWarning(s"Failed to load disk store $uiStorePath for $appId.", e) + Utils.deleteRecursively(path) + (createDiskStore(path, conf), true) } + + case _ => + (new InMemoryStore(), true) + } + + val listener = if (needReplay) { + val _listener = new AppStatusListener(kvstore, conf, false) + replayBus.addListener(_listener) + Some(_listener) + } else { + None + } + + val loadedUI = { + val ui = SparkUI.create(None, new AppStatusStore(kvstore), conf, + l => replayBus.addListener(l), + secManager, + app.info.name, + HistoryServer.getAttemptURI(appId, attempt.info.attemptId), + attempt.info.startTime.getTime(), + appSparkVersion = attempt.info.appSparkVersion) + LoadedAppUI(ui) + } + + try { + val listenerFactories = ServiceLoader.load(classOf[SparkHistoryListenerFactory], + Utils.getContextOrSparkClassLoader).asScala + listenerFactories.foreach { listenerFactory => + val listeners = listenerFactory.createListeners(conf, loadedUI.ui) + listeners.foreach(replayBus.addListener) + } + + val fileStatus = fs.getFileStatus(new Path(logDir, attempt.logPath)) + replay(fileStatus, isApplicationCompleted(fileStatus), replayBus) + listener.foreach(_.flush()) } catch { - case _: FileNotFoundException => None - case _: NoSuchElementException => None + case e: Exception => + try { + kvstore.close() + } catch { + case _e: Exception => logInfo("Error closing store.", _e) + } + uiStorePath.foreach(Utils.deleteRecursively) + if (e.isInstanceOf[FileNotFoundException]) { + return None + } else { + throw e + } + } + + synchronized { + activeUIs((appId, attemptId)) = loadedUI } + + Some(loadedUI) } override def getEmptyListingHtml(): Seq[Node] = { @@ -332,11 +393,40 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) initThread.interrupt() initThread.join() } + Seq(pool, replayExecutor).foreach { executor => + executor.shutdown() + if (!executor.awaitTermination(5, TimeUnit.SECONDS)) { + executor.shutdownNow() + } + } } finally { + activeUIs.foreach { case (_, loadedUI) => loadedUI.ui.store.close() } + activeUIs.clear() listing.close() } } + override def onUIDetached(appId: String, attemptId: Option[String], ui: SparkUI): Unit = { + val uiOption = synchronized { + activeUIs.remove((appId, attemptId)) + } + uiOption.foreach { loadedUI => + loadedUI.lock.writeLock().lock() + try { + loadedUI.ui.store.close() + } finally { + loadedUI.lock.writeLock().unlock() + } + + // If the UI is not valid, delete its files from disk, if any. This relies on the fact that + // ApplicationCache will never call this method concurrently with getAppUI() for the same + // appId / attemptId. + if (!loadedUI.valid && storePath.isDefined) { + Utils.deleteRecursively(getStorePath(storePath.get, appId, attemptId)) + } + } + } + /** * Builds the application list based on the current contents of the log directory. * Tries to reuse as much of the data already in memory as possible, by not reading @@ -475,7 +565,8 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) val eventsFilter: ReplayEventsFilter = { eventString => eventString.startsWith(APPL_START_EVENT_PREFIX) || eventString.startsWith(APPL_END_EVENT_PREFIX) || - eventString.startsWith(LOG_START_EVENT_PREFIX) + eventString.startsWith(LOG_START_EVENT_PREFIX) || + eventString.startsWith(ENV_UPDATE_EVENT_PREFIX) } val logPath = fileStatus.getPath() @@ -486,8 +577,19 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) bus.addListener(listener) replay(fileStatus, isApplicationCompleted(fileStatus), bus, eventsFilter) - listener.applicationInfo.foreach(addListing) - listing.write(LogInfo(logPath.toString(), fileStatus.getLen())) + listener.applicationInfo.foreach { app => + // Invalidate the existing UI for the reloaded app attempt, if any. See LoadedAppUI for a + // discussion on the UI lifecycle. + synchronized { + activeUIs.get((app.info.id, app.attempts.head.info.attemptId)).foreach { ui => + ui.invalidate() + ui.ui.store.close() + } + } + + addListing(app) + } + listing.write(new LogInfo(logPath.toString(), fileStatus.getLen())) } /** @@ -546,16 +648,14 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) } /** - * Replays the events in the specified log file on the supplied `ReplayListenerBus`. Returns - * an `ApplicationEventListener` instance with event data captured from the replay. - * `ReplayEventsFilter` determines what events are replayed and can therefore limit the - * data captured in the returned `ApplicationEventListener` instance. + * Replays the events in the specified log file on the supplied `ReplayListenerBus`. + * `ReplayEventsFilter` determines what events are replayed. */ private def replay( eventLog: FileStatus, appCompleted: Boolean, bus: ReplayListenerBus, - eventsFilter: ReplayEventsFilter = SELECT_ALL_FILTER): ApplicationEventListener = { + eventsFilter: ReplayEventsFilter = SELECT_ALL_FILTER): Unit = { val logPath = eventLog.getPath() logInfo(s"Replaying log path: $logPath") // Note that the eventLog may have *increased* in size since when we grabbed the filestatus, @@ -566,11 +666,8 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) // after it's created, so we get a file size that is no bigger than what is actually read. val logInput = EventLoggingListener.openEventLog(logPath, fs) try { - val appListener = new ApplicationEventListener - bus.addListener(appListener) bus.replay(logInput, logPath.toString, !appCompleted, eventsFilter) logInfo(s"Finished replaying $logPath") - appListener } finally { logInput.close() } @@ -613,32 +710,6 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) | application count=$count}""".stripMargin } - /** - * Return true iff a newer version of the UI is available. The check is based on whether the - * fileSize for the currently loaded UI is smaller than the file size the last time - * the logs were loaded. - * - * This is a very cheap operation -- the work of loading the new attempt was already done - * by [[checkForLogs]]. - * @param appId application to probe - * @param attemptId attempt to probe - * @param prevFileSize the file size of the logs for the currently displayed UI - */ - private def updateProbe( - appId: String, - attemptId: Option[String], - prevFileSize: Long)(): Boolean = { - try { - val attempt = getAttempt(appId, attemptId) - val logPath = fs.makeQualified(new Path(logDir, attempt.logPath)) - recordedFileSize(logPath) > prevFileSize - } catch { - case _: NoSuchElementException => - logDebug(s"Application Attempt $appId/$attemptId not found") - false - } - } - /** * Return the last known size of the given event log, recorded the last time the file * system scanner detected a change in the file. @@ -682,6 +753,16 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) listing.write(newAppInfo) } + private def createDiskStore(path: File, conf: SparkConf): KVStore = { + val metadata = new AppStatusStoreMetadata(AppStatusStore.CURRENT_VERSION) + KVUtils.open(path, metadata) + } + + private def getStorePath(path: File, appId: String, attemptId: Option[String]): File = { + val fileName = appId + attemptId.map("_" + _).getOrElse("") + ".ldb" + new File(path, fileName) + } + /** For testing. Returns internal data about a single attempt. */ private[history] def getAttempt(appId: String, attemptId: Option[String]): AttemptInfoWrapper = { load(appId).attempts.find(_.info.attemptId == attemptId).getOrElse( @@ -699,6 +780,8 @@ private[history] object FsHistoryProvider { private val LOG_START_EVENT_PREFIX = "{\"Event\":\"SparkListenerLogStart\"" + private val ENV_UPDATE_EVENT_PREFIX = "{\"Event\":\"SparkListenerEnvironmentUpdate\"," + /** * Current version of the data written to the listing database. When opening an existing * db, if the version does not match this value, the FsHistoryProvider will throw away @@ -708,17 +791,22 @@ private[history] object FsHistoryProvider { } private[history] case class FsHistoryProviderMetadata( - version: Long, - logDir: String) + version: Long, + uiVersion: Long, + logDir: String) private[history] case class LogInfo( - @KVIndexParam logPath: String, - fileSize: Long) + @KVIndexParam logPath: String, + fileSize: Long) private[history] class AttemptInfoWrapper( val info: v1.ApplicationAttemptInfo, val logPath: String, - val fileSize: Long) { + val fileSize: Long, + val adminAcls: Option[String], + val viewAcls: Option[String], + val adminAclsGroups: Option[String], + val viewAclsGroups: Option[String]) { def toAppAttemptInfo(): ApplicationAttemptInfo = { ApplicationAttemptInfo(info.attemptId, info.startTime.getTime(), @@ -769,6 +857,14 @@ private[history] class AppListingListener(log: FileStatus, clock: Clock) extends attempt.completed = true } + override def onEnvironmentUpdate(event: SparkListenerEnvironmentUpdate): Unit = { + val allProperties = event.environmentDetails("Spark Properties").toMap + attempt.viewAcls = allProperties.get("spark.ui.view.acls") + attempt.adminAcls = allProperties.get("spark.admin.acls") + attempt.viewAclsGroups = allProperties.get("spark.ui.view.acls.groups") + attempt.adminAclsGroups = allProperties.get("spark.admin.acls.groups") + } + override def onOtherEvent(event: SparkListenerEvent): Unit = event match { case SparkListenerLogStart(sparkVersion) => attempt.appSparkVersion = sparkVersion @@ -809,6 +905,11 @@ private[history] class AppListingListener(log: FileStatus, clock: Clock) extends var completed = false var appSparkVersion = "" + var adminAcls: Option[String] = None + var viewAcls: Option[String] = None + var adminAclsGroups: Option[String] = None + var viewAclsGroups: Option[String] = None + def toView(): AttemptInfoWrapper = { val apiInfo = new v1.ApplicationAttemptInfo( attemptId, @@ -822,7 +923,11 @@ private[history] class AppListingListener(log: FileStatus, clock: Clock) extends new AttemptInfoWrapper( apiInfo, logPath, - fileSize) + fileSize, + adminAcls, + viewAcls, + adminAclsGroups, + viewAclsGroups) } } diff --git a/core/src/main/scala/org/apache/spark/deploy/history/HistoryServer.scala b/core/src/main/scala/org/apache/spark/deploy/history/HistoryServer.scala index d9c8fda99ef97..b822a48e98e91 100644 --- a/core/src/main/scala/org/apache/spark/deploy/history/HistoryServer.scala +++ b/core/src/main/scala/org/apache/spark/deploy/history/HistoryServer.scala @@ -106,8 +106,8 @@ class HistoryServer( } } - def getSparkUI(appKey: String): Option[SparkUI] = { - appCache.getSparkUI(appKey) + override def withSparkUI[T](appId: String, attemptId: Option[String])(fn: SparkUI => T): T = { + appCache.withSparkUI(appId, attemptId)(fn) } initialize() @@ -140,7 +140,6 @@ class HistoryServer( override def stop() { super.stop() provider.stop() - appCache.stop() } /** Attach a reconstructed UI to this server. Only valid after bind(). */ @@ -158,6 +157,7 @@ class HistoryServer( override def detachSparkUI(appId: String, attemptId: Option[String], ui: SparkUI): Unit = { assert(serverInfo.isDefined, "HistoryServer must be bound before detaching SparkUIs") ui.getHandlers.foreach(detachHandler) + provider.onUIDetached(appId, attemptId, ui) } /** @@ -224,15 +224,13 @@ class HistoryServer( */ private def loadAppUi(appId: String, attemptId: Option[String]): Boolean = { try { - appCache.get(appId, attemptId) + appCache.withSparkUI(appId, attemptId) { _ => + // Do nothing, just force the UI to load. + } true } catch { - case NonFatal(e) => e.getCause() match { - case nsee: NoSuchElementException => - false - - case cause: Exception => throw cause - } + case NonFatal(e: NoSuchElementException) => + false } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/ApplicationEventListener.scala b/core/src/main/scala/org/apache/spark/scheduler/ApplicationEventListener.scala deleted file mode 100644 index 6da8865cd10d3..0000000000000 --- a/core/src/main/scala/org/apache/spark/scheduler/ApplicationEventListener.scala +++ /dev/null @@ -1,67 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.scheduler - -/** - * A simple listener for application events. - * - * This listener expects to hear events from a single application only. If events - * from multiple applications are seen, the behavior is unspecified. - */ -private[spark] class ApplicationEventListener extends SparkListener { - var appName: Option[String] = None - var appId: Option[String] = None - var appAttemptId: Option[String] = None - var sparkUser: Option[String] = None - var startTime: Option[Long] = None - var endTime: Option[Long] = None - var viewAcls: Option[String] = None - var adminAcls: Option[String] = None - var viewAclsGroups: Option[String] = None - var adminAclsGroups: Option[String] = None - var appSparkVersion: Option[String] = None - - override def onApplicationStart(applicationStart: SparkListenerApplicationStart) { - appName = Some(applicationStart.appName) - appId = applicationStart.appId - appAttemptId = applicationStart.appAttemptId - startTime = Some(applicationStart.time) - sparkUser = Some(applicationStart.sparkUser) - } - - override def onApplicationEnd(applicationEnd: SparkListenerApplicationEnd) { - endTime = Some(applicationEnd.time) - } - - override def onEnvironmentUpdate(environmentUpdate: SparkListenerEnvironmentUpdate) { - synchronized { - val environmentDetails = environmentUpdate.environmentDetails - val allProperties = environmentDetails("Spark Properties").toMap - viewAcls = allProperties.get("spark.ui.view.acls") - adminAcls = allProperties.get("spark.admin.acls") - viewAclsGroups = allProperties.get("spark.ui.view.acls.groups") - adminAclsGroups = allProperties.get("spark.admin.acls.groups") - } - } - - override def onOtherEvent(event: SparkListenerEvent): Unit = event match { - case SparkListenerLogStart(sparkVersion) => - appSparkVersion = Some(sparkVersion) - case _ => - } -} diff --git a/core/src/main/scala/org/apache/spark/status/AppStatusListener.scala b/core/src/main/scala/org/apache/spark/status/AppStatusListener.scala index f120685c941df..cd43612fae357 100644 --- a/core/src/main/scala/org/apache/spark/status/AppStatusListener.scala +++ b/core/src/main/scala/org/apache/spark/status/AppStatusListener.scala @@ -34,12 +34,22 @@ import org.apache.spark.util.kvstore.KVStore * A Spark listener that writes application information to a data store. The types written to the * store are defined in the `storeTypes.scala` file and are based on the public REST API. */ -private class AppStatusListener(kvstore: KVStore) extends SparkListener with Logging { +private[spark] class AppStatusListener( + kvstore: KVStore, + conf: SparkConf, + live: Boolean) extends SparkListener with Logging { + + import config._ private var sparkVersion = SPARK_VERSION private var appInfo: v1.ApplicationInfo = null private var coresPerTask: Int = 1 + // How often to update live entities. -1 means "never update" when replaying applications, + // meaning only the last write will happen. For live applications, this avoids a few + // operations that we can live without when rapidly processing incoming task events. + private val liveUpdatePeriodNs = if (live) conf.get(LIVE_ENTITY_UPDATE_PERIOD) else -1L + // Keep track of live entities, so that task metrics can be efficiently updated (without // causing too many writes to the underlying store, and other expensive operations). private val liveStages = new HashMap[(Int, Int), LiveStage]() @@ -110,13 +120,13 @@ private class AppStatusListener(kvstore: KVStore) extends SparkListener with Log exec.totalCores = event.executorInfo.totalCores exec.maxTasks = event.executorInfo.totalCores / coresPerTask exec.executorLogs = event.executorInfo.logUrlMap - update(exec) + liveUpdate(exec, System.nanoTime()) } override def onExecutorRemoved(event: SparkListenerExecutorRemoved): Unit = { liveExecutors.remove(event.executorId).foreach { exec => exec.isActive = false - update(exec) + update(exec, System.nanoTime()) } } @@ -139,21 +149,25 @@ private class AppStatusListener(kvstore: KVStore) extends SparkListener with Log private def updateBlackListStatus(execId: String, blacklisted: Boolean): Unit = { liveExecutors.get(execId).foreach { exec => exec.isBlacklisted = blacklisted - update(exec) + liveUpdate(exec, System.nanoTime()) } } private def updateNodeBlackList(host: String, blacklisted: Boolean): Unit = { + val now = System.nanoTime() + // Implicitly (un)blacklist every executor associated with the node. liveExecutors.values.foreach { exec => if (exec.hostname == host) { exec.isBlacklisted = blacklisted - update(exec) + liveUpdate(exec, now) } } } override def onJobStart(event: SparkListenerJobStart): Unit = { + val now = System.nanoTime() + // Compute (a potential over-estimate of) the number of tasks that will be run by this job. // This may be an over-estimate because the job start event references all of the result // stages' transitive stage dependencies, but some of these stages might be skipped if their @@ -178,7 +192,7 @@ private class AppStatusListener(kvstore: KVStore) extends SparkListener with Log jobGroup, numTasks) liveJobs.put(event.jobId, job) - update(job) + liveUpdate(job, now) event.stageInfos.foreach { stageInfo => // A new job submission may re-use an existing stage, so this code needs to do an update @@ -186,7 +200,7 @@ private class AppStatusListener(kvstore: KVStore) extends SparkListener with Log val stage = getOrCreateStage(stageInfo) stage.jobs :+= job stage.jobIds += event.jobId - update(stage) + liveUpdate(stage, now) } } @@ -198,11 +212,12 @@ private class AppStatusListener(kvstore: KVStore) extends SparkListener with Log } job.completionTime = Some(new Date(event.time)) - update(job) + update(job, System.nanoTime()) } } override def onStageSubmitted(event: SparkListenerStageSubmitted): Unit = { + val now = System.nanoTime() val stage = getOrCreateStage(event.stageInfo) stage.status = v1.StageStatus.ACTIVE stage.schedulingPool = Option(event.properties).flatMap { p => @@ -218,38 +233,39 @@ private class AppStatusListener(kvstore: KVStore) extends SparkListener with Log stage.jobs.foreach { job => job.completedStages = job.completedStages - event.stageInfo.stageId job.activeStages += 1 - update(job) + liveUpdate(job, now) } event.stageInfo.rddInfos.foreach { info => if (info.storageLevel.isValid) { - update(liveRDDs.getOrElseUpdate(info.id, new LiveRDD(info))) + liveUpdate(liveRDDs.getOrElseUpdate(info.id, new LiveRDD(info)), now) } } - update(stage) + liveUpdate(stage, now) } override def onTaskStart(event: SparkListenerTaskStart): Unit = { + val now = System.nanoTime() val task = new LiveTask(event.taskInfo, event.stageId, event.stageAttemptId) liveTasks.put(event.taskInfo.taskId, task) - update(task) + liveUpdate(task, now) liveStages.get((event.stageId, event.stageAttemptId)).foreach { stage => stage.activeTasks += 1 stage.firstLaunchTime = math.min(stage.firstLaunchTime, event.taskInfo.launchTime) - update(stage) + maybeUpdate(stage, now) stage.jobs.foreach { job => job.activeTasks += 1 - update(job) + maybeUpdate(job, now) } } liveExecutors.get(event.taskInfo.executorId).foreach { exec => exec.activeTasks += 1 exec.totalTasks += 1 - update(exec) + maybeUpdate(exec, now) } } @@ -257,7 +273,7 @@ private class AppStatusListener(kvstore: KVStore) extends SparkListener with Log // Call update on the task so that the "getting result" time is written to the store; the // value is part of the mutable TaskInfo state that the live entity already references. liveTasks.get(event.taskInfo.taskId).foreach { task => - update(task) + maybeUpdate(task, System.nanoTime()) } } @@ -267,6 +283,8 @@ private class AppStatusListener(kvstore: KVStore) extends SparkListener with Log return } + val now = System.nanoTime() + val metricsDelta = liveTasks.remove(event.taskInfo.taskId).map { task => val errorMessage = event.reason match { case Success => @@ -283,7 +301,7 @@ private class AppStatusListener(kvstore: KVStore) extends SparkListener with Log } task.errorMessage = errorMessage val delta = task.updateMetrics(event.taskMetrics) - update(task) + update(task, now) delta }.orNull @@ -301,13 +319,13 @@ private class AppStatusListener(kvstore: KVStore) extends SparkListener with Log stage.activeTasks -= 1 stage.completedTasks += completedDelta stage.failedTasks += failedDelta - update(stage) + maybeUpdate(stage, now) stage.jobs.foreach { job => job.activeTasks -= 1 job.completedTasks += completedDelta job.failedTasks += failedDelta - update(job) + maybeUpdate(job, now) } val esummary = stage.executorSummary(event.taskInfo.executorId) @@ -317,7 +335,7 @@ private class AppStatusListener(kvstore: KVStore) extends SparkListener with Log if (metricsDelta != null) { esummary.metrics.update(metricsDelta) } - update(esummary) + maybeUpdate(esummary, now) } liveExecutors.get(event.taskInfo.executorId).foreach { exec => @@ -333,12 +351,13 @@ private class AppStatusListener(kvstore: KVStore) extends SparkListener with Log exec.completedTasks += completedDelta exec.failedTasks += failedDelta exec.totalDuration += event.taskInfo.duration - update(exec) + maybeUpdate(exec, now) } } override def onStageCompleted(event: SparkListenerStageCompleted): Unit = { liveStages.remove((event.stageInfo.stageId, event.stageInfo.attemptId)).foreach { stage => + val now = System.nanoTime() stage.info = event.stageInfo // Because of SPARK-20205, old event logs may contain valid stages without a submission time @@ -349,7 +368,6 @@ private class AppStatusListener(kvstore: KVStore) extends SparkListener with Log case _ if event.stageInfo.submissionTime.isDefined => v1.StageStatus.COMPLETE case _ => v1.StageStatus.SKIPPED } - update(stage) stage.jobs.foreach { job => stage.status match { @@ -362,11 +380,11 @@ private class AppStatusListener(kvstore: KVStore) extends SparkListener with Log job.failedStages += 1 } job.activeStages -= 1 - update(job) + liveUpdate(job, now) } - stage.executorSummaries.values.foreach(update) - update(stage) + stage.executorSummaries.values.foreach(update(_, now)) + update(stage, now) } } @@ -381,7 +399,7 @@ private class AppStatusListener(kvstore: KVStore) extends SparkListener with Log } exec.isActive = true exec.maxMemory = event.maxMem - update(exec) + liveUpdate(exec, System.nanoTime()) } override def onBlockManagerRemoved(event: SparkListenerBlockManagerRemoved): Unit = { @@ -394,19 +412,21 @@ private class AppStatusListener(kvstore: KVStore) extends SparkListener with Log } override def onExecutorMetricsUpdate(event: SparkListenerExecutorMetricsUpdate): Unit = { + val now = System.nanoTime() + event.accumUpdates.foreach { case (taskId, sid, sAttempt, accumUpdates) => liveTasks.get(taskId).foreach { task => val metrics = TaskMetrics.fromAccumulatorInfos(accumUpdates) val delta = task.updateMetrics(metrics) - update(task) + maybeUpdate(task, now) liveStages.get((sid, sAttempt)).foreach { stage => stage.metrics.update(delta) - update(stage) + maybeUpdate(stage, now) val esummary = stage.executorSummary(event.execId) esummary.metrics.update(delta) - update(esummary) + maybeUpdate(esummary, now) } } } @@ -419,7 +439,18 @@ private class AppStatusListener(kvstore: KVStore) extends SparkListener with Log } } + /** Flush all live entities' data to the underlying store. */ + def flush(): Unit = { + val now = System.nanoTime() + liveStages.values.foreach(update(_, now)) + liveJobs.values.foreach(update(_, now)) + liveExecutors.values.foreach(update(_, now)) + liveTasks.values.foreach(update(_, now)) + liveRDDs.values.foreach(update(_, now)) + } + private def updateRDDBlock(event: SparkListenerBlockUpdated, block: RDDBlockId): Unit = { + val now = System.nanoTime() val executorId = event.blockUpdatedInfo.blockManagerId.executorId // Whether values are being added to or removed from the existing accounting. @@ -494,7 +525,7 @@ private class AppStatusListener(kvstore: KVStore) extends SparkListener with Log } rdd.memoryUsed = newValue(rdd.memoryUsed, memoryDelta) rdd.diskUsed = newValue(rdd.diskUsed, diskDelta) - update(rdd) + update(rdd, now) } maybeExec.foreach { exec => @@ -508,9 +539,7 @@ private class AppStatusListener(kvstore: KVStore) extends SparkListener with Log exec.memoryUsed = newValue(exec.memoryUsed, memoryDelta) exec.diskUsed = newValue(exec.diskUsed, diskDelta) exec.rddBlocks += rddBlocksDelta - if (exec.hasMemoryInfo || rddBlocksDelta != 0) { - update(exec) - } + maybeUpdate(exec, now) } } @@ -524,8 +553,22 @@ private class AppStatusListener(kvstore: KVStore) extends SparkListener with Log stage } - private def update(entity: LiveEntity): Unit = { - entity.write(kvstore) + private def update(entity: LiveEntity, now: Long): Unit = { + entity.write(kvstore, now) + } + + /** Update a live entity only if it hasn't been updated in the last configured period. */ + private def maybeUpdate(entity: LiveEntity, now: Long): Unit = { + if (liveUpdatePeriodNs >= 0 && now - entity.lastWriteTime > liveUpdatePeriodNs) { + update(entity, now) + } + } + + /** Update an entity only if in a live app; avoids redundant writes when replaying logs. */ + private def liveUpdate(entity: LiveEntity, now: Long): Unit = { + if (live) { + update(entity, now) + } } } diff --git a/core/src/main/scala/org/apache/spark/status/AppStatusStore.scala b/core/src/main/scala/org/apache/spark/status/AppStatusStore.scala new file mode 100644 index 0000000000000..2927a3227cbef --- /dev/null +++ b/core/src/main/scala/org/apache/spark/status/AppStatusStore.scala @@ -0,0 +1,239 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.status + +import java.io.File +import java.util.{Arrays, List => JList} + +import scala.collection.JavaConverters._ + +import org.apache.spark.{JobExecutionStatus, SparkConf} +import org.apache.spark.scheduler.LiveListenerBus +import org.apache.spark.status.api.v1 +import org.apache.spark.util.{Distribution, Utils} +import org.apache.spark.util.kvstore.{InMemoryStore, KVStore} + +/** + * A wrapper around a KVStore that provides methods for accessing the API data stored within. + */ +private[spark] class AppStatusStore(store: KVStore) { + + def jobsList(statuses: JList[JobExecutionStatus]): Seq[v1.JobData] = { + val it = store.view(classOf[JobDataWrapper]).asScala.map(_.info) + if (!statuses.isEmpty()) { + it.filter { job => statuses.contains(job.status) }.toSeq + } else { + it.toSeq + } + } + + def job(jobId: Int): v1.JobData = { + store.read(classOf[JobDataWrapper], jobId).info + } + + def executorList(activeOnly: Boolean): Seq[v1.ExecutorSummary] = { + store.view(classOf[ExecutorSummaryWrapper]).index("active").reverse().first(true) + .last(true).asScala.map(_.info).toSeq + } + + def stageList(statuses: JList[v1.StageStatus]): Seq[v1.StageData] = { + val it = store.view(classOf[StageDataWrapper]).asScala.map(_.info) + if (!statuses.isEmpty) { + it.filter { s => statuses.contains(s.status) }.toSeq + } else { + it.toSeq + } + } + + def stageData(stageId: Int): Seq[v1.StageData] = { + store.view(classOf[StageDataWrapper]).index("stageId").first(stageId).last(stageId) + .asScala.map(_.info).toSeq + } + + def stageAttempt(stageId: Int, stageAttemptId: Int): v1.StageData = { + store.read(classOf[StageDataWrapper], Array(stageId, stageAttemptId)).info + } + + def taskSummary( + stageId: Int, + stageAttemptId: Int, + quantiles: Array[Double]): v1.TaskMetricDistributions = { + + val stage = Array(stageId, stageAttemptId) + + val rawMetrics = store.view(classOf[TaskDataWrapper]) + .index("stage") + .first(stage) + .last(stage) + .asScala + .flatMap(_.info.taskMetrics) + .toList + .view + + def metricQuantiles(f: v1.TaskMetrics => Double): IndexedSeq[Double] = + Distribution(rawMetrics.map { d => f(d) }).get.getQuantiles(quantiles) + + // We need to do a lot of similar munging to nested metrics here. For each one, + // we want (a) extract the values for nested metrics (b) make a distribution for each metric + // (c) shove the distribution into the right field in our return type and (d) only return + // a result if the option is defined for any of the tasks. MetricHelper is a little util + // to make it a little easier to deal w/ all of the nested options. Mostly it lets us just + // implement one "build" method, which just builds the quantiles for each field. + + val inputMetrics = + new MetricHelper[v1.InputMetrics, v1.InputMetricDistributions](rawMetrics, quantiles) { + def getSubmetrics(raw: v1.TaskMetrics): v1.InputMetrics = raw.inputMetrics + + def build: v1.InputMetricDistributions = new v1.InputMetricDistributions( + bytesRead = submetricQuantiles(_.bytesRead), + recordsRead = submetricQuantiles(_.recordsRead) + ) + }.build + + val outputMetrics = + new MetricHelper[v1.OutputMetrics, v1.OutputMetricDistributions](rawMetrics, quantiles) { + def getSubmetrics(raw: v1.TaskMetrics): v1.OutputMetrics = raw.outputMetrics + + def build: v1.OutputMetricDistributions = new v1.OutputMetricDistributions( + bytesWritten = submetricQuantiles(_.bytesWritten), + recordsWritten = submetricQuantiles(_.recordsWritten) + ) + }.build + + val shuffleReadMetrics = + new MetricHelper[v1.ShuffleReadMetrics, v1.ShuffleReadMetricDistributions](rawMetrics, + quantiles) { + def getSubmetrics(raw: v1.TaskMetrics): v1.ShuffleReadMetrics = + raw.shuffleReadMetrics + + def build: v1.ShuffleReadMetricDistributions = new v1.ShuffleReadMetricDistributions( + readBytes = submetricQuantiles { s => s.localBytesRead + s.remoteBytesRead }, + readRecords = submetricQuantiles(_.recordsRead), + remoteBytesRead = submetricQuantiles(_.remoteBytesRead), + remoteBytesReadToDisk = submetricQuantiles(_.remoteBytesReadToDisk), + remoteBlocksFetched = submetricQuantiles(_.remoteBlocksFetched), + localBlocksFetched = submetricQuantiles(_.localBlocksFetched), + totalBlocksFetched = submetricQuantiles { s => + s.localBlocksFetched + s.remoteBlocksFetched + }, + fetchWaitTime = submetricQuantiles(_.fetchWaitTime) + ) + }.build + + val shuffleWriteMetrics = + new MetricHelper[v1.ShuffleWriteMetrics, v1.ShuffleWriteMetricDistributions](rawMetrics, + quantiles) { + def getSubmetrics(raw: v1.TaskMetrics): v1.ShuffleWriteMetrics = + raw.shuffleWriteMetrics + + def build: v1.ShuffleWriteMetricDistributions = new v1.ShuffleWriteMetricDistributions( + writeBytes = submetricQuantiles(_.bytesWritten), + writeRecords = submetricQuantiles(_.recordsWritten), + writeTime = submetricQuantiles(_.writeTime) + ) + }.build + + new v1.TaskMetricDistributions( + quantiles = quantiles, + executorDeserializeTime = metricQuantiles(_.executorDeserializeTime), + executorDeserializeCpuTime = metricQuantiles(_.executorDeserializeCpuTime), + executorRunTime = metricQuantiles(_.executorRunTime), + executorCpuTime = metricQuantiles(_.executorCpuTime), + resultSize = metricQuantiles(_.resultSize), + jvmGcTime = metricQuantiles(_.jvmGcTime), + resultSerializationTime = metricQuantiles(_.resultSerializationTime), + memoryBytesSpilled = metricQuantiles(_.memoryBytesSpilled), + diskBytesSpilled = metricQuantiles(_.diskBytesSpilled), + inputMetrics = inputMetrics, + outputMetrics = outputMetrics, + shuffleReadMetrics = shuffleReadMetrics, + shuffleWriteMetrics = shuffleWriteMetrics + ) + } + + def taskList( + stageId: Int, + stageAttemptId: Int, + offset: Int, + length: Int, + sortBy: v1.TaskSorting): Seq[v1.TaskData] = { + val stageKey = Array(stageId, stageAttemptId) + val base = store.view(classOf[TaskDataWrapper]) + val indexed = sortBy match { + case v1.TaskSorting.ID => + base.index("stage").first(stageKey).last(stageKey) + case v1.TaskSorting.INCREASING_RUNTIME => + base.index("runtime").first(stageKey ++ Array(-1L)).last(stageKey ++ Array(Long.MaxValue)) + case v1.TaskSorting.DECREASING_RUNTIME => + base.index("runtime").first(stageKey ++ Array(Long.MaxValue)).last(stageKey ++ Array(-1L)) + .reverse() + } + indexed.skip(offset).max(length).asScala.map(_.info).toSeq + } + + def rddList(): Seq[v1.RDDStorageInfo] = { + store.view(classOf[RDDStorageInfoWrapper]).asScala.map(_.info).toSeq + } + + def rdd(rddId: Int): v1.RDDStorageInfo = { + store.read(classOf[RDDStorageInfoWrapper], rddId).info + } + + def close(): Unit = { + store.close() + } + +} + +private[spark] object AppStatusStore { + + val CURRENT_VERSION = 1L + + /** + * Create an in-memory store for a live application. + * + * @param conf Configuration. + * @param bus Where to attach the listener to populate the store. + */ + def createLiveStore(conf: SparkConf, bus: LiveListenerBus): AppStatusStore = { + val store = new InMemoryStore() + val stateStore = new AppStatusStore(store) + bus.addToStatusQueue(new AppStatusListener(store, conf, true)) + stateStore + } + +} + +/** + * Helper for getting distributions from nested metric types. + */ +private abstract class MetricHelper[I, O]( + rawMetrics: Seq[v1.TaskMetrics], + quantiles: Array[Double]) { + + def getSubmetrics(raw: v1.TaskMetrics): I + + def build: O + + val data: Seq[I] = rawMetrics.map(getSubmetrics) + + /** applies the given function to all input metrics, and returns the quantiles */ + def submetricQuantiles(f: I => Double): IndexedSeq[Double] = { + Distribution(data.map { d => f(d) }).get.getQuantiles(quantiles) + } +} diff --git a/core/src/main/scala/org/apache/spark/status/LiveEntity.scala b/core/src/main/scala/org/apache/spark/status/LiveEntity.scala index 63fa36580bc7d..041dfe1ef915e 100644 --- a/core/src/main/scala/org/apache/spark/status/LiveEntity.scala +++ b/core/src/main/scala/org/apache/spark/status/LiveEntity.scala @@ -37,8 +37,11 @@ import org.apache.spark.util.kvstore.KVStore */ private[spark] abstract class LiveEntity { - def write(store: KVStore): Unit = { + var lastWriteTime = 0L + + def write(store: KVStore, now: Long): Unit = { store.write(doUpdate()) + lastWriteTime = now } /** @@ -204,7 +207,7 @@ private class LiveTask( newAccumulatorInfos(info.accumulables), errorMessage, Option(recordedMetrics)) - new TaskDataWrapper(task) + new TaskDataWrapper(task, stageId, stageAttemptId) } } diff --git a/core/src/main/scala/org/apache/spark/status/api/v1/ApiRootResource.scala b/core/src/main/scala/org/apache/spark/status/api/v1/ApiRootResource.scala index f17b637754826..9d3833086172f 100644 --- a/core/src/main/scala/org/apache/spark/status/api/v1/ApiRootResource.scala +++ b/core/src/main/scala/org/apache/spark/status/api/v1/ApiRootResource.scala @@ -248,7 +248,13 @@ private[spark] object ApiRootResource { * interface needed for them all to expose application info as json. */ private[spark] trait UIRoot { - def getSparkUI(appKey: String): Option[SparkUI] + /** + * Runs some code with the current SparkUI instance for the app / attempt. + * + * @throws NoSuchElementException If the app / attempt pair does not exist. + */ + def withSparkUI[T](appId: String, attemptId: Option[String])(fn: SparkUI => T): T + def getApplicationInfoList: Iterator[ApplicationInfo] def getApplicationInfo(appId: String): Option[ApplicationInfo] @@ -293,15 +299,18 @@ private[v1] trait ApiRequestContext { * to it. If there is no such app, throw an appropriate exception */ def withSparkUI[T](appId: String, attemptId: Option[String])(f: SparkUI => T): T = { - val appKey = attemptId.map(appId + "/" + _).getOrElse(appId) - uiRoot.getSparkUI(appKey) match { - case Some(ui) => + try { + uiRoot.withSparkUI(appId, attemptId) { ui => val user = httpRequest.getRemoteUser() if (!ui.securityManager.checkUIViewPermissions(user)) { throw new ForbiddenException(raw"""user "$user" is not authorized""") } f(ui) - case None => throw new NotFoundException("no such app: " + appId) + } + } catch { + case _: NoSuchElementException => + val appKey = attemptId.map(appId + "/" + _).getOrElse(appId) + throw new NotFoundException(s"no such app: $appKey") } } } diff --git a/core/src/main/scala/org/apache/spark/status/config.scala b/core/src/main/scala/org/apache/spark/status/config.scala new file mode 100644 index 0000000000000..49144fc883e69 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/status/config.scala @@ -0,0 +1,30 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.status + +import java.util.concurrent.TimeUnit + +import org.apache.spark.internal.config._ + +private[spark] object config { + + val LIVE_ENTITY_UPDATE_PERIOD = ConfigBuilder("spark.ui.liveUpdate.period") + .timeConf(TimeUnit.NANOSECONDS) + .createWithDefaultString("100ms") + +} diff --git a/core/src/main/scala/org/apache/spark/status/storeTypes.scala b/core/src/main/scala/org/apache/spark/status/storeTypes.scala index 9579accd2cba7..a445435809f3a 100644 --- a/core/src/main/scala/org/apache/spark/status/storeTypes.scala +++ b/core/src/main/scala/org/apache/spark/status/storeTypes.scala @@ -17,12 +17,16 @@ package org.apache.spark.status +import java.lang.{Integer => JInteger, Long => JLong} + import com.fasterxml.jackson.annotation.JsonIgnore import org.apache.spark.status.KVUtils._ import org.apache.spark.status.api.v1._ import org.apache.spark.util.kvstore.KVIndex +private[spark] case class AppStatusStoreMetadata(version: Long) + private[spark] class ApplicationInfoWrapper(val info: ApplicationInfo) { @JsonIgnore @KVIndex @@ -64,13 +68,33 @@ private[spark] class StageDataWrapper( @JsonIgnore @KVIndex def id: Array[Int] = Array(info.stageId, info.attemptId) + @JsonIgnore @KVIndex("stageId") + def stageId: Int = info.stageId + } -private[spark] class TaskDataWrapper(val info: TaskData) { +/** + * The task information is always indexed with the stage ID, since that is how the UI and API + * consume it. That means every indexed value has the stage ID and attempt ID included, aside + * from the actual data being indexed. + */ +private[spark] class TaskDataWrapper( + val info: TaskData, + val stageId: Int, + val stageAttemptId: Int) { @JsonIgnore @KVIndex def id: Long = info.taskId + @JsonIgnore @KVIndex("stage") + def stage: Array[Int] = Array(stageId, stageAttemptId) + + @JsonIgnore @KVIndex("runtime") + def runtime: Array[AnyRef] = { + val _runtime = info.taskMetrics.map(_.executorRunTime).getOrElse(-1L) + Array(stageId: JInteger, stageAttemptId: JInteger, _runtime: JLong) + } + } private[spark] class RDDStorageInfoWrapper(val info: RDDStorageInfo) { diff --git a/core/src/main/scala/org/apache/spark/ui/SparkUI.scala b/core/src/main/scala/org/apache/spark/ui/SparkUI.scala index 6e94073238a56..ee645f6bf8a7a 100644 --- a/core/src/main/scala/org/apache/spark/ui/SparkUI.scala +++ b/core/src/main/scala/org/apache/spark/ui/SparkUI.scala @@ -24,6 +24,7 @@ import scala.collection.JavaConverters._ import org.apache.spark.{SecurityManager, SparkConf, SparkContext} import org.apache.spark.internal.Logging import org.apache.spark.scheduler._ +import org.apache.spark.status.AppStatusStore import org.apache.spark.status.api.v1.{ApiRootResource, ApplicationAttemptInfo, ApplicationInfo, UIRoot} import org.apache.spark.storage.StorageStatusListener @@ -39,6 +40,7 @@ import org.apache.spark.util.Utils * Top level user interface for a Spark application. */ private[spark] class SparkUI private ( + val store: AppStatusStore, val sc: Option[SparkContext], val conf: SparkConf, securityManager: SecurityManager, @@ -51,7 +53,8 @@ private[spark] class SparkUI private ( var appName: String, val basePath: String, val lastUpdateTime: Option[Long] = None, - val startTime: Long) + val startTime: Long, + val appSparkVersion: String) extends WebUI(securityManager, securityManager.getSSLOptions("ui"), SparkUI.getUIPort(conf), conf, basePath, "SparkUI") with Logging @@ -61,8 +64,6 @@ private[spark] class SparkUI private ( var appId: String = _ - var appSparkVersion = org.apache.spark.SPARK_VERSION - private var streamingJobProgressListener: Option[SparkListener] = None /** Initialize all components of the server. */ @@ -104,8 +105,12 @@ private[spark] class SparkUI private ( logInfo(s"Stopped Spark web UI at $webUrl") } - def getSparkUI(appId: String): Option[SparkUI] = { - if (appId == this.appId) Some(this) else None + override def withSparkUI[T](appId: String, attemptId: Option[String])(fn: SparkUI => T): T = { + if (appId == this.appId) { + fn(this) + } else { + throw new NoSuchElementException() + } } def getApplicationInfoList: Iterator[ApplicationInfo] = { @@ -159,63 +164,26 @@ private[spark] object SparkUI { conf.getInt("spark.ui.port", SparkUI.DEFAULT_PORT) } - def createLiveUI( - sc: SparkContext, - conf: SparkConf, - jobProgressListener: JobProgressListener, - securityManager: SecurityManager, - appName: String, - startTime: Long): SparkUI = { - create(Some(sc), conf, - sc.listenerBus.addToStatusQueue, - securityManager, appName, jobProgressListener = Some(jobProgressListener), - startTime = startTime) - } - - def createHistoryUI( - conf: SparkConf, - listenerBus: SparkListenerBus, - securityManager: SecurityManager, - appName: String, - basePath: String, - lastUpdateTime: Option[Long], - startTime: Long): SparkUI = { - val sparkUI = create(None, conf, listenerBus.addListener, securityManager, appName, basePath, - lastUpdateTime = lastUpdateTime, startTime = startTime) - - val listenerFactories = ServiceLoader.load(classOf[SparkHistoryListenerFactory], - Utils.getContextOrSparkClassLoader).asScala - listenerFactories.foreach { listenerFactory => - val listeners = listenerFactory.createListeners(conf, sparkUI) - listeners.foreach(listenerBus.addListener) - } - sparkUI - } - /** - * Create a new Spark UI. - * - * @param sc optional SparkContext; this can be None when reconstituting a UI from event logs. - * @param jobProgressListener if supplied, this JobProgressListener will be used; otherwise, the - * web UI will create and register its own JobProgressListener. + * Create a new UI backed by an AppStatusStore. */ - private def create( + def create( sc: Option[SparkContext], + store: AppStatusStore, conf: SparkConf, addListenerFn: SparkListenerInterface => Unit, securityManager: SecurityManager, appName: String, - basePath: String = "", - jobProgressListener: Option[JobProgressListener] = None, + basePath: String, + startTime: Long, lastUpdateTime: Option[Long] = None, - startTime: Long): SparkUI = { + appSparkVersion: String = org.apache.spark.SPARK_VERSION): SparkUI = { - val _jobProgressListener: JobProgressListener = jobProgressListener.getOrElse { + val jobProgressListener = sc.map(_.jobProgressListener).getOrElse { val listener = new JobProgressListener(conf) addListenerFn(listener) listener } - val environmentListener = new EnvironmentListener val storageStatusListener = new StorageStatusListener(conf) val executorsListener = new ExecutorsListener(storageStatusListener, conf) @@ -228,8 +196,9 @@ private[spark] object SparkUI { addListenerFn(storageListener) addListenerFn(operationGraphListener) - new SparkUI(sc, conf, securityManager, environmentListener, storageStatusListener, - executorsListener, _jobProgressListener, storageListener, operationGraphListener, - appName, basePath, lastUpdateTime, startTime) + new SparkUI(store, sc, conf, securityManager, environmentListener, storageStatusListener, + executorsListener, jobProgressListener, storageListener, operationGraphListener, + appName, basePath, lastUpdateTime, startTime, appSparkVersion) } + } diff --git a/core/src/test/scala/org/apache/spark/deploy/history/ApplicationCacheSuite.scala b/core/src/test/scala/org/apache/spark/deploy/history/ApplicationCacheSuite.scala index 6e50e84549047..44f9c566a380d 100644 --- a/core/src/test/scala/org/apache/spark/deploy/history/ApplicationCacheSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/history/ApplicationCacheSuite.scala @@ -18,15 +18,11 @@ package org.apache.spark.deploy.history import java.util.{Date, NoSuchElementException} -import javax.servlet.Filter import javax.servlet.http.{HttpServletRequest, HttpServletResponse} import scala.collection.mutable -import scala.collection.mutable.ListBuffer import com.codahale.metrics.Counter -import com.google.common.cache.LoadingCache -import com.google.common.util.concurrent.UncheckedExecutionException import org.eclipse.jetty.servlet.ServletContextHandler import org.mockito.Matchers._ import org.mockito.Mockito._ @@ -39,23 +35,10 @@ import org.apache.spark.SparkFunSuite import org.apache.spark.internal.Logging import org.apache.spark.status.api.v1.{ApplicationAttemptInfo => AttemptInfo, ApplicationInfo} import org.apache.spark.ui.SparkUI -import org.apache.spark.util.{Clock, ManualClock, Utils} +import org.apache.spark.util.ManualClock class ApplicationCacheSuite extends SparkFunSuite with Logging with MockitoSugar with Matchers { - /** - * subclass with access to the cache internals - * @param retainedApplications number of retained applications - */ - class TestApplicationCache( - operations: ApplicationCacheOperations = new StubCacheOperations(), - retainedApplications: Int, - clock: Clock = new ManualClock(0)) - extends ApplicationCache(operations, retainedApplications, clock) { - - def cache(): LoadingCache[CacheKey, CacheEntry] = appCache - } - /** * Stub cache operations. * The state is kept in a map of [[CacheKey]] to [[CacheEntry]], @@ -77,8 +60,7 @@ class ApplicationCacheSuite extends SparkFunSuite with Logging with MockitoSugar override def getAppUI(appId: String, attemptId: Option[String]): Option[LoadedAppUI] = { logDebug(s"getAppUI($appId, $attemptId)") getAppUICount += 1 - instances.get(CacheKey(appId, attemptId)).map( e => - LoadedAppUI(e.ui, () => updateProbe(appId, attemptId, e.probeTime))) + instances.get(CacheKey(appId, attemptId)).map { e => e.loadedUI } } override def attachSparkUI( @@ -96,10 +78,9 @@ class ApplicationCacheSuite extends SparkFunSuite with Logging with MockitoSugar attemptId: Option[String], completed: Boolean, started: Long, - ended: Long, - timestamp: Long): SparkUI = { - val ui = putAppUI(appId, attemptId, completed, started, ended, timestamp) - attachSparkUI(appId, attemptId, ui, completed) + ended: Long): LoadedAppUI = { + val ui = putAppUI(appId, attemptId, completed, started, ended) + attachSparkUI(appId, attemptId, ui.ui, completed) ui } @@ -108,23 +89,12 @@ class ApplicationCacheSuite extends SparkFunSuite with Logging with MockitoSugar attemptId: Option[String], completed: Boolean, started: Long, - ended: Long, - timestamp: Long): SparkUI = { - val ui = newUI(appId, attemptId, completed, started, ended) - putInstance(appId, attemptId, ui, completed, timestamp) + ended: Long): LoadedAppUI = { + val ui = LoadedAppUI(newUI(appId, attemptId, completed, started, ended)) + instances(CacheKey(appId, attemptId)) = new CacheEntry(ui, completed) ui } - def putInstance( - appId: String, - attemptId: Option[String], - ui: SparkUI, - completed: Boolean, - timestamp: Long): Unit = { - instances += (CacheKey(appId, attemptId) -> - new CacheEntry(ui, completed, () => updateProbe(appId, attemptId, timestamp), timestamp)) - } - /** * Detach a reconstructed UI * @@ -146,23 +116,6 @@ class ApplicationCacheSuite extends SparkFunSuite with Logging with MockitoSugar attached.get(CacheKey(appId, attemptId)) } - /** - * The update probe. - * @param appId application to probe - * @param attemptId attempt to probe - * @param updateTime timestamp of this UI load - */ - private[history] def updateProbe( - appId: String, - attemptId: Option[String], - updateTime: Long)(): Boolean = { - updateProbeCount += 1 - logDebug(s"isUpdated($appId, $attemptId, ${updateTime})") - val entry = instances.get(CacheKey(appId, attemptId)).get - val updated = entry.probeTime > updateTime - logDebug(s"entry = $entry; updated = $updated") - updated - } } /** @@ -210,15 +163,13 @@ class ApplicationCacheSuite extends SparkFunSuite with Logging with MockitoSugar val now = clock.getTimeMillis() // add the entry - operations.putAppUI(app1, None, true, now, now, now) + operations.putAppUI(app1, None, true, now, now) // make sure its local operations.getAppUI(app1, None).get operations.getAppUICount = 0 // now expect it to be found - val cacheEntry = cache.lookupCacheEntry(app1, None) - assert(1 === cacheEntry.probeTime) - assert(cacheEntry.completed) + cache.withSparkUI(app1, None) { _ => } // assert about queries made of the operations assert(1 === operations.getAppUICount, "getAppUICount") assert(1 === operations.attachCount, "attachCount") @@ -236,8 +187,8 @@ class ApplicationCacheSuite extends SparkFunSuite with Logging with MockitoSugar assert(0 === operations.detachCount, "attachCount") // evict the entry - operations.putAndAttach("2", None, true, time2, time2, time2) - operations.putAndAttach("3", None, true, time2, time2, time2) + operations.putAndAttach("2", None, true, time2, time2) + operations.putAndAttach("3", None, true, time2, time2) cache.get("2") cache.get("3") @@ -248,7 +199,7 @@ class ApplicationCacheSuite extends SparkFunSuite with Logging with MockitoSugar val appId = "app1" val attemptId = Some("_01") val time3 = clock.getTimeMillis() - operations.putAppUI(appId, attemptId, false, time3, 0, time3) + operations.putAppUI(appId, attemptId, false, time3, 0) // expect an error here assertNotFound(appId, None) } @@ -256,10 +207,11 @@ class ApplicationCacheSuite extends SparkFunSuite with Logging with MockitoSugar test("Test that if an attempt ID is set, it must be used in lookups") { val operations = new StubCacheOperations() val clock = new ManualClock(1) - implicit val cache = new ApplicationCache(operations, retainedApplications = 10, clock = clock) + implicit val cache = new ApplicationCache(operations, retainedApplications = 10, + clock = clock) val appId = "app1" val attemptId = Some("_01") - operations.putAppUI(appId, attemptId, false, clock.getTimeMillis(), 0, 0) + operations.putAppUI(appId, attemptId, false, clock.getTimeMillis(), 0) assertNotFound(appId, None) } @@ -271,50 +223,29 @@ class ApplicationCacheSuite extends SparkFunSuite with Logging with MockitoSugar test("Incomplete apps refreshed") { val operations = new StubCacheOperations() val clock = new ManualClock(50) - val window = 500 - implicit val cache = new ApplicationCache(operations, retainedApplications = 5, clock = clock) + implicit val cache = new ApplicationCache(operations, 5, clock) val metrics = cache.metrics // add the incomplete app // add the entry val started = clock.getTimeMillis() val appId = "app1" val attemptId = Some("001") - operations.putAppUI(appId, attemptId, false, started, 0, started) - val firstEntry = cache.lookupCacheEntry(appId, attemptId) - assert(started === firstEntry.probeTime, s"timestamp in $firstEntry") - assert(!firstEntry.completed, s"entry is complete: $firstEntry") - assertMetric("lookupCount", metrics.lookupCount, 1) + val initialUI = operations.putAndAttach(appId, attemptId, false, started, 0) + val firstUI = cache.withSparkUI(appId, attemptId) { ui => ui } + assertMetric("lookupCount", metrics.lookupCount, 1) assert(0 === operations.updateProbeCount, "expected no update probe on that first get") - val checkTime = window * 2 - clock.setTime(checkTime) - val entry3 = cache.lookupCacheEntry(appId, attemptId) - assert(firstEntry !== entry3, s"updated entry test from $cache") + // Invalidate the first entry to trigger a re-load. + initialUI.invalidate() + + // Update the UI in the stub so that a new one is provided to the cache. + operations.putAppUI(appId, attemptId, true, started, started + 10) + + val updatedUI = cache.withSparkUI(appId, attemptId) { ui => ui } + assert(firstUI !== updatedUI, s"expected updated UI") assertMetric("lookupCount", metrics.lookupCount, 2) - assertMetric("updateProbeCount", metrics.updateProbeCount, 1) - assertMetric("updateTriggeredCount", metrics.updateTriggeredCount, 0) - assert(1 === operations.updateProbeCount, s"refresh count in $cache") - assert(0 === operations.detachCount, s"detach count") - assert(entry3.probeTime === checkTime) - - val updateTime = window * 3 - // update the cached value - val updatedApp = operations.putAppUI(appId, attemptId, true, started, updateTime, updateTime) - val endTime = window * 10 - clock.setTime(endTime) - logDebug(s"Before operation = $cache") - val entry5 = cache.lookupCacheEntry(appId, attemptId) - assertMetric("lookupCount", metrics.lookupCount, 3) - assertMetric("updateProbeCount", metrics.updateProbeCount, 2) - // the update was triggered - assertMetric("updateTriggeredCount", metrics.updateTriggeredCount, 1) - assert(updatedApp === entry5.ui, s"UI {$updatedApp} did not match entry {$entry5} in $cache") - - // at which point, the refreshes stop - clock.setTime(window * 20) - assertCacheEntryEquals(appId, attemptId, entry5) - assertMetric("updateProbeCount", metrics.updateProbeCount, 2) + assert(1 === operations.detachCount, s"detach count") } /** @@ -337,27 +268,6 @@ class ApplicationCacheSuite extends SparkFunSuite with Logging with MockitoSugar } } - /** - * Look up the cache entry and assert that it matches in the expected value. - * This assertion works if the two CacheEntries are different -it looks at the fields. - * UI are compared on object equality; the timestamp and completed flags directly. - * @param appId application ID - * @param attemptId attempt ID - * @param expected expected value - * @param cache app cache - */ - def assertCacheEntryEquals( - appId: String, - attemptId: Option[String], - expected: CacheEntry) - (implicit cache: ApplicationCache): Unit = { - val actual = cache.lookupCacheEntry(appId, attemptId) - val errorText = s"Expected get($appId, $attemptId) -> $expected, but got $actual from $cache" - assert(expected.ui === actual.ui, errorText + " SparkUI reference") - assert(expected.completed === actual.completed, errorText + " -completed flag") - assert(expected.probeTime === actual.probeTime, errorText + " -timestamp") - } - /** * Assert that a key wasn't found in cache or loaded. * @@ -370,14 +280,9 @@ class ApplicationCacheSuite extends SparkFunSuite with Logging with MockitoSugar appId: String, attemptId: Option[String]) (implicit cache: ApplicationCache): Unit = { - val ex = intercept[UncheckedExecutionException] { + val ex = intercept[NoSuchElementException] { cache.get(appId, attemptId) } - var cause = ex.getCause - assert(cause !== null) - if (!cause.isInstanceOf[NoSuchElementException]) { - throw cause - } } test("Large Scale Application Eviction") { @@ -385,12 +290,12 @@ class ApplicationCacheSuite extends SparkFunSuite with Logging with MockitoSugar val clock = new ManualClock(0) val size = 5 // only two entries are retained, so we expect evictions to occur on lookups - implicit val cache: ApplicationCache = new TestApplicationCache(operations, - retainedApplications = size, clock = clock) + implicit val cache = new ApplicationCache(operations, retainedApplications = size, + clock = clock) val attempt1 = Some("01") - val ids = new ListBuffer[String]() + val ids = new mutable.ListBuffer[String]() // build a list of applications val count = 100 for (i <- 1 to count ) { @@ -398,7 +303,7 @@ class ApplicationCacheSuite extends SparkFunSuite with Logging with MockitoSugar ids += appId clock.advance(10) val t = clock.getTimeMillis() - operations.putAppUI(appId, attempt1, true, t, t, t) + operations.putAppUI(appId, attempt1, true, t, t) } // now go through them in sequence reading them, expect evictions ids.foreach { id => @@ -413,20 +318,19 @@ class ApplicationCacheSuite extends SparkFunSuite with Logging with MockitoSugar test("Attempts are Evicted") { val operations = new StubCacheOperations() - implicit val cache: ApplicationCache = new TestApplicationCache(operations, - retainedApplications = 4) + implicit val cache = new ApplicationCache(operations, 4, new ManualClock()) val metrics = cache.metrics val appId = "app1" val attempt1 = Some("01") val attempt2 = Some("02") val attempt3 = Some("03") - operations.putAppUI(appId, attempt1, true, 100, 110, 110) - operations.putAppUI(appId, attempt2, true, 200, 210, 210) - operations.putAppUI(appId, attempt3, true, 300, 310, 310) + operations.putAppUI(appId, attempt1, true, 100, 110) + operations.putAppUI(appId, attempt2, true, 200, 210) + operations.putAppUI(appId, attempt3, true, 300, 310) val attempt4 = Some("04") - operations.putAppUI(appId, attempt4, true, 400, 410, 410) + operations.putAppUI(appId, attempt4, true, 400, 410) val attempt5 = Some("05") - operations.putAppUI(appId, attempt5, true, 500, 510, 510) + operations.putAppUI(appId, attempt5, true, 500, 510) def expectLoadAndEvictionCounts(expectedLoad: Int, expectedEvictionCount: Int): Unit = { assertMetric("loadCount", metrics.loadCount, expectedLoad) @@ -457,20 +361,14 @@ class ApplicationCacheSuite extends SparkFunSuite with Logging with MockitoSugar } - test("Instantiate Filter") { - // this is a regression test on the filter being constructable - val clazz = Utils.classForName(ApplicationCacheCheckFilterRelay.FILTER_NAME) - val instance = clazz.newInstance() - instance shouldBe a [Filter] - } - test("redirect includes query params") { - val clazz = Utils.classForName(ApplicationCacheCheckFilterRelay.FILTER_NAME) - val filter = clazz.newInstance().asInstanceOf[ApplicationCacheCheckFilter] - filter.appId = "local-123" + val operations = new StubCacheOperations() + val ui = operations.putAndAttach("foo", None, true, 0, 10) val cache = mock[ApplicationCache] - when(cache.checkForUpdates(any(), any())).thenReturn(true) - ApplicationCacheCheckFilterRelay.setApplicationCache(cache) + when(cache.operations).thenReturn(operations) + val filter = new ApplicationCacheCheckFilter(new CacheKey("foo", None), ui, cache) + ui.invalidate() + val request = mock[HttpServletRequest] when(request.getMethod()).thenReturn("GET") when(request.getRequestURI()).thenReturn("http://localhost:18080/history/local-123/jobs/job/") diff --git a/core/src/test/scala/org/apache/spark/deploy/history/FsHistoryProviderSuite.scala b/core/src/test/scala/org/apache/spark/deploy/history/FsHistoryProviderSuite.scala index 03bd3eaf579f3..86c8cdf43258c 100644 --- a/core/src/test/scala/org/apache/spark/deploy/history/FsHistoryProviderSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/history/FsHistoryProviderSuite.scala @@ -41,6 +41,7 @@ import org.apache.spark.internal.Logging import org.apache.spark.io._ import org.apache.spark.scheduler._ import org.apache.spark.security.GroupMappingServiceProvider +import org.apache.spark.status.AppStatusStore import org.apache.spark.util.{Clock, JsonProtocol, ManualClock, Utils} class FsHistoryProviderSuite extends SparkFunSuite with BeforeAndAfter with Matchers with Logging { @@ -612,7 +613,7 @@ class FsHistoryProviderSuite extends SparkFunSuite with BeforeAndAfter with Matc // Manually overwrite the version in the listing db; this should cause the new provider to // discard all data because the versions don't match. val meta = new FsHistoryProviderMetadata(FsHistoryProvider.CURRENT_LISTING_VERSION + 1, - conf.get(LOCAL_STORE_DIR).get) + AppStatusStore.CURRENT_VERSION, conf.get(LOCAL_STORE_DIR).get) oldProvider.listing.setMetadata(meta) oldProvider.stop() @@ -620,6 +621,43 @@ class FsHistoryProviderSuite extends SparkFunSuite with BeforeAndAfter with Matc assert(mistatchedVersionProvider.listing.count(classOf[ApplicationInfoWrapper]) === 0) } + test("invalidate cached UI") { + val provider = new FsHistoryProvider(createTestConf()) + val appId = "new1" + + // Write an incomplete app log. + val appLog = newLogFile(appId, None, inProgress = true) + writeFile(appLog, true, None, + SparkListenerApplicationStart(appId, Some(appId), 1L, "test", None) + ) + provider.checkForLogs() + + // Load the app UI. + val oldUI = provider.getAppUI(appId, None) + assert(oldUI.isDefined) + intercept[NoSuchElementException] { + oldUI.get.ui.store.job(0) + } + + // Add more info to the app log, and trigger the provider to update things. + writeFile(appLog, true, None, + SparkListenerApplicationStart(appId, Some(appId), 1L, "test", None), + SparkListenerJobStart(0, 1L, Nil, null), + SparkListenerApplicationEnd(5L) + ) + provider.checkForLogs() + + // Manually detach the old UI; ApplicationCache would do this automatically in a real SHS + // when the app's UI was requested. + provider.onUIDetached(appId, None, oldUI.get.ui) + + // Load the UI again and make sure we can get the new info added to the logs. + val freshUI = provider.getAppUI(appId, None) + assert(freshUI.isDefined) + assert(freshUI != oldUI) + freshUI.get.ui.store.job(0) + } + /** * Asks the provider to check for logs and calls a function to perform checks on the updated * app list. Example: diff --git a/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerSuite.scala b/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerSuite.scala index c11543a4b3ba2..010a8dd004d4f 100644 --- a/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerSuite.scala @@ -72,6 +72,8 @@ class HistoryServerSuite extends SparkFunSuite with BeforeAndAfter with Matchers private var port: Int = -1 def init(extraConf: (String, String)*): Unit = { + Utils.deleteRecursively(storeDir) + assert(storeDir.mkdir()) val conf = new SparkConf() .set("spark.history.fs.logDirectory", logDir) .set("spark.history.fs.update.interval", "0") @@ -292,21 +294,8 @@ class HistoryServerSuite extends SparkFunSuite with BeforeAndAfter with Matchers val uiRoot = "/testwebproxybase" System.setProperty("spark.ui.proxyBase", uiRoot) - server.stop() - - val conf = new SparkConf() - .set("spark.history.fs.logDirectory", logDir) - .set("spark.history.fs.update.interval", "0") - .set("spark.testing", "true") - .set(LOCAL_STORE_DIR, storeDir.getAbsolutePath()) - - provider = new FsHistoryProvider(conf) - provider.checkForLogs() - val securityManager = HistoryServer.createSecurityManager(conf) - - server = new HistoryServer(conf, provider, securityManager, 18080) - server.initialize() - server.bind() + stop() + init() val port = server.boundPort @@ -375,8 +364,6 @@ class HistoryServerSuite extends SparkFunSuite with BeforeAndAfter with Matchers } test("incomplete apps get refreshed") { - server.stop() - implicit val webDriver: WebDriver = new HtmlUnitDriver implicit val formats = org.json4s.DefaultFormats @@ -386,6 +373,7 @@ class HistoryServerSuite extends SparkFunSuite with BeforeAndAfter with Matchers // a new conf is used with the background thread set and running at its fastest // allowed refresh rate (1Hz) + stop() val myConf = new SparkConf() .set("spark.history.fs.logDirectory", logDir.getAbsolutePath) .set("spark.eventLog.dir", logDir.getAbsolutePath) @@ -418,7 +406,7 @@ class HistoryServerSuite extends SparkFunSuite with BeforeAndAfter with Matchers } } - server = new HistoryServer(myConf, provider, securityManager, 18080) + server = new HistoryServer(myConf, provider, securityManager, 0) server.initialize() server.bind() val port = server.boundPort @@ -464,7 +452,7 @@ class HistoryServerSuite extends SparkFunSuite with BeforeAndAfter with Matchers rootAppPage should not be empty def getAppUI: SparkUI = { - provider.getAppUI(appId, None).get.ui + server.withSparkUI(appId, None) { ui => ui } } // selenium isn't that useful on failures...add our own reporting @@ -519,7 +507,7 @@ class HistoryServerSuite extends SparkFunSuite with BeforeAndAfter with Matchers getNumJobs("") should be (1) getNumJobs("/jobs") should be (1) getNumJobsRestful() should be (1) - assert(metrics.lookupCount.getCount > 1, s"lookup count too low in $metrics") + assert(metrics.lookupCount.getCount > 0, s"lookup count too low in $metrics") // dump state before the next bit of test, which is where update // checking really gets stressed diff --git a/core/src/test/scala/org/apache/spark/status/AppStatusListenerSuite.scala b/core/src/test/scala/org/apache/spark/status/AppStatusListenerSuite.scala index 6f7a0c14dd684..7ac1ce19f8ddf 100644 --- a/core/src/test/scala/org/apache/spark/status/AppStatusListenerSuite.scala +++ b/core/src/test/scala/org/apache/spark/status/AppStatusListenerSuite.scala @@ -18,7 +18,8 @@ package org.apache.spark.status import java.io.File -import java.util.{Date, Properties} +import java.lang.{Integer => JInteger, Long => JLong} +import java.util.{Arrays, Date, Properties} import scala.collection.JavaConverters._ import scala.reflect.{classTag, ClassTag} @@ -36,6 +37,10 @@ import org.apache.spark.util.kvstore._ class AppStatusListenerSuite extends SparkFunSuite with BeforeAndAfter { + import config._ + + private val conf = new SparkConf().set(LIVE_ENTITY_UPDATE_PERIOD, 0L) + private var time: Long = _ private var testDir: File = _ private var store: KVStore = _ @@ -52,7 +57,7 @@ class AppStatusListenerSuite extends SparkFunSuite with BeforeAndAfter { } test("scheduler events") { - val listener = new AppStatusListener(store) + val listener = new AppStatusListener(store, conf, true) // Start the application. time += 1 @@ -174,6 +179,14 @@ class AppStatusListenerSuite extends SparkFunSuite with BeforeAndAfter { s1Tasks.foreach { task => check[TaskDataWrapper](task.taskId) { wrapper => assert(wrapper.info.taskId === task.taskId) + assert(wrapper.stageId === stages.head.stageId) + assert(wrapper.stageAttemptId === stages.head.attemptId) + assert(Arrays.equals(wrapper.stage, Array(stages.head.stageId, stages.head.attemptId))) + + val runtime = Array[AnyRef](stages.head.stageId: JInteger, stages.head.attemptId: JInteger, + -1L: JLong) + assert(Arrays.equals(wrapper.runtime, runtime)) + assert(wrapper.info.index === task.index) assert(wrapper.info.attempt === task.attemptNumber) assert(wrapper.info.launchTime === new Date(task.launchTime)) @@ -510,7 +523,7 @@ class AppStatusListenerSuite extends SparkFunSuite with BeforeAndAfter { } test("storage events") { - val listener = new AppStatusListener(store) + val listener = new AppStatusListener(store, conf, true) val maxMemory = 42L // Register a couple of block managers. diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index 45b8870f3b62f..99cac34c85ebc 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -38,6 +38,8 @@ object MimaExcludes { lazy val v23excludes = v22excludes ++ Seq( // SPARK-18085: Better History Server scalability for many / large applications ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.status.api.v1.ExecutorSummary.executorLogs"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.deploy.history.HistoryServer.getSparkUI"), + // [SPARK-20495][SQL] Add StorageLevel to cacheTable API ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.sql.catalog.Catalog.cacheTable"), From 65a8bf6036fe41a53b4b1e4298fa35d7fa4e9970 Mon Sep 17 00:00:00 2001 From: Shivaram Venkataraman Date: Mon, 6 Nov 2017 08:58:42 -0800 Subject: [PATCH 1611/1765] [SPARK-22315][SPARKR] Warn if SparkR package version doesn't match SparkContext ## What changes were proposed in this pull request? This PR adds a check between the R package version used and the version reported by SparkContext running in the JVM. The goal here is to warn users when they have a R package downloaded from CRAN and are using that to connect to an existing Spark cluster. This is raised as a warning rather than an error as users might want to use patch versions interchangeably (e.g., 2.1.3 with 2.1.2 etc.) ## How was this patch tested? Manually by changing the `DESCRIPTION` file Author: Shivaram Venkataraman Closes #19624 from shivaram/sparkr-version-check. --- R/pkg/R/sparkR.R | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/R/pkg/R/sparkR.R b/R/pkg/R/sparkR.R index 81507ea7186af..fb5f1d21fc723 100644 --- a/R/pkg/R/sparkR.R +++ b/R/pkg/R/sparkR.R @@ -420,6 +420,18 @@ sparkR.session <- function( enableHiveSupport) assign(".sparkRsession", sparkSession, envir = .sparkREnv) } + + # Check if version number of SparkSession matches version number of SparkR package + jvmVersion <- callJMethod(sparkSession, "version") + # Remove -SNAPSHOT from jvm versions + jvmVersionStrip <- gsub("-SNAPSHOT", "", jvmVersion) + rPackageVersion <- paste0(packageVersion("SparkR")) + + if (jvmVersionStrip != rPackageVersion) { + warning(paste("Version mismatch between Spark JVM and SparkR package. JVM version was", + jvmVersion, ", while R package version was", rPackageVersion)) + } + sparkSession } From 5014d6e2568021e6958eddc9cfb4c512ed7424a2 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Mon, 6 Nov 2017 22:25:11 +0100 Subject: [PATCH 1612/1765] [SPARK-22078][SQL] clarify exception behaviors for all data source v2 interfaces ## What changes were proposed in this pull request? clarify exception behaviors for all data source v2 interfaces. ## How was this patch tested? document change only Author: Wenchen Fan Closes #19623 from cloud-fan/data-source-exception. --- .../spark/sql/sources/v2/ReadSupport.java | 3 ++ .../sql/sources/v2/ReadSupportWithSchema.java | 3 ++ .../spark/sql/sources/v2/WriteSupport.java | 3 ++ .../sql/sources/v2/reader/DataReader.java | 11 +++++++- .../sources/v2/reader/DataSourceV2Reader.java | 9 ++++++ .../spark/sql/sources/v2/reader/ReadTask.java | 12 +++++++- .../SupportsPushDownCatalystFilters.java | 1 - .../v2/reader/SupportsScanUnsafeRow.java | 1 - .../sources/v2/writer/DataSourceV2Writer.java | 28 ++++++++++++------- .../sql/sources/v2/writer/DataWriter.java | 20 +++++++++---- .../sources/v2/writer/DataWriterFactory.java | 3 ++ .../v2/writer/SupportsWriteInternalRow.java | 3 -- 12 files changed, 74 insertions(+), 23 deletions(-) diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/ReadSupport.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/ReadSupport.java index ee489ad0f608f..948e20bacf4a2 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/ReadSupport.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/ReadSupport.java @@ -30,6 +30,9 @@ public interface ReadSupport { /** * Creates a {@link DataSourceV2Reader} to scan the data from this data source. * + * If this method fails (by throwing an exception), the action would fail and no Spark job was + * submitted. + * * @param options the options for the returned data source reader, which is an immutable * case-insensitive string-to-string map. */ diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/ReadSupportWithSchema.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/ReadSupportWithSchema.java index 74e81a2c84d68..b69c6bed8d1b5 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/ReadSupportWithSchema.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/ReadSupportWithSchema.java @@ -35,6 +35,9 @@ public interface ReadSupportWithSchema { /** * Create a {@link DataSourceV2Reader} to scan the data from this data source. * + * If this method fails (by throwing an exception), the action would fail and no Spark job was + * submitted. + * * @param schema the full schema of this data source reader. Full schema usually maps to the * physical schema of the underlying storage of this data source reader, e.g. * CSV files, JSON files, etc, while this reader may not read data with full diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/WriteSupport.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/WriteSupport.java index 8fdfdfd19ea1e..1e3b644d8c4ae 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/WriteSupport.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/WriteSupport.java @@ -35,6 +35,9 @@ public interface WriteSupport { * Creates an optional {@link DataSourceV2Writer} to save the data to this data source. Data * sources can return None if there is no writing needed to be done according to the save mode. * + * If this method fails (by throwing an exception), the action would fail and no Spark job was + * submitted. + * * @param jobId A unique string for the writing job. It's possible that there are many writing * jobs running at the same time, and the returned {@link DataSourceV2Writer} can * use this job id to distinguish itself from other jobs. diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/DataReader.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/DataReader.java index 52bb138673fc9..8f58c865b6201 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/DataReader.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/DataReader.java @@ -18,6 +18,7 @@ package org.apache.spark.sql.sources.v2.reader; import java.io.Closeable; +import java.io.IOException; import org.apache.spark.annotation.InterfaceStability; @@ -34,11 +35,19 @@ public interface DataReader extends Closeable { /** * Proceed to next record, returns false if there is no more records. + * + * If this method fails (by throwing an exception), the corresponding Spark task would fail and + * get retried until hitting the maximum retry times. + * + * @throws IOException if failure happens during disk/network IO like reading files. */ - boolean next(); + boolean next() throws IOException; /** * Return the current record. This method should return same value until `next` is called. + * + * If this method fails (by throwing an exception), the corresponding Spark task would fail and + * get retried until hitting the maximum retry times. */ T get(); } diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/DataSourceV2Reader.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/DataSourceV2Reader.java index 88c3219a75c1d..95ee4a8278322 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/DataSourceV2Reader.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/DataSourceV2Reader.java @@ -40,6 +40,9 @@ * 3. Special scans. E.g, columnar scan, unsafe row scan, etc. * Names of these interfaces start with `SupportsScan`. * + * If an exception was throw when applying any of these query optimizations, the action would fail + * and no Spark job was submitted. + * * Spark first applies all operator push-down optimizations that this data source supports. Then * Spark collects information this data source reported for further optimizations. Finally Spark * issues the scan request and does the actual data reading. @@ -50,6 +53,9 @@ public interface DataSourceV2Reader { /** * Returns the actual schema of this data source reader, which may be different from the physical * schema of the underlying storage, as column pruning or other optimizations may happen. + * + * If this method fails (by throwing an exception), the action would fail and no Spark job was + * submitted. */ StructType readSchema(); @@ -61,6 +67,9 @@ public interface DataSourceV2Reader { * Note that, this may not be a full scan if the data source reader mixes in other optimization * interfaces like column pruning, filter push-down, etc. These optimizations are applied before * Spark issues the scan request. + * + * If this method fails (by throwing an exception), the action would fail and no Spark job was + * submitted. */ List> createReadTasks(); } diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/ReadTask.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/ReadTask.java index 44786db419a32..fa161cdb8b347 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/ReadTask.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/ReadTask.java @@ -36,7 +36,14 @@ public interface ReadTask extends Serializable { /** * The preferred locations where this read task can run faster, but Spark does not guarantee that * this task will always run on these locations. The implementations should make sure that it can - * be run on any location. The location is a string representing the host name of an executor. + * be run on any location. The location is a string representing the host name. + * + * Note that if a host name cannot be recognized by Spark, it will be ignored as it was not in + * the returned locations. By default this method returns empty string array, which means this + * task has no location preference. + * + * If this method fails (by throwing an exception), the action would fail and no Spark job was + * submitted. */ default String[] preferredLocations() { return new String[0]; @@ -44,6 +51,9 @@ default String[] preferredLocations() { /** * Returns a data reader to do the actual reading work for this read task. + * + * If this method fails (by throwing an exception), the corresponding Spark task would fail and + * get retried until hitting the maximum retry times. */ DataReader createDataReader(); } diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsPushDownCatalystFilters.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsPushDownCatalystFilters.java index efc42242f4421..f76c687f450c8 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsPushDownCatalystFilters.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsPushDownCatalystFilters.java @@ -17,7 +17,6 @@ package org.apache.spark.sql.sources.v2.reader; -import org.apache.spark.annotation.Experimental; import org.apache.spark.annotation.InterfaceStability; import org.apache.spark.sql.catalyst.expressions.Expression; diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsScanUnsafeRow.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsScanUnsafeRow.java index 6008fb5f71cc1..b90ec880dc85e 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsScanUnsafeRow.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsScanUnsafeRow.java @@ -19,7 +19,6 @@ import java.util.List; -import org.apache.spark.annotation.Experimental; import org.apache.spark.annotation.InterfaceStability; import org.apache.spark.sql.Row; import org.apache.spark.sql.catalyst.expressions.UnsafeRow; diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataSourceV2Writer.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataSourceV2Writer.java index 37bb15f87c59a..fc37b9a516f82 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataSourceV2Writer.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataSourceV2Writer.java @@ -30,6 +30,9 @@ * It can mix in various writing optimization interfaces to speed up the data saving. The actual * writing logic is delegated to {@link DataWriter}. * + * If an exception was throw when applying any of these writing optimizations, the action would fail + * and no Spark job was submitted. + * * The writing procedure is: * 1. Create a writer factory by {@link #createWriterFactory()}, serialize and send it to all the * partitions of the input data(RDD). @@ -50,28 +53,33 @@ public interface DataSourceV2Writer { /** * Creates a writer factory which will be serialized and sent to executors. + * + * If this method fails (by throwing an exception), the action would fail and no Spark job was + * submitted. */ DataWriterFactory createWriterFactory(); /** * Commits this writing job with a list of commit messages. The commit messages are collected from - * successful data writers and are produced by {@link DataWriter#commit()}. If this method - * fails(throw exception), this writing job is considered to be failed, and - * {@link #abort(WriterCommitMessage[])} will be called. The written data should only be visible - * to data source readers if this method succeeds. + * successful data writers and are produced by {@link DataWriter#commit()}. + * + * If this method fails (by throwing an exception), this writing job is considered to to have been + * failed, and {@link #abort(WriterCommitMessage[])} would be called. The state of the destination + * is undefined and @{@link #abort(WriterCommitMessage[])} may not be able to deal with it. * * Note that, one partition may have multiple committed data writers because of speculative tasks. * Spark will pick the first successful one and get its commit message. Implementations should be - * aware of this and handle it correctly, e.g., have a mechanism to make sure only one data writer - * can commit successfully, or have a way to clean up the data of already-committed writers. + * aware of this and handle it correctly, e.g., have a coordinator to make sure only one data + * writer can commit, or have a way to clean up the data of already-committed writers. */ void commit(WriterCommitMessage[] messages); /** - * Aborts this writing job because some data writers are failed to write the records and aborted, - * or the Spark job fails with some unknown reasons, or {@link #commit(WriterCommitMessage[])} - * fails. If this method fails(throw exception), the underlying data source may have garbage that - * need to be cleaned manually, but these garbage should not be visible to data source readers. + * Aborts this writing job because some data writers are failed and keep failing when retry, or + * the Spark job fails with some unknown reasons, or {@link #commit(WriterCommitMessage[])} fails. + * + * If this method fails (by throwing an exception), the underlying data source may require manual + * cleanup. * * Unless the abort is triggered by the failure of commit, the given messages should have some * null slots as there maybe only a few data writers that are committed before the abort diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataWriter.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataWriter.java index dc1aab33bdcef..04b03e63de500 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataWriter.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataWriter.java @@ -17,6 +17,8 @@ package org.apache.spark.sql.sources.v2.writer; +import java.io.IOException; + import org.apache.spark.annotation.InterfaceStability; /** @@ -59,8 +61,10 @@ public interface DataWriter { * * If this method fails (by throwing an exception), {@link #abort()} will be called and this * data writer is considered to have been failed. + * + * @throws IOException if failure happens during disk/network IO like writing files. */ - void write(T record); + void write(T record) throws IOException; /** * Commits this writer after all records are written successfully, returns a commit message which @@ -74,8 +78,10 @@ public interface DataWriter { * * If this method fails (by throwing an exception), {@link #abort()} will be called and this * data writer is considered to have been failed. + * + * @throws IOException if failure happens during disk/network IO like writing files. */ - WriterCommitMessage commit(); + WriterCommitMessage commit() throws IOException; /** * Aborts this writer if it is failed. Implementations should clean up the data for already @@ -84,9 +90,11 @@ public interface DataWriter { * This method will only be called if there is one record failed to write, or {@link #commit()} * failed. * - * If this method fails(throw exception), the underlying data source may have garbage that need - * to be cleaned by {@link DataSourceV2Writer#abort(WriterCommitMessage[])} or manually, but - * these garbage should not be visible to data source readers. + * If this method fails(by throwing an exception), the underlying data source may have garbage + * that need to be cleaned by {@link DataSourceV2Writer#abort(WriterCommitMessage[])} or manually, + * but these garbage should not be visible to data source readers. + * + * @throws IOException if failure happens during disk/network IO like writing files. */ - void abort(); + void abort() throws IOException; } diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataWriterFactory.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataWriterFactory.java index fe56cc00d1c7a..18ec792f5a2c9 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataWriterFactory.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataWriterFactory.java @@ -35,6 +35,9 @@ public interface DataWriterFactory extends Serializable { /** * Returns a data writer to do the actual writing work. * + * If this method fails (by throwing an exception), the action would fail and no Spark job was + * submitted. + * * @param partitionId A unique id of the RDD partition that the returned writer will process. * Usually Spark processes many RDD partitions at the same time, * implementations should use the partition id to distinguish writers for diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/SupportsWriteInternalRow.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/SupportsWriteInternalRow.java index a8e95901f3b07..3e0518814f458 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/SupportsWriteInternalRow.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/SupportsWriteInternalRow.java @@ -17,7 +17,6 @@ package org.apache.spark.sql.sources.v2.writer; -import org.apache.spark.annotation.Experimental; import org.apache.spark.annotation.InterfaceStability; import org.apache.spark.sql.Row; import org.apache.spark.sql.catalyst.InternalRow; @@ -29,8 +28,6 @@ * changed in the future Spark versions. */ -@InterfaceStability.Evolving -@Experimental @InterfaceStability.Unstable public interface SupportsWriteInternalRow extends DataSourceV2Writer { From 14a32a647a019ed45793784069fbf077b9c45e60 Mon Sep 17 00:00:00 2001 From: Alexander Istomin Date: Tue, 7 Nov 2017 00:47:16 +0100 Subject: [PATCH 1613/1765] [SPARK-22330][CORE] Linear containsKey operation for serialized maps MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit …alization. ## What changes were proposed in this pull request? Use non-linear containsKey operation for serialized maps, lookup into underlying map. ## How was this patch tested? unit tests Author: Alexander Istomin Closes #19553 from Whoosh/SPARK-22330. --- .../org/apache/spark/api/java/JavaUtils.scala | 9 +++- .../spark/api/java/JavaUtilsSuite.scala | 49 +++++++++++++++++++ 2 files changed, 57 insertions(+), 1 deletion(-) create mode 100644 core/src/test/scala/org/apache/spark/api/java/JavaUtilsSuite.scala diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaUtils.scala b/core/src/main/scala/org/apache/spark/api/java/JavaUtils.scala index d6506231b8d74..fd96052f95d3f 100644 --- a/core/src/main/scala/org/apache/spark/api/java/JavaUtils.scala +++ b/core/src/main/scala/org/apache/spark/api/java/JavaUtils.scala @@ -43,10 +43,17 @@ private[spark] object JavaUtils { override def size: Int = underlying.size + // Delegate to implementation because AbstractMap implementation iterates over whole key set + override def containsKey(key: AnyRef): Boolean = try { + underlying.contains(key.asInstanceOf[A]) + } catch { + case _: ClassCastException => false + } + override def get(key: AnyRef): B = try { underlying.getOrElse(key.asInstanceOf[A], null.asInstanceOf[B]) } catch { - case ex: ClassCastException => null.asInstanceOf[B] + case _: ClassCastException => null.asInstanceOf[B] } override def entrySet: ju.Set[ju.Map.Entry[A, B]] = new ju.AbstractSet[ju.Map.Entry[A, B]] { diff --git a/core/src/test/scala/org/apache/spark/api/java/JavaUtilsSuite.scala b/core/src/test/scala/org/apache/spark/api/java/JavaUtilsSuite.scala new file mode 100644 index 0000000000000..8e6e3e0968617 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/api/java/JavaUtilsSuite.scala @@ -0,0 +1,49 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.api.java + +import java.io.Serializable + +import org.mockito.Mockito._ + +import org.apache.spark.SparkFunSuite + + +class JavaUtilsSuite extends SparkFunSuite { + + test("containsKey implementation without iteratively entrySet call") { + val src = new scala.collection.mutable.HashMap[Double, String] + val key: Double = 42.5 + val key2 = "key" + + src.put(key, "42") + + val map: java.util.Map[Double, String] = spy(JavaUtils.mapAsSerializableJavaMap(src)) + + assert(map.containsKey(key)) + + // ClassCast checking, shouldn't throw exception + assert(!map.containsKey(key2)) + assert(map.get(key2) == null) + + assert(map.get(key).eq("42")) + assert(map.isInstanceOf[Serializable]) + + verify(map, never()).entrySet() + } +} From 9df08e218cfd4dd91bc407b98528b74f452f34f8 Mon Sep 17 00:00:00 2001 From: Yuming Wang Date: Tue, 7 Nov 2017 08:30:58 +0000 Subject: [PATCH 1614/1765] [SPARK-22454][CORE] ExternalShuffleClient.close() should check clientFactory null ## What changes were proposed in this pull request? `ExternalShuffleClient.close()` should check `clientFactory` null. otherwise it will throw NPE sometimes: ``` 17/11/06 20:08:05 ERROR Utils: Uncaught exception in thread main java.lang.NullPointerException at org.apache.spark.network.shuffle.ExternalShuffleClient.close(ExternalShuffleClient.java:152) at org.apache.spark.storage.BlockManager.stop(BlockManager.scala:1407) at org.apache.spark.SparkEnv.stop(SparkEnv.scala:89) at org.apache.spark.SparkContext$$anonfun$stop$11.apply$mcV$sp(SparkContext.scala:1849) ``` ## How was this patch tested? manual tests Author: Yuming Wang Closes #19670 from wangyum/SPARK-22454. --- .../apache/spark/network/shuffle/ExternalShuffleClient.java | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleClient.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleClient.java index 510017fee2db5..7ed0b6e93a7a8 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleClient.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleClient.java @@ -148,6 +148,9 @@ public void registerWithShuffleServer( @Override public void close() { checkInit(); - clientFactory.close(); + if (clientFactory != null) { + clientFactory.close(); + clientFactory = null; + } } } From ed1478cfe173a44876b586943d61cdc93b0869a2 Mon Sep 17 00:00:00 2001 From: Xingbo Jiang Date: Tue, 7 Nov 2017 17:57:47 +0900 Subject: [PATCH 1615/1765] [BUILD] Close stale PRs Closes #11494 Closes #14158 Closes #16803 Closes #16864 Closes #17455 Closes #17936 Closes #19377 Added: Closes #19380 Closes #18642 Closes #18377 Closes #19632 Added: Closes #14471 Closes #17402 Closes #17953 Closes #18607 Also cc srowen vanzin HyukjinKwon gatorsmile cloud-fan to see if you have other PRs to close. Author: Xingbo Jiang Closes #19669 from jiangxb1987/stale-prs. From 160a540610051ee3233ea24102533b08b69f03fc Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Tue, 7 Nov 2017 19:45:34 +0900 Subject: [PATCH 1616/1765] [SPARK-22376][TESTS] Makes dev/run-tests.py script compatible with Python 3 ## What changes were proposed in this pull request? This PR proposes to fix `dev/run-tests.py` script to support Python 3. Here are some backgrounds. Up to my knowledge, In Python 2, - `unicode` is NOT `str` in Python 2 (`type("foo") != type(u"foo")`). - `str` has an alias, `bytes` in Python 2 (`type("foo") == type(b"foo")`). In Python 3, - `unicode` was (roughly) replaced by `str` in Python 3 (`type("foo") == type(u"foo")`). - `str` is NOT `bytes` in Python 3 (`type("foo") != type(b"foo")`). So, this PR fixes: 1. Use `b''` instead of `''` so that both `str` in Python 2 and `bytes` in Python 3 can be hanlded. `sbt_proc.stdout.readline()` returns `str` (which has an alias, `bytes`) in Python 2 and `bytes` in Python 3 2. Similarily, use `b''` instead of `''` so that both `str` in Python 2 and `bytes` in Python 3 can be hanlded. `re.compile` with `str` pattern does not seem supporting to match `bytes` in Python 3: Actually, this change is recommended up to my knowledge - https://docs.python.org/3/howto/pyporting.html#text-versus-binary-data: > Mark all binary literals with a b prefix, textual literals with a u prefix ## How was this patch tested? I manually tested this via Python 3 with few additional changes to reduce the elapsed time. Author: hyukjinkwon Closes #19665 from HyukjinKwon/SPARK-22376. --- dev/run-tests.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/dev/run-tests.py b/dev/run-tests.py index 72d148d7ea0fb..ef0e788a91606 100755 --- a/dev/run-tests.py +++ b/dev/run-tests.py @@ -276,9 +276,9 @@ def exec_sbt(sbt_args=()): sbt_cmd = [os.path.join(SPARK_HOME, "build", "sbt")] + sbt_args - sbt_output_filter = re.compile("^.*[info].*Resolving" + "|" + - "^.*[warn].*Merging" + "|" + - "^.*[info].*Including") + sbt_output_filter = re.compile(b"^.*[info].*Resolving" + b"|" + + b"^.*[warn].*Merging" + b"|" + + b"^.*[info].*Including") # NOTE: echo "q" is needed because sbt on encountering a build file # with failure (either resolution or compilation) prompts the user for @@ -289,7 +289,7 @@ def exec_sbt(sbt_args=()): stdin=echo_proc.stdout, stdout=subprocess.PIPE) echo_proc.wait() - for line in iter(sbt_proc.stdout.readline, ''): + for line in iter(sbt_proc.stdout.readline, b''): if not sbt_output_filter.match(line): print(line, end='') retcode = sbt_proc.wait() From d5202259d9aa9ad95d572af253bf4a722b7b437a Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Tue, 7 Nov 2017 09:33:52 -0800 Subject: [PATCH 1617/1765] [SPARK-21127][SQL][FOLLOWUP] fix a config name typo ## What changes were proposed in this pull request? `spark.sql.statistics.autoUpdate.size` should be `spark.sql.statistics.size.autoUpdate.enabled`. The previous name is confusing as users may treat it as a size config. This config is in master branch only, no backward compatibility issue. ## How was this patch tested? N/A Author: Wenchen Fan Closes #19667 from cloud-fan/minor. --- .../scala/org/apache/spark/sql/internal/SQLConf.scala | 6 +++--- .../spark/sql/execution/command/CommandUtils.scala | 2 +- .../org/apache/spark/sql/execution/command/ddl.scala | 2 +- .../apache/spark/sql/StatisticsCollectionSuite.scala | 10 +++++----- .../org/apache/spark/sql/hive/StatisticsSuite.scala | 6 +++--- 5 files changed, 13 insertions(+), 13 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index ede116e964a03..a04f8778079de 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -812,8 +812,8 @@ object SQLConf { .doubleConf .createWithDefault(0.05) - val AUTO_UPDATE_SIZE = - buildConf("spark.sql.statistics.autoUpdate.size") + val AUTO_SIZE_UPDATE_ENABLED = + buildConf("spark.sql.statistics.size.autoUpdate.enabled") .doc("Enables automatic update for table size once table's data is changed. Note that if " + "the total number of files of the table is very large, this can be expensive and slow " + "down data change commands.") @@ -1206,7 +1206,7 @@ class SQLConf extends Serializable with Logging { def cboEnabled: Boolean = getConf(SQLConf.CBO_ENABLED) - def autoUpdateSize: Boolean = getConf(SQLConf.AUTO_UPDATE_SIZE) + def autoSizeUpdateEnabled: Boolean = getConf(SQLConf.AUTO_SIZE_UPDATE_ENABLED) def joinReorderEnabled: Boolean = getConf(SQLConf.JOIN_REORDER_ENABLED) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/CommandUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/CommandUtils.scala index b22958d59336c..1a0d67fc71fbc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/CommandUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/CommandUtils.scala @@ -36,7 +36,7 @@ object CommandUtils extends Logging { def updateTableStats(sparkSession: SparkSession, table: CatalogTable): Unit = { if (table.stats.nonEmpty) { val catalog = sparkSession.sessionState.catalog - if (sparkSession.sessionState.conf.autoUpdateSize) { + if (sparkSession.sessionState.conf.autoSizeUpdateEnabled) { val newTable = catalog.getTableMetadata(table.identifier) val newSize = CommandUtils.calculateTotalSize(sparkSession.sessionState, newTable) val newStats = CatalogStatistics(sizeInBytes = newSize) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala index a9cd65e3242c9..568567aa8ea88 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala @@ -442,7 +442,7 @@ case class AlterTableAddPartitionCommand( catalog.createPartitions(table.identifier, parts, ignoreIfExists = ifNotExists) if (table.stats.nonEmpty) { - if (sparkSession.sessionState.conf.autoUpdateSize) { + if (sparkSession.sessionState.conf.autoSizeUpdateEnabled) { val addedSize = parts.map { part => CommandUtils.calculateLocationSize(sparkSession.sessionState, table.identifier, part.storage.locationUri) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/StatisticsCollectionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/StatisticsCollectionSuite.scala index 2fc92f4aff92e..7247c3a876df3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/StatisticsCollectionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/StatisticsCollectionSuite.scala @@ -216,7 +216,7 @@ class StatisticsCollectionSuite extends StatisticsCollectionTestBase with Shared test("change stats after set location command") { val table = "change_stats_set_location_table" Seq(false, true).foreach { autoUpdate => - withSQLConf(SQLConf.AUTO_UPDATE_SIZE.key -> autoUpdate.toString) { + withSQLConf(SQLConf.AUTO_SIZE_UPDATE_ENABLED.key -> autoUpdate.toString) { withTable(table) { spark.range(100).select($"id", $"id" % 5 as "value").write.saveAsTable(table) // analyze to get initial stats @@ -252,7 +252,7 @@ class StatisticsCollectionSuite extends StatisticsCollectionTestBase with Shared test("change stats after insert command for datasource table") { val table = "change_stats_insert_datasource_table" Seq(false, true).foreach { autoUpdate => - withSQLConf(SQLConf.AUTO_UPDATE_SIZE.key -> autoUpdate.toString) { + withSQLConf(SQLConf.AUTO_SIZE_UPDATE_ENABLED.key -> autoUpdate.toString) { withTable(table) { sql(s"CREATE TABLE $table (i int, j string) USING PARQUET") // analyze to get initial stats @@ -285,7 +285,7 @@ class StatisticsCollectionSuite extends StatisticsCollectionTestBase with Shared test("invalidation of tableRelationCache after inserts") { val table = "invalidate_catalog_cache_table" Seq(false, true).foreach { autoUpdate => - withSQLConf(SQLConf.AUTO_UPDATE_SIZE.key -> autoUpdate.toString) { + withSQLConf(SQLConf.AUTO_SIZE_UPDATE_ENABLED.key -> autoUpdate.toString) { withTable(table) { spark.range(100).write.saveAsTable(table) sql(s"ANALYZE TABLE $table COMPUTE STATISTICS") @@ -302,7 +302,7 @@ class StatisticsCollectionSuite extends StatisticsCollectionTestBase with Shared test("invalidation of tableRelationCache after table truncation") { val table = "invalidate_catalog_cache_table" Seq(false, true).foreach { autoUpdate => - withSQLConf(SQLConf.AUTO_UPDATE_SIZE.key -> autoUpdate.toString) { + withSQLConf(SQLConf.AUTO_SIZE_UPDATE_ENABLED.key -> autoUpdate.toString) { withTable(table) { spark.range(100).write.saveAsTable(table) sql(s"ANALYZE TABLE $table COMPUTE STATISTICS") @@ -318,7 +318,7 @@ class StatisticsCollectionSuite extends StatisticsCollectionTestBase with Shared test("invalidation of tableRelationCache after alter table add partition") { val table = "invalidate_catalog_cache_table" Seq(false, true).foreach { autoUpdate => - withSQLConf(SQLConf.AUTO_UPDATE_SIZE.key -> autoUpdate.toString) { + withSQLConf(SQLConf.AUTO_SIZE_UPDATE_ENABLED.key -> autoUpdate.toString) { withTempDir { dir => withTable(table) { val path = dir.getCanonicalPath diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala index b9a5ad7657134..9e8fc32a05471 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala @@ -755,7 +755,7 @@ class StatisticsSuite extends StatisticsCollectionTestBase with TestHiveSingleto test("change stats after insert command for hive table") { val table = s"change_stats_insert_hive_table" Seq(false, true).foreach { autoUpdate => - withSQLConf(SQLConf.AUTO_UPDATE_SIZE.key -> autoUpdate.toString) { + withSQLConf(SQLConf.AUTO_SIZE_UPDATE_ENABLED.key -> autoUpdate.toString) { withTable(table) { sql(s"CREATE TABLE $table (i int, j string)") // analyze to get initial stats @@ -783,7 +783,7 @@ class StatisticsSuite extends StatisticsCollectionTestBase with TestHiveSingleto test("change stats after load data command") { val table = "change_stats_load_table" Seq(false, true).foreach { autoUpdate => - withSQLConf(SQLConf.AUTO_UPDATE_SIZE.key -> autoUpdate.toString) { + withSQLConf(SQLConf.AUTO_SIZE_UPDATE_ENABLED.key -> autoUpdate.toString) { withTable(table) { sql(s"CREATE TABLE $table (i INT, j STRING) STORED AS PARQUET") // analyze to get initial stats @@ -817,7 +817,7 @@ class StatisticsSuite extends StatisticsCollectionTestBase with TestHiveSingleto test("change stats after add/drop partition command") { val table = "change_stats_part_table" Seq(false, true).foreach { autoUpdate => - withSQLConf(SQLConf.AUTO_UPDATE_SIZE.key -> autoUpdate.toString) { + withSQLConf(SQLConf.AUTO_SIZE_UPDATE_ENABLED.key -> autoUpdate.toString) { withTable(table) { sql(s"CREATE TABLE $table (i INT, j STRING) PARTITIONED BY (ds STRING, hr STRING)") // table has two partitions initially From 1d341042d6948e636643183da9bf532268592c6a Mon Sep 17 00:00:00 2001 From: Bryan Cutler Date: Tue, 7 Nov 2017 21:32:37 +0100 Subject: [PATCH 1618/1765] [SPARK-22417][PYTHON] Fix for createDataFrame from pandas.DataFrame with timestamp ## What changes were proposed in this pull request? Currently, a pandas.DataFrame that contains a timestamp of type 'datetime64[ns]' when converted to a Spark DataFrame with `createDataFrame` will interpret the values as LongType. This fix will check for a timestamp type and convert it to microseconds which will allow Spark to read as TimestampType. ## How was this patch tested? Added unit test to verify Spark schema is expected for TimestampType and DateType when created from pandas Author: Bryan Cutler Closes #19646 from BryanCutler/pyspark-non-arrow-createDataFrame-ts-fix-SPARK-22417. --- python/pyspark/sql/session.py | 49 ++++++++++++++++++++++++++++++++--- python/pyspark/sql/tests.py | 15 +++++++++++ 2 files changed, 61 insertions(+), 3 deletions(-) diff --git a/python/pyspark/sql/session.py b/python/pyspark/sql/session.py index c3dc1a46fd3c1..d1d0b8b8fe5d9 100644 --- a/python/pyspark/sql/session.py +++ b/python/pyspark/sql/session.py @@ -23,6 +23,7 @@ if sys.version >= '3': basestring = unicode = str + xrange = range else: from itertools import imap as map @@ -416,6 +417,50 @@ def _createFromLocal(self, data, schema): data = [schema.toInternal(row) for row in data] return self._sc.parallelize(data), schema + def _get_numpy_record_dtypes(self, rec): + """ + Used when converting a pandas.DataFrame to Spark using to_records(), this will correct + the dtypes of records so they can be properly loaded into Spark. + :param rec: a numpy record to check dtypes + :return corrected dtypes for a numpy.record or None if no correction needed + """ + import numpy as np + cur_dtypes = rec.dtype + col_names = cur_dtypes.names + record_type_list = [] + has_rec_fix = False + for i in xrange(len(cur_dtypes)): + curr_type = cur_dtypes[i] + # If type is a datetime64 timestamp, convert to microseconds + # NOTE: if dtype is datetime[ns] then np.record.tolist() will output values as longs, + # conversion from [us] or lower will lead to py datetime objects, see SPARK-22417 + if curr_type == np.dtype('datetime64[ns]'): + curr_type = 'datetime64[us]' + has_rec_fix = True + record_type_list.append((str(col_names[i]), curr_type)) + return record_type_list if has_rec_fix else None + + def _convert_from_pandas(self, pdf, schema): + """ + Convert a pandas.DataFrame to list of records that can be used to make a DataFrame + :return tuple of list of records and schema + """ + # If no schema supplied by user then get the names of columns only + if schema is None: + schema = [str(x) for x in pdf.columns] + + # Convert pandas.DataFrame to list of numpy records + np_records = pdf.to_records(index=False) + + # Check if any columns need to be fixed for Spark to infer properly + if len(np_records) > 0: + record_type_list = self._get_numpy_record_dtypes(np_records[0]) + if record_type_list is not None: + return [r.astype(record_type_list).tolist() for r in np_records], schema + + # Convert list of numpy records to python lists + return [r.tolist() for r in np_records], schema + @since(2.0) @ignore_unicode_prefix def createDataFrame(self, data, schema=None, samplingRatio=None, verifySchema=True): @@ -512,9 +557,7 @@ def createDataFrame(self, data, schema=None, samplingRatio=None, verifySchema=Tr except Exception: has_pandas = False if has_pandas and isinstance(data, pandas.DataFrame): - if schema is None: - schema = [str(x) for x in data.columns] - data = [r.tolist() for r in data.to_records(index=False)] + data, schema = self._convert_from_pandas(data, schema) if isinstance(schema, StructType): verify_func = _make_type_verifier(schema) if verifySchema else lambda _: True diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 483f39aeef66a..eb0d4e29a5978 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -2592,6 +2592,21 @@ def test_create_dataframe_from_array_of_long(self): df = self.spark.createDataFrame(data) self.assertEqual(df.first(), Row(longarray=[-9223372036854775808, 0, 9223372036854775807])) + @unittest.skipIf(not _have_pandas, "Pandas not installed") + def test_create_dataframe_from_pandas_with_timestamp(self): + import pandas as pd + from datetime import datetime + pdf = pd.DataFrame({"ts": [datetime(2017, 10, 31, 1, 1, 1)], + "d": [pd.Timestamp.now().date()]}) + # test types are inferred correctly without specifying schema + df = self.spark.createDataFrame(pdf) + self.assertTrue(isinstance(df.schema['ts'].dataType, TimestampType)) + self.assertTrue(isinstance(df.schema['d'].dataType, DateType)) + # test with schema will accept pdf as input + df = self.spark.createDataFrame(pdf, schema="d date, ts timestamp") + self.assertTrue(isinstance(df.schema['ts'].dataType, TimestampType)) + self.assertTrue(isinstance(df.schema['d'].dataType, DateType)) + class HiveSparkSubmitTests(SparkSubmitTests): From 0846a44736d9a71ba3234ad5de4c8de9e7fe9f6c Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Tue, 7 Nov 2017 21:57:43 +0100 Subject: [PATCH 1619/1765] [SPARK-22464][SQL] No pushdown for Hive metastore partition predicates containing null-safe equality ## What changes were proposed in this pull request? `<=>` is not supported by Hive metastore partition predicate pushdown. We should not push down it to Hive metastore when they are be using in partition predicates. ## How was this patch tested? Added a test case Author: gatorsmile Closes #19682 from gatorsmile/fixLimitPushDown. --- .../spark/sql/hive/client/HiveShim.scala | 29 ++++++++++++++----- .../sql/hive/client/HiveClientSuite.scala | 9 ++++++ 2 files changed, 30 insertions(+), 8 deletions(-) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala index 5c1ff2b76fdaa..bd1b300416990 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala @@ -592,6 +592,19 @@ private[client] class Shim_v0_13 extends Shim_v0_12 { } } + + /** + * An extractor that matches all binary comparison operators except null-safe equality. + * + * Null-safe equality is not supported by Hive metastore partition predicate pushdown + */ + object SpecialBinaryComparison { + def unapply(e: BinaryComparison): Option[(Expression, Expression)] = e match { + case _: EqualNullSafe => None + case _ => Some((e.left, e.right)) + } + } + private def convertBasicFilters(table: Table, filters: Seq[Expression]): String = { // hive varchar is treated as catalyst string, but hive varchar can't be pushed down. lazy val varcharKeys = table.getPartitionKeys.asScala @@ -600,14 +613,14 @@ private[client] class Shim_v0_13 extends Shim_v0_12 { .map(col => col.getName).toSet filters.collect { - case op @ BinaryComparison(a: Attribute, Literal(v, _: IntegralType)) => + case op @ SpecialBinaryComparison(a: Attribute, Literal(v, _: IntegralType)) => s"${a.name} ${op.symbol} $v" - case op @ BinaryComparison(Literal(v, _: IntegralType), a: Attribute) => + case op @ SpecialBinaryComparison(Literal(v, _: IntegralType), a: Attribute) => s"$v ${op.symbol} ${a.name}" - case op @ BinaryComparison(a: Attribute, Literal(v, _: StringType)) + case op @ SpecialBinaryComparison(a: Attribute, Literal(v, _: StringType)) if !varcharKeys.contains(a.name) => s"""${a.name} ${op.symbol} ${quoteStringLiteral(v.toString)}""" - case op @ BinaryComparison(Literal(v, _: StringType), a: Attribute) + case op @ SpecialBinaryComparison(Literal(v, _: StringType), a: Attribute) if !varcharKeys.contains(a.name) => s"""${quoteStringLiteral(v.toString)} ${op.symbol} ${a.name}""" }.mkString(" and ") @@ -666,16 +679,16 @@ private[client] class Shim_v0_13 extends Shim_v0_12 { case InSet(a: Attribute, ExtractableValues(values)) if !varcharKeys.contains(a.name) && values.nonEmpty => convertInToOr(a, values) - case op @ BinaryComparison(a: Attribute, ExtractableLiteral(value)) + case op @ SpecialBinaryComparison(a: Attribute, ExtractableLiteral(value)) if !varcharKeys.contains(a.name) => s"${a.name} ${op.symbol} $value" - case op @ BinaryComparison(ExtractableLiteral(value), a: Attribute) + case op @ SpecialBinaryComparison(ExtractableLiteral(value), a: Attribute) if !varcharKeys.contains(a.name) => s"$value ${op.symbol} ${a.name}" - case op @ And(expr1, expr2) + case And(expr1, expr2) if convert.isDefinedAt(expr1) || convert.isDefinedAt(expr2) => (convert.lift(expr1) ++ convert.lift(expr2)).mkString("(", " and ", ")") - case op @ Or(expr1, expr2) + case Or(expr1, expr2) if convert.isDefinedAt(expr1) && convert.isDefinedAt(expr2) => s"(${convert(expr1)} or ${convert(expr2)})" } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/HiveClientSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/HiveClientSuite.scala index 3eedcf7e0874e..ce53acef51503 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/HiveClientSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/HiveClientSuite.scala @@ -78,6 +78,15 @@ class HiveClientSuite(version: String) assert(filteredPartitions.size == testPartitionCount) } + test("getPartitionsByFilter: ds<=>20170101") { + // Should return all partitions where <=> is not supported + testMetastorePartitionFiltering( + "ds<=>20170101", + 20170101 to 20170103, + 0 to 23, + "aa" :: "ab" :: "ba" :: "bb" :: Nil) + } + test("getPartitionsByFilter: ds=20170101") { testMetastorePartitionFiltering( "ds=20170101", From 7475a9655cd23dbf667e9543cf0818243e59f998 Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Tue, 7 Nov 2017 16:03:24 -0600 Subject: [PATCH 1620/1765] [SPARK-20645][CORE] Port environment page to new UI backend. This change modifies the status listener to collect the information needed to render the envionment page, and populates that page and the API with information collected by the listener. Tested with existing and added unit tests. Author: Marcelo Vanzin Closes #19677 from vanzin/SPARK-20645. --- .../spark/status/AppStatusListener.scala | 18 ++ .../apache/spark/status/AppStatusStore.scala | 9 + .../v1/ApplicationEnvironmentResource.scala | 15 +- .../org/apache/spark/status/storeTypes.scala | 11 + .../scala/org/apache/spark/ui/SparkUI.scala | 28 +- .../apache/spark/ui/env/EnvironmentPage.scala | 31 +- .../apache/spark/ui/env/EnvironmentTab.scala | 56 ---- .../app_environment_expectation.json | 281 ++++++++++++++++++ .../deploy/history/HistoryServerSuite.scala | 4 +- .../spark/status/AppStatusListenerSuite.scala | 40 +++ project/MimaExcludes.scala | 1 + 11 files changed, 402 insertions(+), 92 deletions(-) delete mode 100644 core/src/main/scala/org/apache/spark/ui/env/EnvironmentTab.scala create mode 100644 core/src/test/resources/HistoryServerExpectations/app_environment_expectation.json diff --git a/core/src/main/scala/org/apache/spark/status/AppStatusListener.scala b/core/src/main/scala/org/apache/spark/status/AppStatusListener.scala index cd43612fae357..424a1159a875c 100644 --- a/core/src/main/scala/org/apache/spark/status/AppStatusListener.scala +++ b/core/src/main/scala/org/apache/spark/status/AppStatusListener.scala @@ -88,6 +88,24 @@ private[spark] class AppStatusListener( kvstore.write(new ApplicationInfoWrapper(appInfo)) } + override def onEnvironmentUpdate(event: SparkListenerEnvironmentUpdate): Unit = { + val details = event.environmentDetails + + val jvmInfo = Map(details("JVM Information"): _*) + val runtime = new v1.RuntimeInfo( + jvmInfo.get("Java Version").orNull, + jvmInfo.get("Java Home").orNull, + jvmInfo.get("Scala Version").orNull) + + val envInfo = new v1.ApplicationEnvironmentInfo( + runtime, + details.getOrElse("Spark Properties", Nil), + details.getOrElse("System Properties", Nil), + details.getOrElse("Classpath Entries", Nil)) + + kvstore.write(new ApplicationEnvironmentInfoWrapper(envInfo)) + } + override def onApplicationEnd(event: SparkListenerApplicationEnd): Unit = { val old = appInfo.attempts.head val attempt = new v1.ApplicationAttemptInfo( diff --git a/core/src/main/scala/org/apache/spark/status/AppStatusStore.scala b/core/src/main/scala/org/apache/spark/status/AppStatusStore.scala index 2927a3227cbef..d6b5d2661e8ee 100644 --- a/core/src/main/scala/org/apache/spark/status/AppStatusStore.scala +++ b/core/src/main/scala/org/apache/spark/status/AppStatusStore.scala @@ -33,6 +33,15 @@ import org.apache.spark.util.kvstore.{InMemoryStore, KVStore} */ private[spark] class AppStatusStore(store: KVStore) { + def applicationInfo(): v1.ApplicationInfo = { + store.view(classOf[ApplicationInfoWrapper]).max(1).iterator().next().info + } + + def environmentInfo(): v1.ApplicationEnvironmentInfo = { + val klass = classOf[ApplicationEnvironmentInfoWrapper] + store.read(klass, klass.getName()).info + } + def jobsList(statuses: JList[JobExecutionStatus]): Seq[v1.JobData] = { val it = store.view(classOf[JobDataWrapper]).asScala.map(_.info) if (!statuses.isEmpty()) { diff --git a/core/src/main/scala/org/apache/spark/status/api/v1/ApplicationEnvironmentResource.scala b/core/src/main/scala/org/apache/spark/status/api/v1/ApplicationEnvironmentResource.scala index 739a8aceae861..e702f8aa2ef2d 100644 --- a/core/src/main/scala/org/apache/spark/status/api/v1/ApplicationEnvironmentResource.scala +++ b/core/src/main/scala/org/apache/spark/status/api/v1/ApplicationEnvironmentResource.scala @@ -26,20 +26,7 @@ private[v1] class ApplicationEnvironmentResource(ui: SparkUI) { @GET def getEnvironmentInfo(): ApplicationEnvironmentInfo = { - val listener = ui.environmentListener - listener.synchronized { - val jvmInfo = Map(listener.jvmInformation: _*) - val runtime = new RuntimeInfo( - jvmInfo("Java Version"), - jvmInfo("Java Home"), - jvmInfo("Scala Version")) - - new ApplicationEnvironmentInfo( - runtime, - listener.sparkProperties, - listener.systemProperties, - listener.classpathEntries) - } + ui.store.environmentInfo() } } diff --git a/core/src/main/scala/org/apache/spark/status/storeTypes.scala b/core/src/main/scala/org/apache/spark/status/storeTypes.scala index a445435809f3a..23e9a360ddc02 100644 --- a/core/src/main/scala/org/apache/spark/status/storeTypes.scala +++ b/core/src/main/scala/org/apache/spark/status/storeTypes.scala @@ -34,6 +34,17 @@ private[spark] class ApplicationInfoWrapper(val info: ApplicationInfo) { } +private[spark] class ApplicationEnvironmentInfoWrapper(val info: ApplicationEnvironmentInfo) { + + /** + * There's always a single ApplicationEnvironmentInfo object per application, so this + * ID doesn't need to be dynamic. But the KVStore API requires an ID. + */ + @JsonIgnore @KVIndex + def id: String = classOf[ApplicationEnvironmentInfoWrapper].getName() + +} + private[spark] class ExecutorSummaryWrapper(val info: ExecutorSummary) { @JsonIgnore @KVIndex diff --git a/core/src/main/scala/org/apache/spark/ui/SparkUI.scala b/core/src/main/scala/org/apache/spark/ui/SparkUI.scala index ee645f6bf8a7a..43b57a1630aa9 100644 --- a/core/src/main/scala/org/apache/spark/ui/SparkUI.scala +++ b/core/src/main/scala/org/apache/spark/ui/SparkUI.scala @@ -25,11 +25,10 @@ import org.apache.spark.{SecurityManager, SparkConf, SparkContext} import org.apache.spark.internal.Logging import org.apache.spark.scheduler._ import org.apache.spark.status.AppStatusStore -import org.apache.spark.status.api.v1.{ApiRootResource, ApplicationAttemptInfo, ApplicationInfo, - UIRoot} +import org.apache.spark.status.api.v1._ import org.apache.spark.storage.StorageStatusListener import org.apache.spark.ui.JettyUtils._ -import org.apache.spark.ui.env.{EnvironmentListener, EnvironmentTab} +import org.apache.spark.ui.env.EnvironmentTab import org.apache.spark.ui.exec.{ExecutorsListener, ExecutorsTab} import org.apache.spark.ui.jobs.{JobProgressListener, JobsTab, StagesTab} import org.apache.spark.ui.scope.RDDOperationGraphListener @@ -44,7 +43,6 @@ private[spark] class SparkUI private ( val sc: Option[SparkContext], val conf: SparkConf, securityManager: SecurityManager, - val environmentListener: EnvironmentListener, val storageStatusListener: StorageStatusListener, val executorsListener: ExecutorsListener, val jobProgressListener: JobProgressListener, @@ -73,7 +71,7 @@ private[spark] class SparkUI private ( val stagesTab = new StagesTab(this) attachTab(stagesTab) attachTab(new StorageTab(this)) - attachTab(new EnvironmentTab(this)) + attachTab(new EnvironmentTab(this, store)) attachTab(new ExecutorsTab(this)) attachHandler(createStaticHandler(SparkUI.STATIC_RESOURCE_DIR, "/static")) attachHandler(createRedirectHandler("/", "/jobs/", basePath = basePath)) @@ -88,9 +86,13 @@ private[spark] class SparkUI private ( initialize() def getSparkUser: String = { - environmentListener.sparkUser - .orElse(environmentListener.systemProperties.toMap.get("user.name")) - .getOrElse("") + try { + Option(store.applicationInfo().attempts.head.sparkUser) + .orElse(store.environmentInfo().systemProperties.toMap.get("user.name")) + .getOrElse("") + } catch { + case _: NoSuchElementException => "" + } } def getAppName: String = appName @@ -143,6 +145,7 @@ private[spark] class SparkUI private ( def setStreamingJobProgressListener(sparkListener: SparkListener): Unit = { streamingJobProgressListener = Option(sparkListener) } + } private[spark] abstract class SparkUITab(parent: SparkUI, prefix: String) @@ -184,21 +187,20 @@ private[spark] object SparkUI { addListenerFn(listener) listener } - val environmentListener = new EnvironmentListener + val storageStatusListener = new StorageStatusListener(conf) val executorsListener = new ExecutorsListener(storageStatusListener, conf) val storageListener = new StorageListener(storageStatusListener) val operationGraphListener = new RDDOperationGraphListener(conf) - addListenerFn(environmentListener) addListenerFn(storageStatusListener) addListenerFn(executorsListener) addListenerFn(storageListener) addListenerFn(operationGraphListener) - new SparkUI(store, sc, conf, securityManager, environmentListener, storageStatusListener, - executorsListener, jobProgressListener, storageListener, operationGraphListener, - appName, basePath, lastUpdateTime, startTime, appSparkVersion) + new SparkUI(store, sc, conf, securityManager, storageStatusListener, executorsListener, + jobProgressListener, storageListener, operationGraphListener, appName, basePath, + lastUpdateTime, startTime, appSparkVersion) } } diff --git a/core/src/main/scala/org/apache/spark/ui/env/EnvironmentPage.scala b/core/src/main/scala/org/apache/spark/ui/env/EnvironmentPage.scala index b11f8f1555f17..43adab7a35d65 100644 --- a/core/src/main/scala/org/apache/spark/ui/env/EnvironmentPage.scala +++ b/core/src/main/scala/org/apache/spark/ui/env/EnvironmentPage.scala @@ -21,22 +21,31 @@ import javax.servlet.http.HttpServletRequest import scala.xml.Node -import org.apache.spark.ui.{UIUtils, WebUIPage} +import org.apache.spark.SparkConf +import org.apache.spark.status.AppStatusStore +import org.apache.spark.ui._ import org.apache.spark.util.Utils -private[ui] class EnvironmentPage(parent: EnvironmentTab) extends WebUIPage("") { - private val listener = parent.listener +private[ui] class EnvironmentPage( + parent: EnvironmentTab, + conf: SparkConf, + store: AppStatusStore) extends WebUIPage("") { def render(request: HttpServletRequest): Seq[Node] = { + val appEnv = store.environmentInfo() + val jvmInformation = Map( + "Java Version" -> appEnv.runtime.javaVersion, + "Java Home" -> appEnv.runtime.javaHome, + "Scala Version" -> appEnv.runtime.scalaVersion) + val runtimeInformationTable = UIUtils.listingTable( - propertyHeader, jvmRow, listener.jvmInformation, fixedWidth = true) + propertyHeader, jvmRow, jvmInformation, fixedWidth = true) val sparkPropertiesTable = UIUtils.listingTable(propertyHeader, propertyRow, - Utils.redact(parent.conf, listener.sparkProperties), fixedWidth = true) - + Utils.redact(conf, appEnv.sparkProperties.toSeq), fixedWidth = true) val systemPropertiesTable = UIUtils.listingTable( - propertyHeader, propertyRow, listener.systemProperties, fixedWidth = true) + propertyHeader, propertyRow, appEnv.systemProperties, fixedWidth = true) val classpathEntriesTable = UIUtils.listingTable( - classPathHeaders, classPathRow, listener.classpathEntries, fixedWidth = true) + classPathHeaders, classPathRow, appEnv.classpathEntries, fixedWidth = true) val content =

    Runtime Information

    {runtimeInformationTable} @@ -54,3 +63,9 @@ private[ui] class EnvironmentPage(parent: EnvironmentTab) extends WebUIPage("") private def propertyRow(kv: (String, String)) = {kv._1}{kv._2} private def classPathRow(data: (String, String)) = {data._1}{data._2} } + +private[ui] class EnvironmentTab( + parent: SparkUI, + store: AppStatusStore) extends SparkUITab(parent, "environment") { + attachPage(new EnvironmentPage(this, parent.conf, store)) +} diff --git a/core/src/main/scala/org/apache/spark/ui/env/EnvironmentTab.scala b/core/src/main/scala/org/apache/spark/ui/env/EnvironmentTab.scala deleted file mode 100644 index 61b12aaa32bb6..0000000000000 --- a/core/src/main/scala/org/apache/spark/ui/env/EnvironmentTab.scala +++ /dev/null @@ -1,56 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.ui.env - -import org.apache.spark.annotation.DeveloperApi -import org.apache.spark.scheduler._ -import org.apache.spark.ui._ - -private[ui] class EnvironmentTab(parent: SparkUI) extends SparkUITab(parent, "environment") { - val listener = parent.environmentListener - val conf = parent.conf - attachPage(new EnvironmentPage(this)) -} - -/** - * :: DeveloperApi :: - * A SparkListener that prepares information to be displayed on the EnvironmentTab - */ -@DeveloperApi -@deprecated("This class will be removed in a future release.", "2.2.0") -class EnvironmentListener extends SparkListener { - var sparkUser: Option[String] = None - var jvmInformation = Seq[(String, String)]() - var sparkProperties = Seq[(String, String)]() - var systemProperties = Seq[(String, String)]() - var classpathEntries = Seq[(String, String)]() - - override def onApplicationStart(event: SparkListenerApplicationStart): Unit = { - sparkUser = Some(event.sparkUser) - } - - override def onEnvironmentUpdate(environmentUpdate: SparkListenerEnvironmentUpdate) { - synchronized { - val environmentDetails = environmentUpdate.environmentDetails - jvmInformation = environmentDetails("JVM Information") - sparkProperties = environmentDetails("Spark Properties") - systemProperties = environmentDetails("System Properties") - classpathEntries = environmentDetails("Classpath Entries") - } - } -} diff --git a/core/src/test/resources/HistoryServerExpectations/app_environment_expectation.json b/core/src/test/resources/HistoryServerExpectations/app_environment_expectation.json new file mode 100644 index 0000000000000..4ed053899ee6c --- /dev/null +++ b/core/src/test/resources/HistoryServerExpectations/app_environment_expectation.json @@ -0,0 +1,281 @@ +{ + "runtime" : { + "javaVersion" : "1.8.0_92 (Oracle Corporation)", + "javaHome" : "/Library/Java/JavaVirtualMachines/jdk1.8.0_92.jdk/Contents/Home/jre", + "scalaVersion" : "version 2.11.8" + }, + "sparkProperties" : [ + [ "spark.blacklist.task.maxTaskAttemptsPerExecutor", "3" ], + [ "spark.blacklist.enabled", "TRUE" ], + [ "spark.driver.host", "172.22.0.167" ], + [ "spark.blacklist.task.maxTaskAttemptsPerNode", "3" ], + [ "spark.eventLog.enabled", "TRUE" ], + [ "spark.driver.port", "51459" ], + [ "spark.repl.class.uri", "spark://172.22.0.167:51459/classes" ], + [ "spark.jars", "" ], + [ "spark.repl.class.outputDir", "/private/var/folders/l4/d46wlzj16593f3d812vk49tw0000gp/T/spark-1cbc97d0-7fe6-4c9f-8c2c-f6fe51ee3cf2/repl-39929169-ac4c-4c6d-b116-f648e4dd62ed" ], + [ "spark.app.name", "Spark shell" ], + [ "spark.blacklist.stage.maxFailedExecutorsPerNode", "3" ], + [ "spark.scheduler.mode", "FIFO" ], + [ "spark.eventLog.overwrite", "TRUE" ], + [ "spark.blacklist.stage.maxFailedTasksPerExecutor", "3" ], + [ "spark.executor.id", "driver" ], + [ "spark.blacklist.application.maxFailedExecutorsPerNode", "2" ], + [ "spark.submit.deployMode", "client" ], + [ "spark.master", "local-cluster[4,4,1024]" ], + [ "spark.home", "/Users/Jose/IdeaProjects/spark" ], + [ "spark.eventLog.dir", "/Users/jose/logs" ], + [ "spark.sql.catalogImplementation", "in-memory" ], + [ "spark.eventLog.compress", "FALSE" ], + [ "spark.blacklist.application.maxFailedTasksPerExecutor", "1" ], + [ "spark.blacklist.timeout", "1000000" ], + [ "spark.app.id", "app-20161116163331-0000" ], + [ "spark.task.maxFailures", "4" ] + ], + "systemProperties" : [ + [ "java.io.tmpdir", "/var/folders/l4/d46wlzj16593f3d812vk49tw0000gp/T/" ], + [ "line.separator", "\n" ], + [ "path.separator", ":" ], + [ "sun.management.compiler", "HotSpot 64-Bit Tiered Compilers" ], + [ "SPARK_SUBMIT", "true" ], + [ "sun.cpu.endian", "little" ], + [ "java.specification.version", "1.8" ], + [ "java.vm.specification.name", "Java Virtual Machine Specification" ], + [ "java.vendor", "Oracle Corporation" ], + [ "java.vm.specification.version", "1.8" ], + [ "user.home", "/Users/Jose" ], + [ "file.encoding.pkg", "sun.io" ], + [ "sun.nio.ch.bugLevel", "" ], + [ "ftp.nonProxyHosts", "local|*.local|169.254/16|*.169.254/16" ], + [ "sun.arch.data.model", "64" ], + [ "sun.boot.library.path", "/Library/Java/JavaVirtualMachines/jdk1.8.0_92.jdk/Contents/Home/jre/lib" ], + [ "user.dir", "/Users/Jose/IdeaProjects/spark" ], + [ "java.library.path", "/Users/Jose/Library/Java/Extensions:/Library/Java/Extensions:/Network/Library/Java/Extensions:/System/Library/Java/Extensions:/usr/lib/java:." ], + [ "sun.cpu.isalist", "" ], + [ "os.arch", "x86_64" ], + [ "java.vm.version", "25.92-b14" ], + [ "java.endorsed.dirs", "/Library/Java/JavaVirtualMachines/jdk1.8.0_92.jdk/Contents/Home/jre/lib/endorsed" ], + [ "java.runtime.version", "1.8.0_92-b14" ], + [ "java.vm.info", "mixed mode" ], + [ "java.ext.dirs", "/Users/Jose/Library/Java/Extensions:/Library/Java/JavaVirtualMachines/jdk1.8.0_92.jdk/Contents/Home/jre/lib/ext:/Library/Java/Extensions:/Network/Library/Java/Extensions:/System/Library/Java/Extensions:/usr/lib/java" ], + [ "java.runtime.name", "Java(TM) SE Runtime Environment" ], + [ "file.separator", "/" ], + [ "io.netty.maxDirectMemory", "0" ], + [ "java.class.version", "52.0" ], + [ "scala.usejavacp", "true" ], + [ "java.specification.name", "Java Platform API Specification" ], + [ "sun.boot.class.path", "/Library/Java/JavaVirtualMachines/jdk1.8.0_92.jdk/Contents/Home/jre/lib/resources.jar:/Library/Java/JavaVirtualMachines/jdk1.8.0_92.jdk/Contents/Home/jre/lib/rt.jar:/Library/Java/JavaVirtualMachines/jdk1.8.0_92.jdk/Contents/Home/jre/lib/sunrsasign.jar:/Library/Java/JavaVirtualMachines/jdk1.8.0_92.jdk/Contents/Home/jre/lib/jsse.jar:/Library/Java/JavaVirtualMachines/jdk1.8.0_92.jdk/Contents/Home/jre/lib/jce.jar:/Library/Java/JavaVirtualMachines/jdk1.8.0_92.jdk/Contents/Home/jre/lib/charsets.jar:/Library/Java/JavaVirtualMachines/jdk1.8.0_92.jdk/Contents/Home/jre/lib/jfr.jar:/Library/Java/JavaVirtualMachines/jdk1.8.0_92.jdk/Contents/Home/jre/classes" ], + [ "file.encoding", "UTF-8" ], + [ "user.timezone", "America/Chicago" ], + [ "java.specification.vendor", "Oracle Corporation" ], + [ "sun.java.launcher", "SUN_STANDARD" ], + [ "os.version", "10.11.6" ], + [ "sun.os.patch.level", "unknown" ], + [ "gopherProxySet", "false" ], + [ "java.vm.specification.vendor", "Oracle Corporation" ], + [ "user.country", "US" ], + [ "sun.jnu.encoding", "UTF-8" ], + [ "http.nonProxyHosts", "local|*.local|169.254/16|*.169.254/16" ], + [ "user.language", "en" ], + [ "socksNonProxyHosts", "local|*.local|169.254/16|*.169.254/16" ], + [ "java.vendor.url", "http://java.oracle.com/" ], + [ "java.awt.printerjob", "sun.lwawt.macosx.CPrinterJob" ], + [ "java.awt.graphicsenv", "sun.awt.CGraphicsEnvironment" ], + [ "awt.toolkit", "sun.lwawt.macosx.LWCToolkit" ], + [ "os.name", "Mac OS X" ], + [ "java.vm.vendor", "Oracle Corporation" ], + [ "java.vendor.url.bug", "http://bugreport.sun.com/bugreport/" ], + [ "user.name", "jose" ], + [ "java.vm.name", "Java HotSpot(TM) 64-Bit Server VM" ], + [ "sun.java.command", "org.apache.spark.deploy.SparkSubmit --master local-cluster[4,4,1024] --conf spark.blacklist.enabled=TRUE --conf spark.blacklist.timeout=1000000 --conf spark.blacklist.application.maxFailedTasksPerExecutor=1 --conf spark.eventLog.overwrite=TRUE --conf spark.blacklist.task.maxTaskAttemptsPerNode=3 --conf spark.blacklist.stage.maxFailedTasksPerExecutor=3 --conf spark.blacklist.task.maxTaskAttemptsPerExecutor=3 --conf spark.eventLog.compress=FALSE --conf spark.blacklist.stage.maxFailedExecutorsPerNode=3 --conf spark.eventLog.enabled=TRUE --conf spark.eventLog.dir=/Users/jose/logs --conf spark.blacklist.application.maxFailedExecutorsPerNode=2 --conf spark.task.maxFailures=4 --class org.apache.spark.repl.Main --name Spark shell spark-shell -i /Users/Jose/dev/jose-utils/blacklist/test-blacklist.scala" ], + [ "java.home", "/Library/Java/JavaVirtualMachines/jdk1.8.0_92.jdk/Contents/Home/jre" ], + [ "java.version", "1.8.0_92" ], + [ "sun.io.unicode.encoding", "UnicodeBig" ] + ], + "classpathEntries" : [ + [ "/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/avro-mapred-1.7.7-hadoop2.jar", "System Classpath" ], + [ "/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/hadoop-mapreduce-client-core-2.2.0.jar", "System Classpath" ], + [ "/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/jetty-servlet-9.2.16.v20160414.jar", "System Classpath" ], + [ "/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/parquet-column-1.8.1.jar", "System Classpath" ], + [ "/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/snappy-java-1.1.2.6.jar", "System Classpath" ], + [ "/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/oro-2.0.8.jar", "System Classpath" ], + [ "/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/arpack_combined_all-0.1.jar", "System Classpath" ], + [ "/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/pmml-schema-1.2.15.jar", "System Classpath" ], + [ "/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/spark-assembly_2.11-2.1.0-SNAPSHOT.jar", "System Classpath" ], + [ "/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/javassist-3.18.1-GA.jar", "System Classpath" ], + [ "/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/spark-tags_2.11-2.1.0-SNAPSHOT.jar", "System Classpath" ], + [ "/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/spark-launcher_2.11-2.1.0-SNAPSHOT.jar", "System Classpath" ], + [ "/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/commons-math3-3.4.1.jar", "System Classpath" ], + [ "/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/hk2-api-2.4.0-b34.jar", "System Classpath" ], + [ "/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/scala-xml_2.11-1.0.4.jar", "System Classpath" ], + [ "/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/objenesis-2.1.jar", "System Classpath" ], + [ "/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/spire-macros_2.11-0.7.4.jar", "System Classpath" ], + [ "/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/scala-reflect-2.11.8.jar", "System Classpath" ], + [ "/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/spark-mllib-local_2.11-2.1.0-SNAPSHOT.jar", "System Classpath" ], + [ "/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/spark-mllib_2.11-2.1.0-SNAPSHOT.jar", "System Classpath" ], + [ "/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/jersey-server-2.22.2.jar", "System Classpath" ], + [ "/Users/Jose/IdeaProjects/spark/core/target/scala-2.11/classes/", "System Classpath" ], + [ "/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/jackson-mapper-asl-1.9.13.jar", "System Classpath" ], + [ "/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/jackson-module-scala_2.11-2.6.5.jar", "System Classpath" ], + [ "/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/curator-framework-2.4.0.jar", "System Classpath" ], + [ "/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/javax.inject-1.jar", "System Classpath" ], + [ "/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/curator-client-2.4.0.jar", "System Classpath" ], + [ "/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/jackson-core-asl-1.9.13.jar", "System Classpath" ], + [ "/Users/Jose/IdeaProjects/spark/common/network-common/target/scala-2.11/classes/", "System Classpath" ], + [ "/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/zookeeper-3.4.5.jar", "System Classpath" ], + [ "/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/hadoop-auth-2.2.0.jar", "System Classpath" ], + [ "/Users/Jose/IdeaProjects/spark/repl/target/scala-2.11/classes/", "System Classpath" ], + [ "/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/jul-to-slf4j-1.7.16.jar", "System Classpath" ], + [ "/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/jersey-media-jaxb-2.22.2.jar", "System Classpath" ], + [ "/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/jetty-io-9.2.16.v20160414.jar", "System Classpath" ], + [ "/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/RoaringBitmap-0.5.11.jar", "System Classpath" ], + [ "/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/javax.ws.rs-api-2.0.1.jar", "System Classpath" ], + [ "/Users/Jose/IdeaProjects/spark/sql/catalyst/target/scala-2.11/classes/", "System Classpath" ], + [ "/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/spark-unsafe_2.11-2.1.0-SNAPSHOT.jar", "System Classpath" ], + [ "/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/spark-repl_2.11-2.1.0-SNAPSHOT.jar", "System Classpath" ], + [ "/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/jetty-continuation-9.2.16.v20160414.jar", "System Classpath" ], + [ "/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/hadoop-yarn-client-2.2.0.jar", "System Classpath" ], + [ "/Users/Jose/IdeaProjects/spark/sql/hive-thriftserver/target/scala-2.11/classes", "System Classpath" ], + [ "/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/hadoop-annotations-2.2.0.jar", "System Classpath" ], + [ "/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/metrics-graphite-3.1.2.jar", "System Classpath" ], + [ "/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/hadoop-yarn-api-2.2.0.jar", "System Classpath" ], + [ "/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/jersey-container-servlet-core-2.22.2.jar", "System Classpath" ], + [ "/Users/Jose/IdeaProjects/spark/streaming/target/scala-2.11/classes/", "System Classpath" ], + [ "/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/commons-net-3.1.jar", "System Classpath" ], + [ "/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/jetty-proxy-9.2.16.v20160414.jar", "System Classpath" ], + [ "/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/spark-catalyst_2.11-2.1.0-SNAPSHOT.jar", "System Classpath" ], + [ "/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/lz4-1.3.0.jar", "System Classpath" ], + [ "/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/commons-crypto-1.0.0.jar", "System Classpath" ], + [ "/Users/Jose/IdeaProjects/spark/common/network-yarn/target/scala-2.11/classes", "System Classpath" ], + [ "/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/javax.annotation-api-1.2.jar", "System Classpath" ], + [ "/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/spark-sql_2.11-2.1.0-SNAPSHOT.jar", "System Classpath" ], + [ "/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/guava-14.0.1.jar", "System Classpath" ], + [ "/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/javax.servlet-api-3.1.0.jar", "System Classpath" ], + [ "/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/commons-collections-3.2.1.jar", "System Classpath" ], + [ "/Users/Jose/IdeaProjects/spark/conf/", "System Classpath" ], + [ "/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/unused-1.0.0.jar", "System Classpath" ], + [ "/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/aopalliance-1.0.jar", "System Classpath" ], + [ "/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/parquet-encoding-1.8.1.jar", "System Classpath" ], + [ "/Users/Jose/IdeaProjects/spark/common/tags/target/scala-2.11/classes/", "System Classpath" ], + [ "/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/json4s-jackson_2.11-3.2.11.jar", "System Classpath" ], + [ "/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/commons-cli-1.2.jar", "System Classpath" ], + [ "/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/hadoop-yarn-server-common-2.2.0.jar", "System Classpath" ], + [ "/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/cglib-2.2.1-v20090111.jar", "System Classpath" ], + [ "/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/pyrolite-4.13.jar", "System Classpath" ], + [ "/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/scala-library-2.11.8.jar", "System Classpath" ], + [ "/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/scala-parser-combinators_2.11-1.0.4.jar", "System Classpath" ], + [ "/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/jetty-util-6.1.26.jar", "System Classpath" ], + [ "/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/py4j-0.10.4.jar", "System Classpath" ], + [ "/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/commons-configuration-1.6.jar", "System Classpath" ], + [ "/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/core-1.1.2.jar", "System Classpath" ], + [ "/Users/Jose/IdeaProjects/spark/core/target/jars/*", "System Classpath" ], + [ "/Users/Jose/IdeaProjects/spark/common/network-shuffle/target/scala-2.11/classes/", "System Classpath" ], + [ "/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/parquet-format-2.3.0-incubating.jar", "System Classpath" ], + [ "/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/kryo-shaded-3.0.3.jar", "System Classpath" ], + [ "/Users/Jose/IdeaProjects/spark/sql/core/target/scala-2.11/classes/", "System Classpath" ], + [ "/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/chill-java-0.8.0.jar", "System Classpath" ], + [ "/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/jackson-annotations-2.6.5.jar", "System Classpath" ], + [ "/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/parquet-hadoop-1.8.1.jar", "System Classpath" ], + [ "/Users/Jose/IdeaProjects/spark/sql/hive/target/scala-2.11/classes/", "System Classpath" ], + [ "/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/avro-ipc-1.7.7.jar", "System Classpath" ], + [ "/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/xz-1.0.jar", "System Classpath" ], + [ "/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/parquet-jackson-1.8.1.jar", "System Classpath" ], + [ "/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/aopalliance-repackaged-2.4.0-b34.jar", "System Classpath" ], + [ "/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/jersey-common-2.22.2.jar", "System Classpath" ], + [ "/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/log4j-1.2.17.jar", "System Classpath" ], + [ "/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/metrics-core-3.1.2.jar", "System Classpath" ], + [ "/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/jetty-util-9.2.16.v20160414.jar", "System Classpath" ], + [ "/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/scalap-2.11.0.jar", "System Classpath" ], + [ "/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/osgi-resource-locator-1.0.1.jar", "System Classpath" ], + [ "/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/commons-beanutils-1.7.0.jar", "System Classpath" ], + [ "/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/commons-compress-1.4.1.jar", "System Classpath" ], + [ "/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/jcl-over-slf4j-1.7.16.jar", "System Classpath" ], + [ "/Users/Jose/IdeaProjects/spark/yarn/target/scala-2.11/classes", "System Classpath" ], + [ "/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/jetty-plus-9.2.16.v20160414.jar", "System Classpath" ], + [ "/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/protobuf-java-2.5.0.jar", "System Classpath" ], + [ "/Users/Jose/IdeaProjects/spark/common/unsafe/target/scala-2.11/classes/", "System Classpath" ], + [ "/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/jackson-module-paranamer-2.6.5.jar", "System Classpath" ], + [ "/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/leveldbjni-all-1.8.jar", "System Classpath" ], + [ "/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/jackson-core-2.6.5.jar", "System Classpath" ], + [ "/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/slf4j-api-1.7.16.jar", "System Classpath" ], + [ "/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/compress-lzf-1.0.3.jar", "System Classpath" ], + [ "/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/stream-2.7.0.jar", "System Classpath" ], + [ "/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/hadoop-mapreduce-client-shuffle-2.2.0.jar", "System Classpath" ], + [ "/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/commons-codec-1.10.jar", "System Classpath" ], + [ "/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/hadoop-yarn-common-2.2.0.jar", "System Classpath" ], + [ "/Users/Jose/IdeaProjects/spark/common/sketch/target/scala-2.11/classes/", "System Classpath" ], + [ "/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/breeze_2.11-0.12.jar", "System Classpath" ], + [ "/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/hadoop-mapreduce-client-common-2.2.0.jar", "System Classpath" ], + [ "/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/spark-core_2.11-2.1.0-SNAPSHOT.jar", "System Classpath" ], + [ "/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/jersey-container-servlet-2.22.2.jar", "System Classpath" ], + [ "/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/spark-network-shuffle_2.11-2.1.0-SNAPSHOT.jar", "System Classpath" ], + [ "/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/commons-lang-2.5.jar", "System Classpath" ], + [ "/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/ivy-2.4.0.jar", "System Classpath" ], + [ "/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/hadoop-common-2.2.0.jar", "System Classpath" ], + [ "/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/commons-math-2.1.jar", "System Classpath" ], + [ "/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/hadoop-hdfs-2.2.0.jar", "System Classpath" ], + [ "/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/scala-compiler-2.11.8.jar", "System Classpath" ], + [ "/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/metrics-jvm-3.1.2.jar", "System Classpath" ], + [ "/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/commons-lang3-3.5.jar", "System Classpath" ], + [ "/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/jsr305-1.3.9.jar", "System Classpath" ], + [ "/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/minlog-1.3.0.jar", "System Classpath" ], + [ "/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/netty-3.8.0.Final.jar", "System Classpath" ], + [ "/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/jetty-webapp-9.2.16.v20160414.jar", "System Classpath" ], + [ "/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/json4s-ast_2.11-3.2.11.jar", "System Classpath" ], + [ "/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/xbean-asm5-shaded-4.4.jar", "System Classpath" ], + [ "/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/commons-io-2.1.jar", "System Classpath" ], + [ "/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/slf4j-log4j12-1.7.16.jar", "System Classpath" ], + [ "/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/hk2-locator-2.4.0-b34.jar", "System Classpath" ], + [ "/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/shapeless_2.11-2.0.0.jar", "System Classpath" ], + [ "/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/spark-network-common_2.11-2.1.0-SNAPSHOT.jar", "System Classpath" ], + [ "/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/jetty-xml-9.2.16.v20160414.jar", "System Classpath" ], + [ "/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/commons-httpclient-3.1.jar", "System Classpath" ], + [ "/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/javax.inject-2.4.0-b34.jar", "System Classpath" ], + [ "/Users/Jose/IdeaProjects/spark/mllib/target/scala-2.11/classes/", "System Classpath" ], + [ "/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/scalatest_2.11-2.2.6.jar", "System Classpath" ], + [ "/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/hk2-utils-2.4.0-b34.jar", "System Classpath" ], + [ "/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/jetty-client-9.2.16.v20160414.jar", "System Classpath" ], + [ "/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/jersey-guava-2.22.2.jar", "System Classpath" ], + [ "/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/jetty-jndi-9.2.16.v20160414.jar", "System Classpath" ], + [ "/Users/Jose/IdeaProjects/spark/graphx/target/scala-2.11/classes/", "System Classpath" ], + [ "/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/hadoop-mapreduce-client-app-2.2.0.jar", "System Classpath" ], + [ "/Users/Jose/IdeaProjects/spark/examples/target/scala-2.11/classes/", "System Classpath" ], + [ "/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/xmlenc-0.52.jar", "System Classpath" ], + [ "/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/jets3t-0.7.1.jar", "System Classpath" ], + [ "/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/curator-recipes-2.4.0.jar", "System Classpath" ], + [ "/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/opencsv-2.3.jar", "System Classpath" ], + [ "/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/jtransforms-2.4.0.jar", "System Classpath" ], + [ "/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/antlr4-runtime-4.5.3.jar", "System Classpath" ], + [ "/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/chill_2.11-0.8.0.jar", "System Classpath" ], + [ "/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/commons-digester-1.8.jar", "System Classpath" ], + [ "/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/univocity-parsers-2.2.1.jar", "System Classpath" ], + [ "/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/jline-2.12.1.jar", "System Classpath" ], + [ "/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/spark-streaming_2.11-2.1.0-SNAPSHOT.jar", "System Classpath" ], + [ "/Users/Jose/IdeaProjects/spark/launcher/target/scala-2.11/classes/", "System Classpath" ], + [ "/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/breeze-macros_2.11-0.12.jar", "System Classpath" ], + [ "/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/jersey-client-2.22.2.jar", "System Classpath" ], + [ "/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/jackson-databind-2.6.5.jar", "System Classpath" ], + [ "/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/jetty-servlets-9.2.16.v20160414.jar", "System Classpath" ], + [ "/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/paranamer-2.6.jar", "System Classpath" ], + [ "/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/jetty-security-9.2.16.v20160414.jar", "System Classpath" ], + [ "/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/avro-ipc-1.7.7-tests.jar", "System Classpath" ], + [ "/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/avro-1.7.7.jar", "System Classpath" ], + [ "/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/spire_2.11-0.7.4.jar", "System Classpath" ], + [ "/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/hadoop-client-2.2.0.jar", "System Classpath" ], + [ "/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/metrics-json-3.1.2.jar", "System Classpath" ], + [ "/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/commons-beanutils-core-1.8.0.jar", "System Classpath" ], + [ "/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/validation-api-1.1.0.Final.jar", "System Classpath" ], + [ "/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/spark-graphx_2.11-2.1.0-SNAPSHOT.jar", "System Classpath" ], + [ "/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/netty-all-4.0.41.Final.jar", "System Classpath" ], + [ "/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/janino-3.0.0.jar", "System Classpath" ], + [ "/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/json4s-core_2.11-3.2.11.jar", "System Classpath" ], + [ "/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/commons-compiler-3.0.0.jar", "System Classpath" ], + [ "/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/guice-3.0.jar", "System Classpath" ], + [ "/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/jetty-server-9.2.16.v20160414.jar", "System Classpath" ], + [ "/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/jetty-http-9.2.16.v20160414.jar", "System Classpath" ], + [ "/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/parquet-common-1.8.1.jar", "System Classpath" ], + [ "/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/hadoop-mapreduce-client-jobclient-2.2.0.jar", "System Classpath" ], + [ "/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/spark-sketch_2.11-2.1.0-SNAPSHOT.jar", "System Classpath" ], + [ "/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/pmml-model-1.2.15.jar", "System Classpath" ] + ] +} diff --git a/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerSuite.scala b/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerSuite.scala index 010a8dd004d4f..6a1abceaeb63c 100644 --- a/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerSuite.scala @@ -158,7 +158,9 @@ class HistoryServerSuite extends SparkFunSuite with BeforeAndAfter with Matchers "rdd list storage json" -> "applications/local-1422981780767/storage/rdd", "executor node blacklisting" -> "applications/app-20161116163331-0000/executors", "executor node blacklisting unblacklisting" -> "applications/app-20161115172038-0000/executors", - "executor memory usage" -> "applications/app-20161116163331-0000/executors" + "executor memory usage" -> "applications/app-20161116163331-0000/executors", + + "app environment" -> "applications/app-20161116163331-0000/environment" // Todo: enable this test when logging the even of onBlockUpdated. See: SPARK-13845 // "one rdd storage json" -> "applications/local-1422981780767/storage/rdd/0" ) diff --git a/core/src/test/scala/org/apache/spark/status/AppStatusListenerSuite.scala b/core/src/test/scala/org/apache/spark/status/AppStatusListenerSuite.scala index 7ac1ce19f8ddf..867d35f231dc0 100644 --- a/core/src/test/scala/org/apache/spark/status/AppStatusListenerSuite.scala +++ b/core/src/test/scala/org/apache/spark/status/AppStatusListenerSuite.scala @@ -56,6 +56,46 @@ class AppStatusListenerSuite extends SparkFunSuite with BeforeAndAfter { Utils.deleteRecursively(testDir) } + test("environment info") { + val listener = new AppStatusListener(store, conf, true) + + val details = Map( + "JVM Information" -> Seq( + "Java Version" -> sys.props("java.version"), + "Java Home" -> sys.props("java.home"), + "Scala Version" -> scala.util.Properties.versionString + ), + "Spark Properties" -> Seq( + "spark.conf.1" -> "1", + "spark.conf.2" -> "2" + ), + "System Properties" -> Seq( + "sys.prop.1" -> "1", + "sys.prop.2" -> "2" + ), + "Classpath Entries" -> Seq( + "/jar1" -> "System", + "/jar2" -> "User" + ) + ) + + listener.onEnvironmentUpdate(SparkListenerEnvironmentUpdate(details)) + + val appEnvKey = classOf[ApplicationEnvironmentInfoWrapper].getName() + check[ApplicationEnvironmentInfoWrapper](appEnvKey) { env => + val info = env.info + + val runtimeInfo = Map(details("JVM Information"): _*) + assert(info.runtime.javaVersion == runtimeInfo("Java Version")) + assert(info.runtime.javaHome == runtimeInfo("Java Home")) + assert(info.runtime.scalaVersion == runtimeInfo("Scala Version")) + + assert(info.sparkProperties === details("Spark Properties")) + assert(info.systemProperties === details("System Properties")) + assert(info.classpathEntries === details("Classpath Entries")) + } + } + test("scheduler events") { val listener = new AppStatusListener(store, conf, true) diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index 99cac34c85ebc..62930e2e3a931 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -39,6 +39,7 @@ object MimaExcludes { // SPARK-18085: Better History Server scalability for many / large applications ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.status.api.v1.ExecutorSummary.executorLogs"), ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.deploy.history.HistoryServer.getSparkUI"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.ui.env.EnvironmentListener"), // [SPARK-20495][SQL] Add StorageLevel to cacheTable API ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.sql.catalog.Catalog.cacheTable"), From 3da3d76352cc471252a54088cc55208bb4ea5b3a Mon Sep 17 00:00:00 2001 From: Yanbo Liang Date: Tue, 7 Nov 2017 20:07:30 -0800 Subject: [PATCH 1621/1765] [SPARK-14516][ML][FOLLOW-UP] Move ClusteringEvaluatorSuite test data to data/mllib. ## What changes were proposed in this pull request? Move ```ClusteringEvaluatorSuite``` test data(iris) to data/mllib, to prevent from re-creating a new folder. ## How was this patch tested? Existing tests. Author: Yanbo Liang Closes #19648 from yanboliang/spark-14516. --- .../iris.libsvm => data/mllib/iris_libsvm.txt | 0 .../evaluation/ClusteringEvaluatorSuite.scala | 30 +++++++------------ 2 files changed, 11 insertions(+), 19 deletions(-) rename mllib/src/test/resources/test-data/iris.libsvm => data/mllib/iris_libsvm.txt (100%) diff --git a/mllib/src/test/resources/test-data/iris.libsvm b/data/mllib/iris_libsvm.txt similarity index 100% rename from mllib/src/test/resources/test-data/iris.libsvm rename to data/mllib/iris_libsvm.txt diff --git a/mllib/src/test/scala/org/apache/spark/ml/evaluation/ClusteringEvaluatorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/evaluation/ClusteringEvaluatorSuite.scala index e60ebbd7c852d..677ce49a903ab 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/evaluation/ClusteringEvaluatorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/evaluation/ClusteringEvaluatorSuite.scala @@ -22,8 +22,7 @@ import org.apache.spark.ml.param.ParamsSuite import org.apache.spark.ml.util.DefaultReadWriteTest import org.apache.spark.ml.util.TestingUtils._ import org.apache.spark.mllib.util.MLlibTestSparkContext -import org.apache.spark.sql.{DataFrame, SparkSession} -import org.apache.spark.sql.types.IntegerType +import org.apache.spark.sql.Dataset class ClusteringEvaluatorSuite @@ -31,6 +30,13 @@ class ClusteringEvaluatorSuite import testImplicits._ + @transient var irisDataset: Dataset[_] = _ + + override def beforeAll(): Unit = { + super.beforeAll() + irisDataset = spark.read.format("libsvm").load("../data/mllib/iris_libsvm.txt") + } + test("params") { ParamsSuite.checkParams(new ClusteringEvaluator) } @@ -53,37 +59,23 @@ class ClusteringEvaluatorSuite 0.6564679231 */ test("squared euclidean Silhouette") { - val iris = ClusteringEvaluatorSuite.irisDataset(spark) val evaluator = new ClusteringEvaluator() .setFeaturesCol("features") .setPredictionCol("label") - assert(evaluator.evaluate(iris) ~== 0.6564679231 relTol 1e-5) + assert(evaluator.evaluate(irisDataset) ~== 0.6564679231 relTol 1e-5) } test("number of clusters must be greater than one") { - val iris = ClusteringEvaluatorSuite.irisDataset(spark) - .where($"label" === 0.0) + val singleClusterDataset = irisDataset.where($"label" === 0.0) val evaluator = new ClusteringEvaluator() .setFeaturesCol("features") .setPredictionCol("label") val e = intercept[AssertionError]{ - evaluator.evaluate(iris) + evaluator.evaluate(singleClusterDataset) } assert(e.getMessage.contains("Number of clusters must be greater than one")) } } - -object ClusteringEvaluatorSuite { - def irisDataset(spark: SparkSession): DataFrame = { - - val irisPath = Thread.currentThread() - .getContextClassLoader - .getResource("test-data/iris.libsvm") - .toString - - spark.read.format("libsvm").load(irisPath) - } -} From 2ca5aae47a25dc6bc9e333fb592025ff14824501 Mon Sep 17 00:00:00 2001 From: Felix Cheung Date: Tue, 7 Nov 2017 21:02:14 -0800 Subject: [PATCH 1622/1765] [SPARK-22281][SPARKR] Handle R method breaking signature changes ## What changes were proposed in this pull request? This is to fix the code for the latest R changes in R-devel, when running CRAN check ``` checking for code/documentation mismatches ... WARNING Codoc mismatches from documentation object 'attach': attach Code: function(what, pos = 2L, name = deparse(substitute(what), backtick = FALSE), warn.conflicts = TRUE) Docs: function(what, pos = 2L, name = deparse(substitute(what)), warn.conflicts = TRUE) Mismatches in argument default values: Name: 'name' Code: deparse(substitute(what), backtick = FALSE) Docs: deparse(substitute(what)) Codoc mismatches from documentation object 'glm': glm Code: function(formula, family = gaussian, data, weights, subset, na.action, start = NULL, etastart, mustart, offset, control = list(...), model = TRUE, method = "glm.fit", x = FALSE, y = TRUE, singular.ok = TRUE, contrasts = NULL, ...) Docs: function(formula, family = gaussian, data, weights, subset, na.action, start = NULL, etastart, mustart, offset, control = list(...), model = TRUE, method = "glm.fit", x = FALSE, y = TRUE, contrasts = NULL, ...) Argument names in code not in docs: singular.ok Mismatches in argument names: Position: 16 Code: singular.ok Docs: contrasts Position: 17 Code: contrasts Docs: ... ``` With attach, it's pulling in the function definition from base::attach. We need to disable that but we would still need a function signature for roxygen2 to build with. With glm it's pulling in the function definition (ie. "usage") from the stats::glm function. Since this is "compiled in" when we build the source package into the .Rd file, when it changes at runtime or in CRAN check it won't match the latest signature. The solution is not to pull in from stats::glm since there isn't much value in doing that (none of the param we actually use, the ones we do use we have explicitly documented them) Also with attach we are changing to call dynamically. ## How was this patch tested? Manually. - [x] check documentation output - yes - [x] check help `?attach` `?glm` - yes - [x] check on other platforms, r-hub, on r-devel etc.. Author: Felix Cheung Closes #19557 from felixcheung/rattachglmdocerror. --- R/pkg/R/DataFrame.R | 11 +++++++---- R/pkg/R/generics.R | 10 ++++------ R/pkg/R/mllib_regression.R | 1 + R/run-tests.sh | 6 +++--- 4 files changed, 15 insertions(+), 13 deletions(-) diff --git a/R/pkg/R/DataFrame.R b/R/pkg/R/DataFrame.R index aaa3349d57506..763c8d2548580 100644 --- a/R/pkg/R/DataFrame.R +++ b/R/pkg/R/DataFrame.R @@ -3236,7 +3236,7 @@ setMethod("as.data.frame", #' #' @family SparkDataFrame functions #' @rdname attach -#' @aliases attach,SparkDataFrame-method +#' @aliases attach attach,SparkDataFrame-method #' @param what (SparkDataFrame) The SparkDataFrame to attach #' @param pos (integer) Specify position in search() where to attach. #' @param name (character) Name to use for the attached SparkDataFrame. Names @@ -3252,9 +3252,12 @@ setMethod("as.data.frame", #' @note attach since 1.6.0 setMethod("attach", signature(what = "SparkDataFrame"), - function(what, pos = 2, name = deparse(substitute(what)), warn.conflicts = TRUE) { - newEnv <- assignNewEnv(what) - attach(newEnv, pos = pos, name = name, warn.conflicts = warn.conflicts) + function(what, pos = 2L, name = deparse(substitute(what), backtick = FALSE), + warn.conflicts = TRUE) { + args <- as.list(environment()) # capture all parameters - this must be the first line + newEnv <- assignNewEnv(args$what) + args$what <- newEnv + do.call(attach, args) }) #' Evaluate a R expression in an environment constructed from a SparkDataFrame diff --git a/R/pkg/R/generics.R b/R/pkg/R/generics.R index 4e427489f6860..8312d417b99d2 100644 --- a/R/pkg/R/generics.R +++ b/R/pkg/R/generics.R @@ -409,7 +409,8 @@ setGeneric("as.data.frame", standardGeneric("as.data.frame") }) -#' @rdname attach +# Do not document the generic because of signature changes across R versions +#' @noRd #' @export setGeneric("attach") @@ -1569,12 +1570,9 @@ setGeneric("year", function(x) { standardGeneric("year") }) #' @export setGeneric("fitted") -#' @param x,y For \code{glm}: logical values indicating whether the response vector -#' and model matrix used in the fitting process should be returned as -#' components of the returned value. -#' @inheritParams stats::glm -#' @rdname glm +# Do not carry stats::glm usage and param here, and do not document the generic #' @export +#' @noRd setGeneric("glm") #' @param object a fitted ML model object. diff --git a/R/pkg/R/mllib_regression.R b/R/pkg/R/mllib_regression.R index f734a0865ec3b..545be5e1d89f0 100644 --- a/R/pkg/R/mllib_regression.R +++ b/R/pkg/R/mllib_regression.R @@ -210,6 +210,7 @@ setMethod("spark.glm", signature(data = "SparkDataFrame", formula = "formula"), #' 1.0. #' @return \code{glm} returns a fitted generalized linear model. #' @rdname glm +#' @aliases glm #' @export #' @examples #' \dontrun{ diff --git a/R/run-tests.sh b/R/run-tests.sh index f38c86e3e6b1d..86bd8aad5f113 100755 --- a/R/run-tests.sh +++ b/R/run-tests.sh @@ -47,10 +47,10 @@ if [[ $FAILED != 0 || $NUM_TEST_WARNING != 0 ]]; then echo -en "\033[0m" # No color exit -1 else - # We have 2 NOTEs for RoxygenNote, attach(); and one in Jenkins only "No repository set" + # We have 2 NOTEs: for RoxygenNote and one in Jenkins only "No repository set" # For non-latest version branches, one WARNING for package version - if [[ ($NUM_CRAN_WARNING != 0 || $NUM_CRAN_ERROR != 0 || $NUM_CRAN_NOTES -gt 3) && - ($HAS_PACKAGE_VERSION_WARN != 1 || $NUM_CRAN_WARNING != 1 || $NUM_CRAN_ERROR != 0 || $NUM_CRAN_NOTES -gt 2) ]]; then + if [[ ($NUM_CRAN_WARNING != 0 || $NUM_CRAN_ERROR != 0 || $NUM_CRAN_NOTES -gt 2) && + ($HAS_PACKAGE_VERSION_WARN != 1 || $NUM_CRAN_WARNING != 1 || $NUM_CRAN_ERROR != 0 || $NUM_CRAN_NOTES -gt 1) ]]; then cat $CRAN_CHECK_LOG_FILE echo -en "\033[31m" # Red echo "Had CRAN check errors; see logs." From 11eea1a4ce32c9018218d4dfc9f46b744eb82991 Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Tue, 7 Nov 2017 23:14:29 -0600 Subject: [PATCH 1623/1765] [SPARK-20646][CORE] Port executors page to new UI backend. The executors page is built on top of the REST API, so the page itself was easy to hook up to the new code. Some other pages depend on the `ExecutorListener` class that is being removed, though, so they needed to be modified to use data from the new store. Fortunately, all they seemed to need is the map of executor logs, so that was somewhat easy too. The executor timeline graph required adding some properties to the ExecutorSummary API type. Instead of following the previous code, which stored all the listener events in memory, the timeline is now created based on the data available from the API. I had to change some of the test golden files because the old code would return executors in "random" order (since it used a mutable Map instead of something that returns a sorted list), and the new code returns executors in id order. Tested with existing unit tests. Author: Marcelo Vanzin Closes #19678 from vanzin/SPARK-20646. --- .../spark/status/AppStatusListener.scala | 36 ++- .../apache/spark/status/AppStatusStore.scala | 18 +- .../org/apache/spark/status/LiveEntity.scala | 9 +- .../api/v1/AllExecutorListResource.scala | 15 +- .../status/api/v1/ExecutorListResource.scala | 14 +- .../org/apache/spark/status/api/v1/api.scala | 3 + .../scala/org/apache/spark/ui/SparkUI.scala | 13 +- .../ui/exec/ExecutorThreadDumpPage.scala | 9 +- .../apache/spark/ui/exec/ExecutorsPage.scala | 154 ------------- .../apache/spark/ui/exec/ExecutorsTab.scala | 206 +++--------------- .../apache/spark/ui/jobs/AllJobsPage.scala | 60 +++-- .../apache/spark/ui/jobs/ExecutorTable.scala | 11 +- .../org/apache/spark/ui/jobs/JobPage.scala | 59 +++-- .../org/apache/spark/ui/jobs/JobsTab.scala | 3 +- .../org/apache/spark/ui/jobs/StagePage.scala | 18 +- .../org/apache/spark/ui/jobs/StagesTab.scala | 8 +- .../executor_list_json_expectation.json | 1 + .../executor_memory_usage_expectation.json | 127 +++++------ ...xecutor_node_blacklisting_expectation.json | 117 +++++----- ...acklisting_unblacklisting_expectation.json | 89 ++++---- .../org/apache/spark/ui/StagePageSuite.scala | 11 +- project/MimaExcludes.scala | 1 + 22 files changed, 354 insertions(+), 628 deletions(-) delete mode 100644 core/src/main/scala/org/apache/spark/ui/exec/ExecutorsPage.scala diff --git a/core/src/main/scala/org/apache/spark/status/AppStatusListener.scala b/core/src/main/scala/org/apache/spark/status/AppStatusListener.scala index 424a1159a875c..0469c871362c0 100644 --- a/core/src/main/scala/org/apache/spark/status/AppStatusListener.scala +++ b/core/src/main/scala/org/apache/spark/status/AppStatusListener.scala @@ -103,6 +103,9 @@ private[spark] class AppStatusListener( details.getOrElse("System Properties", Nil), details.getOrElse("Classpath Entries", Nil)) + coresPerTask = envInfo.sparkProperties.toMap.get("spark.task.cpus").map(_.toInt) + .getOrElse(coresPerTask) + kvstore.write(new ApplicationEnvironmentInfoWrapper(envInfo)) } @@ -132,7 +135,7 @@ private[spark] class AppStatusListener( override def onExecutorAdded(event: SparkListenerExecutorAdded): Unit = { // This needs to be an update in case an executor re-registers after the driver has // marked it as "dead". - val exec = getOrCreateExecutor(event.executorId) + val exec = getOrCreateExecutor(event.executorId, event.time) exec.host = event.executorInfo.executorHost exec.isActive = true exec.totalCores = event.executorInfo.totalCores @@ -144,6 +147,8 @@ private[spark] class AppStatusListener( override def onExecutorRemoved(event: SparkListenerExecutorRemoved): Unit = { liveExecutors.remove(event.executorId).foreach { exec => exec.isActive = false + exec.removeTime = new Date(event.time) + exec.removeReason = event.reason update(exec, System.nanoTime()) } } @@ -357,18 +362,25 @@ private[spark] class AppStatusListener( } liveExecutors.get(event.taskInfo.executorId).foreach { exec => - if (event.taskMetrics != null) { - val readMetrics = event.taskMetrics.shuffleReadMetrics - exec.totalGcTime += event.taskMetrics.jvmGCTime - exec.totalInputBytes += event.taskMetrics.inputMetrics.bytesRead - exec.totalShuffleRead += readMetrics.localBytesRead + readMetrics.remoteBytesRead - exec.totalShuffleWrite += event.taskMetrics.shuffleWriteMetrics.bytesWritten - } - exec.activeTasks -= 1 exec.completedTasks += completedDelta exec.failedTasks += failedDelta exec.totalDuration += event.taskInfo.duration + + // Note: For resubmitted tasks, we continue to use the metrics that belong to the + // first attempt of this task. This may not be 100% accurate because the first attempt + // could have failed half-way through. The correct fix would be to keep track of the + // metrics added by each attempt, but this is much more complicated. + if (event.reason != Resubmitted) { + if (event.taskMetrics != null) { + val readMetrics = event.taskMetrics.shuffleReadMetrics + exec.totalGcTime += event.taskMetrics.jvmGCTime + exec.totalInputBytes += event.taskMetrics.inputMetrics.bytesRead + exec.totalShuffleRead += readMetrics.localBytesRead + readMetrics.remoteBytesRead + exec.totalShuffleWrite += event.taskMetrics.shuffleWriteMetrics.bytesWritten + } + } + maybeUpdate(exec, now) } } @@ -409,7 +421,7 @@ private[spark] class AppStatusListener( override def onBlockManagerAdded(event: SparkListenerBlockManagerAdded): Unit = { // This needs to set fields that are already set by onExecutorAdded because the driver is // considered an "executor" in the UI, but does not have a SparkListenerExecutorAdded event. - val exec = getOrCreateExecutor(event.blockManagerId.executorId) + val exec = getOrCreateExecutor(event.blockManagerId.executorId, event.time) exec.hostPort = event.blockManagerId.hostPort event.maxOnHeapMem.foreach { _ => exec.totalOnHeap = event.maxOnHeapMem.get @@ -561,8 +573,8 @@ private[spark] class AppStatusListener( } } - private def getOrCreateExecutor(executorId: String): LiveExecutor = { - liveExecutors.getOrElseUpdate(executorId, new LiveExecutor(executorId)) + private def getOrCreateExecutor(executorId: String, addTime: Long): LiveExecutor = { + liveExecutors.getOrElseUpdate(executorId, new LiveExecutor(executorId, addTime)) } private def getOrCreateStage(info: StageInfo): LiveStage = { diff --git a/core/src/main/scala/org/apache/spark/status/AppStatusStore.scala b/core/src/main/scala/org/apache/spark/status/AppStatusStore.scala index d6b5d2661e8ee..334407829f9fe 100644 --- a/core/src/main/scala/org/apache/spark/status/AppStatusStore.scala +++ b/core/src/main/scala/org/apache/spark/status/AppStatusStore.scala @@ -56,8 +56,22 @@ private[spark] class AppStatusStore(store: KVStore) { } def executorList(activeOnly: Boolean): Seq[v1.ExecutorSummary] = { - store.view(classOf[ExecutorSummaryWrapper]).index("active").reverse().first(true) - .last(true).asScala.map(_.info).toSeq + val base = store.view(classOf[ExecutorSummaryWrapper]) + val filtered = if (activeOnly) { + base.index("active").reverse().first(true).last(true) + } else { + base + } + filtered.asScala.map(_.info).toSeq + } + + def executorSummary(executorId: String): Option[v1.ExecutorSummary] = { + try { + Some(store.read(classOf[ExecutorSummaryWrapper], executorId).info) + } catch { + case _: NoSuchElementException => + None + } } def stageList(statuses: JList[v1.StageStatus]): Seq[v1.StageData] = { diff --git a/core/src/main/scala/org/apache/spark/status/LiveEntity.scala b/core/src/main/scala/org/apache/spark/status/LiveEntity.scala index 041dfe1ef915e..8c48020e246b4 100644 --- a/core/src/main/scala/org/apache/spark/status/LiveEntity.scala +++ b/core/src/main/scala/org/apache/spark/status/LiveEntity.scala @@ -212,13 +212,17 @@ private class LiveTask( } -private class LiveExecutor(val executorId: String) extends LiveEntity { +private class LiveExecutor(val executorId: String, _addTime: Long) extends LiveEntity { var hostPort: String = null var host: String = null var isActive = true var totalCores = 0 + val addTime = new Date(_addTime) + var removeTime: Date = null + var removeReason: String = null + var rddBlocks = 0 var memoryUsed = 0L var diskUsed = 0L @@ -276,6 +280,9 @@ private class LiveExecutor(val executorId: String) extends LiveEntity { totalShuffleWrite, isBlacklisted, maxMemory, + addTime, + Option(removeTime), + Option(removeReason), executorLogs, memoryMetrics) new ExecutorSummaryWrapper(info) diff --git a/core/src/main/scala/org/apache/spark/status/api/v1/AllExecutorListResource.scala b/core/src/main/scala/org/apache/spark/status/api/v1/AllExecutorListResource.scala index eb5cc1b9a3bd0..5522f4cebd773 100644 --- a/core/src/main/scala/org/apache/spark/status/api/v1/AllExecutorListResource.scala +++ b/core/src/main/scala/org/apache/spark/status/api/v1/AllExecutorListResource.scala @@ -20,22 +20,11 @@ import javax.ws.rs.{GET, Produces} import javax.ws.rs.core.MediaType import org.apache.spark.ui.SparkUI -import org.apache.spark.ui.exec.ExecutorsPage @Produces(Array(MediaType.APPLICATION_JSON)) private[v1] class AllExecutorListResource(ui: SparkUI) { @GET - def executorList(): Seq[ExecutorSummary] = { - val listener = ui.executorsListener - listener.synchronized { - // The follow codes should be protected by `listener` to make sure no executors will be - // removed before we query their status. See SPARK-12784. - (0 until listener.activeStorageStatusList.size).map { statusId => - ExecutorsPage.getExecInfo(listener, statusId, isActive = true) - } ++ (0 until listener.deadStorageStatusList.size).map { statusId => - ExecutorsPage.getExecInfo(listener, statusId, isActive = false) - } - } - } + def executorList(): Seq[ExecutorSummary] = ui.store.executorList(false) + } diff --git a/core/src/main/scala/org/apache/spark/status/api/v1/ExecutorListResource.scala b/core/src/main/scala/org/apache/spark/status/api/v1/ExecutorListResource.scala index 2f3b5e984002a..975101c33c59c 100644 --- a/core/src/main/scala/org/apache/spark/status/api/v1/ExecutorListResource.scala +++ b/core/src/main/scala/org/apache/spark/status/api/v1/ExecutorListResource.scala @@ -20,21 +20,11 @@ import javax.ws.rs.{GET, Produces} import javax.ws.rs.core.MediaType import org.apache.spark.ui.SparkUI -import org.apache.spark.ui.exec.ExecutorsPage @Produces(Array(MediaType.APPLICATION_JSON)) private[v1] class ExecutorListResource(ui: SparkUI) { @GET - def executorList(): Seq[ExecutorSummary] = { - val listener = ui.executorsListener - listener.synchronized { - // The follow codes should be protected by `listener` to make sure no executors will be - // removed before we query their status. See SPARK-12784. - val storageStatusList = listener.activeStorageStatusList - (0 until storageStatusList.size).map { statusId => - ExecutorsPage.getExecInfo(listener, statusId, isActive = true) - } - } - } + def executorList(): Seq[ExecutorSummary] = ui.store.executorList(true) + } diff --git a/core/src/main/scala/org/apache/spark/status/api/v1/api.scala b/core/src/main/scala/org/apache/spark/status/api/v1/api.scala index bff6f90823f40..b338b1f3fd073 100644 --- a/core/src/main/scala/org/apache/spark/status/api/v1/api.scala +++ b/core/src/main/scala/org/apache/spark/status/api/v1/api.scala @@ -85,6 +85,9 @@ class ExecutorSummary private[spark]( val totalShuffleWrite: Long, val isBlacklisted: Boolean, val maxMemory: Long, + val addTime: Date, + val removeTime: Option[Date], + val removeReason: Option[String], val executorLogs: Map[String, String], val memoryMetrics: Option[MemoryMetrics]) diff --git a/core/src/main/scala/org/apache/spark/ui/SparkUI.scala b/core/src/main/scala/org/apache/spark/ui/SparkUI.scala index 43b57a1630aa9..79d40b6a90c3c 100644 --- a/core/src/main/scala/org/apache/spark/ui/SparkUI.scala +++ b/core/src/main/scala/org/apache/spark/ui/SparkUI.scala @@ -29,7 +29,7 @@ import org.apache.spark.status.api.v1._ import org.apache.spark.storage.StorageStatusListener import org.apache.spark.ui.JettyUtils._ import org.apache.spark.ui.env.EnvironmentTab -import org.apache.spark.ui.exec.{ExecutorsListener, ExecutorsTab} +import org.apache.spark.ui.exec.ExecutorsTab import org.apache.spark.ui.jobs.{JobProgressListener, JobsTab, StagesTab} import org.apache.spark.ui.scope.RDDOperationGraphListener import org.apache.spark.ui.storage.{StorageListener, StorageTab} @@ -44,7 +44,6 @@ private[spark] class SparkUI private ( val conf: SparkConf, securityManager: SecurityManager, val storageStatusListener: StorageStatusListener, - val executorsListener: ExecutorsListener, val jobProgressListener: JobProgressListener, val storageListener: StorageListener, val operationGraphListener: RDDOperationGraphListener, @@ -68,7 +67,7 @@ private[spark] class SparkUI private ( def initialize() { val jobsTab = new JobsTab(this) attachTab(jobsTab) - val stagesTab = new StagesTab(this) + val stagesTab = new StagesTab(this, store) attachTab(stagesTab) attachTab(new StorageTab(this)) attachTab(new EnvironmentTab(this, store)) @@ -189,18 +188,16 @@ private[spark] object SparkUI { } val storageStatusListener = new StorageStatusListener(conf) - val executorsListener = new ExecutorsListener(storageStatusListener, conf) val storageListener = new StorageListener(storageStatusListener) val operationGraphListener = new RDDOperationGraphListener(conf) addListenerFn(storageStatusListener) - addListenerFn(executorsListener) addListenerFn(storageListener) addListenerFn(operationGraphListener) - new SparkUI(store, sc, conf, securityManager, storageStatusListener, executorsListener, - jobProgressListener, storageListener, operationGraphListener, appName, basePath, - lastUpdateTime, startTime, appSparkVersion) + new SparkUI(store, sc, conf, securityManager, storageStatusListener, jobProgressListener, + storageListener, operationGraphListener, appName, basePath, lastUpdateTime, startTime, + appSparkVersion) } } diff --git a/core/src/main/scala/org/apache/spark/ui/exec/ExecutorThreadDumpPage.scala b/core/src/main/scala/org/apache/spark/ui/exec/ExecutorThreadDumpPage.scala index 7b211ea5199c3..f4686ea3cf91f 100644 --- a/core/src/main/scala/org/apache/spark/ui/exec/ExecutorThreadDumpPage.scala +++ b/core/src/main/scala/org/apache/spark/ui/exec/ExecutorThreadDumpPage.scala @@ -22,11 +22,12 @@ import javax.servlet.http.HttpServletRequest import scala.xml.{Node, Text} -import org.apache.spark.ui.{UIUtils, WebUIPage} +import org.apache.spark.SparkContext +import org.apache.spark.ui.{SparkUITab, UIUtils, WebUIPage} -private[ui] class ExecutorThreadDumpPage(parent: ExecutorsTab) extends WebUIPage("threadDump") { - - private val sc = parent.sc +private[ui] class ExecutorThreadDumpPage( + parent: SparkUITab, + sc: Option[SparkContext]) extends WebUIPage("threadDump") { // stripXSS is called first to remove suspicious characters used in XSS attacks def render(request: HttpServletRequest): Seq[Node] = { diff --git a/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsPage.scala b/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsPage.scala deleted file mode 100644 index 7b2767f0be3cd..0000000000000 --- a/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsPage.scala +++ /dev/null @@ -1,154 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.ui.exec - -import javax.servlet.http.HttpServletRequest - -import scala.xml.Node - -import org.apache.spark.status.api.v1.{ExecutorSummary, MemoryMetrics} -import org.apache.spark.ui.{UIUtils, WebUIPage} - -// This isn't even used anymore -- but we need to keep it b/c of a MiMa false positive -private[ui] case class ExecutorSummaryInfo( - id: String, - hostPort: String, - rddBlocks: Int, - memoryUsed: Long, - diskUsed: Long, - activeTasks: Int, - failedTasks: Int, - completedTasks: Int, - totalTasks: Int, - totalDuration: Long, - totalInputBytes: Long, - totalShuffleRead: Long, - totalShuffleWrite: Long, - isBlacklisted: Int, - maxOnHeapMem: Long, - maxOffHeapMem: Long, - executorLogs: Map[String, String]) - - -private[ui] class ExecutorsPage( - parent: ExecutorsTab, - threadDumpEnabled: Boolean) - extends WebUIPage("") { - - def render(request: HttpServletRequest): Seq[Node] = { - val content = -
    - { -
    - - - Show Additional Metrics - - -
    ++ -
    ++ - ++ - ++ - - } -
    - - UIUtils.headerSparkPage("Executors", content, parent, useDataTables = true) - } -} - -private[spark] object ExecutorsPage { - private val ON_HEAP_MEMORY_TOOLTIP = "Memory used / total available memory for on heap " + - "storage of data like RDD partitions cached in memory." - private val OFF_HEAP_MEMORY_TOOLTIP = "Memory used / total available memory for off heap " + - "storage of data like RDD partitions cached in memory." - - /** Represent an executor's info as a map given a storage status index */ - def getExecInfo( - listener: ExecutorsListener, - statusId: Int, - isActive: Boolean): ExecutorSummary = { - val status = if (isActive) { - listener.activeStorageStatusList(statusId) - } else { - listener.deadStorageStatusList(statusId) - } - val execId = status.blockManagerId.executorId - val hostPort = status.blockManagerId.hostPort - val rddBlocks = status.numBlocks - val memUsed = status.memUsed - val maxMem = status.maxMem - val memoryMetrics = for { - onHeapUsed <- status.onHeapMemUsed - offHeapUsed <- status.offHeapMemUsed - maxOnHeap <- status.maxOnHeapMem - maxOffHeap <- status.maxOffHeapMem - } yield { - new MemoryMetrics(onHeapUsed, offHeapUsed, maxOnHeap, maxOffHeap) - } - - - val diskUsed = status.diskUsed - val taskSummary = listener.executorToTaskSummary.getOrElse(execId, ExecutorTaskSummary(execId)) - - new ExecutorSummary( - execId, - hostPort, - isActive, - rddBlocks, - memUsed, - diskUsed, - taskSummary.totalCores, - taskSummary.tasksMax, - taskSummary.tasksActive, - taskSummary.tasksFailed, - taskSummary.tasksComplete, - taskSummary.tasksActive + taskSummary.tasksFailed + taskSummary.tasksComplete, - taskSummary.duration, - taskSummary.jvmGCTime, - taskSummary.inputBytes, - taskSummary.shuffleRead, - taskSummary.shuffleWrite, - taskSummary.isBlacklisted, - maxMem, - taskSummary.executorLogs, - memoryMetrics - ) - } -} diff --git a/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsTab.scala b/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsTab.scala index 64a1a292a3840..843486f4a70d2 100644 --- a/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsTab.scala +++ b/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsTab.scala @@ -17,192 +17,44 @@ package org.apache.spark.ui.exec -import scala.collection.mutable.{LinkedHashMap, ListBuffer} +import javax.servlet.http.HttpServletRequest -import org.apache.spark.{Resubmitted, SparkConf, SparkContext} -import org.apache.spark.annotation.DeveloperApi -import org.apache.spark.scheduler._ -import org.apache.spark.storage.{StorageStatus, StorageStatusListener} -import org.apache.spark.ui.{SparkUI, SparkUITab} +import scala.xml.Node -private[ui] class ExecutorsTab(parent: SparkUI) extends SparkUITab(parent, "executors") { - val listener = parent.executorsListener - val sc = parent.sc - val threadDumpEnabled = - sc.isDefined && parent.conf.getBoolean("spark.ui.threadDumpsEnabled", true) - - attachPage(new ExecutorsPage(this, threadDumpEnabled)) - if (threadDumpEnabled) { - attachPage(new ExecutorThreadDumpPage(this)) - } -} - -private[ui] case class ExecutorTaskSummary( - var executorId: String, - var totalCores: Int = 0, - var tasksMax: Int = 0, - var tasksActive: Int = 0, - var tasksFailed: Int = 0, - var tasksComplete: Int = 0, - var duration: Long = 0L, - var jvmGCTime: Long = 0L, - var inputBytes: Long = 0L, - var inputRecords: Long = 0L, - var outputBytes: Long = 0L, - var outputRecords: Long = 0L, - var shuffleRead: Long = 0L, - var shuffleWrite: Long = 0L, - var executorLogs: Map[String, String] = Map.empty, - var isAlive: Boolean = true, - var isBlacklisted: Boolean = false -) - -/** - * :: DeveloperApi :: - * A SparkListener that prepares information to be displayed on the ExecutorsTab - */ -@DeveloperApi -@deprecated("This class will be removed in a future release.", "2.2.0") -class ExecutorsListener(storageStatusListener: StorageStatusListener, conf: SparkConf) - extends SparkListener { - val executorToTaskSummary = LinkedHashMap[String, ExecutorTaskSummary]() - var executorEvents = new ListBuffer[SparkListenerEvent]() - - private val maxTimelineExecutors = conf.getInt("spark.ui.timeline.executors.maximum", 1000) - private val retainedDeadExecutors = conf.getInt("spark.ui.retainedDeadExecutors", 100) - - def activeStorageStatusList: Seq[StorageStatus] = storageStatusListener.storageStatusList +import org.apache.spark.ui.{SparkUI, SparkUITab, UIUtils, WebUIPage} - def deadStorageStatusList: Seq[StorageStatus] = storageStatusListener.deadStorageStatusList - - override def onExecutorAdded( - executorAdded: SparkListenerExecutorAdded): Unit = synchronized { - val eid = executorAdded.executorId - val taskSummary = executorToTaskSummary.getOrElseUpdate(eid, ExecutorTaskSummary(eid)) - taskSummary.executorLogs = executorAdded.executorInfo.logUrlMap - taskSummary.totalCores = executorAdded.executorInfo.totalCores - taskSummary.tasksMax = taskSummary.totalCores / conf.getInt("spark.task.cpus", 1) - executorEvents += executorAdded - if (executorEvents.size > maxTimelineExecutors) { - executorEvents.remove(0) - } - - val deadExecutors = executorToTaskSummary.filter(e => !e._2.isAlive) - if (deadExecutors.size > retainedDeadExecutors) { - val head = deadExecutors.head - executorToTaskSummary.remove(head._1) - } - } - - override def onExecutorRemoved( - executorRemoved: SparkListenerExecutorRemoved): Unit = synchronized { - executorEvents += executorRemoved - if (executorEvents.size > maxTimelineExecutors) { - executorEvents.remove(0) - } - executorToTaskSummary.get(executorRemoved.executorId).foreach(e => e.isAlive = false) - } - - override def onApplicationStart( - applicationStart: SparkListenerApplicationStart): Unit = { - applicationStart.driverLogs.foreach { logs => - val storageStatus = activeStorageStatusList.find { s => - s.blockManagerId.executorId == SparkContext.LEGACY_DRIVER_IDENTIFIER || - s.blockManagerId.executorId == SparkContext.DRIVER_IDENTIFIER - } - storageStatus.foreach { s => - val eid = s.blockManagerId.executorId - val taskSummary = executorToTaskSummary.getOrElseUpdate(eid, ExecutorTaskSummary(eid)) - taskSummary.executorLogs = logs.toMap - } - } - } - - override def onTaskStart( - taskStart: SparkListenerTaskStart): Unit = synchronized { - val eid = taskStart.taskInfo.executorId - val taskSummary = executorToTaskSummary.getOrElseUpdate(eid, ExecutorTaskSummary(eid)) - taskSummary.tasksActive += 1 - } +private[ui] class ExecutorsTab(parent: SparkUI) extends SparkUITab(parent, "executors") { - override def onTaskEnd( - taskEnd: SparkListenerTaskEnd): Unit = synchronized { - val info = taskEnd.taskInfo - if (info != null) { - val eid = info.executorId - val taskSummary = executorToTaskSummary.getOrElseUpdate(eid, ExecutorTaskSummary(eid)) - // Note: For resubmitted tasks, we continue to use the metrics that belong to the - // first attempt of this task. This may not be 100% accurate because the first attempt - // could have failed half-way through. The correct fix would be to keep track of the - // metrics added by each attempt, but this is much more complicated. - if (taskEnd.reason == Resubmitted) { - return - } - if (info.successful) { - taskSummary.tasksComplete += 1 - } else { - taskSummary.tasksFailed += 1 - } - if (taskSummary.tasksActive >= 1) { - taskSummary.tasksActive -= 1 - } - taskSummary.duration += info.duration + init() - // Update shuffle read/write - val metrics = taskEnd.taskMetrics - if (metrics != null) { - taskSummary.inputBytes += metrics.inputMetrics.bytesRead - taskSummary.inputRecords += metrics.inputMetrics.recordsRead - taskSummary.outputBytes += metrics.outputMetrics.bytesWritten - taskSummary.outputRecords += metrics.outputMetrics.recordsWritten + private def init(): Unit = { + val threadDumpEnabled = + parent.sc.isDefined && parent.conf.getBoolean("spark.ui.threadDumpsEnabled", true) - taskSummary.shuffleRead += metrics.shuffleReadMetrics.remoteBytesRead - taskSummary.shuffleWrite += metrics.shuffleWriteMetrics.bytesWritten - taskSummary.jvmGCTime += metrics.jvmGCTime - } + attachPage(new ExecutorsPage(this, threadDumpEnabled)) + if (threadDumpEnabled) { + attachPage(new ExecutorThreadDumpPage(this, parent.sc)) } } - private def updateExecutorBlacklist( - eid: String, - isBlacklisted: Boolean): Unit = { - val execTaskSummary = executorToTaskSummary.getOrElseUpdate(eid, ExecutorTaskSummary(eid)) - execTaskSummary.isBlacklisted = isBlacklisted - } - - override def onExecutorBlacklisted( - executorBlacklisted: SparkListenerExecutorBlacklisted) - : Unit = synchronized { - updateExecutorBlacklist(executorBlacklisted.executorId, true) - } - - override def onExecutorUnblacklisted( - executorUnblacklisted: SparkListenerExecutorUnblacklisted) - : Unit = synchronized { - updateExecutorBlacklist(executorUnblacklisted.executorId, false) - } - - override def onNodeBlacklisted( - nodeBlacklisted: SparkListenerNodeBlacklisted) - : Unit = synchronized { - // Implicitly blacklist every executor associated with this node, and show this in the UI. - activeStorageStatusList.foreach { status => - if (status.blockManagerId.host == nodeBlacklisted.hostId) { - updateExecutorBlacklist(status.blockManagerId.executorId, true) - } - } - } +} - override def onNodeUnblacklisted( - nodeUnblacklisted: SparkListenerNodeUnblacklisted) - : Unit = synchronized { - // Implicitly unblacklist every executor associated with this node, regardless of how - // they may have been blacklisted initially (either explicitly through executor blacklisting - // or implicitly through node blacklisting). Show this in the UI. - activeStorageStatusList.foreach { status => - if (status.blockManagerId.host == nodeUnblacklisted.hostId) { - updateExecutorBlacklist(status.blockManagerId.executorId, false) - } - } +private[ui] class ExecutorsPage( + parent: SparkUITab, + threadDumpEnabled: Boolean) + extends WebUIPage("") { + + def render(request: HttpServletRequest): Seq[Node] = { + val content = +
    + { +
    ++ + ++ + ++ + + } +
    + + UIUtils.headerSparkPage("Executors", content, parent, useDataTables = true) } } diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/AllJobsPage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/AllJobsPage.scala index a7f2caafe04b8..a647a1173a8cb 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/AllJobsPage.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/AllJobsPage.scala @@ -29,6 +29,7 @@ import org.apache.commons.lang3.StringEscapeUtils import org.apache.spark.JobExecutionStatus import org.apache.spark.scheduler._ +import org.apache.spark.status.api.v1.ExecutorSummary import org.apache.spark.ui._ import org.apache.spark.ui.jobs.UIData.{JobUIData, StageUIData} import org.apache.spark.util.Utils @@ -123,55 +124,53 @@ private[ui] class AllJobsPage(parent: JobsTab) extends WebUIPage("") { } } - private def makeExecutorEvent(executorUIDatas: Seq[SparkListenerEvent]): + private def makeExecutorEvent(executors: Seq[ExecutorSummary]): Seq[String] = { val events = ListBuffer[String]() - executorUIDatas.foreach { - case a: SparkListenerExecutorAdded => - val addedEvent = - s""" - |{ - | 'className': 'executor added', - | 'group': 'executors', - | 'start': new Date(${a.time}), - | 'content': '
    Executor ${a.executorId} added
    ' - |} - """.stripMargin - events += addedEvent - case e: SparkListenerExecutorRemoved => + executors.foreach { e => + val addedEvent = + s""" + |{ + | 'className': 'executor added', + | 'group': 'executors', + | 'start': new Date(${e.addTime.getTime()}), + | 'content': '
    Executor ${e.id} added
    ' + |} + """.stripMargin + events += addedEvent + + e.removeTime.foreach { removeTime => val removedEvent = s""" |{ | 'className': 'executor removed', | 'group': 'executors', - | 'start': new Date(${e.time}), + | 'start': new Date(${removeTime.getTime()}), | 'content': '
    ' + + | 'Removed at ${UIUtils.formatDate(removeTime)}' + | '${ - if (e.reason != null) { - s"""
    Reason: ${e.reason.replace("\n", " ")}""" - } else { - "" - } + e.removeReason.map { reason => + s"""
    Reason: ${reason.replace("\n", " ")}""" + }.getOrElse("") }"' + - | 'data-html="true">Executor ${e.executorId} removed
    ' + | 'data-html="true">Executor ${e.id} removed' |} """.stripMargin events += removedEvent - + } } events.toSeq } private def makeTimeline( jobs: Seq[JobUIData], - executors: Seq[SparkListenerEvent], + executors: Seq[ExecutorSummary], startTime: Long): Seq[Node] = { val jobEventJsonAsStrSeq = makeJobEvent(jobs) @@ -360,9 +359,8 @@ private[ui] class AllJobsPage(parent: JobsTab) extends WebUIPage("") { var content = summary - val executorListener = parent.executorListener content ++= makeTimeline(activeJobs ++ completedJobs ++ failedJobs, - executorListener.executorEvents, startTime) + parent.parent.store.executorList(false), startTime) if (shouldShowActiveJobs) { content ++=

    Active Jobs ({activeJobs.size})

    ++ diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/ExecutorTable.scala b/core/src/main/scala/org/apache/spark/ui/jobs/ExecutorTable.scala index 382a6f979f2e6..07a41d195a191 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/ExecutorTable.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/ExecutorTable.scala @@ -20,12 +20,17 @@ package org.apache.spark.ui.jobs import scala.collection.mutable import scala.xml.{Node, Unparsed} +import org.apache.spark.status.AppStatusStore import org.apache.spark.ui.{ToolTips, UIUtils} import org.apache.spark.ui.jobs.UIData.StageUIData import org.apache.spark.util.Utils /** Stage summary grouped by executors. */ -private[ui] class ExecutorTable(stageId: Int, stageAttemptId: Int, parent: StagesTab) { +private[ui] class ExecutorTable( + stageId: Int, + stageAttemptId: Int, + parent: StagesTab, + store: AppStatusStore) { private val listener = parent.progressListener def toNodeSeq: Seq[Node] = { @@ -123,9 +128,7 @@ private[ui] class ExecutorTable(stageId: Int, stageAttemptId: Int, parent: Stage
    {k}
    { - val logs = parent.executorsListener.executorToTaskSummary.get(k) - .map(_.executorLogs).getOrElse(Map.empty) - logs.map { + store.executorSummary(k).map(_.executorLogs).getOrElse(Map.empty).map { case (logName, logUrl) => } } diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/JobPage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/JobPage.scala index 9fb011a049b7e..7ed01646f3621 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/JobPage.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/JobPage.scala @@ -27,6 +27,7 @@ import org.apache.commons.lang3.StringEscapeUtils import org.apache.spark.JobExecutionStatus import org.apache.spark.scheduler._ +import org.apache.spark.status.api.v1.ExecutorSummary import org.apache.spark.ui.{ToolTips, UIUtils, WebUIPage} /** Page showing statistics and stage list for a given job */ @@ -92,55 +93,52 @@ private[ui] class JobPage(parent: JobsTab) extends WebUIPage("job") { } } - def makeExecutorEvent(executorUIDatas: Seq[SparkListenerEvent]): Seq[String] = { + def makeExecutorEvent(executors: Seq[ExecutorSummary]): Seq[String] = { val events = ListBuffer[String]() - executorUIDatas.foreach { - case a: SparkListenerExecutorAdded => - val addedEvent = - s""" - |{ - | 'className': 'executor added', - | 'group': 'executors', - | 'start': new Date(${a.time}), - | 'content': '
    Executor ${a.executorId} added
    ' - |} - """.stripMargin - events += addedEvent + executors.foreach { e => + val addedEvent = + s""" + |{ + | 'className': 'executor added', + | 'group': 'executors', + | 'start': new Date(${e.addTime.getTime()}), + | 'content': '
    Executor ${e.id} added
    ' + |} + """.stripMargin + events += addedEvent - case e: SparkListenerExecutorRemoved => + e.removeTime.foreach { removeTime => val removedEvent = s""" |{ | 'className': 'executor removed', | 'group': 'executors', - | 'start': new Date(${e.time}), + | 'start': new Date(${removeTime.getTime()}), | 'content': '
    ' + + | 'Removed at ${UIUtils.formatDate(removeTime)}' + | '${ - if (e.reason != null) { - s"""
    Reason: ${e.reason.replace("\n", " ")}""" - } else { - "" - } + e.removeReason.map { reason => + s"""
    Reason: ${reason.replace("\n", " ")}""" + }.getOrElse("") }"' + - | 'data-html="true">Executor ${e.executorId} removed
    ' + | 'data-html="true">Executor ${e.id} removed
    ' |} """.stripMargin events += removedEvent - + } } events.toSeq } private def makeTimeline( stages: Seq[StageInfo], - executors: Seq[SparkListenerEvent], + executors: Seq[ExecutorSummary], appStartTime: Long): Seq[Node] = { val stageEventJsonAsStrSeq = makeStageEvent(stages) @@ -322,11 +320,10 @@ private[ui] class JobPage(parent: JobsTab) extends WebUIPage("job") { var content = summary val appStartTime = listener.startTime - val executorListener = parent.executorListener val operationGraphListener = parent.operationGraphListener content ++= makeTimeline(activeStages ++ completedStages ++ failedStages, - executorListener.executorEvents, appStartTime) + parent.parent.store.executorList(false), appStartTime) content ++= UIUtils.showDagVizForJob( jobId, operationGraphListener.getOperationGraphForJob(jobId)) diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/JobsTab.scala b/core/src/main/scala/org/apache/spark/ui/jobs/JobsTab.scala index cc173381879a6..81ffe04aca49a 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/JobsTab.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/JobsTab.scala @@ -23,11 +23,10 @@ import org.apache.spark.scheduler.SchedulingMode import org.apache.spark.ui.{SparkUI, SparkUITab, UIUtils} /** Web UI showing progress status of all jobs in the given SparkContext. */ -private[ui] class JobsTab(parent: SparkUI) extends SparkUITab(parent, "jobs") { +private[ui] class JobsTab(val parent: SparkUI) extends SparkUITab(parent, "jobs") { val sc = parent.sc val killEnabled = parent.killEnabled val jobProgresslistener = parent.jobProgressListener - val executorListener = parent.executorsListener val operationGraphListener = parent.operationGraphListener def isFairScheduler: Boolean = diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala index 4d80308eb0a6d..3151b8d554658 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala @@ -29,18 +29,17 @@ import org.apache.commons.lang3.StringEscapeUtils import org.apache.spark.SparkConf import org.apache.spark.executor.TaskMetrics import org.apache.spark.scheduler.{AccumulableInfo, TaskInfo, TaskLocality} +import org.apache.spark.status.AppStatusStore import org.apache.spark.ui._ -import org.apache.spark.ui.exec.ExecutorsListener import org.apache.spark.ui.jobs.UIData._ import org.apache.spark.util.{Distribution, Utils} /** Page showing statistics and task list for a given stage */ -private[ui] class StagePage(parent: StagesTab) extends WebUIPage("stage") { +private[ui] class StagePage(parent: StagesTab, store: AppStatusStore) extends WebUIPage("stage") { import StagePage._ private val progressListener = parent.progressListener private val operationGraphListener = parent.operationGraphListener - private val executorsListener = parent.executorsListener private val TIMELINE_LEGEND = {
    @@ -304,7 +303,7 @@ private[ui] class StagePage(parent: StagesTab) extends WebUIPage("stage") { pageSize = taskPageSize, sortColumn = taskSortColumn, desc = taskSortDesc, - executorsListener = executorsListener + store = store ) (_taskTable, _taskTable.table(page)) } catch { @@ -563,7 +562,7 @@ private[ui] class StagePage(parent: StagesTab) extends WebUIPage("stage") { stripeRowsWithCss = false)) } - val executorTable = new ExecutorTable(stageId, stageAttemptId, parent) + val executorTable = new ExecutorTable(stageId, stageAttemptId, parent, store) val maybeAccumulableTable: Seq[Node] = if (hasAccumulators) {

    Accumulators

    ++ accumulableTable } else Seq.empty @@ -869,7 +868,7 @@ private[ui] class TaskDataSource( pageSize: Int, sortColumn: String, desc: Boolean, - executorsListener: ExecutorsListener) extends PagedDataSource[TaskTableRowData](pageSize) { + store: AppStatusStore) extends PagedDataSource[TaskTableRowData](pageSize) { import StagePage._ // Convert TaskUIData to TaskTableRowData which contains the final contents to show in the table @@ -1012,8 +1011,7 @@ private[ui] class TaskDataSource( None } - val logs = executorsListener.executorToTaskSummary.get(info.executorId) - .map(_.executorLogs).getOrElse(Map.empty) + val logs = store.executorSummary(info.executorId).map(_.executorLogs).getOrElse(Map.empty) new TaskTableRowData( info.index, info.taskId, @@ -1162,7 +1160,7 @@ private[ui] class TaskPagedTable( pageSize: Int, sortColumn: String, desc: Boolean, - executorsListener: ExecutorsListener) extends PagedTable[TaskTableRowData] { + store: AppStatusStore) extends PagedTable[TaskTableRowData] { override def tableId: String = "task-table" @@ -1188,7 +1186,7 @@ private[ui] class TaskPagedTable( pageSize, sortColumn, desc, - executorsListener) + store) override def pageLink(page: Int): String = { val encodedSortColumn = URLEncoder.encode(sortColumn, "UTF-8") diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/StagesTab.scala b/core/src/main/scala/org/apache/spark/ui/jobs/StagesTab.scala index 0787ea6625903..65446f967ad76 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/StagesTab.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/StagesTab.scala @@ -20,20 +20,22 @@ package org.apache.spark.ui.jobs import javax.servlet.http.HttpServletRequest import org.apache.spark.scheduler.SchedulingMode +import org.apache.spark.status.AppStatusStore import org.apache.spark.ui.{SparkUI, SparkUITab, UIUtils} /** Web UI showing progress status of all stages in the given SparkContext. */ -private[ui] class StagesTab(parent: SparkUI) extends SparkUITab(parent, "stages") { +private[ui] class StagesTab(val parent: SparkUI, store: AppStatusStore) + extends SparkUITab(parent, "stages") { + val sc = parent.sc val conf = parent.conf val killEnabled = parent.killEnabled val progressListener = parent.jobProgressListener val operationGraphListener = parent.operationGraphListener - val executorsListener = parent.executorsListener val lastUpdateTime = parent.lastUpdateTime attachPage(new AllStagesPage(this)) - attachPage(new StagePage(this)) + attachPage(new StagePage(this, store)) attachPage(new PoolPage(this)) def isFairScheduler: Boolean = progressListener.schedulingMode == Some(SchedulingMode.FAIR) diff --git a/core/src/test/resources/HistoryServerExpectations/executor_list_json_expectation.json b/core/src/test/resources/HistoryServerExpectations/executor_list_json_expectation.json index 6b9f29e1a230e..942e6d8f04363 100644 --- a/core/src/test/resources/HistoryServerExpectations/executor_list_json_expectation.json +++ b/core/src/test/resources/HistoryServerExpectations/executor_list_json_expectation.json @@ -18,5 +18,6 @@ "totalShuffleWrite" : 13180, "isBlacklisted" : false, "maxMemory" : 278302556, + "addTime" : "2015-02-03T16:43:00.906GMT", "executorLogs" : { } } ] diff --git a/core/src/test/resources/HistoryServerExpectations/executor_memory_usage_expectation.json b/core/src/test/resources/HistoryServerExpectations/executor_memory_usage_expectation.json index 0f94e3b255dbc..ed33c90dd39ba 100644 --- a/core/src/test/resources/HistoryServerExpectations/executor_memory_usage_expectation.json +++ b/core/src/test/resources/HistoryServerExpectations/executor_memory_usage_expectation.json @@ -1,6 +1,34 @@ [ { - "id" : "2", - "hostPort" : "172.22.0.167:51487", + "id" : "driver", + "hostPort" : "172.22.0.167:51475", + "isActive" : true, + "rddBlocks" : 0, + "memoryUsed" : 0, + "diskUsed" : 0, + "totalCores" : 0, + "maxTasks" : 0, + "activeTasks" : 0, + "failedTasks" : 0, + "completedTasks" : 0, + "totalTasks" : 0, + "totalDuration" : 0, + "totalGCTime" : 0, + "totalInputBytes" : 0, + "totalShuffleRead" : 0, + "totalShuffleWrite" : 0, + "isBlacklisted" : true, + "maxMemory" : 908381388, + "addTime" : "2016-11-16T22:33:31.477GMT", + "executorLogs" : { }, + "memoryMetrics" : { + "usedOnHeapStorageMemory" : 0, + "usedOffHeapStorageMemory" : 0, + "totalOnHeapStorageMemory" : 384093388, + "totalOffHeapStorageMemory" : 524288000 + } +}, { + "id" : "3", + "hostPort" : "172.22.0.167:51485", "isActive" : true, "rddBlocks" : 0, "memoryUsed" : 0, @@ -8,52 +36,57 @@ "totalCores" : 4, "maxTasks" : 4, "activeTasks" : 0, - "failedTasks" : 4, - "completedTasks" : 0, - "totalTasks" : 4, - "totalDuration" : 2537, - "totalGCTime" : 88, + "failedTasks" : 0, + "completedTasks" : 12, + "totalTasks" : 12, + "totalDuration" : 2453, + "totalGCTime" : 72, "totalInputBytes" : 0, "totalShuffleRead" : 0, "totalShuffleWrite" : 0, "isBlacklisted" : true, "maxMemory" : 908381388, + "addTime" : "2016-11-16T22:33:35.320GMT", "executorLogs" : { - "stdout" : "http://172.22.0.167:51469/logPage/?appId=app-20161116163331-0000&executorId=2&logType=stdout", - "stderr" : "http://172.22.0.167:51469/logPage/?appId=app-20161116163331-0000&executorId=2&logType=stderr" + "stdout" : "http://172.22.0.167:51466/logPage/?appId=app-20161116163331-0000&executorId=3&logType=stdout", + "stderr" : "http://172.22.0.167:51466/logPage/?appId=app-20161116163331-0000&executorId=3&logType=stderr" }, "memoryMetrics": { - "usedOnHeapStorageMemory": 0, - "usedOffHeapStorageMemory": 0, - "totalOnHeapStorageMemory": 384093388, - "totalOffHeapStorageMemory": 524288000 + "usedOnHeapStorageMemory" : 0, + "usedOffHeapStorageMemory" : 0, + "totalOnHeapStorageMemory" : 384093388, + "totalOffHeapStorageMemory" : 524288000 } -}, { - "id" : "driver", - "hostPort" : "172.22.0.167:51475", +} ,{ + "id" : "2", + "hostPort" : "172.22.0.167:51487", "isActive" : true, "rddBlocks" : 0, "memoryUsed" : 0, "diskUsed" : 0, - "totalCores" : 0, - "maxTasks" : 0, + "totalCores" : 4, + "maxTasks" : 4, "activeTasks" : 0, - "failedTasks" : 0, + "failedTasks" : 4, "completedTasks" : 0, - "totalTasks" : 0, - "totalDuration" : 0, - "totalGCTime" : 0, + "totalTasks" : 4, + "totalDuration" : 2537, + "totalGCTime" : 88, "totalInputBytes" : 0, "totalShuffleRead" : 0, "totalShuffleWrite" : 0, "isBlacklisted" : true, "maxMemory" : 908381388, - "executorLogs" : { }, + "addTime" : "2016-11-16T22:33:35.393GMT", + "executorLogs" : { + "stdout" : "http://172.22.0.167:51469/logPage/?appId=app-20161116163331-0000&executorId=2&logType=stdout", + "stderr" : "http://172.22.0.167:51469/logPage/?appId=app-20161116163331-0000&executorId=2&logType=stderr" + }, "memoryMetrics": { - "usedOnHeapStorageMemory": 0, - "usedOffHeapStorageMemory": 0, - "totalOnHeapStorageMemory": 384093388, - "totalOffHeapStorageMemory": 524288000 + "usedOnHeapStorageMemory" : 0, + "usedOffHeapStorageMemory" : 0, + "totalOnHeapStorageMemory" : 384093388, + "totalOffHeapStorageMemory" : 524288000 } }, { "id" : "1", @@ -75,6 +108,7 @@ "totalShuffleWrite" : 0, "isBlacklisted" : true, "maxMemory" : 908381388, + "addTime" : "2016-11-16T22:33:35.443GMT", "executorLogs" : { "stdout" : "http://172.22.0.167:51467/logPage/?appId=app-20161116163331-0000&executorId=1&logType=stdout", "stderr" : "http://172.22.0.167:51467/logPage/?appId=app-20161116163331-0000&executorId=1&logType=stderr" @@ -105,44 +139,15 @@ "totalShuffleWrite" : 0, "isBlacklisted" : true, "maxMemory" : 908381388, + "addTime" : "2016-11-16T22:33:35.462GMT", "executorLogs" : { "stdout" : "http://172.22.0.167:51465/logPage/?appId=app-20161116163331-0000&executorId=0&logType=stdout", "stderr" : "http://172.22.0.167:51465/logPage/?appId=app-20161116163331-0000&executorId=0&logType=stderr" }, "memoryMetrics": { - "usedOnHeapStorageMemory": 0, - "usedOffHeapStorageMemory": 0, - "totalOnHeapStorageMemory": 384093388, - "totalOffHeapStorageMemory": 524288000 - } -}, { - "id" : "3", - "hostPort" : "172.22.0.167:51485", - "isActive" : true, - "rddBlocks" : 0, - "memoryUsed" : 0, - "diskUsed" : 0, - "totalCores" : 4, - "maxTasks" : 4, - "activeTasks" : 0, - "failedTasks" : 0, - "completedTasks" : 12, - "totalTasks" : 12, - "totalDuration" : 2453, - "totalGCTime" : 72, - "totalInputBytes" : 0, - "totalShuffleRead" : 0, - "totalShuffleWrite" : 0, - "isBlacklisted" : true, - "maxMemory" : 908381388, - "executorLogs" : { - "stdout" : "http://172.22.0.167:51466/logPage/?appId=app-20161116163331-0000&executorId=3&logType=stdout", - "stderr" : "http://172.22.0.167:51466/logPage/?appId=app-20161116163331-0000&executorId=3&logType=stderr" - }, - "memoryMetrics": { - "usedOnHeapStorageMemory": 0, - "usedOffHeapStorageMemory": 0, - "totalOnHeapStorageMemory": 384093388, - "totalOffHeapStorageMemory": 524288000 + "usedOnHeapStorageMemory" : 0, + "usedOffHeapStorageMemory" : 0, + "totalOnHeapStorageMemory" : 384093388, + "totalOffHeapStorageMemory" : 524288000 } } ] diff --git a/core/src/test/resources/HistoryServerExpectations/executor_node_blacklisting_expectation.json b/core/src/test/resources/HistoryServerExpectations/executor_node_blacklisting_expectation.json index 0f94e3b255dbc..73519f1d9e2e4 100644 --- a/core/src/test/resources/HistoryServerExpectations/executor_node_blacklisting_expectation.json +++ b/core/src/test/resources/HistoryServerExpectations/executor_node_blacklisting_expectation.json @@ -1,6 +1,34 @@ [ { - "id" : "2", - "hostPort" : "172.22.0.167:51487", + "id" : "driver", + "hostPort" : "172.22.0.167:51475", + "isActive" : true, + "rddBlocks" : 0, + "memoryUsed" : 0, + "diskUsed" : 0, + "totalCores" : 0, + "maxTasks" : 0, + "activeTasks" : 0, + "failedTasks" : 0, + "completedTasks" : 0, + "totalTasks" : 0, + "totalDuration" : 0, + "totalGCTime" : 0, + "totalInputBytes" : 0, + "totalShuffleRead" : 0, + "totalShuffleWrite" : 0, + "isBlacklisted" : true, + "maxMemory" : 908381388, + "addTime" : "2016-11-16T22:33:31.477GMT", + "executorLogs" : { }, + "memoryMetrics": { + "usedOnHeapStorageMemory" : 0, + "usedOffHeapStorageMemory" : 0, + "totalOnHeapStorageMemory" : 384093388, + "totalOffHeapStorageMemory" : 524288000 + } +}, { + "id" : "3", + "hostPort" : "172.22.0.167:51485", "isActive" : true, "rddBlocks" : 0, "memoryUsed" : 0, @@ -8,52 +36,57 @@ "totalCores" : 4, "maxTasks" : 4, "activeTasks" : 0, - "failedTasks" : 4, - "completedTasks" : 0, - "totalTasks" : 4, - "totalDuration" : 2537, - "totalGCTime" : 88, + "failedTasks" : 0, + "completedTasks" : 12, + "totalTasks" : 12, + "totalDuration" : 2453, + "totalGCTime" : 72, "totalInputBytes" : 0, "totalShuffleRead" : 0, "totalShuffleWrite" : 0, "isBlacklisted" : true, "maxMemory" : 908381388, + "addTime" : "2016-11-16T22:33:35.320GMT", "executorLogs" : { - "stdout" : "http://172.22.0.167:51469/logPage/?appId=app-20161116163331-0000&executorId=2&logType=stdout", - "stderr" : "http://172.22.0.167:51469/logPage/?appId=app-20161116163331-0000&executorId=2&logType=stderr" + "stdout" : "http://172.22.0.167:51466/logPage/?appId=app-20161116163331-0000&executorId=3&logType=stdout", + "stderr" : "http://172.22.0.167:51466/logPage/?appId=app-20161116163331-0000&executorId=3&logType=stderr" }, "memoryMetrics": { - "usedOnHeapStorageMemory": 0, - "usedOffHeapStorageMemory": 0, - "totalOnHeapStorageMemory": 384093388, - "totalOffHeapStorageMemory": 524288000 + "usedOnHeapStorageMemory" : 0, + "usedOffHeapStorageMemory" : 0, + "totalOnHeapStorageMemory" : 384093388, + "totalOffHeapStorageMemory" : 524288000 } }, { - "id" : "driver", - "hostPort" : "172.22.0.167:51475", + "id" : "2", + "hostPort" : "172.22.0.167:51487", "isActive" : true, "rddBlocks" : 0, "memoryUsed" : 0, "diskUsed" : 0, - "totalCores" : 0, - "maxTasks" : 0, + "totalCores" : 4, + "maxTasks" : 4, "activeTasks" : 0, - "failedTasks" : 0, + "failedTasks" : 4, "completedTasks" : 0, - "totalTasks" : 0, - "totalDuration" : 0, - "totalGCTime" : 0, + "totalTasks" : 4, + "totalDuration" : 2537, + "totalGCTime" : 88, "totalInputBytes" : 0, "totalShuffleRead" : 0, "totalShuffleWrite" : 0, "isBlacklisted" : true, "maxMemory" : 908381388, - "executorLogs" : { }, + "addTime" : "2016-11-16T22:33:35.393GMT", + "executorLogs" : { + "stdout" : "http://172.22.0.167:51469/logPage/?appId=app-20161116163331-0000&executorId=2&logType=stdout", + "stderr" : "http://172.22.0.167:51469/logPage/?appId=app-20161116163331-0000&executorId=2&logType=stderr" + }, "memoryMetrics": { - "usedOnHeapStorageMemory": 0, - "usedOffHeapStorageMemory": 0, - "totalOnHeapStorageMemory": 384093388, - "totalOffHeapStorageMemory": 524288000 + "usedOnHeapStorageMemory" : 0, + "usedOffHeapStorageMemory" : 0, + "totalOnHeapStorageMemory" : 384093388, + "totalOffHeapStorageMemory" : 524288000 } }, { "id" : "1", @@ -75,6 +108,7 @@ "totalShuffleWrite" : 0, "isBlacklisted" : true, "maxMemory" : 908381388, + "addTime" : "2016-11-16T22:33:35.443GMT", "executorLogs" : { "stdout" : "http://172.22.0.167:51467/logPage/?appId=app-20161116163331-0000&executorId=1&logType=stdout", "stderr" : "http://172.22.0.167:51467/logPage/?appId=app-20161116163331-0000&executorId=1&logType=stderr" @@ -105,6 +139,7 @@ "totalShuffleWrite" : 0, "isBlacklisted" : true, "maxMemory" : 908381388, + "addTime" : "2016-11-16T22:33:35.462GMT", "executorLogs" : { "stdout" : "http://172.22.0.167:51465/logPage/?appId=app-20161116163331-0000&executorId=0&logType=stdout", "stderr" : "http://172.22.0.167:51465/logPage/?appId=app-20161116163331-0000&executorId=0&logType=stderr" @@ -115,34 +150,4 @@ "totalOnHeapStorageMemory": 384093388, "totalOffHeapStorageMemory": 524288000 } -}, { - "id" : "3", - "hostPort" : "172.22.0.167:51485", - "isActive" : true, - "rddBlocks" : 0, - "memoryUsed" : 0, - "diskUsed" : 0, - "totalCores" : 4, - "maxTasks" : 4, - "activeTasks" : 0, - "failedTasks" : 0, - "completedTasks" : 12, - "totalTasks" : 12, - "totalDuration" : 2453, - "totalGCTime" : 72, - "totalInputBytes" : 0, - "totalShuffleRead" : 0, - "totalShuffleWrite" : 0, - "isBlacklisted" : true, - "maxMemory" : 908381388, - "executorLogs" : { - "stdout" : "http://172.22.0.167:51466/logPage/?appId=app-20161116163331-0000&executorId=3&logType=stdout", - "stderr" : "http://172.22.0.167:51466/logPage/?appId=app-20161116163331-0000&executorId=3&logType=stderr" - }, - "memoryMetrics": { - "usedOnHeapStorageMemory": 0, - "usedOffHeapStorageMemory": 0, - "totalOnHeapStorageMemory": 384093388, - "totalOffHeapStorageMemory": 524288000 - } } ] diff --git a/core/src/test/resources/HistoryServerExpectations/executor_node_blacklisting_unblacklisting_expectation.json b/core/src/test/resources/HistoryServerExpectations/executor_node_blacklisting_unblacklisting_expectation.json index 92e249c851116..6931fead3d2ff 100644 --- a/core/src/test/resources/HistoryServerExpectations/executor_node_blacklisting_unblacklisting_expectation.json +++ b/core/src/test/resources/HistoryServerExpectations/executor_node_blacklisting_unblacklisting_expectation.json @@ -1,6 +1,28 @@ [ { - "id" : "2", - "hostPort" : "172.22.0.111:64539", + "id" : "driver", + "hostPort" : "172.22.0.111:64527", + "isActive" : true, + "rddBlocks" : 0, + "memoryUsed" : 0, + "diskUsed" : 0, + "totalCores" : 0, + "maxTasks" : 0, + "activeTasks" : 0, + "failedTasks" : 0, + "completedTasks" : 0, + "totalTasks" : 0, + "totalDuration" : 0, + "totalGCTime" : 0, + "totalInputBytes" : 0, + "totalShuffleRead" : 0, + "totalShuffleWrite" : 0, + "isBlacklisted" : false, + "maxMemory" : 384093388, + "addTime" : "2016-11-15T23:20:38.836GMT", + "executorLogs" : { } +}, { + "id" : "3", + "hostPort" : "172.22.0.111:64543", "isActive" : true, "rddBlocks" : 0, "memoryUsed" : 0, @@ -8,41 +30,46 @@ "totalCores" : 4, "maxTasks" : 4, "activeTasks" : 0, - "failedTasks" : 6, - "completedTasks" : 0, - "totalTasks" : 6, - "totalDuration" : 2792, - "totalGCTime" : 128, + "failedTasks" : 0, + "completedTasks" : 4, + "totalTasks" : 4, + "totalDuration" : 3457, + "totalGCTime" : 72, "totalInputBytes" : 0, "totalShuffleRead" : 0, "totalShuffleWrite" : 0, "isBlacklisted" : false, "maxMemory" : 384093388, + "addTime" : "2016-11-15T23:20:42.711GMT", "executorLogs" : { - "stdout" : "http://172.22.0.111:64519/logPage/?appId=app-20161115172038-0000&executorId=2&logType=stdout", - "stderr" : "http://172.22.0.111:64519/logPage/?appId=app-20161115172038-0000&executorId=2&logType=stderr" + "stdout" : "http://172.22.0.111:64521/logPage/?appId=app-20161115172038-0000&executorId=3&logType=stdout", + "stderr" : "http://172.22.0.111:64521/logPage/?appId=app-20161115172038-0000&executorId=3&logType=stderr" } }, { - "id" : "driver", - "hostPort" : "172.22.0.111:64527", + "id" : "2", + "hostPort" : "172.22.0.111:64539", "isActive" : true, "rddBlocks" : 0, "memoryUsed" : 0, "diskUsed" : 0, - "totalCores" : 0, - "maxTasks" : 0, + "totalCores" : 4, + "maxTasks" : 4, "activeTasks" : 0, - "failedTasks" : 0, + "failedTasks" : 6, "completedTasks" : 0, - "totalTasks" : 0, - "totalDuration" : 0, - "totalGCTime" : 0, + "totalTasks" : 6, + "totalDuration" : 2792, + "totalGCTime" : 128, "totalInputBytes" : 0, "totalShuffleRead" : 0, "totalShuffleWrite" : 0, "isBlacklisted" : false, "maxMemory" : 384093388, - "executorLogs" : { } + "addTime" : "2016-11-15T23:20:42.589GMT", + "executorLogs" : { + "stdout" : "http://172.22.0.111:64519/logPage/?appId=app-20161115172038-0000&executorId=2&logType=stdout", + "stderr" : "http://172.22.0.111:64519/logPage/?appId=app-20161115172038-0000&executorId=2&logType=stderr" + } }, { "id" : "1", "hostPort" : "172.22.0.111:64541", @@ -63,6 +90,7 @@ "totalShuffleWrite" : 0, "isBlacklisted" : false, "maxMemory" : 384093388, + "addTime" : "2016-11-15T23:20:42.629GMT", "executorLogs" : { "stdout" : "http://172.22.0.111:64518/logPage/?appId=app-20161115172038-0000&executorId=1&logType=stdout", "stderr" : "http://172.22.0.111:64518/logPage/?appId=app-20161115172038-0000&executorId=1&logType=stderr" @@ -87,32 +115,9 @@ "totalShuffleWrite" : 0, "isBlacklisted" : false, "maxMemory" : 384093388, + "addTime" : "2016-11-15T23:20:42.593GMT", "executorLogs" : { "stdout" : "http://172.22.0.111:64517/logPage/?appId=app-20161115172038-0000&executorId=0&logType=stdout", "stderr" : "http://172.22.0.111:64517/logPage/?appId=app-20161115172038-0000&executorId=0&logType=stderr" } -}, { - "id" : "3", - "hostPort" : "172.22.0.111:64543", - "isActive" : true, - "rddBlocks" : 0, - "memoryUsed" : 0, - "diskUsed" : 0, - "totalCores" : 4, - "maxTasks" : 4, - "activeTasks" : 0, - "failedTasks" : 0, - "completedTasks" : 4, - "totalTasks" : 4, - "totalDuration" : 3457, - "totalGCTime" : 72, - "totalInputBytes" : 0, - "totalShuffleRead" : 0, - "totalShuffleWrite" : 0, - "isBlacklisted" : false, - "maxMemory" : 384093388, - "executorLogs" : { - "stdout" : "http://172.22.0.111:64521/logPage/?appId=app-20161115172038-0000&executorId=3&logType=stdout", - "stderr" : "http://172.22.0.111:64521/logPage/?appId=app-20161115172038-0000&executorId=3&logType=stderr" - } } ] diff --git a/core/src/test/scala/org/apache/spark/ui/StagePageSuite.scala b/core/src/test/scala/org/apache/spark/ui/StagePageSuite.scala index 499d47b13d702..1c51c148ae61b 100644 --- a/core/src/test/scala/org/apache/spark/ui/StagePageSuite.scala +++ b/core/src/test/scala/org/apache/spark/ui/StagePageSuite.scala @@ -22,13 +22,13 @@ import javax.servlet.http.HttpServletRequest import scala.xml.Node +import org.mockito.Matchers.anyString import org.mockito.Mockito.{mock, when, RETURNS_SMART_NULLS} import org.apache.spark._ import org.apache.spark.executor.TaskMetrics import org.apache.spark.scheduler._ -import org.apache.spark.storage.StorageStatusListener -import org.apache.spark.ui.exec.ExecutorsListener +import org.apache.spark.status.AppStatusStore import org.apache.spark.ui.jobs.{JobProgressListener, StagePage, StagesTab} import org.apache.spark.ui.scope.RDDOperationGraphListener @@ -55,20 +55,21 @@ class StagePageSuite extends SparkFunSuite with LocalSparkContext { * This also runs a dummy stage to populate the page with useful content. */ private def renderStagePage(conf: SparkConf): Seq[Node] = { + val store = mock(classOf[AppStatusStore]) + when(store.executorSummary(anyString())).thenReturn(None) + val jobListener = new JobProgressListener(conf) val graphListener = new RDDOperationGraphListener(conf) - val executorsListener = new ExecutorsListener(new StorageStatusListener(conf), conf) val tab = mock(classOf[StagesTab], RETURNS_SMART_NULLS) val request = mock(classOf[HttpServletRequest]) when(tab.conf).thenReturn(conf) when(tab.progressListener).thenReturn(jobListener) when(tab.operationGraphListener).thenReturn(graphListener) - when(tab.executorsListener).thenReturn(executorsListener) when(tab.appName).thenReturn("testing") when(tab.headerTabs).thenReturn(Seq.empty) when(request.getParameter("id")).thenReturn("0") when(request.getParameter("attempt")).thenReturn("0") - val page = new StagePage(tab) + val page = new StagePage(tab, store) // Simulate a stage in job progress listener val stageInfo = new StageInfo(0, 0, "dummy", 1, Seq.empty, Seq.empty, "details") diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index 62930e2e3a931..0c31b2b4a9402 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -40,6 +40,7 @@ object MimaExcludes { ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.status.api.v1.ExecutorSummary.executorLogs"), ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.deploy.history.HistoryServer.getSparkUI"), ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.ui.env.EnvironmentListener"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.ui.exec.ExecutorsListener"), // [SPARK-20495][SQL] Add StorageLevel to cacheTable API ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.sql.catalog.Catalog.cacheTable"), From 51debf8b1f4d479bc7f81e2759ba28e526367d70 Mon Sep 17 00:00:00 2001 From: Sean Owen Date: Wed, 8 Nov 2017 10:24:40 +0000 Subject: [PATCH 1624/1765] [SPARK-14540][BUILD] Support Scala 2.12 closures and Java 8 lambdas in ClosureCleaner (step 0) ## What changes were proposed in this pull request? Preliminary changes to get ClosureCleaner to work with Scala 2.12. Makes many usages just work, but not all. This does _not_ resolve the JIRA. ## How was this patch tested? Existing tests Author: Sean Owen Closes #19675 from srowen/SPARK-14540.0. --- .../apache/spark/util/ClosureCleaner.scala | 28 +++++++++++-------- .../spark/util/ClosureCleanerSuite2.scala | 10 +++++-- .../spark/graphx/lib/ShortestPaths.scala | 2 +- .../streaming/BasicOperationsSuite.scala | 4 +-- 4 files changed, 26 insertions(+), 18 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/util/ClosureCleaner.scala b/core/src/main/scala/org/apache/spark/util/ClosureCleaner.scala index dfece5dd0670b..40616421b5bca 100644 --- a/core/src/main/scala/org/apache/spark/util/ClosureCleaner.scala +++ b/core/src/main/scala/org/apache/spark/util/ClosureCleaner.scala @@ -38,12 +38,13 @@ private[spark] object ClosureCleaner extends Logging { // Copy data over, before delegating to ClassReader - else we can run out of open file handles. val className = cls.getName.replaceFirst("^.*\\.", "") + ".class" val resourceStream = cls.getResourceAsStream(className) - // todo: Fixme - continuing with earlier behavior ... - if (resourceStream == null) return new ClassReader(resourceStream) - - val baos = new ByteArrayOutputStream(128) - Utils.copyStream(resourceStream, baos, true) - new ClassReader(new ByteArrayInputStream(baos.toByteArray)) + if (resourceStream == null) { + null + } else { + val baos = new ByteArrayOutputStream(128) + Utils.copyStream(resourceStream, baos, true) + new ClassReader(new ByteArrayInputStream(baos.toByteArray)) + } } // Check whether a class represents a Scala closure @@ -81,11 +82,13 @@ private[spark] object ClosureCleaner extends Logging { val stack = Stack[Class[_]](obj.getClass) while (!stack.isEmpty) { val cr = getClassReader(stack.pop()) - val set = Set.empty[Class[_]] - cr.accept(new InnerClosureFinder(set), 0) - for (cls <- set -- seen) { - seen += cls - stack.push(cls) + if (cr != null) { + val set = Set.empty[Class[_]] + cr.accept(new InnerClosureFinder(set), 0) + for (cls <- set -- seen) { + seen += cls + stack.push(cls) + } } } (seen - obj.getClass).toList @@ -366,7 +369,8 @@ private[spark] class ReturnStatementInClosureException private class ReturnStatementFinder extends ClassVisitor(ASM5) { override def visitMethod(access: Int, name: String, desc: String, sig: String, exceptions: Array[String]): MethodVisitor = { - if (name.contains("apply")) { + // $anonfun$ covers Java 8 lambdas + if (name.contains("apply") || name.contains("$anonfun$")) { new MethodVisitor(ASM5) { override def visitTypeInsn(op: Int, tp: String) { if (op == NEW && tp.contains("scala/runtime/NonLocalReturnControl")) { diff --git a/core/src/test/scala/org/apache/spark/util/ClosureCleanerSuite2.scala b/core/src/test/scala/org/apache/spark/util/ClosureCleanerSuite2.scala index 934385fbcad1b..278fada83d78c 100644 --- a/core/src/test/scala/org/apache/spark/util/ClosureCleanerSuite2.scala +++ b/core/src/test/scala/org/apache/spark/util/ClosureCleanerSuite2.scala @@ -117,9 +117,13 @@ class ClosureCleanerSuite2 extends SparkFunSuite with BeforeAndAfterAll with Pri findTransitively: Boolean): Map[Class[_], Set[String]] = { val fields = new mutable.HashMap[Class[_], mutable.Set[String]] outerClasses.foreach { c => fields(c) = new mutable.HashSet[String] } - ClosureCleaner.getClassReader(closure.getClass) - .accept(new FieldAccessFinder(fields, findTransitively), 0) - fields.mapValues(_.toSet).toMap + val cr = ClosureCleaner.getClassReader(closure.getClass) + if (cr == null) { + Map.empty + } else { + cr.accept(new FieldAccessFinder(fields, findTransitively), 0) + fields.mapValues(_.toSet).toMap + } } // Accessors for private methods diff --git a/graphx/src/main/scala/org/apache/spark/graphx/lib/ShortestPaths.scala b/graphx/src/main/scala/org/apache/spark/graphx/lib/ShortestPaths.scala index 4cac633aed008..aff0b932e9429 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/lib/ShortestPaths.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/lib/ShortestPaths.scala @@ -25,7 +25,7 @@ import org.apache.spark.graphx._ * Computes shortest paths to the given set of landmark vertices, returning a graph where each * vertex attribute is a map containing the shortest-path distance to each reachable landmark. */ -object ShortestPaths { +object ShortestPaths extends Serializable { /** Stores a map from the vertex id of a landmark to the distance to that landmark. */ type SPMap = Map[VertexId, Int] diff --git a/streaming/src/test/scala/org/apache/spark/streaming/BasicOperationsSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/BasicOperationsSuite.scala index 6f62c7a88dc3c..0a764f61c0cd9 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/BasicOperationsSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/BasicOperationsSuite.scala @@ -596,8 +596,6 @@ class BasicOperationsSuite extends TestSuiteBase { ) val updateStateOperation = (s: DStream[String]) => { - class StateObject(var counter: Int = 0, var expireCounter: Int = 0) extends Serializable - // updateFunc clears a state when a StateObject is seen without new values twice in a row val updateFunc = (values: Seq[Int], state: Option[StateObject]) => { val stateObj = state.getOrElse(new StateObject) @@ -817,3 +815,5 @@ class BasicOperationsSuite extends TestSuiteBase { } } } + +class StateObject(var counter: Int = 0, var expireCounter: Int = 0) extends Serializable From 87343e15566da870d1b8e49a78ef16e08ddfd406 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Wed, 8 Nov 2017 12:17:52 +0100 Subject: [PATCH 1625/1765] [SPARK-22446][SQL][ML] Declare StringIndexerModel indexer udf as nondeterministic ## What changes were proposed in this pull request? UDFs that can cause runtime exception on invalid data are not safe to pushdown, because its behavior depends on its position in the query plan. Pushdown of it will risk to change its original behavior. The example reported in the JIRA and taken as test case shows this issue. We should declare UDFs that can cause runtime exception on invalid data as non-determinstic. This updates the document of `deterministic` property in `Expression` and states clearly an UDF that can cause runtime exception on some specific input, should be declared as non-determinstic. ## How was this patch tested? Added test. Manually test. Author: Liang-Chi Hsieh Closes #19662 from viirya/SPARK-22446. --- .../spark/ml/feature/StringIndexer.scala | 2 +- .../spark/ml/feature/VectorAssembler.scala | 2 +- .../spark/ml/feature/StringIndexerSuite.scala | 17 ++++++++++++++ .../ml/feature/VectorAssemblerSuite.scala | 23 ++++++++++++++++++- .../sql/catalyst/expressions/Expression.scala | 1 + 5 files changed, 42 insertions(+), 3 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala index 2679ec310c470..1cdcdfcaeab78 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala @@ -261,7 +261,7 @@ class StringIndexerModel ( s"set Param handleInvalid to ${StringIndexer.KEEP_INVALID}.") } } - } + }.asNondeterministic() filteredDataset.select(col("*"), indexer(dataset($(inputCol)).cast(StringType)).as($(outputCol), metadata)) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala index 73f27d1a423d9..b373ae921ed38 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala @@ -97,7 +97,7 @@ class VectorAssembler @Since("1.4.0") (@Since("1.4.0") override val uid: String) // Data transformation. val assembleFunc = udf { r: Row => VectorAssembler.assemble(r.toSeq: _*) - } + }.asNondeterministic() val args = $(inputCols).map { c => schema(c).dataType match { case DoubleType => dataset(c) diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala index 027b1fbc6657c..775a04d3df050 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala @@ -314,4 +314,21 @@ class StringIndexerSuite idx += 1 } } + + test("SPARK-22446: StringIndexerModel's indexer UDF should not apply on filtered data") { + val df = List( + ("A", "London", "StrA"), + ("B", "Bristol", null), + ("C", "New York", "StrC")).toDF("ID", "CITY", "CONTENT") + + val dfNoBristol = df.filter($"CONTENT".isNotNull) + + val model = new StringIndexer() + .setInputCol("CITY") + .setOutputCol("CITYIndexed") + .fit(dfNoBristol) + + val dfWithIndex = model.transform(dfNoBristol) + assert(dfWithIndex.filter($"CITYIndexed" === 1.0).count == 1) + } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/VectorAssemblerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/VectorAssemblerSuite.scala index 6aef1c6837025..eca065f7e775d 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/VectorAssemblerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/VectorAssemblerSuite.scala @@ -24,7 +24,7 @@ import org.apache.spark.ml.param.ParamsSuite import org.apache.spark.ml.util.DefaultReadWriteTest import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.sql.Row -import org.apache.spark.sql.functions.col +import org.apache.spark.sql.functions.{col, udf} class VectorAssemblerSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { @@ -126,4 +126,25 @@ class VectorAssemblerSuite .setOutputCol("myOutputCol") testDefaultReadWrite(t) } + + test("SPARK-22446: VectorAssembler's UDF should not apply on filtered data") { + val df = Seq( + (0, 0.0, Vectors.dense(1.0, 2.0), "a", Vectors.sparse(2, Array(1), Array(3.0)), 10L), + (0, 1.0, null, "b", null, 20L) + ).toDF("id", "x", "y", "name", "z", "n") + + val assembler = new VectorAssembler() + .setInputCols(Array("x", "z", "n")) + .setOutputCol("features") + + val filteredDF = df.filter($"y".isNotNull) + + val vectorUDF = udf { vector: Vector => + vector.numActives + } + + assert(assembler.transform(filteredDF).select("features") + .filter(vectorUDF($"features") > 1) + .count() == 1) + } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala index 0e75ac88dc2b8..a3b722a47d688 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala @@ -75,6 +75,7 @@ abstract class Expression extends TreeNode[Expression] { * - it relies on some mutable internal state, or * - it relies on some implicit input that is not part of the children expression list. * - it has non-deterministic child or children. + * - it assumes the input satisfies some certain condition via the child operator. * * An example would be `SparkPartitionID` that relies on the partition id returned by TaskContext. * By default leaf expressions are deterministic as Nil.forall(_.deterministic) returns true. From 6447d7bc1de4ab1d99a8dfcd3fea07f5a2da363d Mon Sep 17 00:00:00 2001 From: "Li, YanKit | Wilson | RIT" Date: Wed, 8 Nov 2017 17:55:21 +0000 Subject: [PATCH 1626/1765] [SPARK-22133][DOCS] Documentation for Mesos Reject Offer Configurations ## What changes were proposed in this pull request? Documentation about Mesos Reject Offer Configurations ## Related PR https://github.com/apache/spark/pull/19510 for `spark.mem.max` Author: Li, YanKit | Wilson | RIT Closes #19555 from windkit/spark_22133. --- docs/running-on-mesos.md | 26 +++++++++++++++++++++++++- 1 file changed, 25 insertions(+), 1 deletion(-) diff --git a/docs/running-on-mesos.md b/docs/running-on-mesos.md index b7e3e6473c338..7a443ffddc5f0 100644 --- a/docs/running-on-mesos.md +++ b/docs/running-on-mesos.md @@ -203,7 +203,7 @@ details and default values. Executors are brought up eagerly when the application starts, until `spark.cores.max` is reached. If you don't set `spark.cores.max`, the -Spark application will reserve all resources offered to it by Mesos, +Spark application will consume all resources offered to it by Mesos, so we of course urge you to set this variable in any sort of multi-tenant cluster, including one which runs multiple concurrent Spark applications. @@ -680,6 +680,30 @@ See the [configuration page](configuration.html) for information on Spark config driver disconnects, the master immediately tears down the framework. + + spark.mesos.rejectOfferDuration + 120s + + Time to consider unused resources refused, serves as a fallback of + `spark.mesos.rejectOfferDurationForUnmetConstraints`, + `spark.mesos.rejectOfferDurationForReachedMaxCores` + + + + spark.mesos.rejectOfferDurationForUnmetConstraints + spark.mesos.rejectOfferDuration + + Time to consider unused resources refused with unmet constraints + + + + spark.mesos.rejectOfferDurationForReachedMaxCores + spark.mesos.rejectOfferDuration + + Time to consider unused resources refused when maximum number of cores + spark.cores.max is reached + + # Troubleshooting and Debugging From ee571d79e52dc7e0ad7ae80619c12a5c0b90b5a5 Mon Sep 17 00:00:00 2001 From: Kent Yao Date: Thu, 9 Nov 2017 14:33:08 +0900 Subject: [PATCH 1627/1765] [SPARK-22466][SPARK SUBMIT] export SPARK_CONF_DIR while conf is default MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## What changes were proposed in this pull request? We use SPARK_CONF_DIR to switch spark conf directory and can be visited if we explicitly export it in spark-env.sh, but with default settings, it can't be done. This PR export SPARK_CONF_DIR while it is default. ### Before ``` KentKentsMacBookPro  ~/Documents/spark-packages/spark-2.3.0-SNAPSHOT-bin-master  bin/spark-shell --master local Using Spark's default log4j profile: org/apache/spark/log4j-defaults.properties Setting default log level to "WARN". To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel). 17/11/08 10:28:44 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable 17/11/08 10:28:45 WARN Utils: Service 'SparkUI' could not bind on port 4040. Attempting port 4041. Spark context Web UI available at http://169.254.168.63:4041 Spark context available as 'sc' (master = local, app id = local-1510108125770). Spark session available as 'spark'. Welcome to ____ __ / __/__ ___ _____/ /__ _\ \/ _ \/ _ `/ __/ '_/ /___/ .__/\_,_/_/ /_/\_\ version 2.3.0-SNAPSHOT /_/ Using Scala version 2.11.8 (Java HotSpot(TM) 64-Bit Server VM, Java 1.8.0_65) Type in expressions to have them evaluated. Type :help for more information. scala> sys.env.get("SPARK_CONF_DIR") res0: Option[String] = None ``` ### After ``` scala> sys.env.get("SPARK_CONF_DIR") res0: Option[String] = Some(/Users/Kent/Documents/spark/conf) ``` ## How was this patch tested? vanzin Author: Kent Yao Closes #19688 from yaooqinn/SPARK-22466. --- bin/load-spark-env.cmd | 12 +++++------- bin/load-spark-env.sh | 9 +++------ conf/spark-env.sh.template | 3 ++- 3 files changed, 10 insertions(+), 14 deletions(-) diff --git a/bin/load-spark-env.cmd b/bin/load-spark-env.cmd index f946197b02d55..cefa513b6fb77 100644 --- a/bin/load-spark-env.cmd +++ b/bin/load-spark-env.cmd @@ -19,15 +19,13 @@ rem rem This script loads spark-env.cmd if it exists, and ensures it is only loaded once. rem spark-env.cmd is loaded from SPARK_CONF_DIR if set, or within the current directory's -rem conf/ subdirectory. +rem conf\ subdirectory. if [%SPARK_ENV_LOADED%] == [] ( set SPARK_ENV_LOADED=1 - if not [%SPARK_CONF_DIR%] == [] ( - set user_conf_dir=%SPARK_CONF_DIR% - ) else ( - set user_conf_dir=..\conf + if [%SPARK_CONF_DIR%] == [] ( + set SPARK_CONF_DIR=%~dp0..\conf ) call :LoadSparkEnv @@ -54,6 +52,6 @@ if [%SPARK_SCALA_VERSION%] == [] ( exit /b 0 :LoadSparkEnv -if exist "%user_conf_dir%\spark-env.cmd" ( - call "%user_conf_dir%\spark-env.cmd" +if exist "%SPARK_CONF_DIR%\spark-env.cmd" ( + call "%SPARK_CONF_DIR%\spark-env.cmd" ) diff --git a/bin/load-spark-env.sh b/bin/load-spark-env.sh index d05d94e68c81b..0b5006dbd63ac 100644 --- a/bin/load-spark-env.sh +++ b/bin/load-spark-env.sh @@ -29,15 +29,12 @@ fi if [ -z "$SPARK_ENV_LOADED" ]; then export SPARK_ENV_LOADED=1 - # Returns the parent of the directory this script lives in. - parent_dir="${SPARK_HOME}" + export SPARK_CONF_DIR="${SPARK_CONF_DIR:-"${SPARK_HOME}"/conf}" - user_conf_dir="${SPARK_CONF_DIR:-"$parent_dir"/conf}" - - if [ -f "${user_conf_dir}/spark-env.sh" ]; then + if [ -f "${SPARK_CONF_DIR}/spark-env.sh" ]; then # Promote all variable declarations to environment (exported) variables set -a - . "${user_conf_dir}/spark-env.sh" + . "${SPARK_CONF_DIR}/spark-env.sh" set +a fi fi diff --git a/conf/spark-env.sh.template b/conf/spark-env.sh.template index f8c895f5303b9..bc92c78f0f8f3 100755 --- a/conf/spark-env.sh.template +++ b/conf/spark-env.sh.template @@ -32,7 +32,8 @@ # - SPARK_LOCAL_DIRS, storage directories to use on this node for shuffle and RDD data # - MESOS_NATIVE_JAVA_LIBRARY, to point to your libmesos.so if you use Mesos -# Options read in YARN client mode +# Options read in YARN client/cluster mode +# - SPARK_CONF_DIR, Alternate conf dir. (Default: ${SPARK_HOME}/conf) # - HADOOP_CONF_DIR, to point Spark towards Hadoop configuration files # - YARN_CONF_DIR, to point Spark towards YARN configuration files when you use YARN # - SPARK_EXECUTOR_CORES, Number of cores for the executors (Default: 1). From d01044233c7065a20ef7f8e7ab052b88eb34eaa5 Mon Sep 17 00:00:00 2001 From: ptkool Date: Thu, 9 Nov 2017 14:44:39 +0900 Subject: [PATCH 1628/1765] [SPARK-22456][SQL] Add support for dayofweek function ## What changes were proposed in this pull request? This PR adds support for a new function called `dayofweek` that returns the day of the week of the given argument as an integer value in the range 1-7, where 1 represents Sunday. ## How was this patch tested? Unit tests and manual tests. Author: ptkool Closes #19672 from ptkool/day_of_week_function. --- python/pyspark/sql/functions.py | 13 +++++++++++++ python/pyspark/sql/tests.py | 7 +++++++ .../main/scala/org/apache/spark/sql/functions.scala | 7 +++++++ 3 files changed, 27 insertions(+) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index 39815497f3956..087ce7caa89c8 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -887,6 +887,19 @@ def month(col): return Column(sc._jvm.functions.month(_to_java_column(col))) +@since(2.3) +def dayofweek(col): + """ + Extract the day of the week of a given date as integer. + + >>> df = spark.createDataFrame([('2015-04-08',)], ['dt']) + >>> df.select(dayofweek('dt').alias('day')).collect() + [Row(day=4)] + """ + sc = SparkContext._active_spark_context + return Column(sc._jvm.functions.dayofweek(_to_java_column(col))) + + @since(1.5) def dayofmonth(col): """ diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index eb0d4e29a5978..4819f629c5310 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -1765,6 +1765,13 @@ def test_datetime_at_epoch(self): self.assertEqual(first['date'], epoch) self.assertEqual(first['lit_date'], epoch) + def test_dayofweek(self): + from pyspark.sql.functions import dayofweek + dt = datetime.datetime(2017, 11, 6) + df = self.spark.createDataFrame([Row(date=dt)]) + row = df.select(dayofweek(df.date)).first() + self.assertEqual(row[0], 2) + def test_decimal(self): from decimal import Decimal schema = StructType([StructField("decimal", DecimalType(10, 5))]) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 6bbdfa3ad1893..3e4659b9eae60 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -2601,6 +2601,13 @@ object functions { */ def month(e: Column): Column = withExpr { Month(e.expr) } + /** + * Extracts the day of the week as an integer from a given date/timestamp/string. + * @group datetime_funcs + * @since 2.3.0 + */ + def dayofweek(e: Column): Column = withExpr { DayOfWeek(e.expr) } + /** * Extracts the day of the month as an integer from a given date/timestamp/string. * @group datetime_funcs From 695647bf2ebda56f9effb7fcdd875490132ea012 Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Thu, 9 Nov 2017 15:00:31 +0900 Subject: [PATCH 1629/1765] [SPARK-21640][SQL][PYTHON][R][FOLLOWUP] Add errorifexists in SparkR and other documentations ## What changes were proposed in this pull request? This PR proposes to add `errorifexists` to SparkR API and fix the rest of them describing the mode, mainly, in API documentations as well. This PR also replaces `convertToJSaveMode` to `setWriteMode` so that string as is is passed to JVM and executes: https://github.com/apache/spark/blob/b034f2565f72aa73c9f0be1e49d148bb4cf05153/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala#L72-L82 and remove the duplication here: https://github.com/apache/spark/blob/3f958a99921d149fb9fdf7ba7e78957afdad1405/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala#L187-L194 ## How was this patch tested? Manually checked the built documentation. These were mainly found by `` grep -r `error` `` and `grep -r 'error'`. Also, unit tests added in `test_sparkSQL.R`. Author: hyukjinkwon Closes #19673 from HyukjinKwon/SPARK-21640-followup. --- R/pkg/R/DataFrame.R | 79 +++++++++++-------- R/pkg/R/utils.R | 9 --- R/pkg/tests/fulltests/test_sparkSQL.R | 8 ++ R/pkg/tests/fulltests/test_utils.R | 8 -- python/pyspark/sql/readwriter.py | 25 +++--- .../apache/spark/sql/DataFrameWriter.scala | 2 +- .../org/apache/spark/sql/api/r/SQLUtils.scala | 9 --- 7 files changed, 71 insertions(+), 69 deletions(-) diff --git a/R/pkg/R/DataFrame.R b/R/pkg/R/DataFrame.R index 763c8d2548580..b8d732a485862 100644 --- a/R/pkg/R/DataFrame.R +++ b/R/pkg/R/DataFrame.R @@ -58,14 +58,23 @@ setMethod("initialize", "SparkDataFrame", function(.Object, sdf, isCached) { #' Set options/mode and then return the write object #' @noRd setWriteOptions <- function(write, path = NULL, mode = "error", ...) { - options <- varargsToStrEnv(...) - if (!is.null(path)) { - options[["path"]] <- path - } - jmode <- convertToJSaveMode(mode) - write <- callJMethod(write, "mode", jmode) - write <- callJMethod(write, "options", options) - write + options <- varargsToStrEnv(...) + if (!is.null(path)) { + options[["path"]] <- path + } + write <- setWriteMode(write, mode) + write <- callJMethod(write, "options", options) + write +} + +#' Set mode and then return the write object +#' @noRd +setWriteMode <- function(write, mode) { + if (!is.character(mode)) { + stop("mode should be character or omitted. It is 'error' by default.") + } + write <- handledCallJMethod(write, "mode", mode) + write } #' @export @@ -556,9 +565,8 @@ setMethod("registerTempTable", setMethod("insertInto", signature(x = "SparkDataFrame", tableName = "character"), function(x, tableName, overwrite = FALSE) { - jmode <- convertToJSaveMode(ifelse(overwrite, "overwrite", "append")) write <- callJMethod(x@sdf, "write") - write <- callJMethod(write, "mode", jmode) + write <- setWriteMode(write, ifelse(overwrite, "overwrite", "append")) invisible(callJMethod(write, "insertInto", tableName)) }) @@ -810,7 +818,8 @@ setMethod("toJSON", #' #' @param x A SparkDataFrame #' @param path The directory where the file is saved -#' @param mode one of 'append', 'overwrite', 'error', 'ignore' save mode (it is 'error' by default) +#' @param mode one of 'append', 'overwrite', 'error', 'errorifexists', 'ignore' +#' save mode (it is 'error' by default) #' @param ... additional argument(s) passed to the method. #' #' @family SparkDataFrame functions @@ -841,7 +850,8 @@ setMethod("write.json", #' #' @param x A SparkDataFrame #' @param path The directory where the file is saved -#' @param mode one of 'append', 'overwrite', 'error', 'ignore' save mode (it is 'error' by default) +#' @param mode one of 'append', 'overwrite', 'error', 'errorifexists', 'ignore' +#' save mode (it is 'error' by default) #' @param ... additional argument(s) passed to the method. #' #' @family SparkDataFrame functions @@ -872,7 +882,8 @@ setMethod("write.orc", #' #' @param x A SparkDataFrame #' @param path The directory where the file is saved -#' @param mode one of 'append', 'overwrite', 'error', 'ignore' save mode (it is 'error' by default) +#' @param mode one of 'append', 'overwrite', 'error', 'errorifexists', 'ignore' +#' save mode (it is 'error' by default) #' @param ... additional argument(s) passed to the method. #' #' @family SparkDataFrame functions @@ -917,7 +928,8 @@ setMethod("saveAsParquetFile", #' #' @param x A SparkDataFrame #' @param path The directory where the file is saved -#' @param mode one of 'append', 'overwrite', 'error', 'ignore' save mode (it is 'error' by default) +#' @param mode one of 'append', 'overwrite', 'error', 'errorifexists', 'ignore' +#' save mode (it is 'error' by default) #' @param ... additional argument(s) passed to the method. #' #' @family SparkDataFrame functions @@ -2871,18 +2883,19 @@ setMethod("except", #' Additionally, mode is used to specify the behavior of the save operation when data already #' exists in the data source. There are four modes: #' \itemize{ -#' \item append: Contents of this SparkDataFrame are expected to be appended to existing data. -#' \item overwrite: Existing data is expected to be overwritten by the contents of this +#' \item 'append': Contents of this SparkDataFrame are expected to be appended to existing data. +#' \item 'overwrite': Existing data is expected to be overwritten by the contents of this #' SparkDataFrame. -#' \item error: An exception is expected to be thrown. -#' \item ignore: The save operation is expected to not save the contents of the SparkDataFrame +#' \item 'error' or 'errorifexists': An exception is expected to be thrown. +#' \item 'ignore': The save operation is expected to not save the contents of the SparkDataFrame #' and to not change the existing data. #' } #' #' @param df a SparkDataFrame. #' @param path a name for the table. #' @param source a name for external data source. -#' @param mode one of 'append', 'overwrite', 'error', 'ignore' save mode (it is 'error' by default) +#' @param mode one of 'append', 'overwrite', 'error', 'errorifexists', 'ignore' +#' save mode (it is 'error' by default) #' @param ... additional argument(s) passed to the method. #' #' @family SparkDataFrame functions @@ -2940,17 +2953,18 @@ setMethod("saveDF", #' #' Additionally, mode is used to specify the behavior of the save operation when #' data already exists in the data source. There are four modes: \cr -#' append: Contents of this SparkDataFrame are expected to be appended to existing data. \cr -#' overwrite: Existing data is expected to be overwritten by the contents of this +#' 'append': Contents of this SparkDataFrame are expected to be appended to existing data. \cr +#' 'overwrite': Existing data is expected to be overwritten by the contents of this #' SparkDataFrame. \cr -#' error: An exception is expected to be thrown. \cr -#' ignore: The save operation is expected to not save the contents of the SparkDataFrame +#' 'error' or 'errorifexists': An exception is expected to be thrown. \cr +#' 'ignore': The save operation is expected to not save the contents of the SparkDataFrame #' and to not change the existing data. \cr #' #' @param df a SparkDataFrame. #' @param tableName a name for the table. #' @param source a name for external data source. -#' @param mode one of 'append', 'overwrite', 'error', 'ignore' save mode (it is 'error' by default). +#' @param mode one of 'append', 'overwrite', 'error', 'errorifexists', 'ignore' +#' save mode (it is 'error' by default) #' @param ... additional option(s) passed to the method. #' #' @family SparkDataFrame functions @@ -2972,12 +2986,11 @@ setMethod("saveAsTable", if (is.null(source)) { source <- getDefaultSqlSource() } - jmode <- convertToJSaveMode(mode) options <- varargsToStrEnv(...) write <- callJMethod(df@sdf, "write") write <- callJMethod(write, "format", source) - write <- callJMethod(write, "mode", jmode) + write <- setWriteMode(write, mode) write <- callJMethod(write, "options", options) invisible(callJMethod(write, "saveAsTable", tableName)) }) @@ -3544,18 +3557,19 @@ setMethod("histogram", #' Also, mode is used to specify the behavior of the save operation when #' data already exists in the data source. There are four modes: #' \itemize{ -#' \item append: Contents of this SparkDataFrame are expected to be appended to existing data. -#' \item overwrite: Existing data is expected to be overwritten by the contents of this +#' \item 'append': Contents of this SparkDataFrame are expected to be appended to existing data. +#' \item 'overwrite': Existing data is expected to be overwritten by the contents of this #' SparkDataFrame. -#' \item error: An exception is expected to be thrown. -#' \item ignore: The save operation is expected to not save the contents of the SparkDataFrame +#' \item 'error' or 'errorifexists': An exception is expected to be thrown. +#' \item 'ignore': The save operation is expected to not save the contents of the SparkDataFrame #' and to not change the existing data. #' } #' #' @param x a SparkDataFrame. #' @param url JDBC database url of the form \code{jdbc:subprotocol:subname}. #' @param tableName yhe name of the table in the external database. -#' @param mode one of 'append', 'overwrite', 'error', 'ignore' save mode (it is 'error' by default). +#' @param mode one of 'append', 'overwrite', 'error', 'errorifexists', 'ignore' +#' save mode (it is 'error' by default) #' @param ... additional JDBC database connection properties. #' @family SparkDataFrame functions #' @rdname write.jdbc @@ -3572,10 +3586,9 @@ setMethod("histogram", setMethod("write.jdbc", signature(x = "SparkDataFrame", url = "character", tableName = "character"), function(x, url, tableName, mode = "error", ...) { - jmode <- convertToJSaveMode(mode) jprops <- varargsToJProperties(...) write <- callJMethod(x@sdf, "write") - write <- callJMethod(write, "mode", jmode) + write <- setWriteMode(write, mode) invisible(handledCallJMethod(write, "jdbc", url, tableName, jprops)) }) diff --git a/R/pkg/R/utils.R b/R/pkg/R/utils.R index 4b716995f2c46..fa4099231ca8d 100644 --- a/R/pkg/R/utils.R +++ b/R/pkg/R/utils.R @@ -736,15 +736,6 @@ splitString <- function(input) { Filter(nzchar, unlist(strsplit(input, ",|\\s"))) } -convertToJSaveMode <- function(mode) { - allModes <- c("append", "overwrite", "error", "ignore") - if (!(mode %in% allModes)) { - stop('mode should be one of "append", "overwrite", "error", "ignore"') # nolint - } - jmode <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", "saveMode", mode) - jmode -} - varargsToJProperties <- function(...) { pairs <- list(...) props <- newJObject("java.util.Properties") diff --git a/R/pkg/tests/fulltests/test_sparkSQL.R b/R/pkg/tests/fulltests/test_sparkSQL.R index 0c8118a7c73f3..a0dbd475f78e6 100644 --- a/R/pkg/tests/fulltests/test_sparkSQL.R +++ b/R/pkg/tests/fulltests/test_sparkSQL.R @@ -630,6 +630,10 @@ test_that("read/write json files", { jsonPath2 <- tempfile(pattern = "jsonPath2", fileext = ".json") write.df(df, jsonPath2, "json", mode = "overwrite") + # Test errorifexists + expect_error(write.df(df, jsonPath2, "json", mode = "errorifexists"), + "analysis error - path file:.*already exists") + # Test write.json jsonPath3 <- tempfile(pattern = "jsonPath3", fileext = ".json") write.json(df, jsonPath3) @@ -1371,6 +1375,9 @@ test_that("test HiveContext", { expect_equal(count(df5), 3) unlink(parquetDataPath) + # Invalid mode + expect_error(saveAsTable(df, "parquetest", "parquet", mode = "abc", path = parquetDataPath), + "illegal argument - Unknown save mode: abc") unsetHiveContext() } }) @@ -3303,6 +3310,7 @@ test_that("Call DataFrameWriter.save() API in Java without path and check argume "Error in orc : analysis error - path file:.*already exists") expect_error(write.parquet(df, jsonPath), "Error in parquet : analysis error - path file:.*already exists") + expect_error(write.parquet(df, jsonPath, mode = 123), "mode should be character or omitted.") # Arguments checking in R side. expect_error(write.df(df, "data.tmp", source = c(1, 2)), diff --git a/R/pkg/tests/fulltests/test_utils.R b/R/pkg/tests/fulltests/test_utils.R index af81423aa8dd0..fb394b8069c1c 100644 --- a/R/pkg/tests/fulltests/test_utils.R +++ b/R/pkg/tests/fulltests/test_utils.R @@ -158,14 +158,6 @@ test_that("varargsToJProperties", { expect_equal(callJMethod(jprops, "size"), 0L) }) -test_that("convertToJSaveMode", { - s <- convertToJSaveMode("error") - expect_true(class(s) == "jobj") - expect_match(capture.output(print.jobj(s)), "Java ref type org.apache.spark.sql.SaveMode id ") - expect_error(convertToJSaveMode("foo"), - 'mode should be one of "append", "overwrite", "error", "ignore"') #nolint -}) - test_that("captureJVMException", { method <- "createStructField" expect_error(tryCatch(callJStatic("org.apache.spark.sql.api.r.SQLUtils", method, diff --git a/python/pyspark/sql/readwriter.py b/python/pyspark/sql/readwriter.py index 3d87567ab673d..a75bdf8078dd5 100644 --- a/python/pyspark/sql/readwriter.py +++ b/python/pyspark/sql/readwriter.py @@ -540,7 +540,7 @@ def mode(self, saveMode): * `append`: Append contents of this :class:`DataFrame` to existing data. * `overwrite`: Overwrite existing data. - * `error`: Throw an exception if data already exists. + * `error` or `errorifexists`: Throw an exception if data already exists. * `ignore`: Silently ignore this operation if data already exists. >>> df.write.mode('append').parquet(os.path.join(tempfile.mkdtemp(), 'data')) @@ -675,7 +675,8 @@ def save(self, path=None, format=None, mode=None, partitionBy=None, **options): * ``append``: Append contents of this :class:`DataFrame` to existing data. * ``overwrite``: Overwrite existing data. * ``ignore``: Silently ignore this operation if data already exists. - * ``error`` (default case): Throw an exception if data already exists. + * ``error`` or ``errorifexists`` (default case): Throw an exception if data already \ + exists. :param partitionBy: names of partitioning columns :param options: all other string options @@ -713,12 +714,13 @@ def saveAsTable(self, name, format=None, mode=None, partitionBy=None, **options) * `append`: Append contents of this :class:`DataFrame` to existing data. * `overwrite`: Overwrite existing data. - * `error`: Throw an exception if data already exists. + * `error` or `errorifexists`: Throw an exception if data already exists. * `ignore`: Silently ignore this operation if data already exists. :param name: the table name :param format: the format used to save - :param mode: one of `append`, `overwrite`, `error`, `ignore` (default: error) + :param mode: one of `append`, `overwrite`, `error`, `errorifexists`, `ignore` \ + (default: error) :param partitionBy: names of partitioning columns :param options: all other string options """ @@ -741,7 +743,8 @@ def json(self, path, mode=None, compression=None, dateFormat=None, timestampForm * ``append``: Append contents of this :class:`DataFrame` to existing data. * ``overwrite``: Overwrite existing data. * ``ignore``: Silently ignore this operation if data already exists. - * ``error`` (default case): Throw an exception if data already exists. + * ``error`` or ``errorifexists`` (default case): Throw an exception if data already \ + exists. :param compression: compression codec to use when saving to file. This can be one of the known case-insensitive shorten names (none, bzip2, gzip, lz4, snappy and deflate). @@ -771,7 +774,8 @@ def parquet(self, path, mode=None, partitionBy=None, compression=None): * ``append``: Append contents of this :class:`DataFrame` to existing data. * ``overwrite``: Overwrite existing data. * ``ignore``: Silently ignore this operation if data already exists. - * ``error`` (default case): Throw an exception if data already exists. + * ``error`` or ``errorifexists`` (default case): Throw an exception if data already \ + exists. :param partitionBy: names of partitioning columns :param compression: compression codec to use when saving to file. This can be one of the known case-insensitive shorten names (none, snappy, gzip, and lzo). @@ -814,7 +818,8 @@ def csv(self, path, mode=None, compression=None, sep=None, quote=None, escape=No * ``append``: Append contents of this :class:`DataFrame` to existing data. * ``overwrite``: Overwrite existing data. * ``ignore``: Silently ignore this operation if data already exists. - * ``error`` (default case): Throw an exception if data already exists. + * ``error`` or ``errorifexists`` (default case): Throw an exception if data already \ + exists. :param compression: compression codec to use when saving to file. This can be one of the known case-insensitive shorten names (none, bzip2, gzip, lz4, @@ -874,7 +879,8 @@ def orc(self, path, mode=None, partitionBy=None, compression=None): * ``append``: Append contents of this :class:`DataFrame` to existing data. * ``overwrite``: Overwrite existing data. * ``ignore``: Silently ignore this operation if data already exists. - * ``error`` (default case): Throw an exception if data already exists. + * ``error`` or ``errorifexists`` (default case): Throw an exception if data already \ + exists. :param partitionBy: names of partitioning columns :param compression: compression codec to use when saving to file. This can be one of the known case-insensitive shorten names (none, snappy, zlib, and lzo). @@ -905,7 +911,8 @@ def jdbc(self, url, table, mode=None, properties=None): * ``append``: Append contents of this :class:`DataFrame` to existing data. * ``overwrite``: Overwrite existing data. * ``ignore``: Silently ignore this operation if data already exists. - * ``error`` (default case): Throw an exception if data already exists. + * ``error`` or ``errorifexists`` (default case): Throw an exception if data already \ + exists. :param properties: a dictionary of JDBC database connection arguments. Normally at least properties "user" and "password" with their corresponding values. For example { 'user' : 'SYSTEM', 'password' : 'mypassword' } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala index 8d95b24c00619..e3fa2ced760e9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala @@ -65,7 +65,7 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { * - `overwrite`: overwrite the existing data. * - `append`: append the data. * - `ignore`: ignore the operation (i.e. no-op). - * - `error`: default option, throw an exception at runtime. + * - `error` or `errorifexists`: default option, throw an exception at runtime. * * @since 1.4.0 */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala index 872ef773e8a3a..af20764f9a968 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala @@ -184,15 +184,6 @@ private[sql] object SQLUtils extends Logging { colArray } - def saveMode(mode: String): SaveMode = { - mode match { - case "append" => SaveMode.Append - case "overwrite" => SaveMode.Overwrite - case "error" => SaveMode.ErrorIfExists - case "ignore" => SaveMode.Ignore - } - } - def readSqlObject(dis: DataInputStream, dataType: Char): Object = { dataType match { case 's' => From 98be55c0fafc3577ab4b106316666b6807bc928f Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Thu, 9 Nov 2017 16:34:38 +0900 Subject: [PATCH 1630/1765] [SPARK-22222][CORE][TEST][FOLLOW-UP] Remove redundant and deprecated `Timeouts` ## What changes were proposed in this pull request? Since SPARK-21939, Apache Spark uses `TimeLimits` instead of the deprecated `Timeouts`. This PR fixes the build warning `BufferHolderSparkSubmitSuite.scala` introduced at [SPARK-22222](https://github.com/apache/spark/pull/19460/files#diff-d8cf6e0c229969db94ec8ffc31a9239cR36) by removing the redundant `Timeouts`. ```scala trait Timeouts in package concurrent is deprecated: Please use org.scalatest.concurrent.TimeLimits instead [warn] with Timeouts { ``` ## How was this patch tested? N/A Author: Dongjoon Hyun Closes #19697 from dongjoon-hyun/SPARK-22222. --- .../expressions/codegen/BufferHolderSparkSubmitSuite.scala | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/BufferHolderSparkSubmitSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/BufferHolderSparkSubmitSuite.scala index 1167d2f3f3891..85682cf6ea670 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/BufferHolderSparkSubmitSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/BufferHolderSparkSubmitSuite.scala @@ -18,7 +18,6 @@ package org.apache.spark.sql.catalyst.expressions.codegen import org.scalatest.{BeforeAndAfterEach, Matchers} -import org.scalatest.concurrent.Timeouts import org.apache.spark.{SparkFunSuite, TestUtils} import org.apache.spark.deploy.SparkSubmitSuite @@ -32,8 +31,7 @@ class BufferHolderSparkSubmitSuite extends SparkFunSuite with Matchers with BeforeAndAfterEach - with ResetSystemProperties - with Timeouts { + with ResetSystemProperties { test("SPARK-22222: Buffer holder should be able to allocate memory larger than 1GB") { val unusedJar = TestUtils.createJarWithClasses(Seq.empty) From c755b0d910d68e7921807f2f2ac1e3fac7a8f357 Mon Sep 17 00:00:00 2001 From: Kent Yao Date: Thu, 9 Nov 2017 09:22:33 +0100 Subject: [PATCH 1631/1765] [SPARK-22463][YARN][SQL][HIVE] add hadoop/hive/hbase/etc configuration files in SPARK_CONF_DIR to distribute archive ## What changes were proposed in this pull request? When I ran self contained sql apps, such as ```scala import org.apache.spark.sql.SparkSession object ShowHiveTables { def main(args: Array[String]): Unit = { val spark = SparkSession .builder() .appName("Show Hive Tables") .enableHiveSupport() .getOrCreate() spark.sql("show tables").show() spark.stop() } } ``` with **yarn cluster** mode and `hive-site.xml` correctly within `$SPARK_HOME/conf`,they failed to connect the right hive metestore for not seeing hive-site.xml in AM/Driver's classpath. Although submitting them with `--files/--jars local/path/to/hive-site.xml` or puting it to `$HADOOP_CONF_DIR/YARN_CONF_DIR` can make these apps works well in cluster mode as client mode, according to the official doc, see http://spark.apache.org/docs/latest/sql-programming-guide.html#hive-tables > Configuration of Hive is done by placing your hive-site.xml, core-site.xml (for security configuration), and hdfs-site.xml (for HDFS configuration) file in conf/. We may respect these configuration files too or modify the doc for hive-tables in cluster mode. ## How was this patch tested? cc cloud-fan gatorsmile Author: Kent Yao Closes #19663 from yaooqinn/SPARK-21888. --- .../org/apache/spark/deploy/yarn/Client.scala | 15 ++++++++++++++- 1 file changed, 14 insertions(+), 1 deletion(-) diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala index 1fe25c4ddaabf..99e7d46ca5c96 100644 --- a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala +++ b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala @@ -17,7 +17,7 @@ package org.apache.spark.deploy.yarn -import java.io.{File, FileOutputStream, IOException, OutputStreamWriter} +import java.io.{FileSystem => _, _} import java.net.{InetAddress, UnknownHostException, URI} import java.nio.ByteBuffer import java.nio.charset.StandardCharsets @@ -687,6 +687,19 @@ private[spark] class Client( private def createConfArchive(): File = { val hadoopConfFiles = new HashMap[String, File]() + // SPARK_CONF_DIR shows up in the classpath before HADOOP_CONF_DIR/YARN_CONF_DIR + sys.env.get("SPARK_CONF_DIR").foreach { localConfDir => + val dir = new File(localConfDir) + if (dir.isDirectory) { + val files = dir.listFiles(new FileFilter { + override def accept(pathname: File): Boolean = { + pathname.isFile && pathname.getName.endsWith(".xml") + } + }) + files.foreach { f => hadoopConfFiles(f.getName) = f } + } + } + Seq("HADOOP_CONF_DIR", "YARN_CONF_DIR").foreach { envKey => sys.env.get(envKey).foreach { path => val dir = new File(path) From fe93c0bf6115a62486ebf9f494da5dc368c24418 Mon Sep 17 00:00:00 2001 From: guoxiaolong Date: Thu, 9 Nov 2017 11:46:01 +0100 Subject: [PATCH 1632/1765] [DOC] update the API doc and modify the stage API description ## What changes were proposed in this pull request? **1.stage api modify the description format** A list of all stages for a given application.
    ?status=[active|complete|pending|failed] list only stages in the state. content should be included in fix before: ![1](https://user-images.githubusercontent.com/26266482/31753100-201f3432-b4c1-11e7-9e8d-54b62b96c17f.png) fix after: ![2](https://user-images.githubusercontent.com/26266482/31753102-23b174de-b4c1-11e7-96ad-fd79d10440b9.png) **2.add version api doc '/api/v1/version' in monitoring.md** fix after: ![3](https://user-images.githubusercontent.com/26266482/31753087-0fd3a036-b4c1-11e7-802f-a6dc86a2a4b0.png) ## How was this patch tested? manual tests Please review http://spark.apache.org/contributing.html before opening a pull request. Author: guoxiaolong Closes #19532 from guoxiaolongzte/SPARK-22311. --- docs/monitoring.md | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/docs/monitoring.md b/docs/monitoring.md index 1ae43185d22f8..f8d3ce91a0691 100644 --- a/docs/monitoring.md +++ b/docs/monitoring.md @@ -311,8 +311,10 @@ can be identified by their `[attempt-id]`. In the API listed below, when running /applications/[app-id]/stages - A list of all stages for a given application. -
    ?status=[active|complete|pending|failed] list only stages in the state. + + A list of all stages for a given application. +
    ?status=[active|complete|pending|failed] list only stages in the state. + /applications/[app-id]/stages/[stage-id] @@ -398,7 +400,11 @@ can be identified by their `[attempt-id]`. In the API listed below, when running /applications/[app-id]/environment Environment details of the given application. - + + + /version + Get the current spark version. + The number of jobs and stages which can retrieved is constrained by the same retention From 40a8aefaf3e97e80b23fb05d4afdcc30e1922312 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Thu, 9 Nov 2017 11:54:50 +0100 Subject: [PATCH 1633/1765] [SPARK-22442][SQL] ScalaReflection should produce correct field names for special characters ## What changes were proposed in this pull request? For a class with field name of special characters, e.g.: ```scala case class MyType(`field.1`: String, `field 2`: String) ``` Although we can manipulate DataFrame/Dataset, the field names are encoded: ```scala scala> val df = Seq(MyType("a", "b"), MyType("c", "d")).toDF df: org.apache.spark.sql.DataFrame = [field$u002E1: string, field$u00202: string] scala> df.as[MyType].collect res7: Array[MyType] = Array(MyType(a,b), MyType(c,d)) ``` It causes resolving problem when we try to convert the data with non-encoded field names: ```scala spark.read.json(path).as[MyType] ... [info] org.apache.spark.sql.AnalysisException: cannot resolve '`field$u002E1`' given input columns: [field 2, fie ld.1]; [info] at org.apache.spark.sql.catalyst.analysis.package$AnalysisErrorAt.failAnalysis(package.scala:42) ... ``` We should use decoded field name in Dataset schema. ## How was this patch tested? Added tests. Author: Liang-Chi Hsieh Closes #19664 from viirya/SPARK-22442. --- .../spark/sql/catalyst/ScalaReflection.scala | 9 +++++---- .../catalyst/expressions/objects/objects.scala | 11 +++++++---- .../sql/catalyst/ScalaReflectionSuite.scala | 18 +++++++++++++++++- .../org/apache/spark/sql/DatasetSuite.scala | 12 ++++++++++++ 4 files changed, 41 insertions(+), 9 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala index 17e595f9c5d8d..f62553ddd3971 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala @@ -146,7 +146,7 @@ object ScalaReflection extends ScalaReflection { def addToPath(part: String, dataType: DataType, walkedTypePath: Seq[String]): Expression = { val newPath = path .map(p => UnresolvedExtractValue(p, expressions.Literal(part))) - .getOrElse(UnresolvedAttribute(part)) + .getOrElse(UnresolvedAttribute.quoted(part)) upCastToExpectedType(newPath, dataType, walkedTypePath) } @@ -675,7 +675,7 @@ object ScalaReflection extends ScalaReflection { val m = runtimeMirror(cls.getClassLoader) val classSymbol = m.staticClass(cls.getName) val t = classSymbol.selfType - constructParams(t).map(_.name.toString) + constructParams(t).map(_.name.decodedName.toString) } /** @@ -855,11 +855,12 @@ trait ScalaReflection { // if there are type variables to fill in, do the substitution (SomeClass[T] -> SomeClass[Int]) if (actualTypeArgs.nonEmpty) { params.map { p => - p.name.toString -> p.typeSignature.substituteTypes(formalTypeArgs, actualTypeArgs) + p.name.decodedName.toString -> + p.typeSignature.substituteTypes(formalTypeArgs, actualTypeArgs) } } else { params.map { p => - p.name.toString -> p.typeSignature + p.name.decodedName.toString -> p.typeSignature } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala index 6ae3490a3f863..f2eee991c9865 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala @@ -28,6 +28,7 @@ import org.apache.spark.{SparkConf, SparkEnv} import org.apache.spark.serializer._ import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.ScalaReflection.universe.TermName import org.apache.spark.sql.catalyst.encoders.RowEncoder import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} @@ -214,11 +215,13 @@ case class Invoke( override def eval(input: InternalRow): Any = throw new UnsupportedOperationException("Only code-generated evaluation is supported.") + private lazy val encodedFunctionName = TermName(functionName).encodedName.toString + @transient lazy val method = targetObject.dataType match { case ObjectType(cls) => - val m = cls.getMethods.find(_.getName == functionName) + val m = cls.getMethods.find(_.getName == encodedFunctionName) if (m.isEmpty) { - sys.error(s"Couldn't find $functionName on $cls") + sys.error(s"Couldn't find $encodedFunctionName on $cls") } else { m } @@ -247,7 +250,7 @@ case class Invoke( } val evaluate = if (returnPrimitive) { - getFuncResult(ev.value, s"${obj.value}.$functionName($argString)") + getFuncResult(ev.value, s"${obj.value}.$encodedFunctionName($argString)") } else { val funcResult = ctx.freshName("funcResult") // If the function can return null, we do an extra check to make sure our null bit is still @@ -265,7 +268,7 @@ case class Invoke( } s""" Object $funcResult = null; - ${getFuncResult(funcResult, s"${obj.value}.$functionName($argString)")} + ${getFuncResult(funcResult, s"${obj.value}.$encodedFunctionName($argString)")} $assignResult """ } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala index a5b9855e959d4..f77af5db3279b 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala @@ -20,7 +20,8 @@ package org.apache.spark.sql.catalyst import java.sql.{Date, Timestamp} import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.catalyst.expressions.{BoundReference, Literal, SpecificInternalRow} +import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute +import org.apache.spark.sql.catalyst.expressions.{BoundReference, Literal, SpecificInternalRow, UpCast} import org.apache.spark.sql.catalyst.expressions.objects.NewInstance import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String @@ -79,6 +80,8 @@ case class MultipleConstructorsData(a: Int, b: String, c: Double) { def this(b: String, a: Int) = this(a, b, c = 1.0) } +case class SpecialCharAsFieldData(`field.1`: String, `field 2`: String) + object TestingUDT { @SQLUserDefinedType(udt = classOf[NestedStructUDT]) class NestedStruct(val a: Integer, val b: Long, val c: Double) @@ -335,4 +338,17 @@ class ScalaReflectionSuite extends SparkFunSuite { assert(linkedHashMapDeserializer.dataType == ObjectType(classOf[LHMap[_, _]])) } + test("SPARK-22442: Generate correct field names for special characters") { + val serializer = serializerFor[SpecialCharAsFieldData](BoundReference( + 0, ObjectType(classOf[SpecialCharAsFieldData]), nullable = false)) + val deserializer = deserializerFor[SpecialCharAsFieldData] + assert(serializer.dataType(0).name == "field.1") + assert(serializer.dataType(1).name == "field 2") + + val argumentsFields = deserializer.asInstanceOf[NewInstance].arguments.flatMap { _.collect { + case UpCast(u: UnresolvedAttribute, _, _) => u.nameParts + }} + assert(argumentsFields(0) == Seq("field.1")) + assert(argumentsFields(1) == Seq("field 2")) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala index 1537ce3313c09..c67165c7abca6 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala @@ -1398,6 +1398,16 @@ class DatasetSuite extends QueryTest with SharedSQLContext { val actual = kvDataset.toString assert(expected === actual) } + + test("SPARK-22442: Generate correct field names for special characters") { + withTempPath { dir => + val path = dir.getCanonicalPath + val data = """{"field.1": 1, "field 2": 2}""" + Seq(data).toDF().repartition(1).write.text(path) + val ds = spark.read.json(path).as[SpecialCharClass] + checkDataset(ds, SpecialCharClass("1", "2")) + } + } } case class SingleData(id: Int) @@ -1487,3 +1497,5 @@ case class CircularReferenceClassB(cls: CircularReferenceClassA) case class CircularReferenceClassC(ar: Array[CircularReferenceClassC]) case class CircularReferenceClassD(map: Map[String, CircularReferenceClassE]) case class CircularReferenceClassE(id: String, list: List[CircularReferenceClassD]) + +case class SpecialCharClass(`field.1`: String, `field 2`: String) From 6793a3dac0a44570625044e1eb30fa578fa2f142 Mon Sep 17 00:00:00 2001 From: jerryshao Date: Thu, 9 Nov 2017 11:57:56 +0100 Subject: [PATCH 1634/1765] [SPARK-22405][SQL] Add new alter table and alter database related ExternalCatalogEvent ## What changes were proposed in this pull request? We're building a data lineage tool in which we need to monitor the metadata changes in ExternalCatalog, current ExternalCatalog already provides several useful events like "CreateDatabaseEvent" for custom SparkListener to use. But still there's some event missing, like alter database event and alter table event. So here propose to and new ExternalCatalogEvent. ## How was this patch tested? Enrich the current UT and tested on local cluster. CC hvanhovell please let me know your comments about current proposal, thanks. Author: jerryshao Closes #19649 from jerryshao/SPARK-22405. --- .../catalyst/catalog/ExternalCatalog.scala | 35 +++++++++++++++-- .../catalyst/catalog/InMemoryCatalog.scala | 8 ++-- .../spark/sql/catalyst/catalog/events.scala | 38 ++++++++++++++++++- .../catalog/ExternalCatalogEventSuite.scala | 22 +++++++++++ .../spark/sql/hive/HiveExternalCatalog.scala | 18 ++++++--- 5 files changed, 107 insertions(+), 14 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalog.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalog.scala index 223094d485936..45b4f013620c1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalog.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalog.scala @@ -87,7 +87,14 @@ abstract class ExternalCatalog * Note: If the underlying implementation does not support altering a certain field, * this becomes a no-op. */ - def alterDatabase(dbDefinition: CatalogDatabase): Unit + final def alterDatabase(dbDefinition: CatalogDatabase): Unit = { + val db = dbDefinition.name + postToAll(AlterDatabasePreEvent(db)) + doAlterDatabase(dbDefinition) + postToAll(AlterDatabaseEvent(db)) + } + + protected def doAlterDatabase(dbDefinition: CatalogDatabase): Unit def getDatabase(db: String): CatalogDatabase @@ -147,7 +154,15 @@ abstract class ExternalCatalog * Note: If the underlying implementation does not support altering a certain field, * this becomes a no-op. */ - def alterTable(tableDefinition: CatalogTable): Unit + final def alterTable(tableDefinition: CatalogTable): Unit = { + val db = tableDefinition.database + val name = tableDefinition.identifier.table + postToAll(AlterTablePreEvent(db, name, AlterTableKind.TABLE)) + doAlterTable(tableDefinition) + postToAll(AlterTableEvent(db, name, AlterTableKind.TABLE)) + } + + protected def doAlterTable(tableDefinition: CatalogTable): Unit /** * Alter the data schema of a table identified by the provided database and table name. The new @@ -158,10 +173,22 @@ abstract class ExternalCatalog * @param table Name of table to alter schema for * @param newDataSchema Updated data schema to be used for the table. */ - def alterTableDataSchema(db: String, table: String, newDataSchema: StructType): Unit + final def alterTableDataSchema(db: String, table: String, newDataSchema: StructType): Unit = { + postToAll(AlterTablePreEvent(db, table, AlterTableKind.DATASCHEMA)) + doAlterTableDataSchema(db, table, newDataSchema) + postToAll(AlterTableEvent(db, table, AlterTableKind.DATASCHEMA)) + } + + protected def doAlterTableDataSchema(db: String, table: String, newDataSchema: StructType): Unit /** Alter the statistics of a table. If `stats` is None, then remove all existing statistics. */ - def alterTableStats(db: String, table: String, stats: Option[CatalogStatistics]): Unit + final def alterTableStats(db: String, table: String, stats: Option[CatalogStatistics]): Unit = { + postToAll(AlterTablePreEvent(db, table, AlterTableKind.STATS)) + doAlterTableStats(db, table, stats) + postToAll(AlterTableEvent(db, table, AlterTableKind.STATS)) + } + + protected def doAlterTableStats(db: String, table: String, stats: Option[CatalogStatistics]): Unit def getTable(db: String, table: String): CatalogTable diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/InMemoryCatalog.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/InMemoryCatalog.scala index 9504140d51e99..8eacfa058bd52 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/InMemoryCatalog.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/InMemoryCatalog.scala @@ -152,7 +152,7 @@ class InMemoryCatalog( } } - override def alterDatabase(dbDefinition: CatalogDatabase): Unit = synchronized { + override def doAlterDatabase(dbDefinition: CatalogDatabase): Unit = synchronized { requireDbExists(dbDefinition.name) catalog(dbDefinition.name).db = dbDefinition } @@ -294,7 +294,7 @@ class InMemoryCatalog( catalog(db).tables.remove(oldName) } - override def alterTable(tableDefinition: CatalogTable): Unit = synchronized { + override def doAlterTable(tableDefinition: CatalogTable): Unit = synchronized { assert(tableDefinition.identifier.database.isDefined) val db = tableDefinition.identifier.database.get requireTableExists(db, tableDefinition.identifier.table) @@ -303,7 +303,7 @@ class InMemoryCatalog( catalog(db).tables(tableDefinition.identifier.table).table = newTableDefinition } - override def alterTableDataSchema( + override def doAlterTableDataSchema( db: String, table: String, newDataSchema: StructType): Unit = synchronized { @@ -313,7 +313,7 @@ class InMemoryCatalog( catalog(db).tables(table).table = origTable.copy(schema = newSchema) } - override def alterTableStats( + override def doAlterTableStats( db: String, table: String, stats: Option[CatalogStatistics]): Unit = synchronized { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/events.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/events.scala index 742a51e640383..e7d41644392d5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/events.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/events.scala @@ -61,6 +61,16 @@ case class DropDatabasePreEvent(database: String) extends DatabaseEvent */ case class DropDatabaseEvent(database: String) extends DatabaseEvent +/** + * Event fired before a database is altered. + */ +case class AlterDatabasePreEvent(database: String) extends DatabaseEvent + +/** + * Event fired after a database is altered. + */ +case class AlterDatabaseEvent(database: String) extends DatabaseEvent + /** * Event fired when a table is created, dropped or renamed. */ @@ -110,7 +120,33 @@ case class RenameTableEvent( extends TableEvent /** - * Event fired when a function is created, dropped or renamed. + * String to indicate which part of table is altered. If a plain alterTable API is called, then + * type will generally be Table. + */ +object AlterTableKind extends Enumeration { + val TABLE = "table" + val DATASCHEMA = "dataSchema" + val STATS = "stats" +} + +/** + * Event fired before a table is altered. + */ +case class AlterTablePreEvent( + database: String, + name: String, + kind: String) extends TableEvent + +/** + * Event fired after a table is altered. + */ +case class AlterTableEvent( + database: String, + name: String, + kind: String) extends TableEvent + +/** + * Event fired when a function is created, dropped, altered or renamed. */ trait FunctionEvent extends DatabaseEvent { /** diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalogEventSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalogEventSuite.scala index 087c26f23f383..1acbe34d9a075 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalogEventSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalogEventSuite.scala @@ -75,6 +75,11 @@ class ExternalCatalogEventSuite extends SparkFunSuite { } checkEvents(CreateDatabasePreEvent("db5") :: Nil) + // ALTER + val newDbDefinition = dbDefinition.copy(description = "test") + catalog.alterDatabase(newDbDefinition) + checkEvents(AlterDatabasePreEvent("db5") :: AlterDatabaseEvent("db5") :: Nil) + // DROP intercept[AnalysisException] { catalog.dropDatabase("db4", ignoreIfNotExists = false, cascade = false) @@ -119,6 +124,23 @@ class ExternalCatalogEventSuite extends SparkFunSuite { } checkEvents(CreateTablePreEvent("db5", "tbl1") :: Nil) + // ALTER + val newTableDefinition = tableDefinition.copy(tableType = CatalogTableType.EXTERNAL) + catalog.alterTable(newTableDefinition) + checkEvents(AlterTablePreEvent("db5", "tbl1", AlterTableKind.TABLE) :: + AlterTableEvent("db5", "tbl1", AlterTableKind.TABLE) :: Nil) + + // ALTER schema + val newSchema = new StructType().add("id", "long", nullable = false) + catalog.alterTableDataSchema("db5", "tbl1", newSchema) + checkEvents(AlterTablePreEvent("db5", "tbl1", AlterTableKind.DATASCHEMA) :: + AlterTableEvent("db5", "tbl1", AlterTableKind.DATASCHEMA) :: Nil) + + // ALTER stats + catalog.alterTableStats("db5", "tbl1", None) + checkEvents(AlterTablePreEvent("db5", "tbl1", AlterTableKind.STATS) :: + AlterTableEvent("db5", "tbl1", AlterTableKind.STATS) :: Nil) + // RENAME catalog.renameTable("db5", "tbl1", "tbl2") checkEvents( diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala index f8a947bf527e7..7cd772544a96a 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala @@ -177,7 +177,7 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat * * Note: As of now, this only supports altering database properties! */ - override def alterDatabase(dbDefinition: CatalogDatabase): Unit = withClient { + override def doAlterDatabase(dbDefinition: CatalogDatabase): Unit = withClient { val existingDb = getDatabase(dbDefinition.name) if (existingDb.properties == dbDefinition.properties) { logWarning(s"Request to alter database ${dbDefinition.name} is a no-op because " + @@ -540,7 +540,7 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat * Note: As of now, this doesn't support altering table schema, partition column names and bucket * specification. We will ignore them even if users do specify different values for these fields. */ - override def alterTable(tableDefinition: CatalogTable): Unit = withClient { + override def doAlterTable(tableDefinition: CatalogTable): Unit = withClient { assert(tableDefinition.identifier.database.isDefined) val db = tableDefinition.identifier.database.get requireTableExists(db, tableDefinition.identifier.table) @@ -619,8 +619,15 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat } } - override def alterTableDataSchema( - db: String, table: String, newDataSchema: StructType): Unit = withClient { + /** + * Alter the data schema of a table identified by the provided database and table name. The new + * data schema should not have conflict column names with the existing partition columns, and + * should still contain all the existing data columns. + */ + override def doAlterTableDataSchema( + db: String, + table: String, + newDataSchema: StructType): Unit = withClient { requireTableExists(db, table) val oldTable = getTable(db, table) verifyDataSchema(oldTable.identifier, oldTable.tableType, newDataSchema) @@ -648,7 +655,8 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat } } - override def alterTableStats( + /** Alter the statistics of a table. If `stats` is None, then remove all existing statistics. */ + override def doAlterTableStats( db: String, table: String, stats: Option[CatalogStatistics]): Unit = withClient { From 77f74539ec7a445e24736029fb198b48ffd50ea9 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Thu, 9 Nov 2017 16:35:06 +0200 Subject: [PATCH 1635/1765] [SPARK-20542][ML][SQL] Add an API to Bucketizer that can bin multiple columns ## What changes were proposed in this pull request? Current ML's Bucketizer can only bin a column of continuous features. If a dataset has thousands of of continuous columns needed to bin, we will result in thousands of ML stages. It is inefficient regarding query planning and execution. We should have a type of bucketizer that can bin a lot of columns all at once. It would need to accept an list of arrays of split points to correspond to the columns to bin, but it might make things more efficient by replacing thousands of stages with just one. This current approach in this patch is to add a new `MultipleBucketizerInterface` for this purpose. `Bucketizer` now extends this new interface. ### Performance Benchmarking using the test dataset provided in JIRA SPARK-20392 (blockbuster.csv). The ML pipeline includes 2 `StringIndexer`s and 1 `MultipleBucketizer` or 137 `Bucketizer`s to bin 137 input columns with the same splits. Then count the time to transform the dataset. MultipleBucketizer: 3352 ms Bucketizer: 51512 ms ## How was this patch tested? Jenkins tests. Please review http://spark.apache.org/contributing.html before opening a pull request. Author: Liang-Chi Hsieh Closes #17819 from viirya/SPARK-20542. --- .../examples/ml/JavaBucketizerExample.java | 41 +++ .../spark/examples/ml/BucketizerExample.scala | 36 ++- .../apache/spark/ml/feature/Bucketizer.scala | 122 +++++++-- .../org/apache/spark/ml/param/params.scala | 39 +++ .../ml/param/shared/SharedParamsCodeGen.scala | 1 + .../spark/ml/param/shared/sharedParams.scala | 17 ++ .../spark/ml/feature/JavaBucketizerSuite.java | 35 +++ .../spark/ml/feature/BucketizerSuite.scala | 239 +++++++++++++++++- .../apache/spark/ml/param/ParamsSuite.scala | 38 ++- .../scala/org/apache/spark/sql/Dataset.scala | 21 +- .../org/apache/spark/sql/DataFrameSuite.scala | 28 ++ 11 files changed, 592 insertions(+), 25 deletions(-) diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaBucketizerExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaBucketizerExample.java index f00993833321d..3e49bf04ac892 100644 --- a/examples/src/main/java/org/apache/spark/examples/ml/JavaBucketizerExample.java +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaBucketizerExample.java @@ -33,6 +33,13 @@ import org.apache.spark.sql.types.StructType; // $example off$ +/** + * An example for Bucketizer. + * Run with + *
    + * bin/run-example ml.JavaBucketizerExample
    + * 
    + */ public class JavaBucketizerExample { public static void main(String[] args) { SparkSession spark = SparkSession @@ -68,6 +75,40 @@ public static void main(String[] args) { bucketedData.show(); // $example off$ + // $example on$ + // Bucketize multiple columns at one pass. + double[][] splitsArray = { + {Double.NEGATIVE_INFINITY, -0.5, 0.0, 0.5, Double.POSITIVE_INFINITY}, + {Double.NEGATIVE_INFINITY, -0.3, 0.0, 0.3, Double.POSITIVE_INFINITY} + }; + + List data2 = Arrays.asList( + RowFactory.create(-999.9, -999.9), + RowFactory.create(-0.5, -0.2), + RowFactory.create(-0.3, -0.1), + RowFactory.create(0.0, 0.0), + RowFactory.create(0.2, 0.4), + RowFactory.create(999.9, 999.9) + ); + StructType schema2 = new StructType(new StructField[]{ + new StructField("features1", DataTypes.DoubleType, false, Metadata.empty()), + new StructField("features2", DataTypes.DoubleType, false, Metadata.empty()) + }); + Dataset dataFrame2 = spark.createDataFrame(data2, schema2); + + Bucketizer bucketizer2 = new Bucketizer() + .setInputCols(new String[] {"features1", "features2"}) + .setOutputCols(new String[] {"bucketedFeatures1", "bucketedFeatures2"}) + .setSplitsArray(splitsArray); + // Transform original data into its bucket index. + Dataset bucketedData2 = bucketizer2.transform(dataFrame2); + + System.out.println("Bucketizer output with [" + + (bucketizer2.getSplitsArray()[0].length-1) + ", " + + (bucketizer2.getSplitsArray()[1].length-1) + "] buckets for each input column"); + bucketedData2.show(); + // $example off$ + spark.stop(); } } diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/BucketizerExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/BucketizerExample.scala index 04e4eccd436ed..7e65f9c88907d 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/BucketizerExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/BucketizerExample.scala @@ -22,7 +22,13 @@ package org.apache.spark.examples.ml import org.apache.spark.ml.feature.Bucketizer // $example off$ import org.apache.spark.sql.SparkSession - +/** + * An example for Bucketizer. + * Run with + * {{{ + * bin/run-example ml.BucketizerExample + * }}} + */ object BucketizerExample { def main(args: Array[String]): Unit = { val spark = SparkSession @@ -48,6 +54,34 @@ object BucketizerExample { bucketedData.show() // $example off$ + // $example on$ + val splitsArray = Array( + Array(Double.NegativeInfinity, -0.5, 0.0, 0.5, Double.PositiveInfinity), + Array(Double.NegativeInfinity, -0.3, 0.0, 0.3, Double.PositiveInfinity)) + + val data2 = Array( + (-999.9, -999.9), + (-0.5, -0.2), + (-0.3, -0.1), + (0.0, 0.0), + (0.2, 0.4), + (999.9, 999.9)) + val dataFrame2 = spark.createDataFrame(data2).toDF("features1", "features2") + + val bucketizer2 = new Bucketizer() + .setInputCols(Array("features1", "features2")) + .setOutputCols(Array("bucketedFeatures1", "bucketedFeatures2")) + .setSplitsArray(splitsArray) + + // Transform original data into its bucket index. + val bucketedData2 = bucketizer2.transform(dataFrame2) + + println(s"Bucketizer output with [" + + s"${bucketizer2.getSplitsArray(0).length-1}, " + + s"${bucketizer2.getSplitsArray(1).length-1}] buckets for each input column") + bucketedData2.show() + // $example off$ + spark.stop() } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala index 6a11a75d1d569..e07f2a107badb 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala @@ -24,7 +24,7 @@ import org.apache.spark.annotation.Since import org.apache.spark.ml.Model import org.apache.spark.ml.attribute.NominalAttribute import org.apache.spark.ml.param._ -import org.apache.spark.ml.param.shared.{HasHandleInvalid, HasInputCol, HasOutputCol} +import org.apache.spark.ml.param.shared.{HasHandleInvalid, HasInputCol, HasInputCols, HasOutputCol, HasOutputCols} import org.apache.spark.ml.util._ import org.apache.spark.sql._ import org.apache.spark.sql.expressions.UserDefinedFunction @@ -32,12 +32,16 @@ import org.apache.spark.sql.functions._ import org.apache.spark.sql.types.{DoubleType, StructField, StructType} /** - * `Bucketizer` maps a column of continuous features to a column of feature buckets. + * `Bucketizer` maps a column of continuous features to a column of feature buckets. Since 2.3.0, + * `Bucketizer` can map multiple columns at once by setting the `inputCols` parameter. Note that + * when both the `inputCol` and `inputCols` parameters are set, a log warning will be printed and + * only `inputCol` will take effect, while `inputCols` will be ignored. The `splits` parameter is + * only used for single column usage, and `splitsArray` is for multiple columns. */ @Since("1.4.0") final class Bucketizer @Since("1.4.0") (@Since("1.4.0") override val uid: String) extends Model[Bucketizer] with HasHandleInvalid with HasInputCol with HasOutputCol - with DefaultParamsWritable { + with HasInputCols with HasOutputCols with DefaultParamsWritable { @Since("1.4.0") def this() = this(Identifiable.randomUID("bucketizer")) @@ -81,7 +85,9 @@ final class Bucketizer @Since("1.4.0") (@Since("1.4.0") override val uid: String /** * Param for how to handle invalid entries. Options are 'skip' (filter out rows with * invalid values), 'error' (throw an error), or 'keep' (keep invalid values in a special - * additional bucket). + * additional bucket). Note that in the multiple column case, the invalid handling is applied + * to all columns. That said for 'error' it will throw an error if any invalids are found in + * any column, for 'skip' it will skip rows with any invalids in any columns, etc. * Default: "error" * @group param */ @@ -96,9 +102,59 @@ final class Bucketizer @Since("1.4.0") (@Since("1.4.0") override val uid: String def setHandleInvalid(value: String): this.type = set(handleInvalid, value) setDefault(handleInvalid, Bucketizer.ERROR_INVALID) + /** + * Parameter for specifying multiple splits parameters. Each element in this array can be used to + * map continuous features into buckets. + * + * @group param + */ + @Since("2.3.0") + val splitsArray: DoubleArrayArrayParam = new DoubleArrayArrayParam(this, "splitsArray", + "The array of split points for mapping continuous features into buckets for multiple " + + "columns. For each input column, with n+1 splits, there are n buckets. A bucket defined by " + + "splits x,y holds values in the range [x,y) except the last bucket, which also includes y. " + + "The splits should be of length >= 3 and strictly increasing. Values at -inf, inf must be " + + "explicitly provided to cover all Double values; otherwise, values outside the splits " + + "specified will be treated as errors.", + Bucketizer.checkSplitsArray) + + /** @group getParam */ + @Since("2.3.0") + def getSplitsArray: Array[Array[Double]] = $(splitsArray) + + /** @group setParam */ + @Since("2.3.0") + def setSplitsArray(value: Array[Array[Double]]): this.type = set(splitsArray, value) + + /** @group setParam */ + @Since("2.3.0") + def setInputCols(value: Array[String]): this.type = set(inputCols, value) + + /** @group setParam */ + @Since("2.3.0") + def setOutputCols(value: Array[String]): this.type = set(outputCols, value) + + /** + * Determines whether this `Bucketizer` is going to map multiple columns. If and only if + * `inputCols` is set, it will map multiple columns. Otherwise, it just maps a column specified + * by `inputCol`. A warning will be printed if both are set. + */ + private[feature] def isBucketizeMultipleColumns(): Boolean = { + if (isSet(inputCols) && isSet(inputCol)) { + logWarning("Both `inputCol` and `inputCols` are set, we ignore `inputCols` and this " + + "`Bucketizer` only map one column specified by `inputCol`") + false + } else if (isSet(inputCols)) { + true + } else { + false + } + } + @Since("2.0.0") override def transform(dataset: Dataset[_]): DataFrame = { - transformSchema(dataset.schema) + val transformedSchema = transformSchema(dataset.schema) + val (filteredDataset, keepInvalid) = { if (getHandleInvalid == Bucketizer.SKIP_INVALID) { // "skip" NaN option is set, will filter out NaN values in the dataset @@ -108,26 +164,53 @@ final class Bucketizer @Since("1.4.0") (@Since("1.4.0") override val uid: String } } - val bucketizer: UserDefinedFunction = udf { (feature: Double) => - Bucketizer.binarySearchForBuckets($(splits), feature, keepInvalid) - }.withName("bucketizer") + val seqOfSplits = if (isBucketizeMultipleColumns()) { + $(splitsArray).toSeq + } else { + Seq($(splits)) + } - val newCol = bucketizer(filteredDataset($(inputCol)).cast(DoubleType)) - val newField = prepOutputField(filteredDataset.schema) - filteredDataset.withColumn($(outputCol), newCol, newField.metadata) + val bucketizers: Seq[UserDefinedFunction] = seqOfSplits.zipWithIndex.map { case (splits, idx) => + udf { (feature: Double) => + Bucketizer.binarySearchForBuckets(splits, feature, keepInvalid) + }.withName(s"bucketizer_$idx") + } + + val (inputColumns, outputColumns) = if (isBucketizeMultipleColumns()) { + ($(inputCols).toSeq, $(outputCols).toSeq) + } else { + (Seq($(inputCol)), Seq($(outputCol))) + } + val newCols = inputColumns.zipWithIndex.map { case (inputCol, idx) => + bucketizers(idx)(filteredDataset(inputCol).cast(DoubleType)) + } + val metadata = outputColumns.map { col => + transformedSchema(col).metadata + } + filteredDataset.withColumns(outputColumns, newCols, metadata) } - private def prepOutputField(schema: StructType): StructField = { - val buckets = $(splits).sliding(2).map(bucket => bucket.mkString(", ")).toArray - val attr = new NominalAttribute(name = Some($(outputCol)), isOrdinal = Some(true), + private def prepOutputField(splits: Array[Double], outputCol: String): StructField = { + val buckets = splits.sliding(2).map(bucket => bucket.mkString(", ")).toArray + val attr = new NominalAttribute(name = Some(outputCol), isOrdinal = Some(true), values = Some(buckets)) attr.toStructField() } @Since("1.4.0") override def transformSchema(schema: StructType): StructType = { - SchemaUtils.checkNumericType(schema, $(inputCol)) - SchemaUtils.appendColumn(schema, prepOutputField(schema)) + if (isBucketizeMultipleColumns()) { + var transformedSchema = schema + $(inputCols).zip($(outputCols)).zipWithIndex.map { case ((inputCol, outputCol), idx) => + SchemaUtils.checkNumericType(transformedSchema, inputCol) + transformedSchema = SchemaUtils.appendColumn(transformedSchema, + prepOutputField($(splitsArray)(idx), outputCol)) + } + transformedSchema + } else { + SchemaUtils.checkNumericType(schema, $(inputCol)) + SchemaUtils.appendColumn(schema, prepOutputField($(splits), $(outputCol))) + } } @Since("1.4.1") @@ -163,6 +246,13 @@ object Bucketizer extends DefaultParamsReadable[Bucketizer] { } } + /** + * Check each splits in the splits array. + */ + private[feature] def checkSplitsArray(splitsArray: Array[Array[Double]]): Boolean = { + splitsArray.forall(checkSplits(_)) + } + /** * Binary searching in several buckets to place each data point. * @param splits array of split points diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala index ac68b825af537..8985f2af90a9a 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala @@ -490,6 +490,45 @@ class DoubleArrayParam(parent: Params, name: String, doc: String, isValid: Array } } +/** + * :: DeveloperApi :: + * Specialized version of `Param[Array[Array[Double]]]` for Java. + */ +@DeveloperApi +class DoubleArrayArrayParam( + parent: Params, + name: String, + doc: String, + isValid: Array[Array[Double]] => Boolean) + extends Param[Array[Array[Double]]](parent, name, doc, isValid) { + + def this(parent: Params, name: String, doc: String) = + this(parent, name, doc, ParamValidators.alwaysTrue) + + /** Creates a param pair with a `java.util.List` of values (for Java and Python). */ + def w(value: java.util.List[java.util.List[java.lang.Double]]): ParamPair[Array[Array[Double]]] = + w(value.asScala.map(_.asScala.map(_.asInstanceOf[Double]).toArray).toArray) + + override def jsonEncode(value: Array[Array[Double]]): String = { + import org.json4s.JsonDSL._ + compact(render(value.toSeq.map(_.toSeq.map(DoubleParam.jValueEncode)))) + } + + override def jsonDecode(json: String): Array[Array[Double]] = { + parse(json) match { + case JArray(values) => + values.map { + case JArray(values) => + values.map(DoubleParam.jValueDecode).toArray + case _ => + throw new IllegalArgumentException(s"Cannot decode $json to Array[Array[Double]].") + }.toArray + case _ => + throw new IllegalArgumentException(s"Cannot decode $json to Array[Array[Double]].") + } + } +} + /** * :: DeveloperApi :: * Specialized version of `Param[Array[Int]]` for Java. diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala b/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala index a932d28fadbd8..20a1db854e3a6 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala @@ -60,6 +60,7 @@ private[shared] object SharedParamsCodeGen { ParamDesc[String]("inputCol", "input column name"), ParamDesc[Array[String]]("inputCols", "input column names"), ParamDesc[String]("outputCol", "output column name", Some("uid + \"__output\"")), + ParamDesc[Array[String]]("outputCols", "output column names"), ParamDesc[Int]("checkpointInterval", "set checkpoint interval (>= 1) or " + "disable checkpoint (-1). E.g. 10 means that the cache will get checkpointed " + "every 10 iterations", isValid = "(interval: Int) => interval == -1 || interval >= 1"), diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala b/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala index e6bdf5236e72d..0d5fb28ae783c 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala @@ -257,6 +257,23 @@ trait HasOutputCol extends Params { final def getOutputCol: String = $(outputCol) } +/** + * Trait for shared param outputCols. This trait may be changed or + * removed between minor versions. + */ +@DeveloperApi +trait HasOutputCols extends Params { + + /** + * Param for output column names. + * @group param + */ + final val outputCols: StringArrayParam = new StringArrayParam(this, "outputCols", "output column names") + + /** @group getParam */ + final def getOutputCols: Array[String] = $(outputCols) +} + /** * Trait for shared param checkpointInterval. This trait may be changed or * removed between minor versions. diff --git a/mllib/src/test/java/org/apache/spark/ml/feature/JavaBucketizerSuite.java b/mllib/src/test/java/org/apache/spark/ml/feature/JavaBucketizerSuite.java index 87639380bdcf4..e65265bf74a88 100644 --- a/mllib/src/test/java/org/apache/spark/ml/feature/JavaBucketizerSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/feature/JavaBucketizerSuite.java @@ -61,4 +61,39 @@ public void bucketizerTest() { Assert.assertTrue((index >= 0) && (index <= 1)); } } + + @Test + public void bucketizerMultipleColumnsTest() { + double[][] splitsArray = { + {-0.5, 0.0, 0.5}, + {-0.5, 0.0, 0.2, 0.5} + }; + + StructType schema = new StructType(new StructField[]{ + new StructField("feature1", DataTypes.DoubleType, false, Metadata.empty()), + new StructField("feature2", DataTypes.DoubleType, false, Metadata.empty()), + }); + Dataset dataset = spark.createDataFrame( + Arrays.asList( + RowFactory.create(-0.5, -0.5), + RowFactory.create(-0.3, -0.3), + RowFactory.create(0.0, 0.0), + RowFactory.create(0.2, 0.3)), + schema); + + Bucketizer bucketizer = new Bucketizer() + .setInputCols(new String[] {"feature1", "feature2"}) + .setOutputCols(new String[] {"result1", "result2"}) + .setSplitsArray(splitsArray); + + List result = bucketizer.transform(dataset).select("result1", "result2").collectAsList(); + + for (Row r : result) { + double index1 = r.getDouble(0); + Assert.assertTrue((index1 >= 0) && (index1 <= 1)); + + double index2 = r.getDouble(1); + Assert.assertTrue((index2 >= 0) && (index2 <= 2)); + } + } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/BucketizerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/BucketizerSuite.scala index 420fb17ddce8c..748dbd1b995d3 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/BucketizerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/BucketizerSuite.scala @@ -20,9 +20,10 @@ package org.apache.spark.ml.feature import scala.util.Random import org.apache.spark.{SparkException, SparkFunSuite} +import org.apache.spark.ml.Pipeline import org.apache.spark.ml.linalg.Vectors import org.apache.spark.ml.param.ParamsSuite -import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} +import org.apache.spark.ml.util.DefaultReadWriteTest import org.apache.spark.ml.util.TestingUtils._ import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.sql.{DataFrame, Row} @@ -187,6 +188,220 @@ class BucketizerSuite extends SparkFunSuite with MLlibTestSparkContext with Defa } } } + + test("multiple columns: Bucket continuous features, without -inf,inf") { + // Check a set of valid feature values. + val splits = Array(Array(-0.5, 0.0, 0.5), Array(-0.1, 0.3, 0.5)) + val validData1 = Array(-0.5, -0.3, 0.0, 0.2) + val validData2 = Array(0.5, 0.3, 0.0, -0.1) + val expectedBuckets1 = Array(0.0, 0.0, 1.0, 1.0) + val expectedBuckets2 = Array(1.0, 1.0, 0.0, 0.0) + + val data = (0 until validData1.length).map { idx => + (validData1(idx), validData2(idx), expectedBuckets1(idx), expectedBuckets2(idx)) + } + val dataFrame: DataFrame = data.toDF("feature1", "feature2", "expected1", "expected2") + + val bucketizer1: Bucketizer = new Bucketizer() + .setInputCols(Array("feature1", "feature2")) + .setOutputCols(Array("result1", "result2")) + .setSplitsArray(splits) + + assert(bucketizer1.isBucketizeMultipleColumns()) + + bucketizer1.transform(dataFrame).select("result1", "expected1", "result2", "expected2") + BucketizerSuite.checkBucketResults(bucketizer1.transform(dataFrame), + Seq("result1", "result2"), + Seq("expected1", "expected2")) + + // Check for exceptions when using a set of invalid feature values. + val invalidData1 = Array(-0.9) ++ validData1 + val invalidData2 = Array(0.51) ++ validData1 + val badDF1 = invalidData1.zipWithIndex.toSeq.toDF("feature", "idx") + + val bucketizer2: Bucketizer = new Bucketizer() + .setInputCols(Array("feature")) + .setOutputCols(Array("result")) + .setSplitsArray(Array(splits(0))) + + assert(bucketizer2.isBucketizeMultipleColumns()) + + withClue("Invalid feature value -0.9 was not caught as an invalid feature!") { + intercept[SparkException] { + bucketizer2.transform(badDF1).collect() + } + } + val badDF2 = invalidData2.zipWithIndex.toSeq.toDF("feature", "idx") + withClue("Invalid feature value 0.51 was not caught as an invalid feature!") { + intercept[SparkException] { + bucketizer2.transform(badDF2).collect() + } + } + } + + test("multiple columns: Bucket continuous features, with -inf,inf") { + val splits = Array( + Array(Double.NegativeInfinity, -0.5, 0.0, 0.5, Double.PositiveInfinity), + Array(Double.NegativeInfinity, -0.3, 0.2, 0.5, Double.PositiveInfinity)) + + val validData1 = Array(-0.9, -0.5, -0.3, 0.0, 0.2, 0.5, 0.9) + val validData2 = Array(-0.1, -0.5, -0.2, 0.0, 0.1, 0.3, 0.5) + val expectedBuckets1 = Array(0.0, 1.0, 1.0, 2.0, 2.0, 3.0, 3.0) + val expectedBuckets2 = Array(1.0, 0.0, 1.0, 1.0, 1.0, 2.0, 3.0) + + val data = (0 until validData1.length).map { idx => + (validData1(idx), validData2(idx), expectedBuckets1(idx), expectedBuckets2(idx)) + } + val dataFrame: DataFrame = data.toDF("feature1", "feature2", "expected1", "expected2") + + val bucketizer: Bucketizer = new Bucketizer() + .setInputCols(Array("feature1", "feature2")) + .setOutputCols(Array("result1", "result2")) + .setSplitsArray(splits) + + assert(bucketizer.isBucketizeMultipleColumns()) + + BucketizerSuite.checkBucketResults(bucketizer.transform(dataFrame), + Seq("result1", "result2"), + Seq("expected1", "expected2")) + } + + test("multiple columns: Bucket continuous features, with NaN data but non-NaN splits") { + val splits = Array( + Array(Double.NegativeInfinity, -0.5, 0.0, 0.5, Double.PositiveInfinity), + Array(Double.NegativeInfinity, -0.1, 0.2, 0.6, Double.PositiveInfinity)) + + val validData1 = Array(-0.9, -0.5, -0.3, 0.0, 0.2, 0.5, 0.9, Double.NaN, Double.NaN, Double.NaN) + val validData2 = Array(0.2, -0.1, 0.3, 0.0, 0.1, 0.3, 0.5, 0.8, Double.NaN, Double.NaN) + val expectedBuckets1 = Array(0.0, 1.0, 1.0, 2.0, 2.0, 3.0, 3.0, 4.0, 4.0, 4.0) + val expectedBuckets2 = Array(2.0, 1.0, 2.0, 1.0, 1.0, 2.0, 2.0, 3.0, 4.0, 4.0) + + val data = (0 until validData1.length).map { idx => + (validData1(idx), validData2(idx), expectedBuckets1(idx), expectedBuckets2(idx)) + } + val dataFrame: DataFrame = data.toDF("feature1", "feature2", "expected1", "expected2") + + val bucketizer: Bucketizer = new Bucketizer() + .setInputCols(Array("feature1", "feature2")) + .setOutputCols(Array("result1", "result2")) + .setSplitsArray(splits) + + assert(bucketizer.isBucketizeMultipleColumns()) + + bucketizer.setHandleInvalid("keep") + BucketizerSuite.checkBucketResults(bucketizer.transform(dataFrame), + Seq("result1", "result2"), + Seq("expected1", "expected2")) + + bucketizer.setHandleInvalid("skip") + val skipResults1: Array[Double] = bucketizer.transform(dataFrame) + .select("result1").as[Double].collect() + assert(skipResults1.length === 7) + assert(skipResults1.forall(_ !== 4.0)) + + val skipResults2: Array[Double] = bucketizer.transform(dataFrame) + .select("result2").as[Double].collect() + assert(skipResults2.length === 7) + assert(skipResults2.forall(_ !== 4.0)) + + bucketizer.setHandleInvalid("error") + withClue("Bucketizer should throw error when setHandleInvalid=error and given NaN values") { + intercept[SparkException] { + bucketizer.transform(dataFrame).collect() + } + } + } + + test("multiple columns: Bucket continuous features, with NaN splits") { + val splits = Array(Double.NegativeInfinity, -0.5, 0.0, 0.5, Double.PositiveInfinity, Double.NaN) + withClue("Invalid NaN split was not caught during Bucketizer initialization") { + intercept[IllegalArgumentException] { + new Bucketizer().setSplitsArray(Array(splits)) + } + } + } + + test("multiple columns: read/write") { + val t = new Bucketizer() + .setInputCols(Array("myInputCol")) + .setOutputCols(Array("myOutputCol")) + .setSplitsArray(Array(Array(0.1, 0.8, 0.9))) + assert(t.isBucketizeMultipleColumns()) + testDefaultReadWrite(t) + } + + test("Bucketizer in a pipeline") { + val df = Seq((0.5, 0.3, 1.0, 1.0), (0.5, -0.4, 1.0, 0.0)) + .toDF("feature1", "feature2", "expected1", "expected2") + + val bucket = new Bucketizer() + .setInputCols(Array("feature1", "feature2")) + .setOutputCols(Array("result1", "result2")) + .setSplitsArray(Array(Array(-0.5, 0.0, 0.5), Array(-0.5, 0.0, 0.5))) + + assert(bucket.isBucketizeMultipleColumns()) + + val pl = new Pipeline() + .setStages(Array(bucket)) + .fit(df) + pl.transform(df).select("result1", "expected1", "result2", "expected2") + + BucketizerSuite.checkBucketResults(pl.transform(df), + Seq("result1", "result2"), Seq("expected1", "expected2")) + } + + test("Compare single/multiple column(s) Bucketizer in pipeline") { + val df = Seq((0.5, 0.3, 1.0, 1.0), (0.5, -0.4, 1.0, 0.0)) + .toDF("feature1", "feature2", "expected1", "expected2") + + val multiColsBucket = new Bucketizer() + .setInputCols(Array("feature1", "feature2")) + .setOutputCols(Array("result1", "result2")) + .setSplitsArray(Array(Array(-0.5, 0.0, 0.5), Array(-0.5, 0.0, 0.5))) + + val plForMultiCols = new Pipeline() + .setStages(Array(multiColsBucket)) + .fit(df) + + val bucketForCol1 = new Bucketizer() + .setInputCol("feature1") + .setOutputCol("result1") + .setSplits(Array(-0.5, 0.0, 0.5)) + val bucketForCol2 = new Bucketizer() + .setInputCol("feature2") + .setOutputCol("result2") + .setSplits(Array(-0.5, 0.0, 0.5)) + + val plForSingleCol = new Pipeline() + .setStages(Array(bucketForCol1, bucketForCol2)) + .fit(df) + + val resultForSingleCol = plForSingleCol.transform(df) + .select("result1", "expected1", "result2", "expected2") + .collect() + val resultForMultiCols = plForMultiCols.transform(df) + .select("result1", "expected1", "result2", "expected2") + .collect() + + resultForSingleCol.zip(resultForMultiCols).foreach { + case (rowForSingle, rowForMultiCols) => + assert(rowForSingle.getDouble(0) == rowForMultiCols.getDouble(0) && + rowForSingle.getDouble(1) == rowForMultiCols.getDouble(1) && + rowForSingle.getDouble(2) == rowForMultiCols.getDouble(2) && + rowForSingle.getDouble(3) == rowForMultiCols.getDouble(3)) + } + } + + test("Both inputCol and inputCols are set") { + val bucket = new Bucketizer() + .setInputCol("feature1") + .setOutputCol("result") + .setSplits(Array(-0.5, 0.0, 0.5)) + .setInputCols(Array("feature1", "feature2")) + + // When both are set, we ignore `inputCols` and just map the column specified by `inputCol`. + assert(bucket.isBucketizeMultipleColumns() == false) + } } private object BucketizerSuite extends SparkFunSuite { @@ -220,4 +435,26 @@ private object BucketizerSuite extends SparkFunSuite { i += 1 } } + + /** Checks if bucketized results match expected ones. */ + def checkBucketResults( + bucketResult: DataFrame, + resultColumns: Seq[String], + expectedColumns: Seq[String]): Unit = { + assert(resultColumns.length == expectedColumns.length, + s"Given ${resultColumns.length} result columns doesn't match " + + s"${expectedColumns.length} expected columns.") + assert(resultColumns.length > 0, "At least one result and expected columns are needed.") + + val allColumns = resultColumns ++ expectedColumns + bucketResult.select(allColumns.head, allColumns.tail: _*).collect().foreach { + case row => + for (idx <- 0 until row.length / 2) { + val result = row.getDouble(idx) + val expected = row.getDouble(idx + row.length / 2) + assert(result === expected, "The feature value is not correct after bucketing. " + + s"Expected $expected but found $result.") + } + } + } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala index 78a33e05e0e48..85198ad4c913a 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala @@ -121,10 +121,10 @@ class ParamsSuite extends SparkFunSuite { { // DoubleArrayParam val param = new DoubleArrayParam(dummy, "name", "doc") val values: Seq[Array[Double]] = Seq( - Array(), - Array(1.0), - Array(Double.NaN, Double.NegativeInfinity, Double.MinValue, -1.0, 0.0, - Double.MinPositiveValue, 1.0, Double.MaxValue, Double.PositiveInfinity)) + Array(), + Array(1.0), + Array(Double.NaN, Double.NegativeInfinity, Double.MinValue, -1.0, 0.0, + Double.MinPositiveValue, 1.0, Double.MaxValue, Double.PositiveInfinity)) for (value <- values) { val json = param.jsonEncode(value) val decoded = param.jsonDecode(json) @@ -139,6 +139,36 @@ class ParamsSuite extends SparkFunSuite { } } + { // DoubleArrayArrayParam + val param = new DoubleArrayArrayParam(dummy, "name", "doc") + val values: Seq[Array[Array[Double]]] = Seq( + Array(Array()), + Array(Array(1.0)), + Array(Array(1.0), Array(2.0)), + Array( + Array(Double.NaN, Double.NegativeInfinity, Double.MinValue, -1.0, 0.0, + Double.MinPositiveValue, 1.0, Double.MaxValue, Double.PositiveInfinity), + Array(Double.MaxValue, Double.PositiveInfinity, Double.MinPositiveValue, 1.0, + Double.NaN, Double.NegativeInfinity, Double.MinValue, -1.0, 0.0) + )) + + for (value <- values) { + val json = param.jsonEncode(value) + val decoded = param.jsonDecode(json) + assert(decoded.length === value.length) + decoded.zip(value).foreach { case (actualArray, expectedArray) => + assert(actualArray.length === expectedArray.length) + actualArray.zip(expectedArray).foreach { case (actual, expected) => + if (expected.isNaN) { + assert(actual.isNaN) + } else { + assert(actual === expected) + } + } + } + } + } + { // StringArrayParam val param = new StringArrayParam(dummy, "name", "doc") val values: Seq[Array[String]] = Seq( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index bd99ec52ce93f..5eb2affa0bd8f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -2135,12 +2135,27 @@ class Dataset[T] private[sql]( } /** - * Returns a new Dataset by adding a column with metadata. + * Returns a new Dataset by adding columns with metadata. */ - private[spark] def withColumn(colName: String, col: Column, metadata: Metadata): DataFrame = { - withColumn(colName, col.as(colName, metadata)) + private[spark] def withColumns( + colNames: Seq[String], + cols: Seq[Column], + metadata: Seq[Metadata]): DataFrame = { + require(colNames.size == metadata.size, + s"The size of column names: ${colNames.size} isn't equal to " + + s"the size of metadata elements: ${metadata.size}") + val newCols = colNames.zip(cols).zip(metadata).map { case ((colName, col), metadata) => + col.as(colName, metadata) + } + withColumns(colNames, newCols) } + /** + * Returns a new Dataset by adding a column with metadata. + */ + private[spark] def withColumn(colName: String, col: Column, metadata: Metadata): DataFrame = + withColumns(Seq(colName), Seq(col), Seq(metadata)) + /** * Returns a new Dataset with a column renamed. * This is a no-op if schema doesn't contain existingName. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index 17c88b0690800..31bfa77e76329 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -686,6 +686,34 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { } } + test("withColumns: given metadata") { + def buildMetadata(num: Int): Seq[Metadata] = { + (0 until num).map { n => + val builder = new MetadataBuilder + builder.putLong("key", n.toLong) + builder.build() + } + } + + val df = testData.toDF().withColumns( + Seq("newCol1", "newCol2"), + Seq(col("key") + 1, col("key") + 2), + buildMetadata(2)) + + df.select("newCol1", "newCol2").schema.zipWithIndex.foreach { case (col, idx) => + assert(col.metadata.getLong("key").toInt === idx) + } + + val err = intercept[IllegalArgumentException] { + testData.toDF().withColumns( + Seq("newCol1", "newCol2"), + Seq(col("key") + 1, col("key") + 2), + buildMetadata(1)) + } + assert(err.getMessage.contains( + "The size of column names: 2 isn't equal to the size of metadata elements: 1")) + } + test("replace column using withColumn") { val df2 = sparkContext.parallelize(Array(1, 2, 3)).toDF("x") val df3 = df2.withColumn("x", df2("x") + 1) From 6b19c0735d5dc9287c13dfc255eb353d827800e2 Mon Sep 17 00:00:00 2001 From: Srinivasa Reddy Vundela Date: Thu, 9 Nov 2017 09:53:41 -0800 Subject: [PATCH 1636/1765] [MINOR][CORE] Fix nits in MetricsSystemSuite ## What changes were proposed in this pull request? Fixing nits in MetricsSystemSuite file 1) Using Sink instead of Source while casting 2) Using meaningful naming for variables, which reflect their usage ## How was this patch tested? Ran the tests locally and all of them passing Author: Srinivasa Reddy Vundela Closes #19699 from vundela/master. --- .../spark/metrics/MetricsSystemSuite.scala | 37 ++++++++++--------- 1 file changed, 19 insertions(+), 18 deletions(-) diff --git a/core/src/test/scala/org/apache/spark/metrics/MetricsSystemSuite.scala b/core/src/test/scala/org/apache/spark/metrics/MetricsSystemSuite.scala index 61db6af830cc5..a7a24114f17e2 100644 --- a/core/src/test/scala/org/apache/spark/metrics/MetricsSystemSuite.scala +++ b/core/src/test/scala/org/apache/spark/metrics/MetricsSystemSuite.scala @@ -25,6 +25,7 @@ import org.scalatest.{BeforeAndAfter, PrivateMethodTester} import org.apache.spark.{SecurityManager, SparkConf, SparkFunSuite} import org.apache.spark.deploy.master.MasterSource import org.apache.spark.internal.config._ +import org.apache.spark.metrics.sink.Sink import org.apache.spark.metrics.source.{Source, StaticSources} class MetricsSystemSuite extends SparkFunSuite with BeforeAndAfter with PrivateMethodTester{ @@ -42,7 +43,7 @@ class MetricsSystemSuite extends SparkFunSuite with BeforeAndAfter with PrivateM val metricsSystem = MetricsSystem.createMetricsSystem("default", conf, securityMgr) metricsSystem.start() val sources = PrivateMethod[ArrayBuffer[Source]]('sources) - val sinks = PrivateMethod[ArrayBuffer[Source]]('sinks) + val sinks = PrivateMethod[ArrayBuffer[Sink]]('sinks) assert(metricsSystem.invokePrivate(sources()).length === StaticSources.allSources.length) assert(metricsSystem.invokePrivate(sinks()).length === 0) @@ -53,7 +54,7 @@ class MetricsSystemSuite extends SparkFunSuite with BeforeAndAfter with PrivateM val metricsSystem = MetricsSystem.createMetricsSystem("test", conf, securityMgr) metricsSystem.start() val sources = PrivateMethod[ArrayBuffer[Source]]('sources) - val sinks = PrivateMethod[ArrayBuffer[Source]]('sinks) + val sinks = PrivateMethod[ArrayBuffer[Sink]]('sinks) assert(metricsSystem.invokePrivate(sources()).length === StaticSources.allSources.length) assert(metricsSystem.invokePrivate(sinks()).length === 1) @@ -126,9 +127,9 @@ class MetricsSystemSuite extends SparkFunSuite with BeforeAndAfter with PrivateM conf.set("spark.executor.id", executorId) val instanceName = "executor" - val driverMetricsSystem = MetricsSystem.createMetricsSystem(instanceName, conf, securityMgr) + val executorMetricsSystem = MetricsSystem.createMetricsSystem(instanceName, conf, securityMgr) - val metricName = driverMetricsSystem.buildRegistryName(source) + val metricName = executorMetricsSystem.buildRegistryName(source) assert(metricName === s"$appId.$executorId.${source.sourceName}") } @@ -142,9 +143,9 @@ class MetricsSystemSuite extends SparkFunSuite with BeforeAndAfter with PrivateM conf.set("spark.executor.id", executorId) val instanceName = "executor" - val driverMetricsSystem = MetricsSystem.createMetricsSystem(instanceName, conf, securityMgr) + val executorMetricsSystem = MetricsSystem.createMetricsSystem(instanceName, conf, securityMgr) - val metricName = driverMetricsSystem.buildRegistryName(source) + val metricName = executorMetricsSystem.buildRegistryName(source) assert(metricName === source.sourceName) } @@ -158,9 +159,9 @@ class MetricsSystemSuite extends SparkFunSuite with BeforeAndAfter with PrivateM conf.set("spark.app.id", appId) val instanceName = "executor" - val driverMetricsSystem = MetricsSystem.createMetricsSystem(instanceName, conf, securityMgr) + val executorMetricsSystem = MetricsSystem.createMetricsSystem(instanceName, conf, securityMgr) - val metricName = driverMetricsSystem.buildRegistryName(source) + val metricName = executorMetricsSystem.buildRegistryName(source) assert(metricName === source.sourceName) } @@ -176,9 +177,9 @@ class MetricsSystemSuite extends SparkFunSuite with BeforeAndAfter with PrivateM conf.set("spark.executor.id", executorId) val instanceName = "testInstance" - val driverMetricsSystem = MetricsSystem.createMetricsSystem(instanceName, conf, securityMgr) + val testMetricsSystem = MetricsSystem.createMetricsSystem(instanceName, conf, securityMgr) - val metricName = driverMetricsSystem.buildRegistryName(source) + val metricName = testMetricsSystem.buildRegistryName(source) // Even if spark.app.id and spark.executor.id are set, they are not used for the metric name. assert(metricName != s"$appId.$executorId.${source.sourceName}") @@ -200,9 +201,9 @@ class MetricsSystemSuite extends SparkFunSuite with BeforeAndAfter with PrivateM conf.set(METRICS_NAMESPACE, "${spark.app.name}") val instanceName = "executor" - val driverMetricsSystem = MetricsSystem.createMetricsSystem(instanceName, conf, securityMgr) + val executorMetricsSystem = MetricsSystem.createMetricsSystem(instanceName, conf, securityMgr) - val metricName = driverMetricsSystem.buildRegistryName(source) + val metricName = executorMetricsSystem.buildRegistryName(source) assert(metricName === s"$appName.$executorId.${source.sourceName}") } @@ -218,9 +219,9 @@ class MetricsSystemSuite extends SparkFunSuite with BeforeAndAfter with PrivateM conf.set(METRICS_NAMESPACE, namespaceToResolve) val instanceName = "executor" - val driverMetricsSystem = MetricsSystem.createMetricsSystem(instanceName, conf, securityMgr) + val executorMetricsSystem = MetricsSystem.createMetricsSystem(instanceName, conf, securityMgr) - val metricName = driverMetricsSystem.buildRegistryName(source) + val metricName = executorMetricsSystem.buildRegistryName(source) // If the user set the spark.metrics.namespace property to an expansion of another property // (say ${spark.doesnotexist}, the unresolved name (i.e. literally ${spark.doesnotexist}) // is used as the root logger name. @@ -238,9 +239,9 @@ class MetricsSystemSuite extends SparkFunSuite with BeforeAndAfter with PrivateM conf.set(METRICS_NAMESPACE, "${spark.app.name}") val instanceName = "executor" - val driverMetricsSystem = MetricsSystem.createMetricsSystem(instanceName, conf, securityMgr) + val executorMetricsSystem = MetricsSystem.createMetricsSystem(instanceName, conf, securityMgr) - val metricName = driverMetricsSystem.buildRegistryName(source) + val metricName = executorMetricsSystem.buildRegistryName(source) assert(metricName === source.sourceName) } @@ -259,9 +260,9 @@ class MetricsSystemSuite extends SparkFunSuite with BeforeAndAfter with PrivateM conf.set("spark.executor.id", executorId) val instanceName = "testInstance" - val driverMetricsSystem = MetricsSystem.createMetricsSystem(instanceName, conf, securityMgr) + val testMetricsSystem = MetricsSystem.createMetricsSystem(instanceName, conf, securityMgr) - val metricName = driverMetricsSystem.buildRegistryName(source) + val metricName = testMetricsSystem.buildRegistryName(source) // Even if spark.app.id and spark.executor.id are set, they are not used for the metric name. assert(metricName != s"$appId.$executorId.${source.sourceName}") From 6ae12715c7286abc87cc7a22b1fa955e844077d5 Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Thu, 9 Nov 2017 15:46:16 -0600 Subject: [PATCH 1637/1765] [SPARK-20647][CORE] Port StorageTab to the new UI backend. This required adding information about StreamBlockId to the store, which is not available yet via the API. So an internal type was added until there's a need to expose that information in the API. The UI only lists RDDs that have cached partitions, and that information wasn't being correctly captured in the listener, so that's also fixed, along with some minor (internal) API adjustments so that the UI can get the correct data. Because of the way partitions are cached, some optimizations w.r.t. how often the data is flushed to the store could not be applied to this code; because of that, some different ways to make the code more performant were added to the data structures tracking RDD blocks, with the goal of avoiding expensive copies when lots of blocks are being updated. Tested with existing and updated unit tests. Author: Marcelo Vanzin Closes #19679 from vanzin/SPARK-20647. --- .../spark/status/AppStatusListener.scala | 97 ++++++--- .../apache/spark/status/AppStatusStore.scala | 10 +- .../org/apache/spark/status/LiveEntity.scala | 170 ++++++++++++--- .../spark/status/api/v1/AllRDDResource.scala | 81 +------ .../spark/status/api/v1/OneRDDResource.scala | 10 +- .../org/apache/spark/status/storeTypes.scala | 16 ++ .../spark/storage/BlockStatusListener.scala | 100 --------- .../spark/storage/StorageStatusListener.scala | 111 ---------- .../scala/org/apache/spark/ui/SparkUI.scala | 17 +- .../org/apache/spark/ui/storage/RDDPage.scala | 16 +- .../apache/spark/ui/storage/StoragePage.scala | 78 ++++--- .../apache/spark/ui/storage/StorageTab.scala | 68 +----- .../spark/status/AppStatusListenerSuite.scala | 186 +++++++++++----- .../apache/spark/status/LiveEntitySuite.scala | 68 ++++++ .../storage/BlockStatusListenerSuite.scala | 113 ---------- .../storage/StorageStatusListenerSuite.scala | 167 -------------- .../spark/ui/storage/StoragePageSuite.scala | 128 +++++++---- .../spark/ui/storage/StorageTabSuite.scala | 205 ------------------ project/MimaExcludes.scala | 2 + 19 files changed, 586 insertions(+), 1057 deletions(-) delete mode 100644 core/src/main/scala/org/apache/spark/storage/BlockStatusListener.scala delete mode 100644 core/src/main/scala/org/apache/spark/storage/StorageStatusListener.scala create mode 100644 core/src/test/scala/org/apache/spark/status/LiveEntitySuite.scala delete mode 100644 core/src/test/scala/org/apache/spark/storage/BlockStatusListenerSuite.scala delete mode 100644 core/src/test/scala/org/apache/spark/storage/StorageStatusListenerSuite.scala delete mode 100644 core/src/test/scala/org/apache/spark/ui/storage/StorageTabSuite.scala diff --git a/core/src/main/scala/org/apache/spark/status/AppStatusListener.scala b/core/src/main/scala/org/apache/spark/status/AppStatusListener.scala index 0469c871362c0..7f2c00c09d43d 100644 --- a/core/src/main/scala/org/apache/spark/status/AppStatusListener.scala +++ b/core/src/main/scala/org/apache/spark/status/AppStatusListener.scala @@ -146,10 +146,19 @@ private[spark] class AppStatusListener( override def onExecutorRemoved(event: SparkListenerExecutorRemoved): Unit = { liveExecutors.remove(event.executorId).foreach { exec => + val now = System.nanoTime() exec.isActive = false exec.removeTime = new Date(event.time) exec.removeReason = event.reason - update(exec, System.nanoTime()) + update(exec, now) + + // Remove all RDD distributions that reference the removed executor, in case there wasn't + // a corresponding event. + liveRDDs.values.foreach { rdd => + if (rdd.removeDistribution(exec)) { + update(rdd, now) + } + } } } @@ -465,7 +474,8 @@ private[spark] class AppStatusListener( override def onBlockUpdated(event: SparkListenerBlockUpdated): Unit = { event.blockUpdatedInfo.blockId match { case block: RDDBlockId => updateRDDBlock(event, block) - case _ => // TODO: API only covers RDD storage. + case stream: StreamBlockId => updateStreamBlock(event, stream) + case _ => } } @@ -502,29 +512,47 @@ private[spark] class AppStatusListener( val maybeExec = liveExecutors.get(executorId) var rddBlocksDelta = 0 + // Update the executor stats first, since they are used to calculate the free memory + // on tracked RDD distributions. + maybeExec.foreach { exec => + if (exec.hasMemoryInfo) { + if (storageLevel.useOffHeap) { + exec.usedOffHeap = newValue(exec.usedOffHeap, memoryDelta) + } else { + exec.usedOnHeap = newValue(exec.usedOnHeap, memoryDelta) + } + } + exec.memoryUsed = newValue(exec.memoryUsed, memoryDelta) + exec.diskUsed = newValue(exec.diskUsed, diskDelta) + } + // Update the block entry in the RDD info, keeping track of the deltas above so that we // can update the executor information too. liveRDDs.get(block.rddId).foreach { rdd => + if (updatedStorageLevel.isDefined) { + rdd.storageLevel = updatedStorageLevel.get + } + val partition = rdd.partition(block.name) val executors = if (updatedStorageLevel.isDefined) { - if (!partition.executors.contains(executorId)) { + val current = partition.executors + if (current.contains(executorId)) { + current + } else { rddBlocksDelta = 1 + current :+ executorId } - partition.executors + executorId } else { rddBlocksDelta = -1 - partition.executors - executorId + partition.executors.filter(_ != executorId) } // Only update the partition if it's still stored in some executor, otherwise get rid of it. if (executors.nonEmpty) { - if (updatedStorageLevel.isDefined) { - partition.storageLevel = updatedStorageLevel.get - } - partition.memoryUsed = newValue(partition.memoryUsed, memoryDelta) - partition.diskUsed = newValue(partition.diskUsed, diskDelta) - partition.executors = executors + partition.update(executors, rdd.storageLevel, + newValue(partition.memoryUsed, memoryDelta), + newValue(partition.diskUsed, diskDelta)) } else { rdd.removePartition(block.name) } @@ -532,42 +560,39 @@ private[spark] class AppStatusListener( maybeExec.foreach { exec => if (exec.rddBlocks + rddBlocksDelta > 0) { val dist = rdd.distribution(exec) - dist.memoryRemaining = newValue(dist.memoryRemaining, -memoryDelta) dist.memoryUsed = newValue(dist.memoryUsed, memoryDelta) dist.diskUsed = newValue(dist.diskUsed, diskDelta) if (exec.hasMemoryInfo) { if (storageLevel.useOffHeap) { dist.offHeapUsed = newValue(dist.offHeapUsed, memoryDelta) - dist.offHeapRemaining = newValue(dist.offHeapRemaining, -memoryDelta) } else { dist.onHeapUsed = newValue(dist.onHeapUsed, memoryDelta) - dist.onHeapRemaining = newValue(dist.onHeapRemaining, -memoryDelta) } } + dist.lastUpdate = null } else { rdd.removeDistribution(exec) } - } - if (updatedStorageLevel.isDefined) { - rdd.storageLevel = updatedStorageLevel.get + // Trigger an update on other RDDs so that the free memory information is updated. + liveRDDs.values.foreach { otherRdd => + if (otherRdd.info.id != block.rddId) { + otherRdd.distributionOpt(exec).foreach { dist => + dist.lastUpdate = null + update(otherRdd, now) + } + } + } } + rdd.memoryUsed = newValue(rdd.memoryUsed, memoryDelta) rdd.diskUsed = newValue(rdd.diskUsed, diskDelta) update(rdd, now) } + // Finish updating the executor now that we know the delta in the number of blocks. maybeExec.foreach { exec => - if (exec.hasMemoryInfo) { - if (storageLevel.useOffHeap) { - exec.usedOffHeap = newValue(exec.usedOffHeap, memoryDelta) - } else { - exec.usedOnHeap = newValue(exec.usedOnHeap, memoryDelta) - } - } - exec.memoryUsed = newValue(exec.memoryUsed, memoryDelta) - exec.diskUsed = newValue(exec.diskUsed, diskDelta) exec.rddBlocks += rddBlocksDelta maybeUpdate(exec, now) } @@ -577,6 +602,26 @@ private[spark] class AppStatusListener( liveExecutors.getOrElseUpdate(executorId, new LiveExecutor(executorId, addTime)) } + private def updateStreamBlock(event: SparkListenerBlockUpdated, stream: StreamBlockId): Unit = { + val storageLevel = event.blockUpdatedInfo.storageLevel + if (storageLevel.isValid) { + val data = new StreamBlockData( + stream.name, + event.blockUpdatedInfo.blockManagerId.executorId, + event.blockUpdatedInfo.blockManagerId.hostPort, + storageLevel.description, + storageLevel.useMemory, + storageLevel.useDisk, + storageLevel.deserialized, + event.blockUpdatedInfo.memSize, + event.blockUpdatedInfo.diskSize) + kvstore.write(data) + } else { + kvstore.delete(classOf[StreamBlockData], + Array(stream.name, event.blockUpdatedInfo.blockManagerId.executorId)) + } + } + private def getOrCreateStage(info: StageInfo): LiveStage = { val stage = liveStages.getOrElseUpdate((info.stageId, info.attemptId), new LiveStage()) stage.info = info diff --git a/core/src/main/scala/org/apache/spark/status/AppStatusStore.scala b/core/src/main/scala/org/apache/spark/status/AppStatusStore.scala index 334407829f9fe..80c8d7d11a3c2 100644 --- a/core/src/main/scala/org/apache/spark/status/AppStatusStore.scala +++ b/core/src/main/scala/org/apache/spark/status/AppStatusStore.scala @@ -209,14 +209,20 @@ private[spark] class AppStatusStore(store: KVStore) { indexed.skip(offset).max(length).asScala.map(_.info).toSeq } - def rddList(): Seq[v1.RDDStorageInfo] = { - store.view(classOf[RDDStorageInfoWrapper]).asScala.map(_.info).toSeq + def rddList(cachedOnly: Boolean = true): Seq[v1.RDDStorageInfo] = { + store.view(classOf[RDDStorageInfoWrapper]).asScala.map(_.info).filter { rdd => + !cachedOnly || rdd.numCachedPartitions > 0 + }.toSeq } def rdd(rddId: Int): v1.RDDStorageInfo = { store.read(classOf[RDDStorageInfoWrapper], rddId).info } + def streamBlocksList(): Seq[StreamBlockData] = { + store.view(classOf[StreamBlockData]).asScala.toSeq + } + def close(): Unit = { store.close() } diff --git a/core/src/main/scala/org/apache/spark/status/LiveEntity.scala b/core/src/main/scala/org/apache/spark/status/LiveEntity.scala index 8c48020e246b4..706d94c3a59b9 100644 --- a/core/src/main/scala/org/apache/spark/status/LiveEntity.scala +++ b/core/src/main/scala/org/apache/spark/status/LiveEntity.scala @@ -420,79 +420,101 @@ private class LiveStage extends LiveEntity { private class LiveRDDPartition(val blockName: String) { - var executors = Set[String]() - var storageLevel: String = null - var memoryUsed = 0L - var diskUsed = 0L + // Pointers used by RDDPartitionSeq. + @volatile var prev: LiveRDDPartition = null + @volatile var next: LiveRDDPartition = null + + var value: v1.RDDPartitionInfo = null + + def executors: Seq[String] = value.executors + + def memoryUsed: Long = value.memoryUsed + + def diskUsed: Long = value.diskUsed - def toApi(): v1.RDDPartitionInfo = { - new v1.RDDPartitionInfo( + def update( + executors: Seq[String], + storageLevel: String, + memoryUsed: Long, + diskUsed: Long): Unit = { + value = new v1.RDDPartitionInfo( blockName, storageLevel, memoryUsed, diskUsed, - executors.toSeq.sorted) + executors) } } -private class LiveRDDDistribution(val exec: LiveExecutor) { +private class LiveRDDDistribution(exec: LiveExecutor) { - var memoryRemaining = exec.maxMemory + val executorId = exec.executorId var memoryUsed = 0L var diskUsed = 0L var onHeapUsed = 0L var offHeapUsed = 0L - var onHeapRemaining = 0L - var offHeapRemaining = 0L + + // Keep the last update handy. This avoids recomputing the API view when not needed. + var lastUpdate: v1.RDDDataDistribution = null def toApi(): v1.RDDDataDistribution = { - new v1.RDDDataDistribution( - exec.hostPort, - memoryUsed, - memoryRemaining, - diskUsed, - if (exec.hasMemoryInfo) Some(onHeapUsed) else None, - if (exec.hasMemoryInfo) Some(offHeapUsed) else None, - if (exec.hasMemoryInfo) Some(onHeapRemaining) else None, - if (exec.hasMemoryInfo) Some(offHeapRemaining) else None) + if (lastUpdate == null) { + lastUpdate = new v1.RDDDataDistribution( + exec.hostPort, + memoryUsed, + exec.maxMemory - exec.memoryUsed, + diskUsed, + if (exec.hasMemoryInfo) Some(onHeapUsed) else None, + if (exec.hasMemoryInfo) Some(offHeapUsed) else None, + if (exec.hasMemoryInfo) Some(exec.totalOnHeap - exec.usedOnHeap) else None, + if (exec.hasMemoryInfo) Some(exec.totalOffHeap - exec.usedOffHeap) else None) + } + lastUpdate } } -private class LiveRDD(info: RDDInfo) extends LiveEntity { +private class LiveRDD(val info: RDDInfo) extends LiveEntity { var storageLevel: String = info.storageLevel.description var memoryUsed = 0L var diskUsed = 0L private val partitions = new HashMap[String, LiveRDDPartition]() + private val partitionSeq = new RDDPartitionSeq() + private val distributions = new HashMap[String, LiveRDDDistribution]() def partition(blockName: String): LiveRDDPartition = { - partitions.getOrElseUpdate(blockName, new LiveRDDPartition(blockName)) + partitions.getOrElseUpdate(blockName, { + val part = new LiveRDDPartition(blockName) + part.update(Nil, storageLevel, 0L, 0L) + partitionSeq.addPartition(part) + part + }) } - def removePartition(blockName: String): Unit = partitions.remove(blockName) + def removePartition(blockName: String): Unit = { + partitions.remove(blockName).foreach(partitionSeq.removePartition) + } def distribution(exec: LiveExecutor): LiveRDDDistribution = { - distributions.getOrElseUpdate(exec.hostPort, new LiveRDDDistribution(exec)) + distributions.getOrElseUpdate(exec.executorId, new LiveRDDDistribution(exec)) } - def removeDistribution(exec: LiveExecutor): Unit = { - distributions.remove(exec.hostPort) + def removeDistribution(exec: LiveExecutor): Boolean = { + distributions.remove(exec.executorId).isDefined } - override protected def doUpdate(): Any = { - val parts = if (partitions.nonEmpty) { - Some(partitions.values.toList.sortBy(_.blockName).map(_.toApi())) - } else { - None - } + def distributionOpt(exec: LiveExecutor): Option[LiveRDDDistribution] = { + distributions.get(exec.executorId) + } + override protected def doUpdate(): Any = { val dists = if (distributions.nonEmpty) { - Some(distributions.values.toList.sortBy(_.exec.executorId).map(_.toApi())) + Some(distributions.values.map(_.toApi()).toSeq) } else { None } @@ -506,7 +528,7 @@ private class LiveRDD(info: RDDInfo) extends LiveEntity { memoryUsed, diskUsed, dists, - parts) + Some(partitionSeq)) new RDDStorageInfoWrapper(rdd) } @@ -526,7 +548,7 @@ private object LiveEntityHelpers { .map { acc => new v1.AccumulableInfo( acc.id, - acc.name.map(_.intern()).orNull, + acc.name.orNull, acc.update.map(_.toString()), acc.value.map(_.toString()).orNull) } @@ -534,3 +556,81 @@ private object LiveEntityHelpers { } } + +/** + * A custom sequence of partitions based on a mutable linked list. + * + * The external interface is an immutable Seq, which is thread-safe for traversal. There are no + * guarantees about consistency though - iteration might return elements that have been removed + * or miss added elements. + * + * Internally, the sequence is mutable, and elements can modify the data they expose. Additions and + * removals are O(1). It is not safe to do multiple writes concurrently. + */ +private class RDDPartitionSeq extends Seq[v1.RDDPartitionInfo] { + + @volatile private var _head: LiveRDDPartition = null + @volatile private var _tail: LiveRDDPartition = null + @volatile var count = 0 + + override def apply(idx: Int): v1.RDDPartitionInfo = { + var curr = 0 + var e = _head + while (curr < idx && e != null) { + curr += 1 + e = e.next + } + if (e != null) e.value else throw new IndexOutOfBoundsException(idx.toString) + } + + override def iterator: Iterator[v1.RDDPartitionInfo] = { + new Iterator[v1.RDDPartitionInfo] { + var current = _head + + override def hasNext: Boolean = current != null + + override def next(): v1.RDDPartitionInfo = { + if (current != null) { + val tmp = current + current = tmp.next + tmp.value + } else { + throw new NoSuchElementException() + } + } + } + } + + override def length: Int = count + + def addPartition(part: LiveRDDPartition): Unit = { + part.prev = _tail + if (_tail != null) { + _tail.next = part + } + if (_head == null) { + _head = part + } + _tail = part + count += 1 + } + + def removePartition(part: LiveRDDPartition): Unit = { + count -= 1 + // Remove the partition from the list, but leave the pointers unchanged. That ensures a best + // effort at returning existing elements when iterations still reference the removed partition. + if (part.prev != null) { + part.prev.next = part.next + } + if (part eq _head) { + _head = part.next + } + if (part.next != null) { + part.next.prev = part.prev + } + if (part eq _tail) { + _tail = part.prev + } + } + +} diff --git a/core/src/main/scala/org/apache/spark/status/api/v1/AllRDDResource.scala b/core/src/main/scala/org/apache/spark/status/api/v1/AllRDDResource.scala index 1279b281ad8d8..2189e1da91841 100644 --- a/core/src/main/scala/org/apache/spark/status/api/v1/AllRDDResource.scala +++ b/core/src/main/scala/org/apache/spark/status/api/v1/AllRDDResource.scala @@ -21,90 +21,11 @@ import javax.ws.rs.core.MediaType import org.apache.spark.storage.{RDDInfo, StorageStatus, StorageUtils} import org.apache.spark.ui.SparkUI -import org.apache.spark.ui.storage.StorageListener @Produces(Array(MediaType.APPLICATION_JSON)) private[v1] class AllRDDResource(ui: SparkUI) { @GET - def rddList(): Seq[RDDStorageInfo] = { - val storageStatusList = ui.storageListener.activeStorageStatusList - val rddInfos = ui.storageListener.rddInfoList - rddInfos.map{rddInfo => - AllRDDResource.getRDDStorageInfo(rddInfo.id, rddInfo, storageStatusList, - includeDetails = false) - } - } + def rddList(): Seq[RDDStorageInfo] = ui.store.rddList() } - -private[spark] object AllRDDResource { - - def getRDDStorageInfo( - rddId: Int, - listener: StorageListener, - includeDetails: Boolean): Option[RDDStorageInfo] = { - val storageStatusList = listener.activeStorageStatusList - listener.rddInfoList.find { _.id == rddId }.map { rddInfo => - getRDDStorageInfo(rddId, rddInfo, storageStatusList, includeDetails) - } - } - - def getRDDStorageInfo( - rddId: Int, - rddInfo: RDDInfo, - storageStatusList: Seq[StorageStatus], - includeDetails: Boolean): RDDStorageInfo = { - val workers = storageStatusList.map { (rddId, _) } - val blockLocations = StorageUtils.getRddBlockLocations(rddId, storageStatusList) - val blocks = storageStatusList - .flatMap { _.rddBlocksById(rddId) } - .sortWith { _._1.name < _._1.name } - .map { case (blockId, status) => - (blockId, status, blockLocations.getOrElse(blockId, Seq[String]("Unknown"))) - } - - val dataDistribution = if (includeDetails) { - Some(storageStatusList.map { status => - new RDDDataDistribution( - address = status.blockManagerId.hostPort, - memoryUsed = status.memUsedByRdd(rddId), - memoryRemaining = status.memRemaining, - diskUsed = status.diskUsedByRdd(rddId), - onHeapMemoryUsed = Some( - if (!rddInfo.storageLevel.useOffHeap) status.memUsedByRdd(rddId) else 0L), - offHeapMemoryUsed = Some( - if (rddInfo.storageLevel.useOffHeap) status.memUsedByRdd(rddId) else 0L), - onHeapMemoryRemaining = status.onHeapMemRemaining, - offHeapMemoryRemaining = status.offHeapMemRemaining - ) } ) - } else { - None - } - val partitions = if (includeDetails) { - Some(blocks.map { case (id, block, locations) => - new RDDPartitionInfo( - blockName = id.name, - storageLevel = block.storageLevel.description, - memoryUsed = block.memSize, - diskUsed = block.diskSize, - executors = locations - ) - } ) - } else { - None - } - - new RDDStorageInfo( - id = rddId, - name = rddInfo.name, - numPartitions = rddInfo.numPartitions, - numCachedPartitions = rddInfo.numCachedPartitions, - storageLevel = rddInfo.storageLevel.description, - memoryUsed = rddInfo.memSize, - diskUsed = rddInfo.diskSize, - dataDistribution = dataDistribution, - partitions = partitions - ) - } -} diff --git a/core/src/main/scala/org/apache/spark/status/api/v1/OneRDDResource.scala b/core/src/main/scala/org/apache/spark/status/api/v1/OneRDDResource.scala index 237aeac185877..ca9758cf0d109 100644 --- a/core/src/main/scala/org/apache/spark/status/api/v1/OneRDDResource.scala +++ b/core/src/main/scala/org/apache/spark/status/api/v1/OneRDDResource.scala @@ -16,6 +16,7 @@ */ package org.apache.spark.status.api.v1 +import java.util.NoSuchElementException import javax.ws.rs.{GET, PathParam, Produces} import javax.ws.rs.core.MediaType @@ -26,9 +27,12 @@ private[v1] class OneRDDResource(ui: SparkUI) { @GET def rddData(@PathParam("rddId") rddId: Int): RDDStorageInfo = { - AllRDDResource.getRDDStorageInfo(rddId, ui.storageListener, true).getOrElse( - throw new NotFoundException(s"no rdd found w/ id $rddId") - ) + try { + ui.store.rdd(rddId) + } catch { + case _: NoSuchElementException => + throw new NotFoundException(s"no rdd found w/ id $rddId") + } } } diff --git a/core/src/main/scala/org/apache/spark/status/storeTypes.scala b/core/src/main/scala/org/apache/spark/status/storeTypes.scala index 23e9a360ddc02..f44b7935bfaa3 100644 --- a/core/src/main/scala/org/apache/spark/status/storeTypes.scala +++ b/core/src/main/scala/org/apache/spark/status/storeTypes.scala @@ -131,3 +131,19 @@ private[spark] class ExecutorStageSummaryWrapper( private[this] val stage: Array[Int] = Array(stageId, stageAttemptId) } + +private[spark] class StreamBlockData( + val name: String, + val executorId: String, + val hostPort: String, + val storageLevel: String, + val useMemory: Boolean, + val useDisk: Boolean, + val deserialized: Boolean, + val memSize: Long, + val diskSize: Long) { + + @JsonIgnore @KVIndex + def key: Array[String] = Array(name, executorId) + +} diff --git a/core/src/main/scala/org/apache/spark/storage/BlockStatusListener.scala b/core/src/main/scala/org/apache/spark/storage/BlockStatusListener.scala deleted file mode 100644 index 0a14fcadf53e0..0000000000000 --- a/core/src/main/scala/org/apache/spark/storage/BlockStatusListener.scala +++ /dev/null @@ -1,100 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.storage - -import scala.collection.mutable - -import org.apache.spark.scheduler._ - -private[spark] case class BlockUIData( - blockId: BlockId, - location: String, - storageLevel: StorageLevel, - memSize: Long, - diskSize: Long) - -/** - * The aggregated status of stream blocks in an executor - */ -private[spark] case class ExecutorStreamBlockStatus( - executorId: String, - location: String, - blocks: Seq[BlockUIData]) { - - def totalMemSize: Long = blocks.map(_.memSize).sum - - def totalDiskSize: Long = blocks.map(_.diskSize).sum - - def numStreamBlocks: Int = blocks.size - -} - -private[spark] class BlockStatusListener extends SparkListener { - - private val blockManagers = - new mutable.HashMap[BlockManagerId, mutable.HashMap[BlockId, BlockUIData]] - - override def onBlockUpdated(blockUpdated: SparkListenerBlockUpdated): Unit = { - val blockId = blockUpdated.blockUpdatedInfo.blockId - if (!blockId.isInstanceOf[StreamBlockId]) { - // Now we only monitor StreamBlocks - return - } - val blockManagerId = blockUpdated.blockUpdatedInfo.blockManagerId - val storageLevel = blockUpdated.blockUpdatedInfo.storageLevel - val memSize = blockUpdated.blockUpdatedInfo.memSize - val diskSize = blockUpdated.blockUpdatedInfo.diskSize - - synchronized { - // Drop the update info if the block manager is not registered - blockManagers.get(blockManagerId).foreach { blocksInBlockManager => - if (storageLevel.isValid) { - blocksInBlockManager.put(blockId, - BlockUIData( - blockId, - blockManagerId.hostPort, - storageLevel, - memSize, - diskSize) - ) - } else { - // If isValid is not true, it means we should drop the block. - blocksInBlockManager -= blockId - } - } - } - } - - override def onBlockManagerAdded(blockManagerAdded: SparkListenerBlockManagerAdded): Unit = { - synchronized { - blockManagers.put(blockManagerAdded.blockManagerId, mutable.HashMap()) - } - } - - override def onBlockManagerRemoved( - blockManagerRemoved: SparkListenerBlockManagerRemoved): Unit = synchronized { - blockManagers -= blockManagerRemoved.blockManagerId - } - - def allExecutorStreamBlockStatus: Seq[ExecutorStreamBlockStatus] = synchronized { - blockManagers.map { case (blockManagerId, blocks) => - ExecutorStreamBlockStatus( - blockManagerId.executorId, blockManagerId.hostPort, blocks.values.toSeq) - }.toSeq - } -} diff --git a/core/src/main/scala/org/apache/spark/storage/StorageStatusListener.scala b/core/src/main/scala/org/apache/spark/storage/StorageStatusListener.scala deleted file mode 100644 index ac60f795915a3..0000000000000 --- a/core/src/main/scala/org/apache/spark/storage/StorageStatusListener.scala +++ /dev/null @@ -1,111 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.storage - -import scala.collection.mutable - -import org.apache.spark.SparkConf -import org.apache.spark.annotation.DeveloperApi -import org.apache.spark.scheduler._ - -/** - * :: DeveloperApi :: - * A SparkListener that maintains executor storage status. - * - * This class is thread-safe (unlike JobProgressListener) - */ -@DeveloperApi -@deprecated("This class will be removed in a future release.", "2.2.0") -class StorageStatusListener(conf: SparkConf) extends SparkListener { - // This maintains only blocks that are cached (i.e. storage level is not StorageLevel.NONE) - private[storage] val executorIdToStorageStatus = mutable.Map[String, StorageStatus]() - private[storage] val deadExecutorStorageStatus = new mutable.ListBuffer[StorageStatus]() - private[this] val retainedDeadExecutors = conf.getInt("spark.ui.retainedDeadExecutors", 100) - - def storageStatusList: Seq[StorageStatus] = synchronized { - executorIdToStorageStatus.values.toSeq - } - - def deadStorageStatusList: Seq[StorageStatus] = synchronized { - deadExecutorStorageStatus - } - - /** Update storage status list to reflect updated block statuses */ - private def updateStorageStatus(execId: String, updatedBlocks: Seq[(BlockId, BlockStatus)]) { - executorIdToStorageStatus.get(execId).foreach { storageStatus => - updatedBlocks.foreach { case (blockId, updatedStatus) => - if (updatedStatus.storageLevel == StorageLevel.NONE) { - storageStatus.removeBlock(blockId) - } else { - storageStatus.updateBlock(blockId, updatedStatus) - } - } - } - } - - /** Update storage status list to reflect the removal of an RDD from the cache */ - private def updateStorageStatus(unpersistedRDDId: Int) { - storageStatusList.foreach { storageStatus => - storageStatus.rddBlocksById(unpersistedRDDId).foreach { case (blockId, _) => - storageStatus.removeBlock(blockId) - } - } - } - - override def onUnpersistRDD(unpersistRDD: SparkListenerUnpersistRDD): Unit = synchronized { - updateStorageStatus(unpersistRDD.rddId) - } - - override def onBlockManagerAdded(blockManagerAdded: SparkListenerBlockManagerAdded) { - synchronized { - val blockManagerId = blockManagerAdded.blockManagerId - val executorId = blockManagerId.executorId - // The onHeap and offHeap memory are always defined for new applications, - // but they can be missing if we are replaying old event logs. - val storageStatus = new StorageStatus(blockManagerId, blockManagerAdded.maxMem, - blockManagerAdded.maxOnHeapMem, blockManagerAdded.maxOffHeapMem) - executorIdToStorageStatus(executorId) = storageStatus - - // Try to remove the dead storage status if same executor register the block manager twice. - deadExecutorStorageStatus.zipWithIndex.find(_._1.blockManagerId.executorId == executorId) - .foreach(toRemoveExecutor => deadExecutorStorageStatus.remove(toRemoveExecutor._2)) - } - } - - override def onBlockManagerRemoved(blockManagerRemoved: SparkListenerBlockManagerRemoved) { - synchronized { - val executorId = blockManagerRemoved.blockManagerId.executorId - executorIdToStorageStatus.remove(executorId).foreach { status => - deadExecutorStorageStatus += status - } - if (deadExecutorStorageStatus.size > retainedDeadExecutors) { - deadExecutorStorageStatus.trimStart(1) - } - } - } - - override def onBlockUpdated(blockUpdated: SparkListenerBlockUpdated): Unit = { - val executorId = blockUpdated.blockUpdatedInfo.blockManagerId.executorId - val blockId = blockUpdated.blockUpdatedInfo.blockId - val storageLevel = blockUpdated.blockUpdatedInfo.storageLevel - val memSize = blockUpdated.blockUpdatedInfo.memSize - val diskSize = blockUpdated.blockUpdatedInfo.diskSize - val blockStatus = BlockStatus(storageLevel, memSize, diskSize) - updateStorageStatus(executorId, Seq((blockId, blockStatus))) - } -} diff --git a/core/src/main/scala/org/apache/spark/ui/SparkUI.scala b/core/src/main/scala/org/apache/spark/ui/SparkUI.scala index 79d40b6a90c3c..e93ade001c607 100644 --- a/core/src/main/scala/org/apache/spark/ui/SparkUI.scala +++ b/core/src/main/scala/org/apache/spark/ui/SparkUI.scala @@ -26,13 +26,12 @@ import org.apache.spark.internal.Logging import org.apache.spark.scheduler._ import org.apache.spark.status.AppStatusStore import org.apache.spark.status.api.v1._ -import org.apache.spark.storage.StorageStatusListener import org.apache.spark.ui.JettyUtils._ import org.apache.spark.ui.env.EnvironmentTab import org.apache.spark.ui.exec.ExecutorsTab import org.apache.spark.ui.jobs.{JobProgressListener, JobsTab, StagesTab} import org.apache.spark.ui.scope.RDDOperationGraphListener -import org.apache.spark.ui.storage.{StorageListener, StorageTab} +import org.apache.spark.ui.storage.StorageTab import org.apache.spark.util.Utils /** @@ -43,9 +42,7 @@ private[spark] class SparkUI private ( val sc: Option[SparkContext], val conf: SparkConf, securityManager: SecurityManager, - val storageStatusListener: StorageStatusListener, val jobProgressListener: JobProgressListener, - val storageListener: StorageListener, val operationGraphListener: RDDOperationGraphListener, var appName: String, val basePath: String, @@ -69,7 +66,7 @@ private[spark] class SparkUI private ( attachTab(jobsTab) val stagesTab = new StagesTab(this, store) attachTab(stagesTab) - attachTab(new StorageTab(this)) + attachTab(new StorageTab(this, store)) attachTab(new EnvironmentTab(this, store)) attachTab(new ExecutorsTab(this)) attachHandler(createStaticHandler(SparkUI.STATIC_RESOURCE_DIR, "/static")) @@ -186,18 +183,12 @@ private[spark] object SparkUI { addListenerFn(listener) listener } - - val storageStatusListener = new StorageStatusListener(conf) - val storageListener = new StorageListener(storageStatusListener) val operationGraphListener = new RDDOperationGraphListener(conf) - addListenerFn(storageStatusListener) - addListenerFn(storageListener) addListenerFn(operationGraphListener) - new SparkUI(store, sc, conf, securityManager, storageStatusListener, jobProgressListener, - storageListener, operationGraphListener, appName, basePath, lastUpdateTime, startTime, - appSparkVersion) + new SparkUI(store, sc, conf, securityManager, jobProgressListener, operationGraphListener, + appName, basePath, lastUpdateTime, startTime, appSparkVersion) } } diff --git a/core/src/main/scala/org/apache/spark/ui/storage/RDDPage.scala b/core/src/main/scala/org/apache/spark/ui/storage/RDDPage.scala index e8ff08f7d88ff..02cee7f8c5b33 100644 --- a/core/src/main/scala/org/apache/spark/ui/storage/RDDPage.scala +++ b/core/src/main/scala/org/apache/spark/ui/storage/RDDPage.scala @@ -22,13 +22,13 @@ import javax.servlet.http.HttpServletRequest import scala.xml.{Node, Unparsed} -import org.apache.spark.status.api.v1.{AllRDDResource, RDDDataDistribution, RDDPartitionInfo} -import org.apache.spark.ui.{PagedDataSource, PagedTable, UIUtils, WebUIPage} +import org.apache.spark.status.AppStatusStore +import org.apache.spark.status.api.v1.{RDDDataDistribution, RDDPartitionInfo} +import org.apache.spark.ui._ import org.apache.spark.util.Utils /** Page showing storage details for a given RDD */ -private[ui] class RDDPage(parent: StorageTab) extends WebUIPage("rdd") { - private val listener = parent.listener +private[ui] class RDDPage(parent: SparkUITab, store: AppStatusStore) extends WebUIPage("rdd") { def render(request: HttpServletRequest): Seq[Node] = { // stripXSS is called first to remove suspicious characters used in XSS attacks @@ -48,11 +48,13 @@ private[ui] class RDDPage(parent: StorageTab) extends WebUIPage("rdd") { val blockPrevPageSize = Option(parameterBlockPrevPageSize).map(_.toInt).getOrElse(blockPageSize) val rddId = parameterId.toInt - val rddStorageInfo = AllRDDResource.getRDDStorageInfo(rddId, listener, includeDetails = true) - .getOrElse { + val rddStorageInfo = try { + store.rdd(rddId) + } catch { + case _: NoSuchElementException => // Rather than crashing, render an "RDD Not Found" page return UIUtils.headerSparkPage("RDD Not Found", Seq.empty[Node], parent) - } + } // Worker table val workerTable = UIUtils.listingTable(workerHeader, workerRow, diff --git a/core/src/main/scala/org/apache/spark/ui/storage/StoragePage.scala b/core/src/main/scala/org/apache/spark/ui/storage/StoragePage.scala index b6c764d1728e4..b8aec9890247a 100644 --- a/core/src/main/scala/org/apache/spark/ui/storage/StoragePage.scala +++ b/core/src/main/scala/org/apache/spark/ui/storage/StoragePage.scala @@ -19,23 +19,23 @@ package org.apache.spark.ui.storage import javax.servlet.http.HttpServletRequest +import scala.collection.SortedMap import scala.xml.Node -import org.apache.spark.storage._ -import org.apache.spark.ui.{UIUtils, WebUIPage} +import org.apache.spark.status.{AppStatusStore, StreamBlockData} +import org.apache.spark.status.api.v1 +import org.apache.spark.ui._ import org.apache.spark.util.Utils /** Page showing list of RDD's currently stored in the cluster */ -private[ui] class StoragePage(parent: StorageTab) extends WebUIPage("") { - private val listener = parent.listener +private[ui] class StoragePage(parent: SparkUITab, store: AppStatusStore) extends WebUIPage("") { def render(request: HttpServletRequest): Seq[Node] = { - val content = rddTable(listener.rddInfoList) ++ - receiverBlockTables(listener.allExecutorStreamBlockStatus.sortBy(_.executorId)) + val content = rddTable(store.rddList()) ++ receiverBlockTables(store.streamBlocksList()) UIUtils.headerSparkPage("Storage", content, parent) } - private[storage] def rddTable(rdds: Seq[RDDInfo]): Seq[Node] = { + private[storage] def rddTable(rdds: Seq[v1.RDDStorageInfo]): Seq[Node] = { if (rdds.isEmpty) { // Don't show the rdd table if there is no RDD persisted. Nil @@ -58,7 +58,7 @@ private[ui] class StoragePage(parent: StorageTab) extends WebUIPage("") { "Size on Disk") /** Render an HTML row representing an RDD */ - private def rddRow(rdd: RDDInfo): Seq[Node] = { + private def rddRow(rdd: v1.RDDStorageInfo): Seq[Node] = { // scalastyle:off {rdd.id} @@ -67,35 +67,40 @@ private[ui] class StoragePage(parent: StorageTab) extends WebUIPage("") { {rdd.name} - {rdd.storageLevel.description} + {rdd.storageLevel} {rdd.numCachedPartitions.toString} {"%.0f%%".format(rdd.numCachedPartitions * 100.0 / rdd.numPartitions)} - {Utils.bytesToString(rdd.memSize)} - {Utils.bytesToString(rdd.diskSize)} + {Utils.bytesToString(rdd.memoryUsed)} + {Utils.bytesToString(rdd.diskUsed)} // scalastyle:on } - private[storage] def receiverBlockTables(statuses: Seq[ExecutorStreamBlockStatus]): Seq[Node] = { - if (statuses.map(_.numStreamBlocks).sum == 0) { + private[storage] def receiverBlockTables(blocks: Seq[StreamBlockData]): Seq[Node] = { + if (blocks.isEmpty) { // Don't show the tables if there is no stream block Nil } else { - val blocks = statuses.flatMap(_.blocks).groupBy(_.blockId).toSeq.sortBy(_._1.toString) + val sorted = blocks.groupBy(_.name).toSeq.sortBy(_._1.toString)

    Receiver Blocks

    - {executorMetricsTable(statuses)} - {streamBlockTable(blocks)} + {executorMetricsTable(blocks)} + {streamBlockTable(sorted)}
    } } - private def executorMetricsTable(statuses: Seq[ExecutorStreamBlockStatus]): Seq[Node] = { + private def executorMetricsTable(blocks: Seq[StreamBlockData]): Seq[Node] = { + val blockManagers = SortedMap(blocks.groupBy(_.executorId).toSeq: _*) + .map { case (id, blocks) => + new ExecutorStreamSummary(blocks) + } +
    Aggregated Block Metrics by Executor
    - {UIUtils.listingTable(executorMetricsTableHeader, executorMetricsTableRow, statuses, + {UIUtils.listingTable(executorMetricsTableHeader, executorMetricsTableRow, blockManagers, id = Some("storage-by-executor-stream-blocks"))}
    } @@ -107,7 +112,7 @@ private[ui] class StoragePage(parent: StorageTab) extends WebUIPage("") { "Total Size on Disk", "Stream Blocks") - private def executorMetricsTableRow(status: ExecutorStreamBlockStatus): Seq[Node] = { + private def executorMetricsTableRow(status: ExecutorStreamSummary): Seq[Node] = { {status.executorId} @@ -127,7 +132,7 @@ private[ui] class StoragePage(parent: StorageTab) extends WebUIPage("") { } - private def streamBlockTable(blocks: Seq[(BlockId, Seq[BlockUIData])]): Seq[Node] = { + private def streamBlockTable(blocks: Seq[(String, Seq[StreamBlockData])]): Seq[Node] = { if (blocks.isEmpty) { Nil } else { @@ -151,7 +156,7 @@ private[ui] class StoragePage(parent: StorageTab) extends WebUIPage("") { "Size") /** Render a stream block */ - private def streamBlockTableRow(block: (BlockId, Seq[BlockUIData])): Seq[Node] = { + private def streamBlockTableRow(block: (String, Seq[StreamBlockData])): Seq[Node] = { val replications = block._2 assert(replications.nonEmpty) // This must be true because it's the result of "groupBy" if (replications.size == 1) { @@ -163,33 +168,36 @@ private[ui] class StoragePage(parent: StorageTab) extends WebUIPage("") { } private def streamBlockTableSubrow( - blockId: BlockId, block: BlockUIData, replication: Int, firstSubrow: Boolean): Seq[Node] = { + blockId: String, + block: StreamBlockData, + replication: Int, + firstSubrow: Boolean): Seq[Node] = { val (storageLevel, size) = streamBlockStorageLevelDescriptionAndSize(block) { if (firstSubrow) { - {block.blockId.toString} + {block.name} {replication.toString} } } - {block.location} + {block.hostPort} {storageLevel} {Utils.bytesToString(size)} } private[storage] def streamBlockStorageLevelDescriptionAndSize( - block: BlockUIData): (String, Long) = { - if (block.storageLevel.useDisk) { + block: StreamBlockData): (String, Long) = { + if (block.useDisk) { ("Disk", block.diskSize) - } else if (block.storageLevel.useMemory && block.storageLevel.deserialized) { + } else if (block.useMemory && block.deserialized) { ("Memory", block.memSize) - } else if (block.storageLevel.useMemory && !block.storageLevel.deserialized) { + } else if (block.useMemory && !block.deserialized) { ("Memory Serialized", block.memSize) } else { throw new IllegalStateException(s"Invalid Storage Level: ${block.storageLevel}") @@ -197,3 +205,17 @@ private[ui] class StoragePage(parent: StorageTab) extends WebUIPage("") { } } + +private class ExecutorStreamSummary(blocks: Seq[StreamBlockData]) { + + def executorId: String = blocks.head.executorId + + def location: String = blocks.head.hostPort + + def totalMemSize: Long = blocks.map(_.memSize).sum + + def totalDiskSize: Long = blocks.map(_.diskSize).sum + + def numStreamBlocks: Int = blocks.size + +} diff --git a/core/src/main/scala/org/apache/spark/ui/storage/StorageTab.scala b/core/src/main/scala/org/apache/spark/ui/storage/StorageTab.scala index 148efb134e14f..688efa24ade0c 100644 --- a/core/src/main/scala/org/apache/spark/ui/storage/StorageTab.scala +++ b/core/src/main/scala/org/apache/spark/ui/storage/StorageTab.scala @@ -17,71 +17,13 @@ package org.apache.spark.ui.storage -import scala.collection.mutable - -import org.apache.spark.annotation.DeveloperApi -import org.apache.spark.scheduler._ -import org.apache.spark.storage._ +import org.apache.spark.status.AppStatusStore import org.apache.spark.ui._ /** Web UI showing storage status of all RDD's in the given SparkContext. */ -private[ui] class StorageTab(parent: SparkUI) extends SparkUITab(parent, "storage") { - val listener = parent.storageListener - - attachPage(new StoragePage(this)) - attachPage(new RDDPage(this)) -} - -/** - * :: DeveloperApi :: - * A SparkListener that prepares information to be displayed on the BlockManagerUI. - * - * This class is thread-safe (unlike JobProgressListener) - */ -@DeveloperApi -@deprecated("This class will be removed in a future release.", "2.2.0") -class StorageListener(storageStatusListener: StorageStatusListener) extends BlockStatusListener { - - private[ui] val _rddInfoMap = mutable.Map[Int, RDDInfo]() // exposed for testing - - def activeStorageStatusList: Seq[StorageStatus] = storageStatusListener.storageStatusList - - /** Filter RDD info to include only those with cached partitions */ - def rddInfoList: Seq[RDDInfo] = synchronized { - _rddInfoMap.values.filter(_.numCachedPartitions > 0).toSeq - } - - /** Update the storage info of the RDDs whose blocks are among the given updated blocks */ - private def updateRDDInfo(updatedBlocks: Seq[(BlockId, BlockStatus)]): Unit = { - val rddIdsToUpdate = updatedBlocks.flatMap { case (bid, _) => bid.asRDDId.map(_.rddId) }.toSet - val rddInfosToUpdate = _rddInfoMap.values.toSeq.filter { s => rddIdsToUpdate.contains(s.id) } - StorageUtils.updateRddInfo(rddInfosToUpdate, activeStorageStatusList) - } - - override def onStageSubmitted(stageSubmitted: SparkListenerStageSubmitted): Unit = synchronized { - val rddInfos = stageSubmitted.stageInfo.rddInfos - rddInfos.foreach { info => _rddInfoMap.getOrElseUpdate(info.id, info).name = info.name } - } - - override def onStageCompleted(stageCompleted: SparkListenerStageCompleted): Unit = synchronized { - // Remove all partitions that are no longer cached in current completed stage - val completedRddIds = stageCompleted.stageInfo.rddInfos.map(r => r.id).toSet - _rddInfoMap.retain { case (id, info) => - !completedRddIds.contains(id) || info.numCachedPartitions > 0 - } - } - - override def onUnpersistRDD(unpersistRDD: SparkListenerUnpersistRDD): Unit = synchronized { - _rddInfoMap.remove(unpersistRDD.rddId) - } +private[ui] class StorageTab(parent: SparkUI, store: AppStatusStore) + extends SparkUITab(parent, "storage") { - override def onBlockUpdated(blockUpdated: SparkListenerBlockUpdated): Unit = { - super.onBlockUpdated(blockUpdated) - val blockId = blockUpdated.blockUpdatedInfo.blockId - val storageLevel = blockUpdated.blockUpdatedInfo.storageLevel - val memSize = blockUpdated.blockUpdatedInfo.memSize - val diskSize = blockUpdated.blockUpdatedInfo.diskSize - val blockStatus = BlockStatus(storageLevel, memSize, diskSize) - updateRDDInfo(Seq((blockId, blockStatus))) - } + attachPage(new StoragePage(this, store)) + attachPage(new RDDPage(this, store)) } diff --git a/core/src/test/scala/org/apache/spark/status/AppStatusListenerSuite.scala b/core/src/test/scala/org/apache/spark/status/AppStatusListenerSuite.scala index 867d35f231dc0..ba082bc93dd42 100644 --- a/core/src/test/scala/org/apache/spark/status/AppStatusListenerSuite.scala +++ b/core/src/test/scala/org/apache/spark/status/AppStatusListenerSuite.scala @@ -578,145 +578,160 @@ class AppStatusListenerSuite extends SparkFunSuite with BeforeAndAfter { } } - val rdd1b1 = RDDBlockId(1, 1) + val rdd1b1 = RddBlock(1, 1, 1L, 2L) + val rdd1b2 = RddBlock(1, 2, 3L, 4L) + val rdd2b1 = RddBlock(2, 1, 5L, 6L) val level = StorageLevel.MEMORY_AND_DISK - // Submit a stage and make sure the RDD is recorded. - val rddInfo = new RDDInfo(rdd1b1.rddId, "rdd1", 2, level, Nil) - val stage = new StageInfo(1, 0, "stage1", 4, Seq(rddInfo), Nil, "details1") + // Submit a stage and make sure the RDDs are recorded. + val rdd1Info = new RDDInfo(rdd1b1.rddId, "rdd1", 2, level, Nil) + val rdd2Info = new RDDInfo(rdd2b1.rddId, "rdd2", 1, level, Nil) + val stage = new StageInfo(1, 0, "stage1", 4, Seq(rdd1Info, rdd2Info), Nil, "details1") listener.onStageSubmitted(SparkListenerStageSubmitted(stage, new Properties())) check[RDDStorageInfoWrapper](rdd1b1.rddId) { wrapper => - assert(wrapper.info.name === rddInfo.name) - assert(wrapper.info.numPartitions === rddInfo.numPartitions) - assert(wrapper.info.storageLevel === rddInfo.storageLevel.description) + assert(wrapper.info.name === rdd1Info.name) + assert(wrapper.info.numPartitions === rdd1Info.numPartitions) + assert(wrapper.info.storageLevel === rdd1Info.storageLevel.description) } // Add partition 1 replicated on two block managers. - listener.onBlockUpdated(SparkListenerBlockUpdated(BlockUpdatedInfo(bm1, rdd1b1, level, 1L, 1L))) + listener.onBlockUpdated(SparkListenerBlockUpdated( + BlockUpdatedInfo(bm1, rdd1b1.blockId, level, rdd1b1.memSize, rdd1b1.diskSize))) check[RDDStorageInfoWrapper](rdd1b1.rddId) { wrapper => - assert(wrapper.info.memoryUsed === 1L) - assert(wrapper.info.diskUsed === 1L) + assert(wrapper.info.numCachedPartitions === 1L) + assert(wrapper.info.memoryUsed === rdd1b1.memSize) + assert(wrapper.info.diskUsed === rdd1b1.diskSize) assert(wrapper.info.dataDistribution.isDefined) assert(wrapper.info.dataDistribution.get.size === 1) val dist = wrapper.info.dataDistribution.get.head assert(dist.address === bm1.hostPort) - assert(dist.memoryUsed === 1L) - assert(dist.diskUsed === 1L) + assert(dist.memoryUsed === rdd1b1.memSize) + assert(dist.diskUsed === rdd1b1.diskSize) assert(dist.memoryRemaining === maxMemory - dist.memoryUsed) assert(wrapper.info.partitions.isDefined) assert(wrapper.info.partitions.get.size === 1) val part = wrapper.info.partitions.get.head - assert(part.blockName === rdd1b1.name) + assert(part.blockName === rdd1b1.blockId.name) assert(part.storageLevel === level.description) - assert(part.memoryUsed === 1L) - assert(part.diskUsed === 1L) + assert(part.memoryUsed === rdd1b1.memSize) + assert(part.diskUsed === rdd1b1.diskSize) assert(part.executors === Seq(bm1.executorId)) } check[ExecutorSummaryWrapper](bm1.executorId) { exec => assert(exec.info.rddBlocks === 1L) - assert(exec.info.memoryUsed === 1L) - assert(exec.info.diskUsed === 1L) + assert(exec.info.memoryUsed === rdd1b1.memSize) + assert(exec.info.diskUsed === rdd1b1.diskSize) } - listener.onBlockUpdated(SparkListenerBlockUpdated(BlockUpdatedInfo(bm2, rdd1b1, level, 1L, 1L))) + listener.onBlockUpdated(SparkListenerBlockUpdated( + BlockUpdatedInfo(bm2, rdd1b1.blockId, level, rdd1b1.memSize, rdd1b1.diskSize))) check[RDDStorageInfoWrapper](rdd1b1.rddId) { wrapper => - assert(wrapper.info.memoryUsed === 2L) - assert(wrapper.info.diskUsed === 2L) + assert(wrapper.info.numCachedPartitions === 1L) + assert(wrapper.info.memoryUsed === rdd1b1.memSize * 2) + assert(wrapper.info.diskUsed === rdd1b1.diskSize * 2) assert(wrapper.info.dataDistribution.get.size === 2L) assert(wrapper.info.partitions.get.size === 1L) val dist = wrapper.info.dataDistribution.get.find(_.address == bm2.hostPort).get - assert(dist.memoryUsed === 1L) - assert(dist.diskUsed === 1L) + assert(dist.memoryUsed === rdd1b1.memSize) + assert(dist.diskUsed === rdd1b1.diskSize) assert(dist.memoryRemaining === maxMemory - dist.memoryUsed) val part = wrapper.info.partitions.get(0) - assert(part.memoryUsed === 2L) - assert(part.diskUsed === 2L) + assert(part.memoryUsed === rdd1b1.memSize * 2) + assert(part.diskUsed === rdd1b1.diskSize * 2) assert(part.executors === Seq(bm1.executorId, bm2.executorId)) } check[ExecutorSummaryWrapper](bm2.executorId) { exec => assert(exec.info.rddBlocks === 1L) - assert(exec.info.memoryUsed === 1L) - assert(exec.info.diskUsed === 1L) + assert(exec.info.memoryUsed === rdd1b1.memSize) + assert(exec.info.diskUsed === rdd1b1.diskSize) } // Add a second partition only to bm 1. - val rdd1b2 = RDDBlockId(1, 2) - listener.onBlockUpdated(SparkListenerBlockUpdated(BlockUpdatedInfo(bm1, rdd1b2, level, - 3L, 3L))) + listener.onBlockUpdated(SparkListenerBlockUpdated( + BlockUpdatedInfo(bm1, rdd1b2.blockId, level, rdd1b2.memSize, rdd1b2.diskSize))) check[RDDStorageInfoWrapper](rdd1b1.rddId) { wrapper => - assert(wrapper.info.memoryUsed === 5L) - assert(wrapper.info.diskUsed === 5L) + assert(wrapper.info.numCachedPartitions === 2L) + assert(wrapper.info.memoryUsed === 2 * rdd1b1.memSize + rdd1b2.memSize) + assert(wrapper.info.diskUsed === 2 * rdd1b1.diskSize + rdd1b2.diskSize) assert(wrapper.info.dataDistribution.get.size === 2L) assert(wrapper.info.partitions.get.size === 2L) val dist = wrapper.info.dataDistribution.get.find(_.address == bm1.hostPort).get - assert(dist.memoryUsed === 4L) - assert(dist.diskUsed === 4L) + assert(dist.memoryUsed === rdd1b1.memSize + rdd1b2.memSize) + assert(dist.diskUsed === rdd1b1.diskSize + rdd1b2.diskSize) assert(dist.memoryRemaining === maxMemory - dist.memoryUsed) - val part = wrapper.info.partitions.get.find(_.blockName === rdd1b2.name).get + val part = wrapper.info.partitions.get.find(_.blockName === rdd1b2.blockId.name).get assert(part.storageLevel === level.description) - assert(part.memoryUsed === 3L) - assert(part.diskUsed === 3L) + assert(part.memoryUsed === rdd1b2.memSize) + assert(part.diskUsed === rdd1b2.diskSize) assert(part.executors === Seq(bm1.executorId)) } check[ExecutorSummaryWrapper](bm1.executorId) { exec => assert(exec.info.rddBlocks === 2L) - assert(exec.info.memoryUsed === 4L) - assert(exec.info.diskUsed === 4L) + assert(exec.info.memoryUsed === rdd1b1.memSize + rdd1b2.memSize) + assert(exec.info.diskUsed === rdd1b1.diskSize + rdd1b2.diskSize) } // Remove block 1 from bm 1. - listener.onBlockUpdated(SparkListenerBlockUpdated(BlockUpdatedInfo(bm1, rdd1b1, - StorageLevel.NONE, 1L, 1L))) + listener.onBlockUpdated(SparkListenerBlockUpdated( + BlockUpdatedInfo(bm1, rdd1b1.blockId, StorageLevel.NONE, rdd1b1.memSize, rdd1b1.diskSize))) check[RDDStorageInfoWrapper](rdd1b1.rddId) { wrapper => - assert(wrapper.info.memoryUsed === 4L) - assert(wrapper.info.diskUsed === 4L) + assert(wrapper.info.numCachedPartitions === 2L) + assert(wrapper.info.memoryUsed === rdd1b1.memSize + rdd1b2.memSize) + assert(wrapper.info.diskUsed === rdd1b1.diskSize + rdd1b2.diskSize) assert(wrapper.info.dataDistribution.get.size === 2L) assert(wrapper.info.partitions.get.size === 2L) val dist = wrapper.info.dataDistribution.get.find(_.address == bm1.hostPort).get - assert(dist.memoryUsed === 3L) - assert(dist.diskUsed === 3L) + assert(dist.memoryUsed === rdd1b2.memSize) + assert(dist.diskUsed === rdd1b2.diskSize) assert(dist.memoryRemaining === maxMemory - dist.memoryUsed) - val part = wrapper.info.partitions.get.find(_.blockName === rdd1b1.name).get + val part = wrapper.info.partitions.get.find(_.blockName === rdd1b1.blockId.name).get assert(part.storageLevel === level.description) - assert(part.memoryUsed === 1L) - assert(part.diskUsed === 1L) + assert(part.memoryUsed === rdd1b1.memSize) + assert(part.diskUsed === rdd1b1.diskSize) assert(part.executors === Seq(bm2.executorId)) } check[ExecutorSummaryWrapper](bm1.executorId) { exec => assert(exec.info.rddBlocks === 1L) - assert(exec.info.memoryUsed === 3L) - assert(exec.info.diskUsed === 3L) + assert(exec.info.memoryUsed === rdd1b2.memSize) + assert(exec.info.diskUsed === rdd1b2.diskSize) } - // Remove block 2 from bm 2. This should leave only block 2 info in the store. - listener.onBlockUpdated(SparkListenerBlockUpdated(BlockUpdatedInfo(bm2, rdd1b1, - StorageLevel.NONE, 1L, 1L))) + // Remove block 1 from bm 2. This should leave only block 2's info in the store. + listener.onBlockUpdated(SparkListenerBlockUpdated( + BlockUpdatedInfo(bm2, rdd1b1.blockId, StorageLevel.NONE, rdd1b1.memSize, rdd1b1.diskSize))) check[RDDStorageInfoWrapper](rdd1b1.rddId) { wrapper => - assert(wrapper.info.memoryUsed === 3L) - assert(wrapper.info.diskUsed === 3L) + assert(wrapper.info.numCachedPartitions === 1L) + assert(wrapper.info.memoryUsed === rdd1b2.memSize) + assert(wrapper.info.diskUsed === rdd1b2.diskSize) assert(wrapper.info.dataDistribution.get.size === 1L) assert(wrapper.info.partitions.get.size === 1L) - assert(wrapper.info.partitions.get(0).blockName === rdd1b2.name) + assert(wrapper.info.partitions.get(0).blockName === rdd1b2.blockId.name) + } + + check[ExecutorSummaryWrapper](bm1.executorId) { exec => + assert(exec.info.rddBlocks === 1L) + assert(exec.info.memoryUsed === rdd1b2.memSize) + assert(exec.info.diskUsed === rdd1b2.diskSize) } check[ExecutorSummaryWrapper](bm2.executorId) { exec => @@ -725,12 +740,61 @@ class AppStatusListenerSuite extends SparkFunSuite with BeforeAndAfter { assert(exec.info.diskUsed === 0L) } + // Add a block from a different RDD. Verify the executor is updated correctly and also that + // the distribution data for both rdds is updated to match the remaining memory. + listener.onBlockUpdated(SparkListenerBlockUpdated( + BlockUpdatedInfo(bm1, rdd2b1.blockId, level, rdd2b1.memSize, rdd2b1.diskSize))) + + check[ExecutorSummaryWrapper](bm1.executorId) { exec => + assert(exec.info.rddBlocks === 2L) + assert(exec.info.memoryUsed === rdd1b2.memSize + rdd2b1.memSize) + assert(exec.info.diskUsed === rdd1b2.diskSize + rdd2b1.diskSize) + } + + check[RDDStorageInfoWrapper](rdd1b2.rddId) { wrapper => + assert(wrapper.info.dataDistribution.get.size === 1L) + val dist = wrapper.info.dataDistribution.get(0) + assert(dist.memoryRemaining === maxMemory - rdd2b1.memSize - rdd1b2.memSize ) + } + + check[RDDStorageInfoWrapper](rdd2b1.rddId) { wrapper => + assert(wrapper.info.dataDistribution.get.size === 1L) + + val dist = wrapper.info.dataDistribution.get(0) + assert(dist.memoryUsed === rdd2b1.memSize) + assert(dist.diskUsed === rdd2b1.diskSize) + assert(dist.memoryRemaining === maxMemory - rdd2b1.memSize - rdd1b2.memSize ) + } + // Unpersist RDD1. listener.onUnpersistRDD(SparkListenerUnpersistRDD(rdd1b1.rddId)) - intercept[NoSuchElementException] { + intercept[NoSuchElementException] { check[RDDStorageInfoWrapper](rdd1b1.rddId) { _ => () } } + // Update a StreamBlock. + val stream1 = StreamBlockId(1, 1L) + listener.onBlockUpdated(SparkListenerBlockUpdated( + BlockUpdatedInfo(bm1, stream1, level, 1L, 1L))) + + check[StreamBlockData](Array(stream1.name, bm1.executorId)) { stream => + assert(stream.name === stream1.name) + assert(stream.executorId === bm1.executorId) + assert(stream.hostPort === bm1.hostPort) + assert(stream.storageLevel === level.description) + assert(stream.useMemory === level.useMemory) + assert(stream.useDisk === level.useDisk) + assert(stream.deserialized === level.deserialized) + assert(stream.memSize === 1L) + assert(stream.diskSize === 1L) + } + + // Drop a StreamBlock. + listener.onBlockUpdated(SparkListenerBlockUpdated( + BlockUpdatedInfo(bm1, stream1, StorageLevel.NONE, 0L, 0L))) + intercept[NoSuchElementException] { + check[StreamBlockData](stream1.name) { _ => () } + } } private def key(stage: StageInfo): Array[Int] = Array(stage.stageId, stage.attemptId) @@ -740,4 +804,14 @@ class AppStatusListenerSuite extends SparkFunSuite with BeforeAndAfter { fn(value) } + private case class RddBlock( + rddId: Int, + partId: Int, + memSize: Long, + diskSize: Long) { + + def blockId: BlockId = RDDBlockId(rddId, partId) + + } + } diff --git a/core/src/test/scala/org/apache/spark/status/LiveEntitySuite.scala b/core/src/test/scala/org/apache/spark/status/LiveEntitySuite.scala new file mode 100644 index 0000000000000..bb2d2633001f0 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/status/LiveEntitySuite.scala @@ -0,0 +1,68 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.status + +import org.apache.spark.SparkFunSuite +import org.apache.spark.status.api.v1.RDDPartitionInfo + +class LiveEntitySuite extends SparkFunSuite { + + test("partition seq") { + val seq = new RDDPartitionSeq() + val items = (1 to 10).map { i => + val part = newPartition(i) + seq.addPartition(part) + part + }.toList + + checkSize(seq, 10) + + val added = newPartition(11) + seq.addPartition(added) + checkSize(seq, 11) + assert(seq.last.blockName === added.blockName) + + seq.removePartition(items(0)) + assert(seq.head.blockName === items(1).blockName) + assert(!seq.exists(_.blockName == items(0).blockName)) + checkSize(seq, 10) + + seq.removePartition(added) + assert(seq.last.blockName === items.last.blockName) + assert(!seq.exists(_.blockName == added.blockName)) + checkSize(seq, 9) + + seq.removePartition(items(5)) + checkSize(seq, 8) + assert(!seq.exists(_.blockName == items(5).blockName)) + } + + private def checkSize(seq: Seq[_], expected: Int): Unit = { + assert(seq.length === expected) + var count = 0 + seq.iterator.foreach { _ => count += 1 } + assert(count === expected) + } + + private def newPartition(i: Int): LiveRDDPartition = { + val part = new LiveRDDPartition(i.toString) + part.update(Seq(i.toString), i.toString, i, i) + part + } + +} diff --git a/core/src/test/scala/org/apache/spark/storage/BlockStatusListenerSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockStatusListenerSuite.scala deleted file mode 100644 index 06acca3943c20..0000000000000 --- a/core/src/test/scala/org/apache/spark/storage/BlockStatusListenerSuite.scala +++ /dev/null @@ -1,113 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.storage - -import org.apache.spark.SparkFunSuite -import org.apache.spark.scheduler._ - -class BlockStatusListenerSuite extends SparkFunSuite { - - test("basic functions") { - val blockManagerId = BlockManagerId("0", "localhost", 10000) - val listener = new BlockStatusListener() - - // Add a block manager and a new block status - listener.onBlockManagerAdded(SparkListenerBlockManagerAdded(0, blockManagerId, 0)) - listener.onBlockUpdated(SparkListenerBlockUpdated( - BlockUpdatedInfo( - blockManagerId, - StreamBlockId(0, 100), - StorageLevel.MEMORY_AND_DISK, - memSize = 100, - diskSize = 100))) - // The new block status should be added to the listener - val expectedBlock = BlockUIData( - StreamBlockId(0, 100), - "localhost:10000", - StorageLevel.MEMORY_AND_DISK, - memSize = 100, - diskSize = 100 - ) - val expectedExecutorStreamBlockStatus = Seq( - ExecutorStreamBlockStatus("0", "localhost:10000", Seq(expectedBlock)) - ) - assert(listener.allExecutorStreamBlockStatus === expectedExecutorStreamBlockStatus) - - // Add the second block manager - val blockManagerId2 = BlockManagerId("1", "localhost", 10001) - listener.onBlockManagerAdded(SparkListenerBlockManagerAdded(0, blockManagerId2, 0)) - // Add a new replication of the same block id from the second manager - listener.onBlockUpdated(SparkListenerBlockUpdated( - BlockUpdatedInfo( - blockManagerId2, - StreamBlockId(0, 100), - StorageLevel.MEMORY_AND_DISK, - memSize = 100, - diskSize = 100))) - val expectedBlock2 = BlockUIData( - StreamBlockId(0, 100), - "localhost:10001", - StorageLevel.MEMORY_AND_DISK, - memSize = 100, - diskSize = 100 - ) - // Each block manager should contain one block - val expectedExecutorStreamBlockStatus2 = Set( - ExecutorStreamBlockStatus("0", "localhost:10000", Seq(expectedBlock)), - ExecutorStreamBlockStatus("1", "localhost:10001", Seq(expectedBlock2)) - ) - assert(listener.allExecutorStreamBlockStatus.toSet === expectedExecutorStreamBlockStatus2) - - // Remove a replication of the same block - listener.onBlockUpdated(SparkListenerBlockUpdated( - BlockUpdatedInfo( - blockManagerId2, - StreamBlockId(0, 100), - StorageLevel.NONE, // StorageLevel.NONE means removing it - memSize = 0, - diskSize = 0))) - // Only the first block manager contains a block - val expectedExecutorStreamBlockStatus3 = Set( - ExecutorStreamBlockStatus("0", "localhost:10000", Seq(expectedBlock)), - ExecutorStreamBlockStatus("1", "localhost:10001", Seq.empty) - ) - assert(listener.allExecutorStreamBlockStatus.toSet === expectedExecutorStreamBlockStatus3) - - // Remove the second block manager at first but add a new block status - // from this removed block manager - listener.onBlockManagerRemoved(SparkListenerBlockManagerRemoved(0, blockManagerId2)) - listener.onBlockUpdated(SparkListenerBlockUpdated( - BlockUpdatedInfo( - blockManagerId2, - StreamBlockId(0, 100), - StorageLevel.MEMORY_AND_DISK, - memSize = 100, - diskSize = 100))) - // The second block manager is removed so we should not see the new block - val expectedExecutorStreamBlockStatus4 = Seq( - ExecutorStreamBlockStatus("0", "localhost:10000", Seq(expectedBlock)) - ) - assert(listener.allExecutorStreamBlockStatus === expectedExecutorStreamBlockStatus4) - - // Remove the last block manager - listener.onBlockManagerRemoved(SparkListenerBlockManagerRemoved(0, blockManagerId)) - // No block manager now so we should dop all block managers - assert(listener.allExecutorStreamBlockStatus.isEmpty) - } - -} diff --git a/core/src/test/scala/org/apache/spark/storage/StorageStatusListenerSuite.scala b/core/src/test/scala/org/apache/spark/storage/StorageStatusListenerSuite.scala deleted file mode 100644 index 9835f11a2f7ed..0000000000000 --- a/core/src/test/scala/org/apache/spark/storage/StorageStatusListenerSuite.scala +++ /dev/null @@ -1,167 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.storage - -import org.apache.spark.{SparkConf, SparkFunSuite, Success} -import org.apache.spark.executor.TaskMetrics -import org.apache.spark.scheduler._ - -/** - * Test the behavior of StorageStatusListener in response to all relevant events. - */ -class StorageStatusListenerSuite extends SparkFunSuite { - private val bm1 = BlockManagerId("big", "dog", 1) - private val bm2 = BlockManagerId("fat", "duck", 2) - private val taskInfo1 = new TaskInfo(0, 0, 0, 0, "big", "dog", TaskLocality.ANY, false) - private val taskInfo2 = new TaskInfo(0, 0, 0, 0, "fat", "duck", TaskLocality.ANY, false) - private val conf = new SparkConf() - - test("block manager added/removed") { - conf.set("spark.ui.retainedDeadExecutors", "1") - val listener = new StorageStatusListener(conf) - - // Block manager add - assert(listener.executorIdToStorageStatus.size === 0) - listener.onBlockManagerAdded(SparkListenerBlockManagerAdded(1L, bm1, 1000L)) - assert(listener.executorIdToStorageStatus.size === 1) - assert(listener.executorIdToStorageStatus.get("big").isDefined) - assert(listener.executorIdToStorageStatus("big").blockManagerId === bm1) - assert(listener.executorIdToStorageStatus("big").maxMem === 1000L) - assert(listener.executorIdToStorageStatus("big").numBlocks === 0) - listener.onBlockManagerAdded(SparkListenerBlockManagerAdded(1L, bm2, 2000L)) - assert(listener.executorIdToStorageStatus.size === 2) - assert(listener.executorIdToStorageStatus.get("fat").isDefined) - assert(listener.executorIdToStorageStatus("fat").blockManagerId === bm2) - assert(listener.executorIdToStorageStatus("fat").maxMem === 2000L) - assert(listener.executorIdToStorageStatus("fat").numBlocks === 0) - - // Block manager remove - listener.onBlockManagerRemoved(SparkListenerBlockManagerRemoved(1L, bm1)) - assert(listener.executorIdToStorageStatus.size === 1) - assert(!listener.executorIdToStorageStatus.get("big").isDefined) - assert(listener.executorIdToStorageStatus.get("fat").isDefined) - assert(listener.deadExecutorStorageStatus.size === 1) - assert(listener.deadExecutorStorageStatus(0).blockManagerId.executorId.equals("big")) - listener.onBlockManagerRemoved(SparkListenerBlockManagerRemoved(1L, bm2)) - assert(listener.executorIdToStorageStatus.size === 0) - assert(!listener.executorIdToStorageStatus.get("big").isDefined) - assert(!listener.executorIdToStorageStatus.get("fat").isDefined) - assert(listener.deadExecutorStorageStatus.size === 1) - assert(listener.deadExecutorStorageStatus(0).blockManagerId.executorId.equals("fat")) - } - - test("task end without updated blocks") { - val listener = new StorageStatusListener(conf) - listener.onBlockManagerAdded(SparkListenerBlockManagerAdded(1L, bm1, 1000L)) - listener.onBlockManagerAdded(SparkListenerBlockManagerAdded(1L, bm2, 2000L)) - val taskMetrics = new TaskMetrics - - // Task end with no updated blocks - assert(listener.executorIdToStorageStatus("big").numBlocks === 0) - assert(listener.executorIdToStorageStatus("fat").numBlocks === 0) - listener.onTaskEnd(SparkListenerTaskEnd(1, 0, "obliteration", Success, taskInfo1, taskMetrics)) - assert(listener.executorIdToStorageStatus("big").numBlocks === 0) - assert(listener.executorIdToStorageStatus("fat").numBlocks === 0) - listener.onTaskEnd(SparkListenerTaskEnd(1, 0, "obliteration", Success, taskInfo2, taskMetrics)) - assert(listener.executorIdToStorageStatus("big").numBlocks === 0) - assert(listener.executorIdToStorageStatus("fat").numBlocks === 0) - } - - test("updated blocks") { - val listener = new StorageStatusListener(conf) - listener.onBlockManagerAdded(SparkListenerBlockManagerAdded(1L, bm1, 1000L)) - listener.onBlockManagerAdded(SparkListenerBlockManagerAdded(1L, bm2, 2000L)) - - val blockUpdateInfos1 = Seq( - BlockUpdatedInfo(bm1, RDDBlockId(1, 1), StorageLevel.DISK_ONLY, 0L, 100L), - BlockUpdatedInfo(bm1, RDDBlockId(1, 2), StorageLevel.DISK_ONLY, 0L, 200L) - ) - val blockUpdateInfos2 = - Seq(BlockUpdatedInfo(bm2, RDDBlockId(4, 0), StorageLevel.DISK_ONLY, 0L, 300L)) - - // Add some new blocks - assert(listener.executorIdToStorageStatus("big").numBlocks === 0) - assert(listener.executorIdToStorageStatus("fat").numBlocks === 0) - postUpdateBlock(listener, blockUpdateInfos1) - assert(listener.executorIdToStorageStatus("big").numBlocks === 2) - assert(listener.executorIdToStorageStatus("fat").numBlocks === 0) - assert(listener.executorIdToStorageStatus("big").containsBlock(RDDBlockId(1, 1))) - assert(listener.executorIdToStorageStatus("big").containsBlock(RDDBlockId(1, 2))) - assert(listener.executorIdToStorageStatus("fat").numBlocks === 0) - postUpdateBlock(listener, blockUpdateInfos2) - assert(listener.executorIdToStorageStatus("big").numBlocks === 2) - assert(listener.executorIdToStorageStatus("fat").numBlocks === 1) - assert(listener.executorIdToStorageStatus("big").containsBlock(RDDBlockId(1, 1))) - assert(listener.executorIdToStorageStatus("big").containsBlock(RDDBlockId(1, 2))) - assert(listener.executorIdToStorageStatus("fat").containsBlock(RDDBlockId(4, 0))) - - // Dropped the blocks - val droppedBlockInfo1 = Seq( - BlockUpdatedInfo(bm1, RDDBlockId(1, 1), StorageLevel.NONE, 0L, 0L), - BlockUpdatedInfo(bm1, RDDBlockId(4, 0), StorageLevel.NONE, 0L, 0L) - ) - val droppedBlockInfo2 = Seq( - BlockUpdatedInfo(bm2, RDDBlockId(1, 2), StorageLevel.NONE, 0L, 0L), - BlockUpdatedInfo(bm2, RDDBlockId(4, 0), StorageLevel.NONE, 0L, 0L) - ) - - postUpdateBlock(listener, droppedBlockInfo1) - assert(listener.executorIdToStorageStatus("big").numBlocks === 1) - assert(listener.executorIdToStorageStatus("fat").numBlocks === 1) - assert(!listener.executorIdToStorageStatus("big").containsBlock(RDDBlockId(1, 1))) - assert(listener.executorIdToStorageStatus("big").containsBlock(RDDBlockId(1, 2))) - assert(listener.executorIdToStorageStatus("fat").containsBlock(RDDBlockId(4, 0))) - postUpdateBlock(listener, droppedBlockInfo2) - assert(listener.executorIdToStorageStatus("big").numBlocks === 1) - assert(listener.executorIdToStorageStatus("fat").numBlocks === 0) - assert(!listener.executorIdToStorageStatus("big").containsBlock(RDDBlockId(1, 1))) - assert(listener.executorIdToStorageStatus("big").containsBlock(RDDBlockId(1, 2))) - assert(listener.executorIdToStorageStatus("fat").numBlocks === 0) - } - - test("unpersist RDD") { - val listener = new StorageStatusListener(conf) - listener.onBlockManagerAdded(SparkListenerBlockManagerAdded(1L, bm1, 1000L)) - val blockUpdateInfos1 = Seq( - BlockUpdatedInfo(bm1, RDDBlockId(1, 1), StorageLevel.DISK_ONLY, 0L, 100L), - BlockUpdatedInfo(bm1, RDDBlockId(1, 2), StorageLevel.DISK_ONLY, 0L, 200L) - ) - val blockUpdateInfos2 = - Seq(BlockUpdatedInfo(bm1, RDDBlockId(4, 0), StorageLevel.DISK_ONLY, 0L, 300L)) - postUpdateBlock(listener, blockUpdateInfos1) - postUpdateBlock(listener, blockUpdateInfos2) - assert(listener.executorIdToStorageStatus("big").numBlocks === 3) - - // Unpersist RDD - listener.onUnpersistRDD(SparkListenerUnpersistRDD(9090)) - assert(listener.executorIdToStorageStatus("big").numBlocks === 3) - listener.onUnpersistRDD(SparkListenerUnpersistRDD(4)) - assert(listener.executorIdToStorageStatus("big").numBlocks === 2) - assert(listener.executorIdToStorageStatus("big").containsBlock(RDDBlockId(1, 1))) - assert(listener.executorIdToStorageStatus("big").containsBlock(RDDBlockId(1, 2))) - listener.onUnpersistRDD(SparkListenerUnpersistRDD(1)) - assert(listener.executorIdToStorageStatus("big").numBlocks === 0) - } - - private def postUpdateBlock( - listener: StorageStatusListener, updateBlockInfos: Seq[BlockUpdatedInfo]): Unit = { - updateBlockInfos.foreach { updateBlockInfo => - listener.onBlockUpdated(SparkListenerBlockUpdated(updateBlockInfo)) - } - } -} diff --git a/core/src/test/scala/org/apache/spark/ui/storage/StoragePageSuite.scala b/core/src/test/scala/org/apache/spark/ui/storage/StoragePageSuite.scala index 4a48b3c686725..a71521c91d2f2 100644 --- a/core/src/test/scala/org/apache/spark/ui/storage/StoragePageSuite.scala +++ b/core/src/test/scala/org/apache/spark/ui/storage/StoragePageSuite.scala @@ -20,39 +20,46 @@ package org.apache.spark.ui.storage import org.mockito.Mockito._ import org.apache.spark.SparkFunSuite +import org.apache.spark.status.StreamBlockData +import org.apache.spark.status.api.v1.RDDStorageInfo import org.apache.spark.storage._ class StoragePageSuite extends SparkFunSuite { val storageTab = mock(classOf[StorageTab]) when(storageTab.basePath).thenReturn("http://localhost:4040") - val storagePage = new StoragePage(storageTab) + val storagePage = new StoragePage(storageTab, null) test("rddTable") { - val rdd1 = new RDDInfo(1, + val rdd1 = new RDDStorageInfo(1, "rdd1", 10, - StorageLevel.MEMORY_ONLY, - Seq.empty) - rdd1.memSize = 100 - rdd1.numCachedPartitions = 10 + 10, + StorageLevel.MEMORY_ONLY.description, + 100L, + 0L, + None, + None) - val rdd2 = new RDDInfo(2, + val rdd2 = new RDDStorageInfo(2, "rdd2", 10, - StorageLevel.DISK_ONLY, - Seq.empty) - rdd2.diskSize = 200 - rdd2.numCachedPartitions = 5 - - val rdd3 = new RDDInfo(3, + 5, + StorageLevel.DISK_ONLY.description, + 0L, + 200L, + None, + None) + + val rdd3 = new RDDStorageInfo(3, "rdd3", 10, - StorageLevel.MEMORY_AND_DISK_SER, - Seq.empty) - rdd3.memSize = 400 - rdd3.diskSize = 500 - rdd3.numCachedPartitions = 10 + 10, + StorageLevel.MEMORY_AND_DISK_SER.description, + 400L, + 500L, + None, + None) val xmlNodes = storagePage.rddTable(Seq(rdd1, rdd2, rdd3)) @@ -91,58 +98,85 @@ class StoragePageSuite extends SparkFunSuite { } test("streamBlockStorageLevelDescriptionAndSize") { - val memoryBlock = BlockUIData(StreamBlockId(0, 0), + val memoryBlock = new StreamBlockData("0", + "0", "localhost:1111", - StorageLevel.MEMORY_ONLY, - memSize = 100, - diskSize = 0) + StorageLevel.MEMORY_ONLY.description, + true, + false, + true, + 100, + 0) assert(("Memory", 100) === storagePage.streamBlockStorageLevelDescriptionAndSize(memoryBlock)) - val memorySerializedBlock = BlockUIData(StreamBlockId(0, 0), + val memorySerializedBlock = new StreamBlockData("0", + "0", "localhost:1111", - StorageLevel.MEMORY_ONLY_SER, + StorageLevel.MEMORY_ONLY_SER.description, + true, + false, + false, memSize = 100, diskSize = 0) assert(("Memory Serialized", 100) === storagePage.streamBlockStorageLevelDescriptionAndSize(memorySerializedBlock)) - val diskBlock = BlockUIData(StreamBlockId(0, 0), + val diskBlock = new StreamBlockData("0", + "0", "localhost:1111", - StorageLevel.DISK_ONLY, - memSize = 0, - diskSize = 100) + StorageLevel.DISK_ONLY.description, + false, + true, + false, + 0, + 100) assert(("Disk", 100) === storagePage.streamBlockStorageLevelDescriptionAndSize(diskBlock)) } test("receiverBlockTables") { val blocksForExecutor0 = Seq( - BlockUIData(StreamBlockId(0, 0), + new StreamBlockData(StreamBlockId(0, 0).name, + "0", "localhost:10000", - StorageLevel.MEMORY_ONLY, - memSize = 100, - diskSize = 0), - BlockUIData(StreamBlockId(1, 1), + StorageLevel.MEMORY_ONLY.description, + true, + false, + true, + 100, + 0), + new StreamBlockData(StreamBlockId(1, 1).name, + "0", "localhost:10000", - StorageLevel.DISK_ONLY, - memSize = 0, - diskSize = 100) + StorageLevel.DISK_ONLY.description, + false, + true, + false, + 0, + 100) ) - val executor0 = ExecutorStreamBlockStatus("0", "localhost:10000", blocksForExecutor0) val blocksForExecutor1 = Seq( - BlockUIData(StreamBlockId(0, 0), + new StreamBlockData(StreamBlockId(0, 0).name, + "1", "localhost:10001", - StorageLevel.MEMORY_ONLY, + StorageLevel.MEMORY_ONLY.description, + true, + false, + true, memSize = 100, diskSize = 0), - BlockUIData(StreamBlockId(1, 1), + new StreamBlockData(StreamBlockId(1, 1).name, + "1", "localhost:10001", - StorageLevel.MEMORY_ONLY_SER, - memSize = 100, - diskSize = 0) + StorageLevel.MEMORY_ONLY_SER.description, + true, + false, + false, + 100, + 0) ) - val executor1 = ExecutorStreamBlockStatus("1", "localhost:10001", blocksForExecutor1) - val xmlNodes = storagePage.receiverBlockTables(Seq(executor0, executor1)) + + val xmlNodes = storagePage.receiverBlockTables(blocksForExecutor0 ++ blocksForExecutor1) val executorTable = (xmlNodes \\ "table")(0) val executorHeaders = Seq( @@ -190,8 +224,6 @@ class StoragePageSuite extends SparkFunSuite { test("empty receiverBlockTables") { assert(storagePage.receiverBlockTables(Seq.empty).isEmpty) - val executor0 = ExecutorStreamBlockStatus("0", "localhost:10000", Seq.empty) - val executor1 = ExecutorStreamBlockStatus("1", "localhost:10001", Seq.empty) - assert(storagePage.receiverBlockTables(Seq(executor0, executor1)).isEmpty) } + } diff --git a/core/src/test/scala/org/apache/spark/ui/storage/StorageTabSuite.scala b/core/src/test/scala/org/apache/spark/ui/storage/StorageTabSuite.scala deleted file mode 100644 index 79f02f2e50bbd..0000000000000 --- a/core/src/test/scala/org/apache/spark/ui/storage/StorageTabSuite.scala +++ /dev/null @@ -1,205 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.ui.storage - -import org.scalatest.BeforeAndAfter - -import org.apache.spark._ -import org.apache.spark.scheduler._ -import org.apache.spark.storage._ - -/** - * Test various functionality in the StorageListener that supports the StorageTab. - */ -class StorageTabSuite extends SparkFunSuite with BeforeAndAfter { - private var bus: SparkListenerBus = _ - private var storageStatusListener: StorageStatusListener = _ - private var storageListener: StorageListener = _ - private val memAndDisk = StorageLevel.MEMORY_AND_DISK - private val memOnly = StorageLevel.MEMORY_ONLY - private val none = StorageLevel.NONE - private val taskInfo = new TaskInfo(0, 0, 0, 0, "big", "dog", TaskLocality.ANY, false) - private val taskInfo1 = new TaskInfo(1, 1, 1, 1, "big", "cat", TaskLocality.ANY, false) - private def rddInfo0 = new RDDInfo(0, "freedom", 100, memOnly, Seq(10)) - private def rddInfo1 = new RDDInfo(1, "hostage", 200, memOnly, Seq(10)) - private def rddInfo2 = new RDDInfo(2, "sanity", 300, memAndDisk, Seq(10)) - private def rddInfo3 = new RDDInfo(3, "grace", 400, memAndDisk, Seq(10)) - private val bm1 = BlockManagerId("big", "dog", 1) - - before { - val conf = new SparkConf() - bus = new ReplayListenerBus() - storageStatusListener = new StorageStatusListener(conf) - storageListener = new StorageListener(storageStatusListener) - bus.addListener(storageStatusListener) - bus.addListener(storageListener) - } - - test("stage submitted / completed") { - assert(storageListener._rddInfoMap.isEmpty) - assert(storageListener.rddInfoList.isEmpty) - - // 2 RDDs are known, but none are cached - val stageInfo0 = new StageInfo(0, 0, "0", 100, Seq(rddInfo0, rddInfo1), Seq.empty, "details") - bus.postToAll(SparkListenerStageSubmitted(stageInfo0)) - assert(storageListener._rddInfoMap.size === 2) - assert(storageListener.rddInfoList.isEmpty) - - // 4 RDDs are known, but only 2 are cached - val rddInfo2Cached = rddInfo2 - val rddInfo3Cached = rddInfo3 - rddInfo2Cached.numCachedPartitions = 1 - rddInfo3Cached.numCachedPartitions = 1 - val stageInfo1 = new StageInfo( - 1, 0, "0", 100, Seq(rddInfo2Cached, rddInfo3Cached), Seq.empty, "details") - bus.postToAll(SparkListenerStageSubmitted(stageInfo1)) - assert(storageListener._rddInfoMap.size === 4) - assert(storageListener.rddInfoList.size === 2) - - // Submitting RDDInfos with duplicate IDs does nothing - val rddInfo0Cached = new RDDInfo(0, "freedom", 100, StorageLevel.MEMORY_ONLY, Seq(10)) - rddInfo0Cached.numCachedPartitions = 1 - val stageInfo0Cached = new StageInfo(0, 0, "0", 100, Seq(rddInfo0Cached), Seq.empty, "details") - bus.postToAll(SparkListenerStageSubmitted(stageInfo0Cached)) - assert(storageListener._rddInfoMap.size === 4) - assert(storageListener.rddInfoList.size === 2) - - // We only keep around the RDDs that are cached - bus.postToAll(SparkListenerStageCompleted(stageInfo0)) - assert(storageListener._rddInfoMap.size === 2) - assert(storageListener.rddInfoList.size === 2) - } - - test("unpersist") { - val rddInfo0Cached = rddInfo0 - val rddInfo1Cached = rddInfo1 - rddInfo0Cached.numCachedPartitions = 1 - rddInfo1Cached.numCachedPartitions = 1 - val stageInfo0 = new StageInfo( - 0, 0, "0", 100, Seq(rddInfo0Cached, rddInfo1Cached), Seq.empty, "details") - bus.postToAll(SparkListenerStageSubmitted(stageInfo0)) - assert(storageListener._rddInfoMap.size === 2) - assert(storageListener.rddInfoList.size === 2) - bus.postToAll(SparkListenerUnpersistRDD(0)) - assert(storageListener._rddInfoMap.size === 1) - assert(storageListener.rddInfoList.size === 1) - bus.postToAll(SparkListenerUnpersistRDD(4)) // doesn't exist - assert(storageListener._rddInfoMap.size === 1) - assert(storageListener.rddInfoList.size === 1) - bus.postToAll(SparkListenerUnpersistRDD(1)) - assert(storageListener._rddInfoMap.size === 0) - assert(storageListener.rddInfoList.size === 0) - } - - test("block update") { - val myRddInfo0 = rddInfo0 - val myRddInfo1 = rddInfo1 - val myRddInfo2 = rddInfo2 - val stageInfo0 = new StageInfo( - 0, 0, "0", 100, Seq(myRddInfo0, myRddInfo1, myRddInfo2), Seq.empty, "details") - bus.postToAll(SparkListenerBlockManagerAdded(1L, bm1, 1000L)) - bus.postToAll(SparkListenerStageSubmitted(stageInfo0)) - assert(storageListener._rddInfoMap.size === 3) - assert(storageListener.rddInfoList.size === 0) // not cached - assert(!storageListener._rddInfoMap(0).isCached) - assert(!storageListener._rddInfoMap(1).isCached) - assert(!storageListener._rddInfoMap(2).isCached) - - // Some blocks updated - val blockUpdateInfos = Seq( - BlockUpdatedInfo(bm1, RDDBlockId(0, 100), memAndDisk, 400L, 0L), - BlockUpdatedInfo(bm1, RDDBlockId(0, 101), memAndDisk, 0L, 400L), - BlockUpdatedInfo(bm1, RDDBlockId(1, 20), memAndDisk, 0L, 240L) - ) - postUpdateBlocks(bus, blockUpdateInfos) - assert(storageListener._rddInfoMap(0).memSize === 400L) - assert(storageListener._rddInfoMap(0).diskSize === 400L) - assert(storageListener._rddInfoMap(0).numCachedPartitions === 2) - assert(storageListener._rddInfoMap(0).isCached) - assert(storageListener._rddInfoMap(1).memSize === 0L) - assert(storageListener._rddInfoMap(1).diskSize === 240L) - assert(storageListener._rddInfoMap(1).numCachedPartitions === 1) - assert(storageListener._rddInfoMap(1).isCached) - assert(!storageListener._rddInfoMap(2).isCached) - assert(storageListener._rddInfoMap(2).numCachedPartitions === 0) - - // Drop some blocks - val blockUpdateInfos2 = Seq( - BlockUpdatedInfo(bm1, RDDBlockId(0, 100), none, 0L, 0L), - BlockUpdatedInfo(bm1, RDDBlockId(1, 20), none, 0L, 0L), - BlockUpdatedInfo(bm1, RDDBlockId(2, 40), none, 0L, 0L), // doesn't actually exist - BlockUpdatedInfo(bm1, RDDBlockId(4, 80), none, 0L, 0L) // doesn't actually exist - ) - postUpdateBlocks(bus, blockUpdateInfos2) - assert(storageListener._rddInfoMap(0).memSize === 0L) - assert(storageListener._rddInfoMap(0).diskSize === 400L) - assert(storageListener._rddInfoMap(0).numCachedPartitions === 1) - assert(storageListener._rddInfoMap(0).isCached) - assert(!storageListener._rddInfoMap(1).isCached) - assert(storageListener._rddInfoMap(2).numCachedPartitions === 0) - assert(!storageListener._rddInfoMap(2).isCached) - assert(storageListener._rddInfoMap(2).numCachedPartitions === 0) - } - - test("verify StorageTab contains all cached rdds") { - - val rddInfo0 = new RDDInfo(0, "rdd0", 1, memOnly, Seq(4)) - val rddInfo1 = new RDDInfo(1, "rdd1", 1, memOnly, Seq(4)) - val stageInfo0 = new StageInfo(0, 0, "stage0", 1, Seq(rddInfo0), Seq.empty, "details") - val stageInfo1 = new StageInfo(1, 0, "stage1", 1, Seq(rddInfo1), Seq.empty, "details") - val blockUpdateInfos1 = Seq(BlockUpdatedInfo(bm1, RDDBlockId(0, 1), memOnly, 100L, 0L)) - val blockUpdateInfos2 = Seq(BlockUpdatedInfo(bm1, RDDBlockId(1, 1), memOnly, 200L, 0L)) - bus.postToAll(SparkListenerBlockManagerAdded(1L, bm1, 1000L)) - bus.postToAll(SparkListenerStageSubmitted(stageInfo0)) - assert(storageListener.rddInfoList.size === 0) - postUpdateBlocks(bus, blockUpdateInfos1) - assert(storageListener.rddInfoList.size === 1) - bus.postToAll(SparkListenerStageSubmitted(stageInfo1)) - assert(storageListener.rddInfoList.size === 1) - bus.postToAll(SparkListenerStageCompleted(stageInfo0)) - assert(storageListener.rddInfoList.size === 1) - postUpdateBlocks(bus, blockUpdateInfos2) - assert(storageListener.rddInfoList.size === 2) - bus.postToAll(SparkListenerStageCompleted(stageInfo1)) - assert(storageListener.rddInfoList.size === 2) - } - - test("verify StorageTab still contains a renamed RDD") { - val rddInfo = new RDDInfo(0, "original_name", 1, memOnly, Seq(4)) - val stageInfo0 = new StageInfo(0, 0, "stage0", 1, Seq(rddInfo), Seq.empty, "details") - bus.postToAll(SparkListenerBlockManagerAdded(1L, bm1, 1000L)) - bus.postToAll(SparkListenerStageSubmitted(stageInfo0)) - val blockUpdateInfos1 = Seq(BlockUpdatedInfo(bm1, RDDBlockId(0, 1), memOnly, 100L, 0L)) - postUpdateBlocks(bus, blockUpdateInfos1) - assert(storageListener.rddInfoList.size == 1) - - val newName = "new_name" - val rddInfoRenamed = new RDDInfo(0, newName, 1, memOnly, Seq(4)) - val stageInfo1 = new StageInfo(1, 0, "stage1", 1, Seq(rddInfoRenamed), Seq.empty, "details") - bus.postToAll(SparkListenerStageSubmitted(stageInfo1)) - assert(storageListener.rddInfoList.size == 1) - assert(storageListener.rddInfoList.head.name == newName) - } - - private def postUpdateBlocks( - bus: SparkListenerBus, blockUpdateInfos: Seq[BlockUpdatedInfo]): Unit = { - blockUpdateInfos.foreach { blockUpdateInfo => - bus.postToAll(SparkListenerBlockUpdated(blockUpdateInfo)) - } - } -} diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index 0c31b2b4a9402..090375f65552e 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -41,6 +41,8 @@ object MimaExcludes { ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.deploy.history.HistoryServer.getSparkUI"), ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.ui.env.EnvironmentListener"), ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.ui.exec.ExecutorsListener"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.ui.storage.StorageListener"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.storage.StorageStatusListener"), // [SPARK-20495][SQL] Add StorageLevel to cacheTable API ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.sql.catalog.Catalog.cacheTable"), From 9eb7096c47a7e5f98de19512515386b3d0698039 Mon Sep 17 00:00:00 2001 From: Srinivasa Reddy Vundela Date: Thu, 9 Nov 2017 16:05:47 -0800 Subject: [PATCH 1638/1765] [SPARK-22483][CORE] Exposing java.nio bufferedPool memory metrics to Metric System ## What changes were proposed in this pull request? Adds java.nio bufferedPool memory metrics to metrics system which includes both direct and mapped memory. ## How was this patch tested? Manually tested and checked direct and mapped memory metrics too available in metrics system using Console sink. Here is the sample console output application_1509655862825_0016.2.jvm.direct.capacity value = 19497 application_1509655862825_0016.2.jvm.direct.count value = 6 application_1509655862825_0016.2.jvm.direct.used value = 19498 application_1509655862825_0016.2.jvm.mapped.capacity value = 0 application_1509655862825_0016.2.jvm.mapped.count value = 0 application_1509655862825_0016.2.jvm.mapped.used value = 0 Author: Srinivasa Reddy Vundela Closes #19709 from vundela/SPARK-22483. --- .../scala/org/apache/spark/metrics/source/JvmSource.scala | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/metrics/source/JvmSource.scala b/core/src/main/scala/org/apache/spark/metrics/source/JvmSource.scala index 635bff2cd7ec8..dcaa0f19295e8 100644 --- a/core/src/main/scala/org/apache/spark/metrics/source/JvmSource.scala +++ b/core/src/main/scala/org/apache/spark/metrics/source/JvmSource.scala @@ -17,8 +17,10 @@ package org.apache.spark.metrics.source +import java.lang.management.ManagementFactory + import com.codahale.metrics.MetricRegistry -import com.codahale.metrics.jvm.{GarbageCollectorMetricSet, MemoryUsageGaugeSet} +import com.codahale.metrics.jvm.{BufferPoolMetricSet, GarbageCollectorMetricSet, MemoryUsageGaugeSet} private[spark] class JvmSource extends Source { override val sourceName = "jvm" @@ -26,4 +28,6 @@ private[spark] class JvmSource extends Source { metricRegistry.registerAll(new GarbageCollectorMetricSet) metricRegistry.registerAll(new MemoryUsageGaugeSet) + metricRegistry.registerAll( + new BufferPoolMetricSet(ManagementFactory.getPlatformMBeanServer)) } From 11c4021044f3a302449a2ea76811e73f5c99a26a Mon Sep 17 00:00:00 2001 From: Wing Yew Poon Date: Thu, 9 Nov 2017 16:20:55 -0800 Subject: [PATCH 1639/1765] [SPARK-22403][SS] Add optional checkpointLocation argument to StructuredKafkaWordCount example ## What changes were proposed in this pull request? When run in YARN cluster mode, the StructuredKafkaWordCount example fails because Spark tries to create a temporary checkpoint location in a subdirectory of the path given by java.io.tmpdir, and YARN sets java.io.tmpdir to a path in the local filesystem that usually does not correspond to an existing path in the distributed filesystem. Add an optional checkpointLocation argument to the StructuredKafkaWordCount example so that users can specify the checkpoint location and avoid this issue. ## How was this patch tested? Built and ran the example manually on YARN client and cluster mode. Author: Wing Yew Poon Closes #19703 from wypoon/SPARK-22403. --- .../sql/streaming/StructuredKafkaWordCount.scala | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/examples/src/main/scala/org/apache/spark/examples/sql/streaming/StructuredKafkaWordCount.scala b/examples/src/main/scala/org/apache/spark/examples/sql/streaming/StructuredKafkaWordCount.scala index c26f73e788814..2aab49c8891d1 100644 --- a/examples/src/main/scala/org/apache/spark/examples/sql/streaming/StructuredKafkaWordCount.scala +++ b/examples/src/main/scala/org/apache/spark/examples/sql/streaming/StructuredKafkaWordCount.scala @@ -18,11 +18,14 @@ // scalastyle:off println package org.apache.spark.examples.sql.streaming +import java.util.UUID + import org.apache.spark.sql.SparkSession /** * Consumes messages from one or more topics in Kafka and does wordcount. * Usage: StructuredKafkaWordCount + * [] * The Kafka "bootstrap.servers" configuration. A * comma-separated list of host:port. * There are three kinds of type, i.e. 'assign', 'subscribe', @@ -36,6 +39,8 @@ import org.apache.spark.sql.SparkSession * |- Only one of "assign, "subscribe" or "subscribePattern" options can be * | specified for Kafka source. * Different value format depends on the value of 'subscribe-type'. + * Directory in which to create checkpoints. If not + * provided, defaults to a randomized directory in /tmp. * * Example: * `$ bin/run-example \ @@ -46,11 +51,13 @@ object StructuredKafkaWordCount { def main(args: Array[String]): Unit = { if (args.length < 3) { System.err.println("Usage: StructuredKafkaWordCount " + - " ") + " []") System.exit(1) } - val Array(bootstrapServers, subscribeType, topics) = args + val Array(bootstrapServers, subscribeType, topics, _*) = args + val checkpointLocation = + if (args.length > 3) args(3) else "/tmp/temporary-" + UUID.randomUUID.toString val spark = SparkSession .builder @@ -76,6 +83,7 @@ object StructuredKafkaWordCount { val query = wordCounts.writeStream .outputMode("complete") .format("console") + .option("checkpointLocation", checkpointLocation) .start() query.awaitTermination() From 64c989495a4ac7e23a22672448a622862a63486a Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Thu, 9 Nov 2017 16:40:19 -0800 Subject: [PATCH 1640/1765] [SPARK-22485][BUILD] Use `exclude[Problem]` instead `excludePackage` in MiMa ## What changes were proposed in this pull request? `excludePackage` is deprecated like the [following](https://github.com/lightbend/migration-manager/blob/master/core/src/main/scala/com/typesafe/tools/mima/core/Filters.scala#L33-L36) and shows deprecation warnings now. This PR uses `exclude[Problem](packageName + ".*")` instead. ```scala deprecated("Replace with ProblemFilters.exclude[Problem](\"my.package.*\")", "0.1.15") def excludePackage(packageName: String): ProblemFilter = { exclude[Problem](packageName + ".*") } ``` ## How was this patch tested? Pass the Jenkins MiMa. Author: Dongjoon Hyun Closes #19710 from dongjoon-hyun/SPARK-22485. --- project/MimaBuild.scala | 4 ++-- project/MimaExcludes.scala | 22 +++++++++++----------- 2 files changed, 13 insertions(+), 13 deletions(-) diff --git a/project/MimaBuild.scala b/project/MimaBuild.scala index de0655b6cb357..2ef0e7b40d940 100644 --- a/project/MimaBuild.scala +++ b/project/MimaBuild.scala @@ -44,7 +44,7 @@ object MimaBuild { // Exclude a single class def excludeClass(className: String) = Seq( - excludePackage(className), + ProblemFilters.exclude[Problem](className + ".*"), ProblemFilters.exclude[MissingClassProblem](className), ProblemFilters.exclude[MissingTypesProblem](className) ) @@ -56,7 +56,7 @@ object MimaBuild { // Exclude a Spark package, that is in the package org.apache.spark def excludeSparkPackage(packageName: String) = { - excludePackage("org.apache.spark." + packageName) + ProblemFilters.exclude[Problem]("org.apache.spark." + packageName + ".*") } def ignoredABIProblems(base: File, currentSparkVersion: String) = { diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index 090375f65552e..e6f136c7c8b0a 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -242,17 +242,17 @@ object MimaExcludes { // Exclude rules for 2.0.x lazy val v20excludes = { Seq( - excludePackage("org.apache.spark.rpc"), - excludePackage("org.spark-project.jetty"), - excludePackage("org.spark_project.jetty"), - excludePackage("org.apache.spark.internal"), - excludePackage("org.apache.spark.unused"), - excludePackage("org.apache.spark.unsafe"), - excludePackage("org.apache.spark.memory"), - excludePackage("org.apache.spark.util.collection.unsafe"), - excludePackage("org.apache.spark.sql.catalyst"), - excludePackage("org.apache.spark.sql.execution"), - excludePackage("org.apache.spark.sql.internal"), + ProblemFilters.exclude[Problem]("org.apache.spark.rpc.*"), + ProblemFilters.exclude[Problem]("org.spark-project.jetty.*"), + ProblemFilters.exclude[Problem]("org.spark_project.jetty.*"), + ProblemFilters.exclude[Problem]("org.apache.spark.internal.*"), + ProblemFilters.exclude[Problem]("org.apache.spark.unused.*"), + ProblemFilters.exclude[Problem]("org.apache.spark.unsafe.*"), + ProblemFilters.exclude[Problem]("org.apache.spark.memory.*"), + ProblemFilters.exclude[Problem]("org.apache.spark.util.collection.unsafe.*"), + ProblemFilters.exclude[Problem]("org.apache.spark.sql.catalyst.*"), + ProblemFilters.exclude[Problem]("org.apache.spark.sql.execution.*"), + ProblemFilters.exclude[Problem]("org.apache.spark.sql.internal.*"), ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.mllib.feature.PCAModel.this"), ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.status.api.v1.StageData.this"), ProblemFilters.exclude[MissingMethodProblem]( From f5fe63f7b8546b0102d7bfaf3dde77379f58a4d1 Mon Sep 17 00:00:00 2001 From: Paul Mackles Date: Thu, 9 Nov 2017 16:42:33 -0800 Subject: [PATCH 1641/1765] =?UTF-8?q?[SPARK-22287][MESOS]=20SPARK=5FDAEMON?= =?UTF-8?q?=5FMEMORY=20not=20honored=20by=20MesosClusterD=E2=80=A6?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit …ispatcher ## What changes were proposed in this pull request? Allow JVM max heap size to be controlled for MesosClusterDispatcher via SPARK_DAEMON_MEMORY environment variable. ## How was this patch tested? Tested on local Mesos cluster Author: Paul Mackles Closes #19515 from pmackles/SPARK-22287. --- .../java/org/apache/spark/launcher/SparkClassCommandBuilder.java | 1 + 1 file changed, 1 insertion(+) diff --git a/launcher/src/main/java/org/apache/spark/launcher/SparkClassCommandBuilder.java b/launcher/src/main/java/org/apache/spark/launcher/SparkClassCommandBuilder.java index 32724acdc362c..fd056bb90e0c4 100644 --- a/launcher/src/main/java/org/apache/spark/launcher/SparkClassCommandBuilder.java +++ b/launcher/src/main/java/org/apache/spark/launcher/SparkClassCommandBuilder.java @@ -81,6 +81,7 @@ public List buildCommand(Map env) case "org.apache.spark.deploy.mesos.MesosClusterDispatcher": javaOptsKeys.add("SPARK_DAEMON_JAVA_OPTS"); extraClassPath = getenv("SPARK_DAEMON_CLASSPATH"); + memKey = "SPARK_DAEMON_MEMORY"; break; case "org.apache.spark.deploy.ExternalShuffleService": case "org.apache.spark.deploy.mesos.MesosExternalShuffleService": From b57ed2245c705fb0964462cf4492b809ade836c6 Mon Sep 17 00:00:00 2001 From: Nathan Kronenfeld Date: Thu, 9 Nov 2017 19:11:30 -0800 Subject: [PATCH 1642/1765] [SPARK-22308][TEST-MAVEN] Support alternative unit testing styles in external applications Continuation of PR#19528 (https://github.com/apache/spark/pull/19529#issuecomment-340252119) The problem with the maven build in the previous PR was the new tests.... the creation of a spark session outside the tests meant there was more than one spark session around at a time. I was using the spark session outside the tests so that the tests could share data; I've changed it so that each test creates the data anew. Author: Nathan Kronenfeld Author: Nathan Kronenfeld Closes #19705 from nkronenfeld/alternative-style-tests-2. --- .../org/apache/spark/SharedSparkContext.scala | 17 +- .../spark/sql/catalyst/plans/PlanTest.scala | 10 +- .../spark/sql/test/GenericFlatSpecSuite.scala | 47 +++++ .../spark/sql/test/GenericFunSpecSuite.scala | 49 +++++ .../spark/sql/test/GenericWordSpecSuite.scala | 53 ++++++ .../apache/spark/sql/test/SQLTestUtils.scala | 173 ++++++++++-------- .../spark/sql/test/SharedSQLContext.scala | 84 +-------- .../spark/sql/test/SharedSparkSession.scala | 119 ++++++++++++ 8 files changed, 387 insertions(+), 165 deletions(-) create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/test/GenericFlatSpecSuite.scala create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/test/GenericFunSpecSuite.scala create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/test/GenericWordSpecSuite.scala create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/test/SharedSparkSession.scala diff --git a/core/src/test/scala/org/apache/spark/SharedSparkContext.scala b/core/src/test/scala/org/apache/spark/SharedSparkContext.scala index 6aedcb1271ff6..1aa1c421d792e 100644 --- a/core/src/test/scala/org/apache/spark/SharedSparkContext.scala +++ b/core/src/test/scala/org/apache/spark/SharedSparkContext.scala @@ -29,10 +29,23 @@ trait SharedSparkContext extends BeforeAndAfterAll with BeforeAndAfterEach { sel var conf = new SparkConf(false) + /** + * Initialize the [[SparkContext]]. Generally, this is just called from beforeAll; however, in + * test using styles other than FunSuite, there is often code that relies on the session between + * test group constructs and the actual tests, which may need this session. It is purely a + * semantic difference, but semantically, it makes more sense to call 'initializeContext' between + * a 'describe' and an 'it' call than it does to call 'beforeAll'. + */ + protected def initializeContext(): Unit = { + if (null == _sc) { + _sc = new SparkContext( + "local[4]", "test", conf.set("spark.hadoop.fs.file.impl", classOf[DebugFilesystem].getName)) + } + } + override def beforeAll() { super.beforeAll() - _sc = new SparkContext( - "local[4]", "test", conf.set("spark.hadoop.fs.file.impl", classOf[DebugFilesystem].getName)) + initializeContext() } override def afterAll() { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala index 10bdfafd6f933..82c5307d54360 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.catalyst.plans +import org.scalatest.Suite + import org.apache.spark.SparkFunSuite import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.analysis.SimpleAnalyzer @@ -29,7 +31,13 @@ import org.apache.spark.sql.internal.SQLConf /** * Provides helper methods for comparing plans. */ -trait PlanTest extends SparkFunSuite with PredicateHelper { +trait PlanTest extends SparkFunSuite with PlanTestBase + +/** + * Provides helper methods for comparing plans, but without the overhead of + * mandating a FunSuite. + */ +trait PlanTestBase extends PredicateHelper { self: Suite => // TODO(gatorsmile): remove this from PlanTest and all the analyzer rules protected def conf = SQLConf.get diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/GenericFlatSpecSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/GenericFlatSpecSuite.scala new file mode 100644 index 0000000000000..14ac479e89754 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/GenericFlatSpecSuite.scala @@ -0,0 +1,47 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.test + +import org.scalatest.FlatSpec + +import org.apache.spark.sql.Dataset + +/** + * The purpose of this suite is to make sure that generic FlatSpec-based scala + * tests work with a shared spark session + */ +class GenericFlatSpecSuite extends FlatSpec with SharedSparkSession { + import testImplicits._ + + private def ds = Seq((1, 1), (2, 1), (3, 2), (4, 2), (5, 3), (6, 3), (7, 4), (8, 4)).toDS + + "A Simple Dataset" should "have the specified number of elements" in { + assert(8 === ds.count) + } + it should "have the specified number of unique elements" in { + assert(8 === ds.distinct.count) + } + it should "have the specified number of elements in each column" in { + assert(8 === ds.select("_1").count) + assert(8 === ds.select("_2").count) + } + it should "have the correct number of distinct elements in each column" in { + assert(8 === ds.select("_1").distinct.count) + assert(4 === ds.select("_2").distinct.count) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/GenericFunSpecSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/GenericFunSpecSuite.scala new file mode 100644 index 0000000000000..e8971e36d112d --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/GenericFunSpecSuite.scala @@ -0,0 +1,49 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.test + +import org.scalatest.FunSpec + +import org.apache.spark.sql.Dataset + +/** + * The purpose of this suite is to make sure that generic FunSpec-based scala + * tests work with a shared spark session + */ +class GenericFunSpecSuite extends FunSpec with SharedSparkSession { + import testImplicits._ + + private def ds = Seq((1, 1), (2, 1), (3, 2), (4, 2), (5, 3), (6, 3), (7, 4), (8, 4)).toDS + + describe("Simple Dataset") { + it("should have the specified number of elements") { + assert(8 === ds.count) + } + it("should have the specified number of unique elements") { + assert(8 === ds.distinct.count) + } + it("should have the specified number of elements in each column") { + assert(8 === ds.select("_1").count) + assert(8 === ds.select("_2").count) + } + it("should have the correct number of distinct elements in each column") { + assert(8 === ds.select("_1").distinct.count) + assert(4 === ds.select("_2").distinct.count) + } + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/GenericWordSpecSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/GenericWordSpecSuite.scala new file mode 100644 index 0000000000000..44655a5345ca4 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/GenericWordSpecSuite.scala @@ -0,0 +1,53 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.test + +import org.scalatest.WordSpec + +import org.apache.spark.sql.Dataset + +/** + * The purpose of this suite is to make sure that generic WordSpec-based scala + * tests work with a shared spark session + */ +class GenericWordSpecSuite extends WordSpec with SharedSparkSession { + import testImplicits._ + + private def ds = Seq((1, 1), (2, 1), (3, 2), (4, 2), (5, 3), (6, 3), (7, 4), (8, 4)).toDS + + "A Simple Dataset" when { + "looked at as complete rows" should { + "have the specified number of elements" in { + assert(8 === ds.count) + } + "have the specified number of unique elements" in { + assert(8 === ds.distinct.count) + } + } + "refined to specific columns" should { + "have the specified number of elements in each column" in { + assert(8 === ds.select("_1").count) + assert(8 === ds.select("_2").count) + } + "have the correct number of distinct elements in each column" in { + assert(8 === ds.select("_1").distinct.count) + assert(4 === ds.select("_2").distinct.count) + } + } + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala index a14a1441a4313..b4248b74f50ab 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala @@ -27,7 +27,7 @@ import scala.language.implicitConversions import scala.util.control.NonFatal import org.apache.hadoop.fs.Path -import org.scalatest.BeforeAndAfterAll +import org.scalatest.{BeforeAndAfterAll, Suite} import org.scalatest.concurrent.Eventually import org.apache.spark.SparkFunSuite @@ -36,14 +36,17 @@ import org.apache.spark.sql.catalyst.FunctionIdentifier import org.apache.spark.sql.catalyst.analysis.NoSuchTableException import org.apache.spark.sql.catalyst.catalog.SessionCatalog.DEFAULT_DATABASE import org.apache.spark.sql.catalyst.plans.PlanTest +import org.apache.spark.sql.catalyst.plans.PlanTestBase import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.util._ import org.apache.spark.sql.execution.FilterExec import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.util.{UninterruptibleThread, Utils} +import org.apache.spark.util.UninterruptibleThread +import org.apache.spark.util.Utils /** - * Helper trait that should be extended by all SQL test suites. + * Helper trait that should be extended by all SQL test suites within the Spark + * code base. * * This allows subclasses to plugin a custom `SQLContext`. It comes with test data * prepared in advance as well as all implicit conversions used extensively by dataframes. @@ -52,17 +55,99 @@ import org.apache.spark.util.{UninterruptibleThread, Utils} * Subclasses should *not* create `SQLContext`s in the test suite constructor, which is * prone to leaving multiple overlapping [[org.apache.spark.SparkContext]]s in the same JVM. */ -private[sql] trait SQLTestUtils - extends SparkFunSuite with Eventually +private[sql] trait SQLTestUtils extends SparkFunSuite with SQLTestUtilsBase with PlanTest { + // Whether to materialize all test data before the first test is run + private var loadTestDataBeforeTests = false + + protected override def beforeAll(): Unit = { + super.beforeAll() + if (loadTestDataBeforeTests) { + loadTestData() + } + } + + /** + * Materialize the test data immediately after the `SQLContext` is set up. + * This is necessary if the data is accessed by name but not through direct reference. + */ + protected def setupTestData(): Unit = { + loadTestDataBeforeTests = true + } + + /** + * Disable stdout and stderr when running the test. To not output the logs to the console, + * ConsoleAppender's `follow` should be set to `true` so that it will honors reassignments of + * System.out or System.err. Otherwise, ConsoleAppender will still output to the console even if + * we change System.out and System.err. + */ + protected def testQuietly(name: String)(f: => Unit): Unit = { + test(name) { + quietly { + f + } + } + } + + /** + * Run a test on a separate `UninterruptibleThread`. + */ + protected def testWithUninterruptibleThread(name: String, quietly: Boolean = false) + (body: => Unit): Unit = { + val timeoutMillis = 10000 + @transient var ex: Throwable = null + + def runOnThread(): Unit = { + val thread = new UninterruptibleThread(s"Testing thread for test $name") { + override def run(): Unit = { + try { + body + } catch { + case NonFatal(e) => + ex = e + } + } + } + thread.setDaemon(true) + thread.start() + thread.join(timeoutMillis) + if (thread.isAlive) { + thread.interrupt() + // If this interrupt does not work, then this thread is most likely running something that + // is not interruptible. There is not much point to wait for the thread to termniate, and + // we rather let the JVM terminate the thread on exit. + fail( + s"Test '$name' running on o.a.s.util.UninterruptibleThread timed out after" + + s" $timeoutMillis ms") + } else if (ex != null) { + throw ex + } + } + + if (quietly) { + testQuietly(name) { runOnThread() } + } else { + test(name) { runOnThread() } + } + } +} + +/** + * Helper trait that can be extended by all external SQL test suites. + * + * This allows subclasses to plugin a custom `SQLContext`. + * To use implicit methods, import `testImplicits._` instead of through the `SQLContext`. + * + * Subclasses should *not* create `SQLContext`s in the test suite constructor, which is + * prone to leaving multiple overlapping [[org.apache.spark.SparkContext]]s in the same JVM. + */ +private[sql] trait SQLTestUtilsBase + extends Eventually with BeforeAndAfterAll with SQLTestData - with PlanTest { self => + with PlanTestBase { self: Suite => protected def sparkContext = spark.sparkContext - // Whether to materialize all test data before the first test is run - private var loadTestDataBeforeTests = false - // Shorthand for running a query using our SQLContext protected lazy val sql = spark.sql _ @@ -77,21 +162,6 @@ private[sql] trait SQLTestUtils protected override def _sqlContext: SQLContext = self.spark.sqlContext } - /** - * Materialize the test data immediately after the `SQLContext` is set up. - * This is necessary if the data is accessed by name but not through direct reference. - */ - protected def setupTestData(): Unit = { - loadTestDataBeforeTests = true - } - - protected override def beforeAll(): Unit = { - super.beforeAll() - if (loadTestDataBeforeTests) { - loadTestData() - } - } - protected override def withSQLConf(pairs: (String, String)*)(f: => Unit): Unit = { SparkSession.setActiveSession(spark) super.withSQLConf(pairs: _*)(f) @@ -297,61 +367,6 @@ private[sql] trait SQLTestUtils Dataset.ofRows(spark, plan) } - /** - * Disable stdout and stderr when running the test. To not output the logs to the console, - * ConsoleAppender's `follow` should be set to `true` so that it will honors reassignments of - * System.out or System.err. Otherwise, ConsoleAppender will still output to the console even if - * we change System.out and System.err. - */ - protected def testQuietly(name: String)(f: => Unit): Unit = { - test(name) { - quietly { - f - } - } - } - - /** - * Run a test on a separate `UninterruptibleThread`. - */ - protected def testWithUninterruptibleThread(name: String, quietly: Boolean = false) - (body: => Unit): Unit = { - val timeoutMillis = 10000 - @transient var ex: Throwable = null - - def runOnThread(): Unit = { - val thread = new UninterruptibleThread(s"Testing thread for test $name") { - override def run(): Unit = { - try { - body - } catch { - case NonFatal(e) => - ex = e - } - } - } - thread.setDaemon(true) - thread.start() - thread.join(timeoutMillis) - if (thread.isAlive) { - thread.interrupt() - // If this interrupt does not work, then this thread is most likely running something that - // is not interruptible. There is not much point to wait for the thread to termniate, and - // we rather let the JVM terminate the thread on exit. - fail( - s"Test '$name' running on o.a.s.util.UninterruptibleThread timed out after" + - s" $timeoutMillis ms") - } else if (ex != null) { - throw ex - } - } - - if (quietly) { - testQuietly(name) { runOnThread() } - } else { - test(name) { runOnThread() } - } - } /** * This method is used to make the given path qualified, when a path diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSQLContext.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSQLContext.scala index cd8d0708d8a32..4d578e21f5494 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSQLContext.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSQLContext.scala @@ -17,86 +17,4 @@ package org.apache.spark.sql.test -import scala.concurrent.duration._ - -import org.scalatest.BeforeAndAfterEach -import org.scalatest.concurrent.Eventually - -import org.apache.spark.{DebugFilesystem, SparkConf} -import org.apache.spark.sql.{SparkSession, SQLContext} -import org.apache.spark.sql.internal.SQLConf - -/** - * Helper trait for SQL test suites where all tests share a single [[TestSparkSession]]. - */ -trait SharedSQLContext extends SQLTestUtils with BeforeAndAfterEach with Eventually { - - protected def sparkConf = { - new SparkConf() - .set("spark.hadoop.fs.file.impl", classOf[DebugFilesystem].getName) - .set("spark.unsafe.exceptionOnMemoryLeak", "true") - .set(SQLConf.CODEGEN_FALLBACK.key, "false") - } - - /** - * The [[TestSparkSession]] to use for all tests in this suite. - * - * By default, the underlying [[org.apache.spark.SparkContext]] will be run in local - * mode with the default test configurations. - */ - private var _spark: TestSparkSession = null - - /** - * The [[TestSparkSession]] to use for all tests in this suite. - */ - protected implicit def spark: SparkSession = _spark - - /** - * The [[TestSQLContext]] to use for all tests in this suite. - */ - protected implicit def sqlContext: SQLContext = _spark.sqlContext - - protected def createSparkSession: TestSparkSession = { - new TestSparkSession(sparkConf) - } - - /** - * Initialize the [[TestSparkSession]]. - */ - protected override def beforeAll(): Unit = { - SparkSession.sqlListener.set(null) - if (_spark == null) { - _spark = createSparkSession - } - // Ensure we have initialized the context before calling parent code - super.beforeAll() - } - - /** - * Stop the underlying [[org.apache.spark.SparkContext]], if any. - */ - protected override def afterAll(): Unit = { - super.afterAll() - if (_spark != null) { - _spark.sessionState.catalog.reset() - _spark.stop() - _spark = null - } - } - - protected override def beforeEach(): Unit = { - super.beforeEach() - DebugFilesystem.clearOpenStreams() - } - - protected override def afterEach(): Unit = { - super.afterEach() - // Clear all persistent datasets after each test - spark.sharedState.cacheManager.clearCache() - // files can be closed from other threads, so wait a bit - // normally this doesn't take more than 1s - eventually(timeout(10.seconds)) { - DebugFilesystem.assertNoOpenStreams() - } - } -} +trait SharedSQLContext extends SQLTestUtils with SharedSparkSession diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSparkSession.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSparkSession.scala new file mode 100644 index 0000000000000..e0568a3c5c99f --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSparkSession.scala @@ -0,0 +1,119 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.test + +import scala.concurrent.duration._ + +import org.scalatest.{BeforeAndAfterEach, Suite} +import org.scalatest.concurrent.Eventually + +import org.apache.spark.{DebugFilesystem, SparkConf} +import org.apache.spark.sql.{SparkSession, SQLContext} +import org.apache.spark.sql.internal.SQLConf + +/** + * Helper trait for SQL test suites where all tests share a single [[TestSparkSession]]. + */ +trait SharedSparkSession + extends SQLTestUtilsBase + with BeforeAndAfterEach + with Eventually { self: Suite => + + protected def sparkConf = { + new SparkConf() + .set("spark.hadoop.fs.file.impl", classOf[DebugFilesystem].getName) + .set("spark.unsafe.exceptionOnMemoryLeak", "true") + .set(SQLConf.CODEGEN_FALLBACK.key, "false") + } + + /** + * The [[TestSparkSession]] to use for all tests in this suite. + * + * By default, the underlying [[org.apache.spark.SparkContext]] will be run in local + * mode with the default test configurations. + */ + private var _spark: TestSparkSession = null + + /** + * The [[TestSparkSession]] to use for all tests in this suite. + */ + protected implicit def spark: SparkSession = _spark + + /** + * The [[TestSQLContext]] to use for all tests in this suite. + */ + protected implicit def sqlContext: SQLContext = _spark.sqlContext + + protected def createSparkSession: TestSparkSession = { + new TestSparkSession(sparkConf) + } + + /** + * Initialize the [[TestSparkSession]]. Generally, this is just called from + * beforeAll; however, in test using styles other than FunSuite, there is + * often code that relies on the session between test group constructs and + * the actual tests, which may need this session. It is purely a semantic + * difference, but semantically, it makes more sense to call + * 'initializeSession' between a 'describe' and an 'it' call than it does to + * call 'beforeAll'. + */ + protected def initializeSession(): Unit = { + SparkSession.sqlListener.set(null) + if (_spark == null) { + _spark = createSparkSession + } + } + + /** + * Make sure the [[TestSparkSession]] is initialized before any tests are run. + */ + protected override def beforeAll(): Unit = { + initializeSession() + + // Ensure we have initialized the context before calling parent code + super.beforeAll() + } + + /** + * Stop the underlying [[org.apache.spark.SparkContext]], if any. + */ + protected override def afterAll(): Unit = { + super.afterAll() + if (_spark != null) { + _spark.sessionState.catalog.reset() + _spark.stop() + _spark = null + } + } + + protected override def beforeEach(): Unit = { + super.beforeEach() + DebugFilesystem.clearOpenStreams() + } + + protected override def afterEach(): Unit = { + super.afterEach() + // Clear all persistent datasets after each test + spark.sharedState.cacheManager.clearCache() + // files can be closed from other threads, so wait a bit + // normally this doesn't take more than 1s + eventually(timeout(10.seconds)) { + DebugFilesystem.assertNoOpenStreams() + } + } +} From 0025ddeb1dd4fd6951ecd8456457f6b94124f84e Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Thu, 9 Nov 2017 21:56:20 -0800 Subject: [PATCH 1643/1765] [SPARK-22472][SQL] add null check for top-level primitive values ## What changes were proposed in this pull request? One powerful feature of `Dataset` is, we can easily map SQL rows to Scala/Java objects and do runtime null check automatically. For example, let's say we have a parquet file with schema ``, and we have a `case class Data(a: Int, b: String)`. Users can easily read this parquet file into `Data` objects, and Spark will throw NPE if column `a` has null values. However the null checking is left behind for top-level primitive values. For example, let's say we have a parquet file with schema ``, and we read it into Scala `Int`. If column `a` has null values, we will get some weird results. ``` scala> val ds = spark.read.parquet(...).as[Int] scala> ds.show() +----+ |v | +----+ |null| |1 | +----+ scala> ds.collect res0: Array[Long] = Array(0, 1) scala> ds.map(_ * 2).show +-----+ |value| +-----+ |-2 | |2 | +-----+ ``` This is because internally Spark use some special default values for primitive types, but never expect users to see/operate these default value directly. This PR adds null check for top-level primitive values ## How was this patch tested? new test Author: Wenchen Fan Closes #19707 from cloud-fan/bug. --- .../spark/sql/catalyst/ScalaReflection.scala | 8 +++++++- .../sql/catalyst/ScalaReflectionSuite.scala | 7 ++++++- .../org/apache/spark/sql/DatasetSuite.scala | 18 ++++++++++++++++++ 3 files changed, 31 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala index f62553ddd3971..4e47a5890db9c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala @@ -134,7 +134,13 @@ object ScalaReflection extends ScalaReflection { val tpe = localTypeOf[T] val clsName = getClassNameFromType(tpe) val walkedTypePath = s"""- root class: "$clsName"""" :: Nil - deserializerFor(tpe, None, walkedTypePath) + val expr = deserializerFor(tpe, None, walkedTypePath) + val Schema(_, nullable) = schemaFor(tpe) + if (nullable) { + expr + } else { + AssertNotNull(expr, walkedTypePath) + } } private def deserializerFor( diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala index f77af5db3279b..23e866cdf4917 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala @@ -22,7 +22,7 @@ import java.sql.{Date, Timestamp} import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute import org.apache.spark.sql.catalyst.expressions.{BoundReference, Literal, SpecificInternalRow, UpCast} -import org.apache.spark.sql.catalyst.expressions.objects.NewInstance +import org.apache.spark.sql.catalyst.expressions.objects.{AssertNotNull, NewInstance} import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String @@ -351,4 +351,9 @@ class ScalaReflectionSuite extends SparkFunSuite { assert(argumentsFields(0) == Seq("field.1")) assert(argumentsFields(1) == Seq("field 2")) } + + test("SPARK-22472: add null check for top-level primitive values") { + assert(deserializerFor[Int].isInstanceOf[AssertNotNull]) + assert(!deserializerFor[String].isInstanceOf[AssertNotNull]) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala index c67165c7abca6..6e13a5d491e0f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql import java.io.{Externalizable, ObjectInput, ObjectOutput} import java.sql.{Date, Timestamp} +import org.apache.spark.SparkException import org.apache.spark.sql.catalyst.encoders.{OuterScopes, RowEncoder} import org.apache.spark.sql.catalyst.plans.{LeftAnti, LeftSemi} import org.apache.spark.sql.catalyst.util.sideBySide @@ -1408,6 +1409,23 @@ class DatasetSuite extends QueryTest with SharedSQLContext { checkDataset(ds, SpecialCharClass("1", "2")) } } + + test("SPARK-22472: add null check for top-level primitive values") { + // If the primitive values are from Option, we need to do runtime null check. + val ds = Seq(Some(1), None).toDS().as[Int] + intercept[NullPointerException](ds.collect()) + val e = intercept[SparkException](ds.map(_ * 2).collect()) + assert(e.getCause.isInstanceOf[NullPointerException]) + + withTempPath { path => + Seq(new Integer(1), null).toDF("i").write.parquet(path.getCanonicalPath) + // If the primitive values are from files, we need to do runtime null check. + val ds = spark.read.parquet(path.getCanonicalPath).as[Int] + intercept[NullPointerException](ds.collect()) + val e = intercept[SparkException](ds.map(_ * 2).collect()) + assert(e.getCause.isInstanceOf[NullPointerException]) + } + } } case class SingleData(id: Int) From 28ab5bf59766096bb7e1bdab32b05cb5f0799a48 Mon Sep 17 00:00:00 2001 From: Kent Yao Date: Fri, 10 Nov 2017 12:01:02 +0100 Subject: [PATCH 1644/1765] [SPARK-22487][SQL][HIVE] Remove the unused HIVE_EXECUTION_VERSION property ## What changes were proposed in this pull request? At the beginning https://github.com/apache/spark/pull/2843 added `spark.sql.hive.version` to reveal underlying hive version for jdbc connections. For some time afterwards, it was used as a version identifier for the execution hive client. Actually there is no hive client for executions in spark now and there are no usages of HIVE_EXECUTION_VERSION found in whole spark project. HIVE_EXECUTION_VERSION is set by `spark.sql.hive.version`, which is still set internally in some places or by users, this may confuse developers and users with HIVE_METASTORE_VERSION(spark.sql.hive.metastore.version). It might better to be removed. ## How was this patch tested? modify some existing ut cc cloud-fan gatorsmile Author: Kent Yao Closes #19712 from yaooqinn/SPARK-22487. --- .../sql/hive/thriftserver/SparkSQLEnv.scala | 1 - .../thriftserver/SparkSQLSessionManager.scala | 1 - .../HiveThriftServer2Suites.scala | 27 +++++-------------- .../org/apache/spark/sql/hive/HiveUtils.scala | 25 +++++++---------- .../spark/sql/hive/client/VersionsSuite.scala | 4 +-- 5 files changed, 19 insertions(+), 39 deletions(-) diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLEnv.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLEnv.scala index 01c4eb131a564..5db93b26f550e 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLEnv.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLEnv.scala @@ -55,7 +55,6 @@ private[hive] object SparkSQLEnv extends Logging { metadataHive.setOut(new PrintStream(System.out, true, "UTF-8")) metadataHive.setInfo(new PrintStream(System.err, true, "UTF-8")) metadataHive.setError(new PrintStream(System.err, true, "UTF-8")) - sparkSession.conf.set("spark.sql.hive.version", HiveUtils.hiveExecutionVersion) } } diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLSessionManager.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLSessionManager.scala index 7adaafe5ad5c1..00920c297d493 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLSessionManager.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLSessionManager.scala @@ -77,7 +77,6 @@ private[hive] class SparkSQLSessionManager(hiveServer: HiveServer2, sqlContext: } else { sqlContext.newSession() } - ctx.setConf("spark.sql.hive.version", HiveUtils.hiveExecutionVersion) if (sessionConf != null && sessionConf.containsKey("use:database")) { ctx.sql(s"use ${sessionConf.get("use:database")}") } diff --git a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala index 4997d7f96afa2..b80596f55bdea 100644 --- a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala +++ b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala @@ -155,10 +155,10 @@ class HiveThriftBinaryServerSuite extends HiveThriftJdbcTest { test("Checks Hive version") { withJdbcStatement() { statement => - val resultSet = statement.executeQuery("SET spark.sql.hive.version") + val resultSet = statement.executeQuery("SET spark.sql.hive.metastore.version") resultSet.next() - assert(resultSet.getString(1) === "spark.sql.hive.version") - assert(resultSet.getString(2) === HiveUtils.hiveExecutionVersion) + assert(resultSet.getString(1) === "spark.sql.hive.metastore.version") + assert(resultSet.getString(2) === HiveUtils.builtinHiveVersion) } } @@ -521,20 +521,7 @@ class HiveThriftBinaryServerSuite extends HiveThriftJdbcTest { conf += resultSet.getString(1) -> resultSet.getString(2) } - assert(conf.get("spark.sql.hive.version") === Some("1.2.1")) - } - } - - test("Checks Hive version via SET") { - withJdbcStatement() { statement => - val resultSet = statement.executeQuery("SET") - - val conf = mutable.Map.empty[String, String] - while (resultSet.next()) { - conf += resultSet.getString(1) -> resultSet.getString(2) - } - - assert(conf.get("spark.sql.hive.version") === Some("1.2.1")) + assert(conf.get("spark.sql.hive.metastore.version") === Some("1.2.1")) } } @@ -721,10 +708,10 @@ class HiveThriftHttpServerSuite extends HiveThriftJdbcTest { test("Checks Hive version") { withJdbcStatement() { statement => - val resultSet = statement.executeQuery("SET spark.sql.hive.version") + val resultSet = statement.executeQuery("SET spark.sql.hive.metastore.version") resultSet.next() - assert(resultSet.getString(1) === "spark.sql.hive.version") - assert(resultSet.getString(2) === HiveUtils.hiveExecutionVersion) + assert(resultSet.getString(1) === "spark.sql.hive.metastore.version") + assert(resultSet.getString(2) === HiveUtils.builtinHiveVersion) } } } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveUtils.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveUtils.scala index 80b9a3dc9605d..d8e08f1f6df50 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveUtils.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveUtils.scala @@ -58,28 +58,23 @@ private[spark] object HiveUtils extends Logging { } /** The version of hive used internally by Spark SQL. */ - val hiveExecutionVersion: String = "1.2.1" + val builtinHiveVersion: String = "1.2.1" val HIVE_METASTORE_VERSION = buildConf("spark.sql.hive.metastore.version") .doc("Version of the Hive metastore. Available options are " + - s"0.12.0 through $hiveExecutionVersion.") + s"0.12.0 through 2.1.1.") .stringConf - .createWithDefault(hiveExecutionVersion) - - val HIVE_EXECUTION_VERSION = buildConf("spark.sql.hive.version") - .doc("Version of Hive used internally by Spark SQL.") - .stringConf - .createWithDefault(hiveExecutionVersion) + .createWithDefault(builtinHiveVersion) val HIVE_METASTORE_JARS = buildConf("spark.sql.hive.metastore.jars") .doc(s""" | Location of the jars that should be used to instantiate the HiveMetastoreClient. | This property can be one of three options: " | 1. "builtin" - | Use Hive ${hiveExecutionVersion}, which is bundled with the Spark assembly when + | Use Hive ${builtinHiveVersion}, which is bundled with the Spark assembly when | -Phive is enabled. When this option is chosen, | spark.sql.hive.metastore.version must be either - | ${hiveExecutionVersion} or not defined. + | ${builtinHiveVersion} or not defined. | 2. "maven" | Use Hive jars of specified version downloaded from Maven repositories. | 3. A classpath in the standard format for both Hive and Hadoop. @@ -259,9 +254,9 @@ private[spark] object HiveUtils extends Logging { protected[hive] def newClientForExecution( conf: SparkConf, hadoopConf: Configuration): HiveClientImpl = { - logInfo(s"Initializing execution hive, version $hiveExecutionVersion") + logInfo(s"Initializing execution hive, version $builtinHiveVersion") val loader = new IsolatedClientLoader( - version = IsolatedClientLoader.hiveVersion(hiveExecutionVersion), + version = IsolatedClientLoader.hiveVersion(builtinHiveVersion), sparkConf = conf, execJars = Seq.empty, hadoopConf = hadoopConf, @@ -297,12 +292,12 @@ private[spark] object HiveUtils extends Logging { val metaVersion = IsolatedClientLoader.hiveVersion(hiveMetastoreVersion) val isolatedLoader = if (hiveMetastoreJars == "builtin") { - if (hiveExecutionVersion != hiveMetastoreVersion) { + if (builtinHiveVersion != hiveMetastoreVersion) { throw new IllegalArgumentException( "Builtin jars can only be used when hive execution version == hive metastore version. " + - s"Execution: $hiveExecutionVersion != Metastore: $hiveMetastoreVersion. " + + s"Execution: $builtinHiveVersion != Metastore: $hiveMetastoreVersion. " + "Specify a vaild path to the correct hive jars using $HIVE_METASTORE_JARS " + - s"or change ${HIVE_METASTORE_VERSION.key} to $hiveExecutionVersion.") + s"or change ${HIVE_METASTORE_VERSION.key} to $builtinHiveVersion.") } // We recursively find all jars in the class loader chain, diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala index edb9a9ffbaaf6..9ed39cc80f50d 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala @@ -73,7 +73,7 @@ class VersionsSuite extends SparkFunSuite with Logging { } test("success sanity check") { - val badClient = buildClient(HiveUtils.hiveExecutionVersion, new Configuration()) + val badClient = buildClient(HiveUtils.builtinHiveVersion, new Configuration()) val db = new CatalogDatabase("default", "desc", new URI("loc"), Map()) badClient.createDatabase(db, ignoreIfExists = true) } @@ -81,7 +81,7 @@ class VersionsSuite extends SparkFunSuite with Logging { test("hadoop configuration preserved") { val hadoopConf = new Configuration() hadoopConf.set("test", "success") - val client = buildClient(HiveUtils.hiveExecutionVersion, hadoopConf) + val client = buildClient(HiveUtils.builtinHiveVersion, hadoopConf) assert("success" === client.getConf("test", null)) } From 9b9827759af2ca3eea146a6032f9165f640ce152 Mon Sep 17 00:00:00 2001 From: Pralabh Kumar Date: Fri, 10 Nov 2017 13:17:25 +0200 Subject: [PATCH 1645/1765] [SPARK-20199][ML] : Provided featureSubsetStrategy to GBTClassifier and GBTRegressor ## What changes were proposed in this pull request? (Provided featureSubset Strategy to GBTClassifier a) Moved featureSubsetStrategy to TreeEnsembleParams b) Changed GBTClassifier to pass featureSubsetStrategy val firstTreeModel = firstTree.train(input, treeStrategy, featureSubsetStrategy)) ## How was this patch tested? a) Tested GradientBoostedTreeClassifierExample by adding .setFeatureSubsetStrategy with GBTClassifier b)Added test cases in GBTClassifierSuite and GBTRegressorSuite Author: Pralabh Kumar Closes #18118 from pralabhkumar/develop. --- ...GradientBoostedTreeClassifierExample.scala | 1 + .../ml/classification/GBTClassifier.scala | 9 +- .../RandomForestClassifier.scala | 2 +- .../ml/regression/DecisionTreeRegressor.scala | 8 +- .../spark/ml/regression/GBTRegressor.scala | 9 +- .../ml/regression/RandomForestRegressor.scala | 2 +- .../ml/tree/impl/DecisionTreeMetadata.scala | 4 +- .../ml/tree/impl/GradientBoostedTrees.scala | 25 ++++-- .../org/apache/spark/ml/tree/treeParams.scala | 82 ++++++++++--------- .../mllib/tree/GradientBoostedTrees.scala | 4 +- .../spark/mllib/tree/RandomForest.scala | 2 +- .../classification/GBTClassifierSuite.scala | 29 +++++++ .../ml/regression/GBTRegressorSuite.scala | 29 +++++++ .../tree/impl/GradientBoostedTreesSuite.scala | 4 +- 14 files changed, 146 insertions(+), 64 deletions(-) diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/GradientBoostedTreeClassifierExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/GradientBoostedTreeClassifierExample.scala index 9a39acfbf37e5..3656773c8b817 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/GradientBoostedTreeClassifierExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/GradientBoostedTreeClassifierExample.scala @@ -59,6 +59,7 @@ object GradientBoostedTreeClassifierExample { .setLabelCol("indexedLabel") .setFeaturesCol("indexedFeatures") .setMaxIter(10) + .setFeatureSubsetStrategy("auto") // Convert indexed labels back to original labels. val labelConverter = new IndexToString() diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala index 3da809ce5f77c..f11bc1d8fe415 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala @@ -135,6 +135,11 @@ class GBTClassifier @Since("1.4.0") ( @Since("1.4.0") override def setStepSize(value: Double): this.type = set(stepSize, value) + /** @group setParam */ + @Since("2.3.0") + override def setFeatureSubsetStrategy(value: String): this.type = + set(featureSubsetStrategy, value) + // Parameters from GBTClassifierParams: /** @group setParam */ @@ -167,12 +172,12 @@ class GBTClassifier @Since("1.4.0") ( val instr = Instrumentation.create(this, oldDataset) instr.logParams(labelCol, featuresCol, predictionCol, impurity, lossType, maxDepth, maxBins, maxIter, maxMemoryInMB, minInfoGain, minInstancesPerNode, - seed, stepSize, subsamplingRate, cacheNodeIds, checkpointInterval) + seed, stepSize, subsamplingRate, cacheNodeIds, checkpointInterval, featureSubsetStrategy) instr.logNumFeatures(numFeatures) instr.logNumClasses(numClasses) val (baseLearners, learnerWeights) = GradientBoostedTrees.run(oldDataset, boostingStrategy, - $(seed)) + $(seed), $(featureSubsetStrategy)) val m = new GBTClassificationModel(uid, baseLearners, learnerWeights, numFeatures) instr.logSuccess(m) m diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala index ab4c235209289..78a4972adbdbb 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala @@ -158,7 +158,7 @@ object RandomForestClassifier extends DefaultParamsReadable[RandomForestClassifi /** Accessor for supported featureSubsetStrategy settings: auto, all, onethird, sqrt, log2 */ @Since("1.4.0") final val supportedFeatureSubsetStrategies: Array[String] = - RandomForestParams.supportedFeatureSubsetStrategies + TreeEnsembleParams.supportedFeatureSubsetStrategies @Since("2.0.0") override def load(path: String): RandomForestClassifier = super.load(path) diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala index 01c5cc1c7efa9..0291a57487c47 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala @@ -117,12 +117,14 @@ class DecisionTreeRegressor @Since("1.4.0") (@Since("1.4.0") override val uid: S } /** (private[ml]) Train a decision tree on an RDD */ - private[ml] def train(data: RDD[LabeledPoint], - oldStrategy: OldStrategy): DecisionTreeRegressionModel = { + private[ml] def train( + data: RDD[LabeledPoint], + oldStrategy: OldStrategy, + featureSubsetStrategy: String): DecisionTreeRegressionModel = { val instr = Instrumentation.create(this, data) instr.logParams(params: _*) - val trees = RandomForest.run(data, oldStrategy, numTrees = 1, featureSubsetStrategy = "all", + val trees = RandomForest.run(data, oldStrategy, numTrees = 1, featureSubsetStrategy, seed = $(seed), instr = Some(instr), parentUID = Some(uid)) val m = trees.head.asInstanceOf[DecisionTreeRegressionModel] diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala index 08d175cb94442..f41d15b62dddd 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala @@ -140,6 +140,11 @@ class GBTRegressor @Since("1.4.0") (@Since("1.4.0") override val uid: String) @Since("1.4.0") def setLossType(value: String): this.type = set(lossType, value) + /** @group setParam */ + @Since("2.3.0") + override def setFeatureSubsetStrategy(value: String): this.type = + set(featureSubsetStrategy, value) + override protected def train(dataset: Dataset[_]): GBTRegressionModel = { val categoricalFeatures: Map[Int, Int] = MetadataUtils.getCategoricalFeatures(dataset.schema($(featuresCol))) @@ -150,11 +155,11 @@ class GBTRegressor @Since("1.4.0") (@Since("1.4.0") override val uid: String) val instr = Instrumentation.create(this, oldDataset) instr.logParams(labelCol, featuresCol, predictionCol, impurity, lossType, maxDepth, maxBins, maxIter, maxMemoryInMB, minInfoGain, minInstancesPerNode, - seed, stepSize, subsamplingRate, cacheNodeIds, checkpointInterval) + seed, stepSize, subsamplingRate, cacheNodeIds, checkpointInterval, featureSubsetStrategy) instr.logNumFeatures(numFeatures) val (baseLearners, learnerWeights) = GradientBoostedTrees.run(oldDataset, boostingStrategy, - $(seed)) + $(seed), $(featureSubsetStrategy)) val m = new GBTRegressionModel(uid, baseLearners, learnerWeights, numFeatures) instr.logSuccess(m) m diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala index a58da50fad972..200b234b79978 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala @@ -149,7 +149,7 @@ object RandomForestRegressor extends DefaultParamsReadable[RandomForestRegressor /** Accessor for supported featureSubsetStrategy settings: auto, all, onethird, sqrt, log2 */ @Since("1.4.0") final val supportedFeatureSubsetStrategies: Array[String] = - RandomForestParams.supportedFeatureSubsetStrategies + TreeEnsembleParams.supportedFeatureSubsetStrategies @Since("2.0.0") override def load(path: String): RandomForestRegressor = super.load(path) diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/DecisionTreeMetadata.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/DecisionTreeMetadata.scala index 8a9dcb486b7bf..53189e0797b6a 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/DecisionTreeMetadata.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/DecisionTreeMetadata.scala @@ -22,7 +22,7 @@ import scala.util.Try import org.apache.spark.internal.Logging import org.apache.spark.ml.feature.LabeledPoint -import org.apache.spark.ml.tree.RandomForestParams +import org.apache.spark.ml.tree.TreeEnsembleParams import org.apache.spark.mllib.tree.configuration.Algo._ import org.apache.spark.mllib.tree.configuration.QuantileStrategy._ import org.apache.spark.mllib.tree.configuration.Strategy @@ -200,7 +200,7 @@ private[spark] object DecisionTreeMetadata extends Logging { Try(_featureSubsetStrategy.toDouble).filter(_ > 0).filter(_ <= 1.0).toOption match { case Some(value) => math.ceil(value * numFeatures).toInt case _ => throw new IllegalArgumentException(s"Supported values:" + - s" ${RandomForestParams.supportedFeatureSubsetStrategies.mkString(", ")}," + + s" ${TreeEnsembleParams.supportedFeatureSubsetStrategies.mkString(", ")}," + s" (0.0-1.0], [1-n].") } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/GradientBoostedTrees.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/GradientBoostedTrees.scala index e32447a79abb8..bd8c9afb5e209 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/GradientBoostedTrees.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/GradientBoostedTrees.scala @@ -42,16 +42,18 @@ private[spark] object GradientBoostedTrees extends Logging { def run( input: RDD[LabeledPoint], boostingStrategy: OldBoostingStrategy, - seed: Long): (Array[DecisionTreeRegressionModel], Array[Double]) = { + seed: Long, + featureSubsetStrategy: String): (Array[DecisionTreeRegressionModel], Array[Double]) = { val algo = boostingStrategy.treeStrategy.algo algo match { case OldAlgo.Regression => - GradientBoostedTrees.boost(input, input, boostingStrategy, validate = false, seed) + GradientBoostedTrees.boost(input, input, boostingStrategy, validate = false, + seed, featureSubsetStrategy) case OldAlgo.Classification => // Map labels to -1, +1 so binary classification can be treated as regression. val remappedInput = input.map(x => new LabeledPoint((x.label * 2) - 1, x.features)) GradientBoostedTrees.boost(remappedInput, remappedInput, boostingStrategy, validate = false, - seed) + seed, featureSubsetStrategy) case _ => throw new IllegalArgumentException(s"$algo is not supported by gradient boosting.") } @@ -73,11 +75,13 @@ private[spark] object GradientBoostedTrees extends Logging { input: RDD[LabeledPoint], validationInput: RDD[LabeledPoint], boostingStrategy: OldBoostingStrategy, - seed: Long): (Array[DecisionTreeRegressionModel], Array[Double]) = { + seed: Long, + featureSubsetStrategy: String): (Array[DecisionTreeRegressionModel], Array[Double]) = { val algo = boostingStrategy.treeStrategy.algo algo match { case OldAlgo.Regression => - GradientBoostedTrees.boost(input, validationInput, boostingStrategy, validate = true, seed) + GradientBoostedTrees.boost(input, validationInput, boostingStrategy, + validate = true, seed, featureSubsetStrategy) case OldAlgo.Classification => // Map labels to -1, +1 so binary classification can be treated as regression. val remappedInput = input.map( @@ -85,7 +89,7 @@ private[spark] object GradientBoostedTrees extends Logging { val remappedValidationInput = validationInput.map( x => new LabeledPoint((x.label * 2) - 1, x.features)) GradientBoostedTrees.boost(remappedInput, remappedValidationInput, boostingStrategy, - validate = true, seed) + validate = true, seed, featureSubsetStrategy) case _ => throw new IllegalArgumentException(s"$algo is not supported by the gradient boosting.") } @@ -245,7 +249,8 @@ private[spark] object GradientBoostedTrees extends Logging { validationInput: RDD[LabeledPoint], boostingStrategy: OldBoostingStrategy, validate: Boolean, - seed: Long): (Array[DecisionTreeRegressionModel], Array[Double]) = { + seed: Long, + featureSubsetStrategy: String): (Array[DecisionTreeRegressionModel], Array[Double]) = { val timer = new TimeTracker() timer.start("total") timer.start("init") @@ -258,6 +263,7 @@ private[spark] object GradientBoostedTrees extends Logging { val baseLearnerWeights = new Array[Double](numIterations) val loss = boostingStrategy.loss val learningRate = boostingStrategy.learningRate + // Prepare strategy for individual trees, which use regression with variance impurity. val treeStrategy = boostingStrategy.treeStrategy.copy val validationTol = boostingStrategy.validationTol @@ -288,7 +294,7 @@ private[spark] object GradientBoostedTrees extends Logging { // Initialize tree timer.start("building tree 0") val firstTree = new DecisionTreeRegressor().setSeed(seed) - val firstTreeModel = firstTree.train(input, treeStrategy) + val firstTreeModel = firstTree.train(input, treeStrategy, featureSubsetStrategy) val firstTreeWeight = 1.0 baseLearners(0) = firstTreeModel baseLearnerWeights(0) = firstTreeWeight @@ -319,8 +325,9 @@ private[spark] object GradientBoostedTrees extends Logging { logDebug("###################################################") logDebug("Gradient boosting tree iteration " + m) logDebug("###################################################") + val dt = new DecisionTreeRegressor().setSeed(seed + m) - val model = dt.train(data, treeStrategy) + val model = dt.train(data, treeStrategy, featureSubsetStrategy) timer.stop(s"building tree $m") // Update partial model baseLearners(m) = model diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala index 47079d9c6bb1c..81b6222acc7ce 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala @@ -320,6 +320,12 @@ private[ml] trait DecisionTreeRegressorParams extends DecisionTreeParams } } +private[spark] object TreeEnsembleParams { + // These options should be lowercase. + final val supportedFeatureSubsetStrategies: Array[String] = + Array("auto", "all", "onethird", "sqrt", "log2").map(_.toLowerCase(Locale.ROOT)) +} + /** * Parameters for Decision Tree-based ensemble algorithms. * @@ -359,38 +365,6 @@ private[ml] trait TreeEnsembleParams extends DecisionTreeParams { oldImpurity: OldImpurity): OldStrategy = { super.getOldStrategy(categoricalFeatures, numClasses, oldAlgo, oldImpurity, getSubsamplingRate) } -} - -/** - * Parameters for Random Forest algorithms. - */ -private[ml] trait RandomForestParams extends TreeEnsembleParams { - - /** - * Number of trees to train (>= 1). - * If 1, then no bootstrapping is used. If > 1, then bootstrapping is done. - * TODO: Change to always do bootstrapping (simpler). SPARK-7130 - * (default = 20) - * - * Note: The reason that we cannot add this to both GBT and RF (i.e. in TreeEnsembleParams) - * is the param `maxIter` controls how many trees a GBT has. The semantics in the algorithms - * are a bit different. - * @group param - */ - final val numTrees: IntParam = new IntParam(this, "numTrees", "Number of trees to train (>= 1)", - ParamValidators.gtEq(1)) - - setDefault(numTrees -> 20) - - /** - * @deprecated This method is deprecated and will be removed in 3.0.0. - * @group setParam - */ - @deprecated("This method is deprecated and will be removed in 3.0.0.", "2.1.0") - def setNumTrees(value: Int): this.type = set(numTrees, value) - - /** @group getParam */ - final def getNumTrees: Int = $(numTrees) /** * The number of features to consider for splits at each tree node. @@ -420,10 +394,10 @@ private[ml] trait RandomForestParams extends TreeEnsembleParams { */ final val featureSubsetStrategy: Param[String] = new Param[String](this, "featureSubsetStrategy", "The number of features to consider for splits at each tree node." + - s" Supported options: ${RandomForestParams.supportedFeatureSubsetStrategies.mkString(", ")}" + + s" Supported options: ${TreeEnsembleParams.supportedFeatureSubsetStrategies.mkString(", ")}" + s", (0.0-1.0], [1-n].", (value: String) => - RandomForestParams.supportedFeatureSubsetStrategies.contains( + TreeEnsembleParams.supportedFeatureSubsetStrategies.contains( value.toLowerCase(Locale.ROOT)) || Try(value.toInt).filter(_ > 0).isSuccess || Try(value.toDouble).filter(_ > 0).filter(_ <= 1.0).isSuccess) @@ -431,7 +405,7 @@ private[ml] trait RandomForestParams extends TreeEnsembleParams { setDefault(featureSubsetStrategy -> "auto") /** - * @deprecated This method is deprecated and will be removed in 3.0.0. + * @deprecated This method is deprecated and will be removed in 3.0.0 * @group setParam */ @deprecated("This method is deprecated and will be removed in 3.0.0.", "2.1.0") @@ -441,10 +415,38 @@ private[ml] trait RandomForestParams extends TreeEnsembleParams { final def getFeatureSubsetStrategy: String = $(featureSubsetStrategy).toLowerCase(Locale.ROOT) } -private[spark] object RandomForestParams { - // These options should be lowercase. - final val supportedFeatureSubsetStrategies: Array[String] = - Array("auto", "all", "onethird", "sqrt", "log2").map(_.toLowerCase(Locale.ROOT)) + + +/** + * Parameters for Random Forest algorithms. + */ +private[ml] trait RandomForestParams extends TreeEnsembleParams { + + /** + * Number of trees to train (>= 1). + * If 1, then no bootstrapping is used. If > 1, then bootstrapping is done. + * TODO: Change to always do bootstrapping (simpler). SPARK-7130 + * (default = 20) + * + * Note: The reason that we cannot add this to both GBT and RF (i.e. in TreeEnsembleParams) + * is the param `maxIter` controls how many trees a GBT has. The semantics in the algorithms + * are a bit different. + * @group param + */ + final val numTrees: IntParam = new IntParam(this, "numTrees", "Number of trees to train (>= 1)", + ParamValidators.gtEq(1)) + + setDefault(numTrees -> 20) + + /** + * @deprecated This method is deprecated and will be removed in 3.0.0. + * @group setParam + */ + @deprecated("This method is deprecated and will be removed in 3.0.0.", "2.1.0") + def setNumTrees(value: Int): this.type = set(numTrees, value) + + /** @group getParam */ + final def getNumTrees: Int = $(numTrees) } private[ml] trait RandomForestClassifierParams @@ -497,6 +499,8 @@ private[ml] trait GBTParams extends TreeEnsembleParams with HasMaxIter with HasS setDefault(maxIter -> 20, stepSize -> 0.1) + setDefault(featureSubsetStrategy -> "all") + /** (private[ml]) Create a BoostingStrategy instance to use with the old API. */ private[ml] def getOldBoostingStrategy( categoricalFeatures: Map[Int, Int], diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala index df2c1b02f4f40..d24d8da0dab48 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala @@ -69,7 +69,7 @@ class GradientBoostedTrees private[spark] ( val algo = boostingStrategy.treeStrategy.algo val (trees, treeWeights) = NewGBT.run(input.map { point => NewLabeledPoint(point.label, point.features.asML) - }, boostingStrategy, seed.toLong) + }, boostingStrategy, seed.toLong, "all") new GradientBoostedTreesModel(algo, trees.map(_.toOld), treeWeights) } @@ -101,7 +101,7 @@ class GradientBoostedTrees private[spark] ( NewLabeledPoint(point.label, point.features.asML) }, validationInput.map { point => NewLabeledPoint(point.label, point.features.asML) - }, boostingStrategy, seed.toLong) + }, boostingStrategy, seed.toLong, "all") new GradientBoostedTreesModel(algo, trees.map(_.toOld), treeWeights) } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala index d1331a57de27b..a8c5286f3dc10 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala @@ -23,7 +23,7 @@ import scala.util.Try import org.apache.spark.annotation.Since import org.apache.spark.api.java.JavaRDD import org.apache.spark.internal.Logging -import org.apache.spark.ml.tree.{DecisionTreeModel => NewDTModel, RandomForestParams => NewRFParams} +import org.apache.spark.ml.tree.{DecisionTreeModel => NewDTModel, TreeEnsembleParams => NewRFParams} import org.apache.spark.ml.tree.impl.{RandomForest => NewRandomForest} import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.tree.configuration.Algo._ diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala index 8000143d4d142..978f89c459f0a 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala @@ -83,6 +83,7 @@ class GBTClassifierSuite extends SparkFunSuite with MLlibTestSparkContext assert(gbt.getPredictionCol === "prediction") assert(gbt.getRawPredictionCol === "rawPrediction") assert(gbt.getProbabilityCol === "probability") + assert(gbt.getFeatureSubsetStrategy === "all") val df = trainData.toDF() val model = gbt.fit(df) model.transform(df) @@ -95,6 +96,7 @@ class GBTClassifierSuite extends SparkFunSuite with MLlibTestSparkContext assert(model.getPredictionCol === "prediction") assert(model.getRawPredictionCol === "rawPrediction") assert(model.getProbabilityCol === "probability") + assert(model.getFeatureSubsetStrategy === "all") assert(model.hasParent) MLTestingUtils.checkCopyAndUids(gbt, model) @@ -356,6 +358,33 @@ class GBTClassifierSuite extends SparkFunSuite with MLlibTestSparkContext assert(importances.toArray.forall(_ >= 0.0)) } + ///////////////////////////////////////////////////////////////////////////// + // Tests of feature subset strategy + ///////////////////////////////////////////////////////////////////////////// + test("Tests of feature subset strategy") { + val numClasses = 2 + val gbt = new GBTClassifier() + .setSeed(123) + .setMaxDepth(3) + .setMaxIter(5) + .setFeatureSubsetStrategy("all") + + // In this data, feature 1 is very important. + val data: RDD[LabeledPoint] = TreeTests.featureImportanceData(sc) + val categoricalFeatures = Map.empty[Int, Int] + val df: DataFrame = TreeTests.setMetadata(data, categoricalFeatures, numClasses) + + val importances = gbt.fit(df).featureImportances + val mostImportantFeature = importances.argmax + assert(mostImportantFeature === 1) + + // GBT with different featureSubsetStrategy + val gbtWithFeatureSubset = gbt.setFeatureSubsetStrategy("1") + val importanceFeatures = gbtWithFeatureSubset.fit(df).featureImportances + val mostIF = importanceFeatures.argmax + assert(mostImportantFeature !== mostIF) + } + ///////////////////////////////////////////////////////////////////////////// // Tests of model save/load ///////////////////////////////////////////////////////////////////////////// diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala index 2da25f7e0100a..ecbb57126d759 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala @@ -165,6 +165,35 @@ class GBTRegressorSuite extends SparkFunSuite with MLlibTestSparkContext assert(importances.toArray.forall(_ >= 0.0)) } + ///////////////////////////////////////////////////////////////////////////// + // Tests of feature subset strategy + ///////////////////////////////////////////////////////////////////////////// + test("Tests of feature subset strategy") { + val numClasses = 2 + val gbt = new GBTRegressor() + .setMaxDepth(3) + .setMaxIter(5) + .setSeed(123) + .setFeatureSubsetStrategy("all") + + // In this data, feature 1 is very important. + val data: RDD[LabeledPoint] = TreeTests.featureImportanceData(sc) + val categoricalFeatures = Map.empty[Int, Int] + val df: DataFrame = TreeTests.setMetadata(data, categoricalFeatures, numClasses) + + val importances = gbt.fit(df).featureImportances + val mostImportantFeature = importances.argmax + assert(mostImportantFeature === 1) + + // GBT with different featureSubsetStrategy + val gbtWithFeatureSubset = gbt.setFeatureSubsetStrategy("1") + val importanceFeatures = gbtWithFeatureSubset.fit(df).featureImportances + val mostIF = importanceFeatures.argmax + assert(mostImportantFeature !== mostIF) + } + + + ///////////////////////////////////////////////////////////////////////////// // Tests of model save/load ///////////////////////////////////////////////////////////////////////////// diff --git a/mllib/src/test/scala/org/apache/spark/ml/tree/impl/GradientBoostedTreesSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/GradientBoostedTreesSuite.scala index 4109a299091dc..366d5ec3a53fb 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/tree/impl/GradientBoostedTreesSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/GradientBoostedTreesSuite.scala @@ -50,12 +50,12 @@ class GradientBoostedTreesSuite extends SparkFunSuite with MLlibTestSparkContext val boostingStrategy = new BoostingStrategy(treeStrategy, loss, numIterations, validationTol = 0.0) val (validateTrees, validateTreeWeights) = GradientBoostedTrees - .runWithValidation(trainRdd, validateRdd, boostingStrategy, 42L) + .runWithValidation(trainRdd, validateRdd, boostingStrategy, 42L, "all") val numTrees = validateTrees.length assert(numTrees !== numIterations) // Test that it performs better on the validation dataset. - val (trees, treeWeights) = GradientBoostedTrees.run(trainRdd, boostingStrategy, 42L) + val (trees, treeWeights) = GradientBoostedTrees.run(trainRdd, boostingStrategy, 42L, "all") val (errorWithoutValidation, errorWithValidation) = { if (algo == Classification) { val remappedRdd = validateRdd.map(x => new LabeledPoint(2 * x.label - 1, x.features)) From 1c923d7d65dd94996f0fe2cf9851a1ae738c5c0c Mon Sep 17 00:00:00 2001 From: Xianyang Liu Date: Fri, 10 Nov 2017 12:43:29 +0100 Subject: [PATCH 1646/1765] [SPARK-22450][CORE][MLLIB] safely register class for mllib ## What changes were proposed in this pull request? There are still some algorithms based on mllib, such as KMeans. For now, many mllib common class (such as: Vector, DenseVector, SparseVector, Matrix, DenseMatrix, SparseMatrix) are not registered in Kryo. So there are some performance issues for those object serialization or deserialization. Previously dicussed: https://github.com/apache/spark/pull/19586 ## How was this patch tested? New test case. Author: Xianyang Liu Closes #19661 from ConeyLiu/register_vector. --- .../spark/serializer/KryoSerializer.scala | 26 ++++++++++ .../spark/ml/feature/InstanceSuit.scala | 47 +++++++++++++++++++ .../spark/mllib/linalg/MatricesSuite.scala | 28 ++++++++++- .../spark/mllib/linalg/VectorsSuite.scala | 19 +++++++- 4 files changed, 118 insertions(+), 2 deletions(-) create mode 100644 mllib/src/test/scala/org/apache/spark/ml/feature/InstanceSuit.scala diff --git a/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala b/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala index 58483c9577d29..2259d1a2d555d 100644 --- a/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala +++ b/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala @@ -25,6 +25,7 @@ import javax.annotation.Nullable import scala.collection.JavaConverters._ import scala.collection.mutable.ArrayBuffer import scala.reflect.ClassTag +import scala.util.control.NonFatal import com.esotericsoftware.kryo.{Kryo, KryoException, Serializer => KryoClassSerializer} import com.esotericsoftware.kryo.io.{Input => KryoInput, Output => KryoOutput} @@ -178,6 +179,31 @@ class KryoSerializer(conf: SparkConf) kryo.register(Utils.classForName("scala.collection.immutable.Map$EmptyMap$")) kryo.register(classOf[ArrayBuffer[Any]]) + // We can't load those class directly in order to avoid unnecessary jar dependencies. + // We load them safely, ignore it if the class not found. + Seq("org.apache.spark.mllib.linalg.Vector", + "org.apache.spark.mllib.linalg.DenseVector", + "org.apache.spark.mllib.linalg.SparseVector", + "org.apache.spark.mllib.linalg.Matrix", + "org.apache.spark.mllib.linalg.DenseMatrix", + "org.apache.spark.mllib.linalg.SparseMatrix", + "org.apache.spark.ml.linalg.Vector", + "org.apache.spark.ml.linalg.DenseVector", + "org.apache.spark.ml.linalg.SparseVector", + "org.apache.spark.ml.linalg.Matrix", + "org.apache.spark.ml.linalg.DenseMatrix", + "org.apache.spark.ml.linalg.SparseMatrix", + "org.apache.spark.ml.feature.Instance", + "org.apache.spark.ml.feature.OffsetInstance" + ).foreach { name => + try { + val clazz = Utils.classForName(name) + kryo.register(clazz) + } catch { + case NonFatal(_) => // do nothing + } + } + kryo.setClassLoader(classLoader) kryo } diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/InstanceSuit.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/InstanceSuit.scala new file mode 100644 index 0000000000000..88c85a9425e78 --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/InstanceSuit.scala @@ -0,0 +1,47 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.feature + +import scala.reflect.ClassTag + +import org.apache.spark.{SparkConf, SparkFunSuite} +import org.apache.spark.ml.linalg.Vectors +import org.apache.spark.serializer.KryoSerializer + +class InstanceSuit extends SparkFunSuite{ + test("Kryo class register") { + val conf = new SparkConf(false) + conf.set("spark.kryo.registrationRequired", "true") + + val ser = new KryoSerializer(conf) + val serInstance = new KryoSerializer(conf).newInstance() + + def check[T: ClassTag](t: T) { + assert(serInstance.deserialize[T](serInstance.serialize(t)) === t) + } + + val instance1 = Instance(19.0, 2.0, Vectors.dense(1.0, 7.0)) + val instance2 = Instance(17.0, 1.0, Vectors.dense(0.0, 5.0).toSparse) + val oInstance1 = OffsetInstance(0.2, 1.0, 2.0, Vectors.dense(0.0, 5.0)) + val oInstance2 = OffsetInstance(0.2, 1.0, 2.0, Vectors.dense(0.0, 5.0).toSparse) + check(instance1) + check(instance2) + check(oInstance1) + check(oInstance2) + } +} diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/MatricesSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/MatricesSuite.scala index c8ac92eecf40b..d76edb940b2bd 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/MatricesSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/MatricesSuite.scala @@ -20,16 +20,42 @@ package org.apache.spark.mllib.linalg import java.util.Random import scala.collection.mutable.{Map => MutableMap} +import scala.reflect.ClassTag import breeze.linalg.{CSCMatrix, Matrix => BM} import org.mockito.Mockito.when import org.scalatest.mockito.MockitoSugar._ -import org.apache.spark.SparkFunSuite +import org.apache.spark.{SparkConf, SparkFunSuite} import org.apache.spark.ml.{linalg => newlinalg} import org.apache.spark.mllib.util.TestingUtils._ +import org.apache.spark.serializer.KryoSerializer class MatricesSuite extends SparkFunSuite { + test("kryo class register") { + val conf = new SparkConf(false) + conf.set("spark.kryo.registrationRequired", "true") + + val ser = new KryoSerializer(conf).newInstance() + + def check[T: ClassTag](t: T) { + assert(ser.deserialize[T](ser.serialize(t)) === t) + } + + val m = 3 + val n = 2 + val denseValues = Array(0.0, 1.0, 2.0, 3.0, 4.0, 5.0) + val denseMat = Matrices.dense(m, n, denseValues).asInstanceOf[DenseMatrix] + + val sparseValues = Array(1.0, 2.0, 4.0, 5.0) + val colPtrs = Array(0, 2, 4) + val rowIndices = Array(1, 2, 1, 2) + val sparseMat = + Matrices.sparse(m, n, colPtrs, rowIndices, sparseValues).asInstanceOf[SparseMatrix] + check(denseMat) + check(sparseMat) + } + test("dense matrix construction") { val m = 3 val n = 2 diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala index a1e3ee54b49ff..4074bead421e6 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala @@ -17,15 +17,17 @@ package org.apache.spark.mllib.linalg +import scala.reflect.ClassTag import scala.util.Random import breeze.linalg.{squaredDistance => breezeSquaredDistance, DenseMatrix => BDM} import org.json4s.jackson.JsonMethods.{parse => parseJson} -import org.apache.spark.{SparkException, SparkFunSuite} +import org.apache.spark.{SparkConf, SparkException, SparkFunSuite} import org.apache.spark.internal.Logging import org.apache.spark.ml.{linalg => newlinalg} import org.apache.spark.mllib.util.TestingUtils._ +import org.apache.spark.serializer.KryoSerializer class VectorsSuite extends SparkFunSuite with Logging { @@ -34,6 +36,21 @@ class VectorsSuite extends SparkFunSuite with Logging { val indices = Array(0, 2, 3) val values = Array(0.1, 0.3, 0.4) + test("kryo class register") { + val conf = new SparkConf(false) + conf.set("spark.kryo.registrationRequired", "true") + + val ser = new KryoSerializer(conf).newInstance() + def check[T: ClassTag](t: T) { + assert(ser.deserialize[T](ser.serialize(t)) === t) + } + + val desVec = Vectors.dense(arr).asInstanceOf[DenseVector] + val sparVec = Vectors.sparse(n, indices, values).asInstanceOf[SparseVector] + check(desVec) + check(sparVec) + } + test("dense vector construction with varargs") { val vec = Vectors.dense(arr).asInstanceOf[DenseVector] assert(vec.size === arr.length) From 5b41cbf13b6d6f47b8f8f1ffcc7a3348018627ca Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Fri, 10 Nov 2017 11:24:24 -0600 Subject: [PATCH 1647/1765] [SPARK-22473][TEST] Replace deprecated AsyncAssertions.Waiter and methods of java.sql.Date ## What changes were proposed in this pull request? In `spark-sql` module tests there are deprecations warnings caused by the usage of deprecated methods of `java.sql.Date` and the usage of the deprecated `AsyncAssertions.Waiter` class. This PR replace the deprecated methods of `java.sql.Date` with non-deprecated ones (using `Calendar` where needed). It replaces also the deprecated `org.scalatest.concurrent.AsyncAssertions.Waiter` with `org.scalatest.concurrent.Waiters._`. ## How was this patch tested? existing UTs Author: Marco Gaido Closes #19696 from mgaido91/SPARK-22473. --- .../test/scala/org/apache/spark/sql/DatasetSuite.scala | 4 ++-- .../sql/execution/streaming/HDFSMetadataLogSuite.scala | 4 ++-- .../spark/sql/streaming/EventTimeWatermarkSuite.scala | 8 ++++++-- .../spark/sql/streaming/StreamingQueryListenerSuite.scala | 2 +- 4 files changed, 11 insertions(+), 7 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala index 6e13a5d491e0f..b02db7721aa7f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala @@ -1274,8 +1274,8 @@ class DatasetSuite extends QueryTest with SharedSQLContext { assert(spark.range(1).map { x => scala.math.BigDecimal(1, 18) }.head == scala.math.BigDecimal(1, 18)) - assert(spark.range(1).map { x => new java.sql.Date(2016, 12, 12) }.head == - new java.sql.Date(2016, 12, 12)) + assert(spark.range(1).map { x => java.sql.Date.valueOf("2016-12-12") }.head == + java.sql.Date.valueOf("2016-12-12")) assert(spark.range(1).map { x => new java.sql.Timestamp(100000) }.head == new java.sql.Timestamp(100000)) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/HDFSMetadataLogSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/HDFSMetadataLogSuite.scala index 48e70e48b1799..4677769c12a35 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/HDFSMetadataLogSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/HDFSMetadataLogSuite.scala @@ -26,10 +26,10 @@ import scala.util.Random import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs._ -import org.scalatest.concurrent.AsyncAssertions._ +import org.scalatest.concurrent.Waiters._ import org.scalatest.time.SpanSugar._ -import org.apache.spark.{SparkConf, SparkFunSuite} +import org.apache.spark.SparkFunSuite import org.apache.spark.sql.execution.streaming.FakeFileSystem._ import org.apache.spark.sql.execution.streaming.HDFSMetadataLog.{FileContextManager, FileManager, FileSystemManager} import org.apache.spark.sql.test.SharedSQLContext diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/EventTimeWatermarkSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/EventTimeWatermarkSuite.scala index f3e8cf950a5a4..47bc452bda0d4 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/EventTimeWatermarkSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/EventTimeWatermarkSuite.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.streaming import java.{util => ju} import java.text.SimpleDateFormat -import java.util.Date +import java.util.{Calendar, Date} import org.scalatest.{BeforeAndAfter, Matchers} @@ -218,7 +218,11 @@ class EventTimeWatermarkSuite extends StreamTest with BeforeAndAfter with Matche .agg(count("*") as 'count) .select($"window".getField("start").cast("long").as[Long], $"count".as[Long]) - def monthsSinceEpoch(date: Date): Int = { date.getYear * 12 + date.getMonth } + def monthsSinceEpoch(date: Date): Int = { + val cal = Calendar.getInstance() + cal.setTime(date) + cal.get(Calendar.YEAR) * 12 + cal.get(Calendar.MONTH) + } testStream(aggWithWatermark)( AddData(input, currentTimeMs / 1000), diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryListenerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryListenerSuite.scala index 1fe639fcf2840..9ff02dee288fb 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryListenerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryListenerSuite.scala @@ -25,8 +25,8 @@ import scala.language.reflectiveCalls import org.scalactic.TolerantNumerics import org.scalatest.BeforeAndAfter import org.scalatest.PrivateMethodTester._ -import org.scalatest.concurrent.AsyncAssertions.Waiter import org.scalatest.concurrent.PatienceConfiguration.Timeout +import org.scalatest.concurrent.Waiters.Waiter import org.apache.spark.SparkException import org.apache.spark.scheduler._ From b70aa9e08b4476746e912c2c2a8b7bdd102305e8 Mon Sep 17 00:00:00 2001 From: Felix Cheung Date: Fri, 10 Nov 2017 10:22:42 -0800 Subject: [PATCH 1648/1765] [SPARK-22344][SPARKR] clean up install dir if running test as source package ## What changes were proposed in this pull request? remove spark if spark downloaded & installed ## How was this patch tested? manually by building package Jenkins, AppVeyor Author: Felix Cheung Closes #19657 from felixcheung/rinstalldir. --- R/pkg/R/install.R | 37 +++++++++++++++++++++++++++- R/pkg/R/utils.R | 13 ++++++++++ R/pkg/tests/fulltests/test_utils.R | 25 +++++++++++++++++++ R/pkg/tests/run-all.R | 4 ++- R/pkg/vignettes/sparkr-vignettes.Rmd | 6 ++++- 5 files changed, 82 insertions(+), 3 deletions(-) diff --git a/R/pkg/R/install.R b/R/pkg/R/install.R index 492dee68e164d..04dc7562e5346 100644 --- a/R/pkg/R/install.R +++ b/R/pkg/R/install.R @@ -152,6 +152,11 @@ install.spark <- function(hadoopVersion = "2.7", mirrorUrl = NULL, }) if (!tarExists || overwrite || !success) { unlink(packageLocalPath) + if (success) { + # if tar file was not there before (or it was, but we are told to overwrite it), + # and untar is successful - set a flag that we have downloaded (and untar) Spark package. + assign(".sparkDownloaded", TRUE, envir = .sparkREnv) + } } if (!success) stop("Extract archive failed.") message("DONE.") @@ -266,6 +271,7 @@ hadoopVersionName <- function(hadoopVersion) { # The implementation refers to appdirs package: https://pypi.python.org/pypi/appdirs and # adapt to Spark context +# see also sparkCacheRelPathLength() sparkCachePath <- function() { if (is_windows()) { winAppPath <- Sys.getenv("LOCALAPPDATA", unset = NA) @@ -282,7 +288,7 @@ sparkCachePath <- function() { } } else if (.Platform$OS.type == "unix") { if (Sys.info()["sysname"] == "Darwin") { - path <- file.path(Sys.getenv("HOME"), "Library/Caches", "spark") + path <- file.path(Sys.getenv("HOME"), "Library", "Caches", "spark") } else { path <- file.path( Sys.getenv("XDG_CACHE_HOME", file.path(Sys.getenv("HOME"), ".cache")), "spark") @@ -293,6 +299,16 @@ sparkCachePath <- function() { normalizePath(path, mustWork = FALSE) } +# Length of the Spark cache specific relative path segments for each platform +# eg. "Apache\Spark\Cache" is 3 in Windows, or "spark" is 1 in unix +# Must match sparkCachePath() exactly. +sparkCacheRelPathLength <- function() { + if (is_windows()) { + 3 + } else { + 1 + } +} installInstruction <- function(mode) { if (mode == "remote") { @@ -310,3 +326,22 @@ installInstruction <- function(mode) { stop(paste0("No instruction found for ", mode, " mode.")) } } + +uninstallDownloadedSpark <- function() { + # clean up if Spark was downloaded + sparkDownloaded <- getOne(".sparkDownloaded", + envir = .sparkREnv, + inherits = TRUE, + ifnotfound = FALSE) + sparkDownloadedDir <- Sys.getenv("SPARK_HOME") + if (sparkDownloaded && nchar(sparkDownloadedDir) > 0) { + unlink(sparkDownloadedDir, recursive = TRUE, force = TRUE) + + dirs <- traverseParentDirs(sparkCachePath(), sparkCacheRelPathLength()) + lapply(dirs, function(d) { + if (length(list.files(d, all.files = TRUE, include.dirs = TRUE, no.. = TRUE)) == 0) { + unlink(d, recursive = TRUE, force = TRUE) + } + }) + } +} diff --git a/R/pkg/R/utils.R b/R/pkg/R/utils.R index fa4099231ca8d..164cd6d01a347 100644 --- a/R/pkg/R/utils.R +++ b/R/pkg/R/utils.R @@ -910,3 +910,16 @@ hadoop_home_set <- function() { windows_with_hadoop <- function() { !is_windows() || hadoop_home_set() } + +# get0 not supported before R 3.2.0 +getOne <- function(x, envir, inherits = TRUE, ifnotfound = NULL) { + mget(x[1L], envir = envir, inherits = inherits, ifnotfound = list(ifnotfound))[[1L]] +} + +# Returns a vector of parent directories, traversing up count times, starting with a full path +# eg. traverseParentDirs("/Users/user/Library/Caches/spark/spark2.2", 1) should return +# this "/Users/user/Library/Caches/spark/spark2.2" +# and "/Users/user/Library/Caches/spark" +traverseParentDirs <- function(x, count) { + if (dirname(x) == x || count <= 0) x else c(x, Recall(dirname(x), count - 1)) +} diff --git a/R/pkg/tests/fulltests/test_utils.R b/R/pkg/tests/fulltests/test_utils.R index fb394b8069c1c..f0292ab335592 100644 --- a/R/pkg/tests/fulltests/test_utils.R +++ b/R/pkg/tests/fulltests/test_utils.R @@ -228,4 +228,29 @@ test_that("basenameSansExtFromUrl", { expect_equal(basenameSansExtFromUrl(z), "spark-2.1.0--hive") }) +test_that("getOne", { + dummy <- getOne(".dummyValue", envir = new.env(), ifnotfound = FALSE) + expect_equal(dummy, FALSE) +}) + +test_that("traverseParentDirs", { + if (is_windows()) { + # original path is included as-is, otherwise dirname() replaces \\ with / on windows + dirs <- traverseParentDirs("c:\\Users\\user\\AppData\\Local\\Apache\\Spark\\Cache\\spark2.2", 3) + expect <- c("c:\\Users\\user\\AppData\\Local\\Apache\\Spark\\Cache\\spark2.2", + "c:/Users/user/AppData/Local/Apache/Spark/Cache", + "c:/Users/user/AppData/Local/Apache/Spark", + "c:/Users/user/AppData/Local/Apache") + expect_equal(dirs, expect) + } else { + dirs <- traverseParentDirs("/Users/user/Library/Caches/spark/spark2.2", 1) + expect <- c("/Users/user/Library/Caches/spark/spark2.2", "/Users/user/Library/Caches/spark") + expect_equal(dirs, expect) + + dirs <- traverseParentDirs("/home/u/.cache/spark/spark2.2", 1) + expect <- c("/home/u/.cache/spark/spark2.2", "/home/u/.cache/spark") + expect_equal(dirs, expect) + } +}) + sparkR.session.stop() diff --git a/R/pkg/tests/run-all.R b/R/pkg/tests/run-all.R index a7f913e5fad11..63812ba70bb50 100644 --- a/R/pkg/tests/run-all.R +++ b/R/pkg/tests/run-all.R @@ -46,7 +46,7 @@ if (identical(Sys.getenv("NOT_CRAN"), "true")) { tmpDir <- tempdir() tmpArg <- paste0("-Djava.io.tmpdir=", tmpDir) sparkRTestConfig <- list(spark.driver.extraJavaOptions = tmpArg, - spark.executor.extraJavaOptions = tmpArg) + spark.executor.extraJavaOptions = tmpArg) } test_package("SparkR") @@ -60,3 +60,5 @@ if (identical(Sys.getenv("NOT_CRAN"), "true")) { NULL, "summary") } + +SparkR:::uninstallDownloadedSpark() diff --git a/R/pkg/vignettes/sparkr-vignettes.Rmd b/R/pkg/vignettes/sparkr-vignettes.Rmd index 907bbb3d66018..8c4ea2f2db188 100644 --- a/R/pkg/vignettes/sparkr-vignettes.Rmd +++ b/R/pkg/vignettes/sparkr-vignettes.Rmd @@ -37,7 +37,7 @@ opts_hooks$set(eval = function(options) { options }) r_tmp_dir <- tempdir() -tmp_arg <- paste("-Djava.io.tmpdir=", r_tmp_dir, sep = "") +tmp_arg <- paste0("-Djava.io.tmpdir=", r_tmp_dir) sparkSessionConfig <- list(spark.driver.extraJavaOptions = tmp_arg, spark.executor.extraJavaOptions = tmp_arg) old_java_opt <- Sys.getenv("_JAVA_OPTIONS") @@ -1183,3 +1183,7 @@ env | map ```{r, echo=FALSE} sparkR.session.stop() ``` + +```{r cleanup, include=FALSE} +SparkR:::uninstallDownloadedSpark() +``` From 5ebdcd185f2108a90e37a1aa4214c3b6c69a97a4 Mon Sep 17 00:00:00 2001 From: Santiago Saavedra Date: Fri, 10 Nov 2017 10:57:58 -0800 Subject: [PATCH 1649/1765] [SPARK-22294][DEPLOY] Reset spark.driver.bindAddress when starting a Checkpoint ## What changes were proposed in this pull request? It seems that recovering from a checkpoint can replace the old driver and executor IP addresses, as the workload can now be taking place in a different cluster configuration. It follows that the bindAddress for the master may also have changed. Thus we should not be keeping the old one, and instead be added to the list of properties to reset and recreate from the new environment. ## How was this patch tested? This patch was tested via manual testing on AWS, using the experimental (not yet merged) Kubernetes scheduler, which uses bindAddress to bind to a Kubernetes service (and thus was how I first encountered the bug too), but it is not a code-path related to the scheduler and this may have slipped through when merging SPARK-4563. Author: Santiago Saavedra Closes #19427 from ssaavedra/fix-checkpointing-master. --- .../src/main/scala/org/apache/spark/streaming/Checkpoint.scala | 2 ++ 1 file changed, 2 insertions(+) diff --git a/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala b/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala index 3cfbcedd519d6..aed67a5027433 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala @@ -51,6 +51,7 @@ class Checkpoint(ssc: StreamingContext, val checkpointTime: Time) "spark.yarn.app.id", "spark.yarn.app.attemptId", "spark.driver.host", + "spark.driver.bindAddress", "spark.driver.port", "spark.master", "spark.yarn.jars", @@ -64,6 +65,7 @@ class Checkpoint(ssc: StreamingContext, val checkpointTime: Time) val newSparkConf = new SparkConf(loadDefaults = false).setAll(sparkConfPairs) .remove("spark.driver.host") + .remove("spark.driver.bindAddress") .remove("spark.driver.port") val newReloadConf = new SparkConf(loadDefaults = true) propertiesToReload.foreach { prop => From 24ea781cd30fbc611c714b5c6931f7d993f3a08d Mon Sep 17 00:00:00 2001 From: Shixiong Zhu Date: Fri, 10 Nov 2017 11:27:28 -0800 Subject: [PATCH 1650/1765] [SPARK-19644][SQL] Clean up Scala reflection garbage after creating Encoder ## What changes were proposed in this pull request? Because of the memory leak issue in `scala.reflect.api.Types.TypeApi.<:<` (https://github.com/scala/bug/issues/8302), creating an encoder may leak memory. This PR adds `cleanUpReflectionObjects` to clean up these leaking objects for methods calling `scala.reflect.api.Types.TypeApi.<:<`. ## How was this patch tested? The updated unit tests. Author: Shixiong Zhu Closes #19687 from zsxwing/SPARK-19644. --- .../spark/sql/catalyst/ScalaReflection.scala | 25 +++++++++---- .../encoders/ExpressionEncoderSuite.scala | 35 +++++++++++++++++-- 2 files changed, 51 insertions(+), 9 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala index 4e47a5890db9c..65040f1af4b04 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala @@ -61,7 +61,7 @@ object ScalaReflection extends ScalaReflection { */ def dataTypeFor[T : TypeTag]: DataType = dataTypeFor(localTypeOf[T]) - private def dataTypeFor(tpe: `Type`): DataType = { + private def dataTypeFor(tpe: `Type`): DataType = cleanUpReflectionObjects { tpe.dealias match { case t if t <:< definitions.IntTpe => IntegerType case t if t <:< definitions.LongTpe => LongType @@ -93,7 +93,7 @@ object ScalaReflection extends ScalaReflection { * Special handling is performed for primitive types to map them back to their raw * JVM form instead of the Scala Array that handles auto boxing. */ - private def arrayClassFor(tpe: `Type`): ObjectType = { + private def arrayClassFor(tpe: `Type`): ObjectType = cleanUpReflectionObjects { val cls = tpe.dealias match { case t if t <:< definitions.IntTpe => classOf[Array[Int]] case t if t <:< definitions.LongTpe => classOf[Array[Long]] @@ -146,7 +146,7 @@ object ScalaReflection extends ScalaReflection { private def deserializerFor( tpe: `Type`, path: Option[Expression], - walkedTypePath: Seq[String]): Expression = { + walkedTypePath: Seq[String]): Expression = cleanUpReflectionObjects { /** Returns the current path with a sub-field extracted. */ def addToPath(part: String, dataType: DataType, walkedTypePath: Seq[String]): Expression = { @@ -441,7 +441,7 @@ object ScalaReflection extends ScalaReflection { inputObject: Expression, tpe: `Type`, walkedTypePath: Seq[String], - seenTypeSet: Set[`Type`] = Set.empty): Expression = { + seenTypeSet: Set[`Type`] = Set.empty): Expression = cleanUpReflectionObjects { def toCatalystArray(input: Expression, elementType: `Type`): Expression = { dataTypeFor(elementType) match { @@ -648,7 +648,7 @@ object ScalaReflection extends ScalaReflection { * Returns true if the given type is option of product type, e.g. `Option[Tuple2]`. Note that, * we also treat [[DefinedByConstructorParams]] as product type. */ - def optionOfProductType(tpe: `Type`): Boolean = { + def optionOfProductType(tpe: `Type`): Boolean = cleanUpReflectionObjects { tpe.dealias match { case t if t <:< localTypeOf[Option[_]] => val TypeRef(_, _, Seq(optType)) = t @@ -710,7 +710,7 @@ object ScalaReflection extends ScalaReflection { def schemaFor[T: TypeTag]: Schema = schemaFor(localTypeOf[T]) /** Returns a catalyst DataType and its nullability for the given Scala Type using reflection. */ - def schemaFor(tpe: `Type`): Schema = { + def schemaFor(tpe: `Type`): Schema = cleanUpReflectionObjects { tpe.dealias match { case t if t.typeSymbol.annotations.exists(_.tree.tpe =:= typeOf[SQLUserDefinedType]) => val udt = getClassFromType(t).getAnnotation(classOf[SQLUserDefinedType]).udt().newInstance() @@ -780,7 +780,7 @@ object ScalaReflection extends ScalaReflection { /** * Whether the fields of the given type is defined entirely by its constructor parameters. */ - def definedByConstructorParams(tpe: Type): Boolean = { + def definedByConstructorParams(tpe: Type): Boolean = cleanUpReflectionObjects { tpe.dealias <:< localTypeOf[Product] || tpe.dealias <:< localTypeOf[DefinedByConstructorParams] } @@ -809,6 +809,17 @@ trait ScalaReflection { // Since the map values can be mutable, we explicitly import scala.collection.Map at here. import scala.collection.Map + /** + * Any codes calling `scala.reflect.api.Types.TypeApi.<:<` should be wrapped by this method to + * clean up the Scala reflection garbage automatically. Otherwise, it will leak some objects to + * `scala.reflect.runtime.JavaUniverse.undoLog`. + * + * @see https://github.com/scala/bug/issues/8302 + */ + def cleanUpReflectionObjects[T](func: => T): T = { + universe.asInstanceOf[scala.reflect.runtime.JavaUniverse].undoLog.undo(func) + } + /** * Return the Scala Type for `T` in the current classloader mirror. * diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala index bb1955a1ae242..e6d09bdae67d7 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala @@ -34,6 +34,7 @@ import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, Project} import org.apache.spark.sql.catalyst.util.ArrayData import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String +import org.apache.spark.util.ClosureCleaner case class RepeatedStruct(s: Seq[PrimitiveData]) @@ -114,7 +115,9 @@ object ReferenceValueClass { class ExpressionEncoderSuite extends PlanTest with AnalysisTest { OuterScopes.addOuterScope(this) - implicit def encoder[T : TypeTag]: ExpressionEncoder[T] = ExpressionEncoder() + implicit def encoder[T : TypeTag]: ExpressionEncoder[T] = verifyNotLeakingReflectionObjects { + ExpressionEncoder() + } // test flat encoders encodeDecodeTest(false, "primitive boolean") @@ -370,8 +373,12 @@ class ExpressionEncoderSuite extends PlanTest with AnalysisTest { private def encodeDecodeTest[T : ExpressionEncoder]( input: T, testName: String): Unit = { - test(s"encode/decode for $testName: $input") { + testAndVerifyNotLeakingReflectionObjects(s"encode/decode for $testName: $input") { val encoder = implicitly[ExpressionEncoder[T]] + + // Make sure encoder is serializable. + ClosureCleaner.clean((s: String) => encoder.getClass.getName) + val row = encoder.toRow(input) val schema = encoder.schema.toAttributes val boundEncoder = encoder.resolveAndBind() @@ -441,4 +448,28 @@ class ExpressionEncoderSuite extends PlanTest with AnalysisTest { } } } + + /** + * Verify the size of scala.reflect.runtime.JavaUniverse.undoLog before and after `func` to + * ensure we don't leak Scala reflection garbage. + * + * @see org.apache.spark.sql.catalyst.ScalaReflection.cleanUpReflectionObjects + */ + private def verifyNotLeakingReflectionObjects[T](func: => T): T = { + def undoLogSize: Int = { + scala.reflect.runtime.universe + .asInstanceOf[scala.reflect.runtime.JavaUniverse].undoLog.log.size + } + + val previousUndoLogSize = undoLogSize + val r = func + assert(previousUndoLogSize == undoLogSize) + r + } + + private def testAndVerifyNotLeakingReflectionObjects(testName: String)(testFun: => Any) { + test(testName) { + verifyNotLeakingReflectionObjects(testFun) + } + } } From f2da738c76810131045e6c32533a2d13526cdaf6 Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Fri, 10 Nov 2017 21:17:49 +0100 Subject: [PATCH 1651/1765] [SPARK-22284][SQL] Fix 64KB JVM bytecode limit problem in calculating hash for nested structs ## What changes were proposed in this pull request? This PR avoids to generate a huge method for calculating a murmur3 hash for nested structs. This PR splits a huge method (e.g. `apply_4`) into multiple smaller methods. Sample program ``` val structOfString = new StructType().add("str", StringType) var inner = new StructType() for (_ <- 0 until 800) { inner = inner1.add("structOfString", structOfString) } var schema = new StructType() for (_ <- 0 until 50) { schema = schema.add("structOfStructOfStrings", inner) } GenerateMutableProjection.generate(Seq(Murmur3Hash(exprs, 42))) ``` Without this PR ``` /* 005 */ class SpecificMutableProjection extends org.apache.spark.sql.catalyst.expressions.codegen.BaseMutableProjection { /* 006 */ /* 007 */ private Object[] references; /* 008 */ private InternalRow mutableRow; /* 009 */ private int value; /* 010 */ private int value_0; ... /* 034 */ public java.lang.Object apply(java.lang.Object _i) { /* 035 */ InternalRow i = (InternalRow) _i; /* 036 */ /* 037 */ /* 038 */ /* 039 */ value = 42; /* 040 */ apply_0(i); /* 041 */ apply_1(i); /* 042 */ apply_2(i); /* 043 */ apply_3(i); /* 044 */ apply_4(i); /* 045 */ nestedClassInstance.apply_5(i); ... /* 089 */ nestedClassInstance8.apply_49(i); /* 090 */ value_0 = value; /* 091 */ /* 092 */ // copy all the results into MutableRow /* 093 */ mutableRow.setInt(0, value_0); /* 094 */ return mutableRow; /* 095 */ } /* 096 */ /* 097 */ /* 098 */ private void apply_4(InternalRow i) { /* 099 */ /* 100 */ boolean isNull5 = i.isNullAt(4); /* 101 */ InternalRow value5 = isNull5 ? null : (i.getStruct(4, 800)); /* 102 */ if (!isNull5) { /* 103 */ /* 104 */ if (!value5.isNullAt(0)) { /* 105 */ /* 106 */ final InternalRow element6400 = value5.getStruct(0, 1); /* 107 */ /* 108 */ if (!element6400.isNullAt(0)) { /* 109 */ /* 110 */ final UTF8String element6401 = element6400.getUTF8String(0); /* 111 */ value = org.apache.spark.unsafe.hash.Murmur3_x86_32.hashUnsafeBytes(element6401.getBaseObject(), element6401.getBaseOffset(), element6401.numBytes(), value); /* 112 */ /* 113 */ } /* 114 */ /* 115 */ /* 116 */ } /* 117 */ /* 118 */ /* 119 */ if (!value5.isNullAt(1)) { /* 120 */ /* 121 */ final InternalRow element6402 = value5.getStruct(1, 1); /* 122 */ /* 123 */ if (!element6402.isNullAt(0)) { /* 124 */ /* 125 */ final UTF8String element6403 = element6402.getUTF8String(0); /* 126 */ value = org.apache.spark.unsafe.hash.Murmur3_x86_32.hashUnsafeBytes(element6403.getBaseObject(), element6403.getBaseOffset(), element6403.numBytes(), value); /* 127 */ /* 128 */ } /* 128 */ } /* 129 */ /* 130 */ /* 131 */ } /* 132 */ /* 133 */ /* 134 */ if (!value5.isNullAt(2)) { /* 135 */ /* 136 */ final InternalRow element6404 = value5.getStruct(2, 1); /* 137 */ /* 138 */ if (!element6404.isNullAt(0)) { /* 139 */ /* 140 */ final UTF8String element6405 = element6404.getUTF8String(0); /* 141 */ value = org.apache.spark.unsafe.hash.Murmur3_x86_32.hashUnsafeBytes(element6405.getBaseObject(), element6405.getBaseOffset(), element6405.numBytes(), value); /* 142 */ /* 143 */ } /* 144 */ /* 145 */ /* 146 */ } /* 147 */ ... /* 12074 */ if (!value5.isNullAt(798)) { /* 12075 */ /* 12076 */ final InternalRow element7996 = value5.getStruct(798, 1); /* 12077 */ /* 12078 */ if (!element7996.isNullAt(0)) { /* 12079 */ /* 12080 */ final UTF8String element7997 = element7996.getUTF8String(0); /* 12083 */ } /* 12084 */ /* 12085 */ /* 12086 */ } /* 12087 */ /* 12088 */ /* 12089 */ if (!value5.isNullAt(799)) { /* 12090 */ /* 12091 */ final InternalRow element7998 = value5.getStruct(799, 1); /* 12092 */ /* 12093 */ if (!element7998.isNullAt(0)) { /* 12094 */ /* 12095 */ final UTF8String element7999 = element7998.getUTF8String(0); /* 12096 */ value = org.apache.spark.unsafe.hash.Murmur3_x86_32.hashUnsafeBytes(element7999.getBaseObject(), element7999.getBaseOffset(), element7999.numBytes(), value); /* 12097 */ /* 12098 */ } /* 12099 */ /* 12100 */ /* 12101 */ } /* 12102 */ /* 12103 */ } /* 12104 */ /* 12105 */ } /* 12106 */ /* 12106 */ /* 12107 */ /* 12108 */ private void apply_1(InternalRow i) { ... ``` With this PR ``` /* 005 */ class SpecificMutableProjection extends org.apache.spark.sql.catalyst.expressions.codegen.BaseMutableProjection { /* 006 */ /* 007 */ private Object[] references; /* 008 */ private InternalRow mutableRow; /* 009 */ private int value; /* 010 */ private int value_0; /* 011 */ ... /* 034 */ public java.lang.Object apply(java.lang.Object _i) { /* 035 */ InternalRow i = (InternalRow) _i; /* 036 */ /* 037 */ /* 038 */ /* 039 */ value = 42; /* 040 */ nestedClassInstance11.apply50_0(i); /* 041 */ nestedClassInstance11.apply50_1(i); ... /* 088 */ nestedClassInstance11.apply50_48(i); /* 089 */ nestedClassInstance11.apply50_49(i); /* 090 */ value_0 = value; /* 091 */ /* 092 */ // copy all the results into MutableRow /* 093 */ mutableRow.setInt(0, value_0); /* 094 */ return mutableRow; /* 095 */ } /* 096 */ ... /* 37717 */ private void apply4_0(InternalRow value5, InternalRow i) { /* 37718 */ /* 37719 */ if (!value5.isNullAt(0)) { /* 37720 */ /* 37721 */ final InternalRow element6400 = value5.getStruct(0, 1); /* 37722 */ /* 37723 */ if (!element6400.isNullAt(0)) { /* 37724 */ /* 37725 */ final UTF8String element6401 = element6400.getUTF8String(0); /* 37726 */ value = org.apache.spark.unsafe.hash.Murmur3_x86_32.hashUnsafeBytes(element6401.getBaseObject(), element6401.getBaseOffset(), element6401.numBytes(), value); /* 37727 */ /* 37728 */ } /* 37729 */ /* 37730 */ /* 37731 */ } /* 37732 */ /* 37733 */ if (!value5.isNullAt(1)) { /* 37734 */ /* 37735 */ final InternalRow element6402 = value5.getStruct(1, 1); /* 37736 */ /* 37737 */ if (!element6402.isNullAt(0)) { /* 37738 */ /* 37739 */ final UTF8String element6403 = element6402.getUTF8String(0); /* 37740 */ value = org.apache.spark.unsafe.hash.Murmur3_x86_32.hashUnsafeBytes(element6403.getBaseObject(), element6403.getBaseOffset(), element6403.numBytes(), value); /* 37741 */ /* 37742 */ } /* 37743 */ /* 37744 */ /* 37745 */ } /* 37746 */ /* 37747 */ if (!value5.isNullAt(2)) { /* 37748 */ /* 37749 */ final InternalRow element6404 = value5.getStruct(2, 1); /* 37750 */ /* 37751 */ if (!element6404.isNullAt(0)) { /* 37752 */ /* 37753 */ final UTF8String element6405 = element6404.getUTF8String(0); /* 37754 */ value = org.apache.spark.unsafe.hash.Murmur3_x86_32.hashUnsafeBytes(element6405.getBaseObject(), element6405.getBaseOffset(), element6405.numBytes(), value); /* 37755 */ /* 37756 */ } /* 37757 */ /* 37758 */ /* 37759 */ } /* 37760 */ /* 37761 */ } ... /* 218470 */ /* 218471 */ private void apply50_4(InternalRow i) { /* 218472 */ /* 218473 */ boolean isNull5 = i.isNullAt(4); /* 218474 */ InternalRow value5 = isNull5 ? null : (i.getStruct(4, 800)); /* 218475 */ if (!isNull5) { /* 218476 */ apply4_0(value5, i); /* 218477 */ apply4_1(value5, i); /* 218478 */ apply4_2(value5, i); ... /* 218742 */ nestedClassInstance.apply4_266(value5, i); /* 218743 */ } /* 218744 */ /* 218745 */ } ``` ## How was this patch tested? Added new test to `HashExpressionsSuite` Author: Kazuaki Ishizaki Closes #19563 from kiszk/SPARK-22284. --- .../spark/sql/catalyst/expressions/hash.scala | 5 ++-- .../expressions/HashExpressionsSuite.scala | 29 +++++++++++++++++++ 2 files changed, 32 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala index 85a5f7fb2c6c3..eb3c49f5cf30e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala @@ -389,9 +389,10 @@ abstract class HashExpression[E] extends Expression { input: String, result: String, fields: Array[StructField]): String = { - fields.zipWithIndex.map { case (field, index) => + val hashes = fields.zipWithIndex.map { case (field, index) => nullSafeElementHash(input, index.toString, field.nullable, field.dataType, result, ctx) - }.mkString("\n") + } + ctx.splitExpressions(input, hashes) } @tailrec diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HashExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HashExpressionsSuite.scala index 59fc8eaf73d61..112a4a09728ae 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HashExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HashExpressionsSuite.scala @@ -639,6 +639,35 @@ class HashExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { assert(hiveHashPlan(wideRow).getInt(0) == hiveHashEval) } + test("SPARK-22284: Compute hash for nested structs") { + val M = 80 + val N = 10 + val L = M * N + val O = 50 + val seed = 42 + + val wideRow = new GenericInternalRow(Seq.tabulate(O)(k => + new GenericInternalRow(Seq.tabulate(M)(j => + new GenericInternalRow(Seq.tabulate(N)(i => + new GenericInternalRow(Array[Any]( + UTF8String.fromString((k * L + j * N + i).toString)))) + .toArray[Any])).toArray[Any])).toArray[Any]) + val inner = new StructType( + (0 until N).map(_ => StructField("structOfString", structOfString)).toArray) + val outer = new StructType( + (0 until M).map(_ => StructField("structOfStructOfString", inner)).toArray) + val schema = new StructType( + (0 until O).map(_ => StructField("structOfStructOfStructOfString", outer)).toArray) + val exprs = schema.fields.zipWithIndex.map { case (f, i) => + BoundReference(i, f.dataType, true) + } + val murmur3HashExpr = Murmur3Hash(exprs, 42) + val murmur3HashPlan = GenerateMutableProjection.generate(Seq(murmur3HashExpr)) + + val murmursHashEval = Murmur3Hash(exprs, 42).eval(wideRow) + assert(murmur3HashPlan(wideRow).getInt(0) == murmursHashEval) + } + private def testHash(inputSchema: StructType): Unit = { val inputGenerator = RandomDataGenerator.forType(inputSchema, nullable = false).get val encoder = RowEncoder(inputSchema) From 808e886b9638ab2981dac676b594f09cda9722fe Mon Sep 17 00:00:00 2001 From: Rekha Joshi Date: Fri, 10 Nov 2017 15:18:11 -0800 Subject: [PATCH 1652/1765] [SPARK-21667][STREAMING] ConsoleSink should not fail streaming query with checkpointLocation option ## What changes were proposed in this pull request? Fix to allow recovery on console , avoid checkpoint exception ## How was this patch tested? existing tests manual tests [ Replicating error and seeing no checkpoint error after fix] Author: Rekha Joshi Author: rjoshi2 Closes #19407 from rekhajoshm/SPARK-21667. --- .../apache/spark/sql/streaming/DataStreamWriter.scala | 10 ++-------- 1 file changed, 2 insertions(+), 8 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala index 14e7df672cc58..0be69b98abc8a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala @@ -267,12 +267,6 @@ final class DataStreamWriter[T] private[sql](ds: Dataset[T]) { useTempCheckpointLocation = true, trigger = trigger) } else { - val (useTempCheckpointLocation, recoverFromCheckpointLocation) = - if (source == "console") { - (true, false) - } else { - (false, true) - } val dataSource = DataSource( df.sparkSession, @@ -285,8 +279,8 @@ final class DataStreamWriter[T] private[sql](ds: Dataset[T]) { df, dataSource.createSink(outputMode), outputMode, - useTempCheckpointLocation = useTempCheckpointLocation, - recoverFromCheckpointLocation = recoverFromCheckpointLocation, + useTempCheckpointLocation = source == "console", + recoverFromCheckpointLocation = true, trigger = trigger) } } From 3eb315d7141d69ac040dcba498dd863b6d217775 Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Sat, 11 Nov 2017 04:10:54 -0600 Subject: [PATCH 1653/1765] [SPARK-19759][ML] not using blas in ALSModel.predict for optimization ## What changes were proposed in this pull request? In `ALS.predict` currently we are using `blas.sdot` function to perform a dot product on two `Seq`s. It turns out that this is not the most efficient way. I used the following code to compare the implementations: ``` def time[R](block: => R): Unit = { val t0 = System.nanoTime() block val t1 = System.nanoTime() println("Elapsed time: " + (t1 - t0) + "ns") } val r = new scala.util.Random(100) val input = (1 to 500000).map(_ => (1 to 100).map(_ => r.nextFloat).toSeq) def f(a:Seq[Float], b:Seq[Float]): Float = { var r = 0.0f for(i <- 0 until a.length) { r+=a(i)*b(i) } r } import com.github.fommil.netlib.BLAS.{getInstance => blas} val b = (1 to 100).map(_ => r.nextFloat).toSeq time { input.foreach(a=>blas.sdot(100, a.toArray, 1, b.toArray, 1)) } // on average it takes 2968718815 ns time { input.foreach(a=>f(a,b)) } // on average it takes 515510185 ns ``` Thus this PR proposes the old-style for loop implementation for performance reasons. ## How was this patch tested? existing UTs Author: Marco Gaido Closes #19685 from mgaido91/SPARK-19759. --- .../scala/org/apache/spark/ml/recommendation/ALS.scala | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala index a8843661c873b..81a8f50761e0e 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala @@ -289,9 +289,13 @@ class ALSModel private[ml] ( private val predict = udf { (featuresA: Seq[Float], featuresB: Seq[Float]) => if (featuresA != null && featuresB != null) { - // TODO(SPARK-19759): try dot-producting on Seqs or another non-converted type for - // potential optimization. - blas.sdot(rank, featuresA.toArray, 1, featuresB.toArray, 1) + var dotProduct = 0.0f + var i = 0 + while (i < rank) { + dotProduct += featuresA(i) * featuresB(i) + i += 1 + } + dotProduct } else { Float.NaN } From 223d83ee93e604009afea4af3d13a838d08625a4 Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Sat, 11 Nov 2017 19:16:31 +0900 Subject: [PATCH 1654/1765] [SPARK-22476][R] Add dayofweek function to R ## What changes were proposed in this pull request? This PR adds `dayofweek` to R API: ```r data <- list(list(d = as.Date("2012-12-13")), list(d = as.Date("2013-12-14")), list(d = as.Date("2014-12-15"))) df <- createDataFrame(data) collect(select(df, dayofweek(df$d))) ``` ``` dayofweek(d) 1 5 2 7 3 2 ``` ## How was this patch tested? Manual tests and unit tests in `R/pkg/tests/fulltests/test_sparkSQL.R` Author: hyukjinkwon Closes #19706 from HyukjinKwon/add-dayofweek. --- R/pkg/NAMESPACE | 1 + R/pkg/R/functions.R | 17 ++++++++++++++++- R/pkg/R/generics.R | 5 +++++ R/pkg/tests/fulltests/test_sparkSQL.R | 1 + 4 files changed, 23 insertions(+), 1 deletion(-) diff --git a/R/pkg/NAMESPACE b/R/pkg/NAMESPACE index 3fc756b9ef40c..57838f52eac3f 100644 --- a/R/pkg/NAMESPACE +++ b/R/pkg/NAMESPACE @@ -232,6 +232,7 @@ exportMethods("%<=>%", "date_sub", "datediff", "dayofmonth", + "dayofweek", "dayofyear", "decode", "dense_rank", diff --git a/R/pkg/R/functions.R b/R/pkg/R/functions.R index 0143a3e63ba61..237ef061e8071 100644 --- a/R/pkg/R/functions.R +++ b/R/pkg/R/functions.R @@ -696,7 +696,7 @@ setMethod("hash", #' #' \dontrun{ #' head(select(df, df$time, year(df$time), quarter(df$time), month(df$time), -#' dayofmonth(df$time), dayofyear(df$time), weekofyear(df$time))) +#' dayofmonth(df$time), dayofweek(df$time), dayofyear(df$time), weekofyear(df$time))) #' head(agg(groupBy(df, year(df$time)), count(df$y), avg(df$y))) #' head(agg(groupBy(df, month(df$time)), avg(df$y)))} #' @note dayofmonth since 1.5.0 @@ -707,6 +707,21 @@ setMethod("dayofmonth", column(jc) }) +#' @details +#' \code{dayofweek}: Extracts the day of the week as an integer from a +#' given date/timestamp/string. +#' +#' @rdname column_datetime_functions +#' @aliases dayofweek dayofweek,Column-method +#' @export +#' @note dayofweek since 2.3.0 +setMethod("dayofweek", + signature(x = "Column"), + function(x) { + jc <- callJStatic("org.apache.spark.sql.functions", "dayofweek", x@jc) + column(jc) + }) + #' @details #' \code{dayofyear}: Extracts the day of the year as an integer from a #' given date/timestamp/string. diff --git a/R/pkg/R/generics.R b/R/pkg/R/generics.R index 8312d417b99d2..8fcf269087c7d 100644 --- a/R/pkg/R/generics.R +++ b/R/pkg/R/generics.R @@ -1048,6 +1048,11 @@ setGeneric("date_sub", function(y, x) { standardGeneric("date_sub") }) #' @name NULL setGeneric("dayofmonth", function(x) { standardGeneric("dayofmonth") }) +#' @rdname column_datetime_functions +#' @export +#' @name NULL +setGeneric("dayofweek", function(x) { standardGeneric("dayofweek") }) + #' @rdname column_datetime_functions #' @export #' @name NULL diff --git a/R/pkg/tests/fulltests/test_sparkSQL.R b/R/pkg/tests/fulltests/test_sparkSQL.R index a0dbd475f78e6..8a7fb124bd9db 100644 --- a/R/pkg/tests/fulltests/test_sparkSQL.R +++ b/R/pkg/tests/fulltests/test_sparkSQL.R @@ -1699,6 +1699,7 @@ test_that("date functions on a DataFrame", { list(a = 2L, b = as.Date("2013-12-14")), list(a = 3L, b = as.Date("2014-12-15"))) df <- createDataFrame(l) + expect_equal(collect(select(df, dayofweek(df$b)))[, 1], c(5, 7, 2)) expect_equal(collect(select(df, dayofmonth(df$b)))[, 1], c(13, 14, 15)) expect_equal(collect(select(df, dayofyear(df$b)))[, 1], c(348, 348, 349)) expect_equal(collect(select(df, weekofyear(df$b)))[, 1], c(50, 50, 51)) From 154351e6dbd24c4254094477e3f7defcba979b1a Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Sat, 11 Nov 2017 12:34:30 +0100 Subject: [PATCH 1655/1765] [SPARK-22462][SQL] Make rdd-based actions in Dataset trackable in SQL UI ## What changes were proposed in this pull request? For the few Dataset actions such as `foreach`, currently no SQL metrics are visible in the SQL tab of SparkUI. It is because it binds wrongly to Dataset's `QueryExecution`. As the actions directly evaluate on the RDD which has individual `QueryExecution`, to show correct SQL metrics on UI, we should bind to RDD's `QueryExecution`. ## How was this patch tested? Manually test. Screenshot is attached in the PR. Author: Liang-Chi Hsieh Closes #19689 from viirya/SPARK-22462. --- .../scala/org/apache/spark/sql/Dataset.scala | 27 ++++++++++++++++--- 1 file changed, 23 insertions(+), 4 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index 5eb2affa0bd8f..1620ab3aa2094 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -2594,7 +2594,7 @@ class Dataset[T] private[sql]( * @group action * @since 1.6.0 */ - def foreach(f: T => Unit): Unit = withNewExecutionId { + def foreach(f: T => Unit): Unit = withNewRDDExecutionId { rdd.foreach(f) } @@ -2613,7 +2613,7 @@ class Dataset[T] private[sql]( * @group action * @since 1.6.0 */ - def foreachPartition(f: Iterator[T] => Unit): Unit = withNewExecutionId { + def foreachPartition(f: Iterator[T] => Unit): Unit = withNewRDDExecutionId { rdd.foreachPartition(f) } @@ -2851,6 +2851,12 @@ class Dataset[T] private[sql]( */ def unpersist(): this.type = unpersist(blocking = false) + // Represents the `QueryExecution` used to produce the content of the Dataset as an `RDD`. + @transient private lazy val rddQueryExecution: QueryExecution = { + val deserialized = CatalystSerde.deserialize[T](logicalPlan) + sparkSession.sessionState.executePlan(deserialized) + } + /** * Represents the content of the Dataset as an `RDD` of `T`. * @@ -2859,8 +2865,7 @@ class Dataset[T] private[sql]( */ lazy val rdd: RDD[T] = { val objectType = exprEnc.deserializer.dataType - val deserialized = CatalystSerde.deserialize[T](logicalPlan) - sparkSession.sessionState.executePlan(deserialized).toRdd.mapPartitions { rows => + rddQueryExecution.toRdd.mapPartitions { rows => rows.map(_.get(0, objectType).asInstanceOf[T]) } } @@ -3113,6 +3118,20 @@ class Dataset[T] private[sql]( SQLExecution.withNewExecutionId(sparkSession, queryExecution)(body) } + /** + * Wrap an action of the Dataset's RDD to track all Spark jobs in the body so that we can connect + * them with an execution. Before performing the action, the metrics of the executed plan will be + * reset. + */ + private def withNewRDDExecutionId[U](body: => U): U = { + SQLExecution.withNewExecutionId(sparkSession, rddQueryExecution) { + rddQueryExecution.executedPlan.foreach { plan => + plan.resetMetrics() + } + body + } + } + /** * Wrap a Dataset action to track the QueryExecution and time cost, then report to the * user-registered callback functions. From d6ee69e7761a62e47e21e4c2ce1bb20038d745b6 Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Sat, 11 Nov 2017 18:20:11 +0100 Subject: [PATCH 1656/1765] [SPARK-22488][SQL] Fix the view resolution issue in the SparkSession internal table() API ## What changes were proposed in this pull request? The current internal `table()` API of `SparkSession` bypasses the Analyzer and directly calls `sessionState.catalog.lookupRelation` API. This skips the view resolution logics in our Analyzer rule `ResolveRelations`. This internal API is widely used by various DDL commands, public and internal APIs. Users might get the strange error caused by view resolution when the default database is different. ``` Table or view not found: t1; line 1 pos 14 org.apache.spark.sql.AnalysisException: Table or view not found: t1; line 1 pos 14 at org.apache.spark.sql.catalyst.analysis.package$AnalysisErrorAt.failAnalysis(package.scala:42) ``` This PR is to fix it by enforcing it to use `ResolveRelations` to resolve the table. ## How was this patch tested? Added a test case and modified the existing test cases Author: gatorsmile Closes #19713 from gatorsmile/viewResolution. --- R/pkg/tests/fulltests/test_sparkSQL.R | 2 +- .../org/apache/spark/sql/SparkSession.scala | 3 ++- .../spark/sql/execution/command/cache.scala | 4 +--- .../sql/execution/GlobalTempViewSuite.scala | 16 +++++++++++----- .../spark/sql/execution/SQLViewSuite.scala | 15 +++++++++++++++ .../spark/sql/execution/command/DDLSuite.scala | 5 +++-- .../spark/sql/sources/FilteredScanSuite.scala | 2 +- .../apache/spark/sql/hive/CachedTableSuite.scala | 14 +++++++++----- 8 files changed, 43 insertions(+), 18 deletions(-) diff --git a/R/pkg/tests/fulltests/test_sparkSQL.R b/R/pkg/tests/fulltests/test_sparkSQL.R index 8a7fb124bd9db..00217c892fc8d 100644 --- a/R/pkg/tests/fulltests/test_sparkSQL.R +++ b/R/pkg/tests/fulltests/test_sparkSQL.R @@ -733,7 +733,7 @@ test_that("test cache, uncache and clearCache", { expect_true(dropTempView("table1")) expect_error(uncacheTable("foo"), - "Error in uncacheTable : no such table - Table or view 'foo' not found in database 'default'") + "Error in uncacheTable : analysis error - Table or view not found: foo") }) test_that("insertInto() on a registered table", { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala index d5ab53ad8fe29..2821f5ee7feee 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala @@ -32,6 +32,7 @@ import org.apache.spark.rdd.RDD import org.apache.spark.scheduler.{SparkListener, SparkListenerApplicationEnd} import org.apache.spark.sql.catalog.Catalog import org.apache.spark.sql.catalyst._ +import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation import org.apache.spark.sql.catalyst.encoders._ import org.apache.spark.sql.catalyst.expressions.AttributeReference import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, Range} @@ -621,7 +622,7 @@ class SparkSession private( } private[sql] def table(tableIdent: TableIdentifier): DataFrame = { - Dataset.ofRows(self, sessionState.catalog.lookupRelation(tableIdent)) + Dataset.ofRows(self, UnresolvedRelation(tableIdent)) } /* ----------------- * diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/cache.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/cache.scala index 687994d82a003..6b00426d2fa91 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/cache.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/cache.scala @@ -54,10 +54,8 @@ case class UncacheTableCommand( override def run(sparkSession: SparkSession): Seq[Row] = { val tableId = tableIdent.quotedString - try { + if (!ifExists || sparkSession.catalog.tableExists(tableId)) { sparkSession.catalog.uncacheTable(tableId) - } catch { - case _: NoSuchTableException if ifExists => // don't throw } Seq.empty[Row] } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/GlobalTempViewSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/GlobalTempViewSuite.scala index a3d75b221ec3e..cc943e0356f2a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/GlobalTempViewSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/GlobalTempViewSuite.scala @@ -35,23 +35,27 @@ class GlobalTempViewSuite extends QueryTest with SharedSQLContext { private var globalTempDB: String = _ test("basic semantic") { + val expectedErrorMsg = "not found" try { sql("CREATE GLOBAL TEMP VIEW src AS SELECT 1, 'a'") // If there is no database in table name, we should try local temp view first, if not found, // try table/view in current database, which is "default" in this case. So we expect // NoSuchTableException here. - intercept[NoSuchTableException](spark.table("src")) + var e = intercept[AnalysisException](spark.table("src")).getMessage + assert(e.contains(expectedErrorMsg)) // Use qualified name to refer to the global temp view explicitly. checkAnswer(spark.table(s"$globalTempDB.src"), Row(1, "a")) // Table name without database will never refer to a global temp view. - intercept[NoSuchTableException](sql("DROP VIEW src")) + e = intercept[AnalysisException](sql("DROP VIEW src")).getMessage + assert(e.contains(expectedErrorMsg)) sql(s"DROP VIEW $globalTempDB.src") // The global temp view should be dropped successfully. - intercept[NoSuchTableException](spark.table(s"$globalTempDB.src")) + e = intercept[AnalysisException](spark.table(s"$globalTempDB.src")).getMessage + assert(e.contains(expectedErrorMsg)) // We can also use Dataset API to create global temp view Seq(1 -> "a").toDF("i", "j").createGlobalTempView("src") @@ -59,7 +63,8 @@ class GlobalTempViewSuite extends QueryTest with SharedSQLContext { // Use qualified name to rename a global temp view. sql(s"ALTER VIEW $globalTempDB.src RENAME TO src2") - intercept[NoSuchTableException](spark.table(s"$globalTempDB.src")) + e = intercept[AnalysisException](spark.table(s"$globalTempDB.src")).getMessage + assert(e.contains(expectedErrorMsg)) checkAnswer(spark.table(s"$globalTempDB.src2"), Row(1, "a")) // Use qualified name to alter a global temp view. @@ -68,7 +73,8 @@ class GlobalTempViewSuite extends QueryTest with SharedSQLContext { // We can also use Catalog API to drop global temp view spark.catalog.dropGlobalTempView("src2") - intercept[NoSuchTableException](spark.table(s"$globalTempDB.src2")) + e = intercept[AnalysisException](spark.table(s"$globalTempDB.src2")).getMessage + assert(e.contains(expectedErrorMsg)) // We can also use Dataset API to replace global temp view Seq(2 -> "b").toDF("i", "j").createOrReplaceGlobalTempView("src") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLViewSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLViewSuite.scala index 6761f05bb462a..08a4a21b20f61 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLViewSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLViewSuite.scala @@ -679,4 +679,19 @@ abstract class SQLViewSuite extends QueryTest with SQLTestUtils { assert(spark.table("v").schema.head.name == "cBa") } } + + test("sparkSession API view resolution with different default database") { + withDatabase("db2") { + withView("v1") { + withTable("t1") { + sql("USE default") + sql("CREATE TABLE t1 USING parquet AS SELECT 1 AS c0") + sql("CREATE VIEW v1 AS SELECT * FROM t1") + sql("CREATE DATABASE IF NOT EXISTS db2") + sql("USE db2") + checkAnswer(spark.table("default.v1"), Row(1)) + } + } + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala index 21a2c62929146..878f435c75cb3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala @@ -825,10 +825,11 @@ abstract class DDLSuite extends QueryTest with SQLTestUtils { spark.range(10).createOrReplaceTempView("tab1") sql("ALTER TABLE tab1 RENAME TO tab2") checkAnswer(spark.table("tab2"), spark.range(10).toDF()) - intercept[NoSuchTableException] { spark.table("tab1") } + val e = intercept[AnalysisException](spark.table("tab1")).getMessage + assert(e.contains("Table or view not found")) sql("ALTER VIEW tab2 RENAME TO tab1") checkAnswer(spark.table("tab1"), spark.range(10).toDF()) - intercept[NoSuchTableException] { spark.table("tab2") } + intercept[AnalysisException] { spark.table("tab2") } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/FilteredScanSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/FilteredScanSuite.scala index c45b507d2b489..a538b9458177e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/FilteredScanSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/FilteredScanSuite.scala @@ -326,7 +326,7 @@ class FilteredScanSuite extends DataSourceTest with SharedSQLContext with Predic assert(ColumnsRequired.set === requiredColumnNames) val table = spark.table("oneToTenFiltered") - val relation = table.queryExecution.logical.collectFirst { + val relation = table.queryExecution.analyzed.collectFirst { case LogicalRelation(r, _, _, _) => r }.get diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/CachedTableSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/CachedTableSuite.scala index d3cbf898e2439..48ab4eb9a6178 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/CachedTableSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/CachedTableSuite.scala @@ -102,14 +102,18 @@ class CachedTableSuite extends QueryTest with SQLTestUtils with TestHiveSingleto } test("uncache of nonexistant tables") { + val expectedErrorMsg = "Table or view not found: nonexistantTable" // make sure table doesn't exist - intercept[NoSuchTableException](spark.table("nonexistantTable")) - intercept[NoSuchTableException] { + var e = intercept[AnalysisException](spark.table("nonexistantTable")).getMessage + assert(e.contains(expectedErrorMsg)) + e = intercept[AnalysisException] { spark.catalog.uncacheTable("nonexistantTable") - } - intercept[NoSuchTableException] { + }.getMessage + assert(e.contains(expectedErrorMsg)) + e = intercept[AnalysisException] { sql("UNCACHE TABLE nonexistantTable") - } + }.getMessage + assert(e.contains(expectedErrorMsg)) sql("UNCACHE TABLE IF EXISTS nonexistantTable") } From 21a7bfd5c324e6c82152229f1394f26afeae771c Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Sat, 11 Nov 2017 22:40:26 +0100 Subject: [PATCH 1657/1765] [SPARK-10365][SQL] Support Parquet logical type TIMESTAMP_MICROS ## What changes were proposed in this pull request? This PR makes Spark to be able to read Parquet TIMESTAMP_MICROS values, and add a new config to allow Spark to write timestamp values to parquet as TIMESTAMP_MICROS type. ## How was this patch tested? new test Author: Wenchen Fan Closes #19702 from cloud-fan/parquet. --- .../apache/spark/sql/internal/SQLConf.scala | 30 ++++- .../SpecificParquetRecordReaderBase.java | 4 +- .../parquet/VectorizedColumnReader.java | 26 ++-- .../VectorizedParquetRecordReader.java | 11 +- .../parquet/ParquetFileFormat.scala | 51 +++----- .../parquet/ParquetReadSupport.scala | 4 +- .../parquet/ParquetRecordMaterializer.scala | 4 +- .../parquet/ParquetRowConverter.scala | 15 ++- .../parquet/ParquetSchemaConverter.scala | 92 ++++++++------- .../parquet/ParquetWriteSupport.scala | 48 ++++---- .../datasources/parquet/ParquetIOSuite.scala | 5 +- .../parquet/ParquetQuerySuite.scala | 22 ++++ .../parquet/ParquetSchemaSuite.scala | 111 +++++++----------- .../datasources/parquet/ParquetTest.scala | 2 +- .../spark/sql/internal/SQLConfSuite.scala | 28 +++++ 15 files changed, 249 insertions(+), 204 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index a04f8778079de..831ef62d74c3c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -285,8 +285,24 @@ object SQLConf { .booleanConf .createWithDefault(true) + object ParquetOutputTimestampType extends Enumeration { + val INT96, TIMESTAMP_MICROS, TIMESTAMP_MILLIS = Value + } + + val PARQUET_OUTPUT_TIMESTAMP_TYPE = buildConf("spark.sql.parquet.outputTimestampType") + .doc("Sets which Parquet timestamp type to use when Spark writes data to Parquet files. " + + "INT96 is a non-standard but commonly used timestamp type in Parquet. TIMESTAMP_MICROS " + + "is a standard timestamp type in Parquet, which stores number of microseconds from the " + + "Unix epoch. TIMESTAMP_MILLIS is also standard, but with millisecond precision, which " + + "means Spark has to truncate the microsecond portion of its timestamp value.") + .stringConf + .transform(_.toUpperCase(Locale.ROOT)) + .checkValues(ParquetOutputTimestampType.values.map(_.toString)) + .createWithDefault(ParquetOutputTimestampType.INT96.toString) + val PARQUET_INT64_AS_TIMESTAMP_MILLIS = buildConf("spark.sql.parquet.int64AsTimestampMillis") - .doc("When true, timestamp values will be stored as INT64 with TIMESTAMP_MILLIS as the " + + .doc(s"(Deprecated since Spark 2.3, please set ${PARQUET_OUTPUT_TIMESTAMP_TYPE.key}.) " + + "When true, timestamp values will be stored as INT64 with TIMESTAMP_MILLIS as the " + "extended type. In this mode, the microsecond portion of the timestamp value will be" + "truncated.") .booleanConf @@ -1143,6 +1159,18 @@ class SQLConf extends Serializable with Logging { def isParquetINT64AsTimestampMillis: Boolean = getConf(PARQUET_INT64_AS_TIMESTAMP_MILLIS) + def parquetOutputTimestampType: ParquetOutputTimestampType.Value = { + val isOutputTimestampTypeSet = settings.containsKey(PARQUET_OUTPUT_TIMESTAMP_TYPE.key) + if (!isOutputTimestampTypeSet && isParquetINT64AsTimestampMillis) { + // If PARQUET_OUTPUT_TIMESTAMP_TYPE is not set and PARQUET_INT64_AS_TIMESTAMP_MILLIS is set, + // respect PARQUET_INT64_AS_TIMESTAMP_MILLIS and use TIMESTAMP_MILLIS. Otherwise, + // PARQUET_OUTPUT_TIMESTAMP_TYPE has higher priority. + ParquetOutputTimestampType.TIMESTAMP_MILLIS + } else { + ParquetOutputTimestampType.withName(getConf(PARQUET_OUTPUT_TIMESTAMP_TYPE)) + } + } + def writeLegacyParquetFormat: Boolean = getConf(PARQUET_WRITE_LEGACY_FORMAT) def inMemoryPartitionPruning: Boolean = getConf(IN_MEMORY_PARTITION_PRUNING) diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/SpecificParquetRecordReaderBase.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/SpecificParquetRecordReaderBase.java index 5a810cae1e184..80c2f491b48ce 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/SpecificParquetRecordReaderBase.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/SpecificParquetRecordReaderBase.java @@ -197,8 +197,6 @@ protected void initialize(String path, List columns) throws IOException Configuration config = new Configuration(); config.set("spark.sql.parquet.binaryAsString", "false"); config.set("spark.sql.parquet.int96AsTimestamp", "false"); - config.set("spark.sql.parquet.writeLegacyFormat", "false"); - config.set("spark.sql.parquet.int64AsTimestampMillis", "false"); this.file = new Path(path); long length = this.file.getFileSystem(config).getFileStatus(this.file).getLen(); @@ -224,7 +222,7 @@ protected void initialize(String path, List columns) throws IOException this.requestedSchema = ParquetSchemaConverter.EMPTY_MESSAGE(); } } - this.sparkSchema = new ParquetSchemaConverter(config).convert(requestedSchema); + this.sparkSchema = new ParquetToSparkSchemaConverter(config).convert(requestedSchema); this.reader = new ParquetFileReader( config, footer.getFileMetaData(), file, blocks, requestedSchema.getColumns()); for (BlockMetaData block : blocks) { diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedColumnReader.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedColumnReader.java index 3c8d766ffad30..0f1f470dc597e 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedColumnReader.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedColumnReader.java @@ -26,6 +26,7 @@ import org.apache.parquet.column.page.*; import org.apache.parquet.column.values.ValuesReader; import org.apache.parquet.io.api.Binary; +import org.apache.parquet.schema.OriginalType; import org.apache.parquet.schema.PrimitiveType; import org.apache.spark.sql.catalyst.util.DateTimeUtils; @@ -91,11 +92,15 @@ public class VectorizedColumnReader { private final PageReader pageReader; private final ColumnDescriptor descriptor; + private final OriginalType originalType; - public VectorizedColumnReader(ColumnDescriptor descriptor, PageReader pageReader) - throws IOException { + public VectorizedColumnReader( + ColumnDescriptor descriptor, + OriginalType originalType, + PageReader pageReader) throws IOException { this.descriptor = descriptor; this.pageReader = pageReader; + this.originalType = originalType; this.maxDefLevel = descriptor.getMaxDefinitionLevel(); DictionaryPage dictionaryPage = pageReader.readDictionaryPage(); @@ -158,12 +163,12 @@ void readBatch(int total, WritableColumnVector column) throws IOException { defColumn.readIntegers( num, dictionaryIds, column, rowId, maxDefLevel, (VectorizedValuesReader) dataColumn); - // Timestamp values encoded as INT64 can't be lazily decoded as we need to post process + // TIMESTAMP_MILLIS encoded as INT64 can't be lazily decoded as we need to post process // the values to add microseconds precision. if (column.hasDictionary() || (rowId == 0 && (descriptor.getType() == PrimitiveType.PrimitiveTypeName.INT32 || (descriptor.getType() == PrimitiveType.PrimitiveTypeName.INT64 && - column.dataType() != DataTypes.TimestampType) || + originalType != OriginalType.TIMESTAMP_MILLIS) || descriptor.getType() == PrimitiveType.PrimitiveTypeName.FLOAT || descriptor.getType() == PrimitiveType.PrimitiveTypeName.DOUBLE || descriptor.getType() == PrimitiveType.PrimitiveTypeName.BINARY))) { @@ -253,21 +258,21 @@ private void decodeDictionaryIds( case INT64: if (column.dataType() == DataTypes.LongType || - DecimalType.is64BitDecimalType(column.dataType())) { + DecimalType.is64BitDecimalType(column.dataType()) || + originalType == OriginalType.TIMESTAMP_MICROS) { for (int i = rowId; i < rowId + num; ++i) { if (!column.isNullAt(i)) { column.putLong(i, dictionary.decodeToLong(dictionaryIds.getDictId(i))); } } - } else if (column.dataType() == DataTypes.TimestampType) { + } else if (originalType == OriginalType.TIMESTAMP_MILLIS) { for (int i = rowId; i < rowId + num; ++i) { if (!column.isNullAt(i)) { column.putLong(i, DateTimeUtils.fromMillis(dictionary.decodeToLong(dictionaryIds.getDictId(i)))); } } - } - else { + } else { throw new UnsupportedOperationException("Unimplemented type: " + column.dataType()); } break; @@ -377,10 +382,11 @@ private void readIntBatch(int rowId, int num, WritableColumnVector column) { private void readLongBatch(int rowId, int num, WritableColumnVector column) { // This is where we implement support for the valid type conversions. if (column.dataType() == DataTypes.LongType || - DecimalType.is64BitDecimalType(column.dataType())) { + DecimalType.is64BitDecimalType(column.dataType()) || + originalType == OriginalType.TIMESTAMP_MICROS) { defColumn.readLongs( num, column, rowId, maxDefLevel, (VectorizedValuesReader) dataColumn); - } else if (column.dataType() == DataTypes.TimestampType) { + } else if (originalType == OriginalType.TIMESTAMP_MILLIS) { for (int i = 0; i < num; i++) { if (defColumn.readInteger() == maxDefLevel) { column.putLong(rowId + i, DateTimeUtils.fromMillis(dataColumn.readLong())); diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedParquetRecordReader.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedParquetRecordReader.java index 0cacf0c9c93a5..e827229dceef8 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedParquetRecordReader.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedParquetRecordReader.java @@ -165,8 +165,10 @@ public float getProgress() throws IOException, InterruptedException { // Columns 0,1: data columns // Column 2: partitionValues[0] // Column 3: partitionValues[1] - public void initBatch(MemoryMode memMode, StructType partitionColumns, - InternalRow partitionValues) { + public void initBatch( + MemoryMode memMode, + StructType partitionColumns, + InternalRow partitionValues) { StructType batchSchema = new StructType(); for (StructField f: sparkSchema.fields()) { batchSchema = batchSchema.add(f); @@ -281,11 +283,12 @@ private void checkEndOfRowGroup() throws IOException { + rowsReturned + " out of " + totalRowCount); } List columns = requestedSchema.getColumns(); + List types = requestedSchema.asGroupType().getFields(); columnReaders = new VectorizedColumnReader[columns.size()]; for (int i = 0; i < columns.size(); ++i) { if (missingColumns[i]) continue; - columnReaders[i] = new VectorizedColumnReader(columns.get(i), - pages.getPageReader(columns.get(i))); + columnReaders[i] = new VectorizedColumnReader( + columns.get(i), types.get(i).getOriginalType(), pages.getPageReader(columns.get(i))); } totalCountLoadedSoFar += pages.getRowCount(); } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala index 61bd65dd48144..a48f8d517b6ab 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala @@ -111,23 +111,15 @@ class ParquetFileFormat // This metadata is only useful for detecting optional columns when pushdowning filters. ParquetWriteSupport.setSchema(dataSchema, conf) - // Sets flags for `CatalystSchemaConverter` (which converts Catalyst schema to Parquet schema) - // and `CatalystWriteSupport` (writing actual rows to Parquet files). - conf.set( - SQLConf.PARQUET_BINARY_AS_STRING.key, - sparkSession.sessionState.conf.isParquetBinaryAsString.toString) - - conf.set( - SQLConf.PARQUET_INT96_AS_TIMESTAMP.key, - sparkSession.sessionState.conf.isParquetINT96AsTimestamp.toString) - + // Sets flags for `ParquetWriteSupport`, which converts Catalyst schema to Parquet + // schema and writes actual rows to Parquet files. conf.set( SQLConf.PARQUET_WRITE_LEGACY_FORMAT.key, sparkSession.sessionState.conf.writeLegacyParquetFormat.toString) conf.set( - SQLConf.PARQUET_INT64_AS_TIMESTAMP_MILLIS.key, - sparkSession.sessionState.conf.isParquetINT64AsTimestampMillis.toString) + SQLConf.PARQUET_OUTPUT_TIMESTAMP_TYPE.key, + sparkSession.sessionState.conf.parquetOutputTimestampType.toString) // Sets compression scheme conf.set(ParquetOutputFormat.COMPRESSION, parquetOptions.compressionCodecClassName) @@ -312,16 +304,13 @@ class ParquetFileFormat ParquetWriteSupport.setSchema(requiredSchema, hadoopConf) - // Sets flags for `CatalystSchemaConverter` + // Sets flags for `ParquetToSparkSchemaConverter` hadoopConf.setBoolean( SQLConf.PARQUET_BINARY_AS_STRING.key, sparkSession.sessionState.conf.isParquetBinaryAsString) hadoopConf.setBoolean( SQLConf.PARQUET_INT96_AS_TIMESTAMP.key, sparkSession.sessionState.conf.isParquetINT96AsTimestamp) - hadoopConf.setBoolean( - SQLConf.PARQUET_INT64_AS_TIMESTAMP_MILLIS.key, - sparkSession.sessionState.conf.isParquetINT64AsTimestampMillis) // Try to push down filters when filter push-down is enabled. val pushed = @@ -428,15 +417,9 @@ object ParquetFileFormat extends Logging { private[parquet] def readSchema( footers: Seq[Footer], sparkSession: SparkSession): Option[StructType] = { - def parseParquetSchema(schema: MessageType): StructType = { - val converter = new ParquetSchemaConverter( - sparkSession.sessionState.conf.isParquetBinaryAsString, - sparkSession.sessionState.conf.isParquetBinaryAsString, - sparkSession.sessionState.conf.writeLegacyParquetFormat, - sparkSession.sessionState.conf.isParquetINT64AsTimestampMillis) - - converter.convert(schema) - } + val converter = new ParquetToSparkSchemaConverter( + sparkSession.sessionState.conf.isParquetBinaryAsString, + sparkSession.sessionState.conf.isParquetINT96AsTimestamp) val seen = mutable.HashSet[String]() val finalSchemas: Seq[StructType] = footers.flatMap { footer => @@ -447,7 +430,7 @@ object ParquetFileFormat extends Logging { .get(ParquetReadSupport.SPARK_METADATA_KEY) if (serializedSchema.isEmpty) { // Falls back to Parquet schema if no Spark SQL schema found. - Some(parseParquetSchema(metadata.getSchema)) + Some(converter.convert(metadata.getSchema)) } else if (!seen.contains(serializedSchema.get)) { seen += serializedSchema.get @@ -470,7 +453,7 @@ object ParquetFileFormat extends Logging { .map(_.asInstanceOf[StructType]) .getOrElse { // Falls back to Parquet schema if Spark SQL schema can't be parsed. - parseParquetSchema(metadata.getSchema) + converter.convert(metadata.getSchema) }) } else { None @@ -538,8 +521,6 @@ object ParquetFileFormat extends Logging { sparkSession: SparkSession): Option[StructType] = { val assumeBinaryIsString = sparkSession.sessionState.conf.isParquetBinaryAsString val assumeInt96IsTimestamp = sparkSession.sessionState.conf.isParquetINT96AsTimestamp - val writeTimestampInMillis = sparkSession.sessionState.conf.isParquetINT64AsTimestampMillis - val writeLegacyParquetFormat = sparkSession.sessionState.conf.writeLegacyParquetFormat val serializedConf = new SerializableConfiguration(sparkSession.sessionState.newHadoopConf()) // !! HACK ALERT !! @@ -579,13 +560,9 @@ object ParquetFileFormat extends Logging { serializedConf.value, fakeFileStatuses, ignoreCorruptFiles) // Converter used to convert Parquet `MessageType` to Spark SQL `StructType` - val converter = - new ParquetSchemaConverter( - assumeBinaryIsString = assumeBinaryIsString, - assumeInt96IsTimestamp = assumeInt96IsTimestamp, - writeLegacyParquetFormat = writeLegacyParquetFormat, - writeTimestampInMillis = writeTimestampInMillis) - + val converter = new ParquetToSparkSchemaConverter( + assumeBinaryIsString = assumeBinaryIsString, + assumeInt96IsTimestamp = assumeInt96IsTimestamp) if (footers.isEmpty) { Iterator.empty } else { @@ -625,7 +602,7 @@ object ParquetFileFormat extends Logging { * a [[StructType]] converted from the [[MessageType]] stored in this footer. */ def readSchemaFromFooter( - footer: Footer, converter: ParquetSchemaConverter): StructType = { + footer: Footer, converter: ParquetToSparkSchemaConverter): StructType = { val fileMetaData = footer.getParquetMetadata.getFileMetaData fileMetaData .getKeyValueMetaData diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetReadSupport.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetReadSupport.scala index f1a35dd8a6200..2854cb1bc0c25 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetReadSupport.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetReadSupport.scala @@ -95,7 +95,7 @@ private[parquet] class ParquetReadSupport extends ReadSupport[UnsafeRow] with Lo new ParquetRecordMaterializer( parquetRequestedSchema, ParquetReadSupport.expandUDT(catalystRequestedSchema), - new ParquetSchemaConverter(conf)) + new ParquetToSparkSchemaConverter(conf)) } } @@ -270,7 +270,7 @@ private[parquet] object ParquetReadSupport { private def clipParquetGroupFields( parquetRecord: GroupType, structType: StructType): Seq[Type] = { val parquetFieldMap = parquetRecord.getFields.asScala.map(f => f.getName -> f).toMap - val toParquet = new ParquetSchemaConverter(writeLegacyParquetFormat = false) + val toParquet = new SparkToParquetSchemaConverter(writeLegacyParquetFormat = false) structType.map { f => parquetFieldMap .get(f.name) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRecordMaterializer.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRecordMaterializer.scala index 4e49a0dac97c0..793755e9aaeb5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRecordMaterializer.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRecordMaterializer.scala @@ -31,7 +31,9 @@ import org.apache.spark.sql.types.StructType * @param schemaConverter A Parquet-Catalyst schema converter that helps initializing row converters */ private[parquet] class ParquetRecordMaterializer( - parquetSchema: MessageType, catalystSchema: StructType, schemaConverter: ParquetSchemaConverter) + parquetSchema: MessageType, + catalystSchema: StructType, + schemaConverter: ParquetToSparkSchemaConverter) extends RecordMaterializer[UnsafeRow] { private val rootConverter = diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRowConverter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRowConverter.scala index 32e6c60cd9766..10f6c3b4f15e3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRowConverter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRowConverter.scala @@ -27,7 +27,7 @@ import org.apache.parquet.column.Dictionary import org.apache.parquet.io.api.{Binary, Converter, GroupConverter, PrimitiveConverter} import org.apache.parquet.schema.{GroupType, MessageType, OriginalType, Type} import org.apache.parquet.schema.OriginalType.{INT_32, LIST, UTF8} -import org.apache.parquet.schema.PrimitiveType.PrimitiveTypeName.{BINARY, DOUBLE, FIXED_LEN_BYTE_ARRAY, INT32, INT64} +import org.apache.parquet.schema.PrimitiveType.PrimitiveTypeName.{BINARY, DOUBLE, FIXED_LEN_BYTE_ARRAY, INT32, INT64, INT96} import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.InternalRow @@ -120,7 +120,7 @@ private[parquet] class ParquetPrimitiveConverter(val updater: ParentContainerUpd * @param updater An updater which propagates converted field values to the parent container */ private[parquet] class ParquetRowConverter( - schemaConverter: ParquetSchemaConverter, + schemaConverter: ParquetToSparkSchemaConverter, parquetType: GroupType, catalystType: StructType, updater: ParentContainerUpdater) @@ -252,6 +252,13 @@ private[parquet] class ParquetRowConverter( case StringType => new ParquetStringConverter(updater) + case TimestampType if parquetType.getOriginalType == OriginalType.TIMESTAMP_MICROS => + new ParquetPrimitiveConverter(updater) { + override def addLong(value: Long): Unit = { + updater.setLong(value) + } + } + case TimestampType if parquetType.getOriginalType == OriginalType.TIMESTAMP_MILLIS => new ParquetPrimitiveConverter(updater) { override def addLong(value: Long): Unit = { @@ -259,8 +266,8 @@ private[parquet] class ParquetRowConverter( } } - case TimestampType => - // TODO Implements `TIMESTAMP_MICROS` once parquet-mr has that. + // INT96 timestamp doesn't have a logical type, here we check the physical type instead. + case TimestampType if parquetType.asPrimitiveType().getPrimitiveTypeName == INT96 => new ParquetPrimitiveConverter(updater) { // Converts nanosecond timestamps stored as INT96 override def addBinary(value: Binary): Unit = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaConverter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaConverter.scala index cd384d17d0cda..c61be077d309f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaConverter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaConverter.scala @@ -30,49 +30,31 @@ import org.apache.spark.sql.execution.datasources.parquet.ParquetSchemaConverter import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ + /** - * This converter class is used to convert Parquet [[MessageType]] to Spark SQL [[StructType]] and - * vice versa. + * This converter class is used to convert Parquet [[MessageType]] to Spark SQL [[StructType]]. * * Parquet format backwards-compatibility rules are respected when converting Parquet * [[MessageType]] schemas. * * @see https://github.com/apache/parquet-format/blob/master/LogicalTypes.md - * @constructor + * * @param assumeBinaryIsString Whether unannotated BINARY fields should be assumed to be Spark SQL - * [[StringType]] fields when converting Parquet a [[MessageType]] to Spark SQL - * [[StructType]]. This argument only affects Parquet read path. + * [[StringType]] fields. * @param assumeInt96IsTimestamp Whether unannotated INT96 fields should be assumed to be Spark SQL - * [[TimestampType]] fields when converting Parquet a [[MessageType]] to Spark SQL - * [[StructType]]. Note that Spark SQL [[TimestampType]] is similar to Hive timestamp, which - * has optional nanosecond precision, but different from `TIME_MILLS` and `TIMESTAMP_MILLIS` - * described in Parquet format spec. This argument only affects Parquet read path. - * @param writeLegacyParquetFormat Whether to use legacy Parquet format compatible with Spark 1.4 - * and prior versions when converting a Catalyst [[StructType]] to a Parquet [[MessageType]]. - * When set to false, use standard format defined in parquet-format spec. This argument only - * affects Parquet write path. - * @param writeTimestampInMillis Whether to write timestamp values as INT64 annotated by logical - * type TIMESTAMP_MILLIS. - * + * [[TimestampType]] fields. */ -private[parquet] class ParquetSchemaConverter( +class ParquetToSparkSchemaConverter( assumeBinaryIsString: Boolean = SQLConf.PARQUET_BINARY_AS_STRING.defaultValue.get, - assumeInt96IsTimestamp: Boolean = SQLConf.PARQUET_INT96_AS_TIMESTAMP.defaultValue.get, - writeLegacyParquetFormat: Boolean = SQLConf.PARQUET_WRITE_LEGACY_FORMAT.defaultValue.get, - writeTimestampInMillis: Boolean = SQLConf.PARQUET_INT64_AS_TIMESTAMP_MILLIS.defaultValue.get) { + assumeInt96IsTimestamp: Boolean = SQLConf.PARQUET_INT96_AS_TIMESTAMP.defaultValue.get) { def this(conf: SQLConf) = this( assumeBinaryIsString = conf.isParquetBinaryAsString, - assumeInt96IsTimestamp = conf.isParquetINT96AsTimestamp, - writeLegacyParquetFormat = conf.writeLegacyParquetFormat, - writeTimestampInMillis = conf.isParquetINT64AsTimestampMillis) + assumeInt96IsTimestamp = conf.isParquetINT96AsTimestamp) def this(conf: Configuration) = this( assumeBinaryIsString = conf.get(SQLConf.PARQUET_BINARY_AS_STRING.key).toBoolean, - assumeInt96IsTimestamp = conf.get(SQLConf.PARQUET_INT96_AS_TIMESTAMP.key).toBoolean, - writeLegacyParquetFormat = conf.get(SQLConf.PARQUET_WRITE_LEGACY_FORMAT.key, - SQLConf.PARQUET_WRITE_LEGACY_FORMAT.defaultValue.get.toString).toBoolean, - writeTimestampInMillis = conf.get(SQLConf.PARQUET_INT64_AS_TIMESTAMP_MILLIS.key).toBoolean) + assumeInt96IsTimestamp = conf.get(SQLConf.PARQUET_INT96_AS_TIMESTAMP.key).toBoolean) /** @@ -165,6 +147,7 @@ private[parquet] class ParquetSchemaConverter( case INT_64 | null => LongType case DECIMAL => makeDecimalType(Decimal.MAX_LONG_DIGITS) case UINT_64 => typeNotSupported() + case TIMESTAMP_MICROS => TimestampType case TIMESTAMP_MILLIS => TimestampType case _ => illegalType() } @@ -310,6 +293,30 @@ private[parquet] class ParquetSchemaConverter( repeatedType.getName == s"${parentName}_tuple" } } +} + +/** + * This converter class is used to convert Spark SQL [[StructType]] to Parquet [[MessageType]]. + * + * @param writeLegacyParquetFormat Whether to use legacy Parquet format compatible with Spark 1.4 + * and prior versions when converting a Catalyst [[StructType]] to a Parquet [[MessageType]]. + * When set to false, use standard format defined in parquet-format spec. This argument only + * affects Parquet write path. + * @param outputTimestampType which parquet timestamp type to use when writing. + */ +class SparkToParquetSchemaConverter( + writeLegacyParquetFormat: Boolean = SQLConf.PARQUET_WRITE_LEGACY_FORMAT.defaultValue.get, + outputTimestampType: SQLConf.ParquetOutputTimestampType.Value = + SQLConf.ParquetOutputTimestampType.INT96) { + + def this(conf: SQLConf) = this( + writeLegacyParquetFormat = conf.writeLegacyParquetFormat, + outputTimestampType = conf.parquetOutputTimestampType) + + def this(conf: Configuration) = this( + writeLegacyParquetFormat = conf.get(SQLConf.PARQUET_WRITE_LEGACY_FORMAT.key).toBoolean, + outputTimestampType = SQLConf.ParquetOutputTimestampType.withName( + conf.get(SQLConf.PARQUET_OUTPUT_TIMESTAMP_TYPE.key))) /** * Converts a Spark SQL [[StructType]] to a Parquet [[MessageType]]. @@ -363,7 +370,9 @@ private[parquet] class ParquetSchemaConverter( case DateType => Types.primitive(INT32, repetition).as(DATE).named(field.name) - // NOTE: Spark SQL TimestampType is NOT a well defined type in Parquet format spec. + // NOTE: Spark SQL can write timestamp values to Parquet using INT96, TIMESTAMP_MICROS or + // TIMESTAMP_MILLIS. TIMESTAMP_MICROS is recommended but INT96 is the default to keep the + // behavior same as before. // // As stated in PARQUET-323, Parquet `INT96` was originally introduced to represent nanosecond // timestamp in Impala for some historical reasons. It's not recommended to be used for any @@ -372,23 +381,18 @@ private[parquet] class ParquetSchemaConverter( // `TIMESTAMP_MICROS` which are both logical types annotating `INT64`. // // Originally, Spark SQL uses the same nanosecond timestamp type as Impala and Hive. Starting - // from Spark 1.5.0, we resort to a timestamp type with 100 ns precision so that we can store - // a timestamp into a `Long`. This design decision is subject to change though, for example, - // we may resort to microsecond precision in the future. - // - // For Parquet, we plan to write all `TimestampType` value as `TIMESTAMP_MICROS`, but it's - // currently not implemented yet because parquet-mr 1.8.1 (the version we're currently using) - // hasn't implemented `TIMESTAMP_MICROS` yet, however it supports TIMESTAMP_MILLIS. We will - // encode timestamp values as TIMESTAMP_MILLIS annotating INT64 if - // 'spark.sql.parquet.int64AsTimestampMillis' is set. - // - // TODO Converts `TIMESTAMP_MICROS` once parquet-mr implements that. - - case TimestampType if writeTimestampInMillis => - Types.primitive(INT64, repetition).as(TIMESTAMP_MILLIS).named(field.name) - + // from Spark 1.5.0, we resort to a timestamp type with microsecond precision so that we can + // store a timestamp into a `Long`. This design decision is subject to change though, for + // example, we may resort to nanosecond precision in the future. case TimestampType => - Types.primitive(INT96, repetition).named(field.name) + outputTimestampType match { + case SQLConf.ParquetOutputTimestampType.INT96 => + Types.primitive(INT96, repetition).named(field.name) + case SQLConf.ParquetOutputTimestampType.TIMESTAMP_MICROS => + Types.primitive(INT64, repetition).as(TIMESTAMP_MICROS).named(field.name) + case SQLConf.ParquetOutputTimestampType.TIMESTAMP_MILLIS => + Types.primitive(INT64, repetition).as(TIMESTAMP_MILLIS).named(field.name) + } case BinaryType => Types.primitive(BINARY, repetition).named(field.name) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetWriteSupport.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetWriteSupport.scala index 63a8666f0d774..af4e1433c876f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetWriteSupport.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetWriteSupport.scala @@ -66,8 +66,8 @@ private[parquet] class ParquetWriteSupport extends WriteSupport[InternalRow] wit // Whether to write data in legacy Parquet format compatible with Spark 1.4 and prior versions private var writeLegacyParquetFormat: Boolean = _ - // Whether to write timestamp value with milliseconds precision. - private var writeTimestampInMillis: Boolean = _ + // Which parquet timestamp type to use when writing. + private var outputTimestampType: SQLConf.ParquetOutputTimestampType.Value = _ // Reusable byte array used to write timestamps as Parquet INT96 values private val timestampBuffer = new Array[Byte](12) @@ -84,15 +84,15 @@ private[parquet] class ParquetWriteSupport extends WriteSupport[InternalRow] wit configuration.get(SQLConf.PARQUET_WRITE_LEGACY_FORMAT.key).toBoolean } - this.writeTimestampInMillis = { - assert(configuration.get(SQLConf.PARQUET_INT64_AS_TIMESTAMP_MILLIS.key) != null) - configuration.get(SQLConf.PARQUET_INT64_AS_TIMESTAMP_MILLIS.key).toBoolean + this.outputTimestampType = { + val key = SQLConf.PARQUET_OUTPUT_TIMESTAMP_TYPE.key + assert(configuration.get(key) != null) + SQLConf.ParquetOutputTimestampType.withName(configuration.get(key)) } - this.rootFieldWriters = schema.map(_.dataType).map(makeWriter).toArray[ValueWriter] - val messageType = new ParquetSchemaConverter(configuration).convert(schema) + val messageType = new SparkToParquetSchemaConverter(configuration).convert(schema) val metadata = Map(ParquetReadSupport.SPARK_METADATA_KEY -> schemaString).asJava logInfo( @@ -163,25 +163,23 @@ private[parquet] class ParquetWriteSupport extends WriteSupport[InternalRow] wit recordConsumer.addBinary( Binary.fromReusedByteArray(row.getUTF8String(ordinal).getBytes)) - case TimestampType if writeTimestampInMillis => - (row: SpecializedGetters, ordinal: Int) => - val millis = DateTimeUtils.toMillis(row.getLong(ordinal)) - recordConsumer.addLong(millis) - case TimestampType => - (row: SpecializedGetters, ordinal: Int) => { - // TODO Writes `TimestampType` values as `TIMESTAMP_MICROS` once parquet-mr implements it - // Currently we only support timestamps stored as INT96, which is compatible with Hive - // and Impala. However, INT96 is to be deprecated. We plan to support `TIMESTAMP_MICROS` - // defined in the parquet-format spec. But up until writing, the most recent parquet-mr - // version (1.8.1) hasn't implemented it yet. - - // NOTE: Starting from Spark 1.5, Spark SQL `TimestampType` only has microsecond - // precision. Nanosecond parts of timestamp values read from INT96 are simply stripped. - val (julianDay, timeOfDayNanos) = DateTimeUtils.toJulianDay(row.getLong(ordinal)) - val buf = ByteBuffer.wrap(timestampBuffer) - buf.order(ByteOrder.LITTLE_ENDIAN).putLong(timeOfDayNanos).putInt(julianDay) - recordConsumer.addBinary(Binary.fromReusedByteArray(timestampBuffer)) + outputTimestampType match { + case SQLConf.ParquetOutputTimestampType.INT96 => + (row: SpecializedGetters, ordinal: Int) => + val (julianDay, timeOfDayNanos) = DateTimeUtils.toJulianDay(row.getLong(ordinal)) + val buf = ByteBuffer.wrap(timestampBuffer) + buf.order(ByteOrder.LITTLE_ENDIAN).putLong(timeOfDayNanos).putInt(julianDay) + recordConsumer.addBinary(Binary.fromReusedByteArray(timestampBuffer)) + + case SQLConf.ParquetOutputTimestampType.TIMESTAMP_MICROS => + (row: SpecializedGetters, ordinal: Int) => + recordConsumer.addLong(row.getLong(ordinal)) + + case SQLConf.ParquetOutputTimestampType.TIMESTAMP_MILLIS => + (row: SpecializedGetters, ordinal: Int) => + val millis = DateTimeUtils.toMillis(row.getLong(ordinal)) + recordConsumer.addLong(millis) } case BinaryType => diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala index d76990b482db2..633cfde6ab941 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala @@ -110,12 +110,13 @@ class ParquetIOSuite extends QueryTest with ParquetTest with SharedSQLContext { | required binary h(DECIMAL(32,0)); | required fixed_len_byte_array(32) i(DECIMAL(32,0)); | required int64 j(TIMESTAMP_MILLIS); + | required int64 k(TIMESTAMP_MICROS); |} """.stripMargin) val expectedSparkTypes = Seq(ByteType, ShortType, DateType, DecimalType(1, 0), DecimalType(10, 0), StringType, StringType, DecimalType(32, 0), DecimalType(32, 0), - TimestampType) + TimestampType, TimestampType) withTempPath { location => val path = new Path(location.getCanonicalPath) @@ -380,7 +381,7 @@ class ParquetIOSuite extends QueryTest with ParquetTest with SharedSQLContext { assert(fs.exists(new Path(path, ParquetFileWriter.PARQUET_COMMON_METADATA_FILE))) assert(fs.exists(new Path(path, ParquetFileWriter.PARQUET_METADATA_FILE))) - val expectedSchema = new ParquetSchemaConverter().convert(schema) + val expectedSchema = new SparkToParquetSchemaConverter().convert(schema) val actualSchema = readFooter(path, hadoopConf).getFileMetaData.getSchema actualSchema.checkContains(expectedSchema) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala index e822e40b146ee..4c8c9ef6e0432 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala @@ -235,6 +235,28 @@ class ParquetQuerySuite extends QueryTest with ParquetTest with SharedSQLContext } } + test("SPARK-10365 timestamp written and read as INT64 - TIMESTAMP_MICROS") { + val data = (1 to 10).map { i => + val ts = new java.sql.Timestamp(i) + ts.setNanos(2000) + Row(i, ts) + } + val schema = StructType(List(StructField("d", IntegerType, false), + StructField("time", TimestampType, false)).toArray) + withSQLConf(SQLConf.PARQUET_OUTPUT_TIMESTAMP_TYPE.key -> "TIMESTAMP_MICROS") { + withTempPath { file => + val df = spark.createDataFrame(sparkContext.parallelize(data), schema) + df.write.parquet(file.getCanonicalPath) + ("true" :: "false" :: Nil).foreach { vectorized => + withSQLConf(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key -> vectorized) { + val df2 = spark.read.parquet(file.getCanonicalPath) + checkAnswer(df2, df.collect().toSeq) + } + } + } + } + } + test("Enabling/disabling merging partfiles when merging parquet schema") { def testSchemaMerging(expectedColumnNumber: Int): Unit = { withTempDir { dir => diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaSuite.scala index ce992674d719f..2cd2a600f2b97 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaSuite.scala @@ -24,6 +24,7 @@ import org.apache.parquet.schema.{MessageType, MessageTypeParser} import org.apache.spark.SparkException import org.apache.spark.sql.catalyst.ScalaReflection +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types._ @@ -52,14 +53,10 @@ abstract class ParquetSchemaTest extends ParquetTest with SharedSQLContext { sqlSchema: StructType, parquetSchema: String, binaryAsString: Boolean, - int96AsTimestamp: Boolean, - writeLegacyParquetFormat: Boolean, - int64AsTimestampMillis: Boolean = false): Unit = { - val converter = new ParquetSchemaConverter( + int96AsTimestamp: Boolean): Unit = { + val converter = new ParquetToSparkSchemaConverter( assumeBinaryIsString = binaryAsString, - assumeInt96IsTimestamp = int96AsTimestamp, - writeLegacyParquetFormat = writeLegacyParquetFormat, - writeTimestampInMillis = int64AsTimestampMillis) + assumeInt96IsTimestamp = int96AsTimestamp) test(s"sql <= parquet: $testName") { val actual = converter.convert(MessageTypeParser.parseMessageType(parquetSchema)) @@ -77,15 +74,12 @@ abstract class ParquetSchemaTest extends ParquetTest with SharedSQLContext { testName: String, sqlSchema: StructType, parquetSchema: String, - binaryAsString: Boolean, - int96AsTimestamp: Boolean, writeLegacyParquetFormat: Boolean, - int64AsTimestampMillis: Boolean = false): Unit = { - val converter = new ParquetSchemaConverter( - assumeBinaryIsString = binaryAsString, - assumeInt96IsTimestamp = int96AsTimestamp, + outputTimestampType: SQLConf.ParquetOutputTimestampType.Value = + SQLConf.ParquetOutputTimestampType.INT96): Unit = { + val converter = new SparkToParquetSchemaConverter( writeLegacyParquetFormat = writeLegacyParquetFormat, - writeTimestampInMillis = int64AsTimestampMillis) + outputTimestampType = outputTimestampType) test(s"sql => parquet: $testName") { val actual = converter.convert(sqlSchema) @@ -102,25 +96,22 @@ abstract class ParquetSchemaTest extends ParquetTest with SharedSQLContext { binaryAsString: Boolean, int96AsTimestamp: Boolean, writeLegacyParquetFormat: Boolean, - int64AsTimestampMillis: Boolean = false): Unit = { + outputTimestampType: SQLConf.ParquetOutputTimestampType.Value = + SQLConf.ParquetOutputTimestampType.INT96): Unit = { testCatalystToParquet( testName, sqlSchema, parquetSchema, - binaryAsString, - int96AsTimestamp, writeLegacyParquetFormat, - int64AsTimestampMillis) + outputTimestampType) testParquetToCatalyst( testName, sqlSchema, parquetSchema, binaryAsString, - int96AsTimestamp, - writeLegacyParquetFormat, - int64AsTimestampMillis) + int96AsTimestamp) } } @@ -411,8 +402,7 @@ class ParquetSchemaSuite extends ParquetSchemaTest { |} """.stripMargin, binaryAsString = true, - int96AsTimestamp = true, - writeLegacyParquetFormat = true) + int96AsTimestamp = true) testParquetToCatalyst( "Backwards-compatibility: LIST with nullable element type - 2", @@ -430,8 +420,7 @@ class ParquetSchemaSuite extends ParquetSchemaTest { |} """.stripMargin, binaryAsString = true, - int96AsTimestamp = true, - writeLegacyParquetFormat = true) + int96AsTimestamp = true) testParquetToCatalyst( "Backwards-compatibility: LIST with non-nullable element type - 1 - standard", @@ -446,8 +435,7 @@ class ParquetSchemaSuite extends ParquetSchemaTest { |} """.stripMargin, binaryAsString = true, - int96AsTimestamp = true, - writeLegacyParquetFormat = true) + int96AsTimestamp = true) testParquetToCatalyst( "Backwards-compatibility: LIST with non-nullable element type - 2", @@ -462,8 +450,7 @@ class ParquetSchemaSuite extends ParquetSchemaTest { |} """.stripMargin, binaryAsString = true, - int96AsTimestamp = true, - writeLegacyParquetFormat = true) + int96AsTimestamp = true) testParquetToCatalyst( "Backwards-compatibility: LIST with non-nullable element type - 3", @@ -476,8 +463,7 @@ class ParquetSchemaSuite extends ParquetSchemaTest { |} """.stripMargin, binaryAsString = true, - int96AsTimestamp = true, - writeLegacyParquetFormat = true) + int96AsTimestamp = true) testParquetToCatalyst( "Backwards-compatibility: LIST with non-nullable element type - 4", @@ -500,8 +486,7 @@ class ParquetSchemaSuite extends ParquetSchemaTest { |} """.stripMargin, binaryAsString = true, - int96AsTimestamp = true, - writeLegacyParquetFormat = true) + int96AsTimestamp = true) testParquetToCatalyst( "Backwards-compatibility: LIST with non-nullable element type - 5 - parquet-avro style", @@ -522,8 +507,7 @@ class ParquetSchemaSuite extends ParquetSchemaTest { |} """.stripMargin, binaryAsString = true, - int96AsTimestamp = true, - writeLegacyParquetFormat = true) + int96AsTimestamp = true) testParquetToCatalyst( "Backwards-compatibility: LIST with non-nullable element type - 6 - parquet-thrift style", @@ -544,8 +528,7 @@ class ParquetSchemaSuite extends ParquetSchemaTest { |} """.stripMargin, binaryAsString = true, - int96AsTimestamp = true, - writeLegacyParquetFormat = true) + int96AsTimestamp = true) testParquetToCatalyst( "Backwards-compatibility: LIST with non-nullable element type 7 - " + @@ -557,8 +540,7 @@ class ParquetSchemaSuite extends ParquetSchemaTest { |} """.stripMargin, binaryAsString = true, - int96AsTimestamp = true, - writeLegacyParquetFormat = true) + int96AsTimestamp = true) testParquetToCatalyst( "Backwards-compatibility: LIST with non-nullable element type 8 - " + @@ -580,8 +562,7 @@ class ParquetSchemaSuite extends ParquetSchemaTest { |} """.stripMargin, binaryAsString = true, - int96AsTimestamp = true, - writeLegacyParquetFormat = true) + int96AsTimestamp = true) // ======================================================= // Tests for converting Catalyst ArrayType to Parquet LIST @@ -602,8 +583,6 @@ class ParquetSchemaSuite extends ParquetSchemaTest { | } |} """.stripMargin, - binaryAsString = true, - int96AsTimestamp = true, writeLegacyParquetFormat = false) testCatalystToParquet( @@ -621,8 +600,6 @@ class ParquetSchemaSuite extends ParquetSchemaTest { | } |} """.stripMargin, - binaryAsString = true, - int96AsTimestamp = true, writeLegacyParquetFormat = true) testCatalystToParquet( @@ -640,8 +617,6 @@ class ParquetSchemaSuite extends ParquetSchemaTest { | } |} """.stripMargin, - binaryAsString = true, - int96AsTimestamp = true, writeLegacyParquetFormat = false) testCatalystToParquet( @@ -657,8 +632,6 @@ class ParquetSchemaSuite extends ParquetSchemaTest { | } |} """.stripMargin, - binaryAsString = true, - int96AsTimestamp = true, writeLegacyParquetFormat = true) // ==================================================== @@ -682,8 +655,7 @@ class ParquetSchemaSuite extends ParquetSchemaTest { |} """.stripMargin, binaryAsString = true, - int96AsTimestamp = true, - writeLegacyParquetFormat = true) + int96AsTimestamp = true) testParquetToCatalyst( "Backwards-compatibility: MAP with non-nullable value type - 2", @@ -702,8 +674,7 @@ class ParquetSchemaSuite extends ParquetSchemaTest { |} """.stripMargin, binaryAsString = true, - int96AsTimestamp = true, - writeLegacyParquetFormat = true) + int96AsTimestamp = true) testParquetToCatalyst( "Backwards-compatibility: MAP with non-nullable value type - 3 - prior to 1.4.x", @@ -722,8 +693,7 @@ class ParquetSchemaSuite extends ParquetSchemaTest { |} """.stripMargin, binaryAsString = true, - int96AsTimestamp = true, - writeLegacyParquetFormat = true) + int96AsTimestamp = true) testParquetToCatalyst( "Backwards-compatibility: MAP with nullable value type - 1 - standard", @@ -742,8 +712,7 @@ class ParquetSchemaSuite extends ParquetSchemaTest { |} """.stripMargin, binaryAsString = true, - int96AsTimestamp = true, - writeLegacyParquetFormat = true) + int96AsTimestamp = true) testParquetToCatalyst( "Backwards-compatibility: MAP with nullable value type - 2", @@ -762,8 +731,7 @@ class ParquetSchemaSuite extends ParquetSchemaTest { |} """.stripMargin, binaryAsString = true, - int96AsTimestamp = true, - writeLegacyParquetFormat = true) + int96AsTimestamp = true) testParquetToCatalyst( "Backwards-compatibility: MAP with nullable value type - 3 - parquet-avro style", @@ -782,8 +750,7 @@ class ParquetSchemaSuite extends ParquetSchemaTest { |} """.stripMargin, binaryAsString = true, - int96AsTimestamp = true, - writeLegacyParquetFormat = true) + int96AsTimestamp = true) // ==================================================== // Tests for converting Catalyst MapType to Parquet Map @@ -805,8 +772,6 @@ class ParquetSchemaSuite extends ParquetSchemaTest { | } |} """.stripMargin, - binaryAsString = true, - int96AsTimestamp = true, writeLegacyParquetFormat = false) testCatalystToParquet( @@ -825,8 +790,6 @@ class ParquetSchemaSuite extends ParquetSchemaTest { | } |} """.stripMargin, - binaryAsString = true, - int96AsTimestamp = true, writeLegacyParquetFormat = true) testCatalystToParquet( @@ -845,8 +808,6 @@ class ParquetSchemaSuite extends ParquetSchemaTest { | } |} """.stripMargin, - binaryAsString = true, - int96AsTimestamp = true, writeLegacyParquetFormat = false) testCatalystToParquet( @@ -865,8 +826,6 @@ class ParquetSchemaSuite extends ParquetSchemaTest { | } |} """.stripMargin, - binaryAsString = true, - int96AsTimestamp = true, writeLegacyParquetFormat = true) // ================================= @@ -982,7 +941,19 @@ class ParquetSchemaSuite extends ParquetSchemaTest { binaryAsString = true, int96AsTimestamp = false, writeLegacyParquetFormat = true, - int64AsTimestampMillis = true) + outputTimestampType = SQLConf.ParquetOutputTimestampType.TIMESTAMP_MILLIS) + + testSchema( + "Timestamp written and read as INT64 with TIMESTAMP_MICROS", + StructType(Seq(StructField("f1", TimestampType))), + """message root { + | optional INT64 f1 (TIMESTAMP_MICROS); + |} + """.stripMargin, + binaryAsString = true, + int96AsTimestamp = false, + writeLegacyParquetFormat = true, + outputTimestampType = SQLConf.ParquetOutputTimestampType.TIMESTAMP_MICROS) private def testSchemaClipping( testName: String, diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetTest.scala index 85efca3c4b24d..f05f5722af51a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetTest.scala @@ -124,7 +124,7 @@ private[sql] trait ParquetTest extends SQLTestUtils { protected def writeMetadata( schema: StructType, path: Path, configuration: Configuration): Unit = { - val parquetSchema = new ParquetSchemaConverter().convert(schema) + val parquetSchema = new SparkToParquetSchemaConverter().convert(schema) val extraMetadata = Map(ParquetReadSupport.SPARK_METADATA_KEY -> schema.json).asJava val createdBy = s"Apache Spark ${org.apache.spark.SPARK_VERSION}" val fileMetadata = new FileMetaData(parquetSchema, extraMetadata, createdBy) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/internal/SQLConfSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/internal/SQLConfSuite.scala index 205c303b6cc4b..f9d75fc1788db 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/internal/SQLConfSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/internal/SQLConfSuite.scala @@ -281,4 +281,32 @@ class SQLConfSuite extends QueryTest with SharedSQLContext { assert(null == spark.conf.get("spark.sql.nonexistent", null)) assert("" == spark.conf.get("spark.sql.nonexistent", "")) } + + test("SPARK-10365: PARQUET_OUTPUT_TIMESTAMP_TYPE") { + spark.sessionState.conf.clear() + + // check default value + assert(spark.sessionState.conf.parquetOutputTimestampType == + SQLConf.ParquetOutputTimestampType.INT96) + + // PARQUET_INT64_AS_TIMESTAMP_MILLIS should be respected. + spark.sessionState.conf.setConf(SQLConf.PARQUET_INT64_AS_TIMESTAMP_MILLIS, true) + assert(spark.sessionState.conf.parquetOutputTimestampType == + SQLConf.ParquetOutputTimestampType.TIMESTAMP_MILLIS) + + // PARQUET_OUTPUT_TIMESTAMP_TYPE has higher priority over PARQUET_INT64_AS_TIMESTAMP_MILLIS + spark.sessionState.conf.setConf(SQLConf.PARQUET_OUTPUT_TIMESTAMP_TYPE, "timestamp_micros") + assert(spark.sessionState.conf.parquetOutputTimestampType == + SQLConf.ParquetOutputTimestampType.TIMESTAMP_MICROS) + spark.sessionState.conf.setConf(SQLConf.PARQUET_OUTPUT_TIMESTAMP_TYPE, "int96") + assert(spark.sessionState.conf.parquetOutputTimestampType == + SQLConf.ParquetOutputTimestampType.INT96) + + // test invalid conf value + intercept[IllegalArgumentException] { + spark.conf.set(SQLConf.PARQUET_OUTPUT_TIMESTAMP_TYPE.key, "invalid") + } + + spark.sessionState.conf.clear() + } } From b3f9dbf48ec0938ff5c98833bb6b6855c620ef57 Mon Sep 17 00:00:00 2001 From: Paul Mackles Date: Sun, 12 Nov 2017 11:21:23 -0800 Subject: [PATCH 1658/1765] [SPARK-19606][MESOS] Support constraints in spark-dispatcher ## What changes were proposed in this pull request? A discussed in SPARK-19606, the addition of a new config property named "spark.mesos.constraints.driver" for constraining drivers running on a Mesos cluster ## How was this patch tested? Corresponding unit test added also tested locally on a Mesos cluster Please review http://spark.apache.org/contributing.html before opening a pull request. Author: Paul Mackles Closes #19543 from pmackles/SPARK-19606. --- docs/running-on-mesos.md | 17 ++++++- .../apache/spark/deploy/mesos/config.scala | 7 +++ .../cluster/mesos/MesosClusterScheduler.scala | 15 ++++-- .../mesos/MesosClusterSchedulerSuite.scala | 47 +++++++++++++++++++ .../spark/scheduler/cluster/mesos/Utils.scala | 31 ++++++++---- 5 files changed, 100 insertions(+), 17 deletions(-) diff --git a/docs/running-on-mesos.md b/docs/running-on-mesos.md index 7a443ffddc5f0..19ec7c1e0aeee 100644 --- a/docs/running-on-mesos.md +++ b/docs/running-on-mesos.md @@ -263,7 +263,10 @@ resource offers will be accepted. conf.set("spark.mesos.constraints", "os:centos7;us-east-1:false") {% endhighlight %} -For example, Let's say `spark.mesos.constraints` is set to `os:centos7;us-east-1:false`, then the resource offers will be checked to see if they meet both these constraints and only then will be accepted to start new executors. +For example, Let's say `spark.mesos.constraints` is set to `os:centos7;us-east-1:false`, then the resource offers will +be checked to see if they meet both these constraints and only then will be accepted to start new executors. + +To constrain where driver tasks are run, use `spark.mesos.driver.constraints` # Mesos Docker Support @@ -447,7 +450,9 @@ See the [configuration page](configuration.html) for information on Spark config spark.mesos.constraints (none) - Attribute based constraints on mesos resource offers. By default, all resource offers will be accepted. Refer to Mesos Attributes & Resources for more information on attributes. + Attribute based constraints on mesos resource offers. By default, all resource offers will be accepted. This setting + applies only to executors. Refer to Mesos + Attributes & Resources for more information on attributes.
    • Scalar constraints are matched with "less than equal" semantics i.e. value in the constraint must be less than or equal to the value in the resource offer.
    • Range constraints are matched with "contains" semantics i.e. value in the constraint must be within the resource offer's value.
    • @@ -457,6 +462,14 @@ See the [configuration page](configuration.html) for information on Spark config
    + + spark.mesos.driver.constraints + (none) + + Same as spark.mesos.constraints except applied to drivers when launched through the dispatcher. By default, + all offers with sufficient resources will be accepted. + + spark.mesos.containerizer docker diff --git a/resource-managers/mesos/src/main/scala/org/apache/spark/deploy/mesos/config.scala b/resource-managers/mesos/src/main/scala/org/apache/spark/deploy/mesos/config.scala index 821534eb4fc38..d134847dc74d2 100644 --- a/resource-managers/mesos/src/main/scala/org/apache/spark/deploy/mesos/config.scala +++ b/resource-managers/mesos/src/main/scala/org/apache/spark/deploy/mesos/config.scala @@ -122,4 +122,11 @@ package object config { "Example: key1:val1,key2:val2") .stringConf .createOptional + + private[spark] val DRIVER_CONSTRAINTS = + ConfigBuilder("spark.mesos.driver.constraints") + .doc("Attribute based constraints on mesos resource offers. Applied by the dispatcher " + + "when launching drivers. Default is to accept all offers with sufficient resources.") + .stringConf + .createWithDefault("") } diff --git a/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterScheduler.scala b/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterScheduler.scala index de846c85d53a6..c41283e4a3e39 100644 --- a/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterScheduler.scala +++ b/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterScheduler.scala @@ -556,9 +556,10 @@ private[spark] class MesosClusterScheduler( private class ResourceOffer( val offer: Offer, - var remainingResources: JList[Resource]) { + var remainingResources: JList[Resource], + var attributes: JList[Attribute]) { override def toString(): String = { - s"Offer id: ${offer.getId}, resources: ${remainingResources}" + s"Offer id: ${offer.getId}, resources: ${remainingResources}, attributes: ${attributes}" } } @@ -601,10 +602,14 @@ private[spark] class MesosClusterScheduler( for (submission <- candidates) { val driverCpu = submission.cores val driverMem = submission.mem - logTrace(s"Finding offer to launch driver with cpu: $driverCpu, mem: $driverMem") + val driverConstraints = + parseConstraintString(submission.conf.get(config.DRIVER_CONSTRAINTS)) + logTrace(s"Finding offer to launch driver with cpu: $driverCpu, mem: $driverMem, " + + s"driverConstraints: $driverConstraints") val offerOption = currentOffers.find { offer => getResource(offer.remainingResources, "cpus") >= driverCpu && - getResource(offer.remainingResources, "mem") >= driverMem + getResource(offer.remainingResources, "mem") >= driverMem && + matchesAttributeRequirements(driverConstraints, toAttributeMap(offer.attributes)) } if (offerOption.isEmpty) { logDebug(s"Unable to find offer to launch driver id: ${submission.submissionId}, " + @@ -652,7 +657,7 @@ private[spark] class MesosClusterScheduler( val currentTime = new Date() val currentOffers = offers.asScala.map { - offer => new ResourceOffer(offer, offer.getResourcesList) + offer => new ResourceOffer(offer, offer.getResourcesList, offer.getAttributesList) }.toList stateLock.synchronized { diff --git a/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterSchedulerSuite.scala b/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterSchedulerSuite.scala index 77acee608f25f..e534b9d7e3ed9 100644 --- a/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterSchedulerSuite.scala +++ b/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterSchedulerSuite.scala @@ -254,6 +254,53 @@ class MesosClusterSchedulerSuite extends SparkFunSuite with LocalSparkContext wi assert(networkInfos.get(0).getLabels.getLabels(1).getValue == "val2") } + test("accept/decline offers with driver constraints") { + setScheduler() + + val mem = 1000 + val cpu = 1 + val s2Attributes = List(Utils.createTextAttribute("c1", "a")) + val s3Attributes = List( + Utils.createTextAttribute("c1", "a"), + Utils.createTextAttribute("c2", "b")) + val offers = List( + Utils.createOffer("o1", "s1", mem, cpu, None, 0), + Utils.createOffer("o2", "s2", mem, cpu, None, 0, s2Attributes), + Utils.createOffer("o3", "s3", mem, cpu, None, 0, s3Attributes)) + + def submitDriver(driverConstraints: String): Unit = { + val response = scheduler.submitDriver( + new MesosDriverDescription("d1", "jar", mem, cpu, true, + command, + Map("spark.mesos.executor.home" -> "test", + "spark.app.name" -> "test", + config.DRIVER_CONSTRAINTS.key -> driverConstraints), + "s1", + new Date())) + assert(response.success) + } + + submitDriver("c1:x") + scheduler.resourceOffers(driver, offers.asJava) + offers.foreach(o => Utils.verifyTaskNotLaunched(driver, o.getId.getValue)) + + submitDriver("c1:y;c2:z") + scheduler.resourceOffers(driver, offers.asJava) + offers.foreach(o => Utils.verifyTaskNotLaunched(driver, o.getId.getValue)) + + submitDriver("") + scheduler.resourceOffers(driver, offers.asJava) + Utils.verifyTaskLaunched(driver, "o1") + + submitDriver("c1:a") + scheduler.resourceOffers(driver, offers.asJava) + Utils.verifyTaskLaunched(driver, "o2") + + submitDriver("c1:a;c2:b") + scheduler.resourceOffers(driver, offers.asJava) + Utils.verifyTaskLaunched(driver, "o3") + } + test("supports spark.mesos.driver.labels") { setScheduler() diff --git a/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/Utils.scala b/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/Utils.scala index 5636ac52bd4a7..c9f47471cd75e 100644 --- a/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/Utils.scala +++ b/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/Utils.scala @@ -43,12 +43,13 @@ object Utils { .build() def createOffer( - offerId: String, - slaveId: String, - mem: Int, - cpus: Int, - ports: Option[(Long, Long)] = None, - gpus: Int = 0): Offer = { + offerId: String, + slaveId: String, + mem: Int, + cpus: Int, + ports: Option[(Long, Long)] = None, + gpus: Int = 0, + attributes: List[Attribute] = List.empty): Offer = { val builder = Offer.newBuilder() builder.addResourcesBuilder() .setName("mem") @@ -63,7 +64,7 @@ object Utils { .setName("ports") .setType(Value.Type.RANGES) .setRanges(Ranges.newBuilder().addRange(MesosRange.newBuilder() - .setBegin(resourcePorts._1).setEnd(resourcePorts._2).build())) + .setBegin(resourcePorts._1).setEnd(resourcePorts._2).build())) } if (gpus > 0) { builder.addResourcesBuilder() @@ -73,9 +74,10 @@ object Utils { } builder.setId(createOfferId(offerId)) .setFrameworkId(FrameworkID.newBuilder() - .setValue("f1")) + .setValue("f1")) .setSlaveId(SlaveID.newBuilder().setValue(slaveId)) .setHostname(s"host${slaveId}") + .addAllAttributes(attributes.asJava) .build() } @@ -125,7 +127,7 @@ object Utils { .getVariablesList .asScala assert(envVars - .count(!_.getName.startsWith("SPARK_")) == 2) // user-defined secret env vars + .count(!_.getName.startsWith("SPARK_")) == 2) // user-defined secret env vars val variableOne = envVars.filter(_.getName == "SECRET_ENV_KEY").head assert(variableOne.getSecret.isInitialized) assert(variableOne.getSecret.getType == Secret.Type.REFERENCE) @@ -154,7 +156,7 @@ object Utils { .getVariablesList .asScala assert(envVars - .count(!_.getName.startsWith("SPARK_")) == 2) // user-defined secret env vars + .count(!_.getName.startsWith("SPARK_")) == 2) // user-defined secret env vars val variableOne = envVars.filter(_.getName == "USER").head assert(variableOne.getSecret.isInitialized) assert(variableOne.getSecret.getType == Secret.Type.VALUE) @@ -212,4 +214,13 @@ object Utils { assert(secretVolTwo.getSource.getSecret.getValue.getData == ByteString.copyFrom("password".getBytes)) } + + def createTextAttribute(name: String, value: String): Attribute = { + Attribute.newBuilder() + .setName(name) + .setType(Value.Type.TEXT) + .setText(Value.Text.newBuilder().setValue(value)) + .build() + } } + From 9bf696dbece6b1993880efba24a6d32c54c4d11c Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Sun, 12 Nov 2017 22:44:47 +0100 Subject: [PATCH 1659/1765] [SPARK-21720][SQL] Fix 64KB JVM bytecode limit problem with AND or OR ## What changes were proposed in this pull request? This PR changes `AND` or `OR` code generation to place condition and then expressions' generated code into separated methods if these size could be large. When the method is newly generated, variables for `isNull` and `value` are declared as an instance variable to pass these values (e.g. `isNull1409` and `value1409`) to the callers of the generated method. This PR resolved two cases: * large code size of left expression * large code size of right expression ## How was this patch tested? Added a new test case into `CodeGenerationSuite` Author: Kazuaki Ishizaki Closes #18972 from kiszk/SPARK-21720. --- .../expressions/codegen/CodeGenerator.scala | 30 +++++++ .../expressions/conditionalExpressions.scala | 29 +------ .../sql/catalyst/expressions/predicates.scala | 82 ++++++++++++++++++- .../expressions/CodeGenerationSuite.scala | 39 +++++++++ .../org/apache/spark/sql/DataFrameSuite.scala | 6 +- 5 files changed, 157 insertions(+), 29 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index 98eda2a1ba92c..3dc3f8e4adac0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -908,6 +908,36 @@ class CodegenContext { } } + /** + * Wrap the generated code of expression, which was created from a row object in INPUT_ROW, + * by a function. ev.isNull and ev.value are passed by global variables + * + * @param ev the code to evaluate expressions. + * @param dataType the data type of ev.value. + * @param baseFuncName the split function name base. + */ + def createAndAddFunction( + ev: ExprCode, + dataType: DataType, + baseFuncName: String): (String, String, String) = { + val globalIsNull = freshName("isNull") + addMutableState("boolean", globalIsNull, s"$globalIsNull = false;") + val globalValue = freshName("value") + addMutableState(javaType(dataType), globalValue, + s"$globalValue = ${defaultValue(dataType)};") + val funcName = freshName(baseFuncName) + val funcBody = + s""" + |private void $funcName(InternalRow ${INPUT_ROW}) { + | ${ev.code.trim} + | $globalIsNull = ${ev.isNull}; + | $globalValue = ${ev.value}; + |} + """.stripMargin + val fullFuncName = addNewFunction(funcName, funcBody) + (fullFuncName, globalIsNull, globalValue) + } + /** * Perform a function which generates a sequence of ExprCodes with a given mapping between * expressions and common expressions, instead of using the mapping in current context. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala index d95b59d5ec423..c41a10c7b0f87 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala @@ -72,11 +72,11 @@ case class If(predicate: Expression, trueValue: Expression, falseValue: Expressi (ctx.INPUT_ROW != null && ctx.currentVars == null)) { val (condFuncName, condGlobalIsNull, condGlobalValue) = - createAndAddFunction(ctx, condEval, predicate.dataType, "evalIfCondExpr") + ctx.createAndAddFunction(condEval, predicate.dataType, "evalIfCondExpr") val (trueFuncName, trueGlobalIsNull, trueGlobalValue) = - createAndAddFunction(ctx, trueEval, trueValue.dataType, "evalIfTrueExpr") + ctx.createAndAddFunction(trueEval, trueValue.dataType, "evalIfTrueExpr") val (falseFuncName, falseGlobalIsNull, falseGlobalValue) = - createAndAddFunction(ctx, falseEval, falseValue.dataType, "evalIfFalseExpr") + ctx.createAndAddFunction(falseEval, falseValue.dataType, "evalIfFalseExpr") s""" $condFuncName(${ctx.INPUT_ROW}); boolean ${ev.isNull} = false; @@ -112,29 +112,6 @@ case class If(predicate: Expression, trueValue: Expression, falseValue: Expressi ev.copy(code = generatedCode) } - private def createAndAddFunction( - ctx: CodegenContext, - ev: ExprCode, - dataType: DataType, - baseFuncName: String): (String, String, String) = { - val globalIsNull = ctx.freshName("isNull") - ctx.addMutableState("boolean", globalIsNull, s"$globalIsNull = false;") - val globalValue = ctx.freshName("value") - ctx.addMutableState(ctx.javaType(dataType), globalValue, - s"$globalValue = ${ctx.defaultValue(dataType)};") - val funcName = ctx.freshName(baseFuncName) - val funcBody = - s""" - |private void $funcName(InternalRow ${ctx.INPUT_ROW}) { - | ${ev.code.trim} - | $globalIsNull = ${ev.isNull}; - | $globalValue = ${ev.value}; - |} - """.stripMargin - val fullFuncName = ctx.addNewFunction(funcName, funcBody) - (fullFuncName, globalIsNull, globalValue) - } - override def toString: String = s"if ($predicate) $trueValue else $falseValue" override def sql: String = s"(IF(${predicate.sql}, ${trueValue.sql}, ${falseValue.sql}))" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala index efcd45fad779c..61df5e053a374 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala @@ -368,7 +368,46 @@ case class And(left: Expression, right: Expression) extends BinaryOperator with val eval2 = right.genCode(ctx) // The result should be `false`, if any of them is `false` whenever the other is null or not. - if (!left.nullable && !right.nullable) { + + // place generated code of eval1 and eval2 in separate methods if their code combined is large + val combinedLength = eval1.code.length + eval2.code.length + if (combinedLength > 1024 && + // Split these expressions only if they are created from a row object + (ctx.INPUT_ROW != null && ctx.currentVars == null)) { + + val (eval1FuncName, eval1GlobalIsNull, eval1GlobalValue) = + ctx.createAndAddFunction(eval1, BooleanType, "eval1Expr") + val (eval2FuncName, eval2GlobalIsNull, eval2GlobalValue) = + ctx.createAndAddFunction(eval2, BooleanType, "eval2Expr") + if (!left.nullable && !right.nullable) { + val generatedCode = s""" + $eval1FuncName(${ctx.INPUT_ROW}); + boolean ${ev.value} = false; + if (${eval1GlobalValue}) { + $eval2FuncName(${ctx.INPUT_ROW}); + ${ev.value} = ${eval2GlobalValue}; + } + """ + ev.copy(code = generatedCode, isNull = "false") + } else { + val generatedCode = s""" + $eval1FuncName(${ctx.INPUT_ROW}); + boolean ${ev.isNull} = false; + boolean ${ev.value} = false; + if (!${eval1GlobalIsNull} && !${eval1GlobalValue}) { + } else { + $eval2FuncName(${ctx.INPUT_ROW}); + if (!${eval2GlobalIsNull} && !${eval2GlobalValue}) { + } else if (!${eval1GlobalIsNull} && !${eval2GlobalIsNull}) { + ${ev.value} = true; + } else { + ${ev.isNull} = true; + } + } + """ + ev.copy(code = generatedCode) + } + } else if (!left.nullable && !right.nullable) { ev.copy(code = s""" ${eval1.code} boolean ${ev.value} = false; @@ -431,7 +470,46 @@ case class Or(left: Expression, right: Expression) extends BinaryOperator with P val eval2 = right.genCode(ctx) // The result should be `true`, if any of them is `true` whenever the other is null or not. - if (!left.nullable && !right.nullable) { + + // place generated code of eval1 and eval2 in separate methods if their code combined is large + val combinedLength = eval1.code.length + eval2.code.length + if (combinedLength > 1024 && + // Split these expressions only if they are created from a row object + (ctx.INPUT_ROW != null && ctx.currentVars == null)) { + + val (eval1FuncName, eval1GlobalIsNull, eval1GlobalValue) = + ctx.createAndAddFunction(eval1, BooleanType, "eval1Expr") + val (eval2FuncName, eval2GlobalIsNull, eval2GlobalValue) = + ctx.createAndAddFunction(eval2, BooleanType, "eval2Expr") + if (!left.nullable && !right.nullable) { + val generatedCode = s""" + $eval1FuncName(${ctx.INPUT_ROW}); + boolean ${ev.value} = true; + if (!${eval1GlobalValue}) { + $eval2FuncName(${ctx.INPUT_ROW}); + ${ev.value} = ${eval2GlobalValue}; + } + """ + ev.copy(code = generatedCode, isNull = "false") + } else { + val generatedCode = s""" + $eval1FuncName(${ctx.INPUT_ROW}); + boolean ${ev.isNull} = false; + boolean ${ev.value} = true; + if (!${eval1GlobalIsNull} && ${eval1GlobalValue}) { + } else { + $eval2FuncName(${ctx.INPUT_ROW}); + if (!${eval2GlobalIsNull} && ${eval2GlobalValue}) { + } else if (!${eval1GlobalIsNull} && !${eval2GlobalIsNull}) { + ${ev.value} = false; + } else { + ${ev.isNull} = true; + } + } + """ + ev.copy(code = generatedCode) + } + } else if (!left.nullable && !right.nullable) { ev.isNull = "false" ev.copy(code = s""" ${eval1.code} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala index 1e6f7b65e7e72..8f6289f00571c 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala @@ -341,4 +341,43 @@ class CodeGenerationSuite extends SparkFunSuite with ExpressionEvalHelper { // should not throw exception projection(row) } + + test("SPARK-21720: split large predications into blocks due to JVM code size limit") { + val length = 600 + + val input = new GenericInternalRow(length) + val utf8Str = UTF8String.fromString(s"abc") + for (i <- 0 until length) { + input.update(i, utf8Str) + } + + var exprOr: Expression = Literal(false) + for (i <- 0 until length) { + exprOr = Or(EqualTo(BoundReference(i, StringType, true), Literal(s"c$i")), exprOr) + } + + val planOr = GenerateMutableProjection.generate(Seq(exprOr)) + val actualOr = planOr(input).toSeq(Seq(exprOr.dataType)) + assert(actualOr.length == 1) + val expectedOr = false + + if (!checkResult(actualOr.head, expectedOr, exprOr.dataType)) { + fail(s"Incorrect Evaluation: expressions: $exprOr, actual: $actualOr, expected: $expectedOr") + } + + var exprAnd: Expression = Literal(true) + for (i <- 0 until length) { + exprAnd = And(EqualTo(BoundReference(i, StringType, true), Literal(s"c$i")), exprAnd) + } + + val planAnd = GenerateMutableProjection.generate(Seq(exprAnd)) + val actualAnd = planAnd(input).toSeq(Seq(exprAnd.dataType)) + assert(actualAnd.length == 1) + val expectedAnd = false + + if (!checkResult(actualAnd.head, expectedAnd, exprAnd.dataType)) { + fail( + s"Incorrect Evaluation: expressions: $exprAnd, actual: $actualAnd, expected: $expectedAnd") + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index 31bfa77e76329..644e72c893ceb 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -2097,7 +2097,11 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { .count } - testQuietly("SPARK-19372: Filter can be executed w/o generated code due to JVM code size limit") { + // The fix of SPARK-21720 avoid an exception regarding JVM code size limit + // TODO: When we make a threshold of splitting statements (1024) configurable, + // we will re-enable this with max threshold to cause an exception + // See https://github.com/apache/spark/pull/18972/files#r150223463 + ignore("SPARK-19372: Filter can be executed w/o generated code due to JVM code size limit") { val N = 400 val rows = Seq(Row.fromSeq(Seq.fill(N)("string"))) val schema = StructType(Seq.tabulate(N)(i => StructField(s"_c$i", StringType))) From 3d90b2cb384affe8ceac9398615e9e21b8c8e0b0 Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Sun, 12 Nov 2017 14:37:20 -0800 Subject: [PATCH 1660/1765] [SPARK-21693][R][ML] Reduce max iterations in Linear SVM test in R to speed up AppVeyor build ## What changes were proposed in this pull request? This PR proposes to reduce max iteration in Linear SVM test in SparkR. This particular test elapses roughly 5 mins on my Mac and over 20 mins on Windows. The root cause appears, it triggers 2500ish jobs by the default 100 max iterations. In Linux, `daemon.R` is forked but on Windows another process is launched, which is extremely slow. So, given my observation, there are many processes (not forked) ran on Windows, which makes the differences of elapsed time. After reducing the max iteration to 10, the total jobs in this single test is reduced to 550ish. After reducing the max iteration to 5, the total jobs in this single test is reduced to 360ish. ## How was this patch tested? Manually tested the elapsed times. Author: hyukjinkwon Closes #19722 from HyukjinKwon/SPARK-21693-test. --- R/pkg/tests/fulltests/test_mllib_classification.R | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/R/pkg/tests/fulltests/test_mllib_classification.R b/R/pkg/tests/fulltests/test_mllib_classification.R index a4d0397236d17..ad47717ddc12f 100644 --- a/R/pkg/tests/fulltests/test_mllib_classification.R +++ b/R/pkg/tests/fulltests/test_mllib_classification.R @@ -66,7 +66,7 @@ test_that("spark.svmLinear", { feature <- c(1.1419053, 0.9194079, -0.9498666, -1.1069903, 0.2809776) data <- as.data.frame(cbind(label, feature)) df <- createDataFrame(data) - model <- spark.svmLinear(df, label ~ feature, regParam = 0.1) + model <- spark.svmLinear(df, label ~ feature, regParam = 0.1, maxIter = 5) prediction <- collect(select(predict(model, df), "prediction")) expect_equal(sort(prediction$prediction), c("0.0", "0.0", "0.0", "1.0", "1.0")) @@ -77,10 +77,11 @@ test_that("spark.svmLinear", { trainidxs <- base::sample(nrow(data), nrow(data) * 0.7) traindf <- as.DataFrame(data[trainidxs, ]) testdf <- as.DataFrame(rbind(data[-trainidxs, ], c(0, "the other"))) - model <- spark.svmLinear(traindf, clicked ~ ., regParam = 0.1) + model <- spark.svmLinear(traindf, clicked ~ ., regParam = 0.1, maxIter = 5) predictions <- predict(model, testdf) expect_error(collect(predictions)) - model <- spark.svmLinear(traindf, clicked ~ ., regParam = 0.1, handleInvalid = "skip") + model <- spark.svmLinear(traindf, clicked ~ ., regParam = 0.1, + handleInvalid = "skip", maxIter = 5) predictions <- predict(model, testdf) expect_equal(class(collect(predictions)$clicked[1]), "list") From 209b9361ac8a4410ff797cff1115e1888e2f7e66 Mon Sep 17 00:00:00 2001 From: Bryan Cutler Date: Mon, 13 Nov 2017 13:16:01 +0900 Subject: [PATCH 1661/1765] [SPARK-20791][PYSPARK] Use Arrow to create Spark DataFrame from Pandas ## What changes were proposed in this pull request? This change uses Arrow to optimize the creation of a Spark DataFrame from a Pandas DataFrame. The input df is sliced according to the default parallelism. The optimization is enabled with the existing conf "spark.sql.execution.arrow.enabled" and is disabled by default. ## How was this patch tested? Added new unit test to create DataFrame with and without the optimization enabled, then compare results. Author: Bryan Cutler Author: Takuya UESHIN Closes #19459 from BryanCutler/arrow-createDataFrame-from_pandas-SPARK-20791. --- python/pyspark/context.py | 28 +++--- python/pyspark/java_gateway.py | 1 + python/pyspark/serializers.py | 10 ++- python/pyspark/sql/session.py | 88 ++++++++++++++---- python/pyspark/sql/tests.py | 89 ++++++++++++++++--- python/pyspark/sql/types.py | 49 ++++++++++ .../spark/sql/api/python/PythonSQLUtils.scala | 18 ++++ .../sql/execution/arrow/ArrowConverters.scala | 14 +++ 8 files changed, 254 insertions(+), 43 deletions(-) diff --git a/python/pyspark/context.py b/python/pyspark/context.py index a33f6dcf31fc0..24905f1c97b21 100644 --- a/python/pyspark/context.py +++ b/python/pyspark/context.py @@ -475,24 +475,30 @@ def f(split, iterator): return xrange(getStart(split), getStart(split + 1), step) return self.parallelize([], numSlices).mapPartitionsWithIndex(f) - # Calling the Java parallelize() method with an ArrayList is too slow, - # because it sends O(n) Py4J commands. As an alternative, serialized - # objects are written to a file and loaded through textFile(). + + # Make sure we distribute data evenly if it's smaller than self.batchSize + if "__len__" not in dir(c): + c = list(c) # Make it a list so we can compute its length + batchSize = max(1, min(len(c) // numSlices, self._batchSize or 1024)) + serializer = BatchedSerializer(self._unbatched_serializer, batchSize) + jrdd = self._serialize_to_jvm(c, numSlices, serializer) + return RDD(jrdd, self, serializer) + + def _serialize_to_jvm(self, data, parallelism, serializer): + """ + Calling the Java parallelize() method with an ArrayList is too slow, + because it sends O(n) Py4J commands. As an alternative, serialized + objects are written to a file and loaded through textFile(). + """ tempFile = NamedTemporaryFile(delete=False, dir=self._temp_dir) try: - # Make sure we distribute data evenly if it's smaller than self.batchSize - if "__len__" not in dir(c): - c = list(c) # Make it a list so we can compute its length - batchSize = max(1, min(len(c) // numSlices, self._batchSize or 1024)) - serializer = BatchedSerializer(self._unbatched_serializer, batchSize) - serializer.dump_stream(c, tempFile) + serializer.dump_stream(data, tempFile) tempFile.close() readRDDFromFile = self._jvm.PythonRDD.readRDDFromFile - jrdd = readRDDFromFile(self._jsc, tempFile.name, numSlices) + return readRDDFromFile(self._jsc, tempFile.name, parallelism) finally: # readRDDFromFile eagerily reads the file so we can delete right after. os.unlink(tempFile.name) - return RDD(jrdd, self, serializer) def pickleFile(self, name, minPartitions=None): """ diff --git a/python/pyspark/java_gateway.py b/python/pyspark/java_gateway.py index 3c783ae541a1f..3e704fe9bf6ec 100644 --- a/python/pyspark/java_gateway.py +++ b/python/pyspark/java_gateway.py @@ -121,6 +121,7 @@ def killChild(): java_import(gateway.jvm, "org.apache.spark.mllib.api.python.*") # TODO(davies): move into sql java_import(gateway.jvm, "org.apache.spark.sql.*") + java_import(gateway.jvm, "org.apache.spark.sql.api.python.*") java_import(gateway.jvm, "org.apache.spark.sql.hive.*") java_import(gateway.jvm, "scala.Tuple2") diff --git a/python/pyspark/serializers.py b/python/pyspark/serializers.py index d7979f095da76..e0afdafbfcd62 100644 --- a/python/pyspark/serializers.py +++ b/python/pyspark/serializers.py @@ -214,6 +214,13 @@ def __repr__(self): def _create_batch(series): + """ + Create an Arrow record batch from the given pandas.Series or list of Series, with optional type. + + :param series: A single pandas.Series, list of Series, or list of (series, arrow_type) + :return: Arrow RecordBatch + """ + from pyspark.sql.types import _check_series_convert_timestamps_internal import pyarrow as pa # Make input conform to [(series1, type1), (series2, type2), ...] @@ -229,7 +236,8 @@ def cast_series(s, t): # NOTE: convert to 'us' with astype here, unit ignored in `from_pandas` see ARROW-1680 return _check_series_convert_timestamps_internal(s.fillna(0))\ .values.astype('datetime64[us]', copy=False) - elif t == pa.date32(): + # NOTE: can not compare None with pyarrow.DataType(), fixed with Arrow >= 0.7.1 + elif t is not None and t == pa.date32(): # TODO: this converts the series to Python objects, possibly avoid with Arrow >= 0.8 return s.dt.date elif t is None or s.dtype == t.to_pandas_dtype(): diff --git a/python/pyspark/sql/session.py b/python/pyspark/sql/session.py index d1d0b8b8fe5d9..589365b083012 100644 --- a/python/pyspark/sql/session.py +++ b/python/pyspark/sql/session.py @@ -25,7 +25,7 @@ basestring = unicode = str xrange = range else: - from itertools import imap as map + from itertools import izip as zip, imap as map from pyspark import since from pyspark.rdd import RDD, ignore_unicode_prefix @@ -417,12 +417,12 @@ def _createFromLocal(self, data, schema): data = [schema.toInternal(row) for row in data] return self._sc.parallelize(data), schema - def _get_numpy_record_dtypes(self, rec): + def _get_numpy_record_dtype(self, rec): """ Used when converting a pandas.DataFrame to Spark using to_records(), this will correct - the dtypes of records so they can be properly loaded into Spark. - :param rec: a numpy record to check dtypes - :return corrected dtypes for a numpy.record or None if no correction needed + the dtypes of fields in a record so they can be properly loaded into Spark. + :param rec: a numpy record to check field dtypes + :return corrected dtype for a numpy.record or None if no correction needed """ import numpy as np cur_dtypes = rec.dtype @@ -438,28 +438,70 @@ def _get_numpy_record_dtypes(self, rec): curr_type = 'datetime64[us]' has_rec_fix = True record_type_list.append((str(col_names[i]), curr_type)) - return record_type_list if has_rec_fix else None + return np.dtype(record_type_list) if has_rec_fix else None - def _convert_from_pandas(self, pdf, schema): + def _convert_from_pandas(self, pdf): """ Convert a pandas.DataFrame to list of records that can be used to make a DataFrame - :return tuple of list of records and schema + :return list of records """ - # If no schema supplied by user then get the names of columns only - if schema is None: - schema = [str(x) for x in pdf.columns] # Convert pandas.DataFrame to list of numpy records np_records = pdf.to_records(index=False) # Check if any columns need to be fixed for Spark to infer properly if len(np_records) > 0: - record_type_list = self._get_numpy_record_dtypes(np_records[0]) - if record_type_list is not None: - return [r.astype(record_type_list).tolist() for r in np_records], schema + record_dtype = self._get_numpy_record_dtype(np_records[0]) + if record_dtype is not None: + return [r.astype(record_dtype).tolist() for r in np_records] # Convert list of numpy records to python lists - return [r.tolist() for r in np_records], schema + return [r.tolist() for r in np_records] + + def _create_from_pandas_with_arrow(self, pdf, schema): + """ + Create a DataFrame from a given pandas.DataFrame by slicing it into partitions, converting + to Arrow data, then sending to the JVM to parallelize. If a schema is passed in, the + data types will be used to coerce the data in Pandas to Arrow conversion. + """ + from pyspark.serializers import ArrowSerializer, _create_batch + from pyspark.sql.types import from_arrow_schema, to_arrow_type, TimestampType + from pandas.api.types import is_datetime64_dtype, is_datetime64tz_dtype + + # Determine arrow types to coerce data when creating batches + if isinstance(schema, StructType): + arrow_types = [to_arrow_type(f.dataType) for f in schema.fields] + elif isinstance(schema, DataType): + raise ValueError("Single data type %s is not supported with Arrow" % str(schema)) + else: + # Any timestamps must be coerced to be compatible with Spark + arrow_types = [to_arrow_type(TimestampType()) + if is_datetime64_dtype(t) or is_datetime64tz_dtype(t) else None + for t in pdf.dtypes] + + # Slice the DataFrame to be batched + step = -(-len(pdf) // self.sparkContext.defaultParallelism) # round int up + pdf_slices = (pdf[start:start + step] for start in xrange(0, len(pdf), step)) + + # Create Arrow record batches + batches = [_create_batch([(c, t) for (_, c), t in zip(pdf_slice.iteritems(), arrow_types)]) + for pdf_slice in pdf_slices] + + # Create the Spark schema from the first Arrow batch (always at least 1 batch after slicing) + if isinstance(schema, (list, tuple)): + struct = from_arrow_schema(batches[0].schema) + for i, name in enumerate(schema): + struct.fields[i].name = name + struct.names[i] = name + schema = struct + + # Create the Spark DataFrame directly from the Arrow data and schema + jrdd = self._sc._serialize_to_jvm(batches, len(batches), ArrowSerializer()) + jdf = self._jvm.PythonSQLUtils.arrowPayloadToDataFrame( + jrdd, schema.json(), self._wrapped._jsqlContext) + df = DataFrame(jdf, self._wrapped) + df._schema = schema + return df @since(2.0) @ignore_unicode_prefix @@ -557,7 +599,19 @@ def createDataFrame(self, data, schema=None, samplingRatio=None, verifySchema=Tr except Exception: has_pandas = False if has_pandas and isinstance(data, pandas.DataFrame): - data, schema = self._convert_from_pandas(data, schema) + + # If no schema supplied by user then get the names of columns only + if schema is None: + schema = [str(x) for x in data.columns] + + if self.conf.get("spark.sql.execution.arrow.enabled", "false").lower() == "true" \ + and len(data) > 0: + try: + return self._create_from_pandas_with_arrow(data, schema) + except Exception as e: + warnings.warn("Arrow will not be used in createDataFrame: %s" % str(e)) + # Fallback to create DataFrame without arrow if raise some exception + data = self._convert_from_pandas(data) if isinstance(schema, StructType): verify_func = _make_type_verifier(schema) if verifySchema else lambda _: True @@ -576,7 +630,7 @@ def prepare(obj): verify_func(obj) return obj, else: - if isinstance(schema, list): + if isinstance(schema, (list, tuple)): schema = [x.encode('utf-8') if not isinstance(x, str) else x for x in schema] prepare = lambda obj: obj diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 4819f629c5310..6356d938db26a 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -3127,9 +3127,9 @@ def setUpClass(cls): StructField("5_double_t", DoubleType(), True), StructField("6_date_t", DateType(), True), StructField("7_timestamp_t", TimestampType(), True)]) - cls.data = [("a", 1, 10, 0.2, 2.0, datetime(1969, 1, 1), datetime(1969, 1, 1, 1, 1, 1)), - ("b", 2, 20, 0.4, 4.0, datetime(2012, 2, 2), datetime(2012, 2, 2, 2, 2, 2)), - ("c", 3, 30, 0.8, 6.0, datetime(2100, 3, 3), datetime(2100, 3, 3, 3, 3, 3))] + cls.data = [(u"a", 1, 10, 0.2, 2.0, datetime(1969, 1, 1), datetime(1969, 1, 1, 1, 1, 1)), + (u"b", 2, 20, 0.4, 4.0, datetime(2012, 2, 2), datetime(2012, 2, 2, 2, 2, 2)), + (u"c", 3, 30, 0.8, 6.0, datetime(2100, 3, 3), datetime(2100, 3, 3, 3, 3, 3))] @classmethod def tearDownClass(cls): @@ -3145,6 +3145,17 @@ def assertFramesEqual(self, df_with_arrow, df_without): ("\n\nWithout:\n%s\n%s" % (df_without, df_without.dtypes))) self.assertTrue(df_without.equals(df_with_arrow), msg=msg) + def create_pandas_data_frame(self): + import pandas as pd + import numpy as np + data_dict = {} + for j, name in enumerate(self.schema.names): + data_dict[name] = [self.data[i][j] for i in range(len(self.data))] + # need to convert these to numpy types first + data_dict["2_int_t"] = np.int32(data_dict["2_int_t"]) + data_dict["4_float_t"] = np.float32(data_dict["4_float_t"]) + return pd.DataFrame(data=data_dict) + def test_unsupported_datatype(self): schema = StructType([StructField("decimal", DecimalType(), True)]) df = self.spark.createDataFrame([(None,)], schema=schema) @@ -3161,21 +3172,15 @@ def test_null_conversion(self): def test_toPandas_arrow_toggle(self): df = self.spark.createDataFrame(self.data, schema=self.schema) self.spark.conf.set("spark.sql.execution.arrow.enabled", "false") - pdf = df.toPandas() - self.spark.conf.set("spark.sql.execution.arrow.enabled", "true") + try: + pdf = df.toPandas() + finally: + self.spark.conf.set("spark.sql.execution.arrow.enabled", "true") pdf_arrow = df.toPandas() self.assertFramesEqual(pdf_arrow, pdf) def test_pandas_round_trip(self): - import pandas as pd - import numpy as np - data_dict = {} - for j, name in enumerate(self.schema.names): - data_dict[name] = [self.data[i][j] for i in range(len(self.data))] - # need to convert these to numpy types first - data_dict["2_int_t"] = np.int32(data_dict["2_int_t"]) - data_dict["4_float_t"] = np.float32(data_dict["4_float_t"]) - pdf = pd.DataFrame(data=data_dict) + pdf = self.create_pandas_data_frame() df = self.spark.createDataFrame(self.data, schema=self.schema) pdf_arrow = df.toPandas() self.assertFramesEqual(pdf_arrow, pdf) @@ -3187,6 +3192,62 @@ def test_filtered_frame(self): self.assertEqual(pdf.columns[0], "i") self.assertTrue(pdf.empty) + def test_createDataFrame_toggle(self): + pdf = self.create_pandas_data_frame() + self.spark.conf.set("spark.sql.execution.arrow.enabled", "false") + try: + df_no_arrow = self.spark.createDataFrame(pdf) + finally: + self.spark.conf.set("spark.sql.execution.arrow.enabled", "true") + df_arrow = self.spark.createDataFrame(pdf) + self.assertEquals(df_no_arrow.collect(), df_arrow.collect()) + + def test_createDataFrame_with_schema(self): + pdf = self.create_pandas_data_frame() + df = self.spark.createDataFrame(pdf, schema=self.schema) + self.assertEquals(self.schema, df.schema) + pdf_arrow = df.toPandas() + self.assertFramesEqual(pdf_arrow, pdf) + + def test_createDataFrame_with_incorrect_schema(self): + pdf = self.create_pandas_data_frame() + wrong_schema = StructType(list(reversed(self.schema))) + with QuietTest(self.sc): + with self.assertRaisesRegexp(TypeError, ".*field.*can.not.accept.*type"): + self.spark.createDataFrame(pdf, schema=wrong_schema) + + def test_createDataFrame_with_names(self): + pdf = self.create_pandas_data_frame() + # Test that schema as a list of column names gets applied + df = self.spark.createDataFrame(pdf, schema=list('abcdefg')) + self.assertEquals(df.schema.fieldNames(), list('abcdefg')) + # Test that schema as tuple of column names gets applied + df = self.spark.createDataFrame(pdf, schema=tuple('abcdefg')) + self.assertEquals(df.schema.fieldNames(), list('abcdefg')) + + def test_createDataFrame_with_single_data_type(self): + import pandas as pd + with QuietTest(self.sc): + with self.assertRaisesRegexp(TypeError, ".*IntegerType.*tuple"): + self.spark.createDataFrame(pd.DataFrame({"a": [1]}), schema="int") + + def test_createDataFrame_does_not_modify_input(self): + # Some series get converted for Spark to consume, this makes sure input is unchanged + pdf = self.create_pandas_data_frame() + # Use a nanosecond value to make sure it is not truncated + pdf.ix[0, '7_timestamp_t'] = 1 + # Integers with nulls will get NaNs filled with 0 and will be casted + pdf.ix[1, '2_int_t'] = None + pdf_copy = pdf.copy(deep=True) + self.spark.createDataFrame(pdf, schema=self.schema) + self.assertTrue(pdf.equals(pdf_copy)) + + def test_schema_conversion_roundtrip(self): + from pyspark.sql.types import from_arrow_schema, to_arrow_schema + arrow_schema = to_arrow_schema(self.schema) + schema_rt = from_arrow_schema(arrow_schema) + self.assertEquals(self.schema, schema_rt) + @unittest.skipIf(not _have_pandas or not _have_arrow, "Pandas or Arrow not installed") class VectorizedUDFTests(ReusedSQLTestCase): diff --git a/python/pyspark/sql/types.py b/python/pyspark/sql/types.py index 7dd8fa04160e0..fe62f60dd6d0e 100644 --- a/python/pyspark/sql/types.py +++ b/python/pyspark/sql/types.py @@ -1629,6 +1629,55 @@ def to_arrow_type(dt): return arrow_type +def to_arrow_schema(schema): + """ Convert a schema from Spark to Arrow + """ + import pyarrow as pa + fields = [pa.field(field.name, to_arrow_type(field.dataType), nullable=field.nullable) + for field in schema] + return pa.schema(fields) + + +def from_arrow_type(at): + """ Convert pyarrow type to Spark data type. + """ + # TODO: newer pyarrow has is_boolean(at) functions that would be better to check type + import pyarrow as pa + if at == pa.bool_(): + spark_type = BooleanType() + elif at == pa.int8(): + spark_type = ByteType() + elif at == pa.int16(): + spark_type = ShortType() + elif at == pa.int32(): + spark_type = IntegerType() + elif at == pa.int64(): + spark_type = LongType() + elif at == pa.float32(): + spark_type = FloatType() + elif at == pa.float64(): + spark_type = DoubleType() + elif type(at) == pa.DecimalType: + spark_type = DecimalType(precision=at.precision, scale=at.scale) + elif at == pa.string(): + spark_type = StringType() + elif at == pa.date32(): + spark_type = DateType() + elif type(at) == pa.TimestampType: + spark_type = TimestampType() + else: + raise TypeError("Unsupported type in conversion from Arrow: " + str(at)) + return spark_type + + +def from_arrow_schema(arrow_schema): + """ Convert schema from Arrow to Spark. + """ + return StructType( + [StructField(field.name, from_arrow_type(field.type), nullable=field.nullable) + for field in arrow_schema]) + + def _check_dataframe_localize_timestamps(pdf): """ Convert timezone aware timestamps to timezone-naive in local time diff --git a/sql/core/src/main/scala/org/apache/spark/sql/api/python/PythonSQLUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/api/python/PythonSQLUtils.scala index 4d5ce0bb60c0b..b33760b1edbc6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/api/python/PythonSQLUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/api/python/PythonSQLUtils.scala @@ -17,9 +17,12 @@ package org.apache.spark.sql.api.python +import org.apache.spark.api.java.JavaRDD +import org.apache.spark.sql.{DataFrame, SQLContext} import org.apache.spark.sql.catalyst.analysis.FunctionRegistry import org.apache.spark.sql.catalyst.expressions.ExpressionInfo import org.apache.spark.sql.catalyst.parser.CatalystSqlParser +import org.apache.spark.sql.execution.arrow.ArrowConverters import org.apache.spark.sql.types.DataType private[sql] object PythonSQLUtils { @@ -29,4 +32,19 @@ private[sql] object PythonSQLUtils { def listBuiltinFunctionInfos(): Array[ExpressionInfo] = { FunctionRegistry.functionSet.flatMap(f => FunctionRegistry.builtin.lookupFunction(f)).toArray } + + /** + * Python Callable function to convert ArrowPayloads into a [[DataFrame]]. + * + * @param payloadRDD A JavaRDD of ArrowPayloads. + * @param schemaString JSON Formatted Schema for ArrowPayloads. + * @param sqlContext The active [[SQLContext]]. + * @return The converted [[DataFrame]]. + */ + def arrowPayloadToDataFrame( + payloadRDD: JavaRDD[Array[Byte]], + schemaString: String, + sqlContext: SQLContext): DataFrame = { + ArrowConverters.toDataFrame(payloadRDD, schemaString, sqlContext) + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala index 05ea1517fcac9..3cafb344ef553 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala @@ -29,6 +29,8 @@ import org.apache.arrow.vector.schema.ArrowRecordBatch import org.apache.arrow.vector.util.ByteArrayReadableSeekableByteChannel import org.apache.spark.TaskContext +import org.apache.spark.api.java.JavaRDD +import org.apache.spark.sql.{DataFrame, SQLContext} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.execution.vectorized.{ArrowColumnVector, ColumnarBatch, ColumnVector} import org.apache.spark.sql.types._ @@ -204,4 +206,16 @@ private[sql] object ArrowConverters { reader.close() } } + + private[sql] def toDataFrame( + payloadRDD: JavaRDD[Array[Byte]], + schemaString: String, + sqlContext: SQLContext): DataFrame = { + val rdd = payloadRDD.rdd.mapPartitions { iter => + val context = TaskContext.get() + ArrowConverters.fromPayloadIterator(iter.map(new ArrowPayload(_)), context) + } + val schema = DataType.fromJson(schemaString).asInstanceOf[StructType] + sqlContext.internalCreateDataFrame(rdd, schema) + } } From 176ae4d53e0269cfc2cfa62d3a2991e28f5a9182 Mon Sep 17 00:00:00 2001 From: Xianyang Liu Date: Mon, 13 Nov 2017 06:19:13 -0600 Subject: [PATCH 1662/1765] [MINOR][CORE] Using bufferedInputStream for dataDeserializeStream ## What changes were proposed in this pull request? Small fix. Using bufferedInputStream for dataDeserializeStream. ## How was this patch tested? Existing UT. Author: Xianyang Liu Closes #19735 from ConeyLiu/smallfix. --- .../scala/org/apache/spark/serializer/SerializerManager.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/serializer/SerializerManager.scala b/core/src/main/scala/org/apache/spark/serializer/SerializerManager.scala index 311383e7ea2bd..1d4b05caaa143 100644 --- a/core/src/main/scala/org/apache/spark/serializer/SerializerManager.scala +++ b/core/src/main/scala/org/apache/spark/serializer/SerializerManager.scala @@ -206,7 +206,7 @@ private[spark] class SerializerManager( val autoPick = !blockId.isInstanceOf[StreamBlockId] getSerializer(classTag, autoPick) .newInstance() - .deserializeStream(wrapForCompression(blockId, inputStream)) + .deserializeStream(wrapForCompression(blockId, stream)) .asIterator.asInstanceOf[Iterator[T]] } } From f7534b37ee91be14e511ab29259c3f83c7ad50af Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Mon, 13 Nov 2017 13:10:13 -0800 Subject: [PATCH 1663/1765] [SPARK-22487][SQL][FOLLOWUP] still keep spark.sql.hive.version ## What changes were proposed in this pull request? a followup of https://github.com/apache/spark/pull/19712 , adds back the `spark.sql.hive.version`, so that if users try to read this config, they can still get a default value instead of null. ## How was this patch tested? N/A Author: Wenchen Fan Closes #19719 from cloud-fan/minor. --- .../sql/hive/thriftserver/SparkSQLEnv.scala | 1 + .../thriftserver/SparkSQLSessionManager.scala | 1 + .../HiveThriftServer2Suites.scala | 23 +++++++++++++++---- .../org/apache/spark/sql/hive/HiveUtils.scala | 8 +++++++ 4 files changed, 28 insertions(+), 5 deletions(-) diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLEnv.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLEnv.scala index 5db93b26f550e..6b19f971b73bb 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLEnv.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLEnv.scala @@ -55,6 +55,7 @@ private[hive] object SparkSQLEnv extends Logging { metadataHive.setOut(new PrintStream(System.out, true, "UTF-8")) metadataHive.setInfo(new PrintStream(System.err, true, "UTF-8")) metadataHive.setError(new PrintStream(System.err, true, "UTF-8")) + sparkSession.conf.set(HiveUtils.FAKE_HIVE_VERSION.key, HiveUtils.builtinHiveVersion) } } diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLSessionManager.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLSessionManager.scala index 00920c297d493..48c0ebef3e0ce 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLSessionManager.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLSessionManager.scala @@ -77,6 +77,7 @@ private[hive] class SparkSQLSessionManager(hiveServer: HiveServer2, sqlContext: } else { sqlContext.newSession() } + ctx.setConf(HiveUtils.FAKE_HIVE_VERSION.key, HiveUtils.builtinHiveVersion) if (sessionConf != null && sessionConf.containsKey("use:database")) { ctx.sql(s"use ${sessionConf.get("use:database")}") } diff --git a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala index b80596f55bdea..7289da71a3365 100644 --- a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala +++ b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala @@ -155,9 +155,9 @@ class HiveThriftBinaryServerSuite extends HiveThriftJdbcTest { test("Checks Hive version") { withJdbcStatement() { statement => - val resultSet = statement.executeQuery("SET spark.sql.hive.metastore.version") + val resultSet = statement.executeQuery("SET spark.sql.hive.version") resultSet.next() - assert(resultSet.getString(1) === "spark.sql.hive.metastore.version") + assert(resultSet.getString(1) === "spark.sql.hive.version") assert(resultSet.getString(2) === HiveUtils.builtinHiveVersion) } } @@ -521,7 +521,20 @@ class HiveThriftBinaryServerSuite extends HiveThriftJdbcTest { conf += resultSet.getString(1) -> resultSet.getString(2) } - assert(conf.get("spark.sql.hive.metastore.version") === Some("1.2.1")) + assert(conf.get("spark.sql.hive.version") === Some("1.2.1")) + } + } + + test("Checks Hive version via SET") { + withJdbcStatement() { statement => + val resultSet = statement.executeQuery("SET") + + val conf = mutable.Map.empty[String, String] + while (resultSet.next()) { + conf += resultSet.getString(1) -> resultSet.getString(2) + } + + assert(conf.get("spark.sql.hive.version") === Some("1.2.1")) } } @@ -708,9 +721,9 @@ class HiveThriftHttpServerSuite extends HiveThriftJdbcTest { test("Checks Hive version") { withJdbcStatement() { statement => - val resultSet = statement.executeQuery("SET spark.sql.hive.metastore.version") + val resultSet = statement.executeQuery("SET spark.sql.hive.version") resultSet.next() - assert(resultSet.getString(1) === "spark.sql.hive.metastore.version") + assert(resultSet.getString(1) === "spark.sql.hive.version") assert(resultSet.getString(2) === HiveUtils.builtinHiveVersion) } } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveUtils.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveUtils.scala index d8e08f1f6df50..f5e6720f6a510 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveUtils.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveUtils.scala @@ -66,6 +66,14 @@ private[spark] object HiveUtils extends Logging { .stringConf .createWithDefault(builtinHiveVersion) + // A fake config which is only here for backward compatibility reasons. This config has no effect + // to Spark, just for reporting the builtin Hive version of Spark to existing applications that + // already rely on this config. + val FAKE_HIVE_VERSION = buildConf("spark.sql.hive.version") + .doc(s"deprecated, please use ${HIVE_METASTORE_VERSION.key} to get the Hive version in Spark.") + .stringConf + .createWithDefault(builtinHiveVersion) + val HIVE_METASTORE_JARS = buildConf("spark.sql.hive.metastore.jars") .doc(s""" | Location of the jars that should be used to instantiate the HiveMetastoreClient. From c8b7f97b8a58bf4a9f6e3a07dd6e5b0f646d8d99 Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Tue, 14 Nov 2017 08:28:13 +0900 Subject: [PATCH 1664/1765] [SPARK-22377][BUILD] Use /usr/sbin/lsof if lsof does not exists in release-build.sh ## What changes were proposed in this pull request? This PR proposes to use `/usr/sbin/lsof` if `lsof` is missing in the path to fix nightly snapshot jenkins jobs. Please refer https://github.com/apache/spark/pull/19359#issuecomment-340139557: > Looks like some of the snapshot builds are having lsof issues: > > https://amplab.cs.berkeley.edu/jenkins/view/Spark%20Packaging/job/spark-branch-2.1-maven-snapshots/182/console > >https://amplab.cs.berkeley.edu/jenkins/view/Spark%20Packaging/job/spark-branch-2.2-maven-snapshots/134/console > >spark-build/dev/create-release/release-build.sh: line 344: lsof: command not found >usage: kill [ -s signal | -p ] [ -a ] pid ... >kill -l [ signal ] Up to my knowledge, the full path of `lsof` is required for non-root user in few OSs. ## How was this patch tested? Manually tested as below: ```bash #!/usr/bin/env bash LSOF=lsof if ! hash $LSOF 2>/dev/null; then echo "a" LSOF=/usr/sbin/lsof fi $LSOF -P | grep "a" ``` Author: hyukjinkwon Closes #19695 from HyukjinKwon/SPARK-22377. --- dev/create-release/release-build.sh | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/dev/create-release/release-build.sh b/dev/create-release/release-build.sh index 7e8d5c7075195..5b43f9bab7505 100755 --- a/dev/create-release/release-build.sh +++ b/dev/create-release/release-build.sh @@ -130,6 +130,13 @@ else fi fi +# This is a band-aid fix to avoid the failure of Maven nightly snapshot in some Jenkins +# machines by explicitly calling /usr/sbin/lsof. Please see SPARK-22377 and the discussion +# in its pull request. +LSOF=lsof +if ! hash $LSOF 2>/dev/null; then + LSOF=/usr/sbin/lsof +fi if [ -z "$SPARK_PACKAGE_VERSION" ]; then SPARK_PACKAGE_VERSION="${SPARK_VERSION}-$(date +%Y_%m_%d_%H_%M)-${git_hash}" @@ -345,7 +352,7 @@ if [[ "$1" == "publish-snapshot" ]]; then # -DskipTests $SCALA_2_12_PROFILES $PUBLISH_PROFILES clean deploy # Clean-up Zinc nailgun process - lsof -P |grep $ZINC_PORT | grep LISTEN | awk '{ print $2; }' | xargs kill + $LSOF -P |grep $ZINC_PORT | grep LISTEN | awk '{ print $2; }' | xargs kill rm $tmp_settings cd .. @@ -382,7 +389,7 @@ if [[ "$1" == "publish-release" ]]; then # -DskipTests $SCALA_2_12_PROFILES §$PUBLISH_PROFILES clean install # Clean-up Zinc nailgun process - lsof -P |grep $ZINC_PORT | grep LISTEN | awk '{ print $2; }' | xargs kill + $LSOF -P |grep $ZINC_PORT | grep LISTEN | awk '{ print $2; }' | xargs kill #./dev/change-scala-version.sh 2.11 From d8741b2b0fe8b8da74f120859e969326fb170629 Mon Sep 17 00:00:00 2001 From: WeichenXu Date: Mon, 13 Nov 2017 17:00:51 -0800 Subject: [PATCH 1665/1765] [SPARK-21911][ML][FOLLOW-UP] Fix doc for parallel ML Tuning in PySpark ## What changes were proposed in this pull request? Fix doc issue mentioned here: https://github.com/apache/spark/pull/19122#issuecomment-340111834 ## How was this patch tested? N/A Author: WeichenXu Closes #19641 from WeichenXu123/fix_doc. --- docs/ml-tuning.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/ml-tuning.md b/docs/ml-tuning.md index 64dc46cf0c0e7..54d9cd21909df 100644 --- a/docs/ml-tuning.md +++ b/docs/ml-tuning.md @@ -55,7 +55,7 @@ for multiclass problems. The default metric used to choose the best `ParamMap` c method in each of these evaluators. To help construct the parameter grid, users can use the [`ParamGridBuilder`](api/scala/index.html#org.apache.spark.ml.tuning.ParamGridBuilder) utility. -By default, sets of parameters from the parameter grid are evaluated in serial. Parameter evaluation can be done in parallel by setting `parallelism` with a value of 2 or more (a value of 1 will be serial) before running model selection with `CrossValidator` or `TrainValidationSplit` (NOTE: this is not yet supported in Python). +By default, sets of parameters from the parameter grid are evaluated in serial. Parameter evaluation can be done in parallel by setting `parallelism` with a value of 2 or more (a value of 1 will be serial) before running model selection with `CrossValidator` or `TrainValidationSplit`. The value of `parallelism` should be chosen carefully to maximize parallelism without exceeding cluster resources, and larger values may not always lead to improved performance. Generally speaking, a value up to 10 should be sufficient for most clusters. # Cross-Validation From 673c67046598d33b9ecf864024ca7a937c1998d6 Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Tue, 14 Nov 2017 12:34:21 +0100 Subject: [PATCH 1666/1765] [SPARK-17310][SQL] Add an option to disable record-level filter in Parquet-side ## What changes were proposed in this pull request? There is a concern that Spark-side codegen row-by-row filtering might be faster than Parquet's one in general due to type-boxing and additional fuction calls which Spark's one tries to avoid. So, this PR adds an option to disable/enable record-by-record filtering in Parquet side. It sets the default to `false` to take the advantage of the improvement. This was also discussed in https://github.com/apache/spark/pull/14671. ## How was this patch tested? Manually benchmarks were performed. I generated a billion (1,000,000,000) records and tested equality comparison concatenated with `OR`. This filter combinations were made from 5 to 30. It seem indeed Spark-filtering is faster in the test case and the gap increased as the filter tree becomes larger. The details are as below: **Code** ``` scala test("Parquet-side filter vs Spark-side filter - record by record") { withTempPath { path => val N = 1000 * 1000 * 1000 val df = spark.range(N).toDF("a") df.write.parquet(path.getAbsolutePath) val benchmark = new Benchmark("Parquet-side vs Spark-side", N) Seq(5, 10, 20, 30).foreach { num => val filterExpr = (0 to num).map(i => s"a = $i").mkString(" OR ") benchmark.addCase(s"Parquet-side filter - number of filters [$num]", 3) { _ => withSQLConf(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key -> false.toString, SQLConf.PARQUET_RECORD_FILTER_ENABLED.key -> true.toString) { // We should strip Spark-side filter to compare correctly. stripSparkFilter( spark.read.parquet(path.getAbsolutePath).filter(filterExpr)).count() } } benchmark.addCase(s"Spark-side filter - number of filters [$num]", 3) { _ => withSQLConf(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key -> false.toString, SQLConf.PARQUET_RECORD_FILTER_ENABLED.key -> false.toString) { spark.read.parquet(path.getAbsolutePath).filter(filterExpr).count() } } } benchmark.run() } } ``` **Result** ``` Parquet-side vs Spark-side: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ Parquet-side filter - number of filters [5] 4268 / 4367 234.3 4.3 0.8X Spark-side filter - number of filters [5] 3709 / 3741 269.6 3.7 0.9X Parquet-side filter - number of filters [10] 5673 / 5727 176.3 5.7 0.6X Spark-side filter - number of filters [10] 3588 / 3632 278.7 3.6 0.9X Parquet-side filter - number of filters [20] 8024 / 8440 124.6 8.0 0.4X Spark-side filter - number of filters [20] 3912 / 3946 255.6 3.9 0.8X Parquet-side filter - number of filters [30] 11936 / 12041 83.8 11.9 0.3X Spark-side filter - number of filters [30] 3929 / 3978 254.5 3.9 0.8X ``` Author: hyukjinkwon Closes #15049 from HyukjinKwon/SPARK-17310. --- .../apache/spark/sql/internal/SQLConf.scala | 9 ++++ .../parquet/ParquetFileFormat.scala | 14 ++--- .../parquet/ParquetFilterSuite.scala | 51 ++++++++++++++++++- 3 files changed, 65 insertions(+), 9 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 831ef62d74c3c..0cb58fab47ac7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -327,6 +327,13 @@ object SQLConf { .booleanConf .createWithDefault(false) + val PARQUET_RECORD_FILTER_ENABLED = buildConf("spark.sql.parquet.recordLevelFilter.enabled") + .doc("If true, enables Parquet's native record-level filtering using the pushed down " + + "filters. This configuration only has an effect when 'spark.sql.parquet.filterPushdown' " + + "is enabled.") + .booleanConf + .createWithDefault(false) + val PARQUET_OUTPUT_COMMITTER_CLASS = buildConf("spark.sql.parquet.output.committer.class") .doc("The output committer class used by Parquet. The specified class needs to be a " + "subclass of org.apache.hadoop.mapreduce.OutputCommitter. Typically, it's also a subclass " + @@ -1173,6 +1180,8 @@ class SQLConf extends Serializable with Logging { def writeLegacyParquetFormat: Boolean = getConf(PARQUET_WRITE_LEGACY_FORMAT) + def parquetRecordFilterEnabled: Boolean = getConf(PARQUET_RECORD_FILTER_ENABLED) + def inMemoryPartitionPruning: Boolean = getConf(IN_MEMORY_PARTITION_PRUNING) def columnNameOfCorruptRecord: String = getConf(COLUMN_NAME_OF_CORRUPT_RECORD) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala index a48f8d517b6ab..044b1a89d57c9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala @@ -335,6 +335,8 @@ class ParquetFileFormat val enableVectorizedReader: Boolean = sparkSession.sessionState.conf.parquetVectorizedReaderEnabled && resultSchema.forall(_.dataType.isInstanceOf[AtomicType]) + val enableRecordFilter: Boolean = + sparkSession.sessionState.conf.parquetRecordFilterEnabled // Whole stage codegen (PhysicalRDD) is able to deal with batches directly val returningBatch = supportBatch(sparkSession, resultSchema) @@ -374,13 +376,11 @@ class ParquetFileFormat } else { logDebug(s"Falling back to parquet-mr") // ParquetRecordReader returns UnsafeRow - val reader = pushed match { - case Some(filter) => - new ParquetRecordReader[UnsafeRow]( - new ParquetReadSupport, - FilterCompat.get(filter, null)) - case _ => - new ParquetRecordReader[UnsafeRow](new ParquetReadSupport) + val reader = if (pushed.isDefined && enableRecordFilter) { + val parquetFilter = FilterCompat.get(pushed.get, null) + new ParquetRecordReader[UnsafeRow](new ParquetReadSupport, parquetFilter) + } else { + new ParquetRecordReader[UnsafeRow](new ParquetReadSupport) } reader.initialize(split, hadoopAttemptContext) reader diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala index 90f6620d990cb..33801954ebd51 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala @@ -45,8 +45,29 @@ import org.apache.spark.util.{AccumulatorContext, AccumulatorV2} * * 2. `Tuple1(Option(x))` is used together with `AnyVal` types like `Int` to ensure the inferred * data type is nullable. + * + * NOTE: + * + * This file intendedly enables record-level filtering explicitly. If new test cases are + * dependent on this configuration, don't forget you better explicitly set this configuration + * within the test. */ class ParquetFilterSuite extends QueryTest with ParquetTest with SharedSQLContext { + + override def beforeEach(): Unit = { + super.beforeEach() + // Note that there are many tests here that require record-level filtering set to be true. + spark.conf.set(SQLConf.PARQUET_RECORD_FILTER_ENABLED.key, "true") + } + + override def afterEach(): Unit = { + try { + spark.conf.unset(SQLConf.PARQUET_RECORD_FILTER_ENABLED.key) + } finally { + super.afterEach() + } + } + private def checkFilterPredicate( df: DataFrame, predicate: Predicate, @@ -369,7 +390,7 @@ class ParquetFilterSuite extends QueryTest with ParquetTest with SharedSQLContex test("Filter applied on merged Parquet schema with new column should work") { import testImplicits._ - Seq("true", "false").map { vectorized => + Seq("true", "false").foreach { vectorized => withSQLConf(SQLConf.PARQUET_FILTER_PUSHDOWN_ENABLED.key -> "true", SQLConf.PARQUET_SCHEMA_MERGING_ENABLED.key -> "true", SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key -> vectorized) { @@ -491,7 +512,7 @@ class ParquetFilterSuite extends QueryTest with ParquetTest with SharedSQLContex } } - test("Fiters should be pushed down for vectorized Parquet reader at row group level") { + test("Filters should be pushed down for vectorized Parquet reader at row group level") { import testImplicits._ withSQLConf(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key -> "true", @@ -555,6 +576,32 @@ class ParquetFilterSuite extends QueryTest with ParquetTest with SharedSQLContex } } } + + test("Filters should be pushed down for Parquet readers at row group level") { + import testImplicits._ + + withSQLConf( + // Makes sure disabling 'spark.sql.parquet.recordFilter' still enables + // row group level filtering. + SQLConf.PARQUET_RECORD_FILTER_ENABLED.key -> "false", + SQLConf.PARQUET_FILTER_PUSHDOWN_ENABLED.key -> "true", + SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key -> "false") { + withTempPath { path => + val data = (1 to 1024) + data.toDF("a").coalesce(1) + .write.option("parquet.block.size", 512) + .parquet(path.getAbsolutePath) + val df = spark.read.parquet(path.getAbsolutePath).filter("a == 500") + // Here, we strip the Spark side filter and check the actual results from Parquet. + val actual = stripSparkFilter(df).collect().length + // Since those are filtered at row group level, the result count should be less + // than the total length but should not be a single record. + // Note that, if record level filtering is enabled, it should be a single record. + // If no filter is pushed down to Parquet, it should be the total length of data. + assert(actual > 1 && actual < data.length) + } + } + } } class NumRowGroupsAcc extends AccumulatorV2[Integer, Integer] { From 11b60af737a04d931356aa74ebf3c6cf4a6b08d6 Mon Sep 17 00:00:00 2001 From: Zhenhua Wang Date: Tue, 14 Nov 2017 16:41:43 +0100 Subject: [PATCH 1667/1765] [SPARK-17074][SQL] Generate equi-height histogram in column statistics ## What changes were proposed in this pull request? Equi-height histogram is effective in cardinality estimation, and more accurate than basic column stats (min, max, ndv, etc) especially in skew distribution. So we need to support it. For equi-height histogram, all buckets (intervals) have the same height (frequency). In this PR, we use a two-step method to generate an equi-height histogram: 1. use `ApproximatePercentile` to get percentiles `p(0), p(1/n), p(2/n) ... p((n-1)/n), p(1)`; 2. construct range values of buckets, e.g. `[p(0), p(1/n)], [p(1/n), p(2/n)] ... [p((n-1)/n), p(1)]`, and use `ApproxCountDistinctForIntervals` to count ndv in each bucket. Each bucket is of the form: `(lowerBound, higherBound, ndv)`. ## How was this patch tested? Added new test cases and modified some existing test cases. Author: Zhenhua Wang Author: Zhenhua Wang Closes #19479 from wzhfy/generate_histogram. --- .../catalyst/plans/logical/Statistics.scala | 203 ++++++++++++-- .../apache/spark/sql/internal/SQLConf.scala | 34 ++- .../command/AnalyzeColumnCommand.scala | 57 +++- .../spark/sql/StatisticsCollectionSuite.scala | 15 +- .../sql/StatisticsCollectionTestBase.scala | 41 ++- .../spark/sql/hive/HiveExternalCatalog.scala | 11 +- .../spark/sql/hive/StatisticsSuite.scala | 255 ++++++++++++------ 7 files changed, 484 insertions(+), 132 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/Statistics.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/Statistics.scala index 5ae1a55a8b66f..96b199d7f20b0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/Statistics.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/Statistics.scala @@ -17,16 +17,20 @@ package org.apache.spark.sql.catalyst.plans.logical +import java.io.{ByteArrayInputStream, ByteArrayOutputStream, DataInputStream, DataOutputStream} import java.math.{MathContext, RoundingMode} import scala.util.control.NonFatal +import net.jpountz.lz4.{LZ4BlockInputStream, LZ4BlockOutputStream} + import org.apache.spark.internal.Logging import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ -import org.apache.spark.sql.catalyst.util.DateTimeUtils +import org.apache.spark.sql.catalyst.util.{ArrayData, DateTimeUtils} +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ import org.apache.spark.util.Utils @@ -88,6 +92,7 @@ case class Statistics( * @param nullCount number of nulls * @param avgLen average length of the values. For fixed-length types, this should be a constant. * @param maxLen maximum length of the values. For fixed-length types, this should be a constant. + * @param histogram histogram of the values */ case class ColumnStat( distinctCount: BigInt, @@ -95,7 +100,8 @@ case class ColumnStat( max: Option[Any], nullCount: BigInt, avgLen: Long, - maxLen: Long) { + maxLen: Long, + histogram: Option[Histogram] = None) { // We currently don't store min/max for binary/string type. This can change in the future and // then we need to remove this require. @@ -121,6 +127,7 @@ case class ColumnStat( map.put(ColumnStat.KEY_MAX_LEN, maxLen.toString) min.foreach { v => map.put(ColumnStat.KEY_MIN_VALUE, toExternalString(v, colName, dataType)) } max.foreach { v => map.put(ColumnStat.KEY_MAX_VALUE, toExternalString(v, colName, dataType)) } + histogram.foreach { h => map.put(ColumnStat.KEY_HISTOGRAM, HistogramSerializer.serialize(h)) } map.toMap } @@ -155,6 +162,7 @@ object ColumnStat extends Logging { private val KEY_NULL_COUNT = "nullCount" private val KEY_AVG_LEN = "avgLen" private val KEY_MAX_LEN = "maxLen" + private val KEY_HISTOGRAM = "histogram" /** Returns true iff the we support gathering column statistics on column of the given type. */ def supportsType(dataType: DataType): Boolean = dataType match { @@ -168,6 +176,16 @@ object ColumnStat extends Logging { case _ => false } + /** Returns true iff the we support gathering histogram on column of the given type. */ + def supportsHistogram(dataType: DataType): Boolean = dataType match { + case _: IntegralType => true + case _: DecimalType => true + case DoubleType | FloatType => true + case DateType => true + case TimestampType => true + case _ => false + } + /** * Creates a [[ColumnStat]] object from the given map. This is used to deserialize column stats * from some external storage. The serialization side is defined in [[ColumnStat.toMap]]. @@ -183,7 +201,8 @@ object ColumnStat extends Logging { .map(fromExternalString(_, field.name, field.dataType)).flatMap(Option.apply), nullCount = BigInt(map(KEY_NULL_COUNT).toLong), avgLen = map.getOrElse(KEY_AVG_LEN, field.dataType.defaultSize.toString).toLong, - maxLen = map.getOrElse(KEY_MAX_LEN, field.dataType.defaultSize.toString).toLong + maxLen = map.getOrElse(KEY_MAX_LEN, field.dataType.defaultSize.toString).toLong, + histogram = map.get(KEY_HISTOGRAM).map(HistogramSerializer.deserialize) )) } catch { case NonFatal(e) => @@ -220,12 +239,16 @@ object ColumnStat extends Logging { * Constructs an expression to compute column statistics for a given column. * * The expression should create a single struct column with the following schema: - * distinctCount: Long, min: T, max: T, nullCount: Long, avgLen: Long, maxLen: Long + * distinctCount: Long, min: T, max: T, nullCount: Long, avgLen: Long, maxLen: Long, + * distinctCountsForIntervals: Array[Long] * * Together with [[rowToColumnStat]], this function is used to create [[ColumnStat]] and * as a result should stay in sync with it. */ - def statExprs(col: Attribute, relativeSD: Double): CreateNamedStruct = { + def statExprs( + col: Attribute, + conf: SQLConf, + colPercentiles: AttributeMap[ArrayData]): CreateNamedStruct = { def struct(exprs: Expression*): CreateNamedStruct = CreateStruct(exprs.map { expr => expr.transformUp { case af: AggregateFunction => af.toAggregateExpression() } }) @@ -233,40 +256,55 @@ object ColumnStat extends Logging { // the approximate ndv (num distinct value) should never be larger than the number of rows val numNonNulls = if (col.nullable) Count(col) else Count(one) - val ndv = Least(Seq(HyperLogLogPlusPlus(col, relativeSD), numNonNulls)) + val ndv = Least(Seq(HyperLogLogPlusPlus(col, conf.ndvMaxError), numNonNulls)) val numNulls = Subtract(Count(one), numNonNulls) val defaultSize = Literal(col.dataType.defaultSize, LongType) + val nullArray = Literal(null, ArrayType(LongType)) - def fixedLenTypeStruct(castType: DataType) = { + def fixedLenTypeStruct: CreateNamedStruct = { + val genHistogram = + ColumnStat.supportsHistogram(col.dataType) && colPercentiles.contains(col) + val intervalNdvsExpr = if (genHistogram) { + ApproxCountDistinctForIntervals(col, + Literal(colPercentiles(col), ArrayType(col.dataType)), conf.ndvMaxError) + } else { + nullArray + } // For fixed width types, avg size should be the same as max size. - struct(ndv, Cast(Min(col), castType), Cast(Max(col), castType), numNulls, defaultSize, - defaultSize) + struct(ndv, Cast(Min(col), col.dataType), Cast(Max(col), col.dataType), numNulls, + defaultSize, defaultSize, intervalNdvsExpr) } col.dataType match { - case dt: IntegralType => fixedLenTypeStruct(dt) - case _: DecimalType => fixedLenTypeStruct(col.dataType) - case dt @ (DoubleType | FloatType) => fixedLenTypeStruct(dt) - case BooleanType => fixedLenTypeStruct(col.dataType) - case DateType => fixedLenTypeStruct(col.dataType) - case TimestampType => fixedLenTypeStruct(col.dataType) + case _: IntegralType => fixedLenTypeStruct + case _: DecimalType => fixedLenTypeStruct + case DoubleType | FloatType => fixedLenTypeStruct + case BooleanType => fixedLenTypeStruct + case DateType => fixedLenTypeStruct + case TimestampType => fixedLenTypeStruct case BinaryType | StringType => - // For string and binary type, we don't store min/max. + // For string and binary type, we don't compute min, max or histogram val nullLit = Literal(null, col.dataType) struct( ndv, nullLit, nullLit, numNulls, // Set avg/max size to default size if all the values are null or there is no value. Coalesce(Seq(Ceil(Average(Length(col))), defaultSize)), - Coalesce(Seq(Cast(Max(Length(col)), LongType), defaultSize))) + Coalesce(Seq(Cast(Max(Length(col)), LongType), defaultSize)), + nullArray) case _ => throw new AnalysisException("Analyzing column statistics is not supported for column " + - s"${col.name} of data type: ${col.dataType}.") + s"${col.name} of data type: ${col.dataType}.") } } - /** Convert a struct for column stats (defined in statExprs) into [[ColumnStat]]. */ - def rowToColumnStat(row: InternalRow, attr: Attribute): ColumnStat = { - ColumnStat( + /** Convert a struct for column stats (defined in `statExprs`) into [[ColumnStat]]. */ + def rowToColumnStat( + row: InternalRow, + attr: Attribute, + rowCount: Long, + percentiles: Option[ArrayData]): ColumnStat = { + // The first 6 fields are basic column stats, the 7th is ndvs for histogram bins. + val cs = ColumnStat( distinctCount = BigInt(row.getLong(0)), // for string/binary min/max, get should return null min = Option(row.get(1, attr.dataType)), @@ -275,6 +313,129 @@ object ColumnStat extends Logging { avgLen = row.getLong(4), maxLen = row.getLong(5) ) + if (row.isNullAt(6)) { + cs + } else { + val ndvs = row.getArray(6).toLongArray() + assert(percentiles.get.numElements() == ndvs.length + 1) + val endpoints = percentiles.get.toArray[Any](attr.dataType).map(_.toString.toDouble) + // Construct equi-height histogram + val bins = ndvs.zipWithIndex.map { case (ndv, i) => + HistogramBin(endpoints(i), endpoints(i + 1), ndv) + } + val nonNullRows = rowCount - cs.nullCount + val histogram = Histogram(nonNullRows.toDouble / ndvs.length, bins) + cs.copy(histogram = Some(histogram)) + } + } + +} + +/** + * This class is an implementation of equi-height histogram. + * Equi-height histogram represents the distribution of a column's values by a sequence of bins. + * Each bin has a value range and contains approximately the same number of rows. + * + * @param height number of rows in each bin + * @param bins equi-height histogram bins + */ +case class Histogram(height: Double, bins: Array[HistogramBin]) { + + // Only for histogram equality test. + override def equals(other: Any): Boolean = other match { + case otherHgm: Histogram => + height == otherHgm.height && bins.sameElements(otherHgm.bins) + case _ => false + } + + override def hashCode(): Int = { + val temp = java.lang.Double.doubleToLongBits(height) + var result = (temp ^ (temp >>> 32)).toInt + result = 31 * result + java.util.Arrays.hashCode(bins.asInstanceOf[Array[AnyRef]]) + result } +} + +/** + * A bin in an equi-height histogram. We use double type for lower/higher bound for simplicity. + * + * @param lo lower bound of the value range in this bin + * @param hi higher bound of the value range in this bin + * @param ndv approximate number of distinct values in this bin + */ +case class HistogramBin(lo: Double, hi: Double, ndv: Long) +object HistogramSerializer { + /** + * Serializes a given histogram to a string. For advanced statistics like histograms, sketches, + * etc, we don't provide readability for their serialized formats in metastore + * (string-to-string table properties). This is because it's hard or unnatural for these + * statistics to be human readable. For example, a histogram usually cannot fit in a single, + * self-described property. And for count-min-sketch, it's essentially unnatural to make it + * a readable string. + */ + final def serialize(histogram: Histogram): String = { + val bos = new ByteArrayOutputStream() + val out = new DataOutputStream(new LZ4BlockOutputStream(bos)) + out.writeDouble(histogram.height) + out.writeInt(histogram.bins.length) + // Write data with same type together for compression. + var i = 0 + while (i < histogram.bins.length) { + out.writeDouble(histogram.bins(i).lo) + i += 1 + } + i = 0 + while (i < histogram.bins.length) { + out.writeDouble(histogram.bins(i).hi) + i += 1 + } + i = 0 + while (i < histogram.bins.length) { + out.writeLong(histogram.bins(i).ndv) + i += 1 + } + out.writeInt(-1) + out.flush() + out.close() + + org.apache.commons.codec.binary.Base64.encodeBase64String(bos.toByteArray) + } + + /** Deserializes a given string to a histogram. */ + final def deserialize(str: String): Histogram = { + val bytes = org.apache.commons.codec.binary.Base64.decodeBase64(str) + val bis = new ByteArrayInputStream(bytes) + val ins = new DataInputStream(new LZ4BlockInputStream(bis)) + val height = ins.readDouble() + val numBins = ins.readInt() + + val los = new Array[Double](numBins) + var i = 0 + while (i < numBins) { + los(i) = ins.readDouble() + i += 1 + } + val his = new Array[Double](numBins) + i = 0 + while (i < numBins) { + his(i) = ins.readDouble() + i += 1 + } + val ndvs = new Array[Long](numBins) + i = 0 + while (i < numBins) { + ndvs(i) = ins.readLong() + i += 1 + } + ins.close() + + val bins = new Array[HistogramBin](numBins) + i = 0 + while (i < numBins) { + bins(i) = HistogramBin(los(i), his(i), ndvs(i)) + i += 1 + } + Histogram(height, bins) + } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 0cb58fab47ac7..3452a1e715fb9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -31,7 +31,6 @@ import org.apache.spark.internal.config._ import org.apache.spark.network.util.ByteUnit import org.apache.spark.sql.catalyst.analysis.Resolver import org.apache.spark.sql.catalyst.expressions.codegen.CodeGenerator -import org.apache.spark.util.collection.unsafe.sort.UnsafeExternalSorter //////////////////////////////////////////////////////////////////////////////////////////////////// // This file defines the configuration options for Spark SQL. @@ -835,6 +834,33 @@ object SQLConf { .doubleConf .createWithDefault(0.05) + val HISTOGRAM_ENABLED = + buildConf("spark.sql.statistics.histogram.enabled") + .doc("Generates histograms when computing column statistics if enabled. Histograms can " + + "provide better estimation accuracy. Currently, Spark only supports equi-height " + + "histogram. Note that collecting histograms takes extra cost. For example, collecting " + + "column statistics usually takes only one table scan, but generating equi-height " + + "histogram will cause an extra table scan.") + .booleanConf + .createWithDefault(false) + + val HISTOGRAM_NUM_BINS = + buildConf("spark.sql.statistics.histogram.numBins") + .internal() + .doc("The number of bins when generating histograms.") + .intConf + .checkValue(num => num > 1, "The number of bins must be large than 1.") + .createWithDefault(254) + + val PERCENTILE_ACCURACY = + buildConf("spark.sql.statistics.percentile.accuracy") + .internal() + .doc("Accuracy of percentile approximation when generating equi-height histograms. " + + "Larger value means better accuracy. The relative error can be deduced by " + + "1.0 / PERCENTILE_ACCURACY.") + .intConf + .createWithDefault(10000) + val AUTO_SIZE_UPDATE_ENABLED = buildConf("spark.sql.statistics.size.autoUpdate.enabled") .doc("Enables automatic update for table size once table's data is changed. Note that if " + @@ -1241,6 +1267,12 @@ class SQLConf extends Serializable with Logging { def ndvMaxError: Double = getConf(NDV_MAX_ERROR) + def histogramEnabled: Boolean = getConf(HISTOGRAM_ENABLED) + + def histogramNumBins: Int = getConf(HISTOGRAM_NUM_BINS) + + def percentileAccuracy: Int = getConf(PERCENTILE_ACCURACY) + def cboEnabled: Boolean = getConf(SQLConf.CBO_ENABLED) def autoSizeUpdateEnabled: Boolean = getConf(SQLConf.AUTO_SIZE_UPDATE_ENABLED) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeColumnCommand.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeColumnCommand.scala index caf12ad745bb8..e3bb4d357b395 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeColumnCommand.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeColumnCommand.scala @@ -17,12 +17,15 @@ package org.apache.spark.sql.execution.command +import scala.collection.mutable + import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.catalog.{CatalogStatistics, CatalogTableType} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.catalyst.util.ArrayData import org.apache.spark.sql.execution.QueryExecution @@ -68,11 +71,11 @@ case class AnalyzeColumnCommand( tableIdent: TableIdentifier, columnNames: Seq[String]): (Long, Map[String, ColumnStat]) = { + val conf = sparkSession.sessionState.conf val relation = sparkSession.table(tableIdent).logicalPlan // Resolve the column names and dedup using AttributeSet - val resolver = sparkSession.sessionState.conf.resolver val attributesToAnalyze = columnNames.map { col => - val exprOption = relation.output.find(attr => resolver(attr.name, col)) + val exprOption = relation.output.find(attr => conf.resolver(attr.name, col)) exprOption.getOrElse(throw new AnalysisException(s"Column $col does not exist.")) } @@ -86,12 +89,21 @@ case class AnalyzeColumnCommand( } // Collect statistics per column. + // If no histogram is required, we run a job to compute basic column stats such as + // min, max, ndv, etc. Otherwise, besides basic column stats, histogram will also be + // generated. Currently we only support equi-height histogram. + // To generate an equi-height histogram, we need two jobs: + // 1. compute percentiles p(0), p(1/n) ... p((n-1)/n), p(1). + // 2. use the percentiles as value intervals of bins, e.g. [p(0), p(1/n)], + // [p(1/n), p(2/n)], ..., [p((n-1)/n), p(1)], and then count ndv in each bin. + // Basic column stats will be computed together in the second job. + val attributePercentiles = computePercentiles(attributesToAnalyze, sparkSession, relation) + // The first element in the result will be the overall row count, the following elements // will be structs containing all column stats. // The layout of each struct follows the layout of the ColumnStats. - val ndvMaxErr = sparkSession.sessionState.conf.ndvMaxError val expressions = Count(Literal(1)).toAggregateExpression() +: - attributesToAnalyze.map(ColumnStat.statExprs(_, ndvMaxErr)) + attributesToAnalyze.map(ColumnStat.statExprs(_, conf, attributePercentiles)) val namedExpressions = expressions.map(e => Alias(e, e.toString)()) val statsRow = new QueryExecution(sparkSession, Aggregate(Nil, namedExpressions, relation)) @@ -99,9 +111,42 @@ case class AnalyzeColumnCommand( val rowCount = statsRow.getLong(0) val columnStats = attributesToAnalyze.zipWithIndex.map { case (attr, i) => - // according to `ColumnStat.statExprs`, the stats struct always have 6 fields. - (attr.name, ColumnStat.rowToColumnStat(statsRow.getStruct(i + 1, 6), attr)) + // according to `ColumnStat.statExprs`, the stats struct always have 7 fields. + (attr.name, ColumnStat.rowToColumnStat(statsRow.getStruct(i + 1, 7), attr, rowCount, + attributePercentiles.get(attr))) }.toMap (rowCount, columnStats) } + + /** Computes percentiles for each attribute. */ + private def computePercentiles( + attributesToAnalyze: Seq[Attribute], + sparkSession: SparkSession, + relation: LogicalPlan): AttributeMap[ArrayData] = { + val attrsToGenHistogram = if (conf.histogramEnabled) { + attributesToAnalyze.filter(a => ColumnStat.supportsHistogram(a.dataType)) + } else { + Nil + } + val attributePercentiles = mutable.HashMap[Attribute, ArrayData]() + if (attrsToGenHistogram.nonEmpty) { + val percentiles = (0 to conf.histogramNumBins) + .map(i => i.toDouble / conf.histogramNumBins).toArray + + val namedExprs = attrsToGenHistogram.map { attr => + val aggFunc = + new ApproximatePercentile(attr, Literal(percentiles), Literal(conf.percentileAccuracy)) + val expr = aggFunc.toAggregateExpression() + Alias(expr, expr.toString)() + } + + val percentilesRow = new QueryExecution(sparkSession, Aggregate(Nil, namedExprs, relation)) + .executedPlan.executeTake(1).head + attrsToGenHistogram.zipWithIndex.foreach { case (attr, i) => + attributePercentiles += attr -> percentilesRow.getArray(i) + } + } + AttributeMap(attributePercentiles.toSeq) + } + } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/StatisticsCollectionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/StatisticsCollectionSuite.scala index 7247c3a876df3..fba5d2652d3f5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/StatisticsCollectionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/StatisticsCollectionSuite.scala @@ -142,10 +142,12 @@ class StatisticsCollectionSuite extends StatisticsCollectionTestBase with Shared test("column stats round trip serialization") { // Make sure we serialize and then deserialize and we will get the result data val df = data.toDF(stats.keys.toSeq :+ "carray" : _*) - stats.zip(df.schema).foreach { case ((k, v), field) => - withClue(s"column $k with type ${field.dataType}") { - val roundtrip = ColumnStat.fromMap("table_is_foo", field, v.toMap(k, field.dataType)) - assert(roundtrip == Some(v)) + Seq(stats, statsWithHgms).foreach { s => + s.zip(df.schema).foreach { case ((k, v), field) => + withClue(s"column $k with type ${field.dataType}") { + val roundtrip = ColumnStat.fromMap("table_is_foo", field, v.toMap(k, field.dataType)) + assert(roundtrip == Some(v)) + } } } } @@ -155,6 +157,11 @@ class StatisticsCollectionSuite extends StatisticsCollectionTestBase with Shared assert(stats.size == data.head.productArity - 1) val df = data.toDF(stats.keys.toSeq :+ "carray" : _*) checkColStats(df, stats) + + // test column stats with histograms + withSQLConf(SQLConf.HISTOGRAM_ENABLED.key -> "true", SQLConf.HISTOGRAM_NUM_BINS.key -> "2") { + checkColStats(df, statsWithHgms) + } } test("column stats collection for null columns") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/StatisticsCollectionTestBase.scala b/sql/core/src/test/scala/org/apache/spark/sql/StatisticsCollectionTestBase.scala index a2f63edd786bf..f6df077ec5727 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/StatisticsCollectionTestBase.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/StatisticsCollectionTestBase.scala @@ -25,7 +25,7 @@ import scala.util.Random import org.apache.spark.sql.catalyst.{QualifiedTableName, TableIdentifier} import org.apache.spark.sql.catalyst.catalog.{CatalogStatistics, CatalogTable, HiveTableRelation} -import org.apache.spark.sql.catalyst.plans.logical.{ColumnStat, LogicalPlan} +import org.apache.spark.sql.catalyst.plans.logical.{ColumnStat, Histogram, HistogramBin, LogicalPlan} import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.execution.datasources.LogicalRelation import org.apache.spark.sql.internal.StaticSQLConf @@ -46,6 +46,10 @@ abstract class StatisticsCollectionTestBase extends QueryTest with SQLTestUtils private val d2 = Date.valueOf("2016-05-09") private val t1 = Timestamp.valueOf("2016-05-08 00:00:01") private val t2 = Timestamp.valueOf("2016-05-09 00:00:02") + private val d1Internal = DateTimeUtils.fromJavaDate(d1) + private val d2Internal = DateTimeUtils.fromJavaDate(d2) + private val t1Internal = DateTimeUtils.fromJavaTimestamp(t1) + private val t2Internal = DateTimeUtils.fromJavaTimestamp(t2) /** * Define a very simple 3 row table used for testing column serialization. @@ -73,12 +77,39 @@ abstract class StatisticsCollectionTestBase extends QueryTest with SQLTestUtils "cdecimal" -> ColumnStat(2, Some(Decimal(dec1)), Some(Decimal(dec2)), 1, 16, 16), "cstring" -> ColumnStat(2, None, None, 1, 3, 3), "cbinary" -> ColumnStat(2, None, None, 1, 3, 3), - "cdate" -> ColumnStat(2, Some(DateTimeUtils.fromJavaDate(d1)), - Some(DateTimeUtils.fromJavaDate(d2)), 1, 4, 4), - "ctimestamp" -> ColumnStat(2, Some(DateTimeUtils.fromJavaTimestamp(t1)), - Some(DateTimeUtils.fromJavaTimestamp(t2)), 1, 8, 8) + "cdate" -> ColumnStat(2, Some(d1Internal), Some(d2Internal), 1, 4, 4), + "ctimestamp" -> ColumnStat(2, Some(t1Internal), Some(t2Internal), 1, 8, 8) ) + /** + * A mapping from column to the stats collected including histograms. + * The number of bins in the histograms is 2. + */ + protected val statsWithHgms = { + val colStats = mutable.LinkedHashMap(stats.toSeq: _*) + colStats.update("cbyte", stats("cbyte").copy(histogram = + Some(Histogram(1, Array(HistogramBin(1, 1, 1), HistogramBin(1, 2, 1)))))) + colStats.update("cshort", stats("cshort").copy(histogram = + Some(Histogram(1, Array(HistogramBin(1, 1, 1), HistogramBin(1, 3, 1)))))) + colStats.update("cint", stats("cint").copy(histogram = + Some(Histogram(1, Array(HistogramBin(1, 1, 1), HistogramBin(1, 4, 1)))))) + colStats.update("clong", stats("clong").copy(histogram = + Some(Histogram(1, Array(HistogramBin(1, 1, 1), HistogramBin(1, 5, 1)))))) + colStats.update("cdouble", stats("cdouble").copy(histogram = + Some(Histogram(1, Array(HistogramBin(1, 1, 1), HistogramBin(1, 6, 1)))))) + colStats.update("cfloat", stats("cfloat").copy(histogram = + Some(Histogram(1, Array(HistogramBin(1, 1, 1), HistogramBin(1, 7, 1)))))) + colStats.update("cdecimal", stats("cdecimal").copy(histogram = + Some(Histogram(1, Array(HistogramBin(1, 1, 1), HistogramBin(1, 8, 1)))))) + colStats.update("cdate", stats("cdate").copy(histogram = + Some(Histogram(1, Array(HistogramBin(d1Internal, d1Internal, 1), + HistogramBin(d1Internal, d2Internal, 1)))))) + colStats.update("ctimestamp", stats("ctimestamp").copy(histogram = + Some(Histogram(1, Array(HistogramBin(t1Internal, t1Internal, 1), + HistogramBin(t1Internal, t2Internal, 1)))))) + colStats + } + private val randomName = new Random(31) def getCatalogTable(tableName: String): CatalogTable = { diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala index 7cd772544a96a..44e680dbd2f93 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala @@ -1032,8 +1032,8 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat stats: CatalogStatistics, schema: StructType): Map[String, String] = { - var statsProperties: Map[String, String] = - Map(STATISTICS_TOTAL_SIZE -> stats.sizeInBytes.toString()) + val statsProperties = new mutable.HashMap[String, String]() + statsProperties += STATISTICS_TOTAL_SIZE -> stats.sizeInBytes.toString() if (stats.rowCount.isDefined) { statsProperties += STATISTICS_NUM_ROWS -> stats.rowCount.get.toString() } @@ -1046,7 +1046,7 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat } } - statsProperties + statsProperties.toMap } private def statsFromProperties( @@ -1072,9 +1072,8 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat val colStatMap = statsProps.filterKeys(_.startsWith(keyPrefix)).map { case (k, v) => (k.drop(keyPrefix.length), v) } - - ColumnStat.fromMap(table, field, colStatMap).foreach { - colStat => colStats += field.name -> colStat + ColumnStat.fromMap(table, field, colStatMap).foreach { cs => + colStats += field.name -> cs } } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala index 9e8fc32a05471..7427948fe138b 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.hive import java.io.{File, PrintWriter} +import java.sql.Timestamp import scala.reflect.ClassTag import scala.util.matching.Regex @@ -28,8 +29,8 @@ import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.analysis.NoSuchPartitionException import org.apache.spark.sql.catalyst.catalog.{CatalogStatistics, HiveTableRelation} -import org.apache.spark.sql.catalyst.plans.logical.ColumnStat -import org.apache.spark.sql.catalyst.util.StringUtils +import org.apache.spark.sql.catalyst.plans.logical.{ColumnStat, HistogramBin, HistogramSerializer} +import org.apache.spark.sql.catalyst.util.{DateTimeUtils, StringUtils} import org.apache.spark.sql.execution.command.DDLUtils import org.apache.spark.sql.execution.datasources.LogicalRelation import org.apache.spark.sql.execution.joins._ @@ -963,98 +964,174 @@ class StatisticsSuite extends StatisticsCollectionTestBase with TestHiveSingleto assert(stats.size == data.head.productArity - 1) val df = data.toDF(stats.keys.toSeq :+ "carray" : _*) + val expectedSerializedColStats = Map( + "spark.sql.statistics.colStats.cbinary.avgLen" -> "3", + "spark.sql.statistics.colStats.cbinary.distinctCount" -> "2", + "spark.sql.statistics.colStats.cbinary.maxLen" -> "3", + "spark.sql.statistics.colStats.cbinary.nullCount" -> "1", + "spark.sql.statistics.colStats.cbinary.version" -> "1", + "spark.sql.statistics.colStats.cbool.avgLen" -> "1", + "spark.sql.statistics.colStats.cbool.distinctCount" -> "2", + "spark.sql.statistics.colStats.cbool.max" -> "true", + "spark.sql.statistics.colStats.cbool.maxLen" -> "1", + "spark.sql.statistics.colStats.cbool.min" -> "false", + "spark.sql.statistics.colStats.cbool.nullCount" -> "1", + "spark.sql.statistics.colStats.cbool.version" -> "1", + "spark.sql.statistics.colStats.cbyte.avgLen" -> "1", + "spark.sql.statistics.colStats.cbyte.distinctCount" -> "2", + "spark.sql.statistics.colStats.cbyte.max" -> "2", + "spark.sql.statistics.colStats.cbyte.maxLen" -> "1", + "spark.sql.statistics.colStats.cbyte.min" -> "1", + "spark.sql.statistics.colStats.cbyte.nullCount" -> "1", + "spark.sql.statistics.colStats.cbyte.version" -> "1", + "spark.sql.statistics.colStats.cdate.avgLen" -> "4", + "spark.sql.statistics.colStats.cdate.distinctCount" -> "2", + "spark.sql.statistics.colStats.cdate.max" -> "2016-05-09", + "spark.sql.statistics.colStats.cdate.maxLen" -> "4", + "spark.sql.statistics.colStats.cdate.min" -> "2016-05-08", + "spark.sql.statistics.colStats.cdate.nullCount" -> "1", + "spark.sql.statistics.colStats.cdate.version" -> "1", + "spark.sql.statistics.colStats.cdecimal.avgLen" -> "16", + "spark.sql.statistics.colStats.cdecimal.distinctCount" -> "2", + "spark.sql.statistics.colStats.cdecimal.max" -> "8.000000000000000000", + "spark.sql.statistics.colStats.cdecimal.maxLen" -> "16", + "spark.sql.statistics.colStats.cdecimal.min" -> "1.000000000000000000", + "spark.sql.statistics.colStats.cdecimal.nullCount" -> "1", + "spark.sql.statistics.colStats.cdecimal.version" -> "1", + "spark.sql.statistics.colStats.cdouble.avgLen" -> "8", + "spark.sql.statistics.colStats.cdouble.distinctCount" -> "2", + "spark.sql.statistics.colStats.cdouble.max" -> "6.0", + "spark.sql.statistics.colStats.cdouble.maxLen" -> "8", + "spark.sql.statistics.colStats.cdouble.min" -> "1.0", + "spark.sql.statistics.colStats.cdouble.nullCount" -> "1", + "spark.sql.statistics.colStats.cdouble.version" -> "1", + "spark.sql.statistics.colStats.cfloat.avgLen" -> "4", + "spark.sql.statistics.colStats.cfloat.distinctCount" -> "2", + "spark.sql.statistics.colStats.cfloat.max" -> "7.0", + "spark.sql.statistics.colStats.cfloat.maxLen" -> "4", + "spark.sql.statistics.colStats.cfloat.min" -> "1.0", + "spark.sql.statistics.colStats.cfloat.nullCount" -> "1", + "spark.sql.statistics.colStats.cfloat.version" -> "1", + "spark.sql.statistics.colStats.cint.avgLen" -> "4", + "spark.sql.statistics.colStats.cint.distinctCount" -> "2", + "spark.sql.statistics.colStats.cint.max" -> "4", + "spark.sql.statistics.colStats.cint.maxLen" -> "4", + "spark.sql.statistics.colStats.cint.min" -> "1", + "spark.sql.statistics.colStats.cint.nullCount" -> "1", + "spark.sql.statistics.colStats.cint.version" -> "1", + "spark.sql.statistics.colStats.clong.avgLen" -> "8", + "spark.sql.statistics.colStats.clong.distinctCount" -> "2", + "spark.sql.statistics.colStats.clong.max" -> "5", + "spark.sql.statistics.colStats.clong.maxLen" -> "8", + "spark.sql.statistics.colStats.clong.min" -> "1", + "spark.sql.statistics.colStats.clong.nullCount" -> "1", + "spark.sql.statistics.colStats.clong.version" -> "1", + "spark.sql.statistics.colStats.cshort.avgLen" -> "2", + "spark.sql.statistics.colStats.cshort.distinctCount" -> "2", + "spark.sql.statistics.colStats.cshort.max" -> "3", + "spark.sql.statistics.colStats.cshort.maxLen" -> "2", + "spark.sql.statistics.colStats.cshort.min" -> "1", + "spark.sql.statistics.colStats.cshort.nullCount" -> "1", + "spark.sql.statistics.colStats.cshort.version" -> "1", + "spark.sql.statistics.colStats.cstring.avgLen" -> "3", + "spark.sql.statistics.colStats.cstring.distinctCount" -> "2", + "spark.sql.statistics.colStats.cstring.maxLen" -> "3", + "spark.sql.statistics.colStats.cstring.nullCount" -> "1", + "spark.sql.statistics.colStats.cstring.version" -> "1", + "spark.sql.statistics.colStats.ctimestamp.avgLen" -> "8", + "spark.sql.statistics.colStats.ctimestamp.distinctCount" -> "2", + "spark.sql.statistics.colStats.ctimestamp.max" -> "2016-05-09 00:00:02.0", + "spark.sql.statistics.colStats.ctimestamp.maxLen" -> "8", + "spark.sql.statistics.colStats.ctimestamp.min" -> "2016-05-08 00:00:01.0", + "spark.sql.statistics.colStats.ctimestamp.nullCount" -> "1", + "spark.sql.statistics.colStats.ctimestamp.version" -> "1" + ) + + val expectedSerializedHistograms = Map( + "spark.sql.statistics.colStats.cbyte.histogram" -> + HistogramSerializer.serialize(statsWithHgms("cbyte").histogram.get), + "spark.sql.statistics.colStats.cshort.histogram" -> + HistogramSerializer.serialize(statsWithHgms("cshort").histogram.get), + "spark.sql.statistics.colStats.cint.histogram" -> + HistogramSerializer.serialize(statsWithHgms("cint").histogram.get), + "spark.sql.statistics.colStats.clong.histogram" -> + HistogramSerializer.serialize(statsWithHgms("clong").histogram.get), + "spark.sql.statistics.colStats.cdouble.histogram" -> + HistogramSerializer.serialize(statsWithHgms("cdouble").histogram.get), + "spark.sql.statistics.colStats.cfloat.histogram" -> + HistogramSerializer.serialize(statsWithHgms("cfloat").histogram.get), + "spark.sql.statistics.colStats.cdecimal.histogram" -> + HistogramSerializer.serialize(statsWithHgms("cdecimal").histogram.get), + "spark.sql.statistics.colStats.cdate.histogram" -> + HistogramSerializer.serialize(statsWithHgms("cdate").histogram.get), + "spark.sql.statistics.colStats.ctimestamp.histogram" -> + HistogramSerializer.serialize(statsWithHgms("ctimestamp").histogram.get) + ) + + def checkColStatsProps(expected: Map[String, String]): Unit = { + sql(s"ANALYZE TABLE $tableName COMPUTE STATISTICS FOR COLUMNS " + stats.keys.mkString(", ")) + val table = hiveClient.getTable("default", tableName) + val props = table.properties.filterKeys(_.startsWith("spark.sql.statistics.colStats")) + assert(props == expected) + } + withTable(tableName) { df.write.saveAsTable(tableName) - // Collect statistics - sql(s"analyze table $tableName compute STATISTICS FOR COLUMNS " + stats.keys.mkString(", ")) + // Collect and validate statistics + checkColStatsProps(expectedSerializedColStats) - // Validate statistics - val table = hiveClient.getTable("default", tableName) + withSQLConf( + SQLConf.HISTOGRAM_ENABLED.key -> "true", SQLConf.HISTOGRAM_NUM_BINS.key -> "2") { - val props = table.properties.filterKeys(_.startsWith("spark.sql.statistics.colStats")) - assert(props == Map( - "spark.sql.statistics.colStats.cbinary.avgLen" -> "3", - "spark.sql.statistics.colStats.cbinary.distinctCount" -> "2", - "spark.sql.statistics.colStats.cbinary.maxLen" -> "3", - "spark.sql.statistics.colStats.cbinary.nullCount" -> "1", - "spark.sql.statistics.colStats.cbinary.version" -> "1", - "spark.sql.statistics.colStats.cbool.avgLen" -> "1", - "spark.sql.statistics.colStats.cbool.distinctCount" -> "2", - "spark.sql.statistics.colStats.cbool.max" -> "true", - "spark.sql.statistics.colStats.cbool.maxLen" -> "1", - "spark.sql.statistics.colStats.cbool.min" -> "false", - "spark.sql.statistics.colStats.cbool.nullCount" -> "1", - "spark.sql.statistics.colStats.cbool.version" -> "1", - "spark.sql.statistics.colStats.cbyte.avgLen" -> "1", - "spark.sql.statistics.colStats.cbyte.distinctCount" -> "2", - "spark.sql.statistics.colStats.cbyte.max" -> "2", - "spark.sql.statistics.colStats.cbyte.maxLen" -> "1", - "spark.sql.statistics.colStats.cbyte.min" -> "1", - "spark.sql.statistics.colStats.cbyte.nullCount" -> "1", - "spark.sql.statistics.colStats.cbyte.version" -> "1", - "spark.sql.statistics.colStats.cdate.avgLen" -> "4", - "spark.sql.statistics.colStats.cdate.distinctCount" -> "2", - "spark.sql.statistics.colStats.cdate.max" -> "2016-05-09", - "spark.sql.statistics.colStats.cdate.maxLen" -> "4", - "spark.sql.statistics.colStats.cdate.min" -> "2016-05-08", - "spark.sql.statistics.colStats.cdate.nullCount" -> "1", - "spark.sql.statistics.colStats.cdate.version" -> "1", - "spark.sql.statistics.colStats.cdecimal.avgLen" -> "16", - "spark.sql.statistics.colStats.cdecimal.distinctCount" -> "2", - "spark.sql.statistics.colStats.cdecimal.max" -> "8.000000000000000000", - "spark.sql.statistics.colStats.cdecimal.maxLen" -> "16", - "spark.sql.statistics.colStats.cdecimal.min" -> "1.000000000000000000", - "spark.sql.statistics.colStats.cdecimal.nullCount" -> "1", - "spark.sql.statistics.colStats.cdecimal.version" -> "1", - "spark.sql.statistics.colStats.cdouble.avgLen" -> "8", - "spark.sql.statistics.colStats.cdouble.distinctCount" -> "2", - "spark.sql.statistics.colStats.cdouble.max" -> "6.0", - "spark.sql.statistics.colStats.cdouble.maxLen" -> "8", - "spark.sql.statistics.colStats.cdouble.min" -> "1.0", - "spark.sql.statistics.colStats.cdouble.nullCount" -> "1", - "spark.sql.statistics.colStats.cdouble.version" -> "1", - "spark.sql.statistics.colStats.cfloat.avgLen" -> "4", - "spark.sql.statistics.colStats.cfloat.distinctCount" -> "2", - "spark.sql.statistics.colStats.cfloat.max" -> "7.0", - "spark.sql.statistics.colStats.cfloat.maxLen" -> "4", - "spark.sql.statistics.colStats.cfloat.min" -> "1.0", - "spark.sql.statistics.colStats.cfloat.nullCount" -> "1", - "spark.sql.statistics.colStats.cfloat.version" -> "1", - "spark.sql.statistics.colStats.cint.avgLen" -> "4", - "spark.sql.statistics.colStats.cint.distinctCount" -> "2", - "spark.sql.statistics.colStats.cint.max" -> "4", - "spark.sql.statistics.colStats.cint.maxLen" -> "4", - "spark.sql.statistics.colStats.cint.min" -> "1", - "spark.sql.statistics.colStats.cint.nullCount" -> "1", - "spark.sql.statistics.colStats.cint.version" -> "1", - "spark.sql.statistics.colStats.clong.avgLen" -> "8", - "spark.sql.statistics.colStats.clong.distinctCount" -> "2", - "spark.sql.statistics.colStats.clong.max" -> "5", - "spark.sql.statistics.colStats.clong.maxLen" -> "8", - "spark.sql.statistics.colStats.clong.min" -> "1", - "spark.sql.statistics.colStats.clong.nullCount" -> "1", - "spark.sql.statistics.colStats.clong.version" -> "1", - "spark.sql.statistics.colStats.cshort.avgLen" -> "2", - "spark.sql.statistics.colStats.cshort.distinctCount" -> "2", - "spark.sql.statistics.colStats.cshort.max" -> "3", - "spark.sql.statistics.colStats.cshort.maxLen" -> "2", - "spark.sql.statistics.colStats.cshort.min" -> "1", - "spark.sql.statistics.colStats.cshort.nullCount" -> "1", - "spark.sql.statistics.colStats.cshort.version" -> "1", - "spark.sql.statistics.colStats.cstring.avgLen" -> "3", - "spark.sql.statistics.colStats.cstring.distinctCount" -> "2", - "spark.sql.statistics.colStats.cstring.maxLen" -> "3", - "spark.sql.statistics.colStats.cstring.nullCount" -> "1", - "spark.sql.statistics.colStats.cstring.version" -> "1", - "spark.sql.statistics.colStats.ctimestamp.avgLen" -> "8", - "spark.sql.statistics.colStats.ctimestamp.distinctCount" -> "2", - "spark.sql.statistics.colStats.ctimestamp.max" -> "2016-05-09 00:00:02.0", - "spark.sql.statistics.colStats.ctimestamp.maxLen" -> "8", - "spark.sql.statistics.colStats.ctimestamp.min" -> "2016-05-08 00:00:01.0", - "spark.sql.statistics.colStats.ctimestamp.nullCount" -> "1", - "spark.sql.statistics.colStats.ctimestamp.version" -> "1" - )) + checkColStatsProps(expectedSerializedColStats ++ expectedSerializedHistograms) + } + } + } + + test("serialization and deserialization of histograms to/from hive metastore") { + import testImplicits._ + + def checkBinsOrder(bins: Array[HistogramBin]): Unit = { + for (i <- bins.indices) { + val b = bins(i) + assert(b.lo <= b.hi) + if (i > 0) { + val pre = bins(i - 1) + assert(pre.hi <= b.lo) + } + } + } + + val startTimestamp = DateTimeUtils.fromJavaTimestamp(Timestamp.valueOf("2016-05-08 00:00:01")) + val df = (1 to 5000) + .map(i => (i, DateTimeUtils.toJavaTimestamp(startTimestamp + i))) + .toDF("cint", "ctimestamp") + val tableName = "histogram_serde_test" + + withTable(tableName) { + df.write.saveAsTable(tableName) + + withSQLConf(SQLConf.HISTOGRAM_ENABLED.key -> "true") { + sql(s"ANALYZE TABLE $tableName COMPUTE STATISTICS FOR COLUMNS cint, ctimestamp") + val table = hiveClient.getTable("default", tableName) + val intHistogramProps = table.properties + .filterKeys(_.startsWith("spark.sql.statistics.colStats.cint.histogram")) + assert(intHistogramProps.size == 1) + + val tsHistogramProps = table.properties + .filterKeys(_.startsWith("spark.sql.statistics.colStats.ctimestamp.histogram")) + assert(tsHistogramProps.size == 1) + + // Validate histogram after deserialization. + val cs = getCatalogStatistics(tableName).colStats + val intHistogram = cs("cint").histogram.get + val tsHistogram = cs("ctimestamp").histogram.get + assert(intHistogram.bins.length == spark.sessionState.conf.histogramNumBins) + checkBinsOrder(intHistogram.bins) + assert(tsHistogram.bins.length == spark.sessionState.conf.histogramNumBins) + checkBinsOrder(tsHistogram.bins) + } } } From 4741c07809393ab85be8b4a169d4ed3da93a4781 Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Tue, 14 Nov 2017 10:34:32 -0600 Subject: [PATCH 1668/1765] [SPARK-20648][CORE] Port JobsTab and StageTab to the new UI backend. This change is a little larger because there's a whole lot of logic behind these pages, all really tied to internal types and listeners, and some of that logic had to be implemented in the new listener and the needed data exposed through the API types. - Added missing StageData and ExecutorStageSummary fields which are used by the UI. Some json golden files needed to be updated to account for new fields. - Save RDD graph data in the store. This tries to re-use existing types as much as possible, so that the code doesn't need to be re-written. So it's probably not very optimal. - Some old classes (e.g. JobProgressListener) still remain, since they're used in other parts of the code; they're not used by the UI anymore, though, and will be cleaned up in a separate change. - Save information about active pools in the store. This data is not really used in the SHS, but it's not a lot of data so it's still recorded when replaying applications. - Because the new store sorts things slightly differently from the previous code, some json golden files had some elements within them shuffled around. - The retention unit test in UISeleniumSuite was disabled because the code to throw away old stages / tasks hasn't been added yet. - The job description field in the API tries to follow the old behavior, which makes it be empty most of the time, even though there's information to fill it in. For stages, a new field was added to hold the description (which is basically the job description), so that the UI can be rendered in the old way. - A new stage status ("SKIPPED") was added to account for the fact that the API couldn't represent that state before. Without this, the stage would show up as "PENDING" in the UI, which is now based on API types. - The API used to expose "executorRunTime" as the value of the task's duration, which wasn't really correct (also because that value was easily available from the metrics object); this change fixes that by storing the correct duration, which also means a few expectation files needed to be updated to account for the new durations and sorting differences due to the changed values. - Added changes to implement SPARK-20713 and SPARK-21922 in the new code. Tested with existing unit tests (and by using the UI a lot). Author: Marcelo Vanzin Closes #19698 from vanzin/SPARK-20648. --- .../scala/org/apache/spark/SparkContext.scala | 8 +- .../deploy/history/FsHistoryProvider.scala | 10 +- .../spark/status/AppStatusListener.scala | 121 +- .../apache/spark/status/AppStatusStore.scala | 135 ++- .../org/apache/spark/status/LiveEntity.scala | 57 +- .../spark/status/api/v1/AllJobsResource.scala | 70 +- .../status/api/v1/AllStagesResource.scala | 290 +---- .../spark/status/api/v1/OneJobResource.scala | 15 +- .../status/api/v1/OneStageResource.scala | 112 +- .../org/apache/spark/status/api/v1/api.scala | 21 +- .../org/apache/spark/status/config.scala | 4 + .../org/apache/spark/status/storeTypes.scala | 40 + .../scala/org/apache/spark/ui/SparkUI.scala | 30 +- .../apache/spark/ui/jobs/AllJobsPage.scala | 286 +++-- .../apache/spark/ui/jobs/AllStagesPage.scala | 189 +-- .../apache/spark/ui/jobs/ExecutorTable.scala | 158 +-- .../org/apache/spark/ui/jobs/JobPage.scala | 326 ++--- .../org/apache/spark/ui/jobs/JobsTab.scala | 38 +- .../org/apache/spark/ui/jobs/PoolPage.scala | 57 +- .../org/apache/spark/ui/jobs/PoolTable.scala | 34 +- .../org/apache/spark/ui/jobs/StagePage.scala | 1064 ++++++++--------- .../org/apache/spark/ui/jobs/StageTable.scala | 100 +- .../org/apache/spark/ui/jobs/StagesTab.scala | 28 +- .../spark/ui/scope/RDDOperationGraph.scala | 10 +- .../ui/scope/RDDOperationGraphListener.scala | 150 --- .../complete_stage_list_json_expectation.json | 21 +- .../failed_stage_list_json_expectation.json | 8 +- ...multi_attempt_app_json_1__expectation.json | 5 +- ...multi_attempt_app_json_2__expectation.json | 5 +- .../job_list_json_expectation.json | 15 +- .../one_job_json_expectation.json | 5 +- .../one_stage_attempt_json_expectation.json | 126 +- .../one_stage_json_expectation.json | 126 +- .../stage_list_json_expectation.json | 79 +- ...ist_with_accumulable_json_expectation.json | 7 +- .../stage_task_list_expectation.json | 60 +- ...multi_attempt_app_json_1__expectation.json | 24 +- ...multi_attempt_app_json_2__expectation.json | 24 +- ...k_list_w__offset___length_expectation.json | 150 ++- ...stage_task_list_w__sortBy_expectation.json | 190 +-- ...tBy_short_names___runtime_expectation.json | 190 +-- ...rtBy_short_names__runtime_expectation.json | 60 +- ...age_with_accumulable_json_expectation.json | 36 +- ...eded_failed_job_list_json_expectation.json | 15 +- .../succeeded_job_list_json_expectation.json | 10 +- .../deploy/history/HistoryServerSuite.scala | 13 +- .../spark/status/AppStatusListenerSuite.scala | 80 +- .../api/v1/AllStagesResourceSuite.scala | 62 - .../org/apache/spark/ui/StagePageSuite.scala | 64 +- .../org/apache/spark/ui/UISeleniumSuite.scala | 60 +- .../RDDOperationGraphListenerSuite.scala | 226 ---- project/MimaExcludes.scala | 2 + 52 files changed, 2359 insertions(+), 2657 deletions(-) delete mode 100644 core/src/main/scala/org/apache/spark/ui/scope/RDDOperationGraphListener.scala delete mode 100644 core/src/test/scala/org/apache/spark/status/api/v1/AllStagesResourceSuite.scala delete mode 100644 core/src/test/scala/org/apache/spark/ui/scope/RDDOperationGraphListenerSuite.scala diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index e5aaaf6c155eb..1d325e651b1d9 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -426,7 +426,7 @@ class SparkContext(config: SparkConf) extends Logging { // Initialize the app status store and listener before SparkEnv is created so that it gets // all events. - _statusStore = AppStatusStore.createLiveStore(conf, listenerBus) + _statusStore = AppStatusStore.createLiveStore(conf, l => listenerBus.addToStatusQueue(l)) // Create the Spark execution environment (cache, map output tracker, etc) _env = createSparkEnv(_conf, isLocal, listenerBus) @@ -449,11 +449,7 @@ class SparkContext(config: SparkConf) extends Logging { _ui = if (conf.getBoolean("spark.ui.enabled", true)) { - Some(SparkUI.create(Some(this), _statusStore, _conf, - l => listenerBus.addToStatusQueue(l), - _env.securityManager, - appName, - "", + Some(SparkUI.create(Some(this), _statusStore, _conf, _env.securityManager, appName, "", startTime)) } else { // For tests, do not enable the UI diff --git a/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala b/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala index f16dddea9f784..a6dc53321d650 100644 --- a/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala +++ b/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala @@ -316,7 +316,8 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) } val listener = if (needReplay) { - val _listener = new AppStatusListener(kvstore, conf, false) + val _listener = new AppStatusListener(kvstore, conf, false, + lastUpdateTime = Some(attempt.info.lastUpdated.getTime())) replayBus.addListener(_listener) Some(_listener) } else { @@ -324,13 +325,10 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) } val loadedUI = { - val ui = SparkUI.create(None, new AppStatusStore(kvstore), conf, - l => replayBus.addListener(l), - secManager, - app.info.name, + val ui = SparkUI.create(None, new AppStatusStore(kvstore), conf, secManager, app.info.name, HistoryServer.getAttemptURI(appId, attempt.info.attemptId), attempt.info.startTime.getTime(), - appSparkVersion = attempt.info.appSparkVersion) + attempt.info.appSparkVersion) LoadedAppUI(ui) } diff --git a/core/src/main/scala/org/apache/spark/status/AppStatusListener.scala b/core/src/main/scala/org/apache/spark/status/AppStatusListener.scala index 7f2c00c09d43d..f2d8e0a5480ba 100644 --- a/core/src/main/scala/org/apache/spark/status/AppStatusListener.scala +++ b/core/src/main/scala/org/apache/spark/status/AppStatusListener.scala @@ -28,16 +28,21 @@ import org.apache.spark.scheduler._ import org.apache.spark.status.api.v1 import org.apache.spark.storage._ import org.apache.spark.ui.SparkUI +import org.apache.spark.ui.scope._ import org.apache.spark.util.kvstore.KVStore /** * A Spark listener that writes application information to a data store. The types written to the * store are defined in the `storeTypes.scala` file and are based on the public REST API. + * + * @param lastUpdateTime When replaying logs, the log's last update time, so that the duration of + * unfinished tasks can be more accurately calculated (see SPARK-21922). */ private[spark] class AppStatusListener( kvstore: KVStore, conf: SparkConf, - live: Boolean) extends SparkListener with Logging { + live: Boolean, + lastUpdateTime: Option[Long] = None) extends SparkListener with Logging { import config._ @@ -50,6 +55,8 @@ private[spark] class AppStatusListener( // operations that we can live without when rapidly processing incoming task events. private val liveUpdatePeriodNs = if (live) conf.get(LIVE_ENTITY_UPDATE_PERIOD) else -1L + private val maxGraphRootNodes = conf.get(MAX_RETAINED_ROOT_NODES) + // Keep track of live entities, so that task metrics can be efficiently updated (without // causing too many writes to the underlying store, and other expensive operations). private val liveStages = new HashMap[(Int, Int), LiveStage]() @@ -57,6 +64,7 @@ private[spark] class AppStatusListener( private val liveExecutors = new HashMap[String, LiveExecutor]() private val liveTasks = new HashMap[Long, LiveTask]() private val liveRDDs = new HashMap[Int, LiveRDD]() + private val pools = new HashMap[String, SchedulerPool]() override def onOtherEvent(event: SparkListenerEvent): Unit = event match { case SparkListenerLogStart(version) => sparkVersion = version @@ -210,16 +218,15 @@ private[spark] class AppStatusListener( missingStages.map(_.numTasks).sum } - val lastStageInfo = event.stageInfos.lastOption + val lastStageInfo = event.stageInfos.sortBy(_.stageId).lastOption val lastStageName = lastStageInfo.map(_.name).getOrElse("(Unknown Stage Name)") - val jobGroup = Option(event.properties) .flatMap { p => Option(p.getProperty(SparkContext.SPARK_JOB_GROUP_ID)) } val job = new LiveJob( event.jobId, lastStageName, - Some(new Date(event.time)), + if (event.time > 0) Some(new Date(event.time)) else None, event.stageIds, jobGroup, numTasks) @@ -234,17 +241,51 @@ private[spark] class AppStatusListener( stage.jobIds += event.jobId liveUpdate(stage, now) } + + // Create the graph data for all the job's stages. + event.stageInfos.foreach { stage => + val graph = RDDOperationGraph.makeOperationGraph(stage, maxGraphRootNodes) + val uigraph = new RDDOperationGraphWrapper( + stage.stageId, + graph.edges, + graph.outgoingEdges, + graph.incomingEdges, + newRDDOperationCluster(graph.rootCluster)) + kvstore.write(uigraph) + } + } + + private def newRDDOperationCluster(cluster: RDDOperationCluster): RDDOperationClusterWrapper = { + new RDDOperationClusterWrapper( + cluster.id, + cluster.name, + cluster.childNodes, + cluster.childClusters.map(newRDDOperationCluster)) } override def onJobEnd(event: SparkListenerJobEnd): Unit = { liveJobs.remove(event.jobId).foreach { job => + val now = System.nanoTime() + + // Check if there are any pending stages that match this job; mark those as skipped. + job.stageIds.foreach { sid => + val pending = liveStages.filter { case ((id, _), _) => id == sid } + pending.foreach { case (key, stage) => + stage.status = v1.StageStatus.SKIPPED + job.skippedStages += stage.info.stageId + job.skippedTasks += stage.info.numTasks + liveStages.remove(key) + update(stage, now) + } + } + job.status = event.jobResult match { case JobSucceeded => JobExecutionStatus.SUCCEEDED case JobFailed(_) => JobExecutionStatus.FAILED } - job.completionTime = Some(new Date(event.time)) - update(job, System.nanoTime()) + job.completionTime = if (event.time > 0) Some(new Date(event.time)) else None + update(job, now) } } @@ -262,12 +303,24 @@ private[spark] class AppStatusListener( .toSeq stage.jobIds = stage.jobs.map(_.jobId).toSet + stage.schedulingPool = Option(event.properties).flatMap { p => + Option(p.getProperty("spark.scheduler.pool")) + }.getOrElse(SparkUI.DEFAULT_POOL_NAME) + + stage.description = Option(event.properties).flatMap { p => + Option(p.getProperty(SparkContext.SPARK_JOB_DESCRIPTION)) + } + stage.jobs.foreach { job => job.completedStages = job.completedStages - event.stageInfo.stageId job.activeStages += 1 liveUpdate(job, now) } + val pool = pools.getOrElseUpdate(stage.schedulingPool, new SchedulerPool(stage.schedulingPool)) + pool.stageIds = pool.stageIds + event.stageInfo.stageId + update(pool, now) + event.stageInfo.rddInfos.foreach { info => if (info.storageLevel.isValid) { liveUpdate(liveRDDs.getOrElseUpdate(info.id, new LiveRDD(info)), now) @@ -279,7 +332,7 @@ private[spark] class AppStatusListener( override def onTaskStart(event: SparkListenerTaskStart): Unit = { val now = System.nanoTime() - val task = new LiveTask(event.taskInfo, event.stageId, event.stageAttemptId) + val task = new LiveTask(event.taskInfo, event.stageId, event.stageAttemptId, lastUpdateTime) liveTasks.put(event.taskInfo.taskId, task) liveUpdate(task, now) @@ -318,6 +371,8 @@ private[spark] class AppStatusListener( val now = System.nanoTime() val metricsDelta = liveTasks.remove(event.taskInfo.taskId).map { task => + task.info = event.taskInfo + val errorMessage = event.reason match { case Success => None @@ -337,11 +392,15 @@ private[spark] class AppStatusListener( delta }.orNull - val (completedDelta, failedDelta) = event.reason match { + val (completedDelta, failedDelta, killedDelta) = event.reason match { case Success => - (1, 0) + (1, 0, 0) + case _: TaskKilled => + (0, 0, 1) + case _: TaskCommitDenied => + (0, 0, 1) case _ => - (0, 1) + (0, 1, 0) } liveStages.get((event.stageId, event.stageAttemptId)).foreach { stage => @@ -350,13 +409,29 @@ private[spark] class AppStatusListener( } stage.activeTasks -= 1 stage.completedTasks += completedDelta + if (completedDelta > 0) { + stage.completedIndices.add(event.taskInfo.index) + } stage.failedTasks += failedDelta + stage.killedTasks += killedDelta + if (killedDelta > 0) { + stage.killedSummary = killedTasksSummary(event.reason, stage.killedSummary) + } maybeUpdate(stage, now) + // Store both stage ID and task index in a single long variable for tracking at job level. + val taskIndex = (event.stageId.toLong << Integer.SIZE) | event.taskInfo.index stage.jobs.foreach { job => job.activeTasks -= 1 job.completedTasks += completedDelta + if (completedDelta > 0) { + job.completedIndices.add(taskIndex) + } job.failedTasks += failedDelta + job.killedTasks += killedDelta + if (killedDelta > 0) { + job.killedSummary = killedTasksSummary(event.reason, job.killedSummary) + } maybeUpdate(job, now) } @@ -364,6 +439,7 @@ private[spark] class AppStatusListener( esummary.taskTime += event.taskInfo.duration esummary.succeededTasks += completedDelta esummary.failedTasks += failedDelta + esummary.killedTasks += killedDelta if (metricsDelta != null) { esummary.metrics.update(metricsDelta) } @@ -422,6 +498,11 @@ private[spark] class AppStatusListener( liveUpdate(job, now) } + pools.get(stage.schedulingPool).foreach { pool => + pool.stageIds = pool.stageIds - event.stageInfo.stageId + update(pool, now) + } + stage.executorSummaries.values.foreach(update(_, now)) update(stage, now) } @@ -482,11 +563,15 @@ private[spark] class AppStatusListener( /** Flush all live entities' data to the underlying store. */ def flush(): Unit = { val now = System.nanoTime() - liveStages.values.foreach(update(_, now)) + liveStages.values.foreach { stage => + update(stage, now) + stage.executorSummaries.values.foreach(update(_, now)) + } liveJobs.values.foreach(update(_, now)) liveExecutors.values.foreach(update(_, now)) liveTasks.values.foreach(update(_, now)) liveRDDs.values.foreach(update(_, now)) + pools.values.foreach(update(_, now)) } private def updateRDDBlock(event: SparkListenerBlockUpdated, block: RDDBlockId): Unit = { @@ -628,6 +713,20 @@ private[spark] class AppStatusListener( stage } + private def killedTasksSummary( + reason: TaskEndReason, + oldSummary: Map[String, Int]): Map[String, Int] = { + reason match { + case k: TaskKilled => + oldSummary.updated(k.reason, oldSummary.getOrElse(k.reason, 0) + 1) + case denied: TaskCommitDenied => + val reason = denied.toErrorString + oldSummary.updated(reason, oldSummary.getOrElse(reason, 0) + 1) + case _ => + oldSummary + } + } + private def update(entity: LiveEntity, now: Long): Unit = { entity.write(kvstore, now) } diff --git a/core/src/main/scala/org/apache/spark/status/AppStatusStore.scala b/core/src/main/scala/org/apache/spark/status/AppStatusStore.scala index 80c8d7d11a3c2..9b42f55605755 100644 --- a/core/src/main/scala/org/apache/spark/status/AppStatusStore.scala +++ b/core/src/main/scala/org/apache/spark/status/AppStatusStore.scala @@ -23,8 +23,9 @@ import java.util.{Arrays, List => JList} import scala.collection.JavaConverters._ import org.apache.spark.{JobExecutionStatus, SparkConf} -import org.apache.spark.scheduler.LiveListenerBus +import org.apache.spark.scheduler.SparkListener import org.apache.spark.status.api.v1 +import org.apache.spark.ui.scope._ import org.apache.spark.util.{Distribution, Utils} import org.apache.spark.util.kvstore.{InMemoryStore, KVStore} @@ -43,8 +44,8 @@ private[spark] class AppStatusStore(store: KVStore) { } def jobsList(statuses: JList[JobExecutionStatus]): Seq[v1.JobData] = { - val it = store.view(classOf[JobDataWrapper]).asScala.map(_.info) - if (!statuses.isEmpty()) { + val it = store.view(classOf[JobDataWrapper]).reverse().asScala.map(_.info) + if (statuses != null && !statuses.isEmpty()) { it.filter { job => statuses.contains(job.status) }.toSeq } else { it.toSeq @@ -65,31 +66,40 @@ private[spark] class AppStatusStore(store: KVStore) { filtered.asScala.map(_.info).toSeq } - def executorSummary(executorId: String): Option[v1.ExecutorSummary] = { - try { - Some(store.read(classOf[ExecutorSummaryWrapper], executorId).info) - } catch { - case _: NoSuchElementException => - None - } + def executorSummary(executorId: String): v1.ExecutorSummary = { + store.read(classOf[ExecutorSummaryWrapper], executorId).info } def stageList(statuses: JList[v1.StageStatus]): Seq[v1.StageData] = { - val it = store.view(classOf[StageDataWrapper]).asScala.map(_.info) - if (!statuses.isEmpty) { + val it = store.view(classOf[StageDataWrapper]).reverse().asScala.map(_.info) + if (statuses != null && !statuses.isEmpty()) { it.filter { s => statuses.contains(s.status) }.toSeq } else { it.toSeq } } - def stageData(stageId: Int): Seq[v1.StageData] = { + def stageData(stageId: Int, details: Boolean = false): Seq[v1.StageData] = { store.view(classOf[StageDataWrapper]).index("stageId").first(stageId).last(stageId) - .asScala.map(_.info).toSeq + .asScala.map { s => + if (details) stageWithDetails(s.info) else s.info + }.toSeq + } + + def lastStageAttempt(stageId: Int): v1.StageData = { + val it = store.view(classOf[StageDataWrapper]).index("stageId").reverse().first(stageId) + .closeableIterator() + try { + it.next().info + } finally { + it.close() + } } - def stageAttempt(stageId: Int, stageAttemptId: Int): v1.StageData = { - store.read(classOf[StageDataWrapper], Array(stageId, stageAttemptId)).info + def stageAttempt(stageId: Int, stageAttemptId: Int, details: Boolean = false): v1.StageData = { + val stageKey = Array(stageId, stageAttemptId) + val stage = store.read(classOf[StageDataWrapper], stageKey).info + if (details) stageWithDetails(stage) else stage } def taskSummary( @@ -189,6 +199,12 @@ private[spark] class AppStatusStore(store: KVStore) { ) } + def taskList(stageId: Int, stageAttemptId: Int, maxTasks: Int): Seq[v1.TaskData] = { + val stageKey = Array(stageId, stageAttemptId) + store.view(classOf[TaskDataWrapper]).index("stage").first(stageKey).last(stageKey).reverse() + .max(maxTasks).asScala.map(_.info).toSeq.reverse + } + def taskList( stageId: Int, stageAttemptId: Int, @@ -215,6 +231,66 @@ private[spark] class AppStatusStore(store: KVStore) { }.toSeq } + /** + * Calls a closure that may throw a NoSuchElementException and returns `None` when the exception + * is thrown. + */ + def asOption[T](fn: => T): Option[T] = { + try { + Some(fn) + } catch { + case _: NoSuchElementException => None + } + } + + private def stageWithDetails(stage: v1.StageData): v1.StageData = { + val tasks = taskList(stage.stageId, stage.attemptId, Int.MaxValue) + .map { t => (t.taskId, t) } + .toMap + + val stageKey = Array(stage.stageId, stage.attemptId) + val execs = store.view(classOf[ExecutorStageSummaryWrapper]).index("stage").first(stageKey) + .last(stageKey).closeableIterator().asScala + .map { exec => (exec.executorId -> exec.info) } + .toMap + + new v1.StageData( + stage.status, + stage.stageId, + stage.attemptId, + stage.numTasks, + stage.numActiveTasks, + stage.numCompleteTasks, + stage.numFailedTasks, + stage.numKilledTasks, + stage.numCompletedIndices, + stage.executorRunTime, + stage.executorCpuTime, + stage.submissionTime, + stage.firstTaskLaunchedTime, + stage.completionTime, + stage.failureReason, + stage.inputBytes, + stage.inputRecords, + stage.outputBytes, + stage.outputRecords, + stage.shuffleReadBytes, + stage.shuffleReadRecords, + stage.shuffleWriteBytes, + stage.shuffleWriteRecords, + stage.memoryBytesSpilled, + stage.diskBytesSpilled, + stage.name, + stage.description, + stage.details, + stage.schedulingPool, + stage.rddIds, + stage.accumulatorUpdates, + Some(tasks), + Some(execs), + stage.killedTasksSummary) + } + def rdd(rddId: Int): v1.RDDStorageInfo = { store.read(classOf[RDDStorageInfoWrapper], rddId).info } @@ -223,6 +299,27 @@ private[spark] class AppStatusStore(store: KVStore) { store.view(classOf[StreamBlockData]).asScala.toSeq } + def operationGraphForStage(stageId: Int): RDDOperationGraph = { + store.read(classOf[RDDOperationGraphWrapper], stageId).toRDDOperationGraph() + } + + def operationGraphForJob(jobId: Int): Seq[RDDOperationGraph] = { + val job = store.read(classOf[JobDataWrapper], jobId) + val stages = job.info.stageIds + + stages.map { id => + val g = store.read(classOf[RDDOperationGraphWrapper], id).toRDDOperationGraph() + if (job.skippedStages.contains(id) && !g.rootCluster.name.contains("skipped")) { + g.rootCluster.setName(g.rootCluster.name + " (skipped)") + } + g + } + } + + def pool(name: String): PoolData = { + store.read(classOf[PoolData], name) + } + def close(): Unit = { store.close() } @@ -237,12 +334,12 @@ private[spark] object AppStatusStore { * Create an in-memory store for a live application. * * @param conf Configuration. - * @param bus Where to attach the listener to populate the store. + * @param addListenerFn Function to register a listener with a bus. */ - def createLiveStore(conf: SparkConf, bus: LiveListenerBus): AppStatusStore = { + def createLiveStore(conf: SparkConf, addListenerFn: SparkListener => Unit): AppStatusStore = { val store = new InMemoryStore() val stateStore = new AppStatusStore(store) - bus.addToStatusQueue(new AppStatusListener(store, conf, true)) + addListenerFn(new AppStatusListener(store, conf, true)) stateStore } diff --git a/core/src/main/scala/org/apache/spark/status/LiveEntity.scala b/core/src/main/scala/org/apache/spark/status/LiveEntity.scala index 706d94c3a59b9..ef2936c9b69a4 100644 --- a/core/src/main/scala/org/apache/spark/status/LiveEntity.scala +++ b/core/src/main/scala/org/apache/spark/status/LiveEntity.scala @@ -28,6 +28,7 @@ import org.apache.spark.status.api.v1 import org.apache.spark.storage.RDDInfo import org.apache.spark.ui.SparkUI import org.apache.spark.util.AccumulatorContext +import org.apache.spark.util.collection.OpenHashSet import org.apache.spark.util.kvstore.KVStore /** @@ -64,6 +65,12 @@ private class LiveJob( var completedTasks = 0 var failedTasks = 0 + // Holds both the stage ID and the task index, packed into a single long value. + val completedIndices = new OpenHashSet[Long]() + + var killedTasks = 0 + var killedSummary: Map[String, Int] = Map() + var skippedTasks = 0 var skippedStages = Set[Int]() @@ -89,19 +96,23 @@ private class LiveJob( completedTasks, skippedTasks, failedTasks, + killedTasks, + completedIndices.size, activeStages, completedStages.size, skippedStages.size, - failedStages) + failedStages, + killedSummary) new JobDataWrapper(info, skippedStages) } } private class LiveTask( - info: TaskInfo, + var info: TaskInfo, stageId: Int, - stageAttemptId: Int) extends LiveEntity { + stageAttemptId: Int, + lastUpdateTime: Option[Long]) extends LiveEntity { import LiveEntityHelpers._ @@ -126,6 +137,7 @@ private class LiveTask( metrics.resultSerializationTime, metrics.memoryBytesSpilled, metrics.diskBytesSpilled, + metrics.peakExecutionMemory, new v1.InputMetrics( metrics.inputMetrics.bytesRead, metrics.inputMetrics.recordsRead), @@ -186,6 +198,7 @@ private class LiveTask( 0L, 0L, 0L, metrics.memoryBytesSpilled - old.memoryBytesSpilled, metrics.diskBytesSpilled - old.diskBytesSpilled, + 0L, inputDelta, outputDelta, shuffleReadDelta, @@ -193,12 +206,19 @@ private class LiveTask( } override protected def doUpdate(): Any = { + val duration = if (info.finished) { + info.duration + } else { + info.timeRunning(lastUpdateTime.getOrElse(System.currentTimeMillis())) + } + val task = new v1.TaskData( info.taskId, info.index, info.attemptNumber, new Date(info.launchTime), - if (info.finished) Some(info.duration) else None, + if (info.gettingResult) Some(new Date(info.gettingResultTime)) else None, + Some(duration), info.executorId, info.host, info.status, @@ -340,10 +360,15 @@ private class LiveExecutorStageSummary( taskTime, failedTasks, succeededTasks, + killedTasks, metrics.inputBytes, + metrics.inputRecords, metrics.outputBytes, + metrics.outputRecords, metrics.shuffleReadBytes, + metrics.shuffleReadRecords, metrics.shuffleWriteBytes, + metrics.shuffleWriteRecords, metrics.memoryBytesSpilled, metrics.diskBytesSpilled) new ExecutorStageSummaryWrapper(stageId, attemptId, executorId, info) @@ -361,11 +386,16 @@ private class LiveStage extends LiveEntity { var info: StageInfo = null var status = v1.StageStatus.PENDING + var description: Option[String] = None var schedulingPool: String = SparkUI.DEFAULT_POOL_NAME var activeTasks = 0 var completedTasks = 0 var failedTasks = 0 + val completedIndices = new OpenHashSet[Int]() + + var killedTasks = 0 + var killedSummary: Map[String, Int] = Map() var firstLaunchTime = Long.MaxValue @@ -384,15 +414,19 @@ private class LiveStage extends LiveEntity { info.stageId, info.attemptId, + info.numTasks, activeTasks, completedTasks, failedTasks, + killedTasks, + completedIndices.size, metrics.executorRunTime, metrics.executorCpuTime, info.submissionTime.map(new Date(_)), if (firstLaunchTime < Long.MaxValue) Some(new Date(firstLaunchTime)) else None, info.completionTime.map(new Date(_)), + info.failureReason, metrics.inputBytes, metrics.inputRecords, @@ -406,12 +440,15 @@ private class LiveStage extends LiveEntity { metrics.diskBytesSpilled, info.name, + description, info.details, schedulingPool, + info.rddInfos.map(_.id), newAccumulatorInfos(info.accumulables.values), None, - None) + None, + killedSummary) new StageDataWrapper(update, jobIds) } @@ -535,6 +572,16 @@ private class LiveRDD(val info: RDDInfo) extends LiveEntity { } +private class SchedulerPool(name: String) extends LiveEntity { + + var stageIds = Set[Int]() + + override protected def doUpdate(): Any = { + new PoolData(name, stageIds) + } + +} + private object LiveEntityHelpers { def newAccumulatorInfos(accums: Iterable[AccumulableInfo]): Seq[v1.AccumulableInfo] = { diff --git a/core/src/main/scala/org/apache/spark/status/api/v1/AllJobsResource.scala b/core/src/main/scala/org/apache/spark/status/api/v1/AllJobsResource.scala index d0d9ef1165e81..b4fa3e633f6c1 100644 --- a/core/src/main/scala/org/apache/spark/status/api/v1/AllJobsResource.scala +++ b/core/src/main/scala/org/apache/spark/status/api/v1/AllJobsResource.scala @@ -22,7 +22,6 @@ import javax.ws.rs.core.MediaType import org.apache.spark.JobExecutionStatus import org.apache.spark.ui.SparkUI -import org.apache.spark.ui.jobs.JobProgressListener import org.apache.spark.ui.jobs.UIData.JobUIData @Produces(Array(MediaType.APPLICATION_JSON)) @@ -30,74 +29,7 @@ private[v1] class AllJobsResource(ui: SparkUI) { @GET def jobsList(@QueryParam("status") statuses: JList[JobExecutionStatus]): Seq[JobData] = { - val statusToJobs: Seq[(JobExecutionStatus, Seq[JobUIData])] = - AllJobsResource.getStatusToJobs(ui) - val adjStatuses: JList[JobExecutionStatus] = { - if (statuses.isEmpty) { - Arrays.asList(JobExecutionStatus.values(): _*) - } else { - statuses - } - } - val jobInfos = for { - (status, jobs) <- statusToJobs - job <- jobs if adjStatuses.contains(status) - } yield { - AllJobsResource.convertJobData(job, ui.jobProgressListener, false) - } - jobInfos.sortBy{- _.jobId} + ui.store.jobsList(statuses) } } - -private[v1] object AllJobsResource { - - def getStatusToJobs(ui: SparkUI): Seq[(JobExecutionStatus, Seq[JobUIData])] = { - val statusToJobs = ui.jobProgressListener.synchronized { - Seq( - JobExecutionStatus.RUNNING -> ui.jobProgressListener.activeJobs.values.toSeq, - JobExecutionStatus.SUCCEEDED -> ui.jobProgressListener.completedJobs.toSeq, - JobExecutionStatus.FAILED -> ui.jobProgressListener.failedJobs.reverse.toSeq - ) - } - statusToJobs - } - - def convertJobData( - job: JobUIData, - listener: JobProgressListener, - includeStageDetails: Boolean): JobData = { - listener.synchronized { - val lastStageInfo = - if (job.stageIds.isEmpty) { - None - } else { - listener.stageIdToInfo.get(job.stageIds.max) - } - val lastStageData = lastStageInfo.flatMap { s => - listener.stageIdToData.get((s.stageId, s.attemptId)) - } - val lastStageName = lastStageInfo.map { _.name }.getOrElse("(Unknown Stage Name)") - val lastStageDescription = lastStageData.flatMap { _.description } - new JobData( - jobId = job.jobId, - name = lastStageName, - description = lastStageDescription, - submissionTime = job.submissionTime.map{new Date(_)}, - completionTime = job.completionTime.map{new Date(_)}, - stageIds = job.stageIds, - jobGroup = job.jobGroup, - status = job.status, - numTasks = job.numTasks, - numActiveTasks = job.numActiveTasks, - numCompletedTasks = job.numCompletedTasks, - numSkippedTasks = job.numSkippedTasks, - numFailedTasks = job.numFailedTasks, - numActiveStages = job.numActiveStages, - numCompletedStages = job.completedStageIndices.size, - numSkippedStages = job.numSkippedStages, - numFailedStages = job.numFailedStages - ) - } - } -} diff --git a/core/src/main/scala/org/apache/spark/status/api/v1/AllStagesResource.scala b/core/src/main/scala/org/apache/spark/status/api/v1/AllStagesResource.scala index 5f69949c618fd..e1c91cb527a51 100644 --- a/core/src/main/scala/org/apache/spark/status/api/v1/AllStagesResource.scala +++ b/core/src/main/scala/org/apache/spark/status/api/v1/AllStagesResource.scala @@ -16,304 +16,18 @@ */ package org.apache.spark.status.api.v1 -import java.util.{Arrays, Date, List => JList} +import java.util.{List => JList} import javax.ws.rs.{GET, Produces, QueryParam} import javax.ws.rs.core.MediaType -import org.apache.spark.scheduler.{AccumulableInfo => InternalAccumulableInfo, StageInfo} import org.apache.spark.ui.SparkUI -import org.apache.spark.ui.jobs.UIData.{StageUIData, TaskUIData} -import org.apache.spark.ui.jobs.UIData.{InputMetricsUIData => InternalInputMetrics, OutputMetricsUIData => InternalOutputMetrics, ShuffleReadMetricsUIData => InternalShuffleReadMetrics, ShuffleWriteMetricsUIData => InternalShuffleWriteMetrics, TaskMetricsUIData => InternalTaskMetrics} -import org.apache.spark.util.Distribution @Produces(Array(MediaType.APPLICATION_JSON)) private[v1] class AllStagesResource(ui: SparkUI) { @GET def stageList(@QueryParam("status") statuses: JList[StageStatus]): Seq[StageData] = { - val listener = ui.jobProgressListener - val stageAndStatus = AllStagesResource.stagesAndStatus(ui) - val adjStatuses = { - if (statuses.isEmpty()) { - Arrays.asList(StageStatus.values(): _*) - } else { - statuses - } - } - for { - (status, stageList) <- stageAndStatus - stageInfo: StageInfo <- stageList if adjStatuses.contains(status) - stageUiData: StageUIData <- listener.synchronized { - listener.stageIdToData.get((stageInfo.stageId, stageInfo.attemptId)) - } - } yield { - stageUiData.lastUpdateTime = ui.lastUpdateTime - AllStagesResource.stageUiToStageData(status, stageInfo, stageUiData, includeDetails = false) - } + ui.store.stageList(statuses) } -} - -private[v1] object AllStagesResource { - def stageUiToStageData( - status: StageStatus, - stageInfo: StageInfo, - stageUiData: StageUIData, - includeDetails: Boolean): StageData = { - - val taskLaunchTimes = stageUiData.taskData.values.map(_.taskInfo.launchTime).filter(_ > 0) - - val firstTaskLaunchedTime: Option[Date] = - if (taskLaunchTimes.nonEmpty) { - Some(new Date(taskLaunchTimes.min)) - } else { - None - } - - val taskData = if (includeDetails) { - Some(stageUiData.taskData.map { case (k, v) => - k -> convertTaskData(v, stageUiData.lastUpdateTime) }.toMap) - } else { - None - } - val executorSummary = if (includeDetails) { - Some(stageUiData.executorSummary.map { case (k, summary) => - k -> new ExecutorStageSummary( - taskTime = summary.taskTime, - failedTasks = summary.failedTasks, - succeededTasks = summary.succeededTasks, - inputBytes = summary.inputBytes, - outputBytes = summary.outputBytes, - shuffleRead = summary.shuffleRead, - shuffleWrite = summary.shuffleWrite, - memoryBytesSpilled = summary.memoryBytesSpilled, - diskBytesSpilled = summary.diskBytesSpilled - ) - }.toMap) - } else { - None - } - - val accumulableInfo = stageUiData.accumulables.values.map { convertAccumulableInfo }.toSeq - - new StageData( - status = status, - stageId = stageInfo.stageId, - attemptId = stageInfo.attemptId, - numActiveTasks = stageUiData.numActiveTasks, - numCompleteTasks = stageUiData.numCompleteTasks, - numFailedTasks = stageUiData.numFailedTasks, - executorRunTime = stageUiData.executorRunTime, - executorCpuTime = stageUiData.executorCpuTime, - submissionTime = stageInfo.submissionTime.map(new Date(_)), - firstTaskLaunchedTime, - completionTime = stageInfo.completionTime.map(new Date(_)), - inputBytes = stageUiData.inputBytes, - inputRecords = stageUiData.inputRecords, - outputBytes = stageUiData.outputBytes, - outputRecords = stageUiData.outputRecords, - shuffleReadBytes = stageUiData.shuffleReadTotalBytes, - shuffleReadRecords = stageUiData.shuffleReadRecords, - shuffleWriteBytes = stageUiData.shuffleWriteBytes, - shuffleWriteRecords = stageUiData.shuffleWriteRecords, - memoryBytesSpilled = stageUiData.memoryBytesSpilled, - diskBytesSpilled = stageUiData.diskBytesSpilled, - schedulingPool = stageUiData.schedulingPool, - name = stageInfo.name, - details = stageInfo.details, - accumulatorUpdates = accumulableInfo, - tasks = taskData, - executorSummary = executorSummary - ) - } - - def stagesAndStatus(ui: SparkUI): Seq[(StageStatus, Seq[StageInfo])] = { - val listener = ui.jobProgressListener - listener.synchronized { - Seq( - StageStatus.ACTIVE -> listener.activeStages.values.toSeq, - StageStatus.COMPLETE -> listener.completedStages.reverse.toSeq, - StageStatus.FAILED -> listener.failedStages.reverse.toSeq, - StageStatus.PENDING -> listener.pendingStages.values.toSeq - ) - } - } - - def convertTaskData(uiData: TaskUIData, lastUpdateTime: Option[Long]): TaskData = { - new TaskData( - taskId = uiData.taskInfo.taskId, - index = uiData.taskInfo.index, - attempt = uiData.taskInfo.attemptNumber, - launchTime = new Date(uiData.taskInfo.launchTime), - duration = uiData.taskDuration(lastUpdateTime), - executorId = uiData.taskInfo.executorId, - host = uiData.taskInfo.host, - status = uiData.taskInfo.status, - taskLocality = uiData.taskInfo.taskLocality.toString(), - speculative = uiData.taskInfo.speculative, - accumulatorUpdates = uiData.taskInfo.accumulables.map { convertAccumulableInfo }, - errorMessage = uiData.errorMessage, - taskMetrics = uiData.metrics.map { convertUiTaskMetrics } - ) - } - - def taskMetricDistributions( - allTaskData: Iterable[TaskUIData], - quantiles: Array[Double]): TaskMetricDistributions = { - - val rawMetrics = allTaskData.flatMap{_.metrics}.toSeq - - def metricQuantiles(f: InternalTaskMetrics => Double): IndexedSeq[Double] = - Distribution(rawMetrics.map { d => f(d) }).get.getQuantiles(quantiles) - - // We need to do a lot of similar munging to nested metrics here. For each one, - // we want (a) extract the values for nested metrics (b) make a distribution for each metric - // (c) shove the distribution into the right field in our return type and (d) only return - // a result if the option is defined for any of the tasks. MetricHelper is a little util - // to make it a little easier to deal w/ all of the nested options. Mostly it lets us just - // implement one "build" method, which just builds the quantiles for each field. - - val inputMetrics: InputMetricDistributions = - new MetricHelper[InternalInputMetrics, InputMetricDistributions](rawMetrics, quantiles) { - def getSubmetrics(raw: InternalTaskMetrics): InternalInputMetrics = raw.inputMetrics - - def build: InputMetricDistributions = new InputMetricDistributions( - bytesRead = submetricQuantiles(_.bytesRead), - recordsRead = submetricQuantiles(_.recordsRead) - ) - }.build - - val outputMetrics: OutputMetricDistributions = - new MetricHelper[InternalOutputMetrics, OutputMetricDistributions](rawMetrics, quantiles) { - def getSubmetrics(raw: InternalTaskMetrics): InternalOutputMetrics = raw.outputMetrics - - def build: OutputMetricDistributions = new OutputMetricDistributions( - bytesWritten = submetricQuantiles(_.bytesWritten), - recordsWritten = submetricQuantiles(_.recordsWritten) - ) - }.build - - val shuffleReadMetrics: ShuffleReadMetricDistributions = - new MetricHelper[InternalShuffleReadMetrics, ShuffleReadMetricDistributions](rawMetrics, - quantiles) { - def getSubmetrics(raw: InternalTaskMetrics): InternalShuffleReadMetrics = - raw.shuffleReadMetrics - - def build: ShuffleReadMetricDistributions = new ShuffleReadMetricDistributions( - readBytes = submetricQuantiles(_.totalBytesRead), - readRecords = submetricQuantiles(_.recordsRead), - remoteBytesRead = submetricQuantiles(_.remoteBytesRead), - remoteBytesReadToDisk = submetricQuantiles(_.remoteBytesReadToDisk), - remoteBlocksFetched = submetricQuantiles(_.remoteBlocksFetched), - localBlocksFetched = submetricQuantiles(_.localBlocksFetched), - totalBlocksFetched = submetricQuantiles(_.totalBlocksFetched), - fetchWaitTime = submetricQuantiles(_.fetchWaitTime) - ) - }.build - - val shuffleWriteMetrics: ShuffleWriteMetricDistributions = - new MetricHelper[InternalShuffleWriteMetrics, ShuffleWriteMetricDistributions](rawMetrics, - quantiles) { - def getSubmetrics(raw: InternalTaskMetrics): InternalShuffleWriteMetrics = - raw.shuffleWriteMetrics - def build: ShuffleWriteMetricDistributions = new ShuffleWriteMetricDistributions( - writeBytes = submetricQuantiles(_.bytesWritten), - writeRecords = submetricQuantiles(_.recordsWritten), - writeTime = submetricQuantiles(_.writeTime) - ) - }.build - - new TaskMetricDistributions( - quantiles = quantiles, - executorDeserializeTime = metricQuantiles(_.executorDeserializeTime), - executorDeserializeCpuTime = metricQuantiles(_.executorDeserializeCpuTime), - executorRunTime = metricQuantiles(_.executorRunTime), - executorCpuTime = metricQuantiles(_.executorCpuTime), - resultSize = metricQuantiles(_.resultSize), - jvmGcTime = metricQuantiles(_.jvmGCTime), - resultSerializationTime = metricQuantiles(_.resultSerializationTime), - memoryBytesSpilled = metricQuantiles(_.memoryBytesSpilled), - diskBytesSpilled = metricQuantiles(_.diskBytesSpilled), - inputMetrics = inputMetrics, - outputMetrics = outputMetrics, - shuffleReadMetrics = shuffleReadMetrics, - shuffleWriteMetrics = shuffleWriteMetrics - ) - } - - def convertAccumulableInfo(acc: InternalAccumulableInfo): AccumulableInfo = { - new AccumulableInfo( - acc.id, acc.name.orNull, acc.update.map(_.toString), acc.value.map(_.toString).orNull) - } - - def convertUiTaskMetrics(internal: InternalTaskMetrics): TaskMetrics = { - new TaskMetrics( - executorDeserializeTime = internal.executorDeserializeTime, - executorDeserializeCpuTime = internal.executorDeserializeCpuTime, - executorRunTime = internal.executorRunTime, - executorCpuTime = internal.executorCpuTime, - resultSize = internal.resultSize, - jvmGcTime = internal.jvmGCTime, - resultSerializationTime = internal.resultSerializationTime, - memoryBytesSpilled = internal.memoryBytesSpilled, - diskBytesSpilled = internal.diskBytesSpilled, - inputMetrics = convertInputMetrics(internal.inputMetrics), - outputMetrics = convertOutputMetrics(internal.outputMetrics), - shuffleReadMetrics = convertShuffleReadMetrics(internal.shuffleReadMetrics), - shuffleWriteMetrics = convertShuffleWriteMetrics(internal.shuffleWriteMetrics) - ) - } - - def convertInputMetrics(internal: InternalInputMetrics): InputMetrics = { - new InputMetrics( - bytesRead = internal.bytesRead, - recordsRead = internal.recordsRead - ) - } - - def convertOutputMetrics(internal: InternalOutputMetrics): OutputMetrics = { - new OutputMetrics( - bytesWritten = internal.bytesWritten, - recordsWritten = internal.recordsWritten - ) - } - - def convertShuffleReadMetrics(internal: InternalShuffleReadMetrics): ShuffleReadMetrics = { - new ShuffleReadMetrics( - remoteBlocksFetched = internal.remoteBlocksFetched, - localBlocksFetched = internal.localBlocksFetched, - fetchWaitTime = internal.fetchWaitTime, - remoteBytesRead = internal.remoteBytesRead, - remoteBytesReadToDisk = internal.remoteBytesReadToDisk, - localBytesRead = internal.localBytesRead, - recordsRead = internal.recordsRead - ) - } - - def convertShuffleWriteMetrics(internal: InternalShuffleWriteMetrics): ShuffleWriteMetrics = { - new ShuffleWriteMetrics( - bytesWritten = internal.bytesWritten, - writeTime = internal.writeTime, - recordsWritten = internal.recordsWritten - ) - } -} - -/** - * Helper for getting distributions from nested metric types. - */ -private[v1] abstract class MetricHelper[I, O]( - rawMetrics: Seq[InternalTaskMetrics], - quantiles: Array[Double]) { - - def getSubmetrics(raw: InternalTaskMetrics): I - - def build: O - - val data: Seq[I] = rawMetrics.map(getSubmetrics) - - /** applies the given function to all input metrics, and returns the quantiles */ - def submetricQuantiles(f: I => Double): IndexedSeq[Double] = { - Distribution(data.map { d => f(d) }).get.getQuantiles(quantiles) - } } diff --git a/core/src/main/scala/org/apache/spark/status/api/v1/OneJobResource.scala b/core/src/main/scala/org/apache/spark/status/api/v1/OneJobResource.scala index 653150385c732..3ee884e084c12 100644 --- a/core/src/main/scala/org/apache/spark/status/api/v1/OneJobResource.scala +++ b/core/src/main/scala/org/apache/spark/status/api/v1/OneJobResource.scala @@ -16,25 +16,22 @@ */ package org.apache.spark.status.api.v1 +import java.util.NoSuchElementException import javax.ws.rs.{GET, PathParam, Produces} import javax.ws.rs.core.MediaType -import org.apache.spark.JobExecutionStatus import org.apache.spark.ui.SparkUI -import org.apache.spark.ui.jobs.UIData.JobUIData @Produces(Array(MediaType.APPLICATION_JSON)) private[v1] class OneJobResource(ui: SparkUI) { @GET def oneJob(@PathParam("jobId") jobId: Int): JobData = { - val statusToJobs: Seq[(JobExecutionStatus, Seq[JobUIData])] = - AllJobsResource.getStatusToJobs(ui) - val jobOpt = statusToJobs.flatMap(_._2).find { jobInfo => jobInfo.jobId == jobId} - jobOpt.map { job => - AllJobsResource.convertJobData(job, ui.jobProgressListener, false) - }.getOrElse { - throw new NotFoundException("unknown job: " + jobId) + try { + ui.store.job(jobId) + } catch { + case _: NoSuchElementException => + throw new NotFoundException("unknown job: " + jobId) } } diff --git a/core/src/main/scala/org/apache/spark/status/api/v1/OneStageResource.scala b/core/src/main/scala/org/apache/spark/status/api/v1/OneStageResource.scala index f15073bccced2..20dd73e916613 100644 --- a/core/src/main/scala/org/apache/spark/status/api/v1/OneStageResource.scala +++ b/core/src/main/scala/org/apache/spark/status/api/v1/OneStageResource.scala @@ -24,7 +24,6 @@ import org.apache.spark.scheduler.StageInfo import org.apache.spark.status.api.v1.StageStatus._ import org.apache.spark.status.api.v1.TaskSorting._ import org.apache.spark.ui.SparkUI -import org.apache.spark.ui.jobs.JobProgressListener import org.apache.spark.ui.jobs.UIData.StageUIData @Produces(Array(MediaType.APPLICATION_JSON)) @@ -32,13 +31,14 @@ private[v1] class OneStageResource(ui: SparkUI) { @GET @Path("") - def stageData(@PathParam("stageId") stageId: Int): Seq[StageData] = { - withStage(stageId) { stageAttempts => - stageAttempts.map { stage => - stage.ui.lastUpdateTime = ui.lastUpdateTime - AllStagesResource.stageUiToStageData(stage.status, stage.info, stage.ui, - includeDetails = true) - } + def stageData( + @PathParam("stageId") stageId: Int, + @QueryParam("details") @DefaultValue("true") details: Boolean): Seq[StageData] = { + val ret = ui.store.stageData(stageId, details = details) + if (ret.nonEmpty) { + ret + } else { + throw new NotFoundException(s"unknown stage: $stageId") } } @@ -46,11 +46,13 @@ private[v1] class OneStageResource(ui: SparkUI) { @Path("/{stageAttemptId: \\d+}") def oneAttemptData( @PathParam("stageId") stageId: Int, - @PathParam("stageAttemptId") stageAttemptId: Int): StageData = { - withStageAttempt(stageId, stageAttemptId) { stage => - stage.ui.lastUpdateTime = ui.lastUpdateTime - AllStagesResource.stageUiToStageData(stage.status, stage.info, stage.ui, - includeDetails = true) + @PathParam("stageAttemptId") stageAttemptId: Int, + @QueryParam("details") @DefaultValue("true") details: Boolean): StageData = { + try { + ui.store.stageAttempt(stageId, stageAttemptId, details = details) + } catch { + case _: NoSuchElementException => + throw new NotFoundException(s"unknown attempt $stageAttemptId for stage $stageId.") } } @@ -61,17 +63,16 @@ private[v1] class OneStageResource(ui: SparkUI) { @PathParam("stageAttemptId") stageAttemptId: Int, @DefaultValue("0.05,0.25,0.5,0.75,0.95") @QueryParam("quantiles") quantileString: String) : TaskMetricDistributions = { - withStageAttempt(stageId, stageAttemptId) { stage => - val quantiles = quantileString.split(",").map { s => - try { - s.toDouble - } catch { - case nfe: NumberFormatException => - throw new BadParameterException("quantiles", "double", s) - } + val quantiles = quantileString.split(",").map { s => + try { + s.toDouble + } catch { + case nfe: NumberFormatException => + throw new BadParameterException("quantiles", "double", s) } - AllStagesResource.taskMetricDistributions(stage.ui.taskData.values, quantiles) } + + ui.store.taskSummary(stageId, stageAttemptId, quantiles) } @GET @@ -82,72 +83,7 @@ private[v1] class OneStageResource(ui: SparkUI) { @DefaultValue("0") @QueryParam("offset") offset: Int, @DefaultValue("20") @QueryParam("length") length: Int, @DefaultValue("ID") @QueryParam("sortBy") sortBy: TaskSorting): Seq[TaskData] = { - withStageAttempt(stageId, stageAttemptId) { stage => - val tasks = stage.ui.taskData.values - .map{ AllStagesResource.convertTaskData(_, ui.lastUpdateTime)}.toIndexedSeq - .sorted(OneStageResource.ordering(sortBy)) - tasks.slice(offset, offset + length) - } - } - - private case class StageStatusInfoUi(status: StageStatus, info: StageInfo, ui: StageUIData) - - private def withStage[T](stageId: Int)(f: Seq[StageStatusInfoUi] => T): T = { - val stageAttempts = findStageStatusUIData(ui.jobProgressListener, stageId) - if (stageAttempts.isEmpty) { - throw new NotFoundException("unknown stage: " + stageId) - } else { - f(stageAttempts) - } + ui.store.taskList(stageId, stageAttemptId, offset, length, sortBy) } - private def findStageStatusUIData( - listener: JobProgressListener, - stageId: Int): Seq[StageStatusInfoUi] = { - listener.synchronized { - def getStatusInfoUi(status: StageStatus, infos: Seq[StageInfo]): Seq[StageStatusInfoUi] = { - infos.filter { _.stageId == stageId }.map { info => - val ui = listener.stageIdToData.getOrElse((info.stageId, info.attemptId), - // this is an internal error -- we should always have uiData - throw new SparkException( - s"no stage ui data found for stage: ${info.stageId}:${info.attemptId}") - ) - StageStatusInfoUi(status, info, ui) - } - } - getStatusInfoUi(ACTIVE, listener.activeStages.values.toSeq) ++ - getStatusInfoUi(COMPLETE, listener.completedStages) ++ - getStatusInfoUi(FAILED, listener.failedStages) ++ - getStatusInfoUi(PENDING, listener.pendingStages.values.toSeq) - } - } - - private def withStageAttempt[T]( - stageId: Int, - stageAttemptId: Int) - (f: StageStatusInfoUi => T): T = { - withStage(stageId) { attempts => - val oneAttempt = attempts.find { stage => stage.info.attemptId == stageAttemptId } - oneAttempt match { - case Some(stage) => - f(stage) - case None => - val stageAttempts = attempts.map { _.info.attemptId } - throw new NotFoundException(s"unknown attempt for stage $stageId. " + - s"Found attempts: ${stageAttempts.mkString("[", ",", "]")}") - } - } - } -} - -object OneStageResource { - def ordering(taskSorting: TaskSorting): Ordering[TaskData] = { - val extractor: (TaskData => Long) = td => - taskSorting match { - case ID => td.taskId - case INCREASING_RUNTIME => td.taskMetrics.map{_.executorRunTime}.getOrElse(-1L) - case DECREASING_RUNTIME => -td.taskMetrics.map{_.executorRunTime}.getOrElse(-1L) - } - Ordering.by(extractor) - } } diff --git a/core/src/main/scala/org/apache/spark/status/api/v1/api.scala b/core/src/main/scala/org/apache/spark/status/api/v1/api.scala index b338b1f3fd073..14280099f6422 100644 --- a/core/src/main/scala/org/apache/spark/status/api/v1/api.scala +++ b/core/src/main/scala/org/apache/spark/status/api/v1/api.scala @@ -58,10 +58,15 @@ class ExecutorStageSummary private[spark]( val taskTime : Long, val failedTasks : Int, val succeededTasks : Int, + val killedTasks : Int, val inputBytes : Long, + val inputRecords : Long, val outputBytes : Long, + val outputRecords : Long, val shuffleRead : Long, + val shuffleReadRecords : Long, val shuffleWrite : Long, + val shuffleWriteRecords : Long, val memoryBytesSpilled : Long, val diskBytesSpilled : Long) @@ -111,10 +116,13 @@ class JobData private[spark]( val numCompletedTasks: Int, val numSkippedTasks: Int, val numFailedTasks: Int, + val numKilledTasks: Int, + val numCompletedIndices: Int, val numActiveStages: Int, val numCompletedStages: Int, val numSkippedStages: Int, - val numFailedStages: Int) + val numFailedStages: Int, + val killedTasksSummary: Map[String, Int]) class RDDStorageInfo private[spark]( val id: Int, @@ -152,15 +160,19 @@ class StageData private[spark]( val status: StageStatus, val stageId: Int, val attemptId: Int, + val numTasks: Int, val numActiveTasks: Int, val numCompleteTasks: Int, val numFailedTasks: Int, + val numKilledTasks: Int, + val numCompletedIndices: Int, val executorRunTime: Long, val executorCpuTime: Long, val submissionTime: Option[Date], val firstTaskLaunchedTime: Option[Date], val completionTime: Option[Date], + val failureReason: Option[String], val inputBytes: Long, val inputRecords: Long, @@ -174,18 +186,22 @@ class StageData private[spark]( val diskBytesSpilled: Long, val name: String, + val description: Option[String], val details: String, val schedulingPool: String, + val rddIds: Seq[Int], val accumulatorUpdates: Seq[AccumulableInfo], val tasks: Option[Map[Long, TaskData]], - val executorSummary: Option[Map[String, ExecutorStageSummary]]) + val executorSummary: Option[Map[String, ExecutorStageSummary]], + val killedTasksSummary: Map[String, Int]) class TaskData private[spark]( val taskId: Long, val index: Int, val attempt: Int, val launchTime: Date, + val resultFetchStart: Option[Date], @JsonDeserialize(contentAs = classOf[JLong]) val duration: Option[Long], val executorId: String, @@ -207,6 +223,7 @@ class TaskMetrics private[spark]( val resultSerializationTime: Long, val memoryBytesSpilled: Long, val diskBytesSpilled: Long, + val peakExecutionMemory: Long, val inputMetrics: InputMetrics, val outputMetrics: OutputMetrics, val shuffleReadMetrics: ShuffleReadMetrics, diff --git a/core/src/main/scala/org/apache/spark/status/config.scala b/core/src/main/scala/org/apache/spark/status/config.scala index 49144fc883e69..7af9dff977a86 100644 --- a/core/src/main/scala/org/apache/spark/status/config.scala +++ b/core/src/main/scala/org/apache/spark/status/config.scala @@ -27,4 +27,8 @@ private[spark] object config { .timeConf(TimeUnit.NANOSECONDS) .createWithDefaultString("100ms") + val MAX_RETAINED_ROOT_NODES = ConfigBuilder("spark.ui.dagGraph.retainedRootRDDs") + .intConf + .createWithDefault(Int.MaxValue) + } diff --git a/core/src/main/scala/org/apache/spark/status/storeTypes.scala b/core/src/main/scala/org/apache/spark/status/storeTypes.scala index f44b7935bfaa3..c1ea87542d6cc 100644 --- a/core/src/main/scala/org/apache/spark/status/storeTypes.scala +++ b/core/src/main/scala/org/apache/spark/status/storeTypes.scala @@ -23,6 +23,7 @@ import com.fasterxml.jackson.annotation.JsonIgnore import org.apache.spark.status.KVUtils._ import org.apache.spark.status.api.v1._ +import org.apache.spark.ui.scope._ import org.apache.spark.util.kvstore.KVIndex private[spark] case class AppStatusStoreMetadata(version: Long) @@ -106,6 +107,11 @@ private[spark] class TaskDataWrapper( Array(stageId: JInteger, stageAttemptId: JInteger, _runtime: JLong) } + @JsonIgnore @KVIndex("startTime") + def startTime: Array[AnyRef] = { + Array(stageId: JInteger, stageAttemptId: JInteger, info.launchTime.getTime(): JLong) + } + } private[spark] class RDDStorageInfoWrapper(val info: RDDStorageInfo) { @@ -147,3 +153,37 @@ private[spark] class StreamBlockData( def key: Array[String] = Array(name, executorId) } + +private[spark] class RDDOperationClusterWrapper( + val id: String, + val name: String, + val childNodes: Seq[RDDOperationNode], + val childClusters: Seq[RDDOperationClusterWrapper]) { + + def toRDDOperationCluster(): RDDOperationCluster = { + val cluster = new RDDOperationCluster(id, name) + childNodes.foreach(cluster.attachChildNode) + childClusters.foreach { child => + cluster.attachChildCluster(child.toRDDOperationCluster()) + } + cluster + } + +} + +private[spark] class RDDOperationGraphWrapper( + @KVIndexParam val stageId: Int, + val edges: Seq[RDDOperationEdge], + val outgoingEdges: Seq[RDDOperationEdge], + val incomingEdges: Seq[RDDOperationEdge], + val rootCluster: RDDOperationClusterWrapper) { + + def toRDDOperationGraph(): RDDOperationGraph = { + new RDDOperationGraph(edges, outgoingEdges, incomingEdges, rootCluster.toRDDOperationCluster()) + } + +} + +private[spark] class PoolData( + @KVIndexParam val name: String, + val stageIds: Set[Int]) diff --git a/core/src/main/scala/org/apache/spark/ui/SparkUI.scala b/core/src/main/scala/org/apache/spark/ui/SparkUI.scala index e93ade001c607..35da3c3bfd1a2 100644 --- a/core/src/main/scala/org/apache/spark/ui/SparkUI.scala +++ b/core/src/main/scala/org/apache/spark/ui/SparkUI.scala @@ -17,11 +17,11 @@ package org.apache.spark.ui -import java.util.{Date, ServiceLoader} +import java.util.{Date, List => JList, ServiceLoader} import scala.collection.JavaConverters._ -import org.apache.spark.{SecurityManager, SparkConf, SparkContext} +import org.apache.spark.{JobExecutionStatus, SecurityManager, SparkConf, SparkContext} import org.apache.spark.internal.Logging import org.apache.spark.scheduler._ import org.apache.spark.status.AppStatusStore @@ -29,8 +29,7 @@ import org.apache.spark.status.api.v1._ import org.apache.spark.ui.JettyUtils._ import org.apache.spark.ui.env.EnvironmentTab import org.apache.spark.ui.exec.ExecutorsTab -import org.apache.spark.ui.jobs.{JobProgressListener, JobsTab, StagesTab} -import org.apache.spark.ui.scope.RDDOperationGraphListener +import org.apache.spark.ui.jobs.{JobsTab, StagesTab} import org.apache.spark.ui.storage.StorageTab import org.apache.spark.util.Utils @@ -42,11 +41,8 @@ private[spark] class SparkUI private ( val sc: Option[SparkContext], val conf: SparkConf, securityManager: SecurityManager, - val jobProgressListener: JobProgressListener, - val operationGraphListener: RDDOperationGraphListener, var appName: String, val basePath: String, - val lastUpdateTime: Option[Long] = None, val startTime: Long, val appSparkVersion: String) extends WebUI(securityManager, securityManager.getSSLOptions("ui"), SparkUI.getUIPort(conf), @@ -61,8 +57,8 @@ private[spark] class SparkUI private ( private var streamingJobProgressListener: Option[SparkListener] = None /** Initialize all components of the server. */ - def initialize() { - val jobsTab = new JobsTab(this) + def initialize(): Unit = { + val jobsTab = new JobsTab(this, store) attachTab(jobsTab) val stagesTab = new StagesTab(this, store) attachTab(stagesTab) @@ -72,6 +68,7 @@ private[spark] class SparkUI private ( attachHandler(createStaticHandler(SparkUI.STATIC_RESOURCE_DIR, "/static")) attachHandler(createRedirectHandler("/", "/jobs/", basePath = basePath)) attachHandler(ApiRootResource.getServletHandler(this)) + // These should be POST only, but, the YARN AM proxy won't proxy POSTs attachHandler(createRedirectHandler( "/jobs/job/kill", "/jobs/", jobsTab.handleKillRequest, httpMethods = Set("GET", "POST"))) @@ -79,6 +76,7 @@ private[spark] class SparkUI private ( "/stages/stage/kill", "/stages/", stagesTab.handleKillRequest, httpMethods = Set("GET", "POST"))) } + initialize() def getSparkUser: String = { @@ -170,25 +168,13 @@ private[spark] object SparkUI { sc: Option[SparkContext], store: AppStatusStore, conf: SparkConf, - addListenerFn: SparkListenerInterface => Unit, securityManager: SecurityManager, appName: String, basePath: String, startTime: Long, - lastUpdateTime: Option[Long] = None, appSparkVersion: String = org.apache.spark.SPARK_VERSION): SparkUI = { - val jobProgressListener = sc.map(_.jobProgressListener).getOrElse { - val listener = new JobProgressListener(conf) - addListenerFn(listener) - listener - } - val operationGraphListener = new RDDOperationGraphListener(conf) - - addListenerFn(operationGraphListener) - - new SparkUI(store, sc, conf, securityManager, jobProgressListener, operationGraphListener, - appName, basePath, lastUpdateTime, startTime, appSparkVersion) + new SparkUI(store, sc, conf, securityManager, appName, basePath, startTime, appSparkVersion) } } diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/AllJobsPage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/AllJobsPage.scala index a647a1173a8cb..b60d39b21b4bf 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/AllJobsPage.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/AllJobsPage.scala @@ -22,20 +22,20 @@ import java.util.Date import javax.servlet.http.HttpServletRequest import scala.collection.JavaConverters._ -import scala.collection.mutable.{HashMap, ListBuffer} +import scala.collection.mutable.ListBuffer import scala.xml._ import org.apache.commons.lang3.StringEscapeUtils import org.apache.spark.JobExecutionStatus import org.apache.spark.scheduler._ -import org.apache.spark.status.api.v1.ExecutorSummary +import org.apache.spark.status.AppStatusStore +import org.apache.spark.status.api.v1 import org.apache.spark.ui._ -import org.apache.spark.ui.jobs.UIData.{JobUIData, StageUIData} import org.apache.spark.util.Utils /** Page showing list of all ongoing and recently finished jobs */ -private[ui] class AllJobsPage(parent: JobsTab) extends WebUIPage("") { +private[ui] class AllJobsPage(parent: JobsTab, store: AppStatusStore) extends WebUIPage("") { private val JOBS_LEGEND =
    Removed
    .toString.filter(_ != '\n') - private def getLastStageNameAndDescription(job: JobUIData): (String, String) = { - val lastStageInfo = Option(job.stageIds) - .filter(_.nonEmpty) - .flatMap { ids => parent.jobProgresslistener.stageIdToInfo.get(ids.max)} - val lastStageData = lastStageInfo.flatMap { s => - parent.jobProgresslistener.stageIdToData.get((s.stageId, s.attemptId)) - } - val name = lastStageInfo.map(_.name).getOrElse("(Unknown Stage Name)") - val description = lastStageData.flatMap(_.description).getOrElse("") - (name, description) - } - - private def makeJobEvent(jobUIDatas: Seq[JobUIData]): Seq[String] = { - jobUIDatas.filter { jobUIData => - jobUIData.status != JobExecutionStatus.UNKNOWN && jobUIData.submissionTime.isDefined - }.map { jobUIData => - val jobId = jobUIData.jobId - val status = jobUIData.status - val (jobName, jobDescription) = getLastStageNameAndDescription(jobUIData) + private def makeJobEvent(jobs: Seq[v1.JobData]): Seq[String] = { + jobs.filter { job => + job.status != JobExecutionStatus.UNKNOWN && job.submissionTime.isDefined + }.map { job => + val jobId = job.jobId + val status = job.status val displayJobDescription = - if (jobDescription.isEmpty) { - jobName + if (job.description.isEmpty) { + job.name } else { - UIUtils.makeDescription(jobDescription, "", plainText = true).text + UIUtils.makeDescription(job.description.get, "", plainText = true).text } - val submissionTime = jobUIData.submissionTime.get - val completionTimeOpt = jobUIData.completionTime - val completionTime = completionTimeOpt.getOrElse(System.currentTimeMillis()) + val submissionTime = job.submissionTime.get.getTime() + val completionTime = job.completionTime.map(_.getTime()).getOrElse(System.currentTimeMillis()) val classNameByStatus = status match { case JobExecutionStatus.SUCCEEDED => "succeeded" case JobExecutionStatus.FAILED => "failed" @@ -124,7 +110,7 @@ private[ui] class AllJobsPage(parent: JobsTab) extends WebUIPage("") { } } - private def makeExecutorEvent(executors: Seq[ExecutorSummary]): + private def makeExecutorEvent(executors: Seq[v1.ExecutorSummary]): Seq[String] = { val events = ListBuffer[String]() executors.foreach { e => @@ -169,8 +155,8 @@ private[ui] class AllJobsPage(parent: JobsTab) extends WebUIPage("") { } private def makeTimeline( - jobs: Seq[JobUIData], - executors: Seq[ExecutorSummary], + jobs: Seq[v1.JobData], + executors: Seq[v1.ExecutorSummary], startTime: Long): Seq[Node] = { val jobEventJsonAsStrSeq = makeJobEvent(jobs) @@ -217,7 +203,7 @@ private[ui] class AllJobsPage(parent: JobsTab) extends WebUIPage("") { request: HttpServletRequest, tableHeaderId: String, jobTag: String, - jobs: Seq[JobUIData], + jobs: Seq[v1.JobData], killEnabled: Boolean): Seq[Node] = { // stripXSS is called to remove suspicious characters used in XSS attacks val allParameters = request.getParameterMap.asScala.toMap.mapValues(_.map(UIUtils.stripXSS)) @@ -258,14 +244,13 @@ private[ui] class AllJobsPage(parent: JobsTab) extends WebUIPage("") { try { new JobPagedTable( + store, jobs, tableHeaderId, jobTag, UIUtils.prependBaseUri(parent.basePath), "jobs", // subPath parameterOtherTable, - parent.jobProgresslistener.stageIdToInfo, - parent.jobProgresslistener.stageIdToData, killEnabled, currentTime, jobIdTitle, @@ -285,106 +270,117 @@ private[ui] class AllJobsPage(parent: JobsTab) extends WebUIPage("") { } def render(request: HttpServletRequest): Seq[Node] = { - val listener = parent.jobProgresslistener - listener.synchronized { - val startTime = listener.startTime - val endTime = listener.endTime - val activeJobs = listener.activeJobs.values.toSeq - val completedJobs = listener.completedJobs.reverse - val failedJobs = listener.failedJobs.reverse - - val activeJobsTable = - jobsTable(request, "active", "activeJob", activeJobs, killEnabled = parent.killEnabled) - val completedJobsTable = - jobsTable(request, "completed", "completedJob", completedJobs, killEnabled = false) - val failedJobsTable = - jobsTable(request, "failed", "failedJob", failedJobs, killEnabled = false) - - val shouldShowActiveJobs = activeJobs.nonEmpty - val shouldShowCompletedJobs = completedJobs.nonEmpty - val shouldShowFailedJobs = failedJobs.nonEmpty - - val completedJobNumStr = if (completedJobs.size == listener.numCompletedJobs) { - s"${completedJobs.size}" - } else { - s"${listener.numCompletedJobs}, only showing ${completedJobs.size}" + val appInfo = store.applicationInfo() + val startTime = appInfo.attempts.head.startTime.getTime() + val endTime = appInfo.attempts.head.endTime.getTime() + + val activeJobs = new ListBuffer[v1.JobData]() + val completedJobs = new ListBuffer[v1.JobData]() + val failedJobs = new ListBuffer[v1.JobData]() + + store.jobsList(null).foreach { job => + job.status match { + case JobExecutionStatus.SUCCEEDED => + completedJobs += job + case JobExecutionStatus.FAILED => + failedJobs += job + case _ => + activeJobs += job } + } - val summary: NodeSeq = -
    -
      -
    • - User: - {parent.getSparkUser} -
    • -
    • - Total Uptime: - { - if (endTime < 0 && parent.sc.isDefined) { - UIUtils.formatDuration(System.currentTimeMillis() - startTime) - } else if (endTime > 0) { - UIUtils.formatDuration(endTime - startTime) - } - } -
    • -
    • - Scheduling Mode: - {listener.schedulingMode.map(_.toString).getOrElse("Unknown")} -
    • + val activeJobsTable = + jobsTable(request, "active", "activeJob", activeJobs, killEnabled = parent.killEnabled) + val completedJobsTable = + jobsTable(request, "completed", "completedJob", completedJobs, killEnabled = false) + val failedJobsTable = + jobsTable(request, "failed", "failedJob", failedJobs, killEnabled = false) + + val shouldShowActiveJobs = activeJobs.nonEmpty + val shouldShowCompletedJobs = completedJobs.nonEmpty + val shouldShowFailedJobs = failedJobs.nonEmpty + + val completedJobNumStr = s"${completedJobs.size}" + val schedulingMode = store.environmentInfo().sparkProperties.toMap + .get("spark.scheduler.mode") + .map { mode => SchedulingMode.withName(mode).toString } + .getOrElse("Unknown") + + val summary: NodeSeq = +
      +
        +
      • + User: + {parent.getSparkUser} +
      • +
      • + Total Uptime: { - if (shouldShowActiveJobs) { -
      • - Active Jobs: - {activeJobs.size} -
      • + if (endTime < 0 && parent.sc.isDefined) { + UIUtils.formatDuration(System.currentTimeMillis() - startTime) + } else if (endTime > 0) { + UIUtils.formatDuration(endTime - startTime) } } - { - if (shouldShowCompletedJobs) { -
      • - Completed Jobs: - {completedJobNumStr} -
      • - } + +
      • + Scheduling Mode: + {schedulingMode} +
      • + { + if (shouldShowActiveJobs) { +
      • + Active Jobs: + {activeJobs.size} +
      • } - { - if (shouldShowFailedJobs) { -
      • - Failed Jobs: - {listener.numFailedJobs} -
      • - } + } + { + if (shouldShowCompletedJobs) { +
      • + Completed Jobs: + {completedJobNumStr} +
      • } -
      -
      + } + { + if (shouldShowFailedJobs) { +
    • + Failed Jobs: + {failedJobs.size} +
    • + } + } +
    +
    - var content = summary - content ++= makeTimeline(activeJobs ++ completedJobs ++ failedJobs, - parent.parent.store.executorList(false), startTime) + var content = summary + content ++= makeTimeline(activeJobs ++ completedJobs ++ failedJobs, + store.executorList(false), startTime) - if (shouldShowActiveJobs) { - content ++=

    Active Jobs ({activeJobs.size})

    ++ - activeJobsTable - } - if (shouldShowCompletedJobs) { - content ++=

    Completed Jobs ({completedJobNumStr})

    ++ - completedJobsTable - } - if (shouldShowFailedJobs) { - content ++=

    Failed Jobs ({failedJobs.size})

    ++ - failedJobsTable - } + if (shouldShowActiveJobs) { + content ++=

    Active Jobs ({activeJobs.size})

    ++ + activeJobsTable + } + if (shouldShowCompletedJobs) { + content ++=

    Completed Jobs ({completedJobNumStr})

    ++ + completedJobsTable + } + if (shouldShowFailedJobs) { + content ++=

    Failed Jobs ({failedJobs.size})

    ++ + failedJobsTable + } - val helpText = """A job is triggered by an action, like count() or saveAsTextFile().""" + - " Click on a job to see information about the stages of tasks inside it." + val helpText = """A job is triggered by an action, like count() or saveAsTextFile().""" + + " Click on a job to see information about the stages of tasks inside it." - UIUtils.headerSparkPage("Spark Jobs", content, parent, helpText = Some(helpText)) - } + UIUtils.headerSparkPage("Spark Jobs", content, parent, helpText = Some(helpText)) } + } private[ui] class JobTableRowData( - val jobData: JobUIData, + val jobData: v1.JobData, val lastStageName: String, val lastStageDescription: String, val duration: Long, @@ -395,9 +391,8 @@ private[ui] class JobTableRowData( val detailUrl: String) private[ui] class JobDataSource( - jobs: Seq[JobUIData], - stageIdToInfo: HashMap[Int, StageInfo], - stageIdToData: HashMap[(Int, Int), StageUIData], + store: AppStatusStore, + jobs: Seq[v1.JobData], basePath: String, currentTime: Long, pageSize: Int, @@ -418,40 +413,28 @@ private[ui] class JobDataSource( r } - private def getLastStageNameAndDescription(job: JobUIData): (String, String) = { - val lastStageInfo = Option(job.stageIds) - .filter(_.nonEmpty) - .flatMap { ids => stageIdToInfo.get(ids.max)} - val lastStageData = lastStageInfo.flatMap { s => - stageIdToData.get((s.stageId, s.attemptId)) - } - val name = lastStageInfo.map(_.name).getOrElse("(Unknown Stage Name)") - val description = lastStageData.flatMap(_.description).getOrElse("") - (name, description) - } - - private def jobRow(jobData: JobUIData): JobTableRowData = { - val (lastStageName, lastStageDescription) = getLastStageNameAndDescription(jobData) + private def jobRow(jobData: v1.JobData): JobTableRowData = { val duration: Option[Long] = { jobData.submissionTime.map { start => - val end = jobData.completionTime.getOrElse(System.currentTimeMillis()) - end - start + val end = jobData.completionTime.map(_.getTime()).getOrElse(System.currentTimeMillis()) + end - start.getTime() } } val formattedDuration = duration.map(d => UIUtils.formatDuration(d)).getOrElse("Unknown") val submissionTime = jobData.submissionTime val formattedSubmissionTime = submissionTime.map(UIUtils.formatDate).getOrElse("Unknown") - val jobDescription = UIUtils.makeDescription(lastStageDescription, basePath, plainText = false) + val jobDescription = UIUtils.makeDescription(jobData.description.getOrElse(""), + basePath, plainText = false) val detailUrl = "%s/jobs/job?id=%s".format(basePath, jobData.jobId) - new JobTableRowData ( + new JobTableRowData( jobData, - lastStageName, - lastStageDescription, + jobData.name, + jobData.description.getOrElse(jobData.name), duration.getOrElse(-1), formattedDuration, - submissionTime.getOrElse(-1), + submissionTime.map(_.getTime()).getOrElse(-1L), formattedSubmissionTime, jobDescription, detailUrl @@ -479,15 +462,15 @@ private[ui] class JobDataSource( } } + private[ui] class JobPagedTable( - data: Seq[JobUIData], + store: AppStatusStore, + data: Seq[v1.JobData], tableHeaderId: String, jobTag: String, basePath: String, subPath: String, parameterOtherTable: Iterable[String], - stageIdToInfo: HashMap[Int, StageInfo], - stageIdToData: HashMap[(Int, Int), StageUIData], killEnabled: Boolean, currentTime: Long, jobIdTitle: String, @@ -510,9 +493,8 @@ private[ui] class JobPagedTable( override def pageNumberFormField: String = jobTag + ".page" override val dataSource = new JobDataSource( + store, data, - stageIdToInfo, - stageIdToData, basePath, currentTime, pageSize, @@ -624,15 +606,15 @@ private[ui] class JobPagedTable( {jobTableRow.formattedDuration} - {job.completedStageIndices.size}/{job.stageIds.size - job.numSkippedStages} + {job.numCompletedStages}/{job.stageIds.size - job.numSkippedStages} {if (job.numFailedStages > 0) s"(${job.numFailedStages} failed)"} {if (job.numSkippedStages > 0) s"(${job.numSkippedStages} skipped)"} {UIUtils.makeProgressBar(started = job.numActiveTasks, - completed = job.completedIndices.size, + completed = job.numCompletedIndices, failed = job.numFailedTasks, skipped = job.numSkippedTasks, - reasonToNumKilled = job.reasonToNumKilled, total = job.numTasks - job.numSkippedTasks)} + reasonToNumKilled = job.killedTasksSummary, total = job.numTasks - job.numSkippedTasks)} } diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/AllStagesPage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/AllStagesPage.scala index dc5b03c5269a9..e4cf99e7b9e04 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/AllStagesPage.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/AllStagesPage.scala @@ -22,120 +22,121 @@ import javax.servlet.http.HttpServletRequest import scala.xml.{Node, NodeSeq} import org.apache.spark.scheduler.Schedulable +import org.apache.spark.status.PoolData +import org.apache.spark.status.api.v1._ import org.apache.spark.ui.{UIUtils, WebUIPage} /** Page showing list of all ongoing and recently finished stages and pools */ private[ui] class AllStagesPage(parent: StagesTab) extends WebUIPage("") { private val sc = parent.sc - private val listener = parent.progressListener private def isFairScheduler = parent.isFairScheduler def render(request: HttpServletRequest): Seq[Node] = { - listener.synchronized { - val activeStages = listener.activeStages.values.toSeq - val pendingStages = listener.pendingStages.values.toSeq - val completedStages = listener.completedStages.reverse - val numCompletedStages = listener.numCompletedStages - val failedStages = listener.failedStages.reverse - val numFailedStages = listener.numFailedStages - val subPath = "stages" + val allStages = parent.store.stageList(null) - val activeStagesTable = - new StageTableBase(request, activeStages, "active", "activeStage", parent.basePath, subPath, - parent.progressListener, parent.isFairScheduler, - killEnabled = parent.killEnabled, isFailedStage = false) - val pendingStagesTable = - new StageTableBase(request, pendingStages, "pending", "pendingStage", parent.basePath, - subPath, parent.progressListener, parent.isFairScheduler, - killEnabled = false, isFailedStage = false) - val completedStagesTable = - new StageTableBase(request, completedStages, "completed", "completedStage", parent.basePath, - subPath, parent.progressListener, parent.isFairScheduler, - killEnabled = false, isFailedStage = false) - val failedStagesTable = - new StageTableBase(request, failedStages, "failed", "failedStage", parent.basePath, subPath, - parent.progressListener, parent.isFairScheduler, - killEnabled = false, isFailedStage = true) + val activeStages = allStages.filter(_.status == StageStatus.ACTIVE) + val pendingStages = allStages.filter(_.status == StageStatus.PENDING) + val completedStages = allStages.filter(_.status == StageStatus.COMPLETE) + val failedStages = allStages.filter(_.status == StageStatus.FAILED).reverse - // For now, pool information is only accessible in live UIs - val pools = sc.map(_.getAllPools).getOrElse(Seq.empty[Schedulable]) - val poolTable = new PoolTable(pools, parent) + val numCompletedStages = completedStages.size + val numFailedStages = failedStages.size + val subPath = "stages" - val shouldShowActiveStages = activeStages.nonEmpty - val shouldShowPendingStages = pendingStages.nonEmpty - val shouldShowCompletedStages = completedStages.nonEmpty - val shouldShowFailedStages = failedStages.nonEmpty + val activeStagesTable = + new StageTableBase(parent.store, request, activeStages, "active", "activeStage", + parent.basePath, subPath, parent.isFairScheduler, parent.killEnabled, false) + val pendingStagesTable = + new StageTableBase(parent.store, request, pendingStages, "pending", "pendingStage", + parent.basePath, subPath, parent.isFairScheduler, false, false) + val completedStagesTable = + new StageTableBase(parent.store, request, completedStages, "completed", "completedStage", + parent.basePath, subPath, parent.isFairScheduler, false, false) + val failedStagesTable = + new StageTableBase(parent.store, request, failedStages, "failed", "failedStage", + parent.basePath, subPath, parent.isFairScheduler, false, true) - val completedStageNumStr = if (numCompletedStages == completedStages.size) { - s"$numCompletedStages" - } else { - s"$numCompletedStages, only showing ${completedStages.size}" - } + // For now, pool information is only accessible in live UIs + val pools = sc.map(_.getAllPools).getOrElse(Seq.empty[Schedulable]).map { pool => + val uiPool = parent.store.asOption(parent.store.pool(pool.name)).getOrElse( + new PoolData(pool.name, Set())) + pool -> uiPool + }.toMap + val poolTable = new PoolTable(pools, parent) + + val shouldShowActiveStages = activeStages.nonEmpty + val shouldShowPendingStages = pendingStages.nonEmpty + val shouldShowCompletedStages = completedStages.nonEmpty + val shouldShowFailedStages = failedStages.nonEmpty + + val completedStageNumStr = if (numCompletedStages == completedStages.size) { + s"$numCompletedStages" + } else { + s"$numCompletedStages, only showing ${completedStages.size}" + } - val summary: NodeSeq = -
    -
      - { - if (shouldShowActiveStages) { -
    • - Active Stages: - {activeStages.size} -
    • - } + val summary: NodeSeq = +
      +
        + { + if (shouldShowActiveStages) { +
      • + Active Stages: + {activeStages.size} +
      • } - { - if (shouldShowPendingStages) { -
      • - Pending Stages: - {pendingStages.size} -
      • - } + } + { + if (shouldShowPendingStages) { +
      • + Pending Stages: + {pendingStages.size} +
      • } - { - if (shouldShowCompletedStages) { -
      • - Completed Stages: - {completedStageNumStr} -
      • - } + } + { + if (shouldShowCompletedStages) { +
      • + Completed Stages: + {completedStageNumStr} +
      • } - { - if (shouldShowFailedStages) { -
      • - Failed Stages: - {numFailedStages} -
      • - } + } + { + if (shouldShowFailedStages) { +
      • + Failed Stages: + {numFailedStages} +
      • } -
      -
      - - var content = summary ++ - { - if (sc.isDefined && isFairScheduler) { -

      Fair Scheduler Pools ({pools.size})

      ++ poolTable.toNodeSeq - } else { - Seq.empty[Node] } +
    +
    + + var content = summary ++ + { + if (sc.isDefined && isFairScheduler) { +

    Fair Scheduler Pools ({pools.size})

    ++ poolTable.toNodeSeq + } else { + Seq.empty[Node] } - if (shouldShowActiveStages) { - content ++=

    Active Stages ({activeStages.size})

    ++ - activeStagesTable.toNodeSeq - } - if (shouldShowPendingStages) { - content ++=

    Pending Stages ({pendingStages.size})

    ++ - pendingStagesTable.toNodeSeq - } - if (shouldShowCompletedStages) { - content ++=

    Completed Stages ({completedStageNumStr})

    ++ - completedStagesTable.toNodeSeq - } - if (shouldShowFailedStages) { - content ++=

    Failed Stages ({numFailedStages})

    ++ - failedStagesTable.toNodeSeq } - UIUtils.headerSparkPage("Stages for All Jobs", content, parent) + if (shouldShowActiveStages) { + content ++=

    Active Stages ({activeStages.size})

    ++ + activeStagesTable.toNodeSeq } + if (shouldShowPendingStages) { + content ++=

    Pending Stages ({pendingStages.size})

    ++ + pendingStagesTable.toNodeSeq + } + if (shouldShowCompletedStages) { + content ++=

    Completed Stages ({completedStageNumStr})

    ++ + completedStagesTable.toNodeSeq + } + if (shouldShowFailedStages) { + content ++=

    Failed Stages ({numFailedStages})

    ++ + failedStagesTable.toNodeSeq + } + UIUtils.headerSparkPage("Stages for All Jobs", content, parent) } } - diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/ExecutorTable.scala b/core/src/main/scala/org/apache/spark/ui/jobs/ExecutorTable.scala index 07a41d195a191..41d42b52430a5 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/ExecutorTable.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/ExecutorTable.scala @@ -17,44 +17,19 @@ package org.apache.spark.ui.jobs -import scala.collection.mutable import scala.xml.{Node, Unparsed} import org.apache.spark.status.AppStatusStore +import org.apache.spark.status.api.v1.StageData import org.apache.spark.ui.{ToolTips, UIUtils} -import org.apache.spark.ui.jobs.UIData.StageUIData import org.apache.spark.util.Utils /** Stage summary grouped by executors. */ -private[ui] class ExecutorTable( - stageId: Int, - stageAttemptId: Int, - parent: StagesTab, - store: AppStatusStore) { - private val listener = parent.progressListener +private[ui] class ExecutorTable(stage: StageData, store: AppStatusStore) { - def toNodeSeq: Seq[Node] = { - listener.synchronized { - executorTable() - } - } - - /** Special table which merges two header cells. */ - private def executorTable[T](): Seq[Node] = { - val stageData = listener.stageIdToData.get((stageId, stageAttemptId)) - var hasInput = false - var hasOutput = false - var hasShuffleWrite = false - var hasShuffleRead = false - var hasBytesSpilled = false - stageData.foreach { data => - hasInput = data.hasInput - hasOutput = data.hasOutput - hasShuffleRead = data.hasShuffleRead - hasShuffleWrite = data.hasShuffleWrite - hasBytesSpilled = data.hasBytesSpilled - } + import ApiHelper._ + def toNodeSeq: Seq[Node] = { @@ -64,29 +39,29 @@ private[ui] class ExecutorTable( - {if (hasInput) { + {if (hasInput(stage)) { }} - {if (hasOutput) { + {if (hasOutput(stage)) { }} - {if (hasShuffleRead) { + {if (hasShuffleRead(stage)) { }} - {if (hasShuffleWrite) { + {if (hasShuffleWrite(stage)) { }} - {if (hasBytesSpilled) { + {if (hasBytesSpilled(stage)) { }} @@ -97,7 +72,7 @@ private[ui] class ExecutorTable( - {createExecutorTable()} + {createExecutorTable(stage)}
    Executor IDFailed Tasks Killed Tasks Succeeded Tasks Input Size / Records Output Size / Records Shuffle Read Size / Records Shuffle Write Size / Records Shuffle Spill (Memory) Shuffle Spill (Disk)
    } - private def createExecutorTable() : Seq[Node] = { - // Make an executor-id -> address map - val executorIdToAddress = mutable.HashMap[String, String]() - listener.blockManagerIds.foreach { blockManagerId => - val address = blockManagerId.hostPort - val executorId = blockManagerId.executorId - executorIdToAddress.put(executorId, address) - } - - listener.stageIdToData.get((stageId, stageAttemptId)) match { - case Some(stageData: StageUIData) => - stageData.executorSummary.toSeq.sortBy(_._1).map { case (k, v) => - - -
    {k}
    -
    - { - store.executorSummary(k).map(_.executorLogs).getOrElse(Map.empty).map { - case (logName, logUrl) => - } - } -
    - - {executorIdToAddress.getOrElse(k, "CANNOT FIND ADDRESS")} - {UIUtils.formatDuration(v.taskTime)} - {v.failedTasks + v.succeededTasks + v.reasonToNumKilled.values.sum} - {v.failedTasks} - {v.reasonToNumKilled.values.sum} - {v.succeededTasks} - {if (stageData.hasInput) { - - {s"${Utils.bytesToString(v.inputBytes)} / ${v.inputRecords}"} - - }} - {if (stageData.hasOutput) { - - {s"${Utils.bytesToString(v.outputBytes)} / ${v.outputRecords}"} - - }} - {if (stageData.hasShuffleRead) { - - {s"${Utils.bytesToString(v.shuffleRead)} / ${v.shuffleReadRecords}"} - - }} - {if (stageData.hasShuffleWrite) { - - {s"${Utils.bytesToString(v.shuffleWrite)} / ${v.shuffleWriteRecords}"} - - }} - {if (stageData.hasBytesSpilled) { - - {Utils.bytesToString(v.memoryBytesSpilled)} - - - {Utils.bytesToString(v.diskBytesSpilled)} - - }} - {v.isBlacklisted} - - } - case None => - Seq.empty[Node] + private def createExecutorTable(stage: StageData) : Seq[Node] = { + stage.executorSummary.getOrElse(Map.empty).toSeq.sortBy(_._1).map { case (k, v) => + val executor = store.asOption(store.executorSummary(k)) + + +
    {k}
    +
    + { + executor.map(_.executorLogs).getOrElse(Map.empty).map { + case (logName, logUrl) => + } + } +
    + + {executor.map { e => e.hostPort }.getOrElse("CANNOT FIND ADDRESS")} + {UIUtils.formatDuration(v.taskTime)} + {v.failedTasks + v.succeededTasks + v.killedTasks} + {v.failedTasks} + {v.killedTasks} + {v.succeededTasks} + {if (hasInput(stage)) { + + {s"${Utils.bytesToString(v.inputBytes)} / ${v.inputRecords}"} + + }} + {if (hasOutput(stage)) { + + {s"${Utils.bytesToString(v.outputBytes)} / ${v.outputRecords}"} + + }} + {if (hasShuffleRead(stage)) { + + {s"${Utils.bytesToString(v.shuffleRead)} / ${v.shuffleReadRecords}"} + + }} + {if (hasShuffleWrite(stage)) { + + {s"${Utils.bytesToString(v.shuffleWrite)} / ${v.shuffleWriteRecords}"} + + }} + {if (hasBytesSpilled(stage)) { + + {Utils.bytesToString(v.memoryBytesSpilled)} + + + {Utils.bytesToString(v.diskBytesSpilled)} + + }} + {executor.map(_.isBlacklisted).getOrElse(false)} + } } + } diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/JobPage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/JobPage.scala index 7ed01646f3621..740f12e7d13d4 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/JobPage.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/JobPage.scala @@ -17,7 +17,7 @@ package org.apache.spark.ui.jobs -import java.util.{Date, Locale} +import java.util.Locale import javax.servlet.http.HttpServletRequest import scala.collection.mutable.{Buffer, ListBuffer} @@ -27,11 +27,12 @@ import org.apache.commons.lang3.StringEscapeUtils import org.apache.spark.JobExecutionStatus import org.apache.spark.scheduler._ -import org.apache.spark.status.api.v1.ExecutorSummary -import org.apache.spark.ui.{ToolTips, UIUtils, WebUIPage} +import org.apache.spark.status.AppStatusStore +import org.apache.spark.status.api.v1 +import org.apache.spark.ui._ /** Page showing statistics and stage list for a given job */ -private[ui] class JobPage(parent: JobsTab) extends WebUIPage("job") { +private[ui] class JobPage(parent: JobsTab, store: AppStatusStore) extends WebUIPage("job") { private val STAGES_LEGEND =
    @@ -56,14 +57,15 @@ private[ui] class JobPage(parent: JobsTab) extends WebUIPage("job") { Removed
    .toString.filter(_ != '\n') - private def makeStageEvent(stageInfos: Seq[StageInfo]): Seq[String] = { + private def makeStageEvent(stageInfos: Seq[v1.StageData]): Seq[String] = { stageInfos.map { stage => val stageId = stage.stageId val attemptId = stage.attemptId val name = stage.name - val status = stage.getStatusString - val submissionTime = stage.submissionTime.get - val completionTime = stage.completionTime.getOrElse(System.currentTimeMillis()) + val status = stage.status.toString + val submissionTime = stage.submissionTime.get.getTime() + val completionTime = stage.completionTime.map(_.getTime()) + .getOrElse(System.currentTimeMillis()) // The timeline library treats contents as HTML, so we have to escape them. We need to add // extra layers of escaping in order to embed this in a Javascript string literal. @@ -79,10 +81,10 @@ private[ui] class JobPage(parent: JobsTab) extends WebUIPage("job") { | 'data-placement="top" data-html="true"' + | 'data-title="${jsEscapedName} (Stage ${stageId}.${attemptId})
    ' + | 'Status: ${status.toUpperCase(Locale.ROOT)}
    ' + - | 'Submitted: ${UIUtils.formatDate(new Date(submissionTime))}' + + | 'Submitted: ${UIUtils.formatDate(submissionTime)}' + | '${ if (status != "running") { - s"""
    Completed: ${UIUtils.formatDate(new Date(completionTime))}""" + s"""
    Completed: ${UIUtils.formatDate(completionTime)}""" } else { "" } @@ -93,7 +95,7 @@ private[ui] class JobPage(parent: JobsTab) extends WebUIPage("job") { } } - def makeExecutorEvent(executors: Seq[ExecutorSummary]): Seq[String] = { + def makeExecutorEvent(executors: Seq[v1.ExecutorSummary]): Seq[String] = { val events = ListBuffer[String]() executors.foreach { e => val addedEvent = @@ -137,8 +139,8 @@ private[ui] class JobPage(parent: JobsTab) extends WebUIPage("job") { } private def makeTimeline( - stages: Seq[StageInfo], - executors: Seq[ExecutorSummary], + stages: Seq[v1.StageData], + executors: Seq[v1.ExecutorSummary], appStartTime: Long): Seq[Node] = { val stageEventJsonAsStrSeq = makeStageEvent(stages) @@ -182,173 +184,181 @@ private[ui] class JobPage(parent: JobsTab) extends WebUIPage("job") { } def render(request: HttpServletRequest): Seq[Node] = { - val listener = parent.jobProgresslistener + // stripXSS is called first to remove suspicious characters used in XSS attacks + val parameterId = UIUtils.stripXSS(request.getParameter("id")) + require(parameterId != null && parameterId.nonEmpty, "Missing id parameter") - listener.synchronized { - // stripXSS is called first to remove suspicious characters used in XSS attacks - val parameterId = UIUtils.stripXSS(request.getParameter("id")) - require(parameterId != null && parameterId.nonEmpty, "Missing id parameter") - - val jobId = parameterId.toInt - val jobDataOption = listener.jobIdToData.get(jobId) - if (jobDataOption.isEmpty) { - val content = -
    -

    No information to display for job {jobId}

    -
    - return UIUtils.headerSparkPage( - s"Details for Job $jobId", content, parent) - } - val jobData = jobDataOption.get - val isComplete = jobData.status != JobExecutionStatus.RUNNING - val stages = jobData.stageIds.map { stageId => - // This could be empty if the JobProgressListener hasn't received information about the - // stage or if the stage information has been garbage collected - listener.stageIdToInfo.getOrElse(stageId, - new StageInfo(stageId, 0, "Unknown", 0, Seq.empty, Seq.empty, "Unknown")) + val jobId = parameterId.toInt + val jobData = store.asOption(store.job(jobId)).getOrElse { + val content = +
    +

    No information to display for job {jobId}

    +
    + return UIUtils.headerSparkPage( + s"Details for Job $jobId", content, parent) + } + val isComplete = jobData.status != JobExecutionStatus.RUNNING + val stages = jobData.stageIds.map { stageId => + // This could be empty if the listener hasn't received information about the + // stage or if the stage information has been garbage collected + store.stageData(stageId).lastOption.getOrElse { + new v1.StageData( + v1.StageStatus.PENDING, + stageId, + 0, 0, 0, 0, 0, 0, 0, + 0L, 0L, None, None, None, None, + 0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L, + "Unknown", + None, + "Unknown", + null, + Nil, + Nil, + None, + None, + Map()) } + } - val activeStages = Buffer[StageInfo]() - val completedStages = Buffer[StageInfo]() - // If the job is completed, then any pending stages are displayed as "skipped": - val pendingOrSkippedStages = Buffer[StageInfo]() - val failedStages = Buffer[StageInfo]() - for (stage <- stages) { - if (stage.submissionTime.isEmpty) { - pendingOrSkippedStages += stage - } else if (stage.completionTime.isDefined) { - if (stage.failureReason.isDefined) { - failedStages += stage - } else { - completedStages += stage - } + val activeStages = Buffer[v1.StageData]() + val completedStages = Buffer[v1.StageData]() + // If the job is completed, then any pending stages are displayed as "skipped": + val pendingOrSkippedStages = Buffer[v1.StageData]() + val failedStages = Buffer[v1.StageData]() + for (stage <- stages) { + if (stage.submissionTime.isEmpty) { + pendingOrSkippedStages += stage + } else if (stage.completionTime.isDefined) { + if (stage.status == v1.StageStatus.FAILED) { + failedStages += stage } else { - activeStages += stage + completedStages += stage } + } else { + activeStages += stage } + } - val basePath = "jobs/job" + val basePath = "jobs/job" - val pendingOrSkippedTableId = - if (isComplete) { - "pending" - } else { - "skipped" - } + val pendingOrSkippedTableId = + if (isComplete) { + "pending" + } else { + "skipped" + } - val activeStagesTable = - new StageTableBase(request, activeStages, "active", "activeStage", parent.basePath, - basePath, parent.jobProgresslistener, parent.isFairScheduler, - killEnabled = parent.killEnabled, isFailedStage = false) - val pendingOrSkippedStagesTable = - new StageTableBase(request, pendingOrSkippedStages, pendingOrSkippedTableId, "pendingStage", - parent.basePath, basePath, parent.jobProgresslistener, parent.isFairScheduler, - killEnabled = false, isFailedStage = false) - val completedStagesTable = - new StageTableBase(request, completedStages, "completed", "completedStage", parent.basePath, - basePath, parent.jobProgresslistener, parent.isFairScheduler, - killEnabled = false, isFailedStage = false) - val failedStagesTable = - new StageTableBase(request, failedStages, "failed", "failedStage", parent.basePath, - basePath, parent.jobProgresslistener, parent.isFairScheduler, - killEnabled = false, isFailedStage = true) + val activeStagesTable = + new StageTableBase(store, request, activeStages, "active", "activeStage", parent.basePath, + basePath, parent.isFairScheduler, + killEnabled = parent.killEnabled, isFailedStage = false) + val pendingOrSkippedStagesTable = + new StageTableBase(store, request, pendingOrSkippedStages, pendingOrSkippedTableId, + "pendingStage", parent.basePath, basePath, parent.isFairScheduler, + killEnabled = false, isFailedStage = false) + val completedStagesTable = + new StageTableBase(store, request, completedStages, "completed", "completedStage", + parent.basePath, basePath, parent.isFairScheduler, + killEnabled = false, isFailedStage = false) + val failedStagesTable = + new StageTableBase(store, request, failedStages, "failed", "failedStage", parent.basePath, + basePath, parent.isFairScheduler, + killEnabled = false, isFailedStage = true) - val shouldShowActiveStages = activeStages.nonEmpty - val shouldShowPendingStages = !isComplete && pendingOrSkippedStages.nonEmpty - val shouldShowCompletedStages = completedStages.nonEmpty - val shouldShowSkippedStages = isComplete && pendingOrSkippedStages.nonEmpty - val shouldShowFailedStages = failedStages.nonEmpty + val shouldShowActiveStages = activeStages.nonEmpty + val shouldShowPendingStages = !isComplete && pendingOrSkippedStages.nonEmpty + val shouldShowCompletedStages = completedStages.nonEmpty + val shouldShowSkippedStages = isComplete && pendingOrSkippedStages.nonEmpty + val shouldShowFailedStages = failedStages.nonEmpty - val summary: NodeSeq = -
    -
      -
    • - Status: - {jobData.status} -
    • - { - if (jobData.jobGroup.isDefined) { -
    • - Job Group: - {jobData.jobGroup.get} -
    • - } - } - { - if (shouldShowActiveStages) { -
    • - Active Stages: - {activeStages.size} -
    • - } - } - { - if (shouldShowPendingStages) { -
    • - - Pending Stages: - {pendingOrSkippedStages.size} -
    • - } + val summary: NodeSeq = +
      +
        +
      • + Status: + {jobData.status} +
      • + { + if (jobData.jobGroup.isDefined) { +
      • + Job Group: + {jobData.jobGroup.get} +
      • } - { - if (shouldShowCompletedStages) { -
      • - Completed Stages: - {completedStages.size} -
      • - } + } + { + if (shouldShowActiveStages) { +
      • + Active Stages: + {activeStages.size} +
      • } - { - if (shouldShowSkippedStages) { + } + { + if (shouldShowPendingStages) {
      • - Skipped Stages: - {pendingOrSkippedStages.size} + + Pending Stages: + {pendingOrSkippedStages.size}
      • } + } + { + if (shouldShowCompletedStages) { +
      • + Completed Stages: + {completedStages.size} +
      • } - { - if (shouldShowFailedStages) { -
      • - Failed Stages: - {failedStages.size} -
      • - } + } + { + if (shouldShowSkippedStages) { +
      • + Skipped Stages: + {pendingOrSkippedStages.size} +
      • + } + } + { + if (shouldShowFailedStages) { +
      • + Failed Stages: + {failedStages.size} +
      • } -
      -
      + } +
    +
    - var content = summary - val appStartTime = listener.startTime - val operationGraphListener = parent.operationGraphListener + var content = summary + val appStartTime = store.applicationInfo().attempts.head.startTime.getTime() - content ++= makeTimeline(activeStages ++ completedStages ++ failedStages, - parent.parent.store.executorList(false), appStartTime) + content ++= makeTimeline(activeStages ++ completedStages ++ failedStages, + store.executorList(false), appStartTime) - content ++= UIUtils.showDagVizForJob( - jobId, operationGraphListener.getOperationGraphForJob(jobId)) + content ++= UIUtils.showDagVizForJob( + jobId, store.operationGraphForJob(jobId)) - if (shouldShowActiveStages) { - content ++=

    Active Stages ({activeStages.size})

    ++ - activeStagesTable.toNodeSeq - } - if (shouldShowPendingStages) { - content ++=

    Pending Stages ({pendingOrSkippedStages.size})

    ++ - pendingOrSkippedStagesTable.toNodeSeq - } - if (shouldShowCompletedStages) { - content ++=

    Completed Stages ({completedStages.size})

    ++ - completedStagesTable.toNodeSeq - } - if (shouldShowSkippedStages) { - content ++=

    Skipped Stages ({pendingOrSkippedStages.size})

    ++ - pendingOrSkippedStagesTable.toNodeSeq - } - if (shouldShowFailedStages) { - content ++=

    Failed Stages ({failedStages.size})

    ++ - failedStagesTable.toNodeSeq - } - UIUtils.headerSparkPage(s"Details for Job $jobId", content, parent, showVisualization = true) + if (shouldShowActiveStages) { + content ++=

    Active Stages ({activeStages.size})

    ++ + activeStagesTable.toNodeSeq + } + if (shouldShowPendingStages) { + content ++=

    Pending Stages ({pendingOrSkippedStages.size})

    ++ + pendingOrSkippedStagesTable.toNodeSeq + } + if (shouldShowCompletedStages) { + content ++=

    Completed Stages ({completedStages.size})

    ++ + completedStagesTable.toNodeSeq + } + if (shouldShowSkippedStages) { + content ++=

    Skipped Stages ({pendingOrSkippedStages.size})

    ++ + pendingOrSkippedStagesTable.toNodeSeq + } + if (shouldShowFailedStages) { + content ++=

    Failed Stages ({failedStages.size})

    ++ + failedStagesTable.toNodeSeq } + UIUtils.headerSparkPage(s"Details for Job $jobId", content, parent, showVisualization = true) } } diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/JobsTab.scala b/core/src/main/scala/org/apache/spark/ui/jobs/JobsTab.scala index 81ffe04aca49a..99eab1b2a27d8 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/JobsTab.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/JobsTab.scala @@ -19,35 +19,45 @@ package org.apache.spark.ui.jobs import javax.servlet.http.HttpServletRequest +import scala.collection.JavaConverters._ + +import org.apache.spark.JobExecutionStatus import org.apache.spark.scheduler.SchedulingMode -import org.apache.spark.ui.{SparkUI, SparkUITab, UIUtils} +import org.apache.spark.status.AppStatusStore +import org.apache.spark.ui._ /** Web UI showing progress status of all jobs in the given SparkContext. */ -private[ui] class JobsTab(val parent: SparkUI) extends SparkUITab(parent, "jobs") { +private[ui] class JobsTab(parent: SparkUI, store: AppStatusStore) + extends SparkUITab(parent, "jobs") { + val sc = parent.sc val killEnabled = parent.killEnabled - val jobProgresslistener = parent.jobProgressListener - val operationGraphListener = parent.operationGraphListener - def isFairScheduler: Boolean = - jobProgresslistener.schedulingMode == Some(SchedulingMode.FAIR) + def isFairScheduler: Boolean = { + store.environmentInfo().sparkProperties.toMap + .get("spark.scheduler.mode") + .map { mode => mode == SchedulingMode.FAIR } + .getOrElse(false) + } def getSparkUser: String = parent.getSparkUser - attachPage(new AllJobsPage(this)) - attachPage(new JobPage(this)) + attachPage(new AllJobsPage(this, store)) + attachPage(new JobPage(this, store)) def handleKillRequest(request: HttpServletRequest): Unit = { if (killEnabled && parent.securityManager.checkModifyPermissions(request.getRemoteUser)) { // stripXSS is called first to remove suspicious characters used in XSS attacks val jobId = Option(UIUtils.stripXSS(request.getParameter("id"))).map(_.toInt) jobId.foreach { id => - if (jobProgresslistener.activeJobs.contains(id)) { - sc.foreach(_.cancelJob(id)) - // Do a quick pause here to give Spark time to kill the job so it shows up as - // killed after the refresh. Note that this will block the serving thread so the - // time should be limited in duration. - Thread.sleep(100) + store.asOption(store.job(id)).foreach { job => + if (job.status == JobExecutionStatus.RUNNING) { + sc.foreach(_.cancelJob(id)) + // Do a quick pause here to give Spark time to kill the job so it shows up as + // killed after the refresh. Note that this will block the serving thread so the + // time should be limited in duration. + Thread.sleep(100) + } } } } diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/PoolPage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/PoolPage.scala index 4b8c7b203771d..98fbd7aceaa11 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/PoolPage.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/PoolPage.scala @@ -21,46 +21,39 @@ import javax.servlet.http.HttpServletRequest import scala.xml.Node -import org.apache.spark.scheduler.StageInfo +import org.apache.spark.status.PoolData +import org.apache.spark.status.api.v1._ import org.apache.spark.ui.{UIUtils, WebUIPage} /** Page showing specific pool details */ private[ui] class PoolPage(parent: StagesTab) extends WebUIPage("pool") { - private val sc = parent.sc - private val listener = parent.progressListener def render(request: HttpServletRequest): Seq[Node] = { - listener.synchronized { - // stripXSS is called first to remove suspicious characters used in XSS attacks - val poolName = Option(UIUtils.stripXSS(request.getParameter("poolname"))).map { poolname => - UIUtils.decodeURLParameter(poolname) - }.getOrElse { - throw new IllegalArgumentException(s"Missing poolname parameter") - } - - val poolToActiveStages = listener.poolToActiveStages - val activeStages = poolToActiveStages.get(poolName) match { - case Some(s) => s.values.toSeq - case None => Seq.empty[StageInfo] - } - val shouldShowActiveStages = activeStages.nonEmpty - val activeStagesTable = - new StageTableBase(request, activeStages, "", "activeStage", parent.basePath, "stages/pool", - parent.progressListener, parent.isFairScheduler, parent.killEnabled, - isFailedStage = false) - - // For now, pool information is only accessible in live UIs - val pools = sc.map(_.getPoolForName(poolName).getOrElse { - throw new IllegalArgumentException(s"Unknown poolname: $poolName") - }).toSeq - val poolTable = new PoolTable(pools, parent) + // stripXSS is called first to remove suspicious characters used in XSS attacks + val poolName = Option(UIUtils.stripXSS(request.getParameter("poolname"))).map { poolname => + UIUtils.decodeURLParameter(poolname) + }.getOrElse { + throw new IllegalArgumentException(s"Missing poolname parameter") + } - var content =

    Summary

    ++ poolTable.toNodeSeq - if (shouldShowActiveStages) { - content ++=

    Active Stages ({activeStages.size})

    ++ activeStagesTable.toNodeSeq - } + // For now, pool information is only accessible in live UIs + val pool = parent.sc.flatMap(_.getPoolForName(poolName)).getOrElse { + throw new IllegalArgumentException(s"Unknown pool: $poolName") + } - UIUtils.headerSparkPage("Fair Scheduler Pool: " + poolName, content, parent) + val uiPool = parent.store.asOption(parent.store.pool(poolName)).getOrElse( + new PoolData(poolName, Set())) + val activeStages = uiPool.stageIds.toSeq.map(parent.store.lastStageAttempt(_)) + val activeStagesTable = + new StageTableBase(parent.store, request, activeStages, "", "activeStage", parent.basePath, + "stages/pool", parent.isFairScheduler, parent.killEnabled, false) + + val poolTable = new PoolTable(Map(pool -> uiPool), parent) + var content =

    Summary

    ++ poolTable.toNodeSeq + if (activeStages.nonEmpty) { + content ++=

    Active Stages ({activeStages.size})

    ++ activeStagesTable.toNodeSeq } + + UIUtils.headerSparkPage("Fair Scheduler Pool: " + poolName, content, parent) } } diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/PoolTable.scala b/core/src/main/scala/org/apache/spark/ui/jobs/PoolTable.scala index ea02968733cac..5dfce858dec07 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/PoolTable.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/PoolTable.scala @@ -19,25 +19,16 @@ package org.apache.spark.ui.jobs import java.net.URLEncoder -import scala.collection.mutable.HashMap import scala.xml.Node -import org.apache.spark.scheduler.{Schedulable, StageInfo} +import org.apache.spark.scheduler.Schedulable +import org.apache.spark.status.PoolData import org.apache.spark.ui.UIUtils /** Table showing list of pools */ -private[ui] class PoolTable(pools: Seq[Schedulable], parent: StagesTab) { - private val listener = parent.progressListener +private[ui] class PoolTable(pools: Map[Schedulable, PoolData], parent: StagesTab) { def toNodeSeq: Seq[Node] = { - listener.synchronized { - poolTable(poolRow, pools) - } - } - - private def poolTable( - makeRow: (Schedulable, HashMap[String, HashMap[Int, StageInfo]]) => Seq[Node], - rows: Seq[Schedulable]): Seq[Node] = { @@ -48,29 +39,24 @@ private[ui] class PoolTable(pools: Seq[Schedulable], parent: StagesTab) { - {rows.map(r => makeRow(r, listener.poolToActiveStages))} + {pools.map { case (s, p) => poolRow(s, p) }}
    Pool NameSchedulingMode
    } - private def poolRow( - p: Schedulable, - poolToActiveStages: HashMap[String, HashMap[Int, StageInfo]]): Seq[Node] = { - val activeStages = poolToActiveStages.get(p.name) match { - case Some(stages) => stages.size - case None => 0 - } + private def poolRow(s: Schedulable, p: PoolData): Seq[Node] = { + val activeStages = p.stageIds.size val href = "%s/stages/pool?poolname=%s" .format(UIUtils.prependBaseUri(parent.basePath), URLEncoder.encode(p.name, "UTF-8")) {p.name} - {p.minShare} - {p.weight} + {s.minShare} + {s.weight} {activeStages} - {p.runningTasks} - {p.schedulingMode} + {s.runningTasks} + {s.schedulingMode} } } diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala index 3151b8d554658..5f93f2ffb412f 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala @@ -21,26 +21,25 @@ import java.net.URLEncoder import java.util.Date import javax.servlet.http.HttpServletRequest -import scala.collection.mutable.HashSet +import scala.collection.mutable.{HashMap, HashSet} import scala.xml.{Elem, Node, Unparsed} import org.apache.commons.lang3.StringEscapeUtils import org.apache.spark.SparkConf -import org.apache.spark.executor.TaskMetrics -import org.apache.spark.scheduler.{AccumulableInfo, TaskInfo, TaskLocality} +import org.apache.spark.internal.config._ +import org.apache.spark.scheduler.TaskLocality import org.apache.spark.status.AppStatusStore +import org.apache.spark.status.api.v1._ import org.apache.spark.ui._ import org.apache.spark.ui.jobs.UIData._ import org.apache.spark.util.{Distribution, Utils} /** Page showing statistics and task list for a given stage */ private[ui] class StagePage(parent: StagesTab, store: AppStatusStore) extends WebUIPage("stage") { + import ApiHelper._ import StagePage._ - private val progressListener = parent.progressListener - private val operationGraphListener = parent.operationGraphListener - private val TIMELINE_LEGEND = {
    @@ -69,555 +68,521 @@ private[ui] class StagePage(parent: StagesTab, store: AppStatusStore) extends We // if we find that it's okay. private val MAX_TIMELINE_TASKS = parent.conf.getInt("spark.ui.timeline.tasks.maximum", 1000) - private def getLocalitySummaryString(stageData: StageUIData): String = { - val localities = stageData.taskData.values.map(_.taskInfo.taskLocality) + private def getLocalitySummaryString(stageData: StageData, taskList: Seq[TaskData]): String = { + val localities = taskList.map(_.taskLocality) val localityCounts = localities.groupBy(identity).mapValues(_.size) + val names = Map( + TaskLocality.PROCESS_LOCAL.toString() -> "Process local", + TaskLocality.NODE_LOCAL.toString() -> "Node local", + TaskLocality.RACK_LOCAL.toString() -> "Rack local", + TaskLocality.ANY.toString() -> "Any") val localityNamesAndCounts = localityCounts.toSeq.map { case (locality, count) => - val localityName = locality match { - case TaskLocality.PROCESS_LOCAL => "Process local" - case TaskLocality.NODE_LOCAL => "Node local" - case TaskLocality.RACK_LOCAL => "Rack local" - case TaskLocality.ANY => "Any" - } - s"$localityName: $count" + s"${names(locality)}: $count" } localityNamesAndCounts.sorted.mkString("; ") } def render(request: HttpServletRequest): Seq[Node] = { - progressListener.synchronized { - // stripXSS is called first to remove suspicious characters used in XSS attacks - val parameterId = UIUtils.stripXSS(request.getParameter("id")) - require(parameterId != null && parameterId.nonEmpty, "Missing id parameter") - - val parameterAttempt = UIUtils.stripXSS(request.getParameter("attempt")) - require(parameterAttempt != null && parameterAttempt.nonEmpty, "Missing attempt parameter") - - val parameterTaskPage = UIUtils.stripXSS(request.getParameter("task.page")) - val parameterTaskSortColumn = UIUtils.stripXSS(request.getParameter("task.sort")) - val parameterTaskSortDesc = UIUtils.stripXSS(request.getParameter("task.desc")) - val parameterTaskPageSize = UIUtils.stripXSS(request.getParameter("task.pageSize")) - val parameterTaskPrevPageSize = UIUtils.stripXSS(request.getParameter("task.prevPageSize")) - - val taskPage = Option(parameterTaskPage).map(_.toInt).getOrElse(1) - val taskSortColumn = Option(parameterTaskSortColumn).map { sortColumn => - UIUtils.decodeURLParameter(sortColumn) - }.getOrElse("Index") - val taskSortDesc = Option(parameterTaskSortDesc).map(_.toBoolean).getOrElse(false) - val taskPageSize = Option(parameterTaskPageSize).map(_.toInt).getOrElse(100) - val taskPrevPageSize = Option(parameterTaskPrevPageSize).map(_.toInt).getOrElse(taskPageSize) - - val stageId = parameterId.toInt - val stageAttemptId = parameterAttempt.toInt - val stageDataOption = progressListener.stageIdToData.get((stageId, stageAttemptId)) - - val stageHeader = s"Details for Stage $stageId (Attempt $stageAttemptId)" - if (stageDataOption.isEmpty) { + // stripXSS is called first to remove suspicious characters used in XSS attacks + val parameterId = UIUtils.stripXSS(request.getParameter("id")) + require(parameterId != null && parameterId.nonEmpty, "Missing id parameter") + + val parameterAttempt = UIUtils.stripXSS(request.getParameter("attempt")) + require(parameterAttempt != null && parameterAttempt.nonEmpty, "Missing attempt parameter") + + val parameterTaskPage = UIUtils.stripXSS(request.getParameter("task.page")) + val parameterTaskSortColumn = UIUtils.stripXSS(request.getParameter("task.sort")) + val parameterTaskSortDesc = UIUtils.stripXSS(request.getParameter("task.desc")) + val parameterTaskPageSize = UIUtils.stripXSS(request.getParameter("task.pageSize")) + val parameterTaskPrevPageSize = UIUtils.stripXSS(request.getParameter("task.prevPageSize")) + + val taskPage = Option(parameterTaskPage).map(_.toInt).getOrElse(1) + val taskSortColumn = Option(parameterTaskSortColumn).map { sortColumn => + UIUtils.decodeURLParameter(sortColumn) + }.getOrElse("Index") + val taskSortDesc = Option(parameterTaskSortDesc).map(_.toBoolean).getOrElse(false) + val taskPageSize = Option(parameterTaskPageSize).map(_.toInt).getOrElse(100) + val taskPrevPageSize = Option(parameterTaskPrevPageSize).map(_.toInt).getOrElse(taskPageSize) + + val stageId = parameterId.toInt + val stageAttemptId = parameterAttempt.toInt + + val stageHeader = s"Details for Stage $stageId (Attempt $stageAttemptId)" + val stageData = parent.store + .asOption(parent.store.stageAttempt(stageId, stageAttemptId, details = true)) + .getOrElse { val content =

    No information to display for Stage {stageId} (Attempt {stageAttemptId})

    return UIUtils.headerSparkPage(stageHeader, content, parent) - - } - if (stageDataOption.get.taskData.isEmpty) { - val content = -
    -

    Summary Metrics

    No tasks have started yet -

    Tasks

    No tasks have started yet -
    - return UIUtils.headerSparkPage(stageHeader, content, parent) } - val stageData = stageDataOption.get - val tasks = stageData.taskData.values.toSeq.sortBy(_.taskInfo.launchTime) - val numCompleted = stageData.numCompleteTasks - val totalTasks = stageData.numActiveTasks + - stageData.numCompleteTasks + stageData.numFailedTasks - val totalTasksNumStr = if (totalTasks == tasks.size) { - s"$totalTasks" - } else { - s"$totalTasks, showing ${tasks.size}" - } + val tasks = stageData.tasks.getOrElse(Map.empty).values.toSeq + if (tasks.isEmpty) { + val content = +
    +

    Summary Metrics

    No tasks have started yet +

    Tasks

    No tasks have started yet +
    + return UIUtils.headerSparkPage(stageHeader, content, parent) + } - val allAccumulables = progressListener.stageIdToData((stageId, stageAttemptId)).accumulables - val externalAccumulables = allAccumulables.values.filter { acc => !acc.internal } - val hasAccumulators = externalAccumulables.nonEmpty + val numCompleted = stageData.numCompleteTasks + val totalTasks = stageData.numActiveTasks + stageData.numCompleteTasks + + stageData.numFailedTasks + stageData.numKilledTasks + val totalTasksNumStr = if (totalTasks == tasks.size) { + s"$totalTasks" + } else { + s"$totalTasks, showing ${tasks.size}" + } - val summary = -
    -
      + val externalAccumulables = stageData.accumulatorUpdates + val hasAccumulators = externalAccumulables.size > 0 + + val summary = +
      +
        +
      • + Total Time Across All Tasks: + {UIUtils.formatDuration(stageData.executorRunTime)} +
      • +
      • + Locality Level Summary: + {getLocalitySummaryString(stageData, tasks)} +
      • + {if (hasInput(stageData)) {
      • - Total Time Across All Tasks: - {UIUtils.formatDuration(stageData.executorRunTime)} + Input Size / Records: + {s"${Utils.bytesToString(stageData.inputBytes)} / ${stageData.inputRecords}"}
      • + }} + {if (hasOutput(stageData)) {
      • - Locality Level Summary: - {getLocalitySummaryString(stageData)} + Output: + {s"${Utils.bytesToString(stageData.outputBytes)} / ${stageData.outputRecords}"}
      • - {if (stageData.hasInput) { -
      • - Input Size / Records: - {s"${Utils.bytesToString(stageData.inputBytes)} / ${stageData.inputRecords}"} -
      • - }} - {if (stageData.hasOutput) { -
      • - Output: - {s"${Utils.bytesToString(stageData.outputBytes)} / ${stageData.outputRecords}"} -
      • - }} - {if (stageData.hasShuffleRead) { -
      • - Shuffle Read: - {s"${Utils.bytesToString(stageData.shuffleReadTotalBytes)} / " + - s"${stageData.shuffleReadRecords}"} -
      • - }} - {if (stageData.hasShuffleWrite) { -
      • - Shuffle Write: - {s"${Utils.bytesToString(stageData.shuffleWriteBytes)} / " + - s"${stageData.shuffleWriteRecords}"} -
      • - }} - {if (stageData.hasBytesSpilled) { -
      • - Shuffle Spill (Memory): - {Utils.bytesToString(stageData.memoryBytesSpilled)} -
      • -
      • - Shuffle Spill (Disk): - {Utils.bytesToString(stageData.diskBytesSpilled)} -
      • - }} -
      -
      + }} + {if (hasShuffleRead(stageData)) { +
    • + Shuffle Read: + {s"${Utils.bytesToString(stageData.shuffleReadBytes)} / " + + s"${stageData.shuffleReadRecords}"} +
    • + }} + {if (hasShuffleWrite(stageData)) { +
    • + Shuffle Write: + {s"${Utils.bytesToString(stageData.shuffleWriteBytes)} / " + + s"${stageData.shuffleWriteRecords}"} +
    • + }} + {if (hasBytesSpilled(stageData)) { +
    • + Shuffle Spill (Memory): + {Utils.bytesToString(stageData.memoryBytesSpilled)} +
    • +
    • + Shuffle Spill (Disk): + {Utils.bytesToString(stageData.diskBytesSpilled)} +
    • + }} +
    +
    - val showAdditionalMetrics = -
    - - - Show Additional Metrics - - +
    - val outputSizes = validTasks.map { taskUIData: TaskUIData => - taskUIData.metrics.get.outputMetrics.bytesWritten.toDouble - } + val stageGraph = parent.store.asOption(parent.store.operationGraphForStage(stageId)) + val dagViz = UIUtils.showDagVizForStage(stageId, stageGraph) - val outputRecords = validTasks.map { taskUIData: TaskUIData => - taskUIData.metrics.get.outputMetrics.recordsWritten.toDouble - } + val accumulableHeaders: Seq[String] = Seq("Accumulable", "Value") + def accumulableRow(acc: AccumulableInfo): Seq[Node] = { + {acc.name}{acc.value} + } + val accumulableTable = UIUtils.listingTable( + accumulableHeaders, + accumulableRow, + externalAccumulables.toSeq) + + val page: Int = { + // If the user has changed to a larger page size, then go to page 1 in order to avoid + // IndexOutOfBoundsException. + if (taskPageSize <= taskPrevPageSize) { + taskPage + } else { + 1 + } + } + val currentTime = System.currentTimeMillis() + val (taskTable, taskTableHTML) = try { + val _taskTable = new TaskPagedTable( + parent.conf, + UIUtils.prependBaseUri(parent.basePath) + + s"/stages/stage?id=${stageId}&attempt=${stageAttemptId}", + tasks, + hasAccumulators, + hasInput(stageData), + hasOutput(stageData), + hasShuffleRead(stageData), + hasShuffleWrite(stageData), + hasBytesSpilled(stageData), + currentTime, + pageSize = taskPageSize, + sortColumn = taskSortColumn, + desc = taskSortDesc, + store = parent.store + ) + (_taskTable, _taskTable.table(page)) + } catch { + case e @ (_ : IllegalArgumentException | _ : IndexOutOfBoundsException) => + val errorMessage = +
    +

    Error while rendering stage table:

    +
    +              {Utils.exceptionString(e)}
    +            
    +
    + (null, errorMessage) + } - val outputQuantiles = Output Size / Records +: - getFormattedSizeQuantilesWithRecords(outputSizes, outputRecords) + val jsForScrollingDownToTaskTable = + - val shuffleReadBlockedTimes = validTasks.map { taskUIData: TaskUIData => - taskUIData.metrics.get.shuffleReadMetrics.fetchWaitTime.toDouble - } - val shuffleReadBlockedQuantiles = - - - Shuffle Read Blocked Time - - +: - getFormattedTimeQuantiles(shuffleReadBlockedTimes) + val taskIdsInPage = if (taskTable == null) Set.empty[Long] + else taskTable.dataSource.slicedTaskIds - val shuffleReadTotalSizes = validTasks.map { taskUIData: TaskUIData => - taskUIData.metrics.get.shuffleReadMetrics.totalBytesRead.toDouble - } - val shuffleReadTotalRecords = validTasks.map { taskUIData: TaskUIData => - taskUIData.metrics.get.shuffleReadMetrics.recordsRead.toDouble - } - val shuffleReadTotalQuantiles = - - - Shuffle Read Size / Records - - +: - getFormattedSizeQuantilesWithRecords(shuffleReadTotalSizes, shuffleReadTotalRecords) + // Excludes tasks which failed and have incomplete metrics + val validTasks = tasks.filter(t => t.status == "SUCCESS" && t.taskMetrics.isDefined) - val shuffleReadRemoteSizes = validTasks.map { taskUIData: TaskUIData => - taskUIData.metrics.get.shuffleReadMetrics.remoteBytesRead.toDouble + val summaryTable: Option[Seq[Node]] = + if (validTasks.size == 0) { + None + } else { + def getDistributionQuantiles(data: Seq[Double]): IndexedSeq[Double] = { + Distribution(data).get.getQuantiles() + } + def getFormattedTimeQuantiles(times: Seq[Double]): Seq[Node] = { + getDistributionQuantiles(times).map { millis => + {UIUtils.formatDuration(millis.toLong)} } - val shuffleReadRemoteQuantiles = - - - Shuffle Remote Reads - - +: - getFormattedSizeQuantiles(shuffleReadRemoteSizes) + } + def getFormattedSizeQuantiles(data: Seq[Double]): Seq[Elem] = { + getDistributionQuantiles(data).map(d => {Utils.bytesToString(d.toLong)}) + } - val shuffleWriteSizes = validTasks.map { taskUIData: TaskUIData => - taskUIData.metrics.get.shuffleWriteMetrics.bytesWritten.toDouble - } + val deserializationTimes = validTasks.map { task => + task.taskMetrics.get.executorDeserializeTime.toDouble + } + val deserializationQuantiles = + + + Task Deserialization Time + + +: getFormattedTimeQuantiles(deserializationTimes) + + val serviceTimes = validTasks.map(_.taskMetrics.get.executorRunTime.toDouble) + val serviceQuantiles = Duration +: getFormattedTimeQuantiles(serviceTimes) + + val gcTimes = validTasks.map(_.taskMetrics.get.jvmGcTime.toDouble) + val gcQuantiles = + + GC Time + + +: getFormattedTimeQuantiles(gcTimes) + + val serializationTimes = validTasks.map(_.taskMetrics.get.resultSerializationTime.toDouble) + val serializationQuantiles = + + + Result Serialization Time + + +: getFormattedTimeQuantiles(serializationTimes) + + val gettingResultTimes = validTasks.map(getGettingResultTime(_, currentTime).toDouble) + val gettingResultQuantiles = + + + Getting Result Time + + +: + getFormattedTimeQuantiles(gettingResultTimes) + + val peakExecutionMemory = validTasks.map(_.taskMetrics.get.peakExecutionMemory.toDouble) + val peakExecutionMemoryQuantiles = { + + + Peak Execution Memory + + +: getFormattedSizeQuantiles(peakExecutionMemory) + } - val shuffleWriteRecords = validTasks.map { taskUIData: TaskUIData => - taskUIData.metrics.get.shuffleWriteMetrics.recordsWritten.toDouble - } + // The scheduler delay includes the network delay to send the task to the worker + // machine and to send back the result (but not the time to fetch the task result, + // if it needed to be fetched from the block manager on the worker). + val schedulerDelays = validTasks.map { task => + getSchedulerDelay(task, task.taskMetrics.get, currentTime).toDouble + } + val schedulerDelayTitle = Scheduler Delay + val schedulerDelayQuantiles = schedulerDelayTitle +: + getFormattedTimeQuantiles(schedulerDelays) + def getFormattedSizeQuantilesWithRecords(data: Seq[Double], records: Seq[Double]) + : Seq[Elem] = { + val recordDist = getDistributionQuantiles(records).iterator + getDistributionQuantiles(data).map(d => + {s"${Utils.bytesToString(d.toLong)} / ${recordDist.next().toLong}"} + ) + } - val shuffleWriteQuantiles = Shuffle Write Size / Records +: - getFormattedSizeQuantilesWithRecords(shuffleWriteSizes, shuffleWriteRecords) + val inputSizes = validTasks.map(_.taskMetrics.get.inputMetrics.bytesRead.toDouble) + val inputRecords = validTasks.map(_.taskMetrics.get.inputMetrics.recordsRead.toDouble) + val inputQuantiles = Input Size / Records +: + getFormattedSizeQuantilesWithRecords(inputSizes, inputRecords) - val memoryBytesSpilledSizes = validTasks.map { taskUIData: TaskUIData => - taskUIData.metrics.get.memoryBytesSpilled.toDouble - } - val memoryBytesSpilledQuantiles = Shuffle spill (memory) +: - getFormattedSizeQuantiles(memoryBytesSpilledSizes) + val outputSizes = validTasks.map(_.taskMetrics.get.outputMetrics.bytesWritten.toDouble) + val outputRecords = validTasks.map(_.taskMetrics.get.outputMetrics.recordsWritten.toDouble) + val outputQuantiles = Output Size / Records +: + getFormattedSizeQuantilesWithRecords(outputSizes, outputRecords) - val diskBytesSpilledSizes = validTasks.map { taskUIData: TaskUIData => - taskUIData.metrics.get.diskBytesSpilled.toDouble - } - val diskBytesSpilledQuantiles = Shuffle spill (disk) +: - getFormattedSizeQuantiles(diskBytesSpilledSizes) - - val listings: Seq[Seq[Node]] = Seq( - {serviceQuantiles}, - {schedulerDelayQuantiles}, - - {deserializationQuantiles} - - {gcQuantiles}, - - {serializationQuantiles} - , - {gettingResultQuantiles}, - - {peakExecutionMemoryQuantiles} - , - if (stageData.hasInput) {inputQuantiles} else Nil, - if (stageData.hasOutput) {outputQuantiles} else Nil, - if (stageData.hasShuffleRead) { - - {shuffleReadBlockedQuantiles} - - {shuffleReadTotalQuantiles} - - {shuffleReadRemoteQuantiles} - - } else { - Nil - }, - if (stageData.hasShuffleWrite) {shuffleWriteQuantiles} else Nil, - if (stageData.hasBytesSpilled) {memoryBytesSpilledQuantiles} else Nil, - if (stageData.hasBytesSpilled) {diskBytesSpilledQuantiles} else Nil) - - val quantileHeaders = Seq("Metric", "Min", "25th percentile", - "Median", "75th percentile", "Max") - // The summary table does not use CSS to stripe rows, which doesn't work with hidden - // rows (instead, JavaScript in table.js is used to stripe the non-hidden rows). - Some(UIUtils.listingTable( - quantileHeaders, - identity[Seq[Node]], - listings, - fixedWidth = true, - id = Some("task-summary-table"), - stripeRowsWithCss = false)) + val shuffleReadBlockedTimes = validTasks.map { task => + task.taskMetrics.get.shuffleReadMetrics.fetchWaitTime.toDouble + } + val shuffleReadBlockedQuantiles = + + + Shuffle Read Blocked Time + + +: + getFormattedTimeQuantiles(shuffleReadBlockedTimes) + + val shuffleReadTotalSizes = validTasks.map { task => + totalBytesRead(task.taskMetrics.get.shuffleReadMetrics).toDouble + } + val shuffleReadTotalRecords = validTasks.map { task => + task.taskMetrics.get.shuffleReadMetrics.recordsRead.toDouble + } + val shuffleReadTotalQuantiles = + + + Shuffle Read Size / Records + + +: + getFormattedSizeQuantilesWithRecords(shuffleReadTotalSizes, shuffleReadTotalRecords) + + val shuffleReadRemoteSizes = validTasks.map { task => + task.taskMetrics.get.shuffleReadMetrics.remoteBytesRead.toDouble + } + val shuffleReadRemoteQuantiles = + + + Shuffle Remote Reads + + +: + getFormattedSizeQuantiles(shuffleReadRemoteSizes) + + val shuffleWriteSizes = validTasks.map { task => + task.taskMetrics.get.shuffleWriteMetrics.bytesWritten.toDouble } - val executorTable = new ExecutorTable(stageId, stageAttemptId, parent, store) + val shuffleWriteRecords = validTasks.map { task => + task.taskMetrics.get.shuffleWriteMetrics.recordsWritten.toDouble + } - val maybeAccumulableTable: Seq[Node] = - if (hasAccumulators) {

    Accumulators

    ++ accumulableTable } else Seq.empty + val shuffleWriteQuantiles = Shuffle Write Size / Records +: + getFormattedSizeQuantilesWithRecords(shuffleWriteSizes, shuffleWriteRecords) + + val memoryBytesSpilledSizes = validTasks.map(_.taskMetrics.get.memoryBytesSpilled.toDouble) + val memoryBytesSpilledQuantiles = Shuffle spill (memory) +: + getFormattedSizeQuantiles(memoryBytesSpilledSizes) + + val diskBytesSpilledSizes = validTasks.map(_.taskMetrics.get.diskBytesSpilled.toDouble) + val diskBytesSpilledQuantiles = Shuffle spill (disk) +: + getFormattedSizeQuantiles(diskBytesSpilledSizes) + + val listings: Seq[Seq[Node]] = Seq( + {serviceQuantiles}, + {schedulerDelayQuantiles}, + + {deserializationQuantiles} + + {gcQuantiles}, + + {serializationQuantiles} + , + {gettingResultQuantiles}, + + {peakExecutionMemoryQuantiles} + , + if (hasInput(stageData)) {inputQuantiles} else Nil, + if (hasOutput(stageData)) {outputQuantiles} else Nil, + if (hasShuffleRead(stageData)) { + + {shuffleReadBlockedQuantiles} + + {shuffleReadTotalQuantiles} + + {shuffleReadRemoteQuantiles} + + } else { + Nil + }, + if (hasShuffleWrite(stageData)) {shuffleWriteQuantiles} else Nil, + if (hasBytesSpilled(stageData)) {memoryBytesSpilledQuantiles} else Nil, + if (hasBytesSpilled(stageData)) {diskBytesSpilledQuantiles} else Nil) + + val quantileHeaders = Seq("Metric", "Min", "25th percentile", + "Median", "75th percentile", "Max") + // The summary table does not use CSS to stripe rows, which doesn't work with hidden + // rows (instead, JavaScript in table.js is used to stripe the non-hidden rows). + Some(UIUtils.listingTable( + quantileHeaders, + identity[Seq[Node]], + listings, + fixedWidth = true, + id = Some("task-summary-table"), + stripeRowsWithCss = false)) + } - val aggMetrics = - -

    - - Aggregated Metrics by Executor -

    -
    -
    - {executorTable.toNodeSeq} -
    + val executorTable = new ExecutorTable(stageData, parent.store) + + val maybeAccumulableTable: Seq[Node] = + if (hasAccumulators) {

    Accumulators

    ++ accumulableTable } else Seq() + + val aggMetrics = + +

    + + Aggregated Metrics by Executor +

    +
    +
    + {executorTable.toNodeSeq} +
    - val content = - summary ++ - dagViz ++ - showAdditionalMetrics ++ - makeTimeline( - // Only show the tasks in the table - stageData.taskData.values.toSeq.filter(t => taskIdsInPage.contains(t.taskInfo.taskId)), - currentTime) ++ -

    Summary Metrics for {numCompleted} Completed Tasks

    ++ -
    {summaryTable.getOrElse("No tasks have reported metrics yet.")}
    ++ - aggMetrics ++ - maybeAccumulableTable ++ -

    Tasks ({totalTasksNumStr})

    ++ - taskTableHTML ++ jsForScrollingDownToTaskTable - UIUtils.headerSparkPage(stageHeader, content, parent, showVisualization = true) - } + val content = + summary ++ + dagViz ++ + showAdditionalMetrics ++ + makeTimeline( + // Only show the tasks in the table + tasks.filter { t => taskIdsInPage.contains(t.taskId) }, + currentTime) ++ +

    Summary Metrics for {numCompleted} Completed Tasks

    ++ +
    {summaryTable.getOrElse("No tasks have reported metrics yet.")}
    ++ + aggMetrics ++ + maybeAccumulableTable ++ +

    Tasks ({totalTasksNumStr})

    ++ + taskTableHTML ++ jsForScrollingDownToTaskTable + UIUtils.headerSparkPage(stageHeader, content, parent, showVisualization = true) } - def makeTimeline(tasks: Seq[TaskUIData], currentTime: Long): Seq[Node] = { + def makeTimeline(tasks: Seq[TaskData], currentTime: Long): Seq[Node] = { val executorsSet = new HashSet[(String, String)] var minLaunchTime = Long.MaxValue var maxFinishTime = Long.MinValue val executorsArrayStr = - tasks.sortBy(-_.taskInfo.launchTime).take(MAX_TIMELINE_TASKS).map { taskUIData => - val taskInfo = taskUIData.taskInfo + tasks.sortBy(-_.launchTime.getTime()).take(MAX_TIMELINE_TASKS).map { taskInfo => val executorId = taskInfo.executorId val host = taskInfo.host executorsSet += ((executorId, host)) - val launchTime = taskInfo.launchTime - val finishTime = if (!taskInfo.running) taskInfo.finishTime else currentTime + val launchTime = taskInfo.launchTime.getTime() + val finishTime = taskInfo.duration.map(taskInfo.launchTime.getTime() + _) + .getOrElse(currentTime) val totalExecutionTime = finishTime - launchTime minLaunchTime = launchTime.min(minLaunchTime) maxFinishTime = finishTime.max(maxFinishTime) def toProportion(time: Long) = time.toDouble / totalExecutionTime * 100 - val metricsOpt = taskUIData.metrics + val metricsOpt = taskInfo.taskMetrics val shuffleReadTime = metricsOpt.map(_.shuffleReadMetrics.fetchWaitTime).getOrElse(0L) val shuffleReadTimeProportion = toProportion(shuffleReadTime) @@ -629,14 +594,14 @@ private[ui] class StagePage(parent: StagesTab, store: AppStatusStore) extends We val serializationTimeProportion = toProportion(serializationTime) val deserializationTime = metricsOpt.map(_.executorDeserializeTime).getOrElse(0L) val deserializationTimeProportion = toProportion(deserializationTime) - val gettingResultTime = getGettingResultTime(taskUIData.taskInfo, currentTime) + val gettingResultTime = getGettingResultTime(taskInfo, currentTime) val gettingResultTimeProportion = toProportion(gettingResultTime) val schedulerDelay = metricsOpt.map(getSchedulerDelay(taskInfo, _, currentTime)).getOrElse(0L) val schedulerDelayProportion = toProportion(schedulerDelay) val executorOverhead = serializationTime + deserializationTime - val executorRunTime = if (taskInfo.running) { + val executorRunTime = if (taskInfo.duration.isDefined) { totalExecutionTime - executorOverhead - gettingResultTime } else { metricsOpt.map(_.executorRunTime).getOrElse( @@ -663,7 +628,7 @@ private[ui] class StagePage(parent: StagesTab, store: AppStatusStore) extends We serializationTimeProportionPos + serializationTimeProportion val index = taskInfo.index - val attempt = taskInfo.attemptNumber + val attempt = taskInfo.attempt val svgTag = if (totalExecutionTime == 0) { @@ -705,7 +670,7 @@ private[ui] class StagePage(parent: StagesTab, store: AppStatusStore) extends We |Status: ${taskInfo.status}
    |Launch Time: ${UIUtils.formatDate(new Date(launchTime))} |${ - if (!taskInfo.running) { + if (!taskInfo.duration.isDefined) { s"""
    Finish Time: ${UIUtils.formatDate(new Date(finishTime))}""" } else { "" @@ -770,34 +735,40 @@ private[ui] class StagePage(parent: StagesTab, store: AppStatusStore) extends We } private[ui] object StagePage { - private[ui] def getGettingResultTime(info: TaskInfo, currentTime: Long): Long = { - if (info.gettingResult) { - if (info.finished) { - info.finishTime - info.gettingResultTime - } else { - // The task is still fetching the result. - currentTime - info.gettingResultTime - } - } else { - 0L + private[ui] def getGettingResultTime(info: TaskData, currentTime: Long): Long = { + info.resultFetchStart match { + case Some(start) => + info.duration match { + case Some(duration) => + info.launchTime.getTime() + duration - start.getTime() + + case _ => + currentTime - start.getTime() + } + + case _ => + 0L } } private[ui] def getSchedulerDelay( - info: TaskInfo, metrics: TaskMetricsUIData, currentTime: Long): Long = { - if (info.finished) { - val totalExecutionTime = info.finishTime - info.launchTime - val executorOverhead = metrics.executorDeserializeTime + - metrics.resultSerializationTime - math.max( - 0, - totalExecutionTime - metrics.executorRunTime - executorOverhead - - getGettingResultTime(info, currentTime)) - } else { - // The task is still running and the metrics like executorRunTime are not available. - 0L + info: TaskData, + metrics: TaskMetrics, + currentTime: Long): Long = { + info.duration match { + case Some(duration) => + val executorOverhead = metrics.executorDeserializeTime + metrics.resultSerializationTime + math.max( + 0, + duration - metrics.executorRunTime - executorOverhead - + getGettingResultTime(info, currentTime)) + + case _ => + // The task is still running and the metrics like executorRunTime are not available. + 0L } } + } private[ui] case class TaskTableRowInputData(inputSortable: Long, inputReadable: String) @@ -826,7 +797,7 @@ private[ui] case class TaskTableRowBytesSpilledData( /** * Contains all data that needs for sorting and generating HTML. Using this one rather than - * TaskUIData to avoid creating duplicate contents during sorting the data. + * TaskData to avoid creating duplicate contents during sorting the data. */ private[ui] class TaskTableRowData( val index: Int, @@ -856,14 +827,13 @@ private[ui] class TaskTableRowData( val logs: Map[String, String]) private[ui] class TaskDataSource( - tasks: Seq[TaskUIData], + tasks: Seq[TaskData], hasAccumulators: Boolean, hasInput: Boolean, hasOutput: Boolean, hasShuffleRead: Boolean, hasShuffleWrite: Boolean, hasBytesSpilled: Boolean, - lastUpdateTime: Option[Long], currentTime: Long, pageSize: Int, sortColumn: String, @@ -871,7 +841,10 @@ private[ui] class TaskDataSource( store: AppStatusStore) extends PagedDataSource[TaskTableRowData](pageSize) { import StagePage._ - // Convert TaskUIData to TaskTableRowData which contains the final contents to show in the table + // Keep an internal cache of executor log maps so that long task lists render faster. + private val executorIdToLogs = new HashMap[String, Map[String, String]]() + + // Convert TaskData to TaskTableRowData which contains the final contents to show in the table // so that we can avoid creating duplicate contents during sorting the data private val data = tasks.map(taskRow).sorted(ordering(sortColumn, desc)) @@ -887,26 +860,19 @@ private[ui] class TaskDataSource( def slicedTaskIds: Set[Long] = _slicedTaskIds - private def taskRow(taskData: TaskUIData): TaskTableRowData = { - val info = taskData.taskInfo - val metrics = taskData.metrics - val duration = taskData.taskDuration(lastUpdateTime).getOrElse(1L) - val formatDuration = - taskData.taskDuration(lastUpdateTime).map(d => UIUtils.formatDuration(d)).getOrElse("") + private def taskRow(info: TaskData): TaskTableRowData = { + val metrics = info.taskMetrics + val duration = info.duration.getOrElse(1L) + val formatDuration = info.duration.map(d => UIUtils.formatDuration(d)).getOrElse("") val schedulerDelay = metrics.map(getSchedulerDelay(info, _, currentTime)).getOrElse(0L) - val gcTime = metrics.map(_.jvmGCTime).getOrElse(0L) + val gcTime = metrics.map(_.jvmGcTime).getOrElse(0L) val taskDeserializationTime = metrics.map(_.executorDeserializeTime).getOrElse(0L) val serializationTime = metrics.map(_.resultSerializationTime).getOrElse(0L) val gettingResultTime = getGettingResultTime(info, currentTime) - val externalAccumulableReadable = info.accumulables - .filterNot(_.internal) - .flatMap { a => - (a.name, a.update) match { - case (Some(name), Some(update)) => Some(StringEscapeUtils.escapeHtml4(s"$name: $update")) - case _ => None - } - } + val externalAccumulableReadable = info.accumulatorUpdates.map { acc => + StringEscapeUtils.escapeHtml4(s"${acc.name}: ${acc.update}") + } val peakExecutionMemoryUsed = metrics.map(_.peakExecutionMemory).getOrElse(0L) val maybeInput = metrics.map(_.inputMetrics) @@ -928,7 +894,7 @@ private[ui] class TaskDataSource( val shuffleReadBlockedTimeReadable = maybeShuffleRead.map(ms => UIUtils.formatDuration(ms.fetchWaitTime)).getOrElse("") - val totalShuffleBytes = maybeShuffleRead.map(_.totalBytesRead) + val totalShuffleBytes = maybeShuffleRead.map(ApiHelper.totalBytesRead) val shuffleReadSortable = totalShuffleBytes.getOrElse(0L) val shuffleReadReadable = totalShuffleBytes.map(Utils.bytesToString).getOrElse("") val shuffleReadRecords = maybeShuffleRead.map(_.recordsRead.toString).getOrElse("") @@ -1011,17 +977,16 @@ private[ui] class TaskDataSource( None } - val logs = store.executorSummary(info.executorId).map(_.executorLogs).getOrElse(Map.empty) new TaskTableRowData( info.index, info.taskId, - info.attemptNumber, + info.attempt, info.speculative, info.status, info.taskLocality.toString, info.executorId, info.host, - info.launchTime, + info.launchTime.getTime(), duration, formatDuration, schedulerDelay, @@ -1036,8 +1001,13 @@ private[ui] class TaskDataSource( shuffleRead, shuffleWrite, bytesSpilled, - taskData.errorMessage.getOrElse(""), - logs) + info.errorMessage.getOrElse(""), + executorLogs(info.executorId)) + } + + private def executorLogs(id: String): Map[String, String] = { + executorIdToLogs.getOrElseUpdate(id, + store.asOption(store.executorSummary(id)).map(_.executorLogs).getOrElse(Map.empty)) } /** @@ -1148,14 +1118,13 @@ private[ui] class TaskDataSource( private[ui] class TaskPagedTable( conf: SparkConf, basePath: String, - data: Seq[TaskUIData], + data: Seq[TaskData], hasAccumulators: Boolean, hasInput: Boolean, hasOutput: Boolean, hasShuffleRead: Boolean, hasShuffleWrite: Boolean, hasBytesSpilled: Boolean, - lastUpdateTime: Option[Long], currentTime: Long, pageSize: Int, sortColumn: String, @@ -1181,7 +1150,6 @@ private[ui] class TaskPagedTable( hasShuffleRead, hasShuffleWrite, hasBytesSpilled, - lastUpdateTime, currentTime, pageSize, sortColumn, @@ -1363,3 +1331,23 @@ private[ui] class TaskPagedTable( {errorSummary}{details} } } + +private object ApiHelper { + + def hasInput(stageData: StageData): Boolean = stageData.inputBytes > 0 + + def hasOutput(stageData: StageData): Boolean = stageData.outputBytes > 0 + + def hasShuffleRead(stageData: StageData): Boolean = stageData.shuffleReadBytes > 0 + + def hasShuffleWrite(stageData: StageData): Boolean = stageData.shuffleWriteBytes > 0 + + def hasBytesSpilled(stageData: StageData): Boolean = { + stageData.diskBytesSpilled > 0 || stageData.memoryBytesSpilled > 0 + } + + def totalBytesRead(metrics: ShuffleReadMetrics): Long = { + metrics.localBytesRead + metrics.remoteBytesRead + } + +} diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala b/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala index f0a12a28de069..18a4926f2f6c0 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala @@ -26,19 +26,19 @@ import scala.xml._ import org.apache.commons.lang3.StringEscapeUtils -import org.apache.spark.scheduler.StageInfo +import org.apache.spark.status.AppStatusStore +import org.apache.spark.status.api.v1 import org.apache.spark.ui._ -import org.apache.spark.ui.jobs.UIData.StageUIData import org.apache.spark.util.Utils private[ui] class StageTableBase( + store: AppStatusStore, request: HttpServletRequest, - stages: Seq[StageInfo], + stages: Seq[v1.StageData], tableHeaderID: String, stageTag: String, basePath: String, subPath: String, - progressListener: JobProgressListener, isFairScheduler: Boolean, killEnabled: Boolean, isFailedStage: Boolean) { @@ -79,12 +79,12 @@ private[ui] class StageTableBase( val toNodeSeq = try { new StagePagedTable( + store, stages, tableHeaderID, stageTag, basePath, subPath, - progressListener, isFairScheduler, killEnabled, currentTime, @@ -106,13 +106,13 @@ private[ui] class StageTableBase( } private[ui] class StageTableRowData( - val stageInfo: StageInfo, - val stageData: Option[StageUIData], + val stage: v1.StageData, + val option: Option[v1.StageData], val stageId: Int, val attemptId: Int, val schedulingPool: String, val descriptionOption: Option[String], - val submissionTime: Long, + val submissionTime: Date, val formattedSubmissionTime: String, val duration: Long, val formattedDuration: String, @@ -126,19 +126,20 @@ private[ui] class StageTableRowData( val shuffleWriteWithUnit: String) private[ui] class MissingStageTableRowData( - stageInfo: StageInfo, + stageInfo: v1.StageData, stageId: Int, attemptId: Int) extends StageTableRowData( - stageInfo, None, stageId, attemptId, "", None, 0, "", -1, "", 0, "", 0, "", 0, "", 0, "") + stageInfo, None, stageId, attemptId, "", None, new Date(0), "", -1, "", 0, "", 0, "", 0, "", 0, + "") /** Page showing list of all ongoing and recently finished stages */ private[ui] class StagePagedTable( - stages: Seq[StageInfo], + store: AppStatusStore, + stages: Seq[v1.StageData], tableHeaderId: String, stageTag: String, basePath: String, subPath: String, - listener: JobProgressListener, isFairScheduler: Boolean, killEnabled: Boolean, currentTime: Long, @@ -164,8 +165,8 @@ private[ui] class StagePagedTable( parameterOtherTable.mkString("&") override val dataSource = new StageDataSource( + store, stages, - listener, currentTime, pageSize, sortColumn, @@ -274,10 +275,10 @@ private[ui] class StagePagedTable( } private def rowContent(data: StageTableRowData): Seq[Node] = { - data.stageData match { + data.option match { case None => missingStageRow(data.stageId) case Some(stageData) => - val info = data.stageInfo + val info = data.stage {if (data.attemptId > 0) { {data.stageId} (retry {data.attemptId}) @@ -301,8 +302,8 @@ private[ui] class StagePagedTable( {data.formattedDuration} {UIUtils.makeProgressBar(started = stageData.numActiveTasks, - completed = stageData.completedIndices.size, failed = stageData.numFailedTasks, - skipped = 0, reasonToNumKilled = stageData.reasonToNumKilled, total = info.numTasks)} + completed = stageData.numCompleteTasks, failed = stageData.numFailedTasks, + skipped = 0, reasonToNumKilled = stageData.killedTasksSummary, total = info.numTasks)} {data.inputReadWithUnit} {data.outputWriteWithUnit} @@ -318,7 +319,7 @@ private[ui] class StagePagedTable( } } - private def failureReasonHtml(s: StageInfo): Seq[Node] = { + private def failureReasonHtml(s: v1.StageData): Seq[Node] = { val failureReason = s.failureReason.getOrElse("") val isMultiline = failureReason.indexOf('\n') >= 0 // Display the first line by default @@ -344,7 +345,7 @@ private[ui] class StagePagedTable( {failureReasonSummary}{details} } - private def makeDescription(s: StageInfo, descriptionOption: Option[String]): Seq[Node] = { + private def makeDescription(s: v1.StageData, descriptionOption: Option[String]): Seq[Node] = { val basePathUri = UIUtils.prependBaseUri(basePath) val killLink = if (killEnabled) { @@ -368,8 +369,8 @@ private[ui] class StagePagedTable( val nameLinkUri = s"$basePathUri/stages/stage?id=${s.stageId}&attempt=${s.attemptId}" val nameLink = {s.name} - val cachedRddInfos = s.rddInfos.filter(_.numCachedPartitions > 0) - val details = if (s.details.nonEmpty) { + val cachedRddInfos = store.rddList().filter { rdd => s.rddIds.contains(rdd.id) } + val details = if (s.details != null && s.details.nonEmpty) { +details @@ -404,14 +405,14 @@ private[ui] class StagePagedTable( } private[ui] class StageDataSource( - stages: Seq[StageInfo], - listener: JobProgressListener, + store: AppStatusStore, + stages: Seq[v1.StageData], currentTime: Long, pageSize: Int, sortColumn: String, desc: Boolean) extends PagedDataSource[StageTableRowData](pageSize) { - // Convert StageInfo to StageTableRowData which contains the final contents to show in the table - // so that we can avoid creating duplicate contents during sorting the data + // Convert v1.StageData to StageTableRowData which contains the final contents to show in the + // table so that we can avoid creating duplicate contents during sorting the data private val data = stages.map(stageRow).sorted(ordering(sortColumn, desc)) private var _slicedStageIds: Set[Int] = _ @@ -424,57 +425,46 @@ private[ui] class StageDataSource( r } - private def stageRow(s: StageInfo): StageTableRowData = { - val stageDataOption = listener.stageIdToData.get((s.stageId, s.attemptId)) + private def stageRow(stageData: v1.StageData): StageTableRowData = { + val description = stageData.description.getOrElse("") - if (stageDataOption.isEmpty) { - return new MissingStageTableRowData(s, s.stageId, s.attemptId) - } - val stageData = stageDataOption.get - - val description = stageData.description - - val formattedSubmissionTime = s.submissionTime match { - case Some(t) => UIUtils.formatDate(new Date(t)) + val formattedSubmissionTime = stageData.submissionTime match { + case Some(t) => UIUtils.formatDate(t) case None => "Unknown" } - val finishTime = s.completionTime.getOrElse(currentTime) + val finishTime = stageData.completionTime.map(_.getTime()).getOrElse(currentTime) // The submission time for a stage is misleading because it counts the time // the stage waits to be launched. (SPARK-10930) - val taskLaunchTimes = - stageData.taskData.values.map(_.taskInfo.launchTime).filter(_ > 0) - val duration: Option[Long] = - if (taskLaunchTimes.nonEmpty) { - val startTime = taskLaunchTimes.min - if (finishTime > startTime) { - Some(finishTime - startTime) - } else { - Some(currentTime - startTime) - } + val duration = stageData.firstTaskLaunchedTime.map { date => + val time = date.getTime() + if (finishTime > time) { + finishTime - time } else { None + currentTime - time } + } val formattedDuration = duration.map(d => UIUtils.formatDuration(d)).getOrElse("Unknown") val inputRead = stageData.inputBytes val inputReadWithUnit = if (inputRead > 0) Utils.bytesToString(inputRead) else "" val outputWrite = stageData.outputBytes val outputWriteWithUnit = if (outputWrite > 0) Utils.bytesToString(outputWrite) else "" - val shuffleRead = stageData.shuffleReadTotalBytes + val shuffleRead = stageData.shuffleReadBytes val shuffleReadWithUnit = if (shuffleRead > 0) Utils.bytesToString(shuffleRead) else "" val shuffleWrite = stageData.shuffleWriteBytes val shuffleWriteWithUnit = if (shuffleWrite > 0) Utils.bytesToString(shuffleWrite) else "" new StageTableRowData( - s, - stageDataOption, - s.stageId, - s.attemptId, + stageData, + Some(stageData), + stageData.stageId, + stageData.attemptId, stageData.schedulingPool, - description, - s.submissionTime.getOrElse(0), + stageData.description, + stageData.submissionTime.getOrElse(new Date(0)), formattedSubmissionTime, duration.getOrElse(-1), formattedDuration, @@ -496,7 +486,7 @@ private[ui] class StageDataSource( val ordering: Ordering[StageTableRowData] = sortColumn match { case "Stage Id" => Ordering.by(_.stageId) case "Pool Name" => Ordering.by(_.schedulingPool) - case "Description" => Ordering.by(x => (x.descriptionOption, x.stageInfo.name)) + case "Description" => Ordering.by(x => (x.descriptionOption, x.stage.name)) case "Submitted" => Ordering.by(_.submissionTime) case "Duration" => Ordering.by(_.duration) case "Input" => Ordering.by(_.inputRead) diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/StagesTab.scala b/core/src/main/scala/org/apache/spark/ui/jobs/StagesTab.scala index 65446f967ad76..be05a963f0e68 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/StagesTab.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/StagesTab.scala @@ -21,36 +21,42 @@ import javax.servlet.http.HttpServletRequest import org.apache.spark.scheduler.SchedulingMode import org.apache.spark.status.AppStatusStore +import org.apache.spark.status.api.v1.StageStatus import org.apache.spark.ui.{SparkUI, SparkUITab, UIUtils} /** Web UI showing progress status of all stages in the given SparkContext. */ -private[ui] class StagesTab(val parent: SparkUI, store: AppStatusStore) +private[ui] class StagesTab(val parent: SparkUI, val store: AppStatusStore) extends SparkUITab(parent, "stages") { val sc = parent.sc val conf = parent.conf val killEnabled = parent.killEnabled - val progressListener = parent.jobProgressListener - val operationGraphListener = parent.operationGraphListener - val lastUpdateTime = parent.lastUpdateTime attachPage(new AllStagesPage(this)) attachPage(new StagePage(this, store)) attachPage(new PoolPage(this)) - def isFairScheduler: Boolean = progressListener.schedulingMode == Some(SchedulingMode.FAIR) + def isFairScheduler: Boolean = { + store.environmentInfo().sparkProperties.toMap + .get("spark.scheduler.mode") + .map { mode => mode == SchedulingMode.FAIR } + .getOrElse(false) + } def handleKillRequest(request: HttpServletRequest): Unit = { if (killEnabled && parent.securityManager.checkModifyPermissions(request.getRemoteUser)) { // stripXSS is called first to remove suspicious characters used in XSS attacks val stageId = Option(UIUtils.stripXSS(request.getParameter("id"))).map(_.toInt) stageId.foreach { id => - if (progressListener.activeStages.contains(id)) { - sc.foreach(_.cancelStage(id, "killed via the Web UI")) - // Do a quick pause here to give Spark time to kill the stage so it shows up as - // killed after the refresh. Note that this will block the serving thread so the - // time should be limited in duration. - Thread.sleep(100) + store.asOption(store.lastStageAttempt(id)).foreach { stage => + val status = stage.status + if (status == StageStatus.ACTIVE || status == StageStatus.PENDING) { + sc.foreach(_.cancelStage(id, "killed via the Web UI")) + // Do a quick pause here to give Spark time to kill the stage so it shows up as + // killed after the refresh. Note that this will block the serving thread so the + // time should be limited in duration. + Thread.sleep(100) + } } } } diff --git a/core/src/main/scala/org/apache/spark/ui/scope/RDDOperationGraph.scala b/core/src/main/scala/org/apache/spark/ui/scope/RDDOperationGraph.scala index bb763248cd7e0..827a8637b9bd2 100644 --- a/core/src/main/scala/org/apache/spark/ui/scope/RDDOperationGraph.scala +++ b/core/src/main/scala/org/apache/spark/ui/scope/RDDOperationGraph.scala @@ -35,20 +35,20 @@ import org.apache.spark.storage.StorageLevel * nodes and children clusters. Additionally, a graph may also have edges that enter or exit * the graph from nodes that belong to adjacent graphs. */ -private[ui] case class RDDOperationGraph( +private[spark] case class RDDOperationGraph( edges: Seq[RDDOperationEdge], outgoingEdges: Seq[RDDOperationEdge], incomingEdges: Seq[RDDOperationEdge], rootCluster: RDDOperationCluster) /** A node in an RDDOperationGraph. This represents an RDD. */ -private[ui] case class RDDOperationNode(id: Int, name: String, cached: Boolean, callsite: String) +private[spark] case class RDDOperationNode(id: Int, name: String, cached: Boolean, callsite: String) /** * A directed edge connecting two nodes in an RDDOperationGraph. * This represents an RDD dependency. */ -private[ui] case class RDDOperationEdge(fromId: Int, toId: Int) +private[spark] case class RDDOperationEdge(fromId: Int, toId: Int) /** * A cluster that groups nodes together in an RDDOperationGraph. @@ -56,7 +56,7 @@ private[ui] case class RDDOperationEdge(fromId: Int, toId: Int) * This represents any grouping of RDDs, including operation scopes (e.g. textFile, flatMap), * stages, jobs, or any higher level construct. A cluster may be nested inside of other clusters. */ -private[ui] class RDDOperationCluster(val id: String, private var _name: String) { +private[spark] class RDDOperationCluster(val id: String, private var _name: String) { private val _childNodes = new ListBuffer[RDDOperationNode] private val _childClusters = new ListBuffer[RDDOperationCluster] @@ -92,7 +92,7 @@ private[ui] class RDDOperationCluster(val id: String, private var _name: String) } } -private[ui] object RDDOperationGraph extends Logging { +private[spark] object RDDOperationGraph extends Logging { val STAGE_CLUSTER_PREFIX = "stage_" diff --git a/core/src/main/scala/org/apache/spark/ui/scope/RDDOperationGraphListener.scala b/core/src/main/scala/org/apache/spark/ui/scope/RDDOperationGraphListener.scala deleted file mode 100644 index 37a12a8646938..0000000000000 --- a/core/src/main/scala/org/apache/spark/ui/scope/RDDOperationGraphListener.scala +++ /dev/null @@ -1,150 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.ui.scope - -import scala.collection.mutable - -import org.apache.spark.SparkConf -import org.apache.spark.scheduler._ -import org.apache.spark.ui.SparkUI - -/** - * A SparkListener that constructs a DAG of RDD operations. - */ -private[ui] class RDDOperationGraphListener(conf: SparkConf) extends SparkListener { - - // Note: the fate of jobs and stages are tied. This means when we clean up a job, - // we always clean up all of its stages. Similarly, when we clean up a stage, we - // always clean up its job (and, transitively, other stages in the same job). - private[ui] val jobIdToStageIds = new mutable.HashMap[Int, Seq[Int]] - private[ui] val jobIdToSkippedStageIds = new mutable.HashMap[Int, Seq[Int]] - private[ui] val stageIdToJobId = new mutable.HashMap[Int, Int] - private[ui] val stageIdToGraph = new mutable.HashMap[Int, RDDOperationGraph] - private[ui] val completedStageIds = new mutable.HashSet[Int] - - // Keep track of the order in which these are inserted so we can remove old ones - private[ui] val jobIds = new mutable.ArrayBuffer[Int] - private[ui] val stageIds = new mutable.ArrayBuffer[Int] - - // How many root nodes to retain in DAG Graph - private[ui] val retainedNodes = - conf.getInt("spark.ui.dagGraph.retainedRootRDDs", Int.MaxValue) - - // How many jobs or stages to retain graph metadata for - private val retainedJobs = - conf.getInt("spark.ui.retainedJobs", SparkUI.DEFAULT_RETAINED_JOBS) - private val retainedStages = - conf.getInt("spark.ui.retainedStages", SparkUI.DEFAULT_RETAINED_STAGES) - - /** - * Return the graph metadata for all stages in the given job. - * An empty list is returned if one or more of its stages has been cleaned up. - */ - def getOperationGraphForJob(jobId: Int): Seq[RDDOperationGraph] = synchronized { - val skippedStageIds = jobIdToSkippedStageIds.getOrElse(jobId, Seq.empty) - val graphs = jobIdToStageIds.getOrElse(jobId, Seq.empty) - .flatMap { sid => stageIdToGraph.get(sid) } - // Mark any skipped stages as such - graphs.foreach { g => - val stageId = g.rootCluster.id.replaceAll(RDDOperationGraph.STAGE_CLUSTER_PREFIX, "").toInt - if (skippedStageIds.contains(stageId) && !g.rootCluster.name.contains("skipped")) { - g.rootCluster.setName(g.rootCluster.name + " (skipped)") - } - } - graphs - } - - /** Return the graph metadata for the given stage, or None if no such information exists. */ - def getOperationGraphForStage(stageId: Int): Option[RDDOperationGraph] = synchronized { - stageIdToGraph.get(stageId) - } - - /** On job start, construct a RDDOperationGraph for each stage in the job for display later. */ - override def onJobStart(jobStart: SparkListenerJobStart): Unit = synchronized { - val jobId = jobStart.jobId - val stageInfos = jobStart.stageInfos - - jobIds += jobId - jobIdToStageIds(jobId) = jobStart.stageInfos.map(_.stageId).sorted - - stageInfos.foreach { stageInfo => - val stageId = stageInfo.stageId - stageIds += stageId - stageIdToJobId(stageId) = jobId - stageIdToGraph(stageId) = RDDOperationGraph.makeOperationGraph(stageInfo, retainedNodes) - trimStagesIfNecessary() - } - - trimJobsIfNecessary() - } - - /** Keep track of stages that have completed. */ - override def onStageCompleted(stageCompleted: SparkListenerStageCompleted): Unit = synchronized { - val stageId = stageCompleted.stageInfo.stageId - if (stageIdToJobId.contains(stageId)) { - // Note: Only do this if the stage has not already been cleaned up - // Otherwise, we may never clean this stage from `completedStageIds` - completedStageIds += stageCompleted.stageInfo.stageId - } - } - - /** On job end, find all stages in this job that are skipped and mark them as such. */ - override def onJobEnd(jobEnd: SparkListenerJobEnd): Unit = synchronized { - val jobId = jobEnd.jobId - jobIdToStageIds.get(jobId).foreach { stageIds => - val skippedStageIds = stageIds.filter { sid => !completedStageIds.contains(sid) } - // Note: Only do this if the job has not already been cleaned up - // Otherwise, we may never clean this job from `jobIdToSkippedStageIds` - jobIdToSkippedStageIds(jobId) = skippedStageIds - } - } - - /** Clean metadata for old stages if we have exceeded the number to retain. */ - private def trimStagesIfNecessary(): Unit = { - if (stageIds.size >= retainedStages) { - val toRemove = math.max(retainedStages / 10, 1) - stageIds.take(toRemove).foreach { id => cleanStage(id) } - stageIds.trimStart(toRemove) - } - } - - /** Clean metadata for old jobs if we have exceeded the number to retain. */ - private def trimJobsIfNecessary(): Unit = { - if (jobIds.size >= retainedJobs) { - val toRemove = math.max(retainedJobs / 10, 1) - jobIds.take(toRemove).foreach { id => cleanJob(id) } - jobIds.trimStart(toRemove) - } - } - - /** Clean metadata for the given stage, its job, and all other stages that belong to the job. */ - private[ui] def cleanStage(stageId: Int): Unit = { - completedStageIds.remove(stageId) - stageIdToGraph.remove(stageId) - stageIdToJobId.remove(stageId).foreach { jobId => cleanJob(jobId) } - } - - /** Clean metadata for the given job and all stages that belong to it. */ - private[ui] def cleanJob(jobId: Int): Unit = { - jobIdToSkippedStageIds.remove(jobId) - jobIdToStageIds.remove(jobId).foreach { stageIds => - stageIds.foreach { stageId => cleanStage(stageId) } - } - } - -} diff --git a/core/src/test/resources/HistoryServerExpectations/complete_stage_list_json_expectation.json b/core/src/test/resources/HistoryServerExpectations/complete_stage_list_json_expectation.json index 25c4fff77e0ad..37b7d7269059f 100644 --- a/core/src/test/resources/HistoryServerExpectations/complete_stage_list_json_expectation.json +++ b/core/src/test/resources/HistoryServerExpectations/complete_stage_list_json_expectation.json @@ -2,9 +2,12 @@ "status" : "COMPLETE", "stageId" : 3, "attemptId" : 0, + "numTasks" : 8, "numActiveTasks" : 0, "numCompleteTasks" : 8, "numFailedTasks" : 0, + "numKilledTasks" : 0, + "numCompletedIndices" : 8, "executorRunTime" : 162, "executorCpuTime" : 0, "submissionTime" : "2015-02-03T16:43:07.191GMT", @@ -23,14 +26,19 @@ "name" : "count at :17", "details" : "org.apache.spark.rdd.RDD.count(RDD.scala:910)\n$line19.$read$$iwC$$iwC$$iwC$$iwC.(:17)\n$line19.$read$$iwC$$iwC$$iwC.(:22)\n$line19.$read$$iwC$$iwC.(:24)\n$line19.$read$$iwC.(:26)\n$line19.$read.(:28)\n$line19.$read$.(:32)\n$line19.$read$.()\n$line19.$eval$.(:7)\n$line19.$eval$.()\n$line19.$eval.$print()\nsun.reflect.NativeMethodAccessorImpl.invoke0(Native Method)\nsun.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:57)\nsun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)\njava.lang.reflect.Method.invoke(Method.java:606)\norg.apache.spark.repl.SparkIMain$ReadEvalPrint.call(SparkIMain.scala:852)\norg.apache.spark.repl.SparkIMain$Request.loadAndRun(SparkIMain.scala:1125)\norg.apache.spark.repl.SparkIMain.loadAndRunReq$1(SparkIMain.scala:674)\norg.apache.spark.repl.SparkIMain.interpret(SparkIMain.scala:705)\norg.apache.spark.repl.SparkIMain.interpret(SparkIMain.scala:669)", "schedulingPool" : "default", - "accumulatorUpdates" : [ ] + "rddIds" : [ 6, 5 ], + "accumulatorUpdates" : [ ], + "killedTasksSummary" : { } }, { "status" : "COMPLETE", "stageId" : 1, "attemptId" : 0, + "numTasks" : 8, "numActiveTasks" : 0, "numCompleteTasks" : 8, "numFailedTasks" : 0, + "numKilledTasks" : 0, + "numCompletedIndices" : 8, "executorRunTime" : 3476, "executorCpuTime" : 0, "submissionTime" : "2015-02-03T16:43:05.829GMT", @@ -49,14 +57,19 @@ "name" : "map at :14", "details" : "org.apache.spark.rdd.RDD.map(RDD.scala:271)\n$line10.$read$$iwC$$iwC$$iwC$$iwC.(:14)\n$line10.$read$$iwC$$iwC$$iwC.(:19)\n$line10.$read$$iwC$$iwC.(:21)\n$line10.$read$$iwC.(:23)\n$line10.$read.(:25)\n$line10.$read$.(:29)\n$line10.$read$.()\n$line10.$eval$.(:7)\n$line10.$eval$.()\n$line10.$eval.$print()\nsun.reflect.NativeMethodAccessorImpl.invoke0(Native Method)\nsun.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:57)\nsun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)\njava.lang.reflect.Method.invoke(Method.java:606)\norg.apache.spark.repl.SparkIMain$ReadEvalPrint.call(SparkIMain.scala:852)\norg.apache.spark.repl.SparkIMain$Request.loadAndRun(SparkIMain.scala:1125)\norg.apache.spark.repl.SparkIMain.loadAndRunReq$1(SparkIMain.scala:674)\norg.apache.spark.repl.SparkIMain.interpret(SparkIMain.scala:705)\norg.apache.spark.repl.SparkIMain.interpret(SparkIMain.scala:669)", "schedulingPool" : "default", - "accumulatorUpdates" : [ ] + "rddIds" : [ 1, 0 ], + "accumulatorUpdates" : [ ], + "killedTasksSummary" : { } }, { "status" : "COMPLETE", "stageId" : 0, "attemptId" : 0, + "numTasks" : 8, "numActiveTasks" : 0, "numCompleteTasks" : 8, "numFailedTasks" : 0, + "numKilledTasks" : 0, + "numCompletedIndices" : 8, "executorRunTime" : 4338, "executorCpuTime" : 0, "submissionTime" : "2015-02-03T16:43:04.228GMT", @@ -75,5 +88,7 @@ "name" : "count at :15", "details" : "org.apache.spark.rdd.RDD.count(RDD.scala:910)\n$line9.$read$$iwC$$iwC$$iwC$$iwC.(:15)\n$line9.$read$$iwC$$iwC$$iwC.(:20)\n$line9.$read$$iwC$$iwC.(:22)\n$line9.$read$$iwC.(:24)\n$line9.$read.(:26)\n$line9.$read$.(:30)\n$line9.$read$.()\n$line9.$eval$.(:7)\n$line9.$eval$.()\n$line9.$eval.$print()\nsun.reflect.NativeMethodAccessorImpl.invoke0(Native Method)\nsun.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:57)\nsun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)\njava.lang.reflect.Method.invoke(Method.java:606)\norg.apache.spark.repl.SparkIMain$ReadEvalPrint.call(SparkIMain.scala:852)\norg.apache.spark.repl.SparkIMain$Request.loadAndRun(SparkIMain.scala:1125)\norg.apache.spark.repl.SparkIMain.loadAndRunReq$1(SparkIMain.scala:674)\norg.apache.spark.repl.SparkIMain.interpret(SparkIMain.scala:705)\norg.apache.spark.repl.SparkIMain.interpret(SparkIMain.scala:669)", "schedulingPool" : "default", - "accumulatorUpdates" : [ ] + "rddIds" : [ 0 ], + "accumulatorUpdates" : [ ], + "killedTasksSummary" : { } } ] diff --git a/core/src/test/resources/HistoryServerExpectations/failed_stage_list_json_expectation.json b/core/src/test/resources/HistoryServerExpectations/failed_stage_list_json_expectation.json index b86ba1e65de12..2fd55666fa018 100644 --- a/core/src/test/resources/HistoryServerExpectations/failed_stage_list_json_expectation.json +++ b/core/src/test/resources/HistoryServerExpectations/failed_stage_list_json_expectation.json @@ -2,14 +2,18 @@ "status" : "FAILED", "stageId" : 2, "attemptId" : 0, + "numTasks" : 8, "numActiveTasks" : 0, "numCompleteTasks" : 7, "numFailedTasks" : 1, + "numKilledTasks" : 0, + "numCompletedIndices" : 7, "executorRunTime" : 278, "executorCpuTime" : 0, "submissionTime" : "2015-02-03T16:43:06.296GMT", "firstTaskLaunchedTime" : "2015-02-03T16:43:06.296GMT", "completionTime" : "2015-02-03T16:43:06.347GMT", + "failureReason" : "Job aborted due to stage failure: Task 3 in stage 2.0 failed 1 times, most recent failure: Lost task 3.0 in stage 2.0 (TID 19, localhost): java.lang.RuntimeException: got a 3, failing\n\tat $line11.$read$$iwC$$iwC$$iwC$$iwC$$anonfun$1.apply(:18)\n\tat $line11.$read$$iwC$$iwC$$iwC$$iwC$$anonfun$1.apply(:17)\n\tat scala.collection.Iterator$$anon$11.next(Iterator.scala:328)\n\tat org.apache.spark.util.Utils$.getIteratorSize(Utils.scala:1311)\n\tat org.apache.spark.rdd.RDD$$anonfun$count$1.apply(RDD.scala:910)\n\tat org.apache.spark.rdd.RDD$$anonfun$count$1.apply(RDD.scala:910)\n\tat org.apache.spark.SparkContext$$anonfun$runJob$4.apply(SparkContext.scala:1314)\n\tat org.apache.spark.SparkContext$$anonfun$runJob$4.apply(SparkContext.scala:1314)\n\tat org.apache.spark.scheduler.ResultTask.runTask(ResultTask.scala:61)\n\tat org.apache.spark.scheduler.Task.run(Task.scala:56)\n\tat org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:196)\n\tat java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1145)\n\tat java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:615)\n\tat java.lang.Thread.run(Thread.java:745)\n\nDriver stacktrace:", "inputBytes" : 0, "inputRecords" : 0, "outputBytes" : 0, @@ -23,5 +27,7 @@ "name" : "count at :20", "details" : "org.apache.spark.rdd.RDD.count(RDD.scala:910)\n$line11.$read$$iwC$$iwC$$iwC$$iwC.(:20)\n$line11.$read$$iwC$$iwC$$iwC.(:25)\n$line11.$read$$iwC$$iwC.(:27)\n$line11.$read$$iwC.(:29)\n$line11.$read.(:31)\n$line11.$read$.(:35)\n$line11.$read$.()\n$line11.$eval$.(:7)\n$line11.$eval$.()\n$line11.$eval.$print()\nsun.reflect.NativeMethodAccessorImpl.invoke0(Native Method)\nsun.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:57)\nsun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)\njava.lang.reflect.Method.invoke(Method.java:606)\norg.apache.spark.repl.SparkIMain$ReadEvalPrint.call(SparkIMain.scala:852)\norg.apache.spark.repl.SparkIMain$Request.loadAndRun(SparkIMain.scala:1125)\norg.apache.spark.repl.SparkIMain.loadAndRunReq$1(SparkIMain.scala:674)\norg.apache.spark.repl.SparkIMain.interpret(SparkIMain.scala:705)\norg.apache.spark.repl.SparkIMain.interpret(SparkIMain.scala:669)", "schedulingPool" : "default", - "accumulatorUpdates" : [ ] + "rddIds" : [ 3, 2 ], + "accumulatorUpdates" : [ ], + "killedTasksSummary" : { } } ] diff --git a/core/src/test/resources/HistoryServerExpectations/job_list_from_multi_attempt_app_json_1__expectation.json b/core/src/test/resources/HistoryServerExpectations/job_list_from_multi_attempt_app_json_1__expectation.json index c108fa61a4318..2f275c7bfe2f4 100644 --- a/core/src/test/resources/HistoryServerExpectations/job_list_from_multi_attempt_app_json_1__expectation.json +++ b/core/src/test/resources/HistoryServerExpectations/job_list_from_multi_attempt_app_json_1__expectation.json @@ -8,8 +8,11 @@ "numCompletedTasks" : 8, "numSkippedTasks" : 0, "numFailedTasks" : 0, + "numKilledTasks" : 0, + "numCompletedIndices" : 8, "numActiveStages" : 0, "numCompletedStages" : 1, "numSkippedStages" : 0, - "numFailedStages" : 0 + "numFailedStages" : 0, + "killedTasksSummary" : { } } ] diff --git a/core/src/test/resources/HistoryServerExpectations/job_list_from_multi_attempt_app_json_2__expectation.json b/core/src/test/resources/HistoryServerExpectations/job_list_from_multi_attempt_app_json_2__expectation.json index c108fa61a4318..2f275c7bfe2f4 100644 --- a/core/src/test/resources/HistoryServerExpectations/job_list_from_multi_attempt_app_json_2__expectation.json +++ b/core/src/test/resources/HistoryServerExpectations/job_list_from_multi_attempt_app_json_2__expectation.json @@ -8,8 +8,11 @@ "numCompletedTasks" : 8, "numSkippedTasks" : 0, "numFailedTasks" : 0, + "numKilledTasks" : 0, + "numCompletedIndices" : 8, "numActiveStages" : 0, "numCompletedStages" : 1, "numSkippedStages" : 0, - "numFailedStages" : 0 + "numFailedStages" : 0, + "killedTasksSummary" : { } } ] diff --git a/core/src/test/resources/HistoryServerExpectations/job_list_json_expectation.json b/core/src/test/resources/HistoryServerExpectations/job_list_json_expectation.json index 3d7407004d262..71bf8706307c8 100644 --- a/core/src/test/resources/HistoryServerExpectations/job_list_json_expectation.json +++ b/core/src/test/resources/HistoryServerExpectations/job_list_json_expectation.json @@ -8,10 +8,13 @@ "numCompletedTasks" : 8, "numSkippedTasks" : 0, "numFailedTasks" : 0, + "numKilledTasks" : 0, + "numCompletedIndices" : 8, "numActiveStages" : 0, "numCompletedStages" : 1, "numSkippedStages" : 0, - "numFailedStages" : 0 + "numFailedStages" : 0, + "killedTasksSummary" : { } }, { "jobId" : 1, "name" : "count at :20", @@ -22,10 +25,13 @@ "numCompletedTasks" : 15, "numSkippedTasks" : 0, "numFailedTasks" : 1, + "numKilledTasks" : 0, + "numCompletedIndices" : 15, "numActiveStages" : 0, "numCompletedStages" : 1, "numSkippedStages" : 0, - "numFailedStages" : 1 + "numFailedStages" : 1, + "killedTasksSummary" : { } }, { "jobId" : 0, "name" : "count at :15", @@ -36,8 +42,11 @@ "numCompletedTasks" : 8, "numSkippedTasks" : 0, "numFailedTasks" : 0, + "numKilledTasks" : 0, + "numCompletedIndices" : 8, "numActiveStages" : 0, "numCompletedStages" : 1, "numSkippedStages" : 0, - "numFailedStages" : 0 + "numFailedStages" : 0, + "killedTasksSummary" : { } } ] diff --git a/core/src/test/resources/HistoryServerExpectations/one_job_json_expectation.json b/core/src/test/resources/HistoryServerExpectations/one_job_json_expectation.json index 10c7e1c0b36fd..1eae5f3d5beb3 100644 --- a/core/src/test/resources/HistoryServerExpectations/one_job_json_expectation.json +++ b/core/src/test/resources/HistoryServerExpectations/one_job_json_expectation.json @@ -8,8 +8,11 @@ "numCompletedTasks" : 8, "numSkippedTasks" : 0, "numFailedTasks" : 0, + "numKilledTasks" : 0, + "numCompletedIndices" : 8, "numActiveStages" : 0, "numCompletedStages" : 1, "numSkippedStages" : 0, - "numFailedStages" : 0 + "numFailedStages" : 0, + "killedTasksSummary" : { } } diff --git a/core/src/test/resources/HistoryServerExpectations/one_stage_attempt_json_expectation.json b/core/src/test/resources/HistoryServerExpectations/one_stage_attempt_json_expectation.json index 6fb40f6f1713b..31093a661663b 100644 --- a/core/src/test/resources/HistoryServerExpectations/one_stage_attempt_json_expectation.json +++ b/core/src/test/resources/HistoryServerExpectations/one_stage_attempt_json_expectation.json @@ -2,9 +2,12 @@ "status" : "COMPLETE", "stageId" : 1, "attemptId" : 0, + "numTasks" : 8, "numActiveTasks" : 0, "numCompleteTasks" : 8, "numFailedTasks" : 0, + "numKilledTasks" : 0, + "numCompletedIndices" : 8, "executorRunTime" : 3476, "executorCpuTime" : 0, "submissionTime" : "2015-02-03T16:43:05.829GMT", @@ -23,14 +26,15 @@ "name" : "map at :14", "details" : "org.apache.spark.rdd.RDD.map(RDD.scala:271)\n$line10.$read$$iwC$$iwC$$iwC$$iwC.(:14)\n$line10.$read$$iwC$$iwC$$iwC.(:19)\n$line10.$read$$iwC$$iwC.(:21)\n$line10.$read$$iwC.(:23)\n$line10.$read.(:25)\n$line10.$read$.(:29)\n$line10.$read$.()\n$line10.$eval$.(:7)\n$line10.$eval$.()\n$line10.$eval.$print()\nsun.reflect.NativeMethodAccessorImpl.invoke0(Native Method)\nsun.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:57)\nsun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)\njava.lang.reflect.Method.invoke(Method.java:606)\norg.apache.spark.repl.SparkIMain$ReadEvalPrint.call(SparkIMain.scala:852)\norg.apache.spark.repl.SparkIMain$Request.loadAndRun(SparkIMain.scala:1125)\norg.apache.spark.repl.SparkIMain.loadAndRunReq$1(SparkIMain.scala:674)\norg.apache.spark.repl.SparkIMain.interpret(SparkIMain.scala:705)\norg.apache.spark.repl.SparkIMain.interpret(SparkIMain.scala:669)", "schedulingPool" : "default", + "rddIds" : [ 1, 0 ], "accumulatorUpdates" : [ ], "tasks" : { - "8" : { - "taskId" : 8, - "index" : 0, + "10" : { + "taskId" : 10, + "index" : 2, "attempt" : 0, - "launchTime" : "2015-02-03T16:43:05.829GMT", - "duration" : 435, + "launchTime" : "2015-02-03T16:43:05.830GMT", + "duration" : 456, "executorId" : "", "host" : "localhost", "status" : "SUCCESS", @@ -38,15 +42,16 @@ "speculative" : false, "accumulatorUpdates" : [ ], "taskMetrics" : { - "executorDeserializeTime" : 1, + "executorDeserializeTime" : 2, "executorDeserializeCpuTime" : 0, - "executorRunTime" : 435, + "executorRunTime" : 434, "executorCpuTime" : 0, "resultSize" : 1902, "jvmGcTime" : 19, - "resultSerializationTime" : 2, + "resultSerializationTime" : 1, "memoryBytesSpilled" : 0, "diskBytesSpilled" : 0, + "peakExecutionMemory" : 0, "inputMetrics" : { "bytesRead" : 3500016, "recordsRead" : 0 @@ -66,17 +71,17 @@ }, "shuffleWriteMetrics" : { "bytesWritten" : 1648, - "writeTime" : 94000, + "writeTime" : 76000, "recordsWritten" : 0 } } }, - "9" : { - "taskId" : 9, - "index" : 1, + "14" : { + "taskId" : 14, + "index" : 6, "attempt" : 0, - "launchTime" : "2015-02-03T16:43:05.830GMT", - "duration" : 436, + "launchTime" : "2015-02-03T16:43:05.832GMT", + "duration" : 450, "executorId" : "", "host" : "localhost", "status" : "SUCCESS", @@ -84,15 +89,16 @@ "speculative" : false, "accumulatorUpdates" : [ ], "taskMetrics" : { - "executorDeserializeTime" : 1, + "executorDeserializeTime" : 2, "executorDeserializeCpuTime" : 0, - "executorRunTime" : 436, + "executorRunTime" : 434, "executorCpuTime" : 0, "resultSize" : 1902, "jvmGcTime" : 19, - "resultSerializationTime" : 0, + "resultSerializationTime" : 1, "memoryBytesSpilled" : 0, "diskBytesSpilled" : 0, + "peakExecutionMemory" : 0, "inputMetrics" : { "bytesRead" : 3500016, "recordsRead" : 0 @@ -112,17 +118,17 @@ }, "shuffleWriteMetrics" : { "bytesWritten" : 1648, - "writeTime" : 98000, + "writeTime" : 88000, "recordsWritten" : 0 } } }, - "10" : { - "taskId" : 10, - "index" : 2, + "9" : { + "taskId" : 9, + "index" : 1, "attempt" : 0, "launchTime" : "2015-02-03T16:43:05.830GMT", - "duration" : 434, + "duration" : 454, "executorId" : "", "host" : "localhost", "status" : "SUCCESS", @@ -130,15 +136,16 @@ "speculative" : false, "accumulatorUpdates" : [ ], "taskMetrics" : { - "executorDeserializeTime" : 2, + "executorDeserializeTime" : 1, "executorDeserializeCpuTime" : 0, - "executorRunTime" : 434, + "executorRunTime" : 436, "executorCpuTime" : 0, "resultSize" : 1902, "jvmGcTime" : 19, - "resultSerializationTime" : 1, + "resultSerializationTime" : 0, "memoryBytesSpilled" : 0, "diskBytesSpilled" : 0, + "peakExecutionMemory" : 0, "inputMetrics" : { "bytesRead" : 3500016, "recordsRead" : 0 @@ -158,17 +165,17 @@ }, "shuffleWriteMetrics" : { "bytesWritten" : 1648, - "writeTime" : 76000, + "writeTime" : 98000, "recordsWritten" : 0 } } }, - "11" : { - "taskId" : 11, - "index" : 3, + "13" : { + "taskId" : 13, + "index" : 5, "attempt" : 0, - "launchTime" : "2015-02-03T16:43:05.830GMT", - "duration" : 434, + "launchTime" : "2015-02-03T16:43:05.831GMT", + "duration" : 452, "executorId" : "", "host" : "localhost", "status" : "SUCCESS", @@ -182,9 +189,10 @@ "executorCpuTime" : 0, "resultSize" : 1902, "jvmGcTime" : 19, - "resultSerializationTime" : 1, + "resultSerializationTime" : 2, "memoryBytesSpilled" : 0, "diskBytesSpilled" : 0, + "peakExecutionMemory" : 0, "inputMetrics" : { "bytesRead" : 3500016, "recordsRead" : 0 @@ -203,8 +211,8 @@ "recordsRead" : 0 }, "shuffleWriteMetrics" : { - "bytesWritten" : 1647, - "writeTime" : 83000, + "bytesWritten" : 1648, + "writeTime" : 73000, "recordsWritten" : 0 } } @@ -214,7 +222,7 @@ "index" : 4, "attempt" : 0, "launchTime" : "2015-02-03T16:43:05.831GMT", - "duration" : 434, + "duration" : 454, "executorId" : "", "host" : "localhost", "status" : "SUCCESS", @@ -231,6 +239,7 @@ "resultSerializationTime" : 1, "memoryBytesSpilled" : 0, "diskBytesSpilled" : 0, + "peakExecutionMemory" : 0, "inputMetrics" : { "bytesRead" : 3500016, "recordsRead" : 0 @@ -255,12 +264,12 @@ } } }, - "13" : { - "taskId" : 13, - "index" : 5, + "11" : { + "taskId" : 11, + "index" : 3, "attempt" : 0, - "launchTime" : "2015-02-03T16:43:05.831GMT", - "duration" : 434, + "launchTime" : "2015-02-03T16:43:05.830GMT", + "duration" : 454, "executorId" : "", "host" : "localhost", "status" : "SUCCESS", @@ -274,9 +283,10 @@ "executorCpuTime" : 0, "resultSize" : 1902, "jvmGcTime" : 19, - "resultSerializationTime" : 2, + "resultSerializationTime" : 1, "memoryBytesSpilled" : 0, "diskBytesSpilled" : 0, + "peakExecutionMemory" : 0, "inputMetrics" : { "bytesRead" : 3500016, "recordsRead" : 0 @@ -295,18 +305,18 @@ "recordsRead" : 0 }, "shuffleWriteMetrics" : { - "bytesWritten" : 1648, - "writeTime" : 73000, + "bytesWritten" : 1647, + "writeTime" : 83000, "recordsWritten" : 0 } } }, - "14" : { - "taskId" : 14, - "index" : 6, + "8" : { + "taskId" : 8, + "index" : 0, "attempt" : 0, - "launchTime" : "2015-02-03T16:43:05.832GMT", - "duration" : 434, + "launchTime" : "2015-02-03T16:43:05.829GMT", + "duration" : 454, "executorId" : "", "host" : "localhost", "status" : "SUCCESS", @@ -314,15 +324,16 @@ "speculative" : false, "accumulatorUpdates" : [ ], "taskMetrics" : { - "executorDeserializeTime" : 2, + "executorDeserializeTime" : 1, "executorDeserializeCpuTime" : 0, - "executorRunTime" : 434, + "executorRunTime" : 435, "executorCpuTime" : 0, "resultSize" : 1902, "jvmGcTime" : 19, - "resultSerializationTime" : 1, + "resultSerializationTime" : 2, "memoryBytesSpilled" : 0, "diskBytesSpilled" : 0, + "peakExecutionMemory" : 0, "inputMetrics" : { "bytesRead" : 3500016, "recordsRead" : 0 @@ -342,7 +353,7 @@ }, "shuffleWriteMetrics" : { "bytesWritten" : 1648, - "writeTime" : 88000, + "writeTime" : 94000, "recordsWritten" : 0 } } @@ -352,7 +363,7 @@ "index" : 7, "attempt" : 0, "launchTime" : "2015-02-03T16:43:05.833GMT", - "duration" : 435, + "duration" : 450, "executorId" : "", "host" : "localhost", "status" : "SUCCESS", @@ -369,6 +380,7 @@ "resultSerializationTime" : 1, "memoryBytesSpilled" : 0, "diskBytesSpilled" : 0, + "peakExecutionMemory" : 0, "inputMetrics" : { "bytesRead" : 3500016, "recordsRead" : 0 @@ -399,12 +411,18 @@ "taskTime" : 3624, "failedTasks" : 0, "succeededTasks" : 8, + "killedTasks" : 0, "inputBytes" : 28000128, + "inputRecords" : 0, "outputBytes" : 0, + "outputRecords" : 0, "shuffleRead" : 0, + "shuffleReadRecords" : 0, "shuffleWrite" : 13180, + "shuffleWriteRecords" : 0, "memoryBytesSpilled" : 0, "diskBytesSpilled" : 0 } - } + }, + "killedTasksSummary" : { } } diff --git a/core/src/test/resources/HistoryServerExpectations/one_stage_json_expectation.json b/core/src/test/resources/HistoryServerExpectations/one_stage_json_expectation.json index f5a89a2107646..601d70695b17c 100644 --- a/core/src/test/resources/HistoryServerExpectations/one_stage_json_expectation.json +++ b/core/src/test/resources/HistoryServerExpectations/one_stage_json_expectation.json @@ -2,9 +2,12 @@ "status" : "COMPLETE", "stageId" : 1, "attemptId" : 0, + "numTasks" : 8, "numActiveTasks" : 0, "numCompleteTasks" : 8, "numFailedTasks" : 0, + "numKilledTasks" : 0, + "numCompletedIndices" : 8, "executorRunTime" : 3476, "executorCpuTime" : 0, "submissionTime" : "2015-02-03T16:43:05.829GMT", @@ -23,14 +26,15 @@ "name" : "map at :14", "details" : "org.apache.spark.rdd.RDD.map(RDD.scala:271)\n$line10.$read$$iwC$$iwC$$iwC$$iwC.(:14)\n$line10.$read$$iwC$$iwC$$iwC.(:19)\n$line10.$read$$iwC$$iwC.(:21)\n$line10.$read$$iwC.(:23)\n$line10.$read.(:25)\n$line10.$read$.(:29)\n$line10.$read$.()\n$line10.$eval$.(:7)\n$line10.$eval$.()\n$line10.$eval.$print()\nsun.reflect.NativeMethodAccessorImpl.invoke0(Native Method)\nsun.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:57)\nsun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)\njava.lang.reflect.Method.invoke(Method.java:606)\norg.apache.spark.repl.SparkIMain$ReadEvalPrint.call(SparkIMain.scala:852)\norg.apache.spark.repl.SparkIMain$Request.loadAndRun(SparkIMain.scala:1125)\norg.apache.spark.repl.SparkIMain.loadAndRunReq$1(SparkIMain.scala:674)\norg.apache.spark.repl.SparkIMain.interpret(SparkIMain.scala:705)\norg.apache.spark.repl.SparkIMain.interpret(SparkIMain.scala:669)", "schedulingPool" : "default", + "rddIds" : [ 1, 0 ], "accumulatorUpdates" : [ ], "tasks" : { - "8" : { - "taskId" : 8, - "index" : 0, + "10" : { + "taskId" : 10, + "index" : 2, "attempt" : 0, - "launchTime" : "2015-02-03T16:43:05.829GMT", - "duration" : 435, + "launchTime" : "2015-02-03T16:43:05.830GMT", + "duration" : 456, "executorId" : "", "host" : "localhost", "status" : "SUCCESS", @@ -38,15 +42,16 @@ "speculative" : false, "accumulatorUpdates" : [ ], "taskMetrics" : { - "executorDeserializeTime" : 1, + "executorDeserializeTime" : 2, "executorDeserializeCpuTime" : 0, - "executorRunTime" : 435, + "executorRunTime" : 434, "executorCpuTime" : 0, "resultSize" : 1902, "jvmGcTime" : 19, - "resultSerializationTime" : 2, + "resultSerializationTime" : 1, "memoryBytesSpilled" : 0, "diskBytesSpilled" : 0, + "peakExecutionMemory" : 0, "inputMetrics" : { "bytesRead" : 3500016, "recordsRead" : 0 @@ -66,17 +71,17 @@ }, "shuffleWriteMetrics" : { "bytesWritten" : 1648, - "writeTime" : 94000, + "writeTime" : 76000, "recordsWritten" : 0 } } }, - "9" : { - "taskId" : 9, - "index" : 1, + "14" : { + "taskId" : 14, + "index" : 6, "attempt" : 0, - "launchTime" : "2015-02-03T16:43:05.830GMT", - "duration" : 436, + "launchTime" : "2015-02-03T16:43:05.832GMT", + "duration" : 450, "executorId" : "", "host" : "localhost", "status" : "SUCCESS", @@ -84,15 +89,16 @@ "speculative" : false, "accumulatorUpdates" : [ ], "taskMetrics" : { - "executorDeserializeTime" : 1, + "executorDeserializeTime" : 2, "executorDeserializeCpuTime" : 0, - "executorRunTime" : 436, + "executorRunTime" : 434, "executorCpuTime" : 0, "resultSize" : 1902, "jvmGcTime" : 19, - "resultSerializationTime" : 0, + "resultSerializationTime" : 1, "memoryBytesSpilled" : 0, "diskBytesSpilled" : 0, + "peakExecutionMemory" : 0, "inputMetrics" : { "bytesRead" : 3500016, "recordsRead" : 0 @@ -112,17 +118,17 @@ }, "shuffleWriteMetrics" : { "bytesWritten" : 1648, - "writeTime" : 98000, + "writeTime" : 88000, "recordsWritten" : 0 } } }, - "10" : { - "taskId" : 10, - "index" : 2, + "9" : { + "taskId" : 9, + "index" : 1, "attempt" : 0, "launchTime" : "2015-02-03T16:43:05.830GMT", - "duration" : 434, + "duration" : 454, "executorId" : "", "host" : "localhost", "status" : "SUCCESS", @@ -130,15 +136,16 @@ "speculative" : false, "accumulatorUpdates" : [ ], "taskMetrics" : { - "executorDeserializeTime" : 2, + "executorDeserializeTime" : 1, "executorDeserializeCpuTime" : 0, - "executorRunTime" : 434, + "executorRunTime" : 436, "executorCpuTime" : 0, "resultSize" : 1902, "jvmGcTime" : 19, - "resultSerializationTime" : 1, + "resultSerializationTime" : 0, "memoryBytesSpilled" : 0, "diskBytesSpilled" : 0, + "peakExecutionMemory" : 0, "inputMetrics" : { "bytesRead" : 3500016, "recordsRead" : 0 @@ -158,17 +165,17 @@ }, "shuffleWriteMetrics" : { "bytesWritten" : 1648, - "writeTime" : 76000, + "writeTime" : 98000, "recordsWritten" : 0 } } }, - "11" : { - "taskId" : 11, - "index" : 3, + "13" : { + "taskId" : 13, + "index" : 5, "attempt" : 0, - "launchTime" : "2015-02-03T16:43:05.830GMT", - "duration" : 434, + "launchTime" : "2015-02-03T16:43:05.831GMT", + "duration" : 452, "executorId" : "", "host" : "localhost", "status" : "SUCCESS", @@ -182,9 +189,10 @@ "executorCpuTime" : 0, "resultSize" : 1902, "jvmGcTime" : 19, - "resultSerializationTime" : 1, + "resultSerializationTime" : 2, "memoryBytesSpilled" : 0, "diskBytesSpilled" : 0, + "peakExecutionMemory" : 0, "inputMetrics" : { "bytesRead" : 3500016, "recordsRead" : 0 @@ -203,8 +211,8 @@ "recordsRead" : 0 }, "shuffleWriteMetrics" : { - "bytesWritten" : 1647, - "writeTime" : 83000, + "bytesWritten" : 1648, + "writeTime" : 73000, "recordsWritten" : 0 } } @@ -214,7 +222,7 @@ "index" : 4, "attempt" : 0, "launchTime" : "2015-02-03T16:43:05.831GMT", - "duration" : 434, + "duration" : 454, "executorId" : "", "host" : "localhost", "status" : "SUCCESS", @@ -231,6 +239,7 @@ "resultSerializationTime" : 1, "memoryBytesSpilled" : 0, "diskBytesSpilled" : 0, + "peakExecutionMemory" : 0, "inputMetrics" : { "bytesRead" : 3500016, "recordsRead" : 0 @@ -255,12 +264,12 @@ } } }, - "13" : { - "taskId" : 13, - "index" : 5, + "11" : { + "taskId" : 11, + "index" : 3, "attempt" : 0, - "launchTime" : "2015-02-03T16:43:05.831GMT", - "duration" : 434, + "launchTime" : "2015-02-03T16:43:05.830GMT", + "duration" : 454, "executorId" : "", "host" : "localhost", "status" : "SUCCESS", @@ -274,9 +283,10 @@ "executorCpuTime" : 0, "resultSize" : 1902, "jvmGcTime" : 19, - "resultSerializationTime" : 2, + "resultSerializationTime" : 1, "memoryBytesSpilled" : 0, "diskBytesSpilled" : 0, + "peakExecutionMemory" : 0, "inputMetrics" : { "bytesRead" : 3500016, "recordsRead" : 0 @@ -295,18 +305,18 @@ "recordsRead" : 0 }, "shuffleWriteMetrics" : { - "bytesWritten" : 1648, - "writeTime" : 73000, + "bytesWritten" : 1647, + "writeTime" : 83000, "recordsWritten" : 0 } } }, - "14" : { - "taskId" : 14, - "index" : 6, + "8" : { + "taskId" : 8, + "index" : 0, "attempt" : 0, - "launchTime" : "2015-02-03T16:43:05.832GMT", - "duration" : 434, + "launchTime" : "2015-02-03T16:43:05.829GMT", + "duration" : 454, "executorId" : "", "host" : "localhost", "status" : "SUCCESS", @@ -314,15 +324,16 @@ "speculative" : false, "accumulatorUpdates" : [ ], "taskMetrics" : { - "executorDeserializeTime" : 2, + "executorDeserializeTime" : 1, "executorDeserializeCpuTime" : 0, - "executorRunTime" : 434, + "executorRunTime" : 435, "executorCpuTime" : 0, "resultSize" : 1902, "jvmGcTime" : 19, - "resultSerializationTime" : 1, + "resultSerializationTime" : 2, "memoryBytesSpilled" : 0, "diskBytesSpilled" : 0, + "peakExecutionMemory" : 0, "inputMetrics" : { "bytesRead" : 3500016, "recordsRead" : 0 @@ -342,7 +353,7 @@ }, "shuffleWriteMetrics" : { "bytesWritten" : 1648, - "writeTime" : 88000, + "writeTime" : 94000, "recordsWritten" : 0 } } @@ -352,7 +363,7 @@ "index" : 7, "attempt" : 0, "launchTime" : "2015-02-03T16:43:05.833GMT", - "duration" : 435, + "duration" : 450, "executorId" : "", "host" : "localhost", "status" : "SUCCESS", @@ -369,6 +380,7 @@ "resultSerializationTime" : 1, "memoryBytesSpilled" : 0, "diskBytesSpilled" : 0, + "peakExecutionMemory" : 0, "inputMetrics" : { "bytesRead" : 3500016, "recordsRead" : 0 @@ -399,12 +411,18 @@ "taskTime" : 3624, "failedTasks" : 0, "succeededTasks" : 8, + "killedTasks" : 0, "inputBytes" : 28000128, + "inputRecords" : 0, "outputBytes" : 0, + "outputRecords" : 0, "shuffleRead" : 0, + "shuffleReadRecords" : 0, "shuffleWrite" : 13180, + "shuffleWriteRecords" : 0, "memoryBytesSpilled" : 0, "diskBytesSpilled" : 0 } - } + }, + "killedTasksSummary" : { } } ] diff --git a/core/src/test/resources/HistoryServerExpectations/stage_list_json_expectation.json b/core/src/test/resources/HistoryServerExpectations/stage_list_json_expectation.json index 6509df1508b30..1e6fb40d60284 100644 --- a/core/src/test/resources/HistoryServerExpectations/stage_list_json_expectation.json +++ b/core/src/test/resources/HistoryServerExpectations/stage_list_json_expectation.json @@ -2,9 +2,12 @@ "status" : "COMPLETE", "stageId" : 3, "attemptId" : 0, + "numTasks" : 8, "numActiveTasks" : 0, "numCompleteTasks" : 8, "numFailedTasks" : 0, + "numKilledTasks" : 0, + "numCompletedIndices" : 8, "executorRunTime" : 162, "executorCpuTime" : 0, "submissionTime" : "2015-02-03T16:43:07.191GMT", @@ -23,14 +26,51 @@ "name" : "count at :17", "details" : "org.apache.spark.rdd.RDD.count(RDD.scala:910)\n$line19.$read$$iwC$$iwC$$iwC$$iwC.(:17)\n$line19.$read$$iwC$$iwC$$iwC.(:22)\n$line19.$read$$iwC$$iwC.(:24)\n$line19.$read$$iwC.(:26)\n$line19.$read.(:28)\n$line19.$read$.(:32)\n$line19.$read$.()\n$line19.$eval$.(:7)\n$line19.$eval$.()\n$line19.$eval.$print()\nsun.reflect.NativeMethodAccessorImpl.invoke0(Native Method)\nsun.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:57)\nsun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)\njava.lang.reflect.Method.invoke(Method.java:606)\norg.apache.spark.repl.SparkIMain$ReadEvalPrint.call(SparkIMain.scala:852)\norg.apache.spark.repl.SparkIMain$Request.loadAndRun(SparkIMain.scala:1125)\norg.apache.spark.repl.SparkIMain.loadAndRunReq$1(SparkIMain.scala:674)\norg.apache.spark.repl.SparkIMain.interpret(SparkIMain.scala:705)\norg.apache.spark.repl.SparkIMain.interpret(SparkIMain.scala:669)", "schedulingPool" : "default", - "accumulatorUpdates" : [ ] + "rddIds" : [ 6, 5 ], + "accumulatorUpdates" : [ ], + "killedTasksSummary" : { } +}, { + "status" : "FAILED", + "stageId" : 2, + "attemptId" : 0, + "numTasks" : 8, + "numActiveTasks" : 0, + "numCompleteTasks" : 7, + "numFailedTasks" : 1, + "numKilledTasks" : 0, + "numCompletedIndices" : 7, + "executorRunTime" : 278, + "executorCpuTime" : 0, + "submissionTime" : "2015-02-03T16:43:06.296GMT", + "firstTaskLaunchedTime" : "2015-02-03T16:43:06.296GMT", + "completionTime" : "2015-02-03T16:43:06.347GMT", + "failureReason" : "Job aborted due to stage failure: Task 3 in stage 2.0 failed 1 times, most recent failure: Lost task 3.0 in stage 2.0 (TID 19, localhost): java.lang.RuntimeException: got a 3, failing\n\tat $line11.$read$$iwC$$iwC$$iwC$$iwC$$anonfun$1.apply(:18)\n\tat $line11.$read$$iwC$$iwC$$iwC$$iwC$$anonfun$1.apply(:17)\n\tat scala.collection.Iterator$$anon$11.next(Iterator.scala:328)\n\tat org.apache.spark.util.Utils$.getIteratorSize(Utils.scala:1311)\n\tat org.apache.spark.rdd.RDD$$anonfun$count$1.apply(RDD.scala:910)\n\tat org.apache.spark.rdd.RDD$$anonfun$count$1.apply(RDD.scala:910)\n\tat org.apache.spark.SparkContext$$anonfun$runJob$4.apply(SparkContext.scala:1314)\n\tat org.apache.spark.SparkContext$$anonfun$runJob$4.apply(SparkContext.scala:1314)\n\tat org.apache.spark.scheduler.ResultTask.runTask(ResultTask.scala:61)\n\tat org.apache.spark.scheduler.Task.run(Task.scala:56)\n\tat org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:196)\n\tat java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1145)\n\tat java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:615)\n\tat java.lang.Thread.run(Thread.java:745)\n\nDriver stacktrace:", + "inputBytes" : 0, + "inputRecords" : 0, + "outputBytes" : 0, + "outputRecords" : 0, + "shuffleReadBytes" : 0, + "shuffleReadRecords" : 0, + "shuffleWriteBytes" : 0, + "shuffleWriteRecords" : 0, + "memoryBytesSpilled" : 0, + "diskBytesSpilled" : 0, + "name" : "count at :20", + "details" : "org.apache.spark.rdd.RDD.count(RDD.scala:910)\n$line11.$read$$iwC$$iwC$$iwC$$iwC.(:20)\n$line11.$read$$iwC$$iwC$$iwC.(:25)\n$line11.$read$$iwC$$iwC.(:27)\n$line11.$read$$iwC.(:29)\n$line11.$read.(:31)\n$line11.$read$.(:35)\n$line11.$read$.()\n$line11.$eval$.(:7)\n$line11.$eval$.()\n$line11.$eval.$print()\nsun.reflect.NativeMethodAccessorImpl.invoke0(Native Method)\nsun.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:57)\nsun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)\njava.lang.reflect.Method.invoke(Method.java:606)\norg.apache.spark.repl.SparkIMain$ReadEvalPrint.call(SparkIMain.scala:852)\norg.apache.spark.repl.SparkIMain$Request.loadAndRun(SparkIMain.scala:1125)\norg.apache.spark.repl.SparkIMain.loadAndRunReq$1(SparkIMain.scala:674)\norg.apache.spark.repl.SparkIMain.interpret(SparkIMain.scala:705)\norg.apache.spark.repl.SparkIMain.interpret(SparkIMain.scala:669)", + "schedulingPool" : "default", + "rddIds" : [ 3, 2 ], + "accumulatorUpdates" : [ ], + "killedTasksSummary" : { } }, { "status" : "COMPLETE", "stageId" : 1, "attemptId" : 0, + "numTasks" : 8, "numActiveTasks" : 0, "numCompleteTasks" : 8, "numFailedTasks" : 0, + "numKilledTasks" : 0, + "numCompletedIndices" : 8, "executorRunTime" : 3476, "executorCpuTime" : 0, "submissionTime" : "2015-02-03T16:43:05.829GMT", @@ -49,14 +89,19 @@ "name" : "map at :14", "details" : "org.apache.spark.rdd.RDD.map(RDD.scala:271)\n$line10.$read$$iwC$$iwC$$iwC$$iwC.(:14)\n$line10.$read$$iwC$$iwC$$iwC.(:19)\n$line10.$read$$iwC$$iwC.(:21)\n$line10.$read$$iwC.(:23)\n$line10.$read.(:25)\n$line10.$read$.(:29)\n$line10.$read$.()\n$line10.$eval$.(:7)\n$line10.$eval$.()\n$line10.$eval.$print()\nsun.reflect.NativeMethodAccessorImpl.invoke0(Native Method)\nsun.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:57)\nsun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)\njava.lang.reflect.Method.invoke(Method.java:606)\norg.apache.spark.repl.SparkIMain$ReadEvalPrint.call(SparkIMain.scala:852)\norg.apache.spark.repl.SparkIMain$Request.loadAndRun(SparkIMain.scala:1125)\norg.apache.spark.repl.SparkIMain.loadAndRunReq$1(SparkIMain.scala:674)\norg.apache.spark.repl.SparkIMain.interpret(SparkIMain.scala:705)\norg.apache.spark.repl.SparkIMain.interpret(SparkIMain.scala:669)", "schedulingPool" : "default", - "accumulatorUpdates" : [ ] + "rddIds" : [ 1, 0 ], + "accumulatorUpdates" : [ ], + "killedTasksSummary" : { } }, { "status" : "COMPLETE", "stageId" : 0, "attemptId" : 0, + "numTasks" : 8, "numActiveTasks" : 0, "numCompleteTasks" : 8, "numFailedTasks" : 0, + "numKilledTasks" : 0, + "numCompletedIndices" : 8, "executorRunTime" : 4338, "executorCpuTime" : 0, "submissionTime" : "2015-02-03T16:43:04.228GMT", @@ -75,31 +120,7 @@ "name" : "count at :15", "details" : "org.apache.spark.rdd.RDD.count(RDD.scala:910)\n$line9.$read$$iwC$$iwC$$iwC$$iwC.(:15)\n$line9.$read$$iwC$$iwC$$iwC.(:20)\n$line9.$read$$iwC$$iwC.(:22)\n$line9.$read$$iwC.(:24)\n$line9.$read.(:26)\n$line9.$read$.(:30)\n$line9.$read$.()\n$line9.$eval$.(:7)\n$line9.$eval$.()\n$line9.$eval.$print()\nsun.reflect.NativeMethodAccessorImpl.invoke0(Native Method)\nsun.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:57)\nsun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)\njava.lang.reflect.Method.invoke(Method.java:606)\norg.apache.spark.repl.SparkIMain$ReadEvalPrint.call(SparkIMain.scala:852)\norg.apache.spark.repl.SparkIMain$Request.loadAndRun(SparkIMain.scala:1125)\norg.apache.spark.repl.SparkIMain.loadAndRunReq$1(SparkIMain.scala:674)\norg.apache.spark.repl.SparkIMain.interpret(SparkIMain.scala:705)\norg.apache.spark.repl.SparkIMain.interpret(SparkIMain.scala:669)", "schedulingPool" : "default", - "accumulatorUpdates" : [ ] -}, { - "status" : "FAILED", - "stageId" : 2, - "attemptId" : 0, - "numActiveTasks" : 0, - "numCompleteTasks" : 7, - "numFailedTasks" : 1, - "executorRunTime" : 278, - "executorCpuTime" : 0, - "submissionTime" : "2015-02-03T16:43:06.296GMT", - "firstTaskLaunchedTime" : "2015-02-03T16:43:06.296GMT", - "completionTime" : "2015-02-03T16:43:06.347GMT", - "inputBytes" : 0, - "inputRecords" : 0, - "outputBytes" : 0, - "outputRecords" : 0, - "shuffleReadBytes" : 0, - "shuffleReadRecords" : 0, - "shuffleWriteBytes" : 0, - "shuffleWriteRecords" : 0, - "memoryBytesSpilled" : 0, - "diskBytesSpilled" : 0, - "name" : "count at :20", - "details" : "org.apache.spark.rdd.RDD.count(RDD.scala:910)\n$line11.$read$$iwC$$iwC$$iwC$$iwC.(:20)\n$line11.$read$$iwC$$iwC$$iwC.(:25)\n$line11.$read$$iwC$$iwC.(:27)\n$line11.$read$$iwC.(:29)\n$line11.$read.(:31)\n$line11.$read$.(:35)\n$line11.$read$.()\n$line11.$eval$.(:7)\n$line11.$eval$.()\n$line11.$eval.$print()\nsun.reflect.NativeMethodAccessorImpl.invoke0(Native Method)\nsun.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:57)\nsun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)\njava.lang.reflect.Method.invoke(Method.java:606)\norg.apache.spark.repl.SparkIMain$ReadEvalPrint.call(SparkIMain.scala:852)\norg.apache.spark.repl.SparkIMain$Request.loadAndRun(SparkIMain.scala:1125)\norg.apache.spark.repl.SparkIMain.loadAndRunReq$1(SparkIMain.scala:674)\norg.apache.spark.repl.SparkIMain.interpret(SparkIMain.scala:705)\norg.apache.spark.repl.SparkIMain.interpret(SparkIMain.scala:669)", - "schedulingPool" : "default", - "accumulatorUpdates" : [ ] + "rddIds" : [ 0 ], + "accumulatorUpdates" : [ ], + "killedTasksSummary" : { } } ] diff --git a/core/src/test/resources/HistoryServerExpectations/stage_list_with_accumulable_json_expectation.json b/core/src/test/resources/HistoryServerExpectations/stage_list_with_accumulable_json_expectation.json index 8496863a93469..e6284ccf9b73d 100644 --- a/core/src/test/resources/HistoryServerExpectations/stage_list_with_accumulable_json_expectation.json +++ b/core/src/test/resources/HistoryServerExpectations/stage_list_with_accumulable_json_expectation.json @@ -2,9 +2,12 @@ "status" : "COMPLETE", "stageId" : 0, "attemptId" : 0, + "numTasks" : 8, "numActiveTasks" : 0, "numCompleteTasks" : 8, "numFailedTasks" : 0, + "numKilledTasks" : 0, + "numCompletedIndices" : 8, "executorRunTime" : 120, "executorCpuTime" : 0, "submissionTime" : "2015-03-16T19:25:36.103GMT", @@ -23,9 +26,11 @@ "name" : "foreach at :15", "details" : "org.apache.spark.rdd.RDD.foreach(RDD.scala:765)\n$line9.$read$$iwC$$iwC$$iwC$$iwC.(:15)\n$line9.$read$$iwC$$iwC$$iwC.(:20)\n$line9.$read$$iwC$$iwC.(:22)\n$line9.$read$$iwC.(:24)\n$line9.$read.(:26)\n$line9.$read$.(:30)\n$line9.$read$.()\n$line9.$eval$.(:7)\n$line9.$eval$.()\n$line9.$eval.$print()\nsun.reflect.NativeMethodAccessorImpl.invoke0(Native Method)\nsun.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:62)\nsun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)\njava.lang.reflect.Method.invoke(Method.java:483)\norg.apache.spark.repl.SparkIMain$ReadEvalPrint.call(SparkIMain.scala:852)\norg.apache.spark.repl.SparkIMain$Request.loadAndRun(SparkIMain.scala:1125)\norg.apache.spark.repl.SparkIMain.loadAndRunReq$1(SparkIMain.scala:674)\norg.apache.spark.repl.SparkIMain.interpret(SparkIMain.scala:705)\norg.apache.spark.repl.SparkIMain.interpret(SparkIMain.scala:669)", "schedulingPool" : "default", + "rddIds" : [ 0 ], "accumulatorUpdates" : [ { "id" : 1, "name" : "my counter", "value" : "5050" - } ] + } ], + "killedTasksSummary" : { } } ] diff --git a/core/src/test/resources/HistoryServerExpectations/stage_task_list_expectation.json b/core/src/test/resources/HistoryServerExpectations/stage_task_list_expectation.json index 9b401b414f8d4..a15ee23523365 100644 --- a/core/src/test/resources/HistoryServerExpectations/stage_task_list_expectation.json +++ b/core/src/test/resources/HistoryServerExpectations/stage_task_list_expectation.json @@ -3,7 +3,7 @@ "index" : 0, "attempt" : 0, "launchTime" : "2015-05-06T13:03:06.494GMT", - "duration" : 349, + "duration" : 435, "executorId" : "driver", "host" : "localhost", "status" : "SUCCESS", @@ -20,6 +20,7 @@ "resultSerializationTime" : 1, "memoryBytesSpilled" : 0, "diskBytesSpilled" : 0, + "peakExecutionMemory" : 0, "inputMetrics" : { "bytesRead" : 49294, "recordsRead" : 10000 @@ -48,7 +49,7 @@ "index" : 1, "attempt" : 0, "launchTime" : "2015-05-06T13:03:06.502GMT", - "duration" : 350, + "duration" : 421, "executorId" : "driver", "host" : "localhost", "status" : "SUCCESS", @@ -65,6 +66,7 @@ "resultSerializationTime" : 0, "memoryBytesSpilled" : 0, "diskBytesSpilled" : 0, + "peakExecutionMemory" : 0, "inputMetrics" : { "bytesRead" : 60488, "recordsRead" : 10000 @@ -93,7 +95,7 @@ "index" : 2, "attempt" : 0, "launchTime" : "2015-05-06T13:03:06.503GMT", - "duration" : 348, + "duration" : 419, "executorId" : "driver", "host" : "localhost", "status" : "SUCCESS", @@ -110,6 +112,7 @@ "resultSerializationTime" : 2, "memoryBytesSpilled" : 0, "diskBytesSpilled" : 0, + "peakExecutionMemory" : 0, "inputMetrics" : { "bytesRead" : 60488, "recordsRead" : 10000 @@ -138,7 +141,7 @@ "index" : 3, "attempt" : 0, "launchTime" : "2015-05-06T13:03:06.504GMT", - "duration" : 349, + "duration" : 423, "executorId" : "driver", "host" : "localhost", "status" : "SUCCESS", @@ -155,6 +158,7 @@ "resultSerializationTime" : 2, "memoryBytesSpilled" : 0, "diskBytesSpilled" : 0, + "peakExecutionMemory" : 0, "inputMetrics" : { "bytesRead" : 60488, "recordsRead" : 10000 @@ -183,7 +187,7 @@ "index" : 4, "attempt" : 0, "launchTime" : "2015-05-06T13:03:06.504GMT", - "duration" : 349, + "duration" : 419, "executorId" : "driver", "host" : "localhost", "status" : "SUCCESS", @@ -200,6 +204,7 @@ "resultSerializationTime" : 1, "memoryBytesSpilled" : 0, "diskBytesSpilled" : 0, + "peakExecutionMemory" : 0, "inputMetrics" : { "bytesRead" : 60488, "recordsRead" : 10000 @@ -228,7 +233,7 @@ "index" : 5, "attempt" : 0, "launchTime" : "2015-05-06T13:03:06.505GMT", - "duration" : 350, + "duration" : 414, "executorId" : "driver", "host" : "localhost", "status" : "SUCCESS", @@ -245,6 +250,7 @@ "resultSerializationTime" : 1, "memoryBytesSpilled" : 0, "diskBytesSpilled" : 0, + "peakExecutionMemory" : 0, "inputMetrics" : { "bytesRead" : 60488, "recordsRead" : 10000 @@ -273,7 +279,7 @@ "index" : 6, "attempt" : 0, "launchTime" : "2015-05-06T13:03:06.505GMT", - "duration" : 351, + "duration" : 419, "executorId" : "driver", "host" : "localhost", "status" : "SUCCESS", @@ -290,6 +296,7 @@ "resultSerializationTime" : 1, "memoryBytesSpilled" : 0, "diskBytesSpilled" : 0, + "peakExecutionMemory" : 0, "inputMetrics" : { "bytesRead" : 60488, "recordsRead" : 10000 @@ -318,7 +325,7 @@ "index" : 7, "attempt" : 0, "launchTime" : "2015-05-06T13:03:06.506GMT", - "duration" : 349, + "duration" : 423, "executorId" : "driver", "host" : "localhost", "status" : "SUCCESS", @@ -335,6 +342,7 @@ "resultSerializationTime" : 0, "memoryBytesSpilled" : 0, "diskBytesSpilled" : 0, + "peakExecutionMemory" : 0, "inputMetrics" : { "bytesRead" : 60488, "recordsRead" : 10000 @@ -363,7 +371,7 @@ "index" : 8, "attempt" : 0, "launchTime" : "2015-05-06T13:03:06.914GMT", - "duration" : 80, + "duration" : 88, "executorId" : "driver", "host" : "localhost", "status" : "SUCCESS", @@ -380,6 +388,7 @@ "resultSerializationTime" : 0, "memoryBytesSpilled" : 0, "diskBytesSpilled" : 0, + "peakExecutionMemory" : 0, "inputMetrics" : { "bytesRead" : 60488, "recordsRead" : 10000 @@ -408,7 +417,7 @@ "index" : 9, "attempt" : 0, "launchTime" : "2015-05-06T13:03:06.915GMT", - "duration" : 84, + "duration" : 101, "executorId" : "driver", "host" : "localhost", "status" : "SUCCESS", @@ -425,6 +434,7 @@ "resultSerializationTime" : 0, "memoryBytesSpilled" : 0, "diskBytesSpilled" : 0, + "peakExecutionMemory" : 0, "inputMetrics" : { "bytesRead" : 60489, "recordsRead" : 10000 @@ -453,7 +463,7 @@ "index" : 10, "attempt" : 0, "launchTime" : "2015-05-06T13:03:06.916GMT", - "duration" : 73, + "duration" : 99, "executorId" : "driver", "host" : "localhost", "status" : "SUCCESS", @@ -470,6 +480,7 @@ "resultSerializationTime" : 0, "memoryBytesSpilled" : 0, "diskBytesSpilled" : 0, + "peakExecutionMemory" : 0, "inputMetrics" : { "bytesRead" : 70564, "recordsRead" : 10000 @@ -498,7 +509,7 @@ "index" : 11, "attempt" : 0, "launchTime" : "2015-05-06T13:03:06.918GMT", - "duration" : 75, + "duration" : 89, "executorId" : "driver", "host" : "localhost", "status" : "SUCCESS", @@ -515,6 +526,7 @@ "resultSerializationTime" : 0, "memoryBytesSpilled" : 0, "diskBytesSpilled" : 0, + "peakExecutionMemory" : 0, "inputMetrics" : { "bytesRead" : 70564, "recordsRead" : 10000 @@ -543,7 +555,7 @@ "index" : 12, "attempt" : 0, "launchTime" : "2015-05-06T13:03:06.923GMT", - "duration" : 77, + "duration" : 93, "executorId" : "driver", "host" : "localhost", "status" : "SUCCESS", @@ -560,6 +572,7 @@ "resultSerializationTime" : 0, "memoryBytesSpilled" : 0, "diskBytesSpilled" : 0, + "peakExecutionMemory" : 0, "inputMetrics" : { "bytesRead" : 70564, "recordsRead" : 10000 @@ -588,7 +601,7 @@ "index" : 13, "attempt" : 0, "launchTime" : "2015-05-06T13:03:06.924GMT", - "duration" : 76, + "duration" : 138, "executorId" : "driver", "host" : "localhost", "status" : "SUCCESS", @@ -605,6 +618,7 @@ "resultSerializationTime" : 0, "memoryBytesSpilled" : 0, "diskBytesSpilled" : 0, + "peakExecutionMemory" : 0, "inputMetrics" : { "bytesRead" : 70564, "recordsRead" : 10000 @@ -633,7 +647,7 @@ "index" : 14, "attempt" : 0, "launchTime" : "2015-05-06T13:03:06.925GMT", - "duration" : 83, + "duration" : 94, "executorId" : "driver", "host" : "localhost", "status" : "SUCCESS", @@ -650,6 +664,7 @@ "resultSerializationTime" : 0, "memoryBytesSpilled" : 0, "diskBytesSpilled" : 0, + "peakExecutionMemory" : 0, "inputMetrics" : { "bytesRead" : 70564, "recordsRead" : 10000 @@ -678,7 +693,7 @@ "index" : 15, "attempt" : 0, "launchTime" : "2015-05-06T13:03:06.928GMT", - "duration" : 76, + "duration" : 83, "executorId" : "driver", "host" : "localhost", "status" : "SUCCESS", @@ -695,6 +710,7 @@ "resultSerializationTime" : 0, "memoryBytesSpilled" : 0, "diskBytesSpilled" : 0, + "peakExecutionMemory" : 0, "inputMetrics" : { "bytesRead" : 70564, "recordsRead" : 10000 @@ -723,7 +739,7 @@ "index" : 16, "attempt" : 0, "launchTime" : "2015-05-06T13:03:07.001GMT", - "duration" : 84, + "duration" : 98, "executorId" : "driver", "host" : "localhost", "status" : "SUCCESS", @@ -740,6 +756,7 @@ "resultSerializationTime" : 0, "memoryBytesSpilled" : 0, "diskBytesSpilled" : 0, + "peakExecutionMemory" : 0, "inputMetrics" : { "bytesRead" : 70564, "recordsRead" : 10000 @@ -768,7 +785,7 @@ "index" : 17, "attempt" : 0, "launchTime" : "2015-05-06T13:03:07.005GMT", - "duration" : 91, + "duration" : 123, "executorId" : "driver", "host" : "localhost", "status" : "SUCCESS", @@ -785,6 +802,7 @@ "resultSerializationTime" : 1, "memoryBytesSpilled" : 0, "diskBytesSpilled" : 0, + "peakExecutionMemory" : 0, "inputMetrics" : { "bytesRead" : 70564, "recordsRead" : 10000 @@ -813,7 +831,7 @@ "index" : 18, "attempt" : 0, "launchTime" : "2015-05-06T13:03:07.010GMT", - "duration" : 92, + "duration" : 105, "executorId" : "driver", "host" : "localhost", "status" : "SUCCESS", @@ -830,6 +848,7 @@ "resultSerializationTime" : 0, "memoryBytesSpilled" : 0, "diskBytesSpilled" : 0, + "peakExecutionMemory" : 0, "inputMetrics" : { "bytesRead" : 70564, "recordsRead" : 10000 @@ -858,7 +877,7 @@ "index" : 19, "attempt" : 0, "launchTime" : "2015-05-06T13:03:07.012GMT", - "duration" : 84, + "duration" : 94, "executorId" : "driver", "host" : "localhost", "status" : "SUCCESS", @@ -875,6 +894,7 @@ "resultSerializationTime" : 0, "memoryBytesSpilled" : 0, "diskBytesSpilled" : 0, + "peakExecutionMemory" : 0, "inputMetrics" : { "bytesRead" : 70564, "recordsRead" : 10000 diff --git a/core/src/test/resources/HistoryServerExpectations/stage_task_list_from_multi_attempt_app_json_1__expectation.json b/core/src/test/resources/HistoryServerExpectations/stage_task_list_from_multi_attempt_app_json_1__expectation.json index 2ebee66a6d7c2..f9182b1658334 100644 --- a/core/src/test/resources/HistoryServerExpectations/stage_task_list_from_multi_attempt_app_json_1__expectation.json +++ b/core/src/test/resources/HistoryServerExpectations/stage_task_list_from_multi_attempt_app_json_1__expectation.json @@ -3,7 +3,7 @@ "index" : 0, "attempt" : 0, "launchTime" : "2015-03-16T19:25:36.515GMT", - "duration" : 15, + "duration" : 61, "executorId" : "", "host" : "localhost", "status" : "SUCCESS", @@ -25,6 +25,7 @@ "resultSerializationTime" : 2, "memoryBytesSpilled" : 0, "diskBytesSpilled" : 0, + "peakExecutionMemory" : 0, "inputMetrics" : { "bytesRead" : 0, "recordsRead" : 0 @@ -53,7 +54,7 @@ "index" : 1, "attempt" : 0, "launchTime" : "2015-03-16T19:25:36.521GMT", - "duration" : 15, + "duration" : 53, "executorId" : "", "host" : "localhost", "status" : "SUCCESS", @@ -75,6 +76,7 @@ "resultSerializationTime" : 2, "memoryBytesSpilled" : 0, "diskBytesSpilled" : 0, + "peakExecutionMemory" : 0, "inputMetrics" : { "bytesRead" : 0, "recordsRead" : 0 @@ -103,7 +105,7 @@ "index" : 2, "attempt" : 0, "launchTime" : "2015-03-16T19:25:36.522GMT", - "duration" : 15, + "duration" : 48, "executorId" : "", "host" : "localhost", "status" : "SUCCESS", @@ -125,6 +127,7 @@ "resultSerializationTime" : 2, "memoryBytesSpilled" : 0, "diskBytesSpilled" : 0, + "peakExecutionMemory" : 0, "inputMetrics" : { "bytesRead" : 0, "recordsRead" : 0 @@ -153,7 +156,7 @@ "index" : 3, "attempt" : 0, "launchTime" : "2015-03-16T19:25:36.522GMT", - "duration" : 15, + "duration" : 50, "executorId" : "", "host" : "localhost", "status" : "SUCCESS", @@ -175,6 +178,7 @@ "resultSerializationTime" : 2, "memoryBytesSpilled" : 0, "diskBytesSpilled" : 0, + "peakExecutionMemory" : 0, "inputMetrics" : { "bytesRead" : 0, "recordsRead" : 0 @@ -203,7 +207,7 @@ "index" : 4, "attempt" : 0, "launchTime" : "2015-03-16T19:25:36.522GMT", - "duration" : 15, + "duration" : 52, "executorId" : "", "host" : "localhost", "status" : "SUCCESS", @@ -225,6 +229,7 @@ "resultSerializationTime" : 1, "memoryBytesSpilled" : 0, "diskBytesSpilled" : 0, + "peakExecutionMemory" : 0, "inputMetrics" : { "bytesRead" : 0, "recordsRead" : 0 @@ -253,7 +258,7 @@ "index" : 5, "attempt" : 0, "launchTime" : "2015-03-16T19:25:36.523GMT", - "duration" : 15, + "duration" : 52, "executorId" : "", "host" : "localhost", "status" : "SUCCESS", @@ -275,6 +280,7 @@ "resultSerializationTime" : 2, "memoryBytesSpilled" : 0, "diskBytesSpilled" : 0, + "peakExecutionMemory" : 0, "inputMetrics" : { "bytesRead" : 0, "recordsRead" : 0 @@ -303,7 +309,7 @@ "index" : 6, "attempt" : 0, "launchTime" : "2015-03-16T19:25:36.523GMT", - "duration" : 15, + "duration" : 51, "executorId" : "", "host" : "localhost", "status" : "SUCCESS", @@ -325,6 +331,7 @@ "resultSerializationTime" : 2, "memoryBytesSpilled" : 0, "diskBytesSpilled" : 0, + "peakExecutionMemory" : 0, "inputMetrics" : { "bytesRead" : 0, "recordsRead" : 0 @@ -353,7 +360,7 @@ "index" : 7, "attempt" : 0, "launchTime" : "2015-03-16T19:25:36.524GMT", - "duration" : 15, + "duration" : 51, "executorId" : "", "host" : "localhost", "status" : "SUCCESS", @@ -375,6 +382,7 @@ "resultSerializationTime" : 2, "memoryBytesSpilled" : 0, "diskBytesSpilled" : 0, + "peakExecutionMemory" : 0, "inputMetrics" : { "bytesRead" : 0, "recordsRead" : 0 diff --git a/core/src/test/resources/HistoryServerExpectations/stage_task_list_from_multi_attempt_app_json_2__expectation.json b/core/src/test/resources/HistoryServerExpectations/stage_task_list_from_multi_attempt_app_json_2__expectation.json index 965a31a4104c3..76dd2f710b90f 100644 --- a/core/src/test/resources/HistoryServerExpectations/stage_task_list_from_multi_attempt_app_json_2__expectation.json +++ b/core/src/test/resources/HistoryServerExpectations/stage_task_list_from_multi_attempt_app_json_2__expectation.json @@ -3,7 +3,7 @@ "index" : 0, "attempt" : 0, "launchTime" : "2015-03-17T23:12:16.515GMT", - "duration" : 15, + "duration" : 61, "executorId" : "", "host" : "localhost", "status" : "SUCCESS", @@ -25,6 +25,7 @@ "resultSerializationTime" : 2, "memoryBytesSpilled" : 0, "diskBytesSpilled" : 0, + "peakExecutionMemory" : 0, "inputMetrics" : { "bytesRead" : 0, "recordsRead" : 0 @@ -53,7 +54,7 @@ "index" : 1, "attempt" : 0, "launchTime" : "2015-03-17T23:12:16.521GMT", - "duration" : 15, + "duration" : 53, "executorId" : "", "host" : "localhost", "status" : "SUCCESS", @@ -75,6 +76,7 @@ "resultSerializationTime" : 2, "memoryBytesSpilled" : 0, "diskBytesSpilled" : 0, + "peakExecutionMemory" : 0, "inputMetrics" : { "bytesRead" : 0, "recordsRead" : 0 @@ -103,7 +105,7 @@ "index" : 2, "attempt" : 0, "launchTime" : "2015-03-17T23:12:16.522GMT", - "duration" : 15, + "duration" : 48, "executorId" : "", "host" : "localhost", "status" : "SUCCESS", @@ -125,6 +127,7 @@ "resultSerializationTime" : 2, "memoryBytesSpilled" : 0, "diskBytesSpilled" : 0, + "peakExecutionMemory" : 0, "inputMetrics" : { "bytesRead" : 0, "recordsRead" : 0 @@ -153,7 +156,7 @@ "index" : 3, "attempt" : 0, "launchTime" : "2015-03-17T23:12:16.522GMT", - "duration" : 15, + "duration" : 50, "executorId" : "", "host" : "localhost", "status" : "SUCCESS", @@ -175,6 +178,7 @@ "resultSerializationTime" : 2, "memoryBytesSpilled" : 0, "diskBytesSpilled" : 0, + "peakExecutionMemory" : 0, "inputMetrics" : { "bytesRead" : 0, "recordsRead" : 0 @@ -203,7 +207,7 @@ "index" : 4, "attempt" : 0, "launchTime" : "2015-03-17T23:12:16.522GMT", - "duration" : 15, + "duration" : 52, "executorId" : "", "host" : "localhost", "status" : "SUCCESS", @@ -225,6 +229,7 @@ "resultSerializationTime" : 1, "memoryBytesSpilled" : 0, "diskBytesSpilled" : 0, + "peakExecutionMemory" : 0, "inputMetrics" : { "bytesRead" : 0, "recordsRead" : 0 @@ -253,7 +258,7 @@ "index" : 5, "attempt" : 0, "launchTime" : "2015-03-17T23:12:16.523GMT", - "duration" : 15, + "duration" : 52, "executorId" : "", "host" : "localhost", "status" : "SUCCESS", @@ -275,6 +280,7 @@ "resultSerializationTime" : 2, "memoryBytesSpilled" : 0, "diskBytesSpilled" : 0, + "peakExecutionMemory" : 0, "inputMetrics" : { "bytesRead" : 0, "recordsRead" : 0 @@ -303,7 +309,7 @@ "index" : 6, "attempt" : 0, "launchTime" : "2015-03-17T23:12:16.523GMT", - "duration" : 15, + "duration" : 51, "executorId" : "", "host" : "localhost", "status" : "SUCCESS", @@ -325,6 +331,7 @@ "resultSerializationTime" : 2, "memoryBytesSpilled" : 0, "diskBytesSpilled" : 0, + "peakExecutionMemory" : 0, "inputMetrics" : { "bytesRead" : 0, "recordsRead" : 0 @@ -353,7 +360,7 @@ "index" : 7, "attempt" : 0, "launchTime" : "2015-03-17T23:12:16.524GMT", - "duration" : 15, + "duration" : 51, "executorId" : "", "host" : "localhost", "status" : "SUCCESS", @@ -375,6 +382,7 @@ "resultSerializationTime" : 2, "memoryBytesSpilled" : 0, "diskBytesSpilled" : 0, + "peakExecutionMemory" : 0, "inputMetrics" : { "bytesRead" : 0, "recordsRead" : 0 diff --git a/core/src/test/resources/HistoryServerExpectations/stage_task_list_w__offset___length_expectation.json b/core/src/test/resources/HistoryServerExpectations/stage_task_list_w__offset___length_expectation.json index 31132e156937c..6bdc10465d89e 100644 --- a/core/src/test/resources/HistoryServerExpectations/stage_task_list_w__offset___length_expectation.json +++ b/core/src/test/resources/HistoryServerExpectations/stage_task_list_w__offset___length_expectation.json @@ -3,7 +3,7 @@ "index" : 10, "attempt" : 0, "launchTime" : "2015-05-06T13:03:06.916GMT", - "duration" : 73, + "duration" : 99, "executorId" : "driver", "host" : "localhost", "status" : "SUCCESS", @@ -20,6 +20,7 @@ "resultSerializationTime" : 0, "memoryBytesSpilled" : 0, "diskBytesSpilled" : 0, + "peakExecutionMemory" : 0, "inputMetrics" : { "bytesRead" : 70564, "recordsRead" : 10000 @@ -48,7 +49,7 @@ "index" : 11, "attempt" : 0, "launchTime" : "2015-05-06T13:03:06.918GMT", - "duration" : 75, + "duration" : 89, "executorId" : "driver", "host" : "localhost", "status" : "SUCCESS", @@ -65,6 +66,7 @@ "resultSerializationTime" : 0, "memoryBytesSpilled" : 0, "diskBytesSpilled" : 0, + "peakExecutionMemory" : 0, "inputMetrics" : { "bytesRead" : 70564, "recordsRead" : 10000 @@ -93,7 +95,7 @@ "index" : 12, "attempt" : 0, "launchTime" : "2015-05-06T13:03:06.923GMT", - "duration" : 77, + "duration" : 93, "executorId" : "driver", "host" : "localhost", "status" : "SUCCESS", @@ -110,6 +112,7 @@ "resultSerializationTime" : 0, "memoryBytesSpilled" : 0, "diskBytesSpilled" : 0, + "peakExecutionMemory" : 0, "inputMetrics" : { "bytesRead" : 70564, "recordsRead" : 10000 @@ -138,7 +141,7 @@ "index" : 13, "attempt" : 0, "launchTime" : "2015-05-06T13:03:06.924GMT", - "duration" : 76, + "duration" : 138, "executorId" : "driver", "host" : "localhost", "status" : "SUCCESS", @@ -155,6 +158,7 @@ "resultSerializationTime" : 0, "memoryBytesSpilled" : 0, "diskBytesSpilled" : 0, + "peakExecutionMemory" : 0, "inputMetrics" : { "bytesRead" : 70564, "recordsRead" : 10000 @@ -183,7 +187,7 @@ "index" : 14, "attempt" : 0, "launchTime" : "2015-05-06T13:03:06.925GMT", - "duration" : 83, + "duration" : 94, "executorId" : "driver", "host" : "localhost", "status" : "SUCCESS", @@ -200,6 +204,7 @@ "resultSerializationTime" : 0, "memoryBytesSpilled" : 0, "diskBytesSpilled" : 0, + "peakExecutionMemory" : 0, "inputMetrics" : { "bytesRead" : 70564, "recordsRead" : 10000 @@ -228,7 +233,7 @@ "index" : 15, "attempt" : 0, "launchTime" : "2015-05-06T13:03:06.928GMT", - "duration" : 76, + "duration" : 83, "executorId" : "driver", "host" : "localhost", "status" : "SUCCESS", @@ -245,6 +250,7 @@ "resultSerializationTime" : 0, "memoryBytesSpilled" : 0, "diskBytesSpilled" : 0, + "peakExecutionMemory" : 0, "inputMetrics" : { "bytesRead" : 70564, "recordsRead" : 10000 @@ -273,7 +279,7 @@ "index" : 16, "attempt" : 0, "launchTime" : "2015-05-06T13:03:07.001GMT", - "duration" : 84, + "duration" : 98, "executorId" : "driver", "host" : "localhost", "status" : "SUCCESS", @@ -290,6 +296,7 @@ "resultSerializationTime" : 0, "memoryBytesSpilled" : 0, "diskBytesSpilled" : 0, + "peakExecutionMemory" : 0, "inputMetrics" : { "bytesRead" : 70564, "recordsRead" : 10000 @@ -318,7 +325,7 @@ "index" : 17, "attempt" : 0, "launchTime" : "2015-05-06T13:03:07.005GMT", - "duration" : 91, + "duration" : 123, "executorId" : "driver", "host" : "localhost", "status" : "SUCCESS", @@ -335,6 +342,7 @@ "resultSerializationTime" : 1, "memoryBytesSpilled" : 0, "diskBytesSpilled" : 0, + "peakExecutionMemory" : 0, "inputMetrics" : { "bytesRead" : 70564, "recordsRead" : 10000 @@ -363,7 +371,7 @@ "index" : 18, "attempt" : 0, "launchTime" : "2015-05-06T13:03:07.010GMT", - "duration" : 92, + "duration" : 105, "executorId" : "driver", "host" : "localhost", "status" : "SUCCESS", @@ -380,6 +388,7 @@ "resultSerializationTime" : 0, "memoryBytesSpilled" : 0, "diskBytesSpilled" : 0, + "peakExecutionMemory" : 0, "inputMetrics" : { "bytesRead" : 70564, "recordsRead" : 10000 @@ -408,7 +417,7 @@ "index" : 19, "attempt" : 0, "launchTime" : "2015-05-06T13:03:07.012GMT", - "duration" : 84, + "duration" : 94, "executorId" : "driver", "host" : "localhost", "status" : "SUCCESS", @@ -425,6 +434,7 @@ "resultSerializationTime" : 0, "memoryBytesSpilled" : 0, "diskBytesSpilled" : 0, + "peakExecutionMemory" : 0, "inputMetrics" : { "bytesRead" : 70564, "recordsRead" : 10000 @@ -453,7 +463,7 @@ "index" : 20, "attempt" : 0, "launchTime" : "2015-05-06T13:03:07.014GMT", - "duration" : 83, + "duration" : 90, "executorId" : "driver", "host" : "localhost", "status" : "SUCCESS", @@ -470,6 +480,7 @@ "resultSerializationTime" : 0, "memoryBytesSpilled" : 0, "diskBytesSpilled" : 0, + "peakExecutionMemory" : 0, "inputMetrics" : { "bytesRead" : 70564, "recordsRead" : 10000 @@ -498,7 +509,7 @@ "index" : 21, "attempt" : 0, "launchTime" : "2015-05-06T13:03:07.015GMT", - "duration" : 88, + "duration" : 96, "executorId" : "driver", "host" : "localhost", "status" : "SUCCESS", @@ -515,6 +526,7 @@ "resultSerializationTime" : 0, "memoryBytesSpilled" : 0, "diskBytesSpilled" : 0, + "peakExecutionMemory" : 0, "inputMetrics" : { "bytesRead" : 70564, "recordsRead" : 10000 @@ -543,7 +555,7 @@ "index" : 22, "attempt" : 0, "launchTime" : "2015-05-06T13:03:07.018GMT", - "duration" : 93, + "duration" : 101, "executorId" : "driver", "host" : "localhost", "status" : "SUCCESS", @@ -560,6 +572,7 @@ "resultSerializationTime" : 0, "memoryBytesSpilled" : 0, "diskBytesSpilled" : 0, + "peakExecutionMemory" : 0, "inputMetrics" : { "bytesRead" : 70564, "recordsRead" : 10000 @@ -588,7 +601,7 @@ "index" : 23, "attempt" : 0, "launchTime" : "2015-05-06T13:03:07.031GMT", - "duration" : 65, + "duration" : 84, "executorId" : "driver", "host" : "localhost", "status" : "SUCCESS", @@ -605,6 +618,7 @@ "resultSerializationTime" : 0, "memoryBytesSpilled" : 0, "diskBytesSpilled" : 0, + "peakExecutionMemory" : 0, "inputMetrics" : { "bytesRead" : 70564, "recordsRead" : 10000 @@ -633,7 +647,7 @@ "index" : 24, "attempt" : 0, "launchTime" : "2015-05-06T13:03:07.098GMT", - "duration" : 43, + "duration" : 52, "executorId" : "driver", "host" : "localhost", "status" : "SUCCESS", @@ -650,6 +664,7 @@ "resultSerializationTime" : 1, "memoryBytesSpilled" : 0, "diskBytesSpilled" : 0, + "peakExecutionMemory" : 0, "inputMetrics" : { "bytesRead" : 70564, "recordsRead" : 10000 @@ -678,7 +693,7 @@ "index" : 25, "attempt" : 0, "launchTime" : "2015-05-06T13:03:07.103GMT", - "duration" : 49, + "duration" : 61, "executorId" : "driver", "host" : "localhost", "status" : "SUCCESS", @@ -695,6 +710,7 @@ "resultSerializationTime" : 0, "memoryBytesSpilled" : 0, "diskBytesSpilled" : 0, + "peakExecutionMemory" : 0, "inputMetrics" : { "bytesRead" : 70564, "recordsRead" : 10000 @@ -723,7 +739,7 @@ "index" : 26, "attempt" : 0, "launchTime" : "2015-05-06T13:03:07.105GMT", - "duration" : 38, + "duration" : 52, "executorId" : "driver", "host" : "localhost", "status" : "SUCCESS", @@ -740,6 +756,7 @@ "resultSerializationTime" : 0, "memoryBytesSpilled" : 0, "diskBytesSpilled" : 0, + "peakExecutionMemory" : 0, "inputMetrics" : { "bytesRead" : 70564, "recordsRead" : 10000 @@ -768,7 +785,7 @@ "index" : 27, "attempt" : 0, "launchTime" : "2015-05-06T13:03:07.110GMT", - "duration" : 32, + "duration" : 41, "executorId" : "driver", "host" : "localhost", "status" : "SUCCESS", @@ -785,6 +802,7 @@ "resultSerializationTime" : 0, "memoryBytesSpilled" : 0, "diskBytesSpilled" : 0, + "peakExecutionMemory" : 0, "inputMetrics" : { "bytesRead" : 70564, "recordsRead" : 10000 @@ -813,7 +831,7 @@ "index" : 28, "attempt" : 0, "launchTime" : "2015-05-06T13:03:07.113GMT", - "duration" : 29, + "duration" : 49, "executorId" : "driver", "host" : "localhost", "status" : "SUCCESS", @@ -830,6 +848,7 @@ "resultSerializationTime" : 0, "memoryBytesSpilled" : 0, "diskBytesSpilled" : 0, + "peakExecutionMemory" : 0, "inputMetrics" : { "bytesRead" : 70564, "recordsRead" : 10000 @@ -858,7 +877,7 @@ "index" : 29, "attempt" : 0, "launchTime" : "2015-05-06T13:03:07.114GMT", - "duration" : 39, + "duration" : 52, "executorId" : "driver", "host" : "localhost", "status" : "SUCCESS", @@ -875,6 +894,7 @@ "resultSerializationTime" : 0, "memoryBytesSpilled" : 0, "diskBytesSpilled" : 0, + "peakExecutionMemory" : 0, "inputMetrics" : { "bytesRead" : 70564, "recordsRead" : 10000 @@ -903,7 +923,7 @@ "index" : 30, "attempt" : 0, "launchTime" : "2015-05-06T13:03:07.118GMT", - "duration" : 34, + "duration" : 62, "executorId" : "driver", "host" : "localhost", "status" : "SUCCESS", @@ -920,6 +940,7 @@ "resultSerializationTime" : 0, "memoryBytesSpilled" : 0, "diskBytesSpilled" : 0, + "peakExecutionMemory" : 0, "inputMetrics" : { "bytesRead" : 70564, "recordsRead" : 10000 @@ -948,7 +969,7 @@ "index" : 31, "attempt" : 0, "launchTime" : "2015-05-06T13:03:07.127GMT", - "duration" : 24, + "duration" : 74, "executorId" : "driver", "host" : "localhost", "status" : "SUCCESS", @@ -965,6 +986,7 @@ "resultSerializationTime" : 0, "memoryBytesSpilled" : 0, "diskBytesSpilled" : 0, + "peakExecutionMemory" : 0, "inputMetrics" : { "bytesRead" : 70564, "recordsRead" : 10000 @@ -993,7 +1015,7 @@ "index" : 32, "attempt" : 0, "launchTime" : "2015-05-06T13:03:07.148GMT", - "duration" : 17, + "duration" : 33, "executorId" : "driver", "host" : "localhost", "status" : "SUCCESS", @@ -1010,6 +1032,7 @@ "resultSerializationTime" : 0, "memoryBytesSpilled" : 0, "diskBytesSpilled" : 0, + "peakExecutionMemory" : 0, "inputMetrics" : { "bytesRead" : 70564, "recordsRead" : 10000 @@ -1038,7 +1061,7 @@ "index" : 33, "attempt" : 0, "launchTime" : "2015-05-06T13:03:07.149GMT", - "duration" : 43, + "duration" : 58, "executorId" : "driver", "host" : "localhost", "status" : "SUCCESS", @@ -1055,6 +1078,7 @@ "resultSerializationTime" : 0, "memoryBytesSpilled" : 0, "diskBytesSpilled" : 0, + "peakExecutionMemory" : 0, "inputMetrics" : { "bytesRead" : 70564, "recordsRead" : 10000 @@ -1083,7 +1107,7 @@ "index" : 34, "attempt" : 0, "launchTime" : "2015-05-06T13:03:07.156GMT", - "duration" : 27, + "duration" : 42, "executorId" : "driver", "host" : "localhost", "status" : "SUCCESS", @@ -1100,6 +1124,7 @@ "resultSerializationTime" : 0, "memoryBytesSpilled" : 0, "diskBytesSpilled" : 0, + "peakExecutionMemory" : 0, "inputMetrics" : { "bytesRead" : 70564, "recordsRead" : 10000 @@ -1128,7 +1153,7 @@ "index" : 35, "attempt" : 0, "launchTime" : "2015-05-06T13:03:07.161GMT", - "duration" : 35, + "duration" : 50, "executorId" : "driver", "host" : "localhost", "status" : "SUCCESS", @@ -1145,6 +1170,7 @@ "resultSerializationTime" : 0, "memoryBytesSpilled" : 0, "diskBytesSpilled" : 0, + "peakExecutionMemory" : 0, "inputMetrics" : { "bytesRead" : 70564, "recordsRead" : 10000 @@ -1173,7 +1199,7 @@ "index" : 36, "attempt" : 0, "launchTime" : "2015-05-06T13:03:07.164GMT", - "duration" : 29, + "duration" : 40, "executorId" : "driver", "host" : "localhost", "status" : "SUCCESS", @@ -1190,6 +1216,7 @@ "resultSerializationTime" : 0, "memoryBytesSpilled" : 0, "diskBytesSpilled" : 0, + "peakExecutionMemory" : 0, "inputMetrics" : { "bytesRead" : 70564, "recordsRead" : 10000 @@ -1218,7 +1245,7 @@ "index" : 37, "attempt" : 0, "launchTime" : "2015-05-06T13:03:07.165GMT", - "duration" : 32, + "duration" : 42, "executorId" : "driver", "host" : "localhost", "status" : "SUCCESS", @@ -1235,6 +1262,7 @@ "resultSerializationTime" : 0, "memoryBytesSpilled" : 0, "diskBytesSpilled" : 0, + "peakExecutionMemory" : 0, "inputMetrics" : { "bytesRead" : 70564, "recordsRead" : 10000 @@ -1263,7 +1291,7 @@ "index" : 38, "attempt" : 0, "launchTime" : "2015-05-06T13:03:07.166GMT", - "duration" : 31, + "duration" : 47, "executorId" : "driver", "host" : "localhost", "status" : "SUCCESS", @@ -1280,6 +1308,7 @@ "resultSerializationTime" : 0, "memoryBytesSpilled" : 0, "diskBytesSpilled" : 0, + "peakExecutionMemory" : 0, "inputMetrics" : { "bytesRead" : 70564, "recordsRead" : 10000 @@ -1308,7 +1337,7 @@ "index" : 39, "attempt" : 0, "launchTime" : "2015-05-06T13:03:07.180GMT", - "duration" : 17, + "duration" : 32, "executorId" : "driver", "host" : "localhost", "status" : "SUCCESS", @@ -1325,6 +1354,7 @@ "resultSerializationTime" : 0, "memoryBytesSpilled" : 0, "diskBytesSpilled" : 0, + "peakExecutionMemory" : 0, "inputMetrics" : { "bytesRead" : 70564, "recordsRead" : 10000 @@ -1353,7 +1383,7 @@ "index" : 40, "attempt" : 0, "launchTime" : "2015-05-06T13:03:07.197GMT", - "duration" : 14, + "duration" : 24, "executorId" : "driver", "host" : "localhost", "status" : "SUCCESS", @@ -1370,6 +1400,7 @@ "resultSerializationTime" : 0, "memoryBytesSpilled" : 0, "diskBytesSpilled" : 0, + "peakExecutionMemory" : 0, "inputMetrics" : { "bytesRead" : 70564, "recordsRead" : 10000 @@ -1398,7 +1429,7 @@ "index" : 41, "attempt" : 0, "launchTime" : "2015-05-06T13:03:07.200GMT", - "duration" : 16, + "duration" : 24, "executorId" : "driver", "host" : "localhost", "status" : "SUCCESS", @@ -1415,6 +1446,7 @@ "resultSerializationTime" : 0, "memoryBytesSpilled" : 0, "diskBytesSpilled" : 0, + "peakExecutionMemory" : 0, "inputMetrics" : { "bytesRead" : 70564, "recordsRead" : 10000 @@ -1443,7 +1475,7 @@ "index" : 42, "attempt" : 0, "launchTime" : "2015-05-06T13:03:07.203GMT", - "duration" : 17, + "duration" : 42, "executorId" : "driver", "host" : "localhost", "status" : "SUCCESS", @@ -1460,6 +1492,7 @@ "resultSerializationTime" : 0, "memoryBytesSpilled" : 0, "diskBytesSpilled" : 0, + "peakExecutionMemory" : 0, "inputMetrics" : { "bytesRead" : 70564, "recordsRead" : 10000 @@ -1488,7 +1521,7 @@ "index" : 43, "attempt" : 0, "launchTime" : "2015-05-06T13:03:07.204GMT", - "duration" : 16, + "duration" : 39, "executorId" : "driver", "host" : "localhost", "status" : "SUCCESS", @@ -1505,6 +1538,7 @@ "resultSerializationTime" : 0, "memoryBytesSpilled" : 0, "diskBytesSpilled" : 0, + "peakExecutionMemory" : 0, "inputMetrics" : { "bytesRead" : 70564, "recordsRead" : 10000 @@ -1533,7 +1567,7 @@ "index" : 44, "attempt" : 0, "launchTime" : "2015-05-06T13:03:07.205GMT", - "duration" : 18, + "duration" : 37, "executorId" : "driver", "host" : "localhost", "status" : "SUCCESS", @@ -1550,6 +1584,7 @@ "resultSerializationTime" : 0, "memoryBytesSpilled" : 0, "diskBytesSpilled" : 0, + "peakExecutionMemory" : 0, "inputMetrics" : { "bytesRead" : 70564, "recordsRead" : 10000 @@ -1578,7 +1613,7 @@ "index" : 45, "attempt" : 0, "launchTime" : "2015-05-06T13:03:07.206GMT", - "duration" : 19, + "duration" : 37, "executorId" : "driver", "host" : "localhost", "status" : "SUCCESS", @@ -1595,6 +1630,7 @@ "resultSerializationTime" : 0, "memoryBytesSpilled" : 0, "diskBytesSpilled" : 0, + "peakExecutionMemory" : 0, "inputMetrics" : { "bytesRead" : 70564, "recordsRead" : 10000 @@ -1623,7 +1659,7 @@ "index" : 46, "attempt" : 0, "launchTime" : "2015-05-06T13:03:07.210GMT", - "duration" : 31, + "duration" : 43, "executorId" : "driver", "host" : "localhost", "status" : "SUCCESS", @@ -1640,6 +1676,7 @@ "resultSerializationTime" : 0, "memoryBytesSpilled" : 0, "diskBytesSpilled" : 0, + "peakExecutionMemory" : 0, "inputMetrics" : { "bytesRead" : 70564, "recordsRead" : 10000 @@ -1668,7 +1705,7 @@ "index" : 47, "attempt" : 0, "launchTime" : "2015-05-06T13:03:07.212GMT", - "duration" : 18, + "duration" : 33, "executorId" : "driver", "host" : "localhost", "status" : "SUCCESS", @@ -1685,6 +1722,7 @@ "resultSerializationTime" : 0, "memoryBytesSpilled" : 0, "diskBytesSpilled" : 0, + "peakExecutionMemory" : 0, "inputMetrics" : { "bytesRead" : 70564, "recordsRead" : 10000 @@ -1713,7 +1751,7 @@ "index" : 48, "attempt" : 0, "launchTime" : "2015-05-06T13:03:07.220GMT", - "duration" : 24, + "duration" : 30, "executorId" : "driver", "host" : "localhost", "status" : "SUCCESS", @@ -1730,6 +1768,7 @@ "resultSerializationTime" : 0, "memoryBytesSpilled" : 0, "diskBytesSpilled" : 0, + "peakExecutionMemory" : 0, "inputMetrics" : { "bytesRead" : 70564, "recordsRead" : 10000 @@ -1758,7 +1797,7 @@ "index" : 49, "attempt" : 0, "launchTime" : "2015-05-06T13:03:07.223GMT", - "duration" : 23, + "duration" : 34, "executorId" : "driver", "host" : "localhost", "status" : "SUCCESS", @@ -1775,6 +1814,7 @@ "resultSerializationTime" : 0, "memoryBytesSpilled" : 0, "diskBytesSpilled" : 0, + "peakExecutionMemory" : 0, "inputMetrics" : { "bytesRead" : 70564, "recordsRead" : 10000 @@ -1803,7 +1843,7 @@ "index" : 50, "attempt" : 0, "launchTime" : "2015-05-06T13:03:07.240GMT", - "duration" : 18, + "duration" : 26, "executorId" : "driver", "host" : "localhost", "status" : "SUCCESS", @@ -1820,6 +1860,7 @@ "resultSerializationTime" : 0, "memoryBytesSpilled" : 0, "diskBytesSpilled" : 0, + "peakExecutionMemory" : 0, "inputMetrics" : { "bytesRead" : 70564, "recordsRead" : 10000 @@ -1848,7 +1889,7 @@ "index" : 51, "attempt" : 0, "launchTime" : "2015-05-06T13:03:07.242GMT", - "duration" : 17, + "duration" : 21, "executorId" : "driver", "host" : "localhost", "status" : "SUCCESS", @@ -1865,6 +1906,7 @@ "resultSerializationTime" : 0, "memoryBytesSpilled" : 0, "diskBytesSpilled" : 0, + "peakExecutionMemory" : 0, "inputMetrics" : { "bytesRead" : 70564, "recordsRead" : 10000 @@ -1893,7 +1935,7 @@ "index" : 52, "attempt" : 0, "launchTime" : "2015-05-06T13:03:07.243GMT", - "duration" : 18, + "duration" : 28, "executorId" : "driver", "host" : "localhost", "status" : "SUCCESS", @@ -1910,6 +1952,7 @@ "resultSerializationTime" : 0, "memoryBytesSpilled" : 0, "diskBytesSpilled" : 0, + "peakExecutionMemory" : 0, "inputMetrics" : { "bytesRead" : 70564, "recordsRead" : 10000 @@ -1938,7 +1981,7 @@ "index" : 53, "attempt" : 0, "launchTime" : "2015-05-06T13:03:07.244GMT", - "duration" : 18, + "duration" : 29, "executorId" : "driver", "host" : "localhost", "status" : "SUCCESS", @@ -1955,6 +1998,7 @@ "resultSerializationTime" : 0, "memoryBytesSpilled" : 0, "diskBytesSpilled" : 0, + "peakExecutionMemory" : 0, "inputMetrics" : { "bytesRead" : 70564, "recordsRead" : 10000 @@ -1983,7 +2027,7 @@ "index" : 54, "attempt" : 0, "launchTime" : "2015-05-06T13:03:07.244GMT", - "duration" : 18, + "duration" : 59, "executorId" : "driver", "host" : "localhost", "status" : "SUCCESS", @@ -2000,6 +2044,7 @@ "resultSerializationTime" : 0, "memoryBytesSpilled" : 0, "diskBytesSpilled" : 0, + "peakExecutionMemory" : 0, "inputMetrics" : { "bytesRead" : 70564, "recordsRead" : 10000 @@ -2028,7 +2073,7 @@ "index" : 55, "attempt" : 0, "launchTime" : "2015-05-06T13:03:07.246GMT", - "duration" : 21, + "duration" : 30, "executorId" : "driver", "host" : "localhost", "status" : "SUCCESS", @@ -2045,6 +2090,7 @@ "resultSerializationTime" : 0, "memoryBytesSpilled" : 0, "diskBytesSpilled" : 0, + "peakExecutionMemory" : 0, "inputMetrics" : { "bytesRead" : 70564, "recordsRead" : 10000 @@ -2073,7 +2119,7 @@ "index" : 56, "attempt" : 0, "launchTime" : "2015-05-06T13:03:07.249GMT", - "duration" : 20, + "duration" : 31, "executorId" : "driver", "host" : "localhost", "status" : "SUCCESS", @@ -2090,6 +2136,7 @@ "resultSerializationTime" : 0, "memoryBytesSpilled" : 0, "diskBytesSpilled" : 0, + "peakExecutionMemory" : 0, "inputMetrics" : { "bytesRead" : 70564, "recordsRead" : 10000 @@ -2118,7 +2165,7 @@ "index" : 57, "attempt" : 0, "launchTime" : "2015-05-06T13:03:07.257GMT", - "duration" : 16, + "duration" : 21, "executorId" : "driver", "host" : "localhost", "status" : "SUCCESS", @@ -2135,6 +2182,7 @@ "resultSerializationTime" : 0, "memoryBytesSpilled" : 0, "diskBytesSpilled" : 0, + "peakExecutionMemory" : 0, "inputMetrics" : { "bytesRead" : 70564, "recordsRead" : 10000 @@ -2163,7 +2211,7 @@ "index" : 58, "attempt" : 0, "launchTime" : "2015-05-06T13:03:07.263GMT", - "duration" : 16, + "duration" : 23, "executorId" : "driver", "host" : "localhost", "status" : "SUCCESS", @@ -2180,6 +2228,7 @@ "resultSerializationTime" : 0, "memoryBytesSpilled" : 0, "diskBytesSpilled" : 0, + "peakExecutionMemory" : 0, "inputMetrics" : { "bytesRead" : 70564, "recordsRead" : 10000 @@ -2208,7 +2257,7 @@ "index" : 59, "attempt" : 0, "launchTime" : "2015-05-06T13:03:07.265GMT", - "duration" : 17, + "duration" : 23, "executorId" : "driver", "host" : "localhost", "status" : "SUCCESS", @@ -2225,6 +2274,7 @@ "resultSerializationTime" : 0, "memoryBytesSpilled" : 0, "diskBytesSpilled" : 0, + "peakExecutionMemory" : 0, "inputMetrics" : { "bytesRead" : 70564, "recordsRead" : 10000 diff --git a/core/src/test/resources/HistoryServerExpectations/stage_task_list_w__sortBy_expectation.json b/core/src/test/resources/HistoryServerExpectations/stage_task_list_w__sortBy_expectation.json index 6af1cfbeb8f7e..bc1cd49909d31 100644 --- a/core/src/test/resources/HistoryServerExpectations/stage_task_list_w__sortBy_expectation.json +++ b/core/src/test/resources/HistoryServerExpectations/stage_task_list_w__sortBy_expectation.json @@ -3,7 +3,7 @@ "index" : 6, "attempt" : 0, "launchTime" : "2015-05-06T13:03:06.505GMT", - "duration" : 351, + "duration" : 419, "executorId" : "driver", "host" : "localhost", "status" : "SUCCESS", @@ -20,6 +20,7 @@ "resultSerializationTime" : 1, "memoryBytesSpilled" : 0, "diskBytesSpilled" : 0, + "peakExecutionMemory" : 0, "inputMetrics" : { "bytesRead" : 60488, "recordsRead" : 10000 @@ -44,11 +45,11 @@ } } }, { - "taskId" : 1, - "index" : 1, + "taskId" : 5, + "index" : 5, "attempt" : 0, - "launchTime" : "2015-05-06T13:03:06.502GMT", - "duration" : 350, + "launchTime" : "2015-05-06T13:03:06.505GMT", + "duration" : 414, "executorId" : "driver", "host" : "localhost", "status" : "SUCCESS", @@ -56,15 +57,16 @@ "speculative" : false, "accumulatorUpdates" : [ ], "taskMetrics" : { - "executorDeserializeTime" : 31, + "executorDeserializeTime" : 30, "executorDeserializeCpuTime" : 0, "executorRunTime" : 350, "executorCpuTime" : 0, "resultSize" : 2010, "jvmGcTime" : 7, - "resultSerializationTime" : 0, + "resultSerializationTime" : 1, "memoryBytesSpilled" : 0, "diskBytesSpilled" : 0, + "peakExecutionMemory" : 0, "inputMetrics" : { "bytesRead" : 60488, "recordsRead" : 10000 @@ -84,16 +86,16 @@ }, "shuffleWriteMetrics" : { "bytesWritten" : 1710, - "writeTime" : 3934399, + "writeTime" : 3675510, "recordsWritten" : 10 } } }, { - "taskId" : 5, - "index" : 5, + "taskId" : 1, + "index" : 1, "attempt" : 0, - "launchTime" : "2015-05-06T13:03:06.505GMT", - "duration" : 350, + "launchTime" : "2015-05-06T13:03:06.502GMT", + "duration" : 421, "executorId" : "driver", "host" : "localhost", "status" : "SUCCESS", @@ -101,15 +103,16 @@ "speculative" : false, "accumulatorUpdates" : [ ], "taskMetrics" : { - "executorDeserializeTime" : 30, + "executorDeserializeTime" : 31, "executorDeserializeCpuTime" : 0, "executorRunTime" : 350, "executorCpuTime" : 0, "resultSize" : 2010, "jvmGcTime" : 7, - "resultSerializationTime" : 1, + "resultSerializationTime" : 0, "memoryBytesSpilled" : 0, "diskBytesSpilled" : 0, + "peakExecutionMemory" : 0, "inputMetrics" : { "bytesRead" : 60488, "recordsRead" : 10000 @@ -129,16 +132,16 @@ }, "shuffleWriteMetrics" : { "bytesWritten" : 1710, - "writeTime" : 3675510, + "writeTime" : 3934399, "recordsWritten" : 10 } } }, { - "taskId" : 0, - "index" : 0, + "taskId" : 7, + "index" : 7, "attempt" : 0, - "launchTime" : "2015-05-06T13:03:06.494GMT", - "duration" : 349, + "launchTime" : "2015-05-06T13:03:06.506GMT", + "duration" : 423, "executorId" : "driver", "host" : "localhost", "status" : "SUCCESS", @@ -146,17 +149,18 @@ "speculative" : false, "accumulatorUpdates" : [ ], "taskMetrics" : { - "executorDeserializeTime" : 32, + "executorDeserializeTime" : 31, "executorDeserializeCpuTime" : 0, "executorRunTime" : 349, "executorCpuTime" : 0, "resultSize" : 2010, "jvmGcTime" : 7, - "resultSerializationTime" : 1, + "resultSerializationTime" : 0, "memoryBytesSpilled" : 0, "diskBytesSpilled" : 0, + "peakExecutionMemory" : 0, "inputMetrics" : { - "bytesRead" : 49294, + "bytesRead" : 60488, "recordsRead" : 10000 }, "outputMetrics" : { @@ -174,16 +178,16 @@ }, "shuffleWriteMetrics" : { "bytesWritten" : 1710, - "writeTime" : 3842811, + "writeTime" : 2579051, "recordsWritten" : 10 } } }, { - "taskId" : 3, - "index" : 3, + "taskId" : 4, + "index" : 4, "attempt" : 0, "launchTime" : "2015-05-06T13:03:06.504GMT", - "duration" : 349, + "duration" : 419, "executorId" : "driver", "host" : "localhost", "status" : "SUCCESS", @@ -197,9 +201,10 @@ "executorCpuTime" : 0, "resultSize" : 2010, "jvmGcTime" : 7, - "resultSerializationTime" : 2, + "resultSerializationTime" : 1, "memoryBytesSpilled" : 0, "diskBytesSpilled" : 0, + "peakExecutionMemory" : 0, "inputMetrics" : { "bytesRead" : 60488, "recordsRead" : 10000 @@ -219,16 +224,16 @@ }, "shuffleWriteMetrics" : { "bytesWritten" : 1710, - "writeTime" : 1311694, + "writeTime" : 83022, "recordsWritten" : 10 } } }, { - "taskId" : 4, - "index" : 4, + "taskId" : 3, + "index" : 3, "attempt" : 0, "launchTime" : "2015-05-06T13:03:06.504GMT", - "duration" : 349, + "duration" : 423, "executorId" : "driver", "host" : "localhost", "status" : "SUCCESS", @@ -242,9 +247,10 @@ "executorCpuTime" : 0, "resultSize" : 2010, "jvmGcTime" : 7, - "resultSerializationTime" : 1, + "resultSerializationTime" : 2, "memoryBytesSpilled" : 0, "diskBytesSpilled" : 0, + "peakExecutionMemory" : 0, "inputMetrics" : { "bytesRead" : 60488, "recordsRead" : 10000 @@ -264,16 +270,16 @@ }, "shuffleWriteMetrics" : { "bytesWritten" : 1710, - "writeTime" : 83022, + "writeTime" : 1311694, "recordsWritten" : 10 } } }, { - "taskId" : 7, - "index" : 7, + "taskId" : 0, + "index" : 0, "attempt" : 0, - "launchTime" : "2015-05-06T13:03:06.506GMT", - "duration" : 349, + "launchTime" : "2015-05-06T13:03:06.494GMT", + "duration" : 435, "executorId" : "driver", "host" : "localhost", "status" : "SUCCESS", @@ -281,17 +287,18 @@ "speculative" : false, "accumulatorUpdates" : [ ], "taskMetrics" : { - "executorDeserializeTime" : 31, + "executorDeserializeTime" : 32, "executorDeserializeCpuTime" : 0, "executorRunTime" : 349, "executorCpuTime" : 0, "resultSize" : 2010, "jvmGcTime" : 7, - "resultSerializationTime" : 0, + "resultSerializationTime" : 1, "memoryBytesSpilled" : 0, "diskBytesSpilled" : 0, + "peakExecutionMemory" : 0, "inputMetrics" : { - "bytesRead" : 60488, + "bytesRead" : 49294, "recordsRead" : 10000 }, "outputMetrics" : { @@ -309,7 +316,7 @@ }, "shuffleWriteMetrics" : { "bytesWritten" : 1710, - "writeTime" : 2579051, + "writeTime" : 3842811, "recordsWritten" : 10 } } @@ -318,7 +325,7 @@ "index" : 2, "attempt" : 0, "launchTime" : "2015-05-06T13:03:06.503GMT", - "duration" : 348, + "duration" : 419, "executorId" : "driver", "host" : "localhost", "status" : "SUCCESS", @@ -335,6 +342,7 @@ "resultSerializationTime" : 2, "memoryBytesSpilled" : 0, "diskBytesSpilled" : 0, + "peakExecutionMemory" : 0, "inputMetrics" : { "bytesRead" : 60488, "recordsRead" : 10000 @@ -363,7 +371,7 @@ "index" : 22, "attempt" : 0, "launchTime" : "2015-05-06T13:03:07.018GMT", - "duration" : 93, + "duration" : 101, "executorId" : "driver", "host" : "localhost", "status" : "SUCCESS", @@ -380,6 +388,7 @@ "resultSerializationTime" : 0, "memoryBytesSpilled" : 0, "diskBytesSpilled" : 0, + "peakExecutionMemory" : 0, "inputMetrics" : { "bytesRead" : 70564, "recordsRead" : 10000 @@ -408,7 +417,7 @@ "index" : 18, "attempt" : 0, "launchTime" : "2015-05-06T13:03:07.010GMT", - "duration" : 92, + "duration" : 105, "executorId" : "driver", "host" : "localhost", "status" : "SUCCESS", @@ -425,6 +434,7 @@ "resultSerializationTime" : 0, "memoryBytesSpilled" : 0, "diskBytesSpilled" : 0, + "peakExecutionMemory" : 0, "inputMetrics" : { "bytesRead" : 70564, "recordsRead" : 10000 @@ -453,7 +463,7 @@ "index" : 17, "attempt" : 0, "launchTime" : "2015-05-06T13:03:07.005GMT", - "duration" : 91, + "duration" : 123, "executorId" : "driver", "host" : "localhost", "status" : "SUCCESS", @@ -470,6 +480,7 @@ "resultSerializationTime" : 1, "memoryBytesSpilled" : 0, "diskBytesSpilled" : 0, + "peakExecutionMemory" : 0, "inputMetrics" : { "bytesRead" : 70564, "recordsRead" : 10000 @@ -498,7 +509,7 @@ "index" : 21, "attempt" : 0, "launchTime" : "2015-05-06T13:03:07.015GMT", - "duration" : 88, + "duration" : 96, "executorId" : "driver", "host" : "localhost", "status" : "SUCCESS", @@ -515,6 +526,7 @@ "resultSerializationTime" : 0, "memoryBytesSpilled" : 0, "diskBytesSpilled" : 0, + "peakExecutionMemory" : 0, "inputMetrics" : { "bytesRead" : 70564, "recordsRead" : 10000 @@ -539,11 +551,11 @@ } } }, { - "taskId" : 9, - "index" : 9, + "taskId" : 19, + "index" : 19, "attempt" : 0, - "launchTime" : "2015-05-06T13:03:06.915GMT", - "duration" : 84, + "launchTime" : "2015-05-06T13:03:07.012GMT", + "duration" : 94, "executorId" : "driver", "host" : "localhost", "status" : "SUCCESS", @@ -551,17 +563,18 @@ "speculative" : false, "accumulatorUpdates" : [ ], "taskMetrics" : { - "executorDeserializeTime" : 9, + "executorDeserializeTime" : 5, "executorDeserializeCpuTime" : 0, "executorRunTime" : 84, "executorCpuTime" : 0, "resultSize" : 2010, - "jvmGcTime" : 0, + "jvmGcTime" : 5, "resultSerializationTime" : 0, "memoryBytesSpilled" : 0, "diskBytesSpilled" : 0, + "peakExecutionMemory" : 0, "inputMetrics" : { - "bytesRead" : 60489, + "bytesRead" : 70564, "recordsRead" : 10000 }, "outputMetrics" : { @@ -579,7 +592,7 @@ }, "shuffleWriteMetrics" : { "bytesWritten" : 1710, - "writeTime" : 101664, + "writeTime" : 95788, "recordsWritten" : 10 } } @@ -588,7 +601,7 @@ "index" : 16, "attempt" : 0, "launchTime" : "2015-05-06T13:03:07.001GMT", - "duration" : 84, + "duration" : 98, "executorId" : "driver", "host" : "localhost", "status" : "SUCCESS", @@ -605,6 +618,7 @@ "resultSerializationTime" : 0, "memoryBytesSpilled" : 0, "diskBytesSpilled" : 0, + "peakExecutionMemory" : 0, "inputMetrics" : { "bytesRead" : 70564, "recordsRead" : 10000 @@ -629,11 +643,11 @@ } } }, { - "taskId" : 19, - "index" : 19, + "taskId" : 9, + "index" : 9, "attempt" : 0, - "launchTime" : "2015-05-06T13:03:07.012GMT", - "duration" : 84, + "launchTime" : "2015-05-06T13:03:06.915GMT", + "duration" : 101, "executorId" : "driver", "host" : "localhost", "status" : "SUCCESS", @@ -641,17 +655,18 @@ "speculative" : false, "accumulatorUpdates" : [ ], "taskMetrics" : { - "executorDeserializeTime" : 5, + "executorDeserializeTime" : 9, "executorDeserializeCpuTime" : 0, "executorRunTime" : 84, "executorCpuTime" : 0, "resultSize" : 2010, - "jvmGcTime" : 5, + "jvmGcTime" : 0, "resultSerializationTime" : 0, "memoryBytesSpilled" : 0, "diskBytesSpilled" : 0, + "peakExecutionMemory" : 0, "inputMetrics" : { - "bytesRead" : 70564, + "bytesRead" : 60489, "recordsRead" : 10000 }, "outputMetrics" : { @@ -669,16 +684,16 @@ }, "shuffleWriteMetrics" : { "bytesWritten" : 1710, - "writeTime" : 95788, + "writeTime" : 101664, "recordsWritten" : 10 } } }, { - "taskId" : 14, - "index" : 14, + "taskId" : 20, + "index" : 20, "attempt" : 0, - "launchTime" : "2015-05-06T13:03:06.925GMT", - "duration" : 83, + "launchTime" : "2015-05-06T13:03:07.014GMT", + "duration" : 90, "executorId" : "driver", "host" : "localhost", "status" : "SUCCESS", @@ -686,15 +701,16 @@ "speculative" : false, "accumulatorUpdates" : [ ], "taskMetrics" : { - "executorDeserializeTime" : 6, + "executorDeserializeTime" : 3, "executorDeserializeCpuTime" : 0, "executorRunTime" : 83, "executorCpuTime" : 0, "resultSize" : 2010, - "jvmGcTime" : 0, + "jvmGcTime" : 5, "resultSerializationTime" : 0, "memoryBytesSpilled" : 0, "diskBytesSpilled" : 0, + "peakExecutionMemory" : 0, "inputMetrics" : { "bytesRead" : 70564, "recordsRead" : 10000 @@ -714,16 +730,16 @@ }, "shuffleWriteMetrics" : { "bytesWritten" : 1710, - "writeTime" : 95646, + "writeTime" : 97716, "recordsWritten" : 10 } } }, { - "taskId" : 20, - "index" : 20, + "taskId" : 14, + "index" : 14, "attempt" : 0, - "launchTime" : "2015-05-06T13:03:07.014GMT", - "duration" : 83, + "launchTime" : "2015-05-06T13:03:06.925GMT", + "duration" : 94, "executorId" : "driver", "host" : "localhost", "status" : "SUCCESS", @@ -731,15 +747,16 @@ "speculative" : false, "accumulatorUpdates" : [ ], "taskMetrics" : { - "executorDeserializeTime" : 3, + "executorDeserializeTime" : 6, "executorDeserializeCpuTime" : 0, "executorRunTime" : 83, "executorCpuTime" : 0, "resultSize" : 2010, - "jvmGcTime" : 5, + "jvmGcTime" : 0, "resultSerializationTime" : 0, "memoryBytesSpilled" : 0, "diskBytesSpilled" : 0, + "peakExecutionMemory" : 0, "inputMetrics" : { "bytesRead" : 70564, "recordsRead" : 10000 @@ -759,7 +776,7 @@ }, "shuffleWriteMetrics" : { "bytesWritten" : 1710, - "writeTime" : 97716, + "writeTime" : 95646, "recordsWritten" : 10 } } @@ -768,7 +785,7 @@ "index" : 8, "attempt" : 0, "launchTime" : "2015-05-06T13:03:06.914GMT", - "duration" : 80, + "duration" : 88, "executorId" : "driver", "host" : "localhost", "status" : "SUCCESS", @@ -785,6 +802,7 @@ "resultSerializationTime" : 0, "memoryBytesSpilled" : 0, "diskBytesSpilled" : 0, + "peakExecutionMemory" : 0, "inputMetrics" : { "bytesRead" : 60488, "recordsRead" : 10000 @@ -813,7 +831,7 @@ "index" : 12, "attempt" : 0, "launchTime" : "2015-05-06T13:03:06.923GMT", - "duration" : 77, + "duration" : 93, "executorId" : "driver", "host" : "localhost", "status" : "SUCCESS", @@ -830,6 +848,7 @@ "resultSerializationTime" : 0, "memoryBytesSpilled" : 0, "diskBytesSpilled" : 0, + "peakExecutionMemory" : 0, "inputMetrics" : { "bytesRead" : 70564, "recordsRead" : 10000 @@ -854,11 +873,11 @@ } } }, { - "taskId" : 13, - "index" : 13, + "taskId" : 15, + "index" : 15, "attempt" : 0, - "launchTime" : "2015-05-06T13:03:06.924GMT", - "duration" : 76, + "launchTime" : "2015-05-06T13:03:06.928GMT", + "duration" : 83, "executorId" : "driver", "host" : "localhost", "status" : "SUCCESS", @@ -866,7 +885,7 @@ "speculative" : false, "accumulatorUpdates" : [ ], "taskMetrics" : { - "executorDeserializeTime" : 9, + "executorDeserializeTime" : 3, "executorDeserializeCpuTime" : 0, "executorRunTime" : 76, "executorCpuTime" : 0, @@ -875,6 +894,7 @@ "resultSerializationTime" : 0, "memoryBytesSpilled" : 0, "diskBytesSpilled" : 0, + "peakExecutionMemory" : 0, "inputMetrics" : { "bytesRead" : 70564, "recordsRead" : 10000 @@ -894,7 +914,7 @@ }, "shuffleWriteMetrics" : { "bytesWritten" : 1710, - "writeTime" : 95004, + "writeTime" : 602780, "recordsWritten" : 10 } } diff --git a/core/src/test/resources/HistoryServerExpectations/stage_task_list_w__sortBy_short_names___runtime_expectation.json b/core/src/test/resources/HistoryServerExpectations/stage_task_list_w__sortBy_short_names___runtime_expectation.json index 6af1cfbeb8f7e..bc1cd49909d31 100644 --- a/core/src/test/resources/HistoryServerExpectations/stage_task_list_w__sortBy_short_names___runtime_expectation.json +++ b/core/src/test/resources/HistoryServerExpectations/stage_task_list_w__sortBy_short_names___runtime_expectation.json @@ -3,7 +3,7 @@ "index" : 6, "attempt" : 0, "launchTime" : "2015-05-06T13:03:06.505GMT", - "duration" : 351, + "duration" : 419, "executorId" : "driver", "host" : "localhost", "status" : "SUCCESS", @@ -20,6 +20,7 @@ "resultSerializationTime" : 1, "memoryBytesSpilled" : 0, "diskBytesSpilled" : 0, + "peakExecutionMemory" : 0, "inputMetrics" : { "bytesRead" : 60488, "recordsRead" : 10000 @@ -44,11 +45,11 @@ } } }, { - "taskId" : 1, - "index" : 1, + "taskId" : 5, + "index" : 5, "attempt" : 0, - "launchTime" : "2015-05-06T13:03:06.502GMT", - "duration" : 350, + "launchTime" : "2015-05-06T13:03:06.505GMT", + "duration" : 414, "executorId" : "driver", "host" : "localhost", "status" : "SUCCESS", @@ -56,15 +57,16 @@ "speculative" : false, "accumulatorUpdates" : [ ], "taskMetrics" : { - "executorDeserializeTime" : 31, + "executorDeserializeTime" : 30, "executorDeserializeCpuTime" : 0, "executorRunTime" : 350, "executorCpuTime" : 0, "resultSize" : 2010, "jvmGcTime" : 7, - "resultSerializationTime" : 0, + "resultSerializationTime" : 1, "memoryBytesSpilled" : 0, "diskBytesSpilled" : 0, + "peakExecutionMemory" : 0, "inputMetrics" : { "bytesRead" : 60488, "recordsRead" : 10000 @@ -84,16 +86,16 @@ }, "shuffleWriteMetrics" : { "bytesWritten" : 1710, - "writeTime" : 3934399, + "writeTime" : 3675510, "recordsWritten" : 10 } } }, { - "taskId" : 5, - "index" : 5, + "taskId" : 1, + "index" : 1, "attempt" : 0, - "launchTime" : "2015-05-06T13:03:06.505GMT", - "duration" : 350, + "launchTime" : "2015-05-06T13:03:06.502GMT", + "duration" : 421, "executorId" : "driver", "host" : "localhost", "status" : "SUCCESS", @@ -101,15 +103,16 @@ "speculative" : false, "accumulatorUpdates" : [ ], "taskMetrics" : { - "executorDeserializeTime" : 30, + "executorDeserializeTime" : 31, "executorDeserializeCpuTime" : 0, "executorRunTime" : 350, "executorCpuTime" : 0, "resultSize" : 2010, "jvmGcTime" : 7, - "resultSerializationTime" : 1, + "resultSerializationTime" : 0, "memoryBytesSpilled" : 0, "diskBytesSpilled" : 0, + "peakExecutionMemory" : 0, "inputMetrics" : { "bytesRead" : 60488, "recordsRead" : 10000 @@ -129,16 +132,16 @@ }, "shuffleWriteMetrics" : { "bytesWritten" : 1710, - "writeTime" : 3675510, + "writeTime" : 3934399, "recordsWritten" : 10 } } }, { - "taskId" : 0, - "index" : 0, + "taskId" : 7, + "index" : 7, "attempt" : 0, - "launchTime" : "2015-05-06T13:03:06.494GMT", - "duration" : 349, + "launchTime" : "2015-05-06T13:03:06.506GMT", + "duration" : 423, "executorId" : "driver", "host" : "localhost", "status" : "SUCCESS", @@ -146,17 +149,18 @@ "speculative" : false, "accumulatorUpdates" : [ ], "taskMetrics" : { - "executorDeserializeTime" : 32, + "executorDeserializeTime" : 31, "executorDeserializeCpuTime" : 0, "executorRunTime" : 349, "executorCpuTime" : 0, "resultSize" : 2010, "jvmGcTime" : 7, - "resultSerializationTime" : 1, + "resultSerializationTime" : 0, "memoryBytesSpilled" : 0, "diskBytesSpilled" : 0, + "peakExecutionMemory" : 0, "inputMetrics" : { - "bytesRead" : 49294, + "bytesRead" : 60488, "recordsRead" : 10000 }, "outputMetrics" : { @@ -174,16 +178,16 @@ }, "shuffleWriteMetrics" : { "bytesWritten" : 1710, - "writeTime" : 3842811, + "writeTime" : 2579051, "recordsWritten" : 10 } } }, { - "taskId" : 3, - "index" : 3, + "taskId" : 4, + "index" : 4, "attempt" : 0, "launchTime" : "2015-05-06T13:03:06.504GMT", - "duration" : 349, + "duration" : 419, "executorId" : "driver", "host" : "localhost", "status" : "SUCCESS", @@ -197,9 +201,10 @@ "executorCpuTime" : 0, "resultSize" : 2010, "jvmGcTime" : 7, - "resultSerializationTime" : 2, + "resultSerializationTime" : 1, "memoryBytesSpilled" : 0, "diskBytesSpilled" : 0, + "peakExecutionMemory" : 0, "inputMetrics" : { "bytesRead" : 60488, "recordsRead" : 10000 @@ -219,16 +224,16 @@ }, "shuffleWriteMetrics" : { "bytesWritten" : 1710, - "writeTime" : 1311694, + "writeTime" : 83022, "recordsWritten" : 10 } } }, { - "taskId" : 4, - "index" : 4, + "taskId" : 3, + "index" : 3, "attempt" : 0, "launchTime" : "2015-05-06T13:03:06.504GMT", - "duration" : 349, + "duration" : 423, "executorId" : "driver", "host" : "localhost", "status" : "SUCCESS", @@ -242,9 +247,10 @@ "executorCpuTime" : 0, "resultSize" : 2010, "jvmGcTime" : 7, - "resultSerializationTime" : 1, + "resultSerializationTime" : 2, "memoryBytesSpilled" : 0, "diskBytesSpilled" : 0, + "peakExecutionMemory" : 0, "inputMetrics" : { "bytesRead" : 60488, "recordsRead" : 10000 @@ -264,16 +270,16 @@ }, "shuffleWriteMetrics" : { "bytesWritten" : 1710, - "writeTime" : 83022, + "writeTime" : 1311694, "recordsWritten" : 10 } } }, { - "taskId" : 7, - "index" : 7, + "taskId" : 0, + "index" : 0, "attempt" : 0, - "launchTime" : "2015-05-06T13:03:06.506GMT", - "duration" : 349, + "launchTime" : "2015-05-06T13:03:06.494GMT", + "duration" : 435, "executorId" : "driver", "host" : "localhost", "status" : "SUCCESS", @@ -281,17 +287,18 @@ "speculative" : false, "accumulatorUpdates" : [ ], "taskMetrics" : { - "executorDeserializeTime" : 31, + "executorDeserializeTime" : 32, "executorDeserializeCpuTime" : 0, "executorRunTime" : 349, "executorCpuTime" : 0, "resultSize" : 2010, "jvmGcTime" : 7, - "resultSerializationTime" : 0, + "resultSerializationTime" : 1, "memoryBytesSpilled" : 0, "diskBytesSpilled" : 0, + "peakExecutionMemory" : 0, "inputMetrics" : { - "bytesRead" : 60488, + "bytesRead" : 49294, "recordsRead" : 10000 }, "outputMetrics" : { @@ -309,7 +316,7 @@ }, "shuffleWriteMetrics" : { "bytesWritten" : 1710, - "writeTime" : 2579051, + "writeTime" : 3842811, "recordsWritten" : 10 } } @@ -318,7 +325,7 @@ "index" : 2, "attempt" : 0, "launchTime" : "2015-05-06T13:03:06.503GMT", - "duration" : 348, + "duration" : 419, "executorId" : "driver", "host" : "localhost", "status" : "SUCCESS", @@ -335,6 +342,7 @@ "resultSerializationTime" : 2, "memoryBytesSpilled" : 0, "diskBytesSpilled" : 0, + "peakExecutionMemory" : 0, "inputMetrics" : { "bytesRead" : 60488, "recordsRead" : 10000 @@ -363,7 +371,7 @@ "index" : 22, "attempt" : 0, "launchTime" : "2015-05-06T13:03:07.018GMT", - "duration" : 93, + "duration" : 101, "executorId" : "driver", "host" : "localhost", "status" : "SUCCESS", @@ -380,6 +388,7 @@ "resultSerializationTime" : 0, "memoryBytesSpilled" : 0, "diskBytesSpilled" : 0, + "peakExecutionMemory" : 0, "inputMetrics" : { "bytesRead" : 70564, "recordsRead" : 10000 @@ -408,7 +417,7 @@ "index" : 18, "attempt" : 0, "launchTime" : "2015-05-06T13:03:07.010GMT", - "duration" : 92, + "duration" : 105, "executorId" : "driver", "host" : "localhost", "status" : "SUCCESS", @@ -425,6 +434,7 @@ "resultSerializationTime" : 0, "memoryBytesSpilled" : 0, "diskBytesSpilled" : 0, + "peakExecutionMemory" : 0, "inputMetrics" : { "bytesRead" : 70564, "recordsRead" : 10000 @@ -453,7 +463,7 @@ "index" : 17, "attempt" : 0, "launchTime" : "2015-05-06T13:03:07.005GMT", - "duration" : 91, + "duration" : 123, "executorId" : "driver", "host" : "localhost", "status" : "SUCCESS", @@ -470,6 +480,7 @@ "resultSerializationTime" : 1, "memoryBytesSpilled" : 0, "diskBytesSpilled" : 0, + "peakExecutionMemory" : 0, "inputMetrics" : { "bytesRead" : 70564, "recordsRead" : 10000 @@ -498,7 +509,7 @@ "index" : 21, "attempt" : 0, "launchTime" : "2015-05-06T13:03:07.015GMT", - "duration" : 88, + "duration" : 96, "executorId" : "driver", "host" : "localhost", "status" : "SUCCESS", @@ -515,6 +526,7 @@ "resultSerializationTime" : 0, "memoryBytesSpilled" : 0, "diskBytesSpilled" : 0, + "peakExecutionMemory" : 0, "inputMetrics" : { "bytesRead" : 70564, "recordsRead" : 10000 @@ -539,11 +551,11 @@ } } }, { - "taskId" : 9, - "index" : 9, + "taskId" : 19, + "index" : 19, "attempt" : 0, - "launchTime" : "2015-05-06T13:03:06.915GMT", - "duration" : 84, + "launchTime" : "2015-05-06T13:03:07.012GMT", + "duration" : 94, "executorId" : "driver", "host" : "localhost", "status" : "SUCCESS", @@ -551,17 +563,18 @@ "speculative" : false, "accumulatorUpdates" : [ ], "taskMetrics" : { - "executorDeserializeTime" : 9, + "executorDeserializeTime" : 5, "executorDeserializeCpuTime" : 0, "executorRunTime" : 84, "executorCpuTime" : 0, "resultSize" : 2010, - "jvmGcTime" : 0, + "jvmGcTime" : 5, "resultSerializationTime" : 0, "memoryBytesSpilled" : 0, "diskBytesSpilled" : 0, + "peakExecutionMemory" : 0, "inputMetrics" : { - "bytesRead" : 60489, + "bytesRead" : 70564, "recordsRead" : 10000 }, "outputMetrics" : { @@ -579,7 +592,7 @@ }, "shuffleWriteMetrics" : { "bytesWritten" : 1710, - "writeTime" : 101664, + "writeTime" : 95788, "recordsWritten" : 10 } } @@ -588,7 +601,7 @@ "index" : 16, "attempt" : 0, "launchTime" : "2015-05-06T13:03:07.001GMT", - "duration" : 84, + "duration" : 98, "executorId" : "driver", "host" : "localhost", "status" : "SUCCESS", @@ -605,6 +618,7 @@ "resultSerializationTime" : 0, "memoryBytesSpilled" : 0, "diskBytesSpilled" : 0, + "peakExecutionMemory" : 0, "inputMetrics" : { "bytesRead" : 70564, "recordsRead" : 10000 @@ -629,11 +643,11 @@ } } }, { - "taskId" : 19, - "index" : 19, + "taskId" : 9, + "index" : 9, "attempt" : 0, - "launchTime" : "2015-05-06T13:03:07.012GMT", - "duration" : 84, + "launchTime" : "2015-05-06T13:03:06.915GMT", + "duration" : 101, "executorId" : "driver", "host" : "localhost", "status" : "SUCCESS", @@ -641,17 +655,18 @@ "speculative" : false, "accumulatorUpdates" : [ ], "taskMetrics" : { - "executorDeserializeTime" : 5, + "executorDeserializeTime" : 9, "executorDeserializeCpuTime" : 0, "executorRunTime" : 84, "executorCpuTime" : 0, "resultSize" : 2010, - "jvmGcTime" : 5, + "jvmGcTime" : 0, "resultSerializationTime" : 0, "memoryBytesSpilled" : 0, "diskBytesSpilled" : 0, + "peakExecutionMemory" : 0, "inputMetrics" : { - "bytesRead" : 70564, + "bytesRead" : 60489, "recordsRead" : 10000 }, "outputMetrics" : { @@ -669,16 +684,16 @@ }, "shuffleWriteMetrics" : { "bytesWritten" : 1710, - "writeTime" : 95788, + "writeTime" : 101664, "recordsWritten" : 10 } } }, { - "taskId" : 14, - "index" : 14, + "taskId" : 20, + "index" : 20, "attempt" : 0, - "launchTime" : "2015-05-06T13:03:06.925GMT", - "duration" : 83, + "launchTime" : "2015-05-06T13:03:07.014GMT", + "duration" : 90, "executorId" : "driver", "host" : "localhost", "status" : "SUCCESS", @@ -686,15 +701,16 @@ "speculative" : false, "accumulatorUpdates" : [ ], "taskMetrics" : { - "executorDeserializeTime" : 6, + "executorDeserializeTime" : 3, "executorDeserializeCpuTime" : 0, "executorRunTime" : 83, "executorCpuTime" : 0, "resultSize" : 2010, - "jvmGcTime" : 0, + "jvmGcTime" : 5, "resultSerializationTime" : 0, "memoryBytesSpilled" : 0, "diskBytesSpilled" : 0, + "peakExecutionMemory" : 0, "inputMetrics" : { "bytesRead" : 70564, "recordsRead" : 10000 @@ -714,16 +730,16 @@ }, "shuffleWriteMetrics" : { "bytesWritten" : 1710, - "writeTime" : 95646, + "writeTime" : 97716, "recordsWritten" : 10 } } }, { - "taskId" : 20, - "index" : 20, + "taskId" : 14, + "index" : 14, "attempt" : 0, - "launchTime" : "2015-05-06T13:03:07.014GMT", - "duration" : 83, + "launchTime" : "2015-05-06T13:03:06.925GMT", + "duration" : 94, "executorId" : "driver", "host" : "localhost", "status" : "SUCCESS", @@ -731,15 +747,16 @@ "speculative" : false, "accumulatorUpdates" : [ ], "taskMetrics" : { - "executorDeserializeTime" : 3, + "executorDeserializeTime" : 6, "executorDeserializeCpuTime" : 0, "executorRunTime" : 83, "executorCpuTime" : 0, "resultSize" : 2010, - "jvmGcTime" : 5, + "jvmGcTime" : 0, "resultSerializationTime" : 0, "memoryBytesSpilled" : 0, "diskBytesSpilled" : 0, + "peakExecutionMemory" : 0, "inputMetrics" : { "bytesRead" : 70564, "recordsRead" : 10000 @@ -759,7 +776,7 @@ }, "shuffleWriteMetrics" : { "bytesWritten" : 1710, - "writeTime" : 97716, + "writeTime" : 95646, "recordsWritten" : 10 } } @@ -768,7 +785,7 @@ "index" : 8, "attempt" : 0, "launchTime" : "2015-05-06T13:03:06.914GMT", - "duration" : 80, + "duration" : 88, "executorId" : "driver", "host" : "localhost", "status" : "SUCCESS", @@ -785,6 +802,7 @@ "resultSerializationTime" : 0, "memoryBytesSpilled" : 0, "diskBytesSpilled" : 0, + "peakExecutionMemory" : 0, "inputMetrics" : { "bytesRead" : 60488, "recordsRead" : 10000 @@ -813,7 +831,7 @@ "index" : 12, "attempt" : 0, "launchTime" : "2015-05-06T13:03:06.923GMT", - "duration" : 77, + "duration" : 93, "executorId" : "driver", "host" : "localhost", "status" : "SUCCESS", @@ -830,6 +848,7 @@ "resultSerializationTime" : 0, "memoryBytesSpilled" : 0, "diskBytesSpilled" : 0, + "peakExecutionMemory" : 0, "inputMetrics" : { "bytesRead" : 70564, "recordsRead" : 10000 @@ -854,11 +873,11 @@ } } }, { - "taskId" : 13, - "index" : 13, + "taskId" : 15, + "index" : 15, "attempt" : 0, - "launchTime" : "2015-05-06T13:03:06.924GMT", - "duration" : 76, + "launchTime" : "2015-05-06T13:03:06.928GMT", + "duration" : 83, "executorId" : "driver", "host" : "localhost", "status" : "SUCCESS", @@ -866,7 +885,7 @@ "speculative" : false, "accumulatorUpdates" : [ ], "taskMetrics" : { - "executorDeserializeTime" : 9, + "executorDeserializeTime" : 3, "executorDeserializeCpuTime" : 0, "executorRunTime" : 76, "executorCpuTime" : 0, @@ -875,6 +894,7 @@ "resultSerializationTime" : 0, "memoryBytesSpilled" : 0, "diskBytesSpilled" : 0, + "peakExecutionMemory" : 0, "inputMetrics" : { "bytesRead" : 70564, "recordsRead" : 10000 @@ -894,7 +914,7 @@ }, "shuffleWriteMetrics" : { "bytesWritten" : 1710, - "writeTime" : 95004, + "writeTime" : 602780, "recordsWritten" : 10 } } diff --git a/core/src/test/resources/HistoryServerExpectations/stage_task_list_w__sortBy_short_names__runtime_expectation.json b/core/src/test/resources/HistoryServerExpectations/stage_task_list_w__sortBy_short_names__runtime_expectation.json index c26daf4b8d7bd..09857cb401acd 100644 --- a/core/src/test/resources/HistoryServerExpectations/stage_task_list_w__sortBy_short_names__runtime_expectation.json +++ b/core/src/test/resources/HistoryServerExpectations/stage_task_list_w__sortBy_short_names__runtime_expectation.json @@ -3,7 +3,7 @@ "index" : 40, "attempt" : 0, "launchTime" : "2015-05-06T13:03:07.197GMT", - "duration" : 14, + "duration" : 24, "executorId" : "driver", "host" : "localhost", "status" : "SUCCESS", @@ -20,6 +20,7 @@ "resultSerializationTime" : 0, "memoryBytesSpilled" : 0, "diskBytesSpilled" : 0, + "peakExecutionMemory" : 0, "inputMetrics" : { "bytesRead" : 70564, "recordsRead" : 10000 @@ -48,7 +49,7 @@ "index" : 41, "attempt" : 0, "launchTime" : "2015-05-06T13:03:07.200GMT", - "duration" : 16, + "duration" : 24, "executorId" : "driver", "host" : "localhost", "status" : "SUCCESS", @@ -65,6 +66,7 @@ "resultSerializationTime" : 0, "memoryBytesSpilled" : 0, "diskBytesSpilled" : 0, + "peakExecutionMemory" : 0, "inputMetrics" : { "bytesRead" : 70564, "recordsRead" : 10000 @@ -93,7 +95,7 @@ "index" : 43, "attempt" : 0, "launchTime" : "2015-05-06T13:03:07.204GMT", - "duration" : 16, + "duration" : 39, "executorId" : "driver", "host" : "localhost", "status" : "SUCCESS", @@ -110,6 +112,7 @@ "resultSerializationTime" : 0, "memoryBytesSpilled" : 0, "diskBytesSpilled" : 0, + "peakExecutionMemory" : 0, "inputMetrics" : { "bytesRead" : 70564, "recordsRead" : 10000 @@ -138,7 +141,7 @@ "index" : 57, "attempt" : 0, "launchTime" : "2015-05-06T13:03:07.257GMT", - "duration" : 16, + "duration" : 21, "executorId" : "driver", "host" : "localhost", "status" : "SUCCESS", @@ -155,6 +158,7 @@ "resultSerializationTime" : 0, "memoryBytesSpilled" : 0, "diskBytesSpilled" : 0, + "peakExecutionMemory" : 0, "inputMetrics" : { "bytesRead" : 70564, "recordsRead" : 10000 @@ -183,7 +187,7 @@ "index" : 58, "attempt" : 0, "launchTime" : "2015-05-06T13:03:07.263GMT", - "duration" : 16, + "duration" : 23, "executorId" : "driver", "host" : "localhost", "status" : "SUCCESS", @@ -200,6 +204,7 @@ "resultSerializationTime" : 0, "memoryBytesSpilled" : 0, "diskBytesSpilled" : 0, + "peakExecutionMemory" : 0, "inputMetrics" : { "bytesRead" : 70564, "recordsRead" : 10000 @@ -228,7 +233,7 @@ "index" : 68, "attempt" : 0, "launchTime" : "2015-05-06T13:03:07.306GMT", - "duration" : 16, + "duration" : 22, "executorId" : "driver", "host" : "localhost", "status" : "SUCCESS", @@ -245,6 +250,7 @@ "resultSerializationTime" : 0, "memoryBytesSpilled" : 0, "diskBytesSpilled" : 0, + "peakExecutionMemory" : 0, "inputMetrics" : { "bytesRead" : 70564, "recordsRead" : 10000 @@ -273,7 +279,7 @@ "index" : 86, "attempt" : 0, "launchTime" : "2015-05-06T13:03:07.374GMT", - "duration" : 16, + "duration" : 28, "executorId" : "driver", "host" : "localhost", "status" : "SUCCESS", @@ -290,6 +296,7 @@ "resultSerializationTime" : 1, "memoryBytesSpilled" : 0, "diskBytesSpilled" : 0, + "peakExecutionMemory" : 0, "inputMetrics" : { "bytesRead" : 70564, "recordsRead" : 10000 @@ -318,7 +325,7 @@ "index" : 32, "attempt" : 0, "launchTime" : "2015-05-06T13:03:07.148GMT", - "duration" : 17, + "duration" : 33, "executorId" : "driver", "host" : "localhost", "status" : "SUCCESS", @@ -335,6 +342,7 @@ "resultSerializationTime" : 0, "memoryBytesSpilled" : 0, "diskBytesSpilled" : 0, + "peakExecutionMemory" : 0, "inputMetrics" : { "bytesRead" : 70564, "recordsRead" : 10000 @@ -363,7 +371,7 @@ "index" : 39, "attempt" : 0, "launchTime" : "2015-05-06T13:03:07.180GMT", - "duration" : 17, + "duration" : 32, "executorId" : "driver", "host" : "localhost", "status" : "SUCCESS", @@ -380,6 +388,7 @@ "resultSerializationTime" : 0, "memoryBytesSpilled" : 0, "diskBytesSpilled" : 0, + "peakExecutionMemory" : 0, "inputMetrics" : { "bytesRead" : 70564, "recordsRead" : 10000 @@ -408,7 +417,7 @@ "index" : 42, "attempt" : 0, "launchTime" : "2015-05-06T13:03:07.203GMT", - "duration" : 17, + "duration" : 42, "executorId" : "driver", "host" : "localhost", "status" : "SUCCESS", @@ -425,6 +434,7 @@ "resultSerializationTime" : 0, "memoryBytesSpilled" : 0, "diskBytesSpilled" : 0, + "peakExecutionMemory" : 0, "inputMetrics" : { "bytesRead" : 70564, "recordsRead" : 10000 @@ -453,7 +463,7 @@ "index" : 51, "attempt" : 0, "launchTime" : "2015-05-06T13:03:07.242GMT", - "duration" : 17, + "duration" : 21, "executorId" : "driver", "host" : "localhost", "status" : "SUCCESS", @@ -470,6 +480,7 @@ "resultSerializationTime" : 0, "memoryBytesSpilled" : 0, "diskBytesSpilled" : 0, + "peakExecutionMemory" : 0, "inputMetrics" : { "bytesRead" : 70564, "recordsRead" : 10000 @@ -498,7 +509,7 @@ "index" : 59, "attempt" : 0, "launchTime" : "2015-05-06T13:03:07.265GMT", - "duration" : 17, + "duration" : 23, "executorId" : "driver", "host" : "localhost", "status" : "SUCCESS", @@ -515,6 +526,7 @@ "resultSerializationTime" : 0, "memoryBytesSpilled" : 0, "diskBytesSpilled" : 0, + "peakExecutionMemory" : 0, "inputMetrics" : { "bytesRead" : 70564, "recordsRead" : 10000 @@ -543,7 +555,7 @@ "index" : 63, "attempt" : 0, "launchTime" : "2015-05-06T13:03:07.276GMT", - "duration" : 17, + "duration" : 40, "executorId" : "driver", "host" : "localhost", "status" : "SUCCESS", @@ -560,6 +572,7 @@ "resultSerializationTime" : 0, "memoryBytesSpilled" : 0, "diskBytesSpilled" : 0, + "peakExecutionMemory" : 0, "inputMetrics" : { "bytesRead" : 70564, "recordsRead" : 10000 @@ -588,7 +601,7 @@ "index" : 87, "attempt" : 0, "launchTime" : "2015-05-06T13:03:07.374GMT", - "duration" : 17, + "duration" : 36, "executorId" : "driver", "host" : "localhost", "status" : "SUCCESS", @@ -605,6 +618,7 @@ "resultSerializationTime" : 0, "memoryBytesSpilled" : 0, "diskBytesSpilled" : 0, + "peakExecutionMemory" : 0, "inputMetrics" : { "bytesRead" : 70564, "recordsRead" : 10000 @@ -633,7 +647,7 @@ "index" : 90, "attempt" : 0, "launchTime" : "2015-05-06T13:03:07.385GMT", - "duration" : 17, + "duration" : 23, "executorId" : "driver", "host" : "localhost", "status" : "SUCCESS", @@ -650,6 +664,7 @@ "resultSerializationTime" : 0, "memoryBytesSpilled" : 0, "diskBytesSpilled" : 0, + "peakExecutionMemory" : 0, "inputMetrics" : { "bytesRead" : 70564, "recordsRead" : 10000 @@ -678,7 +693,7 @@ "index" : 99, "attempt" : 0, "launchTime" : "2015-05-06T13:03:07.426GMT", - "duration" : 17, + "duration" : 22, "executorId" : "driver", "host" : "localhost", "status" : "SUCCESS", @@ -695,6 +710,7 @@ "resultSerializationTime" : 0, "memoryBytesSpilled" : 0, "diskBytesSpilled" : 0, + "peakExecutionMemory" : 0, "inputMetrics" : { "bytesRead" : 70565, "recordsRead" : 10000 @@ -723,7 +739,7 @@ "index" : 44, "attempt" : 0, "launchTime" : "2015-05-06T13:03:07.205GMT", - "duration" : 18, + "duration" : 37, "executorId" : "driver", "host" : "localhost", "status" : "SUCCESS", @@ -740,6 +756,7 @@ "resultSerializationTime" : 0, "memoryBytesSpilled" : 0, "diskBytesSpilled" : 0, + "peakExecutionMemory" : 0, "inputMetrics" : { "bytesRead" : 70564, "recordsRead" : 10000 @@ -768,7 +785,7 @@ "index" : 47, "attempt" : 0, "launchTime" : "2015-05-06T13:03:07.212GMT", - "duration" : 18, + "duration" : 33, "executorId" : "driver", "host" : "localhost", "status" : "SUCCESS", @@ -785,6 +802,7 @@ "resultSerializationTime" : 0, "memoryBytesSpilled" : 0, "diskBytesSpilled" : 0, + "peakExecutionMemory" : 0, "inputMetrics" : { "bytesRead" : 70564, "recordsRead" : 10000 @@ -813,7 +831,7 @@ "index" : 50, "attempt" : 0, "launchTime" : "2015-05-06T13:03:07.240GMT", - "duration" : 18, + "duration" : 26, "executorId" : "driver", "host" : "localhost", "status" : "SUCCESS", @@ -830,6 +848,7 @@ "resultSerializationTime" : 0, "memoryBytesSpilled" : 0, "diskBytesSpilled" : 0, + "peakExecutionMemory" : 0, "inputMetrics" : { "bytesRead" : 70564, "recordsRead" : 10000 @@ -858,7 +877,7 @@ "index" : 52, "attempt" : 0, "launchTime" : "2015-05-06T13:03:07.243GMT", - "duration" : 18, + "duration" : 28, "executorId" : "driver", "host" : "localhost", "status" : "SUCCESS", @@ -875,6 +894,7 @@ "resultSerializationTime" : 0, "memoryBytesSpilled" : 0, "diskBytesSpilled" : 0, + "peakExecutionMemory" : 0, "inputMetrics" : { "bytesRead" : 70564, "recordsRead" : 10000 diff --git a/core/src/test/resources/HistoryServerExpectations/stage_with_accumulable_json_expectation.json b/core/src/test/resources/HistoryServerExpectations/stage_with_accumulable_json_expectation.json index 44b5f66efe339..9cdcef0746185 100644 --- a/core/src/test/resources/HistoryServerExpectations/stage_with_accumulable_json_expectation.json +++ b/core/src/test/resources/HistoryServerExpectations/stage_with_accumulable_json_expectation.json @@ -2,9 +2,12 @@ "status" : "COMPLETE", "stageId" : 0, "attemptId" : 0, + "numTasks" : 8, "numActiveTasks" : 0, "numCompleteTasks" : 8, "numFailedTasks" : 0, + "numKilledTasks" : 0, + "numCompletedIndices" : 8, "executorRunTime" : 120, "executorCpuTime" : 0, "submissionTime" : "2015-03-16T19:25:36.103GMT", @@ -23,6 +26,7 @@ "name" : "foreach at :15", "details" : "org.apache.spark.rdd.RDD.foreach(RDD.scala:765)\n$line9.$read$$iwC$$iwC$$iwC$$iwC.(:15)\n$line9.$read$$iwC$$iwC$$iwC.(:20)\n$line9.$read$$iwC$$iwC.(:22)\n$line9.$read$$iwC.(:24)\n$line9.$read.(:26)\n$line9.$read$.(:30)\n$line9.$read$.()\n$line9.$eval$.(:7)\n$line9.$eval$.()\n$line9.$eval.$print()\nsun.reflect.NativeMethodAccessorImpl.invoke0(Native Method)\nsun.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:62)\nsun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)\njava.lang.reflect.Method.invoke(Method.java:483)\norg.apache.spark.repl.SparkIMain$ReadEvalPrint.call(SparkIMain.scala:852)\norg.apache.spark.repl.SparkIMain$Request.loadAndRun(SparkIMain.scala:1125)\norg.apache.spark.repl.SparkIMain.loadAndRunReq$1(SparkIMain.scala:674)\norg.apache.spark.repl.SparkIMain.interpret(SparkIMain.scala:705)\norg.apache.spark.repl.SparkIMain.interpret(SparkIMain.scala:669)", "schedulingPool" : "default", + "rddIds" : [ 0 ], "accumulatorUpdates" : [ { "id" : 1, "name" : "my counter", @@ -34,7 +38,7 @@ "index" : 0, "attempt" : 0, "launchTime" : "2015-03-16T19:25:36.515GMT", - "duration" : 15, + "duration" : 61, "executorId" : "", "host" : "localhost", "status" : "SUCCESS", @@ -56,6 +60,7 @@ "resultSerializationTime" : 2, "memoryBytesSpilled" : 0, "diskBytesSpilled" : 0, + "peakExecutionMemory" : 0, "inputMetrics" : { "bytesRead" : 0, "recordsRead" : 0 @@ -85,7 +90,7 @@ "index" : 1, "attempt" : 0, "launchTime" : "2015-03-16T19:25:36.521GMT", - "duration" : 15, + "duration" : 53, "executorId" : "", "host" : "localhost", "status" : "SUCCESS", @@ -107,6 +112,7 @@ "resultSerializationTime" : 2, "memoryBytesSpilled" : 0, "diskBytesSpilled" : 0, + "peakExecutionMemory" : 0, "inputMetrics" : { "bytesRead" : 0, "recordsRead" : 0 @@ -136,7 +142,7 @@ "index" : 2, "attempt" : 0, "launchTime" : "2015-03-16T19:25:36.522GMT", - "duration" : 15, + "duration" : 48, "executorId" : "", "host" : "localhost", "status" : "SUCCESS", @@ -158,6 +164,7 @@ "resultSerializationTime" : 2, "memoryBytesSpilled" : 0, "diskBytesSpilled" : 0, + "peakExecutionMemory" : 0, "inputMetrics" : { "bytesRead" : 0, "recordsRead" : 0 @@ -187,7 +194,7 @@ "index" : 3, "attempt" : 0, "launchTime" : "2015-03-16T19:25:36.522GMT", - "duration" : 15, + "duration" : 50, "executorId" : "", "host" : "localhost", "status" : "SUCCESS", @@ -209,6 +216,7 @@ "resultSerializationTime" : 2, "memoryBytesSpilled" : 0, "diskBytesSpilled" : 0, + "peakExecutionMemory" : 0, "inputMetrics" : { "bytesRead" : 0, "recordsRead" : 0 @@ -238,7 +246,7 @@ "index" : 4, "attempt" : 0, "launchTime" : "2015-03-16T19:25:36.522GMT", - "duration" : 15, + "duration" : 52, "executorId" : "", "host" : "localhost", "status" : "SUCCESS", @@ -260,6 +268,7 @@ "resultSerializationTime" : 1, "memoryBytesSpilled" : 0, "diskBytesSpilled" : 0, + "peakExecutionMemory" : 0, "inputMetrics" : { "bytesRead" : 0, "recordsRead" : 0 @@ -289,7 +298,7 @@ "index" : 5, "attempt" : 0, "launchTime" : "2015-03-16T19:25:36.523GMT", - "duration" : 15, + "duration" : 52, "executorId" : "", "host" : "localhost", "status" : "SUCCESS", @@ -311,6 +320,7 @@ "resultSerializationTime" : 2, "memoryBytesSpilled" : 0, "diskBytesSpilled" : 0, + "peakExecutionMemory" : 0, "inputMetrics" : { "bytesRead" : 0, "recordsRead" : 0 @@ -340,7 +350,7 @@ "index" : 6, "attempt" : 0, "launchTime" : "2015-03-16T19:25:36.523GMT", - "duration" : 15, + "duration" : 51, "executorId" : "", "host" : "localhost", "status" : "SUCCESS", @@ -362,6 +372,7 @@ "resultSerializationTime" : 2, "memoryBytesSpilled" : 0, "diskBytesSpilled" : 0, + "peakExecutionMemory" : 0, "inputMetrics" : { "bytesRead" : 0, "recordsRead" : 0 @@ -391,7 +402,7 @@ "index" : 7, "attempt" : 0, "launchTime" : "2015-03-16T19:25:36.524GMT", - "duration" : 15, + "duration" : 51, "executorId" : "", "host" : "localhost", "status" : "SUCCESS", @@ -413,6 +424,7 @@ "resultSerializationTime" : 2, "memoryBytesSpilled" : 0, "diskBytesSpilled" : 0, + "peakExecutionMemory" : 0, "inputMetrics" : { "bytesRead" : 0, "recordsRead" : 0 @@ -443,12 +455,18 @@ "taskTime" : 418, "failedTasks" : 0, "succeededTasks" : 8, + "killedTasks" : 0, "inputBytes" : 0, + "inputRecords" : 0, "outputBytes" : 0, + "outputRecords" : 0, "shuffleRead" : 0, + "shuffleReadRecords" : 0, "shuffleWrite" : 0, + "shuffleWriteRecords" : 0, "memoryBytesSpilled" : 0, "diskBytesSpilled" : 0 } - } + }, + "killedTasksSummary" : { } } diff --git a/core/src/test/resources/HistoryServerExpectations/succeeded_failed_job_list_json_expectation.json b/core/src/test/resources/HistoryServerExpectations/succeeded_failed_job_list_json_expectation.json index 3d7407004d262..71bf8706307c8 100644 --- a/core/src/test/resources/HistoryServerExpectations/succeeded_failed_job_list_json_expectation.json +++ b/core/src/test/resources/HistoryServerExpectations/succeeded_failed_job_list_json_expectation.json @@ -8,10 +8,13 @@ "numCompletedTasks" : 8, "numSkippedTasks" : 0, "numFailedTasks" : 0, + "numKilledTasks" : 0, + "numCompletedIndices" : 8, "numActiveStages" : 0, "numCompletedStages" : 1, "numSkippedStages" : 0, - "numFailedStages" : 0 + "numFailedStages" : 0, + "killedTasksSummary" : { } }, { "jobId" : 1, "name" : "count at :20", @@ -22,10 +25,13 @@ "numCompletedTasks" : 15, "numSkippedTasks" : 0, "numFailedTasks" : 1, + "numKilledTasks" : 0, + "numCompletedIndices" : 15, "numActiveStages" : 0, "numCompletedStages" : 1, "numSkippedStages" : 0, - "numFailedStages" : 1 + "numFailedStages" : 1, + "killedTasksSummary" : { } }, { "jobId" : 0, "name" : "count at :15", @@ -36,8 +42,11 @@ "numCompletedTasks" : 8, "numSkippedTasks" : 0, "numFailedTasks" : 0, + "numKilledTasks" : 0, + "numCompletedIndices" : 8, "numActiveStages" : 0, "numCompletedStages" : 1, "numSkippedStages" : 0, - "numFailedStages" : 0 + "numFailedStages" : 0, + "killedTasksSummary" : { } } ] diff --git a/core/src/test/resources/HistoryServerExpectations/succeeded_job_list_json_expectation.json b/core/src/test/resources/HistoryServerExpectations/succeeded_job_list_json_expectation.json index 6a9bafd6b2191..b1ddd760c9714 100644 --- a/core/src/test/resources/HistoryServerExpectations/succeeded_job_list_json_expectation.json +++ b/core/src/test/resources/HistoryServerExpectations/succeeded_job_list_json_expectation.json @@ -8,10 +8,13 @@ "numCompletedTasks" : 8, "numSkippedTasks" : 0, "numFailedTasks" : 0, + "numKilledTasks" : 0, + "numCompletedIndices" : 8, "numActiveStages" : 0, "numCompletedStages" : 1, "numSkippedStages" : 0, - "numFailedStages" : 0 + "numFailedStages" : 0, + "killedTasksSummary" : { } }, { "jobId" : 0, "name" : "count at :15", @@ -22,8 +25,11 @@ "numCompletedTasks" : 8, "numSkippedTasks" : 0, "numFailedTasks" : 0, + "numKilledTasks" : 0, + "numCompletedIndices" : 8, "numActiveStages" : 0, "numCompletedStages" : 1, "numSkippedStages" : 0, - "numFailedStages" : 0 + "numFailedStages" : 0, + "killedTasksSummary" : { } } ] diff --git a/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerSuite.scala b/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerSuite.scala index 6a1abceaeb63c..d22a19e8af74a 100644 --- a/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerSuite.scala @@ -23,6 +23,7 @@ import java.util.zip.ZipInputStream import javax.servlet._ import javax.servlet.http.{HttpServletRequest, HttpServletRequestWrapper, HttpServletResponse} +import scala.collection.JavaConverters._ import scala.concurrent.duration._ import scala.language.postfixOps @@ -44,8 +45,8 @@ import org.scalatest.selenium.WebBrowser import org.apache.spark._ import org.apache.spark.deploy.history.config._ +import org.apache.spark.status.api.v1.JobData import org.apache.spark.ui.SparkUI -import org.apache.spark.ui.jobs.UIData.JobUIData import org.apache.spark.util.{ResetSystemProperties, Utils} /** @@ -262,7 +263,7 @@ class HistoryServerSuite extends SparkFunSuite with BeforeAndAfter with Matchers val badStageAttemptId = getContentAndCode("applications/local-1422981780767/stages/1/1") badStageAttemptId._1 should be (HttpServletResponse.SC_NOT_FOUND) - badStageAttemptId._3 should be (Some("unknown attempt for stage 1. Found attempts: [0]")) + badStageAttemptId._3 should be (Some("unknown attempt 1 for stage 1.")) val badStageId2 = getContentAndCode("applications/local-1422981780767/stages/flimflam") badStageId2._1 should be (HttpServletResponse.SC_NOT_FOUND) @@ -496,12 +497,12 @@ class HistoryServerSuite extends SparkFunSuite with BeforeAndAfter with Matchers } } - def completedJobs(): Seq[JobUIData] = { - getAppUI.jobProgressListener.completedJobs + def completedJobs(): Seq[JobData] = { + getAppUI.store.jobsList(List(JobExecutionStatus.SUCCEEDED).asJava) } - def activeJobs(): Seq[JobUIData] = { - getAppUI.jobProgressListener.activeJobs.values.toSeq + def activeJobs(): Seq[JobData] = { + getAppUI.store.jobsList(List(JobExecutionStatus.RUNNING).asJava) } activeJobs() should have size 0 diff --git a/core/src/test/scala/org/apache/spark/status/AppStatusListenerSuite.scala b/core/src/test/scala/org/apache/spark/status/AppStatusListenerSuite.scala index ba082bc93dd42..88fe6bd70a14e 100644 --- a/core/src/test/scala/org/apache/spark/status/AppStatusListenerSuite.scala +++ b/core/src/test/scala/org/apache/spark/status/AppStatusListenerSuite.scala @@ -180,7 +180,7 @@ class AppStatusListenerSuite extends SparkFunSuite with BeforeAndAfter { check[StageDataWrapper](key(stages.head)) { stage => assert(stage.info.status === v1.StageStatus.ACTIVE) assert(stage.info.submissionTime === Some(new Date(stages.head.submissionTime.get))) - assert(stage.info.schedulingPool === "schedPool") + assert(stage.info.numTasks === stages.head.numTasks) } // Start tasks from stage 1 @@ -265,12 +265,7 @@ class AppStatusListenerSuite extends SparkFunSuite with BeforeAndAfter { "taskType", TaskResultLost, s1Tasks.head, null)) time += 1 - val reattempt = { - val orig = s1Tasks.head - // Task reattempts have a different ID, but the same index as the original. - new TaskInfo(nextTaskId(), orig.index, orig.attemptNumber + 1, time, orig.executorId, - s"${orig.executorId}.example.com", TaskLocality.PROCESS_LOCAL, orig.speculative) - } + val reattempt = newAttempt(s1Tasks.head, nextTaskId()) listener.onTaskStart(SparkListenerTaskStart(stages.head.stageId, stages.head.attemptId, reattempt)) @@ -288,7 +283,6 @@ class AppStatusListenerSuite extends SparkFunSuite with BeforeAndAfter { check[TaskDataWrapper](s1Tasks.head.taskId) { task => assert(task.info.status === s1Tasks.head.status) - assert(task.info.duration === Some(s1Tasks.head.duration)) assert(task.info.errorMessage == Some(TaskResultLost.toErrorString)) } @@ -297,8 +291,64 @@ class AppStatusListenerSuite extends SparkFunSuite with BeforeAndAfter { assert(task.info.attempt === reattempt.attemptNumber) } + // Kill one task, restart it. + time += 1 + val killed = s1Tasks.drop(1).head + killed.finishTime = time + killed.failed = true + listener.onTaskEnd(SparkListenerTaskEnd(stages.head.stageId, stages.head.attemptId, + "taskType", TaskKilled("killed"), killed, null)) + + check[JobDataWrapper](1) { job => + assert(job.info.numKilledTasks === 1) + assert(job.info.killedTasksSummary === Map("killed" -> 1)) + } + + check[StageDataWrapper](key(stages.head)) { stage => + assert(stage.info.numKilledTasks === 1) + assert(stage.info.killedTasksSummary === Map("killed" -> 1)) + } + + check[TaskDataWrapper](killed.taskId) { task => + assert(task.info.index === killed.index) + assert(task.info.errorMessage === Some("killed")) + } + + // Start a new attempt and finish it with TaskCommitDenied, make sure it's handled like a kill. + time += 1 + val denied = newAttempt(killed, nextTaskId()) + val denyReason = TaskCommitDenied(1, 1, 1) + listener.onTaskStart(SparkListenerTaskStart(stages.head.stageId, stages.head.attemptId, + denied)) + + time += 1 + denied.finishTime = time + denied.failed = true + listener.onTaskEnd(SparkListenerTaskEnd(stages.head.stageId, stages.head.attemptId, + "taskType", denyReason, denied, null)) + + check[JobDataWrapper](1) { job => + assert(job.info.numKilledTasks === 2) + assert(job.info.killedTasksSummary === Map("killed" -> 1, denyReason.toErrorString -> 1)) + } + + check[StageDataWrapper](key(stages.head)) { stage => + assert(stage.info.numKilledTasks === 2) + assert(stage.info.killedTasksSummary === Map("killed" -> 1, denyReason.toErrorString -> 1)) + } + + check[TaskDataWrapper](denied.taskId) { task => + assert(task.info.index === killed.index) + assert(task.info.errorMessage === Some(denyReason.toErrorString)) + } + + // Start a new attempt. + val reattempt2 = newAttempt(denied, nextTaskId()) + listener.onTaskStart(SparkListenerTaskStart(stages.head.stageId, stages.head.attemptId, + reattempt2)) + // Succeed all tasks in stage 1. - val pending = s1Tasks.drop(1) ++ Seq(reattempt) + val pending = s1Tasks.drop(2) ++ Seq(reattempt, reattempt2) val s1Metrics = TaskMetrics.empty s1Metrics.setExecutorCpuTime(2L) @@ -313,12 +363,14 @@ class AppStatusListenerSuite extends SparkFunSuite with BeforeAndAfter { check[JobDataWrapper](1) { job => assert(job.info.numFailedTasks === 1) + assert(job.info.numKilledTasks === 2) assert(job.info.numActiveTasks === 0) assert(job.info.numCompletedTasks === pending.size) } check[StageDataWrapper](key(stages.head)) { stage => assert(stage.info.numFailedTasks === 1) + assert(stage.info.numKilledTasks === 2) assert(stage.info.numActiveTasks === 0) assert(stage.info.numCompleteTasks === pending.size) } @@ -328,10 +380,11 @@ class AppStatusListenerSuite extends SparkFunSuite with BeforeAndAfter { assert(wrapper.info.errorMessage === None) assert(wrapper.info.taskMetrics.get.executorCpuTime === 2L) assert(wrapper.info.taskMetrics.get.executorRunTime === 4L) + assert(wrapper.info.duration === Some(task.duration)) } } - assert(store.count(classOf[TaskDataWrapper]) === pending.size + 1) + assert(store.count(classOf[TaskDataWrapper]) === pending.size + 3) // End stage 1. time += 1 @@ -404,6 +457,7 @@ class AppStatusListenerSuite extends SparkFunSuite with BeforeAndAfter { assert(stage.info.numFailedTasks === s2Tasks.size) assert(stage.info.numActiveTasks === 0) assert(stage.info.numCompleteTasks === 0) + assert(stage.info.failureReason === stages.last.failureReason) } // - Re-submit stage 2, all tasks, and succeed them and the stage. @@ -804,6 +858,12 @@ class AppStatusListenerSuite extends SparkFunSuite with BeforeAndAfter { fn(value) } + private def newAttempt(orig: TaskInfo, nextId: Long): TaskInfo = { + // Task reattempts have a different ID, but the same index as the original. + new TaskInfo(nextId, orig.index, orig.attemptNumber + 1, time, orig.executorId, + s"${orig.executorId}.example.com", TaskLocality.PROCESS_LOCAL, orig.speculative) + } + private case class RddBlock( rddId: Int, partId: Int, diff --git a/core/src/test/scala/org/apache/spark/status/api/v1/AllStagesResourceSuite.scala b/core/src/test/scala/org/apache/spark/status/api/v1/AllStagesResourceSuite.scala deleted file mode 100644 index 82bd7c4ff6604..0000000000000 --- a/core/src/test/scala/org/apache/spark/status/api/v1/AllStagesResourceSuite.scala +++ /dev/null @@ -1,62 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.status.api.v1 - -import java.util.Date - -import scala.collection.mutable.LinkedHashMap - -import org.apache.spark.SparkFunSuite -import org.apache.spark.scheduler.{StageInfo, TaskInfo, TaskLocality} -import org.apache.spark.ui.jobs.UIData.{StageUIData, TaskUIData} - -class AllStagesResourceSuite extends SparkFunSuite { - - def getFirstTaskLaunchTime(taskLaunchTimes: Seq[Long]): Option[Date] = { - val tasks = new LinkedHashMap[Long, TaskUIData] - taskLaunchTimes.zipWithIndex.foreach { case (time, idx) => - tasks(idx.toLong) = TaskUIData( - new TaskInfo(idx, idx, 1, time, "", "", TaskLocality.ANY, false)) - } - - val stageUiData = new StageUIData() - stageUiData.taskData = tasks - val status = StageStatus.ACTIVE - val stageInfo = new StageInfo( - 1, 1, "stage 1", 10, Seq.empty, Seq.empty, "details abc") - val stageData = AllStagesResource.stageUiToStageData(status, stageInfo, stageUiData, false) - - stageData.firstTaskLaunchedTime - } - - test("firstTaskLaunchedTime when there are no tasks") { - val result = getFirstTaskLaunchTime(Seq()) - assert(result == None) - } - - test("firstTaskLaunchedTime when there are tasks but none launched") { - val result = getFirstTaskLaunchTime(Seq(-100L, -200L, -300L)) - assert(result == None) - } - - test("firstTaskLaunchedTime when there are tasks and some launched") { - val result = getFirstTaskLaunchTime(Seq(-100L, 1449255596000L, 1449255597000L)) - assert(result == Some(new Date(1449255596000L))) - } - -} diff --git a/core/src/test/scala/org/apache/spark/ui/StagePageSuite.scala b/core/src/test/scala/org/apache/spark/ui/StagePageSuite.scala index 1c51c148ae61b..46932a02f1a1b 100644 --- a/core/src/test/scala/org/apache/spark/ui/StagePageSuite.scala +++ b/core/src/test/scala/org/apache/spark/ui/StagePageSuite.scala @@ -29,8 +29,8 @@ import org.apache.spark._ import org.apache.spark.executor.TaskMetrics import org.apache.spark.scheduler._ import org.apache.spark.status.AppStatusStore -import org.apache.spark.ui.jobs.{JobProgressListener, StagePage, StagesTab} -import org.apache.spark.ui.scope.RDDOperationGraphListener +import org.apache.spark.ui.jobs.{StagePage, StagesTab} +import org.apache.spark.util.Utils class StagePageSuite extends SparkFunSuite with LocalSparkContext { @@ -55,38 +55,40 @@ class StagePageSuite extends SparkFunSuite with LocalSparkContext { * This also runs a dummy stage to populate the page with useful content. */ private def renderStagePage(conf: SparkConf): Seq[Node] = { - val store = mock(classOf[AppStatusStore]) - when(store.executorSummary(anyString())).thenReturn(None) + val bus = new ReplayListenerBus() + val store = AppStatusStore.createLiveStore(conf, l => bus.addListener(l)) - val jobListener = new JobProgressListener(conf) - val graphListener = new RDDOperationGraphListener(conf) - val tab = mock(classOf[StagesTab], RETURNS_SMART_NULLS) - val request = mock(classOf[HttpServletRequest]) - when(tab.conf).thenReturn(conf) - when(tab.progressListener).thenReturn(jobListener) - when(tab.operationGraphListener).thenReturn(graphListener) - when(tab.appName).thenReturn("testing") - when(tab.headerTabs).thenReturn(Seq.empty) - when(request.getParameter("id")).thenReturn("0") - when(request.getParameter("attempt")).thenReturn("0") - val page = new StagePage(tab, store) + try { + val tab = mock(classOf[StagesTab], RETURNS_SMART_NULLS) + when(tab.store).thenReturn(store) - // Simulate a stage in job progress listener - val stageInfo = new StageInfo(0, 0, "dummy", 1, Seq.empty, Seq.empty, "details") - // Simulate two tasks to test PEAK_EXECUTION_MEMORY correctness - (1 to 2).foreach { - taskId => - val taskInfo = new TaskInfo(taskId, taskId, 0, 0, "0", "localhost", TaskLocality.ANY, false) - jobListener.onStageSubmitted(SparkListenerStageSubmitted(stageInfo)) - jobListener.onTaskStart(SparkListenerTaskStart(0, 0, taskInfo)) - taskInfo.markFinished(TaskState.FINISHED, System.currentTimeMillis()) - val taskMetrics = TaskMetrics.empty - taskMetrics.incPeakExecutionMemory(peakExecutionMemory) - jobListener.onTaskEnd( - SparkListenerTaskEnd(0, 0, "result", Success, taskInfo, taskMetrics)) + val request = mock(classOf[HttpServletRequest]) + when(tab.conf).thenReturn(conf) + when(tab.appName).thenReturn("testing") + when(tab.headerTabs).thenReturn(Seq.empty) + when(request.getParameter("id")).thenReturn("0") + when(request.getParameter("attempt")).thenReturn("0") + val page = new StagePage(tab, store) + + // Simulate a stage in job progress listener + val stageInfo = new StageInfo(0, 0, "dummy", 1, Seq.empty, Seq.empty, "details") + // Simulate two tasks to test PEAK_EXECUTION_MEMORY correctness + (1 to 2).foreach { + taskId => + val taskInfo = new TaskInfo(taskId, taskId, 0, 0, "0", "localhost", TaskLocality.ANY, + false) + bus.postToAll(SparkListenerStageSubmitted(stageInfo)) + bus.postToAll(SparkListenerTaskStart(0, 0, taskInfo)) + taskInfo.markFinished(TaskState.FINISHED, System.currentTimeMillis()) + val taskMetrics = TaskMetrics.empty + taskMetrics.incPeakExecutionMemory(peakExecutionMemory) + bus.postToAll(SparkListenerTaskEnd(0, 0, "result", Success, taskInfo, taskMetrics)) + } + bus.postToAll(SparkListenerStageCompleted(stageInfo)) + page.render(request) + } finally { + store.close() } - jobListener.onStageCompleted(SparkListenerStageCompleted(stageInfo)) - page.render(request) } } diff --git a/core/src/test/scala/org/apache/spark/ui/UISeleniumSuite.scala b/core/src/test/scala/org/apache/spark/ui/UISeleniumSuite.scala index 267c8dc1bd750..6a6c37873e1c2 100644 --- a/core/src/test/scala/org/apache/spark/ui/UISeleniumSuite.scala +++ b/core/src/test/scala/org/apache/spark/ui/UISeleniumSuite.scala @@ -524,7 +524,7 @@ class UISeleniumSuite extends SparkFunSuite with WebBrowser with Matchers with B } } - test("stage & job retention") { + ignore("stage & job retention") { val conf = new SparkConf() .setMaster("local") .setAppName("test") @@ -670,34 +670,36 @@ class UISeleniumSuite extends SparkFunSuite with WebBrowser with Matchers with B sc.parallelize(Seq(1, 2, 3)).map(identity).groupBy(identity).map(identity).groupBy(identity) rdd.count() - val stage0 = Source.fromURL(sc.ui.get.webUrl + - "/stages/stage/?id=0&attempt=0&expandDagViz=true").mkString - assert(stage0.contains("digraph G {\n subgraph clusterstage_0 {\n " + - "label="Stage 0";\n subgraph ")) - assert(stage0.contains("{\n label="parallelize";\n " + - "0 [label="ParallelCollectionRDD [0]")) - assert(stage0.contains("{\n label="map";\n " + - "1 [label="MapPartitionsRDD [1]")) - assert(stage0.contains("{\n label="groupBy";\n " + - "2 [label="MapPartitionsRDD [2]")) - - val stage1 = Source.fromURL(sc.ui.get.webUrl + - "/stages/stage/?id=1&attempt=0&expandDagViz=true").mkString - assert(stage1.contains("digraph G {\n subgraph clusterstage_1 {\n " + - "label="Stage 1";\n subgraph ")) - assert(stage1.contains("{\n label="groupBy";\n " + - "3 [label="ShuffledRDD [3]")) - assert(stage1.contains("{\n label="map";\n " + - "4 [label="MapPartitionsRDD [4]")) - assert(stage1.contains("{\n label="groupBy";\n " + - "5 [label="MapPartitionsRDD [5]")) - - val stage2 = Source.fromURL(sc.ui.get.webUrl + - "/stages/stage/?id=2&attempt=0&expandDagViz=true").mkString - assert(stage2.contains("digraph G {\n subgraph clusterstage_2 {\n " + - "label="Stage 2";\n subgraph ")) - assert(stage2.contains("{\n label="groupBy";\n " + - "6 [label="ShuffledRDD [6]")) + eventually(timeout(5 seconds), interval(100 milliseconds)) { + val stage0 = Source.fromURL(sc.ui.get.webUrl + + "/stages/stage/?id=0&attempt=0&expandDagViz=true").mkString + assert(stage0.contains("digraph G {\n subgraph clusterstage_0 {\n " + + "label="Stage 0";\n subgraph ")) + assert(stage0.contains("{\n label="parallelize";\n " + + "0 [label="ParallelCollectionRDD [0]")) + assert(stage0.contains("{\n label="map";\n " + + "1 [label="MapPartitionsRDD [1]")) + assert(stage0.contains("{\n label="groupBy";\n " + + "2 [label="MapPartitionsRDD [2]")) + + val stage1 = Source.fromURL(sc.ui.get.webUrl + + "/stages/stage/?id=1&attempt=0&expandDagViz=true").mkString + assert(stage1.contains("digraph G {\n subgraph clusterstage_1 {\n " + + "label="Stage 1";\n subgraph ")) + assert(stage1.contains("{\n label="groupBy";\n " + + "3 [label="ShuffledRDD [3]")) + assert(stage1.contains("{\n label="map";\n " + + "4 [label="MapPartitionsRDD [4]")) + assert(stage1.contains("{\n label="groupBy";\n " + + "5 [label="MapPartitionsRDD [5]")) + + val stage2 = Source.fromURL(sc.ui.get.webUrl + + "/stages/stage/?id=2&attempt=0&expandDagViz=true").mkString + assert(stage2.contains("digraph G {\n subgraph clusterstage_2 {\n " + + "label="Stage 2";\n subgraph ")) + assert(stage2.contains("{\n label="groupBy";\n " + + "6 [label="ShuffledRDD [6]")) + } } } diff --git a/core/src/test/scala/org/apache/spark/ui/scope/RDDOperationGraphListenerSuite.scala b/core/src/test/scala/org/apache/spark/ui/scope/RDDOperationGraphListenerSuite.scala deleted file mode 100644 index 3fb78da0c7476..0000000000000 --- a/core/src/test/scala/org/apache/spark/ui/scope/RDDOperationGraphListenerSuite.scala +++ /dev/null @@ -1,226 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.ui.scope - -import org.apache.spark.{SparkConf, SparkFunSuite} -import org.apache.spark.scheduler._ - -/** - * Tests that this listener populates and cleans up its data structures properly. - */ -class RDDOperationGraphListenerSuite extends SparkFunSuite { - private var jobIdCounter = 0 - private var stageIdCounter = 0 - private val maxRetainedJobs = 10 - private val maxRetainedStages = 10 - private val conf = new SparkConf() - .set("spark.ui.retainedJobs", maxRetainedJobs.toString) - .set("spark.ui.retainedStages", maxRetainedStages.toString) - - test("run normal jobs") { - val startingJobId = jobIdCounter - val startingStageId = stageIdCounter - val listener = new RDDOperationGraphListener(conf) - assert(listener.jobIdToStageIds.isEmpty) - assert(listener.jobIdToSkippedStageIds.isEmpty) - assert(listener.stageIdToJobId.isEmpty) - assert(listener.stageIdToGraph.isEmpty) - assert(listener.completedStageIds.isEmpty) - assert(listener.jobIds.isEmpty) - assert(listener.stageIds.isEmpty) - - // Run a few jobs, but not enough for clean up yet - (1 to 3).foreach { numStages => startJob(numStages, listener) } // start 3 jobs and 6 stages - (0 to 5).foreach { i => endStage(startingStageId + i, listener) } // finish all 6 stages - (0 to 2).foreach { i => endJob(startingJobId + i, listener) } // finish all 3 jobs - - assert(listener.jobIdToStageIds.size === 3) - assert(listener.jobIdToStageIds(startingJobId).size === 1) - assert(listener.jobIdToStageIds(startingJobId + 1).size === 2) - assert(listener.jobIdToStageIds(startingJobId + 2).size === 3) - assert(listener.jobIdToSkippedStageIds.size === 3) - assert(listener.jobIdToSkippedStageIds.values.forall(_.isEmpty)) // no skipped stages - assert(listener.stageIdToJobId.size === 6) - assert(listener.stageIdToJobId(startingStageId) === startingJobId) - assert(listener.stageIdToJobId(startingStageId + 1) === startingJobId + 1) - assert(listener.stageIdToJobId(startingStageId + 2) === startingJobId + 1) - assert(listener.stageIdToJobId(startingStageId + 3) === startingJobId + 2) - assert(listener.stageIdToJobId(startingStageId + 4) === startingJobId + 2) - assert(listener.stageIdToJobId(startingStageId + 5) === startingJobId + 2) - assert(listener.stageIdToGraph.size === 6) - assert(listener.completedStageIds.size === 6) - assert(listener.jobIds.size === 3) - assert(listener.stageIds.size === 6) - } - - test("run jobs with skipped stages") { - val startingJobId = jobIdCounter - val startingStageId = stageIdCounter - val listener = new RDDOperationGraphListener(conf) - - // Run a few jobs, but not enough for clean up yet - // Leave some stages unfinished so that they are marked as skipped - (1 to 3).foreach { numStages => startJob(numStages, listener) } // start 3 jobs and 6 stages - (4 to 5).foreach { i => endStage(startingStageId + i, listener) } // finish only last 2 stages - (0 to 2).foreach { i => endJob(startingJobId + i, listener) } // finish all 3 jobs - - assert(listener.jobIdToSkippedStageIds.size === 3) - assert(listener.jobIdToSkippedStageIds(startingJobId).size === 1) - assert(listener.jobIdToSkippedStageIds(startingJobId + 1).size === 2) - assert(listener.jobIdToSkippedStageIds(startingJobId + 2).size === 1) // 2 stages not skipped - assert(listener.completedStageIds.size === 2) - - // The rest should be the same as before - assert(listener.jobIdToStageIds.size === 3) - assert(listener.jobIdToStageIds(startingJobId).size === 1) - assert(listener.jobIdToStageIds(startingJobId + 1).size === 2) - assert(listener.jobIdToStageIds(startingJobId + 2).size === 3) - assert(listener.stageIdToJobId.size === 6) - assert(listener.stageIdToJobId(startingStageId) === startingJobId) - assert(listener.stageIdToJobId(startingStageId + 1) === startingJobId + 1) - assert(listener.stageIdToJobId(startingStageId + 2) === startingJobId + 1) - assert(listener.stageIdToJobId(startingStageId + 3) === startingJobId + 2) - assert(listener.stageIdToJobId(startingStageId + 4) === startingJobId + 2) - assert(listener.stageIdToJobId(startingStageId + 5) === startingJobId + 2) - assert(listener.stageIdToGraph.size === 6) - assert(listener.jobIds.size === 3) - assert(listener.stageIds.size === 6) - } - - test("clean up metadata") { - val startingJobId = jobIdCounter - val startingStageId = stageIdCounter - val listener = new RDDOperationGraphListener(conf) - - // Run many jobs and stages to trigger clean up - (1 to 10000).foreach { i => - // Note: this must be less than `maxRetainedStages` - val numStages = i % (maxRetainedStages - 2) + 1 - val startingStageIdForJob = stageIdCounter - val jobId = startJob(numStages, listener) - // End some, but not all, stages that belong to this job - // This is to ensure that we have both completed and skipped stages - (startingStageIdForJob until stageIdCounter) - .filter { i => i % 2 == 0 } - .foreach { i => endStage(i, listener) } - // End all jobs - endJob(jobId, listener) - } - - // Ensure we never exceed the max retained thresholds - assert(listener.jobIdToStageIds.size <= maxRetainedJobs) - assert(listener.jobIdToSkippedStageIds.size <= maxRetainedJobs) - assert(listener.stageIdToJobId.size <= maxRetainedStages) - assert(listener.stageIdToGraph.size <= maxRetainedStages) - assert(listener.completedStageIds.size <= maxRetainedStages) - assert(listener.jobIds.size <= maxRetainedJobs) - assert(listener.stageIds.size <= maxRetainedStages) - - // Also ensure we're actually populating these data structures - // Otherwise the previous group of asserts will be meaningless - assert(listener.jobIdToStageIds.nonEmpty) - assert(listener.jobIdToSkippedStageIds.nonEmpty) - assert(listener.stageIdToJobId.nonEmpty) - assert(listener.stageIdToGraph.nonEmpty) - assert(listener.completedStageIds.nonEmpty) - assert(listener.jobIds.nonEmpty) - assert(listener.stageIds.nonEmpty) - - // Ensure we clean up old jobs and stages, not arbitrary ones - assert(!listener.jobIdToStageIds.contains(startingJobId)) - assert(!listener.jobIdToSkippedStageIds.contains(startingJobId)) - assert(!listener.stageIdToJobId.contains(startingStageId)) - assert(!listener.stageIdToGraph.contains(startingStageId)) - assert(!listener.completedStageIds.contains(startingStageId)) - assert(!listener.stageIds.contains(startingStageId)) - assert(!listener.jobIds.contains(startingJobId)) - } - - test("fate sharing between jobs and stages") { - val startingJobId = jobIdCounter - val startingStageId = stageIdCounter - val listener = new RDDOperationGraphListener(conf) - - // Run 3 jobs and 8 stages, finishing all 3 jobs but only 2 stages - startJob(5, listener) - startJob(1, listener) - startJob(2, listener) - (0 until 8).foreach { i => startStage(i + startingStageId, listener) } - endStage(startingStageId + 3, listener) - endStage(startingStageId + 4, listener) - (0 until 3).foreach { i => endJob(i + startingJobId, listener) } - - // First, assert the old stuff - assert(listener.jobIdToStageIds.size === 3) - assert(listener.jobIdToSkippedStageIds.size === 3) - assert(listener.stageIdToJobId.size === 8) - assert(listener.stageIdToGraph.size === 8) - assert(listener.completedStageIds.size === 2) - - // Cleaning the third job should clean all of its stages - listener.cleanJob(startingJobId + 2) - assert(listener.jobIdToStageIds.size === 2) - assert(listener.jobIdToSkippedStageIds.size === 2) - assert(listener.stageIdToJobId.size === 6) - assert(listener.stageIdToGraph.size === 6) - assert(listener.completedStageIds.size === 2) - - // Cleaning one of the stages in the first job should clean that job and all of its stages - // Note that we still keep around the last stage because it belongs to a different job - listener.cleanStage(startingStageId) - assert(listener.jobIdToStageIds.size === 1) - assert(listener.jobIdToSkippedStageIds.size === 1) - assert(listener.stageIdToJobId.size === 1) - assert(listener.stageIdToGraph.size === 1) - assert(listener.completedStageIds.size === 0) - } - - /** Start a job with the specified number of stages. */ - private def startJob(numStages: Int, listener: RDDOperationGraphListener): Int = { - assert(numStages > 0, "I will not run a job with 0 stages for you.") - val stageInfos = (0 until numStages).map { _ => - val stageInfo = new StageInfo(stageIdCounter, 0, "s", 0, Seq.empty, Seq.empty, "d") - stageIdCounter += 1 - stageInfo - } - val jobId = jobIdCounter - listener.onJobStart(new SparkListenerJobStart(jobId, 0, stageInfos)) - // Also start all stages that belong to this job - stageInfos.map(_.stageId).foreach { sid => startStage(sid, listener) } - jobIdCounter += 1 - jobId - } - - /** Start the stage specified by the given ID. */ - private def startStage(stageId: Int, listener: RDDOperationGraphListener): Unit = { - val stageInfo = new StageInfo(stageId, 0, "s", 0, Seq.empty, Seq.empty, "d") - listener.onStageSubmitted(new SparkListenerStageSubmitted(stageInfo)) - } - - /** Finish the stage specified by the given ID. */ - private def endStage(stageId: Int, listener: RDDOperationGraphListener): Unit = { - val stageInfo = new StageInfo(stageId, 0, "s", 0, Seq.empty, Seq.empty, "d") - listener.onStageCompleted(new SparkListenerStageCompleted(stageInfo)) - } - - /** Finish the job specified by the given ID. */ - private def endJob(jobId: Int, listener: RDDOperationGraphListener): Unit = { - listener.onJobEnd(new SparkListenerJobEnd(jobId, 0, JobSucceeded)) - } - -} diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index e6f136c7c8b0a..7f18b40f9d960 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -43,6 +43,8 @@ object MimaExcludes { ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.ui.exec.ExecutorsListener"), ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.ui.storage.StorageListener"), ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.storage.StorageStatusListener"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.status.api.v1.ExecutorStageSummary.this"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.status.api.v1.JobData.this"), // [SPARK-20495][SQL] Add StorageLevel to cacheTable API ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.sql.catalog.Catalog.cacheTable"), From 0ffa7c488fa8156e2a1aa282e60b7c36b86d8af8 Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Tue, 14 Nov 2017 15:28:22 -0600 Subject: [PATCH 1669/1765] [SPARK-20652][SQL] Store SQL UI data in the new app status store. This change replaces the SQLListener with a new implementation that saves the data to the same store used by the SparkContext's status store. For that, the types used by the old SQLListener had to be updated a bit so that they're more serialization-friendly. The interface for getting data from the store was abstracted into a new class, SQLAppStatusStore (following the convention used in core). Another change is the way that the SQL UI hooks up into the core UI or the SHS. The old "SparkHistoryListenerFactory" was replaced with a new "AppStatePlugin" that more explicitly differentiates between the two use cases: processing events, and showing the UI. Both live apps and the SHS use this new API (previously, it was restricted to the SHS). Note on the above: this causes a slight change of behavior for live apps; the SQL tab will only show up after the first execution is started. The metrics gathering code was re-worked a bit so that the types used are less memory hungry and more serialization-friendly. This reduces memory usage when using in-memory stores, and reduces load times when using disk stores. Tested with existing and added unit tests. Note one unit test was disabled because it depends on SPARK-20653, which isn't in yet. Author: Marcelo Vanzin Closes #19681 from vanzin/SPARK-20652. --- .../scala/org/apache/spark/SparkContext.scala | 15 +- .../deploy/history/FsHistoryProvider.scala | 12 +- .../spark/scheduler/SparkListener.scala | 12 - .../apache/spark/status/AppStatusPlugin.scala | 71 +++ .../apache/spark/status/AppStatusStore.scala | 8 +- ...park.scheduler.SparkHistoryListenerFactory | 1 - .../org.apache.spark.status.AppStatusPlugin | 1 + .../org/apache/spark/sql/SparkSession.scala | 5 - .../sql/execution/ui/AllExecutionsPage.scala | 86 ++-- .../sql/execution/ui/ExecutionPage.scala | 60 ++- .../execution/ui/SQLAppStatusListener.scala | 366 ++++++++++++++++ .../sql/execution/ui/SQLAppStatusStore.scala | 179 ++++++++ .../spark/sql/execution/ui/SQLListener.scala | 403 +----------------- .../spark/sql/execution/ui/SQLTab.scala | 2 +- .../spark/sql/internal/SharedState.scala | 19 - .../execution/metric/SQLMetricsSuite.scala | 18 +- .../metric/SQLMetricsTestUtils.scala | 30 +- .../sql/execution/ui/SQLListenerSuite.scala | 340 ++++++++------- .../spark/sql/test/SharedSparkSession.scala | 1 - 19 files changed, 920 insertions(+), 709 deletions(-) create mode 100644 core/src/main/scala/org/apache/spark/status/AppStatusPlugin.scala delete mode 100644 sql/core/src/main/resources/META-INF/services/org.apache.spark.scheduler.SparkHistoryListenerFactory create mode 100644 sql/core/src/main/resources/META-INF/services/org.apache.spark.status.AppStatusPlugin create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLAppStatusListener.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLAppStatusStore.scala diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index 1d325e651b1d9..23fd54f59268a 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -54,7 +54,7 @@ import org.apache.spark.rpc.RpcEndpointRef import org.apache.spark.scheduler._ import org.apache.spark.scheduler.cluster.{CoarseGrainedSchedulerBackend, StandaloneSchedulerBackend} import org.apache.spark.scheduler.local.LocalSchedulerBackend -import org.apache.spark.status.AppStatusStore +import org.apache.spark.status.{AppStatusPlugin, AppStatusStore} import org.apache.spark.storage._ import org.apache.spark.storage.BlockManagerMessages.TriggerThreadDump import org.apache.spark.ui.{ConsoleProgressBar, SparkUI} @@ -246,6 +246,8 @@ class SparkContext(config: SparkConf) extends Logging { */ def isStopped: Boolean = stopped.get() + private[spark] def statusStore: AppStatusStore = _statusStore + // An asynchronous listener bus for Spark events private[spark] def listenerBus: LiveListenerBus = _listenerBus @@ -455,9 +457,14 @@ class SparkContext(config: SparkConf) extends Logging { // For tests, do not enable the UI None } - // Bind the UI before starting the task scheduler to communicate - // the bound port to the cluster manager properly - _ui.foreach(_.bind()) + _ui.foreach { ui => + // Load any plugins that might want to modify the UI. + AppStatusPlugin.loadPlugins().foreach(_.setupUI(ui)) + + // Bind the UI before starting the task scheduler to communicate + // the bound port to the cluster manager properly + ui.bind() + } _hadoopConfiguration = SparkHadoopUtil.get.newConfiguration(_conf) diff --git a/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala b/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala index a6dc53321d650..25f82b55f2003 100644 --- a/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala +++ b/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala @@ -41,7 +41,7 @@ import org.apache.spark.deploy.history.config._ import org.apache.spark.internal.Logging import org.apache.spark.scheduler._ import org.apache.spark.scheduler.ReplayListenerBus._ -import org.apache.spark.status.{AppStatusListener, AppStatusStore, AppStatusStoreMetadata, KVUtils} +import org.apache.spark.status._ import org.apache.spark.status.KVUtils._ import org.apache.spark.status.api.v1 import org.apache.spark.ui.SparkUI @@ -319,6 +319,9 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) val _listener = new AppStatusListener(kvstore, conf, false, lastUpdateTime = Some(attempt.info.lastUpdated.getTime())) replayBus.addListener(_listener) + AppStatusPlugin.loadPlugins().foreach { plugin => + plugin.setupListeners(conf, kvstore, l => replayBus.addListener(l), false) + } Some(_listener) } else { None @@ -333,11 +336,8 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) } try { - val listenerFactories = ServiceLoader.load(classOf[SparkHistoryListenerFactory], - Utils.getContextOrSparkClassLoader).asScala - listenerFactories.foreach { listenerFactory => - val listeners = listenerFactory.createListeners(conf, loadedUI.ui) - listeners.foreach(replayBus.addListener) + AppStatusPlugin.loadPlugins().foreach { plugin => + plugin.setupUI(loadedUI.ui) } val fileStatus = fs.getFileStatus(new Path(logDir, attempt.logPath)) diff --git a/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala b/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala index b76e560669d59..3b677ca9657db 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala @@ -167,18 +167,6 @@ case class SparkListenerApplicationEnd(time: Long) extends SparkListenerEvent @DeveloperApi case class SparkListenerLogStart(sparkVersion: String) extends SparkListenerEvent -/** - * Interface for creating history listeners defined in other modules like SQL, which are used to - * rebuild the history UI. - */ -private[spark] trait SparkHistoryListenerFactory { - /** - * Create listeners used to rebuild the history UI. - */ - def createListeners(conf: SparkConf, sparkUI: SparkUI): Seq[SparkListener] -} - - /** * Interface for listening to events from the Spark scheduler. Most applications should probably * extend SparkListener or SparkFirehoseListener directly, rather than implementing this class. diff --git a/core/src/main/scala/org/apache/spark/status/AppStatusPlugin.scala b/core/src/main/scala/org/apache/spark/status/AppStatusPlugin.scala new file mode 100644 index 0000000000000..69ca02ec76293 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/status/AppStatusPlugin.scala @@ -0,0 +1,71 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.status + +import java.util.ServiceLoader + +import scala.collection.JavaConverters._ + +import org.apache.spark.SparkConf +import org.apache.spark.scheduler.SparkListener +import org.apache.spark.ui.SparkUI +import org.apache.spark.util.Utils +import org.apache.spark.util.kvstore.KVStore + +/** + * An interface that defines plugins for collecting and storing application state. + * + * The plugin implementations are invoked for both live and replayed applications. For live + * applications, it's recommended that plugins defer creation of UI tabs until there's actual + * data to be shown. + */ +private[spark] trait AppStatusPlugin { + + /** + * Install listeners to collect data about the running application and populate the given + * store. + * + * @param conf The Spark configuration. + * @param store The KVStore where to keep application data. + * @param addListenerFn Function to register listeners with a bus. + * @param live Whether this is a live application (or an application being replayed by the + * HistoryServer). + */ + def setupListeners( + conf: SparkConf, + store: KVStore, + addListenerFn: SparkListener => Unit, + live: Boolean): Unit + + /** + * Install any needed extensions (tabs, pages, etc) to a Spark UI. The plugin can detect whether + * the app is live or replayed by looking at the UI's SparkContext field `sc`. + * + * @param ui The Spark UI instance for the application. + */ + def setupUI(ui: SparkUI): Unit + +} + +private[spark] object AppStatusPlugin { + + def loadPlugins(): Iterable[AppStatusPlugin] = { + ServiceLoader.load(classOf[AppStatusPlugin], Utils.getContextOrSparkClassLoader).asScala + } + +} diff --git a/core/src/main/scala/org/apache/spark/status/AppStatusStore.scala b/core/src/main/scala/org/apache/spark/status/AppStatusStore.scala index 9b42f55605755..d0615e5dd0223 100644 --- a/core/src/main/scala/org/apache/spark/status/AppStatusStore.scala +++ b/core/src/main/scala/org/apache/spark/status/AppStatusStore.scala @@ -32,7 +32,7 @@ import org.apache.spark.util.kvstore.{InMemoryStore, KVStore} /** * A wrapper around a KVStore that provides methods for accessing the API data stored within. */ -private[spark] class AppStatusStore(store: KVStore) { +private[spark] class AppStatusStore(val store: KVStore) { def applicationInfo(): v1.ApplicationInfo = { store.view(classOf[ApplicationInfoWrapper]).max(1).iterator().next().info @@ -338,9 +338,11 @@ private[spark] object AppStatusStore { */ def createLiveStore(conf: SparkConf, addListenerFn: SparkListener => Unit): AppStatusStore = { val store = new InMemoryStore() - val stateStore = new AppStatusStore(store) addListenerFn(new AppStatusListener(store, conf, true)) - stateStore + AppStatusPlugin.loadPlugins().foreach { p => + p.setupListeners(conf, store, addListenerFn, true) + } + new AppStatusStore(store) } } diff --git a/sql/core/src/main/resources/META-INF/services/org.apache.spark.scheduler.SparkHistoryListenerFactory b/sql/core/src/main/resources/META-INF/services/org.apache.spark.scheduler.SparkHistoryListenerFactory deleted file mode 100644 index 507100be90967..0000000000000 --- a/sql/core/src/main/resources/META-INF/services/org.apache.spark.scheduler.SparkHistoryListenerFactory +++ /dev/null @@ -1 +0,0 @@ -org.apache.spark.sql.execution.ui.SQLHistoryListenerFactory diff --git a/sql/core/src/main/resources/META-INF/services/org.apache.spark.status.AppStatusPlugin b/sql/core/src/main/resources/META-INF/services/org.apache.spark.status.AppStatusPlugin new file mode 100644 index 0000000000000..ac6d7f6962f85 --- /dev/null +++ b/sql/core/src/main/resources/META-INF/services/org.apache.spark.status.AppStatusPlugin @@ -0,0 +1 @@ +org.apache.spark.sql.execution.ui.SQLAppStatusPlugin diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala index 2821f5ee7feee..272eb844226d4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala @@ -38,7 +38,6 @@ import org.apache.spark.sql.catalyst.expressions.AttributeReference import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, Range} import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.datasources.LogicalRelation -import org.apache.spark.sql.execution.ui.SQLListener import org.apache.spark.sql.internal._ import org.apache.spark.sql.internal.StaticSQLConf.CATALOG_IMPLEMENTATION import org.apache.spark.sql.sources.BaseRelation @@ -957,7 +956,6 @@ object SparkSession { sparkContext.addSparkListener(new SparkListener { override def onApplicationEnd(applicationEnd: SparkListenerApplicationEnd): Unit = { defaultSession.set(null) - sqlListener.set(null) } }) } @@ -1026,9 +1024,6 @@ object SparkSession { */ def getDefaultSession: Option[SparkSession] = Option(defaultSession.get) - /** A global SQL listener used for the SQL UI. */ - private[sql] val sqlListener = new AtomicReference[SQLListener]() - //////////////////////////////////////////////////////////////////////////////////////// // Private methods from now on //////////////////////////////////////////////////////////////////////////////////////// diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/AllExecutionsPage.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/AllExecutionsPage.scala index f9c69864a3361..7019d98e1619f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/AllExecutionsPage.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/AllExecutionsPage.scala @@ -24,34 +24,54 @@ import scala.xml.{Node, NodeSeq} import org.apache.commons.lang3.StringEscapeUtils +import org.apache.spark.JobExecutionStatus import org.apache.spark.internal.Logging import org.apache.spark.ui.{UIUtils, WebUIPage} private[ui] class AllExecutionsPage(parent: SQLTab) extends WebUIPage("") with Logging { - private val listener = parent.listener + private val sqlStore = parent.sqlStore override def render(request: HttpServletRequest): Seq[Node] = { val currentTime = System.currentTimeMillis() - val content = listener.synchronized { + val running = new mutable.ArrayBuffer[SQLExecutionUIData]() + val completed = new mutable.ArrayBuffer[SQLExecutionUIData]() + val failed = new mutable.ArrayBuffer[SQLExecutionUIData]() + + sqlStore.executionsList().foreach { e => + val isRunning = e.jobs.exists { case (_, status) => status == JobExecutionStatus.RUNNING } + val isFailed = e.jobs.exists { case (_, status) => status == JobExecutionStatus.FAILED } + if (isRunning) { + running += e + } else if (isFailed) { + failed += e + } else { + completed += e + } + } + + val content = { val _content = mutable.ListBuffer[Node]() - if (listener.getRunningExecutions.nonEmpty) { + + if (running.nonEmpty) { _content ++= new RunningExecutionTable( - parent, s"Running Queries (${listener.getRunningExecutions.size})", currentTime, - listener.getRunningExecutions.sortBy(_.submissionTime).reverse).toNodeSeq + parent, s"Running Queries (${running.size})", currentTime, + running.sortBy(_.submissionTime).reverse).toNodeSeq } - if (listener.getCompletedExecutions.nonEmpty) { + + if (completed.nonEmpty) { _content ++= new CompletedExecutionTable( - parent, s"Completed Queries (${listener.getCompletedExecutions.size})", currentTime, - listener.getCompletedExecutions.sortBy(_.submissionTime).reverse).toNodeSeq + parent, s"Completed Queries (${completed.size})", currentTime, + completed.sortBy(_.submissionTime).reverse).toNodeSeq } - if (listener.getFailedExecutions.nonEmpty) { + + if (failed.nonEmpty) { _content ++= new FailedExecutionTable( - parent, s"Failed Queries (${listener.getFailedExecutions.size})", currentTime, - listener.getFailedExecutions.sortBy(_.submissionTime).reverse).toNodeSeq + parent, s"Failed Queries (${failed.size})", currentTime, + failed.sortBy(_.submissionTime).reverse).toNodeSeq } _content } @@ -65,26 +85,26 @@ private[ui] class AllExecutionsPage(parent: SQLTab) extends WebUIPage("") with L
      { - if (listener.getRunningExecutions.nonEmpty) { + if (running.nonEmpty) {
    • Running Queries: - {listener.getRunningExecutions.size} + {running.size}
    • } } { - if (listener.getCompletedExecutions.nonEmpty) { + if (completed.nonEmpty) {
    • Completed Queries: - {listener.getCompletedExecutions.size} + {completed.size}
    • } } { - if (listener.getFailedExecutions.nonEmpty) { + if (failed.nonEmpty) {
    • Failed Queries: - {listener.getFailedExecutions.size} + {failed.size}
    • } } @@ -114,23 +134,19 @@ private[ui] abstract class ExecutionTable( protected def row(currentTime: Long, executionUIData: SQLExecutionUIData): Seq[Node] = { val submissionTime = executionUIData.submissionTime - val duration = executionUIData.completionTime.getOrElse(currentTime) - submissionTime + val duration = executionUIData.completionTime.map(_.getTime()).getOrElse(currentTime) - + submissionTime - val runningJobs = executionUIData.runningJobs.map { jobId => - - [{jobId.toString}] - - } - val succeededJobs = executionUIData.succeededJobs.sorted.map { jobId => - - [{jobId.toString}] - - } - val failedJobs = executionUIData.failedJobs.sorted.map { jobId => - - [{jobId.toString}] - + def jobLinks(status: JobExecutionStatus): Seq[Node] = { + executionUIData.jobs.flatMap { case (jobId, jobStatus) => + if (jobStatus == status) { + [{jobId.toString}] + } else { + None + } + }.toSeq } + {executionUIData.executionId.toString} @@ -146,17 +162,17 @@ private[ui] abstract class ExecutionTable( {if (showRunningJobs) { - {runningJobs} + {jobLinks(JobExecutionStatus.RUNNING)} }} {if (showSucceededJobs) { - {succeededJobs} + {jobLinks(JobExecutionStatus.SUCCEEDED)} }} {if (showFailedJobs) { - {failedJobs} + {jobLinks(JobExecutionStatus.FAILED)} }} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/ExecutionPage.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/ExecutionPage.scala index 460fc946c3e6f..f29e135ac357f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/ExecutionPage.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/ExecutionPage.scala @@ -21,24 +21,42 @@ import javax.servlet.http.HttpServletRequest import scala.xml.Node +import org.apache.spark.JobExecutionStatus import org.apache.spark.internal.Logging import org.apache.spark.ui.{UIUtils, WebUIPage} class ExecutionPage(parent: SQLTab) extends WebUIPage("execution") with Logging { - private val listener = parent.listener + private val sqlStore = parent.sqlStore - override def render(request: HttpServletRequest): Seq[Node] = listener.synchronized { + override def render(request: HttpServletRequest): Seq[Node] = { // stripXSS is called first to remove suspicious characters used in XSS attacks val parameterExecutionId = UIUtils.stripXSS(request.getParameter("id")) require(parameterExecutionId != null && parameterExecutionId.nonEmpty, "Missing execution id parameter") val executionId = parameterExecutionId.toLong - val content = listener.getExecution(executionId).map { executionUIData => + val content = sqlStore.execution(executionId).map { executionUIData => val currentTime = System.currentTimeMillis() - val duration = - executionUIData.completionTime.getOrElse(currentTime) - executionUIData.submissionTime + val duration = executionUIData.completionTime.map(_.getTime()).getOrElse(currentTime) - + executionUIData.submissionTime + + def jobLinks(status: JobExecutionStatus, label: String): Seq[Node] = { + val jobs = executionUIData.jobs.flatMap { case (jobId, jobStatus) => + if (jobStatus == status) Some(jobId) else None + } + if (jobs.nonEmpty) { +
    • + {label} + {jobs.toSeq.sorted.map { jobId => + {jobId.toString}  + }} +
    • + } else { + Nil + } + } + val summary =
      @@ -49,37 +67,17 @@ class ExecutionPage(parent: SQLTab) extends WebUIPage("execution") with Logging
    • Duration: {UIUtils.formatDuration(duration)}
    • - {if (executionUIData.runningJobs.nonEmpty) { -
    • - Running Jobs: - {executionUIData.runningJobs.sorted.map { jobId => - {jobId.toString}  - }} -
    • - }} - {if (executionUIData.succeededJobs.nonEmpty) { -
    • - Succeeded Jobs: - {executionUIData.succeededJobs.sorted.map { jobId => - {jobId.toString}  - }} -
    • - }} - {if (executionUIData.failedJobs.nonEmpty) { -
    • - Failed Jobs: - {executionUIData.failedJobs.sorted.map { jobId => - {jobId.toString}  - }} -
    • - }} + {jobLinks(JobExecutionStatus.RUNNING, "Running Jobs:")} + {jobLinks(JobExecutionStatus.SUCCEEDED, "Succeeded Jobs:")} + {jobLinks(JobExecutionStatus.FAILED, "Failed Jobs:")}
    - val metrics = listener.getExecutionMetrics(executionId) + val metrics = sqlStore.executionMetrics(executionId) + val graph = sqlStore.planGraph(executionId) summary ++ - planVisualization(metrics, executionUIData.physicalPlanGraph) ++ + planVisualization(metrics, graph) ++ physicalPlanDescription(executionUIData.physicalPlanDescription) }.getOrElse {
    No information to display for Plan {executionId}
    diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLAppStatusListener.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLAppStatusListener.scala new file mode 100644 index 0000000000000..43cec4807ae4d --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLAppStatusListener.scala @@ -0,0 +1,366 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.execution.ui + +import java.util.Date +import java.util.concurrent.ConcurrentHashMap +import java.util.function.Function + +import scala.collection.JavaConverters._ + +import org.apache.spark.{JobExecutionStatus, SparkConf} +import org.apache.spark.internal.Logging +import org.apache.spark.scheduler._ +import org.apache.spark.sql.execution.SQLExecution +import org.apache.spark.sql.execution.metric._ +import org.apache.spark.status.LiveEntity +import org.apache.spark.status.config._ +import org.apache.spark.ui.SparkUI +import org.apache.spark.util.kvstore.KVStore + +private[sql] class SQLAppStatusListener( + conf: SparkConf, + kvstore: KVStore, + live: Boolean, + ui: Option[SparkUI] = None) + extends SparkListener with Logging { + + // How often to flush intermediate state of a live execution to the store. When replaying logs, + // never flush (only do the very last write). + private val liveUpdatePeriodNs = if (live) conf.get(LIVE_ENTITY_UPDATE_PERIOD) else -1L + + // Live tracked data is needed by the SQL status store to calculate metrics for in-flight + // executions; that means arbitrary threads may be querying these maps, so they need to be + // thread-safe. + private val liveExecutions = new ConcurrentHashMap[Long, LiveExecutionData]() + private val stageMetrics = new ConcurrentHashMap[Int, LiveStageMetrics]() + + private var uiInitialized = false + + override def onJobStart(event: SparkListenerJobStart): Unit = { + val executionIdString = event.properties.getProperty(SQLExecution.EXECUTION_ID_KEY) + if (executionIdString == null) { + // This is not a job created by SQL + return + } + + val executionId = executionIdString.toLong + val jobId = event.jobId + val exec = getOrCreateExecution(executionId) + + // Record the accumulator IDs for the stages of this job, so that the code that keeps + // track of the metrics knows which accumulators to look at. + val accumIds = exec.metrics.map(_.accumulatorId).sorted.toList + event.stageIds.foreach { id => + stageMetrics.put(id, new LiveStageMetrics(id, 0, accumIds.toArray, new ConcurrentHashMap())) + } + + exec.jobs = exec.jobs + (jobId -> JobExecutionStatus.RUNNING) + exec.stages = event.stageIds.toSet + update(exec) + } + + override def onStageSubmitted(event: SparkListenerStageSubmitted): Unit = { + if (!isSQLStage(event.stageInfo.stageId)) { + return + } + + // Reset the metrics tracking object for the new attempt. + Option(stageMetrics.get(event.stageInfo.stageId)).foreach { metrics => + metrics.taskMetrics.clear() + metrics.attemptId = event.stageInfo.attemptId + } + } + + override def onJobEnd(event: SparkListenerJobEnd): Unit = { + liveExecutions.values().asScala.foreach { exec => + if (exec.jobs.contains(event.jobId)) { + val result = event.jobResult match { + case JobSucceeded => JobExecutionStatus.SUCCEEDED + case _ => JobExecutionStatus.FAILED + } + exec.jobs = exec.jobs + (event.jobId -> result) + exec.endEvents += 1 + update(exec) + } + } + } + + override def onExecutorMetricsUpdate(event: SparkListenerExecutorMetricsUpdate): Unit = { + event.accumUpdates.foreach { case (taskId, stageId, attemptId, accumUpdates) => + updateStageMetrics(stageId, attemptId, taskId, accumUpdates, false) + } + } + + override def onTaskEnd(event: SparkListenerTaskEnd): Unit = { + if (!isSQLStage(event.stageId)) { + return + } + + val info = event.taskInfo + // SPARK-20342. If processing events from a live application, use the task metrics info to + // work around a race in the DAGScheduler. The metrics info does not contain accumulator info + // when reading event logs in the SHS, so we have to rely on the accumulator in that case. + val accums = if (live && event.taskMetrics != null) { + event.taskMetrics.externalAccums.flatMap { a => + // This call may fail if the accumulator is gc'ed, so account for that. + try { + Some(a.toInfo(Some(a.value), None)) + } catch { + case _: IllegalAccessError => None + } + } + } else { + info.accumulables + } + updateStageMetrics(event.stageId, event.stageAttemptId, info.taskId, accums, + info.successful) + } + + def liveExecutionMetrics(executionId: Long): Option[Map[Long, String]] = { + Option(liveExecutions.get(executionId)).map { exec => + if (exec.metricsValues != null) { + exec.metricsValues + } else { + aggregateMetrics(exec) + } + } + } + + private def aggregateMetrics(exec: LiveExecutionData): Map[Long, String] = { + val metricIds = exec.metrics.map(_.accumulatorId).sorted + val metricTypes = exec.metrics.map { m => (m.accumulatorId, m.metricType) }.toMap + val metrics = exec.stages.toSeq + .flatMap { stageId => Option(stageMetrics.get(stageId)) } + .flatMap(_.taskMetrics.values().asScala) + .flatMap { metrics => metrics.ids.zip(metrics.values) } + + val aggregatedMetrics = (metrics ++ exec.driverAccumUpdates.toSeq) + .filter { case (id, _) => metricIds.contains(id) } + .groupBy(_._1) + .map { case (id, values) => + id -> SQLMetrics.stringValue(metricTypes(id), values.map(_._2).toSeq) + } + + // Check the execution again for whether the aggregated metrics data has been calculated. + // This can happen if the UI is requesting this data, and the onExecutionEnd handler is + // running at the same time. The metrics calculcated for the UI can be innacurate in that + // case, since the onExecutionEnd handler will clean up tracked stage metrics. + if (exec.metricsValues != null) { + exec.metricsValues + } else { + aggregatedMetrics + } + } + + private def updateStageMetrics( + stageId: Int, + attemptId: Int, + taskId: Long, + accumUpdates: Seq[AccumulableInfo], + succeeded: Boolean): Unit = { + Option(stageMetrics.get(stageId)).foreach { metrics => + if (metrics.attemptId != attemptId || metrics.accumulatorIds.isEmpty) { + return + } + + val oldTaskMetrics = metrics.taskMetrics.get(taskId) + if (oldTaskMetrics != null && oldTaskMetrics.succeeded) { + return + } + + val updates = accumUpdates + .filter { acc => acc.update.isDefined && metrics.accumulatorIds.contains(acc.id) } + .sortBy(_.id) + + if (updates.isEmpty) { + return + } + + val ids = new Array[Long](updates.size) + val values = new Array[Long](updates.size) + updates.zipWithIndex.foreach { case (acc, idx) => + ids(idx) = acc.id + // In a live application, accumulators have Long values, but when reading from event + // logs, they have String values. For now, assume all accumulators are Long and covert + // accordingly. + values(idx) = acc.update.get match { + case s: String => s.toLong + case l: Long => l + case o => throw new IllegalArgumentException(s"Unexpected: $o") + } + } + + // TODO: storing metrics by task ID can cause metrics for the same task index to be + // counted multiple times, for example due to speculation or re-attempts. + metrics.taskMetrics.put(taskId, new LiveTaskMetrics(ids, values, succeeded)) + } + } + + private def onExecutionStart(event: SparkListenerSQLExecutionStart): Unit = { + // Install the SQL tab in a live app if it hasn't been initialized yet. + if (!uiInitialized) { + ui.foreach { _ui => + new SQLTab(new SQLAppStatusStore(kvstore, Some(this)), _ui) + } + uiInitialized = true + } + + val SparkListenerSQLExecutionStart(executionId, description, details, + physicalPlanDescription, sparkPlanInfo, time) = event + + def toStoredNodes(nodes: Seq[SparkPlanGraphNode]): Seq[SparkPlanGraphNodeWrapper] = { + nodes.map { + case cluster: SparkPlanGraphCluster => + val storedCluster = new SparkPlanGraphClusterWrapper( + cluster.id, + cluster.name, + cluster.desc, + toStoredNodes(cluster.nodes), + cluster.metrics) + new SparkPlanGraphNodeWrapper(null, storedCluster) + + case node => + new SparkPlanGraphNodeWrapper(node, null) + } + } + + val planGraph = SparkPlanGraph(sparkPlanInfo) + val sqlPlanMetrics = planGraph.allNodes.flatMap { node => + node.metrics.map { metric => (metric.accumulatorId, metric) } + }.toMap.values.toList + + val graphToStore = new SparkPlanGraphWrapper( + executionId, + toStoredNodes(planGraph.nodes), + planGraph.edges) + kvstore.write(graphToStore) + + val exec = getOrCreateExecution(executionId) + exec.description = description + exec.details = details + exec.physicalPlanDescription = physicalPlanDescription + exec.metrics = sqlPlanMetrics + exec.submissionTime = time + update(exec) + } + + private def onExecutionEnd(event: SparkListenerSQLExecutionEnd): Unit = { + val SparkListenerSQLExecutionEnd(executionId, time) = event + Option(liveExecutions.get(executionId)).foreach { exec => + exec.metricsValues = aggregateMetrics(exec) + exec.completionTime = Some(new Date(time)) + exec.endEvents += 1 + update(exec) + + // Remove stale LiveStageMetrics objects for stages that are not active anymore. + val activeStages = liveExecutions.values().asScala.flatMap { other => + if (other != exec) other.stages else Nil + }.toSet + stageMetrics.keySet().asScala + .filter(!activeStages.contains(_)) + .foreach(stageMetrics.remove) + } + } + + private def onDriverAccumUpdates(event: SparkListenerDriverAccumUpdates): Unit = { + val SparkListenerDriverAccumUpdates(executionId, accumUpdates) = event + Option(liveExecutions.get(executionId)).foreach { exec => + exec.driverAccumUpdates = accumUpdates.toMap + update(exec) + } + } + + override def onOtherEvent(event: SparkListenerEvent): Unit = event match { + case e: SparkListenerSQLExecutionStart => onExecutionStart(e) + case e: SparkListenerSQLExecutionEnd => onExecutionEnd(e) + case e: SparkListenerDriverAccumUpdates => onDriverAccumUpdates(e) + case _ => // Ignore + } + + private def getOrCreateExecution(executionId: Long): LiveExecutionData = { + liveExecutions.computeIfAbsent(executionId, + new Function[Long, LiveExecutionData]() { + override def apply(key: Long): LiveExecutionData = new LiveExecutionData(executionId) + }) + } + + private def update(exec: LiveExecutionData): Unit = { + val now = System.nanoTime() + if (exec.endEvents >= exec.jobs.size + 1) { + exec.write(kvstore, now) + liveExecutions.remove(exec.executionId) + } else if (liveUpdatePeriodNs >= 0) { + if (now - exec.lastWriteTime > liveUpdatePeriodNs) { + exec.write(kvstore, now) + } + } + } + + private def isSQLStage(stageId: Int): Boolean = { + liveExecutions.values().asScala.exists { exec => + exec.stages.contains(stageId) + } + } + +} + +private class LiveExecutionData(val executionId: Long) extends LiveEntity { + + var description: String = null + var details: String = null + var physicalPlanDescription: String = null + var metrics = Seq[SQLPlanMetric]() + var submissionTime = -1L + var completionTime: Option[Date] = None + + var jobs = Map[Int, JobExecutionStatus]() + var stages = Set[Int]() + var driverAccumUpdates = Map[Long, Long]() + + @volatile var metricsValues: Map[Long, String] = null + + // Just in case job end and execution end arrive out of order, keep track of how many + // end events arrived so that the listener can stop tracking the execution. + var endEvents = 0 + + override protected def doUpdate(): Any = { + new SQLExecutionUIData( + executionId, + description, + details, + physicalPlanDescription, + metrics, + submissionTime, + completionTime, + jobs, + stages, + metricsValues) + } + +} + +private class LiveStageMetrics( + val stageId: Int, + var attemptId: Int, + val accumulatorIds: Array[Long], + val taskMetrics: ConcurrentHashMap[Long, LiveTaskMetrics]) + +private[sql] class LiveTaskMetrics( + val ids: Array[Long], + val values: Array[Long], + val succeeded: Boolean) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLAppStatusStore.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLAppStatusStore.scala new file mode 100644 index 0000000000000..586d3ae411c74 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLAppStatusStore.scala @@ -0,0 +1,179 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.ui + +import java.lang.{Long => JLong} +import java.util.Date + +import scala.collection.JavaConverters._ +import scala.collection.mutable.ArrayBuffer + +import com.fasterxml.jackson.databind.annotation.JsonDeserialize + +import org.apache.spark.{JobExecutionStatus, SparkConf} +import org.apache.spark.scheduler.SparkListener +import org.apache.spark.status.AppStatusPlugin +import org.apache.spark.status.KVUtils.KVIndexParam +import org.apache.spark.ui.SparkUI +import org.apache.spark.util.Utils +import org.apache.spark.util.kvstore.KVStore + +/** + * Provides a view of a KVStore with methods that make it easy to query SQL-specific state. There's + * no state kept in this class, so it's ok to have multiple instances of it in an application. + */ +private[sql] class SQLAppStatusStore( + store: KVStore, + listener: Option[SQLAppStatusListener] = None) { + + def executionsList(): Seq[SQLExecutionUIData] = { + store.view(classOf[SQLExecutionUIData]).asScala.toSeq + } + + def execution(executionId: Long): Option[SQLExecutionUIData] = { + try { + Some(store.read(classOf[SQLExecutionUIData], executionId)) + } catch { + case _: NoSuchElementException => None + } + } + + def executionsCount(): Long = { + store.count(classOf[SQLExecutionUIData]) + } + + def executionMetrics(executionId: Long): Map[Long, String] = { + def metricsFromStore(): Option[Map[Long, String]] = { + val exec = store.read(classOf[SQLExecutionUIData], executionId) + Option(exec.metricValues) + } + + metricsFromStore() + .orElse(listener.flatMap(_.liveExecutionMetrics(executionId))) + // Try a second time in case the execution finished while this method is trying to + // get the metrics. + .orElse(metricsFromStore()) + .getOrElse(Map()) + } + + def planGraph(executionId: Long): SparkPlanGraph = { + store.read(classOf[SparkPlanGraphWrapper], executionId).toSparkPlanGraph() + } + +} + +/** + * An AppStatusPlugin for handling the SQL UI and listeners. + */ +private[sql] class SQLAppStatusPlugin extends AppStatusPlugin { + + override def setupListeners( + conf: SparkConf, + store: KVStore, + addListenerFn: SparkListener => Unit, + live: Boolean): Unit = { + // For live applications, the listener is installed in [[setupUI]]. This also avoids adding + // the listener when the UI is disabled. Force installation during testing, though. + if (!live || Utils.isTesting) { + val listener = new SQLAppStatusListener(conf, store, live, None) + addListenerFn(listener) + } + } + + override def setupUI(ui: SparkUI): Unit = { + ui.sc match { + case Some(sc) => + // If this is a live application, then install a listener that will enable the SQL + // tab as soon as there's a SQL event posted to the bus. + val listener = new SQLAppStatusListener(sc.conf, ui.store.store, true, Some(ui)) + sc.listenerBus.addToStatusQueue(listener) + + case _ => + // For a replayed application, only add the tab if the store already contains SQL data. + val sqlStore = new SQLAppStatusStore(ui.store.store) + if (sqlStore.executionsCount() > 0) { + new SQLTab(sqlStore, ui) + } + } + } + +} + +private[sql] class SQLExecutionUIData( + @KVIndexParam val executionId: Long, + val description: String, + val details: String, + val physicalPlanDescription: String, + val metrics: Seq[SQLPlanMetric], + val submissionTime: Long, + val completionTime: Option[Date], + @JsonDeserialize(keyAs = classOf[Integer]) + val jobs: Map[Int, JobExecutionStatus], + @JsonDeserialize(contentAs = classOf[Integer]) + val stages: Set[Int], + /** + * This field is only populated after the execution is finished; it will be null while the + * execution is still running. During execution, aggregate metrics need to be retrieved + * from the SQL listener instance. + */ + @JsonDeserialize(keyAs = classOf[JLong]) + val metricValues: Map[Long, String] + ) + +private[sql] class SparkPlanGraphWrapper( + @KVIndexParam val executionId: Long, + val nodes: Seq[SparkPlanGraphNodeWrapper], + val edges: Seq[SparkPlanGraphEdge]) { + + def toSparkPlanGraph(): SparkPlanGraph = { + SparkPlanGraph(nodes.map(_.toSparkPlanGraphNode()), edges) + } + +} + +private[sql] class SparkPlanGraphClusterWrapper( + val id: Long, + val name: String, + val desc: String, + val nodes: Seq[SparkPlanGraphNodeWrapper], + val metrics: Seq[SQLPlanMetric]) { + + def toSparkPlanGraphCluster(): SparkPlanGraphCluster = { + new SparkPlanGraphCluster(id, name, desc, + new ArrayBuffer() ++ nodes.map(_.toSparkPlanGraphNode()), + metrics) + } + +} + +/** Only one of the values should be set. */ +private[sql] class SparkPlanGraphNodeWrapper( + val node: SparkPlanGraphNode, + val cluster: SparkPlanGraphClusterWrapper) { + + def toSparkPlanGraphNode(): SparkPlanGraphNode = { + assert(node == null ^ cluster == null, "One and only of of nore or cluster must be set.") + if (node != null) node else cluster.toSparkPlanGraphCluster() + } + +} + +private[sql] case class SQLPlanMetric( + name: String, + accumulatorId: Long, + metricType: String) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLListener.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLListener.scala index 8c27af374febd..b58b8c6d45e5b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLListener.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLListener.scala @@ -17,21 +17,15 @@ package org.apache.spark.sql.execution.ui -import scala.collection.mutable - import com.fasterxml.jackson.databind.JavaType import com.fasterxml.jackson.databind.`type`.TypeFactory import com.fasterxml.jackson.databind.annotation.JsonDeserialize import com.fasterxml.jackson.databind.util.Converter -import org.apache.spark.{JobExecutionStatus, SparkConf} import org.apache.spark.annotation.DeveloperApi -import org.apache.spark.internal.Logging import org.apache.spark.scheduler._ -import org.apache.spark.sql.execution.{SparkPlanInfo, SQLExecution} +import org.apache.spark.sql.execution.SparkPlanInfo import org.apache.spark.sql.execution.metric._ -import org.apache.spark.ui.SparkUI -import org.apache.spark.util.AccumulatorContext @DeveloperApi case class SparkListenerSQLExecutionStart( @@ -89,398 +83,3 @@ private class LongLongTupleConverter extends Converter[(Object, Object), (Long, typeFactory.constructSimpleType(classOf[(_, _)], classOf[(_, _)], Array(longType, longType)) } } - -class SQLHistoryListenerFactory extends SparkHistoryListenerFactory { - - override def createListeners(conf: SparkConf, sparkUI: SparkUI): Seq[SparkListener] = { - List(new SQLHistoryListener(conf, sparkUI)) - } -} - -class SQLListener(conf: SparkConf) extends SparkListener with Logging { - - private val retainedExecutions = conf.getInt("spark.sql.ui.retainedExecutions", 1000) - - private val activeExecutions = mutable.HashMap[Long, SQLExecutionUIData]() - - // Old data in the following fields must be removed in "trimExecutionsIfNecessary". - // If adding new fields, make sure "trimExecutionsIfNecessary" can clean up old data - private val _executionIdToData = mutable.HashMap[Long, SQLExecutionUIData]() - - /** - * Maintain the relation between job id and execution id so that we can get the execution id in - * the "onJobEnd" method. - */ - private val _jobIdToExecutionId = mutable.HashMap[Long, Long]() - - private val _stageIdToStageMetrics = mutable.HashMap[Long, SQLStageMetrics]() - - private val failedExecutions = mutable.ListBuffer[SQLExecutionUIData]() - - private val completedExecutions = mutable.ListBuffer[SQLExecutionUIData]() - - def executionIdToData: Map[Long, SQLExecutionUIData] = synchronized { - _executionIdToData.toMap - } - - def jobIdToExecutionId: Map[Long, Long] = synchronized { - _jobIdToExecutionId.toMap - } - - def stageIdToStageMetrics: Map[Long, SQLStageMetrics] = synchronized { - _stageIdToStageMetrics.toMap - } - - private def trimExecutionsIfNecessary( - executions: mutable.ListBuffer[SQLExecutionUIData]): Unit = { - if (executions.size > retainedExecutions) { - val toRemove = math.max(retainedExecutions / 10, 1) - executions.take(toRemove).foreach { execution => - for (executionUIData <- _executionIdToData.remove(execution.executionId)) { - for (jobId <- executionUIData.jobs.keys) { - _jobIdToExecutionId.remove(jobId) - } - for (stageId <- executionUIData.stages) { - _stageIdToStageMetrics.remove(stageId) - } - } - } - executions.trimStart(toRemove) - } - } - - override def onJobStart(jobStart: SparkListenerJobStart): Unit = { - val executionIdString = jobStart.properties.getProperty(SQLExecution.EXECUTION_ID_KEY) - if (executionIdString == null) { - // This is not a job created by SQL - return - } - val executionId = executionIdString.toLong - val jobId = jobStart.jobId - val stageIds = jobStart.stageIds - - synchronized { - activeExecutions.get(executionId).foreach { executionUIData => - executionUIData.jobs(jobId) = JobExecutionStatus.RUNNING - executionUIData.stages ++= stageIds - stageIds.foreach(stageId => - _stageIdToStageMetrics(stageId) = new SQLStageMetrics(stageAttemptId = 0)) - _jobIdToExecutionId(jobId) = executionId - } - } - } - - override def onJobEnd(jobEnd: SparkListenerJobEnd): Unit = synchronized { - val jobId = jobEnd.jobId - for (executionId <- _jobIdToExecutionId.get(jobId); - executionUIData <- _executionIdToData.get(executionId)) { - jobEnd.jobResult match { - case JobSucceeded => executionUIData.jobs(jobId) = JobExecutionStatus.SUCCEEDED - case JobFailed(_) => executionUIData.jobs(jobId) = JobExecutionStatus.FAILED - } - if (executionUIData.completionTime.nonEmpty && !executionUIData.hasRunningJobs) { - // We are the last job of this execution, so mark the execution as finished. Note that - // `onExecutionEnd` also does this, but currently that can be called before `onJobEnd` - // since these are called on different threads. - markExecutionFinished(executionId) - } - } - } - - override def onExecutorMetricsUpdate( - executorMetricsUpdate: SparkListenerExecutorMetricsUpdate): Unit = synchronized { - for ((taskId, stageId, stageAttemptID, accumUpdates) <- executorMetricsUpdate.accumUpdates) { - updateTaskAccumulatorValues(taskId, stageId, stageAttemptID, accumUpdates, finishTask = false) - } - } - - override def onStageSubmitted(stageSubmitted: SparkListenerStageSubmitted): Unit = synchronized { - val stageId = stageSubmitted.stageInfo.stageId - val stageAttemptId = stageSubmitted.stageInfo.attemptId - // Always override metrics for old stage attempt - if (_stageIdToStageMetrics.contains(stageId)) { - _stageIdToStageMetrics(stageId) = new SQLStageMetrics(stageAttemptId) - } else { - // If a stage belongs to some SQL execution, its stageId will be put in "onJobStart". - // Since "_stageIdToStageMetrics" doesn't contain it, it must not belong to any SQL execution. - // So we can ignore it. Otherwise, this may lead to memory leaks (SPARK-11126). - } - } - - override def onTaskEnd(taskEnd: SparkListenerTaskEnd): Unit = synchronized { - if (taskEnd.taskMetrics != null) { - updateTaskAccumulatorValues( - taskEnd.taskInfo.taskId, - taskEnd.stageId, - taskEnd.stageAttemptId, - taskEnd.taskMetrics.externalAccums.map(a => a.toInfo(Some(a.value), None)), - finishTask = true) - } - } - - /** - * Update the accumulator values of a task with the latest metrics for this task. This is called - * every time we receive an executor heartbeat or when a task finishes. - */ - protected def updateTaskAccumulatorValues( - taskId: Long, - stageId: Int, - stageAttemptID: Int, - _accumulatorUpdates: Seq[AccumulableInfo], - finishTask: Boolean): Unit = { - val accumulatorUpdates = - _accumulatorUpdates.filter(_.update.isDefined).map(accum => (accum.id, accum.update.get)) - - _stageIdToStageMetrics.get(stageId) match { - case Some(stageMetrics) => - if (stageAttemptID < stageMetrics.stageAttemptId) { - // A task of an old stage attempt. Because a new stage is submitted, we can ignore it. - } else if (stageAttemptID > stageMetrics.stageAttemptId) { - logWarning(s"A task should not have a higher stageAttemptID ($stageAttemptID) then " + - s"what we have seen (${stageMetrics.stageAttemptId})") - } else { - // TODO We don't know the attemptId. Currently, what we can do is overriding the - // accumulator updates. However, if there are two same task are running, such as - // speculation, the accumulator updates will be overriding by different task attempts, - // the results will be weird. - stageMetrics.taskIdToMetricUpdates.get(taskId) match { - case Some(taskMetrics) => - if (finishTask) { - taskMetrics.finished = true - taskMetrics.accumulatorUpdates = accumulatorUpdates - } else if (!taskMetrics.finished) { - taskMetrics.accumulatorUpdates = accumulatorUpdates - } else { - // If a task is finished, we should not override with accumulator updates from - // heartbeat reports - } - case None => - stageMetrics.taskIdToMetricUpdates(taskId) = new SQLTaskMetrics( - finished = finishTask, accumulatorUpdates) - } - } - case None => - // This execution and its stage have been dropped - } - } - - override def onOtherEvent(event: SparkListenerEvent): Unit = event match { - case SparkListenerSQLExecutionStart(executionId, description, details, - physicalPlanDescription, sparkPlanInfo, time) => - val physicalPlanGraph = SparkPlanGraph(sparkPlanInfo) - val sqlPlanMetrics = physicalPlanGraph.allNodes.flatMap { node => - node.metrics.map(metric => metric.accumulatorId -> metric) - } - val executionUIData = new SQLExecutionUIData( - executionId, - description, - details, - physicalPlanDescription, - physicalPlanGraph, - sqlPlanMetrics.toMap, - time) - synchronized { - activeExecutions(executionId) = executionUIData - _executionIdToData(executionId) = executionUIData - } - case SparkListenerSQLExecutionEnd(executionId, time) => synchronized { - _executionIdToData.get(executionId).foreach { executionUIData => - executionUIData.completionTime = Some(time) - if (!executionUIData.hasRunningJobs) { - // onExecutionEnd happens after all "onJobEnd"s - // So we should update the execution lists. - markExecutionFinished(executionId) - } else { - // There are some running jobs, onExecutionEnd happens before some "onJobEnd"s. - // Then we don't if the execution is successful, so let the last onJobEnd updates the - // execution lists. - } - } - } - case SparkListenerDriverAccumUpdates(executionId, accumUpdates) => synchronized { - _executionIdToData.get(executionId).foreach { executionUIData => - for ((accId, accValue) <- accumUpdates) { - executionUIData.driverAccumUpdates(accId) = accValue - } - } - } - case _ => // Ignore - } - - private def markExecutionFinished(executionId: Long): Unit = { - activeExecutions.remove(executionId).foreach { executionUIData => - if (executionUIData.isFailed) { - failedExecutions += executionUIData - trimExecutionsIfNecessary(failedExecutions) - } else { - completedExecutions += executionUIData - trimExecutionsIfNecessary(completedExecutions) - } - } - } - - def getRunningExecutions: Seq[SQLExecutionUIData] = synchronized { - activeExecutions.values.toSeq - } - - def getFailedExecutions: Seq[SQLExecutionUIData] = synchronized { - failedExecutions - } - - def getCompletedExecutions: Seq[SQLExecutionUIData] = synchronized { - completedExecutions - } - - def getExecution(executionId: Long): Option[SQLExecutionUIData] = synchronized { - _executionIdToData.get(executionId) - } - - /** - * Get all accumulator updates from all tasks which belong to this execution and merge them. - */ - def getExecutionMetrics(executionId: Long): Map[Long, String] = synchronized { - _executionIdToData.get(executionId) match { - case Some(executionUIData) => - val accumulatorUpdates = { - for (stageId <- executionUIData.stages; - stageMetrics <- _stageIdToStageMetrics.get(stageId).toIterable; - taskMetrics <- stageMetrics.taskIdToMetricUpdates.values; - accumulatorUpdate <- taskMetrics.accumulatorUpdates) yield { - (accumulatorUpdate._1, accumulatorUpdate._2) - } - } - - val driverUpdates = executionUIData.driverAccumUpdates.toSeq - val totalUpdates = (accumulatorUpdates ++ driverUpdates).filter { - case (id, _) => executionUIData.accumulatorMetrics.contains(id) - } - mergeAccumulatorUpdates(totalUpdates, accumulatorId => - executionUIData.accumulatorMetrics(accumulatorId).metricType) - case None => - // This execution has been dropped - Map.empty - } - } - - private def mergeAccumulatorUpdates( - accumulatorUpdates: Seq[(Long, Any)], - metricTypeFunc: Long => String): Map[Long, String] = { - accumulatorUpdates.groupBy(_._1).map { case (accumulatorId, values) => - val metricType = metricTypeFunc(accumulatorId) - accumulatorId -> - SQLMetrics.stringValue(metricType, values.map(_._2.asInstanceOf[Long])) - } - } - -} - - -/** - * A [[SQLListener]] for rendering the SQL UI in the history server. - */ -class SQLHistoryListener(conf: SparkConf, sparkUI: SparkUI) - extends SQLListener(conf) { - - private var sqlTabAttached = false - - override def onExecutorMetricsUpdate(u: SparkListenerExecutorMetricsUpdate): Unit = { - // Do nothing; these events are not logged - } - - override def onTaskEnd(taskEnd: SparkListenerTaskEnd): Unit = synchronized { - updateTaskAccumulatorValues( - taskEnd.taskInfo.taskId, - taskEnd.stageId, - taskEnd.stageAttemptId, - taskEnd.taskInfo.accumulables.flatMap { a => - // Filter out accumulators that are not SQL metrics - // For now we assume all SQL metrics are Long's that have been JSON serialized as String's - if (a.metadata == Some(AccumulatorContext.SQL_ACCUM_IDENTIFIER)) { - val newValue = a.update.map(_.toString.toLong).getOrElse(0L) - Some(a.copy(update = Some(newValue))) - } else { - None - } - }, - finishTask = true) - } - - override def onOtherEvent(event: SparkListenerEvent): Unit = event match { - case _: SparkListenerSQLExecutionStart => - if (!sqlTabAttached) { - new SQLTab(this, sparkUI) - sqlTabAttached = true - } - super.onOtherEvent(event) - case _ => super.onOtherEvent(event) - } -} - -/** - * Represent all necessary data for an execution that will be used in Web UI. - */ -private[ui] class SQLExecutionUIData( - val executionId: Long, - val description: String, - val details: String, - val physicalPlanDescription: String, - val physicalPlanGraph: SparkPlanGraph, - val accumulatorMetrics: Map[Long, SQLPlanMetric], - val submissionTime: Long) { - - var completionTime: Option[Long] = None - - val jobs: mutable.HashMap[Long, JobExecutionStatus] = mutable.HashMap.empty - - val stages: mutable.ArrayBuffer[Int] = mutable.ArrayBuffer() - - val driverAccumUpdates: mutable.HashMap[Long, Long] = mutable.HashMap.empty - - /** - * Return whether there are running jobs in this execution. - */ - def hasRunningJobs: Boolean = jobs.values.exists(_ == JobExecutionStatus.RUNNING) - - /** - * Return whether there are any failed jobs in this execution. - */ - def isFailed: Boolean = jobs.values.exists(_ == JobExecutionStatus.FAILED) - - def runningJobs: Seq[Long] = - jobs.filter { case (_, status) => status == JobExecutionStatus.RUNNING }.keys.toSeq - - def succeededJobs: Seq[Long] = - jobs.filter { case (_, status) => status == JobExecutionStatus.SUCCEEDED }.keys.toSeq - - def failedJobs: Seq[Long] = - jobs.filter { case (_, status) => status == JobExecutionStatus.FAILED }.keys.toSeq -} - -/** - * Represent a metric in a SQLPlan. - * - * Because we cannot revert our changes for an "Accumulator", we need to maintain accumulator - * updates for each task. So that if a task is retried, we can simply override the old updates with - * the new updates of the new attempt task. Since we cannot add them to accumulator, we need to use - * "AccumulatorParam" to get the aggregation value. - */ -private[ui] case class SQLPlanMetric( - name: String, - accumulatorId: Long, - metricType: String) - -/** - * Store all accumulatorUpdates for all tasks in a Spark stage. - */ -private[ui] class SQLStageMetrics( - val stageAttemptId: Long, - val taskIdToMetricUpdates: mutable.HashMap[Long, SQLTaskMetrics] = mutable.HashMap.empty) - - -// TODO Should add attemptId here when we can get it from SparkListenerExecutorMetricsUpdate -/** - * Store all accumulatorUpdates for a Spark task. - */ -private[ui] class SQLTaskMetrics( - var finished: Boolean, - var accumulatorUpdates: Seq[(Long, Any)]) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLTab.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLTab.scala index d0376af3e31ca..a321a22f10789 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLTab.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLTab.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.execution.ui import org.apache.spark.internal.Logging import org.apache.spark.ui.{SparkUI, SparkUITab} -class SQLTab(val listener: SQLListener, sparkUI: SparkUI) +class SQLTab(val sqlStore: SQLAppStatusStore, sparkUI: SparkUI) extends SparkUITab(sparkUI, "SQL") with Logging { val parent = sparkUI diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/SharedState.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/SharedState.scala index ad9db308b2627..3e479faed72ac 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/SharedState.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/SharedState.scala @@ -32,7 +32,6 @@ import org.apache.spark.scheduler.LiveListenerBus import org.apache.spark.sql.{SparkSession, SQLContext} import org.apache.spark.sql.catalyst.catalog._ import org.apache.spark.sql.execution.CacheManager -import org.apache.spark.sql.execution.ui.{SQLListener, SQLTab} import org.apache.spark.sql.internal.StaticSQLConf._ import org.apache.spark.util.{MutableURLClassLoader, Utils} @@ -83,11 +82,6 @@ private[sql] class SharedState(val sparkContext: SparkContext) extends Logging { */ val cacheManager: CacheManager = new CacheManager - /** - * A listener for SQL-specific [[org.apache.spark.scheduler.SparkListenerEvent]]s. - */ - val listener: SQLListener = createListenerAndUI(sparkContext) - /** * A catalog that interacts with external systems. */ @@ -142,19 +136,6 @@ private[sql] class SharedState(val sparkContext: SparkContext) extends Logging { val jarClassLoader = new NonClosableMutableURLClassLoader( org.apache.spark.util.Utils.getContextOrSparkClassLoader) - /** - * Create a SQLListener then add it into SparkContext, and create a SQLTab if there is SparkUI. - */ - private def createListenerAndUI(sc: SparkContext): SQLListener = { - if (SparkSession.sqlListener.get() == null) { - val listener = new SQLListener(sc.conf) - if (SparkSession.sqlListener.compareAndSet(null, listener)) { - sc.listenerBus.addToStatusQueue(listener) - sc.ui.foreach(new SQLTab(listener, _)) - } - } - SparkSession.sqlListener.get() - } } object SharedState extends Logging { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala index 58a194b8af62b..d588af3e19dde 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala @@ -24,6 +24,7 @@ import scala.util.Random import org.apache.spark.SparkFunSuite import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.plans.logical.LocalRelation +import org.apache.spark.sql.execution.ui.SQLAppStatusStore import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSQLContext @@ -32,6 +33,13 @@ import org.apache.spark.util.{AccumulatorContext, JsonProtocol} class SQLMetricsSuite extends SparkFunSuite with SQLMetricsTestUtils with SharedSQLContext { import testImplicits._ + private def statusStore: SQLAppStatusStore = { + new SQLAppStatusStore(sparkContext.statusStore.store) + } + + private def currentExecutionIds(): Set[Long] = { + statusStore.executionsList.map(_.executionId).toSet + } /** * Generates a `DataFrame` by filling randomly generated bytes for hash collision. @@ -420,21 +428,19 @@ class SQLMetricsSuite extends SparkFunSuite with SQLMetricsTestUtils with Shared withTempPath { file => // person creates a temporary view. get the DF before listing previous execution IDs val data = person.select('name) - sparkContext.listenerBus.waitUntilEmpty(10000) - val previousExecutionIds = spark.sharedState.listener.executionIdToData.keySet + val previousExecutionIds = currentExecutionIds() // Assume the execution plan is // PhysicalRDD(nodeId = 0) data.write.format("json").save(file.getAbsolutePath) sparkContext.listenerBus.waitUntilEmpty(10000) - val executionIds = - spark.sharedState.listener.executionIdToData.keySet.diff(previousExecutionIds) + val executionIds = currentExecutionIds().diff(previousExecutionIds) assert(executionIds.size === 1) val executionId = executionIds.head - val jobs = spark.sharedState.listener.getExecution(executionId).get.jobs + val jobs = statusStore.execution(executionId).get.jobs // Use "<=" because there is a race condition that we may miss some jobs // TODO Change "<=" to "=" once we fix the race condition that missing the JobStarted event. assert(jobs.size <= 1) - val metricValues = spark.sharedState.listener.getExecutionMetrics(executionId) + val metricValues = statusStore.executionMetrics(executionId) // Because "save" will create a new DataFrame internally, we cannot get the real metric id. // However, we still can check the value. assert(metricValues.values.toSeq.exists(_ === "2")) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsTestUtils.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsTestUtils.scala index 3966e98c1ce06..d89c4b14619fa 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsTestUtils.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsTestUtils.scala @@ -25,7 +25,7 @@ import org.apache.spark.scheduler.{SparkListener, SparkListenerTaskEnd} import org.apache.spark.sql.DataFrame import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.execution.SparkPlanInfo -import org.apache.spark.sql.execution.ui.SparkPlanGraph +import org.apache.spark.sql.execution.ui.{SparkPlanGraph, SQLAppStatusStore} import org.apache.spark.sql.test.SQLTestUtils import org.apache.spark.util.Utils @@ -34,6 +34,14 @@ trait SQLMetricsTestUtils extends SQLTestUtils { import testImplicits._ + private def statusStore: SQLAppStatusStore = { + new SQLAppStatusStore(sparkContext.statusStore.store) + } + + private def currentExecutionIds(): Set[Long] = { + statusStore.executionsList.map(_.executionId).toSet + } + /** * Get execution metrics for the SQL execution and verify metrics values. * @@ -41,24 +49,23 @@ trait SQLMetricsTestUtils extends SQLTestUtils { * @param func the function can produce execution id after running. */ private def verifyWriteDataMetrics(metricsValues: Seq[Int])(func: => Unit): Unit = { - val previousExecutionIds = spark.sharedState.listener.executionIdToData.keySet + val previousExecutionIds = currentExecutionIds() // Run the given function to trigger query execution. func spark.sparkContext.listenerBus.waitUntilEmpty(10000) - val executionIds = - spark.sharedState.listener.executionIdToData.keySet.diff(previousExecutionIds) + val executionIds = currentExecutionIds().diff(previousExecutionIds) assert(executionIds.size == 1) val executionId = executionIds.head - val executionData = spark.sharedState.listener.getExecution(executionId).get - val executedNode = executionData.physicalPlanGraph.nodes.head + val executionData = statusStore.execution(executionId).get + val executedNode = statusStore.planGraph(executionId).nodes.head val metricsNames = Seq( "number of written files", "number of dynamic part", "number of output rows") - val metrics = spark.sharedState.listener.getExecutionMetrics(executionId) + val metrics = statusStore.executionMetrics(executionId) metricsNames.zip(metricsValues).foreach { case (metricsName, expected) => val sqlMetric = executedNode.metrics.find(_.name == metricsName) @@ -134,22 +141,21 @@ trait SQLMetricsTestUtils extends SQLTestUtils { expectedNumOfJobs: Int, expectedNodeIds: Set[Long], enableWholeStage: Boolean = false): Option[Map[Long, (String, Map[String, Any])]] = { - val previousExecutionIds = spark.sharedState.listener.executionIdToData.keySet + val previousExecutionIds = currentExecutionIds() withSQLConf("spark.sql.codegen.wholeStage" -> enableWholeStage.toString) { df.collect() } sparkContext.listenerBus.waitUntilEmpty(10000) - val executionIds = - spark.sharedState.listener.executionIdToData.keySet.diff(previousExecutionIds) + val executionIds = currentExecutionIds().diff(previousExecutionIds) assert(executionIds.size === 1) val executionId = executionIds.head - val jobs = spark.sharedState.listener.getExecution(executionId).get.jobs + val jobs = statusStore.execution(executionId).get.jobs // Use "<=" because there is a race condition that we may miss some jobs // TODO Change it to "=" once we fix the race condition that missing the JobStarted event. assert(jobs.size <= expectedNumOfJobs) if (jobs.size == expectedNumOfJobs) { // If we can track all jobs, check the metric values - val metricValues = spark.sharedState.listener.getExecutionMetrics(executionId) + val metricValues = statusStore.executionMetrics(executionId) val metrics = SparkPlanGraph(SparkPlanInfo.fromSparkPlan( df.queryExecution.executedPlan)).allNodes.filter { node => expectedNodeIds.contains(node.id) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLListenerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLListenerSuite.scala index 1055f09f5411c..eba8d55daad58 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLListenerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLListenerSuite.scala @@ -19,12 +19,12 @@ package org.apache.spark.sql.execution.ui import java.util.Properties +import scala.collection.mutable.ListBuffer + import org.json4s.jackson.JsonMethods._ -import org.mockito.Mockito.mock import org.apache.spark._ import org.apache.spark.LocalSparkContext._ -import org.apache.spark.executor.TaskMetrics import org.apache.spark.internal.config import org.apache.spark.rdd.RDD import org.apache.spark.scheduler._ @@ -36,13 +36,14 @@ import org.apache.spark.sql.catalyst.util.quietly import org.apache.spark.sql.execution.{LeafExecNode, QueryExecution, SparkPlanInfo, SQLExecution} import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics} import org.apache.spark.sql.test.SharedSQLContext -import org.apache.spark.ui.SparkUI +import org.apache.spark.status.config._ import org.apache.spark.util.{AccumulatorMetadata, JsonProtocol, LongAccumulator} - +import org.apache.spark.util.kvstore.InMemoryStore class SQLListenerSuite extends SparkFunSuite with SharedSQLContext with JsonTestUtils { import testImplicits._ - import org.apache.spark.AccumulatorSuite.makeInfo + + override protected def sparkConf = super.sparkConf.set(LIVE_ENTITY_UPDATE_PERIOD, 0L) private def createTestDataFrame: DataFrame = { Seq( @@ -68,44 +69,67 @@ class SQLListenerSuite extends SparkFunSuite with SharedSQLContext with JsonTest details = "" ) - private def createTaskInfo(taskId: Int, attemptNumber: Int): TaskInfo = new TaskInfo( - taskId = taskId, - attemptNumber = attemptNumber, - // The following fields are not used in tests - index = 0, - launchTime = 0, - executorId = "", - host = "", - taskLocality = null, - speculative = false - ) + private def createTaskInfo( + taskId: Int, + attemptNumber: Int, + accums: Map[Long, Long] = Map()): TaskInfo = { + val info = new TaskInfo( + taskId = taskId, + attemptNumber = attemptNumber, + // The following fields are not used in tests + index = 0, + launchTime = 0, + executorId = "", + host = "", + taskLocality = null, + speculative = false) + info.markFinished(TaskState.FINISHED, 1L) + info.setAccumulables(createAccumulatorInfos(accums)) + info + } - private def createTaskMetrics(accumulatorUpdates: Map[Long, Long]): TaskMetrics = { - val metrics = TaskMetrics.empty - accumulatorUpdates.foreach { case (id, update) => + private def createAccumulatorInfos(accumulatorUpdates: Map[Long, Long]): Seq[AccumulableInfo] = { + accumulatorUpdates.map { case (id, value) => val acc = new LongAccumulator - acc.metadata = AccumulatorMetadata(id, Some(""), true) - acc.add(update) - metrics.registerAccumulator(acc) + acc.metadata = AccumulatorMetadata(id, None, false) + acc.toInfo(Some(value), None) + }.toSeq + } + + /** Return the shared SQL store from the active SparkSession. */ + private def statusStore: SQLAppStatusStore = + new SQLAppStatusStore(spark.sparkContext.statusStore.store) + + /** + * Runs a test with a temporary SQLAppStatusStore tied to a listener bus. Events can be sent to + * the listener bus to update the store, and all data will be cleaned up at the end of the test. + */ + private def sqlStoreTest(name: String) + (fn: (SQLAppStatusStore, SparkListenerBus) => Unit): Unit = { + test(name) { + val store = new InMemoryStore() + val bus = new ReplayListenerBus() + val listener = new SQLAppStatusListener(sparkConf, store, true) + bus.addListener(listener) + val sqlStore = new SQLAppStatusStore(store, Some(listener)) + fn(sqlStore, bus) } - metrics } - test("basic") { + sqlStoreTest("basic") { (store, bus) => def checkAnswer(actual: Map[Long, String], expected: Map[Long, Long]): Unit = { assert(actual.size == expected.size) - expected.foreach { e => + expected.foreach { case (id, value) => // The values in actual can be SQL metrics meaning that they contain additional formatting // when converted to string. Verify that they start with the expected value. // TODO: this is brittle. There is no requirement that the actual string needs to start // with the accumulator value. - assert(actual.contains(e._1)) - val v = actual.get(e._1).get.trim - assert(v.startsWith(e._2.toString)) + assert(actual.contains(id)) + val v = actual.get(id).get.trim + assert(v.startsWith(value.toString), s"Wrong value for accumulator $id") } } - val listener = new SQLListener(spark.sparkContext.conf) val executionId = 0 val df = createTestDataFrame val accumulatorIds = @@ -118,7 +142,7 @@ class SQLListenerSuite extends SparkFunSuite with SharedSQLContext with JsonTest (id, accumulatorValue) }.toMap - listener.onOtherEvent(SparkListenerSQLExecutionStart( + bus.postToAll(SparkListenerSQLExecutionStart( executionId, "test", "test", @@ -126,9 +150,7 @@ class SQLListenerSuite extends SparkFunSuite with SharedSQLContext with JsonTest SparkPlanInfo.fromSparkPlan(df.queryExecution.executedPlan), System.currentTimeMillis())) - val executionUIData = listener.executionIdToData(0) - - listener.onJobStart(SparkListenerJobStart( + bus.postToAll(SparkListenerJobStart( jobId = 0, time = System.currentTimeMillis(), stageInfos = Seq( @@ -136,291 +158,270 @@ class SQLListenerSuite extends SparkFunSuite with SharedSQLContext with JsonTest createStageInfo(1, 0) ), createProperties(executionId))) - listener.onStageSubmitted(SparkListenerStageSubmitted(createStageInfo(0, 0))) + bus.postToAll(SparkListenerStageSubmitted(createStageInfo(0, 0))) - assert(listener.getExecutionMetrics(0).isEmpty) + assert(store.executionMetrics(0).isEmpty) - listener.onExecutorMetricsUpdate(SparkListenerExecutorMetricsUpdate("", Seq( + bus.postToAll(SparkListenerExecutorMetricsUpdate("", Seq( // (task id, stage id, stage attempt, accum updates) - (0L, 0, 0, createTaskMetrics(accumulatorUpdates).accumulators().map(makeInfo)), - (1L, 0, 0, createTaskMetrics(accumulatorUpdates).accumulators().map(makeInfo)) + (0L, 0, 0, createAccumulatorInfos(accumulatorUpdates)), + (1L, 0, 0, createAccumulatorInfos(accumulatorUpdates)) ))) - checkAnswer(listener.getExecutionMetrics(0), accumulatorUpdates.mapValues(_ * 2)) + checkAnswer(store.executionMetrics(0), accumulatorUpdates.mapValues(_ * 2)) // Driver accumulator updates don't belong to this execution should be filtered and no // exception will be thrown. - listener.onOtherEvent(SparkListenerDriverAccumUpdates(0, Seq((999L, 2L)))) - checkAnswer(listener.getExecutionMetrics(0), accumulatorUpdates.mapValues(_ * 2)) + bus.postToAll(SparkListenerDriverAccumUpdates(0, Seq((999L, 2L)))) - listener.onExecutorMetricsUpdate(SparkListenerExecutorMetricsUpdate("", Seq( + checkAnswer(store.executionMetrics(0), accumulatorUpdates.mapValues(_ * 2)) + + bus.postToAll(SparkListenerExecutorMetricsUpdate("", Seq( // (task id, stage id, stage attempt, accum updates) - (0L, 0, 0, createTaskMetrics(accumulatorUpdates).accumulators().map(makeInfo)), - (1L, 0, 0, - createTaskMetrics(accumulatorUpdates.mapValues(_ * 2)).accumulators().map(makeInfo)) + (0L, 0, 0, createAccumulatorInfos(accumulatorUpdates)), + (1L, 0, 0, createAccumulatorInfos(accumulatorUpdates.mapValues(_ * 2))) ))) - checkAnswer(listener.getExecutionMetrics(0), accumulatorUpdates.mapValues(_ * 3)) + checkAnswer(store.executionMetrics(0), accumulatorUpdates.mapValues(_ * 3)) // Retrying a stage should reset the metrics - listener.onStageSubmitted(SparkListenerStageSubmitted(createStageInfo(0, 1))) + bus.postToAll(SparkListenerStageSubmitted(createStageInfo(0, 1))) - listener.onExecutorMetricsUpdate(SparkListenerExecutorMetricsUpdate("", Seq( + bus.postToAll(SparkListenerExecutorMetricsUpdate("", Seq( // (task id, stage id, stage attempt, accum updates) - (0L, 0, 1, createTaskMetrics(accumulatorUpdates).accumulators().map(makeInfo)), - (1L, 0, 1, createTaskMetrics(accumulatorUpdates).accumulators().map(makeInfo)) + (0L, 0, 1, createAccumulatorInfos(accumulatorUpdates)), + (1L, 0, 1, createAccumulatorInfos(accumulatorUpdates)) ))) - checkAnswer(listener.getExecutionMetrics(0), accumulatorUpdates.mapValues(_ * 2)) + checkAnswer(store.executionMetrics(0), accumulatorUpdates.mapValues(_ * 2)) // Ignore the task end for the first attempt - listener.onTaskEnd(SparkListenerTaskEnd( + bus.postToAll(SparkListenerTaskEnd( stageId = 0, stageAttemptId = 0, taskType = "", reason = null, - createTaskInfo(0, 0), - createTaskMetrics(accumulatorUpdates.mapValues(_ * 100)))) + createTaskInfo(0, 0, accums = accumulatorUpdates.mapValues(_ * 100)), + null)) - checkAnswer(listener.getExecutionMetrics(0), accumulatorUpdates.mapValues(_ * 2)) + checkAnswer(store.executionMetrics(0), accumulatorUpdates.mapValues(_ * 2)) // Finish two tasks - listener.onTaskEnd(SparkListenerTaskEnd( + bus.postToAll(SparkListenerTaskEnd( stageId = 0, stageAttemptId = 1, taskType = "", reason = null, - createTaskInfo(0, 0), - createTaskMetrics(accumulatorUpdates.mapValues(_ * 2)))) - listener.onTaskEnd(SparkListenerTaskEnd( + createTaskInfo(0, 0, accums = accumulatorUpdates.mapValues(_ * 2)), + null)) + bus.postToAll(SparkListenerTaskEnd( stageId = 0, stageAttemptId = 1, taskType = "", reason = null, - createTaskInfo(1, 0), - createTaskMetrics(accumulatorUpdates.mapValues(_ * 3)))) + createTaskInfo(1, 0, accums = accumulatorUpdates.mapValues(_ * 3)), + null)) - checkAnswer(listener.getExecutionMetrics(0), accumulatorUpdates.mapValues(_ * 5)) + checkAnswer(store.executionMetrics(0), accumulatorUpdates.mapValues(_ * 5)) // Summit a new stage - listener.onStageSubmitted(SparkListenerStageSubmitted(createStageInfo(1, 0))) + bus.postToAll(SparkListenerStageSubmitted(createStageInfo(1, 0))) - listener.onExecutorMetricsUpdate(SparkListenerExecutorMetricsUpdate("", Seq( + bus.postToAll(SparkListenerExecutorMetricsUpdate("", Seq( // (task id, stage id, stage attempt, accum updates) - (0L, 1, 0, createTaskMetrics(accumulatorUpdates).accumulators().map(makeInfo)), - (1L, 1, 0, createTaskMetrics(accumulatorUpdates).accumulators().map(makeInfo)) + (0L, 1, 0, createAccumulatorInfos(accumulatorUpdates)), + (1L, 1, 0, createAccumulatorInfos(accumulatorUpdates)) ))) - checkAnswer(listener.getExecutionMetrics(0), accumulatorUpdates.mapValues(_ * 7)) + checkAnswer(store.executionMetrics(0), accumulatorUpdates.mapValues(_ * 7)) // Finish two tasks - listener.onTaskEnd(SparkListenerTaskEnd( + bus.postToAll(SparkListenerTaskEnd( stageId = 1, stageAttemptId = 0, taskType = "", reason = null, - createTaskInfo(0, 0), - createTaskMetrics(accumulatorUpdates.mapValues(_ * 3)))) - listener.onTaskEnd(SparkListenerTaskEnd( + createTaskInfo(0, 0, accums = accumulatorUpdates.mapValues(_ * 3)), + null)) + bus.postToAll(SparkListenerTaskEnd( stageId = 1, stageAttemptId = 0, taskType = "", reason = null, - createTaskInfo(1, 0), - createTaskMetrics(accumulatorUpdates.mapValues(_ * 3)))) + createTaskInfo(1, 0, accums = accumulatorUpdates.mapValues(_ * 3)), + null)) - checkAnswer(listener.getExecutionMetrics(0), accumulatorUpdates.mapValues(_ * 11)) + checkAnswer(store.executionMetrics(0), accumulatorUpdates.mapValues(_ * 11)) - assert(executionUIData.runningJobs === Seq(0)) - assert(executionUIData.succeededJobs.isEmpty) - assert(executionUIData.failedJobs.isEmpty) + assertJobs(store.execution(0), running = Seq(0)) - listener.onJobEnd(SparkListenerJobEnd( + bus.postToAll(SparkListenerJobEnd( jobId = 0, time = System.currentTimeMillis(), JobSucceeded )) - listener.onOtherEvent(SparkListenerSQLExecutionEnd( + bus.postToAll(SparkListenerSQLExecutionEnd( executionId, System.currentTimeMillis())) - assert(executionUIData.runningJobs.isEmpty) - assert(executionUIData.succeededJobs === Seq(0)) - assert(executionUIData.failedJobs.isEmpty) - - checkAnswer(listener.getExecutionMetrics(0), accumulatorUpdates.mapValues(_ * 11)) + assertJobs(store.execution(0), completed = Seq(0)) + checkAnswer(store.executionMetrics(0), accumulatorUpdates.mapValues(_ * 11)) } - test("onExecutionEnd happens before onJobEnd(JobSucceeded)") { - val listener = new SQLListener(spark.sparkContext.conf) + sqlStoreTest("onExecutionEnd happens before onJobEnd(JobSucceeded)") { (store, bus) => val executionId = 0 val df = createTestDataFrame - listener.onOtherEvent(SparkListenerSQLExecutionStart( + bus.postToAll(SparkListenerSQLExecutionStart( executionId, "test", "test", df.queryExecution.toString, SparkPlanInfo.fromSparkPlan(df.queryExecution.executedPlan), System.currentTimeMillis())) - listener.onJobStart(SparkListenerJobStart( + bus.postToAll(SparkListenerJobStart( jobId = 0, time = System.currentTimeMillis(), stageInfos = Nil, createProperties(executionId))) - listener.onOtherEvent(SparkListenerSQLExecutionEnd( + bus.postToAll(SparkListenerSQLExecutionEnd( executionId, System.currentTimeMillis())) - listener.onJobEnd(SparkListenerJobEnd( + bus.postToAll(SparkListenerJobEnd( jobId = 0, time = System.currentTimeMillis(), JobSucceeded )) - val executionUIData = listener.executionIdToData(0) - assert(executionUIData.runningJobs.isEmpty) - assert(executionUIData.succeededJobs === Seq(0)) - assert(executionUIData.failedJobs.isEmpty) + assertJobs(store.execution(0), completed = Seq(0)) } - test("onExecutionEnd happens before multiple onJobEnd(JobSucceeded)s") { - val listener = new SQLListener(spark.sparkContext.conf) + sqlStoreTest("onExecutionEnd happens before multiple onJobEnd(JobSucceeded)s") { (store, bus) => val executionId = 0 val df = createTestDataFrame - listener.onOtherEvent(SparkListenerSQLExecutionStart( + bus.postToAll(SparkListenerSQLExecutionStart( executionId, "test", "test", df.queryExecution.toString, SparkPlanInfo.fromSparkPlan(df.queryExecution.executedPlan), System.currentTimeMillis())) - listener.onJobStart(SparkListenerJobStart( + bus.postToAll(SparkListenerJobStart( jobId = 0, time = System.currentTimeMillis(), stageInfos = Nil, createProperties(executionId))) - listener.onJobEnd(SparkListenerJobEnd( + bus.postToAll(SparkListenerJobEnd( jobId = 0, time = System.currentTimeMillis(), JobSucceeded )) - listener.onJobStart(SparkListenerJobStart( + bus.postToAll(SparkListenerJobStart( jobId = 1, time = System.currentTimeMillis(), stageInfos = Nil, createProperties(executionId))) - listener.onOtherEvent(SparkListenerSQLExecutionEnd( + bus.postToAll(SparkListenerSQLExecutionEnd( executionId, System.currentTimeMillis())) - listener.onJobEnd(SparkListenerJobEnd( + bus.postToAll(SparkListenerJobEnd( jobId = 1, time = System.currentTimeMillis(), JobSucceeded )) - val executionUIData = listener.executionIdToData(0) - assert(executionUIData.runningJobs.isEmpty) - assert(executionUIData.succeededJobs.sorted === Seq(0, 1)) - assert(executionUIData.failedJobs.isEmpty) + assertJobs(store.execution(0), completed = Seq(0, 1)) } - test("onExecutionEnd happens before onJobEnd(JobFailed)") { - val listener = new SQLListener(spark.sparkContext.conf) + sqlStoreTest("onExecutionEnd happens before onJobEnd(JobFailed)") { (store, bus) => val executionId = 0 val df = createTestDataFrame - listener.onOtherEvent(SparkListenerSQLExecutionStart( + bus.postToAll(SparkListenerSQLExecutionStart( executionId, "test", "test", df.queryExecution.toString, SparkPlanInfo.fromSparkPlan(df.queryExecution.executedPlan), System.currentTimeMillis())) - listener.onJobStart(SparkListenerJobStart( + bus.postToAll(SparkListenerJobStart( jobId = 0, time = System.currentTimeMillis(), stageInfos = Seq.empty, createProperties(executionId))) - listener.onOtherEvent(SparkListenerSQLExecutionEnd( + bus.postToAll(SparkListenerSQLExecutionEnd( executionId, System.currentTimeMillis())) - listener.onJobEnd(SparkListenerJobEnd( + bus.postToAll(SparkListenerJobEnd( jobId = 0, time = System.currentTimeMillis(), JobFailed(new RuntimeException("Oops")) )) - val executionUIData = listener.executionIdToData(0) - assert(executionUIData.runningJobs.isEmpty) - assert(executionUIData.succeededJobs.isEmpty) - assert(executionUIData.failedJobs === Seq(0)) + assertJobs(store.execution(0), failed = Seq(0)) } test("SPARK-11126: no memory leak when running non SQL jobs") { - val previousStageNumber = spark.sharedState.listener.stageIdToStageMetrics.size + val previousStageNumber = statusStore.executionsList().size spark.sparkContext.parallelize(1 to 10).foreach(i => ()) spark.sparkContext.listenerBus.waitUntilEmpty(10000) // listener should ignore the non SQL stage - assert(spark.sharedState.listener.stageIdToStageMetrics.size == previousStageNumber) + assert(statusStore.executionsList().size == previousStageNumber) spark.sparkContext.parallelize(1 to 10).toDF().foreach(i => ()) spark.sparkContext.listenerBus.waitUntilEmpty(10000) // listener should save the SQL stage - assert(spark.sharedState.listener.stageIdToStageMetrics.size == previousStageNumber + 1) - } - - test("SPARK-13055: history listener only tracks SQL metrics") { - val listener = new SQLHistoryListener(sparkContext.conf, mock(classOf[SparkUI])) - // We need to post other events for the listener to track our accumulators. - // These are largely just boilerplate unrelated to what we're trying to test. - val df = createTestDataFrame - val executionStart = SparkListenerSQLExecutionStart( - 0, "", "", "", SparkPlanInfo.fromSparkPlan(df.queryExecution.executedPlan), 0) - val stageInfo = createStageInfo(0, 0) - val jobStart = SparkListenerJobStart(0, 0, Seq(stageInfo), createProperties(0)) - val stageSubmitted = SparkListenerStageSubmitted(stageInfo) - // This task has both accumulators that are SQL metrics and accumulators that are not. - // The listener should only track the ones that are actually SQL metrics. - val sqlMetric = SQLMetrics.createMetric(sparkContext, "beach umbrella") - val nonSqlMetric = sparkContext.longAccumulator("baseball") - val sqlMetricInfo = sqlMetric.toInfo(Some(sqlMetric.value), None) - val nonSqlMetricInfo = nonSqlMetric.toInfo(Some(nonSqlMetric.value), None) - val taskInfo = createTaskInfo(0, 0) - taskInfo.setAccumulables(List(sqlMetricInfo, nonSqlMetricInfo)) - val taskEnd = SparkListenerTaskEnd(0, 0, "just-a-task", null, taskInfo, null) - listener.onOtherEvent(executionStart) - listener.onJobStart(jobStart) - listener.onStageSubmitted(stageSubmitted) - // Before SPARK-13055, this throws ClassCastException because the history listener would - // assume that the accumulator value is of type Long, but this may not be true for - // accumulators that are not SQL metrics. - listener.onTaskEnd(taskEnd) - val trackedAccums = listener.stageIdToStageMetrics.values.flatMap { stageMetrics => - stageMetrics.taskIdToMetricUpdates.values.flatMap(_.accumulatorUpdates) - } - // Listener tracks only SQL metrics, not other accumulators - assert(trackedAccums.size === 1) - assert(trackedAccums.head === ((sqlMetricInfo.id, sqlMetricInfo.update.get))) + assert(statusStore.executionsList().size == previousStageNumber + 1) } test("driver side SQL metrics") { - val listener = new SQLListener(spark.sparkContext.conf) - val expectedAccumValue = 12345 + val oldCount = statusStore.executionsList().size + val expectedAccumValue = 12345L val physicalPlan = MyPlan(sqlContext.sparkContext, expectedAccumValue) - sqlContext.sparkContext.addSparkListener(listener) val dummyQueryExecution = new QueryExecution(spark, LocalRelation()) { override lazy val sparkPlan = physicalPlan override lazy val executedPlan = physicalPlan } + SQLExecution.withNewExecutionId(spark, dummyQueryExecution) { physicalPlan.execute().collect() } - def waitTillExecutionFinished(): Unit = { - while (listener.getCompletedExecutions.isEmpty) { - Thread.sleep(100) + while (statusStore.executionsList().size < oldCount) { + Thread.sleep(100) + } + + // Wait for listener to finish computing the metrics for the execution. + while (statusStore.executionsList().last.metricValues == null) { + Thread.sleep(100) + } + + val execId = statusStore.executionsList().last.executionId + val metrics = statusStore.executionMetrics(execId) + val driverMetric = physicalPlan.metrics("dummy") + val expectedValue = SQLMetrics.stringValue(driverMetric.metricType, Seq(expectedAccumValue)) + + assert(metrics.contains(driverMetric.id)) + assert(metrics(driverMetric.id) === expectedValue) + } + + private def assertJobs( + exec: Option[SQLExecutionUIData], + running: Seq[Int] = Nil, + completed: Seq[Int] = Nil, + failed: Seq[Int] = Nil): Unit = { + + val actualRunning = new ListBuffer[Int]() + val actualCompleted = new ListBuffer[Int]() + val actualFailed = new ListBuffer[Int]() + + exec.get.jobs.foreach { case (jobId, jobStatus) => + jobStatus match { + case JobExecutionStatus.RUNNING => actualRunning += jobId + case JobExecutionStatus.SUCCEEDED => actualCompleted += jobId + case JobExecutionStatus.FAILED => actualFailed += jobId + case _ => fail(s"Unexpected status $jobStatus") } } - waitTillExecutionFinished() - val driverUpdates = listener.getCompletedExecutions.head.driverAccumUpdates - assert(driverUpdates.size == 1) - assert(driverUpdates(physicalPlan.longMetric("dummy").id) == expectedAccumValue) + assert(actualRunning.toSeq.sorted === running) + assert(actualCompleted.toSeq.sorted === completed) + assert(actualFailed.toSeq.sorted === failed) } test("roundtripping SparkListenerDriverAccumUpdates through JsonProtocol (SPARK-18462)") { @@ -490,7 +491,8 @@ private case class MyPlan(sc: SparkContext, expectedValue: Long) extends LeafExe class SQLListenerMemoryLeakSuite extends SparkFunSuite { - test("no memory leak") { + // TODO: this feature is not yet available in SQLAppStatusStore. + ignore("no memory leak") { quietly { val conf = new SparkConf() .setMaster("local") @@ -498,7 +500,6 @@ class SQLListenerMemoryLeakSuite extends SparkFunSuite { .set(config.MAX_TASK_FAILURES, 1) // Don't retry the tasks to run this test quickly .set("spark.sql.ui.retainedExecutions", "50") // Set it to 50 to run this test quickly withSpark(new SparkContext(conf)) { sc => - SparkSession.sqlListener.set(null) val spark = new SparkSession(sc) import spark.implicits._ // Run 100 successful executions and 100 failed executions. @@ -516,12 +517,9 @@ class SQLListenerMemoryLeakSuite extends SparkFunSuite { } } sc.listenerBus.waitUntilEmpty(10000) - assert(spark.sharedState.listener.getCompletedExecutions.size <= 50) - assert(spark.sharedState.listener.getFailedExecutions.size <= 50) - // 50 for successful executions and 50 for failed executions - assert(spark.sharedState.listener.executionIdToData.size <= 100) - assert(spark.sharedState.listener.jobIdToExecutionId.size <= 100) - assert(spark.sharedState.listener.stageIdToStageMetrics.size <= 100) + + val statusStore = new SQLAppStatusStore(sc.statusStore.store) + assert(statusStore.executionsList().size <= 50) } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSparkSession.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSparkSession.scala index e0568a3c5c99f..0b4629a51b425 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSparkSession.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSparkSession.scala @@ -73,7 +73,6 @@ trait SharedSparkSession * call 'beforeAll'. */ protected def initializeSession(): Unit = { - SparkSession.sqlListener.set(null) if (_spark == null) { _spark = createSparkSession } From eaff295a232217c4424f2885303f9a127dea0422 Mon Sep 17 00:00:00 2001 From: Devaraj K Date: Tue, 14 Nov 2017 15:20:03 -0800 Subject: [PATCH 1670/1765] [SPARK-22519][YARN] Remove unnecessary stagingDirPath null check in ApplicationMaster.cleanupStagingDir() ## What changes were proposed in this pull request? Removed the unnecessary stagingDirPath null check in ApplicationMaster.cleanupStagingDir(). ## How was this patch tested? I verified with the existing test cases. Author: Devaraj K Closes #19749 from devaraj-kavali/SPARK-22519. --- .../org/apache/spark/deploy/yarn/ApplicationMaster.scala | 4 ---- 1 file changed, 4 deletions(-) diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala index 244d912b9f3aa..ca0aa0ea3bc73 100644 --- a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala +++ b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala @@ -608,10 +608,6 @@ private[spark] class ApplicationMaster(args: ApplicationMasterArguments) extends val preserveFiles = sparkConf.get(PRESERVE_STAGING_FILES) if (!preserveFiles) { stagingDirPath = new Path(System.getenv("SPARK_YARN_STAGING_DIR")) - if (stagingDirPath == null) { - logError("Staging directory is null") - return - } logInfo("Deleting staging directory " + stagingDirPath) val fs = stagingDirPath.getFileSystem(yarnConf) fs.delete(stagingDirPath, true) From b009722591d2635698233c84f6e7e6cde7177019 Mon Sep 17 00:00:00 2001 From: Sean Owen Date: Tue, 14 Nov 2017 17:58:07 -0600 Subject: [PATCH 1671/1765] [SPARK-22511][BUILD] Update maven central repo address ## What changes were proposed in this pull request? Use repo.maven.apache.org repo address; use latest ASF parent POM version 18 ## How was this patch tested? Existing tests; no functional change Author: Sean Owen Closes #19742 from srowen/SPARK-22511. --- dev/check-license | 2 +- pom.xml | 8 +++++--- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/dev/check-license b/dev/check-license index 8cee09a53e087..b729f34f24059 100755 --- a/dev/check-license +++ b/dev/check-license @@ -20,7 +20,7 @@ acquire_rat_jar () { - URL="https://repo1.maven.org/maven2/org/apache/rat/apache-rat/${RAT_VERSION}/apache-rat-${RAT_VERSION}.jar" + URL="https://repo.maven.apache.org/maven2/org/apache/rat/apache-rat/${RAT_VERSION}/apache-rat-${RAT_VERSION}.jar" JAR="$rat_jar" diff --git a/pom.xml b/pom.xml index 8570338df6878..0297311dd6e64 100644 --- a/pom.xml +++ b/pom.xml @@ -22,7 +22,7 @@ org.apache apache - 14 + 18 org.apache.spark spark-parent_2.11 @@ -111,6 +111,8 @@ UTF-8 UTF-8 1.8 + ${java.version} + ${java.version} 3.3.9 spark 1.7.16 @@ -226,7 +228,7 @@ central Maven Repository - https://repo1.maven.org/maven2 + https://repo.maven.apache.org/maven2 true @@ -238,7 +240,7 @@ central - https://repo1.maven.org/maven2 + https://repo.maven.apache.org/maven2 true From 774398045b7b0cde4afb3f3c1a19ad491cf71ed1 Mon Sep 17 00:00:00 2001 From: WeichenXu Date: Tue, 14 Nov 2017 16:48:26 -0800 Subject: [PATCH 1672/1765] [SPARK-21087][ML] CrossValidator, TrainValidationSplit expose sub models after fitting: Scala MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## What changes were proposed in this pull request? We add a parameter whether to collect the full model list when CrossValidator/TrainValidationSplit training (Default is NOT), avoid the change cause OOM) - Add a method in CrossValidatorModel/TrainValidationSplitModel, allow user to get the model list - CrossValidatorModelWriter add a “option”, allow user to control whether to persist the model list to disk (will persist by default). - Note: when persisting the model list, use indices as the sub-model path ## How was this patch tested? Test cases added. Author: WeichenXu Closes #19208 from WeichenXu123/expose-model-list. --- .../ml/param/shared/SharedParamsCodeGen.scala | 8 +- .../spark/ml/param/shared/sharedParams.scala | 17 +++ .../spark/ml/tuning/CrossValidator.scala | 137 ++++++++++++++++-- .../ml/tuning/TrainValidationSplit.scala | 128 ++++++++++++++-- .../org/apache/spark/ml/util/ReadWrite.scala | 19 +++ .../spark/ml/tuning/CrossValidatorSuite.scala | 54 ++++++- .../ml/tuning/TrainValidationSplitSuite.scala | 48 +++++- project/MimaExcludes.scala | 6 +- 8 files changed, 388 insertions(+), 29 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala b/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala index 20a1db854e3a6..c54062921fce6 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala @@ -83,7 +83,13 @@ private[shared] object SharedParamsCodeGen { "all instance weights as 1.0"), ParamDesc[String]("solver", "the solver algorithm for optimization", finalFields = false), ParamDesc[Int]("aggregationDepth", "suggested depth for treeAggregate (>= 2)", Some("2"), - isValid = "ParamValidators.gtEq(2)", isExpertParam = true)) + isValid = "ParamValidators.gtEq(2)", isExpertParam = true), + ParamDesc[Boolean]("collectSubModels", "If set to false, then only the single best " + + "sub-model will be available after fitting. If set to true, then all sub-models will be " + + "available. Warning: For large models, collecting all sub-models can cause OOMs on the " + + "Spark driver.", + Some("false"), isExpertParam = true) + ) val code = genSharedParams(params) val file = "src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala" diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala b/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala index 0d5fb28ae783c..34aa38ac751fd 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala @@ -468,4 +468,21 @@ trait HasAggregationDepth extends Params { /** @group expertGetParam */ final def getAggregationDepth: Int = $(aggregationDepth) } + +/** + * Trait for shared param collectSubModels (default: false). + */ +private[ml] trait HasCollectSubModels extends Params { + + /** + * Param for whether to collect a list of sub-models trained during tuning. + * @group expertParam + */ + final val collectSubModels: BooleanParam = new BooleanParam(this, "collectSubModels", "whether to collect a list of sub-models trained during tuning") + + setDefault(collectSubModels, false) + + /** @group expertGetParam */ + final def getCollectSubModels: Boolean = $(collectSubModels) +} // scalastyle:on diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala index 7c81cb96e07f2..1682ca91bf832 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala @@ -17,7 +17,7 @@ package org.apache.spark.ml.tuning -import java.util.{List => JList} +import java.util.{List => JList, Locale} import scala.collection.JavaConverters._ import scala.concurrent.Future @@ -31,7 +31,7 @@ import org.apache.spark.internal.Logging import org.apache.spark.ml.{Estimator, Model} import org.apache.spark.ml.evaluation.Evaluator import org.apache.spark.ml.param.{IntParam, ParamMap, ParamValidators} -import org.apache.spark.ml.param.shared.HasParallelism +import org.apache.spark.ml.param.shared.{HasCollectSubModels, HasParallelism} import org.apache.spark.ml.util._ import org.apache.spark.mllib.util.MLUtils import org.apache.spark.sql.{DataFrame, Dataset} @@ -67,7 +67,8 @@ private[ml] trait CrossValidatorParams extends ValidatorParams { @Since("1.2.0") class CrossValidator @Since("1.2.0") (@Since("1.4.0") override val uid: String) extends Estimator[CrossValidatorModel] - with CrossValidatorParams with HasParallelism with MLWritable with Logging { + with CrossValidatorParams with HasParallelism with HasCollectSubModels + with MLWritable with Logging { @Since("1.2.0") def this() = this(Identifiable.randomUID("cv")) @@ -101,6 +102,21 @@ class CrossValidator @Since("1.2.0") (@Since("1.4.0") override val uid: String) @Since("2.3.0") def setParallelism(value: Int): this.type = set(parallelism, value) + /** + * Whether to collect submodels when fitting. If set, we can get submodels from + * the returned model. + * + * Note: If set this param, when you save the returned model, you can set an option + * "persistSubModels" to be "true" before saving, in order to save these submodels. + * You can check documents of + * {@link org.apache.spark.ml.tuning.CrossValidatorModel.CrossValidatorModelWriter} + * for more information. + * + * @group expertSetParam + */ + @Since("2.3.0") + def setCollectSubModels(value: Boolean): this.type = set(collectSubModels, value) + @Since("2.0.0") override def fit(dataset: Dataset[_]): CrossValidatorModel = { val schema = dataset.schema @@ -117,6 +133,12 @@ class CrossValidator @Since("1.2.0") (@Since("1.4.0") override val uid: String) instr.logParams(numFolds, seed, parallelism) logTuningParams(instr) + val collectSubModelsParam = $(collectSubModels) + + var subModels: Option[Array[Array[Model[_]]]] = if (collectSubModelsParam) { + Some(Array.fill($(numFolds))(Array.fill[Model[_]](epm.length)(null))) + } else None + // Compute metrics for each model over each split val splits = MLUtils.kFold(dataset.toDF.rdd, $(numFolds), $(seed)) val metrics = splits.zipWithIndex.map { case ((training, validation), splitIndex) => @@ -125,10 +147,14 @@ class CrossValidator @Since("1.2.0") (@Since("1.4.0") override val uid: String) logDebug(s"Train split $splitIndex with multiple sets of parameters.") // Fit models in a Future for training in parallel - val modelFutures = epm.map { paramMap => + val modelFutures = epm.zipWithIndex.map { case (paramMap, paramIndex) => Future[Model[_]] { - val model = est.fit(trainingDataset, paramMap) - model.asInstanceOf[Model[_]] + val model = est.fit(trainingDataset, paramMap).asInstanceOf[Model[_]] + + if (collectSubModelsParam) { + subModels.get(splitIndex)(paramIndex) = model + } + model } (executionContext) } @@ -160,7 +186,8 @@ class CrossValidator @Since("1.2.0") (@Since("1.4.0") override val uid: String) logInfo(s"Best cross-validation metric: $bestMetric.") val bestModel = est.fit(dataset, epm(bestIndex)).asInstanceOf[Model[_]] instr.logSuccess(bestModel) - copyValues(new CrossValidatorModel(uid, bestModel, metrics).setParent(this)) + copyValues(new CrossValidatorModel(uid, bestModel, metrics) + .setSubModels(subModels).setParent(this)) } @Since("1.4.0") @@ -244,6 +271,31 @@ class CrossValidatorModel private[ml] ( this(uid, bestModel, avgMetrics.asScala.toArray) } + private var _subModels: Option[Array[Array[Model[_]]]] = None + + private[tuning] def setSubModels(subModels: Option[Array[Array[Model[_]]]]) + : CrossValidatorModel = { + _subModels = subModels + this + } + + /** + * @return submodels represented in two dimension array. The index of outer array is the + * fold index, and the index of inner array corresponds to the ordering of + * estimatorParamMaps + * @throws IllegalArgumentException if subModels are not available. To retrieve subModels, + * make sure to set collectSubModels to true before fitting. + */ + @Since("2.3.0") + def subModels: Array[Array[Model[_]]] = { + require(_subModels.isDefined, "subModels not available, To retrieve subModels, make sure " + + "to set collectSubModels to true before fitting.") + _subModels.get + } + + @Since("2.3.0") + def hasSubModels: Boolean = _subModels.isDefined + @Since("2.0.0") override def transform(dataset: Dataset[_]): DataFrame = { transformSchema(dataset.schema, logging = true) @@ -260,34 +312,76 @@ class CrossValidatorModel private[ml] ( val copied = new CrossValidatorModel( uid, bestModel.copy(extra).asInstanceOf[Model[_]], - avgMetrics.clone()) + avgMetrics.clone() + ).setSubModels(CrossValidatorModel.copySubModels(_subModels)) copyValues(copied, extra).setParent(parent) } @Since("1.6.0") - override def write: MLWriter = new CrossValidatorModel.CrossValidatorModelWriter(this) + override def write: CrossValidatorModel.CrossValidatorModelWriter = { + new CrossValidatorModel.CrossValidatorModelWriter(this) + } } @Since("1.6.0") object CrossValidatorModel extends MLReadable[CrossValidatorModel] { + private[CrossValidatorModel] def copySubModels(subModels: Option[Array[Array[Model[_]]]]) + : Option[Array[Array[Model[_]]]] = { + subModels.map(_.map(_.map(_.copy(ParamMap.empty).asInstanceOf[Model[_]]))) + } + @Since("1.6.0") override def read: MLReader[CrossValidatorModel] = new CrossValidatorModelReader @Since("1.6.0") override def load(path: String): CrossValidatorModel = super.load(path) - private[CrossValidatorModel] - class CrossValidatorModelWriter(instance: CrossValidatorModel) extends MLWriter { + /** + * Writer for CrossValidatorModel. + * @param instance CrossValidatorModel instance used to construct the writer + * + * CrossValidatorModelWriter supports an option "persistSubModels", with possible values + * "true" or "false". If you set the collectSubModels Param before fitting, then you can + * set "persistSubModels" to "true" in order to persist the subModels. By default, + * "persistSubModels" will be "true" when subModels are available and "false" otherwise. + * If subModels are not available, then setting "persistSubModels" to "true" will cause + * an exception. + */ + @Since("2.3.0") + final class CrossValidatorModelWriter private[tuning] ( + instance: CrossValidatorModel) extends MLWriter { ValidatorParams.validateParams(instance) override protected def saveImpl(path: String): Unit = { + val persistSubModelsParam = optionMap.getOrElse("persistsubmodels", + if (instance.hasSubModels) "true" else "false") + + require(Array("true", "false").contains(persistSubModelsParam.toLowerCase(Locale.ROOT)), + s"persistSubModels option value ${persistSubModelsParam} is invalid, the possible " + + "values are \"true\" or \"false\"") + val persistSubModels = persistSubModelsParam.toBoolean + import org.json4s.JsonDSL._ - val extraMetadata = "avgMetrics" -> instance.avgMetrics.toSeq + val extraMetadata = ("avgMetrics" -> instance.avgMetrics.toSeq) ~ + ("persistSubModels" -> persistSubModels) ValidatorParams.saveImpl(path, instance, sc, Some(extraMetadata)) val bestModelPath = new Path(path, "bestModel").toString instance.bestModel.asInstanceOf[MLWritable].save(bestModelPath) + if (persistSubModels) { + require(instance.hasSubModels, "When persisting tuning models, you can only set " + + "persistSubModels to true if the tuning was done with collectSubModels set to true. " + + "To save the sub-models, try rerunning fitting with collectSubModels set to true.") + val subModelsPath = new Path(path, "subModels") + for (splitIndex <- 0 until instance.getNumFolds) { + val splitPath = new Path(subModelsPath, s"fold${splitIndex.toString}") + for (paramIndex <- 0 until instance.getEstimatorParamMaps.length) { + val modelPath = new Path(splitPath, paramIndex.toString).toString + instance.subModels(splitIndex)(paramIndex).asInstanceOf[MLWritable].save(modelPath) + } + } + } } } @@ -301,11 +395,30 @@ object CrossValidatorModel extends MLReadable[CrossValidatorModel] { val (metadata, estimator, evaluator, estimatorParamMaps) = ValidatorParams.loadImpl(path, sc, className) + val numFolds = (metadata.params \ "numFolds").extract[Int] val bestModelPath = new Path(path, "bestModel").toString val bestModel = DefaultParamsReader.loadParamsInstance[Model[_]](bestModelPath, sc) val avgMetrics = (metadata.metadata \ "avgMetrics").extract[Seq[Double]].toArray + val persistSubModels = (metadata.metadata \ "persistSubModels") + .extractOrElse[Boolean](false) + + val subModels: Option[Array[Array[Model[_]]]] = if (persistSubModels) { + val subModelsPath = new Path(path, "subModels") + val _subModels = Array.fill(numFolds)(Array.fill[Model[_]]( + estimatorParamMaps.length)(null)) + for (splitIndex <- 0 until numFolds) { + val splitPath = new Path(subModelsPath, s"fold${splitIndex.toString}") + for (paramIndex <- 0 until estimatorParamMaps.length) { + val modelPath = new Path(splitPath, paramIndex.toString).toString + _subModels(splitIndex)(paramIndex) = + DefaultParamsReader.loadParamsInstance(modelPath, sc) + } + } + Some(_subModels) + } else None val model = new CrossValidatorModel(metadata.uid, bestModel, avgMetrics) + .setSubModels(subModels) model.set(model.estimator, estimator) .set(model.evaluator, evaluator) .set(model.estimatorParamMaps, estimatorParamMaps) diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala index 6e3ad40706803..c73bd18475475 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala @@ -17,8 +17,7 @@ package org.apache.spark.ml.tuning -import java.io.IOException -import java.util.{List => JList} +import java.util.{List => JList, Locale} import scala.collection.JavaConverters._ import scala.concurrent.Future @@ -33,7 +32,7 @@ import org.apache.spark.internal.Logging import org.apache.spark.ml.{Estimator, Model} import org.apache.spark.ml.evaluation.Evaluator import org.apache.spark.ml.param.{DoubleParam, ParamMap, ParamValidators} -import org.apache.spark.ml.param.shared.HasParallelism +import org.apache.spark.ml.param.shared.{HasCollectSubModels, HasParallelism} import org.apache.spark.ml.util._ import org.apache.spark.sql.{DataFrame, Dataset} import org.apache.spark.sql.types.StructType @@ -67,7 +66,8 @@ private[ml] trait TrainValidationSplitParams extends ValidatorParams { @Since("1.5.0") class TrainValidationSplit @Since("1.5.0") (@Since("1.5.0") override val uid: String) extends Estimator[TrainValidationSplitModel] - with TrainValidationSplitParams with HasParallelism with MLWritable with Logging { + with TrainValidationSplitParams with HasParallelism with HasCollectSubModels + with MLWritable with Logging { @Since("1.5.0") def this() = this(Identifiable.randomUID("tvs")) @@ -101,6 +101,20 @@ class TrainValidationSplit @Since("1.5.0") (@Since("1.5.0") override val uid: St @Since("2.3.0") def setParallelism(value: Int): this.type = set(parallelism, value) + /** + * Whether to collect submodels when fitting. If set, we can get submodels from + * the returned model. + * + * Note: If set this param, when you save the returned model, you can set an option + * "persistSubModels" to be "true" before saving, in order to save these submodels. + * You can check documents of + * {@link org.apache.spark.ml.tuning.TrainValidationSplitModel.TrainValidationSplitModelWriter} + * for more information. + * + * @group expertSetParam + */@Since("2.3.0") + def setCollectSubModels(value: Boolean): this.type = set(collectSubModels, value) + @Since("2.0.0") override def fit(dataset: Dataset[_]): TrainValidationSplitModel = { val schema = dataset.schema @@ -121,12 +135,22 @@ class TrainValidationSplit @Since("1.5.0") (@Since("1.5.0") override val uid: St trainingDataset.cache() validationDataset.cache() + val collectSubModelsParam = $(collectSubModels) + + var subModels: Option[Array[Model[_]]] = if (collectSubModelsParam) { + Some(Array.fill[Model[_]](epm.length)(null)) + } else None + // Fit models in a Future for training in parallel logDebug(s"Train split with multiple sets of parameters.") - val modelFutures = epm.map { paramMap => + val modelFutures = epm.zipWithIndex.map { case (paramMap, paramIndex) => Future[Model[_]] { - val model = est.fit(trainingDataset, paramMap) - model.asInstanceOf[Model[_]] + val model = est.fit(trainingDataset, paramMap).asInstanceOf[Model[_]] + + if (collectSubModelsParam) { + subModels.get(paramIndex) = model + } + model } (executionContext) } @@ -158,7 +182,8 @@ class TrainValidationSplit @Since("1.5.0") (@Since("1.5.0") override val uid: St logInfo(s"Best train validation split metric: $bestMetric.") val bestModel = est.fit(dataset, epm(bestIndex)).asInstanceOf[Model[_]] instr.logSuccess(bestModel) - copyValues(new TrainValidationSplitModel(uid, bestModel, metrics).setParent(this)) + copyValues(new TrainValidationSplitModel(uid, bestModel, metrics) + .setSubModels(subModels).setParent(this)) } @Since("1.5.0") @@ -238,6 +263,30 @@ class TrainValidationSplitModel private[ml] ( this(uid, bestModel, validationMetrics.asScala.toArray) } + private var _subModels: Option[Array[Model[_]]] = None + + private[tuning] def setSubModels(subModels: Option[Array[Model[_]]]) + : TrainValidationSplitModel = { + _subModels = subModels + this + } + + /** + * @return submodels represented in array. The index of array corresponds to the ordering of + * estimatorParamMaps + * @throws IllegalArgumentException if subModels are not available. To retrieve subModels, + * make sure to set collectSubModels to true before fitting. + */ + @Since("2.3.0") + def subModels: Array[Model[_]] = { + require(_subModels.isDefined, "subModels not available, To retrieve subModels, make sure " + + "to set collectSubModels to true before fitting.") + _subModels.get + } + + @Since("2.3.0") + def hasSubModels: Boolean = _subModels.isDefined + @Since("2.0.0") override def transform(dataset: Dataset[_]): DataFrame = { transformSchema(dataset.schema, logging = true) @@ -254,34 +303,73 @@ class TrainValidationSplitModel private[ml] ( val copied = new TrainValidationSplitModel ( uid, bestModel.copy(extra).asInstanceOf[Model[_]], - validationMetrics.clone()) + validationMetrics.clone() + ).setSubModels(TrainValidationSplitModel.copySubModels(_subModels)) copyValues(copied, extra).setParent(parent) } @Since("2.0.0") - override def write: MLWriter = new TrainValidationSplitModel.TrainValidationSplitModelWriter(this) + override def write: TrainValidationSplitModel.TrainValidationSplitModelWriter = { + new TrainValidationSplitModel.TrainValidationSplitModelWriter(this) + } } @Since("2.0.0") object TrainValidationSplitModel extends MLReadable[TrainValidationSplitModel] { + private[TrainValidationSplitModel] def copySubModels(subModels: Option[Array[Model[_]]]) + : Option[Array[Model[_]]] = { + subModels.map(_.map(_.copy(ParamMap.empty).asInstanceOf[Model[_]])) + } + @Since("2.0.0") override def read: MLReader[TrainValidationSplitModel] = new TrainValidationSplitModelReader @Since("2.0.0") override def load(path: String): TrainValidationSplitModel = super.load(path) - private[TrainValidationSplitModel] - class TrainValidationSplitModelWriter(instance: TrainValidationSplitModel) extends MLWriter { + /** + * Writer for TrainValidationSplitModel. + * @param instance TrainValidationSplitModel instance used to construct the writer + * + * TrainValidationSplitModel supports an option "persistSubModels", with possible values + * "true" or "false". If you set the collectSubModels Param before fitting, then you can + * set "persistSubModels" to "true" in order to persist the subModels. By default, + * "persistSubModels" will be "true" when subModels are available and "false" otherwise. + * If subModels are not available, then setting "persistSubModels" to "true" will cause + * an exception. + */ + @Since("2.3.0") + final class TrainValidationSplitModelWriter private[tuning] ( + instance: TrainValidationSplitModel) extends MLWriter { ValidatorParams.validateParams(instance) override protected def saveImpl(path: String): Unit = { + val persistSubModelsParam = optionMap.getOrElse("persistsubmodels", + if (instance.hasSubModels) "true" else "false") + + require(Array("true", "false").contains(persistSubModelsParam.toLowerCase(Locale.ROOT)), + s"persistSubModels option value ${persistSubModelsParam} is invalid, the possible " + + "values are \"true\" or \"false\"") + val persistSubModels = persistSubModelsParam.toBoolean + import org.json4s.JsonDSL._ - val extraMetadata = "validationMetrics" -> instance.validationMetrics.toSeq + val extraMetadata = ("validationMetrics" -> instance.validationMetrics.toSeq) ~ + ("persistSubModels" -> persistSubModels) ValidatorParams.saveImpl(path, instance, sc, Some(extraMetadata)) val bestModelPath = new Path(path, "bestModel").toString instance.bestModel.asInstanceOf[MLWritable].save(bestModelPath) + if (persistSubModels) { + require(instance.hasSubModels, "When persisting tuning models, you can only set " + + "persistSubModels to true if the tuning was done with collectSubModels set to true. " + + "To save the sub-models, try rerunning fitting with collectSubModels set to true.") + val subModelsPath = new Path(path, "subModels") + for (paramIndex <- 0 until instance.getEstimatorParamMaps.length) { + val modelPath = new Path(subModelsPath, paramIndex.toString).toString + instance.subModels(paramIndex).asInstanceOf[MLWritable].save(modelPath) + } + } } } @@ -298,8 +386,22 @@ object TrainValidationSplitModel extends MLReadable[TrainValidationSplitModel] { val bestModelPath = new Path(path, "bestModel").toString val bestModel = DefaultParamsReader.loadParamsInstance[Model[_]](bestModelPath, sc) val validationMetrics = (metadata.metadata \ "validationMetrics").extract[Seq[Double]].toArray + val persistSubModels = (metadata.metadata \ "persistSubModels") + .extractOrElse[Boolean](false) + + val subModels: Option[Array[Model[_]]] = if (persistSubModels) { + val subModelsPath = new Path(path, "subModels") + val _subModels = Array.fill[Model[_]](estimatorParamMaps.length)(null) + for (paramIndex <- 0 until estimatorParamMaps.length) { + val modelPath = new Path(subModelsPath, paramIndex.toString).toString + _subModels(paramIndex) = + DefaultParamsReader.loadParamsInstance(modelPath, sc) + } + Some(_subModels) + } else None val model = new TrainValidationSplitModel(metadata.uid, bestModel, validationMetrics) + .setSubModels(subModels) model.set(model.estimator, estimator) .set(model.evaluator, evaluator) .set(model.estimatorParamMaps, estimatorParamMaps) diff --git a/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala b/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala index 7188da3531267..a616907800969 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala @@ -18,6 +18,9 @@ package org.apache.spark.ml.util import java.io.IOException +import java.util.Locale + +import scala.collection.mutable import org.apache.hadoop.fs.Path import org.json4s._ @@ -107,6 +110,22 @@ abstract class MLWriter extends BaseReadWrite with Logging { @Since("1.6.0") protected def saveImpl(path: String): Unit + /** + * Map to store extra options for this writer. + */ + protected val optionMap: mutable.Map[String, String] = new mutable.HashMap[String, String]() + + /** + * Adds an option to the underlying MLWriter. See the documentation for the specific model's + * writer for possible options. The option name (key) is case-insensitive. + */ + @Since("2.3.0") + def option(key: String, value: String): this.type = { + require(key != null && !key.isEmpty) + optionMap.put(key.toLowerCase(Locale.ROOT), value) + this + } + /** * Overwrites if the output path already exists. */ diff --git a/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala index 853eeb39bf8df..15dade2627090 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala @@ -17,6 +17,8 @@ package org.apache.spark.ml.tuning +import java.io.File + import org.apache.spark.SparkFunSuite import org.apache.spark.ml.{Estimator, Model, Pipeline} import org.apache.spark.ml.classification.{LogisticRegression, LogisticRegressionModel, OneVsRest} @@ -27,7 +29,7 @@ import org.apache.spark.ml.linalg.Vectors import org.apache.spark.ml.param.ParamMap import org.apache.spark.ml.param.shared.HasInputCol import org.apache.spark.ml.regression.LinearRegression -import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} +import org.apache.spark.ml.util.{DefaultReadWriteTest, Identifiable, MLTestingUtils} import org.apache.spark.mllib.util.{LinearDataGenerator, MLlibTestSparkContext} import org.apache.spark.sql.Dataset import org.apache.spark.sql.types.StructType @@ -161,6 +163,7 @@ class CrossValidatorSuite .setEstimatorParamMaps(paramMaps) .setSeed(42L) .setParallelism(2) + .setCollectSubModels(true) val cv2 = testDefaultReadWrite(cv, testParams = false) @@ -168,6 +171,7 @@ class CrossValidatorSuite assert(cv.getNumFolds === cv2.getNumFolds) assert(cv.getSeed === cv2.getSeed) assert(cv.getParallelism === cv2.getParallelism) + assert(cv.getCollectSubModels === cv2.getCollectSubModels) assert(cv2.getEvaluator.isInstanceOf[BinaryClassificationEvaluator]) val evaluator2 = cv2.getEvaluator.asInstanceOf[BinaryClassificationEvaluator] @@ -187,6 +191,54 @@ class CrossValidatorSuite .compareParamMaps(cv.getEstimatorParamMaps, cv2.getEstimatorParamMaps) } + test("CrossValidator expose sub models") { + val lr = new LogisticRegression + val lrParamMaps = new ParamGridBuilder() + .addGrid(lr.regParam, Array(0.001, 1000.0)) + .addGrid(lr.maxIter, Array(0, 3)) + .build() + val eval = new BinaryClassificationEvaluator + val numFolds = 3 + val subPath = new File(tempDir, "testCrossValidatorSubModels") + + val cv = new CrossValidator() + .setEstimator(lr) + .setEstimatorParamMaps(lrParamMaps) + .setEvaluator(eval) + .setNumFolds(numFolds) + .setParallelism(1) + .setCollectSubModels(true) + + val cvModel = cv.fit(dataset) + + assert(cvModel.hasSubModels && cvModel.subModels.length == numFolds) + cvModel.subModels.foreach(array => assert(array.length == lrParamMaps.length)) + + // Test the default value for option "persistSubModel" to be "true" + val savingPathWithSubModels = new File(subPath, "cvModel3").getPath + cvModel.save(savingPathWithSubModels) + val cvModel3 = CrossValidatorModel.load(savingPathWithSubModels) + assert(cvModel3.hasSubModels && cvModel3.subModels.length == numFolds) + cvModel3.subModels.foreach(array => assert(array.length == lrParamMaps.length)) + + val savingPathWithoutSubModels = new File(subPath, "cvModel2").getPath + cvModel.write.option("persistSubModels", "false").save(savingPathWithoutSubModels) + val cvModel2 = CrossValidatorModel.load(savingPathWithoutSubModels) + assert(!cvModel2.hasSubModels) + + for (i <- 0 until numFolds) { + for (j <- 0 until lrParamMaps.length) { + assert(cvModel.subModels(i)(j).asInstanceOf[LogisticRegressionModel].uid === + cvModel3.subModels(i)(j).asInstanceOf[LogisticRegressionModel].uid) + } + } + + val savingPathTestingIllegalParam = new File(subPath, "cvModel4").getPath + intercept[IllegalArgumentException] { + cvModel2.write.option("persistSubModels", "true").save(savingPathTestingIllegalParam) + } + } + test("read/write: CrossValidator with nested estimator") { val ova = new OneVsRest().setClassifier(new LogisticRegression) val evaluator = new MulticlassClassificationEvaluator() diff --git a/mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala index f8d9c66be2c40..9024342d9c831 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala @@ -17,6 +17,8 @@ package org.apache.spark.ml.tuning +import java.io.File + import org.apache.spark.SparkFunSuite import org.apache.spark.ml.{Estimator, Model} import org.apache.spark.ml.classification.{LogisticRegression, LogisticRegressionModel, OneVsRest} @@ -26,7 +28,7 @@ import org.apache.spark.ml.linalg.Vectors import org.apache.spark.ml.param.ParamMap import org.apache.spark.ml.param.shared.HasInputCol import org.apache.spark.ml.regression.LinearRegression -import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} +import org.apache.spark.ml.util.{DefaultReadWriteTest, Identifiable, MLTestingUtils} import org.apache.spark.mllib.util.{LinearDataGenerator, MLlibTestSparkContext} import org.apache.spark.sql.Dataset import org.apache.spark.sql.types.StructType @@ -161,12 +163,14 @@ class TrainValidationSplitSuite .setEstimatorParamMaps(paramMaps) .setSeed(42L) .setParallelism(2) + .setCollectSubModels(true) val tvs2 = testDefaultReadWrite(tvs, testParams = false) assert(tvs.getTrainRatio === tvs2.getTrainRatio) assert(tvs.getSeed === tvs2.getSeed) assert(tvs.getParallelism === tvs2.getParallelism) + assert(tvs.getCollectSubModels === tvs2.getCollectSubModels) ValidatorParamsSuiteHelpers .compareParamMaps(tvs.getEstimatorParamMaps, tvs2.getEstimatorParamMaps) @@ -181,6 +185,48 @@ class TrainValidationSplitSuite } } + test("TrainValidationSplit expose sub models") { + val lr = new LogisticRegression + val lrParamMaps = new ParamGridBuilder() + .addGrid(lr.regParam, Array(0.001, 1000.0)) + .addGrid(lr.maxIter, Array(0, 3)) + .build() + val eval = new BinaryClassificationEvaluator + val subPath = new File(tempDir, "testTrainValidationSplitSubModels") + + val tvs = new TrainValidationSplit() + .setEstimator(lr) + .setEstimatorParamMaps(lrParamMaps) + .setEvaluator(eval) + .setParallelism(1) + .setCollectSubModels(true) + + val tvsModel = tvs.fit(dataset) + + assert(tvsModel.hasSubModels && tvsModel.subModels.length == lrParamMaps.length) + + // Test the default value for option "persistSubModel" to be "true" + val savingPathWithSubModels = new File(subPath, "tvsModel3").getPath + tvsModel.save(savingPathWithSubModels) + val tvsModel3 = TrainValidationSplitModel.load(savingPathWithSubModels) + assert(tvsModel3.hasSubModels && tvsModel3.subModels.length == lrParamMaps.length) + + val savingPathWithoutSubModels = new File(subPath, "tvsModel2").getPath + tvsModel.write.option("persistSubModels", "false").save(savingPathWithoutSubModels) + val tvsModel2 = TrainValidationSplitModel.load(savingPathWithoutSubModels) + assert(!tvsModel2.hasSubModels) + + for (i <- 0 until lrParamMaps.length) { + assert(tvsModel.subModels(i).asInstanceOf[LogisticRegressionModel].uid === + tvsModel3.subModels(i).asInstanceOf[LogisticRegressionModel].uid) + } + + val savingPathTestingIllegalParam = new File(subPath, "tvsModel4").getPath + intercept[IllegalArgumentException] { + tvsModel2.write.option("persistSubModels", "true").save(savingPathTestingIllegalParam) + } + } + test("read/write: TrainValidationSplit with nested estimator") { val ova = new OneVsRest() .setClassifier(new LogisticRegression) diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index 7f18b40f9d960..915c7e2e2fda3 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -78,7 +78,11 @@ object MimaExcludes { // [SPARK-14280] Support Scala 2.12 ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.FutureAction.transformWith"), - ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.FutureAction.transform") + ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.FutureAction.transform"), + + // [SPARK-21087] CrossValidator, TrainValidationSplit expose sub models after fitting: Scala + ProblemFilters.exclude[FinalClassProblem]("org.apache.spark.ml.tuning.CrossValidatorModel$CrossValidatorModelWriter"), + ProblemFilters.exclude[FinalClassProblem]("org.apache.spark.ml.tuning.TrainValidationSplitModel$TrainValidationSplitModelWriter") ) // Exclude rules for 2.2.x From 1e6f760593d81def059c514d34173bf2777d71ec Mon Sep 17 00:00:00 2001 From: WeichenXu Date: Tue, 14 Nov 2017 16:58:18 -0800 Subject: [PATCH 1673/1765] [SPARK-12375][ML] VectorIndexerModel support handle unseen categories via handleInvalid ## What changes were proposed in this pull request? Support skip/error/keep strategy, similar to `StringIndexer`. Implemented via `try...catch`, so that it can avoid possible performance impact. ## How was this patch tested? Unit test added. Author: WeichenXu Closes #19588 from WeichenXu123/handle_invalid_for_vector_indexer. --- .../spark/ml/feature/VectorIndexer.scala | 92 +++++++++++++++++-- .../spark/ml/feature/VectorIndexerSuite.scala | 39 ++++++++ 2 files changed, 121 insertions(+), 10 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala index d371da762c55d..3403ec4259b86 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala @@ -18,12 +18,13 @@ package org.apache.spark.ml.feature import java.lang.{Double => JDouble, Integer => JInt} -import java.util.{Map => JMap} +import java.util.{Map => JMap, NoSuchElementException} import scala.collection.JavaConverters._ import org.apache.hadoop.fs.Path +import org.apache.spark.SparkException import org.apache.spark.annotation.Since import org.apache.spark.ml.{Estimator, Model} import org.apache.spark.ml.attribute._ @@ -37,7 +38,27 @@ import org.apache.spark.sql.types.{StructField, StructType} import org.apache.spark.util.collection.OpenHashSet /** Private trait for params for VectorIndexer and VectorIndexerModel */ -private[ml] trait VectorIndexerParams extends Params with HasInputCol with HasOutputCol { +private[ml] trait VectorIndexerParams extends Params with HasInputCol with HasOutputCol + with HasHandleInvalid { + + /** + * Param for how to handle invalid data (unseen labels or NULL values). + * Note: this param only applies to categorical features, not continuous ones. + * Options are: + * 'skip': filter out rows with invalid data. + * 'error': throw an error. + * 'keep': put invalid data in a special additional bucket, at index numCategories. + * Default value: "error" + * @group param + */ + @Since("2.3.0") + override val handleInvalid: Param[String] = new Param[String](this, "handleInvalid", + "How to handle invalid data (unseen labels or NULL values). " + + "Options are 'skip' (filter out rows with invalid data), 'error' (throw an error), " + + "or 'keep' (put invalid data in a special additional bucket, at index numLabels).", + ParamValidators.inArray(VectorIndexer.supportedHandleInvalids)) + + setDefault(handleInvalid, VectorIndexer.ERROR_INVALID) /** * Threshold for the number of values a categorical feature can take. @@ -113,6 +134,10 @@ class VectorIndexer @Since("1.4.0") ( @Since("1.4.0") def setOutputCol(value: String): this.type = set(outputCol, value) + /** @group setParam */ + @Since("2.3.0") + def setHandleInvalid(value: String): this.type = set(handleInvalid, value) + @Since("2.0.0") override def fit(dataset: Dataset[_]): VectorIndexerModel = { transformSchema(dataset.schema, logging = true) @@ -148,6 +173,11 @@ class VectorIndexer @Since("1.4.0") ( @Since("1.6.0") object VectorIndexer extends DefaultParamsReadable[VectorIndexer] { + private[feature] val SKIP_INVALID: String = "skip" + private[feature] val ERROR_INVALID: String = "error" + private[feature] val KEEP_INVALID: String = "keep" + private[feature] val supportedHandleInvalids: Array[String] = + Array(SKIP_INVALID, ERROR_INVALID, KEEP_INVALID) @Since("1.6.0") override def load(path: String): VectorIndexer = super.load(path) @@ -287,9 +317,15 @@ class VectorIndexerModel private[ml] ( while (featureIndex < numFeatures) { if (categoryMaps.contains(featureIndex)) { // categorical feature - val featureValues: Array[String] = + val rawFeatureValues: Array[String] = categoryMaps(featureIndex).toArray.sortBy(_._1).map(_._1).map(_.toString) - if (featureValues.length == 2) { + + val featureValues = if (getHandleInvalid == VectorIndexer.KEEP_INVALID) { + (rawFeatureValues.toList :+ "__unknown").toArray + } else { + rawFeatureValues + } + if (featureValues.length == 2 && getHandleInvalid != VectorIndexer.KEEP_INVALID) { attrs(featureIndex) = new BinaryAttribute(index = Some(featureIndex), values = Some(featureValues)) } else { @@ -311,22 +347,39 @@ class VectorIndexerModel private[ml] ( // TODO: Check more carefully about whether this whole class will be included in a closure. /** Per-vector transform function */ - private val transformFunc: Vector => Vector = { + private lazy val transformFunc: Vector => Vector = { val sortedCatFeatureIndices = categoryMaps.keys.toArray.sorted val localVectorMap = categoryMaps val localNumFeatures = numFeatures + val localHandleInvalid = getHandleInvalid val f: Vector => Vector = { (v: Vector) => assert(v.size == localNumFeatures, "VectorIndexerModel expected vector of length" + s" $numFeatures but found length ${v.size}") v match { case dv: DenseVector => + var hasInvalid = false val tmpv = dv.copy localVectorMap.foreach { case (featureIndex: Int, categoryMap: Map[Double, Int]) => - tmpv.values(featureIndex) = categoryMap(tmpv(featureIndex)) + try { + tmpv.values(featureIndex) = categoryMap(tmpv(featureIndex)) + } catch { + case _: NoSuchElementException => + localHandleInvalid match { + case VectorIndexer.ERROR_INVALID => + throw new SparkException(s"VectorIndexer encountered invalid value " + + s"${tmpv(featureIndex)} on feature index ${featureIndex}. To handle " + + s"or skip invalid value, try setting VectorIndexer.handleInvalid.") + case VectorIndexer.KEEP_INVALID => + tmpv.values(featureIndex) = categoryMap.size + case VectorIndexer.SKIP_INVALID => + hasInvalid = true + } + } } - tmpv + if (hasInvalid) null else tmpv case sv: SparseVector => // We use the fact that categorical value 0 is always mapped to index 0. + var hasInvalid = false val tmpv = sv.copy var catFeatureIdx = 0 // index into sortedCatFeatureIndices var k = 0 // index into non-zero elements of sparse vector @@ -337,12 +390,26 @@ class VectorIndexerModel private[ml] ( } else if (featureIndex > tmpv.indices(k)) { k += 1 } else { - tmpv.values(k) = localVectorMap(featureIndex)(tmpv.values(k)) + try { + tmpv.values(k) = localVectorMap(featureIndex)(tmpv.values(k)) + } catch { + case _: NoSuchElementException => + localHandleInvalid match { + case VectorIndexer.ERROR_INVALID => + throw new SparkException(s"VectorIndexer encountered invalid value " + + s"${tmpv.values(k)} on feature index ${featureIndex}. To handle " + + s"or skip invalid value, try setting VectorIndexer.handleInvalid.") + case VectorIndexer.KEEP_INVALID => + tmpv.values(k) = localVectorMap(featureIndex).size + case VectorIndexer.SKIP_INVALID => + hasInvalid = true + } + } catFeatureIdx += 1 k += 1 } } - tmpv + if (hasInvalid) null else tmpv } } f @@ -362,7 +429,12 @@ class VectorIndexerModel private[ml] ( val newField = prepOutputField(dataset.schema) val transformUDF = udf { (vector: Vector) => transformFunc(vector) } val newCol = transformUDF(dataset($(inputCol))) - dataset.withColumn($(outputCol), newCol, newField.metadata) + val ds = dataset.withColumn($(outputCol), newCol, newField.metadata) + if (getHandleInvalid == VectorIndexer.SKIP_INVALID) { + ds.na.drop(Array($(outputCol))) + } else { + ds + } } @Since("1.4.0") diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/VectorIndexerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/VectorIndexerSuite.scala index f2cca8aa82e85..69a7b75e32eb7 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/VectorIndexerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/VectorIndexerSuite.scala @@ -38,6 +38,8 @@ class VectorIndexerSuite extends SparkFunSuite with MLlibTestSparkContext // identical, of length 3 @transient var densePoints1: DataFrame = _ @transient var sparsePoints1: DataFrame = _ + @transient var densePoints1TestInvalid: DataFrame = _ + @transient var sparsePoints1TestInvalid: DataFrame = _ @transient var point1maxes: Array[Double] = _ // identical, of length 2 @@ -55,11 +57,19 @@ class VectorIndexerSuite extends SparkFunSuite with MLlibTestSparkContext Vectors.dense(0.0, 1.0, 2.0), Vectors.dense(0.0, 0.0, -1.0), Vectors.dense(1.0, 3.0, 2.0)) + val densePoints1SeqTestInvalid = densePoints1Seq ++ Seq( + Vectors.dense(10.0, 2.0, 0.0), + Vectors.dense(0.0, 10.0, 2.0), + Vectors.dense(1.0, 3.0, 10.0)) val sparsePoints1Seq = Seq( Vectors.sparse(3, Array(0, 1), Array(1.0, 2.0)), Vectors.sparse(3, Array(1, 2), Array(1.0, 2.0)), Vectors.sparse(3, Array(2), Array(-1.0)), Vectors.sparse(3, Array(0, 1, 2), Array(1.0, 3.0, 2.0))) + val sparsePoints1SeqTestInvalid = sparsePoints1Seq ++ Seq( + Vectors.sparse(3, Array(0, 1), Array(10.0, 2.0)), + Vectors.sparse(3, Array(1, 2), Array(10.0, 2.0)), + Vectors.sparse(3, Array(0, 1, 2), Array(1.0, 3.0, 10.0))) point1maxes = Array(1.0, 3.0, 2.0) val densePoints2Seq = Seq( @@ -88,6 +98,8 @@ class VectorIndexerSuite extends SparkFunSuite with MLlibTestSparkContext densePoints1 = densePoints1Seq.map(FeatureData).toDF() sparsePoints1 = sparsePoints1Seq.map(FeatureData).toDF() + densePoints1TestInvalid = densePoints1SeqTestInvalid.map(FeatureData).toDF() + sparsePoints1TestInvalid = sparsePoints1SeqTestInvalid.map(FeatureData).toDF() densePoints2 = densePoints2Seq.map(FeatureData).toDF() sparsePoints2 = sparsePoints2Seq.map(FeatureData).toDF() badPoints = badPointsSeq.map(FeatureData).toDF() @@ -219,6 +231,33 @@ class VectorIndexerSuite extends SparkFunSuite with MLlibTestSparkContext checkCategoryMaps(densePoints2, maxCategories = 2, categoricalFeatures = Set(1, 3)) } + test("handle invalid") { + for ((points, pointsTestInvalid) <- Seq((densePoints1, densePoints1TestInvalid), + (sparsePoints1, sparsePoints1TestInvalid))) { + val vectorIndexer = getIndexer.setMaxCategories(4).setHandleInvalid("error") + val model = vectorIndexer.fit(points) + intercept[SparkException] { + model.transform(pointsTestInvalid).collect() + } + val vectorIndexer1 = getIndexer.setMaxCategories(4).setHandleInvalid("skip") + val model1 = vectorIndexer1.fit(points) + val invalidTransformed1 = model1.transform(pointsTestInvalid).select("indexed") + .collect().map(_(0)) + val transformed1 = model1.transform(points).select("indexed").collect().map(_(0)) + assert(transformed1 === invalidTransformed1) + + val vectorIndexer2 = getIndexer.setMaxCategories(4).setHandleInvalid("keep") + val model2 = vectorIndexer2.fit(points) + val invalidTransformed2 = model2.transform(pointsTestInvalid).select("indexed") + .collect().map(_(0)) + assert(invalidTransformed2 === transformed1 ++ Array( + Vectors.dense(2.0, 2.0, 0.0), + Vectors.dense(0.0, 4.0, 2.0), + Vectors.dense(1.0, 3.0, 3.0)) + ) + } + } + test("Maintain sparsity for sparse vectors") { def checkSparsity(data: DataFrame, maxCategories: Int): Unit = { val points = data.collect().map(_.getAs[Vector](0)) From dce1610ae376af00712ba7f4c99bfb4c006dbaec Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Wed, 15 Nov 2017 14:42:37 +0100 Subject: [PATCH 1674/1765] [SPARK-22514][SQL] move ColumnVector.Array and ColumnarBatch.Row to individual files ## What changes were proposed in this pull request? Logically the `Array` doesn't belong to `ColumnVector`, and `Row` doesn't belong to `ColumnarBatch`. e.g. `ColumnVector` needs to return `Array` for `getArray`, and `Row` for `getStruct`. `Array` and `Row` can return each other with the `getArray`/`getStruct` methods. This is also a step to make `ColumnVector` public, it's cleaner to have `Array` and `Row` as top-level classes. This PR is just code moving around, with 2 renaming: `Array` -> `VectorBasedArray`, `Row` -> `VectorBasedRow`. ## How was this patch tested? existing tests. Author: Wenchen Fan Closes #19740 from cloud-fan/vector. --- .../vectorized/AggregateHashMap.java | 2 +- .../vectorized/ArrowColumnVector.java | 6 +- .../execution/vectorized/ColumnVector.java | 202 +---------- .../vectorized/ColumnVectorUtils.java | 2 +- .../execution/vectorized/ColumnarArray.java | 208 +++++++++++ .../execution/vectorized/ColumnarBatch.java | 326 +---------------- .../sql/execution/vectorized/ColumnarRow.java | 327 ++++++++++++++++++ .../vectorized/OffHeapColumnVector.java | 2 +- .../vectorized/OnHeapColumnVector.java | 2 +- .../vectorized/WritableColumnVector.java | 14 +- .../aggregate/HashAggregateExec.scala | 10 +- .../VectorizedHashMapGenerator.scala | 12 +- .../vectorized/ColumnVectorSuite.scala | 40 +-- 13 files changed, 597 insertions(+), 556 deletions(-) create mode 100644 sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarArray.java create mode 100644 sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarRow.java diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/AggregateHashMap.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/AggregateHashMap.java index cb3ad4eab1f60..9467435435d1f 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/AggregateHashMap.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/AggregateHashMap.java @@ -72,7 +72,7 @@ public AggregateHashMap(StructType schema) { this(schema, DEFAULT_CAPACITY, DEFAULT_LOAD_FACTOR, DEFAULT_MAX_STEPS); } - public ColumnarBatch.Row findOrInsert(long key) { + public ColumnarRow findOrInsert(long key) { int idx = find(key); if (idx != -1 && buckets[idx] == -1) { columnVectors[0].putLong(numRows, key); diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ArrowColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ArrowColumnVector.java index 51ea719f8c4a6..949035bfb177c 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ArrowColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ArrowColumnVector.java @@ -251,7 +251,7 @@ public int getArrayOffset(int rowId) { } @Override - public void loadBytes(ColumnVector.Array array) { + public void loadBytes(ColumnarArray array) { throw new UnsupportedOperationException(); } @@ -330,7 +330,7 @@ public ArrowColumnVector(ValueVector vector) { childColumns = new ArrowColumnVector[1]; childColumns[0] = new ArrowColumnVector(listVector.getDataVector()); - resultArray = new ColumnVector.Array(childColumns[0]); + resultArray = new ColumnarArray(childColumns[0]); } else if (vector instanceof MapVector) { MapVector mapVector = (MapVector) vector; accessor = new StructAccessor(mapVector); @@ -339,7 +339,7 @@ public ArrowColumnVector(ValueVector vector) { for (int i = 0; i < childColumns.length; ++i) { childColumns[i] = new ArrowColumnVector(mapVector.getVectorById(i)); } - resultStruct = new ColumnarBatch.Row(childColumns); + resultStruct = new ColumnarRow(childColumns); } else { throw new UnsupportedOperationException(); } diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVector.java index c4b519f0b153f..666fd63fdcf2f 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVector.java @@ -16,11 +16,9 @@ */ package org.apache.spark.sql.execution.vectorized; -import org.apache.spark.sql.catalyst.InternalRow; -import org.apache.spark.sql.catalyst.util.ArrayData; import org.apache.spark.sql.catalyst.util.MapData; -import org.apache.spark.sql.types.*; -import org.apache.spark.unsafe.types.CalendarInterval; +import org.apache.spark.sql.types.DataType; +import org.apache.spark.sql.types.Decimal; import org.apache.spark.unsafe.types.UTF8String; /** @@ -42,190 +40,6 @@ * ColumnVectors are intended to be reused. */ public abstract class ColumnVector implements AutoCloseable { - - /** - * Holder object to return an array. This object is intended to be reused. Callers should - * copy the data out if it needs to be stored. - */ - public static final class Array extends ArrayData { - // The data for this array. This array contains elements from - // data[offset] to data[offset + length). - public final ColumnVector data; - public int length; - public int offset; - - // Populate if binary data is required for the Array. This is stored here as an optimization - // for string data. - public byte[] byteArray; - public int byteArrayOffset; - - // Reused staging buffer, used for loading from offheap. - protected byte[] tmpByteArray = new byte[1]; - - protected Array(ColumnVector data) { - this.data = data; - } - - @Override - public int numElements() { return length; } - - @Override - public ArrayData copy() { - throw new UnsupportedOperationException(); - } - - @Override - public boolean[] toBooleanArray() { return data.getBooleans(offset, length); } - - @Override - public byte[] toByteArray() { return data.getBytes(offset, length); } - - @Override - public short[] toShortArray() { return data.getShorts(offset, length); } - - @Override - public int[] toIntArray() { return data.getInts(offset, length); } - - @Override - public long[] toLongArray() { return data.getLongs(offset, length); } - - @Override - public float[] toFloatArray() { return data.getFloats(offset, length); } - - @Override - public double[] toDoubleArray() { return data.getDoubles(offset, length); } - - // TODO: this is extremely expensive. - @Override - public Object[] array() { - DataType dt = data.dataType(); - Object[] list = new Object[length]; - try { - for (int i = 0; i < length; i++) { - if (!data.isNullAt(offset + i)) { - list[i] = get(i, dt); - } - } - return list; - } catch(Exception e) { - throw new RuntimeException("Could not get the array", e); - } - } - - @Override - public boolean isNullAt(int ordinal) { return data.isNullAt(offset + ordinal); } - - @Override - public boolean getBoolean(int ordinal) { - return data.getBoolean(offset + ordinal); - } - - @Override - public byte getByte(int ordinal) { return data.getByte(offset + ordinal); } - - @Override - public short getShort(int ordinal) { - return data.getShort(offset + ordinal); - } - - @Override - public int getInt(int ordinal) { return data.getInt(offset + ordinal); } - - @Override - public long getLong(int ordinal) { return data.getLong(offset + ordinal); } - - @Override - public float getFloat(int ordinal) { - return data.getFloat(offset + ordinal); - } - - @Override - public double getDouble(int ordinal) { return data.getDouble(offset + ordinal); } - - @Override - public Decimal getDecimal(int ordinal, int precision, int scale) { - return data.getDecimal(offset + ordinal, precision, scale); - } - - @Override - public UTF8String getUTF8String(int ordinal) { - return data.getUTF8String(offset + ordinal); - } - - @Override - public byte[] getBinary(int ordinal) { - return data.getBinary(offset + ordinal); - } - - @Override - public CalendarInterval getInterval(int ordinal) { - int month = data.getChildColumn(0).getInt(offset + ordinal); - long microseconds = data.getChildColumn(1).getLong(offset + ordinal); - return new CalendarInterval(month, microseconds); - } - - @Override - public InternalRow getStruct(int ordinal, int numFields) { - return data.getStruct(offset + ordinal); - } - - @Override - public ArrayData getArray(int ordinal) { - return data.getArray(offset + ordinal); - } - - @Override - public MapData getMap(int ordinal) { - throw new UnsupportedOperationException(); - } - - @Override - public Object get(int ordinal, DataType dataType) { - if (dataType instanceof BooleanType) { - return getBoolean(ordinal); - } else if (dataType instanceof ByteType) { - return getByte(ordinal); - } else if (dataType instanceof ShortType) { - return getShort(ordinal); - } else if (dataType instanceof IntegerType) { - return getInt(ordinal); - } else if (dataType instanceof LongType) { - return getLong(ordinal); - } else if (dataType instanceof FloatType) { - return getFloat(ordinal); - } else if (dataType instanceof DoubleType) { - return getDouble(ordinal); - } else if (dataType instanceof StringType) { - return getUTF8String(ordinal); - } else if (dataType instanceof BinaryType) { - return getBinary(ordinal); - } else if (dataType instanceof DecimalType) { - DecimalType t = (DecimalType) dataType; - return getDecimal(ordinal, t.precision(), t.scale()); - } else if (dataType instanceof DateType) { - return getInt(ordinal); - } else if (dataType instanceof TimestampType) { - return getLong(ordinal); - } else if (dataType instanceof ArrayType) { - return getArray(ordinal); - } else if (dataType instanceof StructType) { - return getStruct(ordinal, ((StructType)dataType).fields().length); - } else if (dataType instanceof MapType) { - return getMap(ordinal); - } else if (dataType instanceof CalendarIntervalType) { - return getInterval(ordinal); - } else { - throw new UnsupportedOperationException("Datatype not supported " + dataType); - } - } - - @Override - public void update(int ordinal, Object value) { throw new UnsupportedOperationException(); } - - @Override - public void setNullAt(int ordinal) { throw new UnsupportedOperationException(); } - } - /** * Returns the data type of this column. */ @@ -350,7 +164,7 @@ public Object get(int ordinal, DataType dataType) { /** * Returns a utility object to get structs. */ - public ColumnarBatch.Row getStruct(int rowId) { + public ColumnarRow getStruct(int rowId) { resultStruct.rowId = rowId; return resultStruct; } @@ -359,7 +173,7 @@ public ColumnarBatch.Row getStruct(int rowId) { * Returns a utility object to get structs. * provided to keep API compatibility with InternalRow for code generation */ - public ColumnarBatch.Row getStruct(int rowId, int size) { + public ColumnarRow getStruct(int rowId, int size) { resultStruct.rowId = rowId; return resultStruct; } @@ -367,7 +181,7 @@ public ColumnarBatch.Row getStruct(int rowId, int size) { /** * Returns the array at rowid. */ - public final ColumnVector.Array getArray(int rowId) { + public final ColumnarArray getArray(int rowId) { resultArray.length = getArrayLength(rowId); resultArray.offset = getArrayOffset(rowId); return resultArray; @@ -376,7 +190,7 @@ public final ColumnVector.Array getArray(int rowId) { /** * Loads the data into array.byteArray. */ - public abstract void loadBytes(ColumnVector.Array array); + public abstract void loadBytes(ColumnarArray array); /** * Returns the value for rowId. @@ -423,12 +237,12 @@ public MapData getMap(int ordinal) { /** * Reusable Array holder for getArray(). */ - protected ColumnVector.Array resultArray; + protected ColumnarArray resultArray; /** * Reusable Struct holder for getStruct(). */ - protected ColumnarBatch.Row resultStruct; + protected ColumnarRow resultStruct; /** * The Dictionary for this column. diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVectorUtils.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVectorUtils.java index adb859ed17757..b4b5f0a265934 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVectorUtils.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVectorUtils.java @@ -98,7 +98,7 @@ public static void populate(WritableColumnVector col, InternalRow row, int field * For example, an array of IntegerType will return an int[]. * Throws exceptions for unhandled schemas. */ - public static Object toPrimitiveJavaArray(ColumnVector.Array array) { + public static Object toPrimitiveJavaArray(ColumnarArray array) { DataType dt = array.data.dataType(); if (dt instanceof IntegerType) { int[] result = new int[array.length]; diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarArray.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarArray.java new file mode 100644 index 0000000000000..5e88ce0321084 --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarArray.java @@ -0,0 +1,208 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.execution.vectorized; + +import org.apache.spark.sql.catalyst.util.ArrayData; +import org.apache.spark.sql.catalyst.util.MapData; +import org.apache.spark.sql.types.*; +import org.apache.spark.unsafe.types.CalendarInterval; +import org.apache.spark.unsafe.types.UTF8String; + +/** + * Array abstraction in {@link ColumnVector}. The instance of this class is intended + * to be reused, callers should copy the data out if it needs to be stored. + */ +public final class ColumnarArray extends ArrayData { + // The data for this array. This array contains elements from + // data[offset] to data[offset + length). + public final ColumnVector data; + public int length; + public int offset; + + // Populate if binary data is required for the Array. This is stored here as an optimization + // for string data. + public byte[] byteArray; + public int byteArrayOffset; + + // Reused staging buffer, used for loading from offheap. + protected byte[] tmpByteArray = new byte[1]; + + protected ColumnarArray(ColumnVector data) { + this.data = data; + } + + @Override + public int numElements() { + return length; + } + + @Override + public ArrayData copy() { + throw new UnsupportedOperationException(); + } + + @Override + public boolean[] toBooleanArray() { return data.getBooleans(offset, length); } + + @Override + public byte[] toByteArray() { return data.getBytes(offset, length); } + + @Override + public short[] toShortArray() { return data.getShorts(offset, length); } + + @Override + public int[] toIntArray() { return data.getInts(offset, length); } + + @Override + public long[] toLongArray() { return data.getLongs(offset, length); } + + @Override + public float[] toFloatArray() { return data.getFloats(offset, length); } + + @Override + public double[] toDoubleArray() { return data.getDoubles(offset, length); } + + // TODO: this is extremely expensive. + @Override + public Object[] array() { + DataType dt = data.dataType(); + Object[] list = new Object[length]; + try { + for (int i = 0; i < length; i++) { + if (!data.isNullAt(offset + i)) { + list[i] = get(i, dt); + } + } + return list; + } catch(Exception e) { + throw new RuntimeException("Could not get the array", e); + } + } + + @Override + public boolean isNullAt(int ordinal) { return data.isNullAt(offset + ordinal); } + + @Override + public boolean getBoolean(int ordinal) { + return data.getBoolean(offset + ordinal); + } + + @Override + public byte getByte(int ordinal) { return data.getByte(offset + ordinal); } + + @Override + public short getShort(int ordinal) { + return data.getShort(offset + ordinal); + } + + @Override + public int getInt(int ordinal) { return data.getInt(offset + ordinal); } + + @Override + public long getLong(int ordinal) { return data.getLong(offset + ordinal); } + + @Override + public float getFloat(int ordinal) { + return data.getFloat(offset + ordinal); + } + + @Override + public double getDouble(int ordinal) { return data.getDouble(offset + ordinal); } + + @Override + public Decimal getDecimal(int ordinal, int precision, int scale) { + return data.getDecimal(offset + ordinal, precision, scale); + } + + @Override + public UTF8String getUTF8String(int ordinal) { + return data.getUTF8String(offset + ordinal); + } + + @Override + public byte[] getBinary(int ordinal) { + return data.getBinary(offset + ordinal); + } + + @Override + public CalendarInterval getInterval(int ordinal) { + int month = data.getChildColumn(0).getInt(offset + ordinal); + long microseconds = data.getChildColumn(1).getLong(offset + ordinal); + return new CalendarInterval(month, microseconds); + } + + @Override + public ColumnarRow getStruct(int ordinal, int numFields) { + return data.getStruct(offset + ordinal); + } + + @Override + public ColumnarArray getArray(int ordinal) { + return data.getArray(offset + ordinal); + } + + @Override + public MapData getMap(int ordinal) { + throw new UnsupportedOperationException(); + } + + @Override + public Object get(int ordinal, DataType dataType) { + if (dataType instanceof BooleanType) { + return getBoolean(ordinal); + } else if (dataType instanceof ByteType) { + return getByte(ordinal); + } else if (dataType instanceof ShortType) { + return getShort(ordinal); + } else if (dataType instanceof IntegerType) { + return getInt(ordinal); + } else if (dataType instanceof LongType) { + return getLong(ordinal); + } else if (dataType instanceof FloatType) { + return getFloat(ordinal); + } else if (dataType instanceof DoubleType) { + return getDouble(ordinal); + } else if (dataType instanceof StringType) { + return getUTF8String(ordinal); + } else if (dataType instanceof BinaryType) { + return getBinary(ordinal); + } else if (dataType instanceof DecimalType) { + DecimalType t = (DecimalType) dataType; + return getDecimal(ordinal, t.precision(), t.scale()); + } else if (dataType instanceof DateType) { + return getInt(ordinal); + } else if (dataType instanceof TimestampType) { + return getLong(ordinal); + } else if (dataType instanceof ArrayType) { + return getArray(ordinal); + } else if (dataType instanceof StructType) { + return getStruct(ordinal, ((StructType)dataType).fields().length); + } else if (dataType instanceof MapType) { + return getMap(ordinal); + } else if (dataType instanceof CalendarIntervalType) { + return getInterval(ordinal); + } else { + throw new UnsupportedOperationException("Datatype not supported " + dataType); + } + } + + @Override + public void update(int ordinal, Object value) { throw new UnsupportedOperationException(); } + + @Override + public void setNullAt(int ordinal) { throw new UnsupportedOperationException(); } +} diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarBatch.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarBatch.java index bc546c7c425b1..8849a20d6ceb5 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarBatch.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarBatch.java @@ -16,17 +16,9 @@ */ package org.apache.spark.sql.execution.vectorized; -import java.math.BigDecimal; import java.util.*; -import org.apache.spark.sql.catalyst.InternalRow; -import org.apache.spark.sql.catalyst.expressions.GenericInternalRow; -import org.apache.spark.sql.catalyst.expressions.UnsafeRow; -import org.apache.spark.sql.catalyst.util.ArrayData; -import org.apache.spark.sql.catalyst.util.MapData; -import org.apache.spark.sql.types.*; -import org.apache.spark.unsafe.types.CalendarInterval; -import org.apache.spark.unsafe.types.UTF8String; +import org.apache.spark.sql.types.StructType; /** * This class is the in memory representation of rows as they are streamed through operators. It @@ -48,7 +40,7 @@ public final class ColumnarBatch { private final StructType schema; private final int capacity; private int numRows; - private final ColumnVector[] columns; + final ColumnVector[] columns; // True if the row is filtered. private final boolean[] filteredRows; @@ -60,7 +52,7 @@ public final class ColumnarBatch { private int numRowsFiltered = 0; // Staging row returned from getRow. - final Row row; + final ColumnarRow row; /** * Called to close all the columns in this batch. It is not valid to access the data after @@ -72,313 +64,13 @@ public void close() { } } - /** - * Adapter class to interop with existing components that expect internal row. A lot of - * performance is lost with this translation. - */ - public static final class Row extends InternalRow { - protected int rowId; - private final ColumnarBatch parent; - private final int fixedLenRowSize; - private final ColumnVector[] columns; - private final WritableColumnVector[] writableColumns; - - // Ctor used if this is a top level row. - private Row(ColumnarBatch parent) { - this.parent = parent; - this.fixedLenRowSize = UnsafeRow.calculateFixedPortionByteSize(parent.numCols()); - this.columns = parent.columns; - this.writableColumns = new WritableColumnVector[this.columns.length]; - for (int i = 0; i < this.columns.length; i++) { - if (this.columns[i] instanceof WritableColumnVector) { - this.writableColumns[i] = (WritableColumnVector) this.columns[i]; - } - } - } - - // Ctor used if this is a struct. - protected Row(ColumnVector[] columns) { - this.parent = null; - this.fixedLenRowSize = UnsafeRow.calculateFixedPortionByteSize(columns.length); - this.columns = columns; - this.writableColumns = new WritableColumnVector[this.columns.length]; - for (int i = 0; i < this.columns.length; i++) { - if (this.columns[i] instanceof WritableColumnVector) { - this.writableColumns[i] = (WritableColumnVector) this.columns[i]; - } - } - } - - /** - * Marks this row as being filtered out. This means a subsequent iteration over the rows - * in this batch will not include this row. - */ - public void markFiltered() { - parent.markFiltered(rowId); - } - - public ColumnVector[] columns() { return columns; } - - @Override - public int numFields() { return columns.length; } - - @Override - /** - * Revisit this. This is expensive. This is currently only used in test paths. - */ - public InternalRow copy() { - GenericInternalRow row = new GenericInternalRow(columns.length); - for (int i = 0; i < numFields(); i++) { - if (isNullAt(i)) { - row.setNullAt(i); - } else { - DataType dt = columns[i].dataType(); - if (dt instanceof BooleanType) { - row.setBoolean(i, getBoolean(i)); - } else if (dt instanceof ByteType) { - row.setByte(i, getByte(i)); - } else if (dt instanceof ShortType) { - row.setShort(i, getShort(i)); - } else if (dt instanceof IntegerType) { - row.setInt(i, getInt(i)); - } else if (dt instanceof LongType) { - row.setLong(i, getLong(i)); - } else if (dt instanceof FloatType) { - row.setFloat(i, getFloat(i)); - } else if (dt instanceof DoubleType) { - row.setDouble(i, getDouble(i)); - } else if (dt instanceof StringType) { - row.update(i, getUTF8String(i).copy()); - } else if (dt instanceof BinaryType) { - row.update(i, getBinary(i)); - } else if (dt instanceof DecimalType) { - DecimalType t = (DecimalType)dt; - row.setDecimal(i, getDecimal(i, t.precision(), t.scale()), t.precision()); - } else if (dt instanceof DateType) { - row.setInt(i, getInt(i)); - } else if (dt instanceof TimestampType) { - row.setLong(i, getLong(i)); - } else { - throw new RuntimeException("Not implemented. " + dt); - } - } - } - return row; - } - - @Override - public boolean anyNull() { - throw new UnsupportedOperationException(); - } - - @Override - public boolean isNullAt(int ordinal) { return columns[ordinal].isNullAt(rowId); } - - @Override - public boolean getBoolean(int ordinal) { return columns[ordinal].getBoolean(rowId); } - - @Override - public byte getByte(int ordinal) { return columns[ordinal].getByte(rowId); } - - @Override - public short getShort(int ordinal) { return columns[ordinal].getShort(rowId); } - - @Override - public int getInt(int ordinal) { return columns[ordinal].getInt(rowId); } - - @Override - public long getLong(int ordinal) { return columns[ordinal].getLong(rowId); } - - @Override - public float getFloat(int ordinal) { return columns[ordinal].getFloat(rowId); } - - @Override - public double getDouble(int ordinal) { return columns[ordinal].getDouble(rowId); } - - @Override - public Decimal getDecimal(int ordinal, int precision, int scale) { - if (columns[ordinal].isNullAt(rowId)) return null; - return columns[ordinal].getDecimal(rowId, precision, scale); - } - - @Override - public UTF8String getUTF8String(int ordinal) { - if (columns[ordinal].isNullAt(rowId)) return null; - return columns[ordinal].getUTF8String(rowId); - } - - @Override - public byte[] getBinary(int ordinal) { - if (columns[ordinal].isNullAt(rowId)) return null; - return columns[ordinal].getBinary(rowId); - } - - @Override - public CalendarInterval getInterval(int ordinal) { - if (columns[ordinal].isNullAt(rowId)) return null; - final int months = columns[ordinal].getChildColumn(0).getInt(rowId); - final long microseconds = columns[ordinal].getChildColumn(1).getLong(rowId); - return new CalendarInterval(months, microseconds); - } - - @Override - public InternalRow getStruct(int ordinal, int numFields) { - if (columns[ordinal].isNullAt(rowId)) return null; - return columns[ordinal].getStruct(rowId); - } - - @Override - public ArrayData getArray(int ordinal) { - if (columns[ordinal].isNullAt(rowId)) return null; - return columns[ordinal].getArray(rowId); - } - - @Override - public MapData getMap(int ordinal) { - throw new UnsupportedOperationException(); - } - - @Override - public Object get(int ordinal, DataType dataType) { - if (dataType instanceof BooleanType) { - return getBoolean(ordinal); - } else if (dataType instanceof ByteType) { - return getByte(ordinal); - } else if (dataType instanceof ShortType) { - return getShort(ordinal); - } else if (dataType instanceof IntegerType) { - return getInt(ordinal); - } else if (dataType instanceof LongType) { - return getLong(ordinal); - } else if (dataType instanceof FloatType) { - return getFloat(ordinal); - } else if (dataType instanceof DoubleType) { - return getDouble(ordinal); - } else if (dataType instanceof StringType) { - return getUTF8String(ordinal); - } else if (dataType instanceof BinaryType) { - return getBinary(ordinal); - } else if (dataType instanceof DecimalType) { - DecimalType t = (DecimalType) dataType; - return getDecimal(ordinal, t.precision(), t.scale()); - } else if (dataType instanceof DateType) { - return getInt(ordinal); - } else if (dataType instanceof TimestampType) { - return getLong(ordinal); - } else if (dataType instanceof ArrayType) { - return getArray(ordinal); - } else if (dataType instanceof StructType) { - return getStruct(ordinal, ((StructType)dataType).fields().length); - } else if (dataType instanceof MapType) { - return getMap(ordinal); - } else { - throw new UnsupportedOperationException("Datatype not supported " + dataType); - } - } - - @Override - public void update(int ordinal, Object value) { - if (value == null) { - setNullAt(ordinal); - } else { - DataType dt = columns[ordinal].dataType(); - if (dt instanceof BooleanType) { - setBoolean(ordinal, (boolean) value); - } else if (dt instanceof IntegerType) { - setInt(ordinal, (int) value); - } else if (dt instanceof ShortType) { - setShort(ordinal, (short) value); - } else if (dt instanceof LongType) { - setLong(ordinal, (long) value); - } else if (dt instanceof FloatType) { - setFloat(ordinal, (float) value); - } else if (dt instanceof DoubleType) { - setDouble(ordinal, (double) value); - } else if (dt instanceof DecimalType) { - DecimalType t = (DecimalType) dt; - setDecimal(ordinal, Decimal.apply((BigDecimal) value, t.precision(), t.scale()), - t.precision()); - } else { - throw new UnsupportedOperationException("Datatype not supported " + dt); - } - } - } - - @Override - public void setNullAt(int ordinal) { - getWritableColumn(ordinal).putNull(rowId); - } - - @Override - public void setBoolean(int ordinal, boolean value) { - WritableColumnVector column = getWritableColumn(ordinal); - column.putNotNull(rowId); - column.putBoolean(rowId, value); - } - - @Override - public void setByte(int ordinal, byte value) { - WritableColumnVector column = getWritableColumn(ordinal); - column.putNotNull(rowId); - column.putByte(rowId, value); - } - - @Override - public void setShort(int ordinal, short value) { - WritableColumnVector column = getWritableColumn(ordinal); - column.putNotNull(rowId); - column.putShort(rowId, value); - } - - @Override - public void setInt(int ordinal, int value) { - WritableColumnVector column = getWritableColumn(ordinal); - column.putNotNull(rowId); - column.putInt(rowId, value); - } - - @Override - public void setLong(int ordinal, long value) { - WritableColumnVector column = getWritableColumn(ordinal); - column.putNotNull(rowId); - column.putLong(rowId, value); - } - - @Override - public void setFloat(int ordinal, float value) { - WritableColumnVector column = getWritableColumn(ordinal); - column.putNotNull(rowId); - column.putFloat(rowId, value); - } - - @Override - public void setDouble(int ordinal, double value) { - WritableColumnVector column = getWritableColumn(ordinal); - column.putNotNull(rowId); - column.putDouble(rowId, value); - } - - @Override - public void setDecimal(int ordinal, Decimal value, int precision) { - WritableColumnVector column = getWritableColumn(ordinal); - column.putNotNull(rowId); - column.putDecimal(rowId, value, precision); - } - - private WritableColumnVector getWritableColumn(int ordinal) { - WritableColumnVector column = writableColumns[ordinal]; - assert (!column.isConstant); - return column; - } - } - /** * Returns an iterator over the rows in this batch. This skips rows that are filtered out. */ - public Iterator rowIterator() { + public Iterator rowIterator() { final int maxRows = ColumnarBatch.this.numRows(); - final Row row = new Row(this); - return new Iterator() { + final ColumnarRow row = new ColumnarRow(this); + return new Iterator() { int rowId = 0; @Override @@ -390,7 +82,7 @@ public boolean hasNext() { } @Override - public Row next() { + public ColumnarRow next() { while (rowId < maxRows && ColumnarBatch.this.filteredRows[rowId]) { ++rowId; } @@ -491,7 +183,7 @@ public void setColumn(int ordinal, ColumnVector column) { /** * Returns the row in this batch at `rowId`. Returned row is reused across calls. */ - public ColumnarBatch.Row getRow(int rowId) { + public ColumnarRow getRow(int rowId) { assert(rowId >= 0); assert(rowId < numRows); row.rowId = rowId; @@ -522,6 +214,6 @@ public ColumnarBatch(StructType schema, ColumnVector[] columns, int capacity) { this.capacity = capacity; this.nullFilteredColumns = new HashSet<>(); this.filteredRows = new boolean[capacity]; - this.row = new Row(this); + this.row = new ColumnarRow(this); } } diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarRow.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarRow.java new file mode 100644 index 0000000000000..c75adafd69461 --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarRow.java @@ -0,0 +1,327 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.execution.vectorized; + +import java.math.BigDecimal; + +import org.apache.spark.sql.catalyst.InternalRow; +import org.apache.spark.sql.catalyst.expressions.GenericInternalRow; +import org.apache.spark.sql.catalyst.expressions.UnsafeRow; +import org.apache.spark.sql.catalyst.util.MapData; +import org.apache.spark.sql.types.*; +import org.apache.spark.unsafe.types.CalendarInterval; +import org.apache.spark.unsafe.types.UTF8String; + +/** + * Row abstraction in {@link ColumnVector}. The instance of this class is intended + * to be reused, callers should copy the data out if it needs to be stored. + */ +public final class ColumnarRow extends InternalRow { + protected int rowId; + private final ColumnarBatch parent; + private final int fixedLenRowSize; + private final ColumnVector[] columns; + private final WritableColumnVector[] writableColumns; + + // Ctor used if this is a top level row. + ColumnarRow(ColumnarBatch parent) { + this.parent = parent; + this.fixedLenRowSize = UnsafeRow.calculateFixedPortionByteSize(parent.numCols()); + this.columns = parent.columns; + this.writableColumns = new WritableColumnVector[this.columns.length]; + for (int i = 0; i < this.columns.length; i++) { + if (this.columns[i] instanceof WritableColumnVector) { + this.writableColumns[i] = (WritableColumnVector) this.columns[i]; + } + } + } + + // Ctor used if this is a struct. + ColumnarRow(ColumnVector[] columns) { + this.parent = null; + this.fixedLenRowSize = UnsafeRow.calculateFixedPortionByteSize(columns.length); + this.columns = columns; + this.writableColumns = new WritableColumnVector[this.columns.length]; + for (int i = 0; i < this.columns.length; i++) { + if (this.columns[i] instanceof WritableColumnVector) { + this.writableColumns[i] = (WritableColumnVector) this.columns[i]; + } + } + } + + /** + * Marks this row as being filtered out. This means a subsequent iteration over the rows + * in this batch will not include this row. + */ + public void markFiltered() { + parent.markFiltered(rowId); + } + + public ColumnVector[] columns() { return columns; } + + @Override + public int numFields() { return columns.length; } + + /** + * Revisit this. This is expensive. This is currently only used in test paths. + */ + @Override + public InternalRow copy() { + GenericInternalRow row = new GenericInternalRow(columns.length); + for (int i = 0; i < numFields(); i++) { + if (isNullAt(i)) { + row.setNullAt(i); + } else { + DataType dt = columns[i].dataType(); + if (dt instanceof BooleanType) { + row.setBoolean(i, getBoolean(i)); + } else if (dt instanceof ByteType) { + row.setByte(i, getByte(i)); + } else if (dt instanceof ShortType) { + row.setShort(i, getShort(i)); + } else if (dt instanceof IntegerType) { + row.setInt(i, getInt(i)); + } else if (dt instanceof LongType) { + row.setLong(i, getLong(i)); + } else if (dt instanceof FloatType) { + row.setFloat(i, getFloat(i)); + } else if (dt instanceof DoubleType) { + row.setDouble(i, getDouble(i)); + } else if (dt instanceof StringType) { + row.update(i, getUTF8String(i).copy()); + } else if (dt instanceof BinaryType) { + row.update(i, getBinary(i)); + } else if (dt instanceof DecimalType) { + DecimalType t = (DecimalType)dt; + row.setDecimal(i, getDecimal(i, t.precision(), t.scale()), t.precision()); + } else if (dt instanceof DateType) { + row.setInt(i, getInt(i)); + } else if (dt instanceof TimestampType) { + row.setLong(i, getLong(i)); + } else { + throw new RuntimeException("Not implemented. " + dt); + } + } + } + return row; + } + + @Override + public boolean anyNull() { + throw new UnsupportedOperationException(); + } + + @Override + public boolean isNullAt(int ordinal) { return columns[ordinal].isNullAt(rowId); } + + @Override + public boolean getBoolean(int ordinal) { return columns[ordinal].getBoolean(rowId); } + + @Override + public byte getByte(int ordinal) { return columns[ordinal].getByte(rowId); } + + @Override + public short getShort(int ordinal) { return columns[ordinal].getShort(rowId); } + + @Override + public int getInt(int ordinal) { return columns[ordinal].getInt(rowId); } + + @Override + public long getLong(int ordinal) { return columns[ordinal].getLong(rowId); } + + @Override + public float getFloat(int ordinal) { return columns[ordinal].getFloat(rowId); } + + @Override + public double getDouble(int ordinal) { return columns[ordinal].getDouble(rowId); } + + @Override + public Decimal getDecimal(int ordinal, int precision, int scale) { + if (columns[ordinal].isNullAt(rowId)) return null; + return columns[ordinal].getDecimal(rowId, precision, scale); + } + + @Override + public UTF8String getUTF8String(int ordinal) { + if (columns[ordinal].isNullAt(rowId)) return null; + return columns[ordinal].getUTF8String(rowId); + } + + @Override + public byte[] getBinary(int ordinal) { + if (columns[ordinal].isNullAt(rowId)) return null; + return columns[ordinal].getBinary(rowId); + } + + @Override + public CalendarInterval getInterval(int ordinal) { + if (columns[ordinal].isNullAt(rowId)) return null; + final int months = columns[ordinal].getChildColumn(0).getInt(rowId); + final long microseconds = columns[ordinal].getChildColumn(1).getLong(rowId); + return new CalendarInterval(months, microseconds); + } + + @Override + public ColumnarRow getStruct(int ordinal, int numFields) { + if (columns[ordinal].isNullAt(rowId)) return null; + return columns[ordinal].getStruct(rowId); + } + + @Override + public ColumnarArray getArray(int ordinal) { + if (columns[ordinal].isNullAt(rowId)) return null; + return columns[ordinal].getArray(rowId); + } + + @Override + public MapData getMap(int ordinal) { + throw new UnsupportedOperationException(); + } + + @Override + public Object get(int ordinal, DataType dataType) { + if (dataType instanceof BooleanType) { + return getBoolean(ordinal); + } else if (dataType instanceof ByteType) { + return getByte(ordinal); + } else if (dataType instanceof ShortType) { + return getShort(ordinal); + } else if (dataType instanceof IntegerType) { + return getInt(ordinal); + } else if (dataType instanceof LongType) { + return getLong(ordinal); + } else if (dataType instanceof FloatType) { + return getFloat(ordinal); + } else if (dataType instanceof DoubleType) { + return getDouble(ordinal); + } else if (dataType instanceof StringType) { + return getUTF8String(ordinal); + } else if (dataType instanceof BinaryType) { + return getBinary(ordinal); + } else if (dataType instanceof DecimalType) { + DecimalType t = (DecimalType) dataType; + return getDecimal(ordinal, t.precision(), t.scale()); + } else if (dataType instanceof DateType) { + return getInt(ordinal); + } else if (dataType instanceof TimestampType) { + return getLong(ordinal); + } else if (dataType instanceof ArrayType) { + return getArray(ordinal); + } else if (dataType instanceof StructType) { + return getStruct(ordinal, ((StructType)dataType).fields().length); + } else if (dataType instanceof MapType) { + return getMap(ordinal); + } else { + throw new UnsupportedOperationException("Datatype not supported " + dataType); + } + } + + @Override + public void update(int ordinal, Object value) { + if (value == null) { + setNullAt(ordinal); + } else { + DataType dt = columns[ordinal].dataType(); + if (dt instanceof BooleanType) { + setBoolean(ordinal, (boolean) value); + } else if (dt instanceof IntegerType) { + setInt(ordinal, (int) value); + } else if (dt instanceof ShortType) { + setShort(ordinal, (short) value); + } else if (dt instanceof LongType) { + setLong(ordinal, (long) value); + } else if (dt instanceof FloatType) { + setFloat(ordinal, (float) value); + } else if (dt instanceof DoubleType) { + setDouble(ordinal, (double) value); + } else if (dt instanceof DecimalType) { + DecimalType t = (DecimalType) dt; + setDecimal(ordinal, Decimal.apply((BigDecimal) value, t.precision(), t.scale()), + t.precision()); + } else { + throw new UnsupportedOperationException("Datatype not supported " + dt); + } + } + } + + @Override + public void setNullAt(int ordinal) { + getWritableColumn(ordinal).putNull(rowId); + } + + @Override + public void setBoolean(int ordinal, boolean value) { + WritableColumnVector column = getWritableColumn(ordinal); + column.putNotNull(rowId); + column.putBoolean(rowId, value); + } + + @Override + public void setByte(int ordinal, byte value) { + WritableColumnVector column = getWritableColumn(ordinal); + column.putNotNull(rowId); + column.putByte(rowId, value); + } + + @Override + public void setShort(int ordinal, short value) { + WritableColumnVector column = getWritableColumn(ordinal); + column.putNotNull(rowId); + column.putShort(rowId, value); + } + + @Override + public void setInt(int ordinal, int value) { + WritableColumnVector column = getWritableColumn(ordinal); + column.putNotNull(rowId); + column.putInt(rowId, value); + } + + @Override + public void setLong(int ordinal, long value) { + WritableColumnVector column = getWritableColumn(ordinal); + column.putNotNull(rowId); + column.putLong(rowId, value); + } + + @Override + public void setFloat(int ordinal, float value) { + WritableColumnVector column = getWritableColumn(ordinal); + column.putNotNull(rowId); + column.putFloat(rowId, value); + } + + @Override + public void setDouble(int ordinal, double value) { + WritableColumnVector column = getWritableColumn(ordinal); + column.putNotNull(rowId); + column.putDouble(rowId, value); + } + + @Override + public void setDecimal(int ordinal, Decimal value, int precision) { + WritableColumnVector column = getWritableColumn(ordinal); + column.putNotNull(rowId); + column.putDecimal(rowId, value, precision); + } + + private WritableColumnVector getWritableColumn(int ordinal) { + WritableColumnVector column = writableColumns[ordinal]; + assert (!column.isConstant); + return column; + } +} diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OffHeapColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OffHeapColumnVector.java index a7522ebf5821a..2bf523b7e7198 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OffHeapColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OffHeapColumnVector.java @@ -523,7 +523,7 @@ public int putByteArray(int rowId, byte[] value, int offset, int length) { } @Override - public void loadBytes(ColumnVector.Array array) { + public void loadBytes(ColumnarArray array) { if (array.tmpByteArray.length < array.length) array.tmpByteArray = new byte[array.length]; Platform.copyMemory( null, data + array.offset, array.tmpByteArray, Platform.BYTE_ARRAY_OFFSET, array.length); diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OnHeapColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OnHeapColumnVector.java index 166a39e0fabd9..d699d292711dc 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OnHeapColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OnHeapColumnVector.java @@ -494,7 +494,7 @@ public void putArray(int rowId, int offset, int length) { } @Override - public void loadBytes(ColumnVector.Array array) { + public void loadBytes(ColumnarArray array) { array.byteArray = byteData; array.byteArrayOffset = array.offset; } diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/WritableColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/WritableColumnVector.java index d3a14b9d8bd74..96cfeed34f300 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/WritableColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/WritableColumnVector.java @@ -283,8 +283,8 @@ public final int putByteArray(int rowId, byte[] value) { /** * Returns the value for rowId. */ - private ColumnVector.Array getByteArray(int rowId) { - ColumnVector.Array array = getArray(rowId); + private ColumnarArray getByteArray(int rowId) { + ColumnarArray array = getArray(rowId); array.data.loadBytes(array); return array; } @@ -324,7 +324,7 @@ public void putDecimal(int rowId, Decimal value, int precision) { @Override public UTF8String getUTF8String(int rowId) { if (dictionary == null) { - ColumnVector.Array a = getByteArray(rowId); + ColumnarArray a = getByteArray(rowId); return UTF8String.fromBytes(a.byteArray, a.byteArrayOffset, a.length); } else { byte[] bytes = dictionary.decodeToBinary(dictionaryIds.getDictId(rowId)); @@ -338,7 +338,7 @@ public UTF8String getUTF8String(int rowId) { @Override public byte[] getBinary(int rowId) { if (dictionary == null) { - ColumnVector.Array array = getByteArray(rowId); + ColumnarArray array = getByteArray(rowId); byte[] bytes = new byte[array.length]; System.arraycopy(array.byteArray, array.byteArrayOffset, bytes, 0, bytes.length); return bytes; @@ -685,7 +685,7 @@ protected WritableColumnVector(int capacity, DataType type) { } this.childColumns = new WritableColumnVector[1]; this.childColumns[0] = reserveNewColumn(childCapacity, childType); - this.resultArray = new ColumnVector.Array(this.childColumns[0]); + this.resultArray = new ColumnarArray(this.childColumns[0]); this.resultStruct = null; } else if (type instanceof StructType) { StructType st = (StructType)type; @@ -694,14 +694,14 @@ protected WritableColumnVector(int capacity, DataType type) { this.childColumns[i] = reserveNewColumn(capacity, st.fields()[i].dataType()); } this.resultArray = null; - this.resultStruct = new ColumnarBatch.Row(this.childColumns); + this.resultStruct = new ColumnarRow(this.childColumns); } else if (type instanceof CalendarIntervalType) { // Two columns. Months as int. Microseconds as Long. this.childColumns = new WritableColumnVector[2]; this.childColumns[0] = reserveNewColumn(capacity, DataTypes.IntegerType); this.childColumns[1] = reserveNewColumn(capacity, DataTypes.LongType); this.resultArray = null; - this.resultStruct = new ColumnarBatch.Row(this.childColumns); + this.resultStruct = new ColumnarRow(this.childColumns); } else { this.childColumns = null; this.resultArray = null; diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala index 2a208a2722550..51f7c9e22b902 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala @@ -595,7 +595,7 @@ case class HashAggregateExec( ctx.addMutableState(fastHashMapClassName, fastHashMapTerm, s"$fastHashMapTerm = new $fastHashMapClassName();") ctx.addMutableState( - "java.util.Iterator", + "java.util.Iterator", iterTermForFastHashMap, "") } else { ctx.addMutableState(fastHashMapClassName, fastHashMapTerm, @@ -681,7 +681,7 @@ case class HashAggregateExec( """ } - // Iterate over the aggregate rows and convert them from ColumnarBatch.Row to UnsafeRow + // Iterate over the aggregate rows and convert them from ColumnarRow to UnsafeRow def outputFromVectorizedMap: String = { val row = ctx.freshName("fastHashMapRow") ctx.currentVars = null @@ -697,8 +697,8 @@ case class HashAggregateExec( s""" | while ($iterTermForFastHashMap.hasNext()) { | $numOutput.add(1); - | org.apache.spark.sql.execution.vectorized.ColumnarBatch.Row $row = - | (org.apache.spark.sql.execution.vectorized.ColumnarBatch.Row) + | org.apache.spark.sql.execution.vectorized.ColumnarRow $row = + | (org.apache.spark.sql.execution.vectorized.ColumnarRow) | $iterTermForFastHashMap.next(); | ${generateKeyRow.code} | ${generateBufferRow.code} @@ -892,7 +892,7 @@ case class HashAggregateExec( ${ if (isVectorizedHashMapEnabled) { s""" - | org.apache.spark.sql.execution.vectorized.ColumnarBatch.Row $fastRowBuffer = null; + | org.apache.spark.sql.execution.vectorized.ColumnarRow $fastRowBuffer = null; """.stripMargin } else { s""" diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/VectorizedHashMapGenerator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/VectorizedHashMapGenerator.scala index 812d405d5ebfe..fd783d905b776 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/VectorizedHashMapGenerator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/VectorizedHashMapGenerator.scala @@ -17,8 +17,8 @@ package org.apache.spark.sql.execution.aggregate -import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression} -import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext} +import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression +import org.apache.spark.sql.catalyst.expressions.codegen.CodegenContext import org.apache.spark.sql.types._ /** @@ -142,14 +142,14 @@ class VectorizedHashMapGenerator( /** * Generates a method that returns a mutable - * [[org.apache.spark.sql.execution.vectorized.ColumnarBatch.Row]] which keeps track of the + * [[org.apache.spark.sql.execution.vectorized.ColumnarRow]] which keeps track of the * aggregate value(s) for a given set of keys. If the corresponding row doesn't exist, the * generated method adds the corresponding row in the associated * [[org.apache.spark.sql.execution.vectorized.ColumnarBatch]]. For instance, if we * have 2 long group-by keys, the generated function would be of the form: * * {{{ - * public org.apache.spark.sql.execution.vectorized.ColumnarBatch.Row findOrInsert( + * public org.apache.spark.sql.execution.vectorized.ColumnarRow findOrInsert( * long agg_key, long agg_key1) { * long h = hash(agg_key, agg_key1); * int step = 0; @@ -189,7 +189,7 @@ class VectorizedHashMapGenerator( } s""" - |public org.apache.spark.sql.execution.vectorized.ColumnarBatch.Row findOrInsert(${ + |public org.apache.spark.sql.execution.vectorized.ColumnarRow findOrInsert(${ groupingKeySignature}) { | long h = hash(${groupingKeys.map(_.name).mkString(", ")}); | int step = 0; @@ -229,7 +229,7 @@ class VectorizedHashMapGenerator( protected def generateRowIterator(): String = { s""" - |public java.util.Iterator + |public java.util.Iterator | rowIterator() { | return batch.rowIterator(); |} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnVectorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnVectorSuite.scala index c5c8ae3a17c6c..3c76ca79f5dda 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnVectorSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnVectorSuite.scala @@ -57,7 +57,7 @@ class ColumnVectorSuite extends SparkFunSuite with BeforeAndAfterEach { testVector.appendBoolean(i % 2 == 0) } - val array = new ColumnVector.Array(testVector) + val array = new ColumnarArray(testVector) (0 until 10).foreach { i => assert(array.get(i, BooleanType) === (i % 2 == 0)) @@ -69,7 +69,7 @@ class ColumnVectorSuite extends SparkFunSuite with BeforeAndAfterEach { testVector.appendByte(i.toByte) } - val array = new ColumnVector.Array(testVector) + val array = new ColumnarArray(testVector) (0 until 10).foreach { i => assert(array.get(i, ByteType) === i.toByte) @@ -81,7 +81,7 @@ class ColumnVectorSuite extends SparkFunSuite with BeforeAndAfterEach { testVector.appendShort(i.toShort) } - val array = new ColumnVector.Array(testVector) + val array = new ColumnarArray(testVector) (0 until 10).foreach { i => assert(array.get(i, ShortType) === i.toShort) @@ -93,7 +93,7 @@ class ColumnVectorSuite extends SparkFunSuite with BeforeAndAfterEach { testVector.appendInt(i) } - val array = new ColumnVector.Array(testVector) + val array = new ColumnarArray(testVector) (0 until 10).foreach { i => assert(array.get(i, IntegerType) === i) @@ -105,7 +105,7 @@ class ColumnVectorSuite extends SparkFunSuite with BeforeAndAfterEach { testVector.appendLong(i) } - val array = new ColumnVector.Array(testVector) + val array = new ColumnarArray(testVector) (0 until 10).foreach { i => assert(array.get(i, LongType) === i) @@ -117,7 +117,7 @@ class ColumnVectorSuite extends SparkFunSuite with BeforeAndAfterEach { testVector.appendFloat(i.toFloat) } - val array = new ColumnVector.Array(testVector) + val array = new ColumnarArray(testVector) (0 until 10).foreach { i => assert(array.get(i, FloatType) === i.toFloat) @@ -129,7 +129,7 @@ class ColumnVectorSuite extends SparkFunSuite with BeforeAndAfterEach { testVector.appendDouble(i.toDouble) } - val array = new ColumnVector.Array(testVector) + val array = new ColumnarArray(testVector) (0 until 10).foreach { i => assert(array.get(i, DoubleType) === i.toDouble) @@ -142,7 +142,7 @@ class ColumnVectorSuite extends SparkFunSuite with BeforeAndAfterEach { testVector.appendByteArray(utf8, 0, utf8.length) } - val array = new ColumnVector.Array(testVector) + val array = new ColumnarArray(testVector) (0 until 10).foreach { i => assert(array.get(i, StringType) === UTF8String.fromString(s"str$i")) @@ -155,7 +155,7 @@ class ColumnVectorSuite extends SparkFunSuite with BeforeAndAfterEach { testVector.appendByteArray(utf8, 0, utf8.length) } - val array = new ColumnVector.Array(testVector) + val array = new ColumnarArray(testVector) (0 until 10).foreach { i => val utf8 = s"str$i".getBytes("utf8") @@ -179,12 +179,12 @@ class ColumnVectorSuite extends SparkFunSuite with BeforeAndAfterEach { testVector.putArray(2, 3, 0) testVector.putArray(3, 3, 3) - val array = new ColumnVector.Array(testVector) + val array = new ColumnarArray(testVector) - assert(array.get(0, arrayType).asInstanceOf[ArrayData].toIntArray() === Array(0)) - assert(array.get(1, arrayType).asInstanceOf[ArrayData].toIntArray() === Array(1, 2)) - assert(array.get(2, arrayType).asInstanceOf[ArrayData].toIntArray() === Array.empty[Int]) - assert(array.get(3, arrayType).asInstanceOf[ArrayData].toIntArray() === Array(3, 4, 5)) + assert(array.getArray(0).toIntArray() === Array(0)) + assert(array.getArray(1).toIntArray() === Array(1, 2)) + assert(array.getArray(2).toIntArray() === Array.empty[Int]) + assert(array.getArray(3).toIntArray() === Array(3, 4, 5)) } val structType: StructType = new StructType().add("int", IntegerType).add("double", DoubleType) @@ -196,12 +196,12 @@ class ColumnVectorSuite extends SparkFunSuite with BeforeAndAfterEach { c1.putInt(1, 456) c2.putDouble(1, 5.67) - val array = new ColumnVector.Array(testVector) + val array = new ColumnarArray(testVector) - assert(array.get(0, structType).asInstanceOf[ColumnarBatch.Row].get(0, IntegerType) === 123) - assert(array.get(0, structType).asInstanceOf[ColumnarBatch.Row].get(1, DoubleType) === 3.45) - assert(array.get(1, structType).asInstanceOf[ColumnarBatch.Row].get(0, IntegerType) === 456) - assert(array.get(1, structType).asInstanceOf[ColumnarBatch.Row].get(1, DoubleType) === 5.67) + assert(array.getStruct(0, structType.length).get(0, IntegerType) === 123) + assert(array.getStruct(0, structType.length).get(1, DoubleType) === 3.45) + assert(array.getStruct(1, structType.length).get(0, IntegerType) === 456) + assert(array.getStruct(1, structType.length).get(1, DoubleType) === 5.67) } test("[SPARK-22092] off-heap column vector reallocation corrupts array data") { @@ -214,7 +214,7 @@ class ColumnVectorSuite extends SparkFunSuite with BeforeAndAfterEach { testVector.reserve(16) // Check that none of the values got lost/overwritten. - val array = new ColumnVector.Array(testVector) + val array = new ColumnarArray(testVector) (0 until 8).foreach { i => assert(array.get(i, arrayType).asInstanceOf[ArrayData].toIntArray() === Array(i)) } From 8f0e88df03a06a91bb61c6e0d69b1b19e2bfb3f7 Mon Sep 17 00:00:00 2001 From: Bryan Cutler Date: Wed, 15 Nov 2017 23:35:13 +0900 Subject: [PATCH 1675/1765] [SPARK-20791][PYTHON][FOLLOWUP] Check for unicode column names in createDataFrame with Arrow ## What changes were proposed in this pull request? If schema is passed as a list of unicode strings for column names, they should be re-encoded to 'utf-8' to be consistent. This is similar to the #13097 but for creation of DataFrame using Arrow. ## How was this patch tested? Added new test of using unicode names for schema. Author: Bryan Cutler Closes #19738 from BryanCutler/arrow-createDataFrame-followup-unicode-SPARK-20791. --- python/pyspark/sql/session.py | 7 ++++--- python/pyspark/sql/tests.py | 10 ++++++++++ 2 files changed, 14 insertions(+), 3 deletions(-) diff --git a/python/pyspark/sql/session.py b/python/pyspark/sql/session.py index 589365b083012..dbbcfff6db91b 100644 --- a/python/pyspark/sql/session.py +++ b/python/pyspark/sql/session.py @@ -592,6 +592,9 @@ def createDataFrame(self, data, schema=None, samplingRatio=None, verifySchema=Tr if isinstance(schema, basestring): schema = _parse_datatype_string(schema) + elif isinstance(schema, (list, tuple)): + # Must re-encode any unicode strings to be consistent with StructField names + schema = [x.encode('utf-8') if not isinstance(x, str) else x for x in schema] try: import pandas @@ -602,7 +605,7 @@ def createDataFrame(self, data, schema=None, samplingRatio=None, verifySchema=Tr # If no schema supplied by user then get the names of columns only if schema is None: - schema = [str(x) for x in data.columns] + schema = [x.encode('utf-8') if not isinstance(x, str) else x for x in data.columns] if self.conf.get("spark.sql.execution.arrow.enabled", "false").lower() == "true" \ and len(data) > 0: @@ -630,8 +633,6 @@ def prepare(obj): verify_func(obj) return obj, else: - if isinstance(schema, (list, tuple)): - schema = [x.encode('utf-8') if not isinstance(x, str) else x for x in schema] prepare = lambda obj: obj if isinstance(data, RDD): diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 6356d938db26a..ef592c2356a8c 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -3225,6 +3225,16 @@ def test_createDataFrame_with_names(self): df = self.spark.createDataFrame(pdf, schema=tuple('abcdefg')) self.assertEquals(df.schema.fieldNames(), list('abcdefg')) + def test_createDataFrame_column_name_encoding(self): + import pandas as pd + pdf = pd.DataFrame({u'a': [1]}) + columns = self.spark.createDataFrame(pdf).columns + self.assertTrue(isinstance(columns[0], str)) + self.assertEquals(columns[0], 'a') + columns = self.spark.createDataFrame(pdf, [u'b']).columns + self.assertTrue(isinstance(columns[0], str)) + self.assertEquals(columns[0], 'b') + def test_createDataFrame_with_single_data_type(self): import pandas as pd with QuietTest(self.sc): From 7f99a05e6ff258fc2192130451aa8aa1304bfe93 Mon Sep 17 00:00:00 2001 From: test Date: Wed, 15 Nov 2017 10:13:01 -0600 Subject: [PATCH 1676/1765] [SPARK-22422][ML] Add Adjusted R2 to RegressionMetrics ## What changes were proposed in this pull request? I added adjusted R2 as a regression metric which was implemented in all major statistical analysis tools. In practice, no one looks at R2 alone. The reason is R2 itself is misleading. If we add more parameters, R2 will not decrease but only increase (or stay the same). This leads to overfitting. Adjusted R2 addressed this issue by using number of parameters as "weight" for the sum of errors. ## How was this patch tested? - Added a new unit test and passed. - ./dev/run-tests all passed. Author: test Author: tengpeng Closes #19638 from tengpeng/master. --- .../spark/ml/regression/LinearRegression.scala | 15 +++++++++++++++ .../ml/regression/LinearRegressionSuite.scala | 6 ++++++ 2 files changed, 21 insertions(+) diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala index df1aa609c1b71..da6bcf07e4742 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala @@ -722,6 +722,21 @@ class LinearRegressionSummary private[regression] ( @Since("1.5.0") val r2: Double = metrics.r2 + /** + * Returns Adjusted R^2^, the adjusted coefficient of determination. + * Reference: + * Wikipedia coefficient of determination + * + * @note This ignores instance weights (setting all to 1.0) from `LinearRegression.weightCol`. + * This will change in later Spark versions. + */ + @Since("2.3.0") + val r2adj: Double = { + val interceptDOF = if (privateModel.getFitIntercept) 1 else 0 + 1 - (1 - r2) * (numInstances - interceptDOF) / + (numInstances - privateModel.coefficients.size - interceptDOF) + } + /** Residuals (label - predicted value) */ @Since("1.5.0") @transient lazy val residuals: DataFrame = { diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala index f470dca7dbd0a..0e0be58dbf022 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala @@ -764,6 +764,11 @@ class LinearRegressionSuite (Intercept) 6.3022157 0.0018600 3388 <2e-16 *** V2 4.6982442 0.0011805 3980 <2e-16 *** V3 7.1994344 0.0009044 7961 <2e-16 *** + + # R code for r2adj + lm_fit <- lm(V1 ~ V2 + V3, data = d1) + summary(lm_fit)$adj.r.squared + [1] 0.9998736 --- .... @@ -771,6 +776,7 @@ class LinearRegressionSuite assert(model.summary.meanSquaredError ~== 0.00985449 relTol 1E-4) assert(model.summary.meanAbsoluteError ~== 0.07961668 relTol 1E-4) assert(model.summary.r2 ~== 0.9998737 relTol 1E-4) + assert(model.summary.r2adj ~== 0.9998736 relTol 1E-4) // Normal solver uses "WeightedLeastSquares". If no regularization is applied or only L2 // regularization is applied, this algorithm uses a direct solver and does not generate an From aa88b8dbbb7e71b282f31ae775140c783e83b4d6 Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Wed, 15 Nov 2017 08:59:29 -0800 Subject: [PATCH 1677/1765] [SPARK-22490][DOC] Add PySpark doc for SparkSession.builder MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## What changes were proposed in this pull request? In PySpark API Document, [SparkSession.build](http://spark.apache.org/docs/2.2.0/api/python/pyspark.sql.html) is not documented and shows default value description. ``` SparkSession.builder =
    >
    > builder
    >

    A class attribute having a Builder to construct SparkSession instances

    >
    > 212,216d217 <
    < builder = <pyspark.sql.session.SparkSession.Builder object>
    <
    < <
    ``` ## How was this patch tested? Manual. ``` cd python/docs make html open _build/html/pyspark.sql.html ``` Author: Dongjoon Hyun Closes #19726 from dongjoon-hyun/SPARK-22490. --- python/docs/pyspark.sql.rst | 3 +++ python/pyspark/sql/session.py | 4 ++++ 2 files changed, 7 insertions(+) diff --git a/python/docs/pyspark.sql.rst b/python/docs/pyspark.sql.rst index 09848b880194d..5c3b7e274857a 100644 --- a/python/docs/pyspark.sql.rst +++ b/python/docs/pyspark.sql.rst @@ -7,6 +7,9 @@ Module Context .. automodule:: pyspark.sql :members: :undoc-members: + :exclude-members: builder +.. We need `exclude-members` to prevent default description generations + as a workaround for old Sphinx (< 1.6.6). pyspark.sql.types module ------------------------ diff --git a/python/pyspark/sql/session.py b/python/pyspark/sql/session.py index dbbcfff6db91b..47c58bb28221c 100644 --- a/python/pyspark/sql/session.py +++ b/python/pyspark/sql/session.py @@ -72,6 +72,9 @@ class SparkSession(object): ... .appName("Word Count") \\ ... .config("spark.some.config.option", "some-value") \\ ... .getOrCreate() + + .. autoattribute:: builder + :annotation: """ class Builder(object): @@ -183,6 +186,7 @@ def getOrCreate(self): return session builder = Builder() + """A class attribute having a :class:`Builder` to construct :class:`SparkSession` instances""" _instantiatedSession = None From bc0848b4c1ab84ccef047363a70fd11df240dbbf Mon Sep 17 00:00:00 2001 From: liutang123 Date: Wed, 15 Nov 2017 09:02:54 -0800 Subject: [PATCH 1678/1765] [SPARK-22469][SQL] Accuracy problem in comparison with string and numeric ## What changes were proposed in this pull request? This fixes a problem caused by #15880 `select '1.5' > 0.5; // Result is NULL in Spark but is true in Hive. ` When compare string and numeric, cast them as double like Hive. Author: liutang123 Closes #19692 from liutang123/SPARK-22469. --- .../sql/catalyst/analysis/TypeCoercion.scala | 7 + .../catalyst/analysis/TypeCoercionSuite.scala | 3 + .../sql-tests/inputs/predicate-functions.sql | 5 + .../results/predicate-functions.sql.out | 140 +++++++++++------- 4 files changed, 105 insertions(+), 50 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala index 532d22dbf2321..074eda56199e8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala @@ -137,6 +137,13 @@ object TypeCoercion { case (DateType, TimestampType) => Some(StringType) case (StringType, NullType) => Some(StringType) case (NullType, StringType) => Some(StringType) + + // There is no proper decimal type we can pick, + // using double type is the best we can do. + // See SPARK-22469 for details. + case (n: DecimalType, s: StringType) => Some(DoubleType) + case (s: StringType, n: DecimalType) => Some(DoubleType) + case (l: StringType, r: AtomicType) if r != StringType => Some(r) case (l: AtomicType, r: StringType) if (l != StringType) => Some(l) case (l, r) => None diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala index 793e04f66f0f9..5dcd653e9b341 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala @@ -1152,6 +1152,9 @@ class TypeCoercionSuite extends AnalysisTest { ruleTest(PromoteStrings, EqualTo(Literal(Array(1, 2)), Literal("123")), EqualTo(Literal(Array(1, 2)), Literal("123"))) + ruleTest(PromoteStrings, + GreaterThan(Literal("1.5"), Literal(BigDecimal("0.5"))), + GreaterThan(Cast(Literal("1.5"), DoubleType), Cast(Literal(BigDecimal("0.5")), DoubleType))) } test("cast WindowFrame boundaries to the type they operate upon") { diff --git a/sql/core/src/test/resources/sql-tests/inputs/predicate-functions.sql b/sql/core/src/test/resources/sql-tests/inputs/predicate-functions.sql index 3b3d4ad64b3ec..e99d5cef81f64 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/predicate-functions.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/predicate-functions.sql @@ -2,12 +2,14 @@ select 1 = 1; select 1 = '1'; select 1.0 = '1'; +select 1.5 = '1.51'; -- GreaterThan select 1 > '1'; select 2 > '1.0'; select 2 > '2.0'; select 2 > '2.2'; +select '1.5' > 0.5; select to_date('2009-07-30 04:17:52') > to_date('2009-07-30 04:17:52'); select to_date('2009-07-30 04:17:52') > '2009-07-30 04:17:52'; @@ -16,6 +18,7 @@ select 1 >= '1'; select 2 >= '1.0'; select 2 >= '2.0'; select 2.0 >= '2.2'; +select '1.5' >= 0.5; select to_date('2009-07-30 04:17:52') >= to_date('2009-07-30 04:17:52'); select to_date('2009-07-30 04:17:52') >= '2009-07-30 04:17:52'; @@ -24,6 +27,7 @@ select 1 < '1'; select 2 < '1.0'; select 2 < '2.0'; select 2.0 < '2.2'; +select 0.5 < '1.5'; select to_date('2009-07-30 04:17:52') < to_date('2009-07-30 04:17:52'); select to_date('2009-07-30 04:17:52') < '2009-07-30 04:17:52'; @@ -32,5 +36,6 @@ select 1 <= '1'; select 2 <= '1.0'; select 2 <= '2.0'; select 2.0 <= '2.2'; +select 0.5 <= '1.5'; select to_date('2009-07-30 04:17:52') <= to_date('2009-07-30 04:17:52'); select to_date('2009-07-30 04:17:52') <= '2009-07-30 04:17:52'; diff --git a/sql/core/src/test/resources/sql-tests/results/predicate-functions.sql.out b/sql/core/src/test/resources/sql-tests/results/predicate-functions.sql.out index 8e7e04c8e1c4f..8cd0d51da64f5 100644 --- a/sql/core/src/test/resources/sql-tests/results/predicate-functions.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/predicate-functions.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 27 +-- Number of queries: 31 -- !query 0 @@ -21,11 +21,19 @@ true -- !query 2 select 1.0 = '1' -- !query 2 schema -struct<(1.0 = CAST(1 AS DECIMAL(2,1))):boolean> +struct<(CAST(1.0 AS DOUBLE) = CAST(1 AS DOUBLE)):boolean> -- !query 2 output true +-- !query 3 +select 1.5 = '1.51' +-- !query 3 schema +struct<(CAST(1.5 AS DOUBLE) = CAST(1.51 AS DOUBLE)):boolean> +-- !query 3 output +false + + -- !query 3 select 1 > '1' -- !query 3 schema @@ -59,160 +67,192 @@ false -- !query 7 -select to_date('2009-07-30 04:17:52') > to_date('2009-07-30 04:17:52') +select '1.5' > 0.5 -- !query 7 schema -struct<(to_date('2009-07-30 04:17:52') > to_date('2009-07-30 04:17:52')):boolean> +struct<(CAST(1.5 AS DOUBLE) > CAST(0.5 AS DOUBLE)):boolean> -- !query 7 output -false +true -- !query 8 -select to_date('2009-07-30 04:17:52') > '2009-07-30 04:17:52' +select to_date('2009-07-30 04:17:52') > to_date('2009-07-30 04:17:52') -- !query 8 schema -struct<(CAST(to_date('2009-07-30 04:17:52') AS STRING) > 2009-07-30 04:17:52):boolean> +struct<(to_date('2009-07-30 04:17:52') > to_date('2009-07-30 04:17:52')):boolean> -- !query 8 output false -- !query 9 -select 1 >= '1' +select to_date('2009-07-30 04:17:52') > '2009-07-30 04:17:52' -- !query 9 schema -struct<(1 >= CAST(1 AS INT)):boolean> +struct<(CAST(to_date('2009-07-30 04:17:52') AS STRING) > 2009-07-30 04:17:52):boolean> -- !query 9 output -true +false -- !query 10 -select 2 >= '1.0' +select 1 >= '1' -- !query 10 schema -struct<(2 >= CAST(1.0 AS INT)):boolean> +struct<(1 >= CAST(1 AS INT)):boolean> -- !query 10 output true -- !query 11 -select 2 >= '2.0' +select 2 >= '1.0' -- !query 11 schema -struct<(2 >= CAST(2.0 AS INT)):boolean> +struct<(2 >= CAST(1.0 AS INT)):boolean> -- !query 11 output true -- !query 12 -select 2.0 >= '2.2' +select 2 >= '2.0' -- !query 12 schema -struct<(2.0 >= CAST(2.2 AS DECIMAL(2,1))):boolean> +struct<(2 >= CAST(2.0 AS INT)):boolean> -- !query 12 output -false +true -- !query 13 -select to_date('2009-07-30 04:17:52') >= to_date('2009-07-30 04:17:52') +select 2.0 >= '2.2' -- !query 13 schema -struct<(to_date('2009-07-30 04:17:52') >= to_date('2009-07-30 04:17:52')):boolean> +struct<(CAST(2.0 AS DOUBLE) >= CAST(2.2 AS DOUBLE)):boolean> -- !query 13 output -true +false -- !query 14 -select to_date('2009-07-30 04:17:52') >= '2009-07-30 04:17:52' +select '1.5' >= 0.5 -- !query 14 schema -struct<(CAST(to_date('2009-07-30 04:17:52') AS STRING) >= 2009-07-30 04:17:52):boolean> +struct<(CAST(1.5 AS DOUBLE) >= CAST(0.5 AS DOUBLE)):boolean> -- !query 14 output -false +true -- !query 15 -select 1 < '1' +select to_date('2009-07-30 04:17:52') >= to_date('2009-07-30 04:17:52') -- !query 15 schema -struct<(1 < CAST(1 AS INT)):boolean> +struct<(to_date('2009-07-30 04:17:52') >= to_date('2009-07-30 04:17:52')):boolean> -- !query 15 output -false +true -- !query 16 -select 2 < '1.0' +select to_date('2009-07-30 04:17:52') >= '2009-07-30 04:17:52' -- !query 16 schema -struct<(2 < CAST(1.0 AS INT)):boolean> +struct<(CAST(to_date('2009-07-30 04:17:52') AS STRING) >= 2009-07-30 04:17:52):boolean> -- !query 16 output false -- !query 17 -select 2 < '2.0' +select 1 < '1' -- !query 17 schema -struct<(2 < CAST(2.0 AS INT)):boolean> +struct<(1 < CAST(1 AS INT)):boolean> -- !query 17 output false -- !query 18 -select 2.0 < '2.2' +select 2 < '1.0' -- !query 18 schema -struct<(2.0 < CAST(2.2 AS DECIMAL(2,1))):boolean> +struct<(2 < CAST(1.0 AS INT)):boolean> -- !query 18 output -true +false -- !query 19 -select to_date('2009-07-30 04:17:52') < to_date('2009-07-30 04:17:52') +select 2 < '2.0' -- !query 19 schema -struct<(to_date('2009-07-30 04:17:52') < to_date('2009-07-30 04:17:52')):boolean> +struct<(2 < CAST(2.0 AS INT)):boolean> -- !query 19 output false -- !query 20 -select to_date('2009-07-30 04:17:52') < '2009-07-30 04:17:52' +select 2.0 < '2.2' -- !query 20 schema -struct<(CAST(to_date('2009-07-30 04:17:52') AS STRING) < 2009-07-30 04:17:52):boolean> +struct<(CAST(2.0 AS DOUBLE) < CAST(2.2 AS DOUBLE)):boolean> -- !query 20 output true -- !query 21 -select 1 <= '1' +select 0.5 < '1.5' -- !query 21 schema -struct<(1 <= CAST(1 AS INT)):boolean> +struct<(CAST(0.5 AS DOUBLE) < CAST(1.5 AS DOUBLE)):boolean> -- !query 21 output true -- !query 22 -select 2 <= '1.0' +select to_date('2009-07-30 04:17:52') < to_date('2009-07-30 04:17:52') -- !query 22 schema -struct<(2 <= CAST(1.0 AS INT)):boolean> +struct<(to_date('2009-07-30 04:17:52') < to_date('2009-07-30 04:17:52')):boolean> -- !query 22 output false -- !query 23 -select 2 <= '2.0' +select to_date('2009-07-30 04:17:52') < '2009-07-30 04:17:52' -- !query 23 schema -struct<(2 <= CAST(2.0 AS INT)):boolean> +struct<(CAST(to_date('2009-07-30 04:17:52') AS STRING) < 2009-07-30 04:17:52):boolean> -- !query 23 output true -- !query 24 -select 2.0 <= '2.2' +select 1 <= '1' -- !query 24 schema -struct<(2.0 <= CAST(2.2 AS DECIMAL(2,1))):boolean> +struct<(1 <= CAST(1 AS INT)):boolean> -- !query 24 output true -- !query 25 -select to_date('2009-07-30 04:17:52') <= to_date('2009-07-30 04:17:52') +select 2 <= '1.0' -- !query 25 schema -struct<(to_date('2009-07-30 04:17:52') <= to_date('2009-07-30 04:17:52')):boolean> +struct<(2 <= CAST(1.0 AS INT)):boolean> -- !query 25 output -true +false -- !query 26 -select to_date('2009-07-30 04:17:52') <= '2009-07-30 04:17:52' +select 2 <= '2.0' -- !query 26 schema -struct<(CAST(to_date('2009-07-30 04:17:52') AS STRING) <= 2009-07-30 04:17:52):boolean> +struct<(2 <= CAST(2.0 AS INT)):boolean> -- !query 26 output true + + +-- !query 27 +select 2.0 <= '2.2' +-- !query 27 schema +struct<(CAST(2.0 AS DOUBLE) <= CAST(2.2 AS DOUBLE)):boolean> +-- !query 27 output +true + + +-- !query 28 +select 0.5 <= '1.5' +-- !query 28 schema +struct<(CAST(0.5 AS DOUBLE) <= CAST(1.5 AS DOUBLE)):boolean> +-- !query 28 output +true + + +-- !query 29 +select to_date('2009-07-30 04:17:52') <= to_date('2009-07-30 04:17:52') +-- !query 29 schema +struct<(to_date('2009-07-30 04:17:52') <= to_date('2009-07-30 04:17:52')):boolean> +-- !query 29 output +true + + +-- !query 30 +select to_date('2009-07-30 04:17:52') <= '2009-07-30 04:17:52' +-- !query 30 schema +struct<(CAST(to_date('2009-07-30 04:17:52') AS STRING) <= 2009-07-30 04:17:52):boolean> +-- !query 30 output +true From 39b3f10dda73f4a1f735f17467e5c6c45c44e977 Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Wed, 15 Nov 2017 15:41:53 -0600 Subject: [PATCH 1679/1765] [SPARK-20649][CORE] Simplify REST API resource structure. With the new UI store, the API resource classes have a lot less code, since there's no need for complicated translations between the UI types and the API types. So the code ended up with a bunch of files with a single method declared in them. This change re-structures the API code so that it uses less classes; mainly, most sub-resources were removed, and the code to deal with single-attempt and multi-attempt apps was simplified. The only change was the addition of a method to return a single attempt's information; that was missing in the old API, so trying to retrieve "/v1/applications/appId/attemptId" would result in a 404 even if the attempt existed (and URIs under that one would return valid data). The streaming API resources also overtook the same treatment, even though the data is not stored in the new UI store. Author: Marcelo Vanzin Closes #19748 from vanzin/SPARK-20649. --- .../api/v1/AllExecutorListResource.scala | 30 --- .../spark/status/api/v1/AllJobsResource.scala | 35 --- .../spark/status/api/v1/AllRDDResource.scala | 31 --- .../status/api/v1/AllStagesResource.scala | 33 --- .../spark/status/api/v1/ApiRootResource.scala | 203 ++---------------- .../v1/ApplicationEnvironmentResource.scala | 32 --- .../api/v1/ApplicationListResource.scala | 2 +- .../api/v1/EventLogDownloadResource.scala | 71 ------ .../status/api/v1/ExecutorListResource.scala | 30 --- .../api/v1/OneApplicationResource.scala | 146 ++++++++++++- .../spark/status/api/v1/OneJobResource.scala | 38 ---- .../spark/status/api/v1/OneRDDResource.scala | 38 ---- ...ageResource.scala => StagesResource.scala} | 34 +-- .../spark/status/api/v1/VersionResource.scala | 30 --- .../api/v1/streaming/AllBatchesResource.scala | 78 ------- .../AllOutputOperationsResource.scala | 66 ------ .../v1/streaming/AllReceiversResource.scala | 76 ------- .../api/v1/streaming/ApiStreamingApp.scala | 31 ++- .../streaming/ApiStreamingRootResource.scala | 172 ++++++++++++--- .../api/v1/streaming/OneBatchResource.scala | 35 --- .../OneOutputOperationResource.scala | 39 ---- .../v1/streaming/OneReceiverResource.scala | 35 --- .../StreamingStatisticsResource.scala | 64 ------ 23 files changed, 349 insertions(+), 1000 deletions(-) delete mode 100644 core/src/main/scala/org/apache/spark/status/api/v1/AllExecutorListResource.scala delete mode 100644 core/src/main/scala/org/apache/spark/status/api/v1/AllJobsResource.scala delete mode 100644 core/src/main/scala/org/apache/spark/status/api/v1/AllRDDResource.scala delete mode 100644 core/src/main/scala/org/apache/spark/status/api/v1/AllStagesResource.scala delete mode 100644 core/src/main/scala/org/apache/spark/status/api/v1/ApplicationEnvironmentResource.scala delete mode 100644 core/src/main/scala/org/apache/spark/status/api/v1/EventLogDownloadResource.scala delete mode 100644 core/src/main/scala/org/apache/spark/status/api/v1/ExecutorListResource.scala delete mode 100644 core/src/main/scala/org/apache/spark/status/api/v1/OneJobResource.scala delete mode 100644 core/src/main/scala/org/apache/spark/status/api/v1/OneRDDResource.scala rename core/src/main/scala/org/apache/spark/status/api/v1/{OneStageResource.scala => StagesResource.scala} (77%) delete mode 100644 core/src/main/scala/org/apache/spark/status/api/v1/VersionResource.scala delete mode 100644 streaming/src/main/scala/org/apache/spark/status/api/v1/streaming/AllBatchesResource.scala delete mode 100644 streaming/src/main/scala/org/apache/spark/status/api/v1/streaming/AllOutputOperationsResource.scala delete mode 100644 streaming/src/main/scala/org/apache/spark/status/api/v1/streaming/AllReceiversResource.scala delete mode 100644 streaming/src/main/scala/org/apache/spark/status/api/v1/streaming/OneBatchResource.scala delete mode 100644 streaming/src/main/scala/org/apache/spark/status/api/v1/streaming/OneOutputOperationResource.scala delete mode 100644 streaming/src/main/scala/org/apache/spark/status/api/v1/streaming/OneReceiverResource.scala delete mode 100644 streaming/src/main/scala/org/apache/spark/status/api/v1/streaming/StreamingStatisticsResource.scala diff --git a/core/src/main/scala/org/apache/spark/status/api/v1/AllExecutorListResource.scala b/core/src/main/scala/org/apache/spark/status/api/v1/AllExecutorListResource.scala deleted file mode 100644 index 5522f4cebd773..0000000000000 --- a/core/src/main/scala/org/apache/spark/status/api/v1/AllExecutorListResource.scala +++ /dev/null @@ -1,30 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.spark.status.api.v1 - -import javax.ws.rs.{GET, Produces} -import javax.ws.rs.core.MediaType - -import org.apache.spark.ui.SparkUI - -@Produces(Array(MediaType.APPLICATION_JSON)) -private[v1] class AllExecutorListResource(ui: SparkUI) { - - @GET - def executorList(): Seq[ExecutorSummary] = ui.store.executorList(false) - -} diff --git a/core/src/main/scala/org/apache/spark/status/api/v1/AllJobsResource.scala b/core/src/main/scala/org/apache/spark/status/api/v1/AllJobsResource.scala deleted file mode 100644 index b4fa3e633f6c1..0000000000000 --- a/core/src/main/scala/org/apache/spark/status/api/v1/AllJobsResource.scala +++ /dev/null @@ -1,35 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.spark.status.api.v1 - -import java.util.{Arrays, Date, List => JList} -import javax.ws.rs._ -import javax.ws.rs.core.MediaType - -import org.apache.spark.JobExecutionStatus -import org.apache.spark.ui.SparkUI -import org.apache.spark.ui.jobs.UIData.JobUIData - -@Produces(Array(MediaType.APPLICATION_JSON)) -private[v1] class AllJobsResource(ui: SparkUI) { - - @GET - def jobsList(@QueryParam("status") statuses: JList[JobExecutionStatus]): Seq[JobData] = { - ui.store.jobsList(statuses) - } - -} diff --git a/core/src/main/scala/org/apache/spark/status/api/v1/AllRDDResource.scala b/core/src/main/scala/org/apache/spark/status/api/v1/AllRDDResource.scala deleted file mode 100644 index 2189e1da91841..0000000000000 --- a/core/src/main/scala/org/apache/spark/status/api/v1/AllRDDResource.scala +++ /dev/null @@ -1,31 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.spark.status.api.v1 - -import javax.ws.rs.{GET, Produces} -import javax.ws.rs.core.MediaType - -import org.apache.spark.storage.{RDDInfo, StorageStatus, StorageUtils} -import org.apache.spark.ui.SparkUI - -@Produces(Array(MediaType.APPLICATION_JSON)) -private[v1] class AllRDDResource(ui: SparkUI) { - - @GET - def rddList(): Seq[RDDStorageInfo] = ui.store.rddList() - -} diff --git a/core/src/main/scala/org/apache/spark/status/api/v1/AllStagesResource.scala b/core/src/main/scala/org/apache/spark/status/api/v1/AllStagesResource.scala deleted file mode 100644 index e1c91cb527a51..0000000000000 --- a/core/src/main/scala/org/apache/spark/status/api/v1/AllStagesResource.scala +++ /dev/null @@ -1,33 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.spark.status.api.v1 - -import java.util.{List => JList} -import javax.ws.rs.{GET, Produces, QueryParam} -import javax.ws.rs.core.MediaType - -import org.apache.spark.ui.SparkUI - -@Produces(Array(MediaType.APPLICATION_JSON)) -private[v1] class AllStagesResource(ui: SparkUI) { - - @GET - def stageList(@QueryParam("status") statuses: JList[StageStatus]): Seq[StageData] = { - ui.store.stageList(statuses) - } - -} diff --git a/core/src/main/scala/org/apache/spark/status/api/v1/ApiRootResource.scala b/core/src/main/scala/org/apache/spark/status/api/v1/ApiRootResource.scala index 9d3833086172f..ed9bdc6e1e3c2 100644 --- a/core/src/main/scala/org/apache/spark/status/api/v1/ApiRootResource.scala +++ b/core/src/main/scala/org/apache/spark/status/api/v1/ApiRootResource.scala @@ -44,189 +44,14 @@ import org.apache.spark.ui.SparkUI private[v1] class ApiRootResource extends ApiRequestContext { @Path("applications") - def getApplicationList(): ApplicationListResource = { - new ApplicationListResource(uiRoot) - } + def applicationList(): Class[ApplicationListResource] = classOf[ApplicationListResource] @Path("applications/{appId}") - def getApplication(): OneApplicationResource = { - new OneApplicationResource(uiRoot) - } - - @Path("applications/{appId}/{attemptId}/jobs") - def getJobs( - @PathParam("appId") appId: String, - @PathParam("attemptId") attemptId: String): AllJobsResource = { - withSparkUI(appId, Some(attemptId)) { ui => - new AllJobsResource(ui) - } - } - - @Path("applications/{appId}/jobs") - def getJobs(@PathParam("appId") appId: String): AllJobsResource = { - withSparkUI(appId, None) { ui => - new AllJobsResource(ui) - } - } - - @Path("applications/{appId}/jobs/{jobId: \\d+}") - def getJob(@PathParam("appId") appId: String): OneJobResource = { - withSparkUI(appId, None) { ui => - new OneJobResource(ui) - } - } - - @Path("applications/{appId}/{attemptId}/jobs/{jobId: \\d+}") - def getJob( - @PathParam("appId") appId: String, - @PathParam("attemptId") attemptId: String): OneJobResource = { - withSparkUI(appId, Some(attemptId)) { ui => - new OneJobResource(ui) - } - } - - @Path("applications/{appId}/executors") - def getExecutors(@PathParam("appId") appId: String): ExecutorListResource = { - withSparkUI(appId, None) { ui => - new ExecutorListResource(ui) - } - } - - @Path("applications/{appId}/allexecutors") - def getAllExecutors(@PathParam("appId") appId: String): AllExecutorListResource = { - withSparkUI(appId, None) { ui => - new AllExecutorListResource(ui) - } - } - - @Path("applications/{appId}/{attemptId}/executors") - def getExecutors( - @PathParam("appId") appId: String, - @PathParam("attemptId") attemptId: String): ExecutorListResource = { - withSparkUI(appId, Some(attemptId)) { ui => - new ExecutorListResource(ui) - } - } - - @Path("applications/{appId}/{attemptId}/allexecutors") - def getAllExecutors( - @PathParam("appId") appId: String, - @PathParam("attemptId") attemptId: String): AllExecutorListResource = { - withSparkUI(appId, Some(attemptId)) { ui => - new AllExecutorListResource(ui) - } - } - - @Path("applications/{appId}/stages") - def getStages(@PathParam("appId") appId: String): AllStagesResource = { - withSparkUI(appId, None) { ui => - new AllStagesResource(ui) - } - } - - @Path("applications/{appId}/{attemptId}/stages") - def getStages( - @PathParam("appId") appId: String, - @PathParam("attemptId") attemptId: String): AllStagesResource = { - withSparkUI(appId, Some(attemptId)) { ui => - new AllStagesResource(ui) - } - } - - @Path("applications/{appId}/stages/{stageId: \\d+}") - def getStage(@PathParam("appId") appId: String): OneStageResource = { - withSparkUI(appId, None) { ui => - new OneStageResource(ui) - } - } - - @Path("applications/{appId}/{attemptId}/stages/{stageId: \\d+}") - def getStage( - @PathParam("appId") appId: String, - @PathParam("attemptId") attemptId: String): OneStageResource = { - withSparkUI(appId, Some(attemptId)) { ui => - new OneStageResource(ui) - } - } - - @Path("applications/{appId}/storage/rdd") - def getRdds(@PathParam("appId") appId: String): AllRDDResource = { - withSparkUI(appId, None) { ui => - new AllRDDResource(ui) - } - } - - @Path("applications/{appId}/{attemptId}/storage/rdd") - def getRdds( - @PathParam("appId") appId: String, - @PathParam("attemptId") attemptId: String): AllRDDResource = { - withSparkUI(appId, Some(attemptId)) { ui => - new AllRDDResource(ui) - } - } - - @Path("applications/{appId}/storage/rdd/{rddId: \\d+}") - def getRdd(@PathParam("appId") appId: String): OneRDDResource = { - withSparkUI(appId, None) { ui => - new OneRDDResource(ui) - } - } - - @Path("applications/{appId}/{attemptId}/storage/rdd/{rddId: \\d+}") - def getRdd( - @PathParam("appId") appId: String, - @PathParam("attemptId") attemptId: String): OneRDDResource = { - withSparkUI(appId, Some(attemptId)) { ui => - new OneRDDResource(ui) - } - } - - @Path("applications/{appId}/logs") - def getEventLogs( - @PathParam("appId") appId: String): EventLogDownloadResource = { - try { - // withSparkUI will throw NotFoundException if attemptId exists for this application. - // So we need to try again with attempt id "1". - withSparkUI(appId, None) { _ => - new EventLogDownloadResource(uiRoot, appId, None) - } - } catch { - case _: NotFoundException => - withSparkUI(appId, Some("1")) { _ => - new EventLogDownloadResource(uiRoot, appId, None) - } - } - } - - @Path("applications/{appId}/{attemptId}/logs") - def getEventLogs( - @PathParam("appId") appId: String, - @PathParam("attemptId") attemptId: String): EventLogDownloadResource = { - withSparkUI(appId, Some(attemptId)) { _ => - new EventLogDownloadResource(uiRoot, appId, Some(attemptId)) - } - } + def application(): Class[OneApplicationResource] = classOf[OneApplicationResource] @Path("version") - def getVersion(): VersionResource = { - new VersionResource(uiRoot) - } - - @Path("applications/{appId}/environment") - def getEnvironment(@PathParam("appId") appId: String): ApplicationEnvironmentResource = { - withSparkUI(appId, None) { ui => - new ApplicationEnvironmentResource(ui) - } - } + def version(): VersionInfo = new VersionInfo(org.apache.spark.SPARK_VERSION) - @Path("applications/{appId}/{attemptId}/environment") - def getEnvironment( - @PathParam("appId") appId: String, - @PathParam("attemptId") attemptId: String): ApplicationEnvironmentResource = { - withSparkUI(appId, Some(attemptId)) { ui => - new ApplicationEnvironmentResource(ui) - } - } } private[spark] object ApiRootResource { @@ -293,23 +118,29 @@ private[v1] trait ApiRequestContext { def uiRoot: UIRoot = UIRootFromServletContext.getUiRoot(servletContext) +} - /** - * Get the spark UI with the given appID, and apply a function - * to it. If there is no such app, throw an appropriate exception - */ - def withSparkUI[T](appId: String, attemptId: Option[String])(f: SparkUI => T): T = { +/** + * Base class for resource handlers that use app-specific data. Abstracts away dealing with + * application and attempt IDs, and finding the app's UI. + */ +private[v1] trait BaseAppResource extends ApiRequestContext { + + @PathParam("appId") protected[this] var appId: String = _ + @PathParam("attemptId") protected[this] var attemptId: String = _ + + protected def withUI[T](fn: SparkUI => T): T = { try { - uiRoot.withSparkUI(appId, attemptId) { ui => + uiRoot.withSparkUI(appId, Option(attemptId)) { ui => val user = httpRequest.getRemoteUser() if (!ui.securityManager.checkUIViewPermissions(user)) { throw new ForbiddenException(raw"""user "$user" is not authorized""") } - f(ui) + fn(ui) } } catch { case _: NoSuchElementException => - val appKey = attemptId.map(appId + "/" + _).getOrElse(appId) + val appKey = Option(attemptId).map(appId + "/" + _).getOrElse(appId) throw new NotFoundException(s"no such app: $appKey") } } diff --git a/core/src/main/scala/org/apache/spark/status/api/v1/ApplicationEnvironmentResource.scala b/core/src/main/scala/org/apache/spark/status/api/v1/ApplicationEnvironmentResource.scala deleted file mode 100644 index e702f8aa2ef2d..0000000000000 --- a/core/src/main/scala/org/apache/spark/status/api/v1/ApplicationEnvironmentResource.scala +++ /dev/null @@ -1,32 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.spark.status.api.v1 - -import javax.ws.rs._ -import javax.ws.rs.core.MediaType - -import org.apache.spark.ui.SparkUI - -@Produces(Array(MediaType.APPLICATION_JSON)) -private[v1] class ApplicationEnvironmentResource(ui: SparkUI) { - - @GET - def getEnvironmentInfo(): ApplicationEnvironmentInfo = { - ui.store.environmentInfo() - } - -} diff --git a/core/src/main/scala/org/apache/spark/status/api/v1/ApplicationListResource.scala b/core/src/main/scala/org/apache/spark/status/api/v1/ApplicationListResource.scala index f039744e7f67f..91660a524ca93 100644 --- a/core/src/main/scala/org/apache/spark/status/api/v1/ApplicationListResource.scala +++ b/core/src/main/scala/org/apache/spark/status/api/v1/ApplicationListResource.scala @@ -23,7 +23,7 @@ import javax.ws.rs.core.MediaType import org.apache.spark.deploy.history.ApplicationHistoryInfo @Produces(Array(MediaType.APPLICATION_JSON)) -private[v1] class ApplicationListResource(uiRoot: UIRoot) { +private[v1] class ApplicationListResource extends ApiRequestContext { @GET def appList( diff --git a/core/src/main/scala/org/apache/spark/status/api/v1/EventLogDownloadResource.scala b/core/src/main/scala/org/apache/spark/status/api/v1/EventLogDownloadResource.scala deleted file mode 100644 index c84022ddfeef0..0000000000000 --- a/core/src/main/scala/org/apache/spark/status/api/v1/EventLogDownloadResource.scala +++ /dev/null @@ -1,71 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.spark.status.api.v1 - -import java.io.OutputStream -import java.util.zip.ZipOutputStream -import javax.ws.rs.{GET, Produces} -import javax.ws.rs.core.{MediaType, Response, StreamingOutput} - -import scala.util.control.NonFatal - -import org.apache.spark.SparkConf -import org.apache.spark.deploy.SparkHadoopUtil -import org.apache.spark.internal.Logging - -@Produces(Array(MediaType.APPLICATION_OCTET_STREAM)) -private[v1] class EventLogDownloadResource( - val uIRoot: UIRoot, - val appId: String, - val attemptId: Option[String]) extends Logging { - val conf = SparkHadoopUtil.get.newConfiguration(new SparkConf) - - @GET - def getEventLogs(): Response = { - try { - val fileName = { - attemptId match { - case Some(id) => s"eventLogs-$appId-$id.zip" - case None => s"eventLogs-$appId.zip" - } - } - - val stream = new StreamingOutput { - override def write(output: OutputStream): Unit = { - val zipStream = new ZipOutputStream(output) - try { - uIRoot.writeEventLogs(appId, attemptId, zipStream) - } finally { - zipStream.close() - } - - } - } - - Response.ok(stream) - .header("Content-Disposition", s"attachment; filename=$fileName") - .header("Content-Type", MediaType.APPLICATION_OCTET_STREAM) - .build() - } catch { - case NonFatal(e) => - Response.serverError() - .entity(s"Event logs are not available for app: $appId.") - .status(Response.Status.SERVICE_UNAVAILABLE) - .build() - } - } -} diff --git a/core/src/main/scala/org/apache/spark/status/api/v1/ExecutorListResource.scala b/core/src/main/scala/org/apache/spark/status/api/v1/ExecutorListResource.scala deleted file mode 100644 index 975101c33c59c..0000000000000 --- a/core/src/main/scala/org/apache/spark/status/api/v1/ExecutorListResource.scala +++ /dev/null @@ -1,30 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.spark.status.api.v1 - -import javax.ws.rs.{GET, Produces} -import javax.ws.rs.core.MediaType - -import org.apache.spark.ui.SparkUI - -@Produces(Array(MediaType.APPLICATION_JSON)) -private[v1] class ExecutorListResource(ui: SparkUI) { - - @GET - def executorList(): Seq[ExecutorSummary] = ui.store.executorList(true) - -} diff --git a/core/src/main/scala/org/apache/spark/status/api/v1/OneApplicationResource.scala b/core/src/main/scala/org/apache/spark/status/api/v1/OneApplicationResource.scala index 18c3e2f407360..bd4df07e7afc6 100644 --- a/core/src/main/scala/org/apache/spark/status/api/v1/OneApplicationResource.scala +++ b/core/src/main/scala/org/apache/spark/status/api/v1/OneApplicationResource.scala @@ -16,16 +16,150 @@ */ package org.apache.spark.status.api.v1 -import javax.ws.rs.{GET, PathParam, Produces} -import javax.ws.rs.core.MediaType +import java.io.OutputStream +import java.util.{List => JList} +import java.util.zip.ZipOutputStream +import javax.ws.rs.{GET, Path, PathParam, Produces, QueryParam} +import javax.ws.rs.core.{MediaType, Response, StreamingOutput} + +import scala.util.control.NonFatal + +import org.apache.spark.JobExecutionStatus +import org.apache.spark.ui.SparkUI @Produces(Array(MediaType.APPLICATION_JSON)) -private[v1] class OneApplicationResource(uiRoot: UIRoot) { +private[v1] class AbstractApplicationResource extends BaseAppResource { + + @GET + @Path("jobs") + def jobsList(@QueryParam("status") statuses: JList[JobExecutionStatus]): Seq[JobData] = { + withUI(_.store.jobsList(statuses)) + } + + @GET + @Path("jobs/{jobId: \\d+}") + def oneJob(@PathParam("jobId") jobId: Int): JobData = withUI { ui => + try { + ui.store.job(jobId) + } catch { + case _: NoSuchElementException => + throw new NotFoundException("unknown job: " + jobId) + } + } + + @GET + @Path("executors") + def executorList(): Seq[ExecutorSummary] = withUI(_.store.executorList(true)) + + @GET + @Path("allexecutors") + def allExecutorList(): Seq[ExecutorSummary] = withUI(_.store.executorList(false)) + + @Path("stages") + def stages(): Class[StagesResource] = classOf[StagesResource] + + @GET + @Path("storage/rdd") + def rddList(): Seq[RDDStorageInfo] = withUI(_.store.rddList()) + + @GET + @Path("storage/rdd/{rddId: \\d+}") + def rddData(@PathParam("rddId") rddId: Int): RDDStorageInfo = withUI { ui => + try { + ui.store.rdd(rddId) + } catch { + case _: NoSuchElementException => + throw new NotFoundException(s"no rdd found w/ id $rddId") + } + } + + @GET + @Path("environment") + def environmentInfo(): ApplicationEnvironmentInfo = withUI(_.store.environmentInfo()) + + @GET + @Path("logs") + @Produces(Array(MediaType.APPLICATION_OCTET_STREAM)) + def getEventLogs(): Response = { + // Retrieve the UI for the application just to do access permission checks. For backwards + // compatibility, this code also tries with attemptId "1" if the UI without an attempt ID does + // not exist. + try { + withUI { _ => } + } catch { + case _: NotFoundException if attemptId == null => + attemptId = "1" + withUI { _ => } + attemptId = null + } + + try { + val fileName = if (attemptId != null) { + s"eventLogs-$appId-$attemptId.zip" + } else { + s"eventLogs-$appId.zip" + } + + val stream = new StreamingOutput { + override def write(output: OutputStream): Unit = { + val zipStream = new ZipOutputStream(output) + try { + uiRoot.writeEventLogs(appId, Option(attemptId), zipStream) + } finally { + zipStream.close() + } + + } + } + + Response.ok(stream) + .header("Content-Disposition", s"attachment; filename=$fileName") + .header("Content-Type", MediaType.APPLICATION_OCTET_STREAM) + .build() + } catch { + case NonFatal(e) => + Response.serverError() + .entity(s"Event logs are not available for app: $appId.") + .status(Response.Status.SERVICE_UNAVAILABLE) + .build() + } + } + + /** + * This method needs to be last, otherwise it clashes with the paths for the above methods + * and causes JAX-RS to not find things. + */ + @Path("{attemptId}") + def applicationAttempt(): Class[OneApplicationAttemptResource] = { + if (attemptId != null) { + throw new NotFoundException(httpRequest.getRequestURI()) + } + classOf[OneApplicationAttemptResource] + } + +} + +private[v1] class OneApplicationResource extends AbstractApplicationResource { + + @GET + def getApp(): ApplicationInfo = { + val app = uiRoot.getApplicationInfo(appId) + app.getOrElse(throw new NotFoundException("unknown app: " + appId)) + } + +} + +private[v1] class OneApplicationAttemptResource extends AbstractApplicationResource { @GET - def getApp(@PathParam("appId") appId: String): ApplicationInfo = { - val apps = uiRoot.getApplicationInfo(appId) - apps.getOrElse(throw new NotFoundException("unknown app: " + appId)) + def getAttempt(): ApplicationAttemptInfo = { + uiRoot.getApplicationInfo(appId) + .flatMap { app => + app.attempts.filter(_.attemptId == attemptId).headOption + } + .getOrElse { + throw new NotFoundException(s"unknown app $appId, attempt $attemptId") + } } } diff --git a/core/src/main/scala/org/apache/spark/status/api/v1/OneJobResource.scala b/core/src/main/scala/org/apache/spark/status/api/v1/OneJobResource.scala deleted file mode 100644 index 3ee884e084c12..0000000000000 --- a/core/src/main/scala/org/apache/spark/status/api/v1/OneJobResource.scala +++ /dev/null @@ -1,38 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.spark.status.api.v1 - -import java.util.NoSuchElementException -import javax.ws.rs.{GET, PathParam, Produces} -import javax.ws.rs.core.MediaType - -import org.apache.spark.ui.SparkUI - -@Produces(Array(MediaType.APPLICATION_JSON)) -private[v1] class OneJobResource(ui: SparkUI) { - - @GET - def oneJob(@PathParam("jobId") jobId: Int): JobData = { - try { - ui.store.job(jobId) - } catch { - case _: NoSuchElementException => - throw new NotFoundException("unknown job: " + jobId) - } - } - -} diff --git a/core/src/main/scala/org/apache/spark/status/api/v1/OneRDDResource.scala b/core/src/main/scala/org/apache/spark/status/api/v1/OneRDDResource.scala deleted file mode 100644 index ca9758cf0d109..0000000000000 --- a/core/src/main/scala/org/apache/spark/status/api/v1/OneRDDResource.scala +++ /dev/null @@ -1,38 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.spark.status.api.v1 - -import java.util.NoSuchElementException -import javax.ws.rs.{GET, PathParam, Produces} -import javax.ws.rs.core.MediaType - -import org.apache.spark.ui.SparkUI - -@Produces(Array(MediaType.APPLICATION_JSON)) -private[v1] class OneRDDResource(ui: SparkUI) { - - @GET - def rddData(@PathParam("rddId") rddId: Int): RDDStorageInfo = { - try { - ui.store.rdd(rddId) - } catch { - case _: NoSuchElementException => - throw new NotFoundException(s"no rdd found w/ id $rddId") - } - } - -} diff --git a/core/src/main/scala/org/apache/spark/status/api/v1/OneStageResource.scala b/core/src/main/scala/org/apache/spark/status/api/v1/StagesResource.scala similarity index 77% rename from core/src/main/scala/org/apache/spark/status/api/v1/OneStageResource.scala rename to core/src/main/scala/org/apache/spark/status/api/v1/StagesResource.scala index 20dd73e916613..bd4dfe3c68885 100644 --- a/core/src/main/scala/org/apache/spark/status/api/v1/OneStageResource.scala +++ b/core/src/main/scala/org/apache/spark/status/api/v1/StagesResource.scala @@ -16,6 +16,7 @@ */ package org.apache.spark.status.api.v1 +import java.util.{List => JList} import javax.ws.rs._ import javax.ws.rs.core.MediaType @@ -27,27 +28,34 @@ import org.apache.spark.ui.SparkUI import org.apache.spark.ui.jobs.UIData.StageUIData @Produces(Array(MediaType.APPLICATION_JSON)) -private[v1] class OneStageResource(ui: SparkUI) { +private[v1] class StagesResource extends BaseAppResource { @GET - @Path("") + def stageList(@QueryParam("status") statuses: JList[StageStatus]): Seq[StageData] = { + withUI(_.store.stageList(statuses)) + } + + @GET + @Path("{stageId: \\d+}") def stageData( @PathParam("stageId") stageId: Int, @QueryParam("details") @DefaultValue("true") details: Boolean): Seq[StageData] = { - val ret = ui.store.stageData(stageId, details = details) - if (ret.nonEmpty) { - ret - } else { - throw new NotFoundException(s"unknown stage: $stageId") + withUI { ui => + val ret = ui.store.stageData(stageId, details = details) + if (ret.nonEmpty) { + ret + } else { + throw new NotFoundException(s"unknown stage: $stageId") + } } } @GET - @Path("/{stageAttemptId: \\d+}") + @Path("{stageId: \\d+}/{stageAttemptId: \\d+}") def oneAttemptData( @PathParam("stageId") stageId: Int, @PathParam("stageAttemptId") stageAttemptId: Int, - @QueryParam("details") @DefaultValue("true") details: Boolean): StageData = { + @QueryParam("details") @DefaultValue("true") details: Boolean): StageData = withUI { ui => try { ui.store.stageAttempt(stageId, stageAttemptId, details = details) } catch { @@ -57,12 +65,12 @@ private[v1] class OneStageResource(ui: SparkUI) { } @GET - @Path("/{stageAttemptId: \\d+}/taskSummary") + @Path("{stageId: \\d+}/{stageAttemptId: \\d+}/taskSummary") def taskSummary( @PathParam("stageId") stageId: Int, @PathParam("stageAttemptId") stageAttemptId: Int, @DefaultValue("0.05,0.25,0.5,0.75,0.95") @QueryParam("quantiles") quantileString: String) - : TaskMetricDistributions = { + : TaskMetricDistributions = withUI { ui => val quantiles = quantileString.split(",").map { s => try { s.toDouble @@ -76,14 +84,14 @@ private[v1] class OneStageResource(ui: SparkUI) { } @GET - @Path("/{stageAttemptId: \\d+}/taskList") + @Path("{stageId: \\d+}/{stageAttemptId: \\d+}/taskList") def taskList( @PathParam("stageId") stageId: Int, @PathParam("stageAttemptId") stageAttemptId: Int, @DefaultValue("0") @QueryParam("offset") offset: Int, @DefaultValue("20") @QueryParam("length") length: Int, @DefaultValue("ID") @QueryParam("sortBy") sortBy: TaskSorting): Seq[TaskData] = { - ui.store.taskList(stageId, stageAttemptId, offset, length, sortBy) + withUI(_.store.taskList(stageId, stageAttemptId, offset, length, sortBy)) } } diff --git a/core/src/main/scala/org/apache/spark/status/api/v1/VersionResource.scala b/core/src/main/scala/org/apache/spark/status/api/v1/VersionResource.scala deleted file mode 100644 index 673da1ce36b57..0000000000000 --- a/core/src/main/scala/org/apache/spark/status/api/v1/VersionResource.scala +++ /dev/null @@ -1,30 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.spark.status.api.v1 - -import javax.ws.rs._ -import javax.ws.rs.core.MediaType - -@Produces(Array(MediaType.APPLICATION_JSON)) -private[v1] class VersionResource(ui: UIRoot) { - - @GET - def getVersionInfo(): VersionInfo = new VersionInfo( - org.apache.spark.SPARK_VERSION - ) - -} diff --git a/streaming/src/main/scala/org/apache/spark/status/api/v1/streaming/AllBatchesResource.scala b/streaming/src/main/scala/org/apache/spark/status/api/v1/streaming/AllBatchesResource.scala deleted file mode 100644 index 3a51ae609303a..0000000000000 --- a/streaming/src/main/scala/org/apache/spark/status/api/v1/streaming/AllBatchesResource.scala +++ /dev/null @@ -1,78 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.status.api.v1.streaming - -import java.util.{ArrayList => JArrayList, Arrays => JArrays, Date, List => JList} -import javax.ws.rs.{GET, Produces, QueryParam} -import javax.ws.rs.core.MediaType - -import org.apache.spark.status.api.v1.streaming.AllBatchesResource._ -import org.apache.spark.streaming.ui.StreamingJobProgressListener - -@Produces(Array(MediaType.APPLICATION_JSON)) -private[v1] class AllBatchesResource(listener: StreamingJobProgressListener) { - - @GET - def batchesList(@QueryParam("status") statusParams: JList[BatchStatus]): Seq[BatchInfo] = { - batchInfoList(listener, statusParams).sortBy(- _.batchId) - } -} - -private[v1] object AllBatchesResource { - - def batchInfoList( - listener: StreamingJobProgressListener, - statusParams: JList[BatchStatus] = new JArrayList[BatchStatus]()): Seq[BatchInfo] = { - - listener.synchronized { - val statuses = - if (statusParams.isEmpty) JArrays.asList(BatchStatus.values(): _*) else statusParams - val statusToBatches = Seq( - BatchStatus.COMPLETED -> listener.retainedCompletedBatches, - BatchStatus.QUEUED -> listener.waitingBatches, - BatchStatus.PROCESSING -> listener.runningBatches - ) - - val batchInfos = for { - (status, batches) <- statusToBatches - batch <- batches if statuses.contains(status) - } yield { - val batchId = batch.batchTime.milliseconds - val firstFailureReason = batch.outputOperations.flatMap(_._2.failureReason).headOption - - new BatchInfo( - batchId = batchId, - batchTime = new Date(batchId), - status = status.toString, - batchDuration = listener.batchDuration, - inputSize = batch.numRecords, - schedulingDelay = batch.schedulingDelay, - processingTime = batch.processingDelay, - totalDelay = batch.totalDelay, - numActiveOutputOps = batch.numActiveOutputOp, - numCompletedOutputOps = batch.numCompletedOutputOp, - numFailedOutputOps = batch.numFailedOutputOp, - numTotalOutputOps = batch.outputOperations.size, - firstFailureReason = firstFailureReason - ) - } - - batchInfos - } - } -} diff --git a/streaming/src/main/scala/org/apache/spark/status/api/v1/streaming/AllOutputOperationsResource.scala b/streaming/src/main/scala/org/apache/spark/status/api/v1/streaming/AllOutputOperationsResource.scala deleted file mode 100644 index 0eb649f0e1b72..0000000000000 --- a/streaming/src/main/scala/org/apache/spark/status/api/v1/streaming/AllOutputOperationsResource.scala +++ /dev/null @@ -1,66 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.status.api.v1.streaming - -import java.util.Date -import javax.ws.rs.{GET, PathParam, Produces} -import javax.ws.rs.core.MediaType - -import org.apache.spark.status.api.v1.NotFoundException -import org.apache.spark.status.api.v1.streaming.AllOutputOperationsResource._ -import org.apache.spark.streaming.Time -import org.apache.spark.streaming.ui.StreamingJobProgressListener - -@Produces(Array(MediaType.APPLICATION_JSON)) -private[v1] class AllOutputOperationsResource(listener: StreamingJobProgressListener) { - - @GET - def operationsList(@PathParam("batchId") batchId: Long): Seq[OutputOperationInfo] = { - outputOperationInfoList(listener, batchId).sortBy(_.outputOpId) - } -} - -private[v1] object AllOutputOperationsResource { - - def outputOperationInfoList( - listener: StreamingJobProgressListener, - batchId: Long): Seq[OutputOperationInfo] = { - - listener.synchronized { - listener.getBatchUIData(Time(batchId)) match { - case Some(batch) => - for ((opId, op) <- batch.outputOperations) yield { - val jobIds = batch.outputOpIdSparkJobIdPairs - .filter(_.outputOpId == opId).map(_.sparkJobId).toSeq.sorted - - new OutputOperationInfo( - outputOpId = opId, - name = op.name, - description = op.description, - startTime = op.startTime.map(new Date(_)), - endTime = op.endTime.map(new Date(_)), - duration = op.duration, - failureReason = op.failureReason, - jobIds = jobIds - ) - } - case None => throw new NotFoundException("unknown batch: " + batchId) - } - }.toSeq - } -} diff --git a/streaming/src/main/scala/org/apache/spark/status/api/v1/streaming/AllReceiversResource.scala b/streaming/src/main/scala/org/apache/spark/status/api/v1/streaming/AllReceiversResource.scala deleted file mode 100644 index 5a276a9236a0f..0000000000000 --- a/streaming/src/main/scala/org/apache/spark/status/api/v1/streaming/AllReceiversResource.scala +++ /dev/null @@ -1,76 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.status.api.v1.streaming - -import java.util.Date -import javax.ws.rs.{GET, Produces} -import javax.ws.rs.core.MediaType - -import org.apache.spark.status.api.v1.streaming.AllReceiversResource._ -import org.apache.spark.streaming.ui.StreamingJobProgressListener - -@Produces(Array(MediaType.APPLICATION_JSON)) -private[v1] class AllReceiversResource(listener: StreamingJobProgressListener) { - - @GET - def receiversList(): Seq[ReceiverInfo] = { - receiverInfoList(listener).sortBy(_.streamId) - } -} - -private[v1] object AllReceiversResource { - - def receiverInfoList(listener: StreamingJobProgressListener): Seq[ReceiverInfo] = { - listener.synchronized { - listener.receivedRecordRateWithBatchTime.map { case (streamId, eventRates) => - - val receiverInfo = listener.receiverInfo(streamId) - val streamName = receiverInfo.map(_.name) - .orElse(listener.streamName(streamId)).getOrElse(s"Stream-$streamId") - val avgEventRate = - if (eventRates.isEmpty) None else Some(eventRates.map(_._2).sum / eventRates.size) - - val (errorTime, errorMessage, error) = receiverInfo match { - case None => (None, None, None) - case Some(info) => - val someTime = - if (info.lastErrorTime >= 0) Some(new Date(info.lastErrorTime)) else None - val someMessage = - if (info.lastErrorMessage.length > 0) Some(info.lastErrorMessage) else None - val someError = - if (info.lastError.length > 0) Some(info.lastError) else None - - (someTime, someMessage, someError) - } - - new ReceiverInfo( - streamId = streamId, - streamName = streamName, - isActive = receiverInfo.map(_.active), - executorId = receiverInfo.map(_.executorId), - executorHost = receiverInfo.map(_.location), - lastErrorTime = errorTime, - lastErrorMessage = errorMessage, - lastError = error, - avgEventRate = avgEventRate, - eventRates = eventRates - ) - }.toSeq - } - } -} diff --git a/streaming/src/main/scala/org/apache/spark/status/api/v1/streaming/ApiStreamingApp.scala b/streaming/src/main/scala/org/apache/spark/status/api/v1/streaming/ApiStreamingApp.scala index aea75d5a9c8d0..07d8164e1d2c0 100644 --- a/streaming/src/main/scala/org/apache/spark/status/api/v1/streaming/ApiStreamingApp.scala +++ b/streaming/src/main/scala/org/apache/spark/status/api/v1/streaming/ApiStreamingApp.scala @@ -19,24 +19,39 @@ package org.apache.spark.status.api.v1.streaming import javax.ws.rs.{Path, PathParam} -import org.apache.spark.status.api.v1.ApiRequestContext +import org.apache.spark.status.api.v1._ +import org.apache.spark.streaming.ui.StreamingJobProgressListener @Path("/v1") private[v1] class ApiStreamingApp extends ApiRequestContext { @Path("applications/{appId}/streaming") - def getStreamingRoot(@PathParam("appId") appId: String): ApiStreamingRootResource = { - withSparkUI(appId, None) { ui => - new ApiStreamingRootResource(ui) - } + def getStreamingRoot(@PathParam("appId") appId: String): Class[ApiStreamingRootResource] = { + classOf[ApiStreamingRootResource] } @Path("applications/{appId}/{attemptId}/streaming") def getStreamingRoot( @PathParam("appId") appId: String, - @PathParam("attemptId") attemptId: String): ApiStreamingRootResource = { - withSparkUI(appId, Some(attemptId)) { ui => - new ApiStreamingRootResource(ui) + @PathParam("attemptId") attemptId: String): Class[ApiStreamingRootResource] = { + classOf[ApiStreamingRootResource] + } +} + +/** + * Base class for streaming API handlers, provides easy access to the streaming listener that + * holds the app's information. + */ +private[v1] trait BaseStreamingAppResource extends BaseAppResource { + + protected def withListener[T](fn: StreamingJobProgressListener => T): T = withUI { ui => + val listener = ui.getStreamingJobProgressListener match { + case Some(listener) => listener.asInstanceOf[StreamingJobProgressListener] + case None => throw new NotFoundException("no streaming listener attached to " + ui.getAppName) + } + listener.synchronized { + fn(listener) } } + } diff --git a/streaming/src/main/scala/org/apache/spark/status/api/v1/streaming/ApiStreamingRootResource.scala b/streaming/src/main/scala/org/apache/spark/status/api/v1/streaming/ApiStreamingRootResource.scala index 1ccd586c848bd..a2571b910f615 100644 --- a/streaming/src/main/scala/org/apache/spark/status/api/v1/streaming/ApiStreamingRootResource.scala +++ b/streaming/src/main/scala/org/apache/spark/status/api/v1/streaming/ApiStreamingRootResource.scala @@ -17,58 +17,180 @@ package org.apache.spark.status.api.v1.streaming -import javax.ws.rs.Path +import java.util.{Arrays => JArrays, Collections, Date, List => JList} +import javax.ws.rs.{GET, Path, PathParam, Produces, QueryParam} +import javax.ws.rs.core.MediaType import org.apache.spark.status.api.v1.NotFoundException +import org.apache.spark.streaming.Time import org.apache.spark.streaming.ui.StreamingJobProgressListener +import org.apache.spark.streaming.ui.StreamingJobProgressListener._ import org.apache.spark.ui.SparkUI -private[v1] class ApiStreamingRootResource(ui: SparkUI) { - - import org.apache.spark.status.api.v1.streaming.ApiStreamingRootResource._ +@Produces(Array(MediaType.APPLICATION_JSON)) +private[v1] class ApiStreamingRootResource extends BaseStreamingAppResource { + @GET @Path("statistics") - def getStreamingStatistics(): StreamingStatisticsResource = { - new StreamingStatisticsResource(getListener(ui)) + def streamingStatistics(): StreamingStatistics = withListener { listener => + val batches = listener.retainedBatches + val avgInputRate = avgRate(batches.map(_.numRecords * 1000.0 / listener.batchDuration)) + val avgSchedulingDelay = avgTime(batches.flatMap(_.schedulingDelay)) + val avgProcessingTime = avgTime(batches.flatMap(_.processingDelay)) + val avgTotalDelay = avgTime(batches.flatMap(_.totalDelay)) + + new StreamingStatistics( + startTime = new Date(listener.startTime), + batchDuration = listener.batchDuration, + numReceivers = listener.numReceivers, + numActiveReceivers = listener.numActiveReceivers, + numInactiveReceivers = listener.numInactiveReceivers, + numTotalCompletedBatches = listener.numTotalCompletedBatches, + numRetainedCompletedBatches = listener.retainedCompletedBatches.size, + numActiveBatches = listener.numUnprocessedBatches, + numProcessedRecords = listener.numTotalProcessedRecords, + numReceivedRecords = listener.numTotalReceivedRecords, + avgInputRate = avgInputRate, + avgSchedulingDelay = avgSchedulingDelay, + avgProcessingTime = avgProcessingTime, + avgTotalDelay = avgTotalDelay + ) } + @GET @Path("receivers") - def getReceivers(): AllReceiversResource = { - new AllReceiversResource(getListener(ui)) + def receiversList(): Seq[ReceiverInfo] = withListener { listener => + listener.receivedRecordRateWithBatchTime.map { case (streamId, eventRates) => + val receiverInfo = listener.receiverInfo(streamId) + val streamName = receiverInfo.map(_.name) + .orElse(listener.streamName(streamId)).getOrElse(s"Stream-$streamId") + val avgEventRate = + if (eventRates.isEmpty) None else Some(eventRates.map(_._2).sum / eventRates.size) + + val (errorTime, errorMessage, error) = receiverInfo match { + case None => (None, None, None) + case Some(info) => + val someTime = + if (info.lastErrorTime >= 0) Some(new Date(info.lastErrorTime)) else None + val someMessage = + if (info.lastErrorMessage.length > 0) Some(info.lastErrorMessage) else None + val someError = + if (info.lastError.length > 0) Some(info.lastError) else None + + (someTime, someMessage, someError) + } + + new ReceiverInfo( + streamId = streamId, + streamName = streamName, + isActive = receiverInfo.map(_.active), + executorId = receiverInfo.map(_.executorId), + executorHost = receiverInfo.map(_.location), + lastErrorTime = errorTime, + lastErrorMessage = errorMessage, + lastError = error, + avgEventRate = avgEventRate, + eventRates = eventRates + ) + }.toSeq.sortBy(_.streamId) } + @GET @Path("receivers/{streamId: \\d+}") - def getReceiver(): OneReceiverResource = { - new OneReceiverResource(getListener(ui)) + def oneReceiver(@PathParam("streamId") streamId: Int): ReceiverInfo = { + receiversList().find { _.streamId == streamId }.getOrElse( + throw new NotFoundException("unknown receiver: " + streamId)) } + @GET @Path("batches") - def getBatches(): AllBatchesResource = { - new AllBatchesResource(getListener(ui)) + def batchesList(@QueryParam("status") statusParams: JList[BatchStatus]): Seq[BatchInfo] = { + withListener { listener => + val statuses = + if (statusParams.isEmpty) JArrays.asList(BatchStatus.values(): _*) else statusParams + val statusToBatches = Seq( + BatchStatus.COMPLETED -> listener.retainedCompletedBatches, + BatchStatus.QUEUED -> listener.waitingBatches, + BatchStatus.PROCESSING -> listener.runningBatches + ) + + val batchInfos = for { + (status, batches) <- statusToBatches + batch <- batches if statuses.contains(status) + } yield { + val batchId = batch.batchTime.milliseconds + val firstFailureReason = batch.outputOperations.flatMap(_._2.failureReason).headOption + + new BatchInfo( + batchId = batchId, + batchTime = new Date(batchId), + status = status.toString, + batchDuration = listener.batchDuration, + inputSize = batch.numRecords, + schedulingDelay = batch.schedulingDelay, + processingTime = batch.processingDelay, + totalDelay = batch.totalDelay, + numActiveOutputOps = batch.numActiveOutputOp, + numCompletedOutputOps = batch.numCompletedOutputOp, + numFailedOutputOps = batch.numFailedOutputOp, + numTotalOutputOps = batch.outputOperations.size, + firstFailureReason = firstFailureReason + ) + } + + batchInfos.sortBy(- _.batchId) + } } + @GET @Path("batches/{batchId: \\d+}") - def getBatch(): OneBatchResource = { - new OneBatchResource(getListener(ui)) + def oneBatch(@PathParam("batchId") batchId: Long): BatchInfo = { + batchesList(Collections.emptyList()).find { _.batchId == batchId }.getOrElse( + throw new NotFoundException("unknown batch: " + batchId)) } + @GET @Path("batches/{batchId: \\d+}/operations") - def getOutputOperations(): AllOutputOperationsResource = { - new AllOutputOperationsResource(getListener(ui)) + def operationsList(@PathParam("batchId") batchId: Long): Seq[OutputOperationInfo] = { + withListener { listener => + val ops = listener.getBatchUIData(Time(batchId)) match { + case Some(batch) => + for ((opId, op) <- batch.outputOperations) yield { + val jobIds = batch.outputOpIdSparkJobIdPairs + .filter(_.outputOpId == opId).map(_.sparkJobId).toSeq.sorted + + new OutputOperationInfo( + outputOpId = opId, + name = op.name, + description = op.description, + startTime = op.startTime.map(new Date(_)), + endTime = op.endTime.map(new Date(_)), + duration = op.duration, + failureReason = op.failureReason, + jobIds = jobIds + ) + } + case None => throw new NotFoundException("unknown batch: " + batchId) + } + ops.toSeq + } } + @GET @Path("batches/{batchId: \\d+}/operations/{outputOpId: \\d+}") - def getOutputOperation(): OneOutputOperationResource = { - new OneOutputOperationResource(getListener(ui)) + def oneOperation( + @PathParam("batchId") batchId: Long, + @PathParam("outputOpId") opId: OutputOpId): OutputOperationInfo = { + operationsList(batchId).find { _.outputOpId == opId }.getOrElse( + throw new NotFoundException("unknown output operation: " + opId)) } -} + private def avgRate(data: Seq[Double]): Option[Double] = { + if (data.isEmpty) None else Some(data.sum / data.size) + } -private[v1] object ApiStreamingRootResource { - def getListener(ui: SparkUI): StreamingJobProgressListener = { - ui.getStreamingJobProgressListener match { - case Some(listener) => listener.asInstanceOf[StreamingJobProgressListener] - case None => throw new NotFoundException("no streaming listener attached to " + ui.getAppName) - } + private def avgTime(data: Seq[Long]): Option[Long] = { + if (data.isEmpty) None else Some(data.sum / data.size) } + } diff --git a/streaming/src/main/scala/org/apache/spark/status/api/v1/streaming/OneBatchResource.scala b/streaming/src/main/scala/org/apache/spark/status/api/v1/streaming/OneBatchResource.scala deleted file mode 100644 index d3c689c790cfc..0000000000000 --- a/streaming/src/main/scala/org/apache/spark/status/api/v1/streaming/OneBatchResource.scala +++ /dev/null @@ -1,35 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.status.api.v1.streaming - -import javax.ws.rs.{GET, PathParam, Produces} -import javax.ws.rs.core.MediaType - -import org.apache.spark.status.api.v1.NotFoundException -import org.apache.spark.streaming.ui.StreamingJobProgressListener - -@Produces(Array(MediaType.APPLICATION_JSON)) -private[v1] class OneBatchResource(listener: StreamingJobProgressListener) { - - @GET - def oneBatch(@PathParam("batchId") batchId: Long): BatchInfo = { - val someBatch = AllBatchesResource.batchInfoList(listener) - .find { _.batchId == batchId } - someBatch.getOrElse(throw new NotFoundException("unknown batch: " + batchId)) - } -} diff --git a/streaming/src/main/scala/org/apache/spark/status/api/v1/streaming/OneOutputOperationResource.scala b/streaming/src/main/scala/org/apache/spark/status/api/v1/streaming/OneOutputOperationResource.scala deleted file mode 100644 index aabcdb29b0d4c..0000000000000 --- a/streaming/src/main/scala/org/apache/spark/status/api/v1/streaming/OneOutputOperationResource.scala +++ /dev/null @@ -1,39 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.status.api.v1.streaming - -import javax.ws.rs.{GET, PathParam, Produces} -import javax.ws.rs.core.MediaType - -import org.apache.spark.status.api.v1.NotFoundException -import org.apache.spark.streaming.ui.StreamingJobProgressListener -import org.apache.spark.streaming.ui.StreamingJobProgressListener._ - -@Produces(Array(MediaType.APPLICATION_JSON)) -private[v1] class OneOutputOperationResource(listener: StreamingJobProgressListener) { - - @GET - def oneOperation( - @PathParam("batchId") batchId: Long, - @PathParam("outputOpId") opId: OutputOpId): OutputOperationInfo = { - - val someOutputOp = AllOutputOperationsResource.outputOperationInfoList(listener, batchId) - .find { _.outputOpId == opId } - someOutputOp.getOrElse(throw new NotFoundException("unknown output operation: " + opId)) - } -} diff --git a/streaming/src/main/scala/org/apache/spark/status/api/v1/streaming/OneReceiverResource.scala b/streaming/src/main/scala/org/apache/spark/status/api/v1/streaming/OneReceiverResource.scala deleted file mode 100644 index c0cc99da3a9c7..0000000000000 --- a/streaming/src/main/scala/org/apache/spark/status/api/v1/streaming/OneReceiverResource.scala +++ /dev/null @@ -1,35 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.status.api.v1.streaming - -import javax.ws.rs.{GET, PathParam, Produces} -import javax.ws.rs.core.MediaType - -import org.apache.spark.status.api.v1.NotFoundException -import org.apache.spark.streaming.ui.StreamingJobProgressListener - -@Produces(Array(MediaType.APPLICATION_JSON)) -private[v1] class OneReceiverResource(listener: StreamingJobProgressListener) { - - @GET - def oneReceiver(@PathParam("streamId") streamId: Int): ReceiverInfo = { - val someReceiver = AllReceiversResource.receiverInfoList(listener) - .find { _.streamId == streamId } - someReceiver.getOrElse(throw new NotFoundException("unknown receiver: " + streamId)) - } -} diff --git a/streaming/src/main/scala/org/apache/spark/status/api/v1/streaming/StreamingStatisticsResource.scala b/streaming/src/main/scala/org/apache/spark/status/api/v1/streaming/StreamingStatisticsResource.scala deleted file mode 100644 index 6cff87be59ca8..0000000000000 --- a/streaming/src/main/scala/org/apache/spark/status/api/v1/streaming/StreamingStatisticsResource.scala +++ /dev/null @@ -1,64 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.status.api.v1.streaming - -import java.util.Date -import javax.ws.rs.{GET, Produces} -import javax.ws.rs.core.MediaType - -import org.apache.spark.streaming.ui.StreamingJobProgressListener - -@Produces(Array(MediaType.APPLICATION_JSON)) -private[v1] class StreamingStatisticsResource(listener: StreamingJobProgressListener) { - - @GET - def streamingStatistics(): StreamingStatistics = { - listener.synchronized { - val batches = listener.retainedBatches - val avgInputRate = avgRate(batches.map(_.numRecords * 1000.0 / listener.batchDuration)) - val avgSchedulingDelay = avgTime(batches.flatMap(_.schedulingDelay)) - val avgProcessingTime = avgTime(batches.flatMap(_.processingDelay)) - val avgTotalDelay = avgTime(batches.flatMap(_.totalDelay)) - - new StreamingStatistics( - startTime = new Date(listener.startTime), - batchDuration = listener.batchDuration, - numReceivers = listener.numReceivers, - numActiveReceivers = listener.numActiveReceivers, - numInactiveReceivers = listener.numInactiveReceivers, - numTotalCompletedBatches = listener.numTotalCompletedBatches, - numRetainedCompletedBatches = listener.retainedCompletedBatches.size, - numActiveBatches = listener.numUnprocessedBatches, - numProcessedRecords = listener.numTotalProcessedRecords, - numReceivedRecords = listener.numTotalReceivedRecords, - avgInputRate = avgInputRate, - avgSchedulingDelay = avgSchedulingDelay, - avgProcessingTime = avgProcessingTime, - avgTotalDelay = avgTotalDelay - ) - } - } - - private def avgRate(data: Seq[Double]): Option[Double] = { - if (data.isEmpty) None else Some(data.sum / data.size) - } - - private def avgTime(data: Seq[Long]): Option[Long] = { - if (data.isEmpty) None else Some(data.sum / data.size) - } -} From 2014e7a789d36e376ca62b1e24636d79c1b19745 Mon Sep 17 00:00:00 2001 From: osatici Date: Wed, 15 Nov 2017 14:08:51 -0800 Subject: [PATCH 1680/1765] [SPARK-22479][SQL] Exclude credentials from SaveintoDataSourceCommand.simpleString ## What changes were proposed in this pull request? Do not include jdbc properties which may contain credentials in logging a logical plan with `SaveIntoDataSourceCommand` in it. ## How was this patch tested? building locally and trying to reproduce (per the steps in https://issues.apache.org/jira/browse/SPARK-22479): ``` == Parsed Logical Plan == SaveIntoDataSourceCommand org.apache.spark.sql.execution.datasources.jdbc.JdbcRelationProvider570127fa, Map(dbtable -> test20, driver -> org.postgresql.Driver, url -> *********(redacted), password -> *********(redacted)), ErrorIfExists +- Range (0, 100, step=1, splits=Some(8)) == Analyzed Logical Plan == SaveIntoDataSourceCommand org.apache.spark.sql.execution.datasources.jdbc.JdbcRelationProvider570127fa, Map(dbtable -> test20, driver -> org.postgresql.Driver, url -> *********(redacted), password -> *********(redacted)), ErrorIfExists +- Range (0, 100, step=1, splits=Some(8)) == Optimized Logical Plan == SaveIntoDataSourceCommand org.apache.spark.sql.execution.datasources.jdbc.JdbcRelationProvider570127fa, Map(dbtable -> test20, driver -> org.postgresql.Driver, url -> *********(redacted), password -> *********(redacted)), ErrorIfExists +- Range (0, 100, step=1, splits=Some(8)) == Physical Plan == Execute SaveIntoDataSourceCommand +- SaveIntoDataSourceCommand org.apache.spark.sql.execution.datasources.jdbc.JdbcRelationProvider570127fa, Map(dbtable -> test20, driver -> org.postgresql.Driver, url -> *********(redacted), password -> *********(redacted)), ErrorIfExists +- Range (0, 100, step=1, splits=Some(8)) ``` Author: osatici Closes #19708 from onursatici/os/redact-jdbc-creds. --- .../spark/internal/config/package.scala | 2 +- .../SaveIntoDataSourceCommand.scala | 7 +++ .../SaveIntoDataSourceCommandSuite.scala | 48 +++++++++++++++++++ 3 files changed, 56 insertions(+), 1 deletion(-) create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/SaveIntoDataSourceCommandSuite.scala diff --git a/core/src/main/scala/org/apache/spark/internal/config/package.scala b/core/src/main/scala/org/apache/spark/internal/config/package.scala index 57e2da8353d6d..84315f55a59ad 100644 --- a/core/src/main/scala/org/apache/spark/internal/config/package.scala +++ b/core/src/main/scala/org/apache/spark/internal/config/package.scala @@ -307,7 +307,7 @@ package object config { "a property key or value, the value is redacted from the environment UI and various logs " + "like YARN and event logs.") .regexConf - .createWithDefault("(?i)secret|password".r) + .createWithDefault("(?i)secret|password|url|user|username".r) private[spark] val STRING_REDACTION_PATTERN = ConfigBuilder("spark.redaction.string.regex") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SaveIntoDataSourceCommand.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SaveIntoDataSourceCommand.scala index 96c84eab1c894..568e953a5db66 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SaveIntoDataSourceCommand.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SaveIntoDataSourceCommand.scala @@ -17,11 +17,13 @@ package org.apache.spark.sql.execution.datasources +import org.apache.spark.SparkEnv import org.apache.spark.sql.{Dataset, Row, SaveMode, SparkSession} import org.apache.spark.sql.catalyst.plans.QueryPlan import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.execution.command.RunnableCommand import org.apache.spark.sql.sources.CreatableRelationProvider +import org.apache.spark.util.Utils /** * Saves the results of `query` in to a data source. @@ -46,4 +48,9 @@ case class SaveIntoDataSourceCommand( Seq.empty[Row] } + + override def simpleString: String = { + val redacted = Utils.redact(SparkEnv.get.conf, options.toSeq).toMap + s"SaveIntoDataSourceCommand ${dataSource}, ${redacted}, ${mode}" + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/SaveIntoDataSourceCommandSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/SaveIntoDataSourceCommandSuite.scala new file mode 100644 index 0000000000000..4b3ca8e60cab6 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/SaveIntoDataSourceCommandSuite.scala @@ -0,0 +1,48 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources + +import org.apache.spark.SparkConf +import org.apache.spark.sql.SaveMode +import org.apache.spark.sql.test.SharedSQLContext + +class SaveIntoDataSourceCommandSuite extends SharedSQLContext { + + override protected def sparkConf: SparkConf = super.sparkConf + .set("spark.redaction.regex", "(?i)password|url") + + test("simpleString is redacted") { + val URL = "connection.url" + val PASS = "123" + val DRIVER = "mydriver" + + val dataSource = DataSource( + sparkSession = spark, + className = "jdbc", + partitionColumns = Nil, + options = Map("password" -> PASS, "url" -> URL, "driver" -> DRIVER)) + + val logicalPlanString = dataSource + .planForWriting(SaveMode.ErrorIfExists, spark.range(1).logicalPlan) + .treeString(true) + + assert(!logicalPlanString.contains(URL)) + assert(!logicalPlanString.contains(PASS)) + assert(logicalPlanString.contains(DRIVER)) + } +} From 1e82335413bc2384073ead0d6d581c862036d0f5 Mon Sep 17 00:00:00 2001 From: ArtRand Date: Wed, 15 Nov 2017 15:53:05 -0800 Subject: [PATCH 1681/1765] [SPARK-21842][MESOS] Support Kerberos ticket renewal and creation in Mesos ## What changes were proposed in this pull request? tl;dr: Add a class, `MesosHadoopDelegationTokenManager` that updates delegation tokens on a schedule on the behalf of Spark Drivers. Broadcast renewed credentials to the executors. ## The problem We recently added Kerberos support to Mesos-based Spark jobs as well as Secrets support to the Mesos Dispatcher (SPARK-16742, SPARK-20812, respectively). However the delegation tokens have a defined expiration. This poses a problem for long running Spark jobs (e.g. Spark Streaming applications). YARN has a solution for this where a thread is scheduled to renew the tokens they reach 75% of their way to expiration. It then writes the tokens to HDFS for the executors to find (uses a monotonically increasing suffix). ## This solution We replace the current method in `CoarseGrainedSchedulerBackend` which used to discard the token renewal time with a protected method `fetchHadoopDelegationTokens`. Now the individual cluster backends are responsible for overriding this method to fetch and manage token renewal. The delegation tokens themselves, are still part of the `CoarseGrainedSchedulerBackend` as before. In the case of Mesos renewed Credentials are broadcasted to the executors. This maintains all transfer of Credentials within Spark (as opposed to Spark-to-HDFS). It also does not require any writing of Credentials to disk. It also does not require any GC of old files. ## How was this patch tested? Manually against a Kerberized HDFS cluster. Thank you for the reviews. Author: ArtRand Closes #19272 from ArtRand/spark-21842-450-kerberos-ticket-renewal. --- .../apache/spark/deploy/SparkHadoopUtil.scala | 28 +++- .../HadoopDelegationTokenManager.scala | 3 + .../CoarseGrainedExecutorBackend.scala | 9 +- .../cluster/CoarseGrainedClusterMessage.scala | 3 + .../CoarseGrainedSchedulerBackend.scala | 30 +--- .../cluster/mesos/MesosClusterManager.scala | 2 +- .../MesosCoarseGrainedSchedulerBackend.scala | 19 ++- .../MesosHadoopDelegationTokenManager.scala | 157 ++++++++++++++++++ .../yarn/security/AMCredentialRenewer.scala | 21 +-- 9 files changed, 228 insertions(+), 44 deletions(-) create mode 100644 resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosHadoopDelegationTokenManager.scala diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala b/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala index 1fa10ab943f34..17c7319b40f24 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala @@ -140,12 +140,23 @@ class SparkHadoopUtil extends Logging { if (!new File(keytabFilename).exists()) { throw new SparkException(s"Keytab file: ${keytabFilename} does not exist") } else { - logInfo("Attempting to login to Kerberos" + - s" using principal: ${principalName} and keytab: ${keytabFilename}") + logInfo("Attempting to login to Kerberos " + + s"using principal: ${principalName} and keytab: ${keytabFilename}") UserGroupInformation.loginUserFromKeytab(principalName, keytabFilename) } } + /** + * Add or overwrite current user's credentials with serialized delegation tokens, + * also confirms correct hadoop configuration is set. + */ + private[spark] def addDelegationTokens(tokens: Array[Byte], sparkConf: SparkConf) { + UserGroupInformation.setConfiguration(newConfiguration(sparkConf)) + val creds = deserialize(tokens) + logInfo(s"Adding/updating delegation tokens ${dumpTokens(creds)}") + addCurrentUserCredentials(creds) + } + /** * Returns a function that can be called to find Hadoop FileSystem bytes read. If * getFSBytesReadOnThreadCallback is called from thread r at time t, the returned callback will @@ -462,6 +473,19 @@ object SparkHadoopUtil { } } + /** + * Given an expiration date (e.g. for Hadoop Delegation Tokens) return a the date + * when a given fraction of the duration until the expiration date has passed. + * Formula: current time + (fraction * (time until expiration)) + * @param expirationDate Drop-dead expiration date + * @param fraction fraction of the time until expiration return + * @return Date when the fraction of the time until expiration has passed + */ + private[spark] def getDateOfNextUpdate(expirationDate: Long, fraction: Double): Long = { + val ct = System.currentTimeMillis + (ct + (fraction * (expirationDate - ct))).toLong + } + /** * Returns a Configuration object with Spark configuration applied on top. Unlike * the instance method, this will always return a Configuration instance, and not a diff --git a/core/src/main/scala/org/apache/spark/deploy/security/HadoopDelegationTokenManager.scala b/core/src/main/scala/org/apache/spark/deploy/security/HadoopDelegationTokenManager.scala index 483d0deec8070..116a686fe1480 100644 --- a/core/src/main/scala/org/apache/spark/deploy/security/HadoopDelegationTokenManager.scala +++ b/core/src/main/scala/org/apache/spark/deploy/security/HadoopDelegationTokenManager.scala @@ -109,6 +109,8 @@ private[spark] class HadoopDelegationTokenManager( * Writes delegation tokens to creds. Delegation tokens are fetched from all registered * providers. * + * @param hadoopConf hadoop Configuration + * @param creds Credentials that will be updated in place (overwritten) * @return Time after which the fetched delegation tokens should be renewed. */ def obtainDelegationTokens( @@ -125,3 +127,4 @@ private[spark] class HadoopDelegationTokenManager( }.foldLeft(Long.MaxValue)(math.min) } } + diff --git a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala index d27362ae85bea..acefc9d2436d0 100644 --- a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala +++ b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala @@ -123,6 +123,10 @@ private[spark] class CoarseGrainedExecutorBackend( executor.stop() } }.start() + + case UpdateDelegationTokens(tokenBytes) => + logInfo(s"Received tokens of ${tokenBytes.length} bytes") + SparkHadoopUtil.get.addDelegationTokens(tokenBytes, env.conf) } override def onDisconnected(remoteAddress: RpcAddress): Unit = { @@ -219,9 +223,8 @@ private[spark] object CoarseGrainedExecutorBackend extends Logging { SparkHadoopUtil.get.startCredentialUpdater(driverConf) } - cfg.hadoopDelegationCreds.foreach { hadoopCreds => - val creds = SparkHadoopUtil.get.deserialize(hadoopCreds) - SparkHadoopUtil.get.addCurrentUserCredentials(creds) + cfg.hadoopDelegationCreds.foreach { tokens => + SparkHadoopUtil.get.addDelegationTokens(tokens, driverConf) } val env = SparkEnv.createExecutorEnv( diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedClusterMessage.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedClusterMessage.scala index 5d65731dfc30e..e8b7fc0ef100a 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedClusterMessage.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedClusterMessage.scala @@ -54,6 +54,9 @@ private[spark] object CoarseGrainedClusterMessages { case class RegisterExecutorFailed(message: String) extends CoarseGrainedClusterMessage with RegisterExecutorResponse + case class UpdateDelegationTokens(tokens: Array[Byte]) + extends CoarseGrainedClusterMessage + // Executors to driver case class RegisterExecutor( executorId: String, diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala index 424e43b25c77a..22d9c4cf81c55 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala @@ -24,11 +24,7 @@ import javax.annotation.concurrent.GuardedBy import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet} import scala.concurrent.Future -import org.apache.hadoop.security.UserGroupInformation - import org.apache.spark.{ExecutorAllocationClient, SparkEnv, SparkException, TaskState} -import org.apache.spark.deploy.SparkHadoopUtil -import org.apache.spark.deploy.security.HadoopDelegationTokenManager import org.apache.spark.internal.Logging import org.apache.spark.rpc._ import org.apache.spark.scheduler._ @@ -99,12 +95,6 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp // The num of current max ExecutorId used to re-register appMaster @volatile protected var currentExecutorIdCounter = 0 - // hadoop token manager used by some sub-classes (e.g. Mesos) - def hadoopDelegationTokenManager: Option[HadoopDelegationTokenManager] = None - - // Hadoop delegation tokens to be sent to the executors. - val hadoopDelegationCreds: Option[Array[Byte]] = getHadoopDelegationCreds() - class DriverEndpoint(override val rpcEnv: RpcEnv, sparkProperties: Seq[(String, String)]) extends ThreadSafeRpcEndpoint with Logging { @@ -159,6 +149,11 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp scheduler.getExecutorsAliveOnHost(host).foreach { exec => killExecutors(exec.toSeq, replace = true, force = true) } + + case UpdateDelegationTokens(newDelegationTokens) => + executorDataMap.values.foreach { ed => + ed.executorEndpoint.send(UpdateDelegationTokens(newDelegationTokens)) + } } override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { @@ -236,7 +231,7 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp val reply = SparkAppConfig( sparkProperties, SparkEnv.get.securityManager.getIOEncryptionKey(), - hadoopDelegationCreds) + fetchHadoopDelegationTokens()) context.reply(reply) } @@ -686,18 +681,7 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp true } - protected def getHadoopDelegationCreds(): Option[Array[Byte]] = { - if (UserGroupInformation.isSecurityEnabled && hadoopDelegationTokenManager.isDefined) { - hadoopDelegationTokenManager.map { manager => - val creds = UserGroupInformation.getCurrentUser.getCredentials - val hadoopConf = SparkHadoopUtil.get.newConfiguration(conf) - manager.obtainDelegationTokens(hadoopConf, creds) - SparkHadoopUtil.get.serialize(creds) - } - } else { - None - } - } + protected def fetchHadoopDelegationTokens(): Option[Array[Byte]] = { None } } private[spark] object CoarseGrainedSchedulerBackend { diff --git a/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterManager.scala b/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterManager.scala index 911a0857917ef..da71f8f9e407c 100644 --- a/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterManager.scala +++ b/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterManager.scala @@ -17,7 +17,7 @@ package org.apache.spark.scheduler.cluster.mesos -import org.apache.spark.{SparkContext, SparkException} +import org.apache.spark.SparkContext import org.apache.spark.internal.config._ import org.apache.spark.scheduler.{ExternalClusterManager, SchedulerBackend, TaskScheduler, TaskSchedulerImpl} diff --git a/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackend.scala b/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackend.scala index 104ed01d293ce..c392061fdb358 100644 --- a/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackend.scala +++ b/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackend.scala @@ -22,15 +22,16 @@ import java.util.{Collections, List => JList} import java.util.concurrent.atomic.{AtomicBoolean, AtomicLong} import java.util.concurrent.locks.ReentrantLock -import org.apache.mesos.Protos.{TaskInfo => MesosTaskInfo, _} -import org.apache.mesos.SchedulerDriver import scala.collection.JavaConverters._ import scala.collection.mutable import scala.concurrent.Future +import org.apache.hadoop.security.UserGroupInformation +import org.apache.mesos.Protos.{TaskInfo => MesosTaskInfo, _} +import org.apache.mesos.SchedulerDriver + import org.apache.spark.{SecurityManager, SparkConf, SparkContext, SparkException, TaskState} import org.apache.spark.deploy.mesos.config._ -import org.apache.spark.deploy.security.HadoopDelegationTokenManager import org.apache.spark.internal.config import org.apache.spark.launcher.{LauncherBackend, SparkAppHandle} import org.apache.spark.network.netty.SparkTransportConf @@ -58,8 +59,8 @@ private[spark] class MesosCoarseGrainedSchedulerBackend( extends CoarseGrainedSchedulerBackend(scheduler, sc.env.rpcEnv) with org.apache.mesos.Scheduler with MesosSchedulerUtils { - override def hadoopDelegationTokenManager: Option[HadoopDelegationTokenManager] = - Some(new HadoopDelegationTokenManager(sc.conf, sc.hadoopConfiguration)) + private lazy val hadoopDelegationTokenManager: MesosHadoopDelegationTokenManager = + new MesosHadoopDelegationTokenManager(conf, sc.hadoopConfiguration, driverEndpoint) // Blacklist a slave after this many failures private val MAX_SLAVE_FAILURES = 2 @@ -772,6 +773,14 @@ private[spark] class MesosCoarseGrainedSchedulerBackend( offer.getHostname } } + + override def fetchHadoopDelegationTokens(): Option[Array[Byte]] = { + if (UserGroupInformation.isSecurityEnabled) { + Some(hadoopDelegationTokenManager.getTokens()) + } else { + None + } + } } private class Slave(val hostname: String) { diff --git a/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosHadoopDelegationTokenManager.scala b/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosHadoopDelegationTokenManager.scala new file mode 100644 index 0000000000000..325dc179d63ea --- /dev/null +++ b/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosHadoopDelegationTokenManager.scala @@ -0,0 +1,157 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.scheduler.cluster.mesos + +import java.security.PrivilegedExceptionAction +import java.util.concurrent.{ScheduledExecutorService, TimeUnit} + +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.security.UserGroupInformation + +import org.apache.spark.SparkConf +import org.apache.spark.deploy.SparkHadoopUtil +import org.apache.spark.deploy.security.HadoopDelegationTokenManager +import org.apache.spark.internal.{config, Logging} +import org.apache.spark.rpc.RpcEndpointRef +import org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages.UpdateDelegationTokens +import org.apache.spark.util.ThreadUtils + + +/** + * The MesosHadoopDelegationTokenManager fetches and updates Hadoop delegation tokens on the behalf + * of the MesosCoarseGrainedSchedulerBackend. It is modeled after the YARN AMCredentialRenewer, + * and similarly will renew the Credentials when 75% of the renewal interval has passed. + * The principal difference is that instead of writing the new credentials to HDFS and + * incrementing the timestamp of the file, the new credentials (called Tokens when they are + * serialized) are broadcast to all running executors. On the executor side, when new Tokens are + * received they overwrite the current credentials. + */ +private[spark] class MesosHadoopDelegationTokenManager( + conf: SparkConf, + hadoopConfig: Configuration, + driverEndpoint: RpcEndpointRef) + extends Logging { + + require(driverEndpoint != null, "DriverEndpoint is not initialized") + + private val credentialRenewerThread: ScheduledExecutorService = + ThreadUtils.newDaemonSingleThreadScheduledExecutor("Credential Renewal Thread") + + private val tokenManager: HadoopDelegationTokenManager = + new HadoopDelegationTokenManager(conf, hadoopConfig) + + private val principal: String = conf.get(config.PRINCIPAL).orNull + + private var (tokens: Array[Byte], timeOfNextRenewal: Long) = { + try { + val creds = UserGroupInformation.getCurrentUser.getCredentials + val hadoopConf = SparkHadoopUtil.get.newConfiguration(conf) + val rt = tokenManager.obtainDelegationTokens(hadoopConf, creds) + logInfo(s"Initialized tokens: ${SparkHadoopUtil.get.dumpTokens(creds)}") + (SparkHadoopUtil.get.serialize(creds), rt) + } catch { + case e: Exception => + logError(s"Failed to fetch Hadoop delegation tokens $e") + throw e + } + } + + private val keytabFile: Option[String] = conf.get(config.KEYTAB) + + scheduleTokenRenewal() + + private def scheduleTokenRenewal(): Unit = { + if (keytabFile.isDefined) { + require(principal != null, "Principal is required for Keytab-based authentication") + logInfo(s"Using keytab: ${keytabFile.get} and principal $principal") + } else { + logInfo("Using ticket cache for Kerberos authentication, no token renewal.") + return + } + + def scheduleRenewal(runnable: Runnable): Unit = { + val remainingTime = timeOfNextRenewal - System.currentTimeMillis() + if (remainingTime <= 0) { + logInfo("Credentials have expired, creating new ones now.") + runnable.run() + } else { + logInfo(s"Scheduling login from keytab in $remainingTime millis.") + credentialRenewerThread.schedule(runnable, remainingTime, TimeUnit.MILLISECONDS) + } + } + + val credentialRenewerRunnable = + new Runnable { + override def run(): Unit = { + try { + getNewDelegationTokens() + broadcastDelegationTokens(tokens) + } catch { + case e: Exception => + // Log the error and try to write new tokens back in an hour + logWarning("Couldn't broadcast tokens, trying again in an hour", e) + credentialRenewerThread.schedule(this, 1, TimeUnit.HOURS) + return + } + scheduleRenewal(this) + } + } + scheduleRenewal(credentialRenewerRunnable) + } + + private def getNewDelegationTokens(): Unit = { + logInfo(s"Attempting to login to KDC with principal ${principal}") + // Get new delegation tokens by logging in with a new UGI inspired by AMCredentialRenewer.scala + // Don't protect against keytabFile being empty because it's guarded above. + val ugi = UserGroupInformation.loginUserFromKeytabAndReturnUGI(principal, keytabFile.get) + logInfo("Successfully logged into KDC") + val tempCreds = ugi.getCredentials + val hadoopConf = SparkHadoopUtil.get.newConfiguration(conf) + val nextRenewalTime = ugi.doAs(new PrivilegedExceptionAction[Long] { + override def run(): Long = { + tokenManager.obtainDelegationTokens(hadoopConf, tempCreds) + } + }) + + val currTime = System.currentTimeMillis() + timeOfNextRenewal = if (nextRenewalTime <= currTime) { + logWarning(s"Next credential renewal time ($nextRenewalTime) is earlier than " + + s"current time ($currTime), which is unexpected, please check your credential renewal " + + "related configurations in the target services.") + currTime + } else { + SparkHadoopUtil.getDateOfNextUpdate(nextRenewalTime, 0.75) + } + logInfo(s"Time of next renewal is in ${timeOfNextRenewal - System.currentTimeMillis()} ms") + + // Add the temp credentials back to the original ones. + UserGroupInformation.getCurrentUser.addCredentials(tempCreds) + // update tokens for late or dynamically added executors + tokens = SparkHadoopUtil.get.serialize(tempCreds) + } + + private def broadcastDelegationTokens(tokens: Array[Byte]) = { + logInfo("Sending new tokens to all executors") + driverEndpoint.send(UpdateDelegationTokens(tokens)) + } + + def getTokens(): Array[Byte] = { + tokens + } +} + diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/security/AMCredentialRenewer.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/security/AMCredentialRenewer.scala index 68a2e9e70a78b..6134757a82fdc 100644 --- a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/security/AMCredentialRenewer.scala +++ b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/security/AMCredentialRenewer.scala @@ -17,7 +17,7 @@ package org.apache.spark.deploy.yarn.security import java.security.PrivilegedExceptionAction -import java.util.concurrent.{Executors, TimeUnit} +import java.util.concurrent.{ScheduledExecutorService, TimeUnit} import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.{FileSystem, Path} @@ -25,6 +25,7 @@ import org.apache.hadoop.security.UserGroupInformation import org.apache.spark.SparkConf import org.apache.spark.deploy.SparkHadoopUtil +import org.apache.spark.deploy.security.HadoopDelegationTokenManager import org.apache.spark.deploy.yarn.YarnSparkHadoopUtil import org.apache.spark.deploy.yarn.config._ import org.apache.spark.internal.Logging @@ -58,9 +59,8 @@ private[yarn] class AMCredentialRenewer( private var lastCredentialsFileSuffix = 0 - private val credentialRenewer = - Executors.newSingleThreadScheduledExecutor( - ThreadUtils.namedThreadFactory("Credential Refresh Thread")) + private val credentialRenewerThread: ScheduledExecutorService = + ThreadUtils.newDaemonSingleThreadScheduledExecutor("Credential Refresh Thread") private val hadoopUtil = YarnSparkHadoopUtil.get @@ -70,7 +70,7 @@ private[yarn] class AMCredentialRenewer( private val freshHadoopConf = hadoopUtil.getConfBypassingFSCache(hadoopConf, new Path(credentialsFile).toUri.getScheme) - @volatile private var timeOfNextRenewal = sparkConf.get(CREDENTIALS_RENEWAL_TIME) + @volatile private var timeOfNextRenewal: Long = sparkConf.get(CREDENTIALS_RENEWAL_TIME) /** * Schedule a login from the keytab and principal set using the --principal and --keytab @@ -95,7 +95,7 @@ private[yarn] class AMCredentialRenewer( runnable.run() } else { logInfo(s"Scheduling login from keytab in $remainingTime millis.") - credentialRenewer.schedule(runnable, remainingTime, TimeUnit.MILLISECONDS) + credentialRenewerThread.schedule(runnable, remainingTime, TimeUnit.MILLISECONDS) } } @@ -111,7 +111,7 @@ private[yarn] class AMCredentialRenewer( // Log the error and try to write new tokens back in an hour logWarning("Failed to write out new credentials to HDFS, will try again in an " + "hour! If this happens too often tasks will fail.", e) - credentialRenewer.schedule(this, 1, TimeUnit.HOURS) + credentialRenewerThread.schedule(this, 1, TimeUnit.HOURS) return } scheduleRenewal(this) @@ -195,8 +195,9 @@ private[yarn] class AMCredentialRenewer( } else { // Next valid renewal time is about 75% of credential renewal time, and update time is // slightly later than valid renewal time (80% of renewal time). - timeOfNextRenewal = ((nearestNextRenewalTime - currTime) * 0.75 + currTime).toLong - ((nearestNextRenewalTime - currTime) * 0.8 + currTime).toLong + timeOfNextRenewal = + SparkHadoopUtil.getDateOfNextUpdate(nearestNextRenewalTime, 0.75) + SparkHadoopUtil.getDateOfNextUpdate(nearestNextRenewalTime, 0.8) } // Add the temp credentials back to the original ones. @@ -232,6 +233,6 @@ private[yarn] class AMCredentialRenewer( } def stop(): Unit = { - credentialRenewer.shutdown() + credentialRenewerThread.shutdown() } } From 03f2b7bff7e537ec747b41ad22e448e1c141f0dd Mon Sep 17 00:00:00 2001 From: Shixiong Zhu Date: Thu, 16 Nov 2017 14:22:25 +0900 Subject: [PATCH 1682/1765] [SPARK-22535][PYSPARK] Sleep before killing the python worker in PythonRunner.MonitorThread ## What changes were proposed in this pull request? `PythonRunner.MonitorThread` should give the task a little time to finish before forcibly killing the python worker. This will reduce the chance of the race condition a lot. I also improved the log a bit to find out the task to blame when it's stuck. ## How was this patch tested? Jenkins Author: Shixiong Zhu Closes #19762 from zsxwing/SPARK-22535. --- .../spark/api/python/PythonRunner.scala | 21 +++++++++++++------ 1 file changed, 15 insertions(+), 6 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala index d417303bb147d..9989f68f8508c 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala @@ -337,6 +337,9 @@ private[spark] abstract class BasePythonRunner[IN, OUT]( class MonitorThread(env: SparkEnv, worker: Socket, context: TaskContext) extends Thread(s"Worker Monitor for $pythonExec") { + /** How long to wait before killing the python worker if a task cannot be interrupted. */ + private val taskKillTimeout = env.conf.getTimeAsMs("spark.python.task.killTimeout", "2s") + setDaemon(true) override def run() { @@ -346,12 +349,18 @@ private[spark] abstract class BasePythonRunner[IN, OUT]( Thread.sleep(2000) } if (!context.isCompleted) { - try { - logWarning("Incomplete task interrupted: Attempting to kill Python Worker") - env.destroyPythonWorker(pythonExec, envVars.asScala.toMap, worker) - } catch { - case e: Exception => - logError("Exception when trying to kill worker", e) + Thread.sleep(taskKillTimeout) + if (!context.isCompleted) { + try { + // Mimic the task name used in `Executor` to help the user find out the task to blame. + val taskName = s"${context.partitionId}.${context.taskAttemptId} " + + s"in stage ${context.stageId} (TID ${context.taskAttemptId})" + logWarning(s"Incomplete task $taskName interrupted: Attempting to kill Python Worker") + env.destroyPythonWorker(pythonExec, envVars.asScala.toMap, worker) + } catch { + case e: Exception => + logError("Exception when trying to kill worker", e) + } } } } From ed885e7a6504c439ffb6730e6963efbd050d43dd Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Thu, 16 Nov 2017 17:56:21 +0100 Subject: [PATCH 1683/1765] [SPARK-22499][SQL] Fix 64KB JVM bytecode limit problem with least and greatest ## What changes were proposed in this pull request? This PR changes `least` and `greatest` code generation to place generated code for expression for arguments into separated methods if these size could be large. This PR resolved two cases: * `least` with a lot of argument * `greatest` with a lot of argument ## How was this patch tested? Added a new test case into `ArithmeticExpressionsSuite` Author: Kazuaki Ishizaki Closes #19729 from kiszk/SPARK-22499. --- .../sql/catalyst/expressions/arithmetic.scala | 24 +++++++++---------- .../ArithmeticExpressionSuite.scala | 9 +++++++ 2 files changed, 21 insertions(+), 12 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala index 7559852a2ac45..72d5889d2f202 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala @@ -602,8 +602,8 @@ case class Least(children: Seq[Expression]) extends Expression { override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val evalChildren = children.map(_.genCode(ctx)) - val first = evalChildren(0) - val rest = evalChildren.drop(1) + ctx.addMutableState("boolean", ev.isNull, "") + ctx.addMutableState(ctx.javaType(dataType), ev.value, "") def updateEval(eval: ExprCode): String = { s""" ${eval.code} @@ -614,11 +614,11 @@ case class Least(children: Seq[Expression]) extends Expression { } """ } + val codes = ctx.splitExpressions(ctx.INPUT_ROW, evalChildren.map(updateEval)) ev.copy(code = s""" - ${first.code} - boolean ${ev.isNull} = ${first.isNull}; - ${ctx.javaType(dataType)} ${ev.value} = ${first.value}; - ${rest.map(updateEval).mkString("\n")}""") + ${ev.isNull} = true; + ${ev.value} = ${ctx.defaultValue(dataType)}; + $codes""") } } @@ -668,8 +668,8 @@ case class Greatest(children: Seq[Expression]) extends Expression { override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val evalChildren = children.map(_.genCode(ctx)) - val first = evalChildren(0) - val rest = evalChildren.drop(1) + ctx.addMutableState("boolean", ev.isNull, "") + ctx.addMutableState(ctx.javaType(dataType), ev.value, "") def updateEval(eval: ExprCode): String = { s""" ${eval.code} @@ -680,10 +680,10 @@ case class Greatest(children: Seq[Expression]) extends Expression { } """ } + val codes = ctx.splitExpressions(ctx.INPUT_ROW, evalChildren.map(updateEval)) ev.copy(code = s""" - ${first.code} - boolean ${ev.isNull} = ${first.isNull}; - ${ctx.javaType(dataType)} ${ev.value} = ${first.value}; - ${rest.map(updateEval).mkString("\n")}""") + ${ev.isNull} = true; + ${ev.value} = ${ctx.defaultValue(dataType)}; + $codes""") } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala index 031053727f08e..fb759eba6a9e2 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala @@ -334,4 +334,13 @@ class ArithmeticExpressionSuite extends SparkFunSuite with ExpressionEvalHelper checkConsistencyBetweenInterpretedAndCodegen(Greatest, dt, 2) } } + + test("SPARK-22499: Least and greatest should not generate codes beyond 64KB") { + val N = 3000 + val strings = (1 to N).map(x => "s" * x) + val inputsExpr = strings.map(Literal.create(_, StringType)) + + checkEvaluation(Least(inputsExpr), "s" * 1, EmptyRow) + checkEvaluation(Greatest(inputsExpr), "s" * N, EmptyRow) + } } From 4e7f07e2550fa995cc37406173a937033135cf3b Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Thu, 16 Nov 2017 18:19:13 +0100 Subject: [PATCH 1684/1765] [SPARK-22494][SQL] Fix 64KB limit exception with Coalesce and AtleastNNonNulls ## What changes were proposed in this pull request? Both `Coalesce` and `AtLeastNNonNulls` can cause the 64KB limit exception when used with a lot of arguments and/or complex expressions. This PR splits their expressions in order to avoid the issue. ## How was this patch tested? Added UTs Author: Marco Gaido Author: Marco Gaido Closes #19720 from mgaido91/SPARK-22494. --- .../expressions/nullExpressions.scala | 42 ++++++++++++++----- .../expressions/NullExpressionsSuite.scala | 10 +++++ 2 files changed, 41 insertions(+), 11 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala index 62786e13bda2c..4aeab2c3ad0a8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala @@ -72,14 +72,10 @@ case class Coalesce(children: Seq[Expression]) extends Expression { } override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - val first = children(0) - val rest = children.drop(1) - val firstEval = first.genCode(ctx) - ev.copy(code = s""" - ${firstEval.code} - boolean ${ev.isNull} = ${firstEval.isNull}; - ${ctx.javaType(dataType)} ${ev.value} = ${firstEval.value};""" + - rest.map { e => + ctx.addMutableState("boolean", ev.isNull, "") + ctx.addMutableState(ctx.javaType(dataType), ev.value, "") + + val evals = children.map { e => val eval = e.genCode(ctx) s""" if (${ev.isNull}) { @@ -90,7 +86,12 @@ case class Coalesce(children: Seq[Expression]) extends Expression { } } """ - }.mkString("\n")) + } + + ev.copy(code = s""" + ${ev.isNull} = true; + ${ev.value} = ${ctx.defaultValue(dataType)}; + ${ctx.splitExpressions(ctx.INPUT_ROW, evals)}""") } } @@ -357,7 +358,7 @@ case class AtLeastNNonNulls(n: Int, children: Seq[Expression]) extends Predicate override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val nonnull = ctx.freshName("nonnull") - val code = children.map { e => + val evals = children.map { e => val eval = e.genCode(ctx) e.dataType match { case DoubleType | FloatType => @@ -379,7 +380,26 @@ case class AtLeastNNonNulls(n: Int, children: Seq[Expression]) extends Predicate } """ } - }.mkString("\n") + } + + val code = if (ctx.INPUT_ROW == null || ctx.currentVars != null) { + evals.mkString("\n") + } else { + ctx.splitExpressions(evals, "atLeastNNonNulls", + ("InternalRow", ctx.INPUT_ROW) :: ("int", nonnull) :: Nil, + returnType = "int", + makeSplitFunction = { body => + s""" + $body + return $nonnull; + """ + }, + foldFunctions = { funcCalls => + funcCalls.map(funcCall => s"$nonnull = $funcCall;").mkString("\n") + } + ) + } + ev.copy(code = s""" int $nonnull = 0; $code diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/NullExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/NullExpressionsSuite.scala index 394c0a091e390..40ef7770da33f 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/NullExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/NullExpressionsSuite.scala @@ -149,4 +149,14 @@ class NullExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(AtLeastNNonNulls(3, nullOnly), true, EmptyRow) checkEvaluation(AtLeastNNonNulls(4, nullOnly), false, EmptyRow) } + + test("Coalesce should not throw 64kb exception") { + val inputs = (1 to 2500).map(x => Literal(s"x_$x")) + checkEvaluation(Coalesce(inputs), "x_1") + } + + test("AtLeastNNonNulls should not throw 64kb exception") { + val inputs = (1 to 4000).map(x => Literal(s"x_$x")) + checkEvaluation(AtLeastNNonNulls(1, inputs), true) + } } From 7f2e62ee6b9d1f32772a18d626fb9fd907aa7733 Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Thu, 16 Nov 2017 18:24:49 +0100 Subject: [PATCH 1685/1765] [SPARK-22501][SQL] Fix 64KB JVM bytecode limit problem with in ## What changes were proposed in this pull request? This PR changes `In` code generation to place generated code for expression for expressions for arguments into separated methods if these size could be large. ## How was this patch tested? Added new test cases into `PredicateSuite` Author: Kazuaki Ishizaki Closes #19733 from kiszk/SPARK-22501. --- .../sql/catalyst/expressions/predicates.scala | 20 ++++++++++++++----- .../catalyst/expressions/PredicateSuite.scala | 6 ++++++ 2 files changed, 21 insertions(+), 5 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala index 61df5e053a374..5d75c6004bfe4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala @@ -236,24 +236,34 @@ case class In(value: Expression, list: Seq[Expression]) extends Predicate { override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val valueGen = value.genCode(ctx) val listGen = list.map(_.genCode(ctx)) + ctx.addMutableState("boolean", ev.value, "") + ctx.addMutableState("boolean", ev.isNull, "") + val valueArg = ctx.freshName("valueArg") val listCode = listGen.map(x => s""" if (!${ev.value}) { ${x.code} if (${x.isNull}) { ${ev.isNull} = true; - } else if (${ctx.genEqual(value.dataType, valueGen.value, x.value)}) { + } else if (${ctx.genEqual(value.dataType, valueArg, x.value)}) { ${ev.isNull} = false; ${ev.value} = true; } } - """).mkString("\n") + """) + val listCodes = if (ctx.INPUT_ROW != null && ctx.currentVars == null) { + val args = ("InternalRow", ctx.INPUT_ROW) :: (ctx.javaType(value.dataType), valueArg) :: Nil + ctx.splitExpressions(listCode, "valueIn", args) + } else { + listCode.mkString("\n") + } ev.copy(code = s""" ${valueGen.code} - boolean ${ev.value} = false; - boolean ${ev.isNull} = ${valueGen.isNull}; + ${ev.value} = false; + ${ev.isNull} = ${valueGen.isNull}; if (!${ev.isNull}) { - $listCode + ${ctx.javaType(value.dataType)} $valueArg = ${valueGen.value}; + $listCodes } """) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala index 1438a88c19e0b..865092a659f26 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala @@ -239,6 +239,12 @@ class PredicateSuite extends SparkFunSuite with ExpressionEvalHelper { } } + test("SPARK-22501: In should not generate codes beyond 64KB") { + val N = 3000 + val sets = (1 to N).map(i => Literal(i.toDouble)) + checkEvaluation(In(Literal(1.0D), sets), true) + } + test("INSET") { val hS = HashSet[Any]() + 1 + 2 val nS = HashSet[Any]() + 1 + 2 + null From b9dcbe5e1ba81135a51b486240662674728dda84 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Thu, 16 Nov 2017 18:23:00 -0800 Subject: [PATCH 1686/1765] [SPARK-22542][SQL] remove unused features in ColumnarBatch ## What changes were proposed in this pull request? `ColumnarBatch` provides features to do fast filter and project in a columnar fashion, however this feature is never used by Spark, as Spark uses whole stage codegen and processes the data in a row fashion. This PR proposes to remove these unused features as we won't switch to columnar execution in the near future. Even we do, I think this part needs a proper redesign. This is also a step to make `ColumnVector` public, as we don't wanna expose these features to users. ## How was this patch tested? existing tests Author: Wenchen Fan Closes #19766 from cloud-fan/vector. --- .../sql/catalyst/expressions/UnsafeRow.java | 4 - .../execution/vectorized/ColumnarArray.java | 2 +- .../execution/vectorized/ColumnarBatch.java | 78 +------------------ .../sql/execution/vectorized/ColumnarRow.java | 26 ------- .../arrow/ArrowConvertersSuite.scala | 2 +- .../parquet/ParquetReadBenchmark.scala | 24 ------ .../vectorized/ColumnarBatchSuite.scala | 52 ------------- 7 files changed, 6 insertions(+), 182 deletions(-) diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java index ec947d7580282..71c086029cc5b 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java @@ -69,10 +69,6 @@ public static int calculateBitSetWidthInBytes(int numFields) { return ((numFields + 63)/ 64) * 8; } - public static int calculateFixedPortionByteSize(int numFields) { - return 8 * numFields + calculateBitSetWidthInBytes(numFields); - } - /** * Field types that can be updated in place in UnsafeRows (e.g. we support set() for these types) */ diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarArray.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarArray.java index 5e88ce0321084..34bde3e14d378 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarArray.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarArray.java @@ -41,7 +41,7 @@ public final class ColumnarArray extends ArrayData { // Reused staging buffer, used for loading from offheap. protected byte[] tmpByteArray = new byte[1]; - protected ColumnarArray(ColumnVector data) { + ColumnarArray(ColumnVector data) { this.data = data; } diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarBatch.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarBatch.java index 8849a20d6ceb5..2f5fb360b226f 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarBatch.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarBatch.java @@ -42,15 +42,6 @@ public final class ColumnarBatch { private int numRows; final ColumnVector[] columns; - // True if the row is filtered. - private final boolean[] filteredRows; - - // Column indices that cannot have null values. - private final Set nullFilteredColumns; - - // Total number of rows that have been filtered. - private int numRowsFiltered = 0; - // Staging row returned from getRow. final ColumnarRow row; @@ -68,24 +59,18 @@ public void close() { * Returns an iterator over the rows in this batch. This skips rows that are filtered out. */ public Iterator rowIterator() { - final int maxRows = ColumnarBatch.this.numRows(); - final ColumnarRow row = new ColumnarRow(this); + final int maxRows = numRows; + final ColumnarRow row = new ColumnarRow(columns); return new Iterator() { int rowId = 0; @Override public boolean hasNext() { - while (rowId < maxRows && ColumnarBatch.this.filteredRows[rowId]) { - ++rowId; - } return rowId < maxRows; } @Override public ColumnarRow next() { - while (rowId < maxRows && ColumnarBatch.this.filteredRows[rowId]) { - ++rowId; - } if (rowId >= maxRows) { throw new NoSuchElementException(); } @@ -109,31 +94,15 @@ public void reset() { ((WritableColumnVector) columns[i]).reset(); } } - if (this.numRowsFiltered > 0) { - Arrays.fill(filteredRows, false); - } this.numRows = 0; - this.numRowsFiltered = 0; } /** - * Sets the number of rows that are valid. Additionally, marks all rows as "filtered" if one or - * more of their attributes are part of a non-nullable column. + * Sets the number of rows that are valid. */ public void setNumRows(int numRows) { assert(numRows <= this.capacity); this.numRows = numRows; - - for (int ordinal : nullFilteredColumns) { - if (columns[ordinal].numNulls() != 0) { - for (int rowId = 0; rowId < numRows; rowId++) { - if (!filteredRows[rowId] && columns[ordinal].isNullAt(rowId)) { - filteredRows[rowId] = true; - ++numRowsFiltered; - } - } - } - } } /** @@ -146,14 +115,6 @@ public void setNumRows(int numRows) { */ public int numRows() { return numRows; } - /** - * Returns the number of valid rows. - */ - public int numValidRows() { - assert(numRowsFiltered <= numRows); - return numRows - numRowsFiltered; - } - /** * Returns the schema that makes up this batch. */ @@ -169,17 +130,6 @@ public int numValidRows() { */ public ColumnVector column(int ordinal) { return columns[ordinal]; } - /** - * Sets (replaces) the column at `ordinal` with column. This can be used to do very efficient - * projections. - */ - public void setColumn(int ordinal, ColumnVector column) { - if (column instanceof OffHeapColumnVector) { - throw new UnsupportedOperationException("Need to ref count columns."); - } - columns[ordinal] = column; - } - /** * Returns the row in this batch at `rowId`. Returned row is reused across calls. */ @@ -190,30 +140,10 @@ public ColumnarRow getRow(int rowId) { return row; } - /** - * Marks this row as being filtered out. This means a subsequent iteration over the rows - * in this batch will not include this row. - */ - public void markFiltered(int rowId) { - assert(!filteredRows[rowId]); - filteredRows[rowId] = true; - ++numRowsFiltered; - } - - /** - * Marks a given column as non-nullable. Any row that has a NULL value for the corresponding - * attribute is filtered out. - */ - public void filterNullsInColumn(int ordinal) { - nullFilteredColumns.add(ordinal); - } - public ColumnarBatch(StructType schema, ColumnVector[] columns, int capacity) { this.schema = schema; this.columns = columns; this.capacity = capacity; - this.nullFilteredColumns = new HashSet<>(); - this.filteredRows = new boolean[capacity]; - this.row = new ColumnarRow(this); + this.row = new ColumnarRow(columns); } } diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarRow.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarRow.java index c75adafd69461..98a907322713b 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarRow.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarRow.java @@ -20,7 +20,6 @@ import org.apache.spark.sql.catalyst.InternalRow; import org.apache.spark.sql.catalyst.expressions.GenericInternalRow; -import org.apache.spark.sql.catalyst.expressions.UnsafeRow; import org.apache.spark.sql.catalyst.util.MapData; import org.apache.spark.sql.types.*; import org.apache.spark.unsafe.types.CalendarInterval; @@ -32,28 +31,11 @@ */ public final class ColumnarRow extends InternalRow { protected int rowId; - private final ColumnarBatch parent; - private final int fixedLenRowSize; private final ColumnVector[] columns; private final WritableColumnVector[] writableColumns; - // Ctor used if this is a top level row. - ColumnarRow(ColumnarBatch parent) { - this.parent = parent; - this.fixedLenRowSize = UnsafeRow.calculateFixedPortionByteSize(parent.numCols()); - this.columns = parent.columns; - this.writableColumns = new WritableColumnVector[this.columns.length]; - for (int i = 0; i < this.columns.length; i++) { - if (this.columns[i] instanceof WritableColumnVector) { - this.writableColumns[i] = (WritableColumnVector) this.columns[i]; - } - } - } - // Ctor used if this is a struct. ColumnarRow(ColumnVector[] columns) { - this.parent = null; - this.fixedLenRowSize = UnsafeRow.calculateFixedPortionByteSize(columns.length); this.columns = columns; this.writableColumns = new WritableColumnVector[this.columns.length]; for (int i = 0; i < this.columns.length; i++) { @@ -63,14 +45,6 @@ public final class ColumnarRow extends InternalRow { } } - /** - * Marks this row as being filtered out. This means a subsequent iteration over the rows - * in this batch will not include this row. - */ - public void markFiltered() { - parent.markFiltered(rowId); - } - public ColumnVector[] columns() { return columns; } @Override diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowConvertersSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowConvertersSuite.scala index ba2903babbba8..57958f7239224 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowConvertersSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowConvertersSuite.scala @@ -1731,7 +1731,7 @@ class ArrowConvertersSuite extends SharedSQLContext with BeforeAndAfterAll { val payloadIter = ArrowConverters.toPayloadIterator(inputRows.toIterator, schema, 0, null, ctx) val outputRowIter = ArrowConverters.fromPayloadIterator(payloadIter, ctx) - assert(schema.equals(outputRowIter.schema)) + assert(schema == outputRowIter.schema) var count = 0 outputRowIter.zipWithIndex.foreach { case (row, i) => diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetReadBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetReadBenchmark.scala index 0917f188b9799..de7a5795b4796 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetReadBenchmark.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetReadBenchmark.scala @@ -295,48 +295,24 @@ object ParquetReadBenchmark { } } - benchmark.addCase("PR Vectorized (Null Filtering)") { num => - var sum = 0L - files.map(_.asInstanceOf[String]).foreach { p => - val reader = new VectorizedParquetRecordReader - try { - reader.initialize(p, ("c1" :: "c2" :: Nil).asJava) - val batch = reader.resultBatch() - batch.filterNullsInColumn(0) - batch.filterNullsInColumn(1) - while (reader.nextBatch()) { - val rowIterator = batch.rowIterator() - while (rowIterator.hasNext) { - sum += rowIterator.next().getUTF8String(0).numBytes() - } - } - } finally { - reader.close() - } - } - } - /* Intel(R) Core(TM) i7-4960HQ CPU @ 2.60GHz String with Nulls Scan (0%): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------- SQL Parquet Vectorized 1229 / 1648 8.5 117.2 1.0X PR Vectorized 833 / 846 12.6 79.4 1.5X - PR Vectorized (Null Filtering) 732 / 782 14.3 69.8 1.7X Intel(R) Core(TM) i7-4960HQ CPU @ 2.60GHz String with Nulls Scan (50%): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------- SQL Parquet Vectorized 995 / 1053 10.5 94.9 1.0X PR Vectorized 732 / 772 14.3 69.8 1.4X - PR Vectorized (Null Filtering) 725 / 790 14.5 69.1 1.4X Intel(R) Core(TM) i7-4960HQ CPU @ 2.60GHz String with Nulls Scan (95%): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------- SQL Parquet Vectorized 326 / 333 32.2 31.1 1.0X PR Vectorized 190 / 200 55.1 18.2 1.7X - PR Vectorized (Null Filtering) 168 / 172 62.2 16.1 1.9X */ benchmark.run() diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala index 4cfc776e51db1..4a6c8f5521d18 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala @@ -919,7 +919,6 @@ class ColumnarBatchSuite extends SparkFunSuite { val batch = new ColumnarBatch(schema, columns.toArray, ColumnarBatch.DEFAULT_BATCH_SIZE) assert(batch.numCols() == 4) assert(batch.numRows() == 0) - assert(batch.numValidRows() == 0) assert(batch.capacity() > 0) assert(batch.rowIterator().hasNext == false) @@ -933,7 +932,6 @@ class ColumnarBatchSuite extends SparkFunSuite { // Verify the results of the row. assert(batch.numCols() == 4) assert(batch.numRows() == 1) - assert(batch.numValidRows() == 1) assert(batch.rowIterator().hasNext == true) assert(batch.rowIterator().hasNext == true) @@ -957,16 +955,9 @@ class ColumnarBatchSuite extends SparkFunSuite { assert(it.hasNext == false) assert(it.hasNext == false) - // Filter out the row. - row.markFiltered() - assert(batch.numRows() == 1) - assert(batch.numValidRows() == 0) - assert(batch.rowIterator().hasNext == false) - // Reset and add 3 rows batch.reset() assert(batch.numRows() == 0) - assert(batch.numValidRows() == 0) assert(batch.rowIterator().hasNext == false) // Add rows [NULL, 2.2, 2, "abc"], [3, NULL, 3, ""], [4, 4.4, 4, "world] @@ -1002,26 +993,12 @@ class ColumnarBatchSuite extends SparkFunSuite { // Verify assert(batch.numRows() == 3) - assert(batch.numValidRows() == 3) val it2 = batch.rowIterator() rowEquals(it2.next(), Row(null, 2.2, 2, "abc")) rowEquals(it2.next(), Row(3, null, 3, "")) rowEquals(it2.next(), Row(4, 4.4, 4, "world")) assert(!it.hasNext) - // Filter out some rows and verify - batch.markFiltered(1) - assert(batch.numValidRows() == 2) - val it3 = batch.rowIterator() - rowEquals(it3.next(), Row(null, 2.2, 2, "abc")) - rowEquals(it3.next(), Row(4, 4.4, 4, "world")) - assert(!it.hasNext) - - batch.markFiltered(2) - assert(batch.numValidRows() == 1) - val it4 = batch.rowIterator() - rowEquals(it4.next(), Row(null, 2.2, 2, "abc")) - batch.close() }} } @@ -1176,35 +1153,6 @@ class ColumnarBatchSuite extends SparkFunSuite { testRandomRows(false, 30) } - test("null filtered columns") { - val NUM_ROWS = 10 - val schema = new StructType() - .add("key", IntegerType, nullable = false) - .add("value", StringType, nullable = true) - for (numNulls <- List(0, NUM_ROWS / 2, NUM_ROWS)) { - val rows = mutable.ArrayBuffer.empty[Row] - for (i <- 0 until NUM_ROWS) { - val row = if (i < numNulls) Row.fromSeq(Seq(i, null)) else Row.fromSeq(Seq(i, i.toString)) - rows += row - } - (MemoryMode.ON_HEAP :: MemoryMode.OFF_HEAP :: Nil).foreach { memMode => { - val batch = ColumnVectorUtils.toBatch(schema, memMode, rows.iterator.asJava) - batch.filterNullsInColumn(1) - batch.setNumRows(NUM_ROWS) - assert(batch.numRows() == NUM_ROWS) - val it = batch.rowIterator() - // Top numNulls rows should be filtered - var k = numNulls - while (it.hasNext) { - assert(it.next().getInt(0) == k) - k += 1 - } - assert(k == NUM_ROWS) - batch.close() - }} - } - } - test("mutable ColumnarBatch rows") { val NUM_ITERS = 10 val types = Array( From d00b55d4b25ba0bf92983ff1bb47d8528e943737 Mon Sep 17 00:00:00 2001 From: yucai Date: Fri, 17 Nov 2017 07:53:53 -0600 Subject: [PATCH 1687/1765] [SPARK-22540][SQL] Ensure HighlyCompressedMapStatus calculates correct avgSize ## What changes were proposed in this pull request? Ensure HighlyCompressedMapStatus calculates correct avgSize ## How was this patch tested? New unit test added. Author: yucai Closes #19765 from yucai/avgsize. --- .../apache/spark/scheduler/MapStatus.scala | 10 +++++---- .../spark/scheduler/MapStatusSuite.scala | 22 +++++++++++++++++++ 2 files changed, 28 insertions(+), 4 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/scheduler/MapStatus.scala b/core/src/main/scala/org/apache/spark/scheduler/MapStatus.scala index 5e45b375ddd45..2ec2f2031aa45 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/MapStatus.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/MapStatus.scala @@ -197,7 +197,8 @@ private[spark] object HighlyCompressedMapStatus { // block as being non-empty (or vice-versa) when using the average block size. var i = 0 var numNonEmptyBlocks: Int = 0 - var totalSize: Long = 0 + var numSmallBlocks: Int = 0 + var totalSmallBlockSize: Long = 0 // From a compression standpoint, it shouldn't matter whether we track empty or non-empty // blocks. From a performance standpoint, we benefit from tracking empty blocks because // we expect that there will be far fewer of them, so we will perform fewer bitmap insertions. @@ -214,7 +215,8 @@ private[spark] object HighlyCompressedMapStatus { // Huge blocks are not included in the calculation for average size, thus size for smaller // blocks is more accurate. if (size < threshold) { - totalSize += size + totalSmallBlockSize += size + numSmallBlocks += 1 } else { hugeBlockSizesArray += Tuple2(i, MapStatus.compressSize(uncompressedSizes(i))) } @@ -223,8 +225,8 @@ private[spark] object HighlyCompressedMapStatus { } i += 1 } - val avgSize = if (numNonEmptyBlocks > 0) { - totalSize / numNonEmptyBlocks + val avgSize = if (numSmallBlocks > 0) { + totalSmallBlockSize / numSmallBlocks } else { 0 } diff --git a/core/src/test/scala/org/apache/spark/scheduler/MapStatusSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/MapStatusSuite.scala index 144e5afdcdd78..2155a0f2b6c21 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/MapStatusSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/MapStatusSuite.scala @@ -98,6 +98,28 @@ class MapStatusSuite extends SparkFunSuite { } } + test("SPARK-22540: ensure HighlyCompressedMapStatus calculates correct avgSize") { + val threshold = 1000 + val conf = new SparkConf().set(config.SHUFFLE_ACCURATE_BLOCK_THRESHOLD.key, threshold.toString) + val env = mock(classOf[SparkEnv]) + doReturn(conf).when(env).conf + SparkEnv.set(env) + val sizes = (0L to 3000L).toArray + val smallBlockSizes = sizes.filter(n => n > 0 && n < threshold) + val avg = smallBlockSizes.sum / smallBlockSizes.length + val loc = BlockManagerId("a", "b", 10) + val status = MapStatus(loc, sizes) + val status1 = compressAndDecompressMapStatus(status) + assert(status1.isInstanceOf[HighlyCompressedMapStatus]) + assert(status1.location == loc) + for (i <- 0 until threshold) { + val estimate = status1.getSizeForBlock(i) + if (sizes(i) > 0) { + assert(estimate === avg) + } + } + } + def compressAndDecompressMapStatus(status: MapStatus): MapStatus = { val ser = new JavaSerializer(new SparkConf) val buf = ser.newInstance().serialize(status) From 7d039e0c0af0931a1696d89e76f455ae4adf277d Mon Sep 17 00:00:00 2001 From: Li Jin Date: Fri, 17 Nov 2017 16:43:08 +0100 Subject: [PATCH 1688/1765] [SPARK-22409] Introduce function type argument in pandas_udf ## What changes were proposed in this pull request? * Add a "function type" argument to pandas_udf. * Add a new public enum class `PandasUdfType` in pyspark.sql.functions * Refactor udf related code from pyspark.sql.functions to pyspark.sql.udf * Merge "PythonUdfType" and "PythonEvalType" into a single enum class "PythonEvalType" Example: ``` from pyspark.sql.functions import pandas_udf, PandasUDFType pandas_udf('double', PandasUDFType.SCALAR): def plus_one(v): return v + 1 ``` ## Design doc https://docs.google.com/document/d/1KlLaa-xJ3oz28xlEJqXyCAHU3dwFYkFs_ixcUXrJNTc/edit ## How was this patch tested? Added PandasUDFTests ## TODO: * [x] Implement proper enum type for `PandasUDFType` * [x] Update documentation * [x] Add more tests in PandasUDFTests Author: Li Jin Closes #19630 from icexelloss/spark-22409-pandas-udf-type. --- .../spark/api/python/PythonRunner.scala | 8 +- python/pyspark/rdd.py | 16 ++ python/pyspark/serializers.py | 7 - python/pyspark/sql/catalog.py | 7 +- python/pyspark/sql/functions.py | 222 ++++++------------ python/pyspark/sql/group.py | 49 +--- python/pyspark/sql/tests.py | 221 ++++++++++++----- python/pyspark/sql/udf.py | 161 +++++++++++++ python/pyspark/worker.py | 39 ++- .../spark/sql/RelationalGroupedDataset.scala | 9 +- .../python/ArrowEvalPythonExec.scala | 2 +- .../execution/python/ExtractPythonUDFs.scala | 12 +- .../python/FlatMapGroupsInPandasExec.scala | 2 +- .../sql/execution/python/PythonUDF.scala | 2 +- .../python/UserDefinedPythonFunction.scala | 13 +- .../python/BatchEvalPythonExecSuite.scala | 4 +- 16 files changed, 490 insertions(+), 284 deletions(-) create mode 100644 python/pyspark/sql/udf.py diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala index 9989f68f8508c..f524de68fbce0 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala @@ -34,9 +34,11 @@ import org.apache.spark.util._ */ private[spark] object PythonEvalType { val NON_UDF = 0 - val SQL_BATCHED_UDF = 1 - val SQL_PANDAS_UDF = 2 - val SQL_PANDAS_GROUPED_UDF = 3 + + val SQL_BATCHED_UDF = 100 + + val SQL_PANDAS_SCALAR_UDF = 200 + val SQL_PANDAS_GROUP_MAP_UDF = 201 } /** diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index ea993c572fafd..340bc3a6b7470 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -56,6 +56,22 @@ __all__ = ["RDD"] +class PythonEvalType(object): + """ + Evaluation type of python rdd. + + These values are internal to PySpark. + + These values should match values in org.apache.spark.api.python.PythonEvalType. + """ + NON_UDF = 0 + + SQL_BATCHED_UDF = 100 + + SQL_PANDAS_SCALAR_UDF = 200 + SQL_PANDAS_GROUP_MAP_UDF = 201 + + def portable_hash(x): """ This function returns consistent hash code for builtin types, especially diff --git a/python/pyspark/serializers.py b/python/pyspark/serializers.py index e0afdafbfcd62..b95de2c804394 100644 --- a/python/pyspark/serializers.py +++ b/python/pyspark/serializers.py @@ -82,13 +82,6 @@ class SpecialLengths(object): START_ARROW_STREAM = -6 -class PythonEvalType(object): - NON_UDF = 0 - SQL_BATCHED_UDF = 1 - SQL_PANDAS_UDF = 2 - SQL_PANDAS_GROUPED_UDF = 3 - - class Serializer(object): def dump_stream(self, iterator, stream): diff --git a/python/pyspark/sql/catalog.py b/python/pyspark/sql/catalog.py index 5f25dce161963..659bc65701a0c 100644 --- a/python/pyspark/sql/catalog.py +++ b/python/pyspark/sql/catalog.py @@ -19,9 +19,9 @@ from collections import namedtuple from pyspark import since -from pyspark.rdd import ignore_unicode_prefix +from pyspark.rdd import ignore_unicode_prefix, PythonEvalType from pyspark.sql.dataframe import DataFrame -from pyspark.sql.functions import UserDefinedFunction +from pyspark.sql.udf import UserDefinedFunction from pyspark.sql.types import IntegerType, StringType, StructType @@ -256,7 +256,8 @@ def registerFunction(self, name, f, returnType=StringType()): >>> spark.sql("SELECT stringLengthInt('test')").collect() [Row(stringLengthInt(test)=4)] """ - udf = UserDefinedFunction(f, returnType, name) + udf = UserDefinedFunction(f, returnType=returnType, name=name, + evalType=PythonEvalType.SQL_BATCHED_UDF) self._jsparkSession.udf().registerPython(name, udf._judf) return udf._wrapped() diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index 087ce7caa89c8..b631e2041706f 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -27,11 +27,12 @@ from itertools import imap as map from pyspark import since, SparkContext -from pyspark.rdd import _prepare_for_python_RDD, ignore_unicode_prefix +from pyspark.rdd import ignore_unicode_prefix, PythonEvalType from pyspark.serializers import PickleSerializer, AutoBatchedSerializer -from pyspark.sql.types import StringType, DataType, _parse_datatype_string from pyspark.sql.column import Column, _to_java_column, _to_seq from pyspark.sql.dataframe import DataFrame +from pyspark.sql.types import StringType, DataType +from pyspark.sql.udf import UserDefinedFunction, _create_udf def _create_function(name, doc=""): @@ -2062,132 +2063,12 @@ def map_values(col): # ---------------------------- User Defined Function ---------------------------------- -def _wrap_function(sc, func, returnType): - command = (func, returnType) - pickled_command, broadcast_vars, env, includes = _prepare_for_python_RDD(sc, command) - return sc._jvm.PythonFunction(bytearray(pickled_command), env, includes, sc.pythonExec, - sc.pythonVer, broadcast_vars, sc._javaAccumulator) - - -class PythonUdfType(object): - # row-at-a-time UDFs - NORMAL_UDF = 0 - # scalar vectorized UDFs - PANDAS_UDF = 1 - # grouped vectorized UDFs - PANDAS_GROUPED_UDF = 2 - - -class UserDefinedFunction(object): - """ - User defined function in Python - - .. versionadded:: 1.3 - """ - def __init__(self, func, returnType, name=None, pythonUdfType=PythonUdfType.NORMAL_UDF): - if not callable(func): - raise TypeError( - "Not a function or callable (__call__ is not defined): " - "{0}".format(type(func))) - - self.func = func - self._returnType = returnType - # Stores UserDefinedPythonFunctions jobj, once initialized - self._returnType_placeholder = None - self._judf_placeholder = None - self._name = name or ( - func.__name__ if hasattr(func, '__name__') - else func.__class__.__name__) - self.pythonUdfType = pythonUdfType - - @property - def returnType(self): - # This makes sure this is called after SparkContext is initialized. - # ``_parse_datatype_string`` accesses to JVM for parsing a DDL formatted string. - if self._returnType_placeholder is None: - if isinstance(self._returnType, DataType): - self._returnType_placeholder = self._returnType - else: - self._returnType_placeholder = _parse_datatype_string(self._returnType) - return self._returnType_placeholder - - @property - def _judf(self): - # It is possible that concurrent access, to newly created UDF, - # will initialize multiple UserDefinedPythonFunctions. - # This is unlikely, doesn't affect correctness, - # and should have a minimal performance impact. - if self._judf_placeholder is None: - self._judf_placeholder = self._create_judf() - return self._judf_placeholder - - def _create_judf(self): - from pyspark.sql import SparkSession - - spark = SparkSession.builder.getOrCreate() - sc = spark.sparkContext - - wrapped_func = _wrap_function(sc, self.func, self.returnType) - jdt = spark._jsparkSession.parseDataType(self.returnType.json()) - judf = sc._jvm.org.apache.spark.sql.execution.python.UserDefinedPythonFunction( - self._name, wrapped_func, jdt, self.pythonUdfType) - return judf - - def __call__(self, *cols): - judf = self._judf - sc = SparkContext._active_spark_context - return Column(judf.apply(_to_seq(sc, cols, _to_java_column))) - - def _wrapped(self): - """ - Wrap this udf with a function and attach docstring from func - """ - - # It is possible for a callable instance without __name__ attribute or/and - # __module__ attribute to be wrapped here. For example, functools.partial. In this case, - # we should avoid wrapping the attributes from the wrapped function to the wrapper - # function. So, we take out these attribute names from the default names to set and - # then manually assign it after being wrapped. - assignments = tuple( - a for a in functools.WRAPPER_ASSIGNMENTS if a != '__name__' and a != '__module__') - - @functools.wraps(self.func, assigned=assignments) - def wrapper(*args): - return self(*args) - - wrapper.__name__ = self._name - wrapper.__module__ = (self.func.__module__ if hasattr(self.func, '__module__') - else self.func.__class__.__module__) - - wrapper.func = self.func - wrapper.returnType = self.returnType - wrapper.pythonUdfType = self.pythonUdfType - - return wrapper - - -def _create_udf(f, returnType, pythonUdfType): - - def _udf(f, returnType=StringType(), pythonUdfType=pythonUdfType): - if pythonUdfType == PythonUdfType.PANDAS_UDF: - import inspect - argspec = inspect.getargspec(f) - if len(argspec.args) == 0 and argspec.varargs is None: - raise ValueError( - "0-arg pandas_udfs are not supported. " - "Instead, create a 1-arg pandas_udf and ignore the arg in your function." - ) - udf_obj = UserDefinedFunction(f, returnType, pythonUdfType=pythonUdfType) - return udf_obj._wrapped() - - # decorator @udf, @udf(), @udf(dataType()), or similar with @pandas_udf - if f is None or isinstance(f, (str, DataType)): - # If DataType has been passed as a positional argument - # for decorator use it as a returnType - return_type = f or returnType - return functools.partial(_udf, returnType=return_type, pythonUdfType=pythonUdfType) - else: - return _udf(f=f, returnType=returnType, pythonUdfType=pythonUdfType) +class PandasUDFType(object): + """Pandas UDF Types. See :meth:`pyspark.sql.functions.pandas_udf`. + """ + SCALAR = PythonEvalType.SQL_PANDAS_SCALAR_UDF + + GROUP_MAP = PythonEvalType.SQL_PANDAS_GROUP_MAP_UDF @since(1.3) @@ -2228,33 +2109,47 @@ def udf(f=None, returnType=StringType()): | 8| JOHN DOE| 22| +----------+--------------+------------+ """ - return _create_udf(f, returnType=returnType, pythonUdfType=PythonUdfType.NORMAL_UDF) + # decorator @udf, @udf(), @udf(dataType()) + if f is None or isinstance(f, (str, DataType)): + # If DataType has been passed as a positional argument + # for decorator use it as a returnType + return_type = f or returnType + return functools.partial(_create_udf, returnType=return_type, + evalType=PythonEvalType.SQL_BATCHED_UDF) + else: + return _create_udf(f=f, returnType=returnType, + evalType=PythonEvalType.SQL_BATCHED_UDF) @since(2.3) -def pandas_udf(f=None, returnType=StringType()): +def pandas_udf(f=None, returnType=None, functionType=None): """ Creates a vectorized user defined function (UDF). :param f: user-defined function. A python function if used as a standalone function :param returnType: a :class:`pyspark.sql.types.DataType` object + :param functionType: an enum value in :class:`pyspark.sql.functions.PandasUDFType`. + Default: SCALAR. - The user-defined function can define one of the following transformations: + The function type of the UDF can be one of the following: - 1. One or more `pandas.Series` -> A `pandas.Series` + 1. SCALAR - This udf is used with :meth:`pyspark.sql.DataFrame.withColumn` and - :meth:`pyspark.sql.DataFrame.select`. + A scalar UDF defines a transformation: One or more `pandas.Series` -> A `pandas.Series`. The returnType should be a primitive data type, e.g., `DoubleType()`. The length of the returned `pandas.Series` must be of the same as the input `pandas.Series`. + Scalar UDFs are used with :meth:`pyspark.sql.DataFrame.withColumn` and + :meth:`pyspark.sql.DataFrame.select`. + + >>> from pyspark.sql.functions import pandas_udf, PandasUDFType >>> from pyspark.sql.types import IntegerType, StringType >>> slen = pandas_udf(lambda s: s.str.len(), IntegerType()) - >>> @pandas_udf(returnType=StringType()) + >>> @pandas_udf(StringType()) ... def to_upper(s): ... return s.str.upper() ... - >>> @pandas_udf(returnType="integer") + >>> @pandas_udf("integer", PandasUDFType.SCALAR) ... def add_one(x): ... return x + 1 ... @@ -2267,20 +2162,24 @@ def pandas_udf(f=None, returnType=StringType()): | 8| JOHN DOE| 22| +----------+--------------+------------+ - 2. A `pandas.DataFrame` -> A `pandas.DataFrame` + 2. GROUP_MAP - This udf is only used with :meth:`pyspark.sql.GroupedData.apply`. + A group map UDF defines transformation: A `pandas.DataFrame` -> A `pandas.DataFrame` The returnType should be a :class:`StructType` describing the schema of the returned `pandas.DataFrame`. + The length of the returned `pandas.DataFrame` can be arbitrary. + + Group map UDFs are used with :meth:`pyspark.sql.GroupedData.apply`. + >>> from pyspark.sql.functions import pandas_udf, PandasUDFType >>> df = spark.createDataFrame( ... [(1, 1.0), (1, 2.0), (2, 3.0), (2, 5.0), (2, 10.0)], ... ("id", "v")) - >>> @pandas_udf(returnType=df.schema) + >>> @pandas_udf("id long, v double", PandasUDFType.GROUP_MAP) ... def normalize(pdf): ... v = pdf.v ... return pdf.assign(v=(v - v.mean()) / v.std()) - >>> df.groupby('id').apply(normalize).show() # doctest: +SKIP + >>> df.groupby("id").apply(normalize).show() # doctest: +SKIP +---+-------------------+ | id| v| +---+-------------------+ @@ -2291,10 +2190,6 @@ def pandas_udf(f=None, returnType=StringType()): | 2| 1.1094003924504583| +---+-------------------+ - .. note:: This type of udf cannot be used with functions such as `withColumn` or `select` - because it defines a `DataFrame` transformation rather than a `Column` - transformation. - .. seealso:: :meth:`pyspark.sql.GroupedData.apply` .. note:: The user-defined function must be deterministic. @@ -2306,7 +2201,44 @@ def pandas_udf(f=None, returnType=StringType()): rows that do not satisfy the conditions, the suggested workaround is to incorporate the condition logic into the functions. """ - return _create_udf(f, returnType=returnType, pythonUdfType=PythonUdfType.PANDAS_UDF) + # decorator @pandas_udf(returnType, functionType) + is_decorator = f is None or isinstance(f, (str, DataType)) + + if is_decorator: + # If DataType has been passed as a positional argument + # for decorator use it as a returnType + return_type = f or returnType + + if functionType is not None: + # @pandas_udf(dataType, functionType=functionType) + # @pandas_udf(returnType=dataType, functionType=functionType) + eval_type = functionType + elif returnType is not None and isinstance(returnType, int): + # @pandas_udf(dataType, functionType) + eval_type = returnType + else: + # @pandas_udf(dataType) or @pandas_udf(returnType=dataType) + eval_type = PythonEvalType.SQL_PANDAS_SCALAR_UDF + else: + return_type = returnType + + if functionType is not None: + eval_type = functionType + else: + eval_type = PythonEvalType.SQL_PANDAS_SCALAR_UDF + + if return_type is None: + raise ValueError("Invalid returnType: returnType can not be None") + + if eval_type not in [PythonEvalType.SQL_PANDAS_SCALAR_UDF, + PythonEvalType.SQL_PANDAS_GROUP_MAP_UDF]: + raise ValueError("Invalid functionType: " + "functionType must be one the values from PandasUDFType") + + if is_decorator: + return functools.partial(_create_udf, returnType=return_type, evalType=eval_type) + else: + return _create_udf(f=f, returnType=return_type, evalType=eval_type) blacklist = ['map', 'since', 'ignore_unicode_prefix'] diff --git a/python/pyspark/sql/group.py b/python/pyspark/sql/group.py index e11388d604312..4d47dd6a3e878 100644 --- a/python/pyspark/sql/group.py +++ b/python/pyspark/sql/group.py @@ -16,10 +16,10 @@ # from pyspark import since -from pyspark.rdd import ignore_unicode_prefix +from pyspark.rdd import ignore_unicode_prefix, PythonEvalType from pyspark.sql.column import Column, _to_seq, _to_java_column, _create_column_from_literal from pyspark.sql.dataframe import DataFrame -from pyspark.sql.functions import PythonUdfType, UserDefinedFunction +from pyspark.sql.udf import UserDefinedFunction from pyspark.sql.types import * __all__ = ["GroupedData"] @@ -214,15 +214,15 @@ def apply(self, udf): :param udf: A function object returned by :meth:`pyspark.sql.functions.pandas_udf` - >>> from pyspark.sql.functions import pandas_udf + >>> from pyspark.sql.functions import pandas_udf, PandasUDFType >>> df = spark.createDataFrame( ... [(1, 1.0), (1, 2.0), (2, 3.0), (2, 5.0), (2, 10.0)], ... ("id", "v")) - >>> @pandas_udf(returnType=df.schema) + >>> @pandas_udf("id long, v double", PandasUDFType.GROUP_MAP) ... def normalize(pdf): ... v = pdf.v ... return pdf.assign(v=(v - v.mean()) / v.std()) - >>> df.groupby('id').apply(normalize).show() # doctest: +SKIP + >>> df.groupby("id").apply(normalize).show() # doctest: +SKIP +---+-------------------+ | id| v| +---+-------------------+ @@ -236,44 +236,13 @@ def apply(self, udf): .. seealso:: :meth:`pyspark.sql.functions.pandas_udf` """ - import inspect - # Columns are special because hasattr always return True if isinstance(udf, Column) or not hasattr(udf, 'func') \ - or udf.pythonUdfType != PythonUdfType.PANDAS_UDF \ - or len(inspect.getargspec(udf.func).args) != 1: - raise ValueError("The argument to apply must be a 1-arg pandas_udf") - if not isinstance(udf.returnType, StructType): - raise ValueError("The returnType of the pandas_udf must be a StructType") - + or udf.evalType != PythonEvalType.SQL_PANDAS_GROUP_MAP_UDF: + raise ValueError("Invalid udf: the udf argument must be a pandas_udf of type " + "GROUP_MAP.") df = self._df - func = udf.func - returnType = udf.returnType - - # The python executors expects the function to use pd.Series as input and output - # So we to create a wrapper function that turns that to a pd.DataFrame before passing - # down to the user function, then turn the result pd.DataFrame back into pd.Series - columns = df.columns - - def wrapped(*cols): - from pyspark.sql.types import to_arrow_type - import pandas as pd - result = func(pd.concat(cols, axis=1, keys=columns)) - if not isinstance(result, pd.DataFrame): - raise TypeError("Return type of the user-defined function should be " - "Pandas.DataFrame, but is {}".format(type(result))) - if not len(result.columns) == len(returnType): - raise RuntimeError( - "Number of columns of the returned Pandas.DataFrame " - "doesn't match specified schema. " - "Expected: {} Actual: {}".format(len(returnType), len(result.columns))) - arrow_return_types = (to_arrow_type(field.dataType) for field in returnType) - return [(result[result.columns[i]], arrow_type) - for i, arrow_type in enumerate(arrow_return_types)] - - udf_obj = UserDefinedFunction( - wrapped, returnType, name=udf.__name__, pythonUdfType=PythonUdfType.PANDAS_GROUPED_UDF) - udf_column = udf_obj(*[df[col] for col in df.columns]) + udf_column = udf(*[df[col] for col in df.columns]) jdf = self._jgd.flatMapGroupsInPandas(udf_column._jc.expr()) return DataFrame(jdf, self.sql_ctx) diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index ef592c2356a8c..762afe0d730f3 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -3259,6 +3259,129 @@ def test_schema_conversion_roundtrip(self): self.assertEquals(self.schema, schema_rt) +class PandasUDFTests(ReusedSQLTestCase): + def test_pandas_udf_basic(self): + from pyspark.rdd import PythonEvalType + from pyspark.sql.functions import pandas_udf, PandasUDFType + + udf = pandas_udf(lambda x: x, DoubleType()) + self.assertEqual(udf.returnType, DoubleType()) + self.assertEqual(udf.evalType, PythonEvalType.SQL_PANDAS_SCALAR_UDF) + + udf = pandas_udf(lambda x: x, DoubleType(), PandasUDFType.SCALAR) + self.assertEqual(udf.returnType, DoubleType()) + self.assertEqual(udf.evalType, PythonEvalType.SQL_PANDAS_SCALAR_UDF) + + udf = pandas_udf(lambda x: x, 'double', PandasUDFType.SCALAR) + self.assertEqual(udf.returnType, DoubleType()) + self.assertEqual(udf.evalType, PythonEvalType.SQL_PANDAS_SCALAR_UDF) + + udf = pandas_udf(lambda x: x, StructType([StructField("v", DoubleType())]), + PandasUDFType.GROUP_MAP) + self.assertEqual(udf.returnType, StructType([StructField("v", DoubleType())])) + self.assertEqual(udf.evalType, PythonEvalType.SQL_PANDAS_GROUP_MAP_UDF) + + udf = pandas_udf(lambda x: x, 'v double', PandasUDFType.GROUP_MAP) + self.assertEqual(udf.returnType, StructType([StructField("v", DoubleType())])) + self.assertEqual(udf.evalType, PythonEvalType.SQL_PANDAS_GROUP_MAP_UDF) + + udf = pandas_udf(lambda x: x, 'v double', + functionType=PandasUDFType.GROUP_MAP) + self.assertEqual(udf.returnType, StructType([StructField("v", DoubleType())])) + self.assertEqual(udf.evalType, PythonEvalType.SQL_PANDAS_GROUP_MAP_UDF) + + udf = pandas_udf(lambda x: x, returnType='v double', + functionType=PandasUDFType.GROUP_MAP) + self.assertEqual(udf.returnType, StructType([StructField("v", DoubleType())])) + self.assertEqual(udf.evalType, PythonEvalType.SQL_PANDAS_GROUP_MAP_UDF) + + def test_pandas_udf_decorator(self): + from pyspark.rdd import PythonEvalType + from pyspark.sql.functions import pandas_udf, PandasUDFType + from pyspark.sql.types import StructType, StructField, DoubleType + + @pandas_udf(DoubleType()) + def foo(x): + return x + self.assertEqual(foo.returnType, DoubleType()) + self.assertEqual(foo.evalType, PythonEvalType.SQL_PANDAS_SCALAR_UDF) + + @pandas_udf(returnType=DoubleType()) + def foo(x): + return x + self.assertEqual(foo.returnType, DoubleType()) + self.assertEqual(foo.evalType, PythonEvalType.SQL_PANDAS_SCALAR_UDF) + + schema = StructType([StructField("v", DoubleType())]) + + @pandas_udf(schema, PandasUDFType.GROUP_MAP) + def foo(x): + return x + self.assertEqual(foo.returnType, schema) + self.assertEqual(foo.evalType, PythonEvalType.SQL_PANDAS_GROUP_MAP_UDF) + + @pandas_udf('v double', PandasUDFType.GROUP_MAP) + def foo(x): + return x + self.assertEqual(foo.returnType, schema) + self.assertEqual(foo.evalType, PythonEvalType.SQL_PANDAS_GROUP_MAP_UDF) + + @pandas_udf(schema, functionType=PandasUDFType.GROUP_MAP) + def foo(x): + return x + self.assertEqual(foo.returnType, schema) + self.assertEqual(foo.evalType, PythonEvalType.SQL_PANDAS_GROUP_MAP_UDF) + + @pandas_udf(returnType='v double', functionType=PandasUDFType.SCALAR) + def foo(x): + return x + self.assertEqual(foo.returnType, schema) + self.assertEqual(foo.evalType, PythonEvalType.SQL_PANDAS_SCALAR_UDF) + + @pandas_udf(returnType=schema, functionType=PandasUDFType.GROUP_MAP) + def foo(x): + return x + self.assertEqual(foo.returnType, schema) + self.assertEqual(foo.evalType, PythonEvalType.SQL_PANDAS_GROUP_MAP_UDF) + + def test_udf_wrong_arg(self): + from pyspark.sql.functions import pandas_udf, PandasUDFType + + with QuietTest(self.sc): + with self.assertRaises(ParseException): + @pandas_udf('blah') + def foo(x): + return x + with self.assertRaisesRegexp(ValueError, 'Invalid returnType.*None'): + @pandas_udf(functionType=PandasUDFType.SCALAR) + def foo(x): + return x + with self.assertRaisesRegexp(ValueError, 'Invalid functionType'): + @pandas_udf('double', 100) + def foo(x): + return x + + with self.assertRaisesRegexp(ValueError, '0-arg pandas_udfs.*not.*supported'): + pandas_udf(lambda: 1, LongType(), PandasUDFType.SCALAR) + with self.assertRaisesRegexp(ValueError, '0-arg pandas_udfs.*not.*supported'): + @pandas_udf(LongType(), PandasUDFType.SCALAR) + def zero_with_type(): + return 1 + + with self.assertRaisesRegexp(TypeError, 'Invalid returnType'): + @pandas_udf(returnType=PandasUDFType.GROUP_MAP) + def foo(df): + return df + with self.assertRaisesRegexp(ValueError, 'Invalid returnType'): + @pandas_udf(returnType='double', functionType=PandasUDFType.GROUP_MAP) + def foo(df): + return df + with self.assertRaisesRegexp(ValueError, 'Invalid function'): + @pandas_udf(returnType='k int, v double', functionType=PandasUDFType.GROUP_MAP) + def foo(k, v): + return k + + @unittest.skipIf(not _have_pandas or not _have_arrow, "Pandas or Arrow not installed") class VectorizedUDFTests(ReusedSQLTestCase): @@ -3355,23 +3478,6 @@ def test_vectorized_udf_null_string(self): res = df.select(str_f(col('str'))) self.assertEquals(df.collect(), res.collect()) - def test_vectorized_udf_zero_parameter(self): - from pyspark.sql.functions import pandas_udf - error_str = '0-arg pandas_udfs.*not.*supported' - with QuietTest(self.sc): - with self.assertRaisesRegexp(ValueError, error_str): - pandas_udf(lambda: 1, LongType()) - - with self.assertRaisesRegexp(ValueError, error_str): - @pandas_udf - def zero_no_type(): - return 1 - - with self.assertRaisesRegexp(ValueError, error_str): - @pandas_udf(LongType()) - def zero_with_type(): - return 1 - def test_vectorized_udf_datatype_string(self): from pyspark.sql.functions import pandas_udf, col df = self.spark.range(10).select( @@ -3570,7 +3676,7 @@ def check_records_per_batch(x): return x result = df.select(check_records_per_batch(col("id"))) - self.assertEquals(df.collect(), result.collect()) + self.assertEqual(df.collect(), result.collect()) finally: if orig_value is None: self.spark.conf.unset("spark.sql.execution.arrow.maxRecordsPerBatch") @@ -3595,7 +3701,7 @@ def data(self): .withColumn("v", explode(col('vs'))).drop('vs') def test_simple(self): - from pyspark.sql.functions import pandas_udf + from pyspark.sql.functions import pandas_udf, PandasUDFType df = self.data foo_udf = pandas_udf( @@ -3604,21 +3710,22 @@ def test_simple(self): [StructField('id', LongType()), StructField('v', IntegerType()), StructField('v1', DoubleType()), - StructField('v2', LongType())])) + StructField('v2', LongType())]), + PandasUDFType.GROUP_MAP + ) result = df.groupby('id').apply(foo_udf).sort('id').toPandas() expected = df.toPandas().groupby('id').apply(foo_udf.func).reset_index(drop=True) self.assertFramesEqual(expected, result) def test_decorator(self): - from pyspark.sql.functions import pandas_udf + from pyspark.sql.functions import pandas_udf, PandasUDFType df = self.data - @pandas_udf(StructType( - [StructField('id', LongType()), - StructField('v', IntegerType()), - StructField('v1', DoubleType()), - StructField('v2', LongType())])) + @pandas_udf( + 'id long, v int, v1 double, v2 long', + PandasUDFType.GROUP_MAP + ) def foo(pdf): return pdf.assign(v1=pdf.v * pdf.id * 1.0, v2=pdf.v + pdf.id) @@ -3627,12 +3734,14 @@ def foo(pdf): self.assertFramesEqual(expected, result) def test_coerce(self): - from pyspark.sql.functions import pandas_udf + from pyspark.sql.functions import pandas_udf, PandasUDFType df = self.data foo = pandas_udf( lambda pdf: pdf, - StructType([StructField('id', LongType()), StructField('v', DoubleType())])) + 'id long, v double', + PandasUDFType.GROUP_MAP + ) result = df.groupby('id').apply(foo).sort('id').toPandas() expected = df.toPandas().groupby('id').apply(foo.func).reset_index(drop=True) @@ -3640,13 +3749,13 @@ def test_coerce(self): self.assertFramesEqual(expected, result) def test_complex_groupby(self): - from pyspark.sql.functions import pandas_udf, col + from pyspark.sql.functions import pandas_udf, col, PandasUDFType df = self.data - @pandas_udf(StructType( - [StructField('id', LongType()), - StructField('v', IntegerType()), - StructField('norm', DoubleType())])) + @pandas_udf( + 'id long, v int, norm double', + PandasUDFType.GROUP_MAP + ) def normalize(pdf): v = pdf.v return pdf.assign(norm=(v - v.mean()) / v.std()) @@ -3659,13 +3768,13 @@ def normalize(pdf): self.assertFramesEqual(expected, result) def test_empty_groupby(self): - from pyspark.sql.functions import pandas_udf, col + from pyspark.sql.functions import pandas_udf, col, PandasUDFType df = self.data - @pandas_udf(StructType( - [StructField('id', LongType()), - StructField('v', IntegerType()), - StructField('norm', DoubleType())])) + @pandas_udf( + 'id long, v int, norm double', + PandasUDFType.GROUP_MAP + ) def normalize(pdf): v = pdf.v return pdf.assign(norm=(v - v.mean()) / v.std()) @@ -3678,57 +3787,63 @@ def normalize(pdf): self.assertFramesEqual(expected, result) def test_datatype_string(self): - from pyspark.sql.functions import pandas_udf + from pyspark.sql.functions import pandas_udf, PandasUDFType df = self.data foo_udf = pandas_udf( lambda pdf: pdf.assign(v1=pdf.v * pdf.id * 1.0, v2=pdf.v + pdf.id), - "id long, v int, v1 double, v2 long") + 'id long, v int, v1 double, v2 long', + PandasUDFType.GROUP_MAP + ) result = df.groupby('id').apply(foo_udf).sort('id').toPandas() expected = df.toPandas().groupby('id').apply(foo_udf.func).reset_index(drop=True) self.assertFramesEqual(expected, result) def test_wrong_return_type(self): - from pyspark.sql.functions import pandas_udf + from pyspark.sql.functions import pandas_udf, PandasUDFType df = self.data foo = pandas_udf( lambda pdf: pdf, - StructType([StructField('id', LongType()), StructField('v', StringType())])) + 'id long, v string', + PandasUDFType.GROUP_MAP + ) with QuietTest(self.sc): with self.assertRaisesRegexp(Exception, 'Invalid.*type'): df.groupby('id').apply(foo).sort('id').toPandas() def test_wrong_args(self): - from pyspark.sql.functions import udf, pandas_udf, sum + from pyspark.sql.functions import udf, pandas_udf, sum, PandasUDFType df = self.data with QuietTest(self.sc): - with self.assertRaisesRegexp(ValueError, 'pandas_udf'): + with self.assertRaisesRegexp(ValueError, 'Invalid udf'): df.groupby('id').apply(lambda x: x) - with self.assertRaisesRegexp(ValueError, 'pandas_udf'): + with self.assertRaisesRegexp(ValueError, 'Invalid udf'): df.groupby('id').apply(udf(lambda x: x, DoubleType())) - with self.assertRaisesRegexp(ValueError, 'pandas_udf'): + with self.assertRaisesRegexp(ValueError, 'Invalid udf'): df.groupby('id').apply(sum(df.v)) - with self.assertRaisesRegexp(ValueError, 'pandas_udf'): + with self.assertRaisesRegexp(ValueError, 'Invalid udf'): df.groupby('id').apply(df.v + 1) - with self.assertRaisesRegexp(ValueError, 'pandas_udf'): + with self.assertRaisesRegexp(ValueError, 'Invalid function'): df.groupby('id').apply( pandas_udf(lambda: 1, StructType([StructField("d", DoubleType())]))) - with self.assertRaisesRegexp(ValueError, 'pandas_udf'): + with self.assertRaisesRegexp(ValueError, 'Invalid udf'): df.groupby('id').apply( pandas_udf(lambda x, y: x, StructType([StructField("d", DoubleType())]))) - with self.assertRaisesRegexp(ValueError, 'returnType'): - df.groupby('id').apply(pandas_udf(lambda x: x, DoubleType())) + with self.assertRaisesRegexp(ValueError, 'Invalid udf.*GROUP_MAP'): + df.groupby('id').apply( + pandas_udf(lambda x, y: x, StructType([StructField("d", DoubleType())]), + PandasUDFType.SCALAR)) def test_unsupported_types(self): - from pyspark.sql.functions import pandas_udf, col + from pyspark.sql.functions import pandas_udf, col, PandasUDFType schema = StructType( [StructField("id", LongType(), True), StructField("dt", DecimalType(), True)]) df = self.spark.createDataFrame([(1, None,)], schema=schema) - f = pandas_udf(lambda x: x, df.schema) + f = pandas_udf(lambda x: x, df.schema, PandasUDFType.GROUP_MAP) with QuietTest(self.sc): with self.assertRaisesRegexp(Exception, 'Unsupported data type'): df.groupby('id').apply(f).collect() diff --git a/python/pyspark/sql/udf.py b/python/pyspark/sql/udf.py new file mode 100644 index 0000000000000..c3301a41ccd5a --- /dev/null +++ b/python/pyspark/sql/udf.py @@ -0,0 +1,161 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +""" +User-defined function related classes and functions +""" +import functools + +from pyspark import SparkContext +from pyspark.rdd import _prepare_for_python_RDD, PythonEvalType +from pyspark.sql.column import Column, _to_java_column, _to_seq +from pyspark.sql.types import StringType, DataType, StructType, _parse_datatype_string + + +def _wrap_function(sc, func, returnType): + command = (func, returnType) + pickled_command, broadcast_vars, env, includes = _prepare_for_python_RDD(sc, command) + return sc._jvm.PythonFunction(bytearray(pickled_command), env, includes, sc.pythonExec, + sc.pythonVer, broadcast_vars, sc._javaAccumulator) + + +def _create_udf(f, returnType, evalType): + if evalType == PythonEvalType.SQL_PANDAS_SCALAR_UDF: + import inspect + argspec = inspect.getargspec(f) + if len(argspec.args) == 0 and argspec.varargs is None: + raise ValueError( + "Invalid function: 0-arg pandas_udfs are not supported. " + "Instead, create a 1-arg pandas_udf and ignore the arg in your function." + ) + + elif evalType == PythonEvalType.SQL_PANDAS_GROUP_MAP_UDF: + import inspect + argspec = inspect.getargspec(f) + if len(argspec.args) != 1: + raise ValueError( + "Invalid function: pandas_udfs with function type GROUP_MAP " + "must take a single arg that is a pandas DataFrame." + ) + + # Set the name of the UserDefinedFunction object to be the name of function f + udf_obj = UserDefinedFunction(f, returnType=returnType, name=None, evalType=evalType) + return udf_obj._wrapped() + + +class UserDefinedFunction(object): + """ + User defined function in Python + + .. versionadded:: 1.3 + """ + def __init__(self, func, + returnType=StringType(), name=None, + evalType=PythonEvalType.SQL_BATCHED_UDF): + if not callable(func): + raise TypeError( + "Invalid function: not a function or callable (__call__ is not defined): " + "{0}".format(type(func))) + + if not isinstance(returnType, (DataType, str)): + raise TypeError( + "Invalid returnType: returnType should be DataType or str " + "but is {}".format(returnType)) + + if not isinstance(evalType, int): + raise TypeError( + "Invalid evalType: evalType should be an int but is {}".format(evalType)) + + self.func = func + self._returnType = returnType + # Stores UserDefinedPythonFunctions jobj, once initialized + self._returnType_placeholder = None + self._judf_placeholder = None + self._name = name or ( + func.__name__ if hasattr(func, '__name__') + else func.__class__.__name__) + self.evalType = evalType + + @property + def returnType(self): + # This makes sure this is called after SparkContext is initialized. + # ``_parse_datatype_string`` accesses to JVM for parsing a DDL formatted string. + if self._returnType_placeholder is None: + if isinstance(self._returnType, DataType): + self._returnType_placeholder = self._returnType + else: + self._returnType_placeholder = _parse_datatype_string(self._returnType) + + if self.evalType == PythonEvalType.SQL_PANDAS_GROUP_MAP_UDF \ + and not isinstance(self._returnType_placeholder, StructType): + raise ValueError("Invalid returnType: returnType must be a StructType for " + "pandas_udf with function type GROUP_MAP") + + return self._returnType_placeholder + + @property + def _judf(self): + # It is possible that concurrent access, to newly created UDF, + # will initialize multiple UserDefinedPythonFunctions. + # This is unlikely, doesn't affect correctness, + # and should have a minimal performance impact. + if self._judf_placeholder is None: + self._judf_placeholder = self._create_judf() + return self._judf_placeholder + + def _create_judf(self): + from pyspark.sql import SparkSession + + spark = SparkSession.builder.getOrCreate() + sc = spark.sparkContext + + wrapped_func = _wrap_function(sc, self.func, self.returnType) + jdt = spark._jsparkSession.parseDataType(self.returnType.json()) + judf = sc._jvm.org.apache.spark.sql.execution.python.UserDefinedPythonFunction( + self._name, wrapped_func, jdt, self.evalType) + return judf + + def __call__(self, *cols): + judf = self._judf + sc = SparkContext._active_spark_context + return Column(judf.apply(_to_seq(sc, cols, _to_java_column))) + + def _wrapped(self): + """ + Wrap this udf with a function and attach docstring from func + """ + + # It is possible for a callable instance without __name__ attribute or/and + # __module__ attribute to be wrapped here. For example, functools.partial. In this case, + # we should avoid wrapping the attributes from the wrapped function to the wrapper + # function. So, we take out these attribute names from the default names to set and + # then manually assign it after being wrapped. + assignments = tuple( + a for a in functools.WRAPPER_ASSIGNMENTS if a != '__name__' and a != '__module__') + + @functools.wraps(self.func, assigned=assignments) + def wrapper(*args): + return self(*args) + + wrapper.__name__ = self._name + wrapper.__module__ = (self.func.__module__ if hasattr(self.func, '__module__') + else self.func.__class__.__module__) + + wrapper.func = self.func + wrapper.returnType = self.returnType + wrapper.evalType = self.evalType + + return wrapper diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py index 5e100e0a9a95d..939643071943a 100644 --- a/python/pyspark/worker.py +++ b/python/pyspark/worker.py @@ -29,8 +29,9 @@ from pyspark.broadcast import Broadcast, _broadcastRegistry from pyspark.taskcontext import TaskContext from pyspark.files import SparkFiles +from pyspark.rdd import PythonEvalType from pyspark.serializers import write_with_length, write_int, read_long, \ - write_long, read_int, SpecialLengths, PythonEvalType, UTF8Deserializer, PickleSerializer, \ + write_long, read_int, SpecialLengths, UTF8Deserializer, PickleSerializer, \ BatchedSerializer, ArrowStreamPandasSerializer from pyspark.sql.types import to_arrow_type from pyspark import shuffle @@ -73,7 +74,7 @@ def wrap_udf(f, return_type): return lambda *a: f(*a) -def wrap_pandas_udf(f, return_type): +def wrap_pandas_scalar_udf(f, return_type): arrow_return_type = to_arrow_type(return_type) def verify_result_length(*a): @@ -89,6 +90,26 @@ def verify_result_length(*a): return lambda *a: (verify_result_length(*a), arrow_return_type) +def wrap_pandas_group_map_udf(f, return_type): + def wrapped(*series): + import pandas as pd + + result = f(pd.concat(series, axis=1)) + if not isinstance(result, pd.DataFrame): + raise TypeError("Return type of the user-defined function should be " + "pandas.DataFrame, but is {}".format(type(result))) + if not len(result.columns) == len(return_type): + raise RuntimeError( + "Number of columns of the returned pandas.DataFrame " + "doesn't match specified schema. " + "Expected: {} Actual: {}".format(len(return_type), len(result.columns))) + arrow_return_types = (to_arrow_type(field.dataType) for field in return_type) + return [(result[result.columns[i]], arrow_type) + for i, arrow_type in enumerate(arrow_return_types)] + + return wrapped + + def read_single_udf(pickleSer, infile, eval_type): num_arg = read_int(infile) arg_offsets = [read_int(infile) for i in range(num_arg)] @@ -99,12 +120,12 @@ def read_single_udf(pickleSer, infile, eval_type): row_func = f else: row_func = chain(row_func, f) + # the last returnType will be the return type of UDF - if eval_type == PythonEvalType.SQL_PANDAS_UDF: - return arg_offsets, wrap_pandas_udf(row_func, return_type) - elif eval_type == PythonEvalType.SQL_PANDAS_GROUPED_UDF: - # a groupby apply udf has already been wrapped under apply() - return arg_offsets, row_func + if eval_type == PythonEvalType.SQL_PANDAS_SCALAR_UDF: + return arg_offsets, wrap_pandas_scalar_udf(row_func, return_type) + elif eval_type == PythonEvalType.SQL_PANDAS_GROUP_MAP_UDF: + return arg_offsets, wrap_pandas_group_map_udf(row_func, return_type) else: return arg_offsets, wrap_udf(row_func, return_type) @@ -127,8 +148,8 @@ def read_udfs(pickleSer, infile, eval_type): func = lambda _, it: map(mapper, it) - if eval_type == PythonEvalType.SQL_PANDAS_UDF \ - or eval_type == PythonEvalType.SQL_PANDAS_GROUPED_UDF: + if eval_type == PythonEvalType.SQL_PANDAS_SCALAR_UDF \ + or eval_type == PythonEvalType.SQL_PANDAS_GROUP_MAP_UDF: ser = ArrowStreamPandasSerializer() else: ser = BatchedSerializer(PickleSerializer(), 100) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala index 3e4edd4ea8cf3..a009c00b0abc5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala @@ -23,6 +23,7 @@ import scala.collection.JavaConverters._ import scala.language.implicitConversions import org.apache.spark.annotation.InterfaceStability +import org.apache.spark.api.python.PythonEvalType import org.apache.spark.broadcast.Broadcast import org.apache.spark.sql.catalyst.analysis.{Star, UnresolvedAlias, UnresolvedAttribute, UnresolvedFunction} import org.apache.spark.sql.catalyst.expressions._ @@ -30,7 +31,7 @@ import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.util.toPrettySQL import org.apache.spark.sql.execution.aggregate.TypedAggregateExpression -import org.apache.spark.sql.execution.python.{PythonUDF, PythonUdfType} +import org.apache.spark.sql.execution.python.PythonUDF import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.{NumericType, StructType} @@ -449,10 +450,10 @@ class RelationalGroupedDataset protected[sql]( * workers. */ private[sql] def flatMapGroupsInPandas(expr: PythonUDF): DataFrame = { - require(expr.pythonUdfType == PythonUdfType.PANDAS_GROUPED_UDF, - "Must pass a grouped vectorized python udf") + require(expr.evalType == PythonEvalType.SQL_PANDAS_GROUP_MAP_UDF, + "Must pass a group map udf") require(expr.dataType.isInstanceOf[StructType], - "The returnType of the vectorized python udf must be a StructType") + "The returnType of the udf must be a StructType") val groupingNamedExpressions = groupingExprs.map { case ne: NamedExpression => ne diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonExec.scala index bcda2dae92e53..e27210117a1e7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonExec.scala @@ -81,7 +81,7 @@ case class ArrowEvalPythonExec(udfs: Seq[PythonUDF], output: Seq[Attribute], chi val columnarBatchIter = new ArrowPythonRunner( funcs, bufferSize, reuseWorker, - PythonEvalType.SQL_PANDAS_UDF, argOffsets, schema, sessionLocalTimeZone) + PythonEvalType.SQL_PANDAS_SCALAR_UDF, argOffsets, schema, sessionLocalTimeZone) .compute(batchIter, context.partitionId(), context) new Iterator[InternalRow] { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala index e15e760136e81..f5a4cbc4793e3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.execution.python import scala.collection.mutable import scala.collection.mutable.ArrayBuffer +import org.apache.spark.api.python.PythonEvalType import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LogicalPlan, Project} @@ -148,15 +149,18 @@ object ExtractPythonUDFs extends Rule[SparkPlan] with PredicateHelper { udf.references.subsetOf(child.outputSet) } if (validUdfs.nonEmpty) { - if (validUdfs.exists(_.pythonUdfType == PythonUdfType.PANDAS_GROUPED_UDF)) { - throw new IllegalArgumentException("Can not use grouped vectorized UDFs") - } + require(validUdfs.forall(udf => + udf.evalType == PythonEvalType.SQL_BATCHED_UDF || + udf.evalType == PythonEvalType.SQL_PANDAS_SCALAR_UDF + ), "Can only extract scalar vectorized udf or sql batch udf") val resultAttrs = udfs.zipWithIndex.map { case (u, i) => AttributeReference(s"pythonUDF$i", u.dataType)() } - val evaluation = validUdfs.partition(_.pythonUdfType == PythonUdfType.PANDAS_UDF) match { + val evaluation = validUdfs.partition( + _.evalType == PythonEvalType.SQL_PANDAS_SCALAR_UDF + ) match { case (vectorizedUdfs, plainUdfs) if plainUdfs.isEmpty => ArrowEvalPythonExec(vectorizedUdfs, child.output ++ resultAttrs, child) case (vectorizedUdfs, plainUdfs) if vectorizedUdfs.isEmpty => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasExec.scala index e1e04e34e0c71..ee495814b8255 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasExec.scala @@ -95,7 +95,7 @@ case class FlatMapGroupsInPandasExec( val columnarBatchIter = new ArrowPythonRunner( chainedFunc, bufferSize, reuseWorker, - PythonEvalType.SQL_PANDAS_GROUPED_UDF, argOffsets, schema, sessionLocalTimeZone) + PythonEvalType.SQL_PANDAS_GROUP_MAP_UDF, argOffsets, schema, sessionLocalTimeZone) .compute(grouped, context.partitionId(), context) columnarBatchIter.flatMap(_.rowIterator.asScala).map(UnsafeProjection.create(output, output)) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonUDF.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonUDF.scala index 9c07c7638de57..ef27fbc2db7d9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonUDF.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonUDF.scala @@ -29,7 +29,7 @@ case class PythonUDF( func: PythonFunction, dataType: DataType, children: Seq[Expression], - pythonUdfType: Int) + evalType: Int) extends Expression with Unevaluable with NonSQLExpression with UserDefinedExpression { override def toString: String = s"$name(${children.mkString(", ")})" diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonFunction.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonFunction.scala index b2fe6c300846a..348e49e473ed3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonFunction.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonFunction.scala @@ -22,15 +22,6 @@ import org.apache.spark.sql.Column import org.apache.spark.sql.catalyst.expressions.Expression import org.apache.spark.sql.types.DataType -private[spark] object PythonUdfType { - // row-at-a-time UDFs - val NORMAL_UDF = 0 - // scalar vectorized UDFs - val PANDAS_UDF = 1 - // grouped vectorized UDFs - val PANDAS_GROUPED_UDF = 2 -} - /** * A user-defined Python function. This is used by the Python API. */ @@ -38,10 +29,10 @@ case class UserDefinedPythonFunction( name: String, func: PythonFunction, dataType: DataType, - pythonUdfType: Int) { + pythonEvalType: Int) { def builder(e: Seq[Expression]): PythonUDF = { - PythonUDF(name, func, dataType, e, pythonUdfType) + PythonUDF(name, func, dataType, e, pythonEvalType) } /** Returns a [[Column]] that will evaluate to calling this UDF with the given input. */ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExecSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExecSuite.scala index 95b21fc9f16ae..53d3f34567518 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExecSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExecSuite.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.execution.python import scala.collection.JavaConverters._ import scala.collection.mutable.ArrayBuffer -import org.apache.spark.api.python.PythonFunction +import org.apache.spark.api.python.{PythonEvalType, PythonFunction} import org.apache.spark.sql.catalyst.FunctionIdentifier import org.apache.spark.sql.catalyst.expressions.{And, AttributeReference, GreaterThan, In} import org.apache.spark.sql.execution.{FilterExec, InputAdapter, SparkPlanTest, WholeStageCodegenExec} @@ -109,4 +109,4 @@ class MyDummyPythonUDF extends UserDefinedPythonFunction( name = "dummyUDF", func = new DummyUDF, dataType = BooleanType, - pythonUdfType = PythonUdfType.NORMAL_UDF) + pythonEvalType = PythonEvalType.SQL_BATCHED_UDF) From fccb337f9d1e44a83cfcc00ce33eae1fad367695 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Fri, 17 Nov 2017 17:43:40 +0100 Subject: [PATCH 1689/1765] [SPARK-22538][ML] SQLTransformer should not unpersist possibly cached input dataset ## What changes were proposed in this pull request? `SQLTransformer.transform` unpersists input dataset when dropping temporary view. We should not change input dataset's cache status. ## How was this patch tested? Added test. Author: Liang-Chi Hsieh Closes #19772 from viirya/SPARK-22538. --- .../org/apache/spark/ml/feature/SQLTransformer.scala | 3 ++- .../spark/ml/feature/SQLTransformerSuite.scala | 12 ++++++++++++ 2 files changed, 14 insertions(+), 1 deletion(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/SQLTransformer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/SQLTransformer.scala index 62c1972aab12c..0fb1d8c5dc579 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/SQLTransformer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/SQLTransformer.scala @@ -70,7 +70,8 @@ class SQLTransformer @Since("1.6.0") (@Since("1.6.0") override val uid: String) dataset.createOrReplaceTempView(tableName) val realStatement = $(statement).replace(tableIdentifier, tableName) val result = dataset.sparkSession.sql(realStatement) - dataset.sparkSession.catalog.dropTempView(tableName) + // Call SessionCatalog.dropTempView to avoid unpersisting the possibly cached dataset. + dataset.sparkSession.sessionState.catalog.dropTempView(tableName) result } diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/SQLTransformerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/SQLTransformerSuite.scala index 753f890c48301..673a146e619f2 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/SQLTransformerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/SQLTransformerSuite.scala @@ -22,6 +22,7 @@ import org.apache.spark.ml.param.ParamsSuite import org.apache.spark.ml.util.DefaultReadWriteTest import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.sql.types.{LongType, StructField, StructType} +import org.apache.spark.storage.StorageLevel class SQLTransformerSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { @@ -60,4 +61,15 @@ class SQLTransformerSuite val expected = StructType(Seq(StructField("id1", LongType, nullable = false))) assert(outputSchema === expected) } + + test("SPARK-22538: SQLTransformer should not unpersist given dataset") { + val df = spark.range(10) + df.cache() + df.count() + assert(df.storageLevel != StorageLevel.NONE) + new SQLTransformer() + .setStatement("SELECT id + 1 AS id1 FROM __THIS__") + .transform(df) + assert(df.storageLevel != StorageLevel.NONE) + } } From bf0c0ae2dcc7fd1ce92cd0fb4809bb3d65b2e309 Mon Sep 17 00:00:00 2001 From: Shixiong Zhu Date: Fri, 17 Nov 2017 15:35:24 -0800 Subject: [PATCH 1690/1765] [SPARK-22544][SS] FileStreamSource should use its own hadoop conf to call globPathIfNecessary ## What changes were proposed in this pull request? Pass the FileSystem created using the correct Hadoop conf into `globPathIfNecessary` so that it can pick up user's hadoop configurations, such as credentials. ## How was this patch tested? Jenkins Author: Shixiong Zhu Closes #19771 from zsxwing/fix-file-stream-conf. --- .../spark/sql/execution/streaming/FileStreamSource.scala | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSource.scala index f17417343e289..0debd7db84757 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSource.scala @@ -47,8 +47,9 @@ class FileStreamSource( private val hadoopConf = sparkSession.sessionState.newHadoopConf() + @transient private val fs = new Path(path).getFileSystem(hadoopConf) + private val qualifiedBasePath: Path = { - val fs = new Path(path).getFileSystem(hadoopConf) fs.makeQualified(new Path(path)) // can contains glob patterns } @@ -187,7 +188,7 @@ class FileStreamSource( if (SparkHadoopUtil.get.isGlobPath(new Path(path))) Some(false) else None private def allFilesUsingInMemoryFileIndex() = { - val globbedPaths = SparkHadoopUtil.get.globPathIfNecessary(qualifiedBasePath) + val globbedPaths = SparkHadoopUtil.get.globPathIfNecessary(fs, qualifiedBasePath) val fileIndex = new InMemoryFileIndex(sparkSession, globbedPaths, options, Some(new StructType)) fileIndex.allFiles() } From d54bfec2e07f2eb934185402f915558fe27b9312 Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Sat, 18 Nov 2017 19:40:06 +0100 Subject: [PATCH 1691/1765] [SPARK-22498][SQL] Fix 64KB JVM bytecode limit problem with concat ## What changes were proposed in this pull request? This PR changes `concat` code generation to place generated code for expression for arguments into separated methods if these size could be large. This PR resolved the case of `concat` with a lot of argument ## How was this patch tested? Added new test cases into `StringExpressionsSuite` Author: Kazuaki Ishizaki Closes #19728 from kiszk/SPARK-22498. --- .../expressions/stringExpressions.scala | 30 +++++++++++++------ .../expressions/StringExpressionsSuite.scala | 6 ++++ 2 files changed, 27 insertions(+), 9 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala index c341943187820..d5bb7e95769e5 100755 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala @@ -63,15 +63,27 @@ case class Concat(children: Seq[Expression]) extends Expression with ImplicitCas override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val evals = children.map(_.genCode(ctx)) - val inputs = evals.map { eval => - s"${eval.isNull} ? null : ${eval.value}" - }.mkString(", ") - ev.copy(evals.map(_.code).mkString("\n") + s""" - boolean ${ev.isNull} = false; - UTF8String ${ev.value} = UTF8String.concat($inputs); - if (${ev.value} == null) { - ${ev.isNull} = true; - } + val args = ctx.freshName("args") + + val inputs = evals.zipWithIndex.map { case (eval, index) => + s""" + ${eval.code} + if (!${eval.isNull}) { + $args[$index] = ${eval.value}; + } + """ + } + val codes = if (ctx.INPUT_ROW != null && ctx.currentVars == null) { + ctx.splitExpressions(inputs, "valueConcat", + ("InternalRow", ctx.INPUT_ROW) :: ("UTF8String[]", args) :: Nil) + } else { + inputs.mkString("\n") + } + ev.copy(s""" + UTF8String[] $args = new UTF8String[${evals.length}]; + $codes + UTF8String ${ev.value} = UTF8String.concat($args); + boolean ${ev.isNull} = ${ev.value} == null; """) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala index 18ef4bc37c2b5..aa9d5a0aa95e3 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala @@ -45,6 +45,12 @@ class StringExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { // scalastyle:on } + test("SPARK-22498: Concat should not generate codes beyond 64KB") { + val N = 5000 + val strs = (1 to N).map(x => s"s$x") + checkEvaluation(Concat(strs.map(Literal.create(_, StringType))), strs.mkString, EmptyRow) + } + test("concat_ws") { def testConcatWs(expected: String, sep: String, inputs: Any*): Unit = { val inputExprs = inputs.map { From b10837ab1a7bef04bf7a2773b9e44ed9206643fe Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Mon, 20 Nov 2017 13:32:01 +0900 Subject: [PATCH 1692/1765] [SPARK-22557][TEST] Use ThreadSignaler explicitly ## What changes were proposed in this pull request? ScalaTest 3.0 uses an implicit `Signaler`. This PR makes it sure all Spark tests uses `ThreadSignaler` explicitly which has the same default behavior of interrupting a thread on the JVM like ScalaTest 2.2.x. This will reduce potential flakiness. ## How was this patch tested? This is testsuite-only update. This should passes the Jenkins tests. Author: Dongjoon Hyun Closes #19784 from dongjoon-hyun/use_thread_signaler. --- .../test/scala/org/apache/spark/DistributedSuite.scala | 7 +++++-- core/src/test/scala/org/apache/spark/DriverSuite.scala | 5 ++++- .../test/scala/org/apache/spark/UnpersistSuite.scala | 8 ++++++-- .../org/apache/spark/deploy/SparkSubmitSuite.scala | 9 ++++++++- .../org/apache/spark/rdd/AsyncRDDActionsSuite.scala | 5 ++++- .../org/apache/spark/scheduler/DAGSchedulerSuite.scala | 5 ++++- .../OutputCommitCoordinatorIntegrationSuite.scala | 5 ++++- .../org/apache/spark/storage/BlockManagerSuite.scala | 10 ++++++++-- .../scala/org/apache/spark/util/EventLoopSuite.scala | 5 ++++- .../streaming/ProcessingTimeExecutorSuite.scala | 8 +++++--- .../org/apache/spark/sql/streaming/StreamTest.scala | 2 ++ .../apache/spark/sql/hive/SparkSubmitTestUtils.scala | 5 ++++- .../org/apache/spark/streaming/ReceiverSuite.scala | 5 +++-- .../apache/spark/streaming/StreamingContextSuite.scala | 5 +++-- .../spark/streaming/receiver/BlockGeneratorSuite.scala | 7 ++++--- 15 files changed, 68 insertions(+), 23 deletions(-) diff --git a/core/src/test/scala/org/apache/spark/DistributedSuite.scala b/core/src/test/scala/org/apache/spark/DistributedSuite.scala index f8005610f7e4f..ea9f6d2fc20f4 100644 --- a/core/src/test/scala/org/apache/spark/DistributedSuite.scala +++ b/core/src/test/scala/org/apache/spark/DistributedSuite.scala @@ -18,7 +18,7 @@ package org.apache.spark import org.scalatest.Matchers -import org.scalatest.concurrent.TimeLimits._ +import org.scalatest.concurrent.{Signaler, ThreadSignaler, TimeLimits} import org.scalatest.time.{Millis, Span} import org.apache.spark.security.EncryptionFunSuite @@ -30,7 +30,10 @@ class NotSerializableExn(val notSer: NotSerializableClass) extends Throwable() { class DistributedSuite extends SparkFunSuite with Matchers with LocalSparkContext - with EncryptionFunSuite { + with EncryptionFunSuite with TimeLimits { + + // Necessary to make ScalaTest 3.x interrupt a thread on the JVM like ScalaTest 2.2.x + implicit val defaultSignaler: Signaler = ThreadSignaler val clusterUrl = "local-cluster[2,1,1024]" diff --git a/core/src/test/scala/org/apache/spark/DriverSuite.scala b/core/src/test/scala/org/apache/spark/DriverSuite.scala index be80d278fcea8..962945e5b6bb1 100644 --- a/core/src/test/scala/org/apache/spark/DriverSuite.scala +++ b/core/src/test/scala/org/apache/spark/DriverSuite.scala @@ -19,7 +19,7 @@ package org.apache.spark import java.io.File -import org.scalatest.concurrent.TimeLimits +import org.scalatest.concurrent.{Signaler, ThreadSignaler, TimeLimits} import org.scalatest.prop.TableDrivenPropertyChecks._ import org.scalatest.time.SpanSugar._ @@ -27,6 +27,9 @@ import org.apache.spark.util.Utils class DriverSuite extends SparkFunSuite with TimeLimits { + // Necessary to make ScalaTest 3.x interrupt a thread on the JVM like ScalaTest 2.2.x + implicit val defaultSignaler: Signaler = ThreadSignaler + ignore("driver should exit after finishing without cleanup (SPARK-530)") { val sparkHome = sys.props.getOrElse("spark.test.home", fail("spark.test.home is not set!")) val masters = Table("master", "local", "local-cluster[2,1,1024]") diff --git a/core/src/test/scala/org/apache/spark/UnpersistSuite.scala b/core/src/test/scala/org/apache/spark/UnpersistSuite.scala index bc3f58cf2a35d..b58a3ebe6e4c9 100644 --- a/core/src/test/scala/org/apache/spark/UnpersistSuite.scala +++ b/core/src/test/scala/org/apache/spark/UnpersistSuite.scala @@ -17,10 +17,14 @@ package org.apache.spark -import org.scalatest.concurrent.TimeLimits._ +import org.scalatest.concurrent.{Signaler, ThreadSignaler, TimeLimits} import org.scalatest.time.{Millis, Span} -class UnpersistSuite extends SparkFunSuite with LocalSparkContext { +class UnpersistSuite extends SparkFunSuite with LocalSparkContext with TimeLimits { + + // Necessary to make ScalaTest 3.x interrupt a thread on the JVM like ScalaTest 2.2.x + implicit val defaultSignaler: Signaler = ThreadSignaler + test("unpersist RDD") { sc = new SparkContext("local", "test") val rdd = sc.makeRDD(Array(1, 2, 3, 4), 2).cache() diff --git a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala index cfbf56fb8c369..d0a34c5cdcf57 100644 --- a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala @@ -31,7 +31,7 @@ import org.apache.commons.io.FileUtils import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.{FileStatus, FSDataInputStream, Path} import org.scalatest.{BeforeAndAfterEach, Matchers} -import org.scalatest.concurrent.TimeLimits +import org.scalatest.concurrent.{Signaler, ThreadSignaler, TimeLimits} import org.scalatest.time.SpanSugar._ import org.apache.spark._ @@ -102,6 +102,9 @@ class SparkSubmitSuite import SparkSubmitSuite._ + // Necessary to make ScalaTest 3.x interrupt a thread on the JVM like ScalaTest 2.2.x + implicit val defaultSignaler: Signaler = ThreadSignaler + override def beforeEach() { super.beforeEach() System.setProperty("spark.testing", "true") @@ -1016,6 +1019,10 @@ class SparkSubmitSuite } object SparkSubmitSuite extends SparkFunSuite with TimeLimits { + + // Necessary to make ScalaTest 3.x interrupt a thread on the JVM like ScalaTest 2.2.x + implicit val defaultSignaler: Signaler = ThreadSignaler + // NOTE: This is an expensive operation in terms of time (10 seconds+). Use sparingly. def runSparkSubmit(args: Seq[String], root: String = ".."): Unit = { val sparkHome = sys.props.getOrElse("spark.test.home", fail("spark.test.home is not set!")) diff --git a/core/src/test/scala/org/apache/spark/rdd/AsyncRDDActionsSuite.scala b/core/src/test/scala/org/apache/spark/rdd/AsyncRDDActionsSuite.scala index de0e71a332f23..24b0144a38bd2 100644 --- a/core/src/test/scala/org/apache/spark/rdd/AsyncRDDActionsSuite.scala +++ b/core/src/test/scala/org/apache/spark/rdd/AsyncRDDActionsSuite.scala @@ -24,7 +24,7 @@ import scala.concurrent.ExecutionContext.Implicits.global import scala.concurrent.duration.Duration import org.scalatest.BeforeAndAfterAll -import org.scalatest.concurrent.TimeLimits +import org.scalatest.concurrent.{Signaler, ThreadSignaler, TimeLimits} import org.scalatest.time.SpanSugar._ import org.apache.spark._ @@ -34,6 +34,9 @@ class AsyncRDDActionsSuite extends SparkFunSuite with BeforeAndAfterAll with Tim @transient private var sc: SparkContext = _ + // Necessary to make ScalaTest 3.x interrupt a thread on the JVM like ScalaTest 2.2.x + implicit val defaultSignaler: Signaler = ThreadSignaler + override def beforeAll() { super.beforeAll() sc = new SparkContext("local[2]", "test") diff --git a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala index 6222e576d1ce9..d395e09969453 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala @@ -25,7 +25,7 @@ import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet, Map} import scala.language.reflectiveCalls import scala.util.control.NonFatal -import org.scalatest.concurrent.TimeLimits +import org.scalatest.concurrent.{Signaler, ThreadSignaler, TimeLimits} import org.scalatest.time.SpanSugar._ import org.apache.spark._ @@ -102,6 +102,9 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi import DAGSchedulerSuite._ + // Necessary to make ScalaTest 3.x interrupt a thread on the JVM like ScalaTest 2.2.x + implicit val defaultSignaler: Signaler = ThreadSignaler + val conf = new SparkConf /** Set of TaskSets the DAGScheduler has requested executed. */ val taskSets = scala.collection.mutable.Buffer[TaskSet]() diff --git a/core/src/test/scala/org/apache/spark/scheduler/OutputCommitCoordinatorIntegrationSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/OutputCommitCoordinatorIntegrationSuite.scala index a27dadcf49bfc..d6ff5bb33055c 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/OutputCommitCoordinatorIntegrationSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/OutputCommitCoordinatorIntegrationSuite.scala @@ -18,7 +18,7 @@ package org.apache.spark.scheduler import org.apache.hadoop.mapred.{FileOutputCommitter, TaskAttemptContext} -import org.scalatest.concurrent.TimeLimits +import org.scalatest.concurrent.{Signaler, ThreadSignaler, TimeLimits} import org.scalatest.time.{Seconds, Span} import org.apache.spark.{LocalSparkContext, SparkConf, SparkContext, SparkFunSuite, TaskContext} @@ -34,6 +34,9 @@ class OutputCommitCoordinatorIntegrationSuite with LocalSparkContext with TimeLimits { + // Necessary to make ScalaTest 3.x interrupt a thread on the JVM like ScalaTest 2.2.x + implicit val defaultSignaler: Signaler = ThreadSignaler + override def beforeAll(): Unit = { super.beforeAll() val conf = new SparkConf() diff --git a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala index d45c194d31adc..f3e8a2ed1d562 100644 --- a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala @@ -31,8 +31,8 @@ import org.apache.commons.lang3.RandomUtils import org.mockito.{Matchers => mc} import org.mockito.Mockito.{mock, times, verify, when} import org.scalatest._ +import org.scalatest.concurrent.{Signaler, ThreadSignaler, TimeLimits} import org.scalatest.concurrent.Eventually._ -import org.scalatest.concurrent.TimeLimits._ import org.apache.spark._ import org.apache.spark.broadcast.BroadcastManager @@ -57,10 +57,13 @@ import org.apache.spark.util.io.ChunkedByteBuffer class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterEach with PrivateMethodTester with LocalSparkContext with ResetSystemProperties - with EncryptionFunSuite { + with EncryptionFunSuite with TimeLimits { import BlockManagerSuite._ + // Necessary to make ScalaTest 3.x interrupt a thread on the JVM like ScalaTest 2.2.x + implicit val defaultSignaler: Signaler = ThreadSignaler + var conf: SparkConf = null var store: BlockManager = null var store2: BlockManager = null @@ -1450,6 +1453,9 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE private object BlockManagerSuite { + // Necessary to make ScalaTest 3.x interrupt a thread on the JVM like ScalaTest 2.2.x + implicit val defaultSignaler: Signaler = ThreadSignaler + private implicit class BlockManagerTestUtils(store: BlockManager) { def dropFromMemoryIfExists( diff --git a/core/src/test/scala/org/apache/spark/util/EventLoopSuite.scala b/core/src/test/scala/org/apache/spark/util/EventLoopSuite.scala index f4f8388f5f19f..550745771750c 100644 --- a/core/src/test/scala/org/apache/spark/util/EventLoopSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/EventLoopSuite.scala @@ -23,13 +23,16 @@ import scala.collection.JavaConverters._ import scala.concurrent.duration._ import scala.language.postfixOps +import org.scalatest.concurrent.{Signaler, ThreadSignaler, TimeLimits} import org.scalatest.concurrent.Eventually._ -import org.scalatest.concurrent.TimeLimits import org.apache.spark.SparkFunSuite class EventLoopSuite extends SparkFunSuite with TimeLimits { + // Necessary to make ScalaTest 3.x interrupt a thread on the JVM like ScalaTest 2.2.x + implicit val defaultSignaler: Signaler = ThreadSignaler + test("EventLoop") { val buffer = new ConcurrentLinkedQueue[Int] val eventLoop = new EventLoop[Int]("test") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/ProcessingTimeExecutorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/ProcessingTimeExecutorSuite.scala index 519e3c01afe8a..80c76915e4c23 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/ProcessingTimeExecutorSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/ProcessingTimeExecutorSuite.scala @@ -22,16 +22,18 @@ import java.util.concurrent.ConcurrentHashMap import scala.collection.mutable import org.eclipse.jetty.util.ConcurrentHashSet -import org.scalatest.concurrent.Eventually +import org.scalatest.concurrent.{Eventually, Signaler, ThreadSignaler, TimeLimits} import org.scalatest.concurrent.PatienceConfiguration.Timeout -import org.scalatest.concurrent.TimeLimits._ import org.scalatest.time.SpanSugar._ import org.apache.spark.SparkFunSuite import org.apache.spark.sql.streaming.ProcessingTime import org.apache.spark.sql.streaming.util.StreamManualClock -class ProcessingTimeExecutorSuite extends SparkFunSuite { +class ProcessingTimeExecutorSuite extends SparkFunSuite with TimeLimits { + + // Necessary to make ScalaTest 3.x interrupt a thread on the JVM like ScalaTest 2.2.x + implicit val defaultSignaler: Signaler = ThreadSignaler val timeout = 10.seconds diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala index 70b39b934071a..e68fca050571f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala @@ -69,7 +69,9 @@ import org.apache.spark.util.{Clock, SystemClock, Utils} */ trait StreamTest extends QueryTest with SharedSQLContext with TimeLimits with BeforeAndAfterAll { + // Necessary to make ScalaTest 3.x interrupt a thread on the JVM like ScalaTest 2.2.x implicit val defaultSignaler: Signaler = ThreadSignaler + override def afterAll(): Unit = { super.afterAll() StateStore.stop() // stop the state store maintenance thread and unload store providers diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/SparkSubmitTestUtils.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/SparkSubmitTestUtils.scala index ede44df4afe11..68ed97d6d1f5a 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/SparkSubmitTestUtils.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/SparkSubmitTestUtils.scala @@ -23,7 +23,7 @@ import java.util.Date import scala.collection.mutable.ArrayBuffer -import org.scalatest.concurrent.TimeLimits +import org.scalatest.concurrent.{Signaler, ThreadSignaler, TimeLimits} import org.scalatest.exceptions.TestFailedDueToTimeoutException import org.scalatest.time.SpanSugar._ @@ -33,6 +33,9 @@ import org.apache.spark.util.Utils trait SparkSubmitTestUtils extends SparkFunSuite with TimeLimits { + // Necessary to make ScalaTest 3.x interrupt a thread on the JVM like ScalaTest 2.2.x + implicit val defaultSignaler: Signaler = ThreadSignaler + // NOTE: This is an expensive operation in terms of time (10 seconds+). Use sparingly. // This is copied from org.apache.spark.deploy.SparkSubmitSuite protected def runSparkSubmit(args: Seq[String], sparkHomeOpt: Option[String] = None): Unit = { diff --git a/streaming/src/test/scala/org/apache/spark/streaming/ReceiverSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/ReceiverSuite.scala index 5fc626c1f78b8..145c48e5a9a72 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/ReceiverSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/ReceiverSuite.scala @@ -38,6 +38,9 @@ import org.apache.spark.util.Utils /** Testsuite for testing the network receiver behavior */ class ReceiverSuite extends TestSuiteBase with TimeLimits with Serializable { + // Necessary to make ScalaTest 3.x interrupt a thread on the JVM like ScalaTest 2.2.x + implicit val signaler: Signaler = ThreadSignaler + test("receiver life cycle") { val receiver = new FakeReceiver @@ -60,8 +63,6 @@ class ReceiverSuite extends TestSuiteBase with TimeLimits with Serializable { // Verify that the receiver intercept[Exception] { - // Necessary to make failAfter interrupt awaitTermination() in ScalaTest 3.x - implicit val signaler: Signaler = ThreadSignaler failAfter(200 millis) { executingThread.join() } diff --git a/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala index 5810e73f4098b..52c8959351fe7 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala @@ -44,6 +44,9 @@ import org.apache.spark.util.Utils class StreamingContextSuite extends SparkFunSuite with BeforeAndAfter with TimeLimits with Logging { + // Necessary to make ScalaTest 3.x interrupt a thread on the JVM like ScalaTest 2.2.x + implicit val signaler: Signaler = ThreadSignaler + val master = "local[2]" val appName = this.getClass.getSimpleName val batchDuration = Milliseconds(500) @@ -406,8 +409,6 @@ class StreamingContextSuite extends SparkFunSuite with BeforeAndAfter with TimeL // test whether awaitTermination() does not exit if not time is given val exception = intercept[Exception] { - // Necessary to make failAfter interrupt awaitTermination() in ScalaTest 3.x - implicit val signaler: Signaler = ThreadSignaler failAfter(1000 millis) { ssc.awaitTermination() throw new Exception("Did not wait for stop") diff --git a/streaming/src/test/scala/org/apache/spark/streaming/receiver/BlockGeneratorSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/receiver/BlockGeneratorSuite.scala index 898da4445e464..580f831548cd5 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/receiver/BlockGeneratorSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/receiver/BlockGeneratorSuite.scala @@ -24,18 +24,19 @@ import scala.collection.mutable import org.scalatest.BeforeAndAfter import org.scalatest.Matchers._ -import org.scalatest.concurrent.{Signaler, ThreadSignaler} +import org.scalatest.concurrent.{Signaler, ThreadSignaler, TimeLimits} import org.scalatest.concurrent.Eventually._ -import org.scalatest.concurrent.TimeLimits._ import org.scalatest.time.SpanSugar._ import org.apache.spark.{SparkConf, SparkException, SparkFunSuite} import org.apache.spark.storage.StreamBlockId import org.apache.spark.util.ManualClock -class BlockGeneratorSuite extends SparkFunSuite with BeforeAndAfter { +class BlockGeneratorSuite extends SparkFunSuite with BeforeAndAfter with TimeLimits { + // Necessary to make ScalaTest 3.x interrupt a thread on the JVM like ScalaTest 2.2.x implicit val defaultSignaler: Signaler = ThreadSignaler + private val blockIntervalMs = 10 private val conf = new SparkConf().set("spark.streaming.blockInterval", s"${blockIntervalMs}ms") @volatile private var blockGenerator: BlockGenerator = null From 57c5514de9dba1c14e296f85fb13fef23ce8c73f Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Mon, 20 Nov 2017 13:34:06 +0900 Subject: [PATCH 1693/1765] [SPARK-22554][PYTHON] Add a config to control if PySpark should use daemon or not for workers ## What changes were proposed in this pull request? This PR proposes to add a flag to control if PySpark should use daemon or not. Actually, SparkR already has a flag for useDaemon: https://github.com/apache/spark/blob/478fbc866fbfdb4439788583281863ecea14e8af/core/src/main/scala/org/apache/spark/api/r/RRunner.scala#L362 It'd be great if we have this flag too. It makes easier to debug Windows specific issue. ## How was this patch tested? Manually tested. Author: hyukjinkwon Closes #19782 from HyukjinKwon/use-daemon-flag. --- .../org/apache/spark/api/python/PythonWorkerFactory.scala | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala b/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala index fc595ae9e4563..f53c6178047f5 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala @@ -38,7 +38,12 @@ private[spark] class PythonWorkerFactory(pythonExec: String, envVars: Map[String // (pyspark/daemon.py) and tell it to fork new workers for our tasks. This daemon currently // only works on UNIX-based systems now because it uses signals for child management, so we can // also fall back to launching workers (pyspark/worker.py) directly. - val useDaemon = !System.getProperty("os.name").startsWith("Windows") + val useDaemon = { + val useDaemonEnabled = SparkEnv.get.conf.getBoolean("spark.python.use.daemon", true) + + // This flag is ignored on Windows as it's unable to fork. + !System.getProperty("os.name").startsWith("Windows") && useDaemonEnabled + } var daemon: Process = null val daemonHost = InetAddress.getByAddress(Array(127, 0, 0, 1)) From 3c3eebc8734e36e61f4627e2c517fbbe342b3b42 Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Mon, 20 Nov 2017 12:40:16 +0100 Subject: [PATCH 1694/1765] [SPARK-20101][SQL] Use OffHeapColumnVector when "spark.sql.columnVector.offheap.enable" is set to "true" This PR enables to use ``OffHeapColumnVector`` when ``spark.sql.columnVector.offheap.enable`` is set to ``true``. While ``ColumnVector`` has two implementations ``OnHeapColumnVector`` and ``OffHeapColumnVector``, only ``OnHeapColumnVector`` is always used. This PR implements the followings - Pass ``OffHeapColumnVector`` to ``ColumnarBatch.allocate()`` when ``spark.sql.columnVector.offheap.enable`` is set to ``true`` - Free all of off-heap memory regions by ``OffHeapColumnVector.close()`` - Ensure to call ``OffHeapColumnVector.close()`` Use existing tests Author: Kazuaki Ishizaki Closes #17436 from kiszk/SPARK-20101. --- .../scala/org/apache/spark/SparkConf.scala | 2 +- .../spark/internal/config/package.scala | 16 ++++++++++++++ .../apache/spark/memory/MemoryManager.scala | 7 +++--- .../memory/StaticMemoryManagerSuite.scala | 3 ++- .../memory/UnifiedMemoryManagerSuite.scala | 7 +++--- .../BlockManagerReplicationSuite.scala | 3 ++- .../spark/storage/BlockManagerSuite.scala | 2 +- .../org/apache/spark/ui/UISeleniumSuite.scala | 3 ++- .../apache/spark/sql/internal/SQLConf.scala | 9 ++++++++ .../VectorizedParquetRecordReader.java | 12 ++++++---- .../sql/execution/DataSourceScanExec.scala | 3 ++- .../columnar/InMemoryTableScanExec.scala | 17 ++++++++++++-- .../execution/datasources/FileFormat.scala | 4 +++- .../parquet/ParquetFileFormat.scala | 22 ++++++++++++++----- .../sql/execution/joins/HashedRelation.scala | 7 +++--- .../UnsafeFixedWidthAggregationMapSuite.scala | 3 ++- .../UnsafeKVExternalSorterSuite.scala | 6 ++--- .../benchmark/AggregateBenchmark.scala | 7 +++--- .../parquet/ParquetEncodingSuite.scala | 6 ++--- .../datasources/parquet/ParquetIOSuite.scala | 11 +++++----- .../parquet/ParquetReadBenchmark.scala | 8 ++++--- .../execution/joins/HashedRelationSuite.scala | 9 ++++---- 22 files changed, 117 insertions(+), 50 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/SparkConf.scala b/core/src/main/scala/org/apache/spark/SparkConf.scala index 57b3744e9c30a..ee726df7391f1 100644 --- a/core/src/main/scala/org/apache/spark/SparkConf.scala +++ b/core/src/main/scala/org/apache/spark/SparkConf.scala @@ -655,7 +655,7 @@ private[spark] object SparkConf extends Logging { AlternateConfig("spark.streaming.minRememberDuration", "1.5")), "spark.yarn.max.executor.failures" -> Seq( AlternateConfig("spark.yarn.max.worker.failures", "1.5")), - "spark.memory.offHeap.enabled" -> Seq( + MEMORY_OFFHEAP_ENABLED.key -> Seq( AlternateConfig("spark.unsafe.offHeap", "1.6")), "spark.rpc.message.maxSize" -> Seq( AlternateConfig("spark.akka.frameSize", "1.6")), diff --git a/core/src/main/scala/org/apache/spark/internal/config/package.scala b/core/src/main/scala/org/apache/spark/internal/config/package.scala index 84315f55a59ad..7be4d6b212d72 100644 --- a/core/src/main/scala/org/apache/spark/internal/config/package.scala +++ b/core/src/main/scala/org/apache/spark/internal/config/package.scala @@ -80,6 +80,22 @@ package object config { .bytesConf(ByteUnit.MiB) .createWithDefaultString("1g") + private[spark] val MEMORY_OFFHEAP_ENABLED = ConfigBuilder("spark.memory.offHeap.enabled") + .doc("If true, Spark will attempt to use off-heap memory for certain operations. " + + "If off-heap memory use is enabled, then spark.memory.offHeap.size must be positive.") + .withAlternative("spark.unsafe.offHeap") + .booleanConf + .createWithDefault(false) + + private[spark] val MEMORY_OFFHEAP_SIZE = ConfigBuilder("spark.memory.offHeap.size") + .doc("The absolute amount of memory in bytes which can be used for off-heap allocation. " + + "This setting has no impact on heap memory usage, so if your executors' total memory " + + "consumption must fit within some hard limit then be sure to shrink your JVM heap size " + + "accordingly. This must be set to a positive value when spark.memory.offHeap.enabled=true.") + .bytesConf(ByteUnit.BYTE) + .checkValue(_ >= 0, "The off-heap memory size must not be negative") + .createWithDefault(0) + private[spark] val IS_PYTHON_APP = ConfigBuilder("spark.yarn.isPython").internal() .booleanConf.createWithDefault(false) diff --git a/core/src/main/scala/org/apache/spark/memory/MemoryManager.scala b/core/src/main/scala/org/apache/spark/memory/MemoryManager.scala index 82442cf56154c..0641adc2ab699 100644 --- a/core/src/main/scala/org/apache/spark/memory/MemoryManager.scala +++ b/core/src/main/scala/org/apache/spark/memory/MemoryManager.scala @@ -21,6 +21,7 @@ import javax.annotation.concurrent.GuardedBy import org.apache.spark.SparkConf import org.apache.spark.internal.Logging +import org.apache.spark.internal.config._ import org.apache.spark.storage.BlockId import org.apache.spark.storage.memory.MemoryStore import org.apache.spark.unsafe.Platform @@ -54,7 +55,7 @@ private[spark] abstract class MemoryManager( onHeapStorageMemoryPool.incrementPoolSize(onHeapStorageMemory) onHeapExecutionMemoryPool.incrementPoolSize(onHeapExecutionMemory) - protected[this] val maxOffHeapMemory = conf.getSizeAsBytes("spark.memory.offHeap.size", 0) + protected[this] val maxOffHeapMemory = conf.get(MEMORY_OFFHEAP_SIZE) protected[this] val offHeapStorageMemory = (maxOffHeapMemory * conf.getDouble("spark.memory.storageFraction", 0.5)).toLong @@ -194,8 +195,8 @@ private[spark] abstract class MemoryManager( * sun.misc.Unsafe. */ final val tungstenMemoryMode: MemoryMode = { - if (conf.getBoolean("spark.memory.offHeap.enabled", false)) { - require(conf.getSizeAsBytes("spark.memory.offHeap.size", 0) > 0, + if (conf.get(MEMORY_OFFHEAP_ENABLED)) { + require(conf.get(MEMORY_OFFHEAP_SIZE) > 0, "spark.memory.offHeap.size must be > 0 when spark.memory.offHeap.enabled == true") require(Platform.unaligned(), "No support for unaligned Unsafe. Set spark.memory.offHeap.enabled to false.") diff --git a/core/src/test/scala/org/apache/spark/memory/StaticMemoryManagerSuite.scala b/core/src/test/scala/org/apache/spark/memory/StaticMemoryManagerSuite.scala index 4e31fb5589a9c..0f32fe4059fbb 100644 --- a/core/src/test/scala/org/apache/spark/memory/StaticMemoryManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/memory/StaticMemoryManagerSuite.scala @@ -20,6 +20,7 @@ package org.apache.spark.memory import org.mockito.Mockito.when import org.apache.spark.SparkConf +import org.apache.spark.internal.config.MEMORY_OFFHEAP_SIZE import org.apache.spark.storage.TestBlockId import org.apache.spark.storage.memory.MemoryStore @@ -48,7 +49,7 @@ class StaticMemoryManagerSuite extends MemoryManagerSuite { conf.clone .set("spark.memory.fraction", "1") .set("spark.testing.memory", maxOnHeapExecutionMemory.toString) - .set("spark.memory.offHeap.size", maxOffHeapExecutionMemory.toString), + .set(MEMORY_OFFHEAP_SIZE.key, maxOffHeapExecutionMemory.toString), maxOnHeapExecutionMemory = maxOnHeapExecutionMemory, maxOnHeapStorageMemory = 0, numCores = 1) diff --git a/core/src/test/scala/org/apache/spark/memory/UnifiedMemoryManagerSuite.scala b/core/src/test/scala/org/apache/spark/memory/UnifiedMemoryManagerSuite.scala index 02b04cdbb2a5f..d56cfc183d921 100644 --- a/core/src/test/scala/org/apache/spark/memory/UnifiedMemoryManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/memory/UnifiedMemoryManagerSuite.scala @@ -20,6 +20,7 @@ package org.apache.spark.memory import org.scalatest.PrivateMethodTester import org.apache.spark.SparkConf +import org.apache.spark.internal.config._ import org.apache.spark.storage.TestBlockId import org.apache.spark.storage.memory.MemoryStore @@ -43,7 +44,7 @@ class UnifiedMemoryManagerSuite extends MemoryManagerSuite with PrivateMethodTes val conf = new SparkConf() .set("spark.memory.fraction", "1") .set("spark.testing.memory", maxOnHeapExecutionMemory.toString) - .set("spark.memory.offHeap.size", maxOffHeapExecutionMemory.toString) + .set(MEMORY_OFFHEAP_SIZE.key, maxOffHeapExecutionMemory.toString) .set("spark.memory.storageFraction", storageFraction.toString) UnifiedMemoryManager(conf, numCores = 1) } @@ -305,9 +306,9 @@ class UnifiedMemoryManagerSuite extends MemoryManagerSuite with PrivateMethodTes test("not enough free memory in the storage pool --OFF_HEAP") { val conf = new SparkConf() - .set("spark.memory.offHeap.size", "1000") + .set(MEMORY_OFFHEAP_SIZE.key, "1000") .set("spark.testing.memory", "1000") - .set("spark.memory.offHeap.enabled", "true") + .set(MEMORY_OFFHEAP_ENABLED.key, "true") val taskAttemptId = 0L val mm = UnifiedMemoryManager(conf, numCores = 1) val ms = makeMemoryStore(mm) diff --git a/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala index c2101ba828553..3962bdc27d22c 100644 --- a/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala @@ -31,6 +31,7 @@ import org.scalatest.concurrent.Eventually._ import org.apache.spark._ import org.apache.spark.broadcast.BroadcastManager import org.apache.spark.internal.Logging +import org.apache.spark.internal.config.MEMORY_OFFHEAP_SIZE import org.apache.spark.memory.UnifiedMemoryManager import org.apache.spark.network.BlockTransferService import org.apache.spark.network.netty.NettyBlockTransferService @@ -69,7 +70,7 @@ trait BlockManagerReplicationBehavior extends SparkFunSuite maxMem: Long, name: String = SparkContext.DRIVER_IDENTIFIER): BlockManager = { conf.set("spark.testing.memory", maxMem.toString) - conf.set("spark.memory.offHeap.size", maxMem.toString) + conf.set(MEMORY_OFFHEAP_SIZE.key, maxMem.toString) val transfer = new NettyBlockTransferService(conf, securityMgr, "localhost", "localhost", 0, 1) val memManager = UnifiedMemoryManager(conf, numCores = 1) val serializerManager = new SerializerManager(serializer, conf) diff --git a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala index f3e8a2ed1d562..629eed49b04cc 100644 --- a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala @@ -90,7 +90,7 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE testConf: Option[SparkConf] = None): BlockManager = { val bmConf = testConf.map(_.setAll(conf.getAll)).getOrElse(conf) bmConf.set("spark.testing.memory", maxMem.toString) - bmConf.set("spark.memory.offHeap.size", maxMem.toString) + bmConf.set(MEMORY_OFFHEAP_SIZE.key, maxMem.toString) val serializer = new KryoSerializer(bmConf) val encryptionKey = if (bmConf.get(IO_ENCRYPTION_ENABLED)) { Some(CryptoStreamUtils.createKey(bmConf)) diff --git a/core/src/test/scala/org/apache/spark/ui/UISeleniumSuite.scala b/core/src/test/scala/org/apache/spark/ui/UISeleniumSuite.scala index 6a6c37873e1c2..df5f0b5335e82 100644 --- a/core/src/test/scala/org/apache/spark/ui/UISeleniumSuite.scala +++ b/core/src/test/scala/org/apache/spark/ui/UISeleniumSuite.scala @@ -39,6 +39,7 @@ import org.apache.spark._ import org.apache.spark.LocalSparkContext._ import org.apache.spark.api.java.StorageLevels import org.apache.spark.deploy.history.HistoryServerSuite +import org.apache.spark.internal.config.MEMORY_OFFHEAP_SIZE import org.apache.spark.shuffle.FetchFailedException import org.apache.spark.status.api.v1.{JacksonMessageWriter, RDDDataDistribution, StageStatus} @@ -104,7 +105,7 @@ class UISeleniumSuite extends SparkFunSuite with WebBrowser with Matchers with B .set("spark.ui.enabled", "true") .set("spark.ui.port", "0") .set("spark.ui.killEnabled", killEnabled.toString) - .set("spark.memory.offHeap.size", "64m") + .set(MEMORY_OFFHEAP_SIZE.key, "64m") val sc = new SparkContext(conf) assert(sc.ui.isDefined) sc diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 3452a1e715fb9..8485ed4c887d4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -140,6 +140,13 @@ object SQLConf { .booleanConf .createWithDefault(true) + val COLUMN_VECTOR_OFFHEAP_ENABLED = + buildConf("spark.sql.columnVector.offheap.enable") + .internal() + .doc("When true, use OffHeapColumnVector in ColumnarBatch.") + .booleanConf + .createWithDefault(false) + val PREFER_SORTMERGEJOIN = buildConf("spark.sql.join.preferSortMergeJoin") .internal() .doc("When true, prefer sort merge join over shuffle hash join.") @@ -1210,6 +1217,8 @@ class SQLConf extends Serializable with Logging { def inMemoryPartitionPruning: Boolean = getConf(IN_MEMORY_PARTITION_PRUNING) + def offHeapColumnVectorEnabled: Boolean = getConf(COLUMN_VECTOR_OFFHEAP_ENABLED) + def columnNameOfCorruptRecord: String = getConf(COLUMN_NAME_OF_CORRUPT_RECORD) def broadcastTimeout: Long = getConf(BROADCAST_TIMEOUT) diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedParquetRecordReader.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedParquetRecordReader.java index e827229dceef8..669d71e60779d 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedParquetRecordReader.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedParquetRecordReader.java @@ -101,9 +101,13 @@ public class VectorizedParquetRecordReader extends SpecificParquetRecordReaderBa private boolean returnColumnarBatch; /** - * The default config on whether columnarBatch should be offheap. + * The memory mode of the columnarBatch */ - private static final MemoryMode DEFAULT_MEMORY_MODE = MemoryMode.ON_HEAP; + private final MemoryMode MEMORY_MODE; + + public VectorizedParquetRecordReader(boolean useOffHeap) { + MEMORY_MODE = useOffHeap ? MemoryMode.OFF_HEAP : MemoryMode.ON_HEAP; + } /** * Implementation of RecordReader API. @@ -204,11 +208,11 @@ public void initBatch( } public void initBatch() { - initBatch(DEFAULT_MEMORY_MODE, null, null); + initBatch(MEMORY_MODE, null, null); } public void initBatch(StructType partitionColumns, InternalRow partitionValues) { - initBatch(DEFAULT_MEMORY_MODE, partitionColumns, partitionValues); + initBatch(MEMORY_MODE, partitionColumns, partitionValues); } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala index a607ec0bf8c9b..a477c23140536 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala @@ -177,7 +177,8 @@ case class FileSourceScanExec( override def vectorTypes: Option[Seq[String]] = relation.fileFormat.vectorTypes( requiredSchema = requiredSchema, - partitionSchema = relation.partitionSchema) + partitionSchema = relation.partitionSchema, + relation.sparkSession.sessionState.conf) @transient private lazy val selectedPartitions: Seq[PartitionDirectory] = { val optimizerMetadataTimeNs = relation.location.metadataOpsTimeNs.getOrElse(0L) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala index 2ae3f35eb1da1..3e73393b12850 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.execution.columnar +import org.apache.spark.TaskContext import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.dsl.expressions._ @@ -37,7 +38,13 @@ case class InMemoryTableScanExec( override protected def innerChildren: Seq[QueryPlan[_]] = Seq(relation) ++ super.innerChildren override def vectorTypes: Option[Seq[String]] = - Option(Seq.fill(attributes.length)(classOf[OnHeapColumnVector].getName)) + Option(Seq.fill(attributes.length)( + if (!conf.offHeapColumnVectorEnabled) { + classOf[OnHeapColumnVector].getName + } else { + classOf[OffHeapColumnVector].getName + } + )) /** * If true, get data from ColumnVector in ColumnarBatch, which are generally faster. @@ -62,7 +69,12 @@ case class InMemoryTableScanExec( private def createAndDecompressColumn(cachedColumnarBatch: CachedBatch): ColumnarBatch = { val rowCount = cachedColumnarBatch.numRows - val columnVectors = OnHeapColumnVector.allocateColumns(rowCount, columnarBatchSchema) + val taskContext = Option(TaskContext.get()) + val columnVectors = if (!conf.offHeapColumnVectorEnabled || taskContext.isEmpty) { + OnHeapColumnVector.allocateColumns(rowCount, columnarBatchSchema) + } else { + OffHeapColumnVector.allocateColumns(rowCount, columnarBatchSchema) + } val columnarBatch = new ColumnarBatch( columnarBatchSchema, columnVectors.asInstanceOf[Array[ColumnVector]], rowCount) columnarBatch.setNumRows(rowCount) @@ -73,6 +85,7 @@ case class InMemoryTableScanExec( columnarBatch.column(i).asInstanceOf[WritableColumnVector], columnarBatchSchema.fields(i).dataType, rowCount) } + taskContext.foreach(_.addTaskCompletionListener(_ => columnarBatch.close())) columnarBatch } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormat.scala index e5a7aee64a4f4..d3874b58bc807 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormat.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormat.scala @@ -26,6 +26,7 @@ import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.sources.Filter import org.apache.spark.sql.types.StructType @@ -70,7 +71,8 @@ trait FileFormat { */ def vectorTypes( requiredSchema: StructType, - partitionSchema: StructType): Option[Seq[String]] = { + partitionSchema: StructType, + sqlConf: SQLConf): Option[Seq[String]] = { None } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala index 044b1a89d57c9..2b1064955a777 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala @@ -46,7 +46,7 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection import org.apache.spark.sql.catalyst.parser.LegacyTypeStringParser import org.apache.spark.sql.execution.datasources._ -import org.apache.spark.sql.execution.vectorized.OnHeapColumnVector +import org.apache.spark.sql.execution.vectorized.{OffHeapColumnVector, OnHeapColumnVector} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.sources._ import org.apache.spark.sql.types._ @@ -274,9 +274,15 @@ class ParquetFileFormat override def vectorTypes( requiredSchema: StructType, - partitionSchema: StructType): Option[Seq[String]] = { + partitionSchema: StructType, + sqlConf: SQLConf): Option[Seq[String]] = { Option(Seq.fill(requiredSchema.fields.length + partitionSchema.fields.length)( - classOf[OnHeapColumnVector].getName)) + if (!sqlConf.offHeapColumnVectorEnabled) { + classOf[OnHeapColumnVector].getName + } else { + classOf[OffHeapColumnVector].getName + } + )) } override def isSplitable( @@ -332,8 +338,10 @@ class ParquetFileFormat // If true, enable using the custom RecordReader for parquet. This only works for // a subset of the types (no complex types). val resultSchema = StructType(partitionSchema.fields ++ requiredSchema.fields) + val sqlConf = sparkSession.sessionState.conf + val enableOffHeapColumnVector = sqlConf.offHeapColumnVectorEnabled val enableVectorizedReader: Boolean = - sparkSession.sessionState.conf.parquetVectorizedReaderEnabled && + sqlConf.parquetVectorizedReaderEnabled && resultSchema.forall(_.dataType.isInstanceOf[AtomicType]) val enableRecordFilter: Boolean = sparkSession.sessionState.conf.parquetRecordFilterEnabled @@ -364,8 +372,10 @@ class ParquetFileFormat if (pushed.isDefined) { ParquetInputFormat.setFilterPredicate(hadoopAttemptContext.getConfiguration, pushed.get) } + val taskContext = Option(TaskContext.get()) val parquetReader = if (enableVectorizedReader) { - val vectorizedReader = new VectorizedParquetRecordReader() + val vectorizedReader = + new VectorizedParquetRecordReader(enableOffHeapColumnVector && taskContext.isDefined) vectorizedReader.initialize(split, hadoopAttemptContext) logDebug(s"Appending $partitionSchema ${file.partitionValues}") vectorizedReader.initBatch(partitionSchema, file.partitionValues) @@ -387,7 +397,7 @@ class ParquetFileFormat } val iter = new RecordReaderIterator(parquetReader) - Option(TaskContext.get()).foreach(_.addTaskCompletionListener(_ => iter.close())) + taskContext.foreach(_.addTaskCompletionListener(_ => iter.close())) // UnsafeRowParquetRecordReader appends the columns internally to avoid another copy. if (parquetReader.isInstanceOf[VectorizedParquetRecordReader] && diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala index b2dcbe5aa9877..d98cf852a1b48 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala @@ -23,6 +23,7 @@ import com.esotericsoftware.kryo.{Kryo, KryoSerializable} import com.esotericsoftware.kryo.io.{Input, Output} import org.apache.spark.{SparkConf, SparkEnv, SparkException} +import org.apache.spark.internal.config.MEMORY_OFFHEAP_ENABLED import org.apache.spark.memory.{MemoryConsumer, StaticMemoryManager, TaskMemoryManager} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ @@ -99,7 +100,7 @@ private[execution] object HashedRelation { val mm = Option(taskMemoryManager).getOrElse { new TaskMemoryManager( new StaticMemoryManager( - new SparkConf().set("spark.memory.offHeap.enabled", "false"), + new SparkConf().set(MEMORY_OFFHEAP_ENABLED.key, "false"), Long.MaxValue, Long.MaxValue, 1), @@ -232,7 +233,7 @@ private[joins] class UnsafeHashedRelation( // so that tests compile: val taskMemoryManager = new TaskMemoryManager( new StaticMemoryManager( - new SparkConf().set("spark.memory.offHeap.enabled", "false"), + new SparkConf().set(MEMORY_OFFHEAP_ENABLED.key, "false"), Long.MaxValue, Long.MaxValue, 1), @@ -403,7 +404,7 @@ private[execution] final class LongToUnsafeRowMap(val mm: TaskMemoryManager, cap this( new TaskMemoryManager( new StaticMemoryManager( - new SparkConf().set("spark.memory.offHeap.enabled", "false"), + new SparkConf().set(MEMORY_OFFHEAP_ENABLED.key, "false"), Long.MaxValue, Long.MaxValue, 1), diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala index d194f58cd1cdd..232c1beae7998 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala @@ -26,6 +26,7 @@ import scala.util.control.NonFatal import org.scalatest.Matchers import org.apache.spark.{SparkConf, SparkFunSuite, TaskContext, TaskContextImpl} +import org.apache.spark.internal.config.MEMORY_OFFHEAP_ENABLED import org.apache.spark.memory.{TaskMemoryManager, TestMemoryManager} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.UnsafeRow @@ -63,7 +64,7 @@ class UnsafeFixedWidthAggregationMapSuite } test(name) { - val conf = new SparkConf().set("spark.memory.offHeap.enabled", "false") + val conf = new SparkConf().set(MEMORY_OFFHEAP_ENABLED.key, "false") memoryManager = new TestMemoryManager(conf) taskMemoryManager = new TaskMemoryManager(memoryManager, 0) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeKVExternalSorterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeKVExternalSorterSuite.scala index 359525fcd05a2..604502f2a57d0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeKVExternalSorterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeKVExternalSorterSuite.scala @@ -22,7 +22,7 @@ import java.util.Properties import scala.util.Random import org.apache.spark._ -import org.apache.spark.internal.config +import org.apache.spark.internal.config._ import org.apache.spark.memory.{TaskMemoryManager, TestMemoryManager} import org.apache.spark.sql.{RandomDataGenerator, Row} import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} @@ -112,7 +112,7 @@ class UnsafeKVExternalSorterSuite extends SparkFunSuite with SharedSQLContext { pageSize: Long, spill: Boolean): Unit = { val memoryManager = - new TestMemoryManager(new SparkConf().set("spark.memory.offHeap.enabled", "false")) + new TestMemoryManager(new SparkConf().set(MEMORY_OFFHEAP_ENABLED.key, "false")) val taskMemMgr = new TaskMemoryManager(memoryManager, 0) TaskContext.setTaskContext(new TaskContextImpl( stageId = 0, @@ -125,7 +125,7 @@ class UnsafeKVExternalSorterSuite extends SparkFunSuite with SharedSQLContext { val sorter = new UnsafeKVExternalSorter( keySchema, valueSchema, SparkEnv.get.blockManager, SparkEnv.get.serializerManager, - pageSize, config.SHUFFLE_SPILL_NUM_ELEMENTS_FORCE_SPILL_THRESHOLD.defaultValue.get) + pageSize, SHUFFLE_SPILL_NUM_ELEMENTS_FORCE_SPILL_THRESHOLD.defaultValue.get) // Insert the keys and values into the sorter inputData.foreach { case (k, v) => diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/AggregateBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/AggregateBenchmark.scala index a834b7cd2c69f..8f4ee8533e599 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/AggregateBenchmark.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/AggregateBenchmark.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.execution.benchmark import java.util.HashMap import org.apache.spark.SparkConf +import org.apache.spark.internal.config._ import org.apache.spark.memory.{StaticMemoryManager, TaskMemoryManager} import org.apache.spark.sql.catalyst.expressions.UnsafeRow import org.apache.spark.sql.execution.joins.LongToUnsafeRowMap @@ -538,7 +539,7 @@ class AggregateBenchmark extends BenchmarkBase { value.setInt(0, 555) val taskMemoryManager = new TaskMemoryManager( new StaticMemoryManager( - new SparkConf().set("spark.memory.offHeap.enabled", "false"), + new SparkConf().set(MEMORY_OFFHEAP_ENABLED.key, "false"), Long.MaxValue, Long.MaxValue, 1), @@ -569,8 +570,8 @@ class AggregateBenchmark extends BenchmarkBase { benchmark.addCase(s"BytesToBytesMap ($heap Heap)") { iter => val taskMemoryManager = new TaskMemoryManager( new StaticMemoryManager( - new SparkConf().set("spark.memory.offHeap.enabled", s"${heap == "off"}") - .set("spark.memory.offHeap.size", "102400000"), + new SparkConf().set(MEMORY_OFFHEAP_ENABLED.key, s"${heap == "off"}") + .set(MEMORY_OFFHEAP_SIZE.key, "102400000"), Long.MaxValue, Long.MaxValue, 1), diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetEncodingSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetEncodingSuite.scala index 00799301ca8d9..edb1290ee2eb0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetEncodingSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetEncodingSuite.scala @@ -40,7 +40,7 @@ class ParquetEncodingSuite extends ParquetCompatibilityTest with SharedSQLContex List.fill(n)(ROW).toDF.repartition(1).write.parquet(dir.getCanonicalPath) val file = SpecificParquetRecordReaderBase.listDirectory(dir).toArray.head - val reader = new VectorizedParquetRecordReader + val reader = new VectorizedParquetRecordReader(sqlContext.conf.offHeapColumnVectorEnabled) reader.initialize(file.asInstanceOf[String], null) val batch = reader.resultBatch() assert(reader.nextBatch()) @@ -65,7 +65,7 @@ class ParquetEncodingSuite extends ParquetCompatibilityTest with SharedSQLContex data.repartition(1).write.parquet(dir.getCanonicalPath) val file = SpecificParquetRecordReaderBase.listDirectory(dir).toArray.head - val reader = new VectorizedParquetRecordReader + val reader = new VectorizedParquetRecordReader(sqlContext.conf.offHeapColumnVectorEnabled) reader.initialize(file.asInstanceOf[String], null) val batch = reader.resultBatch() assert(reader.nextBatch()) @@ -94,7 +94,7 @@ class ParquetEncodingSuite extends ParquetCompatibilityTest with SharedSQLContex data.toDF("f").coalesce(1).write.parquet(dir.getCanonicalPath) val file = SpecificParquetRecordReaderBase.listDirectory(dir).asScala.head - val reader = new VectorizedParquetRecordReader + val reader = new VectorizedParquetRecordReader(sqlContext.conf.offHeapColumnVectorEnabled) reader.initialize(file, null /* set columns to null to project all columns */) val column = reader.resultBatch().column(0) assert(reader.nextBatch()) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala index 633cfde6ab941..44a8b25c61dfb 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala @@ -653,7 +653,7 @@ class ParquetIOSuite extends QueryTest with ParquetTest with SharedSQLContext { spark.createDataFrame(data).repartition(1).write.parquet(dir.getCanonicalPath) val file = SpecificParquetRecordReaderBase.listDirectory(dir).get(0); { - val reader = new VectorizedParquetRecordReader + val reader = new VectorizedParquetRecordReader(sqlContext.conf.offHeapColumnVectorEnabled) try { reader.initialize(file, null) val result = mutable.ArrayBuffer.empty[(Int, String)] @@ -670,7 +670,7 @@ class ParquetIOSuite extends QueryTest with ParquetTest with SharedSQLContext { // Project just one column { - val reader = new VectorizedParquetRecordReader + val reader = new VectorizedParquetRecordReader(sqlContext.conf.offHeapColumnVectorEnabled) try { reader.initialize(file, ("_2" :: Nil).asJava) val result = mutable.ArrayBuffer.empty[(String)] @@ -686,7 +686,7 @@ class ParquetIOSuite extends QueryTest with ParquetTest with SharedSQLContext { // Project columns in opposite order { - val reader = new VectorizedParquetRecordReader + val reader = new VectorizedParquetRecordReader(sqlContext.conf.offHeapColumnVectorEnabled) try { reader.initialize(file, ("_2" :: "_1" :: Nil).asJava) val result = mutable.ArrayBuffer.empty[(String, Int)] @@ -703,7 +703,7 @@ class ParquetIOSuite extends QueryTest with ParquetTest with SharedSQLContext { // Empty projection { - val reader = new VectorizedParquetRecordReader + val reader = new VectorizedParquetRecordReader(sqlContext.conf.offHeapColumnVectorEnabled) try { reader.initialize(file, List[String]().asJava) var result = 0 @@ -742,7 +742,8 @@ class ParquetIOSuite extends QueryTest with ParquetTest with SharedSQLContext { dataTypes.zip(constantValues).foreach { case (dt, v) => val schema = StructType(StructField("pcol", dt) :: Nil) - val vectorizedReader = new VectorizedParquetRecordReader + val vectorizedReader = + new VectorizedParquetRecordReader(sqlContext.conf.offHeapColumnVectorEnabled) val partitionValues = new GenericInternalRow(Array(v)) val file = SpecificParquetRecordReaderBase.listDirectory(dir).get(0) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetReadBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetReadBenchmark.scala index de7a5795b4796..86a3c71a3c4f6 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetReadBenchmark.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetReadBenchmark.scala @@ -75,6 +75,7 @@ object ParquetReadBenchmark { withTempPath { dir => withTempTable("t1", "tempTable") { + val enableOffHeapColumnVector = spark.sessionState.conf.offHeapColumnVectorEnabled spark.range(values).createOrReplaceTempView("t1") spark.sql("select cast(id as INT) as id from t1") .write.parquet(dir.getCanonicalPath) @@ -95,7 +96,7 @@ object ParquetReadBenchmark { parquetReaderBenchmark.addCase("ParquetReader Vectorized") { num => var sum = 0L files.map(_.asInstanceOf[String]).foreach { p => - val reader = new VectorizedParquetRecordReader + val reader = new VectorizedParquetRecordReader(enableOffHeapColumnVector) try { reader.initialize(p, ("id" :: Nil).asJava) val batch = reader.resultBatch() @@ -118,7 +119,7 @@ object ParquetReadBenchmark { parquetReaderBenchmark.addCase("ParquetReader Vectorized -> Row") { num => var sum = 0L files.map(_.asInstanceOf[String]).foreach { p => - val reader = new VectorizedParquetRecordReader + val reader = new VectorizedParquetRecordReader(enableOffHeapColumnVector) try { reader.initialize(p, ("id" :: Nil).asJava) val batch = reader.resultBatch() @@ -260,6 +261,7 @@ object ParquetReadBenchmark { def stringWithNullsScanBenchmark(values: Int, fractionOfNulls: Double): Unit = { withTempPath { dir => withTempTable("t1", "tempTable") { + val enableOffHeapColumnVector = spark.sessionState.conf.offHeapColumnVectorEnabled spark.range(values).createOrReplaceTempView("t1") spark.sql(s"select IF(rand(1) < $fractionOfNulls, NULL, cast(id as STRING)) as c1, " + s"IF(rand(2) < $fractionOfNulls, NULL, cast(id as STRING)) as c2 from t1") @@ -277,7 +279,7 @@ object ParquetReadBenchmark { benchmark.addCase("PR Vectorized") { num => var sum = 0 files.map(_.asInstanceOf[String]).foreach { p => - val reader = new VectorizedParquetRecordReader + val reader = new VectorizedParquetRecordReader(enableOffHeapColumnVector) try { reader.initialize(p, ("c1" :: "c2" :: Nil).asJava) val batch = reader.resultBatch() diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala index ede63fea9606f..51f8c3325fdff 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala @@ -22,6 +22,7 @@ import java.io.{ByteArrayInputStream, ByteArrayOutputStream, ObjectInputStream, import scala.util.Random import org.apache.spark.{SparkConf, SparkFunSuite} +import org.apache.spark.internal.config.MEMORY_OFFHEAP_ENABLED import org.apache.spark.memory.{StaticMemoryManager, TaskMemoryManager} import org.apache.spark.serializer.KryoSerializer import org.apache.spark.sql.catalyst.InternalRow @@ -36,7 +37,7 @@ class HashedRelationSuite extends SparkFunSuite with SharedSQLContext { val mm = new TaskMemoryManager( new StaticMemoryManager( - new SparkConf().set("spark.memory.offHeap.enabled", "false"), + new SparkConf().set(MEMORY_OFFHEAP_ENABLED.key, "false"), Long.MaxValue, Long.MaxValue, 1), @@ -85,7 +86,7 @@ class HashedRelationSuite extends SparkFunSuite with SharedSQLContext { test("test serialization empty hash map") { val taskMemoryManager = new TaskMemoryManager( new StaticMemoryManager( - new SparkConf().set("spark.memory.offHeap.enabled", "false"), + new SparkConf().set(MEMORY_OFFHEAP_ENABLED.key, "false"), Long.MaxValue, Long.MaxValue, 1), @@ -157,7 +158,7 @@ class HashedRelationSuite extends SparkFunSuite with SharedSQLContext { test("LongToUnsafeRowMap with very wide range") { val taskMemoryManager = new TaskMemoryManager( new StaticMemoryManager( - new SparkConf().set("spark.memory.offHeap.enabled", "false"), + new SparkConf().set(MEMORY_OFFHEAP_ENABLED.key, "false"), Long.MaxValue, Long.MaxValue, 1), @@ -202,7 +203,7 @@ class HashedRelationSuite extends SparkFunSuite with SharedSQLContext { test("LongToUnsafeRowMap with random keys") { val taskMemoryManager = new TaskMemoryManager( new StaticMemoryManager( - new SparkConf().set("spark.memory.offHeap.enabled", "false"), + new SparkConf().set(MEMORY_OFFHEAP_ENABLED.key, "false"), Long.MaxValue, Long.MaxValue, 1), From c13b60e0194c90156e74d10b19f94c70675d21ae Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Mon, 20 Nov 2017 12:45:21 +0100 Subject: [PATCH 1695/1765] [SPARK-22533][CORE] Handle deprecated names in ConfigEntry. This change hooks up the config reader to `SparkConf.getDeprecatedConfig`, so that config constants with deprecated names generate the proper warnings. It also changes two deprecated configs from the new "alternatives" system to the old deprecation system, since they're not yet hooked up to each other. Added a few unit tests to verify the desired behavior. Author: Marcelo Vanzin Closes #19760 from vanzin/SPARK-22533. --- .../main/scala/org/apache/spark/SparkConf.scala | 16 ++++++++++------ .../spark/internal/config/ConfigProvider.scala | 4 +++- .../apache/spark/internal/config/package.scala | 2 -- .../scala/org/apache/spark/SparkConfSuite.scala | 7 +++++++ 4 files changed, 20 insertions(+), 9 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/SparkConf.scala b/core/src/main/scala/org/apache/spark/SparkConf.scala index ee726df7391f1..0e08ff65e4784 100644 --- a/core/src/main/scala/org/apache/spark/SparkConf.scala +++ b/core/src/main/scala/org/apache/spark/SparkConf.scala @@ -17,6 +17,7 @@ package org.apache.spark +import java.util.{Map => JMap} import java.util.concurrent.ConcurrentHashMap import scala.collection.JavaConverters._ @@ -24,6 +25,7 @@ import scala.collection.mutable.LinkedHashSet import org.apache.avro.{Schema, SchemaNormalization} +import org.apache.spark.deploy.history.config._ import org.apache.spark.internal.Logging import org.apache.spark.internal.config._ import org.apache.spark.serializer.KryoSerializer @@ -370,7 +372,7 @@ class SparkConf(loadDefaults: Boolean) extends Cloneable with Logging with Seria /** Get a parameter as an Option */ def getOption(key: String): Option[String] = { - Option(settings.get(key)).orElse(getDeprecatedConfig(key, this)) + Option(settings.get(key)).orElse(getDeprecatedConfig(key, settings)) } /** Get an optional value, applying variable substitution. */ @@ -622,7 +624,7 @@ private[spark] object SparkConf extends Logging { AlternateConfig("spark.history.updateInterval", "1.3")), "spark.history.fs.cleaner.interval" -> Seq( AlternateConfig("spark.history.fs.cleaner.interval.seconds", "1.4")), - "spark.history.fs.cleaner.maxAge" -> Seq( + MAX_LOG_AGE_S.key -> Seq( AlternateConfig("spark.history.fs.cleaner.maxAge.seconds", "1.4")), "spark.yarn.am.waitTime" -> Seq( AlternateConfig("spark.yarn.applicationMaster.waitTries", "1.3", @@ -663,8 +665,10 @@ private[spark] object SparkConf extends Logging { AlternateConfig("spark.yarn.jar", "2.0")), "spark.yarn.access.hadoopFileSystems" -> Seq( AlternateConfig("spark.yarn.access.namenodes", "2.2")), - "spark.maxRemoteBlockSizeFetchToMem" -> Seq( - AlternateConfig("spark.reducer.maxReqSizeShuffleToMem", "2.3")) + MAX_REMOTE_BLOCK_SIZE_FETCH_TO_MEM.key -> Seq( + AlternateConfig("spark.reducer.maxReqSizeShuffleToMem", "2.3")), + LISTENER_BUS_EVENT_QUEUE_CAPACITY.key -> Seq( + AlternateConfig("spark.scheduler.listenerbus.eventqueue.size", "2.3")) ) /** @@ -704,9 +708,9 @@ private[spark] object SparkConf extends Logging { * Looks for available deprecated keys for the given config option, and return the first * value available. */ - def getDeprecatedConfig(key: String, conf: SparkConf): Option[String] = { + def getDeprecatedConfig(key: String, conf: JMap[String, String]): Option[String] = { configsWithAlternatives.get(key).flatMap { alts => - alts.collectFirst { case alt if conf.contains(alt.key) => + alts.collectFirst { case alt if conf.containsKey(alt.key) => val value = conf.get(alt.key) if (alt.translation != null) alt.translation(value) else value } diff --git a/core/src/main/scala/org/apache/spark/internal/config/ConfigProvider.scala b/core/src/main/scala/org/apache/spark/internal/config/ConfigProvider.scala index 5d98a1185f053..392f9d56e7f51 100644 --- a/core/src/main/scala/org/apache/spark/internal/config/ConfigProvider.scala +++ b/core/src/main/scala/org/apache/spark/internal/config/ConfigProvider.scala @@ -19,6 +19,8 @@ package org.apache.spark.internal.config import java.util.{Map => JMap} +import org.apache.spark.SparkConf + /** * A source of configuration values. */ @@ -53,7 +55,7 @@ private[spark] class SparkConfigProvider(conf: JMap[String, String]) extends Con override def get(key: String): Option[String] = { if (key.startsWith("spark.")) { - Option(conf.get(key)) + Option(conf.get(key)).orElse(SparkConf.getDeprecatedConfig(key, conf)) } else { None } diff --git a/core/src/main/scala/org/apache/spark/internal/config/package.scala b/core/src/main/scala/org/apache/spark/internal/config/package.scala index 7be4d6b212d72..7a9072736b9aa 100644 --- a/core/src/main/scala/org/apache/spark/internal/config/package.scala +++ b/core/src/main/scala/org/apache/spark/internal/config/package.scala @@ -209,7 +209,6 @@ package object config { private[spark] val LISTENER_BUS_EVENT_QUEUE_CAPACITY = ConfigBuilder("spark.scheduler.listenerbus.eventqueue.capacity") - .withAlternative("spark.scheduler.listenerbus.eventqueue.size") .intConf .checkValue(_ > 0, "The capacity of listener bus event queue must not be negative") .createWithDefault(10000) @@ -404,7 +403,6 @@ package object config { "affect both shuffle fetch and block manager remote block fetch. For users who " + "enabled external shuffle service, this feature can only be worked when external shuffle" + " service is newer than Spark 2.2.") - .withAlternative("spark.reducer.maxReqSizeShuffleToMem") .bytesConf(ByteUnit.BYTE) .createWithDefault(Long.MaxValue) diff --git a/core/src/test/scala/org/apache/spark/SparkConfSuite.scala b/core/src/test/scala/org/apache/spark/SparkConfSuite.scala index 0897891ee1758..c771eb4ee3ef5 100644 --- a/core/src/test/scala/org/apache/spark/SparkConfSuite.scala +++ b/core/src/test/scala/org/apache/spark/SparkConfSuite.scala @@ -26,6 +26,7 @@ import scala.util.{Random, Try} import com.esotericsoftware.kryo.Kryo +import org.apache.spark.deploy.history.config._ import org.apache.spark.internal.config._ import org.apache.spark.network.util.ByteUnit import org.apache.spark.serializer.{JavaSerializer, KryoRegistrator, KryoSerializer} @@ -248,6 +249,12 @@ class SparkConfSuite extends SparkFunSuite with LocalSparkContext with ResetSyst conf.set("spark.kryoserializer.buffer.mb", "1.1") assert(conf.getSizeAsKb("spark.kryoserializer.buffer") === 1100) + + conf.set("spark.history.fs.cleaner.maxAge.seconds", "42") + assert(conf.get(MAX_LOG_AGE_S) === 42L) + + conf.set("spark.scheduler.listenerbus.eventqueue.size", "84") + assert(conf.get(LISTENER_BUS_EVENT_QUEUE_CAPACITY) === 84) } test("akka deprecated configs") { From 41c6f36018eb086477f21574aacd71616513bd8e Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Tue, 21 Nov 2017 01:42:05 +0100 Subject: [PATCH 1696/1765] [SPARK-22549][SQL] Fix 64KB JVM bytecode limit problem with concat_ws ## What changes were proposed in this pull request? This PR changes `concat_ws` code generation to place generated code for expression for arguments into separated methods if these size could be large. This PR resolved the case of `concat_ws` with a lot of argument ## How was this patch tested? Added new test cases into `StringExpressionsSuite` Author: Kazuaki Ishizaki Closes #19777 from kiszk/SPARK-22549. --- .../expressions/stringExpressions.scala | 100 +++++++++++++----- .../expressions/StringExpressionsSuite.scala | 13 +++ 2 files changed, 89 insertions(+), 24 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala index d5bb7e95769e5..e6f55f476a921 100755 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala @@ -137,13 +137,34 @@ case class ConcatWs(children: Seq[Expression]) if (children.forall(_.dataType == StringType)) { // All children are strings. In that case we can construct a fixed size array. val evals = children.map(_.genCode(ctx)) - - val inputs = evals.map { eval => - s"${eval.isNull} ? (UTF8String) null : ${eval.value}" - }.mkString(", ") - - ev.copy(evals.map(_.code).mkString("\n") + s""" - UTF8String ${ev.value} = UTF8String.concatWs($inputs); + val separator = evals.head + val strings = evals.tail + val numArgs = strings.length + val args = ctx.freshName("args") + + val inputs = strings.zipWithIndex.map { case (eval, index) => + if (eval.isNull != "true") { + s""" + ${eval.code} + if (!${eval.isNull}) { + $args[$index] = ${eval.value}; + } + """ + } else { + "" + } + } + val codes = if (ctx.INPUT_ROW != null && ctx.currentVars == null) { + ctx.splitExpressions(inputs, "valueConcatWs", + ("InternalRow", ctx.INPUT_ROW) :: ("UTF8String[]", args) :: Nil) + } else { + inputs.mkString("\n") + } + ev.copy(s""" + UTF8String[] $args = new UTF8String[$numArgs]; + ${separator.code} + $codes + UTF8String ${ev.value} = UTF8String.concatWs(${separator.value}, $args); boolean ${ev.isNull} = ${ev.value} == null; """) } else { @@ -156,32 +177,63 @@ case class ConcatWs(children: Seq[Expression]) child.dataType match { case StringType => ("", // we count all the StringType arguments num at once below. - s"$array[$idxInVararg ++] = ${eval.isNull} ? (UTF8String) null : ${eval.value};") + if (eval.isNull == "true") { + "" + } else { + s"$array[$idxInVararg ++] = ${eval.isNull} ? (UTF8String) null : ${eval.value};" + }) case _: ArrayType => val size = ctx.freshName("n") - (s""" - if (!${eval.isNull}) { - $varargNum += ${eval.value}.numElements(); - } - """, - s""" - if (!${eval.isNull}) { - final int $size = ${eval.value}.numElements(); - for (int j = 0; j < $size; j ++) { - $array[$idxInVararg ++] = ${ctx.getValue(eval.value, StringType, "j")}; - } + if (eval.isNull == "true") { + ("", "") + } else { + (s""" + if (!${eval.isNull}) { + $varargNum += ${eval.value}.numElements(); + } + """, + s""" + if (!${eval.isNull}) { + final int $size = ${eval.value}.numElements(); + for (int j = 0; j < $size; j ++) { + $array[$idxInVararg ++] = ${ctx.getValue(eval.value, StringType, "j")}; + } + } + """) } - """) } }.unzip - ev.copy(evals.map(_.code).mkString("\n") + - s""" + val codes = ctx.splitExpressions(ctx.INPUT_ROW, evals.map(_.code)) + val varargCounts = ctx.splitExpressions(varargCount, "varargCountsConcatWs", + ("InternalRow", ctx.INPUT_ROW) :: Nil, + "int", + { body => + s""" + int $varargNum = 0; + $body + return $varargNum; + """ + }, + _.mkString(s"$varargNum += ", s";\n$varargNum += ", ";")) + val varargBuilds = ctx.splitExpressions(varargBuild, "varargBuildsConcatWs", + ("InternalRow", ctx.INPUT_ROW) :: ("UTF8String []", array) :: ("int", idxInVararg) :: Nil, + "int", + { body => + s""" + $body + return $idxInVararg; + """ + }, + _.mkString(s"$idxInVararg = ", s";\n$idxInVararg = ", ";")) + ev.copy( + s""" + $codes int $varargNum = ${children.count(_.dataType == StringType) - 1}; int $idxInVararg = 0; - ${varargCount.mkString("\n")} + $varargCounts UTF8String[] $array = new UTF8String[$varargNum]; - ${varargBuild.mkString("\n")} + $varargBuilds UTF8String ${ev.value} = UTF8String.concatWs(${evals.head.value}, $array); boolean ${ev.isNull} = ${ev.value} == null; """) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala index aa9d5a0aa95e3..7ce43066b5aab 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala @@ -80,6 +80,19 @@ class StringExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { // scalastyle:on } + test("SPARK-22549: ConcatWs should not generate codes beyond 64KB") { + val N = 5000 + val sepExpr = Literal.create("#", StringType) + val strings1 = (1 to N).map(x => s"s$x") + val inputsExpr1 = strings1.map(Literal.create(_, StringType)) + checkEvaluation(ConcatWs(sepExpr +: inputsExpr1), strings1.mkString("#"), EmptyRow) + + val strings2 = (1 to N).map(x => Seq(s"s$x")) + val inputsExpr2 = strings2.map(Literal.create(_, ArrayType(StringType))) + checkEvaluation( + ConcatWs(sepExpr +: inputsExpr2), strings2.map(s => s(0)).mkString("#"), EmptyRow) + } + test("elt") { def testElt(result: String, n: java.lang.Integer, args: String*): Unit = { checkEvaluation( From 9d45e675e27278163241081789b06449ca220e43 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Tue, 21 Nov 2017 09:36:37 +0100 Subject: [PATCH 1697/1765] [SPARK-22541][SQL] Explicitly claim that Python udfs can't be conditionally executed with short-curcuit evaluation ## What changes were proposed in this pull request? Besides conditional expressions such as `when` and `if`, users may want to conditionally execute python udfs by short-curcuit evaluation. We should also explicitly note that python udfs don't support this kind of conditional execution too. ## How was this patch tested? N/A, just document change. Author: Liang-Chi Hsieh Closes #19787 from viirya/SPARK-22541. --- python/pyspark/sql/functions.py | 18 ++++++------------ 1 file changed, 6 insertions(+), 12 deletions(-) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index b631e2041706f..4e0faddb1c0df 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -2079,12 +2079,9 @@ def udf(f=None, returnType=StringType()): duplicate invocations may be eliminated or the function may even be invoked more times than it is present in the query. - .. note:: The user-defined functions do not support conditional execution by using them with - SQL conditional expressions such as `when` or `if`. The functions still apply on all rows no - matter the conditions are met or not. So the output is correct if the functions can be - correctly run on all rows without failure. If the functions can cause runtime failure on the - rows that do not satisfy the conditions, the suggested workaround is to incorporate the - condition logic into the functions. + .. note:: The user-defined functions do not support conditional expressions or short curcuiting + in boolean expressions and it ends up with being executed all internally. If the functions + can fail on special rows, the workaround is to incorporate the condition into the functions. :param f: python function if used as a standalone function :param returnType: a :class:`pyspark.sql.types.DataType` object @@ -2194,12 +2191,9 @@ def pandas_udf(f=None, returnType=None, functionType=None): .. note:: The user-defined function must be deterministic. - .. note:: The user-defined functions do not support conditional execution by using them with - SQL conditional expressions such as `when` or `if`. The functions still apply on all rows no - matter the conditions are met or not. So the output is correct if the functions can be - correctly run on all rows without failure. If the functions can cause runtime failure on the - rows that do not satisfy the conditions, the suggested workaround is to incorporate the - condition logic into the functions. + .. note:: The user-defined functions do not support conditional expressions or short curcuiting + in boolean expressions and it ends up with being executed all internally. If the functions + can fail on special rows, the workaround is to incorporate the condition into the functions. """ # decorator @pandas_udf(returnType, functionType) is_decorator = f is None or isinstance(f, (str, DataType)) From c9577148069d2215dc79cbf828a378591b4fba5d Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Tue, 21 Nov 2017 12:16:54 +0100 Subject: [PATCH 1698/1765] [SPARK-22508][SQL] Fix 64KB JVM bytecode limit problem with GenerateUnsafeRowJoiner.create() ## What changes were proposed in this pull request? This PR changes `GenerateUnsafeRowJoiner.create()` code generation to place generated code for statements to operate bitmap and offset into separated methods if these size could be large. ## How was this patch tested? Added a new test case into `GenerateUnsafeRowJoinerSuite` Author: Kazuaki Ishizaki Closes #19737 from kiszk/SPARK-22508. --- .../codegen/GenerateUnsafeRowJoiner.scala | 31 ++++++++++++++----- .../GenerateUnsafeRowJoinerSuite.scala | 5 +++ 2 files changed, 28 insertions(+), 8 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeRowJoiner.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeRowJoiner.scala index 6bc72a0d75c6d..be5f5a73b5d47 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeRowJoiner.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeRowJoiner.scala @@ -17,6 +17,9 @@ package org.apache.spark.sql.catalyst.expressions.codegen +import scala.collection.mutable +import scala.collection.mutable.ArrayBuffer + import org.apache.spark.sql.catalyst.expressions.{Attribute, UnsafeRow} import org.apache.spark.sql.types.StructType import org.apache.spark.unsafe.Platform @@ -51,6 +54,7 @@ object GenerateUnsafeRowJoiner extends CodeGenerator[(StructType, StructType), U } def create(schema1: StructType, schema2: StructType): UnsafeRowJoiner = { + val ctx = new CodegenContext val offset = Platform.BYTE_ARRAY_OFFSET val getLong = "Platform.getLong" val putLong = "Platform.putLong" @@ -88,8 +92,14 @@ object GenerateUnsafeRowJoiner extends CodeGenerator[(StructType, StructType), U s"$getLong(obj2, offset2 + ${(i - bitset1Words) * 8})" } } - s"$putLong(buf, ${offset + i * 8}, $bits);" - }.mkString("\n") + s"$putLong(buf, ${offset + i * 8}, $bits);\n" + } + + val copyBitsets = ctx.splitExpressions( + expressions = copyBitset, + funcName = "copyBitsetFunc", + arguments = ("java.lang.Object", "obj1") :: ("long", "offset1") :: + ("java.lang.Object", "obj2") :: ("long", "offset2") :: Nil) // --------------------- copy fixed length portion from row 1 ----------------------- // var cursor = offset + outputBitsetWords * 8 @@ -150,11 +160,14 @@ object GenerateUnsafeRowJoiner extends CodeGenerator[(StructType, StructType), U s"(${(outputBitsetWords - bitset2Words + schema1.size) * 8}L + numBytesVariableRow1)" } val cursor = offset + outputBitsetWords * 8 + i * 8 - s""" - |$putLong(buf, $cursor, $getLong(buf, $cursor) + ($shift << 32)); - """.stripMargin + s"$putLong(buf, $cursor, $getLong(buf, $cursor) + ($shift << 32));\n" } - }.mkString("\n") + } + + val updateOffsets = ctx.splitExpressions( + expressions = updateOffset, + funcName = "copyBitsetFunc", + arguments = ("long", "numBytesVariableRow1") :: Nil) // ------------------------ Finally, put everything together --------------------------- // val codeBody = s""" @@ -166,6 +179,8 @@ object GenerateUnsafeRowJoiner extends CodeGenerator[(StructType, StructType), U | private byte[] buf = new byte[64]; | private UnsafeRow out = new UnsafeRow(${schema1.size + schema2.size}); | + | ${ctx.declareAddedFunctions()} + | | public UnsafeRow join(UnsafeRow row1, UnsafeRow row2) { | // row1: ${schema1.size} fields, $bitset1Words words in bitset | // row2: ${schema2.size}, $bitset2Words words in bitset @@ -180,12 +195,12 @@ object GenerateUnsafeRowJoiner extends CodeGenerator[(StructType, StructType), U | final java.lang.Object obj2 = row2.getBaseObject(); | final long offset2 = row2.getBaseOffset(); | - | $copyBitset + | $copyBitsets | $copyFixedLengthRow1 | $copyFixedLengthRow2 | $copyVariableLengthRow1 | $copyVariableLengthRow2 - | $updateOffset + | $updateOffsets | | out.pointTo(buf, sizeInBytes); | diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeRowJoinerSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeRowJoinerSuite.scala index 9f19745cefd20..f203f25ad10d4 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeRowJoinerSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeRowJoinerSuite.scala @@ -66,6 +66,11 @@ class GenerateUnsafeRowJoinerSuite extends SparkFunSuite { } } + test("SPARK-22508: GenerateUnsafeRowJoiner.create should not generate codes beyond 64KB") { + val N = 3000 + testConcatOnce(N, N, variable) + } + private def testConcat(numFields1: Int, numFields2: Int, candidateTypes: Seq[DataType]): Unit = { for (i <- 0 until 10) { testConcatOnce(numFields1, numFields2, candidateTypes) From 9bdff0bcd83e730aba8dc1253da24a905ba07ae3 Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Tue, 21 Nov 2017 12:19:11 +0100 Subject: [PATCH 1699/1765] [SPARK-22550][SQL] Fix 64KB JVM bytecode limit problem with elt ## What changes were proposed in this pull request? This PR changes `elt` code generation to place generated code for expression for arguments into separated methods if these size could be large. This PR resolved the case of `elt` with a lot of argument ## How was this patch tested? Added new test cases into `StringExpressionsSuite` Author: Kazuaki Ishizaki Closes #19778 from kiszk/SPARK-22550. --- .../expressions/codegen/CodeGenerator.scala | 44 ++++++++++------- .../expressions/stringExpressions.scala | 48 +++++++++++++++---- .../expressions/StringExpressionsSuite.scala | 7 +++ 3 files changed, 73 insertions(+), 26 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index 3dc3f8e4adac0..e02f125a93345 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -790,23 +790,7 @@ class CodegenContext { returnType: String = "void", makeSplitFunction: String => String = identity, foldFunctions: Seq[String] => String = _.mkString("", ";\n", ";")): String = { - val blocks = new ArrayBuffer[String]() - val blockBuilder = new StringBuilder() - var length = 0 - for (code <- expressions) { - // We can't know how many bytecode will be generated, so use the length of source code - // as metric. A method should not go beyond 8K, otherwise it will not be JITted, should - // also not be too small, or it will have many function calls (for wide table), see the - // results in BenchmarkWideTable. - if (length > 1024) { - blocks += blockBuilder.toString() - blockBuilder.clear() - length = 0 - } - blockBuilder.append(code) - length += CodeFormatter.stripExtraNewLinesAndComments(code).length - } - blocks += blockBuilder.toString() + val blocks = buildCodeBlocks(expressions) if (blocks.length == 1) { // inline execution if only one block @@ -841,6 +825,32 @@ class CodegenContext { } } + /** + * Splits the generated code of expressions into multiple sequences of String + * based on a threshold of length of a String + * + * @param expressions the codes to evaluate expressions. + */ + def buildCodeBlocks(expressions: Seq[String]): Seq[String] = { + val blocks = new ArrayBuffer[String]() + val blockBuilder = new StringBuilder() + var length = 0 + for (code <- expressions) { + // We can't know how many bytecode will be generated, so use the length of source code + // as metric. A method should not go beyond 8K, otherwise it will not be JITted, should + // also not be too small, or it will have many function calls (for wide table), see the + // results in BenchmarkWideTable. + if (length > 1024) { + blocks += blockBuilder.toString() + blockBuilder.clear() + length = 0 + } + blockBuilder.append(code) + length += CodeFormatter.stripExtraNewLinesAndComments(code).length + } + blocks += blockBuilder.toString() + } + /** * Here we handle all the methods which have been added to the inner classes and * not to the outer class. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala index e6f55f476a921..360dd845f8d3a 100755 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala @@ -288,22 +288,52 @@ case class Elt(children: Seq[Expression]) override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val index = indexExpr.genCode(ctx) val strings = stringExprs.map(_.genCode(ctx)) + val indexVal = ctx.freshName("index") + val stringVal = ctx.freshName("stringVal") val assignStringValue = strings.zipWithIndex.map { case (eval, index) => s""" case ${index + 1}: - ${ev.value} = ${eval.isNull} ? null : ${eval.value}; + ${eval.code} + $stringVal = ${eval.isNull} ? null : ${eval.value}; break; """ - }.mkString("\n") - val indexVal = ctx.freshName("index") - val stringArray = ctx.freshName("strings"); + } - ev.copy(index.code + "\n" + strings.map(_.code).mkString("\n") + s""" - final int $indexVal = ${index.value}; - UTF8String ${ev.value} = null; - switch ($indexVal) { - $assignStringValue + val cases = ctx.buildCodeBlocks(assignStringValue) + val codes = if (cases.length == 1) { + s""" + UTF8String $stringVal = null; + switch ($indexVal) { + ${cases.head} + } + """ + } else { + var prevFunc = "null" + for (c <- cases.reverse) { + val funcName = ctx.freshName("eltFunc") + val funcBody = s""" + private UTF8String $funcName(InternalRow ${ctx.INPUT_ROW}, int $indexVal) { + UTF8String $stringVal = null; + switch ($indexVal) { + $c + default: + return $prevFunc; + } + return $stringVal; + } + """ + val fullFuncName = ctx.addNewFunction(funcName, funcBody) + prevFunc = s"$fullFuncName(${ctx.INPUT_ROW}, $indexVal)" } + s"UTF8String $stringVal = $prevFunc;" + } + + ev.copy( + s""" + ${index.code} + final int $indexVal = ${index.value}; + $codes + UTF8String ${ev.value} = $stringVal; final boolean ${ev.isNull} = ${ev.value} == null; """) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala index 7ce43066b5aab..c761394756875 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala @@ -116,6 +116,13 @@ class StringExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { assert(Elt(Seq(Literal(1), Literal(2))).checkInputDataTypes().isFailure) } + test("SPARK-22550: Elt should not generate codes beyond 64KB") { + val N = 10000 + val strings = (1 to N).map(x => s"s$x") + val args = Literal.create(N, IntegerType) +: strings.map(Literal.create(_, StringType)) + checkEvaluation(Elt(args), s"s$N") + } + test("StringComparison") { val row = create_row("abc", null) val c1 = 'a.string.at(0) From 96e947ed6c945bdbe71c308112308489668f19ac Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Tue, 21 Nov 2017 13:48:09 +0100 Subject: [PATCH 1700/1765] [SPARK-22569][SQL] Clean usage of addMutableState and splitExpressions ## What changes were proposed in this pull request? This PR is to clean the usage of addMutableState and splitExpressions 1. replace hardcoded type string to ctx.JAVA_BOOLEAN etc. 2. create a default value of the initCode for ctx.addMutableStats 3. Use named arguments when calling `splitExpressions ` ## How was this patch tested? The existing test cases Author: gatorsmile Closes #19790 from gatorsmile/codeClean. --- .../MonotonicallyIncreasingID.scala | 4 +- .../expressions/SparkPartitionID.scala | 2 +- .../sql/catalyst/expressions/arithmetic.scala | 8 ++-- .../expressions/codegen/CodeGenerator.scala | 22 ++++++--- .../codegen/GenerateMutableProjection.scala | 2 +- .../expressions/complexTypeCreator.scala | 2 +- .../spark/sql/catalyst/expressions/hash.scala | 6 +-- .../expressions/nullExpressions.scala | 10 +++-- .../expressions/objects/objects.scala | 26 +++++------ .../sql/catalyst/expressions/predicates.scala | 6 +-- .../expressions/randomExpressions.scala | 4 +- .../expressions/stringExpressions.scala | 45 +++++++++++-------- .../sql/execution/ColumnarBatchScan.scala | 4 +- .../apache/spark/sql/execution/SortExec.scala | 2 +- .../aggregate/HashAggregateExec.scala | 20 ++++----- .../aggregate/HashMapGenerator.scala | 4 +- .../execution/basicPhysicalOperators.scala | 8 ++-- .../columnar/GenerateColumnAccessor.scala | 2 +- .../execution/joins/SortMergeJoinExec.scala | 6 +-- .../apache/spark/sql/execution/limit.scala | 4 +- 20 files changed, 104 insertions(+), 83 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/MonotonicallyIncreasingID.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/MonotonicallyIncreasingID.scala index 84027b53dca27..821d784a01342 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/MonotonicallyIncreasingID.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/MonotonicallyIncreasingID.scala @@ -67,8 +67,8 @@ case class MonotonicallyIncreasingID() extends LeafExpression with Nondeterminis override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val countTerm = ctx.freshName("count") val partitionMaskTerm = ctx.freshName("partitionMask") - ctx.addMutableState(ctx.JAVA_LONG, countTerm, "") - ctx.addMutableState(ctx.JAVA_LONG, partitionMaskTerm, "") + ctx.addMutableState(ctx.JAVA_LONG, countTerm) + ctx.addMutableState(ctx.JAVA_LONG, partitionMaskTerm) ctx.addPartitionInitializationStatement(s"$countTerm = 0L;") ctx.addPartitionInitializationStatement(s"$partitionMaskTerm = ((long) partitionIndex) << 33;") diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SparkPartitionID.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SparkPartitionID.scala index 8db7efdbb5dd4..4fa18d6b3209b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SparkPartitionID.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SparkPartitionID.scala @@ -44,7 +44,7 @@ case class SparkPartitionID() extends LeafExpression with Nondeterministic { override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val idTerm = ctx.freshName("partitionId") - ctx.addMutableState(ctx.JAVA_INT, idTerm, "") + ctx.addMutableState(ctx.JAVA_INT, idTerm) ctx.addPartitionInitializationStatement(s"$idTerm = partitionIndex;") ev.copy(code = s"final ${ctx.javaType(dataType)} ${ev.value} = $idTerm;", isNull = "false") } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala index 72d5889d2f202..e5a1096bba713 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala @@ -602,8 +602,8 @@ case class Least(children: Seq[Expression]) extends Expression { override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val evalChildren = children.map(_.genCode(ctx)) - ctx.addMutableState("boolean", ev.isNull, "") - ctx.addMutableState(ctx.javaType(dataType), ev.value, "") + ctx.addMutableState(ctx.JAVA_BOOLEAN, ev.isNull) + ctx.addMutableState(ctx.javaType(dataType), ev.value) def updateEval(eval: ExprCode): String = { s""" ${eval.code} @@ -668,8 +668,8 @@ case class Greatest(children: Seq[Expression]) extends Expression { override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val evalChildren = children.map(_.genCode(ctx)) - ctx.addMutableState("boolean", ev.isNull, "") - ctx.addMutableState(ctx.javaType(dataType), ev.value, "") + ctx.addMutableState(ctx.JAVA_BOOLEAN, ev.isNull) + ctx.addMutableState(ctx.javaType(dataType), ev.value) def updateEval(eval: ExprCode): String = { s""" ${eval.code} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index e02f125a93345..78617194e47d5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -157,7 +157,19 @@ class CodegenContext { val mutableStates: mutable.ArrayBuffer[(String, String, String)] = mutable.ArrayBuffer.empty[(String, String, String)] - def addMutableState(javaType: String, variableName: String, initCode: String): Unit = { + /** + * Add a mutable state as a field to the generated class. c.f. the comments above. + * + * @param javaType Java type of the field. Note that short names can be used for some types, + * e.g. InternalRow, UnsafeRow, UnsafeArrayData, etc. Other types will have to + * specify the fully-qualified Java type name. See the code in doCompile() for + * the list of default imports available. + * Also, generic type arguments are accepted but ignored. + * @param variableName Name of the field. + * @param initCode The statement(s) to put into the init() method to initialize this field. + * If left blank, the field will be default-initialized. + */ + def addMutableState(javaType: String, variableName: String, initCode: String = ""): Unit = { mutableStates += ((javaType, variableName, initCode)) } @@ -191,7 +203,7 @@ class CodegenContext { val initCodes = mutableStates.distinct.map(_._3 + "\n") // The generated initialization code may exceed 64kb function size limit in JVM if there are too // many mutable states, so split it into multiple functions. - splitExpressions(initCodes, "init", Nil) + splitExpressions(expressions = initCodes, funcName = "init", arguments = Nil) } /** @@ -769,7 +781,7 @@ class CodegenContext { // Cannot split these expressions because they are not created from a row object. return expressions.mkString("\n") } - splitExpressions(expressions, "apply", ("InternalRow", row) :: Nil) + splitExpressions(expressions, funcName = "apply", arguments = ("InternalRow", row) :: Nil) } /** @@ -931,7 +943,7 @@ class CodegenContext { dataType: DataType, baseFuncName: String): (String, String, String) = { val globalIsNull = freshName("isNull") - addMutableState("boolean", globalIsNull, s"$globalIsNull = false;") + addMutableState(JAVA_BOOLEAN, globalIsNull, s"$globalIsNull = false;") val globalValue = freshName("value") addMutableState(javaType(dataType), globalValue, s"$globalValue = ${defaultValue(dataType)};") @@ -1038,7 +1050,7 @@ class CodegenContext { // 2. Less code. // Currently, we will do this for all non-leaf only expression trees (i.e. expr trees with // at least two nodes) as the cost of doing it is expected to be low. - addMutableState("boolean", isNull, s"$isNull = false;") + addMutableState(JAVA_BOOLEAN, isNull, s"$isNull = false;") addMutableState(javaType(expr.dataType), value, s"$value = ${defaultValue(expr.dataType)};") diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala index b5429fade53cf..802e8bdb1ca33 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala @@ -63,7 +63,7 @@ object GenerateMutableProjection extends CodeGenerator[Seq[Expression], MutableP if (e.nullable) { val isNull = s"isNull_$i" val value = s"value_$i" - ctx.addMutableState("boolean", isNull, s"$isNull = true;") + ctx.addMutableState(ctx.JAVA_BOOLEAN, isNull, s"$isNull = true;") ctx.addMutableState(ctx.javaType(e.dataType), value, s"$value = ${ctx.defaultValue(e.dataType)};") s""" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala index 4b6574a31424e..2a00d57ee1300 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala @@ -120,7 +120,7 @@ private [sql] object GenArrayData { UnsafeArrayData.calculateHeaderPortionInBytes(numElements) + ByteArrayMethods.roundNumberOfBytesToNearestWord(elementType.defaultSize * numElements) val baseOffset = Platform.BYTE_ARRAY_OFFSET - ctx.addMutableState("UnsafeArrayData", arrayDataName, "") + ctx.addMutableState("UnsafeArrayData", arrayDataName) val primitiveValueTypeName = ctx.primitiveTypeName(elementType) val assignments = elementsCode.zipWithIndex.map { case (eval, i) => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala index eb3c49f5cf30e..9e0786e367911 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala @@ -277,7 +277,7 @@ abstract class HashExpression[E] extends Expression { } }) - ctx.addMutableState(ctx.javaType(dataType), ev.value, "") + ctx.addMutableState(ctx.javaType(dataType), ev.value) ev.copy(code = s""" ${ev.value} = $seed; $childrenHash""") @@ -616,8 +616,8 @@ case class HiveHash(children: Seq[Expression]) extends HashExpression[Int] { s"\n$childHash = 0;" }) - ctx.addMutableState(ctx.javaType(dataType), ev.value, "") - ctx.addMutableState("int", childHash, s"$childHash = 0;") + ctx.addMutableState(ctx.javaType(dataType), ev.value) + ctx.addMutableState(ctx.JAVA_INT, childHash, s"$childHash = 0;") ev.copy(code = s""" ${ev.value} = $seed; $childrenHash""") diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala index 4aeab2c3ad0a8..5eaf3f2202776 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala @@ -72,8 +72,8 @@ case class Coalesce(children: Seq[Expression]) extends Expression { } override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - ctx.addMutableState("boolean", ev.isNull, "") - ctx.addMutableState(ctx.javaType(dataType), ev.value, "") + ctx.addMutableState(ctx.JAVA_BOOLEAN, ev.isNull) + ctx.addMutableState(ctx.javaType(dataType), ev.value) val evals = children.map { e => val eval = e.genCode(ctx) @@ -385,8 +385,10 @@ case class AtLeastNNonNulls(n: Int, children: Seq[Expression]) extends Predicate val code = if (ctx.INPUT_ROW == null || ctx.currentVars != null) { evals.mkString("\n") } else { - ctx.splitExpressions(evals, "atLeastNNonNulls", - ("InternalRow", ctx.INPUT_ROW) :: ("int", nonnull) :: Nil, + ctx.splitExpressions( + expressions = evals, + funcName = "atLeastNNonNulls", + arguments = ("InternalRow", ctx.INPUT_ROW) :: ("int", nonnull) :: Nil, returnType = "int", makeSplitFunction = { body => s""" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala index f2eee991c9865..006d37f38d6c4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala @@ -63,14 +63,14 @@ trait InvokeLike extends Expression with NonSQLExpression { val resultIsNull = if (needNullCheck) { val resultIsNull = ctx.freshName("resultIsNull") - ctx.addMutableState("boolean", resultIsNull, "") + ctx.addMutableState(ctx.JAVA_BOOLEAN, resultIsNull) resultIsNull } else { "false" } val argValues = arguments.map { e => val argValue = ctx.freshName("argValue") - ctx.addMutableState(ctx.javaType(e.dataType), argValue, "") + ctx.addMutableState(ctx.javaType(e.dataType), argValue) argValue } @@ -548,7 +548,7 @@ case class MapObjects private( override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val elementJavaType = ctx.javaType(loopVarDataType) - ctx.addMutableState(elementJavaType, loopValue, "") + ctx.addMutableState(elementJavaType, loopValue) val genInputData = inputData.genCode(ctx) val genFunction = lambdaFunction.genCode(ctx) val dataLength = ctx.freshName("dataLength") @@ -644,7 +644,7 @@ case class MapObjects private( } val loopNullCheck = if (loopIsNull != "false") { - ctx.addMutableState("boolean", loopIsNull, "") + ctx.addMutableState(ctx.JAVA_BOOLEAN, loopIsNull) inputDataType match { case _: ArrayType => s"$loopIsNull = ${genInputData.value}.isNullAt($loopIndex);" case _ => s"$loopIsNull = $loopValue == null;" @@ -808,10 +808,10 @@ case class CatalystToExternalMap private( val mapType = inputDataType(inputData.dataType).asInstanceOf[MapType] val keyElementJavaType = ctx.javaType(mapType.keyType) - ctx.addMutableState(keyElementJavaType, keyLoopValue, "") + ctx.addMutableState(keyElementJavaType, keyLoopValue) val genKeyFunction = keyLambdaFunction.genCode(ctx) val valueElementJavaType = ctx.javaType(mapType.valueType) - ctx.addMutableState(valueElementJavaType, valueLoopValue, "") + ctx.addMutableState(valueElementJavaType, valueLoopValue) val genValueFunction = valueLambdaFunction.genCode(ctx) val genInputData = inputData.genCode(ctx) val dataLength = ctx.freshName("dataLength") @@ -844,7 +844,7 @@ case class CatalystToExternalMap private( val genValueFunctionValue = genFunctionValue(valueLambdaFunction, genValueFunction) val valueLoopNullCheck = if (valueLoopIsNull != "false") { - ctx.addMutableState("boolean", valueLoopIsNull, "") + ctx.addMutableState(ctx.JAVA_BOOLEAN, valueLoopIsNull) s"$valueLoopIsNull = $valueArray.isNullAt($loopIndex);" } else { "" @@ -994,8 +994,8 @@ case class ExternalMapToCatalyst private( val keyElementJavaType = ctx.javaType(keyType) val valueElementJavaType = ctx.javaType(valueType) - ctx.addMutableState(keyElementJavaType, key, "") - ctx.addMutableState(valueElementJavaType, value, "") + ctx.addMutableState(keyElementJavaType, key) + ctx.addMutableState(valueElementJavaType, value) val (defineEntries, defineKeyValue) = child.dataType match { case ObjectType(cls) if classOf[java.util.Map[_, _]].isAssignableFrom(cls) => @@ -1031,14 +1031,14 @@ case class ExternalMapToCatalyst private( } val keyNullCheck = if (keyIsNull != "false") { - ctx.addMutableState("boolean", keyIsNull, "") + ctx.addMutableState(ctx.JAVA_BOOLEAN, keyIsNull) s"$keyIsNull = $key == null;" } else { "" } val valueNullCheck = if (valueIsNull != "false") { - ctx.addMutableState("boolean", valueIsNull, "") + ctx.addMutableState(ctx.JAVA_BOOLEAN, valueIsNull) s"$valueIsNull = $value == null;" } else { "" @@ -1106,7 +1106,7 @@ case class CreateExternalRow(children: Seq[Expression], schema: StructType) override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val rowClass = classOf[GenericRowWithSchema].getName val values = ctx.freshName("values") - ctx.addMutableState("Object[]", values, "") + ctx.addMutableState("Object[]", values) val childrenCodes = children.zipWithIndex.map { case (e, i) => val eval = e.genCode(ctx) @@ -1244,7 +1244,7 @@ case class InitializeJavaBean(beanInstance: Expression, setters: Map[String, Exp val javaBeanInstance = ctx.freshName("javaBean") val beanInstanceJavaType = ctx.javaType(beanInstance.dataType) - ctx.addMutableState(beanInstanceJavaType, javaBeanInstance, "") + ctx.addMutableState(beanInstanceJavaType, javaBeanInstance) val initialize = setters.map { case (setterMethod, fieldValue) => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala index 5d75c6004bfe4..c0084af320689 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala @@ -236,8 +236,8 @@ case class In(value: Expression, list: Seq[Expression]) extends Predicate { override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val valueGen = value.genCode(ctx) val listGen = list.map(_.genCode(ctx)) - ctx.addMutableState("boolean", ev.value, "") - ctx.addMutableState("boolean", ev.isNull, "") + ctx.addMutableState(ctx.JAVA_BOOLEAN, ev.value) + ctx.addMutableState(ctx.JAVA_BOOLEAN, ev.isNull) val valueArg = ctx.freshName("valueArg") val listCode = listGen.map(x => s""" @@ -253,7 +253,7 @@ case class In(value: Expression, list: Seq[Expression]) extends Predicate { """) val listCodes = if (ctx.INPUT_ROW != null && ctx.currentVars == null) { val args = ("InternalRow", ctx.INPUT_ROW) :: (ctx.javaType(value.dataType), valueArg) :: Nil - ctx.splitExpressions(listCode, "valueIn", args) + ctx.splitExpressions(expressions = listCode, funcName = "valueIn", arguments = args) } else { listCode.mkString("\n") } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala index 97051769cbf72..b4aefe6cff73e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala @@ -79,7 +79,7 @@ case class Rand(child: Expression) extends RDG { override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val rngTerm = ctx.freshName("rng") val className = classOf[XORShiftRandom].getName - ctx.addMutableState(className, rngTerm, "") + ctx.addMutableState(className, rngTerm) ctx.addPartitionInitializationStatement( s"$rngTerm = new $className(${seed}L + partitionIndex);") ev.copy(code = s""" @@ -114,7 +114,7 @@ case class Randn(child: Expression) extends RDG { override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val rngTerm = ctx.freshName("rng") val className = classOf[XORShiftRandom].getName - ctx.addMutableState(className, rngTerm, "") + ctx.addMutableState(className, rngTerm) ctx.addPartitionInitializationStatement( s"$rngTerm = new $className(${seed}L + partitionIndex);") ev.copy(code = s""" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala index 360dd845f8d3a..1c599af2a01d0 100755 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala @@ -74,8 +74,10 @@ case class Concat(children: Seq[Expression]) extends Expression with ImplicitCas """ } val codes = if (ctx.INPUT_ROW != null && ctx.currentVars == null) { - ctx.splitExpressions(inputs, "valueConcat", - ("InternalRow", ctx.INPUT_ROW) :: ("UTF8String[]", args) :: Nil) + ctx.splitExpressions( + expressions = inputs, + funcName = "valueConcat", + arguments = ("InternalRow", ctx.INPUT_ROW) :: ("UTF8String[]", args) :: Nil) } else { inputs.mkString("\n") } @@ -155,8 +157,10 @@ case class ConcatWs(children: Seq[Expression]) } } val codes = if (ctx.INPUT_ROW != null && ctx.currentVars == null) { - ctx.splitExpressions(inputs, "valueConcatWs", - ("InternalRow", ctx.INPUT_ROW) :: ("UTF8String[]", args) :: Nil) + ctx.splitExpressions( + expressions = inputs, + funcName = "valueConcatWs", + arguments = ("InternalRow", ctx.INPUT_ROW) :: ("UTF8String[]", args) :: Nil) } else { inputs.mkString("\n") } @@ -205,27 +209,30 @@ case class ConcatWs(children: Seq[Expression]) }.unzip val codes = ctx.splitExpressions(ctx.INPUT_ROW, evals.map(_.code)) - val varargCounts = ctx.splitExpressions(varargCount, "varargCountsConcatWs", - ("InternalRow", ctx.INPUT_ROW) :: Nil, - "int", - { body => + val varargCounts = ctx.splitExpressions( + expressions = varargCount, + funcName = "varargCountsConcatWs", + arguments = ("InternalRow", ctx.INPUT_ROW) :: Nil, + returnType = "int", + makeSplitFunction = body => s""" int $varargNum = 0; $body return $varargNum; - """ - }, - _.mkString(s"$varargNum += ", s";\n$varargNum += ", ";")) - val varargBuilds = ctx.splitExpressions(varargBuild, "varargBuildsConcatWs", - ("InternalRow", ctx.INPUT_ROW) :: ("UTF8String []", array) :: ("int", idxInVararg) :: Nil, - "int", - { body => + """, + foldFunctions = _.mkString(s"$varargNum += ", s";\n$varargNum += ", ";")) + val varargBuilds = ctx.splitExpressions( + expressions = varargBuild, + funcName = "varargBuildsConcatWs", + arguments = + ("InternalRow", ctx.INPUT_ROW) :: ("UTF8String []", array) :: ("int", idxInVararg) :: Nil, + returnType = "int", + makeSplitFunction = body => s""" $body return $idxInVararg; - """ - }, - _.mkString(s"$idxInVararg = ", s";\n$idxInVararg = ", ";")) + """, + foldFunctions = _.mkString(s"$idxInVararg = ", s";\n$idxInVararg = ", ";")) ev.copy( s""" $codes @@ -2059,7 +2066,7 @@ case class FormatNumber(x: Expression, d: Expression) val numberFormat = ctx.freshName("numberFormat") val i = ctx.freshName("i") val dFormat = ctx.freshName("dFormat") - ctx.addMutableState("int", lastDValue, s"$lastDValue = -100;") + ctx.addMutableState(ctx.JAVA_INT, lastDValue, s"$lastDValue = -100;") ctx.addMutableState(sb, pattern, s"$pattern = new $sb();") ctx.addMutableState(df, numberFormat, s"""$numberFormat = new $df("", new $dfs($l.$usLocale));""") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ColumnarBatchScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ColumnarBatchScan.scala index 1925bad8c3545..a9bfb634fbdea 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ColumnarBatchScan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ColumnarBatchScan.scala @@ -76,14 +76,14 @@ private[sql] trait ColumnarBatchScan extends CodegenSupport { val numOutputRows = metricTerm(ctx, "numOutputRows") val scanTimeMetric = metricTerm(ctx, "scanTime") val scanTimeTotalNs = ctx.freshName("scanTime") - ctx.addMutableState("long", scanTimeTotalNs, s"$scanTimeTotalNs = 0;") + ctx.addMutableState(ctx.JAVA_LONG, scanTimeTotalNs, s"$scanTimeTotalNs = 0;") val columnarBatchClz = classOf[ColumnarBatch].getName val batch = ctx.freshName("batch") ctx.addMutableState(columnarBatchClz, batch, s"$batch = null;") val idx = ctx.freshName("batchIdx") - ctx.addMutableState("int", idx, s"$idx = 0;") + ctx.addMutableState(ctx.JAVA_INT, idx, s"$idx = 0;") val colVars = output.indices.map(i => ctx.freshName("colInstance" + i)) val columnVectorClzs = vectorTypes.getOrElse( Seq.fill(colVars.size)(classOf[ColumnVector].getName)) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SortExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SortExec.scala index 21765cdbd94cf..c0e21343ae623 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SortExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SortExec.scala @@ -134,7 +134,7 @@ case class SortExec( override protected def doProduce(ctx: CodegenContext): String = { val needToSort = ctx.freshName("needToSort") - ctx.addMutableState("boolean", needToSort, s"$needToSort = true;") + ctx.addMutableState(ctx.JAVA_BOOLEAN, needToSort, s"$needToSort = true;") // Initialize the class member variables. This includes the instance of the Sorter and // the iterator to return sorted rows. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala index 51f7c9e22b902..19c793e45a57d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala @@ -178,7 +178,7 @@ case class HashAggregateExec( private def doProduceWithoutKeys(ctx: CodegenContext): String = { val initAgg = ctx.freshName("initAgg") - ctx.addMutableState("boolean", initAgg, s"$initAgg = false;") + ctx.addMutableState(ctx.JAVA_BOOLEAN, initAgg, s"$initAgg = false;") // generate variables for aggregation buffer val functions = aggregateExpressions.map(_.aggregateFunction.asInstanceOf[DeclarativeAggregate]) @@ -186,8 +186,8 @@ case class HashAggregateExec( bufVars = initExpr.map { e => val isNull = ctx.freshName("bufIsNull") val value = ctx.freshName("bufValue") - ctx.addMutableState("boolean", isNull, "") - ctx.addMutableState(ctx.javaType(e.dataType), value, "") + ctx.addMutableState(ctx.JAVA_BOOLEAN, isNull) + ctx.addMutableState(ctx.javaType(e.dataType), value) // The initial expression should not access any column val ev = e.genCode(ctx) val initVars = s""" @@ -565,7 +565,7 @@ case class HashAggregateExec( private def doProduceWithKeys(ctx: CodegenContext): String = { val initAgg = ctx.freshName("initAgg") - ctx.addMutableState("boolean", initAgg, s"$initAgg = false;") + ctx.addMutableState(ctx.JAVA_BOOLEAN, initAgg, s"$initAgg = false;") if (sqlContext.conf.enableTwoLevelAggMap) { enableTwoLevelHashMap(ctx) } else { @@ -596,27 +596,27 @@ case class HashAggregateExec( s"$fastHashMapTerm = new $fastHashMapClassName();") ctx.addMutableState( "java.util.Iterator", - iterTermForFastHashMap, "") + iterTermForFastHashMap) } else { ctx.addMutableState(fastHashMapClassName, fastHashMapTerm, s"$fastHashMapTerm = new $fastHashMapClassName(" + s"$thisPlan.getTaskMemoryManager(), $thisPlan.getEmptyAggregationBuffer());") ctx.addMutableState( "org.apache.spark.unsafe.KVIterator", - iterTermForFastHashMap, "") + iterTermForFastHashMap) } } // create hashMap hashMapTerm = ctx.freshName("hashMap") val hashMapClassName = classOf[UnsafeFixedWidthAggregationMap].getName - ctx.addMutableState(hashMapClassName, hashMapTerm, "") + ctx.addMutableState(hashMapClassName, hashMapTerm) sorterTerm = ctx.freshName("sorter") - ctx.addMutableState(classOf[UnsafeKVExternalSorter].getName, sorterTerm, "") + ctx.addMutableState(classOf[UnsafeKVExternalSorter].getName, sorterTerm) // Create a name for iterator from HashMap val iterTerm = ctx.freshName("mapIter") - ctx.addMutableState(classOf[KVIterator[UnsafeRow, UnsafeRow]].getName, iterTerm, "") + ctx.addMutableState(classOf[KVIterator[UnsafeRow, UnsafeRow]].getName, iterTerm) def generateGenerateCode(): String = { if (isFastHashMapEnabled) { @@ -774,7 +774,7 @@ case class HashAggregateExec( val (checkFallbackForGeneratedHashMap, checkFallbackForBytesToBytesMap, resetCounter, incCounter) = if (testFallbackStartsAt.isDefined) { val countTerm = ctx.freshName("fallbackCounter") - ctx.addMutableState("int", countTerm, s"$countTerm = 0;") + ctx.addMutableState(ctx.JAVA_INT, countTerm, s"$countTerm = 0;") (s"$countTerm < ${testFallbackStartsAt.get._1}", s"$countTerm < ${testFallbackStartsAt.get._2}", s"$countTerm = 0;", s"$countTerm += 1;") } else { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashMapGenerator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashMapGenerator.scala index 90deb20e97244..85b4529501ea8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashMapGenerator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashMapGenerator.scala @@ -48,8 +48,8 @@ abstract class HashMapGenerator( initExpr.map { e => val isNull = ctx.freshName("bufIsNull") val value = ctx.freshName("bufValue") - ctx.addMutableState("boolean", isNull, "") - ctx.addMutableState(ctx.javaType(e.dataType), value, "") + ctx.addMutableState(ctx.JAVA_BOOLEAN, isNull) + ctx.addMutableState(ctx.javaType(e.dataType), value) val ev = e.genCode(ctx) val initVars = s""" diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala index 3c7daa0a45844..f205bdf3da709 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala @@ -368,9 +368,9 @@ case class RangeExec(range: org.apache.spark.sql.catalyst.plans.logical.Range) val numOutput = metricTerm(ctx, "numOutputRows") val initTerm = ctx.freshName("initRange") - ctx.addMutableState("boolean", initTerm, s"$initTerm = false;") + ctx.addMutableState(ctx.JAVA_BOOLEAN, initTerm, s"$initTerm = false;") val number = ctx.freshName("number") - ctx.addMutableState("long", number, s"$number = 0L;") + ctx.addMutableState(ctx.JAVA_LONG, number, s"$number = 0L;") val value = ctx.freshName("value") val ev = ExprCode("", "false", value) @@ -391,11 +391,11 @@ case class RangeExec(range: org.apache.spark.sql.catalyst.plans.logical.Range) // Once number == batchEnd, it's time to progress to the next batch. val batchEnd = ctx.freshName("batchEnd") - ctx.addMutableState("long", batchEnd, s"$batchEnd = 0;") + ctx.addMutableState(ctx.JAVA_LONG, batchEnd, s"$batchEnd = 0;") // How many values should still be generated by this range operator. val numElementsTodo = ctx.freshName("numElementsTodo") - ctx.addMutableState("long", numElementsTodo, s"$numElementsTodo = 0L;") + ctx.addMutableState(ctx.JAVA_LONG, numElementsTodo, s"$numElementsTodo = 0L;") // How many values should be generated in the next batch. val nextBatchTodo = ctx.freshName("nextBatchTodo") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/GenerateColumnAccessor.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/GenerateColumnAccessor.scala index ae600c1ffae8e..ff5dd707f0b38 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/GenerateColumnAccessor.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/GenerateColumnAccessor.scala @@ -89,7 +89,7 @@ object GenerateColumnAccessor extends CodeGenerator[Seq[DataType], ColumnarItera case array: ArrayType => classOf[ArrayColumnAccessor].getName case t: MapType => classOf[MapColumnAccessor].getName } - ctx.addMutableState(accessorCls, accessorName, "") + ctx.addMutableState(accessorCls, accessorName) val createCode = dt match { case t if ctx.isPrimitiveType(dt) => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala index cf7885f80d9fe..9c08ec71c1fde 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala @@ -423,7 +423,7 @@ case class SortMergeJoinExec( private def genScanner(ctx: CodegenContext): (String, String) = { // Create class member for next row from both sides. val leftRow = ctx.freshName("leftRow") - ctx.addMutableState("InternalRow", leftRow, "") + ctx.addMutableState("InternalRow", leftRow) val rightRow = ctx.freshName("rightRow") ctx.addMutableState("InternalRow", rightRow, s"$rightRow = null;") @@ -519,10 +519,10 @@ case class SortMergeJoinExec( val value = ctx.freshName("value") val valueCode = ctx.getValue(leftRow, a.dataType, i.toString) // declare it as class member, so we can access the column before or in the loop. - ctx.addMutableState(ctx.javaType(a.dataType), value, "") + ctx.addMutableState(ctx.javaType(a.dataType), value) if (a.nullable) { val isNull = ctx.freshName("isNull") - ctx.addMutableState("boolean", isNull, "") + ctx.addMutableState(ctx.JAVA_BOOLEAN, isNull) val code = s""" |$isNull = $leftRow.isNullAt($i); diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala index 13da4b26a5dcb..a8556f6ba107a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala @@ -72,7 +72,7 @@ trait BaseLimitExec extends UnaryExecNode with CodegenSupport { override def doConsume(ctx: CodegenContext, input: Seq[ExprCode], row: ExprCode): String = { val stopEarly = ctx.freshName("stopEarly") - ctx.addMutableState("boolean", stopEarly, s"$stopEarly = false;") + ctx.addMutableState(ctx.JAVA_BOOLEAN, stopEarly, s"$stopEarly = false;") ctx.addNewFunction("stopEarly", s""" @Override @@ -81,7 +81,7 @@ trait BaseLimitExec extends UnaryExecNode with CodegenSupport { } """, inlineToOuterClass = true) val countTerm = ctx.freshName("count") - ctx.addMutableState("int", countTerm, s"$countTerm = 0;") + ctx.addMutableState(ctx.JAVA_INT, countTerm, s"$countTerm = 0;") s""" | if ($countTerm < $limit) { | $countTerm += 1; From 5855b5c03e831af997dab9f2023792c9b8d2676c Mon Sep 17 00:00:00 2001 From: Prashant Sharma Date: Tue, 21 Nov 2017 07:25:56 -0600 Subject: [PATCH 1701/1765] [MINOR][DOC] The left navigation bar should be fixed with respect to scrolling. ## What changes were proposed in this pull request? A minor CSS style change to make Left navigation bar stay fixed with respect to scrolling, it improves usability of the docs. ## How was this patch tested? It was tested on both, firefox and chrome. ### Before ![a2](https://user-images.githubusercontent.com/992952/33004206-6acf9fc0-cde5-11e7-9070-02f26f7899b0.gif) ### After ![a1](https://user-images.githubusercontent.com/992952/33004205-69b27798-cde5-11e7-8002-509b29786b37.gif) Author: Prashant Sharma Closes #19785 from ScrapCodes/doc/css. --- docs/css/main.css | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/css/main.css b/docs/css/main.css index 175e8004fca0e..7f1e99bf67224 100755 --- a/docs/css/main.css +++ b/docs/css/main.css @@ -195,7 +195,7 @@ a.anchorjs-link:hover { text-decoration: none; } margin-top: 0px; width: 210px; float: left; - position: absolute; + position: fixed; } .left-menu { @@ -286,4 +286,4 @@ label[for="nav-trigger"]:hover { margin-left: -215px; } -} \ No newline at end of file +} From 2d868d93987ea1757cc66cdfb534bc49794eb0d0 Mon Sep 17 00:00:00 2001 From: WeichenXu Date: Tue, 21 Nov 2017 10:53:53 -0800 Subject: [PATCH 1702/1765] [SPARK-22521][ML] VectorIndexerModel support handle unseen categories via handleInvalid: Python API ## What changes were proposed in this pull request? Add python api for VectorIndexerModel support handle unseen categories via handleInvalid. ## How was this patch tested? doctest added. Author: WeichenXu Closes #19753 from WeichenXu123/vector_indexer_invalid_py. --- .../spark/ml/feature/VectorIndexer.scala | 7 +++-- python/pyspark/ml/feature.py | 30 ++++++++++++++----- 2 files changed, 27 insertions(+), 10 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala index 3403ec4259b86..e6ec4e2e36ff0 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala @@ -47,7 +47,8 @@ private[ml] trait VectorIndexerParams extends Params with HasInputCol with HasOu * Options are: * 'skip': filter out rows with invalid data. * 'error': throw an error. - * 'keep': put invalid data in a special additional bucket, at index numCategories. + * 'keep': put invalid data in a special additional bucket, at index of the number of + * categories of the feature. * Default value: "error" * @group param */ @@ -55,7 +56,8 @@ private[ml] trait VectorIndexerParams extends Params with HasInputCol with HasOu override val handleInvalid: Param[String] = new Param[String](this, "handleInvalid", "How to handle invalid data (unseen labels or NULL values). " + "Options are 'skip' (filter out rows with invalid data), 'error' (throw an error), " + - "or 'keep' (put invalid data in a special additional bucket, at index numLabels).", + "or 'keep' (put invalid data in a special additional bucket, at index of the " + + "number of categories of the feature).", ParamValidators.inArray(VectorIndexer.supportedHandleInvalids)) setDefault(handleInvalid, VectorIndexer.ERROR_INVALID) @@ -112,7 +114,6 @@ private[ml] trait VectorIndexerParams extends Params with HasInputCol with HasOu * - Preserve metadata in transform; if a feature's metadata is already present, do not recompute. * - Specify certain features to not index, either via a parameter or via existing metadata. * - Add warning if a categorical feature has only 1 category. - * - Add option for allowing unknown categories. */ @Since("1.4.0") class VectorIndexer @Since("1.4.0") ( diff --git a/python/pyspark/ml/feature.py b/python/pyspark/ml/feature.py index 232ae3ef41166..608f2a5715497 100755 --- a/python/pyspark/ml/feature.py +++ b/python/pyspark/ml/feature.py @@ -2490,7 +2490,8 @@ def setParams(self, inputCols=None, outputCol=None): @inherit_doc -class VectorIndexer(JavaEstimator, HasInputCol, HasOutputCol, JavaMLReadable, JavaMLWritable): +class VectorIndexer(JavaEstimator, HasInputCol, HasOutputCol, HasHandleInvalid, JavaMLReadable, + JavaMLWritable): """ Class for indexing categorical feature columns in a dataset of `Vector`. @@ -2525,7 +2526,6 @@ class VectorIndexer(JavaEstimator, HasInputCol, HasOutputCol, JavaMLReadable, Ja do not recompute. - Specify certain features to not index, either via a parameter or via existing metadata. - Add warning if a categorical feature has only 1 category. - - Add option for allowing unknown categories. >>> from pyspark.ml.linalg import Vectors >>> df = spark.createDataFrame([(Vectors.dense([-1.0, 0.0]),), @@ -2556,6 +2556,15 @@ class VectorIndexer(JavaEstimator, HasInputCol, HasOutputCol, JavaMLReadable, Ja True >>> loadedModel.categoryMaps == model.categoryMaps True + >>> dfWithInvalid = spark.createDataFrame([(Vectors.dense([3.0, 1.0]),)], ["a"]) + >>> indexer.getHandleInvalid() + 'error' + >>> model3 = indexer.setHandleInvalid("skip").fit(df) + >>> model3.transform(dfWithInvalid).count() + 0 + >>> model4 = indexer.setParams(handleInvalid="keep", outputCol="indexed").fit(df) + >>> model4.transform(dfWithInvalid).head().indexed + DenseVector([2.0, 1.0]) .. versionadded:: 1.4.0 """ @@ -2565,22 +2574,29 @@ class VectorIndexer(JavaEstimator, HasInputCol, HasOutputCol, JavaMLReadable, Ja "(>= 2). If a feature is found to have > maxCategories values, then " + "it is declared continuous.", typeConverter=TypeConverters.toInt) + handleInvalid = Param(Params._dummy(), "handleInvalid", "How to handle invalid data " + + "(unseen labels or NULL values). Options are 'skip' (filter out " + + "rows with invalid data), 'error' (throw an error), or 'keep' (put " + + "invalid data in a special additional bucket, at index of the number " + + "of categories of the feature).", + typeConverter=TypeConverters.toString) + @keyword_only - def __init__(self, maxCategories=20, inputCol=None, outputCol=None): + def __init__(self, maxCategories=20, inputCol=None, outputCol=None, handleInvalid="error"): """ - __init__(self, maxCategories=20, inputCol=None, outputCol=None) + __init__(self, maxCategories=20, inputCol=None, outputCol=None, handleInvalid="error") """ super(VectorIndexer, self).__init__() self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.VectorIndexer", self.uid) - self._setDefault(maxCategories=20) + self._setDefault(maxCategories=20, handleInvalid="error") kwargs = self._input_kwargs self.setParams(**kwargs) @keyword_only @since("1.4.0") - def setParams(self, maxCategories=20, inputCol=None, outputCol=None): + def setParams(self, maxCategories=20, inputCol=None, outputCol=None, handleInvalid="error"): """ - setParams(self, maxCategories=20, inputCol=None, outputCol=None) + setParams(self, maxCategories=20, inputCol=None, outputCol=None, handleInvalid="error") Sets params for this VectorIndexer. """ kwargs = self._input_kwargs From 6d7ebf2f9fbd043813738005a23c57a77eba6f47 Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Tue, 21 Nov 2017 20:53:38 +0100 Subject: [PATCH 1703/1765] [SPARK-22165][SQL] Fixes type conflicts between double, long, decimals, dates and timestamps in partition column ## What changes were proposed in this pull request? This PR proposes to add a rule that re-uses `TypeCoercion.findWiderCommonType` when resolving type conflicts in partition values. Currently, this uses numeric precedence-like comparison; therefore, it looks introducing failures for type conflicts between timestamps, dates and decimals, please see: ```scala private val upCastingOrder: Seq[DataType] = Seq(NullType, IntegerType, LongType, FloatType, DoubleType, StringType) ... literals.map(_.dataType).maxBy(upCastingOrder.indexOf(_)) ``` The codes below: ```scala val df = Seq((1, "2015-01-01"), (2, "2016-01-01 00:00:00")).toDF("i", "ts") df.write.format("parquet").partitionBy("ts").save("/tmp/foo") spark.read.load("/tmp/foo").printSchema() val df = Seq((1, "1"), (2, "1" * 30)).toDF("i", "decimal") df.write.format("parquet").partitionBy("decimal").save("/tmp/bar") spark.read.load("/tmp/bar").printSchema() ``` produces output as below: **Before** ``` root |-- i: integer (nullable = true) |-- ts: date (nullable = true) root |-- i: integer (nullable = true) |-- decimal: integer (nullable = true) ``` **After** ``` root |-- i: integer (nullable = true) |-- ts: timestamp (nullable = true) root |-- i: integer (nullable = true) |-- decimal: decimal(30,0) (nullable = true) ``` ### Type coercion table: This PR proposes the type conflict resolusion as below: **Before** |InputA \ InputB|`NullType`|`IntegerType`|`LongType`|`DecimalType(38,0)`|`DoubleType`|`DateType`|`TimestampType`|`StringType`| |------------------------|----------|----------|----------|----------|----------|----------|----------|----------| |**`NullType`**|`StringType`|`IntegerType`|`LongType`|`StringType`|`DoubleType`|`StringType`|`StringType`|`StringType`| |**`IntegerType`**|`IntegerType`|`IntegerType`|`LongType`|`IntegerType`|`DoubleType`|`IntegerType`|`IntegerType`|`StringType`| |**`LongType`**|`LongType`|`LongType`|`LongType`|`LongType`|`DoubleType`|`LongType`|`LongType`|`StringType`| |**`DecimalType(38,0)`**|`StringType`|`IntegerType`|`LongType`|`DecimalType(38,0)`|`DoubleType`|`DecimalType(38,0)`|`DecimalType(38,0)`|`StringType`| |**`DoubleType`**|`DoubleType`|`DoubleType`|`DoubleType`|`DoubleType`|`DoubleType`|`DoubleType`|`DoubleType`|`StringType`| |**`DateType`**|`StringType`|`IntegerType`|`LongType`|`DateType`|`DoubleType`|`DateType`|`DateType`|`StringType`| |**`TimestampType`**|`StringType`|`IntegerType`|`LongType`|`TimestampType`|`DoubleType`|`TimestampType`|`TimestampType`|`StringType`| |**`StringType`**|`StringType`|`StringType`|`StringType`|`StringType`|`StringType`|`StringType`|`StringType`|`StringType`| **After** |InputA \ InputB|`NullType`|`IntegerType`|`LongType`|`DecimalType(38,0)`|`DoubleType`|`DateType`|`TimestampType`|`StringType`| |------------------------|----------|----------|----------|----------|----------|----------|----------|----------| |**`NullType`**|`NullType`|`IntegerType`|`LongType`|`DecimalType(38,0)`|`DoubleType`|`DateType`|`TimestampType`|`StringType`| |**`IntegerType`**|`IntegerType`|`IntegerType`|`LongType`|`DecimalType(38,0)`|`DoubleType`|`StringType`|`StringType`|`StringType`| |**`LongType`**|`LongType`|`LongType`|`LongType`|`DecimalType(38,0)`|`StringType`|`StringType`|`StringType`|`StringType`| |**`DecimalType(38,0)`**|`DecimalType(38,0)`|`DecimalType(38,0)`|`DecimalType(38,0)`|`DecimalType(38,0)`|`StringType`|`StringType`|`StringType`|`StringType`| |**`DoubleType`**|`DoubleType`|`DoubleType`|`StringType`|`StringType`|`DoubleType`|`StringType`|`StringType`|`StringType`| |**`DateType`**|`DateType`|`StringType`|`StringType`|`StringType`|`StringType`|`DateType`|`TimestampType`|`StringType`| |**`TimestampType`**|`TimestampType`|`StringType`|`StringType`|`StringType`|`StringType`|`TimestampType`|`TimestampType`|`StringType`| |**`StringType`**|`StringType`|`StringType`|`StringType`|`StringType`|`StringType`|`StringType`|`StringType`|`StringType`| This was produced by: ```scala test("Print out chart") { val supportedTypes: Seq[DataType] = Seq( NullType, IntegerType, LongType, DecimalType(38, 0), DoubleType, DateType, TimestampType, StringType) // Old type conflict resolution: val upCastingOrder: Seq[DataType] = Seq(NullType, IntegerType, LongType, FloatType, DoubleType, StringType) def oldResolveTypeConflicts(dataTypes: Seq[DataType]): DataType = { val topType = dataTypes.maxBy(upCastingOrder.indexOf(_)) if (topType == NullType) StringType else topType } println(s"|InputA \\ InputB|${supportedTypes.map(dt => s"`${dt.toString}`").mkString("|")}|") println(s"|------------------------|${supportedTypes.map(_ => "----------").mkString("|")}|") supportedTypes.foreach { inputA => val types = supportedTypes.map(inputB => oldResolveTypeConflicts(Seq(inputA, inputB))) println(s"|**`$inputA`**|${types.map(dt => s"`${dt.toString}`").mkString("|")}|") } // New type conflict resolution: def newResolveTypeConflicts(dataTypes: Seq[DataType]): DataType = { dataTypes.fold[DataType](NullType)(findWiderTypeForPartitionColumn) } println(s"|InputA \\ InputB|${supportedTypes.map(dt => s"`${dt.toString}`").mkString("|")}|") println(s"|------------------------|${supportedTypes.map(_ => "----------").mkString("|")}|") supportedTypes.foreach { inputA => val types = supportedTypes.map(inputB => newResolveTypeConflicts(Seq(inputA, inputB))) println(s"|**`$inputA`**|${types.map(dt => s"`${dt.toString}`").mkString("|")}|") } } ``` ## How was this patch tested? Unit tests added in `ParquetPartitionDiscoverySuite`. Author: hyukjinkwon Closes #19389 from HyukjinKwon/partition-type-coercion. --- docs/sql-programming-guide.md | 139 ++++++++++++++++++ .../sql/catalyst/analysis/TypeCoercion.scala | 2 +- .../datasources/PartitioningUtils.scala | 60 +++++--- .../ParquetPartitionDiscoverySuite.scala | 57 ++++++- 4 files changed, 235 insertions(+), 23 deletions(-) diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md index 686fcb159d09d..5f9821378b271 100644 --- a/docs/sql-programming-guide.md +++ b/docs/sql-programming-guide.md @@ -1577,6 +1577,145 @@ options. - Since Spark 2.3, the queries from raw JSON/CSV files are disallowed when the referenced columns only include the internal corrupt record column (named `_corrupt_record` by default). For example, `spark.read.schema(schema).json(file).filter($"_corrupt_record".isNotNull).count()` and `spark.read.schema(schema).json(file).select("_corrupt_record").show()`. Instead, you can cache or save the parsed results and then send the same query. For example, `val df = spark.read.schema(schema).json(file).cache()` and then `df.filter($"_corrupt_record".isNotNull).count()`. - The `percentile_approx` function previously accepted numeric type input and output double type results. Now it supports date type, timestamp type and numeric types as input types. The result type is also changed to be the same as the input type, which is more reasonable for percentiles. + - Partition column inference previously found incorrect common type for different inferred types, for example, previously it ended up with double type as the common type for double type and date type. Now it finds the correct common type for such conflicts. The conflict resolution follows the table below: + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
    + InputA \ InputB + + NullType + + IntegerType + + LongType + + DecimalType(38,0)* + + DoubleType + + DateType + + TimestampType + + StringType +
    + NullType + NullTypeIntegerTypeLongTypeDecimalType(38,0)DoubleTypeDateTypeTimestampTypeStringType
    + IntegerType + IntegerTypeIntegerTypeLongTypeDecimalType(38,0)DoubleTypeStringTypeStringTypeStringType
    + LongType + LongTypeLongTypeLongTypeDecimalType(38,0)StringTypeStringTypeStringTypeStringType
    + DecimalType(38,0)* + DecimalType(38,0)DecimalType(38,0)DecimalType(38,0)DecimalType(38,0)StringTypeStringTypeStringTypeStringType
    + DoubleType + DoubleTypeDoubleTypeStringTypeStringTypeDoubleTypeStringTypeStringTypeStringType
    + DateType + DateTypeStringTypeStringTypeStringTypeStringTypeDateTypeTimestampTypeStringType
    + TimestampType + TimestampTypeStringTypeStringTypeStringTypeStringTypeTimestampTypeTimestampTypeStringType
    + StringType + StringTypeStringTypeStringTypeStringTypeStringTypeStringTypeStringTypeStringType
    + + Note that, for DecimalType(38,0)*, the table above intentionally does not cover all other combinations of scales and precisions because currently we only infer decimal type like `BigInteger`/`BigInt`. For example, 1.1 is inferred as double type. ## Upgrading From Spark SQL 2.1 to 2.2 diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala index 074eda56199e8..28be955e08a0d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala @@ -155,7 +155,7 @@ object TypeCoercion { * i.e. the main difference with [[findTightestCommonType]] is that here we allow some * loss of precision when widening decimal and double, and promotion to string. */ - private[analysis] def findWiderTypeForTwo(t1: DataType, t2: DataType): Option[DataType] = { + def findWiderTypeForTwo(t1: DataType, t2: DataType): Option[DataType] = { findTightestCommonType(t1, t2) .orElse(findWiderTypeForDecimal(t1, t2)) .orElse(stringPromotion(t1, t2)) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala index 1c00c9ebb4144..472bf82d3604d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala @@ -28,7 +28,7 @@ import org.apache.hadoop.fs.Path import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.analysis.Resolver +import org.apache.spark.sql.catalyst.analysis.{Resolver, TypeCoercion} import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec import org.apache.spark.sql.catalyst.expressions.{Cast, Literal} import org.apache.spark.sql.catalyst.util.DateTimeUtils @@ -309,13 +309,8 @@ object PartitioningUtils { } /** - * Resolves possible type conflicts between partitions by up-casting "lower" types. The up- - * casting order is: - * {{{ - * NullType -> - * IntegerType -> LongType -> - * DoubleType -> StringType - * }}} + * Resolves possible type conflicts between partitions by up-casting "lower" types using + * [[findWiderTypeForPartitionColumn]]. */ def resolvePartitions( pathsWithPartitionValues: Seq[(Path, PartitionValues)], @@ -372,11 +367,31 @@ object PartitioningUtils { suspiciousPaths.map("\t" + _).mkString("\n", "\n", "") } + // scalastyle:off line.size.limit /** - * Converts a string to a [[Literal]] with automatic type inference. Currently only supports - * [[IntegerType]], [[LongType]], [[DoubleType]], [[DecimalType]], [[DateType]] + * Converts a string to a [[Literal]] with automatic type inference. Currently only supports + * [[NullType]], [[IntegerType]], [[LongType]], [[DoubleType]], [[DecimalType]], [[DateType]] * [[TimestampType]], and [[StringType]]. + * + * When resolving conflicts, it follows the table below: + * + * +--------------------+-------------------+-------------------+-------------------+--------------------+------------+---------------+---------------+------------+ + * | InputA \ InputB | NullType | IntegerType | LongType | DecimalType(38,0)* | DoubleType | DateType | TimestampType | StringType | + * +--------------------+-------------------+-------------------+-------------------+--------------------+------------+---------------+---------------+------------+ + * | NullType | NullType | IntegerType | LongType | DecimalType(38,0) | DoubleType | DateType | TimestampType | StringType | + * | IntegerType | IntegerType | IntegerType | LongType | DecimalType(38,0) | DoubleType | StringType | StringType | StringType | + * | LongType | LongType | LongType | LongType | DecimalType(38,0) | StringType | StringType | StringType | StringType | + * | DecimalType(38,0)* | DecimalType(38,0) | DecimalType(38,0) | DecimalType(38,0) | DecimalType(38,0) | StringType | StringType | StringType | StringType | + * | DoubleType | DoubleType | DoubleType | StringType | StringType | DoubleType | StringType | StringType | StringType | + * | DateType | DateType | StringType | StringType | StringType | StringType | DateType | TimestampType | StringType | + * | TimestampType | TimestampType | StringType | StringType | StringType | StringType | TimestampType | TimestampType | StringType | + * | StringType | StringType | StringType | StringType | StringType | StringType | StringType | StringType | StringType | + * +--------------------+-------------------+-------------------+-------------------+--------------------+------------+---------------+---------------+------------+ + * Note that, for DecimalType(38,0)*, the table above intentionally does not cover all other + * combinations of scales and precisions because currently we only infer decimal type like + * `BigInteger`/`BigInt`. For example, 1.1 is inferred as double type. */ + // scalastyle:on line.size.limit private[datasources] def inferPartitionColumnValue( raw: String, typeInference: Boolean, @@ -427,9 +442,6 @@ object PartitioningUtils { } } - private val upCastingOrder: Seq[DataType] = - Seq(NullType, IntegerType, LongType, FloatType, DoubleType, StringType) - def validatePartitionColumn( schema: StructType, partitionColumns: Seq[String], @@ -468,18 +480,26 @@ object PartitioningUtils { } /** - * Given a collection of [[Literal]]s, resolves possible type conflicts by up-casting "lower" - * types. + * Given a collection of [[Literal]]s, resolves possible type conflicts by + * [[findWiderTypeForPartitionColumn]]. */ private def resolveTypeConflicts(literals: Seq[Literal], timeZone: TimeZone): Seq[Literal] = { - val desiredType = { - val topType = literals.map(_.dataType).maxBy(upCastingOrder.indexOf(_)) - // Falls back to string if all values of this column are null or empty string - if (topType == NullType) StringType else topType - } + val litTypes = literals.map(_.dataType) + val desiredType = litTypes.reduce(findWiderTypeForPartitionColumn) literals.map { case l @ Literal(_, dataType) => Literal.create(Cast(l, desiredType, Some(timeZone.getID)).eval(), desiredType) } } + + /** + * Type widening rule for partition column types. It is similar to + * [[TypeCoercion.findWiderTypeForTwo]] but the main difference is that here we disallow + * precision loss when widening double/long and decimal, and fall back to string. + */ + private val findWiderTypeForPartitionColumn: (DataType, DataType) => DataType = { + case (DoubleType, _: DecimalType) | (_: DecimalType, DoubleType) => StringType + case (DoubleType, LongType) | (LongType, DoubleType) => StringType + case (t1, t2) => TypeCoercion.findWiderTypeForTwo(t1, t2).getOrElse(StringType) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetPartitionDiscoverySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetPartitionDiscoverySuite.scala index f79b92b804c70..d4902641e335f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetPartitionDiscoverySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetPartitionDiscoverySuite.scala @@ -249,6 +249,11 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest with Sha true, rootPaths, timeZoneId) + assert(actualSpec.partitionColumns === spec.partitionColumns) + assert(actualSpec.partitions.length === spec.partitions.length) + actualSpec.partitions.zip(spec.partitions).foreach { case (actual, expected) => + assert(actual === expected) + } assert(actualSpec === spec) } @@ -314,7 +319,7 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest with Sha PartitionSpec( StructType(Seq( StructField("a", DoubleType), - StructField("b", StringType))), + StructField("b", NullType))), Seq( Partition(InternalRow(10, null), s"hdfs://host:9000/path/a=10/b=$defaultPartitionName"), Partition(InternalRow(10.5, null), @@ -324,6 +329,32 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest with Sha s"hdfs://host:9000/path1", s"hdfs://host:9000/path2"), PartitionSpec.emptySpec) + + // The cases below check the resolution for type conflicts. + val t1 = Timestamp.valueOf("2014-01-01 00:00:00.0").getTime * 1000 + val t2 = Timestamp.valueOf("2014-01-01 00:01:00.0").getTime * 1000 + // Values in column 'a' are inferred as null, date and timestamp each, and timestamp is set + // as a common type. + // Values in column 'b' are inferred as integer, decimal(22, 0) and null, and decimal(22, 0) + // is set as a common type. + check(Seq( + s"hdfs://host:9000/path/a=$defaultPartitionName/b=0", + s"hdfs://host:9000/path/a=2014-01-01/b=${Long.MaxValue}111", + s"hdfs://host:9000/path/a=2014-01-01 00%3A01%3A00.0/b=$defaultPartitionName"), + PartitionSpec( + StructType(Seq( + StructField("a", TimestampType), + StructField("b", DecimalType(22, 0)))), + Seq( + Partition( + InternalRow(null, Decimal(0)), + s"hdfs://host:9000/path/a=$defaultPartitionName/b=0"), + Partition( + InternalRow(t1, Decimal(s"${Long.MaxValue}111")), + s"hdfs://host:9000/path/a=2014-01-01/b=${Long.MaxValue}111"), + Partition( + InternalRow(t2, null), + s"hdfs://host:9000/path/a=2014-01-01 00%3A01%3A00.0/b=$defaultPartitionName")))) } test("parse partitions with type inference disabled") { @@ -395,7 +426,7 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest with Sha PartitionSpec( StructType(Seq( StructField("a", StringType), - StructField("b", StringType))), + StructField("b", NullType))), Seq( Partition(InternalRow(UTF8String.fromString("10"), null), s"hdfs://host:9000/path/a=10/b=$defaultPartitionName"), @@ -1067,4 +1098,26 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest with Sha checkAnswer(spark.read.load(path.getAbsolutePath), df) } } + + test("Resolve type conflicts - decimals, dates and timestamps in partition column") { + withTempPath { path => + val df = Seq((1, "2014-01-01"), (2, "2016-01-01"), (3, "2015-01-01 00:01:00")).toDF("i", "ts") + df.write.format("parquet").partitionBy("ts").save(path.getAbsolutePath) + checkAnswer( + spark.read.load(path.getAbsolutePath), + Row(1, Timestamp.valueOf("2014-01-01 00:00:00")) :: + Row(2, Timestamp.valueOf("2016-01-01 00:00:00")) :: + Row(3, Timestamp.valueOf("2015-01-01 00:01:00")) :: Nil) + } + + withTempPath { path => + val df = Seq((1, "1"), (2, "3"), (3, "2" * 30)).toDF("i", "decimal") + df.write.format("parquet").partitionBy("decimal").save(path.getAbsolutePath) + checkAnswer( + spark.read.load(path.getAbsolutePath), + Row(1, BigDecimal("1")) :: + Row(2, BigDecimal("3")) :: + Row(3, BigDecimal("2" * 30)) :: Nil) + } + } } From b96f61b6b262836e6be3f7657a3fe136d58b4dfe Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Tue, 21 Nov 2017 20:55:24 +0100 Subject: [PATCH 1704/1765] [SPARK-22475][SQL] show histogram in DESC COLUMN command ## What changes were proposed in this pull request? Added the histogram representation to the output of the `DESCRIBE EXTENDED table_name column_name` command. ## How was this patch tested? Modified SQL UT and checked output Please review http://spark.apache.org/contributing.html before opening a pull request. Author: Marco Gaido Closes #19774 from mgaido91/SPARK-22475. --- .../spark/sql/execution/command/tables.scala | 17 +++++ .../inputs/describe-table-column.sql | 10 +++ .../results/describe-table-column.sql.out | 74 +++++++++++++++++-- 3 files changed, 93 insertions(+), 8 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala index 95f16b0f4baea..c9f6e571ddab3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala @@ -34,6 +34,7 @@ import org.apache.spark.sql.catalyst.catalog._ import org.apache.spark.sql.catalyst.catalog.CatalogTableType._ import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference} +import org.apache.spark.sql.catalyst.plans.logical.Histogram import org.apache.spark.sql.catalyst.util.quoteIdentifier import org.apache.spark.sql.execution.datasources.{DataSource, PartitioningUtils} import org.apache.spark.sql.execution.datasources.csv.CSVFileFormat @@ -689,9 +690,25 @@ case class DescribeColumnCommand( buffer += Row("distinct_count", cs.map(_.distinctCount.toString).getOrElse("NULL")) buffer += Row("avg_col_len", cs.map(_.avgLen.toString).getOrElse("NULL")) buffer += Row("max_col_len", cs.map(_.maxLen.toString).getOrElse("NULL")) + val histDesc = for { + c <- cs + hist <- c.histogram + } yield histogramDescription(hist) + buffer ++= histDesc.getOrElse(Seq(Row("histogram", "NULL"))) } buffer } + + private def histogramDescription(histogram: Histogram): Seq[Row] = { + val header = Row("histogram", + s"height: ${histogram.height}, num_of_bins: ${histogram.bins.length}") + val bins = histogram.bins.zipWithIndex.map { + case (bin, index) => + Row(s"bin_$index", + s"lower_bound: ${bin.lo}, upper_bound: ${bin.hi}, distinct_count: ${bin.ndv}") + } + header +: bins + } } /** diff --git a/sql/core/src/test/resources/sql-tests/inputs/describe-table-column.sql b/sql/core/src/test/resources/sql-tests/inputs/describe-table-column.sql index a6ddcd999bf9b..2d180d118da7a 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/describe-table-column.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/describe-table-column.sql @@ -34,6 +34,16 @@ DESC FORMATTED desc_complex_col_table col; -- Describe a nested column DESC FORMATTED desc_complex_col_table col.x; +-- Test output for histogram statistics +SET spark.sql.statistics.histogram.enabled=true; +SET spark.sql.statistics.histogram.numBins=2; + +INSERT INTO desc_col_table values 1, 2, 3, 4; + +ANALYZE TABLE desc_col_table COMPUTE STATISTICS FOR COLUMNS key; + +DESC EXTENDED desc_col_table key; + DROP VIEW desc_col_temp_view; DROP TABLE desc_col_table; diff --git a/sql/core/src/test/resources/sql-tests/results/describe-table-column.sql.out b/sql/core/src/test/resources/sql-tests/results/describe-table-column.sql.out index 30d0a2dc5a3f7..6ef8af6574e98 100644 --- a/sql/core/src/test/resources/sql-tests/results/describe-table-column.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/describe-table-column.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 18 +-- Number of queries: 23 -- !query 0 @@ -34,6 +34,7 @@ num_nulls NULL distinct_count NULL avg_col_len NULL max_col_len NULL +histogram NULL -- !query 3 @@ -50,6 +51,7 @@ num_nulls NULL distinct_count NULL avg_col_len NULL max_col_len NULL +histogram NULL -- !query 4 @@ -66,6 +68,7 @@ num_nulls NULL distinct_count NULL avg_col_len NULL max_col_len NULL +histogram NULL -- !query 5 @@ -117,6 +120,7 @@ num_nulls 0 distinct_count 0 avg_col_len 4 max_col_len 4 +histogram NULL -- !query 10 @@ -133,6 +137,7 @@ num_nulls 0 distinct_count 0 avg_col_len 4 max_col_len 4 +histogram NULL -- !query 11 @@ -157,6 +162,7 @@ num_nulls NULL distinct_count NULL avg_col_len NULL max_col_len NULL +histogram NULL -- !query 13 @@ -173,6 +179,7 @@ num_nulls NULL distinct_count NULL avg_col_len NULL max_col_len NULL +histogram NULL -- !query 14 @@ -185,24 +192,75 @@ DESC TABLE COLUMN command does not support nested data types: col.x; -- !query 15 -DROP VIEW desc_col_temp_view +SET spark.sql.statistics.histogram.enabled=true -- !query 15 schema -struct<> +struct -- !query 15 output - +spark.sql.statistics.histogram.enabled true -- !query 16 -DROP TABLE desc_col_table +SET spark.sql.statistics.histogram.numBins=2 -- !query 16 schema -struct<> +struct -- !query 16 output - +spark.sql.statistics.histogram.numBins 2 -- !query 17 -DROP TABLE desc_complex_col_table +INSERT INTO desc_col_table values 1, 2, 3, 4 -- !query 17 schema struct<> -- !query 17 output + + +-- !query 18 +ANALYZE TABLE desc_col_table COMPUTE STATISTICS FOR COLUMNS key +-- !query 18 schema +struct<> +-- !query 18 output + + + +-- !query 19 +DESC EXTENDED desc_col_table key +-- !query 19 schema +struct +-- !query 19 output +col_name key +data_type int +comment column_comment +min 1 +max 4 +num_nulls 0 +distinct_count 4 +avg_col_len 4 +max_col_len 4 +histogram height: 2.0, num_of_bins: 2 +bin_0 lower_bound: 1.0, upper_bound: 2.0, distinct_count: 2 +bin_1 lower_bound: 2.0, upper_bound: 4.0, distinct_count: 2 + + +-- !query 20 +DROP VIEW desc_col_temp_view +-- !query 20 schema +struct<> +-- !query 20 output + + + +-- !query 21 +DROP TABLE desc_col_table +-- !query 21 schema +struct<> +-- !query 21 output + + + +-- !query 22 +DROP TABLE desc_complex_col_table +-- !query 22 schema +struct<> +-- !query 22 output + From ac10171bea2fc027d6691393b385b3fc0ef3293d Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Tue, 21 Nov 2017 22:24:43 +0100 Subject: [PATCH 1705/1765] [SPARK-22500][SQL] Fix 64KB JVM bytecode limit problem with cast ## What changes were proposed in this pull request? This PR changes `cast` code generation to place generated code for expression for fields of a structure into separated methods if these size could be large. ## How was this patch tested? Added new test cases into `CastSuite` Author: Kazuaki Ishizaki Closes #19730 from kiszk/SPARK-22500. --- .../spark/sql/catalyst/expressions/Cast.scala | 12 ++++++++++-- .../sql/catalyst/expressions/CastSuite.scala | 18 ++++++++++++++++++ 2 files changed, 28 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala index bc809f559d586..12baddf1bf7ac 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala @@ -1039,13 +1039,21 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String } } """ - }.mkString("\n") + } + val fieldsEvalCodes = if (ctx.INPUT_ROW != null && ctx.currentVars == null) { + ctx.splitExpressions( + expressions = fieldsEvalCode, + funcName = "castStruct", + arguments = ("InternalRow", tmpRow) :: (rowClass, result) :: Nil) + } else { + fieldsEvalCode.mkString("\n") + } (c, evPrim, evNull) => s""" final $rowClass $result = new $rowClass(${fieldsCasts.length}); final InternalRow $tmpRow = $c; - $fieldsEvalCode + $fieldsEvalCodes $evPrim = $result; """ } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala index a7ffa884d2286..84bd8b2f91e4f 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala @@ -827,4 +827,22 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(cast(Literal.create(input, from), to), input) } + + test("SPARK-22500: cast for struct should not generate codes beyond 64KB") { + val N = 250 + + val fromInner = new StructType( + (1 to N).map(i => StructField(s"s$i", DoubleType)).toArray) + val toInner = new StructType( + (1 to N).map(i => StructField(s"i$i", IntegerType)).toArray) + val inputInner = Row.fromSeq((1 to N).map(i => i + 0.5)) + val outputInner = Row.fromSeq((1 to N)) + val fromOuter = new StructType( + (1 to N).map(i => StructField(s"s$i", fromInner)).toArray) + val toOuter = new StructType( + (1 to N).map(i => StructField(s"s$i", toInner)).toArray) + val inputOuter = Row.fromSeq((1 to N).map(_ => inputInner)) + val outputOuter = Row.fromSeq((1 to N).map(_ => outputInner)) + checkEvaluation(cast(Literal.create(inputOuter, fromOuter), toOuter), outputOuter) + } } From 881c5c807304a305ef96e805d51afbde097f7f4f Mon Sep 17 00:00:00 2001 From: Jia Li Date: Tue, 21 Nov 2017 17:30:02 -0800 Subject: [PATCH 1706/1765] [SPARK-22548][SQL] Incorrect nested AND expression pushed down to JDBC data source MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## What changes were proposed in this pull request? Let’s say I have a nested AND expression shown below and p2 can not be pushed down, (p1 AND p2) OR p3 In current Spark code, during data source filter translation, (p1 AND p2) is returned as p1 only and p2 is simply lost. This issue occurs with JDBC data source and is similar to [SPARK-12218](https://github.com/apache/spark/pull/10362) for Parquet. When we have AND nested below another expression, we should either push both legs or nothing. Note that: - The current Spark code will always split conjunctive predicate before it determines if a predicate can be pushed down or not - If I have (p1 AND p2) AND p3, it will be split into p1, p2, p3. There won't be nested AND expression. - The current Spark code logic for OR is OK. It either pushes both legs or nothing. The same translation method is also called by Data Source V2. ## How was this patch tested? Added new unit test cases to JDBCSuite gatorsmile Author: Jia Li Closes #19776 from jliwork/spark-22548. --- .../datasources/DataSourceStrategy.scala | 14 +- .../datasources/DataSourceStrategySuite.scala | 231 ++++++++++++++++++ .../org/apache/spark/sql/jdbc/JDBCSuite.scala | 9 +- 3 files changed, 250 insertions(+), 4 deletions(-) create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategySuite.scala diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala index 04d6f3f56eb02..400f2e03165b2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala @@ -497,7 +497,19 @@ object DataSourceStrategy { Some(sources.IsNotNull(a.name)) case expressions.And(left, right) => - (translateFilter(left) ++ translateFilter(right)).reduceOption(sources.And) + // See SPARK-12218 for detailed discussion + // It is not safe to just convert one side if we do not understand the + // other side. Here is an example used to explain the reason. + // Let's say we have (a = 2 AND trim(b) = 'blah') OR (c > 0) + // and we do not understand how to convert trim(b) = 'blah'. + // If we only convert a = 2, we will end up with + // (a = 2) OR (c > 0), which will generate wrong results. + // Pushing one leg of AND down is only safe to do at the top level. + // You can see ParquetFilters' createFilter for more details. + for { + leftFilter <- translateFilter(left) + rightFilter <- translateFilter(right) + } yield sources.And(leftFilter, rightFilter) case expressions.Or(left, right) => for { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategySuite.scala new file mode 100644 index 0000000000000..f20aded169e44 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategySuite.scala @@ -0,0 +1,231 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources + +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans.PlanTest +import org.apache.spark.sql.sources +import org.apache.spark.sql.test.SharedSQLContext + +class DataSourceStrategySuite extends PlanTest with SharedSQLContext { + + test("translate simple expression") { + val attrInt = 'cint.int + val attrStr = 'cstr.string + + testTranslateFilter(EqualTo(attrInt, 1), Some(sources.EqualTo("cint", 1))) + testTranslateFilter(EqualTo(1, attrInt), Some(sources.EqualTo("cint", 1))) + + testTranslateFilter(EqualNullSafe(attrStr, Literal(null)), + Some(sources.EqualNullSafe("cstr", null))) + testTranslateFilter(EqualNullSafe(Literal(null), attrStr), + Some(sources.EqualNullSafe("cstr", null))) + + testTranslateFilter(GreaterThan(attrInt, 1), Some(sources.GreaterThan("cint", 1))) + testTranslateFilter(GreaterThan(1, attrInt), Some(sources.LessThan("cint", 1))) + + testTranslateFilter(LessThan(attrInt, 1), Some(sources.LessThan("cint", 1))) + testTranslateFilter(LessThan(1, attrInt), Some(sources.GreaterThan("cint", 1))) + + testTranslateFilter(GreaterThanOrEqual(attrInt, 1), Some(sources.GreaterThanOrEqual("cint", 1))) + testTranslateFilter(GreaterThanOrEqual(1, attrInt), Some(sources.LessThanOrEqual("cint", 1))) + + testTranslateFilter(LessThanOrEqual(attrInt, 1), Some(sources.LessThanOrEqual("cint", 1))) + testTranslateFilter(LessThanOrEqual(1, attrInt), Some(sources.GreaterThanOrEqual("cint", 1))) + + testTranslateFilter(InSet(attrInt, Set(1, 2, 3)), Some(sources.In("cint", Array(1, 2, 3)))) + + testTranslateFilter(In(attrInt, Seq(1, 2, 3)), Some(sources.In("cint", Array(1, 2, 3)))) + + testTranslateFilter(IsNull(attrInt), Some(sources.IsNull("cint"))) + testTranslateFilter(IsNotNull(attrInt), Some(sources.IsNotNull("cint"))) + + // cint > 1 AND cint < 10 + testTranslateFilter(And( + GreaterThan(attrInt, 1), + LessThan(attrInt, 10)), + Some(sources.And( + sources.GreaterThan("cint", 1), + sources.LessThan("cint", 10)))) + + // cint >= 8 OR cint <= 2 + testTranslateFilter(Or( + GreaterThanOrEqual(attrInt, 8), + LessThanOrEqual(attrInt, 2)), + Some(sources.Or( + sources.GreaterThanOrEqual("cint", 8), + sources.LessThanOrEqual("cint", 2)))) + + testTranslateFilter(Not(GreaterThanOrEqual(attrInt, 8)), + Some(sources.Not(sources.GreaterThanOrEqual("cint", 8)))) + + testTranslateFilter(StartsWith(attrStr, "a"), Some(sources.StringStartsWith("cstr", "a"))) + + testTranslateFilter(EndsWith(attrStr, "a"), Some(sources.StringEndsWith("cstr", "a"))) + + testTranslateFilter(Contains(attrStr, "a"), Some(sources.StringContains("cstr", "a"))) + } + + test("translate complex expression") { + val attrInt = 'cint.int + + // ABS(cint) - 2 <= 1 + testTranslateFilter(LessThanOrEqual( + // Expressions are not supported + // Functions such as 'Abs' are not supported + Subtract(Abs(attrInt), 2), 1), None) + + // (cin1 > 1 AND cint < 10) OR (cint > 50 AND cint > 100) + testTranslateFilter(Or( + And( + GreaterThan(attrInt, 1), + LessThan(attrInt, 10) + ), + And( + GreaterThan(attrInt, 50), + LessThan(attrInt, 100))), + Some(sources.Or( + sources.And( + sources.GreaterThan("cint", 1), + sources.LessThan("cint", 10)), + sources.And( + sources.GreaterThan("cint", 50), + sources.LessThan("cint", 100))))) + + // SPARK-22548 Incorrect nested AND expression pushed down to JDBC data source + // (cint > 1 AND ABS(cint) < 10) OR (cint < 50 AND cint > 100) + testTranslateFilter(Or( + And( + GreaterThan(attrInt, 1), + // Functions such as 'Abs' are not supported + LessThan(Abs(attrInt), 10) + ), + And( + GreaterThan(attrInt, 50), + LessThan(attrInt, 100))), None) + + // NOT ((cint <= 1 OR ABS(cint) >= 10) AND (cint <= 50 OR cint >= 100)) + testTranslateFilter(Not(And( + Or( + LessThanOrEqual(attrInt, 1), + // Functions such as 'Abs' are not supported + GreaterThanOrEqual(Abs(attrInt), 10) + ), + Or( + LessThanOrEqual(attrInt, 50), + GreaterThanOrEqual(attrInt, 100)))), None) + + // (cint = 1 OR cint = 10) OR (cint > 0 OR cint < -10) + testTranslateFilter(Or( + Or( + EqualTo(attrInt, 1), + EqualTo(attrInt, 10) + ), + Or( + GreaterThan(attrInt, 0), + LessThan(attrInt, -10))), + Some(sources.Or( + sources.Or( + sources.EqualTo("cint", 1), + sources.EqualTo("cint", 10)), + sources.Or( + sources.GreaterThan("cint", 0), + sources.LessThan("cint", -10))))) + + // (cint = 1 OR ABS(cint) = 10) OR (cint > 0 OR cint < -10) + testTranslateFilter(Or( + Or( + EqualTo(attrInt, 1), + // Functions such as 'Abs' are not supported + EqualTo(Abs(attrInt), 10) + ), + Or( + GreaterThan(attrInt, 0), + LessThan(attrInt, -10))), None) + + // In end-to-end testing, conjunctive predicate should has been split + // before reaching DataSourceStrategy.translateFilter. + // This is for UT purpose to test each [[case]]. + // (cint > 1 AND cint < 10) AND (cint = 6 AND cint IS NOT NULL) + testTranslateFilter(And( + And( + GreaterThan(attrInt, 1), + LessThan(attrInt, 10) + ), + And( + EqualTo(attrInt, 6), + IsNotNull(attrInt))), + Some(sources.And( + sources.And( + sources.GreaterThan("cint", 1), + sources.LessThan("cint", 10)), + sources.And( + sources.EqualTo("cint", 6), + sources.IsNotNull("cint"))))) + + // (cint > 1 AND cint < 10) AND (ABS(cint) = 6 AND cint IS NOT NULL) + testTranslateFilter(And( + And( + GreaterThan(attrInt, 1), + LessThan(attrInt, 10) + ), + And( + // Functions such as 'Abs' are not supported + EqualTo(Abs(attrInt), 6), + IsNotNull(attrInt))), None) + + // (cint > 1 OR cint < 10) AND (cint = 6 OR cint IS NOT NULL) + testTranslateFilter(And( + Or( + GreaterThan(attrInt, 1), + LessThan(attrInt, 10) + ), + Or( + EqualTo(attrInt, 6), + IsNotNull(attrInt))), + Some(sources.And( + sources.Or( + sources.GreaterThan("cint", 1), + sources.LessThan("cint", 10)), + sources.Or( + sources.EqualTo("cint", 6), + sources.IsNotNull("cint"))))) + + // (cint > 1 OR cint < 10) AND (cint = 6 OR cint IS NOT NULL) + testTranslateFilter(And( + Or( + GreaterThan(attrInt, 1), + LessThan(attrInt, 10) + ), + Or( + // Functions such as 'Abs' are not supported + EqualTo(Abs(attrInt), 6), + IsNotNull(attrInt))), None) + } + + /** + * Translate the given Catalyst [[Expression]] into data source [[sources.Filter]] + * then verify against the given [[sources.Filter]]. + */ + def testTranslateFilter(catalystFilter: Expression, result: Option[sources.Filter]): Unit = { + assertResult(result) { + DataSourceStrategy.translateFilter(catalystFilter) + } + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala index 88a5f618d604d..61571bccdcb51 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala @@ -294,10 +294,13 @@ class JDBCSuite extends SparkFunSuite // This is a test to reflect discussion in SPARK-12218. // The older versions of spark have this kind of bugs in parquet data source. - val df1 = sql("SELECT * FROM foobar WHERE NOT (THEID != 2 AND NAME != 'mary')") - val df2 = sql("SELECT * FROM foobar WHERE NOT (THEID != 2) OR NOT (NAME != 'mary')") + val df1 = sql("SELECT * FROM foobar WHERE NOT (THEID != 2) OR NOT (NAME != 'mary')") assert(df1.collect.toSet === Set(Row("mary", 2))) - assert(df2.collect.toSet === Set(Row("mary", 2))) + + // SPARK-22548: Incorrect nested AND expression pushed down to JDBC data source + val df2 = sql("SELECT * FROM foobar " + + "WHERE (THEID > 0 AND TRIM(NAME) = 'mary') OR (NAME = 'fred')") + assert(df2.collect.toSet === Set(Row("fred", 1), Row("mary", 2))) def checkNotPushdown(df: DataFrame): DataFrame = { val parentPlan = df.queryExecution.executedPlan From e0d7665cec1e6954d640f422c79ebba4c273be7d Mon Sep 17 00:00:00 2001 From: vinodkc Date: Tue, 21 Nov 2017 22:31:46 -0800 Subject: [PATCH 1707/1765] [SPARK-17920][SPARK-19580][SPARK-19878][SQL] Support writing to Hive table which uses Avro schema url 'avro.schema.url' ## What changes were proposed in this pull request? SPARK-19580 Support for avro.schema.url while writing to hive table SPARK-19878 Add hive configuration when initialize hive serde in InsertIntoHiveTable.scala SPARK-17920 HiveWriterContainer passes null configuration to serde.initialize, causing NullPointerException in AvroSerde when using avro.schema.url Support writing to Hive table which uses Avro schema url 'avro.schema.url' For ex: create external table avro_in (a string) stored as avro location '/avro-in/' tblproperties ('avro.schema.url'='/avro-schema/avro.avsc'); create external table avro_out (a string) stored as avro location '/avro-out/' tblproperties ('avro.schema.url'='/avro-schema/avro.avsc'); insert overwrite table avro_out select * from avro_in; // fails with java.lang.NullPointerException WARN AvroSerDe: Encountered exception determining schema. Returning signal schema to indicate problem java.lang.NullPointerException at org.apache.hadoop.fs.FileSystem.getDefaultUri(FileSystem.java:182) at org.apache.hadoop.fs.FileSystem.get(FileSystem.java:174) ## Changes proposed in this fix Currently 'null' value is passed to serializer, which causes NPE during insert operation, instead pass Hadoop configuration object ## How was this patch tested? Added new test case in VersionsSuite Author: vinodkc Closes #19779 from vinodkc/br_Fix_SPARK-17920. --- .../sql/hive/execution/HiveFileFormat.scala | 4 +- .../spark/sql/hive/client/VersionsSuite.scala | 72 ++++++++++++++++++- 2 files changed, 73 insertions(+), 3 deletions(-) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveFileFormat.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveFileFormat.scala index ac735e8b383f6..4a7cd6901923b 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveFileFormat.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveFileFormat.scala @@ -116,7 +116,7 @@ class HiveOutputWriter( private val serializer = { val serializer = tableDesc.getDeserializerClass.newInstance().asInstanceOf[Serializer] - serializer.initialize(null, tableDesc.getProperties) + serializer.initialize(jobConf, tableDesc.getProperties) serializer } @@ -130,7 +130,7 @@ class HiveOutputWriter( private val standardOI = ObjectInspectorUtils .getStandardObjectInspector( - tableDesc.getDeserializer.getObjectInspector, + tableDesc.getDeserializer(jobConf).getObjectInspector, ObjectInspectorCopyOption.JAVA) .asInstanceOf[StructObjectInspector] diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala index 9ed39cc80f50d..fbf6877c994a4 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.hive.client -import java.io.{ByteArrayOutputStream, File, PrintStream} +import java.io.{ByteArrayOutputStream, File, PrintStream, PrintWriter} import java.net.URI import org.apache.hadoop.conf.Configuration @@ -841,6 +841,76 @@ class VersionsSuite extends SparkFunSuite with Logging { } } + test(s"$version: SPARK-17920: Insert into/overwrite avro table") { + withTempDir { dir => + val path = dir.getAbsolutePath + val schemaPath = s"""$path${File.separator}avroschemadir""" + + new File(schemaPath).mkdir() + val avroSchema = + """{ + | "name": "test_record", + | "type": "record", + | "fields": [ { + | "name": "f0", + | "type": [ + | "null", + | { + | "precision": 38, + | "scale": 2, + | "type": "bytes", + | "logicalType": "decimal" + | } + | ] + | } ] + |} + """.stripMargin + val schemaUrl = s"""$schemaPath${File.separator}avroDecimal.avsc""" + val schemaFile = new File(schemaPath, "avroDecimal.avsc") + val writer = new PrintWriter(schemaFile) + writer.write(avroSchema) + writer.close() + + val url = Thread.currentThread().getContextClassLoader.getResource("avroDecimal") + val srcLocation = new File(url.getFile) + val destTableName = "tab1" + val srcTableName = "tab2" + + withTable(srcTableName, destTableName) { + versionSpark.sql( + s""" + |CREATE EXTERNAL TABLE $srcTableName + |ROW FORMAT SERDE 'org.apache.hadoop.hive.serde2.avro.AvroSerDe' + |WITH SERDEPROPERTIES ('respectSparkSchema' = 'true') + |STORED AS + | INPUTFORMAT 'org.apache.hadoop.hive.ql.io.avro.AvroContainerInputFormat' + | OUTPUTFORMAT 'org.apache.hadoop.hive.ql.io.avro.AvroContainerOutputFormat' + |LOCATION '$srcLocation' + |TBLPROPERTIES ('avro.schema.url' = '$schemaUrl') + """.stripMargin + ) + + versionSpark.sql( + s""" + |CREATE TABLE $destTableName + |ROW FORMAT SERDE 'org.apache.hadoop.hive.serde2.avro.AvroSerDe' + |WITH SERDEPROPERTIES ('respectSparkSchema' = 'true') + |STORED AS + | INPUTFORMAT 'org.apache.hadoop.hive.ql.io.avro.AvroContainerInputFormat' + | OUTPUTFORMAT 'org.apache.hadoop.hive.ql.io.avro.AvroContainerOutputFormat' + |TBLPROPERTIES ('avro.schema.url' = '$schemaUrl') + """.stripMargin + ) + versionSpark.sql( + s"""INSERT OVERWRITE TABLE $destTableName SELECT * FROM $srcTableName""") + val result = versionSpark.table(srcTableName).collect() + assert(versionSpark.table(destTableName).collect() === result) + versionSpark.sql( + s"""INSERT INTO TABLE $destTableName SELECT * FROM $srcTableName""") + assert(versionSpark.table(destTableName).collect().toSeq === result ++ result) + } + } + } // TODO: add more tests. } } From 2c0fe818a624cfdc76c752ec6bfe6a42e5680604 Mon Sep 17 00:00:00 2001 From: Takeshi Yamamuro Date: Wed, 22 Nov 2017 09:09:50 +0100 Subject: [PATCH 1708/1765] [SPARK-22445][SQL][FOLLOW-UP] Respect stream-side child's needCopyResult in BroadcastHashJoin ## What changes were proposed in this pull request? I found #19656 causes some bugs, for example, it changed the result set of `q6` in tpcds (I keep tracking TPCDS results daily [here](https://github.com/maropu/spark-tpcds-datagen/tree/master/reports/tests)): - w/o pr19658 ``` +-----+---+ |state|cnt| +-----+---+ | MA| 10| | AK| 10| | AZ| 11| | ME| 13| | VT| 14| | NV| 15| | NH| 16| | UT| 17| | NJ| 21| | MD| 22| | WY| 25| | NM| 26| | OR| 31| | WA| 36| | ND| 38| | ID| 39| | SC| 45| | WV| 50| | FL| 51| | OK| 53| | MT| 53| | CO| 57| | AR| 58| | NY| 58| | PA| 62| | AL| 63| | LA| 63| | SD| 70| | WI| 80| | null| 81| | MI| 82| | NC| 82| | MS| 83| | CA| 84| | MN| 85| | MO| 88| | IL| 95| | IA|102| | TN|102| | IN|103| | KY|104| | NE|113| | OH|114| | VA|130| | KS|139| | GA|168| | TX|216| +-----+---+ ``` - w/ pr19658 ``` +-----+---+ |state|cnt| +-----+---+ | RI| 14| | AK| 16| | FL| 20| | NJ| 21| | NM| 21| | NV| 22| | MA| 22| | MD| 22| | UT| 22| | AZ| 25| | SC| 28| | AL| 36| | MT| 36| | WA| 39| | ND| 41| | MI| 44| | AR| 45| | OR| 47| | OK| 52| | PA| 53| | LA| 55| | CO| 55| | NY| 64| | WV| 66| | SD| 72| | MS| 73| | NC| 79| | IN| 82| | null| 85| | ID| 88| | MN| 91| | WI| 95| | IL| 96| | MO| 97| | CA|109| | CA|109| | TN|114| | NE|115| | KY|128| | OH|131| | IA|156| | TX|160| | VA|182| | KS|211| | GA|230| +-----+---+ ``` This pr is to keep the original logic of `CodegenContext.copyResult` in `BroadcastHashJoinExec`. ## How was this patch tested? Existing tests Author: Takeshi Yamamuro Closes #19781 from maropu/SPARK-22445-bugfix. --- .../joins/BroadcastHashJoinExec.scala | 15 ++++++----- .../org/apache/spark/sql/JoinSuite.scala | 27 ++++++++++++++++++- 2 files changed, 35 insertions(+), 7 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala index 837b8525fed55..c96ed6ef41016 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala @@ -76,20 +76,23 @@ case class BroadcastHashJoinExec( streamedPlan.asInstanceOf[CodegenSupport].inputRDDs() } - override def needCopyResult: Boolean = joinType match { + private def multipleOutputForOneInput: Boolean = joinType match { case _: InnerLike | LeftOuter | RightOuter => // For inner and outer joins, one row from the streamed side may produce multiple result rows, - // if the build side has duplicated keys. Then we need to copy the result rows before putting - // them in a buffer, because these result rows share one UnsafeRow instance. Note that here - // we wait for the broadcast to be finished, which is a no-op because it's already finished - // when we wait it in `doProduce`. + // if the build side has duplicated keys. Note that here we wait for the broadcast to be + // finished, which is a no-op because it's already finished when we wait it in `doProduce`. !buildPlan.executeBroadcast[HashedRelation]().value.keyIsUnique // Other joins types(semi, anti, existence) can at most produce one result row for one input - // row from the streamed side, so no need to copy the result rows. + // row from the streamed side. case _ => false } + // If the streaming side needs to copy result, this join plan needs to copy too. Otherwise, + // this join plan only needs to copy result if it may output multiple rows for one input. + override def needCopyResult: Boolean = + streamedPlan.asInstanceOf[CodegenSupport].needCopyResult || multipleOutputForOneInput + override def doProduce(ctx: CodegenContext): String = { streamedPlan.asInstanceOf[CodegenSupport].produce(ctx, this) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala index 226cc3028b135..771e1186e63ab 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala @@ -25,7 +25,7 @@ import org.apache.spark.TestUtils.{assertNotSpilled, assertSpilled} import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation import org.apache.spark.sql.catalyst.expressions.{Ascending, SortOrder} -import org.apache.spark.sql.execution.SortExec +import org.apache.spark.sql.execution.{BinaryExecNode, SortExec} import org.apache.spark.sql.execution.joins._ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSQLContext @@ -857,4 +857,29 @@ class JoinSuite extends QueryTest with SharedSQLContext { joinQueries.foreach(assertJoinOrdering) } + + test("SPARK-22445 Respect stream-side child's needCopyResult in BroadcastHashJoin") { + val df1 = Seq((2, 3), (2, 5), (2, 2), (3, 8), (2, 1)).toDF("k", "v1") + val df2 = Seq((2, 8), (3, 7), (3, 4), (1, 2)).toDF("k", "v2") + val df3 = Seq((1, 1), (3, 2), (4, 3), (5, 1)).toDF("k", "v3") + + withSQLConf( + SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1", + SQLConf.JOIN_REORDER_ENABLED.key -> "false") { + val df = df1.join(df2, "k").join(functions.broadcast(df3), "k") + val plan = df.queryExecution.sparkPlan + + // Check if `needCopyResult` in `BroadcastHashJoin` is correct when smj->bhj + val joins = new collection.mutable.ArrayBuffer[BinaryExecNode]() + plan.foreachUp { + case j: BroadcastHashJoinExec => joins += j + case j: SortMergeJoinExec => joins += j + case _ => + } + assert(joins.size == 2) + assert(joins(0).isInstanceOf[SortMergeJoinExec]) + assert(joins(1).isInstanceOf[BroadcastHashJoinExec]) + checkAnswer(df, Row(3, 8, 7, 2) :: Row(3, 8, 4, 2) :: Nil) + } + } } From 572af5027e45ca96e0d283a8bf7c84dcf476f9bc Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Wed, 22 Nov 2017 13:27:20 +0100 Subject: [PATCH 1709/1765] [SPARK-20101][SQL][FOLLOW-UP] use correct config name "spark.sql.columnVector.offheap.enabled" ## What changes were proposed in this pull request? This PR addresses [the spelling miss](https://github.com/apache/spark/pull/17436#discussion_r152189670) of the config name `spark.sql.columnVector.offheap.enabled`. We should use `spark.sql.columnVector.offheap.enabled`. ## How was this patch tested? Existing tests Author: Kazuaki Ishizaki Closes #19794 from kiszk/SPARK-20101-follow. --- .../src/main/scala/org/apache/spark/sql/internal/SQLConf.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 8485ed4c887d4..4eda9f337953e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -141,7 +141,7 @@ object SQLConf { .createWithDefault(true) val COLUMN_VECTOR_OFFHEAP_ENABLED = - buildConf("spark.sql.columnVector.offheap.enable") + buildConf("spark.sql.columnVector.offheap.enabled") .internal() .doc("When true, use OffHeapColumnVector in ColumnarBatch.") .booleanConf From 327d25fe1741f62cd84097e94739f82ecb05383a Mon Sep 17 00:00:00 2001 From: Mark Petruska Date: Wed, 22 Nov 2017 21:35:47 +0900 Subject: [PATCH 1710/1765] [SPARK-22572][SPARK SHELL] spark-shell does not re-initialize on :replay ## What changes were proposed in this pull request? Ticket: [SPARK-22572](https://issues.apache.org/jira/browse/SPARK-22572) ## How was this patch tested? Added a new test case to `org.apache.spark.repl.ReplSuite` Author: Mark Petruska Closes #19791 from mpetruska/SPARK-22572. --- .../org/apache/spark/repl/SparkILoop.scala | 75 +++++++++++-------- .../org/apache/spark/repl/SparkILoop.scala | 74 ++++++++++-------- .../org/apache/spark/repl/ReplSuite.scala | 10 +++ 3 files changed, 96 insertions(+), 63 deletions(-) diff --git a/repl/scala-2.11/src/main/scala/org/apache/spark/repl/SparkILoop.scala b/repl/scala-2.11/src/main/scala/org/apache/spark/repl/SparkILoop.scala index ea279e4f0ebce..3ce7cc7c85f74 100644 --- a/repl/scala-2.11/src/main/scala/org/apache/spark/repl/SparkILoop.scala +++ b/repl/scala-2.11/src/main/scala/org/apache/spark/repl/SparkILoop.scala @@ -35,40 +35,45 @@ class SparkILoop(in0: Option[BufferedReader], out: JPrintWriter) def this(in0: BufferedReader, out: JPrintWriter) = this(Some(in0), out) def this() = this(None, new JPrintWriter(Console.out, true)) + val initializationCommands: Seq[String] = Seq( + """ + @transient val spark = if (org.apache.spark.repl.Main.sparkSession != null) { + org.apache.spark.repl.Main.sparkSession + } else { + org.apache.spark.repl.Main.createSparkSession() + } + @transient val sc = { + val _sc = spark.sparkContext + if (_sc.getConf.getBoolean("spark.ui.reverseProxy", false)) { + val proxyUrl = _sc.getConf.get("spark.ui.reverseProxyUrl", null) + if (proxyUrl != null) { + println( + s"Spark Context Web UI is available at ${proxyUrl}/proxy/${_sc.applicationId}") + } else { + println(s"Spark Context Web UI is available at Spark Master Public URL") + } + } else { + _sc.uiWebUrl.foreach { + webUrl => println(s"Spark context Web UI available at ${webUrl}") + } + } + println("Spark context available as 'sc' " + + s"(master = ${_sc.master}, app id = ${_sc.applicationId}).") + println("Spark session available as 'spark'.") + _sc + } + """, + "import org.apache.spark.SparkContext._", + "import spark.implicits._", + "import spark.sql", + "import org.apache.spark.sql.functions._" + ) + def initializeSpark() { intp.beQuietDuring { - processLine(""" - @transient val spark = if (org.apache.spark.repl.Main.sparkSession != null) { - org.apache.spark.repl.Main.sparkSession - } else { - org.apache.spark.repl.Main.createSparkSession() - } - @transient val sc = { - val _sc = spark.sparkContext - if (_sc.getConf.getBoolean("spark.ui.reverseProxy", false)) { - val proxyUrl = _sc.getConf.get("spark.ui.reverseProxyUrl", null) - if (proxyUrl != null) { - println( - s"Spark Context Web UI is available at ${proxyUrl}/proxy/${_sc.applicationId}") - } else { - println(s"Spark Context Web UI is available at Spark Master Public URL") - } - } else { - _sc.uiWebUrl.foreach { - webUrl => println(s"Spark context Web UI available at ${webUrl}") - } - } - println("Spark context available as 'sc' " + - s"(master = ${_sc.master}, app id = ${_sc.applicationId}).") - println("Spark session available as 'spark'.") - _sc - } - """) - processLine("import org.apache.spark.SparkContext._") - processLine("import spark.implicits._") - processLine("import spark.sql") - processLine("import org.apache.spark.sql.functions._") - replayCommandStack = Nil // remove above commands from session history. + savingReplayStack { // remove the commands from session history. + initializationCommands.foreach(processLine) + } } } @@ -107,6 +112,12 @@ class SparkILoop(in0: Option[BufferedReader], out: JPrintWriter) initializeSpark() echo("Note that after :reset, state of SparkSession and SparkContext is unchanged.") } + + override def replay(): Unit = { + initializeSpark() + super.replay() + } + } object SparkILoop { diff --git a/repl/scala-2.12/src/main/scala/org/apache/spark/repl/SparkILoop.scala b/repl/scala-2.12/src/main/scala/org/apache/spark/repl/SparkILoop.scala index 900edd63cb90e..ffb2e5f5db7e2 100644 --- a/repl/scala-2.12/src/main/scala/org/apache/spark/repl/SparkILoop.scala +++ b/repl/scala-2.12/src/main/scala/org/apache/spark/repl/SparkILoop.scala @@ -32,39 +32,45 @@ class SparkILoop(in0: Option[BufferedReader], out: JPrintWriter) def this(in0: BufferedReader, out: JPrintWriter) = this(Some(in0), out) def this() = this(None, new JPrintWriter(Console.out, true)) + val initializationCommands: Seq[String] = Seq( + """ + @transient val spark = if (org.apache.spark.repl.Main.sparkSession != null) { + org.apache.spark.repl.Main.sparkSession + } else { + org.apache.spark.repl.Main.createSparkSession() + } + @transient val sc = { + val _sc = spark.sparkContext + if (_sc.getConf.getBoolean("spark.ui.reverseProxy", false)) { + val proxyUrl = _sc.getConf.get("spark.ui.reverseProxyUrl", null) + if (proxyUrl != null) { + println( + s"Spark Context Web UI is available at ${proxyUrl}/proxy/${_sc.applicationId}") + } else { + println(s"Spark Context Web UI is available at Spark Master Public URL") + } + } else { + _sc.uiWebUrl.foreach { + webUrl => println(s"Spark context Web UI available at ${webUrl}") + } + } + println("Spark context available as 'sc' " + + s"(master = ${_sc.master}, app id = ${_sc.applicationId}).") + println("Spark session available as 'spark'.") + _sc + } + """, + "import org.apache.spark.SparkContext._", + "import spark.implicits._", + "import spark.sql", + "import org.apache.spark.sql.functions._" + ) + def initializeSpark() { intp.beQuietDuring { - command(""" - @transient val spark = if (org.apache.spark.repl.Main.sparkSession != null) { - org.apache.spark.repl.Main.sparkSession - } else { - org.apache.spark.repl.Main.createSparkSession() - } - @transient val sc = { - val _sc = spark.sparkContext - if (_sc.getConf.getBoolean("spark.ui.reverseProxy", false)) { - val proxyUrl = _sc.getConf.get("spark.ui.reverseProxyUrl", null) - if (proxyUrl != null) { - println( - s"Spark Context Web UI is available at ${proxyUrl}/proxy/${_sc.applicationId}") - } else { - println(s"Spark Context Web UI is available at Spark Master Public URL") - } - } else { - _sc.uiWebUrl.foreach { - webUrl => println(s"Spark context Web UI available at ${webUrl}") - } - } - println("Spark context available as 'sc' " + - s"(master = ${_sc.master}, app id = ${_sc.applicationId}).") - println("Spark session available as 'spark'.") - _sc - } - """) - command("import org.apache.spark.SparkContext._") - command("import spark.implicits._") - command("import spark.sql") - command("import org.apache.spark.sql.functions._") + savingReplayStack { // remove the commands from session history. + initializationCommands.foreach(command) + } } } @@ -103,6 +109,12 @@ class SparkILoop(in0: Option[BufferedReader], out: JPrintWriter) initializeSpark() echo("Note that after :reset, state of SparkSession and SparkContext is unchanged.") } + + override def replay(): Unit = { + initializeSpark() + super.replay() + } + } object SparkILoop { diff --git a/repl/src/test/scala/org/apache/spark/repl/ReplSuite.scala b/repl/src/test/scala/org/apache/spark/repl/ReplSuite.scala index c7ae1940d0297..905b41cdc1594 100644 --- a/repl/src/test/scala/org/apache/spark/repl/ReplSuite.scala +++ b/repl/src/test/scala/org/apache/spark/repl/ReplSuite.scala @@ -217,4 +217,14 @@ class ReplSuite extends SparkFunSuite { assertDoesNotContain("error:", output) assertDoesNotContain("Exception", output) } + + test(":replay should work correctly") { + val output = runInterpreter("local", + """ + |sc + |:replay + """.stripMargin) + assertDoesNotContain("error: not found: value sc", output) + } + } From 0605ad761438b202ab077a6af342f48cab2825d8 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Wed, 22 Nov 2017 10:05:46 -0800 Subject: [PATCH 1711/1765] [SPARK-22543][SQL] fix java 64kb compile error for deeply nested expressions ## What changes were proposed in this pull request? A frequently reported issue of Spark is the Java 64kb compile error. This is because Spark generates a very big method and it's usually caused by 3 reasons: 1. a deep expression tree, e.g. a very complex filter condition 2. many individual expressions, e.g. expressions can have many children, operators can have many expressions. 3. a deep query plan tree (with whole stage codegen) This PR focuses on 1. There are already several patches(#15620 #18972 #18641) trying to fix this issue and some of them are already merged. However this is an endless job as every non-leaf expression has this issue. This PR proposes to fix this issue in `Expression.genCode`, to make sure the code for a single expression won't grow too big. According to maropu 's benchmark, no regression is found with TPCDS (thanks maropu !): https://docs.google.com/spreadsheets/d/1K3_7lX05-ZgxDXi9X_GleNnDjcnJIfoSlSCDZcL4gdg/edit?usp=sharing ## How was this patch tested? existing test Author: Wenchen Fan Author: Wenchen Fan Closes #19767 from cloud-fan/codegen. --- .../sql/catalyst/expressions/Expression.scala | 40 ++++++++- .../expressions/codegen/CodeGenerator.scala | 33 +------- .../expressions/conditionalExpressions.scala | 60 ++++---------- .../expressions/namedExpressions.scala | 4 +- .../sql/catalyst/expressions/predicates.scala | 82 +------------------ .../expressions/CodeGenerationSuite.scala | 4 +- .../aggregate/HashAggregateExec.scala | 2 + 7 files changed, 62 insertions(+), 163 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala index a3b722a47d688..743782a6453e9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala @@ -104,16 +104,48 @@ abstract class Expression extends TreeNode[Expression] { }.getOrElse { val isNull = ctx.freshName("isNull") val value = ctx.freshName("value") - val ve = doGenCode(ctx, ExprCode("", isNull, value)) - if (ve.code.nonEmpty) { + val eval = doGenCode(ctx, ExprCode("", isNull, value)) + reduceCodeSize(ctx, eval) + if (eval.code.nonEmpty) { // Add `this` in the comment. - ve.copy(code = s"${ctx.registerComment(this.toString)}\n" + ve.code.trim) + eval.copy(code = s"${ctx.registerComment(this.toString)}\n" + eval.code.trim) } else { - ve + eval } } } + private def reduceCodeSize(ctx: CodegenContext, eval: ExprCode): Unit = { + // TODO: support whole stage codegen too + if (eval.code.trim.length > 1024 && ctx.INPUT_ROW != null && ctx.currentVars == null) { + val setIsNull = if (eval.isNull != "false" && eval.isNull != "true") { + val globalIsNull = ctx.freshName("globalIsNull") + ctx.addMutableState(ctx.JAVA_BOOLEAN, globalIsNull) + val localIsNull = eval.isNull + eval.isNull = globalIsNull + s"$globalIsNull = $localIsNull;" + } else { + "" + } + + val javaType = ctx.javaType(dataType) + val newValue = ctx.freshName("value") + + val funcName = ctx.freshName(nodeName) + val funcFullName = ctx.addNewFunction(funcName, + s""" + |private $javaType $funcName(InternalRow ${ctx.INPUT_ROW}) { + | ${eval.code.trim} + | $setIsNull + | return ${eval.value}; + |} + """.stripMargin) + + eval.value = newValue + eval.code = s"$javaType $newValue = $funcFullName(${ctx.INPUT_ROW});" + } + } + /** * Returns Java source code that can be compiled to evaluate this expression. * The default behavior is to call the eval method of the expression. Concrete expression diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index 78617194e47d5..9df8a8d6f6609 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -930,36 +930,6 @@ class CodegenContext { } } - /** - * Wrap the generated code of expression, which was created from a row object in INPUT_ROW, - * by a function. ev.isNull and ev.value are passed by global variables - * - * @param ev the code to evaluate expressions. - * @param dataType the data type of ev.value. - * @param baseFuncName the split function name base. - */ - def createAndAddFunction( - ev: ExprCode, - dataType: DataType, - baseFuncName: String): (String, String, String) = { - val globalIsNull = freshName("isNull") - addMutableState(JAVA_BOOLEAN, globalIsNull, s"$globalIsNull = false;") - val globalValue = freshName("value") - addMutableState(javaType(dataType), globalValue, - s"$globalValue = ${defaultValue(dataType)};") - val funcName = freshName(baseFuncName) - val funcBody = - s""" - |private void $funcName(InternalRow ${INPUT_ROW}) { - | ${ev.code.trim} - | $globalIsNull = ${ev.isNull}; - | $globalValue = ${ev.value}; - |} - """.stripMargin - val fullFuncName = addNewFunction(funcName, funcBody) - (fullFuncName, globalIsNull, globalValue) - } - /** * Perform a function which generates a sequence of ExprCodes with a given mapping between * expressions and common expressions, instead of using the mapping in current context. @@ -1065,7 +1035,8 @@ class CodegenContext { * elimination will be performed. Subexpression elimination assumes that the code for each * expression will be combined in the `expressions` order. */ - def generateExpressions(expressions: Seq[Expression], + def generateExpressions( + expressions: Seq[Expression], doSubexpressionElimination: Boolean = false): Seq[ExprCode] = { if (doSubexpressionElimination) subexpressionElimination(expressions) expressions.map(e => e.genCode(this)) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala index c41a10c7b0f87..6195be3a258c4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala @@ -64,52 +64,22 @@ case class If(predicate: Expression, trueValue: Expression, falseValue: Expressi val trueEval = trueValue.genCode(ctx) val falseEval = falseValue.genCode(ctx) - // place generated code of condition, true value and false value in separate methods if - // their code combined is large - val combinedLength = condEval.code.length + trueEval.code.length + falseEval.code.length - val generatedCode = if (combinedLength > 1024 && - // Split these expressions only if they are created from a row object - (ctx.INPUT_ROW != null && ctx.currentVars == null)) { - - val (condFuncName, condGlobalIsNull, condGlobalValue) = - ctx.createAndAddFunction(condEval, predicate.dataType, "evalIfCondExpr") - val (trueFuncName, trueGlobalIsNull, trueGlobalValue) = - ctx.createAndAddFunction(trueEval, trueValue.dataType, "evalIfTrueExpr") - val (falseFuncName, falseGlobalIsNull, falseGlobalValue) = - ctx.createAndAddFunction(falseEval, falseValue.dataType, "evalIfFalseExpr") + val code = s""" - $condFuncName(${ctx.INPUT_ROW}); - boolean ${ev.isNull} = false; - ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; - if (!$condGlobalIsNull && $condGlobalValue) { - $trueFuncName(${ctx.INPUT_ROW}); - ${ev.isNull} = $trueGlobalIsNull; - ${ev.value} = $trueGlobalValue; - } else { - $falseFuncName(${ctx.INPUT_ROW}); - ${ev.isNull} = $falseGlobalIsNull; - ${ev.value} = $falseGlobalValue; - } - """ - } - else { - s""" - ${condEval.code} - boolean ${ev.isNull} = false; - ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; - if (!${condEval.isNull} && ${condEval.value}) { - ${trueEval.code} - ${ev.isNull} = ${trueEval.isNull}; - ${ev.value} = ${trueEval.value}; - } else { - ${falseEval.code} - ${ev.isNull} = ${falseEval.isNull}; - ${ev.value} = ${falseEval.value}; - } - """ - } - - ev.copy(code = generatedCode) + |${condEval.code} + |boolean ${ev.isNull} = false; + |${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; + |if (!${condEval.isNull} && ${condEval.value}) { + | ${trueEval.code} + | ${ev.isNull} = ${trueEval.isNull}; + | ${ev.value} = ${trueEval.value}; + |} else { + | ${falseEval.code} + | ${ev.isNull} = ${falseEval.isNull}; + | ${ev.value} = ${falseEval.value}; + |} + """.stripMargin + ev.copy(code = code) } override def toString: String = s"if ($predicate) $trueValue else $falseValue" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala index e518e73cba549..8df870468c2ad 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala @@ -140,7 +140,9 @@ case class Alias(child: Expression, name: String)( /** Just a simple passthrough for code generation. */ override def genCode(ctx: CodegenContext): ExprCode = child.genCode(ctx) - override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = ev.copy("") + override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + throw new IllegalStateException("Alias.doGenCode should not be called.") + } override def dataType: DataType = child.dataType override def nullable: Boolean = child.nullable diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala index c0084af320689..eb7475354b104 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala @@ -378,46 +378,7 @@ case class And(left: Expression, right: Expression) extends BinaryOperator with val eval2 = right.genCode(ctx) // The result should be `false`, if any of them is `false` whenever the other is null or not. - - // place generated code of eval1 and eval2 in separate methods if their code combined is large - val combinedLength = eval1.code.length + eval2.code.length - if (combinedLength > 1024 && - // Split these expressions only if they are created from a row object - (ctx.INPUT_ROW != null && ctx.currentVars == null)) { - - val (eval1FuncName, eval1GlobalIsNull, eval1GlobalValue) = - ctx.createAndAddFunction(eval1, BooleanType, "eval1Expr") - val (eval2FuncName, eval2GlobalIsNull, eval2GlobalValue) = - ctx.createAndAddFunction(eval2, BooleanType, "eval2Expr") - if (!left.nullable && !right.nullable) { - val generatedCode = s""" - $eval1FuncName(${ctx.INPUT_ROW}); - boolean ${ev.value} = false; - if (${eval1GlobalValue}) { - $eval2FuncName(${ctx.INPUT_ROW}); - ${ev.value} = ${eval2GlobalValue}; - } - """ - ev.copy(code = generatedCode, isNull = "false") - } else { - val generatedCode = s""" - $eval1FuncName(${ctx.INPUT_ROW}); - boolean ${ev.isNull} = false; - boolean ${ev.value} = false; - if (!${eval1GlobalIsNull} && !${eval1GlobalValue}) { - } else { - $eval2FuncName(${ctx.INPUT_ROW}); - if (!${eval2GlobalIsNull} && !${eval2GlobalValue}) { - } else if (!${eval1GlobalIsNull} && !${eval2GlobalIsNull}) { - ${ev.value} = true; - } else { - ${ev.isNull} = true; - } - } - """ - ev.copy(code = generatedCode) - } - } else if (!left.nullable && !right.nullable) { + if (!left.nullable && !right.nullable) { ev.copy(code = s""" ${eval1.code} boolean ${ev.value} = false; @@ -480,46 +441,7 @@ case class Or(left: Expression, right: Expression) extends BinaryOperator with P val eval2 = right.genCode(ctx) // The result should be `true`, if any of them is `true` whenever the other is null or not. - - // place generated code of eval1 and eval2 in separate methods if their code combined is large - val combinedLength = eval1.code.length + eval2.code.length - if (combinedLength > 1024 && - // Split these expressions only if they are created from a row object - (ctx.INPUT_ROW != null && ctx.currentVars == null)) { - - val (eval1FuncName, eval1GlobalIsNull, eval1GlobalValue) = - ctx.createAndAddFunction(eval1, BooleanType, "eval1Expr") - val (eval2FuncName, eval2GlobalIsNull, eval2GlobalValue) = - ctx.createAndAddFunction(eval2, BooleanType, "eval2Expr") - if (!left.nullable && !right.nullable) { - val generatedCode = s""" - $eval1FuncName(${ctx.INPUT_ROW}); - boolean ${ev.value} = true; - if (!${eval1GlobalValue}) { - $eval2FuncName(${ctx.INPUT_ROW}); - ${ev.value} = ${eval2GlobalValue}; - } - """ - ev.copy(code = generatedCode, isNull = "false") - } else { - val generatedCode = s""" - $eval1FuncName(${ctx.INPUT_ROW}); - boolean ${ev.isNull} = false; - boolean ${ev.value} = true; - if (!${eval1GlobalIsNull} && ${eval1GlobalValue}) { - } else { - $eval2FuncName(${ctx.INPUT_ROW}); - if (!${eval2GlobalIsNull} && ${eval2GlobalValue}) { - } else if (!${eval1GlobalIsNull} && !${eval2GlobalIsNull}) { - ${ev.value} = false; - } else { - ${ev.isNull} = true; - } - } - """ - ev.copy(code = generatedCode) - } - } else if (!left.nullable && !right.nullable) { + if (!left.nullable && !right.nullable) { ev.isNull = "false" ev.copy(code = s""" ${eval1.code} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala index 8f6289f00571c..6e33087b4c6c8 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala @@ -97,7 +97,7 @@ class CodeGenerationSuite extends SparkFunSuite with ExpressionEvalHelper { assert(actual(0) == cases) } - test("SPARK-18091: split large if expressions into blocks due to JVM code size limit") { + test("SPARK-22543: split large if expressions into blocks due to JVM code size limit") { var strExpr: Expression = Literal("abc") for (_ <- 1 to 150) { strExpr = Decode(Encode(strExpr, "utf-8"), "utf-8") @@ -342,7 +342,7 @@ class CodeGenerationSuite extends SparkFunSuite with ExpressionEvalHelper { projection(row) } - test("SPARK-21720: split large predications into blocks due to JVM code size limit") { + test("SPARK-22543: split large predicates into blocks due to JVM code size limit") { val length = 600 val input = new GenericInternalRow(length) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala index 19c793e45a57d..dc8aecf185a96 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala @@ -179,6 +179,8 @@ case class HashAggregateExec( private def doProduceWithoutKeys(ctx: CodegenContext): String = { val initAgg = ctx.freshName("initAgg") ctx.addMutableState(ctx.JAVA_BOOLEAN, initAgg, s"$initAgg = false;") + // The generated function doesn't have input row in the code context. + ctx.INPUT_ROW = null // generate variables for aggregation buffer val functions = aggregateExpressions.map(_.aggregateFunction.asInstanceOf[DeclarativeAggregate]) From 1edb3175d8358c2f6bfc84a0d958342bd5337a62 Mon Sep 17 00:00:00 2001 From: Ilya Matiach Date: Wed, 22 Nov 2017 15:45:45 -0800 Subject: [PATCH 1712/1765] [SPARK-21866][ML][PYSPARK] Adding spark image reader ## What changes were proposed in this pull request? Adding spark image reader, an implementation of schema for representing images in spark DataFrames The code is taken from the spark package located here: (https://github.com/Microsoft/spark-images) Please see the JIRA for more information (https://issues.apache.org/jira/browse/SPARK-21866) Please see mailing list for SPIP vote and approval information: (http://apache-spark-developers-list.1001551.n3.nabble.com/VOTE-SPIP-SPARK-21866-Image-support-in-Apache-Spark-td22510.html) # Background and motivation As Apache Spark is being used more and more in the industry, some new use cases are emerging for different data formats beyond the traditional SQL types or the numerical types (vectors and matrices). Deep Learning applications commonly deal with image processing. A number of projects add some Deep Learning capabilities to Spark (see list below), but they struggle to communicate with each other or with MLlib pipelines because there is no standard way to represent an image in Spark DataFrames. We propose to federate efforts for representing images in Spark by defining a representation that caters to the most common needs of users and library developers. This SPIP proposes a specification to represent images in Spark DataFrames and Datasets (based on existing industrial standards), and an interface for loading sources of images. It is not meant to be a full-fledged image processing library, but rather the core description that other libraries and users can rely on. Several packages already offer various processing facilities for transforming images or doing more complex operations, and each has various design tradeoffs that make them better as standalone solutions. This project is a joint collaboration between Microsoft and Databricks, which have been testing this design in two open source packages: MMLSpark and Deep Learning Pipelines. The proposed image format is an in-memory, decompressed representation that targets low-level applications. It is significantly more liberal in memory usage than compressed image representations such as JPEG, PNG, etc., but it allows easy communication with popular image processing libraries and has no decoding overhead. ## How was this patch tested? Unit tests in scala ImageSchemaSuite, unit tests in python Author: Ilya Matiach Author: hyukjinkwon Closes #19439 from imatiach-msft/ilmat/spark-images. --- .../images/kittens/29.5.a_b_EGDP022204.jpg | Bin 0 -> 27295 bytes data/mllib/images/kittens/54893.jpg | Bin 0 -> 35914 bytes data/mllib/images/kittens/DP153539.jpg | Bin 0 -> 26354 bytes data/mllib/images/kittens/DP802813.jpg | Bin 0 -> 30432 bytes data/mllib/images/kittens/not-image.txt | 1 + data/mllib/images/license.txt | 13 + data/mllib/images/multi-channel/BGRA.png | Bin 0 -> 683 bytes .../images/multi-channel/chr30.4.184.jpg | Bin 0 -> 59472 bytes data/mllib/images/multi-channel/grayscale.jpg | Bin 0 -> 36728 bytes dev/sparktestsupport/modules.py | 1 + .../apache/spark/ml/image/HadoopUtils.scala | 116 ++++++++ .../apache/spark/ml/image/ImageSchema.scala | 257 ++++++++++++++++++ .../spark/ml/image/ImageSchemaSuite.scala | 108 ++++++++ python/docs/pyspark.ml.rst | 8 + python/pyspark/ml/image.py | 198 ++++++++++++++ python/pyspark/ml/tests.py | 19 ++ 16 files changed, 721 insertions(+) create mode 100644 data/mllib/images/kittens/29.5.a_b_EGDP022204.jpg create mode 100644 data/mllib/images/kittens/54893.jpg create mode 100644 data/mllib/images/kittens/DP153539.jpg create mode 100644 data/mllib/images/kittens/DP802813.jpg create mode 100644 data/mllib/images/kittens/not-image.txt create mode 100644 data/mllib/images/license.txt create mode 100644 data/mllib/images/multi-channel/BGRA.png create mode 100644 data/mllib/images/multi-channel/chr30.4.184.jpg create mode 100644 data/mllib/images/multi-channel/grayscale.jpg create mode 100644 mllib/src/main/scala/org/apache/spark/ml/image/HadoopUtils.scala create mode 100644 mllib/src/main/scala/org/apache/spark/ml/image/ImageSchema.scala create mode 100644 mllib/src/test/scala/org/apache/spark/ml/image/ImageSchemaSuite.scala create mode 100644 python/pyspark/ml/image.py diff --git a/data/mllib/images/kittens/29.5.a_b_EGDP022204.jpg b/data/mllib/images/kittens/29.5.a_b_EGDP022204.jpg new file mode 100644 index 0000000000000000000000000000000000000000..435e7dfd6a459822d7f8feaeec2b1f3c5b2a3f76 GIT binary patch literal 27295 zcmeGCcUV-*vM>&>VaQRUAZY{yCFh)xBoZV?5rzSVFat9~&Zrp3ARwrqf*_JqG9p28 z5{Z(t1Q8?%f}s4?pxeFAzUSQcp8G!U_k4exhFPn+s=B(my1F{8qmiQ-;KT(jT`d3t zfdHxCA8<4aaB2D@odG~!AK(W7fE1+B0Z zpO6qwBv4`_0G|K>8-Wo(rR{^oIrySs8V;UbST+JPfJN8S!3BwOft^QSkuE5hi5CKn zbnrl8acl%w0COJcg2Xv^!0=oMd6=my0;Z3^VbES^4-k)n>HA<2KAteNGfWMG6NDKc z{9smSjJpEN6^HYZ7ZLIE^Aq+&;5?7ngwYrm5hThPjq!BAA<-xi90KEs6>>m13Bl1Q zCnO#RDn$g4ipvR0IS9fW*$7Yo$A4cb8-YK>TEhVc^F!iXVeUvA7KTQ_oE-29JHZ^$ z{xGmUm;(yw2{yyi$I~-F5GE}xITxNRAqrC$))1!BbP4dnAstYHu!{~j1k4D5L86^N zx)BcTf$(%d!hfNtG|*mN2vC-X2Rc9jrXvg!6NAA$VP_>_QX*o~FeFO&oC3^37$zqy zd5o6;8U7?(1oN%0yrc96S-GnwF*@iVe^Ni~tkB5HJ8>AnXNT0W^RDVG{rc z1OOg@4~T(*JQxrU10(=YXFwD$?N>EM7&O}11cCGMGIVrF9&>z#&n`B0&<_F}H@kmEKP~Dd71Dguut2cB1=>oeV)o z1AQE%pTu)&c;QfgG#xaJBj$I;*afTqJEMU?Y5bW%;r`5Abo9XfrWm;3e1B&&Jw49< z&VbhcQ?xqV-Q|}dj++Mn#@gx{pf>oB0vJ2NU`}WsM^*F;K6{+BG2npu8&2Khzw)VL zoJ`G7I4y2t54=bGPU;>`uz!X(#(Ln6@kah0s+fj9_{n`i&qQO+JK!AfEjU)v$i)cz zhcDq_JTLghs|{R4IR6y$w`xtnSoqJXHGzBl#2aDYs#bUj$LRzd3~Lsy2(WWOzXn6% zALgUR6O2Kf)X_K`+S3D#a`{z+@(&7Lg}>mb{=gd}U0nZ4p!kCT>h_x-AFH8jKXt%I z{NuR(>wO%@d5>w-zbXY^D&fGfzd^T-apd4Eah!f)38{{;_~_4ZT%P40^U&?%JlXLR zd_az6@!>rd^#5O@8$=)LwR4Qs0YsdgL5c_XBtFJKeqkP^XZ*&9gA@p04L<+ikvhiU z$HOnovCKU%58%h!Pp$v{gf8RBf5-mJ|KHG#Z9$5c{m1AePz6UT`1@ry4m<=L&zi@) zr|?-Qz70RM$7ccRO8~$abPVEK`(M*vaCkmE|39X`F#IO`87=^nj|V;$f|E8t?1=V3 zIbr`cTH!b`kanyl-o=0W5s3Vo9|G#X{q?6B#=qV1xCz8q4_xxY#B5}c>eZ2kwMU2LP%Z1;uQqu%4-S8L(%Aj)a!J)MgCe}{#{@zClNpb}Q*=FDVBqWjb^D<{7>{4m%&$(t8~-0M`dF7grlOxb#10-fQwNuS z;G%|u6R;`5AE%4eG1b5LN7y3$1N{%oWUgpTpqdA`+WHmIXnrx7aA+@Z62~I`8mGuT zz<~B&u#}F+0r0=zDZnmv{X6^d2s(?e2%K+0cUIXrhJVhJ4976ug^!~Q;pyYN89t9U zG7X>xIDxadBY*(#GrKDY&G7y<(aK>vT>fB_od`al{l5ICBnG}6`3us1R`)Y84E zd0Yjh>O~|9O#lu%PjGGoCqS5`l{JjyEjaARz|js4Fb6o+OW#BbzXSnvH8o&ZkO>d} znRbVNj{gB69U~a*PyGJ~qJ?{5z)2Ef1aeFG;k@whH4tWSbU%jg;$aM^6o@ax!!AGJ zhGU+e@E{&`@r75Wa`Y*4!V1$O)k?1aF=!3Em{$PdSXo)5kO zaDg0f>Ub=R3Ezi*Sq>u+ zaAACdf2aQ=!*9-i4g94(VSIhkVsg^bLgK<=;O^>pIK%Ls1}?JEJ{TC*%K?sn3H{x% z{?`TnqSjyZ5Htoi2N(njtjZkp0{l)-$IpLw5X6I70&uqq z!=p8fO<`~!j4vMkH7gQ-0J4A_ARj0OUILXsEzk(G0$sp6UAgwg{(lfAm0d}1e62}1ndO71R?~o1Zo6&1f~Rb1TF;L1c3zC2;vCt5{1Mu$gdxaF%d`@COkU5gU;Jkvx$u zktLA}ksnbMQ7Ta`Q8`fy(GbxGqFrKQVkTlfVi{r`Vk=@d;vnJ};s?Yf#0|uK#Ph^E zBqSs(BmyK#B!(moBsh{Nl5~Baa}zOI}RgLOwyhO+i7y zMIlFVi2_LxPH~swIYm3gEX6)09VI`d2Bj^f4`m!>9%TdNDCIU46%{X)DwQ=Ajw+5S zpQ?#!ifWIVo?4h%m)eOsjQSpR1$965It>}k85%X3%QS&BsWhcDy)>Urke%Q?p>e|D zMCgfoC#p}pKe0ngPb)@iNb5-(L;H-jgLavYgpQX^lg^3m8eKMB6Wu&Ll%AVjgC0(Q zl|F~Qh5iErF#{iiE`u9GG(!=?TZRorIz~xGb4Gv0dyKCcXPBT&yiB@G9!&8}FPVm! zzMW)0sdf@^^47`Xll>>ZGP5$PG9#F8GnX(AF@Ix$v1qcmvm~-qvP_;LJjH+N(kZ`F znWtJ#ePU%`RbYj)-e!HtI>tu8#?NNV7Q~j#*3GuV&cUwD?!}(Y-o*ZigNZ|x!;Rw( zM;*rsCq1V!CzA6HXFcaf7$Zy#<^j74Yldx{W;?BO+UNA6({E27aPe`OaYb^Ka7}WP zbIWo&bKl{9!@a@7$)nE`%u~QKc82VX>>1ZHX=mEbeC6fiwcx$MTgki3$IPe87syw@ zH*uEgtjbyMvyaaXo+Cadd(QpbgLCir3HYV>UHR|v_X-dQNDCkZ9tiXa5(&x)q68lc zz89htQWNqODiE3#W)i+A94TBWye`5cVl9#=(kAjlR7%uCG)Ht?j85!=Sfp5u*p@iI zI9&Xm_@D%pgqB2@M3uy*q<|zsGE;Iyie5@z>W0)CsqfM<(irJt=@l6snJY5)W!}p& z$QsGU$hON7$*IeQ$<@m3%S+4SU>*}KF81-`XuNv|iAsP+mq35;F-#-6V zlTOo2GedJqi%ZKzt3Yc*TS_}fyFrIYM^7hF=e;hwE?hTXcjJQ0g^&x)dX#!5dKr3i z7x^z@F4pK1=wHxJ)}JsqV}LTKGz1KF4U-Kgjrfebjb2?MzGQUi-laujG2>w4HWPXi zdy{7-yQb=<@uuTuXU%-f-k8&v+nVQ_f3-Mokz_GrDPkF7*=2Rg3TahkO=4|sool^g zqiK_3^TAfiHp+I`j?d2DuH*8l%O00s+tb)P*q2@*x?*wV*_A^FLx;x>JC3@J_Z>Ij z8t^pus*|eI9j6t95+Vt)N;;Q1B>iP+(fy_W|y6L!OxqWpvbkB1? z^04qI@g()U;#rNNL%E}x(Hv-hbf1@iSCrR`x4d_X_Xb7}lZ%C5?XXoi25_;{?Q_=W zy3eeyvhO|LJwJ25Qh!>1lz(>sf55GPr9iE~CqaZk@Sr!rr-LJcXG7FN9)&_e9Yf!Q zafe+C`w*@j{xpI-!abrVQZzCt^2=3=t2NiyuZ3Tmzpj0~Ac{H)6ZQUv!i}t(gg22l z-`6NnQ~3B!phiBFSglLC|G zll7A;?!fNczVju;KBXg7GBqoWJk2L<=I+J2Rp~tG3F(I!t{Fr3)bEwtXS;vv{?`Xi z5Bf7zGmEp>v!b)UJw!end8GZQ;xX^z)NJBx-|WR4^PILPvQP4JPvu7Ee$PYY%{(=J z+MF+wpZ|>QS$qLOflt9op>5&2qVq*n#X`l8N|;Kbp99Z*o_~CC<;Czzy_auFWlM|8 z&XnCRr!T)<0jUV6*r;@^oT;*^>Z`s`{ia5tru>!gtGwFNwf9~#zK*XWue(-vR3BKs z)8N(csnMlz{*C>c@g~cr!RAZNZ(Hf4ptYdhpSsyk&nE4!q-D!Qe* z%X_4H%HK-At>~5Mt$HW-uBK11udZLEzj5IFK;6 zqnl&CV+Z5m6GRiYCr?bKPO(iro<28SG9x|ndiMNm_ngVx_&j`mbpg9@_~F_j)ne)r z=TiQ%_;T%v*2=q&HXlE%daWLOidv&xyT5*Rz3j8<=k5*5jSric&7-ZDZI=^A#?RxGWe7(KLviEdfcE9bL#kZvczeAG4yWh`!ulb?(W8%p3=;$wNA2R%!T^I=m z*QvkLKg0bd{!f9U*%P+E){-c!u!9rY5g`mmdy4owc!`J!i;4h>D*oW=&K-e+IU>Lv zu@c|sSM_`_q>~b#g_ORizLy5V6{!=5L6`&@n8E|y;c`xVD#}#kivIHco?hU2xdY7K z(*uQ-_gCUNmM#y%c(e!??3e}TuEh28NM0U)H(&}LlYyts60%&N5HdxKle4_>d97bE zz?>4-Usk?;636H8;ObmXL{wZvTwDm`5W)taa1Q=LC@lA}jGsczgKK;Y^7zaWh8GH+ z$NAutxIiV2wbS*&!6iY3HP9Hu?_031 zll(7B8-sp>M*X)<`NeW9;_o&57hB>TgG^Bb-;00g)t}hEA^%xdJHdZ zMUlT(qX;7Kga4Pq{!t8Hx0(m;Z*}86Sl+_{<)XypFZAoIg-iMW>Kcjse|(Qb{{QA8 ziTuB}=l|VJ68T>u$Uk1_frm=}v#X?k+M&O@=D0)uOaI|J6^F!mApW0s=l?*vKNb1k z4p$lDw=D2)G5^QA{$s9x z%L4xv^Z(y@*PkZ^2o!ic;0K-q9F2pQmjn<9{{L5k5<&5an2-=kL_!Q6xPw1ZG72(M zQgTuf5^^eXatcZ?k&sc-P*GCjVZ0E0aeOK8pOTb>6wmSRrK1LbmK11*2tpxT00Auo zN((t^1~|c686xlk1WbP%m_P`igha%kO5_wEL-`3%eJGRwRF(+;5h5f6qyvPsM0BUc z)QIUXIgoI9Gl)kdJ|g8-f6>Tj+_%9a;fRSOBWF6v%yQ}sFW=d7{E|}AGO}{=8s{~& zv~_eZn3$TGTYy(ZaB#2Y;)-;`;(UDl`~w27Ub`N3Ml(z5c3%Bt#`H%-kgt!?cco&5uYL&NV!M#pC7<`+IJE-kNo+}zs!va|bjZ~q&< zT@V2Jvsu4e_P2J?g6$#z`+$%H-!2G&ANYgP5)z#jBc@ZkMB?C0&m|r~%AlV3=tUzL zw}kNqqa&t|oQX$r_RJ=}X~!-5pEWG=zV`9d$YjXlJOHi;`vh62i4op`N}UQ908@~sT1C}GtT;MUS0dF zcIFbb$~Dt`CX)Im);6LvlWvJZrK*hCmAzB>1fHn-W>441*k$vxU+xEbA3h~`rWL^I zpU(unUCCZd(^)&XN}?F&N^rAMIVyJypMxm3~H zvrw^za$^C9=i-S;3^$8ZUPvqPOTUC|sN#4v6kjYObSykm20-)vYX*nX^A!Zbe8|(f29! z{h~vaA{nbU{Gsy9_ zxsFP4o0^QKNh^|Ab;L*4S$OpK zgNfr3B8F-G;Q~({UvE9PO8$kyJgNDmX@MiKRAP-BA$OdR;7r@eH=d?2uWQS|j3osc z6O@dGrsqY66pYN3P?)Ru?L7I>+qUuU(tSr`fwaBs`Gc?rgu`SR~l;AZJ!>zj`Ujz)-F@ypct{i$~K!?n)EIPJZ^K3uG@pYPOn72z_Fl zyxKuX{-Q%dS29wz53c2V=3ZQMxO}PjQ=)X!HcF&7C04r|a~7feFy{75oisA9Ww$qP zY9;j~({^y5Q0IDv_`O})6a|{Yvi(q0cqDwjFPQsbMe?kj@URBAE*e2;%)*M5xCW%k zvhZBZBzcE>cHO`w_st#8x&;weo}3@O^H1E*eQUKnnBH+=1G3G(W|&eMgmwvril zyCTcV3vg>u<)+sFTq@OY@*_Ws$4JTKs!cpH`Bi0+`Y_SwlnLZQgsiY>4`i7h6}`OHqE&!$CNfrnow4u zV8AdZID5sEc_H39MM!3FM8;*^d`M0(s=lEjpDODwA)-nCp+vAN3zt)AmB&ToWto$ylK%|DG~D-G^$<6E^o0vgaakTda78G&W| z8BzJC?5D*z6?9<}Cz9WC8=s(V7Mpod-23UgZ}B&;ORQB?X%AKVRLNe=SzF+dFx9o#_7lbI;f3lB}iI-NQb4 zETuecNjs2)ye*H0@70svwxJGj&voa#w?0UR93R;Ty#smsI*#N*%*y$!S9Yw(YWr8; zOoj4hzEAh(s`{6CZEnT5d?}#SjPq9i5NLkQmtTJ9W}&D=+*znHmDE)y(5LgT*0d{! zB?ck<3)bG~Rg-RFr@&B~yjcr}fvD*k^OErq2HT-)r6}1IV`5w&CSuR@-N@OMZ@1s+ ziLnWBd&J2OzIt_nk->*26X3|Zy|AMaH^n&6`%P(RxQI_xP^kO%C<}Aqu2$2eu0@#8%gW`7(gMlfsD~C$s+7fU_f|B$q4$`l!3IAXmu*`g5{qHGuVY9YeFVgu z`rfZKjuISp?Ce4236L^V#_VUhT-Kgqm&ab>@6$?r!W;O~njmyJFC=FD?F6>JHcEuq zzqI+JxOSfZ1IN^pV>&OjxnY2oSlmwXJeno=LSI>Y+34z);3LIsCagQJz^3ua*c!&@Ci5pt37 z^Pw$Sm(Ij}J_5|-JTQ|2i+WjO3UiFbnWb%Wp2B;jDHI&Zt_7T(a*plH34Uylh=z&4 zyXz`$BDXYi4>*dlFOycvbUYcq*i3olx&9N34BN~SN`7E>lLpTHQO85xwBw4-vI_Zr zSCf6&*SKQI*4nC?jqeNI33(j@F#^&Z z=}tAl#x|W#oRR}Fv{k#I1lE$2kXbTL0jC3~zn}f)Zo#gdMzi((-R*N$LQaEcyt)Gm zH13P1lR>0UykhZXuTmn8J&<&j$>bbzc%3vE+RIRhsGAlP>pe}*!{^p@GQE7Bk5KyL z`j%L)pnB@fF@^y4(Y>KfJ-zpmB8=Zt=X8a+UwR12&~Bev3WqE5V(+`;f5K*k9RZR4 zhq6_Bm77^jffwH=>p$rgLl)7-Clb20x<@i>u+PQBY#DZv&5)1nt!(fL(a-rbR+h&W zjd^|4?Cc)kmNvV~qotPE0ro4K3s+sn_c+M{PKrasMzl7MS6Os`<7uSdh4hMK0#&$i zn=gaz*7}6l*1Y_){G3}U(k1ZK)&a`m?gqq5e+J>Cs-i?h8^dg#)7QBgY7-X8^2-Ii z=s>E-au4MPaf|Vvb|fLUafk6mTHpLNY>MEC{rA2<$Q)#dQTgD`V9g(pb`%kso5&Iyz{0MjxE*Ft~AMVMw>pVSc-&w=^9QeAb zFKs3tvHiL@X*__Pf8mt51-O8aT}E8S873aJ3BW$aq`Vnhvb68OchyM4Bj#L6jHuIdouyX4g`fvI=n;&#n;rg!cYg$J&aerXfchkt7 z{jr9vQ`*bU4OXsW64&+)dZr6h>A9*h8(%11TuY(`h@MNtTuBkow7X`K{E}spzGb(b zqLwkF#NPVjNptJKf{Up}`!BD%P`qN{Jlo#!Nv=e`?+8G?iZu?po6)OT>TmXhSl1$XYVfX9m^Ex3)})DOx&Im7hgP z$lbe9Ym#Rj2Wxeo>bx5n7c6aKE2yEQx$7WNdEbeSjD-12uB5xcl35IYJz}$e(>4Zy0U+W2Sc@-ZD?7~?Qr(l01pA-fn^K9aI-Ap7^C z^=oS?%PL=#m#eX*(b=Ed=H`iQqe!4`Off$#*3qil+9zrGr6huv*(oXL77YExGRJ>D z$z`(Ak6PK>gSgQlvcbtyMP@YE_TDVIagDoqz|F*FIdCDN0GS5staz*OxFpja*T}3H zQ!?c||FrPIhm92G;=Jb{GcvO{N+_`*`P?x9x&AYrFPk%LOs#iaN%C1eZ^E6&0iNm}Ubt3U5P}QS~IZ1N0Qm!~@dZsu_ z@69l>GP~%Px8&d7GOD6kEy;~Q><*%DtE^=6;9+;^SRTdBRczp=bhvfPnv?xndZD>V%&wf{E zoRS@K8C03&t#W0Wdf%JqkhI(mq-Tnfk-t2f(8#=CF(E#-r`?z}l3HUQQ+5QLuzlTq z1#Xri{HQFR7N#lg=}6wjojTjsBwz3w4Y-c*_Y1czkobkrs}fg zY7dW^cx-pi!u>&ZrfSK#0UtESo}>QPzU4sbTY;;dSgMxgfMU{p#pT&+h#bl{XM`4& zM`f6x`Iqd7lzN{kORR8kf-|N!)%I{-Y$*>B=~%vGyHIi2C>z!4t}5v z|6UucfwHtcl)Cw>=AnQ2?${`hB;g(tH}9u54YNR3=1t7)N+{nv;FE1wxp;8mO9^%V z5r8J%et%%t;Fl^^;OQG}rJEi$BScQq@cd99sObo(L`V97U$i6U*i~E|t=1l;ip7mu z`O!~bmt&WRBnr~;6BupZQFE#KK1P%gKU=^-v&oURNunBW+J}BTDCp#v>EYIC7KJG+ zgr1wvh9tJnT(#EE+y3e{oa?M$6;@hvqueOa+qz9(NOeG5DgFsmY-;-4K%Lsto|jbz z{^!X_uNA> z`0PP1<#J%!$iT&+;=|R78uPO~ti%=3+y`3)w2{=y$%Is+`z54}oM zUKPhK&)Rj1Ejm5Vvd@gP-Dmreic9C?QZXJ2YJsV=bCXfGFD*LBqg(wCLpCGzW%TQl zD58S|?WVO+yivnU&bJ;5TRugG4w_vcK0Azg#Hq6Wjb==v>8tgP2AfzBwU<@SRb^W< z!f-6xGYrGCibW5nwJ`%oJL?^pyY6GoSqChKtl&zZ>`~PlGdTzNy~bj;8hM97kHAKhCfKnK(kB(Ee$3OufTvE0ZB= zWcE%HWC9M%(77|WZk0+Vm!aZa_mZF3xz&@e^KA4?-3wv3`yf@D{~Zjc+O2+PgvF}P z_At6qj)#*AYzq_14+6_Qm6k5vY}Bxd^B{K;rler!<8d$JiC*1a~P z!D7k%TE1S~wRJu+?8`VA?i18YXaCWN3xkG2aekQT*vqh3(jomG?v_c1K7+)QpdH-dx=63svjA*B@ zuvJ^4gk+HI4BJ(<*v6#el)A<;a7uq8KMo0-%XAw>@hyfF zx;e`-e)Jt7G%1n^FZ|L+)hkrrRItaXjVFZNc)UN|M01@(Wbb+|4W<99H@hn2hbmId z9dmNEDzNI;!`wdSczoOUPv};OE64@sH_SXo%WX~{(o4)C;tHhlt{&P-T2{E2{J~*+F83UpcPOUZpKxMbKM@#KrJXYG?4c=pP-%w}X5yEzD62JClNNfpA zti0oCyAJUBZrGd&^&n;zN^;9udoj0ME3*Mt8Z3UB&T_(?4{ zBBJKPb>Fekw}rNImfAO?BRP07KDTo4PG--7J8MH^*lzn_Yiy!ULPTse^W~F(pOImS z|I%#QHDyu%z$Lv+#-iO458*ROD&jNaJ@oS79}Kw(VY1LllMaE@fko>Z;^pz=8-gmd zJQXiC`J(2NJ{!sy*$0TzNL+kfX|7Glx6SaN&eH0H=atTe8S#{FXZH^D_ljm%3d07z z*o4uJ?1RUTokGQ2y7XE6vnRPf@I|&vvM@+u8=>dlJ#AZZL6^HCtw%RC@(8LF)}uOb8YRhifha!&bEij*$ zc8<4&ab$lnjA$xYVAJ41&c|fY0aP*PClhV)!Zf#WhV1qacY&$P6p`M-Gd4UHm2O#N z<@O}V-Ha!((Y@}ECj!z4hSh4ss^OT|r_?F)Q&e*v90nH(aLf6k#Tr0$_+L^P9 z+wt1mD$JO}o_Ud&-Rea9;YG^D>kFTZ#3J2WaP?6toQmA@YAUxBEjD(?FPKLL-HQ(D zSgZ;UxR!HmEpC17ofb83;=AG#iPz2gX={C)^g?J%HeKRSY34QC<<;ES4`uFDO`40R zNxVX-A;x!=Xu2h0;3n#Ql3)4-OER|))X_rvwbw$|Os~W~=5^rpmu)rZi^(+1@*i)j z(T@Lerq2d_>l%qq@xwY>`V+cMj8_^1-d^fWZ8$NN_k1HY`kB>=vcVJakGg5CEjNY2 zJ2*1ii(#)#fNoLejR2u zFA^wCvte0hQFOu6=8K%6P-%dBkpwmak6Fli$dOfiNV9QB1l7!&lBJ2RP}0_8K!$F_9UjHI%T7pNpBb47}?dNEIoerGr!4y&0Y<;Uds z!s1B3mULK_h1jbO*mj`$b*ugKsMiX!nsr7b-VQX{0;9zBI`uPIp_0jS=uagNb-l%E z>Nmc;=isaf##oO=#NgOrNsxg(eVoaF#;LCS<)w?glWe%0A6-1)DrU)!b|C6=4U$tI zm??5|#MY0hEJl#Qlj)0wfR=s_H6IHg2{tX&-B_ws$)pr)^HQvy?atcS& zkfFn6Xp3jCB5l~vmx6>9@*AuI+3~vbiPM#CDY*LeALU5P?xtwHrS$`|o4D6)>!82_ zl#mROme|P9wE-Kyyq#QPw!lj(pZ5n>pB%nc`L}JC?OD_26QR>ee)6$9DURPXc8q!pe_=*1mlYs=qvW@59PT?s_Wf zMUmb2cG*wn??OY*oG9ZSO z)HZ62Le?Z+km)k{Vb3OwS$f&N&@Ay3&R!H&GbVjhTMHk+-Yc2@-`c`Z_i)f`xJnN_fV-ucfw4`g2GB6x#n={C! z9v(*f+(dD|nUg(iX(G5#tp$r>6C>PgW6Mejn{{$$dLQu^aRXpfBVQBr?qCg3$&*|O z6?P{}ZOt4d=HsGcX(Ztz{L04W-Fv~_!tVBN+u)e}TInXIW{qx9u$siatK;ZXEp`MS zziGBG%@og#4iXh*uk-I!Dh-xg=e%j^8ag zAg0UlAa1Rd+9H>90hiZ1Uu4kC z9hbGWcr3!9a(Y|GQ%G7%OFcGMkTTw`x}lXAn@zLUws}s_nsZO55dHxgRaWlXZufAB zWj!zM!Irl@U(@`-yllwu8t>;C2Z$F5TYm5Rjc=wFq4Gm0CqWb=u{@7!)smuj`>xy% z(-1qxHJV(Xu<^9H_%w>fxh@<3(J^V)FRQxDhtDX&p7A8Ga)iebYZ?uLTfCCn7UnB# zq33F2T6j0!wDyr*u08@RG8;TsV!jJB2-b%A$7a18>3O@z%PZKVL_G}6NPR!Yz7XA$ zlCA}d{rKI>K%xrwWp`bUDWAus7rPwG{ev93Nv}6p^qIq5q4jzM5>YXng2`x8H=IV6WtiZi~NRt^gL-f%_e%-w4E5cz4paV4r1DoW_Wy`=kt&t23} zDsHgytLBTz@xl3g&LxH(!`8sZvi_gBN0`H3j=m`xsM^1IsQV)#h`H`Mf2y^_lxLVH z&CRjY2i{InU8~Yn<}_i93(>gO5f;u35Px!i;6CG9&GaJx#d+(}BwyMlXNv{B;Iq$X zsy$^sy(Onse9Rths^rU-I?pEa7IUaQn6Cc4)O~3^+JBy9=^bvR#bwGr+9-9k``~=E zx9IqgSrBAAzG=Ex-(qNh#-rDJI1E|OY1`uM%^9s|TIna|5tP1@{@Gcaj%7oyc*ye^ zDft4fgKfx8X18W92Uk)m5?$WvL){BYJhi+;q?$OW|^5sF{3=-d<*Q%AN2Bc@hiEP8E-y7?A5f5`^gT3#;d8)vo~C%0~0D^ZGkF4wCRu-ZyxGKf=ZS}=2R zh-px9)baQc4fQX0NDKK@wZZjC=NSHyL#-96jSxWe-M` z4&q2TZxn@Ag*8@w-R#+zo~z4zZDBMZtV&0h6F?txUGeHG-vGJ!h^rYNbhG7@X!?jt z3BOObn_O;}aB(flt$=5xf#>V`(yBDCdWaQU1+3uD4hHle&ueUxx23|)g( z$L|{1toCQS8`Ic^GL{LMDPWE6nXx}*S=?Niy|?$moZgU{n5pba8Q~YDXA{1mch;Q| z#p&BwDSnXyy<&;w# zS-EuP%SFYDB%Q01WandZ65Gv;kbLm+v))kq6U19;>{8o3wMN z-K3bjceU_)>J`=7un4-zTva~_xfC;h)oz+8l?|i#==zN(^@z0qPhmP0TK~rrE2l1U zXmo$25KL_G{a65g*|X-9zI{K6BBp}jhw)zT4Bf}eGBvL@)`og-J*rB4>=%DQT0|t} zbfuL!ZDNhn&E9go=f>-1w8>#gw&&r(=I9H_NsHPKrrf@c)o4rS4CqQmm`lN{Pf#+N zT^4;92kfiP1JnUDCy^3v-aQ{~@g!+PrV}Q|Er8!nIP{_(WZlL^ukkMG+4xVn&O^+d z>0pF2YrN;xt&~^Z02f}cFiUlOwD#P&lhob~j}wE5hxy;<(^`FBuS5-5xs)>WPTy)Y z(`GJ&4r%vvz1xyw{(dDk;(n+c_dB1Jj4!e7>Xqe* z^#g^DuJ4{BdCjJD#(NzxYM%^rJ6aqkR(96}M8*oO91_#tbSIs2B)t@;Cd=bQGY8wW z#-8x59~mF?gnAK1e)O&AOA?po_^~vdQ8-NLO!}@uI>{26ue;i zFs`+#HIn3csr8uNGe`#}t|}OW+HxEo0kNokjG=^l`>WDHXaY4hTtyJK;Im>??~uyE zAR!sk_AYLzUY1>~ZR_!j_*rHL`1i4JVncYVHphNG%V2XuVNk2>LR^sd=l9bYGIh3R|dBZy^N&KE!w(Z&MXvhhV%UiH10XRCY7|w+t(HAu~)}8d(^)+-lJr> z19|^6S2wq}6jQ*T={uis<5QBB^1?k+KXxVl;_%nIeN;6szw9!)snC5YwRE01v5@dg zJVZ%}d?; zcJ8h~$7=>6~D4owtNLBs?3XXqa)z zlv{VZUYyOF7{WM6!Kz2`*{BJlQcC?XkapRy$YUDWuPWU^^O*%;VS}{o9Rd7HCGR^^ z=26u@IBry02Ngipo`3v?dhI~;;aTT!&`nCn_f`gkeP1{?Yy(5e!^k^m?QUl!S&%xW z-8qVC*qM*pSx|=VT~&!FXQ?QVdNKy_in!^-Et@FEDZfSR4@+zu>-sdN^t}IxW$nm9 z^TmU-@Am4_1=OsceOO53SlHf6ejmoEcemZGb!&3V(@3aVdQYMhvbrPjgX*bPM(%pqOY53=fzBY<#o&u&h^E~wes{+8mUd4U3(g6&+J zT3`9V@iwlU6Sa2(>X11?+&oui0kz1t-5-?%L%2LfBlZm$rk~Gnh)rjV1bG&G^)tCV zs&-Juw#YiVt?y1CkK?M)e@LPbCLg;dr(zAXxGQDnT`p^K#7UH8xguC`rwmrhB_&{k&*+Q)Uu zUNJzevashwq}CaoWO3OuU&D)7r&<$}M-C@8X_{HqJ=5shS$R^n$%Msv_gvV1uz8Oa zeMNG)W_uTUbmYLRCsB-ZB1k)tFtX^Fv*5+(Vz@ znymx1X;qtls%VBu?)KDtc3PvxCQ@bEdGgzCfppEFjZIh5zL|lg zU?F6IHZyZF=)D}{{N=-!{QL{S4{5{Yl&cTJE>yz>tg;=;# zSgeUP-A-#KS;ks;iZJc?w(+dh^Q#Krzm>4V7;=MOOch*;>Z}t>3iJV^hwUd9a*W@Vl-sTdfOtbIWukF-sewXp}+(uuY<@uiX^%tz`?> zW@VH7=lJ`cHt?Osc)!mirjhM&YnxquSi7}ghGbHGC0)ACdb4B9viK9_7a-8dbRW_= zS(g`vT$s3Qt4nc0o#CRdVxUir|v0T5}p4Dki__`UA_DwU}M|p-UXX` z6s|s0{Y9^OI1gW5Ir+iim9J=5)}8rkW(VDe{?dylU8<^IER&=zLL}pVL%;ZlM@qs zk~6M_#9WCJm~`3*mNBP0)B4DT{HZhRI`jnbyFj*xc0%fFkkRV_)A7pWA8(k%c`7qI zXIv&IOxHc>Z8sAq9R=~1BZ-u-=+2z=4b#lpI3ItDgMJU(VG%BDgua-yBRRh@_`_qD z<8;M<;9U8ELAJ*S#r?x_dwJ)GhrHTdbUp2C7Q@7Y(qo4&j)1H7g@1HMkl&Qm2FB^GMo|oEgdU^O)ps>FnQh5A~MabW&%A_ZC8Q+N?cN z_uH1f6xa;i;7pP8Y?2OH=IM1e#*sMYQ+D&%ukjP4>?Q4iDp_1iOjt4Y~ zgPqNn*eow+4mqCHQV5lV*iq1l8Y%{w-5kES^LYgf%BM<-blG4csC>}V2cW2@qn#Jw9!%8JqOm*$h@{=L?FX+J1Z z3di>1ciE&14ps}on2>(bA9L0$XeUn#zGe@uOmL{`AbDz}{y4tvwR#TBEwX~a9?amZV3sGzS%5sj(~>2$s}Wgo=?Xf9c?>PxYCy5MUzWI z587F_ERt*oA~g&~NLcdOY$@DF-5L8&ddSkpkLrFHhsZFZx7$3x2hMQHcPQF0JM|>w zoN?B<`|T!cJA$)AGRU~XpOk=oc4!-fqc040h9%qu%vUemM74YlgX{GeNj4fIQXr6>6SK(;n-&w z7(Y+1PTd7`9x9gBahn#&~+f-bsy)Bw#Qg)8UEF7ziLbv-5pPs>OllCQV*_2L(gswE!$10Ci5I< zxd5`M2mtGok)FVwymEP>ErEqKzmlJK zX*BPZN6x_VFnV^bt}hDQT{H@eyK%H`Vfk`+`hWH6)t`dqoVL*+2eN0M^X*FDau!-c zdBY=a>|vRJ1CiAFbM&ng@Z!%dBa^H1h3ng=wRJFjG`RECqgMHHhvnoBLFwzrttdPP zcGl%~vveR8+l=QQ{-fNE=9e%Wm7a?&w5uAVirc!TFsFs*raSlTis?Kjs9fk%t1R~x zpSuwmBY!A3BxOlcf--aSV+B_jV4!rmXTv6G7k2h+w(O?urBIW=9YUP-KcTM2!9NFM z)h{8M>NyfU<`SYccSk4XDgoR^OB1zv207zpOKQM#S~jDp8^7&Ux`tk!bkmi%Q~U~F ze(#v6PzxO7D8_5M`)$^xrCZwEHlCMD<~exTe6=8CgMh%Qk1>cw0hF^3_IewShLc+9 zP~WY_v24+*3n*6At|y;q&+d{kz_E6Sw;5Ilr2N1iU;Uo{0A%}%xUVk!O&+14+4-i| zPqK~P3#lPSarTK!at23y{Gd6HX=CK_uUO7e-o-}k*&ZpQ_>%tsTGtxPSy{(-BrbtmtY$nneNL{b6*y9bwi zpDcyQjpB#QY-A5DR1x(L!(Z8Bz&;%CsM2NA?sbS{4dz^GHz)0KI}9|wd?GnSmsnd^z|ZuG;FJ5%52T z^gUM218W8RcIIo?A}1`XaVdFMc3W``mfYDVb7r(9ZPmJuT=SZq8GD0`r>GU8xh?cP z*GTxgcdk5db#rlW$^xrPA#$V8WaJ9$v`-d!!FM+pub%unsJv4z%v)*B4SH9D^_C?U zE7zetD?(3W()CX___+F3)wR!B>h#CmwEjGr(GtS;t+>p&he6n|a5^GFf0 zQPhk?4XH8}NWb@&5qAlKeH- zuPs{M;@K=!78Rai;#3Yu%V(7;N6xs<&CsqZ&!yFFmtB)nxQ$ls;Q7)P2)j28+lFz$ zUO~?z+P>@f#o&7nh?-Q^vTR$+HWJ$0dKGQnTQCIqOhyhlU%WCwR{0;sJ_b4tuXSf} zV;7f}jN}1^8R{3DVYdUGcp$b-XR|AF4^U|QLmcUM8raDUd0>pZvm!E?2nTWKagn!n z2;I9JCb{s@Zef@oAj+?hfp(HP!6P_Kj&a+jPgCDoXicX#lJReafnbVM;FE*wJx+PY zZi2dvGvSrfS?_CgUoaLMu*^6({DwFlxgG0B?Vy?Stey*)FYckYnQ}M$zdB%JXl^mj zJY>~Np9M{IaU;BI=1NPk2nIdz(C5%_Yw0~h;Z>|wlgu|Xi~xoIu};htGjAkaEw5VmmKaz1X7h{7iKjU^*&c$hKplo z3mwvVs7Q}!!kpj}gZT4->}m^tg*U`pd2u$@11O*$mm~w8!=TPFlA|DzUr5}pl?IKz z);|o!>_pOC>$gP>q5-@}kgn#Fb~b^x2Xhgft+x+F{i1vyqv)#)+FW;6(Z~Y9Zxewb z0EdyCw#d94sGt=A*kw*lRBARpTOMoSZ-Tm=#e{aS>8*8rBxTU!cgP3c``dPm=W}D8 zdJa0z4t~n|mc8N&T?XDw3LRSF5MJK%T(!26QGmg~l0_gM2O#7P%eZnqQ^Fs$cY)2c zk;h}JXgY<**^pkq?QF2J5t)^v++$_O?!u|q-!Nt1C!%;?_L%r>ZL&>vZxY<;b4%yk zwVtDF*AhD`Jdfl@!zq(vGs)$aW+W+P5hi}pU528MIlun^f^7J%!^5_kwxOabX;Y%h z9n#9OiDNkcE*S3vZUaW8Nh6RC6HWUL-*|q`85++vvo=xG23+z&lDU3BY<#1qI6POi z>YwmXKZpJvw6T*%v(@h+N4SZj{qjo!s8otl7;RwK%BX$HtA=>-&J-W;QJp0$ByBUo zH(I2wB}px%y$GurVgfS=r3~jPesF+*_yc}Y9Ohps(_(O{{RqL3GfDGy$+WCVn(@eg-+f;`MMS0=cjywo-5uoZ`I-;sGFx|P?f)j=X zLCo_`+K-hPegx4QMZLJZM2Ov9uxVvfP*nM7geeLeY1j@4BRM!3IrJ?V?j>b!-ZnTS zde=$ePlB=}gUo{rG4h7pkH-T(nZU0>@c#h8w_C9>K+zoj_Z9A}{V){@(o z7$3@{OFhFF$I^f>Zglefx{Aisbar&@nm{HLE@)YtI(zwppIo;%_n49P08 zjoZ6@YtH;#@c#1U8~0?demd9G6Zko0`^lersXxQ0rs|9Mik?yP9|U;A_I0?PX(q9W zp^=VEY^P!W01Em40QQ0Xoplc zZ9m?Hes5K(8o$6zCtI?T-tNZU=H6Z&DPj*B{Z#Q&D0d$ST7p zs8zuG%t#3Czp1|){{Y~Q{{Rm@EyZ(lrT+kFcrRR@Hi&8(b0yL^UaQIhGmsFZLm2d8~$D8eL$W$qZWJ-Ys(x8{z;D#~+6i~VS zr**G~c2|)P6KOhg>b9>G(8XySvpP2AD#16F+z#MWAYH(#Dafz1d^7(52FzXPvF($@ zUK^25U8^$g$Eh)&*1bc)AMkFUhApla`)gab@I zUaP82;*DZksUf($Vi_Uxi6SsT0kQqz)N_D9KS3-EMl=J-lYYGudoH4DjzZ}*G-Zz2G<{jZfN$?As;M_ z+t$BtrvCtfDSR-K{ZyK)ZP1tUoS*!3R;Gpj00e0GTc^J3i1jP1cFIOSMWRt*?GMJE z2}aMi%uYJ=JlC>(HPbAekj}1i-vr{nSG2GAB}c+NH}2@SE=P8`KjB?=gZl$~Fwi#n zeiycVz2%YE{{T9lX2H})=O=?aH)Avl<-@yfI;JbpJQLxYR@Ea}H0ztB&H_Om;y(;m z=r)u4Ie0fiMO$AC=+W)}0HGe(PL5fSLi9rqoDpA{uKS17dt#MFM#l% zU)|#a(x+d6kgxo*{W%rwqfCq+yGiy0Bks}2h0f0!UxCo8f#ok9IOe90;GMWn^(*cO z>0Z9p2tR~#%_rFumi8TK@*(rH$y4CPvbQmQm2TJJ4ZH{aT6OfVPHRQZJ9|@7)k6K= z^m3tdvxl|#eGR|ekAYg2-WW(g!tKX;=?%L2ezeknj`;1~j!@<@qS971AIhamMNlg1 zEO0l6 zw>am}XBqxXTr>;)*B%M0h-d`BAlfHqk`@JZ%HNG%Bci zD4+#6RsJunDPVf`In5MM0s+M_kNZCPqKW`R*S#Tc{Q@ HV1NJFdVH0S literal 0 HcmV?d00001 diff --git a/data/mllib/images/kittens/54893.jpg b/data/mllib/images/kittens/54893.jpg new file mode 100644 index 0000000000000000000000000000000000000000..825630cc40288dd7a64b1dd76681be305c959c5a GIT binary patch literal 35914 zcmeFY2|SeF_c;EH!Pp{25g`=HHe;KyjV1dUDoc?VjBSRoXHAGKS=yvfmQ+ag2u0RJ zh_dgJ>`Qk4XQuVNewOd|_j`T+-{0pkbKQIHx#ymHpL5SW_ntF5?{@})-I}TzssIE6 z0j`68z|IiBt>T5X0RSBx;3xn9RAAj601DB>{r(Bo)Ph(pNMq_pU9%UoE*74Fs=kF z&RLEdE-cE;g>kmRS!11T<+x88sR|*v(egC1hmPwh8=2^^BcZEs5REgO`e8Y)&b)!50dys4DX6{ zb`#dOw8g06T%9bztd2U)73+g@cC&Oezz}3bi1=W$Z_WR_y&Klb0pliY<%+R%!?}KM zLoG{i1IgR|zB>@)Wa$ic31WW3prQU-!dN$qlcKAurT0(Lmc$~+7})STYkxuG9}kX5 z1Z3S&UJ4;ALdp>7{w+sVj#!Md+Yfd?@8@2b7xs=$s$=<=+#TQ{X;`1;yrn;g0nuL;E2*|Jkqp8}_U36a2gN`M+h)`itTH zkfHymAL3MeH*fxF^+Rw!Z}ipb#an__xCPeP250frlKJO0uHTrYiq`h-1UHN|5omz1 z`e|7Hkix$=CVppl|BKe>-&tjUG`*>1zZ-?biTiFU67&D?Xz`cBB)0yADp(?IiK4=w zY5H?>{0CquG4VgcmXMPC4fv-`0`~gnOnIM#Tb}>S(En+n{?5|>MectQ{9mB|$}0s|!T%GY`}=YJh3pE!k-rhUZzBB{U@56T zmtE9v!T*t6(t-8sarcML{+DDIe0>_At_CkU^ofdF*V5_dN&BZZ`%}UHt``1DZ7Ym- z64u8N+zhQ;v3O9W|Gew?+ZFs@*SLQ-M545QW5n4~$PwcPUOfGVwZB(df4KMfLGAZ< zN`F20zwhDi)S{S>sFaW>+z2iuEsl~Fll<>ei+?z{|5rupq_gG!ogMvufQ;J<9`1gbT3km*+hPndEsqr1UmM<>2xE~UBxeK{kZY?yFUIn z%gyayFry==nr@#7ji5_E; zU-;X9k{|d8sg9nMk`i>_H*fGRQjzE<<{)9R{{X)vKgZ}J4eV`w|aYRMJBlZCLw#gK}e!tamfKUJ+d9_L0 z9v~K>XfXNO6R`;F#{0!!c88Cbc&!4ipKBhWGLh;z0D%cMseW8?N~S;-CO z^tJqz4Fg-dZ}p_XP!b%mRv5x5M{NSw9P*ofC`ff6cn|o6kqR2ef1wStb;Y^ke*>Zf zec`rPXN-8Epl+n2 zeIAE%Bod_h4*m@?wH?mYN6`^$`+cSuz78G?CHM*kTf(d{HkR&=ZeW($19aE?2|Dqr zngsqWI_-H|Wt=0<^_R(I{7PR*?MDUJ1klAf6X%-R4TslrcOzha6&xBzP>}urr8|#v z0|n|2a9Z#c+x;G$w3hgZ9RYOU7lNK6<-f*@Fp=^^5hkr;@a zg#Uc$eWyfP;;Mw-tTB)u*KbSu8x?}8>vw*L%kz7W+RDzq@J*Zp0~|@QeAR_uW752& zlBlpd*|=iC98u}~G7n!@#m>+!>u+&2&IDmgYux#-VKrWscoAac4FHDr;4O(UEnMd@ z;98gCn=Py4 zD@|INhTtq>vN3HVnWaY14$DN>3Q6cZB{10i4vLd200qNEf-N+n3CB-mC=3akZF5FboY z#1si)f+?{tF(il&ro>*vkRUZM1>1rtsVz!UP#i7}))G^YTQL0!ks{@h#1sYc38o|* zlmwWULJ;9n2oy0D138DIK$^sq*cOE(wnZU{I4C3$2ZbWxpx~rDks69b4TU21h$7KM z5jjDK!X*Vo1wjl!2|-CgDM6$lN)VhHI9w2p5QK{f!o>yQ5`rKSTuKm*6hwfy2tkCH zAVORaAt8v66hue~g3O7D%BUg~;o>SPqAC)S$5jD1@@2tjHhV z7~w&a5;dh;U`=}Ss4uEv%xxo6>4Bi8<(1;lanQ=0Fb|8se>_oI9G3w zcaeX3Ufi%Yf3a4HHQyK{W~Fs7)>up8_A1A%i?JoGBW^BI&G!ZOwbIDd(wTs_bOqzP zzSV?rL(oAtM#`!WiUR`5bJ|cm*bWa6PA*c6caMeqQi~VH55u<8p z1U@#g#gtdLQ)FUE#OR($b1cMl;g%afeuy?BItVsIPXLrLFy7n;&MQU{$LIa z%XGyMz6E!F!&h{+bi~s!h)E=UelHV|oWAJWvAldI_4Pw`)vDQT3 zSFK0lj0D!kINK88Uo{>HP6|^6w+vUTl?E85=L)j&HKvXvJ|q-FOD8ZMl%S7sC8^b) z6~@2L8j&XWh`@>IH{MMC5yH>qhMGYuECm-N_K+@-^d#m4_2u zt>9V%vtKi$O45W&|E?*C+;4O?LIPAF5}*!|0F{db3e+f|dXWN^3#elepnjA9wVi|* z9L$4oP-`MUttBQ3>RC`3iNR5V5};0$5Cb)?C{Y!G+DQt;l_HiwoeOF%qAo|sAS6{# zQewwNMODSblqDrpQBnxy<6_E4goL7!BwP}V0R1Cp|Gv5aEs8%^6yHq^aAUQ^ODlhM zDH(z0fE+hyR-MOikpzGgr>tx!#-(Fv1x6Wyam&BpV~D$=9QV(SkVFbpF<2Y#uT{TB zjEYMmB*Bjm=ualf0kKydP^3VQv*mYWI9x~+B_t|g1ecITfFCi_|B9vrd(hGO`@uUo z5e*_Dp0tKDc*X5(h0!=J$L;QpwU$0Ep)4kOTtZP&QB)ZYKYknqR~18viK?Pd#}Ud1 zRn)(r2%XM@+T*9aM-10=^~8B#oJq0NPDJH*9_##7AkWN{6K{h^it% zjU^=oUhS%g3WVO2wm;_k!8r?HsSzdDn~7H2e0m zupT&kgpXg~sJMir6cQz^d|X9UOE-Pcd?_R} z>~eTSTztZ{>o;!RO1yVJEj=UiLDs{k`2~eV#U-WBURKw9}#i~t(gqHt-u=c@Yd%PaNMF|XTsQTlSvTFKSOHw?p|lgf+jIWEg7ZZmd3 z4Oc7XMc+a9fC47WkIj@ZCean-{rc%E5vjO!?L4n>K!0<#?&kg{|q#T(e3*?OC z2w>DZ)#MYCF3ynK)s`F?U<&xmiN)7X8VI{|Y`e3_4yVP~nOCL*mZA-H{&%oZiLtc^ znMAAN{*{&y1;)Lr<5$nIPj60+o?FoE#s$0#+@lMd)NX+-WHMK+Gc;ah=QJD;2Owp= z*8XnuEq$Al>7vrvsM9==8}kY@Y|9f&bA*^TiD8LeNX~XSI{-#@H}k<%^fk`89l*lW zV|pN$2^e#{o8&&8l$#uICgPma#9X>z@LDT$A}YNoPQshp#vrwwExy-!xB3@ypNq)M zye3GZ!w0s|rj$!IRjm)Qpv5d3@GQSzv0nWlh`zxz;MOv29*nRXN(dZU&+sJLBTO;cR zc`bTFDdfUwWRhdGc20~1^UQQ5bBxE+1`0m z-q1i`B$}W=#u1>Kuuvj9pKO2J;>+o_r1Gr`t>jm}RGPYO4WuLF)hA{mPWo>7-?1ap zcU7yve!N|`1DM`7r(P$_8l>G;+X2i2;Z*gXz0)j1b^w}y-umTP9ocL^)oP?t&pe@P zsG?9*dl`N8!KlIhWMu`$;N~N(0b6c!2u9UXbQ;#)CCNS0RcLwmPps4;^m_SnN)kfT$I6jXjr>%ld@o@Jb;Q zdt!9?_3125Wu;tA{Q7Vp!fXe4r|qw1E^QQDIyI$HJhmyLPcGENxFwLXXu!;}GO6<- zh4NZ&!BG2MG!3lR`EG#~9+WC1?!+g7Mwd0)Y_ZDRDTAPR1qLXee6PVn=gqC7QqQy; zut#FcONI?*YPMP_b-XNU(A~9DFAwyZ>l6fF;;*Se1ycf5{al)X3;AkMm6~?sj8n#F zx^4bR3+7Gv;%>*PYcHD_H|9A5aLQ+?&JJ7yxSVc9Y7R8Gngaw+>&cL7iefUjthWiY@(Q!FTv*Qt&PXvOJ}vqfhur$k~v_@=%Hk8TOfPo8QVhi<`Pp&S&R}YiytZ!?Q_7&uch-^ zeoS?t=zJF_qN}ZBrPV&$OfLtdlmckXx04$aUj=R10cd#3P_Qaovzlkkz^R4P%O;SF z&|<#Kdj|4Kr6b_^wBd1T^r$|JVdJ=0lp8~6zFJA<%B2R!8oFySHwa5F2a{KxPYtz% zb27W_0QsCheN%4az}$Vndi&>Mo$kkU?J z5EBVZqstn}pn|3NH~SQRsb-B^?_4)2x=7PI&nAL0fQ-7L_)x%n+w&&vt@J{&9bj^w zkHdT!b*b7!$qv9{E;2lDg#+5|>s(2{0~FjZoKNeNkYmuc5b%>r-ENhZTRhsYCU->&Qa5XbHjuYmjmR1v z=)R!X`gnA^WsiL^4gI?lW2$xSgn$s=RPP}jGgHNtX=P&_(+h>6d>e!wQ&XigdGUz!qGAV|Ih>)A;Q0;h|HpA!BK6rCifUyTBLgMQECJ6AS#NbmCh43J{XR*O76p zTf@n1wMwv>kk*#hTb+ltWsjC#cUgKPInth$0`m@*U`1ux9^~&U?hmBDQ)!BI%8Um^ zGaVMwb~}<@rj;zHcp<&cUWF}eN>>-Tod68p_TIw2oy&Z$A5}M{$oZuU+>!v@fR~XI zNYE%!?MAg^%sP$-a&`nPKd4pT_D@I|-K@Pvi1sf4Lec#q4`50D0gV^aFHiSm%kFAk zO6M6%p}bU=JE3bdv7G5R5!-zy2L23e3s_Vi^f9_)FF}EBF~vKhYxK#lEVPI56son( z>l*CunNOkXLM=@!Nzg_;+PctDGa;W)dnOUM9N1i3p~L1qxz=&JIN(Biy0|c!gC97m zxxqm}AxA-OZS7w`FZU6Wy|4pF&j<7NW3ru%_4k8h4dr zP@R%~*98+6*4!3!4`*H5ltlq7eCEz6>-HH=o#=z0$@7@=nIQ6$CWo~De2lprv_-hw zLEL9)&%oLR-p0qyw+clPgZl`YA+Zy|l5^l4|euOD3-mGz%IIB6(fee9GCiCeraORr%@+fKvV{)@(=XSe8DvGt zr$^`HAYU>SN09EjPJQa^$MW9NILwt`{Qyd5h2czI?TXwqnVZRTtDKAVm0wuCBXBL_ zovg?q%E`-bte@CawGe1I)Yn|T=pCG!U5R?TEMzdH*C9cB6DPh_(d^9_`N};_8!tuG z3p=wL3Q2IZwmx^2;mIcF@Pk}8ufdq63KLH2D|`*SYQ?*fbypwu3%2o?$hWgQ6spd> z9rBb??n+wRoJt1mt*HB5Zj(ZoTjE>x9b`8cA3i9J3Ct-S3ga0=v9G+bzf)T%l>Yeo z7ySY$O}hN6nGL0^RhVAge(9jq3Dy*<&2uCAc|E+xc-h`rA6DoIr7+8&XIt|?<`~S@ z?9cL+d)(y0b%R@g)_G%M?3~94l*@g3fc?Wf^U|}-W;)0Dp1R6_*N?+kAvd^OmKx*` z4o55^!-oz^59Lg6wg@c*wboX;-PJo03hS_@0oJXysq2Pbq}un(Uh31!(y^316=&Yd zI9^|2Jo!wM%^8s2CK`lq!ts!dK>NbCSFM3yc*Y6(irypsE_;@yax$&8ij zZQmv@{o{9#TpAOBUWXOxSzh@GN`odaP-nQ-&O+D$I-510%ZwBW#F9*~4CgsIYo}=Ys8oDtO%*QyOEhKJ;CdLc^CgK8DQC{Cwg;v(6@vgPyR~ zM$w~Q9`i*9DXcf5KoiWjoOzT?#$Q~ZD)Zsz%?H4TTJ-Z~el~vg;oVoS(_!L31Is<%2W{*@Ayr7eTwZ$V9U#0aD{K=s&(v#K zdBDRCDx{g)kJ=I_7?-&n)6yB3yuKDh*69^9b*6kI}U94Gm;8b8SUtLR%(v z3}Y4(0Z$o+MZT3HkGZgE_{eSQG`oJM^JXE+$^B}0l^Uz~=sgOAWYaC=>H1B$1r<1* z7XB=H_q=XY%S-2}EV!$isZ|i#P{Ry@@AfunA2c0^?$)BvIV1!7I}c0c)p zk(nJJ|Bdd2t_-h7b*3UU8~fgg6o}CG+7BaBuAC|jGF{uxdE_i!y(fesysExVUtV_y zxHPVn5WuUH?ZB~iYy2c4XVe^mvWk;G*-)!1nih5E*zokL=l5%^Uuso5y)0_;W~aZu zAP!k}PM#;|z2zz>El-7DqH4lDpY$tMpS7O!rgS#Ljs9H5=s#k#jt4+DNBgu35 zRnO3&1()EHiMx_K_I+TnHPV+VegtoV&%D<2`Cty7QB3oFS%;_4mpS&dyEmh^WtXp5 zmxU1GSzO}6w$gdCmX10Yh16*R+EKxL(E$gjU$DESOw-;?&K5WmyVMMbwkv_hspYJJ zwmU_EEPI_~oJU$TPn{WoMkfeV`^!}%0C&w9`nD71!ivS;u?4(LGwKWe62Lj7KyP=8 zy00gAS}EhR1WeFVN^apotC~14Sjo@#T$V0cPT=O+=}l}ka)Q%B6FA8_t0KVW=br#A zV9?V9m6tH z){m^d6TLch_6=SPG?wM~p>xpMkK$+jH*`Ao4Zq8__R0vKAMYvQl}3(+n(U6MNLfDs zZnZ2^d+pcnyz0%+;ETPLs32Kzscavz7KW1bnXf%Io4YC`@_DeIe07>L(yd3XJ0SYd zeouSqy6&N)@q=#Up`iu=;hFwkxj~z%5nrOlYTq^>!sMzBt3fBL*2e=;)T1wFdV7rL z{8Y}i=U$4e#=Xgu^O~Gonmyh*wU%#f0YW6Qe;{wMBH|I$;@crR6H{K5m^l?2sQv|81Loq>$=QgU3W!b~a7Z{QO*{D+0 zlPi=j>c^1vAD@@Bw3{=u-`09s7}52nl=hfl!yaabLZLz{1AH|rYb{=$7l3ItI zKCNRiLcq351JudqiOm=6934f+nWOtPrwLNGDok=^!d+1%U>VSLD;I=qp#Af%~ zN{3u#U&$_aey7iAVW0D-lT>)GHdFBeW)$TwF2WkqhFqoM^ZWvazHBn(>bJXFv^&;ldbXb@a4^VzZUYPBhSv>mRnFrp)7vQe=K<}3(1{$de^FLIv_6Gdcmi?D|T`= zLZ_ne(3HY1C~gOMcz6G38VuI;QQ9yp8<4V<^)DFHxt9>>Y0-Fj+v*iSe(u7TbiU-w zNu7i6t%1%POhC$_R&W7*z%@}0Kkwmd-Uj@Tv0L?1jmzr+$*t`{wXV`0OBtcPVw#j{ z_Q2WRn8vhGBd@+~pY|73Qwppf8aAcdEygwjOkEtoV_@K5l0K9bAF%$8m30|5W$*-= z*`5v$Mdv;TFLY?sDDrwbtg`vpbfCdgldk~4mKD@WhS&0DC!5&_Z+zx8VZ_@QdCMK( z;xk{-fXkeD%l+v%>crWXiIE2ad1S7$QWwPogs8TfiL3B^dVRTdLZ0fo&*m&M){W<@ zu~Awripr?5ta;Lp$gJvu z7nhO;ymttSTRW)^6QpW7H^Swb5~y_D@N_{o{fKTXFBvPNv0v|#8+2s}A-!iPk7m-A z2~+Xz*3se9-@jtXm899mc%Yi+V5ZP^E);q|wJ z?OVU<))?wf#m?ma~LYmf=N;xO@b`k zl8;zb%Xt5>sI^bEwf?D?Sv7xnwV)b8^lZv*MLMm0G)i$|uPj#?hTNm$b9$!Zrx~j* zYIzCe-jqFjbx&m(P5Q~AIOdB(ifbW0!m;U&okHCwH`J1bPS^x0<$T`F92*cdXKa-t zZ{YO`@S(qEFY@_X(VD}a!92^-AR9V-+Gb5`t!{e~J%#z)N_jtL1uc7<8sHh#5tk;f z!54oK!WaEH;6UoV0?|(nbo`xpAsBq4hrpZG3wK@dYK{wV9h=*yTs*7`qecVZhy6+% zBXe|n2QQhcaarM6js(I%m)@eiWFqdeqiO%EQSLW8Knnw}y}6XYIM+EUp^Xed1pemW zOin;`;-woasuua?N>O|28yslSAC8PBP^kK!4dc3NB*6aYz$bYj`!0@yb63TTd$onr zPh%K(U36$_-;BK%wWgxRCCpNUvm6ca%HD5=IhyOVm=z&Y@-nmpiCnyb7)#R@XlS6T zPYpk(0CO0`MI5NRW<=ekzZ8fFvGp*MX@-tJHZ4GUWR>a>0=)b z?n-iB%4JwRwxMBC5ro@HcUzcrM0@(Ej@ZHPJhh6uA9R%-+0sbO(MQ)npd+YpYp_S~ zs;Tzoq*?D)Bi~JkeFdP9xY{(@*ln+jhd$sb*7eTIKynrfQ`e-W&3qkBkA%(sR4{Ku0zyy^<+v{(mdDH+S37kw%bqAkHyqQVDbC|LEybY3d}Zi2S8>D z8*j;7Xs71BwqG0=lPR6HxVi((okEAi`wDGk7fv;K_0!@Qndg~r4?e7Kf3!Bbn~+>k z$Nz$*wqJBVXRao|IeNM_V9r!8(9mmkF!aVqwtQb&`J8q$eM)_)izNzBzYxpRgYMf~ zmUJ{Pg`r_B*93FXik|05YtJF5 zEi_jpQtkEY4)!)VHwW;W*)EFl##9nm(q8C^5%!&3>;E95XgA6(R~!DoT<}8Co2CQx zsb3DphU_`Eo3QZkp5Is%1-(9n%ajG}472!UwGY!gaZ`TxHH9N@oVI^X0FST}@0kKA z$IsBE@UCRN^PLh0i`iq!dSeSUVSEepZL@FmjvivGwN;zh0?SLx#Rg zf6Mtuct(-J?U(9psxH%EAv=O(Q36gKj`p`&bSYGd`#6!Q8 zlU3#E4a@SU;ZnzPXnBqn?U!-oAt$2*B;}8InJ?D*)0CGNMry5~?v+Kq8j2lHr0Uj# zvzk4Eduu*uymtJk=F`~4vpYZyTeFN|%=A^ZV1}8Bw_@T~@Oj791ZgRSLfE2>DmwTe z3H#QQU#=H4$T)^1N)GHkepP<&o7(2rH20l&E}ES)fgNVO#DQ968>$Rni8B{BxUx~D zEyR>c(`M*i*#-~8J_sIe$8i;hcD0-!ca~s{E)!%N-qrj#rf^mE(t6lH<~fHgX1L3!9LfUfuD z+YeLF=})H^IdAYrw5fgqMzvVAg1gh z98|~4DGWK;7_BSDV-YGP(I@HS1SqTx5L7&oRxEY)@#Io{xp({_z9%h0a9dX%XmGFk zt9lkWiunbF_XRcJMA%B8(f6Ta=;O+WUS;;!`ll+~vucK=CVE`Ky^Nqgz|6rA)KWg+ z&F5+}3Ota1;)i@VSMYIrOhS4qK+`pMp&8&B4A|dlkeiO=%#$OI@Ao^ltN}DTq;2=M+D`+~>V3~C#O@${n zb4a%qjcyFwk8EkX`zWSGyb5$$N&7@qEP<-j&)=$!nIkpRY|0e#?%H(W^W~@|gxQ%M z-P^;Z;T`vDKY1N(e4XY=9^u9A)0WEv-tPpWnF0n=(NCG{K3|!?=fu}io?PwI__MLi-`Wpbj9E%^8-f2_hmFDmp>A(isdVpOK0#S7in0&-_GmzDi+%aqMlR5rH5_iP4ozf9=Wu;lP40Mg;Y0VWma~uE zYT1rgh8=hu_lZm(z#}_^g~xeDRDGilcA`DHgn8Mv`VrT@w&3y_`q_z^ zbH1aR%Fu<)(!=UUiiMoJ-)49A1sA;v)lwF8$=8T&rx za7w%+WPgu(?&PKLnB_PWO~`o;bpk6dx3-yw6*NBJa@PqJ9>IN2&xp*96-KC7&?Icg3-@+t%UADRu5> zp#~Z*v05~e>?zga3-OlzZAtvv+uBnWeC|m)&>-K zJ3!vNLOp$B`f}I4CpRZ-OZMuVS3P{*l_EK$9j9#Q;5=<6cZCz1^xPL_oVH&bQ#dZn6c+3G zilu5TkKI1ZvhSH)pqvWeB2T_*IaC?;?Ar6@(8T;t*Imoh+~cZL(;nsNd_rtgJM&zS zksb_WeZIH)z~?Uu4K55fnBG1xc=$&4b#Nsl(2#BK!V^B31iY0*%j!jlTvdS;Rb_Fc zEQ@L!Pslaq*l_(a4=pOZ^6uF~yskyEQi#``_FX5m<7#QR@Q)))&yzs{m5xfPkr#vk zmrQ}Fj82=-!QP|2>N$Z#f%(Ajp_A$g4fHh_`0OjzL)Ng>0&p3)wR8zvMYpzQ`4z=>42M( zR(NW$s(ZMpDh=!VAtT!?zChTKlmp8#D-yM7U7ZdY-8GB7=G4lHmC{4T@#XNp`sw7j2Jx;NU?Sc+jx_WkW6@y>XkmVysi(; zD=zjqaqS>%Lczk3(_VVTqBc;>ntf@iuJgW2fzM)nv-$}4&8FC?_%(fUOj&@b4Qjih zkcSTR{0H3y-Q|F@sC>e3l0Qd+bKS%%)|ZjPg~qiJUtyYw=!-Ux?vus zT&h8ik>KxU66}Ku)A?>Zk*`j|Tg>U76?qAIg)Kf~hF6#F;HsbXjR%7_HPjSA`5n_U zk{XVzX*>>hbxQ9jET04KIX5|?RkFXg;J+!mXmfloT-~t8zZ~9=3Bqth*<5 z!D%Q&e#%K+x1-uwU;jwUZ35LZZ(}a@+2}&*J#C$RHnjLIhS99ZoaOgVDriE_AjGvB z)22PTwYWviX%*(VKXJ&I@sw8T}mS1E`-#JXK$E0Y4uQJGv6E z$Ep*Tx}lm(m)I*sM!qnuwePxsnN^zRd!3%fWC2CGdY32D7H8&`-xWtMt#$V5Sc!LE zO7soCF!so#AWrk~!xypY25(c{0j`12mPQ4s$Fe5kK{nCD3sTX;SnEC<*JVb0P=hbi zF!@E%)NpFJtTSrXZ7REV@iW@eb|Iwi+3_BuE2u%Qr!|(o4OTr@)nB?f>v(4*JFN8i z2q_uo2<=0evlS!6g?8IUU=Hi`b=T_;vaeb10ea2~haQRU){`^wZ`t#Rr|gktP}ekj z!h~l$z6O2Dxl$5<=?yc`nl-1W;H+$=8?0C)T}MTq#Wb9-f0)U7;()^AOMr3l+1I9= zFZRcU`WXr&mv*-S@1( zGp>21{>_w^D}D3WyZpHDFV~*M0izVaUC3>*-XYw(@|YJnv}n z>y$&W8tj1Yt2F?|;o$xP;3%ETH^S`(<9zet)=@3p^3|f(k!Cc*dz2+vl_RrItmJ|4 zJI`@-2TndOzWXdYUe(JsVodg){L0|Q6Mua(*4nov){M!o8JIVZ&kuZ_%&s0f)h?1Y zXA>p<$ib5FI&${hqUp)#)2DfPJA?$gCM z_qI^v!?s9R1?O7{tu0|WbhYG-JtkCxab0F{TG{P+r)nOFHOXAH!W?C6ngvt`%uH7= z>iEb=Mt?$n2yhn_WD(`Hg>Jn6s0XkQl>sgR44YK_bb0N9)0PnfdTmbk^TKaHayt19 z4)M67T+>U}N+)1JOA{|~G5e)Gk{zrbzUr{<^brgyBjtv^GK3{d_}UJ2XXFt z_rlXh35^Ulj`rJlVl=NZv^5_)S=i-hFy3Ze610xs*m|)fUz>@k-}0KaxQotf*<^O8 zLLVp(*t&wtvF~T@aE5HKn?B}dnLXq8K1kK;R{cj#_9SQ-`Bw2<^2n~_nkj55hmlXn z6~?2g0iK8nT0y^+Cr9z#tY93(;k5wsr7_3EDBU4l74Yz-Po9_LV=xzxfxYY!lR2iR zTxtgu&p|Cquq|z8B=ZnjhIhg9z~HB;?uXC z{@^StufdCRFxQ?Vy6>w#!jo&_L6`C2`Ke(uhXKxV^e}5-$!E8scc#2guXRhnn{S2& zTuM%gOd3pDKmBkM*vfL(3-agNFZQVaO|A_!V}`7?e!z6y zSRg|;n#@=6?Ka=)sV>vjS<8y{nH#Rs97bPG9XuH{7+lYtFOOU`pQufS`D$bs-`DbU zIX~;ff)vwX+*8L9(CzguU{JM^8*MAnU0G{ETR15q{`|a?PzGLdkLj_CGE68~G-~tM z=avxk+=&>@)sHXp?EKl{kjCC!%}hZs*}!{xHzVEH=?F8B37@V|w-(M-!>vO*fbXMH zc9cQbbv9!T6zjpt1B(`no0KO*RJnHJf`C3D#|h-PkaeBQ3H{L0>jg9%H$&mrv5N&2wYwksOt&pR%~ zq+`b!OVXrbN@)il;uPqI;^T`3KS{TmgB_NI#2r@mirH=HdbJ@&lvwJ)F_}(YfHEY<- zb9`M-7toB`U@$`6CD{`f?uS2KChy>pH#l>&|7o>E!6=zkf_lZdtImro%=9n!)Zkw} zr`$MO{LYOp|Jh!XdY&l%;rcE{IGNCf$ytxG%Qtfn57`7*ZE-1u;Nj`Rr&yu%<;WW@ zk1@(jR$~IYs?f=RRea*&bHf#<=npA3l#{1Bhs;?W56oTOEJET0lm#uD$R3~BU0}}m*7?`UY#(;rkn4+_$A0EFKXB6{Y zaE_wvMbQ>=EhwHpEokzT3gWN$b~)?ugr2ZYX6$E07WB>fO+SZDd8XA`nUd==S&|K8 z$@!F5b8466G(H1Jc0#K@tSf#~`)!Un>!He1`AIS_TgeQ#Qb!gVv>D3tWtQuLrgE z1T<-v0y(NY4p>(fr)^QS|Z)Qhv>Q%IF zzm-h6veILsbmqf@MINstCa)7E6s}gM4n|B=s91#c!=62ff?l~iQmhhn?H+vk(9;n8 zbwahgyyXi)(=xN{V!AdJUMF|LMxpu_C-2*=2Hq$^MOlmnANk&^baG-Z+wCQvtUYyF z5Nbo5gv8iir_tJDmL*F^*+9;jAau~II`LZ4d~HOOH@o&+<9=HF2b2b+ig)JJ%V|Db z|7e$I`lqVCd^Quw#}N^cRTt^iH#?j;0)wu)XgxMRQ_${ri>t)z{J@IGA;m+{ZD!;n zp`QsC{XRlp%PJ2(ka#q3#Nkv&-0|1`2CAZZeQ!oSJ*wv#ak-`wrCd-%dNh;CI-VnQ} zac}@ zK77Ica&;qq^y{4tG7oSCQ`et9Z=cP^6g_LpVK($NsOiymW)7z8RTgITq}w}ES#{#>i%YxGi`$3G%C(LmZQQWV(i+DEy=5-uz<4;}OSv5)A8@st za}mDA?ZeSFC%irKK3fJh9AdmG>O(Kez;2B{HX1;-s;;Rl z+vc5DV;F#yO;&(>Bu8D)|K?U=0e1h$>Cox%2vuzE{X-g@hYZK}*5R2q7WTIV%CBok zCf6o0tjT>s_VBEC1Qq)mLAZ0%2*WuSWj6K5rj7NB&NgENPve*qpv&ToIN}{p1 z(SnXj!CR1}T*}oO^VcV}yyZ1jr^`dOd3xm5FA%R6)|Cf~tFt-l*6%wc7B~=WZRLCH zm>AjH&&Aw^38q|wsL>IU9vmG|wM!h`H|fN#pZ)N$W)n@$I|Y}8K|HeVAo}R)>|k`+UHIZGcqK?iRmCd%s#ZFYV0q5e&8P;?0^U<-QiHk}IIm1}mcd~Q3MFZ^+?4_MQc zpH+-=0Ye=j)6c@9(w^=BDxpZ<5|0Y&HvsQ$S8tZ7DKmej)ud-3iS@;&x+1FaiAM(5 zQzADrY(I}+CEw7ML_69Ja_iB*)0{{Rbg+(7{xs0Y;6mp1WSn}wB1$jlgm8Qctl03Aqd z0yxHdQ*L#2gyE%F<80xEM{Jyv--Fh?bHQI1?e88tdxu%BqarsEOS&}$dmaGEBdl3G~7gm z1X&$M)g&B!eSJBg4qL~557Yd8r!}qRw&>*DBCuelJg(42dSn(HWDtGp=1FfE|w|?B;u`2uny_3YA6^40y7c8PF;Asrcp~!duvH6X=G1{!zWNa%E zN|WFo9k>zeb}Z2O`)*?RXmWEHR-XLe`d>uP&_&lP$*B?aKK@3YH_L z(s6^h5H{c*bJwOni?!K&I+s&j3tuKkeB3&nvB?+?y@@04=NV#l;|CRm;Z1!t7@=)J zAp1;a8Rkr|X!k3(n9auJE3|?@1Lx#!004F3OIf_TdsVZzX;*F+i0%|d9N_Xcuwq6@ z<8kMV^7_w*qSkcd3|6pQZ*ErMK}nFGyTRa--;s>+d+0u)8cBC0y3HKOk+~j047|Gr zEO_LVBdJo@?UPk+e2qb`BUv8KIkM~{C7c`*;YK%g&Ozs%PZR<1kBa{Q;GJG1)xIf7 zCxvg~(DVz(Yk@WVau2hN=gUx_tc|$f45;anxdOJE{t1iYtycAR{6f?24XOrTJ4cyY z8+jlJSP{kv1pE3|QQ^;zUu5{5sCahc#NIgZH;B^ecx~-}vSN}r!k!Aj1YUH3fCqm3 z5KaID1MHf2#a%>civ{kRCY2j5{{SNJ+q}&A%AmL2CInz~jkh*1#AHwh$l90eZ}CgR z*No7=hjkL8XlO2Gj!6{rkgpCf2t0=DdFxs_N5DTGT%>QPUqhzvZdTMZuGtuDOX-Z{ zrZa*&^YtFD@!!LK9k*F-?xK=lV|f-gifo)JJGR66umwrS%%ZU$_JZ(j{p9m5t!Xry z+!80$vj#eHY{ z9(*;MRrsxU;z@*UYoXe;v~Q2Ub1$c3wjy$Rk1=cK{{RVVvEMC}I&5jGYg(aLwCyz{ z4J4B}$X41&P(O&0cm$K)zWDel@YsAN_-CbPZE!B`?c?)rZa~1gx^uo+2m2+-3OguK z+JHK+Hk`L#y`& z%-uGuE!k~SIHS0Nb#vGs3qC7Yk~?4 z=jTVYX4WQKT?%{nySaUmrJIJu9B!bM? zOrA~y5_8mWYtyCido*#!Bug#yAh5a-{^eJU?N%FeoM&q}KZzB{{A1Ff_<4UF{{Z&4 zk*Bn&1+kuIl1oMjxnf`#=R267+AyOT;1fpE!~*OzcN0RzS(YY_`M2BfN@1EIy!8yB z+#qGhIiL?$(R@>Je*(s$;bcNfN9O+Re(+XNxd|8;EUHF*F@xB`(rUV8ypJSz3krGX z$xY%H83e4IN#JvVoDgwdLEy&!0E8}o^sn?8u8O-&a!D6`H|&%EB?BC9VB2wzO7tx! z#~N>lw&}N1T3XEU22(ZA+jO|VDw~+P`ygy}_aGfzrm2LO%-(gz{{SgrAsfq#j;hMc z#O~mLy@@4w4Ey_^?b~&aY?kt2RE8m=Z2jTDLO|!1KD>3TYqYh~;__Q^Llh%)O>$Be z000O=oPWK5a92DKDAK$|7oNUIX;phIyF(i{sAd6Ja7F<^p*4;uJIV2K)Ev&{|k zEf|F+Nr_H{T{3aU8%{Da!5FWcPy7>m;pUO?a?ivG9<6I{8>P%zR5GM`Xt)HBO0qMm zXKQ(C;Q;_`3zV;?E%mtW?$Y8lMLt<|opX}S^75O*ag5=($4>P7pA5s|%`!!nIT%g4 zA-XNYp5z{*Ipd${Kp5W+{68L(B26FeKEbAxM3_;}tM(LU_f@B@ZY+xSv&No@4o&6D5o73co|6ZHwL zd?j$I(R`(htv&-trx?!OGmJ6$Ks$(0@{vFpe;O@U#d=hbrli)gLlUSyU_!b1#_n)) z#sC;2qh}(#x5FO}E&L^H+LLNc4yhEPTMOH!j%-K(7-)ln#~B1=K?JUA9zWVd%Xrp5 zW|5&sXuQqO84NMNRXnx<8^Lz#bsFw%0H1Y-YQ>Nr^)` zA}{ucF5t`+zU#b!BqZ%p_k*EapPICfiFcn9z9HM(Gqa93=W?*amI21 z81sz-;zXhan(iBixWh{u#6fGxRz>6G8&u^&m1YEdqY9XF-}o=$TsP49a>zGGqFKJo zRijKgtB`TZru77nIKu)8kN`YyI&kU1%)-XWge-I4B6Mfr(en;e`DgFh(a2iJ^&Kp!6bKmP!NcbB?r zje7cP#J4zxIh$-)$zV@VK~aoj8Ej`Y>XQEe!9V;1aLe{xK20A`+FssUc^Ia5{GrC} zthg)aPdMveNl4Z@#hfi0F58y#H$3ER1a1TtB!je)G2MG$Voh&EST5Q)e|c{X(6N^J zRP99=9Wl2Dlhk127yur#{{RH;{gFH;@K3-NULduZC>tVrK7 za6X(6aLdT|=zgy7 zkBhB*WuqnKt2Wcbk+H+X-`!#TdJ#Yy%H28V-kyi${{TEyXwFZ6Q<|Ah^`Hrq?LBc) ztL^^)^;8Ji=O0Ru4p4ObO#o%e*az1&U9x%U#aNk7ztmNEKJ5TkcPf6h8#d8XOP{Z` zK_?UeT5w43ROLg@U(T<);QpUVf>1x&{HOyi;c|QS$u&-Bom=j%+zi&F%Yj_)#BUX+ zg1j?#s0CT3jdt6{#{yYMJ6G2n5^>lX0P-K(16r1U3onOlR3{?} z?yZtY;w*9r2jpKHd?2^@mRC+TDuX3B)dr@KXTgYWsN~3QPD>~(0P08p0hY-pCqJl1z>Rwp(n8)`s#pLE z1#`|vs66$-;MNC>z6a^PJkun!@dk&a_=43VG3~yEXK3?<4vcc@cH;zcFfqsmfIeXO zZ}GWqboQFwPqH@9C7Rwr5PWBDTp(~mW9Cv66MzUOJc7! z0X@$JqI_?YRyOHzD;Z(Q%dyK}Bw#LUr zUX0nv$W`cgVB1Ej_%q{2#EB!iu<)*(q{g7-Zyi?QmmkI?Dp=zf%7KA_z}sGf`#ksq zL-2os?Nd+t6^^ZUYJ&P}csAUl$P}2PVg9XoIor+$Pze?9eh<61iqUQsJ+1Q}oz7%@ zFeh;tIOmRo^&lP@@NeM<#BU3(SWi320)G zWwn`{u6*Tw`$!#B0uLh@CxTB(-fxKdZk7y_NoyL0#HFNc=YTLa<%#qJ55|*J@pO8U zZCMsI8EgaPZt6i854m|hi^hFu0}sSLBE8V`(w7X<1q|he6z2mRa(K_ac+Gs*`$71I zX|Gx)+jAbFBFE(UrFbE*Mghi2L6O*;b`|cPFY&#ty_L+W(KAL4#tM!j=YszL!baig z&PQ%f%l`lz{xIqP01SLLCb@RWcdcDYVz9R2G`XDr010g1B8{x=kX2bh{t!Ss!{S$r zG;f7AvKgcjU2SOMjUx-Z<+4>t1h_8*7tlTAjq#SJx~Il0WFGIt3j790B;&mWSb4u4eMvY-;Kc z&G%c3iU9BY9VF6NDTXlncD{b@F`dP+w<;D%a7%Jnf=J4*wCLXwE)1|tk}aO6BNT5g zP(0n`TMjnoX~MDRBX;o{GZvG^d=H~)S_4K|rDX>&Ixlhas)oSE;=F^@9D7%xcyn8i zTRu};+S|@uRaXTF##C(J?Z}rIE5L3~0)Rb&&&76rBP}xBN9H7P2gb}=$GN<=RmaL1 z7`E`cUpP5tU^-6&>sH#1v2Pp8aU_iG4a6wRF@i@9?UqxS12zF8a8FF54bcIfS%|I3M;z`ojd9w)HU>PPZ31?A| z0c?{PJCu{3e$w>_^xZn)Z!GQPjs{d^n&~{oFOq+D!5MviV5A;#fw+Kqu94zPdreLC z3(qZV;fs8WDK3fvsb@I(1_m*=s0@1a=)5W7Pxx6LBz3pk*06luY^YQ}IW5yWk4`WS z4muj?Cip|(-G1&Shf1}M2F%CExhZbs{H2&FF%|%}2^n1RU5AJK7o_OcUuUtkYgp~~ zNb4Rpv1TN&Ip^jqNj&F0IiLzQ7LDacGOJCA4)$U51IJvCZoRQxx5BR&w}bpc3#^~G zlbh2s!>NR&IkNlIXTgApPUp##{s6Xw7Ly&JWA=pb1oQ z_2!f&N8bF+2)G0MY5Pd(^Z}M*@9R*Q@_!C7R^2mDu)@#S zlc)o9GX*Epp68(>rF&e_6@lV@+TkAi7dX3G3s~a1LE^d$IQ%|@RWAW}R>#5@GYdO*^Dd6&oFU54SuApZa_P6w@QXj-ASSsrFn<<3gB;#H934nDk}e$)Zf zT;DV?x0WMkn+S*-k@Iu|JoG&EsjjWO`&)$l-MXTAY-Er3Z^+3V_~7s|PZg7->mGDU ze6&}IRUiGVH#`&2jQ&;FYe^of2J2?uhZ~~CP`(KnJ^FX!r?miZ_Ti&~HS=L5)tQt6 zp!yI_oS(DRSOwRzu<0$NOsBh=K^Pt=2`K^ZYgCaV<;08N!&~+Vv0Ll2} z;J7?n@Xt`UxP@d{m4xx|W6s=lY-0c~Uw(R5uo;#1L zyaDl^ZBteQ?TFFt^5kYuF>RX}3P)972pu}?BPXZjpM`uyb>SUiJKJ3EQHETQ9A`M= z>5zL6E9tL-UmWJveQ#?d z!3xW8*hU}%0cFQ>IOjgxR}0{;+Rsi+Gs=?s*T_}<&r-8so@94XtkE_${pBMW`LY1cF~^{a6>Rt;XzEuzt$lun)&gd&D+Yy0)C&4bW|*5=(h` zacw+`yMun_DCA}@*9Cx3PEPD?pbwrrJFIB-R-Rk#n^d0A%9;lh|bAJmb0P zZvHOl?Hp-sVQXh3WUl!nGDdU!MMnhZBcK3hjzC}2&j+p|9U-W;5eg#s%4b9=}MsLldCV)R2XZYo!*hA#`Rw5V_EIXw| zbl~l72+0JT_QxI0V7C3Bbej-~Z+UMTFYgE!D$MPWa$Bx=9l8P6x9*pRzhxf*=#hbG z;NJ||?d=AYZx??>L0vAJ@H@knC48R`=`oc#R!dmagX%b-55e6_;)jSnBj`V7(R??m z>z5K0B3msoNg;_q+PEsX_3hYWy>Yi63%(}VZr|ZA6h|DLRdm@%KAi`!ujsiv1*gUT z04^|aJE*Gno)EQ-2#-&;ZKJvh0Q{N#hG4#j;aIcPEcHze{uj88=Fa-ec{7%6(NE>c z3%17d0yLj8hmY>BF27UgI^$f&V-z}!H(1*hs@v)^O=}W1La1M}7G1%P(4kNq<7gti z2ll)0&Z)2XA?@vaGSW|C-gTwTqnK_j-yGnDA(6P<8v?8VI1H*+;yp|D#D9h}{{UwA z`W<6NlkQ$#-s$sPeX+qOEg=phB<;X(6CrY>jL-+z5%`u%T`$Rw-erar$|P3bXO03} zBYK&VHN!C48$1!mP&aKghSXw2`#VE$$+}47S8O&iyMh4i*P%p^;jl2JtK_XW_U!m~ z3$4Y!jC?Z*wL4OG)NP|-xQ5^>1W~xIKp5SU*&{X6cz^cg@LsGVve5iN4~Qgo{{W8- z6LzO@F8)YWw^wtIPnm~q!hk-Mxwu*5UHn_45|QK>-M&%RC-Vij2H=gu5%@I8;@=uu zXli_uc_W-{%SRHh^B>T0(0lFr`Ok!OkBk2R8(c@>?}xg^kEYyUe{6WT9X2S3aF)k) z(~+F*XyV*+^N#-jQ~u8%vse5oI~KUqwGZu`a^XVHb|6?+9d<^eB~W{k#{(pq0QEZn zvD>Pug;fAAJOC=>Uo&z(ywf9uD5+I*%>X?(V1EjgcLDrgg+@Mi`qP#T!>IhI0-G)e z-TZ|@>w(wt#aLy?sAON8y#QB`4aYTJRN47{6=m0@MA3hrS^%bF=b-1(n$hFw=~jyU z?tb?jDmfFcufKW#r*RJ9T#v+G46lfE3yWD4Ngc#&j=h*3wc8cQs`nOaDs%j(1M+kB zvhb$4`z&~pOXQ8$L6TBd>6O|meDqm1gRywV>>jz!csyp2`*7%Xx@;%yk;@@<9!eEF zlgUyu?l4H}+lu{T_=)g;OY!EHHMRAuQ^PWqWOdxE4_w#6pR~995mUkbEtcZ{0K?Yp zpj^nsv1!C4ob^w-KL98K=Yf7IS?ao0mvFZc+(UnE3~Z@~#^KIzc?1wh#z%g&>Ru1l z-saljt>$};#IogvULlT0C%*?Ij!EN+_=imWojxUg$uPc?BG0XOax0HI3FBPNWs*0p z(~x6+qzn+pBimy_{j*!)^}@$^jVmiRnH#s4Bd!l0Fa=IK`udsx`pe=!?EB!~jWeyr zt!pCRhz+mWt>rMt5a%F{RQ3KPU@~!DU*n&EdS}D9w%F^M_0q#8_OzNv>?-fNm zi5s#>$<9H|dFjcN4%JCguI5 zeKF4m4}c|;7zB*%3IOVBdrAKQf;ISsG?@O?aQ7GNxt9@qr1jwaee;U_a_~>WZ41Df zBpQB~r`p+B+NeZ|7}XU(_5-i2be8%RpdaejGy(XF;(z!g7me(4_K>WKr#o0S8y&d8 z=i0c5{{Y~fpA$R_erC1O@7hSm-?t5%ejdK{`{r~DQa{y4kN0|-2Zro{{{W9@Pq|S* zAAJun1vPY&%He&!0k`(UAJqKm?71Mk&KOHqs5Wya$@VDY6 z#f9YXiD$R)-TcA;2-?IXOiHQj6#oEI`;iulZx7sp6+X#|u(s;B`A1%WPzTC?v7hYI z;!Q6`)U}ToS=?#YdZ~sxeLhuGXkv&z?_B};Ws#+9spKL?4jUnUhSUBYTH7j(szCmA z)d(@)sp(D;=cwXEnKB!yc)rE4sF`ufv?k5B7B9G$npF9+M)Eb;h9!tu_$iAIxdAM3a# ze;n6IV`7?huaj>YTf#a=8iY~!0gAJYqtczT6abOOvw(Oaq>eqLf1J`PkbT8Hct6g7 z6&C~Ao(322>+4EJ54|WK=h}cApEn#*&d1@qwcS6 z)iieM2MRuv0Y=wOXrmjqayy!^uwZmNRiK$4j%liZp5ICUhs(j~x}6$ip4Bej%cm5G zm>$1{00DNNS_E1CT+@*-Pu|a=siKX3Y~gt0j=t0Z7nh21r+(dO*yrx?oO4P>2T#g? z8RhLi;ZIoq0G|}#nf2>V3A9iIeqIJK_*1rXO3~-?r3=sb;(!rUZ9e^}N{Gp)WBc59 zqhl{a^q>S#2_E$IO!m(JPzTrRP9NvmfB-r4^rr!|Vuj>)qz4^Gxd8U=!nkb+I#|!lQsWbj@MHB!L_~Y=Q2Q*PY z4prQKl&#a|{x0-UKo3>VG^2*c`4mw=13usL)|F01-lshK(M13Se)fCPqG##rMHB$x z9uLx=w~l)redwZq0zPlAwE;(C*QFFt1DBDXY}1o~Pqh?K08UhYoEl#&W54G`6ad_X zIM3%wC01+$-yJBTfE*b~?fFr?N$W)v0U(nev!q``vZFwZ329y=P|DnziQ4th49L+56ege&%HSWEODYv5vkD zfQ5wxkjC5qCv$*nkAmHt0RRI7fD`}#IK#Zg3t(lQu`sv)*-j<_4*;iFSpM7pS7SZR z`d?#bV`DuHWCsHOJIo|Egp^Zy?5-|he1#oRg1 z0MGo_7Vu(DhOoZ?|&pI667IxVpLf`3D3B z1&4&j{`2---1`sl$tj;w)6z3OfB82z?`uA`ps=X2s=B7OuD+r1duP`V9KO4!cW8Lz z*XY>z?}@oTe^AAG7|qW&a<$c$jvbV$J~@@ISj) zP6aU!Rvxy~;&SY~+RuUZ(2LjPqt9HrpOjP4aTcWTVuSCs?@x})V8yxXoBx^iKP~(J zXIRYtm1Y0iu>aSt2>>T63v=>Vc>tP#;}mD|L2Cn5ze-{geOyA_#(3hd`-}LgEe`NbcBIFEc5&FA7 z0IUnz2b>pQyw_I`+>Cmpbd9H;Sif^gzSG$D6>&0#eWs;5h?Mw4Dt$M(~j zYEgbMluK#9MsEXuH` zazU^g7>J|>;BI*dlJ5J_3$FjY|3LnDfl2^A_LE|Wab+b~UcJ7k@y3cUv(=3m<`ep2 zGe$ooAssGqn#$dXUo=Wzz#p6dFx%5#94GO{)GuGG1d@XEz2y;6l@~PCOkNWnd4$OO z&+Ho+SteJ5US~(l;MG-z^zUc}d1);)tS`|{0O*_PSEK=g6OcOp+a19oyTUc|_mULT zwspmf>wF=0Tr6TqGi;Hs3(UZN4Ux;}(8Cdr-T@Uk`4;v9{L^cA2j$P|7 zR&E5--glEr!|jiPU>*0HJ&zQ_Lo>g(idJu|>`CWa%z)Da5nG_gVA5Hx+#)y>-x87C z7dswQ*4K6*{k0@rvo_56kvTrjf02`2uOq{6O^@E1TBGXcZUdPbABtmetZ%BU$#=I3 zBOV#IaG_+OXp=Q^_Usy7ZHHwL)?w2s`7~DcZ&Q%Nsg8R}fwKuV#un0mL+gzjzbSc@Uob_?^uy7Al@ih`dA!_1jz+>A#(m z(!bSn#=_hs=v7N5L>P!2{TKPtZD#R$?Pi(;0(LlMwD5do_)fCK_!Zj?Gw+Xb?ftiR zY^M?kRIFt;uShT1`m!8QJn?^z;*`-Y*YE_M7X^zK#*<-^k_ zoIl$Kgc5q^|KSW}C9$qzHuS$JZ>$O>o=X#WpHh?f%?DT~u)m9QD{>zzZ+Jab5}|Mc z;PyTEb$5uSG8*68nK;3hucJcdiSRM1z`JL_IxHD?AGkYu+CDJDajjnMo;?A~N?S9{r#a%cd}JK@xSq8TXDhWLRMFbErhz#D z#5IJw$MtXDExojAy*)?Oj!ilNJUsyr+^8pj)1C(-y(fT8jh{ijxhkhIeT*^Edh8hd zBF;y*e_u|Qul2GL>4FRET1wK`CjJCq%V$V_?xGsAV7t&E@r_#(DRjgUUR>US3cFK( z_}KeEnurJJ*}xurKLK3W*iG8qg>McXZW`JiNW<{@?{+sYe;A1Z#(B_04KStQc;DCN z^O{Yn84lpfued$t1CBLnxN=nQ^V6Go*(0&68r8no>|>N!tu1u-4qe@wjM{|$Nw{la zG=SuCJ5~+&*mmGgFVh!!q4B}AzGt3}R-y>de`-=X%w9O%gjfbWxV>)J;< zxt(L;P8qtZ4PL9$JyvH zWm0fn65FIHK3|Tm7I*?s|2tf$_VHi!cd!FVBrYZj_PuNK)Y`mf06#xLB`jz_mbsSh zw<@ID%5 zRYTztE%JIWUXAkv;M%RaJzYF-6cFy%kk$6*Z}7UeXJdb_7Uif!lj+*F0@4BG=tLES zYt$O8BDO9^7)EfjjvNS7*FOPx+QKjDG_~Fgzdr#~9=kzn44_rvBSY4SisLX3%tr%W z%z>rt1%aqpr3wWHjocv3ve0ei8*Jo*7z&b6Bdatwe2yUbGbY|jBqiwR72TxXLw^&5 zD^VD9_X5^Em>hx0uwW!s6*f7HQ^**;@ZW?0{Rr6fw;L9O6TqBZJDg`7HR4OK_QWx; z5~Zy3y7th`b=>~5?&RRUABIhVi2-iyAmrX*SeU8c%!gOk+CG2r410;`Ulw8L5MOm2 z@@ihCT+4iOJcXA4M6V)#f%I|{{px_i>f98@rAw&^)tg>9^7M&!IX5=nvCk1;o9lk2 z-gk?8KoMgvlO=mZmpGa08?sA3C^^XAvheBJC6A^P9U5}`&a}@g0}Zo`o)1L@-L|mp z5gvN%oT*xFHQ;1_0yqT(Z5NGQsS2@DqkfS6-Kp2*wckUB9h>++nykMtbqHq}UNTU& z(-oCGf8cilSkC`^Oa9ALuc_wH?Hmrdt=_&tI|u&x$kfA-69AiU)0LS@bWY2(#g(ox z?fF0w{#ay~RXtcJQYiH00{yS`DZbPeqe8wzPiALV@}BufD)X5j;f~2+@3rh_nwviv zA^(mdZ<6uHUqZt za_mTS)s1VN<~i4QSy_mCniP%f-ld>h;+E#T(--Lx(%T@$Ir(>YCgPr?xJgUA^YrZ! zovUx;Kh}Xw^D)B?-xI%uzV|vgoF5zdbY6OyCv!hVD72pdcFn-Gz9gyd)6ih+RnIKr zgC>NbkuCK{tbg4VgP7jX{2^KWZt8FjUAvgqK<2`ow8@vfBz^uBBLQ`8sM!iS&a}EA z;ofI&scx_tM?Dh#nVd8}%XO^0_wSwCHy_jM0Dh=Dk? zSkeFNu%W^n)Nk@$z97@vH*%ZR5iV+!vp4YOQfa}PLA3C*uJ#a^G*2gI42-Dw%2sbn zroW`pDvK0Id>R%&r;$s%60R2&hwjlHA8s#Ad~K?TmmN>_Jpm|b6AMoORKHP#iBNwe z=FQ&Q>3;CmufR3HJiUI8u6h)dOc-;NSIm4~$0=Q_A|e|RC$zn4;mO0XrMpnk&K2Yv zEL=e(FIkFuq8IteZ){4w31md45R8D?Vr2Dbs_`vcS~~CA!-Ox7yfFf2_=Q-g(MdZJ z`G#_%Uw3-bw!9oQe7ZSgfpZ6u*p2aD{(t;+K7aZ?==(yhyi4c7O<^B-{cgy5_kjJ3 zRpd%V9LuxIU{Pzwn8;XoWcZ+D@8900Cae1^{ekm#j^SO$n)wDdjp`Ip-@p01jQM`e z7l?%rjD%dUkG;o7DbFUte{3b%r@mZIR!RBW)EF8StmdjHVy#{K&J^F%cbOHiyE_|- z8r{^FMnU4sGLtS@c1@IaI`~7y#i~x@*bVjDDT@->9fzX}sn}Ez2(9~irIG!3N;D4E z&1afz<|JQ|{IEl3RyU8xy` z@Dinkra=dOansse8JMMC{-_5#FcvX&0wAjS?dHf^3Z1zqE%ypQoe@?q*m~&J7qqOr zlhY~sl>E&vWwPd60<|vCmjwvf$yU~BlERgy-9iAZcLr1xNqoNeF^xRf?&gnkBA4jn zOHo$KAz0md#n^=8Ieu^kUUkM?n_ztb(g|RF*YCf8#x_@k&n|JO}89pLb?kq zr!WvsUv70-$#Xo_D_w!!+6>rLd$XHSlbOQ7ccmYtzR*42zN1XMG|WElY?}^(sJUz& zh93TeZY79Gf<>lLlqCi%OLg(i76PlT`Z_~rPk7{;Tw{Jsr5o71n`V_v;yTUJEAFh8 z+zC>Q+pIv3f)(1vQNMZvWgh;@vVKkGlz>d;-uIN$M)$GMo;5QQ)K&uyUH^-u64n@ZF?d zaT;H%?RGq!>0@H11_E4;9tV4LiX}L1G%BnaBW`1D+9~|Pj0~c(D0Xc!Zb~%WlHUJc zSMyTJD&K5YuNr=%apkpc_vzJfuute3b)1!I`Jztg>#!jOSIZEuM-%r4@%J(^V{f|l zkQeZr%EAO?F+1mP_~IXhg*4}yEqsQhec~kqJG-Hdt_RB)pP@AWcy-`vo0lqQA9DR7 zaX1_ARTFaSvN=4``lC3TXojrAh}?Q~52a$YqqGHI(L|9X>DBSl5*1`5Po>}LRMIEI zVTUctA)us(nQ%e#g-tJD>#N+DReHu1g| znA^h-?oxmq7nt!rRI8|jm|1W5avO5{Kp-o@-NpDT&uJb4sfG){BBl#xG^2hF+^|V{ zq90@Rq|p9;ukvYgmMzSERWb&D>0l+d2BA0o(8&YNntW-VhY=?F_41)dB1!K@61^W# zFBYY~iL_7>MV6@&IFlv%Ri>Z!AKSiu0(k3eCJ5IvUpwl4c^#HcCbsxl27FtW z%Y=PSe0%1!i3=;GFp^U>Pn3dL>)BG3$vJ;jeg>Qa7cS&ES&W?TO6)8=+*zF& zs!bW3#jcOk%Y3}P+gS4Et&Z>gtenamqx=w$NR3^rpeHROVhF2t0(e~JU{{{&iA>X! zV}O!K49%}9mrZC0U9sb}ET->yypV>iFjOo0k(&B-3ZV>eq%czfQ+57aS;hsR%o*Us zzjYtt@12Gs;;{S_IwL^-otMhinY|y{IiT+8&$A+aPZLH8ooz}x*f3=+*V4?};tIhc zQ<}s}fAxd6Yq~O+cm+d=)#5lC@crAU1WWDi`isk3+o^%%cjxfN48=Z!(!R^b;{+%? zfJ%#ODeA)Zb6S|THMK7tyNQAHdz0vD`$Cb5;lBk0pFm?&4q--wJn~ePmzf5C>ND2;ijdnOt8m0; z0_@I^ca`I1wyJwtVS!~t&4XL#KN*MH9Kgh-Th*^F%7@mTCukpvX80Ws#}A3}nCxp@ zuSl>@R(iVT1Xw_X`5KM!>FiOjK+lC_jD}nxCT{5U7L#v{`^{`D4dvFC_b18sh=IL01X5qBN7NIIsbekXRn5uM zZX9oeeG;w~r>Z%kP!4GTDC*DK+Z`{yyfJ=oByn^@nX^LR^4*!=RZ%KV??q53mqo^& zZrZZhR%9PXr$}JoRG@fc!~xqVs?JgbRz2^XQL?R9SgS`*YM68-4CxB1w>?ezp$!eF3K|d% z6qVEX2d44sMpa?w-vy)&bJkw0en;ZjqZ_x~N*=2Udrx>=fX&s$V97!~QV=8RaKL{x}AlO|p zb{1@%udIfkr;?z{Tz(@XwVOB@zBD75ZhSUl8+9>uodnyhec}*PTemCI0g|)%dAvZ$ zem>6A(++}PB*A0honY9+6%7*usnq}%kF4UD(%3a0c5aBXy5xBBxE$g&ILYS>WU_wN zj&nM1sl~_6h0plgtaWl)<7SyI+;v&8~2j_8WK$o*wTfeG5@uLRzY&2nNl1 zscX~S&BuK6e-Zl-02+c!oY%!Z)^|EyTzN;Xr_?xxf(Ny?U_pS$461=`vf+{3xWf`d zGj8tbbUFPurwX8N{b_w0{Bq~%XXh^H&H9nJ@fyF-b$9On>Da1_PH0E8z16xQ;{5ZS zXN$@Og&s}H_L4>V^lcmJd9>rC*vr@0g&vBqA0k98!maUK&zno~1~SQt@z1kxL*{ip z5W=M(CJDBkn>XUsuU>84@Vp_P_R*uDU&Q)?O)#*HwOpWm8491DPo6VV^8SQA(=AGV z)H2;gMgKklC>L`+pVxfqfGN{xq27FqTjCp35`fMu>@%0OrfsQ`{LfXvK$ljH7NZtR zj_l(h{vq(=vr`T4d{sHgOU9Ae56jne1e*|gRH~s=cw19Kv*UK`1orkN(dH{DG0fag zGeKTJ(0MexP&M<`7D9EX-W)L(h_=et6@A|GC(D~v&iQ+Cu2@J!$owwx1n>=;4Dj6V z_HCFD+M$OOd~nUDAr{)f?+CzMf@ItMcQjQcPbd2k$5=#YsOdk_3uS4)SMEn|1>)z% zrlY+D-%X(b;2nt^;IOEzCH#Dn=pWK=>N;YijwRWPbgHa3k4sAQI{N;=u&czmRWlQu z^&G-S2Wb4nohszf<)aHmWNBKyw0js2G9{QSq?SWKaEF<)r3V{6?X@H_gkO}=yo;rE zU|gTtXRsiR+;;J*osM_CT}BJ+sU;N1Cim1414>|Qz|j+Dlh*^(0L$~b{H?K(u%EH_$KQ8 zx&oSk8E$c3$RXNrAv6X}#a(2y6bXP2fj<+Ql!e8z9zPc&an4Sf%c# zpiv@ngAuB~!ZHiY46PQeS0)JzX7S-UlQTI7Tht_Z$VVfx6i^P>Ijib`WN?+*R%P#y z9%3tU_be{*Rc_gMQ^J)HLcuP@WBDd^5^!vvyNQjAh~WI5I}@*D zgLL}uZzxM(DD8Bq07>M_5eL`bWSf%A$DXr2aED}dE+(-#R&*zFB1SG=`SK&R3+<;R zFos+an`nHg<(C^g8dH^W{3LdHD6!08Ii%#M>7uvS7$*qfI>th8UeJYT#S6i8Wc6iWtWB4r@%9ivO zq9)Spq0gE4v@3Eom(4*iLY_-g1TtG!P(sTK3tJS*SVMVr;L|z`6?R49;8%zJ(mths zf9_2D{Pv5!>3&&X7FM&1L{=7yQk_R_IU|-_5+2vC%L-jgEG3x6kfTu|&_RDT;{a(t zXj8P(^v9o>QE;IwW934}-iP>vh27du=pON-k$?D7KPZp+;PO6_XWYvVKM5(Jog;Tc8 zM1|rf?>jHuy!q^(i*Q6jGWRRx4t_Q?>_gVR$?P|Oj^(7{G@2fc`&G?(+}AgYmAJkp zI~vP*oI3Dsbm2jfw$O|`UMMUYoHdJS^a-Xce&1E@kz#lMGD)&k&>4_^6G z@NTXF$^r!ikX7y`qXHyKpLlxx)L1R98<1zg$xp?#^#W%jmKT}^8cZD~1oUsu+6C-W ziNn;vek}hIj}D(yxWka3m2%$18Mn&LGrsraQ-10RBh-P}kQVGCZ=Wr`55Ii1bUfHQ z*pln<`DZB$g-H7gDE^r!>N+_inqKPnIOn%v{KIAO^L5@c zvp#2x*TRB!3g@d37n`6jre0oeL}4g91#@+zUY$9Dp(c@P#`CH11W<)&#%!rmrN6}J z)YCrR@ev3nFyVA$_2{!fFIro6iO$KJ_}>P{gqtL6(qXOD`1N&boo0m7)V#&=h!t5#_nIfA0L`Q7!m9 zUg($`2yhWzPW`|tP^l?NYrrY|hlQ{{tPBVo#y>Ixgw}L3XC)L^#^X7EN#J>e_R|o)L7>S+F4Hk zWs^gns?Fp}H#E~1k=60VE(;>fXg8a^y1vI)JBHMoHHig*uZ7>a6J`wMFjQN#=7#lY zSGv{4@h{uO;WZOkVXFb*upI*MeuRbx<6XsRSj!9_jTvB1w??&AZLG%Q`sSx{&eP2u zH50q3B&2cH&B-k5MpX>@=&KulL~g$fJI5+sOCJb4>fF>Amp8ig9nc!MnqvavAwZ*m z-b5ip@KjEL_R+k|K&`B{*xY-uAiAdZfy(?VP+2mwggJOs$2y6w8(Y64w7rf1y1RYp zRru=wbJuH1{i2Nva-sRSAi&O zKd=1T9&7)8)rJ?EV8UPPQM>&syWIq*2G(hj_+e!=K z`kNC7CKoiK4q-{LFRQJP)!Ik6Nd-v9j#-wpOG~UmCwX zvRR>idl(g^bxPp-baP^Ntc+5<8A*8uB77*@F&kzUAtP1s+#NdZmHo4*8wc5olWKkA z`Ao$DS*DGFnr{)kDIwDqh`aXvP|rz(Z2fBYKETh+F>^rom=x(uq9Ku}&?w*%JS^0u z@6N;>kn{Ez{fuj>^%MN`mNSGIJRGgUF@*e_;5dpaEuGad|KkOn)*5FOiERx&yz05) zQx=yWBGBsbF5K&E`otOQC?k)+UA|X18usVjJG-9}8Kx-qnKAwjDJY6dj~YMNR1-7y z{jw|W?mMj_@@Z3(Ep|DS50HPJEACV>Qd%m&WcFXDA8*d%(|vkzl4KKeoYHr8fha1m zH#bl|0?=M+^~R@r{hVGAq4U%=u$el;9?pr#@_C~&b~P@ow96xUFLf%V9*%-0>UuY7 zfeJuR9*>Nwu2)rszmx5@U5`y8k5}7^Qy0>>CI=wqp@Q5+G>ngNUaVca1(Of*B5YQr{1|nxJVo2Qfsr3gNMXR&*=fKQC?R-#FjgClZ1@)_#-8c&rfG5C}k>vgWx}jTkz8>aP73st^o^ zF3q%R;_`gH{m7(+*Ar~9Ah+Wg`&&>A-atF{SFG|c) z5u-Lg0%K&2SP||J&$|9rXp)Jn^t;JqpOAS@X4&ZdTI#=V)7Q!1q{Lv+i~$v83n)Kf zs(oH52)0Y38m`J#f|@?$rtG^9RPWdY1rb6Kl&@wNx{2w+Cx9m87;@~;EMes1MHqJ< zic+`02xCIJmSu}{AQ+(lxBJ=uy}OXt^L1DFS7^^-zT5q4a^iugx8U%Zp&WhLQbrF! zT^aw_qElg&qwVHVPK2_=A%nHl^zdS@vh((-Rb(I2r3WuDOAGlu;?AlLwwZ~+0s{N_ zxRp$!wIkH(f9i##b;AbdN&2mN@}idNAybf$s`uyvAKEbDEKbfA2J7IXxT)RSym7t+ z^rS^z`|sH&<))>1g*%82T#l)X@)LPS0#STx#y){-EFFv#q$e-f@jOHp#l7 zF=FI!b%vGEhzXRW>8b6|Z#GS6C}Fp-H<`gqUGs2c^eLvXl~(ezBU^smb{qb(y=}wl zOlx1z#v-|IYHKD}^qE=;Jm8Q#+DZ>D&(hd?6wc&APIspFR2{rp(=fxBJRl*{1~&lA z**8l(_-PWP!FD}$s*WedCv+QWOZ<}3-&eIJy^fC4XgC}((ztcclT>0P0r!Dj3!$z* ztUP|lHzg>~_uAa{J?rUgN&t$b`BsSdT z!!AlCE*n7=9(nctv@VwGRk=ySYsZ}ZjtRHp^p&#+y8MK%B2!d++B{-L68T^Zij#{` zzsJeZ;8IxHzcTH2`=C$c3H$ZUSUk{MnJq$A7Fc)ma55L;D|dc-3l$q1i}GLHrIfxA zaTQ7h8HTCrdF)n4JACv=s9&Fuy@Z%+oZaGLAq3LL?cO|loM^-J*|`&Z8NP+}u>`H7 zJov8Y&1Gqg;fKbd1B-GFH*EcQM2rGKU{tNVp|YA?Wwp`Scl#?VlgLjwvC2-LW__`w zKy{<=O~b&@oT2ol9`z3kp6Q-7rb_w^`eF+{dH|8OFp+}+r{$kGao?XY0}T(~_9<&l zlcQtD;X9NTcNtHg2REJSx6Ds_ymNboY3AZ4T$0=Ks=0J|#bduoHcr6lpSK=i=oUO! zmWueK{q>rZ4L(9=T=$xLN=_@JW&7pW_YEb-fY5#$~m1aJB47d-uya1GgGbX^#y zuq`gEK#bJ!8^ear|NZBQw;fwxQKah`e3qnm><*Cs+%*d3`EdTPjT69!`;2ver!G5r z4^=ZQE@0t=Iy5?V__V=wKlGQyq2s5S+~VGsg_|7S3TR!glbvUY3M5}|5Y|&;iE;b| zLRc+%sG5sz{ zU@+i!kGP(gALb8WlMpr#p?Z7Arn=DTZ65p29D-%X%Z$&9^K`gknZwNd{F(u|OcoqDmwO6tBE8>G+1Xo^5wcQw1mDoC7_h_lxVv zw>7Z;UyfSyP#(O)etgo)TRZ;ZU(bvDJStB{bJS^RK_I*8JwkAn$MBb3{a@p~_5SU9 zvt`YfprAA;yVl7C>_Z_xF0^19{h&N~E*O`ekf`)?;%sdJi znMxr1>)!~w^_tXwWzdCxiQaPpxK^T+Kg#^bO8SJ}Dbtz|5ur#gqrt_lLN6ANxy&p$ zjF~dLXvR#O4)U5P9o+C^s5j(l3E%ylBoDq`j`!j=V$_Mw55wPndD`=Ko=%D#n^bp4 z4IXFJDOs3%;?`d791iEDI@S3$2QIDZ-rD2xn%^LR14JBQH|HZAYOawKQhQhJH0-Nm zAY?JLD(vlZgiW$MVoFvIuX!B$`6t%1f8~q-aFS};`@sO=LRzZ%R4#NB#?-Q%0IV@? zlDS-hK)^+!rH6yRZvVvOdo@I!2xXIDdahF-sg-(~EFh>8YHI#yIQW5l4(T^T#BZ#Q zx^ByS8y!4032j~9axALn66ythbZ8Pcx_9|DIQ2hQs!+t6cFJ1YW=kvz{`Q4V&(PPzW^{GGgDASL) z@KW>x7qTld;%;TbLj0C*@VXAx)$7);q)d;7SAZ2D2p)R4#ZoBNbme)~qnb||!_oWt z`@4rfg%a*o7RqGb)}|}iTye|#qUhL)Dg@buP;cv#7W|UC$O~c=2p-@dymfQRZFJk9 zCe3skNf3y17(R#2iYv1xK4^j7_kDv94Y=%8{_nPnh zWU1V+<0Q{_uEMnAy2A;;q&Z?x`*fRzZp6Uzi1<0fKiOR?dI`}^@1@XwD0CDc&xOGb zXj9yOJ`u4u^~++!OJLq7bTfKbqV#%xzC)#RtHmdZkGEvP5~5zPo&fD8jEZp`^L#1^ z*@AnS)r>frAatg@a_CE@nZcmKB?76Czhh~Pe;^~kZ1@)W4ONHhr(T?lY$d>O=Dttf zF_+&&QRcyIkFTCH(x(=LBrRmy0$}1}!05l{ANe~=i~gQ;mHVTmwe=klSh>Z|rS*c_ zi<{2?o2$+L_-5<7ttXWSw%vw)Oxu{l9397HZ8E)$EBz(=DGW@eP!omqX0aOml;~OIGrx5d5n8eDdV1&Jp zuZ+~B`Pr!V4c0rLq_^tI9c&0?0rVCDGKIW}Tir@|hgYvO-&q!HHL~n~bGNcFjuBG( zH|w(!4&>Cy|4YAzb#ay@l>i|LL$uf8)1-=iXZ92lN+F?Gnmv z7vHB^LEZR3$z<;g{6X!iS`i$jmCLd ziyS|$EQkk_+agtdg_!0dWtvi>mxEXJEKlLf*G$YP>MD^C><8J7F)#Me04AJ%9Dm)> z?JUhS&H+syPUgKQYn=uZ-WM zZe*Ao7!4{)ctvIzux-F{O831Gt}tiol3 zI+x-x(`4DI)$Ri9{>WM|OP`6Nb5G$HY%{%ue5B`$4sO@fXU-XQDt%uFm@NMma>nwE zmc%3qV{Nv>xa$jw#EiyKbW&6w?sH7L9L}<9gEu~6@WMMkj&KiB5vBZm$V1`wFv?r!nC{UF7N4qkqU7+@^|7xmQwGrM%@vgL{yqH2)Pd1HW^MnGVjZ{X`~|h=vEW3kY2RT91rHhCQ7e z=55+e!l?V9E+#b9gYC|bGd51IzPf+@yIx16E)7pz*hTi7A$9z1$G;XHp1SK+6jftV@$v7k%M}z0CR@Q?OqZ5B5xkNs@k zw?l-r=IuBx&$#|0<54M3I^S_0b6c@}vLOoJ??=owCKy1b_dnFhC|^vsr9Q7J5>|0t zbe&K5pu-l8d0O5xOTJ!(@I#p%-+Pkyy;83ocX83_bw&AXR{Np#B+f2cHa^Z8on^ea4Ip>R}#i?SJV~BA5pMQ-*b3-BiD}$3y z%-Xg;PXw6R@tEnrqZ&F-$?B}Ngz{K^F>O{_<5l8GzyqMsV_5Rxub>Yp5@- zpj+taw&95!g+<>`M@)LqbL8VoVTgr)Tdg}RAkQ@ap{PqKf8yp)EqPfcfcIb`1N1r1^4f*xI2Ep7O$vRJCJD2j{Y6HL13ocpRu^Rb5nm2VcJp9|Ky9MNWl7P{EMaO_l za&ZLZnW5flHP7gK(o^UIFQX%C2ceWNSg?M9&JAwK891LG(u0${i65J|W&GSVeP7@; zln(SRpyc~8{Ni(5iWD#Nes602G55eljqJ51tE@ryv%+Q8{QUIwXL6VGyScf#WD?E6 zFC7>nkrWMgD1#2QZZlExbsL3n?ahHEr4=_GJwrL-#A_kWL(JgIjtl&r*FNi}FWw}~ zr8%qW_8a}tef4Yq4*kNny(h%?XtnK6YmBahv6xG-0iqJMK3DQ}jd5Cc%`~XayrJgD z;zwj`Ipv#Mvgoa%`z5nU?K{FtAZz5DSLm1d3PNn4asCQC;|^sY_l98m&8%3hjTkAQv|6}@^% zlend+P@_$;))%t>`esO;_@qkdA0`J8h?_WW8|((ku%?RmgF- z^z4l-&c6EN_M?g@U6AO(`W9W`_BWmd67^aUdTCpvJ!yoceiuKz;`3$@UiLM90S@Xn zI2T0)N%2oHoQZ0lkR1xrc(UxS*YsZ^MOs@%Qw}7T%rBw}19k7Lo8G$Ze7?pJCJWFV zc!w9DK&S*J0zwMZ)O=EEJhDteTK44zX*@;MJ{3?~}V3ZL z@Bt=s;amz{jf=#j@3N@qyi&p3N;!jM+XG*|=uwhP5NCLM;al-~y442_57-%Q*~Ui-Y!XLt1rWTt9KN8`^Wo&QuC&zgLedj=FYx{XMYJRCaL_T80(~$yh>sO)wE(D9=m?? zcE0t`M{1K#6f5HFw(8ZQ@)8xsef_BDFaepl_secFKTNS&tBNieIHgr5$}p=4y6c6t z)g$4LQiadtc;7+*7YQ2iXXv_LxxIRP3x%L-cE4}003ZD27H$Cw-`rkuw5@o#U;D6D z<&8DikkY4oXP#ZqUIGF@Fw3R7&s0YLdC+GFn5sWzvfoj1PfcDTr4jID58Wws>dx?FkEf& z)u+FMX@)NDtY2VkrUkqS2?3;$N)$$fLPfU++$Y?h_+863bKN5cO|!-l%UIDq>f08* zDTe(+e7V?{8f3jS3g+;`#6)YYt5ttdUNU;UGqKrCk5EM&L`B< ziNY$}W7k|h^@2SoAIJ7^X#Iu4Lq~moj{kIj@1renJfls7IT!@%r?}!4$q=DE@(xq9 zT$q<|KfK=A7!K!jkr<}`i59Yt+=T60@h{B`{d29#)Amp2O|~Q|y1Lcn_pd8saF^t> z1JYBB{RR8uwT2GjFnuw5i$=>74Cdci)$+Thw@KN3_!_u@P>II3)mbJRe2ZeNehSct z$~)@UM(V^pms-tEjZ~FO4=y~1T)Rt!uK=C{`x_?FR~GS8Og5c~a@9;c z&!(B0NfmW|i}V;t+~^zwfgl7;?gi7o*OZMiIpHOf?U3bjyhgCj$x?@p;MSBr@wctN z%wEenPBKD|p#+ZQeu1AMbD=`JxKPt`A4hAQKHAeFwU{J1!?eWeugmp_&KJXi{_pra zx5Z{s_D$@d_{h$rB}C2M82!Q@(h^1>JWOQKW?EF2L!W$nn`Q7=i^=AZ!6}<+`yRJ7 z1E*Mk{3I?Tza@;}^zI&v&U0hMCUUU#wuU2|I2Y85V)?OvRb$zqyBFQ-ak%+NG#|Rv z_R6S6O1BzzIqc7uuAPp*3O~f*DP!JVZ{UlRuhUENqMLp zb};oIO7rTG+XsJIWX+~1{R)>lh9z87qW^2SRwpZVfR#%}nemyJbG4Iv{{tTjlkj@k=X|FWMwRbpuBNnb^R^aLD*P!|b?AwNQ1(aXpjDUyu57cDSX z+ArBq(-xO+)#r-&d|As%DAdma?cr-xglvg!()7oau|cB=A!{loD{5cQExh$uhz7&` zS%v*ZFx#n;nK*1x;JL5fSptz6YD{M9tw*p~xb@Sc$mcN}-@;&r#n;AZ9@*02?dP){ zDLDX_HZpOH?{}~^d1zI`v(^7&Q#2`QeV1OGR;6#ggbhp9b@B`o>=AR#_c9bJz+3MY zEshFGr$Vlqsm*TiFx6$VPz-VS#&*ZmgzOTI_BBfM;qD=9w^+l5S1?OQAMEzc;dzaN z6rT>euQpJEDgU$n==ZONGc4h5u@ditFwTtR8*;rrk>GovG(s>L&q=@z<81OJB>` z15+i`BQ(D^O43z?>c5%%VJ@Lg!Nc^Z+;?x$@Qx?7LCBY^ltF3(Kt0G8P^_H5gcDuG2Tca!Fk5>~7XE~nLb$m{F zEq(iy>Bt+ayTjd<`bw4aF+q{6>Z&d%)-^poLTZq338LTeiO(8#UG6!2-jkIHQCcC? z4^AUawe_xKWN5OFMK?tZ5sP9S)95_)uX*^MbDBECWm;TUBdd{h(63#XmecJoyL+{1 zocQH8*;K*XS^V=?k`l^aV3kgrO+M@9ne)ai<&ACwGR>mTUJD3LWqZr2{sCVm=6!$p zU}+F~fPj@zOjDjF1)Sd6__^MF`U9v{S%^NqCd(HAkS+3e@h`BMIzLo%*1k1G#7@Vy z1-xhwmsE%1&QO9d{w|>d$(VhKb8xhldhn9P(9**1(coXXQyHD#d=%DHdGNKWG42zD zX{(#IO{o19>9f?rW@AgPq(IWQOZ564QV5y8{Wx1s=@VsbD@1LOuX*)1%M(YQ4jj;W z7cWtV^%xS^I`GT|JD|g!^ThqOu%f{CAw-Gj>16#&G5>0GqDCqV%dQr>unJ6QuK{WM z`3joTs)rkTY5~mJM)kF^sHfvu(I&L z$xt8rCfZe0d~AW%hz^GE$pHIrrP};n{vQ=wbzD!=@sP?rtPTPr8Q0fC)(FDBt`3yU*S3?z!hY@r?(36Bq6qmma+m z3Lx%{wC^w9kGHwdXtG7eF1l|5-9%yP)Bo1^6a0E*pm>^Cz`kQ?_!a)hpE+(QK`Hsmw@xVxtA$Q#t2wTC?&m5d?Y4NE)SfczsfFHl<&mVDFa#&fIA zUvHY594mW>bql^7oCLrin&I*uvl2Zl33ZxFuu#KQOKS_i6I`3J&HW|^hDh6LV&uxv z4vzHfnEB$axUI*z)@dOv{QZ(2{{piwc5GqXr>J0;rGFBvOR)+&f#^x7NI_(6a%q%c z)1rj^yuzdQoh{2KzR#j_GB_QRZU7V4uR4CayMp$?EuZ+6~VvTx%^lYY?^Bs(q>)LBJB7nZiEDfy{Vd>>l3^>auhtTUHMeIL?u+B5i7qSaeK0=pJ*m(6p;Z^MxlOc#uu!Ti z?}JhCSnF)`$gaxTr50DLxPnEXnu@iV~^zo)56gqLR$S)^}dEotKj*a z*ohbAi8jm7X(GUJA=H=tyo5S?VyP}1u3ni?hR-O37)hkQ^3WA_f9q^q>1;T}W0w9% zEG1?@zR?MQ{p8PrUF~#lGu$F-^3U?JLKn;bk%hqxWH$AGr&BE}wsh(LY;cVGf|lTp zYp{ktAVn8=*h7o$oH-rc+?CwYy{G$FasSOwphA6yg>cEgi#Io%=YK0@+j;RkAFaM_ zapOP&+SksB8Ep`_@Oa!9XUq~*>mD||;8w05OB^E^U(6(M~8<8~nWO$vFTK2|bJvAG}kh2nnYyZeTx~eW0N#X~`IQpu;en{cay>=-| z-bZ=J&us!Ov+kqK>iKOF_m3|+%CcMnkk)>B3-e!pi6{iM2 zbvq=)xbh!oS#b2M3-g(ha9KNeuO8-jTFrrXH%*Uv!lN6f7f}z!U7q>`hO5xrWAiR% zG__A^lHr+KST~!^dSqob=RP{}kifaK4!AkYHVTSd|B*QwV+Yz`oIk?$5sGw{#}_X0 zG2V@Tb7!;z?$ic*+i;v6Y+ec_W|^m8bj3ACE4RIllltiEvKVa-q#1#kEL%Jq8uk>w z5@Lyr+^A(kf%I_JRN9KV+le<=H)A8)IiwJ$-CMWRp(?NTTcKkNufN*8z5eMdGC!f| zj$YtRzfXsBe_ew|@D4_yp8YJbg!&T#P;ps36U-GcMBj)qKvp{hE6uQmIL`AMrVNN} z4hNnQ!NVoOjlHqV{rHkB0*LaOT-eUp#4X39KC|oYnVR~;fvB8kytJ#WRvn=|nisAT zP0-sc4^7j8`ls2Y#)iE%0j+Q~;*sA1yNvBegBE2l5t*y6is!+%`FzB&d&*y-<6b=F zi_z1s-@@C?RAp_=1Mv$omFJZ5w?LCO4!m{Uo#O|~d8QW{9*9_)bUI)}ek00XlCvH& z8a&M3ihXx-$Tdb9wXsompw(tX#)~1)Q?-v*TRYynThR+?*1Won$+yKVT?L*SZXr{o zdYjC2z3QbPV935}_K3*rnwcGaFL$PJqE@eZi2A&&n^!nUIjOPkJ0Fp7n(CcM$I^H8 zu~2{2*S<$f>rdKGbNtg?5zp%^Hlv@p0?u28AzYAwjIanvGl!e=`(j8y2oKFzEp6?!V^sS6*8a{0-mXTa(55uW(a|CJ>orgLL~)lmWc1%7d9U(;riaAspF~Yn@xYXWBKxihFpc-FT1+~sX3Cizy`!|lPA~2E=sUE{KR=hj`@~$Zn0~EOTk6=OnaK|$r2MM8{5q^V% z%B(;^1TG2wzh(N32mf)-)R&7^r;kpvWD4*f=IANb3!xQ@4+%Q(<|B4^ds|N{U=0bE z(iyp$OdVLQdNOZ}!d?%kRp_r?X42XD z6xL?fVuvJ=DN*kb^wI%+zHAE|p%Q_MxB6g#ZI8sYrF3hv}<|2gB_93#q}) z+(2^ZW!Zp@O}6wjxfh?q1HVnX8{f|26cr`oAb~%x&@{Mz>FKK$M>WWTE`nXMN{;n% zob|)Ch;0MN5cJCEv5B7p%)<5=MEp;6qFPzQXR%PJqdtWQT-(^D&IokgoaM?$EuLMJ zD9aOiW2XUPjoH+E^z{xSG;LTc0a25agNKennqLG@N5hZJ=bcehNtPuC!T=mmf{V|yP zc7QBVVbZNi@-aV^oncpWsd#h5_|YQ8FM35*+TWUMlt#S7mKn5!goU1=wT8w@lS^};BTMm@$_j~!jUjLy zOYtY3meS&G_3G0Wt37w_1S{<+Htqs;sahz91X|LLZ=TVn%TV#BR+WGmZ2-k<%hy>1 z6g3P*C|_%`hPzk+hM#Wn=LOF_8?OFPuI!DWa%x@JG99zGk_|68m6M|S6o|aUuQL1# z6zn%g%j>4KwO6Zl86xMZ?P=a82ArIc30_WPUk+Z9a8@xHe3dhiiK@3BQK;8H-98IL z$jF<9s~1qj-kc4$8qB%Ox^^84F~_N09&gRqqym7Rp!2(BOr*l-{Pa2nBTf|^-HNvXEeNL~0D{{5@{ zqR44W92?{9^Ip8|6weyF=Z%YhWVsFRDFAC1dx zV=EM9H=qh$^MhlxlrbIcVCj%H%PY86PT>ZZbFxvE#8X!6e4+PgCBt~>6GClTp)|eV z)wW7RyW%Cdr*M978w-OM+h%c|9$&Sh|wxq>=#DjuySd@Wl4Mg1vk>IM5ni$5v4s-&-nvFiZgNtwfADLVJ9gQx{A2#Ic z39R6;qY}(+T1)y_F|B%teBdvRsqJh10XzCzU1sxVzwY_%37lZf*JMEqH)X&Oa}wz4 zmb1s&RSy$}OLR{OjZ!DW0wHsfYh`Lm_8ztIT{=)xRJ*IMPcoT54@ zWVkj;b!E=%7`;OF>%7MB7ik~+@*3RLt`S+!LTvO6@a9YD=BlLf>;XDmx#Jy^Wn=l^ z8y&JsRmQ^2)p3hW0qLyVMem53jDr_>kjqdWkE?(7YA4Vr4c{tZauku*No!%;GelLW4(Mx zJy_Z+VtV(#wNWN^>03=}4>Yj=36xQTu^J+C`l$K9`37iJA+dgQY_1H0q1F3^%gj_! zHig@%OQ$v@CsUi$8{#0_g<)8hk}|T{;w}HVI)>}@SUhFJ;-Zrb>(MNo<&-RM99L9x z{fgtvyHu-^{Z%f`>0W@TG9TGZ!8}U9*g#gR6G35jnZpOQ5iOn)^}C+psxL~%9_ft@ z`W9QySv9M^8EBZ!J_QFvVU!W`z!80$vTnKcx(~s|U$!1SFwR$xG;GXRYHrB0d(`eH zxJ{}J4_r(H(}YSyd(5WM1U30CLalQY%fWP;X8jzxh0lZ;(knQM`N*&C$k$4Y2g5Jl z(M$$a$4gDEYE>~A_Q=t|wy!JD=@t@8}z_qxn^4k{-vxN*E=4=N#^HW7reM1Uu$cw~=cE|ij8O?$wQPTWkoyQ$d>v^4> z?jV5#C=KQeT&t=49?x}fI|1^8`%z4FoUz#qs?vK*!?fW0eyNSVPll4Q)|X;><{C}q zgEI0iuHojykK3}Gz_c#Xg`jEAV4nZUJ4>e$@%O5<-N_a*9CBU|T5Bb=7!IW1rS0jp z7Ss{)rQJr)|$h%3wj4LxZ6Bi zV{K;sz4b6BDS3d7oKmdPH)>UXU7LX}x{G(GV0GKmrE2o<04Rb&+uebY0Afb@quHXV zJpXGA0t1<_4E@^9mOg`{&>p%`9ZWV#DCR!V~8F z%06JeAhVz19wHF_w7=q=@T-SUytb&kz1k9WAskJJhV_h`+xQ_N$xIXPpa%FD0v#7< zjcu@=onuqCm`yk&Mdd076h4u1h>;icYM}OGC$R3Il4JQ3v|yiRn%FDaSv!;_Y<0ZV zN_br(<@ZkN{5a~5mW!r~@W^7Rmya-_zQ%Fipt&9MC;|uQxj5{vs!@hReJy?IRW zsluH=JC2HBknLQFS$)F|1;^(L;$*&w#a(yQGzb*vFS~yBYdH;{UuO=Vzd)|pN1pe$ zfH3EZxIK1atTRjL2hF+$#_sA}D|H$K$}-L14v|Y><ocS0z{ z9sl0(<0EaHh{OD~JT~raw=CecaIRc+iR$v8erYEMj%rABSUH?){MlcO6k zKvA)^1&q@caO=C-U)cYZ;9X{*uA7D<49gpbY>lv2tDsDPS^?nE6Z@H*7hDC^Fs@LQ zMr&!n;1f@+Nj%K$|J~W(j1gBo5MiduM1MIW99|Tu{3LgOn!#Xv(@TKtwV3KVuBe?V z+}A=S8`Oh$nL#>cwFUh$>!{F=3WyJ6*LwnrfZR1wrm(s{HyE488Xmh1DwXY_V~kBn zk^Exu0wt2d@=Vh{Nk5U zf{f1%52V#9`?!<5Ydxo@B{rlMU;BbHX>L|***x{9s^g)0+(q>yoz>uZ8AYBr;Gw5C zu5TTRYBpgy3F>)(bhQ_vy^amWBXUp|=lb*{tjF)0x0n>RaZ`I+zq1u9yLpmrS>6nK zMFKOF7Ri6n&R6NrF@C0-_Yl1XVmVwVjnr?~axkFmH+$ddDJh3HP_zaIJS%miP(?Yw zSPqOpU;+8$Pn&Zb0e@V%<0?{SGj(Oi11`Wez>qF$XF}8q1XF@>Kpw9eubu`=AJzag zFc%x0F7ngE!%#`PHdfSlem`~_OSO9BrZGPw+rj!tWh+bAxD&>9vVbHQS~@@IuX6_u?n$n0Q{p*k_1Iku2II^v^P8${ z&dufP2~NBX?5lOJ$-{OKg4*x@{75kJt+MmlZ2!MKMy`xAK8ALF;F5q9!h*&Qu4=Gi5);^%WutD@=Idf`M{R;S6W(*J5?aFu${@hmf%P&@0 z`-H6y8(szxYJcotQkI<Ru>K(T_ z$m5J1s|cpUjzyE+SoyhY3vTJhZhJh3Hz81$-NEueU0MLtr6apLdnwi9$`5pzFKNAS zT)3$}&aWnEI`2jnsh<{3HJmlj%xrZmt{A3Q&cpONA1(!#2Xgh~(qm5GDwcO;OvB0< zZZAl@e(!|}245N-&|NS+91;QB#_2!9@f@pp4nYEKiO(;#cwt_Z-BD{5#`AoIs$I?wf@F3`!nqEGV5o4#E02XSH3Pc z9dXHa^2nW71cZ2060qc4CoVDyYbPLw(;1sv9r09n_LbOkKvlD;IsQu#k*7J5mpqfY z+Cxr>$g;QNDZnkQhrcX}V?McYpAO>y|1S5}sHENa>S|Tz^|{pq?_0F<7G+CYMoc5&QofPv;n>v` zE^E}nW#AvDkqom1_rLL;z6ku`8SU@-Lhl*ogiQV&UsfYx*)m`X%g)i4U^Vdjl~-$F z**bkGD$nkO0Cyg-Mo*5IUkSv>tKY2r{^k5b!Z2(dt9$ACvU-iTib}R`1t`A(VN;1# zBMaQ(j- z)U)YMtD~WsTL*MP=4hAxV<8whw^cCCu>*Ce0JZ=Ck}^6Fo>Ob8lv`>a8P4ZSWH_5cfy*9NmsHP)5NzT@u@T6SHsV-~b=}A-?1{-nCra$% z?3G;Pxk#QvUPhgcvCM+^O1zphtnbPOlV9Z2NvVaF9zcAxQ&c|GDUq5Ft3RP!N!T>)hkPr3Q|lxsITwZY9V(pe534>RBmAPnMvqsIPT0mu`rF0E z2YSM>K<8otE$J>S8yx&W`-8LX((UN4fIXX8xmUGMFQJaaWp0qvpszBfWz=po|M8Ih zwP$c)08}p+CuWK;i!z*wj&m1f*#@MzILcR*^wW4Oe}iKErrujIv+(WNAt2_h8)~vu zH;`)2xc2hx;&OwVLf*8t$e&$051ki8Bd{R;5;926gjr*2+`$P&vNX?6M=DVLg`0y10nZEhnval))r4w64!erH=##!R^j19+>9M z3bu5UCuA&?hg_nvgavlLT(%rz1P&L zLNFD#8-flJ7%kt@?|b?80Jm==q|oqUmzTk> z_h!v3nh$!JR>@~$`g*+m39MUIVy6io=z8bhI2`z?OS~AGawe2nz8{zJJ1-UOR~tD9 zted(xDV_9~!7e{Yl(Hc^ZSqwH;D{=Rb`I}XNBm{(U(Vq(cFR7M66(9eZo}Tfg#+4m zZh`b-f8bm$hrx3$*uk-Bf>Z0kn9EWKw3!o*W4zw-mKW)=E&qv>iGbJ`A_or2lw*#i zP8)cIS1+FJVJsu>TgIA(P+;Kuf6%xDrY|ssRK{AC>n)=Uzf#YraJv>yT4^kCTAB; zH#EhSY{9h)LB|;Y@O+F3f$5E`(-?=>@O#8Und18UTmgxJEnbVybwZ+ZeF|O! zfAUq3uO_ooq)#ruQ+7#F$Q)e~k2XnTYgXNZ$gD9RgTm1MISOyFj^oYV4#tOrtf_42 zpVk*HGS2wGi(k+jC`e|foM2otvCEj>Z}q|PrHeNFl$VI*g$^yv-*csn>XPF3C{<6t z^8c|)-;ioZeEIxLCc^y{`##9{&kdYc2WR@Q+=MJ+E_CH#PbcU}NJLrLL0JLw)eXMA zg81!GNj^7k>K1$0Fx^G<2528+4)5T(k{FN>TEF5#zNHGA{5!|0=pTPRO8f*9Ar8b* z`iY4QbyCKz?8*}Ny7}_ZHK9wN6*YiUM;vV}kGi@J%(V9U6|&PQ3LSQ@d482GC3awKEz zHnNDko=7nT^uCW;^J;*#lnXL^EWL1QeWwLHc8eZP)m}mcd@32F70>aTch@c0|3HMU zoHzhpF<3)-!FX0usbon3?VP`VW(qhOv0q@EaV+5e*0%ru)RQ1J2e?FcE2W+q`J=zE z>BSfR03)?@+_tL<{zT7o5ahKeohHqsl-i%ysfAWwc@T-XTUUsY;h{WI zgyT35E7jMMh9_*}Au_}8pRMCIrw?u7%-9C^PQr~ehF8K5>0t{eA!uTYDKHi@H9AUo zkK(7D^ZED5tuN@Z6gt@SA!jqKKCMrXnq4BRRPxUVgv1p*S4O)T2#EG%>XuXZ7+}@( zCNm~ppzl`0CDCaj=~CUCSYL>Hp;9$aUW7k#p}6$ua%Jr8xu@lOQN?{_LohqnNFGK; z*#Xt*g5LtD4FIe@^Eqr*JFS7{fPsT2K;jI4T%7PFX;~pwV*ktNwIrOUy<*kxY$xDe zx|L63ke}~|vcK-kRnwIjhVfhOE%0#2GN}Dv(RzC-=1$X*=}jyE9bQ#Sl8~z!gcW)+ z8UCz9MaWTE8_?m0^SxDid!WR7ii8O)zTIU6aW$ecNJw250Cc4^+jSEcP*T`aq4lG{ zSR5qH5f|SbT|PekwH?TCAbmZbw=o2?);^Nf(mHXAfo}ZBVSZtiWqpj7ZnE_w@CZ&? zzd~j?Z@WFLOZNYegm{@%)PB?fMiXrGE#V#aFYY7HargZ{rzS+$UdPH_k4k6eo9k-aNeLqNv{M`IT9RDCLH3S z`*jo*55GilVe77_qD$P8ELTm)O^WUs-(^82*Ru!Zlz5L$ST3i0NG+*7vq>-KeXS<5 z0Z*-w_N6nxaAvDA7?0wA%HO+{vc)4kp7A1eQ2#)zhZfe&peJsc>3s*x>GBpjP{tmy zrOK%IU&7&9j~T>9%PqwX99h4EfT=QaKfmWp;W57z?VPlYLVxs9Wz+ys!(G+mcgVBtU3YNL8(P#m(#W*Cy{04R)4u7(v z%3w(WqZq4`9CxH>1FnXArKL;w^^Yu0esDU6r}RIT@6nbWqsu#h7YovWDU{%948DFg zqAE#~V$pl|kRLf)i_aV(Pcv{X*9U)Fz9ijux@gC5gZ&gldSS5H7j|2oNAxaEAqgI|L7q;Oml2eoeprD`tAkPi(v;koF=;vSy z0H~@0*Z=?k=JOmO0QGr<^4$KXJS_oa0WVNc{(JtHprWJxCoj>^P|-17VqpB&F|n|* zFfm_YVqm<&ef0_(=UFhYUgP27y#84UXY%w5K!AxN0vJL? zVF0`!KtUxydFltyKI@A9jOag5|4UF_prWC_d{*fd_VWTL{ULEv2 z4?rV8C#2_y>*(s~8yFf{S=-p!**iFTefIY8^#l2bM?^+N$Hc~g)4rxdGBUHW z3yX?NO3QwfS2Q#>L7Q7z+uD2k`UeJwhDS!Hre|j7<`)*>8=G6(JG*=P2Z!eum;bJ= zZxFY4|KUObp#DFw{x`D!4=#ddTrZx_01e|mTqrMmpF1i68ah4iOG0T)3=0n;2EI_t zH!`W;>wB;m`L)i7Ej=e+kpKlYn9l!$_CJyR{{t5Ge}(LS1N(n*!2vj^D9?w7N&t`o zJV>517uA>};ExJ|)kdp>tEYbtY%C{hgzB`mTmC>$F)$KZFl6PXz0oL3&sk!C*1i^x z;0i2P-h|HpcL{ZoeVhC_;d9CF_ahXtAQ8GJz)Ne_XlhZbZXnS-NS$03h@ z(gU>G*@_WeKYs{nWNfg8@V64Vn;U!Go;I=2m#j9+#T~d&?a1oXR12Z1B>9bk2JTu| zJk6aQuBt6;B?@tL><$n!()e|?g=RGkbuwNWtgG1>f~6u4FZ!!0Z+z}?^y2Ye7vIZy z>1Df?P=DoZhE}RJsQ%s6s9FSW(S6=S6!dB`KUH z@#fNmyAWZBAMd|AQ1u(#t0Y>UzUE=c9VIwCnQY2#I@A_}u>0wnmzk7{e@ZHfx&lm9 zI1@gcR&4PKNDEI;9L0 zttgTD)K|HfeN(D14)3{@K9okX_4_k496$beFk4aBL9^82px_s9cb( z7ynz_PW0it!O>_?2W8@qdkBX1-Cd9)**~XYH}TWsYR8Xy%M3nWeo%TOU&%6xC+rn` zVwV$X>EKdaTBCZEM!t7GLDP%(t!**7)GtI2YU_YhtTYQgt6L#%VB0H1*5icc^C_-d z=X-Hvn|&u#r1GCDE|kMDgfP3KXuSXzyrowVr|ndAYCb*nh1Fh_Zh3yXEKn+_>|r;k zwO?vWyT<-YEIL=h{1S!Lwe?nMj~#yx(N#3fkoUU4*ENPxX9l=LICKx==8&o53Rds* z#5zu=dM(|##^|5)?prG@mryhRTy`1F4=yV~=r-D|Kjd4tXdC+MWmvFzmO6$++dFfm zFtLCibP7)ZsGy>ck~JT>eqgkZd$5TMwa}NBV!jlArK>UMD4WS#06 zs>4Hw+>WwE9Y*oEw61oG-bs#uCV%CSRIP&rE!FTNKL0y9{b)yQt&(4c2LvK=G(P#% zx20J7TW^(xX#cCOR5<^GGc0{%cuy`a|Kp`t+zRiz<6IsX5d1tlcC$F`SAwJD67RPuKm&I%uR&GQ zX3DgJd8l5T!OnZ##>*hKL;0PuCxBT~!h4q&+M z0-j{*3Opew-IVY@Khok0>}cCUUaEcUSvtL)ZYZZS?eoEQ0R~OBu+S`0-6tD~d>q&3 z`AaPvED%M-<#GE2D7eht6W2U%?4!Y&?dmIcBH{VnfLSWm5^uFLzbFIEt^`rd=S>ftbEM0miFZ|%cxr6cbIFUldS(=~(FK?9D z*yEe4MiF5Gluc3S>BCm*!x|V9iSG?wefHq~NGb(gMfSqFYHlnDkY#W`65Gjt~U!QiWb{NY;dcRiD4d0GCW#mEA+$*Yo@`RH|1M=M9k4k>Bj^A`&8epO1$8*~~FTxnl&abzLTV?0aIxU(BPD9Uxg) zy2AP)Bykx8iu;i@$>*F&vOiW~+P~I`_F&7gYfx30=Z-n8Uhl zk-YOjqowfaW#7G@pGn*+A{Tf>_XH>9Ev;3Z@Irp3O%wAQ#7a~8Q+-iaFS|WfO^$jz z$M}A#7ZytkbY&lwd;)CGt3@P6OXPkBYETr4Rcb86#vJn9G$=-MQKU9lFSyndp^5^y^U3K; zaUd?QCuRhslHp%T6Zwr7k#zG{rIFvnRHoDhENo18Jn8C8i&C(b@kc^?kfw+Sp`Sr# zCvz~ZBnj>TIJDDOqdh8{-2oEtHX-l}Aj|c-ksdy`{|#ME?{;ah0Dl)c>((jl>;1+z zJ0r5ww1S9r=JO_dFXb}?Ac zCR5waWZ<&%kvDGP-HJLe=;27a&GkIv8;)2)KOAG*No3i53||+4BS-to>ZtS>Mr_ zrVE}u-bniQ0dhk)u9Ia&uQv4}b0OsoDmypX@HutB=8V5g0s#HSdjiobs2gURj^e_ zE-qCHyROaAcw1KwgJot*gJrhZqSIQyD41T27i+)FiI|PCu`CixoT_CtU;hg03GN;C zM0wbZT582R*~4(OJE}!RG;eaFmoCk=-gCzWRz+T-$=@>&rU(UJN@a>Lz6)9)*Z5pw zIIYK9{!y_&!y))>(+oM8I%x(Ey`nffTO$ul#YJoJHxD0QO~m0lf?TWl@Z%xGeB{2Q z=w%ZkWL}=iRi6O=LLO-4nVO``DuIbZD(*~qxKcK7U;EmT!015e%*5fw$YtUl*vNa2 zQ!bH%u>fvhmdXupx8lurZijOwTPVIgsRw*GM9qpgrAEO#5W3 zCl8Z(4-=t&DMtz5sfn4@;!PJ})QysP0(_?0to)sDaM8?v{J{46EIblklw1at+kOjM zIhu!>G}F84U@8lF6=-H#r<{3Rk(Hw`AK9C4U_7i2=AS=`sh1h;1yGb~pG}7hC|7Wc zbxk;hNyWTuJNP#OoAIC2^7BDACv==*4?eH1M$stRF}*PuqQ2eB&+{W=Pf{B%N^s&v zR9@}!9}-t zTU#)oYmENdX#wb*|M7)MJ@~Ie`iG%Z&Lu(mbJcP_cduJ3?ml+~A5RYFD4+?wRBySc zG$zCQn&hvqqq5yA-Av+&fJs#TeiZN<`K-Xjc|zaT79|+AiDDb29811qPeC?`;*&nx zI~X^zn?)mWkXs-#TSEdbaOK>- z=a*Q%5d7I0WkSorTqg0=J~bEmb0$NC6_p~pu{lyZ@k*TZ`uw4SV<^e_-y&C4peoah zrZVuFtJ;x;XcTRl0nj3&tlNc!<4j^a@DsCR7e5&r;dXx9l)UJ3_|BVCy0m%9E$^fz ztJGM~u%>VHvJnU-PMyVhG)Wt6Sv9c%1ede>m3^hHCc0yS^3h)frvVz7QAgM(dO+S~LB4EmCGZV=K?@;Y8J7j*d#s z3}S@*t+~-6*%WkcfmDwk3=}o+ITZW<&e$^CIcJIDDr{|SFX)9!&y-D#T%J^)aAZ4d zG$%$=iGn1JaO8JM;n>5_Uye@zB5p28+Wk7kvK1LfHy)I!9nODU>7ePnYCVEpq)@$` z;p)fq`tOP5it=Ex!&ye(-EI-E*+0ED>J+>WM@5s_)FK_-As4Eo?orq<`1=-s^~-Ry zi}I```g3dcGV$4NdFh06jWw7(Dp9!%(n-osqb`Cqgo!Ka=QpO2=Sz*JJRQtAREg+9 zs5-Y092ygj1jE~$M*i)4k?e|SQ;js8SSl-pJEr(L83($C?XC_g3xePN^Giws=6lDQ zVZ>@i35Q;A2%U!AYAN(tiy@-2DxgsKLxls0y{kB4R|Rx^;kpXD1PI%_w|tLY%9&8l z!bYuj5hQt`{%>&lgxhRC)P1r_PEO_^sRH-->&Q`z?3W~Nz-UzBY)a0tv-GkQTtiQP_gk=6>? zBk0h#+9cbUxe`lMo2V&*%(`e663`74$C6NyY_ST-mJdzvJshuf3~DGYObOEb{a5fB zsRV7ECUMmj7sLRJPNBjgK4LmY}?xtt6j4n`&EgnXiPkf0F8=iGtuaB z6nT57do6_3FaC!#lXK?e!p`uAj{^^h-~HxoiP5HH8m%)@1Dkl* z#qIV$)(|yVPj^#LZR{~L&q_Y)q{^G~m4u0!OG1Hlvv9U?@fV@q!|Mixd(+NepD+i4 zKhhG~uTep{@emN@AW5hA%C*<~b8~l4IHvPvhukA9cK{6>OU4>y;1Y3>ICdSg>Y&`j z5rlp1fcRPN=Vc0BKJb^!DS3#Z76h-_fSSf=n-*evn*s>M>{m891{H~2z^a4H)=PxR z?5(V@vQHj07jv@fKT0I2)qH3>n@6Ovv87X*Nllu zKeIP=hTreCWJwx-X8!1TfH#e+4Y1WGu$@So4HP2^os8A%mYSXbt`{m6&6^H{Jo1tj zxNW3_$o{Zce%@aE^W)NHJUbFt4U)b(#Nj=!2h|w6nB2bH{GjkLAU{E!e9N7>P7tUu zzQ%SDkTkpb1i%#6^l~Y3aO*zCu>S*uj8MZ1s7Y!L}c5?iZp6j`q>*e|6La$psfS7uEA|z=!I69*J#UZrtHeRDs!8ZU^nd&ht$Zk@B<; z{s(-nyXpg~juNOnr1M^!urxFmfr(e-6z_`kz*cph+cIEcO|Z#ZT3#dM(nGx@?;#-~ z`w1Z3cn>5^GdEdf>QDvJyOxA~8IwdGkU~&-_nv(kdqn{a@IqX}(GxNwZ-gI18?2hTw4U)_uo;UnkWnsPMH0_G} z=WB+{s(liJ0=%47SJc7v!=b4?Xs=HlA$+6XdKQjHjQm)qd|Q27wb>B=(P>`*;j$f* z**TaezLRcg!`&dHfsB%RdnD?R6ZB1X*V#1fj{cV318;gD9NE=5h zy~6@t(kIB#j=`tsdP)4oTtZgx$o-NZ2{Nl$Qk?7gM9tJ^jJ`$5>q};(5=e`ac$;?~ zX(Gn@?n2`Kbp(-WG}O&ixiKeERywt6Z3=3TL*(Spgn+z1fL45Xj;?xt#(c|>>j$Lfpk|*1K^Y$ITcD`d9JJ z%%eW%94j9G<#PSol_E~=zt?lPCKsLz&_a0$CV}7}ee=4k9fMXTu`Juyx?@Qb*i_`NG>Dz7gjaBe-wX2Bwx*wDtZ90Ds^< z&Vg70G&;@{0c?Xq^idMZ{X_Ee;^x%XSvFkBR)G=H4xATjBV?Agkt&5`o<+NKndx!u zJDGJX`Lngo|GF_&l)arB411eXOypeEM=;Yr#MGB-$O7w!cEW_&3BcdxscgK;GCpV{ zOa+8QZNs`uyyR9HL$lI$3>X!48Ug*te9I9Qu_WqD7cPb2GwvAR*<*>XP%U$Y8%sjr zw76163m;Gj0<-hrD6p}hM=z#KC3urgph%go#8=j{5ICmFV$ouMEU5j3neO*@XHu~S zRN$%?PK^|Zfr?NlClT`;BZl&0OKT_oVa`U_OcXQL_bj`F*RhGy0&!s~;rwo&UaoV( zZImt5u@M*_fpPoyg+WAh~=Atv&xOM+?UqPf z`Yh>NEOUXxY_IW>?6^%fsIRvZH+11=2mT_6MNzqVP|vjBrHKI|O(w!6WeE}M%5{Z# zf@~EpaqzLQG3Veo5a8hpy`L0p;_G{Xo5(RzqdDXS&Dm;+3lvqH-U)b3(&DTpQUGNu_oLp$l-M%SG6U0p44~V7taR7h$m*L z&Q2ydB}@(=xXnSfyo#rj$WIJ-_DfYIodN$@lEn?m(RBRHh3qQ??kn`2AR!jnQ*a!<5ev0s$*J6(~ajJHcv zDl!ajGU<8HLdd?nl(`H&7};9NMmz!5IU0?fR+Yt_Gv;qdh3YGt*8G=mov}UdW0rn) zS^f!+93j4%k2taJJ{$Pj2nu$tf3bwcrnsFBL~Xi|W_&jYn($ty+20%BD!FwJ^7!{7 zruV}5E*r(hvucgoDm!r#P3Z3xOWJh7fvHxInwHdSL25;*@voJM>L<4k?8k)s?5ku3dh6Q`y6&1I2X#dGHcyGWz1;ZhABLna{ zifo!CbOq1kF_VHfq{6g}>UEHlu+~>R(!*EACslq!WstJk8ujhqbRR5V+UsbE;Icct zuQtnBBz>n6FilN(uimAv^h`CvGl`0zOm}Z zE(-*l6E2NW3C9<&kR@bci%&`mPW7d;ZM;}Yqti1To}V4l_8!ks`dW^#^Jje{RA zHp%pJV7VUBGUGGUwue2Q#C%Dj5=y98Ve-|SQ>S~i`^+F~b?MsV+|x*&D&wv>kvX2? zbfklDI0F9lkhlnR)1I={)uP@8%tY=ybl5;U&m!kaUVr!S@F-1g91$4DWP{txsNs?E zjSzQ&yQWj-2ED?FsFvn#^XeXPVOVF+{%fblRzMJHp9YK&QNEidZ>l-&Ioz1#T)p9* zV6|N7vFl3bEXPq5x0G{rv{`#>y5WXkF&|uu4#9FxYU(05Ta6KD-TL5AD=y7>HP8+n zZC(L8YKn*0zr7E<>S{N*Q9GAq*;em$VH_asCi*-f)eBcW`M&9bTg`dQxltF=WZe0Z z8)Dj)B*nJ-cD9(wWP}cPmxQ-qDjC6w_r>{>{+8N~X6Yp3!fxOe#*NZmMR)GG{i|D>?1nYc^;7{R!3`qX-)vb*8vd0E@ zB|>zlz#HU1D8L7jAslmO2<>c(=fGxMk`_r9j@f8CZ`r)*D}ZBf(DuQw1LdZXrt2U{ z>H}N&kn-fe{@JD=y*E&0FFtfliIRSIxY=N>#|wFipvG82kzNG!riZq1)!{CEATlxW zIc6>^M{JuZ&@Y{8>>LhRTB(fIn(TtWh#-X0!YrITHtT;d=hT`SS}JQ@7?r1(nPEG{Q-;#%HZKEpf?cw8Sy?!-Bb=)| zj7VdLJuTW5H|A6isgyf)0^#mOyi!1A=Q^wY<40a>2|BBNxk?^(dy zjmvnsgY|>^A1Z*D-#%0`6pdTfe)ohx?C{P)^cxigvPdY$VvD*z$F07NIF*9T)iv-e z6XO?)-SE(H*s=9oXIhbdD-XRzzR?tN6 zAfxKp`!TwJ1DVPOQCCE$pT8RV=oB!hEcU}NPmnqC1J7x*!Fc0G0uV5cmuMuvgpRxU z?GvDpGrjGQot@Es^+uvnZRVMXTV2deo>b?HbisJC@yD8kUj8ZMtoI5k`D0W5?fR(wTTyV`wiRiUN%cCsp^bH3l!{to5Kn1 zkEeyesHWxbcVxemFcAXd7T;E0D8*U%9gn&RE;7z?kza$nmU8@_0KTpSeT+r)u<@l_ zNZ^Gsn$1Vuy&VVA4#~WS>}tEVxz7b`4kYANi`zJV$9Vmwq@%6!wQRLhIzm&{-4upa zXvo>FCmJ6Nd)`y}(r)`c0gm+dT6*1{6Nz?*nzBCivN)Ejn|cmQY+7Wf&H7$g^^7L; z^oxLbzF(bH)}^uQ=adz9R$vq{+3_JElK^oA`M;WK5|f*CYv=1b zLY5C{roqh%>+Vt$LJF*mc$!1-rlFeJVQtv$)xRXw`{s$K@%YJUX`cW+7UCig-RtTv zmzr8l-TEnlQ{7zan+mUYmGgd3>r_(LjrshjxMkSgh@xOJ40+Bjt$#B^Z=L_=c6=1D zZ4X_M8&kvY?B?vRWoT5UQ~e>*ldpjzK9-e9V})5D|E!(3?oHGVia`PyiT<@we-PCXC# zJ^k-<=j^{)JV&2vA3X;SL75d{-Zm!SXtr7Hk2^qqCrPMkx5eqEnIdj2W$hYUriRb!CXQPMh=Fz(} z4DGG$9nhuEHVt;o9fp&mCzmwr?3tO=9g~x^D7SmTWJvzWU=}3ZAMc~YI~5OvZC^jd zKi*rg;4-S3(D|rlG8<%Bs2BSSH7u)Faw38Yq&BSOzDEYd!>7gWe zk-EuW6y|fpHKPgCZ*C;(9B(jP8WA7VYfb4E`=(m5Y;5M3wz%MMM_t`NPsCAgJb~LZ z+AA2tG+o6@%SDn}y`T^_#T;i7;M1aPt}y$zK;xV{`d?0^li<97v8vGiv0t#^o3Zjk zq4V`hH;2+NgD9nm>oI+oWOlh+spvrNVnm2EcuT21sqgR5=%nlC4tBfNmV7ZMLn~<9 z0b_eAjk>690X8wI9-dZla0u!ll%tdfF->!D5gu+IUMVw~FM_3b8b z0T)+0MS*@xt59o=3rC?APdt`!(C1TO^cqWhR-{E+dk7&2jyus$VJ(`t?JtT(h?PB1 zbzJ=+4wV~nwy>0WQa zR!n4tBW_NprS2~YB`So=1vQQOwVkwkL+s6?uOGz_4pu@LFALnN{S;cR%}Rn?txRPG zm%KnO;+~)ot1KXAl*~XV%P2mF+of>$sj0X974}#x^US~I)`rK4HgjfU^3LXn0U!%5 zM@oKC<69i1l<1=%Y?XomGqKrh$}vvSHMN;5R?eo(?KJg7i+kN|j|dP>y*SwLAzO8V z6l{?aMP^C+%lV%VqWru=J&Ib0)7aoVvMVR#bB>U@Hwa|q#yR$8P-<&uJM#qiMdJk~ znvTuMg3nEoo$BOqmG%v!v;e0AW zO1}}>3{@=9nby;^^Q3_NCd+!+56>x~uUldO%jBD>`aO1+o0=X~JMA7q7(}+^!b|xfc(9uNRGyw37Oe zXGQ$!u`!0*qVVA2*;Jnw=J0n)94f3k_Y;S0`7SYgmU39Se1Dn&)eg)5AnzP&g{J(G z{qFn+gCQAvT0HdfYRD<#CxB%7zQW~b7uju~ba`NBQ2xS&AB*>|gIi6tKLpzScV%}I z6ApH{&F^V!R;6gV%YG8~-cXfTh0kad#2SSMj^^udsQJM%zVISwe?z}xcZgs9mSAT$ zg}m!xGJX5|fWu`m>79r}^o1F;+*QH#x#&ji>%oamGfN++pjW?a<1r-NG$JskJG$lH zj~8y_z4x<`F5xt}^D4f4sh6Zk;zY*yYlF@iaxR=~>Xj0ts$1i44({saAxa*48y~Fg z^MO%gC55C_D|3FoSGz|a5m!Y{KDD@2K0@>6i>G#2_rR76(WErXFC=5Zhhhi2Nqj!d zGoC&$Htv>?9!KawQ^>;oKuVR?t=EAH7+C`GquE6N%Kk#Aq2@y_4}p+rk;=(gNNfw- zEuxDb0h!mCqko*5t~u^hwkTDG{G8baPa7w(JH2(chA1a01?Z*i;AQI3lX-{n!#EvR zp3OMWd-{>MBlWR2nB@r&5r2IFIoTX^0{w#l5L&wqzQKQi;qNqOijOOlph%I?4whc^jKEE8VCk1Q(( ztHkwZ|KOb-hXe>&HU4*Del5^Lhj!B%lRDG7EEP;+MCahT74hbsPp&w|Sb-ZjcoA&cvY5gStVTB_?xhc&1=gxEkj?nNOb)#>J^W3`L9i}2aH^<{^fi>XKwQ^ zu#VD;3sMt;XwcJH6d4%7$KBjth#6Tq=%bF&H>UN`u$;MU!%M34%?Z7+ZR&qU>DY#$ z8*t|P=+PrvH4OW8-hS3e#FV9&BmB7}7*$Rqc3$w_pi_4NhE8>e+K_KyxNu8e$9V;g6f1h##rVa9x?lR@WXH^^4BsIex6p>!W3cseClGb#TG) zlKpZ4U&-;?d0SuEzfYa5BTo`(w*Nr<0Fgv{U(jG1TejjjPsknj*|5gQMf7lUg_C~P z8vLjezK5oWy{eJ|i~OvcjH2*qjSG?7OGA#}zgP2H`gNThZU9#-bAXp-l{uV!=(fu` z<*Ih8{rXqx00*)){s~}j6UtL85>ni42}`E1i;=RGdU>c%dacoCWk9*$5w+IFnp#uF1N}PBXg#zw*+a@6 zyu$KUnowNtKEaaoGrXq$&nN_ZlU(rl3Q1+2?7x=8SYY0Ewi>=}D#^ZLtT}zmb^01I z_s0m^;El{=d~Azsh-Zz;pilzuN4{^gM!|g{XyUZ%1!1D|@l8w0MVs7YCM#n>RVMI4 zbJHH|cTnE`U-6#<6h38H4xX;EdXA&K30eP2aTy5Y$hT*34sRpznFHB-V5wUCB@RM+9dx)mQ%}K!L04r85@Bp`%lV)~} ze1;P36r~|o$8@qpXs1XL6(O}br(lG|;hGORA)1KjV@yybyzSq>I|0jU4s$lW?HSR# zc>@DuJ()708vs6z(?uyUH;uuoS$oB<{`X; z^NF1I-}5vEg`<7+v0)H`^cLK`PoCdnsP1<2HiAfS)#Wb#4O=(Qx_bZE^=UGsu-^(Z zXJryouzbUX%?=R%L;7xfuF;=T`zKU_zpp(c`rkd7o4Ca0fMKTuZiFKB*~PIX!D15+ zFW+&0zUSzhjN2V<@=u1v=CcJ07uOQ9>!Zsu)l_Z)3}o5*>Brj!oAR-FUpO?sCa%2Z zKQ)gw(rbH#q(7f4_qIYG*q)m^$+(_mQ=)H;R)y%UN(Q2XbX>I5UbZHLgGstP4!jZCrH?`}qcalYRTym43+cem*=k9f;LJTXpX3za5fXC>r!aEZx5C!3$8OZPdFkZM$jwF( zy=fO4UGc5Y8SDtVF(l}^I`LgtNygQ)$)C;HoGRMn?Zdrr5bAR{jXmW}S{(Df*-OyN ztHET<7rlO07?VGc%s=WJaP_h*5dmNFxAsM{6?cWZiF-YUz=+@NfL4C<$L$-Zx!@&- ztk*y?GP+QvUDiYXNysT!w6A$zn=>K*Ku)SWZM8W{?DXjv_pBb)v-Iv9Z ztM)2B6Ryl|OWlWY$d-1*sV(`#RO2mZH9Bxu8I%O`xK2jCi9D`iMyPX+^NTu}2~Hs; ze-ji(F}HR#r(##+Em)w{^OcRC!91Fx6&8b^sd-ravSLx7;B-kCy~lN#(1n7PoV7fd zqyGZsqLKr&+}(}_3e2mP;_8~^mYSq)f?u06+#rdPf9b7ZJYxAxn?+H^j{p&-oA108 z1dgRXwHg}=emk*AXuc&cY|-J1SrAhIuuX>F37(7fH#j`9&KpT+k7QFGm9&ICV)MIF z16}G}36;g%hA4HFz-_v(e2^XYj7SQnx@lR?0z`Mi9%duOhG6-!w z3fFaDi)dKst94%(Gp&HRG1BHwE!2sdR6*e0@CsLE>WVWP#IXrVPG8~rl1&Vb}J69VGQd7+7Kn`6XZEg9=;gt1g#C-N13Tvz!{nlY>)(>sm zZiR@r%940$rPJOg0Kt_xaXfn@ABHmZ%hj{>w2SSP{==;uMa$%Hup?eL(W)Qi;DtRY z5w`hX33-JUzJ*&*Q%A>)71-d9i2ZMxyXQbD<|Qh1ZR8&0IK+G|>U@ed-{YKLZMsIB z*X@U&tm5=ld~sdU?$jzw<8cgP&LHwBlIA0pPXp(-^z~d7Gtw_~`h;)DrSQ7wj*R5P z9w!Bt{qsvo2W$)f43s$X?VktL>e+%y2>afsoPwN_7?LnFt`DuFUVn^SSTD%~%6o1V5w(7U0ILG+M3jD!p zN3!(s*_+4qeWTCILG$7Db&6c-@(N?7nK17dTRM6~xD-QT$l}^oX9&&}H;H{0E94!L zee)=@Hch{NU~0b|ni;ksxhUXeLL9plB?dBccrieoSE?Lq=NH_1w$?9D@V@mK_b}vf zaCJ!;9FX#AMfuGiRE7qecMu7^WqOs3c)U)`T39k6qDNm7Ja5v{G%JYrS_M8)1^>BI zvXQMYC8vAuG~~)OgUSH}+3dmVyTAV4uB{A)nPVeW*SQngX~TDHJ6fi=zMX$;(gyM~ zbu9gyt!sE){1uv^F(Y6c@j;{hA?Z-=Ee)}Y6^l0^3$d%!LZ|W>fz?B6Q;=0JYtY7R zNt7ExiRrUk^Gwz;ml>}+X#uZITl1&m5nfRe_^9=1m$;vh6pr8YulDqvyxo%0g(J!< zda7K^;+D2B(u?sdH?B=5-!*9S59d$r5*UNQ1g&$`TTw&(q!6KQ>o^o2Qlbxubb9H` z3B}#^Sq5w=_p%IiGDEpXv=P~~@klWcdheQKG8MTj zm>wEI(F8@*svIEm@kza?>S7vQ4MTL`w%1O?iOTw(tT|OGNM3PSZ_H zauXxy31IAChSPr5Js3i0ub&Z*oBq2VcktUK#ywYci6c>t@ejs_EiCfR_%Z^%lbEW+ zzVE4z7Oe=o>Sr?%8ufd~`!rSnMmXSe<8xN22r4&TT$H1M3yd{XWe3*`vMB0E<}u=K zFebHaFRk6xPd?X0(Z-&K^r&hpQUp$~nJo2G+_MPlB}49RLAy*VA0G>gdhFJn=8tVq#Oct5VCO^dhy_hl#ZG^S+<*MgbrJo$NH*NHX8~Y zslr<8Lsm`J15lE}6TtppwFFwp>C__ivH71+Vu zOskc9^`%`i`*$M70V$_rA;Z&PntCYsaHP}t)6EAF0;7_~oO*#F^r8p>x~9o=zeT!~ z41;UhP^tCmM{0e9u|>I}o4@TOl+tOEWF~BHoXBWN)z#8c!-TL`PVNQ68^Uq=Q2G7t zV`d6C9=0h+TLB*2f)+EvR?){13nb{MOsS3f`X%XuP2d(iX2<*`>=B#T&1>@~*Ib|S z+I?jryGqjPyQ1t(5!c#u`k1C1Q!{R+w9lcYbN1qc`XGD@x7C5$yfhO(@Twu)x4t|kDCH1xddUGEKY#y_lpWVdOgVtxmHJbCb+M7D->$4XW zUqo8pt8RP9l`&cuDr6bF6d97xrEHECck#-8C>7T)7aqVPcgl|7IS___&6wvgGCtJg z1{?R@eGumOy+9~`zeV=70>fa|hjpV35Bo@TFIFIgLYZkNH-pFzrV{u2Fd#jQhy#CZ zMS;n3NegGlbs%@WLV;)D#dGPDeGs}g_baFD#M!BdYGZ`l;nA@hn`usMl65;aMIW`GFkh)rQeMpkQ#x75tFaTC9p!$cdv3V@6)|l8$@ByS-r=1_}XQ> z8XK>X!y=qC7N2)>=veh!$^gN)N;<8t_6R%w`8yrtI^hLusdtoBZ~L>IzyMXs>&&gIexM?2RIfg!?&z-9&%i_tUVp>NGc}A|c5tzKwQ?*00_>HG0!vE^EyK zdf;faG372dG!y6Yn}M`=#{yCfOYlRAlI^;hoU-4i@ma2>F2`1*VY=~3LX9|>pE?77 zM8+IipK})Us-rth9+6gi;qaE!ckvFmQ8m)Uk+P*!1}>p2QLwpibV!SIk)8$RL7+jK z_wj}l7nIy}3-q3wbG-|&?aZDHGn>?HAV6lczMysByxYN?c0U+^00jxhbp5#tP-krT zkmcY~m`HfYvqEeplttdkPSw!LUe|ub?QOC*r!KqGy=1j@m3I?U;~n%;;h(l6ydYZq zJ$R|kMUA-`%4;hV`rUrLyCqwgts3oA)D1AUhq(DwkZ49Y!-z=T4ARz91lA`>xZ4vk z!q?cP%N;W=1CvSajpGe%W=e3WHVw9td_4bT>pJ6Tk0p^HT|iZQ@C#4V}pT&VVp5I#T9MKEnK+w}7dd{x>mgsRe` zRO}AWag6dDmSh#sHa-E) zRFdmW9{W}?y6Ri*$>tc^W|+UEn30?9z8_s(;)|gFj^atuuqa)O`zx*X)SF zA#n`_ytCwc*Sr2I0pysfrC<6}-Oa%`!0YB`Iaq}%go`zq#?tL1EqCiukhqVJ#Ofya zC$5OnAgVZdAj|nlo7;|n<=o)F?2;=J9@83Km>Nn#a^`DUZ>mYPGo3+o5+_EdV`DKM zvS%7YXquMf5!&ic6T}<)ben55ik4^vrx8QgFYpONuF&x!#*+bHK~okmRMqvZ6nH) z%_NsqgvtXc+ajkt_JZ+*x^IO%8{wO)ZBoX0()-2<3thz=@wLsh>smoFjH&Y3H#sM03<%uHYvtQ7 zjQWRyyi=-LTUds;x(JY@=r>;9c{*el!AX;QcmzHv0PBEA;Ul^l?bGwg%??=ai!%8*U8Ecw>yR z$QhPQDkEd${{VxUwu|C_+C#whHyRzrm*Gzo_=4Fry*EaJICRVCWrpD`Ls_JS& zqOys%5T-e2Tq<86Kd9PPmErG)I&3y}kZAfHt<=CpZyLvJmojZs!l1_XDu|!}P7y~y z03!}(%)+F(S0XnixQ*>u_xW)WrHWy8jE)CxdRM`J@KOH&jsF1fk@#E0nhm|we{R+^ z+rJOnMkX%}rr^TrHEo)yyXygQ^_Y5G0ot+{Px zU0U3~_A@*wJaD{7NO>9=71;q|7HCidG|?mcW&1LCdi(a1{h+)}u3cYWYI15;O=F{_ z=^Ol+Z}&|z3@l(xaRSKVH8IB=5yXr{4TV}-`UM#2bLmYl;a!)CygeU-?CtH>PSNzc zIbO;nAHJGP*rB!OuglaguAC6xcJ_@hFYuI&*X19aGy4u{&r-@T^_EXxx#^N%? z!m$0?NE;iZX7(P zD^l>k#D5QHv%{&|>RN4tYa0N`G?95xBr`x&nmIQtPO&nbt+h{QCS$6*Hs(ujBMad@ z&)7AS8r&;F#U?Jw6(kY7tGI?a&LkLP$`Ei#I6uCh@J@MSxc!PgAIh=$es-K>lPR-% zG_H4n{ni*GIUz_HIIrjb0Knh5)NP|pvNFOgztNlbW>}fO)$}_E$Xs9?X9JQypwHM2 z!}}-pCiq<*{{T2m;k!G$G$0v@p$XCqs6He3FiN?hSTi_HP zyS`FLJ(+!RMIH-z6UQ3j6%88c+Ixvha>UJk3gaDusn31>+?9+{wxQJH{AT$0oi@_J zzp|{*@%`EF4;+8CKPr_k%1AaZOq>JBA5LpK#mvv~dgp4YmDBB_&Usix_9&ZuQg!*07bb(W=w$KNa%KxgU8f< zc7MS{w6mu8^Wn{`Hy2(V)h5=-vxd=D?ipiyJJtUHT70&e(%v~*Reazr-ODnp##E2= zds>2fs}_#XqTDddCBbO?#rc#tP)^1pY1#{6jyb^}oL}%$e~9*P`$qgC`&Fhikag`| z@uE{5cZ9Gw^X%_bP zGovP+bi!tnDV9 zrdZ9r?VXjnEQLZyiw)z)lYz1`o&yny@jLd5n)~A1(fD@9`(oC{OK%$6M+9PD?SZRV zzM*Y+L`cD1Z)Hhjj%}WENeMfN1Z^no{eN2^UlTna_M!NvrTC9f(Pp@s;?GZlJxf$u zhKHrr_{{Ut$fST9DpN}@$HHNF zG{)36ji-H=2(m1)$t~8%(%d@lHuDjHvl_kQZvps&;^%=aw0k{U!oCr=w|!5@)9x-L zH#&X1x-Hg_ikABs)uD#sDS{W=ix4iaARv3k?0*KY;7=G`=&P++_*+oau5WeUAKu$T zx#7JwPQJ!G8*r8+Gu}2b(!;0oS4dS1IJ_GS2-YD^Oeh;?1 zT}|beR@S^ZBQEgKvjw<#q*!iKEr@)?m^^t@5w?NNDzGT+}vDF@*K)ML&aV^_=|3IeLq~fdAu*;ts2kD*Dr5o zd%F+xTj~72HI`94Z#BCz#c+~EBs0Uce7OvfMwhrKyIA^L;;-#f`!wp2YF7UM4g6E# zo6i|)c0$TKEelba!a3oF<|#^vVF&DY5j1{$S1N#`z>fRkU8l}`ZTnk%aqz$F%VXlL zYgq8@``_@FH&FeO^%eBni}ol^gHWZDEcpk#UzRG|O{qG~zdZ-brt^StN~&@Aqbd;N$V*#J0Llzo=en z)@^I3Jd+}(bqRd9Y+eyA*wMU`e|Qn&iXh*;w@)SHD%Dm^J1(J4nzh;I{uKCeuXyXk z`h|~(*ym_VEW+N*>K9gw?H}&c2NG>mj^TiLUS|m!c{y*tKW9&bns0#``sr31 zNLTH#=r=|Y945>yw3sC45^qbjkbnT%?a{YM=AR0`X6S^zCjQL4w_Rgg(iSQALbg)d zT!^EF(8h}2=*6Q9p*67G%R18m)jdgINf#@Vq zanB>T@{?_wN;y3g{`AmqgfKtXoak?KGJZ)IbvyW{1%9Vg<@N!@MnuO zt#89}T}LJ$wU2eoGGV5%To)0sT<%5rwmRh)E)=#hX?WYlx?hE_^c^|ly1bLb8n5kp@0;)5W{yQbbY0+kRpX=6A*)0L|h{0pQEaIM!V~R@1`zh2-PT`%TISqcSl& zRx60?meNtaLlm>yJdHHYg}+PuSK&QV;)lik2Tr>d5=(0x?XIORS-!<~whSPgD9Ws; z(_Jp$eq^^$Fvv*w@BRuCrbjQ2ykU7~1XXe7|5_O5ShS=i%zwN3zb}NrpRg&OE5cMfNH|Twr5vFa>@bABa~zGWeCOYm<5Q zw=+C5S$^+-*!M^C6x=Y3gJT3Fg=T(saNoDT@J_8EY(Hdg2gogBwOOp%MYoCfMQ&Po zq%ok{r8XBSwUjcluqD3xD$;JM^eIIuN3=@q6ekK>CuzqWNC%In@}?88s)h%GN%!OY zdsKw9Tff6O{{T3u-e2#=0K)X^_?N8qhZW%!4!S>c0Q>2?=6=WZpQ^*u41z<30nM*@nJz~uZ-1ZU!{!+Bsc zX%={JIXRy2b?L_6^O~!0%W@<}W%XKJ^|!`e*!}G!3%M%2@8%(Bm2Z06f-FMJZ@x=vNY1-HSJ7k#?CwNO=PwZT?4( zUe(t|b{;wboCDUfY_#a2nVczEAFkIt{YUxl_-I+jwlxQF?azPn(z z&eqaJ_7~H1`(Kfm+FGpeBuQ@$_7SW`<&|UtMTS+J%kvNHr=r-{e%jh(nxw-`@Xf+% zy8fK8k#};|cM;2FvUS3^iEM859mUGa99Hj=fns<1%R%th{3hDc=(?ttadm4Yw3hm% zl=muhzqQ=kM|bC3U1m0d%-CARGE7nM@^JtCYXvtJ>Eu&yrO1t12g*JaJv2SS}?kBHKkAtRr_QOTx0l zy6V^BHKVl9CDKUOQEK)#a$eej4iT0JA-9QURw@x@P*eo3b@+SxPF`q#w-1T_CTd)W~Ih)Fe-%ucHf+~nOU6PAKkyp<%D4>2P;II}p* zx%s>C2wP3@3gXwn?F^bc8nynRXF3*+;&r&axQ5y@a*8E8tVO8{yKX0ylz#EUdxygB z+JnXVwy|fS+()D8b7~g$*EzUW65kY^q-o)b(agmnH#3A=R$b8ChCWPgOp|yY_8GbO zTRPrrnr!#pBh$=lCXH%kj`AD%W`^DyRymGg`{zIrAag92UPvlZ37hmY;qSxm1b@Lf z{{Up&b5i(;Wu)7^_zuErt5v>R*;%dh+X!x@AG>RNK&mb`M5^&$U9yoQ$TV_{mn!Hh z+f$$TANy!{mru|d$KsM+9TL?ewXpGgsS-zNb}g;d8t(cWA%c5}WGd6!Lh*@_1I;fZ z+&&(D+Mf%4EqLcz(0niBF9*%2LuDQLSqn)G#k4HWWQy%)3nIn?l1i@K>$@dnQ}Vxq z{yna#;av*XTJWEZJV_>skyvS;5WFd-U0h9Y01LU}Xguc#!ISNOxR~#AA{P(l_SU}; zJO}ZY;d$1)L2-5B4RiZ4?)&}`?Ki^e2&M9dr|Q>?6;@*rJMUt zP*k|_gm)KEo!X_%^!F1^t_l5LXKZAX%ihuuo3|(|NG!wT8;%dk>`8YfzohTk ziW`Ofp8O8iO}gSu0?DMvzy+aJJAc{vzFva_U=D)nMX9e(@)LK4hh1{JLar6~9<^3M zgS3FW=O(nS*K}*od}lwHtodTZ>__mZDsniiB&;Q{?+Du?c=q~lEquYJ#UGi@=3h2a z0X;AXBzGH|rg125iRlnivPnq*62 z)ImsH&9*gcA7P%Nz8C$id>f?rgX4ais_Xg+>ROfeh4m|0ts=TH$nuN31bFR+@Cn~@ zJTo`SVo2UcD)(Sl+=5N9$T{ToJuBtU+KwF#%j2bnz5f6S1)YzHJOiWL-~FEI(&3`8 zx704<`%V0QTyAiheWs!$bNjIKmLxMOrj?^~6lG?9b9_42E`AOCVfcTjYcgBwQfeL} zn)6pxfAo9(KUIfPS1K58SlZlAD6}9tBz)nW0sHSy_-kt}mEsL9);n9z_(!$EIHvyqt-sA0RNlw3-`{CJ3Vs53x8fWg zAMoC(s%awev@3rUUC(x!c!D%}<)6-HmU&hft~Dl+N7}BVG7ZK_3s=tn0JA^FKNoo7 zPls^po-nu8w0{)cSn5!VZ92x<(*AbwU51Nxmk~*B>5*%D9I553Z{{>>zC|FlFWbB0 zqkK92hxH#0*ruUxr|aMFl4^Q;OgBiDx=)8~F72Xs>-){oMowFXWh9=SCf+Ll0ES=j ze~0`#B!^0~_*>$b^vy%<>>f*$k6-ZKp#;Pz*fR!>_c;tz#nJ^_0g~@RV;v8yz7BrT zx<8Gy3v2V?e+B7&EAc0X^myzweJaW*n@hdaEc_npaT+mw@wMNJnSjVpmSR7f&XBa; z9QeQB{{W9)7-iI-!JZet(L7NekoRe;*uiyhQ&QKjboK?4Y>a-v9kfArY&dp4Vab1p zehz$ce-G-b95>H>rfE73nX}7ub&xQ4fyi+H`R{qeF#1_6M@D7vWfpx0+ zmgZeX(&1y&URdIjXOiM9nA$c2?Kc21Q5r?b+-ja4lj7CpgC?P^Sol@7Ej{9UdF1;f z7FP1K|seokHBC&mnnaR7lg$KWSKXZ;W3Gb1nr#cZJuN@ zjj#K}2X-(@klf^4!4CrXTJgIawn4}li|%i#z=?*$!M)(MobO>R#LorSLYAx zf#Y`Z4d$ERyQI~B;UduD(zN|l+*$0omRYyQ48SU+lPvc3?=c==Xw`zbKm(7|4+vSc zhlMn|c@|}o3piqNG^R-d2*4rS{pQb8+~n4D>eQUywCFm~gkZS~Q;fSGyN<@H+()#G ze;01Q=dD=1kLAeYcXL%{xL8PG{(4nP*VM*42ZPyH#@e6H8Dz1y1D~0Hc|LMF5i-p4k5YKD8f;ShvL4w=)GAHJ0D*Zr(t&1eBzi%y-$51%y`0o8HKHY&Q+p*)`uO!Qj{m0Yan(3ypBO((Ug-9E^g;7qaY4-BU(Cd5kpq?#z& zAdMVsyOe-NMpTAVgT_GsgV>Dqu6XTo{EmTcepPl~zwHLxoAI1Lw0Q4a##^bj!UPS_#?&8qX|08h+4PFN!rU6>8twnqIkm4}_qPUR^fuNhI(= zq(ce4hDf7R8?i~Cw0DtHbc)gJ%q)ZVr&aKEp0nZDw4?T|KTKHG0S23ME}p9rs8Ja5 z83>Fo$`B9$IL0f0@&5qqEAaQ?mCTcCi{lMK;(KXqgN>#z$S^jSTp9Pl+01&8u16*x%dudhSyT%=S>*c)Q{Pgp80p*$}~E=u@iv zyyb;{q1}JMD}D=KL#W4fqF;|2l`xdtZOnz1?Ud4EK#cm4Y(qMzyM3B zf59xi3u{*Q*7q7$jU>0!t_|(vGJHbu42g9)p5oxj`>&2)&bqpncU|`qwsP!FoX)|` z8~*@b=6+@PAK}&i0K&hH9vjncgGsJV4Oitgg#-rM3OoYzxBb1T8|AH+eVmKXCPkhj}OERuZcN%>d;05xAn{{VtW z{1nmF-aB1C$55r5CfW_2Ki4Dj;x`L(<&Bp>?GvKN0pK!%Dn0Fn>O}mx*8UjWcq8Ga zrKZ~7UrT#u;Y~u`*5(UX);qh^@aKt5v|d(Oh&O%@vy~JA#zxzK$_d=~qwsd)!`~bJ zB_@xkU4LflJ~Y)d8(Vqpnj@sFUO4c?K^a$Jc-6c?6E56**6a@nn*Bl3zu=eu00?Ya zHd{X)D?uYnjy@uIg$JR!2c7 zUCrn3AS}|Z3lWDSk;N`qw{>FGNLtr_hxwn2f3$Cg^=pq5-a@*3_IfRxmvY%jb86)! zj`v*B^+@f&AQg@)PZz{Qgp$jVgwH;;;4g-LF#Vc572tht*TVXQarlEo(OXiwx|UrQ zJG%0}cX!%HiM5N!c7>l%@lK&M z%eENp#w7DY?hwNadx}j9wSLkgoM$tt8Cu z*eICX#{m_7=l~U5-@1|%CkX4U{zvC${1ZpuCyFnBY3ttxZua{7DQM{16@ z+pXdfE@O#V07Dy}m4cyl3oU+}vokpB^W~0DJX97M9hQk`w$?UwmKOPDSZ$$5V^(5W zh*barS&2MyNj>Ud%^3dj=L7!$uTdDas})L|+YFsnRBhb{p{%Q+zEjRf zTm0g)S|S`W?jJ!@F^=$blSANz8gZSHi`XrFKIH$^Gl4{Tz1i1*z^9vyWb5p8~Q zth!y;Sp9_7ptnzOImb?iouzWhR(1oB@ne?wwQ!+6RnV-m{CBw202%h{_|=$G&27p) z?_k|deLbq5j8a4Kftz>k{{XP;gA9y3zdU5{%*Zk^e`6|j6a`gkQXa}Kg z2Fy$H_QB$%j$w_xzdUs{TsnpdI}YEiIUSi9cm=;&tJqd2k;eVsr}_HSo7zAd_(1%H zS5U)@7CGnGasD*#n6LvZ06prItS(YnSm)*Yy-hkPHw9mqbmF6DXOpK+I`B<87}?tx zz|Vdt6J>^DB4-J2hk3h+LJ{{UK))-sEg5j8aySXO=n$J^CR7E-5m{sq~H)^f{c|VE1Zh0IW z5AmrlCHY6omSK_!>>(w{A0&TlwEKUFe?vGR2vy&*3!r_O+OH$0OjVkbc;u*hx zXOp)W>^TRF_s<10!%Oy$h>SQJ*3Hxb~vC?c7$m%J{Mcm*WhL zy>glbkY#1ysk+oL{J+j~T2LrR*#;F>V`Qoey!UD%WgV+4==}tESXyiqde-nCre_pwv}#sk7||~)uk*?I0mUfm3_sR`@Z0M zRj6Jx9rN4iO-rEcOtN9e&A0G0=;r+rjQjIc1MC}A<(vF!cqH48?zeC;j(Dq#gO{;8LId65!!`La`oC;a+;pM_;2gz!VizY!#ps@&r0pmQEk92 zh0h|L0(pCT=RE-DHCbdgL@J=b#|OPxf%{>;o}AT2Rsr}e-;Td~txF7|B?r*;^yyKT z+SvJV+b6gA{&d+7nI`S6o`Cw*a{+Z`P^XQ)l=Xf-U%lHuN~3QPL|%C9&NI@b5-d;O z+nUizk97$lKSgBb~#J{qfi9QQZ8=<+cc8oF0a;l(j7&FqXD&m-mf=z#qHD zKlsX=vO5w{gm8)03`&=$Zf0Af{-38lP05+skz^5(bGBzTdKeDTXP zJNBG^0fI2Yx0sxB!BY)^?s*5sFNiSsZ}y67X{0#+0EAk`eVI7I-D-^4IKtsmKRXPA zg$L$fw%$Fr)F#^w*x8sO+G?HMr@|12QhTM1@lgRFWQ_p_I#ySwo<83D1 zO|oC<`G~LXjl!&o#GHJ&?s|SbYi=(d=+}VxI;Dd(ylNc%p-4s}cG7df0OuZrR$6#B zPnLbD9FT2^BUge=vBsgez|Y~+H~?{)t3HXT_;56uY}Pu2(&KraF$>De!sUQ${JTKT za#Z#pEps)vTIR6$i=d|4hSs#NBW_kfa+_UvISrnD@Nx9*avvRNvZDV0r^3poKO!;D zIN;+Qk55BcY2(JQx!tLy&4VMTTZy9{Tek!dK_x~=KQ=btl5hzdR$9a|+s7!lwVFGC z959h~xFBtAI)lI-xgDsCh09`V4;^UoTmJ7+hG5@t42I7<73w+@$g284Jo;A~aua@v}{9qlUoM4`Vk~5y1_U5l! z>Y%G*KQfmptU2k^*SAibb5hT!2OC+L268|go}E2t&@5126j|QcsgB|~BVnERZKNK4 zQU}ePbQv`C@kOJ_wPq~kg03QV3{OsW_sBluK9sAfNfqMT{g&oPe4{YB1}`NJV2aQqs&pB22W3K^XXGU+aHJI@<=@@*1BAi zCH!q-5nUHiWpGpxV?SQJotPNG>UbicZ-`b;DM1#jSyv<`op5&YK<8#TBOZkG&l%e{ zQX+iYTbz)&<0sRvKHSu+xNScyXTDD&w2q=OWASaceV=F4T1YazySkGc*zdvz4{ z@r*GIqTf<13CI?5^TvAMXF12NcWh}ZEoUl*`ekr2!+G>H4C&jI#Lvu{a);YK3~L4 zZNqE}M3t11Syr$eoA# zEV3>}d6;L6W0$eXjdbGA`8X|Ckj`A z*N{i8C){e8I(~oFtLp2xl~;4kv>%0<^|iW3cW>d}F+m=6z3l6|k&KTgImSkEaB_Lh z1+nn=#F5;2k*<=;;|#BVAzm0@f2!z4@1e)u=N^F{@YI99`ShdMrJ@N{cQ|c3!u}>l zXzq0V9zYOgXz$xHI0N`hTyvk|?gcjA!rmdDOxre@XqO5@3}5b6An}o~d;b9Suq(e_ zzks1X)SsvIs&u@;RA-+J55!oJf3fYt@-VlV-PrTAdJdWXS)L6$#8KJhbSOm4)lQ+n z`GCgnD3Su__YY%G@0SKZgB!sGDv_chep{{Y9jf4V=NIe*rl z_eb)fMcZ)ZIk@~UstF4<-Ir0db~Y4xe?QKZJTa=fFp}#}^72aUg|~D#&rJ39>0J`_ z{{TMJ#Qra>S1*{kj#A^nn&r%_My;pDxCIyNuHcY!&RlSD_*9GF?OQtzWO|y=+@HdF0xtr}8rjDu?_b(N!jM&_A8B34h z!Rl~$#~Cs~;9Xq?*>$ToUoQu3fkrLQ&7Uy(j>o6Hb?Nv203Y$FNBYL2m!NYta?$uN zSXonC1he7zDG~Wx`ycAN+sx@TCZlATK!_Y6oepVoPzXB*LU70&2W4CbLZltCyTnKd z19Bil(jncWpyUn>Bj-Q}qpVdn#}$U7H3%W8>ye~H4us$kp}f(RKVDoHLU4#s-ss97 zFRlwAI7E;PA#xyOltZ$sn~a7W1R)jr#NKls7jS2a)^sV$Qv>wK?q4* zkPM;Z4k4qG`e|eoI0!;;$UrjKAA&=Ae{*-00XYzIghMWNY^V}hND73Ea>%BR3RXxy zzXu2z;SlmhnmZxH+98mUP7;JzIs`IOj)VIl?hb*BGBR+8l|x*}sIG$V145)7qC!Sd z$5s&;%=WrFq>Bud&KDvJLM)Sd78wWU4MtfblDg{u-VR=i3?UO75;!7lW&%Q-9ReA7 zWZ)2Ihm0XZt#F9BLu| zM=0)ngb*CkK!(w-Ukby&R2xQ<89J1u^qlDqwVhFKv2xM&Qd`4NO^$8NZ#m^%E RqoDu*002ovPDHLkV1nTs5oQ1Y literal 0 HcmV?d00001 diff --git a/data/mllib/images/multi-channel/chr30.4.184.jpg b/data/mllib/images/multi-channel/chr30.4.184.jpg new file mode 100644 index 0000000000000000000000000000000000000000..7068b97deb344269811fd6000dc254d5005c1444 GIT binary patch literal 59472 zcmeFZ2UrwK(=a;AlB0kinU$#IoP$Ufi4r6WOIpH`vw)(4B0;i}a}JU-7?2VL(Sv65i+aSt!UItzJQFhKBfPjBLuNU>0F3RvP&KD=i2f&Xu@h-CGEz?|3lK(Xi z-M+|UU7W!ak8_La^uJG z(~2}8762~#?}C=%JPTC>Pe1fI&jJv2@E8ICY`_h026zE3fCJzI4{ksRyk77 z@W8mpD9ZeaH-@E?y{p4tQ0Vqf2ulRg9O-HY*1?s%bIx7$0te>tOk7>;<;{`iP9`qq zW*`F4^mH))NsRr2_#AUyqG)Fc126w?>koBxvelHms|hys)41ox|0_n>+0y(k7<3a` z7fln(zu@rU=3qT@4;Mve1x@AKruO!>=L|7_q5ndRWo7T=C1s1S{ADSEAIdBILW4Ci z%*-uJTy0%I5X;Tn$>raWYyU$3l^n;^Qrh0u-sw-qCH%qvru^>$unM4Jk36?EmW#cE zs;i5$`JWvJ+ZJ@B|ANIewRZs>>c8M|z%I7>SMrNdb@{xZ3s;3)WSK%c9f(n76prW8*pktyzae2W)LMZB0B5pJ(brWJb$A>)Vm%~#_ zZqhT5XuLJ$<#Y1yk;Y(*xU)I61$X8rO)GuGBonD=Ci^bkMc_)W*0lMzvd8cHTxYiB zn4>bDlt(xB&+cSad>EM9)wXa8h{>vK8Jyn}l-IFz4~)&OY8_hGCj?MXKxt7gWWqp4 zKV+2s} zUAu}!J`V;`Q!R=#3y5dFcdwaH)O2$Zswgf##VZ{0Ol+LlaVt3kl$TT6Y!5?EMqhlm zrz!WQmS>TJy$AyX?)U8)`|H?C$|~yugP;AD_$}slAB>Dk`zVoTXV8$ZGrlc2RkHcA zVrX*)On>%*_-b>m8LU3neG>XkDY+u;-GuADZH^yL>uEq(_X!1xSrR4lH^mR?(=fGgeRS4{To4t+?y8byRkW#>??09-M2eq({xLP}*5Hvu&l!MqC^|mPg~_KT4KfS>r)6L@Llxx6CLUH>L?whl6p!XBBV`5xi*f!$nbqRu|>xF z`}+IX$kqs*Z%YQ;fWXASVdbN>0|Fc0=OLkq_eWAhsxzo0zFKGJ8!=X)> z>SSi+@fGE9O_SNUcU95KE<1B&)Stt&)nx5Kw_?|0I#n!^Yp3?=i< zfV0mlULDCtT9XJsh{$b`qId7<$U;Tm#@9p1RK2tgX}5|9y|>k68a};-oUf~PLe7B4 zhrT&F1`)9G%ADh+y)%GsZ#i#P4@$+hHxp9i{gB=5tFe9l)h_#ru+ygR{E74{ z$+s8>T4@K~w_GLhfagzWQ9iMEmJNMP?tpp*Jh?Y|sC4~gBc<{%GrDGAXY3y3^Ji84 zHhk!B1n;O!jHNbZ`5g;p6O51O#=-HQPYXQY)JF~+9j z^<+p7!uTm<&tR>~)AE>p;!yRwN$~R6?38j9@YVB=euxsE9^~IoJ_=qz zmNtE7R1?W^DF|}nPPWhrwQ*dxN&1u`(4;@oqN{y!nzcW1yfDvZyJBmzZo%bxa}zHf z_kjyaR`t<|c)YFlmPxS5TjHiAkA1{!J3R-Tz;W{3tdGy-vnZ=d>}jmcyYcWa6!uzY zpDZM6`j#ZvNVG;RwoVm2wGoZ2RW0*7sX7CY9EbTMdq@N+e z8lKxPt61=!R@~DZl~hn}N_RFILMU&ik6e+hqPJ|kKTRm294lEkS@QVswtz^l_i5;$ z_Qdka@^vm}DqWb6O+Wwi=&AMCLH3#%dPs{Q{~3@0j+36H#uAN*QO=qIPUpL=gj`rO zgJ;0JevE8^d9bt5K9}y=Ya&5ARFs1=KwiJ$4A>!E7{q?IT*mf*yd5gd;7(Bm8l&;q-(MF88A8M8Ahu6z?DaIXL;%7eW6tPLK$%m zyNbRiy_;i7KJGrLqjK?WJ0*g4r*>f_-CvGtqjh>a^WvRx+|}VDpFVLQ=Y<9jH!?^M zqmjiE#wRE6Ni&PxypF@@B2RV;f#nXpGvHQP)6$V&;>4lx;bTg#A@z65ZeGz3AJz{U zF)yG46N;;4V+*>kdEIj@mtM4o&Z`KYmeQu4+)d-qacjNUUqZ1Z{x3G&S!aClk zlX}RZh%v1n-;cBC5ai#jrP`;)KpYSBoM;{Pb1~DhZz+JUKB!p7pC&#YY9YaURPpJR zxXMYn;a4aO?Suyt2mY}tEfy*_%pX7O-0MsHIDcEcJ-7v~@}wB+BQZL&POs#GFYU+^ ze(j?U^p%9RfT~lxjEcdUx+cNXLIcY|->CT*ugtQR`S9X*f_4U4?hk)`IZ)@JW_hm2nz0B!|Xt9D`Ewrz*a^}6_vj4K? zHCW^2SG)&UqzWZfk+w zDsoC+m+gy@8)IjHzfssjkB&EX!xbUTWH~P()9&08qiW7nHSf?4tIIKvv?o}R`1UTz zJ%u*=FHi2Zm4+t6m)trdE6RneWFLQfdoZe7vu+)uZNs|6-kvpByBf4W>2av~aOc%Q za?`q21Mk3VS~(?IVO;N5^(Z*DD3ms!-#CkhX)Ik92@uJ@OKCC9MhCr3U=x>`^VFtb zMs(^l_Y9~p7c6-Qqv1LmF`RBV0+EHKXbzt<4P9avj^Or1YcuMuAKF*NK|O6 z7$R->Ng^BNVTt)qI}b~IE%o?K^|h4So%L0UUVIAqy3C*2Y2Q$#yAvyY%5ylcqtxD< zO_S9y^t!I!k)g__@(d8%k(AsX@C~-JS#4YisqA~Y(yD|EH=zPOI?RBl5LI+A9N zzmjw}ez-B)-r7p!5oZ6P+yr}yzL#pq<lcu?ozqXO?I<<)T$ zP^(@6uiL|4i;p+D?{p`$5}7LUx%SCj`_@PQ6q3`^f|9LcV~KEpBDm2eS5nq)pU<7fEdA)@%C{*6_;4d&-74?8^=YUL4={avU3Rc+SHX zw>oAp|A?`;B3NR-q7HPNKK%);l7voPM@qEE=+Btt6gG~tuX&c@3)|H6!~%5iH`c3T zpA1G0oIOv^fN_lv85mdA@WxC&QQ%$vDv`36$V7lR84y3xyCItXntIS7gqqLp%NHkL z{44V_Nhe7qWLM_Xz5-F{Et>(J6)b1D$K=Ujt~S}??{lXyFeNg`PV=K)2CiH^`F;lU z!rQ;89+e&^pY`>h0TezQtLcGJZ6sP+pBl#A`^g>P(hl69`}o?c;6`ZyRNvq|ib8I- z{Yj19a$^Ng_P~vr)Ky1ss$M4P_;vfeSl;-cW8WtOMhnPMB_$(1v_c=d(NjDAiaTwJ zufn7(*YMVjsv#TdXFyV0BdZI3Pw|$nrpst(^-6@KHZkvBW%6@uIR5xF@jpXI(`-{UQ)-EkuM{cCcXCQRl1PC4@#AZvF^ zTzRSZ+q9kdI?%8HC+juZP66DVgJw!?x|R*~Mr{kn%M|iyYD8Hrf<;cKqYoz#KHfSv z$Da(!7SJSaIB^NSXcKd~vdS-A|BE=ilBDvQS zR;N`x+WMlHueecFGW9G6x?X>i-e@+Vdbm6y_04LBq|!&b9aZ?`hW`|_xAh+F&;$r1(bE z=J&TY*_Vy3D|dW+6+b-%wR(_nzxf_Ny3B4>r2Cy_JO-mye1S(I2Ww;LM{-5oK3(W^ zS>m(F=mFdvRpz$lCed76c;YF1exEUAHo~-}6$T#mN+nNVEN1z%;mzutthxjqZ>)OW z+1p32O<-$%z);<*f09^f>%^(rbUo1qB{*2As*MuqJjOSVlWW}^wT67?az*9uK?`r_tg`Pt8R*)@59@PuSUvLIE^c{18{DezoJ`xGB4i0Z z=cOAG7#`tHpA`FUs2z%m8)sG-f=>y@pPKgRcUV~^_bhY6R(#6L&VUg=E{}XzcfhACUd{{4Ke<9(Ca>JeIf7?j*`)rVXKnmQ54=|!TB{iw znw9}osy#=PCvl@Y+PtS7_&a_`)dl&Hrp3b`_{~1aG~Ooxjx%e*apEtTVwGFvSs;hrGCVo$?Qy^@Bm(b4c;!Tov2^J zA*ok^>A;eY4o<<0ZglXdxcGWuW6H<^e~NE(;h5p#mt}8BAIH@*ApQ&}Z&CtZBdZQW zcJ_P%EpsK~_x$?(X2exLA0~Hkls4VzE0RRmJ}k2z-fNv-+wqB)oIZI5d-%0sNwcRS z#EPmFTPh`1IP=pORKWiRVbQ$r${7%oQ|Ffat?0|3K=T>E+P8nLZdE@l|40MgG(smS zB2Ms0dHeWIw#Ed%SkQ_5m4Lc&*!yyJniTN)s>?lhO`d@H!F^y}o43i#PV@{wI2KJ* zc1vzb%u!03NqyZrscRT*9BoD%H#~d)3{}47h&rogJKeyS_)-Fn;g^w>oX;G*yT+qh z-D|X8yj4eh2llGJV)PsloaBwxEp6!e?a@k{0os1CdlN^8iaYy~mZy(*nBr+}T>XR^ zB9|QP^cAbiPh)pXx}fb-W$W9|C=1-l^QGnbP4gVB_B%~gD;{USJ6|V5-cx&3#hoho z9ElH-=7)Y6yG`-?$X&nWFNb!gvBxaWVN`0<{5eI9Sq3ij5e0{?pXJg|dC$IF%zIwF znDz5KpRGKf0VD&Y05iZIFa^v37@!7Nfv^jBIA1^xAWjyr1wE*_B(+S)rv~0DuA@A>TNZ}yD&9AkOue}EBkr9m<@*g5ucN_HMcWI zx;Vplz(5_q!^I^i!p$YZ%@5<|7U2~V;R53)SN-FDNsIE>Kk+vl{FLseRG^MenSa7y zxZ(vkDG#=b%EiOO^#b}!3zVoAsK4QyaL5lD2nrMkNBPaOJ^~etAf0yufYKa+_8Z?B zf$@_+Xu}B1pD;)mf%Ovx!$J|*zv=rU0xtMJ0Qroi{E~;8iz@~4lOAj#1?86|Awwz9 z->`c64|>RX*QNiE6Z8Q9$oN(F97g#=h79N*JTp*#$pAITK>LYT0_oBJME^?{kZH!B zbjkRo3+M?M*ngtO{R0M@K2PU1IzMF2uNS?>#V^7w4C4|8`%Z+1|3^>$3l5}31&?!8 z&Q&|Ff9`4L*TkP!@B#q89qRuqaM3n+0zC*EcfV{3j>JDk4z&D_kpr#&^T+|?A3t@2 zHlzIRS4xJTOn=0Se~tk#ZZZl;1DNRO;1oS31_lNe7A7|CWjtIQ99(iDVuH)m6tpzd z6jW3&1`cKz9UDCr)iuHEY+T%Ye0;Rb!s0?aVjR4DJm*XxSYQ|h2bT;FkBsLk)m5Iq zoqohZJb)A^yD}73A*el)%0)}@mD^0Ef#Y-GHI7` z%*1bIaFo6()&QAN$7*-=s$pII5=B4lZMZo1 zi4D1!>-6=H55+5|9lwppKV0f}&nB-d`XGYM(HUn zY(vXI@ce|ReYHnNMQW;0qbc4V>?x{fg6g_h8RQ;&9Xr@K<=a?mP)k0*GTdERsM`MIoHa6#G%>K`$8z){+hD*!b6LlOo~pCiPh{ov z$|tc#_7|VnA}o4uBTfCDi4690t`W9|rmF}c_-d(7Gd}VQD{;+Gd1)ayra}T1jJ!At ztxuUeUZ_SBC{=eU;?+;CO65bVKgKwR>(iF2dtVST*7V#2H(yAltGlF?6Kf52-X}jiC_s*{)RO0q3 z7pugIbzYbD@xMiw@bWz1$$ai<={QMygY_nkF%JW_g2yPn!ZP^qv@#f#$&4mglW)UwOA@3O%tAq=Pd0_rzra^_2icZS7E zR$HgaDe|eFGIjCXc+-|m+E>e@rgeqw`AsG$l;mjch)zk7XVwSGON<3|)H~CXxE{Xj5}%rS=g{ck z?X=_0J8SE*&PTl;KYdYMPYJH0S%5Nct>{kACPayneQsjQI!y8!jE#FNs3Vn_Q8r?( zy*5wE_R;Y(P1A<<_%jp!?Sjzw6zzlP_N!q?(IG>(+_%H?>7E~3VlpQUY1JpLRvw8J zshCxXQYIZccpeT(I&O|Hk21C1WpjLI&|2*`A)VY?eKjPCmTh2Y+KXESf79(a$Wo!P zOwqJfCc=z?pe#c3Malyn-Y!|QPabv(M|O!{Q^kA_nbf}%vF6J0em8Dg%nqrHl##UU z%l7FRmyi4sDw(KE8(Gcuirnvu$zU?u^w*xj{Jyw-GXtW82MRCrZ9M^|m7A%Su4t8y9zeR)oFuCGE2fkvJ*Fu#=Vi9C$cHDdwRrgu7~g^wvxb(< zWv=hSs%u4}N5@~dRgYd0Pp6iYXbN!{*$>j{KI4}RU+nLd1Y20Z0|a;;Se+8p&t3*ux$CJsj_b)~2o9@8#pGH+@>5_!hZn?kmD=`yWxR8TL; ziuBZr*|hPT?$|dYVlThZtbg7kRPj1I{!*|q5xAZ6DW=Myu`jLBk1%L4J1qIQKFD+~ z;=5T77cW^qXqe{9!_4;b_Z=_cCap1%pJ<*ike@^&@W_LlvZa&%4+-W4FD)D7&jPl#3C|6Pp?j@|xmWXje zsp=b-d3?HVTh{jqd*{MKM}(W5uu`b5MEb5)-C5^89K5xxBM`ghR7Uh=c_Cp2H5eYZ zI~}R4Foo=`6_0=COCP3LqRGQNLO<1Tat6>4%!@h?J_{I^ujc>uJ?It*BdNEhc{gf1 z)g?ho@lvmTwEkQTba`ZSgr2o^IaL}@O5#Uvvpyrc^bLRS$!!_^hJfQ{g`OcQ=<%{3 zx}fO8g}w$2ZM{f$CwiWe1AsMnrTr#OO%dAp}ok)_EpQ1V)1E_uP_`y;tkfgaCdXJnYEDtWcx5o`1PJX`8k{v)rew+d4-sw^^t1&wao zc{}u$<(LY+5n6j_c;yNFTg3nsUQfvzl1vQR0$EvIgRVU5;NoGw=R>E=pGj|OzU_eS zJE%(L6sb2kxXaSUwtRMc(;~l1@9P$!e9>G82)X2sXp$T3WrYuyw=bD_e= z9(klyI(J?_K`lCbOas65fJdg4F#|6tp_sN4PmFk_r*I(v?^< zF$(?mn7&KmP04fG7NG<&Cb*_*%}8uPss|#cDD}?-Rd}-YEp6>Y6=}4lcU=3|%|jX$ zIdpZogH7mBPn=hl49uHtt7uGY@2AX=&%*h4$Sl3oH{fjrh0M$9FHsq50~QAa5_85! zIx0t`w`pq?p*yp_B*SYg_Y2+{@qAvksIHG+J$$mw+FE}#jQvSBYH+fV>R@GReS?Ct zozs;3wE}5QYGM9`flMsyEh=)}I@AVi_ZrfNN64P3jxDw!NWQ zQNxkhi^NMomp6RG;KGb|NU@#Jzsz)$*zWYM?d@h|CLdS3-6b7GES#EB(^hH>qo6jyhi`t2?hL!UkIdZ%`adCp`9R4GQKt zSrxxke0-f?E0RH{r>%Fj^Hz4kJyPM*oApN@;u`Ap*uUo1p}143!C#OChk*}+8+P?F zX`x=oVHeMV0^I#vGQSxzO<1rup-~vFjbpY!1`%KxlHMNdA}+;%6c>%FH}vdtv_WR{ zWMq@RdS9SJ!M^G=D$4dInzNtEO(OkR)-}%f8GYTyaTb~HAEOuptM_s~sB6KysaLhI zgW~7BX%N%G5dy^B48! z!q8-sC4D1WHD(i*Tm#}z9&gjCTytW~`|{eF-B7vIo2hJ?JIL!U!&(Ke;GKR#ERSbX zX6PjihA7=`FI8p>7WmYf#S&dB3)x&DpK+wVOR)JtIZfIT=Ff%L z2%N`mttpreOC$0;2R~?;wVACv8Vc*JilSxLNqkS$8|l1X>u*p?jE+M#?iM@?AFQaC znEW1rYj!_a_>f}o$w|5R?ff$U-{?s(B`dGqT9~+${gRRKv?0FXO-jpQw$uT$r%~fEq1Br=^ z6}676?%pYI97vcmFKMJnrmN$&*!&b+8yws7Ok=X--e|j;D#jc$*-UTTq0H9t=5u$l zt1nO+$#i2Ty%sxa-fBnB7v_?%RJ$L+b?4jU;Z`gED|@s3NQE z@O4x@r_k1=a$NUC0^aaZ1vb4N{cFGQhbAu^GLl{$VlPC0{5+rTzi#KIXjeS-FtIYE zImhucis*brghFT`Zlv6sJhXQX=}_`J6Up&ZuKDrWrF4S&i4zZz^rsyW{cW^n+4iM= zSjUdZEw!)S;FEZVilg0Mr;}`m^yxN@+$S`;W;o0i_~ak}4YArRHtF_323g6Wsvi=5 zfAZ0T6n|FL_w^qBfHyjD2E0+W4?Jb#9D7%_Ik(6=jjThYj_UKbzu?O|y(`n@7M`3v zPO>maQ9sg&aH!n+V85`Y&7tT7h;z3Z4CU%Yk7aZjuq)GFb{Ts8vWRcg9pF>a66cHq-dE zQoW4QC~RJ)%lM*pv!~)uQ zubCSu1=~qZ>0KGzDC zo!mE9+h1&~Ui*ZIkCHvjtwYW`a#4Z(0DIzE(6S;tE^bM7L>%S0o#C!<7{&6X(Io=Y zr+VHWuG0_zxMs<|9`0^celVe(p0E* ze)6U8MJzg7tX6uI0yHV48oqoJQ+=+Ryr;(E4Lb&1$Q&6}_?uS#X6r^mV#L-mX<;lk zs}i|b2{c0dB)kg|kjZ9w(pT)RDf>j+%Xc4#9lZ2l^Z`vX$yW1bm~H?M!8d5)G->Cb2Q_69UlygplO!Oh1@=V_*os&Z(?g?@X+f#!5Owl?>^P)hYZ2o#?v$nD@yO@{HB;oN8Ftb=ucV-Uw7}9?D7;J3wvET zWmw!?9#O7-Nj%pPeK>gr95zvGC(@B9*=Q0z&CtypzdRF-#OOnPnB3LC@@c!ei9Q}Z zmd!Y~_&!O!etxZhIxeff4?XR)AB$N}!CQ^nB*J9$JH#eQ~u!0T!@_N`9|S!G$uxhv=%nw~xKu}|Je z7z23=LSw#2-WH5X_wvH)E=QJzZ^}#gf}6cIG3G*9vx)V%GI9z#t@@C@196ynl+5Id zTp|Nebh6h{9t9Jx8Q^Ha3)kRLTLSs)c?O8B8-?_FPM^Y}t1PMDDZp%4BP-9IvUZ1* z2hBH5&W@3rZJBM8Xf3?SMxOHx1rjvagi#fBHS+_x8Y;KMaqY`|0tSh`8NGUGuVrSK zUV@8*FM%9?6ts5C?R_Nytz3}U;FiB4H|YZV1mH%cHJ!>5Uat5)6YHqe;8}iG;oRbW zr&rMqChxATW;h~lXL(|a^;;c!j0-vK4aFj|$?X2MXfO3;(Icqt8yc;=t#)4GnXF6=$ z8*QFEQ?!q?XT%~W7Z0(=<-7io-JLsze&pt@ulTJWhEG4@ssN4#V%jYGn_tnzUNd4% z5;^vLJI-6lSbjC7ueom4xtzp4IAcIuOCP**GTi-*FUunxwVOo_!`(qL(X*mxS&_cS z|9;R(c&rC(fZ++O{@KE_QWB??Zdxu>$B7~ljix6RuntQ7@%DLccdnZ#ObSz~$iNzj z=z&13XbCOinDVRLp0`?_G(B-$9L}K7RBt9b?){qEegC*dSYj7k&V=ck=hV9_@VbS7 zwvJ6+)j`}k#%Odto^{1#Cc~Be?WDY2){Jhg{oxdaoMfc-V_%G3xLqk_&8xsJvl`#T z#wC(ITt=Z4I2Js!=dMCMZE9B9bM!$5Hb_>N0_KPklv+0^{q_(82ND z8792fIoR)`nS&i}6g`(?l@8Ui6zFokySbCZoz&%0f2qSX%-$j5&6G@nrImPCk*buA z3oRg>w(MY(mB|nxrwA(&M9F{`t2lqxa;>f@FBt|HV3yWGgs*+HpU zUJ@!=qwLYi7W_gT_G~h4;g(fOlq@~hdmdz~MD+XCsX}!N3z69ZB_tY*kmjpHOL-y( z;_K`8gCRn?`QaZX?O2o!7V2x`_9oZ1A9X2qOVy#$yjsyG^zc(C`8LumTY(_2v5<nG;Nu!s;yMJ~lBC?b#BGuPnbN>66Kp=XbKVac)PG$PCD zXzx|h?(L|a_GNGu#^9ZR(^qqi!v_)*%QU#Ly1=L|7PuD>3>BNeogI|#$Z4KO-4tb|Va_1Y&j{Sl z%iiEc{1ND5S5Sk&{#y1w<_O^qPA*`uOAVytH8XdHgYaVzwsm)LILB{-@MTk*a~S13 z8hF_W6cB{T&SA?RFxv&q5BSD8Y-Wcv18L4>b}+LuJBOP=_^F%g`PMt=Z4eG{LzugR za6bq$*t*&wK=^#O=4Cr`6L7%+>iJ$R1{ZU<)p^hxgmIlTHKajU3|uIHYxyVG^iQx0 zxCaiT1*GjAJTG>I!LGrt!?=ZogR;HE%VCz!K?3EUjU{;y8_A6NX7Sbvg(O~c&6 z+{qjXHl+=E83Jhub~n-t;exP7!Vt*+u804_YJZa99Q@(eAV8dP3S2hk04{Z32B2%5 z04gyK05wPgOCZ1d?KYM+xBv$LbZ8fT_&o@NPV0ptQjKq*iK)B|sU7N7(83=9CHz$7pStN@$99&iN4moXrC5Ml^9 zgciaGVTJHOgdh?SS;%dO21FNf4`K;HLR=yDA&(%>ArX)`NGc>3QUa-gyoIzwdLUmQ zQ;;RdH^?Ch6a^RMG72>c6ACAa5Q-Fv5{f2@A&Mo6BZ?QwW0X*omndl{1t^s$Z&5l> z22dtZR#0}qr3*MvQs`AEJ5&fN4ONBeK`o$8P+w>eGzOXmErQlUTcLf>N$4u{02LjT z7!`)fjw*_(fU1QGM|DK?MGZ!cL(M_0L~TaxL!Cn1L_I~rL!&}tMH4|&M7xV-iRO+L zfEI(6g;s&~0c`+n4s8z|9i0@N8C?)v0bLi}8r>T`7(Efa82t@;5Bd!HF1W&h9D@}@ z977eu1j7X*0OJ)#0Y(!>FUB0kcT5~iT1*~HIZQoFB&I)R3}zl?BW5q=0_G9cB`iiP z5iC_KGb~T6Fsux$TC8rYIjkdWLTqMiacoU&YwU;EFR_cTKVXkxZ{y(L(BTNxvtSn~nPh_Y3Ye9v&VO-VHoGJQuuByj;BZc;k5A z@rm&{@D=gR@cr=P@hkCr@z)8k2p9>Z2#g542x17z2)YSYFJWC`x+HzcL97ZlnZbKeUUPbNHUWe{a4

    oD6$_OHl{Zy7RVURhH4U{qwFC7l>i5*EG^8{)Xsl_XXc}pjXo+bh zXc4qgv`w_jFcO#)%og?%)(qPOmy;-5b-kK;wd?8;9W$L4oj+YM-B)^SdLeoX`Y8H$ z^qUN{49X1m8S)uM7_k_I7_As%8Cx0mnV6Y$m;#xqm=>AIm=&45nDd#(uHjviyykc< z?OOkJ^y@;`ZLTL?|IC8I!p~yG63^1jioz6J|rQrLhgM&IKoyT!-KXU> zS<25fvI()4YX+fKJ@l~I&$ zDMu=gsW7S_RjO10RYlcE)vs#IYL04k>geif>hbFH8r&N9HCpcw-!ZJt zsMa+tSFLy21lk7Lx!U_WGCC1DlXp4q`rPf%rO-v_R_kHtY3pU_?dr?wN9)fS@EZge z3>q>UdKk7FQ5e}8H5d~Zn;4hfL%XMQFYn&5iKMydq{Il_M*oU{OI) zJJH6`?J;~Y$+6h6uCbFZZ@;X5#qjFItK&GUxPkba@x=*L3C|Mt6U`EPlWrsxB~vFq zPd-dRqztFZr&gvhrNyM9r@N)kXJ}`%W(sBIW>I84&pOFQW>4m5)>yeNg())*{#PzE!HVsZG4CzFnlf=A+QZsty5g?=ydA#V7tx6dmED5f z)t`kw*Y=3@H1Un~zzbUC3SJS*%}@Tl&1L zzdX0%uyV8-utvC+uzqd5eB;JO=jPqbxhsZ zX3UZjc-Uf+7rfX(KI|3TZNe8BX4H>zgEWw;O-Jc#}-FZ3eoh&)Qhk9X7E*?%E9(Is|-PsfAV&cJ$bfyPeybuL( z_jk3w*aUI+!Xy_>{b)-DRLKe9aPf3#0)F$vMV#&z6Nf)F{ueG_0mz1z^9LI&v8(Lt zFc2pq?PPA^V($b)IKxbA5hl*&W-xa4bFHxcpt<8}YHbd8`8PCToaeg#NuGCvHvrXK(F}s!u;~cUv>lYzig%n`i8`R zyO!DlG^!2)X=d;4%mxF$N#cC|EDtWQ7CU$Hzp(hj*ykmGw%h-M+9@mlul2LD`|TQ_ z6w1mXGWKxS^Ht%BG7>*Vh?zYcVJ7msb3}yrxOq)@O!(LZO?kQ4xw*~Y?80XJX6&XG zyyoU`Zb4yQK8p)w|H|xK+yAVO%)oK{b6h!qU!u7fRc3Hba9_^F>UGdtxPR+l1{X2= zW0d^x=O2Sd%GTv)_D9|paX$V;Y)z1s;&dMDKR+b%v%7zZYK!>A;=uDDi3FYIXWi6shIcniSR3+EQ^0b%~)0MU5lBRE6PgZpLopw0>Rl0@AK!IEu=XllJ~LdlD|{F{Mj zbbL&brLN_4aiHMiG!Koh4BOZwB2J-l-ZNA)#a;~kr_`;k6v>z)8&Q*lrh&<=70#6s zlmQ>-<|39d!UM1n=zU{uIpv9%Cz)n8xfHi+u$y>H8hoQ$b$E(h-A~*FPT#+4&`cih z%AFzh_V*K*rtVDA-e-2ooqfHJ(V*#!s}nm$)V^swHZ;_^GV!4i zFRsV8byA#2ZDIB+(ftY4*tn$b4@T6w8mwWcRt}>xr@Ey{J}8s;^EKk=<0CD2%e`Mu z62fWjiV8NVuqJM8Jm=)o&yh{6iC1$w1JrtFaE50eEu*vEdiil@i2q8yzUM@h)8>H$ ze#RJbXWKGYd72jW8GY!?)M-*?6^SSN)Jh9?cmhQQBLBYR3s%2(S27DUP;^5c8*>UR zgtB(;7E^NsP6_WRwu;qS={jG`e7$RQ$$QNei zXRVtgdSpoVsjw@$`~_i#lVLl(uo0Jtb1TX1A=<(X*WQU)_sMF#P@c4G#gYAi_QE=j zrtr07YQM>HKOA4IWAiPHN2`ql@AJm2oKf{Ff=K*VAu(??O{*uJL>}EGo!B%k4)`|O zN_ihALbQifhK`9*Xmmf|DlB)mt+?YmzoGsIBaZsPjDgj4&e_BWlCQ~Pk=p_aXr@x&ykd?V5$Bm@+W^UjW2FJHJUXy;u{BbH+U@(}pT`Q*-IIxAgx2hCG_N zEeO@ZDZVZ)Na5;bp9nM)OQA@gpEp z(Nsn=!5rjuuax{pYh$U&b!#P~Tf=>9vxv(4s#ZTSU;g<3WBJ#J+}qtYy$!&TKFbVE z8McnyLW7_6g|nRSG0r;I%wlsKbvmh4ZAZST?X}VRrvz|bQIEpU>lC@FMXhe^)xB4# z`Vlw$7Moqa-S!`ae+TSX1a&Q>`H{%D%YP6%v7QD*EgHGKT4tD-a!Zlc82xJK<`vvj zl%H)|H;)2jaf|-|Wv74eqv^dz{tBP*cFa7m>fRZ=KQKj`!uoJXn3KUCUufjDzIF)9akM|{X};T1xQ$7Y6!Rk` zfY09GX9Lr%Y+ZP3PnE6z0JXKYSo5_YfbU>A>&I`?6<=JUoTWOCPM3W@Q}bGUGl{O_ z3RE7tUful9b@0FZ71!gwk!xpvsV{@=^!s%aNd*2KwMb=*j4CiiR@_g_GEOl`;{O2n zDUZdE5$b5~BN7CdH_e40g8u-Ar+n9mm}Avx{{XXJPd)uVQ}%8O z@cm$~h|H+MNu_I~?Y;L;$oGv)_Wk&^ukDiO$J)VrbZ;%85JIXCJeW{9JP(+Yj1NkW zNBzD2E$L8sQT$5r7nd6eho2;O4hi(!fMeL_^sfQDnnj4RaRhM!!P`4??d?l%s7Tmm zk~40qHx(H?`ubOoPZd?hQmE|it!?ev=jamQtffUKhllMm2Z7Tx=zrj%KN{WkuaC9m4uRQLI6h<79Ii%vE5k+GNiG$a$=X0E z*bX}PsWkb*q>#?OWINTNcWur{&&$R*?^?Gh+D>mxZL0qOf%lUx8_KP3tHS>P;ZO8E z_8<5u55_??ox^zhR@lIAg7iY`CQfh;&CACa&mF6t@o(+N@l(dPlYOVe+KVY&%?V(* zT%465M^3}*?T;aQL`POnEo6jb621Qbh;{UwnNH~?0?h1Qz)0*LQ z0p>I^MQsE(4fn#y)O9@m72JN-2JY5fKI$3m z(bhq-W>JEH&j9g^_O5@z)9YR`^5oR6AfHcZ_6Z}C&GJD0B6?u*0qwT4ol11$%LjMp zdeFdBl2=_dJ(ot)%SRn{ug*k4NZJCeSNB5&nSZ0 z)=?>y7=V(NJ+O0uk_B@@ym>jNxgXiJ9Mw2~UZ-E<--{66+(o)MBZfnXrbt=*ul9-K z0CyzUn`qy(M~1v*tLi#shOwt#Txsr2%t)bH2-hU%*y9{}*9-exe%N=P0QAYcF)fV# zCe{2+r<;wpGAv_F-lzm>fOg%J+~c-u=0AtO6#Nh1pNDa2_Ih@--aPSRi)~{52?0>d z`0~+5&9?x1jGE@DoooB3*{zp%{ztck;qPfp&b|3w&$936eTDGH_Qmkui}g0K@akP^ zH`ez}1d%{mKPDGa7I0gth8;Tl8v1L%zXx>PPFC?YuM*vPapfVolWZgoa!00roq6Bv zTl*O4z60>?m34om+JA2gF$}j6Mk8sW3|cjCKZs|O>MNW70D_AC&{~$2eWiRO@IA!V zy56i3O?u@RV@lIVA%;dsk;%&A1OcAHy-sSFTpY1jhpy>%eUF`~hmJQMuow+H%gJv2 z4kz}6{h((3q5dZ?hWrzBx^IL$FLWc+uWi_ae{~zK=+}@6j1ivv^{A2Aj5&wxB7rhGK;rIwXzAC(@S zsUMK%cI+q~VLjNMa(T(GB>lF&HR?;O_Mglr+cIAvkG`T9tJdHA9f1K^=R%j`3}6WO(E(tW|KM_;Jv& z9Gd-Jlod#OM)lhMXTZwySC++y`~j&)uG`z(642dV`AsXwl}PeQ8By4Q(zEdI z$N@rw&wBAM+n>aqH277m-Zq!;x5dWhJAo#tsE3N;)uSt&$gmb{0Nm~igN`ejo?$<2 z3R9BneNt(3dcNmfI&Jd3ZMD+B&5Psz01v~bK{dt3sMhnFR{J!wOXZ37s*S3Dyu*M8 zbBgk7;*D*xN#+jFuq9gr9P~e0^KE4$k&N2x0 zuA4#lkD_>{HE6COzJ--s1ds-Ay^aa>2Dow9JXL6^PBM+PcYF2Ir~CuTLO6!gNsGTG`#T#EQ_zB)It#3_5^0>CP%UhD(K; z%#vnBEV3z#7Va<)(z%&#S#B=3Ni5!G@ZpZ_`%lZyQPQ#B;$_X$D;leLK!YMLBsbLe z{S4Y&Yy z3&9+B>sPeM65qq}NpSv5toYcZZeR}pe_GSgwFo?qI!(7wupokPBVpW%;l)*&ai;fn zZLh-j2+@r=@>AQ$%xftvEthP|iH_xdpEs*|bp0v1ov+#S_ua7;R@z)SW(|Rn{_hQn zn&MTG-ZoWbvtxoAu0PN8rs^6bmgyn4k~xHjks#a@U>uKrMRDQla@10a(%S2#-~2J% z7|K-nB4e9>wEppBMo3p=m?6vOpcv`t)YYpEC9bbh7>At91_&4fudf-Z5opsvxt3wM zj_2$dKFJEtFt)X_)gav)`T-jfsK>b4DD`x|*1~lOJgE=Z__2Wcgp8yib4rwQW*`dDBvMZFEg&j4<^*=NYc1uZT3b?ci6A zF|?9h20G)eNc8MJl-nD*e7IWTRY*_B7Cb2`dJmWJ{HvGoFN$s=pX{xCvSSi1PS6fH z_0M|R(JZZJf3!;{nH&Ra@W#MbIRYrln|8c3J^673z7spWyZi-YZ7 z(JCUc^Ixd&l{E|`lIrgLPvw6cEc03tU&;Rfw==^30Bp-ow&)gd$7}Yf+11e>{ZkZe z*lPK=z`qe^@rH@wn>}vVPt)}WpX~6ekRzRQg#h$yj@hrEzu=$}Ua!O37FNEWYMM>il_QbNR?yjPVQkXx$55Cd!~7yK}P6J2>8 z9+aQAo9buA8MH81Sm7|%_m#a7&-{A$*QEZ+-vl*36KWB?-lV_i+HJC*-d$XK1xNuO5S@YEZ}{mqYR)G07C&_PxxRSIHPbz|S9l`0wIIi|{kU zw;I=qE|%Xw(&G}abN9{(+73GO!0%s|`oF{t4#(mA-|(D8bsJqH-($)|nWwsu=EQ&z zxJJ*nle-ID`~em1<&1FAdy(&U)$Q^;jxVY1s&JQ^@E*@yRtAmcg5 zwR`^n>}&Buz&{^>F-f+?8c}fu`2vw133VKRbBxzF`!ReX zUk&_cn@jOzntamUX^(%S-r7p*46<81LhcesA>;`>Po6LVufEPB^>r$fT0>iRK0_;; zs|eG+t=sWE5B5~}J*;?2{%eUWEVR?8%4T5)%-?rlVnzTgxFio@UqP>fV$`*Xqib}! z0SL($^RxUQcjKIURqYJe%WoS-vB~EkSll*tfyl}4!4HYwH*vdkM8 z*}gIk(}A9A=f8=6@NbvHy;fNtMDZ7hJRfqZS~%dB?Qp5V+6ul73Gcs;di@5xxtb|f zN#RMPjB-HagTTkX{{UKbzT9%t~@UhW_WowF_48SB-QZ5OHI3^`H14--C6Ha$h6H{yMa>h!uuCw(>UONAVB>qWc0) z2TJBOfA}HKiLD9$0EBzvMxk#sE+dcoM#+{kdl1o*Lg%;{$4cI&ZNvD58Z@zPHt!eu zXnhS1EzDxMr-tTRT}ggS^O4~Q?k}*j{h<{|fznkQvB~4wouYVN2=xT}WOo_mRFV!r zZ2p-1D}M9-30?a}XsXMnc;~}pe`-w|Rgo-w zF{WzJr~xLlzk*$xzV_U@5((oN&3aU8VQMej;^6JxO?P zFvo8*G-qO$eaxerl1CqprD$pTn%WkQMg7vlhIa%ko`7TU8LvFlJTvj<;ka0J&x5`w zV+?XUuMVFy{V{?JVSrrcBtdy9a*JE(D8JO+hZ$w@c9fi7W|{5t_>S7kM@4A{ghg-z zW78dZ?s=?#vT>>ve{$^a({-WhVfdpfl+>Mg-`3xm*m#e`jcKDvJTgNkost*LV zbsL#&;F8pNeF3{!74d?9h~uqu)XSYYMqbIQbhnZ8m`rsUMSD7tY4&UU&eOwEHof8a zR}yZE?;L_bCjgxBnw!Ko-)5Q%SmlvbpUQB+3@{n(&|;F}&eQu2cto>9kc*Z64yT45 z{p&M8zte4PCDZRL$&zH|=5{Q~6P?_0N$JxS9?>b;^xx2zMYzgM+}iM+<*X1G+Q#DP zTHyJT&4IOX&OJqR+GJK*GPF>_(Zt7l$UbF2bB-&Kwb!8*wyhjL@$x=YjFJ^QZYQVs zR`r(KMDrtsZlmb?nC212Hj#|}Ij(xqrBhhDw|DX~<#AVgf06Bfv5nMPf9*@~r6!St zlK9s~K+urG{y$>lC!Bs2{SvBg=Z>UwujdE$GL|L%slFEOkzfl>6qi0cWs;l@{_rbAos7TKj5lYGDZ76Myq@F zA2#|SxXK|C<1OWu?5%XcHe!A;5ss^m883~x6(XZ-XHM|?y(F~UT8PAvPPHc=+AR7;{|h)yNMijA4|MIwJDtEkVhjJ^fk$a zr^)WOw@*ZUU1@Uruvc#1p}_vpdKCIri>Je-&v&U`THm#@>0Gw`&4hPsbF|=N(APKn zLS1N@=8*@2^yn@l@fM${*`?mgTtZ<1e8YfF=3-kvO5W4_P2(Q}>)Q8=u49tiXm_b* zVV)9h?u_b;SYVJB44`q?)gOj3>Ru=Kf#a_pc`0Xoq3U++ZrAl1yNcCOMw|`{qD%Kqa@$}hk7fJhD$7mslAxP7M8=M2zo@%%4 z@B2+j;lBe|_#eXeQ2zkJS>ncPaka44b|V6L$9{(ZW0DU}E3Ehv;N|#3@MrCwCb(GT zjv2LWLRG>kc2gq`-*hS9@@wNyh8{8h0EPMSJ6ia$sarLr$B2;HXx7rkH(A*(?Ysg& zEt1E%+y`M^&VQLx%CLUPd#+k}-2CS)!NW1lDq`wuPMc53zKHDfj}F=RXG*ivZFe%m zle=I-j51bNfzwak2fL{310?N>q~O@ftldPZJ1ayp(*zZ#f9KQ<)p(Ui{!6 zUYZdlvZ?*#sR1Tr9mi5fOm+HK=1=YS;;A(++SA7n_=-sFZTv5z>2TRZ@{*roxQ0n* znNiLdnN`72*J#gd{0|Fc)bmWw*i=_`a(DIpPkV@S97Yy|t5ekB{w;pdUNHEJ<1Ie( z#fyD;;K)tH_7QpQZElj`WfC}Wo_UE`QCXCaL)4Ew_^;xv3&%RUzN6v~3+fs)&Aoi9 z^4ygtc?LeYAShPnIKZ!|zh^J_C%2BjXb**&XT%1#@ZZC|CscK~(^Fi~r`0v4(d7lL zZdK%vDOpAUO{6Yc@cFaH9|Zh;;Sbtp#GVuJ1($`iT_?k;@j>_TYTj&8`8ILhn9<@mVv5eu zz!jSd!RNJnZFS-OQ^EfL88xfTa>8q0?Axf|X(NqeiO@%Y%kJcE!RIINtp5PoFZP<$ zwMgv$0JH3GZ1lZ0@$KP~WC}GX5C+>S=MpNG2WbZ%FSU9YtiD*)?4@TJrkmgT)Zn9v zmKL`>m7HDI$o<0jVe$7(@n!YSlQqVgH7gLSd||ptqqlh(%utCJhIRm+K3-2X=^AbK z+9Ottm&#W$tNH=R$K#s(D*cTA0BcBo6wi08*y`4LRn*eQ4xyyYxuCl9CPOv+0ALAX zkL6f+Rt1<~^Ixq$vhTJ2^I{<5#zF@3l!7=B{EtfdM4^dyw z+*!dD@>=eEs#KoMFYCGc9!JD<@s!e>z3#`+8fDaC>30?p;IGO*h?T(~l}C7Ex4dbi zM)~=dAKvvflcQ^qTV8#VStJZ_xhH{;GwoZJK2^r}yx7`Ka$i4y;=U4{YKz{bA9Y{v zK91`A$fNc@v?raPHU-RNr%~FVpTW?}rs^71-IO}Tjr2>p3BX{-JAotgtrfX0 zvnI`iVOjl0(wlo1+qELaq=rafQ1~Qg9Zf&3N~iB-qp>b2rx(=8(;~FF(bnG8W3@|g z$>qZq#fWs-f!KEUtlO(~(`|%As}1VJM;?Z#~=P7%EC*1pfdo zJ?hP$hCjBX-xPDHjf!$X$2eYd>5A3wpj`}si5*r0^z$d@uM6h(H9%pGQ z*+0zZi>V1G1-`FS&(5R$lO&5PMf;f0ESbY^4Z!~Zcbw*=)-+8*Pm#jcmATg(eQ zcW@-ENm0NIq!k$%`qs3%Tr=ud4I4BO%t0y!Cpj%6Q5V*RtFqU7C#!t*L4q1Ei&Qw>@;W3%68A;H@REpPWH?_Om<(zQU z-Ks@DRFBM$AAi9N{3kTy`y<7^BD0P*{{UAsQB3iI*uh|{y9d4nX4?M%!31|)z?Od$ zH7FyvKQNa>4Im!569JE|EA<}ZM$_$Pf@vlS*HZaEy$UxH0So9kJv!o{O)4>QAU7UN zG4A96fsWblUY#C1;uz7l)-NaBC;A+fGJLa{a#FMMZGNZbeyjfg1Qqz1;p=(sd_VD0 z8_6L}z)b|U%gG=Ntd0l1RFAD>!T$gRxBaKQCv6(oYyJ_|A&o&lX}=Kt-Y~vod0;(w z`d8@>hdf(8lYMxzPHwF3BvS)EGad;ARDF5&&2~_D!aJthAd6~^z!8;0)3t5?01+cM zhr3cSZ@()24vaoq8WQHEDMk57pOT&~{hWVlZwB5+4BiE^)VK@gSkG?pGWN;Jo=-K0 zdwucG!FMp*>t6|cLK0)NJ*Jy9yN`S{YQX*mzh19rU4GtHjtJa_!CAR74E^l(=eMP0 z_-{?cvA9^a+(s}E6Y`Aq^y^yF@fxQs>){~!H>y1;Vliq~jv{|OkIsMCo8!NUyhr<7 zd?dBM_<5_^y~mGqC6T_%Dy^b6p%5>UfNfq0B=lf;uj&Yv2aI(!`3I}~Rgc5}03QAn z=x|$!ueGlh>BizXoz%k@+Tb}P9PI=rUOx)`elhZYsIS_%8#0zN1uCwRq_3jA?z^9d zc$bc^UmHu?Po2dz_@68P0Kr;59@}{P_Ia~ew->V2Ba$GlS1x0986U(xVs`si=XJ~$ z)}9v9;k#R?WRh~opca-5ndGi|5uDfP-~1I$+H>#%&f?vjWVO{JU=^?qw z9!qwS5UhDPz?6)Z$5K9(*E8DuIMuY>Ev#%-lHh^((urQ%AnIxxctKtX??Ta5kvy6mI|yyN{r*ctuW%=6b8^-ECvpt%z$8DaCuP z%HOHQ{8RBhxo6wyzz#$J&ott;}0bAMid5-Ist*S=eQj^SG9h?{{RZLZ70Te z-VD~QXSd&VW3SxY-MNvnfH+(peqzhVVk^Ra8Ee;H6~AiD3t7@`8q#YLVO741$&D?N zCX8iR^!ZhQJlEb|uw9z#T33lR$*&{TFST7u?5lY3?3kGU0CE#2w&dUrFh|zAGF%*I zS`}(1?&|$I92`ZDcw7YOMJq{O{cczP0D_!+TYVGZC6|CS`JHe4eKCbv2gH%t#>%ot z4^g?e{7rbqk@i0c=rGu;4<gJT)-g26^7n$v z=}=!3z3Fdzy_57I@yEtXPluiZn^x36b8BZ4lLUl$*9_SU9Al$1WE>BA{L=W>;=OC) zAI6^)_@l*^z8`CU7+c#Vvxs7rYkLg&am0>B)pTsQJRFnLx7B|hHCeymr*uyb>$bX* zrmJ&lZQ(oEt*5xsH0z0xqgk&WGyAqTK3a{#7;juxTk(7LEV#GuZ}yjn^!*v`H5lT) zyO&LD+uO?ue*KYO?n9I~PNzI|&3#`D=`+5~!hI*q2Sd#qWHvqsxdZP}VQ(byphxK#w; za%<-8H&WF6X{t@AUwDG@7x8t2#dRjCtefs7+C0M-_wu<1_;+z%W?guFJb$A@r|JF) z)ZfD@>!@4n4|@PuZA?z!{hZtg#D5Kq#l`i$qpj$I z2(^}38aWvw#8K~4%QJ9tItub%6u^ReJ5{%Z#jGM-k-IwO(mlf@^v^$sdio<=_-}RM z&k^5vfLO;gEo{PgWkPosk83+G_mx-w03FSFXNiL^+vxW8vD#fp{gk&&7Ug18 zDPVGVUIu#{SI}T_k&M2so~a&HS~7pVo~!XZr%=>k(e*txNLJ(RT6jYfFU+#AeBk5` zql){7_9XqV;qZ5jWzsJ#f8t-_q=xs(k9>@}_9KWy(y#heh=e`;AgW<`ccU(1R>qzRL0 z1!e)52I@DTTGKVk;hA9F8mn-@vH)JFPs)B7)5nh0wa- z7gBw8mgA|(9YL=~wz;~JC7BQI=8&{xA?Oz*pVGe${{UuR+ULYtPmHZ}T?baxd|RR2 zNG>GO?xaZWZ6J;qqT3Ue%%x&iSmVg&2dS^ruh`4>qOkb0rpB%fw|5oxcp zYl|Gy#$j*-auwhMP-6O*qO=?AdzXx88j$uH=$dcl-?q)g`u)H;3hmDza|_1n>#t z-!-46>;6rgmbV+@nF&;KcOf9-Irqg$9m?x=(perxe*V$aI?)nmw$MgS;um8{hmZoV)QHF;BUfWGp}f%loAC=f7I>t3gzD>QrK) zw%5BaU21kLtkW2-7BvO-ow?3<{3{bt)O77qC63Kywq|uMN_@O`-Pe;=^+;{?SR_?T zjj}iH$xWniGmmPMNAQey_UN%Nh2uV8FT*W&(ZIMp%MP4Zt}g)Iv0bvrLQ zIhTAuGk^{ooO@SOV|#0H6k=&NNW2W352Z81z7|%59g6#MrB!$U9`%0X!tt!CvMNTo z$aEhtJ$U!3bYtxOv=_bGV+cvx-`v6RCx>L7A#racM{AIFt}r^0`1H@Eb1_}d3=bu| zjU~;A%kMZv?Tmgk+iH4)3rU%QL;MQdhQ}W;A5Yf2x8gRhsc3#5)wL^4GHc@uW^c1V zcF5sKEC;E^c&=$wttqNejlEgDIHmi^Y+StYt;dQq`K_AXIoVlV%C0tGgP*DC#YcB) z*BVGMX|e_w1jful{YQGm_&4$I!haET=`K7!GhJWhOYkrhxNZxqK!z&DMM6pu;>>Wz=qZvS#9H40UMMI z40Fam&MT+Wd@FUQ*tBbEt0Sj7&g@{-Zy5Yi9s<+kzSQKkvD0l;ou{4FG87OA&V9$d zeC6>Q{t9hz1ghQ_wUb)=LL*0UBw`q31dzx9QKY1IZa9;|1AAREA+QUHcHT9%0Tz|qzWYI+gzGQIQMLd31 z1_shpWc^2_e$X)9e)cQ!x8Y5TcxU35hhU1|mDKGKLt$OF%@_pwfye3TU#c57j=A?e z>-R4Uan&(6YVfGO86=mLkI4K-!@`axmnT%)`_C2qy)7(X;1%@O=?szET)`wlLGp*) z!5-qjIdr?#@ZG=sBn4kp)BHY?-Cr1^4J620{aE!Qr>;$Yi2ne=MdgOy_FvcJxLwi0 zV3nAS+ZAR)#17|ZJuBglhnjAW@gKtyC9S=_oo2Tt+6a}tOMSN@vi?!dEU-$>n_W@GnD3=-Xm`q!T-kfVol zr+5DV57W8zF^h7&td@@VwuIAoX4~S7-|&({rL~Rn_=R+d~g|yGLL-M!GE(S#63Ua&&6vW4CwdP zHkSG>uRf^gleCR8HM+!r7RK+Go=>l8_uu#|hs3Y=R9~|K@sAQ_&};x*iU7Ibw4P7)H~A)!SV!uHIKXbM}7G zAy0+6SZkQBFSQ9C;`FKW43jy`h)?&F@m+_9=Xdc$mUc2U(A$#io3}DAp&ds(E9d)9 zh?)n$4~||g@go_^zBmDe2Wdm-LlGnO%edFx2pifLFT)Eg!fvl zx9v6JPZU^QS?XR2@x_Ws;oGQR_kue&5x(J{zQLMKnA&Ue{Og8s^e9xu)L!FVwe)AL z9xTG*@ot=SmHYJQciuSoQ)lA62K!CcQd`Sg=)&Ab5ZfUt+m+Ri-C)=t=Yjzo*NuEm z@Xy69AK`wXrTBxynzxO-A95}H%e_9&%ha@(#nOlSNf`S`WaIZ?P6!8~HS~tFuG{ME z_Kz}6+X6bF;1%H2Zx8CW`c%ZB*y`DDtVYi2HI}0jj zo)?bf*M|Ia{gyTF*+0iR9ruZKD~}cUcU;s>o{gyeo9z0OI-H4!VhXtrziW?_b|i33 zexi8eYhCGjtbb=!dr0l)ML#g>l1JClzG43Wf{tn)Z^0i0-$`?03uwO)v@J;|y9H5R z8RP|tUzjSVnVcNsIOe@9KNnY*62k_P_A_aDy)}LBe?EtgiQ)QG@>xk=+DBis9Cz>H@Dt$6Q`JaI4lA@`bn(OVz1NYGk{ zhY=Fs@sGNCA8}hB71mFNe-*XQ56N$Ne-)kG#ktkszGjbbnWA=G!#-Z_12^91n)6?X zR!^yTF77*NZlb=??%F*-u4LYa0yho7EJ?xJlZySz3xSHK6rIzv=IiKwb5dT;4r@ni ze^)tqE@koU&6$!kRkF8xK!{K83~cSc86}ywa(}|P&llMZJ6C43TbOkIAXb@hS=B(3 zr)|WM>MNkpbj?#uw~qZ}fn*c;K}RwiZ3HnKjO1~ia&uW;D$thF@`Um($`&MZFIrw*`PonA?d){6ju-#g0F-Zi87&9mgcPgTs?re9f zUl4z79|(LS@PvBCyvYr!TL!n09ZJj8DvtYz8Oa#rSK@DiziYn|YCb!-)O8IPPw^I~ zc^OeWe+tiSr2hbEG?3dy ziUA-w849X6>)RFhO>YV4LbH`xzT$iA{Flq2`d%N7D$0v=VIzNiFTR*bLXv(d|VwqC=Qy5@Zz$Lh9=@v}hHbnp2K1_I&i zq$~dbET>S#m9juMQIm|0YMr-^;TBf&M>c`f`)&q0&Kc?QuW9sPdGtR@ zZ1onK?5i62!f+ZXQyW6}&%gLrNgC;}#AA|W^DW#y!Jm}lJ!|1Ff?v0Ph4qPSbWJ-} zwY=1`y=`rlNlVR*&kg|E9eaghKMLu-9e=?=d_D22#2zTp*H!-jgnv!&;=rkWE3eyi zN1Qnk!vgMh3K5GPykmo0H28NPTbWdYi;nH+t@vEhrx0N%{_#mky^?9a!(-|<_@{2Z zG#VaArFnUF zAM06#96Pn;E#GCS>Shv+95ki!)6C?*;Fdlgw3k-#7QKC_TUq}AYupQ4D<*%QSIazw zRAVEC9_Nbt(KIm(Vi_7T#y>YPrrZJ0AI$z$`ET%_#dd!VzAS0F-Hqq=ot}aqxce%z zTYx}v{{T1#Zccf}8R$r_x%@@@P-gpvD0La&(h$t zj(317yh>G2mtoT+u1^@|v_EH$g3a+$U5~_?g^V__z0TydLs@N@=G`qsw*~ydFaaC6#6A4ri%<3XPIkl5Be(Zc);WjVN+Qu z-&)1B<8LJSh=6>+oOKmqTT`fLJCZ^g;Wsmd#_{QzucYa5rl8YB8xJ9OyJv=60seDe z9V#%DC(onqkD!e8_H7gB@<(m)Q$?EIG(`J#up7qaTuQhg*Xg`t1CL(y`I~QGx1KEU z^wFavtEj^MS^La9#jn#OIbZRw+B_lmaJ843{Ey0fOxCpzNBN`1zi&&6dwn5J9QW3AjP`+>U!i_)Vt8;4NvS)GXTaJ7VcLjE96i3 zDM!TnKZ0KjX7Ej{cN1&gHbEuqYbTi{y^==7os&3i&H&>W3twGIx_^fJLT@iM3s|GI zwJ|ECz}>N>GP^O(NC)y5^smmZ_$YP#zr`Qg_Ty6U{-H03uNzH*%4>B}#s-gg5urSQ z7aZfC!mlQcD%I&mO8d2Kv`=@j>u1@N>EY)}i};$qbgX!9!M}-KAowrh3w?Ia#DCcL zmV{drf%7-WazQ6K&PU~6QfMEx?~XreFN^;G5ByoH__S$9;ey+(VdA!Qdr~r%ky=)=O=4)9yZJBErg8xW;!$ z-Cp{w52CKWVoh`Qiuk4AJzK<X#N9zo(MqB^#peSDU3Ea!6DZ*P|w zcwq&)B>E3Nfiz_>O%8;wAKx zLwR_%*1E*ZrZ@|3I~iA(1gHRjG4Ee=!Qz-TO*+={*{x-@xQS&)ft8iifN{|BM;Onw zeo%Z<)1&adu?4-Pc5_FiphZZOJpN$rdvG#(*QIBzk9yH1i#%r*Cy1l(#oKGrm3Dh{qZ2UoU^bW;AIv z8*c$!NZ(Bf4g%g*q`QZ*ioEx#e8l-mApQGMvftBpPsrqH0`nOX4uLvE1ao2 z^LE_%3-+A-qI@snuZViKyt7^E3#{q|qg$(Ghi$sax-zJpxC885`@p}H%%~PxJ z+vatbc4JTQ)RSCXUR&JWMd!gBv-yhu060>4JRE`Z4!Ep4e-hZ<_;Tdm$uzBS_>hRg z$c?n)ACF$Om*v~(kWVGxKWZ_o&=CIsb+USn;Na)dvF@)V)@5lNHr^ev(Mjrm6U}?n zFtMpVS#{gx=5XRMGK9U}q&K>Q-D51nZSC7~ zBgW?r4%PYs%kG;HP>r-ey7#Oicz$CSspNhU-&tw>o=BmHoB*Wb&Iup zxGJ`QA!cO*l14c7uXy->;(b5G{uch;)9qp~WhOFiJBqpvcH;x_?_7VxOZ$(38oq~h zapAxCNwl>{*&lx4Yz|gdbQlfSBl=f)`#yLjbRUK758Ii=)Q;vb0wGs&$Z?#4szArJ zXNi3DJ*r;suB^HdsX}t4P29!!lW%e2EfKW~9UeQ_Y~zw^kqkS$wBvRTexIFky03_S z5Bw?chLPeA8$seNV^~X~Ewn4RHrl#ISQED+%>eED*-&%ABDQ`lcph&IXc5h(T}NXY zOD<*3(NTlW{m^wi`CH0DkHNyMe(t zCb}|pPP|%+w|C{(;LK%ND}Cp4@f-dM4dRV^T(>?QSo|ZW*~{kJ3pkU`jO{7Ie+Xsa zfFIJnvi|^rWPaG+4`RBLLAvnAhkP@vNwnOg7Ra$&Hs%6Y8)<)%GyBgmiO)_e;13S` z64g9)c9UuGo9kFelJaX@y4fc%_1h~h#v>RF-xa44WVjq=23Sg(UuPZsu9oZ6tdA9m!eXOd3h8ft$L{Ba zB(>BnZOmpCid+cUc~?{4ckuu`LBqau{3^f1uN>*04D=g~b5@4o?3Pg&Hc@$v8#3j0 z0)A1)ulQGjf5AQeHva&^-1uh-&8JVJ&vMhxErXMr78eeW9$4pxR^a2ceKn~3HqyLv zq0bC*ERA-dU9dMJC?N63`d9Lwe7e?}Q(= zZ-@MQ>Cj%P3)$Hj{>>p&WpG#T0#%jr2>FG5Icaj% z%e%`}`XUuJv>@akwD6M{o$V`Da<3f;>T zQ+%FoILPb|C)1kP`1SCENcgkim$$sPwSw)HT`!bvf0K}*y#o%O-t~#_Ye2p5^_BEj zQ%JLmzb_IvZ#j01k&ZF=3iR`=|(dVY?R z=#Y@hR}8~7^8UgwQKuz!rS0aAdR64-&l983=V|;eCfD-T#3%vB%AlTdDgGK{ zxwyY!HeqP|zc2oJgN(5D6<1E4>w3o0tZ>{Zl0}oTJY@S-%?9qu!&KbTEON)pi;b)P z?mKkP<6WN8DvYHKY`XS22-`j;Txd#GOC-ertPKMv>(!VY* zwRP0|cc$F!eb(1j{!=6E&SgXWKRW$KMZ%A})7HOY@Q%Gn(8NKl6qn?FSK?G{T7~;R z%^W}N6XE3XN5Xw-&BNrvB#bMSJJEwLBis%;9+mlLjrDJ*FeHh87 zTEi56T*rkwnU3xW&IUo}(!V-x?e!}w^zkS5#nW7B3e$Ot<|(&@7&%eXoO*FzuiiKK zZvOxi{{UuBA9xn(Z%s5hvplaAjmOG( zbnLeNy$_MJX!Q>ic(E;&n)2c{kL?dN&zQkUIV6+M%uZXsHT7rg(W1fe1H)b<@XRt> z%8_h%t{V!mOown(Bxju9{{Z!?-Y$ROoxTVCpMPWD5qwD0F11~9O<8TA(e(SY4A8u! zKILFXIAh07E2i*I!A}h6dJl$V({xKc9`er4C7s|VGji;|F>&389Gc6GxS9~9^=dKq zx_);IgCwaby7GSNZEL>Dqyx(=8*^HHctyV++aUtn-1nIsP1a`q!drI(#}k#A_d&G>UMNFB?Z+*QFZu z_u2G&e>i!KaVd`lcs!HW@vp~fW;K>3IFyoUr)0GGTVw93;c3omdfXW$?clgV5I-;&0iT zz`iH(JeroH;j1|{i`GGNa2Uqd3K3KmE03E1bKA9kiT?l*I*k7SbBLj3-Y)n1&yB6Y zHCD80b6>uVe9`mIgT4#tTAsNA>J6xA_V$-{&`LFnMVix7Wj{BNgc2nYEA3;?TvwES z+FuBCmH0hx;sYddX?KywsKsd&%+ahdOC+fh2qbLe$AWg_fHFH**n0l}?2BtHx7xfr z;?Ed(cH$?*Q^BcyrUg4#Di|gXHv6d!j+o}WSN4|h4~VaPAE;@66ZAImH-#h8&73;! zhYn}7x>*E}+}&HBnJ@!K`>c#P&r157yNL4qbyz4=<(k_1Ka>9ef_OQmL65^yZld>m zuKx6Xee0ePkH-H12(4^1`)kQAH1=eT+!a{ElEHf8j(<9)bbSX*@O8b!w)d>ip)sQ( zt+mafY2fll`g|*TQ}zxBDHnoce9a1=2XnO?2B20kM(e+Dsmf+k- zRJv{-d2?P(@Q2~aNxqMD)wFw~mPX3TRg`3~#tum7=xg1yU0%+A5nM%fu*EcR zlCYHA5O(LUK9%Cury6tThh5F-#?iEySK{A;uCKftnxt{I^gm2%SJAADs|!NmP$Wvuw%eEGKU0pKtC>C$(!2}e zdmD`+#^OyoMYR#mAA1SfLAjjr2JOW37#w@okIzxklw`f#_Bv`bq^fOx#UGD0pYV|Q zS(f5>Z6b?LxVX4kCRI{QE>Ac=G7>Np;PQPdCcM6#5&)8xYem}5$f`RXq-WFt z-n>8IhwS6v-G5YTjcZoaY_!O+EUOf;$0TtnJO2RI4p=D+7aZ&&y^O9{dU$uv-JZH9 zb@(1`JKM&(l|8)w0Lh<3e#*WB&^#HU>s})8zlt@6cdt6f=j+w#>#b7~l=c4;cjFx&Hu#KN7Wl8^`un8WLIfZ^V{z ztm`UjF3Bp1EBg>;+`SKa>UcGZ;2%!5>)) zvAxpR$FS2h-Fkb8^$W|W(*$`a;E3OG7(vtPp5)bUhCUfOZoeh5nro=$n5<1RZgg&{ zD#`Fm>37$YLMAPdDap#@_5)XM23L(5=as7G%&?H6F_t~Jugx!w9y5!tCi{7iW7fhZ^LlzGFdSEUKV=;;366rGGE~0JKN^6<6YajRu<(&7x@R&c;Wu`R4M( zukthck$_G(z^|as^17JJ!t~>NB)`h%lft>BJSHBZslMmvr26HprNzuMYtY4O6h{78 z3k76FW1Kcda0vN(mBo2BnpYhEDLBGt78((GcL0J%GEWb@b)AY=1mCxAH^9`*Jw?DybZH^y4jdRCun91~pX zYilG?zSZ83iC5f!wey*_NeUIIrz_7*Ed3ABvRuk_D9)nUtvmk!Rz27D`grZXW?zNg z502d}rd(RAw#8y5y1UYvTB-v?}H>}Orcl8V^CVbFa| zPbG`p3$|Y_IMjtjE%Po&9X}e*doL?bktU8g`7DYTEDD@g$W@}K^Ej;$(HSP_;c83X zlp2fc-1VtegkLQ=Cn8Nl#J7t-*=(%*!;A?GJ5(^}+3U`GRXaT@=U0hg)>aI%7^jXI zOGhkcB!l0R&IzrUd`$xB%C9lphhpv-!SBaz-<437@n^nm%(6!Ab}mY`PZ{acnk7z} z#KtZSdEamNAv%tx@21Cs{8s&fz7}hK2-aW2x(%m>d_f9JeQ$VfUO3)J_b?7O9vhX$ z4`E-L5o*@@w}|6rfnX~0c@xWRwoxLs(yDR7o}#}_e{KH&inrede`j5LT+{90xw=UO zf*G6$;PZnalknr8Tvz3fgYC?)SVgO8^RmYrmiF<)%&OA!A(XaA!S(1X`VYhWO-xQ6 zrCRkIl4(Z!>AL&2KON&LbSdIv8gaGOy_Nc&o#H(o#CqnXqQtTJvO}~LZdHCcQK>faK)4Qn5UbiI36@r|oo9X1ODdDaIkf>x1$A!7-@4A>dRJ*!*xd>7s& z)@*z;Yjq5I?zm-&z{ePPCXNnR&cyEdPC9UXYwT}_pR=}`bEMndOxlIMp?hSoO?zRe z-dHoMWbN9H2UCoCA8Pb-o+`rSn3&bCtF_Wwzdu8dE68eQG#|3tzxC*O&%w|5CnP@( zyft8R3@gj$`QiL76?f80beJ}c@UB$*# z60%xZ{%7Yu!V9fotmcMoo>D}N(kz9QTL5xDOncW|@eARewW)Y=+f>rEXyd+yhK6ole1HNn7ek%UK*Pao)hV#JpK5R3%lIF^N!)pZtYPNbD6Zq8~L*QSFd@JC+ z_mIr8d2H?@Y~*7+W0G@T=3m8F3{@FUjYQIGdv)Dw{{RelRkA!q3&x~d^gea?OQ?9C z$3Gn|+gjA&iq}rMXss-xoT^Ir3U?stLN*81zLwSPe#xXtwtr*PwB0>-PlxheG(AhR z9;>wD>0ch}`X`9K5PV#>@fM+@-|2RDu?d7x+PsZ*A!H9cz12df%He-Ltb2cAuRMLaxlsuP7zbwys<9t}#>XeSv*UYG25 z9y-)C9~~Jkq_w)!V7gnIYYXUZqm#{b$VUpN?u?$ijl#V5;-BpE;zY8wit60nY7tAe zOPL)ZhstzF;v8omEM$Y%kV(ns%r4y9>#X)b*63-Mu25}#~X=OU7)UefJJ)}!^;m9Do;hfnmB3N zwCz>uQTSz|c*_3(#M+E^kjrg*sMy;joO4@wim_Nh<*H#u(0sWG1-T=U(A1wEJ|5oq zPUpiHO>mbQPNi=Z+%ww7(X?^r3>ddOH$m-RS9qi2&X1s(JVB$`!J|zbm{+(7wmX#A zRPO7^2N|!J{73OWUihD`={_;FHgkAp&h-VYn3zts@mk2!-6Vr<23kZ=NWncbUq2jG zDzbGYrN6JgvFq2R8c9=|z5es1_(S1~FCFW;6{|-jL{{RrucZW-RI8qCNBv7OY zybN5vG4nAv0E~1$O8p1XG)XM)BV~zWRZY)A6huIL)2!m$!GAazP(9eJh}8<~D((wVTYz3hqYG za5+Ezs>#u> zmKOQQ+E*AHepJ4%9(gAhrokC<{AQ$gYh!;TibF>fLMGgKEz}Y0E1z4a<EwBwwEO;N8D#TeN{gb>uMb?T^nD z`jr~uBpZp^Hk0d)EAorPTIbvEVe(}1VGv5O!c*klK`co6fJf55R^pMGC2l@o4Su=d zOko;0hP1qvhvq&aPnx9!yG!#&KkQ7qz59b>^ zWpid7zqA^EW7wvg*Sf zLF->N{{X=%67#?w3AoX9Z8qZ0JC=DZUM1ZF@Li0L{5zC01?!b3ka1sJe#F{u{3gE; zEOiFA5yu70w$QhkpDi~B&OVGeuaV7W<(fuvpH-(#HT@6Mvafe3s`^GhrJ2|M+S-WL z{{U`F`+J!e$+gm~3En%;GD{AZ;j38QIAs0OgaJcirzQH1qZQm}+N=^n(mL)|0F_>c9dqx<>0gp@ z1Zk|!Rg>#?+g5#M1;!b53U8zr`I74p{{Ur09L&+PTo4L@{;M6&}aBkcDMEn4Am6rf-+E1%Hfzdrsmd>p;; zw~yw!_=l|BYLn@x*7mx5ku))CO2n6RAOW*r0|B#;Xv*^Xc&tYzD#|Hzz22vpm}4tp z<9czseO{J2@7cq|8c*%j@xNBp{8^yh%VVX*CcU8Cncr;i2lA2`A|sHe_kd$@$;rU4 zt$aUyFT@)!?MUvLNLgAysDP~4BnHnXp1AfI!LKOz3FE&DjU_Z+4OrW0I>w(Ib}eNb zvU$;4va!gKj!K*{$83E~ddGz>ZM0idUEgK-7X=&T-~c~NRdEmRW}(T+2<*K-%-1E4 zEK|QUp6~Se31Q*M46c&0N4Mt%xFr2Ln!l>)7TSE0J?zY^6e9ti3F(j1rBv|^?3bFn zQZa8QAgsL<83RM@)CkdQa^O@h)$TKeV>1d2uEFt);c}>!#dH zj$?v1z&m4Zgn&8pt$jDbHh&4cGkvG&dc@7*tzz!l4LeU*KzZf5bwzN*W93rJNX`P+ z^j`>Nk1fjMO02(mE6Q40*IRwtpNw%9FY9>EXti1-zGuJx0Kq-J8?=5R*L+=X8(v!4 z%AOz6CWz!lk?%Kdr1A4P+Di4?-oC`}%XMfvNj<|KG30+%XJ*OAO6()mjj~JUBO}~rByp41r%L=&h_eXe zvnsWw^mkDI01CCAk@~(zgLpg~DD~=gekaq$qPD(L!wiic94QQ?h~(qBtzQFpV%En? zTSmQkogDy}4Po?Fn7isC#cs7Wo%cN0Y-#;LR@<+<;VY}Z|Fbq4=`{OwGkvsql6fpz zRE!*|f!MJGeg>-n7}8(7dM9H!$-8o^o*(;Ud_OwJ#D9l(ntawV*k5WAXt$QKLS9RY zX`^F~GNkovwgzx|;=Vimss8|B9ZTSEiLWoTcyG0T3EV)lT+a68t}TMU1(kVhb|aqE z^{4Fx@iRuX_}}o~!&jP?pQP$iJeq{?T}pR_3=ts0;}SPG=c1l7Uq%f^-&fP&hwOXq zf)J~nqA*Fscx;o&<0SBF?XVnCjme>ha*g98^mq4{Gq=xk=W$ss4>9MAw3Jl0F~l?Gvkj`Ca!Q_f{}pzr4q<*{5~{{WF+hhlNnF*PMpqSpI) zey8c!dNHF%*{zbXX75TDcZ{avF@P1eHqb~NeZLyXz3@e>dW=_4-3ZK0wFU{``r@P1 zFK#rb?j(h6=7h^1AU5U$InO@-0HsH3uiV?g=S<#oIQzSP@Wp9Ts&P~Df1xUy=xK$G zY_C4gBr&46h4L`_`UU{=g19;SDY5ETT6740+>=If8*hG`ex|tnN5r>Jt0KbQWHT-# zbxiF8v8nAfg|xNZ9Klt|42qli_0MDNS^M#)yIX75;l za|S-6pkJj{*1SH=s)-@;x5ERVT^{P`rS5!>oGttE<#6_ZJ|_x}LF1sNygPidyZbdj>MtkJPm`M3uias6uS zkw%e-W5Y_P=L`Hz-2OGf>Dsl;wyd%*${G^BZ}3Lo_x8sXxZX56bW3qHyGJtspt}ST zA7hX6nwk=dpEUJu;ZiO>>`|ZNZ7NHBC(8k?rivZq*x`;lXZqLcZMGKPKlesB_OHk< z7i#0gk`Zx@Iok=DA8=(S1mn}1{dm(R^K{#G+&*ZO=huN>vv^rjQm0-mbz7@T!J< zzS&fu?LX5?oV*Pxe%5sAY}ev^P5T=7rrkedPl;Nu+2fN|U0eG??kPsp7%qznDEp|p zus_+rucSZVpWhF(iw}+dAKeZ5L*jb|ltyLS9m^*F0GU6#oREEfUrO;SKZVvmvmLjO zv>2`~uk3tO1=KLio^;PMzU_dF02d(V*b4d!_9wBrx4p2u7g5~D;gT%uospwqk-fJ1 z08TOWuMZ1~ql(O_;vlaV6#X>if0_CfI&q_u;vp^e`8~fA#C|sEaq9m7@KfK47aDwX zeV;+Ii$`~f_QcR4!G&f1@R-O+_a8ysM|&i8O9{fuF9@!FUEKk%EYW;Hb$9;&1r_+& zB<#xRtY1ScE(!Aj&mJ9DXbOXX0U+Rxwd%T=`!(vdzzZ8LFMuT`tQQ+KgM<|4y6Ql!Z1F} z%&f9R{kFRD2d!Sz^#z&*2*Nvf&)zGJN$=jQSnCUHB3p>$k*&xj)M6*>G2gv--&(1M zyyqo!+uvJsDPh##ElpE7+aHAbPk_GE);dgjWS7kx&fCiqCej-$$EHv9tGY;<8#a#M z8GOcM4CE@Daly~w#d04W{{U+3Kj40=rAMl1I@Y&iu1vO;Nu`j+N!hlfXCRUDU}pz8 z^)=+b8a_Js)590Ke~B+7)_fPG!z__UXKeN^g;?7^9Ps9+0$O~(#O#kdQ{ibM{f+7v{GbH#{*zHivA&neKyyZ{?wi= z(fmgkO7dBtjy_2PbFdu%%|~hRR@2A73?T4N!W$c34e2-X82nk{ z`$r*f#0ljhMI(*#PRPR&@r4_K1e*1-z7v)Lr`GccD<^cew^grx%d0psd__EDeTEJW z&v$+NkC?8$1$dLdo+I(~wug1P4+H8l#dE1k51DKcKX}SAdV~zE&Q$P4eH-vA_Gh{A zH^#3F`1)z=buA{tL5}-NHtmlp;%%UB3^EGD7;(TPXSI6w!cPWiKM%YEZ=~qfxwVgc zGlpQ&-EUBhEXDJ_N5)Ax&jj$ZJ}FvDz&CrFYL?!A@Cb&BJoC7>W@#7qpdL5i05W@0J+!f1$8S3AdxqVN z{nP4wF;o8l!bK$VL-umBsw-{-_(wdSU&g9k>Xz}^wAa`6YRuVTJB46=qd2e5>CvxN zYH?cro=4cmH@i2RPX5JJhHHrCSqB+r$UTo7XPnm7pRCy1mXCLtCkg-@UoV-%it}TkFeA zymS4cFseju4(x;X0OK|HTra||wjol?VAVFXYpdA#ULvoH!{J=;no9cmA2RsQ<8H0{ zQhaIgPNm`f9_rp}>+LsI(FOdG%_f{Ji7SQxjf}!f;NXHYUbsFfd|dFCfMfHt--uIq zk5RY+`h7FR7O8D^9Bm#*NcYAgW^4h=af9japR&jN5_{t{hl=j?Z3n_W1N$Atv=;Ec zi06B0pfR%_mK-rCHxbvbTKZvs;F6ykyj^!=Hl?h58_;g7(P56}-$|Bs^22T<@$(Q# z;~eI{R;`1>)0H|HdQHZfTdlfj`k$Xrrz+GZJVqf$ncsUoyCd-?`{A9h#&3ylrKgJF zv9`WZa@RWCZ)+L4kr2(iGnuX!6$`okY@7kNr>kl|vv-30H=)kD$HfgF!g`;Z9;c>U z-GbJ$PXY5HZHhr)N}%b$_Z9jd`!W9j!MM#Vw_YIeXT49Hq{?DHjH28I$e${U@_;igQReK9V zYcP^2z=>p820NVfuY{NOezov_impGlE;PICN;t%EMwT&3@cbb;fTI{rIwz}Jza!|~Pvd@v09Cx!8Mh~!7*Ih&oM#xSn!oKU;XMb# z@=0~8$EjUH8_R+-iIImL``6f7H~bQp;3Qu={6B+KX&+>g>Q|R{4exrB$&wu@U%dDdc za(C}#q510{?RDYH8@pdImP2xh`-PG}F3pK$Ekqr`Ug`0LhQlkUpInPXw zQD3Q+nl;3*w6WaCBpC^kL1I0-^H}lTWb4r;B_!Ap03fN!F)Emg78?I~!onJu)~Q zvst%4@JfG+SCFLBc++@=`$=>66xoIRuRQVEt?N%SI^r zJ)bZVg`6)be;nGL?4)7Vf$pZ>{`2(|muT+3SsG zWd-}(8REN|Ex>)!7hyO*G3Pyb^%dWK&7KyM!@9SLEF6hi-u@|LWrxYPmwb6`owB)6 z^*mS3?fY%`A5PS-uB|nzi>){1mdRqXXI1&YXq-v6W=00%(>eC8qr<=OR1XPw8%)%- zn=c0I`mB=qV%pjpq>V^c^BZp9!;IkPp4|5NJg?b8q-RkjFZc)Q8IDtqtAzGcVJl0c z)6c2lzAL$gKlm!=#+6wil08c4?x9vC3$XHIeB-*eT>96hc!K`s{xthUJ5fL@w}w&Z z2wlyO{DlA!m+eTv2O2{;z+sYaELxh{A zg{2%#RT)MrM%B|^mpd&l%o4BKm&~(|0zHRujGylw{{W3#@s5dbvBz|=w03YDNMwDO z^yBiY9vc4uf>nHO@icC}Ecn@>=yvx~Vnooic9v*{GEQYG3FtTj8O?e(f`8zR-VL^t z-FU0w2gFYnUCe`YbLwAaww1aJM;1!-=uhQY;k-V_)%T~%ua{2mL(HwjnOzI}N~!76 z@AN)(_`l$-2gY6zhx=Y@J4Lg&7goB0OcH1&S4T;s$rxqWeMdby5yn}5!XNG`eY}kH^0r_;cZXw~oKEZ2U8> zNHjQB8?7GAZX~g{FgV_mTQA>bZ{T3Vzk2mg7X?oR?-Pv5& zX~N0`iMUeGxKAsN&bb{3&3?9N_Lov_jiH7!aWRuVQVAq;l1mZEJohw7phGtz72}UKd}&V){_BTOe|h;wrT)o(wWhg0nGeFP z6T-I2-e`SeU$li>9taB?Gj1J5K(AJx{t0jK3&umk{{U{+V$g30S+#;LS5aHBjjFM1 zL@_8Qc3GI3{TI`;=ydVBJ6}DxV_2bRZnYbT-g!q61M;XK^PhZm zuToqWl}1r?=_cFTdt14jxbqz=IP%5%bUp^rf8dHX_K^LTTKK8r&1qXQGTB>4Z{;1w z_hFfsu==ilr&Hkn0Qe#o!QCfNisMrM0EFknw~G@7hW`LuXb_GE-Ha}O6WsA%VqbVN z^IpEVYss!H^p&_DWfpf&42R|9w%#+l2M4cuo5Ow^pHG=aq_+2(To(CBrk58M?xc^L z^yGg&mDL=h2Sz^aIIog8@f=y0)r*X4JHFZ;3-G)C361bC;rhC1e+#@fG^|6rUc9%2 zu<69Ez#clTIp)1{M))(}&xZO<!V;1RQ856kSRB@i9b+1CdTU&OOZS8i+ zk|T)}e)H3xUY^|6E~BXGo+I&f(T>U};nCB|wsyE5=`-@I#0~-{uRN2UmEHANMv{wG z{{Vt{weyT@qaJ#V>8iUc_+P`VV_>Z(lMTj9ud#;Qka7ZycEHEAdB5#dp?Iss{ua^v zHQ{|?{{X|9w~DSB7S$}8684ZPZk9(9hyER^r9e4h>6+=hWvXl1rLWkmG}){-LK;WU5mC;TTgZGlM z+UYaRt2)(VerA#AnoozJ(JkY(^5lQBN3b~nNgy5E_ayRazO&-(JHuB7Jx=1v&6(CH zR$H`JWsVy-}V|@dv)^yu8xyQsi7u36Os8X)50+jaXz5L0%1gG2!R%$L&k|I^IIIQ)=ED@h+V3EH^E& zr32+tj-c*d4`My4=fI60!#@jr38#2-!+I3jHLi~shFfcEVH9gG@w1RH&c~hK#7=wj z7+gm`EGN@ogngU$Pn&)p{{X=YJO!d8(|D&>)I3jp%0Viqw%ZvUaJ^1*$PORljjSIv#g#&O3=NZO0>&<%yi2Q%2ctcK` zQn<6cv$fOVXVvEN!rrQJj%i^>OO;K0?0W{x#4pq32p#$+e@B zMF+|wC3xf!P};=|w(`2j%_PJp?>h0%;Z9hhF)z%%b1?`*mLOxf&mWCvDl>Cd-*#g! zk3rU+AYuE(RJT_`!!cy>)c!nEuWu!}Qe@i{j)x%u$UQmq>-DNDdbaZIWsy9fR52fX zeNWc3Bk>))s+Ge_SuMNv(s4K&#+)Ubxl7fZd zXv&~m3~)L#iCiyghWl1;LxBLpbt(?5?D;y>_FPmebL0JHDxXXB3%TltXc7g~F1Hbw=^ zkX&5mC16KQp|@w+y;o52CY#}ZEhd>ETgevQ2;_VjTO;ILo?F;+UlM=8Z@gJHtD*cc z@Ei`&wzI6I)z_5YD@MEp_a87ng?SmyA{1!4>*D+mV})=p6PVMXKZ5%D9%uVAcy`~y z{{RCt*p@#sX0|g+8yv|aoE+oafr|Z#@V@5o-kx)e+eQBXpbGr7@Ew(|ou$O~yIFMf zREhQv-m-C%>N=mOuhgFlNhJOl(`0P!dsvyV>QED2H-|sGo+hl7oSJ8&;ymN+DX8tK za{eIEup>Nvb@AWrC;K+|N8;DU_187u4(XmEy0Nv6)e_{85lL(TkL6!tYAfbLCeze# zeJkbPAACjdj)mj9oj+3Xe};5@Y8PBXd8ow=)c*k8D;Z$K5^=yiYwU9jJu0}HSK8fA zkjd*+r#sQVH*~+xh0w{uBH>(eJL{R+8Z~Xl9V^B)9J*R$zGtJ#s#n z>AXMtG3ZwMznKP$q3P`$F~t?Vt&D1j#gFec0fUjo20HrzUW=x9)503evqxv*4Ifav zh!>S$xK{Fr87qb7J*$uSouf_Ti!`x|#rN7xqN!`iG>M>&J zx=D7rTIhP;T*lRIl%(Ztb#COEPltRJXRZx5O_N2`FXMOfEuvW3Si&5g$_C{FfPKy@ zaV~$eXI8bhw}D}1!l2=I3=lnduaEx#V$axX!oLrHYS}zV;eUz#Hu1Ihhiv7wzin4+ zQHHp0nlCgG(y$6LpzIjN2GNhWkAd3aWRXSJojGEOmEBhn74JK4*8Mtsj%cQ(d24?sr|4X;*CW&%<}(~@e(_9Tq=3ISKD9s1sU#}7--l^S8Qxf&3)%C9Dua1=#&ZZ4ve=27MC$KmbkPfa>IYoiyFad#Bal8?CUEz_rO>s?f9 zKaN)W_Ar!TDRay4K3M&;{{Uw{iQlxIp=BR~{wGJ^9TvVwx_?L04e$F4WKDFW#t6g}89|-s=NLNml<;Bjec^N`(1pK*>Mx+39l5@^0+J}Vg z;)8k#A(lkQ#mCBb9)uJ4{{Wm;XNw|#2KYl%y0?xitE+uB>T7#hk!BY%h?^0WEAs|8 z9r2pa9Vt_%%(t`HsM+15eqR2~U+`ERJI5au{4L?#R!dzegU1;B`&ZyU#eaq#De!NM zwQK9I4Qd`5)Gy*X&AMOM$*0JWtfXQWWd~_0#{;jw74@g=TmJwBMe%pT$-FP&-ySmS zz7*5tNoQ+|FFyF&9EhNe$q1yCQVw#z?Bv!eN>ur6bgk9z`JLXz$xg@W2gDzaUK{*t zX0p9fH(a1379b3%EF0xKiu{V!zha+=pSFMPM+x%-a*N2_?7`(n>i)HFR%$K^>hM0qkMN##ydo^hNC z+lNt-@#M9J*6g%zhVvI_BiQsOIO3Rf2omS_mhs0oB&5udl?#A-{XaUG`!9*exY2U= zQYidI@GAJT;r6ki>iXo%ruc#>Y%ku{F$AAxb!g;J>M{$4``F^NybWy+g*-QBrrX-L z_I{CV6}7aP&QTN~`j7yk!^Ji?&~1}ZwDQ@uvMlkj!(?EL=aQ#^lf_hn?EW3yWvAb1 z_p7+b(93u?!nytx&wiPrg?YI(&Az^db#N`MFMmUQ{7D_n(X}wD%m^yvWmI>o7gq~o z3#4qnFee|c70Zti#bXrC_nL3nbB1=FEXGL5T<56IPPEtYj+1G9Zx!aXrTKGn6D{;A z?me=@rFv14wfU9P@2g&iEgU^XJ4z4kW|hx~8qVt1&bFE1yZLu7A{=)=pS4=K))v~& z%NS_blNtN0*(35i8s;y2cca0%hPj}Qemd}UaP1@an%%Oh=RF4e z_2=qpxc#!e5LsWix%j8x3ssG_W?gCl=Ht+A?eFRI{{UwFIU^ApN!pk6zM9T&k!5QygoPW1`dWVjG;G_CBlIc9!e}wF<<+7P# zI09HhG<)rRsR45g#DZ}Iip*Di8^kpJd%h+dUdVE#HdOC09Zgzx<8j7 zbgzuAziE9R_O1T_f`jW{@P)3a;%N2DBld8XhTVsdBTq1_m?(fY+&h-bqJQtSZaMfyJXHBT6?^DLlledAMKj=vc{X`c`L0-EZ>$C2M?x`mP%ZfvDbK17JC?R-R7 zHdP7Q2|S#EU$8$8uVcITf1yl_Jh`;%c@Um53`$2@_INK3?43(Rr0u^`=Q#OZ6kFe4 z(A&9P*7)_!e5dBaB5~Uw^rk})&z6)%@kL;m_bVr zuhAbJm*Of?ZPRu%BKU8j*jd}d4UAI0QxrSf)TY$jPFL&NOM%_?JZf`yQPP%M!0OEkjKI0D@Kg zCDpto{f{#EwtX@=3Pl9E!?8Hv5=iW$0Cvx)6^=jPjQ;=!tmpexo#&5iZta7J)iu4* zx{-h!6gz-k6iLvweq@JJ3 zzsSMX{{Y~cU$h2^cjjq7wEqCaTR5bN)P18?mggh^&jc|0?a6Jx?#2c?44+zqN&SI; zY7I8ZIi&rmJWR;NgqHAIW+@boFj>PU+>8O9fE@Eh6iWpOD>nY~O8iZVNjC~_POy*p z7jJ)MZ`!U+Rm6X^HQcs#+{GK&M$@hb8FmD99Y6hbQrGre{j)AS!E5_OhkJQ^(z>mL z5=Y3!#Sz9%IpZg;IULbNcVaNGr%Lf}qTI?{MP6`=bm{z{jtAn0{1Wr_sQsdJ`&}2t z{{S6)DdHy0x3Ut%FN&eKt zbMoC=(SMPZVgCRG?X~k(TmJxyKOA)>cil2w__j$2e!)l|E=D=Q$G1w=)PLZb{t?#< zV^Hyj#!0-txflE=`iyE6bXM}*1|53w&uS>Eqec>!yoD^&6M0LREx+KE9}37iyQP*S-g~$O|5O~jg^{m)` z;E`Vfbp2vWsQwvgcM>}1f7|X+iR;=?$G1)>qKcGZOkuMPG zmj)d+;D8lllHGDP@!J4mzEb_0z8>FrBlf)i0EH;u41O2*>rn8`@7Xn!*7RFxWVb>V zMk4Kw`3M_!<2<%W;}lU&loOvL*xsfhtv3Cfn!4!O`fa8E0KqrD3)sZ3C&90VcJoOJ z`H>A485S1%-zg&*2iqK0&Y$}Rd?kj{?CJ2Q;oQi|h=1apvQ{di9G@?o<1|rE4+U45 zOQD>z3`>VIoUf-><>qf#{>8r!VQ9^-!=C}$=aNQ;XfQF1V`dMfK0nz*z|peB4}d%# zsH7ED((Tbx9)VO*MRY1Ka&fknG5*xWb4%ZUd;b73zP0-}cp5qFZDsI}hVG=1$yl3A zh;TAaTwvgPSB-woKd|S-?}nPDsdE>J{2!*rC6(p1j5c~Vn6|S`b0o3M;g}f;#$rrv z#v|`V6^%SCKW(D1xAw*^oFx@D)7>`h_ZPfT`wDzo)qWiKQ%kY&w}YYA^w^+_Th(qn zITXiRDJnb2>miVUR61=|0}_*rfNSkf59yOw=oU8f7~X4pc;a?N3arYY5(i<&1GN-a Wt$B#_`XxVyW%y9FEEU4s)KxH|-QcMa|k+)3~NK?4K`@+R3y_CEKV zci+4JytiO}Q>*){udA!8dU`?4)6&x?0Gf=1v;+VG0sQ5OKEN4TY{b?_NzyLsj+aSO{00>+F%ug8rsDZ%yBey_M z{lb9RAeevZB7<@7Ay|IOm|%Gv^7-GGHB;O2jM-Tnx+{02U@@P9A0^9%c?=W;PxsZXR|n0DvSg{kLt9 zPJvl}WUnyNUp9be6h`?=HU$7;0RT8KJ2Vpu3sWrAZ<`?lW1;`Z$#D>WVIUx(0CA9i zILqRp0f>KD2U!;n^QXNn9_|-E*kkeVzhp2l9^sda29^>3;CtgkfAaeS0Fa}(zwKdW zV#>t2BP4?ObUGI;18*AE~4c!~XZXJKOE|s0O!x&Spt>&ca}hv{4+}+A%4#css{4Us^X&jV*HnQfBI)O{}=K8kT46wK36t^BXz`N35a_K4-~2G!BUw%BV#po3PeDB z@e&&cmz;ugMj@=@k?l z68a`AJR&|JF)2AEH7z|azo4+FxTLhKuD+qMskx=Kt+%g#U~p)7WOQbBZhm2LX?bOH zYkOyRZ~ybb;rYea%d6{eH@A0BD**T(^AQAGdoSlkXuhiJALl0b-^+a9mIp3*v&4u`X$HzIna4DK_uXSj0VD%;!OR**HX$*kpd;Fpr|Am=$XEm81 z(zw*>HR+sCxRQdmFYOfRO;$zA=*mHX6UVJ!rCc{!V{zN-c`!#mZBJ@an9XGOBUS}Sn_M=g848{?nAt#)8X020?0Io)JXIJYcfV#Cv8|D;L~=bI`??6G_6z+8MCo< zR#M$=I1VK{9Vl9=r&NWx9S)01?bD7w&otvY8IyrW*)+4Re}AkH@JK^!29}=CVJO;5 zNk;<5ZT+ew)9?yE0v&-1F2`=3E%_89u;1{4c>-94#vu)xT?I}ILBgcV!7z6V9?z?$ zm?wZR)|)MAT)gpuh{#bU{gw&P7`M1ZK`hXBp5)f-y6ml62C~d<6~curATn|6Fw~{l2-<8$0198l3 zd_Nj#p*T~#^Q9CvOmYg7LaXy5QX$8{=oa#ng0a%Anu&a5clGmJ;AQGmHQ}rsRHGg>8*FGEB6o) zXZ3D&*IYR>D=6^`H{!iXkk;a?Fc0tW*dDDd1Dv9q*TS11KCa_Z_W2$1{5KD>Oq(&z zx`gDNoXtVhD*$2$&brC>Do@ zfQ^PnxTdUw&FHJeY1p>!8IVx*qFXB5&djpPpj)#}Oq`hFC2nrL9kmyUP#H}U9-iwN zzIpT#jTZ!}lbzXz89xCuJwaRqshAYyvYV&Rg@hi*L(Q(F{kNNSDz3dcO9U;Z?H2Jd z{J9=z3eqzjb}1vaI8F#jnbmVN>uJBYI;^+NIqHd$KuyYJD~{9L#yJGMax#g{Pmkv> zllZXiL8jh7N$do1Xnl}_ihpA-CwO`E@-dJ$^*q1Kim8faicn(n!k=f6$If6&e>#x) z0{BeSafBVS605S+(%}{Wkz)`9wNl?(;_WS#WE$eg7OA!?Wd}Fg=&pLmVzEIOttMng zrW6MxlqEVl(sC3t4u-Woj-T8w=X5r3t+uMSe#s&XRi_(j8l{;)g0l}r78G z5^cE&4o)FRBj#Bv;M0y-?a4iJp}`4#u`V-zzv1h=`ZxYQf^mD6!`}B5eKGH61vd>m z>>LQ;_k)(nxrwr^I;Z;D#d@0O8S{~rTI>(hp8%V%9U`aBAbOzkLWl2miJ$r`oQ1hF~EN1vYdvs**U5lqiKNQH)oaIrFDulx!Ynd0J`_j_zWExmq3RmaIa_ zA}bRlV%ztI(Vi8@qORkx8_fPg(3iZ{3of9GOIgf=CTOBqtf8C7`Rv05b>nUl1A)eI z5xy1^sU5sl#wnB4BemiPOL~|x%t#JY8)iNjqRF-xl(vHv@xgr3kZ3NzYUaZS?+0s) zIjI`Y1muoL)7U3K!&-WEi4O1DfXllLu0gUe2YtU~{use%T3XHQi5*L$ua%Y4PCHfj z93Fb;IR@(kPF}$(jpx$zdjUd>*s+G0?gTM2wT#A1rdsOjNXR+WjfM)VMX#tF;MPsP zDK6e$74?I06VF#^^bU4E)bPGBfKzB3 zUY7)N(`+UU%ATb7?g>>D1KdHr@G^6495%47G-Pih=U;%xd}Hu z%GxvGEI^M$-xy!IN}0I|ua|A?_#p17RwUgFt5OW(9Q*ps!6SQe*eEt@H8(QvW!WT$ z9rCm7Nh?`?iLH^P-Zid1>K!7P!^$gpUG@4G-_w~#j6q~I{4qFC?Dn!P%&Ymmrh!|6 zaid@3Y1IzzlDcO1PgMO)mssz+Ct(a~&(-pT+eE2FNAt4YyIl`$>o#@Kz0{Pp#$<^# zVbE9Q^@S^X-IY{r-o?J2hRM8ks7u~Ca;gbxW1(E<1`_aSB5kFe=xXR6o)5qiiUwrX z``ty^FIg>e;u^0E5G6Tdgs_Z!g*sJUK7hsBc4U%tBYz7KnXu7DHTd0cCefdmHwM~u zO6Jm*_zBRTJ==Yfa;wR|O>WOKDF<0n#|h;71aW_qpd^GiEg@Qjnin((}MHlLfD9Mq> zcVZ*NEkT^5KXomYSIaNT?anXkm>mZL>nmx}5Xm5{x4-j*k;mQj-Y6~5&J$L2Hd=__ z?uva+CpbmPAT?WQZF}cnrfiJa#ZX8&TY4u!9h6i+oG=@p$(9WtNgHpuYWCl=xIK&N^1b4;RH3{G#haFxsxKyQxn?Ux$Q*wvCj zNdKuRCm8un?*~PNPD%gfe7ASZwB0@1YWg zw+nd`lV@=Nmsp>@xF@5`{0`!u05pB)b*MFxBu*`4_=LHOY1XvFGj@gAMUX2mVzF#Y zFLAD`*(iL+^E47BZo`Ima-IPCE^QQ9EN>#DvJkMpBE~{RJD|8}oC_`|@=~Q>+ja^f zZdJ3%>&Sj>+gLu17Y&RS3lSvbD*TSDL@>A|6dd#wig4SRAA^p!CW(O?^G3+}#h=n6+*$;3jamiB!zm8Z_4x&qv2OABM)S(tSAWzpgplcIwG)eW~j7*4cj(OOz zC#G$O31Q_J3I{|Q8mAexX{~5-v&dY$HmgihdVg}1(_0>_q;R4rH;WU;+Hw9si>ES8 zJV%@f?v$(V58n$WPnVIK``S=h$)2c~fwJxlaJZ z=&?Tc^w4}qOsQ3@$GeNwOBDBYg1u2!LF zwFDQf7D_O`KnfLP;E6c6XGnQ&WzUpqVx|`tFZ>=#aKF%t4K&V0=5k+#8OCC3CsodI zW-Gd6_sMJS0u4wwbXoFsHR!~9ngd=hI_a|xg+jPFIBPD&>=r~0#iWuqp#M{A-%n{MPtWqoNS-Uxj zXVod0_qmX$atkUqUQ#~2DZD8+!uX4{B|3+=a8j0XRW91 zWL{`C@9MvfUF3)zWX#j<4^JJV;v!PMQFWz0v7K}K6&i3$h_EHVS za|H4a2&|8UK;aunh?`qC*RS;I%&=Rc;+XiqsTG)yde;|T#Zc5?X>X4i?{#Wk<2X1YpQS?eYAoo3S7nqu`~WvL#UMYlugt-9JXz|_8eOz5$AS30>V z$sX@x5E0dVca}Zb{qcM5qXC0x`7%`?v;&~*(J(Th=bm?gx8tSG-4Mf|eLdQCO#V^T*bG6?9j+nTm$SU0R5$&qZpk?JyUalA5S7* zG;^bZxCG@gwR|_#u`Jy|O?36CLCofM4(hVsBFRjt$K)j5=`%e^9(&=?0gtQ3*4gSS zT2}>_+zV+v2h>Ar;~h~{|3)&zmbiRGCq-llze+kQ6>|OM;UlT(@=;Oz9T)yX5+h`7 zropluKlptt?M6Iy}t*aluF$Dge{c5s$3AFO$G+IaNPr$Ty?Oh#~!62=Uj z99Z91aZKA^*(z+^7Dp_1Ocopu6Z>>DIleztYYZ#uoO>6zrO`q^a-ml%HnJV>IJ~bc&Qv2|1M})*IMCa&^+ia}4@BiNwzwP*qnt zMf>1^XOr2fgQQSc6kY8N-T|w&o>>FAonf|4g%5D6mhtGp-J*fRC~Siiq)cugR9p*J~5MDuKQoB8jywYG}3^%aISbg(J8i z55nOW_=+Y-(~26v{sP2_yQz1DrW=scpWcxT(GdA2MX5Vk^!<2%pAfYjAKbf@RlX;H zN&KN9Q=sk}ZnA9xkGqAgoK0Vf!FB0bTb~d4o>X|5UfGUv-LADQ4V`tm;`Nf$Uf+cf z2wCueMW|n^6t&OPAG^0?J2gi4R%CVkv-LvyAUOgCS!c2QBySuYi>x@lnJ7zm6le;z z^99VgR^2LqEY{(nux<9g7Dkaqg{$*jI5~^89+u}+YTmAn#C^qs_)4|&{IiuI&_2Tg z+HvNcw-nwU4Hv(p*Orem9jizBl%j%ogPHr3{KWeQ7%$W}Hy?3sC z+Fsn9&vex0pCh<39iWQ^c+fi(yarZTrY9$D={w8#vUiOjKXcj@rX=1EQqlInILq^G z*KIWr_(GzFIt1IG+S7*8F@iPMtjWSOtbKyL$FR;XvKahS>9i{M2pK4mX{N6rtY-ovdJp4=~tCbQe`+hcW&W2&1Sz<<-Ces$_ zo8ubpbg24NB6=OHI}`q+T8k%wBW12~b4^*dF(b_g-MEy^Rqy_RDhm!v6ot~2>V&Yp zS&cM*+iGueIUKEXcMGapB4SRx5A2r+y; zGPJ2#2g`L|8gjN~T=3N$nV^#fKh(I039$40!-scFQUKhjx9{|81F`RyzVKU5%)Dy8 z$x!dwOdKaweH@LRKlDPN&(q0#0-#>8J?N;tQB0P`Co=@4MFX&=gk6oC0=C)$g(TDl zm*LJUY~AjTMwl&E8k6!vhiE%vY5?bxyc!M9-+vIRGgxJNCuAeV&9&dDrO{~IKLy2& zO|EB-#U~dctf>rs=*Hq$i-wH5%I%j)!nq4jhDVGSK9~zIp!8|(dU5Wvc%59|N0hmt zG8T)2B%aLX708UH7Y1FTS52=gy z7KGFvBvH$kV_t0pDcHa8aFpcfY{mYJAULeSrw>+8uwmm4Sej}Y8YvM}JqtJNx`ei!zJtv3`)Wm|p35lw@w2*XIzA{_^f&JMUlFL}ES1f{cS`D$7 z;!TmF_jRd_p{3%eR_^^{*X91%M%YgBOeO20fYyY0&!#-=L3n~}6VXW(&4)9Ez;Vf? zVZyhx;FF~38ASk{4eo-_%Yg6<;Q53QA5xweXWn|~)cu3>BRlq)S*rfa$vId>Ww+Z$ zm*UOqDB9#?eG!o~qpss6^UqfE-GOXWMJ5CtW6?~_EK~yfC?k7~FMC%+sVEQ>ZScaG zp8)Svh>$Jbqh7*K5%o(ynq|BUZBYB(X4}p0lvcWZb`qK6F57@}8RFMlj4P0xwrYuZ zy$L8LAX4Bm0oh5N7^^&yf{t2gM*}1 zC2Hsw3@oKZJ|Kh7_2(hVZUb6p#fIp?UObDpsoVNx`fu^Q6BfFkmY+5ePAyzq9C#QR z?VK5mOzn+93?}xrj2=b~jLZy7i~v4C4+kR?Ymf`EG05D~j-T?hy_=HQ(v+VP{BXo1 z?;r}Yu$1z00;zf_sF`?Kn{b;_3JM_ddGL7HI@p3-jEFsKZS0(RJoqVplJkJ&AI*%E zV33ok8IOvX#4iZ=j-T?EE$;5_4DPH9_D<%E%-r1Ej7%(yEG+b31iiDTor{qNy`3`< z%z#MzlS2&TY~p0;;9_ZS_lwKnA6)-3FK`EVUaUWcC+1`PVZc9>|2skpdl!3W3wwt@ zCjKk!@9v*Y`=647UG<-7f3Lz1LSjBJj7QW7WaMJ+q-Jk#BOv@^wTP8|4^1p?1F{9# zxi}M>**g(yTH2Y~yF1enyMdgXe>jGPfr*dtzf%0qoB-CAm;diAwYB}}+MmYC%kzlY zo4Ee4Nm`8fhp(ML&i1ZOCZONJU}|q-Y0Bg9dj>pQY|N}iEJkefoW`t7^vukrCiL8< z?56a_W~?BP2{R`*E1MZ5G2{Pm?MzMn*zDlyWb-qsOidU;zuX00GUi`Ko0{;L{vAs{ zoc>pU3EQ~*y!q?S#_~s$@z@yIne$V6(1ZWeXyj_+LMiY^{420*EdL_-8D+m1_!$3f zi2q3xg1`kAf9CQp0r;cmPwWx%{kY|kwly*b{Uf}8I__7Wypc6X;Ad?FALCE(&xRjS z_Sc=gDac0P@2ZL4pr6fuP>N2L;I%dSk#_=W_73zmATyU=$bWYI!E332*WbqVC!Uk# z7xJHte~{Y$3Hi^)Kghp<>UNec0@8LS7Qcx9OV1w-ReLiRcOxgz|2}_yrpix0{K5Xw z7Bg}Isae{B1b!A6{3sP+CuU~mVddgsV*WMoAIKjiZBb)UCnbAl%O6v>v@;hFwWkN) zfJ}dB{9FGYO;LMW2Pcs8j|>)&Q4*IV25-*48-6MO6Y@vzpI-g{=W1~e2YY8%r@tJ= zOzaFYVfbbAzx4dET0+6q))?d@V8+P7!1W9HyX_C?x2lqpy}6T-t$?rt$jJpf2(g2c zrM;7-i|2nC;y3OOfwZckssjjQssfH&J9Fp1(_9tg_%CTL0Irw$54rpc_G_5`!1h-T ziGy=cju0N|Bn5CLjIWdJAA+?1pdG^viU30{)#0VW1kCSq}o{}Ld7 z_y19qH}d!$R)1--FmV6SB&ITBWMSZ>{>AWbu)hgFE+8jxbjgEvOaUb$7x10kZ>4{B z{ZWv!G`Db3{rTxApl)YvXYc+S|IeQP{CWMCtHCEJ@R^G7&&dmX^XKdavI8I3oPM0l zepdQGLO}fZ{dGaXK>fI2p}}=HaIoO=AMgtv0SN&f9uZs&gouKOh=dHj;1Hgnq98x} zAwTv0sPtKcp!-$n1FoX_tJ3FpoeuyS66!~t4+H?=cfkw*;`l~#>FYgQf z;S#1XP?~8EA6odi{gC`)&>Za`{Gr}kH>O9wUQmHz)a0mV$_9Cg6g4x7&qbHp=^eYY zzw7IE-Aa_=T~7mgKdrM%zXBrqx&C`&La8Zf=>3J$85#i(F|Rn)u{#Ntdi8yjXRqTD z@JMKqIIxekkJ&5i$Hl4^(dkBcl>j z<5N&)Jy}WuuE9TW06)maww?qd*y=tlFCix@m2v94n=4>m* zNA1@Rd%lcW=CO@3mr19kjIXI#m^Y{vm&>Az%(?326XqMDZES0sZ8&VpeC?vNi`4R$ zmc2m3`pED0Oic>&X>QHFhIBX5Pv+(s)_h{l1+ltwXr?(`*-q@-LB~5o~-1?g6qq#dT!>OzD-}MwZ*EE-uJ%k-?o;P~Q ze@0?d1n)B$8APX|hb5CHfyD=L!4a&+%~E@%eV6GCYz-VKiHl!oGY-m`V5K?Q>x8 zdX9U$btwmWW^K>C4!?#n=*p(u;FzrzyIS&#zjLd{;*0j1rhi>KvUgGsJogY)b_S{?8B&~@uo*d;-sihDM?n};z&bBW1vL#1e9{r%uzy?8E z4xMR-xT=+mG1Ae-SEkBeSo?V*C@7Z|or{cU!#-9tX|K9~|L^C~R|hHfmRUd6O}e#_ zNtKS}(~_>f$|L)$1g>j!{{0-Y_*6-G3JGB#SIf6>Y9pSeg`(Cx_tIBm2BBlkl>EL^ z{AhktSr!WEP1V!-GHF+I5u=!x29zXA`#YyI0?+V!VWVZk{J#q4=)lnS5)g>YUSy}U zJpq(ufVSZBK6*3&1T+LRBm~UQ-2uEeKtVu)%l`my=ors2v0jjn0-0IhUsA9Nle7KW zB_LoR({S+!t1Bc)tTQ9LI$89sFN>-qNj_OLHRijwBj_wUwrVrE^l4UAa~4%4?{Ypi zmz@?!lDMa1e`x<`fG+5VexiTIdntdET%rFZr;jsvsQdev(Y$-$od4@DdaS}5x&iPR zt)}M+>#~j!9H$%M_02LH!qP1%a39>CeUf+ES)(c1HDE|^u4+xmUYU3e;m%-Z7hVxp zw%AZ2jZa$s1hCY`!QOO_~^mb8pC7~ib=T(TvlCUr66mc1DpjJ+tCxX`G-?Y!2(epIzqvKU3{ z^5WJtfc4RQvTja&&U9bgt5ris{DVZWXCViTVv`6LH<@K<`FOq@9imZ+Hlau>L6u^o zdOZS47svF_ryN4DIQDRo1w&z`oUN%5Y<7)&KXxaIxCxW2z6P^=ee~fOuIPlK+R2$( zof+10uH$zx^~-8iG4mB<0ap_I9z6(y7CGbX^)9QyboHa2Ay|89g^gE5Qf(kTxo#FW z0h;exOh+}DR1Uxs2{B=eHsW&9y+zZDK}kfBl5u<2>V#~mG!1QTUP=qr0ZQq<2BM{f zvSyjf*O{kdJ8L^6>MM8+WQP#woq8Sbpeq+cRE1;DyV&D4dB;{&K-`$fR!1=Q_XC1y z^}K@tru9%1mO$(-kN#ko7+m>|mW#@hRROnIBf_=v?>)weuFCL>Nb&8!#4-HP$&f2{ zhi-*hoggt%QIX_RyCW8Q?I%ES+g3$ZW@i<-zFbM$u&uC=^11}a8&Ynz&o=5;=u^;U zx0n-ZMm(0LY(_4FWh3#GNT!U@oTo7)4;sLgC=4d*;8wUK%C8Pd2fB2D%ug<2F$rzMD_8`}b;dm`+Tt197XC7BqLBxM6r z@1#SFkBfKE0pc&@bma@*H5oNw;Km@y4v8b}N zfuY~hh*OU`lk`|w$l|F#U%)86()G&-wO4W;8Uy?%v}35<L5#gKZKta8>79GViePNW2MvTr=;UJYCemZX2;!J zUx(XznXtK2hSAqDd|~`j%T6ar_U%W+1tz&DXD+wygW1_uqG^{;dCiNI3GXjwWjekS zD2qjV=U>itjUQIe`g2r=qnnH?^HTsDcBN6G8T*q<3BT*M8!X>!$5p&l5gx5LyJKuj zbQ4)UjE}JUtmi=kF@x{fj8YxoPB(jEar!!=%t)GTnIOPaQ&w?gNBAYeI0pX8n7?)a z);)*hsRjxAwMN)I>&+9uRphhKw zJT9kNzfVWKZq59iOQBEmYxaUlu>)kT+*IQ1 zJC^Iy$PMp)oiGt6h3PJ`v1W{=PBl{)GrGmS{CZ_?&#J`{W5D*DfCOD8iUyXH8{%c6 zR!M})BL@Ys$?V7eTI{LaxpqY{H@x({3OrhfJKnb@{$l0P9itg)%R+TA7BsEQMc|EpYp8#3S%zI_GQAvqQu`!X5Jz|TqyJ7J zf|b)dg{qqI&AI;kwVQeh@=shSAK=+Cv#=Fh@1|c&^ABhym3<_vDvU?5TaYw~Iy$fH zVV$KMDt)cAi^wK=zk(&PO8M1s9xq~%jYC;2D3SYxH9q!05Pjvh?5Qp0xFyNl_EME* z3hotsNm{Xy?%kIDC%~4jv$idNww58fM10@$K2WNHHS|W?B~T%AM021xgK(xyux7CC z&Lj$oG|^ba+%c+b{e6rxDYWAYI&DidCkoPfrxV#{SaL<0x>BPEzJNrPVXF^hQV%}< zXbHm$?x{KBw>oRa_(nn&R+8u6V@yY7+N@b;rJgx*Qn%<4-x3sU?nLL-^?t*Sj9r#e z9!=glayA1Pk6d8aebKdxM;|{QvjRd6h5{=7c^+T*s*#g9(x#aiShHqa8YHy$|iFwcV85mACxn z%AFMkhAT(v5|VK(_O{u2!xKhl9MM4QO>rEjaeWove126`Qw@j4NK^TSsE*u|kyw93 zsy>aQU|jxD^?7G_;35`K1n(hq34M-y=-mx1eQt%JLPhdbl#KRdZkY+C?bYI#MtqBE zkfSC=@5Ho;}^=qp9*kdiaO-kuj2JaqkrgG^$&b z+T6lm8y|8XzC83O1!X9fdo_$J~NrGo>^G(6pBHJvT;thl!TG=1{clgRFSLY zO9{h>=vvtO@L<&UcukqnQy9eMZ^a6u>7&NU4A>;6P?=Np4O%Y=&X7c!>$G$MbG=a{ zN00?_1Iw4kPfQ>6^A~F3vZWJf9~E)Z{2`y~H^(}D;Xt7?gxD5pIX!#%}dgg7LX@m>575i>KOcI1_Iv}mlx z2h^7`8Kj|(MSzs~7g6cGv|g;=iu1QTAhTPfay6ARy0-#8ShONAu}Yv8_5JoquhT z)H!^ES8z9r!7sZr*-7ga;oQTk(Q+9KGNcfY{{q+VDrct1?jjr{FSZ#zePT{U(%?A# zK(x*m(WZEpbZ`0m*rwCkiGHp=y@j0C35V+veN147R@iAM4{m5xW)e=ontB_$GV{fV zLbM;_L?w}IbiY0Q(Dz~X*L<7^OP6Nt+smXVRj@bFzW6OQ(KuKyzPVqR#f|lx%W{+2 zWNbpI(8ZYfV47zb0?J=4dxeFmkb3IQmPZ>%u|Cc;cLnb^n&J&V0q&dCr8werr18`i zI$Jz4k2=<}ZH`AbLhQroPDIs^orLN$YWQuixyDAsrjMKpqT6AMBdg`RtG2|~5qYlHndS*-^6u zLS0^AnzPaK&hMreuk4v}$hs#y1JW2zYu?kFCsgI3d$WIjDZL$^5@7nm$ac|%>O$u; zlhc{0a8($Bitp&>qLJe*Nei?SiNqL(hUI5g7^)qWoEo*AD<8^rSXv9UyV#QLL1G<| z$0y_6DEPGZJi4mLtHY24$4ehp^I{&HBsx*GxFbPIxu2yiVNJjCf zuRy4)m&KglxUkf=(Q-RhJ@U-hPCI2JHZ+AB$SkZacjLXm*76VV<_P>?q4El%jX@cM z&t=@Cc6XSs5+zOyeY}PtNIPMPz@Z|0oN?vtyV_MnyUPqwQmnJro0K=Z4%5q!>a5sy z$eh3)lF2ov*9;`w)<}cZb+t1u_AYER#}l=cF_!C(Xh;ydknM8q>&lDMO6EG>;8cy3 z_p0s@wLFoe=7_#9Ij_F_ROPn2nLAkbv3T{F#;USd!$?p=BEmZ(%=wHA3x1(R2cwZt)Z55sWUTF&d0bwbk^3crox}#j6t3OapnivfPtzm* zc0NtbQe`WG1)?aKYs>^9HQPJsxWl)1xcr(1;Hufk?pjYq9YWbnXBx#ndW>)W72D!# z*hi4Frfi>X>&xSznsFT-i9B!y@6kN+z7c@&J*rOy(%H_215_xMm&_a-JL-B&Y{kE1 z2`_utYGv@oNcy@_zfxhY9V?6(t+AM(aaip2{D{jEB|VR|Lzm_(=fpt=#o=6lQx1w6 zOssvdDJIB6wrXLagF|z8cFSW-TD%(lpiV~mxznW*C+m|*!!L+9-DDQ zS*>1b&!4^B(iKnmEF{uLiyXX2^FZ6sJwAnBKZ0OdG{jtC z?;**v4h@Cqx?^JnL+?*tPSCfe(5nf&ERbyUm)dfHrptgASoeV~&>ADhnsS#Yz2hD{ z9Fo};>5nOLlgIMrduJLl6`RmZ6=N{~v;I!-WEGFfFq^t&yvbO*viB^k@%B~mvM!9O zxK&X)Wf|3?0jKp~*r|HSvh+cJQ$-S@9afvl(cF@Lw0YtdyNctX!1yLFuM}sq0lC*Z zUhCnoJsBKytlEb7xG3#BlaGA#i<3QwyyVsLzSasuRh!!rmYk}%W`*_?958MQ_{|O{ zP9B#?I+woQ-Dxk#^^NS30`OQmn}QOkuJdxI1(&VdT2{sNRN&0KXgTuyXq(&uJgAtdJO}y0b)=k$@)3N_#l;T<8=THo43>Yb6Zjz-G(cf z@Ww{Dnd3rjgA8F2mbbPvLVnkKxkKH{W$iA5NBT&NFIH4);r!=-aKV}s3F!Xm9IQ~8 z&d~McbQOq8^V9wLG2o|Daq!dWukW5gA;N;ccLo4|`3xNZ4gT)gb5dbt7#1T;5yyac zXh4;?T2|v8SS&I&QB|ku&9fKmCeDF%d3--ZdzYZ(=b%22 z$A`~TgqjXb^X3TvE3yYvyeP=Iapn7X^<1zJ9bA;q1gt2=r4fK1|H2UgK?t{LNhUq^ zYEv&8YZLzHJ2h;9&W()ZLxiCFW{*JfR=2sC--GD~n3!izfZO1O@H&y=^(J{KuVp_x zv&%b;rY`I~Nk8)-Ecz7>FZEMDNiii?dJki^Z#J>g1(j1KvX*l}cB!Ow*Zxy%Tca3& zQq3^_u?C?B8o|sHV(SW;?rN^;dvU*UqWchVfiV~FnVz*IZLt8>CXSwf@d5E@bhokk zOhIDWfhbgk8o82Up|n~~l!AnX!m4sSuco?5BeE)kTl%o?(IyUTLMP_(O%RS8Jw1@{ z@3FLVMb2iXtA2qa|= zNE=~0@Db$no9`aNj_*E@5Yuk!y`!HVKiVkf&MHBfKM)mxSR0(%mhx;Y?PmHyi{z83 zqYjNi>|PE28#EM%fOfe;y|2DejjQ4NApz}5jm)UNN7KN3S6||y1vdW-J^=Fx7 zC>c;QS;B*!=~K)2qfw$p(S~fgd&0<8SriP7CJcCE@K9-Nnry$ShwS9llsUB2Nnj`7 zS%_zuNvOoAM3LSjDn#v(#6bR5`T1x}%&L~hOXJLYfqP14I1RFc*71Qp=V53mOhNKM zQ_7dVX_G5@Fdfvul-@pA0ysYzy zN{7S&MsWGbj-B^x3IlTHMN@W}ri|C#rn9G;lLNRkmmhp^=qsfaWeCM{z@`%lM17Q~ zIK@C2Xbxvj7yf2?EMKD}f$Okm91(%`5p6P2g4++bcWXYcvnNtDhgx93 zYhV~1L3`+ddgo(9jD_%kv_fwxVosNY?y+hr`bOOZQ5KeUb)|iZt#u}iKLU>$LH^gu z5~nOPU#d7*Un^o%9lQf!;_UeymDWazTzG<`-hxAW?bUmeFG99H?3IQv5_7{k5Tafs zyKQaaaEX1d7dR2rN{MI2Y%~1(@i`Lo3}uZ)rr0^^Bw?AAvMr0Wo27NF_8oPL(nVAQ zMO+9YL6{1!X@a1c2YIE z>5|*-jn24`GsoT>QrVbWeK9KtC(RwnMWNnFOeS)9iP(gQDcdwIFt%YpDzeE=I`QH} zGAki?`t9mYt_`=mpX7N+K@=Q1t}pG$scrU<)|}!@^oz zhDBWY2PCXw)HI?C$dm@8qi4l|1cF``;iPY^C5kuLQVH$mn?x_Gkro7t>Yo5D-*Jb! z9!lrP03VoImk2ML{b`LXfr9?h8f3s?}vF5l-UJE^)`oUg^nw~3RTEse|@5{X4rHWvPtPR&R>y8uj;<~;b znHw?oyO{8n*s3p2fJ!Xgn>-Ck2i1EE!-xFb*W^QoR8Yn#{7R~r7~g`|tZ#HFFRS5Y z?z!i7U-w^Rzv>_|&q2^SdVU}+1Q+^lQk?Z|#EVa69@(b2h}6;|AO+#0nVfbdpnnH( zG)Vr{t3f2A{_Wy)Yu+SbK-=cs>;~PIBPV`EK@2HvE0kB)OK5VAiUu^wH6^l~$&7Pk z+YbI>&y4LFeeXb@vbn4B4B23vy7{hvFD3#j?hDzvpQ$%X5FlaaqH$k@W*89X9K%(6 zsSdlMP{ra~d@1Fr@Cuko*V3tNQILN;yKz~mC7Y8<2YMaCW8o1gG;pLdsrZFRdM&UW z>A#>F^k3{?Kkh_PkUH3?Qvq&DlifyIOKL%;P?Av{9Ow*SL+vlrL8WSjFg>UBI_~iQ z%tB;Lx#og?t17GwsW`6Qr?3N68ErOE%_5f19mDH(hohLcp zu&<0pLky}-Hu-H7Cp=y{{6P(yzCpOq(n+}3-k<4PHHDW}tjFs?8aF#V$oJsCKCy)P3B8LcKL_$GCJ>(gTTcqBz_u6n6eDU^DIbnh}DG@{z(UB2R z^UxM|*d`>>A$vF^HOzuIJuj&1t6HzgkF%7(j0NDj&l-u#w|RJCi$;7Oyb1$;`7l&S zv{5ZuD0QI!sOGbBtYeA$kp4bjp(S+mlLF9dbS-2!K?0g0Vx3yl^3bLFb15UJF3jIj zd}a_4k{J`T(l1KaKj--QT}&;=K|~Ec{W@l)*`7uZMX`{?N+!U;397=aG85&TT+*`I z66I63#U|OaX-!VP(ian0K6k1(Okb^X>8s!YrgS{nXsZ7xKVpZk$~X$sziPlQ50_CnCq=S)vUuO5&b z`4kCIR2$1B{h_)u;F2vKQ3KUV&YBl@?lAomQ~Ysns2=kNcYYYmkfz4RlR2{-59AhV z8r%Urfo6&*Iir;`an@^yEVKoAh!Y4vE8qs_wc%--* z+ZaqOy%loM@mj4yCU9AedvR0PiNa)|IqrPv12JamLY;IVwF1{D1HM5)w74w+&DOR* zi&kP|e{7sIB6bcixbd7%G?6Ia}H2gg-c@PoeDrYsFq zfp@WtC`wyW>y0Xd= z$HmKjQM3x52Jv`NPD;HP7Rkbw_&NqE36bSq2r^7VG9M>YVCU^5?uCn|zO;?mSk1nK z)K`MB!mVwOA&j>h><{ACVyqI%xS+$=7CYK;(Ay5$U#u>M*mnOWp9emm#y! zCNZR~fku@u$VBeOKD2J)EeAsAC|0sIydVly?G`N=y+`48O+h7BVjPc>WLsRPhYX<< z6{olYK`%g_=e~53HQh9IH;bBD!2wlWNFAe)o<=JcAzc$KA^Fbw+BiFX#?-zQ%4%Rz zk8Sv}I60Vn)jjD6u+VZ`An{dCjO7Iyf*ML``P*rSuggv~@5YLcq^~_1Y$;Uz=QJTS z7Otm=*eJ!lVv@f~hXplKe-o)m-Fc4LwD{IARftkuLoh^}QaYktBtmzVeFq|r$6n}D z@Jm=RsD(GSpC1!lBdNU(&cGaSoHn&d`u`BqnFxyGm!!ZLxdV z2Ty%WKy@y%s*{A@XXPs1PnHFt*+ENINFwuv?VW=pR%11kz%g_IacH|eTe+LQXsL@L zkK2s~YV7=~Lkh+_9NL)*Oiv8fi|<1+O-fSRRq82QPXI@;=LD+fd769SSc1lCX%IS! zULfMf%TrlvwHO&AnWpRH25@C^r^IWMM#z;w=p1HVC9hBA7;;p%R=PL^R&_bf+Zq$4 zR}_1@-1S6n^Ip4<*Nv#f$Mj4zq9O0f@%qz;+)iz>H=;EZ(4NoqhG*s*G(cNX%Ph^w z!C09!2XYq8->9YDKMutmkS11Sj^Krtk2Q8}y|fu`iiO6#&St4IsIRT?IYdt^v)Wt2 zOkLKGg6;}jeCDuVs@l`=h7c-NPPz!+yC>VS0TMw3B)Jf~Tle9rO<_cnMSwKaM`9+4 zjF1Jy;KOg6;7Ht5gO7kn3PHgpkeJAkl*HLwNR|lxn;Bw~G=hYc)H3c|IO4a26@MA^ z__i4(wStopN@sEI*a8$v9jP&(o|52s!PG))m=sE>Ou~V|(CewjrB;m0d?Mu|MLk`8 zJ{w$Vx6Ke2#rCTRqmTV)Kb2ooA*T}kgropLcm?kKqBvJ1=eD{QMP zd%MeZ4Ty6g`5JM*6KC3OLs$pc8l_VqN>9^DC}S-SxSi;6)RZ9!3P_#BWP^edok8nK z!6+@Ou+>yDZA(hP1t7#8^~GBbVz#k4nD`oLW0@mDdeVa65H_StN2j|;Q%F&5l)?3y zbS#p!oi*~-Pdefo8kisuK6H35xpKwAM-jCu0qUKp1%Qidlo_AJOx(6Z1(wTi&bX@} zX(~w=Cz3bnCZD0G@f6-joaA zBn-)`rIJa={Hwz3!f_6`G591>d#02Izhn(TthCcg#R#8RsCqjP4un(Q&cbc31g$H| z1d19)@Je->WJHlk7R^5O;&S1J^^WuDNnPnmRvYIkR0A8*eVOhiE(EUul0c61LD&uE zlIu=@u&sG^rE#4BEvgcLccX?5Hs1Xrii%X&R!Ks_ftpBCb8sD0bP=%qDZrpow-B5G zOz43hp{r$OlTIWk6o`RLCJN2Muu+=Dq~hCV2g17wj`Vs1K@&vu0x44vXykOE8l>0m z&)E6i5KT4@1)lhC6Hh+iblK zGC5FkH`WM>-n)u@_RLO%_7+<0r%u0280gul^!BoS|F_sEfhZO zDPTw=3XmiGR5Xs!TBOZuq*U$& zNd$@my&Vp;BcimdPnKvwfd&WIc=)KPy@cEltzi_dJ2_5CFum~)|QnW64E2K z)PLBZbmsa}opt0OCtq0WPQ7-*${Gds6+4ZfkF9Y+q;K)9bfcN8)VQ=2T*FPLsWJ@w zwXBff+Y{VV_lX&$C)wH~qs}nT$^sVQWoaW-B#K(u z3wF^$RB`4_E0HTyWP+6ehpk(sTr|3n3M25+S!r@f0F@^!MLMs$&GUj<2%nfA#+y;* zIC7ExRkT-|N(u^yrQ$&8UU?K1k@-}JKN=myK;EOuK%q+mc&oO+Cl^=K)Y6%CwIC~0 z#@{+Z@qiRkHQKMai49vMKsr>`Dmc!@io8MPI$PySWI%GbHpoj)HUcmN$WO+<4Fl7? zDb<`%-m^}umAckvhNG|_3R`WweW~RnHr#GiOxw3OvNi5iywRQ2K|~m&>AC#tnh)@z z+I7!U8^w5SvfOQ_47fs%)8jy6&H8#!Y%DWsb+~Q@Id&bW@LfPqKxrCzJ5j#6xDGOo zRjx4t4?)(EezAgf0F6^nRTLA%-)d19ks_q(C`O>~S{jg|CbF8JP5V|wFn?Ot?*25N zg+)Jb!V;x2qbN|2>E66a+(CDX#85w_Tj;pCM8xQ-@kKTcEFd$c(Wmj6j;4|cpM^t! zFYxyitQrlK!THo3W`aTG-k!8*rPmRU7a3O#C;0kR`dC($P7tREl=rJ^@=)VQU@feA zDE|PZ9CzM;KJ`IRGf}ysqs#!LF~K)iOKDH7v?wPA)7DL0IKxT>Yi-tZAeQm}0A_@@ z;s)WxbC&KCioP0Y6~gwQgs+A^MkJ{0p{MytUC>S;wr(%nk;RQF$%a^2gWO@8?5s*`(bJx3LN>2{Ydad<`)TatsO z)SrbKy3$^Q9Zj{1t`+PfrB$PF~?zc9Pc#32n2Q)3L~Q zrMquwVS5FT$6)|!OP}&JhC~IHx@Z|1XT9rRqJRb+P>pyjo`qpe% z+cBQMvH*(H9K8ijR2fss4wTAu^9onqst(%Holg&l@2-T_t)_aEQgY%gCsH@7`=tZH zRl&qfBm5)mT8P+E9>D>5NKqAXjp1O8Aty@ltw9beA@mi@E1Ogey(lYW4d#a#NKOjL zE>58L75CRSH`ab^T>=&Jis2&F^S}_MQISGpx@nW)-U#*EPa`BL4d5f;(xiA(HmTq5 zpGf^G-#5bx-R^ECUk=ra-Ya(ZUPOe)r3E?d*DH;d{C<@UcwM>=FB(frm3`Ws?tWC?iQC>= zD%Hk<7Pgm`L{8sICh6BMm(clit9yo%ah0;OPEPfRrpPHNOyrtL*q)U>UzKa2t73;V zZAXYtXr!rW$fwG#wFV&8um?($tzO(FLyB*Nh}r2K_ud_5N$YM0ihKUxU3v0;=Tbbf{B` z5_(dzDoa^%;HWkCn603l$>vQQQ}7+>T!032PdAv>2yCNGS4HPtD^@tq;T02*K9%iH zv=p}ajCqI}(#8ebYnI{_)|B zJ$u%k?CAyIO7h)5;ipR z4?cKz0HJt@IP7UHg#Ft|2}!JDTGEp#Ft4RYPL+@6-i`?aEyYNIKyMYON`Tf4GSQOV zq7qPJxZ0Pr$4Y4TK>a?&Hs$5BFAj@v1s|0Vi-ufS-W9b33$?VLe6&DU){VuliQe9> z-32S+=sCJiTApb>m92RRGgGFO0PG~xFC_!%Q3_Zhx`Katol-zjlG=(=b;!ol>eX!J zvx^*L=7QOY=~mYF{{Vk#*^`nbc=Q5_UTMZybg5z2X94R!I<>QSG81aw3nUW@gS}`K z@tROcqZ*SynVKmHR-NlqfPHkWTugjLVA|VXSS#9(J9zv?l#sWYgSR-Z+p&kMg0;M~ z*GS_gq+eNq>hf!w#cUHD-Wt}vG2!E{?L`6n^JD)2+PQp5w@EP#rAj{bdLpXE)@szP)%r}BwK17U!gIR5}ByrzEj0FR)gvwP#Vmk1x;Gp%Irh%nW{ zIU5Z$PdBWZ33K^bd|O6=E!#E(cGjoj@}F@UQvj@`EVWE`Xn< zO8i^L9PgdkG&q-@GGnK=xIaoxNxM?}s_n&3){ZuV9n~u*TIaN6#=C7rL&F zIxH>5BOlI{#PI89TTi&O0x3(fXZgLs(X~Tr^eM0AU*^vXzH-@kJ;r4ur>~o>YvA4( zl18I~3|2qQUL1OVS4a6*FNau<{{RYV2Z!0m{j}0jZ7q~fP8g?DE^ONvEhq?PaV129 z>!f~FYT7Ss!|i|`bynceE?%8QF>L@K9NGx;fev_FUkAi80uYS`+P^oU2lLs8(|G&tXk8NWSQEMxX`^=H)k z)@xRdBh*Itu86LY&Z|f!?ZUtM!8V329y5+&G2y5J_unmcyTltlRru6+GjA7e# zd<0^Fve3)<2$~&}KNd^$6>aMtF$61c>SRRgGX|zb(+^V?CcGlchC^6|$hBLfQk3rr zjGHVOGc*xvT};!{Ct=*MkCe5U-ktbBq_h&(w80jX66_o+FT{1QhW>AvMS(v9nMX=( zw*LU-%j$8Q)4%ZsTzpt6hw4&7^#y#ULC9eXaXVuuvWgY%V}@(iYN{HG9RmLV@*r9l zRWQGLMV{l({$~xErBukgwCZ5x95*bD3vaY+Nr0rH zgKKka#(Z3DCZ$O%3zXh0c#F;bbauhp*O^$3iuUe&rV;TkP2MaIgk?A_`xUtNuJA}; z^2#-~y))K{4ol9P!~&Fj^$Y=5RhBgeh%+`MST>hdhi0HVd3E1ej8anY5@ZgOkKz!= znECTAbPBHRzw>e85Z5b}*{ms4Daad{dih{ro(CjCfvsVWi26l`cZ#6Ha~;NoX69eH zf^6ODraxFr8O6j?YW;`)>Jk*!Hn{C@Po~s5w8JJq$`~IpkXuHz(A`dMnhUHxWf7fE zZ_9~804>eq^#h6Ni$IO!3~J*+oD1DMxN;p}(T=^qV})$o--S+Fa3-4zyz^0Oo`?rh zzMBUbJBKl1obTola#Oay$SpJnRJ9bV4yvova@>{PZnY~w^c}|9aElmHt4KL>F9ptx zvspLl3#k`y&#;vXuI^tXsy^+?DE?Y6xL;VhMjtZgJRC*11@{9I78T4jN(H(g?M>!a zL@B3_aZR>_uQm~B*Nzc4a^Sy(e$KxPcV4*>g@OyFB-3c-eJs3JHy z%W!}f$-+2&U-2;v>%yuxlG;53xEWO9af_2{FMO>;b-%`Ata|5#dLBw&yYBpiSw%vV%vahx`02Ml2bo;M&2Q!!@P{+G`yeUG$DuR%29|)}-F4d- zXrV>)^>NH1BFSOwKbYx$HoI`Ih(=1UU*x%o1-}hLl(1}Ya~{;wl~u|>1KP&;F}4VO zW?C3ob%_B;tx>Zt8~&8eQoty6B)OwH5O$$jR}&L%S7LCszfy@}jTu9C!yV#L3D4>k z@pwSxYtMYa!Zt9E2pY~u!8ez2#wE)QtNe(-D_Cmaa)Djc!aNG)cl=DKMH<(WQHJW8 zcDu?DzZ!2CyQuth?KyWv${ye?Wb}>EpNsf_#@`FAlZH2BkH{*^8@~ArepnSja%)HT z1tFm2Ue;YPoi+8>tDFLaw=|~_g`l?E3YG3oAC3CSF*R2T(D!T_a5;j~-&HTT+kjnt z@8&g61Bh7(g%cg%r|L;tS#{vlj*o%iu?% zvAsle@peiY;aO?_01??JtprsvraOB-$w60&&xmp^htWN}H6`7GBw+ZnH$BMq|eARi+r zcN2w(v)*@3yB0eT3yDLa9K+AcEEEYrYX%77hY>E1lOGo|HgHI69IO(`VI?wJ7P!_j zE0`mhgl-oAXoA{|wia@ITyW+Xc#T5$gO~-m4kl5K1*Ulh2V@1XR!8_ojodX_hh0IM z!~hm0;~#SiZzs$gxJ_gp2tJLQ(Gf`zXw+ZX97-!3h_nD4=5ly8R4cSyGgAEc`XWqD zC9rKU_9SENg6Oi_7!F{WZ@q2?_awA9J!kGUEO}HGkD-A{CaKl?CY2=O%HPW!PPz?O z`zdAw9HF*g&vNOAZU-=Rj-fn?BAd^aa1z4`JXOkhc!h@5ZQT_+1M(7;W3b|{;vc_R znRK?>^9==*GXq-IZhXqWfK+I&K#+OHB2#{W@zg0z-IT$#y{)oP>1H_UFpBmf{YkZV zFLqn$1%iu(L1Moq$X3wSC?E<4X5tRxq37J{+tpU$8X#k7X1*e|SQ~~xs=0(Wg%6O# z#he%4JBC9&hbX#m%fj4)9l}3|Mcf<|8DQL6zD4xDrt9g>3vc64g3Kd9z4GCPkVk_Q zuV&I@F(}#7*%=!#F3^Q#8v_Q!cy4hv!Z!VfgaY=P6^pvC_?k|T9jM*YsbGa;agv@l z6td^M{FVTwso1yWZW32couU1qvEr#oY_J(dO2F8oc8@8e(P7R&Ib*UNfh(czV)tl- zbX>-?+FEtjbC+_XhU|VPe3lM*C(N+bK&Uodbpv4ghUB+KIh&w9ub9@fQ=TdSOCv+W z_cJaic{ zrqxJzn}3lZ@5gB>2(NeH0 zmX;r0x^WVx0koJeifra#!tTU8^bn*NQCee~RJV9n+#KJEx5Oq+Y=rQX;uXVywV{6$AG6(G8HA_eijI@Y8ot6DVs+IxAs0bc~0a~2f z{7xNY_{^Rkg*R3?b#pi@ulx+) z&HyGKHVmn;Vz@?&y|DO}5{^H6#}d?MgD$FynH&5i)S-2TWt>48N&#}ilL)ssoihVU zH)zeX=L!~IQprW0Tet_u@|M8E!2t6ipG;J*x5S}kQ4+Wei5E_93Wlw7tnoLouvYxk zprj$Y=#!W)b92C@v*?9gWWrgjZ!!8ZmBN?7xF@kp9#Yo{=mWy|MLy<#Q@rdsc^EGd zrQ}InQY{w%jyI%z$I6E}F5`sYhYmQua1FW3tw0+V%Qq_rb_}wLfl(Pq&}j7viMv6r z5MiiKUs0@$kdoD7)BVo_Z5ZWq4p|nbEqH~(qRW;40KCD-RZIh#BWDn>R^w57j`}yA zqU|fyIUM3wiW{dLq2C|m(lMpx>)>gaFhL>$@5F0>iNS2tK!o6c0=Ll$Roh+8TP7Gnfo1YRXt2k|U(YGmse}zSQXlKj}Xw) zBBv$TOX3)pL&_24G290TwKZ5rhcE(}GwvNI%=ypkd;FPiYJR}@gbjP<%Kd_@MB4JP z`G8~xxS78&AX3oca|?kmiEb&^nCSq?VL_VaUJ_~CG-t=G#?P!z^!2yk6rx40?@o(+#ygaW4n<>rfaxiH7vG~^;Ohi}EQD9*aHGc=hPioRd zte%;#8uPgFy;QDWv%#j0KiP)RAiN)h2Qa9GF^qd7j4blxyYmp>n&ch6V4NTw#(N=} znpgIMHv?V7#zu~y37ZSzqhtd*g7T#@&sQvPNyG*og67m{>cZ zcv|0FCAu8hf ziY8+u5{hzpxWol^0bG5@D9ha(TR2L@Yzsz|6vmO=;C7ir23}!ZMg5f)FOIpHWiqI~ zJmv%(w$wa^;n-yeD3_N0GYutW^78kX1+;9b%AV#VumlRx>C5RC5}{*Ta?>bjS)%q! zqrlTp`8hy~XxTX#PM#)85)^a8ms zMq;S9%Wf>Tyh}dBcrO0{xq%{BO2z)ih_T>;S@Gg zUtSXcB;I~yP_|arN;;`iHh|xP--tjM%WY)5s?2nx6bK5gVrWeRxUDyw(QH)?<70vI z8;fM>x-sSbP4})1eUS9U49K*2XXZHYFuGqB>zF&aVZh&EzW3gs0{*3xFVI5rb^id8;uI>8F;3***Ffs!SThd61lM*%PYEsVs@>_Fu^{A2 za7XZ5FRj<_feYe$H-uqk1K5d2_1@JIj1a+-Dv_ z3NCEH#JC#pthPl4G#5E9F|wmMe*vjPmP$~+9ZF|rFzvo8#9wv)08B#JbS(K}UlXiT z4e(*8Qu1BJJT=S}70|j4Sd2+jn)-(Os&fILG}E8iisSw5y7-$I5f$4}HAGv}Fjffb za7CQT;CmG?$Z|G-$EB66A85b;22;X<_y_YFKXn1iLG~=a+>aH|Ot}0rFazD~+EZ@S z7*N$!BvWufzYz)6=Jdvh{%w8)EaVsNjIF}N^_a4_PxyJh(Yl-HEdsg`YW1(YlV85CRyv>OMN{w7CX;RU?rB|@={ zxUx4_%H3kU#aC<@%&XQQ5TW(y98vm!>q!`V zBEn^jS>;>91A4wz^-;tSd?pG2OM$dQh|&u<7U);Z5SFrItxO^vZQR6=aN-OAdY3$^ zydn6zg{wB3Az%72DA_LNT)TAC!!?2-oUAL{!cGP-+OFlqZuPKml%>^TTN)G}v6KZc z2%m{4Ec0xB68;8(r=d|ayF^~`daa&iqU$A|Vzmwil{|Z+v=cP=gmx3P)y;W8S``%&ka0Cp1Jro&Oh)`ojE6QT@4Jsn51*zbElLtyb%Ht6#jMvd$v zkJ!6Z0IOm;^0V~`P@w*t48xuDGTXqi;K2+H@Rwbtf}^u!rKxTXE&{4b>E2hO-!qL< z%RV-L*@)_oH4!TN7`VKwX=((Z5sKzAI-!xrh>@sui+s(= z-fs1mAJm~@rpJCWc!!O+dx0_4~Gc)^+$?uENWfG(T1Yd#Bcjb#vE~T*yv{Hs{07eJNeXc6qmrX9GnvdcO zDk3XAN2^@UAOSv@g%_3_`GxA|OBQ(+HLs|#8+2BOh`D^<)Ncau79Wc97h9kFH47!5>r^_oH zu*e}6TsK|+0M`gAo)CV)`ZZtvv{p9gT*1>dBbvhI8c6zTJb}==RvHK{o z>+HH|;z*I_l!_{_{{a8Q03{Fs0RRF50s;a90RaI3000015da}EK~Z6GfsvuH!SK=H z@&DQY2mt{A0Y4BLIZmG6j5*k?4_CwPWh+%L2UKpacu*i3Ce}sQU1JN8qS>r35?LEzz>bc+F=3nO_knrv9*DMw9WSV03YE0wcKG&Cz^hn`3;9gVfZz z>o1g`BATFa(evs8t!XTf?QY)&3U1#xNFW}RERXY_QX-kp9~#a8>`B_U&a&@|*L)s( z!rBo+yE?^eLpUEVjBr?B;7)yfXLzI9HjiW8ZTlq;bB$%%=&b0}e1ylWagAHf4R3ZB zMvaKF46dcRLyhZm@uJs}29G>BYpjIRkXbjW;|WeR6dhS^nsFdhA~4W#HFJ_x2O<9eTmexKl_)j;0Ba;n&8Qw|4Ry6YQZEIf0AD+=;A&bs|fTi}YE8{+H7d9JW7 z1S9qF-NQX9i*4)HKtL2LzdkX8v;)*N&v-(4qbYh#cZnnlFK(Rp!nFuG9j{^TVYPFH zv9M{T*!ajp<0KQXA&iCA_}+A{9z&OZSOG&1b(}t20@VN@9ydFV01=YtJ>NdPV{q>| z_xSc@qzUp&>IZ7$I;wRd-%OhGj1`Na0`dcneK@hoS_5@FU*`sjk`FQfh?l+dk7~UU zM51g&*v+g11=5Z;9z10Ql?kA4YI5a*+S;W`Izx;YVs1e;cqcxwBP|+{PX>VPnN-3S z2;Lu@9l<#WJXejGq=KCXk>SW{fcPIdsX8b_{$!yCjB*C}+`s_=0QMmKVn7OTX;+W- zvc^J2lcU}aOc+EazwS>a&Kd#HqHDfzP=e_Eyg2=4`dteMDfmWkn$kyn1mnjpMX(VF z9QoDG%z*%29wpC#kVRdm-VG{L=QEm~IzBO|tzebG`+hKzL}ZN{J=cxrHdMsSgK!B= zn>xZt0I?DhKjpwAh)0Y&x{bnKIjz9!sMu67p>h!v z3D(bS>fvt*WJ4D~dn=GRfk2>6mv55;#S%skQlhpZ1?bJ|lK25nY29~RMAe{3PX;J? zGz**31T6%ho3PwF#K;j49vmp!kgX(#TL1yyc&6b6!m7YFh-R?im`-U;ogxjCPgrs| zAsr1bk-}9-DdbSvUyKy=8sVsWc=^c@8vSE?_lEC%4}+V4nk24QZcZ^s1s$aMJ^uiC zq!n5_5p~}3&;wL9h5j&fir9eYKfHJVkrn~kVFd(EQ~mv7fD2%O2A=j|U@74M-{YJ@ zszgx|S+5rc0IVb}*NAUeUF&;Sli&Vkp$SF?Z_V#DK6bEKr}u{0C|eGf)2XbSiXoGzWsXx|~zQ&_~9AZ#!fyeHivtFqfwA<=bR2;k}7 z4$V=#WGKLz9zrtN>l0+*>M0;FQ1oG1hE$t$JJ#o{8PXp}o5N7M!H-o0j4@*O$KEEC z4lG1c!OTcD`=xazfsdcOh|_|i)OIu1)w~2-1^8)zZu_)gm#XEJLz*$ z=5lOCfDMM(jv++A*}M?(g$A7x!JvnHS%RvXBX?2TNv?ISa;YG9N`(*&rT+l9+N?#% zA!3C(+I(RGFoYlRAGa<5pAHf{v4bcW zNU@O$TL#O3D7(wDh*$Li$}$8s2FmxO%21FWaQ<-nX3*L?C%^o}FxDLCG03nYqek?7 zf3=;{dm)C{FmS-TPI3ffx60^y?;NwT;BUif@x;yrU9~!C?^Cx)Vp|6aKS)%n-gU=LCQjJ}IydI{3n)!@&}@ZQdC>Ma#GY^yDvI z6D_3RRd^2ZqpfD8D@0AFNF(axv>@oVkEOm$yB%c;fCYs$W6In_ETIo_?bw?n@Zlm> zQq2gV0Bffi0TAxtQgC!{6vRO@1Fj2B+^erW;DJGHb5A_1IfYm;0Q+nqrh?5_3tm3Ya0L@P%!s}#UELxHe%_k z$3>*0O+*G{0b#c_VQx)VzK7t*bn}3%Z$YO92m-gNWnM1+yUC!BP_8sNHHiDB&=Q0M zc<>NhiegfE3qEU3TAYqD6-Yh`SM`rt5KRcE{Q1Rvi+Dfuz~rL@i=lPy=G@x0sLOrK zopuzr#9lEcsYE>t=HB*ZB6R$w%k1ff(+Vv&LRM;J}}ZX9T4R{_`Wf46};+` z?-7z5HAeirYY(JIDi4v~1SS#$CiOqlH&qSk91QFMjVAJ<9DE3|{AXAcymW;kzbA|x zD)Ewn0D#6{F-PWM*FlY;x4w=rTuK)5-YSolHxy=&A;<8ba{ljt{qxTO^4*rf=(3do5=lzXy zaVzc8n_A&$3AVq{0<(T{L?H&c(7iJ^l?qUask%fanU(;MG=iZ34n_IEA?k~pr&Jq{ zE(mj{g^g^8HGp_-1=H*!tBTr$>DqowFAUXeozM4Jfr%WDJp0~ChGj|Z{{S-eWCO9| zee;~8K-ke|tP^U7Tmy2Ypab3mRM7N!*SrE3?h&mIkKQ#AD}@L=vg;-=GoIVVk-Rqs zpb4h;#x=q&0R^|&?-2wT0H`Ci3A0%}Vxkaof&=-uq_={!HvNBjvGC}M3**;$QPBV{NzLKfak ziX-G=NDvF6-Wqaj%8(Tm(cTFuqjG`W1*tY*Ru{jLCiGlR#tAQUQ?&<|BM>We%7hpz zDZJ|lHxIGFU&b*42_8FNfAbi~GrLR928;nkjdao~rPi{MK})Y?@*CNKAcS~-^uJjk z;QO5ikVd<>U?SzcK`Ei_)Q=?@YEZ!9hNrD)?2a&u>$rW4&K}#K(fRapYNJMP3)!lA6pl(pA(si%- zlh6aOwMp}L;}xS_XH=($e&~d7WOizo0!OgV*{{UQ}7A_1Qn~ZY02s3(~@DxCeB|6frJtO)t=s_FK^F#%wHeI&eF9#SM zAYGbv>7l1HG)NE>5`FW|M_2$00s+Ot5ZaBSnSm(*DGi0Xrv|y!E=rM%EiB~m&IJ?{ zp%@0#n)8-pZAu>Whu?Ta5PA-c`26Bo0er!BBz%FoQvg^#vbE^Y1NzJTH2MY2?n8Li zYJts#cl3Q>@@RK~1=C)C=3X#4^q>yR9;PVBsw^FykbPlsk8YWbX z8_4AYe<@*xUCPGy*_sdl4u#z?m0Jf^QRu`)1e9RU=y8X9myEe6XApw;vpZyB(MzN$ zYvUI%gto3)gZL+mc$~pW71$BEzpNTUCWZ|*;Dr^Qao^z0S~@E2-&s&4r23<8CNtC{ z;LEUh0?aibEjMpkE6xg)BXy_KHB!}C(TyBjv42q}uUQIZ{{W(77I89C69qjf0+GN7 z({HNZC+{AGgzh|@jBkt*V^ZNfzi#tq^{ipuVY(ypy*{&CX%c9CFs(owb)C6x!tAEE zR?OnSCl-r`Dt}nSDqtOJEffjsDA4tqC}<4)8q$FF!G z4x!Lgnx94u@^AyYjy{mE3x-1nK)SE2wAfJm4!FO0=PK5({eE+z;CF@%pcPH9_8Q^i z8&r6pJ6-Z|Gd`qgQNe5|@5XPCL6k<0H#}=JUJ*pKRo>0*ywnM}mK84SfAb3y8OjuB zUuBrQ2ZU!Yho3mkAjEBJYP4=7=4CJu@ZV)&bkI|* zBp_M_j#FFw;>3!nh#ExQj?C4CqMpl6mtxtPYAmqRw2c$t!^jKG_!ZZ2*BZ$Y01&3w zfP!CHEJ4(qJ@WI45ft5-M-C&%Cs&&Hm&tqTg72$!^k6Eu0IKdgQVKhKWLN+#ATr?H zgTg#GRK+9#6&wS7b((e^?FA$9L~%Ec+?0kAaEBz0@rcrdu+W#l zmB*40@mh-1^>9H0R3bxxC7%7|s8Mo`&4Cd? zQ6p224k1J-Bi8xfj87pRa*vjA)ejWl3@It*HeU^qT9Z$T!?)O}&J1S=0|=bu?H zDjjV-2e$MC29%N5q|Qq zD=lHEee;}}mPif{mnYXCG~WlTZ89E7!+&!G(2!6LE1ll5g{tK$qu$-c$2l080M@pM z^k)kfdqF0bCTqnG&O+v^YlyIgb|eTI;90(~){{msooW&P0GP-Ea}do$7uz?nus{u; zsqYW?UNG#Man*)zZP`Fn?vL5q=xJ5s?40xFu5k;|suNX_R6)EY>*a2LGIQz~? z6jT*4utgE54V5YY7kSHU9S2G!>r~#dm{(HAy>-{)D8ctMb$H`=5H#XMzfEt<$JtT{ zDbh#J1}Z>!Zz-Nh#-BN0PhCewa>?QKmCONCXu@(qd&ok0ApjZ>kSV@$EUj?>Y;eW< z#Z!Z&CDhv7Jd+H7_rNxAxDt?abRMI0JKLHFuLB)78b`!@=OzWg}K*l zqb@sOUDECrTFnl6Rbb9CO7<@(J;rO`*o1GL(dC zctCQg*vw%~o;0p}4d+yrL_r*F=5kCfKo^0W-+MU7ktoO} zzf8x1P)JvRYY0^w6?6q5M5>B)K3wt0o}dTnVvVqPRO|iYA`&IA4aYn2?-boo;*UW< z-;5Pu;5w+A5iBsI0@^9L-R~&#U8aHq{{_!m$$M7AJn?5m+ zBd{5ubfG?+VSzS4o+in@7OqC7Wne2Jmu{-*#<@qt$8lCEHnbUC+-a);Su}%Vd%`?I zP?k1G4XWZdksKNb!vUr`HF+dH$0d3#E)^JjM37U`vSUPZ~5}dEi z>Gg(EQd6}Z@$g_tr#h&0@@q7}PYrwZHHuIGNUu!)09=h^*Ag*jjpm3cTh5GOJeS5c z8p>uWY*oBn=7R+BiI;>v1|<&F`;XK6*MV0ulI=xMMRp@ezk~UWE3T@urG`>bwLewae%O@@E0^h{+(jHhhnWJG*Jj@ z2^2(YLYJ%x(h=6673(--BJ%S1m;eCPp!DlkcRI;H8>eIH?7n^x0exe=dHz>`Nj^N<8FGK;N+bvJR7w&#^F0Gbbj3pqht zn1KzNYoW$)0IvX0(IfHT0=kH2O_`^R7GhNyr7J_lJavpwS_$VJBsHPYF$4!4En4CB z!QiG231ChG3N+H_qt|=KmBC;%U0ZeR#sf>SMywW)y zAm9VZ=cfn4$^|%{D8yE4SJx*ytajb!eoTDF&w;!)hs&&S zP=d>5cgcL=+>r1GfjRS+8WCDgmw$LVD;A%PelYkT6{m?Q@^6k1B}%y)AZXB8Fa!V( z>roE-!oZjH=_fH{+k2638mp7$%hWF=&lcDPPEbv1oxt2-LqAOT$m##4||{ zZ^ir8Sx{1PJ>e`6Mfg zePnX!$Z!DFjo(}Ajed|gTnZu(d%>v4FtCm5WWkESBDGu%km~O zL8cGE-Qx)P6 zruf5Au!=7=_|A4B2*XvV{{ZGsi1Mb%$nZTUtP`jo+|Y{Ok-hE+2;M~z#0A_Y2uQ~B z4(o3D#ux{iTl&C_fMY>BZoczQ4G81%{NPc}ciCi#2*XI|U4udcg_{Z_E#V&w$H=2Av z>k|fCBx$HleFhqqw6KEbr_-5x;oVS^^E8yD9YW*r3q1Mi^U@rA5|LR~`e{9t~h zBTm)k{CALmD*7Nt_SkX ztvvXo3c{;^)Igz`ArokkOxzVb%x1p@CX8p^5F9KjC!ZM#8-Vk5-zM^@V09Yb$Na#9 zXbumrUUA)s-CQ{9`@%;`O_jng-}+z>fFa2(3LVRfHSL6uMwEv>^_DnZXGc+$2H%m# zU=PedX>e&2L6%5}D^Q4t+Vi6;D zZCMsj6x`Yy#XiBw7lVFx^58D$ORs~&-;)Le!C)6#-itC1y@q(;-_^ii@S|Qq#r$kI z$AnmLnxL7Y%+Uc+pl|O=#VZQ!p_?fa9i}r?=HWn7&^m%|=H*ejR2o)QtIi`(TG4>b zSTdE!y(;O3l}?NfuaL)zF<3|iaBzQ|rGm?hNZOt|GAsI~5aVDv=KbWAM1X}-#`>Hd z?$8vyolK%gkgpikFecrz2S`qe`OaXp7Euw2AYK8x$Cz0129%SNkN)J%3fqK|GCHp_ zX7QB-sRc235Fr{cM7*fb98G{5z%Y{lQh)+TT3+ThNm3Ll)CYQo7pq{YqL!*UO%E&< zni?WUgzK&TbAB&`>LBDz#(LuzTDW*|Bpmat^_)PBu$))(@s~?CPBp5wl(B`lSS6?( z00`>$dBGw5yO&2^xYiM#SlHR_o41#(W}^^#8V-I6^>F@lRUl9f9YC4C;=dvWuy9?{ z>SUsb`2a~r)Yd36t}vT T): T = { + val flagName = FileInputFormat.INPUT_DIR_RECURSIVE + val hadoopConf = spark.sparkContext.hadoopConfiguration + val old = Option(hadoopConf.get(flagName)) + hadoopConf.set(flagName, value.toString) + try f finally { + old match { + case Some(v) => hadoopConf.set(flagName, v) + case None => hadoopConf.unset(flagName) + } + } + } +} + +/** + * Filter that allows loading a fraction of HDFS files. + */ +private class SamplePathFilter extends Configured with PathFilter { + val random = new Random() + + // Ratio of files to be read from disk + var sampleRatio: Double = 1 + + override def setConf(conf: Configuration): Unit = { + if (conf != null) { + sampleRatio = conf.getDouble(SamplePathFilter.ratioParam, 1) + val seed = conf.getLong(SamplePathFilter.seedParam, 0) + random.setSeed(seed) + } + } + + override def accept(path: Path): Boolean = { + // Note: checking fileSystem.isDirectory is very slow here, so we use basic rules instead + !SamplePathFilter.isFile(path) || random.nextDouble() < sampleRatio + } +} + +private object SamplePathFilter { + val ratioParam = "sampleRatio" + val seedParam = "seed" + + def isFile(path: Path): Boolean = FilenameUtils.getExtension(path.toString) != "" + + /** + * Sets the HDFS PathFilter flag and then restores it. + * Only applies the filter if sampleRatio is less than 1. + * + * @param sampleRatio Fraction of the files that the filter picks + * @param spark Existing Spark session + * @param seed Random number seed + * @param f The function to evaluate after setting the flag + * @return Returns the evaluation result T of the function + */ + def withPathFilter[T]( + sampleRatio: Double, + spark: SparkSession, + seed: Long)(f: => T): T = { + val sampleImages = sampleRatio < 1 + if (sampleImages) { + val flagName = FileInputFormat.PATHFILTER_CLASS + val hadoopConf = spark.sparkContext.hadoopConfiguration + val old = Option(hadoopConf.getClass(flagName, null)) + hadoopConf.setDouble(SamplePathFilter.ratioParam, sampleRatio) + hadoopConf.setLong(SamplePathFilter.seedParam, seed) + hadoopConf.setClass(flagName, classOf[SamplePathFilter], classOf[PathFilter]) + try f finally { + hadoopConf.unset(SamplePathFilter.ratioParam) + hadoopConf.unset(SamplePathFilter.seedParam) + old match { + case Some(v) => hadoopConf.setClass(flagName, v, classOf[PathFilter]) + case None => hadoopConf.unset(flagName) + } + } + } else { + f + } + } +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/image/ImageSchema.scala b/mllib/src/main/scala/org/apache/spark/ml/image/ImageSchema.scala new file mode 100644 index 0000000000000..f7850b238465b --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/image/ImageSchema.scala @@ -0,0 +1,257 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.image + +import java.awt.Color +import java.awt.color.ColorSpace +import java.io.ByteArrayInputStream +import javax.imageio.ImageIO + +import scala.collection.JavaConverters._ + +import org.apache.spark.annotation.{Experimental, Since} +import org.apache.spark.input.PortableDataStream +import org.apache.spark.sql.{DataFrame, Row, SparkSession} +import org.apache.spark.sql.types._ + +/** + * :: Experimental :: + * Defines the image schema and methods to read and manipulate images. + */ +@Experimental +@Since("2.3.0") +object ImageSchema { + + val undefinedImageType = "Undefined" + + /** + * (Scala-specific) OpenCV type mapping supported + */ + val ocvTypes: Map[String, Int] = Map( + undefinedImageType -> -1, + "CV_8U" -> 0, "CV_8UC1" -> 0, "CV_8UC3" -> 16, "CV_8UC4" -> 24 + ) + + /** + * (Java-specific) OpenCV type mapping supported + */ + val javaOcvTypes: java.util.Map[String, Int] = ocvTypes.asJava + + /** + * Schema for the image column: Row(String, Int, Int, Int, Int, Array[Byte]) + */ + val columnSchema = StructType( + StructField("origin", StringType, true) :: + StructField("height", IntegerType, false) :: + StructField("width", IntegerType, false) :: + StructField("nChannels", IntegerType, false) :: + // OpenCV-compatible type: CV_8UC3 in most cases + StructField("mode", IntegerType, false) :: + // Bytes in OpenCV-compatible order: row-wise BGR in most cases + StructField("data", BinaryType, false) :: Nil) + + val imageFields: Array[String] = columnSchema.fieldNames + + /** + * DataFrame with a single column of images named "image" (nullable) + */ + val imageSchema = StructType(StructField("image", columnSchema, true) :: Nil) + + /** + * Gets the origin of the image + * + * @return The origin of the image + */ + def getOrigin(row: Row): String = row.getString(0) + + /** + * Gets the height of the image + * + * @return The height of the image + */ + def getHeight(row: Row): Int = row.getInt(1) + + /** + * Gets the width of the image + * + * @return The width of the image + */ + def getWidth(row: Row): Int = row.getInt(2) + + /** + * Gets the number of channels in the image + * + * @return The number of channels in the image + */ + def getNChannels(row: Row): Int = row.getInt(3) + + /** + * Gets the OpenCV representation as an int + * + * @return The OpenCV representation as an int + */ + def getMode(row: Row): Int = row.getInt(4) + + /** + * Gets the image data + * + * @return The image data + */ + def getData(row: Row): Array[Byte] = row.getAs[Array[Byte]](5) + + /** + * Default values for the invalid image + * + * @param origin Origin of the invalid image + * @return Row with the default values + */ + private[spark] def invalidImageRow(origin: String): Row = + Row(Row(origin, -1, -1, -1, ocvTypes(undefinedImageType), Array.ofDim[Byte](0))) + + /** + * Convert the compressed image (jpeg, png, etc.) into OpenCV + * representation and store it in DataFrame Row + * + * @param origin Arbitrary string that identifies the image + * @param bytes Image bytes (for example, jpeg) + * @return DataFrame Row or None (if the decompression fails) + */ + private[spark] def decode(origin: String, bytes: Array[Byte]): Option[Row] = { + + val img = ImageIO.read(new ByteArrayInputStream(bytes)) + + if (img == null) { + None + } else { + val isGray = img.getColorModel.getColorSpace.getType == ColorSpace.TYPE_GRAY + val hasAlpha = img.getColorModel.hasAlpha + + val height = img.getHeight + val width = img.getWidth + val (nChannels, mode) = if (isGray) { + (1, ocvTypes("CV_8UC1")) + } else if (hasAlpha) { + (4, ocvTypes("CV_8UC4")) + } else { + (3, ocvTypes("CV_8UC3")) + } + + val imageSize = height * width * nChannels + assert(imageSize < 1e9, "image is too large") + val decoded = Array.ofDim[Byte](imageSize) + + // Grayscale images in Java require special handling to get the correct intensity + if (isGray) { + var offset = 0 + val raster = img.getRaster + for (h <- 0 until height) { + for (w <- 0 until width) { + decoded(offset) = raster.getSample(w, h, 0).toByte + offset += 1 + } + } + } else { + var offset = 0 + for (h <- 0 until height) { + for (w <- 0 until width) { + val color = new Color(img.getRGB(w, h)) + + decoded(offset) = color.getBlue.toByte + decoded(offset + 1) = color.getGreen.toByte + decoded(offset + 2) = color.getRed.toByte + if (nChannels == 4) { + decoded(offset + 3) = color.getAlpha.toByte + } + offset += nChannels + } + } + } + + // the internal "Row" is needed, because the image is a single DataFrame column + Some(Row(Row(origin, height, width, nChannels, mode, decoded))) + } + } + + /** + * Read the directory of images from the local or remote source + * + * @note If multiple jobs are run in parallel with different sampleRatio or recursive flag, + * there may be a race condition where one job overwrites the hadoop configs of another. + * @note If sample ratio is less than 1, sampling uses a PathFilter that is efficient but + * potentially non-deterministic. + * + * @param path Path to the image directory + * @return DataFrame with a single column "image" of images; + * see ImageSchema for the details + */ + def readImages(path: String): DataFrame = readImages(path, null, false, -1, false, 1.0, 0) + + /** + * Read the directory of images from the local or remote source + * + * @note If multiple jobs are run in parallel with different sampleRatio or recursive flag, + * there may be a race condition where one job overwrites the hadoop configs of another. + * @note If sample ratio is less than 1, sampling uses a PathFilter that is efficient but + * potentially non-deterministic. + * + * @param path Path to the image directory + * @param sparkSession Spark Session, if omitted gets or creates the session + * @param recursive Recursive path search flag + * @param numPartitions Number of the DataFrame partitions, + * if omitted uses defaultParallelism instead + * @param dropImageFailures Drop the files that are not valid images from the result + * @param sampleRatio Fraction of the files loaded + * @return DataFrame with a single column "image" of images; + * see ImageSchema for the details + */ + def readImages( + path: String, + sparkSession: SparkSession, + recursive: Boolean, + numPartitions: Int, + dropImageFailures: Boolean, + sampleRatio: Double, + seed: Long): DataFrame = { + require(sampleRatio <= 1.0 && sampleRatio >= 0, "sampleRatio should be between 0 and 1") + + val session = if (sparkSession != null) sparkSession else SparkSession.builder().getOrCreate + val partitions = + if (numPartitions > 0) { + numPartitions + } else { + session.sparkContext.defaultParallelism + } + + RecursiveFlag.withRecursiveFlag(recursive, session) { + SamplePathFilter.withPathFilter(sampleRatio, session, seed) { + val binResult = session.sparkContext.binaryFiles(path, partitions) + val streams = if (numPartitions == -1) binResult else binResult.repartition(partitions) + val convert = (origin: String, bytes: PortableDataStream) => + decode(origin, bytes.toArray()) + val images = if (dropImageFailures) { + streams.flatMap { case (origin, bytes) => convert(origin, bytes) } + } else { + streams.map { case (origin, bytes) => + convert(origin, bytes).getOrElse(invalidImageRow(origin)) + } + } + session.createDataFrame(images, imageSchema) + } + } + } +} diff --git a/mllib/src/test/scala/org/apache/spark/ml/image/ImageSchemaSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/image/ImageSchemaSuite.scala new file mode 100644 index 0000000000000..dba61cd1eb1cc --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/image/ImageSchemaSuite.scala @@ -0,0 +1,108 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.image + +import java.nio.file.Paths +import java.util.Arrays + +import org.apache.spark.SparkFunSuite +import org.apache.spark.ml.image.ImageSchema._ +import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.sql.Row +import org.apache.spark.sql.types._ + +class ImageSchemaSuite extends SparkFunSuite with MLlibTestSparkContext { + // Single column of images named "image" + private lazy val imagePath = "../data/mllib/images" + + test("Smoke test: create basic ImageSchema dataframe") { + val origin = "path" + val width = 1 + val height = 1 + val nChannels = 3 + val data = Array[Byte](0, 0, 0) + val mode = ocvTypes("CV_8UC3") + + // Internal Row corresponds to image StructType + val rows = Seq(Row(Row(origin, height, width, nChannels, mode, data)), + Row(Row(null, height, width, nChannels, mode, data))) + val rdd = sc.makeRDD(rows) + val df = spark.createDataFrame(rdd, ImageSchema.imageSchema) + + assert(df.count === 2, "incorrect image count") + assert(df.schema("image").dataType == columnSchema, "data do not fit ImageSchema") + } + + test("readImages count test") { + var df = readImages(imagePath) + assert(df.count === 1) + + df = readImages(imagePath, null, true, -1, false, 1.0, 0) + assert(df.count === 9) + + df = readImages(imagePath, null, true, -1, true, 1.0, 0) + val countTotal = df.count + assert(countTotal === 7) + + df = readImages(imagePath, null, true, -1, true, 0.5, 0) + // Random number about half of the size of the original dataset + val count50 = df.count + assert(count50 > 0 && count50 < countTotal) + } + + test("readImages partition test") { + val df = readImages(imagePath, null, true, 3, true, 1.0, 0) + assert(df.rdd.getNumPartitions === 3) + } + + // Images with the different number of channels + test("readImages pixel values test") { + + val images = readImages(imagePath + "/multi-channel/").collect + + images.foreach { rrow => + val row = rrow.getAs[Row](0) + val filename = Paths.get(getOrigin(row)).getFileName().toString() + if (firstBytes20.contains(filename)) { + val mode = getMode(row) + val bytes20 = getData(row).slice(0, 20) + + val (expectedMode, expectedBytes) = firstBytes20(filename) + assert(ocvTypes(expectedMode) === mode, "mode of the image is not read correctly") + assert(Arrays.equals(expectedBytes, bytes20), "incorrect numeric value for flattened image") + } + } + } + + // number of channels and first 20 bytes of OpenCV representation + // - default representation for 3-channel RGB images is BGR row-wise: + // (B00, G00, R00, B10, G10, R10, ...) + // - default representation for 4-channel RGB images is BGRA row-wise: + // (B00, G00, R00, A00, B10, G10, R10, A00, ...) + private val firstBytes20 = Map( + "grayscale.jpg" -> + (("CV_8UC1", Array[Byte](-2, -33, -61, -60, -59, -59, -64, -59, -66, -67, -73, -73, -62, + -57, -60, -63, -53, -49, -55, -69))), + "chr30.4.184.jpg" -> (("CV_8UC3", + Array[Byte](-9, -3, -1, -43, -32, -28, -75, -60, -57, -78, -59, -56, -74, -59, -57, + -71, -58, -56, -73, -64))), + "BGRA.png" -> (("CV_8UC4", + Array[Byte](-128, -128, -8, -1, -128, -128, -8, -1, -128, + -128, -8, -1, 127, 127, -9, -1, 127, 127, -9, -1))) + ) +} diff --git a/python/docs/pyspark.ml.rst b/python/docs/pyspark.ml.rst index 01627ba92b633..6a5d81706f071 100644 --- a/python/docs/pyspark.ml.rst +++ b/python/docs/pyspark.ml.rst @@ -97,6 +97,14 @@ pyspark.ml.fpm module :undoc-members: :inherited-members: +pyspark.ml.image module +---------------------------- + +.. automodule:: pyspark.ml.image + :members: + :undoc-members: + :inherited-members: + pyspark.ml.util module ---------------------------- diff --git a/python/pyspark/ml/image.py b/python/pyspark/ml/image.py new file mode 100644 index 0000000000000..7d14f05295572 --- /dev/null +++ b/python/pyspark/ml/image.py @@ -0,0 +1,198 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +""" +.. attribute:: ImageSchema + + An attribute of this module that contains the instance of :class:`_ImageSchema`. + +.. autoclass:: _ImageSchema + :members: +""" + +import numpy as np +from pyspark import SparkContext +from pyspark.sql.types import Row, _create_row, _parse_datatype_json_string +from pyspark.sql import DataFrame, SparkSession + + +class _ImageSchema(object): + """ + Internal class for `pyspark.ml.image.ImageSchema` attribute. Meant to be private and + not to be instantized. Use `pyspark.ml.image.ImageSchema` attribute to access the + APIs of this class. + """ + + def __init__(self): + self._imageSchema = None + self._ocvTypes = None + self._imageFields = None + self._undefinedImageType = None + + @property + def imageSchema(self): + """ + Returns the image schema. + + :return: a :class:`StructType` with a single column of images + named "image" (nullable). + + .. versionadded:: 2.3.0 + """ + + if self._imageSchema is None: + ctx = SparkContext._active_spark_context + jschema = ctx._jvm.org.apache.spark.ml.image.ImageSchema.imageSchema() + self._imageSchema = _parse_datatype_json_string(jschema.json()) + return self._imageSchema + + @property + def ocvTypes(self): + """ + Returns the OpenCV type mapping supported. + + :return: a dictionary containing the OpenCV type mapping supported. + + .. versionadded:: 2.3.0 + """ + + if self._ocvTypes is None: + ctx = SparkContext._active_spark_context + self._ocvTypes = dict(ctx._jvm.org.apache.spark.ml.image.ImageSchema.javaOcvTypes()) + return self._ocvTypes + + @property + def imageFields(self): + """ + Returns field names of image columns. + + :return: a list of field names. + + .. versionadded:: 2.3.0 + """ + + if self._imageFields is None: + ctx = SparkContext._active_spark_context + self._imageFields = list(ctx._jvm.org.apache.spark.ml.image.ImageSchema.imageFields()) + return self._imageFields + + @property + def undefinedImageType(self): + """ + Returns the name of undefined image type for the invalid image. + + .. versionadded:: 2.3.0 + """ + + if self._undefinedImageType is None: + ctx = SparkContext._active_spark_context + self._undefinedImageType = \ + ctx._jvm.org.apache.spark.ml.image.ImageSchema.undefinedImageType() + return self._undefinedImageType + + def toNDArray(self, image): + """ + Converts an image to an array with metadata. + + :param image: The image to be converted. + :return: a `numpy.ndarray` that is an image. + + .. versionadded:: 2.3.0 + """ + + height = image.height + width = image.width + nChannels = image.nChannels + return np.ndarray( + shape=(height, width, nChannels), + dtype=np.uint8, + buffer=image.data, + strides=(width * nChannels, nChannels, 1)) + + def toImage(self, array, origin=""): + """ + Converts an array with metadata to a two-dimensional image. + + :param array array: The array to convert to image. + :param str origin: Path to the image, optional. + :return: a :class:`Row` that is a two dimensional image. + + .. versionadded:: 2.3.0 + """ + + if array.ndim != 3: + raise ValueError("Invalid array shape") + height, width, nChannels = array.shape + ocvTypes = ImageSchema.ocvTypes + if nChannels == 1: + mode = ocvTypes["CV_8UC1"] + elif nChannels == 3: + mode = ocvTypes["CV_8UC3"] + elif nChannels == 4: + mode = ocvTypes["CV_8UC4"] + else: + raise ValueError("Invalid number of channels") + data = bytearray(array.astype(dtype=np.uint8).ravel()) + # Creating new Row with _create_row(), because Row(name = value, ... ) + # orders fields by name, which conflicts with expected schema order + # when the new DataFrame is created by UDF + return _create_row(self.imageFields, + [origin, height, width, nChannels, mode, data]) + + def readImages(self, path, recursive=False, numPartitions=-1, + dropImageFailures=False, sampleRatio=1.0, seed=0): + """ + Reads the directory of images from the local or remote source. + + .. note:: If multiple jobs are run in parallel with different sampleRatio or recursive flag, + there may be a race condition where one job overwrites the hadoop configs of another. + + .. note:: If sample ratio is less than 1, sampling uses a PathFilter that is efficient but + potentially non-deterministic. + + :param str path: Path to the image directory. + :param bool recursive: Recursive search flag. + :param int numPartitions: Number of DataFrame partitions. + :param bool dropImageFailures: Drop the files that are not valid images. + :param float sampleRatio: Fraction of the images loaded. + :param int seed: Random number seed. + :return: a :class:`DataFrame` with a single column of "images", + see ImageSchema for details. + + >>> df = ImageSchema.readImages('python/test_support/image/kittens', recursive=True) + >>> df.count() + 4 + + .. versionadded:: 2.3.0 + """ + + ctx = SparkContext._active_spark_context + spark = SparkSession(ctx) + image_schema = ctx._jvm.org.apache.spark.ml.image.ImageSchema + jsession = spark._jsparkSession + jresult = image_schema.readImages(path, jsession, recursive, numPartitions, + dropImageFailures, float(sampleRatio), seed) + return DataFrame(jresult, spark._wrapped) + + +ImageSchema = _ImageSchema() + + +# Monkey patch to disallow instantization of this class. +def _disallow_instance(_): + raise RuntimeError("Creating instance of _ImageSchema class is disallowed.") +_ImageSchema.__init__ = _disallow_instance diff --git a/python/pyspark/ml/tests.py b/python/pyspark/ml/tests.py index 2f1f3af957e4d..2258d61c95333 100755 --- a/python/pyspark/ml/tests.py +++ b/python/pyspark/ml/tests.py @@ -54,6 +54,7 @@ MulticlassClassificationEvaluator, RegressionEvaluator from pyspark.ml.feature import * from pyspark.ml.fpm import FPGrowth, FPGrowthModel +from pyspark.ml.image import ImageSchema from pyspark.ml.linalg import DenseMatrix, DenseMatrix, DenseVector, Matrices, MatrixUDT, \ SparseMatrix, SparseVector, Vector, VectorUDT, Vectors from pyspark.ml.param import Param, Params, TypeConverters @@ -1818,6 +1819,24 @@ def tearDown(self): del self.data +class ImageReaderTest(SparkSessionTestCase): + + def test_read_images(self): + data_path = 'data/mllib/images/kittens' + df = ImageSchema.readImages(data_path, recursive=True, dropImageFailures=True) + self.assertEqual(df.count(), 4) + first_row = df.take(1)[0][0] + array = ImageSchema.toNDArray(first_row) + self.assertEqual(len(array), first_row[1]) + self.assertEqual(ImageSchema.toImage(array, origin=first_row[0]), first_row) + self.assertEqual(df.schema, ImageSchema.imageSchema) + expected = {'CV_8UC3': 16, 'Undefined': -1, 'CV_8U': 0, 'CV_8UC1': 0, 'CV_8UC4': 24} + self.assertEqual(ImageSchema.ocvTypes, expected) + expected = ['origin', 'height', 'width', 'nChannels', 'mode', 'data'] + self.assertEqual(ImageSchema.imageFields, expected) + self.assertEqual(ImageSchema.undefinedImageType, "Undefined") + + class ALSTest(SparkSessionTestCase): def test_storage_levels(self): From b4edafa99bd3858c166adeefdafd93dcd4bc9734 Mon Sep 17 00:00:00 2001 From: Jakub Nowacki Date: Thu, 23 Nov 2017 12:47:38 +0900 Subject: [PATCH 1713/1765] [SPARK-22495] Fix setup of SPARK_HOME variable on Windows ## What changes were proposed in this pull request? Fixing the way how `SPARK_HOME` is resolved on Windows. While the previous version was working with the built release download, the set of directories changed slightly for the PySpark `pip` or `conda` install. This has been reflected in Linux files in `bin` but not for Windows `cmd` files. First fix improves the way how the `jars` directory is found, as this was stoping Windows version of `pip/conda` install from working; JARs were not found by on Session/Context setup. Second fix is adding `find-spark-home.cmd` script, which uses `find_spark_home.py` script, as the Linux version, to resolve `SPARK_HOME`. It is based on `find-spark-home` bash script, though, some operations are done in different order due to the `cmd` script language limitations. If environment variable is set, the Python script `find_spark_home.py` will not be run. The process can fail if Python is not installed, but it will mostly use this way if PySpark is installed via `pip/conda`, thus, there is some Python in the system. ## How was this patch tested? Tested on local installation. Author: Jakub Nowacki Closes #19370 from jsnowacki/fix_spark_cmds. --- appveyor.yml | 1 + bin/find-spark-home.cmd | 60 +++++++++++++++++++++++++++++++++++++++++ bin/pyspark2.cmd | 2 +- bin/run-example.cmd | 4 ++- bin/spark-class2.cmd | 2 +- bin/spark-shell2.cmd | 4 ++- bin/sparkR2.cmd | 2 +- 7 files changed, 70 insertions(+), 5 deletions(-) create mode 100644 bin/find-spark-home.cmd diff --git a/appveyor.yml b/appveyor.yml index dc2d81fcdc091..48740920cd09b 100644 --- a/appveyor.yml +++ b/appveyor.yml @@ -33,6 +33,7 @@ only_commits: - core/src/main/scala/org/apache/spark/api/r/ - mllib/src/main/scala/org/apache/spark/ml/r/ - core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala + - bin/*.cmd cache: - C:\Users\appveyor\.m2 diff --git a/bin/find-spark-home.cmd b/bin/find-spark-home.cmd new file mode 100644 index 0000000000000..c75e7eedb9418 --- /dev/null +++ b/bin/find-spark-home.cmd @@ -0,0 +1,60 @@ +@echo off + +rem +rem Licensed to the Apache Software Foundation (ASF) under one or more +rem contributor license agreements. See the NOTICE file distributed with +rem this work for additional information regarding copyright ownership. +rem The ASF licenses this file to You under the Apache License, Version 2.0 +rem (the "License"); you may not use this file except in compliance with +rem the License. You may obtain a copy of the License at +rem +rem http://www.apache.org/licenses/LICENSE-2.0 +rem +rem Unless required by applicable law or agreed to in writing, software +rem distributed under the License is distributed on an "AS IS" BASIS, +rem WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +rem See the License for the specific language governing permissions and +rem limitations under the License. +rem + +rem Path to Python script finding SPARK_HOME +set FIND_SPARK_HOME_PYTHON_SCRIPT=%~dp0find_spark_home.py + +rem Default to standard python interpreter unless told otherwise +set PYTHON_RUNNER=python +rem If PYSPARK_DRIVER_PYTHON is set, it overwrites the python version +if not "x%PYSPARK_DRIVER_PYTHON%"=="x" ( + set PYTHON_RUNNER=%PYSPARK_DRIVER_PYTHON% +) +rem If PYSPARK_PYTHON is set, it overwrites the python version +if not "x%PYSPARK_PYTHON%"=="x" ( + set PYTHON_RUNNER=%PYSPARK_PYTHON% +) + +rem If there is python installed, trying to use the root dir as SPARK_HOME +where %PYTHON_RUNNER% > nul 2>$1 +if %ERRORLEVEL% neq 0 ( + if not exist %PYTHON_RUNNER% ( + if "x%SPARK_HOME%"=="x" ( + echo Missing Python executable '%PYTHON_RUNNER%', defaulting to '%~dp0..' for SPARK_HOME ^ +environment variable. Please install Python or specify the correct Python executable in ^ +PYSPARK_DRIVER_PYTHON or PYSPARK_PYTHON environment variable to detect SPARK_HOME safely. + set SPARK_HOME=%~dp0.. + ) + ) +) + +rem Only attempt to find SPARK_HOME if it is not set. +if "x%SPARK_HOME%"=="x" ( + if not exist "%FIND_SPARK_HOME_PYTHON_SCRIPT%" ( + rem If we are not in the same directory as find_spark_home.py we are not pip installed so we don't + rem need to search the different Python directories for a Spark installation. + rem Note only that, if the user has pip installed PySpark but is directly calling pyspark-shell or + rem spark-submit in another directory we want to use that version of PySpark rather than the + rem pip installed version of PySpark. + set SPARK_HOME=%~dp0.. + ) else ( + rem We are pip installed, use the Python script to resolve a reasonable SPARK_HOME + for /f "delims=" %%i in ('%PYTHON_RUNNER% %FIND_SPARK_HOME_PYTHON_SCRIPT%') do set SPARK_HOME=%%i + ) +) diff --git a/bin/pyspark2.cmd b/bin/pyspark2.cmd index 46d4d5c883cfb..663670f2fddaf 100644 --- a/bin/pyspark2.cmd +++ b/bin/pyspark2.cmd @@ -18,7 +18,7 @@ rem limitations under the License. rem rem Figure out where the Spark framework is installed -set SPARK_HOME=%~dp0.. +call "%~dp0find-spark-home.cmd" call "%SPARK_HOME%\bin\load-spark-env.cmd" set _SPARK_CMD_USAGE=Usage: bin\pyspark.cmd [options] diff --git a/bin/run-example.cmd b/bin/run-example.cmd index efa5f81d08f7f..cc6b234406e4a 100644 --- a/bin/run-example.cmd +++ b/bin/run-example.cmd @@ -17,7 +17,9 @@ rem See the License for the specific language governing permissions and rem limitations under the License. rem -set SPARK_HOME=%~dp0.. +rem Figure out where the Spark framework is installed +call "%~dp0find-spark-home.cmd" + set _SPARK_CMD_USAGE=Usage: ./bin/run-example [options] example-class [example args] rem The outermost quotes are used to prevent Windows command line parse error diff --git a/bin/spark-class2.cmd b/bin/spark-class2.cmd index a93fd2f0e54bc..5da7d7a430d79 100644 --- a/bin/spark-class2.cmd +++ b/bin/spark-class2.cmd @@ -18,7 +18,7 @@ rem limitations under the License. rem rem Figure out where the Spark framework is installed -set SPARK_HOME=%~dp0.. +call "%~dp0find-spark-home.cmd" call "%SPARK_HOME%\bin\load-spark-env.cmd" diff --git a/bin/spark-shell2.cmd b/bin/spark-shell2.cmd index 7b5d396be888c..aaf71906c6526 100644 --- a/bin/spark-shell2.cmd +++ b/bin/spark-shell2.cmd @@ -17,7 +17,9 @@ rem See the License for the specific language governing permissions and rem limitations under the License. rem -set SPARK_HOME=%~dp0.. +rem Figure out where the Spark framework is installed +call "%~dp0find-spark-home.cmd" + set _SPARK_CMD_USAGE=Usage: .\bin\spark-shell.cmd [options] rem SPARK-4161: scala does not assume use of the java classpath, diff --git a/bin/sparkR2.cmd b/bin/sparkR2.cmd index 459b780e2ae33..b48bea345c0b9 100644 --- a/bin/sparkR2.cmd +++ b/bin/sparkR2.cmd @@ -18,7 +18,7 @@ rem limitations under the License. rem rem Figure out where the Spark framework is installed -set SPARK_HOME=%~dp0.. +call "%~dp0find-spark-home.cmd" call "%SPARK_HOME%\bin\load-spark-env.cmd" From 42f83d7c40bb4e9c7c50f2cbda515b331fb2097f Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Thu, 23 Nov 2017 18:20:16 +0100 Subject: [PATCH 1714/1765] [SPARK-17920][FOLLOWUP] simplify the schema file creation in test ## What changes were proposed in this pull request? a followup of https://github.com/apache/spark/pull/19779 , to simplify the file creation. ## How was this patch tested? test only change Author: Wenchen Fan Closes #19799 from cloud-fan/minor. --- .../spark/sql/hive/client/VersionsSuite.scala | 21 ++++++++----------- 1 file changed, 9 insertions(+), 12 deletions(-) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala index fbf6877c994a4..9d15dabc8d3f5 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala @@ -843,15 +843,12 @@ class VersionsSuite extends SparkFunSuite with Logging { test(s"$version: SPARK-17920: Insert into/overwrite avro table") { withTempDir { dir => - val path = dir.getAbsolutePath - val schemaPath = s"""$path${File.separator}avroschemadir""" - - new File(schemaPath).mkdir() val avroSchema = - """{ + """ + |{ | "name": "test_record", | "type": "record", - | "fields": [ { + | "fields": [{ | "name": "f0", | "type": [ | "null", @@ -862,17 +859,17 @@ class VersionsSuite extends SparkFunSuite with Logging { | "logicalType": "decimal" | } | ] - | } ] + | }] |} """.stripMargin - val schemaUrl = s"""$schemaPath${File.separator}avroDecimal.avsc""" - val schemaFile = new File(schemaPath, "avroDecimal.avsc") + val schemaFile = new File(dir, "avroDecimal.avsc") val writer = new PrintWriter(schemaFile) writer.write(avroSchema) writer.close() + val schemaPath = schemaFile.getCanonicalPath val url = Thread.currentThread().getContextClassLoader.getResource("avroDecimal") - val srcLocation = new File(url.getFile) + val srcLocation = new File(url.getFile).getCanonicalPath val destTableName = "tab1" val srcTableName = "tab2" @@ -886,7 +883,7 @@ class VersionsSuite extends SparkFunSuite with Logging { | INPUTFORMAT 'org.apache.hadoop.hive.ql.io.avro.AvroContainerInputFormat' | OUTPUTFORMAT 'org.apache.hadoop.hive.ql.io.avro.AvroContainerOutputFormat' |LOCATION '$srcLocation' - |TBLPROPERTIES ('avro.schema.url' = '$schemaUrl') + |TBLPROPERTIES ('avro.schema.url' = '$schemaPath') """.stripMargin ) @@ -898,7 +895,7 @@ class VersionsSuite extends SparkFunSuite with Logging { |STORED AS | INPUTFORMAT 'org.apache.hadoop.hive.ql.io.avro.AvroContainerInputFormat' | OUTPUTFORMAT 'org.apache.hadoop.hive.ql.io.avro.AvroContainerOutputFormat' - |TBLPROPERTIES ('avro.schema.url' = '$schemaUrl') + |TBLPROPERTIES ('avro.schema.url' = '$schemaPath') """.stripMargin ) versionSpark.sql( From c1217565e20bd3297f3b1bc8f18f5dea933211c0 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Thu, 23 Nov 2017 15:33:26 -0800 Subject: [PATCH 1715/1765] [SPARK-22592][SQL] cleanup filter converting for hive ## What changes were proposed in this pull request? We have 2 different methods to convert filters for hive, regarding a config. This introduces duplicated and inconsistent code(e.g. one use helper objects for pattern match and one doesn't). ## How was this patch tested? existing tests Author: Wenchen Fan Closes #19801 from cloud-fan/cleanup. --- .../spark/sql/hive/client/HiveShim.scala | 144 +++++++++--------- 1 file changed, 69 insertions(+), 75 deletions(-) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala index bd1b300416990..1eac70dbf19cd 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala @@ -585,53 +585,17 @@ private[client] class Shim_v0_13 extends Shim_v0_12 { * Unsupported predicates are skipped. */ def convertFilters(table: Table, filters: Seq[Expression]): String = { - if (SQLConf.get.advancedPartitionPredicatePushdownEnabled) { - convertComplexFilters(table, filters) - } else { - convertBasicFilters(table, filters) - } - } - - - /** - * An extractor that matches all binary comparison operators except null-safe equality. - * - * Null-safe equality is not supported by Hive metastore partition predicate pushdown - */ - object SpecialBinaryComparison { - def unapply(e: BinaryComparison): Option[(Expression, Expression)] = e match { - case _: EqualNullSafe => None - case _ => Some((e.left, e.right)) + /** + * An extractor that matches all binary comparison operators except null-safe equality. + * + * Null-safe equality is not supported by Hive metastore partition predicate pushdown + */ + object SpecialBinaryComparison { + def unapply(e: BinaryComparison): Option[(Expression, Expression)] = e match { + case _: EqualNullSafe => None + case _ => Some((e.left, e.right)) + } } - } - - private def convertBasicFilters(table: Table, filters: Seq[Expression]): String = { - // hive varchar is treated as catalyst string, but hive varchar can't be pushed down. - lazy val varcharKeys = table.getPartitionKeys.asScala - .filter(col => col.getType.startsWith(serdeConstants.VARCHAR_TYPE_NAME) || - col.getType.startsWith(serdeConstants.CHAR_TYPE_NAME)) - .map(col => col.getName).toSet - - filters.collect { - case op @ SpecialBinaryComparison(a: Attribute, Literal(v, _: IntegralType)) => - s"${a.name} ${op.symbol} $v" - case op @ SpecialBinaryComparison(Literal(v, _: IntegralType), a: Attribute) => - s"$v ${op.symbol} ${a.name}" - case op @ SpecialBinaryComparison(a: Attribute, Literal(v, _: StringType)) - if !varcharKeys.contains(a.name) => - s"""${a.name} ${op.symbol} ${quoteStringLiteral(v.toString)}""" - case op @ SpecialBinaryComparison(Literal(v, _: StringType), a: Attribute) - if !varcharKeys.contains(a.name) => - s"""${quoteStringLiteral(v.toString)} ${op.symbol} ${a.name}""" - }.mkString(" and ") - } - - private def convertComplexFilters(table: Table, filters: Seq[Expression]): String = { - // hive varchar is treated as catalyst string, but hive varchar can't be pushed down. - lazy val varcharKeys = table.getPartitionKeys.asScala - .filter(col => col.getType.startsWith(serdeConstants.VARCHAR_TYPE_NAME) || - col.getType.startsWith(serdeConstants.CHAR_TYPE_NAME)) - .map(col => col.getName).toSet object ExtractableLiteral { def unapply(expr: Expression): Option[String] = expr match { @@ -643,9 +607,11 @@ private[client] class Shim_v0_13 extends Shim_v0_12 { object ExtractableLiterals { def unapply(exprs: Seq[Expression]): Option[Seq[String]] = { - exprs.map(ExtractableLiteral.unapply).foldLeft(Option(Seq.empty[String])) { - case (Some(accum), Some(value)) => Some(accum :+ value) - case _ => None + val extractables = exprs.map(ExtractableLiteral.unapply) + if (extractables.nonEmpty && extractables.forall(_.isDefined)) { + Some(extractables.map(_.get)) + } else { + None } } } @@ -660,40 +626,68 @@ private[client] class Shim_v0_13 extends Shim_v0_12 { } def unapply(values: Set[Any]): Option[Seq[String]] = { - values.toSeq.foldLeft(Option(Seq.empty[String])) { - case (Some(accum), value) if valueToLiteralString.isDefinedAt(value) => - Some(accum :+ valueToLiteralString(value)) - case _ => None + val extractables = values.toSeq.map(valueToLiteralString.lift) + if (extractables.nonEmpty && extractables.forall(_.isDefined)) { + Some(extractables.map(_.get)) + } else { + None } } } - def convertInToOr(a: Attribute, values: Seq[String]): String = { - values.map(value => s"${a.name} = $value").mkString("(", " or ", ")") + object NonVarcharAttribute { + // hive varchar is treated as catalyst string, but hive varchar can't be pushed down. + private val varcharKeys = table.getPartitionKeys.asScala + .filter(col => col.getType.startsWith(serdeConstants.VARCHAR_TYPE_NAME) || + col.getType.startsWith(serdeConstants.CHAR_TYPE_NAME)) + .map(col => col.getName).toSet + + def unapply(attr: Attribute): Option[String] = { + if (varcharKeys.contains(attr.name)) { + None + } else { + Some(attr.name) + } + } + } + + def convertInToOr(name: String, values: Seq[String]): String = { + values.map(value => s"$name = $value").mkString("(", " or ", ")") } - lazy val convert: PartialFunction[Expression, String] = { - case In(a: Attribute, ExtractableLiterals(values)) - if !varcharKeys.contains(a.name) && values.nonEmpty => - convertInToOr(a, values) - case InSet(a: Attribute, ExtractableValues(values)) - if !varcharKeys.contains(a.name) && values.nonEmpty => - convertInToOr(a, values) - case op @ SpecialBinaryComparison(a: Attribute, ExtractableLiteral(value)) - if !varcharKeys.contains(a.name) => - s"${a.name} ${op.symbol} $value" - case op @ SpecialBinaryComparison(ExtractableLiteral(value), a: Attribute) - if !varcharKeys.contains(a.name) => - s"$value ${op.symbol} ${a.name}" - case And(expr1, expr2) - if convert.isDefinedAt(expr1) || convert.isDefinedAt(expr2) => - (convert.lift(expr1) ++ convert.lift(expr2)).mkString("(", " and ", ")") - case Or(expr1, expr2) - if convert.isDefinedAt(expr1) && convert.isDefinedAt(expr2) => - s"(${convert(expr1)} or ${convert(expr2)})" + val useAdvanced = SQLConf.get.advancedPartitionPredicatePushdownEnabled + + def convert(expr: Expression): Option[String] = expr match { + case In(NonVarcharAttribute(name), ExtractableLiterals(values)) if useAdvanced => + Some(convertInToOr(name, values)) + + case InSet(NonVarcharAttribute(name), ExtractableValues(values)) if useAdvanced => + Some(convertInToOr(name, values)) + + case op @ SpecialBinaryComparison(NonVarcharAttribute(name), ExtractableLiteral(value)) => + Some(s"$name ${op.symbol} $value") + + case op @ SpecialBinaryComparison(ExtractableLiteral(value), NonVarcharAttribute(name)) => + Some(s"$value ${op.symbol} $name") + + case And(expr1, expr2) if useAdvanced => + val converted = convert(expr1) ++ convert(expr2) + if (converted.isEmpty) { + None + } else { + Some(converted.mkString("(", " and ", ")")) + } + + case Or(expr1, expr2) if useAdvanced => + for { + left <- convert(expr1) + right <- convert(expr2) + } yield s"($left or $right)" + + case _ => None } - filters.map(convert.lift).collect { case Some(filterString) => filterString }.mkString(" and ") + filters.flatMap(convert).mkString(" and ") } private def quoteStringLiteral(str: String): String = { From 62a826f17c549ed93300bdce562db56bddd5d959 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Fri, 24 Nov 2017 11:46:58 +0100 Subject: [PATCH 1716/1765] [SPARK-22591][SQL] GenerateOrdering shouldn't change CodegenContext.INPUT_ROW ## What changes were proposed in this pull request? When I played with codegen in developing another PR, I found the value of `CodegenContext.INPUT_ROW` is not reliable. Under wholestage codegen, it is assigned to null first and then suddenly changed to `i`. The reason is `GenerateOrdering` changes `CodegenContext.INPUT_ROW` but doesn't restore it back. ## How was this patch tested? Added test. Author: Liang-Chi Hsieh Closes #19800 from viirya/SPARK-22591. --- .../expressions/codegen/GenerateOrdering.scala | 16 ++++++++++------ .../sql/catalyst/expressions/OrderingSuite.scala | 11 ++++++++++- 2 files changed, 20 insertions(+), 7 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala index 1639d1b9dda1f..4a459571ed634 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala @@ -72,13 +72,15 @@ object GenerateOrdering extends CodeGenerator[Seq[SortOrder], Ordering[InternalR * Generates the code for ordering based on the given order. */ def genComparisons(ctx: CodegenContext, ordering: Seq[SortOrder]): String = { + val oldInputRow = ctx.INPUT_ROW + val oldCurrentVars = ctx.currentVars + val inputRow = "i" + ctx.INPUT_ROW = inputRow + // to use INPUT_ROW we must make sure currentVars is null + ctx.currentVars = null + val comparisons = ordering.map { order => - val oldCurrentVars = ctx.currentVars - ctx.INPUT_ROW = "i" - // to use INPUT_ROW we must make sure currentVars is null - ctx.currentVars = null val eval = order.child.genCode(ctx) - ctx.currentVars = oldCurrentVars val asc = order.isAscending val isNullA = ctx.freshName("isNullA") val primitiveA = ctx.freshName("primitiveA") @@ -147,10 +149,12 @@ object GenerateOrdering extends CodeGenerator[Seq[SortOrder], Ordering[InternalR """ }.mkString }) + ctx.currentVars = oldCurrentVars + ctx.INPUT_ROW = oldInputRow // make sure INPUT_ROW is declared even if splitExpressions // returns an inlined block s""" - |InternalRow ${ctx.INPUT_ROW} = null; + |InternalRow $inputRow = null; |$code """.stripMargin } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/OrderingSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/OrderingSuite.scala index aa61ba2bff2bb..d0604b8eb7675 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/OrderingSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/OrderingSuite.scala @@ -24,7 +24,7 @@ import org.apache.spark.serializer.KryoSerializer import org.apache.spark.sql.{RandomDataGenerator, Row} import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} import org.apache.spark.sql.catalyst.dsl.expressions._ -import org.apache.spark.sql.catalyst.expressions.codegen.{GenerateOrdering, LazilyGeneratedOrdering} +import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, GenerateOrdering, LazilyGeneratedOrdering} import org.apache.spark.sql.types._ class OrderingSuite extends SparkFunSuite with ExpressionEvalHelper { @@ -156,4 +156,13 @@ class OrderingSuite extends SparkFunSuite with ExpressionEvalHelper { assert(genOrdering.compare(rowB1, rowB2) < 0) } } + + test("SPARK-22591: GenerateOrdering shouldn't change ctx.INPUT_ROW") { + val ctx = new CodegenContext() + ctx.INPUT_ROW = null + + val schema = new StructType().add("field", FloatType, nullable = true) + GenerateOrdering.genComparisons(ctx, schema) + assert(ctx.INPUT_ROW == null) + } } From 554adc77d24c411a6df6d38c596aa33cdf68f3c1 Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Fri, 24 Nov 2017 12:08:49 +0100 Subject: [PATCH 1717/1765] [SPARK-22595][SQL] fix flaky test: CastSuite.SPARK-22500: cast for struct should not generate codes beyond 64KB ## What changes were proposed in this pull request? This PR reduces the number of fields in the test case of `CastSuite` to fix an issue that is pointed at [here](https://github.com/apache/spark/pull/19800#issuecomment-346634950). ``` java.lang.OutOfMemoryError: GC overhead limit exceeded java.lang.OutOfMemoryError: GC overhead limit exceeded at org.codehaus.janino.UnitCompiler.findClass(UnitCompiler.java:10971) at org.codehaus.janino.UnitCompiler.findTypeByName(UnitCompiler.java:7607) at org.codehaus.janino.UnitCompiler.getReferenceType(UnitCompiler.java:5758) at org.codehaus.janino.UnitCompiler.getType2(UnitCompiler.java:5732) at org.codehaus.janino.UnitCompiler.access$13200(UnitCompiler.java:206) at org.codehaus.janino.UnitCompiler$18.visitReferenceType(UnitCompiler.java:5668) at org.codehaus.janino.UnitCompiler$18.visitReferenceType(UnitCompiler.java:5660) at org.codehaus.janino.Java$ReferenceType.accept(Java.java:3356) at org.codehaus.janino.UnitCompiler.getType(UnitCompiler.java:5660) at org.codehaus.janino.UnitCompiler.buildLocalVariableMap(UnitCompiler.java:2892) at org.codehaus.janino.UnitCompiler.compile(UnitCompiler.java:2764) at org.codehaus.janino.UnitCompiler.compileDeclaredMethods(UnitCompiler.java:1262) at org.codehaus.janino.UnitCompiler.compileDeclaredMethods(UnitCompiler.java:1234) at org.codehaus.janino.UnitCompiler.compile2(UnitCompiler.java:538) at org.codehaus.janino.UnitCompiler.compile2(UnitCompiler.java:890) at org.codehaus.janino.UnitCompiler.compile2(UnitCompiler.java:894) at org.codehaus.janino.UnitCompiler.access$600(UnitCompiler.java:206) at org.codehaus.janino.UnitCompiler$2.visitMemberClassDeclaration(UnitCompiler.java:377) at org.codehaus.janino.UnitCompiler$2.visitMemberClassDeclaration(UnitCompiler.java:369) at org.codehaus.janino.Java$MemberClassDeclaration.accept(Java.java:1128) at org.codehaus.janino.UnitCompiler.compile(UnitCompiler.java:369) at org.codehaus.janino.UnitCompiler.compileDeclaredMemberTypes(UnitCompiler.java:1209) at org.codehaus.janino.UnitCompiler.compile2(UnitCompiler.java:564) at org.codehaus.janino.UnitCompiler.compile2(UnitCompiler.java:890) at org.codehaus.janino.UnitCompiler.compile2(UnitCompiler.java:894) at org.codehaus.janino.UnitCompiler.access$600(UnitCompiler.java:206) at org.codehaus.janino.UnitCompiler$2.visitMemberClassDeclaration(UnitCompiler.java:377) at org.codehaus.janino.UnitCompiler$2.visitMemberClassDeclaration(UnitCompiler.java:369) at org.codehaus.janino.Java$MemberClassDeclaration.accept(Java.java:1128) at org.codehaus.janino.UnitCompiler.compile(UnitCompiler.java:369) at org.codehaus.janino.UnitCompiler.compileDeclaredMemberTypes(UnitCompiler.java:1209) at org.codehaus.janino.UnitCompiler.compile2(UnitCompiler.java:564) ... ``` ## How was this patch tested? Used existing test case Author: Kazuaki Ishizaki Closes #19806 from kiszk/SPARK-22595. --- .../org/apache/spark/sql/catalyst/expressions/CastSuite.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala index 84bd8b2f91e4f..7837d6529d12b 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala @@ -829,7 +829,7 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper { } test("SPARK-22500: cast for struct should not generate codes beyond 64KB") { - val N = 250 + val N = 25 val fromInner = new StructType( (1 to N).map(i => StructField(s"s$i", DoubleType)).toArray) From 449e26ecdc891039198c26ece99454a2e76d5455 Mon Sep 17 00:00:00 2001 From: Wang Gengliang Date: Fri, 24 Nov 2017 15:07:43 +0100 Subject: [PATCH 1718/1765] [SPARK-22559][CORE] history server: handle exception on opening corrupted listing.ldb ## What changes were proposed in this pull request? Currently history server v2 failed to start if `listing.ldb` is corrupted. This patch get rid of the corrupted `listing.ldb` and re-create it. The exception handling follows [opening disk store for app](https://github.com/apache/spark/blob/0ffa7c488fa8156e2a1aa282e60b7c36b86d8af8/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala#L307) ## How was this patch tested? manual test Please review http://spark.apache.org/contributing.html before opening a pull request. Author: Wang Gengliang Closes #19786 from gengliangwang/listingException. --- .../spark/deploy/history/FsHistoryProvider.scala | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala b/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala index 25f82b55f2003..69ccde3a8149d 100644 --- a/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala +++ b/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala @@ -34,10 +34,10 @@ import org.apache.hadoop.fs.permission.FsAction import org.apache.hadoop.hdfs.DistributedFileSystem import org.apache.hadoop.hdfs.protocol.HdfsConstants import org.apache.hadoop.security.AccessControlException +import org.fusesource.leveldbjni.internal.NativeDB import org.apache.spark.{SecurityManager, SparkConf, SparkException} import org.apache.spark.deploy.SparkHadoopUtil -import org.apache.spark.deploy.history.config._ import org.apache.spark.internal.Logging import org.apache.spark.scheduler._ import org.apache.spark.scheduler.ReplayListenerBus._ @@ -132,7 +132,7 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) AppStatusStore.CURRENT_VERSION, logDir.toString()) try { - open(new File(path, "listing.ldb"), metadata) + open(dbPath, metadata) } catch { // If there's an error, remove the listing database and any existing UI database // from the store directory, since it's extremely likely that they'll all contain @@ -140,7 +140,12 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) case _: UnsupportedStoreVersionException | _: MetadataMismatchException => logInfo("Detected incompatible DB versions, deleting...") path.listFiles().foreach(Utils.deleteRecursively) - open(new File(path, "listing.ldb"), metadata) + open(dbPath, metadata) + case dbExc: NativeDB.DBException => + // Get rid of the corrupted listing.ldb and re-create it. + logWarning(s"Failed to load disk store $dbPath :", dbExc) + Utils.deleteRecursively(dbPath) + open(dbPath, metadata) } }.getOrElse(new InMemoryStore()) @@ -568,7 +573,6 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) } val logPath = fileStatus.getPath() - logInfo(s"Replaying log path: $logPath") val bus = new ReplayListenerBus() val listener = new AppListingListener(fileStatus, clock) From efd0036ec88bdc385f5a9ea568d2e2bbfcda2912 Mon Sep 17 00:00:00 2001 From: GuoChenzhao Date: Fri, 24 Nov 2017 15:09:43 +0100 Subject: [PATCH 1719/1765] [SPARK-22537][CORE] Aggregation of map output statistics on driver faces single point bottleneck ## What changes were proposed in this pull request? In adaptive execution, the map output statistics of all mappers will be aggregated after previous stage is successfully executed. Driver takes the aggregation job while it will get slow when the number of `mapper * shuffle partitions` is large, since it only uses single thread to compute. This PR uses multi-thread to deal with this single point bottleneck. ## How was this patch tested? Test cases are in `MapOutputTrackerSuite.scala` Author: GuoChenzhao Author: gczsjdy Closes #19763 from gczsjdy/single_point_mapstatistics. --- .../org/apache/spark/MapOutputTracker.scala | 60 ++++++++++++++++++- .../spark/internal/config/package.scala | 11 ++++ .../apache/spark/MapOutputTrackerSuite.scala | 23 +++++++ 3 files changed, 91 insertions(+), 3 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala index 7f760a59bda2f..195fd4f818b36 100644 --- a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala +++ b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala @@ -23,11 +23,14 @@ import java.util.zip.{GZIPInputStream, GZIPOutputStream} import scala.collection.JavaConverters._ import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet, Map} +import scala.concurrent.{ExecutionContext, Future} +import scala.concurrent.duration.Duration import scala.reflect.ClassTag import scala.util.control.NonFatal import org.apache.spark.broadcast.{Broadcast, BroadcastManager} import org.apache.spark.internal.Logging +import org.apache.spark.internal.config._ import org.apache.spark.rpc.{RpcCallContext, RpcEndpoint, RpcEndpointRef, RpcEnv} import org.apache.spark.scheduler.MapStatus import org.apache.spark.shuffle.MetadataFetchFailedException @@ -472,15 +475,66 @@ private[spark] class MapOutputTrackerMaster( shuffleStatuses.get(shuffleId).map(_.findMissingPartitions()) } + /** + * Grouped function of Range, this is to avoid traverse of all elements of Range using + * IterableLike's grouped function. + */ + def rangeGrouped(range: Range, size: Int): Seq[Range] = { + val start = range.start + val step = range.step + val end = range.end + for (i <- start.until(end, size * step)) yield { + i.until(i + size * step, step) + } + } + + /** + * To equally divide n elements into m buckets, basically each bucket should have n/m elements, + * for the remaining n%m elements, add one more element to the first n%m buckets each. + */ + def equallyDivide(numElements: Int, numBuckets: Int): Seq[Seq[Int]] = { + val elementsPerBucket = numElements / numBuckets + val remaining = numElements % numBuckets + val splitPoint = (elementsPerBucket + 1) * remaining + if (elementsPerBucket == 0) { + rangeGrouped(0.until(splitPoint), elementsPerBucket + 1) + } else { + rangeGrouped(0.until(splitPoint), elementsPerBucket + 1) ++ + rangeGrouped(splitPoint.until(numElements), elementsPerBucket) + } + } + /** * Return statistics about all of the outputs for a given shuffle. */ def getStatistics(dep: ShuffleDependency[_, _, _]): MapOutputStatistics = { shuffleStatuses(dep.shuffleId).withMapStatuses { statuses => val totalSizes = new Array[Long](dep.partitioner.numPartitions) - for (s <- statuses) { - for (i <- 0 until totalSizes.length) { - totalSizes(i) += s.getSizeForBlock(i) + val parallelAggThreshold = conf.get( + SHUFFLE_MAP_OUTPUT_PARALLEL_AGGREGATION_THRESHOLD) + val parallelism = math.min( + Runtime.getRuntime.availableProcessors(), + statuses.length.toLong * totalSizes.length / parallelAggThreshold + 1).toInt + if (parallelism <= 1) { + for (s <- statuses) { + for (i <- 0 until totalSizes.length) { + totalSizes(i) += s.getSizeForBlock(i) + } + } + } else { + val threadPool = ThreadUtils.newDaemonFixedThreadPool(parallelism, "map-output-aggregate") + try { + implicit val executionContext = ExecutionContext.fromExecutor(threadPool) + val mapStatusSubmitTasks = equallyDivide(totalSizes.length, parallelism).map { + reduceIds => Future { + for (s <- statuses; i <- reduceIds) { + totalSizes(i) += s.getSizeForBlock(i) + } + } + } + ThreadUtils.awaitResult(Future.sequence(mapStatusSubmitTasks), Duration.Inf) + } finally { + threadPool.shutdown() } } new MapOutputStatistics(dep.shuffleId, totalSizes) diff --git a/core/src/main/scala/org/apache/spark/internal/config/package.scala b/core/src/main/scala/org/apache/spark/internal/config/package.scala index 7a9072736b9aa..8fa25c0281493 100644 --- a/core/src/main/scala/org/apache/spark/internal/config/package.scala +++ b/core/src/main/scala/org/apache/spark/internal/config/package.scala @@ -499,4 +499,15 @@ package object config { "array in the sorter.") .intConf .createWithDefault(Integer.MAX_VALUE) + + private[spark] val SHUFFLE_MAP_OUTPUT_PARALLEL_AGGREGATION_THRESHOLD = + ConfigBuilder("spark.shuffle.mapOutput.parallelAggregationThreshold") + .internal() + .doc("Multi-thread is used when the number of mappers * shuffle partitions is greater than " + + "or equal to this threshold. Note that the actual parallelism is calculated by number of " + + "mappers * shuffle partitions / this threshold + 1, so this threshold should be positive.") + .intConf + .checkValue(v => v > 0, "The threshold should be positive.") + .createWithDefault(10000000) + } diff --git a/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala b/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala index ebd826b0ba2f6..50b8ea754d8d9 100644 --- a/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala +++ b/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala @@ -275,4 +275,27 @@ class MapOutputTrackerSuite extends SparkFunSuite { } } + test("equally divide map statistics tasks") { + val func = newTrackerMaster().equallyDivide _ + val cases = Seq((0, 5), (4, 5), (15, 5), (16, 5), (17, 5), (18, 5), (19, 5), (20, 5)) + val expects = Seq( + Seq(0, 0, 0, 0, 0), + Seq(1, 1, 1, 1, 0), + Seq(3, 3, 3, 3, 3), + Seq(4, 3, 3, 3, 3), + Seq(4, 4, 3, 3, 3), + Seq(4, 4, 4, 3, 3), + Seq(4, 4, 4, 4, 3), + Seq(4, 4, 4, 4, 4)) + cases.zip(expects).foreach { case ((num, divisor), expect) => + val answer = func(num, divisor).toSeq + var wholeSplit = (0 until num) + answer.zip(expect).foreach { case (split, expectSplitLength) => + val (currentSplit, rest) = wholeSplit.splitAt(expectSplitLength) + assert(currentSplit.toSet == split.toSet) + wholeSplit = rest + } + } + } + } From a1877f45c3451d18879083ed9b71dd9d5f583f1c Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Fri, 24 Nov 2017 19:55:26 +0100 Subject: [PATCH 1720/1765] [SPARK-22597][SQL] Add spark-sql cmd script for Windows users ## What changes were proposed in this pull request? This PR proposes to add cmd scripts so that Windows users can also run `spark-sql` script. ## How was this patch tested? Manually tested on Windows. **Before** ```cmd C:\...\spark>.\bin\spark-sql '.\bin\spark-sql' is not recognized as an internal or external command, operable program or batch file. C:\...\spark>.\bin\spark-sql.cmd '.\bin\spark-sql.cmd' is not recognized as an internal or external command, operable program or batch file. ``` **After** ```cmd C:\...\spark>.\bin\spark-sql ... spark-sql> SELECT 'Hello World !!'; ... Hello World !! ``` Author: hyukjinkwon Closes #19808 from HyukjinKwon/spark-sql-cmd. --- bin/find-spark-home.cmd | 2 +- bin/run-example.cmd | 2 +- bin/spark-sql.cmd | 25 +++++++++++++++++++++++++ bin/spark-sql2.cmd | 25 +++++++++++++++++++++++++ bin/sparkR2.cmd | 3 +-- 5 files changed, 53 insertions(+), 4 deletions(-) create mode 100644 bin/spark-sql.cmd create mode 100644 bin/spark-sql2.cmd diff --git a/bin/find-spark-home.cmd b/bin/find-spark-home.cmd index c75e7eedb9418..6025f67c38de4 100644 --- a/bin/find-spark-home.cmd +++ b/bin/find-spark-home.cmd @@ -32,7 +32,7 @@ if not "x%PYSPARK_PYTHON%"=="x" ( ) rem If there is python installed, trying to use the root dir as SPARK_HOME -where %PYTHON_RUNNER% > nul 2>$1 +where %PYTHON_RUNNER% > nul 2>&1 if %ERRORLEVEL% neq 0 ( if not exist %PYTHON_RUNNER% ( if "x%SPARK_HOME%"=="x" ( diff --git a/bin/run-example.cmd b/bin/run-example.cmd index cc6b234406e4a..2dd396e785358 100644 --- a/bin/run-example.cmd +++ b/bin/run-example.cmd @@ -20,7 +20,7 @@ rem rem Figure out where the Spark framework is installed call "%~dp0find-spark-home.cmd" -set _SPARK_CMD_USAGE=Usage: ./bin/run-example [options] example-class [example args] +set _SPARK_CMD_USAGE=Usage: .\bin\run-example [options] example-class [example args] rem The outermost quotes are used to prevent Windows command line parse error rem when there are some quotes in parameters, see SPARK-21877. diff --git a/bin/spark-sql.cmd b/bin/spark-sql.cmd new file mode 100644 index 0000000000000..919e3214b5863 --- /dev/null +++ b/bin/spark-sql.cmd @@ -0,0 +1,25 @@ +@echo off + +rem +rem Licensed to the Apache Software Foundation (ASF) under one or more +rem contributor license agreements. See the NOTICE file distributed with +rem this work for additional information regarding copyright ownership. +rem The ASF licenses this file to You under the Apache License, Version 2.0 +rem (the "License"); you may not use this file except in compliance with +rem the License. You may obtain a copy of the License at +rem +rem http://www.apache.org/licenses/LICENSE-2.0 +rem +rem Unless required by applicable law or agreed to in writing, software +rem distributed under the License is distributed on an "AS IS" BASIS, +rem WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +rem See the License for the specific language governing permissions and +rem limitations under the License. +rem + +rem This is the entry point for running SparkSQL. To avoid polluting the +rem environment, it just launches a new cmd to do the real work. + +rem The outermost quotes are used to prevent Windows command line parse error +rem when there are some quotes in parameters, see SPARK-21877. +cmd /V /E /C ""%~dp0spark-sql2.cmd" %*" diff --git a/bin/spark-sql2.cmd b/bin/spark-sql2.cmd new file mode 100644 index 0000000000000..c34a3c5aa0739 --- /dev/null +++ b/bin/spark-sql2.cmd @@ -0,0 +1,25 @@ +@echo off + +rem +rem Licensed to the Apache Software Foundation (ASF) under one or more +rem contributor license agreements. See the NOTICE file distributed with +rem this work for additional information regarding copyright ownership. +rem The ASF licenses this file to You under the Apache License, Version 2.0 +rem (the "License"); you may not use this file except in compliance with +rem the License. You may obtain a copy of the License at +rem +rem http://www.apache.org/licenses/LICENSE-2.0 +rem +rem Unless required by applicable law or agreed to in writing, software +rem distributed under the License is distributed on an "AS IS" BASIS, +rem WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +rem See the License for the specific language governing permissions and +rem limitations under the License. +rem + +rem Figure out where the Spark framework is installed +call "%~dp0find-spark-home.cmd" + +set _SPARK_CMD_USAGE=Usage: .\bin\spark-sql [options] [cli option] + +call "%SPARK_HOME%\bin\spark-submit2.cmd" --class org.apache.spark.sql.hive.thriftserver.SparkSQLCLIDriver %* diff --git a/bin/sparkR2.cmd b/bin/sparkR2.cmd index b48bea345c0b9..446f0c30bfe82 100644 --- a/bin/sparkR2.cmd +++ b/bin/sparkR2.cmd @@ -21,6 +21,5 @@ rem Figure out where the Spark framework is installed call "%~dp0find-spark-home.cmd" call "%SPARK_HOME%\bin\load-spark-env.cmd" - - +set _SPARK_CMD_USAGE=Usage: .\bin\sparkR [options] call "%SPARK_HOME%\bin\spark-submit2.cmd" sparkr-shell-main %* From 70221903f54eaa0514d5d189dfb6f175a62228a8 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Fri, 24 Nov 2017 21:50:30 -0800 Subject: [PATCH 1721/1765] [SPARK-22596][SQL] set ctx.currentVars in CodegenSupport.consume ## What changes were proposed in this pull request? `ctx.currentVars` means the input variables for the current operator, which is already decided in `CodegenSupport`, we can set it there instead of `doConsume`. also add more comments to help people understand the codegen framework. After this PR, we now have a principle about setting `ctx.currentVars` and `ctx.INPUT_ROW`: 1. for non-whole-stage-codegen path, never set them. (permit some special cases like generating ordering) 2. for whole-stage-codegen `produce` path, mostly we don't need to set them, but blocking operators may need to set them for expressions that produce data from data source, sort buffer, aggregate buffer, etc. 3. for whole-stage-codegen `consume` path, mostly we don't need to set them because `currentVars` is automatically set to child input variables and `INPUT_ROW` is mostly not used. A few plans need to tweak them as they may have different inputs, or they use the input row. ## How was this patch tested? existing tests. Author: Wenchen Fan Closes #19803 from cloud-fan/codegen. --- .../catalyst/expressions/BoundAttribute.scala | 23 +++++++++------- .../expressions/codegen/CodeGenerator.scala | 14 +++++++--- .../sql/execution/DataSourceScanExec.scala | 14 +++++----- .../spark/sql/execution/ExpandExec.scala | 3 --- .../spark/sql/execution/GenerateExec.scala | 2 -- .../sql/execution/WholeStageCodegenExec.scala | 27 ++++++++++++++----- .../execution/basicPhysicalOperators.scala | 6 +---- .../apache/spark/sql/execution/objects.scala | 20 +++++--------- 8 files changed, 59 insertions(+), 50 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala index 7d16118c9d59f..6a17a397b3ef2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala @@ -59,21 +59,24 @@ case class BoundReference(ordinal: Int, dataType: DataType, nullable: Boolean) } override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - val javaType = ctx.javaType(dataType) - val value = ctx.getValue(ctx.INPUT_ROW, dataType, ordinal.toString) if (ctx.currentVars != null && ctx.currentVars(ordinal) != null) { val oev = ctx.currentVars(ordinal) ev.isNull = oev.isNull ev.value = oev.value - val code = oev.code - oev.code = "" - ev.copy(code = code) - } else if (nullable) { - ev.copy(code = s""" - boolean ${ev.isNull} = ${ctx.INPUT_ROW}.isNullAt($ordinal); - $javaType ${ev.value} = ${ev.isNull} ? ${ctx.defaultValue(dataType)} : ($value);""") + ev.copy(code = oev.code) } else { - ev.copy(code = s"""$javaType ${ev.value} = $value;""", isNull = "false") + assert(ctx.INPUT_ROW != null, "INPUT_ROW and currentVars cannot both be null.") + val javaType = ctx.javaType(dataType) + val value = ctx.getValue(ctx.INPUT_ROW, dataType, ordinal.toString) + if (nullable) { + ev.copy(code = + s""" + |boolean ${ev.isNull} = ${ctx.INPUT_ROW}.isNullAt($ordinal); + |$javaType ${ev.value} = ${ev.isNull} ? ${ctx.defaultValue(dataType)} : ($value); + """.stripMargin) + } else { + ev.copy(code = s"$javaType ${ev.value} = $value;", isNull = "false") + } } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index 9df8a8d6f6609..0498e61819f48 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -133,6 +133,17 @@ class CodegenContext { term } + /** + * Holding the variable name of the input row of the current operator, will be used by + * `BoundReference` to generate code. + * + * Note that if `currentVars` is not null, `BoundReference` prefers `currentVars` over `INPUT_ROW` + * to generate code. If you want to make sure the generated code use `INPUT_ROW`, you need to set + * `currentVars` to null, or set `currentVars(i)` to null for certain columns, before calling + * `Expression.genCode`. + */ + final var INPUT_ROW = "i" + /** * Holding a list of generated columns as input of current operator, will be used by * BoundReference to generate code. @@ -386,9 +397,6 @@ class CodegenContext { final val JAVA_FLOAT = "float" final val JAVA_DOUBLE = "double" - /** The variable name of the input row in generated code. */ - final var INPUT_ROW = "i" - /** * The map from a variable name to it's next ID. */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala index a477c23140536..747749bc72e66 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala @@ -123,7 +123,7 @@ case class RowDataSourceScanExec( |while ($input.hasNext()) { | InternalRow $row = (InternalRow) $input.next(); | $numOutputRows.add(1); - | ${consume(ctx, columnsRowInput, null).trim} + | ${consume(ctx, columnsRowInput).trim} | if (shouldStop()) return; |} """.stripMargin @@ -355,19 +355,21 @@ case class FileSourceScanExec( // PhysicalRDD always just has one input val input = ctx.freshName("input") ctx.addMutableState("scala.collection.Iterator", input, s"$input = inputs[0];") - val exprRows = output.zipWithIndex.map{ case (a, i) => - BoundReference(i, a.dataType, a.nullable) - } val row = ctx.freshName("row") + ctx.INPUT_ROW = row ctx.currentVars = null - val columnsRowInput = exprRows.map(_.genCode(ctx)) + // Always provide `outputVars`, so that the framework can help us build unsafe row if the input + // row is not unsafe row, i.e. `needsUnsafeRowConversion` is true. + val outputVars = output.zipWithIndex.map{ case (a, i) => + BoundReference(i, a.dataType, a.nullable).genCode(ctx) + } val inputRow = if (needsUnsafeRowConversion) null else row s""" |while ($input.hasNext()) { | InternalRow $row = (InternalRow) $input.next(); | $numOutputRows.add(1); - | ${consume(ctx, columnsRowInput, inputRow).trim} + | ${consume(ctx, outputVars, inputRow).trim} | if (shouldStop()) return; |} """.stripMargin diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExpandExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExpandExec.scala index 33849f4389b92..a7bd5ebf93ecd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExpandExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExpandExec.scala @@ -133,9 +133,6 @@ case class ExpandExec( * size explosion. */ - // Set input variables - ctx.currentVars = input - // Tracks whether a column has the same output for all rows. // Size of sameOutput array should equal N. // If sameOutput(i) is true, then the i-th column has the same value for all output rows given diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/GenerateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/GenerateExec.scala index c142d3b5ed4f2..e1562befe14f9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/GenerateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/GenerateExec.scala @@ -135,8 +135,6 @@ case class GenerateExec( override def needCopyResult: Boolean = true override def doConsume(ctx: CodegenContext, input: Seq[ExprCode], row: ExprCode): String = { - ctx.currentVars = input - // Add input rows to the values when we are joining val values = if (join) { input diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala index 16b5706c03bf9..7166b7771e4db 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala @@ -108,20 +108,22 @@ trait CodegenSupport extends SparkPlan { /** * Consume the generated columns or row from current SparkPlan, call its parent's `doConsume()`. + * + * Note that `outputVars` and `row` can't both be null. */ final def consume(ctx: CodegenContext, outputVars: Seq[ExprCode], row: String = null): String = { val inputVars = - if (row != null) { + if (outputVars != null) { + assert(outputVars.length == output.length) + // outputVars will be used to generate the code for UnsafeRow, so we should copy them + outputVars.map(_.copy()) + } else { + assert(row != null, "outputVars and row cannot both be null.") ctx.currentVars = null ctx.INPUT_ROW = row output.zipWithIndex.map { case (attr, i) => BoundReference(i, attr.dataType, attr.nullable).genCode(ctx) } - } else { - assert(outputVars != null) - assert(outputVars.length == output.length) - // outputVars will be used to generate the code for UnsafeRow, so we should copy them - outputVars.map(_.copy()) } val rowVar = if (row != null) { @@ -147,6 +149,11 @@ trait CodegenSupport extends SparkPlan { } } + // Set up the `currentVars` in the codegen context, as we generate the code of `inputVars` + // before calling `parent.doConsume`. We can't set up `INPUT_ROW`, because parent needs to + // generate code of `rowVar` manually. + ctx.currentVars = inputVars + ctx.INPUT_ROW = null ctx.freshNamePrefix = parent.variablePrefix val evaluated = evaluateRequiredVariables(output, inputVars, parent.usedInputs) s""" @@ -193,7 +200,8 @@ trait CodegenSupport extends SparkPlan { def usedInputs: AttributeSet = references /** - * Generate the Java source code to process the rows from child SparkPlan. + * Generate the Java source code to process the rows from child SparkPlan. This should only be + * called from `consume`. * * This should be override by subclass to support codegen. * @@ -207,6 +215,11 @@ trait CodegenSupport extends SparkPlan { * } * * Note: A plan can either consume the rows as UnsafeRow (row), or a list of variables (input). + * When consuming as a listing of variables, the code to produce the input is already + * generated and `CodegenContext.currentVars` is already set. When consuming as UnsafeRow, + * implementations need to put `row.code` in the generated code and set + * `CodegenContext.INPUT_ROW` manually. Some plans may need more tweaks as they have + * different inputs(join build side, aggregate buffer, etc.), or other special cases. */ def doConsume(ctx: CodegenContext, input: Seq[ExprCode], row: ExprCode): String = { throw new UnsupportedOperationException diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala index f205bdf3da709..c9a15147e30d0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala @@ -56,9 +56,7 @@ case class ProjectExec(projectList: Seq[NamedExpression], child: SparkPlan) } override def doConsume(ctx: CodegenContext, input: Seq[ExprCode], row: ExprCode): String = { - val exprs = projectList.map(x => - ExpressionCanonicalizer.execute(BindReferences.bindReference(x, child.output))) - ctx.currentVars = input + val exprs = projectList.map(x => BindReferences.bindReference[Expression](x, child.output)) val resultVars = exprs.map(_.genCode(ctx)) // Evaluation of non-deterministic expressions can't be deferred. val nonDeterministicAttrs = projectList.filterNot(_.deterministic).map(_.toAttribute) @@ -152,8 +150,6 @@ case class FilterExec(condition: Expression, child: SparkPlan) """.stripMargin } - ctx.currentVars = input - // To generate the predicates we will follow this algorithm. // For each predicate that is not IsNotNull, we will generate them one by one loading attributes // as necessary. For each of both attributes, if there is an IsNotNull predicate we will diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala index d861109436a08..d1bd8a7076863 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala @@ -81,11 +81,8 @@ case class DeserializeToObjectExec( } override def doConsume(ctx: CodegenContext, input: Seq[ExprCode], row: ExprCode): String = { - val bound = ExpressionCanonicalizer.execute( - BindReferences.bindReference(deserializer, child.output)) - ctx.currentVars = input - val resultVars = bound.genCode(ctx) :: Nil - consume(ctx, resultVars) + val resultObj = BindReferences.bindReference(deserializer, child.output).genCode(ctx) + consume(ctx, resultObj :: Nil) } override protected def doExecute(): RDD[InternalRow] = { @@ -118,11 +115,9 @@ case class SerializeFromObjectExec( } override def doConsume(ctx: CodegenContext, input: Seq[ExprCode], row: ExprCode): String = { - val bound = serializer.map { expr => - ExpressionCanonicalizer.execute(BindReferences.bindReference(expr, child.output)) + val resultVars = serializer.map { expr => + BindReferences.bindReference[Expression](expr, child.output).genCode(ctx) } - ctx.currentVars = input - val resultVars = bound.map(_.genCode(ctx)) consume(ctx, resultVars) } @@ -224,12 +219,9 @@ case class MapElementsExec( val funcObj = Literal.create(func, ObjectType(funcClass)) val callFunc = Invoke(funcObj, methodName, outputObjAttr.dataType, child.output) - val bound = ExpressionCanonicalizer.execute( - BindReferences.bindReference(callFunc, child.output)) - ctx.currentVars = input - val resultVars = bound.genCode(ctx) :: Nil + val result = BindReferences.bindReference(callFunc, child.output).genCode(ctx) - consume(ctx, resultVars) + consume(ctx, result :: Nil) } override protected def doExecute(): RDD[InternalRow] = { From e3fd93f149ff0ff1caff28a5191215e2a29749a9 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Fri, 24 Nov 2017 22:43:47 -0800 Subject: [PATCH 1722/1765] [SPARK-22604][SQL] remove the get address methods from ColumnVector ## What changes were proposed in this pull request? `nullsNativeAddress` and `valuesNativeAddress` are only used in tests and benchmark, no need to be top class API. ## How was this patch tested? existing tests Author: Wenchen Fan Closes #19818 from cloud-fan/minor. --- .../vectorized/ArrowColumnVector.java | 10 --- .../execution/vectorized/ColumnVector.java | 7 -- .../vectorized/OffHeapColumnVector.java | 6 +- .../vectorized/OnHeapColumnVector.java | 9 -- .../vectorized/ColumnarBatchBenchmark.scala | 32 ++++---- .../vectorized/ColumnarBatchSuite.scala | 82 +++++++------------ 6 files changed, 47 insertions(+), 99 deletions(-) diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ArrowColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ArrowColumnVector.java index 949035bfb177c..3a10e9830f581 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ArrowColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ArrowColumnVector.java @@ -59,16 +59,6 @@ public boolean anyNullsSet() { return numNulls() > 0; } - @Override - public long nullsNativeAddress() { - throw new RuntimeException("Cannot get native address for arrow column"); - } - - @Override - public long valuesNativeAddress() { - throw new RuntimeException("Cannot get native address for arrow column"); - } - @Override public void close() { if (childColumns != null) { diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVector.java index 666fd63fdcf2f..360ed83e2af2a 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVector.java @@ -62,13 +62,6 @@ public abstract class ColumnVector implements AutoCloseable { */ public abstract boolean anyNullsSet(); - /** - * Returns the off heap ptr for the arrays backing the NULLs and values buffer. Only valid - * to call for off heap columns. - */ - public abstract long nullsNativeAddress(); - public abstract long valuesNativeAddress(); - /** * Returns whether the value at rowId is NULL. */ diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OffHeapColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OffHeapColumnVector.java index 2bf523b7e7198..6b5c783d4fa87 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OffHeapColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OffHeapColumnVector.java @@ -19,6 +19,8 @@ import java.nio.ByteBuffer; import java.nio.ByteOrder; +import com.google.common.annotations.VisibleForTesting; + import org.apache.spark.sql.types.*; import org.apache.spark.unsafe.Platform; @@ -73,12 +75,12 @@ public OffHeapColumnVector(int capacity, DataType type) { reset(); } - @Override + @VisibleForTesting public long valuesNativeAddress() { return data; } - @Override + @VisibleForTesting public long nullsNativeAddress() { return nulls; } diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OnHeapColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OnHeapColumnVector.java index d699d292711dc..a7b103a62b17a 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OnHeapColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OnHeapColumnVector.java @@ -79,15 +79,6 @@ public OnHeapColumnVector(int capacity, DataType type) { reset(); } - @Override - public long valuesNativeAddress() { - throw new RuntimeException("Cannot get native address for on heap column"); - } - @Override - public long nullsNativeAddress() { - throw new RuntimeException("Cannot get native address for on heap column"); - } - @Override public void close() { super.close(); diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchBenchmark.scala index 1331f157363b0..705b26b8c91e6 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchBenchmark.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchBenchmark.scala @@ -36,15 +36,6 @@ import org.apache.spark.util.collection.BitSet * Benchmark to low level memory access using different ways to manage buffers. */ object ColumnarBatchBenchmark { - - def allocate(capacity: Int, dt: DataType, memMode: MemoryMode): WritableColumnVector = { - if (memMode == MemoryMode.OFF_HEAP) { - new OffHeapColumnVector(capacity, dt) - } else { - new OnHeapColumnVector(capacity, dt) - } - } - // This benchmark reads and writes an array of ints. // TODO: there is a big (2x) penalty for a random access API for off heap. // Note: carefully if modifying this code. It's hard to reason about the JIT. @@ -151,7 +142,7 @@ object ColumnarBatchBenchmark { // Access through the column API with on heap memory val columnOnHeap = { i: Int => - val col = allocate(count, IntegerType, MemoryMode.ON_HEAP) + val col = new OnHeapColumnVector(count, IntegerType) var sum = 0L for (n <- 0L until iters) { var i = 0 @@ -170,7 +161,7 @@ object ColumnarBatchBenchmark { // Access through the column API with off heap memory def columnOffHeap = { i: Int => { - val col = allocate(count, IntegerType, MemoryMode.OFF_HEAP) + val col = new OffHeapColumnVector(count, IntegerType) var sum = 0L for (n <- 0L until iters) { var i = 0 @@ -189,7 +180,7 @@ object ColumnarBatchBenchmark { // Access by directly getting the buffer backing the column. val columnOffheapDirect = { i: Int => - val col = allocate(count, IntegerType, MemoryMode.OFF_HEAP) + val col = new OffHeapColumnVector(count, IntegerType) var sum = 0L for (n <- 0L until iters) { var addr = col.valuesNativeAddress() @@ -255,7 +246,7 @@ object ColumnarBatchBenchmark { // Adding values by appending, instead of putting. val onHeapAppend = { i: Int => - val col = allocate(count, IntegerType, MemoryMode.ON_HEAP) + val col = new OnHeapColumnVector(count, IntegerType) var sum = 0L for (n <- 0L until iters) { var i = 0 @@ -330,7 +321,7 @@ object ColumnarBatchBenchmark { for (n <- 0L until iters) { var i = 0 while (i < count) { - if (i % 2 == 0) b(i) = 1; + if (i % 2 == 0) b(i) = 1 i += 1 } i = 0 @@ -351,7 +342,7 @@ object ColumnarBatchBenchmark { } def stringAccess(iters: Long): Unit = { - val chars = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ"; + val chars = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ" val random = new Random(0) def randomString(min: Int, max: Int): String = { @@ -359,10 +350,10 @@ object ColumnarBatchBenchmark { val sb = new StringBuilder(len) var i = 0 while (i < len) { - sb.append(chars.charAt(random.nextInt(chars.length()))); + sb.append(chars.charAt(random.nextInt(chars.length()))) i += 1 } - return sb.toString + sb.toString } val minString = 3 @@ -373,7 +364,12 @@ object ColumnarBatchBenchmark { .map(_.getBytes(StandardCharsets.UTF_8)).toArray def column(memoryMode: MemoryMode) = { i: Int => - val column = allocate(count, BinaryType, memoryMode) + val column = if (memoryMode == MemoryMode.OFF_HEAP) { + new OffHeapColumnVector(count, BinaryType) + } else { + new OnHeapColumnVector(count, BinaryType) + } + var sum = 0L for (n <- 0L until iters) { var i = 0 diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala index 4a6c8f5521d18..80a50866aa504 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala @@ -50,11 +50,11 @@ class ColumnarBatchSuite extends SparkFunSuite { name: String, size: Int, dt: DataType)( - block: (WritableColumnVector, MemoryMode) => Unit): Unit = { + block: WritableColumnVector => Unit): Unit = { test(name) { Seq(MemoryMode.ON_HEAP, MemoryMode.OFF_HEAP).foreach { mode => val vector = allocate(size, dt, mode) - try block(vector, mode) finally { + try block(vector) finally { vector.close() } } @@ -62,7 +62,7 @@ class ColumnarBatchSuite extends SparkFunSuite { } testVector("Null APIs", 1024, IntegerType) { - (column, memMode) => + column => val reference = mutable.ArrayBuffer.empty[Boolean] var idx = 0 assert(!column.anyNullsSet()) @@ -121,15 +121,11 @@ class ColumnarBatchSuite extends SparkFunSuite { reference.zipWithIndex.foreach { v => assert(v._1 == column.isNullAt(v._2)) - if (memMode == MemoryMode.OFF_HEAP) { - val addr = column.nullsNativeAddress() - assert(v._1 == (Platform.getByte(null, addr + v._2) == 1), "index=" + v._2) - } } } testVector("Byte APIs", 1024, ByteType) { - (column, memMode) => + column => val reference = mutable.ArrayBuffer.empty[Byte] var values = (10 :: 20 :: 30 :: 40 :: 50 :: Nil).map(_.toByte).toArray @@ -173,16 +169,12 @@ class ColumnarBatchSuite extends SparkFunSuite { idx += 3 reference.zipWithIndex.foreach { v => - assert(v._1 == column.getByte(v._2), "MemoryMode" + memMode) - if (memMode == MemoryMode.OFF_HEAP) { - val addr = column.valuesNativeAddress() - assert(v._1 == Platform.getByte(null, addr + v._2)) - } + assert(v._1 == column.getByte(v._2), "VectorType=" + column.getClass.getSimpleName) } } testVector("Short APIs", 1024, ShortType) { - (column, memMode) => + column => val seed = System.currentTimeMillis() val random = new Random(seed) val reference = mutable.ArrayBuffer.empty[Short] @@ -248,16 +240,13 @@ class ColumnarBatchSuite extends SparkFunSuite { } reference.zipWithIndex.foreach { v => - assert(v._1 == column.getShort(v._2), "Seed = " + seed + " Mem Mode=" + memMode) - if (memMode == MemoryMode.OFF_HEAP) { - val addr = column.valuesNativeAddress() - assert(v._1 == Platform.getShort(null, addr + 2 * v._2)) - } + assert(v._1 == column.getShort(v._2), + "Seed = " + seed + " VectorType=" + column.getClass.getSimpleName) } } testVector("Int APIs", 1024, IntegerType) { - (column, memMode) => + column => val seed = System.currentTimeMillis() val random = new Random(seed) val reference = mutable.ArrayBuffer.empty[Int] @@ -329,16 +318,13 @@ class ColumnarBatchSuite extends SparkFunSuite { } reference.zipWithIndex.foreach { v => - assert(v._1 == column.getInt(v._2), "Seed = " + seed + " Mem Mode=" + memMode) - if (memMode == MemoryMode.OFF_HEAP) { - val addr = column.valuesNativeAddress() - assert(v._1 == Platform.getInt(null, addr + 4 * v._2)) - } + assert(v._1 == column.getInt(v._2), + "Seed = " + seed + " VectorType=" + column.getClass.getSimpleName) } } testVector("Long APIs", 1024, LongType) { - (column, memMode) => + column => val seed = System.currentTimeMillis() val random = new Random(seed) val reference = mutable.ArrayBuffer.empty[Long] @@ -413,16 +399,12 @@ class ColumnarBatchSuite extends SparkFunSuite { reference.zipWithIndex.foreach { v => assert(v._1 == column.getLong(v._2), "idx=" + v._2 + - " Seed = " + seed + " MemMode=" + memMode) - if (memMode == MemoryMode.OFF_HEAP) { - val addr = column.valuesNativeAddress() - assert(v._1 == Platform.getLong(null, addr + 8 * v._2)) - } + " Seed = " + seed + " VectorType=" + column.getClass.getSimpleName) } } testVector("Float APIs", 1024, FloatType) { - (column, memMode) => + column => val seed = System.currentTimeMillis() val random = new Random(seed) val reference = mutable.ArrayBuffer.empty[Float] @@ -500,16 +482,13 @@ class ColumnarBatchSuite extends SparkFunSuite { } reference.zipWithIndex.foreach { v => - assert(v._1 == column.getFloat(v._2), "Seed = " + seed + " MemMode=" + memMode) - if (memMode == MemoryMode.OFF_HEAP) { - val addr = column.valuesNativeAddress() - assert(v._1 == Platform.getFloat(null, addr + 4 * v._2)) - } + assert(v._1 == column.getFloat(v._2), + "Seed = " + seed + " VectorType=" + column.getClass.getSimpleName) } } testVector("Double APIs", 1024, DoubleType) { - (column, memMode) => + column => val seed = System.currentTimeMillis() val random = new Random(seed) val reference = mutable.ArrayBuffer.empty[Double] @@ -587,16 +566,13 @@ class ColumnarBatchSuite extends SparkFunSuite { } reference.zipWithIndex.foreach { v => - assert(v._1 == column.getDouble(v._2), "Seed = " + seed + " MemMode=" + memMode) - if (memMode == MemoryMode.OFF_HEAP) { - val addr = column.valuesNativeAddress() - assert(v._1 == Platform.getDouble(null, addr + 8 * v._2)) - } + assert(v._1 == column.getDouble(v._2), + "Seed = " + seed + " VectorType=" + column.getClass.getSimpleName) } } testVector("String APIs", 6, StringType) { - (column, memMode) => + column => val reference = mutable.ArrayBuffer.empty[String] assert(column.arrayData().elementsAppended == 0) @@ -643,9 +619,9 @@ class ColumnarBatchSuite extends SparkFunSuite { assert(column.arrayData().elementsAppended == 17 + (s + s).length) reference.zipWithIndex.foreach { v => - assert(v._1.length == column.getArrayLength(v._2), "MemoryMode=" + memMode) - assert(v._1 == column.getUTF8String(v._2).toString, - "MemoryMode" + memMode) + val errMsg = "VectorType=" + column.getClass.getSimpleName + assert(v._1.length == column.getArrayLength(v._2), errMsg) + assert(v._1 == column.getUTF8String(v._2).toString, errMsg) } column.reset() @@ -653,7 +629,7 @@ class ColumnarBatchSuite extends SparkFunSuite { } testVector("Int Array", 10, new ArrayType(IntegerType, true)) { - (column, _) => + column => // Fill the underlying data with all the arrays back to back. val data = column.arrayData() @@ -763,7 +739,7 @@ class ColumnarBatchSuite extends SparkFunSuite { testVector( "Struct Column", 10, - new StructType().add("int", IntegerType).add("double", DoubleType)) { (column, _) => + new StructType().add("int", IntegerType).add("double", DoubleType)) { column => val c1 = column.getChildColumn(0) val c2 = column.getChildColumn(1) assert(c1.dataType() == IntegerType) @@ -789,7 +765,7 @@ class ColumnarBatchSuite extends SparkFunSuite { } testVector("Nest Array in Array", 10, new ArrayType(new ArrayType(IntegerType, true), true)) { - (column, _) => + column => val childColumn = column.arrayData() val data = column.arrayData().arrayData() (0 until 6).foreach { @@ -822,7 +798,7 @@ class ColumnarBatchSuite extends SparkFunSuite { testVector( "Nest Struct in Array", 10, - new ArrayType(structType, true)) { (column, _) => + new ArrayType(structType, true)) { column => val data = column.arrayData() val c0 = data.getChildColumn(0) val c1 = data.getChildColumn(1) @@ -851,7 +827,7 @@ class ColumnarBatchSuite extends SparkFunSuite { 10, new StructType() .add("int", IntegerType) - .add("array", new ArrayType(IntegerType, true))) { (column, _) => + .add("array", new ArrayType(IntegerType, true))) { column => val c0 = column.getChildColumn(0) val c1 = column.getChildColumn(1) c0.putInt(0, 0) @@ -880,7 +856,7 @@ class ColumnarBatchSuite extends SparkFunSuite { testVector( "Nest Struct in Struct", 10, - new StructType().add("int", IntegerType).add("struct", subSchema)) { (column, _) => + new StructType().add("int", IntegerType).add("struct", subSchema)) { column => val c0 = column.getChildColumn(0) val c1 = column.getChildColumn(1) c0.putInt(0, 0) From 4d8ace48698c9a5e45cfc896b170914f1c21dc7b Mon Sep 17 00:00:00 2001 From: Kalvin Chau Date: Sat, 25 Nov 2017 07:32:28 -0600 Subject: [PATCH 1723/1765] [SPARK-22583] First delegation token renewal time is not 75% of renewal time in Mesos The first scheduled renewal time is is set to the exact expiration time, and all subsequent renewal times are 75% of the renewal time. This makes it so that the inital renewal time is also 75%. ## What changes were proposed in this pull request? Set the initial renewal time to be 75% of renewal time. ## How was this patch tested? Tested locally in a test HDFS cluster, checking various renewal times. Author: Kalvin Chau Closes #19798 from kalvinnchau/fix-inital-renewal-time. --- .../cluster/mesos/MesosHadoopDelegationTokenManager.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosHadoopDelegationTokenManager.scala b/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosHadoopDelegationTokenManager.scala index 325dc179d63ea..7165bfae18a5e 100644 --- a/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosHadoopDelegationTokenManager.scala +++ b/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosHadoopDelegationTokenManager.scala @@ -63,7 +63,7 @@ private[spark] class MesosHadoopDelegationTokenManager( val hadoopConf = SparkHadoopUtil.get.newConfiguration(conf) val rt = tokenManager.obtainDelegationTokens(hadoopConf, creds) logInfo(s"Initialized tokens: ${SparkHadoopUtil.get.dumpTokens(creds)}") - (SparkHadoopUtil.get.serialize(creds), rt) + (SparkHadoopUtil.get.serialize(creds), SparkHadoopUtil.getDateOfNextUpdate(rt, 0.75)) } catch { case e: Exception => logError(s"Failed to fetch Hadoop delegation tokens $e") From fba63c1a7bc5c907b909bfd9247b85b36efba469 Mon Sep 17 00:00:00 2001 From: Sean Owen Date: Sun, 26 Nov 2017 07:42:44 -0600 Subject: [PATCH 1724/1765] [SPARK-22607][BUILD] Set large stack size consistently for tests to avoid StackOverflowError ## What changes were proposed in this pull request? Set `-ea` and `-Xss4m` consistently for tests, to fix in particular: ``` OrderingSuite: ... - GenerateOrdering with ShortType *** RUN ABORTED *** java.lang.StackOverflowError: at org.codehaus.janino.CodeContext.flowAnalysis(CodeContext.java:370) at org.codehaus.janino.CodeContext.flowAnalysis(CodeContext.java:541) at org.codehaus.janino.CodeContext.flowAnalysis(CodeContext.java:541) at org.codehaus.janino.CodeContext.flowAnalysis(CodeContext.java:541) at org.codehaus.janino.CodeContext.flowAnalysis(CodeContext.java:541) at org.codehaus.janino.CodeContext.flowAnalysis(CodeContext.java:541) at org.codehaus.janino.CodeContext.flowAnalysis(CodeContext.java:541) at org.codehaus.janino.CodeContext.flowAnalysis(CodeContext.java:541) ... ``` ## How was this patch tested? Existing tests. Manually verified it resolves the StackOverflowError this intends to resolve. Author: Sean Owen Closes #19820 from srowen/SPARK-22607. --- pom.xml | 4 ++-- project/SparkBuild.scala | 2 +- sql/catalyst/pom.xml | 2 +- sql/core/pom.xml | 2 +- 4 files changed, 5 insertions(+), 5 deletions(-) diff --git a/pom.xml b/pom.xml index 0297311dd6e64..3b2c629f8ec30 100644 --- a/pom.xml +++ b/pom.xml @@ -2101,7 +2101,7 @@ **/*Suite.java ${project.build.directory}/surefire-reports - -Xmx3g -Xss4096k -XX:ReservedCodeCacheSize=${CodeCacheSize} + -ea -Xmx3g -Xss4m -XX:ReservedCodeCacheSize=${CodeCacheSize} + + 4.0.0 + + org.apache.spark + spark-parent_2.11 + 2.3.0-SNAPSHOT + ../../../pom.xml + + + spark-kubernetes_2.11 + jar + Spark Project Kubernetes + + kubernetes + 3.0.0 + + + + + org.apache.spark + spark-core_${scala.binary.version} + ${project.version} + + + + org.apache.spark + spark-core_${scala.binary.version} + ${project.version} + test-jar + test + + + + io.fabric8 + kubernetes-client + ${kubernetes.client.version} + + + com.fasterxml.jackson.core + * + + + com.fasterxml.jackson.dataformat + jackson-dataformat-yaml + + + + + + + com.fasterxml.jackson.dataformat + jackson-dataformat-yaml + ${fasterxml.jackson.version} + + + + + com.google.guava + guava + + + + + org.mockito + mockito-core + test + + + + com.squareup.okhttp3 + okhttp + 3.8.1 + + + + + + + target/scala-${scala.binary.version}/classes + target/scala-${scala.binary.version}/test-classes + + + diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/Config.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/Config.scala new file mode 100644 index 0000000000000..f0742b91987b6 --- /dev/null +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/Config.scala @@ -0,0 +1,123 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.deploy.k8s + +import org.apache.spark.internal.Logging +import org.apache.spark.internal.config.ConfigBuilder +import org.apache.spark.network.util.ByteUnit + +private[spark] object Config extends Logging { + + val KUBERNETES_NAMESPACE = + ConfigBuilder("spark.kubernetes.namespace") + .doc("The namespace that will be used for running the driver and executor pods. When using " + + "spark-submit in cluster mode, this can also be passed to spark-submit via the " + + "--kubernetes-namespace command line argument.") + .stringConf + .createWithDefault("default") + + val EXECUTOR_DOCKER_IMAGE = + ConfigBuilder("spark.kubernetes.executor.docker.image") + .doc("Docker image to use for the executors. Specify this using the standard Docker tag " + + "format.") + .stringConf + .createOptional + + val DOCKER_IMAGE_PULL_POLICY = + ConfigBuilder("spark.kubernetes.docker.image.pullPolicy") + .doc("Kubernetes image pull policy. Valid values are Always, Never, and IfNotPresent.") + .stringConf + .checkValues(Set("Always", "Never", "IfNotPresent")) + .createWithDefault("IfNotPresent") + + val APISERVER_AUTH_DRIVER_CONF_PREFIX = + "spark.kubernetes.authenticate.driver" + val APISERVER_AUTH_DRIVER_MOUNTED_CONF_PREFIX = + "spark.kubernetes.authenticate.driver.mounted" + val OAUTH_TOKEN_CONF_SUFFIX = "oauthToken" + val OAUTH_TOKEN_FILE_CONF_SUFFIX = "oauthTokenFile" + val CLIENT_KEY_FILE_CONF_SUFFIX = "clientKeyFile" + val CLIENT_CERT_FILE_CONF_SUFFIX = "clientCertFile" + val CA_CERT_FILE_CONF_SUFFIX = "caCertFile" + + val KUBERNETES_SERVICE_ACCOUNT_NAME = + ConfigBuilder(s"$APISERVER_AUTH_DRIVER_CONF_PREFIX.serviceAccountName") + .doc("Service account that is used when running the driver pod. The driver pod uses " + + "this service account when requesting executor pods from the API server. If specific " + + "credentials are given for the driver pod to use, the driver will favor " + + "using those credentials instead.") + .stringConf + .createOptional + + // Note that while we set a default for this when we start up the + // scheduler, the specific default value is dynamically determined + // based on the executor memory. + val KUBERNETES_EXECUTOR_MEMORY_OVERHEAD = + ConfigBuilder("spark.kubernetes.executor.memoryOverhead") + .doc("The amount of off-heap memory (in megabytes) to be allocated per executor. This " + + "is memory that accounts for things like VM overheads, interned strings, other native " + + "overheads, etc. This tends to grow with the executor size. (typically 6-10%).") + .bytesConf(ByteUnit.MiB) + .createOptional + + val KUBERNETES_EXECUTOR_LABEL_PREFIX = "spark.kubernetes.executor.label." + val KUBERNETES_EXECUTOR_ANNOTATION_PREFIX = "spark.kubernetes.executor.annotation." + + val KUBERNETES_DRIVER_POD_NAME = + ConfigBuilder("spark.kubernetes.driver.pod.name") + .doc("Name of the driver pod.") + .stringConf + .createOptional + + val KUBERNETES_EXECUTOR_POD_NAME_PREFIX = + ConfigBuilder("spark.kubernetes.executor.podNamePrefix") + .doc("Prefix to use in front of the executor pod names.") + .internal() + .stringConf + .createWithDefault("spark") + + val KUBERNETES_ALLOCATION_BATCH_SIZE = + ConfigBuilder("spark.kubernetes.allocation.batch.size") + .doc("Number of pods to launch at once in each round of executor allocation.") + .intConf + .checkValue(value => value > 0, "Allocation batch size should be a positive integer") + .createWithDefault(5) + + val KUBERNETES_ALLOCATION_BATCH_DELAY = + ConfigBuilder("spark.kubernetes.allocation.batch.delay") + .doc("Number of seconds to wait between each round of executor allocation.") + .longConf + .checkValue(value => value > 0, "Allocation batch delay should be a positive integer") + .createWithDefault(1) + + val KUBERNETES_EXECUTOR_LIMIT_CORES = + ConfigBuilder("spark.kubernetes.executor.limit.cores") + .doc("Specify the hard cpu limit for a single executor pod") + .stringConf + .createOptional + + val KUBERNETES_EXECUTOR_LOST_REASON_CHECK_MAX_ATTEMPTS = + ConfigBuilder("spark.kubernetes.executor.lostCheck.maxAttempts") + .doc("Maximum number of attempts allowed for checking the reason of an executor loss " + + "before it is assumed that the executor failed.") + .intConf + .checkValue(value => value > 0, "Maximum attempts of checks of executor lost reason " + + "must be a positive integer") + .createWithDefault(10) + + val KUBERNETES_NODE_SELECTOR_PREFIX = "spark.kubernetes.node.selector." +} diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/ConfigurationUtils.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/ConfigurationUtils.scala new file mode 100644 index 0000000000000..01717479fddd9 --- /dev/null +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/ConfigurationUtils.scala @@ -0,0 +1,41 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.deploy.k8s + +import org.apache.spark.SparkConf + +private[spark] object ConfigurationUtils { + + /** + * Extract and parse Spark configuration properties with a given name prefix and + * return the result as a Map. Keys must not have more than one value. + * + * @param sparkConf Spark configuration + * @param prefix the given property name prefix + * @return a Map storing the configuration property keys and values + */ + def parsePrefixedKeyValuePairs( + sparkConf: SparkConf, + prefix: String): Map[String, String] = { + sparkConf.getAllWithPrefix(prefix).toMap + } + + def requireNandDefined(opt1: Option[_], opt2: Option[_], errMessage: String): Unit = { + opt1.foreach { _ => require(opt2.isEmpty, errMessage) } + } +} diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/Constants.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/Constants.scala new file mode 100644 index 0000000000000..4ddeefb15a89d --- /dev/null +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/Constants.scala @@ -0,0 +1,50 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.deploy.k8s + +private[spark] object Constants { + + // Labels + val SPARK_APP_ID_LABEL = "spark-app-selector" + val SPARK_EXECUTOR_ID_LABEL = "spark-exec-id" + val SPARK_ROLE_LABEL = "spark-role" + val SPARK_POD_DRIVER_ROLE = "driver" + val SPARK_POD_EXECUTOR_ROLE = "executor" + + // Default and fixed ports + val DEFAULT_DRIVER_PORT = 7078 + val DEFAULT_BLOCKMANAGER_PORT = 7079 + val BLOCK_MANAGER_PORT_NAME = "blockmanager" + val EXECUTOR_PORT_NAME = "executor" + + // Environment Variables + val ENV_EXECUTOR_PORT = "SPARK_EXECUTOR_PORT" + val ENV_DRIVER_URL = "SPARK_DRIVER_URL" + val ENV_EXECUTOR_CORES = "SPARK_EXECUTOR_CORES" + val ENV_EXECUTOR_MEMORY = "SPARK_EXECUTOR_MEMORY" + val ENV_APPLICATION_ID = "SPARK_APPLICATION_ID" + val ENV_EXECUTOR_ID = "SPARK_EXECUTOR_ID" + val ENV_EXECUTOR_POD_IP = "SPARK_EXECUTOR_POD_IP" + val ENV_EXECUTOR_EXTRA_CLASSPATH = "SPARK_EXECUTOR_EXTRA_CLASSPATH" + val ENV_MOUNTED_CLASSPATH = "SPARK_MOUNTED_CLASSPATH" + val ENV_JAVA_OPT_PREFIX = "SPARK_JAVA_OPT_" + + // Miscellaneous + val KUBERNETES_MASTER_INTERNAL_URL = "https://kubernetes.default.svc" + val MEMORY_OVERHEAD_FACTOR = 0.10 + val MEMORY_OVERHEAD_MIN_MIB = 384L +} diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/SparkKubernetesClientFactory.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/SparkKubernetesClientFactory.scala new file mode 100644 index 0000000000000..1e3f055e05766 --- /dev/null +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/SparkKubernetesClientFactory.scala @@ -0,0 +1,102 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.deploy.k8s + +import java.io.File + +import com.google.common.base.Charsets +import com.google.common.io.Files +import io.fabric8.kubernetes.client.{ConfigBuilder, DefaultKubernetesClient, KubernetesClient} +import io.fabric8.kubernetes.client.utils.HttpClientUtils +import okhttp3.Dispatcher + +import org.apache.spark.SparkConf +import org.apache.spark.deploy.k8s.Config._ +import org.apache.spark.util.ThreadUtils + +/** + * Spark-opinionated builder for Kubernetes clients. It uses a prefix plus common suffixes to + * parse configuration keys, similar to the manner in which Spark's SecurityManager parses SSL + * options for different components. + */ +private[spark] object SparkKubernetesClientFactory { + + def createKubernetesClient( + master: String, + namespace: Option[String], + kubernetesAuthConfPrefix: String, + sparkConf: SparkConf, + defaultServiceAccountToken: Option[File], + defaultServiceAccountCaCert: Option[File]): KubernetesClient = { + val oauthTokenFileConf = s"$kubernetesAuthConfPrefix.$OAUTH_TOKEN_FILE_CONF_SUFFIX" + val oauthTokenConf = s"$kubernetesAuthConfPrefix.$OAUTH_TOKEN_CONF_SUFFIX" + val oauthTokenFile = sparkConf.getOption(oauthTokenFileConf) + .map(new File(_)) + .orElse(defaultServiceAccountToken) + val oauthTokenValue = sparkConf.getOption(oauthTokenConf) + ConfigurationUtils.requireNandDefined( + oauthTokenFile, + oauthTokenValue, + s"Cannot specify OAuth token through both a file $oauthTokenFileConf and a " + + s"value $oauthTokenConf.") + + val caCertFile = sparkConf + .getOption(s"$kubernetesAuthConfPrefix.$CA_CERT_FILE_CONF_SUFFIX") + .orElse(defaultServiceAccountCaCert.map(_.getAbsolutePath)) + val clientKeyFile = sparkConf + .getOption(s"$kubernetesAuthConfPrefix.$CLIENT_KEY_FILE_CONF_SUFFIX") + val clientCertFile = sparkConf + .getOption(s"$kubernetesAuthConfPrefix.$CLIENT_CERT_FILE_CONF_SUFFIX") + val dispatcher = new Dispatcher( + ThreadUtils.newDaemonCachedThreadPool("kubernetes-dispatcher")) + val config = new ConfigBuilder() + .withApiVersion("v1") + .withMasterUrl(master) + .withWebsocketPingInterval(0) + .withOption(oauthTokenValue) { + (token, configBuilder) => configBuilder.withOauthToken(token) + }.withOption(oauthTokenFile) { + (file, configBuilder) => + configBuilder.withOauthToken(Files.toString(file, Charsets.UTF_8)) + }.withOption(caCertFile) { + (file, configBuilder) => configBuilder.withCaCertFile(file) + }.withOption(clientKeyFile) { + (file, configBuilder) => configBuilder.withClientKeyFile(file) + }.withOption(clientCertFile) { + (file, configBuilder) => configBuilder.withClientCertFile(file) + }.withOption(namespace) { + (ns, configBuilder) => configBuilder.withNamespace(ns) + }.build() + val baseHttpClient = HttpClientUtils.createHttpClient(config) + val httpClientWithCustomDispatcher = baseHttpClient.newBuilder() + .dispatcher(dispatcher) + .build() + new DefaultKubernetesClient(httpClientWithCustomDispatcher, config) + } + + private implicit class OptionConfigurableConfigBuilder(val configBuilder: ConfigBuilder) + extends AnyVal { + + def withOption[T] + (option: Option[T]) + (configurator: ((T, ConfigBuilder) => ConfigBuilder)): ConfigBuilder = { + option.map { opt => + configurator(opt, configBuilder) + }.getOrElse(configBuilder) + } + } +} diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodFactory.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodFactory.scala new file mode 100644 index 0000000000000..f79155b117b67 --- /dev/null +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodFactory.scala @@ -0,0 +1,219 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.scheduler.cluster.k8s + +import scala.collection.JavaConverters._ + +import io.fabric8.kubernetes.api.model._ + +import org.apache.spark.{SparkConf, SparkException} +import org.apache.spark.deploy.k8s.Config._ +import org.apache.spark.deploy.k8s.ConfigurationUtils +import org.apache.spark.deploy.k8s.Constants._ +import org.apache.spark.util.Utils + +/** + * A factory class for configuring and creating executor pods. + */ +private[spark] trait ExecutorPodFactory { + + /** + * Configure and construct an executor pod with the given parameters. + */ + def createExecutorPod( + executorId: String, + applicationId: String, + driverUrl: String, + executorEnvs: Seq[(String, String)], + driverPod: Pod, + nodeToLocalTaskCount: Map[String, Int]): Pod +} + +private[spark] class ExecutorPodFactoryImpl(sparkConf: SparkConf) + extends ExecutorPodFactory { + + private val executorExtraClasspath = + sparkConf.get(org.apache.spark.internal.config.EXECUTOR_CLASS_PATH) + + private val executorLabels = ConfigurationUtils.parsePrefixedKeyValuePairs( + sparkConf, + KUBERNETES_EXECUTOR_LABEL_PREFIX) + require( + !executorLabels.contains(SPARK_APP_ID_LABEL), + s"Custom executor labels cannot contain $SPARK_APP_ID_LABEL as it is reserved for Spark.") + require( + !executorLabels.contains(SPARK_EXECUTOR_ID_LABEL), + s"Custom executor labels cannot contain $SPARK_EXECUTOR_ID_LABEL as it is reserved for" + + " Spark.") + require( + !executorLabels.contains(SPARK_ROLE_LABEL), + s"Custom executor labels cannot contain $SPARK_ROLE_LABEL as it is reserved for Spark.") + + private val executorAnnotations = + ConfigurationUtils.parsePrefixedKeyValuePairs( + sparkConf, + KUBERNETES_EXECUTOR_ANNOTATION_PREFIX) + private val nodeSelector = + ConfigurationUtils.parsePrefixedKeyValuePairs( + sparkConf, + KUBERNETES_NODE_SELECTOR_PREFIX) + + private val executorDockerImage = sparkConf + .get(EXECUTOR_DOCKER_IMAGE) + .getOrElse(throw new SparkException("Must specify the executor Docker image")) + private val dockerImagePullPolicy = sparkConf.get(DOCKER_IMAGE_PULL_POLICY) + private val blockManagerPort = sparkConf + .getInt("spark.blockmanager.port", DEFAULT_BLOCKMANAGER_PORT) + + private val executorPodNamePrefix = sparkConf.get(KUBERNETES_EXECUTOR_POD_NAME_PREFIX) + + private val executorMemoryMiB = sparkConf.get(org.apache.spark.internal.config.EXECUTOR_MEMORY) + private val executorMemoryString = sparkConf.get( + org.apache.spark.internal.config.EXECUTOR_MEMORY.key, + org.apache.spark.internal.config.EXECUTOR_MEMORY.defaultValueString) + + private val memoryOverheadMiB = sparkConf + .get(KUBERNETES_EXECUTOR_MEMORY_OVERHEAD) + .getOrElse(math.max((MEMORY_OVERHEAD_FACTOR * executorMemoryMiB).toInt, + MEMORY_OVERHEAD_MIN_MIB)) + private val executorMemoryWithOverhead = executorMemoryMiB + memoryOverheadMiB + + private val executorCores = sparkConf.getDouble("spark.executor.cores", 1) + private val executorLimitCores = sparkConf.get(KUBERNETES_EXECUTOR_LIMIT_CORES) + + override def createExecutorPod( + executorId: String, + applicationId: String, + driverUrl: String, + executorEnvs: Seq[(String, String)], + driverPod: Pod, + nodeToLocalTaskCount: Map[String, Int]): Pod = { + val name = s"$executorPodNamePrefix-exec-$executorId" + + // hostname must be no longer than 63 characters, so take the last 63 characters of the pod + // name as the hostname. This preserves uniqueness since the end of name contains + // executorId + val hostname = name.substring(Math.max(0, name.length - 63)) + val resolvedExecutorLabels = Map( + SPARK_EXECUTOR_ID_LABEL -> executorId, + SPARK_APP_ID_LABEL -> applicationId, + SPARK_ROLE_LABEL -> SPARK_POD_EXECUTOR_ROLE) ++ + executorLabels + val executorMemoryQuantity = new QuantityBuilder(false) + .withAmount(s"${executorMemoryMiB}Mi") + .build() + val executorMemoryLimitQuantity = new QuantityBuilder(false) + .withAmount(s"${executorMemoryWithOverhead}Mi") + .build() + val executorCpuQuantity = new QuantityBuilder(false) + .withAmount(executorCores.toString) + .build() + val executorExtraClasspathEnv = executorExtraClasspath.map { cp => + new EnvVarBuilder() + .withName(ENV_EXECUTOR_EXTRA_CLASSPATH) + .withValue(cp) + .build() + } + val executorExtraJavaOptionsEnv = sparkConf + .get(org.apache.spark.internal.config.EXECUTOR_JAVA_OPTIONS) + .map { opts => + val delimitedOpts = Utils.splitCommandString(opts) + delimitedOpts.zipWithIndex.map { + case (opt, index) => + new EnvVarBuilder().withName(s"$ENV_JAVA_OPT_PREFIX$index").withValue(opt).build() + } + }.getOrElse(Seq.empty[EnvVar]) + val executorEnv = (Seq( + (ENV_DRIVER_URL, driverUrl), + // Executor backend expects integral value for executor cores, so round it up to an int. + (ENV_EXECUTOR_CORES, math.ceil(executorCores).toInt.toString), + (ENV_EXECUTOR_MEMORY, executorMemoryString), + (ENV_APPLICATION_ID, applicationId), + (ENV_EXECUTOR_ID, executorId)) ++ executorEnvs) + .map(env => new EnvVarBuilder() + .withName(env._1) + .withValue(env._2) + .build() + ) ++ Seq( + new EnvVarBuilder() + .withName(ENV_EXECUTOR_POD_IP) + .withValueFrom(new EnvVarSourceBuilder() + .withNewFieldRef("v1", "status.podIP") + .build()) + .build() + ) ++ executorExtraJavaOptionsEnv ++ executorExtraClasspathEnv.toSeq + val requiredPorts = Seq( + (BLOCK_MANAGER_PORT_NAME, blockManagerPort)) + .map { case (name, port) => + new ContainerPortBuilder() + .withName(name) + .withContainerPort(port) + .build() + } + + val executorContainer = new ContainerBuilder() + .withName("executor") + .withImage(executorDockerImage) + .withImagePullPolicy(dockerImagePullPolicy) + .withNewResources() + .addToRequests("memory", executorMemoryQuantity) + .addToLimits("memory", executorMemoryLimitQuantity) + .addToRequests("cpu", executorCpuQuantity) + .endResources() + .addAllToEnv(executorEnv.asJava) + .withPorts(requiredPorts.asJava) + .build() + + val executorPod = new PodBuilder() + .withNewMetadata() + .withName(name) + .withLabels(resolvedExecutorLabels.asJava) + .withAnnotations(executorAnnotations.asJava) + .withOwnerReferences() + .addNewOwnerReference() + .withController(true) + .withApiVersion(driverPod.getApiVersion) + .withKind(driverPod.getKind) + .withName(driverPod.getMetadata.getName) + .withUid(driverPod.getMetadata.getUid) + .endOwnerReference() + .endMetadata() + .withNewSpec() + .withHostname(hostname) + .withRestartPolicy("Never") + .withNodeSelector(nodeSelector.asJava) + .endSpec() + .build() + + val containerWithExecutorLimitCores = executorLimitCores.map { limitCores => + val executorCpuLimitQuantity = new QuantityBuilder(false) + .withAmount(limitCores) + .build() + new ContainerBuilder(executorContainer) + .editResources() + .addToLimits("cpu", executorCpuLimitQuantity) + .endResources() + .build() + }.getOrElse(executorContainer) + + new PodBuilder(executorPod) + .editSpec() + .addToContainers(containerWithExecutorLimitCores) + .endSpec() + .build() + } +} diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesClusterManager.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesClusterManager.scala new file mode 100644 index 0000000000000..68ca6a7622171 --- /dev/null +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesClusterManager.scala @@ -0,0 +1,70 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.scheduler.cluster.k8s + +import java.io.File + +import io.fabric8.kubernetes.client.Config + +import org.apache.spark.SparkContext +import org.apache.spark.deploy.k8s.Config._ +import org.apache.spark.deploy.k8s.Constants._ +import org.apache.spark.deploy.k8s.SparkKubernetesClientFactory +import org.apache.spark.internal.Logging +import org.apache.spark.scheduler.{ExternalClusterManager, SchedulerBackend, TaskScheduler, TaskSchedulerImpl} +import org.apache.spark.util.ThreadUtils + +private[spark] class KubernetesClusterManager extends ExternalClusterManager with Logging { + + override def canCreate(masterURL: String): Boolean = masterURL.startsWith("k8s") + + override def createTaskScheduler(sc: SparkContext, masterURL: String): TaskScheduler = { + new TaskSchedulerImpl(sc) + } + + override def createSchedulerBackend( + sc: SparkContext, + masterURL: String, + scheduler: TaskScheduler): SchedulerBackend = { + val sparkConf = sc.getConf + + val kubernetesClient = SparkKubernetesClientFactory.createKubernetesClient( + KUBERNETES_MASTER_INTERNAL_URL, + Some(sparkConf.get(KUBERNETES_NAMESPACE)), + APISERVER_AUTH_DRIVER_MOUNTED_CONF_PREFIX, + sparkConf, + Some(new File(Config.KUBERNETES_SERVICE_ACCOUNT_TOKEN_PATH)), + Some(new File(Config.KUBERNETES_SERVICE_ACCOUNT_CA_CRT_PATH))) + + val executorPodFactory = new ExecutorPodFactoryImpl(sparkConf) + val allocatorExecutor = ThreadUtils + .newDaemonSingleThreadScheduledExecutor("kubernetes-pod-allocator") + val requestExecutorsService = ThreadUtils.newDaemonCachedThreadPool( + "kubernetes-executor-requests") + new KubernetesClusterSchedulerBackend( + scheduler.asInstanceOf[TaskSchedulerImpl], + sc.env.rpcEnv, + executorPodFactory, + kubernetesClient, + allocatorExecutor, + requestExecutorsService) + } + + override def initialize(scheduler: TaskScheduler, backend: SchedulerBackend): Unit = { + scheduler.asInstanceOf[TaskSchedulerImpl].initialize(backend) + } +} diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesClusterSchedulerBackend.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesClusterSchedulerBackend.scala new file mode 100644 index 0000000000000..e79c987852db2 --- /dev/null +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesClusterSchedulerBackend.scala @@ -0,0 +1,442 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.scheduler.cluster.k8s + +import java.io.Closeable +import java.net.InetAddress +import java.util.concurrent.{ConcurrentHashMap, ExecutorService, ScheduledExecutorService, TimeUnit} +import java.util.concurrent.atomic.{AtomicInteger, AtomicLong, AtomicReference} +import javax.annotation.concurrent.GuardedBy + +import io.fabric8.kubernetes.api.model._ +import io.fabric8.kubernetes.client.{KubernetesClient, KubernetesClientException, Watcher} +import io.fabric8.kubernetes.client.Watcher.Action +import scala.collection.JavaConverters._ +import scala.collection.mutable +import scala.concurrent.{ExecutionContext, Future} + +import org.apache.spark.SparkException +import org.apache.spark.deploy.k8s.Config._ +import org.apache.spark.deploy.k8s.Constants._ +import org.apache.spark.rpc.{RpcAddress, RpcEndpointAddress, RpcEnv} +import org.apache.spark.scheduler.{ExecutorExited, SlaveLost, TaskSchedulerImpl} +import org.apache.spark.scheduler.cluster.{CoarseGrainedSchedulerBackend, SchedulerBackendUtils} +import org.apache.spark.util.Utils + +private[spark] class KubernetesClusterSchedulerBackend( + scheduler: TaskSchedulerImpl, + rpcEnv: RpcEnv, + executorPodFactory: ExecutorPodFactory, + kubernetesClient: KubernetesClient, + allocatorExecutor: ScheduledExecutorService, + requestExecutorsService: ExecutorService) + extends CoarseGrainedSchedulerBackend(scheduler, rpcEnv) { + + import KubernetesClusterSchedulerBackend._ + + private val EXECUTOR_ID_COUNTER = new AtomicLong(0L) + private val RUNNING_EXECUTOR_PODS_LOCK = new Object + @GuardedBy("RUNNING_EXECUTOR_PODS_LOCK") + private val runningExecutorsToPods = new mutable.HashMap[String, Pod] + private val executorPodsByIPs = new ConcurrentHashMap[String, Pod]() + private val podsWithKnownExitReasons = new ConcurrentHashMap[String, ExecutorExited]() + private val disconnectedPodsByExecutorIdPendingRemoval = new ConcurrentHashMap[String, Pod]() + + private val kubernetesNamespace = conf.get(KUBERNETES_NAMESPACE) + + private val kubernetesDriverPodName = conf + .get(KUBERNETES_DRIVER_POD_NAME) + .getOrElse(throw new SparkException("Must specify the driver pod name")) + private implicit val requestExecutorContext = ExecutionContext.fromExecutorService( + requestExecutorsService) + + private val driverPod = kubernetesClient.pods() + .inNamespace(kubernetesNamespace) + .withName(kubernetesDriverPodName) + .get() + + protected override val minRegisteredRatio = + if (conf.getOption("spark.scheduler.minRegisteredResourcesRatio").isEmpty) { + 0.8 + } else { + super.minRegisteredRatio + } + + private val executorWatchResource = new AtomicReference[Closeable] + private val totalExpectedExecutors = new AtomicInteger(0) + + private val driverUrl = RpcEndpointAddress( + conf.get("spark.driver.host"), + conf.getInt("spark.driver.port", DEFAULT_DRIVER_PORT), + CoarseGrainedSchedulerBackend.ENDPOINT_NAME).toString + + private val initialExecutors = SchedulerBackendUtils.getInitialTargetExecutorNumber(conf) + + private val podAllocationInterval = conf.get(KUBERNETES_ALLOCATION_BATCH_DELAY) + + private val podAllocationSize = conf.get(KUBERNETES_ALLOCATION_BATCH_SIZE) + + private val executorLostReasonCheckMaxAttempts = conf.get( + KUBERNETES_EXECUTOR_LOST_REASON_CHECK_MAX_ATTEMPTS) + + private val allocatorRunnable = new Runnable { + + // Maintains a map of executor id to count of checks performed to learn the loss reason + // for an executor. + private val executorReasonCheckAttemptCounts = new mutable.HashMap[String, Int] + + override def run(): Unit = { + handleDisconnectedExecutors() + + val executorsToAllocate = mutable.Map[String, Pod]() + val currentTotalRegisteredExecutors = totalRegisteredExecutors.get + val currentTotalExpectedExecutors = totalExpectedExecutors.get + val currentNodeToLocalTaskCount = getNodesWithLocalTaskCounts() + RUNNING_EXECUTOR_PODS_LOCK.synchronized { + if (currentTotalRegisteredExecutors < runningExecutorsToPods.size) { + logDebug("Waiting for pending executors before scaling") + } else if (currentTotalExpectedExecutors <= runningExecutorsToPods.size) { + logDebug("Maximum allowed executor limit reached. Not scaling up further.") + } else { + for (_ <- 0 until math.min( + currentTotalExpectedExecutors - runningExecutorsToPods.size, podAllocationSize)) { + val executorId = EXECUTOR_ID_COUNTER.incrementAndGet().toString + val executorPod = executorPodFactory.createExecutorPod( + executorId, + applicationId(), + driverUrl, + conf.getExecutorEnv, + driverPod, + currentNodeToLocalTaskCount) + executorsToAllocate(executorId) = executorPod + logInfo( + s"Requesting a new executor, total executors is now ${runningExecutorsToPods.size}") + } + } + } + + val allocatedExecutors = executorsToAllocate.mapValues { pod => + Utils.tryLog { + kubernetesClient.pods().create(pod) + } + } + + RUNNING_EXECUTOR_PODS_LOCK.synchronized { + allocatedExecutors.map { + case (executorId, attemptedAllocatedExecutor) => + attemptedAllocatedExecutor.map { successfullyAllocatedExecutor => + runningExecutorsToPods.put(executorId, successfullyAllocatedExecutor) + } + } + } + } + + def handleDisconnectedExecutors(): Unit = { + // For each disconnected executor, synchronize with the loss reasons that may have been found + // by the executor pod watcher. If the loss reason was discovered by the watcher, + // inform the parent class with removeExecutor. + disconnectedPodsByExecutorIdPendingRemoval.asScala.foreach { + case (executorId, executorPod) => + val knownExitReason = Option(podsWithKnownExitReasons.remove( + executorPod.getMetadata.getName)) + knownExitReason.fold { + removeExecutorOrIncrementLossReasonCheckCount(executorId) + } { executorExited => + logWarning(s"Removing executor $executorId with loss reason " + executorExited.message) + removeExecutor(executorId, executorExited) + // We don't delete the pod running the executor that has an exit condition caused by + // the application from the Kubernetes API server. This allows users to debug later on + // through commands such as "kubectl logs " and + // "kubectl describe pod ". Note that exited containers have terminated and + // therefore won't take CPU and memory resources. + // Otherwise, the executor pod is marked to be deleted from the API server. + if (executorExited.exitCausedByApp) { + logInfo(s"Executor $executorId exited because of the application.") + deleteExecutorFromDataStructures(executorId) + } else { + logInfo(s"Executor $executorId failed because of a framework error.") + deleteExecutorFromClusterAndDataStructures(executorId) + } + } + } + } + + def removeExecutorOrIncrementLossReasonCheckCount(executorId: String): Unit = { + val reasonCheckCount = executorReasonCheckAttemptCounts.getOrElse(executorId, 0) + if (reasonCheckCount >= executorLostReasonCheckMaxAttempts) { + removeExecutor(executorId, SlaveLost("Executor lost for unknown reasons.")) + deleteExecutorFromClusterAndDataStructures(executorId) + } else { + executorReasonCheckAttemptCounts.put(executorId, reasonCheckCount + 1) + } + } + + def deleteExecutorFromClusterAndDataStructures(executorId: String): Unit = { + deleteExecutorFromDataStructures(executorId).foreach { pod => + kubernetesClient.pods().delete(pod) + } + } + + def deleteExecutorFromDataStructures(executorId: String): Option[Pod] = { + disconnectedPodsByExecutorIdPendingRemoval.remove(executorId) + executorReasonCheckAttemptCounts -= executorId + podsWithKnownExitReasons.remove(executorId) + RUNNING_EXECUTOR_PODS_LOCK.synchronized { + runningExecutorsToPods.remove(executorId).orElse { + logWarning(s"Unable to remove pod for unknown executor $executorId") + None + } + } + } + } + + override def sufficientResourcesRegistered(): Boolean = { + totalRegisteredExecutors.get() >= initialExecutors * minRegisteredRatio + } + + override def start(): Unit = { + super.start() + executorWatchResource.set( + kubernetesClient + .pods() + .withLabel(SPARK_APP_ID_LABEL, applicationId()) + .watch(new ExecutorPodsWatcher())) + + allocatorExecutor.scheduleWithFixedDelay( + allocatorRunnable, 0L, podAllocationInterval, TimeUnit.SECONDS) + + if (!Utils.isDynamicAllocationEnabled(conf)) { + doRequestTotalExecutors(initialExecutors) + } + } + + override def stop(): Unit = { + // stop allocation of new resources and caches. + allocatorExecutor.shutdown() + allocatorExecutor.awaitTermination(30, TimeUnit.SECONDS) + + // send stop message to executors so they shut down cleanly + super.stop() + + try { + val resource = executorWatchResource.getAndSet(null) + if (resource != null) { + resource.close() + } + } catch { + case e: Throwable => logWarning("Failed to close the executor pod watcher", e) + } + + // then delete the executor pods + Utils.tryLogNonFatalError { + deleteExecutorPodsOnStop() + executorPodsByIPs.clear() + } + Utils.tryLogNonFatalError { + logInfo("Closing kubernetes client") + kubernetesClient.close() + } + } + + /** + * @return A map of K8s cluster nodes to the number of tasks that could benefit from data + * locality if an executor launches on the cluster node. + */ + private def getNodesWithLocalTaskCounts() : Map[String, Int] = { + val nodeToLocalTaskCount = synchronized { + mutable.Map[String, Int]() ++ hostToLocalTaskCount + } + + for (pod <- executorPodsByIPs.values().asScala) { + // Remove cluster nodes that are running our executors already. + // TODO: This prefers spreading out executors across nodes. In case users want + // consolidating executors on fewer nodes, introduce a flag. See the spark.deploy.spreadOut + // flag that Spark standalone has: https://spark.apache.org/docs/latest/spark-standalone.html + nodeToLocalTaskCount.remove(pod.getSpec.getNodeName).nonEmpty || + nodeToLocalTaskCount.remove(pod.getStatus.getHostIP).nonEmpty || + nodeToLocalTaskCount.remove( + InetAddress.getByName(pod.getStatus.getHostIP).getCanonicalHostName).nonEmpty + } + nodeToLocalTaskCount.toMap[String, Int] + } + + override def doRequestTotalExecutors(requestedTotal: Int): Future[Boolean] = Future[Boolean] { + totalExpectedExecutors.set(requestedTotal) + true + } + + override def doKillExecutors(executorIds: Seq[String]): Future[Boolean] = Future[Boolean] { + val podsToDelete = RUNNING_EXECUTOR_PODS_LOCK.synchronized { + executorIds.flatMap { executorId => + runningExecutorsToPods.remove(executorId) match { + case Some(pod) => + disconnectedPodsByExecutorIdPendingRemoval.put(executorId, pod) + Some(pod) + + case None => + logWarning(s"Unable to remove pod for unknown executor $executorId") + None + } + } + } + + kubernetesClient.pods().delete(podsToDelete: _*) + true + } + + private def deleteExecutorPodsOnStop(): Unit = { + val executorPodsToDelete = RUNNING_EXECUTOR_PODS_LOCK.synchronized { + val runningExecutorPodsCopy = Seq(runningExecutorsToPods.values.toSeq: _*) + runningExecutorsToPods.clear() + runningExecutorPodsCopy + } + kubernetesClient.pods().delete(executorPodsToDelete: _*) + } + + private class ExecutorPodsWatcher extends Watcher[Pod] { + + private val DEFAULT_CONTAINER_FAILURE_EXIT_STATUS = -1 + + override def eventReceived(action: Action, pod: Pod): Unit = { + val podName = pod.getMetadata.getName + val podIP = pod.getStatus.getPodIP + + action match { + case Action.MODIFIED if (pod.getStatus.getPhase == "Running" + && pod.getMetadata.getDeletionTimestamp == null) => + val clusterNodeName = pod.getSpec.getNodeName + logInfo(s"Executor pod $podName ready, launched at $clusterNodeName as IP $podIP.") + executorPodsByIPs.put(podIP, pod) + + case Action.DELETED | Action.ERROR => + val executorId = getExecutorId(pod) + logDebug(s"Executor pod $podName at IP $podIP was at $action.") + if (podIP != null) { + executorPodsByIPs.remove(podIP) + } + + val executorExitReason = if (action == Action.ERROR) { + logWarning(s"Received error event of executor pod $podName. Reason: " + + pod.getStatus.getReason) + executorExitReasonOnError(pod) + } else if (action == Action.DELETED) { + logWarning(s"Received delete event of executor pod $podName. Reason: " + + pod.getStatus.getReason) + executorExitReasonOnDelete(pod) + } else { + throw new IllegalStateException( + s"Unknown action that should only be DELETED or ERROR: $action") + } + podsWithKnownExitReasons.put(pod.getMetadata.getName, executorExitReason) + + if (!disconnectedPodsByExecutorIdPendingRemoval.containsKey(executorId)) { + log.warn(s"Executor with id $executorId was not marked as disconnected, but the " + + s"watch received an event of type $action for this executor. The executor may " + + "have failed to start in the first place and never registered with the driver.") + } + disconnectedPodsByExecutorIdPendingRemoval.put(executorId, pod) + + case _ => logDebug(s"Received event of executor pod $podName: " + action) + } + } + + override def onClose(cause: KubernetesClientException): Unit = { + logDebug("Executor pod watch closed.", cause) + } + + private def getExecutorExitStatus(pod: Pod): Int = { + val containerStatuses = pod.getStatus.getContainerStatuses + if (!containerStatuses.isEmpty) { + // we assume the first container represents the pod status. This assumption may not hold + // true in the future. Revisit this if side-car containers start running inside executor + // pods. + getExecutorExitStatus(containerStatuses.get(0)) + } else DEFAULT_CONTAINER_FAILURE_EXIT_STATUS + } + + private def getExecutorExitStatus(containerStatus: ContainerStatus): Int = { + Option(containerStatus.getState).map { containerState => + Option(containerState.getTerminated).map { containerStateTerminated => + containerStateTerminated.getExitCode.intValue() + }.getOrElse(UNKNOWN_EXIT_CODE) + }.getOrElse(UNKNOWN_EXIT_CODE) + } + + private def isPodAlreadyReleased(pod: Pod): Boolean = { + val executorId = pod.getMetadata.getLabels.get(SPARK_EXECUTOR_ID_LABEL) + RUNNING_EXECUTOR_PODS_LOCK.synchronized { + !runningExecutorsToPods.contains(executorId) + } + } + + private def executorExitReasonOnError(pod: Pod): ExecutorExited = { + val containerExitStatus = getExecutorExitStatus(pod) + // container was probably actively killed by the driver. + if (isPodAlreadyReleased(pod)) { + ExecutorExited(containerExitStatus, exitCausedByApp = false, + s"Container in pod ${pod.getMetadata.getName} exited from explicit termination " + + "request.") + } else { + val containerExitReason = s"Pod ${pod.getMetadata.getName}'s executor container " + + s"exited with exit status code $containerExitStatus." + ExecutorExited(containerExitStatus, exitCausedByApp = true, containerExitReason) + } + } + + private def executorExitReasonOnDelete(pod: Pod): ExecutorExited = { + val exitMessage = if (isPodAlreadyReleased(pod)) { + s"Container in pod ${pod.getMetadata.getName} exited from explicit termination request." + } else { + s"Pod ${pod.getMetadata.getName} deleted or lost." + } + ExecutorExited(getExecutorExitStatus(pod), exitCausedByApp = false, exitMessage) + } + + private def getExecutorId(pod: Pod): String = { + val executorId = pod.getMetadata.getLabels.get(SPARK_EXECUTOR_ID_LABEL) + require(executorId != null, "Unexpected pod metadata; expected all executor pods " + + s"to have label $SPARK_EXECUTOR_ID_LABEL.") + executorId + } + } + + override def createDriverEndpoint(properties: Seq[(String, String)]): DriverEndpoint = { + new KubernetesDriverEndpoint(rpcEnv, properties) + } + + private class KubernetesDriverEndpoint( + rpcEnv: RpcEnv, + sparkProperties: Seq[(String, String)]) + extends DriverEndpoint(rpcEnv, sparkProperties) { + + override def onDisconnected(rpcAddress: RpcAddress): Unit = { + addressToExecutorId.get(rpcAddress).foreach { executorId => + if (disableExecutor(executorId)) { + RUNNING_EXECUTOR_PODS_LOCK.synchronized { + runningExecutorsToPods.get(executorId).foreach { pod => + disconnectedPodsByExecutorIdPendingRemoval.put(executorId, pod) + } + } + } + } + } + } +} + +private object KubernetesClusterSchedulerBackend { + private val UNKNOWN_EXIT_CODE = -1 +} diff --git a/resource-managers/kubernetes/core/src/test/resources/log4j.properties b/resource-managers/kubernetes/core/src/test/resources/log4j.properties new file mode 100644 index 0000000000000..ad95fadb7c0c0 --- /dev/null +++ b/resource-managers/kubernetes/core/src/test/resources/log4j.properties @@ -0,0 +1,31 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# Set everything to be logged to the file target/unit-tests.log +log4j.rootCategory=INFO, file +log4j.appender.file=org.apache.log4j.FileAppender +log4j.appender.file.append=true +log4j.appender.file.file=target/unit-tests.log +log4j.appender.file.layout=org.apache.log4j.PatternLayout +log4j.appender.file.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss.SSS} %t %p %c{1}: %m%n + +# Ignore messages below warning level from a few verbose libraries. +log4j.logger.com.sun.jersey=WARN +log4j.logger.org.apache.hadoop=WARN +log4j.logger.org.eclipse.jetty=WARN +log4j.logger.org.mortbay=WARN +log4j.logger.org.spark_project.jetty=WARN diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodFactorySuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodFactorySuite.scala new file mode 100644 index 0000000000000..1c7717c238096 --- /dev/null +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodFactorySuite.scala @@ -0,0 +1,135 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.scheduler.cluster.k8s + +import scala.collection.JavaConverters._ + +import io.fabric8.kubernetes.api.model.{Pod, _} +import org.mockito.MockitoAnnotations +import org.scalatest.{BeforeAndAfter, BeforeAndAfterEach} + +import org.apache.spark.{SparkConf, SparkFunSuite} +import org.apache.spark.deploy.k8s.Config._ +import org.apache.spark.deploy.k8s.Constants._ + +class ExecutorPodFactorySuite extends SparkFunSuite with BeforeAndAfter with BeforeAndAfterEach { + private val driverPodName: String = "driver-pod" + private val driverPodUid: String = "driver-uid" + private val executorPrefix: String = "base" + private val executorImage: String = "executor-image" + private val driverPod = new PodBuilder() + .withNewMetadata() + .withName(driverPodName) + .withUid(driverPodUid) + .endMetadata() + .withNewSpec() + .withNodeName("some-node") + .endSpec() + .withNewStatus() + .withHostIP("192.168.99.100") + .endStatus() + .build() + private var baseConf: SparkConf = _ + + before { + MockitoAnnotations.initMocks(this) + baseConf = new SparkConf() + .set(KUBERNETES_DRIVER_POD_NAME, driverPodName) + .set(KUBERNETES_EXECUTOR_POD_NAME_PREFIX, executorPrefix) + .set(EXECUTOR_DOCKER_IMAGE, executorImage) + } + + test("basic executor pod has reasonable defaults") { + val factory = new ExecutorPodFactoryImpl(baseConf) + val executor = factory.createExecutorPod( + "1", "dummy", "dummy", Seq[(String, String)](), driverPod, Map[String, Int]()) + + // The executor pod name and default labels. + assert(executor.getMetadata.getName === s"$executorPrefix-exec-1") + assert(executor.getMetadata.getLabels.size() === 3) + assert(executor.getMetadata.getLabels.get(SPARK_EXECUTOR_ID_LABEL) === "1") + + // There is exactly 1 container with no volume mounts and default memory limits. + // Default memory limit is 1024M + 384M (minimum overhead constant). + assert(executor.getSpec.getContainers.size() === 1) + assert(executor.getSpec.getContainers.get(0).getImage === executorImage) + assert(executor.getSpec.getContainers.get(0).getVolumeMounts.isEmpty) + assert(executor.getSpec.getContainers.get(0).getResources.getLimits.size() === 1) + assert(executor.getSpec.getContainers.get(0).getResources + .getLimits.get("memory").getAmount === "1408Mi") + + // The pod has no node selector, volumes. + assert(executor.getSpec.getNodeSelector.isEmpty) + assert(executor.getSpec.getVolumes.isEmpty) + + checkEnv(executor, Map()) + checkOwnerReferences(executor, driverPodUid) + } + + test("executor pod hostnames get truncated to 63 characters") { + val conf = baseConf.clone() + conf.set(KUBERNETES_EXECUTOR_POD_NAME_PREFIX, + "loremipsumdolorsitametvimatelitrefficiendisuscipianturvixlegeresple") + + val factory = new ExecutorPodFactoryImpl(conf) + val executor = factory.createExecutorPod( + "1", "dummy", "dummy", Seq[(String, String)](), driverPod, Map[String, Int]()) + + assert(executor.getSpec.getHostname.length === 63) + } + + test("classpath and extra java options get translated into environment variables") { + val conf = baseConf.clone() + conf.set(org.apache.spark.internal.config.EXECUTOR_JAVA_OPTIONS, "foo=bar") + conf.set(org.apache.spark.internal.config.EXECUTOR_CLASS_PATH, "bar=baz") + + val factory = new ExecutorPodFactoryImpl(conf) + val executor = factory.createExecutorPod( + "1", "dummy", "dummy", Seq[(String, String)]("qux" -> "quux"), driverPod, Map[String, Int]()) + + checkEnv(executor, + Map("SPARK_JAVA_OPT_0" -> "foo=bar", + "SPARK_EXECUTOR_EXTRA_CLASSPATH" -> "bar=baz", + "qux" -> "quux")) + checkOwnerReferences(executor, driverPodUid) + } + + // There is always exactly one controller reference, and it points to the driver pod. + private def checkOwnerReferences(executor: Pod, driverPodUid: String): Unit = { + assert(executor.getMetadata.getOwnerReferences.size() === 1) + assert(executor.getMetadata.getOwnerReferences.get(0).getUid === driverPodUid) + assert(executor.getMetadata.getOwnerReferences.get(0).getController === true) + } + + // Check that the expected environment variables are present. + private def checkEnv(executor: Pod, additionalEnvVars: Map[String, String]): Unit = { + val defaultEnvs = Map( + ENV_EXECUTOR_ID -> "1", + ENV_DRIVER_URL -> "dummy", + ENV_EXECUTOR_CORES -> "1", + ENV_EXECUTOR_MEMORY -> "1g", + ENV_APPLICATION_ID -> "dummy", + ENV_EXECUTOR_POD_IP -> null) ++ additionalEnvVars + + assert(executor.getSpec.getContainers.size() === 1) + assert(executor.getSpec.getContainers.get(0).getEnv.size() === defaultEnvs.size) + val mapEnvs = executor.getSpec.getContainers.get(0).getEnv.asScala.map { + x => (x.getName, x.getValue) + }.toMap + assert(defaultEnvs === mapEnvs) + } +} diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesClusterSchedulerBackendSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesClusterSchedulerBackendSuite.scala new file mode 100644 index 0000000000000..3febb2f47cfd4 --- /dev/null +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesClusterSchedulerBackendSuite.scala @@ -0,0 +1,440 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.scheduler.cluster.k8s + +import java.util.concurrent.{ExecutorService, ScheduledExecutorService, TimeUnit} + +import io.fabric8.kubernetes.api.model.{DoneablePod, Pod, PodBuilder, PodList} +import io.fabric8.kubernetes.client.{KubernetesClient, Watch, Watcher} +import io.fabric8.kubernetes.client.Watcher.Action +import io.fabric8.kubernetes.client.dsl.{FilterWatchListDeletable, MixedOperation, NonNamespaceOperation, PodResource} +import org.mockito.{AdditionalAnswers, ArgumentCaptor, Mock, MockitoAnnotations} +import org.mockito.Matchers.{any, eq => mockitoEq} +import org.mockito.Mockito.{doNothing, never, times, verify, when} +import org.scalatest.BeforeAndAfter +import org.scalatest.mockito.MockitoSugar._ +import scala.collection.JavaConverters._ +import scala.concurrent.Future + +import org.apache.spark.{SparkConf, SparkContext, SparkFunSuite} +import org.apache.spark.deploy.k8s.Config._ +import org.apache.spark.deploy.k8s.Constants._ +import org.apache.spark.rpc._ +import org.apache.spark.scheduler.{ExecutorExited, LiveListenerBus, SlaveLost, TaskSchedulerImpl} +import org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages.{RegisterExecutor, RemoveExecutor} +import org.apache.spark.scheduler.cluster.CoarseGrainedSchedulerBackend +import org.apache.spark.util.ThreadUtils + +class KubernetesClusterSchedulerBackendSuite extends SparkFunSuite with BeforeAndAfter { + + private val APP_ID = "test-spark-app" + private val DRIVER_POD_NAME = "spark-driver-pod" + private val NAMESPACE = "test-namespace" + private val SPARK_DRIVER_HOST = "localhost" + private val SPARK_DRIVER_PORT = 7077 + private val POD_ALLOCATION_INTERVAL = 60L + private val DRIVER_URL = RpcEndpointAddress( + SPARK_DRIVER_HOST, SPARK_DRIVER_PORT, CoarseGrainedSchedulerBackend.ENDPOINT_NAME).toString + private val FIRST_EXECUTOR_POD = new PodBuilder() + .withNewMetadata() + .withName("pod1") + .endMetadata() + .withNewSpec() + .withNodeName("node1") + .endSpec() + .withNewStatus() + .withHostIP("192.168.99.100") + .endStatus() + .build() + private val SECOND_EXECUTOR_POD = new PodBuilder() + .withNewMetadata() + .withName("pod2") + .endMetadata() + .withNewSpec() + .withNodeName("node2") + .endSpec() + .withNewStatus() + .withHostIP("192.168.99.101") + .endStatus() + .build() + + private type PODS = MixedOperation[Pod, PodList, DoneablePod, PodResource[Pod, DoneablePod]] + private type LABELED_PODS = FilterWatchListDeletable[ + Pod, PodList, java.lang.Boolean, Watch, Watcher[Pod]] + private type IN_NAMESPACE_PODS = NonNamespaceOperation[ + Pod, PodList, DoneablePod, PodResource[Pod, DoneablePod]] + + @Mock + private var sparkContext: SparkContext = _ + + @Mock + private var listenerBus: LiveListenerBus = _ + + @Mock + private var taskSchedulerImpl: TaskSchedulerImpl = _ + + @Mock + private var allocatorExecutor: ScheduledExecutorService = _ + + @Mock + private var requestExecutorsService: ExecutorService = _ + + @Mock + private var executorPodFactory: ExecutorPodFactory = _ + + @Mock + private var kubernetesClient: KubernetesClient = _ + + @Mock + private var podOperations: PODS = _ + + @Mock + private var podsWithLabelOperations: LABELED_PODS = _ + + @Mock + private var podsInNamespace: IN_NAMESPACE_PODS = _ + + @Mock + private var podsWithDriverName: PodResource[Pod, DoneablePod] = _ + + @Mock + private var rpcEnv: RpcEnv = _ + + @Mock + private var driverEndpointRef: RpcEndpointRef = _ + + @Mock + private var executorPodsWatch: Watch = _ + + @Mock + private var successFuture: Future[Boolean] = _ + + private var sparkConf: SparkConf = _ + private var executorPodsWatcherArgument: ArgumentCaptor[Watcher[Pod]] = _ + private var allocatorRunnable: ArgumentCaptor[Runnable] = _ + private var requestExecutorRunnable: ArgumentCaptor[Runnable] = _ + private var driverEndpoint: ArgumentCaptor[RpcEndpoint] = _ + + private val driverPod = new PodBuilder() + .withNewMetadata() + .withName(DRIVER_POD_NAME) + .addToLabels(SPARK_APP_ID_LABEL, APP_ID) + .addToLabels(SPARK_ROLE_LABEL, SPARK_POD_DRIVER_ROLE) + .endMetadata() + .build() + + before { + MockitoAnnotations.initMocks(this) + sparkConf = new SparkConf() + .set(KUBERNETES_DRIVER_POD_NAME, DRIVER_POD_NAME) + .set(KUBERNETES_NAMESPACE, NAMESPACE) + .set("spark.driver.host", SPARK_DRIVER_HOST) + .set("spark.driver.port", SPARK_DRIVER_PORT.toString) + .set(KUBERNETES_ALLOCATION_BATCH_DELAY, POD_ALLOCATION_INTERVAL) + executorPodsWatcherArgument = ArgumentCaptor.forClass(classOf[Watcher[Pod]]) + allocatorRunnable = ArgumentCaptor.forClass(classOf[Runnable]) + requestExecutorRunnable = ArgumentCaptor.forClass(classOf[Runnable]) + driverEndpoint = ArgumentCaptor.forClass(classOf[RpcEndpoint]) + when(sparkContext.conf).thenReturn(sparkConf) + when(sparkContext.listenerBus).thenReturn(listenerBus) + when(taskSchedulerImpl.sc).thenReturn(sparkContext) + when(kubernetesClient.pods()).thenReturn(podOperations) + when(podOperations.withLabel(SPARK_APP_ID_LABEL, APP_ID)).thenReturn(podsWithLabelOperations) + when(podsWithLabelOperations.watch(executorPodsWatcherArgument.capture())) + .thenReturn(executorPodsWatch) + when(podOperations.inNamespace(NAMESPACE)).thenReturn(podsInNamespace) + when(podsInNamespace.withName(DRIVER_POD_NAME)).thenReturn(podsWithDriverName) + when(podsWithDriverName.get()).thenReturn(driverPod) + when(allocatorExecutor.scheduleWithFixedDelay( + allocatorRunnable.capture(), + mockitoEq(0L), + mockitoEq(POD_ALLOCATION_INTERVAL), + mockitoEq(TimeUnit.SECONDS))).thenReturn(null) + // Creating Futures in Scala backed by a Java executor service resolves to running + // ExecutorService#execute (as opposed to submit) + doNothing().when(requestExecutorsService).execute(requestExecutorRunnable.capture()) + when(rpcEnv.setupEndpoint( + mockitoEq(CoarseGrainedSchedulerBackend.ENDPOINT_NAME), driverEndpoint.capture())) + .thenReturn(driverEndpointRef) + + // Used by the CoarseGrainedSchedulerBackend when making RPC calls. + when(driverEndpointRef.ask[Boolean] + (any(classOf[Any])) + (any())).thenReturn(successFuture) + when(successFuture.failed).thenReturn(Future[Throwable] { + // emulate behavior of the Future.failed method. + throw new NoSuchElementException() + }(ThreadUtils.sameThread)) + } + + test("Basic lifecycle expectations when starting and stopping the scheduler.") { + val scheduler = newSchedulerBackend() + scheduler.start() + assert(executorPodsWatcherArgument.getValue != null) + assert(allocatorRunnable.getValue != null) + scheduler.stop() + verify(executorPodsWatch).close() + } + + test("Static allocation should request executors upon first allocator run.") { + sparkConf + .set(KUBERNETES_ALLOCATION_BATCH_SIZE, 2) + .set(org.apache.spark.internal.config.EXECUTOR_INSTANCES, 2) + val scheduler = newSchedulerBackend() + scheduler.start() + requestExecutorRunnable.getValue.run() + val firstResolvedPod = expectPodCreationWithId(1, FIRST_EXECUTOR_POD) + val secondResolvedPod = expectPodCreationWithId(2, SECOND_EXECUTOR_POD) + when(podOperations.create(any(classOf[Pod]))).thenAnswer(AdditionalAnswers.returnsFirstArg()) + allocatorRunnable.getValue.run() + verify(podOperations).create(firstResolvedPod) + verify(podOperations).create(secondResolvedPod) + } + + test("Killing executors deletes the executor pods") { + sparkConf + .set(KUBERNETES_ALLOCATION_BATCH_SIZE, 2) + .set(org.apache.spark.internal.config.EXECUTOR_INSTANCES, 2) + val scheduler = newSchedulerBackend() + scheduler.start() + requestExecutorRunnable.getValue.run() + val firstResolvedPod = expectPodCreationWithId(1, FIRST_EXECUTOR_POD) + val secondResolvedPod = expectPodCreationWithId(2, SECOND_EXECUTOR_POD) + when(podOperations.create(any(classOf[Pod]))) + .thenAnswer(AdditionalAnswers.returnsFirstArg()) + allocatorRunnable.getValue.run() + scheduler.doKillExecutors(Seq("2")) + requestExecutorRunnable.getAllValues.asScala.last.run() + verify(podOperations).delete(secondResolvedPod) + verify(podOperations, never()).delete(firstResolvedPod) + } + + test("Executors should be requested in batches.") { + sparkConf + .set(KUBERNETES_ALLOCATION_BATCH_SIZE, 1) + .set(org.apache.spark.internal.config.EXECUTOR_INSTANCES, 2) + val scheduler = newSchedulerBackend() + scheduler.start() + requestExecutorRunnable.getValue.run() + when(podOperations.create(any(classOf[Pod]))) + .thenAnswer(AdditionalAnswers.returnsFirstArg()) + val firstResolvedPod = expectPodCreationWithId(1, FIRST_EXECUTOR_POD) + val secondResolvedPod = expectPodCreationWithId(2, SECOND_EXECUTOR_POD) + allocatorRunnable.getValue.run() + verify(podOperations).create(firstResolvedPod) + verify(podOperations, never()).create(secondResolvedPod) + val registerFirstExecutorMessage = RegisterExecutor( + "1", mock[RpcEndpointRef], "localhost", 1, Map.empty[String, String]) + when(taskSchedulerImpl.resourceOffers(any())).thenReturn(Seq.empty) + driverEndpoint.getValue.receiveAndReply(mock[RpcCallContext]) + .apply(registerFirstExecutorMessage) + allocatorRunnable.getValue.run() + verify(podOperations).create(secondResolvedPod) + } + + test("Scaled down executors should be cleaned up") { + sparkConf + .set(KUBERNETES_ALLOCATION_BATCH_SIZE, 1) + .set(org.apache.spark.internal.config.EXECUTOR_INSTANCES, 1) + val scheduler = newSchedulerBackend() + scheduler.start() + + // The scheduler backend spins up one executor pod. + requestExecutorRunnable.getValue.run() + when(podOperations.create(any(classOf[Pod]))) + .thenAnswer(AdditionalAnswers.returnsFirstArg()) + val resolvedPod = expectPodCreationWithId(1, FIRST_EXECUTOR_POD) + allocatorRunnable.getValue.run() + val executorEndpointRef = mock[RpcEndpointRef] + when(executorEndpointRef.address).thenReturn(RpcAddress("pod.example.com", 9000)) + val registerFirstExecutorMessage = RegisterExecutor( + "1", executorEndpointRef, "localhost:9000", 1, Map.empty[String, String]) + when(taskSchedulerImpl.resourceOffers(any())).thenReturn(Seq.empty) + driverEndpoint.getValue.receiveAndReply(mock[RpcCallContext]) + .apply(registerFirstExecutorMessage) + + // Request that there are 0 executors and trigger deletion from driver. + scheduler.doRequestTotalExecutors(0) + requestExecutorRunnable.getAllValues.asScala.last.run() + scheduler.doKillExecutors(Seq("1")) + requestExecutorRunnable.getAllValues.asScala.last.run() + verify(podOperations, times(1)).delete(resolvedPod) + driverEndpoint.getValue.onDisconnected(executorEndpointRef.address) + + val exitedPod = exitPod(resolvedPod, 0) + executorPodsWatcherArgument.getValue.eventReceived(Action.DELETED, exitedPod) + allocatorRunnable.getValue.run() + + // No more deletion attempts of the executors. + // This is graceful termination and should not be detected as a failure. + verify(podOperations, times(1)).delete(resolvedPod) + verify(driverEndpointRef, times(1)).ask[Boolean]( + RemoveExecutor("1", ExecutorExited( + 0, + exitCausedByApp = false, + s"Container in pod ${exitedPod.getMetadata.getName} exited from" + + s" explicit termination request."))) + } + + test("Executors that fail should not be deleted.") { + sparkConf + .set(KUBERNETES_ALLOCATION_BATCH_SIZE, 1) + .set(org.apache.spark.internal.config.EXECUTOR_INSTANCES, 1) + + val scheduler = newSchedulerBackend() + scheduler.start() + val firstResolvedPod = expectPodCreationWithId(1, FIRST_EXECUTOR_POD) + when(podOperations.create(any(classOf[Pod]))).thenAnswer(AdditionalAnswers.returnsFirstArg()) + requestExecutorRunnable.getValue.run() + allocatorRunnable.getValue.run() + val executorEndpointRef = mock[RpcEndpointRef] + when(executorEndpointRef.address).thenReturn(RpcAddress("pod.example.com", 9000)) + val registerFirstExecutorMessage = RegisterExecutor( + "1", executorEndpointRef, "localhost:9000", 1, Map.empty[String, String]) + when(taskSchedulerImpl.resourceOffers(any())).thenReturn(Seq.empty) + driverEndpoint.getValue.receiveAndReply(mock[RpcCallContext]) + .apply(registerFirstExecutorMessage) + driverEndpoint.getValue.onDisconnected(executorEndpointRef.address) + executorPodsWatcherArgument.getValue.eventReceived( + Action.ERROR, exitPod(firstResolvedPod, 1)) + + // A replacement executor should be created but the error pod should persist. + val replacementPod = expectPodCreationWithId(2, SECOND_EXECUTOR_POD) + scheduler.doRequestTotalExecutors(1) + requestExecutorRunnable.getValue.run() + allocatorRunnable.getAllValues.asScala.last.run() + verify(podOperations, never()).delete(firstResolvedPod) + verify(driverEndpointRef).ask[Boolean]( + RemoveExecutor("1", ExecutorExited( + 1, + exitCausedByApp = true, + s"Pod ${FIRST_EXECUTOR_POD.getMetadata.getName}'s executor container exited with" + + " exit status code 1."))) + } + + test("Executors disconnected due to unknown reasons are deleted and replaced.") { + sparkConf + .set(KUBERNETES_ALLOCATION_BATCH_SIZE, 1) + .set(org.apache.spark.internal.config.EXECUTOR_INSTANCES, 1) + val executorLostReasonCheckMaxAttempts = sparkConf.get( + KUBERNETES_EXECUTOR_LOST_REASON_CHECK_MAX_ATTEMPTS) + + val scheduler = newSchedulerBackend() + scheduler.start() + val firstResolvedPod = expectPodCreationWithId(1, FIRST_EXECUTOR_POD) + when(podOperations.create(any(classOf[Pod]))).thenAnswer(AdditionalAnswers.returnsFirstArg()) + requestExecutorRunnable.getValue.run() + allocatorRunnable.getValue.run() + val executorEndpointRef = mock[RpcEndpointRef] + when(executorEndpointRef.address).thenReturn(RpcAddress("pod.example.com", 9000)) + val registerFirstExecutorMessage = RegisterExecutor( + "1", executorEndpointRef, "localhost:9000", 1, Map.empty[String, String]) + when(taskSchedulerImpl.resourceOffers(any())).thenReturn(Seq.empty) + driverEndpoint.getValue.receiveAndReply(mock[RpcCallContext]) + .apply(registerFirstExecutorMessage) + + driverEndpoint.getValue.onDisconnected(executorEndpointRef.address) + 1 to executorLostReasonCheckMaxAttempts foreach { _ => + allocatorRunnable.getValue.run() + verify(podOperations, never()).delete(FIRST_EXECUTOR_POD) + } + + val recreatedResolvedPod = expectPodCreationWithId(2, SECOND_EXECUTOR_POD) + allocatorRunnable.getValue.run() + verify(podOperations).delete(firstResolvedPod) + verify(driverEndpointRef).ask[Boolean]( + RemoveExecutor("1", SlaveLost("Executor lost for unknown reasons."))) + } + + test("Executors that fail to start on the Kubernetes API call rebuild in the next batch.") { + sparkConf + .set(KUBERNETES_ALLOCATION_BATCH_SIZE, 1) + .set(org.apache.spark.internal.config.EXECUTOR_INSTANCES, 1) + val scheduler = newSchedulerBackend() + scheduler.start() + val firstResolvedPod = expectPodCreationWithId(1, FIRST_EXECUTOR_POD) + when(podOperations.create(firstResolvedPod)) + .thenThrow(new RuntimeException("test")) + requestExecutorRunnable.getValue.run() + allocatorRunnable.getValue.run() + verify(podOperations, times(1)).create(firstResolvedPod) + val recreatedResolvedPod = expectPodCreationWithId(2, FIRST_EXECUTOR_POD) + allocatorRunnable.getValue.run() + verify(podOperations).create(recreatedResolvedPod) + } + + test("Executors that are initially created but the watch notices them fail are rebuilt" + + " in the next batch.") { + sparkConf + .set(KUBERNETES_ALLOCATION_BATCH_SIZE, 1) + .set(org.apache.spark.internal.config.EXECUTOR_INSTANCES, 1) + val scheduler = newSchedulerBackend() + scheduler.start() + val firstResolvedPod = expectPodCreationWithId(1, FIRST_EXECUTOR_POD) + when(podOperations.create(FIRST_EXECUTOR_POD)).thenAnswer(AdditionalAnswers.returnsFirstArg()) + requestExecutorRunnable.getValue.run() + allocatorRunnable.getValue.run() + verify(podOperations, times(1)).create(firstResolvedPod) + executorPodsWatcherArgument.getValue.eventReceived(Action.ERROR, firstResolvedPod) + val recreatedResolvedPod = expectPodCreationWithId(2, FIRST_EXECUTOR_POD) + allocatorRunnable.getValue.run() + verify(podOperations).create(recreatedResolvedPod) + } + + private def newSchedulerBackend(): KubernetesClusterSchedulerBackend = { + new KubernetesClusterSchedulerBackend( + taskSchedulerImpl, + rpcEnv, + executorPodFactory, + kubernetesClient, + allocatorExecutor, + requestExecutorsService) { + + override def applicationId(): String = APP_ID + } + } + + private def exitPod(basePod: Pod, exitCode: Int): Pod = { + new PodBuilder(basePod) + .editStatus() + .addNewContainerStatus() + .withNewState() + .withNewTerminated() + .withExitCode(exitCode) + .endTerminated() + .endState() + .endContainerStatus() + .endStatus() + .build() + } + + private def expectPodCreationWithId(executorId: Int, expectedPod: Pod): Pod = { + val resolvedPod = new PodBuilder(expectedPod) + .editMetadata() + .addToLabels(SPARK_EXECUTOR_ID_LABEL, executorId.toString) + .endMetadata() + .build() + when(executorPodFactory.createExecutorPod( + executorId.toString, + APP_ID, + DRIVER_URL, + sparkConf.getExecutorEnv, + driverPod, + Map.empty)).thenReturn(resolvedPod) + resolvedPod + } +} diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala index 7052fb347106b..506adb363aa90 100644 --- a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala +++ b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala @@ -41,6 +41,7 @@ import org.apache.spark.rpc.{RpcCallContext, RpcEndpointRef} import org.apache.spark.scheduler.{ExecutorExited, ExecutorLossReason} import org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages.RemoveExecutor import org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages.RetrieveLastAllocatedExecutorId +import org.apache.spark.scheduler.cluster.SchedulerBackendUtils import org.apache.spark.util.{Clock, SystemClock, ThreadUtils} /** @@ -109,7 +110,7 @@ private[yarn] class YarnAllocator( sparkConf.get(EXECUTOR_ATTEMPT_FAILURE_VALIDITY_INTERVAL_MS).getOrElse(-1L) @volatile private var targetNumExecutors = - YarnSparkHadoopUtil.getInitialTargetExecutorNumber(sparkConf) + SchedulerBackendUtils.getInitialTargetExecutorNumber(sparkConf) private var currentNodeBlacklist = Set.empty[String] diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtil.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtil.scala index 3d9f99f57bed7..9c1472cb50e3a 100644 --- a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtil.scala +++ b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtil.scala @@ -133,8 +133,6 @@ object YarnSparkHadoopUtil { val ANY_HOST = "*" - val DEFAULT_NUMBER_EXECUTORS = 2 - // All RM requests are issued with same priority : we do not (yet) have any distinction between // request types (like map/reduce in hadoop for example) val RM_REQUEST_PRIORITY = Priority.newInstance(1) @@ -279,27 +277,5 @@ object YarnSparkHadoopUtil { securityMgr.getModifyAclsGroups) ) } - - /** - * Getting the initial target number of executors depends on whether dynamic allocation is - * enabled. - * If not using dynamic allocation it gets the number of executors requested by the user. - */ - def getInitialTargetExecutorNumber( - conf: SparkConf, - numExecutors: Int = DEFAULT_NUMBER_EXECUTORS): Int = { - if (Utils.isDynamicAllocationEnabled(conf)) { - val minNumExecutors = conf.get(DYN_ALLOCATION_MIN_EXECUTORS) - val initialNumExecutors = Utils.getDynamicAllocationInitialExecutors(conf) - val maxNumExecutors = conf.get(DYN_ALLOCATION_MAX_EXECUTORS) - require(initialNumExecutors >= minNumExecutors && initialNumExecutors <= maxNumExecutors, - s"initial executor number $initialNumExecutors must between min executor number " + - s"$minNumExecutors and max executor number $maxNumExecutors") - - initialNumExecutors - } else { - conf.get(EXECUTOR_INSTANCES).getOrElse(numExecutors) - } - } } diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala index d482376d14dd7..b722cc401bb73 100644 --- a/resource-managers/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala +++ b/resource-managers/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala @@ -52,7 +52,7 @@ private[spark] class YarnClientSchedulerBackend( logDebug("ClientArguments called with: " + argsArrayBuf.mkString(" ")) val args = new ClientArguments(argsArrayBuf.toArray) - totalExpectedExecutors = YarnSparkHadoopUtil.getInitialTargetExecutorNumber(conf) + totalExpectedExecutors = SchedulerBackendUtils.getInitialTargetExecutorNumber(conf) client = new Client(args, conf) bindToYarn(client.submitApplication(), None) diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClusterSchedulerBackend.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClusterSchedulerBackend.scala index 4f3d5ebf403e0..e2d477be329c3 100644 --- a/resource-managers/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClusterSchedulerBackend.scala +++ b/resource-managers/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClusterSchedulerBackend.scala @@ -34,7 +34,7 @@ private[spark] class YarnClusterSchedulerBackend( val attemptId = ApplicationMaster.getAttemptId bindToYarn(attemptId.getApplicationId(), Some(attemptId)) super.start() - totalExpectedExecutors = YarnSparkHadoopUtil.getInitialTargetExecutorNumber(sc.conf) + totalExpectedExecutors = SchedulerBackendUtils.getInitialTargetExecutorNumber(sc.conf) } override def getDriverLogUrls: Option[Map[String, String]] = { From 20b239845b695fe6a893ebfe97b49ef05fae773d Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Wed, 29 Nov 2017 19:18:47 +0800 Subject: [PATCH 1738/1765] [SPARK-22605][SQL] SQL write job should also set Spark task output metrics ## What changes were proposed in this pull request? For SQL write jobs, we only set metrics for the SQL listener and display them in the SQL plan UI. We should also set metrics for Spark task output metrics, which will be shown in spark job UI. ## How was this patch tested? test it manually. For a simple write job ``` spark.range(1000).write.parquet("/tmp/p1") ``` now the spark job UI looks like ![ui](https://user-images.githubusercontent.com/3182036/33326478-05a25b7c-d490-11e7-96ef-806117774356.jpg) Author: Wenchen Fan Closes #19833 from cloud-fan/ui. --- .../execution/datasources/BasicWriteStatsTracker.scala | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/BasicWriteStatsTracker.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/BasicWriteStatsTracker.scala index 11af0aaa7b206..9dbbe9946ee99 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/BasicWriteStatsTracker.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/BasicWriteStatsTracker.scala @@ -22,7 +22,7 @@ import java.io.FileNotFoundException import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.Path -import org.apache.spark.SparkContext +import org.apache.spark.{SparkContext, TaskContext} import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.execution.SQLExecution @@ -44,7 +44,6 @@ case class BasicWriteTaskStats( /** * Simple [[WriteTaskStatsTracker]] implementation that produces [[BasicWriteTaskStats]]. - * @param hadoopConf */ class BasicWriteTaskStatsTracker(hadoopConf: Configuration) extends WriteTaskStatsTracker with Logging { @@ -106,6 +105,13 @@ class BasicWriteTaskStatsTracker(hadoopConf: Configuration) override def getFinalStats(): WriteTaskStats = { statCurrentFile() + + // Reports bytesWritten and recordsWritten to the Spark output metrics. + Option(TaskContext.get()).map(_.taskMetrics().outputMetrics).foreach { outputMetrics => + outputMetrics.setBytesWritten(numBytes) + outputMetrics.setRecordsWritten(numRows) + } + if (submittedFiles != numFiles) { logInfo(s"Expected $submittedFiles files, but only saw $numFiles. " + "This could be due to the output format not writing empty files, " + From 57687280d4171db98d4d9404c7bd3374f51deac0 Mon Sep 17 00:00:00 2001 From: Wang Gengliang Date: Wed, 29 Nov 2017 09:17:39 -0800 Subject: [PATCH 1739/1765] [SPARK-22615][SQL] Handle more cases in PropagateEmptyRelation ## What changes were proposed in this pull request? Currently, in the optimize rule `PropagateEmptyRelation`, the following cases is not handled: 1. empty relation as right child in left outer join 2. empty relation as left child in right outer join 3. empty relation as right child in left semi join 4. empty relation as right child in left anti join 5. only one empty relation in full outer join case 1 / 2 / 5 can be treated as **Cartesian product** and cause exception. See the new test cases. ## How was this patch tested? Unit test Author: Wang Gengliang Closes #19825 from gengliangwang/SPARK-22615. --- .../optimizer/PropagateEmptyRelation.scala | 36 +++- .../PropagateEmptyRelationSuite.scala | 16 +- .../sql-tests/inputs/join-empty-relation.sql | 28 +++ .../results/join-empty-relation.sql.out | 194 ++++++++++++++++++ 4 files changed, 257 insertions(+), 17 deletions(-) create mode 100644 sql/core/src/test/resources/sql-tests/inputs/join-empty-relation.sql create mode 100644 sql/core/src/test/resources/sql-tests/results/join-empty-relation.sql.out diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/PropagateEmptyRelation.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/PropagateEmptyRelation.scala index 52fbb4df2f58e..a6e5aa6daca65 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/PropagateEmptyRelation.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/PropagateEmptyRelation.scala @@ -41,6 +41,10 @@ object PropagateEmptyRelation extends Rule[LogicalPlan] with PredicateHelper { private def empty(plan: LogicalPlan) = LocalRelation(plan.output, data = Seq.empty, isStreaming = plan.isStreaming) + // Construct a project list from plan's output, while the value is always NULL. + private def nullValueProjectList(plan: LogicalPlan): Seq[NamedExpression] = + plan.output.map{ a => Alias(Literal(null), a.name)(a.exprId) } + def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { case p: Union if p.children.forall(isEmptyLocalRelation) => empty(p) @@ -49,16 +53,28 @@ object PropagateEmptyRelation extends Rule[LogicalPlan] with PredicateHelper { // as stateful streaming joins need to perform other state management operations other than // just processing the input data. case p @ Join(_, _, joinType, _) - if !p.children.exists(_.isStreaming) && p.children.exists(isEmptyLocalRelation) => - joinType match { - case _: InnerLike => empty(p) - // Intersect is handled as LeftSemi by `ReplaceIntersectWithSemiJoin` rule. - // Except is handled as LeftAnti by `ReplaceExceptWithAntiJoin` rule. - case LeftOuter | LeftSemi | LeftAnti if isEmptyLocalRelation(p.left) => empty(p) - case RightOuter if isEmptyLocalRelation(p.right) => empty(p) - case FullOuter if p.children.forall(isEmptyLocalRelation) => empty(p) - case _ => p - } + if !p.children.exists(_.isStreaming) => + val isLeftEmpty = isEmptyLocalRelation(p.left) + val isRightEmpty = isEmptyLocalRelation(p.right) + if (isLeftEmpty || isRightEmpty) { + joinType match { + case _: InnerLike => empty(p) + // Intersect is handled as LeftSemi by `ReplaceIntersectWithSemiJoin` rule. + // Except is handled as LeftAnti by `ReplaceExceptWithAntiJoin` rule. + case LeftOuter | LeftSemi | LeftAnti if isLeftEmpty => empty(p) + case LeftSemi if isRightEmpty => empty(p) + case LeftAnti if isRightEmpty => p.left + case FullOuter if isLeftEmpty && isRightEmpty => empty(p) + case LeftOuter | FullOuter if isRightEmpty => + Project(p.left.output ++ nullValueProjectList(p.right), p.left) + case RightOuter if isRightEmpty => empty(p) + case RightOuter | FullOuter if isLeftEmpty => + Project(nullValueProjectList(p.left) ++ p.right.output, p.right) + case _ => p + } + } else { + p + } case p: UnaryNode if p.children.nonEmpty && p.children.forall(isEmptyLocalRelation) => p match { case _: Project => empty(p) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PropagateEmptyRelationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PropagateEmptyRelationSuite.scala index bc1c48b99c295..3964508e3a55e 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PropagateEmptyRelationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PropagateEmptyRelationSuite.scala @@ -21,8 +21,9 @@ import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ +import org.apache.spark.sql.catalyst.expressions.Literal import org.apache.spark.sql.catalyst.plans._ -import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan} +import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan, Project} import org.apache.spark.sql.catalyst.rules.RuleExecutor import org.apache.spark.sql.types.StructType @@ -78,17 +79,18 @@ class PropagateEmptyRelationSuite extends PlanTest { (true, false, Inner, Some(LocalRelation('a.int, 'b.int))), (true, false, Cross, Some(LocalRelation('a.int, 'b.int))), - (true, false, LeftOuter, None), + (true, false, LeftOuter, Some(Project(Seq('a, Literal(null).as('b)), testRelation1).analyze)), (true, false, RightOuter, Some(LocalRelation('a.int, 'b.int))), - (true, false, FullOuter, None), - (true, false, LeftAnti, None), - (true, false, LeftSemi, None), + (true, false, FullOuter, Some(Project(Seq('a, Literal(null).as('b)), testRelation1).analyze)), + (true, false, LeftAnti, Some(testRelation1)), + (true, false, LeftSemi, Some(LocalRelation('a.int))), (false, true, Inner, Some(LocalRelation('a.int, 'b.int))), (false, true, Cross, Some(LocalRelation('a.int, 'b.int))), (false, true, LeftOuter, Some(LocalRelation('a.int, 'b.int))), - (false, true, RightOuter, None), - (false, true, FullOuter, None), + (false, true, RightOuter, + Some(Project(Seq(Literal(null).as('a), 'b), testRelation2).analyze)), + (false, true, FullOuter, Some(Project(Seq(Literal(null).as('a), 'b), testRelation2).analyze)), (false, true, LeftAnti, Some(LocalRelation('a.int))), (false, true, LeftSemi, Some(LocalRelation('a.int))), diff --git a/sql/core/src/test/resources/sql-tests/inputs/join-empty-relation.sql b/sql/core/src/test/resources/sql-tests/inputs/join-empty-relation.sql new file mode 100644 index 0000000000000..8afa3270f4de4 --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/inputs/join-empty-relation.sql @@ -0,0 +1,28 @@ +CREATE TEMPORARY VIEW t1 AS SELECT * FROM VALUES (1) AS GROUPING(a); +CREATE TEMPORARY VIEW t2 AS SELECT * FROM VALUES (1) AS GROUPING(a); + +CREATE TEMPORARY VIEW empty_table as SELECT a FROM t2 WHERE false; + +SELECT * FROM t1 INNER JOIN empty_table; +SELECT * FROM t1 CROSS JOIN empty_table; +SELECT * FROM t1 LEFT OUTER JOIN empty_table; +SELECT * FROM t1 RIGHT OUTER JOIN empty_table; +SELECT * FROM t1 FULL OUTER JOIN empty_table; +SELECT * FROM t1 LEFT SEMI JOIN empty_table; +SELECT * FROM t1 LEFT ANTI JOIN empty_table; + +SELECT * FROM empty_table INNER JOIN t1; +SELECT * FROM empty_table CROSS JOIN t1; +SELECT * FROM empty_table LEFT OUTER JOIN t1; +SELECT * FROM empty_table RIGHT OUTER JOIN t1; +SELECT * FROM empty_table FULL OUTER JOIN t1; +SELECT * FROM empty_table LEFT SEMI JOIN t1; +SELECT * FROM empty_table LEFT ANTI JOIN t1; + +SELECT * FROM empty_table INNER JOIN empty_table; +SELECT * FROM empty_table CROSS JOIN empty_table; +SELECT * FROM empty_table LEFT OUTER JOIN empty_table; +SELECT * FROM empty_table RIGHT OUTER JOIN empty_table; +SELECT * FROM empty_table FULL OUTER JOIN empty_table; +SELECT * FROM empty_table LEFT SEMI JOIN empty_table; +SELECT * FROM empty_table LEFT ANTI JOIN empty_table; diff --git a/sql/core/src/test/resources/sql-tests/results/join-empty-relation.sql.out b/sql/core/src/test/resources/sql-tests/results/join-empty-relation.sql.out new file mode 100644 index 0000000000000..857073a827f24 --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/results/join-empty-relation.sql.out @@ -0,0 +1,194 @@ +-- Automatically generated by SQLQueryTestSuite +-- Number of queries: 24 + + +-- !query 0 +CREATE TEMPORARY VIEW t1 AS SELECT * FROM VALUES (1) AS GROUPING(a) +-- !query 0 schema +struct<> +-- !query 0 output + + + +-- !query 1 +CREATE TEMPORARY VIEW t2 AS SELECT * FROM VALUES (1) AS GROUPING(a) +-- !query 1 schema +struct<> +-- !query 1 output + + + +-- !query 2 +CREATE TEMPORARY VIEW empty_table as SELECT a FROM t2 WHERE false +-- !query 2 schema +struct<> +-- !query 2 output + + + +-- !query 3 +SELECT * FROM t1 INNER JOIN empty_table +-- !query 3 schema +struct +-- !query 3 output + + + +-- !query 4 +SELECT * FROM t1 CROSS JOIN empty_table +-- !query 4 schema +struct +-- !query 4 output + + + +-- !query 5 +SELECT * FROM t1 LEFT OUTER JOIN empty_table +-- !query 5 schema +struct +-- !query 5 output +1 NULL + + +-- !query 6 +SELECT * FROM t1 RIGHT OUTER JOIN empty_table +-- !query 6 schema +struct +-- !query 6 output + + + +-- !query 7 +SELECT * FROM t1 FULL OUTER JOIN empty_table +-- !query 7 schema +struct +-- !query 7 output +1 NULL + + +-- !query 8 +SELECT * FROM t1 LEFT SEMI JOIN empty_table +-- !query 8 schema +struct +-- !query 8 output + + + +-- !query 9 +SELECT * FROM t1 LEFT ANTI JOIN empty_table +-- !query 9 schema +struct +-- !query 9 output +1 + + +-- !query 10 +SELECT * FROM empty_table INNER JOIN t1 +-- !query 10 schema +struct +-- !query 10 output + + + +-- !query 11 +SELECT * FROM empty_table CROSS JOIN t1 +-- !query 11 schema +struct +-- !query 11 output + + + +-- !query 12 +SELECT * FROM empty_table LEFT OUTER JOIN t1 +-- !query 12 schema +struct +-- !query 12 output + + + +-- !query 13 +SELECT * FROM empty_table RIGHT OUTER JOIN t1 +-- !query 13 schema +struct +-- !query 13 output +NULL 1 + + +-- !query 14 +SELECT * FROM empty_table FULL OUTER JOIN t1 +-- !query 14 schema +struct +-- !query 14 output +NULL 1 + + +-- !query 15 +SELECT * FROM empty_table LEFT SEMI JOIN t1 +-- !query 15 schema +struct +-- !query 15 output + + + +-- !query 16 +SELECT * FROM empty_table LEFT ANTI JOIN t1 +-- !query 16 schema +struct +-- !query 16 output + + + +-- !query 17 +SELECT * FROM empty_table INNER JOIN empty_table +-- !query 17 schema +struct +-- !query 17 output + + + +-- !query 18 +SELECT * FROM empty_table CROSS JOIN empty_table +-- !query 18 schema +struct +-- !query 18 output + + + +-- !query 19 +SELECT * FROM empty_table LEFT OUTER JOIN empty_table +-- !query 19 schema +struct +-- !query 19 output + + + +-- !query 20 +SELECT * FROM empty_table RIGHT OUTER JOIN empty_table +-- !query 20 schema +struct +-- !query 20 output + + + +-- !query 21 +SELECT * FROM empty_table FULL OUTER JOIN empty_table +-- !query 21 schema +struct +-- !query 21 output + + + +-- !query 22 +SELECT * FROM empty_table LEFT SEMI JOIN empty_table +-- !query 22 schema +struct +-- !query 22 output + + + +-- !query 23 +SELECT * FROM empty_table LEFT ANTI JOIN empty_table +-- !query 23 schema +struct +-- !query 23 output + From 284836862b2312aea5d7555c8e3c9d3c4dbc8eaf Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Thu, 30 Nov 2017 01:19:37 +0800 Subject: [PATCH 1740/1765] [SPARK-22608][SQL] add new API to CodeGeneration.splitExpressions() ## What changes were proposed in this pull request? This PR adds a new API to ` CodeGenenerator.splitExpression` since since several ` CodeGenenerator.splitExpression` are used with `ctx.INPUT_ROW` to avoid code duplication. ## How was this patch tested? Used existing test suits Author: Kazuaki Ishizaki Closes #19821 from kiszk/SPARK-22608. --- .../spark/sql/catalyst/expressions/Cast.scala | 2 +- .../expressions/codegen/CodeGenerator.scala | 24 ++++++++++++-- .../sql/catalyst/expressions/predicates.scala | 10 +++--- .../expressions/stringExpressions.scala | 32 ++++++------------- 4 files changed, 37 insertions(+), 31 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala index 12baddf1bf7ac..8cafaef61c7d1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala @@ -1040,7 +1040,7 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String } """ } - val fieldsEvalCodes = if (ctx.INPUT_ROW != null && ctx.currentVars == null) { + val fieldsEvalCodes = if (ctx.currentVars == null) { ctx.splitExpressions( expressions = fieldsEvalCode, funcName = "castStruct", diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index 668c816b3fd8d..1645db12c53f0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -788,11 +788,31 @@ class CodegenContext { * @param expressions the codes to evaluate expressions. */ def splitExpressions(expressions: Seq[String]): String = { + splitExpressions(expressions, funcName = "apply", extraArguments = Nil) + } + + /** + * Similar to [[splitExpressions(expressions: Seq[String])]], but has customized function name + * and extra arguments. + * + * @param expressions the codes to evaluate expressions. + * @param funcName the split function name base. + * @param extraArguments the list of (type, name) of the arguments of the split function + * except for ctx.INPUT_ROW + */ + def splitExpressions( + expressions: Seq[String], + funcName: String, + extraArguments: Seq[(String, String)]): String = { // TODO: support whole stage codegen if (INPUT_ROW == null || currentVars != null) { - return expressions.mkString("\n") + expressions.mkString("\n") + } else { + splitExpressions( + expressions, + funcName, + arguments = ("InternalRow", INPUT_ROW) +: extraArguments) } - splitExpressions(expressions, funcName = "apply", arguments = ("InternalRow", INPUT_ROW) :: Nil) } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala index eb7475354b104..1aaaaf1db48d1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala @@ -251,12 +251,10 @@ case class In(value: Expression, list: Seq[Expression]) extends Predicate { } } """) - val listCodes = if (ctx.INPUT_ROW != null && ctx.currentVars == null) { - val args = ("InternalRow", ctx.INPUT_ROW) :: (ctx.javaType(value.dataType), valueArg) :: Nil - ctx.splitExpressions(expressions = listCode, funcName = "valueIn", arguments = args) - } else { - listCode.mkString("\n") - } + val listCodes = ctx.splitExpressions( + expressions = listCode, + funcName = "valueIn", + extraArguments = (ctx.javaType(value.dataType), valueArg) :: Nil) ev.copy(code = s""" ${valueGen.code} ${ev.value} = false; diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala index ee5cf925d3cef..34917ace001fa 100755 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala @@ -73,14 +73,10 @@ case class Concat(children: Seq[Expression]) extends Expression with ImplicitCas } """ } - val codes = if (ctx.INPUT_ROW != null && ctx.currentVars == null) { - ctx.splitExpressions( - expressions = inputs, - funcName = "valueConcat", - arguments = ("InternalRow", ctx.INPUT_ROW) :: ("UTF8String[]", args) :: Nil) - } else { - inputs.mkString("\n") - } + val codes = ctx.splitExpressions( + expressions = inputs, + funcName = "valueConcat", + extraArguments = ("UTF8String[]", args) :: Nil) ev.copy(s""" UTF8String[] $args = new UTF8String[${evals.length}]; $codes @@ -156,14 +152,10 @@ case class ConcatWs(children: Seq[Expression]) "" } } - val codes = if (ctx.INPUT_ROW != null && ctx.currentVars == null) { - ctx.splitExpressions( + val codes = ctx.splitExpressions( expressions = inputs, funcName = "valueConcatWs", - arguments = ("InternalRow", ctx.INPUT_ROW) :: ("UTF8String[]", args) :: Nil) - } else { - inputs.mkString("\n") - } + extraArguments = ("UTF8String[]", args) :: Nil) ev.copy(s""" UTF8String[] $args = new UTF8String[$numArgs]; ${separator.code} @@ -1388,14 +1380,10 @@ case class FormatString(children: Expression*) extends Expression with ImplicitC $argList[$index] = $value; """ } - val argListCodes = if (ctx.INPUT_ROW != null && ctx.currentVars == null) { - ctx.splitExpressions( - expressions = argListCode, - funcName = "valueFormatString", - arguments = ("InternalRow", ctx.INPUT_ROW) :: ("Object[]", argList) :: Nil) - } else { - argListCode.mkString("\n") - } + val argListCodes = ctx.splitExpressions( + expressions = argListCode, + funcName = "valueFormatString", + extraArguments = ("Object[]", argList) :: Nil) val form = ctx.freshName("formatter") val formatter = classOf[java.util.Formatter].getName From 193555f79cc73873613674a09a7c371688b6dbc7 Mon Sep 17 00:00:00 2001 From: Stavros Kontopoulos Date: Wed, 29 Nov 2017 14:15:35 -0800 Subject: [PATCH 1741/1765] [SPARK-18935][MESOS] Fix dynamic reservations on mesos ## What changes were proposed in this pull request? - Solves the issue described in the ticket by preserving reservation and allocation info in all cases (port handling included). - upgrades to 1.4 - Adds extra debug level logging to make debugging easier in the future, for example we add reservation info when applicable. ``` 17/09/29 14:53:07 DEBUG MesosCoarseGrainedSchedulerBackend: Accepting offer: f20de49b-dee3-45dd-a3c1-73418b7de891-O32 with attributes: Map() allocation info: role: "spark-prive" reservation info: name: "ports" type: RANGES ranges { range { begin: 31000 end: 32000 } } role: "spark-prive" reservation { principal: "test" } allocation_info { role: "spark-prive" } ``` - Some style cleanup. ## How was this patch tested? Manually by running the example in the ticket with and without a principal. Specifically I tested it on a dc/os 1.10 cluster with 7 nodes and played with reservations. From the master node in order to reserve resources I executed: ```for i in 0 1 2 3 4 5 6 do curl -i \ -d slaveId=90ec65ea-1f7b-479f-a824-35d2527d6d26-S$i \ -d resources='[ { "name": "cpus", "type": "SCALAR", "scalar": { "value": 2 }, "role": "spark-role", "reservation": { "principal": "" } }, { "name": "mem", "type": "SCALAR", "scalar": { "value": 8026 }, "role": "spark-role", "reservation": { "principal": "" } } ]' \ -X POST http://master.mesos:5050/master/reserve done ``` Nodes had 4 cpus (m3.xlarge instances) and I reserved either 2 or 4 cpus (all for a role). I verified it launches tasks on nodes with reserved resources under `spark-role` role only if a) there are remaining resources for (*) default role and the spark driver has no role assigned to it. b) the spark driver has a role assigned to it and it is the same role used in reservations. I also tested this locally on my machine. Author: Stavros Kontopoulos Closes #19390 from skonto/fix_dynamic_reservation. --- dev/deps/spark-deps-hadoop-2.6 | 2 +- dev/deps/spark-deps-hadoop-2.7 | 2 +- resource-managers/mesos/pom.xml | 2 +- .../cluster/mesos/MesosClusterScheduler.scala | 1 - .../MesosCoarseGrainedSchedulerBackend.scala | 17 +++- .../cluster/mesos/MesosSchedulerUtils.scala | 99 ++++++++++++------- 6 files changed, 80 insertions(+), 43 deletions(-) diff --git a/dev/deps/spark-deps-hadoop-2.6 b/dev/deps/spark-deps-hadoop-2.6 index 21c8a75796387..50ac6d139bbd4 100644 --- a/dev/deps/spark-deps-hadoop-2.6 +++ b/dev/deps/spark-deps-hadoop-2.6 @@ -138,7 +138,7 @@ lz4-java-1.4.0.jar machinist_2.11-0.6.1.jar macro-compat_2.11-1.1.1.jar mail-1.4.7.jar -mesos-1.3.0-shaded-protobuf.jar +mesos-1.4.0-shaded-protobuf.jar metrics-core-3.1.5.jar metrics-graphite-3.1.5.jar metrics-json-3.1.5.jar diff --git a/dev/deps/spark-deps-hadoop-2.7 b/dev/deps/spark-deps-hadoop-2.7 index 7173426c7bf74..1b1e3166d53db 100644 --- a/dev/deps/spark-deps-hadoop-2.7 +++ b/dev/deps/spark-deps-hadoop-2.7 @@ -139,7 +139,7 @@ lz4-java-1.4.0.jar machinist_2.11-0.6.1.jar macro-compat_2.11-1.1.1.jar mail-1.4.7.jar -mesos-1.3.0-shaded-protobuf.jar +mesos-1.4.0-shaded-protobuf.jar metrics-core-3.1.5.jar metrics-graphite-3.1.5.jar metrics-json-3.1.5.jar diff --git a/resource-managers/mesos/pom.xml b/resource-managers/mesos/pom.xml index de8f1c913651d..70d0c1750b14e 100644 --- a/resource-managers/mesos/pom.xml +++ b/resource-managers/mesos/pom.xml @@ -29,7 +29,7 @@ Spark Project Mesos mesos - 1.3.0 + 1.4.0 shaded-protobuf diff --git a/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterScheduler.scala b/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterScheduler.scala index c41283e4a3e39..d224a7325820a 100644 --- a/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterScheduler.scala +++ b/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterScheduler.scala @@ -36,7 +36,6 @@ import org.apache.spark.deploy.rest.{CreateSubmissionResponse, KillSubmissionRes import org.apache.spark.metrics.MetricsSystem import org.apache.spark.util.Utils - /** * Tracks the current state of a Mesos Task that runs a Spark driver. * @param driverDescription Submitted driver description from diff --git a/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackend.scala b/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackend.scala index c392061fdb358..191415a2578b2 100644 --- a/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackend.scala +++ b/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackend.scala @@ -400,13 +400,20 @@ private[spark] class MesosCoarseGrainedSchedulerBackend( val offerMem = getResource(offer.getResourcesList, "mem") val offerCpus = getResource(offer.getResourcesList, "cpus") val offerPorts = getRangeResource(offer.getResourcesList, "ports") + val offerReservationInfo = offer + .getResourcesList + .asScala + .find { r => r.getReservation != null } val id = offer.getId.getValue if (tasks.contains(offer.getId)) { // accept val offerTasks = tasks(offer.getId) logDebug(s"Accepting offer: $id with attributes: $offerAttributes " + - s"mem: $offerMem cpu: $offerCpus ports: $offerPorts." + + offerReservationInfo.map(resInfo => + s"reservation info: ${resInfo.getReservation.toString}").getOrElse("") + + s"mem: $offerMem cpu: $offerCpus ports: $offerPorts " + + s"resources: ${offer.getResourcesList.asScala.mkString(",")}." + s" Launching ${offerTasks.size} Mesos tasks.") for (task <- offerTasks) { @@ -416,7 +423,7 @@ private[spark] class MesosCoarseGrainedSchedulerBackend( val ports = getRangeResource(task.getResourcesList, "ports").mkString(",") logDebug(s"Launching Mesos task: ${taskId.getValue} with mem: $mem cpu: $cpus" + - s" ports: $ports") + s" ports: $ports" + s" on slave with slave id: ${task.getSlaveId.getValue} ") } driver.launchTasks( @@ -431,7 +438,8 @@ private[spark] class MesosCoarseGrainedSchedulerBackend( } else { declineOffer( driver, - offer) + offer, + Some("Offer was declined due to unmet task launch constraints.")) } } } @@ -513,6 +521,9 @@ private[spark] class MesosCoarseGrainedSchedulerBackend( totalGpusAcquired += taskGPUs gpusByTaskId(taskId) = taskGPUs } + } else { + logDebug(s"Cannot launch a task for offer with id: $offerId on slave " + + s"with id: $slaveId. Requirements were not met for this offer.") } } } diff --git a/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerUtils.scala b/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerUtils.scala index 6fcb30af8a733..e75450369ad85 100644 --- a/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerUtils.scala +++ b/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerUtils.scala @@ -28,7 +28,8 @@ import com.google.common.base.Splitter import org.apache.mesos.{MesosSchedulerDriver, Protos, Scheduler, SchedulerDriver} import org.apache.mesos.Protos.{TaskState => MesosTaskState, _} import org.apache.mesos.Protos.FrameworkInfo.Capability -import org.apache.mesos.protobuf.{ByteString, GeneratedMessage} +import org.apache.mesos.Protos.Resource.ReservationInfo +import org.apache.mesos.protobuf.{ByteString, GeneratedMessageV3} import org.apache.spark.{SparkConf, SparkContext, SparkException} import org.apache.spark.TaskState @@ -36,8 +37,6 @@ import org.apache.spark.internal.Logging import org.apache.spark.internal.config._ import org.apache.spark.util.Utils - - /** * Shared trait for implementing a Mesos Scheduler. This holds common state and helper * methods and Mesos scheduler will use. @@ -46,6 +45,8 @@ trait MesosSchedulerUtils extends Logging { // Lock used to wait for scheduler to be registered private final val registerLatch = new CountDownLatch(1) + private final val ANY_ROLE = "*" + /** * Creates a new MesosSchedulerDriver that communicates to the Mesos master. * @@ -175,17 +176,36 @@ trait MesosSchedulerUtils extends Logging { registerLatch.countDown() } - def createResource(name: String, amount: Double, role: Option[String] = None): Resource = { + private def setReservationInfo( + reservationInfo: Option[ReservationInfo], + role: Option[String], + builder: Resource.Builder): Unit = { + if (!role.contains(ANY_ROLE)) { + reservationInfo.foreach { res => builder.setReservation(res) } + } + } + + def createResource( + name: String, + amount: Double, + role: Option[String] = None, + reservationInfo: Option[ReservationInfo] = None): Resource = { val builder = Resource.newBuilder() .setName(name) .setType(Value.Type.SCALAR) .setScalar(Value.Scalar.newBuilder().setValue(amount).build()) - role.foreach { r => builder.setRole(r) } - + setReservationInfo(reservationInfo, role, builder) builder.build() } + private def getReservation(resource: Resource): Option[ReservationInfo] = { + if (resource.hasReservation) { + Some(resource.getReservation) + } else { + None + } + } /** * Partition the existing set of resources into two groups, those remaining to be * scheduled and those requested to be used for a new task. @@ -203,14 +223,17 @@ trait MesosSchedulerUtils extends Logging { var requestedResources = new ArrayBuffer[Resource] val remainingResources = resources.asScala.map { case r => + val reservation = getReservation(r) if (remain > 0 && r.getType == Value.Type.SCALAR && r.getScalar.getValue > 0.0 && r.getName == resourceName) { val usage = Math.min(remain, r.getScalar.getValue) - requestedResources += createResource(resourceName, usage, Some(r.getRole)) + requestedResources += createResource(resourceName, usage, + Option(r.getRole), reservation) remain -= usage - createResource(resourceName, r.getScalar.getValue - usage, Some(r.getRole)) + createResource(resourceName, r.getScalar.getValue - usage, + Option(r.getRole), reservation) } else { r } @@ -228,16 +251,6 @@ trait MesosSchedulerUtils extends Logging { (attr.getName, attr.getText.getValue.split(',').toSet) } - - /** Build a Mesos resource protobuf object */ - protected def createResource(resourceName: String, quantity: Double): Protos.Resource = { - Resource.newBuilder() - .setName(resourceName) - .setType(Value.Type.SCALAR) - .setScalar(Value.Scalar.newBuilder().setValue(quantity).build()) - .build() - } - /** * Converts the attributes from the resource offer into a Map of name to Attribute Value * The attribute values are the mesos attribute types and they are @@ -245,7 +258,8 @@ trait MesosSchedulerUtils extends Logging { * @param offerAttributes the attributes offered * @return */ - protected def toAttributeMap(offerAttributes: JList[Attribute]): Map[String, GeneratedMessage] = { + protected def toAttributeMap(offerAttributes: JList[Attribute]) + : Map[String, GeneratedMessageV3] = { offerAttributes.asScala.map { attr => val attrValue = attr.getType match { case Value.Type.SCALAR => attr.getScalar @@ -266,7 +280,7 @@ trait MesosSchedulerUtils extends Logging { */ def matchesAttributeRequirements( slaveOfferConstraints: Map[String, Set[String]], - offerAttributes: Map[String, GeneratedMessage]): Boolean = { + offerAttributes: Map[String, GeneratedMessageV3]): Boolean = { slaveOfferConstraints.forall { // offer has the required attribute and subsumes the required values for that attribute case (name, requiredValues) => @@ -427,10 +441,10 @@ trait MesosSchedulerUtils extends Logging { // partition port offers val (resourcesWithoutPorts, portResources) = filterPortResources(offeredResources) - val portsAndRoles = requestedPorts. - map(x => (x, findPortAndGetAssignedRangeRole(x, portResources))) + val portsAndResourceInfo = requestedPorts. + map { x => (x, findPortAndGetAssignedResourceInfo(x, portResources)) } - val assignedPortResources = createResourcesFromPorts(portsAndRoles) + val assignedPortResources = createResourcesFromPorts(portsAndResourceInfo) // ignore non-assigned port resources, they will be declined implicitly by mesos // no need for splitting port resources. @@ -450,16 +464,25 @@ trait MesosSchedulerUtils extends Logging { managedPortNames.map(conf.getLong(_, 0)).filter( _ != 0) } + private case class RoleResourceInfo( + role: String, + resInfo: Option[ReservationInfo]) + /** Creates a mesos resource for a specific port number. */ - private def createResourcesFromPorts(portsAndRoles: List[(Long, String)]) : List[Resource] = { - portsAndRoles.flatMap{ case (port, role) => - createMesosPortResource(List((port, port)), Some(role))} + private def createResourcesFromPorts( + portsAndResourcesInfo: List[(Long, RoleResourceInfo)]) + : List[Resource] = { + portsAndResourcesInfo.flatMap { case (port, rInfo) => + createMesosPortResource(List((port, port)), Option(rInfo.role), rInfo.resInfo)} } /** Helper to create mesos resources for specific port ranges. */ private def createMesosPortResource( ranges: List[(Long, Long)], - role: Option[String] = None): List[Resource] = { + role: Option[String] = None, + reservationInfo: Option[ReservationInfo] = None): List[Resource] = { + // for ranges we are going to use (user defined ports fall in there) create mesos resources + // for each range there is a role associated with it. ranges.map { case (rangeStart, rangeEnd) => val rangeValue = Value.Range.newBuilder() .setBegin(rangeStart) @@ -468,7 +491,8 @@ trait MesosSchedulerUtils extends Logging { .setName("ports") .setType(Value.Type.RANGES) .setRanges(Value.Ranges.newBuilder().addRange(rangeValue)) - role.foreach(r => builder.setRole(r)) + role.foreach { r => builder.setRole(r) } + setReservationInfo(reservationInfo, role, builder) builder.build() } } @@ -477,19 +501,21 @@ trait MesosSchedulerUtils extends Logging { * Helper to assign a port to an offered range and get the latter's role * info to use it later on. */ - private def findPortAndGetAssignedRangeRole(port: Long, portResources: List[Resource]) - : String = { + private def findPortAndGetAssignedResourceInfo(port: Long, portResources: List[Resource]) + : RoleResourceInfo = { val ranges = portResources. - map(resource => - (resource.getRole, resource.getRanges.getRangeList.asScala - .map(r => (r.getBegin, r.getEnd)).toList)) + map { resource => + val reservation = getReservation(resource) + (RoleResourceInfo(resource.getRole, reservation), + resource.getRanges.getRangeList.asScala.map(r => (r.getBegin, r.getEnd)).toList) + } - val rangePortRole = ranges - .find { case (role, rangeList) => rangeList + val rangePortResourceInfo = ranges + .find { case (resourceInfo, rangeList) => rangeList .exists{ case (rangeStart, rangeEnd) => rangeStart <= port & rangeEnd >= port}} // this is safe since we have previously checked about the ranges (see checkPorts method) - rangePortRole.map{ case (role, rangeList) => role}.get + rangePortResourceInfo.map{ case (resourceInfo, rangeList) => resourceInfo}.get } /** Retrieves the port resources from a list of mesos offered resources */ @@ -564,3 +590,4 @@ trait MesosSchedulerUtils extends Logging { } } } + From 8ff474f6e543203fac5d49af7fbe98a8a98da567 Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Wed, 29 Nov 2017 14:34:41 -0800 Subject: [PATCH 1742/1765] [SPARK-20650][CORE] Remove JobProgressListener. The only remaining use of this class was the SparkStatusTracker, which was modified to use the new status store. The test code to wait for executors was moved to TestUtils and now uses the SparkStatusTracker API. Indirectly, ConsoleProgressBar also uses this data. Because it has some lower latency requirements, a shortcut to efficiently get the active stages from the active listener was added to the AppStateStore. Now that all UI code goes through the status store to get its data, the FsHistoryProvider can be cleaned up to only replay event logs when needed - that is, when there is no pre-existing disk store for the application. As part of this change I also modified the streaming UI to read the needed data from the store, which was missed in the previous patch that made JobProgressListener redundant. Author: Marcelo Vanzin Closes #19750 from vanzin/SPARK-20650. --- .../scala/org/apache/spark/SparkContext.scala | 11 +- .../org/apache/spark/SparkStatusTracker.scala | 76 +-- .../scala/org/apache/spark/TestUtils.scala | 26 +- .../deploy/history/FsHistoryProvider.scala | 65 +- .../spark/status/AppStatusListener.scala | 51 +- .../apache/spark/status/AppStatusStore.scala | 17 +- .../org/apache/spark/status/LiveEntity.scala | 8 +- .../spark/status/api/v1/StagesResource.scala | 1 - .../apache/spark/ui/ConsoleProgressBar.scala | 18 +- .../spark/ui/jobs/JobProgressListener.scala | 612 ------------------ .../org/apache/spark/ui/jobs/StagePage.scala | 1 - .../org/apache/spark/ui/jobs/UIData.scala | 311 --------- .../org/apache/spark/DistributedSuite.scala | 2 +- .../spark/ExternalShuffleServiceSuite.scala | 2 +- .../org/apache/spark/StatusTrackerSuite.scala | 6 +- .../spark/broadcast/BroadcastSuite.scala | 2 +- .../spark/scheduler/DAGSchedulerSuite.scala | 4 +- .../SparkListenerWithClusterSuite.scala | 4 +- .../ui/jobs/JobProgressListenerSuite.scala | 442 ------------- project/MimaExcludes.scala | 2 + .../apache/spark/streaming/ui/BatchPage.scala | 75 ++- 21 files changed, 208 insertions(+), 1528 deletions(-) delete mode 100644 core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala delete mode 100644 core/src/main/scala/org/apache/spark/ui/jobs/UIData.scala delete mode 100644 core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index 23fd54f59268a..984dd0a6629a2 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -58,7 +58,6 @@ import org.apache.spark.status.{AppStatusPlugin, AppStatusStore} import org.apache.spark.storage._ import org.apache.spark.storage.BlockManagerMessages.TriggerThreadDump import org.apache.spark.ui.{ConsoleProgressBar, SparkUI} -import org.apache.spark.ui.jobs.JobProgressListener import org.apache.spark.util._ /** @@ -195,7 +194,6 @@ class SparkContext(config: SparkConf) extends Logging { private var _eventLogCodec: Option[String] = None private var _listenerBus: LiveListenerBus = _ private var _env: SparkEnv = _ - private var _jobProgressListener: JobProgressListener = _ private var _statusTracker: SparkStatusTracker = _ private var _progressBar: Option[ConsoleProgressBar] = None private var _ui: Option[SparkUI] = None @@ -270,8 +268,6 @@ class SparkContext(config: SparkConf) extends Logging { val map: ConcurrentMap[Int, RDD[_]] = new MapMaker().weakValues().makeMap[Int, RDD[_]]() map.asScala } - private[spark] def jobProgressListener: JobProgressListener = _jobProgressListener - def statusTracker: SparkStatusTracker = _statusTracker private[spark] def progressBar: Option[ConsoleProgressBar] = _progressBar @@ -421,11 +417,6 @@ class SparkContext(config: SparkConf) extends Logging { _listenerBus = new LiveListenerBus(_conf) - // "_jobProgressListener" should be set up before creating SparkEnv because when creating - // "SparkEnv", some messages will be posted to "listenerBus" and we should not miss them. - _jobProgressListener = new JobProgressListener(_conf) - listenerBus.addToStatusQueue(jobProgressListener) - // Initialize the app status store and listener before SparkEnv is created so that it gets // all events. _statusStore = AppStatusStore.createLiveStore(conf, l => listenerBus.addToStatusQueue(l)) @@ -440,7 +431,7 @@ class SparkContext(config: SparkConf) extends Logging { _conf.set("spark.repl.class.uri", replUri) } - _statusTracker = new SparkStatusTracker(this) + _statusTracker = new SparkStatusTracker(this, _statusStore) _progressBar = if (_conf.get(UI_SHOW_CONSOLE_PROGRESS) && !log.isInfoEnabled) { diff --git a/core/src/main/scala/org/apache/spark/SparkStatusTracker.scala b/core/src/main/scala/org/apache/spark/SparkStatusTracker.scala index 22a553e68439a..70865cb58c571 100644 --- a/core/src/main/scala/org/apache/spark/SparkStatusTracker.scala +++ b/core/src/main/scala/org/apache/spark/SparkStatusTracker.scala @@ -17,7 +17,10 @@ package org.apache.spark -import org.apache.spark.scheduler.TaskSchedulerImpl +import java.util.Arrays + +import org.apache.spark.status.AppStatusStore +import org.apache.spark.status.api.v1.StageStatus /** * Low-level status reporting APIs for monitoring job and stage progress. @@ -33,9 +36,7 @@ import org.apache.spark.scheduler.TaskSchedulerImpl * * NOTE: this class's constructor should be considered private and may be subject to change. */ -class SparkStatusTracker private[spark] (sc: SparkContext) { - - private val jobProgressListener = sc.jobProgressListener +class SparkStatusTracker private[spark] (sc: SparkContext, store: AppStatusStore) { /** * Return a list of all known jobs in a particular job group. If `jobGroup` is `null`, then @@ -46,9 +47,8 @@ class SparkStatusTracker private[spark] (sc: SparkContext) { * its result. */ def getJobIdsForGroup(jobGroup: String): Array[Int] = { - jobProgressListener.synchronized { - jobProgressListener.jobGroupToJobIds.getOrElse(jobGroup, Seq.empty).toArray - } + val expected = Option(jobGroup) + store.jobsList(null).filter(_.jobGroup == expected).map(_.jobId).toArray } /** @@ -57,9 +57,7 @@ class SparkStatusTracker private[spark] (sc: SparkContext) { * This method does not guarantee the order of the elements in its result. */ def getActiveStageIds(): Array[Int] = { - jobProgressListener.synchronized { - jobProgressListener.activeStages.values.map(_.stageId).toArray - } + store.stageList(Arrays.asList(StageStatus.ACTIVE)).map(_.stageId).toArray } /** @@ -68,19 +66,15 @@ class SparkStatusTracker private[spark] (sc: SparkContext) { * This method does not guarantee the order of the elements in its result. */ def getActiveJobIds(): Array[Int] = { - jobProgressListener.synchronized { - jobProgressListener.activeJobs.values.map(_.jobId).toArray - } + store.jobsList(Arrays.asList(JobExecutionStatus.RUNNING)).map(_.jobId).toArray } /** * Returns job information, or `None` if the job info could not be found or was garbage collected. */ def getJobInfo(jobId: Int): Option[SparkJobInfo] = { - jobProgressListener.synchronized { - jobProgressListener.jobIdToData.get(jobId).map { data => - new SparkJobInfoImpl(jobId, data.stageIds.toArray, data.status) - } + store.asOption(store.job(jobId)).map { job => + new SparkJobInfoImpl(jobId, job.stageIds.toArray, job.status) } } @@ -89,21 +83,16 @@ class SparkStatusTracker private[spark] (sc: SparkContext) { * garbage collected. */ def getStageInfo(stageId: Int): Option[SparkStageInfo] = { - jobProgressListener.synchronized { - for ( - info <- jobProgressListener.stageIdToInfo.get(stageId); - data <- jobProgressListener.stageIdToData.get((stageId, info.attemptId)) - ) yield { - new SparkStageInfoImpl( - stageId, - info.attemptId, - info.submissionTime.getOrElse(0), - info.name, - info.numTasks, - data.numActiveTasks, - data.numCompleteTasks, - data.numFailedTasks) - } + store.asOption(store.lastStageAttempt(stageId)).map { stage => + new SparkStageInfoImpl( + stageId, + stage.attemptId, + stage.submissionTime.map(_.getTime()).getOrElse(0L), + stage.name, + stage.numTasks, + stage.numActiveTasks, + stage.numCompleteTasks, + stage.numFailedTasks) } } @@ -111,17 +100,20 @@ class SparkStatusTracker private[spark] (sc: SparkContext) { * Returns information of all known executors, including host, port, cacheSize, numRunningTasks. */ def getExecutorInfos: Array[SparkExecutorInfo] = { - val executorIdToRunningTasks: Map[String, Int] = - sc.taskScheduler.asInstanceOf[TaskSchedulerImpl].runningTasksByExecutors + store.executorList(true).map { exec => + val (host, port) = exec.hostPort.split(":", 2) match { + case Array(h, p) => (h, p.toInt) + case Array(h) => (h, -1) + } + val cachedMem = exec.memoryMetrics.map { mem => + mem.usedOnHeapStorageMemory + mem.usedOffHeapStorageMemory + }.getOrElse(0L) - sc.getExecutorStorageStatus.map { status => - val bmId = status.blockManagerId new SparkExecutorInfoImpl( - bmId.host, - bmId.port, - status.cacheSize, - executorIdToRunningTasks.getOrElse(bmId.executorId, 0) - ) - } + host, + port, + cachedMem, + exec.activeTasks) + }.toArray } } diff --git a/core/src/main/scala/org/apache/spark/TestUtils.scala b/core/src/main/scala/org/apache/spark/TestUtils.scala index a80016dd22fc5..93e7ee3d2a404 100644 --- a/core/src/main/scala/org/apache/spark/TestUtils.scala +++ b/core/src/main/scala/org/apache/spark/TestUtils.scala @@ -23,7 +23,7 @@ import java.nio.charset.StandardCharsets import java.security.SecureRandom import java.security.cert.X509Certificate import java.util.Arrays -import java.util.concurrent.{CountDownLatch, TimeUnit} +import java.util.concurrent.{CountDownLatch, TimeoutException, TimeUnit} import java.util.jar.{JarEntry, JarOutputStream} import javax.net.ssl._ import javax.tools.{JavaFileObject, SimpleJavaFileObject, ToolProvider} @@ -232,6 +232,30 @@ private[spark] object TestUtils { } } + /** + * Wait until at least `numExecutors` executors are up, or throw `TimeoutException` if the waiting + * time elapsed before `numExecutors` executors up. Exposed for testing. + * + * @param numExecutors the number of executors to wait at least + * @param timeout time to wait in milliseconds + */ + private[spark] def waitUntilExecutorsUp( + sc: SparkContext, + numExecutors: Int, + timeout: Long): Unit = { + val finishTime = System.nanoTime() + TimeUnit.MILLISECONDS.toNanos(timeout) + while (System.nanoTime() < finishTime) { + if (sc.statusTracker.getExecutorInfos.length > numExecutors) { + return + } + // Sleep rather than using wait/notify, because this is used only for testing and wait/notify + // add overhead in the general case. + Thread.sleep(10) + } + throw new TimeoutException( + s"Can't find $numExecutors executors before $timeout milliseconds elapsed") + } + } diff --git a/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala b/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala index 69ccde3a8149d..6a83c106f6d84 100644 --- a/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala +++ b/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala @@ -299,8 +299,6 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) attempt.adminAclsGroups.getOrElse("")) secManager.setViewAclsGroups(attempt.viewAclsGroups.getOrElse("")) - val replayBus = new ReplayListenerBus() - val uiStorePath = storePath.map { path => getStorePath(path, appId, attemptId) } val (kvstore, needReplay) = uiStorePath match { @@ -320,48 +318,43 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) (new InMemoryStore(), true) } - val listener = if (needReplay) { - val _listener = new AppStatusListener(kvstore, conf, false, + if (needReplay) { + val replayBus = new ReplayListenerBus() + val listener = new AppStatusListener(kvstore, conf, false, lastUpdateTime = Some(attempt.info.lastUpdated.getTime())) - replayBus.addListener(_listener) + replayBus.addListener(listener) AppStatusPlugin.loadPlugins().foreach { plugin => plugin.setupListeners(conf, kvstore, l => replayBus.addListener(l), false) } - Some(_listener) - } else { - None + try { + val fileStatus = fs.getFileStatus(new Path(logDir, attempt.logPath)) + replay(fileStatus, isApplicationCompleted(fileStatus), replayBus) + listener.flush() + } catch { + case e: Exception => + try { + kvstore.close() + } catch { + case _e: Exception => logInfo("Error closing store.", _e) + } + uiStorePath.foreach(Utils.deleteRecursively) + if (e.isInstanceOf[FileNotFoundException]) { + return None + } else { + throw e + } + } } - val loadedUI = { - val ui = SparkUI.create(None, new AppStatusStore(kvstore), conf, secManager, app.info.name, - HistoryServer.getAttemptURI(appId, attempt.info.attemptId), - attempt.info.startTime.getTime(), - attempt.info.appSparkVersion) - LoadedAppUI(ui) + val ui = SparkUI.create(None, new AppStatusStore(kvstore), conf, secManager, app.info.name, + HistoryServer.getAttemptURI(appId, attempt.info.attemptId), + attempt.info.startTime.getTime(), + attempt.info.appSparkVersion) + AppStatusPlugin.loadPlugins().foreach { plugin => + plugin.setupUI(ui) } - try { - AppStatusPlugin.loadPlugins().foreach { plugin => - plugin.setupUI(loadedUI.ui) - } - - val fileStatus = fs.getFileStatus(new Path(logDir, attempt.logPath)) - replay(fileStatus, isApplicationCompleted(fileStatus), replayBus) - listener.foreach(_.flush()) - } catch { - case e: Exception => - try { - kvstore.close() - } catch { - case _e: Exception => logInfo("Error closing store.", _e) - } - uiStorePath.foreach(Utils.deleteRecursively) - if (e.isInstanceOf[FileNotFoundException]) { - return None - } else { - throw e - } - } + val loadedUI = LoadedAppUI(ui) synchronized { activeUIs((appId, attemptId)) = loadedUI diff --git a/core/src/main/scala/org/apache/spark/status/AppStatusListener.scala b/core/src/main/scala/org/apache/spark/status/AppStatusListener.scala index f2d8e0a5480ba..9c23d9d8c923a 100644 --- a/core/src/main/scala/org/apache/spark/status/AppStatusListener.scala +++ b/core/src/main/scala/org/apache/spark/status/AppStatusListener.scala @@ -18,7 +18,10 @@ package org.apache.spark.status import java.util.Date +import java.util.concurrent.ConcurrentHashMap +import java.util.function.Function +import scala.collection.JavaConverters._ import scala.collection.mutable.HashMap import org.apache.spark._ @@ -59,7 +62,7 @@ private[spark] class AppStatusListener( // Keep track of live entities, so that task metrics can be efficiently updated (without // causing too many writes to the underlying store, and other expensive operations). - private val liveStages = new HashMap[(Int, Int), LiveStage]() + private val liveStages = new ConcurrentHashMap[(Int, Int), LiveStage]() private val liveJobs = new HashMap[Int, LiveJob]() private val liveExecutors = new HashMap[String, LiveExecutor]() private val liveTasks = new HashMap[Long, LiveTask]() @@ -268,13 +271,15 @@ private[spark] class AppStatusListener( val now = System.nanoTime() // Check if there are any pending stages that match this job; mark those as skipped. - job.stageIds.foreach { sid => - val pending = liveStages.filter { case ((id, _), _) => id == sid } - pending.foreach { case (key, stage) => + val it = liveStages.entrySet.iterator() + while (it.hasNext()) { + val e = it.next() + if (job.stageIds.contains(e.getKey()._1)) { + val stage = e.getValue() stage.status = v1.StageStatus.SKIPPED job.skippedStages += stage.info.stageId job.skippedTasks += stage.info.numTasks - liveStages.remove(key) + it.remove() update(stage, now) } } @@ -336,7 +341,7 @@ private[spark] class AppStatusListener( liveTasks.put(event.taskInfo.taskId, task) liveUpdate(task, now) - liveStages.get((event.stageId, event.stageAttemptId)).foreach { stage => + Option(liveStages.get((event.stageId, event.stageAttemptId))).foreach { stage => stage.activeTasks += 1 stage.firstLaunchTime = math.min(stage.firstLaunchTime, event.taskInfo.launchTime) maybeUpdate(stage, now) @@ -403,7 +408,7 @@ private[spark] class AppStatusListener( (0, 1, 0) } - liveStages.get((event.stageId, event.stageAttemptId)).foreach { stage => + Option(liveStages.get((event.stageId, event.stageAttemptId))).foreach { stage => if (metricsDelta != null) { stage.metrics.update(metricsDelta) } @@ -466,12 +471,19 @@ private[spark] class AppStatusListener( } } - maybeUpdate(exec, now) + // Force an update on live applications when the number of active tasks reaches 0. This is + // checked in some tests (e.g. SQLTestUtilsBase) so it needs to be reliably up to date. + if (exec.activeTasks == 0) { + liveUpdate(exec, now) + } else { + maybeUpdate(exec, now) + } } } override def onStageCompleted(event: SparkListenerStageCompleted): Unit = { - liveStages.remove((event.stageInfo.stageId, event.stageInfo.attemptId)).foreach { stage => + val maybeStage = Option(liveStages.remove((event.stageInfo.stageId, event.stageInfo.attemptId))) + maybeStage.foreach { stage => val now = System.nanoTime() stage.info = event.stageInfo @@ -540,7 +552,7 @@ private[spark] class AppStatusListener( val delta = task.updateMetrics(metrics) maybeUpdate(task, now) - liveStages.get((sid, sAttempt)).foreach { stage => + Option(liveStages.get((sid, sAttempt))).foreach { stage => stage.metrics.update(delta) maybeUpdate(stage, now) @@ -563,7 +575,7 @@ private[spark] class AppStatusListener( /** Flush all live entities' data to the underlying store. */ def flush(): Unit = { val now = System.nanoTime() - liveStages.values.foreach { stage => + liveStages.values.asScala.foreach { stage => update(stage, now) stage.executorSummaries.values.foreach(update(_, now)) } @@ -574,6 +586,18 @@ private[spark] class AppStatusListener( pools.values.foreach(update(_, now)) } + /** + * Shortcut to get active stages quickly in a live application, for use by the console + * progress bar. + */ + def activeStages(): Seq[v1.StageData] = { + liveStages.values.asScala + .filter(_.info.submissionTime.isDefined) + .map(_.toApi()) + .toList + .sortBy(_.stageId) + } + private def updateRDDBlock(event: SparkListenerBlockUpdated, block: RDDBlockId): Unit = { val now = System.nanoTime() val executorId = event.blockUpdatedInfo.blockManagerId.executorId @@ -708,7 +732,10 @@ private[spark] class AppStatusListener( } private def getOrCreateStage(info: StageInfo): LiveStage = { - val stage = liveStages.getOrElseUpdate((info.stageId, info.attemptId), new LiveStage()) + val stage = liveStages.computeIfAbsent((info.stageId, info.attemptId), + new Function[(Int, Int), LiveStage]() { + override def apply(key: (Int, Int)): LiveStage = new LiveStage() + }) stage.info = info stage } diff --git a/core/src/main/scala/org/apache/spark/status/AppStatusStore.scala b/core/src/main/scala/org/apache/spark/status/AppStatusStore.scala index d0615e5dd0223..22d768b3cb990 100644 --- a/core/src/main/scala/org/apache/spark/status/AppStatusStore.scala +++ b/core/src/main/scala/org/apache/spark/status/AppStatusStore.scala @@ -32,7 +32,9 @@ import org.apache.spark.util.kvstore.{InMemoryStore, KVStore} /** * A wrapper around a KVStore that provides methods for accessing the API data stored within. */ -private[spark] class AppStatusStore(val store: KVStore) { +private[spark] class AppStatusStore( + val store: KVStore, + listener: Option[AppStatusListener] = None) { def applicationInfo(): v1.ApplicationInfo = { store.view(classOf[ApplicationInfoWrapper]).max(1).iterator().next().info @@ -70,6 +72,14 @@ private[spark] class AppStatusStore(val store: KVStore) { store.read(classOf[ExecutorSummaryWrapper], executorId).info } + /** + * This is used by ConsoleProgressBar to quickly fetch active stages for drawing the progress + * bar. It will only return anything useful when called from a live application. + */ + def activeStages(): Seq[v1.StageData] = { + listener.map(_.activeStages()).getOrElse(Nil) + } + def stageList(statuses: JList[v1.StageStatus]): Seq[v1.StageData] = { val it = store.view(classOf[StageDataWrapper]).reverse().asScala.map(_.info) if (statuses != null && !statuses.isEmpty()) { @@ -338,11 +348,12 @@ private[spark] object AppStatusStore { */ def createLiveStore(conf: SparkConf, addListenerFn: SparkListener => Unit): AppStatusStore = { val store = new InMemoryStore() - addListenerFn(new AppStatusListener(store, conf, true)) + val listener = new AppStatusListener(store, conf, true) + addListenerFn(listener) AppStatusPlugin.loadPlugins().foreach { p => p.setupListeners(conf, store, addListenerFn, true) } - new AppStatusStore(store) + new AppStatusStore(store, listener = Some(listener)) } } diff --git a/core/src/main/scala/org/apache/spark/status/LiveEntity.scala b/core/src/main/scala/org/apache/spark/status/LiveEntity.scala index ef2936c9b69a4..983c58a607aa8 100644 --- a/core/src/main/scala/org/apache/spark/status/LiveEntity.scala +++ b/core/src/main/scala/org/apache/spark/status/LiveEntity.scala @@ -408,8 +408,8 @@ private class LiveStage extends LiveEntity { new LiveExecutorStageSummary(info.stageId, info.attemptId, executorId)) } - override protected def doUpdate(): Any = { - val update = new v1.StageData( + def toApi(): v1.StageData = { + new v1.StageData( status, info.stageId, info.attemptId, @@ -449,8 +449,10 @@ private class LiveStage extends LiveEntity { None, None, killedSummary) + } - new StageDataWrapper(update, jobIds) + override protected def doUpdate(): Any = { + new StageDataWrapper(toApi(), jobIds) } } diff --git a/core/src/main/scala/org/apache/spark/status/api/v1/StagesResource.scala b/core/src/main/scala/org/apache/spark/status/api/v1/StagesResource.scala index bd4dfe3c68885..b3561109bc636 100644 --- a/core/src/main/scala/org/apache/spark/status/api/v1/StagesResource.scala +++ b/core/src/main/scala/org/apache/spark/status/api/v1/StagesResource.scala @@ -25,7 +25,6 @@ import org.apache.spark.scheduler.StageInfo import org.apache.spark.status.api.v1.StageStatus._ import org.apache.spark.status.api.v1.TaskSorting._ import org.apache.spark.ui.SparkUI -import org.apache.spark.ui.jobs.UIData.StageUIData @Produces(Array(MediaType.APPLICATION_JSON)) private[v1] class StagesResource extends BaseAppResource { diff --git a/core/src/main/scala/org/apache/spark/ui/ConsoleProgressBar.scala b/core/src/main/scala/org/apache/spark/ui/ConsoleProgressBar.scala index 3ae80ecfd22e6..3c4ee4eb6bbb9 100644 --- a/core/src/main/scala/org/apache/spark/ui/ConsoleProgressBar.scala +++ b/core/src/main/scala/org/apache/spark/ui/ConsoleProgressBar.scala @@ -21,10 +21,11 @@ import java.util.{Timer, TimerTask} import org.apache.spark._ import org.apache.spark.internal.Logging +import org.apache.spark.status.api.v1.StageData /** * ConsoleProgressBar shows the progress of stages in the next line of the console. It poll the - * status of active stages from `sc.statusTracker` periodically, the progress bar will be showed + * status of active stages from the app state store periodically, the progress bar will be showed * up after the stage has ran at least 500ms. If multiple stages run in the same time, the status * of them will be combined together, showed in one line. */ @@ -64,9 +65,8 @@ private[spark] class ConsoleProgressBar(sc: SparkContext) extends Logging { if (now - lastFinishTime < firstDelayMSec) { return } - val stageIds = sc.statusTracker.getActiveStageIds() - val stages = stageIds.flatMap(sc.statusTracker.getStageInfo).filter(_.numTasks() > 1) - .filter(now - _.submissionTime() > firstDelayMSec).sortBy(_.stageId()) + val stages = sc.statusStore.activeStages() + .filter { s => now - s.submissionTime.get.getTime() > firstDelayMSec } if (stages.length > 0) { show(now, stages.take(3)) // display at most 3 stages in same time } @@ -77,15 +77,15 @@ private[spark] class ConsoleProgressBar(sc: SparkContext) extends Logging { * after your last output, keeps overwriting itself to hold in one line. The logging will follow * the progress bar, then progress bar will be showed in next line without overwrite logs. */ - private def show(now: Long, stages: Seq[SparkStageInfo]) { + private def show(now: Long, stages: Seq[StageData]) { val width = TerminalWidth / stages.size val bar = stages.map { s => - val total = s.numTasks() - val header = s"[Stage ${s.stageId()}:" - val tailer = s"(${s.numCompletedTasks()} + ${s.numActiveTasks()}) / $total]" + val total = s.numTasks + val header = s"[Stage ${s.stageId}:" + val tailer = s"(${s.numCompleteTasks} + ${s.numActiveTasks}) / $total]" val w = width - header.length - tailer.length val bar = if (w > 0) { - val percent = w * s.numCompletedTasks() / total + val percent = w * s.numCompleteTasks / total (0 until w).map { i => if (i < percent) "=" else if (i == percent) ">" else " " }.mkString("") diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala b/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala deleted file mode 100644 index a18e86ec0a73b..0000000000000 --- a/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala +++ /dev/null @@ -1,612 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.ui.jobs - -import java.util.concurrent.TimeoutException - -import scala.collection.mutable.{HashMap, HashSet, LinkedHashMap, ListBuffer} - -import org.apache.spark._ -import org.apache.spark.annotation.DeveloperApi -import org.apache.spark.executor.TaskMetrics -import org.apache.spark.internal.Logging -import org.apache.spark.internal.config._ -import org.apache.spark.scheduler._ -import org.apache.spark.scheduler.SchedulingMode.SchedulingMode -import org.apache.spark.storage.BlockManagerId -import org.apache.spark.ui.SparkUI -import org.apache.spark.ui.jobs.UIData._ - -/** - * :: DeveloperApi :: - * Tracks task-level information to be displayed in the UI. - * - * All access to the data structures in this class must be synchronized on the - * class, since the UI thread and the EventBus loop may otherwise be reading and - * updating the internal data structures concurrently. - */ -@DeveloperApi -@deprecated("This class will be removed in a future release.", "2.2.0") -class JobProgressListener(conf: SparkConf) extends SparkListener with Logging { - - // Define a handful of type aliases so that data structures' types can serve as documentation. - // These type aliases are public because they're used in the types of public fields: - - type JobId = Int - type JobGroupId = String - type StageId = Int - type StageAttemptId = Int - type PoolName = String - type ExecutorId = String - - // Application: - @volatile var startTime = -1L - @volatile var endTime = -1L - - // Jobs: - val activeJobs = new HashMap[JobId, JobUIData] - val completedJobs = ListBuffer[JobUIData]() - val failedJobs = ListBuffer[JobUIData]() - val jobIdToData = new HashMap[JobId, JobUIData] - val jobGroupToJobIds = new HashMap[JobGroupId, HashSet[JobId]] - - // Stages: - val pendingStages = new HashMap[StageId, StageInfo] - val activeStages = new HashMap[StageId, StageInfo] - val completedStages = ListBuffer[StageInfo]() - val skippedStages = ListBuffer[StageInfo]() - val failedStages = ListBuffer[StageInfo]() - val stageIdToData = new HashMap[(StageId, StageAttemptId), StageUIData] - val stageIdToInfo = new HashMap[StageId, StageInfo] - val stageIdToActiveJobIds = new HashMap[StageId, HashSet[JobId]] - val poolToActiveStages = HashMap[PoolName, HashMap[StageId, StageInfo]]() - // Total of completed and failed stages that have ever been run. These may be greater than - // `completedStages.size` and `failedStages.size` if we have run more stages or jobs than - // JobProgressListener's retention limits. - var numCompletedStages = 0 - var numFailedStages = 0 - var numCompletedJobs = 0 - var numFailedJobs = 0 - - // Misc: - val executorIdToBlockManagerId = HashMap[ExecutorId, BlockManagerId]() - - def blockManagerIds: Seq[BlockManagerId] = executorIdToBlockManagerId.values.toSeq - - var schedulingMode: Option[SchedulingMode] = None - - // To limit the total memory usage of JobProgressListener, we only track information for a fixed - // number of non-active jobs and stages (there is no limit for active jobs and stages): - - val retainedStages = conf.getInt("spark.ui.retainedStages", SparkUI.DEFAULT_RETAINED_STAGES) - val retainedJobs = conf.getInt("spark.ui.retainedJobs", SparkUI.DEFAULT_RETAINED_JOBS) - val retainedTasks = conf.get(UI_RETAINED_TASKS) - - // We can test for memory leaks by ensuring that collections that track non-active jobs and - // stages do not grow without bound and that collections for active jobs/stages eventually become - // empty once Spark is idle. Let's partition our collections into ones that should be empty - // once Spark is idle and ones that should have a hard- or soft-limited sizes. - // These methods are used by unit tests, but they're defined here so that people don't forget to - // update the tests when adding new collections. Some collections have multiple levels of - // nesting, etc, so this lets us customize our notion of "size" for each structure: - - // These collections should all be empty once Spark is idle (no active stages / jobs): - private[spark] def getSizesOfActiveStateTrackingCollections: Map[String, Int] = { - Map( - "activeStages" -> activeStages.size, - "activeJobs" -> activeJobs.size, - "poolToActiveStages" -> poolToActiveStages.values.map(_.size).sum, - "stageIdToActiveJobIds" -> stageIdToActiveJobIds.values.map(_.size).sum - ) - } - - // These collections should stop growing once we have run at least `spark.ui.retainedStages` - // stages and `spark.ui.retainedJobs` jobs: - private[spark] def getSizesOfHardSizeLimitedCollections: Map[String, Int] = { - Map( - "completedJobs" -> completedJobs.size, - "failedJobs" -> failedJobs.size, - "completedStages" -> completedStages.size, - "skippedStages" -> skippedStages.size, - "failedStages" -> failedStages.size - ) - } - - // These collections may grow arbitrarily, but once Spark becomes idle they should shrink back to - // some bound based on the `spark.ui.retainedStages` and `spark.ui.retainedJobs` settings: - private[spark] def getSizesOfSoftSizeLimitedCollections: Map[String, Int] = { - Map( - "jobIdToData" -> jobIdToData.size, - "stageIdToData" -> stageIdToData.size, - "stageIdToStageInfo" -> stageIdToInfo.size, - "jobGroupToJobIds" -> jobGroupToJobIds.values.map(_.size).sum, - // Since jobGroupToJobIds is map of sets, check that we don't leak keys with empty values: - "jobGroupToJobIds keySet" -> jobGroupToJobIds.keys.size - ) - } - - /** If stages is too large, remove and garbage collect old stages */ - private def trimStagesIfNecessary(stages: ListBuffer[StageInfo]) = synchronized { - if (stages.size > retainedStages) { - val toRemove = calculateNumberToRemove(stages.size, retainedStages) - stages.take(toRemove).foreach { s => - stageIdToData.remove((s.stageId, s.attemptId)) - stageIdToInfo.remove(s.stageId) - } - stages.trimStart(toRemove) - } - } - - /** If jobs is too large, remove and garbage collect old jobs */ - private def trimJobsIfNecessary(jobs: ListBuffer[JobUIData]) = synchronized { - if (jobs.size > retainedJobs) { - val toRemove = calculateNumberToRemove(jobs.size, retainedJobs) - jobs.take(toRemove).foreach { job => - // Remove the job's UI data, if it exists - jobIdToData.remove(job.jobId).foreach { removedJob => - // A null jobGroupId is used for jobs that are run without a job group - val jobGroupId = removedJob.jobGroup.orNull - // Remove the job group -> job mapping entry, if it exists - jobGroupToJobIds.get(jobGroupId).foreach { jobsInGroup => - jobsInGroup.remove(job.jobId) - // If this was the last job in this job group, remove the map entry for the job group - if (jobsInGroup.isEmpty) { - jobGroupToJobIds.remove(jobGroupId) - } - } - } - } - jobs.trimStart(toRemove) - } - } - - override def onJobStart(jobStart: SparkListenerJobStart): Unit = synchronized { - val jobGroup = for ( - props <- Option(jobStart.properties); - group <- Option(props.getProperty(SparkContext.SPARK_JOB_GROUP_ID)) - ) yield group - val jobData: JobUIData = - new JobUIData( - jobId = jobStart.jobId, - submissionTime = Option(jobStart.time).filter(_ >= 0), - stageIds = jobStart.stageIds, - jobGroup = jobGroup, - status = JobExecutionStatus.RUNNING) - // A null jobGroupId is used for jobs that are run without a job group - jobGroupToJobIds.getOrElseUpdate(jobGroup.orNull, new HashSet[JobId]).add(jobStart.jobId) - jobStart.stageInfos.foreach(x => pendingStages(x.stageId) = x) - // Compute (a potential underestimate of) the number of tasks that will be run by this job. - // This may be an underestimate because the job start event references all of the result - // stages' transitive stage dependencies, but some of these stages might be skipped if their - // output is available from earlier runs. - // See https://github.com/apache/spark/pull/3009 for a more extensive discussion. - jobData.numTasks = { - val allStages = jobStart.stageInfos - val missingStages = allStages.filter(_.completionTime.isEmpty) - missingStages.map(_.numTasks).sum - } - jobIdToData(jobStart.jobId) = jobData - activeJobs(jobStart.jobId) = jobData - for (stageId <- jobStart.stageIds) { - stageIdToActiveJobIds.getOrElseUpdate(stageId, new HashSet[StageId]).add(jobStart.jobId) - } - // If there's no information for a stage, store the StageInfo received from the scheduler - // so that we can display stage descriptions for pending stages: - for (stageInfo <- jobStart.stageInfos) { - stageIdToInfo.getOrElseUpdate(stageInfo.stageId, stageInfo) - stageIdToData.getOrElseUpdate((stageInfo.stageId, stageInfo.attemptId), new StageUIData) - } - } - - override def onJobEnd(jobEnd: SparkListenerJobEnd): Unit = synchronized { - val jobData = activeJobs.remove(jobEnd.jobId).getOrElse { - logWarning(s"Job completed for unknown job ${jobEnd.jobId}") - new JobUIData(jobId = jobEnd.jobId) - } - jobData.completionTime = Option(jobEnd.time).filter(_ >= 0) - - jobData.stageIds.foreach(pendingStages.remove) - jobEnd.jobResult match { - case JobSucceeded => - completedJobs += jobData - trimJobsIfNecessary(completedJobs) - jobData.status = JobExecutionStatus.SUCCEEDED - numCompletedJobs += 1 - case JobFailed(_) => - failedJobs += jobData - trimJobsIfNecessary(failedJobs) - jobData.status = JobExecutionStatus.FAILED - numFailedJobs += 1 - } - for (stageId <- jobData.stageIds) { - stageIdToActiveJobIds.get(stageId).foreach { jobsUsingStage => - jobsUsingStage.remove(jobEnd.jobId) - if (jobsUsingStage.isEmpty) { - stageIdToActiveJobIds.remove(stageId) - } - stageIdToInfo.get(stageId).foreach { stageInfo => - if (stageInfo.submissionTime.isEmpty) { - // if this stage is pending, it won't complete, so mark it as "skipped": - skippedStages += stageInfo - trimStagesIfNecessary(skippedStages) - jobData.numSkippedStages += 1 - jobData.numSkippedTasks += stageInfo.numTasks - } - } - } - } - } - - override def onStageCompleted(stageCompleted: SparkListenerStageCompleted): Unit = synchronized { - val stage = stageCompleted.stageInfo - stageIdToInfo(stage.stageId) = stage - val stageData = stageIdToData.getOrElseUpdate((stage.stageId, stage.attemptId), { - logWarning("Stage completed for unknown stage " + stage.stageId) - new StageUIData - }) - - for ((id, info) <- stageCompleted.stageInfo.accumulables) { - stageData.accumulables(id) = info - } - - poolToActiveStages.get(stageData.schedulingPool).foreach { hashMap => - hashMap.remove(stage.stageId) - } - activeStages.remove(stage.stageId) - if (stage.failureReason.isEmpty) { - completedStages += stage - numCompletedStages += 1 - trimStagesIfNecessary(completedStages) - } else { - failedStages += stage - numFailedStages += 1 - trimStagesIfNecessary(failedStages) - } - - for ( - activeJobsDependentOnStage <- stageIdToActiveJobIds.get(stage.stageId); - jobId <- activeJobsDependentOnStage; - jobData <- jobIdToData.get(jobId) - ) { - jobData.numActiveStages -= 1 - if (stage.failureReason.isEmpty) { - if (stage.submissionTime.isDefined) { - jobData.completedStageIndices.add(stage.stageId) - } - } else { - jobData.numFailedStages += 1 - } - } - } - - /** For FIFO, all stages are contained by "default" pool but "default" pool here is meaningless */ - override def onStageSubmitted(stageSubmitted: SparkListenerStageSubmitted): Unit = synchronized { - val stage = stageSubmitted.stageInfo - activeStages(stage.stageId) = stage - pendingStages.remove(stage.stageId) - val poolName = Option(stageSubmitted.properties).map { - p => p.getProperty("spark.scheduler.pool", SparkUI.DEFAULT_POOL_NAME) - }.getOrElse(SparkUI.DEFAULT_POOL_NAME) - - stageIdToInfo(stage.stageId) = stage - val stageData = stageIdToData.getOrElseUpdate((stage.stageId, stage.attemptId), new StageUIData) - stageData.schedulingPool = poolName - - stageData.description = Option(stageSubmitted.properties).flatMap { - p => Option(p.getProperty(SparkContext.SPARK_JOB_DESCRIPTION)) - } - - val stages = poolToActiveStages.getOrElseUpdate(poolName, new HashMap[Int, StageInfo]) - stages(stage.stageId) = stage - - for ( - activeJobsDependentOnStage <- stageIdToActiveJobIds.get(stage.stageId); - jobId <- activeJobsDependentOnStage; - jobData <- jobIdToData.get(jobId) - ) { - jobData.numActiveStages += 1 - - // If a stage retries again, it should be removed from completedStageIndices set - jobData.completedStageIndices.remove(stage.stageId) - } - } - - override def onTaskStart(taskStart: SparkListenerTaskStart): Unit = synchronized { - val taskInfo = taskStart.taskInfo - if (taskInfo != null) { - val stageData = stageIdToData.getOrElseUpdate((taskStart.stageId, taskStart.stageAttemptId), { - logWarning("Task start for unknown stage " + taskStart.stageId) - new StageUIData - }) - stageData.numActiveTasks += 1 - stageData.taskData.put(taskInfo.taskId, TaskUIData(taskInfo)) - } - for ( - activeJobsDependentOnStage <- stageIdToActiveJobIds.get(taskStart.stageId); - jobId <- activeJobsDependentOnStage; - jobData <- jobIdToData.get(jobId) - ) { - jobData.numActiveTasks += 1 - } - } - - override def onTaskGettingResult(taskGettingResult: SparkListenerTaskGettingResult) { - // Do nothing: because we don't do a deep copy of the TaskInfo, the TaskInfo in - // stageToTaskInfos already has the updated status. - } - - override def onTaskEnd(taskEnd: SparkListenerTaskEnd): Unit = synchronized { - val info = taskEnd.taskInfo - // If stage attempt id is -1, it means the DAGScheduler had no idea which attempt this task - // completion event is for. Let's just drop it here. This means we might have some speculation - // tasks on the web ui that's never marked as complete. - if (info != null && taskEnd.stageAttemptId != -1) { - val stageData = stageIdToData.getOrElseUpdate((taskEnd.stageId, taskEnd.stageAttemptId), { - logWarning("Task end for unknown stage " + taskEnd.stageId) - new StageUIData - }) - - for (accumulableInfo <- info.accumulables) { - stageData.accumulables(accumulableInfo.id) = accumulableInfo - } - - val execSummaryMap = stageData.executorSummary - val execSummary = execSummaryMap.getOrElseUpdate(info.executorId, new ExecutorSummary) - - taskEnd.reason match { - case Success => - execSummary.succeededTasks += 1 - case kill: TaskKilled => - execSummary.reasonToNumKilled = execSummary.reasonToNumKilled.updated( - kill.reason, execSummary.reasonToNumKilled.getOrElse(kill.reason, 0) + 1) - case commitDenied: TaskCommitDenied => - execSummary.reasonToNumKilled = execSummary.reasonToNumKilled.updated( - commitDenied.toErrorString, execSummary.reasonToNumKilled.getOrElse( - commitDenied.toErrorString, 0) + 1) - case _ => - execSummary.failedTasks += 1 - } - execSummary.taskTime += info.duration - stageData.numActiveTasks -= 1 - - val errorMessage: Option[String] = - taskEnd.reason match { - case org.apache.spark.Success => - stageData.completedIndices.add(info.index) - stageData.numCompleteTasks += 1 - None - case kill: TaskKilled => - stageData.reasonToNumKilled = stageData.reasonToNumKilled.updated( - kill.reason, stageData.reasonToNumKilled.getOrElse(kill.reason, 0) + 1) - Some(kill.toErrorString) - case commitDenied: TaskCommitDenied => - stageData.reasonToNumKilled = stageData.reasonToNumKilled.updated( - commitDenied.toErrorString, stageData.reasonToNumKilled.getOrElse( - commitDenied.toErrorString, 0) + 1) - Some(commitDenied.toErrorString) - case e: ExceptionFailure => // Handle ExceptionFailure because we might have accumUpdates - stageData.numFailedTasks += 1 - Some(e.toErrorString) - case e: TaskFailedReason => // All other failure cases - stageData.numFailedTasks += 1 - Some(e.toErrorString) - } - - val taskMetrics = Option(taskEnd.taskMetrics) - taskMetrics.foreach { m => - val oldMetrics = stageData.taskData.get(info.taskId).flatMap(_.metrics) - updateAggregateMetrics(stageData, info.executorId, m, oldMetrics) - } - - val taskData = stageData.taskData.getOrElseUpdate(info.taskId, TaskUIData(info)) - taskData.updateTaskInfo(info) - taskData.updateTaskMetrics(taskMetrics) - taskData.errorMessage = errorMessage - - // If Tasks is too large, remove and garbage collect old tasks - if (stageData.taskData.size > retainedTasks) { - stageData.taskData = stageData.taskData.drop( - calculateNumberToRemove(stageData.taskData.size, retainedTasks)) - } - - for ( - activeJobsDependentOnStage <- stageIdToActiveJobIds.get(taskEnd.stageId); - jobId <- activeJobsDependentOnStage; - jobData <- jobIdToData.get(jobId) - ) { - jobData.numActiveTasks -= 1 - taskEnd.reason match { - case Success => - jobData.completedIndices.add((taskEnd.stageId, info.index)) - jobData.numCompletedTasks += 1 - case kill: TaskKilled => - jobData.reasonToNumKilled = jobData.reasonToNumKilled.updated( - kill.reason, jobData.reasonToNumKilled.getOrElse(kill.reason, 0) + 1) - case commitDenied: TaskCommitDenied => - jobData.reasonToNumKilled = jobData.reasonToNumKilled.updated( - commitDenied.toErrorString, jobData.reasonToNumKilled.getOrElse( - commitDenied.toErrorString, 0) + 1) - case _ => - jobData.numFailedTasks += 1 - } - } - } - } - - /** - * Remove at least (maxRetained / 10) items to reduce friction. - */ - private def calculateNumberToRemove(dataSize: Int, retainedSize: Int): Int = { - math.max(retainedSize / 10, dataSize - retainedSize) - } - - /** - * Upon receiving new metrics for a task, updates the per-stage and per-executor-per-stage - * aggregate metrics by calculating deltas between the currently recorded metrics and the new - * metrics. - */ - def updateAggregateMetrics( - stageData: StageUIData, - execId: String, - taskMetrics: TaskMetrics, - oldMetrics: Option[TaskMetricsUIData]) { - val execSummary = stageData.executorSummary.getOrElseUpdate(execId, new ExecutorSummary) - - val shuffleWriteDelta = - taskMetrics.shuffleWriteMetrics.bytesWritten - - oldMetrics.map(_.shuffleWriteMetrics.bytesWritten).getOrElse(0L) - stageData.shuffleWriteBytes += shuffleWriteDelta - execSummary.shuffleWrite += shuffleWriteDelta - - val shuffleWriteRecordsDelta = - taskMetrics.shuffleWriteMetrics.recordsWritten - - oldMetrics.map(_.shuffleWriteMetrics.recordsWritten).getOrElse(0L) - stageData.shuffleWriteRecords += shuffleWriteRecordsDelta - execSummary.shuffleWriteRecords += shuffleWriteRecordsDelta - - val shuffleReadDelta = - taskMetrics.shuffleReadMetrics.totalBytesRead - - oldMetrics.map(_.shuffleReadMetrics.totalBytesRead).getOrElse(0L) - stageData.shuffleReadTotalBytes += shuffleReadDelta - execSummary.shuffleRead += shuffleReadDelta - - val shuffleReadRecordsDelta = - taskMetrics.shuffleReadMetrics.recordsRead - - oldMetrics.map(_.shuffleReadMetrics.recordsRead).getOrElse(0L) - stageData.shuffleReadRecords += shuffleReadRecordsDelta - execSummary.shuffleReadRecords += shuffleReadRecordsDelta - - val inputBytesDelta = - taskMetrics.inputMetrics.bytesRead - - oldMetrics.map(_.inputMetrics.bytesRead).getOrElse(0L) - stageData.inputBytes += inputBytesDelta - execSummary.inputBytes += inputBytesDelta - - val inputRecordsDelta = - taskMetrics.inputMetrics.recordsRead - - oldMetrics.map(_.inputMetrics.recordsRead).getOrElse(0L) - stageData.inputRecords += inputRecordsDelta - execSummary.inputRecords += inputRecordsDelta - - val outputBytesDelta = - taskMetrics.outputMetrics.bytesWritten - - oldMetrics.map(_.outputMetrics.bytesWritten).getOrElse(0L) - stageData.outputBytes += outputBytesDelta - execSummary.outputBytes += outputBytesDelta - - val outputRecordsDelta = - taskMetrics.outputMetrics.recordsWritten - - oldMetrics.map(_.outputMetrics.recordsWritten).getOrElse(0L) - stageData.outputRecords += outputRecordsDelta - execSummary.outputRecords += outputRecordsDelta - - val diskSpillDelta = - taskMetrics.diskBytesSpilled - oldMetrics.map(_.diskBytesSpilled).getOrElse(0L) - stageData.diskBytesSpilled += diskSpillDelta - execSummary.diskBytesSpilled += diskSpillDelta - - val memorySpillDelta = - taskMetrics.memoryBytesSpilled - oldMetrics.map(_.memoryBytesSpilled).getOrElse(0L) - stageData.memoryBytesSpilled += memorySpillDelta - execSummary.memoryBytesSpilled += memorySpillDelta - - val timeDelta = - taskMetrics.executorRunTime - oldMetrics.map(_.executorRunTime).getOrElse(0L) - stageData.executorRunTime += timeDelta - - val cpuTimeDelta = - taskMetrics.executorCpuTime - oldMetrics.map(_.executorCpuTime).getOrElse(0L) - stageData.executorCpuTime += cpuTimeDelta - } - - override def onExecutorMetricsUpdate(executorMetricsUpdate: SparkListenerExecutorMetricsUpdate) { - for ((taskId, sid, sAttempt, accumUpdates) <- executorMetricsUpdate.accumUpdates) { - val stageData = stageIdToData.getOrElseUpdate((sid, sAttempt), { - logWarning("Metrics update for task in unknown stage " + sid) - new StageUIData - }) - val taskData = stageData.taskData.get(taskId) - val metrics = TaskMetrics.fromAccumulatorInfos(accumUpdates) - taskData.foreach { t => - if (!t.taskInfo.finished) { - updateAggregateMetrics(stageData, executorMetricsUpdate.execId, metrics, t.metrics) - // Overwrite task metrics - t.updateTaskMetrics(Some(metrics)) - } - } - } - } - - override def onEnvironmentUpdate(environmentUpdate: SparkListenerEnvironmentUpdate) { - synchronized { - schedulingMode = environmentUpdate - .environmentDetails("Spark Properties").toMap - .get("spark.scheduler.mode") - .map(SchedulingMode.withName) - } - } - - override def onBlockManagerAdded(blockManagerAdded: SparkListenerBlockManagerAdded) { - synchronized { - val blockManagerId = blockManagerAdded.blockManagerId - val executorId = blockManagerId.executorId - executorIdToBlockManagerId(executorId) = blockManagerId - } - } - - override def onBlockManagerRemoved(blockManagerRemoved: SparkListenerBlockManagerRemoved) { - synchronized { - val executorId = blockManagerRemoved.blockManagerId.executorId - executorIdToBlockManagerId.remove(executorId) - } - } - - override def onApplicationStart(appStarted: SparkListenerApplicationStart) { - startTime = appStarted.time - } - - override def onApplicationEnd(appEnded: SparkListenerApplicationEnd) { - endTime = appEnded.time - } - - /** - * For testing only. Wait until at least `numExecutors` executors are up, or throw - * `TimeoutException` if the waiting time elapsed before `numExecutors` executors up. - * Exposed for testing. - * - * @param numExecutors the number of executors to wait at least - * @param timeout time to wait in milliseconds - */ - private[spark] def waitUntilExecutorsUp(numExecutors: Int, timeout: Long): Unit = { - val finishTime = System.currentTimeMillis() + timeout - while (System.currentTimeMillis() < finishTime) { - val numBlockManagers = synchronized { - blockManagerIds.size - } - if (numBlockManagers >= numExecutors + 1) { - // Need to count the block manager in driver - return - } - // Sleep rather than using wait/notify, because this is used only for testing and wait/notify - // add overhead in the general case. - Thread.sleep(10) - } - throw new TimeoutException( - s"Can't find $numExecutors executors before $timeout milliseconds elapsed") - } -} diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala index 5f93f2ffb412f..11a6a34344976 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala @@ -32,7 +32,6 @@ import org.apache.spark.scheduler.TaskLocality import org.apache.spark.status.AppStatusStore import org.apache.spark.status.api.v1._ import org.apache.spark.ui._ -import org.apache.spark.ui.jobs.UIData._ import org.apache.spark.util.{Distribution, Utils} /** Page showing statistics and task list for a given stage */ diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/UIData.scala b/core/src/main/scala/org/apache/spark/ui/jobs/UIData.scala deleted file mode 100644 index 5acec0d0f54c9..0000000000000 --- a/core/src/main/scala/org/apache/spark/ui/jobs/UIData.scala +++ /dev/null @@ -1,311 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.ui.jobs - -import scala.collection.mutable -import scala.collection.mutable.{HashMap, LinkedHashMap} - -import com.google.common.collect.Interners - -import org.apache.spark.JobExecutionStatus -import org.apache.spark.executor._ -import org.apache.spark.scheduler.{AccumulableInfo, TaskInfo} -import org.apache.spark.util.AccumulatorContext -import org.apache.spark.util.collection.OpenHashSet - -private[spark] object UIData { - - class ExecutorSummary { - var taskTime : Long = 0 - var failedTasks : Int = 0 - var succeededTasks : Int = 0 - var reasonToNumKilled : Map[String, Int] = Map.empty - var inputBytes : Long = 0 - var inputRecords : Long = 0 - var outputBytes : Long = 0 - var outputRecords : Long = 0 - var shuffleRead : Long = 0 - var shuffleReadRecords : Long = 0 - var shuffleWrite : Long = 0 - var shuffleWriteRecords : Long = 0 - var memoryBytesSpilled : Long = 0 - var diskBytesSpilled : Long = 0 - var isBlacklisted : Int = 0 - } - - class JobUIData( - var jobId: Int = -1, - var submissionTime: Option[Long] = None, - var completionTime: Option[Long] = None, - var stageIds: Seq[Int] = Seq.empty, - var jobGroup: Option[String] = None, - var status: JobExecutionStatus = JobExecutionStatus.UNKNOWN, - /* Tasks */ - // `numTasks` is a potential underestimate of the true number of tasks that this job will run. - // This may be an underestimate because the job start event references all of the result - // stages' transitive stage dependencies, but some of these stages might be skipped if their - // output is available from earlier runs. - // See https://github.com/apache/spark/pull/3009 for a more extensive discussion. - var numTasks: Int = 0, - var numActiveTasks: Int = 0, - var numCompletedTasks: Int = 0, - var completedIndices: OpenHashSet[(Int, Int)] = new OpenHashSet[(Int, Int)](), - var numSkippedTasks: Int = 0, - var numFailedTasks: Int = 0, - var reasonToNumKilled: Map[String, Int] = Map.empty, - /* Stages */ - var numActiveStages: Int = 0, - // This needs to be a set instead of a simple count to prevent double-counting of rerun stages: - var completedStageIndices: mutable.HashSet[Int] = new mutable.HashSet[Int](), - var numSkippedStages: Int = 0, - var numFailedStages: Int = 0 - ) - - class StageUIData { - var numActiveTasks: Int = _ - var numCompleteTasks: Int = _ - var completedIndices = new OpenHashSet[Int]() - var numFailedTasks: Int = _ - var reasonToNumKilled: Map[String, Int] = Map.empty - - var executorRunTime: Long = _ - var executorCpuTime: Long = _ - - var inputBytes: Long = _ - var inputRecords: Long = _ - var outputBytes: Long = _ - var outputRecords: Long = _ - var shuffleReadTotalBytes: Long = _ - var shuffleReadRecords : Long = _ - var shuffleWriteBytes: Long = _ - var shuffleWriteRecords: Long = _ - var memoryBytesSpilled: Long = _ - var diskBytesSpilled: Long = _ - var isBlacklisted: Int = _ - var lastUpdateTime: Option[Long] = None - - var schedulingPool: String = "" - var description: Option[String] = None - - var accumulables = new HashMap[Long, AccumulableInfo] - var taskData = new LinkedHashMap[Long, TaskUIData] - var executorSummary = new HashMap[String, ExecutorSummary] - - def hasInput: Boolean = inputBytes > 0 - def hasOutput: Boolean = outputBytes > 0 - def hasShuffleRead: Boolean = shuffleReadTotalBytes > 0 - def hasShuffleWrite: Boolean = shuffleWriteBytes > 0 - def hasBytesSpilled: Boolean = memoryBytesSpilled > 0 || diskBytesSpilled > 0 - } - - /** - * These are kept mutable and reused throughout a task's lifetime to avoid excessive reallocation. - */ - class TaskUIData private(private var _taskInfo: TaskInfo) { - - private[this] var _metrics: Option[TaskMetricsUIData] = Some(TaskMetricsUIData.EMPTY) - - var errorMessage: Option[String] = None - - def taskInfo: TaskInfo = _taskInfo - - def metrics: Option[TaskMetricsUIData] = _metrics - - def updateTaskInfo(taskInfo: TaskInfo): Unit = { - _taskInfo = TaskUIData.dropInternalAndSQLAccumulables(taskInfo) - } - - def updateTaskMetrics(metrics: Option[TaskMetrics]): Unit = { - _metrics = metrics.map(TaskMetricsUIData.fromTaskMetrics) - } - - def taskDuration(lastUpdateTime: Option[Long] = None): Option[Long] = { - if (taskInfo.status == "RUNNING") { - Some(_taskInfo.timeRunning(lastUpdateTime.getOrElse(System.currentTimeMillis))) - } else { - _metrics.map(_.executorRunTime) - } - } - } - - object TaskUIData { - - private val stringInterner = Interners.newWeakInterner[String]() - - /** String interning to reduce the memory usage. */ - private def weakIntern(s: String): String = { - stringInterner.intern(s) - } - - def apply(taskInfo: TaskInfo): TaskUIData = { - new TaskUIData(dropInternalAndSQLAccumulables(taskInfo)) - } - - /** - * We don't need to store internal or SQL accumulables as their values will be shown in other - * places, so drop them to reduce the memory usage. - */ - private[spark] def dropInternalAndSQLAccumulables(taskInfo: TaskInfo): TaskInfo = { - val newTaskInfo = new TaskInfo( - taskId = taskInfo.taskId, - index = taskInfo.index, - attemptNumber = taskInfo.attemptNumber, - launchTime = taskInfo.launchTime, - executorId = weakIntern(taskInfo.executorId), - host = weakIntern(taskInfo.host), - taskLocality = taskInfo.taskLocality, - speculative = taskInfo.speculative - ) - newTaskInfo.gettingResultTime = taskInfo.gettingResultTime - newTaskInfo.setAccumulables(taskInfo.accumulables.filter { - accum => !accum.internal && accum.metadata != Some(AccumulatorContext.SQL_ACCUM_IDENTIFIER) - }) - newTaskInfo.finishTime = taskInfo.finishTime - newTaskInfo.failed = taskInfo.failed - newTaskInfo.killed = taskInfo.killed - newTaskInfo - } - } - - case class TaskMetricsUIData( - executorDeserializeTime: Long, - executorDeserializeCpuTime: Long, - executorRunTime: Long, - executorCpuTime: Long, - resultSize: Long, - jvmGCTime: Long, - resultSerializationTime: Long, - memoryBytesSpilled: Long, - diskBytesSpilled: Long, - peakExecutionMemory: Long, - inputMetrics: InputMetricsUIData, - outputMetrics: OutputMetricsUIData, - shuffleReadMetrics: ShuffleReadMetricsUIData, - shuffleWriteMetrics: ShuffleWriteMetricsUIData) - - object TaskMetricsUIData { - def fromTaskMetrics(m: TaskMetrics): TaskMetricsUIData = { - TaskMetricsUIData( - executorDeserializeTime = m.executorDeserializeTime, - executorDeserializeCpuTime = m.executorDeserializeCpuTime, - executorRunTime = m.executorRunTime, - executorCpuTime = m.executorCpuTime, - resultSize = m.resultSize, - jvmGCTime = m.jvmGCTime, - resultSerializationTime = m.resultSerializationTime, - memoryBytesSpilled = m.memoryBytesSpilled, - diskBytesSpilled = m.diskBytesSpilled, - peakExecutionMemory = m.peakExecutionMemory, - inputMetrics = InputMetricsUIData(m.inputMetrics), - outputMetrics = OutputMetricsUIData(m.outputMetrics), - shuffleReadMetrics = ShuffleReadMetricsUIData(m.shuffleReadMetrics), - shuffleWriteMetrics = ShuffleWriteMetricsUIData(m.shuffleWriteMetrics)) - } - - val EMPTY: TaskMetricsUIData = fromTaskMetrics(TaskMetrics.empty) - } - - case class InputMetricsUIData(bytesRead: Long, recordsRead: Long) - object InputMetricsUIData { - def apply(metrics: InputMetrics): InputMetricsUIData = { - if (metrics.bytesRead == 0 && metrics.recordsRead == 0) { - EMPTY - } else { - new InputMetricsUIData( - bytesRead = metrics.bytesRead, - recordsRead = metrics.recordsRead) - } - } - private val EMPTY = InputMetricsUIData(0, 0) - } - - case class OutputMetricsUIData(bytesWritten: Long, recordsWritten: Long) - object OutputMetricsUIData { - def apply(metrics: OutputMetrics): OutputMetricsUIData = { - if (metrics.bytesWritten == 0 && metrics.recordsWritten == 0) { - EMPTY - } else { - new OutputMetricsUIData( - bytesWritten = metrics.bytesWritten, - recordsWritten = metrics.recordsWritten) - } - } - private val EMPTY = OutputMetricsUIData(0, 0) - } - - case class ShuffleReadMetricsUIData( - remoteBlocksFetched: Long, - localBlocksFetched: Long, - remoteBytesRead: Long, - remoteBytesReadToDisk: Long, - localBytesRead: Long, - fetchWaitTime: Long, - recordsRead: Long, - totalBytesRead: Long, - totalBlocksFetched: Long) - - object ShuffleReadMetricsUIData { - def apply(metrics: ShuffleReadMetrics): ShuffleReadMetricsUIData = { - if ( - metrics.remoteBlocksFetched == 0 && - metrics.localBlocksFetched == 0 && - metrics.remoteBytesRead == 0 && - metrics.localBytesRead == 0 && - metrics.fetchWaitTime == 0 && - metrics.recordsRead == 0 && - metrics.totalBytesRead == 0 && - metrics.totalBlocksFetched == 0) { - EMPTY - } else { - new ShuffleReadMetricsUIData( - remoteBlocksFetched = metrics.remoteBlocksFetched, - localBlocksFetched = metrics.localBlocksFetched, - remoteBytesRead = metrics.remoteBytesRead, - remoteBytesReadToDisk = metrics.remoteBytesReadToDisk, - localBytesRead = metrics.localBytesRead, - fetchWaitTime = metrics.fetchWaitTime, - recordsRead = metrics.recordsRead, - totalBytesRead = metrics.totalBytesRead, - totalBlocksFetched = metrics.totalBlocksFetched - ) - } - } - private val EMPTY = ShuffleReadMetricsUIData(0, 0, 0, 0, 0, 0, 0, 0, 0) - } - - case class ShuffleWriteMetricsUIData( - bytesWritten: Long, - recordsWritten: Long, - writeTime: Long) - - object ShuffleWriteMetricsUIData { - def apply(metrics: ShuffleWriteMetrics): ShuffleWriteMetricsUIData = { - if (metrics.bytesWritten == 0 && metrics.recordsWritten == 0 && metrics.writeTime == 0) { - EMPTY - } else { - new ShuffleWriteMetricsUIData( - bytesWritten = metrics.bytesWritten, - recordsWritten = metrics.recordsWritten, - writeTime = metrics.writeTime - ) - } - } - private val EMPTY = ShuffleWriteMetricsUIData(0, 0, 0) - } - -} diff --git a/core/src/test/scala/org/apache/spark/DistributedSuite.scala b/core/src/test/scala/org/apache/spark/DistributedSuite.scala index ea9f6d2fc20f4..e09d5f59817b9 100644 --- a/core/src/test/scala/org/apache/spark/DistributedSuite.scala +++ b/core/src/test/scala/org/apache/spark/DistributedSuite.scala @@ -156,7 +156,7 @@ class DistributedSuite extends SparkFunSuite with Matchers with LocalSparkContex private def testCaching(conf: SparkConf, storageLevel: StorageLevel): Unit = { sc = new SparkContext(conf.setMaster(clusterUrl).setAppName("test")) - sc.jobProgressListener.waitUntilExecutorsUp(2, 30000) + TestUtils.waitUntilExecutorsUp(sc, 2, 30000) val data = sc.parallelize(1 to 1000, 10) val cachedData = data.persist(storageLevel) assert(cachedData.count === 1000) diff --git a/core/src/test/scala/org/apache/spark/ExternalShuffleServiceSuite.scala b/core/src/test/scala/org/apache/spark/ExternalShuffleServiceSuite.scala index fe944031bc948..472952addf353 100644 --- a/core/src/test/scala/org/apache/spark/ExternalShuffleServiceSuite.scala +++ b/core/src/test/scala/org/apache/spark/ExternalShuffleServiceSuite.scala @@ -66,7 +66,7 @@ class ExternalShuffleServiceSuite extends ShuffleSuite with BeforeAndAfterAll { // local blocks from the local BlockManager and won't send requests to ExternalShuffleService. // In this case, we won't receive FetchFailed. And it will make this test fail. // Therefore, we should wait until all slaves are up - sc.jobProgressListener.waitUntilExecutorsUp(2, 60000) + TestUtils.waitUntilExecutorsUp(sc, 2, 60000) val rdd = sc.parallelize(0 until 1000, 10).map(i => (i, 1)).reduceByKey(_ + _) diff --git a/core/src/test/scala/org/apache/spark/StatusTrackerSuite.scala b/core/src/test/scala/org/apache/spark/StatusTrackerSuite.scala index 5483f2b8434aa..a15ae040d43a9 100644 --- a/core/src/test/scala/org/apache/spark/StatusTrackerSuite.scala +++ b/core/src/test/scala/org/apache/spark/StatusTrackerSuite.scala @@ -44,13 +44,13 @@ class StatusTrackerSuite extends SparkFunSuite with Matchers with LocalSparkCont stageIds.size should be(2) val firstStageInfo = eventually(timeout(10 seconds)) { - sc.statusTracker.getStageInfo(stageIds(0)).get + sc.statusTracker.getStageInfo(stageIds.min).get } - firstStageInfo.stageId() should be(stageIds(0)) + firstStageInfo.stageId() should be(stageIds.min) firstStageInfo.currentAttemptId() should be(0) firstStageInfo.numTasks() should be(2) eventually(timeout(10 seconds)) { - val updatedFirstStageInfo = sc.statusTracker.getStageInfo(stageIds(0)).get + val updatedFirstStageInfo = sc.statusTracker.getStageInfo(stageIds.min).get updatedFirstStageInfo.numCompletedTasks() should be(2) updatedFirstStageInfo.numActiveTasks() should be(0) updatedFirstStageInfo.numFailedTasks() should be(0) diff --git a/core/src/test/scala/org/apache/spark/broadcast/BroadcastSuite.scala b/core/src/test/scala/org/apache/spark/broadcast/BroadcastSuite.scala index 46f9ac6b0273a..159629825c677 100644 --- a/core/src/test/scala/org/apache/spark/broadcast/BroadcastSuite.scala +++ b/core/src/test/scala/org/apache/spark/broadcast/BroadcastSuite.scala @@ -224,7 +224,7 @@ class BroadcastSuite extends SparkFunSuite with LocalSparkContext with Encryptio new SparkContext("local-cluster[%d, 1, 1024]".format(numSlaves), "test") // Wait until all salves are up try { - _sc.jobProgressListener.waitUntilExecutorsUp(numSlaves, 60000) + TestUtils.waitUntilExecutorsUp(_sc, numSlaves, 60000) _sc } catch { case e: Throwable => diff --git a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala index d395e09969453..feefb6a4d73f0 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala @@ -2406,13 +2406,13 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi // OutputCommitCoordinator requires the task info itself to not be null. private def createFakeTaskInfo(): TaskInfo = { val info = new TaskInfo(0, 0, 0, 0L, "", "", TaskLocality.ANY, false) - info.finishTime = 1 // to prevent spurious errors in JobProgressListener + info.finishTime = 1 info } private def createFakeTaskInfoWithId(taskId: Long): TaskInfo = { val info = new TaskInfo(taskId, 0, 0, 0L, "", "", TaskLocality.ANY, false) - info.finishTime = 1 // to prevent spurious errors in JobProgressListener + info.finishTime = 1 info } diff --git a/core/src/test/scala/org/apache/spark/scheduler/SparkListenerWithClusterSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/SparkListenerWithClusterSuite.scala index 9fa8859382911..123f7f49d21b5 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/SparkListenerWithClusterSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/SparkListenerWithClusterSuite.scala @@ -21,7 +21,7 @@ import scala.collection.mutable import org.scalatest.{BeforeAndAfter, BeforeAndAfterAll} -import org.apache.spark.{LocalSparkContext, SparkContext, SparkFunSuite} +import org.apache.spark.{LocalSparkContext, SparkContext, SparkFunSuite, TestUtils} import org.apache.spark.scheduler.cluster.ExecutorInfo /** @@ -43,7 +43,7 @@ class SparkListenerWithClusterSuite extends SparkFunSuite with LocalSparkContext // This test will check if the number of executors received by "SparkListener" is same as the // number of all executors, so we need to wait until all executors are up - sc.jobProgressListener.waitUntilExecutorsUp(2, 60000) + TestUtils.waitUntilExecutorsUp(sc, 2, 60000) val rdd1 = sc.parallelize(1 to 100, 4) val rdd2 = rdd1.map(_.toString) diff --git a/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala b/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala deleted file mode 100644 index 48be3be81755a..0000000000000 --- a/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala +++ /dev/null @@ -1,442 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.ui.jobs - -import java.util.Properties - -import org.scalatest.Matchers - -import org.apache.spark._ -import org.apache.spark.{LocalSparkContext, SparkConf, Success} -import org.apache.spark.executor._ -import org.apache.spark.scheduler._ -import org.apache.spark.ui.jobs.UIData.TaskUIData -import org.apache.spark.util.{AccumulatorContext, Utils} - -class JobProgressListenerSuite extends SparkFunSuite with LocalSparkContext with Matchers { - - val jobSubmissionTime = 1421191042750L - val jobCompletionTime = 1421191296660L - - private def createStageStartEvent(stageId: Int) = { - val stageInfo = new StageInfo(stageId, 0, stageId.toString, 0, null, null, "") - SparkListenerStageSubmitted(stageInfo) - } - - private def createStageEndEvent(stageId: Int, failed: Boolean = false) = { - val stageInfo = new StageInfo(stageId, 0, stageId.toString, 0, null, null, "") - if (failed) { - stageInfo.failureReason = Some("Failed!") - } - SparkListenerStageCompleted(stageInfo) - } - - private def createJobStartEvent( - jobId: Int, - stageIds: Seq[Int], - jobGroup: Option[String] = None): SparkListenerJobStart = { - val stageInfos = stageIds.map { stageId => - new StageInfo(stageId, 0, stageId.toString, 0, null, null, "") - } - val properties: Option[Properties] = jobGroup.map { groupId => - val props = new Properties() - props.setProperty(SparkContext.SPARK_JOB_GROUP_ID, groupId) - props - } - SparkListenerJobStart(jobId, jobSubmissionTime, stageInfos, properties.orNull) - } - - private def createJobEndEvent(jobId: Int, failed: Boolean = false) = { - val result = if (failed) JobFailed(new Exception("dummy failure")) else JobSucceeded - SparkListenerJobEnd(jobId, jobCompletionTime, result) - } - - private def runJob(listener: SparkListener, jobId: Int, shouldFail: Boolean = false) { - val stagesThatWontBeRun = jobId * 200 to jobId * 200 + 10 - val stageIds = jobId * 100 to jobId * 100 + 50 - listener.onJobStart(createJobStartEvent(jobId, stageIds ++ stagesThatWontBeRun)) - for (stageId <- stageIds) { - listener.onStageSubmitted(createStageStartEvent(stageId)) - listener.onStageCompleted(createStageEndEvent(stageId, failed = stageId % 2 == 0)) - } - listener.onJobEnd(createJobEndEvent(jobId, shouldFail)) - } - - private def assertActiveJobsStateIsEmpty(listener: JobProgressListener) { - listener.getSizesOfActiveStateTrackingCollections.foreach { case (fieldName, size) => - assert(size === 0, s"$fieldName was not empty") - } - } - - test("test LRU eviction of stages") { - def runWithListener(listener: JobProgressListener) : Unit = { - for (i <- 1 to 50) { - listener.onStageSubmitted(createStageStartEvent(i)) - listener.onStageCompleted(createStageEndEvent(i)) - } - assertActiveJobsStateIsEmpty(listener) - } - val conf = new SparkConf() - conf.set("spark.ui.retainedStages", 5.toString) - var listener = new JobProgressListener(conf) - - // Test with 5 retainedStages - runWithListener(listener) - listener.completedStages.size should be (5) - listener.completedStages.map(_.stageId).toSet should be (Set(50, 49, 48, 47, 46)) - - // Test with 0 retainedStages - conf.set("spark.ui.retainedStages", 0.toString) - listener = new JobProgressListener(conf) - runWithListener(listener) - listener.completedStages.size should be (0) - } - - test("test clearing of stageIdToActiveJobs") { - val conf = new SparkConf() - conf.set("spark.ui.retainedStages", 5.toString) - val listener = new JobProgressListener(conf) - val jobId = 0 - val stageIds = 1 to 50 - // Start a job with 50 stages - listener.onJobStart(createJobStartEvent(jobId, stageIds)) - for (stageId <- stageIds) { - listener.onStageSubmitted(createStageStartEvent(stageId)) - } - listener.stageIdToActiveJobIds.size should be > 0 - - // Complete the stages and job - for (stageId <- stageIds) { - listener.onStageCompleted(createStageEndEvent(stageId, failed = false)) - } - listener.onJobEnd(createJobEndEvent(jobId, false)) - assertActiveJobsStateIsEmpty(listener) - listener.stageIdToActiveJobIds.size should be (0) - } - - test("test clearing of jobGroupToJobIds") { - def runWithListener(listener: JobProgressListener): Unit = { - // Run 50 jobs, each with one stage - for (jobId <- 0 to 50) { - listener.onJobStart(createJobStartEvent(jobId, Seq(0), jobGroup = Some(jobId.toString))) - listener.onStageSubmitted(createStageStartEvent(0)) - listener.onStageCompleted(createStageEndEvent(0, failed = false)) - listener.onJobEnd(createJobEndEvent(jobId, false)) - } - assertActiveJobsStateIsEmpty(listener) - } - val conf = new SparkConf() - conf.set("spark.ui.retainedJobs", 5.toString) - - var listener = new JobProgressListener(conf) - runWithListener(listener) - // This collection won't become empty, but it should be bounded by spark.ui.retainedJobs - listener.jobGroupToJobIds.size should be (5) - - // Test with 0 jobs - conf.set("spark.ui.retainedJobs", 0.toString) - listener = new JobProgressListener(conf) - runWithListener(listener) - listener.jobGroupToJobIds.size should be (0) - } - - test("test LRU eviction of jobs") { - val conf = new SparkConf() - conf.set("spark.ui.retainedStages", 5.toString) - conf.set("spark.ui.retainedJobs", 5.toString) - val listener = new JobProgressListener(conf) - - // Run a bunch of jobs to get the listener into a state where we've exceeded both the - // job and stage retention limits: - for (jobId <- 1 to 10) { - runJob(listener, jobId, shouldFail = false) - } - for (jobId <- 200 to 210) { - runJob(listener, jobId, shouldFail = true) - } - assertActiveJobsStateIsEmpty(listener) - // Snapshot the sizes of various soft- and hard-size-limited collections: - val softLimitSizes = listener.getSizesOfSoftSizeLimitedCollections - val hardLimitSizes = listener.getSizesOfHardSizeLimitedCollections - // Run some more jobs: - for (jobId <- 11 to 50) { - runJob(listener, jobId, shouldFail = false) - // We shouldn't exceed the hard / soft limit sizes after the jobs have finished: - listener.getSizesOfSoftSizeLimitedCollections should be (softLimitSizes) - listener.getSizesOfHardSizeLimitedCollections should be (hardLimitSizes) - } - - listener.completedJobs.size should be (5) - listener.completedJobs.map(_.jobId).toSet should be (Set(50, 49, 48, 47, 46)) - - for (jobId <- 51 to 100) { - runJob(listener, jobId, shouldFail = true) - // We shouldn't exceed the hard / soft limit sizes after the jobs have finished: - listener.getSizesOfSoftSizeLimitedCollections should be (softLimitSizes) - listener.getSizesOfHardSizeLimitedCollections should be (hardLimitSizes) - } - assertActiveJobsStateIsEmpty(listener) - - // Completed and failed jobs each their own size limits, so this should still be the same: - listener.completedJobs.size should be (5) - listener.completedJobs.map(_.jobId).toSet should be (Set(50, 49, 48, 47, 46)) - listener.failedJobs.size should be (5) - listener.failedJobs.map(_.jobId).toSet should be (Set(100, 99, 98, 97, 96)) - } - - test("test executor id to summary") { - val conf = new SparkConf() - val listener = new JobProgressListener(conf) - val taskMetrics = TaskMetrics.empty - val shuffleReadMetrics = taskMetrics.createTempShuffleReadMetrics() - assert(listener.stageIdToData.size === 0) - - // finish this task, should get updated shuffleRead - shuffleReadMetrics.incRemoteBytesRead(1000) - taskMetrics.mergeShuffleReadMetrics() - var taskInfo = new TaskInfo(1234L, 0, 1, 0L, "exe-1", "host1", TaskLocality.NODE_LOCAL, false) - taskInfo.finishTime = 1 - var task = new ShuffleMapTask(0) - val taskType = Utils.getFormattedClassName(task) - listener.onTaskEnd( - SparkListenerTaskEnd(task.stageId, 0, taskType, Success, taskInfo, taskMetrics)) - assert(listener.stageIdToData.getOrElse((0, 0), fail()) - .executorSummary.getOrElse("exe-1", fail()).shuffleRead === 1000) - - // finish a task with unknown executor-id, nothing should happen - taskInfo = - new TaskInfo(1234L, 0, 1, 1000L, "exe-unknown", "host1", TaskLocality.NODE_LOCAL, true) - taskInfo.finishTime = 1 - task = new ShuffleMapTask(0) - listener.onTaskEnd( - SparkListenerTaskEnd(task.stageId, 0, taskType, Success, taskInfo, taskMetrics)) - assert(listener.stageIdToData.size === 1) - - // finish this task, should get updated duration - taskInfo = new TaskInfo(1235L, 0, 1, 0L, "exe-1", "host1", TaskLocality.NODE_LOCAL, false) - taskInfo.finishTime = 1 - task = new ShuffleMapTask(0) - listener.onTaskEnd( - SparkListenerTaskEnd(task.stageId, 0, taskType, Success, taskInfo, taskMetrics)) - assert(listener.stageIdToData.getOrElse((0, 0), fail()) - .executorSummary.getOrElse("exe-1", fail()).shuffleRead === 2000) - - // finish this task, should get updated duration - taskInfo = new TaskInfo(1236L, 0, 2, 0L, "exe-2", "host1", TaskLocality.NODE_LOCAL, false) - taskInfo.finishTime = 1 - task = new ShuffleMapTask(0) - listener.onTaskEnd( - SparkListenerTaskEnd(task.stageId, 0, taskType, Success, taskInfo, taskMetrics)) - assert(listener.stageIdToData.getOrElse((0, 0), fail()) - .executorSummary.getOrElse("exe-2", fail()).shuffleRead === 1000) - } - - test("test task success vs failure counting for different task end reasons") { - val conf = new SparkConf() - val listener = new JobProgressListener(conf) - val metrics = TaskMetrics.empty - val taskInfo = new TaskInfo(1234L, 0, 3, 0L, "exe-1", "host1", TaskLocality.NODE_LOCAL, false) - taskInfo.finishTime = 1 - val task = new ShuffleMapTask(0) - val taskType = Utils.getFormattedClassName(task) - - // Go through all the failure cases to make sure we are counting them as failures. - val taskFailedReasons = Seq( - Resubmitted, - new FetchFailed(null, 0, 0, 0, "ignored"), - ExceptionFailure("Exception", "description", null, null, None), - TaskResultLost, - ExecutorLostFailure("0", true, Some("Induced failure")), - UnknownReason) - var failCount = 0 - for (reason <- taskFailedReasons) { - listener.onTaskEnd( - SparkListenerTaskEnd(task.stageId, 0, taskType, reason, taskInfo, metrics)) - failCount += 1 - assert(listener.stageIdToData((task.stageId, 0)).numCompleteTasks === 0) - assert(listener.stageIdToData((task.stageId, 0)).numFailedTasks === failCount) - } - - // Make sure killed tasks are accounted for correctly. - listener.onTaskEnd( - SparkListenerTaskEnd( - task.stageId, 0, taskType, TaskKilled("test"), taskInfo, metrics)) - assert(listener.stageIdToData((task.stageId, 0)).reasonToNumKilled === Map("test" -> 1)) - - // Make sure we count success as success. - listener.onTaskEnd( - SparkListenerTaskEnd(task.stageId, 1, taskType, Success, taskInfo, metrics)) - assert(listener.stageIdToData((task.stageId, 1)).numCompleteTasks === 1) - assert(listener.stageIdToData((task.stageId, 0)).numFailedTasks === failCount) - } - - test("test update metrics") { - val conf = new SparkConf() - val listener = new JobProgressListener(conf) - - val taskType = Utils.getFormattedClassName(new ShuffleMapTask(0)) - val execId = "exe-1" - - def makeTaskMetrics(base: Int): TaskMetrics = { - val taskMetrics = TaskMetrics.registered - val shuffleReadMetrics = taskMetrics.createTempShuffleReadMetrics() - val shuffleWriteMetrics = taskMetrics.shuffleWriteMetrics - val inputMetrics = taskMetrics.inputMetrics - val outputMetrics = taskMetrics.outputMetrics - shuffleReadMetrics.incRemoteBytesRead(base + 1) - shuffleReadMetrics.incLocalBytesRead(base + 9) - shuffleReadMetrics.incRemoteBlocksFetched(base + 2) - taskMetrics.mergeShuffleReadMetrics() - shuffleWriteMetrics.incBytesWritten(base + 3) - taskMetrics.setExecutorRunTime(base + 4) - taskMetrics.incDiskBytesSpilled(base + 5) - taskMetrics.incMemoryBytesSpilled(base + 6) - inputMetrics.setBytesRead(base + 7) - outputMetrics.setBytesWritten(base + 8) - taskMetrics - } - - def makeTaskInfo(taskId: Long, finishTime: Int = 0): TaskInfo = { - val taskInfo = new TaskInfo(taskId, 0, 1, 0L, execId, "host1", TaskLocality.NODE_LOCAL, - false) - taskInfo.finishTime = finishTime - taskInfo - } - - listener.onTaskStart(SparkListenerTaskStart(0, 0, makeTaskInfo(1234L))) - listener.onTaskStart(SparkListenerTaskStart(0, 0, makeTaskInfo(1235L))) - listener.onTaskStart(SparkListenerTaskStart(1, 0, makeTaskInfo(1236L))) - listener.onTaskStart(SparkListenerTaskStart(1, 0, makeTaskInfo(1237L))) - - listener.onExecutorMetricsUpdate(SparkListenerExecutorMetricsUpdate(execId, Array( - (1234L, 0, 0, makeTaskMetrics(0).accumulators().map(AccumulatorSuite.makeInfo)), - (1235L, 0, 0, makeTaskMetrics(100).accumulators().map(AccumulatorSuite.makeInfo)), - (1236L, 1, 0, makeTaskMetrics(200).accumulators().map(AccumulatorSuite.makeInfo))))) - - var stage0Data = listener.stageIdToData.get((0, 0)).get - var stage1Data = listener.stageIdToData.get((1, 0)).get - assert(stage0Data.shuffleReadTotalBytes == 220) - assert(stage1Data.shuffleReadTotalBytes == 410) - assert(stage0Data.shuffleWriteBytes == 106) - assert(stage1Data.shuffleWriteBytes == 203) - assert(stage0Data.executorRunTime == 108) - assert(stage1Data.executorRunTime == 204) - assert(stage0Data.diskBytesSpilled == 110) - assert(stage1Data.diskBytesSpilled == 205) - assert(stage0Data.memoryBytesSpilled == 112) - assert(stage1Data.memoryBytesSpilled == 206) - assert(stage0Data.inputBytes == 114) - assert(stage1Data.inputBytes == 207) - assert(stage0Data.outputBytes == 116) - assert(stage1Data.outputBytes == 208) - - assert( - stage0Data.taskData.get(1234L).get.metrics.get.shuffleReadMetrics.totalBlocksFetched == 2) - assert( - stage0Data.taskData.get(1235L).get.metrics.get.shuffleReadMetrics.totalBlocksFetched == 102) - assert( - stage1Data.taskData.get(1236L).get.metrics.get.shuffleReadMetrics.totalBlocksFetched == 202) - - // task that was included in a heartbeat - listener.onTaskEnd(SparkListenerTaskEnd(0, 0, taskType, Success, makeTaskInfo(1234L, 1), - makeTaskMetrics(300))) - // task that wasn't included in a heartbeat - listener.onTaskEnd(SparkListenerTaskEnd(1, 0, taskType, Success, makeTaskInfo(1237L, 1), - makeTaskMetrics(400))) - - stage0Data = listener.stageIdToData.get((0, 0)).get - stage1Data = listener.stageIdToData.get((1, 0)).get - // Task 1235 contributed (100+1)+(100+9) = 210 shuffle bytes, and task 1234 contributed - // (300+1)+(300+9) = 610 total shuffle bytes, so the total for the stage is 820. - assert(stage0Data.shuffleReadTotalBytes == 820) - // Task 1236 contributed 410 shuffle bytes, and task 1237 contributed 810 shuffle bytes. - assert(stage1Data.shuffleReadTotalBytes == 1220) - assert(stage0Data.shuffleWriteBytes == 406) - assert(stage1Data.shuffleWriteBytes == 606) - assert(stage0Data.executorRunTime == 408) - assert(stage1Data.executorRunTime == 608) - assert(stage0Data.diskBytesSpilled == 410) - assert(stage1Data.diskBytesSpilled == 610) - assert(stage0Data.memoryBytesSpilled == 412) - assert(stage1Data.memoryBytesSpilled == 612) - assert(stage0Data.inputBytes == 414) - assert(stage1Data.inputBytes == 614) - assert(stage0Data.outputBytes == 416) - assert(stage1Data.outputBytes == 616) - assert( - stage0Data.taskData.get(1234L).get.metrics.get.shuffleReadMetrics.totalBlocksFetched == 302) - assert( - stage1Data.taskData.get(1237L).get.metrics.get.shuffleReadMetrics.totalBlocksFetched == 402) - } - - test("drop internal and sql accumulators") { - val taskInfo = new TaskInfo(0, 0, 0, 0, "", "", TaskLocality.ANY, false) - val internalAccum = - AccumulableInfo(id = 1, name = Some("internal"), None, None, true, false, None) - val sqlAccum = AccumulableInfo( - id = 2, - name = Some("sql"), - update = None, - value = None, - internal = false, - countFailedValues = false, - metadata = Some(AccumulatorContext.SQL_ACCUM_IDENTIFIER)) - val userAccum = AccumulableInfo( - id = 3, - name = Some("user"), - update = None, - value = None, - internal = false, - countFailedValues = false, - metadata = None) - taskInfo.setAccumulables(List(internalAccum, sqlAccum, userAccum)) - - val newTaskInfo = TaskUIData.dropInternalAndSQLAccumulables(taskInfo) - assert(newTaskInfo.accumulables === Seq(userAccum)) - } - - test("SPARK-19146 drop more elements when stageData.taskData.size > retainedTasks") { - val conf = new SparkConf() - conf.set("spark.ui.retainedTasks", "100") - val taskMetrics = TaskMetrics.empty - taskMetrics.mergeShuffleReadMetrics() - val task = new ShuffleMapTask(0) - val taskType = Utils.getFormattedClassName(task) - - val listener1 = new JobProgressListener(conf) - for (t <- 1 to 101) { - val taskInfo = new TaskInfo(t, 0, 1, 0L, "exe-1", "host1", TaskLocality.NODE_LOCAL, false) - taskInfo.finishTime = 1 - listener1.onTaskEnd( - SparkListenerTaskEnd(task.stageId, 0, taskType, Success, taskInfo, taskMetrics)) - } - // 101 - math.max(100 / 10, 101 - 100) = 91 - assert(listener1.stageIdToData((task.stageId, task.stageAttemptId)).taskData.size === 91) - - val listener2 = new JobProgressListener(conf) - for (t <- 1 to 150) { - val taskInfo = new TaskInfo(t, 0, 1, 0L, "exe-1", "host1", TaskLocality.NODE_LOCAL, false) - taskInfo.finishTime = 1 - listener2.onTaskEnd( - SparkListenerTaskEnd(task.stageId, 0, taskType, Success, taskInfo, taskMetrics)) - } - // 150 - math.max(100 / 10, 150 - 100) = 100 - assert(listener2.stageIdToData((task.stageId, task.stageAttemptId)).taskData.size === 100) - } - -} diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index 915c7e2e2fda3..5b8dcd0338cce 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -45,6 +45,8 @@ object MimaExcludes { ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.storage.StorageStatusListener"), ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.status.api.v1.ExecutorStageSummary.this"), ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.status.api.v1.JobData.this"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.SparkStatusTracker.this"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.ui.jobs.JobProgressListener"), // [SPARK-20495][SQL] Add StorageLevel to cacheTable API ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.sql.catalog.Catalog.cacheTable"), diff --git a/streaming/src/main/scala/org/apache/spark/streaming/ui/BatchPage.scala b/streaming/src/main/scala/org/apache/spark/streaming/ui/BatchPage.scala index 69e15655ad790..6748dd4ec48e3 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/ui/BatchPage.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/ui/BatchPage.scala @@ -23,16 +23,16 @@ import scala.xml._ import org.apache.commons.lang3.StringEscapeUtils +import org.apache.spark.status.api.v1.{JobData, StageData} import org.apache.spark.streaming.Time import org.apache.spark.streaming.ui.StreamingJobProgressListener.SparkJobId import org.apache.spark.ui.{UIUtils => SparkUIUtils, WebUIPage} -import org.apache.spark.ui.jobs.UIData.JobUIData -private[ui] case class SparkJobIdWithUIData(sparkJobId: SparkJobId, jobUIData: Option[JobUIData]) +private[ui] case class SparkJobIdWithUIData(sparkJobId: SparkJobId, jobData: Option[JobData]) private[ui] class BatchPage(parent: StreamingTab) extends WebUIPage("batch") { private val streamingListener = parent.listener - private val sparkListener = parent.ssc.sc.jobProgressListener + private val store = parent.parent.store private def columns: Seq[Node] = { Output Op Id @@ -52,13 +52,13 @@ private[ui] class BatchPage(parent: StreamingTab) extends WebUIPage("batch") { formattedOutputOpDuration: String, numSparkJobRowsInOutputOp: Int, isFirstRow: Boolean, - sparkJob: SparkJobIdWithUIData): Seq[Node] = { - if (sparkJob.jobUIData.isDefined) { + jobIdWithData: SparkJobIdWithUIData): Seq[Node] = { + if (jobIdWithData.jobData.isDefined) { generateNormalJobRow(outputOpData, outputOpDescription, formattedOutputOpDuration, - numSparkJobRowsInOutputOp, isFirstRow, sparkJob.jobUIData.get) + numSparkJobRowsInOutputOp, isFirstRow, jobIdWithData.jobData.get) } else { generateDroppedJobRow(outputOpData, outputOpDescription, formattedOutputOpDuration, - numSparkJobRowsInOutputOp, isFirstRow, sparkJob.sparkJobId) + numSparkJobRowsInOutputOp, isFirstRow, jobIdWithData.sparkJobId) } } @@ -94,15 +94,15 @@ private[ui] class BatchPage(parent: StreamingTab) extends WebUIPage("batch") { formattedOutputOpDuration: String, numSparkJobRowsInOutputOp: Int, isFirstRow: Boolean, - sparkJob: JobUIData): Seq[Node] = { + sparkJob: JobData): Seq[Node] = { val duration: Option[Long] = { sparkJob.submissionTime.map { start => - val end = sparkJob.completionTime.getOrElse(System.currentTimeMillis()) - end - start + val end = sparkJob.completionTime.map(_.getTime()).getOrElse(System.currentTimeMillis()) + end - start.getTime() } } val lastFailureReason = - sparkJob.stageIds.sorted.reverse.flatMap(sparkListener.stageIdToInfo.get). + sparkJob.stageIds.sorted.reverse.flatMap(getStageData). dropWhile(_.failureReason == None).take(1). // get the first info that contains failure flatMap(info => info.failureReason).headOption.getOrElse("") val formattedDuration = duration.map(d => SparkUIUtils.formatDuration(d)).getOrElse("-") @@ -135,7 +135,7 @@ private[ui] class BatchPage(parent: StreamingTab) extends WebUIPage("batch") { {formattedDuration} - {sparkJob.completedStageIndices.size}/{sparkJob.stageIds.size - sparkJob.numSkippedStages} + {sparkJob.numCompletedStages}/{sparkJob.stageIds.size - sparkJob.numSkippedStages} {if (sparkJob.numFailedStages > 0) s"(${sparkJob.numFailedStages} failed)"} {if (sparkJob.numSkippedStages > 0) s"(${sparkJob.numSkippedStages} skipped)"} @@ -146,7 +146,7 @@ private[ui] class BatchPage(parent: StreamingTab) extends WebUIPage("batch") { completed = sparkJob.numCompletedTasks, failed = sparkJob.numFailedTasks, skipped = sparkJob.numSkippedTasks, - reasonToNumKilled = sparkJob.reasonToNumKilled, + reasonToNumKilled = sparkJob.killedTasksSummary, total = sparkJob.numTasks - sparkJob.numSkippedTasks) } @@ -246,11 +246,19 @@ private[ui] class BatchPage(parent: StreamingTab) extends WebUIPage("batch") {

    } - private def getJobData(sparkJobId: SparkJobId): Option[JobUIData] = { - sparkListener.activeJobs.get(sparkJobId).orElse { - sparkListener.completedJobs.find(_.jobId == sparkJobId).orElse { - sparkListener.failedJobs.find(_.jobId == sparkJobId) - } + private def getJobData(sparkJobId: SparkJobId): Option[JobData] = { + try { + Some(store.job(sparkJobId)) + } catch { + case _: NoSuchElementException => None + } + } + + private def getStageData(stageId: Int): Option[StageData] = { + try { + Some(store.lastStageAttempt(stageId)) + } catch { + case _: NoSuchElementException => None } } @@ -282,25 +290,22 @@ private[ui] class BatchPage(parent: StreamingTab) extends WebUIPage("batch") { val sparkJobIds = outputOpIdToSparkJobIds.getOrElse(outputOpId, Seq.empty) (outputOperation, sparkJobIds) }.toSeq.sortBy(_._1.id) - sparkListener.synchronized { - val outputOpWithJobs = outputOps.map { case (outputOpData, sparkJobIds) => - (outputOpData, - sparkJobIds.map(sparkJobId => SparkJobIdWithUIData(sparkJobId, getJobData(sparkJobId)))) - } + val outputOpWithJobs = outputOps.map { case (outputOpData, sparkJobIds) => + (outputOpData, sparkJobIds.map { jobId => SparkJobIdWithUIData(jobId, getJobData(jobId)) }) + } - - - {columns} - - - { - outputOpWithJobs.map { case (outputOpData, sparkJobIds) => - generateOutputOpIdRow(outputOpData, sparkJobIds) - } +
    + + {columns} + + + { + outputOpWithJobs.map { case (outputOpData, sparkJobs) => + generateOutputOpIdRow(outputOpData, sparkJobs) } - -
    - } + } + + } def render(request: HttpServletRequest): Seq[Node] = streamingListener.synchronized { From ab6f60c4d6417cbb0240216a6b492aadcca3043e Mon Sep 17 00:00:00 2001 From: Jakub Dubovsky Date: Thu, 30 Nov 2017 10:24:30 +0900 Subject: [PATCH 1743/1765] [SPARK-22585][CORE] Path in addJar is not url encoded ## What changes were proposed in this pull request? This updates a behavior of `addJar` method of `sparkContext` class. If path without any scheme is passed as input it is used literally without url encoding/decoding it. ## How was this patch tested? A unit test is added for this. Author: Jakub Dubovsky Closes #19834 from james64/SPARK-22585-encode-add-jar. --- .../main/scala/org/apache/spark/SparkContext.scala | 6 +++++- .../scala/org/apache/spark/SparkContextSuite.scala | 11 +++++++++++ 2 files changed, 16 insertions(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index 984dd0a6629a2..c174939ca2e54 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -1837,7 +1837,11 @@ class SparkContext(config: SparkConf) extends Logging { Utils.validateURL(uri) uri.getScheme match { // A JAR file which exists only on the driver node - case null | "file" => addJarFile(new File(uri.getPath)) + case null => + // SPARK-22585 path without schema is not url encoded + addJarFile(new File(uri.getRawPath)) + // A JAR file which exists only on the driver node + case "file" => addJarFile(new File(uri.getPath)) // A JAR file which exists locally on every worker node case "local" => "file:" + uri.getPath case _ => path diff --git a/core/src/test/scala/org/apache/spark/SparkContextSuite.scala b/core/src/test/scala/org/apache/spark/SparkContextSuite.scala index 0ed5f26863dad..2bde8757dae5d 100644 --- a/core/src/test/scala/org/apache/spark/SparkContextSuite.scala +++ b/core/src/test/scala/org/apache/spark/SparkContextSuite.scala @@ -309,6 +309,17 @@ class SparkContextSuite extends SparkFunSuite with LocalSparkContext with Eventu assert(sc.listJars().head.contains(tmpJar.getName)) } + test("SPARK-22585 addJar argument without scheme is interpreted literally without url decoding") { + val tmpDir = new File(Utils.createTempDir(), "host%3A443") + tmpDir.mkdirs() + val tmpJar = File.createTempFile("t%2F", ".jar", tmpDir) + + sc = new SparkContext("local", "test") + + sc.addJar(tmpJar.getAbsolutePath) + assert(sc.listJars().size === 1) + } + test("Cancelling job group should not cause SparkContext to shutdown (SPARK-6414)") { try { sc = new SparkContext(new SparkConf().setAppName("test").setMaster("local")) From 92cfbeeb5ce9e2c618a76b3fe60ce84b9d38605b Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Thu, 30 Nov 2017 10:26:55 +0900 Subject: [PATCH 1744/1765] [SPARK-21866][ML][PYTHON][FOLLOWUP] Few cleanups and fix image test failure in Python 3.6.0 / NumPy 1.13.3 ## What changes were proposed in this pull request? Image test seems failed in Python 3.6.0 / NumPy 1.13.3. I manually tested as below: ``` ====================================================================== ERROR: test_read_images (pyspark.ml.tests.ImageReaderTest) ---------------------------------------------------------------------- Traceback (most recent call last): File "/.../spark/python/pyspark/ml/tests.py", line 1831, in test_read_images self.assertEqual(ImageSchema.toImage(array, origin=first_row[0]), first_row) File "/.../spark/python/pyspark/ml/image.py", line 149, in toImage data = bytearray(array.astype(dtype=np.uint8).ravel()) TypeError: only integer scalar arrays can be converted to a scalar index ---------------------------------------------------------------------- Ran 1 test in 7.606s ``` To be clear, I think the error seems from NumPy - https://github.com/numpy/numpy/blob/75b2d5d427afdb1392f2a0b2092e0767e4bab53d/numpy/core/src/multiarray/number.c#L947 For a smaller scope: ```python >>> import numpy as np >>> bytearray(np.array([1]).astype(dtype=np.uint8)) Traceback (most recent call last): File "", line 1, in TypeError: only integer scalar arrays can be converted to a scalar index ``` In Python 2.7 / NumPy 1.13.1, it prints: ``` bytearray(b'\x01') ``` So, here, I simply worked around it by converting it to bytes as below: ```python >>> bytearray(np.array([1]).astype(dtype=np.uint8).tobytes()) bytearray(b'\x01') ``` Also, while looking into it again, I realised few arguments could be quite confusing, for example, `Row` that needs some specific attributes and `numpy.ndarray`. I added few type checking and added some tests accordingly. So, it shows an error message as below: ``` TypeError: array argument should be numpy.ndarray; however, it got []. ``` ## How was this patch tested? Manually tested with `./python/run-tests`. And also: ``` PYSPARK_PYTHON=python3 SPARK_TESTING=1 bin/pyspark pyspark.ml.tests ImageReaderTest ``` Author: hyukjinkwon Closes #19835 from HyukjinKwon/SPARK-21866-followup. --- python/pyspark/ml/image.py | 27 ++++++++++++++++++++++++--- python/pyspark/ml/tests.py | 20 +++++++++++++++++++- 2 files changed, 43 insertions(+), 4 deletions(-) diff --git a/python/pyspark/ml/image.py b/python/pyspark/ml/image.py index 7d14f05295572..2b61aa9c0d9e9 100644 --- a/python/pyspark/ml/image.py +++ b/python/pyspark/ml/image.py @@ -108,12 +108,23 @@ def toNDArray(self, image): """ Converts an image to an array with metadata. - :param image: The image to be converted. + :param `Row` image: A row that contains the image to be converted. It should + have the attributes specified in `ImageSchema.imageSchema`. :return: a `numpy.ndarray` that is an image. .. versionadded:: 2.3.0 """ + if not isinstance(image, Row): + raise TypeError( + "image argument should be pyspark.sql.types.Row; however, " + "it got [%s]." % type(image)) + + if any(not hasattr(image, f) for f in self.imageFields): + raise ValueError( + "image argument should have attributes specified in " + "ImageSchema.imageSchema [%s]." % ", ".join(self.imageFields)) + height = image.height width = image.width nChannels = image.nChannels @@ -127,15 +138,20 @@ def toImage(self, array, origin=""): """ Converts an array with metadata to a two-dimensional image. - :param array array: The array to convert to image. + :param `numpy.ndarray` array: The array to convert to image. :param str origin: Path to the image, optional. :return: a :class:`Row` that is a two dimensional image. .. versionadded:: 2.3.0 """ + if not isinstance(array, np.ndarray): + raise TypeError( + "array argument should be numpy.ndarray; however, it got [%s]." % type(array)) + if array.ndim != 3: raise ValueError("Invalid array shape") + height, width, nChannels = array.shape ocvTypes = ImageSchema.ocvTypes if nChannels == 1: @@ -146,7 +162,12 @@ def toImage(self, array, origin=""): mode = ocvTypes["CV_8UC4"] else: raise ValueError("Invalid number of channels") - data = bytearray(array.astype(dtype=np.uint8).ravel()) + + # Running `bytearray(numpy.array([1]))` fails in specific Python versions + # with a specific Numpy version, for example in Python 3.6.0 and NumPy 1.13.3. + # Here, it avoids it by converting it to bytes. + data = bytearray(array.astype(dtype=np.uint8).ravel().tobytes()) + # Creating new Row with _create_row(), because Row(name = value, ... ) # orders fields by name, which conflicts with expected schema order # when the new DataFrame is created by UDF diff --git a/python/pyspark/ml/tests.py b/python/pyspark/ml/tests.py index 2258d61c95333..89ef555cf3442 100755 --- a/python/pyspark/ml/tests.py +++ b/python/pyspark/ml/tests.py @@ -71,7 +71,7 @@ from pyspark.sql.functions import rand from pyspark.sql.types import DoubleType, IntegerType from pyspark.storagelevel import * -from pyspark.tests import ReusedPySparkTestCase as PySparkTestCase +from pyspark.tests import QuietTest, ReusedPySparkTestCase as PySparkTestCase ser = PickleSerializer() @@ -1836,6 +1836,24 @@ def test_read_images(self): self.assertEqual(ImageSchema.imageFields, expected) self.assertEqual(ImageSchema.undefinedImageType, "Undefined") + with QuietTest(self.sc): + self.assertRaisesRegexp( + TypeError, + "image argument should be pyspark.sql.types.Row; however", + lambda: ImageSchema.toNDArray("a")) + + with QuietTest(self.sc): + self.assertRaisesRegexp( + ValueError, + "image argument should have attributes specified in", + lambda: ImageSchema.toNDArray(Row(a=1))) + + with QuietTest(self.sc): + self.assertRaisesRegexp( + TypeError, + "array argument should be numpy.ndarray; however, it got", + lambda: ImageSchema.toImage("a")) + class ALSTest(SparkSessionTestCase): From 444a2bbb67c2548d121152bc922b4c3337ddc8e8 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Thu, 30 Nov 2017 18:28:58 +0800 Subject: [PATCH 1745/1765] [SPARK-22652][SQL] remove set methods in ColumnarRow ## What changes were proposed in this pull request? As a step to make `ColumnVector` public, the `ColumnarRow` returned by `ColumnVector#getStruct` should be immutable. However we do need the mutability of `ColumnaRow` for the fast vectorized hashmap in hash aggregate. To solve this, this PR introduces a `MutableColumnarRow` for this use case. ## How was this patch tested? existing test. Author: Wenchen Fan Closes #19847 from cloud-fan/mutable-row. --- .../sql/execution/vectorized/ColumnarRow.java | 102 +------ .../vectorized/MutableColumnarRow.java | 278 ++++++++++++++++++ .../aggregate/HashAggregateExec.scala | 3 +- .../VectorizedHashMapGenerator.scala | 82 +++--- .../vectorized/ColumnVectorSuite.scala | 12 + .../vectorized/ColumnarBatchSuite.scala | 23 -- 6 files changed, 336 insertions(+), 164 deletions(-) create mode 100644 sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/MutableColumnarRow.java diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarRow.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarRow.java index 98a907322713b..cabb7479525d9 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarRow.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarRow.java @@ -16,8 +16,6 @@ */ package org.apache.spark.sql.execution.vectorized; -import java.math.BigDecimal; - import org.apache.spark.sql.catalyst.InternalRow; import org.apache.spark.sql.catalyst.expressions.GenericInternalRow; import org.apache.spark.sql.catalyst.util.MapData; @@ -32,17 +30,10 @@ public final class ColumnarRow extends InternalRow { protected int rowId; private final ColumnVector[] columns; - private final WritableColumnVector[] writableColumns; // Ctor used if this is a struct. ColumnarRow(ColumnVector[] columns) { this.columns = columns; - this.writableColumns = new WritableColumnVector[this.columns.length]; - for (int i = 0; i < this.columns.length; i++) { - if (this.columns[i] instanceof WritableColumnVector) { - this.writableColumns[i] = (WritableColumnVector) this.columns[i]; - } - } } public ColumnVector[] columns() { return columns; } @@ -205,97 +196,8 @@ public Object get(int ordinal, DataType dataType) { } @Override - public void update(int ordinal, Object value) { - if (value == null) { - setNullAt(ordinal); - } else { - DataType dt = columns[ordinal].dataType(); - if (dt instanceof BooleanType) { - setBoolean(ordinal, (boolean) value); - } else if (dt instanceof IntegerType) { - setInt(ordinal, (int) value); - } else if (dt instanceof ShortType) { - setShort(ordinal, (short) value); - } else if (dt instanceof LongType) { - setLong(ordinal, (long) value); - } else if (dt instanceof FloatType) { - setFloat(ordinal, (float) value); - } else if (dt instanceof DoubleType) { - setDouble(ordinal, (double) value); - } else if (dt instanceof DecimalType) { - DecimalType t = (DecimalType) dt; - setDecimal(ordinal, Decimal.apply((BigDecimal) value, t.precision(), t.scale()), - t.precision()); - } else { - throw new UnsupportedOperationException("Datatype not supported " + dt); - } - } - } - - @Override - public void setNullAt(int ordinal) { - getWritableColumn(ordinal).putNull(rowId); - } - - @Override - public void setBoolean(int ordinal, boolean value) { - WritableColumnVector column = getWritableColumn(ordinal); - column.putNotNull(rowId); - column.putBoolean(rowId, value); - } + public void update(int ordinal, Object value) { throw new UnsupportedOperationException(); } @Override - public void setByte(int ordinal, byte value) { - WritableColumnVector column = getWritableColumn(ordinal); - column.putNotNull(rowId); - column.putByte(rowId, value); - } - - @Override - public void setShort(int ordinal, short value) { - WritableColumnVector column = getWritableColumn(ordinal); - column.putNotNull(rowId); - column.putShort(rowId, value); - } - - @Override - public void setInt(int ordinal, int value) { - WritableColumnVector column = getWritableColumn(ordinal); - column.putNotNull(rowId); - column.putInt(rowId, value); - } - - @Override - public void setLong(int ordinal, long value) { - WritableColumnVector column = getWritableColumn(ordinal); - column.putNotNull(rowId); - column.putLong(rowId, value); - } - - @Override - public void setFloat(int ordinal, float value) { - WritableColumnVector column = getWritableColumn(ordinal); - column.putNotNull(rowId); - column.putFloat(rowId, value); - } - - @Override - public void setDouble(int ordinal, double value) { - WritableColumnVector column = getWritableColumn(ordinal); - column.putNotNull(rowId); - column.putDouble(rowId, value); - } - - @Override - public void setDecimal(int ordinal, Decimal value, int precision) { - WritableColumnVector column = getWritableColumn(ordinal); - column.putNotNull(rowId); - column.putDecimal(rowId, value, precision); - } - - private WritableColumnVector getWritableColumn(int ordinal) { - WritableColumnVector column = writableColumns[ordinal]; - assert (!column.isConstant); - return column; - } + public void setNullAt(int ordinal) { throw new UnsupportedOperationException(); } } diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/MutableColumnarRow.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/MutableColumnarRow.java new file mode 100644 index 0000000000000..f272cc163611b --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/MutableColumnarRow.java @@ -0,0 +1,278 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.vectorized; + +import java.math.BigDecimal; + +import org.apache.spark.sql.catalyst.InternalRow; +import org.apache.spark.sql.catalyst.expressions.GenericInternalRow; +import org.apache.spark.sql.catalyst.util.MapData; +import org.apache.spark.sql.types.*; +import org.apache.spark.unsafe.types.CalendarInterval; +import org.apache.spark.unsafe.types.UTF8String; + +/** + * A mutable version of {@link ColumnarRow}, which is used in the vectorized hash map for hash + * aggregate. + * + * Note that this class intentionally has a lot of duplicated code with {@link ColumnarRow}, to + * avoid java polymorphism overhead by keeping {@link ColumnarRow} and this class final classes. + */ +public final class MutableColumnarRow extends InternalRow { + public int rowId; + private final WritableColumnVector[] columns; + + public MutableColumnarRow(WritableColumnVector[] columns) { + this.columns = columns; + } + + @Override + public int numFields() { return columns.length; } + + @Override + public InternalRow copy() { + GenericInternalRow row = new GenericInternalRow(columns.length); + for (int i = 0; i < numFields(); i++) { + if (isNullAt(i)) { + row.setNullAt(i); + } else { + DataType dt = columns[i].dataType(); + if (dt instanceof BooleanType) { + row.setBoolean(i, getBoolean(i)); + } else if (dt instanceof ByteType) { + row.setByte(i, getByte(i)); + } else if (dt instanceof ShortType) { + row.setShort(i, getShort(i)); + } else if (dt instanceof IntegerType) { + row.setInt(i, getInt(i)); + } else if (dt instanceof LongType) { + row.setLong(i, getLong(i)); + } else if (dt instanceof FloatType) { + row.setFloat(i, getFloat(i)); + } else if (dt instanceof DoubleType) { + row.setDouble(i, getDouble(i)); + } else if (dt instanceof StringType) { + row.update(i, getUTF8String(i).copy()); + } else if (dt instanceof BinaryType) { + row.update(i, getBinary(i)); + } else if (dt instanceof DecimalType) { + DecimalType t = (DecimalType)dt; + row.setDecimal(i, getDecimal(i, t.precision(), t.scale()), t.precision()); + } else if (dt instanceof DateType) { + row.setInt(i, getInt(i)); + } else if (dt instanceof TimestampType) { + row.setLong(i, getLong(i)); + } else { + throw new RuntimeException("Not implemented. " + dt); + } + } + } + return row; + } + + @Override + public boolean anyNull() { + throw new UnsupportedOperationException(); + } + + @Override + public boolean isNullAt(int ordinal) { return columns[ordinal].isNullAt(rowId); } + + @Override + public boolean getBoolean(int ordinal) { return columns[ordinal].getBoolean(rowId); } + + @Override + public byte getByte(int ordinal) { return columns[ordinal].getByte(rowId); } + + @Override + public short getShort(int ordinal) { return columns[ordinal].getShort(rowId); } + + @Override + public int getInt(int ordinal) { return columns[ordinal].getInt(rowId); } + + @Override + public long getLong(int ordinal) { return columns[ordinal].getLong(rowId); } + + @Override + public float getFloat(int ordinal) { return columns[ordinal].getFloat(rowId); } + + @Override + public double getDouble(int ordinal) { return columns[ordinal].getDouble(rowId); } + + @Override + public Decimal getDecimal(int ordinal, int precision, int scale) { + if (columns[ordinal].isNullAt(rowId)) return null; + return columns[ordinal].getDecimal(rowId, precision, scale); + } + + @Override + public UTF8String getUTF8String(int ordinal) { + if (columns[ordinal].isNullAt(rowId)) return null; + return columns[ordinal].getUTF8String(rowId); + } + + @Override + public byte[] getBinary(int ordinal) { + if (columns[ordinal].isNullAt(rowId)) return null; + return columns[ordinal].getBinary(rowId); + } + + @Override + public CalendarInterval getInterval(int ordinal) { + if (columns[ordinal].isNullAt(rowId)) return null; + final int months = columns[ordinal].getChildColumn(0).getInt(rowId); + final long microseconds = columns[ordinal].getChildColumn(1).getLong(rowId); + return new CalendarInterval(months, microseconds); + } + + @Override + public ColumnarRow getStruct(int ordinal, int numFields) { + if (columns[ordinal].isNullAt(rowId)) return null; + return columns[ordinal].getStruct(rowId); + } + + @Override + public ColumnarArray getArray(int ordinal) { + if (columns[ordinal].isNullAt(rowId)) return null; + return columns[ordinal].getArray(rowId); + } + + @Override + public MapData getMap(int ordinal) { + throw new UnsupportedOperationException(); + } + + @Override + public Object get(int ordinal, DataType dataType) { + if (dataType instanceof BooleanType) { + return getBoolean(ordinal); + } else if (dataType instanceof ByteType) { + return getByte(ordinal); + } else if (dataType instanceof ShortType) { + return getShort(ordinal); + } else if (dataType instanceof IntegerType) { + return getInt(ordinal); + } else if (dataType instanceof LongType) { + return getLong(ordinal); + } else if (dataType instanceof FloatType) { + return getFloat(ordinal); + } else if (dataType instanceof DoubleType) { + return getDouble(ordinal); + } else if (dataType instanceof StringType) { + return getUTF8String(ordinal); + } else if (dataType instanceof BinaryType) { + return getBinary(ordinal); + } else if (dataType instanceof DecimalType) { + DecimalType t = (DecimalType) dataType; + return getDecimal(ordinal, t.precision(), t.scale()); + } else if (dataType instanceof DateType) { + return getInt(ordinal); + } else if (dataType instanceof TimestampType) { + return getLong(ordinal); + } else if (dataType instanceof ArrayType) { + return getArray(ordinal); + } else if (dataType instanceof StructType) { + return getStruct(ordinal, ((StructType)dataType).fields().length); + } else if (dataType instanceof MapType) { + return getMap(ordinal); + } else { + throw new UnsupportedOperationException("Datatype not supported " + dataType); + } + } + + @Override + public void update(int ordinal, Object value) { + if (value == null) { + setNullAt(ordinal); + } else { + DataType dt = columns[ordinal].dataType(); + if (dt instanceof BooleanType) { + setBoolean(ordinal, (boolean) value); + } else if (dt instanceof IntegerType) { + setInt(ordinal, (int) value); + } else if (dt instanceof ShortType) { + setShort(ordinal, (short) value); + } else if (dt instanceof LongType) { + setLong(ordinal, (long) value); + } else if (dt instanceof FloatType) { + setFloat(ordinal, (float) value); + } else if (dt instanceof DoubleType) { + setDouble(ordinal, (double) value); + } else if (dt instanceof DecimalType) { + DecimalType t = (DecimalType) dt; + Decimal d = Decimal.apply((BigDecimal) value, t.precision(), t.scale()); + setDecimal(ordinal, d, t.precision()); + } else { + throw new UnsupportedOperationException("Datatype not supported " + dt); + } + } + } + + @Override + public void setNullAt(int ordinal) { + columns[ordinal].putNull(rowId); + } + + @Override + public void setBoolean(int ordinal, boolean value) { + columns[ordinal].putNotNull(rowId); + columns[ordinal].putBoolean(rowId, value); + } + + @Override + public void setByte(int ordinal, byte value) { + columns[ordinal].putNotNull(rowId); + columns[ordinal].putByte(rowId, value); + } + + @Override + public void setShort(int ordinal, short value) { + columns[ordinal].putNotNull(rowId); + columns[ordinal].putShort(rowId, value); + } + + @Override + public void setInt(int ordinal, int value) { + columns[ordinal].putNotNull(rowId); + columns[ordinal].putInt(rowId, value); + } + + @Override + public void setLong(int ordinal, long value) { + columns[ordinal].putNotNull(rowId); + columns[ordinal].putLong(rowId, value); + } + + @Override + public void setFloat(int ordinal, float value) { + columns[ordinal].putNotNull(rowId); + columns[ordinal].putFloat(rowId, value); + } + + @Override + public void setDouble(int ordinal, double value) { + columns[ordinal].putNotNull(rowId); + columns[ordinal].putDouble(rowId, value); + } + + @Override + public void setDecimal(int ordinal, Decimal value, int precision) { + columns[ordinal].putNotNull(rowId); + columns[ordinal].putDecimal(rowId, value, precision); + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala index dc8aecf185a96..913978892cd8c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala @@ -28,6 +28,7 @@ import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics} +import org.apache.spark.sql.execution.vectorized.MutableColumnarRow import org.apache.spark.sql.types.{DecimalType, StringType, StructType} import org.apache.spark.unsafe.KVIterator import org.apache.spark.util.Utils @@ -894,7 +895,7 @@ case class HashAggregateExec( ${ if (isVectorizedHashMapEnabled) { s""" - | org.apache.spark.sql.execution.vectorized.ColumnarRow $fastRowBuffer = null; + | ${classOf[MutableColumnarRow].getName} $fastRowBuffer = null; """.stripMargin } else { s""" diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/VectorizedHashMapGenerator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/VectorizedHashMapGenerator.scala index fd783d905b776..44ba539ebf7c2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/VectorizedHashMapGenerator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/VectorizedHashMapGenerator.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.execution.aggregate import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression import org.apache.spark.sql.catalyst.expressions.codegen.CodegenContext +import org.apache.spark.sql.execution.vectorized.{ColumnarBatch, ColumnarRow, MutableColumnarRow, OnHeapColumnVector} import org.apache.spark.sql.types._ /** @@ -76,10 +77,9 @@ class VectorizedHashMapGenerator( }.mkString("\n").concat(";") s""" - | private org.apache.spark.sql.execution.vectorized.OnHeapColumnVector[] batchVectors; - | private org.apache.spark.sql.execution.vectorized.OnHeapColumnVector[] bufferVectors; - | private org.apache.spark.sql.execution.vectorized.ColumnarBatch batch; - | private org.apache.spark.sql.execution.vectorized.ColumnarBatch aggregateBufferBatch; + | private ${classOf[OnHeapColumnVector].getName}[] vectors; + | private ${classOf[ColumnarBatch].getName} batch; + | private ${classOf[MutableColumnarRow].getName} aggBufferRow; | private int[] buckets; | private int capacity = 1 << 16; | private double loadFactor = 0.5; @@ -91,19 +91,16 @@ class VectorizedHashMapGenerator( | $generatedAggBufferSchema | | public $generatedClassName() { - | batchVectors = org.apache.spark.sql.execution.vectorized - | .OnHeapColumnVector.allocateColumns(capacity, schema); - | batch = new org.apache.spark.sql.execution.vectorized.ColumnarBatch( - | schema, batchVectors, capacity); + | vectors = ${classOf[OnHeapColumnVector].getName}.allocateColumns(capacity, schema); + | batch = new ${classOf[ColumnarBatch].getName}(schema, vectors, capacity); | - | bufferVectors = new org.apache.spark.sql.execution.vectorized - | .OnHeapColumnVector[aggregateBufferSchema.fields().length]; + | // Generates a projection to return the aggregate buffer only. + | ${classOf[OnHeapColumnVector].getName}[] aggBufferVectors = + | new ${classOf[OnHeapColumnVector].getName}[aggregateBufferSchema.fields().length]; | for (int i = 0; i < aggregateBufferSchema.fields().length; i++) { - | bufferVectors[i] = batchVectors[i + ${groupingKeys.length}]; + | aggBufferVectors[i] = vectors[i + ${groupingKeys.length}]; | } - | // TODO: Possibly generate this projection in HashAggregate directly - | aggregateBufferBatch = new org.apache.spark.sql.execution.vectorized.ColumnarBatch( - | aggregateBufferSchema, bufferVectors, capacity); + | aggBufferRow = new ${classOf[MutableColumnarRow].getName}(aggBufferVectors); | | buckets = new int[numBuckets]; | java.util.Arrays.fill(buckets, -1); @@ -114,13 +111,13 @@ class VectorizedHashMapGenerator( /** * Generates a method that returns true if the group-by keys exist at a given index in the - * associated [[org.apache.spark.sql.execution.vectorized.ColumnarBatch]]. For instance, if we - * have 2 long group-by keys, the generated function would be of the form: + * associated [[org.apache.spark.sql.execution.vectorized.OnHeapColumnVector]]. For instance, + * if we have 2 long group-by keys, the generated function would be of the form: * * {{{ * private boolean equals(int idx, long agg_key, long agg_key1) { - * return batchVectors[0].getLong(buckets[idx]) == agg_key && - * batchVectors[1].getLong(buckets[idx]) == agg_key1; + * return vectors[0].getLong(buckets[idx]) == agg_key && + * vectors[1].getLong(buckets[idx]) == agg_key1; * } * }}} */ @@ -128,7 +125,7 @@ class VectorizedHashMapGenerator( def genEqualsForKeys(groupingKeys: Seq[Buffer]): String = { groupingKeys.zipWithIndex.map { case (key: Buffer, ordinal: Int) => - s"""(${ctx.genEqual(key.dataType, ctx.getValue(s"batchVectors[$ordinal]", "buckets[idx]", + s"""(${ctx.genEqual(key.dataType, ctx.getValue(s"vectors[$ordinal]", "buckets[idx]", key.dataType), key.name)})""" }.mkString(" && ") } @@ -141,29 +138,35 @@ class VectorizedHashMapGenerator( } /** - * Generates a method that returns a mutable - * [[org.apache.spark.sql.execution.vectorized.ColumnarRow]] which keeps track of the + * Generates a method that returns a + * [[org.apache.spark.sql.execution.vectorized.MutableColumnarRow]] which keeps track of the * aggregate value(s) for a given set of keys. If the corresponding row doesn't exist, the * generated method adds the corresponding row in the associated - * [[org.apache.spark.sql.execution.vectorized.ColumnarBatch]]. For instance, if we + * [[org.apache.spark.sql.execution.vectorized.OnHeapColumnVector]]. For instance, if we * have 2 long group-by keys, the generated function would be of the form: * * {{{ - * public org.apache.spark.sql.execution.vectorized.ColumnarRow findOrInsert( - * long agg_key, long agg_key1) { + * public MutableColumnarRow findOrInsert(long agg_key, long agg_key1) { * long h = hash(agg_key, agg_key1); * int step = 0; * int idx = (int) h & (numBuckets - 1); * while (step < maxSteps) { * // Return bucket index if it's either an empty slot or already contains the key * if (buckets[idx] == -1) { - * batchVectors[0].putLong(numRows, agg_key); - * batchVectors[1].putLong(numRows, agg_key1); - * batchVectors[2].putLong(numRows, 0); - * buckets[idx] = numRows++; - * return batch.getRow(buckets[idx]); + * if (numRows < capacity) { + * vectors[0].putLong(numRows, agg_key); + * vectors[1].putLong(numRows, agg_key1); + * vectors[2].putLong(numRows, 0); + * buckets[idx] = numRows++; + * aggBufferRow.rowId = numRows; + * return aggBufferRow; + * } else { + * // No more space + * return null; + * } * } else if (equals(idx, agg_key, agg_key1)) { - * return batch.getRow(buckets[idx]); + * aggBufferRow.rowId = buckets[idx]; + * return aggBufferRow; * } * idx = (idx + 1) & (numBuckets - 1); * step++; @@ -177,20 +180,19 @@ class VectorizedHashMapGenerator( def genCodeToSetKeys(groupingKeys: Seq[Buffer]): Seq[String] = { groupingKeys.zipWithIndex.map { case (key: Buffer, ordinal: Int) => - ctx.setValue(s"batchVectors[$ordinal]", "numRows", key.dataType, key.name) + ctx.setValue(s"vectors[$ordinal]", "numRows", key.dataType, key.name) } } def genCodeToSetAggBuffers(bufferValues: Seq[Buffer]): Seq[String] = { bufferValues.zipWithIndex.map { case (key: Buffer, ordinal: Int) => - ctx.updateColumn(s"batchVectors[${groupingKeys.length + ordinal}]", "numRows", key.dataType, + ctx.updateColumn(s"vectors[${groupingKeys.length + ordinal}]", "numRows", key.dataType, buffVars(ordinal), nullable = true) } } s""" - |public org.apache.spark.sql.execution.vectorized.ColumnarRow findOrInsert(${ - groupingKeySignature}) { + |public ${classOf[MutableColumnarRow].getName} findOrInsert($groupingKeySignature) { | long h = hash(${groupingKeys.map(_.name).mkString(", ")}); | int step = 0; | int idx = (int) h & (numBuckets - 1); @@ -208,15 +210,15 @@ class VectorizedHashMapGenerator( | ${genCodeToSetAggBuffers(bufferValues).mkString("\n")} | | buckets[idx] = numRows++; - | batch.setNumRows(numRows); - | aggregateBufferBatch.setNumRows(numRows); - | return aggregateBufferBatch.getRow(buckets[idx]); + | aggBufferRow.rowId = buckets[idx]; + | return aggBufferRow; | } else { | // No more space | return null; | } | } else if (equals(idx, ${groupingKeys.map(_.name).mkString(", ")})) { - | return aggregateBufferBatch.getRow(buckets[idx]); + | aggBufferRow.rowId = buckets[idx]; + | return aggBufferRow; | } | idx = (idx + 1) & (numBuckets - 1); | step++; @@ -229,8 +231,8 @@ class VectorizedHashMapGenerator( protected def generateRowIterator(): String = { s""" - |public java.util.Iterator - | rowIterator() { + |public java.util.Iterator<${classOf[ColumnarRow].getName}> rowIterator() { + | batch.setNumRows(numRows); | return batch.rowIterator(); |} """.stripMargin diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnVectorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnVectorSuite.scala index 3c76ca79f5dda..e28ab710f5a99 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnVectorSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnVectorSuite.scala @@ -163,6 +163,18 @@ class ColumnVectorSuite extends SparkFunSuite with BeforeAndAfterEach { } } + testVectors("mutable ColumnarRow", 10, IntegerType) { testVector => + val mutableRow = new MutableColumnarRow(Array(testVector)) + (0 until 10).foreach { i => + mutableRow.rowId = i + mutableRow.setInt(0, 10 - i) + } + (0 until 10).foreach { i => + mutableRow.rowId = i + assert(mutableRow.getInt(0) === (10 - i)) + } + } + val arrayType: ArrayType = ArrayType(IntegerType, containsNull = true) testVectors("array", 10, arrayType) { testVector => diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala index 80a50866aa504..1b4e2bad09a20 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala @@ -1129,29 +1129,6 @@ class ColumnarBatchSuite extends SparkFunSuite { testRandomRows(false, 30) } - test("mutable ColumnarBatch rows") { - val NUM_ITERS = 10 - val types = Array( - BooleanType, FloatType, DoubleType, IntegerType, LongType, ShortType, - DecimalType.ShortDecimal, DecimalType.IntDecimal, DecimalType.ByteDecimal, - DecimalType.FloatDecimal, DecimalType.LongDecimal, new DecimalType(5, 2), - new DecimalType(12, 2), new DecimalType(30, 10)) - for (i <- 0 to NUM_ITERS) { - val random = new Random(System.nanoTime()) - val schema = RandomDataGenerator.randomSchema(random, numFields = 20, types) - val oldRow = RandomDataGenerator.randomRow(random, schema) - val newRow = RandomDataGenerator.randomRow(random, schema) - - (MemoryMode.ON_HEAP :: MemoryMode.OFF_HEAP :: Nil).foreach { memMode => - val batch = ColumnVectorUtils.toBatch(schema, memMode, (oldRow :: Nil).iterator.asJava) - val columnarBatchRow = batch.getRow(0) - newRow.toSeq.zipWithIndex.foreach(i => columnarBatchRow.update(i._2, i._1)) - compareStruct(schema, columnarBatchRow, newRow, 0) - batch.close() - } - } - } - test("exceeding maximum capacity should throw an error") { (MemoryMode.ON_HEAP :: MemoryMode.OFF_HEAP :: Nil).foreach { memMode => val column = allocate(1, ByteType, memMode) From 9c29c557635caf739fde942f53255273aac0d7b1 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Thu, 30 Nov 2017 18:34:38 +0800 Subject: [PATCH 1746/1765] [SPARK-22643][SQL] ColumnarArray should be an immutable view ## What changes were proposed in this pull request? To make `ColumnVector` public, `ColumnarArray` need to be public too, and we should not have mutable public fields in a public class. This PR proposes to make `ColumnarArray` an immutable view of the data, and always create a new instance of `ColumnarArray` in `ColumnVector#getArray` ## How was this patch tested? new benchmark in `ColumnarBatchBenchmark` Author: Wenchen Fan Closes #19842 from cloud-fan/column-vector. --- .../parquet/VectorizedColumnReader.java | 2 +- .../vectorized/ArrowColumnVector.java | 1 - .../execution/vectorized/ColumnVector.java | 14 +- .../vectorized/ColumnVectorUtils.java | 18 +-- .../execution/vectorized/ColumnarArray.java | 10 +- .../vectorized/OffHeapColumnVector.java | 2 +- .../vectorized/OnHeapColumnVector.java | 2 +- .../vectorized/WritableColumnVector.java | 13 +- .../vectorized/ColumnVectorSuite.scala | 41 +++-- .../vectorized/ColumnarBatchBenchmark.scala | 142 ++++++++++++++---- .../vectorized/ColumnarBatchSuite.scala | 18 +-- 11 files changed, 164 insertions(+), 99 deletions(-) diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedColumnReader.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedColumnReader.java index 0f1f470dc597e..71ca8b1b96a98 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedColumnReader.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedColumnReader.java @@ -425,7 +425,7 @@ private void readBinaryBatch(int rowId, int num, WritableColumnVector column) { // This is where we implement support for the valid type conversions. // TODO: implement remaining type conversions VectorizedValuesReader data = (VectorizedValuesReader) dataColumn; - if (column.isArray()) { + if (column.dataType() == DataTypes.StringType || column.dataType() == DataTypes.BinaryType) { defColumn.readBinarys(num, column, rowId, maxDefLevel, data); } else if (column.dataType() == DataTypes.TimestampType) { for (int i = 0; i < num; i++) { diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ArrowColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ArrowColumnVector.java index 5c502c9d91be4..0071bd66760be 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ArrowColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ArrowColumnVector.java @@ -315,7 +315,6 @@ public ArrowColumnVector(ValueVector vector) { childColumns = new ArrowColumnVector[1]; childColumns[0] = new ArrowColumnVector(listVector.getDataVector()); - resultArray = new ColumnarArray(childColumns[0]); } else if (vector instanceof MapVector) { MapVector mapVector = (MapVector) vector; accessor = new StructAccessor(mapVector); diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVector.java index 940457f2e3363..cca14911fbb28 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVector.java @@ -175,9 +175,7 @@ public ColumnarRow getStruct(int rowId, int size) { * Returns the array at rowid. */ public final ColumnarArray getArray(int rowId) { - resultArray.length = getArrayLength(rowId); - resultArray.offset = getArrayOffset(rowId); - return resultArray; + return new ColumnarArray(arrayData(), getArrayOffset(rowId), getArrayLength(rowId)); } /** @@ -213,21 +211,11 @@ public MapData getMap(int ordinal) { */ public abstract ColumnVector getChildColumn(int ordinal); - /** - * Returns true if this column is an array. - */ - public final boolean isArray() { return resultArray != null; } - /** * Data type for this column. */ protected DataType type; - /** - * Reusable Array holder for getArray(). - */ - protected ColumnarArray resultArray; - /** * Reusable Struct holder for getStruct(). */ diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVectorUtils.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVectorUtils.java index b4b5f0a265934..bc62bc43484e5 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVectorUtils.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVectorUtils.java @@ -98,21 +98,13 @@ public static void populate(WritableColumnVector col, InternalRow row, int field * For example, an array of IntegerType will return an int[]. * Throws exceptions for unhandled schemas. */ - public static Object toPrimitiveJavaArray(ColumnarArray array) { - DataType dt = array.data.dataType(); - if (dt instanceof IntegerType) { - int[] result = new int[array.length]; - ColumnVector data = array.data; - for (int i = 0; i < result.length; i++) { - if (data.isNullAt(array.offset + i)) { - throw new RuntimeException("Cannot handle NULL values."); - } - result[i] = data.getInt(array.offset + i); + public static int[] toJavaIntArray(ColumnarArray array) { + for (int i = 0; i < array.numElements(); i++) { + if (array.isNullAt(i)) { + throw new RuntimeException("Cannot handle NULL values."); } - return result; - } else { - throw new UnsupportedOperationException(); } + return array.toIntArray(); } private static void appendValue(WritableColumnVector dst, DataType t, Object o) { diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarArray.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarArray.java index b9da641fc66c8..cbc39d1d0aec2 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarArray.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarArray.java @@ -29,12 +29,14 @@ public final class ColumnarArray extends ArrayData { // The data for this array. This array contains elements from // data[offset] to data[offset + length). - public final ColumnVector data; - public int length; - public int offset; + private final ColumnVector data; + private final int offset; + private final int length; - ColumnarArray(ColumnVector data) { + ColumnarArray(ColumnVector data, int offset, int length) { this.data = data; + this.offset = offset; + this.length = length; } @Override diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OffHeapColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OffHeapColumnVector.java index 1cbaf08569334..806d0291a6c49 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OffHeapColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OffHeapColumnVector.java @@ -532,7 +532,7 @@ public int putByteArray(int rowId, byte[] value, int offset, int length) { @Override protected void reserveInternal(int newCapacity) { int oldCapacity = (nulls == 0L) ? 0 : capacity; - if (this.resultArray != null) { + if (isArray()) { this.lengthData = Platform.reallocateMemory(lengthData, oldCapacity * 4, newCapacity * 4); this.offsetData = diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OnHeapColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OnHeapColumnVector.java index 85d72295ab9b8..6e7f74ce12f16 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OnHeapColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OnHeapColumnVector.java @@ -505,7 +505,7 @@ public int putByteArray(int rowId, byte[] value, int offset, int length) { // Spilt this function out since it is the slow path. @Override protected void reserveInternal(int newCapacity) { - if (this.resultArray != null || DecimalType.isByteArrayDecimalType(type)) { + if (isArray()) { int[] newLengths = new int[newCapacity]; int[] newOffsets = new int[newCapacity]; if (this.arrayLengths != null) { diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/WritableColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/WritableColumnVector.java index e7653f0c00b9a..0bea4cc97142d 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/WritableColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/WritableColumnVector.java @@ -75,7 +75,6 @@ public void close() { } dictionary = null; resultStruct = null; - resultArray = null; } public void reserve(int requiredCapacity) { @@ -650,6 +649,11 @@ public WritableColumnVector getDictionaryIds() { */ protected abstract WritableColumnVector reserveNewColumn(int capacity, DataType type); + protected boolean isArray() { + return type instanceof ArrayType || type instanceof BinaryType || type instanceof StringType || + DecimalType.isByteArrayDecimalType(type); + } + /** * Sets up the common state and also handles creating the child columns if this is a nested * type. @@ -658,8 +662,7 @@ protected WritableColumnVector(int capacity, DataType type) { super(type); this.capacity = capacity; - if (type instanceof ArrayType || type instanceof BinaryType || type instanceof StringType - || DecimalType.isByteArrayDecimalType(type)) { + if (isArray()) { DataType childType; int childCapacity = capacity; if (type instanceof ArrayType) { @@ -670,7 +673,6 @@ protected WritableColumnVector(int capacity, DataType type) { } this.childColumns = new WritableColumnVector[1]; this.childColumns[0] = reserveNewColumn(childCapacity, childType); - this.resultArray = new ColumnarArray(this.childColumns[0]); this.resultStruct = null; } else if (type instanceof StructType) { StructType st = (StructType)type; @@ -678,18 +680,15 @@ protected WritableColumnVector(int capacity, DataType type) { for (int i = 0; i < childColumns.length; ++i) { this.childColumns[i] = reserveNewColumn(capacity, st.fields()[i].dataType()); } - this.resultArray = null; this.resultStruct = new ColumnarRow(this.childColumns); } else if (type instanceof CalendarIntervalType) { // Two columns. Months as int. Microseconds as Long. this.childColumns = new WritableColumnVector[2]; this.childColumns[0] = reserveNewColumn(capacity, DataTypes.IntegerType); this.childColumns[1] = reserveNewColumn(capacity, DataTypes.LongType); - this.resultArray = null; this.resultStruct = new ColumnarRow(this.childColumns); } else { this.childColumns = null; - this.resultArray = null; this.resultStruct = null; } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnVectorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnVectorSuite.scala index e28ab710f5a99..54b31cee031f6 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnVectorSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnVectorSuite.scala @@ -57,7 +57,7 @@ class ColumnVectorSuite extends SparkFunSuite with BeforeAndAfterEach { testVector.appendBoolean(i % 2 == 0) } - val array = new ColumnarArray(testVector) + val array = new ColumnarArray(testVector, 0, 10) (0 until 10).foreach { i => assert(array.get(i, BooleanType) === (i % 2 == 0)) @@ -69,7 +69,7 @@ class ColumnVectorSuite extends SparkFunSuite with BeforeAndAfterEach { testVector.appendByte(i.toByte) } - val array = new ColumnarArray(testVector) + val array = new ColumnarArray(testVector, 0, 10) (0 until 10).foreach { i => assert(array.get(i, ByteType) === i.toByte) @@ -81,7 +81,7 @@ class ColumnVectorSuite extends SparkFunSuite with BeforeAndAfterEach { testVector.appendShort(i.toShort) } - val array = new ColumnarArray(testVector) + val array = new ColumnarArray(testVector, 0, 10) (0 until 10).foreach { i => assert(array.get(i, ShortType) === i.toShort) @@ -93,7 +93,7 @@ class ColumnVectorSuite extends SparkFunSuite with BeforeAndAfterEach { testVector.appendInt(i) } - val array = new ColumnarArray(testVector) + val array = new ColumnarArray(testVector, 0, 10) (0 until 10).foreach { i => assert(array.get(i, IntegerType) === i) @@ -105,7 +105,7 @@ class ColumnVectorSuite extends SparkFunSuite with BeforeAndAfterEach { testVector.appendLong(i) } - val array = new ColumnarArray(testVector) + val array = new ColumnarArray(testVector, 0, 10) (0 until 10).foreach { i => assert(array.get(i, LongType) === i) @@ -117,7 +117,7 @@ class ColumnVectorSuite extends SparkFunSuite with BeforeAndAfterEach { testVector.appendFloat(i.toFloat) } - val array = new ColumnarArray(testVector) + val array = new ColumnarArray(testVector, 0, 10) (0 until 10).foreach { i => assert(array.get(i, FloatType) === i.toFloat) @@ -129,7 +129,7 @@ class ColumnVectorSuite extends SparkFunSuite with BeforeAndAfterEach { testVector.appendDouble(i.toDouble) } - val array = new ColumnarArray(testVector) + val array = new ColumnarArray(testVector, 0, 10) (0 until 10).foreach { i => assert(array.get(i, DoubleType) === i.toDouble) @@ -142,7 +142,7 @@ class ColumnVectorSuite extends SparkFunSuite with BeforeAndAfterEach { testVector.appendByteArray(utf8, 0, utf8.length) } - val array = new ColumnarArray(testVector) + val array = new ColumnarArray(testVector, 0, 10) (0 until 10).foreach { i => assert(array.get(i, StringType) === UTF8String.fromString(s"str$i")) @@ -155,7 +155,7 @@ class ColumnVectorSuite extends SparkFunSuite with BeforeAndAfterEach { testVector.appendByteArray(utf8, 0, utf8.length) } - val array = new ColumnarArray(testVector) + val array = new ColumnarArray(testVector, 0, 10) (0 until 10).foreach { i => val utf8 = s"str$i".getBytes("utf8") @@ -191,12 +191,10 @@ class ColumnVectorSuite extends SparkFunSuite with BeforeAndAfterEach { testVector.putArray(2, 3, 0) testVector.putArray(3, 3, 3) - val array = new ColumnarArray(testVector) - - assert(array.getArray(0).toIntArray() === Array(0)) - assert(array.getArray(1).toIntArray() === Array(1, 2)) - assert(array.getArray(2).toIntArray() === Array.empty[Int]) - assert(array.getArray(3).toIntArray() === Array(3, 4, 5)) + assert(testVector.getArray(0).toIntArray() === Array(0)) + assert(testVector.getArray(1).toIntArray() === Array(1, 2)) + assert(testVector.getArray(2).toIntArray() === Array.empty[Int]) + assert(testVector.getArray(3).toIntArray() === Array(3, 4, 5)) } val structType: StructType = new StructType().add("int", IntegerType).add("double", DoubleType) @@ -208,12 +206,10 @@ class ColumnVectorSuite extends SparkFunSuite with BeforeAndAfterEach { c1.putInt(1, 456) c2.putDouble(1, 5.67) - val array = new ColumnarArray(testVector) - - assert(array.getStruct(0, structType.length).get(0, IntegerType) === 123) - assert(array.getStruct(0, structType.length).get(1, DoubleType) === 3.45) - assert(array.getStruct(1, structType.length).get(0, IntegerType) === 456) - assert(array.getStruct(1, structType.length).get(1, DoubleType) === 5.67) + assert(testVector.getStruct(0, structType.length).get(0, IntegerType) === 123) + assert(testVector.getStruct(0, structType.length).get(1, DoubleType) === 3.45) + assert(testVector.getStruct(1, structType.length).get(0, IntegerType) === 456) + assert(testVector.getStruct(1, structType.length).get(1, DoubleType) === 5.67) } test("[SPARK-22092] off-heap column vector reallocation corrupts array data") { @@ -226,9 +222,8 @@ class ColumnVectorSuite extends SparkFunSuite with BeforeAndAfterEach { testVector.reserve(16) // Check that none of the values got lost/overwritten. - val array = new ColumnarArray(testVector) (0 until 8).foreach { i => - assert(array.get(i, arrayType).asInstanceOf[ArrayData].toIntArray() === Array(i)) + assert(testVector.getArray(i).toIntArray() === Array(i)) } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchBenchmark.scala index 705b26b8c91e6..38ea2e47fdef8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchBenchmark.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchBenchmark.scala @@ -23,11 +23,9 @@ import scala.util.Random import org.apache.spark.memory.MemoryMode import org.apache.spark.sql.catalyst.expressions.UnsafeRow -import org.apache.spark.sql.execution.vectorized.ColumnVector import org.apache.spark.sql.execution.vectorized.OffHeapColumnVector import org.apache.spark.sql.execution.vectorized.OnHeapColumnVector -import org.apache.spark.sql.execution.vectorized.WritableColumnVector -import org.apache.spark.sql.types.{BinaryType, DataType, IntegerType} +import org.apache.spark.sql.types.{ArrayType, BinaryType, IntegerType} import org.apache.spark.unsafe.Platform import org.apache.spark.util.Benchmark import org.apache.spark.util.collection.BitSet @@ -265,20 +263,22 @@ object ColumnarBatchBenchmark { } /* - Intel(R) Core(TM) i7-4870HQ CPU @ 2.50GHz - Int Read/Write: Avg Time(ms) Avg Rate(M/s) Relative Rate - ------------------------------------------------------------------------- - Java Array 248.8 1317.04 1.00 X - ByteBuffer Unsafe 435.6 752.25 0.57 X - ByteBuffer API 1752.0 187.03 0.14 X - DirectByteBuffer 595.4 550.35 0.42 X - Unsafe Buffer 235.2 1393.20 1.06 X - Column(on heap) 189.8 1726.45 1.31 X - Column(off heap) 408.4 802.35 0.61 X - Column(off heap direct) 237.6 1379.12 1.05 X - UnsafeRow (on heap) 414.6 790.35 0.60 X - UnsafeRow (off heap) 487.2 672.58 0.51 X - Column On Heap Append 530.1 618.14 0.59 X + Java HotSpot(TM) 64-Bit Server VM 1.8.0_60-b27 on Mac OS X 10.13.1 + Intel(R) Core(TM) i7-4960HQ CPU @ 2.60GHz + + Int Read/Write: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ------------------------------------------------------------------------------------------------ + Java Array 177 / 181 1856.4 0.5 1.0X + ByteBuffer Unsafe 318 / 322 1032.0 1.0 0.6X + ByteBuffer API 1411 / 1418 232.2 4.3 0.1X + DirectByteBuffer 467 / 474 701.8 1.4 0.4X + Unsafe Buffer 178 / 185 1843.6 0.5 1.0X + Column(on heap) 178 / 184 1840.8 0.5 1.0X + Column(off heap) 341 / 344 961.8 1.0 0.5X + Column(off heap direct) 178 / 184 1845.4 0.5 1.0X + UnsafeRow (on heap) 378 / 389 866.3 1.2 0.5X + UnsafeRow (off heap) 393 / 402 834.0 1.2 0.4X + Column On Heap Append 309 / 318 1059.1 0.9 0.6X */ val benchmark = new Benchmark("Int Read/Write", count * iters) benchmark.addCase("Java Array")(javaArray) @@ -332,11 +332,13 @@ object ColumnarBatchBenchmark { } }} /* - Intel(R) Core(TM) i7-4870HQ CPU @ 2.50GHz - Boolean Read/Write: Avg Time(ms) Avg Rate(M/s) Relative Rate - ------------------------------------------------------------------------- - Bitset 895.88 374.54 1.00 X - Byte Array 578.96 579.56 1.55 X + Java HotSpot(TM) 64-Bit Server VM 1.8.0_60-b27 on Mac OS X 10.13.1 + Intel(R) Core(TM) i7-4960HQ CPU @ 2.60GHz + + Boolean Read/Write: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ------------------------------------------------------------------------------------------------ + Bitset 726 / 727 462.4 2.2 1.0X + Byte Array 530 / 542 632.7 1.6 1.4X */ benchmark.run() } @@ -387,10 +389,13 @@ object ColumnarBatchBenchmark { } /* - String Read/Write: Avg Time(ms) Avg Rate(M/s) Relative Rate - ------------------------------------------------------------------------------------- - On Heap 457.0 35.85 1.00 X - Off Heap 1206.0 13.59 0.38 X + Java HotSpot(TM) 64-Bit Server VM 1.8.0_60-b27 on Mac OS X 10.13.1 + Intel(R) Core(TM) i7-4960HQ CPU @ 2.60GHz + + String Read/Write: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ------------------------------------------------------------------------------------------------ + On Heap 332 / 338 49.3 20.3 1.0X + Off Heap 466 / 467 35.2 28.4 0.7X */ val benchmark = new Benchmark("String Read/Write", count * iters) benchmark.addCase("On Heap")(column(MemoryMode.ON_HEAP)) @@ -398,9 +403,94 @@ object ColumnarBatchBenchmark { benchmark.run } + def arrayAccess(iters: Int): Unit = { + val random = new Random(0) + val count = 4 * 1000 + + val onHeapVector = new OnHeapColumnVector(count, ArrayType(IntegerType)) + val offHeapVector = new OffHeapColumnVector(count, ArrayType(IntegerType)) + + val minSize = 3 + val maxSize = 32 + var arraysCount = 0 + var elementsCount = 0 + while (arraysCount < count) { + val size = random.nextInt(maxSize - minSize) + minSize + val onHeapArrayData = onHeapVector.arrayData() + val offHeapArrayData = offHeapVector.arrayData() + + var i = 0 + while (i < size) { + val value = random.nextInt() + onHeapArrayData.appendInt(value) + offHeapArrayData.appendInt(value) + i += 1 + } + + onHeapVector.putArray(arraysCount, elementsCount, size) + offHeapVector.putArray(arraysCount, elementsCount, size) + elementsCount += size + arraysCount += 1 + } + + def readArrays(onHeap: Boolean): Unit = { + System.gc() + val vector = if (onHeap) onHeapVector else offHeapVector + + var sum = 0L + for (_ <- 0 until iters) { + var i = 0 + while (i < count) { + sum += vector.getArray(i).numElements() + i += 1 + } + } + } + + def readArrayElements(onHeap: Boolean): Unit = { + System.gc() + val vector = if (onHeap) onHeapVector else offHeapVector + + var sum = 0L + for (_ <- 0 until iters) { + var i = 0 + while (i < count) { + val array = vector.getArray(i) + val size = array.numElements() + var j = 0 + while (j < size) { + sum += array.getInt(j) + j += 1 + } + i += 1 + } + } + } + + val benchmark = new Benchmark("Array Vector Read", count * iters) + benchmark.addCase("On Heap Read Size Only") { _ => readArrays(true) } + benchmark.addCase("Off Heap Read Size Only") { _ => readArrays(false) } + benchmark.addCase("On Heap Read Elements") { _ => readArrayElements(true) } + benchmark.addCase("Off Heap Read Elements") { _ => readArrayElements(false) } + + /* + Java HotSpot(TM) 64-Bit Server VM 1.8.0_60-b27 on Mac OS X 10.13.1 + Intel(R) Core(TM) i7-4960HQ CPU @ 2.60GHz + + Array Vector Read: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ------------------------------------------------------------------------------------------------ + On Heap Read Size Only 415 / 422 394.7 2.5 1.0X + Off Heap Read Size Only 394 / 402 415.9 2.4 1.1X + On Heap Read Elements 2558 / 2593 64.0 15.6 0.2X + Off Heap Read Elements 3316 / 3317 49.4 20.2 0.1X + */ + benchmark.run + } + def main(args: Array[String]): Unit = { intAccess(1024 * 40) booleanAccess(1024 * 40) stringAccess(1024 * 4) + arrayAccess(1024 * 40) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala index 1b4e2bad09a20..0ae4f2d117609 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala @@ -645,26 +645,26 @@ class ColumnarBatchSuite extends SparkFunSuite { column.putArray(2, 2, 0) column.putArray(3, 3, 3) - val a1 = ColumnVectorUtils.toPrimitiveJavaArray(column.getArray(0)).asInstanceOf[Array[Int]] - val a2 = ColumnVectorUtils.toPrimitiveJavaArray(column.getArray(1)).asInstanceOf[Array[Int]] - val a3 = ColumnVectorUtils.toPrimitiveJavaArray(column.getArray(2)).asInstanceOf[Array[Int]] - val a4 = ColumnVectorUtils.toPrimitiveJavaArray(column.getArray(3)).asInstanceOf[Array[Int]] + val a1 = ColumnVectorUtils.toJavaIntArray(column.getArray(0)) + val a2 = ColumnVectorUtils.toJavaIntArray(column.getArray(1)) + val a3 = ColumnVectorUtils.toJavaIntArray(column.getArray(2)) + val a4 = ColumnVectorUtils.toJavaIntArray(column.getArray(3)) assert(a1 === Array(0)) assert(a2 === Array(1, 2)) assert(a3 === Array.empty[Int]) assert(a4 === Array(3, 4, 5)) // Verify the ArrayData APIs - assert(column.getArray(0).length == 1) + assert(column.getArray(0).numElements() == 1) assert(column.getArray(0).getInt(0) == 0) - assert(column.getArray(1).length == 2) + assert(column.getArray(1).numElements() == 2) assert(column.getArray(1).getInt(0) == 1) assert(column.getArray(1).getInt(1) == 2) - assert(column.getArray(2).length == 0) + assert(column.getArray(2).numElements() == 0) - assert(column.getArray(3).length == 3) + assert(column.getArray(3).numElements() == 3) assert(column.getArray(3).getInt(0) == 3) assert(column.getArray(3).getInt(1) == 4) assert(column.getArray(3).getInt(2) == 5) @@ -677,7 +677,7 @@ class ColumnarBatchSuite extends SparkFunSuite { assert(data.capacity == array.length * 2) data.putInts(0, array.length, array, 0) column.putArray(0, 0, array.length) - assert(ColumnVectorUtils.toPrimitiveJavaArray(column.getArray(0)).asInstanceOf[Array[Int]] + assert(ColumnVectorUtils.toJavaIntArray(column.getArray(0)) === array) } From 6eb203fae7bbc9940710da40f314b89ffb4dd324 Mon Sep 17 00:00:00 2001 From: Sean Owen Date: Fri, 1 Dec 2017 01:21:52 +0900 Subject: [PATCH 1747/1765] [SPARK-22654][TESTS] Retry Spark tarball download if failed in HiveExternalCatalogVersionsSuite ## What changes were proposed in this pull request? Adds a simple loop to retry download of Spark tarballs from different mirrors if the download fails. ## How was this patch tested? Existing tests Author: Sean Owen Closes #19851 from srowen/SPARK-22654. --- .../HiveExternalCatalogVersionsSuite.scala | 24 ++++++++++++++----- 1 file changed, 18 insertions(+), 6 deletions(-) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogVersionsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogVersionsSuite.scala index 6859432c406a9..a3d5b941a6761 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogVersionsSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogVersionsSuite.scala @@ -20,6 +20,8 @@ package org.apache.spark.sql.hive import java.io.File import java.nio.file.Files +import scala.sys.process._ + import org.apache.spark.TestUtils import org.apache.spark.sql.{QueryTest, Row, SparkSession} import org.apache.spark.sql.catalyst.TableIdentifier @@ -50,14 +52,24 @@ class HiveExternalCatalogVersionsSuite extends SparkSubmitTestUtils { super.afterAll() } - private def downloadSpark(version: String): Unit = { - import scala.sys.process._ + private def tryDownloadSpark(version: String, path: String): Unit = { + // Try mirrors a few times until one succeeds + for (i <- 0 until 3) { + val preferredMirror = + Seq("wget", "https://www.apache.org/dyn/closer.lua?preferred=true", "-q", "-O", "-").!!.trim + val url = s"$preferredMirror/spark/spark-$version/spark-$version-bin-hadoop2.7.tgz" + logInfo(s"Downloading Spark $version from $url") + if (Seq("wget", url, "-q", "-P", path).! == 0) { + return + } + logWarning(s"Failed to download Spark $version from $url") + } + fail(s"Unable to download Spark $version") + } - val preferredMirror = - Seq("wget", "https://www.apache.org/dyn/closer.lua?preferred=true", "-q", "-O", "-").!!.trim - val url = s"$preferredMirror/spark/spark-$version/spark-$version-bin-hadoop2.7.tgz" - Seq("wget", url, "-q", "-P", sparkTestingDir.getCanonicalPath).! + private def downloadSpark(version: String): Unit = { + tryDownloadSpark(version, sparkTestingDir.getCanonicalPath) val downloaded = new File(sparkTestingDir, s"spark-$version-bin-hadoop2.7.tgz").getCanonicalPath val targetDir = new File(sparkTestingDir, s"spark-$version").getCanonicalPath From 932bd09c80dc2dc113e94f59f4dcb77e77de7c58 Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Fri, 1 Dec 2017 01:24:15 +0900 Subject: [PATCH 1748/1765] [SPARK-22635][SQL][ORC] FileNotFoundException while reading ORC files containing special characters ## What changes were proposed in this pull request? SPARK-22146 fix the FileNotFoundException issue only for the `inferSchema` method, ie. only for the schema inference, but it doesn't fix the problem when actually reading the data. Thus nearly the same exception happens when someone tries to use the data. This PR covers fixing the problem also there. ## How was this patch tested? enhanced UT Author: Marco Gaido Closes #19844 from mgaido91/SPARK-22635. --- .../org/apache/spark/sql/hive/orc/OrcFileFormat.scala | 11 +++++------ .../spark/sql/hive/MetastoreDataSourcesSuite.scala | 3 ++- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileFormat.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileFormat.scala index 3b33a9ff082f3..95741c7b30289 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileFormat.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileFormat.scala @@ -133,10 +133,12 @@ class OrcFileFormat extends FileFormat with DataSourceRegister with Serializable (file: PartitionedFile) => { val conf = broadcastedHadoopConf.value.value + val filePath = new Path(new URI(file.filePath)) + // SPARK-8501: Empty ORC files always have an empty schema stored in their footer. In this // case, `OrcFileOperator.readSchema` returns `None`, and we can't read the underlying file // using the given physical schema. Instead, we simply return an empty iterator. - val isEmptyFile = OrcFileOperator.readSchema(Seq(file.filePath), Some(conf)).isEmpty + val isEmptyFile = OrcFileOperator.readSchema(Seq(filePath.toString), Some(conf)).isEmpty if (isEmptyFile) { Iterator.empty } else { @@ -146,15 +148,12 @@ class OrcFileFormat extends FileFormat with DataSourceRegister with Serializable val job = Job.getInstance(conf) FileInputFormat.setInputPaths(job, file.filePath) - val fileSplit = new FileSplit( - new Path(new URI(file.filePath)), file.start, file.length, Array.empty - ) + val fileSplit = new FileSplit(filePath, file.start, file.length, Array.empty) // Custom OrcRecordReader is used to get // ObjectInspector during recordReader creation itself and can // avoid NameNode call in unwrapOrcStructs per file. // Specifically would be helpful for partitioned datasets. - val orcReader = OrcFile.createReader( - new Path(new URI(file.filePath)), OrcFile.readerOptions(conf)) + val orcReader = OrcFile.createReader(filePath, OrcFile.readerOptions(conf)) new SparkOrcNewRecordReader(orcReader, conf, fileSplit.getStart, fileSplit.getLength) } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala index a1060476f2211..c8caba83bf365 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala @@ -1350,7 +1350,8 @@ class MetastoreDataSourcesSuite extends QueryTest with SQLTestUtils with TestHiv withTempDir { dir => val tmpFile = s"$dir/$nameWithSpecialChars" spark.createDataset(Seq("a", "b")).write.format(format).save(tmpFile) - spark.read.format(format).load(tmpFile) + val fileContent = spark.read.format(format).load(tmpFile) + checkAnswer(fileContent, Seq(Row("a"), Row("b"))) } } } From 999ec137a97844abbbd483dd98c7ded2f8ff356c Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Fri, 1 Dec 2017 02:28:24 +0800 Subject: [PATCH 1749/1765] [SPARK-22570][SQL] Avoid to create a lot of global variables by using a local variable with allocation of an object in generated code ## What changes were proposed in this pull request? This PR reduces # of global variables in generated code by replacing a global variable with a local variable with an allocation of an object every time. When a lot of global variables were generated, the generated code may meet 64K constant pool limit. This PR reduces # of generated global variables in the following three operations: * `Cast` with String to primitive byte/short/int/long * `RegExpReplace` * `CreateArray` I intentionally leave [this part](https://github.com/apache/spark/blob/master/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala#L595-L603). This is because this variable keeps a class that is dynamically generated. In other word, it is not possible to reuse one class. ## How was this patch tested? Added test cases Author: Kazuaki Ishizaki Closes #19797 from kiszk/SPARK-22570. --- .../spark/sql/catalyst/expressions/Cast.scala | 24 ++++++------- .../expressions/complexTypeCreator.scala | 36 +++++++++++-------- .../expressions/regexpExpressions.scala | 8 ++--- .../sql/catalyst/expressions/CastSuite.scala | 8 +++++ .../expressions/RegexpExpressionsSuite.scala | 11 +++++- .../optimizer/complexTypesSuite.scala | 7 ++++ 6 files changed, 61 insertions(+), 33 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala index 8cafaef61c7d1..f4ecbdb8393ad 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala @@ -799,16 +799,16 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String private[this] def castToByteCode(from: DataType, ctx: CodegenContext): CastFunction = from match { case StringType => - val wrapper = ctx.freshName("wrapper") - ctx.addMutableState("UTF8String.IntWrapper", wrapper, - s"$wrapper = new UTF8String.IntWrapper();") + val wrapper = ctx.freshName("intWrapper") (c, evPrim, evNull) => s""" + UTF8String.IntWrapper $wrapper = new UTF8String.IntWrapper(); if ($c.toByte($wrapper)) { $evPrim = (byte) $wrapper.value; } else { $evNull = true; } + $wrapper = null; """ case BooleanType => (c, evPrim, evNull) => s"$evPrim = $c ? (byte) 1 : (byte) 0;" @@ -826,16 +826,16 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String from: DataType, ctx: CodegenContext): CastFunction = from match { case StringType => - val wrapper = ctx.freshName("wrapper") - ctx.addMutableState("UTF8String.IntWrapper", wrapper, - s"$wrapper = new UTF8String.IntWrapper();") + val wrapper = ctx.freshName("intWrapper") (c, evPrim, evNull) => s""" + UTF8String.IntWrapper $wrapper = new UTF8String.IntWrapper(); if ($c.toShort($wrapper)) { $evPrim = (short) $wrapper.value; } else { $evNull = true; } + $wrapper = null; """ case BooleanType => (c, evPrim, evNull) => s"$evPrim = $c ? (short) 1 : (short) 0;" @@ -851,16 +851,16 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String private[this] def castToIntCode(from: DataType, ctx: CodegenContext): CastFunction = from match { case StringType => - val wrapper = ctx.freshName("wrapper") - ctx.addMutableState("UTF8String.IntWrapper", wrapper, - s"$wrapper = new UTF8String.IntWrapper();") + val wrapper = ctx.freshName("intWrapper") (c, evPrim, evNull) => s""" + UTF8String.IntWrapper $wrapper = new UTF8String.IntWrapper(); if ($c.toInt($wrapper)) { $evPrim = $wrapper.value; } else { $evNull = true; } + $wrapper = null; """ case BooleanType => (c, evPrim, evNull) => s"$evPrim = $c ? 1 : 0;" @@ -876,17 +876,17 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String private[this] def castToLongCode(from: DataType, ctx: CodegenContext): CastFunction = from match { case StringType => - val wrapper = ctx.freshName("wrapper") - ctx.addMutableState("UTF8String.LongWrapper", wrapper, - s"$wrapper = new UTF8String.LongWrapper();") + val wrapper = ctx.freshName("longWrapper") (c, evPrim, evNull) => s""" + UTF8String.LongWrapper $wrapper = new UTF8String.LongWrapper(); if ($c.toLong($wrapper)) { $evPrim = $wrapper.value; } else { $evNull = true; } + $wrapper = null; """ case BooleanType => (c, evPrim, evNull) => s"$evPrim = $c ? 1L : 0L;" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala index 57a7f2e207738..fc68bf478e1c8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala @@ -63,7 +63,7 @@ case class CreateArray(children: Seq[Expression]) extends Expression { val (preprocess, assigns, postprocess, arrayData) = GenArrayData.genCodeToCreateArrayData(ctx, et, evals, false) ev.copy( - code = preprocess + ctx.splitExpressions(assigns) + postprocess, + code = preprocess + assigns + postprocess, value = arrayData, isNull = "false") } @@ -77,24 +77,22 @@ private [sql] object GenArrayData { * * @param ctx a [[CodegenContext]] * @param elementType data type of underlying array elements - * @param elementsCode a set of [[ExprCode]] for each element of an underlying array + * @param elementsCode concatenated set of [[ExprCode]] for each element of an underlying array * @param isMapKey if true, throw an exception when the element is null - * @return (code pre-assignments, assignments to each array elements, code post-assignments, - * arrayData name) + * @return (code pre-assignments, concatenated assignments to each array elements, + * code post-assignments, arrayData name) */ def genCodeToCreateArrayData( ctx: CodegenContext, elementType: DataType, elementsCode: Seq[ExprCode], - isMapKey: Boolean): (String, Seq[String], String, String) = { - val arrayName = ctx.freshName("array") + isMapKey: Boolean): (String, String, String, String) = { val arrayDataName = ctx.freshName("arrayData") val numElements = elementsCode.length if (!ctx.isPrimitiveType(elementType)) { + val arrayName = ctx.freshName("arrayObject") val genericArrayClass = classOf[GenericArrayData].getName - ctx.addMutableState("Object[]", arrayName, - s"$arrayName = new Object[$numElements];") val assignments = elementsCode.zipWithIndex.map { case (eval, i) => val isNullAssignment = if (!isMapKey) { @@ -110,17 +108,21 @@ private [sql] object GenArrayData { } """ } + val assignmentString = ctx.splitExpressions( + expressions = assignments, + funcName = "apply", + extraArguments = ("Object[]", arrayDataName) :: Nil) - ("", - assignments, + (s"Object[] $arrayName = new Object[$numElements];", + assignmentString, s"final ArrayData $arrayDataName = new $genericArrayClass($arrayName);", arrayDataName) } else { + val arrayName = ctx.freshName("array") val unsafeArraySizeInBytes = UnsafeArrayData.calculateHeaderPortionInBytes(numElements) + ByteArrayMethods.roundNumberOfBytesToNearestWord(elementType.defaultSize * numElements) val baseOffset = Platform.BYTE_ARRAY_OFFSET - ctx.addMutableState("UnsafeArrayData", arrayDataName) val primitiveValueTypeName = ctx.primitiveTypeName(elementType) val assignments = elementsCode.zipWithIndex.map { case (eval, i) => @@ -137,14 +139,18 @@ private [sql] object GenArrayData { } """ } + val assignmentString = ctx.splitExpressions( + expressions = assignments, + funcName = "apply", + extraArguments = ("UnsafeArrayData", arrayDataName) :: Nil) (s""" byte[] $arrayName = new byte[$unsafeArraySizeInBytes]; - $arrayDataName = new UnsafeArrayData(); + UnsafeArrayData $arrayDataName = new UnsafeArrayData(); Platform.putLong($arrayName, $baseOffset, $numElements); $arrayDataName.pointTo($arrayName, $baseOffset, $unsafeArraySizeInBytes); """, - assignments, + assignmentString, "", arrayDataName) } @@ -216,10 +222,10 @@ case class CreateMap(children: Seq[Expression]) extends Expression { s""" final boolean ${ev.isNull} = false; $preprocessKeyData - ${ctx.splitExpressions(assignKeys)} + $assignKeys $postprocessKeyData $preprocessValueData - ${ctx.splitExpressions(assignValues)} + $assignValues $postprocessValueData final MapData ${ev.value} = new $mapClass($keyArrayData, $valueArrayData); """ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala index d0d663f63f5db..53d7096dd87d3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala @@ -321,8 +321,7 @@ case class RegExpReplace(subject: Expression, regexp: Expression, rep: Expressio val termLastReplacement = ctx.freshName("lastReplacement") val termLastReplacementInUTF8 = ctx.freshName("lastReplacementInUTF8") - - val termResult = ctx.freshName("result") + val termResult = ctx.freshName("termResult") val classNamePattern = classOf[Pattern].getCanonicalName val classNameStringBuffer = classOf[java.lang.StringBuffer].getCanonicalName @@ -334,8 +333,6 @@ case class RegExpReplace(subject: Expression, regexp: Expression, rep: Expressio ctx.addMutableState("String", termLastReplacement, s"${termLastReplacement} = null;") ctx.addMutableState("UTF8String", termLastReplacementInUTF8, s"${termLastReplacementInUTF8} = null;") - ctx.addMutableState(classNameStringBuffer, - termResult, s"${termResult} = new $classNameStringBuffer();") val setEvNotNull = if (nullable) { s"${ev.isNull} = false;" @@ -355,7 +352,7 @@ case class RegExpReplace(subject: Expression, regexp: Expression, rep: Expressio ${termLastReplacementInUTF8} = $rep.clone(); ${termLastReplacement} = ${termLastReplacementInUTF8}.toString(); } - ${termResult}.delete(0, ${termResult}.length()); + $classNameStringBuffer ${termResult} = new $classNameStringBuffer(); java.util.regex.Matcher ${matcher} = ${termPattern}.matcher($subject.toString()); while (${matcher}.find()) { @@ -363,6 +360,7 @@ case class RegExpReplace(subject: Expression, regexp: Expression, rep: Expressio } ${matcher}.appendTail(${termResult}); ${ev.value} = UTF8String.fromString(${termResult}.toString()); + ${termResult} = null; $setEvNotNull """ }) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala index 7837d6529d12b..65617be05a434 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala @@ -23,6 +23,7 @@ import java.util.{Calendar, Locale, TimeZone} import org.apache.spark.SparkFunSuite import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.codegen.CodegenContext import org.apache.spark.sql.catalyst.util.DateTimeTestUtils._ import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.catalyst.util.DateTimeUtils.TimeZoneGMT @@ -845,4 +846,11 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper { val outputOuter = Row.fromSeq((1 to N).map(_ => outputInner)) checkEvaluation(cast(Literal.create(inputOuter, fromOuter), toOuter), outputOuter) } + + test("SPARK-22570: Cast should not create a lot of global variables") { + val ctx = new CodegenContext + cast("1", IntegerType).genCode(ctx) + cast("2", LongType).genCode(ctx) + assert(ctx.mutableStates.length == 0) + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/RegexpExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/RegexpExpressionsSuite.scala index 1ce150e091981..4fa61fbaf66c2 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/RegexpExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/RegexpExpressionsSuite.scala @@ -20,7 +20,8 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.SparkFunSuite import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.dsl.expressions._ -import org.apache.spark.sql.types.{IntegerType, StringType} +import org.apache.spark.sql.catalyst.expressions.codegen.CodegenContext +import org.apache.spark.sql.types.StringType /** * Unit tests for regular expression (regexp) related SQL expressions. @@ -178,6 +179,14 @@ class RegexpExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(nonNullExpr, "num-num", row1) } + test("SPARK-22570: RegExpReplace should not create a lot of global variables") { + val ctx = new CodegenContext + RegExpReplace(Literal("100"), Literal("(\\d+)"), Literal("num")).genCode(ctx) + // four global variables (lastRegex, pattern, lastReplacement, and lastReplacementInUTF8) + // are always required + assert(ctx.mutableStates.length == 4) + } + test("RegexExtract") { val row1 = create_row("100-200", "(\\d+)-(\\d+)", 1) val row2 = create_row("100-200", "(\\d+)-(\\d+)", 2) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/complexTypesSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/complexTypesSuite.scala index 3634accf1ec21..e3675367d78e4 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/complexTypesSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/complexTypesSuite.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.catalyst.optimizer import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.codegen.CodegenContext import org.apache.spark.sql.catalyst.plans.PlanTest import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan, Range} import org.apache.spark.sql.catalyst.rules.RuleExecutor @@ -164,6 +165,12 @@ class ComplexTypesSuite extends PlanTest{ comparePlans(Optimizer execute query, expected) } + test("SPARK-22570: CreateArray should not create a lot of global variables") { + val ctx = new CodegenContext + CreateArray(Seq(Literal(1))).genCode(ctx) + assert(ctx.mutableStates.length == 0) + } + test("simplify map ops") { val rel = relation .select( From 6ac57fd0d1c82b834eb4bf0dd57596b92a99d6de Mon Sep 17 00:00:00 2001 From: aokolnychyi Date: Thu, 30 Nov 2017 14:25:10 -0800 Subject: [PATCH 1750/1765] [SPARK-21417][SQL] Infer join conditions using propagated constraints ## What changes were proposed in this pull request? This PR adds an optimization rule that infers join conditions using propagated constraints. For instance, if there is a join, where the left relation has 'a = 1' and the right relation has 'b = 1', then the rule infers 'a = b' as a join predicate. Only semantically new predicates are appended to the existing join condition. Refer to the corresponding ticket and tests for more details. ## How was this patch tested? This patch comes with a new test suite to cover the implemented logic. Author: aokolnychyi Closes #18692 from aokolnychyi/spark-21417. --- .../expressions/EquivalentExpressionMap.scala | 66 +++++ .../catalyst/expressions/ExpressionSet.scala | 2 + .../sql/catalyst/optimizer/Optimizer.scala | 1 + .../spark/sql/catalyst/optimizer/joins.scala | 60 +++++ .../EquivalentExpressionMapSuite.scala | 56 +++++ .../optimizer/EliminateCrossJoinSuite.scala | 238 ++++++++++++++++++ 6 files changed, 423 insertions(+) create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressionMap.scala create mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressionMapSuite.scala create mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/EliminateCrossJoinSuite.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressionMap.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressionMap.scala new file mode 100644 index 0000000000000..cf1614afb1a76 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressionMap.scala @@ -0,0 +1,66 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions + +import scala.collection.mutable + +import org.apache.spark.sql.catalyst.expressions.EquivalentExpressionMap.SemanticallyEqualExpr + +/** + * A class that allows you to map an expression into a set of equivalent expressions. The keys are + * handled based on their semantic meaning and ignoring cosmetic differences. The values are + * represented as [[ExpressionSet]]s. + * + * The underlying representation of keys depends on the [[Expression.semanticHash]] and + * [[Expression.semanticEquals]] methods. + * + * {{{ + * val map = new EquivalentExpressionMap() + * + * map.put(1 + 2, a) + * map.put(rand(), b) + * + * map.get(2 + 1) => Set(a) // 1 + 2 and 2 + 1 are semantically equivalent + * map.get(1 + 2) => Set(a) // 1 + 2 and 2 + 1 are semantically equivalent + * map.get(rand()) => Set() // non-deterministic expressions are not equivalent + * }}} + */ +class EquivalentExpressionMap { + + private val equivalenceMap = mutable.HashMap.empty[SemanticallyEqualExpr, ExpressionSet] + + def put(expression: Expression, equivalentExpression: Expression): Unit = { + val equivalentExpressions = equivalenceMap.getOrElseUpdate(expression, ExpressionSet.empty) + equivalenceMap(expression) = equivalentExpressions + equivalentExpression + } + + def get(expression: Expression): Set[Expression] = + equivalenceMap.getOrElse(expression, ExpressionSet.empty) +} + +object EquivalentExpressionMap { + + private implicit class SemanticallyEqualExpr(val expr: Expression) { + override def equals(obj: Any): Boolean = obj match { + case other: SemanticallyEqualExpr => expr.semanticEquals(other.expr) + case _ => false + } + + override def hashCode: Int = expr.semanticHash() + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExpressionSet.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExpressionSet.scala index 7e8e7b8cd5f18..e9890837af07d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExpressionSet.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExpressionSet.scala @@ -27,6 +27,8 @@ object ExpressionSet { expressions.foreach(set.add) set } + + val empty: ExpressionSet = ExpressionSet(Nil) } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 0d961bf2e6e5e..8a5c486912abf 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -87,6 +87,7 @@ abstract class Optimizer(sessionCatalog: SessionCatalog) PushProjectionThroughUnion, ReorderJoin, EliminateOuterJoin, + EliminateCrossJoin, InferFiltersFromConstraints, BooleanSimplification, PushPredicateThroughJoin, diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/joins.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/joins.scala index edbeaf273fd6f..29a3a7f109b80 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/joins.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/joins.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.catalyst.optimizer import scala.annotation.tailrec +import scala.collection.mutable import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.planning.ExtractFiltersAndInnerJoins @@ -152,3 +153,62 @@ object EliminateOuterJoin extends Rule[LogicalPlan] with PredicateHelper { if (j.joinType == newJoinType) f else Filter(condition, j.copy(joinType = newJoinType)) } } + +/** + * A rule that eliminates CROSS joins by inferring join conditions from propagated constraints. + * + * The optimization is applicable only to CROSS joins. For other join types, adding inferred join + * conditions would potentially shuffle children as child node's partitioning won't satisfy the JOIN + * node's requirements which otherwise could have. + * + * For instance, given a CROSS join with the constraint 'a = 1' from the left child and the + * constraint 'b = 1' from the right child, this rule infers a new join predicate 'a = b' and + * converts it to an Inner join. + */ +object EliminateCrossJoin extends Rule[LogicalPlan] with PredicateHelper { + + def apply(plan: LogicalPlan): LogicalPlan = { + if (SQLConf.get.constraintPropagationEnabled) { + eliminateCrossJoin(plan) + } else { + plan + } + } + + private def eliminateCrossJoin(plan: LogicalPlan): LogicalPlan = plan transform { + case join @ Join(leftPlan, rightPlan, Cross, None) => + val leftConstraints = join.constraints.filter(_.references.subsetOf(leftPlan.outputSet)) + val rightConstraints = join.constraints.filter(_.references.subsetOf(rightPlan.outputSet)) + val inferredJoinPredicates = inferJoinPredicates(leftConstraints, rightConstraints) + val joinConditionOpt = inferredJoinPredicates.reduceOption(And) + if (joinConditionOpt.isDefined) Join(leftPlan, rightPlan, Inner, joinConditionOpt) else join + } + + private def inferJoinPredicates( + leftConstraints: Set[Expression], + rightConstraints: Set[Expression]): mutable.Set[EqualTo] = { + + val equivalentExpressionMap = new EquivalentExpressionMap() + + leftConstraints.foreach { + case EqualTo(attr: Attribute, expr: Expression) => + equivalentExpressionMap.put(expr, attr) + case EqualTo(expr: Expression, attr: Attribute) => + equivalentExpressionMap.put(expr, attr) + case _ => + } + + val joinConditions = mutable.Set.empty[EqualTo] + + rightConstraints.foreach { + case EqualTo(attr: Attribute, expr: Expression) => + joinConditions ++= equivalentExpressionMap.get(expr).map(EqualTo(attr, _)) + case EqualTo(expr: Expression, attr: Attribute) => + joinConditions ++= equivalentExpressionMap.get(expr).map(EqualTo(attr, _)) + case _ => + } + + joinConditions + } + +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressionMapSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressionMapSuite.scala new file mode 100644 index 0000000000000..bad7e17bb6cf2 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressionMapSuite.scala @@ -0,0 +1,56 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.dsl.expressions._ + +class EquivalentExpressionMapSuite extends SparkFunSuite { + + private val onePlusTwo = Literal(1) + Literal(2) + private val twoPlusOne = Literal(2) + Literal(1) + private val rand = Rand(10) + + test("behaviour of the equivalent expression map") { + val equivalentExpressionMap = new EquivalentExpressionMap() + equivalentExpressionMap.put(onePlusTwo, 'a) + equivalentExpressionMap.put(Literal(1) + Literal(3), 'b) + equivalentExpressionMap.put(rand, 'c) + + // 1 + 2 should be equivalent to 2 + 1 + assertResult(ExpressionSet(Seq('a)))(equivalentExpressionMap.get(twoPlusOne)) + // non-deterministic expressions should not be equivalent + assertResult(ExpressionSet.empty)(equivalentExpressionMap.get(rand)) + + // if the same (key, value) is added several times, the map still returns only one entry + equivalentExpressionMap.put(onePlusTwo, 'a) + equivalentExpressionMap.put(twoPlusOne, 'a) + assertResult(ExpressionSet(Seq('a)))(equivalentExpressionMap.get(twoPlusOne)) + + // get several equivalent attributes + equivalentExpressionMap.put(onePlusTwo, 'e) + assertResult(ExpressionSet(Seq('a, 'e)))(equivalentExpressionMap.get(onePlusTwo)) + assertResult(2)(equivalentExpressionMap.get(onePlusTwo).size) + + // several non-deterministic expressions should not be equivalent + equivalentExpressionMap.put(rand, 'd) + assertResult(ExpressionSet.empty)(equivalentExpressionMap.get(rand)) + assertResult(0)(equivalentExpressionMap.get(rand).size) + } + +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/EliminateCrossJoinSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/EliminateCrossJoinSuite.scala new file mode 100644 index 0000000000000..e04dd28ee36a0 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/EliminateCrossJoinSuite.scala @@ -0,0 +1,238 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.optimizer + +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.dsl.plans._ +import org.apache.spark.sql.catalyst.expressions.{Cast, Expression, Literal, Not, Rand} +import org.apache.spark.sql.catalyst.plans.{Cross, Inner, JoinType, PlanTest} +import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan} +import org.apache.spark.sql.catalyst.rules.RuleExecutor +import org.apache.spark.sql.internal.SQLConf.CONSTRAINT_PROPAGATION_ENABLED +import org.apache.spark.sql.types.IntegerType + +class EliminateCrossJoinSuite extends PlanTest { + + object Optimize extends RuleExecutor[LogicalPlan] { + val batches = + Batch("Eliminate cross joins", FixedPoint(10), + EliminateCrossJoin, + PushPredicateThroughJoin) :: Nil + } + + val testRelation1 = LocalRelation('a.int, 'b.int) + val testRelation2 = LocalRelation('c.int, 'd.int) + + test("successful elimination of cross joins (1)") { + checkJoinOptimization( + originalFilter = 'a === 1 && 'c === 1 && 'd === 1, + originalJoinType = Cross, + originalJoinCondition = None, + expectedFilter = None, + expectedLeftRelationFilter = 'a === 1, + expectedRightRelationFilter = 'c === 1 && 'd === 1, + expectedJoinType = Inner, + expectedJoinCondition = Some('a === 'c && 'a === 'd)) + } + + test("successful elimination of cross joins (2)") { + checkJoinOptimization( + originalFilter = 'a === 1 && 'b === 2 && 'd === 1, + originalJoinType = Cross, + originalJoinCondition = None, + expectedFilter = None, + expectedLeftRelationFilter = 'a === 1 && 'b === 2, + expectedRightRelationFilter = 'd === 1, + expectedJoinType = Inner, + expectedJoinCondition = Some('a === 'd)) + } + + test("successful elimination of cross joins (3)") { + // PushPredicateThroughJoin will push 'd === 'a into the join condition + // EliminateCrossJoin will NOT apply because the condition will be already present + // therefore, the join type will stay the same (i.e., CROSS) + checkJoinOptimization( + originalFilter = 'a === 1 && Literal(1) === 'd && 'd === 'a, + originalJoinType = Cross, + originalJoinCondition = None, + expectedFilter = None, + expectedLeftRelationFilter = 'a === 1, + expectedRightRelationFilter = Literal(1) === 'd, + expectedJoinType = Cross, + expectedJoinCondition = Some('a === 'd)) + } + + test("successful elimination of cross joins (4)") { + // Literal(1) * Literal(2) and Literal(2) * Literal(1) are semantically equal + checkJoinOptimization( + originalFilter = 'a === Literal(1) * Literal(2) && Literal(2) * Literal(1) === 'c, + originalJoinType = Cross, + originalJoinCondition = None, + expectedFilter = None, + expectedLeftRelationFilter = 'a === Literal(1) * Literal(2), + expectedRightRelationFilter = Literal(2) * Literal(1) === 'c, + expectedJoinType = Inner, + expectedJoinCondition = Some('a === 'c)) + } + + test("successful elimination of cross joins (5)") { + checkJoinOptimization( + originalFilter = 'a === 1 && Literal(1) === 'a && 'c === 1, + originalJoinType = Cross, + originalJoinCondition = None, + expectedFilter = None, + expectedLeftRelationFilter = 'a === 1 && Literal(1) === 'a, + expectedRightRelationFilter = 'c === 1, + expectedJoinType = Inner, + expectedJoinCondition = Some('a === 'c)) + } + + test("successful elimination of cross joins (6)") { + checkJoinOptimization( + originalFilter = 'a === Cast("1", IntegerType) && 'c === Cast("1", IntegerType) && 'd === 1, + originalJoinType = Cross, + originalJoinCondition = None, + expectedFilter = None, + expectedLeftRelationFilter = 'a === Cast("1", IntegerType), + expectedRightRelationFilter = 'c === Cast("1", IntegerType) && 'd === 1, + expectedJoinType = Inner, + expectedJoinCondition = Some('a === 'c)) + } + + test("successful elimination of cross joins (7)") { + // The join condition appears due to PushPredicateThroughJoin + checkJoinOptimization( + originalFilter = (('a >= 1 && 'c === 1) || 'd === 10) && 'b === 10 && 'c === 1, + originalJoinType = Cross, + originalJoinCondition = None, + expectedFilter = None, + expectedLeftRelationFilter = 'b === 10, + expectedRightRelationFilter = 'c === 1, + expectedJoinType = Cross, + expectedJoinCondition = Some(('a >= 1 && 'c === 1) || 'd === 10)) + } + + test("successful elimination of cross joins (8)") { + checkJoinOptimization( + originalFilter = 'a === 1 && 'c === 1 && Literal(1) === 'a && Literal(1) === 'c, + originalJoinType = Cross, + originalJoinCondition = None, + expectedFilter = None, + expectedLeftRelationFilter = 'a === 1 && Literal(1) === 'a, + expectedRightRelationFilter = 'c === 1 && Literal(1) === 'c, + expectedJoinType = Inner, + expectedJoinCondition = Some('a === 'c)) + } + + test("inability to detect join conditions when constant propagation is disabled") { + withSQLConf(CONSTRAINT_PROPAGATION_ENABLED.key -> "false") { + checkJoinOptimization( + originalFilter = 'a === 1 && 'c === 1 && 'd === 1, + originalJoinType = Cross, + originalJoinCondition = None, + expectedFilter = None, + expectedLeftRelationFilter = 'a === 1, + expectedRightRelationFilter = 'c === 1 && 'd === 1, + expectedJoinType = Cross, + expectedJoinCondition = None) + } + } + + test("inability to detect join conditions (1)") { + checkJoinOptimization( + originalFilter = 'a >= 1 && 'c === 1 && 'd >= 1, + originalJoinType = Cross, + originalJoinCondition = None, + expectedFilter = None, + expectedLeftRelationFilter = 'a >= 1, + expectedRightRelationFilter = 'c === 1 && 'd >= 1, + expectedJoinType = Cross, + expectedJoinCondition = None) + } + + test("inability to detect join conditions (2)") { + checkJoinOptimization( + originalFilter = Literal(1) === 'b && ('c === 1 || 'd === 1), + originalJoinType = Cross, + originalJoinCondition = None, + expectedFilter = None, + expectedLeftRelationFilter = Literal(1) === 'b, + expectedRightRelationFilter = 'c === 1 || 'd === 1, + expectedJoinType = Cross, + expectedJoinCondition = None) + } + + test("inability to detect join conditions (3)") { + checkJoinOptimization( + originalFilter = Literal(1) === 'b && 'c === 1, + originalJoinType = Cross, + originalJoinCondition = Some('c === 'b), + expectedFilter = None, + expectedLeftRelationFilter = Literal(1) === 'b, + expectedRightRelationFilter = 'c === 1, + expectedJoinType = Cross, + expectedJoinCondition = Some('c === 'b)) + } + + test("inability to detect join conditions (4)") { + checkJoinOptimization( + originalFilter = Not('a === 1) && 'd === 1, + originalJoinType = Cross, + originalJoinCondition = None, + expectedFilter = None, + expectedLeftRelationFilter = Not('a === 1), + expectedRightRelationFilter = 'd === 1, + expectedJoinType = Cross, + expectedJoinCondition = None) + } + + test("inability to detect join conditions (5)") { + checkJoinOptimization( + originalFilter = 'a === Rand(10) && 'b === 1 && 'd === Rand(10) && 'c === 3, + originalJoinType = Cross, + originalJoinCondition = None, + expectedFilter = Some('a === Rand(10) && 'd === Rand(10)), + expectedLeftRelationFilter = 'b === 1, + expectedRightRelationFilter = 'c === 3, + expectedJoinType = Cross, + expectedJoinCondition = None) + } + + private def checkJoinOptimization( + originalFilter: Expression, + originalJoinType: JoinType, + originalJoinCondition: Option[Expression], + expectedFilter: Option[Expression], + expectedLeftRelationFilter: Expression, + expectedRightRelationFilter: Expression, + expectedJoinType: JoinType, + expectedJoinCondition: Option[Expression]): Unit = { + + val originalQuery = testRelation1 + .join(testRelation2, originalJoinType, originalJoinCondition) + .where(originalFilter) + val optimizedQuery = Optimize.execute(originalQuery.analyze) + + val left = testRelation1.where(expectedLeftRelationFilter) + val right = testRelation2.where(expectedRightRelationFilter) + val join = left.join(right, expectedJoinType, expectedJoinCondition) + val expectedQuery = expectedFilter.fold(join)(join.where(_)).analyze + + comparePlans(optimizedQuery, expectedQuery) + } +} From bcceab649510a45f4c4b8e44b157c9987adff6f4 Mon Sep 17 00:00:00 2001 From: Yuming Wang Date: Thu, 30 Nov 2017 15:36:26 -0800 Subject: [PATCH 1751/1765] [SPARK-22489][SQL] Shouldn't change broadcast join buildSide if user clearly specified ## What changes were proposed in this pull request? How to reproduce: ```scala import org.apache.spark.sql.execution.joins.BroadcastHashJoinExec spark.createDataFrame(Seq((1, "4"), (2, "2"))).toDF("key", "value").createTempView("table1") spark.createDataFrame(Seq((1, "1"), (2, "2"))).toDF("key", "value").createTempView("table2") val bl = sql("SELECT /*+ MAPJOIN(t1) */ * FROM table1 t1 JOIN table2 t2 ON t1.key = t2.key").queryExecution.executedPlan println(bl.children.head.asInstanceOf[BroadcastHashJoinExec].buildSide) ``` The result is `BuildRight`, but should be `BuildLeft`. This PR fix this issue. ## How was this patch tested? unit tests Author: Yuming Wang Closes #19714 from wangyum/SPARK-22489. --- docs/sql-programming-guide.md | 58 ++++++++++++++++ .../spark/sql/execution/SparkStrategies.scala | 67 +++++++++++++----- .../execution/joins/BroadcastJoinSuite.scala | 69 ++++++++++++++++++- 3 files changed, 177 insertions(+), 17 deletions(-) diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md index 983770d506836..a1b9c3bbfd059 100644 --- a/docs/sql-programming-guide.md +++ b/docs/sql-programming-guide.md @@ -1492,6 +1492,64 @@ that these options will be deprecated in future release as more optimizations ar +## Broadcast Hint for SQL Queries + +The `BROADCAST` hint guides Spark to broadcast each specified table when joining them with another table or view. +When Spark deciding the join methods, the broadcast hash join (i.e., BHJ) is preferred, +even if the statistics is above the configuration `spark.sql.autoBroadcastJoinThreshold`. +When both sides of a join are specified, Spark broadcasts the one having the lower statistics. +Note Spark does not guarantee BHJ is always chosen, since not all cases (e.g. full outer join) +support BHJ. When the broadcast nested loop join is selected, we still respect the hint. + +
    + +
    + +{% highlight scala %} +import org.apache.spark.sql.functions.broadcast +broadcast(spark.table("src")).join(spark.table("records"), "key").show() +{% endhighlight %} + +
    + +
    + +{% highlight java %} +import static org.apache.spark.sql.functions.broadcast; +broadcast(spark.table("src")).join(spark.table("records"), "key").show(); +{% endhighlight %} + +
    + +
    + +{% highlight python %} +from pyspark.sql.functions import broadcast +broadcast(spark.table("src")).join(spark.table("records"), "key").show() +{% endhighlight %} + +
    + +
    + +{% highlight r %} +src <- sql("SELECT * FROM src") +records <- sql("SELECT * FROM records") +head(join(broadcast(src), records, src$key == records$key)) +{% endhighlight %} + +
    + +
    + +{% highlight sql %} +-- We accept BROADCAST, BROADCASTJOIN and MAPJOIN for broadcast hint +SELECT /*+ BROADCAST(r) */ * FROM records r JOIN src s ON r.key = s.key +{% endhighlight %} + +
    +
    + # Distributed SQL Engine Spark SQL can also act as a distributed query engine using its JDBC/ODBC or command-line interface. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 19b858faba6ea..1fe3cb1c8750a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -29,7 +29,7 @@ import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.execution.columnar.{InMemoryRelation, InMemoryTableScanExec} import org.apache.spark.sql.execution.command._ import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec -import org.apache.spark.sql.execution.joins.{BuildLeft, BuildRight} +import org.apache.spark.sql.execution.joins.{BuildLeft, BuildRight, BuildSide} import org.apache.spark.sql.execution.streaming._ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.streaming.StreamingQuery @@ -91,12 +91,12 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { * predicates can be evaluated by matching join keys. If found, Join implementations are chosen * with the following precedence: * - * - Broadcast: if one side of the join has an estimated physical size that is smaller than the - * user-configurable [[SQLConf.AUTO_BROADCASTJOIN_THRESHOLD]] threshold - * or if that side has an explicit broadcast hint (e.g. the user applied the - * [[org.apache.spark.sql.functions.broadcast()]] function to a DataFrame), then that side - * of the join will be broadcasted and the other side will be streamed, with no shuffling - * performed. If both sides of the join are eligible to be broadcasted then the + * - Broadcast: We prefer to broadcast the join side with an explicit broadcast hint(e.g. the + * user applied the [[org.apache.spark.sql.functions.broadcast()]] function to a DataFrame). + * If both sides have the broadcast hint, we prefer to broadcast the side with a smaller + * estimated physical size. If neither one of the sides has the broadcast hint, + * we only broadcast the join side if its estimated physical size that is smaller than + * the user-configurable [[SQLConf.AUTO_BROADCASTJOIN_THRESHOLD]] threshold. * - Shuffle hash join: if the average size of a single partition is small enough to build a hash * table. * - Sort merge: if the matching join keys are sortable. @@ -112,9 +112,7 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { * Matches a plan whose output should be small enough to be used in broadcast join. */ private def canBroadcast(plan: LogicalPlan): Boolean = { - plan.stats.hints.broadcast || - (plan.stats.sizeInBytes >= 0 && - plan.stats.sizeInBytes <= conf.autoBroadcastJoinThreshold) + plan.stats.sizeInBytes >= 0 && plan.stats.sizeInBytes <= conf.autoBroadcastJoinThreshold } /** @@ -149,10 +147,45 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { case _ => false } + private def broadcastSide( + canBuildLeft: Boolean, + canBuildRight: Boolean, + left: LogicalPlan, + right: LogicalPlan): BuildSide = { + + def smallerSide = + if (right.stats.sizeInBytes <= left.stats.sizeInBytes) BuildRight else BuildLeft + + val buildRight = canBuildRight && right.stats.hints.broadcast + val buildLeft = canBuildLeft && left.stats.hints.broadcast + + if (buildRight && buildLeft) { + // Broadcast smaller side base on its estimated physical size + // if both sides have broadcast hint + smallerSide + } else if (buildRight) { + BuildRight + } else if (buildLeft) { + BuildLeft + } else if (canBuildRight && canBuildLeft) { + // for the last default broadcast nested loop join + smallerSide + } else { + throw new AnalysisException("Can not decide which side to broadcast for this join") + } + } + def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { // --- BroadcastHashJoin -------------------------------------------------------------------- + case ExtractEquiJoinKeys(joinType, leftKeys, rightKeys, condition, left, right) + if (canBuildRight(joinType) && right.stats.hints.broadcast) || + (canBuildLeft(joinType) && left.stats.hints.broadcast) => + val buildSide = broadcastSide(canBuildLeft(joinType), canBuildRight(joinType), left, right) + Seq(joins.BroadcastHashJoinExec( + leftKeys, rightKeys, joinType, buildSide, condition, planLater(left), planLater(right))) + case ExtractEquiJoinKeys(joinType, leftKeys, rightKeys, condition, left, right) if canBuildRight(joinType) && canBroadcast(right) => Seq(joins.BroadcastHashJoinExec( @@ -189,6 +222,13 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { // --- Without joining keys ------------------------------------------------------------ // Pick BroadcastNestedLoopJoin if one side could be broadcasted + case j @ logical.Join(left, right, joinType, condition) + if (canBuildRight(joinType) && right.stats.hints.broadcast) || + (canBuildLeft(joinType) && left.stats.hints.broadcast) => + val buildSide = broadcastSide(canBuildLeft(joinType), canBuildRight(joinType), left, right) + joins.BroadcastNestedLoopJoinExec( + planLater(left), planLater(right), buildSide, joinType, condition) :: Nil + case j @ logical.Join(left, right, joinType, condition) if canBuildRight(joinType) && canBroadcast(right) => joins.BroadcastNestedLoopJoinExec( @@ -203,12 +243,7 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { joins.CartesianProductExec(planLater(left), planLater(right), condition) :: Nil case logical.Join(left, right, joinType, condition) => - val buildSide = - if (right.stats.sizeInBytes <= left.stats.sizeInBytes) { - BuildRight - } else { - BuildLeft - } + val buildSide = broadcastSide(canBuildLeft = true, canBuildRight = true, left, right) // This join could be very slow or OOM joins.BroadcastNestedLoopJoinExec( planLater(left), planLater(right), buildSide, joinType, condition) :: Nil diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala index a0fad862b44c7..67e2cdc7394bb 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala @@ -22,7 +22,7 @@ import scala.reflect.ClassTag import org.apache.spark.AccumulatorSuite import org.apache.spark.sql.{Dataset, QueryTest, Row, SparkSession} import org.apache.spark.sql.catalyst.expressions.{BitwiseAnd, BitwiseOr, Cast, Literal, ShiftLeft} -import org.apache.spark.sql.execution.SparkPlan +import org.apache.spark.sql.execution.{BinaryExecNode, SparkPlan, WholeStageCodegenExec} import org.apache.spark.sql.execution.exchange.EnsureRequirements import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf @@ -223,4 +223,71 @@ class BroadcastJoinSuite extends QueryTest with SQLTestUtils { assert(HashJoin.rewriteKeyExpr(l :: ss :: Nil) === l :: ss :: Nil) assert(HashJoin.rewriteKeyExpr(i :: ss :: Nil) === i :: ss :: Nil) } + + test("Shouldn't change broadcast join buildSide if user clearly specified") { + def assertJoinBuildSide(sqlStr: String, joinMethod: String, buildSide: BuildSide): Any = { + val executedPlan = sql(sqlStr).queryExecution.executedPlan + executedPlan match { + case b: BroadcastNestedLoopJoinExec => + assert(b.getClass.getSimpleName === joinMethod) + assert(b.buildSide === buildSide) + case w: WholeStageCodegenExec => + assert(w.children.head.getClass.getSimpleName === joinMethod) + assert(w.children.head.asInstanceOf[BroadcastHashJoinExec].buildSide === buildSide) + } + } + + withTempView("t1", "t2") { + spark.createDataFrame(Seq((1, "4"), (2, "2"))).toDF("key", "value").createTempView("t1") + spark.createDataFrame(Seq((1, "1"), (2, "12.3"), (2, "123"))).toDF("key", "value") + .createTempView("t2") + + val t1Size = spark.table("t1").queryExecution.analyzed.children.head.stats.sizeInBytes + val t2Size = spark.table("t2").queryExecution.analyzed.children.head.stats.sizeInBytes + assert(t1Size < t2Size) + + val bh = BroadcastHashJoinExec.toString + val bl = BroadcastNestedLoopJoinExec.toString + + // INNER JOIN && t1Size < t2Size => BuildLeft + assertJoinBuildSide( + "SELECT /*+ MAPJOIN(t1, t2) */ * FROM t1 JOIN t2 ON t1.key = t2.key", bh, BuildLeft) + // LEFT JOIN => BuildRight + assertJoinBuildSide( + "SELECT /*+ MAPJOIN(t1, t2) */ * FROM t1 LEFT JOIN t2 ON t1.key = t2.key", bh, BuildRight) + // RIGHT JOIN => BuildLeft + assertJoinBuildSide( + "SELECT /*+ MAPJOIN(t1, t2) */ * FROM t1 RIGHT JOIN t2 ON t1.key = t2.key", bh, BuildLeft) + // INNER JOIN && broadcast(t1) => BuildLeft + assertJoinBuildSide( + "SELECT /*+ MAPJOIN(t1) */ * FROM t1 JOIN t2 ON t1.key = t2.key", bh, BuildLeft) + // INNER JOIN && broadcast(t2) => BuildRight + assertJoinBuildSide( + "SELECT /*+ MAPJOIN(t2) */ * FROM t1 JOIN t2 ON t1.key = t2.key", bh, BuildRight) + + + withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "0", + SQLConf.CROSS_JOINS_ENABLED.key -> "true") { + // INNER JOIN && t1Size < t2Size => BuildLeft + assertJoinBuildSide("SELECT /*+ MAPJOIN(t1, t2) */ * FROM t1 JOIN t2", bl, BuildLeft) + // FULL JOIN && t1Size < t2Size => BuildLeft + assertJoinBuildSide("SELECT /*+ MAPJOIN(t1, t2) */ * FROM t1 FULL JOIN t2", bl, BuildLeft) + // LEFT JOIN => BuildRight + assertJoinBuildSide("SELECT /*+ MAPJOIN(t1, t2) */ * FROM t1 LEFT JOIN t2", bl, BuildRight) + // RIGHT JOIN => BuildLeft + assertJoinBuildSide("SELECT /*+ MAPJOIN(t1, t2) */ * FROM t1 RIGHT JOIN t2", bl, BuildLeft) + // INNER JOIN && broadcast(t1) => BuildLeft + assertJoinBuildSide("SELECT /*+ MAPJOIN(t1) */ * FROM t1 JOIN t2", bl, BuildLeft) + // INNER JOIN && broadcast(t2) => BuildRight + assertJoinBuildSide("SELECT /*+ MAPJOIN(t2) */ * FROM t1 JOIN t2", bl, BuildRight) + // FULL OUTER && broadcast(t1) => BuildLeft + assertJoinBuildSide("SELECT /*+ MAPJOIN(t1) */ * FROM t1 FULL OUTER JOIN t2", bl, BuildLeft) + // FULL OUTER && broadcast(t2) => BuildRight + assertJoinBuildSide( + "SELECT /*+ MAPJOIN(t2) */ * FROM t1 FULL OUTER JOIN t2", bl, BuildRight) + // FULL OUTER && t1Size < t2Size => BuildLeft + assertJoinBuildSide("SELECT * FROM t1 FULL OUTER JOIN t2", bl, BuildLeft) + } + } + } } From f5f8e84d9d35751dad51490b6ae22931aa88db7b Mon Sep 17 00:00:00 2001 From: Adrian Ionescu Date: Thu, 30 Nov 2017 15:41:34 -0800 Subject: [PATCH 1752/1765] [SPARK-22614] Dataset API: repartitionByRange(...) ## What changes were proposed in this pull request? This PR introduces a way to explicitly range-partition a Dataset. So far, only round-robin and hash partitioning were possible via `df.repartition(...)`, but sometimes range partitioning might be desirable: e.g. when writing to disk, for better compression without the cost of global sort. The current implementation piggybacks on the existing `RepartitionByExpression` `LogicalPlan` and simply adds the following logic: If its expressions are of type `SortOrder`, then it will do `RangePartitioning`; otherwise `HashPartitioning`. This was by far the least intrusive solution I could come up with. ## How was this patch tested? Unit test for `RepartitionByExpression` changes, a test to ensure we're not changing the behavior of existing `.repartition()` and a few end-to-end tests in `DataFrameSuite`. Author: Adrian Ionescu Closes #19828 from adrian-ionescu/repartitionByRange. --- .../plans/logical/basicLogicalOperators.scala | 20 +++++++ .../sql/catalyst/analysis/AnalysisSuite.scala | 26 +++++++++ .../scala/org/apache/spark/sql/Dataset.scala | 57 +++++++++++++++++-- .../spark/sql/execution/SparkStrategies.scala | 5 +- .../org/apache/spark/sql/DataFrameSuite.scala | 57 +++++++++++++++++++ 5 files changed, 157 insertions(+), 8 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala index c2750c3079814..93de7c1daf5c2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala @@ -23,6 +23,7 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical.statsEstimation._ +import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, Partitioning, RangePartitioning} import org.apache.spark.sql.types._ import org.apache.spark.util.Utils import org.apache.spark.util.random.RandomSampler @@ -838,6 +839,25 @@ case class RepartitionByExpression( require(numPartitions > 0, s"Number of partitions ($numPartitions) must be positive.") + val partitioning: Partitioning = { + val (sortOrder, nonSortOrder) = partitionExpressions.partition(_.isInstanceOf[SortOrder]) + + require(sortOrder.isEmpty || nonSortOrder.isEmpty, + s"${getClass.getSimpleName} expects that either all its `partitionExpressions` are of type " + + "`SortOrder`, which means `RangePartitioning`, or none of them are `SortOrder`, which " + + "means `HashPartitioning`. In this case we have:" + + s""" + |SortOrder: ${sortOrder} + |NonSortOrder: ${nonSortOrder} + """.stripMargin) + + if (sortOrder.nonEmpty) { + RangePartitioning(sortOrder.map(_.asInstanceOf[SortOrder]), numPartitions) + } else { + HashPartitioning(nonSortOrder, numPartitions) + } + } + override def maxRows: Option[Long] = child.maxRows override def shuffle: Boolean = true } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala index e56a5d6368318..0e2e706a31a05 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala @@ -27,6 +27,7 @@ import org.apache.spark.sql.catalyst.dsl.plans._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.{Cross, Inner} import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, Partitioning, RangePartitioning} import org.apache.spark.sql.types._ @@ -514,4 +515,29 @@ class AnalysisSuite extends AnalysisTest with Matchers { Seq("Number of column aliases does not match number of columns. " + "Number of column aliases: 5; number of columns: 4.")) } + + test("SPARK-22614 RepartitionByExpression partitioning") { + def checkPartitioning[T <: Partitioning](numPartitions: Int, exprs: Expression*): Unit = { + val partitioning = RepartitionByExpression(exprs, testRelation2, numPartitions).partitioning + assert(partitioning.isInstanceOf[T]) + } + + checkPartitioning[HashPartitioning](numPartitions = 10, exprs = Literal(20)) + checkPartitioning[HashPartitioning](numPartitions = 10, exprs = 'a.attr, 'b.attr) + + checkPartitioning[RangePartitioning](numPartitions = 10, + exprs = SortOrder(Literal(10), Ascending)) + checkPartitioning[RangePartitioning](numPartitions = 10, + exprs = SortOrder('a.attr, Ascending), SortOrder('b.attr, Descending)) + + intercept[IllegalArgumentException] { + checkPartitioning(numPartitions = 0, exprs = Literal(20)) + } + intercept[IllegalArgumentException] { + checkPartitioning(numPartitions = -1, exprs = Literal(20)) + } + intercept[IllegalArgumentException] { + checkPartitioning(numPartitions = 10, exprs = SortOrder('a.attr, Ascending), 'b.attr) + } + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index 1620ab3aa2094..167c9d050c3c4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -2732,8 +2732,18 @@ class Dataset[T] private[sql]( * @since 2.0.0 */ @scala.annotation.varargs - def repartition(numPartitions: Int, partitionExprs: Column*): Dataset[T] = withTypedPlan { - RepartitionByExpression(partitionExprs.map(_.expr), logicalPlan, numPartitions) + def repartition(numPartitions: Int, partitionExprs: Column*): Dataset[T] = { + // The underlying `LogicalPlan` operator special-cases all-`SortOrder` arguments. + // However, we don't want to complicate the semantics of this API method. + // Instead, let's give users a friendly error message, pointing them to the new method. + val sortOrders = partitionExprs.filter(_.expr.isInstanceOf[SortOrder]) + if (sortOrders.nonEmpty) throw new IllegalArgumentException( + s"""Invalid partitionExprs specified: $sortOrders + |For range partitioning use repartitionByRange(...) instead. + """.stripMargin) + withTypedPlan { + RepartitionByExpression(partitionExprs.map(_.expr), logicalPlan, numPartitions) + } } /** @@ -2747,9 +2757,46 @@ class Dataset[T] private[sql]( * @since 2.0.0 */ @scala.annotation.varargs - def repartition(partitionExprs: Column*): Dataset[T] = withTypedPlan { - RepartitionByExpression( - partitionExprs.map(_.expr), logicalPlan, sparkSession.sessionState.conf.numShufflePartitions) + def repartition(partitionExprs: Column*): Dataset[T] = { + repartition(sparkSession.sessionState.conf.numShufflePartitions, partitionExprs: _*) + } + + /** + * Returns a new Dataset partitioned by the given partitioning expressions into + * `numPartitions`. The resulting Dataset is range partitioned. + * + * At least one partition-by expression must be specified. + * When no explicit sort order is specified, "ascending nulls first" is assumed. + * + * @group typedrel + * @since 2.3.0 + */ + @scala.annotation.varargs + def repartitionByRange(numPartitions: Int, partitionExprs: Column*): Dataset[T] = { + require(partitionExprs.nonEmpty, "At least one partition-by expression must be specified.") + val sortOrder: Seq[SortOrder] = partitionExprs.map(_.expr match { + case expr: SortOrder => expr + case expr: Expression => SortOrder(expr, Ascending) + }) + withTypedPlan { + RepartitionByExpression(sortOrder, logicalPlan, numPartitions) + } + } + + /** + * Returns a new Dataset partitioned by the given partitioning expressions, using + * `spark.sql.shuffle.partitions` as number of partitions. + * The resulting Dataset is range partitioned. + * + * At least one partition-by expression must be specified. + * When no explicit sort order is specified, "ascending nulls first" is assumed. + * + * @group typedrel + * @since 2.3.0 + */ + @scala.annotation.varargs + def repartitionByRange(partitionExprs: Column*): Dataset[T] = { + repartitionByRange(sparkSession.sessionState.conf.numShufflePartitions, partitionExprs: _*) } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 1fe3cb1c8750a..9e713cd7bbe2b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -482,9 +482,8 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { execution.RDDScanExec(Nil, singleRowRdd, "OneRowRelation") :: Nil case r: logical.Range => execution.RangeExec(r) :: Nil - case logical.RepartitionByExpression(expressions, child, numPartitions) => - exchange.ShuffleExchangeExec(HashPartitioning( - expressions, numPartitions), planLater(child)) :: Nil + case r: logical.RepartitionByExpression => + exchange.ShuffleExchangeExec(r.partitioning, planLater(r.child)) :: Nil case ExternalRDD(outputObjAttr, rdd) => ExternalRDDScanExec(outputObjAttr, rdd) :: Nil case r: LogicalRDD => RDDScanExec(r.output, r.rdd, "ExistingRDD", r.outputPartitioning, r.outputOrdering) :: Nil diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index 72a5cc98fbec3..5e4c1a6a484fb 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -358,6 +358,63 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { testData.select('key).collect().toSeq) } + test("repartition with SortOrder") { + // passing SortOrder expressions to .repartition() should result in an informative error + + def checkSortOrderErrorMsg[T](data: => Dataset[T]): Unit = { + val ex = intercept[IllegalArgumentException](data) + assert(ex.getMessage.contains("repartitionByRange")) + } + + checkSortOrderErrorMsg { + Seq(0).toDF("a").repartition(2, $"a".asc) + } + + checkSortOrderErrorMsg { + Seq((0, 0)).toDF("a", "b").repartition(2, $"a".asc, $"b") + } + } + + test("repartitionByRange") { + val data1d = Random.shuffle(0.to(9)) + val data2d = data1d.map(i => (i, data1d.size - i)) + + checkAnswer( + data1d.toDF("val").repartitionByRange(data1d.size, $"val".asc) + .select(spark_partition_id().as("id"), $"val"), + data1d.map(i => Row(i, i))) + + checkAnswer( + data1d.toDF("val").repartitionByRange(data1d.size, $"val".desc) + .select(spark_partition_id().as("id"), $"val"), + data1d.map(i => Row(i, data1d.size - 1 - i))) + + checkAnswer( + data1d.toDF("val").repartitionByRange(data1d.size, lit(42)) + .select(spark_partition_id().as("id"), $"val"), + data1d.map(i => Row(0, i))) + + checkAnswer( + data1d.toDF("val").repartitionByRange(data1d.size, lit(null), $"val".asc, rand()) + .select(spark_partition_id().as("id"), $"val"), + data1d.map(i => Row(i, i))) + + // .repartitionByRange() assumes .asc by default if no explicit sort order is specified + checkAnswer( + data2d.toDF("a", "b").repartitionByRange(data2d.size, $"a".desc, $"b") + .select(spark_partition_id().as("id"), $"a", $"b"), + data2d.toDF("a", "b").repartitionByRange(data2d.size, $"a".desc, $"b".asc) + .select(spark_partition_id().as("id"), $"a", $"b")) + + // at least one partition-by expression must be specified + intercept[IllegalArgumentException] { + data1d.toDF("val").repartitionByRange(data1d.size) + } + intercept[IllegalArgumentException] { + data1d.toDF("val").repartitionByRange(data1d.size, Seq.empty: _*) + } + } + test("coalesce") { intercept[IllegalArgumentException] { testData.select('key).coalesce(0) From 7e5f669eb684629c88218f8ec26c01a41a6fef32 Mon Sep 17 00:00:00 2001 From: gaborgsomogyi Date: Thu, 30 Nov 2017 19:20:32 -0600 Subject: [PATCH 1753/1765] =?UTF-8?q?[SPARK-22428][DOC]=20Add=20spark=20ap?= =?UTF-8?q?plication=20garbage=20collector=20configurat=E2=80=A6?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## What changes were proposed in this pull request? The spark properties for configuring the ContextCleaner are not documented in the official documentation at https://spark.apache.org/docs/latest/configuration.html#available-properties. This PR adds the doc. ## How was this patch tested? Manual. ``` cd docs jekyll build open _site/configuration.html ``` Author: gaborgsomogyi Closes #19826 from gaborgsomogyi/SPARK-22428. --- docs/configuration.md | 40 ++++++++++++++++++++++++++++++++++++++++ 1 file changed, 40 insertions(+) diff --git a/docs/configuration.md b/docs/configuration.md index e42f866c40566..ef061dd39dcba 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -1132,6 +1132,46 @@ Apart from these, the following properties are also available, and may be useful to get the replication level of the block to the initial number. + + spark.cleaner.periodicGC.interval + 30min + + Controls how often to trigger a garbage collection.

    + This context cleaner triggers cleanups only when weak references are garbage collected. + In long-running applications with large driver JVMs, where there is little memory pressure + on the driver, this may happen very occasionally or not at all. Not cleaning at all may + lead to executors running out of disk space after a while. + + + + spark.cleaner.referenceTracking + true + + Enables or disables context cleaning. + + + + spark.cleaner.referenceTracking.blocking + true + + Controls whether the cleaning thread should block on cleanup tasks (other than shuffle, which is controlled by + spark.cleaner.referenceTracking.blocking.shuffle Spark property). + + + + spark.cleaner.referenceTracking.blocking.shuffle + false + + Controls whether the cleaning thread should block on shuffle cleanup tasks. + + + + spark.cleaner.referenceTracking.cleanCheckpoints + false + + Controls whether to clean checkpoint files if the reference is out of scope. + + ### Execution Behavior From 7da1f5708cc96c18ddb3acd09542621275e71d83 Mon Sep 17 00:00:00 2001 From: Min Shen Date: Thu, 30 Nov 2017 19:24:44 -0600 Subject: [PATCH 1754/1765] =?UTF-8?q?[SPARK-22373]=20Bump=20Janino=20depen?= =?UTF-8?q?dency=20version=20to=20fix=20thread=20safety=20issue=E2=80=A6?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit … with Janino when compiling generated code. ## What changes were proposed in this pull request? Bump up Janino dependency version to fix thread safety issue during compiling generated code ## How was this patch tested? Check https://issues.apache.org/jira/browse/SPARK-22373 for details. Converted part of the code in CodeGenerator into a standalone application, so the issue can be consistently reproduced locally. Verified that changing Janino dependency version resolved this issue. Author: Min Shen Closes #19839 from Victsm/SPARK-22373. --- dev/deps/spark-deps-hadoop-2.6 | 4 ++-- dev/deps/spark-deps-hadoop-2.7 | 4 ++-- pom.xml | 2 +- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/dev/deps/spark-deps-hadoop-2.6 b/dev/deps/spark-deps-hadoop-2.6 index 50ac6d139bbd4..8f508219c2ded 100644 --- a/dev/deps/spark-deps-hadoop-2.6 +++ b/dev/deps/spark-deps-hadoop-2.6 @@ -35,7 +35,7 @@ commons-beanutils-core-1.8.0.jar commons-cli-1.2.jar commons-codec-1.10.jar commons-collections-3.2.2.jar -commons-compiler-3.0.0.jar +commons-compiler-3.0.7.jar commons-compress-1.4.1.jar commons-configuration-1.6.jar commons-crypto-1.0.0.jar @@ -96,7 +96,7 @@ jackson-mapper-asl-1.9.13.jar jackson-module-paranamer-2.7.9.jar jackson-module-scala_2.11-2.6.7.1.jar jackson-xc-1.9.13.jar -janino-3.0.0.jar +janino-3.0.7.jar java-xmlbuilder-1.0.jar javassist-3.18.1-GA.jar javax.annotation-api-1.2.jar diff --git a/dev/deps/spark-deps-hadoop-2.7 b/dev/deps/spark-deps-hadoop-2.7 index 1b1e3166d53db..68e937f50b391 100644 --- a/dev/deps/spark-deps-hadoop-2.7 +++ b/dev/deps/spark-deps-hadoop-2.7 @@ -35,7 +35,7 @@ commons-beanutils-core-1.8.0.jar commons-cli-1.2.jar commons-codec-1.10.jar commons-collections-3.2.2.jar -commons-compiler-3.0.0.jar +commons-compiler-3.0.7.jar commons-compress-1.4.1.jar commons-configuration-1.6.jar commons-crypto-1.0.0.jar @@ -96,7 +96,7 @@ jackson-mapper-asl-1.9.13.jar jackson-module-paranamer-2.7.9.jar jackson-module-scala_2.11-2.6.7.1.jar jackson-xc-1.9.13.jar -janino-3.0.0.jar +janino-3.0.7.jar java-xmlbuilder-1.0.jar javassist-3.18.1-GA.jar javax.annotation-api-1.2.jar diff --git a/pom.xml b/pom.xml index 7bc66e7d19540..731ee86439eff 100644 --- a/pom.xml +++ b/pom.xml @@ -170,7 +170,7 @@ 3.5 3.2.10 - 3.0.0 + 3.0.7 2.22.2 2.9.3 3.5.2 From dc365422bb337d19ef39739c7c3cf9e53ec85d09 Mon Sep 17 00:00:00 2001 From: Thomas Graves Date: Fri, 1 Dec 2017 10:53:16 +0800 Subject: [PATCH 1755/1765] =?UTF-8?q?[SPARK-22653]=20executorAddress=20reg?= =?UTF-8?q?istered=20in=20CoarseGrainedSchedulerBac=E2=80=A6?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit https://issues.apache.org/jira/browse/SPARK-22653 executorRef.address can be null, pass the executorAddress which accounts for it being null a few lines above the fix. Manually tested this patch. You can reproduce the issue by running a simple spark-shell in yarn client mode with dynamic allocation and request some executors up front. Let those executors idle timeout. Get a heap dump. Without this fix, you will see that addressToExecutorId still contains the ids, with the fix addressToExecutorId is properly cleaned up. Author: Thomas Graves Closes #19850 from tgravescs/SPARK-22653. --- .../spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala index 22d9c4cf81c55..7bfb4d53c1834 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala @@ -182,7 +182,7 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp addressToExecutorId(executorAddress) = executorId totalCoreCount.addAndGet(cores) totalRegisteredExecutors.addAndGet(1) - val data = new ExecutorData(executorRef, executorRef.address, hostname, + val data = new ExecutorData(executorRef, executorAddress, hostname, cores, cores, logUrls) // This must be synchronized because variables mutated // in this block are read when requesting executors From 16adaf634bcca3074b448d95e72177eefdf50069 Mon Sep 17 00:00:00 2001 From: sujith71955 Date: Thu, 30 Nov 2017 20:45:30 -0800 Subject: [PATCH 1756/1765] [SPARK-22601][SQL] Data load is getting displayed successful on providing non existing nonlocal file path ## What changes were proposed in this pull request? When user tries to load data with a non existing hdfs file path system is not validating it and the load command operation is getting successful. This is misleading to the user. already there is a validation in the scenario of none existing local file path. This PR has added validation in the scenario of nonexisting hdfs file path ## How was this patch tested? UT has been added for verifying the issue, also snapshots has been added after the verification in a spark yarn cluster Author: sujith71955 Closes #19823 from sujith71955/master_LoadComand_Issue. --- .../org/apache/spark/sql/execution/command/tables.scala | 9 ++++++++- .../apache/spark/sql/hive/execution/HiveDDLSuite.scala | 9 +++++++++ 2 files changed, 17 insertions(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala index c9f6e571ddab3..c42e6c3257fad 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala @@ -340,7 +340,7 @@ case class LoadDataCommand( uri } else { val uri = new URI(path) - if (uri.getScheme() != null && uri.getAuthority() != null) { + val hdfsUri = if (uri.getScheme() != null && uri.getAuthority() != null) { uri } else { // Follow Hive's behavior: @@ -380,6 +380,13 @@ case class LoadDataCommand( } new URI(scheme, authority, absolutePath, uri.getQuery(), uri.getFragment()) } + val hadoopConf = sparkSession.sessionState.newHadoopConf() + val srcPath = new Path(hdfsUri) + val fs = srcPath.getFileSystem(hadoopConf) + if (!fs.exists(srcPath)) { + throw new AnalysisException(s"LOAD DATA input path does not exist: $path") + } + hdfsUri } if (partition.nonEmpty) { diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala index 9063ef066aa84..6c11905ba8904 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala @@ -2141,4 +2141,13 @@ class HiveDDLSuite } } } + + test("load command for non local invalid path validation") { + withTable("tbl") { + sql("CREATE TABLE tbl(i INT, j STRING)") + val e = intercept[AnalysisException]( + sql("load data inpath '/doesnotexist.csv' into table tbl")) + assert(e.message.contains("LOAD DATA input path does not exist")) + } + } } From 9d06a9e0cf05af99ba210fabae1e77eccfce7986 Mon Sep 17 00:00:00 2001 From: Mark Petruska Date: Fri, 1 Dec 2017 05:14:12 -0600 Subject: [PATCH 1757/1765] [SPARK-22393][SPARK-SHELL] spark-shell can't find imported types in class constructors, extends clause ## What changes were proposed in this pull request? [SPARK-22393](https://issues.apache.org/jira/browse/SPARK-22393) ## How was this patch tested? With a new test case in `RepSuite` ---- This code is a retrofit of the Scala [SI-9881](https://github.com/scala/bug/issues/9881) bug fix, which never made it into the Scala 2.11 branches. Pushing these changes directly to the Scala repo is not practical (see: https://github.com/scala/scala/pull/6195). Author: Mark Petruska Closes #19846 from mpetruska/SPARK-22393. --- .../apache/spark/repl/SparkExprTyper.scala | 74 +++++++++++++ .../org/apache/spark/repl/SparkILoop.scala | 4 + .../spark/repl/SparkILoopInterpreter.scala | 103 ++++++++++++++++++ .../org/apache/spark/repl/ReplSuite.scala | 10 ++ 4 files changed, 191 insertions(+) create mode 100644 repl/scala-2.11/src/main/scala/org/apache/spark/repl/SparkExprTyper.scala create mode 100644 repl/scala-2.11/src/main/scala/org/apache/spark/repl/SparkILoopInterpreter.scala diff --git a/repl/scala-2.11/src/main/scala/org/apache/spark/repl/SparkExprTyper.scala b/repl/scala-2.11/src/main/scala/org/apache/spark/repl/SparkExprTyper.scala new file mode 100644 index 0000000000000..724ce9af49f77 --- /dev/null +++ b/repl/scala-2.11/src/main/scala/org/apache/spark/repl/SparkExprTyper.scala @@ -0,0 +1,74 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.repl + +import scala.tools.nsc.interpreter.{ExprTyper, IR} + +trait SparkExprTyper extends ExprTyper { + + import repl._ + import global.{reporter => _, Import => _, _} + import naming.freshInternalVarName + + def doInterpret(code: String): IR.Result = { + // interpret/interpretSynthetic may change the phase, + // which would have unintended effects on types. + val savedPhase = phase + try interpretSynthetic(code) finally phase = savedPhase + } + + override def symbolOfLine(code: String): Symbol = { + def asExpr(): Symbol = { + val name = freshInternalVarName() + // Typing it with a lazy val would give us the right type, but runs + // into compiler bugs with things like existentials, so we compile it + // behind a def and strip the NullaryMethodType which wraps the expr. + val line = "def " + name + " = " + code + + doInterpret(line) match { + case IR.Success => + val sym0 = symbolOfTerm(name) + // drop NullaryMethodType + sym0.cloneSymbol setInfo exitingTyper(sym0.tpe_*.finalResultType) + case _ => NoSymbol + } + } + + def asDefn(): Symbol = { + val old = repl.definedSymbolList.toSet + + doInterpret(code) match { + case IR.Success => + repl.definedSymbolList filterNot old match { + case Nil => NoSymbol + case sym :: Nil => sym + case syms => NoSymbol.newOverloaded(NoPrefix, syms) + } + case _ => NoSymbol + } + } + + def asError(): Symbol = { + doInterpret(code) + NoSymbol + } + + beSilentDuring(asExpr()) orElse beSilentDuring(asDefn()) orElse asError() + } + +} diff --git a/repl/scala-2.11/src/main/scala/org/apache/spark/repl/SparkILoop.scala b/repl/scala-2.11/src/main/scala/org/apache/spark/repl/SparkILoop.scala index 3ce7cc7c85f74..e69441a475e9a 100644 --- a/repl/scala-2.11/src/main/scala/org/apache/spark/repl/SparkILoop.scala +++ b/repl/scala-2.11/src/main/scala/org/apache/spark/repl/SparkILoop.scala @@ -35,6 +35,10 @@ class SparkILoop(in0: Option[BufferedReader], out: JPrintWriter) def this(in0: BufferedReader, out: JPrintWriter) = this(Some(in0), out) def this() = this(None, new JPrintWriter(Console.out, true)) + override def createInterpreter(): Unit = { + intp = new SparkILoopInterpreter(settings, out) + } + val initializationCommands: Seq[String] = Seq( """ @transient val spark = if (org.apache.spark.repl.Main.sparkSession != null) { diff --git a/repl/scala-2.11/src/main/scala/org/apache/spark/repl/SparkILoopInterpreter.scala b/repl/scala-2.11/src/main/scala/org/apache/spark/repl/SparkILoopInterpreter.scala new file mode 100644 index 0000000000000..0803426403af5 --- /dev/null +++ b/repl/scala-2.11/src/main/scala/org/apache/spark/repl/SparkILoopInterpreter.scala @@ -0,0 +1,103 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.repl + +import scala.tools.nsc.Settings +import scala.tools.nsc.interpreter._ + +class SparkILoopInterpreter(settings: Settings, out: JPrintWriter) extends IMain(settings, out) { + self => + + override lazy val memberHandlers = new { + val intp: self.type = self + } with MemberHandlers { + import intp.global._ + + override def chooseHandler(member: intp.global.Tree): MemberHandler = member match { + case member: Import => new SparkImportHandler(member) + case _ => super.chooseHandler (member) + } + + class SparkImportHandler(imp: Import) extends ImportHandler(imp: Import) { + + override def targetType: Type = intp.global.rootMirror.getModuleIfDefined("" + expr) match { + case NoSymbol => intp.typeOfExpression("" + expr) + case sym => sym.tpe + } + + private def safeIndexOf(name: Name, s: String): Int = fixIndexOf(name, pos(name, s)) + private def fixIndexOf(name: Name, idx: Int): Int = if (idx == name.length) -1 else idx + private def pos(name: Name, s: String): Int = { + var i = name.pos(s.charAt(0), 0) + val sLen = s.length() + if (sLen == 1) return i + while (i + sLen <= name.length) { + var j = 1 + while (s.charAt(j) == name.charAt(i + j)) { + j += 1 + if (j == sLen) return i + } + i = name.pos(s.charAt(0), i + 1) + } + name.length + } + + private def isFlattenedSymbol(sym: Symbol): Boolean = + sym.owner.isPackageClass && + sym.name.containsName(nme.NAME_JOIN_STRING) && + sym.owner.info.member(sym.name.take( + safeIndexOf(sym.name, nme.NAME_JOIN_STRING))) != NoSymbol + + private def importableTargetMembers = + importableMembers(exitingTyper(targetType)).filterNot(isFlattenedSymbol).toList + + def isIndividualImport(s: ImportSelector): Boolean = + s.name != nme.WILDCARD && s.rename != nme.WILDCARD + def isWildcardImport(s: ImportSelector): Boolean = + s.name == nme.WILDCARD + + // non-wildcard imports + private def individualSelectors = selectors filter isIndividualImport + + override val importsWildcard: Boolean = selectors exists isWildcardImport + + lazy val importableSymbolsWithRenames: List[(Symbol, Name)] = { + val selectorRenameMap = + individualSelectors.flatMap(x => x.name.bothNames zip x.rename.bothNames).toMap + importableTargetMembers flatMap (m => selectorRenameMap.get(m.name) map (m -> _)) + } + + override lazy val individualSymbols: List[Symbol] = importableSymbolsWithRenames map (_._1) + override lazy val wildcardSymbols: List[Symbol] = + if (importsWildcard) importableTargetMembers else Nil + + } + + } + + object expressionTyper extends { + val repl: SparkILoopInterpreter.this.type = self + } with SparkExprTyper { } + + override def symbolOfLine(code: String): global.Symbol = + expressionTyper.symbolOfLine(code) + + override def typeOfExpression(expr: String, silent: Boolean): global.Type = + expressionTyper.typeOfExpression(expr, silent) + +} diff --git a/repl/src/test/scala/org/apache/spark/repl/ReplSuite.scala b/repl/src/test/scala/org/apache/spark/repl/ReplSuite.scala index 905b41cdc1594..a5053521f8e31 100644 --- a/repl/src/test/scala/org/apache/spark/repl/ReplSuite.scala +++ b/repl/src/test/scala/org/apache/spark/repl/ReplSuite.scala @@ -227,4 +227,14 @@ class ReplSuite extends SparkFunSuite { assertDoesNotContain("error: not found: value sc", output) } + test("spark-shell should find imported types in class constructors and extends clause") { + val output = runInterpreter("local", + """ + |import org.apache.spark.Partition + |class P(p: Partition) + |class P(val index: Int) extends Partition + """.stripMargin) + assertDoesNotContain("error: not found: type Partition", output) + } + } From ee10ca7ec6cf7fbaab3f95a097b46936d97d0835 Mon Sep 17 00:00:00 2001 From: Shixiong Zhu Date: Fri, 1 Dec 2017 13:02:03 -0800 Subject: [PATCH 1758/1765] [SPARK-22638][SS] Use a separate queue for StreamingQueryListenerBus ## What changes were proposed in this pull request? Use a separate Spark event queue for StreamingQueryListenerBus so that if there are many non-streaming events, streaming query listeners don't need to wait for other Spark listeners and can catch up. ## How was this patch tested? Jenkins Author: Shixiong Zhu Closes #19838 from zsxwing/SPARK-22638. --- .../scala/org/apache/spark/scheduler/LiveListenerBus.scala | 4 +++- .../sql/execution/streaming/StreamingQueryListenerBus.scala | 6 +++++- 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/scheduler/LiveListenerBus.scala b/core/src/main/scala/org/apache/spark/scheduler/LiveListenerBus.scala index 2f93c497c5771..23121402b1025 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/LiveListenerBus.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/LiveListenerBus.scala @@ -87,7 +87,9 @@ private[spark] class LiveListenerBus(conf: SparkConf) { * of each other (each one uses a separate thread for delivering events), allowing slower * listeners to be somewhat isolated from others. */ - private def addToQueue(listener: SparkListenerInterface, queue: String): Unit = synchronized { + private[spark] def addToQueue( + listener: SparkListenerInterface, + queue: String): Unit = synchronized { if (stopped.get()) { throw new IllegalStateException("LiveListenerBus is stopped.") } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingQueryListenerBus.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingQueryListenerBus.scala index 07e39023c8366..7dd491ede9d05 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingQueryListenerBus.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingQueryListenerBus.scala @@ -40,7 +40,7 @@ class StreamingQueryListenerBus(sparkListenerBus: LiveListenerBus) import StreamingQueryListener._ - sparkListenerBus.addToSharedQueue(this) + sparkListenerBus.addToQueue(this, StreamingQueryListenerBus.STREAM_EVENT_QUERY) /** * RunIds of active queries whose events are supposed to be forwarded by this ListenerBus @@ -130,3 +130,7 @@ class StreamingQueryListenerBus(sparkListenerBus: LiveListenerBus) } } } + +object StreamingQueryListenerBus { + val STREAM_EVENT_QUERY = "streams" +} From aa4cf2b19e4cf5588af7e2192e0e9f687cd84bc5 Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Sat, 2 Dec 2017 11:55:43 +0900 Subject: [PATCH 1759/1765] [SPARK-22651][PYTHON][ML] Prevent initiating multiple Hive clients for ImageSchema.readImages ## What changes were proposed in this pull request? Calling `ImageSchema.readImages` multiple times as below in PySpark shell: ```python from pyspark.ml.image import ImageSchema data_path = 'data/mllib/images/kittens' _ = ImageSchema.readImages(data_path, recursive=True, dropImageFailures=True).collect() _ = ImageSchema.readImages(data_path, recursive=True, dropImageFailures=True).collect() ``` throws an error as below: ``` ... org.datanucleus.exceptions.NucleusDataStoreException: Unable to open a test connection to the given database. JDBC url = jdbc:derby:;databaseName=metastore_db;create=true, username = APP. Terminating connection pool (set lazyInit to true if you expect to start your database after your app). Original Exception: ------ java.sql.SQLException: Failed to start database 'metastore_db' with class loader org.apache.spark.sql.hive.client.IsolatedClientLoader$$anon$1742f639f, see the next exception for details. ... at org.apache.derby.jdbc.AutoloadedDriver.connect(Unknown Source) ... at org.apache.hadoop.hive.metastore.HiveMetaStore.newRetryingHMSHandler(HiveMetaStore.java:5762) ... at org.apache.spark.sql.hive.client.HiveClientImpl.newState(HiveClientImpl.scala:180) ... at org.apache.spark.sql.hive.HiveExternalCatalog$$anonfun$databaseExists$1.apply$mcZ$sp(HiveExternalCatalog.scala:195) at org.apache.spark.sql.hive.HiveExternalCatalog$$anonfun$databaseExists$1.apply(HiveExternalCatalog.scala:195) at org.apache.spark.sql.hive.HiveExternalCatalog$$anonfun$databaseExists$1.apply(HiveExternalCatalog.scala:195) at org.apache.spark.sql.hive.HiveExternalCatalog.withClient(HiveExternalCatalog.scala:97) at org.apache.spark.sql.hive.HiveExternalCatalog.databaseExists(HiveExternalCatalog.scala:194) at org.apache.spark.sql.internal.SharedState.externalCatalog$lzycompute(SharedState.scala:100) at org.apache.spark.sql.internal.SharedState.externalCatalog(SharedState.scala:88) at org.apache.spark.sql.hive.HiveSessionStateBuilder.externalCatalog(HiveSessionStateBuilder.scala:39) at org.apache.spark.sql.hive.HiveSessionStateBuilder.catalog$lzycompute(HiveSessionStateBuilder.scala:54) at org.apache.spark.sql.hive.HiveSessionStateBuilder.catalog(HiveSessionStateBuilder.scala:52) at org.apache.spark.sql.hive.HiveSessionStateBuilder$$anon$1.(HiveSessionStateBuilder.scala:69) at org.apache.spark.sql.hive.HiveSessionStateBuilder.analyzer(HiveSessionStateBuilder.scala:69) at org.apache.spark.sql.internal.BaseSessionStateBuilder$$anonfun$build$2.apply(BaseSessionStateBuilder.scala:293) at org.apache.spark.sql.internal.BaseSessionStateBuilder$$anonfun$build$2.apply(BaseSessionStateBuilder.scala:293) at org.apache.spark.sql.internal.SessionState.analyzer$lzycompute(SessionState.scala:79) at org.apache.spark.sql.internal.SessionState.analyzer(SessionState.scala:79) at org.apache.spark.sql.execution.QueryExecution.analyzed$lzycompute(QueryExecution.scala:70) at org.apache.spark.sql.execution.QueryExecution.analyzed(QueryExecution.scala:68) at org.apache.spark.sql.execution.QueryExecution.assertAnalyzed(QueryExecution.scala:51) at org.apache.spark.sql.Dataset$.ofRows(Dataset.scala:70) at org.apache.spark.sql.SparkSession.internalCreateDataFrame(SparkSession.scala:574) at org.apache.spark.sql.SparkSession.createDataFrame(SparkSession.scala:593) at org.apache.spark.sql.SparkSession.createDataFrame(SparkSession.scala:348) at org.apache.spark.sql.SparkSession.createDataFrame(SparkSession.scala:348) at org.apache.spark.ml.image.ImageSchema$$anonfun$readImages$2$$anonfun$apply$1.apply(ImageSchema.scala:253) ... Caused by: ERROR XJ040: Failed to start database 'metastore_db' with class loader org.apache.spark.sql.hive.client.IsolatedClientLoader$$anon$1742f639f, see the next exception for details. at org.apache.derby.iapi.error.StandardException.newException(Unknown Source) at org.apache.derby.impl.jdbc.SQLExceptionFactory.wrapArgsForTransportAcrossDRDA(Unknown Source) ... 121 more Caused by: ERROR XSDB6: Another instance of Derby may have already booted the database /.../spark/metastore_db. ... Traceback (most recent call last): File "", line 1, in File "/.../spark/python/pyspark/ml/image.py", line 190, in readImages dropImageFailures, float(sampleRatio), seed) File "/.../spark/python/lib/py4j-0.10.6-src.zip/py4j/java_gateway.py", line 1160, in __call__ File "/.../spark/python/pyspark/sql/utils.py", line 69, in deco raise AnalysisException(s.split(': ', 1)[1], stackTrace) pyspark.sql.utils.AnalysisException: u'java.lang.RuntimeException: java.lang.RuntimeException: Unable to instantiate org.apache.hadoop.hive.ql.metadata.SessionHiveMetaStoreClient;' ``` Seems we better stick to `SparkSession.builder.getOrCreate()` like: https://github.com/apache/spark/blob/51620e288b5e0a7fffc3899c9deadabace28e6d7/python/pyspark/sql/streaming.py#L329 https://github.com/apache/spark/blob/dc5d34d8dcd6526d1dfdac8606661561c7576a62/python/pyspark/sql/column.py#L541 https://github.com/apache/spark/blob/33d43bf1b6f55594187066f0e38ba3985fa2542b/python/pyspark/sql/readwriter.py#L105 ## How was this patch tested? This was tested as below in PySpark shell: ```python from pyspark.ml.image import ImageSchema data_path = 'data/mllib/images/kittens' _ = ImageSchema.readImages(data_path, recursive=True, dropImageFailures=True).collect() _ = ImageSchema.readImages(data_path, recursive=True, dropImageFailures=True).collect() ``` Author: hyukjinkwon Closes #19845 from HyukjinKwon/SPARK-22651. --- python/pyspark/ml/image.py | 5 ++--- python/pyspark/ml/tests.py | 31 ++++++++++++++++++++++++++++++- 2 files changed, 32 insertions(+), 4 deletions(-) diff --git a/python/pyspark/ml/image.py b/python/pyspark/ml/image.py index 2b61aa9c0d9e9..384599dc0c532 100644 --- a/python/pyspark/ml/image.py +++ b/python/pyspark/ml/image.py @@ -201,9 +201,8 @@ def readImages(self, path, recursive=False, numPartitions=-1, .. versionadded:: 2.3.0 """ - ctx = SparkContext._active_spark_context - spark = SparkSession(ctx) - image_schema = ctx._jvm.org.apache.spark.ml.image.ImageSchema + spark = SparkSession.builder.getOrCreate() + image_schema = spark._jvm.org.apache.spark.ml.image.ImageSchema jsession = spark._jsparkSession jresult = image_schema.readImages(path, jsession, recursive, numPartitions, dropImageFailures, float(sampleRatio), seed) diff --git a/python/pyspark/ml/tests.py b/python/pyspark/ml/tests.py index 89ef555cf3442..3a0b816c367ec 100755 --- a/python/pyspark/ml/tests.py +++ b/python/pyspark/ml/tests.py @@ -67,7 +67,7 @@ from pyspark.ml.util import * from pyspark.ml.wrapper import JavaParams, JavaWrapper from pyspark.serializers import PickleSerializer -from pyspark.sql import DataFrame, Row, SparkSession +from pyspark.sql import DataFrame, Row, SparkSession, HiveContext from pyspark.sql.functions import rand from pyspark.sql.types import DoubleType, IntegerType from pyspark.storagelevel import * @@ -1855,6 +1855,35 @@ def test_read_images(self): lambda: ImageSchema.toImage("a")) +class ImageReaderTest2(PySparkTestCase): + + @classmethod + def setUpClass(cls): + PySparkTestCase.setUpClass() + # Note that here we enable Hive's support. + try: + cls.sc._jvm.org.apache.hadoop.hive.conf.HiveConf() + except py4j.protocol.Py4JError: + cls.tearDownClass() + raise unittest.SkipTest("Hive is not available") + except TypeError: + cls.tearDownClass() + raise unittest.SkipTest("Hive is not available") + cls.spark = HiveContext._createForTesting(cls.sc) + + @classmethod + def tearDownClass(cls): + PySparkTestCase.tearDownClass() + cls.spark.sparkSession.stop() + + def test_read_images_multiple_times(self): + # This test case is to check if `ImageSchema.readImages` tries to + # initiate Hive client multiple times. See SPARK-22651. + data_path = 'data/mllib/images/kittens' + ImageSchema.readImages(data_path, recursive=True, dropImageFailures=True) + ImageSchema.readImages(data_path, recursive=True, dropImageFailures=True) + + class ALSTest(SparkSessionTestCase): def test_storage_levels(self): From d2cf95aa63f5f5c9423f0455c2bfbee7833c9982 Mon Sep 17 00:00:00 2001 From: Sean Owen Date: Sat, 2 Dec 2017 07:37:02 -0600 Subject: [PATCH 1760/1765] [SPARK-22634][BUILD] Update Bouncy Castle to 1.58 ## What changes were proposed in this pull request? Update Bouncy Castle to 1.58, and jets3t to 0.9.4 to (sort of) match. ## How was this patch tested? Existing tests Author: Sean Owen Closes #19859 from srowen/SPARK-22634. --- dev/deps/spark-deps-hadoop-2.6 | 8 +++----- dev/deps/spark-deps-hadoop-2.7 | 8 +++----- pom.xml | 8 +++++++- 3 files changed, 13 insertions(+), 11 deletions(-) diff --git a/dev/deps/spark-deps-hadoop-2.6 b/dev/deps/spark-deps-hadoop-2.6 index 8f508219c2ded..2c68b73095c4d 100644 --- a/dev/deps/spark-deps-hadoop-2.6 +++ b/dev/deps/spark-deps-hadoop-2.6 @@ -21,7 +21,7 @@ avro-1.7.7.jar avro-ipc-1.7.7.jar avro-mapred-1.7.7-hadoop2.jar base64-2.3.8.jar -bcprov-jdk15on-1.51.jar +bcprov-jdk15on-1.58.jar bonecp-0.8.0.RELEASE.jar breeze-macros_2.11-0.13.2.jar breeze_2.11-0.13.2.jar @@ -97,7 +97,7 @@ jackson-module-paranamer-2.7.9.jar jackson-module-scala_2.11-2.6.7.1.jar jackson-xc-1.9.13.jar janino-3.0.7.jar -java-xmlbuilder-1.0.jar +java-xmlbuilder-1.1.jar javassist-3.18.1-GA.jar javax.annotation-api-1.2.jar javax.inject-1.jar @@ -115,7 +115,7 @@ jersey-container-servlet-core-2.22.2.jar jersey-guava-2.22.2.jar jersey-media-jaxb-2.22.2.jar jersey-server-2.22.2.jar -jets3t-0.9.3.jar +jets3t-0.9.4.jar jetty-6.1.26.jar jetty-util-6.1.26.jar jline-2.12.1.jar @@ -137,14 +137,12 @@ log4j-1.2.17.jar lz4-java-1.4.0.jar machinist_2.11-0.6.1.jar macro-compat_2.11-1.1.1.jar -mail-1.4.7.jar mesos-1.4.0-shaded-protobuf.jar metrics-core-3.1.5.jar metrics-graphite-3.1.5.jar metrics-json-3.1.5.jar metrics-jvm-3.1.5.jar minlog-1.3.0.jar -mx4j-3.0.2.jar netty-3.9.9.Final.jar netty-all-4.0.47.Final.jar objenesis-2.1.jar diff --git a/dev/deps/spark-deps-hadoop-2.7 b/dev/deps/spark-deps-hadoop-2.7 index 68e937f50b391..2aaac600b3ec3 100644 --- a/dev/deps/spark-deps-hadoop-2.7 +++ b/dev/deps/spark-deps-hadoop-2.7 @@ -21,7 +21,7 @@ avro-1.7.7.jar avro-ipc-1.7.7.jar avro-mapred-1.7.7-hadoop2.jar base64-2.3.8.jar -bcprov-jdk15on-1.51.jar +bcprov-jdk15on-1.58.jar bonecp-0.8.0.RELEASE.jar breeze-macros_2.11-0.13.2.jar breeze_2.11-0.13.2.jar @@ -97,7 +97,7 @@ jackson-module-paranamer-2.7.9.jar jackson-module-scala_2.11-2.6.7.1.jar jackson-xc-1.9.13.jar janino-3.0.7.jar -java-xmlbuilder-1.0.jar +java-xmlbuilder-1.1.jar javassist-3.18.1-GA.jar javax.annotation-api-1.2.jar javax.inject-1.jar @@ -115,7 +115,7 @@ jersey-container-servlet-core-2.22.2.jar jersey-guava-2.22.2.jar jersey-media-jaxb-2.22.2.jar jersey-server-2.22.2.jar -jets3t-0.9.3.jar +jets3t-0.9.4.jar jetty-6.1.26.jar jetty-util-6.1.26.jar jline-2.12.1.jar @@ -138,14 +138,12 @@ log4j-1.2.17.jar lz4-java-1.4.0.jar machinist_2.11-0.6.1.jar macro-compat_2.11-1.1.1.jar -mail-1.4.7.jar mesos-1.4.0-shaded-protobuf.jar metrics-core-3.1.5.jar metrics-graphite-3.1.5.jar metrics-json-3.1.5.jar metrics-jvm-3.1.5.jar minlog-1.3.0.jar -mx4j-3.0.2.jar netty-3.9.9.Final.jar netty-all-4.0.47.Final.jar objenesis-2.1.jar diff --git a/pom.xml b/pom.xml index 731ee86439eff..07bca9d267da0 100644 --- a/pom.xml +++ b/pom.xml @@ -141,7 +141,7 @@ 3.1.5 1.7.7 hadoop2 - 0.9.3 + 0.9.4 1.7.3 1.11.76 @@ -985,6 +985,12 @@ + + org.bouncycastle + bcprov-jdk15on + + 1.58 + org.apache.hadoop hadoop-yarn-api From f23dddf105aef88531b3572ad70889cf2fc300c9 Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Sun, 3 Dec 2017 22:21:44 +0800 Subject: [PATCH 1761/1765] [SPARK-20682][SPARK-15474][SPARK-21791] Add new ORCFileFormat based on ORC 1.4.1 ## What changes were proposed in this pull request? Since [SPARK-2883](https://issues.apache.org/jira/browse/SPARK-2883), Apache Spark supports Apache ORC inside `sql/hive` module with Hive dependency. This PR aims to add a new ORC data source inside `sql/core` and to replace the old ORC data source eventually. This PR resolves the following three issues. - [SPARK-20682](https://issues.apache.org/jira/browse/SPARK-20682): Add new ORCFileFormat based on Apache ORC 1.4.1 - [SPARK-15474](https://issues.apache.org/jira/browse/SPARK-15474): ORC data source fails to write and read back empty dataframe - [SPARK-21791](https://issues.apache.org/jira/browse/SPARK-21791): ORC should support column names with dot ## How was this patch tested? Pass the Jenkins with the existing all tests and new tests for SPARK-15474 and SPARK-21791. Author: Dongjoon Hyun Author: Wenchen Fan Closes #19651 from dongjoon-hyun/SPARK-20682. --- .../datasources/orc/OrcDeserializer.scala | 243 ++++++++++++++++++ .../datasources/orc/OrcFileFormat.scala | 139 +++++++++- .../datasources/orc/OrcFilters.scala | 210 +++++++++++++++ .../datasources/orc/OrcOutputWriter.scala | 53 ++++ .../datasources/orc/OrcSerializer.scala | 228 ++++++++++++++++ .../execution/datasources/orc/OrcUtils.scala | 113 ++++++++ .../org/apache/spark/sql/SQLQuerySuite.scala | 25 ++ 7 files changed, 1009 insertions(+), 2 deletions(-) create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcDeserializer.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFilters.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcOutputWriter.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcSerializer.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcUtils.scala diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcDeserializer.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcDeserializer.scala new file mode 100644 index 0000000000000..4ecc54bd2fd96 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcDeserializer.scala @@ -0,0 +1,243 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources.orc + +import org.apache.hadoop.io._ +import org.apache.orc.mapred.{OrcList, OrcMap, OrcStruct, OrcTimestamp} +import org.apache.orc.storage.serde2.io.{DateWritable, HiveDecimalWritable} + +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.{SpecificInternalRow, UnsafeArrayData} +import org.apache.spark.sql.catalyst.util._ +import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.UTF8String + +/** + * A deserializer to deserialize ORC structs to Spark rows. + */ +class OrcDeserializer( + dataSchema: StructType, + requiredSchema: StructType, + requestedColIds: Array[Int]) { + + private val resultRow = new SpecificInternalRow(requiredSchema.map(_.dataType)) + + private val fieldWriters: Array[WritableComparable[_] => Unit] = { + requiredSchema.zipWithIndex + // The value of missing columns are always null, do not need writers. + .filterNot { case (_, index) => requestedColIds(index) == -1 } + .map { case (f, index) => + val writer = newWriter(f.dataType, new RowUpdater(resultRow)) + (value: WritableComparable[_]) => writer(index, value) + }.toArray + } + + private val validColIds = requestedColIds.filterNot(_ == -1) + + def deserialize(orcStruct: OrcStruct): InternalRow = { + var i = 0 + while (i < validColIds.length) { + val value = orcStruct.getFieldValue(validColIds(i)) + if (value == null) { + resultRow.setNullAt(i) + } else { + fieldWriters(i)(value) + } + i += 1 + } + resultRow + } + + /** + * Creates a writer to write ORC values to Catalyst data structure at the given ordinal. + */ + private def newWriter( + dataType: DataType, updater: CatalystDataUpdater): (Int, WritableComparable[_]) => Unit = + dataType match { + case NullType => (ordinal, _) => + updater.setNullAt(ordinal) + + case BooleanType => (ordinal, value) => + updater.setBoolean(ordinal, value.asInstanceOf[BooleanWritable].get) + + case ByteType => (ordinal, value) => + updater.setByte(ordinal, value.asInstanceOf[ByteWritable].get) + + case ShortType => (ordinal, value) => + updater.setShort(ordinal, value.asInstanceOf[ShortWritable].get) + + case IntegerType => (ordinal, value) => + updater.setInt(ordinal, value.asInstanceOf[IntWritable].get) + + case LongType => (ordinal, value) => + updater.setLong(ordinal, value.asInstanceOf[LongWritable].get) + + case FloatType => (ordinal, value) => + updater.setFloat(ordinal, value.asInstanceOf[FloatWritable].get) + + case DoubleType => (ordinal, value) => + updater.setDouble(ordinal, value.asInstanceOf[DoubleWritable].get) + + case StringType => (ordinal, value) => + updater.set(ordinal, UTF8String.fromBytes(value.asInstanceOf[Text].copyBytes)) + + case BinaryType => (ordinal, value) => + val binary = value.asInstanceOf[BytesWritable] + val bytes = new Array[Byte](binary.getLength) + System.arraycopy(binary.getBytes, 0, bytes, 0, binary.getLength) + updater.set(ordinal, bytes) + + case DateType => (ordinal, value) => + updater.setInt(ordinal, DateTimeUtils.fromJavaDate(value.asInstanceOf[DateWritable].get)) + + case TimestampType => (ordinal, value) => + updater.setLong(ordinal, DateTimeUtils.fromJavaTimestamp(value.asInstanceOf[OrcTimestamp])) + + case DecimalType.Fixed(precision, scale) => (ordinal, value) => + val decimal = value.asInstanceOf[HiveDecimalWritable].getHiveDecimal() + val v = Decimal(decimal.bigDecimalValue, decimal.precision(), decimal.scale()) + v.changePrecision(precision, scale) + updater.set(ordinal, v) + + case st: StructType => (ordinal, value) => + val result = new SpecificInternalRow(st) + val fieldUpdater = new RowUpdater(result) + val fieldConverters = st.map(_.dataType).map { dt => + newWriter(dt, fieldUpdater) + }.toArray + val orcStruct = value.asInstanceOf[OrcStruct] + + var i = 0 + while (i < st.length) { + val value = orcStruct.getFieldValue(i) + if (value == null) { + result.setNullAt(i) + } else { + fieldConverters(i)(i, value) + } + i += 1 + } + + updater.set(ordinal, result) + + case ArrayType(elementType, _) => (ordinal, value) => + val orcArray = value.asInstanceOf[OrcList[WritableComparable[_]]] + val length = orcArray.size() + val result = createArrayData(elementType, length) + val elementUpdater = new ArrayDataUpdater(result) + val elementConverter = newWriter(elementType, elementUpdater) + + var i = 0 + while (i < length) { + val value = orcArray.get(i) + if (value == null) { + result.setNullAt(i) + } else { + elementConverter(i, value) + } + i += 1 + } + + updater.set(ordinal, result) + + case MapType(keyType, valueType, _) => (ordinal, value) => + val orcMap = value.asInstanceOf[OrcMap[WritableComparable[_], WritableComparable[_]]] + val length = orcMap.size() + val keyArray = createArrayData(keyType, length) + val keyUpdater = new ArrayDataUpdater(keyArray) + val keyConverter = newWriter(keyType, keyUpdater) + val valueArray = createArrayData(valueType, length) + val valueUpdater = new ArrayDataUpdater(valueArray) + val valueConverter = newWriter(valueType, valueUpdater) + + var i = 0 + val it = orcMap.entrySet().iterator() + while (it.hasNext) { + val entry = it.next() + keyConverter(i, entry.getKey) + val value = entry.getValue + if (value == null) { + valueArray.setNullAt(i) + } else { + valueConverter(i, value) + } + i += 1 + } + + updater.set(ordinal, new ArrayBasedMapData(keyArray, valueArray)) + + case udt: UserDefinedType[_] => newWriter(udt.sqlType, updater) + + case _ => + throw new UnsupportedOperationException(s"$dataType is not supported yet.") + } + + private def createArrayData(elementType: DataType, length: Int): ArrayData = elementType match { + case BooleanType => UnsafeArrayData.fromPrimitiveArray(new Array[Boolean](length)) + case ByteType => UnsafeArrayData.fromPrimitiveArray(new Array[Byte](length)) + case ShortType => UnsafeArrayData.fromPrimitiveArray(new Array[Short](length)) + case IntegerType => UnsafeArrayData.fromPrimitiveArray(new Array[Int](length)) + case LongType => UnsafeArrayData.fromPrimitiveArray(new Array[Long](length)) + case FloatType => UnsafeArrayData.fromPrimitiveArray(new Array[Float](length)) + case DoubleType => UnsafeArrayData.fromPrimitiveArray(new Array[Double](length)) + case _ => new GenericArrayData(new Array[Any](length)) + } + + /** + * A base interface for updating values inside catalyst data structure like `InternalRow` and + * `ArrayData`. + */ + sealed trait CatalystDataUpdater { + def set(ordinal: Int, value: Any): Unit + + def setNullAt(ordinal: Int): Unit = set(ordinal, null) + def setBoolean(ordinal: Int, value: Boolean): Unit = set(ordinal, value) + def setByte(ordinal: Int, value: Byte): Unit = set(ordinal, value) + def setShort(ordinal: Int, value: Short): Unit = set(ordinal, value) + def setInt(ordinal: Int, value: Int): Unit = set(ordinal, value) + def setLong(ordinal: Int, value: Long): Unit = set(ordinal, value) + def setDouble(ordinal: Int, value: Double): Unit = set(ordinal, value) + def setFloat(ordinal: Int, value: Float): Unit = set(ordinal, value) + } + + final class RowUpdater(row: InternalRow) extends CatalystDataUpdater { + override def setNullAt(ordinal: Int): Unit = row.setNullAt(ordinal) + override def set(ordinal: Int, value: Any): Unit = row.update(ordinal, value) + + override def setBoolean(ordinal: Int, value: Boolean): Unit = row.setBoolean(ordinal, value) + override def setByte(ordinal: Int, value: Byte): Unit = row.setByte(ordinal, value) + override def setShort(ordinal: Int, value: Short): Unit = row.setShort(ordinal, value) + override def setInt(ordinal: Int, value: Int): Unit = row.setInt(ordinal, value) + override def setLong(ordinal: Int, value: Long): Unit = row.setLong(ordinal, value) + override def setDouble(ordinal: Int, value: Double): Unit = row.setDouble(ordinal, value) + override def setFloat(ordinal: Int, value: Float): Unit = row.setFloat(ordinal, value) + } + + final class ArrayDataUpdater(array: ArrayData) extends CatalystDataUpdater { + override def setNullAt(ordinal: Int): Unit = array.setNullAt(ordinal) + override def set(ordinal: Int, value: Any): Unit = array.update(ordinal, value) + + override def setBoolean(ordinal: Int, value: Boolean): Unit = array.setBoolean(ordinal, value) + override def setByte(ordinal: Int, value: Byte): Unit = array.setByte(ordinal, value) + override def setShort(ordinal: Int, value: Short): Unit = array.setShort(ordinal, value) + override def setInt(ordinal: Int, value: Int): Unit = array.setInt(ordinal, value) + override def setLong(ordinal: Int, value: Long): Unit = array.setLong(ordinal, value) + override def setDouble(ordinal: Int, value: Double): Unit = array.setDouble(ordinal, value) + override def setFloat(ordinal: Int, value: Float): Unit = array.setFloat(ordinal, value) + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFileFormat.scala index 215740e90fe84..75c42213db3c8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFileFormat.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFileFormat.scala @@ -17,10 +17,29 @@ package org.apache.spark.sql.execution.datasources.orc -import org.apache.orc.TypeDescription +import java.io._ +import java.net.URI +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.fs.{FileStatus, Path} +import org.apache.hadoop.mapred.JobConf +import org.apache.hadoop.mapreduce._ +import org.apache.hadoop.mapreduce.lib.input.FileSplit +import org.apache.hadoop.mapreduce.task.TaskAttemptContextImpl +import org.apache.orc._ +import org.apache.orc.OrcConf.{COMPRESS, MAPRED_OUTPUT_SCHEMA} +import org.apache.orc.mapred.OrcStruct +import org.apache.orc.mapreduce._ + +import org.apache.spark.TaskContext import org.apache.spark.sql.AnalysisException -import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.execution.datasources._ +import org.apache.spark.sql.sources._ +import org.apache.spark.sql.types._ +import org.apache.spark.util.SerializableConfiguration private[sql] object OrcFileFormat { private def checkFieldName(name: String): Unit = { @@ -39,3 +58,119 @@ private[sql] object OrcFileFormat { names.foreach(checkFieldName) } } + +/** + * New ORC File Format based on Apache ORC. + */ +class OrcFileFormat + extends FileFormat + with DataSourceRegister + with Serializable { + + override def shortName(): String = "orc" + + override def toString: String = "ORC" + + override def hashCode(): Int = getClass.hashCode() + + override def equals(other: Any): Boolean = other.isInstanceOf[OrcFileFormat] + + override def inferSchema( + sparkSession: SparkSession, + options: Map[String, String], + files: Seq[FileStatus]): Option[StructType] = { + OrcUtils.readSchema(sparkSession, files) + } + + override def prepareWrite( + sparkSession: SparkSession, + job: Job, + options: Map[String, String], + dataSchema: StructType): OutputWriterFactory = { + val orcOptions = new OrcOptions(options, sparkSession.sessionState.conf) + + val conf = job.getConfiguration + + conf.set(MAPRED_OUTPUT_SCHEMA.getAttribute, dataSchema.catalogString) + + conf.set(COMPRESS.getAttribute, orcOptions.compressionCodec) + + conf.asInstanceOf[JobConf] + .setOutputFormat(classOf[org.apache.orc.mapred.OrcOutputFormat[OrcStruct]]) + + new OutputWriterFactory { + override def newInstance( + path: String, + dataSchema: StructType, + context: TaskAttemptContext): OutputWriter = { + new OrcOutputWriter(path, dataSchema, context) + } + + override def getFileExtension(context: TaskAttemptContext): String = { + val compressionExtension: String = { + val name = context.getConfiguration.get(COMPRESS.getAttribute) + OrcUtils.extensionsForCompressionCodecNames.getOrElse(name, "") + } + + compressionExtension + ".orc" + } + } + } + + override def isSplitable( + sparkSession: SparkSession, + options: Map[String, String], + path: Path): Boolean = { + true + } + + override def buildReader( + sparkSession: SparkSession, + dataSchema: StructType, + partitionSchema: StructType, + requiredSchema: StructType, + filters: Seq[Filter], + options: Map[String, String], + hadoopConf: Configuration): (PartitionedFile) => Iterator[InternalRow] = { + if (sparkSession.sessionState.conf.orcFilterPushDown) { + OrcFilters.createFilter(dataSchema, filters).foreach { f => + OrcInputFormat.setSearchArgument(hadoopConf, f, dataSchema.fieldNames) + } + } + + val broadcastedConf = + sparkSession.sparkContext.broadcast(new SerializableConfiguration(hadoopConf)) + val isCaseSensitive = sparkSession.sessionState.conf.caseSensitiveAnalysis + + (file: PartitionedFile) => { + val conf = broadcastedConf.value.value + + val requestedColIdsOrEmptyFile = OrcUtils.requestedColumnIds( + isCaseSensitive, dataSchema, requiredSchema, new Path(new URI(file.filePath)), conf) + + if (requestedColIdsOrEmptyFile.isEmpty) { + Iterator.empty + } else { + val requestedColIds = requestedColIdsOrEmptyFile.get + assert(requestedColIds.length == requiredSchema.length, + "[BUG] requested column IDs do not match required schema") + conf.set(OrcConf.INCLUDE_COLUMNS.getAttribute, + requestedColIds.filter(_ != -1).sorted.mkString(",")) + + val fileSplit = + new FileSplit(new Path(new URI(file.filePath)), file.start, file.length, Array.empty) + val attemptId = new TaskAttemptID(new TaskID(new JobID(), TaskType.MAP, 0), 0) + val taskAttemptContext = new TaskAttemptContextImpl(conf, attemptId) + + val orcRecordReader = new OrcInputFormat[OrcStruct] + .createRecordReader(fileSplit, taskAttemptContext) + val iter = new RecordReaderIterator[OrcStruct](orcRecordReader) + Option(TaskContext.get()).foreach(_.addTaskCompletionListener(_ => iter.close())) + + val unsafeProjection = UnsafeProjection.create(requiredSchema) + val deserializer = new OrcDeserializer(dataSchema, requiredSchema, requestedColIds) + iter.map(value => unsafeProjection(deserializer.deserialize(value))) + } + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFilters.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFilters.scala new file mode 100644 index 0000000000000..cec256cc1b498 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFilters.scala @@ -0,0 +1,210 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources.orc + +import org.apache.orc.storage.ql.io.sarg.{PredicateLeaf, SearchArgument, SearchArgumentFactory} +import org.apache.orc.storage.ql.io.sarg.SearchArgument.Builder +import org.apache.orc.storage.serde2.io.HiveDecimalWritable + +import org.apache.spark.sql.sources.Filter +import org.apache.spark.sql.types._ + +/** + * Helper object for building ORC `SearchArgument`s, which are used for ORC predicate push-down. + * + * Due to limitation of ORC `SearchArgument` builder, we had to end up with a pretty weird double- + * checking pattern when converting `And`/`Or`/`Not` filters. + * + * An ORC `SearchArgument` must be built in one pass using a single builder. For example, you can't + * build `a = 1` and `b = 2` first, and then combine them into `a = 1 AND b = 2`. This is quite + * different from the cases in Spark SQL or Parquet, where complex filters can be easily built using + * existing simpler ones. + * + * The annoying part is that, `SearchArgument` builder methods like `startAnd()`, `startOr()`, and + * `startNot()` mutate internal state of the builder instance. This forces us to translate all + * convertible filters with a single builder instance. However, before actually converting a filter, + * we've no idea whether it can be recognized by ORC or not. Thus, when an inconvertible filter is + * found, we may already end up with a builder whose internal state is inconsistent. + * + * For example, to convert an `And` filter with builder `b`, we call `b.startAnd()` first, and then + * try to convert its children. Say we convert `left` child successfully, but find that `right` + * child is inconvertible. Alas, `b.startAnd()` call can't be rolled back, and `b` is inconsistent + * now. + * + * The workaround employed here is that, for `And`/`Or`/`Not`, we first try to convert their + * children with brand new builders, and only do the actual conversion with the right builder + * instance when the children are proven to be convertible. + * + * P.S.: Hive seems to use `SearchArgument` together with `ExprNodeGenericFuncDesc` only. Usage of + * builder methods mentioned above can only be found in test code, where all tested filters are + * known to be convertible. + */ +private[orc] object OrcFilters { + + /** + * Create ORC filter as a SearchArgument instance. + */ + def createFilter(schema: StructType, filters: Seq[Filter]): Option[SearchArgument] = { + val dataTypeMap = schema.map(f => f.name -> f.dataType).toMap + + // First, tries to convert each filter individually to see whether it's convertible, and then + // collect all convertible ones to build the final `SearchArgument`. + val convertibleFilters = for { + filter <- filters + _ <- buildSearchArgument(dataTypeMap, filter, SearchArgumentFactory.newBuilder()) + } yield filter + + for { + // Combines all convertible filters using `And` to produce a single conjunction + conjunction <- convertibleFilters.reduceOption(org.apache.spark.sql.sources.And) + // Then tries to build a single ORC `SearchArgument` for the conjunction predicate + builder <- buildSearchArgument(dataTypeMap, conjunction, SearchArgumentFactory.newBuilder()) + } yield builder.build() + } + + /** + * Return true if this is a searchable type in ORC. + * Both CharType and VarcharType are cleaned at AstBuilder. + */ + private def isSearchableType(dataType: DataType) = dataType match { + // TODO: SPARK-21787 Support for pushing down filters for DateType in ORC + case BinaryType | DateType => false + case _: AtomicType => true + case _ => false + } + + /** + * Get PredicateLeafType which is corresponding to the given DataType. + */ + private def getPredicateLeafType(dataType: DataType) = dataType match { + case BooleanType => PredicateLeaf.Type.BOOLEAN + case ByteType | ShortType | IntegerType | LongType => PredicateLeaf.Type.LONG + case FloatType | DoubleType => PredicateLeaf.Type.FLOAT + case StringType => PredicateLeaf.Type.STRING + case DateType => PredicateLeaf.Type.DATE + case TimestampType => PredicateLeaf.Type.TIMESTAMP + case _: DecimalType => PredicateLeaf.Type.DECIMAL + case _ => throw new UnsupportedOperationException(s"DataType: $dataType") + } + + /** + * Cast literal values for filters. + * + * We need to cast to long because ORC raises exceptions + * at 'checkLiteralType' of SearchArgumentImpl.java. + */ + private def castLiteralValue(value: Any, dataType: DataType): Any = dataType match { + case ByteType | ShortType | IntegerType | LongType => + value.asInstanceOf[Number].longValue + case FloatType | DoubleType => + value.asInstanceOf[Number].doubleValue() + case _: DecimalType => + val decimal = value.asInstanceOf[java.math.BigDecimal] + val decimalWritable = new HiveDecimalWritable(decimal.longValue) + decimalWritable.mutateEnforcePrecisionScale(decimal.precision, decimal.scale) + decimalWritable + case _ => value + } + + /** + * Build a SearchArgument and return the builder so far. + */ + private def buildSearchArgument( + dataTypeMap: Map[String, DataType], + expression: Filter, + builder: Builder): Option[Builder] = { + def newBuilder = SearchArgumentFactory.newBuilder() + + def getType(attribute: String): PredicateLeaf.Type = + getPredicateLeafType(dataTypeMap(attribute)) + + import org.apache.spark.sql.sources._ + + expression match { + case And(left, right) => + // At here, it is not safe to just convert one side if we do not understand the + // other side. Here is an example used to explain the reason. + // Let's say we have NOT(a = 2 AND b in ('1')) and we do not understand how to + // convert b in ('1'). If we only convert a = 2, we will end up with a filter + // NOT(a = 2), which will generate wrong results. + // Pushing one side of AND down is only safe to do at the top level. + // You can see ParquetRelation's initializeLocalJobFunc method as an example. + for { + _ <- buildSearchArgument(dataTypeMap, left, newBuilder) + _ <- buildSearchArgument(dataTypeMap, right, newBuilder) + lhs <- buildSearchArgument(dataTypeMap, left, builder.startAnd()) + rhs <- buildSearchArgument(dataTypeMap, right, lhs) + } yield rhs.end() + + case Or(left, right) => + for { + _ <- buildSearchArgument(dataTypeMap, left, newBuilder) + _ <- buildSearchArgument(dataTypeMap, right, newBuilder) + lhs <- buildSearchArgument(dataTypeMap, left, builder.startOr()) + rhs <- buildSearchArgument(dataTypeMap, right, lhs) + } yield rhs.end() + + case Not(child) => + for { + _ <- buildSearchArgument(dataTypeMap, child, newBuilder) + negate <- buildSearchArgument(dataTypeMap, child, builder.startNot()) + } yield negate.end() + + // NOTE: For all case branches dealing with leaf predicates below, the additional `startAnd()` + // call is mandatory. ORC `SearchArgument` builder requires that all leaf predicates must be + // wrapped by a "parent" predicate (`And`, `Or`, or `Not`). + + case EqualTo(attribute, value) if isSearchableType(dataTypeMap(attribute)) => + val castedValue = castLiteralValue(value, dataTypeMap(attribute)) + Some(builder.startAnd().equals(attribute, getType(attribute), castedValue).end()) + + case EqualNullSafe(attribute, value) if isSearchableType(dataTypeMap(attribute)) => + val castedValue = castLiteralValue(value, dataTypeMap(attribute)) + Some(builder.startAnd().nullSafeEquals(attribute, getType(attribute), castedValue).end()) + + case LessThan(attribute, value) if isSearchableType(dataTypeMap(attribute)) => + val castedValue = castLiteralValue(value, dataTypeMap(attribute)) + Some(builder.startAnd().lessThan(attribute, getType(attribute), castedValue).end()) + + case LessThanOrEqual(attribute, value) if isSearchableType(dataTypeMap(attribute)) => + val castedValue = castLiteralValue(value, dataTypeMap(attribute)) + Some(builder.startAnd().lessThanEquals(attribute, getType(attribute), castedValue).end()) + + case GreaterThan(attribute, value) if isSearchableType(dataTypeMap(attribute)) => + val castedValue = castLiteralValue(value, dataTypeMap(attribute)) + Some(builder.startNot().lessThanEquals(attribute, getType(attribute), castedValue).end()) + + case GreaterThanOrEqual(attribute, value) if isSearchableType(dataTypeMap(attribute)) => + val castedValue = castLiteralValue(value, dataTypeMap(attribute)) + Some(builder.startNot().lessThan(attribute, getType(attribute), castedValue).end()) + + case IsNull(attribute) if isSearchableType(dataTypeMap(attribute)) => + Some(builder.startAnd().isNull(attribute, getType(attribute)).end()) + + case IsNotNull(attribute) if isSearchableType(dataTypeMap(attribute)) => + Some(builder.startNot().isNull(attribute, getType(attribute)).end()) + + case In(attribute, values) if isSearchableType(dataTypeMap(attribute)) => + val castedValues = values.map(v => castLiteralValue(v, dataTypeMap(attribute))) + Some(builder.startAnd().in(attribute, getType(attribute), + castedValues.map(_.asInstanceOf[AnyRef]): _*).end()) + + case _ => None + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcOutputWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcOutputWriter.scala new file mode 100644 index 0000000000000..84755bfa301f0 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcOutputWriter.scala @@ -0,0 +1,53 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources.orc + +import org.apache.hadoop.fs.Path +import org.apache.hadoop.io.NullWritable +import org.apache.hadoop.mapreduce.TaskAttemptContext +import org.apache.orc.mapred.OrcStruct +import org.apache.orc.mapreduce.OrcOutputFormat + +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.execution.datasources.OutputWriter +import org.apache.spark.sql.types._ + +private[orc] class OrcOutputWriter( + path: String, + dataSchema: StructType, + context: TaskAttemptContext) + extends OutputWriter { + + private[this] val serializer = new OrcSerializer(dataSchema) + + private val recordWriter = { + new OrcOutputFormat[OrcStruct]() { + override def getDefaultWorkFile(context: TaskAttemptContext, extension: String): Path = { + new Path(path) + } + }.getRecordWriter(context) + } + + override def write(row: InternalRow): Unit = { + recordWriter.write(NullWritable.get(), serializer.serialize(row)) + } + + override def close(): Unit = { + recordWriter.close(context) + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcSerializer.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcSerializer.scala new file mode 100644 index 0000000000000..899af0750cadf --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcSerializer.scala @@ -0,0 +1,228 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources.orc + +import org.apache.hadoop.io._ +import org.apache.orc.TypeDescription +import org.apache.orc.mapred.{OrcList, OrcMap, OrcStruct, OrcTimestamp} +import org.apache.orc.storage.common.`type`.HiveDecimal +import org.apache.orc.storage.serde2.io.{DateWritable, HiveDecimalWritable} + +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.SpecializedGetters +import org.apache.spark.sql.catalyst.util._ +import org.apache.spark.sql.types._ + +/** + * A serializer to serialize Spark rows to ORC structs. + */ +class OrcSerializer(dataSchema: StructType) { + + private val result = createOrcValue(dataSchema).asInstanceOf[OrcStruct] + private val converters = dataSchema.map(_.dataType).map(newConverter(_)).toArray + + def serialize(row: InternalRow): OrcStruct = { + var i = 0 + while (i < converters.length) { + if (row.isNullAt(i)) { + result.setFieldValue(i, null) + } else { + result.setFieldValue(i, converters(i)(row, i)) + } + i += 1 + } + result + } + + private type Converter = (SpecializedGetters, Int) => WritableComparable[_] + + /** + * Creates a converter to convert Catalyst data at the given ordinal to ORC values. + */ + private def newConverter( + dataType: DataType, + reuseObj: Boolean = true): Converter = dataType match { + case NullType => (getter, ordinal) => null + + case BooleanType => + if (reuseObj) { + val result = new BooleanWritable() + (getter, ordinal) => + result.set(getter.getBoolean(ordinal)) + result + } else { + (getter, ordinal) => new BooleanWritable(getter.getBoolean(ordinal)) + } + + case ByteType => + if (reuseObj) { + val result = new ByteWritable() + (getter, ordinal) => + result.set(getter.getByte(ordinal)) + result + } else { + (getter, ordinal) => new ByteWritable(getter.getByte(ordinal)) + } + + case ShortType => + if (reuseObj) { + val result = new ShortWritable() + (getter, ordinal) => + result.set(getter.getShort(ordinal)) + result + } else { + (getter, ordinal) => new ShortWritable(getter.getShort(ordinal)) + } + + case IntegerType => + if (reuseObj) { + val result = new IntWritable() + (getter, ordinal) => + result.set(getter.getInt(ordinal)) + result + } else { + (getter, ordinal) => new IntWritable(getter.getInt(ordinal)) + } + + + case LongType => + if (reuseObj) { + val result = new LongWritable() + (getter, ordinal) => + result.set(getter.getLong(ordinal)) + result + } else { + (getter, ordinal) => new LongWritable(getter.getLong(ordinal)) + } + + case FloatType => + if (reuseObj) { + val result = new FloatWritable() + (getter, ordinal) => + result.set(getter.getFloat(ordinal)) + result + } else { + (getter, ordinal) => new FloatWritable(getter.getFloat(ordinal)) + } + + case DoubleType => + if (reuseObj) { + val result = new DoubleWritable() + (getter, ordinal) => + result.set(getter.getDouble(ordinal)) + result + } else { + (getter, ordinal) => new DoubleWritable(getter.getDouble(ordinal)) + } + + + // Don't reuse the result object for string and binary as it would cause extra data copy. + case StringType => (getter, ordinal) => + new Text(getter.getUTF8String(ordinal).getBytes) + + case BinaryType => (getter, ordinal) => + new BytesWritable(getter.getBinary(ordinal)) + + case DateType => + if (reuseObj) { + val result = new DateWritable() + (getter, ordinal) => + result.set(getter.getInt(ordinal)) + result + } else { + (getter, ordinal) => new DateWritable(getter.getInt(ordinal)) + } + + // The following cases are already expensive, reusing object or not doesn't matter. + + case TimestampType => (getter, ordinal) => + val ts = DateTimeUtils.toJavaTimestamp(getter.getLong(ordinal)) + val result = new OrcTimestamp(ts.getTime) + result.setNanos(ts.getNanos) + result + + case DecimalType.Fixed(precision, scale) => (getter, ordinal) => + val d = getter.getDecimal(ordinal, precision, scale) + new HiveDecimalWritable(HiveDecimal.create(d.toJavaBigDecimal)) + + case st: StructType => (getter, ordinal) => + val result = createOrcValue(st).asInstanceOf[OrcStruct] + val fieldConverters = st.map(_.dataType).map(newConverter(_)) + val numFields = st.length + val struct = getter.getStruct(ordinal, numFields) + var i = 0 + while (i < numFields) { + if (struct.isNullAt(i)) { + result.setFieldValue(i, null) + } else { + result.setFieldValue(i, fieldConverters(i)(struct, i)) + } + i += 1 + } + result + + case ArrayType(elementType, _) => (getter, ordinal) => + val result = createOrcValue(dataType).asInstanceOf[OrcList[WritableComparable[_]]] + // Need to put all converted values to a list, can't reuse object. + val elementConverter = newConverter(elementType, reuseObj = false) + val array = getter.getArray(ordinal) + var i = 0 + while (i < array.numElements()) { + if (array.isNullAt(i)) { + result.add(null) + } else { + result.add(elementConverter(array, i)) + } + i += 1 + } + result + + case MapType(keyType, valueType, _) => (getter, ordinal) => + val result = createOrcValue(dataType) + .asInstanceOf[OrcMap[WritableComparable[_], WritableComparable[_]]] + // Need to put all converted values to a list, can't reuse object. + val keyConverter = newConverter(keyType, reuseObj = false) + val valueConverter = newConverter(valueType, reuseObj = false) + val map = getter.getMap(ordinal) + val keyArray = map.keyArray() + val valueArray = map.valueArray() + var i = 0 + while (i < map.numElements()) { + val key = keyConverter(keyArray, i) + if (valueArray.isNullAt(i)) { + result.put(key, null) + } else { + result.put(key, valueConverter(valueArray, i)) + } + i += 1 + } + result + + case udt: UserDefinedType[_] => newConverter(udt.sqlType) + + case _ => + throw new UnsupportedOperationException(s"$dataType is not supported yet.") + } + + /** + * Return a Orc value object for the given Spark schema. + */ + private def createOrcValue(dataType: DataType) = { + OrcStruct.createValue(TypeDescription.fromString(dataType.catalogString)) + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcUtils.scala new file mode 100644 index 0000000000000..b03ee06d04a16 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcUtils.scala @@ -0,0 +1,113 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources.orc + +import scala.collection.JavaConverters._ + +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.fs.{FileStatus, Path} +import org.apache.orc.{OrcFile, TypeDescription} + +import org.apache.spark.deploy.SparkHadoopUtil +import org.apache.spark.internal.Logging +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.catalyst.analysis.{caseInsensitiveResolution, caseSensitiveResolution} +import org.apache.spark.sql.catalyst.parser.CatalystSqlParser +import org.apache.spark.sql.types._ + +object OrcUtils extends Logging { + + // The extensions for ORC compression codecs + val extensionsForCompressionCodecNames = Map( + "NONE" -> "", + "SNAPPY" -> ".snappy", + "ZLIB" -> ".zlib", + "LZO" -> ".lzo") + + def listOrcFiles(pathStr: String, conf: Configuration): Seq[Path] = { + val origPath = new Path(pathStr) + val fs = origPath.getFileSystem(conf) + val paths = SparkHadoopUtil.get.listLeafStatuses(fs, origPath) + .filterNot(_.isDirectory) + .map(_.getPath) + .filterNot(_.getName.startsWith("_")) + .filterNot(_.getName.startsWith(".")) + paths + } + + def readSchema(file: Path, conf: Configuration): Option[TypeDescription] = { + val fs = file.getFileSystem(conf) + val readerOptions = OrcFile.readerOptions(conf).filesystem(fs) + val reader = OrcFile.createReader(file, readerOptions) + val schema = reader.getSchema + if (schema.getFieldNames.size == 0) { + None + } else { + Some(schema) + } + } + + def readSchema(sparkSession: SparkSession, files: Seq[FileStatus]) + : Option[StructType] = { + val conf = sparkSession.sessionState.newHadoopConf() + // TODO: We need to support merge schema. Please see SPARK-11412. + files.map(_.getPath).flatMap(readSchema(_, conf)).headOption.map { schema => + logDebug(s"Reading schema from file $files, got Hive schema string: $schema") + CatalystSqlParser.parseDataType(schema.toString).asInstanceOf[StructType] + } + } + + /** + * Returns the requested column ids from the given ORC file. Column id can be -1, which means the + * requested column doesn't exist in the ORC file. Returns None if the given ORC file is empty. + */ + def requestedColumnIds( + isCaseSensitive: Boolean, + dataSchema: StructType, + requiredSchema: StructType, + file: Path, + conf: Configuration): Option[Array[Int]] = { + val fs = file.getFileSystem(conf) + val readerOptions = OrcFile.readerOptions(conf).filesystem(fs) + val reader = OrcFile.createReader(file, readerOptions) + val orcFieldNames = reader.getSchema.getFieldNames.asScala + if (orcFieldNames.isEmpty) { + // SPARK-8501: Some old empty ORC files always have an empty schema stored in their footer. + None + } else { + if (orcFieldNames.forall(_.startsWith("_col"))) { + // This is a ORC file written by Hive, no field names in the physical schema, assume the + // physical schema maps to the data scheme by index. + assert(orcFieldNames.length <= dataSchema.length, "The given data schema " + + s"${dataSchema.simpleString} has less fields than the actual ORC physical schema, " + + "no idea which columns were dropped, fail to read.") + Some(requiredSchema.fieldNames.map { name => + val index = dataSchema.fieldIndex(name) + if (index < orcFieldNames.length) { + index + } else { + -1 + } + }) + } else { + val resolver = if (isCaseSensitive) caseSensitiveResolution else caseInsensitiveResolution + Some(requiredSchema.fieldNames.map { name => orcFieldNames.indexWhere(resolver(_, name)) }) + } + } + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 5d0bba69daca1..31d9b909ad463 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -2757,4 +2757,29 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { } } } + + // Only New OrcFileFormat supports this + Seq(classOf[org.apache.spark.sql.execution.datasources.orc.OrcFileFormat].getCanonicalName, + "parquet").foreach { format => + test(s"SPARK-15474 Write and read back non-emtpy schema with empty dataframe - $format") { + withTempPath { file => + val path = file.getCanonicalPath + val emptyDf = Seq((true, 1, "str")).toDF.limit(0) + emptyDf.write.format(format).save(path) + + val df = spark.read.format(format).load(path) + assert(df.schema.sameType(emptyDf.schema)) + checkAnswer(df, emptyDf) + } + } + } + + test("SPARK-21791 ORC should support column names with dot") { + val orc = classOf[org.apache.spark.sql.execution.datasources.orc.OrcFileFormat].getCanonicalName + withTempDir { dir => + val path = new File(dir, "orc").getCanonicalPath + Seq(Some(1), None).toDF("col.dots").write.format(orc).save(path) + assert(spark.read.format(orc).load(path).collect().length == 2) + } + } } From 2c16267f7ca392d717942a7654e90db60ba60770 Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Sun, 3 Dec 2017 22:56:03 +0800 Subject: [PATCH 1762/1765] [SPARK-22669][SQL] Avoid unnecessary function calls in code generation ## What changes were proposed in this pull request? In many parts of the codebase for code generation, we are splitting the code to avoid exceptions due to the 64KB method size limit. This is generating a lot of methods which are called every time, even though sometime this is not needed. As pointed out here: https://github.com/apache/spark/pull/19752#discussion_r153081547, this is a not negligible overhead which can be avoided. The PR applies the same approach used in #19752 also to the other places where this was feasible. ## How was this patch tested? existing UTs. Author: Marco Gaido Closes #19860 from mgaido91/SPARK-22669. --- .../expressions/nullExpressions.scala | 141 ++++++++++++------ .../sql/catalyst/expressions/predicates.scala | 68 ++++++--- 2 files changed, 140 insertions(+), 69 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala index 173e171910b69..3b52a0efd404a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala @@ -75,23 +75,51 @@ case class Coalesce(children: Seq[Expression]) extends Expression { ctx.addMutableState(ctx.JAVA_BOOLEAN, ev.isNull) ctx.addMutableState(ctx.javaType(dataType), ev.value) + // all the evals are meant to be in a do { ... } while (false); loop val evals = children.map { e => val eval = e.genCode(ctx) s""" - if (${ev.isNull}) { - ${eval.code} - if (!${eval.isNull}) { - ${ev.isNull} = false; - ${ev.value} = ${eval.value}; - } - } - """ + |${eval.code} + |if (!${eval.isNull}) { + | ${ev.isNull} = false; + | ${ev.value} = ${eval.value}; + | continue; + |} + """.stripMargin } + val code = if (ctx.INPUT_ROW == null || ctx.currentVars != null) { + evals.mkString("\n") + } else { + ctx.splitExpressions(evals, "coalesce", + ("InternalRow", ctx.INPUT_ROW) :: Nil, + makeSplitFunction = { + func => + s""" + |do { + | $func + |} while (false); + """.stripMargin + }, + foldFunctions = { funcCalls => + funcCalls.map { funcCall => + s""" + |$funcCall; + |if (!${ev.isNull}) { + | continue; + |} + """.stripMargin + }.mkString + }) + } - ev.copy(code = s""" - ${ev.isNull} = true; - ${ev.value} = ${ctx.defaultValue(dataType)}; - ${ctx.splitExpressions(evals)}""") + ev.copy(code = + s""" + |${ev.isNull} = true; + |${ev.value} = ${ctx.defaultValue(dataType)}; + |do { + | $code + |} while (false); + """.stripMargin) } } @@ -358,53 +386,70 @@ case class AtLeastNNonNulls(n: Int, children: Seq[Expression]) extends Predicate override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val nonnull = ctx.freshName("nonnull") + // all evals are meant to be inside a do { ... } while (false); loop val evals = children.map { e => val eval = e.genCode(ctx) e.dataType match { case DoubleType | FloatType => s""" - if ($nonnull < $n) { - ${eval.code} - if (!${eval.isNull} && !Double.isNaN(${eval.value})) { - $nonnull += 1; - } - } - """ + |if ($nonnull < $n) { + | ${eval.code} + | if (!${eval.isNull} && !Double.isNaN(${eval.value})) { + | $nonnull += 1; + | } + |} else { + | continue; + |} + """.stripMargin case _ => s""" - if ($nonnull < $n) { - ${eval.code} - if (!${eval.isNull}) { - $nonnull += 1; - } - } - """ + |if ($nonnull < $n) { + | ${eval.code} + | if (!${eval.isNull}) { + | $nonnull += 1; + | } + |} else { + | continue; + |} + """.stripMargin } } val code = if (ctx.INPUT_ROW == null || ctx.currentVars != null) { - evals.mkString("\n") - } else { - ctx.splitExpressions( - expressions = evals, - funcName = "atLeastNNonNulls", - arguments = ("InternalRow", ctx.INPUT_ROW) :: ("int", nonnull) :: Nil, - returnType = "int", - makeSplitFunction = { body => - s""" - $body - return $nonnull; - """ - }, - foldFunctions = { funcCalls => - funcCalls.map(funcCall => s"$nonnull = $funcCall;").mkString("\n") - } - ) - } + evals.mkString("\n") + } else { + ctx.splitExpressions( + expressions = evals, + funcName = "atLeastNNonNulls", + arguments = ("InternalRow", ctx.INPUT_ROW) :: (ctx.JAVA_INT, nonnull) :: Nil, + returnType = ctx.JAVA_INT, + makeSplitFunction = { body => + s""" + |do { + | $body + |} while (false); + |return $nonnull; + """.stripMargin + }, + foldFunctions = { funcCalls => + funcCalls.map(funcCall => + s""" + |$nonnull = $funcCall; + |if ($nonnull >= $n) { + | continue; + |} + """.stripMargin).mkString("\n") + } + ) + } - ev.copy(code = s""" - int $nonnull = 0; - $code - boolean ${ev.value} = $nonnull >= $n;""", isNull = "false") + ev.copy(code = + s""" + |${ctx.JAVA_INT} $nonnull = 0; + |do { + | $code + |} while (false); + |${ctx.JAVA_BOOLEAN} ${ev.value} = $nonnull >= $n; + """.stripMargin, isNull = "false") } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala index 1aaaaf1db48d1..75cc9b3bd8045 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala @@ -234,36 +234,62 @@ case class In(value: Expression, list: Seq[Expression]) extends Predicate { } override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + val javaDataType = ctx.javaType(value.dataType) val valueGen = value.genCode(ctx) val listGen = list.map(_.genCode(ctx)) ctx.addMutableState(ctx.JAVA_BOOLEAN, ev.value) ctx.addMutableState(ctx.JAVA_BOOLEAN, ev.isNull) val valueArg = ctx.freshName("valueArg") + // All the blocks are meant to be inside a do { ... } while (false); loop. + // The evaluation of variables can be stopped when we find a matching value. val listCode = listGen.map(x => s""" - if (!${ev.value}) { - ${x.code} - if (${x.isNull}) { - ${ev.isNull} = true; - } else if (${ctx.genEqual(value.dataType, valueArg, x.value)}) { - ${ev.isNull} = false; - ${ev.value} = true; + |${x.code} + |if (${x.isNull}) { + | ${ev.isNull} = true; + |} else if (${ctx.genEqual(value.dataType, valueArg, x.value)}) { + | ${ev.isNull} = false; + | ${ev.value} = true; + | continue; + |} + """.stripMargin) + val code = if (ctx.INPUT_ROW == null || ctx.currentVars != null) { + listCode.mkString("\n") + } else { + ctx.splitExpressions( + expressions = listCode, + funcName = "valueIn", + arguments = ("InternalRow", ctx.INPUT_ROW) :: (javaDataType, valueArg) :: Nil, + makeSplitFunction = { body => + s""" + |do { + | $body + |} while (false); + """.stripMargin + }, + foldFunctions = { funcCalls => + funcCalls.map(funcCall => + s""" + |$funcCall; + |if (${ev.value}) { + | continue; + |} + """.stripMargin).mkString("\n") } - } - """) - val listCodes = ctx.splitExpressions( - expressions = listCode, - funcName = "valueIn", - extraArguments = (ctx.javaType(value.dataType), valueArg) :: Nil) - ev.copy(code = s""" - ${valueGen.code} - ${ev.value} = false; - ${ev.isNull} = ${valueGen.isNull}; - if (!${ev.isNull}) { - ${ctx.javaType(value.dataType)} $valueArg = ${valueGen.value}; - $listCodes + ) } - """) + ev.copy(code = + s""" + |${valueGen.code} + |${ev.value} = false; + |${ev.isNull} = ${valueGen.isNull}; + |if (!${ev.isNull}) { + | $javaDataType $valueArg = ${valueGen.value}; + | do { + | $code + | } while (false); + |} + """.stripMargin) } override def sql: String = { From dff440f1ecdbce65cad377c44fd2abfd4eff9b44 Mon Sep 17 00:00:00 2001 From: Yuming Wang Date: Sun, 3 Dec 2017 23:05:31 +0800 Subject: [PATCH 1763/1765] [SPARK-22626][SQL] deals with wrong Hive's statistics (zero rowCount) This pr to ensure that the Hive's statistics `totalSize` (or `rawDataSize`) > 0, `rowCount` also must be > 0. Otherwise may cause OOM when CBO is enabled. unit tests Author: Yuming Wang Closes #19831 from wangyum/SPARK-22626. --- .../sql/hive/client/HiveClientImpl.scala | 6 +++--- .../spark/sql/hive/StatisticsSuite.scala | 19 +++++++++++++++++++ 2 files changed, 22 insertions(+), 3 deletions(-) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala index 47ce6ba83866c..77e836003b39f 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala @@ -418,7 +418,7 @@ private[hive] class HiveClientImpl( // Note that this statistics could be overridden by Spark's statistics if that's available. val totalSize = properties.get(StatsSetupConst.TOTAL_SIZE).map(BigInt(_)) val rawDataSize = properties.get(StatsSetupConst.RAW_DATA_SIZE).map(BigInt(_)) - val rowCount = properties.get(StatsSetupConst.ROW_COUNT).map(BigInt(_)).filter(_ >= 0) + val rowCount = properties.get(StatsSetupConst.ROW_COUNT).map(BigInt(_)) // TODO: check if this estimate is valid for tables after partition pruning. // NOTE: getting `totalSize` directly from params is kind of hacky, but this should be // relatively cheap if parameters for the table are populated into the metastore. @@ -430,9 +430,9 @@ private[hive] class HiveClientImpl( // so when `totalSize` is zero, use `rawDataSize` instead. When `rawDataSize` is also zero, // return None. Later, we will use the other ways to estimate the statistics. if (totalSize.isDefined && totalSize.get > 0L) { - Some(CatalogStatistics(sizeInBytes = totalSize.get, rowCount = rowCount)) + Some(CatalogStatistics(sizeInBytes = totalSize.get, rowCount = rowCount.filter(_ > 0))) } else if (rawDataSize.isDefined && rawDataSize.get > 0) { - Some(CatalogStatistics(sizeInBytes = rawDataSize.get, rowCount = rowCount)) + Some(CatalogStatistics(sizeInBytes = rawDataSize.get, rowCount = rowCount.filter(_ > 0))) } else { // TODO: still fill the rowCount even if sizeInBytes is empty. Might break anything? None diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala index 0cdd9305c6b6f..ee027e5308265 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala @@ -1360,4 +1360,23 @@ class StatisticsSuite extends StatisticsCollectionTestBase with TestHiveSingleto } } + + test("Deals with wrong Hive's statistics (zero rowCount)") { + withTable("maybe_big") { + sql("CREATE TABLE maybe_big (c1 bigint)" + + "TBLPROPERTIES ('numRows'='0', 'rawDataSize'='60000000000', 'totalSize'='8000000000000')") + + val relation = spark.table("maybe_big").queryExecution.analyzed.children.head + .asInstanceOf[HiveTableRelation] + + val properties = relation.tableMeta.ignoredProperties + assert(properties("totalSize").toLong > 0) + assert(properties("rawDataSize").toLong > 0) + assert(properties("numRows").toLong == 0) + + assert(relation.stats.sizeInBytes > 0) + // May be cause OOM if rowCount == 0 when enables CBO, see SPARK-22626 for details. + assert(relation.stats.rowCount.isEmpty) + } + } } From 4131ad03f4d2dfcfb1e166e5dfdf0752479f7340 Mon Sep 17 00:00:00 2001 From: Yuming Wang Date: Sun, 3 Dec 2017 23:52:37 -0800 Subject: [PATCH 1764/1765] [SPARK-22489][DOC][FOLLOWUP] Update broadcast behavior changes in migration section ## What changes were proposed in this pull request? Update broadcast behavior changes in migration section. ## How was this patch tested? N/A Author: Yuming Wang Closes #19858 from wangyum/SPARK-22489-migration. --- docs/sql-programming-guide.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md index a1b9c3bbfd059..b76be9132dd03 100644 --- a/docs/sql-programming-guide.md +++ b/docs/sql-programming-guide.md @@ -1776,6 +1776,8 @@ options. Note that, for DecimalType(38,0)*, the table above intentionally does not cover all other combinations of scales and precisions because currently we only infer decimal type like `BigInteger`/`BigInt`. For example, 1.1 is inferred as double type. - In PySpark, now we need Pandas 0.19.2 or upper if you want to use Pandas related functionalities, such as `toPandas`, `createDataFrame` from Pandas DataFrame, etc. - In PySpark, the behavior of timestamp values for Pandas related functionalities was changed to respect session timezone. If you want to use the old behavior, you need to set a configuration `spark.sql.execution.pandas.respectSessionTimeZone` to `False`. See [SPARK-22395](https://issues.apache.org/jira/browse/SPARK-22395) for details. + + - Since Spark 2.3, when either broadcast hash join or broadcast nested loop join is applicable, we prefer to broadcasting the table that is explicitly specified in a broadcast hint. For details, see the section [Broadcast Hint](#broadcast-hint-for-sql-queries) and [SPARK-22489](https://issues.apache.org/jira/browse/SPARK-22489). ## Upgrading From Spark SQL 2.1 to 2.2 From 117fa4bb859717dd5cbe794d7eee4b5c291b5846 Mon Sep 17 00:00:00 2001 From: "sergei.rubtcov" Date: Thu, 21 Dec 2017 18:25:15 +0200 Subject: [PATCH 1765/1765] [SPARK-19228][SQL] add java.time.format.Formatter to more accurately infer the type of time, fix order of inferred types --- .../execution/datasources/csv/CSVInferSchema.scala | 7 ++++--- .../sql/execution/datasources/csv/CSVOptions.scala | 11 +++++++++++ 2 files changed, 15 insertions(+), 3 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchema.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchema.scala index 6249a235ad502..b3ef9d6d5b84c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchema.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchema.scala @@ -92,7 +92,8 @@ private[csv] object CSVInferSchema { case DoubleType => tryParseDouble(field, options) case DateType => tryParseDate(field, options) case TimestampType => - findTightestCommonType(typeSoFar, tryParseTimestamp(field, options)).getOrElse(StringType) + findTightestCommonType(typeSoFar, tryParseTimestamp(field, options)).getOrElse( + tryParseBoolean(field, options)) case BooleanType => tryParseBoolean(field, options) case StringType => StringType case other: DataType => @@ -149,7 +150,7 @@ private[csv] object CSVInferSchema { private def tryParseDate(field: String, options: CSVOptions): DataType = { // This case infers a custom `dateFormat` is set. - if ((allCatch opt options.dateFormat.parse(field)).isDefined) { + if ((allCatch opt options.dateFormatter.parse(field)).isDefined) { DateType } else { tryParseTimestamp(field, options) @@ -158,7 +159,7 @@ private[csv] object CSVInferSchema { private def tryParseTimestamp(field: String, options: CSVOptions): DataType = { // This case infers a custom `timestampFormat` is set. - if ((allCatch opt options.timestampFormat.parse(field)).isDefined) { + if ((allCatch opt options.timestampFormatter.parse(field)).isDefined) { TimestampType } else if ((allCatch opt DateTimeUtils.stringToTime(field)).isDefined) { // We keep this for backwards compatibility. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala index a13a5a34b4a84..3c22fbfa8e486 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.execution.datasources.csv import java.nio.charset.StandardCharsets +import java.time.format.{DateTimeFormatter, ResolverStyle} import java.util.{Locale, TimeZone} import com.univocity.parsers.csv.{CsvParserSettings, CsvWriterSettings, UnescapedQuoteHandling} @@ -142,6 +143,16 @@ class CSVOptions( val isCommentSet = this.comment != '\u0000' + def dateFormatter: DateTimeFormatter = { + DateTimeFormatter.ofPattern(dateFormat.getPattern) + .withLocale(Locale.US).withZone(timeZone.toZoneId).withResolverStyle(ResolverStyle.SMART) + } + + def timestampFormatter: DateTimeFormatter = { + DateTimeFormatter.ofPattern(timestampFormat.getPattern) + .withLocale(Locale.US).withZone(timeZone.toZoneId).withResolverStyle(ResolverStyle.SMART) + } + def asWriterSettings: CsvWriterSettings = { val writerSettings = new CsvWriterSettings() val format = writerSettings.getFormat